Skip to content
Snippets Groups Projects
Select Git revision
  • 02e9880ba7e00cbd4ba2fe735042f48de8a3b60b
  • main default protected
2 results

backgrounding_tasks.rst

Blame
  • main.py 8.51 KiB
    import os
    import sys
    import logging
    from taggingbackends.explorer import BackendExplorer, BackendExplorerDecoder
    
    def help(_print=False):
        msg = """
    Usage:  tagging-backend [train|predict] --model-instance <name>
            tagging-backend train ... --labels <comma-separated-list>
            tagging-backend train ... --sample-size <N> --balancing-strategy <strategy>
            tagging-backend train ... --frame-interval <I> --window-length <T>
            tagging-backend train ... --pretrained-model-instance <name>
            tagging-backend train ... --include-all <secondary-label>
            tagging-backend train ... --skip-make-dataset --skip-build-features
            tagging-backend predict ... --make-dataset --build-features --sandbox <token>
    
    `tagging-backend` typically is run using `poetry run`.
    A name must be provided to identify the trained model and its location within
    the backend.
    
    An intermediate *larva_dataset hdf5* file is generated, with series of labelled
    spines. If option `--sample-size` is passed, <N> time segments are sampled from
    the raw database. The total length 3*<T> of time segments is 60 per default
    (20 *past* points, 20 *present* points and 20 *future* points).
    
    If frame interval <I> is specified (in seconds), spine series are resampled and
    interpolated around each time segment anchor (center).
    
    **Deprecated**:
    Option `--trxmat-only` is suitable for large databases made of trx.mat files
    only. Intermediate HDF5 files are generated prior to counting the various
    behavior labels and sampling time segments in the database. These intermediate
    files may be kept to avoid loading the trx.mat files again.
    
    Note that an existing larva_dataset file in data/interim/<name> makes the
    `train` command skip the `make_dataset` step. However, even in such a case,
    the `make_dataset` module is loaded and this may take quite some time due to
    dependencies (e.g. Julia FFI). The `--skip-make-dataset` option makes `train`
    truly skip this step; the corresponding module is not loaded.
    
    Since version 0.10, `predict` makes `--skip-make-dataset` and
    `--skip-build-features` the default behavior. As a counterpart, it admits
    arguments `--make-dataset` and `--build-features`.
    
    `--sandbox <token>` makes `tagging-backend` use a token instead of <name> as
    directory name in data/raw, data/interim and data/processed.
    This is intended to prevent conflicts on running `predict` in parallel on
    multiple data files with multiple calls.
    """
        if _print:
            print(msg)
        return msg
    
    def main(fun=None):
        logging.basicConfig(level=logging.INFO,
                format="%(levelname)s:%(name)s: %(message)s")
        if fun is None:
            # called by scripts/tagging-backend
            if not sys.argv[1:]:
                help(True)
                sys.exit("too few input arguments; subcommand expected: 'train' or 'predict'")
                return
            train_or_predict = sys.argv[1]
            project_dir = model_instance = None
            input_files, labels = [], []
            sample_size = window_length = frame_interval = None
            trxmat_only = reuse_h5files = False
            make_dataset = build_features = None
            pretrained_model_instance = None
            sandbox = False
            balancing_strategy = 'auto'
            include_all = None
            unknown_args = {}
            k = 2
            while k < len(sys.argv):
                if sys.argv[k] == "--project-dir":
                    k = k + 1
                    project_dir = sys.argv[k]
                elif sys.argv[k] == "--model-instance":
                    k = k + 1
                    model_instance = sys.argv[k]
                elif sys.argv[k] == "--input-files":
                    k = k + 1
                    input_files = sys.argv[k].split(',')
                elif sys.argv[k] == "--labels":
                    k = k + 1
                    labels = sys.argv[k]
                elif sys.argv[k] == "--sample-size":
                    k = k + 1
                    sample_size = sys.argv[k]
                elif sys.argv[k] == "--window-length":
                    k = k + 1
                    window_length = sys.argv[k]
                elif sys.argv[k] == "--frame-interval":
                    k = k + 1
                    frame_interval = sys.argv[k]
                elif sys.argv[k] == "--trxmat-only":
                    trxmat_only = True
                elif sys.argv[k] == "--reuse-h5files":
                    reuse_h5files = True
                elif sys.argv[k] == "--skip-make-dataset":
                    make_dataset = False
                elif sys.argv[k] == "--skip-build-features":
                    build_features = False
                elif sys.argv[k] == '--make-dataset':
                    make_dataset = True
                elif sys.argv[k] == '--build-features':
                    build_features = True
                elif sys.argv[k] == "--pretrained-model-instance":
                    k = k + 1
                    pretrained_model_instance = sys.argv[k]
                elif sys.argv[k] == "--sandbox":
                    k = k + 1
                    sandbox = sys.argv[k]
                elif sys.argv[k] == "--balancing-strategy":
                    k = k + 1
                    balancing_strategy = sys.argv[k]
                elif sys.argv[k] == "--include-all":
                    k = k + 1
                    include_all = sys.argv[k]
                else:
                    unknown_args[sys.argv[k].lstrip('-').replace('-', '_')] = sys.argv[k+1]
                    k = k + 1
                k = k + 1
            backend = BackendExplorer(project_dir, model_instance=model_instance,
                                      sandbox=sandbox)
            backend.reset_data(spare_raw=True)
            sys.stderr.flush()
            sys.stdout.flush()
            if input_files:
                for file in input_files:
                    backend.move_to_raw(file)
            if make_dataset is None and train_or_predict == 'train':
                make_dataset = True
            if build_features is None and train_or_predict == 'train':
                build_features = True
            if make_dataset:
                make_dataset_kwargs = dict(labels_expected=train_or_predict == "train",
                                           balancing_strategy=balancing_strategy)
                if labels:
                    make_dataset_kwargs["labels"] = labels
                if sample_size:
                    make_dataset_kwargs["sample_size"] = sample_size
                if window_length:
                    make_dataset_kwargs["window_length"] = window_length
                if frame_interval:
                    make_dataset_kwargs["frame_interval"] = frame_interval
                if trxmat_only:
                    make_dataset_kwargs["trxmat_only"] = True
                    if reuse_h5files:
                        make_dataset_kwargs["reuse_h5files"] = True
                elif reuse_h5files:
                    logging.info("option --reuse-h5files is ignored in the absence of --trxmat-only")
                if include_all:
                    make_dataset_kwargs["include_all"] = include_all
                backend._run_script(backend.make_dataset, **make_dataset_kwargs)
            if build_features:
                backend._run_script(backend.build_features)
            if train_or_predict == "predict":
                backend._run_script(backend.predict_model, trailing=unknown_args)
            else:
                train_kwargs = dict(balancing_strategy=balancing_strategy)
                if pretrained_model_instance:
                    train_kwargs["pretrained_model_instance"] = pretrained_model_instance
                backend._run_script(backend.train_model, trailing=unknown_args, **train_kwargs)
        else:
            # called by make_dataset, build_features, train_model and predict_model
            backend = BackendExplorerDecoder().decode(sys.argv[1])
            def _decode(key, val):
                if val == "true":
                    val = True
                elif val == "false":
                    val = False
                elif key in ("sample_size", "window_length"):
                    if isinstance(val, str):
                        val = int(val)
                elif key in ("frame_interval",):
                    if isinstance(val, str):
                        val = float(val)
                elif key == "labels":
                    if isinstance(val, str):
                        val = val.split(',')
                elif key == "pretrained_model_instance":
                    val_ = val.split(',')
                    if val_[1:]:
                        val = val_
                elif isinstance(val, str):
                    try:
                        val = int(val)
                    except ValueError:
                        try:
                            val = float(val)
                        except ValueError:
                            pass
                return val
            args = {}
            for k in range(2, len(sys.argv), 2):
                args[sys.argv[k]] = _decode(sys.argv[k], sys.argv[k+1])
            return fun(backend, **args)
        sys.exit()