diff --git a/scripts/larvatagger b/scripts/larvatagger index 6425ae80dbe56a837c0f0ef315db5f771cf6dd1d..d9a475be6a19e7f2cab8feccfc1feea31bd87bf4 100755 --- a/scripts/larvatagger +++ b/scripts/larvatagger @@ -34,6 +34,7 @@ Usage: larvatagger train <backend-path> <data-path> <model-instance> [--pretrained-model=<instance>] [--labels=<comma-separated-list>] [--sample-size=<N>] [--balancing-strategy=<strategy>] [--class-weights=<csv>] [--manual-label=<label>] [--layers=<N>] [--iterations=<N>] [--seed=<seed>] [--debug] larvatagger train <backend-path> <data-path> <model-instance> --fine-tune=<instance> [--balancing-strategy=<strategy>] [--manual-label=<label>] [--iterations=<N>] [--seed=<seed>] [--debug] larvatagger predict <backend-path> <model-instance> <data-path> [--output=<filename>] [--make-dataset] [--skip-make-dataset] [--data-isolation] [--debug] + larvatagger predict <backend-path> <model-instance> <data-path> --embeddings [--data-isolation] [--debug] larvatagger merge <input-path> <input-file> [<output-file>] [--manual-label=<label>] [--decode] larvatagger -V | --version larvatagger -h | --help @@ -59,6 +60,7 @@ Options: --seed=<seed> Seed for the backend's random number generators. --segment=<t0,t1> Start and end times (included, comma-separated) for cropping and including tracks. --debug Lower the logging level to DEBUG. + --embeddings (MaggotUBA) Call the backend to generate embeddings instead of labels. --decode Do not encode the labels into integer indices. --copy-labels Replicate discrete behavior data from the input file. --default-label=<label> Label all untagged data as <label>. diff --git a/src/Taggers.jl b/src/Taggers.jl index 38f57c2e1683e57a49b83f3da2c3bf66744e678c..f75bb50625b14544dbda737e13752865f20c4d0d 100644 --- a/src/Taggers.jl +++ b/src/Taggers.jl @@ -2,7 +2,7 @@ module Taggers import PlanarLarvae.Formats, PlanarLarvae.Dataloaders -export Tagger, isbackend, resetmodel, resetdata, train, predict, finetune +export Tagger, isbackend, resetmodel, resetdata, train, predict, finetune, embed struct Tagger backend_dir::String @@ -247,4 +247,14 @@ function finetune(tagger::Tagger; original_instance=nothing, kwargs...) return ret end +function embed(tagger::Tagger; kwargs...) + args = ["--model-instance", tagger.model_instance] + if !isnothing(tagger.sandbox) + push!(args, "--sandbox") + push!(args, tagger.sandbox) + end + parsekwargs!(args, kwargs) + run(Cmd(`poetry run tagging-backend embed $(args)`; dir=tagger.backend_dir)) +end + end # module diff --git a/src/cli.jl b/src/cli.jl index ed85b62563a56cc43b110c207bdd43b6577c9f42..c0a25fbc05090bc4e0ee55eb5c23388c8b7e62d6 100644 --- a/src/cli.jl +++ b/src/cli.jl @@ -14,6 +14,7 @@ Usage: larvatagger.jl train <backend-path> <data-path> <model-instance> [--pretrained-model=<instance>] [--labels=<comma-separated-list>] [--sample-size=<N>] [--balancing-strategy=<strategy>] [--class-weights=<csv>] [--manual-label=<label>] [--layers=<N>] [--iterations=<N>] [--seed=<seed>] [--debug] larvatagger.jl train <backend-path> <data-path> <model-instance> --fine-tune=<instance> [--balancing-strategy=<strategy>] [--manual-label=<label>] [--iterations=<N>] [--seed=<seed>] [--debug] larvatagger.jl predict <backend-path> <model-instance> <data-path> [--output=<filename>] [--make-dataset] [--skip-make-dataset] [--data-isolation] [--debug] + larvatagger.jl predict <backend-path> <model-instance> <data-path> --embeddings [--data-isolation] [--debug] larvatagger.jl merge <input-path> <input-file> [<output-file>] [--manual-label=<label>] [--decode] larvatagger.jl -V | --version larvatagger.jl -h | --help @@ -39,6 +40,7 @@ Options: --seed=<seed> Seed for the backend's random number generators. --segment=<t0,t1> Start and end times (included, comma-separated) for cropping and including tracks. --debug Lower the logging level to DEBUG. + --embeddings (MaggotUBA) Call the backend to generate embeddings instead of labels. --decode Do not encode the labels into integer indices. --copy-labels Replicate discrete behavior data from the input file. --default-label=<label> Label all untagged data as <label>. diff --git a/src/cli_toolkit.jl b/src/cli_toolkit.jl index 109cde44e83fc88ae9b780d7bd439bd66cbe9963..5614735f43da6693ad420b0a416c32c74792c9b3 100644 --- a/src/cli_toolkit.jl +++ b/src/cli_toolkit.jl @@ -19,6 +19,7 @@ Usage: larvatagger-toolkit.jl train <backend-path> <data-path> <model-instance> [--pretrained-model=<instance>] [--labels=<comma-separated-list>] [--sample-size=<N>] [--balancing-strategy=<strategy>] [--class-weights=<csv>] [--manual-label=<label>] [--layers=<N>] [--iterations=<N>] [--seed=<seed>] [--debug] larvatagger-toolkit.jl train <backend-path> <data-path> <model-instance> --fine-tune=<instance> [--balancing-strategy=<strategy>] [--manual-label=<label>] [--iterations=<N>] [--seed=<seed>] [--debug] larvatagger-toolkit.jl predict <backend-path> <model-instance> <data-path> [--output=<filename>] [--make-dataset] [--skip-make-dataset] [--data-isolation] [--debug] + larvatagger-toolkit.jl predict <backend-path> <model-instance> <data-path> --embeddings [--data-isolation] [--debug] larvatagger-toolkit.jl merge <input-path> <input-file> [<output-file>] [--manual-label=<label>] [--decode] larvatagger-toolkit.jl -V | --version larvatagger-toolkit.jl -h | --help @@ -38,6 +39,7 @@ Options: --seed=<seed> Seed for the backend's random number generators. --segment=<t0,t1> Start and end times (included, comma-separated) for cropping and including tracks. --debug Lower the logging level to DEBUG. + --embeddings (MaggotUBA) Call the backend to generate embeddings instead of labels. --decode Do not encode the labels into integer indices. --copy-labels Replicate discrete behavior data from the input file. --default-label=<label> Label all untagged data as <label>. @@ -188,6 +190,7 @@ function main(args=ARGS; exit_on_error=true) data_path = parsed_args["<data-path>"] data_isolation = parsed_args["--data-isolation"] output_filename = parsed_args["--output"] + embeddings = parsed_args["--embeddings"] # datapath = abspath(data_path) destination = if isfile(datapath) @@ -216,8 +219,13 @@ function main(args=ARGS; exit_on_error=true) end resetdata(tagger) Taggers.push(tagger, datapath) - predict(tagger; skip_make_dataset=parsed_args["--skip-make-dataset"], - make_dataset=parsed_args["--make-dataset"], debug=parsed_args["--debug"]) + if embeddings + embed(tagger; skip_make_dataset=parsed_args["--skip-make-dataset"], + make_dataset=parsed_args["--make-dataset"], debug=parsed_args["--debug"]) + else + predict(tagger; skip_make_dataset=parsed_args["--skip-make-dataset"], + make_dataset=parsed_args["--make-dataset"], debug=parsed_args["--debug"]) + end Taggers.pull(tagger, destination) end end