Skip to content
Snippets Groups Projects
Commit 64099acc authored by François  LAURENT's avatar François LAURENT
Browse files

upper bound on majority classes in 'auto' balancing strategy

parent 18a95b35
No related branches found
No related tags found
No related merge requests found
Pipeline #96636 passed
......@@ -209,7 +209,7 @@ function labelcounts(files, timebefore::AbstractFloat, timeafter::AbstractFloat;
end
"""
balancedcounts(observed_counts, target_count=nothing, majority_weight=2)
balancedcounts(observed_counts, targetcount=nothing, majorityweight=2)
Derive sample sizes for all labels.
......@@ -221,6 +221,8 @@ labels.
If `target_count` is set, sample sizes are adjusted so that the total sample size equals this
value. If too few occurences of a label are found, an error is thrown.
See also [`thresholdedcounts`](@ref).
"""
function balancedcounts(counts, targetcount=nothing, majorityweight=2)
counts = typeof(counts)(k=>count for (k, count) in pairs(counts) if 0 < count)
......@@ -248,6 +250,19 @@ function balancedcounts(counts, targetcount=nothing, majorityweight=2)
return Dict{eltype(keys(counts)), Int}(zip(keys(counts), balancedcounts)), sum(balancedcounts)
end
"""
thresholdedcounts(observed_counts, majorityweight=10)
Derive sample sizes for all labels, with an upper bound set as `majorityweight` times
the cardinal of the least represented class.
"""
function thresholdedcounts(counts; majorityweight=10)
counts = typeof(counts)(k=>count for (k, count) in pairs(counts) if 0 < count)
majoritythresh = minimum(values(counts)) * majorityweight
thresholdedcounts = Dict(k=>(count < majoritythresh ? count : majoritythresh) for (k, count) in pairs(counts))
return thresholdedcounts, sum(values(thresholdedcounts))
end
function write_larva_dataset_hdf5(path, counts, files, refs, nsteps_before, nsteps_after;
fixmwt=false, frameinterval=nothing,
)
......@@ -478,7 +493,7 @@ function write_larva_dataset_hdf5(output_dir::String,
repository = input_data
labelledfiles(repository, chunks; selection_rule=file_filter, shallow=shallow)
elseif eltype(input_data) === String
[preload(f) for f in input_data]
preload.(input_data)
else
input_data
end
......@@ -507,9 +522,8 @@ function write_larva_dataset_hdf5(output_dir::String,
if balance
sample_sizes, total_sample_size = balancedcounts(counts, sample_size)
else
sample_sizes = counts
total_sample_size = sum(values(sample_sizes))
isnothing(sample_size) || @error "Argument sample_size not supported for the specified balancing strategy"
sample_sizes, total_sample_size = thresholdedcounts(counts)
end
@info "Sample sizes (observed, selected):" [Symbol(label) => (count, get(sample_sizes, label, 0)) for (label, count) in pairs(counts)]...
date = Dates.format(Dates.now(), "yyyy_mm_dd")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment