diff --git a/src/pytrxmat/repository.py b/src/pytrxmat/repository.py index 5aa76f04fadd1b4b6b40a4b8ce0177fc74d12549..b53d4fdaa598ee21a63f3eefc5eb57d12da8db86 100644 --- a/src/pytrxmat/repository.py +++ b/src/pytrxmat/repository.py @@ -1,5 +1,6 @@ from pathlib import Path from typing import Dict, List +import functools from .trx import TRX class DateTime: @@ -10,7 +11,7 @@ class DateTime: @property def trx(self)->TRX: - return TRX(self.path/'trx.mat') + return functools.partial(TRX, self.path/'trx.mat') class Protocol: def __init__(self, name:str, root:Path): @@ -34,7 +35,7 @@ class Line: self.name = name self.root = root assert(str(root.name) == name) - self._protocols:Dict[str, Protocol] = {p.name:Protocol(p.name, p) for p in self.root.glob('p_*') if p.is_dir()} + self._protocols:Dict[str, Protocol] = {p.name:Protocol(p.name, p) for p in self.root.iterdir() if p.is_dir() and (p.name.startswith('p_') or p.name.startswith('ch_'))} def __getitem__(self, protocol:str)->Protocol: return self._protocols[protocol] @@ -82,8 +83,8 @@ class Tracker: class Repository: - def __init__(self, root:Path, trackers:List[str]): - self.root = root + def __init__(self, root, trackers:List[str]): + self.root = Path(root) self._trackers:Dict[str, Tracker] = {t:Tracker(t, self.root/t) for t in trackers if (self.root/t in self.root.iterdir()) and (self.root/t).is_dir()} def __getitem__(self, tracker:str)->Tracker: diff --git a/tests/test_repository.py b/tests/test_repository.py index 9ec407ff5460f142c44c9871bc4db702fbeb3dd8..7ef2a9137eb784ad825eda56c9283072dde40524 100644 --- a/tests/test_repository.py +++ b/tests/test_repository.py @@ -52,5 +52,6 @@ def test_repository(root): print(dt) for trx in repo.trx(): print(trx) + print(trx()) print(repo['t1']['a@a']['p_0']['00000000_000000'].trx) diff --git a/tests/test_trx.py b/tests/test_trx.py index ef77bde62d29b8171c6c7d364a8fae3df8f0750e..ad0f32a119ab1770bc8177a9734492b9d056ea3b 100644 --- a/tests/test_trx.py +++ b/tests/test_trx.py @@ -65,11 +65,19 @@ def test_get_string_asarray_false(trx: TRX): s = trx.get_string(['neuron', 'protocol'], asarray=False) print(s) +def test_larva_filtering(trx: TRX): + larvae = trx.get_as_array('numero_larva_num') + larva = larvae[0].item() + x = trx.get_as_array(['t', 'x_head'], l=[larva])[0] + print(x) + + if __name__ == '__main__': - trx = TRX('trx.mat') + trx = TRX('tests/trx.mat') test_get_scalar_list_larvae(trx) test_get_ts_list_larvae(trx) test_get_ts_and_scalar(trx) test_get_bunched_scalar(trx) test_get_string(trx) - test_get_string_asarray_false(trx) \ No newline at end of file + test_get_string_asarray_false(trx) + test_larva_filtering(trx) \ No newline at end of file