From 014b6021550966d3f28b44ea29dd788a127eed3a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Fran=C3=A7ois=20Laurent?= <francois.laurent@posteo.net>
Date: Fri, 14 Apr 2023 18:46:22 +0200
Subject: [PATCH] fixes #24

---
 src/LarvaDatasets.jl                | 4 ++++
 src/taggingbackends/data/dataset.py | 4 ++++
 2 files changed, 8 insertions(+)

diff --git a/src/LarvaDatasets.jl b/src/LarvaDatasets.jl
index 70d9183..7580a90 100644
--- a/src/LarvaDatasets.jl
+++ b/src/LarvaDatasets.jl
@@ -23,6 +23,7 @@ using HDF5
 using Dates
 using Statistics
 using Memoization
+using OrderedCollections
 
 export write_larva_dataset_hdf5, first_stimulus, labelcounts
 
@@ -591,6 +592,9 @@ function write_larva_dataset_hdf5(output_dir::String,
         isnothing(sample_size) || @error "Argument sample_size not supported for the specified balancing strategy"
         sample_sizes, total_sample_size = thresholdedcounts(counts)
     end
+    # ensure label order is preserved
+    sample_sizes = OrderedDict((label => sample_sizes[label]) for label in labels if label in keys(sample_sizes))
+    #
     @info "Sample sizes (observed, selected):" [Symbol(label) => (get(counts, label, 0), get(sample_sizes, label, 0)) for label in labels]...
     date = Dates.format(Dates.now(), "yyyy_mm_dd")
     output_file = joinpath(output_dir, "larva_dataset_$(date)_$(window_length)_$(window_length)_$(total_sample_size).hdf5")
diff --git a/src/taggingbackends/data/dataset.py b/src/taggingbackends/data/dataset.py
index 8504f27..0f8b966 100644
--- a/src/taggingbackends/data/dataset.py
+++ b/src/taggingbackends/data/dataset.py
@@ -37,6 +37,10 @@ class LarvaDataset:
         return self._full_set
     """
     *list* of *bytes*: Set of distinct labels.
+
+    If the hdf5 file does not feature a top-level `labels` element that lists
+    the labels, the fallback labels and their order are:
+    RUN, BEND, STOP, HUNCH, BACK, ROLL.
     """
     @property
     def labels(self):
-- 
GitLab