diff --git a/src/napari_segment/_widget.py b/src/napari_segment/_widget.py
index 204e25f61e72ff01af439d2bfea0662955be2fb0..fc85458d602092bdc633bd279d8939e428965ffc 100644
--- a/src/napari_segment/_widget.py
+++ b/src/napari_segment/_widget.py
@@ -151,10 +151,8 @@ def get_gradient(bf_data: np.ndarray, smooth=10):
     applies gaussian filter
     Returns SegmentedImage object
     """
-    assert (
-        bf_data[0].ndim == 2
-    ), f"expected 2D shape, got shape {bf_data[0].shape}"
-    gradient = get_2d_gradient(bf_data[0])
+    data = strip_dimensions(bf_data)
+    gradient = get_2d_gradient(data)
     smoothed_gradient = gaussian_filter(gradient, smooth)
     #     sm = multiwell.gaussian_filter(well, smooth)
     return smoothed_gradient.reshape(bf_data.shape)
@@ -166,7 +164,8 @@ def threshold_gradient(
     fill: bool = True,
     erode: int = 1,
 ):
-    regions = smoothed_gradient[0] > thr * smoothed_gradient[0].max()
+    data = strip_dimensions(smoothed_gradient)
+    regions = data > thr * data.max()
 
     if fill:
         regions = binary_fill_holes(regions)
@@ -178,6 +177,13 @@ def threshold_gradient(
     return labels.reshape(smoothed_gradient.shape)
 
 
+def strip_dimensions(array:np.ndarray):
+    data = array.copy()
+    while data.ndim > 2:
+        assert data.shape[0] == 1, f'Unexpected multidimensional data! {data.shape}'
+        data = data[0]
+    return data
+
 def get_2d_gradient(xy):
     gx, gy = np.gradient(xy)
     return np.sqrt(gx**2 + gy**2)