diff --git a/Manifest.toml b/Manifest.toml index 0fe6fe2ff6b626d1f7008d0929601c7be8fe1bce..bbcd605befae4e3946638b875ce2ff2fae7fd0fb 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -310,11 +310,11 @@ version = "1.8.0" [[deps.PlanarLarvae]] deps = ["DelimitedFiles", "HDF5", "JSON3", "LinearAlgebra", "MAT", "Meshes", "OrderedCollections", "SHA", "StaticArrays", "Statistics", "StatsBase", "StructTypes"] -git-tree-sha1 = "4d26be48d93856d4d8f087f4b8e5d21d9c6c491d" +git-tree-sha1 = "f7e528d2ecb7b6ef13aab96fade5b4d0a4c64767" repo-rev = "main" repo-url = "https://gitlab.pasteur.fr/nyx/planarlarvae.jl" uuid = "c2615984-ef14-4d40-b148-916c85b43307" -version = "0.8.1" +version = "0.9.0" [[deps.Preferences]] deps = ["TOML"] diff --git a/Project.toml b/Project.toml index a75dfba0f839e4b7833a667211df8ef22c653b2f..f7d95c40bf9f815922f6ea956ee3d7cb5658c7f8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TaggingBackends" uuid = "e551f703-3b82-4335-b341-d497b48d519b" authors = ["François Laurent", "Institut Pasteur"] -version = "0.10.0" +version = "0.11.0" [deps] Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" diff --git a/pyproject.toml b/pyproject.toml index e4b1161777c9e95a16a142e48462a3a767c44482..f0e6d26c9fe3068bbe01e3f10177d16790b67807 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "TaggingBackends" -version = "0.10" +version = "0.11" description = "Backbone for LarvaTagger.jl tagging backends" authors = ["François Laurent"] diff --git a/src/taggingbackends/data/dataset.py b/src/taggingbackends/data/dataset.py index 9ce597029f74b3433edf874862d23b2b18caf4ec..8504f2755471462f07d83e12a4e6f97ffb3549f3 100644 --- a/src/taggingbackends/data/dataset.py +++ b/src/taggingbackends/data/dataset.py @@ -8,7 +8,8 @@ from collections import Counter Torch-like dataset class for *larva_dataset hdf5* files. """ class LarvaDataset: - def __init__(self, dataset=None, generator=None, subsets=(.8, .1, .1)): + def __init__(self, dataset=None, generator=None, subsets=(.8, .1, .1), + balancing_strategy=None, class_weights=None): self.generator = generator self._full_set = dataset self.subset_shares = subsets @@ -21,6 +22,11 @@ class LarvaDataset: self._class_weights = None # this attribute was introduced to implement `training_labels` self._alt_training_set_loader = None + if class_weights is None: + self.weight_classes = isinstance(balancing_strategy, str) and (balancing_strategy.lower() == 'auto') + else: + self.class_weights = class_weights + """ *h5py.File*: *larva_dataset hdf5* file handler. """ diff --git a/src/taggingbackends/main.py b/src/taggingbackends/main.py index 76761fdf39cb7697dca78051aa00d6633e80db67..dde5335e0a03a7620aa8a334c59bc9fb800bc636 100644 --- a/src/taggingbackends/main.py +++ b/src/taggingbackends/main.py @@ -1,18 +1,19 @@ import os import sys import logging -from taggingbackends.explorer import BackendExplorer, BackendExplorerDecoder +from taggingbackends.explorer import BackendExplorer, BackendExplorerDecoder, getlogger def help(_print=False): msg = """ Usage: tagging-backend [train|predict] --model-instance <name> - tagging-backend train ... --labels <comma-separated-list> - tagging-backend train ... --sample-size <N> --balancing-strategy <strategy> + tagging-backend train ... --labels <labels> --class-weights <weights> + tagging-backend train ... --sample-size <N> --balancing-strategy <S> tagging-backend train ... --frame-interval <I> --window-length <T> tagging-backend train ... --pretrained-model-instance <name> tagging-backend train ... --include-all <secondary-label> tagging-backend train ... --skip-make-dataset --skip-build-features - tagging-backend predict ... --make-dataset --build-features --sandbox <token> + tagging-backend predict ... --make-dataset --build-features + tagging-backend predict ... --sandbox <token> `tagging-backend` typically is run using `poetry run`. A name must be provided to identify the trained model and its location within @@ -20,17 +21,48 @@ the backend. An intermediate *larva_dataset hdf5* file is generated, with series of labelled spines. If option `--sample-size` is passed, <N> time segments are sampled from -the raw database. The total length 3*<T> of time segments is 60 per default -(20 *past* points, 20 *present* points and 20 *future* points). +the raw database. + +The total length 3*<T> of time segments is 60 per default (20 *past* points, +20 *present* points and 20 *future* points). This means the length is always a +multiple of 3. This is inherited from MaggotUBA, whose *larva_dataset* interim +file was borrowed, and may be changed in the future. If frame interval <I> is specified (in seconds), spine series are resampled and interpolated around each time segment anchor (center). -**Deprecated**: -Option `--trxmat-only` is suitable for large databases made of trx.mat files -only. Intermediate HDF5 files are generated prior to counting the various -behavior labels and sampling time segments in the database. These intermediate -files may be kept to avoid loading the trx.mat files again. +Some behavior labels can be listed for inclusion in the *larva_dataset* interim +file, with argument `--labels` as a comma-separated list of labels. +If `--labels` is defined, `--class-weights` can be passed as well to specify +associated weights in the calculation of the cost function the training +algorithm aims to minimize. Weights are also specified as a comma-separated +list of floating-point values. As many weights as labels are expected. + +In addition to class penalties, the majority classes can be subsampled in +different ways. Argument `--balancing-strategy` can take either "maggotuba" or +"auto", that correspond to specific subsampling strategies. + +--balancing-strategy maggotuba + Denoting n the size of the minority class, classes of size less than 10n are + also considered minority classes and are subsampled down to size n. Classes + of size greater than or equal to 10n are considered majority classes and are + subsampled down to size 2n. + +--balancing-strategy auto + Denoting n the size of the minority class, classes of size greater than or + equal to 20n are subsampled down to 20n. Other classes are not subsampled. + In addition, if class weights are not defined, they are set as the inverse + of the corresponding class size. + +Subsampling is done at random. However, data observations bearing a specific +secondary label can be included with priority, up to the target size if too +many, and then complemented with different randomly picked observations. To +trigger this behavior, specify the secondary label with argument --include-all. + +`--sandbox <token>` makes `tagging-backend` use a token instead of <name> as +directory name in data/raw, data/interim and data/processed. +This is intended to prevent conflicts on running `predict` in parallel on +multiple data files with multiple calls. Note that an existing larva_dataset file in data/interim/<name> makes the `train` command skip the `make_dataset` step. However, even in such a case, @@ -41,11 +73,6 @@ truly skip this step; the corresponding module is not loaded. Since version 0.10, `predict` makes `--skip-make-dataset` and `--skip-build-features` the default behavior. As a counterpart, it admits arguments `--make-dataset` and `--build-features`. - -`--sandbox <token>` makes `tagging-backend` use a token instead of <name> as -directory name in data/raw, data/interim and data/processed. -This is intended to prevent conflicts on running `predict` in parallel on -multiple data files with multiple calls. """ if _print: print(msg) @@ -70,6 +97,7 @@ def main(fun=None): sandbox = False balancing_strategy = 'auto' include_all = None + class_weights = None unknown_args = {} k = 2 while k < len(sys.argv): @@ -112,6 +140,9 @@ def main(fun=None): elif sys.argv[k] == "--sandbox": k = k + 1 sandbox = sys.argv[k] + elif sys.argv[k] == "--class-weights": + k = k + 1 + class_weights = sys.argv[k] elif sys.argv[k] == "--balancing-strategy": k = k + 1 balancing_strategy = sys.argv[k] @@ -146,10 +177,12 @@ def main(fun=None): if frame_interval: make_dataset_kwargs["frame_interval"] = frame_interval if trxmat_only: + # deprecated make_dataset_kwargs["trxmat_only"] = True if reuse_h5files: make_dataset_kwargs["reuse_h5files"] = True elif reuse_h5files: + # deprecated logging.info("option --reuse-h5files is ignored in the absence of --trxmat-only") if include_all: make_dataset_kwargs["include_all"] = include_all @@ -162,6 +195,8 @@ def main(fun=None): train_kwargs = dict(balancing_strategy=balancing_strategy) if pretrained_model_instance: train_kwargs["pretrained_model_instance"] = pretrained_model_instance + if class_weights: + train_kwargs['class_weights'] = class_weights backend._run_script(backend.train_model, trailing=unknown_args, **train_kwargs) else: # called by make_dataset, build_features, train_model and predict_model @@ -180,6 +215,9 @@ def main(fun=None): elif key == "labels": if isinstance(val, str): val = val.split(',') + elif key == 'class_weights': + if isinstance(val, str): + val = [float(s) for s in val.split(',')] elif key == "pretrained_model_instance": val_ = val.split(',') if val_[1:]: