diff --git a/src/napari_segment/_reader.py b/src/napari_segment/_reader.py index bb844133e4e86eaaee3d01ee2b7e8951d5d79ee2..a7d7d7cac232b7c7d15be9fba8ca64a4ad7b287c 100644 --- a/src/napari_segment/_reader.py +++ b/src/napari_segment/_reader.py @@ -102,7 +102,7 @@ def read_zarr(path): def read_nd2(path): - print(f"opening {path}") + print(f"reading {path}") data = nd2.ND2File(path) print(data.sizes) ddata = data.to_dask() @@ -116,7 +116,10 @@ def read_nd2(path): return [ ( ddata, - {"channel_axis": channel_axis}, + { + "channel_axis": channel_axis, + "metadata": {"sizes": data.sizes, "path": path}, + }, # dict( # channel_axis=channel_axis, # name=[ch.channel.name for ch in data.metadata.channels], diff --git a/src/napari_segment/_widget.py b/src/napari_segment/_widget.py index fb09f3c78aa9ddc180fe71da6abcde04bedf6cd1..5370fb734b250e04b7ba8fcd289edf7fede53814 100644 --- a/src/napari_segment/_widget.py +++ b/src/napari_segment/_widget.py @@ -6,13 +6,18 @@ see: https://napari.org/plugins/stable/guides.html#widgets Replace code below according to your needs. """ +import os from functools import partial import dask +import magicgui.widgets as w import napari import numpy as np +import qtpy.QtWidgets as q +import yaml from magicgui import magic_factory -from qtpy.QtWidgets import QHBoxLayout, QPushButton, QWidget +from napari.layers import Image +from napari.utils.notifications import show_error from scipy.ndimage import ( binary_erosion, binary_fill_holes, @@ -24,7 +29,7 @@ from skimage.measure import regionprops # import matplotlib.pyplot as plt -class ExampleQWidget(QWidget): +class ExampleQWidget(q.QWidget): # your QWidget.__init__ can optionally request the napari viewer instance # in one of two ways: # 1. use a parameter called `napari_viewer`, as done here @@ -33,14 +38,249 @@ class ExampleQWidget(QWidget): super().__init__() self.viewer = napari_viewer - btn = QPushButton("Click me!") - btn.clicked.connect(self._on_click) + self.data = w.ComboBox( + label="BF data", + annotation=Image, + choices=[ + layer.name + for layer in self.viewer.layers + if isinstance(layer, Image) + ], + ) + self.data.changed.connect(self.restore_params) + self.data.changed.connect(self.preprocess) + + self.binning_widget = w.RadioButtons( + label="binning", + choices=[2**n for n in range(4)], + value=4, + orientation="horizontal", + ) + self.binning_widget.changed.connect(self.preprocess) + + self.thr = w.FloatSlider(label="Threshold", min=0.1, max=0.9) + self.thr.changed.connect(self.threshold) + + self.erode = w.SpinBox(label="erode", min=0, max=10) + self.erode.changed.connect(self.threshold) + + self.use = w.RadioButtons( + label="Use", + choices=["Intensity", "Gradient"], + value="Intensity", + orientation="horizontal", + ) + self.use.changed.connect(self.preprocess) + + self.smooth = w.SpinBox(label="smooth", min=0, max=10) + self.smooth.changed.connect(self.preprocess) + + self.min_diam = w.Slider( + label="Min_diameter", min=50, max=500, step=50 + ) + self.min_diam.changed.connect(self.update_out) - self.setLayout(QHBoxLayout()) - self.layout().addWidget(btn) + self.max_diam = w.Slider( + label="Max_diameter", min=150, max=2000, step=150 + ) + self.max_diam.changed.connect(self.update_out) + + self.max_ecc = w.FloatSlider( + label="Max eccentricity", + min=0.0, + max=1.0, + ) + self.max_ecc.changed.connect(self.update_out) + + self.btn = q.QPushButton("Save!") + self.btn.clicked.connect(self.save_params) + + self.container = w.Container( + widgets=[ + self.data, + w.Label(label="Prepocessing"), + self.binning_widget, + self.use, + w.Label(label="Detection"), + self.smooth, + self.thr, + self.erode, + w.Label(label="Filters"), + self.min_diam, + self.max_diam, + self.max_ecc, + ] + ) + + self.setLayout(q.QVBoxLayout()) + self.layout().addWidget(self.container.native) + self.layout().addWidget(self.btn) + self.layout().addStretch() - def _on_click(self): - print("napari has", len(self.viewer.layers), "layers") + self.viewer.layers.events.inserted.connect(self.reset_choices) + self.viewer.layers.events.removed.connect(self.reset_choices) + + if self.data.current_choice: + print("start") + self.restore_params() + self.preprocess() + + def preprocess(self): + self.binning = self.binning_widget.value + try: + self.ddata = self.viewer.layers[self.data.current_choice].data[ + ..., :: self.binning, :: self.binning + ] + except KeyError: + show_error("No data to process") + return + self.scale = np.ones((len(self.ddata.shape),)) + self.scale[-2:] = self.binning + if isinstance(self.ddata, np.ndarray): + chunksize = np.ones(self.ddata.ndims) + chunksize[-2:] = self.ddata.shape[-2:] # xy full size + self.ddata = dask.array.from_array(self.ddata, chunksize=chunksize) + self.smooth_gradient = ( + self.ddata.map_blocks( + partial(get_gradient, smooth=self.smooth.value), + dtype=self.ddata.dtype, + ) + if self.use.value == "Gradient" + else self.ddata.map_blocks( + lambda d: 1 + - (a := (s := gaussian_filter(d, self.smooth.value)) - s.min()) + / a.max(), + dtype=self.ddata.dtype, + ) + ) + if not (name := "Preprocessing") in self.viewer.layers: + self.viewer.add_image( + data=self.smooth_gradient, + **{"name": name, "scale": self.scale}, + ) + else: + self.viewer.layers[name].data = self.smooth_gradient + self.viewer.layers[name].scale = self.scale + self.threshold() + + def threshold(self): + if not self.data.current_choice: + return + self.labels = self.smooth_gradient.map_blocks( + partial( + threshold_gradient, thr=self.thr.value, erode=self.erode.value + ), + dtype=self.ddata.dtype, + ) + if not (name := "Detections") in self.viewer.layers: + self.viewer.add_labels( + data=self.labels, + opacity=0.3, + **{"name": name, "scale": self.scale}, + ) + else: + self.viewer.layers[name].data = self.labels + self.viewer.layers[name].scale = self.scale + self.viewer.layers[name].contour = 5 + self.update_out() + + def update_out(self): + + if not self.data.current_choice: + return + + try: + selected_labels = dask.array.map_blocks( + filter_labels, + self.labels, + self.min_diam.value / self.binning, + self.max_diam.value / self.binning, + self.max_ecc.value, + dtype=self.ddata.dtype, + ) + + if not (name := "selected labels") in self.viewer.layers: + self.viewer.add_labels( + data=selected_labels, + opacity=0.5, + **{"name": name, "scale": self.scale}, + ) + else: + self.viewer.layers[name].scale = self.scale + self.viewer.layers[name].data = selected_labels + # self.save_params() + except TypeError: + show_error("Relax filter!") + + def save_params(self): + data = { + "binning": self.binning, + "use": self.use.value, + "smooth": self.smooth.value, + "thr": self.thr.value, + "erode": self.erode.value, + "min_diameter": self.min_diam.value, + "max_diameter": self.max_diam.value, + "max_ecc": self.max_ecc.value, + } + try: + path = self.viewer.layers[self.data.current_choice].metadata[ + "path" + ] + dir = os.path.dirname(path) + filename = os.path.basename(path) + new_name = filename.replace(".nd2", ".params.yaml") + + with open(os.path.join(dir, new_name), "w") as f: + yaml.safe_dump(data, f) + except KeyError: + pass + with open(os.path.join(".latest.params.yaml"), "w") as f: + yaml.safe_dump(data, f) + + def restore_params(self): + try: + path = self.viewer.layers[self.data.current_choice].metadata[ + "path" + ] + dir = os.path.dirname(path) + filename = os.path.basename(path) + new_name = filename.replace(".nd2", ".params.yaml") + except KeyError: + pass + try: + with open(ppp := os.path.join(dir, new_name)) as f: + data = yaml.safe_load(f) + except OSError: + with open(ppp := os.path.join(".latest.params.yaml")) as f: + data = yaml.safe_load(f) + print(f"restoring parameters from {ppp}") + print(data) + try: + self.binning_widget.value = data["binning"] + self.use.value = data["use"] + self.smooth.value = data["smooth"] + self.thr.value = data["thr"] + self.erode.value = data["erode"] + self.min_diam.value = data["min_diameter"] + self.max_diam.value = data["max_diameter"] + self.max_ecc.value = data["max_ecc"] + except Exception as e: + show_error(f"Restore settings failed, {e}") + + def reset_choices(self, event=None): + self.data.reset_choices(event) + self.data.choices = [ + layer.name + for layer in self.viewer.layers + if isinstance(layer, Image) and layer.name != "Preprocessing" + ] + self.restore_params() + + +def norm01(data): + d = data.copy() + return (a := d - d.min()) / a.max() @magic_factory @@ -102,7 +342,11 @@ def segment_organoid( partial(get_gradient, smooth=smooth), dtype=ddata.dtype ) if use_gradient - else 1 - (a := ddata - ddata.min()) / a.max() + else ddata.map_blocks( + lambda d: 1 + - (a := (s := gaussian_filter(d, smooth)) - s.min()) / a.max(), + dtype=ddata.dtype, + ) ) labels = smooth_gradient.map_blocks( partial(threshold_gradient, thr=thr, erode=erode), @@ -121,6 +365,11 @@ def segment_organoid( scale[-2:] = bin kwargs = {"scale": scale} return [ + ( + smooth_gradient, + {"name": "Gradient", "visible": show_detections, **kwargs}, + "image", + ), ( labels, {"name": "Detections", "visible": show_detections, **kwargs}, @@ -145,9 +394,7 @@ def segment_organoid( def filter_labels(labels, min_diam=50, max_diam=150, max_ecc=0.2): if max_diam <= min_diam: - raise ValueError( - "min value is greater than max value for the diameter filter" - ) + min_diam = max_diam - 10 data = strip_dimensions(labels) props = regionprops( data, @@ -178,7 +425,7 @@ def get_gradient(bf_data: np.ndarray, smooth=10, bin=1): gradient = get_2d_gradient(data[::bin, ::bin]) smoothed_gradient = gaussian_filter(gradient, smooth) # sm = multiwell.gaussian_filter(well, smooth) - return smoothed_gradient.reshape(bf_data[..., ::bin, ::bin].shape) + return norm01(smoothed_gradient.reshape(bf_data[..., ::bin, ::bin].shape)) def threshold_gradient(