diff --git a/src/napari_segment/_widget.py b/src/napari_segment/_widget.py index 204e25f61e72ff01af439d2bfea0662955be2fb0..fc85458d602092bdc633bd279d8939e428965ffc 100644 --- a/src/napari_segment/_widget.py +++ b/src/napari_segment/_widget.py @@ -151,10 +151,8 @@ def get_gradient(bf_data: np.ndarray, smooth=10): applies gaussian filter Returns SegmentedImage object """ - assert ( - bf_data[0].ndim == 2 - ), f"expected 2D shape, got shape {bf_data[0].shape}" - gradient = get_2d_gradient(bf_data[0]) + data = strip_dimensions(bf_data) + gradient = get_2d_gradient(data) smoothed_gradient = gaussian_filter(gradient, smooth) # sm = multiwell.gaussian_filter(well, smooth) return smoothed_gradient.reshape(bf_data.shape) @@ -166,7 +164,8 @@ def threshold_gradient( fill: bool = True, erode: int = 1, ): - regions = smoothed_gradient[0] > thr * smoothed_gradient[0].max() + data = strip_dimensions(smoothed_gradient) + regions = data > thr * data.max() if fill: regions = binary_fill_holes(regions) @@ -178,6 +177,13 @@ def threshold_gradient( return labels.reshape(smoothed_gradient.shape) +def strip_dimensions(array:np.ndarray): + data = array.copy() + while data.ndim > 2: + assert data.shape[0] == 1, f'Unexpected multidimensional data! {data.shape}' + data = data[0] + return data + def get_2d_gradient(xy): gx, gy = np.gradient(xy) return np.sqrt(gx**2 + gy**2)