diff --git a/src/maggotuba/data/make_dataset.py b/src/maggotuba/data/make_dataset.py index 8ff4dd4ccab970ddaceb6958e4fdeebdebac1157..174d074797c50f686181e4679e8ccd5d92eb5d5e 100644 --- a/src/maggotuba/data/make_dataset.py +++ b/src/maggotuba/data/make_dataset.py @@ -1,7 +1,7 @@ import glob import pathlib -def make_dataset(backend, labels_expected=False, **kwargs): +def make_dataset(backend, labels_expected=False, trxmat_only=False, labels=None, **kwargs): if labels_expected: larva_dataset_file = glob.glob(str(backend.raw_data_dir() / "larva_dataset_*.hdf5")) if larva_dataset_file: @@ -12,9 +12,16 @@ def make_dataset(backend, labels_expected=False, **kwargs): print(f"moving file to interim: {larva_dataset_file}") backend.move_to_interim(larva_dataset_file, copy=False) else: + if labels: + if isinstance(labels, str): + labels = labels.split(',') + kwargs["labels"] = labels print("generating a larva_dataset file...") # generate a larva_dataset_*.hdf5 file in data/interim/{instance}/ - out = backend.generate_dataset(backend.raw_data_dir(), **kwargs) + if trxmat_only: + out = backend.compile_trxmat_database(backend.raw_data_dir(), **kwargs) + else: + out = backend.generate_dataset(backend.raw_data_dir(), **kwargs) print(f"larva_dataset file generated: {out}") diff --git a/src/maggotuba/models/train_model.py b/src/maggotuba/models/train_model.py index 7a5d3a578a78d2912ec365b5a5758856c7cf4d8d..0401cd76cba4377ddff3fd0856fe30646ebc5b28 100644 --- a/src/maggotuba/models/train_model.py +++ b/src/maggotuba/models/train_model.py @@ -5,10 +5,12 @@ import numpy as np import json import torch import os +import glob def train_model(backend): # make_dataset generated or moved the larva_dataset file into data/interim/{instance}/ - larva_dataset_file = backend.list_interim_files() + #larva_dataset_file = backend.list_interim_files("larva_dataset_*.hdf5") # recursive + larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # not recursive (faster) assert len(larva_dataset_file) == 1 dataset = LarvaDataset(larva_dataset_file[0], torch.Generator(device).manual_seed(42)) nlabels = len(dataset.labels)