From 4b68a149cb6963d9130d6514cba85bbd0c7e1059 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Fran=C3=A7ois=20Laurent?= <francois.laurent@posteo.net>
Date: Fri, 13 Jan 2023 20:07:11 +0100
Subject: [PATCH] implements
 https://gitlab.pasteur.fr/nyx/larvatagger.jl/-/issues/85 on the
 TaggingBackends side

---
 src/taggingbackends/explorer.py | 26 ++++++++++++++++++++------
 src/taggingbackends/main.py     | 14 ++++++++++++--
 2 files changed, 32 insertions(+), 8 deletions(-)

diff --git a/src/taggingbackends/explorer.py b/src/taggingbackends/explorer.py
index 2f0e14f..803eb72 100644
--- a/src/taggingbackends/explorer.py
+++ b/src/taggingbackends/explorer.py
@@ -7,6 +7,7 @@ import fnmatch
 import importlib
 import logging
 import subprocess
+import tempfile
 from collections import defaultdict
 
 JULIA_PROJECT = os.environ.get('JULIA_PROJECT', '')
@@ -47,7 +48,8 @@ class BackendExplorer:
     Locator for paths to data, scripts, model instances, etc.
     """
 
-    def __init__(self, project_dir=None, package_name=None, model_instance=None):
+    def __init__(self, project_dir=None, package_name=None, model_instance=None,
+                 sandbox=None):
         self.project_dir = pathlib.Path(os.getcwd() if project_dir is None else project_dir)
         logging.debug(f"project directory: {self.project_dir}")
         self._package_name = package_name
@@ -57,6 +59,8 @@ class BackendExplorer:
         self._build_features = None
         self._train_model = None
         self._predict_model = None
+        #
+        self._sandbox = sandbox
 
     @property
     def package_name(self):
@@ -270,6 +274,15 @@ run `poetry add {pkg}` from directory: \n
             raise
         return pkg
 
+    @property
+    def sandbox(self):
+        if self._sandbox is False:
+            self._sandbox = None
+        elif self._sandbox is True:
+            self._sandbox = pathlib.Path(tempfile.mkdtemp(dir=self.project_dir / 'data' / 'raw')).name
+            logging.info(f"sandboxing in {self._sandbox}")
+        return self._sandbox
+
     def _model_dir(self, parent_dir, model_instance=None, create_if_missing=True):
         if model_instance is None:
             model_instance = self.model_instance
@@ -281,19 +294,19 @@ run `poetry add {pkg}` from directory: \n
     def raw_data_dir(self, model_instance=None, create_if_missing=True):
         return self._model_dir(
                 self.project_dir / "data" / "raw",
-                model_instance,
+                self.sandbox if model_instance is None else model_instance,
                 create_if_missing)
 
     def interim_data_dir(self, model_instance=None, create_if_missing=True):
         return self._model_dir(
                 self.project_dir / "data" / "interim",
-                model_instance,
+                self.sandbox if model_instance is None else model_instance,
                 create_if_missing)
 
     def processed_data_dir(self, model_instance=None, create_if_missing=True):
         return self._model_dir(
                 self.project_dir / "data" / "processed",
-                model_instance,
+                self.sandbox if model_instance is None else model_instance,
                 create_if_missing)
 
     def model_dir(self, model_instance=None, create_if_missing=True):
@@ -480,6 +493,7 @@ run `poetry add {pkg}` from directory: \n
         interim *.h5* data files in data/interim/{instance}/ and generate a
         *larva_dataset hdf5* file similarly to `generate_dataset`.
         """
+        logging.warning('BackendExplorer.compile_trxmat_database is deprecated and will soon be removed')
         input_dir = str(input_dir) # in the case input_dir is a pathlib.Path
         interim_dir = str(self.interim_data_dir())
         if not reuse_h5files:
@@ -506,7 +520,7 @@ run `poetry add {pkg}` from directory: \n
             met = dict(raw=self.raw_data_dir,
                        interim=self.interim_data_dir,
                        processed=self.processed_data_dir,
-                       )[dir]
+                      )[dir]
             shutil.rmtree(met(model_instance, False), ignore_errors=True)
 
     def reset_model(self, model_instance=None):
@@ -523,7 +537,7 @@ class BackendExplorerEncoder(json.JSONEncoder):
     def default(self, explorer):
         if isinstance(explorer, BackendExplorer):
             data = {}
-            for attr in ("project_dir", "package_name", "model_instance"):
+            for attr in ("project_dir", "package_name", "model_instance", "sandbox"):
                 try:
                     val = getattr(explorer, attr)
                 except AttributeError:
diff --git a/src/taggingbackends/main.py b/src/taggingbackends/main.py
index b167129..e5c435e 100644
--- a/src/taggingbackends/main.py
+++ b/src/taggingbackends/main.py
@@ -10,7 +10,7 @@ Usage:  tagging-backend [train|predict] --model-instance <name>
         tagging-backend train ... --sample-size <N>
         tagging-backend train ... --frame-interval <I> --window-length <T>
         tagging-backend train ... --pretrained-model-instance <name>
-        tagging-backend predict ... --skip-make-dataset
+        tagging-backend predict ... --skip-make-dataset --sandbox <token>
 
 `tagging-backend` typically is run using `poetry run`.
 A name must be provided to identify the trained model and its location within
@@ -35,6 +35,11 @@ Note that an existing larva_dataset file in data/interim/<name> makes the
 the `make_dataset` module is loaded and this may take quite some time due to
 dependencies (e.g. Julia FFI). The `--skip-make-dataset` option makes `train`
 truly skip this step; the corresponding module is not loaded.
+
+`--sandbox <token>` makes `tagging-backend` use a token instead of <name> as
+directory name in data/raw, data/interim and data/processed.
+This is intended to prevent conflicts on running `predict` in parallel on
+multiple data files with multiple calls.
 """
     if _print:
         print(msg)
@@ -56,6 +61,7 @@ def main(fun=None):
         trxmat_only = reuse_h5files = False
         skip_make_dataset = skip_build_features = False
         pretrained_model_instance = None
+        sandbox = False
         unknown_args = {}
         k = 2
         while k < len(sys.argv):
@@ -91,11 +97,15 @@ def main(fun=None):
             elif sys.argv[k] == "--pretrained-model-instance":
                 k = k + 1
                 pretrained_model_instance = sys.argv[k]
+            elif sys.argv[k] == "--sandbox":
+                k = k + 1
+                sandbox = sys.argv[k]
             else:
                 unknown_args[sys.argv[k].lstrip('-').replace('-', '_')] = sys.argv[k+1]
                 k = k + 1
             k = k + 1
-        backend = BackendExplorer(project_dir, model_instance=model_instance)
+        backend = BackendExplorer(project_dir, model_instance=model_instance,
+                                  sandbox=sandbox)
         backend.reset_data(spare_raw=True)
         sys.stderr.flush()
         sys.stdout.flush()
-- 
GitLab