diff --git a/src/napari_segment/_widget.py b/src/napari_segment/_widget.py index 9bab1f7cfa0916c8168c9bc1ce77116035df3bd1..204e25f61e72ff01af439d2bfea0662955be2fb0 100644 --- a/src/napari_segment/_widget.py +++ b/src/napari_segment/_widget.py @@ -14,13 +14,12 @@ import numpy as np from magicgui import magic_factory from qtpy.QtWidgets import QHBoxLayout, QPushButton, QWidget from scipy.ndimage import ( - binary_dilation, binary_erosion, binary_fill_holes, gaussian_filter, label, ) -from skimage.measure import regionprops_table +from skimage.measure import regionprops # import matplotlib.pyplot as plt @@ -65,67 +64,118 @@ def save_values(val): "max": 0.5, }, erode={"widget_type": "Slider", "min": 1, "max": 10}, + min_diam={"widget_type": "Slider", "min": 10, "max": 250}, + max_diam={"widget_type": "Slider", "min": 150, "max": 1000}, + max_ecc={ + "label": "Eccentricity", + "widget_type": "FloatSlider", + "min": 0.0, + "max": 1.0, + }, ) def segment_organoid( BF_layer: "napari.layers.Image", fluo_layer: "napari.layers.Image", - thr: float = 0.4, + thr: float = 0.3, erode: int = 10, - donut=10, + min_diam=150, + max_diam=550, + max_ecc=0.7, + show_detections=True, ) -> napari.types.LayerDataTuple: # frame = napari.current_viewer().cursor.position[0] kwargs = {} ddata = BF_layer.data - labels = ddata.map_blocks( - partial(segment_bf, thr=thr, erode=erode), dtype=ddata.dtype + smooth_gradient = ddata.map_blocks( + partial(get_gradient), dtype=ddata.dtype ) - selected_labels = dask.array.map_blocks(filter_biggest, labels, donut) - print(selected_labels.shape) - return [ - (labels, {"name": "raw_labels", "visible": False, **kwargs}, "labels"), - (selected_labels, {"name": "selected labels", **kwargs}, "labels"), - ] - - -def filter_biggest(labels, donut=10): - props = regionprops_table( + labels = smooth_gradient.map_blocks( + partial(threshold_gradient, thr=thr, erode=erode), + dtype=ddata.dtype, + ) + try: + selected_labels = dask.array.map_blocks( + filter_labels, + labels, + min_diam, + max_diam, + max_ecc, + dtype=ddata.dtype, + ) + # print(selected_labels.shape) + return [ + ( + labels, + {"name": "Detections", "visible": show_detections, **kwargs}, + "labels", + ), + (selected_labels, {"name": "selected labels", **kwargs}, "labels"), + ] + except TypeError: + return [ + ( + labels, + {"name": "Detections", "visible": True, **kwargs}, + "labels", + ), + ] + + +def filter_labels(labels, min_diam=50, max_diam=150, max_ecc=0.2): + if max_diam <= min_diam: + raise ValueError( + "min value is greater than max value for the diameter filter" + ) + props = regionprops( labels[0], - properties=( - "label", - "area", - ), ) - biggest_prop_index = np.argmax(props["area"]) - label_of_biggest_object = props["label"][biggest_prop_index] - spheroid_mask = labels[0] == label_of_biggest_object - bg_mask = np.bitwise_xor( - binary_dilation(spheroid_mask, structure=np.ones((donut, donut))), - spheroid_mask, + good_props = filter( + lambda p: (d := p.major_axis_length) > min_diam + and d < max_diam + and p.eccentricity < max_ecc, + props, ) - return ( - spheroid_mask.astype("uint16") + 2 * bg_mask.astype("uint16") - ).reshape(labels.shape) + good_labels = [p.label for p in good_props] + if len(good_labels) < 1: + return np.zeros_like(labels) + # print(f'good_labels {good_labels}') + mask = np.sum([labels == v for v in good_labels], axis=0) + # print(mask.shape) + return (label(mask)[0].astype("uint16")).reshape(labels.shape) -def segment_bf(well, thr=0.2, smooth=10, erode=10, fill=True, plot=False): +def get_gradient(bf_data: np.ndarray, smooth=10): """ - Serments input 2d array using thresholded gradient with filling + Removes first dimension, + Computes gradient of the image, + applies gaussian filter Returns SegmentedImage object """ - grad = get_2d_gradient(well[0]) - sm = gaussian_filter(grad, smooth) + assert ( + bf_data[0].ndim == 2 + ), f"expected 2D shape, got shape {bf_data[0].shape}" + gradient = get_2d_gradient(bf_data[0]) + smoothed_gradient = gaussian_filter(gradient, smooth) # sm = multiwell.gaussian_filter(well, smooth) + return smoothed_gradient.reshape(bf_data.shape) - regions = sm > thr * sm.max() + +def threshold_gradient( + smoothed_gradient: np.ndarray, + thr: float = 0.4, + fill: bool = True, + erode: int = 1, +): + regions = smoothed_gradient[0] > thr * smoothed_gradient[0].max() if fill: regions = binary_fill_holes(regions) - if erode: + if erode and erode > 0: regions = binary_erosion(regions, iterations=erode) labels, _ = label(regions) - return labels.reshape(well.shape) + return labels.reshape(smoothed_gradient.shape) def get_2d_gradient(xy):