diff --git a/data/coef_mean_model.tsv b/data/coef_mean_model.tsv
index 9cde86f5e42d1ec10a085ec23d1af69931dc01c1..ca07eb89886c598f96afb4e4f998f870e57d30f1 100644
--- a/data/coef_mean_model.tsv
+++ b/data/coef_mean_model.tsv
@@ -3,5 +3,5 @@ k_coef_mv	0.07740334977380119
 log10_avg_distance_cor_coef_mv	-0.6999110771883902
 log10_mean_gencov_coef_mv	0.746794584985343
 avg_Neff_coef_mv	0.07289261717080556
-avg_h2_mixer_coef_mv	-0.516496395500929
+avg_h2_coef_mv	-0.516496395500929
 avg_perc_h2_diff_region_coef_mv	0.15727591593399
diff --git a/data/combination_example.tsv b/data/combination_example.tsv
new file mode 100644
index 0000000000000000000000000000000000000000..b42f3a73d6e4e8d2ef72070c2b7174f804ce0ed9
--- /dev/null
+++ b/data/combination_example.tsv
@@ -0,0 +1,2 @@
+GRP1,z_GIANT_HIP z_GLG_HDL z_GLG_LDL z_MAGIC_2HGLU-ADJBMI
+GRP2,z_SPIRO-UKB_FVC z_SPIRO-UKB_FEV1 z_TAGC_ASTHMA
diff --git a/data/gain_results.tsv b/data/gain_results.tsv
new file mode 100644
index 0000000000000000000000000000000000000000..cd44d997b8fe28075d6dd53ef1374fc77cf4e6af
--- /dev/null
+++ b/data/gain_results.tsv
@@ -0,0 +1,3 @@
+	traits	k	avg_distance_cor	mean_gencov	avg_Neff	avg_h2	avg_perc_h2_diff_region	log10_mean_gencov	log10_avg_distance_cor	gain
+1	['z_SPIRO-UKB_FVC', 'z_SPIRO-UKB_FEV1', 'z_TAGC_ASTHMA']	0.1	0.1731946683845993	0.0637	0.3843393026739591	0.2785193310634847	0.7976315890930669	0.8139196701681637	0.8013809378674498	0.06428524764535551
+0	['z_GIANT_HIP', 'z_GLG_HDL', 'z_GLG_LDL', 'z_MAGIC_2HGLU-ADJBMI']	0.2	0.14899001074867035	0.01535	0.12076877719858631	0.22628198390356655	0.9055326131023057	0.6573854616675169	0.7879956172999502	-0.010766494024690904
diff --git a/data/range_feature_gain_prediction.tsv b/data/range_feature_gain_prediction.tsv
index 15031fd662c92ed128c48385f81514fd65ec0b1e..30c63afb46258403eeb27bf17cb6a9095afbd972 100644
--- a/data/range_feature_gain_prediction.tsv
+++ b/data/range_feature_gain_prediction.tsv
@@ -3,5 +3,5 @@ k	2.0	12.0
 log10_avg_distance_cor	-4.675617219570908	0.20864138105896807
 log10_mean_gencov	-4.4093921991254446	-0.46117501106209624
 avg_Neff	6730.5	697828.0
-avg_h2_mixer	0.014033707225812	0.4361454950334251
+avg_h2	0.014033707225812	0.4361454950334251
 avg_perc_h2_diff_region	0.0906544694784672	0.9831222899777692
diff --git a/jass/__main__.py b/jass/__main__.py
index 3b64555f9c224c599b8daeb564c96c50e77d0221..4a7cd1e8c133f485a1875caa841970c91a5cbafc 100644
--- a/jass/__main__.py
+++ b/jass/__main__.py
@@ -21,7 +21,7 @@ from jass.models.plots import (
     create_local_plot,
     create_qq_plot,
 )
-from jass.models.gain import compute_gain
+from jass.models.gain import create_features, compute_gain
 
 from pandas import read_hdf
 
@@ -282,12 +282,13 @@ def w_gene_annotation(args):
     )
 
 def w_compute_gain(args):
+    inittable_path = absolute_path_of_the_file(args.inittable_path)
     combi_path = absolute_path_of_the_file(args.combination_path)
     combi_path_with_gain = absolute_path_of_the_file(args.gain_path, True)
