From cde7fc33d2c1e11e8c55ca7a847cfb1f0d4d9038 Mon Sep 17 00:00:00 2001
From: Blaise Li <blaise.li__git@nsup.org>
Date: Tue, 24 Mar 2020 20:41:49 +0100
Subject: [PATCH] Linting.
---
libhts/libhts.py | 388 ++++++++++++++++++++++-------------------------
1 file changed, 178 insertions(+), 210 deletions(-)
diff --git a/libhts/libhts.py b/libhts/libhts.py
index 0ff19b2..8dc859b 100644
--- a/libhts/libhts.py
+++ b/libhts/libhts.py
@@ -14,18 +14,8 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from math import floor, ceil, sqrt, log
from functools import reduce
-from re import sub
+# from re import sub
import warnings
-from libworkflows import texscape
-
-
-def formatwarning(message, category, filename, lineno, line):
- """Used to format warning messages."""
- return "%s:%s: %s: %s\n" % (filename, lineno, category.__name__, message)
-
-
-warnings.formatwarning = formatwarning
-
import numpy as np
import pandas as pd
# To compute correlation coefficient, and compute linear regression
@@ -56,9 +46,21 @@ import seaborn as sns
from pybedtools import BedTool
import pyBigWig
import networkx as nx
+from libworkflows import texscape
+
+
+def formatwarning(
+ message, category, filename, lineno, line): # pylint: disable=W0613
+ """Used to format warning messages."""
+ return "%s:%s: %s: %s\n" % (filename, lineno, category.__name__, message)
-class Exon(object):
+warnings.formatwarning = formatwarning
+
+
+# This might represent any type of genomic interval.
+class Exon():
+ """Object representing an exon."""
__slots__ = ("chrom", "start", "end")
def __init__(self, chrom, start, end):
self.chrom = chrom
@@ -66,11 +68,17 @@ class Exon(object):
self.end = end
def overlap(self, other):
+ """
+ Tell whether *self* and *other* overlap.
+ """
if self.chrom != other.chrom:
return False
return (self.start <= other.start < self.end) or (other.start <= self.start < other.end)
def merge(self, other):
+ """
+ Create a new Exon object by merging *self* with *other*.
+ """
# Not necessary: can be indirectly linked
#assert overlap(self, other)
return Exon(self.chrom, min(self.start, other.start), max(self.end, other.end))
@@ -78,10 +86,10 @@ class Exon(object):
def __len__(self):
return self.end - self.start
-overlap = Exon.overlap
-merge = Exon.merge
+OVERLAP = Exon.overlap
+MERGE = Exon.merge
-class Gene(object):
+class Gene():
"""This object contains information obtained from a gtf file."""
__slots__ = ("gene_id", "exons", "union_exon_length")
def __init__(self, gene_id):
@@ -96,6 +104,10 @@ class Gene(object):
# self.transcripts[the_id] = feature
def add_exon(self, feature):
+ """
+ Add one Exon object to the exon graph based in the information in gtf
+ information *feature*.
+ """
#the_id = feature.attrs["exon_id"]
#assert the_id not in self.exons
#self.exons[the_id] = feature
@@ -117,12 +129,13 @@ class Gene(object):
# len, BedTool(self.exons.values()).merge().features()))
#self.union_exon_length = 0
# We group nodes that overlap, and merge them
- #overlapping_exons = nx.quotient_graph(self.exons, overlap)
+ #overlapping_exons = nx.quotient_graph(self.exons, OVERLAP)
#for node in overlapping_exons.nodes():
- # mex = reduce(merge, node)
+ # mex = reduce(MERGE, node)
# self.union_exon_length += len(mex)
self.union_exon_length = sum((len(reduce(
- merge, node)) for node in nx.quotient_graph(self.exons, overlap).nodes()))
+ MERGE, node)) for node in nx.quotient_graph(
+ self.exons, OVERLAP).nodes()))
def gtf_2_genes_exon_lengths(gtf_filename):
@@ -191,7 +204,9 @@ def spikein_gtf_2_lengths(spikein_gtf):
name=("union_exon_len")).rename_axis("gene"))
-def id_list_gtf2bed(identifiers, gtf_filename, feature_type="transcript", id_kwd="gene_id"):
+def id_list_gtf2bed(
+ identifiers, gtf_filename,
+ feature_type="transcript", id_kwd="gene_id"):
"""
Extract bed coordinates of an iterable of identifiers from a gtf file.
@@ -205,16 +220,17 @@ def id_list_gtf2bed(identifiers, gtf_filename, feature_type="transcript", id_kwd
"""
if identifiers:
ids = set(identifiers)
+
def feature_filter(feature):
return feature[2] == feature_type and feature[id_kwd] in ids
gtf = BedTool(gtf_filename)
return gtf.filter(feature_filter)
- else:
- # https://stackoverflow.com/a/13243870/1878788
- def empty_bed_generator():
- return
- yield
- return empty_bed_generator()
+
+ # https://stackoverflow.com/a/13243870/1878788
+ def empty_bed_generator():
+ return
+ yield # pylint: disable=W0101
+ return empty_bed_generator()
def make_empty_bigwig(filename, chrom_sizes):
@@ -236,15 +252,17 @@ def make_empty_bigwig(filename, chrom_sizes):
#################
# Bowtie2 stuff #
#################
-def zero(value):
+def zero(value): # pylint: disable=W0613
+ """Constant zero."""
return 0
def identity(value):
+ """Identity function."""
return value
-bowtie2_function_selector = {
+BOWTIE2_FUNCTION_SELECTOR = {
"C": zero,
"L": identity,
"S": sqrt,
@@ -270,7 +288,8 @@ def make_seeding_function(seeding_string):
[func_type, constant, coeff] = interval_string.split(",")
constant = float(constant)
coeff = float(coeff)
- func_type = bowtie2_function_selector[func_type]
+ func_type = BOWTIE2_FUNCTION_SELECTOR[func_type]
+
def seeding_function(read_len):
interval = floor(constant + (coeff * func_type(read_len)))
seeds = []
@@ -289,7 +308,7 @@ def aligner2min_mapq(aligner, wildcards):
What minimal MAPQ value should a read have to be considered uniquely mapped?
See <https://sequencing.qcfail.com/articles/mapq-values-are-really-useful-but-their-implementation-is-a-mess/>.
- """
+ """ # pylint: disable=C0301
mapping_type = None
try:
mapping_type = wildcards.mapping_type
@@ -308,15 +327,13 @@ def aligner2min_mapq(aligner, wildcards):
if mapping_type is None or mapping_type.startswith("unique_"):
if aligner == "hisat2":
return "-Q 60"
- elif aligner == "bowtie2":
+ if aligner == "bowtie2":
return "-Q 23"
- else:
- raise NotImplementedError(f"{aligner} not handled (yet?)")
- else:
- return ""
+ raise NotImplementedError(f"{aligner} not handled (yet?)")
+ return ""
-## Not sure this is a good idea...
+# Not sure this is a good idea...
# def masked_gmean(a, axis=0, dtype=None):
# """Modified from stats.py."""
# # Converts the data into a masked array
@@ -336,7 +353,6 @@ def aligner2min_mapq(aligner, wildcards):
# return np.exp(log_a.mean(axis=axis))
-
def median_ratio_to_pseudo_ref_size_factors(counts_data):
"""Adapted from DESeq paper (doi:10.1186/gb-2010-11-10-r106)
All libraries are used to define a pseudo-reference, which has
@@ -344,14 +360,15 @@ def median_ratio_to_pseudo_ref_size_factors(counts_data):
For a given library, the median across genes of the ratios to the
pseudo-reference is used as size factor."""
# Add pseudo-count to compute the geometric mean, then remove it
- #pseudo_ref = (counts_data + 1).apply(gmean, axis=1) - 1
+ # pseudo_ref = (counts_data + 1).apply(gmean, axis=1) - 1
# Ignore lines with zeroes instead (may be bad for IP: many zeroes expected):
pseudo_ref = (counts_data[counts_data.prod(axis=1) > 0]).apply(gmean, axis=1)
# Ignore lines with only zeroes
# pseudo_ref = (counts_data[counts_data.sum(axis=1) > 0]).apply(masked_gmean, axis=1)
+
def median_ratio_to_pseudo_ref(col):
return (col / pseudo_ref).median()
- #size_factors = counts_data.apply(median_ratio_to_pseudo_ref, axis=0)
+ # size_factors = counts_data.apply(median_ratio_to_pseudo_ref, axis=0)
median_ratios = counts_data[counts_data.prod(axis=1) > 0].apply(
median_ratio_to_pseudo_ref, axis=0)
# Not sure fillna(0) is appropriate
@@ -359,8 +376,7 @@ def median_ratio_to_pseudo_ref_size_factors(counts_data):
msg = "Could not compute median ratios to pseudo reference.\n"
warnings.warn(msg)
return median_ratios.fillna(1)
- else:
- return median_ratios
+ return median_ratios
def size_factor_correlations(counts_data, summaries, normalizer):
@@ -371,18 +387,22 @@ def size_factor_correlations(counts_data, summaries, normalizer):
size_factors = median_ratio_to_pseudo_ref_size_factors(counts_data)
else:
size_factors = summaries.loc[normalizer]
- #by_norm = counts_data / size_factors
+ # by_norm = counts_data / size_factors
+
def compute_pearsonr_with_size_factor(row):
return pearsonr(row, size_factors)[0]
- #return by_norm.apply(compute_pearsonr_with_size_factor, axis=1)
+ # return by_norm.apply(compute_pearsonr_with_size_factor, axis=1)
return (counts_data / size_factors).apply(compute_pearsonr_with_size_factor, axis=1)
def plot_norm_correlations(correlations):
- #fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, sharex=True)
- #correlations.plot.kde(ax=ax1)
- #sns.violinplot(data=correlations, orient="h", ax=ax2)
- #ax2.set_xlabel("Pearson correlation coefficient")
+ """
+ Make violin plots to represent data in *correlations*.
+ """
+ # fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, sharex=True)
+ # correlations.plot.kde(ax=ax1)
+ # sns.violinplot(data=correlations, orient="h", ax=ax2)
+ # ax2.set_xlabel("Pearson correlation coefficient")
usetex = mpl.rcParams.get("text.usetex", False)
if usetex:
correlations.columns = [texscape(colname) for colname in correlations.columns]
@@ -391,17 +411,21 @@ def plot_norm_correlations(correlations):
def plot_counts_distribution(data, xlabel):
+ """
+ Plot a kernel density estimate of the distribution of counts in *data*.
+ """
# TODO: try to plot with semilog x axis
- #axis = data.plot.kde(legend=None)
- #axis.set_xlabel(xlabel)
- #axis.legend(ncol=len(REPS))
+ # axis = data.plot.kde(legend=None)
+ # axis.set_xlabel(xlabel)
+ # axis.legend(ncol=len(REPS))
usetex = mpl.rcParams.get("text.usetex", False)
if usetex:
xlabel = texscape(xlabel)
data.columns = [texscape(colname) for colname in data.columns]
try:
axis = data.plot.kde()
- except ValueError as e:
+ # except ValueError as e:
+ except ValueError:
msg = "".join([
"There seems to be a problem with the data.\n",
"The data matrix has %d lines and %d columns.\n" % (len(data), len(data.columns))])
@@ -411,16 +435,18 @@ def plot_counts_distribution(data, xlabel):
def plot_histo(outfile, data, title=None):
- fig = plt.figure(figsize=(15,7))
- ax = fig.add_subplot(111)
- ax.set_xlim([data.index[0] - 0.5, data.index[-1] + 0.5])
- #ax.set_ylim([0, 100])
+ """
+ Plot a histogram of *data* in file *outfile*.
+ """
+ fig = plt.figure(figsize=(15, 7))
+ axis = fig.add_subplot(111)
+ axis.set_xlim([data.index[0] - 0.5, data.index[-1] + 0.5])
+ # axis.set_ylim([0, 100])
bar_width = 0.8
- #letter2legend = dict(zip("ACGT", "ACGT"))
+ # letter2legend = dict(zip("ACGT", "ACGT"))
usetex = mpl.rcParams.get("text.usetex", False)
if usetex:
data.columns = [texscape(colname) for colname in data.columns]
- #title = sub("_", r"\_", title)
title = texscape(title)
for (read_len, count) in data.iterrows():
plt.bar(
@@ -428,20 +454,21 @@ def plot_histo(outfile, data, title=None):
count,
align="center",
width=bar_width)
- #color=letter2colour[letter],
- #label=letter2legend[letter])
- ax.legend()
- ax.set_xticks(data.index)
- ax.set_xticklabels(data.index)
- ax.set_xlabel("read length")
- ax.set_ylabel("number of reads")
- plt.setp(ax.get_xticklabels(), rotation=90)
+ # color=letter2colour[letter],
+ # label=letter2legend[letter])
+ axis.legend()
+ axis.set_xticks(data.index)
+ axis.set_xticklabels(data.index)
+ axis.set_xlabel("read length")
+ axis.set_ylabel("number of reads")
+ plt.setp(axis.get_xticklabels(), rotation=90)
if title is not None:
plt.title(title)
## debug
try:
plt.savefig(outfile)
- except RuntimeError as e:
+ # except RuntimeError as e:
+ except RuntimeError:
print(data.index)
print(title)
raise
@@ -449,6 +476,9 @@ def plot_histo(outfile, data, title=None):
def plot_boxplots(data, ylabel):
+ """
+ Plot boxplots of data in *data* using *ylabel* as y-axis label.
+ """
fig = plt.figure(figsize=(6, 12))
axis = fig.add_subplot(111)
usetex = mpl.rcParams.get("text.usetex", False)
@@ -465,90 +495,6 @@ def plot_boxplots(data, ylabel):
############
# DE stuff #
############
-# def do_deseq2(cond_names, conditions, counts_data,
-# formula=None, contrast=None, deseq2_args=None):
-# """Runs a DESeq2 differential expression analysis."""
-# if formula is None:
-# formula = Formula("~ lib")
-# if contrast is None:
-# # FIXME: MUT and REF are not defined
-# # Maybe just make (formula and) contrast mandatory
-# contrast = StrVector(["lib", MUT, REF])
-# if deseq2_args is None:
-# deseq2_args = {"betaPrior" : True, "addMLE" : True, "independentFiltering" : True}
-# col_data = pd.DataFrame(conditions).assign(
-# cond_name=pd.Series(cond_names).values).set_index("cond_name")
-# # In case we want contrasts between factor combinations
-# if ("lib" in col_data.columns) and ("treat" in col_data.columns):
-# col_data = col_data.assign(
-# lib_treat = ["%s_%s" % (lib, treat) for (lib, treat) in zip(
-# col_data["lib"], col_data["treat"])])
-# # http://stackoverflow.com/a/31206596/1878788
-# pandas2ri.activate() # makes some conversions automatic
-# # r_counts_data = pandas2ri.py2ri(counts_data)
-# # r_col_data = pandas2ri.py2ri(col_data)
-# # r.DESeqDataSetFromMatrix(countData=r_counts_data, colData=r_col_data, design=Formula("~lib"))
-# dds = deseq2.DESeqDataSetFromMatrix(
-# countData=counts_data,
-# colData=col_data,
-# design=formula)
-# # dds = deseq2.DESeq(dds, betaPrior=deseq2_args["betaPrior"])
-# # Decompose into the 3 steps to have more control on the options
-# try:
-# dds = deseq2.estimateSizeFactors_DESeqDataSet(dds, type="ratio")
-# except RRuntimeError as e:
-# if sum(counts_data.prod(axis=1)) == 0:
-# msg = "".join(["Error occurred in estimateSizeFactors:\n%s\n" % e,
-# "This is probably because every gene has at least one zero.\n",
-# "We will try to use the \"poscounts\" method instead."])
-# warnings.warn(msg)
-# try:
-# dds = deseq2.estimateSizeFactors_DESeqDataSet(dds, type="poscounts")
-# except RRuntimeError as e:
-# msg = "".join(["Error occurred in estimateSizeFactors:\n%s\n" % e,
-# "We give up."])
-# warnings.warn(msg)
-# raise
-# #print(counts_data.dtypes)
-# #print(counts_data.columns)
-# #print(len(counts_data))
-# #raise
-# else:
-# raise
-# size_factors = pandas2ri.ri2py(as_df(deseq2.sizeFactors_DESeqDataSet(dds)))
-# #for cond in cond_names:
-# # #s = size_factors.loc[cond][0]
-# # #(*_, s) = size_factors.loc[cond]
-# #pd.DataFrame({cond : size_factors.loc[cond][0] for cond in COND_NAMES}, index=('size_factor',))
-# try:
-# dds = deseq2.estimateDispersions_DESeqDataSet(dds, fitType="parametric")
-# except RRuntimeError as e:
-# msg = "".join(["Error occurred in estimateDispersions:\n%s\n" % e,
-# "We will try with fitType=\"local\"."])
-# warnings.warn(msg)
-# try:
-# dds = deseq2.estimateDispersions_DESeqDataSet(dds, fitType="local")
-# except RRuntimeError as e:
-# msg = "".join(["Error occurred in estimateDispersions:\n%s\n" % e,
-# "We will try with fitType=\"mean\"."])
-# warnings.warn(msg)
-# try:
-# dds = deseq2.estimateDispersions_DESeqDataSet(dds, fitType="mean")
-# except RRuntimeError as e:
-# msg = "".join(["Error occurred in estimateDispersions:\n%s\n" % e,
-# "We give up."])
-# warnings.warn(msg)
-# raise
-# dds = deseq2.nbinomWaldTest(dds, betaPrior=deseq2_args["betaPrior"])
-# res = pandas2ri.ri2py(as_df(deseq2.results(
-# dds,
-# contrast=contrast,
-# addMLE=deseq2_args["addMLE"],
-# independentFiltering=deseq2_args["independentFiltering"])))
-# res.index = counts_data.index
-# return res, {cond : size_factors.loc[cond][0] for cond in cond_names}
-
-
# Cutoffs in log fold change
LFC_CUTOFFS = [0.5, 1, 2]
UP_STATUSES = [f"up{cutoff}" for cutoff in LFC_CUTOFFS]
@@ -564,7 +510,7 @@ def status_setter(lfc_cutoffs=None, fold_type="log2FoldChange"):
"""Determines the up- or down-regulation status corresponding to a given
row of a deseq2 results table."""
if row["padj"] < 0.05:
- #if row["log2FoldChange"] > 0:
+ # if row["log2FoldChange"] > 0:
lfc = row[fold_type]
if lfc > 0:
# Start from the highest cutoff,
@@ -573,13 +519,11 @@ def status_setter(lfc_cutoffs=None, fold_type="log2FoldChange"):
if lfc > cutoff:
return f"up{cutoff}"
return "up"
- else:
- for cutoff in sorted(lfc_cutoffs, reverse=True):
- if lfc < -cutoff:
- return f"down{cutoff}"
- return "down"
- else:
- return "NS"
+ for cutoff in sorted(lfc_cutoffs, reverse=True):
+ if lfc < -cutoff:
+ return f"down{cutoff}"
+ return "down"
+ return "NS"
return set_status
@@ -589,20 +533,19 @@ def set_de_status(row):
based on the adjusted p-value in row of a deseq2 results table."""
if row["padj"] < 0.05:
return "DE"
- else:
- return "NS"
+ return "NS"
DE2COLOUR = {
# black
- "DE" : "k",
+ "DE": "k",
# pale grey
- "NS" : "0.85"}
+ "NS": "0.85"}
def plot_lfc_distribution(res, contrast, fold_type=None):
"""*fold_type* is "log2FoldChange" by default.
It can also be "lfcMLE", which is based on uncorrected values.
This may not be good for genes with low expression levels."""
- #lfc = res.lfcMLE.dropna()
+ # lfc = res.lfcMLE.dropna()
if fold_type is None:
fold_type = "log2FoldChange"
lfc = getattr(res, fold_type).dropna()
@@ -618,11 +561,16 @@ def plot_lfc_distribution(res, contrast, fold_type=None):
def make_status2colour(down_statuses, up_statuses):
+ """
+ Generate a dictionary associating colours to statuses.
+ """
statuses = list(reversed(down_statuses)) + ["down", "NS", "up"] + up_statuses
return dict(zip(statuses, sns.color_palette("coolwarm", len(statuses))))
STATUS2COLOUR = make_status2colour(DOWN_STATUSES, UP_STATUSES)
+
+
# TODO: use other labelling than logfold or gene lists, i.e. biotype
def plot_MA(res,
grouping=None,
@@ -633,9 +581,10 @@ def plot_MA(res,
"""*fold_type* is "log2FoldChange" by default.
It can also be "lfcMLE", which is based on uncorrected values.
This may not be good for genes with low expression levels."""
- if not len(res):
+ # if not len(res):
+ if not res:
raise ValueError("No data to plot.")
- fig, ax = plt.subplots()
+ fig, axis = plt.subplots()
# Make a column indicating whether the gene is DE or NS
data = res.assign(is_DE=res.apply(set_de_status, axis=1))
x_column = "baseMean"
@@ -645,22 +594,23 @@ def plot_MA(res,
else:
y_column = fold_type
usetex = mpl.rcParams.get("text.usetex", False)
+
def scatter_group(group, label, colour, size=1):
"""Plots the data in *group* on the scatterplot."""
if usetex:
label = texscape(label)
group.plot.scatter(
- #x=x_column,
+ # x=x_column,
x="logx",
y=y_column,
s=size,
- #logx=True,
+ # logx=True,
c=colour,
- label=label, ax=ax)
+ label=label, ax=axis)
if usetex:
data.columns = [texscape(colname) for colname in data.columns]
y_column = texscape(y_column)
- de_status_column = "is\_DE"
+ de_status_column = "is\_DE" # pylint: disable=W1401
else:
de_status_column = "is_DE"
# First plot the data in grey and black
@@ -681,28 +631,30 @@ def plot_MA(res,
(status, colour) = group2colour
row_indices = data.index.intersection(grouping)
try:
- label=f"{status} ({len(row_indices)})"
+ label = f"{status} ({len(row_indices)})"
scatter_group(data.ix[row_indices], label, colour)
- except ValueError as e:
- if str(e) != "scatter requires x column to be numeric":
+ except ValueError as err:
+ if str(err) != "scatter requires x column to be numeric":
print(data.ix[row_indices])
raise
- else:
- warnings.warn(f"Nothing to plot for {status}\n")
- ax.axhline(y=1, linewidth=0.5, color="0.5", linestyle="dashed")
- ax.axhline(y=-1, linewidth=0.5, color="0.5", linestyle="dashed")
+ warnings.warn(f"Nothing to plot for {status}\n")
+ axis.axhline(y=1, linewidth=0.5, color="0.5", linestyle="dashed")
+ axis.axhline(y=-1, linewidth=0.5, color="0.5", linestyle="dashed")
# TODO: check data basemean range
if mean_range is not None:
- # ax.set_xlim(mean_range)
- ax.set_xlim(np.log10(mean_range))
+ # axis.set_xlim(mean_range)
+ axis.set_xlim(np.log10(mean_range))
if lfc_range is not None:
(lfc_min, lfc_max) = lfc_range
lfc_here_min = getattr(data, y_column).min()
lfc_here_max = getattr(data, y_column).max()
if (lfc_here_min < lfc_min) or (lfc_here_max > lfc_max):
- warnings.warn(f"Cannot plot {y_column} data ([{lfc_here_min}, {lfc_here_max}]) in requested range ([{lfc_min}, {lfc_max}])\n")
+ warnings.warn(
+ f"Cannot plot {y_column} data "
+ f"([{lfc_here_min}, {lfc_here_max}]) in requested range "
+ f"([{lfc_min}, {lfc_max}])\n")
else:
- ax.set_ylim(lfc_range)
+ axis.set_ylim(lfc_range)
# https://stackoverflow.com/a/24867320/1878788
x_ticks = np.arange(
floor(np.ma.masked_invalid(data["logx"]).min()),
@@ -710,7 +662,7 @@ def plot_MA(res,
1)
x_ticklabels = [r"$10^{{{}}}$".format(tick) for tick in x_ticks]
plt.xticks(x_ticks, x_ticklabels)
- ax.set_xlabel(x_column)
+ axis.set_xlabel(x_column)
def plot_scatter(data,
@@ -722,19 +674,27 @@ def plot_scatter(data,
x_range=None,
y_range=None,
axes_style=None):
- if not len(data):
+ """
+ Plot a scatterplot using data from *data*, using columns
+ """
+ # No rows
+ # if not len(data):
+ # Does it work like that too?
+ if not data:
raise ValueError("No data to plot.")
- fig, ax = plt.subplots()
- # ax.set_adjustable('box')
+ # fig, axis = plt.subplots()
+ _, axis = plt.subplots()
+ # axis.set_adjustable('box')
# First plot the data in grey
data.plot.scatter(
x=x_column, y=y_column,
s=2, c="black", alpha=0.15, edgecolors='none',
- ax=ax)
+ ax=axis)
if regression:
linreg = linregress(data[[x_column, y_column]].dropna())
- a = linreg.slope
- b = linreg.intercept
+ a = linreg.slope # pylint: disable=C0103
+ b = linreg.intercept # pylint: disable=C0103
+
def fit(x):
return (a * x) + b
min_x = data[[x_column]].min()[0]
@@ -742,18 +702,22 @@ def plot_scatter(data,
min_y = fit(min_x)
max_y = fit(max_x)
xfit, yfit = (min_x, max_x), (min_y, max_y)
- ax.plot(xfit, yfit, linewidth=0.5, color="0.5", linestyle="dashed")
+ axis.plot(
+ xfit, yfit,
+ linewidth=0.5, color="0.5", linestyle="dashed")
# Overlay colour points
if grouping is not None:
if isinstance(grouping, str):
# Determine colours based on the "grouping" column
if group2colour is None:
statuses = data[grouping].unique()
- group2colour = dict(zip(statuses, sns.color_palette("colorblind", len(statuses))))
+ group2colour = dict(zip(
+ statuses,
+ sns.color_palette("colorblind", len(statuses))))
for (status, group) in data.groupby(grouping):
group.plot.scatter(
x=x_column, y=y_column, s=1, c=group2colour[status],
- label=f"{status} ({len(group)})", ax=ax)
+ label=f"{status} ({len(group)})", ax=axis)
else:
# Apply a colour to a list of genes
(status, colour) = group2colour
@@ -761,37 +725,41 @@ def plot_scatter(data,
try:
data.ix[row_indices].plot.scatter(
x=x_column, y=y_column, s=1, c=colour,
- label=f"{status} ({len(row_indices)})", ax=ax)
- except ValueError as e:
- if str(e) != "scatter requires x column to be numeric":
+ label=f"{status} ({len(row_indices)})", ax=axis)
+ except ValueError as err:
+ if str(err) != "scatter requires x column to be numeric":
print(data.ix[row_indices])
raise
- else:
- warnings.warn(f"Nothing to plot for {status}\n")
+ warnings.warn(f"Nothing to plot for {status}\n")
if axes_style is None:
axes_style = {"linewidth": 0.5, "color": "0.5", "linestyle": "dashed"}
- ax.axhline(y=0, **axes_style)
- ax.axvline(x=0, **axes_style)
- # ax.axhline(y=0, linewidth=0.5, color="0.5", linestyle="dashed")
- # ax.axvline(x=0, linewidth=0.5, color="0.5", linestyle="dashed")
+ axis.axhline(y=0, **axes_style)
+ axis.axvline(x=0, **axes_style)
+ # axis.axhline(y=0, linewidth=0.5, color="0.5", linestyle="dashed")
+ # axis.axvline(x=0, linewidth=0.5, color="0.5", linestyle="dashed")
# Set axis limits
if x_range is not None:
(x_min, x_max) = x_range
x_here_min = getattr(data, x_column).min()
x_here_max = getattr(data, x_column).max()
if (x_here_min < x_min) or (x_here_max > x_max):
- warnings.warn(f"Cannot plot {x_column} data ([{x_here_min}, {x_here_max}]) in requested range ([{x_min}, {x_max}])\n")
+ warnings.warn(
+ f"Cannot plot {x_column} data "
+ f"([{x_here_min}, {x_here_max}]) in requested range "
+ f"([{x_min}, {x_max}])\n")
else:
- ax.set_xlim(x_range)
+ axis.set_xlim(x_range)
if y_range is not None:
(y_min, y_max) = y_range
y_here_min = getattr(data, y_column).min()
y_here_max = getattr(data, y_column).max()
if (y_here_min < y_min) or (y_here_max > y_max):
- warnings.warn(f"Cannot plot {y_column} data ([{y_here_min}, {y_here_max}]) in requested range ([{y_min}, {y_max}])\n")
+ warnings.warn(
+ f"Cannot plot {y_column} data ([{y_here_min}, {y_here_max}]) "
+ f"in requested range ([{y_min}, {y_max}])\n")
else:
- ax.set_ylim(y_range)
- return ax
+ axis.set_ylim(y_range)
+ return axis
def plot_paired_scatters(data, columns=None, hue=None, log_log=False):
@@ -802,11 +770,11 @@ def plot_paired_scatters(data, columns=None, hue=None, log_log=False):
if usetex:
data.columns = [texscape(colname) for colname in data.columns]
columns = [texscape(colname) for colname in columns]
- g = sns.PairGrid(data, vars=columns, hue=hue, size=8)
- #g.map_offdiag(plt.scatter, marker=".")
- g.map_lower(plt.scatter, marker=".")
+ grid = sns.PairGrid(data, vars=columns, hue=hue, size=8)
+ # grid.map_offdiag(plt.scatter, marker=".")
+ grid.map_lower(plt.scatter, marker=".")
if log_log:
- for ax in g.axes.ravel():
- ax.set_xscale('log')
- ax.set_yscale('log')
- g.add_legend()
+ for axis in grid.axes.ravel():
+ axis.set_xscale('log')
+ axis.set_yscale('log')
+ grid.add_legend()
--
GitLab