Skip to content
Snippets Groups Projects
Select Git revision
  • master
1 result

README.md

Blame
  • Script3_functions.py 24.04 KiB
    # -*- coding: utf-8 -*-
    """
    Created on Mon May 29 09:50:47 2017
    
    @author: apolline
    """
    import pandas as pd
    import numpy as np
    import re
    import os
    import os.path
    
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    import matplotlib.patches as mpatches
    import matplotlib.image as mpimg
    
    
    #################################################################################################################################
    # Function to cut last two characters of a string
    # INPUT: a string  e.g. 'rs1234_A'
    # OUTPUT: a string e.g. 'rs1234'
    ################################################################################################################################# 
    def cut_last(snp):
        snp = snp[0:(len(snp) - 2)]
        return snp   
    
    #################################################################################################################################
    # Function to tranform a list into a string
    # INPUT: a list of strings e.g. ['a', 'b', 'c']
    # OUTPUT: a string         e.g. 'a, b, c'
    #################################################################################################################################
    def list_to_string(l):
        if (type(l) == np.ndarray):
            s = l[0]
            l = l[1::]
            while (len(l) > 0):
                s = s + ', ' + l[0]
                l = l[1::]
            return(s)
    
    #################################################################################################################################
    # Add genome region corresponding to chromose and position (1703 independent LD blocks)
    # INPUT: a dataframe (df) and a file with regions
    # OUTPUT: None
    #################################################################################################################################  
    def fill_regions(df, regions_file):
        regions = pd.read_csv(regions_file, sep='\t')
        reg = pd.DataFrame(np.zeros(df.shape[0], dtype = int))
        for r in range(0, 1703):
            left = regions[' start '][r]    
            right = regions[' stop'][r]    
            chrnum = [int(i) for i in re.findall('\d+', regions['chr '][r])]   
            ind = np.where((df['CHR'] == chrnum) &(df['POS'] >= left ) & (df['POS'] <= right ))        
            reg.iloc[ind] = r + 1
        df['Region'] = reg
    
    #################################################################################################################################
    # Compile results in one file with one line per SNP
    # INPUT: Folder with results per blocks, folder for summarized results, number of blocks, list of phenotypes, list of results, bim file, regions file, file for summarized results
    # OUTPUT: None
    #################################################################################################################################
    def compile_results(blocks_directory, summary_directory, n_blocks, pheno_list, res_list, bim_file, regions_file, results_file):
    
        summary_file = summary_directory + '/results_all_SNPs.txt'    
            
        os.system('for i in `seq 1 '+ str(n_blocks) + '`; do cat ' + blocks_directory + '/results_${i}.txt; done >> ' + summary_file)
        
        results = pd.read_csv(summary_file, sep = '\t', header = None)
        results.columns = ['SNP'] + [pheno + res for pheno in pheno_list for res in res_list]
        
        # Replace rsxx_A/T/C/G with rsxx
        results['SNP'] = results['SNP'].apply(cut_last)
        
        # Merge with bim file to get chromosome and position
        bim = pd.read_csv(bim_file, sep = '\t', header = None)[[0, 1, 3]]
        bim.columns = ['CHR', 'SNP', 'POS']
        data = pd.merge(results, bim, on = 'SNP', how = 'left')
        
        # Compute regions from chromosome and position
        fill_regions(data, regions_file)
        
        # Export 
        data.to_csv(results_file, sep = '\t', index = 0)
        os.system('rm ' + summary_file)
    
    #################################################################################################################################
    # Build a data frame with significant associations (one line per Phenotype - SNP association)
    # INPUT: data frame, list of phenotypes, list of results, significance threshold
    # OUTPUT: data frame with one line per significant Phenotype - SNP association
    #################################################################################################################################
    def significant_snps(df, pheno_list, res_list, th):
        col = ['Phenotype', 'Region', 'CHR', 'SNP', 'Standard_beta', 'Standard_pval', 'CMS_beta', 'CMS_pval', 'CMS_rsquared', 'CMS_ncovs']
        table_signals = pd.DataFrame(columns = col, index = range(len(df)))
        
        s = 0
        
        for pheno in pheno_list:
            print(pheno)
            
            temp_cms = df.iloc[np.where(df[pheno + '_pvalMC'] < th)]
            temp_std = df.iloc[np.where(df[pheno + '_pvalMA'] < th)] 
            temp_all = pd.merge(temp_cms, temp_std, how = 'outer')
            
            temp_all = temp_all[['Region','CHR','SNP'] + [pheno + res for res in res_list]]
            temp_all.columns = col[1::]
            temp_all.index = range(s, s + len(temp_all))
            
            table_signals['Phenotype'][range(s, s + len(temp_all))] = [pheno] * len(temp_all)
            
            for i in range(s, s + len(temp_all)):
                table_signals.loc[i, col[1::]] = temp_all.loc[i]
                         
            s = s + len(temp_all)
            
        table_signals = table_signals.dropna(how = 'all')
        return(table_signals)
    
    #################################################################################################################################
    # Get genes from UCSC data base
    # INPUT: data frame, folder for summarized results
    # OUTPUT: data frame with genes in a new column
    #
    # Tables :
    #  - refFlat: A gene prediction with additional geneName field
    #     *geneName = Gene name
    #     *txStart = Transcription start position
    #     *txEnd = Transcription end position
    #
    #  - snp150: Polymorphism data from dbSNP
    #     *name = SNP id
    #     *chromStart =  	Start position in chromosome
    #     *chromEnd = End position in chromosome
    #################################################################################################################################
    def add_genes_ucsc(results_snps, summary_directory):
        tmp_file = summary_directory + '/tmp_ucsc'
        
        list_snps = np.unique(results_snps['SNP']) # list of significant snps
        command = 'mysql --user=genome --host=genome-euro-mysql.soe.ucsc.edu -A -D hg19 -e \'select R.geneName, R.txStart, R.txEnd, S.name, S.chromStart, S.chromEnd from snp150 as S left join refFlat as R on (S.chrom=R.chrom and not(R.txEnd+100000<S.chromStart or S.chromEnd+100000<R.txStart)) where S.name in ('
        
        for snp in list_snps:
            command = command + '"' + snp + '",'
        
        command = command + '"' + list_snps[len(list_snps)-1] + '")\' > ' + tmp_file
        
        os.system(command)
        
        ucsc = pd.read_csv(tmp_file, sep = '\t')
        
        ucsc = ucsc.dropna()
        ucsc = ucsc.drop_duplicates()
        ucsc['dist'] = range(len(ucsc))
        
        for i in ucsc.index:
            gene_start = ucsc.loc[i, 'txStart']
            gene_end   = ucsc.loc[i, 'txEnd']
            snp_start  = ucsc.loc[i, 'chromStart']
            snp_end    = ucsc.loc[i, 'chromStart']
            if(snp_start > gene_end): #gene before SNP
                dist = snp_start - gene_end
            elif (snp_end < gene_start): # gene after SNP
                dist = gene_start - snp_end
            else: # SNP on gene
                dist = 0
            ucsc.loc[i, 'dist'] = dist
        
        snps = np.unique(ucsc['name'])
        genes = pd.DataFrame(columns = ['SNP', 'Gene', 'dist'], index = range(len(ucsc)))
        
        l = 0
        for s in snps:
            
            tmp = ucsc.iloc[np.where(ucsc['name'] == s)][['name', 'geneName', 'dist']]
            min_dist = min(tmp['dist'])
            
            closest = tmp.iloc[np.where(tmp['dist'] == min_dist)]
            n = len(closest)
            closest.index = range(l,l+n)
            closest.columns = ['SNP', 'Gene', 'dist']
            genes.loc[range(l,l+n)] = closest
            
            l = l+n
        
        genes = genes.dropna()
        genes = genes.drop_duplicates()
        
        recap = pd.DataFrame()
        recap['Genes'] = genes.groupby('SNP')['Gene'].apply(list).apply(np.unique)
        recap['SNP'] = recap.index
                
        results_snps = pd.merge(results_snps, recap, on = 'SNP', how = 'left')
        results_snps['Genes'] = results_snps['Genes'].apply(list_to_string)
        
        os.system('rm ' + tmp_file)
        return(results_snps)
    
    #################################################################################################################################
    # Build a data frame with best snps per region (one line per Phenotype - Region association)
    # INPUT: data frame, list of phenotypes
    # OUTPUT: data frame with one line per Phenotype - Region association
    #################################################################################################################################
    def best_snp_per_region(results_snps, pheno_list):
        table_recap   = pd.DataFrame(columns = results_snps.columns, index = range(len(results_snps)))
        
        r = 0
        
        for pheno in pheno_list:
            print(pheno)
            
            temp = results_snps.iloc[np.where(results_snps['Phenotype'] == pheno)]       
            list_regions = np.unique(temp['Region'])
            
            for i in range(len(list_regions)):
                reg = list_regions[i]
                best_pval = min(temp[['Region', 'Standard_pval']].groupby('Region').min()['Standard_pval'][reg],temp[['Region', 'CMS_pval']].groupby('Region').min()['CMS_pval'][reg])
                if (len(temp.iloc[np.where(temp['CMS_pval'] == best_pval)]) != 0):   
                   line = temp.iloc[np.where(temp['CMS_pval'] == best_pval)]
                else:
                    line = temp.iloc[np.where(temp['Standard_pval'] == best_pval)]
                line.index = range(r + i, r + i + len(line))
                table_recap.loc[r + i, ] = line.loc[r + i]
                    
            r = r + len(list_regions)
            
        table_recap = table_recap.dropna(how = 'all')
        return(table_recap)
     
    #################################################################################################################################
    # Count significant regions per phenotype
    # INPUT: data frame, list of phenotypes, significance threshold
    # OUTPUT: data frame with number of significant regions per phenotype and per test (STD or CMS)
    #################################################################################################################################   
    def count_significant_regions(results_regions, pheno_list, th):    
        table_counts = pd.DataFrame(columns = ['Phenotype', 'Standard_only', 'CMS_only', 'Standard_and_CMS', 'Total'], index = range(len(pheno_list)))
             
        p = 0
        for pheno in pheno_list:
            print(pheno)
            
            temp = results_regions.iloc[np.where(results_regions['Phenotype'] == pheno)]
            
            std = temp.iloc[np.where(temp['Standard_pval'] < th)]
            cms = temp.iloc[np.where(temp['CMS_pval'] < th)]
            
            table_counts['Phenotype'][p] = pheno
            table_counts['Standard_only'][p] = len(std) - len(set(std.index) & set(cms.index))
            table_counts['CMS_only'][p] = len(cms) - len(set(std.index) & set(cms.index))
            table_counts['Standard_and_CMS'][p] = len(set(std.index) & set(cms.index))
            table_counts['Total'][p] = len(set(std.index) | set(cms.index))
            
            p = p + 1
        
        return(table_counts)
        
    #################################################################################################################################
    # Manhattan plot
    # INPUT: 
    # OUTPUT: 
    #################################################################################################################################
    def manhattan_plot(df, pheno, pval, th, colors, directory): # df = dataframe, pheno = phenotype, pval = name of the pvalue column,
                                                                # th = significance threshold
        threshold = -np.log10(th)
        
        df['-log10(' + pval + ')'] = - np.log10(df[pval])
        df_grouped = df.groupby(('CHR'))
        
        fig = plt.figure(figsize = (16, 9))
        ax = fig.add_subplot(111)
        x_labels = []
        x_labels_pos = []
      
        m = 0
        i = 0
        
        for num, (name, group) in enumerate(df_grouped):
            group = pd.DataFrame(group)
            group['absolute_position'] = group['POS'] + m
            m = np.max(group['absolute_position'])
            group.plot(kind = 'scatter' , x = 'absolute_position', y = '-log10('+pval+')', color = colors[i], ax = ax)
            x_labels.append(name)
            x_labels_pos.append((group['absolute_position'].iloc[-1] - (group['absolute_position'].iloc[-1] - group['absolute_position'].iloc[0])/2))
            i = i + 1
            
        ax.set_xticks(x_labels_pos)
        ax.set_xticklabels(x_labels)
        ax.set_xlim([0, m])
        ax.set_xlabel('Chromosome')
        ax.set_ylabel('-log10(P)')
        plt.axhline(y = threshold, color = 'r', linestyle = '-')
        if (pval[len(pval)-1] == 'A'):
            plt.title('Standard test', size = 16)
            fig.savefig(directory + '/Manhattan_plot_' + pheno + '_Standard.png', dpi=120, bbox_inches='tight')
        else:
            plt.title('CMS test', size = 16)
            fig.savefig(directory + '/Manhattan_plot_' + pheno + '_CMS.png', dpi=120, bbox_inches='tight')   
        
        plt.close(fig)
    
    
    #################################################################################################################################
    # QQplot
    # INPUT:
    # OUTPUT:
    #################################################################################################################################
    def qqplot(df, pheno, directory): # df = dataframe, pheno = phenotype
        n_snps = len(df)
        
        pvalue_std = -np.log10(df[pheno+'_pvalMA'])
        pvalue_cms = -np.log10(df[pheno+'_pvalMC'])
        
        x_unif = -np.log10(np.arange(1, n_snps + 1)/(n_snps + 2)) #Uniformed values between 0 and 1, 0 and 1 excluded
        y_std = np.sort(pvalue_std[:])
        y_cms = np.sort(pvalue_cms[:])
        
        fig = plt.figure(figsize = (9, 9))
        plt.scatter(x_unif[::-1], y_std, s = 5, c = 'blue')
        plt.scatter(x_unif[::-1], y_cms, s = 5, c = 'red')
        
        x = np.linspace(0, 6)
        plt.plot(x, x, c = 'black')
        
        lambda_value_std = np.median(y_std) / np.median(x_unif)
        lambda_value_cms = np.median(y_cms) / np.median(x_unif)
        
        plt.text(x = 0, y = 8, s = 'Standard test: lambda = ' + str(lambda_value_std), color = 'blue')
        plt.text(x = 0, y = 7, s = 'CMS test: lambda = '      + str(lambda_value_cms), color = 'red')
        plt.title('QQplot', size = 16)
        
        fig.savefig(directory + '/QQplot_' + pheno + '.png', dpi = 120, bbox_inches = 'tight')
        plt.close(fig)
    
    
    #################################################################################################################################
    # Quadrant plot
    # INPUT
    # OUTPUT
    #################################################################################################################################        
    def quadrant_plot(df, pheno, th, directory):  # df = dataframe, pheno = phenotype, th = significance threshold
        
        fig = plt.figure(figsize=(9, 9))
        
        n_stand = 0
        n_cms = 0
        n_both = 0
        n_others = 0
                   
        std = df[[pheno + '_pvalMA', 'Region']].groupby(('Region')).min()[pheno + '_pvalMA']
        cms = df[[pheno + '_pvalMC', 'Region']].groupby(('Region')).min()[pheno + '_pvalMC']      
        
        for i in range(1, 1704):
            if (i in std.index):
                if(cms[i] < th and std[i] > th):
                    plt.scatter(-np.log10(std[i]), -np.log10(cms[i]), c = 'lightgreen')
                    n_cms = n_cms + 1
                elif(cms[i] < th and std[i] < th):
                    plt.scatter(-np.log10(std[i]), -np.log10(cms[i]), c = 'turquoise')
                    n_both = n_both + 1
                elif(cms[i] > th and std[i] < th):
                    plt.scatter(-np.log10(std[i]), -np.log10(cms[i]), c = 'salmon')
                    n_stand = n_stand + 1
                else:
                    plt.scatter(-np.log10(std[i]), -np.log10(cms[i]), c = 'grey')
                    n_others = n_others + 1
                
        plt.axvline(-np.log10(th), color = 'grey', linestyle = '--')
        plt.axhline(-np.log10(th), color = 'grey', linestyle = '--')
        
        plt.xlabel('-log10(P) for standard test', fontsize=18)
        plt.ylabel('-log10(P) for CMS test',      fontsize=16)
        
        red_patch   = mpatches.Patch(color='salmon',     label='Significant pvalues for standard test only ('     + str(n_stand) + ')')
        blue_patch  = mpatches.Patch(color='turquoise',  label='Significant pvalues for CMS and standard tests (' + str(n_both) + ')')
        green_patch = mpatches.Patch(color='lightgreen', label='Significant pvalues for CMS test only ('          + str(n_cms)+')')
        black_patch = mpatches.Patch(color='grey',       label='Non significant pvalues ('                        + str(n_others)+')')
        
        plt.legend(handles = [red_patch, blue_patch, green_patch, black_patch])
        plt.title('CMS vs Standard test', size = 16)
        
        fig.savefig(directory + '/Quadrant_plot_' + pheno + '.png', dpi = 120, bbox_inches = 'tight')
        plt.close(fig)
        
    
    #################################################################################################################################
    # 4 plots on same file
    # INPUT
    # OUTPUT
    #################################################################################################################################
    def result_plot(df, pheno_list, th, directory): # df = dataframe, pheno_list = list of phenotypes, th = significance threshold
        
        colors = [[ 0.23625862,  0.52860261,  0.78364119], [ 0.03608782,  0.55719131,  0.1664441 ], [ 0.94307407,  0.35026836,  0.68064485], [ 0.25533632,  0.18023988,  0.42811556], [ 0.83003783,  0.81582078,  0.37646543], [ 0.65316214,  0.13796559,  0.45594953], [ 0.47797385,  0.18014566,  0.09708977], [ 0.44313678,  0.59353625,  0.50071228], [ 0.5985085 ,  0.76463359,  0.75645524], [ 0.67242609,  0.13103319,  0.17105948], [ 0.05590727,  0.56748858,  0.47916381], [ 0.80333682,  0.72828047,  0.39257576], [ 0.46142155,  0.65768434,  0.98160204], [ 0.85915425,  0.80420234,  0.47793356], [ 0.16912189,  0.17209861,  0.4464207 ], [ 0.6452403 ,  0.39523205,  0.59043336], [ 0.56400114,  0.46104143,  0.71735811], [ 0.86413373,  0.26718486,  0.04390139], [ 0.76195355,  0.17223056,  0.38066256], [ 0.4721071 ,  0.18925624,  0.25739407], [ 0.62012875,  0.97591735,  0.21921685], [ 0.64509652,  0.54463331,  0.23358183]]
    
        for pheno in pheno_list:
            print(pheno)
            manhattan_plot(df, pheno, pheno + '_pvalMA', th, colors, directory)
            manhattan_plot(df, pheno, pheno + '_pvalMC', th, colors, directory)
            qqplot(df, pheno, directory)
            quadrant_plot(df, pheno, th, directory)
        
            manMA = mpimg.imread(directory + '/Manhattan_plot_' + pheno + '_Standard.png')
            manMC = mpimg.imread(directory + '/Manhattan_plot_' + pheno + '_CMS.png')
            qq =    mpimg.imread(directory + '/QQplot_' + pheno + '.png') 
            quad =  mpimg.imread(directory + '/Quadrant_plot_' + pheno + '.png')
            
            fig = plt.figure(figsize=(32, 18))
            
            # First Manhattan plot (standard test)
            plt.subplot(2, 2, 1)
            plt.imshow(manMA)
            plt.axis('off')
            
            # QQ plot
            plt.subplot(2, 2, 2)
            plt.imshow(qq)
            plt.axis('off')
            
            # Second Manhattan plot (CMS test)
            plt.subplot(2, 2, 3)
            plt.imshow(manMC)
            plt.axis('off')
            
            # Quadrant plot
            plt.subplot(2, 2, 4)
            plt.imshow(quad)
            plt.axis('off')
        
            plt.suptitle(pheno, size = 32)
            fig.savefig(directory + '/Plots_' + pheno + '.png', dpi = 240,  bbox_inches = 'tight')
            plt.close(fig)
            
        os.system('rm ' + directory + '/Manhattan_plot_*')
        os.system('rm ' + directory + '/QQplot_*')
        os.system('rm ' + directory + '/Quadrant_plot_*')
    
    
    #################################################################################################################################
    # Global QQplot
    # INPUT
    # OUTPUT
    #################################################################################################################################
    def global_qqplot(df, pheno_list, directory): # df = dataframe, pheno_list = list of phenotypes
        
        fig = plt.figure(figsize=(9, 9)) 
        
        n_points = len(df) * len(pheno_list)
        x_unif = -np.log10(np.arange(1, n_points + 1) / (n_points + 2))[::-1] #Uniformed values between 0 and 1, 0 and 1 excluded
        y_std = []
        y_cms = []
        
        for pheno in pheno_list:
            print(pheno)           
            y_std = y_std + list(-np.log10(list(df[pheno + '_pvalMA'])))
            y_cms = y_cms + list(-np.log10(list(df[pheno + '_pvalMC'])))  
        
        y_std = np.sort(y_std)
        y_cms = np.sort(y_cms)
        
        plt.scatter(x_unif, y_std, s = 5, c = 'blue')
        plt.scatter(x_unif, y_cms, s = 5, c = 'red')
            
        x = np.linspace(0, 8)
        plt.plot(x, x, c = 'black')
            
        lambda_value_std = np.median(y_std) / np.median(x_unif)
        lambda_value_cms = np.median(y_cms) / np.median(x_unif)
            
        plt.text(x = 0, y = 6,  s = 'Standard test: lambda = ' + str(lambda_value_std), color = 'blue')
        plt.text(x = 0, y = 10, s = 'CMS test: lambda = '      + str(lambda_value_cms), color = 'red')
        plt.title('QQplot', size = 16)
            
        fig.savefig(directory + '/Global_QQplot.png', dpi = 120, bbox_inches = 'tight')
        plt.close(fig)
            
        fig = plt.figure(figsize = (9, 9))
        plt.scatter(x_unif[range(int(0.9999 * n_points))], y_std[range(int(0.9999 * n_points))], s = 5, c = 'blue')
        plt.scatter(x_unif[range(int(0.9999 * n_points))], y_cms[range(int(0.9999 * n_points))], s = 5, c = 'red')
        
        x = np.linspace(0,2)
        plt.plot(x, x, c = 'black')
        fig.savefig(directory + '/Global_QQplot_zoom.png', dpi = 120, bbox_inches = 'tight')
        plt.close(fig)
    
           
    #################################################################################################################################
    # Global quadrant plot
    # INPUT
    # OUTPUT
    #################################################################################################################################        
    def global_quadrant_plot(df, pheno_list, th, directory): # df = dataframe, pheno_list = list of phenotypes,
                                                             # th = significance threshold
        
        fig = plt.figure(figsize=(9, 9))
        
        n_stand = 0
        n_cms = 0
        n_both = 0
        n_others = 0
            
        for pheno in pheno_list:
            print(pheno)    
            std = df[[pheno + '_pvalMA', 'Region']].groupby(('Region')).min()[pheno + '_pvalMA']
            cms = df[[pheno + '_pvalMC', 'Region']].groupby(('Region')).min()[pheno + '_pvalMC']       
            
            for i in range(1, 1704):
                if (i in std.index):
                    if(cms[i] < th and std[i] > th):
                        plt.scatter(-np.log10(std[i]), -np.log10(cms[i]), c = 'lightgreen')
                        n_cms = n_cms + 1
                    elif(cms[i] < th and std[i] < th):
                        plt.scatter(-np.log10(std[i]), -np.log10(cms[i]), c = 'turquoise')
                        n_both = n_both + 1
                    elif(cms[i] > th and std[i] < th):
                        plt.scatter(-np.log10(std[i]), -np.log10(cms[i]), c = 'salmon')
                        n_stand = n_stand + 1
                    else:
                        plt.scatter(-np.log10(std[i]), -np.log10(cms[i]), c = 'grey')
                        n_others = n_others + 1
                
        plt.axvline(-np.log10(th), color = 'grey', linestyle = '--')
        plt.axhline(-np.log10(th), color = 'grey', linestyle = '--')
        
        plt.xlabel('-log10(P) for standard test', fontsize = 18)
        plt.ylabel('-log10(P) for CMS test'     , fontsize=16)
        
        red_patch   = mpatches.Patch(color = 'salmon',     label='Significant pvalues for standard test only ('     + str(n_stand) + ')')
        blue_patch  = mpatches.Patch(color = 'turquoise',  label='Significant pvalues for CMS and standard tests (' + str(n_both) + ')')
        green_patch = mpatches.Patch(color = 'lightgreen', label='Significant pvalues for CMS test only ('          + str(n_cms) + ')')
        black_patch = mpatches.Patch(color = 'grey',       label='Non significant pvalues ('                        + str(n_others) + ')')
        
        plt.legend(handles = [red_patch, blue_patch, green_patch, black_patch])
        plt.title('CMS vs Standard test', size = 16)
        
        fig.savefig(directory + '/Plots_Global_Quadrant_plot.png', dpi = 120, bbox_inches = 'tight')
        plt.close(fig)