From c0ce79427c827e7f7e357a53c5341a43439ad048 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Laurent?= <francois.laurent@posteo.net> Date: Wed, 8 Feb 2023 02:37:54 +0100 Subject: [PATCH] implements #21 --- src/LarvaDatasets.jl | 5 +++++ src/taggingbackends/data/dataset.py | 7 +++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/LarvaDatasets.jl b/src/LarvaDatasets.jl index 328675f..409f925 100644 --- a/src/LarvaDatasets.jl +++ b/src/LarvaDatasets.jl @@ -295,6 +295,7 @@ function write_larva_dataset_hdf5(path, counts, files, refs, nsteps_before, nste attributes(g)["n_samples"] = sampleid # extension h5["labels"] = collect(keys(counts)) + h5["label_counts"] = collect(values(counts)) #h5["files"] = [f.source for f in files] if !isnothing(frameinterval) attributes(g)["frame_interval"] = frameinterval @@ -516,6 +517,10 @@ function write_larva_dataset_hdf5(output_dir::String, end if !isnothing(labels) labels′ = p -> string(p[1]) in labels + missing_labels = [label for label in labels if label ∉ keys(counts)] + if !isempty(missing_labels) + @warn "No occurences found for labels: \"$(join(missing_labels, "\", \""))\"" + end filter!(labels′, counts) filter!(labels′, refs) end diff --git a/src/taggingbackends/data/dataset.py b/src/taggingbackends/data/dataset.py index 3e9d2d3..9ce5970 100644 --- a/src/taggingbackends/data/dataset.py +++ b/src/taggingbackends/data/dataset.py @@ -230,8 +230,11 @@ class LarvaDataset: @property def class_weights(self): if not isinstance(self._class_weights, np.ndarray) and self._class_weights in (None, True): - _, class_counts = np.unique(self.training_labels, return_counts=True) - class_counts = np.array([class_counts[i] for i in range(len(self.labels))]) + try: + class_counts = np.asarray(self.full_set["label_counts"]) + except KeyError: + _, class_counts = np.unique(self.training_labels, return_counts=True) + class_counts = np.array([class_counts[i] for i in range(len(self.labels))]) self._class_weights = 1 - class_counts / np.sum(class_counts) return None if self._class_weights is False else self._class_weights -- GitLab