From fab127a029f1707ba01c4e72a8a6d5653183d184 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Fran=C3=A7ois=20Laurent?= <francois.laurent@posteo.net>
Date: Fri, 27 May 2022 12:04:06 +0200
Subject: [PATCH] metadata handling moved to TaggingBackends (see commit
 https://gitlab.pasteur.fr/nyx/TaggingBackends/-/commit/8e3e2fd5d5f24bf6523a9e969cbfe94f34b5457b)

---
 src/maggotuba/models/denselayer.py    |  2 +-
 src/maggotuba/models/predict_model.py | 27 +++++++++++----------------
 2 files changed, 12 insertions(+), 17 deletions(-)

diff --git a/src/maggotuba/models/denselayer.py b/src/maggotuba/models/denselayer.py
index 82a6f7a..7196bcd 100644
--- a/src/maggotuba/models/denselayer.py
+++ b/src/maggotuba/models/denselayer.py
@@ -169,7 +169,7 @@ class DenseLayer:
         if train:
             return y
         else:
-            return y.numpy()
+            return y.cpu().numpy()
 
     def train(self, all_spines=None, tags=None):
         if all_spines is None or tags is None:
diff --git a/src/maggotuba/models/predict_model.py b/src/maggotuba/models/predict_model.py
index 681cf4b..4299dd2 100644
--- a/src/maggotuba/models/predict_model.py
+++ b/src/maggotuba/models/predict_model.py
@@ -25,13 +25,10 @@ def predict_model(backend):
     # we could go and pick files in `data/interim` as well:
     input_files += backend.list_interim_files()
     assert 0 < len(input_files)
-    metadata = None
-    metadata_file = [file for file in input_files if file.name == "metadata"]
-    if metadata_file:
-        metadata_file = metadata_file[0]
-        input_files.remove(metadata_file)
-        with open(metadata_file, "r") as f:
-            metadata = json.load(f)
+    # initialize output labels
+    input_files, labels = backend.prepare_labels(input_files)
+    assert 0 < len(input_files)
+    #
     for file in input_files:
         # load the input data (or features)
         if file.name.endswith(".spine"):
@@ -40,7 +37,7 @@ def predict_model(backend):
             larvae = spine["larva_id"].values
             t = spine["time"].values
             data = spine.iloc[:,3:].values
-        elif file.name == "trx.mat":
+        elif file.name.endswith(".mat"):
             trx = TrxMat(file)
             t = trx["t"]
             data = trx["spine"]
@@ -49,11 +46,14 @@ def predict_model(backend):
                 run, data = next(iter(data.items()))
             t = t[run]
         elif file.name.endswith(".csv"):
-            print("assuming 30 fps")
-            t, data = fimtrack.read_spines(file, fps=30)
+            if labels.camera_framerate:
+                print(f"camera frame rate: {labels.camera_framerate}fps")
+            else:
+                print("assuming 30-fps camera frame rate")
+                labels.camera_framerate = 30
+            t, data = fimtrack.read_spines(file, fps=labels.camera_framerate)
             run = "NA"
         else:
-            # TODO: support more file formats
             continue
         # downsample the skeleton
         if isinstance(data, dict):
@@ -68,7 +68,6 @@ def predict_model(backend):
             config_file = [file for file in config_file if file.name.endswith("clf_config.json")]
         model = Clf(config_file[-1])
         # assign labels
-        labels = Labels(tracking=input_files)
         if isinstance(data, dict):
             ref_length = np.mean(np.concatenate([
                 model.body_length(spines) for spines in data.values()
@@ -93,10 +92,6 @@ def predict_model(backend):
                 else:
                     labels[run, larva] = dict(zip(t[mask], predictions))
         # save the predicted labels to file
-        if metadata:
-            labels.metadata = metadata
-        else:
-            labels.metadata = {'filename': file.name}
         labels.labelspec = {
                 "names": ["run", "bend", "stop", "hunch", "back", "roll"],
                 "colors": ["#000000", "#ff0000", "#00ff00", "#0000ff",
-- 
GitLab