diff --git a/src/napari_segment/_widget.py b/src/napari_segment/_widget.py index 09c6f131533fc9d2b8d85330674388e6d98d412b..155a2532f84d94f4646a2024821d57356c2b6e12 100644 --- a/src/napari_segment/_widget.py +++ b/src/napari_segment/_widget.py @@ -9,6 +9,7 @@ Replace code below according to your needs. from enum import Enum import os from functools import partial +from tempfile import TemporaryFile import dask import magicgui.widgets as w @@ -28,11 +29,14 @@ from scipy.ndimage import ( from functools import reduce from skimage.measure import regionprops -from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas +from matplotlib.backends.backend_qt5agg import (FigureCanvasQTAgg as FigureCanvas) +from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar from matplotlib.figure import Figure +import matplotlib from threading import Thread import logging from importlib.metadata import PackageNotFoundError, version +import pandas as pd try: __version__ = version("napari-segment") @@ -41,9 +45,9 @@ except PackageNotFoundError: __version__ = "Unknown" logging.basicConfig( - level=logging.DEBUG, format="%(asctime)s %(levelname)s : %(message)s" + level=logging.WARNING, format="%(asctime)s %(levelname)s : %(message)s" ) -logger = logging.getLogger("napari_segment.widget") +logger = logging.getLogger("napari_segment._widget") formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s : %(message)s" @@ -117,20 +121,26 @@ class ExampleQWidget(q.QWidget): value=.9 ) - self.btn = q.QPushButton("Save!") + self.btn = q.QPushButton("Save params!") + self.btn_save_csv = q.QPushButton("Save csv!") + + self.check_auto_plot = w.CheckBox(label="Update Plots") + self.check_auto_plot.changed.connect(self.plot_stats) self.canvas = FigureCanvas(Figure(figsize=(5,5))) - self.ax = self.canvas.figure.subplots(nrows=3, sharex=True) + self.ax = self.canvas.figure.subplots(nrows=3, sharex=False) self.ax[0].set_title("Number of detections") self.ax[1].set_title(diams_title := "Diameters") self.diams_title = diams_title self.ax[2].set_title("Eccentricities") + self.canvas.figure.tight_layout() - self._count, = self.ax[0].plot(range(10), [0]*10) - self._diams, = self.ax[1].plot(range(10), [0]*10) - self._eccs, = self.ax[2].plot(range(10), [0]*10) + self._count, = self.ax[0].plot(range(10), [0]*10, "o", picker=True, pickradius=5) + self._diams, = self.ax[1].plot(range(10), [0]*10, "o", picker=True, pickradius=5) + self._eccs, = self.ax[2].plot(range(10), [0]*10, "o", picker=True, pickradius=5) self.container = w.Container( + label="Container", widgets=[ self.input, w.Label(label="Prepocessing"), @@ -144,6 +154,7 @@ class ExampleQWidget(q.QWidget): self.min_diam, self.max_diam, self.max_ecc, + self.check_auto_plot ] ) @@ -151,27 +162,46 @@ class ExampleQWidget(q.QWidget): self.layout().addWidget(self.container.native) self.layout().addWidget(self.btn) self.layout().addWidget(self.canvas) + self.layout().addWidget(NavigationToolbar(self.canvas, self)) + self.layout().addWidget(self.btn_save_csv) + self.layout().addStretch() self.viewer.layers.events.inserted.connect(self.reset_choices) self.viewer.layers.events.removed.connect(self.reset_choices) self.input.changed.connect(self.restore_params) self.input.changed.connect(self.preprocess) - self.binning_widget.changed.connect(self.preprocess) - self.thr.changed.connect(self.threshold) - self.erode.changed.connect(self.threshold) - self.use.changed.connect(self.preprocess) - self.smooth.changed.connect(self.preprocess) - self.min_diam.changed.connect(self.update_out) - self.max_diam.changed.connect(self.update_out) - self.max_ecc.changed.connect(self.update_out) - self.btn.clicked.connect(self.save_params) + + self.canvas.mpl_connect('button_press_event', self.onfigclick) + self.canvas.mpl_connect('pick_event', self.on_pick) + logger.debug("Initialization finished.") - if (choice := self.input.current_choice): + if self.input.current_choice: self.restore_params() self.preprocess() + def move_step(self, slice): + cur_ = self.viewer.dims.current_step + new_ = list(cur_) + new_[0] = slice + self.viewer.dims.current_step = tuple(new_) + + + def onfigclick(self, event): + print('%s click: button=%d, x=%d, y=%d, xdata=%f, ydata=%f' % + ('double' if event.dblclick else 'single', event.button, + event.x, event.y, event.xdata, event.ydata)) + self.move_step(np.round(event.xdata, 0).astype(int)) + + def on_pick(self, event): + line = event.artist + xdata, ydata = line.get_data() + ind = event.ind + print(f'on pick line: {np.array([xdata[ind], ydata[ind]]).T}') + + self.move_step(xdata[ind]) + def _invert(self, data2D): return 1 - norm01(gaussian_filter(data2D, self.smooth.value)) @@ -185,7 +215,17 @@ class ExampleQWidget(q.QWidget): ) def preprocess(self): + if not self.input.current_choice: + return logger.debug(f"start preprocessing {self.input.current_choice}") + + try: + self.path = self.viewer.layers[self.input.current_choice].metadata[ + "path" + ] + except KeyError: + self.path = "" + self.binning = self.binning_widget.value try: self.data = self.viewer.layers[self.input.current_choice].data[ @@ -197,7 +237,8 @@ class ExampleQWidget(q.QWidget): return try: - self.pixel_size = self.viewer.layers[self.input.current_choice].metadata["pixel_size_um"] + self.pixel_size = self.viewer.layers[self.input.current_choice]\ + .metadata["pixel_size_um"] self.unit = "um" except KeyError: @@ -258,18 +299,21 @@ class ExampleQWidget(q.QWidget): **{"name": name, "scale": self.scale}, ) else: - logger.debug(f'Updating Preprocessing layer with {self.smooth_gradient}') + logger.debug(f'Updating Preprocessing layer with \ + {self.smooth_gradient}') self.viewer.layers[name].data = self.smooth_gradient self.viewer.layers[name].scale = self.scale - logger.debug(f'Preprocessing finished, sending {self.smooth_gradient} to thresholding') + logger.debug(f'Preprocessing finished, \ + sending {self.smooth_gradient} to thresholding') self.threshold() def threshold(self): logger.debug('Start thresholding step') if not self.input.current_choice: return - logger.debug(f'Thresholding with the thr={self.thr.value} and erode={self.erode.value}') + logger.debug(f'Thresholding with the thr={self.thr.value} \ + and erode={self.erode.value}') self.labels = self.smooth_gradient.map_blocks( partial( threshold_gradient, thr=self.thr.value, erode=self.erode.value @@ -283,13 +327,15 @@ class ExampleQWidget(q.QWidget): opacity=0.3, **{"name": name, "scale": self.scale}, ) + self.viewer.layers[name].contour = 8 // self.binning + else: logger.debug(f'Updating Detections layer with {self.labels}') self.viewer.layers[name].data = self.labels self.viewer.layers[name].scale = self.scale - self.viewer.layers[name].contour = 8 / self.binning - logger.debug(f'Thresholding succesful. Sending labels {self.labels} to filtering') + logger.debug(f'Thresholding succesful. \ + Sending labels {self.labels} to filtering') self.update_out() def update_out(self): @@ -298,8 +344,9 @@ class ExampleQWidget(q.QWidget): return logger.debug(f'Start filtering') try: - logger.debug(f"""Filtering labels by ({self.min_diam.value / self.binning} - < size[px] < {self.max_diam.value / self.binning}) + logger.debug(f"""Filtering labels by \ + ({self.min_diam.value / self.binning} \ + < size[px] < {self.max_diam.value / self.binning}) \ and eccentricity > {self.max_ecc.value}""") selected_labels = dask.array.map_blocks( @@ -310,6 +357,8 @@ class ExampleQWidget(q.QWidget): self.max_ecc.value, dtype=np.int32, ) + self.selected_labels = selected_labels + if not (name := "selected labels") in self.viewer.layers: logger.debug(f'No selected labels layer found, adding one') @@ -320,7 +369,7 @@ class ExampleQWidget(q.QWidget): **{"name": name, "scale": self.scale}, ) else: - logger.debug(f'Updating selected labels layer with {selected_labels}') + logger.debug(f'Updating labels layer with {selected_labels}') self.viewer.layers[name].scale = self.scale self.viewer.layers[name].data = selected_labels @@ -331,40 +380,75 @@ class ExampleQWidget(q.QWidget): show_error(f"Relax filter! {e}") try: - - self.plot_stats(selected_labels) + self.plot_stats() except Exception as e: show_error(f"Plot failed: {e}") - - def plot_stats(self, data): - props = [regionprops(label_image=img) for img in data.compute()] + + def plot_stats(self, ignore_check_auto_plot=False): + if not self.check_auto_plot.value: + if not ignore_check_auto_plot: + return + props = [regionprops(label_image=img) for img in self.selected_labels.compute()] num_regions_per_frame = [len(p) for p in props] - diams_ = [ + area_ = [ [ - (i, prop.major_axis_length * self.binning * self.pixel_size) \ + (i, prop.area * (self.binning * self.pixel_size)**2) \ for prop in props_per_frame ] for i,props_per_frame in enumerate(props) ] - diams = reduce(lambda a,b: a+b, diams_[:]) + area = reduce(lambda a,b: a+b, area_[:]) + + self.props_df = \ + pd.DataFrame( + data=area, + columns=["frame", (area_col := f"area [{self.unit}$^2$]")] + ) + self.props_df.loc[:, (diam_col := f"diameter [{self.unit}]")] = \ + np.sqrt(self.props_df[area_col]) * 2 / np.pi + eccs_ = [ [ - (i, prop.eccentricity) for prop in props_per_frame + (prop.eccentricity) for prop in props_per_frame ] for i,props_per_frame in enumerate(props) ] eccs = reduce(lambda a,b: a+b, eccs_) + self.props_df.loc[:, (ecc_col := "eccentricity")] = eccs + + data=self.props_df self._count.set_data(*zip(*enumerate(num_regions_per_frame))) - self._diams.set_data(*zip(*diams)) - self._eccs.set_data(*zip(*eccs)) - [a.set_xlim(len(num_regions_per_frame)) for a in self.ax] - self.ax[0].set_ylim(min(num_regions_per_frame), max(num_regions_per_frame)) - self.ax[1].set_ylim(min(r := [d[1] for d in diams]), max(r)) + self._diams.set_data(data["frame"], data[diam_col]) + self._eccs.set_data(data["frame"], data[ecc_col]) + [a.set_xlim(0, len(num_regions_per_frame)) for a in self.ax] + self.ax[0].set_ylim(0, max(num_regions_per_frame) + 1) + self.ax[1].set_ylim( + self.props_df[diam_col].min(), + self.props_df[diam_col].max() + ) self.ax[2].set_ylim(0,1) self.canvas.draw_idle() + def save_csv(self): + self.plot_stats(ignore_check_auto_plot=True) + try: + len(self.props_df) + except Exception as e: + logger.error(f"Unable to find the data to save: {e}") + show_error(f"Unable to find the data to save: {e}") + return + path = self.path + ".table.csv" + try: + self.props_df.to_csv(path) + logger.info(f"Saved csv {path}") + show_info(f"Saved csv {path}") + except Exception as e: + logger.error(f"Error saving csv: {e}") + show_error(f"Error saving csv: {e}") + + def save_params(self): data = { @@ -378,27 +462,33 @@ class ExampleQWidget(q.QWidget): "max_ecc": self.max_ecc.value, } try: - path = self.viewer.layers[self.input.current_choice].metadata[ - "path" - ] - dir = os.path.dirname(path) - filename = os.path.basename(path) + + dir = os.path.dirname(self.path) + filename = os.path.basename(self.path) new_name = filename + ".params.yaml" with open(os.path.join(dir, new_name), "w") as f: yaml.safe_dump(data, f) show_info(f"Parameters saves into {new_name}") - except KeyError: - show_error("Saving parameters failed") + logger.info(f"Parameters saves into {new_name}") + except Exception as e: + show_error(f"Saving parameters into {new_name} failed: {e}") + logger.error(f"Saving parameters into {new_name} failed: {e}") with open((ff := ".latest.params.yaml"), "w") as f: yaml.safe_dump(data, f) - show_info(f"Parameters saves into {ff}") + logger.info(f"Parameters saves into {ff}") def restore_params(self): + logger.debug(f'Start restoring parameters') try: - path = self.viewer.layers[self.input.current_choice].metadata[ + self.path = self.viewer.layers[self.input.current_choice].metadata[ "path" ] + except KeyError: + self.path = "" + + try: + path = self.path dir = os.path.dirname(path) filename = os.path.basename(path) new_name = filename + ".params.yaml" @@ -408,6 +498,7 @@ class ExampleQWidget(q.QWidget): with open(ppp := os.path.join(dir, new_name)) as f: data = yaml.safe_load(f) show_info(f"restoring parameters from {new_name}") + logger.info(f"restoring parameters from {new_name}") except (UnboundLocalError, UnicodeDecodeError, FileNotFoundError): try: @@ -416,7 +507,7 @@ class ExampleQWidget(q.QWidget): show_info(f"restoring parameters from {ppp}") except FileNotFoundError: return - print(data) + logger.debug(f'Loaded parameters: {data}') try: self.binning_widget.value = data["binning"] self.use.value = data["use"] @@ -428,16 +519,32 @@ class ExampleQWidget(q.QWidget): self.max_ecc.value = data["max_ecc"] except Exception as e: show_error(f"Restore settings failed, {e}") + logger.error(f"Restore settings failed, {e}") + + self.binning_widget.changed.connect(self.preprocess) + self.thr.changed.connect(self.threshold) + self.erode.changed.connect(self.threshold) + self.use.changed.connect(self.preprocess) + self.smooth.changed.connect(self.preprocess) + self.min_diam.changed.connect(self.update_out) + self.max_diam.changed.connect(self.update_out) + self.max_ecc.changed.connect(self.update_out) + self.btn.clicked.connect(self.save_params) + self.btn_save_csv.clicked.connect(self.save_csv) + def reset_choices(self, event=None): - self.input.reset_choices(event) - self.input.choices = [ + logger.debug(f"New layer added. Reset choices. Event: {event}") + new_layers = [ layer.name for layer in self.viewer.layers if isinstance(layer, Image) and layer.name != "Preprocessing" ] - # self.restore_params() - + if self.input.choices != new_layers: + self.input.choices = new_layers + logger.debug(f'Updating layer list with {new_layers}') + else: + logger.debug('No new data layers, probably added pipeline layers triggered this reset') def norm01(data): d = data.copy()