From 47f1d11504aa973c78f2310234c36f378af2f86d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Laurent?= <francois.laurent@posteo.net> Date: Mon, 2 Jan 2023 00:14:50 +0100 Subject: [PATCH] allows no training subset sampling --- src/taggingbackends/data/dataset.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/taggingbackends/data/dataset.py b/src/taggingbackends/data/dataset.py index 3159921..9cc670b 100644 --- a/src/taggingbackends/data/dataset.py +++ b/src/taggingbackends/data/dataset.py @@ -101,12 +101,13 @@ class LarvaDataset: train, val, test = torch.utils.data.random_split(TorchDataset(), [ntrain, nval, ntest], generator=self.generator) - self._training_set = iter(itertools.cycle( - torch.utils.data.DataLoader(train, - batch_size=self.batch_size, - shuffle=True, - generator=g_train, - drop_last=True))) + if 0 < ntrain: + self._training_set = iter(itertools.cycle( + torch.utils.data.DataLoader(train, + batch_size=self.batch_size, + shuffle=True, + generator=g_train, + drop_last=True))) self._validation_set = iter( torch.utils.data.DataLoader(val, batch_size=self.batch_size)) -- GitLab