diff --git a/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/Tensorflow1Interface.java b/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/Tensorflow1Interface.java index 9907ed3de9f85eda71ab1964cc37aec1b1430210..4102d0a1925db0faa36d0054f8814bc70990c99f 100644 --- a/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/Tensorflow1Interface.java +++ b/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/Tensorflow1Interface.java @@ -14,6 +14,7 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.LinkedList; import java.util.List; +import java.util.stream.Collectors; import org.bioimageanalysis.icy.deeplearning.engine.DeepLearningEngineInterface; import org.bioimageanalysis.icy.deeplearning.exceptions.LoadModelException; @@ -32,6 +33,9 @@ import org.tensorflow.framework.TensorInfo; import com.google.protobuf.InvalidProtocolBufferException; +import net.imglib2.type.NativeType; +import net.imglib2.type.numeric.RealType; + /** * This plugin includes the libraries to convert back and forth TensorFlow 1 to Sequences and IcyBufferedImages. @@ -131,7 +135,7 @@ public class Tensorflow1Interface implements DeepLearningEngineInterface } } - public static void main(String[] args) throws LoadModelException { + public static void main(String[] args) throws LoadModelException, IOException { // Unpack the args needed if (args.length < 4) throw new IllegalArgumentException("Error exectuting Tensorflow 1, " @@ -164,17 +168,15 @@ public class Tensorflow1Interface implements DeepLearningEngineInterface HashMap<String, List<String>> map = getInputTensorsFileNames(args); List<String> inputNames = map.get(INPUTS_MAP_KEY); + List<Tensor<?>> inputList = inputNames.stream().map(n -> { + try { + return tfInterface.retrieveInterprocessingTensorsByName(n); + } catch (RunModelException e) { + return null; + } + }).collect(Collectors.toList()); List<String> outputNames = map.get(OUTPUTS_MAP_KEY); - for - for (String inName : inputNames) { - try (RandomAccessFile rd = - new RandomAccessFile(tfInterface.tmpDir + File.separator + inName, "r"); - FileChannel fc = rd.getChannel();) { - int initialLenghToRead = inName.length() + DATATYPE_KEY_LENGTH + MAGIC_WORD_LENGTH + TYPICAL_SHAPE_LENTH; - MappedByteBuffer mem = fc.map(FileChannel.MapMode.READ_WRITE, 0, fc.size()); - } - } } @Override @@ -308,6 +310,20 @@ public class Tensorflow1Interface implements DeepLearningEngineInterface } } + private < T extends RealType< T > & NativeType< T > > Tensor<T> + retrieveInterprocessingTensorsByName(String name) throws RunModelException { + try (RandomAccessFile rd = + new RandomAccessFile(tmpDir + File.separator + name + FILE_EXTENSION, "r"); + FileChannel fc = rd.getChannel();) { + MappedByteBuffer mem = fc.map(FileChannel.MapMode.READ_WRITE, 0, fc.size()); + ByteBuffer byteBuffer = mem.duplicate(); + //tensor.setData(MappedBufferToImgLib2.build(byteBuffer)); + } catch (IOException e) { + throw new RunModelException(e.getCause().toString()); + } + return null; + } + /** * Create the arguments needed to execute tensorflow1 in another * process with the corresponding tensors diff --git a/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedBufferToImgLib2.java b/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedBufferToImgLib2.java index 98878c84f9dddddb9f6891c4567309c62ddae703..855195a92c7255d8e6dffc47c3bda7a2ceb3e5b8 100644 --- a/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedBufferToImgLib2.java +++ b/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedBufferToImgLib2.java @@ -1,24 +1,20 @@ package org.bioimageanalysis.icy.deeplearning.tensorflow.v1.tensor; import java.nio.ByteBuffer; -import java.nio.DoubleBuffer; -import java.nio.FloatBuffer; -import java.nio.IntBuffer; import java.util.Arrays; import java.util.HashMap; -import java.util.Map; import java.util.regex.Matcher; import java.util.regex.Pattern; -import org.bioimageanalysis.icy.deeplearning.utils.IndexingUtils; -import org.tensorflow.Tensor; -import org.tensorflow.types.UInt8; +import org.bioimageanalysis.icy.deeplearning.tensor.Tensor; import net.imglib2.Cursor; import net.imglib2.img.Img; import net.imglib2.img.ImgFactory; import net.imglib2.img.cell.CellImgFactory; +import net.imglib2.type.NativeType; import net.imglib2.type.Type; +import net.imglib2.type.numeric.RealType; import net.imglib2.type.numeric.integer.ByteType; import net.imglib2.type.numeric.integer.IntType; import net.imglib2.type.numeric.real.DoubleType; @@ -36,7 +32,8 @@ public final class MappedBufferToImgLib2 * and retrieves data type and shape */ private static final Pattern HEADER_PATTERN = - Pattern.compile("\\{'dtype':'([a-zA-Z0-9]+)',shape':'\\[(\\d+(,\\d+)*)\\]'}"); + Pattern.compile("\\{'dtype':'([a-zA-Z0-9]+)','axes':'([a-zA-Z0-9]+)'" + + ",'name':'(.+?)',shape':'\\[(\\d+(,\\d+)*)\\]'}"); /** * Key for data type info */ @@ -45,6 +42,10 @@ public final class MappedBufferToImgLib2 * Key for shape info */ private static final String SHAPE_KEY = "shape"; + /** + * Key for axes info + */ + private static final String AXES_KEY = "axes"; /** * Not used (Utility class). @@ -53,6 +54,45 @@ public final class MappedBufferToImgLib2 { } + /** + * Creates a {@link Img} from a given {@link Tensor} and an array with its dimensions order. + * + * @param tensor + * The tensor data is read from. + * @return The INDArray built from the tensor. + * @throws IllegalArgumentException + * If the tensor type is not supported. + */ + @SuppressWarnings("unchecked") + public static < T extends RealType< T > & NativeType< T > > Tensor<T> buildTensor(String name, ByteBuffer buff) throws IllegalArgumentException + { + String infoStr = getTensorInfoFromBuffer(buff); + HashMap<String, Object> map = getDataTypeAndShape(infoStr); + String dtype = (String) map.get(DATA_TYPE_KEY); + String axes = (String) map.get(AXES_KEY); + long[] shape = (long[]) map.get(SHAPE_KEY); + + Img<T> data; + switch (dtype) + { + case "byte": + data = (Img<T>) buildFromTensorByte(buff, shape); + break; + case "int32": + data = (Img<T>) buildFromTensorInt(buff, shape); + break; + case "float32": + data = (Img<T>) buildFromTensorFloat(buff, shape); + break; + case "float64": + data = (Img<T>) buildFromTensorDouble(buff, shape); + break; + default: + throw new IllegalArgumentException("Unsupported tensor type: " + dtype); + } + return Tensor.build(name, axes, data); + } + /** * Creates a {@link Img} from a given {@link Tensor} and an array with its dimensions order. * @@ -204,7 +244,8 @@ public final class MappedBufferToImgLib2 + "info in file hader: " + infoStr); } String typeStr = m.group(1); - String shapeStr = m.group(2); + String axesStr = m.group(2); + String shapeStr = m.group(3); long[] shape = new long[0]; if (!shapeStr.isEmpty()) { String[] tokens = shapeStr.split(", ?"); @@ -212,7 +253,8 @@ public final class MappedBufferToImgLib2 } HashMap<String, Object> map = new HashMap<String, Object>(); map.put(DATA_TYPE_KEY, typeStr); - map.put(SHAPE_KEY, shapeStr); + map.put(AXES_KEY, axesStr); + map.put(SHAPE_KEY, shape); return map; } } diff --git a/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedFileBuilder.java b/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedFileBuilder.java index 183238ea2968e5bd99d0257532d3c52541403c04..af8ade9ad1acade71a582ea82ec0c9f7e5f0e494 100644 --- a/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedFileBuilder.java +++ b/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedFileBuilder.java @@ -188,7 +188,8 @@ public final class MappedFileBuilder */ public static < T extends RealType< T > & NativeType< T > > byte[] createFileHeader(org.bioimageanalysis.icy.deeplearning.tensor.Tensor<T> tensor) { - String descriptionStr = "{'dtype':'" + tensor.getDataType() + "','shape':'" + String descriptionStr = "{'dtype':'" + tensor.getDataType() + "','axes':'" + + tensor.getAxesOrderString() + "','name':'" + tensor.getName() + "','shape':'" + Arrays.toString(tensor.getData().dimensionsAsLongArray()) + "'}"; byte[] descriptionBytes = descriptionStr.getBytes();