Skip to content
Snippets Groups Projects
Commit 40419e29 authored by François  LAURENT's avatar François LAURENT
Browse files

Merge branch 'dev' into 20230111

parents 7386df2e d48b1fe4
Branches
No related tags found
No related merge requests found
[tool.poetry]
name = "MaggotUBA-adapter"
version = "0.9"
version = "0.10"
description = "Interface between MaggotUBA and the Nyx tagging UI"
authors = ["François Laurent"]
license = "MIT"
......
import glob
import pathlib
def make_dataset(backend, labels_expected=False, trxmat_only=False, **kwargs):
def make_dataset(backend, labels_expected=False, trxmat_only=False, balancing_strategy='maggotuba', **kwargs):
if labels_expected:
larva_dataset_file = glob.glob(str(backend.raw_data_dir() / "larva_dataset_*.hdf5"))
if larva_dataset_file:
......@@ -17,7 +17,9 @@ def make_dataset(backend, labels_expected=False, trxmat_only=False, **kwargs):
if False:#trxmat_only:
out = backend.compile_trxmat_database(backend.raw_data_dir(), **kwargs)
else:
out = backend.generate_dataset(backend.raw_data_dir(), **kwargs)
out = backend.generate_dataset(backend.raw_data_dir(),
balance=isinstance(balancing_strategy, str) and balancing_strategy.lower() == 'maggotuba',
**kwargs)
print(f"larva_dataset file generated: {out}")
......
......@@ -6,7 +6,7 @@ import json
import glob
def train_model(backend, layers=1, pretrained_model_instance="default",
subsets=(1, 0, 0), rng_seed=None, **kwargs):
subsets=(1, 0, 0), rng_seed=None, balancing_strategy='maggotuba', **kwargs):
# make_dataset generated or moved the larva_dataset file into data/interim/{instance}/
#larva_dataset_file = backend.list_interim_files("larva_dataset_*.hdf5") # recursive
larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # not recursive (faster)
......@@ -14,6 +14,7 @@ def train_model(backend, layers=1, pretrained_model_instance="default",
# subsets=(1, 0, 0) => all data are training data; no validation or test subsets
dataset = LarvaDataset(larva_dataset_file[0], new_generator(rng_seed),
subsets=subsets, **kwargs)
dataset.weight_classes = isinstance(balancing_strategy, str) and (balancing_strategy.lower() == 'auto')
labels = dataset.labels
assert 0 < len(labels)
labels = labels if isinstance(labels[0], str) else [s.decode() for s in labels]
......
......@@ -159,10 +159,13 @@ class MaggotTrainer:
def train(self, dataset):
self.prepare_dataset(dataset)
kwargs = {}
if dataset.class_weights is not None:
kwargs['weight'] = torch.from_numpy(dataset.class_weights.astype(np.float32)).to(self.device)
model = self.model
model.train() # this only sets the model in training mode (enables gradients)
model.to(self.device)
criterion = nn.CrossEntropyLoss()
criterion = nn.CrossEntropyLoss(**kwargs)
nsteps = self.config['optim_iter']
grad_clip = self.config['grad_clip']
# pre-train the classifier with static encoder weights
......@@ -195,7 +198,7 @@ class MaggotTrainer:
return self
def draw(self, dataset, subset="train"):
data, expected = dataset.getobs(subset)
data, expected = dataset.getbatch(subset)
if isinstance(data, list):
data = torch.stack(data)
data = data.to(torch.float32).to(self.device)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment