diff --git a/Project.toml b/Project.toml index 53f88ab207220c6425c3da81e6a8bf1e65e83892..7002ab52a2c0beba001d16fbfc8e0cf4e2a589c1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PlanarLarvae" uuid = "c2615984-ef14-4d40-b148-916c85b43307" authors = ["François Laurent", "Institut Pasteur"] -version = "0.12" +version = "0.13" [deps] DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" diff --git a/src/Dataloaders.jl b/src/Dataloaders.jl index 687e968422c1f6bac2b56701ca6edd1c18efd19f..ade162ea5efa3fc009a2742caebfca831813d68a 100644 --- a/src/Dataloaders.jl +++ b/src/Dataloaders.jl @@ -11,13 +11,13 @@ The default design is optimized for tracking data stored as json .label files. module Dataloaders using ..LarvaBase: LarvaBase, times -using ..Datasets: Track, TrackID +using ..Datasets: Run, Track, TrackID using ..Formats using ..MWT: interpolate using Random export DataLoader, Repository, TimeWindow, ratiobasedsampling, buildindex, sample, - extendedlength, prioritylabel + extendedlength, prioritylabel, capacitysampling """ The default window type defines a time segment as the time of a segment-specific reference @@ -49,7 +49,7 @@ end TimeWindow(duration) = TimeWindow(duration, nothing) function TimeWindow(duration, samplerate; maggotuba_compatibility=false) 0 < duration || throw("Non-positive time window duration") - 0 < samplerate || throw("Non-positive sampling frequency") + isnothing(samplerate) || 0 < samplerate || throw("Non-positive sampling frequency") margin = maggotuba_compatibility ? duration : 0 TimeWindow(.5duration, .5duration, margin, margin, samplerate) end @@ -98,11 +98,14 @@ function indicator(window::TimeWindow, segment::TimeSegment) end function segment(file, window, trackid, step, class) - ts = times(Formats.getnativerepr(file)[trackid]) + track = Formats.getnativerepr(file)[trackid] + ts = times(track) @assert 1 < step < length(ts) @inbounds anchortime = round(ts[step]; digits=4) segmentdata = if isnothing(window.samplerate) - throw("Not implemented") + @assert track isa LarvaBase.TimeSeries "Not implemented; try defining samplerate" + t0, t1 = anchortime - window.durationbefore, anchortime + window.durationafter + [(t, x) for (t, x) in track if t0 <= t <= t1] else window.durationafter == window.durationbefore || throw("Not implemented: asymmetric window") file′= isa(file, Formats.JSONLabels) ? file.dependencies[1] : file @@ -110,7 +113,7 @@ function segment(file, window, trackid, step, class) ts, xs = [t for (t, _) in timeseries], [x for (_, x) in timeseries] interpolate(ts, xs, extendedtimes(window, anchortime)) end - @assert length(segmentdata) == extendedlength(window) + isnothing(window.samplerate) || @assert length(segmentdata) == extendedlength(window) TimeSegment(trackid, anchortime, window, class, segmentdata) end @@ -121,16 +124,22 @@ end function Repository(root::String, pattern::Regex; basename_only::Bool=false) root = expanduser(root) - files = String[] + files = Formats.PreloadedFile[] for (dir, _, files′) in walkdir(root; follow_symlinks=true) for file′ in files′ file = joinpath(dir, file′) if !isnothing(Base.match(pattern, basename_only ? file′ : file)) - push!(files, file) + file″= try + preload(file; shallow=true) + catch + @warn "Cannot read labels from file" file + continue + end + push!(files, file″) end end end - Repository(root, preload.(files)) + Repository(root, files) end function Repository(root::String, fileselector::Function) @@ -176,8 +185,9 @@ function countlabels(loader::DataLoader; unload=false) countlabels(loader.repository, loader.window; unload=unload) end -function countlabels(repository, window; unload=false) - Count = Dict{Union{String, Vector{String}}, Int} +const Count = Dict{Union{String, Vector{String}}, Int} + +function countlabels(repository::Repository, window; unload=false) counts = Dict{Formats.PreloadedFile, Count}() ch = Channel() do ch foreach(files(repository)) do file @@ -186,13 +196,7 @@ function countlabels(repository, window; unload=false) end c = Threads.Condition() Threads.foreach(ch) do file - counts′= Count() - for track in values(getrun(file)) - labelsequence = track[:labels][indicator(window, track)] - for label in labelsequence - counts′[label] = get(counts′, label, 0) + 1 - end - end + counts′= countlabels(Formats.getnativerepr(file), window) unload && unload!(file; gc=true) lock(c) try @@ -204,6 +208,29 @@ function countlabels(repository, window; unload=false) return counts end +function countlabels(run::Run, window) + counts = Count() + for track in values(run) + labelsequence = track[:labels][indicator(window, track)] + for label in labelsequence + counts[label] = get(counts, label, 0) + 1 + end + end + return counts +end + +function countlabels(timeseries::LarvaBase.Larvae, window) + counts = Count() + for track in values(timeseries) + foreach(track[indicator(window, track)]) do step + _, state = step + label = convert(Vector{String}, state[:tags]) + counts[label] = get(counts, label, 0) + 1 + end + end + return counts +end + function total(counts) iterator = values(counts) counts′, state = iterate(iterator) @@ -249,6 +276,10 @@ struct IntraClassRatios <: RatioBasedSampling seed end +function withselectors(sampler::T, selectors) where {T} + T((field === :selectors ? selectors : getfield(sampler, field) for field in fieldnames(T))...) +end + function ratiobasedsampling(selectors, majority_minority_ratio; seed=nothing) LazyIndex(ClassRatios(asselectors(selectors), majority_minority_ratio, seed)) end @@ -318,6 +349,7 @@ function indicator(selector::LabelSelector, labelsequence::AbstractVector) end indicator(selector::LabelSelector, track::Track) = indicator(selector, track[:labels]) +asselectors(undefined::Nothing) = undefined asselectors(selectors::AbstractDict{Symbol, <:LabelSelector}) = selectors asselectors(label::Union{String, Symbol}) = Dict(asselector(label)) asselectors(labels::AbstractVector{<:Union{String, Symbol}}) = Dict(asselector(label) for label in labels) @@ -381,8 +413,19 @@ function buildindex(loader::DataLoader; kwargs...) end function buildindex(ix::LazyIndex, repository, window; unload=false, verbose=true) + sampler = ix.sampler + if hasproperty(sampler, :selectors) && isnothing(sampler.selectors) + anyfile = first(files(repository)) + labels = getprimarylabels(anyfile) + if verbose && !(anyfile isa Formats.JSONLabels) + @info "Assuming any data file specifies all the labels" labels + end + ix.sampler = sampler = withselectors(sampler, asselectors(labels)) + end + # ix.maxcounts = countlabels(repository, window; unload=unload) - ix.targetcounts = buildindex(ix.sampler, ix) + ix.targetcounts = buildindex(sampler, ix) + # if verbose maxcounts = total(ix.maxcounts) targetcounts = total(ix.targetcounts) @@ -519,10 +562,10 @@ function sample(sampler, file, window, counts, features; verbose=false) # number and time step index) for each behavioral class T = eltype(keys(counts)) T′= Tuple{TrackID, Int, Symbol} - run = getrun(file) + run = Formats.getnativerepr(file) index = Dict{T, Vector{T′}}() for (trackid, track) in pairs(run) - labels = track[:labels] + labels = getlabels′(track) ind = collect(indicator(window, track)) for step in ind label = labels[step] @@ -565,9 +608,13 @@ function sample(sampler, file, window, counts, features; verbose=false) end Formats.load!(file′) end - elseif isempty(file.timeseries) + else + isempty(file.timeseries) || empty!(file.timeseries) + capabilities = [cap[1] for cap in file.capabilities] for capability in (:spine, :outline, :tags) - capability in file.capabilities && capability ∉ features && Formats.drop_record!(file, capability) + if capability in capabilities && capability ∉ features + Formats.drop_record!(file, capability) + end end end @@ -580,10 +627,36 @@ function sample(sampler, file, window, counts, features; verbose=false) end end +getlabels′(track::Track) = track[:labels] + +function getlabels′(timeseries::LarvaBase.TimeSeries) + [convert(Vector{String}, state[:tags]) for (_, state) in timeseries] +end + function presample(state, file::Formats.PreloadedFile, window, ix::LazyIndex) presample(ix.sampler, state, file, window, ix.targetcounts[file]) end presample(_, ::Nothing, _, _, counts) = (0, sum(values(counts))) presample(_, cumulatedcount, _, _, counts) = (sum(cumulatedcount), sum(values(counts))) +struct CapacitySampling <: RatioBasedSampling + selectors + maxcount::Integer + seed +end + +function capacitysampling(selectors, maxcount::Integer; seed=nothing) + LazyIndex(CapacitySampling(asselectors(selectors), maxcount, seed)) +end + +function capacitysampling(maxcount::Integer; seed=nothing) + capacitysampling(nothing, maxcount; seed=seed) +end + +function ratio(sampler::CapacitySampling, counts) + maxcount = sampler.maxcount + ratios = Dict(class => min(count, maxcount) / count for (class, count) in pairs(counts)) + return ratios +end + end diff --git a/test/runtests.jl b/test/runtests.jl index c47f28ee42b29e5dad0d5a5ec3115ad995202e6f..ae75658551b42f80bcecc35ca9fbddb4a9e21dda 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -900,5 +900,47 @@ if all_tests || "Dataloaders" in ARGS repo3 = Repository(joinpath(dataset, "**", "g.*")) @test Dataloaders.filepaths(repo3) == Dataloaders.filepaths(repo2) + mktempdir() do dir + run1 = joinpath(dir, "20140918_170215") + mkpath(run1) + cp(artifact"sample_trxmat_file_small/trx_small.mat", joinpath(run1, "trx.mat")) + # create a fake trx.mat file with a header from an actual v5 MAT file + run2 = joinpath(dir, "20111114_091245") + mkpath(run2) + open(joinpath(run2, "trx.mat"), "w") do f + write(f, "MATLAB 5.0 MAT-file, Platform: GLNXA64, Created on: Wed Mar 7 07:19:12 2018 ") + end + # + repo4 = if VERSION < v"1.7" + # @test_warn exists but does not catch warnings... + Repository(joinpath(dir, "**", "trx.mat")) + else + @test_warn "Cannot read labels from file" Repository(joinpath(dir, "**", "trx.mat")) + end + @test length(repo4) == 1 && basename(dirname(Dataloaders.filepaths(repo4)[1])) == "20140918_170215" + # + window4 = TimeWindow(5) + index4 = capacitysampling(60; seed=347980) + loader4 = DataLoader(repo4, window4, index4) + buildindex(loader4; verbose=false) + @test index4.sampler.maxcount == 60 && index4.sampler.seed == 347980 + @test Set(keys(index4.sampler.selectors)) == Set([:back, :back_large, :back_strong, :back_weak, :cast, :cast_large, :cast_strong, :cast_weak, :hunch, :hunch_large, :hunch_strong, :hunch_weak, :roll, :roll_large, :roll_strong, :roll_weak, :run, :run_large, :run_strong, :run_weak, :stop, :stop_large, :stop_strong, :stop_weak, :small_motion]) + counts4 = copy(first(values(index4.targetcounts))) + counts4′= pop!(counts4, ["back", "back_weak", "small_motion"]) + counts4″= pop!(counts4, ["hunch", "hunch_weak", "small_motion"]) + @test counts4′== 55 && counts4″== 22 && all(==(60), values(counts4)) + # + lengths = nothing # set scope + Dataloaders.sample(loader4, :spine) do i, _, _, segments + lengths = [length(segment.timeseries) for segment in segments] + end + # default rng changed with Julia 1.7 + @test lengths == if VERSION < v"1.7" + [66, 66, 65, 66, 66, 66, 66, 66, 65, 65, 66, 65, 66, 66, 66, 66, 66, 65, 65, 66, 66, 66, 66, 65, 65, 66, 65, 66, 63, 64, 65, 67, 65, 66, 65, 65, 66, 66, 65, 66, 65, 65, 63, 63, 63, 66, 66, 65, 66, 66, 65, 65, 65, 66, 66, 66, 66, 65, 66, 66, 66, 66, 65, 65, 66, 66, 66, 65, 65, 65, 65, 63, 63, 65, 66, 66, 66, 66, 66, 65, 66, 66, 65, 65, 65, 65, 65, 66, 66, 66, 66, 63, 63, 65, 65, 65, 65, 63, 66, 67, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 66, 66, 66, 65, 65, 65, 66, 65, 66, 66, 66, 65, 65, 65, 65, 65, 66, 65, 63, 63, 63, 63, 63, 66, 65, 66, 65, 66, 64, 65, 66, 64, 64, 64, 64, 65, 65, 65, 65, 66, 66, 66, 66, 63, 66, 66, 66, 66, 66, 66, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 66, 66, 66, 66, 64, 64, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 62, 63, 64, 63, 65, 65, 66, 66, 66, 66, 66, 66, 66, 66, 66, 65, 64, 65, 64, 65, 66, 66, 65, 65, 66, 66, 66, 65, 65, 66, 66, 66, 66, 66, 65, 66, 66, 66, 65, 64, 64, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 65, 65, 65, 65, 65, 64, 64, 64, 64, 64, 64, 64, 63, 63, 65, 65, 65, 66, 66, 66, 66, 66, 66, 66, 66, 66, 65, 66, 66, 66, 65, 65, 66, 66, 66, 66, 65, 65, 66, 65, 65, 65, 65, 65, 65, 65, 65, 66, 66, 65, 65, 65, 65, 65, 66, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 66, 66, 66, 65, 66, 66, 65, 65, 65, 64, 64, 64, 65, 65, 65, 65, 65, 65, 65, 63, 63, 63, 63, 65, 63, 63, 64, 64, 64, 64, 64, 66, 66, 66, 65, 65, 65, 66, 66, 66, 66, 66, 64, 64, 64, 64, 63, 63, 63, 63, 63, 63, 63, 63, 63, 65, 65, 66, 66, 66, 65, 65, 65, 66, 65, 66, 65, 65, 65, 65, 65, 65, 65, 66, 65, 65, 66, 66, 65, 64, 64, 64, 65, 64, 64, 64, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 64, 63, 63, 66, 65, 63, 66, 66, 66, 66, 65, 66, 66, 66, 66, 66, 66, 66, 65, 65, 66, 66, 66, 66, 65, 63, 64, 64, 65, 64, 64, 65, 66, 66, 66, 66, 65, 66, 66, 66, 63, 63, 63, 64, 64, 64, 65, 64, 64, 64, 64, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 62, 62, 62, 62, 63, 63, 64, 64, 66, 64, 64, 65, 65, 63, 65, 64, 64, 66, 65, 64, 64, 63] + else + [66, 66, 66, 65, 66, 65, 66, 66, 66, 66, 66, 65, 66, 65, 66, 66, 66, 66, 65, 66, 65, 65, 66, 65, 66, 65, 66, 62, 63, 64, 66, 65, 66, 65, 65, 65, 67, 65, 66, 64, 65, 66, 66, 66, 66, 66, 66, 65, 65, 66, 66, 66, 66, 66, 66, 65, 65, 65, 65, 66, 66, 66, 65, 66, 65, 66, 65, 63, 63, 63, 66, 66, 66, 66, 65, 65, 65, 66, 65, 66, 66, 66, 66, 65, 66, 66, 65, 66, 66, 66, 65, 65, 64, 64, 63, 63, 64, 65, 65, 65, 65, 63, 63, 63, 63, 63, 63, 66, 65, 66, 65, 66, 66, 66, 66, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 66, 65, 66, 66, 65, 65, 66, 65, 65, 66, 65, 65, 65, 65, 66, 64, 64, 64, 63, 66, 66, 66, 63, 62, 65, 65, 66, 65, 65, 66, 64, 64, 64, 65, 63, 63, 63, 63, 63, 63, 64, 63, 65, 66, 66, 65, 66, 65, 66, 65, 66, 65, 65, 66, 65, 65, 65, 65, 65, 65, 66, 66, 66, 66, 64, 65, 65, 64, 63, 63, 63, 63, 62, 63, 63, 64, 64, 64, 64, 65, 66, 66, 66, 66, 66, 65, 66, 66, 66, 66, 66, 65, 66, 66, 66, 67, 65, 65, 65, 64, 63, 66, 65, 66, 66, 66, 65, 66, 66, 65, 64, 64, 64, 64, 64, 64, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 64, 64, 65, 65, 65, 65, 64, 64, 64, 64, 64, 64, 63, 62, 66, 65, 66, 67, 66, 65, 65, 66, 66, 66, 66, 66, 66, 66, 66, 66, 66, 65, 65, 65, 66, 66, 66, 66, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 66, 66, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 66, 66, 65, 66, 66, 66, 65, 65, 64, 64, 64, 65, 65, 65, 65, 65, 65, 65, 64, 64, 64, 63, 64, 64, 63, 62, 64, 64, 64, 64, 64, 64, 66, 66, 65, 66, 65, 66, 65, 65, 65, 66, 65, 63, 63, 63, 63, 64, 64, 64, 64, 64, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 64, 65, 65, 66, 66, 65, 65, 65, 65, 65, 65, 65, 66, 66, 66, 66, 64, 63, 63, 64, 65, 65, 64, 64, 64, 63, 63, 63, 63, 64, 64, 64, 64, 64, 64, 64, 64, 63, 63, 63, 64, 66, 66, 66, 65, 65, 66, 66, 65, 66, 66, 66, 65, 66, 66, 65, 65, 65, 66, 66, 66, 66, 66, 63, 63, 65, 66, 66, 63, 63, 63, 64, 64, 64, 64, 65, 65, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 64, 64, 65, 64, 63, 63, 64, 66, 63, 64, 64, 64, 64, 63, 65, 65, 63, 63, 64] + end + end + end end