From 11c81ce3f88ebafb27dd02b805ad763ba75390cd Mon Sep 17 00:00:00 2001 From: carlosuc3m <100329787@alumnos.uc3m.es> Date: Thu, 20 Oct 2022 17:27:43 +0200 Subject: [PATCH] correct tensor creation from imglib2 --- .../tensorflow/v1/tensor/ImgLib2Builder.java | 75 +++++++++++++++---- .../tensorflow/v1/tensor/TensorBuilder.java | 60 ++++++++++++--- 2 files changed, 109 insertions(+), 26 deletions(-) diff --git a/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/ImgLib2Builder.java b/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/ImgLib2Builder.java index 74c5501..e6b8b60 100644 --- a/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/ImgLib2Builder.java +++ b/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/ImgLib2Builder.java @@ -5,11 +5,14 @@ import java.nio.DoubleBuffer; import java.nio.FloatBuffer; import java.nio.IntBuffer; +import org.bioimageanalysis.icy.deeplearning.utils.IndexingUtils; import org.tensorflow.Tensor; import org.tensorflow.types.UInt8; +import net.imglib2.Cursor; import net.imglib2.img.Img; -import net.imglib2.img.array.ArrayImgs; +import net.imglib2.img.ImgFactory; +import net.imglib2.img.cell.CellImgFactory; import net.imglib2.type.Type; import net.imglib2.type.numeric.integer.ByteType; import net.imglib2.type.numeric.integer.IntType; @@ -69,13 +72,23 @@ public final class ImgLib2Builder private static <T extends Type<T>> Img<ByteType> buildFromTensorByte(Tensor<UInt8> tensor) { long[] tensorShape = tensor.shape(); + final ImgFactory< ByteType > factory = new CellImgFactory<>( new ByteType(), 5 ); + final Img< ByteType > outputImg = (Img<ByteType>) factory.create(tensorShape); + Cursor<ByteType> tensorCursor= outputImg.cursor(); int totalSize = 1; for (long i : tensorShape) {totalSize *= i;} - byte[] flatImageArray = new byte[totalSize]; - ByteBuffer outBuff = ByteBuffer.wrap(flatImageArray); + byte[] flatArr = new byte[totalSize]; + ByteBuffer outBuff = ByteBuffer.wrap(flatArr); tensor.writeTo(outBuff); outBuff = null; - return ArrayImgs.bytes(flatImageArray, tensorShape); + while (tensorCursor.hasNext()) { + tensorCursor.fwd(); + long[] cursorPos = tensorCursor.positionAsLongArray(); + int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape); + byte val = flatArr[flatPos]; + tensorCursor.get().set(val); + } + return outputImg; } /** @@ -87,14 +100,24 @@ public final class ImgLib2Builder */ private static <T extends Type<T>> Img<IntType> buildFromTensorInt(Tensor<Integer> tensor) { - long[] tensorShape = tensor.shape(); + long[] tensorShape = tensor.shape(); + final ImgFactory< IntType > factory = new CellImgFactory<>( new IntType(), 5 ); + final Img< IntType > outputImg = (Img<IntType>) factory.create(tensorShape); + Cursor<IntType> tensorCursor= outputImg.cursor(); int totalSize = 1; for (long i : tensorShape) {totalSize *= i;} - int[] flatImageArray = new int[totalSize]; - IntBuffer outBuff = IntBuffer.wrap(flatImageArray); + int[] flatArr = new int[totalSize]; + IntBuffer outBuff = IntBuffer.wrap(flatArr); tensor.writeTo(outBuff); outBuff = null; - return ArrayImgs.ints(flatImageArray, tensorShape); + while (tensorCursor.hasNext()) { + tensorCursor.fwd(); + long[] cursorPos = tensorCursor.positionAsLongArray(); + int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape); + int val = flatArr[flatPos]; + tensorCursor.get().set(val); + } + return outputImg; } /** @@ -106,14 +129,24 @@ public final class ImgLib2Builder */ private static <T extends Type<T>> Img<FloatType> buildFromTensorFloat(Tensor<Float> tensor) { - long[] tensorShape = tensor.shape(); + long[] tensorShape = tensor.shape(); + final ImgFactory< FloatType > factory = new CellImgFactory<>( new FloatType(), 5 ); + final Img< FloatType > outputImg = (Img<FloatType>) factory.create(tensorShape); + Cursor<FloatType> tensorCursor= outputImg.cursor(); int totalSize = 1; for (long i : tensorShape) {totalSize *= i;} - float[] flatImageArray = new float[totalSize]; - FloatBuffer outBuff = FloatBuffer.wrap(flatImageArray); + float[] flatArr = new float[totalSize]; + FloatBuffer outBuff = FloatBuffer.wrap(flatArr); tensor.writeTo(outBuff); outBuff = null; - return ArrayImgs.floats(flatImageArray, tensorShape); + while (tensorCursor.hasNext()) { + tensorCursor.fwd(); + long[] cursorPos = tensorCursor.positionAsLongArray(); + int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape); + float val = flatArr[flatPos]; + tensorCursor.get().set(val); + } + return outputImg; } /** @@ -125,13 +158,23 @@ public final class ImgLib2Builder */ private static <T extends Type<T>> Img<DoubleType> buildFromTensorDouble(Tensor<Double> tensor) { - long[] tensorShape = tensor.shape(); + long[] tensorShape = tensor.shape(); + final ImgFactory< DoubleType > factory = new CellImgFactory<>( new DoubleType(), 5 ); + final Img< DoubleType > outputImg = (Img<DoubleType>) factory.create(tensorShape); + Cursor<DoubleType> tensorCursor= outputImg.cursor(); int totalSize = 1; for (long i : tensorShape) {totalSize *= i;} - double[] flatImageArray = new double[totalSize]; - DoubleBuffer outBuff = DoubleBuffer.wrap(flatImageArray); + double[] flatArr = new double[totalSize]; + DoubleBuffer outBuff = DoubleBuffer.wrap(flatArr); tensor.writeTo(outBuff); outBuff = null; - return ArrayImgs.doubles(flatImageArray, tensorShape); + while (tensorCursor.hasNext()) { + tensorCursor.fwd(); + long[] cursorPos = tensorCursor.positionAsLongArray(); + int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape); + double val = flatArr[flatPos]; + tensorCursor.get().set(val); + } + return outputImg; } } diff --git a/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/TensorBuilder.java b/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/TensorBuilder.java index 137533c..a3b2ff2 100644 --- a/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/TensorBuilder.java +++ b/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/TensorBuilder.java @@ -6,16 +6,20 @@ import java.nio.FloatBuffer; import java.nio.IntBuffer; import org.bioimageanalysis.icy.deeplearning.tensor.RaiArrayUtils; +import org.bioimageanalysis.icy.deeplearning.utils.IndexingUtils; import org.tensorflow.Tensor; import org.tensorflow.types.UInt8; +import net.imglib2.Cursor; import net.imglib2.RandomAccessibleInterval; +import net.imglib2.img.Img; import net.imglib2.type.Type; import net.imglib2.type.numeric.integer.ByteType; import net.imglib2.type.numeric.integer.IntType; import net.imglib2.type.numeric.real.DoubleType; import net.imglib2.type.numeric.real.FloatType; import net.imglib2.util.Util; +import net.imglib2.view.IntervalView; /** * A TensorFlow {@link Tensor} builder for {@link INDArray} and {@link org.bioimageanalysis.icy.deeplearning.tensor.Tensor} objects. @@ -79,12 +83,30 @@ public final class TensorBuilder * @throws IllegalArgumentException * If the ndarray type is not supported. */ - private static <T extends Type<T>> Tensor<UInt8> buildByte(RandomAccessibleInterval<ByteType> ndarray) + private static <T extends Type<T>> Tensor<UInt8> buildByte(RandomAccessibleInterval<ByteType> imgTensor) { - byte[] arr = RaiArrayUtils.byteArray(ndarray); - ByteBuffer buff = ByteBuffer.wrap(arr); - Tensor<UInt8> tensor = Tensor.create(UInt8.class, ndarray.dimensionsAsLongArray(), buff); - return tensor; + long[] tensorShape = imgTensor.dimensionsAsLongArray(); + Cursor<ByteType> tensorCursor; + if (imgTensor instanceof IntervalView) + tensorCursor = ((IntervalView<ByteType>) imgTensor).cursor(); + else if (imgTensor instanceof Img) + tensorCursor = ((Img<ByteType>) imgTensor).cursor(); + else + throw new IllegalArgumentException("The data of the " + Tensor.class + " has " + + "to be an instance of " + Img.class + " or " + IntervalView.class); + long flatSize = 1; + for (long dd : imgTensor.dimensionsAsLongArray()) { flatSize *= dd;} + byte[] flatArr = new byte[(int) flatSize]; + while (tensorCursor.hasNext()) { + tensorCursor.fwd(); + long[] cursorPos = tensorCursor.positionAsLongArray(); + int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape); + byte val = tensorCursor.get().getByte(); + flatArr[flatPos] = val; + } + ByteBuffer buff = ByteBuffer.wrap(flatArr); + Tensor<UInt8> ndarray = Tensor.create(UInt8.class, imgTensor.dimensionsAsLongArray(), buff); + return ndarray; } /** @@ -96,12 +118,30 @@ public final class TensorBuilder * @throws IllegalArgumentException * If the ndarray type is not supported. */ - private static <T extends Type<T>> Tensor<Integer> buildInt(RandomAccessibleInterval<IntType> ndarray) + private static <T extends Type<T>> Tensor<Integer> buildInt(RandomAccessibleInterval<IntType> imgTensor) { - int[] arr = RaiArrayUtils.intArray(ndarray); - IntBuffer buff = IntBuffer.wrap(arr); - Tensor<Integer> tensor = Tensor.create(ndarray.dimensionsAsLongArray(), buff); - return tensor; + long[] tensorShape = imgTensor.dimensionsAsLongArray(); + Cursor<IntType> tensorCursor; + if (imgTensor instanceof IntervalView) + tensorCursor = ((IntervalView<IntType>) imgTensor).cursor(); + else if (imgTensor instanceof Img) + tensorCursor = ((Img<IntType>) imgTensor).cursor(); + else + throw new IllegalArgumentException("The data of the " + Tensor.class + " has " + + "to be an instance of " + Img.class + " or " + IntervalView.class); + long flatSize = 1; + for (long dd : imgTensor.dimensionsAsLongArray()) { flatSize *= dd;} + int[] flatArr = new int[(int) flatSize]; + while (tensorCursor.hasNext()) { + tensorCursor.fwd(); + long[] cursorPos = tensorCursor.positionAsLongArray(); + int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape); + int val = tensorCursor.get().getInt(); + flatArr[flatPos] = val; + } + IntBuffer buff = IntBuffer.wrap(flatArr); + Tensor<Integer> ndarray = Tensor.create(imgTensor.dimensionsAsLongArray(), buff); + return ndarray; } /** -- GitLab