Skip to content
Snippets Groups Projects
Commit c0ce7942 authored by François  LAURENT's avatar François LAURENT
Browse files

implements #21

parent ef45018d
Branches
No related tags found
No related merge requests found
Pipeline #97340 passed
...@@ -295,6 +295,7 @@ function write_larva_dataset_hdf5(path, counts, files, refs, nsteps_before, nste ...@@ -295,6 +295,7 @@ function write_larva_dataset_hdf5(path, counts, files, refs, nsteps_before, nste
attributes(g)["n_samples"] = sampleid attributes(g)["n_samples"] = sampleid
# extension # extension
h5["labels"] = collect(keys(counts)) h5["labels"] = collect(keys(counts))
h5["label_counts"] = collect(values(counts))
#h5["files"] = [f.source for f in files] #h5["files"] = [f.source for f in files]
if !isnothing(frameinterval) if !isnothing(frameinterval)
attributes(g)["frame_interval"] = frameinterval attributes(g)["frame_interval"] = frameinterval
...@@ -516,6 +517,10 @@ function write_larva_dataset_hdf5(output_dir::String, ...@@ -516,6 +517,10 @@ function write_larva_dataset_hdf5(output_dir::String,
end end
if !isnothing(labels) if !isnothing(labels)
labels′ = p -> string(p[1]) in labels labels′ = p -> string(p[1]) in labels
missing_labels = [label for label in labels if label keys(counts)]
if !isempty(missing_labels)
@warn "No occurences found for labels: \"$(join(missing_labels, "\", \""))\""
end
filter!(labels′, counts) filter!(labels′, counts)
filter!(labels′, refs) filter!(labels′, refs)
end end
......
...@@ -230,6 +230,9 @@ class LarvaDataset: ...@@ -230,6 +230,9 @@ class LarvaDataset:
@property @property
def class_weights(self): def class_weights(self):
if not isinstance(self._class_weights, np.ndarray) and self._class_weights in (None, True): if not isinstance(self._class_weights, np.ndarray) and self._class_weights in (None, True):
try:
class_counts = np.asarray(self.full_set["label_counts"])
except KeyError:
_, class_counts = np.unique(self.training_labels, return_counts=True) _, class_counts = np.unique(self.training_labels, return_counts=True)
class_counts = np.array([class_counts[i] for i in range(len(self.labels))]) class_counts = np.array([class_counts[i] for i in range(len(self.labels))])
self._class_weights = 1 - class_counts / np.sum(class_counts) self._class_weights = 1 - class_counts / np.sum(class_counts)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment