Select Git revision
main.py 10.93 KiB
import os
import sys
import logging
from taggingbackends.explorer import BackendExplorer, BackendExplorerDecoder, getlogger
def help(_print=False):
msg = """
Usage: tagging-backend [train|predict] --model-instance <name>
tagging-backend train ... --labels <labels> --class-weights <weights>
tagging-backend train ... --sample-size <N> --balancing-strategy <S>
tagging-backend train ... --frame-interval <I> --window-length <T>
tagging-backend train ... --pretrained-model-instance <name>
tagging-backend train ... --include-all <secondary-label>
tagging-backend train ... --skip-make-dataset --skip-build-features
tagging-backend train ... --seed <seed>
tagging-backend predict ... --make-dataset --build-features
tagging-backend predict ... --sandbox <token>
tagging-backend --help
`tagging-backend` typically is run using `poetry run`.
A name must be provided to identify the trained model and its location within
the backend.
An intermediate *larva_dataset hdf5* file is generated, with series of labelled
spines. If option `--sample-size` is passed, <N> time segments are sampled from
the raw database.
The total length 3*<T> of time segments is 60 per default (20 *past* points,
20 *present* points and 20 *future* points). This means the length is always a
multiple of 3. This is inherited from MaggotUBA, whose *larva_dataset* interim
file was borrowed, and may be changed in the future.
If frame interval <I> is specified (in seconds), spine series are resampled and
interpolated around each time segment anchor (center).
Some behavior labels can be listed for inclusion in the *larva_dataset* interim
file, with argument `--labels` as a comma-separated list of labels.
If `--labels` is defined, `--class-weights` can be passed as well to specify
associated weights in the calculation of the cost function the training
algorithm aims to minimize. Weights are also specified as a comma-separated
list of floating-point values. As many weights as labels are expected.
In addition to class penalties, the majority classes can be subsampled in
different ways. Argument `--balancing-strategy` can take either "maggotuba" or
"auto", that correspond to specific subsampling strategies.
--balancing-strategy maggotuba
Denoting n the size of the minority class, classes of size less than 10n are
also considered minority classes and are subsampled down to size n. Classes
of size greater than or equal to 10n are considered majority classes and are
subsampled down to size 2n.
--balancing-strategy auto
Denoting n the size of the minority class, classes of size greater than or
equal to 20n are subsampled down to 20n. Other classes are not subsampled.
In addition, if class weights are not defined, they are set as the inverse
of the corresponding class size.
Subsampling is done at random. However, data observations bearing a specific
secondary label can be included with priority, up to the target size if too
many, and then complemented with different randomly picked observations. To
trigger this behavior, specify the secondary label with argument --include-all.
`--sandbox <token>` makes `tagging-backend` use a token instead of <name> as
directory name in data/raw, data/interim and data/processed.
This is intended to prevent conflicts on running `predict` in parallel on
multiple data files with multiple calls.
Note that an existing larva_dataset file in data/interim/<name> makes the
`train` command skip the `make_dataset` step. However, even in such a case,
the `make_dataset` module is loaded and this may take quite some time due to
dependencies (e.g. Julia FFI). The `--skip-make-dataset` option makes `train`
truly skip this step; the corresponding module is not loaded.
Since version 0.10, `predict` makes `--skip-make-dataset` and
`--skip-build-features` the default behavior. As a counterpart, it admits
arguments `--make-dataset` and `--build-features`.
"""
if _print:
print(msg)
return msg
def main(fun=None):
logging.basicConfig(level=logging.INFO,
format="%(levelname)s:%(name)s: %(message)s")
if fun is None:
# called by scripts/tagging-backend
if not sys.argv[1:] or any(arg == '--help' for arg in sys.argv):
help(True)
#sys.exit("too few input arguments; subcommand expected: 'train' or 'predict'")
return
train_or_predict = sys.argv[1]
project_dir = model_instance = None
input_files, labels = [], []
sample_size = window_length = frame_interval = None
trxmat_only = reuse_h5files = False
make_dataset = build_features = None
pretrained_model_instance = None
sandbox = False
balancing_strategy = 'auto'
include_all = None
class_weights = None
seed = None
unknown_args = {}
k = 2
while k < len(sys.argv):
if sys.argv[k] == "--project-dir":
k = k + 1
project_dir = sys.argv[k]
elif sys.argv[k] == "--model-instance":
k = k + 1
model_instance = sys.argv[k]
elif sys.argv[k] == "--input-files":
k = k + 1
input_files = sys.argv[k].split(',')
elif sys.argv[k] == "--labels":
k = k + 1
labels = sys.argv[k]
elif sys.argv[k] == "--sample-size":
k = k + 1
sample_size = sys.argv[k]
elif sys.argv[k] == "--window-length":
k = k + 1
window_length = sys.argv[k]
elif sys.argv[k] == "--frame-interval":
k = k + 1
frame_interval = sys.argv[k]
elif sys.argv[k] == "--trxmat-only":
trxmat_only = True
elif sys.argv[k] == "--reuse-h5files":
reuse_h5files = True
elif sys.argv[k] == "--skip-make-dataset":
make_dataset = False
elif sys.argv[k] == "--skip-build-features":
build_features = False
elif sys.argv[k] == '--make-dataset':
make_dataset = True
elif sys.argv[k] == '--build-features':
build_features = True
elif sys.argv[k] == "--pretrained-model-instance":
k = k + 1
pretrained_model_instance = sys.argv[k]
elif sys.argv[k] == "--sandbox":
k = k + 1
sandbox = sys.argv[k]
elif sys.argv[k] == "--class-weights":
k = k + 1
class_weights = sys.argv[k]
elif sys.argv[k] == "--balancing-strategy":
k = k + 1
balancing_strategy = sys.argv[k]
elif sys.argv[k] == "--include-all":
k = k + 1
include_all = sys.argv[k]
elif sys.argv[k] == "--seed":
k = k + 1
seed = sys.argv[k]
else:
unknown_args[sys.argv[k].lstrip('-').replace('-', '_')] = sys.argv[k+1]
k = k + 1
k = k + 1
backend = BackendExplorer(project_dir, model_instance=model_instance,
sandbox=sandbox)
backend.reset_data(spare_raw=True)
sys.stderr.flush()
sys.stdout.flush()
if input_files:
for file in input_files:
backend.move_to_raw(file)
if make_dataset is None and train_or_predict == 'train':
make_dataset = True
if build_features is None and train_or_predict == 'train':
build_features = True
if make_dataset:
make_dataset_kwargs = dict(labels_expected=train_or_predict == "train",
balancing_strategy=balancing_strategy)
if labels:
make_dataset_kwargs["labels"] = labels
if sample_size:
make_dataset_kwargs["sample_size"] = sample_size
if window_length:
make_dataset_kwargs["window_length"] = window_length
if frame_interval:
make_dataset_kwargs["frame_interval"] = frame_interval
if trxmat_only:
# deprecated
make_dataset_kwargs["trxmat_only"] = True
if reuse_h5files:
make_dataset_kwargs["reuse_h5files"] = True
elif reuse_h5files:
# deprecated
logging.info("option --reuse-h5files is ignored in the absence of --trxmat-only")
if pretrained_model_instance is not None:
make_dataset_kwargs["pretrained_model_instance"] = pretrained_model_instance
if include_all:
make_dataset_kwargs["include_all"] = include_all
if seed is not None:
make_dataset_kwargs["seed"] = seed
backend._run_script(backend.make_dataset, **make_dataset_kwargs)
if build_features:
backend._run_script(backend.build_features)
if train_or_predict == "predict":
backend._run_script(backend.predict_model, trailing=unknown_args)
else:
train_kwargs = dict(balancing_strategy=balancing_strategy)
if pretrained_model_instance:
train_kwargs["pretrained_model_instance"] = pretrained_model_instance
if class_weights:
train_kwargs['class_weights'] = class_weights
if seed is not None:
train_kwargs['seed'] = seed
backend._run_script(backend.train_model, trailing=unknown_args, **train_kwargs)
else:
# called by make_dataset, build_features, train_model and predict_model
backend = BackendExplorerDecoder().decode(sys.argv[1])
def _decode(key, val):
if val == "true":
val = True
elif val == "false":
val = False
elif key in ("sample_size", "window_length"):
if isinstance(val, str):
val = int(val)
elif key in ("frame_interval",):
if isinstance(val, str):
val = float(val)
elif key == "labels":
if isinstance(val, str):
val = val.split(',')
elif key == 'class_weights':
if isinstance(val, str):
val = [float(s) for s in val.split(',')]
elif key == "pretrained_model_instance":
val_ = val.split(',')
if val_[1:]:
val = val_
elif isinstance(val, str):
try:
val = int(val)
except ValueError:
try:
val = float(val)
except ValueError:
pass
return val
args = {}
for k in range(2, len(sys.argv), 2):
args[sys.argv[k]] = _decode(sys.argv[k], sys.argv[k+1])
return fun(backend, **args)
sys.exit()