From ab9e7ceaa48469126b64c1c523d461e1f5d73cd8 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Fran=C3=A7ois=20Laurent?= <francois.laurent@posteo.net>
Date: Tue, 2 Jan 2024 21:22:05 +0100
Subject: [PATCH] --embeddings option

---
 scripts/larvatagger |  2 ++
 src/Taggers.jl      | 12 +++++++++++-
 src/cli.jl          |  2 ++
 src/cli_toolkit.jl  | 12 ++++++++++--
 4 files changed, 25 insertions(+), 3 deletions(-)

diff --git a/scripts/larvatagger b/scripts/larvatagger
index 6425ae8..d9a475b 100755
--- a/scripts/larvatagger
+++ b/scripts/larvatagger
@@ -34,6 +34,7 @@ Usage:
   larvatagger train <backend-path> <data-path> <model-instance> [--pretrained-model=<instance>] [--labels=<comma-separated-list>] [--sample-size=<N>] [--balancing-strategy=<strategy>] [--class-weights=<csv>] [--manual-label=<label>] [--layers=<N>] [--iterations=<N>] [--seed=<seed>] [--debug]
   larvatagger train <backend-path> <data-path> <model-instance> --fine-tune=<instance> [--balancing-strategy=<strategy>] [--manual-label=<label>] [--iterations=<N>] [--seed=<seed>] [--debug]
   larvatagger predict <backend-path> <model-instance> <data-path> [--output=<filename>] [--make-dataset] [--skip-make-dataset] [--data-isolation] [--debug]
+  larvatagger predict <backend-path> <model-instance> <data-path> --embeddings [--data-isolation] [--debug]
   larvatagger merge <input-path> <input-file> [<output-file>] [--manual-label=<label>] [--decode]
   larvatagger -V | --version
   larvatagger -h | --help
@@ -59,6 +60,7 @@ Options:
   --seed=<seed>         Seed for the backend's random number generators.
   --segment=<t0,t1>     Start and end times (included, comma-separated) for cropping and including tracks.
   --debug               Lower the logging level to DEBUG.
+  --embeddings          (MaggotUBA) Call the backend to generate embeddings instead of labels.
   --decode              Do not encode the labels into integer indices.
   --copy-labels         Replicate discrete behavior data from the input file.
   --default-label=<label>             Label all untagged data as <label>.
diff --git a/src/Taggers.jl b/src/Taggers.jl
index 38f57c2..f75bb50 100644
--- a/src/Taggers.jl
+++ b/src/Taggers.jl
@@ -2,7 +2,7 @@ module Taggers
 
 import PlanarLarvae.Formats, PlanarLarvae.Dataloaders
 
-export Tagger, isbackend, resetmodel, resetdata, train, predict, finetune
+export Tagger, isbackend, resetmodel, resetdata, train, predict, finetune, embed
 
 struct Tagger
     backend_dir::String
@@ -247,4 +247,14 @@ function finetune(tagger::Tagger; original_instance=nothing, kwargs...)
     return ret
 end
 
+function embed(tagger::Tagger; kwargs...)
+    args = ["--model-instance", tagger.model_instance]
+    if !isnothing(tagger.sandbox)
+        push!(args, "--sandbox")
+        push!(args, tagger.sandbox)
+    end
+    parsekwargs!(args, kwargs)
+    run(Cmd(`poetry run tagging-backend embed $(args)`; dir=tagger.backend_dir))
+end
+
 end # module
diff --git a/src/cli.jl b/src/cli.jl
index ed85b62..c0a25fb 100644
--- a/src/cli.jl
+++ b/src/cli.jl
@@ -14,6 +14,7 @@ Usage:
   larvatagger.jl train <backend-path> <data-path> <model-instance> [--pretrained-model=<instance>] [--labels=<comma-separated-list>] [--sample-size=<N>] [--balancing-strategy=<strategy>] [--class-weights=<csv>] [--manual-label=<label>] [--layers=<N>] [--iterations=<N>] [--seed=<seed>] [--debug]
   larvatagger.jl train <backend-path> <data-path> <model-instance> --fine-tune=<instance> [--balancing-strategy=<strategy>] [--manual-label=<label>] [--iterations=<N>] [--seed=<seed>] [--debug]
   larvatagger.jl predict <backend-path> <model-instance> <data-path> [--output=<filename>] [--make-dataset] [--skip-make-dataset] [--data-isolation] [--debug]
