From 13fb18832352f70a70367415c2a8c51418e4b06b Mon Sep 17 00:00:00 2001
From: Hippolyte Verdier <hverdier@pasteur.fr>
Date: Fri, 28 Oct 2022 15:34:22 +0200
Subject: [PATCH] small things

---
 examples/trajs-files.ipynb              | 9 +++++++++
 src/palm_tools/analysis/mmd_analysis.py | 1 +
 2 files changed, 10 insertions(+)

diff --git a/examples/trajs-files.ipynb b/examples/trajs-files.ipynb
index 6dd1eca..d73c29b 100644
--- a/examples/trajs-files.ipynb
+++ b/examples/trajs-files.ipynb
@@ -1983,6 +1983,15 @@
     "plt.legend()"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "mmd.find_closest_trajectories()"
+   ]
+  },
   {
    "cell_type": "markdown",
    "metadata": {},
diff --git a/src/palm_tools/analysis/mmd_analysis.py b/src/palm_tools/analysis/mmd_analysis.py
index 6019d5f..d39121f 100644
--- a/src/palm_tools/analysis/mmd_analysis.py
+++ b/src/palm_tools/analysis/mmd_analysis.py
@@ -548,6 +548,7 @@ class GroupBasedAnalysisStep(LatentVecsBasedAnalysisStep):
         )
         df = df.loc[df["distance"] < max_dist]
         df.sort_values("distance", inplace=True)
+        df = df.groupby("traj_ID").first().reset_index()
 
         trajs = []
         for i, row in df.iterrows():
-- 
GitLab