From 3d3e5b90279818772e50b2825b43d30e31a8a2e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Laurent?= <francois.laurent@posteo.net> Date: Mon, 3 Oct 2022 19:15:22 +0200 Subject: [PATCH] prediction from interim files only --- .gitignore | 4 +++- pyproject.toml | 2 +- src/maggotuba/models/predict_model.py | 9 ++++----- src/maggotuba/models/train_model.py | 4 ++-- 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/.gitignore b/.gitignore index 535f534..96d1243 100644 --- a/.gitignore +++ b/.gitignore @@ -20,8 +20,10 @@ poetry.lock .env env/ -# exclude data from source control by default +# exclude data and models from source control by default /data/ +/models/ +/pretrained_models/ # Visual Studio Code .vscode/ diff --git a/pyproject.toml b/pyproject.toml index dc095da..bc69114 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ packages = [ ] [tool.poetry.dependencies] -python = "^3.8,<3.11" +python = "^3.8,<3.10" structured-temporal-convolution = {git = "git@gitlab.pasteur.fr:les-larves/structured-temporal-convolution.git", branch="light-stable-for-tagging"} torch = "^1.11.0" numpy = "^1.19.3" diff --git a/src/maggotuba/models/predict_model.py b/src/maggotuba/models/predict_model.py index 129e070..dd83061 100644 --- a/src/maggotuba/models/predict_model.py +++ b/src/maggotuba/models/predict_model.py @@ -14,11 +14,10 @@ def predict_model(backend, **kwargs): The `predict_model.py` script is required. """ - # in the present case, as make_dataset.py and build_features.py do nothing, - # we pick files in `data/raw` - input_files = backend.list_input_files() - # we could go and pick files in `data/interim` as well: - input_files += backend.list_interim_files() + # we pick files in `data/interim` if any, otherwise in `data/raw` + input_files = backend.list_interim_files() + if not input_files: + input_files = backend.list_input_files() assert 0 < len(input_files) # initialize output labels input_files, labels = backend.prepare_labels(input_files) diff --git a/src/maggotuba/models/train_model.py b/src/maggotuba/models/train_model.py index 9def508..92dcf44 100644 --- a/src/maggotuba/models/train_model.py +++ b/src/maggotuba/models/train_model.py @@ -4,12 +4,12 @@ from maggotuba.models.trainers import MaggotTrainer, MultiscaleMaggotTrainer, ne import json import glob -def train_model(backend, layers=1, pretrained_model_instance="default"): +def train_model(backend, layers=1, pretrained_model_instance="default", **kwargs): # make_dataset generated or moved the larva_dataset file into data/interim/{instance}/ #larva_dataset_file = backend.list_interim_files("larva_dataset_*.hdf5") # recursive larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # not recursive (faster) assert len(larva_dataset_file) == 1 - dataset = LarvaDataset(larva_dataset_file[0], new_generator()) + dataset = LarvaDataset(larva_dataset_file[0], new_generator(), **kwargs) labels = dataset.labels assert 0 < len(labels) labels = labels if isinstance(labels[0], str) else [s.decode() for s in labels] -- GitLab