diff --git a/src/LarvaDatasets.jl b/src/LarvaDatasets.jl
index 3e89f955b78e957450bf24264f20e5bbbbab3223..e8ecf29445ecab112b712db76e12b23f5773bdd7 100644
--- a/src/LarvaDatasets.jl
+++ b/src/LarvaDatasets.jl
@@ -418,7 +418,7 @@ end
 """
     write_larva_dataset_hdf5(output_directory, input_files, window_length=20)
     write_larva_dataset_hdf5(...; labels=nothing, labelpointers=nothing)
-    write_larva_dataset_hdf5(...; sample_size=nothing, chunks=false, shallow=false)
+    write_larva_dataset_hdf5(...; sample_size=nothing, balance=true, chunks=false, shallow=false)
     write_larva_dataset_hdf5(...; file_filter, timestep_filter)
 
 Sample series of 5-point spines from data files and save them in a hdf5 file,
@@ -451,6 +451,8 @@ See also [`labelledfiles`](@ref Main.PlanarLarvae.Formats.labelledfiles) and
 [`labelcounts`](@ref) and their `selection_rule` arguments that correspond to `file_filter`
 and `timestep_filter` respectively.
 
+`balance` refers to MaggotUBA-like class balancing. See also [`balancedcounts`](@ref).
+
 Note that, if `input_data` lists all files, `labelledfiles` is not called and arguments
 `file_filter`, `chunks` and `shallow` are not used.
 Similarly, if `labelpointers` is defined, `labelcounts` is not called and argument
@@ -507,6 +509,7 @@ function write_larva_dataset_hdf5(output_dir::String,
     else
         sample_sizes = counts
         total_sample_size = sum(values(sample_sizes))
+        isnothing(sample_size) || @error "Argument sample_size not supported for the specified balancing strategy"
     end
     @info "Sample sizes (observed, selected):" [Symbol(label) => (count, get(sample_sizes, label, 0)) for (label, count) in pairs(counts)]...
     date = Dates.format(Dates.now(), "yyyy_mm_dd")
diff --git a/src/taggingbackends/data/dataset.py b/src/taggingbackends/data/dataset.py
index 9cc670b5102e27d911806c50ff1a26132ef6dcb9..3e9d2d32ca6520b0c9e848cf20bbce58bf1cd37a 100644
--- a/src/taggingbackends/data/dataset.py
+++ b/src/taggingbackends/data/dataset.py
@@ -1,6 +1,8 @@
 import h5py
 import pathlib
 import itertools
+import numpy as np
+from collections import Counter
 
 """
 Torch-like dataset class for *larva_dataset hdf5* files.
@@ -16,6 +18,9 @@ class LarvaDataset:
         self._sample_size = None
         self._mask = slice(0, None)
         self._window_length = None
+        self._class_weights = None
+        # this attribute was introduced to implement `training_labels`
+        self._alt_training_set_loader = None
     """
     *h5py.File*: *larva_dataset hdf5* file handler.
     """
@@ -101,24 +106,28 @@ class LarvaDataset:
         train, val, test = torch.utils.data.random_split(TorchDataset(),
                 [ntrain, nval, ntest],
                 generator=self.generator)
-        if 0 < ntrain:
+        if ntrain == 0:
+            self._training_set = self._alt_training_set_loader = False
+        else:
             self._training_set = iter(itertools.cycle(
                 torch.utils.data.DataLoader(train,
                     batch_size=self.batch_size,
                     shuffle=True,
                     generator=g_train,
                     drop_last=True)))
+            self._alt_training_set_loader = \
+                torch.utils.data.DataLoader(train, batch_size=ntrain)
         self._validation_set = iter(
-            torch.utils.data.DataLoader(val,
-                batch_size=self.batch_size))
+            torch.utils.data.DataLoader(val, batch_size=self.batch_size))
         self._test_set = iter(
-            torch.utils.data.DataLoader(test,
-                batch_size=self.batch_size))
+            torch.utils.data.DataLoader(test, batch_size=self.batch_size))
     """
     Iterator over the training dataset.
     """
     @property
     def training_set(self):
+        if self._training_set is False: # not available; don't call split again
+            return
         if self._training_set is None:
             self.split()
         return self._training_set
@@ -139,9 +148,32 @@ class LarvaDataset:
             self.split()
         return self._test_set
     """
+    This property was introduced to implement `class_weights`.
+    It does not memoize.
+    """
+    @property
+    def training_labels(self):
+        if self._alt_training_set_loader is False: # not available; don't call split again
+            return
+        if self._alt_training_set_loader is None:
+            self.split()
+        _, labels = next(iter(self._alt_training_set_loader))
+        try:
+            return labels.numpy()
+        except AttributeError:
+            return labels
+    """
     Draw an observation.
+
+    Warning: this actually drew a batch, not an observation;
+             it now throws an error.
     """
     def getobs(self, subset="train"):
+        raise NotImplementedError("renamed getbatch")
+    """
+    Draw a series of observations (or batch).
+    """
+    def getbatch(self, subset="train"):
         if subset.startswith("train"):
             dataset = self.training_set
         elif subset.startswith("val"):
@@ -150,24 +182,29 @@ class LarvaDataset:
             dataset = self.test_set
         return next(dataset)
     """
-    Draw one or more observations.
+    Draw one or more batches.
     """
