diff --git a/README.md b/README.md index 885dbf8fe252af68b9971e9046591047d110fa3a..e837c70701e9402405bd9af6d57cb18e0d889710 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ It was trained on a subset of 5000 files from the t5 and t15 databases. Spines w ## Usage -For installation, see [TaggingBackends' README](https://gitlab.pasteur.fr/nyx/TaggingBackends/-/tree/dev#recommanded-installation-and-troubleshooting). +For installation, see [TaggingBackends' README](https://gitlab.pasteur.fr/nyx/TaggingBackends/-/tree/dev#recommended-installation). A MaggotUBA-based tagger is typically called using the `poetry run tagging-backend` command from the backend's project (directory tree). diff --git a/src/maggotuba/models/modules.py b/src/maggotuba/models/modules.py index 76e1bf849719446d9b53f62d87270cad6fecaf6f..011c7f93e57d75e5cc2f57aeaca4d14f920a0416 100644 --- a/src/maggotuba/models/modules.py +++ b/src/maggotuba/models/modules.py @@ -1,9 +1,11 @@ +import os from pathlib import Path import torch from torch import nn import json import functools from behavior_model.models.neural_nets import Encoder +from taggingbackends.explorer import check_permissions class MaggotModule(nn.Module): def __init__(self, path, cfgfile=None, ptfile=None): @@ -66,12 +68,14 @@ class MaggotModule(nn.Module): path = self.path / cfgfile with open(path, "w") as f: json.dump(self.config, f, indent=2) + check_permissions(path) return path def save_model(self, ptfile=None): if ptfile is None: ptfile = self.ptfile path = self.path / ptfile torch.save(self.model.state_dict(), path) + check_permissions(path) return path def save(self): @@ -81,6 +85,7 @@ class MaggotModule(nn.Module): def parameters(self, recurse=True): return self.model.parameters(recurse) + class MaggotEncoder(MaggotModule): def __init__(self, path, cfgfile=None, @@ -225,6 +230,7 @@ class DeepLinear(nn.Module): def save(self, path): torch.save(self.state_dict(), path) + check_permissions(path) class MaggotClassifier(MaggotModule): def __init__(self, path, behavior_labels=[], n_latent_features=None, diff --git a/src/maggotuba/models/predict_model.py b/src/maggotuba/models/predict_model.py index 13997ceae857135dc519fd870319fb0e37a409d6..78e4fd91b0d932f161b1bdde1d24ff8bbde8319d 100644 --- a/src/maggotuba/models/predict_model.py +++ b/src/maggotuba/models/predict_model.py @@ -26,7 +26,8 @@ def predict_model(backend, **kwargs): model_files = backend.list_model_files() config_file = [file for file in model_files if file.name.endswith("config.json")] n_config_files = len(config_file) - assert 1 < n_config_files + if n_config_files == 0: + raise RuntimeError(f"no such tagger found: {backend.model_instance}") config_file = [file for file in config_file if file.name.endswith("clf_config.json") diff --git a/src/maggotuba/models/train_model.py b/src/maggotuba/models/train_model.py index 92dcf4452066449f5cb5b9c8cf4726562dedde30..cda4a58655c0dfcdcbd6382f17461ff82c206ef0 100644 --- a/src/maggotuba/models/train_model.py +++ b/src/maggotuba/models/train_model.py @@ -1,15 +1,17 @@ from taggingbackends.data.labels import Labels from taggingbackends.data.dataset import LarvaDataset +from taggingbackends.explorer import check_permissions from maggotuba.models.trainers import MaggotTrainer, MultiscaleMaggotTrainer, new_generator import json import glob -def train_model(backend, layers=1, pretrained_model_instance="default", **kwargs): +def train_model(backend, layers=1, pretrained_model_instance="default", subsets=(1, 0, 0), **kwargs): # make_dataset generated or moved the larva_dataset file into data/interim/{instance}/ #larva_dataset_file = backend.list_interim_files("larva_dataset_*.hdf5") # recursive larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # not recursive (faster) assert len(larva_dataset_file) == 1 - dataset = LarvaDataset(larva_dataset_file[0], new_generator(), **kwargs) + # subsets=(1, 0, 0) => all data are training data; no validation or test subsets + dataset = LarvaDataset(larva_dataset_file[0], new_generator(), subsets=subsets, **kwargs) labels = dataset.labels assert 0 < len(labels) labels = labels if isinstance(labels[0], str) else [s.decode() for s in labels] @@ -23,6 +25,7 @@ def train_model(backend, layers=1, pretrained_model_instance="default", **kwargs model = MultiscaleMaggotTrainer(config_files, labels, layers) # fine-tune and save the model model.train(dataset) + print(f"saving model \"{backend.model_instance}\"") model.save() # TODO: merge the below two functions @@ -51,6 +54,7 @@ def import_pretrained_model(backend, pretrained_model_instance): with open(file, "rb") as i: with open(dst, "wb") as o: o.write(i.read()) + check_permissions(dst) return config_file def import_pretrained_models(backend, model_instances): @@ -77,6 +81,7 @@ def import_pretrained_models(backend, model_instances): with open(file, "rb") as i: with open(dst, "wb") as o: o.write(i.read()) + check_permissions(dst) assert config_file is not None config_files.append(config_file) return config_files