diff --git a/src/Dataloaders.jl b/src/Dataloaders.jl index 2e7f50ba975ab815fbe8c4409bc65b3d1e3dacf8..e377da42239f334289680f166b079ff61f2fe37f 100644 --- a/src/Dataloaders.jl +++ b/src/Dataloaders.jl @@ -679,7 +679,9 @@ function sample(f, loader::DataLoader, features=:spine; kwargs...) end function sample(file::Formats.PreloadedFile, window, ix::LazyIndex, features; kwargs...) - ix.sampler.rng = Random.seed!(Random.default_rng(), ix.sampler.seed) + if !(ix.sampler isa CapacitySampling) + ix.sampler.rng = Random.seed!(Random.default_rng(), ix.sampler.seed) + end sample(ix.sampler, file, window, ix.targetcounts[file], features; kwargs...) end @@ -767,11 +769,10 @@ end presample(_, ::Nothing, _, _, counts) = (0, sum(values(counts))) presample(_, cumulatedcount, _, _, counts) = (sum(cumulatedcount), sum(values(counts))) -mutable struct CapacitySampling <: RatioBasedSampling +struct CapacitySampling <: RatioBasedSampling selectors maxcount::Integer rng - seed end """ @@ -788,9 +789,9 @@ function capacitysampling(selectors, maxcount::Integer; seed=nothing, rng=nothin if isnothing(rng) rng = Random.default_rng() end - #Random.seed!(rng, seed) + Random.seed!(rng, seed) end - LazyIndex(CapacitySampling(asselectors(selectors), maxcount, rng, seed)) + LazyIndex(CapacitySampling(asselectors(selectors), maxcount, rng)) end function capacitysampling(maxcount::Integer; kwargs...)