diff --git a/Manifest.toml b/Manifest.toml index d1dbb68891c08febfe490dc5efaa434e0895aee2..0c46941449e49308b431cdc1f53e9756808dd416 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -1,6 +1,6 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.9.0" +julia_version = "1.9.2" manifest_format = "2.0" project_hash = "2c20afabe03d014276e9478d0fdccbc2cdd634c1" @@ -64,7 +64,7 @@ weakdeps = ["Dates", "LinearAlgebra"] [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.0.2+0" +version = "1.0.5+0" [[deps.Conda]] deps = ["Downloads", "JSON", "VersionParsing"] @@ -308,12 +308,12 @@ version = "2.5.10" [[deps.Pkg]] deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.9.0" +version = "1.9.2" [[deps.PlanarLarvae]] deps = ["DelimitedFiles", "HDF5", "JSON3", "LinearAlgebra", "MAT", "Meshes", "OrderedCollections", "Random", "SHA", "StaticArrays", "Statistics", "StatsBase", "StructTypes"] -git-tree-sha1 = "025970ba7e5b0b9455b2487c1ad2150a17edba0c" -repo-rev = "main" +git-tree-sha1 = "c3397f0c8a6ce76acdbe0517a060e39b90a30db8" +repo-rev = "dev" repo-url = "https://gitlab.pasteur.fr/nyx/planarlarvae.jl" uuid = "c2615984-ef14-4d40-b148-916c85b43307" version = "0.13.0" @@ -483,7 +483,7 @@ version = "1.2.13+0" [[deps.libblastrampoline_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.7.0+0" +version = "5.8.0+0" [[deps.nghttp2_jll]] deps = ["Artifacts", "Libdl"] diff --git a/src/LarvaDatasets.jl b/src/LarvaDatasets.jl index fdb25e8da0cca04516cf977c32132e5f02d157e1..a2f1249d32a8e86ecf5de94200eed5ef59d7ba5e 100644 --- a/src/LarvaDatasets.jl +++ b/src/LarvaDatasets.jl @@ -643,11 +643,7 @@ function new_write_larva_dataset_hdf5(output_dir, input_data; @error "Most likely cause: no time segments could be isolated" rethrow() end - if !isnothing(sample_size) - # TODO: move the below function into PlanarLarvae.Dataloaders - isnothing(seed) || @warn "Random number generation is not explicitly seeded with sample_size defined" - samplesize!(index, sample_size) - end + isnothing(sample_size) || samplesize!(index, sample_size) total_sample_size = length(index) classcounts, _ = Dataloaders.groupby(index.sampler.selectors, index.targetcounts) # @@ -852,67 +848,4 @@ end runid(file) = splitpath(file.source)[end-1] -function samplesize!(index, sample_size) - total_sample_size = length(index) - sample_size < total_sample_size || return index - - ratio = sample_size / total_sample_size - - # apply `ratio` to the total counts first - maxcounts = Dataloaders.total(index.maxcounts) - targetcounts = Dataloaders.total(index.targetcounts) - for (label, count) in pairs(targetcounts) - targetcounts[label] = round(Int, count * ratio) - end - totalcount = sum(values(targetcounts)) - if totalcount < sample_size - for label in shuffle(keys(targetcounts)) - if targetcounts[label] < maxcounts[label] - targetcounts[label] += 1 - totalcount += 1 - totalcount == sample_size && break - end - end - elseif sample_size < totalcount - for label in shuffle(keys(targetcounts)) - if 0 < targetcounts[label] - targetcounts[label] -= 1 - totalcount -= 1 - totalcount == sample_size && break - end - end - end - - # apply `ratio` at the per-file level - targetcountsperfile = copy(index.targetcounts) - for counts in values(targetcountsperfile) - for (label, count) in pairs(counts) - counts[label] = round(Int, count * ratio) - end - end - targetcounts′= Dataloaders.total(targetcountsperfile) - for (label, targetcount) in pairs(targetcounts) - count = targetcounts′[label] - while count < targetcount - file = rand(keys(targetcountsperfile)) - counts = targetcountsperfile[file] - if haskey(counts, label) && counts[label] < index.maxcounts[file][label] - counts[label] += 1 - count += 1 - end - end - while targetcount < count - file = rand(keys(targetcountsperfile)) - counts = targetcountsperfile[file] - if 0 < get(counts, label, 0) - counts[label] -= 1 - count -= 1 - end - end - end - - index.targetcounts = targetcountsperfile - return index -end - end