diff --git a/src/LarvaDatasets.jl b/src/LarvaDatasets.jl index 8188429153f1f25c3387aee6a23d671981116271..eef36fe5d7a8856ac59e57bad724e044a0d70d7e 100644 --- a/src/LarvaDatasets.jl +++ b/src/LarvaDatasets.jl @@ -604,6 +604,7 @@ function new_write_larva_dataset_hdf5(output_dir, input_data; includeall="edited", past_future_extensions=true, seed=nothing, + sample_size=nothing, kwargs...) repo = if input_data isa String isnothing(file_filter) ? Repository(input_data) : Repository(input_data, file_filter) @@ -642,6 +643,11 @@ function new_write_larva_dataset_hdf5(output_dir, input_data; @error "Most likely cause: no time segments could be isolated" rethrow() end + if !isnothing(sample_size) + # TODO: move the below function into PlanarLarvae.Dataloaders + isnothing(seed) || @warn "Random number generation is not explicitly seeded with sample_size defined" + samplesize!(index, sample_size) + end total_sample_size = length(index) classcounts, _ = Dataloaders.groupby(index.sampler.selectors, index.targetcounts) # @@ -846,4 +852,67 @@ end runid(file) = splitpath(file.source)[end-1] +function samplesize!(index, sample_size) + total_sample_size = length(index) + sample_size < total_sample_size || return index + + ratio = sample_size / total_sample_size + + # apply `ratio` to the total counts first + maxcounts = Dataloaders.total(index.maxcounts) + targetcounts = Dataloaders.total(index.targetcounts) + for (label, count) in pairs(targetcounts) + targetcounts[label] = round(Int, count * ratio) + end + totalcount = sum(values(targetcounts)) + if totalcount < sample_size + for label in shuffle(keys(targetcounts)) + if targetcounts[label] < maxcounts[label] + targetcounts[label] += 1 + totalcount += 1 + totalcount == sample_size && break + end + end + elseif sample_size < totalcount + for label in shuffle(keys(targetcounts)) + if 0 < targetcounts[label] + targetcounts[label] -= 1 + totalcount -= 1 + totalcount == sample_size && break + end + end + end + + # apply `ratio` at the per-file level + targetcountsperfile = copy(index.targetcounts) + for counts in values(targetcountsperfile) + for (label, count) in pairs(counts) + counts[label] = round(Int, count * ratio) + end + end + targetcounts′= Dataloaders.total(targetcountsperfile) + for (label, targetcount) in pairs(targetcounts) + count = targetcounts′[label] + while count < targetcount + file = rand(keys(targetcountsperfile)) + counts = targetcountsperfile[file] + if counts[label] < index.maxcounts[file][label] + counts[label] += 1 + count += 1 + end + end + while targetcount < count + file = rand(keys(targetcountsperfile)) + counts = targetcountsperfile[file] + if 0 < counts[label] + counts[label] -= 1 + count -= 1 + end + end + end + + index.targetcounts = targetcountsperfile + return index +end + end