From f4ebd759a7ec82c3dcc0f8c1c7095c59b2d4e081 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Laurent?= <francois.laurent@posteo.net> Date: Tue, 14 Mar 2023 12:41:31 +0100 Subject: [PATCH] implements https://gitlab.pasteur.fr/nyx/larvatagger.jl/-/issues/103 --- src/taggingbackends/data/labels.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/src/taggingbackends/data/labels.py b/src/taggingbackends/data/labels.py index fc9bf9b..c9301e2 100644 --- a/src/taggingbackends/data/labels.py +++ b/src/taggingbackends/data/labels.py @@ -84,6 +84,7 @@ class Labels: self.secondarylabelspec = None self._tracking = tracking self._input_labels = None + self.decodingspec = None # self.filepath = None # unused attribute; may help to manage data # dependencies in the future @@ -360,7 +361,18 @@ class Labels: """ Decode the label indices as text (`str` or `list` of `str`). - Text labels are picked in `labelspec`. + Text labels are picked in `labelspec`, or `decodingspec` instead if defined. + + `decodingspec` is set to decode the encoded output from a tagger that + defines redundant labels, which are remapped onto labels in `labelspec`. + These may include several sections in their config file, including + `original_behavior_labels`, `behavior_labels`, `remapped_behavior_labels`. + + Note that `decodingspec` should not be set to decode a label file. A label + file should use and mention the remapped labels only, if the latter are + defined by the tagger used to generate the label file. In the case of + decoding a label file, `labelspec` only should be defined in the `Labels` + object. """ def decode(self, label=None): if label is None: @@ -370,7 +382,10 @@ class Labels: elif isinstance(label, dict): decoded = {t: self.decode(l) for t, l in label.items()} else: - labelset = self.full_label_list + if self.decodingspec is None: + labelset = self.full_label_list + else: + labelset = self.decodingspec if isinstance(label, int): decoded = labelset[label-1] elif isinstance(label, str): @@ -380,6 +395,14 @@ class Labels: decoded = [labelset[l-1] for l in label] return decoded + def load_model_config(self, config): + try: + self.labelspec = config["remapped_behavior_labels"] + except KeyError: + self.labelspec = config["behavior_labels"] + else: + self.decodingspec = config["behavior_labels"] + class LabelEncoder(json.JSONEncoder): def default(self, labels): if isinstance(labels, Labels): -- GitLab