diff --git a/libhts/libhts.py b/libhts/libhts.py index 0ff19b28eb7cfbb4f216cba75d54f55d7c1a53a4..8dc859bb78a49c5ab8cd8d8124cbae85247a6b8c 100644 --- a/libhts/libhts.py +++ b/libhts/libhts.py @@ -14,18 +14,8 @@ # along with this program. If not, see <https://www.gnu.org/licenses/>. from math import floor, ceil, sqrt, log from functools import reduce -from re import sub +# from re import sub import warnings -from libworkflows import texscape - - -def formatwarning(message, category, filename, lineno, line): - """Used to format warning messages.""" - return "%s:%s: %s: %s\n" % (filename, lineno, category.__name__, message) - - -warnings.formatwarning = formatwarning - import numpy as np import pandas as pd # To compute correlation coefficient, and compute linear regression @@ -56,9 +46,21 @@ import seaborn as sns from pybedtools import BedTool import pyBigWig import networkx as nx +from libworkflows import texscape + + +def formatwarning( + message, category, filename, lineno, line): # pylint: disable=W0613 + """Used to format warning messages.""" + return "%s:%s: %s: %s\n" % (filename, lineno, category.__name__, message) -class Exon(object): +warnings.formatwarning = formatwarning + + +# This might represent any type of genomic interval. +class Exon(): + """Object representing an exon.""" __slots__ = ("chrom", "start", "end") def __init__(self, chrom, start, end): self.chrom = chrom @@ -66,11 +68,17 @@ class Exon(object): self.end = end def overlap(self, other): + """ + Tell whether *self* and *other* overlap. + """ if self.chrom != other.chrom: return False return (self.start <= other.start < self.end) or (other.start <= self.start < other.end) def merge(self, other): + """ + Create a new Exon object by merging *self* with *other*. + """ # Not necessary: can be indirectly linked #assert overlap(self, other) return Exon(self.chrom, min(self.start, other.start), max(self.end, other.end)) @@ -78,10 +86,10 @@ class Exon(object): def __len__(self): return self.end - self.start -overlap = Exon.overlap -merge = Exon.merge +OVERLAP = Exon.overlap +MERGE = Exon.merge -class Gene(object): +class Gene(): """This object contains information obtained from a gtf file.""" __slots__ = ("gene_id", "exons", "union_exon_length") def __init__(self, gene_id): @@ -96,6 +104,10 @@ class Gene(object): # self.transcripts[the_id] = feature def add_exon(self, feature): + """ + Add one Exon object to the exon graph based in the information in gtf + information *feature*. + """ #the_id = feature.attrs["exon_id"] #assert the_id not in self.exons #self.exons[the_id] = feature @@ -117,12 +129,13 @@ class Gene(object): # len, BedTool(self.exons.values()).merge().features())) #self.union_exon_length = 0 # We group nodes that overlap, and merge them - #overlapping_exons = nx.quotient_graph(self.exons, overlap) + #overlapping_exons = nx.quotient_graph(self.exons, OVERLAP) #for node in overlapping_exons.nodes(): - # mex = reduce(merge, node) + # mex = reduce(MERGE, node) # self.union_exon_length += len(mex) self.union_exon_length = sum((len(reduce( - merge, node)) for node in nx.quotient_graph(self.exons, overlap).nodes())) + MERGE, node)) for node in nx.quotient_graph( + self.exons, OVERLAP).nodes())) def gtf_2_genes_exon_lengths(gtf_filename): @@ -191,7 +204,9 @@ def spikein_gtf_2_lengths(spikein_gtf): name=("union_exon_len")).rename_axis("gene")) -def id_list_gtf2bed(identifiers, gtf_filename, feature_type="transcript", id_kwd="gene_id"): +def id_list_gtf2bed( + identifiers, gtf_filename, + feature_type="transcript", id_kwd="gene_id"): """ Extract bed coordinates of an iterable of identifiers from a gtf file. @@ -205,16 +220,17 @@ def id_list_gtf2bed(identifiers, gtf_filename, feature_type="transcript", id_kwd """ if identifiers: ids = set(identifiers) + def feature_filter(feature): return feature[2] == feature_type and feature[id_kwd] in ids gtf = BedTool(gtf_filename) return gtf.filter(feature_filter) - else: - # https://stackoverflow.com/a/13243870/1878788 - def empty_bed_generator(): - return - yield - return empty_bed_generator() + + # https://stackoverflow.com/a/13243870/1878788 + def empty_bed_generator(): + return + yield # pylint: disable=W0101 + return empty_bed_generator() def make_empty_bigwig(filename, chrom_sizes): @@ -236,15 +252,17 @@ def make_empty_bigwig(filename, chrom_sizes): ################# # Bowtie2 stuff # ################# -def zero(value): +def zero(value): # pylint: disable=W0613 + """Constant zero.""" return 0 def identity(value): + """Identity function.""" return value -bowtie2_function_selector = { +BOWTIE2_FUNCTION_SELECTOR = { "C": zero, "L": identity, "S": sqrt, @@ -270,7 +288,8 @@ def make_seeding_function(seeding_string): [func_type, constant, coeff] = interval_string.split(",") constant = float(constant) coeff = float(coeff) - func_type = bowtie2_function_selector[func_type] + func_type = BOWTIE2_FUNCTION_SELECTOR[func_type] + def seeding_function(read_len): interval = floor(constant + (coeff * func_type(read_len))) seeds = [] @@ -289,7 +308,7 @@ def aligner2min_mapq(aligner, wildcards): What minimal MAPQ value should a read have to be considered uniquely mapped? See <https://sequencing.qcfail.com/articles/mapq-values-are-really-useful-but-their-implementation-is-a-mess/>. - """ + """ # pylint: disable=C0301 mapping_type = None try: mapping_type = wildcards.mapping_type @@ -308,15 +327,13 @@ def aligner2min_mapq(aligner, wildcards): if mapping_type is None or mapping_type.startswith("unique_"): if aligner == "hisat2": return "-Q 60" - elif aligner == "bowtie2": + if aligner == "bowtie2": return "-Q 23" - else: - raise NotImplementedError(f"{aligner} not handled (yet?)") - else: - return "" + raise NotImplementedError(f"{aligner} not handled (yet?)") + return "" -## Not sure this is a good idea... +# Not sure this is a good idea... # def masked_gmean(a, axis=0, dtype=None): # """Modified from stats.py.""" # # Converts the data into a masked array @@ -336,7 +353,6 @@ def aligner2min_mapq(aligner, wildcards): # return np.exp(log_a.mean(axis=axis)) - def median_ratio_to_pseudo_ref_size_factors(counts_data): """Adapted from DESeq paper (doi:10.1186/gb-2010-11-10-r106) All libraries are used to define a pseudo-reference, which has @@ -344,14 +360,15 @@ def median_ratio_to_pseudo_ref_size_factors(counts_data): For a given library, the median across genes of the ratios to the pseudo-reference is used as size factor.""" # Add pseudo-count to compute the geometric mean, then remove it - #pseudo_ref = (counts_data + 1).apply(gmean, axis=1) - 1 + # pseudo_ref = (counts_data + 1).apply(gmean, axis=1) - 1 # Ignore lines with zeroes instead (may be bad for IP: many zeroes expected): pseudo_ref = (counts_data[counts_data.prod(axis=1) > 0]).apply(gmean, axis=1) # Ignore lines with only zeroes # pseudo_ref = (counts_data[counts_data.sum(axis=1) > 0]).apply(masked_gmean, axis=1) + def median_ratio_to_pseudo_ref(col): return (col / pseudo_ref).median() - #size_factors = counts_data.apply(median_ratio_to_pseudo_ref, axis=0) + # size_factors = counts_data.apply(median_ratio_to_pseudo_ref, axis=0) median_ratios = counts_data[counts_data.prod(axis=1) > 0].apply( median_ratio_to_pseudo_ref, axis=0) # Not sure fillna(0) is appropriate @@ -359,8 +376,7 @@ def median_ratio_to_pseudo_ref_size_factors(counts_data): msg = "Could not compute median ratios to pseudo reference.\n" warnings.warn(msg) return median_ratios.fillna(1) - else: - return median_ratios + return median_ratios def size_factor_correlations(counts_data, summaries, normalizer): @@ -371,18 +387,22 @@ def size_factor_correlations(counts_data, summaries, normalizer): size_factors = median_ratio_to_pseudo_ref_size_factors(counts_data) else: size_factors = summaries.loc[normalizer] - #by_norm = counts_data / size_factors + # by_norm = counts_data / size_factors + def compute_pearsonr_with_size_factor(row): return pearsonr(row, size_factors)[0] - #return by_norm.apply(compute_pearsonr_with_size_factor, axis=1) + # return by_norm.apply(compute_pearsonr_with_size_factor, axis=1) return (counts_data / size_factors).apply(compute_pearsonr_with_size_factor, axis=1) def plot_norm_correlations(correlations): - #fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, sharex=True) - #correlations.plot.kde(ax=ax1) - #sns.violinplot(data=correlations, orient="h", ax=ax2) - #ax2.set_xlabel("Pearson correlation coefficient") + """ + Make violin plots to represent data in *correlations*. + """ + # fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, sharex=True) + # correlations.plot.kde(ax=ax1) + # sns.violinplot(data=correlations, orient="h", ax=ax2) + # ax2.set_xlabel("Pearson correlation coefficient") usetex = mpl.rcParams.get("text.usetex", False) if usetex: correlations.columns = [texscape(colname) for colname in correlations.columns] @@ -391,17 +411,21 @@ def plot_norm_correlations(correlations): def plot_counts_distribution(data, xlabel): + """ + Plot a kernel density estimate of the distribution of counts in *data*. + """ # TODO: try to plot with semilog x axis - #axis = data.plot.kde(legend=None) - #axis.set_xlabel(xlabel) - #axis.legend(ncol=len(REPS)) + # axis = data.plot.kde(legend=None) + # axis.set_xlabel(xlabel) + # axis.legend(ncol=len(REPS)) usetex = mpl.rcParams.get("text.usetex", False) if usetex: xlabel = texscape(xlabel) data.columns = [texscape(colname) for colname in data.columns] try: axis = data.plot.kde() - except ValueError as e: + # except ValueError as e: + except ValueError: msg = "".join([ "There seems to be a problem with the data.\n", "The data matrix has %d lines and %d columns.\n" % (len(data), len(data.columns))]) @@ -411,16 +435,18 @@ def plot_counts_distribution(data, xlabel): def plot_histo(outfile, data, title=None): - fig = plt.figure(figsize=(15,7)) - ax = fig.add_subplot(111) - ax.set_xlim([data.index[0] - 0.5, data.index[-1] + 0.5]) - #ax.set_ylim([0, 100]) + """ + Plot a histogram of *data* in file *outfile*. + """ + fig = plt.figure(figsize=(15, 7)) + axis = fig.add_subplot(111) + axis.set_xlim([data.index[0] - 0.5, data.index[-1] + 0.5]) + # axis.set_ylim([0, 100]) bar_width = 0.8 - #letter2legend = dict(zip("ACGT", "ACGT")) + # letter2legend = dict(zip("ACGT", "ACGT")) usetex = mpl.rcParams.get("text.usetex", False) if usetex: data.columns = [texscape(colname) for colname in data.columns] - #title = sub("_", r"\_", title) title = texscape(title) for (read_len, count) in data.iterrows(): plt.bar( @@ -428,20 +454,21 @@ def plot_histo(outfile, data, title=None): count, align="center", width=bar_width) - #color=letter2colour[letter], - #label=letter2legend[letter]) - ax.legend() - ax.set_xticks(data.index) - ax.set_xticklabels(data.index) - ax.set_xlabel("read length") - ax.set_ylabel("number of reads") - plt.setp(ax.get_xticklabels(), rotation=90) + # color=letter2colour[letter], + # label=letter2legend[letter]) + axis.legend() + axis.set_xticks(data.index) + axis.set_xticklabels(data.index) + axis.set_xlabel("read length") + axis.set_ylabel("number of reads") + plt.setp(axis.get_xticklabels(), rotation=90) if title is not None: plt.title(title) ## debug try: plt.savefig(outfile) - except RuntimeError as e: + # except RuntimeError as e: + except RuntimeError: print(data.index) print(title) raise @@ -449,6 +476,9 @@ def plot_histo(outfile, data, title=None): def plot_boxplots(data, ylabel): + """ + Plot boxplots of data in *data* using *ylabel* as y-axis label. + """ fig = plt.figure(figsize=(6, 12)) axis = fig.add_subplot(111) usetex = mpl.rcParams.get("text.usetex", False) @@ -465,90 +495,6 @@ def plot_boxplots(data, ylabel): ############ # DE stuff # ############ -# def do_deseq2(cond_names, conditions, counts_data, -# formula=None, contrast=None, deseq2_args=None): -# """Runs a DESeq2 differential expression analysis.""" -# if formula is None: -# formula = Formula("~ lib") -# if contrast is None: -# # FIXME: MUT and REF are not defined -# # Maybe just make (formula and) contrast mandatory -# contrast = StrVector(["lib", MUT, REF]) -# if deseq2_args is None: -# deseq2_args = {"betaPrior" : True, "addMLE" : True, "independentFiltering" : True} -# col_data = pd.DataFrame(conditions).assign( -# cond_name=pd.Series(cond_names).values).set_index("cond_name") -# # In case we want contrasts between factor combinations -# if ("lib" in col_data.columns) and ("treat" in col_data.columns): -# col_data = col_data.assign( -# lib_treat = ["%s_%s" % (lib, treat) for (lib, treat) in zip( -# col_data["lib"], col_data["treat"])]) -# # http://stackoverflow.com/a/31206596/1878788 -# pandas2ri.activate() # makes some conversions automatic -# # r_counts_data = pandas2ri.py2ri(counts_data) -# # r_col_data = pandas2ri.py2ri(col_data) -# # r.DESeqDataSetFromMatrix(countData=r_counts_data, colData=r_col_data, design=Formula("~lib")) -# dds = deseq2.DESeqDataSetFromMatrix( -# countData=counts_data, -# colData=col_data, -# design=formula) -# # dds = deseq2.DESeq(dds, betaPrior=deseq2_args["betaPrior"]) -# # Decompose into the 3 steps to have more control on the options -# try: -# dds = deseq2.estimateSizeFactors_DESeqDataSet(dds, type="ratio") -# except RRuntimeError as e: -# if sum(counts_data.prod(axis=1)) == 0: -# msg = "".join(["Error occurred in estimateSizeFactors:\n%s\n" % e, -# "This is probably because every gene has at least one zero.\n", -# "We will try to use the \"poscounts\" method instead."]) -# warnings.warn(msg) -# try: -# dds = deseq2.estimateSizeFactors_DESeqDataSet(dds, type="poscounts") -# except RRuntimeError as e: -# msg = "".join(["Error occurred in estimateSizeFactors:\n%s\n" % e, -# "We give up."]) -# warnings.warn(msg) -# raise -# #print(counts_data.dtypes) -# #print(counts_data.columns) -# #print(len(counts_data)) -# #raise -# else: -# raise -# size_factors = pandas2ri.ri2py(as_df(deseq2.sizeFactors_DESeqDataSet(dds))) -# #for cond in cond_names: -# # #s = size_factors.loc[cond][0] -# # #(*_, s) = size_factors.loc[cond] -# #pd.DataFrame({cond : size_factors.loc[cond][0] for cond in COND_NAMES}, index=('size_factor',)) -# try: -# dds = deseq2.estimateDispersions_DESeqDataSet(dds, fitType="parametric") -# except RRuntimeError as e: -# msg = "".join(["Error occurred in estimateDispersions:\n%s\n" % e, -# "We will try with fitType=\"local\"."]) -# warnings.warn(msg) -# try: -# dds = deseq2.estimateDispersions_DESeqDataSet(dds, fitType="local") -# except RRuntimeError as e: -# msg = "".join(["Error occurred in estimateDispersions:\n%s\n" % e, -# "We will try with fitType=\"mean\"."]) -# warnings.warn(msg) -# try: -# dds = deseq2.estimateDispersions_DESeqDataSet(dds, fitType="mean") -# except RRuntimeError as e: -# msg = "".join(["Error occurred in estimateDispersions:\n%s\n" % e, -# "We give up."]) -# warnings.warn(msg) -# raise -# dds = deseq2.nbinomWaldTest(dds, betaPrior=deseq2_args["betaPrior"]) -# res = pandas2ri.ri2py(as_df(deseq2.results( -# dds, -# contrast=contrast, -# addMLE=deseq2_args["addMLE"], -# independentFiltering=deseq2_args["independentFiltering"]))) -# res.index = counts_data.index -# return res, {cond : size_factors.loc[cond][0] for cond in cond_names} - - # Cutoffs in log fold change LFC_CUTOFFS = [0.5, 1, 2] UP_STATUSES = [f"up{cutoff}" for cutoff in LFC_CUTOFFS] @@ -564,7 +510,7 @@ def status_setter(lfc_cutoffs=None, fold_type="log2FoldChange"): """Determines the up- or down-regulation status corresponding to a given row of a deseq2 results table.""" if row["padj"] < 0.05: - #if row["log2FoldChange"] > 0: + # if row["log2FoldChange"] > 0: lfc = row[fold_type] if lfc > 0: # Start from the highest cutoff, @@ -573,13 +519,11 @@ def status_setter(lfc_cutoffs=None, fold_type="log2FoldChange"): if lfc > cutoff: return f"up{cutoff}" return "up" - else: - for cutoff in sorted(lfc_cutoffs, reverse=True): - if lfc < -cutoff: - return f"down{cutoff}" - return "down" - else: - return "NS" + for cutoff in sorted(lfc_cutoffs, reverse=True): + if lfc < -cutoff: + return f"down{cutoff}" + return "down" + return "NS" return set_status @@ -589,20 +533,19 @@ def set_de_status(row): based on the adjusted p-value in row of a deseq2 results table.""" if row["padj"] < 0.05: return "DE" - else: - return "NS" + return "NS" DE2COLOUR = { # black - "DE" : "k", + "DE": "k", # pale grey - "NS" : "0.85"} + "NS": "0.85"} def plot_lfc_distribution(res, contrast, fold_type=None): """*fold_type* is "log2FoldChange" by default. It can also be "lfcMLE", which is based on uncorrected values. This may not be good for genes with low expression levels.""" - #lfc = res.lfcMLE.dropna() + # lfc = res.lfcMLE.dropna() if fold_type is None: fold_type = "log2FoldChange" lfc = getattr(res, fold_type).dropna() @@ -618,11 +561,16 @@ def plot_lfc_distribution(res, contrast, fold_type=None): def make_status2colour(down_statuses, up_statuses): + """ + Generate a dictionary associating colours to statuses. + """ statuses = list(reversed(down_statuses)) + ["down", "NS", "up"] + up_statuses return dict(zip(statuses, sns.color_palette("coolwarm", len(statuses)))) STATUS2COLOUR = make_status2colour(DOWN_STATUSES, UP_STATUSES) + + # TODO: use other labelling than logfold or gene lists, i.e. biotype def plot_MA(res, grouping=None, @@ -633,9 +581,10 @@ def plot_MA(res, """*fold_type* is "log2FoldChange" by default. It can also be "lfcMLE", which is based on uncorrected values. This may not be good for genes with low expression levels.""" - if not len(res): + # if not len(res): + if not res: raise ValueError("No data to plot.") - fig, ax = plt.subplots() + fig, axis = plt.subplots() # Make a column indicating whether the gene is DE or NS data = res.assign(is_DE=res.apply(set_de_status, axis=1)) x_column = "baseMean" @@ -645,22 +594,23 @@ def plot_MA(res, else: y_column = fold_type usetex = mpl.rcParams.get("text.usetex", False) + def scatter_group(group, label, colour, size=1): """Plots the data in *group* on the scatterplot.""" if usetex: label = texscape(label) group.plot.scatter( - #x=x_column, + # x=x_column, x="logx", y=y_column, s=size, - #logx=True, + # logx=True, c=colour, - label=label, ax=ax) + label=label, ax=axis) if usetex: data.columns = [texscape(colname) for colname in data.columns] y_column = texscape(y_column) - de_status_column = "is\_DE" + de_status_column = "is\_DE" # pylint: disable=W1401 else: de_status_column = "is_DE" # First plot the data in grey and black @@ -681,28 +631,30 @@ def plot_MA(res, (status, colour) = group2colour row_indices = data.index.intersection(grouping) try: - label=f"{status} ({len(row_indices)})" + label = f"{status} ({len(row_indices)})" scatter_group(data.ix[row_indices], label, colour) - except ValueError as e: - if str(e) != "scatter requires x column to be numeric": + except ValueError as err: + if str(err) != "scatter requires x column to be numeric": print(data.ix[row_indices]) raise - else: - warnings.warn(f"Nothing to plot for {status}\n") - ax.axhline(y=1, linewidth=0.5, color="0.5", linestyle="dashed") - ax.axhline(y=-1, linewidth=0.5, color="0.5", linestyle="dashed") + warnings.warn(f"Nothing to plot for {status}\n") + axis.axhline(y=1, linewidth=0.5, color="0.5", linestyle="dashed") + axis.axhline(y=-1, linewidth=0.5, color="0.5", linestyle="dashed") # TODO: check data basemean range if mean_range is not None: - # ax.set_xlim(mean_range) - ax.set_xlim(np.log10(mean_range)) + # axis.set_xlim(mean_range) + axis.set_xlim(np.log10(mean_range)) if lfc_range is not None: (lfc_min, lfc_max) = lfc_range lfc_here_min = getattr(data, y_column).min() lfc_here_max = getattr(data, y_column).max() if (lfc_here_min < lfc_min) or (lfc_here_max > lfc_max): - warnings.warn(f"Cannot plot {y_column} data ([{lfc_here_min}, {lfc_here_max}]) in requested range ([{lfc_min}, {lfc_max}])\n") + warnings.warn( + f"Cannot plot {y_column} data " + f"([{lfc_here_min}, {lfc_here_max}]) in requested range " + f"([{lfc_min}, {lfc_max}])\n") else: - ax.set_ylim(lfc_range) + axis.set_ylim(lfc_range) # https://stackoverflow.com/a/24867320/1878788 x_ticks = np.arange( floor(np.ma.masked_invalid(data["logx"]).min()), @@ -710,7 +662,7 @@ def plot_MA(res, 1) x_ticklabels = [r"$10^{{{}}}$".format(tick) for tick in x_ticks] plt.xticks(x_ticks, x_ticklabels) - ax.set_xlabel(x_column) + axis.set_xlabel(x_column) def plot_scatter(data, @@ -722,19 +674,27 @@ def plot_scatter(data, x_range=None, y_range=None, axes_style=None): - if not len(data): + """ + Plot a scatterplot using data from *data*, using columns + """ + # No rows + # if not len(data): + # Does it work like that too? + if not data: raise ValueError("No data to plot.") - fig, ax = plt.subplots() - # ax.set_adjustable('box') + # fig, axis = plt.subplots() + _, axis = plt.subplots() + # axis.set_adjustable('box') # First plot the data in grey data.plot.scatter( x=x_column, y=y_column, s=2, c="black", alpha=0.15, edgecolors='none', - ax=ax) + ax=axis) if regression: linreg = linregress(data[[x_column, y_column]].dropna()) - a = linreg.slope - b = linreg.intercept + a = linreg.slope # pylint: disable=C0103 + b = linreg.intercept # pylint: disable=C0103 + def fit(x): return (a * x) + b min_x = data[[x_column]].min()[0] @@ -742,18 +702,22 @@ def plot_scatter(data, min_y = fit(min_x) max_y = fit(max_x) xfit, yfit = (min_x, max_x), (min_y, max_y) - ax.plot(xfit, yfit, linewidth=0.5, color="0.5", linestyle="dashed") + axis.plot( + xfit, yfit, + linewidth=0.5, color="0.5", linestyle="dashed") # Overlay colour points if grouping is not None: if isinstance(grouping, str): # Determine colours based on the "grouping" column if group2colour is None: statuses = data[grouping].unique() - group2colour = dict(zip(statuses, sns.color_palette("colorblind", len(statuses)))) + group2colour = dict(zip( + statuses, + sns.color_palette("colorblind", len(statuses)))) for (status, group) in data.groupby(grouping): group.plot.scatter( x=x_column, y=y_column, s=1, c=group2colour[status], - label=f"{status} ({len(group)})", ax=ax) + label=f"{status} ({len(group)})", ax=axis) else: # Apply a colour to a list of genes (status, colour) = group2colour @@ -761,37 +725,41 @@ def plot_scatter(data, try: data.ix[row_indices].plot.scatter( x=x_column, y=y_column, s=1, c=colour, - label=f"{status} ({len(row_indices)})", ax=ax) - except ValueError as e: - if str(e) != "scatter requires x column to be numeric": + label=f"{status} ({len(row_indices)})", ax=axis) + except ValueError as err: + if str(err) != "scatter requires x column to be numeric": print(data.ix[row_indices]) raise - else: - warnings.warn(f"Nothing to plot for {status}\n") + warnings.warn(f"Nothing to plot for {status}\n") if axes_style is None: axes_style = {"linewidth": 0.5, "color": "0.5", "linestyle": "dashed"} - ax.axhline(y=0, **axes_style) - ax.axvline(x=0, **axes_style) - # ax.axhline(y=0, linewidth=0.5, color="0.5", linestyle="dashed") - # ax.axvline(x=0, linewidth=0.5, color="0.5", linestyle="dashed") + axis.axhline(y=0, **axes_style) + axis.axvline(x=0, **axes_style) + # axis.axhline(y=0, linewidth=0.5, color="0.5", linestyle="dashed") + # axis.axvline(x=0, linewidth=0.5, color="0.5", linestyle="dashed") # Set axis limits if x_range is not None: (x_min, x_max) = x_range x_here_min = getattr(data, x_column).min() x_here_max = getattr(data, x_column).max() if (x_here_min < x_min) or (x_here_max > x_max): - warnings.warn(f"Cannot plot {x_column} data ([{x_here_min}, {x_here_max}]) in requested range ([{x_min}, {x_max}])\n") + warnings.warn( + f"Cannot plot {x_column} data " + f"([{x_here_min}, {x_here_max}]) in requested range " + f"([{x_min}, {x_max}])\n") else: - ax.set_xlim(x_range) + axis.set_xlim(x_range) if y_range is not None: (y_min, y_max) = y_range y_here_min = getattr(data, y_column).min() y_here_max = getattr(data, y_column).max() if (y_here_min < y_min) or (y_here_max > y_max): - warnings.warn(f"Cannot plot {y_column} data ([{y_here_min}, {y_here_max}]) in requested range ([{y_min}, {y_max}])\n") + warnings.warn( + f"Cannot plot {y_column} data ([{y_here_min}, {y_here_max}]) " + f"in requested range ([{y_min}, {y_max}])\n") else: - ax.set_ylim(y_range) - return ax + axis.set_ylim(y_range) + return axis def plot_paired_scatters(data, columns=None, hue=None, log_log=False): @@ -802,11 +770,11 @@ def plot_paired_scatters(data, columns=None, hue=None, log_log=False): if usetex: data.columns = [texscape(colname) for colname in data.columns] columns = [texscape(colname) for colname in columns] - g = sns.PairGrid(data, vars=columns, hue=hue, size=8) - #g.map_offdiag(plt.scatter, marker=".") - g.map_lower(plt.scatter, marker=".") + grid = sns.PairGrid(data, vars=columns, hue=hue, size=8) + # grid.map_offdiag(plt.scatter, marker=".") + grid.map_lower(plt.scatter, marker=".") if log_log: - for ax in g.axes.ravel(): - ax.set_xscale('log') - ax.set_yscale('log') - g.add_legend() + for axis in grid.axes.ravel(): + axis.set_xscale('log') + axis.set_yscale('log') + grid.add_legend()