From 3d3e5b90279818772e50b2825b43d30e31a8a2e1 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Fran=C3=A7ois=20Laurent?= <francois.laurent@posteo.net>
Date: Mon, 3 Oct 2022 19:15:22 +0200
Subject: [PATCH] prediction from interim files only

---
 .gitignore                            | 4 +++-
 pyproject.toml                        | 2 +-
 src/maggotuba/models/predict_model.py | 9 ++++-----
 src/maggotuba/models/train_model.py   | 4 ++--
 4 files changed, 10 insertions(+), 9 deletions(-)

diff --git a/.gitignore b/.gitignore
index 535f534..96d1243 100644
--- a/.gitignore
+++ b/.gitignore
@@ -20,8 +20,10 @@ poetry.lock
 .env
 env/
 
-# exclude data from source control by default
+# exclude data and models from source control by default
 /data/
+/models/
+/pretrained_models/
 
 # Visual Studio Code
 .vscode/
diff --git a/pyproject.toml b/pyproject.toml
index dc095da..bc69114 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -9,7 +9,7 @@ packages = [
 ]
 
 [tool.poetry.dependencies]
-python = "^3.8,<3.11"
+python = "^3.8,<3.10"
 structured-temporal-convolution = {git = "git@gitlab.pasteur.fr:les-larves/structured-temporal-convolution.git", branch="light-stable-for-tagging"}
 torch = "^1.11.0"
 numpy = "^1.19.3"
diff --git a/src/maggotuba/models/predict_model.py b/src/maggotuba/models/predict_model.py
index 129e070..dd83061 100644
--- a/src/maggotuba/models/predict_model.py
+++ b/src/maggotuba/models/predict_model.py
@@ -14,11 +14,10 @@ def predict_model(backend, **kwargs):
 
     The `predict_model.py` script is required.
     """
-    # in the present case, as make_dataset.py and build_features.py do nothing,
-    # we pick files in `data/raw`
-    input_files = backend.list_input_files()
-    # we could go and pick files in `data/interim` as well:
-    input_files += backend.list_interim_files()
+    # we pick files in `data/interim` if any, otherwise in `data/raw`
+    input_files = backend.list_interim_files()
+    if not input_files:
+        input_files = backend.list_input_files()
     assert 0 < len(input_files)
     # initialize output labels
     input_files, labels = backend.prepare_labels(input_files)
diff --git a/src/maggotuba/models/train_model.py b/src/maggotuba/models/train_model.py
index 9def508..92dcf44 100644
--- a/src/maggotuba/models/train_model.py
+++ b/src/maggotuba/models/train_model.py
@@ -4,12 +4,12 @@ from maggotuba.models.trainers import MaggotTrainer, MultiscaleMaggotTrainer, ne
 import json
 import glob
 
-def train_model(backend, layers=1, pretrained_model_instance="default"):
+def train_model(backend, layers=1, pretrained_model_instance="default", **kwargs):
     # make_dataset generated or moved the larva_dataset file into data/interim/{instance}/
     #larva_dataset_file = backend.list_interim_files("larva_dataset_*.hdf5") # recursive
     larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # not recursive (faster)
     assert len(larva_dataset_file) == 1
-    dataset = LarvaDataset(larva_dataset_file[0], new_generator())
+    dataset = LarvaDataset(larva_dataset_file[0], new_generator(), **kwargs)
     labels = dataset.labels
     assert 0 < len(labels)
     labels = labels if isinstance(labels[0], str) else [s.decode() for s in labels]
-- 
GitLab