Skip to content
Snippets Groups Projects
Select Git revision
  • 61707af4492d802a82e07de7c9a7fbd0f3627ae3
  • master default protected
2 results

explainer.py

Blame
  • explainer.py 7.54 KiB
    import pandas as pd
    import numpy as np
    
    from manocca import MANOCCA
    from manova import MANOVA
    from univariate import UNIVARIATE
    
    from tools.h_clustering import cluster_corr
    from tools.build_graph import get_tree, build_graph
    
    import plotly.express as px
    from plotly.subplots import make_subplots
    import plotly.graph_objects as go
    # import matplotlib.pyplot as plt
    
    from tqdm import tqdm
    
    class Explainer :
        """ Explainer for the MANOCCA model
    
        In order to give more insights into the test, the Explainer class provides a set of method to help understand the way MANOCCA performed. 
    
        Parameters
        ----------
        model : MANOCCA class
            An initialized MANOCCA instance. Can be retrieved from the Simulation class using : simu.get_model(0), where MANOCCA is the first model from L_methods
    
        
        Attributes
        ----------
        model : MANOCCA instance
        covMat : numpy array, covariance matrix for all outputs.
        df_loadings : pandas DataFrame, dataframe with the PCA loadings.
    
        """
    
    
    
    
        def __init__(self, model):
            self.model = model
    
            self.covMat = None
            self.cols_covMat = None
    
            self.df_loadings = None
    
    
        def _get_loadings(self):
            # Retrieving PCA loadings
            pca = self.model.pca
            cols_outputs = np.array(self.model.cols_outputs.copy())
    
            cols_prod = []
            for i in range(len(cols_outputs)):
                for j in range(i+1,len(cols_outputs)):
                    cols_prod += [str(cols_outputs[i]) + '|'+ str(cols_outputs[j])]
            self.df_loadings = pd.DataFrame(np.abs(pca.components_), columns = cols_prod)
    
    
        def power_pc_kept(self, pred, grid  = None, max_n_comp = None, plot = True):
            i_pred = self.model.cols_predictors.index(pred)
            var = self.model.predictors[:,i_pred].reshape(-1,1)
    
            if max_n_comp is None :
                max_n_comp = self.model.n_comp
    
            if grid is None : 
                grid = list(range(1,max_n_comp))
    
            L_p = []
            for el in grid :
                self.model.test(var = var, n_comp = el)
                L_p += [self.model.p[0]]
            
            L_p = np.array(L_p)
    
            if plot :
                px.scatter(y=-np.log10(L_p), x=grid, labels={'x': 'Number of PC kept', 'y':'$-log_10(P-value)'})
    
            return L_p, grid
    
    
        def significant_covariance(self, pc_num, max_loadings, plot = True):
            """
            !!! pc_num starts at 0 !!! 
            """
            cols_outputs = np.array(self.model.cols_outputs.copy())
            
    
            if isinstance(self.df_loadings, type(None)):
                self._get_loadings()
    
            tmp = self.df_loadings.iloc[pc_num,:]
            tmp.sort_values(ascending = False, inplace = True) # loadings are already taking in absolute value in _get_loadings()   , key = pd.Series.abs)
    
            top_loadings = tmp.copy()
    
    
            # We generate covariance matrix if not already computed
            if isinstance(self.covMat, type(None)):
                covMat = np.cov(self.model.outputs.T)
                covMat, idx = cluster_corr(covMat)
                cols_covMat = cols_outputs[idx]
                self.covMat = covMat
                self.cols_covMat = cols_covMat
            else :
                covMat = self.covMat
                cols_covMat = self.cols_covMat
    
    
    
            # mat = np.array([[np.nan]*covMat.shape[0]]*covMat.shape[1])
            mat = np.empty((covMat.shape[0],covMat.shape[1]))
            mat.fill(np.nan) # seemed to be the quickest to create a Nxk NaN matrix
    
            for k in range(max_loadings):
                el_ = top_loadings.index[k].split("|")
    
                cols_covMat = cols_covMat.astype(str)
                i = np.where(cols_covMat == str(el_[0]))[0][0]
                j = np.where(cols_covMat == str(el_[1]))[0][0]
    
                mat[i,j] = covMat[i,j]
                mat[j,i] = covMat[j,i]
    
    
            mat = pd.DataFrame(mat, columns = cols_covMat, index = cols_covMat)
            covMat = pd.DataFrame(covMat, columns = cols_covMat, index = cols_covMat)
    
            # display(px.imshow(covMat))
    
            fig = make_subplots(rows=1, cols=2)
            fig.update_layout(
                # title = {
                #     'text' : "Loadings of the %ith PC" %pc_num,
                #     'y':0.9,
                #     'x':0.5,    
                #     'font' : dict(family = "Arial", size = 24),
                #     'xanchor' : 'center',
                #     'yanchor' : 'top'
                #     },
                coloraxis=dict(colorscale='thermal', cmin = -1, cmax = 1)) #coloraxis_autocolorscale=False)
    
            fig.add_trace(
                px.imshow(mat, zmin = -1, zmax = 1).data[0],
                row=1, col=1
            )
    
            fig.add_trace(
                px.imshow(covMat, zmin = -1, zmax = 1).data[0],  #, color_continuous_scale = 'Darkmint'
                row=1, col=2
            )
            xmin, xmax = 0,371  # !!!! to change 
            ymin, ymax = 371, 0 # !!!! to change 
            fig.update_xaxes(range=[xmin, xmax],autorange=False)
            fig.update_yaxes(range=[ymin, ymax], autorange=False)
            if plot :
                fig.show()
            
            return mat
    
    
    
    
        def univariate_cov(self, var, n_comp, plot = True, return_beta = False):
            m = self.model
    
            i_var = m.cols_predictors.index(var)
    
            univ = UNIVARIATE(m.prodV_red[:,:n_comp], m.predictors[:,i_var].reshape(-1,1), cols_outputs = ['PC_' + str(i) for i in range(n_comp)], cols_predictors = ['var'])
            univ.test()
    
            if plot:
                univ.plot()
    
            if return_beta :
                return univ.p, univ.beta
            else :
                return univ.p
    
    
        def feature_importances(self, var, n_comp, threshold = 'all', return_raw_contrib = False) :
            m = self.model
            i_var = m.cols_predictors.index(var)
            N_samples = np.sum(~np.isnan(m.predictors[:,i_var])) #retrieve sample (non-NaN values) size for the considered variable
            if isinstance(self.df_loadings, type(None)):
                self._get_loadings()
    
            p, beta = self.univariate_cov(var, n_comp, plot = False, return_beta = True)
            # print(p.shape)
            df_p_beta = pd.DataFrame(p, columns = ['p'])
            df_p_beta['beta'] = beta
            df_p_beta['chi2'] = N_samples*beta**2
    
    
            if threshold == 'Bonferroni':
                df_p_beta = df_p_beta[df_p_beta<0.05/df_p_beta.shape[0]]
    
    
            elif 'top_' in threshold:
                df_p_beta.sort_values('p', inplace = True)
                nb_keep = int(threshold.split('_')[-1]) # retrieve number of pc to keep
                df_p_beta = df_p_beta.iloc[:nb_keep,:]
    
            elif threshold == 'all' :
                pass
    
            else :
                print('Threshold type not found')
    
            pc_to_look = list(df_p_beta.index)
    
            df_pc = self.df_loadings.loc[pc_to_look,:]
    
            # res = df_p_beta['chi2'].values.reshape(-1,1)*df_pc
            res = pd.merge(df_pc, df_p_beta, left_index = True, right_index = True ) 
    
            if return_raw_contrib == True :
                return res
            else :
                df_loadings = res.iloc[:,:-3]
                df_loadings = df_loadings*df_loadings
                df_prod_contrib = res["chi2"].values.reshape(-1,1)*df_loadings
                return df_prod_contrib.sum().sort_values(ascending = False)#.to_dict()
    
        def get_graph(self, L_d_otu_pvals, name = 'test.html', tree = None, figsize = (800,800), show = True, notebook = True):
    
            # Plot graph
            build_graph(L_d_otu_pvals, name = name, tree = tree, figsize = figsize, show = show, notebook = notebook)
    
        def split_contribution(self, prod_contrib, sep = "|"):
            df_contrib_otu = pd.Series(index = self.model.cols_outputs) 
            for otu in tqdm(df_contrib_otu.index) :
                df_contrib_otu[otu] = prod_contrib[prod_contrib.index.str.contains("^"+otu+"\||\|"+otu+"$", regex = True)].sum()
            return df_contrib_otu