From 77bb56a06a4f03b00fefaa35a09d494a53b14998 Mon Sep 17 00:00:00 2001
From: hjulienn <hanna.julienne@pasteur.fr>
Date: Thu, 7 Mar 2024 11:20:17 +0100
Subject: [PATCH] figure 6 modification: add random baseline

---
 Figures_manuscript/outputs/.placeholder       |   0
 ...ait_selection_method_without_duplicates.py | 170 +++++++++++-------
 2 files changed, 107 insertions(+), 63 deletions(-)
 create mode 100644 Figures_manuscript/outputs/.placeholder

diff --git a/Figures_manuscript/outputs/.placeholder b/Figures_manuscript/outputs/.placeholder
new file mode 100644
index 0000000..e69de29
diff --git a/Figures_manuscript/scripts/Figure_6_compare_trait_selection_method_without_duplicates.py b/Figures_manuscript/scripts/Figure_6_compare_trait_selection_method_without_duplicates.py
index 4248f22..440b8e1 100644
--- a/Figures_manuscript/scripts/Figure_6_compare_trait_selection_method_without_duplicates.py
+++ b/Figures_manuscript/scripts/Figure_6_compare_trait_selection_method_without_duplicates.py
@@ -1,4 +1,3 @@
-import pandas as pd
 import numpy as np
 import matplotlib
 matplotlib.use('AGG')
@@ -7,114 +6,158 @@ from scipy.stats import ttest_ind
 from itertools import combinations
 import matplotlib.transforms as mtransforms
 import scipy
+from itertools import combinations
+import pandas as pd
 
 df_train = pd.read_csv('../inputs/clinical_grouping_analysis_2023-09-06/category_traitset_with_mean_train_jass.tsv',delimiter='\t')
 df_val = pd.read_csv('../inputs/clinical_grouping_analysis_2023-09-06/category_traitset_with_mean_test_jass.tsv',delimiter='\t')
-output = '../outputs/Figure_6_compare_trait_selection_method_jass_without_duplicates.png' 
+output = '../outputs/Figure_6_compare_trait_selection_method_jass_without_duplicates.pdf' 
+
+
 similar_train = df_train.loc[df_train.n_group==1]
 dif_train = df_train.loc[(df_train.n_group>1)&(df_train.n_group<=4)]
 verydif_train = df_train.loc[df_train.n_group>4]
+random_baseline_train = df_train.iloc[np.random.choice(range(df_train.shape[0]), 100)]
+
 similar_val = df_val.loc[df_val.n_group==1]
 dif_val = df_val.loc[(df_val.n_group>1)&(df_val.n_group<=4)]
 verydif_val = df_val.loc[df_val.n_group>4]
-#datadriven = df_val.loc[df_val.rank_datadriven < 20]
+random_baseline_val = df_val.iloc[np.random.choice(range(df_val.shape[0]), 100)]
+
 datadriven = df_val.loc[df_val.rank_datadriven < df_val.rank_datadriven.nsmallest(101).iloc[-1]]
 
-# I have edited the previous process where I dropped duplicates already, though only those that were duplicated within each of training and validation data.
-#trait_sets = []
-#for t in datadriven.trait_set:
-#	traits = str(sorted(t.split(' ')))
-#	if traits in trait_sets:
-#		print(t)
-#	else:
-#		trait_sets.append(traits)
-#print(len(trait_sets))
-#
-#print(datadriven.loc[(datadriven.trait_set=='z_MEGASTROKE_AIS z_MEGASTROKE_AS')|(datadriven.trait_set=='z_MEGASTROKE_AS z_MEGASTROKE_AIS')])
-#datadriven = datadriven.loc[~datadriven.trait_set.isin(['z_MEGASTROKE_AIS z_MEGASTROKE_AS','z_MEGASTROKE_AS z_MEGASTROKE_AIS'])]
-#datadriven.reset_index(inplace=True,drop=True)
-#print(datadriven)
-#datadriven.loc[len(datadriven),:] = df_val.loc[df_val.trait_set.isin(['z_MEGASTROKE_AIS z_MEGASTROKE_AS','z_MEGASTROKE_AS z_MEGASTROKE_AIS'])].values[0]
 datadriven.reset_index(inplace=True,drop=True)
+
 print(len(datadriven))
 	
+def ttest(pair, metric):
+        return ttest_ind(dict_data_train[pair[0]][metric].values,dict_data_val[pair[1]][metric].values,axis=0,equal_var=False,alternative='two-sided')
+
+dict_data_train= { "similar": similar_train,
+ "dif":dif_train,
+ "verydif":verydif_train,
+ "random":random_baseline_train,
+}
+
+dict_data_val= { "similar": similar_val,
+ "dif":dif_val,
+ "verydif":verydif_val,
+ "random":random_baseline_val,
+ "datadriven":datadriven
+}
+
+
+combi= list(combinations(dict_data_val.keys(), 2))
+# keep only tests that are significant after multi test correction
+gain_ttest = {t:ttest(t,"obs_gain") for t in combi}
+gain_ttest = {k: v for k, v in gain_ttest.items() if (v[1]*10) < 0.05 }
+
+joint_ttest = {t:ttest(t,"obs_joint") for t in combi}
+joint_ttest = {k: v for k, v in joint_ttest.items() if (v[1]*10) < 0.05 }
+
 
 target = 'fraction_more_significant_joint_qval'
-plt.rcParams.update({'font.size':18})
-fig = plt.figure(figsize=[30,25])
+plt.rcParams.update({'font.size':6.5})
+
+fig = plt.figure(figsize=[8.5,7])
 axs = fig.subplot_mosaic([['(A)','(B)'],
                            ['(C)','(D)']])
 
 ax=axs['(A)']
-ax.boxplot([similar_val['obs_gain'].values,dif_val['obs_gain'],verydif_val['obs_gain'],datadriven['obs_gain']],notch=True)
-ax.set_xticks([1,2,3,4])
-ax.set_xticklabels(['Homogenous\n (1 group; {} sets)'.format(len(similar_val)),'Low heterogeneity\n (2-4 groups; {} sets)'.format(len(dif_val)),'High heterogeneity\n(>4 groups; {} sets)'.format(len(verydif_val)),'data-driven\n (top {} sets)'.format(len(datadriven))],rotation=30)
+ax.boxplot([similar_val['obs_gain'].values,dif_val['obs_gain'],verydif_val['obs_gain'],datadriven['obs_gain'], random_baseline_val['obs_gain']],notch=True, showfliers=False)
+ax.set_xticks([1,2,3,4,5])
+ax.set_xticklabels(['Homogenous\n (1 group; {} sets)'.format(len(similar_val)),
+                    'Low heterogeneity\n (2-4 groups; {} sets)'.format(len(dif_val)),
+                    'High heterogeneity\n(>4 groups; {} sets)'.format(len(verydif_val)),
+                    'Data-driven\n (top {} sets)'.format(len(datadriven)),
+                    'Random\n({} sets)'.format(len(random_baseline_val))
+                    ],rotation=40)
 ax.set_ylim([0,1.5])
 ax.set_ylabel('observed gain')
-result = ttest_ind(dif_val['obs_gain'].values,similar_train['obs_gain'].values,axis=0,equal_var=False,alternative='two-sided')
-ax.plot([1, 1, 2-0.01, 2-0.01],[1.09, 1.1, 1.1, 1.09], lw=1, c='k')
-sig_symbol = 'diff={}, p={}'.format(round(dif_val['obs_gain'].mean()-similar_train['obs_gain'].mean(),2),round(result.pvalue,3))
-ax.text(1.5, 1.11, sig_symbol, ha='center', va='bottom', c='k',fontsize=12)
-result = ttest_ind(verydif_val['obs_gain'].values,dif_train['obs_gain'].values,axis=0,equal_var=False,alternative='two-sided')
+
+result =gain_ttest[('dif', 'verydif')]
 ax.plot([2+0.01, 2+0.01, 3-0.01, 3-0.01],[1.09, 1.1, 1.1, 1.09], lw=1, c='k')
 sig_symbol = 'diff={}, p={}'.format(round(verydif_val['obs_gain'].mean()-dif_train['obs_gain'].mean(),2),'%.1E' % result.pvalue)
-ax.text(2.5, 1.11, sig_symbol, ha='center', va='bottom', c='k',fontsize=12)
-result = ttest_ind(datadriven['obs_gain'].values,verydif_train['obs_gain'].values,axis=0,equal_var=False,alternative='two-sided')
+ax.text(2, 1.11, sig_symbol, ha='center', va='bottom', c='k',fontsize=6)
+
+result = gain_ttest[('verydif', "datadriven")]
 ax.plot([3+0.01, 3+0.01, 4-0.01, 4-0.01],[1.09, 1.1, 1.1, 1.09], lw=1, c='k')
 sig_symbol = 'diff={}, p={}'.format(round(datadriven['obs_gain'].mean()-verydif_train['obs_gain'].mean(),2),'%.1E' % result.pvalue)
-ax.text(3.5, 1.11, sig_symbol, ha='center', va='bottom', c='k',fontsize=12)
-result = ttest_ind(datadriven['obs_gain'].values,dif_train['obs_gain'].values,axis=0,equal_var=False,alternative='two-sided')
+ax.text(3.8, 1.11, sig_symbol, ha='center', va='bottom', c='k',fontsize=6)
+
+result = gain_ttest[('dif', "datadriven")]
 ax.plot([2, 2, 4, 4],[1.19, 1.2, 1.2, 1.19], lw=1, c='k')
 sig_symbol = 'diff={}, p={}'.format(round(datadriven['obs_gain'].mean()-dif_train['obs_gain'].mean(),2),'%.1E' % result.pvalue)
-ax.text(3, 1.21, sig_symbol, ha='center', va='bottom', c='k',fontsize=12)
-result = ttest_ind(verydif_val['obs_gain'].values,similar_train['obs_gain'].values,axis=0,equal_var=False,alternative='two-sided')
-ax.plot([1, 1, 3, 3],[1.29, 1.3, 1.3, 1.29], lw=1, c='k')
-sig_symbol = 'diff={}, p={}'.format(round(verydif_val['obs_gain'].mean()-similar_train['obs_gain'].mean(),2),round(result.pvalue,3))
-ax.text(2, 1.31, sig_symbol, ha='center', va='bottom', c='k',fontsize=12)
-result = ttest_ind(datadriven['obs_gain'].values,similar_train['obs_gain'].values,axis=0,equal_var=False,alternative='two-sided')
-ax.plot([1, 1, 4, 4],[1.39, 1.4, 1.4, 1.39], lw=1, c='k')
+ax.text(3, 1.21, sig_symbol, ha='center', va='bottom', c='k',fontsize=6)
+
+result = gain_ttest[('similar', "datadriven")]
+ax.plot([1, 1, 4, 4],[1.29, 1.3, 1.3, 1.29], lw=1, c='k')
 sig_symbol = 'diff={}, p={}'.format(round(datadriven['obs_gain'].mean()-similar_train['obs_gain'].mean(),2),'%.1E' % result.pvalue)
-ax.text(2.5, 1.41, sig_symbol, ha='center', va='bottom', c='k',fontsize=12)
+ax.text(2.5, 1.31, sig_symbol, ha='center', va='bottom', c='k',fontsize=6)
+
+result = gain_ttest[('random', "datadriven")]
+ax.plot([4,4, 5, 5],[1.39, 1.4, 1.4, 1.39], lw=1, c='k')
+sig_symbol = 'diff={}, p={}'.format(round(datadriven['obs_gain'].mean()-similar_train['obs_gain'].mean(),2),'%.1E' % result.pvalue)
+ax.text(4.5, 1.41, sig_symbol, ha='center', va='bottom', c='k',fontsize=6)
+
 
 # Joint
 ax=axs['(B)']
-ax.boxplot([similar_val['obs_joint'].values,dif_val['obs_joint'].values,verydif_val['obs_joint'].values,datadriven['obs_joint'].values],notch=True)
-ax.set_xticks([1,2,3,4])
-ax.set_xticklabels(['Homogenous\n (1 group; {} sets)'.format(len(similar_val)),'Low heterogeneity\n (2-4 groups; {} sets)'.format(len(dif_val)),'High heterogeneity\n (>4 groups; {} sets)'.format(len(verydif_val)),'data-driven\n (top {} sets)'.format(len(datadriven))],rotation=30)
-ax.set_ylim([0,290])
+ax.boxplot([similar_val['obs_joint'].values,dif_val['obs_joint'].values,verydif_val['obs_joint'].values,datadriven['obs_joint'].values, random_baseline_val['obs_joint'].values],notch=True, showfliers=False)
+ax.set_xticks([1,2,3,4,5])
+ax.set_xticklabels(['Homogenous\n (1 group; {} sets)'.format(len(similar_val)),
+                'Low heterogeneity\n (2-4 groups; {} sets)'.format(len(dif_val)),
+                'High heterogeneity\n (>4 groups; {} sets)'.format(len(verydif_val)),
+                'data-driven\n (top {} sets)'.format(len(datadriven)),
+                'Random\n({} sets)'.format(len(random_baseline_val))],rotation=40)
+ax.set_ylim([0,331])
 ax.set_ylabel('observed #new associations')
-result = ttest_ind(dif_val['obs_joint'].values,similar_train['obs_joint'].values,axis=0,equal_var=False,alternative='two-sided')
-ax.plot([1, 1, 2-0.01, 2-0.01],[205, 210, 210, 205], lw=1, c='k')
-sig_symbol = 'diff={}, p={}'.format(round(dif_val['obs_joint'].mean()-similar_train['obs_joint'].mean(),2),round(result.pvalue,3))
-ax.text(1.5, 211, sig_symbol, ha='center', va='bottom', c='k',fontsize=12)
+
+
 result = ttest_ind(verydif_val['obs_joint'].values,dif_train['obs_joint'].values,axis=0,equal_var=False,alternative='two-sided')
-print(verydif_val['obs_joint'].mean()-dif_train['obs_joint'].mean(),result.pvalue)
-ax.plot([2+0.01, 2+0.01, 3-0.01, 3-0.01],[205, 210, 210, 205], lw=1, c='k')
+ax.plot([2, 2, 2.9, 2.9],[205, 210, 210, 205], lw=1, c='k')
 sig_symbol = 'diff={}, p={}'.format(round(verydif_val['obs_joint'].mean()-dif_train['obs_joint'].mean(),2),'%.1E' % result.pvalue)
-ax.text(2.5, 211, sig_symbol, ha='center', va='bottom', c='k',fontsize=12)
+ax.text(2, 211, sig_symbol, ha='center', va='bottom', c='k',fontsize=6)
+
 result = ttest_ind(datadriven['obs_joint'].values,verydif_train['obs_joint'].values,axis=0,equal_var=False,alternative='two-sided')
-ax.plot([3+0.01, 3+0.01, 4-0.01, 4-0.01],[205, 210, 210, 205], lw=1, c='k')
+ax.plot([3, 3, 3.9, 3.9],[205, 210, 210, 205], lw=1, c='k')
 sig_symbol = 'diff={}, p={}'.format(round(datadriven['obs_joint'].mean()-verydif_train['obs_joint'].mean(),2),'%.1E' % result.pvalue)
-ax.text(3.5, 211, sig_symbol, ha='center', va='bottom', c='k',fontsize=12)
-result = ttest_ind(datadriven['obs_joint'].values,dif_train['obs_joint'].values,axis=0,equal_var=False,alternative='two-sided')
-ax.plot([2, 2, 4, 4],[225, 230, 230, 225], lw=1, c='k')
+ax.text(4, 211, sig_symbol, ha='center', va='bottom', c='k',fontsize=6)
+
+
+result = ttest_ind(datadriven['obs_joint'].values, dif_train['obs_joint'].values,axis=0,equal_var=False,alternative='two-sided')
+ax.plot([2+0.01, 2+0.01, 4-0.01, 4-0.01],[225, 230, 230, 225], lw=1, c='k')
 sig_symbol = 'diff={}, p={}'.format(round(datadriven['obs_joint'].mean()-dif_train['obs_joint'].mean(),2),'%.1E' % result.pvalue)
-ax.text(3, 231, sig_symbol, ha='center', va='bottom', c='k',fontsize=12)
+ax.text(3, 231, sig_symbol, ha='center', va='bottom', c='k',fontsize=6)
+
 result = ttest_ind(verydif_val['obs_joint'].values,similar_train['obs_joint'].values,axis=0,equal_var=False,alternative='two-sided')
