diff --git a/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/TensorBuilder.java b/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/TensorBuilder.java index fb4977c44dfafca5a028527c9a4a1a7d1a4f3687..931abdc6021bc68e6965b4ac026eb0c0831136d4 100644 --- a/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/TensorBuilder.java +++ b/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/TensorBuilder.java @@ -1,10 +1,21 @@ package org.bioimageanalysis.icy.tensorflow.v1.tensor; +import java.nio.ByteBuffer; +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.nio.IntBuffer; + +import org.bioimageanalysis.icy.deeplearning.tensor.RaiToArray; import org.tensorflow.Tensor; import org.tensorflow.types.UInt8; import net.imglib2.RandomAccessibleInterval; import net.imglib2.type.Type; +import net.imglib2.type.numeric.integer.ByteType; +import net.imglib2.type.numeric.integer.IntType; +import net.imglib2.type.numeric.real.DoubleType; +import net.imglib2.type.numeric.real.FloatType; +import net.imglib2.util.Util; /** * A TensorFlow {@link Tensor} builder for {@link INDArray} and {@link org.bioimageanalysis.icy.deeplearning.tensor.Tensor} objects. @@ -38,29 +49,29 @@ public final class TensorBuilder /** * Creates a {@link Tensor} based on the provided {@link RandomAccessibleInterval} and the desired dimension order for the resulting tensor. * - * @param ndarray + * @param rai * The NDArray to be converted. * @return The tensor created from the sequence. * @throws IllegalArgumentException * If the ndarray type is not supported. */ - public static <T extends Type<T>> Tensor<?> build(RandomAccessibleInterval<T> ndarray) + public static <T extends Type<T>> Tensor<?> build(RandomAccessibleInterval<T> rai) { - if (ndarray.dataType() == DataType.INT8 || ndarray.dataType() == DataType.UINT8) { - return buildByte(ndarray); - } else if (ndarray.dataType() == DataType.INT32) { - return buildInt(ndarray); - } else if (ndarray.dataType() == DataType.FLOAT) { - return buildFloat(ndarray); - } else if (ndarray.dataType() == DataType.DOUBLE) { - return buildDouble(ndarray); + if (Util.getTypeFromInterval(rai) instanceof ByteType) { + return buildByte((RandomAccessibleInterval<ByteType>) rai); + } else if (Util.getTypeFromInterval(rai) instanceof IntType) { + return buildInt((RandomAccessibleInterval<IntType>) rai); + } else if (Util.getTypeFromInterval(rai) instanceof FloatType) { + return buildFloat((RandomAccessibleInterval<FloatType>) rai); + } else if (Util.getTypeFromInterval(rai) instanceof DoubleType) { + return buildDouble((RandomAccessibleInterval<DoubleType>) rai); } else { - throw new IllegalArgumentException("The image has an unsupported type: " + ndarray.dataType().toString()); + throw new IllegalArgumentException("The image has an unsupported type: " + Util.getTypeFromInterval(rai).getClass().toString()); } } /** - * Creates a unsigned byte-typed {@link Tensor} based on the provided {@link INDArray} and the desired dimension order for the resulting tensor. + * Creates a unsigned byte-typed {@link Tensor} based on the provided {@link RandomAccessibleInterval} and the desired dimension order for the resulting tensor. * * @param ndarray * The sequence to be converted. @@ -68,17 +79,16 @@ public final class TensorBuilder * @throws IllegalArgumentException * If the ndarray type is not supported. */ - private static Tensor<UInt8> buildByte(INDArray ndarray) + private static <T extends Type<T>> Tensor<UInt8> buildByte(RandomAccessibleInterval<ByteType> ndarray) { - if (ndarray.dataType() != DataType.INT8 && ndarray.dataType() != DataType.UINT8) - throw new IllegalArgumentException("Image is not of byte type: " + ndarray.dataType().toString()); - - Tensor<UInt8> tensor = Tensor.create(UInt8.class, ndarray.shape(), ndarray.data().asNio()); + byte[] arr = RaiToArray.byteArray(ndarray); + ByteBuffer buff = ByteBuffer.wrap(arr); + Tensor<UInt8> tensor = Tensor.create(UInt8.class, ndarray.dimensionsAsLongArray(), buff); return tensor; } /** - * Creates a integer-typed {@link Tensor} based on the provided {@link INDArray} and the desired dimension order for the resulting tensor. + * Creates a integer-typed {@link Tensor} based on the provided {@link RandomAccessibleInterval} and the desired dimension order for the resulting tensor. * * @param ndarray * The sequence to be converted. @@ -86,17 +96,16 @@ public final class TensorBuilder * @throws IllegalArgumentException * If the ndarray type is not supported. */ - private static Tensor<Integer> buildInt(INDArray ndarray) + private static <T extends Type<T>> Tensor<Integer> buildInt(RandomAccessibleInterval<IntType> ndarray) { - if (ndarray.dataType() != DataType.INT32) - throw new IllegalArgumentException("Image is not of int type: " + ndarray.dataType().toString()); - - Tensor<Integer> tensor = Tensor.create(ndarray.shape(), ndarray.data().asNioInt()); + int[] arr = RaiToArray.intArray(ndarray); + IntBuffer buff = IntBuffer.wrap(arr); + Tensor<Integer> tensor = Tensor.create(ndarray.dimensionsAsLongArray(), buff); return tensor; } /** - * Creates a float-typed {@link Tensor} based on the provided {@link INDArray} and the desired dimension order for the resulting tensor. + * Creates a float-typed {@link Tensor} based on the provided {@link RandomAccessibleInterval} and the desired dimension order for the resulting tensor. * * @param ndarray * The sequence to be converted. @@ -104,17 +113,16 @@ public final class TensorBuilder * @throws IllegalArgumentException * If the ndarray type is not supported. */ - private static Tensor<Float> buildFloat(INDArray ndarray) + private static <T extends Type<T>> Tensor<Float> buildFloat(RandomAccessibleInterval<FloatType> ndarray) { - if (ndarray.dataType() != DataType.FLOAT) - throw new IllegalArgumentException("Image is not of float type: " + ndarray.dataType().toString()); - - Tensor<Float> tensor = Tensor.create(ndarray.shape(), ndarray.data().asNioFloat()); + float[] arr = RaiToArray.floatArray(ndarray); + FloatBuffer buff = FloatBuffer.wrap(arr); + Tensor<Float> tensor = Tensor.create(ndarray.dimensionsAsLongArray(), buff); return tensor; } /** - * Creates a double-typed {@link Tensor} based on the provided {@link INDArray} and the desired dimension order for the resulting tensor. + * Creates a double-typed {@link Tensor} based on the provided {@link RandomAccessibleInterval} and the desired dimension order for the resulting tensor. * * @param ndarray * The ndarray to be converted. @@ -122,12 +130,11 @@ public final class TensorBuilder * @throws IllegalArgumentException * If the ndarray type is not supported. */ - private static Tensor<Double> buildDouble(INDArray ndarray) + private static <T extends Type<T>> Tensor<Double> buildDouble(RandomAccessibleInterval<DoubleType> ndarray) { - if (ndarray.dataType() != DataType.DOUBLE) - throw new IllegalArgumentException("Image is not of float type: " + ndarray.dataType().toString()); - - Tensor<Double> tensor = Tensor.create(ndarray.shape(), ndarray.data().asNioDouble()); + double[] arr = RaiToArray.doubleArray(ndarray); + DoubleBuffer buff = DoubleBuffer.wrap(arr); + Tensor<Double> tensor = Tensor.create(ndarray.dimensionsAsLongArray(), buff); return tensor; } }