From e7838f74dd8b8735542ec0c3b05e8ea7e3ed7525 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Laurent?= <francois.laurent@posteo.net> Date: Fri, 30 Jun 2023 01:28:41 +0200 Subject: [PATCH] fine-tuning --- src/LarvaDatasets.jl | 10 +++++-- src/taggingbackends/explorer.py | 8 ++++++ src/taggingbackends/main.py | 51 ++++++++++++++++++++++++++------- 3 files changed, 56 insertions(+), 13 deletions(-) diff --git a/src/LarvaDatasets.jl b/src/LarvaDatasets.jl index 2f6f45c..babf42d 100644 --- a/src/LarvaDatasets.jl +++ b/src/LarvaDatasets.jl @@ -617,7 +617,13 @@ function new_write_larva_dataset_hdf5(output_dir, input_data; ratiobasedsampling(selectors, min_max_ratio, prioritylabel(includeall); seed=seed) end loader = DataLoader(repo, window, index) - buildindex(loader; unload=true) + try + buildindex(loader; unload=true) + catch ArgumentError + # most likely error message: "collection must be non-empty" + @error "Most likely cause: no time segments could be isolated" + rethrow() + end total_sample_size = length(loader.index) classcounts, _ = Dataloaders.groupby(selectors, loader.index.targetcounts) # @@ -680,7 +686,7 @@ function new_write_larva_dataset_hdf5(output_dir, input_data; # ensure labels are ordered as provided in input; # see https://gitlab.pasteur.fr/nyx/TaggingBackends/-/issues/24 h5["labels"] = labels - h5["label_counts"] = [classcounts[Symbol(label)] for label in labels] + h5["label_counts"] = [get(classcounts, Symbol(label), 0) for label in labels] end if !isnothing(frameinterval) attributes(g)["frame_interval"] = frameinterval diff --git a/src/taggingbackends/explorer.py b/src/taggingbackends/explorer.py index 955d71b..286dae4 100644 --- a/src/taggingbackends/explorer.py +++ b/src/taggingbackends/explorer.py @@ -62,6 +62,7 @@ class BackendExplorer: self._build_features = None self._train_model = None self._predict_model = None + self._finetune_model = None # self._sandbox = sandbox @@ -133,6 +134,13 @@ Cannot find any Python package in project root directory: if self._predict_model is not False: return self._predict_model + @property + def finetune_model(self): + if self._finetune_model is None: + self._finetune_model = self._locate_script("models", "finetune_model") + if self._finetune_model is not False: + return self._finetune_model + def _locate_script(self, subpkg, basename): basename = basename + ".py" in_root_dir = self.project_dir / basename diff --git a/src/taggingbackends/main.py b/src/taggingbackends/main.py index 931c471..166bc9c 100644 --- a/src/taggingbackends/main.py +++ b/src/taggingbackends/main.py @@ -5,14 +5,18 @@ from taggingbackends.explorer import BackendExplorer, BackendExplorerDecoder, ge def help(_print=False): msg = """ -Usage: tagging-backend [train|predict] --model-instance <name> +Usage: tagging-backend [train|predict|finetune] --model-instance <name> + tagging-backend [train|finetune] ... --sample-size <N> + tagging-backend [train|finetune] ... --balancing-strategy <S> + tagging-backend [train|finetune] ... --include-all <secondary-label> + tagging-backend [train|finetune] ... --skip-make-dataset + tagging-backend [train|finetune] ... --skip-build-features + tagging-backend [train|finetune] ... --iterations <N> + tagging-backend [train|finetune] ... --seed <seed> tagging-backend train ... --labels <labels> --class-weights <weights> - tagging-backend train ... --sample-size <N> --balancing-strategy <S> tagging-backend train ... --frame-interval <I> --window-length <T> tagging-backend train ... --pretrained-model-instance <name> - tagging-backend train ... --include-all <secondary-label> - tagging-backend train ... --skip-make-dataset --skip-build-features - tagging-backend train ... --seed <seed> + tagging-backend finetune ... --original-model-instance <name> tagging-backend predict ... --make-dataset --build-features tagging-backend predict ... --sandbox <token> tagging-backend --help @@ -75,6 +79,17 @@ truly skip this step; the corresponding module is not loaded. Since version 0.10, `predict` makes `--skip-make-dataset` and `--skip-build-features` the default behavior. As a counterpart, it admits arguments `--make-dataset` and `--build-features`. + +New in version 0.14: the `finetune` switch loads a trained model and further +train it on a similar dataset. The class labels and weights are inherited from +the trained model. The backend is responsible for storing the information and, +for example, MaggotUBA does not store the class weights. + +Fine-tuning is typically resorted to when the (re-)training dataset is small +(and similar enough to the original training data). As a consequence, some +classes may be underrepresented. While totally missing classes are properly +ignored, the data points of the underrepresented classes should be explicitly +unlabelled to be similarly excluded from the (re-)training dataset. """ if _print: print(msg) @@ -89,7 +104,7 @@ def main(fun=None): help(True) #sys.exit("too few input arguments; subcommand expected: 'train' or 'predict'") return - train_or_predict = sys.argv[1] + task = sys.argv[1] project_dir = model_instance = None input_files, labels = [], [] sample_size = window_length = frame_interval = None @@ -140,6 +155,9 @@ def main(fun=None): elif sys.argv[k] == "--pretrained-model-instance": k = k + 1 pretrained_model_instance = sys.argv[k] + elif sys.argv[k] == "--original-model-instance": + k = k + 1 + original_model_instance = sys.argv[k] elif sys.argv[k] == "--sandbox": k = k + 1 sandbox = sys.argv[k] @@ -167,12 +185,12 @@ def main(fun=None): if input_files: for file in input_files: backend.move_to_raw(file) - if make_dataset is None and train_or_predict == 'train': + if make_dataset is None and task in ('train', 'finetune'): make_dataset = True - if build_features is None and train_or_predict == 'train': + if build_features is None and task in ('train', 'finetune'): build_features = True if make_dataset: - make_dataset_kwargs = dict(labels_expected=train_or_predict == "train", + make_dataset_kwargs = dict(labels_expected=task in ('train', 'finetune'), balancing_strategy=balancing_strategy) if labels: make_dataset_kwargs["labels"] = labels @@ -192,6 +210,8 @@ def main(fun=None): logging.info("option --reuse-h5files is ignored in the absence of --trxmat-only") if pretrained_model_instance is not None: make_dataset_kwargs["pretrained_model_instance"] = pretrained_model_instance + if original_model_instance is not None: + make_dataset_kwargs["original_model_instance"] = original_model_instance if include_all: make_dataset_kwargs["include_all"] = include_all if seed is not None: @@ -199,9 +219,9 @@ def main(fun=None): backend._run_script(backend.make_dataset, **make_dataset_kwargs) if build_features: backend._run_script(backend.build_features) - if train_or_predict == "predict": + if task == "predict": backend._run_script(backend.predict_model, trailing=unknown_args) - else: + elif task == 'train': train_kwargs = dict(balancing_strategy=balancing_strategy) if pretrained_model_instance: train_kwargs["pretrained_model_instance"] = pretrained_model_instance @@ -210,6 +230,13 @@ def main(fun=None): if seed is not None: train_kwargs['seed'] = seed backend._run_script(backend.train_model, trailing=unknown_args, **train_kwargs) + elif task == 'finetune': + finetune_kwargs = dict(balancing_strategy=balancing_strategy) + if original_model_instance: + finetune_kwargs['original_model_instance'] = original_model_instance + if seed is not None: + finetune_kwargs['seed'] = seed + backend._run_script(backend.finetune_model, trailing=unknown_args, **finetune_kwargs) else: # called by make_dataset, build_features, train_model and predict_model backend = BackendExplorerDecoder().decode(sys.argv[1]) @@ -234,6 +261,8 @@ def main(fun=None): val_ = val.split(',') if val_[1:]: val = val_ + elif key == 'original_model_instance': + pass # do not try to convert to number elif isinstance(val, str): try: val = int(val) -- GitLab