-ax.plot([1, 1, 3, 3],[245, 250, 250, 245], lw=1, c='k')
-sig_symbol = 'diff={}, p={}'.format(round(verydif_val['obs_joint'].mean()-similar_train['obs_joint'].mean(),2),'%.1E' % result.pvalue)
-ax.text(2, 251, sig_symbol, ha='center', va='bottom', c='k',fontsize=12)
+ax.plot([1, 1, 3-0.01, 3-0.01],[245, 250, 250, 245], lw=1, c='k')
+sig_symbol = 'diff={}, p={}'.format(round(verydif_val['obs_joint'].mean()-similar_train['obs_joint'].mean(),2),round(result.pvalue,3))
+ax.text(2, 251, sig_symbol, ha='center', va='bottom', c='k',fontsize=6)
+
 result = ttest_ind(datadriven['obs_joint'].values,similar_train['obs_joint'].values,axis=0,equal_var=False,alternative='two-sided')
-ax.plot([1, 1, 4, 4],[265, 270, 270, 265], lw=1, c='k')
+ax.plot([1+0.01, 1+0.01, 4-0.01, 4-0.01],[265, 270, 270, 265], lw=1, c='k')
 sig_symbol = 'diff={}, p={}'.format(round(datadriven['obs_joint'].mean()-similar_train['obs_joint'].mean(),2),'%.1E' % result.pvalue)
-ax.text(2.5, 271, sig_symbol, ha='center', va='bottom', c='k',fontsize=12)
+ax.text(2.5, 271, sig_symbol, ha='center', va='bottom', c='k',fontsize=6)
+
+result = ttest_ind(random_baseline_val['obs_joint'].values,similar_train['obs_joint'].values,axis=0,equal_var=False,alternative='two-sided')
+ax.plot([1+0.01, 1+0.01, 5-0.01, 5-0.01],[285, 290, 290, 285], lw=1, c='k')
+sig_symbol = 'diff={}, p={}'.format(round(random_baseline_val['obs_joint'].mean()-similar_train['obs_joint'].mean(),2),'%.1E' % result.pvalue)
+ax.text(3, 291, sig_symbol, ha='center', va='bottom', c='k',fontsize=6)
+
+
+result = ttest_ind(datadriven['obs_joint'].values,random_baseline_train['obs_joint'].values,axis=0,equal_var=False,alternative='two-sided')
+ax.plot([4, 4, 5, 5],[305, 310, 310, 305], lw=1, c='k')
+sig_symbol = 'diff={}, p={}'.format(round(datadriven['obs_joint'].mean()-random_baseline_train['obs_joint'].mean(),2),'%.1E' % result.pvalue)
+ax.text(4.2, 311, sig_symbol, ha='center', va='bottom', c='k',fontsize=6)
+
 
 clinical_name = {'[1]':'Neoplasm','[2]':'Cardiovascular Diseases','[3]':'Musculoskeletal and Neural\n Physiological Phenomena','[4]':'Psychological Phenomena','[5]':'Physiological Phenomena','[6]':'Nutritional and Metabolic Diseases','[7]':'Circulatory and Respiratory\n Physiological Phenomena','[8]':'Nervous System Diseases','[9]':'Mental disorders','[10]':'Immune System Diseases','[11]':'Reproductive and Urinary\n Physiological Phenomena','[12]':'Eye Diseases','[13]':'Population Characteristics','[14]':'Enzymes and Coenzymes'}
 
 ax=axs['(C)']
 names = sorted(set(similar_val.group_names))
 print(names)
-ax.boxplot([similar_val.loc[(similar_val.n_group==1)&(similar_val.group_names==name)]['obs_gain'].values for name in names],vert=False)
+ax.boxplot([similar_val.loc[(similar_val.n_group==1)&(similar_val.group_names==name)]['obs_gain'].values for name in names],vert=False, showfliers=False)
 ax.set_xlim([-0.01,1.01])
 ax.set_yticks(range(1,len(names)+1))
 ax.set_yticklabels([clinical_name[name]+' ({} sets)'.format(len(similar_val.loc[(similar_val.n_group==1)&(similar_val.group_names==name)])) for name in names])
@@ -139,6 +182,7 @@ print('heterogenous:',dif_val['obs_joint'].mean())
 print('very heterogenous:',verydif_val['obs_joint'].mean())
 print('similar:',similar_val['obs_joint'].mean())
 
-plt.subplots_adjust(bottom=0.2,wspace=0.2,hspace=0.4,left=0.3)
+plt.subplots_adjust(bottom=0.05,wspace=0.2,hspace=0.6,left=0.3)
 plt.savefig(output,dpi=300)
 
+
-- 
GitLab