diff --git a/src/LarvaDatasets.jl b/src/LarvaDatasets.jl
index 328675fa321eb82f51ae73e3a217f58f7ad4895d..409f925ade242420b3077d7c43727ef0d34e0ce2 100644
--- a/src/LarvaDatasets.jl
+++ b/src/LarvaDatasets.jl
@@ -295,6 +295,7 @@ function write_larva_dataset_hdf5(path, counts, files, refs, nsteps_before, nste
         attributes(g)["n_samples"] = sampleid
         # extension
         h5["labels"] = collect(keys(counts))
+        h5["label_counts"] = collect(values(counts))
         #h5["files"] = [f.source for f in files]
         if !isnothing(frameinterval)
             attributes(g)["frame_interval"] = frameinterval
@@ -516,6 +517,10 @@ function write_larva_dataset_hdf5(output_dir::String,
     end
     if !isnothing(labels)
         labels′ = p -> string(p[1]) in labels
+        missing_labels = [label for label in labels if label ∉ keys(counts)]
+        if !isempty(missing_labels)
+            @warn "No occurences found for labels: \"$(join(missing_labels, "\", \""))\""
+        end
         filter!(labels′, counts)
         filter!(labels′, refs)
     end
diff --git a/src/taggingbackends/data/dataset.py b/src/taggingbackends/data/dataset.py
index 3e9d2d32ca6520b0c9e848cf20bbce58bf1cd37a..9ce597029f74b3433edf874862d23b2b18caf4ec 100644
--- a/src/taggingbackends/data/dataset.py
+++ b/src/taggingbackends/data/dataset.py
@@ -230,8 +230,11 @@ class LarvaDataset:
     @property
     def class_weights(self):
         if not isinstance(self._class_weights, np.ndarray) and self._class_weights in (None, True):
-            _, class_counts = np.unique(self.training_labels, return_counts=True)
-            class_counts = np.array([class_counts[i] for i in range(len(self.labels))])
+            try:
+                class_counts = np.asarray(self.full_set["label_counts"])
+            except KeyError:
+                _, class_counts = np.unique(self.training_labels, return_counts=True)
+                class_counts = np.array([class_counts[i] for i in range(len(self.labels))])
             self._class_weights = 1 - class_counts / np.sum(class_counts)
         return None if self._class_weights is False else self._class_weights