Select Git revision
explainer.py
explainer.py 7.54 KiB
import pandas as pd
import numpy as np
from manocca import MANOCCA
from manova import MANOVA
from univariate import UNIVARIATE
from tools.h_clustering import cluster_corr
from tools.build_graph import get_tree, build_graph
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go
# import matplotlib.pyplot as plt
from tqdm import tqdm
class Explainer :
""" Explainer for the MANOCCA model
In order to give more insights into the test, the Explainer class provides a set of method to help understand the way MANOCCA performed.
Parameters
----------
model : MANOCCA class
An initialized MANOCCA instance. Can be retrieved from the Simulation class using : simu.get_model(0), where MANOCCA is the first model from L_methods
Attributes
----------
model : MANOCCA instance
covMat : numpy array, covariance matrix for all outputs.
df_loadings : pandas DataFrame, dataframe with the PCA loadings.
"""
def __init__(self, model):
self.model = model
self.covMat = None
self.cols_covMat = None
self.df_loadings = None
def _get_loadings(self):
# Retrieving PCA loadings
pca = self.model.pca
cols_outputs = np.array(self.model.cols_outputs.copy())
cols_prod = []
for i in range(len(cols_outputs)):
for j in range(i+1,len(cols_outputs)):
cols_prod += [str(cols_outputs[i]) + '|'+ str(cols_outputs[j])]
self.df_loadings = pd.DataFrame(np.abs(pca.components_), columns = cols_prod)
def power_pc_kept(self, pred, grid = None, max_n_comp = None, plot = True):
i_pred = self.model.cols_predictors.index(pred)
var = self.model.predictors[:,i_pred].reshape(-1,1)
if max_n_comp is None :
max_n_comp = self.model.n_comp
if grid is None :
grid = list(range(1,max_n_comp))
L_p = []
for el in grid :
self.model.test(var = var, n_comp = el)
L_p += [self.model.p[0]]
L_p = np.array(L_p)
if plot :
px.scatter(y=-np.log10(L_p), x=grid, labels={'x': 'Number of PC kept', 'y':'$-log_10(P-value)'})
return L_p, grid
def significant_covariance(self, pc_num, max_loadings, plot = True):
"""
!!! pc_num starts at 0 !!!
"""
cols_outputs = np.array(self.model.cols_outputs.copy())
if isinstance(self.df_loadings, type(None)):
self._get_loadings()
tmp = self.df_loadings.iloc[pc_num,:]
tmp.sort_values(ascending = False, inplace = True) # loadings are already taking in absolute value in _get_loadings() , key = pd.Series.abs)
top_loadings = tmp.copy()
# We generate covariance matrix if not already computed
if isinstance(self.covMat, type(None)):
covMat = np.cov(self.model.outputs.T)
covMat, idx = cluster_corr(covMat)
cols_covMat = cols_outputs[idx]
self.covMat = covMat
self.cols_covMat = cols_covMat
else :
covMat = self.covMat
cols_covMat = self.cols_covMat
# mat = np.array([[np.nan]*covMat.shape[0]]*covMat.shape[1])
mat = np.empty((covMat.shape[0],covMat.shape[1]))
mat.fill(np.nan) # seemed to be the quickest to create a Nxk NaN matrix
for k in range(max_loadings):
el_ = top_loadings.index[k].split("|")
cols_covMat = cols_covMat.astype(str)
i = np.where(cols_covMat == str(el_[0]))[0][0]
j = np.where(cols_covMat == str(el_[1]))[0][0]
mat[i,j] = covMat[i,j]
mat[j,i] = covMat[j,i]
mat = pd.DataFrame(mat, columns = cols_covMat, index = cols_covMat)
covMat = pd.DataFrame(covMat, columns = cols_covMat, index = cols_covMat)
# display(px.imshow(covMat))
fig = make_subplots(rows=1, cols=2)
fig.update_layout(
# title = {
# 'text' : "Loadings of the %ith PC" %pc_num,
# 'y':0.9,
# 'x':0.5,
# 'font' : dict(family = "Arial", size = 24),
# 'xanchor' : 'center',
# 'yanchor' : 'top'
# },
coloraxis=dict(colorscale='thermal', cmin = -1, cmax = 1)) #coloraxis_autocolorscale=False)
fig.add_trace(
px.imshow(mat, zmin = -1, zmax = 1).data[0],
row=1, col=1
)
fig.add_trace(
px.imshow(covMat, zmin = -1, zmax = 1).data[0], #, color_continuous_scale = 'Darkmint'
row=1, col=2
)
xmin, xmax = 0,371 # !!!! to change
ymin, ymax = 371, 0 # !!!! to change
fig.update_xaxes(range=[xmin, xmax],autorange=False)
fig.update_yaxes(range=[ymin, ymax], autorange=False)
if plot :
fig.show()
return mat
def univariate_cov(self, var, n_comp, plot = True, return_beta = False):
m = self.model
i_var = m.cols_predictors.index(var)
univ = UNIVARIATE(m.prodV_red[:,:n_comp], m.predictors[:,i_var].reshape(-1,1), cols_outputs = ['PC_' + str(i) for i in range(n_comp)], cols_predictors = ['var'])
univ.test()
if plot:
univ.plot()
if return_beta :
return univ.p, univ.beta
else :
return univ.p
def feature_importances(self, var, n_comp, threshold = 'all', return_raw_contrib = False) :
m = self.model
i_var = m.cols_predictors.index(var)
N_samples = np.sum(~np.isnan(m.predictors[:,i_var])) #retrieve sample (non-NaN values) size for the considered variable
if isinstance(self.df_loadings, type(None)):
self._get_loadings()
p, beta = self.univariate_cov(var, n_comp, plot = False, return_beta = True)
# print(p.shape)
df_p_beta = pd.DataFrame(p, columns = ['p'])
df_p_beta['beta'] = beta
df_p_beta['chi2'] = N_samples*beta**2
if threshold == 'Bonferroni':
df_p_beta = df_p_beta[df_p_beta<0.05/df_p_beta.shape[0]]
elif 'top_' in threshold:
df_p_beta.sort_values('p', inplace = True)
nb_keep = int(threshold.split('_')[-1]) # retrieve number of pc to keep
df_p_beta = df_p_beta.iloc[:nb_keep,:]
elif threshold == 'all' :
pass
else :
print('Threshold type not found')
pc_to_look = list(df_p_beta.index)
df_pc = self.df_loadings.loc[pc_to_look,:]
# res = df_p_beta['chi2'].values.reshape(-1,1)*df_pc
res = pd.merge(df_pc, df_p_beta, left_index = True, right_index = True )
if return_raw_contrib == True :
return res
else :
df_loadings = res.iloc[:,:-3]
df_loadings = df_loadings*df_loadings
df_prod_contrib = res["chi2"].values.reshape(-1,1)*df_loadings
return df_prod_contrib.sum().sort_values(ascending = False)#.to_dict()
def get_graph(self, L_d_otu_pvals, name = 'test.html', tree = None, figsize = (800,800), show = True, notebook = True):
# Plot graph
build_graph(L_d_otu_pvals, name = name, tree = tree, figsize = figsize, show = show, notebook = notebook)
def split_contribution(self, prod_contrib, sep = "|"):
df_contrib_otu = pd.Series(index = self.model.cols_outputs)
for otu in tqdm(df_contrib_otu.index) :
df_contrib_otu[otu] = prod_contrib[prod_contrib.index.str.contains("^"+otu+"\||\|"+otu+"$", regex = True)].sum()
return df_contrib_otu