+  larvatagger.jl predict <backend-path> <model-instance> <data-path> --embeddings [--data-isolation] [--debug]
   larvatagger.jl merge <input-path> <input-file> [<output-file>] [--manual-label=<label>] [--decode]
   larvatagger.jl -V | --version
   larvatagger.jl -h | --help
@@ -39,6 +40,7 @@ Options:
   --seed=<seed>         Seed for the backend's random number generators.
   --segment=<t0,t1>     Start and end times (included, comma-separated) for cropping and including tracks.
   --debug               Lower the logging level to DEBUG.
+  --embeddings          (MaggotUBA) Call the backend to generate embeddings instead of labels.
   --decode              Do not encode the labels into integer indices.
   --copy-labels         Replicate discrete behavior data from the input file.
   --default-label=<label>             Label all untagged data as <label>.
diff --git a/src/cli_toolkit.jl b/src/cli_toolkit.jl
index 109cde4..5614735 100644
--- a/src/cli_toolkit.jl
+++ b/src/cli_toolkit.jl
@@ -19,6 +19,7 @@ Usage:
   larvatagger-toolkit.jl train <backend-path> <data-path> <model-instance> [--pretrained-model=<instance>] [--labels=<comma-separated-list>] [--sample-size=<N>] [--balancing-strategy=<strategy>] [--class-weights=<csv>] [--manual-label=<label>] [--layers=<N>] [--iterations=<N>] [--seed=<seed>] [--debug]
   larvatagger-toolkit.jl train <backend-path> <data-path> <model-instance> --fine-tune=<instance> [--balancing-strategy=<strategy>] [--manual-label=<label>] [--iterations=<N>] [--seed=<seed>] [--debug]
   larvatagger-toolkit.jl predict <backend-path> <model-instance> <data-path> [--output=<filename>] [--make-dataset] [--skip-make-dataset] [--data-isolation] [--debug]
+  larvatagger-toolkit.jl predict <backend-path> <model-instance> <data-path> --embeddings [--data-isolation] [--debug]
   larvatagger-toolkit.jl merge <input-path> <input-file> [<output-file>] [--manual-label=<label>] [--decode]
   larvatagger-toolkit.jl -V | --version
   larvatagger-toolkit.jl -h | --help
@@ -38,6 +39,7 @@ Options:
   --seed=<seed>        Seed for the backend's random number generators.
   --segment=<t0,t1>    Start and end times (included, comma-separated) for cropping and including tracks.
   --debug              Lower the logging level to DEBUG.
+  --embeddings         (MaggotUBA) Call the backend to generate embeddings instead of labels.
   --decode             Do not encode the labels into integer indices.
   --copy-labels        Replicate discrete behavior data from the input file.
   --default-label=<label>             Label all untagged data as <label>.
@@ -188,6 +190,7 @@ function main(args=ARGS; exit_on_error=true)
         data_path = parsed_args["<data-path>"]
         data_isolation = parsed_args["--data-isolation"]
         output_filename = parsed_args["--output"]
+        embeddings = parsed_args["--embeddings"]
         #
         datapath = abspath(data_path)
         destination = if isfile(datapath)
@@ -216,8 +219,13 @@ function main(args=ARGS; exit_on_error=true)
         end
         resetdata(tagger)
         Taggers.push(tagger, datapath)
-        predict(tagger; skip_make_dataset=parsed_args["--skip-make-dataset"],
-            make_dataset=parsed_args["--make-dataset"], debug=parsed_args["--debug"])
+        if embeddings
+            embed(tagger; skip_make_dataset=parsed_args["--skip-make-dataset"],
+                make_dataset=parsed_args["--make-dataset"], debug=parsed_args["--debug"])
+        else
+            predict(tagger; skip_make_dataset=parsed_args["--skip-make-dataset"],
+                make_dataset=parsed_args["--make-dataset"], debug=parsed_args["--debug"])
+        end
         Taggers.pull(tagger, destination)
     end
 end
-- 
GitLab