diff --git a/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedFileBuilder.java b/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedFileBuilder.java index 6819629d305865721528276806fa8df48dbe94b2..7999f48fcc40ac36d9bebd9a55ae16d66b9d3862 100644 --- a/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedFileBuilder.java +++ b/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedFileBuilder.java @@ -4,6 +4,7 @@ import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.Arrays; +import org.jcp.xml.dsig.internal.dom.Utils; import org.tensorflow.Tensor; import net.imglib2.Cursor; @@ -14,6 +15,8 @@ import net.imglib2.type.Type; import net.imglib2.type.numeric.RealType; import net.imglib2.type.numeric.integer.ByteType; import net.imglib2.type.numeric.integer.IntType; +import net.imglib2.type.numeric.integer.LongType; +import net.imglib2.type.numeric.integer.UnsignedByteType; import net.imglib2.type.numeric.real.DoubleType; import net.imglib2.type.numeric.real.FloatType; import net.imglib2.util.Util; @@ -182,13 +185,14 @@ public final class MappedFileBuilder * @param tensor * @return */ - public static < T extends RealType< T > & NativeType< T > > byte[] + public static < T extends RealType< T > & NativeType< T > > byte[] createFileHeader(org.bioimageanalysis.icy.deeplearning.tensor.Tensor<T> tensor) { String dimsStr = tensor.getData() != null ? Arrays.toString(tensor.getData().dimensionsAsLongArray()) : "[]"; - String descriptionStr = "{'dtype':'" + tensor.getDataType() + "','axes':'" - + tensor.getAxesOrderString() + "','name':'" + tensor.getName() + "','shape':'" - + dimsStr + "'}"; + String descriptionStr = "{'dtype':'" + + getDataTypeString(Util.getTypeFromInterval(tensor.getData())) + "','axes':'" + + tensor.getAxesOrderString() + "','name':'" + tensor.getName() + "','shape':'" + + dimsStr + "'}"; byte[] descriptionBytes = descriptionStr.getBytes(StandardCharsets.UTF_8); int lenDescriptionBytes = descriptionBytes.length; @@ -204,6 +208,32 @@ public final class MappedFileBuilder return byteHeader; } + + /** + * Method that returns a Sting representing the datatype of T + * @param <T> + * @param type + * @return + */ + public static< T extends RealType< T > & NativeType< T > > String getDataTypeString(T type) { + if (type instanceof ByteType) { + return "byte"; + } else if (type instanceof IntType) { + return "int32"; + } else if (type instanceof FloatType) { + return "float32"; + } else if (type instanceof DoubleType) { + return "float64"; + } else if (type instanceof LongType) { + return "int64"; + } else if (type instanceof UnsignedByteType) { + return "ubyte"; + } else { + throw new IllegalArgumentException("Unsupported data type. At the moment the only " + + "supported dtypes are: " + IntType.class + ", " + FloatType.class + ", " + + DoubleType.class + ", " + LongType.class + " and " + UnsignedByteType.class); + } + } /** * Get the total byte size of the temp file that is oging to be created to do interprocess