From 481579e6149f500632e051d98cf5d23f97ff67ea Mon Sep 17 00:00:00 2001 From: carlosuc3m <100329787@alumnos.uc3m.es> Date: Tue, 12 Jul 2022 18:45:45 +0200 Subject: [PATCH] Merge branch 'develop' of https://gitlab.pasteur.fr/bia/tensorflow-1-interface.git into develop --- .../tensorflow/v1/tensor/ImgLib2Builder.java | 132 ++++++++++++++++++ .../tensorflow/v1/tensor/TensorBuilder.java | 9 +- 2 files changed, 137 insertions(+), 4 deletions(-) create mode 100644 src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/ImgLib2Builder.java diff --git a/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/ImgLib2Builder.java b/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/ImgLib2Builder.java new file mode 100644 index 0000000..146cc67 --- /dev/null +++ b/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/ImgLib2Builder.java @@ -0,0 +1,132 @@ +package org.bioimageanalysis.icy.tensorflow.v1.tensor; + +import java.nio.ByteBuffer; +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.nio.IntBuffer; + +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.tensorflow.Tensor; +import org.tensorflow.types.UInt8; + +/** + * A {@link INDArray} builder for TensorFlow {@link Tensor} objects. + * + * @author Carlos Garcia Lopez de Haro + */ +public final class ImgLib2Builder +{ + + /** + * Not used (Utility class). + */ + private ImgLib2Builder() + { + } + + /** + * Creates a {@link INDArray} from a given {@link Tensor} and an array with its dimensions order. + * + * @param tensor + * The tensor data is read from. + * @return The INDArray built from the tensor. + * @throws IllegalArgumentException + * If the tensor type is not supported. + */ + @SuppressWarnings("unchecked") + public static INDArray build(Tensor<?> tensor) throws IllegalArgumentException + { + // Create an INDArray of the same type of the tensor + switch (tensor.dataType()) + { + case UINT8: + return buildFromTensorByte((Tensor<UInt8>) tensor); + case INT32: + return buildFromTensorInt((Tensor<Integer>) tensor); + case FLOAT: + return buildFromTensorFloat((Tensor<Float>) tensor); + case DOUBLE: + return buildFromTensorDouble((Tensor<Double>) tensor); + default: + throw new IllegalArgumentException("Unsupported tensor type: " + tensor.dataType()); + } + } + + /** + * Builds a {@link INDArray} from a unsigned byte-typed {@link Tensor}. + * + * @param tensor + * The tensor data is read from. + * @return The INDArray built from the tensor of type {@link DataType#UBYTE}. + */ + private static INDArray buildFromTensorByte(Tensor<UInt8> tensor) + { + long[] tensorShape = tensor.shape(); + int totalSize = 1; + for (long i : tensorShape) {totalSize *= i;} + byte[] flatImageArray = new byte[totalSize]; + ByteBuffer outBuff = ByteBuffer.wrap(flatImageArray); + tensor.writeTo(outBuff); + outBuff = null; + return Nd4j.create(flatImageArray, tensorShape, DataType.INT8); + } + + /** + * Builds a {@link INDArray} from a unsigned integer-typed {@link Tensor}. + * + * @param tensor + * The tensor data is read from. + * @return The sequence built from the tensor of type {@link DataType#INT}. + */ + private static INDArray buildFromTensorInt(Tensor<Integer> tensor) + { + long[] tensorShape = tensor.shape(); + int totalSize = 1; + for (long i : tensorShape) {totalSize *= i;} + int[] flatImageArray = new int[totalSize]; + IntBuffer outBuff = IntBuffer.wrap(flatImageArray); + tensor.writeTo(outBuff); + outBuff = null; + return Nd4j.create(flatImageArray, tensorShape, DataType.INT32); + } + + /** + * Builds a {@link INDArray} from a unsigned float-typed {@link Tensor}. + * + * @param tensor + * The tensor data is read from. + * @return The INDArray built from the tensor of type {@link DataType#FLOAT}. + */ + private static INDArray buildFromTensorFloat(Tensor<Float> tensor) + { + long[] tensorShape = tensor.shape(); + int totalSize = 1; + for (long i : tensorShape) {totalSize *= i;} + float[] flatImageArray = new float[totalSize]; + FloatBuffer outBuff = FloatBuffer.wrap(flatImageArray); + tensor.writeTo(outBuff); + outBuff = null; + return Nd4j.create(flatImageArray, tensorShape, DataType.FLOAT); + } + + /** + * Builds a {@link INDArray} from a unsigned double-typed {@link Tensor}. + * + * @param tensor + * The tensor data is read from. + * @return The INDArray built from the tensor of type {@link DataType#DOUBLE}. + */ + private static INDArray buildFromTensorDouble(Tensor<Double> tensor) + { + long[] tensorShape = tensor.shape(); + int totalSize = 1; + for (long i : tensorShape) {totalSize *= i;} + double[] flatImageArray = new double[totalSize]; + DoubleBuffer outBuff = DoubleBuffer.wrap(flatImageArray); + tensor.writeTo(outBuff); + outBuff = null; + return Nd4j.create(flatImageArray, tensorShape, DataType.DOUBLE); + } +} diff --git a/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/TensorBuilder.java b/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/TensorBuilder.java index 3332be6..fb4977c 100644 --- a/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/TensorBuilder.java +++ b/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/TensorBuilder.java @@ -1,10 +1,11 @@ package org.bioimageanalysis.icy.tensorflow.v1.tensor; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; import org.tensorflow.Tensor; import org.tensorflow.types.UInt8; +import net.imglib2.RandomAccessibleInterval; +import net.imglib2.type.Type; + /** * A TensorFlow {@link Tensor} builder for {@link INDArray} and {@link org.bioimageanalysis.icy.deeplearning.tensor.Tensor} objects. * @@ -35,7 +36,7 @@ public final class TensorBuilder } /** - * Creates a {@link Tensor} based on the provided {@link INDArray} and the desired dimension order for the resulting tensor. + * Creates a {@link Tensor} based on the provided {@link RandomAccessibleInterval} and the desired dimension order for the resulting tensor. * * @param ndarray * The NDArray to be converted. @@ -43,7 +44,7 @@ public final class TensorBuilder * @throws IllegalArgumentException * If the ndarray type is not supported. */ - public static Tensor<?> build(INDArray ndarray) + public static <T extends Type<T>> Tensor<?> build(RandomAccessibleInterval<T> ndarray) { if (ndarray.dataType() == DataType.INT8 || ndarray.dataType() == DataType.UINT8) { return buildByte(ndarray); -- GitLab