diff --git a/src/taggingbackends/data/dataset.py b/src/taggingbackends/data/dataset.py
index 3159921b0dbb4b517b08a902ca4b2eadaf03a3c2..9cc670b5102e27d911806c50ff1a26132ef6dcb9 100644
--- a/src/taggingbackends/data/dataset.py
+++ b/src/taggingbackends/data/dataset.py
@@ -101,12 +101,13 @@ class LarvaDataset:
         train, val, test = torch.utils.data.random_split(TorchDataset(),
                 [ntrain, nval, ntest],
                 generator=self.generator)
-        self._training_set = iter(itertools.cycle(
-            torch.utils.data.DataLoader(train,
-                batch_size=self.batch_size,
-                shuffle=True,
-                generator=g_train,
-                drop_last=True)))
+        if 0 < ntrain:
+            self._training_set = iter(itertools.cycle(
+                torch.utils.data.DataLoader(train,
+                    batch_size=self.batch_size,
+                    shuffle=True,
+                    generator=g_train,
+                    drop_last=True)))
         self._validation_set = iter(
             torch.utils.data.DataLoader(val,
                 batch_size=self.batch_size))