diff --git a/src/LarvaDatasets.jl b/src/LarvaDatasets.jl index 328675fa321eb82f51ae73e3a217f58f7ad4895d..409f925ade242420b3077d7c43727ef0d34e0ce2 100644 --- a/src/LarvaDatasets.jl +++ b/src/LarvaDatasets.jl @@ -295,6 +295,7 @@ function write_larva_dataset_hdf5(path, counts, files, refs, nsteps_before, nste attributes(g)["n_samples"] = sampleid # extension h5["labels"] = collect(keys(counts)) + h5["label_counts"] = collect(values(counts)) #h5["files"] = [f.source for f in files] if !isnothing(frameinterval) attributes(g)["frame_interval"] = frameinterval @@ -516,6 +517,10 @@ function write_larva_dataset_hdf5(output_dir::String, end if !isnothing(labels) labels′ = p -> string(p[1]) in labels + missing_labels = [label for label in labels if label ∉ keys(counts)] + if !isempty(missing_labels) + @warn "No occurences found for labels: \"$(join(missing_labels, "\", \""))\"" + end filter!(labels′, counts) filter!(labels′, refs) end diff --git a/src/taggingbackends/data/dataset.py b/src/taggingbackends/data/dataset.py index 3e9d2d32ca6520b0c9e848cf20bbce58bf1cd37a..9ce597029f74b3433edf874862d23b2b18caf4ec 100644 --- a/src/taggingbackends/data/dataset.py +++ b/src/taggingbackends/data/dataset.py @@ -230,8 +230,11 @@ class LarvaDataset: @property def class_weights(self): if not isinstance(self._class_weights, np.ndarray) and self._class_weights in (None, True): - _, class_counts = np.unique(self.training_labels, return_counts=True) - class_counts = np.array([class_counts[i] for i in range(len(self.labels))]) + try: + class_counts = np.asarray(self.full_set["label_counts"]) + except KeyError: + _, class_counts = np.unique(self.training_labels, return_counts=True) + class_counts = np.array([class_counts[i] for i in range(len(self.labels))]) self._class_weights = 1 - class_counts / np.sum(class_counts) return None if self._class_weights is False else self._class_weights