diff --git a/examples/trajs-files.ipynb b/examples/trajs-files.ipynb
index 6dd1eca584a752aa8956bfa27d96a38f05ac6881..d73c29bec04bc92243244955d01a4b8150e7d5e5 100644
--- a/examples/trajs-files.ipynb
+++ b/examples/trajs-files.ipynb
@@ -1983,6 +1983,15 @@
     "plt.legend()"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "mmd.find_closest_trajectories()"
+   ]
+  },
   {
    "cell_type": "markdown",
    "metadata": {},
diff --git a/src/palm_tools/analysis/mmd_analysis.py b/src/palm_tools/analysis/mmd_analysis.py
index 6019d5fe62627e355a1dd181a93f22ccf6945228..d39121f087d126e468434139f6349a6057e5a38f 100644
--- a/src/palm_tools/analysis/mmd_analysis.py
+++ b/src/palm_tools/analysis/mmd_analysis.py
@@ -548,6 +548,7 @@ class GroupBasedAnalysisStep(LatentVecsBasedAnalysisStep):
         )
         df = df.loc[df["distance"] < max_dist]
         df.sort_values("distance", inplace=True)
+        df = df.groupby("traj_ID").first().reset_index()
 
         trajs = []
         for i, row in df.iterrows():