diff --git a/libhts/libhts.py b/libhts/libhts.py index 24b9e20b65cc903acc951648b34a74bd37ee2060..45b9f6e5e8317b0cc6322d336d22b3fa573feff6 100644 --- a/libhts/libhts.py +++ b/libhts/libhts.py @@ -9,6 +9,7 @@ def formatwarning(message, category, filename, lineno, line): warnings.formatwarning = formatwarning +import numpy as np import pandas as pd # To compute correlation coefficient, and compute linear regression from scipy.stats.stats import pearsonr, linregress @@ -242,6 +243,27 @@ def do_deseq2(cond_names, conditions, counts_data, return res, {cond : size_factors.loc[cond][0] for cond in cond_names} +## Not sure this is a good idea... +# def masked_gmean(a, axis=0, dtype=None): +# """Modified from stats.py.""" +# # Converts the data into a masked array +# ma = np.ma.masked_invalid(a) +# # Apply gmean +# if not isinstance(ma, np.ndarray): +# # if not an ndarray object attempt to convert it +# log_a = np.log(np.array(ma, dtype=dtype)) +# elif dtype: +# # Must change the default dtype allowing array type +# if isinstance(ma, np.ma.MaskedArray): +# log_a = np.log(np.ma.asarray(ma, dtype=dtype)) +# else: +# log_a = np.log(np.asarray(ma, dtype=dtype)) +# else: +# log_a = np.log(ma) +# return np.exp(log_a.mean(axis=axis)) + + + def median_ratio_to_pseudo_ref_size_factors(counts_data): """Adapted from DESeq paper (doi:10.1186/gb-2010-11-10-r106) All libraries are used to define a pseudo-reference, which has @@ -252,11 +274,20 @@ def median_ratio_to_pseudo_ref_size_factors(counts_data): #pseudo_ref = (counts_data + 1).apply(gmean, axis=1) - 1 # Ignore lines with zeroes instead (may be bad for IP: many zeroes expected): pseudo_ref = (counts_data[counts_data.prod(axis=1) > 0]).apply(gmean, axis=1) + # Ignore lines with only zeroes + # pseudo_ref = (counts_data[counts_data.sum(axis=1) > 0]).apply(masked_gmean, axis=1) def median_ratio_to_pseudo_ref(col): return (col / pseudo_ref).median() #size_factors = counts_data.apply(median_ratio_to_pseudo_ref, axis=0) - return counts_data[counts_data.prod(axis=1) > 0].apply( + median_ratios = counts_data[counts_data.prod(axis=1) > 0].apply( median_ratio_to_pseudo_ref, axis=0) + # Not sure fillna(0) is appropriate + if any(median_ratios.isna()): + msg = "Could not compute median ratios to pseudo reference.\n" + warnings.warn(msg) + return median_ratios.fillna(1) + else: + return median_ratios def size_factor_correlations(counts_data, summaries, normalizer): @@ -442,7 +473,8 @@ def plot_scatter(data, grouping=None, group2colour=None, x_range=None, - y_range=None): + y_range=None, + axes_style=None): fig, ax = plt.subplots() # First plot the data in grey data.plot.scatter( @@ -468,7 +500,7 @@ def plot_scatter(data, if group2colour is None: statuses = data[grouping].unique() group2colour = dict(zip(statuses, sns.color_palette("colorblind", len(statuses)))) - for status, group in data.groupby(grouping): + for (status, group) in data.groupby(grouping): group.plot.scatter( x=x_column, y=y_column, s=1, c=group2colour[status], label=f"{status} ({len(group)})", ax=ax) @@ -486,8 +518,12 @@ def plot_scatter(data, raise else: warnings.warn(f"Nothing to plot for {status}\n") - ax.axhline(y=0, linewidth=0.5, color="0.5", linestyle="dashed") - ax.axvline(x=0, linewidth=0.5, color="0.5", linestyle="dashed") + if axes_style is None: + axes_style = {"linewidth": 0.5, "color": "0.5", "linestyle": "dashed"} + ax.axhline(y=0, **axes_style) + ax.axvline(x=0, **axes_style) + # ax.axhline(y=0, linewidth=0.5, color="0.5", linestyle="dashed") + # ax.axvline(x=0, linewidth=0.5, color="0.5", linestyle="dashed") # Set axis limits if x_range is not None: (x_min, x_max) = x_range