diff --git a/src/LarvaDatasets.jl b/src/LarvaDatasets.jl index 3e89f955b78e957450bf24264f20e5bbbbab3223..e8ecf29445ecab112b712db76e12b23f5773bdd7 100644 --- a/src/LarvaDatasets.jl +++ b/src/LarvaDatasets.jl @@ -418,7 +418,7 @@ end """ write_larva_dataset_hdf5(output_directory, input_files, window_length=20) write_larva_dataset_hdf5(...; labels=nothing, labelpointers=nothing) - write_larva_dataset_hdf5(...; sample_size=nothing, chunks=false, shallow=false) + write_larva_dataset_hdf5(...; sample_size=nothing, balance=true, chunks=false, shallow=false) write_larva_dataset_hdf5(...; file_filter, timestep_filter) Sample series of 5-point spines from data files and save them in a hdf5 file, @@ -451,6 +451,8 @@ See also [`labelledfiles`](@ref Main.PlanarLarvae.Formats.labelledfiles) and [`labelcounts`](@ref) and their `selection_rule` arguments that correspond to `file_filter` and `timestep_filter` respectively. +`balance` refers to MaggotUBA-like class balancing. See also [`balancedcounts`](@ref). + Note that, if `input_data` lists all files, `labelledfiles` is not called and arguments `file_filter`, `chunks` and `shallow` are not used. Similarly, if `labelpointers` is defined, `labelcounts` is not called and argument @@ -507,6 +509,7 @@ function write_larva_dataset_hdf5(output_dir::String, else sample_sizes = counts total_sample_size = sum(values(sample_sizes)) + isnothing(sample_size) || @error "Argument sample_size not supported for the specified balancing strategy" end @info "Sample sizes (observed, selected):" [Symbol(label) => (count, get(sample_sizes, label, 0)) for (label, count) in pairs(counts)]... date = Dates.format(Dates.now(), "yyyy_mm_dd") diff --git a/src/taggingbackends/data/dataset.py b/src/taggingbackends/data/dataset.py index 9cc670b5102e27d911806c50ff1a26132ef6dcb9..3e9d2d32ca6520b0c9e848cf20bbce58bf1cd37a 100644 --- a/src/taggingbackends/data/dataset.py +++ b/src/taggingbackends/data/dataset.py @@ -1,6 +1,8 @@ import h5py import pathlib import itertools +import numpy as np +from collections import Counter """ Torch-like dataset class for *larva_dataset hdf5* files. @@ -16,6 +18,9 @@ class LarvaDataset: self._sample_size = None self._mask = slice(0, None) self._window_length = None + self._class_weights = None + # this attribute was introduced to implement `training_labels` + self._alt_training_set_loader = None """ *h5py.File*: *larva_dataset hdf5* file handler. """ @@ -101,24 +106,28 @@ class LarvaDataset: train, val, test = torch.utils.data.random_split(TorchDataset(), [ntrain, nval, ntest], generator=self.generator) - if 0 < ntrain: + if ntrain == 0: + self._training_set = self._alt_training_set_loader = False + else: self._training_set = iter(itertools.cycle( torch.utils.data.DataLoader(train, batch_size=self.batch_size, shuffle=True, generator=g_train, drop_last=True))) + self._alt_training_set_loader = \ + torch.utils.data.DataLoader(train, batch_size=ntrain) self._validation_set = iter( - torch.utils.data.DataLoader(val, - batch_size=self.batch_size)) + torch.utils.data.DataLoader(val, batch_size=self.batch_size)) self._test_set = iter( - torch.utils.data.DataLoader(test, - batch_size=self.batch_size)) + torch.utils.data.DataLoader(test, batch_size=self.batch_size)) """ Iterator over the training dataset. """ @property def training_set(self): + if self._training_set is False: # not available; don't call split again + return if self._training_set is None: self.split() return self._training_set @@ -139,9 +148,32 @@ class LarvaDataset: self.split() return self._test_set """ + This property was introduced to implement `class_weights`. + It does not memoize. + """ + @property + def training_labels(self): + if self._alt_training_set_loader is False: # not available; don't call split again + return + if self._alt_training_set_loader is None: + self.split() + _, labels = next(iter(self._alt_training_set_loader)) + try: + return labels.numpy() + except AttributeError: + return labels + """ Draw an observation. + + Warning: this actually drew a batch, not an observation; + it now throws an error. """ def getobs(self, subset="train"): + raise NotImplementedError("renamed getbatch") + """ + Draw a series of observations (or batch). + """ + def getbatch(self, subset="train"): if subset.startswith("train"): dataset = self.training_set elif subset.startswith("val"): @@ -150,24 +182,29 @@ class LarvaDataset: dataset = self.test_set return next(dataset) """ - Draw one or more observations. + Draw one or more batches. """ - def getsample(self, subset="train", n=1): + def getsample(self, subset="train", nbatches=1): if subset.startswith("train"): dataset = self.training_set elif subset.startswith("val"): dataset = self.validation_set elif subset.startswith("test"): dataset = self.test_set - if n == "all": - n = len(dataset) + if nbatches == "all": + nbatches = len(dataset) try: - while 0 < n: - n -= 1 + while 0 < nbatches: + nbatches -= 1 yield next(dataset) except StopIteration: pass """ + Alias for `getbatch`, for backward compatibility. + """ + def getsample(self, subset, n): + return self.getbatch(subset, n) + """ *int*: number of time points in a segment. """ @property @@ -178,6 +215,34 @@ class LarvaDataset: self._window_length = anyrecord.shape[0] return self._window_length + @property + def weight_classes(self): + return self._class_weights is not False + + @weight_classes.setter + def weight_classes(self, do_weight): + if do_weight: + if self._class_weights is False: + self._class_weights = True + else: + self._class_weights = False + + @property + def class_weights(self): + if not isinstance(self._class_weights, np.ndarray) and self._class_weights in (None, True): + _, class_counts = np.unique(self.training_labels, return_counts=True) + class_counts = np.array([class_counts[i] for i in range(len(self.labels))]) + self._class_weights = 1 - class_counts / np.sum(class_counts) + return None if self._class_weights is False else self._class_weights + + @class_weights.setter + def class_weights(self, weights): + if weights not in (None, False): + weights = np.asarray(weights) + if len(weights) != len(self.labels): + raise ValueError("not as many weights as labels") + self._class_weights = weights + def subset_size(ntot, train_share, val_share, test_share): ntrain = int(train_share * ntot) nval = int(val_share * ntot) diff --git a/src/taggingbackends/explorer.py b/src/taggingbackends/explorer.py index 803eb72db4cdbced317f9cca00e1984db9ffc6e4..13bc5815381f358c03494e172d6d440d46b9c61a 100644 --- a/src/taggingbackends/explorer.py +++ b/src/taggingbackends/explorer.py @@ -473,7 +473,7 @@ run `poetry add {pkg}` from directory: \n return input_files, labels def generate_dataset(self, input_files, - labels=None, window_length=20, sample_size=None, + labels=None, window_length=20, sample_size=None, balance=True, frame_interval=None): """ Generate a *larva_dataset hdf5* file in data/interim/{instance}/ @@ -484,6 +484,7 @@ run `poetry add {pkg}` from directory: \n window_length, labels=labels, sample_size=sample_size, + balance=balance, frameinterval=frame_interval) def compile_trxmat_database(self, input_dir, diff --git a/src/taggingbackends/main.py b/src/taggingbackends/main.py index e5c435ef2b97aa6c32183a700ea4c0d29a007d3b..a6c6d9c19f3f25267b24b1cde85b8e83f035a8ec 100644 --- a/src/taggingbackends/main.py +++ b/src/taggingbackends/main.py @@ -7,7 +7,7 @@ def help(_print=False): msg = """ Usage: tagging-backend [train|predict] --model-instance <name> tagging-backend train ... --labels <comma-separated-list> - tagging-backend train ... --sample-size <N> + tagging-backend train ... --sample-size <N> --balancing-strategy <strategy> tagging-backend train ... --frame-interval <I> --window-length <T> tagging-backend train ... --pretrained-model-instance <name> tagging-backend predict ... --skip-make-dataset --sandbox <token> @@ -62,6 +62,7 @@ def main(fun=None): skip_make_dataset = skip_build_features = False pretrained_model_instance = None sandbox = False + balancing_strategy = 'auto' unknown_args = {} k = 2 while k < len(sys.argv): @@ -100,6 +101,9 @@ def main(fun=None): elif sys.argv[k] == "--sandbox": k = k + 1 sandbox = sys.argv[k] + elif sys.argv[k] == "--balancing-strategy": + k = k + 1 + balancing_strategy = sys.argv[k] else: unknown_args[sys.argv[k].lstrip('-').replace('-', '_')] = sys.argv[k+1] k = k + 1 @@ -113,7 +117,8 @@ def main(fun=None): for file in input_files: backend.move_to_raw(file) if not skip_make_dataset: - make_dataset_kwargs = dict(labels_expected=train_or_predict == "train") + make_dataset_kwargs = dict(labels_expected=train_or_predict == "train", + balancing_strategy=balancing_strategy) if labels: make_dataset_kwargs["labels"] = labels if sample_size: @@ -134,7 +139,7 @@ def main(fun=None): if train_or_predict == "predict": backend._run_script(backend.predict_model, trailing=unknown_args) else: - train_kwargs = {} + train_kwargs = dict(balancing_strategy=balancing_strategy) if pretrained_model_instance: train_kwargs["pretrained_model_instance"] = pretrained_model_instance backend._run_script(backend.train_model, trailing=unknown_args, **train_kwargs)