diff --git a/src/LarvaDatasets.jl b/src/LarvaDatasets.jl
index 3e89f955b78e957450bf24264f20e5bbbbab3223..e8ecf29445ecab112b712db76e12b23f5773bdd7 100644
--- a/src/LarvaDatasets.jl
+++ b/src/LarvaDatasets.jl
@@ -418,7 +418,7 @@ end
"""
write_larva_dataset_hdf5(output_directory, input_files, window_length=20)
write_larva_dataset_hdf5(...; labels=nothing, labelpointers=nothing)
- write_larva_dataset_hdf5(...; sample_size=nothing, chunks=false, shallow=false)
+ write_larva_dataset_hdf5(...; sample_size=nothing, balance=true, chunks=false, shallow=false)
write_larva_dataset_hdf5(...; file_filter, timestep_filter)
Sample series of 5-point spines from data files and save them in a hdf5 file,
@@ -451,6 +451,8 @@ See also [`labelledfiles`](@ref Main.PlanarLarvae.Formats.labelledfiles) and
[`labelcounts`](@ref) and their `selection_rule` arguments that correspond to `file_filter`
and `timestep_filter` respectively.
+`balance` refers to MaggotUBA-like class balancing. See also [`balancedcounts`](@ref).
+
Note that, if `input_data` lists all files, `labelledfiles` is not called and arguments
`file_filter`, `chunks` and `shallow` are not used.
Similarly, if `labelpointers` is defined, `labelcounts` is not called and argument
@@ -507,6 +509,7 @@ function write_larva_dataset_hdf5(output_dir::String,
else
sample_sizes = counts
total_sample_size = sum(values(sample_sizes))
+ isnothing(sample_size) || @error "Argument sample_size not supported for the specified balancing strategy"
end
@info "Sample sizes (observed, selected):" [Symbol(label) => (count, get(sample_sizes, label, 0)) for (label, count) in pairs(counts)]...
date = Dates.format(Dates.now(), "yyyy_mm_dd")
diff --git a/src/taggingbackends/data/dataset.py b/src/taggingbackends/data/dataset.py
index 9cc670b5102e27d911806c50ff1a26132ef6dcb9..3e9d2d32ca6520b0c9e848cf20bbce58bf1cd37a 100644
--- a/src/taggingbackends/data/dataset.py
+++ b/src/taggingbackends/data/dataset.py
@@ -1,6 +1,8 @@
import h5py
import pathlib
import itertools
+import numpy as np
+from collections import Counter
"""
Torch-like dataset class for *larva_dataset hdf5* files.
@@ -16,6 +18,9 @@ class LarvaDataset:
self._sample_size = None
self._mask = slice(0, None)
self._window_length = None
+ self._class_weights = None
+ # this attribute was introduced to implement `training_labels`
+ self._alt_training_set_loader = None
"""
*h5py.File*: *larva_dataset hdf5* file handler.
"""
@@ -101,24 +106,28 @@ class LarvaDataset:
train, val, test = torch.utils.data.random_split(TorchDataset(),
[ntrain, nval, ntest],
generator=self.generator)
- if 0 < ntrain:
+ if ntrain == 0:
+ self._training_set = self._alt_training_set_loader = False
+ else:
self._training_set = iter(itertools.cycle(
torch.utils.data.DataLoader(train,
batch_size=self.batch_size,
shuffle=True,
generator=g_train,
drop_last=True)))
+ self._alt_training_set_loader = \
+ torch.utils.data.DataLoader(train, batch_size=ntrain)
self._validation_set = iter(
- torch.utils.data.DataLoader(val,
- batch_size=self.batch_size))
+ torch.utils.data.DataLoader(val, batch_size=self.batch_size))
self._test_set = iter(
- torch.utils.data.DataLoader(test,
- batch_size=self.batch_size))
+ torch.utils.data.DataLoader(test, batch_size=self.batch_size))
"""
Iterator over the training dataset.
"""
@property
def training_set(self):
+ if self._training_set is False: # not available; don't call split again
+ return
if self._training_set is None:
self.split()
return self._training_set
@@ -139,9 +148,32 @@ class LarvaDataset:
self.split()
return self._test_set
"""
+ This property was introduced to implement `class_weights`.
+ It does not memoize.
+ """
+ @property
+ def training_labels(self):
+ if self._alt_training_set_loader is False: # not available; don't call split again
+ return
+ if self._alt_training_set_loader is None:
+ self.split()
+ _, labels = next(iter(self._alt_training_set_loader))
+ try:
+ return labels.numpy()
+ except AttributeError:
+ return labels
+ """
Draw an observation.
+
+ Warning: this actually drew a batch, not an observation;
+ it now throws an error.
"""
def getobs(self, subset="train"):
+ raise NotImplementedError("renamed getbatch")
+ """
+ Draw a series of observations (or batch).
+ """
+ def getbatch(self, subset="train"):
if subset.startswith("train"):
dataset = self.training_set
elif subset.startswith("val"):
@@ -150,24 +182,29 @@ class LarvaDataset:
dataset = self.test_set
return next(dataset)
"""
- Draw one or more observations.
+ Draw one or more batches.
"""
- def getsample(self, subset="train", n=1):
+ def getsample(self, subset="train", nbatches=1):
if subset.startswith("train"):
dataset = self.training_set
elif subset.startswith("val"):
dataset = self.validation_set
elif subset.startswith("test"):
dataset = self.test_set
- if n == "all":
- n = len(dataset)
+ if nbatches == "all":
+ nbatches = len(dataset)
try:
- while 0 < n:
- n -= 1
+ while 0 < nbatches:
+ nbatches -= 1
yield next(dataset)
except StopIteration:
pass
"""
+ Alias for `getbatch`, for backward compatibility.
+ """
+ def getsample(self, subset, n):
+ return self.getbatch(subset, n)
+ """
*int*: number of time points in a segment.
"""
@property
@@ -178,6 +215,34 @@ class LarvaDataset:
self._window_length = anyrecord.shape[0]
return self._window_length
+ @property
+ def weight_classes(self):
+ return self._class_weights is not False
+
+ @weight_classes.setter
+ def weight_classes(self, do_weight):
+ if do_weight:
+ if self._class_weights is False:
+ self._class_weights = True
+ else:
+ self._class_weights = False
+
+ @property
+ def class_weights(self):
+ if not isinstance(self._class_weights, np.ndarray) and self._class_weights in (None, True):
+ _, class_counts = np.unique(self.training_labels, return_counts=True)
+ class_counts = np.array([class_counts[i] for i in range(len(self.labels))])
+ self._class_weights = 1 - class_counts / np.sum(class_counts)
+ return None if self._class_weights is False else self._class_weights
+
+ @class_weights.setter
+ def class_weights(self, weights):
+ if weights not in (None, False):
+ weights = np.asarray(weights)
+ if len(weights) != len(self.labels):
+ raise ValueError("not as many weights as labels")
+ self._class_weights = weights
+
def subset_size(ntot, train_share, val_share, test_share):
ntrain = int(train_share * ntot)
nval = int(val_share * ntot)
diff --git a/src/taggingbackends/explorer.py b/src/taggingbackends/explorer.py
index 803eb72db4cdbced317f9cca00e1984db9ffc6e4..13bc5815381f358c03494e172d6d440d46b9c61a 100644
--- a/src/taggingbackends/explorer.py
+++ b/src/taggingbackends/explorer.py
@@ -473,7 +473,7 @@ run `poetry add {pkg}` from directory: \n
return input_files, labels
def generate_dataset(self, input_files,
- labels=None, window_length=20, sample_size=None,
+ labels=None, window_length=20, sample_size=None, balance=True,
frame_interval=None):
"""
Generate a *larva_dataset hdf5* file in data/interim/{instance}/
@@ -484,6 +484,7 @@ run `poetry add {pkg}` from directory: \n
window_length,
labels=labels,
sample_size=sample_size,
+ balance=balance,
frameinterval=frame_interval)
def compile_trxmat_database(self, input_dir,
diff --git a/src/taggingbackends/main.py b/src/taggingbackends/main.py
index e5c435ef2b97aa6c32183a700ea4c0d29a007d3b..a6c6d9c19f3f25267b24b1cde85b8e83f035a8ec 100644
--- a/src/taggingbackends/main.py
+++ b/src/taggingbackends/main.py
@@ -7,7 +7,7 @@ def help(_print=False):
msg = """
Usage: tagging-backend [train|predict] --model-instance <name>
tagging-backend train ... --labels <comma-separated-list>
- tagging-backend train ... --sample-size <N>
+ tagging-backend train ... --sample-size <N> --balancing-strategy <strategy>
tagging-backend train ... --frame-interval <I> --window-length <T>
tagging-backend train ... --pretrained-model-instance <name>
tagging-backend predict ... --skip-make-dataset --sandbox <token>
@@ -62,6 +62,7 @@ def main(fun=None):
skip_make_dataset = skip_build_features = False
pretrained_model_instance = None
sandbox = False
+ balancing_strategy = 'auto'
unknown_args = {}
k = 2
while k < len(sys.argv):
@@ -100,6 +101,9 @@ def main(fun=None):
elif sys.argv[k] == "--sandbox":
k = k + 1
sandbox = sys.argv[k]
+ elif sys.argv[k] == "--balancing-strategy":
+ k = k + 1
+ balancing_strategy = sys.argv[k]
else:
unknown_args[sys.argv[k].lstrip('-').replace('-', '_')] = sys.argv[k+1]
k = k + 1
@@ -113,7 +117,8 @@ def main(fun=None):
for file in input_files:
backend.move_to_raw(file)
if not skip_make_dataset:
- make_dataset_kwargs = dict(labels_expected=train_or_predict == "train")
+ make_dataset_kwargs = dict(labels_expected=train_or_predict == "train",
+ balancing_strategy=balancing_strategy)
if labels:
make_dataset_kwargs["labels"] = labels
if sample_size:
@@ -134,7 +139,7 @@ def main(fun=None):
if train_or_predict == "predict":
backend._run_script(backend.predict_model, trailing=unknown_args)
else:
- train_kwargs = {}
+ train_kwargs = dict(balancing_strategy=balancing_strategy)
if pretrained_model_instance:
train_kwargs["pretrained_model_instance"] = pretrained_model_instance
backend._run_script(backend.train_model, trailing=unknown_args, **train_kwargs)