diff --git a/Manifest.toml b/Manifest.toml index 0810bd952b6a32a0a615582d4febadd8531d394d..d1dbb68891c08febfe490dc5efaa434e0895aee2 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -312,11 +312,11 @@ version = "1.9.0" [[deps.PlanarLarvae]] deps = ["DelimitedFiles", "HDF5", "JSON3", "LinearAlgebra", "MAT", "Meshes", "OrderedCollections", "Random", "SHA", "StaticArrays", "Statistics", "StatsBase", "StructTypes"] -git-tree-sha1 = "ef6169e9f8705569925bef897704c7514b4d5f18" +git-tree-sha1 = "025970ba7e5b0b9455b2487c1ad2150a17edba0c" repo-rev = "main" repo-url = "https://gitlab.pasteur.fr/nyx/planarlarvae.jl" uuid = "c2615984-ef14-4d40-b148-916c85b43307" -version = "0.12.0" +version = "0.13.0" [[deps.PrecompileTools]] deps = ["Preferences"] diff --git a/Project.toml b/Project.toml index 672f20ba4e212055c69b5cde18b7007ba2b39aba..effa5eff15d35ac9c190c55746c892cf5c09d516 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TaggingBackends" uuid = "e551f703-3b82-4335-b341-d497b48d519b" authors = ["François Laurent", "Institut Pasteur"] -version = "0.14.1" +version = "0.15" [deps] Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" diff --git a/pyproject.toml b/pyproject.toml index 7a4100f2386a8cb3eda595bab992471a693b8329..1b0e1297778318b18d3468c56e17abb79b80523d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "TaggingBackends" -version = "0.14.1" +version = "0.15" description = "Backbone for LarvaTagger.jl tagging backends" authors = ["François Laurent"] diff --git a/src/LarvaDatasets.jl b/src/LarvaDatasets.jl index dc37f4241240585a895e4e880eb77eb8bdb04f83..1c1291682409dff8154b2078cac0d432400d24be 100644 --- a/src/LarvaDatasets.jl +++ b/src/LarvaDatasets.jl @@ -556,6 +556,7 @@ function write_larva_dataset_hdf5(output_dir::String, seed=nothing, distributed_sampling=true, past_future_extensions=true, + balancing_strategy=nothing, ) if distributed_sampling new_write_larva_dataset_hdf5(output_dir, input_data; @@ -571,10 +572,12 @@ function write_larva_dataset_hdf5(output_dir::String, fixmwt=fixmwt, frameinterval=frameinterval, includeall=includeall, + balancing_strategy=balancing_strategy, past_future_extensions=past_future_extensions, seed=seed) else past_future_extensions || throw("not implemented") + isnothing(balancing_strategy) || throw("not implemented") legacy_write_larva_dataset_hdf5(output_dir, input_data, window_length; labels=labels, labelpointers=labelpointers, @@ -595,7 +598,8 @@ function new_write_larva_dataset_hdf5(output_dir, input_data; window_length=20, labels=nothing, file_filter=nothing, - balance=true, + balance=nothing, + balancing_strategy=nothing, frameinterval=0.1, includeall="edited", past_future_extensions=true, @@ -603,6 +607,8 @@ function new_write_larva_dataset_hdf5(output_dir, input_data; kwargs...) repo = if input_data isa String isnothing(file_filter) ? Repository(input_data) : Repository(input_data, file_filter) + elseif input_data isa Repository + input_data else Repository(repo = pwd(), files = eltype(input_data) === String ? preload.(input_data) : input_data) end @@ -613,23 +619,31 @@ function new_write_larva_dataset_hdf5(output_dir, input_data; @assert !isnothing(frameinterval) window = TimeWindow(window_length * frameinterval, round(Int, 1 / frameinterval); maggotuba_compatibility=past_future_extensions) - selectors = isnothing(labels) ? getprimarylabels(first(Dataloaders.files(repo))) : labels - min_max_ratio = balance ? 2 : 20 - index = if isnothing(includeall) - ratiobasedsampling(selectors, min_max_ratio; seed=seed) + index = if startswith(balancing_strategy, "max:") + # `includeall` not supported + maxcount = parse(Int, balancing_strategy[5:end]) + capacitysampling(labels, maxcount; seed=seed) else - ratiobasedsampling(selectors, min_max_ratio, prioritylabel(includeall); seed=seed) + if isnothing(balance) + balance = lowercase(balancing_strategy) == "maggotuba" + end + min_max_ratio = balance ? 2 : 20 + if isnothing(includeall) + ratiobasedsampling(labels, min_max_ratio; seed=seed) + else + ratiobasedsampling(labels, min_max_ratio, prioritylabel(includeall); seed=seed) + end end loader = DataLoader(repo, window, index) try buildindex(loader; unload=true) - catch ArgumentError + catch # most likely error message: "collection must be non-empty" @error "Most likely cause: no time segments could be isolated" rethrow() end - total_sample_size = length(loader.index) - classcounts, _ = Dataloaders.groupby(selectors, loader.index.targetcounts) + total_sample_size = length(index) + classcounts, _ = Dataloaders.groupby(index.sampler.selectors, index.targetcounts) # extended_window_length = past_future_extensions ? 3 * window_length : window_length date = Dates.format(Dates.now(), "yyyy_mm_dd") diff --git a/src/taggingbackends/explorer.py b/src/taggingbackends/explorer.py index 0bbccb648c16643fdd22477f19a91904abda3040..caefe5ee00ee7f4d4d8b821af3ba1ec6018927a3 100644 --- a/src/taggingbackends/explorer.py +++ b/src/taggingbackends/explorer.py @@ -484,9 +484,9 @@ run `poetry add {pkg}` from directory: \n return input_files, labels def generate_dataset(self, input_files, - labels=None, window_length=20, sample_size=None, balance=True, + labels=None, window_length=20, sample_size=None, balance=None, include_all=None, frame_interval=None, seed=None, - past_future_extensions=None): + past_future_extensions=None, balancing_strategy=None): """ Generate a *larva_dataset hdf5* file in data/interim/{instance}/ """ @@ -502,6 +502,7 @@ run `poetry add {pkg}` from directory: \n balance=balance, includeall=include_all, frameinterval=frame_interval, + balancing_strategy=balancing_strategy, past_future_extensions=past_future_extensions, seed=seed) diff --git a/src/taggingbackends/main.py b/src/taggingbackends/main.py index a5b7b8e98d237e375fff0e7990709e49573a4080..cf3f85028dd7b5a8fd39fa406a48e0d678902dda 100644 --- a/src/taggingbackends/main.py +++ b/src/taggingbackends/main.py @@ -45,8 +45,8 @@ algorithm aims to minimize. Weights are also specified as a comma-separated list of floating-point values. As many weights as labels are expected. In addition to class penalties, the majority classes can be subsampled in -different ways. Argument `--balancing-strategy` can take either "maggotuba" or -"auto", that correspond to specific subsampling strategies. +different ways. Argument `--balancing-strategy` can take values "maggotuba", +"auto" or "max:<n>", that correspond to specific subsampling strategies. --balancing-strategy maggotuba Denoting n the size of the minority class, classes of size less than 10n are @@ -60,6 +60,10 @@ different ways. Argument `--balancing-strategy` can take either "maggotuba" or In addition, if class weights are not defined, they are set as the inverse of the corresponding class size. +--balancing-strategy max:<n> + <n> is the maximum count for any given class. Classes with fewer occurences + in the dataset are fully sampled. + Subsampling is done at random. However, data observations bearing a specific secondary label can be included with priority, up to the target size if too many, and then complemented with different randomly picked observations. To