+    
+    features = create_features(inittable_path, combi_path)
+    compute_gain(features, combi_path_with_gain)
 
-    compute_gain(
-        combi_path, combi_path_with_gain
-    )
 
 
 def get_parser():
@@ -632,12 +633,17 @@ def get_parser():
     
     # ------- compute predicted gain -------#
     parser_create_mp = subparsers.add_parser(
-        "predict-gain", help="predict gain based on the genetic architecture of the set of multi-trait"
+        "predict-gain", help="Predict gain based on the genetic architecture of the set of multi-trait. To function, this command need the inittable to contain genetic covariance store under the key 'GEN_COV in the inittable'"
+    )
+    parser_create_mp.add_argument(
+        "--inittable-path",
+        required=True,
+        help="Path to the inittable",
     )
     parser_create_mp.add_argument(
         "--combination-path",
         required=True,
-        help="path to the worktable file containing the data",
+        help="Path to the file storing combination to be scored",
     )
     parser_create_mp.add_argument(
         "--gain-path", required=True, help="path to save predicted gain"
diff --git a/jass/models/gain.py b/jass/models/gain.py
index f5c2fc45ecc31173e431e4bc91c9d7c84949fab2..d92db8fec2befc4f099bc26c4ae05b26eb230c40 100644
--- a/jass/models/gain.py
+++ b/jass/models/gain.py
@@ -1,6 +1,8 @@
 import pandas as pd
 import numpy as np
 
+
+# data issued from https://doi.org/10.1101/2023.10.27.564319
 X_range = pd.read_csv("./data/range_feature_gain_prediction.tsv", sep="\t", index_col=0)
 model_coefficients =  pd.read_csv("./data/coef_mean_model.tsv", sep="\t", index_col=0)
 
@@ -14,15 +16,126 @@ def preprocess_feature(df_combinations):
  
     df_combinations['log10_mean_gencov'] = np.log10(df_combinations.mean_gencov)
     df_combinations['log10_avg_distance_cor'] = np.log10(df_combinations.avg_distance_cor)
-    for f in ["k", "log10_avg_distance_cor", "log10_mean_gencov", "avg_Neff", "avg_h2_mixer", "avg_perc_h2_diff_region"]:
+    for f in ["k", "log10_avg_distance_cor", "log10_mean_gencov", "avg_Neff", "avg_h2", "avg_perc_h2_diff_region"]:
         df_combinations[f] = scale_feature(df_combinations[f], f)
     return df_combinations
 
-
-def compute_gain(path_combi, path_output):
-
-    df_combinations = pd.read_csv(path_combi)
+def compute_gain(df_combinations, path_output):
 
     preprocess_feature(df_combinations)
-    df_combinations["gain"] = df_combinations[["k", "log10_avg_distance_cor", "log10_mean_gencov", "avg_Neff", "avg_h2_mixer", "avg_perc_h2_diff_region"]].dot(model_coefficients["0"].values)
+    df_combinations["gain"] = df_combinations[["k", "log10_avg_distance_cor", "log10_mean_gencov", "avg_Neff", "avg_h2", "avg_perc_h2_diff_region"]].dot(model_coefficients["0"].values)
     df_combinations.sort_values(by="gain", ascending=False).to_csv(path_output, sep="\t")
+
+# cov to cor
+def cov2cor(c):
+    """
+    Return a correlation matrix given a covariance matrix. 
+    : c = covariance matrix
+    """
+    D = 1 / np.sqrt(np.diag(c)) # takes the inverse of sqrt of diag.
+    return D * c * D
+
+def compute_detected_undected_h2(inittable_path): 
+
+    phenoL = pd.read_hdf(inittable_path, "PhenoList")
+    gen_cov = pd.read_hdf(inittable_path, "GEN_COV")
+    region = pd.read_hdf(inittable_path, "Regions")
+
+    combi_c = list(phenoL.index)
+    reg_start=0
+    reg_end=50
+
+    chunk_size=50
+    Nchunk = region.shape[0] // chunk_size + 1
+    start_value = 0
+    
+    zscore_threshold = 5.452
+    h2_GW = np.zeros(len(combi_c))
+
+    for chunk in range(start_value, Nchunk):
+        print(chunk)
+        binf = chunk * chunk_size
+        bsup = (chunk + 1) * chunk_size
+
+        init_extract = pd.read_hdf(inittable_path, "SumStatTab", where= "Region >= {0} and Region < {1}".format(reg_start, reg_end))
+
+        init_extract[combi_c] = init_extract[combi_c].abs()
+        max_zscore = init_extract[["Region"] + combi_c].groupby("Region").max()
+
+        Neff_term = np.ones(max_zscore.shape)
+        Neff_term = Neff_term* (1/phenoL.loc[combi_c, "Effective_sample_size"].values)
+
+        beta_2 = max_zscore.mask((max_zscore < zscore_threshold)) 
+        beta_2 = beta_2 * np.sqrt(Neff_term)
+        beta_2 = beta_2.mask( (beta_2 > 0.019))
+
+        h2_GW += (beta_2**2).sum()
+
+    h2 = np.diag(gen_cov.loc[combi_c, combi_c])
+    undetected_h2 = ((h2 - h2_GW) / h2)
+
+    phenoL["h2"] = h2
+    phenoL["h2_GW"] = h2_GW
+    phenoL["undetected_h2_perc"] = undetected_h2
+    
+    return phenoL
+
+def add_h2_to_pheno_description(inittable_path):
+    phenoL_before = pd.read_hdf(inittable_path, "PhenoList")
+    if "avg_perc_h2_diff_region" in phenoL_before.columns:    
+        phenoL = compute_detected_undected_h2(inittable_path)
+        phenoL.to_hdf(inittable_path, key="table")
+
+
+def compute_mean_cov(cov, combi_c):
+    rows, cols = np.indices(cov.loc[combi_c, combi_c].shape)
+    mean_gencov = cov.loc[combi_c, combi_c].where(rows != cols).stack().abs().mean()
+    return mean_gencov
+
+def compute_diff_cor(res_cov, gen_cov, combi_c):
+    res_cor = cov2cor(res_cov.loc[combi_c, combi_c])
+    gen_cor = cov2cor(gen_cov.loc[combi_c, combi_c])
+    rows, cols = np.indices(res_cor.loc[combi_c, combi_c].shape)
+    off_gencor = res_cor.where(rows != cols).stack()
+    off_rescor = gen_cor.where(rows != cols).stack()
+    return (off_gencor - off_rescor).abs().mean()
+
+def compute_mean_undetected_h2(phenoL, combi_c):
+    mean_h2 = np.mean(phenoL.loc[combi_c, "undetected_h2_perc"])
+    return mean_h2
+
+def compute_mean_h2(phenoL, combi_c):
+    mean_h2 = np.mean(phenoL.loc[combi_c, "h2"])
+    return mean_h2
+
+def compute_mean_Neff(phenoL, combi_c):
+    mean_neff = np.mean(phenoL.loc[combi_c, "Effective_sample_size"])
+    return mean_neff
+
+
+#beta_2_GW = beta_2_GW * (1/phenoL.loc[combi_c, "Effective_sample_size"].values)
+def create_features(inittable_path, combi_file):
+    add_h2_to_pheno_description(inittable_path)
+
+    phenoL = pd.read_hdf(inittable_path, "PhenoList")
+    gen_cov = pd.read_hdf(inittable_path, "GEN_COV")
+    res_cov = pd.read_hdf(inittable_path, "COV")
+
+    combi = pd.read_csv(combi_file, sep=",", index_col=0, names=["combi"])
+    combi = list(combi.combi.str.split(" "))
+
+    D = {'traits':[] ,'k':[], 'avg_distance_cor': [], 'mean_gencov': [], 
+        'avg_Neff':[], 'avg_h2':[], 'avg_perc_h2_diff_region':[]}
+
+    for c in combi:
+        D['traits'].append(str(c))
+        D["k"].append(len(c))
+        D["avg_distance_cor"].append(compute_diff_cor(res_cov, gen_cov, c))
+        D["mean_gencov"].append(compute_mean_cov(gen_cov, c))
+        D["avg_Neff"].append(compute_mean_Neff(phenoL, c))
+
+        D["avg_h2"].append(compute_mean_h2(phenoL, c))
+        D["avg_perc_h2_diff_region"].append(compute_mean_undetected_h2(phenoL, c))
+
+    return pd.DataFrame.from_dict(D)
+