Commit cde7fc33 authored by Blaise Li's avatar Blaise Li

Linting.

parent 66415b36
......@@ -14,18 +14,8 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from math import floor, ceil, sqrt, log
from functools import reduce
from re import sub
# from re import sub
import warnings
from libworkflows import texscape
def formatwarning(message, category, filename, lineno, line):
"""Used to format warning messages."""
return "%s:%s: %s: %s\n" % (filename, lineno, category.__name__, message)
warnings.formatwarning = formatwarning
import numpy as np
import pandas as pd
# To compute correlation coefficient, and compute linear regression
......@@ -56,9 +46,21 @@ import seaborn as sns
from pybedtools import BedTool
import pyBigWig
import networkx as nx
from libworkflows import texscape
def formatwarning(
message, category, filename, lineno, line): # pylint: disable=W0613
"""Used to format warning messages."""
return "%s:%s: %s: %s\n" % (filename, lineno, category.__name__, message)
class Exon(object):
warnings.formatwarning = formatwarning
# This might represent any type of genomic interval.
class Exon():
"""Object representing an exon."""
__slots__ = ("chrom", "start", "end")
def __init__(self, chrom, start, end):
self.chrom = chrom
......@@ -66,11 +68,17 @@ class Exon(object):
self.end = end
def overlap(self, other):
"""
Tell whether *self* and *other* overlap.
"""
if self.chrom != other.chrom:
return False
return (self.start <= other.start < self.end) or (other.start <= self.start < other.end)
def merge(self, other):
"""
Create a new Exon object by merging *self* with *other*.
"""
# Not necessary: can be indirectly linked
#assert overlap(self, other)
return Exon(self.chrom, min(self.start, other.start), max(self.end, other.end))
......@@ -78,10 +86,10 @@ class Exon(object):
def __len__(self):
return self.end - self.start
overlap = Exon.overlap
merge = Exon.merge
OVERLAP = Exon.overlap
MERGE = Exon.merge
class Gene(object):
class Gene():
"""This object contains information obtained from a gtf file."""
__slots__ = ("gene_id", "exons", "union_exon_length")
def __init__(self, gene_id):
......@@ -96,6 +104,10 @@ class Gene(object):
# self.transcripts[the_id] = feature
def add_exon(self, feature):
"""
Add one Exon object to the exon graph based in the information in gtf
information *feature*.
"""
#the_id = feature.attrs["exon_id"]
#assert the_id not in self.exons
#self.exons[the_id] = feature
......@@ -117,12 +129,13 @@ class Gene(object):
# len, BedTool(self.exons.values()).merge().features()))
#self.union_exon_length = 0
# We group nodes that overlap, and merge them
#overlapping_exons = nx.quotient_graph(self.exons, overlap)
#overlapping_exons = nx.quotient_graph(self.exons, OVERLAP)
#for node in overlapping_exons.nodes():
# mex = reduce(merge, node)
# mex = reduce(MERGE, node)
# self.union_exon_length += len(mex)
self.union_exon_length = sum((len(reduce(
merge, node)) for node in nx.quotient_graph(self.exons, overlap).nodes()))
MERGE, node)) for node in nx.quotient_graph(
self.exons, OVERLAP).nodes()))
def gtf_2_genes_exon_lengths(gtf_filename):
......@@ -191,7 +204,9 @@ def spikein_gtf_2_lengths(spikein_gtf):
name=("union_exon_len")).rename_axis("gene"))
def id_list_gtf2bed(identifiers, gtf_filename, feature_type="transcript", id_kwd="gene_id"):
def id_list_gtf2bed(
identifiers, gtf_filename,
feature_type="transcript", id_kwd="gene_id"):
"""
Extract bed coordinates of an iterable of identifiers from a gtf file.
......@@ -205,16 +220,17 @@ def id_list_gtf2bed(identifiers, gtf_filename, feature_type="transcript", id_kwd
"""
if identifiers:
ids = set(identifiers)
def feature_filter(feature):
return feature[2] == feature_type and feature[id_kwd] in ids
gtf = BedTool(gtf_filename)
return gtf.filter(feature_filter)
else:
# https://stackoverflow.com/a/13243870/1878788
def empty_bed_generator():
return
yield
return empty_bed_generator()
# https://stackoverflow.com/a/13243870/1878788
def empty_bed_generator():
return
yield # pylint: disable=W0101
return empty_bed_generator()
def make_empty_bigwig(filename, chrom_sizes):
......@@ -236,15 +252,17 @@ def make_empty_bigwig(filename, chrom_sizes):
#################
# Bowtie2 stuff #
#################
def zero(value):
def zero(value): # pylint: disable=W0613
"""Constant zero."""
return 0
def identity(value):
"""Identity function."""
return value
bowtie2_function_selector = {
BOWTIE2_FUNCTION_SELECTOR = {
"C": zero,
"L": identity,
"S": sqrt,
......@@ -270,7 +288,8 @@ def make_seeding_function(seeding_string):
[func_type, constant, coeff] = interval_string.split(",")
constant = float(constant)
coeff = float(coeff)
func_type = bowtie2_function_selector[func_type]
func_type = BOWTIE2_FUNCTION_SELECTOR[func_type]
def seeding_function(read_len):
interval = floor(constant + (coeff * func_type(read_len)))
seeds = []
......@@ -289,7 +308,7 @@ def aligner2min_mapq(aligner, wildcards):
What minimal MAPQ value should a read have to be considered uniquely mapped?
See <https://sequencing.qcfail.com/articles/mapq-values-are-really-useful-but-their-implementation-is-a-mess/>.
"""
""" # pylint: disable=C0301
mapping_type = None
try:
mapping_type = wildcards.mapping_type
......@@ -308,15 +327,13 @@ def aligner2min_mapq(aligner, wildcards):
if mapping_type is None or mapping_type.startswith("unique_"):
if aligner == "hisat2":
return "-Q 60"
elif aligner == "bowtie2":
if aligner == "bowtie2":
return "-Q 23"
else:
raise NotImplementedError(f"{aligner} not handled (yet?)")
else:
return ""
raise NotImplementedError(f"{aligner} not handled (yet?)")
return ""
## Not sure this is a good idea...
# Not sure this is a good idea...
# def masked_gmean(a, axis=0, dtype=None):
# """Modified from stats.py."""
# # Converts the data into a masked array
......@@ -336,7 +353,6 @@ def aligner2min_mapq(aligner, wildcards):
# return np.exp(log_a.mean(axis=axis))
def median_ratio_to_pseudo_ref_size_factors(counts_data):
"""Adapted from DESeq paper (doi:10.1186/gb-2010-11-10-r106)
All libraries are used to define a pseudo-reference, which has
......@@ -344,14 +360,15 @@ def median_ratio_to_pseudo_ref_size_factors(counts_data):
For a given library, the median across genes of the ratios to the
pseudo-reference is used as size factor."""
# Add pseudo-count to compute the geometric mean, then remove it
#pseudo_ref = (counts_data + 1).apply(gmean, axis=1) - 1
# pseudo_ref = (counts_data + 1).apply(gmean, axis=1) - 1
# Ignore lines with zeroes instead (may be bad for IP: many zeroes expected):
pseudo_ref = (counts_data[counts_data.prod(axis=1) > 0]).apply(gmean, axis=1)
# Ignore lines with only zeroes
# pseudo_ref = (counts_data[counts_data.sum(axis=1) > 0]).apply(masked_gmean, axis=1)
def median_ratio_to_pseudo_ref(col):
return (col / pseudo_ref).median()
#size_factors = counts_data.apply(median_ratio_to_pseudo_ref, axis=0)
# size_factors = counts_data.apply(median_ratio_to_pseudo_ref, axis=0)
median_ratios = counts_data[counts_data.prod(axis=1) > 0].apply(
median_ratio_to_pseudo_ref, axis=0)
# Not sure fillna(0) is appropriate
......@@ -359,8 +376,7 @@ def median_ratio_to_pseudo_ref_size_factors(counts_data):
msg = "Could not compute median ratios to pseudo reference.\n"
warnings.warn(msg)
return median_ratios.fillna(1)
else:
return median_ratios
return median_ratios
def size_factor_correlations(counts_data, summaries, normalizer):
......@@ -371,18 +387,22 @@ def size_factor_correlations(counts_data, summaries, normalizer):
size_factors = median_ratio_to_pseudo_ref_size_factors(counts_data)
else:
size_factors = summaries.loc[normalizer]
#by_norm = counts_data / size_factors
# by_norm = counts_data / size_factors
def compute_pearsonr_with_size_factor(row):
return pearsonr(row, size_factors)[0]
#return by_norm.apply(compute_pearsonr_with_size_factor, axis=1)
# return by_norm.apply(compute_pearsonr_with_size_factor, axis=1)
return (counts_data / size_factors).apply(compute_pearsonr_with_size_factor, axis=1)
def plot_norm_correlations(correlations):
#fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, sharex=True)
#correlations.plot.kde(ax=ax1)
#sns.violinplot(data=correlations, orient="h", ax=ax2)
#ax2.set_xlabel("Pearson correlation coefficient")
"""
Make violin plots to represent data in *correlations*.
"""
# fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, sharex=True)
# correlations.plot.kde(ax=ax1)
# sns.violinplot(data=correlations, orient="h", ax=ax2)
# ax2.set_xlabel("Pearson correlation coefficient")
usetex = mpl.rcParams.get("text.usetex", False)
if usetex:
correlations.columns = [texscape(colname) for colname in correlations.columns]
......@@ -391,17 +411,21 @@ def plot_norm_correlations(correlations):
def plot_counts_distribution(data, xlabel):
"""
Plot a kernel density estimate of the distribution of counts in *data*.
"""
# TODO: try to plot with semilog x axis
#axis = data.plot.kde(legend=None)
#axis.set_xlabel(xlabel)
#axis.legend(ncol=len(REPS))
# axis = data.plot.kde(legend=None)
# axis.set_xlabel(xlabel)
# axis.legend(ncol=len(REPS))
usetex = mpl.rcParams.get("text.usetex", False)
if usetex:
xlabel = texscape(xlabel)
data.columns = [texscape(colname) for colname in data.columns]
try:
axis = data.plot.kde()
except ValueError as e:
# except ValueError as e:
except ValueError:
msg = "".join([
"There seems to be a problem with the data.\n",
"The data matrix has %d lines and %d columns.\n" % (len(data), len(data.columns))])
......@@ -411,16 +435,18 @@ def plot_counts_distribution(data, xlabel):
def plot_histo(outfile, data, title=None):
fig = plt.figure(figsize=(15,7))
ax = fig.add_subplot(111)
ax.set_xlim([data.index[0] - 0.5, data.index[-1] + 0.5])
#ax.set_ylim([0, 100])
"""
Plot a histogram of *data* in file *outfile*.
"""
fig = plt.figure(figsize=(15, 7))
axis = fig.add_subplot(111)
axis.set_xlim([data.index[0] - 0.5, data.index[-1] + 0.5])
# axis.set_ylim([0, 100])
bar_width = 0.8
#letter2legend = dict(zip("ACGT", "ACGT"))
# letter2legend = dict(zip("ACGT", "ACGT"))
usetex = mpl.rcParams.get("text.usetex", False)
if usetex:
data.columns = [texscape(colname) for colname in data.columns]
#title = sub("_", r"\_", title)
title = texscape(title)
for (read_len, count) in data.iterrows():
plt.bar(
......@@ -428,20 +454,21 @@ def plot_histo(outfile, data, title=None):
count,
align="center",
width=bar_width)
#color=letter2colour[letter],
#label=letter2legend[letter])
ax.legend()
ax.set_xticks(data.index)
ax.set_xticklabels(data.index)
ax.set_xlabel("read length")
ax.set_ylabel("number of reads")
plt.setp(ax.get_xticklabels(), rotation=90)
# color=letter2colour[letter],
# label=letter2legend[letter])
axis.legend()
axis.set_xticks(data.index)
axis.set_xticklabels(data.index)
axis.set_xlabel("read length")
axis.set_ylabel("number of reads")
plt.setp(axis.get_xticklabels(), rotation=90)
if title is not None:
plt.title(title)
## debug
try:
plt.savefig(outfile)
except RuntimeError as e:
# except RuntimeError as e:
except RuntimeError:
print(data.index)
print(title)
raise
......@@ -449,6 +476,9 @@ def plot_histo(outfile, data, title=None):
def plot_boxplots(data, ylabel):
"""
Plot boxplots of data in *data* using *ylabel* as y-axis label.
"""
fig = plt.figure(figsize=(6, 12))
axis = fig.add_subplot(111)
usetex = mpl.rcParams.get("text.usetex", False)
......@@ -465,90 +495,6 @@ def plot_boxplots(data, ylabel):
############
# DE stuff #
############
# def do_deseq2(cond_names, conditions, counts_data,
# formula=None, contrast=None, deseq2_args=None):
# """Runs a DESeq2 differential expression analysis."""
# if formula is None:
# formula = Formula("~ lib")
# if contrast is None:
# # FIXME: MUT and REF are not defined
# # Maybe just make (formula and) contrast mandatory
# contrast = StrVector(["lib", MUT, REF])
# if deseq2_args is None:
# deseq2_args = {"betaPrior" : True, "addMLE" : True, "independentFiltering" : True}
# col_data = pd.DataFrame(conditions).assign(
# cond_name=pd.Series(cond_names).values).set_index("cond_name")
# # In case we want contrasts between factor combinations
# if ("lib" in col_data.columns) and ("treat" in col_data.columns):
# col_data = col_data.assign(
# lib_treat = ["%s_%s" % (lib, treat) for (lib, treat) in zip(
# col_data["lib"], col_data["treat"])])
# # http://stackoverflow.com/a/31206596/1878788
# pandas2ri.activate() # makes some conversions automatic
# # r_counts_data = pandas2ri.py2ri(counts_data)
# # r_col_data = pandas2ri.py2ri(col_data)
# # r.DESeqDataSetFromMatrix(countData=r_counts_data, colData=r_col_data, design=Formula("~lib"))
# dds = deseq2.DESeqDataSetFromMatrix(
# countData=counts_data,
# colData=col_data,
# design=formula)
# # dds = deseq2.DESeq(dds, betaPrior=deseq2_args["betaPrior"])
# # Decompose into the 3 steps to have more control on the options
# try:
# dds = deseq2.estimateSizeFactors_DESeqDataSet(dds, type="ratio")
# except RRuntimeError as e:
# if sum(counts_data.prod(axis=1)) == 0:
# msg = "".join(["Error occurred in estimateSizeFactors:\n%s\n" % e,
# "This is probably because every gene has at least one zero.\n",
# "We will try to use the \"poscounts\" method instead."])
# warnings.warn(msg)
# try:
# dds = deseq2.estimateSizeFactors_DESeqDataSet(dds, type="poscounts")
# except RRuntimeError as e:
# msg = "".join(["Error occurred in estimateSizeFactors:\n%s\n" % e,
# "We give up."])
# warnings.warn(msg)
# raise
# #print(counts_data.dtypes)
# #print(counts_data.columns)
# #print(len(counts_data))
# #raise
# else:
# raise
# size_factors = pandas2ri.ri2py(as_df(deseq2.sizeFactors_DESeqDataSet(dds)))
# #for cond in cond_names:
# # #s = size_factors.loc[cond][0]
# # #(*_, s) = size_factors.loc[cond]
# #pd.DataFrame({cond : size_factors.loc[cond][0] for cond in COND_NAMES}, index=('size_factor',))
# try:
# dds = deseq2.estimateDispersions_DESeqDataSet(dds, fitType="parametric")
# except RRuntimeError as e:
# msg = "".join(["Error occurred in estimateDispersions:\n%s\n" % e,
# "We will try with fitType=\"local\"."])
# warnings.warn(msg)
# try:
# dds = deseq2.estimateDispersions_DESeqDataSet(dds, fitType="local")
# except RRuntimeError as e:
# msg = "".join(["Error occurred in estimateDispersions:\n%s\n" % e,
# "We will try with fitType=\"mean\"."])
# warnings.warn(msg)
# try:
# dds = deseq2.estimateDispersions_DESeqDataSet(dds, fitType="mean")
# except RRuntimeError as e:
# msg = "".join(["Error occurred in estimateDispersions:\n%s\n" % e,
# "We give up."])
# warnings.warn(msg)
# raise
# dds = deseq2.nbinomWaldTest(dds, betaPrior=deseq2_args["betaPrior"])
# res = pandas2ri.ri2py(as_df(deseq2.results(
# dds,
# contrast=contrast,
# addMLE=deseq2_args["addMLE"],
# independentFiltering=deseq2_args["independentFiltering"])))
# res.index = counts_data.index
# return res, {cond : size_factors.loc[cond][0] for cond in cond_names}
# Cutoffs in log fold change
LFC_CUTOFFS = [0.5, 1, 2]
UP_STATUSES = [f"up{cutoff}" for cutoff in LFC_CUTOFFS]
......@@ -564,7 +510,7 @@ def status_setter(lfc_cutoffs=None, fold_type="log2FoldChange"):
"""Determines the up- or down-regulation status corresponding to a given
row of a deseq2 results table."""
if row["padj"] < 0.05:
#if row["log2FoldChange"] > 0:
# if row["log2FoldChange"] > 0:
lfc = row[fold_type]
if lfc > 0:
# Start from the highest cutoff,
......@@ -573,13 +519,11 @@ def status_setter(lfc_cutoffs=None, fold_type="log2FoldChange"):
if lfc > cutoff:
return f"up{cutoff}"
return "up"
else:
for cutoff in sorted(lfc_cutoffs, reverse=True):
if lfc < -cutoff:
return f"down{cutoff}"
return "down"
else:
return "NS"
for cutoff in sorted(lfc_cutoffs, reverse=True):
if lfc < -cutoff:
return f"down{cutoff}"
return "down"
return "NS"
return set_status
......@@ -589,20 +533,19 @@ def set_de_status(row):
based on the adjusted p-value in row of a deseq2 results table."""
if row["padj"] < 0.05:
return "DE"
else:
return "NS"
return "NS"
DE2COLOUR = {
# black
"DE" : "k",
"DE": "k",
# pale grey
"NS" : "0.85"}
"NS": "0.85"}
def plot_lfc_distribution(res, contrast, fold_type=None):
"""*fold_type* is "log2FoldChange" by default.
It can also be "lfcMLE", which is based on uncorrected values.
This may not be good for genes with low expression levels."""
#lfc = res.lfcMLE.dropna()
# lfc = res.lfcMLE.dropna()
if fold_type is None:
fold_type = "log2FoldChange"
lfc = getattr(res, fold_type).dropna()
......@@ -618,11 +561,16 @@ def plot_lfc_distribution(res, contrast, fold_type=None):
def make_status2colour(down_statuses, up_statuses):
"""
Generate a dictionary associating colours to statuses.
"""
statuses = list(reversed(down_statuses)) + ["down", "NS", "up"] + up_statuses
return dict(zip(statuses, sns.color_palette("coolwarm", len(statuses))))
STATUS2COLOUR = make_status2colour(DOWN_STATUSES, UP_STATUSES)
# TODO: use other labelling than logfold or gene lists, i.e. biotype
def plot_MA(res,
grouping=None,
......@@ -633,9 +581,10 @@ def plot_MA(res,
"""*fold_type* is "log2FoldChange" by default.
It can also be "lfcMLE", which is based on uncorrected values.
This may not be good for genes with low expression levels."""
if not len(res):
# if not len(res):
if not res:
raise ValueError("No data to plot.")
fig, ax = plt.subplots()
fig, axis = plt.subplots()
# Make a column indicating whether the gene is DE or NS
data = res.assign(is_DE=res.apply(set_de_status, axis=1))
x_column = "baseMean"
......@@ -645,22 +594,23 @@ def plot_MA(res,
else:
y_column = fold_type
usetex = mpl.rcParams.get("text.usetex", False)
def scatter_group(group, label, colour, size=1):
"""Plots the data in *group* on the scatterplot."""
if usetex:
label = texscape(label)
group.plot.scatter(
#x=x_column,
# x=x_column,
x="logx",
y=y_column,
s=size,
#logx=True,
# logx=True,
c=colour,
label=label, ax=ax)
label=label, ax=axis)
if usetex:
data.columns = [texscape(colname) for colname in data.columns]
y_column = texscape(y_column)
de_status_column = "is\_DE"
de_status_column = "is\_DE" # pylint: disable=W1401
else:
de_status_column = "is_DE"
# First plot the data in grey and black
......@@ -681,28 +631,30 @@ def plot_MA(res,
(status, colour) = group2colour
row_indices = data.index.intersection(grouping)
try:
label=f"{status} ({len(row_indices)})"
label = f"{status} ({len(row_indices)})"
scatter_group(data.ix[row_indices], label, colour)
except ValueError as e:
if str(e) != "scatter requires x column to be numeric":
except ValueError as err:
if str(err) != "scatter requires x column to be numeric":
print(data.ix[row_indices])
raise
else:
warnings.warn(f"Nothing to plot for {status}\n")
ax.axhline(y=1, linewidth=0.5, color="0.5", linestyle="dashed")
ax.axhline(y=-1, linewidth=0.5, color="0.5", linestyle="dashed")
warnings.warn(f"Nothing to plot for {status}\n")
axis.axhline(y=1, linewidth=0.5, color="0.5", linestyle="dashed")
axis.axhline(y=-1, linewidth=0.5, color="0.5", linestyle="dashed")
# TODO: check data basemean range
if mean_range is not None:
# ax.set_xlim(mean_range)
ax.set_xlim(np.log10(mean_range))
# axis.set_xlim(mean_range)
axis.set_xlim(np.log10(mean_range))
if lfc_range is not None:
(lfc_min, lfc_max) = lfc_range
lfc_here_min = getattr(data, y_column).min()
lfc_here_max = getattr(data, y_column).max()
if (lfc_here_min < lfc_min) or (lfc_here_max > lfc_max):
warnings.warn(f"Cannot plot {y_column} data ([{lfc_here_min}, {lfc_here_max}]) in requested range ([{lfc_min}, {lfc_max}])\n")
warnings.warn(
f"Cannot plot {y_column} data "
f"([{lfc_here_min}, {lfc_here_max}]) in requested range "
f"([{lfc_min}, {lfc_max}])\n")
else:
ax.set_ylim(lfc_range)
axis.set_ylim(lfc_range)
# https://stackoverflow.com/a/24867320/1878788
x_ticks = np.arange(
floor(np.ma.masked_invalid(data["logx"]).min()),
......@@ -710,7 +662,7 @@ def plot_MA(res,
1)
x_ticklabels = [r"$10^{{{}}}$".format(tick) for tick in x_ticks]
plt.xticks(x_ticks, x_ticklabels)
ax.set_xlabel(x_column)
axis.set_xlabel(x_column)
def plot_scatter(data,
......@@ -722,19 +674,27 @@ def plot_scatter(data,
x_range=None,
y_range=None,
axes_style=None):
if not len(data):
"""
Plot a scatterplot using data from *data*, using columns
"""
# No rows
# if not len(data):
# Does it work like that too?
if not data:
raise ValueError("No data to plot.")
fig, ax = plt.subplots()
# ax.set_adjustable('box')
# fig, axis = plt.subplots()
_, axis = plt.subplots()
# axis.set_adjustable('box')
# First plot the data in grey
data.plot.scatter(
x=x_column, y=y_column,
s=2, c="black", alpha=0.15, edgecolors='none',
ax=ax)
ax=axis)
if regression:
linreg = linregress(data[[x_column, y_column]].dropna())
a = linreg.slope
b = linreg.intercept
a = linreg.slope # pylint: disable=C0103
b = linreg.intercept # pylint: disable=C0103
def fit(x):
return (a * x) + b
min_x = data[[x_column]].min()[0]
......@@ -742,18 +702,22 @@ def plot_scatter(data,
min_y = fit(min_x)
max_y = fit(max_x)
xfit, yfit = (min_x, max_x), (min_y, max_y)
ax.plot(xfit, yfit, linewidth=0.5, color="0.5", linestyle="dashed")
axis.plot(
xfit, yfit,
linewidth=0.5, color="0.5", linestyle="dashed")
# Overlay colour points
if grouping is not None:
if isinstance(grouping, str):
# Determine colours based on the "grouping" column
if group2colour is None:
statuses = data[grouping].unique()
group2colour = dict(zip(statuses, sns.color_palette("colorblind", len(statuses))))
group2colour = dict(zip(
statuses,
sns.color_palette("colorblind", len(statuses))))
for (status, group) in data.groupby(grouping):
group.plot.scatter(
x=x_column, y=y_column, s=1, c=group2colour[status],
label=f"{status} ({len(group)})", ax=ax)
label=f"{status} ({len(group)})", ax=axis)
else:
# Apply a colour to a list of genes
(status, colour) = group2colour
......@@ -761,37 +725,41 @@ def plot_scatter(data,
try:
data.ix[row_indices].plot.scatter(
x=x_column, y=y_column, s=1, c=colour,
label=f"{status} ({len(row_indices)})", ax=ax)
except ValueError as e:
if str(e) != "scatter requires x column to be numeric":
label=f"{status} ({len(row_indices)})", ax=axis)
except ValueError as err: