diff --git a/Manifest.toml b/Manifest.toml index 2f435345361d346a772704ffd2a1779de2560d73..4f9b94d14b7d63e8b28fbb2e6a3fb00a669de3c4 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -309,9 +309,9 @@ uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" version = "1.8.0" [[deps.PlanarLarvae]] -deps = ["DelimitedFiles", "HDF5", "JSON3", "LinearAlgebra", "MAT", "Meshes", "OrderedCollections", "SHA", "StaticArrays", "Statistics", "StatsBase", "StructTypes"] -git-tree-sha1 = "2c358618e63f5de3ea5d6fe9d2ddd15c4b4da3ac" -repo-rev = "main" +deps = ["DelimitedFiles", "HDF5", "JSON3", "LinearAlgebra", "MAT", "Meshes", "OrderedCollections", "Random", "SHA", "StaticArrays", "Statistics", "StatsBase", "StructTypes"] +git-tree-sha1 = "af2845a57442b046685d2b230ac30da50bb8314d" +repo-rev = "dev" repo-url = "https://gitlab.pasteur.fr/nyx/planarlarvae.jl" uuid = "c2615984-ef14-4d40-b148-916c85b43307" version = "0.10.0" diff --git a/src/LarvaDatasets.jl b/src/LarvaDatasets.jl index 0fd33be529c140fc9450319e7032a1075be26451..2f6f45c6d3e9aeaf4718d97e8bf3062c187366a9 100644 --- a/src/LarvaDatasets.jl +++ b/src/LarvaDatasets.jl @@ -16,7 +16,8 @@ structure of *trx.mat* files, an alternative implementation is provided by modul `Trxmat2HDF5.jl`. """ -using PlanarLarvae, PlanarLarvae.Formats, PlanarLarvae.Features, PlanarLarvae.MWT +using PlanarLarvae, PlanarLarvae.Formats, PlanarLarvae.Features, PlanarLarvae.MWT, + PlanarLarvae.Dataloaders using PlanarLarvae.Datasets: coerce using Random using HDF5 @@ -52,7 +53,8 @@ read. !!! note this function typically is the main bottleneck in the process of generating - *larva_dataset* hdf5 files. + *larva_dataset* hdf5 files. Consider passing `distributed_sampling=true` to + `write_larva_dataset_hdf5` instead for large databases. """ function labelcounts(files, headlength=nothing, taillength=nothing; @@ -265,12 +267,7 @@ function thresholdedcounts(counts; majorityweight=20) return thresholdedcounts, sum(values(thresholdedcounts)) end -function write_larva_dataset_hdf5(path, counts, files, refs, nsteps_before, nsteps_after; - fixmwt=false, frameinterval=nothing, includeall=nothing, seed=nothing, - ) - fixmwt && @warn "`fixmwt=true` is no longer supported" - @debug "Seeding the random number generator" seed - isnothing(seed) || Random.seed!(seed) +function makejobs(counts, files, refs::Dict{String, Vector{Tuple{Int, Int, Int}}}; includeall=nothing) # this method mutates argument `refs` T = eltype(keys(counts)) refs′= Tuple{Int, Int, Int, T}[] @@ -322,7 +319,7 @@ function write_larva_dataset_hdf5(path, counts, files, refs, nsteps_before, nste empty!(refs) # free memory refs′= sort(refs′) sampleid = 0 - ch = Channel(; spawn=true) do ch + return Channel(; spawn=true) do ch for (i, file) in enumerate(files) refs = [r for (i′, r...) in refs′ if i′==i] isempty(refs) && continue @@ -330,8 +327,18 @@ function write_larva_dataset_hdf5(path, counts, files, refs, nsteps_before, nste sampleid += length(refs) end end +end + +function write_larva_dataset_hdf5(path, counts, files, refs, nsteps_before, nsteps_after; + fixmwt=false, frameinterval=nothing, includeall=nothing, seed=nothing, + ) + fixmwt && @warn "`fixmwt=true` is no longer supported" + @debug "Seeding the random number generator" seed + isnothing(seed) || Random.seed!(seed) + ch = makejobs(counts, files, refs; includeall=includeall) h5open(path, "w") do h5 g = create_group(h5, "samples") + # submit the jobs Threads.foreach(ch) do (file, refs, sampleid) @info "Sampling series of spines in run: $(runid(file))" processfile(g, file, refs, sampleid, nsteps_before, nsteps_after; @@ -546,7 +553,158 @@ function write_larva_dataset_hdf5(output_dir::String, fixmwt=false, frameinterval=nothing, includeall="edited", - seed=nothing) + seed=nothing, + distributed_sampling=true, + ) + if distributed_sampling + new_write_larva_dataset_hdf5(output_dir, input_data; + window_length=window_length, + labels=labels, + labelpointers=labelpointers, + chunks=chunks, + sample_size=sample_size, + file_filter=file_filter, + timestep_filter=timestep_filter, + shallow=shallow, + balance=balance, + fixmwt=fixmwt, + frameinterval=frameinterval, + includeall=includeall, + seed=seed) + else + legacy_write_larva_dataset_hdf5(output_dir, input_data, window_length; + labels=labels, + labelpointers=labelpointers, + chunks=chunks, + sample_size=sample_size, + file_filter=file_filter, + timestep_filter=timestep_filter, + shallow=shallow, + balance=balance, + fixmwt=fixmwt, + frameinterval=frameinterval, + includeall=includeall, + seed=seed) + end +end + +function new_write_larva_dataset_hdf5(output_dir, input_data; + window_length=20, + labels=nothing, + file_filter=nothing, + balance=true, + frameinterval=0.1, + includeall="edited", + seed=nothing, + kwargs...) + repo = if input_data isa String + isnothing(file_filter) ? Repository(input_data) : Repository(input_data, file_filter) + else + Repository(repo = pwd(), files = eltype(input_data) === String ? preload.(input_data) : input_data) + end + isempty(repo) && throw("no data files found") + if !isnothing(get(kwargs, :labelpointers, nothing)) + throw("label pointers are not supported with distributed_sampling=true") + end + @assert !isnothing(frameinterval) + window = TimeWindow(window_length * frameinterval, round(Int, 1 / frameinterval); + maggotuba_compatibility=true) + selectors = isnothing(labels) ? getprimarylabels(first(Dataloaders.files(repo))) : labels + min_max_ratio = balance ? 2 : 20 + index = if isnothing(includeall) + ratiobasedsampling(selectors, min_max_ratio; seed=seed) + else + ratiobasedsampling(selectors, min_max_ratio, prioritylabel(includeall); seed=seed) + end + loader = DataLoader(repo, window, index) + buildindex(loader; unload=true) + total_sample_size = length(loader.index) + classcounts, _ = Dataloaders.groupby(selectors, loader.index.targetcounts) + # + extended_window_length = 3 * window_length + date = Dates.format(Dates.now(), "yyyy_mm_dd") + win = window_length # shorter name to keep next line within the allowed text width + output_file = "larva_dataset_$(date)_$(win)_$(win)_$(total_sample_size).hdf5" + output_file = joinpath(output_dir, output_file) + h5open(output_file, "w") do h5 + g = create_group(h5, "samples") + sample(loader, :spine) do _, file, counts, segments + sampleid, nsegments = counts + @assert length(segments) == nsegments + tracks = gettimeseries(file; shallow=true) + median_body_length = medianbodylength(collect(values(tracks))) + for segment in segments + trackid = segment.trackid + anchortime = segment.anchortime + timeseries = segment.timeseries + label = segment.class + # + @assert length(segment) == window_length + 1 + @assert length(timeseries) == extendedlength(segment) == extended_window_length + 1 + times, states = [s[1] for s in timeseries], [s[2] for s in timeseries] + # + lengths = bodylength.(states[1:end-1]) + spines = spine5.(states) + mask = Dataloaders.indicator(segment.window, segment) + spines = normalize_spines(spines, median_body_length; mask=mask) + # + sample = zeros(extended_window_length, 18) + sample[:,1] = times[1:end-1] + sample[:,8] = lengths + sample[:,9:end] = spines[1:end-1,:] + # + name = "sample_$sampleid" + # transpose for compatibility with h5py + # see issue https://github.com/JuliaIO/HDF5.jl/issues/785 + g[name] = permutedims(sample, reverse(1:ndims(sample))) + sampleid += 1 + # + d = g[name] + attributes(d)["larva_number"] = convert(Int, trackid) + # we should set `start_point` instead of `reference_time`, to comply with the + # original format, but this would not make sense here due to interpolation: + attributes(d)["reference_time"] = anchortime + attributes(d)["behavior"] = string(label) + attributes(d)["path"] = file.source + end + end + attributes(g)["len_traj"] = window_length + attributes(g)["len_pred"] = window_length + attributes(g)["n_samples"] = total_sample_size + # extensions + counts = Dataloaders.total(index.targetcounts) + if isnothing(labels) + h5["labels"] = string.(keys(classcounts)) + h5["label_counts"] = collect(values(classcounts)) + else + # ensure labels are ordered as provided in input; + # see https://gitlab.pasteur.fr/nyx/TaggingBackends/-/issues/24 + h5["labels"] = labels + h5["label_counts"] = [classcounts[Symbol(label)] for label in labels] + end + if !isnothing(frameinterval) + attributes(g)["frame_interval"] = frameinterval + end + end + return output_file +end + +function legacy_write_larva_dataset_hdf5(output_dir::String, + input_data::Union{String, <:AbstractVector}, + window_length::Int=20; + labels::Union{Nothing, <:AbstractVector{String}}=nothing, + labelpointers::Union{Nothing, <:AbstractDict{String, Vector{Tuple{Int, Int, Int}}}}=nothing, + chunks::Bool=false, + sample_size=nothing, + file_filter=nothing, + timestep_filter=nothing, + shallow=false, + balance=true, + fixmwt=false, + frameinterval=nothing, + includeall="edited", + seed=nothing, + ) files = if input_data isa String repository = input_data labelledfiles(repository, chunks; selection_rule=file_filter, shallow=shallow)