Skip to content
Snippets Groups Projects
Select Git revision
  • ffa07a2c46192458e500a303f339c6584c8b1334
  • master default protected
  • v1.0.0
3 results

DeepLearningVersionSelector.java

Blame
  • DeepLearningVersionSelector.java 6.76 KiB
    package plugins.danyfel80.deeplearningdownloader;
    
    import java.awt.event.WindowListener;
    import java.io.IOException;
    import java.util.List;
    import java.util.Map;
    import java.util.function.Function;
    import java.util.stream.Collectors;
    
    import org.bioimageanalysis.icy.deeplearning.versionmanager.version.AvailableDeepLearningVersions;
    import org.bioimageanalysis.icy.deeplearning.versionmanager.version.DeepLearningVersion;
    
    import icy.gui.dialog.MessageDialog;
    import icy.main.Icy;
    import icy.plugin.PluginLauncher;
    import icy.plugin.PluginLoader;
    import icy.preferences.XMLPreferences;
    import icy.system.IcyHandledException;
    import icy.util.StringUtil;
    import plugins.adufour.blocks.lang.Block;
    import plugins.adufour.blocks.util.VarList;
    import plugins.adufour.ezplug.EzPlug;
    import plugins.adufour.ezplug.EzStoppable;
    import plugins.adufour.ezplug.EzVarText;
    import plugins.adufour.vars.lang.VarBoolean;
    import plugins.adufour.vars.lang.VarString;
    
    /**
     * This plugin allows users to select the TensorFlow version to load in memory.
     * Just select one of the available versions and click the "play" button to load the library.
     * Please keep in mind that only one version of TensorFlow can be loaded per Icy
     * execution. An Exception will be thrown if a TensorFlow version tries to loaded after another version has already been loaded.
     * 
     * @author Daniel Felipe Gonzalez Obando and Carlos Garcia Lopez de Haro
     */
    public class DeepLearningVersionSelector extends EzPlug implements EzStoppable, Block
    {
        private EzVarText varInVersion;
        private final static Map<String, DeepLearningVersion> versions;
        static
        {
            versions = getDisplayMapWithoutPythonRepeated();
        }
        
        /**
         * Create map to display the available Deep Learning versions for naive users.
         * only the DL framework, python version and whether is GPU or CPU
         * @return map to let naive users select the wanted DL version
         */
        private static Map<String, DeepLearningVersion> getDisplayMapWithoutPythonRepeated() {
        	List<DeepLearningVersion> allVersions = AvailableDeepLearningVersions.loadCompatibleOnly().getVersions();
        	List<DeepLearningVersion> vList = AvailableDeepLearningVersions.removeRepeatedPythonVersions(allVersions);
    	    Map<String, DeepLearningVersion> map = vList.stream()
    		        .collect(Collectors.toMap(
    		                v -> v.getEngine() + "-" + v.getPythonVersion()
    		                + (v.getCPU() ? "-" + DeepLearningVersion.cpuKey : "")
    		                + (v.getGPU() ? "-" + DeepLearningVersion.gpuKey : ""),
    		                Function.identity()));
    	    return map;
        }
    
        /**
         * Main routine launching this plugin.
         * 
         * @param args
         */
        public static void main(String[] args)
        {
            Icy.main(args);
            PluginLauncher.start(PluginLoader.getPlugin(DeepLearningVersionSelector.class.getName()));
        }
    
        @Override
        protected void initialize()
        {
            String[] versionStrings = versions.keySet().stream().sorted().toArray(String[]::new);
            varInVersion = new EzVarText("Version", versionStrings, getDefaultVersionIndex(versionStrings), false);
            addEzComponent(varInVersion);
        }
    
        private VarString varInBlockVersion;
    
        @Override
        public void declareInput(VarList inputMap)
        {
            String lastUsedVersion = getLastUsedVersion();
            lastUsedVersion = (!lastUsedVersion.equals(""))
                ? lastUsedVersion
                : versions.keySet().stream().findFirst().orElse("");
    
            varInBlockVersion = new VarString("Library version", lastUsedVersion);
        }
    
        VarBoolean varOutLoaded;
    
        @Override
        public void declareOutput(VarList outputMap)
        {
            varOutLoaded = new VarBoolean("Loaded", false);
            outputMap.add("Loaded", varOutLoaded);
        }
    
        private int getDefaultVersionIndex(String[] versionStrings)
        {
            String lastUsedVersion = getLastUsedVersion();
            int lastIndex = 0;
            if (versions.containsKey(lastUsedVersion))
            {
                for (int i = 0; i < versionStrings.length; i++)
                {
                    if (versionStrings[i].equals(lastUsedVersion))
                    {
                        lastIndex = i;
                        break;
                    }
                }
            }
            return lastIndex;
        }
    
        private String getLastUsedVersion()
        {
            XMLPreferences prefs = getPreferencesRoot();
            return prefs.get("lastUsed", "");
        }
    
        @Override
        protected void execute()
        {
            String targetVersion = (!isHeadLess()) ? varInVersion.getValue(true) : varInBlockVersion.getValue(true);
    
            DeepLearningVersion version = versions.get(targetVersion);
            notifyProgress(Double.NaN, "Loading TensorFlow " + version.getVersion() + "-" + version.getOs()
            		+ (version.getCPU() ? "-" + DeepLearningVersion.cpuKey : "")
                    + (version.getGPU() ? "-" + DeepLearningVersion.gpuKey : "") + "...");
    
            try {
            	DeepLearningDownloader.downloadLibrary(version, false,
                    (p, t) -> notifyProgress(p / t,
                            "Downloading DL engine " + version.getVersion() + "-" + version.getOs()
                            		+ (version.getCPU() ? "-" + DeepLearningVersion.cpuKey : "")
                                    + (version.getGPU() ? "-" + DeepLearningVersion.gpuKey : "") + "..."));
            } catch (Exception ex) {
                if (isHeadLess())
                    varOutLoaded.setValue(false);
                ex.printStackTrace();
                notifyError("Error downloading Deep Learning framework: " + ex.getMessage());
            }
    
            if (!isHeadLess())
                this.getUI().close();
            else
                varOutLoaded.setValue(true);
            notifyInfo();
        }
    
        private boolean notifyProgress(double progress, String message)
        {
            if (isHeadLess())
            {
                System.out.println(
                        "(" + StringUtil.toString(Double.isFinite(progress) ? (progress * 100d) : 0d, 2) + "%)" + message);
            }
            else
            {
                getUI().setProgressBarValue(progress);
                getUI().setProgressBarMessage(message);
            }
            return false;
        }
    
        private void notifyError(String message)
        {
            if (isHeadLess())
            {
                System.err.println(message);
            }
            else
            {
                MessageDialog.showDialog("Error", message, MessageDialog.ERROR_MESSAGE);
            }
        }
    
        private void notifyInfo(String message)
        {
            if (isHeadLess())
            {
                System.out.println(message);
            }
            else
            {
                MessageDialog.showDialog("Success", message, MessageDialog.INFORMATION_MESSAGE);
            }
        }
    
        @Override
        public void clean()
        {
        	System.out.println("Closing plugin");
        }
        
        public void addPluginListener(WindowListener listener)
        {
            this.getUI().onClosed();
        }
    
    }