diff --git a/pyproject.toml b/pyproject.toml index 53777e20537530cd2cc236d0077940209e2fd3a3..e9961fb106e3acaed7745c6a5611beafe090efcf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "MaggotUBA-adapter" -version = "0.9" +version = "0.10" description = "Interface between MaggotUBA and the Nyx tagging UI" authors = ["François Laurent"] license = "MIT" diff --git a/src/maggotuba/data/make_dataset.py b/src/maggotuba/data/make_dataset.py index 50475f49d4da41e23e4360db0d20f520806b9fb3..2d0230632c2832e86dd4d55989ddc840fa60888a 100644 --- a/src/maggotuba/data/make_dataset.py +++ b/src/maggotuba/data/make_dataset.py @@ -1,7 +1,7 @@ import glob import pathlib -def make_dataset(backend, labels_expected=False, trxmat_only=False, **kwargs): +def make_dataset(backend, labels_expected=False, trxmat_only=False, balancing_strategy='maggotuba', **kwargs): if labels_expected: larva_dataset_file = glob.glob(str(backend.raw_data_dir() / "larva_dataset_*.hdf5")) if larva_dataset_file: @@ -17,7 +17,9 @@ def make_dataset(backend, labels_expected=False, trxmat_only=False, **kwargs): if False:#trxmat_only: out = backend.compile_trxmat_database(backend.raw_data_dir(), **kwargs) else: - out = backend.generate_dataset(backend.raw_data_dir(), **kwargs) + out = backend.generate_dataset(backend.raw_data_dir(), + balance=isinstance(balancing_strategy, str) and balancing_strategy.lower() == 'maggotuba', + **kwargs) print(f"larva_dataset file generated: {out}") diff --git a/src/maggotuba/models/train_model.py b/src/maggotuba/models/train_model.py index 1760a46f9f645a4a5e7c24069b6e6eb39eebbb9c..f9562c524a0de3cd72eed5a1817090286b346e9a 100644 --- a/src/maggotuba/models/train_model.py +++ b/src/maggotuba/models/train_model.py @@ -6,7 +6,7 @@ import json import glob def train_model(backend, layers=1, pretrained_model_instance="default", - subsets=(1, 0, 0), rng_seed=None, **kwargs): + subsets=(1, 0, 0), rng_seed=None, balancing_strategy='maggotuba', **kwargs): # make_dataset generated or moved the larva_dataset file into data/interim/{instance}/ #larva_dataset_file = backend.list_interim_files("larva_dataset_*.hdf5") # recursive larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # not recursive (faster) @@ -14,6 +14,7 @@ def train_model(backend, layers=1, pretrained_model_instance="default", # subsets=(1, 0, 0) => all data are training data; no validation or test subsets dataset = LarvaDataset(larva_dataset_file[0], new_generator(rng_seed), subsets=subsets, **kwargs) + dataset.weight_classes = isinstance(balancing_strategy, str) and (balancing_strategy.lower() == 'auto') labels = dataset.labels assert 0 < len(labels) labels = labels if isinstance(labels[0], str) else [s.decode() for s in labels] diff --git a/src/maggotuba/models/trainers.py b/src/maggotuba/models/trainers.py index 19b827ded7f2eac649692003be5747ec2314c1f4..154d6f139282c0a52d452c9f9937d991afa32cae 100644 --- a/src/maggotuba/models/trainers.py +++ b/src/maggotuba/models/trainers.py @@ -159,10 +159,13 @@ class MaggotTrainer: def train(self, dataset): self.prepare_dataset(dataset) + kwargs = {} + if dataset.class_weights is not None: + kwargs['weight'] = torch.from_numpy(dataset.class_weights.astype(np.float32)).to(self.device) model = self.model model.train() # this only sets the model in training mode (enables gradients) model.to(self.device) - criterion = nn.CrossEntropyLoss() + criterion = nn.CrossEntropyLoss(**kwargs) nsteps = self.config['optim_iter'] grad_clip = self.config['grad_clip'] # pre-train the classifier with static encoder weights @@ -195,7 +198,7 @@ class MaggotTrainer: return self def draw(self, dataset, subset="train"): - data, expected = dataset.getobs(subset) + data, expected = dataset.getbatch(subset) if isinstance(data, list): data = torch.stack(data) data = data.to(torch.float32).to(self.device)