diff --git a/Figures_manuscript/scripts/Fig_SUP_gain_vs_Ntrait.R b/Figures_manuscript/scripts/Fig_SUP_gain_vs_Ntrait.R
new file mode 100644
index 0000000000000000000000000000000000000000..3f643ce9f220d5577de2212ed41ffb5acbcc0638
--- /dev/null
+++ b/Figures_manuscript/scripts/Fig_SUP_gain_vs_Ntrait.R
@@ -0,0 +1,80 @@
+library(data.table)
+library(ggplot2)
+library(psych)
+library(cowplot)
+library(stringr)
+library(latex2exp)
+library(scales)
+
+set_features_20k = fread("../inputs/JASS_5CVdata-2023-08-01/traitset_jass_5CVcombined_without_duplicates.tsv")
+
+set_features_large_N =  fread("../inputs/JASS_largeN/stat_set_with_JASS_power_large_N_sets.tsv")
+
+features = intersect(names(set_features_20k), names(set_features_large_N))
+
+set_features= set_features_large_N#rbind(set_features_20k[, features, with=FALSE], set_features_large_N[, features, with=FALSE])
+
+
+set_features[, percent_increase := 100*((Both+Univariate+Joint) / (Univariate+Both))]
+set_features[, .(100*((Both+Univariate+Joint) / (Univariate+Both)), (Both+Univariate+Joint),(Univariate+Both+Joint), Joint, Univariate, Both)]
+
+
+head(set_features)
+
+compute_overlap <- function(x){
+    (length(intersect(x, trait_list)))
+}
+
+setkey(set_features, "Id")
+set_features[, overlap := -1]
+
+for(key in set_features$Id){
+
+    xs = set_features[key,trait]
+    ntrait = set_features[key,k]
+    trait_list = strsplit(xs, " ")[[1]]
+
+    set_features[key, overlap := mean(apply(set_features[(k==ntrait) & (Id!=key), strsplit(trait, " ")], 2, compute_overlap))/k]
+
+}
+
+
+
+
+summary(set_features$percent_increase)
+summary(set_features[, Joint])
+summary(set_features[, Both+Univariate])
+
+#########################################################################
+#########################################################################
+#########################################################################
+set_features$condition_number_rcov
+set_features[,log10_avg_distance_cor := log10(avg_distance_cor)]
+set_features[,log10_condition_number_gcov := log10(condition_number_gcov)]
+set_features[,log10_condition_number_rcov := log10(condition_number_rcov)]
+set_features[,log10_mean_gencov := log10(mean_gencov)]
+set_features[,log10_mean_null_phencov := log10(mean_null_phencov)]
+
+feats_name = c(
+   "f[JASS<univ]"= "fraction_more_significant_joint_qval", 
+    "# of multi-trait loci" = "Joint",
+     "# of univariate loci" = "Univariate", 
+     "% increase in significant loci" = "percent_increase") 
+
+feats_name2 = names(feats_name)
+names(feats_name2) = feats_name
+
+
+long_set = melt(set_features[,.(k,overlap, Univariate, Joint, fraction_more_significant_joint_qval, percent_increase)], id.var=c("k", "overlap"))
+
+
+long_set$variable = feats_name2[as.character(long_set$variable)]
+
+unique(long_set$variable)
+
+long_set[,`#trait`:= k]
+
+png("../outputs/Figure_SUP_gain_Ntraits.png", width=8.5, height=8.5, unit="in", res=300)
+p = ggplot(long_set, aes(x=`#trait`, y=value, color=overlap)) + geom_point(alpha=0.65) + facet_wrap(.~variable, scales = "free")+ theme_minimal() +ylab('')
+print(p)
+dev.off()
\ No newline at end of file
diff --git a/Figures_manuscript/scripts/Figure_SUP_point3_reviewer2.py b/Figures_manuscript/scripts/Figure_SUP_point3_reviewer2.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8dcb37e8586f1518cc38968fa5e8c826d945fca
--- /dev/null
+++ b/Figures_manuscript/scripts/Figure_SUP_point3_reviewer2.py
@@ -0,0 +1,148 @@
+import numpy as np
+import matplotlib
+matplotlib.use('AGG')
+import matplotlib.pyplot as plt
+from scipy.stats import ttest_ind
+from itertools import combinations
+import matplotlib.transforms as mtransforms
+import scipy
+from itertools import combinations
+import pandas as pd
+
+
+trait_features = pd.read_csv("../inputs/72trait_data_2023-07-07.csv")
+trait_features.set_index("ID", inplace=True)
+median_h2 = trait_features.h2_LD.median()
+
+def define_h2_group(x):
+        bool_h2 = trait_features.loc[x.split(" "), "h2_LD"] < median_h2
+        if bool_h2.any() == False:
+                return "low h²"
+        if bool_h2.all():
+                return "high h²"
+        else:
+                return 'heterogeneous h²'
+
+df_train = pd.read_csv("../inputs/JASS_5CVdata-2023-08-01/traitset_jass_5CVcombined_without_duplicates.tsv",delimiter='\t')
+df_train["heritability"] = df_train.trait.apply(define_h2_group)
+df_train = df_train.loc[~df_train.Joint.isna()]
+
+df_val = pd.read_csv("../inputs/JASS_5CVdata-2023-08-01/traitset_jass_CVtest2-newSUMMARY_remove-nan.tsv",delimiter='\t')
+output = '../outputs/Figure_SUP_point3_reviewer2.png' 
+df_val["heritability"] = df_val.trait.apply(define_h2_group)
+df_val = df_val.loc[~df_val.Joint.isna()]
+
+df_valr = pd.read_csv('../inputs/clinical_grouping_analysis_2023-09-06/category_traitset_with_mean_test_jass.tsv',delimiter='\t')
+
+
+df_train = df_train.loc[(df_train.CV==2)& (df_train.train),]
+
+low_h2_train = df_train.loc[(df_train.heritability=="low h²") ] 
+high_h2_train = df_train.loc[(df_train.heritability=="high h²") ]
+heterogeneous_h2_train = df_train.loc[(df_train.heritability=="heterogeneous h²") ]
+random_baseline_train = df_train.iloc[np.random.choice(range(df_train.shape[0]), 100)]
+
+low_h2_val = df_val.loc[(df_val.heritability=="low h²") ] 
+high_h2_val = df_val.loc[(df_val.heritability=="high h²") ]
+heterogeneous_h2_val = df_val.loc[(df_val.heritability=="heterogeneous h²") ]
+random_baseline_val = df_val.iloc[np.random.choice(range(df_val.shape[0]), 100)]
+
+datadriven = df_valr.loc[df_valr.rank_datadriven < df_valr.rank_datadriven.nsmallest(101).iloc[-1]]
+datadriven.reset_index(inplace=True,drop=True)
+
+datadriven["fraction_more_significant_joint_qval"] = datadriven.obs_gain
+datadriven["Joint"] = datadriven.obs_joint
+
+print(len(datadriven))
+	
+def ttest(pair, metric):
+        return ttest_ind(dict_data_train[pair[0]][metric].values,dict_data_val[pair[1]][metric].values,axis=0,equal_var=False,alternative='two-sided')
+
+dict_data_train= { "low_h2": low_h2_train,
+ "high_h2":high_h2_train,
+ "heterogeneous_h2":heterogeneous_h2_train,
+ "random":random_baseline_train,
+}
+
+dict_data_val= { "low_h2": low_h2_val,
+ "high_h2":high_h2_val,
+ "heterogeneous_h2":heterogeneous_h2_val,
+ "random":random_baseline_val,
+ "datadriven":datadriven
+}
+
+position_dictionary = {"low_h2":1, "high_h2":2, "heterogeneous_h2":3, "random":4, "datadriven":5}
+combi= list(combinations(dict_data_val.keys(), 2))
+# keep only tests that are significant after multi test correction
+gain_ttest = {t:ttest(t,"fraction_more_significant_joint_qval") for t in combi}
+gain_ttest = {k: v for k, v in gain_ttest.items() if (v[1]*10) < 0.05 }
+
+joint_ttest = {t:ttest(t,"Joint") for t in combi}
+joint_ttest = {k: v for k, v in joint_ttest.items() if (v[1]*10) < 0.05 }
+
+
+target = 'fraction_more_significant_joint_qval'
+plt.rcParams.update({'font.size':6.5})
+
+fig = plt.figure(figsize=[8.5,5])
+axs = fig.subplot_mosaic([['(A)','(B)']])
+
+ax=axs['(A)']
+ax.boxplot([low_h2_val['fraction_more_significant_joint_qval'].values,high_h2_val['fraction_more_significant_joint_qval'],heterogeneous_h2_val['fraction_more_significant_joint_qval'],datadriven['fraction_more_significant_joint_qval'], random_baseline_val['fraction_more_significant_joint_qval']],notch=True, showfliers=False)
+ax.set_xticks([1,2,3,4,5])
+ax.set_xticklabels(['Low h²\n ({} sets)'.format(len(low_h2_val)),
+                    'High h²\n ({} sets)'.format(len(high_h2_val)),
+                    'Heterogeneous h²\n({} sets)'.format(len(heterogeneous_h2_val)),
+                    'Data-driven\n (top {} sets)'.format(len(datadriven)),
+                    'Random\n({} sets)'.format(len(random_baseline_val))
+                    ],rotation=40)
+
+ax.set_ylim([0,2.1])
+ax.set_ylabel('observed gain')
+
+baseline = 1.09
+
+for key_i in gain_ttest.keys():
+    print(key_i)
+
+    result =gain_ttest[key_i]
+    xstart = position_dictionary[key_i[0]]
+    xend = position_dictionary[key_i[1]]
+
+    ax.plot([xstart, xstart, xend, xend],[baseline, baseline+0.01, baseline+0.01, baseline], lw=1, c='k')
+    sig_symbol = 'diff={}, p={}'.format(round(dict_data_val[key_i[1]]['fraction_more_significant_joint_qval'].mean()-dict_data_train[key_i[0]]['fraction_more_significant_joint_qval'].mean(),2),'%.1E' % result.pvalue)
+    ax.text((xstart+xend)/2, baseline+0.011, sig_symbol, ha='center', va='bottom', c='k',fontsize=6)
+    baseline = baseline + 0.1
+
+# Joint
+ax=axs['(B)']
+ax.boxplot([low_h2_val['Joint'].values,high_h2_val['Joint'].values,heterogeneous_h2_val['Joint'].values,datadriven['Joint'].values, random_baseline_val['Joint'].values],notch=True, showfliers=False)
+ax.set_xticks([1,2,3,4,5])
+ax.set_xticklabels(['Low h²\n ({} sets)'.format(len(low_h2_val)),
+                    'High h²\n ({} sets)'.format(len(high_h2_val)),
+                    'Heterogeneous h²\n({} sets)'.format(len(heterogeneous_h2_val)),
+                    'Data-driven\n (top {} sets)'.format(len(datadriven)),
+                    'Random\n({} sets)'.format(len(random_baseline_val))
+                    ],rotation=40)
+ax.set_ylim([0,405])
+ax.set_ylabel('observed #new associations')
+
+baseline = 205
+
+for key_i in joint_ttest.keys():
+    print(key_i)
+
+    result =joint_ttest[key_i]
+    xstart = position_dictionary[key_i[0]]
+    xend = position_dictionary[key_i[1]]
+
+    ax.plot([xstart, xstart, xend, xend],[baseline, baseline+5, baseline+5, baseline], lw=1, c='k')
+    sig_symbol = 'diff={}, p={}'.format(round(dict_data_val[key_i[1]]['Joint'].mean()-dict_data_train[key_i[0]]['Joint'].mean(),2),'%.1E' % result.pvalue)
+    ax.text((xstart+xend)/2, baseline+6, sig_symbol, ha='center', va='bottom', c='k',fontsize=6)
+    baseline = baseline + 20
+
+
+plt.subplots_adjust(bottom=0.2,wspace=0.2,hspace=0.6,left=0.1)
+plt.savefig(output,dpi=300)
+
+