diff --git a/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/Tensorflow1Interface.java b/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/Tensorflow1Interface.java index c50a3fcb98519a19b504a30b91c5a41e7ea089e1..18e5e5dc3224d65fbd23a6ecd5d340f79cddf2e0 100644 --- a/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/Tensorflow1Interface.java +++ b/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/Tensorflow1Interface.java @@ -65,6 +65,14 @@ public class Tensorflow1Interface implements DeepLearningEngineInterface public Tensorflow1Interface() { } + + public static void main(String[] args) throws LoadModelException { + Tensorflow1Interface intf = new Tensorflow1Interface(); + intf.loadModel("C:\\Users\\angel\\OneDrive\\Documentos\\pasteur\\git\\deep-icy\\models\\Neuron Segmentation in 2D EM (Membrane)_02022023_175546", ""); + String aa = intf.getModelInputName("input"); + String bb = intf.getModelOutputName("output"); + System.out.print(false); + } @Override public void loadModel(String modelFolder, String modelSource) throws LoadModelException { @@ -133,7 +141,6 @@ public class Tensorflow1Interface implements DeepLearningEngineInterface model.close(); } model = null; - } // TODO make only one diff --git a/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/Tensorflow1InterfaceJAvaCPP.java b/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/Tensorflow1InterfaceJAvaCPP.java new file mode 100644 index 0000000000000000000000000000000000000000..2f9196d0fc44ceae5ebdf7d7cdeb4a2af47e5d78 --- /dev/null +++ b/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/Tensorflow1InterfaceJAvaCPP.java @@ -0,0 +1,417 @@ +package org.bioimageanalysis.icy.deeplearning.tensorflow.v1; + +import java.io.File; +import java.io.IOException; +import java.io.OutputStream; +import java.io.RandomAccessFile; +import java.net.ServerSocket; +import java.net.Socket; +import java.net.URL; +import java.net.URLClassLoader; +import java.nio.ByteBuffer; +import java.nio.FloatBuffer; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.security.CodeSource; +import java.security.ProtectionDomain; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.stream.Collectors; + +import org.bioimageanalysis.icy.deeplearning.engine.DeepLearningEngineInterface; +import org.bioimageanalysis.icy.deeplearning.exceptions.LoadModelException; +import org.bioimageanalysis.icy.deeplearning.exceptions.RunModelException; +import org.bioimageanalysis.icy.deeplearning.tensor.Tensor; +import org.bioimageanalysis.icy.deeplearning.tensorflow.v1.tensor.MappedBufferToImgLib2; +import org.bioimageanalysis.icy.deeplearning.tensorflow.v1.tensor.MappedFileBuilder; +import org.bioimageanalysis.icy.deeplearning.tensorflow.v1.tensor.TensorBuilder; +import org.bytedeco.javacpp.BytePointer; +import org.bytedeco.javacpp.Pointer; +import org.bytedeco.javacpp.indexer.Indexer; +import org.bytedeco.tensorflow.RunOptions; +import org.bytedeco.tensorflow.SessionOptions; +import org.bytedeco.tensorflow.StringTensorPairVector; +import org.bytedeco.tensorflow.StringUnorderedSet; +import org.bytedeco.tensorflow.StringVector; +import org.bytedeco.tensorflow.TensorShape; +import org.bytedeco.tensorflow.TensorVector; + + +/** + * This plugin includes the libraries to convert back and forth TensorFlow 1 to Sequences and IcyBufferedImages. + * + * @see IcyBufferedImageBuilder IcyBufferedImageBuilder: Create images from tensors. + * @see Nd4fBuilder SequenceBuilder: Create sequences from tensors. + * @see TensorBuilder TensorBuilder: Create tensors from images and sequences. + * @author Carlos Garcia Lopez de Haro and Daniel Felipe Gonzalez Obando + */ +public class Tensorflow1InterfaceJAvaCPP implements DeepLearningEngineInterface +{ + private static final String[] MODEL_TAGS = {"serve", "inference", "train", "eval", "gpu", "tpu"}; + + private static final String[] TF_MODEL_TAGS = {"tf.saved_model.tag_constants.SERVING", + "tf.saved_model.tag_constants.INFERENCE", "tf.saved_model.tag_constants.TRAINING", + "tf.saved_model.tag_constants.EVAL", "tf.saved_model.tag_constants.GPU", + "tf.saved_model.tag_constants.TPU"}; + + private static final String[] SIGNATURE_CONSTANTS = {"serving_default", "inputs", "tensorflow/serving/classify", + "classes", "scores", "inputs", "tensorflow/serving/predict", "outputs", "inputs", + "tensorflow/serving/regress", "outputs", "train", "eval", "tensorflow/supervised/training", + "tensorflow/supervised/eval"}; + + private static final String[] TF_SIGNATURE_CONSTANTS = { + "tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY", + "tf.saved_model.signature_constants.CLASSIFY_INPUTS", + "tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME", + "tf.saved_model.signature_constants.CLASSIFY_OUTPUT_CLASSES", + "tf.saved_model.signature_constants.CLASSIFY_OUTPUT_SCORES", + "tf.saved_model.signature_constants.PREDICT_INPUTS", + "tf.saved_model.signature_constants.PREDICT_METHOD_NAME", + "tf.saved_model.signature_constants.PREDICT_OUTPUTS", "tf.saved_model.signature_constants.REGRESS_INPUTS", + "tf.saved_model.signature_constants.REGRESS_METHOD_NAME", + "tf.saved_model.signature_constants.REGRESS_OUTPUTS", + "tf.saved_model.signature_constants.DEFAULT_TRAIN_SIGNATURE_DEF_KEY", + "tf.saved_model.signature_constants.DEFAULT_EVAL_SIGNATURE_DEF_KEY", + "tf.saved_model.signature_constants.SUPERVISED_TRAIN_METHOD_NAME", + "tf.saved_model.signature_constants.SUPERVISED_EVAL_METHOD_NAME"}; + + /** + * Idetifier for the files that contain the data of the inputs + */ + final private static String INPUT_FILE_TERMINATION = "_model_input"; + + /** + * Idetifier for the files that contain the data of the outputs + */ + final private static String OUTPUT_FILE_TERMINATION = "_model_output"; + /** + * Key for the inputs in the map that retrieves the file names for interprocess communication + */ + final private static String INPUTS_MAP_KEY = "inputs"; + /** + * Key for the outputs in the map that retrieves the file names for interprocess communication + */ + final private static String OUTPUTS_MAP_KEY = "outputs"; + /** + * Number of bytes typically used to represent the datatype in the memory mapped file + */ + final private static int DATATYPE_KEY_LENGTH = 2; + /** + * Number of bytes typically used to represent an identification header in the memory mapped file + */ + final private static int MAGIC_WORD_LENGTH = 2; + /** + * Number of bytes typically used to represent the shape in the memory mapped file. + * A typical shape (on the bigger side) would be: [1, 3, 32, 1024, 1024]. + * Add a factor 2 to the length as a security margin + */ + final private static int TYPICAL_SHAPE_LENTH = 2; + /** + * File extension for the temporal files used for interprocessing + */ + final private static String FILE_EXTENSION = ".dat"; + + private String tmpDir; + + private String modelFolder; + + + public Tensorflow1InterfaceJAvaCPP() throws IOException + { + this.tmpDir = getTemporaryDir(); + } + + /** + * Retrieve the file names used for interprocess communication + * @param args + * args provided to the main method + * @return a map with a list of input and output names + */ + private static HashMap<String, List<String>> getInputTensorsFileNames(String[] args) { + List<String> inputNames = new ArrayList<String>(); + List<String> outputNames = new ArrayList<String>(); + for (int i = 2; i < args.length; i ++) { + if (args[i].endsWith(INPUT_FILE_TERMINATION)) + inputNames.add(args[i].substring(0, args[i].length() - INPUT_FILE_TERMINATION.length())); + else if (args[i].endsWith(OUTPUT_FILE_TERMINATION)) + outputNames.add(args[i].substring(0, args[i].length() - OUTPUT_FILE_TERMINATION.length())); + } + if (inputNames.size() == 0) + throw new IllegalArgumentException("The args to the main method of '" + + Tensorflow1InterfaceJAvaCPP.class.toString() + "' should contain at " + + "least one input, defined as '<input_name> + '" + INPUT_FILE_TERMINATION + "'."); + if (outputNames.size() == 0) + throw new IllegalArgumentException("The args to the main method of '" + + Tensorflow1InterfaceJAvaCPP.class.toString() + "' should contain at " + + "least one output, defined as '<output_name> + '" + OUTPUT_FILE_TERMINATION + "'."); + HashMap<String, List<String>> map = new HashMap<String, List<String>>(); + map.put(INPUTS_MAP_KEY, inputNames); + map.put(OUTPUTS_MAP_KEY, outputNames); + return map; + } + + public static void main(String[] args) throws LoadModelException, RunModelException, IOException { + // Unpack the args needed + if (args.length < 4) + throw new IllegalArgumentException("Error exectuting Tensorflow 1, " + + "at least arguments are required:" + System.lineSeparator() + + " - Folder where the model is located" + System.lineSeparator() + + " - Temporary dir where the memory mapped files are located" + System.lineSeparator() + + " - Name of the model input followed by the String + '_model_input'" + System.lineSeparator() + + " - Name of the second model input (if it exists) followed by the String + '_model_input'" + System.lineSeparator() + + " - ...." + System.lineSeparator() + + " - Name of the nth model input (if it exists) followed by the String + '_model_input'" + System.lineSeparator() + + " - Name of the model output followed by the String + '_model_output'" + System.lineSeparator() + + " - Name of the second model output (if it exists) followed by the String + '_model_output'" + System.lineSeparator() + + " - ...." + System.lineSeparator() + + " - Name of the nth model output (if it exists) followed by the String + '_model_output'" + System.lineSeparator() + ); + String modelFolder = args[0]; + if (new File(modelFolder).isDirectory()) { + throw new IllegalArgumentException("Argument 0 of the main method, '" + modelFolder + "' " + + "should be an existing directory containing a Tensorflow 1 model."); + } + + Tensorflow1InterfaceJAvaCPP tfInterface = new Tensorflow1InterfaceJAvaCPP(); + tfInterface.tmpDir = args[1]; + if (new File(args[1]).isDirectory()) { + throw new IllegalArgumentException("Argument 1 of the main method, '" + args[1] + "' " + + "should be an existing directory."); + } + + tfInterface.loadModel(modelFolder, modelFolder); + + HashMap<String, List<String>> map = getInputTensorsFileNames(args); + List<String> inputNames = map.get(INPUTS_MAP_KEY); + List<String> outputNames = map.get(OUTPUTS_MAP_KEY); + + for (String inName : inputNames) { + try (RandomAccessFile rd = + new RandomAccessFile(tfInterface.tmpDir + File.separator + inName, "r"); + FileChannel fc = rd.getChannel();) { + int initialLenghToRead = inName.length() + DATATYPE_KEY_LENGTH + MAGIC_WORD_LENGTH + TYPICAL_SHAPE_LENTH; + MappedByteBuffer mem = fc.map(FileChannel.MapMode.READ_WRITE, 0, fc.size()); + } + } + + + + + + + System.out.println(""); + System.out.println(System.getProperty("java.library.path")); + String dir = args[0]; + org.bytedeco.tensorflow.SavedModelBundle model = new org.bytedeco.tensorflow.SavedModelBundle(); + StringUnorderedSet ss = new StringUnorderedSet(); + ss.insert(new BytePointer("serve")); + SessionOptions sesOp = new SessionOptions(); + RunOptions runOp = new RunOptions(); + org.bytedeco.tensorflow.global.tensorflow.LoadSavedModel(sesOp, runOp, dir, ss, model); + + TensorShape shape = new TensorShape(1, 512, 512, 1); + org.bytedeco.tensorflow.Tensor x = + new org.bytedeco.tensorflow.Tensor(org.bytedeco.tensorflow.global.tensorflow.DT_FLOAT, + shape); + FloatBuffer x_flat = x.createBuffer(); + float[] arr = new float[512 * 512]; + x_flat.put(arr); + + TensorVector outputV = new TensorVector(); + //outputV.resize(1); + org.bytedeco.tensorflow.Session session = model.session(); + StringTensorPairVector pair = new StringTensorPairVector(new String[] {"input_2_4:0"}, new org.bytedeco.tensorflow.Tensor[] {x}); + StringVector sVec = new StringVector("conv2d_33_4/Sigmoid:0"); + StringVector svec2 = new StringVector("conv2d_33_4/Sigmoid:0"); + session.Run(pair, sVec, svec2, outputV); + /* + session.Run(new RunOptions(), + new StringTensorPairVector(new String[] {"input_2_4:0"}, new org.bytedeco.tensorflow.Tensor[] {x}), + new StringVector("conv2d_33_4/Sigmoid:0"), new StringVector("conv2d_33_4/Sigmoid:0"), + out, new RunMetadata()); + */ + + + //org.bytedeco.tensorflow.Tensor y = outputV.get(0); + + Indexer aa = outputV.get(0).createIndexer(); + Indexer bb = x.createIndexer(); + //Indexer cc = out.get(0).createIndexer(); + FloatBuffer y_flat = outputV.get(0).createBuffer(); + float sss = 0; + for (int i = 0; i < y_flat.capacity(); i ++) { + sss += y_flat.get(i); + } + float mean = sss / y_flat.capacity(); + x.close(); + outputV.close(); + session.Close(); + session.close(); + model.close(); + aa.close(); + bb.close(); + ss.close(); + sesOp.close(); + runOp.close(); + shape.Clear(); + shape.close(); + pair.clear(); + pair.close(); + sVec.clear(); + sVec.close(); + svec2.clear(); + svec2.close(); + + + x.deallocate(); + outputV.deallocate(); + session.deallocate(); + session.deallocate(); + model.deallocate(); + ss.deallocate(); + sesOp.deallocate(); + runOp.deallocate(); + shape.deallocate(); + pair.deallocate(); + sVec.deallocate(); + svec2.deallocate(); + System.out.println(mean); + } + + @Override + public void loadModel(String modelFolder, String modelSource) throws LoadModelException { + + } + + /** + * Get temporary directory to perform the interprocessing communication in MacOSX intel + * @return the tmp dir + * @throws IOException + */ + private static String getTemporaryDir() throws IOException { + String tmpDir; + if (System.getenv("temp") != null + && Files.isWritable(Paths.get(System.getenv("temp")))) { + return System.getenv("temp"); + } else if (System.getenv("TEMP") != null + && Files.isWritable(Paths.get(System.getenv("TEMP")))) { + return System.getenv("TEMP"); + } else if (System.getenv("tmp") != null + && Files.isWritable(Paths.get(System.getenv("tmp")))) { + return System.getenv("tmp"); + } else if (System.getenv("TMP") != null + && Files.isWritable(Paths.get(System.getenv("TMP")))) { + return System.getenv("TMP"); + } else if (System.getProperty("java.io.tmpdir") != null + && Files.isWritable(Paths.get(System.getProperty("java.io.tmpdir")))) { + return System.getProperty("java.io.tmpdir"); + } + String enginesDir = getEnginesDir(); + if (Files.isWritable(Paths.get(enginesDir))) { + tmpDir = enginesDir + File.separator + "temp"; + if (!(new File(tmpDir).isDirectory()) && !(new File(tmpDir).mkdirs())) + tmpDir = enginesDir; + } else { + throw new IOException("Unable to find temporal directory with writting rights. " + + "Please either allow writting on the system temporal folder or on '" + enginesDir + "'."); + } + return tmpDir; + } + + private static String getEnginesDir() { + ProtectionDomain protectionDomain = Tensorflow1InterfaceJAvaCPP.class.getProtectionDomain(); + CodeSource codeSource = protectionDomain.getCodeSource(); + String jarFile = codeSource.getLocation().getPath(); + return jarFile; + } + + @Override + public void run(List<Tensor<?>> inputTensors, List<Tensor<?>> outputTensors) throws RunModelException { + createTensorsForInterprocessing(inputTensors); + List<String> args = getProcessCommandsWithoutArgs(); + args.add(modelFolder); + args.add(this.tmpDir); + for (Tensor tensor : inputTensors) {args.add(tensor.getName() + INPUT_FILE_TERMINATION);} + for (Tensor tensor : outputTensors) {args.add(tensor.getName() + OUTPUT_FILE_TERMINATION);} + + ProcessBuilder builder = new ProcessBuilder(args); + Process process; + try { + process = builder.inheritIO().start(); + if (process.waitFor() != 0) + throw new RunModelException("Error executing the Tensorflow 1 model in" + + " a separate process. The process was not terminated correctly."); + } catch (RunModelException e) { + throw e; + } catch (Exception e) { + throw new RunModelException(e.getCause().toString()); + } + + retrieveInterprocessingTensors(outputTensors); + } + + private void createTensorsForInterprocessing(List<Tensor<?>> tensors) throws RunModelException{ + for (Tensor<?> tensor : tensors) { + long lenFile = MappedFileBuilder.findTotalLengthFile(tensor); + try (RandomAccessFile rd = + new RandomAccessFile(tmpDir + File.separator + tensor.getName() + FILE_EXTENSION, "rw"); + FileChannel fc = rd.getChannel();) { + MappedByteBuffer mem = fc.map(FileChannel.MapMode.READ_WRITE, 0, lenFile); + ByteBuffer byteBuffer = mem.duplicate(); + byteBuffer.put(MappedFileBuilder.createFileHeader(tensor)); + MappedFileBuilder.build(tensor, byteBuffer); + } catch (IOException e) { + throw new RunModelException(e.getCause().toString()); + } + } + } + + private void retrieveInterprocessingTensors(List<Tensor<?>> tensors) throws RunModelException{ + for (Tensor<?> tensor : tensors) { + try (RandomAccessFile rd = + new RandomAccessFile(tmpDir + File.separator + tensor.getName() + FILE_EXTENSION, "r"); + FileChannel fc = rd.getChannel();) { + MappedByteBuffer mem = fc.map(FileChannel.MapMode.READ_WRITE, 0, fc.size()); + ByteBuffer byteBuffer = mem.duplicate(); + tensor.setData(MappedBufferToImgLib2.build(byteBuffer)); + } catch (IOException e) { + throw new RunModelException(e.getCause().toString()); + } + } + } + + /** + * Create the arguments needed to execute tensorflow1 in another + * process with the corresponding tensors + * @return + */ + private List<String> getProcessCommandsWithoutArgs(){ + System.out.println(System.getProperty("java.library.path")); + String javaHome = System.getProperty("java.home"); + String javaBin = javaHome + File.separator + "bin" + File.separator + "java"; + String classpath = System.getProperty("java.class.path"); + ProtectionDomain protectionDomain = Tensorflow1InterfaceJAvaCPP.class.getProtectionDomain(); + CodeSource codeSource = protectionDomain.getCodeSource(); + String jarFile = codeSource.getLocation().getPath(); + String className = Tensorflow1InterfaceJAvaCPP.class.getName(); + classpath += File.pathSeparator; + for (File ff : new File(codeSource.getLocation().getPath()).getParentFile().listFiles()) { + classpath += ff.getAbsolutePath() + File.pathSeparator; + } + List<String> command = new LinkedList<String>(); + command.add(javaBin); + command.add("-cp"); + command.add(classpath); + command.add(className); + return command; + } + + @Override + public void closeModel() { + Pointer.interruptDeallocatorThread(); + } +} diff --git a/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedBufferToImgLib2.java b/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedBufferToImgLib2.java new file mode 100644 index 0000000000000000000000000000000000000000..98878c84f9dddddb9f6891c4567309c62ddae703 --- /dev/null +++ b/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedBufferToImgLib2.java @@ -0,0 +1,218 @@ +package org.bioimageanalysis.icy.deeplearning.tensorflow.v1.tensor; + +import java.nio.ByteBuffer; +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.nio.IntBuffer; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import org.bioimageanalysis.icy.deeplearning.utils.IndexingUtils; +import org.tensorflow.Tensor; +import org.tensorflow.types.UInt8; + +import net.imglib2.Cursor; +import net.imglib2.img.Img; +import net.imglib2.img.ImgFactory; +import net.imglib2.img.cell.CellImgFactory; +import net.imglib2.type.Type; +import net.imglib2.type.numeric.integer.ByteType; +import net.imglib2.type.numeric.integer.IntType; +import net.imglib2.type.numeric.real.DoubleType; +import net.imglib2.type.numeric.real.FloatType; + +/** + * A {@link Img} builder for TensorFlow {@link Tensor} objects. + * + * @author Carlos Garcia Lopez de Haro + */ +public final class MappedBufferToImgLib2 +{ + /** + * Pattern that matches the header of the temporal file for interprocess communication + * and retrieves data type and shape + */ + private static final Pattern HEADER_PATTERN = + Pattern.compile("\\{'dtype':'([a-zA-Z0-9]+)',shape':'\\[(\\d+(,\\d+)*)\\]'}"); + /** + * Key for data type info + */ + private static final String DATA_TYPE_KEY = "dtype"; + /** + * Key for shape info + */ + private static final String SHAPE_KEY = "shape"; + + /** + * Not used (Utility class). + */ + private MappedBufferToImgLib2() + { + } + + /** + * Creates a {@link Img} from a given {@link Tensor} and an array with its dimensions order. + * + * @param tensor + * The tensor data is read from. + * @return The INDArray built from the tensor. + * @throws IllegalArgumentException + * If the tensor type is not supported. + */ + @SuppressWarnings("unchecked") + public static <T extends Type<T>> Img<T> build(ByteBuffer tensor) throws IllegalArgumentException + { + String infoStr = getTensorInfoFromBuffer(tensor); + HashMap<String, Object> map = getDataTypeAndShape(infoStr); + String dtype = (String) map.get(DATA_TYPE_KEY); + long[] shape = (long[]) map.get(SHAPE_KEY); + + // Create an INDArray of the same type of the tensor + switch (dtype) + { + case "byte": + return (Img<T>) buildFromTensorByte(tensor, shape); + case "int32": + return (Img<T>) buildFromTensorInt(tensor, shape); + case "float32": + return (Img<T>) buildFromTensorFloat(tensor, shape); + case "float64": + return (Img<T>) buildFromTensorDouble(tensor, shape); + default: + throw new IllegalArgumentException("Unsupported tensor type: " + dtype); + } + } + + /** + * Builds a {@link Img} from a unsigned byte-typed {@link Tensor}. + * + * @param tensor + * The tensor data is read from. + * @return The INDArray built from the tensor of type {@link DataType#UBYTE}. + */ + private static Img<ByteType> buildFromTensorByte(ByteBuffer tensor, long[] tensorShape) + { + final ImgFactory< ByteType > factory = new CellImgFactory<>( new ByteType(), 5 ); + final Img< ByteType > outputImg = (Img<ByteType>) factory.create(tensorShape); + Cursor<ByteType> tensorCursor= outputImg.cursor(); + while (tensorCursor.hasNext()) { + tensorCursor.fwd(); + tensorCursor.get().set(tensor.get()); + } + return outputImg; + } + + /** + * Builds a {@link Img} from a unsigned integer-typed {@link Tensor}. + * + * @param tensor + * The tensor data is read from. + * @return The sequence built from the tensor of type {@link DataType#INT}. + */ + private static Img<IntType> buildFromTensorInt(ByteBuffer tensor, long[] tensorShape) + { + final ImgFactory< IntType > factory = new CellImgFactory<>( new IntType(), 5 ); + final Img< IntType > outputImg = (Img<IntType>) factory.create(tensorShape); + Cursor<IntType> tensorCursor= outputImg.cursor(); + byte[] bytes = new byte[4]; + while (tensorCursor.hasNext()) { + tensorCursor.fwd(); + tensor.get(bytes); + int val = ((int) (bytes[0] << 24)) + ((int) (bytes[1] << 16)) + + ((int) (bytes[2] << 8)) + ((int) (bytes[3])); + tensorCursor.get().set(val); + } + return outputImg; + } + + /** + * Builds a {@link Img} from a unsigned float-typed {@link Tensor}. + * + * @param tensor + * The tensor data is read from. + * @return The INDArray built from the tensor of type {@link DataType#FLOAT}. + */ + private static Img<FloatType> buildFromTensorFloat(ByteBuffer tensor, long[] tensorShape) + { + final ImgFactory< FloatType > factory = new CellImgFactory<>( new FloatType(), 5 ); + final Img< FloatType > outputImg = (Img<FloatType>) factory.create(tensorShape); + Cursor<FloatType> tensorCursor= outputImg.cursor(); + byte[] bytes = new byte[4]; + while (tensorCursor.hasNext()) { + tensorCursor.fwd(); + tensor.get(bytes); + float val = ByteBuffer.wrap(bytes).getFloat(); + tensorCursor.get().set(val); + } + return outputImg; + } + + /** + * Builds a {@link Img} from a unsigned double-typed {@link Tensor}. + * + * @param tensor + * The tensor data is read from. + * @return The INDArray built from the tensor of type {@link DataType#DOUBLE}. + */ + private static Img<DoubleType> buildFromTensorDouble(ByteBuffer tensor, long[] tensorShape) + { + final ImgFactory< DoubleType > factory = new CellImgFactory<>( new DoubleType(), 5 ); + final Img< DoubleType > outputImg = (Img<DoubleType>) factory.create(tensorShape); + Cursor<DoubleType> tensorCursor= outputImg.cursor(); + byte[] bytes = new byte[8]; + while (tensorCursor.hasNext()) { + tensorCursor.fwd(); + tensor.get(bytes); + double val = ByteBuffer.wrap(bytes).getDouble(); + tensorCursor.get().set(val); + } + return outputImg; + } + + /** + * GEt the String info stored at the beginning of the buffer that contains + * the data type and the shape. + * @param buff + * @return + */ + public static String getTensorInfoFromBuffer(ByteBuffer buff) { + byte[] arr = new byte[MappedFileBuilder.MODEL_RUNNER_HEADER.length]; + buff.get(arr); + if (!Arrays.equals(arr, MappedFileBuilder.MODEL_RUNNER_HEADER)) + throw new IllegalArgumentException("Error sending tensors between processes."); + byte[] lenInfoInBytes = new byte[4]; + buff.get(lenInfoInBytes); + int lenInfo = ByteBuffer.wrap(lenInfoInBytes).getInt(); + byte[] stringInfoBytes = new byte[lenInfo]; + buff.get(stringInfoBytes); + return new String(stringInfoBytes); + } + + /** + * MEthod that retrieves the data type string and shape long array representing + * the data type and dimensions of the tensor saved in the temp file + * @param infoStr + * @return + */ + public static HashMap<String, Object> getDataTypeAndShape(String infoStr) { + Matcher m = HEADER_PATTERN.matcher(infoStr); + if (!m.find()) { + throw new IllegalArgumentException("Cannot find datatype and dimensions " + + "info in file hader: " + infoStr); + } + String typeStr = m.group(1); + String shapeStr = m.group(2); + long[] shape = new long[0]; + if (!shapeStr.isEmpty()) { + String[] tokens = shapeStr.split(", ?"); + shape = Arrays.stream(tokens).mapToLong(Long::parseLong).toArray(); + } + HashMap<String, Object> map = new HashMap<String, Object>(); + map.put(DATA_TYPE_KEY, typeStr); + map.put(SHAPE_KEY, shapeStr); + return map; + } +} diff --git a/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedFileBuilder.java b/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedFileBuilder.java new file mode 100644 index 0000000000000000000000000000000000000000..183238ea2968e5bd99d0257532d3c52541403c04 --- /dev/null +++ b/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedFileBuilder.java @@ -0,0 +1,237 @@ +package org.bioimageanalysis.icy.deeplearning.tensorflow.v1.tensor; + +import java.nio.ByteBuffer; +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.nio.IntBuffer; +import java.util.Arrays; + +import org.bioimageanalysis.icy.deeplearning.utils.IndexingUtils; +import org.tensorflow.Tensor; +import org.tensorflow.types.UInt8; + +import net.imglib2.Cursor; +import net.imglib2.RandomAccessibleInterval; +import net.imglib2.img.Img; +import net.imglib2.type.NativeType; +import net.imglib2.type.Type; +import net.imglib2.type.numeric.RealType; +import net.imglib2.type.numeric.integer.ByteType; +import net.imglib2.type.numeric.integer.IntType; +import net.imglib2.type.numeric.real.DoubleType; +import net.imglib2.type.numeric.real.FloatType; +import net.imglib2.util.Util; +import net.imglib2.view.IntervalView; + +/** + * Class that creates temporal files for interprocessing communication for MacOSX execution + * + * @author Carlos Garcia Lopez de Haro + */ +public final class MappedFileBuilder +{ + + final public static byte[] MODEL_RUNNER_HEADER = + {(byte) 0x93, 'M', 'O', 'D', 'E', 'L', '-', 'R', 'U', 'N', 'N', 'E', 'R'}; + + /** + * Not used (Utility class). + */ + private MappedFileBuilder() + { + } + + /** + * Creates a {@link Tensor} based on the provided {@link org.bioimageanalysis.icy.deeplearning.tensor.Tensor} and the desired dimension order for the resulting tensor. + * + * @param ndarray + * The Tensor to be converted. + * @return The tensor created from the sequence. + * @throws IllegalArgumentException + * If the ndarray type is not supported. + */ + public static < T extends RealType< T > & NativeType< T > > void build(org.bioimageanalysis.icy.deeplearning.tensor.Tensor<T> tensor, ByteBuffer byteBuffer) + { + build(tensor.getData(), byteBuffer); + } + + /** + * Creates a {@link Tensor} based on the provided {@link RandomAccessibleInterval} and the desired dimension order for the resulting tensor. + * + * @param rai + * The NDArray to be converted. + * @return The tensor created from the sequence. + * @throws IllegalArgumentException + * If the ndarray type is not supported. + */ + public static <T extends Type<T>> void build(RandomAccessibleInterval<T> rai, ByteBuffer byteBuffer) + { + if (Util.getTypeFromInterval(rai) instanceof ByteType) { + buildByte((RandomAccessibleInterval<ByteType>) rai, byteBuffer); + } else if (Util.getTypeFromInterval(rai) instanceof IntType) { + buildInt((RandomAccessibleInterval<IntType>) rai, byteBuffer); + } else if (Util.getTypeFromInterval(rai) instanceof FloatType) { + buildFloat((RandomAccessibleInterval<FloatType>) rai, byteBuffer); + } else if (Util.getTypeFromInterval(rai) instanceof DoubleType) { + buildDouble((RandomAccessibleInterval<DoubleType>) rai, byteBuffer); + } else { + throw new IllegalArgumentException("The image has an unsupported type: " + Util.getTypeFromInterval(rai).getClass().toString()); + } + } + + /** + * Creates a unsigned byte-typed {@link Tensor} based on the provided {@link RandomAccessibleInterval} and the desired dimension order for the resulting tensor. + * + * @param ndarray + * The sequence to be converted. + * @return The INDArray created from the sequence. + * @throws IllegalArgumentException + * If the ndarray type is not supported. + */ + private static void buildByte(RandomAccessibleInterval<ByteType> imgTensor, ByteBuffer byteBuffer) + { + Cursor<ByteType> tensorCursor; + if (imgTensor instanceof IntervalView) + tensorCursor = ((IntervalView<ByteType>) imgTensor).cursor(); + else if (imgTensor instanceof Img) + tensorCursor = ((Img<ByteType>) imgTensor).cursor(); + else + throw new IllegalArgumentException("The data of the " + Tensor.class + " has " + + "to be an instance of " + Img.class + " or " + IntervalView.class); + while (tensorCursor.hasNext()) { + tensorCursor.fwd(); + byteBuffer.put(tensorCursor.get().getByte()); + } + } + + /** + * Creates a integer-typed {@link Tensor} based on the provided {@link RandomAccessibleInterval} and the desired dimension order for the resulting tensor. + * + * @param ndarray + * The sequence to be converted. + * @return The tensor created from the INDArray. + * @throws IllegalArgumentException + * If the ndarray type is not supported. + */ + private static void buildInt(RandomAccessibleInterval<IntType> imgTensor, ByteBuffer byteBuffer) + { + Cursor<IntType> tensorCursor; + if (imgTensor instanceof IntervalView) + tensorCursor = ((IntervalView<IntType>) imgTensor).cursor(); + else if (imgTensor instanceof Img) + tensorCursor = ((Img<IntType>) imgTensor).cursor(); + else + throw new IllegalArgumentException("The data of the " + Tensor.class + " has " + + "to be an instance of " + Img.class + " or " + IntervalView.class); + while (tensorCursor.hasNext()) { + tensorCursor.fwd(); + byteBuffer.putInt(tensorCursor.get().getInt()); + } + } + + /** + * Creates a float-typed {@link Tensor} based on the provided {@link RandomAccessibleInterval} and the desired dimension order for the resulting tensor. + * + * @param ndarray + * The sequence to be converted. + * @return The tensor created from the INDArray. + * @throws IllegalArgumentException + * If the ndarray type is not supported. + */ + private static void buildFloat(RandomAccessibleInterval<FloatType> imgTensor, ByteBuffer byteBuffer) + { + Cursor<FloatType> tensorCursor; + if (imgTensor instanceof IntervalView) + tensorCursor = ((IntervalView<FloatType>) imgTensor).cursor(); + else if (imgTensor instanceof Img) + tensorCursor = ((Img<FloatType>) imgTensor).cursor(); + else + throw new IllegalArgumentException("The data of the " + Tensor.class + " has " + + "to be an instance of " + Img.class + " or " + IntervalView.class); + while (tensorCursor.hasNext()) { + tensorCursor.fwd(); + byteBuffer.putFloat(tensorCursor.get().getRealFloat()); + } + } + + /** + * Creates a double-typed {@link Tensor} based on the provided {@link RandomAccessibleInterval} and the desired dimension order for the resulting tensor. + * + * @param ndarray + * The ndarray to be converted. + * @return The tensor created from the INDArray. + * @throws IllegalArgumentException + * If the ndarray type is not supported. + */ + private static void buildDouble(RandomAccessibleInterval<DoubleType> imgTensor, ByteBuffer byteBuffer) + { + Cursor<DoubleType> tensorCursor; + if (imgTensor instanceof IntervalView) + tensorCursor = ((IntervalView<DoubleType>) imgTensor).cursor(); + else if (imgTensor instanceof Img) + tensorCursor = ((Img<DoubleType>) imgTensor).cursor(); + else + throw new IllegalArgumentException("The data of the " + Tensor.class + " has " + + "to be an instance of " + Img.class + " or " + IntervalView.class); + while (tensorCursor.hasNext()) { + tensorCursor.fwd(); + byteBuffer.putDouble(tensorCursor.get().getRealDouble()); + } + } + + /** + * Create header for the temp file that is used for interprocess communication. + * The header should contain the first key word as an array of bytes (MODEl-RUNNER) + * @param <T> + * @param tensor + * @return + */ + public static < T extends RealType< T > & NativeType< T > > byte[] + createFileHeader(org.bioimageanalysis.icy.deeplearning.tensor.Tensor<T> tensor) { + String descriptionStr = "{'dtype':'" + tensor.getDataType() + "','shape':'" + + Arrays.toString(tensor.getData().dimensionsAsLongArray()) + "'}"; + + byte[] descriptionBytes = descriptionStr.getBytes(); + int lenDescriptionBytes = descriptionBytes.length; + byte[] inAsBytes = ByteBuffer.allocate(4).putInt(lenDescriptionBytes).array(); + int totalHeaderLen = MODEL_RUNNER_HEADER.length + inAsBytes.length + lenDescriptionBytes; + byte[] byteHeader = new byte[totalHeaderLen]; + for (int i = 0; i < MODEL_RUNNER_HEADER.length; i ++) + byteHeader[i] = MODEL_RUNNER_HEADER[i]; + for (int i = MODEL_RUNNER_HEADER.length; i < MODEL_RUNNER_HEADER.length + inAsBytes.length; i ++) + byteHeader[i] = MODEL_RUNNER_HEADER[i - MODEL_RUNNER_HEADER.length]; + for (int i = MODEL_RUNNER_HEADER.length + inAsBytes.length; i < totalHeaderLen; i ++) + byteHeader[i] = MODEL_RUNNER_HEADER[i - MODEL_RUNNER_HEADER.length - inAsBytes.length]; + + return byteHeader; + } + + /** + * Get the total byte size of the temp file that is oging to be created to do interprocess + * communication for MacOSX + * @param <T> + * @param tensor + * @return + */ + public static < T extends RealType< T > & NativeType< T > > long + findTotalLengthFile(org.bioimageanalysis.icy.deeplearning.tensor.Tensor<T> tensor) { + long startLen = createFileHeader(tensor).length; + long[] dimsArr = tensor.getData().dimensionsAsLongArray(); + long totSizeFlat = 1; + for (long i : dimsArr) {totSizeFlat *= i;} + long nBytesDt = 1; + Type<T> dtype = Util.getTypeFromInterval(tensor.getData()); + if (dtype instanceof IntType) { + nBytesDt = 4; + } else if (dtype instanceof ByteType) { + nBytesDt = 1; + } else if (dtype instanceof FloatType) { + nBytesDt = 4; + } else if (dtype instanceof DoubleType) { + nBytesDt = 8; + } else { + throw new IllegalArgumentException("Unsupported tensor type: " + dtype); + } + return startLen + nBytesDt * totSizeFlat; + } +}