Skip to content
Snippets Groups Projects
Select Git revision
  • 60d22c547770feff9074ca4678dae3a35b706344
  • master default protected
  • v1.0.0
3 results

DeepLearningVersionSelector.java

Blame
  • DeepLearningVersionSelector.java 6.42 KiB
    package plugins.danyfel80.deeplearningdownloader;
    
    import java.awt.event.ActionListener;
    import java.awt.event.WindowListener;
    import java.io.IOException;
    import java.util.Map;
    import java.util.function.Function;
    import java.util.stream.Collectors;
    
    import org.bioimageanalysis.icy.deeplearning.versionmanager.loading.LibraryLoadingStatus;
    import org.bioimageanalysis.icy.deeplearning.versionmanager.version.AvailableDeepLearningVersions;
    import org.bioimageanalysis.icy.deeplearning.versionmanager.version.DeepLearningVersion;
    
    import icy.gui.dialog.MessageDialog;
    import icy.main.Icy;
    import icy.plugin.PluginLauncher;
    import icy.plugin.PluginLoader;
    import icy.preferences.XMLPreferences;
    import icy.system.IcyHandledException;
    import icy.util.StringUtil;
    import plugins.adufour.blocks.lang.Block;
    import plugins.adufour.blocks.util.VarList;
    import plugins.adufour.ezplug.EzPlug;
    import plugins.adufour.ezplug.EzStoppable;
    import plugins.adufour.ezplug.EzVarText;
    import plugins.adufour.vars.lang.VarBoolean;
    import plugins.adufour.vars.lang.VarString;
    
    /**
     * This plugin allows users to select the TensorFlow version to load in memory.
     * Just select one of the available versions and click the "play" button to load the library.
     * Please keep in mind that only one version of TensorFlow can be loaded per Icy
     * execution. An Exception will be thrown if a TensorFlow version tries to loaded after another version has already been loaded.
     * 
     * @author Daniel Felipe Gonzalez Obando and Carlos Garcia Lopez de Haro
     */
    public class DeepLearningVersionSelector extends EzPlug implements EzStoppable, Block
    {
        private EzVarText varInVersion;
        private final static Map<String, DeepLearningVersion> versions;
        static
        {
            versions = AvailableDeepLearningVersions.loadCompatibleOnly().getVersions().stream()
                    .collect(Collectors.toMap(
                            v -> v.getEngine() + "-" + v.getPythonVersion() + "-"
                            + (v.getMode().toLowerCase().equals("cpu") ? "cpu" : "cpu-gpu"),
                            Function.identity()));
        }
    
        /**
         * Main routine launching this plugin.
         * 
         * @param args
         */
        public static void main(String[] args)
        {
            Icy.main(args);
            PluginLauncher.start(PluginLoader.getPlugin(DeepLearningVersionSelector.class.getName()));
        }
    
        @Override
        protected void initialize()
        {
            String[] versionStrings = versions.keySet().stream().sorted().toArray(String[]::new);
            varInVersion = new EzVarText("Version", versionStrings, getDefaultVersionIndex(versionStrings), false);
            addEzComponent(varInVersion);
        }
    
        private VarString varInBlockVersion;
    
        @Override
        public void declareInput(VarList inputMap)
        {
            String lastUsedVersion = getLastUsedVersion();
            lastUsedVersion = (!lastUsedVersion.equals(""))
                ? lastUsedVersion
                : versions.keySet().stream().findFirst().orElse("");
    
            varInBlockVersion = new VarString("Library version", lastUsedVersion);
        }
    
        VarBoolean varOutLoaded;
    
        @Override
        public void declareOutput(VarList outputMap)
        {
            varOutLoaded = new VarBoolean("Loaded", false);
            outputMap.add("Loaded", varOutLoaded);
        }
    
        private int getDefaultVersionIndex(String[] versionStrings)
        {
            String lastUsedVersion = getLastUsedVersion();
            int lastIndex = 0;
            if (versions.containsKey(lastUsedVersion))
            {
                for (int i = 0; i < versionStrings.length; i++)
                {
                    if (versionStrings[i].equals(lastUsedVersion))
                    {
                        lastIndex = i;
                        break;
                    }
                }
            }
            return lastIndex;
        }
    
        private String getLastUsedVersion()
        {
            XMLPreferences prefs = getPreferencesRoot();
            return prefs.get("lastUsed", "");
        }
    
        @Override
        protected void execute()
        {
            String targetVersion = (!isHeadLess()) ? varInVersion.getValue(true) : varInBlockVersion.getValue(true);
    
            DeepLearningVersion version = versions.get(targetVersion);
            try
            {
                notifyProgress(Double.NaN, "Loading TensorFlow " + version.getVersion() + "-" + version.getOs() + "-"
                        + version.getMode() + "...");
    
                LibraryLoadingStatus status = DeepLearningDownloader.loadLibrary(version, false,
                        (p, t) -> notifyProgress(p / t,
                                "Loading TensorFlow " + version.getVersion() + "-" + version.getOs() + "-"
                                        + version.getMode() + "..."));
                if (status.getStatus().equals(LibraryLoadingStatus.ERROR_LOADING))
                    notifyError("Error loading TensorFlow: " + status.getMessage());
                else
                    notifyInfo("Loaded: version=" + version.getVersion() + ", mode=" + version.getMode()
                            + ", for TF version=" + version.getPythonVersion());
    
                if (!isHeadLess())
                    this.getUI().close();
                else
                    varOutLoaded.setValue(true);
            }
            catch (IOException e)
            {
                if (isHeadLess())
                    varOutLoaded.setValue(false);
                e.printStackTrace();
                throw new IcyHandledException("Could not load TensorFlow library: ", e);
            }
        }
    
        private boolean notifyProgress(double progress, String message)
        {
            if (isHeadLess())
            {
                System.out.println(
                        "(" + StringUtil.toString(Double.isFinite(progress) ? (progress * 100d) : 0d, 2) + "%)" + message);
            }
            else
            {
                getUI().setProgressBarValue(progress);
                getUI().setProgressBarMessage(message);
            }
            return false;
        }
    
        private void notifyError(String message)
        {
            if (isHeadLess())
            {
                System.err.println(message);
            }
            else
            {
                MessageDialog.showDialog("Error", message, MessageDialog.ERROR_MESSAGE);
            }
        }
    
        private void notifyInfo(String message)
        {
            if (isHeadLess())
            {
                System.out.println(message);
            }
            else
            {
                MessageDialog.showDialog("Success", message, MessageDialog.INFORMATION_MESSAGE);
            }
        }
    
        @Override
        public void clean()
        {
        	System.out.println("Closing plugin");
        }
        
        public void addPluginListener(WindowListener listener)
        {
            this.getUI().onClosed();
        }
    
    }