diff --git a/src/maggotuba/models/modules.py b/src/maggotuba/models/modules.py index 8ecfab77c40c894bd3c49ff0a9b394df5e956b4c..7f1a757b8b3d30baa3d6e7ac0b44a47b0de8dbe5 100644 --- a/src/maggotuba/models/modules.py +++ b/src/maggotuba/models/modules.py @@ -122,9 +122,20 @@ class MaggotEncoders(nn.Module): paths = path self.encoders = [cls(path, cfgfile, **kwargs) for path in paths] + def __iter__(self): + return iter(self.encoders) + def forward(self, x): return torch.cat([encoder(x) for encoder in self.encoders]) + @property + def cfgfilepaths(self): + return [enc.cfgfilepath for enc in self.encoders] + + @property + def ptfilepaths(self): + return [enc.ptfilepath for enc in self.encoders] + def save_config(self, cfgfile=None): for encoder in self.encoders: encoder.save_config(cfgfile) @@ -249,8 +260,33 @@ class SupervisedMaggot(nn.Module): return self.clf(self.encoder(x)) def save(self): - self.encoder.save() - self.clf.config["autoencoder_config"] = str(self.encoder.cfgfilepath) - self.clf.config["enc_path"] = str(self.encoder.ptfilepath) - self.clf.save() + enc, clf = self.encoder, self.clf + enc.save() + clf.config["autoencoder_config"] = str(enc.cfgfilepath) + clf.config["enc_path"] = str(enc.ptfilepath) + clf.save() + +class MultiscaleSupervisedMaggot(nn.Module): + def __init__(self, cfgfilepath, behaviors=[], n_layers=1): + super().__init__() + if behaviors: # the model is only pre-trained + self.encoders = MaggotEncoders(cfgfilepath, cls=PretrainedMaggotEncoder) + path = self.encoders[0].path.parent + n_latent_features = sum(enc.config["dim_latent"] for enc in self.encoders) + self.clf = MaggotClassifier(path / "clf_config.json", + behaviors, n_latent_features, n_layers) + else: # the model has been retrained + self.clf = MaggotClassifier(cfgfilepath) + self.encoders = MaggotEncoders(self.clf.config["autoencoder_config"], + ptfile=self.clf.config["enc_path"]) + + def forward(self, x): + return self.clf(self.encoders(x)) + + def save(self): + enc, clf = self.encoders, self.clf + enc.save() + clf.config["autoencoder_config"] = [str(p) for p in enc.cfgfilepaths] + clf.config["enc_path"] = [str(p) for p in enc.ptfilepaths] + clf.save() diff --git a/src/maggotuba/models/predict_model.py b/src/maggotuba/models/predict_model.py index 7af062ad4d858726db3d5e292a43f5c82e7cff91..129e0706cb2608a58d02ccf1273fff3fe699461f 100644 --- a/src/maggotuba/models/predict_model.py +++ b/src/maggotuba/models/predict_model.py @@ -1,6 +1,6 @@ from taggingbackends.data.labels import Labels from taggingbackends.features.skeleton import get_5point_spines -from maggotuba.models.denselayer import DenseLayer, new_generator +from maggotuba.models.trainers import MaggotTrainer, MultiscaleMaggotTrainer, new_generator import numpy as np def predict_model(backend, **kwargs): @@ -26,9 +26,18 @@ def predict_model(backend, **kwargs): # load the model model_files = backend.list_model_files() config_file = [file for file in model_files if file.name.endswith("config.json")] - if 1 < len(config_file): - config_file = [file for file in config_file if file.name.endswith("clf_config.json")] - model = DenseLayer(config_file[-1]) + n_config_files = len(config_file) + assert 1 < n_config_files + config_file = [file + for file in config_file + if file.name.endswith("clf_config.json") + and file.parent == backend.model_dir()] + assert len(config_file) == 1 + config_file = config_file[-1] + if 2 < n_config_files: + model = MultiscaleMaggotTrainer(config_file) + else: + model = MaggotTrainer(config_file) # labels.labelspec = model.clf_config["behavior_labels"] # diff --git a/src/maggotuba/models/train_model.py b/src/maggotuba/models/train_model.py index 631ba7c54ba005ec0d3b38114cdd48891e6400b4..9def508f6820bbd90549227b4ff00ff39c2bdeb2 100644 --- a/src/maggotuba/models/train_model.py +++ b/src/maggotuba/models/train_model.py @@ -1,6 +1,6 @@ from taggingbackends.data.labels import Labels from taggingbackends.data.dataset import LarvaDataset -from maggotuba.models.denselayer import DenseLayer, new_generator +from maggotuba.models.trainers import MaggotTrainer, MultiscaleMaggotTrainer, new_generator import json import glob @@ -10,9 +10,24 @@ def train_model(backend, layers=1, pretrained_model_instance="default"): larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # not recursive (faster) assert len(larva_dataset_file) == 1 dataset = LarvaDataset(larva_dataset_file[0], new_generator()) - nlabels = len(dataset.labels) - assert 0 < nlabels - # copy the pretrained model into the model instance directory + labels = dataset.labels + assert 0 < len(labels) + labels = labels if isinstance(labels[0], str) else [s.decode() for s in labels] + # copy and load the pretrained model into the model instance directory + if isinstance(pretrained_model_instance, str): + config_file = import_pretrained_model(backend, pretrained_model_instance) + model = MaggotTrainer(config_file, labels, layers) + else: + pretrained_model_instances = pretrained_model_instance + config_files = import_pretrained_models(backend, pretrained_model_instances) + model = MultiscaleMaggotTrainer(config_files, labels, layers) + # fine-tune and save the model + model.train(dataset) + model.save() + +# TODO: merge the below two functions + +def import_pretrained_model(backend, pretrained_model_instance): pretrained_autoencoder_dir = backend.project_dir / "pretrained_models" / pretrained_model_instance config_file = None for file in pretrained_autoencoder_dir.iterdir(): @@ -20,7 +35,7 @@ def train_model(backend, layers=1, pretrained_model_instance="default"): continue dst = backend.model_dir() / file.name if file.name.endswith("config.json"): - with open(str(file)) as f: + with open(file) as f: config = json.load(f) dir = backend.model_dir().relative_to(backend.project_dir) config["log_dir"] = str(dir) @@ -28,21 +43,43 @@ def train_model(backend, layers=1, pretrained_model_instance="default"): #config["project_dir"] = config["exp_folder"] = str(dir) #config["exp_name"] = backend.model_instance #config["config"] = str(dir / os.path.basename(config["config"])) - with open(str(dst), "w") as f: + with open(dst, "w") as f: json.dump(config, f, indent=2) assert config_file is None config_file = dst else: - with open(str(file), "rb") as i: - with open(str(dst), "wb") as o: + with open(file, "rb") as i: + with open(dst, "wb") as o: o.write(i.read()) - # load the pretrained model - labels = dataset.labels - labels = labels if isinstance(labels[0], str) else [s.decode() for s in labels] - model = DenseLayer(config_file, labels, layers) - # fine-tune and save the model - model.train(dataset) - model.save() + return config_file + +def import_pretrained_models(backend, model_instances): + config_files = [] + for pretrained_model_instance in model_instances: + pretrained_autoencoder_dir = backend.project_dir / "pretrained_models" / pretrained_model_instance + encoder_dir = backend.model_dir() / pretrained_model_instance + encoder_dir.mkdir(exist_ok=True) + config_file = None + for file in pretrained_autoencoder_dir.iterdir(): + if not file.is_file(): + continue + dst = encoder_dir / file.name + if file.name.endswith("config.json"): + with open(file) as f: + config = json.load(f) + dir = encoder_dir.relative_to(backend.project_dir) + config["log_dir"] = str(dir) + with open(dst, "w") as f: + json.dump(config, f, indent=2) + assert config_file is None + config_file = dst + else: + with open(file, "rb") as i: + with open(dst, "wb") as o: + o.write(i.read()) + assert config_file is not None + config_files.append(config_file) + return config_files from taggingbackends.main import main diff --git a/src/maggotuba/models/denselayer.py b/src/maggotuba/models/trainers.py similarity index 91% rename from src/maggotuba/models/denselayer.py rename to src/maggotuba/models/trainers.py index cf76118795842ab85a36e9b13f84b0b4c884eee6..44f38aa66000ab787ef9c007fee578b544ce2160 100644 --- a/src/maggotuba/models/denselayer.py +++ b/src/maggotuba/models/trainers.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn from behavior_model.models.neural_nets import device #import behavior_model.data.utils as data_utils -from maggotuba.models.modules import SupervisedMaggot +from maggotuba.models.modules import SupervisedMaggot, MultiscaleSupervisedMaggot """ This model borrows the pre-trained MaggotUBA encoder, substitute a dense layer @@ -17,7 +17,7 @@ Several data preprocessing steps are included for use in prediction mode. Training the model instead relies on the readily-preprocessed data of a *larva_dataset hdf5* file. """ -class DenseLayer: +class MaggotTrainer: def __init__(self, cfgfilepath, behaviors=[], n_layers=1, average_body_length=None, device=device): self.model = SupervisedMaggot(cfgfilepath, behaviors, n_layers) @@ -211,3 +211,19 @@ class DenseLayer: def new_generator(seed=0b11010111001001101001110): return torch.Generator(device).manual_seed(seed) +class MultiscaleMaggotTrainer(MaggotTrainer): + def __init__(self, cfgfilepath, behaviors=[], n_layers=1, + average_body_length=None, device=device): + self.model = MultiscaleSupervisedMaggot(cfgfilepath, behaviors, n_layers) + self.average_body_length = average_body_length # usually set later + self.device = device + # check consistency + ref_config = self.config + for attr in ["batch_size", "len_traj", "optim_iter"]: + for enc in self.model.encoders: + assert enc.config[attr] == ref_config[attr] + + @property + def config(self): + return next(iter(self.model.encoders)).config +