diff --git a/src/napari_segment/_reader.py b/src/napari_segment/_reader.py index a7d7d7cac232b7c7d15be9fba8ca64a4ad7b287c..7278e96fa40101eebc58f6e08c168d5d5b6a0176 100644 --- a/src/napari_segment/_reader.py +++ b/src/napari_segment/_reader.py @@ -11,6 +11,7 @@ import os import dask import nd2 import numpy as np +import tifffile as tf def napari_get_reader(path): @@ -40,6 +41,9 @@ def napari_get_reader(path): if path.endswith(".zarr"): return read_zarr + if path.endswith(".tif"): + return read_tif + # otherwise we return the *function* that can read ``path``. if path.endswith(".npy"): return reader_function @@ -47,6 +51,24 @@ def napari_get_reader(path): return None +def read_tif(path): + data = tf.TiffFile(path) + arr = data.asarray() + channel_axis = ( + arr.shape.index(data.imagej_metadata["channels"]) + if data.is_imagej + else None + ) + + return [ + ( + arr, + {"channel_axis": channel_axis, "metadata": {"path": path}}, + "image", + ) + ] + + def read_zarr(path): print(f"read_zarr {path}") diff --git a/src/napari_segment/_widget.py b/src/napari_segment/_widget.py index 5370fb734b250e04b7ba8fcd289edf7fede53814..d1c3374e340c3d94603c33a36cc3c92675b4b62a 100644 --- a/src/napari_segment/_widget.py +++ b/src/napari_segment/_widget.py @@ -17,7 +17,7 @@ import qtpy.QtWidgets as q import yaml from magicgui import magic_factory from napari.layers import Image -from napari.utils.notifications import show_error +from napari.utils.notifications import show_error, show_info from scipy.ndimage import ( binary_erosion, binary_fill_holes, @@ -38,7 +38,7 @@ class ExampleQWidget(q.QWidget): super().__init__() self.viewer = napari_viewer - self.data = w.ComboBox( + self.input = w.ComboBox( label="BF data", annotation=Image, choices=[ @@ -47,8 +47,8 @@ class ExampleQWidget(q.QWidget): if isinstance(layer, Image) ], ) - self.data.changed.connect(self.restore_params) - self.data.changed.connect(self.preprocess) + self.input.changed.connect(self.restore_params) + self.input.changed.connect(self.preprocess) self.binning_widget = w.RadioButtons( label="binning", @@ -66,7 +66,7 @@ class ExampleQWidget(q.QWidget): self.use = w.RadioButtons( label="Use", - choices=["Intensity", "Gradient"], + choices=["Intensity", "Gradient", "GDif"], value="Intensity", orientation="horizontal", ) @@ -76,7 +76,9 @@ class ExampleQWidget(q.QWidget): self.smooth.changed.connect(self.preprocess) self.min_diam = w.Slider( - label="Min_diameter", min=50, max=500, step=50 + label="Min_diameter", + min=1, + max=500, ) self.min_diam.changed.connect(self.update_out) @@ -97,7 +99,7 @@ class ExampleQWidget(q.QWidget): self.container = w.Container( widgets=[ - self.data, + self.input, w.Label(label="Prepocessing"), self.binning_widget, self.use, @@ -120,7 +122,7 @@ class ExampleQWidget(q.QWidget): self.viewer.layers.events.inserted.connect(self.reset_choices) self.viewer.layers.events.removed.connect(self.reset_choices) - if self.data.current_choice: + if self.input.current_choice: print("start") self.restore_params() self.preprocess() @@ -128,31 +130,51 @@ class ExampleQWidget(q.QWidget): def preprocess(self): self.binning = self.binning_widget.value try: - self.ddata = self.viewer.layers[self.data.current_choice].data[ + self.data = self.viewer.layers[self.input.current_choice].data[ ..., :: self.binning, :: self.binning ] except KeyError: show_error("No data to process") return - self.scale = np.ones((len(self.ddata.shape),)) + self.scale = np.ones((len(self.data.shape),)) self.scale[-2:] = self.binning - if isinstance(self.ddata, np.ndarray): - chunksize = np.ones(self.ddata.ndims) - chunksize[-2:] = self.ddata.shape[-2:] # xy full size - self.ddata = dask.array.from_array(self.ddata, chunksize=chunksize) - self.smooth_gradient = ( - self.ddata.map_blocks( + if isinstance(self.data, np.ndarray): + chunksize = np.ones(len(self.data.shape)) + chunksize[-2:] = self.data.shape[-2:] # xy full size + self.ddata = dask.array.from_array( + self.data.astype("f"), chunks=chunksize + ) + else: + self.ddata = self.data.astype("f") + + show_info(self.use.value) + if self.use.value == "Gradient": + self.smooth_gradient = self.ddata.map_blocks( partial(get_gradient, smooth=self.smooth.value), dtype=self.ddata.dtype, ) - if self.use.value == "Gradient" - else self.ddata.map_blocks( - lambda d: 1 - - (a := (s := gaussian_filter(d, self.smooth.value)) - s.min()) - / a.max(), + elif self.use.value == "Intensity": + self.smooth_gradient = self.ddata.map_blocks( + lambda d: 1 - norm01(gaussian_filter(d, self.smooth.value)), dtype=self.ddata.dtype, ) - ) + elif self.use.value == "GDif": + self.smooth_gradient = self.ddata.map_blocks( + lambda d: ( + gaussian_filter(d, self.smooth.value) + - gaussian_filter(d, self.smooth.value + 2) + ), + dtype=self.ddata.dtype, + ) + else: + self.smooth_gradient = np.zeros_like(self.ddata) + raise ( + ValueError( + f"""Filter `{self.use.value}` not understood! + Expected `Gradient` or `Intensity` or `GDif`""" + ) + ) + if not (name := "Preprocessing") in self.viewer.layers: self.viewer.add_image( data=self.smooth_gradient, @@ -164,13 +186,13 @@ class ExampleQWidget(q.QWidget): self.threshold() def threshold(self): - if not self.data.current_choice: + if not self.input.current_choice: return self.labels = self.smooth_gradient.map_blocks( partial( threshold_gradient, thr=self.thr.value, erode=self.erode.value ), - dtype=self.ddata.dtype, + dtype=np.int32, ) if not (name := "Detections") in self.viewer.layers: self.viewer.add_labels( @@ -186,7 +208,7 @@ class ExampleQWidget(q.QWidget): def update_out(self): - if not self.data.current_choice: + if not self.input.current_choice: return try: @@ -196,7 +218,7 @@ class ExampleQWidget(q.QWidget): self.min_diam.value / self.binning, self.max_diam.value / self.binning, self.max_ecc.value, - dtype=self.ddata.dtype, + dtype=np.int32, ) if not (name := "selected labels") in self.viewer.layers: @@ -209,8 +231,8 @@ class ExampleQWidget(q.QWidget): self.viewer.layers[name].scale = self.scale self.viewer.layers[name].data = selected_labels # self.save_params() - except TypeError: - show_error("Relax filter!") + except TypeError as e: + show_error(f"Relax filter! {e}") def save_params(self): data = { @@ -224,37 +246,41 @@ class ExampleQWidget(q.QWidget): "max_ecc": self.max_ecc.value, } try: - path = self.viewer.layers[self.data.current_choice].metadata[ + path = self.viewer.layers[self.input.current_choice].metadata[ "path" ] dir = os.path.dirname(path) filename = os.path.basename(path) - new_name = filename.replace(".nd2", ".params.yaml") + new_name = filename + ".params.yaml" with open(os.path.join(dir, new_name), "w") as f: yaml.safe_dump(data, f) + show_info(f"Parameters saves into {new_name}") except KeyError: - pass - with open(os.path.join(".latest.params.yaml"), "w") as f: + show_error("Saving parameters failed") + with open((ff := ".latest.params.yaml"), "w") as f: yaml.safe_dump(data, f) + show_info(f"Parameters saves into {ff}") def restore_params(self): try: - path = self.viewer.layers[self.data.current_choice].metadata[ + path = self.viewer.layers[self.input.current_choice].metadata[ "path" ] dir = os.path.dirname(path) filename = os.path.basename(path) - new_name = filename.replace(".nd2", ".params.yaml") + new_name = filename + ".params.yaml" except KeyError: pass try: with open(ppp := os.path.join(dir, new_name)) as f: data = yaml.safe_load(f) - except OSError: - with open(ppp := os.path.join(".latest.params.yaml")) as f: + show_info(f"restoring parameters from {new_name}") + + except (UnboundLocalError, UnicodeDecodeError): + with open(ppp := ".latest.params.yaml") as f: data = yaml.safe_load(f) - print(f"restoring parameters from {ppp}") + show_info(f"restoring parameters from {ppp}") print(data) try: self.binning_widget.value = data["binning"] @@ -269,13 +295,13 @@ class ExampleQWidget(q.QWidget): show_error(f"Restore settings failed, {e}") def reset_choices(self, event=None): - self.data.reset_choices(event) - self.data.choices = [ + self.input.reset_choices(event) + self.input.choices = [ layer.name for layer in self.viewer.layers if isinstance(layer, Image) and layer.name != "Preprocessing" ] - self.restore_params() + # self.restore_params() def norm01(data): diff --git a/src/napari_segment/napari.yaml b/src/napari_segment/napari.yaml index 1c9eb364e22bd9e5e142c85b4f483d40c0fde4bd..e98f5bf1bc839adb59b0f0524e271909df6bdf2e 100644 --- a/src/napari_segment/napari.yaml +++ b/src/napari_segment/napari.yaml @@ -26,7 +26,7 @@ contributions: readers: - command: napari-segment.get_reader accepts_directories: True - filename_patterns: ['*.npy', '*.nd2', '*.zarr'] + filename_patterns: ['*.npy', '*.nd2', '*.zarr', '*.tif'] writers: - command: napari-segment.write_multiple layer_types: ['image*','labels*']