diff --git a/src/LarvaDatasets.jl b/src/LarvaDatasets.jl index e8ecf29445ecab112b712db76e12b23f5773bdd7..f1dcbdeeb1787fba4118fc77c16a97b387840e45 100644 --- a/src/LarvaDatasets.jl +++ b/src/LarvaDatasets.jl @@ -209,7 +209,7 @@ function labelcounts(files, timebefore::AbstractFloat, timeafter::AbstractFloat; end """ - balancedcounts(observed_counts, target_count=nothing, majority_weight=2) + balancedcounts(observed_counts, targetcount=nothing, majorityweight=2) Derive sample sizes for all labels. @@ -221,6 +221,8 @@ labels. If `target_count` is set, sample sizes are adjusted so that the total sample size equals this value. If too few occurences of a label are found, an error is thrown. + +See also [`thresholdedcounts`](@ref). """ function balancedcounts(counts, targetcount=nothing, majorityweight=2) counts = typeof(counts)(k=>count for (k, count) in pairs(counts) if 0 < count) @@ -248,6 +250,19 @@ function balancedcounts(counts, targetcount=nothing, majorityweight=2) return Dict{eltype(keys(counts)), Int}(zip(keys(counts), balancedcounts)), sum(balancedcounts) end +""" + thresholdedcounts(observed_counts, majorityweight=10) + +Derive sample sizes for all labels, with an upper bound set as `majorityweight` times +the cardinal of the least represented class. +""" +function thresholdedcounts(counts; majorityweight=10) + counts = typeof(counts)(k=>count for (k, count) in pairs(counts) if 0 < count) + majoritythresh = minimum(values(counts)) * majorityweight + thresholdedcounts = Dict(k=>(count < majoritythresh ? count : majoritythresh) for (k, count) in pairs(counts)) + return thresholdedcounts, sum(values(thresholdedcounts)) +end + function write_larva_dataset_hdf5(path, counts, files, refs, nsteps_before, nsteps_after; fixmwt=false, frameinterval=nothing, ) @@ -478,7 +493,7 @@ function write_larva_dataset_hdf5(output_dir::String, repository = input_data labelledfiles(repository, chunks; selection_rule=file_filter, shallow=shallow) elseif eltype(input_data) === String - [preload(f) for f in input_data] + preload.(input_data) else input_data end @@ -507,9 +522,8 @@ function write_larva_dataset_hdf5(output_dir::String, if balance sample_sizes, total_sample_size = balancedcounts(counts, sample_size) else - sample_sizes = counts - total_sample_size = sum(values(sample_sizes)) isnothing(sample_size) || @error "Argument sample_size not supported for the specified balancing strategy" + sample_sizes, total_sample_size = thresholdedcounts(counts) end @info "Sample sizes (observed, selected):" [Symbol(label) => (count, get(sample_sizes, label, 0)) for (label, count) in pairs(counts)]... date = Dates.format(Dates.now(), "yyyy_mm_dd")