diff --git a/src/taggingbackends/data/labels.py b/src/taggingbackends/data/labels.py index fc9bf9b2dc7e2ec07eff960c5a5a01f585666d62..c9301e2ccdd580152607548640483df261bd8c8e 100644 --- a/src/taggingbackends/data/labels.py +++ b/src/taggingbackends/data/labels.py @@ -84,6 +84,7 @@ class Labels: self.secondarylabelspec = None self._tracking = tracking self._input_labels = None + self.decodingspec = None # self.filepath = None # unused attribute; may help to manage data # dependencies in the future @@ -360,7 +361,18 @@ class Labels: """ Decode the label indices as text (`str` or `list` of `str`). - Text labels are picked in `labelspec`. + Text labels are picked in `labelspec`, or `decodingspec` instead if defined. + + `decodingspec` is set to decode the encoded output from a tagger that + defines redundant labels, which are remapped onto labels in `labelspec`. + These may include several sections in their config file, including + `original_behavior_labels`, `behavior_labels`, `remapped_behavior_labels`. + + Note that `decodingspec` should not be set to decode a label file. A label + file should use and mention the remapped labels only, if the latter are + defined by the tagger used to generate the label file. In the case of + decoding a label file, `labelspec` only should be defined in the `Labels` + object. """ def decode(self, label=None): if label is None: @@ -370,7 +382,10 @@ class Labels: elif isinstance(label, dict): decoded = {t: self.decode(l) for t, l in label.items()} else: - labelset = self.full_label_list + if self.decodingspec is None: + labelset = self.full_label_list + else: + labelset = self.decodingspec if isinstance(label, int): decoded = labelset[label-1] elif isinstance(label, str): @@ -380,6 +395,14 @@ class Labels: decoded = [labelset[l-1] for l in label] return decoded + def load_model_config(self, config): + try: + self.labelspec = config["remapped_behavior_labels"] + except KeyError: + self.labelspec = config["behavior_labels"] + else: + self.decodingspec = config["behavior_labels"] + class LabelEncoder(json.JSONEncoder): def default(self, labels): if isinstance(labels, Labels):