Skip to content
Snippets Groups Projects
Commit 0db3975e authored by François  LAURENT's avatar François LAURENT
Browse files

sample_size support in new_write_larva_dataset_hdf5

parent 95e864e5
No related branches found
No related tags found
3 merge requests!9Set of commits to be tagged 0.16,!7Save larva_dataset file along with model files,!6sample_size argument and check_larva_dataset_hdf5 function
Pipeline #110098 passed
...@@ -604,6 +604,7 @@ function new_write_larva_dataset_hdf5(output_dir, input_data; ...@@ -604,6 +604,7 @@ function new_write_larva_dataset_hdf5(output_dir, input_data;
includeall="edited", includeall="edited",
past_future_extensions=true, past_future_extensions=true,
seed=nothing, seed=nothing,
sample_size=nothing,
kwargs...) kwargs...)
repo = if input_data isa String repo = if input_data isa String
isnothing(file_filter) ? Repository(input_data) : Repository(input_data, file_filter) isnothing(file_filter) ? Repository(input_data) : Repository(input_data, file_filter)
...@@ -642,6 +643,11 @@ function new_write_larva_dataset_hdf5(output_dir, input_data; ...@@ -642,6 +643,11 @@ function new_write_larva_dataset_hdf5(output_dir, input_data;
@error "Most likely cause: no time segments could be isolated" @error "Most likely cause: no time segments could be isolated"
rethrow() rethrow()
end end
if !isnothing(sample_size)
# TODO: move the below function into PlanarLarvae.Dataloaders
isnothing(seed) || @warn "Random number generation is not explicitly seeded with sample_size defined"
samplesize!(index, sample_size)
end
total_sample_size = length(index) total_sample_size = length(index)
classcounts, _ = Dataloaders.groupby(index.sampler.selectors, index.targetcounts) classcounts, _ = Dataloaders.groupby(index.sampler.selectors, index.targetcounts)
# #
...@@ -846,4 +852,67 @@ end ...@@ -846,4 +852,67 @@ end
runid(file) = splitpath(file.source)[end-1] runid(file) = splitpath(file.source)[end-1]
function samplesize!(index, sample_size)
total_sample_size = length(index)
sample_size < total_sample_size || return index
ratio = sample_size / total_sample_size
# apply `ratio` to the total counts first
maxcounts = Dataloaders.total(index.maxcounts)
targetcounts = Dataloaders.total(index.targetcounts)
for (label, count) in pairs(targetcounts)
targetcounts[label] = round(Int, count * ratio)
end
totalcount = sum(values(targetcounts))
if totalcount < sample_size
for label in shuffle(keys(targetcounts))
if targetcounts[label] < maxcounts[label]
targetcounts[label] += 1
totalcount += 1
totalcount == sample_size && break
end
end
elseif sample_size < totalcount
for label in shuffle(keys(targetcounts))
if 0 < targetcounts[label]
targetcounts[label] -= 1
totalcount -= 1
totalcount == sample_size && break
end
end
end
# apply `ratio` at the per-file level
targetcountsperfile = copy(index.targetcounts)
for counts in values(targetcountsperfile)
for (label, count) in pairs(counts)
counts[label] = round(Int, count * ratio)
end
end
targetcounts′= Dataloaders.total(targetcountsperfile)
for (label, targetcount) in pairs(targetcounts)
count = targetcounts′[label]
while count < targetcount
file = rand(keys(targetcountsperfile))
counts = targetcountsperfile[file]
if counts[label] < index.maxcounts[file][label]
counts[label] += 1
count += 1
end
end
while targetcount < count
file = rand(keys(targetcountsperfile))
counts = targetcountsperfile[file]
if 0 < counts[label]
counts[label] -= 1
count -= 1
end
end
end
index.targetcounts = targetcountsperfile
return index
end
end end
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment