Skip to content
Snippets Groups Projects
Commit 18deccbf authored by carlosuc3m's avatar carlosuc3m
Browse files

adapt tf intereface to imglib2

parent 481579e6
No related branches found
No related tags found
No related merge requests found
package org.bioimageanalysis.icy.tensorflow.v1.tensor; package org.bioimageanalysis.icy.tensorflow.v1.tensor;
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import org.bioimageanalysis.icy.deeplearning.tensor.RaiToArray;
import org.tensorflow.Tensor; import org.tensorflow.Tensor;
import org.tensorflow.types.UInt8; import org.tensorflow.types.UInt8;
import net.imglib2.RandomAccessibleInterval; import net.imglib2.RandomAccessibleInterval;
import net.imglib2.type.Type; import net.imglib2.type.Type;
import net.imglib2.type.numeric.integer.ByteType;
import net.imglib2.type.numeric.integer.IntType;
import net.imglib2.type.numeric.real.DoubleType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Util;
/** /**
* A TensorFlow {@link Tensor} builder for {@link INDArray} and {@link org.bioimageanalysis.icy.deeplearning.tensor.Tensor} objects. * A TensorFlow {@link Tensor} builder for {@link INDArray} and {@link org.bioimageanalysis.icy.deeplearning.tensor.Tensor} objects.
...@@ -38,29 +49,29 @@ public final class TensorBuilder ...@@ -38,29 +49,29 @@ public final class TensorBuilder
/** /**
* Creates a {@link Tensor} based on the provided {@link RandomAccessibleInterval} and the desired dimension order for the resulting tensor. * Creates a {@link Tensor} based on the provided {@link RandomAccessibleInterval} and the desired dimension order for the resulting tensor.
* *
* @param ndarray * @param rai
* The NDArray to be converted. * The NDArray to be converted.
* @return The tensor created from the sequence. * @return The tensor created from the sequence.
* @throws IllegalArgumentException * @throws IllegalArgumentException
* If the ndarray type is not supported. * If the ndarray type is not supported.
*/ */
public static <T extends Type<T>> Tensor<?> build(RandomAccessibleInterval<T> ndarray) public static <T extends Type<T>> Tensor<?> build(RandomAccessibleInterval<T> rai)
{ {
if (ndarray.dataType() == DataType.INT8 || ndarray.dataType() == DataType.UINT8) { if (Util.getTypeFromInterval(rai) instanceof ByteType) {
return buildByte(ndarray); return buildByte((RandomAccessibleInterval<ByteType>) rai);
} else if (ndarray.dataType() == DataType.INT32) { } else if (Util.getTypeFromInterval(rai) instanceof IntType) {
return buildInt(ndarray); return buildInt((RandomAccessibleInterval<IntType>) rai);
} else if (ndarray.dataType() == DataType.FLOAT) { } else if (Util.getTypeFromInterval(rai) instanceof FloatType) {
return buildFloat(ndarray); return buildFloat((RandomAccessibleInterval<FloatType>) rai);
} else if (ndarray.dataType() == DataType.DOUBLE) { } else if (Util.getTypeFromInterval(rai) instanceof DoubleType) {
return buildDouble(ndarray); return buildDouble((RandomAccessibleInterval<DoubleType>) rai);
} else { } else {
throw new IllegalArgumentException("The image has an unsupported type: " + ndarray.dataType().toString()); throw new IllegalArgumentException("The image has an unsupported type: " + Util.getTypeFromInterval(rai).getClass().toString());
} }
} }
/** /**
* Creates a unsigned byte-typed {@link Tensor} based on the provided {@link INDArray} and the desired dimension order for the resulting tensor. * Creates a unsigned byte-typed {@link Tensor} based on the provided {@link RandomAccessibleInterval} and the desired dimension order for the resulting tensor.
* *
* @param ndarray * @param ndarray
* The sequence to be converted. * The sequence to be converted.
...@@ -68,17 +79,16 @@ public final class TensorBuilder ...@@ -68,17 +79,16 @@ public final class TensorBuilder
* @throws IllegalArgumentException * @throws IllegalArgumentException
* If the ndarray type is not supported. * If the ndarray type is not supported.
*/ */
private static Tensor<UInt8> buildByte(INDArray ndarray) private static <T extends Type<T>> Tensor<UInt8> buildByte(RandomAccessibleInterval<ByteType> ndarray)
{ {
if (ndarray.dataType() != DataType.INT8 && ndarray.dataType() != DataType.UINT8) byte[] arr = RaiToArray.byteArray(ndarray);
throw new IllegalArgumentException("Image is not of byte type: " + ndarray.dataType().toString()); ByteBuffer buff = ByteBuffer.wrap(arr);
Tensor<UInt8> tensor = Tensor.create(UInt8.class, ndarray.dimensionsAsLongArray(), buff);
Tensor<UInt8> tensor = Tensor.create(UInt8.class, ndarray.shape(), ndarray.data().asNio());
return tensor; return tensor;
} }
/** /**
* Creates a integer-typed {@link Tensor} based on the provided {@link INDArray} and the desired dimension order for the resulting tensor. * Creates a integer-typed {@link Tensor} based on the provided {@link RandomAccessibleInterval} and the desired dimension order for the resulting tensor.
* *
* @param ndarray * @param ndarray
* The sequence to be converted. * The sequence to be converted.
...@@ -86,17 +96,16 @@ public final class TensorBuilder ...@@ -86,17 +96,16 @@ public final class TensorBuilder
* @throws IllegalArgumentException * @throws IllegalArgumentException
* If the ndarray type is not supported. * If the ndarray type is not supported.
*/ */
private static Tensor<Integer> buildInt(INDArray ndarray) private static <T extends Type<T>> Tensor<Integer> buildInt(RandomAccessibleInterval<IntType> ndarray)
{ {
if (ndarray.dataType() != DataType.INT32) int[] arr = RaiToArray.intArray(ndarray);
throw new IllegalArgumentException("Image is not of int type: " + ndarray.dataType().toString()); IntBuffer buff = IntBuffer.wrap(arr);
Tensor<Integer> tensor = Tensor.create(ndarray.dimensionsAsLongArray(), buff);
Tensor<Integer> tensor = Tensor.create(ndarray.shape(), ndarray.data().asNioInt());
return tensor; return tensor;
} }
/** /**
* Creates a float-typed {@link Tensor} based on the provided {@link INDArray} and the desired dimension order for the resulting tensor. * Creates a float-typed {@link Tensor} based on the provided {@link RandomAccessibleInterval} and the desired dimension order for the resulting tensor.
* *
* @param ndarray * @param ndarray
* The sequence to be converted. * The sequence to be converted.
...@@ -104,17 +113,16 @@ public final class TensorBuilder ...@@ -104,17 +113,16 @@ public final class TensorBuilder
* @throws IllegalArgumentException * @throws IllegalArgumentException
* If the ndarray type is not supported. * If the ndarray type is not supported.
*/ */
private static Tensor<Float> buildFloat(INDArray ndarray) private static <T extends Type<T>> Tensor<Float> buildFloat(RandomAccessibleInterval<FloatType> ndarray)
{ {
if (ndarray.dataType() != DataType.FLOAT) float[] arr = RaiToArray.floatArray(ndarray);
throw new IllegalArgumentException("Image is not of float type: " + ndarray.dataType().toString()); FloatBuffer buff = FloatBuffer.wrap(arr);
Tensor<Float> tensor = Tensor.create(ndarray.dimensionsAsLongArray(), buff);
Tensor<Float> tensor = Tensor.create(ndarray.shape(), ndarray.data().asNioFloat());
return tensor; return tensor;
} }
/** /**
* Creates a double-typed {@link Tensor} based on the provided {@link INDArray} and the desired dimension order for the resulting tensor. * Creates a double-typed {@link Tensor} based on the provided {@link RandomAccessibleInterval} and the desired dimension order for the resulting tensor.
* *
* @param ndarray * @param ndarray
* The ndarray to be converted. * The ndarray to be converted.
...@@ -122,12 +130,11 @@ public final class TensorBuilder ...@@ -122,12 +130,11 @@ public final class TensorBuilder
* @throws IllegalArgumentException * @throws IllegalArgumentException
* If the ndarray type is not supported. * If the ndarray type is not supported.
*/ */
private static Tensor<Double> buildDouble(INDArray ndarray) private static <T extends Type<T>> Tensor<Double> buildDouble(RandomAccessibleInterval<DoubleType> ndarray)
{ {
if (ndarray.dataType() != DataType.DOUBLE) double[] arr = RaiToArray.doubleArray(ndarray);
throw new IllegalArgumentException("Image is not of float type: " + ndarray.dataType().toString()); DoubleBuffer buff = DoubleBuffer.wrap(arr);
Tensor<Double> tensor = Tensor.create(ndarray.dimensionsAsLongArray(), buff);
Tensor<Double> tensor = Tensor.create(ndarray.shape(), ndarray.data().asNioDouble());
return tensor; return tensor;
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment