Skip to content
Snippets Groups Projects
Commit 98e3ed82 authored by François  LAURENT's avatar François LAURENT
Browse files

Merge branch 'dev' into main

parents 65ded650 5f105da5
No related branches found
No related tags found
No related merge requests found
Pipeline #97832 passed
......@@ -310,11 +310,11 @@ version = "1.8.0"
[[deps.PlanarLarvae]]
deps = ["DelimitedFiles", "HDF5", "JSON3", "LinearAlgebra", "MAT", "Meshes", "OrderedCollections", "SHA", "StaticArrays", "Statistics", "StatsBase", "StructTypes"]
git-tree-sha1 = "4d26be48d93856d4d8f087f4b8e5d21d9c6c491d"
git-tree-sha1 = "f7e528d2ecb7b6ef13aab96fade5b4d0a4c64767"
repo-rev = "main"
repo-url = "https://gitlab.pasteur.fr/nyx/planarlarvae.jl"
uuid = "c2615984-ef14-4d40-b148-916c85b43307"
version = "0.8.1"
version = "0.9.0"
[[deps.Preferences]]
deps = ["TOML"]
......
name = "TaggingBackends"
uuid = "e551f703-3b82-4335-b341-d497b48d519b"
authors = ["François Laurent", "Institut Pasteur"]
version = "0.10.0"
version = "0.11.0"
[deps]
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
......
[tool.poetry]
name = "TaggingBackends"
version = "0.10"
version = "0.11"
description = "Backbone for LarvaTagger.jl tagging backends"
authors = ["François Laurent"]
......
......@@ -8,7 +8,8 @@ from collections import Counter
Torch-like dataset class for *larva_dataset hdf5* files.
"""
class LarvaDataset:
def __init__(self, dataset=None, generator=None, subsets=(.8, .1, .1)):
def __init__(self, dataset=None, generator=None, subsets=(.8, .1, .1),
balancing_strategy=None, class_weights=None):
self.generator = generator
self._full_set = dataset
self.subset_shares = subsets
......@@ -21,6 +22,11 @@ class LarvaDataset:
self._class_weights = None
# this attribute was introduced to implement `training_labels`
self._alt_training_set_loader = None
if class_weights is None:
self.weight_classes = isinstance(balancing_strategy, str) and (balancing_strategy.lower() == 'auto')
else:
self.class_weights = class_weights
"""
*h5py.File*: *larva_dataset hdf5* file handler.
"""
......
import os
import sys
import logging
from taggingbackends.explorer import BackendExplorer, BackendExplorerDecoder
from taggingbackends.explorer import BackendExplorer, BackendExplorerDecoder, getlogger
def help(_print=False):
msg = """
Usage: tagging-backend [train|predict] --model-instance <name>
tagging-backend train ... --labels <comma-separated-list>
tagging-backend train ... --sample-size <N> --balancing-strategy <strategy>
tagging-backend train ... --labels <labels> --class-weights <weights>
tagging-backend train ... --sample-size <N> --balancing-strategy <S>
tagging-backend train ... --frame-interval <I> --window-length <T>
tagging-backend train ... --pretrained-model-instance <name>
tagging-backend train ... --include-all <secondary-label>
tagging-backend train ... --skip-make-dataset --skip-build-features
tagging-backend predict ... --make-dataset --build-features --sandbox <token>
tagging-backend predict ... --make-dataset --build-features
tagging-backend predict ... --sandbox <token>
`tagging-backend` typically is run using `poetry run`.
A name must be provided to identify the trained model and its location within
......@@ -20,17 +21,48 @@ the backend.
An intermediate *larva_dataset hdf5* file is generated, with series of labelled
spines. If option `--sample-size` is passed, <N> time segments are sampled from
the raw database. The total length 3*<T> of time segments is 60 per default
(20 *past* points, 20 *present* points and 20 *future* points).
the raw database.
The total length 3*<T> of time segments is 60 per default (20 *past* points,
20 *present* points and 20 *future* points). This means the length is always a
multiple of 3. This is inherited from MaggotUBA, whose *larva_dataset* interim
file was borrowed, and may be changed in the future.
If frame interval <I> is specified (in seconds), spine series are resampled and
interpolated around each time segment anchor (center).
**Deprecated**:
Option `--trxmat-only` is suitable for large databases made of trx.mat files
only. Intermediate HDF5 files are generated prior to counting the various
behavior labels and sampling time segments in the database. These intermediate
files may be kept to avoid loading the trx.mat files again.
Some behavior labels can be listed for inclusion in the *larva_dataset* interim
file, with argument `--labels` as a comma-separated list of labels.
If `--labels` is defined, `--class-weights` can be passed as well to specify
associated weights in the calculation of the cost function the training
algorithm aims to minimize. Weights are also specified as a comma-separated
list of floating-point values. As many weights as labels are expected.
In addition to class penalties, the majority classes can be subsampled in
different ways. Argument `--balancing-strategy` can take either "maggotuba" or
"auto", that correspond to specific subsampling strategies.
--balancing-strategy maggotuba
Denoting n the size of the minority class, classes of size less than 10n are
also considered minority classes and are subsampled down to size n. Classes
of size greater than or equal to 10n are considered majority classes and are
subsampled down to size 2n.
--balancing-strategy auto
Denoting n the size of the minority class, classes of size greater than or
equal to 20n are subsampled down to 20n. Other classes are not subsampled.
In addition, if class weights are not defined, they are set as the inverse
of the corresponding class size.
Subsampling is done at random. However, data observations bearing a specific
secondary label can be included with priority, up to the target size if too
many, and then complemented with different randomly picked observations. To
trigger this behavior, specify the secondary label with argument --include-all.
`--sandbox <token>` makes `tagging-backend` use a token instead of <name> as
directory name in data/raw, data/interim and data/processed.
This is intended to prevent conflicts on running `predict` in parallel on
multiple data files with multiple calls.
Note that an existing larva_dataset file in data/interim/<name> makes the
`train` command skip the `make_dataset` step. However, even in such a case,
......@@ -41,11 +73,6 @@ truly skip this step; the corresponding module is not loaded.
Since version 0.10, `predict` makes `--skip-make-dataset` and
`--skip-build-features` the default behavior. As a counterpart, it admits
arguments `--make-dataset` and `--build-features`.
`--sandbox <token>` makes `tagging-backend` use a token instead of <name> as
directory name in data/raw, data/interim and data/processed.
This is intended to prevent conflicts on running `predict` in parallel on
multiple data files with multiple calls.
"""
if _print:
print(msg)
......@@ -70,6 +97,7 @@ def main(fun=None):
sandbox = False
balancing_strategy = 'auto'
include_all = None
class_weights = None
unknown_args = {}
k = 2
while k < len(sys.argv):
......@@ -112,6 +140,9 @@ def main(fun=None):
elif sys.argv[k] == "--sandbox":
k = k + 1
sandbox = sys.argv[k]
elif sys.argv[k] == "--class-weights":
k = k + 1
class_weights = sys.argv[k]
elif sys.argv[k] == "--balancing-strategy":
k = k + 1
balancing_strategy = sys.argv[k]
......@@ -146,10 +177,12 @@ def main(fun=None):
if frame_interval:
make_dataset_kwargs["frame_interval"] = frame_interval
if trxmat_only:
# deprecated
make_dataset_kwargs["trxmat_only"] = True
if reuse_h5files:
make_dataset_kwargs["reuse_h5files"] = True
elif reuse_h5files:
# deprecated
logging.info("option --reuse-h5files is ignored in the absence of --trxmat-only")
if include_all:
make_dataset_kwargs["include_all"] = include_all
......@@ -162,6 +195,8 @@ def main(fun=None):
train_kwargs = dict(balancing_strategy=balancing_strategy)
if pretrained_model_instance:
train_kwargs["pretrained_model_instance"] = pretrained_model_instance
if class_weights:
train_kwargs['class_weights'] = class_weights
backend._run_script(backend.train_model, trailing=unknown_args, **train_kwargs)
else:
# called by make_dataset, build_features, train_model and predict_model
......@@ -180,6 +215,9 @@ def main(fun=None):
elif key == "labels":
if isinstance(val, str):
val = val.split(',')
elif key == 'class_weights':
if isinstance(val, str):
val = [float(s) for s in val.split(',')]
elif key == "pretrained_model_instance":
val_ = val.split(',')
if val_[1:]:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment