Skip to content
Snippets Groups Projects
Commit 68fa827e authored by François  LAURENT's avatar François LAURENT
Browse files

implements larvatagger.jl#58 at this project level

parent cd8edef1
No related branches found
No related tags found
No related merge requests found
......@@ -28,7 +28,7 @@ It was trained on a subset of 5000 files from the t5 and t15 databases. Spines w
## Usage
For installation, see [TaggingBackends' README](https://gitlab.pasteur.fr/nyx/TaggingBackends/-/tree/dev#recommanded-installation-and-troubleshooting).
For installation, see [TaggingBackends' README](https://gitlab.pasteur.fr/nyx/TaggingBackends/-/tree/dev#recommended-installation).
A MaggotUBA-based tagger is typically called using the `poetry run tagging-backend` command from the backend's project (directory tree).
......
import os
from pathlib import Path
import torch
from torch import nn
import json
import functools
from behavior_model.models.neural_nets import Encoder
from taggingbackends.explorer import check_permissions
class MaggotModule(nn.Module):
def __init__(self, path, cfgfile=None, ptfile=None):
......@@ -66,12 +68,14 @@ class MaggotModule(nn.Module):
path = self.path / cfgfile
with open(path, "w") as f:
json.dump(self.config, f, indent=2)
check_permissions(path)
return path
def save_model(self, ptfile=None):
if ptfile is None: ptfile = self.ptfile
path = self.path / ptfile
torch.save(self.model.state_dict(), path)
check_permissions(path)
return path
def save(self):
......@@ -81,6 +85,7 @@ class MaggotModule(nn.Module):
def parameters(self, recurse=True):
return self.model.parameters(recurse)
class MaggotEncoder(MaggotModule):
def __init__(self, path,
cfgfile=None,
......@@ -225,6 +230,7 @@ class DeepLinear(nn.Module):
def save(self, path):
torch.save(self.state_dict(), path)
check_permissions(path)
class MaggotClassifier(MaggotModule):
def __init__(self, path, behavior_labels=[], n_latent_features=None,
......
......@@ -26,7 +26,8 @@ def predict_model(backend, **kwargs):
model_files = backend.list_model_files()
config_file = [file for file in model_files if file.name.endswith("config.json")]
n_config_files = len(config_file)
assert 1 < n_config_files
if n_config_files == 0:
raise RuntimeError(f"no such tagger found: {backend.model_instance}")
config_file = [file
for file in config_file
if file.name.endswith("clf_config.json")
......
from taggingbackends.data.labels import Labels
from taggingbackends.data.dataset import LarvaDataset
from taggingbackends.explorer import check_permissions
from maggotuba.models.trainers import MaggotTrainer, MultiscaleMaggotTrainer, new_generator
import json
import glob
def train_model(backend, layers=1, pretrained_model_instance="default", **kwargs):
def train_model(backend, layers=1, pretrained_model_instance="default", subsets=(1, 0, 0), **kwargs):
# make_dataset generated or moved the larva_dataset file into data/interim/{instance}/
#larva_dataset_file = backend.list_interim_files("larva_dataset_*.hdf5") # recursive
larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # not recursive (faster)
assert len(larva_dataset_file) == 1
dataset = LarvaDataset(larva_dataset_file[0], new_generator(), **kwargs)
# subsets=(1, 0, 0) => all data are training data; no validation or test subsets
dataset = LarvaDataset(larva_dataset_file[0], new_generator(), subsets=subsets, **kwargs)
labels = dataset.labels
assert 0 < len(labels)
labels = labels if isinstance(labels[0], str) else [s.decode() for s in labels]
......@@ -23,6 +25,7 @@ def train_model(backend, layers=1, pretrained_model_instance="default", **kwargs
model = MultiscaleMaggotTrainer(config_files, labels, layers)
# fine-tune and save the model
model.train(dataset)
print(f"saving model \"{backend.model_instance}\"")
model.save()
# TODO: merge the below two functions
......@@ -51,6 +54,7 @@ def import_pretrained_model(backend, pretrained_model_instance):
with open(file, "rb") as i:
with open(dst, "wb") as o:
o.write(i.read())
check_permissions(dst)
return config_file
def import_pretrained_models(backend, model_instances):
......@@ -77,6 +81,7 @@ def import_pretrained_models(backend, model_instances):
with open(file, "rb") as i:
with open(dst, "wb") as o:
o.write(i.read())
check_permissions(dst)
assert config_file is not None
config_files.append(config_file)
return config_files
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment