Skip to content
Snippets Groups Projects
Select Git revision
  • 97a359a6012cca00e6a27d492157e6980eda3d3b
  • master default protected
2 results

Dockerfile

Blame
  • make_dataset.py 1.90 KiB
    import glob
    import pathlib
    
    def make_dataset(backend, labels_expected=False, trxmat_only=False,
                     balancing_strategy='maggotuba',
                     pretrained_model_instance='default', **kwargs):
        if labels_expected:
            larva_dataset_file = glob.glob(str(backend.raw_data_dir() / "larva_dataset_*.hdf5"))
            if larva_dataset_file:
                if larva_dataset_file[1:]:
                    print(f"multiple larva_dataset files found")
                larva_dataset_file = pathlib.Path(larva_dataset_file[0])
                # make the file available in data/interim/{instance}/
                print(f"moving file to interim: {larva_dataset_file}")
                backend.move_to_interim(larva_dataset_file, copy=False)
    
            else:
                if 'frame_interval' not in kwargs:
                    autoencoder_config = glob.glob(str(backend.project_dir / "pretrained_models" / pretrained_model_instance / "*config.json"))
                    with open(autoencoder_config[0], "r") as f:
                        config = json.load(f)
                    try:
                        frame_interval = config['frame_interval']
                    except KeyError:
                        pass
                    else:
                        kwargs['frame_interval'] = frame_interval
    
                print("generating a larva_dataset file...")
                # generate a larva_dataset_*.hdf5 file in data/interim/{instance}/
                if False:#trxmat_only:
                    out = backend.compile_trxmat_database(backend.raw_data_dir(), **kwargs)
                else:
                    out = backend.generate_dataset(backend.raw_data_dir(),
                                                   balance=isinstance(balancing_strategy, str) and balancing_strategy.lower() == 'maggotuba',
                                                   **kwargs)
                print(f"larva_dataset file generated: {out}")
    
    
    from taggingbackends.main import main
    
    if __name__ == "__main__":
        main(make_dataset)