diff --git a/pyproject.toml b/pyproject.toml
index 01e9c3d3c407fff7624add08df7e33820c8bc4fd..8047108a7bab45f3a5dc8e9bc1e8f8d0a4d3b73d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -14,7 +14,7 @@ maggotuba-core = {git = "https://gitlab.pasteur.fr/nyx/MaggotUBA-core", tag = "v
 torch = "^1.11.0"
 numpy = "^1.19.3"
 protobuf = "3.9.2"
-taggingbackends = {git = "https://gitlab.pasteur.fr/nyx/TaggingBackends", tag = "v0.16"}
+taggingbackends = {git = "https://gitlab.pasteur.fr/nyx/TaggingBackends", rev = "dev"}
 
 [build-system]
 requires = ["poetry-core>=1.0.0"]
diff --git a/src/maggotuba/features/__init__.py b/src/maggotuba/features/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/maggotuba/features/preprocess.py b/src/maggotuba/features/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b98c4de7e5c93f953db28db6e87e702058ec98a
--- /dev/null
+++ b/src/maggotuba/features/preprocess.py
@@ -0,0 +1,108 @@
+import numpy as np
+
+
+class Preprocessor:
+    def __init__(self, configured, average_body_length=1.0):
+        self.configured = configured
+        # usually set later
+        self.average_body_length = average_body_length
+
+    @property
+    def config(self):
+        return self.configured.config
+
+    @property
+    def swap_head_tail(self):
+        return self.config.get('swap_head_tail', True)
+
+    @swap_head_tail.setter
+    def swap_head_tail(self, b):
+        self.config['swap_head_tail'] = b
+
+    def window(self, t, data):
+        interpolation_args = {k: self.config[k]
+                              for k in ('spine_interpolation', 'frame_interval')
+                              if k in self.config}
+        winlen = self.config["len_traj"]
+        N = data.shape[0]+1
+        if interpolation_args:
+            for m in range(0, N-1):
+                win = interpolate(t, data, m, winlen, **interpolation_args)
+                if win is not None:
+                    assert win.shape[0] == winlen
+                    yield t[m], win
+        else:
+            for m in range(0, N-winlen):
+                n = m + winlen
+                yield t[(m + n) // 2], data[m:n]
+
+    def pad(self, target_t, defined_t, data):
+        if data.shape[0] == 1:
+            return np.repeat(data, len(target_t), axis=0)
+        else:
+            head = searchsortedfirst(target_t, defined_t[0])
+            tail = len(target_t) - (searchsortedlast(target_t, defined_t[-1]) + 1)
+            ind = np.r_[
+                    np.zeros(head, dtype=int),
+                    np.arange(data.shape[0]),
+                    (data.shape[1]-1) * np.ones(tail, dtype=int),
+                    ]
+            if len(ind) != len(target_t):
+                raise RuntimeError('missing time steps')
+            return data[ind]
+
+    def body_length(self, data):
+        dx = np.diff(data[:,0::2], axis=1)
+        dy = np.diff(data[:,1::2], axis=1)
+        return np.sum(np.sqrt(dx*dx + dy*dy), axis=1)
+
+    def normalize(self, w):
+        # center coordinates
+        wc = np.mean(w[:,4:6], axis=0, keepdims=True)
+        w = w - np.tile(wc, (1, 5))
+        # rotate
+        v = np.mean(w[:,8:10] - w[:,0:2], axis=0)
+        vnorm = np.sqrt(np.dot(v, v))
+        if vnorm == 0:
+            logging.warning('null distance between head and tail')
+        else:
+            v = v / vnorm
+        c, s = v / self.average_body_length # scale using the rotation matrix
+        rot = np.array([[ c, s],
+                        [-s, c]]) # clockwise rotation
+        w = np.einsum("ij,jkl", rot, np.reshape(w.T, (2, 5, -1), order='F'))
+        return w
+
+    """
+    Preprocess a single track.
+
+    This includes running a sliding window, resampling the track in each window,
+    normalizing the spines, etc.
+    """
+    def preprocess(self, t, data):
+        defined_t = []
+        ws = []
+        for t_, w in self.window(t, data):
+            defined_t.append(t_)
+            ws.append(self.normalize(w))
+        if ws:
+            ret = self.pad(t, defined_t, np.stack(ws))
+            if self.swap_head_tail:
+                ret = ret[:,:,::-1,:]
+            return ret
+
+    def __callable__(self, *args):
+        return self.proprocess(*args)
+
+
+# Julia functions
+def searchsortedfirst(xs, x):
+    for i, x_ in enumerate(xs):
+        if x <= x_:
+            return i
+
+def searchsortedlast(xs, x):
+    for i in range(len(xs))[::-1]:
+        x_ = xs[i]
+        if x_ <= x:
+            return i
diff --git a/src/maggotuba/models/trainers.py b/src/maggotuba/models/trainers.py
index 2603a8490dc3ec052ae15962592a7b85ff5bcab8..40f733bae9a491b5d652ffc108b039a57971ea40 100644
--- a/src/maggotuba/models/trainers.py
+++ b/src/maggotuba/models/trainers.py
@@ -3,6 +3,7 @@ import torch
 import torch.nn as nn
 from behavior_model.models.neural_nets import device
 from maggotuba.models.modules import SupervisedMaggot, MultiscaleSupervisedMaggot, MaggotBag
+from maggotuba.features.preprocess import Preprocessor
 from taggingbackends.features.skeleton import interpolate
 from taggingbackends.explorer import BackendExplorer, check_permissions
 import logging
@@ -26,7 +27,7 @@ class MaggotTrainer:
     def __init__(self, cfgfilepath, behaviors=[], n_layers=1, n_iterations=None,
             average_body_length=1.0, device=device):
         self.model = SupervisedMaggot(cfgfilepath, behaviors, n_layers, n_iterations)
-        self.average_body_length = average_body_length # usually set later
+        self.preprocessor = Preprocessor(self, average_body_length)
         self.device = device
 
     @property
@@ -45,88 +46,8 @@ class MaggotTrainer:
     def labels(self, labels):
         self.model.clf.behavior_labels = labels
 
-    @property
-    def swap_head_tail(self):
-        return self.config.get('swap_head_tail', True)
-
-    @swap_head_tail.setter
-    def swap_head_tail(self, b):
-        self.config['swap_head_tail'] = b
-
-    ### TODO: move parts of the below code in a features module
-    # all the code in this section is called by `predict` only
-    def window(self, t, data):
-        interpolation_args = {k: self.config[k]
-                              for k in ('spine_interpolation', 'frame_interval')
-                              if k in self.config}
-        winlen = self.config["len_traj"]
-        N = data.shape[0]+1
-        if interpolation_args:
-            for m in range(0, N-1):
-                win = interpolate(t, data, m, winlen, **interpolation_args)
-                if win is not None:
-                    assert win.shape[0] == winlen
-                    yield t[m], win
-        else:
-            for m in range(0, N-winlen):
-                n = m + winlen
-                yield t[(m + n) // 2], data[m:n]
-
-    def pad(self, target_t, defined_t, data):
-        if data.shape[0] == 1:
-            return np.repeat(data, len(target_t), axis=0)
-        else:
-            head = searchsortedfirst(target_t, defined_t[0])
-            tail = len(target_t) - (searchsortedlast(target_t, defined_t[-1]) + 1)
-            ind = np.r_[
-                    np.zeros(head, dtype=int),
-                    np.arange(data.shape[0]),
-                    (data.shape[1]-1) * np.ones(tail, dtype=int),
-                    ]
-            if len(ind) != len(target_t):
-                raise RuntimeError('missing time steps')
-            return data[ind]
-
     def body_length(self, data):
-        dx = np.diff(data[:,0::2], axis=1)
-        dy = np.diff(data[:,1::2], axis=1)
-        return np.sum(np.sqrt(dx*dx + dy*dy), axis=1)
-
-    def normalize(self, w):
-        # center coordinates
-        wc = np.mean(w[:,4:6], axis=0, keepdims=True)
-        w = w - np.tile(wc, (1, 5))
-        # rotate
-        v = np.mean(w[:,8:10] - w[:,0:2], axis=0)
-        vnorm = np.sqrt(np.dot(v, v))
-        if vnorm == 0:
-            logging.warning('null distance between head and tail')
-        else:
-            v = v / vnorm
-        c, s = v / self.average_body_length # scale using the rotation matrix
-        rot = np.array([[ c, s],
-                        [-s, c]]) # clockwise rotation
-        w = np.einsum("ij,jkl", rot, np.reshape(w.T, (2, 5, -1), order='F'))
-        return w
-
-    """
-    Preprocess a single track.
-
-    This includes running a sliding window, resampling the track in each window,
-    normalizing the spines, etc.
-    """
-    def preprocess(self, t, data):
-        defined_t = []
-        ws = []
-        for t_, w in self.window(t, data):
-            defined_t.append(t_)
-            ws.append(self.normalize(w))
-        if ws:
-            ret = self.pad(t, defined_t, np.stack(ws))
-            if self.swap_head_tail:
-                ret = ret[:,:,::-1,:]
-            return ret
-    ###
+        return self.preprocessor.body_length(data)
 
     def forward(self, x, train=False):
         if train:
@@ -233,7 +154,7 @@ class MaggotTrainer:
         model.to(self.device)
         if subset is None:
             # data is a (times, spines) couple
-            data = self.preprocess(*data)
+            data = self.preprocessor(*data)
             if data is None:
                 return
             output = self.forward(data)
@@ -323,7 +244,7 @@ class MultiscaleMaggotTrainer(MaggotTrainer):
             average_body_length=1.0, device=device):
         self.model = MultiscaleSupervisedMaggot(cfgfilepath, behaviors,
                                                 n_layers, n_iterations)
-        self.average_body_length = average_body_length # usually set later
+        self.preprocessor = Preprocessor(self, average_body_length)
         self.device = device
         self._default_encoder_config = None
         # check consistency
@@ -349,7 +270,7 @@ class MaggotBagging(MaggotTrainer):
     def __init__(self, cfgfilepaths, behaviors=[], n_layers=1,
             average_body_length=1.0, device=device):
         self.model = MaggotBag(cfgfilepaths, behaviors, n_layers)
-        self.average_body_length = average_body_length # usually set later
+        self.preprocessor = Preprocessor(self, average_body_length)
         self.device = device