diff --git a/src/LarvaDatasets.jl b/src/LarvaDatasets.jl index 409f925ade242420b3077d7c43727ef0d34e0ce2..22d590cc09a4509a7117207aea3d194e9b3da433 100644 --- a/src/LarvaDatasets.jl +++ b/src/LarvaDatasets.jl @@ -17,6 +17,7 @@ structure of *trx.mat* files, an alternative implementation is provided by modul """ using PlanarLarvae, PlanarLarvae.Formats, PlanarLarvae.Features, PlanarLarvae.MWT +using PlanarLarvae.Datasets: coerce using Random using HDF5 using Dates @@ -264,14 +265,55 @@ function thresholdedcounts(counts; majorityweight=20) end function write_larva_dataset_hdf5(path, counts, files, refs, nsteps_before, nsteps_after; - fixmwt=false, frameinterval=nothing, + fixmwt=false, frameinterval=nothing, includeall=nothing, ) fixmwt && @warn "`fixmwt=true` is no longer supported" # this method mutates argument `refs` - refs′= Tuple{Int, Int, Int, eltype(keys(counts))}[] - for (label, count) in pairs(counts) - for (i, j, k) in shuffle(refs[label])[1:count] - push!(refs′, (i, j, k, label)) + T = eltype(keys(counts)) + refs′= Tuple{Int, Int, Int, T}[] + if !isnothing(includeall) + includeall = coerce(T, includeall) + if haskey(counts, includeall) + count = counts[includeall] + T′= Vector{Tuple{Int, Int, Int, T}} + specialrefs = Dict{T, T′}() + for (i, j, k) in refs[includeall] + for l in keys(refs) + if l != includeall + m = findfirst(==((i, j, k)), refs[l]) + if !isnothing(m) + push!(get!(specialrefs, l, T′()), (i, j, k, l)) + deleteat!(refs[l], m) + end + end + end + end + if !isempty(specialrefs) + @info "Explicit inclusions based on label \"$(includeall)\":" [Symbol(label) => length(refs″) for (label, refs″) in pairs(specialrefs)]... + for (label, count) in pairs(counts) + label == includeall && continue + if label in keys(specialrefs) + refs″= specialrefs[label] + if count < length(refs″) + refs″ = shuffle(refs″)[1:count] + end + refs′= vcat(refs′, refs″) + count = count - length(refs″) + end + if 0 < count + for (i, j, k) in shuffle(refs[label])[1:count] + push!(refs′, (i, j, k, label)) + end + end + end + end + end + end + if isempty(refs′) + for (label, count) in pairs(counts) + for (i, j, k) in shuffle(refs[label])[1:count] + push!(refs′, (i, j, k, label)) + end end end empty!(refs) # free memory @@ -474,6 +516,12 @@ Note that, if `input_data` lists all files, `labelledfiles` is not called and ar Similarly, if `labelpointers` is defined, `labelcounts` is not called and argument `timestep_filter` is not used. +*New in version 0.10*: `includeall` specifies a secondary label for systematic inclusion in +the dataset. The time segments with this secondary label are accounted for under the +associated primary label, prior to applying the balancing rule. If the label specified by +`includeall` is found in `labels`, it is considered as primary and `includeall` is ignored. +Generally speaking, `labels` should not include any secondary label. + Known issue: ASCII-compatible string attributes are ASCII encoded and deserialized as `bytes` by the *h5py* Python library. """ @@ -489,7 +537,8 @@ function write_larva_dataset_hdf5(output_dir::String, shallow=false, balance=true, fixmwt=false, - frameinterval=nothing) + frameinterval=nothing, + includeall="edited") files = if input_data isa String repository = input_data labelledfiles(repository, chunks; selection_rule=file_filter, shallow=shallow) @@ -515,14 +564,23 @@ function write_larva_dataset_hdf5(output_dir::String, refs = labelpointers counts = Dict{String, Int}(label=>length(pointers) for (label, pointers) in pairs(labelpointers)) end - if !isnothing(labels) - labels′ = p -> string(p[1]) in labels - missing_labels = [label for label in labels if label ∉ keys(counts)] - if !isempty(missing_labels) - @warn "No occurences found for labels: \"$(join(missing_labels, "\", \""))\"" + if isnothing(labels) + labels = collect(keys(counts)) + if !isnothing(includeall) && includeall ∈ labels + labels = [label for label in labels if label != includeall] + end + else + if !isnothing(includeall) && includeall ∈ labels + includeall = nothing + end + labels′= if isnothing(includeall) + p -> string(p[1]) in labels + else + p -> string(p[1]) in labels || string(p[1]) == includeall end filter!(labels′, counts) filter!(labels′, refs) + isempty(counts) && throw("None of specified labels were found") end if balance sample_sizes, total_sample_size = balancedcounts(counts, sample_size) @@ -530,12 +588,13 @@ function write_larva_dataset_hdf5(output_dir::String, isnothing(sample_size) || @error "Argument sample_size not supported for the specified balancing strategy" sample_sizes, total_sample_size = thresholdedcounts(counts) end - @info "Sample sizes (observed, selected):" [Symbol(label) => (count, get(sample_sizes, label, 0)) for (label, count) in pairs(counts)]... + @info "Sample sizes (observed, selected):" [Symbol(label) => (get(counts, label, 0), get(sample_sizes, label, 0)) for label in labels]... date = Dates.format(Dates.now(), "yyyy_mm_dd") output_file = joinpath(output_dir, "larva_dataset_$(date)_$(window_length)_$(window_length)_$(total_sample_size).hdf5") write_larva_dataset_hdf5(output_file, sample_sizes, files, refs, nsteps_before, nsteps_after; - fixmwt=fixmwt, frameinterval=frameinterval) + fixmwt=fixmwt, frameinterval=frameinterval, + includeall=includeall) h5open(output_file, "cw") do h5 attributes(h5["samples"])["len_traj"] = window_length diff --git a/src/taggingbackends/explorer.py b/src/taggingbackends/explorer.py index 13bc5815381f358c03494e172d6d440d46b9c61a..3d59007aa63d1ac740a1b702e074040ff5f06959 100644 --- a/src/taggingbackends/explorer.py +++ b/src/taggingbackends/explorer.py @@ -474,7 +474,7 @@ run `poetry add {pkg}` from directory: \n def generate_dataset(self, input_files, labels=None, window_length=20, sample_size=None, balance=True, - frame_interval=None): + include_all=None, frame_interval=None): """ Generate a *larva_dataset hdf5* file in data/interim/{instance}/ """ @@ -485,6 +485,7 @@ run `poetry add {pkg}` from directory: \n labels=labels, sample_size=sample_size, balance=balance, + includeall=include_all, frameinterval=frame_interval) def compile_trxmat_database(self, input_dir, diff --git a/src/taggingbackends/main.py b/src/taggingbackends/main.py index 31655f9231bbc94a2dc2c966c627a35eccfba72f..76761fdf39cb7697dca78051aa00d6633e80db67 100644 --- a/src/taggingbackends/main.py +++ b/src/taggingbackends/main.py @@ -10,6 +10,7 @@ Usage: tagging-backend [train|predict] --model-instance <name> tagging-backend train ... --sample-size <N> --balancing-strategy <strategy> tagging-backend train ... --frame-interval <I> --window-length <T> tagging-backend train ... --pretrained-model-instance <name> + tagging-backend train ... --include-all <secondary-label> tagging-backend train ... --skip-make-dataset --skip-build-features tagging-backend predict ... --make-dataset --build-features --sandbox <token> @@ -37,7 +38,7 @@ the `make_dataset` module is loaded and this may take quite some time due to dependencies (e.g. Julia FFI). The `--skip-make-dataset` option makes `train` truly skip this step; the corresponding module is not loaded. -Since version 0.8, `predict` makes `--skip-make-dataset` and +Since version 0.10, `predict` makes `--skip-make-dataset` and `--skip-build-features` the default behavior. As a counterpart, it admits arguments `--make-dataset` and `--build-features`. @@ -68,6 +69,7 @@ def main(fun=None): pretrained_model_instance = None sandbox = False balancing_strategy = 'auto' + include_all = None unknown_args = {} k = 2 while k < len(sys.argv): @@ -113,6 +115,9 @@ def main(fun=None): elif sys.argv[k] == "--balancing-strategy": k = k + 1 balancing_strategy = sys.argv[k] + elif sys.argv[k] == "--include-all": + k = k + 1 + include_all = sys.argv[k] else: unknown_args[sys.argv[k].lstrip('-').replace('-', '_')] = sys.argv[k+1] k = k + 1 @@ -146,6 +151,8 @@ def main(fun=None): make_dataset_kwargs["reuse_h5files"] = True elif reuse_h5files: logging.info("option --reuse-h5files is ignored in the absence of --trxmat-only") + if include_all: + make_dataset_kwargs["include_all"] = include_all backend._run_script(backend.make_dataset, **make_dataset_kwargs) if build_features: backend._run_script(backend.build_features)