diff --git a/src/taggingbackends/data/dataset.py b/src/taggingbackends/data/dataset.py index 3159921b0dbb4b517b08a902ca4b2eadaf03a3c2..9cc670b5102e27d911806c50ff1a26132ef6dcb9 100644 --- a/src/taggingbackends/data/dataset.py +++ b/src/taggingbackends/data/dataset.py @@ -101,12 +101,13 @@ class LarvaDataset: train, val, test = torch.utils.data.random_split(TorchDataset(), [ntrain, nval, ntest], generator=self.generator) - self._training_set = iter(itertools.cycle( - torch.utils.data.DataLoader(train, - batch_size=self.batch_size, - shuffle=True, - generator=g_train, - drop_last=True))) + if 0 < ntrain: + self._training_set = iter(itertools.cycle( + torch.utils.data.DataLoader(train, + batch_size=self.batch_size, + shuffle=True, + generator=g_train, + drop_last=True))) self._validation_set = iter( torch.utils.data.DataLoader(val, batch_size=self.batch_size))