From c0ce79427c827e7f7e357a53c5341a43439ad048 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Fran=C3=A7ois=20Laurent?= <francois.laurent@posteo.net>
Date: Wed, 8 Feb 2023 02:37:54 +0100
Subject: [PATCH] implements #21

---
 src/LarvaDatasets.jl                | 5 +++++
 src/taggingbackends/data/dataset.py | 7 +++++--
 2 files changed, 10 insertions(+), 2 deletions(-)

diff --git a/src/LarvaDatasets.jl b/src/LarvaDatasets.jl
index 328675f..409f925 100644
--- a/src/LarvaDatasets.jl
+++ b/src/LarvaDatasets.jl
@@ -295,6 +295,7 @@ function write_larva_dataset_hdf5(path, counts, files, refs, nsteps_before, nste
         attributes(g)["n_samples"] = sampleid
         # extension
         h5["labels"] = collect(keys(counts))
+        h5["label_counts"] = collect(values(counts))
         #h5["files"] = [f.source for f in files]
         if !isnothing(frameinterval)
             attributes(g)["frame_interval"] = frameinterval
@@ -516,6 +517,10 @@ function write_larva_dataset_hdf5(output_dir::String,
     end
     if !isnothing(labels)
         labels′ = p -> string(p[1]) in labels
+        missing_labels = [label for label in labels if label ∉ keys(counts)]
+        if !isempty(missing_labels)
+            @warn "No occurences found for labels: \"$(join(missing_labels, "\", \""))\""
+        end
         filter!(labels′, counts)
         filter!(labels′, refs)
     end
diff --git a/src/taggingbackends/data/dataset.py b/src/taggingbackends/data/dataset.py
index 3e9d2d3..9ce5970 100644
--- a/src/taggingbackends/data/dataset.py
+++ b/src/taggingbackends/data/dataset.py
@@ -230,8 +230,11 @@ class LarvaDataset:
     @property
     def class_weights(self):
         if not isinstance(self._class_weights, np.ndarray) and self._class_weights in (None, True):
-            _, class_counts = np.unique(self.training_labels, return_counts=True)
-            class_counts = np.array([class_counts[i] for i in range(len(self.labels))])
+            try:
+                class_counts = np.asarray(self.full_set["label_counts"])
+            except KeyError:
+                _, class_counts = np.unique(self.training_labels, return_counts=True)
+                class_counts = np.array([class_counts[i] for i in range(len(self.labels))])
             self._class_weights = 1 - class_counts / np.sum(class_counts)
         return None if self._class_weights is False else self._class_weights
 
-- 
GitLab