From 014b6021550966d3f28b44ea29dd788a127eed3a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Fran=C3=A7ois=20Laurent?= <francois.laurent@posteo.net>
Date: Fri, 14 Apr 2023 18:46:22 +0200
Subject: [PATCH] fixes #24
---
src/LarvaDatasets.jl | 4 ++++
src/taggingbackends/data/dataset.py | 4 ++++
2 files changed, 8 insertions(+)
diff --git a/src/LarvaDatasets.jl b/src/LarvaDatasets.jl
index 70d9183..7580a90 100644
--- a/src/LarvaDatasets.jl
+++ b/src/LarvaDatasets.jl
@@ -23,6 +23,7 @@ using HDF5
using Dates
using Statistics
using Memoization
+using OrderedCollections
export write_larva_dataset_hdf5, first_stimulus, labelcounts
@@ -591,6 +592,9 @@ function write_larva_dataset_hdf5(output_dir::String,
isnothing(sample_size) || @error "Argument sample_size not supported for the specified balancing strategy"
sample_sizes, total_sample_size = thresholdedcounts(counts)
end
+ # ensure label order is preserved
+ sample_sizes = OrderedDict((label => sample_sizes[label]) for label in labels if label in keys(sample_sizes))
+ #
@info "Sample sizes (observed, selected):" [Symbol(label) => (get(counts, label, 0), get(sample_sizes, label, 0)) for label in labels]...
date = Dates.format(Dates.now(), "yyyy_mm_dd")
output_file = joinpath(output_dir, "larva_dataset_$(date)_$(window_length)_$(window_length)_$(total_sample_size).hdf5")
diff --git a/src/taggingbackends/data/dataset.py b/src/taggingbackends/data/dataset.py
index 8504f27..0f8b966 100644
--- a/src/taggingbackends/data/dataset.py
+++ b/src/taggingbackends/data/dataset.py
@@ -37,6 +37,10 @@ class LarvaDataset:
return self._full_set
"""
*list* of *bytes*: Set of distinct labels.
+
+ If the hdf5 file does not feature a top-level `labels` element that lists
+ the labels, the fallback labels and their order are:
+ RUN, BEND, STOP, HUNCH, BACK, ROLL.
"""
@property
def labels(self):
--
GitLab