From 840881723c366fbc7302c83c6fccf4e466b769b2 Mon Sep 17 00:00:00 2001
From: Alexandre Blanc <alexandre.blanc@pasteur.fr>
Date: Fri, 11 Oct 2024 18:17:40 +0200
Subject: [PATCH] Change trx syntax in repository, fix larva indexing

---
 src/pytrxmat/repository.py |  9 +++++----
 tests/test_repository.py   |  1 +
 tests/test_trx.py          | 12 ++++++++++--
 3 files changed, 16 insertions(+), 6 deletions(-)

diff --git a/src/pytrxmat/repository.py b/src/pytrxmat/repository.py
index 5aa76f0..b53d4fd 100644
--- a/src/pytrxmat/repository.py
+++ b/src/pytrxmat/repository.py
@@ -1,5 +1,6 @@
 from pathlib import Path
 from typing import Dict, List
+import functools
 from .trx import TRX
 
 class DateTime:
@@ -10,7 +11,7 @@ class DateTime:
     
     @property
     def trx(self)->TRX:
-        return TRX(self.path/'trx.mat')
+        return functools.partial(TRX, self.path/'trx.mat')
 
 class Protocol:
     def __init__(self, name:str, root:Path):
@@ -34,7 +35,7 @@ class Line:
         self.name = name
         self.root = root
         assert(str(root.name) == name)
-        self._protocols:Dict[str, Protocol] = {p.name:Protocol(p.name, p) for p in self.root.glob('p_*') if p.is_dir()}
+        self._protocols:Dict[str, Protocol] = {p.name:Protocol(p.name, p) for p in self.root.iterdir() if p.is_dir() and (p.name.startswith('p_') or p.name.startswith('ch_'))}
 
     def __getitem__(self, protocol:str)->Protocol:
         return self._protocols[protocol]
@@ -82,8 +83,8 @@ class Tracker:
 
 
 class Repository:
-    def __init__(self, root:Path, trackers:List[str]):
-        self.root = root
+    def __init__(self, root, trackers:List[str]):
+        self.root = Path(root)
         self._trackers:Dict[str, Tracker] = {t:Tracker(t, self.root/t) for t in trackers if (self.root/t in self.root.iterdir()) and (self.root/t).is_dir()}
 
     def __getitem__(self, tracker:str)->Tracker:
diff --git a/tests/test_repository.py b/tests/test_repository.py
index 9ec407f..7ef2a91 100644
--- a/tests/test_repository.py
+++ b/tests/test_repository.py
@@ -52,5 +52,6 @@ def test_repository(root):
         print(dt)
     for trx in repo.trx():
         print(trx)
+        print(trx())
 
     print(repo['t1']['a@a']['p_0']['00000000_000000'].trx)
diff --git a/tests/test_trx.py b/tests/test_trx.py
index ef77bde..ad0f32a 100644
--- a/tests/test_trx.py
+++ b/tests/test_trx.py
@@ -65,11 +65,19 @@ def test_get_string_asarray_false(trx: TRX):
     s = trx.get_string(['neuron', 'protocol'], asarray=False)
     print(s)
 
+def test_larva_filtering(trx: TRX):
+    larvae = trx.get_as_array('numero_larva_num')
+    larva = larvae[0].item()
+    x = trx.get_as_array(['t', 'x_head'], l=[larva])[0]
+    print(x)
+
+
 if __name__ == '__main__':
-    trx = TRX('trx.mat')
+    trx = TRX('tests/trx.mat')
     test_get_scalar_list_larvae(trx)
     test_get_ts_list_larvae(trx)
     test_get_ts_and_scalar(trx)
     test_get_bunched_scalar(trx)
     test_get_string(trx)
-    test_get_string_asarray_false(trx)
\ No newline at end of file
+    test_get_string_asarray_false(trx)
+    test_larva_filtering(trx)
\ No newline at end of file
-- 
GitLab