diff --git a/.gitignore b/.gitignore index 429a119d7be766cf68478aa6e32e48b1c6ee3a7d..173785c2997152e5499314b856bf663767ebd897 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ .snakemake *.pyc +__pycache__ +venv \ No newline at end of file diff --git a/align.py b/align.py index 6a53e39b44cc39e391edc1f9bde0849f27da0538..28af8bdace995910fe3a97b10cd24e7a3a8678ad 100644 --- a/align.py +++ b/align.py @@ -1,6 +1,8 @@ -from typing import List import numpy as np -from droplet_growth import mic, register, poisson +import register +import poisson +from skimage.measure import regionprops, regionprops_table +import count import tifffile as tf import os import dask.array as da @@ -11,189 +13,250 @@ import fire from multiprocessing import Pool import nd2 import pandas as pd -from skimage.measure import regionprops_table from scipy.ndimage import laplace, rotate import json def align_multichip( - BF_TRITC_2D_path, - out_path, - concentrations_path, - template_path, - labels_path, - table_path, - fit_poisson, - nmax=10): - with open(concentrations_path, 'r') as f: - concentrations_dct = (yaml.safe_load(f)) - concentrations = concentrations_dct['concentrations'] - unit = concentrations_dct['units'] - print(f'concentrations from {concentrations_path}: {concentrations} [{unit}]') - + BF_TRITC_2D_path, + out_path, + concentrations_path, + template_path, + labels_path, + table_path, + fit_poisson, + nmax=10, +): + with open(concentrations_path, "r") as f: + concentrations_dct = yaml.safe_load(f) + concentrations = concentrations_dct["concentrations"] + unit = concentrations_dct["units"] + print( + f"concentrations from {concentrations_path}: {concentrations} [{unit}]" + ) if "rotation_data_deg" in concentrations_dct: rotation_data_deg = int(concentrations_dct["rotation_data_deg"]) - print(f'Rotation: {rotation_data_deg}') + print(f"Rotation: {rotation_data_deg}") else: rotation_data_deg = 0 - - tif_paths = [f'{os.path.dirname(BF_TRITC_2D_path)}/{c}{unit}_aligned.tif' for c in concentrations] - print('tif_paths: ', tif_paths) - + + tif_paths = [ + f"{os.path.dirname(BF_TRITC_2D_path)}/{c}{unit}_aligned.tif" + for c in concentrations + ] + print("tif_paths: ", tif_paths) + data = read_dask(BF_TRITC_2D_path) template16 = tf.imread(template_path) big_labels = tf.imread(labels_path) fun = partial( - align_parallel, + align_parallel, template16=template16, big_labels=big_labels, unit=unit, fit_poisson=fit_poisson, nmax=nmax, - rotation_data_deg=rotation_data_deg ) - + rotation_data_deg=rotation_data_deg, + ) + try: - p=Pool(data.shape[0]) - out = p.map(fun, zip(data, tif_paths, concentrations)) + p = Pool(data.shape[0]) + out = p.map(fun, zip(data, tif_paths, concentrations)) except TypeError as e: - print(f'Pool failed due to {e.args}') + print(f"Pool failed due to {e.args}") out = list(map(fun, zip(data, tif_paths, concentrations))) finally: p.close() - - aligned = [o['stack'] for o in out] - counts = [o['counts'] for o in out] - tvecs = [o['tvec'] for o in out] - save_tvecs(tvecs, out_path.replace('.zarr', '.transform.json')) + aligned = [o["stack"] for o in out] + counts = [o["counts"] for o in out] + tvecs = [o["tvec"] for o in out] + + save_tvecs(tvecs, out_path.replace(".zarr", ".transform.json")) - df = pd.concat(counts, ignore_index=True).sort_values(['[AB]','label']) + df = pd.concat(counts, ignore_index=True).sort_values(["[AB]", "label"]) df.to_csv(table_path) daligned = da.from_array(np.array(aligned)) convert.to_zarr( - daligned, - path=out_path, - steps=4, - name=['BF','TRITC','mask'], - colormap=['gray','green','blue'], - lut=((1000,30000),(440, 600),(0,501)), + daligned, + path=out_path, + steps=4, + name=["BF", "TRITC", "mask"], + colormap=["gray", "green", "blue"], + lut=((1000, 30000), (440, 600), (0, 501)), ) return out_path + def prep_tvec(transform_dict): t = transform_dict.copy() t["tvec"] = list(t["tvec"]) return t + def save_tvecs(tvecs, path): try: transform_data = list(map(prep_tvec, tvecs)) - with open(path, 'w') as f: + with open(path, "w") as f: json.dump(transform_data, fp=f) except Exception as e: - print('saving transform json failed: ', e.args) + print("saving transform json failed: ", e.args) + def align_parallel(args, **kwargs): return align2D(*args, **kwargs) - + + def align2D( - stack_dask, - path_tif, - ab, - template16=None, - big_labels=None, - unit='μg_mL', - fit_poisson=True, - nmax=20, - rotation_data_deg=0): + stack_dask, + path_tif, + ab, + template16=None, + big_labels=None, + unit="μg_mL", + fit_poisson=True, + nmax=20, + rotation_data_deg=0, +): print(ab, unit) try: aligned = tf.imread(path_tif) - print(f'already aligned: {path_tif}:{aligned.shape}') - counts = count(aligned, ab=ab) - intensity_table = get_intensity_table(aligned[2],aligned[1]) - counts.loc[:, 'intensity'] = intensity_table.intensity + print(f"already aligned: {path_tif}:{aligned.shape}") + counts = count_cells(aligned, ab=ab) + intensity_table = get_intensity_table(aligned[2], aligned[1]) + counts.loc[:, "intensity"] = intensity_table.intensity return {"stack": aligned, "counts": counts} except FileNotFoundError: - print('Processing...') - + print("Processing...") + data = stack_dask.compute() if rotation_data_deg != 0: - data = rotate(input=data, angle=rotation_data_deg, axes=(1,2)) - print(f'Rotated data {rotation_data_deg} deg') + data = rotate(input=data, angle=rotation_data_deg, axes=(1, 2)) + print(f"Rotated data {rotation_data_deg} deg") aligned, tvec = register.align_stack( - data, - path_to_save=None, - template16=template16, - mask2=big_labels, - binnings=(2,16,2) + data, + template16=template16, + mask2=big_labels, + binnings=(2, 16, 2), ) - counts = count(aligned, ab=ab) - intensity_table = get_intensity_table(aligned[2],aligned[1]) - counts.loc[:, 'intensity'] = intensity_table.intensity - + counts = count_cells(aligned, ab=ab) + intensity_table = get_intensity_table(aligned[2], aligned[1]) + counts.loc[:, "intensity"] = intensity_table.intensity + if fit_poisson: - l = poisson.fit( - counts.query(f'n_cells < {nmax}').n_cells, + lambda_fit_result = poisson.fit( + counts.query(f"n_cells < {nmax}").n_cells, title=f"automatic {ab}{unit.replace('_','/')}", - save_fig_path=path_tif.replace('.tif', '-counts-hist.png') + save_fig_path=path_tif.replace(".tif", "-counts-hist.png"), ) - counts.loc[:, 'poisson fit'] = l + counts.loc[:, "poisson fit"] = lambda_fit_result - # counts.to_csv((cp := path_tif.replace('.tif', '-counts.csv')), index=None) - return {"stack": aligned, "counts": counts, 'tvec': tvec} + return {"stack": aligned, "counts": counts, "tvec": tvec} -def count(aligned:np.ndarray, ab=None): - counts = mic.get_cell_numbers(aligned[1], aligned[2], threshold_abs=2, plot=False, bf=aligned[0]) - counts.loc[:,'[AB]'] = ab + +def count_cells(aligned: np.ndarray, ab=None): + counts = get_cell_numbers( + aligned[1], aligned[2], threshold_abs=2, plot=False, bf=aligned[0] + ) + counts.loc[:, "[AB]"] = ab return counts + +def get_cell_numbers( + multiwell_image: np.ndarray, + labels: np.ndarray, + plot=False, + threshold_abs: float = 2, + min_distance: float = 5, + meta: dict = {}, + bf: np.ndarray = None, +) -> pd.DataFrame: + props = regionprops(labels) + + def get_n_peaks(i): + if bf is None: + return count.get_peak_number( + multiwell_image[props[i].slice], + plot=plot, + dif_gauss_sigma=(3, 5), + threshold_abs=threshold_abs, + min_distance=min_distance, + title=props[i].label, + ) + else: + return count.get_peak_number( + multiwell_image[props[i].slice], + plot=plot, + dif_gauss_sigma=(3, 5), + threshold_abs=threshold_abs, + min_distance=min_distance, + title=props[i].label, + bf_crop=bf[props[i].slice], + return_std=True, + ) + + n_cells = list(map(get_n_peaks, range(labels.max()))) + return pd.DataFrame( + [ + { + "label": prop.label, + "x": prop.centroid[0], + "y": prop.centroid[1], + "n_cells": n_cell[0], + # 'std': n_cell[1], + **meta, + } + for prop, n_cell in zip(props, n_cells) + ] + ) + + def get_intensity_table( labelled_mask: np.ndarray, intensity_image: np.ndarray, - values = ['mean_intensity', ], - estimator = np.mean, - plot: bool = True + values=[ + "mean_intensity", + ], ): - assert (iis := intensity_image).ndim == 2, ( - f'expected 2D array for intensity, got shape {iis.shape}' - ) - data = intensity_image.astype('f') + assert ( + iis := intensity_image + ).ndim == 2, f"expected 2D array for intensity, got shape {iis.shape}" + data = intensity_image.astype("f") dict_li = regionprops_table( - labelled_mask, - intensity_image=data, - properties=['label', *values] + labelled_mask, intensity_image=data, properties=["label", *values] ) dict_bg = regionprops_table( get_outlines(labelled_mask), intensity_image=data, - properties=['mean_intensity'] + properties=["mean_intensity"], ) - dict_litb = {**dict_li, 'bg_mean': dict_bg['mean_intensity']} + dict_litb = {**dict_li, "bg_mean": dict_bg["mean_intensity"]} df = pd.DataFrame.from_dict(dict_litb) df.loc[:, "intensity"] = df.mean_intensity - df.bg_mean return df + def get_outlines(labels): - '''creates 1 px outline around labels''' + """creates 1 px outline around labels""" return labels * (np.abs(laplace(labels)) > 0) -def read_dask(path:str): - if path.endswith('.zarr'): - data = da.from_zarr(path+'/0/') - elif path.endswith('.nd2'): +def read_dask(path: str): + if path.endswith(".zarr"): + data = da.from_zarr(path + "/0/") + elif path.endswith(".nd2"): data = nd2.ND2File(path).to_dask() else: - raise ValueError(f'Unexpected file format, expected zarr or nd2') - print('data:', data) + raise ValueError("Unexpected file format, expected zarr or nd2") + print("data:", data) return data - + + if __name__ == "__main__": fire.Fire(align_multichip) diff --git a/count.py b/count.py new file mode 100644 index 0000000000000000000000000000000000000000..f48b2ecdd0a7afe5d78d2ea6fb323847423a934d --- /dev/null +++ b/count.py @@ -0,0 +1,297 @@ +from functools import partial +import matplotlib.pyplot as plt +import numpy as np +from skimage.feature import peak_local_max +from skimage.feature import peak +from scipy import ndimage as ndi + + +def peak_local_max_labels( + image, + min_distance=1, + threshold_abs=None, + threshold_rel=None, + exclude_border=True, + indices=True, + num_peaks=np.inf, + footprint=None, + labels=None, + num_peaks_per_label=np.inf, + p_norm=np.inf, +): + """Find peaks in an image as coordinate list or boolean mask. + Peaks are the local maxima in a region of `2 * min_distance + 1` + (i.e. peaks are separated by at least `min_distance`). + If both `threshold_abs` and `threshold_rel` are provided, the maximum + of the two is chosen as the minimum intensity threshold of peaks. + .. versionchanged:: 0.18 + Prior to version 0.18, peaks of the same height within a radius of + `min_distance` were all returned, but this could cause unexpected + behaviour. From 0.18 onwards, an arbitrary peak within the region is + returned. See issue gh-2592. + Parameters + ---------- + image : ndarray + Input image. + min_distance : int, optional + The minimal allowed distance separating peaks. To find the + maximum number of peaks, use `min_distance=1`. + threshold_abs : float or None, optional + Minimum intensity of peaks. By default, the absolute threshold is + the minimum intensity of the image. + threshold_rel : float or None, optional + Minimum intensity of peaks, calculated as + ``max(image) * threshold_rel``. + exclude_border : int, tuple of ints, or bool, optional + If positive integer, `exclude_border` excludes peaks from within + `exclude_border`-pixels of the border of the image. + If tuple of non-negative ints, the length of the tuple must match the + input array's dimensionality. Each element of the tuple will exclude + peaks from within `exclude_border`-pixels of the border of the image + along that dimension. + If True, takes the `min_distance` parameter as value. + If zero or False, peaks are identified regardless of their distance + from the border. + indices : bool, optional + If True, the output will be an array representing peak + coordinates. The coordinates are sorted according to peaks + values (Larger first). If False, the output will be a boolean + array shaped as `image.shape` with peaks present at True + elements. ``indices`` is deprecated and will be removed in + version 0.20. Default behavior will be to always return peak + coordinates. You can obtain a mask as shown in the example + below. + num_peaks : int, optional + Maximum number of peaks. When the number of peaks exceeds `num_peaks`, + return `num_peaks` peaks based on highest peak intensity. + footprint : ndarray of bools, optional + If provided, `footprint == 1` represents the local region within which + to search for peaks at every point in `image`. + labels : ndarray of ints, optional + If provided, each unique region `labels == value` represents a unique + region to search for peaks. Zero is reserved for background. The labels are returned a s a third column. + num_peaks_per_label : int, optional + Maximum number of peaks for each label. + p_norm : float + Which Minkowski p-norm to use. Should be in the range [1, inf]. + A finite large p may cause a ValueError if overflow can occur. + ``inf`` corresponds to the Chebyshev distance and 2 to the + Euclidean distance. + Returns + ------- + output : ndarray or ndarray of bools + * If `indices = True` : (row, column, ...) coordinates of peaks. + * If `indices = False` : Boolean array shaped like `image`, with peaks + represented by True values. + Notes + ----- + The peak local maximum function returns the coordinates of local peaks + (maxima) in an image. Internally, a maximum filter is used for finding local + maxima. This operation dilates the original image. After comparison of the + dilated and original image, this function returns the coordinates or a mask + of the peaks where the dilated image equals the original image. + See also + -------- + skimage.feature.corner_peaks + Examples + -------- + >>> img1 = np.zeros((7, 7)) + >>> img1[3, 4] = 1 + >>> img1[3, 2] = 1.5 + >>> img1 + array([[0. , 0. , 0. , 0. , 0. , 0. , 0. ], + [0. , 0. , 0. , 0. , 0. , 0. , 0. ], + [0. , 0. , 0. , 0. , 0. , 0. , 0. ], + [0. , 0. , 1.5, 0. , 1. , 0. , 0. ], + [0. , 0. , 0. , 0. , 0. , 0. , 0. ], + [0. , 0. , 0. , 0. , 0. , 0. , 0. ], + [0. , 0. , 0. , 0. , 0. , 0. , 0. ]]) + >>> peak_local_max(img1, min_distance=1) + array([[3, 2], + [3, 4]]) + >>> peak_local_max(img1, min_distance=2) + array([[3, 2]]) + >>> img2 = np.zeros((20, 20, 20)) + >>> img2[10, 10, 10] = 1 + >>> img2[15, 15, 15] = 1 + >>> peak_idx = peak_local_max(img2, exclude_border=0) + >>> peak_idx + array([[10, 10, 10], + [15, 15, 15]]) + >>> peak_mask = np.zeros_like(img2, dtype=bool) + >>> peak_mask[tuple(peak_idx.T)] = True + >>> np.argwhere(peak_mask) + array([[10, 10, 10], + [15, 15, 15]]) + """ + if (footprint is None or footprint.size == 1) and min_distance < 1: + warn( + "When min_distance < 1, peak_local_max acts as finding " + "image > max(threshold_abs, threshold_rel * max(image)).", + RuntimeWarning, + stacklevel=2, + ) + + border_width = peak._get_excluded_border_width( + image, min_distance, exclude_border + ) + + threshold = peak._get_threshold(image, threshold_abs, threshold_rel) + + if footprint is None: + size = 2 * min_distance + 1 + footprint = np.ones((size,) * image.ndim, dtype=bool) + else: + footprint = np.asarray(footprint) + + if labels is None: + # Non maximum filter + mask = peak._get_peak_mask(image, footprint, threshold) + + mask = peak._exclude_border(mask, border_width) + + # Select highest intensities (num_peaks) + coordinates = peak._get_high_intensity_peaks( + image, mask, num_peaks, min_distance, p_norm + ) + + else: + _labels = peak._exclude_border( + labels.astype(int, casting="safe"), border_width + ) + + if np.issubdtype(image.dtype, np.floating): + bg_val = np.finfo(image.dtype).min + else: + bg_val = np.iinfo(image.dtype).min + + # For each label, extract a smaller image enclosing the object of + # interest, identify num_peaks_per_label peaks + labels_peak_coord = [] + + for label_idx, roi in enumerate(ndi.find_objects(_labels)): + if roi is None: + continue + + # Get roi mask + label_mask = labels[roi] == label_idx + 1 + # Extract image roi + img_object = image[roi].copy() + # Ensure masked values don't affect roi's local peaks + img_object[np.logical_not(label_mask)] = bg_val + + mask = peak._get_peak_mask( + img_object, footprint, threshold, label_mask + ) + + coordinates = peak._get_high_intensity_peaks( + img_object, mask, num_peaks_per_label, min_distance, p_norm + ) + + # transform coordinates in global image indices space + for idx, s in enumerate(roi): + coordinates[:, idx] += s.start + + coordinates = np.hstack( + (coordinates, np.ones((len(coordinates), 1)) * (label_idx + 1)) + ) + + labels_peak_coord.append(coordinates) + + if labels_peak_coord: + coordinates = np.vstack(labels_peak_coord) + else: + coordinates = np.empty((0, 2), dtype=int) + + if len(coordinates) > num_peaks: + out = np.zeros_like(image, dtype=bool) + out[tuple(coordinates.T)] = True + coordinates = peak._get_high_intensity_peaks( + image, out, num_peaks, min_distance, p_norm + ) + + if indices: + return coordinates + + +def crop(stack: np.ndarray, center: tuple, size: int): + im = stack[ + :, + int(center[0]) - size // 2 : int(center[0]) + size // 2, + int(center[1]) - size // 2 : int(center[1]) + size // 2, + ] + return im + + +def gdif(array2d, dif_gauss_sigma=(1, 3)): + array2d = array2d.astype("f") + return ndi.gaussian_filter( + array2d, sigma=dif_gauss_sigma[0] + ) - ndi.gaussian_filter(array2d, sigma=dif_gauss_sigma[1]) + + +def get_peak_number( + crop2d, + dif_gauss_sigma=(1, 3), + min_distance=3, + threshold_abs=5, + plot=False, + title="", + bf_crop=None, + return_std=False, +): + image_max = gdif(crop2d, dif_gauss_sigma) + peaks = peak_local_max( + image_max, min_distance=min_distance, threshold_abs=threshold_abs + ) + + if plot: + if bf_crop is None: + fig, ax = plt.subplots(1, 2, sharey=True) + ax[0].imshow(crop2d) + ax[0].set_title(f"raw image {title}") + ax[1].imshow(image_max) + ax[1].set_title("Filtered + peak detection") + ax[1].plot(peaks[:, 1], peaks[:, 0], "r.") + plt.show() + else: + fig, ax = plt.subplots(1, 3, sharey=True) + + ax[0].imshow(bf_crop, cmap="gray") + ax[0].set_title(f"BF {title}") + + ax[1].imshow(crop2d, vmax=crop2d.mean() + 2 * crop2d.std()) + ax[1].set_title(f"raw image {title}") + + ax[2].imshow(image_max) + ax[2].set_title( + f"Filtered + {len(peaks)} peaks (std {image_max.std():.2f})" + ) + ax[2].plot(peaks[:, 1], peaks[:, 0], "r.") + plt.show() + + if return_std: + return len(peaks), crop2d.std() + else: + return (len(peaks),) + + +def get_peaks_per_frame(stack3d, dif_gauss_sigma=(1, 3), **kwargs): + image_ref = gdif(stack3d[0], dif_gauss_sigma) + thr = 5 * image_ref.std() + return list( + map(partial(get_peak_number, threshold_abs=thr, **kwargs), stack3d) + ) + + +def get_peaks_all_wells(stack, centers, size, plot=0): + n_peaks = [] + for c in centers: + print(".", end="") + well = crop(stack, c["center"], size) + n_peaks.append(get_peaks_per_frame(well, plot=plot)) + return n_peaks + + +def get_n_cells(peaks: list, n_frames=6): + return np.round(np.mean(peaks[:n_frames]), 0).astype(int) diff --git a/merge_tables.py b/merge_tables.py index 86ccdb33d758494d6aa28d7334965c3329803391..af3b17e75359264811df519b219203f67db00c8d 100644 --- a/merge_tables.py +++ b/merge_tables.py @@ -4,41 +4,86 @@ import seaborn as sns import matplotlib.pyplot as plt -def combine(table_day1_path, table_day2_path, table_output_path, swarmplot_output_path, prob_plot_path, threshold:int=20): +def combine( + table_day1_path, + table_day2_path, + table_output_path, + swarmplot_output_path, + prob_plot_path, + threshold: int = 20, +): day1, day2 = [pd.read_csv(t) for t in [table_day1_path, table_day2_path]] - day1.loc[:,'n_cells_final'] = day2.n_cells - day1.loc[:,'intensity_final'] = day2.intensity + day1.loc[:, "n_cells_final"] = day2.n_cells + day1.loc[:, "intensity_final"] = day2.intensity - day1.loc[:,'final_state'] = day1.n_cells_final > threshold + day1.loc[:, "final_state"] = day1.n_cells_final > threshold day1.to_csv(table_output_path, index=None) n_max = int(day1.n_cells.mean()) * 3 + 1 plot_swarm_counts(day1, n_max=n_max, path=swarmplot_output_path) - plot_swarm_intensity(day1, n_max=n_max, path=swarmplot_output_path.replace('.png', '_intensities.png')) - plot_probs(day1, n_max=n_max,path=prob_plot_path) - plot_probs_log(day1, n_max=n_max,path=prob_plot_path.replace('.png', '_log.png')) - -def plot_swarm_counts(table:pd.DataFrame, n_max:int=5, path=None): + plot_swarm_intensity( + day1, + n_max=n_max, + path=swarmplot_output_path.replace(".png", "_intensities.png"), + ) + plot_probs(day1, n_max=n_max, path=prob_plot_path) + plot_probs_log( + day1, n_max=n_max, path=prob_plot_path.replace(".png", "_log.png") + ) + + +def plot_swarm_counts(table: pd.DataFrame, n_max: int = 5, path=None): fig, ax = plt.subplots(dpi=300) - sns.swarmplot(ax=ax, data=table.query(f'n_cells <= {n_max}'), x='[AB]', y='n_cells_final', hue='n_cells', dodge=True, size=1 ) + sns.swarmplot( + ax=ax, + data=table.query(f"n_cells <= {n_max}"), + x="[AB]", + y="n_cells_final", + hue="n_cells", + dodge=True, + size=1, + ) fig.savefig(path) -def plot_swarm_intensity(table:pd.DataFrame, n_max:int=5, path=None): + +def plot_swarm_intensity(table: pd.DataFrame, n_max: int = 5, path=None): fig, ax = plt.subplots(dpi=300) - sns.swarmplot(ax=ax, data=table.query(f'n_cells <= {n_max}'), x='[AB]', y='intensity_final', hue='n_cells', dodge=True, size=1 ) + sns.swarmplot( + ax=ax, + data=table.query(f"n_cells <= {n_max}"), + x="[AB]", + y="intensity_final", + hue="n_cells", + dodge=True, + size=1, + ) fig.savefig(path) -def plot_probs(table:pd.DataFrame, n_max:int=5, path=None): - + +def plot_probs(table: pd.DataFrame, n_max: int = 5, path=None): fig, ax = plt.subplots(dpi=300) - sns.lineplot(ax=ax, data=table.query(f'n_cells <= {n_max}'), x='[AB]', y='final_state', hue='n_cells') + sns.lineplot( + ax=ax, + data=table.query(f"n_cells <= {n_max}"), + x="[AB]", + y="final_state", + hue="n_cells", + ) fig.savefig(path) -def plot_probs_log(table:pd.DataFrame, n_max:int=5, path=None): + +def plot_probs_log(table: pd.DataFrame, n_max: int = 5, path=None): fig, ax = plt.subplots(dpi=300) - sns.lineplot(ax=ax, data=table.query(f'n_cells <= {n_max}'), x='[AB]', y='final_state', hue='n_cells') - ax.set_xscale('log') + sns.lineplot( + ax=ax, + data=table.query(f"n_cells <= {n_max}"), + x="[AB]", + y="final_state", + hue="n_cells", + ) + ax.set_xscale("log") fig.savefig(path) + if __name__ == "__main__": - fire.Fire(combine) \ No newline at end of file + fire.Fire(combine) diff --git a/poisson.py b/poisson.py new file mode 100644 index 0000000000000000000000000000000000000000..943ddfecd0447524c310f241766a3bc6d6475934 --- /dev/null +++ b/poisson.py @@ -0,0 +1,41 @@ +from scipy.optimize import curve_fit +from scipy.stats import poisson +import matplotlib.pyplot as plt +import numpy as np + + +def fit( + numbers: np.ndarray, + max_value=None, + xlabel="Initial number of cells", + title="", + plot=True, + save_fig_path=None, +): + if max_value is None: + max_value = numbers.max() + bins = np.arange(max_value + 1) - 0.5 + vector = bins[:-1] + 0.5 + hist, bins = np.histogram(numbers, bins=bins, density=True) + popt, pcov = curve_fit(poisson.pmf, vector, hist, p0=(1.0,)) + lambda_fit_result = popt[0] + if plot: + plt.hist(numbers, bins=bins, fill=None) + plt.plot( + vector, + len(numbers) * poisson.pmf(vector, lambda_fit_result), + ".-", + label=f"Poisson fit λ={lambda_fit_result:.1f}", + color="tab:red", + ) + plt.xlabel(xlabel) + plt.title(title) + plt.legend() + if save_fig_path is not None: + try: + plt.savefig(save_fig_path) + print(f"Save histogram {save_fig_path}") + except Exception as e: + print("saving histogram failed", e.args) + plt.show() + return lambda_fit_result diff --git a/register.py b/register.py new file mode 100644 index 0000000000000000000000000000000000000000..a6b9fac1a6cf55fd156321ace35093323b4acd34 --- /dev/null +++ b/register.py @@ -0,0 +1,208 @@ +from tifffile import imread +import matplotlib.pyplot as plt +import imreg_dft as reg +import numpy as np + + +def align_stack( + data_or_path, + template16, + mask2, + plot=False, + binnings=(1, 16, 2), +): + """ + stack should contain two channels: bright field and fluorescence. + BF will be binned 8 times and registered with template8 (aligned BF). + When the transformation verctor will be applied to the original data and + stacked with the mask. + The output stack is of the same size as mask. + + :return: aligned_array[bf, fluo, mask], tvec + :rtype: np.ndarray, dict + + """ + if isinstance(data_or_path, str): + path = data_or_path + stack = imread(path) + print(path, stack.shape) + else: + assert data_or_path.ndim == 3 and data_or_path.shape[0] == 2 + stack = data_or_path + + bf, tritc = stack[:2] + stack_temp_scale = binnings[1] // binnings[0] + mask_temp_scale = binnings[1] // binnings[2] + stack_mask_scale = binnings[2] // binnings[0] + + f_bf = bf[::stack_temp_scale, ::stack_temp_scale] + + tvec8 = get_transform(f_bf, template16, plot=plot) + plt.show() + tvec = scale_tvec(tvec8, mask_temp_scale) + print(tvec) + try: + aligned_tritc = unpad( + transform(tritc[::stack_mask_scale, ::stack_mask_scale], tvec), + mask2.shape, + ) + aligned_bf = unpad( + transform(bf[::stack_mask_scale, ::stack_mask_scale], tvec), + mask2.shape, + ) + except ValueError as e: + print("stack_mask_scale: ", stack_mask_scale) + print(e.args) + raise e + + aligned_stack = np.stack((aligned_bf, aligned_tritc, mask2)).astype( + "uint16" + ) + + return aligned_stack, tvec + + +def get_transform( + image, template, plot=True, pad_ratio=1.2, figsize=(10, 5), dpi=300 +): + """ + Pads image and template, registers and returns tvec + """ + padded_template = pad(template, (s := increase(image.shape, pad_ratio))) + padded_image = pad(image, s) + tvec = register(padded_image, padded_template) + if plot: + aligned_bf = unpad(tvec["timg"], template.shape) + plt.figure(figsize=figsize, dpi=dpi) + plt.imshow(aligned_bf, cmap="gray") + return tvec + + +def register(image, template): + """ + Register image towards template + Return: + tvec:dict + """ + assert np.array_equal( + image.shape, template.shape + ), f"unequal shapes {(image.shape, template.shape)}" + return reg.similarity( + template, + image, + constraints={ + "scale": [1, 0.2], + "tx": [0, 50], + "ty": [0, 50], + "angle": [0, 30], + }, + ) + + +def filter_by_fft( + image, + sigma=40, + fix_horizontal_stripes=False, + fix_vertical_stripes=False, + highpass=True, +): + size = get_fft_size(image.shape) + fft = np.fft.fft2(pad(image, size)) + if fix_vertical_stripes: + fft[0, 1:] = 0 # vertical + if fix_horizontal_stripes: + fft[1:, 0] = 0 # horizontal + fft = fft * (fft_mask(size, sigma=size[0] / sigma, highpass=highpass)) + filtered_image = unpad(abs(np.fft.ifft2(fft)), image.shape) + return filtered_image + + +def fft_mask(shape, sigma=10, highpass=False): + y, x = np.indices(shape) + + gaus = np.exp(-((x - shape[1] // 2) ** 2) / sigma**2 / 2) * np.exp( + -((y - shape[0] // 2) ** 2) / sigma**2 / 2 + ) + gaus = gaus / gaus.max() + if highpass: + return 1 - gaus + return gaus + + +def pad(image: np.ndarray, to_shape: tuple = None, padding: tuple = None): + if padding is None: + padding = calculate_padding(image.shape, to_shape) + try: + padded = np.pad(image, padding, "edge") + except TypeError as e: + print(e.args, padding) + raise e + return padded + + +def unpad(image: np.ndarray, to_shape: tuple = None, padding: tuple = None): + if any(np.array(image.shape) - np.array(to_shape) < 0): + print( + f"""unpad:warning: image.shape {image.shape} \ + is within to_shape {to_shape}""" + ) + image = pad(image, np.array((image.shape, to_shape)).max(axis=0)) + print(f"new image shape after padding {image.shape}") + if padding is None: + padding = calculate_padding(to_shape, image.shape) + + y = [padding[0][0], -padding[0][1]] + if y[1] == 0: + y[1] = None + x = [padding[1][0], -padding[1][1]] + if x[1] == 0: + x[1] = None + return image[y[0]: y[1], x[0]: x[1]] + + +def get_fft_size(shape): + max_size = np.max(shape) + n = 5 + while (2**n) < (max_size * 1.5): + n += 1 + return (2**n, 2**n) + + +def calculate_padding(shape1: tuple, shape2: tuple): + """ + Calculates padding to get shape2 from shape1 + Return: + 2D tuple of indices + """ + dif = np.array(shape2) - np.array(shape1) + assert all( + dif >= 0 + ), f"Shape2 must be bigger than shape1, got {shape2}, {shape1}" + mid = dif // 2 + rest = dif - mid + y = mid[0], rest[0] + x = mid[1], rest[1] + return y, x + + +def scale_tvec(tvec, scale=8): + tvec_8x = tvec.copy() + tvec_8x["tvec"] = tvec["tvec"] * scale + try: + tvec_8x["timg"] = None + except KeyError: + pass + finally: + return tvec_8x + + +def transform(image, tvec): + print(f"transform {image.shape}") + fluo = reg.transform_img_dict(image, tvec) + return fluo.astype("uint") + + +def increase(shape, increase_ratio): + assert increase_ratio > 1 + shape = np.array(shape) + return tuple((shape * increase_ratio).astype(int)) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..93e4409c4734c3ff85686a36238a15cab6c05516 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,85 @@ +appdirs==1.4.4 +asciitree==0.3.3 +attrs==23.1.0 +certifi==2023.7.22 +charset-normalizer==3.2.0 +click==8.1.7 +cloudpickle==2.2.1 +ConfigArgParse==1.7 +connection-pool==0.0.3 +contourpy==1.1.1 +cycler==0.11.0 +dask==2023.5.0 +datrie==0.8.2 +docutils==0.20.1 +dpath==2.1.6 +entrypoints==0.4 +fasteners==0.18 +fastjsonschema==2.18.0 +fire==0.5.0 +fonttools==4.42.1 +fsspec==2023.9.1 +gitdb==4.0.10 +GitPython==3.1.36 +humanfriendly==10.0 +idna==3.4 +imageio==2.31.3 +importlib-metadata==6.8.0 +importlib-resources==6.0.1 +imreg-dft @ git+https://github.com/matejak/imreg_dft.git@4023e9e2c4110b04017c85e7bda9883a71b2eea0 +Jinja2==3.1.2 +jsonschema==4.19.0 +jsonschema-specifications==2023.7.1 +jupyter_core==5.3.1 +kiwisolver==1.4.5 +lazy_loader==0.3 +locket==1.0.0 +MarkupSafe==2.1.3 +matplotlib==3.7.3 +nbformat==5.9.2 +nd2==0.7.2 +networkx==3.1 +numcodecs==0.11.0 +numpy==1.24.4 +packaging==23.1 +pandas==2.0.3 +partd==1.4.0 +Pillow==10.0.1 +pkgutil_resolve_name==1.3.10 +plac==1.3.5 +platformdirs==3.10.0 +psutil==5.9.5 +PuLP==2.7.0 +pyparsing==3.1.1 +python-dateutil==2.8.2 +pytz==2023.3.post1 +PyWavelets==1.4.1 +PyYAML==6.0.1 +referencing==0.30.2 +requests==2.31.0 +reretry==0.11.8 +resource-backed-dask-array==0.1.0 +rpds-py==0.10.3 +scikit-image==0.21.0 +scipy==1.10.1 +seaborn==0.12.2 +six==1.16.0 +smart-open==6.4.0 +smmap==5.0.1 +snakemake==7.32.4 +stopit==1.1.2 +tabulate==0.9.0 +termcolor==2.3.0 +throttler==1.2.2 +tifffile==2023.7.10 +toolz==0.12.0 +toposort==1.10 +traitlets==5.10.0 +typing_extensions==4.8.0 +tzdata==2023.3 +urllib3==2.0.4 +wrapt==1.15.0 +yte==1.5.1 +zarr==2.16.1 +zarr-tools==0.4.5 +zipp==3.17.0