diff --git a/src/taggingbackends/data/labels.py b/src/taggingbackends/data/labels.py index a6da2c08d4fc63e870d2e5edea2815d3959f2fd6..d741a89ecbecd4960365a2532469729c828d73e0 100644 --- a/src/taggingbackends/data/labels.py +++ b/src/taggingbackends/data/labels.py @@ -8,6 +8,54 @@ TIME_PRECISION = 0.001 labels_file_extension = (".nyxlabel", ".label", ".labels") +""" +Python interface to JSON *.label* files. + +It is a poor parent of Julia struct `Dataset` from `PlanarLarvae.jl`. + +The following example function opens such a *.label* file and loads the +associated track data assumed to be available as a *trx.mat* file. +It iterates over the tracks defined in the *.label* file. The same tracks are +expected to be found also in the *trx.mat* file. + +``` +import numpy as np +from sklearn.metrics import confusion_matrix + +def retagged_trxmat_confusion_matrix(label_file): + # load the .label and associated trx.mat files + predicted_labels = Labels(label_file) + labels = predicted_labels.labelspec + track_data_file = TrxMat(label_file.parent / 'trx.mat') + expected_labels = track_data_file.read(labels) + + # iterate the tracks + cm = None + for run, larva in predicted_labels: + + ypred = np.array([labels.index(label) + for label in predicted_labels[(run, larva)].values()]) + yexp = np.stack([expected_labels[label][run][larva] + for label in labels]) + iexp, yexp = np.nonzero(yexp.T==1) # the order of dimensions matters + ypred = ypred[iexp] # if the trx.mat file mentions more tags than the + # .label file, some timesteps may appear as untagged + + cm_ = confusion_matrix(yexp, ypred, labels=np.arange(len(labels))) + cm = cm_ if cm is None else cm + cm_ + return cm +``` + +Note that unlike the `TrxMat` class, keys in `Labels` objects are +(runid, trackid) pairs. +A `Labels` object will typically describe a single run, but can technically +represent several runs. + +The `labelspec` attribute is assumed to be a list, which is valid for *.label* +generated by automatic tagging. +Manual tagging will store label names as `Labels.labelspec['names']`, because +label colors are also stored in the `labelspec` attribute. +""" class Labels: def __init__(self, labels=None, labelspec=None, metadata=None, units=None, @@ -23,8 +71,12 @@ class Labels: self.labelspec, self.units = labelspec, units self._tracking = tracking self._input_labels = None + # + self.filepath = None # unused attribute; may help to manage data + # dependencies in the future if isinstance(labels, (str, pathlib.Path)): - self.load(labels) + self.filepath = pathlib.Path(labels) if isinstance(labels, str) else labels + self.load() @property def tracking(self): @@ -182,10 +234,12 @@ class Labels: else: json.dump(self, file, cls=LabelEncoder, indent=indent, **kwargs) - def load(self, file, format="json"): + def load(self, file=None, format="json"): if format != "json": raise ArgumentError(f"unsupported format: {format}") - if isinstance(file, str): + if file is None: + file = self.filepath + elif isinstance(file, str): file = pathlib.Path(file) if isinstance(file, pathlib.Path): if file.is_dir(): @@ -231,7 +285,6 @@ class Labels: return data def from_hierarchical_format(self, data): - # never tested self.metadata = dict(data["metadata"]) run = self.metadata.pop("id") self.units = data.get("units", {})