diff --git a/libhts/libhts.py b/libhts/libhts.py
index 24b9e20b65cc903acc951648b34a74bd37ee2060..45b9f6e5e8317b0cc6322d336d22b3fa573feff6 100644
--- a/libhts/libhts.py
+++ b/libhts/libhts.py
@@ -9,6 +9,7 @@ def formatwarning(message, category, filename, lineno, line):
 
 warnings.formatwarning = formatwarning
 
+import numpy as np
 import pandas as pd
 # To compute correlation coefficient, and compute linear regression
 from scipy.stats.stats import pearsonr, linregress
@@ -242,6 +243,27 @@ def do_deseq2(cond_names, conditions, counts_data,
     return res, {cond : size_factors.loc[cond][0] for cond in cond_names}
 
 
+## Not sure this is a good idea...
+# def masked_gmean(a, axis=0, dtype=None):
+#     """Modified from stats.py."""
+#     # Converts the data into a masked array
+#     ma = np.ma.masked_invalid(a)
+#     # Apply gmean
+#     if not isinstance(ma, np.ndarray):
+#         # if not an ndarray object attempt to convert it
+#         log_a = np.log(np.array(ma, dtype=dtype))
+#     elif dtype:
+#         # Must change the default dtype allowing array type
+#         if isinstance(ma, np.ma.MaskedArray):
+#             log_a = np.log(np.ma.asarray(ma, dtype=dtype))
+#         else:
+#             log_a = np.log(np.asarray(ma, dtype=dtype))
+#     else:
+#         log_a = np.log(ma)
+#     return np.exp(log_a.mean(axis=axis))
+
+
+
 def median_ratio_to_pseudo_ref_size_factors(counts_data):
     """Adapted from DESeq paper (doi:10.1186/gb-2010-11-10-r106)
     All libraries are used to define a pseudo-reference, which has
@@ -252,11 +274,20 @@ def median_ratio_to_pseudo_ref_size_factors(counts_data):
     #pseudo_ref = (counts_data + 1).apply(gmean, axis=1) - 1
     # Ignore lines with zeroes instead (may be bad for IP: many zeroes expected):
     pseudo_ref = (counts_data[counts_data.prod(axis=1) > 0]).apply(gmean, axis=1)
+    # Ignore lines with only zeroes
+    # pseudo_ref = (counts_data[counts_data.sum(axis=1) > 0]).apply(masked_gmean, axis=1)
     def median_ratio_to_pseudo_ref(col):
         return (col / pseudo_ref).median()
     #size_factors = counts_data.apply(median_ratio_to_pseudo_ref, axis=0)
-    return counts_data[counts_data.prod(axis=1) > 0].apply(
+    median_ratios = counts_data[counts_data.prod(axis=1) > 0].apply(
         median_ratio_to_pseudo_ref, axis=0)
+    # Not sure fillna(0) is appropriate
+    if any(median_ratios.isna()):
+        msg = "Could not compute median ratios to pseudo reference.\n"
+        warnings.warn(msg)
+        return median_ratios.fillna(1)
+    else:
+        return median_ratios
 
 
 def size_factor_correlations(counts_data, summaries, normalizer):
@@ -442,7 +473,8 @@ def plot_scatter(data,
                  grouping=None,
                  group2colour=None,
                  x_range=None,
-                 y_range=None):
+                 y_range=None,
+                 axes_style=None):
     fig, ax = plt.subplots()
     # First plot the data in grey
     data.plot.scatter(
@@ -468,7 +500,7 @@ def plot_scatter(data,
             if group2colour is None:
                 statuses = data[grouping].unique()
                 group2colour = dict(zip(statuses, sns.color_palette("colorblind", len(statuses))))
-            for status, group in data.groupby(grouping):
+            for (status, group) in data.groupby(grouping):
                 group.plot.scatter(
                     x=x_column, y=y_column, s=1, c=group2colour[status],
                     label=f"{status} ({len(group)})", ax=ax)
@@ -486,8 +518,12 @@ def plot_scatter(data,
                     raise
                 else:
                     warnings.warn(f"Nothing to plot for {status}\n")
-    ax.axhline(y=0, linewidth=0.5, color="0.5", linestyle="dashed")
-    ax.axvline(x=0, linewidth=0.5, color="0.5", linestyle="dashed")
+    if axes_style is None:
+        axes_style = {"linewidth": 0.5, "color": "0.5", "linestyle": "dashed"}
+    ax.axhline(y=0, **axes_style)
+    ax.axvline(x=0, **axes_style)
+    # ax.axhline(y=0, linewidth=0.5, color="0.5", linestyle="dashed")
+    # ax.axvline(x=0, linewidth=0.5, color="0.5", linestyle="dashed")
     # Set axis limits
     if x_range is not None:
         (x_min, x_max) = x_range