diff --git a/jass/models/plots.py b/jass/models/plots.py index 3953d0f3419db0e4bbebd31e675a7d54a0be2991..ef60219ef9e54f91a5016d74d3c6851f78f69a2d 100644 --- a/jass/models/plots.py +++ b/jass/models/plots.py @@ -17,10 +17,10 @@ from matplotlib import colors import matplotlib.patches as mpatches from scipy.stats import norm, chi2 import seaborn as sns -import os from pandas import DataFrame, read_hdf import pandas as pd +default_chunk_size=50 def replaceZeroes(df): """ @@ -32,30 +32,34 @@ def replaceZeroes(df): df.values[df.values == 0] = min_nonzero return df -def create_global_plot(work_file_path: str, global_plot_path: str): + +def get_info_4_global_plot(work_file_path: str): + regions = read_hdf(work_file_path, "Regions",columns=['Region','CHR','MiddlePosition']) + print(regions.dtypes) + N_reg = regions.Region.max() # Keep biggest element in Region column + binf = regions.Region.iloc[0] + chr_considered = regions.CHR.unique() + length_chr = regions.groupby("CHR").MiddlePosition.max() / 10 ** 6 + length_chr.loc[0] = 0 + return N_reg,binf,chr_considered,length_chr + +def create_global_plot(work_file_path: str, global_plot_path: str, chunk_size:int =default_chunk_size): """ create_global_plot generate genome-wide manhattan plot for a given set of phenotypes """ - regions = read_hdf(work_file_path, "Regions") - chr_length = regions.groupby('CHR').max().position - N_reg= regions.Region.max() + N_reg,binf,chr_considered,length_chr=get_info_4_global_plot(work_file_path) maxy = 0 fig = plt.figure(figsize=(30, 12)) ax = fig.add_subplot(111) - chunk_size = 50 colors = [ '#4287f5', 'orangered' ] - binf=regions.Region.iloc[0] - bsup= binf+chunk_size - chr_considered= regions.CHR.unique() - length_chr = regions.groupby("CHR").MiddlePosition.max() / 10**6 - length_chr.loc[0] = 0 + label = "Chr"+length_chr.loc[chr_considered].index.astype("str") lab_pos = length_chr.loc[chr_considered]/2 @@ -63,7 +67,7 @@ def create_global_plot(work_file_path: str, global_plot_path: str): pos_shift.index = pos_shift.index +1 pos_shift.loc[chr_considered[0]] = 0 lab_pos = lab_pos + [pos_shift.loc[i] for i in chr_considered] - + bsup = binf + chunk_size while binf < N_reg: df = read_hdf(work_file_path, "SumStatTab", columns=["CHR","position", 'JASS_PVAL', "Region"], where = "Region >= {0} and Region < {1}".format(binf, bsup)) binf+= chunk_size diff --git a/jass/test/expected_graphs/expected_global_plot.png b/jass/test/expected_graphs/expected_global_plot.png new file mode 100644 index 0000000000000000000000000000000000000000..a1896d6c7795ac13c1aa6e542cf41f7246e43a08 Binary files /dev/null and b/jass/test/expected_graphs/expected_global_plot.png differ diff --git a/jass/test/test_plots.py b/jass/test/test_plots.py index ad0012abe26c219928e091f836c1a26e87034b9f..99d12721b4baa6fe3789b16b8d06b819ce360235 100644 --- a/jass/test/test_plots.py +++ b/jass/test/test_plots.py @@ -3,8 +3,10 @@ from __future__ import absolute_import import os, shutil, tempfile from pathlib import Path - +import matplotlib as plt from jass.models import plots +from PIL import Image +import numpy as np from . import JassTestCase @@ -12,10 +14,13 @@ from . import JassTestCase class TestPlots(JassTestCase): test_folder = "data_real" + expected_res_folder = "expected_graphs" + #expected_res_folder="baseline_images/test_plot" def setUp(self): # Create a temporary directory self.test_dir = Path(tempfile.mkdtemp()) + self.ref_res_dir=Path(os.path.join(os.path.dirname(os.path.abspath(__file__)), self.expected_res_folder)) self.worktable_hdf_path = self.get_file_path_fn("worktable-withnans.hdf5") def tearDown(self): @@ -24,7 +29,15 @@ class TestPlots(JassTestCase): pass def test_create_global_plot(self): + #import shutil + #print(plt.rcParams) plots.create_global_plot(self.worktable_hdf_path, self.test_dir / "global_plot.png") + img_new=Image.open(self.test_dir /"global_plot.png") + img_ref=Image.open(self.ref_res_dir / "expected_global_plot.png") + sum_sq_diff = np.sum((np.asarray(img_new).astype('float') - np.asarray(img_ref).astype('float')) ** 2) + print("sum_sq_diff=",sum_sq_diff) + assert(sum_sq_diff==0.0) + def test_create_qq_plot(self): plots.create_qq_plot(self.worktable_hdf_path, self.test_dir / "qq_plot.png")