diff --git a/.gitignore b/.gitignore index 535f53400b9a540a1a4d007f733b5ec2b33a63d2..96d12431db3c9cd6b2188e8b535b2c8f8e2b58c0 100644 --- a/.gitignore +++ b/.gitignore @@ -20,8 +20,10 @@ poetry.lock .env env/ -# exclude data from source control by default +# exclude data and models from source control by default /data/ +/models/ +/pretrained_models/ # Visual Studio Code .vscode/ diff --git a/pyproject.toml b/pyproject.toml index dc095dae7e2220542cc46dc7febdee0d26de3639..bc6911401aab3336e2026ac7574077b188400fb5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ packages = [ ] [tool.poetry.dependencies] -python = "^3.8,<3.11" +python = "^3.8,<3.10" structured-temporal-convolution = {git = "git@gitlab.pasteur.fr:les-larves/structured-temporal-convolution.git", branch="light-stable-for-tagging"} torch = "^1.11.0" numpy = "^1.19.3" diff --git a/src/maggotuba/models/predict_model.py b/src/maggotuba/models/predict_model.py index 129e0706cb2608a58d02ccf1273fff3fe699461f..dd8306193d52f16705153f110f367f4e4ff84d58 100644 --- a/src/maggotuba/models/predict_model.py +++ b/src/maggotuba/models/predict_model.py @@ -14,11 +14,10 @@ def predict_model(backend, **kwargs): The `predict_model.py` script is required. """ - # in the present case, as make_dataset.py and build_features.py do nothing, - # we pick files in `data/raw` - input_files = backend.list_input_files() - # we could go and pick files in `data/interim` as well: - input_files += backend.list_interim_files() + # we pick files in `data/interim` if any, otherwise in `data/raw` + input_files = backend.list_interim_files() + if not input_files: + input_files = backend.list_input_files() assert 0 < len(input_files) # initialize output labels input_files, labels = backend.prepare_labels(input_files) diff --git a/src/maggotuba/models/train_model.py b/src/maggotuba/models/train_model.py index 9def508f6820bbd90549227b4ff00ff39c2bdeb2..92dcf4452066449f5cb5b9c8cf4726562dedde30 100644 --- a/src/maggotuba/models/train_model.py +++ b/src/maggotuba/models/train_model.py @@ -4,12 +4,12 @@ from maggotuba.models.trainers import MaggotTrainer, MultiscaleMaggotTrainer, ne import json import glob -def train_model(backend, layers=1, pretrained_model_instance="default"): +def train_model(backend, layers=1, pretrained_model_instance="default", **kwargs): # make_dataset generated or moved the larva_dataset file into data/interim/{instance}/ #larva_dataset_file = backend.list_interim_files("larva_dataset_*.hdf5") # recursive larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # not recursive (faster) assert len(larva_dataset_file) == 1 - dataset = LarvaDataset(larva_dataset_file[0], new_generator()) + dataset = LarvaDataset(larva_dataset_file[0], new_generator(), **kwargs) labels = dataset.labels assert 0 < len(labels) labels = labels if isinstance(labels[0], str) else [s.decode() for s in labels]