diff --git a/Artifacts.toml b/Artifacts.toml index e1a3f6961d64bfa3d8f1823d8e78b68666a64e19..ef6feeb208a4afe4c6498c80d0d94e62c22377fc 100644 --- a/Artifacts.toml +++ b/Artifacts.toml @@ -1,3 +1,19 @@ +[20230311] +git-tree-sha1 = "bd21d3575d436576f40116304d31f75ac1026c0f" +lazy = true + + [[20230311.download]] + url = "https://gitlab.pasteur.fr/nyx/artefacts/-/raw/ef0f4de04620a87c15c7855c192f76731bd603ec/MaggotUBA/20230311.tar.gz?inline=false" + sha256 = "9cb79c0d75e883e1bb80904fb3bacf030758cadfbd874d627c46554b5eafaf51" + +[20230311-0] +git-tree-sha1 = "0790eefe8b0990622aa91b6ffdff9cbf862c983f" +lazy = true + + [[20230311-0.download]] + url = "https://gitlab.pasteur.fr/nyx/artefacts/-/raw/ef0f4de04620a87c15c7855c192f76731bd603ec/MaggotUBA/20230311-0.tar.gz?inline=false" + sha256 = "e7405cf9a6bda62422f89dcf8edd651a7b11164ac53d13d214e06b1d612802d3" + [20230524-6behaviors-25-0] git-tree-sha1 = "361d2ecebbd74f0dbc7840f06db2aa8c2b3293c6" lazy = true diff --git a/README.md b/README.md index 8fef32e89f34a0609fc0f03bd8edd64c4f1e6c3a..418d07e78019e72287a577499f56c1cff151c037 100644 --- a/README.md +++ b/README.md @@ -110,6 +110,8 @@ Use the *larvatagger.jl* script of the [LarvaTagger.jl](https://gitlab.pasteur.f ### Retraining a tagger +#### The `train` command + A new model instance can be trained on a data repository with: ``` @@ -138,3 +140,22 @@ Training operates in two steps, first pretraining the dense-layer classifier, se See also the [`train_model.py`](https://gitlab.pasteur.fr/nyx/MaggotUBA-adapter/-/blob/dev/src/maggotuba/models/train_model.py) script. This generates a new sub-directory in the `models` directory of the `MaggotUBA-adapter` project, which makes the trained model discoverable for automatic tagging (*predict* command). + +#### The `finetune` command + +Alternatively, in cases with few but similar data, an already trained tagger can be further trained, instead of having the encoder part pretrained only. + +This can be done with: +``` +poetry run tagging-backend finetune --model-instance <new-tagger-name> --original-model-instance <reused-tagger-name> +``` + +See also the [`finetune_model.py`](https://gitlab.pasteur.fr/nyx/MaggotUBA-adapter/-/blob/dev/src/maggotuba/models/finetune_model.py) script. + +Beware that the behavior labels must be compatible with the model instance. In particular the [label mapping feature](https://gitlab.pasteur.fr/nyx/larvatagger.jl/-/issues/103) can interfere, because the predicted labels differ from the actual labels the tagger actually generates in the first place, before the mapping applies. +If files with such mapped labels are used for fine-tuning, the `finetune` step may fail to sample a training dataset from these data. +Indeed, `finetune` loads the pre-mapping label definition from the model files and samples instances of these labels in the data. + +This is relevant in cases where the `20230311` tagger is used to generate a first round of predictions that are manually corrected with the purpose of retraining (fine-tuning) a new tagger based on the corrected labelled data. The labels will not be compatible with the underlying classifier. Only `train` will apply. + +The `202230311-0` tagger may be considered instead, if `finetune` may be used with the above-mentioned purpose. diff --git a/pyproject.toml b/pyproject.toml index d9f42805d6b9d0338f3e30dfe33c3d3c8ae20ff4..b032612b4713d5dd21784acbe94b0ae69b8af3f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "MaggotUBA-adapter" -version = "0.14.0" +version = "0.15.0" description = "Interface between MaggotUBA and the Nyx tagging UI" authors = ["François Laurent"] license = "MIT" @@ -14,7 +14,7 @@ maggotuba-core = {git = "https://gitlab.pasteur.fr/nyx/MaggotUBA-core", tag = "v torch = "^1.11.0" numpy = "^1.19.3" protobuf = "3.9.2" -taggingbackends = {git = "https://gitlab.pasteur.fr/nyx/TaggingBackends", tag = "v0.13.1"} +taggingbackends = {git = "https://gitlab.pasteur.fr/nyx/TaggingBackends", tag = "v0.14"} [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/scripts/make_models.jl b/scripts/make_models.jl index 1cc5d9d8cdc891d853d216c99f6940ac9f67586c..dff70ce92641cc27d32c40edb2e4816be0e314dd 100755 --- a/scripts/make_models.jl +++ b/scripts/make_models.jl @@ -10,7 +10,7 @@ using LazyArtifacts projectdir = dirname(Base.active_project()) function pretrained_models(name) - artifact = @artifact_str("$name") + artifact = @artifact_str(name) src = @artifact_str("$name/pretrained_models") dst = mkpath(joinpath(projectdir, "pretrained_models")) for filename in readdir(src; join=false) @@ -19,9 +19,19 @@ function pretrained_models(name) rm(artifact; recursive=true, force=true) end +function models(name) + src = artifact = @artifact_str(name) + dst = mkpath(joinpath(projectdir, "models")) + for filename in readdir(src; join=false) + mv(joinpath(src, filename), joinpath(dst, filename)) + end + rm(artifact; recursive=true, force=true) + @assert isdir(joinpath(dst, name)) +end + function main(args=ARGS) if isempty(args) - print("missing model name") + print("missing pretrained model name") exit() elseif length(args) == 1 && args[1] == "default" args = ["20230524-6behaviors-25", "20230524-hunch-25", "20230524-roll-25"] @@ -32,6 +42,9 @@ function main(args=ARGS) for arg in args pretrained_models(arg) end + for arg in ("20230311", "20230311-0") + models(arg) + end end main() diff --git a/src/maggotuba/data/make_dataset.py b/src/maggotuba/data/make_dataset.py index 9e44127fd5af7510b91ffcdea790a592f5a7986c..4687ff713c4543e648ec3bd6cc2f45fbee212161 100644 --- a/src/maggotuba/data/make_dataset.py +++ b/src/maggotuba/data/make_dataset.py @@ -1,6 +1,7 @@ import glob import pathlib import json +import sys def make_dataset(backend, labels_expected=False, trxmat_only=False, balancing_strategy='maggotuba', @@ -17,8 +18,12 @@ def make_dataset(backend, labels_expected=False, trxmat_only=False, else: if 'frame_interval' not in kwargs: - autoencoder_config = glob.glob(str(backend.project_dir / "pretrained_models" / pretrained_model_instance / "*config.json")) - with open(autoencoder_config[0], "r") as f: + # load argument `frame_interval` + if 'original_model_instance' in kwargs: + autoencoder_config = str(backend.project_dir / 'models' / kwargs['original_model_instance'] / 'autoencoder_config.json') + else: + autoencoder_config = glob.glob(str(backend.project_dir / "pretrained_models" / pretrained_model_instance / "*config.json"))[0] + with open(autoencoder_config, "r") as f: config = json.load(f) try: frame_interval = config['frame_interval'] @@ -27,14 +32,34 @@ def make_dataset(backend, labels_expected=False, trxmat_only=False, else: kwargs['frame_interval'] = frame_interval + if 'original_model_instance' in kwargs: + original_instance = kwargs.pop('original_model_instance') + # load parameter `window_length` + enc_config = str(backend.project_dir / 'models' / original_instance / 'autoencoder_config.json') + with open(enc_config, 'r') as f: + config = json.load(f) + kwargs['window_length'] = int(config['len_traj']) + # load parameter `labels` + clf_config = str(backend.project_dir / 'models' / original_instance / 'clf_config.json') + with open(clf_config, 'r') as f: + config = json.load(f) + for key in ('original_behavior_labels', 'behavior_labels'): + try: + labels = config[key] + except KeyError: + pass + else: + # note kwargs['labels'] may be defined, but we dismiss + # the input argument, because we need to preserve the + # order of the labels (the class indices) + kwargs['labels'] = labels + break + print("generating a larva_dataset file...") # generate a larva_dataset_*.hdf5 file in data/interim/{instance}/ - if False:#trxmat_only: - out = backend.compile_trxmat_database(backend.raw_data_dir(), **kwargs) - else: - out = backend.generate_dataset(backend.raw_data_dir(), - balance=isinstance(balancing_strategy, str) and balancing_strategy.lower() == 'maggotuba', - **kwargs) + balance = isinstance(balancing_strategy, str) and balancing_strategy.lower() == 'maggotuba' + out = backend.generate_dataset(backend.raw_data_dir(), + balance=balance, **kwargs) print(f"larva_dataset file generated: {out}") diff --git a/src/maggotuba/models/finetune_model.py b/src/maggotuba/models/finetune_model.py new file mode 100644 index 0000000000000000000000000000000000000000..bfcb1978e8904668ac8e37aea59a822982c46059 --- /dev/null +++ b/src/maggotuba/models/finetune_model.py @@ -0,0 +1,53 @@ +from taggingbackends.data.labels import Labels +import logging +from taggingbackends.data.dataset import LarvaDataset +from maggotuba.models.trainers import MaggotTrainer, new_generator, enforce_reproducibility, fork_model +import glob + +def finetune_model(backend, original_model_instance="default", + subsets=(1, 0, 0), seed=None, iterations=100, **kwargs): + # list training data files; + # we actually expect a single larva_dataset file that make_dataset generated + # or moved into data/interim/{instance}/ + #larva_dataset_file = backend.list_interim_files("larva_dataset_*.hdf5") # this one is recursive + larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # this other one is not recursive + assert len(larva_dataset_file) == 1 + + # instanciate a LarvaDataset object, that is similar to a PyTorch DataLoader + # add can initialize a Labels object + # note: subsets=(1, 0, 0) => all data are training data; no validation or test subsets + dataset = LarvaDataset(larva_dataset_file[0], new_generator(seed), + subsets=subsets, **kwargs) + + # initialize a Labels object + labels = dataset.labels + assert 0 < len(labels) + + # the labels may be bytes objects; convert to str + labels = labels if isinstance(labels[0], str) else [s.decode() for s in labels] + + # could be moved into `make_trainer`, but we need it to access the generator + enforce_reproducibility(dataset.generator) + + # fork the original model + fork_model(backend, original_model_instance) + logging.info("model forked") + + # load the forked model + config_file = backend.list_model_files('clf_config.json')[0] + model = MaggotTrainer(config_file) + model.n_pretraining_iter = 0 + model.n_finetuning_iter = iterations + + # fine-tune the model on the loaded dataset + model.train(dataset) + + # save the model + print(f"saving model \"{backend.model_instance}\"") + model.save() + + +from taggingbackends.main import main + +if __name__ == "__main__": + main(finetune_model) diff --git a/src/maggotuba/models/modules.py b/src/maggotuba/models/modules.py index 19912ad45dcace6404c2c42c0e63b27df2aa14ee..62ca0707bfc0f7d0b9ba8aa5b8721f2b54809565 100644 --- a/src/maggotuba/models/modules.py +++ b/src/maggotuba/models/modules.py @@ -227,8 +227,7 @@ class MaggotEncoder(MaggotModule): """ Determine whether the encoder was pretrained as part of a MaggotUBA - autoencoder, or only initialized in the context of testing the benefit of - using a pretrained encoder. + autoencoder, or only initialized, with no pretraining. This is to be distinguished from the MaggotEncoder/PretrainedMaggotEncoder classes that instead represent the different states *after* and *before* diff --git a/src/maggotuba/models/trainers.py b/src/maggotuba/models/trainers.py index debb78c8c77dbbe767ebc364fd4ca3073faca7be..bc1b1db75aca963396dc48a8e1fe28647f734f02 100644 --- a/src/maggotuba/models/trainers.py +++ b/src/maggotuba/models/trainers.py @@ -7,6 +7,7 @@ from taggingbackends.features.skeleton import interpolate from taggingbackends.explorer import BackendExplorer, check_permissions import logging import json +import re """ This model borrows the pre-trained MaggotUBA encoder, substitute a dense layer @@ -158,34 +159,41 @@ class MaggotTrainer: raise ValueError(f"the dataset can provide segments of up to {dataset.window_length} time points") dataset._mask = slice(midpoint - before, midpoint + after) - def train(self, dataset): - self.prepare_dataset(dataset) + def init_model_for_training(self, dataset): kwargs = {} if dataset.class_weights is not None: kwargs['weight'] = torch.from_numpy(dataset.class_weights.astype(np.float32)).to(self.device) - model = self.model - model.train() # this only sets the model in training mode (enables gradients) - model.to(self.device) + self.model.train() # this only sets the model in training mode (enables gradients) + self.model.to(self.device) criterion = nn.CrossEntropyLoss(**kwargs) + return criterion + + def _pretrain_classifier(self): + model = self.model + return model.n_pretraining_iter > 0 and model.encoder.was_pretrained() + + def pretrain_classifier(self, criterion, dataset): + model = self.model + grad_clip = self.config['grad_clip'] + optimizer = torch.optim.Adam(model.clf.parameters()) + print("pre-training the classifier...") + for step in range(model.n_pretraining_iter): + optimizer.zero_grad() + # TODO: add an option for renormalizing the input + data, expected = self.draw(dataset) + predicted = self.forward(data, train=True) + loss = criterion(predicted, expected) + loss.backward() + nn.utils.clip_grad_norm_(model.clf.parameters(), grad_clip) + optimizer.step() + + def finetune(self, criterion, dataset): + model = self.model grad_clip = self.config['grad_clip'] - # pre-train the classifier with static encoder weights - if model.encoder.was_pretrained(): - optimizer = torch.optim.Adam(model.clf.parameters()) - print("pre-training the classifier...") - for step in range(self.model.n_pretraining_iter): - optimizer.zero_grad() - # TODO: add an option for renormalizing the input - data, expected = self.draw(dataset) - predicted = self.forward(data, train=True) - loss = criterion(predicted, expected) - loss.backward() - nn.utils.clip_grad_norm_(model.clf.parameters(), grad_clip) - optimizer.step() - # fine-tune both the encoder and the classifier optimizer = torch.optim.Adam(model.parameters()) - print(("fine-tuning" if model.encoder.was_pretrained() else "training") + \ + print(("fine-tuning" if self._pretrain_classifier() else "training") + \ " the encoder and classifier...") - for step in range(self.model.n_finetuning_iter): + for step in range(model.n_finetuning_iter): optimizer.zero_grad() data, expected = self.draw(dataset) predicted = self.forward(data, train=True) @@ -193,7 +201,15 @@ class MaggotTrainer: loss.backward() nn.utils.clip_grad_norm_(model.parameters(), grad_clip) optimizer.step() - # + + def train(self, dataset): + self.prepare_dataset(dataset) + criterion = self.init_model_for_training(dataset) + # pre-train the classifier with static encoder weights + if self._pretrain_classifier(): + self.pretrain_classifier(criterion, dataset) + # fine-tune both the encoder and the classifier + self.finetune(criterion, dataset) return self def draw(self, dataset, subset="train"): @@ -245,6 +261,22 @@ class MaggotTrainer: def root_dir(self, dir): self.model.root_dir = dir + @property + def n_pretraining_iter(self): + return self.model.n_pretraining_iter + + @n_pretraining_iter.setter + def n_pretraining_iter(self, n): + self.model.clf.config['pretraining_iter'] = n + + @property + def n_finetuning_iter(self): + return self.model.n_finetuning_iter + + @n_finetuning_iter.setter + def n_finetuning_iter(self, n): + self.model.clf.config['finetuning_iter'] = n + def new_generator(seed=None): generator = torch.Generator('cpu') if seed == 'random': return generator @@ -400,6 +432,36 @@ def import_pretrained_models(backend, model_instances): config_files.append(config_file) return config_files +""" +Copy a model instance under another instance name. +""" +def fork_model(backend, src_instance): + srcdir = backend.model_dir(src_instance, False) + dstdir = backend.model_dir() + config_files = [] + pattern = f"models/{src_instance}" + replacement = f"models/{backend.model_instance}" + for srcfile in srcdir.iterdir(): + if not srcfile.is_file(): + continue + dstfile = dstdir / srcfile.name + if srcfile.name.endswith('config.json'): + with open(srcfile) as f: + config = json.load(f) + for element, value in config.items(): + if isinstance(value, str): + value = re.sub(pattern, replacement, value) + config[element] = value + with open(dstfile, 'w') as f: + json.dump(config, f, indent=2) + config_files.append(srcfile) + else: + with open(srcfile, 'rb') as i: + with open(dstfile, 'wb') as o: + o.write(i.read()) + check_permissions(dstfile) + return config_files + # Julia functions def searchsortedfirst(xs, x):