Skip to content
Snippets Groups Projects
Commit 5ad8dff6 authored by carlosuc3m's avatar carlosuc3m
Browse files

support sending empty tensors

parent 97df15c7
No related branches found
No related tags found
No related merge requests found
......@@ -76,6 +76,8 @@ public final class MappedBufferToImgLib2
String axes = (String) map.get(AXES_KEY);
String name = (String) map.get(NAME_KEY);
long[] shape = (long[]) map.get(SHAPE_KEY);
if (shape.length == 0)
return Tensor.buildEmptyTensor(name, axes);
Img<T> data;
switch (dtype)
......@@ -114,6 +116,8 @@ public final class MappedBufferToImgLib2
HashMap<String, Object> map = getDataTypeAndShape(infoStr);
String dtype = (String) map.get(DATA_TYPE_KEY);
long[] shape = (long[]) map.get(SHAPE_KEY);
if (shape.length == 0)
return null;
// Create an INDArray of the same type of the tensor
switch (dtype)
......@@ -253,7 +257,7 @@ public final class MappedBufferToImgLib2
String nameStr = m.group(3);
String shapeStr = m.group(4);
long[] shape = new long[0];
if (!shapeStr.isEmpty()) {
if (!shapeStr.isEmpty() && !shapeStr.equals("[]")) {
String[] tokens = shapeStr.split(", ?");
shape = Arrays.stream(tokens).mapToLong(Long::parseLong).toArray();
}
......
package org.bioimageanalysis.icy.deeplearning.tensorflow.v1.tensor;
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.util.Arrays;
import org.bioimageanalysis.icy.deeplearning.utils.IndexingUtils;
import org.tensorflow.Tensor;
import org.tensorflow.types.UInt8;
import net.imglib2.Cursor;
import net.imglib2.RandomAccessibleInterval;
......@@ -188,9 +183,11 @@ public final class MappedFileBuilder
*/
public static < T extends RealType< T > & NativeType< T > > byte[]
createFileHeader(org.bioimageanalysis.icy.deeplearning.tensor.Tensor<T> tensor) {
String dimsStr =
tensor.getData() != null ? Arrays.toString(tensor.getData().dimensionsAsLongArray()) : "[]";
String descriptionStr = "{'dtype':'" + tensor.getDataType() + "','axes':'"
+ tensor.getAxesOrderString() + "','name':'" + tensor.getName() + "','shape':'"
+ Arrays.toString(tensor.getData().dimensionsAsLongArray()) + "'}";
+ dimsStr + "'}";
byte[] descriptionBytes = descriptionStr.getBytes();
int lenDescriptionBytes = descriptionBytes.length;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment