diff --git a/src/napari_segment/_reader.py b/src/napari_segment/_reader.py index 7278e96fa40101eebc58f6e08c168d5d5b6a0176..a93237b6af2a551a7cc1b938769ce2703fa90ff2 100644 --- a/src/napari_segment/_reader.py +++ b/src/napari_segment/_reader.py @@ -126,6 +126,11 @@ def read_zarr(path): def read_nd2(path): print(f"reading {path}") data = nd2.ND2File(path) + try: + pixel_size_um = data.metadata.channels[0].volume.axesCalibration[0] + except Exception as e: + print(f'Pixel information unavailable: {e}') + pixel_size_um = 1 print(data.sizes) ddata = data.to_dask() # colormap = ["gray", "green"] @@ -140,7 +145,11 @@ def read_nd2(path): ddata, { "channel_axis": channel_axis, - "metadata": {"sizes": data.sizes, "path": path}, + "metadata": { + "sizes": data.sizes, + "path": path, + "pixel_size_um": pixel_size_um + }, }, # dict( # channel_axis=channel_axis, diff --git a/src/napari_segment/_widget.py b/src/napari_segment/_widget.py index d1c3374e340c3d94603c33a36cc3c92675b4b62a..9537ae18e27f51d17c09efe5b36d22c38f837418 100644 --- a/src/napari_segment/_widget.py +++ b/src/napari_segment/_widget.py @@ -6,6 +6,7 @@ see: https://napari.org/plugins/stable/guides.html#widgets Replace code below according to your needs. """ +from enum import Enum import os from functools import partial @@ -24,16 +25,25 @@ from scipy.ndimage import ( gaussian_filter, label, ) +from functools import reduce from skimage.measure import regionprops -# import matplotlib.pyplot as plt - +from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas +from matplotlib.figure import Figure +from threading import Thread class ExampleQWidget(q.QWidget): # your QWidget.__init__ can optionally request the napari viewer instance # in one of two ways: # 1. use a parameter called `napari_viewer`, as done here # 2. use a type annotation of 'napari.viewer.Viewer' for any parameter + + class Choices(Enum): + INT="Invert" + GRAD="Gradient" + GDIF="Gauss diff" + + def __init__(self, napari_viewer): super().__init__() self.viewer = napari_viewer @@ -47,8 +57,6 @@ class ExampleQWidget(q.QWidget): if isinstance(layer, Image) ], ) - self.input.changed.connect(self.restore_params) - self.input.changed.connect(self.preprocess) self.binning_widget = w.RadioButtons( label="binning", @@ -56,46 +64,50 @@ class ExampleQWidget(q.QWidget): value=4, orientation="horizontal", ) - self.binning_widget.changed.connect(self.preprocess) - self.thr = w.FloatSlider(label="Threshold", min=0.1, max=0.9) - self.thr.changed.connect(self.threshold) + self.thr = w.FloatSlider(label="Threshold", min=0.1, max=0.9, value=.4) - self.erode = w.SpinBox(label="erode", min=0, max=10) - self.erode.changed.connect(self.threshold) + self.erode = w.SpinBox(label="erode", min=0, max=10, value=0) self.use = w.RadioButtons( label="Use", - choices=["Intensity", "Gradient", "GDif"], - value="Intensity", + choices=[v.value for v in self.Choices], + value=self.Choices.INT.value, orientation="horizontal", + allow_multiple=True ) - self.use.changed.connect(self.preprocess) - self.smooth = w.SpinBox(label="smooth", min=0, max=10) - self.smooth.changed.connect(self.preprocess) + self.smooth = w.SpinBox(label="smooth", min=0, max=10, value=2) self.min_diam = w.Slider( label="Min_diameter", min=1, max=500, ) - self.min_diam.changed.connect(self.update_out) self.max_diam = w.Slider( label="Max_diameter", min=150, max=2000, step=150 ) - self.max_diam.changed.connect(self.update_out) self.max_ecc = w.FloatSlider( label="Max eccentricity", min=0.0, max=1.0, + value=.9 ) - self.max_ecc.changed.connect(self.update_out) self.btn = q.QPushButton("Save!") - self.btn.clicked.connect(self.save_params) + + self.canvas = FigureCanvas(Figure(figsize=(5,5))) + self.ax = self.canvas.figure.subplots(nrows=3, sharex=True) + self.ax[0].set_title("Number of detections") + self.ax[1].set_title(diams_title := "Diameters") + self.diams_title = diams_title + self.ax[2].set_title("Eccentricities") + + self._count, = self.ax[0].plot(range(10), [0]*10) + self._diams, = self.ax[1].plot(range(10), [0]*10) + self._eccs, = self.ax[2].plot(range(10), [0]*10) self.container = w.Container( widgets=[ @@ -117,16 +129,40 @@ class ExampleQWidget(q.QWidget): self.setLayout(q.QVBoxLayout()) self.layout().addWidget(self.container.native) self.layout().addWidget(self.btn) + self.layout().addWidget(self.canvas) self.layout().addStretch() self.viewer.layers.events.inserted.connect(self.reset_choices) self.viewer.layers.events.removed.connect(self.reset_choices) + self.input.changed.connect(self.restore_params) + self.input.changed.connect(self.preprocess) + self.binning_widget.changed.connect(self.preprocess) + self.thr.changed.connect(self.threshold) + self.erode.changed.connect(self.threshold) + self.use.changed.connect(self.preprocess) + self.smooth.changed.connect(self.preprocess) + self.min_diam.changed.connect(self.update_out) + self.max_diam.changed.connect(self.update_out) + self.max_ecc.changed.connect(self.update_out) + self.btn.clicked.connect(self.save_params) if self.input.current_choice: print("start") self.restore_params() self.preprocess() + def _invert(self, data2D): + return 1 - norm01(gaussian_filter(data2D, self.smooth.value)) + + def _grad(self, data2D): + return get_gradient(data2D, smooth=self.smooth.value) + + def _gdif(self, data2D): + return ( + gaussian_filter(data2D, self.smooth.value) + - gaussian_filter(data2D, self.smooth.value + 2) + ) + def preprocess(self): self.binning = self.binning_widget.value try: @@ -134,8 +170,17 @@ class ExampleQWidget(q.QWidget): ..., :: self.binning, :: self.binning ] except KeyError: - show_error("No data to process") return + + try: + self.pixel_size = self.viewer.layers[self.input.current_choice].metadata["pixel_size_um"] + self.unit = "um" + except KeyError: + self.pixel_size = 1 + self.unit = "px" + + self.ax[1].set_title(f"{self.diams_title}, {self.unit}") + self.scale = np.ones((len(self.data.shape),)) self.scale[-2:] = self.binning if isinstance(self.data, np.ndarray): @@ -147,23 +192,20 @@ class ExampleQWidget(q.QWidget): else: self.ddata = self.data.astype("f") - show_info(self.use.value) - if self.use.value == "Gradient": + # show_info(self.use.value) + if self.use.value == self.Choices.GRAD.value: self.smooth_gradient = self.ddata.map_blocks( - partial(get_gradient, smooth=self.smooth.value), + self._grad, dtype=self.ddata.dtype, ) - elif self.use.value == "Intensity": + elif self.use.value == self.Choices.INT.value: self.smooth_gradient = self.ddata.map_blocks( - lambda d: 1 - norm01(gaussian_filter(d, self.smooth.value)), + self._invert, dtype=self.ddata.dtype, ) - elif self.use.value == "GDif": + elif self.use.value == self.Choices.GDIF.value: self.smooth_gradient = self.ddata.map_blocks( - lambda d: ( - gaussian_filter(d, self.smooth.value) - - gaussian_filter(d, self.smooth.value + 2) - ), + self._gdif, dtype=self.ddata.dtype, ) else: @@ -171,7 +213,7 @@ class ExampleQWidget(q.QWidget): raise ( ValueError( f"""Filter `{self.use.value}` not understood! - Expected `Gradient` or `Intensity` or `GDif`""" + Expected {[v.value for v in self.Choices]}""" ) ) @@ -231,9 +273,47 @@ class ExampleQWidget(q.QWidget): self.viewer.layers[name].scale = self.scale self.viewer.layers[name].data = selected_labels # self.save_params() + + except TypeError as e: show_error(f"Relax filter! {e}") + try: + + self.plot_stats(selected_labels) + + except Exception as e: + show_error(f"Plot failed: {e}") + + def plot_stats(self, data): + + props = [regionprops(label_image=img) for img in data.compute()] + num_regions_per_frame = [len(p) for p in props] + diams_ = [ + [ + (i, prop.major_axis_length * self.binning * self.pixel_size) \ + for prop in props_per_frame + ] for i,props_per_frame in enumerate(props) + ] + diams = reduce(lambda a,b: a+b, diams_[:]) + eccs_ = [ + [ + (i, prop.eccentricity) for prop in props_per_frame + ] for i,props_per_frame in enumerate(props) + ] + eccs = reduce(lambda a,b: a+b, eccs_) + + self._count.set_data(*zip(*enumerate(num_regions_per_frame))) + self._diams.set_data(*zip(*diams)) + self._eccs.set_data(*zip(*eccs)) + [a.set_xlim(len(num_regions_per_frame)) for a in self.ax] + self.ax[0].set_ylim(min(num_regions_per_frame), max(num_regions_per_frame)) + self.ax[1].set_ylim(min(r := [d[1] for d in diams]), max(r)) + self.ax[2].set_ylim(0,1) + + self.canvas.draw_idle() + + def save_params(self): data = { "binning": self.binning, @@ -277,10 +357,13 @@ class ExampleQWidget(q.QWidget): data = yaml.safe_load(f) show_info(f"restoring parameters from {new_name}") - except (UnboundLocalError, UnicodeDecodeError): - with open(ppp := ".latest.params.yaml") as f: - data = yaml.safe_load(f) - show_info(f"restoring parameters from {ppp}") + except (UnboundLocalError, UnicodeDecodeError, FileNotFoundError): + try: + with open(ppp := ".latest.params.yaml") as f: + data = yaml.safe_load(f) + show_info(f"restoring parameters from {ppp}") + except FileNotFoundError: + return print(data) try: self.binning_widget.value = data["binning"] @@ -437,7 +520,7 @@ def filter_labels(labels, min_diam=50, max_diam=150, max_ecc=0.2): # print(f'good_labels {good_labels}') mask = np.sum([data == v for v in good_labels], axis=0) # print(mask.shape) - return (label(mask)[0].astype("uint16")).reshape(labels.shape) + return label(mask)[0].astype("uint16").reshape(labels.shape) def get_gradient(bf_data: np.ndarray, smooth=10, bin=1):