diff --git a/src/maggotuba/models/denselayer.py b/src/maggotuba/models/denselayer.py
index 9ae12b8913ce0eb35cbf31f94545a618b2770196..8e89c50810a9fb061da829c24e3d8716db547899 100644
--- a/src/maggotuba/models/denselayer.py
+++ b/src/maggotuba/models/denselayer.py
@@ -7,21 +7,43 @@ import torch.nn as nn
 from behavior_model.models.neural_nets import Encoder, device
 import behavior_model.data.utils as data_utils
 
+class DeepLinear(nn.Module):
+    def __init__(self, n_input, n_output, n_layers=1):
+        super().__init__()
+        if n_layers is None: n_layers = 1
+        self.layers = []
+        layers = []
+        for _ in range(n_layers - 1):
+            layer = nn.Linear(n_input, n_input)
+            self.layers.append(layer)
+            layers.append(layer)
+            layers.append(nn.ReLU())
+        layer = nn.Linear(n_input, n_output)
+        self.layers.append(layer)
+        layers.append(layer)
+        self.classifier = nn.Sequential(*layers)
+
+    def _init_layers(self):
+        for layer in self.layers:
+            nn.init.xavier_uniform_(layer.weight)
+            nn.init.zeros_(layer.bias)
+
+    def forward(self, x):
+        return self.classifier.forward(x)
 
 class SupervisedMaggot(nn.Module):
     def __init__(self, n_latent_features, n_behaviors, enc_config, enc_path,
-            clf_path=None):
+            clf_path=None, n_layers=1):
         super().__init__()
         # Pretrained or trained MaggotUBA encoder
         self.encoder = encoder = Encoder(**enc_config)
         encoder.load_state_dict(torch.load(enc_path))
         # Classifier stacked atop the encoder
-        self.clf = nn.Linear(n_latent_features, n_behaviors)
+        self.clf = DeepLinear(n_latent_features, n_behaviors, n_layers)
         if clf_path:
             self.clf.load_state_dict(torch.load(clf_path))
         else:
-            nn.init.xavier_uniform_(self.clf.weight)
-            nn.init.zeros_(self.clf.bias)
+            self.clf._init_layers()
 
     def forward(self, x):
         #x = torch.flip(x, (2,))
@@ -44,15 +66,15 @@ class DenseLayer:
             config=None,
             autoencoder_config=None,
             n_behaviors=None,
+            n_layers=1,
             average_body_length=None,
             device=device):
         # MaggotUBA autoencoder config
         self._config = autoencoder_config
         self._clf_config = config
         self.prepend_log_dir = True
-        self._model = None
-        if n_behaviors is not None:
-            self.n_behaviors = n_behaviors
+        self._n_behaviors = n_behaviors
+        self._n_layers = n_layers
         self.average_body_length = average_body_length
         self.device = device
 
@@ -66,8 +88,10 @@ class DenseLayer:
         if self._config is None:
             self._config = self.clf_config.get("autoencoder_config", None)
         if isinstance(self._config, (str, pathlib.Path)):
-            with open(self._config, "r") as f:
+            path = self._config
+            with open(path, "r") as f:
                 self._config = json.load(f)
+            self._config["config"] = str(path)
         return self._config
 
     @config.setter
@@ -117,12 +141,23 @@ class DenseLayer:
 
     @property
     def n_behaviors(self):
-        return self.clf_config.get("n_behaviors", None)
+        return self.clf_config.get("n_behaviors", self._n_behaviors)
 
     @n_behaviors.setter
     def n_behaviors(self, n):
         self.clf_config["n_behaviors"] = n
 
+    @property
+    def n_layers(self):
+        try:
+            return self.clf_config["clf_depth"] + 1
+        except KeyError:
+            return self._n_layers
+
+    @n_behaviors.setter
+    def n_layers(self, n):
+        self.clf_config["clf_depth"] =  0 if n is None else n - 1
+
     def window(self, data):
         winlen = self.config["len_traj"]
         N = data.shape[0]+1
@@ -190,12 +225,15 @@ class DenseLayer:
         if train:
             return self.model(x)
         else:
-            if not isinstance(x, torch.Tensor):
+            if isinstance(x, torch.Tensor):
+                if x.dtype is not torch.float32:
+                    x = x.to(torch.float32)
+            else:
                 x = torch.from_numpy(x.astype(np.float32))
             y = self.model(x.to(self.device))
             return y.cpu().numpy()
 
-    def train(self, dataset):
+    def prepare_dataset(self, dataset):
         try:
             dataset.batch_size
         except AttributeError:
@@ -214,6 +252,9 @@ class DenseLayer:
         if not (0 <= midpoint - before and midpoint + after <= dataset.window_length):
             raise ValueError(f"the dataset can provide segments of up to {dataset.window_length} time points")
         dataset._mask = slice(midpoint - before, midpoint + after)
+
+    def train(self, dataset):
+        self.prepare_dataset(dataset)
         #
         enc_path = "best_validated_encoder.pt"
         if self.prepend_log_dir:
@@ -227,6 +268,7 @@ class DenseLayer:
                 n_behaviors=self.n_behaviors,
                 enc_config=self.config,
                 enc_path=enc_path,
+                n_layers=self.n_layers,
                 )
         model.train() # this only sets the model in training mode (enables gradients)
         model.to(self.device)
@@ -256,38 +298,53 @@ class DenseLayer:
         #
         return self
 
-    def draw(self, dataset):
-        data, expected = dataset.getsample()
+    def draw(self, dataset, subset="train"):
+        data, expected = dataset.getobs(subset)
         if isinstance(data, list):
             data = torch.stack(data)
         data = data.to(torch.float32).to(self.device)
         if isinstance(expected, list):
             expected = torch.stack(expected)
-        expected = expected.to(torch.long).to(self.device)
+        if subset.startswith("train"):
+            expected = expected.to(torch.long).to(self.device)
         return data, expected
 
     @torch.no_grad()
-    def predict(self, all_spines):
-        data = self.preprocess(all_spines)
-        if data is None:
-            return
+    def predict(self, data, subset=None):
         self.model = model = SupervisedMaggot(
                 n_latent_features=self.config["dim_latent"],
                 n_behaviors=self.n_behaviors,
                 enc_config=self.config,
                 enc_path=self.enc_path,
                 clf_path=self.clf_path,
+                n_layers=self.n_layers,
                 )
         model.eval()
         model.to(self.device)
-        output = self.forward(data)
-        label_ids = np.argmax(output, axis=1)
-        try:
-            self.labels
-        except AttributeError:
-            self.labels = self.clf_config["behavior_labels"]
-        labels = [self.labels[label] for label in label_ids]
-        return labels
+        if subset is None:
+            data = self.preprocess(data)
+            if data is None:
+                return
+            output = self.forward(data)
+            label_ids = np.argmax(output, axis=1)
+            try:
+                self.labels
+            except AttributeError:
+                self.labels = self.clf_config["behavior_labels"]
+            labels = [self.labels[label] for label in label_ids]
+            return labels
+        else:
+            dataset = data
+            self.prepare_dataset(dataset)
+            predicted, expected = [], []
+            for data, exp in dataset.getsample(subset, "all"):
+                output = self.forward(data)
+                pred = np.argmax(output, axis=1)
+                exp = exp.numpy()
+                assert pred.size == exp.size
+                predicted.append(pred)
+                expected.append(exp)
+            return np.concatenate(predicted), np.concatenate(expected)
 
     def save(self, config_path="clf_config.json", config_only=False):
         if self.prepend_log_dir:
@@ -302,8 +359,8 @@ class DenseLayer:
                 clf_path=self.clf_path,
                 n_behaviors=self.n_behaviors,
                 behavior_labels=self.labels,
+                clf_depth=self.n_layers - 1,
                 # additional information (not reused):
-                clf_depth=0,
                 bias=True,
                 init="xavier",
                 loss="cross-entropy",
@@ -311,3 +368,5 @@ class DenseLayer:
                 target=["present"],
                 ), f, indent=2)
 
+def new_generator():
+    return torch.Generator(device).manual_seed(42)
diff --git a/src/maggotuba/models/predict_model.py b/src/maggotuba/models/predict_model.py
index 9ed347fa2ef2fdd50920d52c708b4f4826f16da9..1580c11b89e6a37fe8dc3195d5655247a2777ff3 100644
--- a/src/maggotuba/models/predict_model.py
+++ b/src/maggotuba/models/predict_model.py
@@ -3,11 +3,11 @@ from taggingbackends.data.chore import load_spine
 import taggingbackends.data.fimtrack as fimtrack
 from taggingbackends.data.labels import Labels
 from taggingbackends.features.skeleton import get_5point_spines
-from maggotuba.models.denselayer import DenseLayer
+from maggotuba.models.denselayer import DenseLayer, new_generator
 import numpy as np
 import json
 
