diff --git a/jass/models/plots.py b/jass/models/plots.py index 12a4eb8cec40d45d99ab6aef47cef0ce8cf30837..8f80635d5d8a17b249574b44ae9ac90a735f8073 100644 --- a/jass/models/plots.py +++ b/jass/models/plots.py @@ -16,10 +16,11 @@ import matplotlib.pyplot as plt from matplotlib import colors import matplotlib.patches as mpatches from scipy.stats import norm, chi2 - +import seaborn as sns import os from pandas import DataFrame, read_hdf -from scipy.stats import chi2 +import pandas as pd + def replaceZeroes(df): """ @@ -233,7 +234,7 @@ def create_quadrant_plot(work_file_path: str, alim = np.ceil(pv_t.JASS_PVAL.loc[pv_t.color == "#f77189"].max() + 2) if np.isnan(alim): alim = 10 - plt.axis([0, alim, 0, alim]) + plt.plot([0, alim, 0, alim]) # légendes abcisse et ordonnee plt.xlabel('-log10(P) for univariate tests', fontsize=12) # plt.show() @@ -251,60 +252,50 @@ def create_quadrant_plot(work_file_path: str, def create_qq_plot(work_file_path: str, qq_plot_path: str): - df = read_hdf(work_file_path, "SumStatTab", columns=['JASS_PVAL', 'UNIVARIATE_MIN_PVAL']) + df = read_hdf(work_file_path, "SumStatTab", columns=['JASS_PVAL', 'UNIVARIATE_MIN_PVAL', 'UNIVARIATE_MIN_QVAL']) + + df[['JASS_PVAL', 'UNIVARIATE_MIN_PVAL', "UNIVARIATE_MIN_QVAL"]] = replaceZeroes( + df[['JASS_PVAL', 'UNIVARIATE_MIN_PVAL','UNIVARIATE_MIN_QVAL']]) - df[['JASS_PVAL', 'UNIVARIATE_MIN_PVAL']] = replaceZeroes( - df[['JASS_PVAL', 'UNIVARIATE_MIN_PVAL']]) - dr = read_hdf(work_file_path, "Regions") - #count the number of traits - Ntrait = len([i for i in dr.columns if i[:2]=='z_']) pvalue = -np.log10(df.JASS_PVAL) -<<<<<<< jass/models/plots.py - # Cast values between 0 and 1, 0 and 1 excluded - x = -np.log10(np.arange(1, pvalue.shape[0] + 1) / (pvalue.shape[0] + 2)) - y = pvalue.sort_values() - plt.scatter(x[::-1], y, s=5) - pval_median = np.nanmedian(df.JASS_PVAL) - lambda_value_omni = chi2.isf(pval_median, df=Ntrait) / chi2.isf(0.5, df=Ntrait) - lambda_value_sumZ = chi2.isf(pval_median, df=1) / chi2.isf(0.5, df=1) - print("Number of trait analyzed = {}".format(Ntrait)) - print("JASS median p-val = {:.3f}".format(pval_median)) - print("Inflation Factor if the test is omnibus= {:.3f}".format(lambda_value_omni)) - print("Inflation Factor if the test is a SumZ= {:.3f}".format(lambda_value_sumZ)) - x_1 = np.linspace(0, 6) - y_1 = (pval_median/0.5) * x_1 - x_2 = np.linspace(0, 6) - plt.plot(x_1, y_1, c="red") - plt.plot(x_2, x_2) - plt.title("Median p-val = {:.2f}".format(pval_median)) - plt.xlabel("expected quantile of -log10(P)") - plt.ylabel("observed quantile of -log10(P)") - - plt.savefig(qq_plot_path, dpi=600) - -======= pvalue_univ = -np.log10(df.UNIVARIATE_MIN_PVAL) + qvalue_univ = -np.log10(df.UNIVARIATE_MIN_QVAL) # compute_expected pvalue - - QQ_pval = pd.DataFrame(index =-np.log10(np.arange(1, pvalue.shape[0] + 1) / (pvalue.shape[0] + 2)), - {"JASS pvalue" : pvalue.sort_values(), - "UNIV pvalue" : pvalue_univ.sort_values()}) - + pvalue.sort_values().values + + QQ_pval = pd.DataFrame({"JASS p-value" : pvalue.sort_values().values, + "Univariate p-value" : pvalue_univ.sort_values().values, + "Univariate q-value" : qvalue_univ.sort_values().values, + }) + QQ_pval = QQ_pval.iloc[::20,].dropna() + exp_val = np.flip(- np.log10(QQ_pval.index.values / QQ_pval.index.max())) + + exp_val[-1] = QQ_pval.max().max() + QQ_pval.index = exp_val + QQ_pval = QQ_pval.iloc[:-1,] pval_median = df.JASS_PVAL.median() pval_median_univ = df.UNIVARIATE_MIN_PVAL.median() + pval_median_quniv = df.UNIVARIATE_MIN_QVAL.median() + print("median pval") print(pval_median) - lambda_value_jass = chi.sf(pval_median) / chi.sf(0.5) - lambda_value_univ = chi.sf(pval_median_univ) / chi.sf(0.5) + lambda_value_jass = chi2.sf(pval_median, df=1) / chi2.sf(0.5, df=1) + lambda_value_jass + lambda_value_univ = chi2.sf(pval_median_univ, df=1) / chi2.sf(0.5, df=1) + lambda_value_quniv = chi2.sf(pval_median_quniv, df=1) / chi2.sf(0.5, df=1) p = sns.lineplot(data=QQ_pval) - p.set("QQ plot") - p.set_xlabel("Expected p-values", fontsize = 20) - p.set_ylabel("Observed p-values", fontsize = 20) + alim = exp_val[-2] + plt.plot([0, alim],[0, alim], c="red", linewidth=0.5) + p.set_title("QQ plot\n λ JASS = {:.2f}\n λ univariate p-values = {:.2f} λ univariate q-values = {:.2f}".format(lambda_value_jass, lambda_value_univ, lambda_value_quniv), fontsize = 11) + p.set_xlabel("Expected -log10(p-values)", fontsize = 13) + p.set_ylabel("Observed -log10(p-values)", fontsize = 13) + plt.savefig(qq_plot_path) ->>>>>>> jass/models/plots.py + plt.clf() print("------ QQ plot -----") + def create_qq_plot_by_GWAS(init_file_path: str, qq_plot_folder: str): df = read_hdf(init_file_path, "SumStatTab", where="Region < {0}".format(2)) uni_var = [i for i in df.columns if i[:2]=="z_"] diff --git a/requirements.txt b/requirements.txt index a3c65b209fe6a63dd6cfd9fb893bc3cdc6a95d02..cbdbddcbb29c31ca7932c7c96e6fa0f50da3e0d4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,9 +8,10 @@ pandas tables scipy matplotlib +seaborn celery pydantic fastapi uvicorn[standard] typing_extensions; python_version < '3.8' -requests \ No newline at end of file +requests