Skip to content
Snippets Groups Projects
Commit cab06b59 authored by François  LAURENT's avatar François LAURENT
Browse files

follows TaggingBackends#8

parent 312786a5
No related branches found
No related tags found
No related merge requests found
import glob import glob
import pathlib import pathlib
def make_dataset(backend, labels_expected=False, **kwargs): def make_dataset(backend, labels_expected=False, trxmat_only=False, labels=None, **kwargs):
if labels_expected: if labels_expected:
larva_dataset_file = glob.glob(str(backend.raw_data_dir() / "larva_dataset_*.hdf5")) larva_dataset_file = glob.glob(str(backend.raw_data_dir() / "larva_dataset_*.hdf5"))
if larva_dataset_file: if larva_dataset_file:
...@@ -12,8 +12,15 @@ def make_dataset(backend, labels_expected=False, **kwargs): ...@@ -12,8 +12,15 @@ def make_dataset(backend, labels_expected=False, **kwargs):
print(f"moving file to interim: {larva_dataset_file}") print(f"moving file to interim: {larva_dataset_file}")
backend.move_to_interim(larva_dataset_file, copy=False) backend.move_to_interim(larva_dataset_file, copy=False)
else: else:
if labels:
if isinstance(labels, str):
labels = labels.split(',')
kwargs["labels"] = labels
print("generating a larva_dataset file...") print("generating a larva_dataset file...")
# generate a larva_dataset_*.hdf5 file in data/interim/{instance}/ # generate a larva_dataset_*.hdf5 file in data/interim/{instance}/
if trxmat_only:
out = backend.compile_trxmat_database(backend.raw_data_dir(), **kwargs)
else:
out = backend.generate_dataset(backend.raw_data_dir(), **kwargs) out = backend.generate_dataset(backend.raw_data_dir(), **kwargs)
print(f"larva_dataset file generated: {out}") print(f"larva_dataset file generated: {out}")
......
...@@ -5,10 +5,12 @@ import numpy as np ...@@ -5,10 +5,12 @@ import numpy as np
import json import json
import torch import torch
import os import os
import glob
def train_model(backend): def train_model(backend):
# make_dataset generated or moved the larva_dataset file into data/interim/{instance}/ # make_dataset generated or moved the larva_dataset file into data/interim/{instance}/
larva_dataset_file = backend.list_interim_files() #larva_dataset_file = backend.list_interim_files("larva_dataset_*.hdf5") # recursive
larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # not recursive (faster)
assert len(larva_dataset_file) == 1 assert len(larva_dataset_file) == 1
dataset = LarvaDataset(larva_dataset_file[0], torch.Generator(device).manual_seed(42)) dataset = LarvaDataset(larva_dataset_file[0], torch.Generator(device).manual_seed(42))
nlabels = len(dataset.labels) nlabels = len(dataset.labels)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment