diff --git a/raiss/stat_models.py b/raiss/stat_models.py index 1cec6781fc3580b55a503a0d1a0794135833b703..ea864f8ee1837874895bce8a9581a5cb4e70c738 100644 --- a/raiss/stat_models.py +++ b/raiss/stat_models.py @@ -68,6 +68,14 @@ def var_in_boundaries(var,lamb): return var +def invert_sig_t(sig_t, lamb, rcond): + try: + np.fill_diagonal(sig_t.values, (1+lamb)) + sig_t_inv = scipy.linalg.pinv(sig_t, rcond=rcond) + return(sig_t_inv) + except np.linalg.LinAlgError: + invert_sig_t(sig_t, lamb*1.1, rcond*1.1) + def raiss_model(zt, sig_t, sig_i_t, lamb=0.01, rcond=0.01, batch=True): """ Compute the variance @@ -80,9 +88,8 @@ def raiss_model(zt, sig_t, sig_i_t, lamb=0.01, rcond=0.01, batch=True): rcond (float): threshold to filter eigenvector with a eigenvalue under rcond make inversion biased but much more numerically robust """ - sig_t = sig_t.values - np.fill_diagonal(sig_t, (1+lamb)) - sig_t_inv = scipy.linalg.pinv(sig_t, rcond=rcond) + sig_t_inv = invert_sig_t(sig_t, lamb, rcond) + if batch: condition_number = np.array([np.linalg.cond(sig_t)]*sig_i_t.shape[0])