From a2a3c5a5e03b41c7cac275ea42f9bb50471bead0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Fran=C3=A7ois=20Laurent?= <francois.laurent@posteo.net>
Date: Tue, 3 May 2022 21:17:27 +0200
Subject: [PATCH] contributes to
 https://gitlab.pasteur.fr/nyx/TaggingBackends/-/issues/1

---
 pyproject.toml                        |  4 ++--
 src/maggotuba/models/predict_model.py | 14 ++++++++------
 2 files changed, 10 insertions(+), 8 deletions(-)

diff --git a/pyproject.toml b/pyproject.toml
index 1507981..8ebdc6f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -2,7 +2,7 @@
 name = "MaggotUBA-adapter"
 version = "0.1.0"
 description = "Interface between MaggotUBA and the Nyx tagging UI"
-authors = ["François Laurent <francois.laurent@posteo.net>"]
+authors = ["François Laurent"]
 license = "MIT"
 packages = [
 	{ include = "maggotuba", from = "src" },
@@ -10,7 +10,7 @@ packages = [
 
 [tool.poetry.dependencies]
 python = "^3.8,<3.11"
-taggingbackends = {git = "https://gitlab.pasteur.fr/nyx/TaggingBackends", rev = "main"}
+taggingbackends = {git = "https://gitlab.pasteur.fr/nyx/TaggingBackends", rev = "dev"}
 structured-temporal-convolution = {git = "git@gitlab.pasteur.fr:les-larves/structured-temporal-convolution.git", branch="dev-branch"}
 torch = "^1.11.0"
 numpy = "^1.19.3"
diff --git a/src/maggotuba/models/predict_model.py b/src/maggotuba/models/predict_model.py
index 3c636da..a6f0742 100644
--- a/src/maggotuba/models/predict_model.py
+++ b/src/maggotuba/models/predict_model.py
@@ -60,7 +60,7 @@ def predict_model(backend):
         config_file = [file for file in model_files if file.name.endswith("config.json")]
         model = RandomForest(config_file[-1]).load()
         # assign labels
-        labels = Labels()
+        labels = Labels(tracking=input_files)
         if isinstance(data, dict):
             ref_length = np.mean(np.concatenate([
                 model.body_length(spines) for spines in data.values()
@@ -84,12 +84,14 @@ def predict_model(backend):
                     labels[run, larva] = dict(zip(t[mask], predictions))
         # save the predicted labels to file
         if metadata:
-            labels[run]['metadata'] = metadata
+            labels.metadata = metadata
         else:
-            labels[run]['metadata'] = {'filename': file.name}
-        labels.metadata['labels'] = ["run", "bend", "stop", "hunch", "back", "roll"]
-        labels.metadata['label_colors'] = ["#000000", "#ff0000", "#00ff00",
-                "#0000ff", "#00ffff", "#ffff00"]
+            labels.metadata = {'filename': file.name}
+        labels.labelspec = {
+                "names": ["run", "bend", "stop", "hunch", "back", "roll"],
+                "colors": ["#000000", "#ff0000", "#00ff00", "#0000ff",
+                    "#00ffff", "#ffff00"]
+                }
         labels.dump(backend.processed_data_dir() / "predicted.labels")
 
 
-- 
GitLab