Skip to content
Snippets Groups Projects
Commit 68fa827e authored by François  LAURENT's avatar François LAURENT
Browse files

implements larvatagger.jl#58 at this project level

parent cd8edef1
No related branches found
No related tags found
No related merge requests found
...@@ -28,7 +28,7 @@ It was trained on a subset of 5000 files from the t5 and t15 databases. Spines w ...@@ -28,7 +28,7 @@ It was trained on a subset of 5000 files from the t5 and t15 databases. Spines w
## Usage ## Usage
For installation, see [TaggingBackends' README](https://gitlab.pasteur.fr/nyx/TaggingBackends/-/tree/dev#recommanded-installation-and-troubleshooting). For installation, see [TaggingBackends' README](https://gitlab.pasteur.fr/nyx/TaggingBackends/-/tree/dev#recommended-installation).
A MaggotUBA-based tagger is typically called using the `poetry run tagging-backend` command from the backend's project (directory tree). A MaggotUBA-based tagger is typically called using the `poetry run tagging-backend` command from the backend's project (directory tree).
......
import os
from pathlib import Path from pathlib import Path
import torch import torch
from torch import nn from torch import nn
import json import json
import functools import functools
from behavior_model.models.neural_nets import Encoder from behavior_model.models.neural_nets import Encoder
from taggingbackends.explorer import check_permissions
class MaggotModule(nn.Module): class MaggotModule(nn.Module):
def __init__(self, path, cfgfile=None, ptfile=None): def __init__(self, path, cfgfile=None, ptfile=None):
...@@ -66,12 +68,14 @@ class MaggotModule(nn.Module): ...@@ -66,12 +68,14 @@ class MaggotModule(nn.Module):
path = self.path / cfgfile path = self.path / cfgfile
with open(path, "w") as f: with open(path, "w") as f:
json.dump(self.config, f, indent=2) json.dump(self.config, f, indent=2)
check_permissions(path)
return path return path
def save_model(self, ptfile=None): def save_model(self, ptfile=None):
if ptfile is None: ptfile = self.ptfile if ptfile is None: ptfile = self.ptfile
path = self.path / ptfile path = self.path / ptfile
torch.save(self.model.state_dict(), path) torch.save(self.model.state_dict(), path)
check_permissions(path)
return path return path
def save(self): def save(self):
...@@ -81,6 +85,7 @@ class MaggotModule(nn.Module): ...@@ -81,6 +85,7 @@ class MaggotModule(nn.Module):
def parameters(self, recurse=True): def parameters(self, recurse=True):
return self.model.parameters(recurse) return self.model.parameters(recurse)
class MaggotEncoder(MaggotModule): class MaggotEncoder(MaggotModule):
def __init__(self, path, def __init__(self, path,
cfgfile=None, cfgfile=None,
...@@ -225,6 +230,7 @@ class DeepLinear(nn.Module): ...@@ -225,6 +230,7 @@ class DeepLinear(nn.Module):
def save(self, path): def save(self, path):
torch.save(self.state_dict(), path) torch.save(self.state_dict(), path)
check_permissions(path)
class MaggotClassifier(MaggotModule): class MaggotClassifier(MaggotModule):
def __init__(self, path, behavior_labels=[], n_latent_features=None, def __init__(self, path, behavior_labels=[], n_latent_features=None,
......
...@@ -26,7 +26,8 @@ def predict_model(backend, **kwargs): ...@@ -26,7 +26,8 @@ def predict_model(backend, **kwargs):
model_files = backend.list_model_files() model_files = backend.list_model_files()
config_file = [file for file in model_files if file.name.endswith("config.json")] config_file = [file for file in model_files if file.name.endswith("config.json")]
n_config_files = len(config_file) n_config_files = len(config_file)
assert 1 < n_config_files if n_config_files == 0:
raise RuntimeError(f"no such tagger found: {backend.model_instance}")
config_file = [file config_file = [file
for file in config_file for file in config_file
if file.name.endswith("clf_config.json") if file.name.endswith("clf_config.json")
......
from taggingbackends.data.labels import Labels from taggingbackends.data.labels import Labels
from taggingbackends.data.dataset import LarvaDataset from taggingbackends.data.dataset import LarvaDataset
from taggingbackends.explorer import check_permissions
from maggotuba.models.trainers import MaggotTrainer, MultiscaleMaggotTrainer, new_generator from maggotuba.models.trainers import MaggotTrainer, MultiscaleMaggotTrainer, new_generator
import json import json
import glob import glob
def train_model(backend, layers=1, pretrained_model_instance="default", **kwargs): def train_model(backend, layers=1, pretrained_model_instance="default", subsets=(1, 0, 0), **kwargs):
# make_dataset generated or moved the larva_dataset file into data/interim/{instance}/ # make_dataset generated or moved the larva_dataset file into data/interim/{instance}/
#larva_dataset_file = backend.list_interim_files("larva_dataset_*.hdf5") # recursive #larva_dataset_file = backend.list_interim_files("larva_dataset_*.hdf5") # recursive
larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # not recursive (faster) larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # not recursive (faster)
assert len(larva_dataset_file) == 1 assert len(larva_dataset_file) == 1
dataset = LarvaDataset(larva_dataset_file[0], new_generator(), **kwargs) # subsets=(1, 0, 0) => all data are training data; no validation or test subsets
dataset = LarvaDataset(larva_dataset_file[0], new_generator(), subsets=subsets, **kwargs)
labels = dataset.labels labels = dataset.labels
assert 0 < len(labels) assert 0 < len(labels)
labels = labels if isinstance(labels[0], str) else [s.decode() for s in labels] labels = labels if isinstance(labels[0], str) else [s.decode() for s in labels]
...@@ -23,6 +25,7 @@ def train_model(backend, layers=1, pretrained_model_instance="default", **kwargs ...@@ -23,6 +25,7 @@ def train_model(backend, layers=1, pretrained_model_instance="default", **kwargs
model = MultiscaleMaggotTrainer(config_files, labels, layers) model = MultiscaleMaggotTrainer(config_files, labels, layers)
# fine-tune and save the model # fine-tune and save the model
model.train(dataset) model.train(dataset)
print(f"saving model \"{backend.model_instance}\"")
model.save() model.save()
# TODO: merge the below two functions # TODO: merge the below two functions
...@@ -51,6 +54,7 @@ def import_pretrained_model(backend, pretrained_model_instance): ...@@ -51,6 +54,7 @@ def import_pretrained_model(backend, pretrained_model_instance):
with open(file, "rb") as i: with open(file, "rb") as i:
with open(dst, "wb") as o: with open(dst, "wb") as o:
o.write(i.read()) o.write(i.read())
check_permissions(dst)
return config_file return config_file
def import_pretrained_models(backend, model_instances): def import_pretrained_models(backend, model_instances):
...@@ -77,6 +81,7 @@ def import_pretrained_models(backend, model_instances): ...@@ -77,6 +81,7 @@ def import_pretrained_models(backend, model_instances):
with open(file, "rb") as i: with open(file, "rb") as i:
with open(dst, "wb") as o: with open(dst, "wb") as o:
o.write(i.read()) o.write(i.read())
check_permissions(dst)
assert config_file is not None assert config_file is not None
config_files.append(config_file) config_files.append(config_file)
return config_files return config_files
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment