Select Git revision
DeepLearningVersionSelector.java
DeepLearningVersionSelector.java 6.42 KiB
package plugins.danyfel80.deeplearningdownloader;
import java.awt.event.ActionListener;
import java.awt.event.WindowListener;
import java.io.IOException;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.bioimageanalysis.icy.deeplearning.versionmanager.loading.LibraryLoadingStatus;
import org.bioimageanalysis.icy.deeplearning.versionmanager.version.AvailableDeepLearningVersions;
import org.bioimageanalysis.icy.deeplearning.versionmanager.version.DeepLearningVersion;
import icy.gui.dialog.MessageDialog;
import icy.main.Icy;
import icy.plugin.PluginLauncher;
import icy.plugin.PluginLoader;
import icy.preferences.XMLPreferences;
import icy.system.IcyHandledException;
import icy.util.StringUtil;
import plugins.adufour.blocks.lang.Block;
import plugins.adufour.blocks.util.VarList;
import plugins.adufour.ezplug.EzPlug;
import plugins.adufour.ezplug.EzStoppable;
import plugins.adufour.ezplug.EzVarText;
import plugins.adufour.vars.lang.VarBoolean;
import plugins.adufour.vars.lang.VarString;
/**
* This plugin allows users to select the TensorFlow version to load in memory.
* Just select one of the available versions and click the "play" button to load the library.
* Please keep in mind that only one version of TensorFlow can be loaded per Icy
* execution. An Exception will be thrown if a TensorFlow version tries to loaded after another version has already been loaded.
*
* @author Daniel Felipe Gonzalez Obando and Carlos Garcia Lopez de Haro
*/
public class DeepLearningVersionSelector extends EzPlug implements EzStoppable, Block
{
private EzVarText varInVersion;
private final static Map<String, DeepLearningVersion> versions;
static
{
versions = AvailableDeepLearningVersions.loadCompatibleOnly().getVersions().stream()
.collect(Collectors.toMap(
v -> v.getEngine() + "-" + v.getPythonVersion() + "-"
+ (v.getMode().toLowerCase().equals("cpu") ? "cpu" : "cpu-gpu"),
Function.identity()));
}
/**
* Main routine launching this plugin.
*
* @param args
*/
public static void main(String[] args)
{
Icy.main(args);
PluginLauncher.start(PluginLoader.getPlugin(DeepLearningVersionSelector.class.getName()));
}
@Override
protected void initialize()
{
String[] versionStrings = versions.keySet().stream().sorted().toArray(String[]::new);
varInVersion = new EzVarText("Version", versionStrings, getDefaultVersionIndex(versionStrings), false);
addEzComponent(varInVersion);
}
private VarString varInBlockVersion;
@Override
public void declareInput(VarList inputMap)
{
String lastUsedVersion = getLastUsedVersion();
lastUsedVersion = (!lastUsedVersion.equals(""))
? lastUsedVersion
: versions.keySet().stream().findFirst().orElse("");
varInBlockVersion = new VarString("Library version", lastUsedVersion);
}
VarBoolean varOutLoaded;
@Override
public void declareOutput(VarList outputMap)
{
varOutLoaded = new VarBoolean("Loaded", false);
outputMap.add("Loaded", varOutLoaded);
}
private int getDefaultVersionIndex(String[] versionStrings)
{
String lastUsedVersion = getLastUsedVersion();
int lastIndex = 0;
if (versions.containsKey(lastUsedVersion))
{
for (int i = 0; i < versionStrings.length; i++)
{
if (versionStrings[i].equals(lastUsedVersion))
{
lastIndex = i;
break;
}
}
}
return lastIndex;
}
private String getLastUsedVersion()
{
XMLPreferences prefs = getPreferencesRoot();
return prefs.get("lastUsed", "");
}
@Override
protected void execute()
{
String targetVersion = (!isHeadLess()) ? varInVersion.getValue(true) : varInBlockVersion.getValue(true);
DeepLearningVersion version = versions.get(targetVersion);
try
{
notifyProgress(Double.NaN, "Loading TensorFlow " + version.getVersion() + "-" + version.getOs() + "-"
+ version.getMode() + "...");
LibraryLoadingStatus status = DeepLearningDownloader.loadLibrary(version, false,
(p, t) -> notifyProgress(p / t,
"Loading TensorFlow " + version.getVersion() + "-" + version.getOs() + "-"
+ version.getMode() + "..."));
if (status.getStatus().equals(LibraryLoadingStatus.ERROR_LOADING))
notifyError("Error loading TensorFlow: " + status.getMessage());
else
notifyInfo("Loaded: version=" + version.getVersion() + ", mode=" + version.getMode()
+ ", for TF version=" + version.getPythonVersion());
if (!isHeadLess())
this.getUI().close();
else
varOutLoaded.setValue(true);
}
catch (IOException e)
{
if (isHeadLess())
varOutLoaded.setValue(false);
e.printStackTrace();
throw new IcyHandledException("Could not load TensorFlow library: ", e);
}
}
private boolean notifyProgress(double progress, String message)
{
if (isHeadLess())
{
System.out.println(
"(" + StringUtil.toString(Double.isFinite(progress) ? (progress * 100d) : 0d, 2) + "%)" + message);
}
else
{
getUI().setProgressBarValue(progress);
getUI().setProgressBarMessage(message);
}
return false;
}
private void notifyError(String message)
{
if (isHeadLess())
{
System.err.println(message);
}
else
{
MessageDialog.showDialog("Error", message, MessageDialog.ERROR_MESSAGE);
}
}
private void notifyInfo(String message)
{
if (isHeadLess())
{
System.out.println(message);
}
else
{
MessageDialog.showDialog("Success", message, MessageDialog.INFORMATION_MESSAGE);
}
}
@Override
public void clean()
{
System.out.println("Closing plugin");
}
public void addPluginListener(WindowListener listener)
{
this.getUI().onClosed();
}
}