Skip to content
Snippets Groups Projects
Select Git revision
  • d1a3ec49f83513915918bc1f0ed95daa6ab995a0
  • main default protected
  • torch2
  • torch1
  • dev protected
  • 20230311_new_default
  • 20230311
  • design protected
  • 20230129
  • 20230111
  • 20221005 protected
  • 20220418 protected
  • v0.20
  • v0.19
  • v0.18
  • v0.17
  • v0.16.4
  • v0.16.3
  • v0.16.2
  • v0.16.1
  • v0.16
  • v0.15
  • v0.14
  • v0.13
  • v0.12.4
  • v0.12.3
  • v0.12.2
  • v0.12.1
  • v0.12
  • v0.11
  • v0.10
  • v0.9.1
32 results

make_dataset.py

Blame
  • make_dataset.py 3.33 KiB
    import glob
    import pathlib
    import json
    import sys
    
    def make_dataset(backend, labels_expected=False, trxmat_only=False,
                     balancing_strategy='maggotuba',
                     pretrained_model_instance='default', **kwargs):
        if labels_expected:
            larva_dataset_file = glob.glob(str(backend.raw_data_dir() / "larva_dataset_*.hdf5"))
            if larva_dataset_file:
                if larva_dataset_file[1:]:
                    print(f"multiple larva_dataset files found")
                larva_dataset_file = pathlib.Path(larva_dataset_file[0])
                # make the file available in data/interim/{instance}/
                print(f"moving file to interim: {larva_dataset_file}")
                backend.move_to_interim(larva_dataset_file, copy=False)
    
            else:
                if 'frame_interval' not in kwargs:
                    # load argument `frame_interval`
                    if 'original_model_instance' in kwargs:
                        autoencoder_config = str(backend.project_dir / 'models' / kwargs['original_model_instance'] / 'autoencoder_config.json')
                    else:
                        autoencoder_config = glob.glob(str(backend.project_dir / "pretrained_models" / pretrained_model_instance / "*config.json"))[0]
                    with open(autoencoder_config, "r") as f:
                        config = json.load(f)
                    try:
                        frame_interval = config['frame_interval']
                    except KeyError:
                        pass
                    else:
                        kwargs['frame_interval'] = frame_interval
    
                if 'original_model_instance' in kwargs:
                    original_instance = kwargs.pop('original_model_instance')
                    # load parameter `window_length`
                    enc_config = str(backend.project_dir / 'models' / original_instance / 'autoencoder_config.json')
                    with open(enc_config, 'r') as f:
                        config = json.load(f)
                    kwargs['window_length'] = int(config['len_traj'])
                    # load parameter `labels`
                    clf_config = str(backend.project_dir / 'models' / original_instance / 'clf_config.json')
                    with open(clf_config, 'r') as f:
                        config = json.load(f)
                    for key in ('original_behavior_labels', 'behavior_labels'):
                        try:
                            labels = config[key]
                        except KeyError:
                            pass
                        else:
                            # note kwargs['labels'] may be defined, but we dismiss
                            # the input argument, because we need to preserve the
                            # order of the labels (the class indices)
                            if isinstance(labels, dict):
                                labels = labels['names']
                            kwargs['labels'] = labels
                            break
    
                print("generating a larva_dataset file...")
                # generate a larva_dataset_*.hdf5 file in data/interim/{instance}/
                balance = isinstance(balancing_strategy, str) and balancing_strategy.lower() == 'maggotuba'
                out = backend.generate_dataset(backend.raw_data_dir(),
                                               balance=balance, **kwargs)
                print(f"larva_dataset file generated: {out}")
    
    
    from taggingbackends.main import main
    
    if __name__ == "__main__":
        main(make_dataset)