From e7838f74dd8b8735542ec0c3b05e8ea7e3ed7525 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Fran=C3=A7ois=20Laurent?= <francois.laurent@posteo.net>
Date: Fri, 30 Jun 2023 01:28:41 +0200
Subject: [PATCH] fine-tuning

---
 src/LarvaDatasets.jl            | 10 +++++--
 src/taggingbackends/explorer.py |  8 ++++++
 src/taggingbackends/main.py     | 51 ++++++++++++++++++++++++++-------
 3 files changed, 56 insertions(+), 13 deletions(-)

diff --git a/src/LarvaDatasets.jl b/src/LarvaDatasets.jl
index 2f6f45c..babf42d 100644
--- a/src/LarvaDatasets.jl
+++ b/src/LarvaDatasets.jl
@@ -617,7 +617,13 @@ function new_write_larva_dataset_hdf5(output_dir, input_data;
         ratiobasedsampling(selectors, min_max_ratio, prioritylabel(includeall); seed=seed)
     end
     loader = DataLoader(repo, window, index)
-    buildindex(loader; unload=true)
+    try
+        buildindex(loader; unload=true)
+    catch ArgumentError
+        # most likely error message: "collection must be non-empty"
+        @error "Most likely cause: no time segments could be isolated"
+        rethrow()
+    end
     total_sample_size = length(loader.index)
     classcounts, _ = Dataloaders.groupby(selectors, loader.index.targetcounts)
     #
@@ -680,7 +686,7 @@ function new_write_larva_dataset_hdf5(output_dir, input_data;
             # ensure labels are ordered as provided in input;
             # see https://gitlab.pasteur.fr/nyx/TaggingBackends/-/issues/24
             h5["labels"] = labels
-            h5["label_counts"] = [classcounts[Symbol(label)] for label in labels]
+            h5["label_counts"] = [get(classcounts, Symbol(label), 0) for label in labels]
         end
         if !isnothing(frameinterval)
             attributes(g)["frame_interval"] = frameinterval
diff --git a/src/taggingbackends/explorer.py b/src/taggingbackends/explorer.py
index 955d71b..286dae4 100644
--- a/src/taggingbackends/explorer.py
+++ b/src/taggingbackends/explorer.py
@@ -62,6 +62,7 @@ class BackendExplorer:
         self._build_features = None
         self._train_model = None
         self._predict_model = None
+        self._finetune_model = None
         #
         self._sandbox = sandbox
 
@@ -133,6 +134,13 @@ Cannot find any Python package in project root directory:
         if self._predict_model is not False:
             return self._predict_model
 
+    @property
+    def finetune_model(self):
+        if self._finetune_model is None:
+            self._finetune_model = self._locate_script("models", "finetune_model")
+        if self._finetune_model is not False:
+            return self._finetune_model
+
     def _locate_script(self, subpkg, basename):
         basename = basename + ".py"
         in_root_dir = self.project_dir / basename
diff --git a/src/taggingbackends/main.py b/src/taggingbackends/main.py
index 931c471..166bc9c 100644
--- a/src/taggingbackends/main.py
+++ b/src/taggingbackends/main.py
@@ -5,14 +5,18 @@ from taggingbackends.explorer import BackendExplorer, BackendExplorerDecoder, ge
 
 def help(_print=False):
     msg = """
-Usage:  tagging-backend [train|predict] --model-instance <name>
+Usage:  tagging-backend [train|predict|finetune] --model-instance <name>
+        tagging-backend [train|finetune] ... --sample-size <N>
+        tagging-backend [train|finetune] ... --balancing-strategy <S>
+        tagging-backend [train|finetune] ... --include-all <secondary-label>
+        tagging-backend [train|finetune] ... --skip-make-dataset
+        tagging-backend [train|finetune] ... --skip-build-features
+        tagging-backend [train|finetune] ... --iterations <N>
+        tagging-backend [train|finetune] ... --seed <seed>
         tagging-backend train ... --labels <labels> --class-weights <weights>
-        tagging-backend train ... --sample-size <N> --balancing-strategy <S>
         tagging-backend train ... --frame-interval <I> --window-length <T>
         tagging-backend train ... --pretrained-model-instance <name>
-        tagging-backend train ... --include-all <secondary-label>
-        tagging-backend train ... --skip-make-dataset --skip-build-features
-        tagging-backend train ... --seed <seed>
+        tagging-backend finetune ... --original-model-instance <name>
         tagging-backend predict ... --make-dataset --build-features
         tagging-backend predict ... --sandbox <token>
         tagging-backend --help
@@ -75,6 +79,17 @@ truly skip this step; the corresponding module is not loaded.
 Since version 0.10, `predict` makes `--skip-make-dataset` and
 `--skip-build-features` the default behavior. As a counterpart, it admits
 arguments `--make-dataset` and `--build-features`.
+
+New in version 0.14: the `finetune` switch loads a trained model and further
+train it on a similar dataset. The class labels and weights are inherited from
+the trained model. The backend is responsible for storing the information and,
+for example, MaggotUBA does not store the class weights.
+
+Fine-tuning is typically resorted to when the (re-)training dataset is small
+(and similar enough to the original training data). As a consequence, some
+classes may be underrepresented. While totally missing classes are properly
+ignored, the data points of the underrepresented classes should be explicitly
+unlabelled to be similarly excluded from the (re-)training dataset.
 """
     if _print:
         print(msg)
@@ -89,7 +104,7 @@ def main(fun=None):
             help(True)
             #sys.exit("too few input arguments; subcommand expected: 'train' or 'predict'")
             return
-        train_or_predict = sys.argv[1]
+        task = sys.argv[1]
         project_dir = model_instance = None
         input_files, labels = [], []
         sample_size = window_length = frame_interval = None
@@ -140,6 +155,9 @@ def main(fun=None):
             elif sys.argv[k] == "--pretrained-model-instance":
                 k = k + 1
                 pretrained_model_instance = sys.argv[k]
+            elif sys.argv[k] == "--original-model-instance":
+                k = k + 1
+                original_model_instance = sys.argv[k]
             elif sys.argv[k] == "--sandbox":
                 k = k + 1
                 sandbox = sys.argv[k]
@@ -167,12 +185,12 @@ def main(fun=None):
         if input_files:
             for file in input_files:
                 backend.move_to_raw(file)
-        if make_dataset is None and train_or_predict == 'train':
+        if make_dataset is None and task in ('train', 'finetune'):
             make_dataset = True
-        if build_features is None and train_or_predict == 'train':
+        if build_features is None and task in ('train', 'finetune'):
             build_features = True
         if make_dataset:
-            make_dataset_kwargs = dict(labels_expected=train_or_predict == "train",
+            make_dataset_kwargs = dict(labels_expected=task in ('train', 'finetune'),
                                        balancing_strategy=balancing_strategy)
             if labels:
                 make_dataset_kwargs["labels"] = labels
@@ -192,6 +210,8 @@ def main(fun=None):
                 logging.info("option --reuse-h5files is ignored in the absence of --trxmat-only")
             if pretrained_model_instance is not None:
                 make_dataset_kwargs["pretrained_model_instance"] = pretrained_model_instance
+            if original_model_instance is not None:
+                make_dataset_kwargs["original_model_instance"] = original_model_instance
             if include_all:
                 make_dataset_kwargs["include_all"] = include_all
             if seed is not None:
@@ -199,9 +219,9 @@ def main(fun=None):
             backend._run_script(backend.make_dataset, **make_dataset_kwargs)
         if build_features:
             backend._run_script(backend.build_features)
-        if train_or_predict == "predict":
+        if task == "predict":
             backend._run_script(backend.predict_model, trailing=unknown_args)
-        else:
+        elif task == 'train':
             train_kwargs = dict(balancing_strategy=balancing_strategy)
             if pretrained_model_instance:
                 train_kwargs["pretrained_model_instance"] = pretrained_model_instance
@@ -210,6 +230,13 @@ def main(fun=None):
             if seed is not None:
                 train_kwargs['seed'] = seed
             backend._run_script(backend.train_model, trailing=unknown_args, **train_kwargs)
+        elif task == 'finetune':
+            finetune_kwargs = dict(balancing_strategy=balancing_strategy)
+            if original_model_instance:
+                finetune_kwargs['original_model_instance'] = original_model_instance
+            if seed is not None:
+                finetune_kwargs['seed'] = seed
+            backend._run_script(backend.finetune_model, trailing=unknown_args, **finetune_kwargs)
     else:
         # called by make_dataset, build_features, train_model and predict_model
         backend = BackendExplorerDecoder().decode(sys.argv[1])
@@ -234,6 +261,8 @@ def main(fun=None):
                 val_ = val.split(',')
                 if val_[1:]:
                     val = val_
+            elif key == 'original_model_instance':
+                pass # do not try to convert to number
             elif isinstance(val, str):
                 try:
                     val = int(val)
-- 
GitLab