Skip to content
Snippets Groups Projects
Commit 481579e6 authored by carlosuc3m's avatar carlosuc3m
Browse files

Merge branch 'develop' of

parent dcc96005
No related branches found
No related tags found
No related merge requests found
package org.bioimageanalysis.icy.tensorflow.v1.tensor;
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.tensorflow.Tensor;
import org.tensorflow.types.UInt8;
/**
* A {@link INDArray} builder for TensorFlow {@link Tensor} objects.
*
* @author Carlos Garcia Lopez de Haro
*/
public final class ImgLib2Builder
{
/**
* Not used (Utility class).
*/
private ImgLib2Builder()
{
}
/**
* Creates a {@link INDArray} from a given {@link Tensor} and an array with its dimensions order.
*
* @param tensor
* The tensor data is read from.
* @return The INDArray built from the tensor.
* @throws IllegalArgumentException
* If the tensor type is not supported.
*/
@SuppressWarnings("unchecked")
public static INDArray build(Tensor<?> tensor) throws IllegalArgumentException
{
// Create an INDArray of the same type of the tensor
switch (tensor.dataType())
{
case UINT8:
return buildFromTensorByte((Tensor<UInt8>) tensor);
case INT32:
return buildFromTensorInt((Tensor<Integer>) tensor);
case FLOAT:
return buildFromTensorFloat((Tensor<Float>) tensor);
case DOUBLE:
return buildFromTensorDouble((Tensor<Double>) tensor);
default:
throw new IllegalArgumentException("Unsupported tensor type: " + tensor.dataType());
}
}
/**
* Builds a {@link INDArray} from a unsigned byte-typed {@link Tensor}.
*
* @param tensor
* The tensor data is read from.
* @return The INDArray built from the tensor of type {@link DataType#UBYTE}.
*/
private static INDArray buildFromTensorByte(Tensor<UInt8> tensor)
{
long[] tensorShape = tensor.shape();
int totalSize = 1;
for (long i : tensorShape) {totalSize *= i;}
byte[] flatImageArray = new byte[totalSize];
ByteBuffer outBuff = ByteBuffer.wrap(flatImageArray);
tensor.writeTo(outBuff);
outBuff = null;
return Nd4j.create(flatImageArray, tensorShape, DataType.INT8);
}
/**
* Builds a {@link INDArray} from a unsigned integer-typed {@link Tensor}.
*
* @param tensor
* The tensor data is read from.
* @return The sequence built from the tensor of type {@link DataType#INT}.
*/
private static INDArray buildFromTensorInt(Tensor<Integer> tensor)
{
long[] tensorShape = tensor.shape();
int totalSize = 1;
for (long i : tensorShape) {totalSize *= i;}
int[] flatImageArray = new int[totalSize];
IntBuffer outBuff = IntBuffer.wrap(flatImageArray);
tensor.writeTo(outBuff);
outBuff = null;
return Nd4j.create(flatImageArray, tensorShape, DataType.INT32);
}
/**
* Builds a {@link INDArray} from a unsigned float-typed {@link Tensor}.
*
* @param tensor
* The tensor data is read from.
* @return The INDArray built from the tensor of type {@link DataType#FLOAT}.
*/
private static INDArray buildFromTensorFloat(Tensor<Float> tensor)
{
long[] tensorShape = tensor.shape();
int totalSize = 1;
for (long i : tensorShape) {totalSize *= i;}
float[] flatImageArray = new float[totalSize];
FloatBuffer outBuff = FloatBuffer.wrap(flatImageArray);
tensor.writeTo(outBuff);
outBuff = null;
return Nd4j.create(flatImageArray, tensorShape, DataType.FLOAT);
}
/**
* Builds a {@link INDArray} from a unsigned double-typed {@link Tensor}.
*
* @param tensor
* The tensor data is read from.
* @return The INDArray built from the tensor of type {@link DataType#DOUBLE}.
*/
private static INDArray buildFromTensorDouble(Tensor<Double> tensor)
{
long[] tensorShape = tensor.shape();
int totalSize = 1;
for (long i : tensorShape) {totalSize *= i;}
double[] flatImageArray = new double[totalSize];
DoubleBuffer outBuff = DoubleBuffer.wrap(flatImageArray);
tensor.writeTo(outBuff);
outBuff = null;
return Nd4j.create(flatImageArray, tensorShape, DataType.DOUBLE);
}
}
package org.bioimageanalysis.icy.tensorflow.v1.tensor; package org.bioimageanalysis.icy.tensorflow.v1.tensor;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.tensorflow.Tensor; import org.tensorflow.Tensor;
import org.tensorflow.types.UInt8; import org.tensorflow.types.UInt8;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.type.Type;
/** /**
* A TensorFlow {@link Tensor} builder for {@link INDArray} and {@link org.bioimageanalysis.icy.deeplearning.tensor.Tensor} objects. * A TensorFlow {@link Tensor} builder for {@link INDArray} and {@link org.bioimageanalysis.icy.deeplearning.tensor.Tensor} objects.
* *
...@@ -35,7 +36,7 @@ public final class TensorBuilder ...@@ -35,7 +36,7 @@ public final class TensorBuilder
} }
/** /**
* Creates a {@link Tensor} based on the provided {@link INDArray} and the desired dimension order for the resulting tensor. * Creates a {@link Tensor} based on the provided {@link RandomAccessibleInterval} and the desired dimension order for the resulting tensor.
* *
* @param ndarray * @param ndarray
* The NDArray to be converted. * The NDArray to be converted.
...@@ -43,7 +44,7 @@ public final class TensorBuilder ...@@ -43,7 +44,7 @@ public final class TensorBuilder
* @throws IllegalArgumentException * @throws IllegalArgumentException
* If the ndarray type is not supported. * If the ndarray type is not supported.
*/ */
public static Tensor<?> build(INDArray ndarray) public static <T extends Type<T>> Tensor<?> build(RandomAccessibleInterval<T> ndarray)
{ {
if (ndarray.dataType() == DataType.INT8 || ndarray.dataType() == DataType.UINT8) { if (ndarray.dataType() == DataType.INT8 || ndarray.dataType() == DataType.UINT8) {
return buildByte(ndarray); return buildByte(ndarray);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment