Skip to content
Snippets Groups Projects
Commit 1c737ace authored by François  LAURENT's avatar François LAURENT
Browse files

Merge branch 'dev'

parents 503e8425 5326a48b
No related branches found
No related tags found
No related merge requests found
Pipeline #94946 passed
name = "TaggingBackends" name = "TaggingBackends"
uuid = "e551f703-3b82-4335-b341-d497b48d519b" uuid = "e551f703-3b82-4335-b341-d497b48d519b"
authors = ["François Laurent", "Institut Pasteur"] authors = ["François Laurent", "Institut Pasteur"]
version = "0.7.1" version = "0.7.2"
[deps] [deps]
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
......
[tool.poetry] [tool.poetry]
name = "TaggingBackends" name = "TaggingBackends"
version = "0.7.1" version = "0.7.2"
description = "Backbone for LarvaTagger.jl tagging backends" description = "Backbone for LarvaTagger.jl tagging backends"
authors = ["François Laurent"] authors = ["François Laurent"]
......
...@@ -101,12 +101,13 @@ class LarvaDataset: ...@@ -101,12 +101,13 @@ class LarvaDataset:
train, val, test = torch.utils.data.random_split(TorchDataset(), train, val, test = torch.utils.data.random_split(TorchDataset(),
[ntrain, nval, ntest], [ntrain, nval, ntest],
generator=self.generator) generator=self.generator)
self._training_set = iter(itertools.cycle( if 0 < ntrain:
torch.utils.data.DataLoader(train, self._training_set = iter(itertools.cycle(
batch_size=self.batch_size, torch.utils.data.DataLoader(train,
shuffle=True, batch_size=self.batch_size,
generator=g_train, shuffle=True,
drop_last=True))) generator=g_train,
drop_last=True)))
self._validation_set = iter( self._validation_set = iter(
torch.utils.data.DataLoader(val, torch.utils.data.DataLoader(val,
batch_size=self.batch_size)) batch_size=self.batch_size))
...@@ -186,14 +187,14 @@ def subset_size(ntot, train_share, val_share, test_share): ...@@ -186,14 +187,14 @@ def subset_size(ntot, train_share, val_share, test_share):
while ndelta < 0: while ndelta < 0:
if 0 < val_share: if 0 < val_share:
nval += 1 nval += 1
ndelta -= 1 ndelta += 1
if ndelta < 0: if ndelta < 0:
if 0 < test_share: if 0 < test_share:
ntest += 1 ntest += 1
ndelta -= 1 ndelta += 1
if ndelta < 0: if ndelta < 0:
ntrain += 1 ntrain += 1
ndelta -= 1 ndelta += 1
while 0 < ndelta: while 0 < ndelta:
if 0 < train_share: if 0 < train_share:
ntrain -= 1 ntrain -= 1
......
...@@ -8,6 +8,54 @@ TIME_PRECISION = 0.001 ...@@ -8,6 +8,54 @@ TIME_PRECISION = 0.001
labels_file_extension = (".nyxlabel", ".label", ".labels") labels_file_extension = (".nyxlabel", ".label", ".labels")
"""
Python interface to JSON *.label* files.
It is a poor parent of Julia struct `Dataset` from `PlanarLarvae.jl`.
The following example function opens such a *.label* file and loads the
associated track data assumed to be available as a *trx.mat* file.
It iterates over the tracks defined in the *.label* file. The same tracks are
expected to be found also in the *trx.mat* file.
```
import numpy as np
from sklearn.metrics import confusion_matrix
def retagged_trxmat_confusion_matrix(label_file):
# load the .label and associated trx.mat files
predicted_labels = Labels(label_file)
labels = predicted_labels.labelspec
track_data_file = TrxMat(label_file.parent / 'trx.mat')
expected_labels = track_data_file.read(labels)
# iterate the tracks
cm = None
for run, larva in predicted_labels:
ypred = np.array([labels.index(label)
for label in predicted_labels[(run, larva)].values()])
yexp = np.stack([expected_labels[label][run][larva]
for label in labels])
iexp, yexp = np.nonzero(yexp.T==1) # the order of dimensions matters
ypred = ypred[iexp] # if the trx.mat file mentions more tags than the
# .label file, some timesteps may appear as untagged
cm_ = confusion_matrix(yexp, ypred, labels=np.arange(len(labels)))
cm = cm_ if cm is None else cm + cm_
return cm
```
Note that unlike the `TrxMat` class, keys in `Labels` objects are
(runid, trackid) pairs.
A `Labels` object will typically describe a single run, but can technically
represent several runs.
The `labelspec` attribute is assumed to be a list, which is valid for *.label*
generated by automatic tagging.
Manual tagging will store label names as `Labels.labelspec['names']`, because
label colors are also stored in the `labelspec` attribute.
"""
class Labels: class Labels:
def __init__(self, labels=None, labelspec=None, metadata=None, units=None, def __init__(self, labels=None, labelspec=None, metadata=None, units=None,
...@@ -23,8 +71,12 @@ class Labels: ...@@ -23,8 +71,12 @@ class Labels:
self.labelspec, self.units = labelspec, units self.labelspec, self.units = labelspec, units
self._tracking = tracking self._tracking = tracking
self._input_labels = None self._input_labels = None
#
self.filepath = None # unused attribute; may help to manage data
# dependencies in the future
if isinstance(labels, (str, pathlib.Path)): if isinstance(labels, (str, pathlib.Path)):
self.load(labels) self.filepath = pathlib.Path(labels) if isinstance(labels, str) else labels
self.load()
@property @property
def tracking(self): def tracking(self):
...@@ -182,10 +234,12 @@ class Labels: ...@@ -182,10 +234,12 @@ class Labels:
else: else:
json.dump(self, file, cls=LabelEncoder, indent=indent, **kwargs) json.dump(self, file, cls=LabelEncoder, indent=indent, **kwargs)
def load(self, file, format="json"): def load(self, file=None, format="json"):
if format != "json": if format != "json":
raise ArgumentError(f"unsupported format: {format}") raise ArgumentError(f"unsupported format: {format}")
if isinstance(file, str): if file is None:
file = self.filepath
elif isinstance(file, str):
file = pathlib.Path(file) file = pathlib.Path(file)
if isinstance(file, pathlib.Path): if isinstance(file, pathlib.Path):
if file.is_dir(): if file.is_dir():
......
...@@ -181,6 +181,11 @@ Cannot find any Python package in project root directory: ...@@ -181,6 +181,11 @@ Cannot find any Python package in project root directory:
# Julia logging # Julia logging
elif line and line[0] in "[┌│└": elif line and line[0] in "[┌│└":
print(line) print(line)
# Torch logs
elif line == ' return torch._C._cuda_getDeviceCount() > 0':
# typically follows:
# UserWarning: CUDA initialization: CUDA unknown error
print(line)
# tensorflow logs # tensorflow logs
elif 26 < len(line): elif 26 < len(line):
# assume line[:26] to be e.g.: 2022-05-17 18:48:15.120981 # assume line[:26] to be e.g.: 2022-05-17 18:48:15.120981
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment