diff --git a/libhts/__init__.py b/libhts/__init__.py
index 3d089bba0c487d9853f2a98da83e4e10a62bb05a..2da39117da59070c590205818dfdcb687e918369 100644
--- a/libhts/__init__.py
+++ b/libhts/__init__.py
@@ -1 +1 @@
-from .libhts import do_deseq2, gtf_2_genes_exon_lengths, median_ratio_to_pseudo_ref_size_factors, plot_boxplots, plot_counts_distribution, plot_lfc_distribution, plot_MA, plot_norm_correlations, repeat_bed_2_lengths, size_factor_correlations, status_setter
+from .libhts import do_deseq2, gtf_2_genes_exon_lengths, median_ratio_to_pseudo_ref_size_factors, plot_boxplots, plot_counts_distribution, plot_lfc_distribution, plot_MA, plot_norm_correlations, plot_scatter, repeat_bed_2_lengths, size_factor_correlations, status_setter
diff --git a/libhts/libhts.py b/libhts/libhts.py
index 7b23e8a19c4f3155c1b7b149979dbe06691b7c6e..fa5e6e168c07e00995e95e6eef170bd69114f76a 100644
--- a/libhts/libhts.py
+++ b/libhts/libhts.py
@@ -10,8 +10,8 @@ def formatwarning(message, category, filename, lineno, line):
 warnings.formatwarning = formatwarning
 
 import pandas as pd
-# To compute correlation coefficient
-from scipy.stats.stats import pearsonr
+# To compute correlation coefficient, and compute linear regression
+from scipy.stats.stats import pearsonr, linregress
 # To compute geometric mean
 from scipy.stats.mstats import gmean
 import matplotlib.pyplot as plt
@@ -150,6 +150,7 @@ def do_deseq2(cond_names, conditions, counts_data,
     if formula is None:
         formula = Formula("~ lib")
     if contrast is None:
+        # FIXME: MUT and REF are not defined
         contrast = StrVector(["lib", MUT, REF])
     if deseq2_args is None:
         deseq2_args = {"betaPrior" : True, "addMLE" : True, "independentFiltering" : True}
@@ -374,22 +375,27 @@ def plot_MA(res,
     This may not be good for genes with low expression levels."""
     fig, ax = plt.subplots()
     # Make a column indicating whether the gene is DE or NS
-    res = res.assign(is_DE=res.apply(set_de_status, axis=1))
+    data = res.assign(is_DE=res.apply(set_de_status, axis=1))
     if fold_type is None:
         fold_type = "log2FoldChange"
     # First plot the data in grey and black
-    for de_status, group in res.groupby("is_DE"):
+    for de_status, group in data.groupby("is_DE"):
         group.plot.scatter(x="baseMean", y=fold_type, s=2, logx=True, c=DE2COLOUR[de_status], label=f"{de_status} ({len(group)})", ax=ax)
     if grouping is not None:
         if isinstance(grouping, str):
             # Overlay colours based on the "grouping" column
             if group2colour is None:
                 group2colour = STATUS2COLOUR
-            for status, group in res.groupby(grouping):
-                group.plot.scatter(x="baseMean", y=fold_type, s=1, logx=True, c=group2colour[status], label=f"{status} ({len(group)})", ax=ax)
+            for status, group in data.groupby(grouping):
+                group.plot.scatter(
+                    x="baseMean", y=fold_type, s=1, logx=True, c=group2colour[status],
+                    label=f"{status} ({len(group)})", ax=ax)
         else:
             (status, colour) = group2colour
-            res.ix[grouping].plot.scatter(x="baseMean", y=fold_type, s=1, logx=True, c=colour, label=f"{status} ({len(grouping)})", ax=ax)
+            row_indices = data.index.intersection(grouping)
+            data.ix[row_indices].plot.scatter(
+                x="baseMean", y=fold_type, s=1, logx=True, c=colour,
+                label=f"{status} ({len(row_indices)})", ax=ax)
     ax.axhline(y=1, linewidth=0.5, color="0.5", linestyle="dashed")
     ax.axhline(y=-1, linewidth=0.5, color="0.5", linestyle="dashed")
     # TODO: check data basemean range
@@ -397,10 +403,71 @@ def plot_MA(res,
         ax.set_xlim(mean_range)
     if lfc_range is not None:
         (lfc_min, lfc_max) = lfc_range
-        lfc_here_min = getattr(res, fold_type).min()
-        lfc_here_max = getattr(res, fold_type).max()
+        lfc_here_min = getattr(data, fold_type).min()
+        lfc_here_max = getattr(data, fold_type).max()
         if (lfc_here_min < lfc_min) or (lfc_here_max > lfc_max):
             warnings.warn(f"Cannot plot {fold_type} data ([{lfc_here_min}, {lfc_here_max}]) in requested range ([{lfc_min}, {lfc_max}])\n")
         else:
             ax.set_ylim(lfc_range)
 
+def plot_scatter(data,
+                 x_column,
+                 y_column,
+                 regression=False,
+                 grouping=None,
+                 group2colour=None,
+                 x_range=None,
+                 y_range=None):
+    fig, ax = plt.subplots()
+    # First plot the data in grey
+    data.plot.scatter(x=x_column, y=y_column, s=2, c="lightgray", ax=ax)
+    if regression:
+        linreg = linregress(data[[x_column, y_column]].dropna())
+        a = linreg.slope
+        b = linreg.intercept
+        def fit(x):
+            return (a * x) + b
+        min_x = data[[x_column]].min()[0]
+        max_x = data[[x_column]].max()[0]
+        min_y = fit(min_x)
+        max_y = fit(max_x)
+        xfit, yfit = (min_x, max_x), (min_y, max_y)
+        ax.plot(xfit, yfit, linewidth=0.5, color="0.5", linestyle="dashed")
+    # Overlay colour points
+    if grouping is not None:
+        if isinstance(grouping, str):
+            # Determine colours based on the "grouping" column
+            if group2colour is None:
+                statuses = data[grouping].unique()
+                group2colour = dict(zip(statuses, sns.color_palette("colorblind", len(statuses))))
+            for status, group in data.groupby(grouping):
+                group.plot.scatter(
+                    x=x_column, y=y_column, s=1, c=group2colour[status],
+                    label=f"{status} ({len(group)})", ax=ax)
+        else:
+            # Apply a coulour to a list of genes
+            (status, colour) = group2colour
+            row_indices = data.index.intersection(grouping)
+            data.ix[row_indices].plot.scatter(
+                x=x_column, y=y_column, s=1, c=colour,
+                label=f"{status} ({len(row_indices)})", ax=ax)
+    ax.axhline(y=0, linewidth=0.5, color="0.5", linestyle="dashed")
+    ax.axvline(x=0, linewidth=0.5, color="0.5", linestyle="dashed")
+    # Set axis limits
+    if x_range is not None:
+        (x_min, x_max) = x_range
+        x_here_min = getattr(data, x_column).min()
+        x_here_max = getattr(data, x_column).max()
+        if (x_here_min < x_min) or (x_here_max > x_max):
+            warnings.warn(f"Cannot plot {x_column} data ([{x_here_min}, {x_here_max}]) in requested range ([{x_min}, {x_max}])\n")
+        else:
+            ax.set_xlim(x_range)
+    if y_range is not None:
+        (y_min, y_max) = y_range
+        y_here_min = getattr(data, y_column).min()
+        y_here_max = getattr(data, y_column).max()
+        if (y_here_min < y_min) or (y_here_max > y_max):
+            warnings.warn(f"Cannot plot {y_column} data ([{y_here_min}, {y_here_max}]) in requested range ([{y_min}, {y_max}])\n")
+        else:
+            ax.set_ylim(y_range)
+    return ax