diff --git a/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/Tensorflow1Interface.java b/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/Tensorflow1Interface.java index 18e5e5dc3224d65fbd23a6ecd5d340f79cddf2e0..9907ed3de9f85eda71ab1964cc37aec1b1430210 100644 --- a/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/Tensorflow1Interface.java +++ b/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/Tensorflow1Interface.java @@ -1,13 +1,28 @@ package org.bioimageanalysis.icy.deeplearning.tensorflow.v1; +import java.io.File; +import java.io.IOException; +import java.io.RandomAccessFile; +import java.nio.ByteBuffer; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.security.CodeSource; +import java.security.ProtectionDomain; import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedList; import java.util.List; import org.bioimageanalysis.icy.deeplearning.engine.DeepLearningEngineInterface; import org.bioimageanalysis.icy.deeplearning.exceptions.LoadModelException; import org.bioimageanalysis.icy.deeplearning.exceptions.RunModelException; +import org.bioimageanalysis.icy.deeplearning.system.PlatformDetection; import org.bioimageanalysis.icy.deeplearning.tensor.Tensor; import org.bioimageanalysis.icy.deeplearning.tensorflow.v1.tensor.ImgLib2Builder; +import org.bioimageanalysis.icy.deeplearning.tensorflow.v1.tensor.MappedBufferToImgLib2; +import org.bioimageanalysis.icy.deeplearning.tensorflow.v1.tensor.MappedFileBuilder; import org.bioimageanalysis.icy.deeplearning.tensorflow.v1.tensor.TensorBuilder; import org.tensorflow.SavedModelBundle; import org.tensorflow.Session; @@ -55,6 +70,28 @@ public class Tensorflow1Interface implements DeepLearningEngineInterface "tf.saved_model.signature_constants.DEFAULT_EVAL_SIGNATURE_DEF_KEY", "tf.saved_model.signature_constants.SUPERVISED_TRAIN_METHOD_NAME", "tf.saved_model.signature_constants.SUPERVISED_EVAL_METHOD_NAME"}; + + /** + * Idetifier for the files that contain the data of the inputs + */ + final private static String INPUT_FILE_TERMINATION = "_model_input"; + + /** + * Idetifier for the files that contain the data of the outputs + */ + final private static String OUTPUT_FILE_TERMINATION = "_model_output"; + /** + * Key for the inputs in the map that retrieves the file names for interprocess communication + */ + final private static String INPUTS_MAP_KEY = "inputs"; + /** + * Key for the outputs in the map that retrieves the file names for interprocess communication + */ + final private static String OUTPUTS_MAP_KEY = "outputs"; + /** + * File extension for the temporal files used for interprocessing + */ + final private static String FILE_EXTENSION = ".dat"; /** * The loaded Tensorflow 1 model @@ -62,20 +99,90 @@ public class Tensorflow1Interface implements DeepLearningEngineInterface private static SavedModelBundle model; private static SignatureDef sig; - public Tensorflow1Interface() + private boolean interprocessing = false; + + private String tmpDir; + + private String modelFolder; + + public Tensorflow1Interface() throws IOException + { + boolean isMac = PlatformDetection.isMacOS(); + boolean isIntel = new PlatformDetection().getArch().equals(PlatformDetection.ARCH_X86_64); + if (isMac && isIntel) { + interprocessing = true; + tmpDir = getTemporaryDir(); + + } + } + + public Tensorflow1Interface(boolean doInterprocessing) throws IOException { + if (!doInterprocessing) { + interprocessing = false; + } else { + boolean isMac = PlatformDetection.isMacOS(); + boolean isIntel = new PlatformDetection().getArch().equals(PlatformDetection.ARCH_X86_64); + if (isMac && isIntel) { + interprocessing = true; + tmpDir = getTemporaryDir(); + + } + } } public static void main(String[] args) throws LoadModelException { - Tensorflow1Interface intf = new Tensorflow1Interface(); - intf.loadModel("C:\\Users\\angel\\OneDrive\\Documentos\\pasteur\\git\\deep-icy\\models\\Neuron Segmentation in 2D EM (Membrane)_02022023_175546", ""); - String aa = intf.getModelInputName("input"); - String bb = intf.getModelOutputName("output"); - System.out.print(false); + // Unpack the args needed + if (args.length < 4) + throw new IllegalArgumentException("Error exectuting Tensorflow 1, " + + "at least arguments are required:" + System.lineSeparator() + + " - Folder where the model is located" + System.lineSeparator() + + " - Temporary dir where the memory mapped files are located" + System.lineSeparator() + + " - Name of the model input followed by the String + '_model_input'" + System.lineSeparator() + + " - Name of the second model input (if it exists) followed by the String + '_model_input'" + System.lineSeparator() + + " - ...." + System.lineSeparator() + + " - Name of the nth model input (if it exists) followed by the String + '_model_input'" + System.lineSeparator() + + " - Name of the model output followed by the String + '_model_output'" + System.lineSeparator() + + " - Name of the second model output (if it exists) followed by the String + '_model_output'" + System.lineSeparator() + + " - ...." + System.lineSeparator() + + " - Name of the nth model output (if it exists) followed by the String + '_model_output'" + System.lineSeparator() + ); + String modelFolder = args[0]; + if (new File(modelFolder).isDirectory()) { + throw new IllegalArgumentException("Argument 0 of the main method, '" + modelFolder + "' " + + "should be an existing directory containing a Tensorflow 1 model."); + } + + Tensorflow1Interface tfInterface = new Tensorflow1Interface(); + tfInterface.tmpDir = args[1]; + if (new File(args[1]).isDirectory()) { + throw new IllegalArgumentException("Argument 1 of the main method, '" + args[1] + "' " + + "should be an existing directory."); + } + + tfInterface.loadModel(modelFolder, modelFolder); + + HashMap<String, List<String>> map = getInputTensorsFileNames(args); + List<String> inputNames = map.get(INPUTS_MAP_KEY); + List<String> outputNames = map.get(OUTPUTS_MAP_KEY); + for + + for (String inName : inputNames) { + try (RandomAccessFile rd = + new RandomAccessFile(tfInterface.tmpDir + File.separator + inName, "r"); + FileChannel fc = rd.getChannel();) { + int initialLenghToRead = inName.length() + DATATYPE_KEY_LENGTH + MAGIC_WORD_LENGTH + TYPICAL_SHAPE_LENTH; + MappedByteBuffer mem = fc.map(FileChannel.MapMode.READ_WRITE, 0, fc.size()); + } + } } @Override public void loadModel(String modelFolder, String modelSource) throws LoadModelException { + if (interprocessing) { + this.modelFolder = modelFolder; + return; + } model = SavedModelBundle.load(modelFolder, "serve"); byte[] byteGraph = model.metaGraphDef(); try { @@ -87,6 +194,10 @@ public class Tensorflow1Interface implements DeepLearningEngineInterface @Override public void run(List<Tensor<?>> inputTensors, List<Tensor<?>> outputTensors) throws RunModelException { + if (interprocessing) { + runInterprocessing(inputTensors, outputTensors); + return; + } Session session = model.session(); Session.Runner runner = session.runner(); List<String> inputListNames = new ArrayList<String>(); @@ -113,6 +224,30 @@ public class Tensorflow1Interface implements DeepLearningEngineInterface } } + public void runInterprocessing(List<Tensor<?>> inputTensors, List<Tensor<?>> outputTensors) throws RunModelException { + createTensorsForInterprocessing(inputTensors); + List<String> args = getProcessCommandsWithoutArgs(); + args.add(modelFolder); + args.add(this.tmpDir); + for (Tensor tensor : inputTensors) {args.add(tensor.getName() + INPUT_FILE_TERMINATION);} + for (Tensor tensor : outputTensors) {args.add(tensor.getName() + OUTPUT_FILE_TERMINATION);} + + ProcessBuilder builder = new ProcessBuilder(args); + Process process; + try { + process = builder.inheritIO().start(); + if (process.waitFor() != 0) + throw new RunModelException("Error executing the Tensorflow 1 model in" + + " a separate process. The process was not terminated correctly."); + } catch (RunModelException e) { + throw e; + } catch (Exception e) { + throw new RunModelException(e.getCause().toString()); + } + + retrieveInterprocessingTensors(outputTensors); + } + /** * Create the list a list of output tensors agnostic to the Deep Learning engine * that can be readable by Deep Icy @@ -142,6 +277,61 @@ public class Tensorflow1Interface implements DeepLearningEngineInterface } model = null; } + + private void createTensorsForInterprocessing(List<Tensor<?>> tensors) throws RunModelException{ + for (Tensor<?> tensor : tensors) { + long lenFile = MappedFileBuilder.findTotalLengthFile(tensor); + try (RandomAccessFile rd = + new RandomAccessFile(tmpDir + File.separator + tensor.getName() + FILE_EXTENSION, "rw"); + FileChannel fc = rd.getChannel();) { + MappedByteBuffer mem = fc.map(FileChannel.MapMode.READ_WRITE, 0, lenFile); + ByteBuffer byteBuffer = mem.duplicate(); + byteBuffer.put(MappedFileBuilder.createFileHeader(tensor)); + MappedFileBuilder.build(tensor, byteBuffer); + } catch (IOException e) { + throw new RunModelException(e.getCause().toString()); + } + } + } + + private void retrieveInterprocessingTensors(List<Tensor<?>> tensors) throws RunModelException{ + for (Tensor<?> tensor : tensors) { + try (RandomAccessFile rd = + new RandomAccessFile(tmpDir + File.separator + tensor.getName() + FILE_EXTENSION, "r"); + FileChannel fc = rd.getChannel();) { + MappedByteBuffer mem = fc.map(FileChannel.MapMode.READ_WRITE, 0, fc.size()); + ByteBuffer byteBuffer = mem.duplicate(); + tensor.setData(MappedBufferToImgLib2.build(byteBuffer)); + } catch (IOException e) { + throw new RunModelException(e.getCause().toString()); + } + } + } + + /** + * Create the arguments needed to execute tensorflow1 in another + * process with the corresponding tensors + * @return + */ + private List<String> getProcessCommandsWithoutArgs(){ + System.out.println(System.getProperty("java.library.path")); + String javaHome = System.getProperty("java.home"); + String javaBin = javaHome + File.separator + "bin" + File.separator + "java"; + String classpath = System.getProperty("java.class.path"); + ProtectionDomain protectionDomain = Tensorflow1Interface.class.getProtectionDomain(); + CodeSource codeSource = protectionDomain.getCodeSource(); + String className = Tensorflow1Interface.class.getName(); + classpath += File.pathSeparator; + for (File ff : new File(codeSource.getLocation().getPath()).getParentFile().listFiles()) { + classpath += ff.getAbsolutePath() + File.pathSeparator; + } + List<String> command = new LinkedList<String>(); + command.add(javaBin); + command.add("-cp"); + command.add(classpath); + command.add(className); + return command; + } // TODO make only one /** @@ -203,4 +393,75 @@ public class Tensorflow1Interface implements DeepLearningEngineInterface return outputName; } } + + /** + * Retrieve the file names used for interprocess communication + * @param args + * args provided to the main method + * @return a map with a list of input and output names + */ + private static HashMap<String, List<String>> getInputTensorsFileNames(String[] args) { + List<String> inputNames = new ArrayList<String>(); + List<String> outputNames = new ArrayList<String>(); + for (int i = 2; i < args.length; i ++) { + if (args[i].endsWith(INPUT_FILE_TERMINATION)) + inputNames.add(args[i].substring(0, args[i].length() - INPUT_FILE_TERMINATION.length())); + else if (args[i].endsWith(OUTPUT_FILE_TERMINATION)) + outputNames.add(args[i].substring(0, args[i].length() - OUTPUT_FILE_TERMINATION.length())); + } + if (inputNames.size() == 0) + throw new IllegalArgumentException("The args to the main method of '" + + Tensorflow1Interface.class.toString() + "' should contain at " + + "least one input, defined as '<input_name> + '" + INPUT_FILE_TERMINATION + "'."); + if (outputNames.size() == 0) + throw new IllegalArgumentException("The args to the main method of '" + + Tensorflow1Interface.class.toString() + "' should contain at " + + "least one output, defined as '<output_name> + '" + OUTPUT_FILE_TERMINATION + "'."); + HashMap<String, List<String>> map = new HashMap<String, List<String>>(); + map.put(INPUTS_MAP_KEY, inputNames); + map.put(OUTPUTS_MAP_KEY, outputNames); + return map; + } + + /** + * Get temporary directory to perform the interprocessing communication in MacOSX intel + * @return the tmp dir + * @throws IOException + */ + private static String getTemporaryDir() throws IOException { + String tmpDir; + if (System.getenv("temp") != null + && Files.isWritable(Paths.get(System.getenv("temp")))) { + return System.getenv("temp"); + } else if (System.getenv("TEMP") != null + && Files.isWritable(Paths.get(System.getenv("TEMP")))) { + return System.getenv("TEMP"); + } else if (System.getenv("tmp") != null + && Files.isWritable(Paths.get(System.getenv("tmp")))) { + return System.getenv("tmp"); + } else if (System.getenv("TMP") != null + && Files.isWritable(Paths.get(System.getenv("TMP")))) { + return System.getenv("TMP"); + } else if (System.getProperty("java.io.tmpdir") != null + && Files.isWritable(Paths.get(System.getProperty("java.io.tmpdir")))) { + return System.getProperty("java.io.tmpdir"); + } + String enginesDir = getEnginesDir(); + if (Files.isWritable(Paths.get(enginesDir))) { + tmpDir = enginesDir + File.separator + "temp"; + if (!(new File(tmpDir).isDirectory()) && !(new File(tmpDir).mkdirs())) + tmpDir = enginesDir; + } else { + throw new IOException("Unable to find temporal directory with writting rights. " + + "Please either allow writting on the system temporal folder or on '" + enginesDir + "'."); + } + return tmpDir; + } + + private static String getEnginesDir() { + ProtectionDomain protectionDomain = Tensorflow1Interface.class.getProtectionDomain(); + CodeSource codeSource = protectionDomain.getCodeSource(); + String jarFile = codeSource.getLocation().getPath(); + return jarFile; + } } diff --git a/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/Tensorflow1InterfaceJAvaCPP.java b/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/Tensorflow1InterfaceJAvaCPP.java index 2f9196d0fc44ceae5ebdf7d7cdeb4a2af47e5d78..f06aaede4a5ad103225cfc05b3cd9ba6dde89205 100644 --- a/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/Tensorflow1InterfaceJAvaCPP.java +++ b/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/Tensorflow1InterfaceJAvaCPP.java @@ -78,24 +78,6 @@ public class Tensorflow1InterfaceJAvaCPP implements DeepLearningEngineInterface "tf.saved_model.signature_constants.DEFAULT_EVAL_SIGNATURE_DEF_KEY", "tf.saved_model.signature_constants.SUPERVISED_TRAIN_METHOD_NAME", "tf.saved_model.signature_constants.SUPERVISED_EVAL_METHOD_NAME"}; - - /** - * Idetifier for the files that contain the data of the inputs - */ - final private static String INPUT_FILE_TERMINATION = "_model_input"; - - /** - * Idetifier for the files that contain the data of the outputs - */ - final private static String OUTPUT_FILE_TERMINATION = "_model_output"; - /** - * Key for the inputs in the map that retrieves the file names for interprocess communication - */ - final private static String INPUTS_MAP_KEY = "inputs"; - /** - * Key for the outputs in the map that retrieves the file names for interprocess communication - */ - final private static String OUTPUTS_MAP_KEY = "outputs"; /** * Number of bytes typically used to represent the datatype in the memory mapped file */ @@ -110,10 +92,6 @@ public class Tensorflow1InterfaceJAvaCPP implements DeepLearningEngineInterface * Add a factor 2 to the length as a security margin */ final private static int TYPICAL_SHAPE_LENTH = 2; - /** - * File extension for the temporal files used for interprocessing - */ - final private static String FILE_EXTENSION = ".dat"; private String tmpDir; @@ -125,35 +103,6 @@ public class Tensorflow1InterfaceJAvaCPP implements DeepLearningEngineInterface this.tmpDir = getTemporaryDir(); } - /** - * Retrieve the file names used for interprocess communication - * @param args - * args provided to the main method - * @return a map with a list of input and output names - */ - private static HashMap<String, List<String>> getInputTensorsFileNames(String[] args) { - List<String> inputNames = new ArrayList<String>(); - List<String> outputNames = new ArrayList<String>(); - for (int i = 2; i < args.length; i ++) { - if (args[i].endsWith(INPUT_FILE_TERMINATION)) - inputNames.add(args[i].substring(0, args[i].length() - INPUT_FILE_TERMINATION.length())); - else if (args[i].endsWith(OUTPUT_FILE_TERMINATION)) - outputNames.add(args[i].substring(0, args[i].length() - OUTPUT_FILE_TERMINATION.length())); - } - if (inputNames.size() == 0) - throw new IllegalArgumentException("The args to the main method of '" - + Tensorflow1InterfaceJAvaCPP.class.toString() + "' should contain at " - + "least one input, defined as '<input_name> + '" + INPUT_FILE_TERMINATION + "'."); - if (outputNames.size() == 0) - throw new IllegalArgumentException("The args to the main method of '" - + Tensorflow1InterfaceJAvaCPP.class.toString() + "' should contain at " - + "least one output, defined as '<output_name> + '" + OUTPUT_FILE_TERMINATION + "'."); - HashMap<String, List<String>> map = new HashMap<String, List<String>>(); - map.put(INPUTS_MAP_KEY, inputNames); - map.put(OUTPUTS_MAP_KEY, outputNames); - return map; - } - public static void main(String[] args) throws LoadModelException, RunModelException, IOException { // Unpack the args needed if (args.length < 4) @@ -286,48 +235,6 @@ public class Tensorflow1InterfaceJAvaCPP implements DeepLearningEngineInterface public void loadModel(String modelFolder, String modelSource) throws LoadModelException { } - - /** - * Get temporary directory to perform the interprocessing communication in MacOSX intel - * @return the tmp dir - * @throws IOException - */ - private static String getTemporaryDir() throws IOException { - String tmpDir; - if (System.getenv("temp") != null - && Files.isWritable(Paths.get(System.getenv("temp")))) { - return System.getenv("temp"); - } else if (System.getenv("TEMP") != null - && Files.isWritable(Paths.get(System.getenv("TEMP")))) { - return System.getenv("TEMP"); - } else if (System.getenv("tmp") != null - && Files.isWritable(Paths.get(System.getenv("tmp")))) { - return System.getenv("tmp"); - } else if (System.getenv("TMP") != null - && Files.isWritable(Paths.get(System.getenv("TMP")))) { - return System.getenv("TMP"); - } else if (System.getProperty("java.io.tmpdir") != null - && Files.isWritable(Paths.get(System.getProperty("java.io.tmpdir")))) { - return System.getProperty("java.io.tmpdir"); - } - String enginesDir = getEnginesDir(); - if (Files.isWritable(Paths.get(enginesDir))) { - tmpDir = enginesDir + File.separator + "temp"; - if (!(new File(tmpDir).isDirectory()) && !(new File(tmpDir).mkdirs())) - tmpDir = enginesDir; - } else { - throw new IOException("Unable to find temporal directory with writting rights. " - + "Please either allow writting on the system temporal folder or on '" + enginesDir + "'."); - } - return tmpDir; - } - - private static String getEnginesDir() { - ProtectionDomain protectionDomain = Tensorflow1InterfaceJAvaCPP.class.getProtectionDomain(); - CodeSource codeSource = protectionDomain.getCodeSource(); - String jarFile = codeSource.getLocation().getPath(); - return jarFile; - } @Override public void run(List<Tensor<?>> inputTensors, List<Tensor<?>> outputTensors) throws RunModelException { @@ -353,62 +260,6 @@ public class Tensorflow1InterfaceJAvaCPP implements DeepLearningEngineInterface retrieveInterprocessingTensors(outputTensors); } - - private void createTensorsForInterprocessing(List<Tensor<?>> tensors) throws RunModelException{ - for (Tensor<?> tensor : tensors) { - long lenFile = MappedFileBuilder.findTotalLengthFile(tensor); - try (RandomAccessFile rd = - new RandomAccessFile(tmpDir + File.separator + tensor.getName() + FILE_EXTENSION, "rw"); - FileChannel fc = rd.getChannel();) { - MappedByteBuffer mem = fc.map(FileChannel.MapMode.READ_WRITE, 0, lenFile); - ByteBuffer byteBuffer = mem.duplicate(); - byteBuffer.put(MappedFileBuilder.createFileHeader(tensor)); - MappedFileBuilder.build(tensor, byteBuffer); - } catch (IOException e) { - throw new RunModelException(e.getCause().toString()); - } - } - } - - private void retrieveInterprocessingTensors(List<Tensor<?>> tensors) throws RunModelException{ - for (Tensor<?> tensor : tensors) { - try (RandomAccessFile rd = - new RandomAccessFile(tmpDir + File.separator + tensor.getName() + FILE_EXTENSION, "r"); - FileChannel fc = rd.getChannel();) { - MappedByteBuffer mem = fc.map(FileChannel.MapMode.READ_WRITE, 0, fc.size()); - ByteBuffer byteBuffer = mem.duplicate(); - tensor.setData(MappedBufferToImgLib2.build(byteBuffer)); - } catch (IOException e) { - throw new RunModelException(e.getCause().toString()); - } - } - } - - /** - * Create the arguments needed to execute tensorflow1 in another - * process with the corresponding tensors - * @return - */ - private List<String> getProcessCommandsWithoutArgs(){ - System.out.println(System.getProperty("java.library.path")); - String javaHome = System.getProperty("java.home"); - String javaBin = javaHome + File.separator + "bin" + File.separator + "java"; - String classpath = System.getProperty("java.class.path"); - ProtectionDomain protectionDomain = Tensorflow1InterfaceJAvaCPP.class.getProtectionDomain(); - CodeSource codeSource = protectionDomain.getCodeSource(); - String jarFile = codeSource.getLocation().getPath(); - String className = Tensorflow1InterfaceJAvaCPP.class.getName(); - classpath += File.pathSeparator; - for (File ff : new File(codeSource.getLocation().getPath()).getParentFile().listFiles()) { - classpath += ff.getAbsolutePath() + File.pathSeparator; - } - List<String> command = new LinkedList<String>(); - command.add(javaBin); - command.add("-cp"); - command.add(classpath); - command.add(className); - return command; - } @Override public void closeModel() {