diff --git a/libhts/__init__.py b/libhts/__init__.py index efc99af6c65d15876f7e985bdf8a184513540c15..c342524162c4ee06ac8f6bd45c2c1a0bf773ade5 100644 --- a/libhts/__init__.py +++ b/libhts/__init__.py @@ -1 +1,7 @@ -from .libhts import do_deseq2, gtf_2_genes_exon_lengths, median_ratio_to_pseudo_ref_size_factors, plot_boxplots, plot_counts_distribution, plot_lfc_distribution, plot_MA, plot_norm_correlations, plot_paired_scatters, plot_scatter, repeat_bed_2_lengths, size_factor_correlations, spikein_gtf_2_lengths, status_setter +from .libhts import ( + do_deseq2, gtf_2_genes_exon_lengths, median_ratio_to_pseudo_ref_size_factors, + plot_boxplots, plot_counts_distribution, plot_histo, + plot_lfc_distribution, plot_MA, + plot_norm_correlations, plot_paired_scatters, plot_scatter, + repeat_bed_2_lengths, size_factor_correlations, + spikein_gtf_2_lengths, status_setter) diff --git a/libhts/libhts.py b/libhts/libhts.py index a3d89b2c1a7bd0df786a5496abce59007b20a83d..9b901e256cf06034e2296de76c027fc7b4adf6d4 100644 --- a/libhts/libhts.py +++ b/libhts/libhts.py @@ -1,6 +1,8 @@ +from math import floor, ceil from functools import reduce from re import sub import warnings +from libworkflows import texscape def formatwarning(message, category, filename, lineno, line): @@ -326,8 +328,7 @@ def plot_norm_correlations(correlations): #ax2.set_xlabel("Pearson correlation coefficient") usetex = mpl.rcParams.get("text.usetex", False) if usetex: - correlations.columns = [sub( - "_", "\_", colname) for colname in correlations.columns] + correlations.columns = [texscape(colname) for colname in correlations.columns] axis = sns.violinplot(data=correlations, cut=0) axis.set_ylabel("Pearson correlation coefficient") @@ -339,8 +340,8 @@ def plot_counts_distribution(data, xlabel): #axis.legend(ncol=len(REPS)) usetex = mpl.rcParams.get("text.usetex", False) if usetex: - xlabel = sub("_", r"\_", xlabel) - data.columns = [sub("_", "\_", colname) for colname in data.columns] + xlabel = texscape(xlabel) + data.columns = [texscape(colname) for colname in data.columns] try: axis = data.plot.kde() except ValueError as e: @@ -352,13 +353,43 @@ def plot_counts_distribution(data, xlabel): axis.set_xlabel(xlabel) +def plot_histo(outfile, data, title=None): + fig = plt.figure(figsize=(15,7)) + ax = fig.add_subplot(111) + ax.set_xlim([data.index[0] - 0.5, data.index[-1] + 0.5]) + #ax.set_ylim([0, 100]) + bar_width = 0.8 + #letter2legend = dict(zip("ACGT", "ACGT")) + usetex = mpl.rcParams.get("text.usetex", False) + if usetex: + data.columns = [texscape(colname) for colname in data.columns] + title = sub("_", r"\_", title) + for (read_len, count) in data.iterrows(): + plt.bar( + read_len, + count, + align="center", + width=bar_width) + #color=letter2colour[letter], + #label=letter2legend[letter]) + ax.legend() + ax.set_xticks(data.index) + ax.set_xticklabels(data.index) + ax.set_xlabel("read length") + ax.set_ylabel("number of reads") + plt.setp(ax.get_xticklabels(), rotation=90) + if title is not None: + plt.title(title) + plt.savefig(outfile) + + def plot_boxplots(data, ylabel): fig = plt.figure(figsize=(6, 12)) axis = fig.add_subplot(111) usetex = mpl.rcParams.get("text.usetex", False) if usetex: - ylabel = sub("_", r"\_", ylabel) - data.columns = [sub("_", "\_", colname) for colname in data.columns] + ylabel = texscape(ylabel) + data.columns = [texscape(colname) for colname in data.columns] data.plot.box(ax=axis) axis.set_ylabel(ylabel) for label in axis.get_xticklabels(): @@ -426,7 +457,7 @@ def plot_lfc_distribution(res, contrast, fold_type=None): lfc.name = contrast usetex = mpl.rcParams.get("text.usetex", False) if usetex: - lfc.columns = [sub("_", "\_", colname) for colname in lfc.columns] + lfc.columns = [texscape(colname) for colname in lfc.columns] axis = sns.kdeplot(lfc) axis.set_xlabel(fold_type) axis.set_ylabel("frequency") @@ -452,29 +483,50 @@ def plot_MA(res, # Make a column indicating whether the gene is DE or NS data = res.assign(is_DE=res.apply(set_de_status, axis=1)) x_column = "baseMean" + data = data.assign(logx=np.log10(data[x_column])) if fold_type is None: y_column = "log2FoldChange" else: y_column = fold_type + usetex = mpl.rcParams.get("text.usetex", False) + def scatter_group(group, label, colour, size=1): + """Plots the data in *group* on the scatterplot.""" + if usetex: + label = texscape(label) + group.plot.scatter( + #x=x_column, + x="logx", + y=y_column, + s=size, + #logx=True, + c=colour, + label=label, ax=ax) + if usetex: + data.columns = [texscape(colname) for colname in data.columns] + y_column = texscape(y_column) + de_status_column = "is\_DE" + else: + de_status_column = "is_DE" # First plot the data in grey and black - for de_status, group in data.groupby("is_DE"): - group.plot.scatter(x=x_column, y=y_column, s=2, logx=True, c=DE2COLOUR[de_status], label=f"{de_status} ({len(group)})", ax=ax) + for (de_status, group) in data.groupby(de_status_column): + label = f"{de_status} ({len(group)})" + colour = DE2COLOUR[de_status] + scatter_group(group, label, colour, size=2) if grouping is not None: if isinstance(grouping, str): # Overlay colours based on the "grouping" column if group2colour is None: group2colour = STATUS2COLOUR for status, group in data.groupby(grouping): - group.plot.scatter( - x=x_column, y=y_column, s=1, logx=True, c=group2colour[status], - label=f"{status} ({len(group)})", ax=ax) + label = f"{status} ({len(group)})" + colour = group2colour[status] + scatter_group(group, label, colour) else: (status, colour) = group2colour row_indices = data.index.intersection(grouping) try: - data.ix[row_indices].plot.scatter( - x=x_column, y=y_column, s=1, logx=True, c=colour, - label=f"{status} ({len(row_indices)})", ax=ax) + label=f"{status} ({len(row_indices)})" + scatter_group(data.ix[row_indices], label, colour) except ValueError as e: if str(e) != "scatter requires x column to be numeric": print(data.ix[row_indices]) @@ -485,7 +537,8 @@ def plot_MA(res, ax.axhline(y=-1, linewidth=0.5, color="0.5", linestyle="dashed") # TODO: check data basemean range if mean_range is not None: - ax.set_xlim(mean_range) + # ax.set_xlim(mean_range) + ax.set_xlim(np.log10(mean_range)) if lfc_range is not None: (lfc_min, lfc_max) = lfc_range lfc_here_min = getattr(data, y_column).min() @@ -494,6 +547,15 @@ def plot_MA(res, warnings.warn(f"Cannot plot {y_column} data ([{lfc_here_min}, {lfc_here_max}]) in requested range ([{lfc_min}, {lfc_max}])\n") else: ax.set_ylim(lfc_range) + # https://stackoverflow.com/a/24867320/1878788 + x_ticks = np.arange( + floor(np.ma.masked_invalid(data["logx"]).min()), + ceil(np.ma.masked_invalid(data["logx"]).max()), + 1) + x_ticklabels = [r"$10^{{{}}}$".format(tick) for tick in x_ticks] + plt.xticks(x_ticks, x_ticklabels) + ax.set_xlabel(x_column) + def plot_scatter(data, x_column, @@ -579,8 +641,8 @@ def plot_paired_scatters(data, columns=None, hue=None, log_log=False): columns = data.columns usetex = mpl.rcParams.get("text.usetex", False) if usetex: - data.columns = [sub("_", "\_", colname) for colname in data.columns] - columns = [sub("_", "\_", colname) for colname in columns] + data.columns = [texscape(colname) for colname in data.columns] + columns = [texscape(colname) for colname in columns] g = sns.PairGrid(data, vars=columns, hue=hue, size=8) #g.map_offdiag(plt.scatter, marker=".") g.map_lower(plt.scatter, marker=".")