Select Git revision
backgrounding_tasks.rst
main.py 8.51 KiB
import os
import sys
import logging
from taggingbackends.explorer import BackendExplorer, BackendExplorerDecoder
def help(_print=False):
msg = """
Usage: tagging-backend [train|predict] --model-instance <name>
tagging-backend train ... --labels <comma-separated-list>
tagging-backend train ... --sample-size <N> --balancing-strategy <strategy>
tagging-backend train ... --frame-interval <I> --window-length <T>
tagging-backend train ... --pretrained-model-instance <name>
tagging-backend train ... --include-all <secondary-label>
tagging-backend train ... --skip-make-dataset --skip-build-features
tagging-backend predict ... --make-dataset --build-features --sandbox <token>
`tagging-backend` typically is run using `poetry run`.
A name must be provided to identify the trained model and its location within
the backend.
An intermediate *larva_dataset hdf5* file is generated, with series of labelled
spines. If option `--sample-size` is passed, <N> time segments are sampled from
the raw database. The total length 3*<T> of time segments is 60 per default
(20 *past* points, 20 *present* points and 20 *future* points).
If frame interval <I> is specified (in seconds), spine series are resampled and
interpolated around each time segment anchor (center).
**Deprecated**:
Option `--trxmat-only` is suitable for large databases made of trx.mat files
only. Intermediate HDF5 files are generated prior to counting the various
behavior labels and sampling time segments in the database. These intermediate
files may be kept to avoid loading the trx.mat files again.
Note that an existing larva_dataset file in data/interim/<name> makes the
`train` command skip the `make_dataset` step. However, even in such a case,
the `make_dataset` module is loaded and this may take quite some time due to
dependencies (e.g. Julia FFI). The `--skip-make-dataset` option makes `train`
truly skip this step; the corresponding module is not loaded.
Since version 0.10, `predict` makes `--skip-make-dataset` and
`--skip-build-features` the default behavior. As a counterpart, it admits
arguments `--make-dataset` and `--build-features`.
`--sandbox <token>` makes `tagging-backend` use a token instead of <name> as
directory name in data/raw, data/interim and data/processed.
This is intended to prevent conflicts on running `predict` in parallel on
multiple data files with multiple calls.
"""
if _print:
print(msg)
return msg
def main(fun=None):
logging.basicConfig(level=logging.INFO,
format="%(levelname)s:%(name)s: %(message)s")
if fun is None:
# called by scripts/tagging-backend
if not sys.argv[1:]:
help(True)
sys.exit("too few input arguments; subcommand expected: 'train' or 'predict'")
return
train_or_predict = sys.argv[1]
project_dir = model_instance = None
input_files, labels = [], []
sample_size = window_length = frame_interval = None
trxmat_only = reuse_h5files = False
make_dataset = build_features = None
pretrained_model_instance = None
sandbox = False
balancing_strategy = 'auto'
include_all = None
unknown_args = {}
k = 2
while k < len(sys.argv):
if sys.argv[k] == "--project-dir":
k = k + 1
project_dir = sys.argv[k]
elif sys.argv[k] == "--model-instance":
k = k + 1
model_instance = sys.argv[k]
elif sys.argv[k] == "--input-files":
k = k + 1
input_files = sys.argv[k].split(',')
elif sys.argv[k] == "--labels":
k = k + 1
labels = sys.argv[k]
elif sys.argv[k] == "--sample-size":
k = k + 1
sample_size = sys.argv[k]
elif sys.argv[k] == "--window-length":
k = k + 1
window_length = sys.argv[k]
elif sys.argv[k] == "--frame-interval":
k = k + 1
frame_interval = sys.argv[k]
elif sys.argv[k] == "--trxmat-only":
trxmat_only = True
elif sys.argv[k] == "--reuse-h5files":
reuse_h5files = True
elif sys.argv[k] == "--skip-make-dataset":
make_dataset = False
elif sys.argv[k] == "--skip-build-features":
build_features = False
elif sys.argv[k] == '--make-dataset':
make_dataset = True
elif sys.argv[k] == '--build-features':
build_features = True
elif sys.argv[k] == "--pretrained-model-instance":
k = k + 1
pretrained_model_instance = sys.argv[k]
elif sys.argv[k] == "--sandbox":
k = k + 1
sandbox = sys.argv[k]
elif sys.argv[k] == "--balancing-strategy":
k = k + 1
balancing_strategy = sys.argv[k]
elif sys.argv[k] == "--include-all":
k = k + 1
include_all = sys.argv[k]
else:
unknown_args[sys.argv[k].lstrip('-').replace('-', '_')] = sys.argv[k+1]
k = k + 1
k = k + 1
backend = BackendExplorer(project_dir, model_instance=model_instance,
sandbox=sandbox)
backend.reset_data(spare_raw=True)
sys.stderr.flush()
sys.stdout.flush()
if input_files:
for file in input_files:
backend.move_to_raw(file)
if make_dataset is None and train_or_predict == 'train':
make_dataset = True
if build_features is None and train_or_predict == 'train':
build_features = True
if make_dataset:
make_dataset_kwargs = dict(labels_expected=train_or_predict == "train",
balancing_strategy=balancing_strategy)
if labels:
make_dataset_kwargs["labels"] = labels
if sample_size:
make_dataset_kwargs["sample_size"] = sample_size
if window_length:
make_dataset_kwargs["window_length"] = window_length
if frame_interval:
make_dataset_kwargs["frame_interval"] = frame_interval
if trxmat_only:
make_dataset_kwargs["trxmat_only"] = True
if reuse_h5files:
make_dataset_kwargs["reuse_h5files"] = True
elif reuse_h5files:
logging.info("option --reuse-h5files is ignored in the absence of --trxmat-only")
if include_all:
make_dataset_kwargs["include_all"] = include_all
backend._run_script(backend.make_dataset, **make_dataset_kwargs)
if build_features:
backend._run_script(backend.build_features)
if train_or_predict == "predict":
backend._run_script(backend.predict_model, trailing=unknown_args)
else:
train_kwargs = dict(balancing_strategy=balancing_strategy)
if pretrained_model_instance:
train_kwargs["pretrained_model_instance"] = pretrained_model_instance
backend._run_script(backend.train_model, trailing=unknown_args, **train_kwargs)
else:
# called by make_dataset, build_features, train_model and predict_model
backend = BackendExplorerDecoder().decode(sys.argv[1])
def _decode(key, val):
if val == "true":
val = True
elif val == "false":
val = False
elif key in ("sample_size", "window_length"):
if isinstance(val, str):
val = int(val)
elif key in ("frame_interval",):
if isinstance(val, str):
val = float(val)
elif key == "labels":
if isinstance(val, str):
val = val.split(',')
elif key == "pretrained_model_instance":
val_ = val.split(',')
if val_[1:]:
val = val_
elif isinstance(val, str):
try:
val = int(val)
except ValueError:
try:
val = float(val)
except ValueError:
pass
return val
args = {}
for k in range(2, len(sys.argv), 2):
args[sys.argv[k]] = _decode(sys.argv[k], sys.argv[k+1])
return fun(backend, **args)
sys.exit()