diff --git a/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/ImgLib2Builder.java b/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/ImgLib2Builder.java
index 74c5501e811c2bdc312df95117aa96c7d9e1cada..e6b8b6017e83ee311015a4fd39c18eebeeaf65ca 100644
--- a/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/ImgLib2Builder.java
+++ b/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/ImgLib2Builder.java
@@ -5,11 +5,14 @@ import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
+import org.bioimageanalysis.icy.deeplearning.utils.IndexingUtils;
import org.tensorflow.Tensor;
import org.tensorflow.types.UInt8;
+import net.imglib2.Cursor;
import net.imglib2.img.Img;
-import net.imglib2.img.array.ArrayImgs;
+import net.imglib2.img.ImgFactory;
+import net.imglib2.img.cell.CellImgFactory;
import net.imglib2.type.Type;
import net.imglib2.type.numeric.integer.ByteType;
import net.imglib2.type.numeric.integer.IntType;
@@ -69,13 +72,23 @@ public final class ImgLib2Builder
private static <T extends Type<T>> Img<ByteType> buildFromTensorByte(Tensor<UInt8> tensor)
{
long[] tensorShape = tensor.shape();
+ final ImgFactory< ByteType > factory = new CellImgFactory<>( new ByteType(), 5 );
+ final Img< ByteType > outputImg = (Img<ByteType>) factory.create(tensorShape);
+ Cursor<ByteType> tensorCursor= outputImg.cursor();
int totalSize = 1;
for (long i : tensorShape) {totalSize *= i;}
- byte[] flatImageArray = new byte[totalSize];
- ByteBuffer outBuff = ByteBuffer.wrap(flatImageArray);
+ byte[] flatArr = new byte[totalSize];
+ ByteBuffer outBuff = ByteBuffer.wrap(flatArr);
tensor.writeTo(outBuff);
outBuff = null;
- return ArrayImgs.bytes(flatImageArray, tensorShape);
+ while (tensorCursor.hasNext()) {
+ tensorCursor.fwd();
+ long[] cursorPos = tensorCursor.positionAsLongArray();
+ int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape);
+ byte val = flatArr[flatPos];
+ tensorCursor.get().set(val);
+ }
+ return outputImg;
}
/**
@@ -87,14 +100,24 @@ public final class ImgLib2Builder
*/
private static <T extends Type<T>> Img<IntType> buildFromTensorInt(Tensor<Integer> tensor)
{
- long[] tensorShape = tensor.shape();
+ long[] tensorShape = tensor.shape();
+ final ImgFactory< IntType > factory = new CellImgFactory<>( new IntType(), 5 );
+ final Img< IntType > outputImg = (Img<IntType>) factory.create(tensorShape);
+ Cursor<IntType> tensorCursor= outputImg.cursor();
int totalSize = 1;
for (long i : tensorShape) {totalSize *= i;}
- int[] flatImageArray = new int[totalSize];
- IntBuffer outBuff = IntBuffer.wrap(flatImageArray);
+ int[] flatArr = new int[totalSize];
+ IntBuffer outBuff = IntBuffer.wrap(flatArr);
tensor.writeTo(outBuff);
outBuff = null;
- return ArrayImgs.ints(flatImageArray, tensorShape);
+ while (tensorCursor.hasNext()) {
+ tensorCursor.fwd();
+ long[] cursorPos = tensorCursor.positionAsLongArray();
+ int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape);
+ int val = flatArr[flatPos];
+ tensorCursor.get().set(val);
+ }
+ return outputImg;
}
/**
@@ -106,14 +129,24 @@ public final class ImgLib2Builder
*/
private static <T extends Type<T>> Img<FloatType> buildFromTensorFloat(Tensor<Float> tensor)
{
- long[] tensorShape = tensor.shape();
+ long[] tensorShape = tensor.shape();
+ final ImgFactory< FloatType > factory = new CellImgFactory<>( new FloatType(), 5 );
+ final Img< FloatType > outputImg = (Img<FloatType>) factory.create(tensorShape);
+ Cursor<FloatType> tensorCursor= outputImg.cursor();
int totalSize = 1;
for (long i : tensorShape) {totalSize *= i;}
- float[] flatImageArray = new float[totalSize];
- FloatBuffer outBuff = FloatBuffer.wrap(flatImageArray);
+ float[] flatArr = new float[totalSize];
+ FloatBuffer outBuff = FloatBuffer.wrap(flatArr);
tensor.writeTo(outBuff);
outBuff = null;
- return ArrayImgs.floats(flatImageArray, tensorShape);
+ while (tensorCursor.hasNext()) {
+ tensorCursor.fwd();
+ long[] cursorPos = tensorCursor.positionAsLongArray();
+ int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape);
+ float val = flatArr[flatPos];
+ tensorCursor.get().set(val);
+ }
+ return outputImg;
}
/**
@@ -125,13 +158,23 @@ public final class ImgLib2Builder
*/
private static <T extends Type<T>> Img<DoubleType> buildFromTensorDouble(Tensor<Double> tensor)
{
- long[] tensorShape = tensor.shape();
+ long[] tensorShape = tensor.shape();
+ final ImgFactory< DoubleType > factory = new CellImgFactory<>( new DoubleType(), 5 );
+ final Img< DoubleType > outputImg = (Img<DoubleType>) factory.create(tensorShape);
+ Cursor<DoubleType> tensorCursor= outputImg.cursor();
int totalSize = 1;
for (long i : tensorShape) {totalSize *= i;}
- double[] flatImageArray = new double[totalSize];
- DoubleBuffer outBuff = DoubleBuffer.wrap(flatImageArray);
+ double[] flatArr = new double[totalSize];
+ DoubleBuffer outBuff = DoubleBuffer.wrap(flatArr);
tensor.writeTo(outBuff);
outBuff = null;
- return ArrayImgs.doubles(flatImageArray, tensorShape);
+ while (tensorCursor.hasNext()) {
+ tensorCursor.fwd();
+ long[] cursorPos = tensorCursor.positionAsLongArray();
+ int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape);
+ double val = flatArr[flatPos];
+ tensorCursor.get().set(val);
+ }
+ return outputImg;
}
}
diff --git a/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/TensorBuilder.java b/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/TensorBuilder.java
index 137533c2512ee97b56d7cd2bba823b5c2cb3f7d0..a3b2ff2d4ee43178ef0ffea4c7ea4d1f6888a560 100644
--- a/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/TensorBuilder.java
+++ b/src/main/java/org/bioimageanalysis/icy/tensorflow/v1/tensor/TensorBuilder.java
@@ -6,16 +6,20 @@ import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import org.bioimageanalysis.icy.deeplearning.tensor.RaiArrayUtils;
+import org.bioimageanalysis.icy.deeplearning.utils.IndexingUtils;
import org.tensorflow.Tensor;
import org.tensorflow.types.UInt8;
+import net.imglib2.Cursor;
import net.imglib2.RandomAccessibleInterval;
+import net.imglib2.img.Img;
import net.imglib2.type.Type;
import net.imglib2.type.numeric.integer.ByteType;
import net.imglib2.type.numeric.integer.IntType;
import net.imglib2.type.numeric.real.DoubleType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Util;
+import net.imglib2.view.IntervalView;
/**
* A TensorFlow {@link Tensor} builder for {@link INDArray} and {@link org.bioimageanalysis.icy.deeplearning.tensor.Tensor} objects.
@@ -79,12 +83,30 @@ public final class TensorBuilder
* @throws IllegalArgumentException
* If the ndarray type is not supported.
*/
- private static <T extends Type<T>> Tensor<UInt8> buildByte(RandomAccessibleInterval<ByteType> ndarray)
+ private static <T extends Type<T>> Tensor<UInt8> buildByte(RandomAccessibleInterval<ByteType> imgTensor)
{
- byte[] arr = RaiArrayUtils.byteArray(ndarray);
- ByteBuffer buff = ByteBuffer.wrap(arr);
- Tensor<UInt8> tensor = Tensor.create(UInt8.class, ndarray.dimensionsAsLongArray(), buff);
- return tensor;
+ long[] tensorShape = imgTensor.dimensionsAsLongArray();
+ Cursor<ByteType> tensorCursor;
+ if (imgTensor instanceof IntervalView)
+ tensorCursor = ((IntervalView<ByteType>) imgTensor).cursor();
+ else if (imgTensor instanceof Img)
+ tensorCursor = ((Img<ByteType>) imgTensor).cursor();
+ else
+ throw new IllegalArgumentException("The data of the " + Tensor.class + " has "
+ + "to be an instance of " + Img.class + " or " + IntervalView.class);
+ long flatSize = 1;
+ for (long dd : imgTensor.dimensionsAsLongArray()) { flatSize *= dd;}
+ byte[] flatArr = new byte[(int) flatSize];
+ while (tensorCursor.hasNext()) {
+ tensorCursor.fwd();
+ long[] cursorPos = tensorCursor.positionAsLongArray();
+ int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape);
+ byte val = tensorCursor.get().getByte();
+ flatArr[flatPos] = val;
+ }
+ ByteBuffer buff = ByteBuffer.wrap(flatArr);
+ Tensor<UInt8> ndarray = Tensor.create(UInt8.class, imgTensor.dimensionsAsLongArray(), buff);
+ return ndarray;
}
/**
@@ -96,12 +118,30 @@ public final class TensorBuilder
* @throws IllegalArgumentException
* If the ndarray type is not supported.
*/
- private static <T extends Type<T>> Tensor<Integer> buildInt(RandomAccessibleInterval<IntType> ndarray)
+ private static <T extends Type<T>> Tensor<Integer> buildInt(RandomAccessibleInterval<IntType> imgTensor)
{
- int[] arr = RaiArrayUtils.intArray(ndarray);
- IntBuffer buff = IntBuffer.wrap(arr);
- Tensor<Integer> tensor = Tensor.create(ndarray.dimensionsAsLongArray(), buff);
- return tensor;
+ long[] tensorShape = imgTensor.dimensionsAsLongArray();
+ Cursor<IntType> tensorCursor;
+ if (imgTensor instanceof IntervalView)
+ tensorCursor = ((IntervalView<IntType>) imgTensor).cursor();
+ else if (imgTensor instanceof Img)
+ tensorCursor = ((Img<IntType>) imgTensor).cursor();
+ else
+ throw new IllegalArgumentException("The data of the " + Tensor.class + " has "
+ + "to be an instance of " + Img.class + " or " + IntervalView.class);
+ long flatSize = 1;
+ for (long dd : imgTensor.dimensionsAsLongArray()) { flatSize *= dd;}
+ int[] flatArr = new int[(int) flatSize];
+ while (tensorCursor.hasNext()) {
+ tensorCursor.fwd();
+ long[] cursorPos = tensorCursor.positionAsLongArray();
+ int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape);
+ int val = tensorCursor.get().getInt();
+ flatArr[flatPos] = val;
+ }
+ IntBuffer buff = IntBuffer.wrap(flatArr);
+ Tensor<Integer> ndarray = Tensor.create(imgTensor.dimensionsAsLongArray(), buff);
+ return ndarray;
}
/**