diff --git a/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/TensorBuilder.java b/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/TensorBuilder.java index a3b2ff2d4ee43178ef0ffea4c7ea4d1f6888a560..094fc00835d4868e25c2caa828e3cd37a6c41b06 100644 --- a/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/TensorBuilder.java +++ b/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/TensorBuilder.java @@ -5,7 +5,6 @@ import java.nio.DoubleBuffer; import java.nio.FloatBuffer; import java.nio.IntBuffer; -import org.bioimageanalysis.icy.deeplearning.tensor.RaiArrayUtils; import org.bioimageanalysis.icy.deeplearning.utils.IndexingUtils; import org.tensorflow.Tensor; import org.tensorflow.types.UInt8; @@ -153,12 +152,30 @@ public final class TensorBuilder * @throws IllegalArgumentException * If the ndarray type is not supported. */ - private static <T extends Type<T>> Tensor<Float> buildFloat(RandomAccessibleInterval<FloatType> ndarray) + private static <T extends Type<T>> Tensor<Float> buildFloat(RandomAccessibleInterval<FloatType> imgTensor) { - float[] arr = RaiArrayUtils.floatArray(ndarray); - FloatBuffer buff = FloatBuffer.wrap(arr); - Tensor<Float> tensor = Tensor.create(ndarray.dimensionsAsLongArray(), buff); - return tensor; + long[] tensorShape = imgTensor.dimensionsAsLongArray(); + Cursor<FloatType> tensorCursor; + if (imgTensor instanceof IntervalView) + tensorCursor = ((IntervalView<FloatType>) imgTensor).cursor(); + else if (imgTensor instanceof Img) + tensorCursor = ((Img<FloatType>) imgTensor).cursor(); + else + throw new IllegalArgumentException("The data of the " + Tensor.class + " has " + + "to be an instance of " + Img.class + " or " + IntervalView.class); + long flatSize = 1; + for (long dd : imgTensor.dimensionsAsLongArray()) { flatSize *= dd;} + float[] flatArr = new float[(int) flatSize]; + while (tensorCursor.hasNext()) { + tensorCursor.fwd(); + long[] cursorPos = tensorCursor.positionAsLongArray(); + int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape); + float val = tensorCursor.get().getRealFloat(); + flatArr[flatPos] = val; + } + FloatBuffer buff = FloatBuffer.wrap(flatArr); + Tensor<Float> tensor = Tensor.create(imgTensor.dimensionsAsLongArray(), buff); + return tensor; } /** @@ -170,11 +187,29 @@ public final class TensorBuilder * @throws IllegalArgumentException * If the ndarray type is not supported. */ - private static <T extends Type<T>> Tensor<Double> buildDouble(RandomAccessibleInterval<DoubleType> ndarray) + private static <T extends Type<T>> Tensor<Double> buildDouble(RandomAccessibleInterval<DoubleType> imgTensor) { - double[] arr = RaiArrayUtils.doubleArray(ndarray); - DoubleBuffer buff = DoubleBuffer.wrap(arr); - Tensor<Double> tensor = Tensor.create(ndarray.dimensionsAsLongArray(), buff); - return tensor; + long[] tensorShape = imgTensor.dimensionsAsLongArray(); + Cursor<DoubleType> tensorCursor; + if (imgTensor instanceof IntervalView) + tensorCursor = ((IntervalView<DoubleType>) imgTensor).cursor(); + else if (imgTensor instanceof Img) + tensorCursor = ((Img<DoubleType>) imgTensor).cursor(); + else + throw new IllegalArgumentException("The data of the " + Tensor.class + " has " + + "to be an instance of " + Img.class + " or " + IntervalView.class); + long flatSize = 1; + for (long dd : imgTensor.dimensionsAsLongArray()) { flatSize *= dd;} + double[] flatArr = new double[(int) flatSize]; + while (tensorCursor.hasNext()) { + tensorCursor.fwd(); + long[] cursorPos = tensorCursor.positionAsLongArray(); + int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape); + double val = tensorCursor.get().getRealFloat(); + flatArr[flatPos] = val; + } + DoubleBuffer buff = DoubleBuffer.wrap(flatArr); + Tensor<Double> tensor = Tensor.create(imgTensor.dimensionsAsLongArray(), buff); + return tensor; } }