diff --git a/jass/models/plots.py b/jass/models/plots.py
index 3953d0f3419db0e4bbebd31e675a7d54a0be2991..ef60219ef9e54f91a5016d74d3c6851f78f69a2d 100644
--- a/jass/models/plots.py
+++ b/jass/models/plots.py
@@ -17,10 +17,10 @@ from matplotlib import colors
 import matplotlib.patches as mpatches
 from scipy.stats import norm, chi2
 import seaborn as sns
-import os
 from pandas import DataFrame, read_hdf
 import pandas as pd
 
+default_chunk_size=50
 
 def replaceZeroes(df):
     """
@@ -32,30 +32,34 @@ def replaceZeroes(df):
     df.values[df.values == 0] = min_nonzero
     return df
 
-def create_global_plot(work_file_path: str, global_plot_path: str):
+
+def get_info_4_global_plot(work_file_path: str):
+    regions = read_hdf(work_file_path, "Regions",columns=['Region','CHR','MiddlePosition'])
+    print(regions.dtypes)
+    N_reg = regions.Region.max()  # Keep biggest element in Region column
+    binf = regions.Region.iloc[0]
+    chr_considered = regions.CHR.unique()
+    length_chr = regions.groupby("CHR").MiddlePosition.max() / 10 ** 6
+    length_chr.loc[0] = 0
+    return N_reg,binf,chr_considered,length_chr
+
+def create_global_plot(work_file_path: str, global_plot_path: str, chunk_size:int =default_chunk_size):
     """
     create_global_plot
     generate genome-wide manhattan plot for a given set of phenotypes
     """
 
-    regions = read_hdf(work_file_path, "Regions")
-    chr_length = regions.groupby('CHR').max().position
-    N_reg= regions.Region.max()
+    N_reg,binf,chr_considered,length_chr=get_info_4_global_plot(work_file_path)
     maxy = 0
 
     fig = plt.figure(figsize=(30, 12))
     ax = fig.add_subplot(111)
 
-    chunk_size = 50
     colors = [
         '#4287f5',
         'orangered'
         ]
-    binf=regions.Region.iloc[0]
-    bsup= binf+chunk_size
-    chr_considered= regions.CHR.unique()
-    length_chr = regions.groupby("CHR").MiddlePosition.max() / 10**6
-    length_chr.loc[0] = 0
+
     label = "Chr"+length_chr.loc[chr_considered].index.astype("str")
 
     lab_pos = length_chr.loc[chr_considered]/2
@@ -63,7 +67,7 @@ def create_global_plot(work_file_path: str, global_plot_path: str):
     pos_shift.index = pos_shift.index +1
     pos_shift.loc[chr_considered[0]] = 0
     lab_pos = lab_pos + [pos_shift.loc[i] for i in chr_considered]
-
+    bsup = binf + chunk_size
     while binf < N_reg:
         df = read_hdf(work_file_path, "SumStatTab", columns=["CHR","position", 'JASS_PVAL', "Region"], where = "Region >= {0} and Region < {1}".format(binf, bsup))
         binf+= chunk_size
diff --git a/jass/test/expected_graphs/expected_global_plot.png b/jass/test/expected_graphs/expected_global_plot.png
new file mode 100644
index 0000000000000000000000000000000000000000..a1896d6c7795ac13c1aa6e542cf41f7246e43a08
Binary files /dev/null and b/jass/test/expected_graphs/expected_global_plot.png differ
diff --git a/jass/test/test_plots.py b/jass/test/test_plots.py
index ad0012abe26c219928e091f836c1a26e87034b9f..99d12721b4baa6fe3789b16b8d06b819ce360235 100644
--- a/jass/test/test_plots.py
+++ b/jass/test/test_plots.py
@@ -3,8 +3,10 @@
 from __future__ import absolute_import
 import os, shutil, tempfile
 from pathlib import Path
-
+import matplotlib as plt
 from jass.models import plots
+from PIL import Image
+import numpy as np
 
 from . import JassTestCase
 
@@ -12,10 +14,13 @@ from . import JassTestCase
 class TestPlots(JassTestCase):
 
     test_folder = "data_real"
+    expected_res_folder = "expected_graphs"
+    #expected_res_folder="baseline_images/test_plot"
 
     def setUp(self):
         # Create a temporary directory
         self.test_dir = Path(tempfile.mkdtemp())
+        self.ref_res_dir=Path(os.path.join(os.path.dirname(os.path.abspath(__file__)), self.expected_res_folder))
         self.worktable_hdf_path = self.get_file_path_fn("worktable-withnans.hdf5")
 
     def tearDown(self):
@@ -24,7 +29,15 @@ class TestPlots(JassTestCase):
         pass
 
     def test_create_global_plot(self):
+        #import shutil
+        #print(plt.rcParams)
         plots.create_global_plot(self.worktable_hdf_path, self.test_dir / "global_plot.png")
+        img_new=Image.open(self.test_dir /"global_plot.png")
+        img_ref=Image.open(self.ref_res_dir / "expected_global_plot.png")
+        sum_sq_diff = np.sum((np.asarray(img_new).astype('float') - np.asarray(img_ref).astype('float')) ** 2)
+        print("sum_sq_diff=",sum_sq_diff)
+        assert(sum_sq_diff==0.0)
+
 
     def test_create_qq_plot(self):
         plots.create_qq_plot(self.worktable_hdf_path, self.test_dir / "qq_plot.png")