Skip to content
Snippets Groups Projects
Commit 18deccbf authored by carlosuc3m's avatar carlosuc3m
Browse files

adapt tf intereface to imglib2

parent 481579e6
Branches
No related tags found
No related merge requests found
package org.bioimageanalysis.icy.tensorflow.v1.tensor;
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import org.bioimageanalysis.icy.deeplearning.tensor.RaiToArray;
import org.tensorflow.Tensor;
import org.tensorflow.types.UInt8;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.type.Type;
import net.imglib2.type.numeric.integer.ByteType;
import net.imglib2.type.numeric.integer.IntType;
import net.imglib2.type.numeric.real.DoubleType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Util;
/**
* A TensorFlow {@link Tensor} builder for {@link INDArray} and {@link org.bioimageanalysis.icy.deeplearning.tensor.Tensor} objects.
......@@ -38,29 +49,29 @@ public final class TensorBuilder
/**
* Creates a {@link Tensor} based on the provided {@link RandomAccessibleInterval} and the desired dimension order for the resulting tensor.
*
* @param ndarray
* @param rai
* The NDArray to be converted.
* @return The tensor created from the sequence.
* @throws IllegalArgumentException
* If the ndarray type is not supported.
*/
public static <T extends Type<T>> Tensor<?> build(RandomAccessibleInterval<T> ndarray)
public static <T extends Type<T>> Tensor<?> build(RandomAccessibleInterval<T> rai)
{
if (ndarray.dataType() == DataType.INT8 || ndarray.dataType() == DataType.UINT8) {
return buildByte(ndarray);
} else if (ndarray.dataType() == DataType.INT32) {
return buildInt(ndarray);
} else if (ndarray.dataType() == DataType.FLOAT) {
return buildFloat(ndarray);
} else if (ndarray.dataType() == DataType.DOUBLE) {
return buildDouble(ndarray);
if (Util.getTypeFromInterval(rai) instanceof ByteType) {
return buildByte((RandomAccessibleInterval<ByteType>) rai);
} else if (Util.getTypeFromInterval(rai) instanceof IntType) {
return buildInt((RandomAccessibleInterval<IntType>) rai);
} else if (Util.getTypeFromInterval(rai) instanceof FloatType) {
return buildFloat((RandomAccessibleInterval<FloatType>) rai);
} else if (Util.getTypeFromInterval(rai) instanceof DoubleType) {
return buildDouble((RandomAccessibleInterval<DoubleType>) rai);
} else {
throw new IllegalArgumentException("The image has an unsupported type: " + ndarray.dataType().toString());
throw new IllegalArgumentException("The image has an unsupported type: " + Util.getTypeFromInterval(rai).getClass().toString());
}
}
/**
* Creates a unsigned byte-typed {@link Tensor} based on the provided {@link INDArray} and the desired dimension order for the resulting tensor.
* Creates a unsigned byte-typed {@link Tensor} based on the provided {@link RandomAccessibleInterval} and the desired dimension order for the resulting tensor.
*
* @param ndarray
* The sequence to be converted.
......@@ -68,17 +79,16 @@ public final class TensorBuilder
* @throws IllegalArgumentException
* If the ndarray type is not supported.
*/
private static Tensor<UInt8> buildByte(INDArray ndarray)
private static <T extends Type<T>> Tensor<UInt8> buildByte(RandomAccessibleInterval<ByteType> ndarray)
{
if (ndarray.dataType() != DataType.INT8 && ndarray.dataType() != DataType.UINT8)
throw new IllegalArgumentException("Image is not of byte type: " + ndarray.dataType().toString());
Tensor<UInt8> tensor = Tensor.create(UInt8.class, ndarray.shape(), ndarray.data().asNio());
byte[] arr = RaiToArray.byteArray(ndarray);
ByteBuffer buff = ByteBuffer.wrap(arr);
Tensor<UInt8> tensor = Tensor.create(UInt8.class, ndarray.dimensionsAsLongArray(), buff);
return tensor;
}
/**
* Creates a integer-typed {@link Tensor} based on the provided {@link INDArray} and the desired dimension order for the resulting tensor.
* Creates a integer-typed {@link Tensor} based on the provided {@link RandomAccessibleInterval} and the desired dimension order for the resulting tensor.
*
* @param ndarray
* The sequence to be converted.
......@@ -86,17 +96,16 @@ public final class TensorBuilder
* @throws IllegalArgumentException
* If the ndarray type is not supported.
*/
private static Tensor<Integer> buildInt(INDArray ndarray)
private static <T extends Type<T>> Tensor<Integer> buildInt(RandomAccessibleInterval<IntType> ndarray)
{
if (ndarray.dataType() != DataType.INT32)
throw new IllegalArgumentException("Image is not of int type: " + ndarray.dataType().toString());
Tensor<Integer> tensor = Tensor.create(ndarray.shape(), ndarray.data().asNioInt());
int[] arr = RaiToArray.intArray(ndarray);
IntBuffer buff = IntBuffer.wrap(arr);
Tensor<Integer> tensor = Tensor.create(ndarray.dimensionsAsLongArray(), buff);
return tensor;
}
/**
* Creates a float-typed {@link Tensor} based on the provided {@link INDArray} and the desired dimension order for the resulting tensor.
* Creates a float-typed {@link Tensor} based on the provided {@link RandomAccessibleInterval} and the desired dimension order for the resulting tensor.
*
* @param ndarray
* The sequence to be converted.
......@@ -104,17 +113,16 @@ public final class TensorBuilder
* @throws IllegalArgumentException
* If the ndarray type is not supported.
*/
private static Tensor<Float> buildFloat(INDArray ndarray)
private static <T extends Type<T>> Tensor<Float> buildFloat(RandomAccessibleInterval<FloatType> ndarray)
{
if (ndarray.dataType() != DataType.FLOAT)
throw new IllegalArgumentException("Image is not of float type: " + ndarray.dataType().toString());
Tensor<Float> tensor = Tensor.create(ndarray.shape(), ndarray.data().asNioFloat());
float[] arr = RaiToArray.floatArray(ndarray);
FloatBuffer buff = FloatBuffer.wrap(arr);
Tensor<Float> tensor = Tensor.create(ndarray.dimensionsAsLongArray(), buff);
return tensor;
}
/**
* Creates a double-typed {@link Tensor} based on the provided {@link INDArray} and the desired dimension order for the resulting tensor.
* Creates a double-typed {@link Tensor} based on the provided {@link RandomAccessibleInterval} and the desired dimension order for the resulting tensor.
*
* @param ndarray
* The ndarray to be converted.
......@@ -122,12 +130,11 @@ public final class TensorBuilder
* @throws IllegalArgumentException
* If the ndarray type is not supported.
*/
private static Tensor<Double> buildDouble(INDArray ndarray)
private static <T extends Type<T>> Tensor<Double> buildDouble(RandomAccessibleInterval<DoubleType> ndarray)
{
if (ndarray.dataType() != DataType.DOUBLE)
throw new IllegalArgumentException("Image is not of float type: " + ndarray.dataType().toString());
Tensor<Double> tensor = Tensor.create(ndarray.shape(), ndarray.data().asNioDouble());
double[] arr = RaiToArray.doubleArray(ndarray);
DoubleBuffer buff = DoubleBuffer.wrap(arr);
Tensor<Double> tensor = Tensor.create(ndarray.dimensionsAsLongArray(), buff);
return tensor;
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment