diff --git a/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedFileBuilder.java b/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedFileBuilder.java
index 6819629d305865721528276806fa8df48dbe94b2..7999f48fcc40ac36d9bebd9a55ae16d66b9d3862 100644
--- a/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedFileBuilder.java
+++ b/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedFileBuilder.java
@@ -4,6 +4,7 @@ import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
+import org.jcp.xml.dsig.internal.dom.Utils;
import org.tensorflow.Tensor;
import net.imglib2.Cursor;
@@ -14,6 +15,8 @@ import net.imglib2.type.Type;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.integer.ByteType;
import net.imglib2.type.numeric.integer.IntType;
+import net.imglib2.type.numeric.integer.LongType;
+import net.imglib2.type.numeric.integer.UnsignedByteType;
import net.imglib2.type.numeric.real.DoubleType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Util;
@@ -182,13 +185,14 @@ public final class MappedFileBuilder
* @param tensor
* @return
*/
- public static < T extends RealType< T > & NativeType< T > > byte[]
+ public static < T extends RealType< T > & NativeType< T > > byte[]
createFileHeader(org.bioimageanalysis.icy.deeplearning.tensor.Tensor<T> tensor) {
String dimsStr =
tensor.getData() != null ? Arrays.toString(tensor.getData().dimensionsAsLongArray()) : "[]";
- String descriptionStr = "{'dtype':'" + tensor.getDataType() + "','axes':'"
- + tensor.getAxesOrderString() + "','name':'" + tensor.getName() + "','shape':'"
- + dimsStr + "'}";
+ String descriptionStr = "{'dtype':'"
+ + getDataTypeString(Util.getTypeFromInterval(tensor.getData())) + "','axes':'"
+ + tensor.getAxesOrderString() + "','name':'" + tensor.getName() + "','shape':'"
+ + dimsStr + "'}";
byte[] descriptionBytes = descriptionStr.getBytes(StandardCharsets.UTF_8);
int lenDescriptionBytes = descriptionBytes.length;
@@ -204,6 +208,32 @@ public final class MappedFileBuilder
return byteHeader;
}
+
+ /**
+ * Method that returns a Sting representing the datatype of T
+ * @param <T>
+ * @param type
+ * @return
+ */
+ public static< T extends RealType< T > & NativeType< T > > String getDataTypeString(T type) {
+ if (type instanceof ByteType) {
+ return "byte";
+ } else if (type instanceof IntType) {
+ return "int32";
+ } else if (type instanceof FloatType) {
+ return "float32";
+ } else if (type instanceof DoubleType) {
+ return "float64";
+ } else if (type instanceof LongType) {
+ return "int64";
+ } else if (type instanceof UnsignedByteType) {
+ return "ubyte";
+ } else {
+ throw new IllegalArgumentException("Unsupported data type. At the moment the only "
+ + "supported dtypes are: " + IntType.class + ", " + FloatType.class + ", "
+ + DoubleType.class + ", " + LongType.class + " and " + UnsignedByteType.class);
+ }
+ }
/**
* Get the total byte size of the temp file that is oging to be created to do interprocess