diff --git a/src/maggotuba/data/make_dataset.py b/src/maggotuba/data/make_dataset.py
index 50475f49d4da41e23e4360db0d20f520806b9fb3..2d0230632c2832e86dd4d55989ddc840fa60888a 100644
--- a/src/maggotuba/data/make_dataset.py
+++ b/src/maggotuba/data/make_dataset.py
@@ -1,7 +1,7 @@
import glob
import pathlib
-def make_dataset(backend, labels_expected=False, trxmat_only=False, **kwargs):
+def make_dataset(backend, labels_expected=False, trxmat_only=False, balancing_strategy='maggotuba', **kwargs):
if labels_expected:
larva_dataset_file = glob.glob(str(backend.raw_data_dir() / "larva_dataset_*.hdf5"))
if larva_dataset_file:
@@ -17,7 +17,9 @@ def make_dataset(backend, labels_expected=False, trxmat_only=False, **kwargs):
if False:#trxmat_only:
out = backend.compile_trxmat_database(backend.raw_data_dir(), **kwargs)
else:
- out = backend.generate_dataset(backend.raw_data_dir(), **kwargs)
+ out = backend.generate_dataset(backend.raw_data_dir(),
+ balance=isinstance(balancing_strategy, str) and balancing_strategy.lower() == 'maggotuba',
+ **kwargs)
print(f"larva_dataset file generated: {out}")
diff --git a/src/maggotuba/models/train_model.py b/src/maggotuba/models/train_model.py
index 516bc0cf1d67218631ed7cf2b18cce105cf7b51d..28d80c23a96327d664a12b526cb12c9140d31936 100644
--- a/src/maggotuba/models/train_model.py
+++ b/src/maggotuba/models/train_model.py
@@ -6,7 +6,7 @@ import json
import glob
def train_model(backend, layers=1, pretrained_model_instance="default",
- subsets=(1, 0, 0), rng_seed=None, **kwargs):
+ subsets=(1, 0, 0), rng_seed=None, balancing_strategy='maggotuba', **kwargs):
# make_dataset generated or moved the larva_dataset file into data/interim/{instance}/
#larva_dataset_file = backend.list_interim_files("larva_dataset_*.hdf5") # recursive
larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # not recursive (faster)
@@ -14,6 +14,7 @@ def train_model(backend, layers=1, pretrained_model_instance="default",
# subsets=(1, 0, 0) => all data are training data; no validation or test subsets
dataset = LarvaDataset(larva_dataset_file[0], new_generator(rng_seed),
subsets=subsets, **kwargs)
+ dataset.weight_classes = isinstance(balancing_strategy, str) and (balancing_strategy.lower() == 'auto')
labels = dataset.labels
assert 0 < len(labels)
labels = labels if isinstance(labels[0], str) else [s.decode() for s in labels]
diff --git a/src/maggotuba/models/trainers.py b/src/maggotuba/models/trainers.py
index 19b827ded7f2eac649692003be5747ec2314c1f4..154d6f139282c0a52d452c9f9937d991afa32cae 100644
--- a/src/maggotuba/models/trainers.py
+++ b/src/maggotuba/models/trainers.py
@@ -159,10 +159,13 @@ class MaggotTrainer:
def train(self, dataset):
self.prepare_dataset(dataset)
+ kwargs = {}
+ if dataset.class_weights is not None:
+ kwargs['weight'] = torch.from_numpy(dataset.class_weights.astype(np.float32)).to(self.device)
model = self.model
model.train() # this only sets the model in training mode (enables gradients)
model.to(self.device)
- criterion = nn.CrossEntropyLoss()
+ criterion = nn.CrossEntropyLoss(**kwargs)
nsteps = self.config['optim_iter']
grad_clip = self.config['grad_clip']
# pre-train the classifier with static encoder weights
@@ -195,7 +198,7 @@ class MaggotTrainer:
return self
def draw(self, dataset, subset="train"):
- data, expected = dataset.getobs(subset)
+ data, expected = dataset.getbatch(subset)
if isinstance(data, list):
data = torch.stack(data)
data = data.to(torch.float32).to(self.device)