Skip to content
Snippets Groups Projects
Select Git revision
  • 1e68e86d842c99d67198c7cebf5363d450eb27c4
  • master default protected
2 results

plot_scatterplot.py

Blame
  • user avatar
    Blaise Li authored
    1e68e86d
    History
    plot_scatterplot.py 21.92 KiB
    #!/usr/bin/env python3
    # Copyright (C) 2020 Blaise Li
    #
    # This program is free software: you can redistribute it and/or modify
    # it under the terms of the GNU General Public License as published by
    # the Free Software Foundation, either version 3 of the License, or
    # (at your option) any later version.
    #
    # This program is distributed in the hope that it will be useful,
    # but WITHOUT ANY WARRANTY; without even the implied warranty of
    # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    # GNU General Public License for more details.
    #
    # You should have received a copy of the GNU General Public License
    # along with this program.  If not, see <https://www.gnu.org/licenses/>.
    """This script reads data from "tidy" files and makes a scatter plot out of it.
    It also outputs a table containing the plotted data points."""
    
    import argparse
    import os
    import sys
    import warnings
    from operator import attrgetter, contains
    # from functools import partial
    import matplotlib as mpl
    # To be able to run the script without a defined $DISPLAY
    # mpl.use("PDF")
    from matplotlib.backends.backend_pgf import FigureCanvasPgf
    import pandas as pd
    from cytoolz import compose, concat, curry
    from libhts import plot_scatter
    from libworkflows import save_plot, strip_split
    
    
    def formatwarning(message, category, filename, lineno, line):
        """Used to format warning messages."""
        return "%s:%s: %s: %s\n" % (filename, lineno, category.__name__, message)
    
    
    warnings.formatwarning = formatwarning
    
    OPJ = os.path.join
    OPB = os.path.basename
    
    # https://stackoverflow.com/a/42768093/1878788
    mpl.backend_bases.register_backend('pdf', FigureCanvasPgf)
    TEX_PARAMS = {
        "text.usetex": True,            # use LaTeX to write all text
        "pgf.rcfonts": False,           # Ignore Matplotlibrc
        "pgf.texsystem": "lualatex",  # hoping to avoid memory issues
        "pgf.preamble": [
            r'\usepackage{color}'     # xcolor for colours
        ]
    }
    mpl.rcParams.update(TEX_PARAMS)
    # mpl.rcParams["figure.figsize"] = 2, 4
    mpl.rcParams["font.sans-serif"] = [
        "Arial", "Liberation Sans", "Bitstream Vera Sans"]
    mpl.rcParams["font.family"] = "sans-serif"
    
    
    # from matplotlib import numpy as np
    # import matplotlib.pyplot as plt
    # https://jakevdp.github.io/blog/2014/10/16/how-bad-is-your-colormap/
    # def grayify_cmap(cmap):
    #     """Return a grayscale version of the colormap"""
    #     cmap = plt.cm.get_cmap(cmap)
    #     colors = cmap(np.arange(cmap.N))
    #     # convert RGBA to perceived greyscale luminance
    #     # cf. http://alienryderflex.com/hsp.html
    #     rgb_weight = [0.299, 0.587, 0.114]
    #     luminance = np.sqrt(np.dot(colors[:, :3] ** 2, rgb_weight))
    #     colors[:, :3] = luminance[:, np.newaxis]
    #     return cmap.from_list(cmap.name + "_grayscale", colors, cmap.N)
    def get_gene_list(filename):
        """Given the path to a file containing gene identifiers,
        extracts the identifiers and the base name of the file.
        Returns a pair (list_name, gene_list)."""
        (list_name, _) = os.path.splitext(OPB(filename))
        if list_name[-4:] == "_ids":
            list_name = list_name[:-4]
        with open(filename, "r") as infile:
            gene_list = [
                strip_split(line)[0] for line in infile.readlines()
                if line[0] != "#"]
        return (list_name, gene_list)
    
    
    class Scatterplot:
        """A 2-dimension scatterplot."""
        __slots__ = ("data", "x_label", "y_label", "grouping_col")
    
        @staticmethod
        def fuse_columns(row):
            """Joins text from columns in a pandas DataFrame row with underscores.
            Intended to be used with *DataFrame.apply*."""
            return "_".join(map(str, row.values))
    
        @staticmethod
        def put_quadrants(axis, x_annot_loc, y_annot_loc, line_style):
            """Delineate quadrants of absolute fold changes over 2.
    
            *axis* is the :class:`matplotlib.axes.Axes` on which to work.
    
            *x_annot_loc* and *y_annot_loc* are the position at which to annotate
            the lines delineating the quadrants.
    
            This works through side effects on *axis*.
            """
            axis.axhline(y=1, **line_style)
            axis.annotate(
                f"y = 1", xy=(x_annot_loc, 1), xycoords="data",
                horizontalalignment='left',
                verticalalignment='bottom',
                size="x-small",
                color=line_style["color"])
            axis.axhline(y=-1, **line_style)
            axis.annotate(
                f"y = -1", xy=(x_annot_loc, -1), xycoords="data",
                horizontalalignment='left',
                verticalalignment='top',
                size="x-small",
                color=line_style["color"])
            axis.axvline(x=1, **line_style)
            axis.annotate(
                f"x = -1", xy=(-1, y_annot_loc), xycoords="data",
                horizontalalignment='right',
                verticalalignment='bottom',
                rotation=90, size="x-small",
                color=line_style["color"])
            axis.axvline(x=-1, **line_style)
            axis.annotate(
                f"x = 1", xy=(1, y_annot_loc), xycoords="data",
                horizontalalignment='left',
                verticalalignment='bottom',
                rotation=90, size="x-small",
                color=line_style["color"])
    
        @staticmethod
        def write_quadrant_counts(axis,
                                  annot_up_up, annot_up_down,
                                  annot_down_up, annot_down_down,
                                  text_color):
            """Annotate quadrants of absolute fold changes over 2.
    
            *axis* is the :class:`matplotlib.axes.Axes` on which to work.
    
            *annot_up_up*, *annot_up_down*, *annot_down_up* and *annot_down_down*
            are the annotations to put in each quadrant.
    
            This works through side effects on *axis*.
            """
            axis.annotate(
                annot_up_up,
                xy=(0.95, 0.95), xycoords="axes fraction",
                size="x-small", color=text_color,
                horizontalalignment="right",
                verticalalignment="top")
            axis.annotate(
                annot_up_down,
                xy=(0.95, 0.05), xycoords="axes fraction",
                size="x-small", color=text_color,
                horizontalalignment="right",
                verticalalignment="bottom")
            axis.annotate(
                annot_down_up,
                xy=(0.05, 0.95), xycoords="axes fraction",
                size="x-small", color=text_color,
                horizontalalignment="left",
                verticalalignment="top")
            axis.annotate(
                annot_down_down,
                xy=(0.05, 0.05), xycoords="axes fraction",
                size="x-small", color=text_color,
                horizontalalignment="left",
                verticalalignment="bottom")
    
        def __init__(self,
                     x_input_file,
                     y_input_file,
                     x_column,
                     y_column,
                     labels,
                     extra_cols=None):
            # usecols can be a callable to filter column names:
            # If callable, the callable function will be evaluated against the
            # column names, returning names where the callable function evaluates
            # to True.
            if extra_cols is None:
                x_usecols = ["gene", x_column].__contains__
                y_usecols = ["gene", y_column].__contains__
            else:
                x_usecols = ["gene", x_column, *extra_cols].__contains__
                y_usecols = ["gene", y_column, *extra_cols].__contains__
            # The columns containing the data to plot might have the same name
            # in the two tables.
            # We rename them to x and y for simplicity.
            x_data = pd.read_csv(
                x_input_file, sep="\t", index_col="gene", usecols=x_usecols).rename(
                    columns={x_column: "x"})
            y_data = pd.read_csv(
                y_input_file, sep="\t", index_col="gene", usecols=y_usecols).rename(
                    columns={y_column: "y"})
            # Just some experiments
            # from cytoolz import merge_with
            # from cytoolz.curried import merge_with as cmerge
            # common = x_data.index.intersection(y_data.index)
            # as_dicts = merge_with(
            #     cmerge(set),
            #     x_data.loc[common].to_dict(orient="index"),
            #     y_data.loc[common].to_dict(orient="index"))
            # data = pd.DataFrame(as_dicts).T
            self.data = pd.merge(
                x_data, y_data,
                left_index=True, right_index=True, validate="one_to_one")
            # Compute a classifier column (to be used to colour points)
            if extra_cols is not None:
                extra_cols = list(concat((
                    [colname] if colname in self.data.columns
                    else [f"{colname}_x", f"{colname}_y"]
                    for colname in extra_cols)))
                # if only one element in extra_cols, it will be overridden:
                # "_".join(extra_cols) == extra_cols[0]
                if len(extra_cols) > 1:
                    self.data = self.data.assign(**{
                        "_".join(extra_cols): self.data[extra_cols].apply(
                            Scatterplot.fuse_columns, axis=1)})
                self.grouping_col = "_".join(extra_cols)
            else:
                self.grouping_col = None
            (self.x_label, self.y_label) = labels
    
        def apply_selector(self, selector, chose_from=None):
            """Returns a list of gene ids based on a *selector* string.
            The *selector* string will be used as a query string on *self.data*.
            If *chose_from* is not empty, gene ids will only been selected
            if they belong to *chose_from*."""
            if chose_from is None:
                chose_from = {}
            else:
                chose_from = set(chose_from)
            if chose_from:
                return [gene_id for gene_id in self.data.query(selector).index
                        if gene_id in chose_from]
            return [gene_id for gene_id in self.data.query(selector).index]
    
        def plot_maker(self,
                       annotate_folds=True,
                       grouping=None,
                       group2colour=None,
                       **kwargs):
            """Builds a plotting function that can colour dots based on them
            belonging to a group defined by *grouping*.
            If *annotate_folds* is True, lines indicating 2-fold thresholds
            will be added to the plot, as well as counts in each quadrants."""
            def plotting_function():
                """Generates the scatterplot, returns its legend so that
                *save_plot* can include it in the bounding box."""
                # fig, axis = plot_scatter(
                # print(kwargs["x_range"])
                try:
                    axis = plot_scatter(
                        self.data,
                        "x", "y",
                        grouping=grouping,
                        group2colour=group2colour,
                        **kwargs)
                except ValueError as err:
                    if str(err) == "No data to plot.":
                        warnings.warn("No data to plot.")
                        return None
                    else:
                        raise
                if annotate_folds:
                    # Lines indicating 2-fold threshold.
                    # Assumes the data are in log2 fold changes
                    line_style = {
                        "linewidth": 0.5, "color": "0.5", "linestyle": "dashed"}
                    if "x_range" in kwargs:
                        (x_annot_loc, _) = kwargs["x_range"]
                    else:
                        x_annot_loc = min(self.data.x)
                    if "y_range" in kwargs:
                        (y_annot_loc, _) = kwargs["y_range"]
                    else:
                        y_annot_loc = min(self.data.y)
                    Scatterplot.put_quadrants(
                        axis, x_annot_loc, y_annot_loc, line_style)
                    # Genes beyond lfc thresholds, in each quadrant
                    up_up = self.data.query(
                        "x > 1 & y > 1")
                    up_down = self.data.query(
                        "x > 1 & y < -1")
                    down_up = self.data.query(
                        "x < -1 & y > 1")
                    down_down = self.data.query(
                        "x < -1 & y < -1")
                    if isinstance(grouping, (list, tuple)):
                        try:
                            (_, colour) = group2colour
                        except (ValueError, TypeError):
                            colour = "black"
                        select_ingroup = pd.Index(grouping).intersection
                        ingroup_up_up = " (\\textcolor{%s}{%d})" % (
                            colour,
                            len(up_up.loc[select_ingroup(up_up.index)]),)
                        ingroup_up_down = " (\\textcolor{%s}{%d})" % (
                            colour,
                            len(up_down.loc[select_ingroup(up_down.index)]))
                        ingroup_down_up = " (\\textcolor{%s}{%d})" % (
                            colour,
                            len(down_up.loc[select_ingroup(down_up.index)]))
                        ingroup_down_down = " (\\textcolor{%s}{%d})" % (
                            colour,
                            len(down_down.loc[select_ingroup(down_down.index)]))
                    else:
                        ingroup_up_up = ""
                        ingroup_up_down = ""
                        ingroup_down_up = ""
                        ingroup_down_down = ""
                    Scatterplot.write_quadrant_counts(
                        axis,
                        f"{len(up_up)}{ingroup_up_up}",
                        f"{len(up_down)}{ingroup_up_down}",
                        f"{len(down_up)}{ingroup_down_up}",
                        f"{len(down_down)}{ingroup_down_down}",
                        line_style["color"])
                axis.set_xlabel(self.x_label, fontsize=17)
                axis.set_ylabel(self.y_label, fontsize=17)
                # This doesn't work with plt.axis("equal")
                # if "x_range" in kwargs:
                #     (xmin, xmax) = kwargs["x_range"]
                #     print("setting x_range")
                #     axis.set_xlim((xmin, xmax))
                # if "y_range" in kwargs:
                #     (ymin, ymax) = kwargs["y_range"]
                #     print("setting y_range")
                #     axis.set_ylim((ymin, ymax))
                # Move legend to middle top
                # axis.legend(loc="upper center")
                if grouping is not None:
                    legend = axis.legend(
                        bbox_to_anchor=(0, 1),
                        bbox_transform=axis.transAxes, loc="lower left")
                    return legend,
                return None
                # Not working:
                # fig.tight_layout()
                # strange behaviour (interaction with tight_layout?)
                # fig.subplots_adjust(top=1.5)
                # fig.subplots_adjust(top=0.75)
                # TODO: force ticks to be integers
                # Return a tuple of "extra artists",
                # to correctly define the bounding box
            return plotting_function
    
        def save_plot(self, outfile,
                      annotate_folds=True, grouping=None, group2colour=None,
                      **kwargs):
            """Creates the plotting function and transmits it for execution
            to the function that really does the saving."""
            if grouping is None and self.grouping_col is not None:
                grouping = self.grouping_col
            # TODO: How to have it square?
            # if "x_range" in kwargs and "y_range" in kwargs:
            #     equal_axes = False
            # else:
            #     equal_axes = True
            if annotate_folds:
                equal_axes = True
            else:
                equal_axes = False
            save_plot(
                outfile,
                self.plot_maker(
                    annotate_folds,
                    grouping=grouping, group2colour=group2colour, **kwargs),
                equal_axes=equal_axes,
                tight=True)
    
    
    def main():
        """Main function of the program."""
        print(" ".join(sys.argv))
        parser = argparse.ArgumentParser(
            description=__doc__,
            formatter_class=argparse.ArgumentDefaultsHelpFormatter)
        parser.add_argument(
            "-x", "--x_input_data",
            help="File containing a saved pandas DataFrame.\n"
            "It should be tab-separated data with one header line.\n"
            "The index column should be named \"gene\"",
            required=True)
        parser.add_argument(
            "-y", "--y_input_data",
            help="File containing a saved pandas DataFrame.\n"
            "It should be tab-separated data with one header line.\n"
            "The index column should be named \"gene\"",
            required=True)
        parser.add_argument(
            "--x_column",
            help="Name of the column to use from the table given with"
            "--x_input_data. Default is based on DESeq2 results.",
            default="log2FoldChange")
        parser.add_argument(
            "--y_column",
            help="Name of the column to use from the table given with"
            "--y_input_data. Default is based on DESeq2 results.",
            default="log2FoldChange")
        parser.add_argument(
            "--x_label",
            help="Name to label the column selected for the x axis",
            required=True)
        parser.add_argument(
            "--y_label",
            help="Name to label the column selected for the y axis",
            required=True)
        parser.add_argument(
            "--data_range",
            help="min and max of the data values to display. "
            "If the range is too narrow to plot data, it will be ignored.",
            type=int,
            nargs=2,
            default=[-12, 12])
        parser.add_argument(
            "--extra_cols",
            help="Columns containing categorical information "
            "to be used to colour points.",
            nargs="*")
        parser.add_argument(
            "-d", "--plot_dir",
            required=True,
            help="Directory in which scatterplots should be written.")
        parser.add_argument(
            "-p", "--plot_name",
            help="Base name for the plot files.")
        parser.add_argument(
            "-g", "--gene_list",
            help="File containing identifiers of genes to highlight in colour.")
        parser.add_argument(
            "--gene_lists",
            help="Files containing identifiers of genes, "
            "so that they can be referenced in --selector.")
        parser.add_argument(
            "-s", "--selector",
            help="Pandas query string to select rows in the merged "
            "x and y input data. If a column name belongs to both "
            "the x and y data, append the _x or _y suffix to desambiguate.")
        parser.add_argument(
            "--selection_label",
            help="Label to use for the gene list given with --gene_list "
            "or obtained by applying the --selector option.")
        parser.add_argument(
            "-c", "--colour",
            default="red",
            help="Colour to use for the elements in the list given by option "
            "--gene_list or --selector.")
        # parser.add_argument(
        #     "-t", "--transform",
        #     help="log2, log10, or a linear scale to apply.",
        #     default=0)
        parser.add_argument(
            "--not_foldchanges",
            help="Use this option to inactivate plotting options for "
            "log2FoldChange.",
            default=False,
            action="store_true")
        parser.add_argument(
            "--plot_regression",
            help="Use this option to plot the regression line.",
            default=False,
            action="store_true")
        args = parser.parse_args()
        # if args.plot_diagonal:
        #     Globals.plot_diagonal = True
        plot_data = Scatterplot(
            args.x_input_data,
            args.y_input_data,
            args.x_column,
            args.y_column,
            (args.x_label, args.y_label),
            extra_cols=args.extra_cols)
        # if args.transform == "log2":
        #     transform = 2
        # elif args.transform == "log10":
        #     transform = 10
        # else:
        #     transform = int(args.transform)
        if args.gene_list:
            (list_name, base_gene_list) = get_gene_list(args.gene_list)
        else:
            base_gene_list = []
        if args.selector:
            msg = """A label should be associated to the query
            given with the --selector option.
            Use the --selection_label option to set this.\n"""
            assert args.selection_label, msg
            gene_list = plot_data.apply_selector(args.selector, base_gene_list)
        else:
            gene_list = list(plot_data.data.index.intersection(base_gene_list))
        if args.plot_name:
            # https://stackoverflow.com/a/14364249/1878788
            os.makedirs(args.plot_dir, exist_ok=True)
            out_pdf = OPJ(
                args.plot_dir,
                "%s.pdf" % args.plot_name)
            out_table = OPJ(
                args.plot_dir,
                "%s.tsv" % args.plot_name)
            out_log = OPJ(
                args.plot_dir,
                "%s.log" % args.plot_name)
            with open(out_log, "w") as log_file:
                print(" \\\n\t".join(sys.argv), file=log_file)
            if gene_list:
                plot_data.data.assign(hightlighted=plot_data.data.apply(
                    # apply takes a function of row
                    # get the row name
                    # check whether it belongs to gene_list
                    compose(curry(contains)(gene_list), attrgetter("name")),
                    axis=1)).to_csv(
                        out_table, sep="\t")
                if args.selection_label:
                    list_name = args.selection_label
                # if args.gene_list:
                #     if args.selection_label:
                #         list_name = args.selection_label
                #     else:
                #         (list_name, _) = os.path.splitext(OPB(args.gene_list))
                #         if list_name[-4:] == "_ids":
                #             list_name = list_name[:-4]
                # elif args.selector:
                #     list_name = args.selection_label
                plot_data.save_plot(
                    # args.x_axis, args.y_axis,
                    out_pdf,
                    annotate_folds=not args.not_foldchanges,
                    grouping=gene_list, group2colour=(list_name, args.colour),
                    x_range=args.data_range,
                    y_range=args.data_range,
                    # x_range=(-12, 12),
                    # y_range=(-12, 12),
                    axes_style={
                        "linewidth": 0.5, "color": "0.5", "linestyle": "-"},
                    regression=args.plot_regression)
            else:
                plot_data.data.to_csv(out_table, sep="\t")
                plot_data.save_plot(
                    # args.x_axis, args.y_axis,
                    out_pdf,
                    annotate_folds=not args.not_foldchanges,
                    x_range=args.data_range,
                    y_range=args.data_range,
                    regression=args.plot_regression)
    
    
    if __name__ == "__main__":
        sys.exit(main())