From 97df15c7183c800d9a0b3a24bf66e1526a93c402 Mon Sep 17 00:00:00 2001 From: carlosuc3m <100329787@alumnos.uc3m.es> Date: Tue, 28 Feb 2023 19:19:47 +0100 Subject: [PATCH] keep improving the interprocessing communication for tf1 mac --- .../tensorflow/v1/Tensorflow1Interface.java | 12 ++++++++++-- .../tensorflow/v1/tensor/MappedBufferToImgLib2.java | 11 +++++++++-- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/Tensorflow1Interface.java b/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/Tensorflow1Interface.java index 4102d0a..4531f54 100644 --- a/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/Tensorflow1Interface.java +++ b/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/Tensorflow1Interface.java @@ -135,7 +135,7 @@ public class Tensorflow1Interface implements DeepLearningEngineInterface } } - public static void main(String[] args) throws LoadModelException, IOException { + public static void main(String[] args) throws LoadModelException, IOException, RunModelException { // Unpack the args needed if (args.length < 4) throw new IllegalArgumentException("Error exectuting Tensorflow 1, " @@ -176,7 +176,15 @@ public class Tensorflow1Interface implements DeepLearningEngineInterface } }).collect(Collectors.toList()); List<String> outputNames = map.get(OUTPUTS_MAP_KEY); - + List<Tensor<?>> outputList = outputNames.stream().map(n -> { + try { + return tfInterface.retrieveInterprocessingTensorsByName(n); + } catch (RunModelException e) { + return null; + } + }).collect(Collectors.toList()); + tfInterface.run(inputList, outputList); + tfInterface.createTensorsForInterprocessing(outputList); } @Override diff --git a/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedBufferToImgLib2.java b/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedBufferToImgLib2.java index 855195a..fbeb159 100644 --- a/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedBufferToImgLib2.java +++ b/src/main/java/org/bioimageanalysis/icy/deeplearning/tensorflow/v1/tensor/MappedBufferToImgLib2.java @@ -46,6 +46,10 @@ public final class MappedBufferToImgLib2 * Key for axes info */ private static final String AXES_KEY = "axes"; + /** + * Key for axes info + */ + private static final String NAME_KEY = "name"; /** * Not used (Utility class). @@ -64,12 +68,13 @@ public final class MappedBufferToImgLib2 * If the tensor type is not supported. */ @SuppressWarnings("unchecked") - public static < T extends RealType< T > & NativeType< T > > Tensor<T> buildTensor(String name, ByteBuffer buff) throws IllegalArgumentException + public static < T extends RealType< T > & NativeType< T > > Tensor<T> buildTensor(ByteBuffer buff) throws IllegalArgumentException { String infoStr = getTensorInfoFromBuffer(buff); HashMap<String, Object> map = getDataTypeAndShape(infoStr); String dtype = (String) map.get(DATA_TYPE_KEY); String axes = (String) map.get(AXES_KEY); + String name = (String) map.get(NAME_KEY); long[] shape = (long[]) map.get(SHAPE_KEY); Img<T> data; @@ -245,7 +250,8 @@ public final class MappedBufferToImgLib2 } String typeStr = m.group(1); String axesStr = m.group(2); - String shapeStr = m.group(3); + String nameStr = m.group(3); + String shapeStr = m.group(4); long[] shape = new long[0]; if (!shapeStr.isEmpty()) { String[] tokens = shapeStr.split(", ?"); @@ -255,6 +261,7 @@ public final class MappedBufferToImgLib2 map.put(DATA_TYPE_KEY, typeStr); map.put(AXES_KEY, axesStr); map.put(SHAPE_KEY, shape); + map.put(NAME_KEY, nameStr); return map; } } -- GitLab