From 840881723c366fbc7302c83c6fccf4e466b769b2 Mon Sep 17 00:00:00 2001 From: Alexandre Blanc <alexandre.blanc@pasteur.fr> Date: Fri, 11 Oct 2024 18:17:40 +0200 Subject: [PATCH] Change trx syntax in repository, fix larva indexing --- src/pytrxmat/repository.py | 9 +++++---- tests/test_repository.py | 1 + tests/test_trx.py | 12 ++++++++++-- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/src/pytrxmat/repository.py b/src/pytrxmat/repository.py index 5aa76f0..b53d4fd 100644 --- a/src/pytrxmat/repository.py +++ b/src/pytrxmat/repository.py @@ -1,5 +1,6 @@ from pathlib import Path from typing import Dict, List +import functools from .trx import TRX class DateTime: @@ -10,7 +11,7 @@ class DateTime: @property def trx(self)->TRX: - return TRX(self.path/'trx.mat') + return functools.partial(TRX, self.path/'trx.mat') class Protocol: def __init__(self, name:str, root:Path): @@ -34,7 +35,7 @@ class Line: self.name = name self.root = root assert(str(root.name) == name) - self._protocols:Dict[str, Protocol] = {p.name:Protocol(p.name, p) for p in self.root.glob('p_*') if p.is_dir()} + self._protocols:Dict[str, Protocol] = {p.name:Protocol(p.name, p) for p in self.root.iterdir() if p.is_dir() and (p.name.startswith('p_') or p.name.startswith('ch_'))} def __getitem__(self, protocol:str)->Protocol: return self._protocols[protocol] @@ -82,8 +83,8 @@ class Tracker: class Repository: - def __init__(self, root:Path, trackers:List[str]): - self.root = root + def __init__(self, root, trackers:List[str]): + self.root = Path(root) self._trackers:Dict[str, Tracker] = {t:Tracker(t, self.root/t) for t in trackers if (self.root/t in self.root.iterdir()) and (self.root/t).is_dir()} def __getitem__(self, tracker:str)->Tracker: diff --git a/tests/test_repository.py b/tests/test_repository.py index 9ec407f..7ef2a91 100644 --- a/tests/test_repository.py +++ b/tests/test_repository.py @@ -52,5 +52,6 @@ def test_repository(root): print(dt) for trx in repo.trx(): print(trx) + print(trx()) print(repo['t1']['a@a']['p_0']['00000000_000000'].trx) diff --git a/tests/test_trx.py b/tests/test_trx.py index ef77bde..ad0f32a 100644 --- a/tests/test_trx.py +++ b/tests/test_trx.py @@ -65,11 +65,19 @@ def test_get_string_asarray_false(trx: TRX): s = trx.get_string(['neuron', 'protocol'], asarray=False) print(s) +def test_larva_filtering(trx: TRX): + larvae = trx.get_as_array('numero_larva_num') + larva = larvae[0].item() + x = trx.get_as_array(['t', 'x_head'], l=[larva])[0] + print(x) + + if __name__ == '__main__': - trx = TRX('trx.mat') + trx = TRX('tests/trx.mat') test_get_scalar_list_larvae(trx) test_get_ts_list_larvae(trx) test_get_ts_and_scalar(trx) test_get_bunched_scalar(trx) test_get_string(trx) - test_get_string_asarray_false(trx) \ No newline at end of file + test_get_string_asarray_false(trx) + test_larva_filtering(trx) \ No newline at end of file -- GitLab