Skip to content
Snippets Groups Projects
Commit 4280ef4c authored by Blaise Li's avatar Blaise Li
Browse files

Scatterplot tweaking. Code homogeneization.

parent 49cfaa10
No related branches found
No related tags found
No related merge requests found
......@@ -376,11 +376,14 @@ def plot_MA(res,
fig, ax = plt.subplots()
# Make a column indicating whether the gene is DE or NS
data = res.assign(is_DE=res.apply(set_de_status, axis=1))
x_column = "baseMean"
if fold_type is None:
fold_type = "log2FoldChange"
y_column = "log2FoldChange"
else:
y_column = fold_type
# First plot the data in grey and black
for de_status, group in data.groupby("is_DE"):
group.plot.scatter(x="baseMean", y=fold_type, s=2, logx=True, c=DE2COLOUR[de_status], label=f"{de_status} ({len(group)})", ax=ax)
group.plot.scatter(x=x_column, y=y_column, s=2, logx=True, c=DE2COLOUR[de_status], label=f"{de_status} ({len(group)})", ax=ax)
if grouping is not None:
if isinstance(grouping, str):
# Overlay colours based on the "grouping" column
......@@ -388,13 +391,13 @@ def plot_MA(res,
group2colour = STATUS2COLOUR
for status, group in data.groupby(grouping):
group.plot.scatter(
x="baseMean", y=fold_type, s=1, logx=True, c=group2colour[status],
x=x_column, y=y_column, s=1, logx=True, c=group2colour[status],
label=f"{status} ({len(group)})", ax=ax)
else:
(status, colour) = group2colour
row_indices = data.index.intersection(grouping)
data.ix[row_indices].plot.scatter(
x="baseMean", y=fold_type, s=1, logx=True, c=colour,
x=x_column, y=y_column, s=1, logx=True, c=colour,
label=f"{status} ({len(row_indices)})", ax=ax)
ax.axhline(y=1, linewidth=0.5, color="0.5", linestyle="dashed")
ax.axhline(y=-1, linewidth=0.5, color="0.5", linestyle="dashed")
......@@ -403,10 +406,10 @@ def plot_MA(res,
ax.set_xlim(mean_range)
if lfc_range is not None:
(lfc_min, lfc_max) = lfc_range
lfc_here_min = getattr(data, fold_type).min()
lfc_here_max = getattr(data, fold_type).max()
lfc_here_min = getattr(data, y_column).min()
lfc_here_max = getattr(data, y_column).max()
if (lfc_here_min < lfc_min) or (lfc_here_max > lfc_max):
warnings.warn(f"Cannot plot {fold_type} data ([{lfc_here_min}, {lfc_here_max}]) in requested range ([{lfc_min}, {lfc_max}])\n")
warnings.warn(f"Cannot plot {y_column} data ([{lfc_here_min}, {lfc_here_max}]) in requested range ([{lfc_min}, {lfc_max}])\n")
else:
ax.set_ylim(lfc_range)
......@@ -420,7 +423,10 @@ def plot_scatter(data,
y_range=None):
fig, ax = plt.subplots()
# First plot the data in grey
data.plot.scatter(x=x_column, y=y_column, s=2, c="lightgray", ax=ax)
data.plot.scatter(
x=x_column, y=y_column,
s=2, c="black", alpha=0.15, edgecolors='none',
ax=ax)
if regression:
linreg = linregress(data[[x_column, y_column]].dropna())
a = linreg.slope
......
......@@ -121,6 +121,50 @@ class Scatterplot:
ax.axhline(y=-1, linewidth=0.5, color="lightblue", linestyle="dashed")
ax.axvline(x=1, linewidth=0.5, color="lightblue", linestyle="dashed")
ax.axvline(x=-1, linewidth=0.5, color="lightblue", linestyle="dashed")
# up_up = 100 * len(self.data.query(
# f"{self.x_col} > 1 & {self.y_col} > 1")) / len(self.data)
# up_down = 100 * len(self.data.query(
# f"{self.x_col} > 1 & {self.y_col} < 1")) / len(self.data)
# down_up = 100 * len(self.data.query(
# f"{self.x_col} < 1 & {self.y_col} > 1")) / len(self.data)
# down_down = 100 * len(self.data.query(
# f"{self.x_col} < 1 & {self.y_col} < 1")) / len(self.data)
up_up = len(self.data.query(
f"{self.x_col} > 1 & {self.y_col} > 1"))
up_down = len(self.data.query(
f"{self.x_col} > 1 & {self.y_col} < 1"))
down_up = len(self.data.query(
f"{self.x_col} < 1 & {self.y_col} > 1"))
down_down = len(self.data.query(
f"{self.x_col} < 1 & {self.y_col} < 1"))
#ax.text(0.9, 0.9, f"{up_up}", transform = ax.transAxes)
#ax.text(0.9, 0.1, f"{up_down}", transform = ax.transAxes)
#ax.text(0.1, 0.9, f"{down_up}", transform = ax.transAxes)
#ax.text(0.1, 0.1, f"{down_down}", transform = ax.transAxes)
ax.annotate(
f"{up_up}", xy=(0.95, 0.95), xycoords='axes fraction',
size="x-small", color="lightblue",
horizontalalignment="right",
verticalalignment="top")
ax.annotate(
f"{up_down}", xy=(0.95, 0.05), xycoords='axes fraction',
size="x-small", color="lightblue",
horizontalalignment="right",
verticalalignment="bottom")
ax.annotate(
f"{down_up}", xy=(0.05, 0.95), xycoords='axes fraction',
size="x-small", color="lightblue",
horizontalalignment="left",
verticalalignment="top")
ax.annotate(
f"{down_down}", xy=(0.05, 0.05), xycoords='axes fraction',
size="x-small", color="lightblue",
horizontalalignment="left",
verticalalignment="bottom")
# Move legend to middle top
#ax.legend(loc="upper center")
ax.legend(bbox_to_anchor=(0, 1), bbox_transform=ax.transAxes, loc="lower left")
#ax.legend(bbox_to_anchor=(0.5, 0.9), bbox_transform=ax.transAxes)
# TODO: force ticks to be integers
return ax
return plot_lfclfc_scatter
......@@ -129,7 +173,8 @@ class Scatterplot:
if grouping is None and self.grouping_col is not None:
grouping = self.grouping_col
save_plot(outfile, self.plot_maker(
grouping=grouping, group2colour=group2colour, **kwargs), equal_axes=True)
grouping=grouping, group2colour=group2colour, **kwargs),
equal_axes=True, tight=True)
#########################################################
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment