From 47f1d11504aa973c78f2310234c36f378af2f86d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Fran=C3=A7ois=20Laurent?= <francois.laurent@posteo.net>
Date: Mon, 2 Jan 2023 00:14:50 +0100
Subject: [PATCH] allows no training subset sampling

---
 src/taggingbackends/data/dataset.py | 13 +++++++------
 1 file changed, 7 insertions(+), 6 deletions(-)

diff --git a/src/taggingbackends/data/dataset.py b/src/taggingbackends/data/dataset.py
index 3159921..9cc670b 100644
--- a/src/taggingbackends/data/dataset.py
+++ b/src/taggingbackends/data/dataset.py
@@ -101,12 +101,13 @@ class LarvaDataset:
         train, val, test = torch.utils.data.random_split(TorchDataset(),
                 [ntrain, nval, ntest],
                 generator=self.generator)
-        self._training_set = iter(itertools.cycle(
-            torch.utils.data.DataLoader(train,
-                batch_size=self.batch_size,
-                shuffle=True,
-                generator=g_train,
-                drop_last=True)))
+        if 0 < ntrain:
+            self._training_set = iter(itertools.cycle(
+                torch.utils.data.DataLoader(train,
+                    batch_size=self.batch_size,
+                    shuffle=True,
+                    generator=g_train,
+                    drop_last=True)))
         self._validation_set = iter(
             torch.utils.data.DataLoader(val,
                 batch_size=self.batch_size))
-- 
GitLab