From 06aa49ad6ff6fde0546ce30d06a5067d72f48d5a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Fran=C3=A7ois=20Laurent?= <francois.laurent@posteo.net>
Date: Sun, 15 Jan 2023 21:54:34 +0100
Subject: [PATCH] Labels.encode and Labels.decode reworked; version increment

---
 Project.toml                       |  2 +-
 pyproject.toml                     |  2 +-
 src/taggingbackends/data/labels.py | 57 +++++++++++++++++++++---------
 3 files changed, 42 insertions(+), 19 deletions(-)

diff --git a/Project.toml b/Project.toml
index 4fefa10..1aebd82 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,7 +1,7 @@
 name = "TaggingBackends"
 uuid = "e551f703-3b82-4335-b341-d497b48d519b"
 authors = ["François Laurent", "Institut Pasteur"]
-version = "0.7.2"
+version = "0.8"
 
 [deps]
 Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
diff --git a/pyproject.toml b/pyproject.toml
index 84e30fe..cb7e012 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
 [tool.poetry]
 name = "TaggingBackends"
-version = "0.7.2"
+version = "0.8"
 description = "Backbone for LarvaTagger.jl tagging backends"
 authors = ["François Laurent"]
 
diff --git a/src/taggingbackends/data/labels.py b/src/taggingbackends/data/labels.py
index 3da447f..f9402f5 100644
--- a/src/taggingbackends/data/labels.py
+++ b/src/taggingbackends/data/labels.py
@@ -309,34 +309,57 @@ class Labels:
                     for timestamp, label in zip(track["t"], track["labels"])}
         return self
 
-    def encode(self, labels):
-        if isinstance(self.labelspec, dict):
-            labelset = self.labelspec["names"]
+    """
+    Encode the text labels as indices (`int` or `list` of `int`).
+
+    Labels are 1-indexed. If shifted down, the indices apply to attribute
+    `labelspec`.
+    """
+    def encode(self, label=None):
+        if label is None:
+            encoded = label = self
+            for run_larva in label:
+                label[run_larva] = self.encode(label[run_larva])
+        elif isinstance(label, dict):
+            encoded = {t: self.encode(l) for t, l in label.items()}
         else:
-            labelset = self.labelspec
-        encoded = []
-        for label in labels:
+            if isinstance(self.labelspec, dict):
+                labelset = self.labelspec['names']
+            else:
+                labelset = self.labelspec
             if isinstance(label, str):
-                encoded.append(labelset.index(label)+1)
+                encoded = labelset.index(label) + 1
+            elif isinstance(label, int):
+                encoded = label
+                logging.debug('label(s) already encoded')
             else:
-                encoded.append([labelset.index(label)+1 for label in label])
+                encoded = [labelset.index(l) + 1 for l in label]
         return encoded
 
+    """
+    Decode the label indices as text (`str` or `list` of `str`).
+
+    Text labels are picked in `labelspec`.
+    """
     def decode(self, label=None):
-        if isinstance(self.labelspec, dict):
-            labelset = self.labelspec["names"]
-        else:
-            labelset = self.labelspec
         if label is None:
-            label = decoded = self
+            decoded = label = self
             for run_larva in label:
                 label[run_larva] = self.decode(label[run_larva])
         elif isinstance(label, dict):
-            decoded = {t: labelset[l-1] for t, l in label.items()}
-        elif isinstance(label, int):
-            decoded = labelset[label-1]
+            decoded = {t: self.decode(l) for t, l in label.items()}
         else:
-            decoded = [labelset[l-1] for l in label]
+            if isinstance(self.labelspec, dict):
+                labelset = self.labelspec['names']
+            else:
+                labelset = self.labelspec
+            if isinstance(label, int):
+                decoded = labelset[label-1]
+            elif isinstance(label, str):
+                decoded = label
+                logging.debug('label(s) already decoded')
+            else:
+                decoded = [labelset[l-1] for l in label]
         return decoded
 
 class LabelEncoder(json.JSONEncoder):
-- 
GitLab