diff --git a/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/Tensorflow1Interface.java b/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/Tensorflow1Interface.java index 4531f54d3fbd559da0af277f95937d0a17669def..6ed3e29ab43039ae17baf232bc0bf9155bf62bf4 100644 --- a/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/Tensorflow1Interface.java +++ b/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/Tensorflow1Interface.java @@ -103,7 +103,7 @@ public class Tensorflow1Interface implements DeepLearningEngineInterface private static SavedModelBundle model; private static SignatureDef sig; - private boolean interprocessing = false; + private boolean interprocessing = true; private String tmpDir; @@ -111,6 +111,8 @@ public class Tensorflow1Interface implements DeepLearningEngineInterface public Tensorflow1Interface() throws IOException { + // TODO remove tmpDIr + tmpDir = getTemporaryDir(); boolean isMac = PlatformDetection.isMacOS(); boolean isIntel = new PlatformDetection().getArch().equals(PlatformDetection.ARCH_X86_64); if (isMac && isIntel) { @@ -122,6 +124,8 @@ public class Tensorflow1Interface implements DeepLearningEngineInterface public Tensorflow1Interface(boolean doInterprocessing) throws IOException { + // TODO remove tmpDIr + tmpDir = getTemporaryDir(); if (!doInterprocessing) { interprocessing = false; } else { @@ -139,7 +143,7 @@ public class Tensorflow1Interface implements DeepLearningEngineInterface // Unpack the args needed if (args.length < 4) throw new IllegalArgumentException("Error exectuting Tensorflow 1, " - + "at least arguments are required:" + System.lineSeparator() + + "at least 5 arguments are required:" + System.lineSeparator() + " - Folder where the model is located" + System.lineSeparator() + " - Temporary dir where the memory mapped files are located" + System.lineSeparator() + " - Name of the model input followed by the String + '_model_input'" + System.lineSeparator() @@ -152,14 +156,14 @@ public class Tensorflow1Interface implements DeepLearningEngineInterface + " - Name of the nth model output (if it exists) followed by the String + '_model_output'" + System.lineSeparator() ); String modelFolder = args[0]; - if (new File(modelFolder).isDirectory()) { + if (!(new File(modelFolder).isDirectory())) { throw new IllegalArgumentException("Argument 0 of the main method, '" + modelFolder + "' " + "should be an existing directory containing a Tensorflow 1 model."); } - Tensorflow1Interface tfInterface = new Tensorflow1Interface(); + Tensorflow1Interface tfInterface = new Tensorflow1Interface(false); tfInterface.tmpDir = args[1]; - if (new File(args[1]).isDirectory()) { + if (!(new File(args[1]).isDirectory())) { throw new IllegalArgumentException("Argument 1 of the main method, '" + args[1] + "' " + "should be an existing directory."); } @@ -241,6 +245,18 @@ public class Tensorflow1Interface implements DeepLearningEngineInterface args.add(this.tmpDir); for (Tensor tensor : inputTensors) {args.add(tensor.getName() + INPUT_FILE_TERMINATION);} for (Tensor tensor : outputTensors) {args.add(tensor.getName() + OUTPUT_FILE_TERMINATION);} + try { + main(new String[] {args.get(4), args.get(5), args.get(6), args.get(7)}); + } catch (LoadModelException e1) { + // TODO Auto-generated catch block + e1.printStackTrace(); + } catch (IOException e1) { + // TODO Auto-generated catch block + e1.printStackTrace(); + } catch (RunModelException e1) { + // TODO Auto-generated catch block + e1.printStackTrace(); + } ProcessBuilder builder = new ProcessBuilder(args); Process process; @@ -323,13 +339,12 @@ public class Tensorflow1Interface implements DeepLearningEngineInterface try (RandomAccessFile rd = new RandomAccessFile(tmpDir + File.separator + name + FILE_EXTENSION, "r"); FileChannel fc = rd.getChannel();) { - MappedByteBuffer mem = fc.map(FileChannel.MapMode.READ_WRITE, 0, fc.size()); + MappedByteBuffer mem = fc.map(FileChannel.MapMode.READ_ONLY, 0, fc.size()); ByteBuffer byteBuffer = mem.duplicate(); - //tensor.setData(MappedBufferToImgLib2.build(byteBuffer)); + return MappedBufferToImgLib2.buildTensor(byteBuffer); } catch (IOException e) { throw new RunModelException(e.getCause().toString()); } - return null; } /** @@ -338,7 +353,6 @@ public class Tensorflow1Interface implements DeepLearningEngineInterface * @return */ private List<String> getProcessCommandsWithoutArgs(){ - System.out.println(System.getProperty("java.library.path")); String javaHome = System.getProperty("java.home"); String javaBin = javaHome + File.separator + "bin" + File.separator + "java"; String classpath = System.getProperty("java.class.path"); diff --git a/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedBufferToImgLib2.java b/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedBufferToImgLib2.java index f5d63adbee9e1e4848d709b4b9923f042af2573b..143358d0a9c50e0a49e3aec92ee853e725a6ea04 100644 --- a/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedBufferToImgLib2.java +++ b/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedBufferToImgLib2.java @@ -1,6 +1,7 @@ package org.bioimageanalysis.icy.deeplearning.tensorflow.v1.tensor; import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.HashMap; import java.util.regex.Matcher; @@ -237,7 +238,7 @@ public final class MappedBufferToImgLib2 int lenInfo = ByteBuffer.wrap(lenInfoInBytes).getInt(); byte[] stringInfoBytes = new byte[lenInfo]; buff.get(stringInfoBytes); - return new String(stringInfoBytes); + return new String(stringInfoBytes, StandardCharsets.UTF_8); } /** diff --git a/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedFileBuilder.java b/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedFileBuilder.java index a941153bd503ffc28010c11ca1829bf70bf5bee3..6819629d305865721528276806fa8df48dbe94b2 100644 --- a/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedFileBuilder.java +++ b/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedFileBuilder.java @@ -1,6 +1,7 @@ package org.bioimageanalysis.icy.deeplearning.tensorflow.v1.tensor; import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import org.tensorflow.Tensor; @@ -189,17 +190,17 @@ public final class MappedFileBuilder + tensor.getAxesOrderString() + "','name':'" + tensor.getName() + "','shape':'" + dimsStr + "'}"; - byte[] descriptionBytes = descriptionStr.getBytes(); + byte[] descriptionBytes = descriptionStr.getBytes(StandardCharsets.UTF_8); int lenDescriptionBytes = descriptionBytes.length; - byte[] inAsBytes = ByteBuffer.allocate(4).putInt(lenDescriptionBytes).array(); - int totalHeaderLen = MODEL_RUNNER_HEADER.length + inAsBytes.length + lenDescriptionBytes; + byte[] intAsBytes = ByteBuffer.allocate(4).putInt(lenDescriptionBytes).array(); + int totalHeaderLen = MODEL_RUNNER_HEADER.length + intAsBytes.length + lenDescriptionBytes; byte[] byteHeader = new byte[totalHeaderLen]; for (int i = 0; i < MODEL_RUNNER_HEADER.length; i ++) byteHeader[i] = MODEL_RUNNER_HEADER[i]; - for (int i = MODEL_RUNNER_HEADER.length; i < MODEL_RUNNER_HEADER.length + inAsBytes.length; i ++) - byteHeader[i] = MODEL_RUNNER_HEADER[i - MODEL_RUNNER_HEADER.length]; - for (int i = MODEL_RUNNER_HEADER.length + inAsBytes.length; i < totalHeaderLen; i ++) - byteHeader[i] = MODEL_RUNNER_HEADER[i - MODEL_RUNNER_HEADER.length - inAsBytes.length]; + for (int i = MODEL_RUNNER_HEADER.length; i < MODEL_RUNNER_HEADER.length + intAsBytes.length; i ++) + byteHeader[i] = intAsBytes[i - MODEL_RUNNER_HEADER.length]; + for (int i = MODEL_RUNNER_HEADER.length + intAsBytes.length; i < totalHeaderLen; i ++) + byteHeader[i] = descriptionBytes[i - MODEL_RUNNER_HEADER.length - intAsBytes.length]; return byteHeader; } @@ -214,7 +215,9 @@ public final class MappedFileBuilder public static < T extends RealType< T > & NativeType< T > > long findTotalLengthFile(org.bioimageanalysis.icy.deeplearning.tensor.Tensor<T> tensor) { long startLen = createFileHeader(tensor).length; - long[] dimsArr = tensor.getData().dimensionsAsLongArray(); + long[] dimsArr = tensor.getData() != null ? tensor.getData().dimensionsAsLongArray() : null; + if (dimsArr == null) + return startLen; long totSizeFlat = 1; for (long i : dimsArr) {totSizeFlat *= i;} long nBytesDt = 1;