-    def getsample(self, subset="train", n=1):
+    def getsample(self, subset="train", nbatches=1):
         if subset.startswith("train"):
             dataset = self.training_set
         elif subset.startswith("val"):
             dataset = self.validation_set
         elif subset.startswith("test"):
             dataset = self.test_set
-        if n == "all":
-            n = len(dataset)
+        if nbatches == "all":
+            nbatches = len(dataset)
         try:
-            while 0 < n:
-                n -= 1
+            while 0 < nbatches:
+                nbatches -= 1
                 yield next(dataset)
         except StopIteration:
             pass
     """
+    Alias for `getbatch`, for backward compatibility.
+    """
+    def getsample(self, subset, n):
+        return self.getbatch(subset, n)
+    """
     *int*: number of time points in a segment.
     """
     @property
@@ -178,6 +215,34 @@ class LarvaDataset:
             self._window_length = anyrecord.shape[0]
         return self._window_length
 
+    @property
+    def weight_classes(self):
+        return self._class_weights is not False
+
+    @weight_classes.setter
+    def weight_classes(self, do_weight):
+        if do_weight:
+            if self._class_weights is False:
+                self._class_weights = True
+        else:
+            self._class_weights = False
+
+    @property
+    def class_weights(self):
+        if not isinstance(self._class_weights, np.ndarray) and self._class_weights in (None, True):
+            _, class_counts = np.unique(self.training_labels, return_counts=True)
+            class_counts = np.array([class_counts[i] for i in range(len(self.labels))])
+            self._class_weights = 1 - class_counts / np.sum(class_counts)
+        return None if self._class_weights is False else self._class_weights
+
+    @class_weights.setter
+    def class_weights(self, weights):
+        if weights not in (None, False):
+            weights = np.asarray(weights)
+            if len(weights) != len(self.labels):
+                raise ValueError("not as many weights as labels")
+        self._class_weights = weights
+
 def subset_size(ntot, train_share, val_share, test_share):
     ntrain = int(train_share * ntot)
     nval = int(val_share * ntot)
diff --git a/src/taggingbackends/explorer.py b/src/taggingbackends/explorer.py
index 803eb72db4cdbced317f9cca00e1984db9ffc6e4..13bc5815381f358c03494e172d6d440d46b9c61a 100644
--- a/src/taggingbackends/explorer.py
+++ b/src/taggingbackends/explorer.py
@@ -473,7 +473,7 @@ run `poetry add {pkg}` from directory: \n
         return input_files, labels
 
     def generate_dataset(self, input_files,
-            labels=None, window_length=20, sample_size=None,
+            labels=None, window_length=20, sample_size=None, balance=True,
             frame_interval=None):
         """
         Generate a *larva_dataset hdf5* file in data/interim/{instance}/
@@ -484,6 +484,7 @@ run `poetry add {pkg}` from directory: \n
                 window_length,
                 labels=labels,
                 sample_size=sample_size,
+                balance=balance,
                 frameinterval=frame_interval)
 
     def compile_trxmat_database(self, input_dir,
diff --git a/src/taggingbackends/main.py b/src/taggingbackends/main.py
index e5c435ef2b97aa6c32183a700ea4c0d29a007d3b..a6c6d9c19f3f25267b24b1cde85b8e83f035a8ec 100644
--- a/src/taggingbackends/main.py
+++ b/src/taggingbackends/main.py
@@ -7,7 +7,7 @@ def help(_print=False):
     msg = """
 Usage:  tagging-backend [train|predict] --model-instance <name>
         tagging-backend train ... --labels <comma-separated-list>
-        tagging-backend train ... --sample-size <N>
+        tagging-backend train ... --sample-size <N> --balancing-strategy <strategy>
         tagging-backend train ... --frame-interval <I> --window-length <T>
         tagging-backend train ... --pretrained-model-instance <name>
         tagging-backend predict ... --skip-make-dataset --sandbox <token>
@@ -62,6 +62,7 @@ def main(fun=None):
         skip_make_dataset = skip_build_features = False
         pretrained_model_instance = None
         sandbox = False
+        balancing_strategy = 'auto'
         unknown_args = {}
         k = 2
         while k < len(sys.argv):
@@ -100,6 +101,9 @@ def main(fun=None):
             elif sys.argv[k] == "--sandbox":
                 k = k + 1
                 sandbox = sys.argv[k]
+            elif sys.argv[k] == "--balancing-strategy":
+                k = k + 1
+                balancing_strategy = sys.argv[k]
             else:
                 unknown_args[sys.argv[k].lstrip('-').replace('-', '_')] = sys.argv[k+1]
                 k = k + 1
@@ -113,7 +117,8 @@ def main(fun=None):
             for file in input_files:
                 backend.move_to_raw(file)
         if not skip_make_dataset:
-            make_dataset_kwargs = dict(labels_expected=train_or_predict == "train")
+            make_dataset_kwargs = dict(labels_expected=train_or_predict == "train",
+                                       balancing_strategy=balancing_strategy)
             if labels:
                 make_dataset_kwargs["labels"] = labels
             if sample_size:
@@ -134,7 +139,7 @@ def main(fun=None):
         if train_or_predict == "predict":
             backend._run_script(backend.predict_model, trailing=unknown_args)
         else:
-            train_kwargs = {}
+            train_kwargs = dict(balancing_strategy=balancing_strategy)
             if pretrained_model_instance:
                 train_kwargs["pretrained_model_instance"] = pretrained_model_instance
             backend._run_script(backend.train_model, trailing=unknown_args, **train_kwargs)