diff --git a/src/napari_segment/_widget.py b/src/napari_segment/_widget.py index d70195064e324fabe81d4f8e84b5d771be67a0a9..aeb08ba61242665aea3c09944d50102b78e66a9a 100644 --- a/src/napari_segment/_widget.py +++ b/src/napari_segment/_widget.py @@ -6,18 +6,27 @@ see: https://napari.org/plugins/stable/guides.html#widgets Replace code below according to your needs. """ -from enum import Enum +import logging import os -from functools import partial -from tempfile import TemporaryFile +from enum import Enum +from functools import partial, reduce +from importlib.metadata import PackageNotFoundError, version import dask import magicgui.widgets as w import napari import numpy as np +import pandas as pd import qtpy.QtWidgets as q import yaml from magicgui import magic_factory +from matplotlib.backends.backend_qt5agg import ( + FigureCanvasQTAgg as FigureCanvas, +) +from matplotlib.backends.backend_qt5agg import ( + NavigationToolbar2QT as NavigationToolbar, +) +from matplotlib.figure import Figure from napari.layers import Image from napari.utils.notifications import show_error, show_info from scipy.ndimage import ( @@ -26,17 +35,8 @@ from scipy.ndimage import ( gaussian_filter, label, ) -from functools import reduce from skimage.measure import regionprops - -from matplotlib.backends.backend_qt5agg import (FigureCanvasQTAgg as FigureCanvas) -from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar -from matplotlib.figure import Figure -import matplotlib -from threading import Thread -import logging -from importlib.metadata import PackageNotFoundError, version -import pandas as pd +from skimage.segmentation import clear_border try: __version__ = version("napari-segment") @@ -64,10 +64,9 @@ class SegmentStack(q.QWidget): # 2. use a type annotation of 'napari.viewer.Viewer' for any parameter class Choices(Enum): - INT="Invert" - GRAD="Gradient" - GDIF="Gauss diff" - + INT = "Invert" + GRAD = "Gradient" + GDIF = "Gauss diff" def __init__(self, napari_viewer): super().__init__() @@ -90,7 +89,9 @@ class SegmentStack(q.QWidget): orientation="horizontal", ) - self.thr = w.FloatSlider(label="Threshold", min=0.1, max=0.9, value=.4) + self.thr = w.FloatSlider( + label="Threshold", min=0.1, max=0.9, value=0.4 + ) self.erode = w.SpinBox(label="erode", min=0, max=10, value=0) @@ -99,7 +100,7 @@ class SegmentStack(q.QWidget): choices=[v.value for v in self.Choices], value=self.Choices.INT.value, orientation="horizontal", - allow_multiple=True + allow_multiple=True, ) self.smooth = w.SpinBox(label="smooth", min=0, max=10, value=2) @@ -115,19 +116,30 @@ class SegmentStack(q.QWidget): ) self.max_ecc = w.FloatSlider( - label="Max eccentricity", - min=0.0, - max=1.0, - value=.9 + label="Max eccentricity", min=0.0, max=1.0, value=0.9 ) - self.btn = q.QPushButton("Save params!") + self.btn_save_params = q.QPushButton("Save params!") self.btn_save_csv = q.QPushButton("Save csv!") self.check_auto_plot = w.CheckBox(label="Update Plots") self.check_auto_plot.changed.connect(self.plot_stats) - self.canvas = FigureCanvas(Figure(figsize=(5,5))) + self.pixel_size = 1 + self.pixel_unit = "px" + self.pixel_size_widget = w.LineEdit( + label="Pixel size", + value=self.pixel_size, + ) # bind=self.set_pixel_size) + self.pixel_unit_widget = w.LineEdit( + label="unit", value=self.pixel_unit + ) # , bind=self.set_pixel_unit) + self.pixel_block = w.Container( + widgets=[self.pixel_size_widget, self.pixel_unit_widget], + layout="horizontal", + ) + + self.canvas = FigureCanvas(Figure(figsize=(5, 5))) self.ax = self.canvas.figure.subplots(nrows=3, sharex=False) self.ax[0].set_title("Number of detections") self.ax[1].set_title(diams_title := "Diameters") @@ -135,9 +147,15 @@ class SegmentStack(q.QWidget): self.ax[2].set_title("Eccentricities") self.canvas.figure.tight_layout() - self._count, = self.ax[0].plot(range(10), [0]*10, "o", picker=True, pickradius=5) - self._diams, = self.ax[1].plot(range(10), [0]*10, "o", picker=True, pickradius=5) - self._eccs, = self.ax[2].plot(range(10), [0]*10, "o", picker=True, pickradius=5) + (self._count,) = self.ax[0].plot( + range(10), [0] * 10, "o", picker=True, pickradius=5 + ) + (self._diams,) = self.ax[1].plot( + range(10), [0] * 10, "o", picker=True, pickradius=5 + ) + (self._eccs,) = self.ax[2].plot( + range(10), [0] * 10, "o", picker=True, pickradius=5 + ) self.container = w.Container( label="Container", @@ -154,13 +172,14 @@ class SegmentStack(q.QWidget): self.min_diam, self.max_diam, self.max_ecc, - self.check_auto_plot - ] + ], ) self.setLayout(q.QVBoxLayout()) self.layout().addWidget(self.container.native) - self.layout().addWidget(self.btn) + self.layout().addWidget(self.btn_save_params) + self.layout().addWidget(self.check_auto_plot.native) + self.layout().addWidget(self.pixel_block.native) self.layout().addWidget(self.canvas) self.layout().addWidget(NavigationToolbar(self.canvas, self)) self.layout().addWidget(self.btn_save_csv) @@ -172,8 +191,8 @@ class SegmentStack(q.QWidget): self.input.changed.connect(self.restore_params) self.input.changed.connect(self.preprocess) - self.canvas.mpl_connect('button_press_event', self.onfigclick) - self.canvas.mpl_connect('pick_event', self.on_pick) + self.canvas.mpl_connect("button_press_event", self.onfigclick) + self.canvas.mpl_connect("pick_event", self.on_pick) logger.debug("Initialization finished.") @@ -187,18 +206,25 @@ class SegmentStack(q.QWidget): new_[0] = slice self.viewer.dims.current_step = tuple(new_) - def onfigclick(self, event): - print('%s click: button=%d, x=%d, y=%d, xdata=%f, ydata=%f' % - ('double' if event.dblclick else 'single', event.button, - event.x, event.y, event.xdata, event.ydata)) + print( + "%s click: button=%d, x=%d, y=%d, xdata=%f, ydata=%f" + % ( + "double" if event.dblclick else "single", + event.button, + event.x, + event.y, + event.xdata, + event.ydata, + ) + ) self.move_step(np.round(event.xdata, 0).astype(int)) def on_pick(self, event): line = event.artist xdata, ydata = line.get_data() ind = event.ind - print(f'on pick line: {np.array([xdata[ind], ydata[ind]]).T}') + print(f"on pick line: {np.array([xdata[ind], ydata[ind]]).T}") self.move_step(xdata[ind]) @@ -209,16 +235,23 @@ class SegmentStack(q.QWidget): return get_gradient(data2D, smooth=self.smooth.value) def _gdif(self, data2D): - return ( - gaussian_filter(data2D, self.smooth.value) - - gaussian_filter(data2D, self.smooth.value + 2) + return gaussian_filter(data2D, self.smooth.value) - gaussian_filter( + data2D, self.smooth.value + 2 ) + def set_pixel_size(self, value): + self.pixel_size = value + self.pixel_size_widget.value = value + + def set_pixel_unit(self, unit): + self.pixel_unit = unit + self.pixel_unit_widget.value = unit + def preprocess(self): if not self.input.current_choice: return logger.debug(f"start preprocessing {self.input.current_choice}") - + try: self.path = self.viewer.layers[self.input.current_choice].metadata[ "path" @@ -235,23 +268,24 @@ class SegmentStack(q.QWidget): except KeyError: return - + try: - self.pixel_size = self.viewer.layers[self.input.current_choice]\ - .metadata["pixel_size_um"] - self.unit = "um" + pixel_size = self.viewer.layers[ + self.input.current_choice + ].metadata["pixel_size_um"] + unit = "um" + self.set_pixel_size(pixel_size) + self.set_pixel_unit(unit) except KeyError: self.pixel_size = 1 - self.unit = "px" - - logger.debug(f"pixel size: {self.pixel_size} {self.unit}") + self.pixel_unit = "px" + + logger.debug(f"pixel size: {self.pixel_size} {self.pixel_unit}") - self.ax[1].set_title(f"{self.diams_title}, {self.unit}") - self.scale = np.ones((len(self.data.shape),)) self.scale[-2:] = self.binning - logger.debug(f'Computed scale for napari {self.scale}') + logger.debug(f"Computed scale for napari {self.scale}") if isinstance(self.data, np.ndarray): chunksize = np.ones(len(self.data.shape)) @@ -261,10 +295,10 @@ class SegmentStack(q.QWidget): ) else: self.ddata = self.data.astype("f") - - logger.debug(f'Dask array: {self.ddata}') - logger.debug(f'Processing data with {self.use.value}') + logger.debug(f"Dask array: {self.ddata}") + + logger.debug(f"Processing data with {self.use.value}") # show_info(self.use.value) if self.use.value == self.Choices.GRAD.value: self.smooth_gradient = self.ddata.map_blocks( @@ -283,8 +317,10 @@ class SegmentStack(q.QWidget): ) else: self.smooth_gradient = np.zeros_like(self.ddata) - logger.error(f"""Filter `{self.use.value}` not understood! - Expected {[v.value for v in self.Choices]}""") + logger.error( + f"""Filter `{self.use.value}` not understood! + Expected {[v.value for v in self.Choices]}""" + ) raise ( ValueError( f"""Filter `{self.use.value}` not understood! @@ -293,27 +329,33 @@ class SegmentStack(q.QWidget): ) if not (name := "Preprocessing") in self.viewer.layers: - logger.debug(f'No Preprocessing layer found, adding one') + logger.debug("No Preprocessing layer found, adding one") self.viewer.add_image( data=self.smooth_gradient, **{"name": name, "scale": self.scale}, ) else: - logger.debug(f'Updating Preprocessing layer with \ - {self.smooth_gradient}') + logger.debug( + f"Updating Preprocessing layer with \ + {self.smooth_gradient}" + ) self.viewer.layers[name].data = self.smooth_gradient self.viewer.layers[name].scale = self.scale - - logger.debug(f'Preprocessing finished, \ - sending {self.smooth_gradient} to thresholding') + + logger.debug( + f"Preprocessing finished, \ + sending {self.smooth_gradient} to thresholding" + ) self.threshold() def threshold(self): - logger.debug('Start thresholding step') + logger.debug("Start thresholding step") if not self.input.current_choice: return - logger.debug(f'Thresholding with the thr={self.thr.value} \ - and erode={self.erode.value}') + logger.debug( + f"Thresholding with the thr={self.thr.value} \ + and erode={self.erode.value}" + ) self.labels = self.smooth_gradient.map_blocks( partial( threshold_gradient, thr=self.thr.value, erode=self.erode.value @@ -321,7 +363,7 @@ class SegmentStack(q.QWidget): dtype=np.int32, ) if not (name := "Detections") in self.viewer.layers: - logger.debug(f'No Detections layer found, adding one') + logger.debug("No Detections layer found, adding one") self.viewer.add_labels( data=self.labels, opacity=0.3, @@ -329,27 +371,30 @@ class SegmentStack(q.QWidget): ) else: - logger.debug(f'Updating Detections layer with {self.labels}') + logger.debug(f"Updating Detections layer with {self.labels}") self.viewer.layers[name].data = self.labels self.viewer.layers[name].scale = self.scale - + self.viewer.layers[name].contour = 8 // self.binning - - - logger.debug(f'Thresholding succesful. \ - Sending labels {self.labels} to filtering') + + logger.debug( + f"Thresholding succesful. \ + Sending labels {self.labels} to filtering" + ) self.update_out() def update_out(self): if not self.input.current_choice: return - logger.debug(f'Start filtering') + logger.debug("Start filtering") try: - logger.debug(f"""Filtering labels by \ + logger.debug( + f"""Filtering labels by \ ({self.min_diam.value / self.binning} \ < size[px] < {self.max_diam.value / self.binning}) \ - and eccentricity > {self.max_ecc.value}""") + and eccentricity > {self.max_ecc.value}""" + ) selected_labels = dask.array.map_blocks( filter_labels, @@ -361,9 +406,8 @@ class SegmentStack(q.QWidget): ) self.selected_labels = selected_labels - if not (name := "selected labels") in self.viewer.layers: - logger.debug(f'No selected labels layer found, adding one') + logger.debug("No selected labels layer found, adding one") self.viewer.add_labels( data=selected_labels, @@ -371,13 +415,12 @@ class SegmentStack(q.QWidget): **{"name": name, "scale": self.scale}, ) else: - logger.debug(f'Updating labels layer with {selected_labels}') + logger.debug(f"Updating labels layer with {selected_labels}") self.viewer.layers[name].scale = self.scale self.viewer.layers[name].data = selected_labels # self.save_params() - except TypeError as e: show_error(f"Relax filter! {e}") @@ -386,50 +429,54 @@ class SegmentStack(q.QWidget): except Exception as e: show_error(f"Plot failed: {e}") - def plot_stats(self, ignore_check_auto_plot=False): if not self.check_auto_plot.value: if not ignore_check_auto_plot: return - props = [regionprops(label_image=img) for img in self.selected_labels.compute()] - num_regions_per_frame = [len(p) for p in props] + self.pixel_size = float(self.pixel_size_widget.value) + self.pixel_unit = str(self.pixel_unit_widget.value) + self.ax[1].set_title(f"{self.diams_title}, {self.pixel_unit}") + + props = [ + regionprops(label_image=img) + for img in self.selected_labels.compute() + ] + num_regions_per_frame = [len(p) for p in props] area_ = [ [ - (i, prop.area * (self.binning * self.pixel_size)**2) \ + (i, prop.area * (self.binning * self.pixel_size) ** 2) for prop in props_per_frame - ] for i,props_per_frame in enumerate(props) - ] - area = reduce(lambda a,b: a+b, area_[:]) - - self.props_df = \ - pd.DataFrame( - data=area, - columns=["frame", (area_col := f"area [{self.unit}$^2$]")] - ) - self.props_df.loc[:, (diam_col := f"diameter [{self.unit}]")] = \ + ] + for i, props_per_frame in enumerate(props) + ] + area = reduce(lambda a, b: a + b, area_[:]) + + self.props_df = pd.DataFrame( + data=area, + columns=["frame", (area_col := f"area [{self.pixel_unit}$^2$]")], + ) + self.props_df.loc[:, (diam_col := f"diameter [{self.pixel_unit}]")] = ( np.sqrt(self.props_df[area_col]) * 2 / np.pi + ) eccs_ = [ - [ - (prop.eccentricity) for prop in props_per_frame - ] for i,props_per_frame in enumerate(props) - ] - eccs = reduce(lambda a,b: a+b, eccs_) + [(prop.eccentricity) for prop in props_per_frame] + for i, props_per_frame in enumerate(props) + ] + eccs = reduce(lambda a, b: a + b, eccs_) self.props_df.loc[:, (ecc_col := "eccentricity")] = eccs - - data=self.props_df + data = self.props_df self._count.set_data(*zip(*enumerate(num_regions_per_frame))) self._diams.set_data(data["frame"], data[diam_col]) self._eccs.set_data(data["frame"], data[ecc_col]) [a.set_xlim(0, len(num_regions_per_frame)) for a in self.ax] self.ax[0].set_ylim(0, max(num_regions_per_frame) + 1) self.ax[1].set_ylim( - self.props_df[diam_col].min(), - self.props_df[diam_col].max() + self.props_df[diam_col].min(), self.props_df[diam_col].max() ) - self.ax[2].set_ylim(0,1) + self.ax[2].set_ylim(0, 1) self.canvas.draw_idle() @@ -442,6 +489,8 @@ class SegmentStack(q.QWidget): show_error(f"Unable to find the data to save: {e}") return path = self.path + ".table.csv" + self.props_df.loc[:, "filename"] = os.path.basename(self.path) + self.props_df.loc[:, "layer"] = self.input.current_choice try: self.props_df.to_csv(path) logger.info(f"Saved csv {path}") @@ -450,8 +499,6 @@ class SegmentStack(q.QWidget): logger.error(f"Error saving csv: {e}") show_error(f"Error saving csv: {e}") - - def save_params(self): data = { "binning": self.binning, @@ -462,9 +509,11 @@ class SegmentStack(q.QWidget): "min_diameter": self.min_diam.value, "max_diameter": self.max_diam.value, "max_ecc": self.max_ecc.value, + "pixel_size": self.pixel_size, + "pixel_unit": self.pixel_unit, } try: - + dir = os.path.dirname(self.path) filename = os.path.basename(self.path) new_name = filename + ".params.yaml" @@ -481,7 +530,7 @@ class SegmentStack(q.QWidget): logger.info(f"Parameters saves into {ff}") def restore_params(self): - logger.debug(f'Start restoring parameters') + logger.debug("Start restoring parameters") try: self.path = self.viewer.layers[self.input.current_choice].metadata[ "path" @@ -509,7 +558,7 @@ class SegmentStack(q.QWidget): show_info(f"restoring parameters from {ppp}") except FileNotFoundError: return - logger.debug(f'Loaded parameters: {data}') + logger.debug(f"Loaded parameters: {data}") try: self.binning_widget.value = data["binning"] self.use.value = data["use"] @@ -519,10 +568,12 @@ class SegmentStack(q.QWidget): self.min_diam.value = data["min_diameter"] self.max_diam.value = data["max_diameter"] self.max_ecc.value = data["max_ecc"] + self.pixel_size = data["pixel_size"] + self.pixel_unit = data["pixel_unit"] except Exception as e: show_error(f"Restore settings failed, {e}") logger.error(f"Restore settings failed, {e}") - + self.binning_widget.changed.connect(self.preprocess) self.thr.changed.connect(self.threshold) self.erode.changed.connect(self.threshold) @@ -531,10 +582,9 @@ class SegmentStack(q.QWidget): self.min_diam.changed.connect(self.update_out) self.max_diam.changed.connect(self.update_out) self.max_ecc.changed.connect(self.update_out) - self.btn.clicked.connect(self.save_params) + self.btn_save_params.clicked.connect(self.save_params) self.btn_save_csv.clicked.connect(self.save_csv) - def reset_choices(self, event=None): logger.debug(f"New layer added. Reset choices. Event: {event}") new_layers = [ @@ -544,12 +594,18 @@ class SegmentStack(q.QWidget): ] if self.input.choices != new_layers: self.input.choices = new_layers - logger.debug(f'Updating layer list with {new_layers}') + logger.debug(f"Updating layer list with {new_layers}") else: - logger.debug('No new data layers, probably added pipeline layers triggered this reset') + logger.debug( + "No new data layers, probably added pipeline \ + layers triggered this reset" + ) + def norm01(data): d = data.copy() + if d.max() == 0: + return d return (a := d - d.min()) / a.max() @@ -707,6 +763,8 @@ def threshold_gradient( data = strip_dimensions(smoothed_gradient) regions = data > thr * data.max() + regions = clear_border(regions) + if fill: regions = binary_fill_holes(regions)