Select Git revision
plot_lfclfc_scatter.py
plot_lfclfc_scatter.py 19.33 KiB
#!/usr/bin/env python3
# vim: set fileencoding=<utf-8> :
"""This script reads data from "tidy" files and makes plots out of
it, at the same scale.
It also outputs a table containing the plotted data points."""
import argparse
import os
import sys
import warnings
from operator import attrgetter, contains
# from functools import partial
import matplotlib as mpl
# To be able to run the script without a defined $DISPLAY
# mpl.use("PDF")
from matplotlib.backends.backend_pgf import FigureCanvasPgf
import pandas as pd
from cytoolz import compose, concat, curry
from libhts import plot_scatter
from libworkflows import save_plot, strip_split
def formatwarning(message, category, filename, lineno, line):
"""Used to format warning messages."""
return "%s:%s: %s: %s\n" % (filename, lineno, category.__name__, message)
warnings.formatwarning = formatwarning
OPJ = os.path.join
OPB = os.path.basename
# https://stackoverflow.com/a/42768093/1878788
mpl.backend_bases.register_backend('pdf', FigureCanvasPgf)
TEX_PARAMS = {
"text.usetex": True, # use LaTeX to write all text
"pgf.rcfonts": False, # Ignore Matplotlibrc
"pgf.texsystem": "lualatex", # hoping to avoid memory issues
"pgf.preamble": [
r'\usepackage{color}' # xcolor for colours
]
}
mpl.rcParams.update(TEX_PARAMS)
# mpl.rcParams["figure.figsize"] = 2, 4
mpl.rcParams["font.sans-serif"] = [
"Arial", "Liberation Sans", "Bitstream Vera Sans"]
mpl.rcParams["font.family"] = "sans-serif"
# from matplotlib import numpy as np
# import matplotlib.pyplot as plt
# https://jakevdp.github.io/blog/2014/10/16/how-bad-is-your-colormap/
# def grayify_cmap(cmap):
# """Return a grayscale version of the colormap"""
# cmap = plt.cm.get_cmap(cmap)
# colors = cmap(np.arange(cmap.N))
# # convert RGBA to perceived greyscale luminance
# # cf. http://alienryderflex.com/hsp.html
# rgb_weight = [0.299, 0.587, 0.114]
# luminance = np.sqrt(np.dot(colors[:, :3] ** 2, rgb_weight))
# colors[:, :3] = luminance[:, np.newaxis]
# return cmap.from_list(cmap.name + "_grayscale", colors, cmap.N)
def get_gene_list(filename):
"""Given the path to a file containing gene identifiers,
extracts the identifiers and the base name of the file.
Returns a pair (list_name, gene_list)."""
(list_name, _) = os.path.splitext(OPB(filename))
if list_name[-4:] == "_ids":
list_name = list_name[:-4]
with open(filename, "r") as infile:
gene_list = [
strip_split(line)[0] for line in infile.readlines()
if line[0] != "#"]
return (list_name, gene_list)
class Scatterplot:
"""A 2-dimension scatterplot."""
__slots__ = ("data", "x_label", "y_label", "grouping_col")
@staticmethod
def fuse_columns(row):
"""Joins text from columns in a pandas DataFrame row with underscores.
Intended to be used with *DataFrame.apply*."""
return "_".join(map(str, row.values))
def __init__(self,
x_input_file,
y_input_file,
x_column,
y_column,
labels,
extra_cols=None):
if extra_cols is None:
x_usecols = ["gene", x_column].__contains__
y_usecols = ["gene", y_column].__contains__
else:
x_usecols = ["gene", x_column, *extra_cols].__contains__
y_usecols = ["gene", y_column, *extra_cols].__contains__
x_data = pd.read_csv(
x_input_file, sep="\t", index_col="gene", usecols=x_usecols).rename(
columns={x_column: "x"})
y_data = pd.read_csv(
y_input_file, sep="\t", index_col="gene", usecols=y_usecols).rename(
columns={y_column: "y"})
# Just some experiments
# from cytoolz import merge_with
# from cytoolz.curried import merge_with as cmerge
# common = x_data.index.intersection(y_data.index)
# as_dicts = merge_with(
# cmerge(set),
# x_data.loc[common].to_dict(orient="index"),
# y_data.loc[common].to_dict(orient="index"))
# data = pd.DataFrame(as_dicts).T
self.data = pd.merge(
x_data, y_data,
left_index=True, right_index=True, validate="one_to_one")
if extra_cols is not None:
extra_cols = list(concat((
[colname] if colname in self.data.columns
else [f"{colname}_x", f"{colname}_y"]
for colname in extra_cols)))
# if only one element in extra_cols, it will be overridden:
# "_".join(extra_cols) == extra_cols[0]
if len(extra_cols) > 1:
self.data = self.data.assign(**{
"_".join(extra_cols): self.data[extra_cols].apply(
Scatterplot.fuse_columns, axis=1)})
self.grouping_col = "_".join(extra_cols)
else:
self.grouping_col = None
(self.x_label, self.y_label) = labels
def apply_selector(self, selector, chose_from=None):
"""Returns a list of gene ids based on a *selector* string.
The *selector* string will be used as a query string on *self.data*.
If *chose_from* is not empty, gene ids will only been selected
if they belong to *chose_from*."""
if chose_from is None:
chose_from = {}
else:
chose_from = set(chose_from)
if chose_from:
return [gene_id for gene_id in self.data.query(selector).index
if gene_id in chose_from]
return [gene_id for gene_id in self.data.query(selector).index]
def plot_maker(self, grouping=None, group2colour=None, **kwargs):
"""Builds a plotting function that can colour dots based on them
belonging to a group defined by *grouping*."""
def plot_lfclfc_scatter():
"""Generates the scatterplot, returns its legend so that
*save_plot* can include it in the bounding box."""
# fig, axis = plot_scatter(
# print(kwargs["x_range"])
try:
axis = plot_scatter(
self.data,
"x", "y",
grouping=grouping,
group2colour=group2colour,
**kwargs)
except ValueError as err:
if str(err) == "No data to plot.":
warnings.warn("No data to plot.")
return None
else:
raise
# Lines indicating 2-fold threshold.
# Assumes the data are in log2 fold changes
line_style = {
"linewidth": 0.5, "color": "0.5", "linestyle": "dashed"}
if "x_range" in kwargs:
(x_annot_loc, _) = kwargs["x_range"]
else:
x_annot_loc = min(self.data.x)
if "y_range" in kwargs:
(y_annot_loc, _) = kwargs["y_range"]
else:
y_annot_loc = min(self.data.y)
axis.axhline(y=1, **line_style)
axis.annotate(
f"y = 1", xy=(x_annot_loc, 1), xycoords="data",
horizontalalignment='left',
verticalalignment='bottom',
size="x-small",
color=line_style["color"])
axis.axhline(y=-1, **line_style)
axis.annotate(
f"y = -1", xy=(x_annot_loc, -1), xycoords="data",
horizontalalignment='left',
verticalalignment='top',
size="x-small",
color=line_style["color"])
axis.axvline(x=1, **line_style)
axis.annotate(
f"x = -1", xy=(-1, y_annot_loc), xycoords="data",
horizontalalignment='right',
verticalalignment='bottom',
rotation=90, size="x-small",
color=line_style["color"])
axis.axvline(x=-1, **line_style)
axis.annotate(
f"x = 1", xy=(1, y_annot_loc), xycoords="data",
horizontalalignment='left',
verticalalignment='bottom',
rotation=90, size="x-small",
color=line_style["color"])
# Number of genes beyond lfc thresholds, in each quadrant
# up_up = 100 * len(self.data.query(
# f"x > 1 & y > 1")) / len(self.data)
# up_down = 100 * len(self.data.query(
# f"x > 1 & y < 1")) / len(self.data)
# down_up = 100 * len(self.data.query(
# f"x < 1 & y > 1")) / len(self.data)
# down_down = 100 * len(self.data.query(
# f"x < 1 & y < 1")) / len(self.data)
up_up = self.data.query(
"x > 1 & y > 1")
up_down = self.data.query(
"x > 1 & y < -1")
down_up = self.data.query(
"x < -1 & y > 1")
down_down = self.data.query(
"x < -1 & y < -1")
if isinstance(grouping, (list, tuple)):
try:
(_, colour) = group2colour
except (ValueError, TypeError):
colour = "black"
select_ingroup = pd.Index(grouping).intersection
ingroup_up_up = " (\\textcolor{%s}{%d})" % (
colour,
len(up_up.loc[select_ingroup(up_up.index)]),)
ingroup_up_down = " (\\textcolor{%s}{%d})" % (
colour,
len(up_down.loc[select_ingroup(up_down.index)]))
ingroup_down_up = " (\\textcolor{%s}{%d})" % (
colour,
len(down_up.loc[select_ingroup(down_up.index)]))
ingroup_down_down = " (\\textcolor{%s}{%d})" % (
colour,
len(down_down.loc[select_ingroup(down_down.index)]))
else:
ingroup_up_up = ""
ingroup_up_down = ""
ingroup_down_up = ""
ingroup_down_down = ""
axis.annotate(
f"{len(up_up)}{ingroup_up_up}",
xy=(0.95, 0.95), xycoords="axes fraction",
size="x-small", color=line_style["color"],
horizontalalignment="right",
verticalalignment="top")
axis.annotate(
f"{len(up_down)}{ingroup_up_down}",
xy=(0.95, 0.05), xycoords="axes fraction",
size="x-small", color=line_style["color"],
horizontalalignment="right",
verticalalignment="bottom")
axis.annotate(
f"{len(down_up)}{ingroup_down_up}",
xy=(0.05, 0.95), xycoords="axes fraction",
size="x-small", color=line_style["color"],
horizontalalignment="left",
verticalalignment="top")
axis.annotate(
f"{len(down_down)}{ingroup_down_down}",
xy=(0.05, 0.05), xycoords="axes fraction",
size="x-small", color=line_style["color"],
horizontalalignment="left",
verticalalignment="bottom")
axis.set_xlabel(self.x_label, fontsize=17)
axis.set_ylabel(self.y_label, fontsize=17)
# This doesn't work with plt.axis("equal")
# if "x_range" in kwargs:
# (xmin, xmax) = kwargs["x_range"]
# print("setting x_range")
# axis.set_xlim((xmin, xmax))
# if "y_range" in kwargs:
# (ymin, ymax) = kwargs["y_range"]
# print("setting y_range")
# axis.set_ylim((ymin, ymax))
# Move legend to middle top
# axis.legend(loc="upper center")
if grouping is not None:
legend = axis.legend(
bbox_to_anchor=(0, 1),
bbox_transform=axis.transAxes, loc="lower left")
return legend,
return None
# Not working:
# fig.tight_layout()
# strange behaviour (interaction with tight_layout?)
# fig.subplots_adjust(top=1.5)
# fig.subplots_adjust(top=0.75)
# TODO: force ticks to be integers
# Return a tuple of "extra artists",
# to correctly define the bounding box
return plot_lfclfc_scatter
def save_plot(self, outfile, grouping=None, group2colour=None, **kwargs):
"""Creates the plotting function and transmits it for execution
to the function that really does the saving."""
if grouping is None and self.grouping_col is not None:
grouping = self.grouping_col
# TODO: How to have it square?
# if "x_range" in kwargs and "y_range" in kwargs:
# equal_axes = False
# else:
# equal_axes = True
equal_axes = True
save_plot(
outfile,
self.plot_maker(
grouping=grouping, group2colour=group2colour, **kwargs),
equal_axes=equal_axes,
tight=True)
def main():
"""Main function of the program."""
print(" ".join(sys.argv))
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"-x", "--x_input_data",
help="File containing a saved pandas DataFrame.\n"
"It should be tab-separated data with one header line.\n"
"The index column should be named \"gene\"",
required=True)
parser.add_argument(
"-y", "--y_input_data",
help="File containing a saved pandas DataFrame.\n"
"It should be tab-separated data with one header line.\n"
"The index column should be named \"gene\"",
required=True)
parser.add_argument(
"--x_column",
help="Name of the column to use from the table given with"
"--x_input_data. Default is based on DESeq2 results.",
default="log2FoldChange")
parser.add_argument(
"--y_column",
help="Name of the column to use from the table given with"
"--y_input_data. Default is based on DESeq2 results.",
default="log2FoldChange")
parser.add_argument(
"--x_label",
help="Name to label the column selected for the x axis",
required=True)
parser.add_argument(
"--y_label",
help="Name to label the column selected for the y axis",
required=True)
parser.add_argument(
"--data_range",
help="min and max of the data values to display. "
"If the range is too narrow to plot data, it will be ignored.",
type=int,
nargs=2,
default=[-12, 12])
parser.add_argument(
"--extra_cols",
help="Columns containing categorical information "
"to be used to colour points.",
nargs="*")
parser.add_argument(
"-d", "--plot_dir",
required=True,
help="Directory in which scatterplots should be written.")
parser.add_argument(
"-p", "--plot_name",
help="Base name for the plot files.")
parser.add_argument(
"-g", "--gene_list",
help="File containing identifiers of genes to highlight in colour.")
parser.add_argument(
"--gene_lists",
help="Files containing identifiers of genes, "
"so that they can be referenced in --selector.")
parser.add_argument(
"-s", "--selector",
help="Pandas query string to select rows in the merged "
"x and y input data. If a column name belongs to both "
"the x and y data, append the _x or _y suffix to desambiguate.")
parser.add_argument(
"--selection_label",
help="Label to use for the gene list given with --gene_list "
"or obtained by applying the --selector option.")
parser.add_argument(
"-c", "--colour",
default="red",
help="Colour to use for the elements in the list given by option "
"--gene_list or --selector.")
# parser.add_argument(
# "-t", "--transform",
# help="log2, log10, or a linear scale to apply.",
# default=0)
parser.add_argument(
"--plot_regression",
help="Use this option to plot the regression line.",
default=False,
action="store_true")
args = parser.parse_args()
# if args.plot_diagonal:
# Globals.plot_diagonal = True
plot_data = Scatterplot(
args.x_input_data,
args.y_input_data,
args.x_column,
args.y_column,
(args.x_label, args.y_label),
extra_cols=args.extra_cols)
# if args.transform == "log2":
# transform = 2
# elif args.transform == "log10":
# transform = 10
# else:
# transform = int(args.transform)
if args.gene_list:
(list_name, base_gene_list) = get_gene_list(args.gene_list)
else:
base_gene_list = []
if args.selector:
msg = """A label should be associated to the query
given with the --selector option.
Use the --selection_label option to set this.\n"""
assert args.selection_label, msg
gene_list = plot_data.apply_selector(args.selector, base_gene_list)
else:
gene_list = list(plot_data.data.index.intersection(base_gene_list))
if args.plot_name:
# https://stackoverflow.com/a/14364249/1878788
os.makedirs(args.plot_dir, exist_ok=True)
out_pdf = OPJ(
args.plot_dir,
"%s.pdf" % args.plot_name)
out_table = OPJ(
args.plot_dir,
"%s.tsv" % args.plot_name)
out_log = OPJ(
args.plot_dir,
"%s.log" % args.plot_name)
with open(out_log, "w") as log_file:
print(" \\\n\t".join(sys.argv), file=log_file)
if gene_list:
plot_data.data.assign(hightlighted=plot_data.data.apply(
# apply takes a function of row
# get the row name
# check whether it belongs to gene_list
compose(curry(contains)(gene_list), attrgetter("name")),
axis=1)).to_csv(
out_table, sep="\t")
if args.selection_label:
list_name = args.selection_label
# if args.gene_list:
# if args.selection_label:
# list_name = args.selection_label
# else:
# (list_name, _) = os.path.splitext(OPB(args.gene_list))
# if list_name[-4:] == "_ids":
# list_name = list_name[:-4]
# elif args.selector:
# list_name = args.selection_label
plot_data.save_plot(
# args.x_axis, args.y_axis,
out_pdf,
grouping=gene_list, group2colour=(list_name, args.colour),
x_range=args.data_range,
y_range=args.data_range,
# x_range=(-12, 12),
# y_range=(-12, 12),
axes_style={
"linewidth": 0.5, "color": "0.5", "linestyle": "-"},
regression=args.plot_regression)
else:
plot_data.data.to_csv(out_table, sep="\t")
plot_data.save_plot(
# args.x_axis, args.y_axis,
out_pdf,
x_range=args.data_range,
y_range=args.data_range,
regression=args.plot_regression)
if __name__ == "__main__":
sys.exit(main())