diff --git a/src/taggingbackends/data/convert.py b/src/taggingbackends/data/convert.py new file mode 100644 index 0000000000000000000000000000000000000000..2a65318d2706fc45d90f1480c2e0a49f9581c866 --- /dev/null +++ b/src/taggingbackends/data/convert.py @@ -0,0 +1,35 @@ +from .trxmat import TrxMat +from .labels import Labels +import numpy as np +from pathlib import Path + +""" +Import behavior labels from a trx.mat file and return a Labels object. +""" +def import_labels_from_trxmat(trxmat_file, labels, decode=True): + if not isinstance(labels, list) or not labels: + raise ValueError("labels should be a non-empty list of strings") + if isinstance(trxmat_file, str): + trxmat_file = Path(trxmat_file) + trxmat = TrxMat(trxmat_file) + trxmat_labels = trxmat.read(['t'] + labels) + imported_labels = Labels(labelspec=labels, tracking=[trxmat_file]) + run = next(iter(trxmat_labels[labels[0]])) + larvae = list(trxmat_labels[labels[0]][run]) + for larva in larvae: + times = trxmat_labels['t'][run][larva] + for i, label in enumerate(labels): + indicator = (i+1) * (trxmat_labels[label][run][larva]==1) + if i==0: + encoded_labels = indicator + else: + encoded_labels += indicator + if np.any(len(labels) < encoded_labels): + raise NotImplementedError("overlapping labels") + if decode: + _labels = imported_labels.decode(encoded_labels) + else: + _labels = encoded_labels + imported_labels[(run, larva)] = {t: l for t, l in zip(times, _labels)} + return imported_labels + diff --git a/src/taggingbackends/data/labels.py b/src/taggingbackends/data/labels.py index d741a89ecbecd4960365a2532469729c828d73e0..c54f58ee2a5bbbd148f608f0663b9b71d4d4606c 100644 --- a/src/taggingbackends/data/labels.py +++ b/src/taggingbackends/data/labels.py @@ -314,15 +314,21 @@ class Labels: encoded.append([labelset.index(label)+1 for label in label]) return encoded - def decode(self, label): + def decode(self, label=None): if isinstance(self.labelspec, dict): labelset = self.labelspec["names"] else: labelset = self.labelspec - if isinstance(label, int): + if label is None: + label = decoded = self + for run_larva in label: + label[run_larva] = self.decode(label[run_larva]) + elif isinstance(label, dict): + decoded = {t: labelset[l-1] for t, l in label.items()} + elif isinstance(label, int): decoded = labelset[label-1] else: - decoded = [labelset[label-1] for label in label] + decoded = [labelset[l-1] for l in label] return decoded class LabelEncoder(json.JSONEncoder):