From 13fb18832352f70a70367415c2a8c51418e4b06b Mon Sep 17 00:00:00 2001 From: Hippolyte Verdier <hverdier@pasteur.fr> Date: Fri, 28 Oct 2022 15:34:22 +0200 Subject: [PATCH] small things --- examples/trajs-files.ipynb | 9 +++++++++ src/palm_tools/analysis/mmd_analysis.py | 1 + 2 files changed, 10 insertions(+) diff --git a/examples/trajs-files.ipynb b/examples/trajs-files.ipynb index 6dd1eca..d73c29b 100644 --- a/examples/trajs-files.ipynb +++ b/examples/trajs-files.ipynb @@ -1983,6 +1983,15 @@ "plt.legend()" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mmd.find_closest_trajectories()" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/src/palm_tools/analysis/mmd_analysis.py b/src/palm_tools/analysis/mmd_analysis.py index 6019d5f..d39121f 100644 --- a/src/palm_tools/analysis/mmd_analysis.py +++ b/src/palm_tools/analysis/mmd_analysis.py @@ -548,6 +548,7 @@ class GroupBasedAnalysisStep(LatentVecsBasedAnalysisStep): ) df = df.loc[df["distance"] < max_dist] df.sort_values("distance", inplace=True) + df = df.groupby("traj_ID").first().reset_index() trajs = [] for i, row in df.iterrows(): -- GitLab