diff --git a/src/maggotuba/data/make_dataset.py b/src/maggotuba/data/make_dataset.py index 9e44127fd5af7510b91ffcdea790a592f5a7986c..4687ff713c4543e648ec3bd6cc2f45fbee212161 100644 --- a/src/maggotuba/data/make_dataset.py +++ b/src/maggotuba/data/make_dataset.py @@ -1,6 +1,7 @@ import glob import pathlib import json +import sys def make_dataset(backend, labels_expected=False, trxmat_only=False, balancing_strategy='maggotuba', @@ -17,8 +18,12 @@ def make_dataset(backend, labels_expected=False, trxmat_only=False, else: if 'frame_interval' not in kwargs: - autoencoder_config = glob.glob(str(backend.project_dir / "pretrained_models" / pretrained_model_instance / "*config.json")) - with open(autoencoder_config[0], "r") as f: + # load argument `frame_interval` + if 'original_model_instance' in kwargs: + autoencoder_config = str(backend.project_dir / 'models' / kwargs['original_model_instance'] / 'autoencoder_config.json') + else: + autoencoder_config = glob.glob(str(backend.project_dir / "pretrained_models" / pretrained_model_instance / "*config.json"))[0] + with open(autoencoder_config, "r") as f: config = json.load(f) try: frame_interval = config['frame_interval'] @@ -27,14 +32,34 @@ def make_dataset(backend, labels_expected=False, trxmat_only=False, else: kwargs['frame_interval'] = frame_interval + if 'original_model_instance' in kwargs: + original_instance = kwargs.pop('original_model_instance') + # load parameter `window_length` + enc_config = str(backend.project_dir / 'models' / original_instance / 'autoencoder_config.json') + with open(enc_config, 'r') as f: + config = json.load(f) + kwargs['window_length'] = int(config['len_traj']) + # load parameter `labels` + clf_config = str(backend.project_dir / 'models' / original_instance / 'clf_config.json') + with open(clf_config, 'r') as f: + config = json.load(f) + for key in ('original_behavior_labels', 'behavior_labels'): + try: + labels = config[key] + except KeyError: + pass + else: + # note kwargs['labels'] may be defined, but we dismiss + # the input argument, because we need to preserve the + # order of the labels (the class indices) + kwargs['labels'] = labels + break + print("generating a larva_dataset file...") # generate a larva_dataset_*.hdf5 file in data/interim/{instance}/ - if False:#trxmat_only: - out = backend.compile_trxmat_database(backend.raw_data_dir(), **kwargs) - else: - out = backend.generate_dataset(backend.raw_data_dir(), - balance=isinstance(balancing_strategy, str) and balancing_strategy.lower() == 'maggotuba', - **kwargs) + balance = isinstance(balancing_strategy, str) and balancing_strategy.lower() == 'maggotuba' + out = backend.generate_dataset(backend.raw_data_dir(), + balance=balance, **kwargs) print(f"larva_dataset file generated: {out}") diff --git a/src/maggotuba/models/finetune_model.py b/src/maggotuba/models/finetune_model.py new file mode 100644 index 0000000000000000000000000000000000000000..bfcb1978e8904668ac8e37aea59a822982c46059 --- /dev/null +++ b/src/maggotuba/models/finetune_model.py @@ -0,0 +1,53 @@ +from taggingbackends.data.labels import Labels +import logging +from taggingbackends.data.dataset import LarvaDataset +from maggotuba.models.trainers import MaggotTrainer, new_generator, enforce_reproducibility, fork_model +import glob + +def finetune_model(backend, original_model_instance="default", + subsets=(1, 0, 0), seed=None, iterations=100, **kwargs): + # list training data files; + # we actually expect a single larva_dataset file that make_dataset generated + # or moved into data/interim/{instance}/ + #larva_dataset_file = backend.list_interim_files("larva_dataset_*.hdf5") # this one is recursive + larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # this other one is not recursive + assert len(larva_dataset_file) == 1 + + # instanciate a LarvaDataset object, that is similar to a PyTorch DataLoader + # add can initialize a Labels object + # note: subsets=(1, 0, 0) => all data are training data; no validation or test subsets + dataset = LarvaDataset(larva_dataset_file[0], new_generator(seed), + subsets=subsets, **kwargs) + + # initialize a Labels object + labels = dataset.labels + assert 0 < len(labels) + + # the labels may be bytes objects; convert to str + labels = labels if isinstance(labels[0], str) else [s.decode() for s in labels] + + # could be moved into `make_trainer`, but we need it to access the generator + enforce_reproducibility(dataset.generator) + + # fork the original model + fork_model(backend, original_model_instance) + logging.info("model forked") + + # load the forked model + config_file = backend.list_model_files('clf_config.json')[0] + model = MaggotTrainer(config_file) + model.n_pretraining_iter = 0 + model.n_finetuning_iter = iterations + + # fine-tune the model on the loaded dataset + model.train(dataset) + + # save the model + print(f"saving model \"{backend.model_instance}\"") + model.save() + + +from taggingbackends.main import main + +if __name__ == "__main__": + main(finetune_model) diff --git a/src/maggotuba/models/modules.py b/src/maggotuba/models/modules.py index 19912ad45dcace6404c2c42c0e63b27df2aa14ee..62ca0707bfc0f7d0b9ba8aa5b8721f2b54809565 100644 --- a/src/maggotuba/models/modules.py +++ b/src/maggotuba/models/modules.py @@ -227,8 +227,7 @@ class MaggotEncoder(MaggotModule): """ Determine whether the encoder was pretrained as part of a MaggotUBA - autoencoder, or only initialized in the context of testing the benefit of - using a pretrained encoder. + autoencoder, or only initialized, with no pretraining. This is to be distinguished from the MaggotEncoder/PretrainedMaggotEncoder classes that instead represent the different states *after* and *before* diff --git a/src/maggotuba/models/trainers.py b/src/maggotuba/models/trainers.py index debb78c8c77dbbe767ebc364fd4ca3073faca7be..bc1b1db75aca963396dc48a8e1fe28647f734f02 100644 --- a/src/maggotuba/models/trainers.py +++ b/src/maggotuba/models/trainers.py @@ -7,6 +7,7 @@ from taggingbackends.features.skeleton import interpolate from taggingbackends.explorer import BackendExplorer, check_permissions import logging import json +import re """ This model borrows the pre-trained MaggotUBA encoder, substitute a dense layer @@ -158,34 +159,41 @@ class MaggotTrainer: raise ValueError(f"the dataset can provide segments of up to {dataset.window_length} time points") dataset._mask = slice(midpoint - before, midpoint + after) - def train(self, dataset): - self.prepare_dataset(dataset) + def init_model_for_training(self, dataset): kwargs = {} if dataset.class_weights is not None: kwargs['weight'] = torch.from_numpy(dataset.class_weights.astype(np.float32)).to(self.device) - model = self.model - model.train() # this only sets the model in training mode (enables gradients) - model.to(self.device) + self.model.train() # this only sets the model in training mode (enables gradients) + self.model.to(self.device) criterion = nn.CrossEntropyLoss(**kwargs) + return criterion + + def _pretrain_classifier(self): + model = self.model + return model.n_pretraining_iter > 0 and model.encoder.was_pretrained() + + def pretrain_classifier(self, criterion, dataset): + model = self.model + grad_clip = self.config['grad_clip'] + optimizer = torch.optim.Adam(model.clf.parameters()) + print("pre-training the classifier...") + for step in range(model.n_pretraining_iter): + optimizer.zero_grad() + # TODO: add an option for renormalizing the input + data, expected = self.draw(dataset) + predicted = self.forward(data, train=True) + loss = criterion(predicted, expected) + loss.backward() + nn.utils.clip_grad_norm_(model.clf.parameters(), grad_clip) + optimizer.step() + + def finetune(self, criterion, dataset): + model = self.model grad_clip = self.config['grad_clip'] - # pre-train the classifier with static encoder weights - if model.encoder.was_pretrained(): - optimizer = torch.optim.Adam(model.clf.parameters()) - print("pre-training the classifier...") - for step in range(self.model.n_pretraining_iter): - optimizer.zero_grad() - # TODO: add an option for renormalizing the input - data, expected = self.draw(dataset) - predicted = self.forward(data, train=True) - loss = criterion(predicted, expected) - loss.backward() - nn.utils.clip_grad_norm_(model.clf.parameters(), grad_clip) - optimizer.step() - # fine-tune both the encoder and the classifier optimizer = torch.optim.Adam(model.parameters()) - print(("fine-tuning" if model.encoder.was_pretrained() else "training") + \ + print(("fine-tuning" if self._pretrain_classifier() else "training") + \ " the encoder and classifier...") - for step in range(self.model.n_finetuning_iter): + for step in range(model.n_finetuning_iter): optimizer.zero_grad() data, expected = self.draw(dataset) predicted = self.forward(data, train=True) @@ -193,7 +201,15 @@ class MaggotTrainer: loss.backward() nn.utils.clip_grad_norm_(model.parameters(), grad_clip) optimizer.step() - # + + def train(self, dataset): + self.prepare_dataset(dataset) + criterion = self.init_model_for_training(dataset) + # pre-train the classifier with static encoder weights + if self._pretrain_classifier(): + self.pretrain_classifier(criterion, dataset) + # fine-tune both the encoder and the classifier + self.finetune(criterion, dataset) return self def draw(self, dataset, subset="train"): @@ -245,6 +261,22 @@ class MaggotTrainer: def root_dir(self, dir): self.model.root_dir = dir + @property + def n_pretraining_iter(self): + return self.model.n_pretraining_iter + + @n_pretraining_iter.setter + def n_pretraining_iter(self, n): + self.model.clf.config['pretraining_iter'] = n + + @property + def n_finetuning_iter(self): + return self.model.n_finetuning_iter + + @n_finetuning_iter.setter + def n_finetuning_iter(self, n): + self.model.clf.config['finetuning_iter'] = n + def new_generator(seed=None): generator = torch.Generator('cpu') if seed == 'random': return generator @@ -400,6 +432,36 @@ def import_pretrained_models(backend, model_instances): config_files.append(config_file) return config_files +""" +Copy a model instance under another instance name. +""" +def fork_model(backend, src_instance): + srcdir = backend.model_dir(src_instance, False) + dstdir = backend.model_dir() + config_files = [] + pattern = f"models/{src_instance}" + replacement = f"models/{backend.model_instance}" + for srcfile in srcdir.iterdir(): + if not srcfile.is_file(): + continue + dstfile = dstdir / srcfile.name + if srcfile.name.endswith('config.json'): + with open(srcfile) as f: + config = json.load(f) + for element, value in config.items(): + if isinstance(value, str): + value = re.sub(pattern, replacement, value) + config[element] = value + with open(dstfile, 'w') as f: + json.dump(config, f, indent=2) + config_files.append(srcfile) + else: + with open(srcfile, 'rb') as i: + with open(dstfile, 'wb') as o: + o.write(i.read()) + check_permissions(dstfile) + return config_files + # Julia functions def searchsortedfirst(xs, x):