diff --git a/src/LarvaDatasets.jl b/src/LarvaDatasets.jl
index 328675fa321eb82f51ae73e3a217f58f7ad4895d..409f925ade242420b3077d7c43727ef0d34e0ce2 100644
--- a/src/LarvaDatasets.jl
+++ b/src/LarvaDatasets.jl
@@ -295,6 +295,7 @@ function write_larva_dataset_hdf5(path, counts, files, refs, nsteps_before, nste
attributes(g)["n_samples"] = sampleid
# extension
h5["labels"] = collect(keys(counts))
+ h5["label_counts"] = collect(values(counts))
#h5["files"] = [f.source for f in files]
if !isnothing(frameinterval)
attributes(g)["frame_interval"] = frameinterval
@@ -516,6 +517,10 @@ function write_larva_dataset_hdf5(output_dir::String,
end
if !isnothing(labels)
labels′ = p -> string(p[1]) in labels
+ missing_labels = [label for label in labels if label ∉ keys(counts)]
+ if !isempty(missing_labels)
+ @warn "No occurences found for labels: \"$(join(missing_labels, "\", \""))\""
+ end
filter!(labels′, counts)
filter!(labels′, refs)
end
diff --git a/src/taggingbackends/data/dataset.py b/src/taggingbackends/data/dataset.py
index 3e9d2d32ca6520b0c9e848cf20bbce58bf1cd37a..9ce597029f74b3433edf874862d23b2b18caf4ec 100644
--- a/src/taggingbackends/data/dataset.py
+++ b/src/taggingbackends/data/dataset.py
@@ -230,8 +230,11 @@ class LarvaDataset:
@property
def class_weights(self):
if not isinstance(self._class_weights, np.ndarray) and self._class_weights in (None, True):
- _, class_counts = np.unique(self.training_labels, return_counts=True)
- class_counts = np.array([class_counts[i] for i in range(len(self.labels))])
+ try:
+ class_counts = np.asarray(self.full_set["label_counts"])
+ except KeyError:
+ _, class_counts = np.unique(self.training_labels, return_counts=True)
+ class_counts = np.array([class_counts[i] for i in range(len(self.labels))])
self._class_weights = 1 - class_counts / np.sum(class_counts)
return None if self._class_weights is False else self._class_weights