Select Git revision
trainers.py
make_dataset.py 3.33 KiB
import glob
import pathlib
import json
import sys
def make_dataset(backend, labels_expected=False, trxmat_only=False,
balancing_strategy='maggotuba',
pretrained_model_instance='default', **kwargs):
if labels_expected:
larva_dataset_file = glob.glob(str(backend.raw_data_dir() / "larva_dataset_*.hdf5"))
if larva_dataset_file:
if larva_dataset_file[1:]:
print(f"multiple larva_dataset files found")
larva_dataset_file = pathlib.Path(larva_dataset_file[0])
# make the file available in data/interim/{instance}/
print(f"moving file to interim: {larva_dataset_file}")
backend.move_to_interim(larva_dataset_file, copy=False)
else:
if 'frame_interval' not in kwargs:
# load argument `frame_interval`
if 'original_model_instance' in kwargs:
autoencoder_config = str(backend.project_dir / 'models' / kwargs['original_model_instance'] / 'autoencoder_config.json')
else:
autoencoder_config = glob.glob(str(backend.project_dir / "pretrained_models" / pretrained_model_instance / "*config.json"))[0]
with open(autoencoder_config, "r") as f:
config = json.load(f)
try:
frame_interval = config['frame_interval']
except KeyError:
pass
else:
kwargs['frame_interval'] = frame_interval
if 'original_model_instance' in kwargs:
original_instance = kwargs.pop('original_model_instance')
# load parameter `window_length`
enc_config = str(backend.project_dir / 'models' / original_instance / 'autoencoder_config.json')
with open(enc_config, 'r') as f:
config = json.load(f)
kwargs['window_length'] = int(config['len_traj'])
# load parameter `labels`
clf_config = str(backend.project_dir / 'models' / original_instance / 'clf_config.json')
with open(clf_config, 'r') as f:
config = json.load(f)
for key in ('original_behavior_labels', 'behavior_labels'):
try:
labels = config[key]
except KeyError:
pass
else:
# note kwargs['labels'] may be defined, but we dismiss
# the input argument, because we need to preserve the
# order of the labels (the class indices)
if isinstance(labels, dict):
labels = labels['names']
kwargs['labels'] = labels
break
print("generating a larva_dataset file...")
# generate a larva_dataset_*.hdf5 file in data/interim/{instance}/
balance = isinstance(balancing_strategy, str) and balancing_strategy.lower() == 'maggotuba'
out = backend.generate_dataset(backend.raw_data_dir(),
balance=balance, **kwargs)
print(f"larva_dataset file generated: {out}")
from taggingbackends.main import main
if __name__ == "__main__":
main(make_dataset)