Skip to content
Snippets Groups Projects
Commit 11c81ce3 authored by carlosuc3m's avatar carlosuc3m
Browse files

correct tensor creation from imglib2

parent 2fcae999
No related branches found
No related tags found
No related merge requests found
......@@ -5,11 +5,14 @@ import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import org.bioimageanalysis.icy.deeplearning.utils.IndexingUtils;
import org.tensorflow.Tensor;
import org.tensorflow.types.UInt8;
import net.imglib2.Cursor;
import net.imglib2.img.Img;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.img.ImgFactory;
import net.imglib2.img.cell.CellImgFactory;
import net.imglib2.type.Type;
import net.imglib2.type.numeric.integer.ByteType;
import net.imglib2.type.numeric.integer.IntType;
......@@ -69,13 +72,23 @@ public final class ImgLib2Builder
private static <T extends Type<T>> Img<ByteType> buildFromTensorByte(Tensor<UInt8> tensor)
{
long[] tensorShape = tensor.shape();
final ImgFactory< ByteType > factory = new CellImgFactory<>( new ByteType(), 5 );
final Img< ByteType > outputImg = (Img<ByteType>) factory.create(tensorShape);
Cursor<ByteType> tensorCursor= outputImg.cursor();
int totalSize = 1;
for (long i : tensorShape) {totalSize *= i;}
byte[] flatImageArray = new byte[totalSize];
ByteBuffer outBuff = ByteBuffer.wrap(flatImageArray);
byte[] flatArr = new byte[totalSize];
ByteBuffer outBuff = ByteBuffer.wrap(flatArr);
tensor.writeTo(outBuff);
outBuff = null;
return ArrayImgs.bytes(flatImageArray, tensorShape);
while (tensorCursor.hasNext()) {
tensorCursor.fwd();
long[] cursorPos = tensorCursor.positionAsLongArray();
int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape);
byte val = flatArr[flatPos];
tensorCursor.get().set(val);
}
return outputImg;
}
/**
......@@ -87,14 +100,24 @@ public final class ImgLib2Builder
*/
private static <T extends Type<T>> Img<IntType> buildFromTensorInt(Tensor<Integer> tensor)
{
long[] tensorShape = tensor.shape();
long[] tensorShape = tensor.shape();
final ImgFactory< IntType > factory = new CellImgFactory<>( new IntType(), 5 );
final Img< IntType > outputImg = (Img<IntType>) factory.create(tensorShape);
Cursor<IntType> tensorCursor= outputImg.cursor();
int totalSize = 1;
for (long i : tensorShape) {totalSize *= i;}
int[] flatImageArray = new int[totalSize];
IntBuffer outBuff = IntBuffer.wrap(flatImageArray);
int[] flatArr = new int[totalSize];
IntBuffer outBuff = IntBuffer.wrap(flatArr);
tensor.writeTo(outBuff);
outBuff = null;
return ArrayImgs.ints(flatImageArray, tensorShape);
while (tensorCursor.hasNext()) {
tensorCursor.fwd();
long[] cursorPos = tensorCursor.positionAsLongArray();
int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape);
int val = flatArr[flatPos];
tensorCursor.get().set(val);
}
return outputImg;
}
/**
......@@ -106,14 +129,24 @@ public final class ImgLib2Builder
*/
private static <T extends Type<T>> Img<FloatType> buildFromTensorFloat(Tensor<Float> tensor)
{
long[] tensorShape = tensor.shape();
long[] tensorShape = tensor.shape();
final ImgFactory< FloatType > factory = new CellImgFactory<>( new FloatType(), 5 );
final Img< FloatType > outputImg = (Img<FloatType>) factory.create(tensorShape);
Cursor<FloatType> tensorCursor= outputImg.cursor();
int totalSize = 1;
for (long i : tensorShape) {totalSize *= i;}
float[] flatImageArray = new float[totalSize];
FloatBuffer outBuff = FloatBuffer.wrap(flatImageArray);
float[] flatArr = new float[totalSize];
FloatBuffer outBuff = FloatBuffer.wrap(flatArr);
tensor.writeTo(outBuff);
outBuff = null;
return ArrayImgs.floats(flatImageArray, tensorShape);
while (tensorCursor.hasNext()) {
tensorCursor.fwd();
long[] cursorPos = tensorCursor.positionAsLongArray();
int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape);
float val = flatArr[flatPos];
tensorCursor.get().set(val);
}
return outputImg;
}
/**
......@@ -125,13 +158,23 @@ public final class ImgLib2Builder
*/
private static <T extends Type<T>> Img<DoubleType> buildFromTensorDouble(Tensor<Double> tensor)
{
long[] tensorShape = tensor.shape();
long[] tensorShape = tensor.shape();
final ImgFactory< DoubleType > factory = new CellImgFactory<>( new DoubleType(), 5 );
final Img< DoubleType > outputImg = (Img<DoubleType>) factory.create(tensorShape);
Cursor<DoubleType> tensorCursor= outputImg.cursor();
int totalSize = 1;
for (long i : tensorShape) {totalSize *= i;}
double[] flatImageArray = new double[totalSize];
DoubleBuffer outBuff = DoubleBuffer.wrap(flatImageArray);
double[] flatArr = new double[totalSize];
DoubleBuffer outBuff = DoubleBuffer.wrap(flatArr);
tensor.writeTo(outBuff);
outBuff = null;
return ArrayImgs.doubles(flatImageArray, tensorShape);
while (tensorCursor.hasNext()) {
tensorCursor.fwd();
long[] cursorPos = tensorCursor.positionAsLongArray();
int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape);
double val = flatArr[flatPos];
tensorCursor.get().set(val);
}
return outputImg;
}
}
......@@ -6,16 +6,20 @@ import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import org.bioimageanalysis.icy.deeplearning.tensor.RaiArrayUtils;
import org.bioimageanalysis.icy.deeplearning.utils.IndexingUtils;
import org.tensorflow.Tensor;
import org.tensorflow.types.UInt8;
import net.imglib2.Cursor;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.Img;
import net.imglib2.type.Type;
import net.imglib2.type.numeric.integer.ByteType;
import net.imglib2.type.numeric.integer.IntType;
import net.imglib2.type.numeric.real.DoubleType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Util;
import net.imglib2.view.IntervalView;
/**
* A TensorFlow {@link Tensor} builder for {@link INDArray} and {@link org.bioimageanalysis.icy.deeplearning.tensor.Tensor} objects.
......@@ -79,12 +83,30 @@ public final class TensorBuilder
* @throws IllegalArgumentException
* If the ndarray type is not supported.
*/
private static <T extends Type<T>> Tensor<UInt8> buildByte(RandomAccessibleInterval<ByteType> ndarray)
private static <T extends Type<T>> Tensor<UInt8> buildByte(RandomAccessibleInterval<ByteType> imgTensor)
{
byte[] arr = RaiArrayUtils.byteArray(ndarray);
ByteBuffer buff = ByteBuffer.wrap(arr);
Tensor<UInt8> tensor = Tensor.create(UInt8.class, ndarray.dimensionsAsLongArray(), buff);
return tensor;
long[] tensorShape = imgTensor.dimensionsAsLongArray();
Cursor<ByteType> tensorCursor;
if (imgTensor instanceof IntervalView)
tensorCursor = ((IntervalView<ByteType>) imgTensor).cursor();
else if (imgTensor instanceof Img)
tensorCursor = ((Img<ByteType>) imgTensor).cursor();
else
throw new IllegalArgumentException("The data of the " + Tensor.class + " has "
+ "to be an instance of " + Img.class + " or " + IntervalView.class);
long flatSize = 1;
for (long dd : imgTensor.dimensionsAsLongArray()) { flatSize *= dd;}
byte[] flatArr = new byte[(int) flatSize];
while (tensorCursor.hasNext()) {
tensorCursor.fwd();
long[] cursorPos = tensorCursor.positionAsLongArray();
int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape);
byte val = tensorCursor.get().getByte();
flatArr[flatPos] = val;
}
ByteBuffer buff = ByteBuffer.wrap(flatArr);
Tensor<UInt8> ndarray = Tensor.create(UInt8.class, imgTensor.dimensionsAsLongArray(), buff);
return ndarray;
}
/**
......@@ -96,12 +118,30 @@ public final class TensorBuilder
* @throws IllegalArgumentException
* If the ndarray type is not supported.
*/
private static <T extends Type<T>> Tensor<Integer> buildInt(RandomAccessibleInterval<IntType> ndarray)
private static <T extends Type<T>> Tensor<Integer> buildInt(RandomAccessibleInterval<IntType> imgTensor)
{
int[] arr = RaiArrayUtils.intArray(ndarray);
IntBuffer buff = IntBuffer.wrap(arr);
Tensor<Integer> tensor = Tensor.create(ndarray.dimensionsAsLongArray(), buff);
return tensor;
long[] tensorShape = imgTensor.dimensionsAsLongArray();
Cursor<IntType> tensorCursor;
if (imgTensor instanceof IntervalView)
tensorCursor = ((IntervalView<IntType>) imgTensor).cursor();
else if (imgTensor instanceof Img)
tensorCursor = ((Img<IntType>) imgTensor).cursor();
else
throw new IllegalArgumentException("The data of the " + Tensor.class + " has "
+ "to be an instance of " + Img.class + " or " + IntervalView.class);
long flatSize = 1;
for (long dd : imgTensor.dimensionsAsLongArray()) { flatSize *= dd;}
int[] flatArr = new int[(int) flatSize];
while (tensorCursor.hasNext()) {
tensorCursor.fwd();
long[] cursorPos = tensorCursor.positionAsLongArray();
int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape);
int val = tensorCursor.get().getInt();
flatArr[flatPos] = val;
}
IntBuffer buff = IntBuffer.wrap(flatArr);
Tensor<Integer> ndarray = Tensor.create(imgTensor.dimensionsAsLongArray(), buff);
return ndarray;
}
/**
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment