diff --git a/src/taggingbackends/data/dataset.py b/src/taggingbackends/data/dataset.py
index 3159921b0dbb4b517b08a902ca4b2eadaf03a3c2..9cc670b5102e27d911806c50ff1a26132ef6dcb9 100644
--- a/src/taggingbackends/data/dataset.py
+++ b/src/taggingbackends/data/dataset.py
@@ -101,12 +101,13 @@ class LarvaDataset:
train, val, test = torch.utils.data.random_split(TorchDataset(),
[ntrain, nval, ntest],
generator=self.generator)
- self._training_set = iter(itertools.cycle(
- torch.utils.data.DataLoader(train,
- batch_size=self.batch_size,
- shuffle=True,
- generator=g_train,
- drop_last=True)))
+ if 0 < ntrain:
+ self._training_set = iter(itertools.cycle(
+ torch.utils.data.DataLoader(train,
+ batch_size=self.batch_size,
+ shuffle=True,
+ generator=g_train,
+ drop_last=True)))
self._validation_set = iter(
torch.utils.data.DataLoader(val,
batch_size=self.batch_size))