diff --git a/src/taggingbackends/data/dataset.py b/src/taggingbackends/data/dataset.py
index 9ce597029f74b3433edf874862d23b2b18caf4ec..8504f2755471462f07d83e12a4e6f97ffb3549f3 100644
--- a/src/taggingbackends/data/dataset.py
+++ b/src/taggingbackends/data/dataset.py
@@ -8,7 +8,8 @@ from collections import Counter
Torch-like dataset class for *larva_dataset hdf5* files.
"""
class LarvaDataset:
- def __init__(self, dataset=None, generator=None, subsets=(.8, .1, .1)):
+ def __init__(self, dataset=None, generator=None, subsets=(.8, .1, .1),
+ balancing_strategy=None, class_weights=None):
self.generator = generator
self._full_set = dataset
self.subset_shares = subsets
@@ -21,6 +22,11 @@ class LarvaDataset:
self._class_weights = None
# this attribute was introduced to implement `training_labels`
self._alt_training_set_loader = None
+ if class_weights is None:
+ self.weight_classes = isinstance(balancing_strategy, str) and (balancing_strategy.lower() == 'auto')
+ else:
+ self.class_weights = class_weights
+
"""
*h5py.File*: *larva_dataset hdf5* file handler.
"""
diff --git a/src/taggingbackends/main.py b/src/taggingbackends/main.py
index 76761fdf39cb7697dca78051aa00d6633e80db67..dde5335e0a03a7620aa8a334c59bc9fb800bc636 100644
--- a/src/taggingbackends/main.py
+++ b/src/taggingbackends/main.py
@@ -1,18 +1,19 @@
import os
import sys
import logging
-from taggingbackends.explorer import BackendExplorer, BackendExplorerDecoder
+from taggingbackends.explorer import BackendExplorer, BackendExplorerDecoder, getlogger
def help(_print=False):
msg = """
Usage: tagging-backend [train|predict] --model-instance <name>
- tagging-backend train ... --labels <comma-separated-list>
- tagging-backend train ... --sample-size <N> --balancing-strategy <strategy>
+ tagging-backend train ... --labels <labels> --class-weights <weights>
+ tagging-backend train ... --sample-size <N> --balancing-strategy <S>
tagging-backend train ... --frame-interval <I> --window-length <T>
tagging-backend train ... --pretrained-model-instance <name>
tagging-backend train ... --include-all <secondary-label>
tagging-backend train ... --skip-make-dataset --skip-build-features
- tagging-backend predict ... --make-dataset --build-features --sandbox <token>
+ tagging-backend predict ... --make-dataset --build-features
+ tagging-backend predict ... --sandbox <token>
`tagging-backend` typically is run using `poetry run`.
A name must be provided to identify the trained model and its location within
@@ -20,17 +21,48 @@ the backend.
An intermediate *larva_dataset hdf5* file is generated, with series of labelled
spines. If option `--sample-size` is passed, <N> time segments are sampled from
-the raw database. The total length 3*<T> of time segments is 60 per default
-(20 *past* points, 20 *present* points and 20 *future* points).
+the raw database.
+
+The total length 3*<T> of time segments is 60 per default (20 *past* points,
+20 *present* points and 20 *future* points). This means the length is always a
+multiple of 3. This is inherited from MaggotUBA, whose *larva_dataset* interim
+file was borrowed, and may be changed in the future.
If frame interval <I> is specified (in seconds), spine series are resampled and
interpolated around each time segment anchor (center).
-**Deprecated**:
-Option `--trxmat-only` is suitable for large databases made of trx.mat files
-only. Intermediate HDF5 files are generated prior to counting the various
-behavior labels and sampling time segments in the database. These intermediate
-files may be kept to avoid loading the trx.mat files again.
+Some behavior labels can be listed for inclusion in the *larva_dataset* interim
+file, with argument `--labels` as a comma-separated list of labels.
+If `--labels` is defined, `--class-weights` can be passed as well to specify
+associated weights in the calculation of the cost function the training
+algorithm aims to minimize. Weights are also specified as a comma-separated
+list of floating-point values. As many weights as labels are expected.
+
+In addition to class penalties, the majority classes can be subsampled in
+different ways. Argument `--balancing-strategy` can take either "maggotuba" or
+"auto", that correspond to specific subsampling strategies.
+
+--balancing-strategy maggotuba
+ Denoting n the size of the minority class, classes of size less than 10n are
+ also considered minority classes and are subsampled down to size n. Classes
+ of size greater than or equal to 10n are considered majority classes and are
+ subsampled down to size 2n.
+
+--balancing-strategy auto
+ Denoting n the size of the minority class, classes of size greater than or
+ equal to 20n are subsampled down to 20n. Other classes are not subsampled.
+ In addition, if class weights are not defined, they are set as the inverse
+ of the corresponding class size.
+
+Subsampling is done at random. However, data observations bearing a specific
+secondary label can be included with priority, up to the target size if too
+many, and then complemented with different randomly picked observations. To
+trigger this behavior, specify the secondary label with argument --include-all.
+
+`--sandbox <token>` makes `tagging-backend` use a token instead of <name> as
+directory name in data/raw, data/interim and data/processed.
+This is intended to prevent conflicts on running `predict` in parallel on
+multiple data files with multiple calls.
Note that an existing larva_dataset file in data/interim/<name> makes the
`train` command skip the `make_dataset` step. However, even in such a case,
@@ -41,11 +73,6 @@ truly skip this step; the corresponding module is not loaded.
Since version 0.10, `predict` makes `--skip-make-dataset` and
`--skip-build-features` the default behavior. As a counterpart, it admits
arguments `--make-dataset` and `--build-features`.
-
-`--sandbox <token>` makes `tagging-backend` use a token instead of <name> as
-directory name in data/raw, data/interim and data/processed.
-This is intended to prevent conflicts on running `predict` in parallel on
-multiple data files with multiple calls.
"""
if _print:
print(msg)
@@ -70,6 +97,7 @@ def main(fun=None):
sandbox = False
balancing_strategy = 'auto'
include_all = None
+ class_weights = None
unknown_args = {}
k = 2
while k < len(sys.argv):
@@ -112,6 +140,9 @@ def main(fun=None):
elif sys.argv[k] == "--sandbox":
k = k + 1
sandbox = sys.argv[k]
+ elif sys.argv[k] == "--class-weights":
+ k = k + 1
+ class_weights = sys.argv[k]
elif sys.argv[k] == "--balancing-strategy":
k = k + 1
balancing_strategy = sys.argv[k]
@@ -146,10 +177,12 @@ def main(fun=None):
if frame_interval:
make_dataset_kwargs["frame_interval"] = frame_interval
if trxmat_only:
+ # deprecated
make_dataset_kwargs["trxmat_only"] = True
if reuse_h5files:
make_dataset_kwargs["reuse_h5files"] = True
elif reuse_h5files:
+ # deprecated
logging.info("option --reuse-h5files is ignored in the absence of --trxmat-only")
if include_all:
make_dataset_kwargs["include_all"] = include_all
@@ -162,6 +195,8 @@ def main(fun=None):
train_kwargs = dict(balancing_strategy=balancing_strategy)
if pretrained_model_instance:
train_kwargs["pretrained_model_instance"] = pretrained_model_instance
+ if class_weights:
+ train_kwargs['class_weights'] = class_weights
backend._run_script(backend.train_model, trailing=unknown_args, **train_kwargs)
else:
# called by make_dataset, build_features, train_model and predict_model
@@ -180,6 +215,9 @@ def main(fun=None):
elif key == "labels":
if isinstance(val, str):
val = val.split(',')
+ elif key == 'class_weights':
+ if isinstance(val, str):
+ val = [float(s) for s in val.split(',')]
elif key == "pretrained_model_instance":
val_ = val.split(',')
if val_[1:]: