diff --git a/src/maggotuba/data/make_dataset.py b/src/maggotuba/data/make_dataset.py
index 8ff4dd4ccab970ddaceb6958e4fdeebdebac1157..174d074797c50f686181e4679e8ccd5d92eb5d5e 100644
--- a/src/maggotuba/data/make_dataset.py
+++ b/src/maggotuba/data/make_dataset.py
@@ -1,7 +1,7 @@
 import glob
 import pathlib
 
-def make_dataset(backend, labels_expected=False, **kwargs):
+def make_dataset(backend, labels_expected=False, trxmat_only=False, labels=None, **kwargs):
     if labels_expected:
         larva_dataset_file = glob.glob(str(backend.raw_data_dir() / "larva_dataset_*.hdf5"))
         if larva_dataset_file:
@@ -12,9 +12,16 @@ def make_dataset(backend, labels_expected=False, **kwargs):
             print(f"moving file to interim: {larva_dataset_file}")
             backend.move_to_interim(larva_dataset_file, copy=False)
         else:
+            if labels:
+                if isinstance(labels, str):
+                    labels = labels.split(',')
+                kwargs["labels"] = labels
             print("generating a larva_dataset file...")
             # generate a larva_dataset_*.hdf5 file in data/interim/{instance}/
-            out = backend.generate_dataset(backend.raw_data_dir(), **kwargs)
+            if trxmat_only:
+                out = backend.compile_trxmat_database(backend.raw_data_dir(), **kwargs)
+            else:
+                out = backend.generate_dataset(backend.raw_data_dir(), **kwargs)
             print(f"larva_dataset file generated: {out}")
 
 
diff --git a/src/maggotuba/models/train_model.py b/src/maggotuba/models/train_model.py
index 7a5d3a578a78d2912ec365b5a5758856c7cf4d8d..0401cd76cba4377ddff3fd0856fe30646ebc5b28 100644
--- a/src/maggotuba/models/train_model.py
+++ b/src/maggotuba/models/train_model.py
@@ -5,10 +5,12 @@ import numpy as np
 import json
 import torch
 import os
+import glob
 
 def train_model(backend):
     # make_dataset generated or moved the larva_dataset file into data/interim/{instance}/
-    larva_dataset_file = backend.list_interim_files()
+    #larva_dataset_file = backend.list_interim_files("larva_dataset_*.hdf5") # recursive
+    larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # not recursive (faster)
     assert len(larva_dataset_file) == 1
     dataset = LarvaDataset(larva_dataset_file[0], torch.Generator(device).manual_seed(42))
     nlabels = len(dataset.labels)