diff --git a/libhts/__init__.py b/libhts/__init__.py index 3d089bba0c487d9853f2a98da83e4e10a62bb05a..2da39117da59070c590205818dfdcb687e918369 100644 --- a/libhts/__init__.py +++ b/libhts/__init__.py @@ -1 +1 @@ -from .libhts import do_deseq2, gtf_2_genes_exon_lengths, median_ratio_to_pseudo_ref_size_factors, plot_boxplots, plot_counts_distribution, plot_lfc_distribution, plot_MA, plot_norm_correlations, repeat_bed_2_lengths, size_factor_correlations, status_setter +from .libhts import do_deseq2, gtf_2_genes_exon_lengths, median_ratio_to_pseudo_ref_size_factors, plot_boxplots, plot_counts_distribution, plot_lfc_distribution, plot_MA, plot_norm_correlations, plot_scatter, repeat_bed_2_lengths, size_factor_correlations, status_setter diff --git a/libhts/libhts.py b/libhts/libhts.py index 7b23e8a19c4f3155c1b7b149979dbe06691b7c6e..fa5e6e168c07e00995e95e6eef170bd69114f76a 100644 --- a/libhts/libhts.py +++ b/libhts/libhts.py @@ -10,8 +10,8 @@ def formatwarning(message, category, filename, lineno, line): warnings.formatwarning = formatwarning import pandas as pd -# To compute correlation coefficient -from scipy.stats.stats import pearsonr +# To compute correlation coefficient, and compute linear regression +from scipy.stats.stats import pearsonr, linregress # To compute geometric mean from scipy.stats.mstats import gmean import matplotlib.pyplot as plt @@ -150,6 +150,7 @@ def do_deseq2(cond_names, conditions, counts_data, if formula is None: formula = Formula("~ lib") if contrast is None: + # FIXME: MUT and REF are not defined contrast = StrVector(["lib", MUT, REF]) if deseq2_args is None: deseq2_args = {"betaPrior" : True, "addMLE" : True, "independentFiltering" : True} @@ -374,22 +375,27 @@ def plot_MA(res, This may not be good for genes with low expression levels.""" fig, ax = plt.subplots() # Make a column indicating whether the gene is DE or NS - res = res.assign(is_DE=res.apply(set_de_status, axis=1)) + data = res.assign(is_DE=res.apply(set_de_status, axis=1)) if fold_type is None: fold_type = "log2FoldChange" # First plot the data in grey and black - for de_status, group in res.groupby("is_DE"): + for de_status, group in data.groupby("is_DE"): group.plot.scatter(x="baseMean", y=fold_type, s=2, logx=True, c=DE2COLOUR[de_status], label=f"{de_status} ({len(group)})", ax=ax) if grouping is not None: if isinstance(grouping, str): # Overlay colours based on the "grouping" column if group2colour is None: group2colour = STATUS2COLOUR - for status, group in res.groupby(grouping): - group.plot.scatter(x="baseMean", y=fold_type, s=1, logx=True, c=group2colour[status], label=f"{status} ({len(group)})", ax=ax) + for status, group in data.groupby(grouping): + group.plot.scatter( + x="baseMean", y=fold_type, s=1, logx=True, c=group2colour[status], + label=f"{status} ({len(group)})", ax=ax) else: (status, colour) = group2colour - res.ix[grouping].plot.scatter(x="baseMean", y=fold_type, s=1, logx=True, c=colour, label=f"{status} ({len(grouping)})", ax=ax) + row_indices = data.index.intersection(grouping) + data.ix[row_indices].plot.scatter( + x="baseMean", y=fold_type, s=1, logx=True, c=colour, + label=f"{status} ({len(row_indices)})", ax=ax) ax.axhline(y=1, linewidth=0.5, color="0.5", linestyle="dashed") ax.axhline(y=-1, linewidth=0.5, color="0.5", linestyle="dashed") # TODO: check data basemean range @@ -397,10 +403,71 @@ def plot_MA(res, ax.set_xlim(mean_range) if lfc_range is not None: (lfc_min, lfc_max) = lfc_range - lfc_here_min = getattr(res, fold_type).min() - lfc_here_max = getattr(res, fold_type).max() + lfc_here_min = getattr(data, fold_type).min() + lfc_here_max = getattr(data, fold_type).max() if (lfc_here_min < lfc_min) or (lfc_here_max > lfc_max): warnings.warn(f"Cannot plot {fold_type} data ([{lfc_here_min}, {lfc_here_max}]) in requested range ([{lfc_min}, {lfc_max}])\n") else: ax.set_ylim(lfc_range) +def plot_scatter(data, + x_column, + y_column, + regression=False, + grouping=None, + group2colour=None, + x_range=None, + y_range=None): + fig, ax = plt.subplots() + # First plot the data in grey + data.plot.scatter(x=x_column, y=y_column, s=2, c="lightgray", ax=ax) + if regression: + linreg = linregress(data[[x_column, y_column]].dropna()) + a = linreg.slope + b = linreg.intercept + def fit(x): + return (a * x) + b + min_x = data[[x_column]].min()[0] + max_x = data[[x_column]].max()[0] + min_y = fit(min_x) + max_y = fit(max_x) + xfit, yfit = (min_x, max_x), (min_y, max_y) + ax.plot(xfit, yfit, linewidth=0.5, color="0.5", linestyle="dashed") + # Overlay colour points + if grouping is not None: + if isinstance(grouping, str): + # Determine colours based on the "grouping" column + if group2colour is None: + statuses = data[grouping].unique() + group2colour = dict(zip(statuses, sns.color_palette("colorblind", len(statuses)))) + for status, group in data.groupby(grouping): + group.plot.scatter( + x=x_column, y=y_column, s=1, c=group2colour[status], + label=f"{status} ({len(group)})", ax=ax) + else: + # Apply a coulour to a list of genes + (status, colour) = group2colour + row_indices = data.index.intersection(grouping) + data.ix[row_indices].plot.scatter( + x=x_column, y=y_column, s=1, c=colour, + label=f"{status} ({len(row_indices)})", ax=ax) + ax.axhline(y=0, linewidth=0.5, color="0.5", linestyle="dashed") + ax.axvline(x=0, linewidth=0.5, color="0.5", linestyle="dashed") + # Set axis limits + if x_range is not None: + (x_min, x_max) = x_range + x_here_min = getattr(data, x_column).min() + x_here_max = getattr(data, x_column).max() + if (x_here_min < x_min) or (x_here_max > x_max): + warnings.warn(f"Cannot plot {x_column} data ([{x_here_min}, {x_here_max}]) in requested range ([{x_min}, {x_max}])\n") + else: + ax.set_xlim(x_range) + if y_range is not None: + (y_min, y_max) = y_range + y_here_min = getattr(data, y_column).min() + y_here_max = getattr(data, y_column).max() + if (y_here_min < y_min) or (y_here_max > y_max): + warnings.warn(f"Cannot plot {y_column} data ([{y_here_min}, {y_here_max}]) in requested range ([{y_min}, {y_max}])\n") + else: + ax.set_ylim(y_range) + return ax