From ae33bb423153e66751ebe9f69596d28a4bc2d20b Mon Sep 17 00:00:00 2001
From: Christophe  BOETTO <cboetto@pasteur.fr>
Date: Fri, 22 Sep 2023 17:28:19 +0200
Subject: [PATCH] added extended version to manocca

---
 python/src/explainer.py            |  2 +-
 python/src/manocca.py              | 37 ++++++++++++++++++++++++++++--
 python/src/tools/compute_manova.py |  3 +++
 3 files changed, 39 insertions(+), 3 deletions(-)

diff --git a/python/src/explainer.py b/python/src/explainer.py
index f71a3a7..ace2bcd 100644
--- a/python/src/explainer.py
+++ b/python/src/explainer.py
@@ -219,7 +219,7 @@ class Explainer :
         if return_raw_contrib == True :
             return res
         else :
-            df_loadings = res.iloc[:,:-3]
+            df_loadings = res.iloc[:,:-3] #-3 because we added p, beta and chi2 at the end
             df_loadings = df_loadings*df_loadings
             df_prod_contrib = res["chi2"].values.reshape(-1,1)*df_loadings
             return df_prod_contrib.sum().sort_values(ascending = False)#.to_dict()
diff --git a/python/src/manocca.py b/python/src/manocca.py
index 7ee0397..95ac228 100644
--- a/python/src/manocca.py
+++ b/python/src/manocca.py
@@ -99,7 +99,7 @@ class MANOCCA:
 
     """
     def __init__(self, predictors, outputs, covariates=None, cols_outputs = None,
-                 cols_predictors = None, cols_covariates = None, prodV_red=None, n_comp = None, prod_to_keep = None, use_resid = True, use_pca = True, n_jobs = 1):
+                 cols_predictors = None, cols_covariates = None, prodV_red=None, n_comp = None, prod_to_keep = None, use_resid = True, use_pca = True, use_extended = False, n_jobs = 1):
 
         ### Initializing 
         self.outputs = outputs
@@ -126,6 +126,7 @@ class MANOCCA:
         self.n_jobs = n_jobs
         self.use_resid = use_resid
         self.use_pca = use_pca
+        self.use_extended = use_extended
 
         # Filled later
         self.prodV = None
@@ -148,7 +149,10 @@ class MANOCCA:
                 self.predictors = np.apply_along_axis(lambda x : pt.adjust_covariates(x,covariates), axis = 0, arr = self.predictors)
 
         else : # If not we preprocess the data and compute prodV and prodV_red
-            self.prodV = self.get_prodV_para_wrap(self.outputs)
+            if self.use_extended : 
+                self.prodV = self.get_prodV_extended_wrap(self.outputs)
+            else :
+                self.prodV = self.get_prodV_para_wrap(self.outputs)
             # self.prodV = scale(self.prodV)
             if not isinstance(self.prod_to_keep, type(None)): # Filtering out some columns
                 self.filter_prodV_columns()
@@ -267,6 +271,35 @@ class MANOCCA:
         else :
             return tmp
 
+    ### For extended ###
+    def get_prodV_extended(self, DD0, job_id, nb_compute):
+        L_prodV = []
+        for i in range(job_id, DD0.shape[1], nb_compute):
+    #         tmp = np.transpose(np.transpose(DD0[:,(i+1):DD0.shape[1]])*DD0[:,i])
+            tmp = (DD0[:,(i):]*DD0[:,i].reshape(-1,1))
+            tmp = pt.get_qt(tmp)
+            L_prodV += [tmp]
+        return L_prodV
+
+    def get_prodV_extended_wrap(self, DD0, verbose = 10):
+        n_jobs = self.n_jobs
+        if n_jobs == -1:
+            nb_compute = cpu_count()
+        else :
+            nb_compute = min(cpu_count(),n_jobs)
+        print("Computing prodV with %i cpu" %nb_compute)
+
+        res = Parallel(n_jobs=n_jobs, verbose = verbose)(delayed(self.get_prodV_extended)(DD0,j, nb_compute) for j in range(nb_compute))
+        
+        # reordering
+        res_ordered = []
+        for i in range(len(res[0])):
+            for L in res :
+                if L != [] : 
+                    res_ordered += [L.pop(0)]
+        return np.hstack([DD0]+res_ordered)
+
+    ### Other functions ###
     def get_prod_cols(self, cols, sep = '|'):
         return [cols[i]+sep+cols[j] for i in range(len(cols)-1) for j in range(i+1,len(cols))]
 
diff --git a/python/src/tools/compute_manova.py b/python/src/tools/compute_manova.py
index 589199a..644bb16 100644
--- a/python/src/tools/compute_manova.py
+++ b/python/src/tools/compute_manova.py
@@ -20,6 +20,7 @@ def custom_manova(Y,X,C=None, return_beta = False):
 
     # (sign_num, num) = np.linalg.slogdet(dot_Y-X.shape[0]*np.dot(beta,beta.T))
     (sign_num, num) = np.linalg.slogdet(dot_Y - np.dot( np.dot(beta, X.T) , np.dot(X, beta.T)))
+    # (sign_num, num) = np.linalg.slogdet((Y - X @ beta.T).T@(Y - X @beta.T))
     (sign_denom, denom) = np.linalg.slogdet(dot_Y)
     lamb = np.exp(sign_num*num-(sign_denom*denom))
     # print(lamb)
@@ -81,6 +82,8 @@ def linear_regression(Y, X, C=None):
         beta = np.dot(np.linalg.inv(np.dot(X.T,X)) , np.dot(X.T,Y))
         beta = beta.T
         # beta = np.dot(X.T,Y)/X.shape[0]
+        # print(beta.shape)
+        # beta = beta.T
     return beta
 
 
-- 
GitLab