-def predict_model(backend):
+def predict_model(backend, **kwargs):
     """
     This function generates predicted labels for all the input data.
 
@@ -27,10 +27,32 @@ def predict_model(backend):
     # initialize output labels
     input_files, labels = backend.prepare_labels(input_files)
     assert 0 < len(input_files)
+    # load the model
+    model_files = backend.list_model_files()
+    config_file = [file for file in model_files if file.name.endswith("config.json")]
+    if 1 < len(config_file):
+        config_file = [file for file in config_file if file.name.endswith("clf_config.json")]
+    model = DenseLayer(config_file[-1])
     #
+    labels.labelspec = model.clf_config["behavior_labels"]
+    #
+    if len(input_files) == 1:
+        file = input_files[0]
+        if file.name.startswith("larva_dataset_") and file.name.endswith(".hdf5"):
+            ret = predict_larva_dataset(backend, model, file, labels, **kwargs)
+            return labels if ret is None else ret
+    #
+    ret = predict_individual_data_files(backend, model, input_files, labels)
+    return labels if ret is None else ret
+
+def predict_individual_data_files(backend, model, input_files, labels):
+    _break = False # for now, a single file can be labelled at a time
     for file in input_files:
         # load the input data (or features)
-        if file.name.endswith(".spine"):
+        if _break:
+            print(f"ignoring file: {file.name}")
+            continue
+        elif file.name.endswith(".spine"):
             spine = load_spine(file)
             run = spine["date_time"].iloc[0]
             larvae = spine["larva_id"].values
@@ -53,6 +75,7 @@ def predict_model(backend):
             t, data = fimtrack.read_spines(file, fps=labels.camera_framerate)
             run = "NA"
         else:
+            print(f"ignoring file: {file.name}")
             continue
         # downsample the skeleton
         if isinstance(data, dict):
@@ -60,12 +83,6 @@ def predict_model(backend):
                 data[larva] = get_5point_spines(data[larva])
         else:
             data = get_5point_spines(data)
-        # load the model
-        model_files = backend.list_model_files()
-        config_file = [file for file in model_files if file.name.endswith("config.json")]
-        if 1 < len(config_file):
-            config_file = [file for file in config_file if file.name.endswith("clf_config.json")]
-        model = DenseLayer(config_file[-1])
         # assign labels
         if isinstance(data, dict):
             ref_length = np.median(np.concatenate([
@@ -91,13 +108,14 @@ def predict_model(backend):
                 else:
                     labels[run, larva] = dict(zip(t[mask], predictions))
         # save the predicted labels to file
-        # labels.labelspec = {
-        #         "names": ["run", "bend", "stop", "hunch", "back", "roll"],
-        #         "colors": ["#000000", "#ff0000", "#00ff00", "#0000ff",
-        #             "#00ffff", "#ffff00"]
-        #         }
-        labels.labelspec = model.clf_config["behavior_labels"]
         labels.dump(backend.processed_data_dir() / "predicted.labels")
+        #
+        _break = True
+
+def predict_larva_dataset(backend, model, file, labels, subset="validation"):
+    from taggingbackends.data.dataset import LarvaDataset
+    dataset = LarvaDataset(file, new_generator())
+    return model.predict(dataset, subset)
 
 
 from taggingbackends.main import main
diff --git a/src/maggotuba/models/train_model.py b/src/maggotuba/models/train_model.py
index c121ddcacc4bc3f697bf8d3b2a9c82873ac3c0ba..a19f121173f0e23c6058daeb811fd0e01b4ee427 100644
--- a/src/maggotuba/models/train_model.py
+++ b/src/maggotuba/models/train_model.py
@@ -1,18 +1,18 @@
 from taggingbackends.data.labels import Labels
 from taggingbackends.data.dataset import LarvaDataset
-from maggotuba.models.denselayer import DenseLayer, device
+from maggotuba.models.denselayer import DenseLayer, new_generator
 import numpy as np
 import json
 import torch
 import os
 import glob
 
-def train_model(backend, pretrained_model_instance="default"):
+def train_model(backend, layers=1, pretrained_model_instance="default"):
     # make_dataset generated or moved the larva_dataset file into data/interim/{instance}/
     #larva_dataset_file = backend.list_interim_files("larva_dataset_*.hdf5") # recursive
     larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # not recursive (faster)
     assert len(larva_dataset_file) == 1
-    dataset = LarvaDataset(larva_dataset_file[0], torch.Generator(device).manual_seed(42))
+    dataset = LarvaDataset(larva_dataset_file[0], new_generator())
     nlabels = len(dataset.labels)
     assert 0 < nlabels
     # copy the pretrained model into the model instance directory
@@ -40,7 +40,8 @@ def train_model(backend, pretrained_model_instance="default"):
                 with open(str(dst), "wb") as o:
                     o.write(i.read())
     # load the pretrained model
-    model = DenseLayer(autoencoder_config=config_file, n_behaviors=nlabels)
+    model = DenseLayer(autoencoder_config=config_file, n_behaviors=nlabels,
+            n_layers=layers)
     # fine-tune and save the model
     model.train(dataset)
     model.save()