diff --git a/pyproject.toml b/pyproject.toml
index 53777e20537530cd2cc236d0077940209e2fd3a3..e9961fb106e3acaed7745c6a5611beafe090efcf 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
 [tool.poetry]
 name = "MaggotUBA-adapter"
-version = "0.9"
+version = "0.10"
 description = "Interface between MaggotUBA and the Nyx tagging UI"
 authors = ["François Laurent"]
 license = "MIT"
diff --git a/src/maggotuba/data/make_dataset.py b/src/maggotuba/data/make_dataset.py
index 50475f49d4da41e23e4360db0d20f520806b9fb3..2d0230632c2832e86dd4d55989ddc840fa60888a 100644
--- a/src/maggotuba/data/make_dataset.py
+++ b/src/maggotuba/data/make_dataset.py
@@ -1,7 +1,7 @@
 import glob
 import pathlib
 
-def make_dataset(backend, labels_expected=False, trxmat_only=False, **kwargs):
+def make_dataset(backend, labels_expected=False, trxmat_only=False, balancing_strategy='maggotuba', **kwargs):
     if labels_expected:
         larva_dataset_file = glob.glob(str(backend.raw_data_dir() / "larva_dataset_*.hdf5"))
         if larva_dataset_file:
@@ -17,7 +17,9 @@ def make_dataset(backend, labels_expected=False, trxmat_only=False, **kwargs):
             if False:#trxmat_only:
                 out = backend.compile_trxmat_database(backend.raw_data_dir(), **kwargs)
             else:
-                out = backend.generate_dataset(backend.raw_data_dir(), **kwargs)
+                out = backend.generate_dataset(backend.raw_data_dir(),
+                                               balance=isinstance(balancing_strategy, str) and balancing_strategy.lower() == 'maggotuba',
+                                               **kwargs)
             print(f"larva_dataset file generated: {out}")
 
 
diff --git a/src/maggotuba/models/train_model.py b/src/maggotuba/models/train_model.py
index 1760a46f9f645a4a5e7c24069b6e6eb39eebbb9c..f9562c524a0de3cd72eed5a1817090286b346e9a 100644
--- a/src/maggotuba/models/train_model.py
+++ b/src/maggotuba/models/train_model.py
@@ -6,7 +6,7 @@ import json
 import glob
 
 def train_model(backend, layers=1, pretrained_model_instance="default",
-                subsets=(1, 0, 0), rng_seed=None, **kwargs):
+                subsets=(1, 0, 0), rng_seed=None, balancing_strategy='maggotuba', **kwargs):
     # make_dataset generated or moved the larva_dataset file into data/interim/{instance}/
     #larva_dataset_file = backend.list_interim_files("larva_dataset_*.hdf5") # recursive
     larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # not recursive (faster)
@@ -14,6 +14,7 @@ def train_model(backend, layers=1, pretrained_model_instance="default",
     # subsets=(1, 0, 0) => all data are training data; no validation or test subsets
     dataset = LarvaDataset(larva_dataset_file[0], new_generator(rng_seed),
                            subsets=subsets, **kwargs)
+    dataset.weight_classes = isinstance(balancing_strategy, str) and (balancing_strategy.lower() == 'auto')
     labels = dataset.labels
     assert 0 < len(labels)
     labels = labels if isinstance(labels[0], str) else [s.decode() for s in labels]
diff --git a/src/maggotuba/models/trainers.py b/src/maggotuba/models/trainers.py
index 19b827ded7f2eac649692003be5747ec2314c1f4..154d6f139282c0a52d452c9f9937d991afa32cae 100644
--- a/src/maggotuba/models/trainers.py
+++ b/src/maggotuba/models/trainers.py
@@ -159,10 +159,13 @@ class MaggotTrainer:
 
     def train(self, dataset):
         self.prepare_dataset(dataset)
+        kwargs = {}
+        if dataset.class_weights is not None:
+            kwargs['weight'] = torch.from_numpy(dataset.class_weights.astype(np.float32)).to(self.device)
         model = self.model
         model.train() # this only sets the model in training mode (enables gradients)
         model.to(self.device)
-        criterion = nn.CrossEntropyLoss()
+        criterion = nn.CrossEntropyLoss(**kwargs)
         nsteps = self.config['optim_iter']
         grad_clip = self.config['grad_clip']
         # pre-train the classifier with static encoder weights
@@ -195,7 +198,7 @@ class MaggotTrainer:
         return self
 
     def draw(self, dataset, subset="train"):
-        data, expected = dataset.getobs(subset)
+        data, expected = dataset.getbatch(subset)
         if isinstance(data, list):
             data = torch.stack(data)
         data = data.to(torch.float32).to(self.device)