Skip to content
Snippets Groups Projects
Commit 480c350e authored by Andrey Aristov's avatar Andrey Aristov
Browse files

interactive plots

parent ce92485f
No related branches found
No related tags found
No related merge requests found
......@@ -9,6 +9,7 @@ Replace code below according to your needs.
from enum import Enum
import os
from functools import partial
from tempfile import TemporaryFile
import dask
import magicgui.widgets as w
......@@ -28,11 +29,14 @@ from scipy.ndimage import (
from functools import reduce
from skimage.measure import regionprops
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.backends.backend_qt5agg import (FigureCanvasQTAgg as FigureCanvas)
from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar
from matplotlib.figure import Figure
import matplotlib
from threading import Thread
import logging
from importlib.metadata import PackageNotFoundError, version
import pandas as pd
try:
__version__ = version("napari-segment")
......@@ -41,9 +45,9 @@ except PackageNotFoundError:
__version__ = "Unknown"
logging.basicConfig(
level=logging.DEBUG, format="%(asctime)s %(levelname)s : %(message)s"
level=logging.WARNING, format="%(asctime)s %(levelname)s : %(message)s"
)
logger = logging.getLogger("napari_segment.widget")
logger = logging.getLogger("napari_segment._widget")
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s : %(message)s"
......@@ -117,20 +121,26 @@ class ExampleQWidget(q.QWidget):
value=.9
)
self.btn = q.QPushButton("Save!")
self.btn = q.QPushButton("Save params!")
self.btn_save_csv = q.QPushButton("Save csv!")
self.check_auto_plot = w.CheckBox(label="Update Plots")
self.check_auto_plot.changed.connect(self.plot_stats)
self.canvas = FigureCanvas(Figure(figsize=(5,5)))
self.ax = self.canvas.figure.subplots(nrows=3, sharex=True)
self.ax = self.canvas.figure.subplots(nrows=3, sharex=False)
self.ax[0].set_title("Number of detections")
self.ax[1].set_title(diams_title := "Diameters")
self.diams_title = diams_title
self.ax[2].set_title("Eccentricities")
self.canvas.figure.tight_layout()
self._count, = self.ax[0].plot(range(10), [0]*10)
self._diams, = self.ax[1].plot(range(10), [0]*10)
self._eccs, = self.ax[2].plot(range(10), [0]*10)
self._count, = self.ax[0].plot(range(10), [0]*10, "o", picker=True, pickradius=5)
self._diams, = self.ax[1].plot(range(10), [0]*10, "o", picker=True, pickradius=5)
self._eccs, = self.ax[2].plot(range(10), [0]*10, "o", picker=True, pickradius=5)
self.container = w.Container(
label="Container",
widgets=[
self.input,
w.Label(label="Prepocessing"),
......@@ -144,6 +154,7 @@ class ExampleQWidget(q.QWidget):
self.min_diam,
self.max_diam,
self.max_ecc,
self.check_auto_plot
]
)
......@@ -151,27 +162,46 @@ class ExampleQWidget(q.QWidget):
self.layout().addWidget(self.container.native)
self.layout().addWidget(self.btn)
self.layout().addWidget(self.canvas)
self.layout().addWidget(NavigationToolbar(self.canvas, self))
self.layout().addWidget(self.btn_save_csv)
self.layout().addStretch()
self.viewer.layers.events.inserted.connect(self.reset_choices)
self.viewer.layers.events.removed.connect(self.reset_choices)
self.input.changed.connect(self.restore_params)
self.input.changed.connect(self.preprocess)
self.binning_widget.changed.connect(self.preprocess)
self.thr.changed.connect(self.threshold)
self.erode.changed.connect(self.threshold)
self.use.changed.connect(self.preprocess)
self.smooth.changed.connect(self.preprocess)
self.min_diam.changed.connect(self.update_out)
self.max_diam.changed.connect(self.update_out)
self.max_ecc.changed.connect(self.update_out)
self.btn.clicked.connect(self.save_params)
self.canvas.mpl_connect('button_press_event', self.onfigclick)
self.canvas.mpl_connect('pick_event', self.on_pick)
logger.debug("Initialization finished.")
if (choice := self.input.current_choice):
if self.input.current_choice:
self.restore_params()
self.preprocess()
def move_step(self, slice):
cur_ = self.viewer.dims.current_step
new_ = list(cur_)
new_[0] = slice
self.viewer.dims.current_step = tuple(new_)
def onfigclick(self, event):
print('%s click: button=%d, x=%d, y=%d, xdata=%f, ydata=%f' %
('double' if event.dblclick else 'single', event.button,
event.x, event.y, event.xdata, event.ydata))
self.move_step(np.round(event.xdata, 0).astype(int))
def on_pick(self, event):
line = event.artist
xdata, ydata = line.get_data()
ind = event.ind
print(f'on pick line: {np.array([xdata[ind], ydata[ind]]).T}')
self.move_step(xdata[ind])
def _invert(self, data2D):
return 1 - norm01(gaussian_filter(data2D, self.smooth.value))
......@@ -185,7 +215,17 @@ class ExampleQWidget(q.QWidget):
)
def preprocess(self):
if not self.input.current_choice:
return
logger.debug(f"start preprocessing {self.input.current_choice}")
try:
self.path = self.viewer.layers[self.input.current_choice].metadata[
"path"
]
except KeyError:
self.path = ""
self.binning = self.binning_widget.value
try:
self.data = self.viewer.layers[self.input.current_choice].data[
......@@ -197,7 +237,8 @@ class ExampleQWidget(q.QWidget):
return
try:
self.pixel_size = self.viewer.layers[self.input.current_choice].metadata["pixel_size_um"]
self.pixel_size = self.viewer.layers[self.input.current_choice]\
.metadata["pixel_size_um"]
self.unit = "um"
except KeyError:
......@@ -258,18 +299,21 @@ class ExampleQWidget(q.QWidget):
**{"name": name, "scale": self.scale},
)
else:
logger.debug(f'Updating Preprocessing layer with {self.smooth_gradient}')
logger.debug(f'Updating Preprocessing layer with \
{self.smooth_gradient}')
self.viewer.layers[name].data = self.smooth_gradient
self.viewer.layers[name].scale = self.scale
logger.debug(f'Preprocessing finished, sending {self.smooth_gradient} to thresholding')
logger.debug(f'Preprocessing finished, \
sending {self.smooth_gradient} to thresholding')
self.threshold()
def threshold(self):
logger.debug('Start thresholding step')
if not self.input.current_choice:
return
logger.debug(f'Thresholding with the thr={self.thr.value} and erode={self.erode.value}')
logger.debug(f'Thresholding with the thr={self.thr.value} \
and erode={self.erode.value}')
self.labels = self.smooth_gradient.map_blocks(
partial(
threshold_gradient, thr=self.thr.value, erode=self.erode.value
......@@ -283,13 +327,15 @@ class ExampleQWidget(q.QWidget):
opacity=0.3,
**{"name": name, "scale": self.scale},
)
self.viewer.layers[name].contour = 8 // self.binning
else:
logger.debug(f'Updating Detections layer with {self.labels}')
self.viewer.layers[name].data = self.labels
self.viewer.layers[name].scale = self.scale
self.viewer.layers[name].contour = 8 / self.binning
logger.debug(f'Thresholding succesful. Sending labels {self.labels} to filtering')
logger.debug(f'Thresholding succesful. \
Sending labels {self.labels} to filtering')
self.update_out()
def update_out(self):
......@@ -298,8 +344,9 @@ class ExampleQWidget(q.QWidget):
return
logger.debug(f'Start filtering')
try:
logger.debug(f"""Filtering labels by ({self.min_diam.value / self.binning}
< size[px] < {self.max_diam.value / self.binning})
logger.debug(f"""Filtering labels by \
({self.min_diam.value / self.binning} \
< size[px] < {self.max_diam.value / self.binning}) \
and eccentricity > {self.max_ecc.value}""")
selected_labels = dask.array.map_blocks(
......@@ -310,6 +357,8 @@ class ExampleQWidget(q.QWidget):
self.max_ecc.value,
dtype=np.int32,
)
self.selected_labels = selected_labels
if not (name := "selected labels") in self.viewer.layers:
logger.debug(f'No selected labels layer found, adding one')
......@@ -320,7 +369,7 @@ class ExampleQWidget(q.QWidget):
**{"name": name, "scale": self.scale},
)
else:
logger.debug(f'Updating selected labels layer with {selected_labels}')
logger.debug(f'Updating labels layer with {selected_labels}')
self.viewer.layers[name].scale = self.scale
self.viewer.layers[name].data = selected_labels
......@@ -331,40 +380,75 @@ class ExampleQWidget(q.QWidget):
show_error(f"Relax filter! {e}")
try:
self.plot_stats(selected_labels)
self.plot_stats()
except Exception as e:
show_error(f"Plot failed: {e}")
def plot_stats(self, data):
props = [regionprops(label_image=img) for img in data.compute()]
def plot_stats(self, ignore_check_auto_plot=False):
if not self.check_auto_plot.value:
if not ignore_check_auto_plot:
return
props = [regionprops(label_image=img) for img in self.selected_labels.compute()]
num_regions_per_frame = [len(p) for p in props]
diams_ = [
area_ = [
[
(i, prop.major_axis_length * self.binning * self.pixel_size) \
(i, prop.area * (self.binning * self.pixel_size)**2) \
for prop in props_per_frame
] for i,props_per_frame in enumerate(props)
]
diams = reduce(lambda a,b: a+b, diams_[:])
area = reduce(lambda a,b: a+b, area_[:])
self.props_df = \
pd.DataFrame(
data=area,
columns=["frame", (area_col := f"area [{self.unit}$^2$]")]
)
self.props_df.loc[:, (diam_col := f"diameter [{self.unit}]")] = \
np.sqrt(self.props_df[area_col]) * 2 / np.pi
eccs_ = [
[
(i, prop.eccentricity) for prop in props_per_frame
(prop.eccentricity) for prop in props_per_frame
] for i,props_per_frame in enumerate(props)
]
eccs = reduce(lambda a,b: a+b, eccs_)
self.props_df.loc[:, (ecc_col := "eccentricity")] = eccs
data=self.props_df
self._count.set_data(*zip(*enumerate(num_regions_per_frame)))
self._diams.set_data(*zip(*diams))
self._eccs.set_data(*zip(*eccs))
[a.set_xlim(len(num_regions_per_frame)) for a in self.ax]
self.ax[0].set_ylim(min(num_regions_per_frame), max(num_regions_per_frame))
self.ax[1].set_ylim(min(r := [d[1] for d in diams]), max(r))
self._diams.set_data(data["frame"], data[diam_col])
self._eccs.set_data(data["frame"], data[ecc_col])
[a.set_xlim(0, len(num_regions_per_frame)) for a in self.ax]
self.ax[0].set_ylim(0, max(num_regions_per_frame) + 1)
self.ax[1].set_ylim(
self.props_df[diam_col].min(),
self.props_df[diam_col].max()
)
self.ax[2].set_ylim(0,1)
self.canvas.draw_idle()
def save_csv(self):
self.plot_stats(ignore_check_auto_plot=True)
try:
len(self.props_df)
except Exception as e:
logger.error(f"Unable to find the data to save: {e}")
show_error(f"Unable to find the data to save: {e}")
return
path = self.path + ".table.csv"
try:
self.props_df.to_csv(path)
logger.info(f"Saved csv {path}")
show_info(f"Saved csv {path}")
except Exception as e:
logger.error(f"Error saving csv: {e}")
show_error(f"Error saving csv: {e}")
def save_params(self):
data = {
......@@ -378,27 +462,33 @@ class ExampleQWidget(q.QWidget):
"max_ecc": self.max_ecc.value,
}
try:
path = self.viewer.layers[self.input.current_choice].metadata[
"path"
]
dir = os.path.dirname(path)
filename = os.path.basename(path)
dir = os.path.dirname(self.path)
filename = os.path.basename(self.path)
new_name = filename + ".params.yaml"
with open(os.path.join(dir, new_name), "w") as f:
yaml.safe_dump(data, f)
show_info(f"Parameters saves into {new_name}")
except KeyError:
show_error("Saving parameters failed")
logger.info(f"Parameters saves into {new_name}")
except Exception as e:
show_error(f"Saving parameters into {new_name} failed: {e}")
logger.error(f"Saving parameters into {new_name} failed: {e}")
with open((ff := ".latest.params.yaml"), "w") as f:
yaml.safe_dump(data, f)
show_info(f"Parameters saves into {ff}")
logger.info(f"Parameters saves into {ff}")
def restore_params(self):
logger.debug(f'Start restoring parameters')
try:
path = self.viewer.layers[self.input.current_choice].metadata[
self.path = self.viewer.layers[self.input.current_choice].metadata[
"path"
]
except KeyError:
self.path = ""
try:
path = self.path
dir = os.path.dirname(path)
filename = os.path.basename(path)
new_name = filename + ".params.yaml"
......@@ -408,6 +498,7 @@ class ExampleQWidget(q.QWidget):
with open(ppp := os.path.join(dir, new_name)) as f:
data = yaml.safe_load(f)
show_info(f"restoring parameters from {new_name}")
logger.info(f"restoring parameters from {new_name}")
except (UnboundLocalError, UnicodeDecodeError, FileNotFoundError):
try:
......@@ -416,7 +507,7 @@ class ExampleQWidget(q.QWidget):
show_info(f"restoring parameters from {ppp}")
except FileNotFoundError:
return
print(data)
logger.debug(f'Loaded parameters: {data}')
try:
self.binning_widget.value = data["binning"]
self.use.value = data["use"]
......@@ -428,16 +519,32 @@ class ExampleQWidget(q.QWidget):
self.max_ecc.value = data["max_ecc"]
except Exception as e:
show_error(f"Restore settings failed, {e}")
logger.error(f"Restore settings failed, {e}")
self.binning_widget.changed.connect(self.preprocess)
self.thr.changed.connect(self.threshold)
self.erode.changed.connect(self.threshold)
self.use.changed.connect(self.preprocess)
self.smooth.changed.connect(self.preprocess)
self.min_diam.changed.connect(self.update_out)
self.max_diam.changed.connect(self.update_out)
self.max_ecc.changed.connect(self.update_out)
self.btn.clicked.connect(self.save_params)
self.btn_save_csv.clicked.connect(self.save_csv)
def reset_choices(self, event=None):
self.input.reset_choices(event)
self.input.choices = [
logger.debug(f"New layer added. Reset choices. Event: {event}")
new_layers = [
layer.name
for layer in self.viewer.layers
if isinstance(layer, Image) and layer.name != "Preprocessing"
]
# self.restore_params()
if self.input.choices != new_layers:
self.input.choices = new_layers
logger.debug(f'Updating layer list with {new_layers}')
else:
logger.debug('No new data layers, probably added pipeline layers triggered this reset')
def norm01(data):
d = data.copy()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment