diff --git a/examples/trajs-files.ipynb b/examples/trajs-files.ipynb index 6dd1eca584a752aa8956bfa27d96a38f05ac6881..d73c29bec04bc92243244955d01a4b8150e7d5e5 100644 --- a/examples/trajs-files.ipynb +++ b/examples/trajs-files.ipynb @@ -1983,6 +1983,15 @@ "plt.legend()" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mmd.find_closest_trajectories()" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/src/palm_tools/analysis/mmd_analysis.py b/src/palm_tools/analysis/mmd_analysis.py index 6019d5fe62627e355a1dd181a93f22ccf6945228..d39121f087d126e468434139f6349a6057e5a38f 100644 --- a/src/palm_tools/analysis/mmd_analysis.py +++ b/src/palm_tools/analysis/mmd_analysis.py @@ -548,6 +548,7 @@ class GroupBasedAnalysisStep(LatentVecsBasedAnalysisStep): ) df = df.loc[df["distance"] < max_dist] df.sort_values("distance", inplace=True) + df = df.groupby("traj_ID").first().reset_index() trajs = [] for i, row in df.iterrows():