Skip to content
Snippets Groups Projects
Commit 42fe655e authored by François  LAURENT's avatar François LAURENT
Browse files

max:<n> balancing strategy

parent 2e87a5e7
Branches
Tags
1 merge request!3Set of commits to be tagged v0.15
Pipeline #106822 canceled
......@@ -312,11 +312,11 @@ version = "1.9.0"
[[deps.PlanarLarvae]]
deps = ["DelimitedFiles", "HDF5", "JSON3", "LinearAlgebra", "MAT", "Meshes", "OrderedCollections", "Random", "SHA", "StaticArrays", "Statistics", "StatsBase", "StructTypes"]
git-tree-sha1 = "ef6169e9f8705569925bef897704c7514b4d5f18"
git-tree-sha1 = "025970ba7e5b0b9455b2487c1ad2150a17edba0c"
repo-rev = "main"
repo-url = "https://gitlab.pasteur.fr/nyx/planarlarvae.jl"
uuid = "c2615984-ef14-4d40-b148-916c85b43307"
version = "0.12.0"
version = "0.13.0"
[[deps.PrecompileTools]]
deps = ["Preferences"]
......
......@@ -556,6 +556,7 @@ function write_larva_dataset_hdf5(output_dir::String,
seed=nothing,
distributed_sampling=true,
past_future_extensions=true,
balancing_strategy=nothing,
)
if distributed_sampling
new_write_larva_dataset_hdf5(output_dir, input_data;
......@@ -571,10 +572,12 @@ function write_larva_dataset_hdf5(output_dir::String,
fixmwt=fixmwt,
frameinterval=frameinterval,
includeall=includeall,
balancing_strategy=balancing_strategy,
past_future_extensions=past_future_extensions,
seed=seed)
else
past_future_extensions || throw("not implemented")
isnothing(balancing_strategy) || throw("not implemented")
legacy_write_larva_dataset_hdf5(output_dir, input_data, window_length;
labels=labels,
labelpointers=labelpointers,
......@@ -595,7 +598,8 @@ function new_write_larva_dataset_hdf5(output_dir, input_data;
window_length=20,
labels=nothing,
file_filter=nothing,
balance=true,
balance=nothing,
balancing_strategy=nothing,
frameinterval=0.1,
includeall="edited",
past_future_extensions=true,
......@@ -603,6 +607,8 @@ function new_write_larva_dataset_hdf5(output_dir, input_data;
kwargs...)
repo = if input_data isa String
isnothing(file_filter) ? Repository(input_data) : Repository(input_data, file_filter)
elseif input_data isa Repository
input_data
else
Repository(repo = pwd(), files = eltype(input_data) === String ? preload.(input_data) : input_data)
end
......@@ -613,23 +619,31 @@ function new_write_larva_dataset_hdf5(output_dir, input_data;
@assert !isnothing(frameinterval)
window = TimeWindow(window_length * frameinterval, round(Int, 1 / frameinterval);
maggotuba_compatibility=past_future_extensions)
selectors = isnothing(labels) ? getprimarylabels(first(Dataloaders.files(repo))) : labels
index = if startswith(balancing_strategy, "max:")
# `includeall` not supported
maxcount = parse(Int, balancing_strategy[5:end])
capacitysampling(labels, maxcount; seed=seed)
else
if isnothing(balance)
balance = lowercase(balancing_strategy) == "maggotuba"
end
min_max_ratio = balance ? 2 : 20
index = if isnothing(includeall)
ratiobasedsampling(selectors, min_max_ratio; seed=seed)
if isnothing(includeall)
ratiobasedsampling(labels, min_max_ratio; seed=seed)
else
ratiobasedsampling(selectors, min_max_ratio, prioritylabel(includeall); seed=seed)
ratiobasedsampling(labels, min_max_ratio, prioritylabel(includeall); seed=seed)
end
end
loader = DataLoader(repo, window, index)
try
buildindex(loader; unload=true)
catch ArgumentError
catch
# most likely error message: "collection must be non-empty"
@error "Most likely cause: no time segments could be isolated"
rethrow()
end
total_sample_size = length(loader.index)
classcounts, _ = Dataloaders.groupby(selectors, loader.index.targetcounts)
total_sample_size = length(index)
classcounts, _ = Dataloaders.groupby(index.sampler.selectors, index.targetcounts)
#
extended_window_length = past_future_extensions ? 3 * window_length : window_length
date = Dates.format(Dates.now(), "yyyy_mm_dd")
......
......@@ -484,9 +484,9 @@ run `poetry add {pkg}` from directory: \n
return input_files, labels
def generate_dataset(self, input_files,
labels=None, window_length=20, sample_size=None, balance=True,
labels=None, window_length=20, sample_size=None, balance=None,
include_all=None, frame_interval=None, seed=None,
past_future_extensions=None):
past_future_extensions=None, balancing_strategy=None):
"""
Generate a *larva_dataset hdf5* file in data/interim/{instance}/
"""
......@@ -502,6 +502,7 @@ run `poetry add {pkg}` from directory: \n
balance=balance,
includeall=include_all,
frameinterval=frame_interval,
balancing_strategy=balancing_strategy,
past_future_extensions=past_future_extensions,
seed=seed)
......
......@@ -45,8 +45,8 @@ algorithm aims to minimize. Weights are also specified as a comma-separated
list of floating-point values. As many weights as labels are expected.
In addition to class penalties, the majority classes can be subsampled in
different ways. Argument `--balancing-strategy` can take either "maggotuba" or
"auto", that correspond to specific subsampling strategies.
different ways. Argument `--balancing-strategy` can take values "maggotuba",
"auto" or "max:<n>", that correspond to specific subsampling strategies.
--balancing-strategy maggotuba
Denoting n the size of the minority class, classes of size less than 10n are
......@@ -60,6 +60,10 @@ different ways. Argument `--balancing-strategy` can take either "maggotuba" or
In addition, if class weights are not defined, they are set as the inverse
of the corresponding class size.
--balancing-strategy max:<n>
<n> is the maximum count for any given class. Classes with fewer occurences
in the dataset are fully sampled.
Subsampling is done at random. However, data observations bearing a specific
secondary label can be included with priority, up to the target size if too
many, and then complemented with different randomly picked observations. To
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment