Skip to content
Snippets Groups Projects
Select Git revision
  • 769b82597855b41362593c28328f84749c025d19
  • master default protected
  • exponential-backoff-login
  • v1.10.0
  • v1.9.2
  • v1.9.0
  • v1.8.8
  • v1.8.7
  • v1.8.5
  • v1.8.4
  • v1.8.2
  • v1.8
  • v1.7
  • v1.6
  • v1.5
  • v1.4
  • v1.3
  • v1.2
  • v1.1
  • v1.0.1
  • v1.0
  • v0.2.80
  • v0.2.79
23 results

basetheme_bootstrap.js

Blame
  • libhts.py 19.60 KiB
    from functools import reduce
    import warnings
    
    
    def formatwarning(message, category, filename, lineno, line):
        """Used to format warning messages."""
        return "%s:%s: %s: %s\n" % (filename, lineno, category.__name__, message)
    
    
    warnings.formatwarning = formatwarning
    
    import pandas as pd
    # To compute correlation coefficient, and compute linear regression
    from scipy.stats.stats import pearsonr, linregress
    # To compute geometric mean
    from scipy.stats.mstats import gmean
    import matplotlib.pyplot as plt
    import seaborn as sns
    from rpy2.robjects import r, pandas2ri, Formula, StrVector
    as_df = r("as.data.frame")
    from rpy2.rinterface import RRuntimeError
    from rpy2.robjects.packages import importr
    deseq2 = importr("DESeq2")
    from pybedtools import BedTool
    import networkx as nx
    
    
    class Exon(object):
        __slots__ = ("chrom", "start", "end")
        def __init__(self, chrom, start, end):
            self.chrom = chrom
            self.start = start
            self.end = end
    
        def overlap(self, other):
            if self.chrom != other.chrom:
                return False
            return (self.start <= other.start < self.end) or (other.start <= self.start < other.end)
    
        def merge(self, other):
            # Not necessary: can be indirectly linked
            #assert overlap(self, other)
            return Exon(self.chrom, min(self.start, other.start), max(self.end, other.end))
    
        def __len__(self):
            return self.end - self.start
    
    overlap = Exon.overlap
    merge = Exon.merge
    
    class Gene(object):
        """This object contains information obtained from a gtf file."""
        __slots__ = ("gene_id", "exons", "union_exon_length")
        def __init__(self, gene_id):
            self.gene_id = gene_id
            #self.transcripts = {}
            self.exons = nx.Graph()
            self.union_exon_length = None
    
        #def add_transcript(self, feature):
        #    the_id = feature.attrs["transcript_id"]
        #    assert the_id not in self.transcripts
        #    self.transcripts[the_id] = feature
    
        def add_exon(self, feature):
            #the_id = feature.attrs["exon_id"]
            #assert the_id not in self.exons
            #self.exons[the_id] = feature
            exon = Exon(feature.chrom, feature.start, feature.end)
            if exon not in self.exons:
                self.exons.add_node(exon)
    
        # The merging cannot be done on the full BedTool because we dont want
        # to merge together exons not belonging to the same gene.
        def set_union_exon_length(self):
            """The exons are used to make a BedTool, which enables convenient merging of
            overlapping features. The sum of the lengths of the merged exons is returned."""
            if len(self.exons) == 1:
                # No need to merge when there is only one exon
                self.union_exon_length = len(next(iter(self.exons.nodes())))
            else:
                # Too slow
                #self.union_exon_length = sum(map(
                #    len, BedTool(self.exons.values()).merge().features()))
                #self.union_exon_length = 0
                # We group nodes that overlap, and merge them
                #overlapping_exons = nx.quotient_graph(self.exons, overlap)
                #for node in overlapping_exons.nodes():
                #    mex = reduce(merge, node)
                #    self.union_exon_length += len(mex)
                self.union_exon_length = sum((len(reduce(
                    merge, node)) for node in nx.quotient_graph(self.exons, overlap).nodes()))
    
    
    def gtf_2_genes_exon_lengths(gtf_filename):
        """Returns a pandas DataFrame where union exon lengths are associated to gene IDs."""
        gtf_file = open(gtf_filename, "r")
        gtf = BedTool(gtf_file)
        genes = {}
        for feature in gtf.features():
            feat_type = feature[2]
            if feat_type != "exon":
                continue
            attrs = feature.attrs
            gene_id = attrs["gene_id"]
            if gene_id not in genes:
                genes[gene_id] = Gene(gene_id)
            gene = genes[gene_id]
            try:
                gene.add_exon(feature)
            except AssertionError:
                # A given exon may be registered for several transcripts, hence several gtf entries
                already = gene.exons[feature.attrs["exon_id"]]
                assert already.attrs["transcript_id"] != feature.attrs["transcript_id"]
                assert (already.start, already.end) == (feature.start, feature.end)
        for gene in genes.values():
            gene.set_union_exon_length()
        return pd.DataFrame(pd.Series(
            {gene.gene_id : gene.union_exon_length for gene in genes.values()},
            name=("union_exon_len",)).rename_axis("gene"))
    
    
    def repeat_bed_2_lengths(repeat_bed):
        """Computes the lengths of repeatitive elements in a bed file, grouped by families.
        This assumes that the elements have their names composed of the family name,
        then a colon, then a number. For instance:
        Simple_repeat|Simple_repeat|(TTTTTTG)n:1
        Simple_repeat|Simple_repeat|(TTTTTTG)n:2
        Simple_repeat|Simple_repeat|(TTTTTTG)n:3
        Simple_repeat|Simple_repeat|(TTTTTTG)n:4
        -> Simple_repeat|Simple_repeat|(TTTTTTG)n
        Returns a DataFrame associating the summed lengths to the family names.
        """
        # usecols=[1, 2, 3]: start, end, id
        # index_col=2: id (relative to the selected columns)
        start_ends = pd.read_table(repeat_bed, usecols=[1, 2, 3], header=None, index_col=2)
        # bed lengths
        lens = start_ends[2] - start_ends[1]
        lens.name = "union_exon_len"
        repeat_families = [":".join(name.split(":")[:-1]) for name in start_ends.index]
        # The reads assigned to a repeated element can come
        # from the summed length of all the members of the family
        # We call this "gene" for convenience and compatibility
        return pd.DataFrame(lens).assign(gene=repeat_families).groupby("gene").sum()
    
    
    def do_deseq2(cond_names, conditions, counts_data,
                  formula=None, contrast=None, deseq2_args=None):
        """Runs a DESeq2 differential expression analysis."""
        if formula is None:
            formula = Formula("~ lib")
        if contrast is None:
            # FIXME: MUT and REF are not defined
            contrast = StrVector(["lib", MUT, REF])
        if deseq2_args is None:
            deseq2_args = {"betaPrior" : True, "addMLE" : True, "independentFiltering" : True}
        col_data = pd.DataFrame(conditions).assign(
            cond_name=pd.Series(cond_names).values).set_index("cond_name")
        # In case we want contrasts between factor combinations
        if ("lib" in col_data.columns) and ("treat" in col_data.columns):
            col_data = col_data.assign(
                lib_treat = ["%s_%s" % (lib, treat) for (lib, treat) in zip(
                    col_data["lib"], col_data["treat"])])
        # http://stackoverflow.com/a/31206596/1878788
        pandas2ri.activate()  # makes some conversions automatic
        # r_counts_data = pandas2ri.py2ri(counts_data)
        # r_col_data = pandas2ri.py2ri(col_data)
        # r.DESeqDataSetFromMatrix(countData=r_counts_data, colData=r_col_data, design=Formula("~lib"))
        dds = deseq2.DESeqDataSetFromMatrix(
            countData=counts_data,
            colData=col_data,
            design=formula)
        # dds = deseq2.DESeq(dds, betaPrior=deseq2_args["betaPrior"])
        # Decompose into the 3 steps to have more control on the options
        try:
            dds = deseq2.estimateSizeFactors_DESeqDataSet(dds, type="ratio")
        except RRuntimeError as e:
            if sum(counts_data.prod(axis=1)) == 0:
                msg = "".join(["Error occurred in estimateSizeFactors:\n%s\n" % e,
                               "This is probably because every gene has at least one zero.\n",
                               "We will try to use the \"poscounts\" method instead."])
                warnings.warn(msg)
                try:
                    dds = deseq2.estimateSizeFactors_DESeqDataSet(dds, type="poscounts")
                except RRuntimeError as e:
                    msg = "".join(["Error occurred in estimateSizeFactors:\n%s\n" % e,
                                   "We give up."])
                    warnings.warn(msg)
                    raise
                    #print(counts_data.dtypes)
                    #print(counts_data.columns)
                    #print(len(counts_data))
                    #raise
            else:
                raise
        size_factors = pandas2ri.ri2py(as_df(deseq2.sizeFactors_DESeqDataSet(dds)))
        #for cond in cond_names:
        #    #s = size_factors.loc[cond][0]
        #    #(*_, s) = size_factors.loc[cond]
        #pd.DataFrame({cond : size_factors.loc[cond][0] for cond in COND_NAMES}, index=('size_factor',))
        try:
            dds = deseq2.estimateDispersions_DESeqDataSet(dds, fitType="parametric")
        except RRuntimeError as e:
            msg = "".join(["Error occurred in estimateDispersions:\n%s\n" % e,
                        "We will try with fitType=\"local\"."])
            warnings.warn(msg)
            try:
                dds = deseq2.estimateDispersions_DESeqDataSet(dds, fitType="local")
            except RRuntimeError as e:
                msg = "".join(["Error occurred in estimateDispersions:\n%s\n" % e,
                            "We will try with fitType=\"mean\"."])
                warnings.warn(msg)
                try:
                    dds = deseq2.estimateDispersions_DESeqDataSet(dds, fitType="mean")
                except RRuntimeError as e:
                    msg = "".join(["Error occurred in estimateDispersions:\n%s\n" % e,
                                "We give up."])
                    warnings.warn(msg)
                    raise
        dds = deseq2.nbinomWaldTest(dds, betaPrior=deseq2_args["betaPrior"])
        res = pandas2ri.ri2py(as_df(deseq2.results(
            dds,
            contrast=contrast,
            addMLE=deseq2_args["addMLE"],
            independentFiltering=deseq2_args["independentFiltering"])))
        res.index = counts_data.index
        return res, {cond : size_factors.loc[cond][0] for cond in cond_names}
    
    
    def median_ratio_to_pseudo_ref_size_factors(counts_data):
        """Adapted from DESeq paper (doi:10.1186/gb-2010-11-10-r106)
        All libraries are used to define a pseudo-reference, which has
        the geometric mean across libraries for a given gene in *counts_data*.
        For a given library, the median across genes of the ratios to the
        pseudo-reference is used as size factor."""
        # Add pseudo-count to compute the geometric mean, then remove it
        #pseudo_ref = (counts_data + 1).apply(gmean, axis=1) - 1
        # Ignore lines with zeroes instead (may be bad for IP: many zeroes expected):
        pseudo_ref = (counts_data[counts_data.prod(axis=1) > 0]).apply(gmean, axis=1)
        def median_ratio_to_pseudo_ref(col):
            return (col / pseudo_ref).median()
        #size_factors = counts_data.apply(median_ratio_to_pseudo_ref, axis=0)
        return counts_data[counts_data.prod(axis=1) > 0].apply(
            median_ratio_to_pseudo_ref, axis=0)
    
    
    def size_factor_correlations(counts_data, summaries, normalizer):
        """Is there a correlation, across libraries, between normalized values and size factors?
        The size factor type *normalizer* is either computed or taken from *summaries*.
        The normalized data are computed by dividing *counts_data* by this size factor."""
        if normalizer == "median_ratio_to_pseudo_ref":
            size_factors = median_ratio_to_pseudo_ref_size_factors(counts_data)
        else:
            size_factors = summaries.loc[normalizer]
        #by_norm = counts_data / size_factors
        def compute_pearsonr_with_size_factor(row):
            return pearsonr(row, size_factors)[0]
        #return by_norm.apply(compute_pearsonr_with_size_factor, axis=1)
        return (counts_data / size_factors).apply(compute_pearsonr_with_size_factor, axis=1)
    
    
    def plot_norm_correlations(correlations):
        #fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, sharex=True)
        #correlations.plot.kde(ax=ax1)
        #sns.violinplot(data=correlations, orient="h", ax=ax2)
        #ax2.set_xlabel("Pearson correlation coefficient")
        axis = sns.violinplot(data=correlations, cut=0)
        axis.set_ylabel("Pearson correlation coefficient")
    
    
    def plot_counts_distribution(data, xlabel):
        # TODO: try to plot with semilog x axis
        #axis = data.plot.kde(legend=None)
        #axis.set_xlabel(xlabel)
        #axis.legend(ncol=len(REPS))
        try:
            axis = data.plot.kde()
        except ValueError as e:
            msg = "".join([
                "There seems to be a problem with the data.\n",
                "The data matrix has %d lines and %d columns.\n" % (len(data), len(data.columns))])
            warnings.warn(msg)
            raise
        axis.set_xlabel(xlabel)
    
    
    def plot_boxplots(data, ylabel):
        fig = plt.figure(figsize=(6, 12))
        axis = fig.add_subplot(111)
        data.plot.box(ax=axis)
        axis.set_ylabel(ylabel)
        for label in axis.get_xticklabels():
            label.set_rotation(90)
        plt.tight_layout()
    
    
    # Cutoffs in log fold change
    LFC_CUTOFFS = [0.5, 1, 2]
    UP_STATUSES = [f"up{cutoff}" for cutoff in LFC_CUTOFFS]
    DOWN_STATUSES = [f"down{cutoff}" for cutoff in LFC_CUTOFFS]
    
    
    def status_setter(lfc_cutoffs=None, fold_type="log2FoldChange"):
        """*fold_type* can also be "lfcMLE", which is based on uncorrected values.
        This may not be good for genes with low expression levels."""
        if lfc_cutoffs is None:
            lfc_cutoffs = LFC_CUTOFFS
        def set_status(row):
            """Determines the up- or down-regulation status corresponding to a given
            row of a deseq2 results table."""
            if row["padj"] < 0.05:
                #if row["log2FoldChange"] > 0:
                lfc = row[fold_type]
                if lfc > 0:
                    # Start from the highest cutoff,
                    # and decrease until below lfc
                    for cutoff in sorted(lfc_cutoffs, reverse=True):
                        if lfc > cutoff:
                            return f"up{cutoff}"
                    return "up"
                else:
                    for cutoff in sorted(lfc_cutoffs, reverse=True):
                        if lfc < -cutoff:
                            return f"down{cutoff}"
                    return "down"
            else:
                return "NS"
        return set_status
    
    
    # res = res.assign(is_DE=res.apply(set_de_status, axis=1))
    def set_de_status(row):
        """Determines whether a gene is differentially expressed (DE) of not (NS)
        based on the adjusted p-value in row of a deseq2 results table."""
        if row["padj"] < 0.05:
            return "DE"
        else:
            return "NS"
    DE2COLOUR = {
        # black
        "DE" : "k",
        # pale grey
        "NS" : "0.85"}
    
    
    def plot_lfc_distribution(res, contrast, fold_type=None):
        """*fold_type* is "log2FoldChange" by default.
        It can also be "lfcMLE", which is based on uncorrected values.
        This may not be good for genes with low expression levels."""
        #lfc = res.lfcMLE.dropna()
        if fold_type is None:
            fold_type = "log2FoldChange"
        lfc = getattr(res, fold_type).dropna()
        lfc.name = contrast
        axis = sns.kdeplot(lfc)
        axis.set_xlabel(fold_type)
        axis.set_ylabel("frequency")
    
    
    def make_status2colour(down_statuses, up_statuses):
        statuses = list(reversed(down_statuses)) + ["down", "NS", "up"] + up_statuses
        return dict(zip(statuses, sns.color_palette("coolwarm", len(statuses))))
    
    
    STATUS2COLOUR = make_status2colour(DOWN_STATUSES, UP_STATUSES)
    # TODO: use other labelling than logfold or gene lists, i.e. biotype
    def plot_MA(res,
                grouping=None,
                group2colour=None,
                mean_range=None,
                lfc_range=None,
                fold_type=None):
        """*fold_type* is "log2FoldChange" by default.
        It can also be "lfcMLE", which is based on uncorrected values.
        This may not be good for genes with low expression levels."""
        fig, ax = plt.subplots()
        # Make a column indicating whether the gene is DE or NS
        data = res.assign(is_DE=res.apply(set_de_status, axis=1))
        x_column = "baseMean"
        if fold_type is None:
            y_column = "log2FoldChange"
        else:
            y_column = fold_type
        # First plot the data in grey and black
        for de_status, group in data.groupby("is_DE"):
            group.plot.scatter(x=x_column, y=y_column, s=2, logx=True, c=DE2COLOUR[de_status], label=f"{de_status} ({len(group)})", ax=ax)
        if grouping is not None:
            if isinstance(grouping, str):
                # Overlay colours based on the "grouping" column
                if group2colour is None:
                    group2colour = STATUS2COLOUR
                for status, group in data.groupby(grouping):
                    group.plot.scatter(
                        x=x_column, y=y_column, s=1, logx=True, c=group2colour[status],
                        label=f"{status} ({len(group)})", ax=ax)
            else:
                (status, colour) = group2colour
                row_indices = data.index.intersection(grouping)
                data.ix[row_indices].plot.scatter(
                    x=x_column, y=y_column, s=1, logx=True, c=colour,
                    label=f"{status} ({len(row_indices)})", ax=ax)
        ax.axhline(y=1, linewidth=0.5, color="0.5", linestyle="dashed")
        ax.axhline(y=-1, linewidth=0.5, color="0.5", linestyle="dashed")
        # TODO: check data basemean range
        if mean_range is not None:
            ax.set_xlim(mean_range)
        if lfc_range is not None:
            (lfc_min, lfc_max) = lfc_range
            lfc_here_min = getattr(data, y_column).min()
            lfc_here_max = getattr(data, y_column).max()
            if (lfc_here_min < lfc_min) or (lfc_here_max > lfc_max):
                warnings.warn(f"Cannot plot {y_column} data ([{lfc_here_min}, {lfc_here_max}]) in requested range ([{lfc_min}, {lfc_max}])\n")
            else:
                ax.set_ylim(lfc_range)
    
    def plot_scatter(data,
                     x_column,
                     y_column,
                     regression=False,
                     grouping=None,
                     group2colour=None,
                     x_range=None,
                     y_range=None):
        fig, ax = plt.subplots()
        # First plot the data in grey
        data.plot.scatter(
            x=x_column, y=y_column,
            s=2, c="black", alpha=0.15, edgecolors='none',
            ax=ax)
        if regression:
            linreg = linregress(data[[x_column, y_column]].dropna())
            a = linreg.slope
            b = linreg.intercept
            def fit(x):
                return (a * x) + b
            min_x = data[[x_column]].min()[0]
            max_x = data[[x_column]].max()[0]
            min_y = fit(min_x)
            max_y = fit(max_x)
            xfit, yfit = (min_x, max_x), (min_y, max_y)
            ax.plot(xfit, yfit, linewidth=0.5, color="0.5", linestyle="dashed")
        # Overlay colour points
        if grouping is not None:
            if isinstance(grouping, str):
                # Determine colours based on the "grouping" column
                if group2colour is None:
                    statuses = data[grouping].unique()
                    group2colour = dict(zip(statuses, sns.color_palette("colorblind", len(statuses))))
                for status, group in data.groupby(grouping):
                    group.plot.scatter(
                        x=x_column, y=y_column, s=1, c=group2colour[status],
                        label=f"{status} ({len(group)})", ax=ax)
            else:
                # Apply a coulour to a list of genes
                (status, colour) = group2colour
                row_indices = data.index.intersection(grouping)
                data.ix[row_indices].plot.scatter(
                    x=x_column, y=y_column, s=1, c=colour,
                    label=f"{status} ({len(row_indices)})", ax=ax)
        ax.axhline(y=0, linewidth=0.5, color="0.5", linestyle="dashed")
        ax.axvline(x=0, linewidth=0.5, color="0.5", linestyle="dashed")
        # Set axis limits
        if x_range is not None:
            (x_min, x_max) = x_range
            x_here_min = getattr(data, x_column).min()
            x_here_max = getattr(data, x_column).max()
            if (x_here_min < x_min) or (x_here_max > x_max):
                warnings.warn(f"Cannot plot {x_column} data ([{x_here_min}, {x_here_max}]) in requested range ([{x_min}, {x_max}])\n")
            else:
                ax.set_xlim(x_range)
        if y_range is not None:
            (y_min, y_max) = y_range
            y_here_min = getattr(data, y_column).min()
            y_here_max = getattr(data, y_column).max()
            if (y_here_min < y_min) or (y_here_max > y_max):
                warnings.warn(f"Cannot plot {y_column} data ([{y_here_min}, {y_here_max}]) in requested range ([{y_min}, {y_max}])\n")
            else:
                ax.set_ylim(y_range)
        return ax