Skip to content
Snippets Groups Projects
Select Git revision
  • d69bd2cecc04dd88e0e131803decab4b0dd80612
  • master default protected
2 results

simulation.py

Blame
  • user avatar
    Christophe BOETTO authored
    d69bd2ce
    History
    simulation.py 5.11 KiB
    import numpy as np
    import pandas as pd
    
    from manocca import MANOCCA
    from manova import MANOVA
    
    
    from tools.preprocessing_tools import scale
    from tools.h_clustering import cluster_corr
    
    
    import plotly.graph_objects as go
    from plotly.subplots import make_subplots
    
    from joblib import Parallel, delayed, cpu_count
    
    import tools.preprocessing_tools as pt
    
    class Simu :
        def __init__(self,  methods, predictors, outputs=None, covariates=None, cols_outputs = None,
                     cols_predictors = None, cols_covariates = None, L_preproc = [], prodV_red=None, use_resid = True, n_comp = None, n_jobs = 1):
    
                ### Initializing 
            self.outputs = outputs
            self.cols_outputs = cols_outputs
            if covariates is not None :
                self.outputs, self.cols_outputs = pt._extract_cols(self.outputs, self.cols_outputs)
    
            self.predictors = predictors
            self.cols_predictors = cols_predictors
            if isinstance(predictors, pd.DataFrame) or isinstance(predictors, np.ndarray) or isinstance(predictors, pd.Series) :
                self.predictors, self.cols_predictors = pt._extract_cols(self.predictors, self.cols_predictors)
    
            self.covariates = covariates
            self.cols_covariates = cols_covariates
            if covariates is not None :
                self.covariates, self.cols_covariates = pt._extract_cols(self.covariates, self.cols_covariates) 
    
            self.prodV_red = prodV_red
    
            self.n_comp = n_comp
            self.use_resid = use_resid
            self.n_jobs = n_jobs
    
            if len(L_preproc)>0:
                self.outputs = pt.pipeline(self.outputs, L_pipe = L_preproc)
    
            if isinstance(methods, str):
                self.methods = [methods]
            else :
                self.methods = methods
            self.L_models = []
            self._build_model()
    
            ### To fill later ###
            self.p = None #np.empty((self.predictors.shape[1],0))
    
        def _build_model(self):
            for m in self.methods :
                if m == 'MANOCCA':
                    self.L_models += [MANOCCA(self.predictors, self.outputs, self.covariates, self.cols_outputs, self.cols_predictors, self.cols_covariates, self.prodV_red, self.n_comp, self.use_resid, self.n_jobs)]
                elif m == 'MANOVA':
                    self.L_models += [MANOVA(self.predictors, self.outputs, self.covariates, self.cols_outputs, self.cols_predictors, self.cols_covariates, self.use_resid)]
        
        def get_model(self, num = 0):
            return self.L_models[num]
    
        def run_simu(self, n_comp, var = None):
    
            self.p = None
            for m in self.L_models :
                m.test(var = var, n_comp = n_comp)
                # print(self.p.shape)
                if isinstance(self.p, type(None)):
                    self.p = m.p
                else :
                    self.p = np.hstack([self.p, m.p])
            #     print(m.p)
            #     self.p += [m.p]#.reshape(-1,1) #np.hstack([self.p, m.p.reshape(-1,1)])
            # self.p = np.array(self.p)
    
    
        def plot_comparison(self, fig=None, show = True, row = 1, col = 1):
            if not isinstance(fig,go.Figure):
                fig = make_subplots(rows = 1,cols=1)
                fig.update_layout(title = "Comparison of " + ' '.join(self.methods))
                
            for i, m in enumerate(self.methods):
                fig.add_trace(
                    go.Scatter(x = self.cols_predictors, y=-np.log10(self.p[:,i]), name = m, mode='markers'),
                    row=row,col=col
                )
            fig.add_trace(
                go.Scatter(x = self.cols_predictors, y = [-np.log10(0.05/len(self.cols_predictors))]*len(self.cols_predictors), name = 'Bonferroni threshold'),
                row=row,col=col
            )
            if show:
                fig.show()
            else :
                return fig
    
        def _null_loop(self, m, var_null, N_iters, n_comp = None ) :
            p_null = np.empty((N_iters, 1))
            for i in range(N_iters):
                np.random.shuffle(var_null)
    
                m.test(var_null, n_comp = n_comp)
                p_null[i] = m.p[0]
            return p_null
    
        def simu_null(self, var_name, model, Nsimu = 1000, n_comp = None):
            var_pos = self.cols_predictors.index(var_name)
            var_null = self.predictors[:,var_pos].copy()
            var_null = var_null.reshape(-1,1)
            m = self.L_models[self.methods.index(model)]
    
            if self.n_jobs == 1 : 
                return self._null_loop(m, var_null, Nsimu, n_comp)
            else:
                nb_cpus = cpu_count()
                iters = [Nsimu//nb_cpus]*nb_cpus
                iters[0]+= Nsimu%nb_cpus
                res = Parallel(n_jobs=self.n_jobs, verbose = 0)(delayed(self._null_loop)(m, var_null, nb_iters, n_comp) for nb_iters in iters)
                return np.vstack(res)
    
    
        # def _GWAS_loop(self, i_job, L_split):
        #     arr = arr[:,maf_mask[L_split[run]:L_split[run+1]]]
    
    
        # def run_GWAS(self):
        #     N = self.predictors.shape[1]
    
        #     if self.n_jobs == -1 or self.n_jobs > cpu_count():
        #         n_array = cpu_count()
        #     else :
        #         n_array = self.n_jobs
        #     L_split = [int(N/n_array)*k for k in range(n_array)]
        #     L_split += [N]
    
        #     res = Parallel(n_jobs=n_jobs, verbose = 10)(delayed(self._GWAS_loop)(k, L_splits) for k in range(n_array))
        #     return res