Skip to content
Snippets Groups Projects
Select Git revision
  • 74753080de9c4a801619f654fbfbcf3906b9c565
  • master default protected
2 results

plot_lfclfc_scatter.py

Blame
  • user avatar
    Blaise Li authored
    Also switched to read_csv to remove warnings.
    
    The latex processing fails when no X server is available, it seems that
    xcolor package is not loaded.
    74753080
    History
    plot_lfclfc_scatter.py 19.33 KiB
    #!/usr/bin/env python3
    # vim: set fileencoding=<utf-8> :
    """This script reads data from "tidy" files and makes plots out of
    it, at the same scale.
    It also outputs a table containing the plotted data points."""
    
    import argparse
    import os
    import sys
    import warnings
    from operator import attrgetter, contains
    # from functools import partial
    import matplotlib as mpl
    # To be able to run the script without a defined $DISPLAY
    # mpl.use("PDF")
    from matplotlib.backends.backend_pgf import FigureCanvasPgf
    import pandas as pd
    from cytoolz import compose, concat, curry
    from libhts import plot_scatter
    from libworkflows import save_plot, strip_split
    
    
    def formatwarning(message, category, filename, lineno, line):
        """Used to format warning messages."""
        return "%s:%s: %s: %s\n" % (filename, lineno, category.__name__, message)
    
    
    warnings.formatwarning = formatwarning
    
    OPJ = os.path.join
    OPB = os.path.basename
    
    # https://stackoverflow.com/a/42768093/1878788
    mpl.backend_bases.register_backend('pdf', FigureCanvasPgf)
    TEX_PARAMS = {
        "text.usetex": True,            # use LaTeX to write all text
        "pgf.rcfonts": False,           # Ignore Matplotlibrc
        "pgf.texsystem": "lualatex",  # hoping to avoid memory issues
        "pgf.preamble": [
            r'\usepackage{color}'     # xcolor for colours
        ]
    }
    mpl.rcParams.update(TEX_PARAMS)
    # mpl.rcParams["figure.figsize"] = 2, 4
    mpl.rcParams["font.sans-serif"] = [
        "Arial", "Liberation Sans", "Bitstream Vera Sans"]
    mpl.rcParams["font.family"] = "sans-serif"
    
    
    # from matplotlib import numpy as np
    # import matplotlib.pyplot as plt
    # https://jakevdp.github.io/blog/2014/10/16/how-bad-is-your-colormap/
    # def grayify_cmap(cmap):
    #     """Return a grayscale version of the colormap"""
    #     cmap = plt.cm.get_cmap(cmap)
    #     colors = cmap(np.arange(cmap.N))
    #     # convert RGBA to perceived greyscale luminance
    #     # cf. http://alienryderflex.com/hsp.html
    #     rgb_weight = [0.299, 0.587, 0.114]
    #     luminance = np.sqrt(np.dot(colors[:, :3] ** 2, rgb_weight))
    #     colors[:, :3] = luminance[:, np.newaxis]
    #     return cmap.from_list(cmap.name + "_grayscale", colors, cmap.N)
    def get_gene_list(filename):
        """Given the path to a file containing gene identifiers,
        extracts the identifiers and the base name of the file.
        Returns a pair (list_name, gene_list)."""
        (list_name, _) = os.path.splitext(OPB(filename))
        if list_name[-4:] == "_ids":
            list_name = list_name[:-4]
        with open(filename, "r") as infile:
            gene_list = [
                strip_split(line)[0] for line in infile.readlines()
                if line[0] != "#"]
        return (list_name, gene_list)
    
    
    class Scatterplot:
        """A 2-dimension scatterplot."""
        __slots__ = ("data", "x_label", "y_label", "grouping_col")
    
        @staticmethod
        def fuse_columns(row):
            """Joins text from columns in a pandas DataFrame row with underscores.
            Intended to be used with *DataFrame.apply*."""
            return "_".join(map(str, row.values))
    
        def __init__(self,
                     x_input_file,
                     y_input_file,
                     x_column,
                     y_column,
                     labels,
                     extra_cols=None):
            if extra_cols is None:
                x_usecols = ["gene", x_column].__contains__
                y_usecols = ["gene", y_column].__contains__
            else:
                x_usecols = ["gene", x_column, *extra_cols].__contains__
                y_usecols = ["gene", y_column, *extra_cols].__contains__
            x_data = pd.read_csv(
                x_input_file, sep="\t", index_col="gene", usecols=x_usecols).rename(
                    columns={x_column: "x"})
            y_data = pd.read_csv(
                y_input_file, sep="\t", index_col="gene", usecols=y_usecols).rename(
                    columns={y_column: "y"})
            # Just some experiments
            # from cytoolz import merge_with
            # from cytoolz.curried import merge_with as cmerge
            # common = x_data.index.intersection(y_data.index)
            # as_dicts = merge_with(
            #     cmerge(set),
            #     x_data.loc[common].to_dict(orient="index"),
            #     y_data.loc[common].to_dict(orient="index"))
            # data = pd.DataFrame(as_dicts).T
            self.data = pd.merge(
                x_data, y_data,
                left_index=True, right_index=True, validate="one_to_one")
            if extra_cols is not None:
                extra_cols = list(concat((
                    [colname] if colname in self.data.columns
                    else [f"{colname}_x", f"{colname}_y"]
                    for colname in extra_cols)))
                # if only one element in extra_cols, it will be overridden:
                # "_".join(extra_cols) == extra_cols[0]
                if len(extra_cols) > 1:
                    self.data = self.data.assign(**{
                        "_".join(extra_cols): self.data[extra_cols].apply(
                            Scatterplot.fuse_columns, axis=1)})
                self.grouping_col = "_".join(extra_cols)
            else:
                self.grouping_col = None
            (self.x_label, self.y_label) = labels
    
        def apply_selector(self, selector, chose_from=None):
            """Returns a list of gene ids based on a *selector* string.
            The *selector* string will be used as a query string on *self.data*.
            If *chose_from* is not empty, gene ids will only been selected
            if they belong to *chose_from*."""
            if chose_from is None:
                chose_from = {}
            else:
                chose_from = set(chose_from)
            if chose_from:
                return [gene_id for gene_id in self.data.query(selector).index
                        if gene_id in chose_from]
            return [gene_id for gene_id in self.data.query(selector).index]
    
        def plot_maker(self, grouping=None, group2colour=None, **kwargs):
            """Builds a plotting function that can colour dots based on them
            belonging to a group defined by *grouping*."""
            def plot_lfclfc_scatter():
                """Generates the scatterplot, returns its legend so that
                *save_plot* can include it in the bounding box."""
                # fig, axis = plot_scatter(
                # print(kwargs["x_range"])
                try:
                    axis = plot_scatter(
                        self.data,
                        "x", "y",
                        grouping=grouping,
                        group2colour=group2colour,
                        **kwargs)
                except ValueError as err:
                    if str(err) == "No data to plot.":
                        warnings.warn("No data to plot.")
                        return None
                    else:
                        raise
                # Lines indicating 2-fold threshold.
                # Assumes the data are in log2 fold changes
                line_style = {
                    "linewidth": 0.5, "color": "0.5", "linestyle": "dashed"}
                if "x_range" in kwargs:
                    (x_annot_loc, _) = kwargs["x_range"]
                else:
                    x_annot_loc = min(self.data.x)
                if "y_range" in kwargs:
                    (y_annot_loc, _) = kwargs["y_range"]
                else:
                    y_annot_loc = min(self.data.y)
                axis.axhline(y=1, **line_style)
                axis.annotate(
                    f"y = 1", xy=(x_annot_loc, 1), xycoords="data",
                    horizontalalignment='left',
                    verticalalignment='bottom',
                    size="x-small",
                    color=line_style["color"])
                axis.axhline(y=-1, **line_style)
                axis.annotate(
                    f"y = -1", xy=(x_annot_loc, -1), xycoords="data",
                    horizontalalignment='left',
                    verticalalignment='top',
                    size="x-small",
                    color=line_style["color"])
                axis.axvline(x=1, **line_style)
                axis.annotate(
                    f"x = -1", xy=(-1, y_annot_loc), xycoords="data",
                    horizontalalignment='right',
                    verticalalignment='bottom',
                    rotation=90, size="x-small",
                    color=line_style["color"])
                axis.axvline(x=-1, **line_style)
                axis.annotate(
                    f"x = 1", xy=(1, y_annot_loc), xycoords="data",
                    horizontalalignment='left',
                    verticalalignment='bottom',
                    rotation=90, size="x-small",
                    color=line_style["color"])
                # Number of genes beyond lfc thresholds, in each quadrant
                # up_up = 100 * len(self.data.query(
                #     f"x > 1 & y > 1")) / len(self.data)
                # up_down = 100 * len(self.data.query(
                #     f"x > 1 & y < 1")) / len(self.data)
                # down_up = 100 * len(self.data.query(
                #     f"x < 1 & y > 1")) / len(self.data)
                # down_down = 100 * len(self.data.query(
                #     f"x < 1 & y < 1")) / len(self.data)
                up_up = self.data.query(
                    "x > 1 & y > 1")
                up_down = self.data.query(
                    "x > 1 & y < -1")
                down_up = self.data.query(
                    "x < -1 & y > 1")
                down_down = self.data.query(
                    "x < -1 & y < -1")
                if isinstance(grouping, (list, tuple)):
                    try:
                        (_, colour) = group2colour
                    except (ValueError, TypeError):
                        colour = "black"
                    select_ingroup = pd.Index(grouping).intersection
                    ingroup_up_up = " (\\textcolor{%s}{%d})" % (
                        colour,
                        len(up_up.loc[select_ingroup(up_up.index)]),)
                    ingroup_up_down = " (\\textcolor{%s}{%d})" % (
                        colour,
                        len(up_down.loc[select_ingroup(up_down.index)]))
                    ingroup_down_up = " (\\textcolor{%s}{%d})" % (
                        colour,
                        len(down_up.loc[select_ingroup(down_up.index)]))
                    ingroup_down_down = " (\\textcolor{%s}{%d})" % (
                        colour,
                        len(down_down.loc[select_ingroup(down_down.index)]))
                else:
                    ingroup_up_up = ""
                    ingroup_up_down = ""
                    ingroup_down_up = ""
                    ingroup_down_down = ""
                axis.annotate(
                    f"{len(up_up)}{ingroup_up_up}",
                    xy=(0.95, 0.95), xycoords="axes fraction",
                    size="x-small", color=line_style["color"],
                    horizontalalignment="right",
                    verticalalignment="top")
                axis.annotate(
                    f"{len(up_down)}{ingroup_up_down}",
                    xy=(0.95, 0.05), xycoords="axes fraction",
                    size="x-small", color=line_style["color"],
                    horizontalalignment="right",
                    verticalalignment="bottom")
                axis.annotate(
                    f"{len(down_up)}{ingroup_down_up}",
                    xy=(0.05, 0.95), xycoords="axes fraction",
                    size="x-small", color=line_style["color"],
                    horizontalalignment="left",
                    verticalalignment="top")
                axis.annotate(
                    f"{len(down_down)}{ingroup_down_down}",
                    xy=(0.05, 0.05), xycoords="axes fraction",
                    size="x-small", color=line_style["color"],
                    horizontalalignment="left",
                    verticalalignment="bottom")
                axis.set_xlabel(self.x_label, fontsize=17)
                axis.set_ylabel(self.y_label, fontsize=17)
                # This doesn't work with plt.axis("equal")
                # if "x_range" in kwargs:
                #     (xmin, xmax) = kwargs["x_range"]
                #     print("setting x_range")
                #     axis.set_xlim((xmin, xmax))
                # if "y_range" in kwargs:
                #     (ymin, ymax) = kwargs["y_range"]
                #     print("setting y_range")
                #     axis.set_ylim((ymin, ymax))
                # Move legend to middle top
                # axis.legend(loc="upper center")
                if grouping is not None:
                    legend = axis.legend(
                        bbox_to_anchor=(0, 1),
                        bbox_transform=axis.transAxes, loc="lower left")
                    return legend,
                return None
                # Not working:
                # fig.tight_layout()
                # strange behaviour (interaction with tight_layout?)
                # fig.subplots_adjust(top=1.5)
                # fig.subplots_adjust(top=0.75)
                # TODO: force ticks to be integers
                # Return a tuple of "extra artists",
                # to correctly define the bounding box
            return plot_lfclfc_scatter
    
        def save_plot(self, outfile, grouping=None, group2colour=None, **kwargs):
            """Creates the plotting function and transmits it for execution
            to the function that really does the saving."""
            if grouping is None and self.grouping_col is not None:
                grouping = self.grouping_col
            # TODO: How to have it square?
            # if "x_range" in kwargs and "y_range" in kwargs:
            #     equal_axes = False
            # else:
            #     equal_axes = True
            equal_axes = True
            save_plot(
                outfile,
                self.plot_maker(
                    grouping=grouping, group2colour=group2colour, **kwargs),
                equal_axes=equal_axes,
                tight=True)
    
    
    def main():
        """Main function of the program."""
        print(" ".join(sys.argv))
        parser = argparse.ArgumentParser(
            description=__doc__,
            formatter_class=argparse.ArgumentDefaultsHelpFormatter)
        parser.add_argument(
            "-x", "--x_input_data",
            help="File containing a saved pandas DataFrame.\n"
            "It should be tab-separated data with one header line.\n"
            "The index column should be named \"gene\"",
            required=True)
        parser.add_argument(
            "-y", "--y_input_data",
            help="File containing a saved pandas DataFrame.\n"
            "It should be tab-separated data with one header line.\n"
            "The index column should be named \"gene\"",
            required=True)
        parser.add_argument(
            "--x_column",
            help="Name of the column to use from the table given with"
            "--x_input_data. Default is based on DESeq2 results.",
            default="log2FoldChange")
        parser.add_argument(
            "--y_column",
            help="Name of the column to use from the table given with"
            "--y_input_data. Default is based on DESeq2 results.",
            default="log2FoldChange")
        parser.add_argument(
            "--x_label",
            help="Name to label the column selected for the x axis",
            required=True)
        parser.add_argument(
            "--y_label",
            help="Name to label the column selected for the y axis",
            required=True)
        parser.add_argument(
            "--data_range",
            help="min and max of the data values to display. "
            "If the range is too narrow to plot data, it will be ignored.",
            type=int,
            nargs=2,
            default=[-12, 12])
        parser.add_argument(
            "--extra_cols",
            help="Columns containing categorical information "
            "to be used to colour points.",
            nargs="*")
        parser.add_argument(
            "-d", "--plot_dir",
            required=True,
            help="Directory in which scatterplots should be written.")
        parser.add_argument(
            "-p", "--plot_name",
            help="Base name for the plot files.")
        parser.add_argument(
            "-g", "--gene_list",
            help="File containing identifiers of genes to highlight in colour.")
        parser.add_argument(
            "--gene_lists",
            help="Files containing identifiers of genes, "
            "so that they can be referenced in --selector.")
        parser.add_argument(
            "-s", "--selector",
            help="Pandas query string to select rows in the merged "
            "x and y input data. If a column name belongs to both "
            "the x and y data, append the _x or _y suffix to desambiguate.")
        parser.add_argument(
            "--selection_label",
            help="Label to use for the gene list given with --gene_list "
            "or obtained by applying the --selector option.")
        parser.add_argument(
            "-c", "--colour",
            default="red",
            help="Colour to use for the elements in the list given by option "
            "--gene_list or --selector.")
        # parser.add_argument(
        #     "-t", "--transform",
        #     help="log2, log10, or a linear scale to apply.",
        #     default=0)
        parser.add_argument(
            "--plot_regression",
            help="Use this option to plot the regression line.",
            default=False,
            action="store_true")
        args = parser.parse_args()
        # if args.plot_diagonal:
        #     Globals.plot_diagonal = True
        plot_data = Scatterplot(
            args.x_input_data,
            args.y_input_data,
            args.x_column,
            args.y_column,
            (args.x_label, args.y_label),
            extra_cols=args.extra_cols)
        # if args.transform == "log2":
        #     transform = 2
        # elif args.transform == "log10":
        #     transform = 10
        # else:
        #     transform = int(args.transform)
        if args.gene_list:
            (list_name, base_gene_list) = get_gene_list(args.gene_list)
        else:
            base_gene_list = []
        if args.selector:
            msg = """A label should be associated to the query
            given with the --selector option.
            Use the --selection_label option to set this.\n"""
            assert args.selection_label, msg
            gene_list = plot_data.apply_selector(args.selector, base_gene_list)
        else:
            gene_list = list(plot_data.data.index.intersection(base_gene_list))
        if args.plot_name:
            # https://stackoverflow.com/a/14364249/1878788
            os.makedirs(args.plot_dir, exist_ok=True)
            out_pdf = OPJ(
                args.plot_dir,
                "%s.pdf" % args.plot_name)
            out_table = OPJ(
                args.plot_dir,
                "%s.tsv" % args.plot_name)
            out_log = OPJ(
                args.plot_dir,
                "%s.log" % args.plot_name)
            with open(out_log, "w") as log_file:
                print(" \\\n\t".join(sys.argv), file=log_file)
            if gene_list:
                plot_data.data.assign(hightlighted=plot_data.data.apply(
                    # apply takes a function of row
                    # get the row name
                    # check whether it belongs to gene_list
                    compose(curry(contains)(gene_list), attrgetter("name")),
                    axis=1)).to_csv(
                        out_table, sep="\t")
                if args.selection_label:
                    list_name = args.selection_label
                # if args.gene_list:
                #     if args.selection_label:
                #         list_name = args.selection_label
                #     else:
                #         (list_name, _) = os.path.splitext(OPB(args.gene_list))
                #         if list_name[-4:] == "_ids":
                #             list_name = list_name[:-4]
                # elif args.selector:
                #     list_name = args.selection_label
                plot_data.save_plot(
                    # args.x_axis, args.y_axis,
                    out_pdf,
                    grouping=gene_list, group2colour=(list_name, args.colour),
                    x_range=args.data_range,
                    y_range=args.data_range,
                    # x_range=(-12, 12),
                    # y_range=(-12, 12),
                    axes_style={
                        "linewidth": 0.5, "color": "0.5", "linestyle": "-"},
                    regression=args.plot_regression)
            else:
                plot_data.data.to_csv(out_table, sep="\t")
                plot_data.save_plot(
                    # args.x_axis, args.y_axis,
                    out_pdf,
                    x_range=args.data_range,
                    y_range=args.data_range,
                    regression=args.plot_regression)
    
    
    if __name__ == "__main__":
        sys.exit(main())