diff --git a/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/ImgLib2Builder.java b/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/ImgLib2Builder.java new file mode 100644 index 0000000000000000000000000000000000000000..146cc67963297c7fa533e79c26f5c982c6ca6a8b --- /dev/null +++ b/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/ImgLib2Builder.java @@ -0,0 +1,132 @@ +package org.bioimageanalysis.icy.tensorflow.v1.tensor; + +import java.nio.ByteBuffer; +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.nio.IntBuffer; + +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.tensorflow.Tensor; +import org.tensorflow.types.UInt8; + +/** + * A {@link INDArray} builder for TensorFlow {@link Tensor} objects. + * + * @author Carlos Garcia Lopez de Haro + */ +public final class ImgLib2Builder +{ + + /** + * Not used (Utility class). + */ + private ImgLib2Builder() + { + } + + /** + * Creates a {@link INDArray} from a given {@link Tensor} and an array with its dimensions order. + * + * @param tensor + * The tensor data is read from. + * @return The INDArray built from the tensor. + * @throws IllegalArgumentException + * If the tensor type is not supported. + */ + @SuppressWarnings("unchecked") + public static INDArray build(Tensor<?> tensor) throws IllegalArgumentException + { + // Create an INDArray of the same type of the tensor + switch (tensor.dataType()) + { + case UINT8: + return buildFromTensorByte((Tensor<UInt8>) tensor); + case INT32: + return buildFromTensorInt((Tensor<Integer>) tensor); + case FLOAT: + return buildFromTensorFloat((Tensor<Float>) tensor); + case DOUBLE: + return buildFromTensorDouble((Tensor<Double>) tensor); + default: + throw new IllegalArgumentException("Unsupported tensor type: " + tensor.dataType()); + } + } + + /** + * Builds a {@link INDArray} from a unsigned byte-typed {@link Tensor}. + * + * @param tensor + * The tensor data is read from. + * @return The INDArray built from the tensor of type {@link DataType#UBYTE}. + */ + private static INDArray buildFromTensorByte(Tensor<UInt8> tensor) + { + long[] tensorShape = tensor.shape(); + int totalSize = 1; + for (long i : tensorShape) {totalSize *= i;} + byte[] flatImageArray = new byte[totalSize]; + ByteBuffer outBuff = ByteBuffer.wrap(flatImageArray); + tensor.writeTo(outBuff); + outBuff = null; + return Nd4j.create(flatImageArray, tensorShape, DataType.INT8); + } + + /** + * Builds a {@link INDArray} from a unsigned integer-typed {@link Tensor}. + * + * @param tensor + * The tensor data is read from. + * @return The sequence built from the tensor of type {@link DataType#INT}. + */ + private static INDArray buildFromTensorInt(Tensor<Integer> tensor) + { + long[] tensorShape = tensor.shape(); + int totalSize = 1; + for (long i : tensorShape) {totalSize *= i;} + int[] flatImageArray = new int[totalSize]; + IntBuffer outBuff = IntBuffer.wrap(flatImageArray); + tensor.writeTo(outBuff); + outBuff = null; + return Nd4j.create(flatImageArray, tensorShape, DataType.INT32); + } + + /** + * Builds a {@link INDArray} from a unsigned float-typed {@link Tensor}. + * + * @param tensor + * The tensor data is read from. + * @return The INDArray built from the tensor of type {@link DataType#FLOAT}. + */ + private static INDArray buildFromTensorFloat(Tensor<Float> tensor) + { + long[] tensorShape = tensor.shape(); + int totalSize = 1; + for (long i : tensorShape) {totalSize *= i;} + float[] flatImageArray = new float[totalSize]; + FloatBuffer outBuff = FloatBuffer.wrap(flatImageArray); + tensor.writeTo(outBuff); + outBuff = null; + return Nd4j.create(flatImageArray, tensorShape, DataType.FLOAT); + } + + /** + * Builds a {@link INDArray} from a unsigned double-typed {@link Tensor}. + * + * @param tensor + * The tensor data is read from. + * @return The INDArray built from the tensor of type {@link DataType#DOUBLE}. + */ + private static INDArray buildFromTensorDouble(Tensor<Double> tensor) + { + long[] tensorShape = tensor.shape(); + int totalSize = 1; + for (long i : tensorShape) {totalSize *= i;} + double[] flatImageArray = new double[totalSize]; + DoubleBuffer outBuff = DoubleBuffer.wrap(flatImageArray); + tensor.writeTo(outBuff); + outBuff = null; + return Nd4j.create(flatImageArray, tensorShape, DataType.DOUBLE); + } +} diff --git a/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/TensorBuilder.java b/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/TensorBuilder.java index 3332be6996af203c5d60fb7dfe1fd66190597a12..fb4977c44dfafca5a028527c9a4a1a7d1a4f3687 100644 --- a/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/TensorBuilder.java +++ b/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/TensorBuilder.java @@ -1,10 +1,11 @@ package org.bioimageanalysis.icy.tensorflow.v1.tensor; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; import org.tensorflow.Tensor; import org.tensorflow.types.UInt8; +import net.imglib2.RandomAccessibleInterval; +import net.imglib2.type.Type; + /** * A TensorFlow {@link Tensor} builder for {@link INDArray} and {@link org.bioimageanalysis.icy.deeplearning.tensor.Tensor} objects. * @@ -35,7 +36,7 @@ public final class TensorBuilder } /** - * Creates a {@link Tensor} based on the provided {@link INDArray} and the desired dimension order for the resulting tensor. + * Creates a {@link Tensor} based on the provided {@link RandomAccessibleInterval} and the desired dimension order for the resulting tensor. * * @param ndarray * The NDArray to be converted. @@ -43,7 +44,7 @@ public final class TensorBuilder * @throws IllegalArgumentException * If the ndarray type is not supported. */ - public static Tensor<?> build(INDArray ndarray) + public static <T extends Type<T>> Tensor<?> build(RandomAccessibleInterval<T> ndarray) { if (ndarray.dataType() == DataType.INT8 || ndarray.dataType() == DataType.UINT8) { return buildByte(ndarray);