From 669085df6f1aae90d889a43276de6c5083f7c890 Mon Sep 17 00:00:00 2001
From: Andrey Aristov <aaristov@pasteur.fr>
Date: Wed, 28 Sep 2022 17:13:41 +0200
Subject: [PATCH] improvements
---
src/napari_segment/_reader.py | 22 +++++++
src/napari_segment/_widget.py | 106 ++++++++++++++++++++-------------
src/napari_segment/napari.yaml | 2 +-
3 files changed, 89 insertions(+), 41 deletions(-)
diff --git a/src/napari_segment/_reader.py b/src/napari_segment/_reader.py
index a7d7d7c..7278e96 100644
--- a/src/napari_segment/_reader.py
+++ b/src/napari_segment/_reader.py
@@ -11,6 +11,7 @@ import os
import dask
import nd2
import numpy as np
+import tifffile as tf
def napari_get_reader(path):
@@ -40,6 +41,9 @@ def napari_get_reader(path):
if path.endswith(".zarr"):
return read_zarr
+ if path.endswith(".tif"):
+ return read_tif
+
# otherwise we return the *function* that can read ``path``.
if path.endswith(".npy"):
return reader_function
@@ -47,6 +51,24 @@ def napari_get_reader(path):
return None
+def read_tif(path):
+ data = tf.TiffFile(path)
+ arr = data.asarray()
+ channel_axis = (
+ arr.shape.index(data.imagej_metadata["channels"])
+ if data.is_imagej
+ else None
+ )
+
+ return [
+ (
+ arr,
+ {"channel_axis": channel_axis, "metadata": {"path": path}},
+ "image",
+ )
+ ]
+
+
def read_zarr(path):
print(f"read_zarr {path}")
diff --git a/src/napari_segment/_widget.py b/src/napari_segment/_widget.py
index 5370fb7..d1c3374 100644
--- a/src/napari_segment/_widget.py
+++ b/src/napari_segment/_widget.py
@@ -17,7 +17,7 @@ import qtpy.QtWidgets as q
import yaml
from magicgui import magic_factory
from napari.layers import Image
-from napari.utils.notifications import show_error
+from napari.utils.notifications import show_error, show_info
from scipy.ndimage import (
binary_erosion,
binary_fill_holes,
@@ -38,7 +38,7 @@ class ExampleQWidget(q.QWidget):
super().__init__()
self.viewer = napari_viewer
- self.data = w.ComboBox(
+ self.input = w.ComboBox(
label="BF data",
annotation=Image,
choices=[
@@ -47,8 +47,8 @@ class ExampleQWidget(q.QWidget):
if isinstance(layer, Image)
],
)
- self.data.changed.connect(self.restore_params)
- self.data.changed.connect(self.preprocess)
+ self.input.changed.connect(self.restore_params)
+ self.input.changed.connect(self.preprocess)
self.binning_widget = w.RadioButtons(
label="binning",
@@ -66,7 +66,7 @@ class ExampleQWidget(q.QWidget):
self.use = w.RadioButtons(
label="Use",
- choices=["Intensity", "Gradient"],
+ choices=["Intensity", "Gradient", "GDif"],
value="Intensity",
orientation="horizontal",
)
@@ -76,7 +76,9 @@ class ExampleQWidget(q.QWidget):
self.smooth.changed.connect(self.preprocess)
self.min_diam = w.Slider(
- label="Min_diameter", min=50, max=500, step=50
+ label="Min_diameter",
+ min=1,
+ max=500,
)
self.min_diam.changed.connect(self.update_out)
@@ -97,7 +99,7 @@ class ExampleQWidget(q.QWidget):
self.container = w.Container(
widgets=[
- self.data,
+ self.input,
w.Label(label="Prepocessing"),
self.binning_widget,
self.use,
@@ -120,7 +122,7 @@ class ExampleQWidget(q.QWidget):
self.viewer.layers.events.inserted.connect(self.reset_choices)
self.viewer.layers.events.removed.connect(self.reset_choices)
- if self.data.current_choice:
+ if self.input.current_choice:
print("start")
self.restore_params()
self.preprocess()
@@ -128,31 +130,51 @@ class ExampleQWidget(q.QWidget):
def preprocess(self):
self.binning = self.binning_widget.value
try:
- self.ddata = self.viewer.layers[self.data.current_choice].data[
+ self.data = self.viewer.layers[self.input.current_choice].data[
..., :: self.binning, :: self.binning
]
except KeyError:
show_error("No data to process")
return
- self.scale = np.ones((len(self.ddata.shape),))
+ self.scale = np.ones((len(self.data.shape),))
self.scale[-2:] = self.binning
- if isinstance(self.ddata, np.ndarray):
- chunksize = np.ones(self.ddata.ndims)
- chunksize[-2:] = self.ddata.shape[-2:] # xy full size
- self.ddata = dask.array.from_array(self.ddata, chunksize=chunksize)
- self.smooth_gradient = (
- self.ddata.map_blocks(
+ if isinstance(self.data, np.ndarray):
+ chunksize = np.ones(len(self.data.shape))
+ chunksize[-2:] = self.data.shape[-2:] # xy full size
+ self.ddata = dask.array.from_array(
+ self.data.astype("f"), chunks=chunksize
+ )
+ else:
+ self.ddata = self.data.astype("f")
+
+ show_info(self.use.value)
+ if self.use.value == "Gradient":
+ self.smooth_gradient = self.ddata.map_blocks(
partial(get_gradient, smooth=self.smooth.value),
dtype=self.ddata.dtype,
)
- if self.use.value == "Gradient"
- else self.ddata.map_blocks(
- lambda d: 1
- - (a := (s := gaussian_filter(d, self.smooth.value)) - s.min())
- / a.max(),
+ elif self.use.value == "Intensity":
+ self.smooth_gradient = self.ddata.map_blocks(
+ lambda d: 1 - norm01(gaussian_filter(d, self.smooth.value)),
dtype=self.ddata.dtype,
)
- )
+ elif self.use.value == "GDif":
+ self.smooth_gradient = self.ddata.map_blocks(
+ lambda d: (
+ gaussian_filter(d, self.smooth.value)
+ - gaussian_filter(d, self.smooth.value + 2)
+ ),
+ dtype=self.ddata.dtype,
+ )
+ else:
+ self.smooth_gradient = np.zeros_like(self.ddata)
+ raise (
+ ValueError(
+ f"""Filter `{self.use.value}` not understood!
+ Expected `Gradient` or `Intensity` or `GDif`"""
+ )
+ )
+
if not (name := "Preprocessing") in self.viewer.layers:
self.viewer.add_image(
data=self.smooth_gradient,
@@ -164,13 +186,13 @@ class ExampleQWidget(q.QWidget):
self.threshold()
def threshold(self):
- if not self.data.current_choice:
+ if not self.input.current_choice:
return
self.labels = self.smooth_gradient.map_blocks(
partial(
threshold_gradient, thr=self.thr.value, erode=self.erode.value
),
- dtype=self.ddata.dtype,
+ dtype=np.int32,
)
if not (name := "Detections") in self.viewer.layers:
self.viewer.add_labels(
@@ -186,7 +208,7 @@ class ExampleQWidget(q.QWidget):
def update_out(self):
- if not self.data.current_choice:
+ if not self.input.current_choice:
return
try:
@@ -196,7 +218,7 @@ class ExampleQWidget(q.QWidget):
self.min_diam.value / self.binning,
self.max_diam.value / self.binning,
self.max_ecc.value,
- dtype=self.ddata.dtype,
+ dtype=np.int32,
)
if not (name := "selected labels") in self.viewer.layers:
@@ -209,8 +231,8 @@ class ExampleQWidget(q.QWidget):
self.viewer.layers[name].scale = self.scale
self.viewer.layers[name].data = selected_labels
# self.save_params()
- except TypeError:
- show_error("Relax filter!")
+ except TypeError as e:
+ show_error(f"Relax filter! {e}")
def save_params(self):
data = {
@@ -224,37 +246,41 @@ class ExampleQWidget(q.QWidget):
"max_ecc": self.max_ecc.value,
}
try:
- path = self.viewer.layers[self.data.current_choice].metadata[
+ path = self.viewer.layers[self.input.current_choice].metadata[
"path"
]
dir = os.path.dirname(path)
filename = os.path.basename(path)
- new_name = filename.replace(".nd2", ".params.yaml")
+ new_name = filename + ".params.yaml"
with open(os.path.join(dir, new_name), "w") as f:
yaml.safe_dump(data, f)
+ show_info(f"Parameters saves into {new_name}")
except KeyError:
- pass
- with open(os.path.join(".latest.params.yaml"), "w") as f:
+ show_error("Saving parameters failed")
+ with open((ff := ".latest.params.yaml"), "w") as f:
yaml.safe_dump(data, f)
+ show_info(f"Parameters saves into {ff}")
def restore_params(self):
try:
- path = self.viewer.layers[self.data.current_choice].metadata[
+ path = self.viewer.layers[self.input.current_choice].metadata[
"path"
]
dir = os.path.dirname(path)
filename = os.path.basename(path)
- new_name = filename.replace(".nd2", ".params.yaml")
+ new_name = filename + ".params.yaml"
except KeyError:
pass
try:
with open(ppp := os.path.join(dir, new_name)) as f:
data = yaml.safe_load(f)
- except OSError:
- with open(ppp := os.path.join(".latest.params.yaml")) as f:
+ show_info(f"restoring parameters from {new_name}")
+
+ except (UnboundLocalError, UnicodeDecodeError):
+ with open(ppp := ".latest.params.yaml") as f:
data = yaml.safe_load(f)
- print(f"restoring parameters from {ppp}")
+ show_info(f"restoring parameters from {ppp}")
print(data)
try:
self.binning_widget.value = data["binning"]
@@ -269,13 +295,13 @@ class ExampleQWidget(q.QWidget):
show_error(f"Restore settings failed, {e}")
def reset_choices(self, event=None):
- self.data.reset_choices(event)
- self.data.choices = [
+ self.input.reset_choices(event)
+ self.input.choices = [
layer.name
for layer in self.viewer.layers
if isinstance(layer, Image) and layer.name != "Preprocessing"
]
- self.restore_params()
+ # self.restore_params()
def norm01(data):
diff --git a/src/napari_segment/napari.yaml b/src/napari_segment/napari.yaml
index 1c9eb36..e98f5bf 100644
--- a/src/napari_segment/napari.yaml
+++ b/src/napari_segment/napari.yaml
@@ -26,7 +26,7 @@ contributions:
readers:
- command: napari-segment.get_reader
accepts_directories: True
- filename_patterns: ['*.npy', '*.nd2', '*.zarr']
+ filename_patterns: ['*.npy', '*.nd2', '*.zarr', '*.tif']
writers:
- command: napari-segment.write_multiple
layer_types: ['image*','labels*']
--
GitLab