diff --git a/src/napari_segment/_widget.py b/src/napari_segment/_widget.py index aeb08ba61242665aea3c09944d50102b78e66a9a..d7267cbb6ba2b212ef04bf643cadf619abc235e9 100644 --- a/src/napari_segment/_widget.py +++ b/src/napari_segment/_widget.py @@ -27,7 +27,7 @@ from matplotlib.backends.backend_qt5agg import ( NavigationToolbar2QT as NavigationToolbar, ) from matplotlib.figure import Figure -from napari.layers import Image +from napari.layers import Image, Labels from napari.utils.notifications import show_error, show_info from scipy.ndimage import ( binary_erosion, @@ -75,11 +75,24 @@ class SegmentStack(q.QWidget): self.input = w.ComboBox( label="BF data", annotation=Image, - choices=[ - layer.name - for layer in self.viewer.layers - if isinstance(layer, Image) + choices=self.update_images(), + ) + + self.stat_layer_selector = w.ComboBox( + label="Labels to quantify", + annotation=Labels, + choices=self.update_labels(), + ) + + self.btn_make_manual_labels = w.PushButton(text="Clone for manual correction") + self.btn_make_manual_labels.clicked.connect(self.make_manual_layer) + + self.stat_layer_selector_container = w.Container( + widgets=[ + self.stat_layer_selector, + self.btn_make_manual_labels ], + layout="horizontal" ) self.binning_widget = w.RadioButtons( @@ -122,8 +135,24 @@ class SegmentStack(q.QWidget): self.btn_save_params = q.QPushButton("Save params!") self.btn_save_csv = q.QPushButton("Save csv!") - self.check_auto_plot = w.CheckBox(label="Update Plots") - self.check_auto_plot.changed.connect(self.plot_stats) + self.check_auto_plot = w.CheckBox(label="Auto Update") + self.check_auto_plot.changed.connect( + partial( + self.plot_stats, force=True + ) + ) + + self.btn_update_stats = w.PushButton(text="Update plots") + self.btn_update_stats.clicked.connect(self.plot_stats) + + self.update_plot_container = w.Container( + widgets=[ + self.btn_update_stats, + self.check_auto_plot + ], + layout='horizontal' + ) + self.pixel_size = 1 self.pixel_unit = "px" @@ -178,7 +207,8 @@ class SegmentStack(q.QWidget): self.setLayout(q.QVBoxLayout()) self.layout().addWidget(self.container.native) self.layout().addWidget(self.btn_save_params) - self.layout().addWidget(self.check_auto_plot.native) + self.layout().addWidget(self.update_plot_container.native) + self.layout().addWidget(self.stat_layer_selector_container.native) self.layout().addWidget(self.pixel_block.native) self.layout().addWidget(self.canvas) self.layout().addWidget(NavigationToolbar(self.canvas, self)) @@ -192,7 +222,7 @@ class SegmentStack(q.QWidget): self.input.changed.connect(self.preprocess) self.canvas.mpl_connect("button_press_event", self.onfigclick) - self.canvas.mpl_connect("pick_event", self.on_pick) + # self.canvas.mpl_connect("pick_event", self.on_pick) logger.debug("Initialization finished.") @@ -200,6 +230,40 @@ class SegmentStack(q.QWidget): self.restore_params() self.preprocess() + def update_labels(self): + return [ + layer.name + for layer in self.viewer.layers + if isinstance(layer, Labels) + ] + + def update_images(self): + return [ + layer.name + for layer in self.viewer.layers + if isinstance(layer, Image) + ] + + def make_manual_layer(self): + try: + clone = self.selected_labels.compute() + layer = self.viewer.add_labels( + data=clone, + name=(manual_layer_name := "Manual Labels"), + scale=self.scale, + metadata={ + "binning": self.binning, + "source": self.viewer.layers["selected labels"] + } + ) + self.stat_layer_selector.choices = self.update_labels()[::-1] + + self.plot_stats(force=True) + + except Exception as e: + show_error(err := f"Unable to create manual layer: {e}") + logger.error(err) + def move_step(self, slice): cur_ = self.viewer.dims.current_step new_ = list(cur_) @@ -330,7 +394,7 @@ class SegmentStack(q.QWidget): if not (name := "Preprocessing") in self.viewer.layers: logger.debug("No Preprocessing layer found, adding one") - self.viewer.add_image( + self.layer_with_preprocessing = self.viewer.add_image( data=self.smooth_gradient, **{"name": name, "scale": self.scale}, ) @@ -339,8 +403,8 @@ class SegmentStack(q.QWidget): f"Updating Preprocessing layer with \ {self.smooth_gradient}" ) - self.viewer.layers[name].data = self.smooth_gradient - self.viewer.layers[name].scale = self.scale + self.layer_with_preprocessing.data = self.smooth_gradient + self.layer_with_preprocessing.scale = self.scale logger.debug( f"Preprocessing finished, \ @@ -364,7 +428,7 @@ class SegmentStack(q.QWidget): ) if not (name := "Detections") in self.viewer.layers: logger.debug("No Detections layer found, adding one") - self.viewer.add_labels( + self.layer_with_detections = self.viewer.add_labels( data=self.labels, opacity=0.3, **{"name": name, "scale": self.scale}, @@ -372,10 +436,10 @@ class SegmentStack(q.QWidget): else: logger.debug(f"Updating Detections layer with {self.labels}") - self.viewer.layers[name].data = self.labels - self.viewer.layers[name].scale = self.scale + self.layer_with_detections.data = self.labels + self.layer_with_detections.scale = self.scale - self.viewer.layers[name].contour = 8 // self.binning + self.layer_with_detections.contour = 8 // self.binning logger.debug( f"Thresholding succesful. \ @@ -409,7 +473,7 @@ class SegmentStack(q.QWidget): if not (name := "selected labels") in self.viewer.layers: logger.debug("No selected labels layer found, adding one") - self.viewer.add_labels( + self.layer_with_selected_labels = self.viewer.add_labels( data=selected_labels, opacity=0.5, **{"name": name, "scale": self.scale}, @@ -417,30 +481,42 @@ class SegmentStack(q.QWidget): else: logger.debug(f"Updating labels layer with {selected_labels}") - self.viewer.layers[name].scale = self.scale - self.viewer.layers[name].data = selected_labels - # self.save_params() + self.layer_with_selected_labels.scale = self.scale + self.layer_with_selected_labels.data = selected_labels except TypeError as e: show_error(f"Relax filter! {e}") + self.stat_layer_selector.choices = self.update_labels()[::-1] + # self.stat_layer_selector.current_choice = "selected labels" + try: self.plot_stats() except Exception as e: show_error(f"Plot failed: {e}") - def plot_stats(self, ignore_check_auto_plot=False): + def plot_stats(self, force=False): + if not self.stat_layer_selector.current_choice: + return + if not self.check_auto_plot.value: - if not ignore_check_auto_plot: + if not force: return self.pixel_size = float(self.pixel_size_widget.value) self.pixel_unit = str(self.pixel_unit_widget.value) self.ax[1].set_title(f"{self.diams_title}, {self.pixel_unit}") + data = self.viewer.layers[ + self.stat_layer_selector.current_choice + ].data + + if isinstance(data, dask.array.Array): + data = data.compute() + props = [ regionprops(label_image=img) - for img in self.selected_labels.compute() + for img in data ] num_regions_per_frame = [len(p) for p in props] area_ = [ @@ -481,7 +557,7 @@ class SegmentStack(q.QWidget): self.canvas.draw_idle() def save_csv(self): - self.plot_stats(ignore_check_auto_plot=True) + self.plot_stats(force=True) try: len(self.props_df) except Exception as e: