diff --git a/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/ImgLib2Builder.java b/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/ImgLib2Builder.java index 146cc67963297c7fa533e79c26f5c982c6ca6a8b..74c5501e811c2bdc312df95117aa96c7d9e1cada 100644 --- a/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/ImgLib2Builder.java +++ b/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/ImgLib2Builder.java @@ -5,14 +5,19 @@ import java.nio.DoubleBuffer; import java.nio.FloatBuffer; import java.nio.IntBuffer; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; import org.tensorflow.Tensor; import org.tensorflow.types.UInt8; +import net.imglib2.img.Img; +import net.imglib2.img.array.ArrayImgs; +import net.imglib2.type.Type; +import net.imglib2.type.numeric.integer.ByteType; +import net.imglib2.type.numeric.integer.IntType; +import net.imglib2.type.numeric.real.DoubleType; +import net.imglib2.type.numeric.real.FloatType; + /** - * A {@link INDArray} builder for TensorFlow {@link Tensor} objects. + * A {@link Img} builder for TensorFlow {@link Tensor} objects. * * @author Carlos Garcia Lopez de Haro */ @@ -27,7 +32,7 @@ public final class ImgLib2Builder } /** - * Creates a {@link INDArray} from a given {@link Tensor} and an array with its dimensions order. + * Creates a {@link Img} from a given {@link Tensor} and an array with its dimensions order. * * @param tensor * The tensor data is read from. @@ -36,32 +41,32 @@ public final class ImgLib2Builder * If the tensor type is not supported. */ @SuppressWarnings("unchecked") - public static INDArray build(Tensor<?> tensor) throws IllegalArgumentException + public static <T extends Type<T>> Img<T> build(Tensor<?> tensor) throws IllegalArgumentException { // Create an INDArray of the same type of the tensor switch (tensor.dataType()) { case UINT8: - return buildFromTensorByte((Tensor<UInt8>) tensor); + return (Img<T>) buildFromTensorByte((Tensor<UInt8>) tensor); case INT32: - return buildFromTensorInt((Tensor<Integer>) tensor); + return (Img<T>) buildFromTensorInt((Tensor<Integer>) tensor); case FLOAT: - return buildFromTensorFloat((Tensor<Float>) tensor); + return (Img<T>) buildFromTensorFloat((Tensor<Float>) tensor); case DOUBLE: - return buildFromTensorDouble((Tensor<Double>) tensor); + return (Img<T>) buildFromTensorDouble((Tensor<Double>) tensor); default: throw new IllegalArgumentException("Unsupported tensor type: " + tensor.dataType()); } } /** - * Builds a {@link INDArray} from a unsigned byte-typed {@link Tensor}. + * Builds a {@link Img} from a unsigned byte-typed {@link Tensor}. * * @param tensor * The tensor data is read from. * @return The INDArray built from the tensor of type {@link DataType#UBYTE}. */ - private static INDArray buildFromTensorByte(Tensor<UInt8> tensor) + private static <T extends Type<T>> Img<ByteType> buildFromTensorByte(Tensor<UInt8> tensor) { long[] tensorShape = tensor.shape(); int totalSize = 1; @@ -70,17 +75,17 @@ public final class ImgLib2Builder ByteBuffer outBuff = ByteBuffer.wrap(flatImageArray); tensor.writeTo(outBuff); outBuff = null; - return Nd4j.create(flatImageArray, tensorShape, DataType.INT8); + return ArrayImgs.bytes(flatImageArray, tensorShape); } /** - * Builds a {@link INDArray} from a unsigned integer-typed {@link Tensor}. + * Builds a {@link Img} from a unsigned integer-typed {@link Tensor}. * * @param tensor * The tensor data is read from. * @return The sequence built from the tensor of type {@link DataType#INT}. */ - private static INDArray buildFromTensorInt(Tensor<Integer> tensor) + private static <T extends Type<T>> Img<IntType> buildFromTensorInt(Tensor<Integer> tensor) { long[] tensorShape = tensor.shape(); int totalSize = 1; @@ -89,17 +94,17 @@ public final class ImgLib2Builder IntBuffer outBuff = IntBuffer.wrap(flatImageArray); tensor.writeTo(outBuff); outBuff = null; - return Nd4j.create(flatImageArray, tensorShape, DataType.INT32); + return ArrayImgs.ints(flatImageArray, tensorShape); } /** - * Builds a {@link INDArray} from a unsigned float-typed {@link Tensor}. + * Builds a {@link Img} from a unsigned float-typed {@link Tensor}. * * @param tensor * The tensor data is read from. * @return The INDArray built from the tensor of type {@link DataType#FLOAT}. */ - private static INDArray buildFromTensorFloat(Tensor<Float> tensor) + private static <T extends Type<T>> Img<FloatType> buildFromTensorFloat(Tensor<Float> tensor) { long[] tensorShape = tensor.shape(); int totalSize = 1; @@ -108,17 +113,17 @@ public final class ImgLib2Builder FloatBuffer outBuff = FloatBuffer.wrap(flatImageArray); tensor.writeTo(outBuff); outBuff = null; - return Nd4j.create(flatImageArray, tensorShape, DataType.FLOAT); + return ArrayImgs.floats(flatImageArray, tensorShape); } /** - * Builds a {@link INDArray} from a unsigned double-typed {@link Tensor}. + * Builds a {@link Img} from a unsigned double-typed {@link Tensor}. * * @param tensor * The tensor data is read from. * @return The INDArray built from the tensor of type {@link DataType#DOUBLE}. */ - private static INDArray buildFromTensorDouble(Tensor<Double> tensor) + private static <T extends Type<T>> Img<DoubleType> buildFromTensorDouble(Tensor<Double> tensor) { long[] tensorShape = tensor.shape(); int totalSize = 1; @@ -127,6 +132,6 @@ public final class ImgLib2Builder DoubleBuffer outBuff = DoubleBuffer.wrap(flatImageArray); tensor.writeTo(outBuff); outBuff = null; - return Nd4j.create(flatImageArray, tensorShape, DataType.DOUBLE); + return ArrayImgs.doubles(flatImageArray, tensorShape); } } diff --git a/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/Nd4fBuilder.java b/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/Nd4fBuilder.java deleted file mode 100644 index b0a96da65a720e870d50c7f5b93582b70ceba4ef..0000000000000000000000000000000000000000 --- a/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/Nd4fBuilder.java +++ /dev/null @@ -1,132 +0,0 @@ -package org.bioimageanalysis.icy.tensorflow.v1.tensor; - -import java.nio.ByteBuffer; -import java.nio.DoubleBuffer; -import java.nio.FloatBuffer; -import java.nio.IntBuffer; - -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.tensorflow.Tensor; -import org.tensorflow.types.UInt8; - -/** - * A {@link INDArray} builder for TensorFlow {@link Tensor} objects. - * - * @author Carlos Garcia Lopez de Haro - */ -public final class Nd4fBuilder -{ - - /** - * Not used (Utility class). - */ - private Nd4fBuilder() - { - } - - /** - * Creates a {@link INDArray} from a given {@link Tensor} and an array with its dimensions order. - * - * @param tensor - * The tensor data is read from. - * @return The INDArray built from the tensor. - * @throws IllegalArgumentException - * If the tensor type is not supported. - */ - @SuppressWarnings("unchecked") - public static INDArray build(Tensor<?> tensor) throws IllegalArgumentException - { - // Create an INDArray of the same type of the tensor - switch (tensor.dataType()) - { - case UINT8: - return buildFromTensorByte((Tensor<UInt8>) tensor); - case INT32: - return buildFromTensorInt((Tensor<Integer>) tensor); - case FLOAT: - return buildFromTensorFloat((Tensor<Float>) tensor); - case DOUBLE: - return buildFromTensorDouble((Tensor<Double>) tensor); - default: - throw new IllegalArgumentException("Unsupported tensor type: " + tensor.dataType()); - } - } - - /** - * Builds a {@link INDArray} from a unsigned byte-typed {@link Tensor}. - * - * @param tensor - * The tensor data is read from. - * @return The INDArray built from the tensor of type {@link DataType#UBYTE}. - */ - private static INDArray buildFromTensorByte(Tensor<UInt8> tensor) - { - long[] tensorShape = tensor.shape(); - int totalSize = 1; - for (long i : tensorShape) {totalSize *= i;} - byte[] flatImageArray = new byte[totalSize]; - ByteBuffer outBuff = ByteBuffer.wrap(flatImageArray); - tensor.writeTo(outBuff); - outBuff = null; - return Nd4j.create(flatImageArray, tensorShape, DataType.INT8); - } - - /** - * Builds a {@link INDArray} from a unsigned integer-typed {@link Tensor}. - * - * @param tensor - * The tensor data is read from. - * @return The sequence built from the tensor of type {@link DataType#INT}. - */ - private static INDArray buildFromTensorInt(Tensor<Integer> tensor) - { - long[] tensorShape = tensor.shape(); - int totalSize = 1; - for (long i : tensorShape) {totalSize *= i;} - int[] flatImageArray = new int[totalSize]; - IntBuffer outBuff = IntBuffer.wrap(flatImageArray); - tensor.writeTo(outBuff); - outBuff = null; - return Nd4j.create(flatImageArray, tensorShape, DataType.INT32); - } - - /** - * Builds a {@link INDArray} from a unsigned float-typed {@link Tensor}. - * - * @param tensor - * The tensor data is read from. - * @return The INDArray built from the tensor of type {@link DataType#FLOAT}. - */ - private static INDArray buildFromTensorFloat(Tensor<Float> tensor) - { - long[] tensorShape = tensor.shape(); - int totalSize = 1; - for (long i : tensorShape) {totalSize *= i;} - float[] flatImageArray = new float[totalSize]; - FloatBuffer outBuff = FloatBuffer.wrap(flatImageArray); - tensor.writeTo(outBuff); - outBuff = null; - return Nd4j.create(flatImageArray, tensorShape, DataType.FLOAT); - } - - /** - * Builds a {@link INDArray} from a unsigned double-typed {@link Tensor}. - * - * @param tensor - * The tensor data is read from. - * @return The INDArray built from the tensor of type {@link DataType#DOUBLE}. - */ - private static INDArray buildFromTensorDouble(Tensor<Double> tensor) - { - long[] tensorShape = tensor.shape(); - int totalSize = 1; - for (long i : tensorShape) {totalSize *= i;} - double[] flatImageArray = new double[totalSize]; - DoubleBuffer outBuff = DoubleBuffer.wrap(flatImageArray); - tensor.writeTo(outBuff); - outBuff = null; - return Nd4j.create(flatImageArray, tensorShape, DataType.DOUBLE); - } -}