diff --git a/src/taggingbackends/data/dataset.py b/src/taggingbackends/data/dataset.py index 9ce597029f74b3433edf874862d23b2b18caf4ec..8504f2755471462f07d83e12a4e6f97ffb3549f3 100644 --- a/src/taggingbackends/data/dataset.py +++ b/src/taggingbackends/data/dataset.py @@ -8,7 +8,8 @@ from collections import Counter Torch-like dataset class for *larva_dataset hdf5* files. """ class LarvaDataset: - def __init__(self, dataset=None, generator=None, subsets=(.8, .1, .1)): + def __init__(self, dataset=None, generator=None, subsets=(.8, .1, .1), + balancing_strategy=None, class_weights=None): self.generator = generator self._full_set = dataset self.subset_shares = subsets @@ -21,6 +22,11 @@ class LarvaDataset: self._class_weights = None # this attribute was introduced to implement `training_labels` self._alt_training_set_loader = None + if class_weights is None: + self.weight_classes = isinstance(balancing_strategy, str) and (balancing_strategy.lower() == 'auto') + else: + self.class_weights = class_weights + """ *h5py.File*: *larva_dataset hdf5* file handler. """ diff --git a/src/taggingbackends/main.py b/src/taggingbackends/main.py index 76761fdf39cb7697dca78051aa00d6633e80db67..dde5335e0a03a7620aa8a334c59bc9fb800bc636 100644 --- a/src/taggingbackends/main.py +++ b/src/taggingbackends/main.py @@ -1,18 +1,19 @@ import os import sys import logging -from taggingbackends.explorer import BackendExplorer, BackendExplorerDecoder +from taggingbackends.explorer import BackendExplorer, BackendExplorerDecoder, getlogger def help(_print=False): msg = """ Usage: tagging-backend [train|predict] --model-instance <name> - tagging-backend train ... --labels <comma-separated-list> - tagging-backend train ... --sample-size <N> --balancing-strategy <strategy> + tagging-backend train ... --labels <labels> --class-weights <weights> + tagging-backend train ... --sample-size <N> --balancing-strategy <S> tagging-backend train ... --frame-interval <I> --window-length <T> tagging-backend train ... --pretrained-model-instance <name> tagging-backend train ... --include-all <secondary-label> tagging-backend train ... --skip-make-dataset --skip-build-features - tagging-backend predict ... --make-dataset --build-features --sandbox <token> + tagging-backend predict ... --make-dataset --build-features + tagging-backend predict ... --sandbox <token> `tagging-backend` typically is run using `poetry run`. A name must be provided to identify the trained model and its location within @@ -20,17 +21,48 @@ the backend. An intermediate *larva_dataset hdf5* file is generated, with series of labelled spines. If option `--sample-size` is passed, <N> time segments are sampled from -the raw database. The total length 3*<T> of time segments is 60 per default -(20 *past* points, 20 *present* points and 20 *future* points). +the raw database. + +The total length 3*<T> of time segments is 60 per default (20 *past* points, +20 *present* points and 20 *future* points). This means the length is always a +multiple of 3. This is inherited from MaggotUBA, whose *larva_dataset* interim +file was borrowed, and may be changed in the future. If frame interval <I> is specified (in seconds), spine series are resampled and interpolated around each time segment anchor (center). -**Deprecated**: -Option `--trxmat-only` is suitable for large databases made of trx.mat files -only. Intermediate HDF5 files are generated prior to counting the various -behavior labels and sampling time segments in the database. These intermediate -files may be kept to avoid loading the trx.mat files again. +Some behavior labels can be listed for inclusion in the *larva_dataset* interim +file, with argument `--labels` as a comma-separated list of labels. +If `--labels` is defined, `--class-weights` can be passed as well to specify +associated weights in the calculation of the cost function the training +algorithm aims to minimize. Weights are also specified as a comma-separated +list of floating-point values. As many weights as labels are expected. + +In addition to class penalties, the majority classes can be subsampled in +different ways. Argument `--balancing-strategy` can take either "maggotuba" or +"auto", that correspond to specific subsampling strategies. + +--balancing-strategy maggotuba + Denoting n the size of the minority class, classes of size less than 10n are + also considered minority classes and are subsampled down to size n. Classes + of size greater than or equal to 10n are considered majority classes and are + subsampled down to size 2n. + +--balancing-strategy auto + Denoting n the size of the minority class, classes of size greater than or + equal to 20n are subsampled down to 20n. Other classes are not subsampled. + In addition, if class weights are not defined, they are set as the inverse + of the corresponding class size. + +Subsampling is done at random. However, data observations bearing a specific +secondary label can be included with priority, up to the target size if too +many, and then complemented with different randomly picked observations. To +trigger this behavior, specify the secondary label with argument --include-all. + +`--sandbox <token>` makes `tagging-backend` use a token instead of <name> as +directory name in data/raw, data/interim and data/processed. +This is intended to prevent conflicts on running `predict` in parallel on +multiple data files with multiple calls. Note that an existing larva_dataset file in data/interim/<name> makes the `train` command skip the `make_dataset` step. However, even in such a case, @@ -41,11 +73,6 @@ truly skip this step; the corresponding module is not loaded. Since version 0.10, `predict` makes `--skip-make-dataset` and `--skip-build-features` the default behavior. As a counterpart, it admits arguments `--make-dataset` and `--build-features`. - -`--sandbox <token>` makes `tagging-backend` use a token instead of <name> as -directory name in data/raw, data/interim and data/processed. -This is intended to prevent conflicts on running `predict` in parallel on -multiple data files with multiple calls. """ if _print: print(msg) @@ -70,6 +97,7 @@ def main(fun=None): sandbox = False balancing_strategy = 'auto' include_all = None + class_weights = None unknown_args = {} k = 2 while k < len(sys.argv): @@ -112,6 +140,9 @@ def main(fun=None): elif sys.argv[k] == "--sandbox": k = k + 1 sandbox = sys.argv[k] + elif sys.argv[k] == "--class-weights": + k = k + 1 + class_weights = sys.argv[k] elif sys.argv[k] == "--balancing-strategy": k = k + 1 balancing_strategy = sys.argv[k] @@ -146,10 +177,12 @@ def main(fun=None): if frame_interval: make_dataset_kwargs["frame_interval"] = frame_interval if trxmat_only: + # deprecated make_dataset_kwargs["trxmat_only"] = True if reuse_h5files: make_dataset_kwargs["reuse_h5files"] = True elif reuse_h5files: + # deprecated logging.info("option --reuse-h5files is ignored in the absence of --trxmat-only") if include_all: make_dataset_kwargs["include_all"] = include_all @@ -162,6 +195,8 @@ def main(fun=None): train_kwargs = dict(balancing_strategy=balancing_strategy) if pretrained_model_instance: train_kwargs["pretrained_model_instance"] = pretrained_model_instance + if class_weights: + train_kwargs['class_weights'] = class_weights backend._run_script(backend.train_model, trailing=unknown_args, **train_kwargs) else: # called by make_dataset, build_features, train_model and predict_model @@ -180,6 +215,9 @@ def main(fun=None): elif key == "labels": if isinstance(val, str): val = val.split(',') + elif key == 'class_weights': + if isinstance(val, str): + val = [float(s) for s in val.split(',')] elif key == "pretrained_model_instance": val_ = val.split(',') if val_[1:]: