From 97df15c7183c800d9a0b3a24bf66e1526a93c402 Mon Sep 17 00:00:00 2001
From: carlosuc3m <100329787@alumnos.uc3m.es>
Date: Tue, 28 Feb 2023 19:19:47 +0100
Subject: [PATCH] keep improving the interprocessing communication for tf1 mac
---
.../tensorflow/v1/Tensorflow1Interface.java | 12 ++++++++++--
.../tensorflow/v1/tensor/MappedBufferToImgLib2.java | 11 +++++++++--
2 files changed, 19 insertions(+), 4 deletions(-)
diff --git a/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/Tensorflow1Interface.java b/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/Tensorflow1Interface.java
index 4102d0a..4531f54 100644
--- a/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/Tensorflow1Interface.java
+++ b/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/Tensorflow1Interface.java
@@ -135,7 +135,7 @@ public class Tensorflow1Interface implements DeepLearningEngineInterface
}
}
- public static void main(String[] args) throws LoadModelException, IOException {
+ public static void main(String[] args) throws LoadModelException, IOException, RunModelException {
// Unpack the args needed
if (args.length < 4)
throw new IllegalArgumentException("Error exectuting Tensorflow 1, "
@@ -176,7 +176,15 @@ public class Tensorflow1Interface implements DeepLearningEngineInterface
}
}).collect(Collectors.toList());
List<String> outputNames = map.get(OUTPUTS_MAP_KEY);
-
+ List<Tensor<?>> outputList = outputNames.stream().map(n -> {
+ try {
+ return tfInterface.retrieveInterprocessingTensorsByName(n);
+ } catch (RunModelException e) {
+ return null;
+ }
+ }).collect(Collectors.toList());
+ tfInterface.run(inputList, outputList);
+ tfInterface.createTensorsForInterprocessing(outputList);
}
@Override
diff --git a/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedBufferToImgLib2.java b/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedBufferToImgLib2.java
index 855195a..fbeb159 100644
--- a/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedBufferToImgLib2.java
+++ b/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedBufferToImgLib2.java
@@ -46,6 +46,10 @@ public final class MappedBufferToImgLib2
* Key for axes info
*/
private static final String AXES_KEY = "axes";
+ /**
+ * Key for axes info
+ */
+ private static final String NAME_KEY = "name";
/**
* Not used (Utility class).
@@ -64,12 +68,13 @@ public final class MappedBufferToImgLib2
* If the tensor type is not supported.
*/
@SuppressWarnings("unchecked")
- public static < T extends RealType< T > & NativeType< T > > Tensor<T> buildTensor(String name, ByteBuffer buff) throws IllegalArgumentException
+ public static < T extends RealType< T > & NativeType< T > > Tensor<T> buildTensor(ByteBuffer buff) throws IllegalArgumentException
{
String infoStr = getTensorInfoFromBuffer(buff);
HashMap<String, Object> map = getDataTypeAndShape(infoStr);
String dtype = (String) map.get(DATA_TYPE_KEY);
String axes = (String) map.get(AXES_KEY);
+ String name = (String) map.get(NAME_KEY);
long[] shape = (long[]) map.get(SHAPE_KEY);
Img<T> data;
@@ -245,7 +250,8 @@ public final class MappedBufferToImgLib2
}
String typeStr = m.group(1);
String axesStr = m.group(2);
- String shapeStr = m.group(3);
+ String nameStr = m.group(3);
+ String shapeStr = m.group(4);
long[] shape = new long[0];
if (!shapeStr.isEmpty()) {
String[] tokens = shapeStr.split(", ?");
@@ -255,6 +261,7 @@ public final class MappedBufferToImgLib2
map.put(DATA_TYPE_KEY, typeStr);
map.put(AXES_KEY, axesStr);
map.put(SHAPE_KEY, shape);
+ map.put(NAME_KEY, nameStr);
return map;
}
}
--
GitLab