Skip to content
Snippets Groups Projects
Select Git revision
  • b65d5a14bc56a5fa8a63353284750c7b68576e55
  • main default protected
  • dev
  • make_dataset
  • v0.19
  • v0.18.2
  • v0.18.1
  • v0.18
  • v0.17
  • v0.16
  • v0.15.3
  • v0.15.2
  • v0.15.1
  • v0.15
  • v0.14.1
  • v0.14
  • v0.13.1
  • v0.13
  • v0.12.4
  • v0.12.3
  • v0.12.2
  • v0.12.1
  • v0.12
  • v0.11.1
24 results

main.py

Blame
  • main.py 10.93 KiB
    import os
    import sys
    import logging
    from taggingbackends.explorer import BackendExplorer, BackendExplorerDecoder, getlogger
    
    def help(_print=False):
        msg = """
    Usage:  tagging-backend [train|predict] --model-instance <name>
            tagging-backend train ... --labels <labels> --class-weights <weights>
            tagging-backend train ... --sample-size <N> --balancing-strategy <S>
            tagging-backend train ... --frame-interval <I> --window-length <T>
            tagging-backend train ... --pretrained-model-instance <name>
            tagging-backend train ... --include-all <secondary-label>
            tagging-backend train ... --skip-make-dataset --skip-build-features
            tagging-backend train ... --seed <seed>
            tagging-backend predict ... --make-dataset --build-features
            tagging-backend predict ... --sandbox <token>
            tagging-backend --help
    
    `tagging-backend` typically is run using `poetry run`.
    A name must be provided to identify the trained model and its location within
    the backend.
    
    An intermediate *larva_dataset hdf5* file is generated, with series of labelled
    spines. If option `--sample-size` is passed, <N> time segments are sampled from
    the raw database.
    
    The total length 3*<T> of time segments is 60 per default (20 *past* points,
    20 *present* points and 20 *future* points). This means the length is always a
    multiple of 3. This is inherited from MaggotUBA, whose *larva_dataset* interim
    file was borrowed, and may be changed in the future.
    
    If frame interval <I> is specified (in seconds), spine series are resampled and
    interpolated around each time segment anchor (center).
    
    Some behavior labels can be listed for inclusion in the *larva_dataset* interim
    file, with argument `--labels` as a comma-separated list of labels.
    If `--labels` is defined, `--class-weights` can be passed as well to specify
    associated weights in the calculation of the cost function the training
    algorithm aims to minimize. Weights are also specified as a comma-separated
    list of floating-point values. As many weights as labels are expected.
    
    In addition to class penalties, the majority classes can be subsampled in
    different ways. Argument `--balancing-strategy` can take either "maggotuba" or
    "auto", that correspond to specific subsampling strategies.
    
    --balancing-strategy maggotuba
        Denoting n the size of the minority class, classes of size less than 10n are
        also considered minority classes and are subsampled down to size n. Classes
        of size greater than or equal to 10n are considered majority classes and are
        subsampled down to size 2n.
    
    --balancing-strategy auto
        Denoting n the size of the minority class, classes of size greater than or
        equal to 20n are subsampled down to 20n. Other classes are not subsampled.
        In addition, if class weights are not defined, they are set as the inverse
        of the corresponding class size.
    
    Subsampling is done at random. However, data observations bearing a specific
    secondary label can be included with priority, up to the target size if too
    many, and then complemented with different randomly picked observations. To
    trigger this behavior, specify the secondary label with argument --include-all.
    
    `--sandbox <token>` makes `tagging-backend` use a token instead of <name> as
    directory name in data/raw, data/interim and data/processed.
    This is intended to prevent conflicts on running `predict` in parallel on
    multiple data files with multiple calls.
    
    Note that an existing larva_dataset file in data/interim/<name> makes the
    `train` command skip the `make_dataset` step. However, even in such a case,
    the `make_dataset` module is loaded and this may take quite some time due to
    dependencies (e.g. Julia FFI). The `--skip-make-dataset` option makes `train`
    truly skip this step; the corresponding module is not loaded.
    
    Since version 0.10, `predict` makes `--skip-make-dataset` and
    `--skip-build-features` the default behavior. As a counterpart, it admits
    arguments `--make-dataset` and `--build-features`.
    """
        if _print:
            print(msg)
        return msg
    
    def main(fun=None):
        logging.basicConfig(level=logging.INFO,
                format="%(levelname)s:%(name)s: %(message)s")
        if fun is None:
            # called by scripts/tagging-backend
            if not sys.argv[1:] or any(arg == '--help' for arg in sys.argv):
                help(True)
                #sys.exit("too few input arguments; subcommand expected: 'train' or 'predict'")
                return
            train_or_predict = sys.argv[1]
            project_dir = model_instance = None
            input_files, labels = [], []
            sample_size = window_length = frame_interval = None
            trxmat_only = reuse_h5files = False
            make_dataset = build_features = None
            pretrained_model_instance = None
            sandbox = False
            balancing_strategy = 'auto'
            include_all = None
            class_weights = None
            seed = None
            unknown_args = {}
            k = 2
            while k < len(sys.argv):
                if sys.argv[k] == "--project-dir":
                    k = k + 1
                    project_dir = sys.argv[k]
                elif sys.argv[k] == "--model-instance":
                    k = k + 1
                    model_instance = sys.argv[k]
                elif sys.argv[k] == "--input-files":
                    k = k + 1
                    input_files = sys.argv[k].split(',')
                elif sys.argv[k] == "--labels":
                    k = k + 1
                    labels = sys.argv[k]
                elif sys.argv[k] == "--sample-size":
                    k = k + 1
                    sample_size = sys.argv[k]
                elif sys.argv[k] == "--window-length":
                    k = k + 1
                    window_length = sys.argv[k]
                elif sys.argv[k] == "--frame-interval":
                    k = k + 1
                    frame_interval = sys.argv[k]
                elif sys.argv[k] == "--trxmat-only":
                    trxmat_only = True
                elif sys.argv[k] == "--reuse-h5files":
                    reuse_h5files = True
                elif sys.argv[k] == "--skip-make-dataset":
                    make_dataset = False
                elif sys.argv[k] == "--skip-build-features":
                    build_features = False
                elif sys.argv[k] == '--make-dataset':
                    make_dataset = True
                elif sys.argv[k] == '--build-features':
                    build_features = True
                elif sys.argv[k] == "--pretrained-model-instance":
                    k = k + 1
                    pretrained_model_instance = sys.argv[k]
                elif sys.argv[k] == "--sandbox":
                    k = k + 1
                    sandbox = sys.argv[k]
                elif sys.argv[k] == "--class-weights":
                    k = k + 1
                    class_weights = sys.argv[k]
                elif sys.argv[k] == "--balancing-strategy":
                    k = k + 1
                    balancing_strategy = sys.argv[k]
                elif sys.argv[k] == "--include-all":
                    k = k + 1
                    include_all = sys.argv[k]
                elif sys.argv[k] == "--seed":
                    k = k + 1
                    seed = sys.argv[k]
                else:
                    unknown_args[sys.argv[k].lstrip('-').replace('-', '_')] = sys.argv[k+1]
                    k = k + 1
                k = k + 1
            backend = BackendExplorer(project_dir, model_instance=model_instance,
                                      sandbox=sandbox)
            backend.reset_data(spare_raw=True)
            sys.stderr.flush()
            sys.stdout.flush()
            if input_files:
                for file in input_files:
                    backend.move_to_raw(file)
            if make_dataset is None and train_or_predict == 'train':
                make_dataset = True
            if build_features is None and train_or_predict == 'train':
                build_features = True
            if make_dataset:
                make_dataset_kwargs = dict(labels_expected=train_or_predict == "train",
                                           balancing_strategy=balancing_strategy)
                if labels:
                    make_dataset_kwargs["labels"] = labels
                if sample_size:
                    make_dataset_kwargs["sample_size"] = sample_size
                if window_length:
                    make_dataset_kwargs["window_length"] = window_length
                if frame_interval:
                    make_dataset_kwargs["frame_interval"] = frame_interval
                if trxmat_only:
                    # deprecated
                    make_dataset_kwargs["trxmat_only"] = True
                    if reuse_h5files:
                        make_dataset_kwargs["reuse_h5files"] = True
                elif reuse_h5files:
                    # deprecated
                    logging.info("option --reuse-h5files is ignored in the absence of --trxmat-only")
                if pretrained_model_instance is not None:
                    make_dataset_kwargs["pretrained_model_instance"] = pretrained_model_instance
                if include_all:
                    make_dataset_kwargs["include_all"] = include_all
                if seed is not None:
                    make_dataset_kwargs["seed"] = seed
                backend._run_script(backend.make_dataset, **make_dataset_kwargs)
            if build_features:
                backend._run_script(backend.build_features)
            if train_or_predict == "predict":
                backend._run_script(backend.predict_model, trailing=unknown_args)
            else:
                train_kwargs = dict(balancing_strategy=balancing_strategy)
                if pretrained_model_instance:
                    train_kwargs["pretrained_model_instance"] = pretrained_model_instance
                if class_weights:
                    train_kwargs['class_weights'] = class_weights
                if seed is not None:
                    train_kwargs['seed'] = seed
                backend._run_script(backend.train_model, trailing=unknown_args, **train_kwargs)
        else:
            # called by make_dataset, build_features, train_model and predict_model
            backend = BackendExplorerDecoder().decode(sys.argv[1])
            def _decode(key, val):
                if val == "true":
                    val = True
                elif val == "false":
                    val = False
                elif key in ("sample_size", "window_length"):
                    if isinstance(val, str):
                        val = int(val)
                elif key in ("frame_interval",):
                    if isinstance(val, str):
                        val = float(val)
                elif key == "labels":
                    if isinstance(val, str):
                        val = val.split(',')
                elif key == 'class_weights':
                    if isinstance(val, str):
                        val = [float(s) for s in val.split(',')]
                elif key == "pretrained_model_instance":
                    val_ = val.split(',')
                    if val_[1:]:
                        val = val_
                elif isinstance(val, str):
                    try:
                        val = int(val)
                    except ValueError:
                        try:
                            val = float(val)
                        except ValueError:
                            pass
                return val
            args = {}
            for k in range(2, len(sys.argv), 2):
                args[sys.argv[k]] = _decode(sys.argv[k], sys.argv[k+1])
            return fun(backend, **args)
        sys.exit()