diff --git a/Figures_manuscript/scripts/Fig_SUP_gain_vs_Ntrait.R b/Figures_manuscript/scripts/Fig_SUP_gain_vs_Ntrait.R new file mode 100644 index 0000000000000000000000000000000000000000..3f643ce9f220d5577de2212ed41ffb5acbcc0638 --- /dev/null +++ b/Figures_manuscript/scripts/Fig_SUP_gain_vs_Ntrait.R @@ -0,0 +1,80 @@ +library(data.table) +library(ggplot2) +library(psych) +library(cowplot) +library(stringr) +library(latex2exp) +library(scales) + +set_features_20k = fread("../inputs/JASS_5CVdata-2023-08-01/traitset_jass_5CVcombined_without_duplicates.tsv") + +set_features_large_N = fread("../inputs/JASS_largeN/stat_set_with_JASS_power_large_N_sets.tsv") + +features = intersect(names(set_features_20k), names(set_features_large_N)) + +set_features= set_features_large_N#rbind(set_features_20k[, features, with=FALSE], set_features_large_N[, features, with=FALSE]) + + +set_features[, percent_increase := 100*((Both+Univariate+Joint) / (Univariate+Both))] +set_features[, .(100*((Both+Univariate+Joint) / (Univariate+Both)), (Both+Univariate+Joint),(Univariate+Both+Joint), Joint, Univariate, Both)] + + +head(set_features) + +compute_overlap <- function(x){ + (length(intersect(x, trait_list))) +} + +setkey(set_features, "Id") +set_features[, overlap := -1] + +for(key in set_features$Id){ + + xs = set_features[key,trait] + ntrait = set_features[key,k] + trait_list = strsplit(xs, " ")[[1]] + + set_features[key, overlap := mean(apply(set_features[(k==ntrait) & (Id!=key), strsplit(trait, " ")], 2, compute_overlap))/k] + +} + + + + +summary(set_features$percent_increase) +summary(set_features[, Joint]) +summary(set_features[, Both+Univariate]) + +######################################################################### +######################################################################### +######################################################################### +set_features$condition_number_rcov +set_features[,log10_avg_distance_cor := log10(avg_distance_cor)] +set_features[,log10_condition_number_gcov := log10(condition_number_gcov)] +set_features[,log10_condition_number_rcov := log10(condition_number_rcov)] +set_features[,log10_mean_gencov := log10(mean_gencov)] +set_features[,log10_mean_null_phencov := log10(mean_null_phencov)] + +feats_name = c( + "f[JASS<univ]"= "fraction_more_significant_joint_qval", + "# of multi-trait loci" = "Joint", + "# of univariate loci" = "Univariate", + "% increase in significant loci" = "percent_increase") + +feats_name2 = names(feats_name) +names(feats_name2) = feats_name + + +long_set = melt(set_features[,.(k,overlap, Univariate, Joint, fraction_more_significant_joint_qval, percent_increase)], id.var=c("k", "overlap")) + + +long_set$variable = feats_name2[as.character(long_set$variable)] + +unique(long_set$variable) + +long_set[,`#trait`:= k] + +png("../outputs/Figure_SUP_gain_Ntraits.png", width=8.5, height=8.5, unit="in", res=300) +p = ggplot(long_set, aes(x=`#trait`, y=value, color=overlap)) + geom_point(alpha=0.65) + facet_wrap(.~variable, scales = "free")+ theme_minimal() +ylab('') +print(p) +dev.off() \ No newline at end of file diff --git a/Figures_manuscript/scripts/Figure_SUP_point3_reviewer2.py b/Figures_manuscript/scripts/Figure_SUP_point3_reviewer2.py new file mode 100644 index 0000000000000000000000000000000000000000..b8dcb37e8586f1518cc38968fa5e8c826d945fca --- /dev/null +++ b/Figures_manuscript/scripts/Figure_SUP_point3_reviewer2.py @@ -0,0 +1,148 @@ +import numpy as np +import matplotlib +matplotlib.use('AGG') +import matplotlib.pyplot as plt +from scipy.stats import ttest_ind +from itertools import combinations +import matplotlib.transforms as mtransforms +import scipy +from itertools import combinations +import pandas as pd + + +trait_features = pd.read_csv("../inputs/72trait_data_2023-07-07.csv") +trait_features.set_index("ID", inplace=True) +median_h2 = trait_features.h2_LD.median() + +def define_h2_group(x): + bool_h2 = trait_features.loc[x.split(" "), "h2_LD"] < median_h2 + if bool_h2.any() == False: + return "low h²" + if bool_h2.all(): + return "high h²" + else: + return 'heterogeneous h²' + +df_train = pd.read_csv("../inputs/JASS_5CVdata-2023-08-01/traitset_jass_5CVcombined_without_duplicates.tsv",delimiter='\t') +df_train["heritability"] = df_train.trait.apply(define_h2_group) +df_train = df_train.loc[~df_train.Joint.isna()] + +df_val = pd.read_csv("../inputs/JASS_5CVdata-2023-08-01/traitset_jass_CVtest2-newSUMMARY_remove-nan.tsv",delimiter='\t') +output = '../outputs/Figure_SUP_point3_reviewer2.png' +df_val["heritability"] = df_val.trait.apply(define_h2_group) +df_val = df_val.loc[~df_val.Joint.isna()] + +df_valr = pd.read_csv('../inputs/clinical_grouping_analysis_2023-09-06/category_traitset_with_mean_test_jass.tsv',delimiter='\t') + + +df_train = df_train.loc[(df_train.CV==2)& (df_train.train),] + +low_h2_train = df_train.loc[(df_train.heritability=="low h²") ] +high_h2_train = df_train.loc[(df_train.heritability=="high h²") ] +heterogeneous_h2_train = df_train.loc[(df_train.heritability=="heterogeneous h²") ] +random_baseline_train = df_train.iloc[np.random.choice(range(df_train.shape[0]), 100)] + +low_h2_val = df_val.loc[(df_val.heritability=="low h²") ] +high_h2_val = df_val.loc[(df_val.heritability=="high h²") ] +heterogeneous_h2_val = df_val.loc[(df_val.heritability=="heterogeneous h²") ] +random_baseline_val = df_val.iloc[np.random.choice(range(df_val.shape[0]), 100)] + +datadriven = df_valr.loc[df_valr.rank_datadriven < df_valr.rank_datadriven.nsmallest(101).iloc[-1]] +datadriven.reset_index(inplace=True,drop=True) + +datadriven["fraction_more_significant_joint_qval"] = datadriven.obs_gain +datadriven["Joint"] = datadriven.obs_joint + +print(len(datadriven)) + +def ttest(pair, metric): + return ttest_ind(dict_data_train[pair[0]][metric].values,dict_data_val[pair[1]][metric].values,axis=0,equal_var=False,alternative='two-sided') + +dict_data_train= { "low_h2": low_h2_train, + "high_h2":high_h2_train, + "heterogeneous_h2":heterogeneous_h2_train, + "random":random_baseline_train, +} + +dict_data_val= { "low_h2": low_h2_val, + "high_h2":high_h2_val, + "heterogeneous_h2":heterogeneous_h2_val, + "random":random_baseline_val, + "datadriven":datadriven +} + +position_dictionary = {"low_h2":1, "high_h2":2, "heterogeneous_h2":3, "random":4, "datadriven":5} +combi= list(combinations(dict_data_val.keys(), 2)) +# keep only tests that are significant after multi test correction +gain_ttest = {t:ttest(t,"fraction_more_significant_joint_qval") for t in combi} +gain_ttest = {k: v for k, v in gain_ttest.items() if (v[1]*10) < 0.05 } + +joint_ttest = {t:ttest(t,"Joint") for t in combi} +joint_ttest = {k: v for k, v in joint_ttest.items() if (v[1]*10) < 0.05 } + + +target = 'fraction_more_significant_joint_qval' +plt.rcParams.update({'font.size':6.5}) + +fig = plt.figure(figsize=[8.5,5]) +axs = fig.subplot_mosaic([['(A)','(B)']]) + +ax=axs['(A)'] +ax.boxplot([low_h2_val['fraction_more_significant_joint_qval'].values,high_h2_val['fraction_more_significant_joint_qval'],heterogeneous_h2_val['fraction_more_significant_joint_qval'],datadriven['fraction_more_significant_joint_qval'], random_baseline_val['fraction_more_significant_joint_qval']],notch=True, showfliers=False) +ax.set_xticks([1,2,3,4,5]) +ax.set_xticklabels(['Low h²\n ({} sets)'.format(len(low_h2_val)), + 'High h²\n ({} sets)'.format(len(high_h2_val)), + 'Heterogeneous h²\n({} sets)'.format(len(heterogeneous_h2_val)), + 'Data-driven\n (top {} sets)'.format(len(datadriven)), + 'Random\n({} sets)'.format(len(random_baseline_val)) + ],rotation=40) + +ax.set_ylim([0,2.1]) +ax.set_ylabel('observed gain') + +baseline = 1.09 + +for key_i in gain_ttest.keys(): + print(key_i) + + result =gain_ttest[key_i] + xstart = position_dictionary[key_i[0]] + xend = position_dictionary[key_i[1]] + + ax.plot([xstart, xstart, xend, xend],[baseline, baseline+0.01, baseline+0.01, baseline], lw=1, c='k') + sig_symbol = 'diff={}, p={}'.format(round(dict_data_val[key_i[1]]['fraction_more_significant_joint_qval'].mean()-dict_data_train[key_i[0]]['fraction_more_significant_joint_qval'].mean(),2),'%.1E' % result.pvalue) + ax.text((xstart+xend)/2, baseline+0.011, sig_symbol, ha='center', va='bottom', c='k',fontsize=6) + baseline = baseline + 0.1 + +# Joint +ax=axs['(B)'] +ax.boxplot([low_h2_val['Joint'].values,high_h2_val['Joint'].values,heterogeneous_h2_val['Joint'].values,datadriven['Joint'].values, random_baseline_val['Joint'].values],notch=True, showfliers=False) +ax.set_xticks([1,2,3,4,5]) +ax.set_xticklabels(['Low h²\n ({} sets)'.format(len(low_h2_val)), + 'High h²\n ({} sets)'.format(len(high_h2_val)), + 'Heterogeneous h²\n({} sets)'.format(len(heterogeneous_h2_val)), + 'Data-driven\n (top {} sets)'.format(len(datadriven)), + 'Random\n({} sets)'.format(len(random_baseline_val)) + ],rotation=40) +ax.set_ylim([0,405]) +ax.set_ylabel('observed #new associations') + +baseline = 205 + +for key_i in joint_ttest.keys(): + print(key_i) + + result =joint_ttest[key_i] + xstart = position_dictionary[key_i[0]] + xend = position_dictionary[key_i[1]] + + ax.plot([xstart, xstart, xend, xend],[baseline, baseline+5, baseline+5, baseline], lw=1, c='k') + sig_symbol = 'diff={}, p={}'.format(round(dict_data_val[key_i[1]]['Joint'].mean()-dict_data_train[key_i[0]]['Joint'].mean(),2),'%.1E' % result.pvalue) + ax.text((xstart+xend)/2, baseline+6, sig_symbol, ha='center', va='bottom', c='k',fontsize=6) + baseline = baseline + 20 + + +plt.subplots_adjust(bottom=0.2,wspace=0.2,hspace=0.6,left=0.1) +plt.savefig(output,dpi=300) + +