Skip to content
Snippets Groups Projects
Select Git revision
  • f2e45da91c5e80888c31f727417713bc677f79dd
  • master default protected
  • dev
  • score_test
4 results

path_optimization.py

Blame
  • DeepLearningVersionSelector.java 9.47 KiB
    package plugins.danyfel80.deeplearningdownloader;
    
    import java.awt.event.WindowListener;
    import java.util.List;
    import java.util.Map;
    import java.util.function.Function;
    import java.util.stream.Collectors;
    
    import org.bioimageanalysis.icy.deeplearning.system.PlatformDetection;
    import org.bioimageanalysis.icy.deeplearning.utils.Constants;
    import org.bioimageanalysis.icy.deeplearning.versionmanagement.AvailableDeepLearningVersions;
    import org.bioimageanalysis.icy.deeplearning.versionmanagement.DeepLearningVersion;
    
    import icy.gui.dialog.MessageDialog;
    import icy.main.Icy;
    import icy.plugin.PluginLauncher;
    import icy.plugin.PluginLoader;
    import icy.preferences.XMLPreferences;
    import icy.util.StringUtil;
    import plugins.adufour.blocks.lang.Block;
    import plugins.adufour.blocks.util.VarList;
    import plugins.adufour.ezplug.EzPlug;
    import plugins.adufour.ezplug.EzStoppable;
    import plugins.adufour.ezplug.EzVarText;
    import plugins.adufour.vars.lang.VarBoolean;
    import plugins.adufour.vars.lang.VarString;
    
    /**
     * This plugin allows users to select the TensorFlow version to load in memory.
     * Just select one of the available versions and click the "play" button to load the library.
     * Please keep in mind that only one version of TensorFlow can be loaded per Icy
     * execution. An Exception will be thrown if a TensorFlow version tries to loaded after another version has already been loaded.
     * 
     * @author Daniel Felipe Gonzalez Obando and Carlos Garcia Lopez de Haro
     */
    public class DeepLearningVersionSelector extends EzPlug implements EzStoppable, Block
    {
        private EzVarText varInVersion;
        private final static Map<String, DeepLearningVersion> versions;
        static
        {
            versions = getDisplayMapWithoutPythonRepeated();
        }
        
        /**
         * Create map to display the available Deep Learning versions for naive users.
         * only the DL framework, python version and whether is GPU or CPU
         * @return map to let naive users select the wanted DL version
         */
        private static Map<String, DeepLearningVersion> getDisplayMapWithoutPythonRepeated() {
        	List<DeepLearningVersion> allVersions = AvailableDeepLearningVersions.loadCompatibleOnly().getVersions();
        	List<DeepLearningVersion> vList = AvailableDeepLearningVersions.removeRepeatedPythonVersions(allVersions);
    	    Map<String, DeepLearningVersion> map = vList.stream()
    		        .collect(Collectors.toMap(
    		                v -> v.getEngine() + "-" + v.getPythonVersion()
    		                + (v.getCPU() ? "-" + DeepLearningVersion.cpuKey : "")
    		                + (v.getGPU() ? "-" + DeepLearningVersion.gpuKey : ""),
    		                Function.identity()));
    	    return map;
        }
    
        /**
         * Main routine launching this plugin.
         * 
         * @param args
         */
        public static void main(String[] args)
        {
            Icy.main(args);
            PluginLauncher.start(PluginLoader.getPlugin(DeepLearningVersionSelector.class.getName()));
        }
    
        @Override
        protected void initialize()
        {
        	PlatformDetection system = new PlatformDetection();
        	if (system.getArch().equals(PlatformDetection.ARCH_ARM64)) {
        		arm64InfoMessage();
        	} else if (system.isUsingRosseta()) {
        		rosettaInfoMessage();
        	}
            String[] versionStrings = versions.keySet().stream().sorted().toArray(String[]::new);
            varInVersion = new EzVarText("Version", versionStrings, getDefaultVersionIndex(versionStrings), false);
            addEzComponent(varInVersion);
        }
        
        /**
         * Display information message about the inconveniences using arm64 chips
         */
        private void arm64InfoMessage() {
        	MessageDialog.showDialog("ARM64 chips and compatiblity with Deep Learnign engines",
        			"This computer uses the ARM64 chip architecture. This architecture" + System.lineSeparator()
        		  + "is relatively recent, therefore many of the existing Deep Learning" + System.lineSeparator()
        		  + "engines will not be available on your computer." + System.lineSeparator()
        		  + "ARM64 chips also provide the possibility of running some x86_64" + System.lineSeparator()
        		  + "compiled programs using Rosetta. In order to enable Rosetta, change" + System.lineSeparator()
        		  + "the JAVA_HOME variable to a Java 8 or lower. Currently, the JAVA_HOME" + System.lineSeparator()
        		  + "points to:" + System.lineSeparator() + System.getProperty("java.home") + System.lineSeparator()
        		  + "Using Rosetta will enable more Deep Learning engines, although some" + System.lineSeparator()
        		  + "will still be missing." + System.lineSeparator()
        		  + "For more information, go to the Wiki: " + System.lineSeparator() + Constants.WIKI_LINK,
        		  MessageDialog.INFORMATION_MESSAGE);
        }
        
        /**
         * Display information message about the inconveniences using arm64 chips
         */
        private void rosettaInfoMessage() {
        	MessageDialog.showDialog("Rosetta is being used",
        			"This computer uses the ARM64 chip architecture. Arm64 chips are" + System.lineSeparator()
        		  + "compatible with fewer Deep Learning engines. However, ARM64 chips" + System.lineSeparator()
        		  + "can replicate the \"traditional\" x86_64 architecture with Rosetta." + System.lineSeparator()
        		  + "Rosetta is currently enabled, allowing the use of certain Deep Learning" + System.lineSeparator()
        		  + "engines that would not be available for ARM64 architecture systems." + System.lineSeparator()
        		  + "The list of installable Deep Learning engines  will be longer than the" + System.lineSeparator()
        		  + "list if the system was not using Rosetta, although Tensorflow 1 will" + System.lineSeparator()
        		  + "still be unavailable." + System.lineSeparator()
        		  + "In order to disable Rosetta and use the full capabilities of the ARM64" + System.lineSeparator()
        		  + "chip, change the JAVA_HOME to a Java distribution compatible with ARM64." + System.lineSeparator()
        		  + "Currently, the JAVA_HOME variable points to:" + System.lineSeparator() 
        		  + System.getProperty("java.home") + System.lineSeparator()
        		  + "For more information, go to the Wiki: " + System.lineSeparator() + Constants.WIKI_LINK,
        		  MessageDialog.INFORMATION_MESSAGE);
        }
    
        private VarString varInBlockVersion;
    
        @Override
        public void declareInput(VarList inputMap)
        {
            String lastUsedVersion = getLastUsedVersion();
            lastUsedVersion = (!lastUsedVersion.equals(""))
                ? lastUsedVersion
                : versions.keySet().stream().findFirst().orElse("");
    
            varInBlockVersion = new VarString("Library version", lastUsedVersion);
        }
    
        VarBoolean varOutLoaded;
    
        @Override
        public void declareOutput(VarList outputMap)
        {
            varOutLoaded = new VarBoolean("Loaded", false);
            outputMap.add("Loaded", varOutLoaded);
        }
    
        private int getDefaultVersionIndex(String[] versionStrings)
        {
            String lastUsedVersion = getLastUsedVersion();
            int lastIndex = 0;
            if (versions.containsKey(lastUsedVersion))
            {
                for (int i = 0; i < versionStrings.length; i++)
                {
                    if (versionStrings[i].equals(lastUsedVersion))
                    {
                        lastIndex = i;
                        break;
                    }
                }
            }
            return lastIndex;
        }
    
        private String getLastUsedVersion()
        {
            XMLPreferences prefs = getPreferencesRoot();
            return prefs.get("lastUsed", "");
        }
    
        @Override
        protected void execute()
        {
            String targetVersion = (!isHeadLess()) ? varInVersion.getValue(true) : varInBlockVersion.getValue(true);
    
            DeepLearningVersion version = versions.get(targetVersion);
            notifyProgress(Double.NaN, "Downloading DL engine " + targetVersion + "...");
    
            try {
            	DeepLearningDownloader.downloadLibrary(version, false,
                    (p, t) -> notifyProgress(p / t,
                            "Downloading DL engine " + targetVersion + "..."));
            } catch (Exception ex) {
                if (isHeadLess())
                    varOutLoaded.setValue(false);
                ex.printStackTrace();
                notifyError("Error downloading Deep Learning framework: " + ex.getMessage());
            }
    
            if (!isHeadLess())
                this.getUI().close();
            else
                varOutLoaded.setValue(true);
            String msg = "Success downloading the " + targetVersion + " engine at: \n"
            		+ version.folderName();
            notifyInfo(msg);
        }
    
        private boolean notifyProgress(double progress, String message)
        {
            if (isHeadLess())
            {
                System.out.println(
                        "(" + StringUtil.toString(Double.isFinite(progress) ? (progress * 100d) : 0d, 2) + "%)" + message);
            }
            else
            {
                getUI().setProgressBarValue(progress);
                getUI().setProgressBarMessage(message);
            }
            return false;
        }
    
        private void notifyError(String message)
        {
            if (isHeadLess())
            {
                System.err.println(message);
            }
            else
            {
                MessageDialog.showDialog("Error", message, MessageDialog.ERROR_MESSAGE);
            }
        }
    
        private void notifyInfo(String message)
        {
            if (isHeadLess())
            {
                System.out.println(message);
            }
            else
            {
                MessageDialog.showDialog("Success", message, MessageDialog.INFORMATION_MESSAGE);
            }
        }
    
        @Override
        public void clean()
        {
        	System.out.println("Closing plugin");
        }
        
        public void addPluginListener(WindowListener listener)
        {
            this.getUI().onClosed();
        }
    
    }