Commit c04ad30a authored by Blaise Li's avatar Blaise Li
Browse files

Log and save selection of lfc-lfc plots.

parent 1d91b2cc
#!/usr/bin/env python3 #!/usr/bin/env python3
# vim: set fileencoding=<utf-8> : # vim: set fileencoding=<utf-8> :
"""This script reads data from "tidy" files and makes plots out of """This script reads data from "tidy" files and makes plots out of
it, at the same scale.""" it, at the same scale.
It also outputs a table containing the plotted data points."""
import argparse import argparse
import os import os
import sys import sys
import warnings import warnings
from operator import attrgetter, contains
# from functools import partial
import matplotlib as mpl import matplotlib as mpl
# To be able to run the script without a defined $DISPLAY # To be able to run the script without a defined $DISPLAY
# mpl.use("PDF") # mpl.use("PDF")
from matplotlib.backends.backend_pgf import FigureCanvasPgf from matplotlib.backends.backend_pgf import FigureCanvasPgf
import pandas as pd import pandas as pd
from cytoolz import concat from cytoolz import compose, concat, curry
from libhts import plot_scatter from libhts import plot_scatter
from libworkflows import save_plot, strip_split from libworkflows import save_plot, strip_split
...@@ -86,8 +89,7 @@ class Scatterplot: ...@@ -86,8 +89,7 @@ class Scatterplot:
y_input_file, y_input_file,
x_column, x_column,
y_column, y_column,
x_label, labels,
y_label,
extra_cols=None): extra_cols=None):
if extra_cols is None: if extra_cols is None:
x_usecols = ["gene", x_column].__contains__ x_usecols = ["gene", x_column].__contains__
...@@ -127,8 +129,7 @@ class Scatterplot: ...@@ -127,8 +129,7 @@ class Scatterplot:
self.grouping_col = "_".join(extra_cols) self.grouping_col = "_".join(extra_cols)
else: else:
self.grouping_col = None self.grouping_col = None
self.x_label = x_label (self.x_label, self.y_label) = labels
self.y_label = y_label
def apply_selector(self, selector, chose_from=None): def apply_selector(self, selector, chose_from=None):
"""Returns a list of gene ids based on a *selector* string. """Returns a list of gene ids based on a *selector* string.
...@@ -142,8 +143,7 @@ class Scatterplot: ...@@ -142,8 +143,7 @@ class Scatterplot:
if chose_from: if chose_from:
return [gene_id for gene_id in self.data.query(selector).index return [gene_id for gene_id in self.data.query(selector).index
if gene_id in chose_from] if gene_id in chose_from]
else: return [gene_id for gene_id in self.data.query(selector).index]
return [gene_id for gene_id in self.data.query(selector).index]
def plot_maker(self, grouping=None, group2colour=None, **kwargs): def plot_maker(self, grouping=None, group2colour=None, **kwargs):
"""Builds a plotting function that can colour dots based on them """Builds a plotting function that can colour dots based on them
...@@ -153,12 +153,19 @@ class Scatterplot: ...@@ -153,12 +153,19 @@ class Scatterplot:
*save_plot* can include it in the bounding box.""" *save_plot* can include it in the bounding box."""
# fig, axis = plot_scatter( # fig, axis = plot_scatter(
# print(kwargs["x_range"]) # print(kwargs["x_range"])
axis = plot_scatter( try:
self.data, axis = plot_scatter(
"x", "y", self.data,
grouping=grouping, "x", "y",
group2colour=group2colour, grouping=grouping,
**kwargs) group2colour=group2colour,
**kwargs)
except ValueError as err:
if str(err) == "No data to plot.":
warnings.warn("No data to plot.")
return None
else:
raise
# Lines indicating 2-fold threshold. # Lines indicating 2-fold threshold.
# Assumes the data are in log2 fold changes # Assumes the data are in log2 fold changes
line_style = { line_style = {
...@@ -281,8 +288,7 @@ class Scatterplot: ...@@ -281,8 +288,7 @@ class Scatterplot:
bbox_to_anchor=(0, 1), bbox_to_anchor=(0, 1),
bbox_transform=axis.transAxes, loc="lower left") bbox_transform=axis.transAxes, loc="lower left")
return legend, return legend,
else: return None
return None
# Not working: # Not working:
# fig.tight_layout() # fig.tight_layout()
# strange behaviour (interaction with tight_layout?) # strange behaviour (interaction with tight_layout?)
...@@ -351,10 +357,10 @@ def main(): ...@@ -351,10 +357,10 @@ def main():
parser.add_argument( parser.add_argument(
"--data_range", "--data_range",
help="min and max of the data values to display. " help="min and max of the data values to display. "
"If the range is to narrow to plot data, it will be ignored.", "If the range is too narrow to plot data, it will be ignored.",
type = int, type=int,
nargs = 2, nargs=2,
default = [-12, 12]) default=[-12, 12])
parser.add_argument( parser.add_argument(
"--extra_cols", "--extra_cols",
help="Columns containing categorical information " help="Columns containing categorical information "
...@@ -405,8 +411,7 @@ def main(): ...@@ -405,8 +411,7 @@ def main():
args.y_input_data, args.y_input_data,
args.x_column, args.x_column,
args.y_column, args.y_column,
args.x_label, (args.x_label, args.y_label),
args.y_label,
extra_cols=args.extra_cols) extra_cols=args.extra_cols)
# if args.transform == "log2": # if args.transform == "log2":
# transform = 2 # transform = 2
...@@ -425,14 +430,29 @@ def main(): ...@@ -425,14 +430,29 @@ def main():
assert args.selection_label, msg assert args.selection_label, msg
gene_list = plot_data.apply_selector(args.selector, base_gene_list) gene_list = plot_data.apply_selector(args.selector, base_gene_list)
else: else:
gene_list = base_gene_list gene_list = list(plot_data.data.index.intersection(base_gene_list))
if args.plot_name: if args.plot_name:
# https://stackoverflow.com/a/14364249/1878788 # https://stackoverflow.com/a/14364249/1878788
os.makedirs(args.plot_dir, exist_ok=True) os.makedirs(args.plot_dir, exist_ok=True)
out_pdf = OPJ( out_pdf = OPJ(
args.plot_dir, args.plot_dir,
"%s.pdf" % args.plot_name) "%s.pdf" % args.plot_name)
out_table = OPJ(
args.plot_dir,
"%s.tsv" % args.plot_name)
out_log = OPJ(
args.plot_dir,
"%s.log" % args.plot_name)
with open(out_log, "w") as log_file:
print(" \\\n\t".join(sys.argv), file=log_file)
if gene_list and args.selection_label: if gene_list and args.selection_label:
plot_data.data.assign(hightlighted=plot_data.data.apply(
# apply takes a function of row
# get the row name
# check whether it belongs to gene_list
compose(curry(contains)(gene_list), attrgetter("name")),
axis=1)).to_csv(
out_table, sep="\t")
list_name = args.selection_label list_name = args.selection_label
# if args.gene_list: # if args.gene_list:
# if args.selection_label: # if args.selection_label:
...@@ -455,6 +475,7 @@ def main(): ...@@ -455,6 +475,7 @@ def main():
"linewidth": 0.5, "color": "0.5", "linestyle": "-"}, "linewidth": 0.5, "color": "0.5", "linestyle": "-"},
regression=args.plot_regression) regression=args.plot_regression)
else: else:
plot_data.data.to_csv(out_table, sep="\t")
plot_data.save_plot( plot_data.save_plot(
# args.x_axis, args.y_axis, # args.x_axis, args.y_axis,
out_pdf, out_pdf,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment