Commit c71bca13 authored by Blaise Li's avatar Blaise Li
Browse files

Still fixing tex-compatibility, square MA-plots.

parent 084f5022
from .libhts import do_deseq2, gtf_2_genes_exon_lengths, median_ratio_to_pseudo_ref_size_factors, plot_boxplots, plot_counts_distribution, plot_lfc_distribution, plot_MA, plot_norm_correlations, plot_paired_scatters, plot_scatter, repeat_bed_2_lengths, size_factor_correlations, spikein_gtf_2_lengths, status_setter
from .libhts import (
do_deseq2, gtf_2_genes_exon_lengths, median_ratio_to_pseudo_ref_size_factors,
plot_boxplots, plot_counts_distribution, plot_histo,
plot_lfc_distribution, plot_MA,
plot_norm_correlations, plot_paired_scatters, plot_scatter,
repeat_bed_2_lengths, size_factor_correlations,
spikein_gtf_2_lengths, status_setter)
from math import floor, ceil
from functools import reduce
from re import sub
import warnings
from libworkflows import texscape
def formatwarning(message, category, filename, lineno, line):
......@@ -326,8 +328,7 @@ def plot_norm_correlations(correlations):
#ax2.set_xlabel("Pearson correlation coefficient")
usetex = mpl.rcParams.get("text.usetex", False)
if usetex:
correlations.columns = [sub(
"_", "\_", colname) for colname in correlations.columns]
correlations.columns = [texscape(colname) for colname in correlations.columns]
axis = sns.violinplot(data=correlations, cut=0)
axis.set_ylabel("Pearson correlation coefficient")
......@@ -339,8 +340,8 @@ def plot_counts_distribution(data, xlabel):
#axis.legend(ncol=len(REPS))
usetex = mpl.rcParams.get("text.usetex", False)
if usetex:
xlabel = sub("_", r"\_", xlabel)
data.columns = [sub("_", "\_", colname) for colname in data.columns]
xlabel = texscape(xlabel)
data.columns = [texscape(colname) for colname in data.columns]
try:
axis = data.plot.kde()
except ValueError as e:
......@@ -352,13 +353,43 @@ def plot_counts_distribution(data, xlabel):
axis.set_xlabel(xlabel)
def plot_histo(outfile, data, title=None):
fig = plt.figure(figsize=(15,7))
ax = fig.add_subplot(111)
ax.set_xlim([data.index[0] - 0.5, data.index[-1] + 0.5])
#ax.set_ylim([0, 100])
bar_width = 0.8
#letter2legend = dict(zip("ACGT", "ACGT"))
usetex = mpl.rcParams.get("text.usetex", False)
if usetex:
data.columns = [texscape(colname) for colname in data.columns]
title = sub("_", r"\_", title)
for (read_len, count) in data.iterrows():
plt.bar(
read_len,
count,
align="center",
width=bar_width)
#color=letter2colour[letter],
#label=letter2legend[letter])
ax.legend()
ax.set_xticks(data.index)
ax.set_xticklabels(data.index)
ax.set_xlabel("read length")
ax.set_ylabel("number of reads")
plt.setp(ax.get_xticklabels(), rotation=90)
if title is not None:
plt.title(title)
plt.savefig(outfile)
def plot_boxplots(data, ylabel):
fig = plt.figure(figsize=(6, 12))
axis = fig.add_subplot(111)
usetex = mpl.rcParams.get("text.usetex", False)
if usetex:
ylabel = sub("_", r"\_", ylabel)
data.columns = [sub("_", "\_", colname) for colname in data.columns]
ylabel = texscape(ylabel)
data.columns = [texscape(colname) for colname in data.columns]
data.plot.box(ax=axis)
axis.set_ylabel(ylabel)
for label in axis.get_xticklabels():
......@@ -426,7 +457,7 @@ def plot_lfc_distribution(res, contrast, fold_type=None):
lfc.name = contrast
usetex = mpl.rcParams.get("text.usetex", False)
if usetex:
lfc.columns = [sub("_", "\_", colname) for colname in lfc.columns]
lfc.columns = [texscape(colname) for colname in lfc.columns]
axis = sns.kdeplot(lfc)
axis.set_xlabel(fold_type)
axis.set_ylabel("frequency")
......@@ -452,29 +483,50 @@ def plot_MA(res,
# Make a column indicating whether the gene is DE or NS
data = res.assign(is_DE=res.apply(set_de_status, axis=1))
x_column = "baseMean"
data = data.assign(logx=np.log10(data[x_column]))
if fold_type is None:
y_column = "log2FoldChange"
else:
y_column = fold_type
usetex = mpl.rcParams.get("text.usetex", False)
def scatter_group(group, label, colour, size=1):
"""Plots the data in *group* on the scatterplot."""
if usetex:
label = texscape(label)
group.plot.scatter(
#x=x_column,
x="logx",
y=y_column,
s=size,
#logx=True,
c=colour,
label=label, ax=ax)
if usetex:
data.columns = [texscape(colname) for colname in data.columns]
y_column = texscape(y_column)
de_status_column = "is\_DE"
else:
de_status_column = "is_DE"
# First plot the data in grey and black
for de_status, group in data.groupby("is_DE"):
group.plot.scatter(x=x_column, y=y_column, s=2, logx=True, c=DE2COLOUR[de_status], label=f"{de_status} ({len(group)})", ax=ax)
for (de_status, group) in data.groupby(de_status_column):
label = f"{de_status} ({len(group)})"
colour = DE2COLOUR[de_status]
scatter_group(group, label, colour, size=2)
if grouping is not None:
if isinstance(grouping, str):
# Overlay colours based on the "grouping" column
if group2colour is None:
group2colour = STATUS2COLOUR
for status, group in data.groupby(grouping):
group.plot.scatter(
x=x_column, y=y_column, s=1, logx=True, c=group2colour[status],
label=f"{status} ({len(group)})", ax=ax)
label = f"{status} ({len(group)})"
colour = group2colour[status]
scatter_group(group, label, colour)
else:
(status, colour) = group2colour
row_indices = data.index.intersection(grouping)
try:
data.ix[row_indices].plot.scatter(
x=x_column, y=y_column, s=1, logx=True, c=colour,
label=f"{status} ({len(row_indices)})", ax=ax)
label=f"{status} ({len(row_indices)})"
scatter_group(data.ix[row_indices], label, colour)
except ValueError as e:
if str(e) != "scatter requires x column to be numeric":
print(data.ix[row_indices])
......@@ -485,7 +537,8 @@ def plot_MA(res,
ax.axhline(y=-1, linewidth=0.5, color="0.5", linestyle="dashed")
# TODO: check data basemean range
if mean_range is not None:
ax.set_xlim(mean_range)
# ax.set_xlim(mean_range)
ax.set_xlim(np.log10(mean_range))
if lfc_range is not None:
(lfc_min, lfc_max) = lfc_range
lfc_here_min = getattr(data, y_column).min()
......@@ -494,6 +547,15 @@ def plot_MA(res,
warnings.warn(f"Cannot plot {y_column} data ([{lfc_here_min}, {lfc_here_max}]) in requested range ([{lfc_min}, {lfc_max}])\n")
else:
ax.set_ylim(lfc_range)
# https://stackoverflow.com/a/24867320/1878788
x_ticks = np.arange(
floor(np.ma.masked_invalid(data["logx"]).min()),
ceil(np.ma.masked_invalid(data["logx"]).max()),
1)
x_ticklabels = [r"$10^{{{}}}$".format(tick) for tick in x_ticks]
plt.xticks(x_ticks, x_ticklabels)
ax.set_xlabel(x_column)
def plot_scatter(data,
x_column,
......@@ -579,8 +641,8 @@ def plot_paired_scatters(data, columns=None, hue=None, log_log=False):
columns = data.columns
usetex = mpl.rcParams.get("text.usetex", False)
if usetex:
data.columns = [sub("_", "\_", colname) for colname in data.columns]
columns = [sub("_", "\_", colname) for colname in columns]
data.columns = [texscape(colname) for colname in data.columns]
columns = [texscape(colname) for colname in columns]
g = sns.PairGrid(data, vars=columns, hue=hue, size=8)
#g.map_offdiag(plt.scatter, marker=".")
g.map_lower(plt.scatter, marker=".")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment