From ab9e7ceaa48469126b64c1c523d461e1f5d73cd8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Laurent?= <francois.laurent@posteo.net> Date: Tue, 2 Jan 2024 21:22:05 +0100 Subject: [PATCH] --embeddings option --- scripts/larvatagger | 2 ++ src/Taggers.jl | 12 +++++++++++- src/cli.jl | 2 ++ src/cli_toolkit.jl | 12 ++++++++++-- 4 files changed, 25 insertions(+), 3 deletions(-) diff --git a/scripts/larvatagger b/scripts/larvatagger index 6425ae8..d9a475b 100755 --- a/scripts/larvatagger +++ b/scripts/larvatagger @@ -34,6 +34,7 @@ Usage: larvatagger train <backend-path> <data-path> <model-instance> [--pretrained-model=<instance>] [--labels=<comma-separated-list>] [--sample-size=<N>] [--balancing-strategy=<strategy>] [--class-weights=<csv>] [--manual-label=<label>] [--layers=<N>] [--iterations=<N>] [--seed=<seed>] [--debug] larvatagger train <backend-path> <data-path> <model-instance> --fine-tune=<instance> [--balancing-strategy=<strategy>] [--manual-label=<label>] [--iterations=<N>] [--seed=<seed>] [--debug] larvatagger predict <backend-path> <model-instance> <data-path> [--output=<filename>] [--make-dataset] [--skip-make-dataset] [--data-isolation] [--debug] + larvatagger predict <backend-path> <model-instance> <data-path> --embeddings [--data-isolation] [--debug] larvatagger merge <input-path> <input-file> [<output-file>] [--manual-label=<label>] [--decode] larvatagger -V | --version larvatagger -h | --help @@ -59,6 +60,7 @@ Options: --seed=<seed> Seed for the backend's random number generators. --segment=<t0,t1> Start and end times (included, comma-separated) for cropping and including tracks. --debug Lower the logging level to DEBUG. + --embeddings (MaggotUBA) Call the backend to generate embeddings instead of labels. --decode Do not encode the labels into integer indices. --copy-labels Replicate discrete behavior data from the input file. --default-label=<label> Label all untagged data as <label>. diff --git a/src/Taggers.jl b/src/Taggers.jl index 38f57c2..f75bb50 100644 --- a/src/Taggers.jl +++ b/src/Taggers.jl @@ -2,7 +2,7 @@ module Taggers import PlanarLarvae.Formats, PlanarLarvae.Dataloaders -export Tagger, isbackend, resetmodel, resetdata, train, predict, finetune +export Tagger, isbackend, resetmodel, resetdata, train, predict, finetune, embed struct Tagger backend_dir::String @@ -247,4 +247,14 @@ function finetune(tagger::Tagger; original_instance=nothing, kwargs...) return ret end +function embed(tagger::Tagger; kwargs...) + args = ["--model-instance", tagger.model_instance] + if !isnothing(tagger.sandbox) + push!(args, "--sandbox") + push!(args, tagger.sandbox) + end + parsekwargs!(args, kwargs) + run(Cmd(`poetry run tagging-backend embed $(args)`; dir=tagger.backend_dir)) +end + end # module diff --git a/src/cli.jl b/src/cli.jl index ed85b62..c0a25fb 100644 --- a/src/cli.jl +++ b/src/cli.jl @@ -14,6 +14,7 @@ Usage: larvatagger.jl train <backend-path> <data-path> <model-instance> [--pretrained-model=<instance>] [--labels=<comma-separated-list>] [--sample-size=<N>] [--balancing-strategy=<strategy>] [--class-weights=<csv>] [--manual-label=<label>] [--layers=<N>] [--iterations=<N>] [--seed=<seed>] [--debug] larvatagger.jl train <backend-path> <data-path> <model-instance> --fine-tune=<instance> [--balancing-strategy=<strategy>] [--manual-label=<label>] [--iterations=<N>] [--seed=<seed>] [--debug] larvatagger.jl predict <backend-path> <model-instance> <data-path> [--output=<filename>] [--make-dataset] [--skip-make-dataset] [--data-isolation] [--debug] + larvatagger.jl predict <backend-path> <model-instance> <data-path> --embeddings [--data-isolation] [--debug] larvatagger.jl merge <input-path> <input-file> [<output-file>] [--manual-label=<label>] [--decode] larvatagger.jl -V | --version larvatagger.jl -h | --help @@ -39,6 +40,7 @@ Options: --seed=<seed> Seed for the backend's random number generators. --segment=<t0,t1> Start and end times (included, comma-separated) for cropping and including tracks. --debug Lower the logging level to DEBUG. + --embeddings (MaggotUBA) Call the backend to generate embeddings instead of labels. --decode Do not encode the labels into integer indices. --copy-labels Replicate discrete behavior data from the input file. --default-label=<label> Label all untagged data as <label>. diff --git a/src/cli_toolkit.jl b/src/cli_toolkit.jl index 109cde4..5614735 100644 --- a/src/cli_toolkit.jl +++ b/src/cli_toolkit.jl @@ -19,6 +19,7 @@ Usage: larvatagger-toolkit.jl train <backend-path> <data-path> <model-instance> [--pretrained-model=<instance>] [--labels=<comma-separated-list>] [--sample-size=<N>] [--balancing-strategy=<strategy>] [--class-weights=<csv>] [--manual-label=<label>] [--layers=<N>] [--iterations=<N>] [--seed=<seed>] [--debug] larvatagger-toolkit.jl train <backend-path> <data-path> <model-instance> --fine-tune=<instance> [--balancing-strategy=<strategy>] [--manual-label=<label>] [--iterations=<N>] [--seed=<seed>] [--debug] larvatagger-toolkit.jl predict <backend-path> <model-instance> <data-path> [--output=<filename>] [--make-dataset] [--skip-make-dataset] [--data-isolation] [--debug] + larvatagger-toolkit.jl predict <backend-path> <model-instance> <data-path> --embeddings [--data-isolation] [--debug] larvatagger-toolkit.jl merge <input-path> <input-file> [<output-file>] [--manual-label=<label>] [--decode] larvatagger-toolkit.jl -V | --version larvatagger-toolkit.jl -h | --help @@ -38,6 +39,7 @@ Options: --seed=<seed> Seed for the backend's random number generators. --segment=<t0,t1> Start and end times (included, comma-separated) for cropping and including tracks. --debug Lower the logging level to DEBUG. + --embeddings (MaggotUBA) Call the backend to generate embeddings instead of labels. --decode Do not encode the labels into integer indices. --copy-labels Replicate discrete behavior data from the input file. --default-label=<label> Label all untagged data as <label>. @@ -188,6 +190,7 @@ function main(args=ARGS; exit_on_error=true) data_path = parsed_args["<data-path>"] data_isolation = parsed_args["--data-isolation"] output_filename = parsed_args["--output"] + embeddings = parsed_args["--embeddings"] # datapath = abspath(data_path) destination = if isfile(datapath) @@ -216,8 +219,13 @@ function main(args=ARGS; exit_on_error=true) end resetdata(tagger) Taggers.push(tagger, datapath) - predict(tagger; skip_make_dataset=parsed_args["--skip-make-dataset"], - make_dataset=parsed_args["--make-dataset"], debug=parsed_args["--debug"]) + if embeddings + embed(tagger; skip_make_dataset=parsed_args["--skip-make-dataset"], + make_dataset=parsed_args["--make-dataset"], debug=parsed_args["--debug"]) + else + predict(tagger; skip_make_dataset=parsed_args["--skip-make-dataset"], + make_dataset=parsed_args["--make-dataset"], debug=parsed_args["--debug"]) + end Taggers.pull(tagger, destination) end end -- GitLab