Skip to content
Snippets Groups Projects
Commit 377eb7de authored by François  LAURENT's avatar François LAURENT
Browse files

delegate to taggingbackends

parent 64b23de4
No related branches found
No related tags found
No related merge requests found
...@@ -6,8 +6,7 @@ import json ...@@ -6,8 +6,7 @@ import json
import glob import glob
def train_model(backend, layers=1, pretrained_model_instance="default", def train_model(backend, layers=1, pretrained_model_instance="default",
subsets=(1, 0, 0), rng_seed=None, iterations=1000, subsets=(1, 0, 0), rng_seed=None, iterations=1000, **kwargs):
balancing_strategy='maggotuba', **kwargs):
# make_dataset generated or moved the larva_dataset file into data/interim/{instance}/ # make_dataset generated or moved the larva_dataset file into data/interim/{instance}/
#larva_dataset_file = backend.list_interim_files("larva_dataset_*.hdf5") # recursive #larva_dataset_file = backend.list_interim_files("larva_dataset_*.hdf5") # recursive
larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # not recursive (faster) larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # not recursive (faster)
...@@ -15,7 +14,6 @@ def train_model(backend, layers=1, pretrained_model_instance="default", ...@@ -15,7 +14,6 @@ def train_model(backend, layers=1, pretrained_model_instance="default",
# subsets=(1, 0, 0) => all data are training data; no validation or test subsets # subsets=(1, 0, 0) => all data are training data; no validation or test subsets
dataset = LarvaDataset(larva_dataset_file[0], new_generator(rng_seed), dataset = LarvaDataset(larva_dataset_file[0], new_generator(rng_seed),
subsets=subsets, **kwargs) subsets=subsets, **kwargs)
dataset.weight_classes = isinstance(balancing_strategy, str) and (balancing_strategy.lower() == 'auto')
labels = dataset.labels labels = dataset.labels
assert 0 < len(labels) assert 0 < len(labels)
labels = labels if isinstance(labels[0], str) else [s.decode() for s in labels] labels = labels if isinstance(labels[0], str) else [s.decode() for s in labels]
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment