diff --git a/.gitignore b/.gitignore
index 429a119d7be766cf68478aa6e32e48b1c6ee3a7d..173785c2997152e5499314b856bf663767ebd897 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,4 @@
 .snakemake
 *.pyc
+__pycache__
+venv
\ No newline at end of file
diff --git a/align.py b/align.py
index 6a53e39b44cc39e391edc1f9bde0849f27da0538..28af8bdace995910fe3a97b10cd24e7a3a8678ad 100644
--- a/align.py
+++ b/align.py
@@ -1,6 +1,8 @@
-from typing import List
 import numpy as np
-from droplet_growth import mic, register, poisson
+import register
+import poisson
+from skimage.measure import regionprops, regionprops_table
+import count
 import tifffile as tf
 import os
 import dask.array as da
@@ -11,189 +13,250 @@ import fire
 from multiprocessing import Pool
 import nd2
 import pandas as pd
-from skimage.measure import regionprops_table
 from scipy.ndimage import laplace, rotate
 import json
 
 
 def align_multichip(
-    BF_TRITC_2D_path, 
-    out_path, 
-    concentrations_path, 
-    template_path, 
-    labels_path, 
-    table_path, 
-    fit_poisson, 
-    nmax=10):
-    with open(concentrations_path, 'r') as f:
-        concentrations_dct = (yaml.safe_load(f))
-    concentrations = concentrations_dct['concentrations']
-    unit = concentrations_dct['units']
-    print(f'concentrations from {concentrations_path}: {concentrations} [{unit}]')
-
+    BF_TRITC_2D_path,
+    out_path,
+    concentrations_path,
+    template_path,
+    labels_path,
+    table_path,
+    fit_poisson,
+    nmax=10,
+):
+    with open(concentrations_path, "r") as f:
+        concentrations_dct = yaml.safe_load(f)
+    concentrations = concentrations_dct["concentrations"]
+    unit = concentrations_dct["units"]
+    print(
+        f"concentrations from {concentrations_path}: {concentrations} [{unit}]"
+    )
 
     if "rotation_data_deg" in concentrations_dct:
         rotation_data_deg = int(concentrations_dct["rotation_data_deg"])
-        print(f'Rotation: {rotation_data_deg}')
+        print(f"Rotation: {rotation_data_deg}")
     else:
         rotation_data_deg = 0
-    
-    tif_paths = [f'{os.path.dirname(BF_TRITC_2D_path)}/{c}{unit}_aligned.tif' for c in concentrations]
-    print('tif_paths: ', tif_paths)
-    
+
+    tif_paths = [
+        f"{os.path.dirname(BF_TRITC_2D_path)}/{c}{unit}_aligned.tif"
+        for c in concentrations
+    ]
+    print("tif_paths: ", tif_paths)
+
     data = read_dask(BF_TRITC_2D_path)
     template16 = tf.imread(template_path)
     big_labels = tf.imread(labels_path)
 
     fun = partial(
-        align_parallel, 
+        align_parallel,
         template16=template16,
         big_labels=big_labels,
         unit=unit,
         fit_poisson=fit_poisson,
         nmax=nmax,
-        rotation_data_deg=rotation_data_deg )
-    
+        rotation_data_deg=rotation_data_deg,
+    )
+
     try:
-        p=Pool(data.shape[0])
-        out = p.map(fun, zip(data, tif_paths, concentrations)) 
+        p = Pool(data.shape[0])
+        out = p.map(fun, zip(data, tif_paths, concentrations))
 
     except TypeError as e:
-        print(f'Pool failed due to {e.args}')
+        print(f"Pool failed due to {e.args}")
         out = list(map(fun, zip(data, tif_paths, concentrations)))
     finally:
         p.close()
-    
-    aligned = [o['stack'] for o in out]
-    counts = [o['counts'] for o in out]
-    tvecs = [o['tvec'] for o in out]
 
-    save_tvecs(tvecs, out_path.replace('.zarr', '.transform.json'))
+    aligned = [o["stack"] for o in out]
+    counts = [o["counts"] for o in out]
+    tvecs = [o["tvec"] for o in out]
+
+    save_tvecs(tvecs, out_path.replace(".zarr", ".transform.json"))
 
-    df = pd.concat(counts, ignore_index=True).sort_values(['[AB]','label'])
+    df = pd.concat(counts, ignore_index=True).sort_values(["[AB]", "label"])
     df.to_csv(table_path)
 
     daligned = da.from_array(np.array(aligned))
     convert.to_zarr(
-        daligned, 
-        path=out_path, 
-        steps=4, 
-        name=['BF','TRITC','mask'], 
-        colormap=['gray','green','blue'],
-        lut=((1000,30000),(440, 600),(0,501)),
+        daligned,
+        path=out_path,
+        steps=4,
+        name=["BF", "TRITC", "mask"],
+        colormap=["gray", "green", "blue"],
+        lut=((1000, 30000), (440, 600), (0, 501)),
     )
     return out_path
 
+
 def prep_tvec(transform_dict):
     t = transform_dict.copy()
     t["tvec"] = list(t["tvec"])
     return t
 
+
 def save_tvecs(tvecs, path):
     try:
         transform_data = list(map(prep_tvec, tvecs))
-        with open(path, 'w') as f:
+        with open(path, "w") as f:
             json.dump(transform_data, fp=f)
     except Exception as e:
-        print('saving transform json failed: ', e.args)
+        print("saving transform json failed: ", e.args)
+
 
 def align_parallel(args, **kwargs):
     return align2D(*args, **kwargs)
-    
+
+
 def align2D(
-    stack_dask, 
-    path_tif, 
-    ab, 
-    template16=None, 
-    big_labels=None, 
-    unit='μg_mL', 
-    fit_poisson=True, 
-    nmax=20, 
-    rotation_data_deg=0):
+    stack_dask,
+    path_tif,
+    ab,
+    template16=None,
+    big_labels=None,
+    unit="μg_mL",
+    fit_poisson=True,
+    nmax=20,
+    rotation_data_deg=0,
+):
     print(ab, unit)
     try:
         aligned = tf.imread(path_tif)
-        print(f'already aligned: {path_tif}:{aligned.shape}')
-        counts = count(aligned, ab=ab)
-        intensity_table = get_intensity_table(aligned[2],aligned[1])
-        counts.loc[:, 'intensity'] = intensity_table.intensity
+        print(f"already aligned: {path_tif}:{aligned.shape}")
+        counts = count_cells(aligned, ab=ab)
+        intensity_table = get_intensity_table(aligned[2], aligned[1])
+        counts.loc[:, "intensity"] = intensity_table.intensity
         return {"stack": aligned, "counts": counts}
     except FileNotFoundError:
-        print('Processing...')
-        
+        print("Processing...")
+
     data = stack_dask.compute()
     if rotation_data_deg != 0:
-        data = rotate(input=data, angle=rotation_data_deg, axes=(1,2))
-        print(f'Rotated data {rotation_data_deg} deg')
+        data = rotate(input=data, angle=rotation_data_deg, axes=(1, 2))
+        print(f"Rotated data {rotation_data_deg} deg")
 
     aligned, tvec = register.align_stack(
-        data, 
-        path_to_save=None, 
-        template16=template16, 
-        mask2=big_labels, 
-        binnings=(2,16,2)
+        data,
+        template16=template16,
+        mask2=big_labels,
+        binnings=(2, 16, 2),
     )
-    counts = count(aligned, ab=ab)
-    intensity_table = get_intensity_table(aligned[2],aligned[1])
-    counts.loc[:, 'intensity'] = intensity_table.intensity
-    
+    counts = count_cells(aligned, ab=ab)
+    intensity_table = get_intensity_table(aligned[2], aligned[1])
+    counts.loc[:, "intensity"] = intensity_table.intensity
+
     if fit_poisson:
-        l = poisson.fit(
-            counts.query(f'n_cells < {nmax}').n_cells, 
+        lambda_fit_result = poisson.fit(
+            counts.query(f"n_cells < {nmax}").n_cells,
             title=f"automatic {ab}{unit.replace('_','/')}",
-            save_fig_path=path_tif.replace('.tif', '-counts-hist.png')
+            save_fig_path=path_tif.replace(".tif", "-counts-hist.png"),
         )
-        counts.loc[:, 'poisson fit'] = l
+        counts.loc[:, "poisson fit"] = lambda_fit_result
 
-    # counts.to_csv((cp := path_tif.replace('.tif', '-counts.csv')), index=None)
-    return {"stack": aligned, "counts": counts, 'tvec': tvec}
+    return {"stack": aligned, "counts": counts, "tvec": tvec}
 
-def count(aligned:np.ndarray, ab=None):
-    counts = mic.get_cell_numbers(aligned[1], aligned[2], threshold_abs=2, plot=False, bf=aligned[0])
-    counts.loc[:,'[AB]'] = ab
+
+def count_cells(aligned: np.ndarray, ab=None):
+    counts = get_cell_numbers(
+        aligned[1], aligned[2], threshold_abs=2, plot=False, bf=aligned[0]
+    )
+    counts.loc[:, "[AB]"] = ab
     return counts
 
+
+def get_cell_numbers(
+    multiwell_image: np.ndarray,
+    labels: np.ndarray,
+    plot=False,
+    threshold_abs: float = 2,
+    min_distance: float = 5,
+    meta: dict = {},
+    bf: np.ndarray = None,
+) -> pd.DataFrame:
+    props = regionprops(labels)
+
+    def get_n_peaks(i):
+        if bf is None:
+            return count.get_peak_number(
+                multiwell_image[props[i].slice],
+                plot=plot,
+                dif_gauss_sigma=(3, 5),
+                threshold_abs=threshold_abs,
+                min_distance=min_distance,
+                title=props[i].label,
+            )
+        else:
+            return count.get_peak_number(
+                multiwell_image[props[i].slice],
+                plot=plot,
+                dif_gauss_sigma=(3, 5),
+                threshold_abs=threshold_abs,
+                min_distance=min_distance,
+                title=props[i].label,
+                bf_crop=bf[props[i].slice],
+                return_std=True,
+            )
+
+    n_cells = list(map(get_n_peaks, range(labels.max())))
+    return pd.DataFrame(
+        [
+            {
+                "label": prop.label,
+                "x": prop.centroid[0],
+                "y": prop.centroid[1],
+                "n_cells": n_cell[0],
+                # 'std': n_cell[1],
+                **meta,
+            }
+            for prop, n_cell in zip(props, n_cells)
+        ]
+    )
+
+
 def get_intensity_table(
     labelled_mask: np.ndarray,
     intensity_image: np.ndarray,
-    values = ['mean_intensity', ],
-    estimator = np.mean,
-    plot: bool = True
+    values=[
+        "mean_intensity",
+    ],
 ):
-    assert (iis := intensity_image).ndim == 2, (
-        f'expected 2D array for intensity, got shape {iis.shape}'
-    )
-    data = intensity_image.astype('f')
+    assert (
+        iis := intensity_image
+    ).ndim == 2, f"expected 2D array for intensity, got shape {iis.shape}"
+    data = intensity_image.astype("f")
     dict_li = regionprops_table(
-        labelled_mask,
-        intensity_image=data,
-        properties=['label', *values]
+        labelled_mask, intensity_image=data, properties=["label", *values]
     )
     dict_bg = regionprops_table(
         get_outlines(labelled_mask),
         intensity_image=data,
-        properties=['mean_intensity']
+        properties=["mean_intensity"],
     )
-    dict_litb = {**dict_li, 'bg_mean': dict_bg['mean_intensity']}
+    dict_litb = {**dict_li, "bg_mean": dict_bg["mean_intensity"]}
     df = pd.DataFrame.from_dict(dict_litb)
 
     df.loc[:, "intensity"] = df.mean_intensity - df.bg_mean
     return df
 
+
 def get_outlines(labels):
-    '''creates 1 px outline around labels'''
+    """creates 1 px outline around labels"""
     return labels * (np.abs(laplace(labels)) > 0)
 
 
-def read_dask(path:str):
-    if path.endswith('.zarr'):
-        data = da.from_zarr(path+'/0/')
-    elif path.endswith('.nd2'):
+def read_dask(path: str):
+    if path.endswith(".zarr"):
+        data = da.from_zarr(path + "/0/")
+    elif path.endswith(".nd2"):
         data = nd2.ND2File(path).to_dask()
     else:
-        raise ValueError(f'Unexpected file format, expected zarr or nd2')
-    print('data:', data)
+        raise ValueError("Unexpected file format, expected zarr or nd2")
+    print("data:", data)
     return data
-   
+
+
 if __name__ == "__main__":
     fire.Fire(align_multichip)
diff --git a/count.py b/count.py
new file mode 100644
index 0000000000000000000000000000000000000000..f48b2ecdd0a7afe5d78d2ea6fb323847423a934d
--- /dev/null
+++ b/count.py
@@ -0,0 +1,297 @@
+from functools import partial
+import matplotlib.pyplot as plt
+import numpy as np
+from skimage.feature import peak_local_max
+from skimage.feature import peak
+from scipy import ndimage as ndi
+
+
+def peak_local_max_labels(
+    image,
+    min_distance=1,
+    threshold_abs=None,
+    threshold_rel=None,
+    exclude_border=True,
+    indices=True,
+    num_peaks=np.inf,
+    footprint=None,
+    labels=None,
+    num_peaks_per_label=np.inf,
+    p_norm=np.inf,
+):
+    """Find peaks in an image as coordinate list or boolean mask.
+    Peaks are the local maxima in a region of `2 * min_distance + 1`
+    (i.e. peaks are separated by at least `min_distance`).
+    If both `threshold_abs` and `threshold_rel` are provided, the maximum
+    of the two is chosen as the minimum intensity threshold of peaks.
+    .. versionchanged:: 0.18
+        Prior to version 0.18, peaks of the same height within a radius of
+        `min_distance` were all returned, but this could cause unexpected
+        behaviour. From 0.18 onwards, an arbitrary peak within the region is
+        returned. See issue gh-2592.
+    Parameters
+    ----------
+    image : ndarray
+        Input image.
+    min_distance : int, optional
+        The minimal allowed distance separating peaks. To find the
+        maximum number of peaks, use `min_distance=1`.
+    threshold_abs : float or None, optional
+        Minimum intensity of peaks. By default, the absolute threshold is
+        the minimum intensity of the image.
+    threshold_rel : float or None, optional
+        Minimum intensity of peaks, calculated as
+        ``max(image) * threshold_rel``.
+    exclude_border : int, tuple of ints, or bool, optional
+        If positive integer, `exclude_border` excludes peaks from within
+        `exclude_border`-pixels of the border of the image.
+        If tuple of non-negative ints, the length of the tuple must match the
+        input array's dimensionality.  Each element of the tuple will exclude
+        peaks from within `exclude_border`-pixels of the border of the image
+        along that dimension.
+        If True, takes the `min_distance` parameter as value.
+        If zero or False, peaks are identified regardless of their distance
+        from the border.
+    indices : bool, optional
+        If True, the output will be an array representing peak
+        coordinates. The coordinates are sorted according to peaks
+        values (Larger first). If False, the output will be a boolean
+        array shaped as `image.shape` with peaks present at True
+        elements. ``indices`` is deprecated and will be removed in
+        version 0.20. Default behavior will be to always return peak
+        coordinates. You can obtain a mask as shown in the example
+        below.
+    num_peaks : int, optional
+        Maximum number of peaks. When the number of peaks exceeds `num_peaks`,
+        return `num_peaks` peaks based on highest peak intensity.
+    footprint : ndarray of bools, optional
+        If provided, `footprint == 1` represents the local region within which
+        to search for peaks at every point in `image`.
+    labels : ndarray of ints, optional
+        If provided, each unique region `labels == value` represents a unique
+        region to search for peaks. Zero is reserved for background. The labels are returned a s a third column.
+    num_peaks_per_label : int, optional
+        Maximum number of peaks for each label.
+    p_norm : float
+        Which Minkowski p-norm to use. Should be in the range [1, inf].
+        A finite large p may cause a ValueError if overflow can occur.
+        ``inf`` corresponds to the Chebyshev distance and 2 to the
+        Euclidean distance.
+    Returns
+    -------
+    output : ndarray or ndarray of bools
+        * If `indices = True`  : (row, column, ...) coordinates of peaks.
+        * If `indices = False` : Boolean array shaped like `image`, with peaks
+          represented by True values.
+    Notes
+    -----
+    The peak local maximum function returns the coordinates of local peaks
+    (maxima) in an image. Internally, a maximum filter is used for finding local
+    maxima. This operation dilates the original image. After comparison of the
+    dilated and original image, this function returns the coordinates or a mask
+    of the peaks where the dilated image equals the original image.
+    See also
+    --------
+    skimage.feature.corner_peaks
+    Examples
+    --------
+    >>> img1 = np.zeros((7, 7))
+    >>> img1[3, 4] = 1
+    >>> img1[3, 2] = 1.5
+    >>> img1
+    array([[0. , 0. , 0. , 0. , 0. , 0. , 0. ],
+           [0. , 0. , 0. , 0. , 0. , 0. , 0. ],
+           [0. , 0. , 0. , 0. , 0. , 0. , 0. ],
+           [0. , 0. , 1.5, 0. , 1. , 0. , 0. ],
+           [0. , 0. , 0. , 0. , 0. , 0. , 0. ],
+           [0. , 0. , 0. , 0. , 0. , 0. , 0. ],
+           [0. , 0. , 0. , 0. , 0. , 0. , 0. ]])
+    >>> peak_local_max(img1, min_distance=1)
+    array([[3, 2],
+           [3, 4]])
+    >>> peak_local_max(img1, min_distance=2)
+    array([[3, 2]])
+    >>> img2 = np.zeros((20, 20, 20))
+    >>> img2[10, 10, 10] = 1
+    >>> img2[15, 15, 15] = 1
+    >>> peak_idx = peak_local_max(img2, exclude_border=0)
+    >>> peak_idx
+    array([[10, 10, 10],
+           [15, 15, 15]])
+    >>> peak_mask = np.zeros_like(img2, dtype=bool)
+    >>> peak_mask[tuple(peak_idx.T)] = True
+    >>> np.argwhere(peak_mask)
+    array([[10, 10, 10],
+           [15, 15, 15]])
+    """
+    if (footprint is None or footprint.size == 1) and min_distance < 1:
+        warn(
+            "When min_distance < 1, peak_local_max acts as finding "
+            "image > max(threshold_abs, threshold_rel * max(image)).",
+            RuntimeWarning,
+            stacklevel=2,
+        )
+
+    border_width = peak._get_excluded_border_width(
+        image, min_distance, exclude_border
+    )
+
+    threshold = peak._get_threshold(image, threshold_abs, threshold_rel)
+
+    if footprint is None:
+        size = 2 * min_distance + 1
+        footprint = np.ones((size,) * image.ndim, dtype=bool)
+    else:
+        footprint = np.asarray(footprint)
+
+    if labels is None:
+        # Non maximum filter
+        mask = peak._get_peak_mask(image, footprint, threshold)
+
+        mask = peak._exclude_border(mask, border_width)
+
+        # Select highest intensities (num_peaks)
+        coordinates = peak._get_high_intensity_peaks(
+            image, mask, num_peaks, min_distance, p_norm
+        )
+
+    else:
+        _labels = peak._exclude_border(
+            labels.astype(int, casting="safe"), border_width
+        )
+
+        if np.issubdtype(image.dtype, np.floating):
+            bg_val = np.finfo(image.dtype).min
+        else:
+            bg_val = np.iinfo(image.dtype).min
+
+        # For each label, extract a smaller image enclosing the object of
+        # interest, identify num_peaks_per_label peaks
+        labels_peak_coord = []
+
+        for label_idx, roi in enumerate(ndi.find_objects(_labels)):
+            if roi is None:
+                continue
+
+            # Get roi mask
+            label_mask = labels[roi] == label_idx + 1
+            # Extract image roi
+            img_object = image[roi].copy()
+            # Ensure masked values don't affect roi's local peaks
+            img_object[np.logical_not(label_mask)] = bg_val
+
+            mask = peak._get_peak_mask(
+                img_object, footprint, threshold, label_mask
+            )
+
+            coordinates = peak._get_high_intensity_peaks(
+                img_object, mask, num_peaks_per_label, min_distance, p_norm
+            )
+
+            # transform coordinates in global image indices space
+            for idx, s in enumerate(roi):
+                coordinates[:, idx] += s.start
+
+            coordinates = np.hstack(
+                (coordinates, np.ones((len(coordinates), 1)) * (label_idx + 1))
+            )
+
+            labels_peak_coord.append(coordinates)
+
+        if labels_peak_coord:
+            coordinates = np.vstack(labels_peak_coord)
+        else:
+            coordinates = np.empty((0, 2), dtype=int)
+
+        if len(coordinates) > num_peaks:
+            out = np.zeros_like(image, dtype=bool)
+            out[tuple(coordinates.T)] = True
+            coordinates = peak._get_high_intensity_peaks(
+                image, out, num_peaks, min_distance, p_norm
+            )
+
+    if indices:
+        return coordinates
+
+
+def crop(stack: np.ndarray, center: tuple, size: int):
+    im = stack[
+        :,
+        int(center[0]) - size // 2 : int(center[0]) + size // 2,
+        int(center[1]) - size // 2 : int(center[1]) + size // 2,
+    ]
+    return im
+
+
+def gdif(array2d, dif_gauss_sigma=(1, 3)):
+    array2d = array2d.astype("f")
+    return ndi.gaussian_filter(
+        array2d, sigma=dif_gauss_sigma[0]
+    ) - ndi.gaussian_filter(array2d, sigma=dif_gauss_sigma[1])
+
+
+def get_peak_number(
+    crop2d,
+    dif_gauss_sigma=(1, 3),
+    min_distance=3,
+    threshold_abs=5,
+    plot=False,
+    title="",
+    bf_crop=None,
+    return_std=False,
+):
+    image_max = gdif(crop2d, dif_gauss_sigma)
+    peaks = peak_local_max(
+        image_max, min_distance=min_distance, threshold_abs=threshold_abs
+    )
+
+    if plot:
+        if bf_crop is None:
+            fig, ax = plt.subplots(1, 2, sharey=True)
+            ax[0].imshow(crop2d)
+            ax[0].set_title(f"raw image {title}")
+            ax[1].imshow(image_max)
+            ax[1].set_title("Filtered + peak detection")
+            ax[1].plot(peaks[:, 1], peaks[:, 0], "r.")
+            plt.show()
+        else:
+            fig, ax = plt.subplots(1, 3, sharey=True)
+
+            ax[0].imshow(bf_crop, cmap="gray")
+            ax[0].set_title(f"BF {title}")
+
+            ax[1].imshow(crop2d, vmax=crop2d.mean() + 2 * crop2d.std())
+            ax[1].set_title(f"raw image {title}")
+
+            ax[2].imshow(image_max)
+            ax[2].set_title(
+                f"Filtered + {len(peaks)} peaks (std {image_max.std():.2f})"
+            )
+            ax[2].plot(peaks[:, 1], peaks[:, 0], "r.")
+            plt.show()
+
+    if return_std:
+        return len(peaks), crop2d.std()
+    else:
+        return (len(peaks),)
+
+
+def get_peaks_per_frame(stack3d, dif_gauss_sigma=(1, 3), **kwargs):
+    image_ref = gdif(stack3d[0], dif_gauss_sigma)
+    thr = 5 * image_ref.std()
+    return list(
+        map(partial(get_peak_number, threshold_abs=thr, **kwargs), stack3d)
+    )
+
+
+def get_peaks_all_wells(stack, centers, size, plot=0):
+    n_peaks = []
+    for c in centers:
+        print(".", end="")
+        well = crop(stack, c["center"], size)
+        n_peaks.append(get_peaks_per_frame(well, plot=plot))
+    return n_peaks
+
+
+def get_n_cells(peaks: list, n_frames=6):
+    return np.round(np.mean(peaks[:n_frames]), 0).astype(int)
diff --git a/merge_tables.py b/merge_tables.py
index 86ccdb33d758494d6aa28d7334965c3329803391..af3b17e75359264811df519b219203f67db00c8d 100644
--- a/merge_tables.py
+++ b/merge_tables.py
@@ -4,41 +4,86 @@ import seaborn as sns
 import matplotlib.pyplot as plt
 
 
-def combine(table_day1_path, table_day2_path, table_output_path, swarmplot_output_path, prob_plot_path, threshold:int=20):
+def combine(
+    table_day1_path,
+    table_day2_path,
+    table_output_path,
+    swarmplot_output_path,
+    prob_plot_path,
+    threshold: int = 20,
+):
     day1, day2 = [pd.read_csv(t) for t in [table_day1_path, table_day2_path]]
-    day1.loc[:,'n_cells_final'] = day2.n_cells
-    day1.loc[:,'intensity_final'] = day2.intensity
+    day1.loc[:, "n_cells_final"] = day2.n_cells
+    day1.loc[:, "intensity_final"] = day2.intensity
 
-    day1.loc[:,'final_state'] = day1.n_cells_final > threshold
+    day1.loc[:, "final_state"] = day1.n_cells_final > threshold
 
     day1.to_csv(table_output_path, index=None)
     n_max = int(day1.n_cells.mean()) * 3 + 1
     plot_swarm_counts(day1, n_max=n_max, path=swarmplot_output_path)
-    plot_swarm_intensity(day1, n_max=n_max, path=swarmplot_output_path.replace('.png', '_intensities.png'))
-    plot_probs(day1, n_max=n_max,path=prob_plot_path)
-    plot_probs_log(day1, n_max=n_max,path=prob_plot_path.replace('.png', '_log.png'))
-    
-def plot_swarm_counts(table:pd.DataFrame, n_max:int=5, path=None):
+    plot_swarm_intensity(
+        day1,
+        n_max=n_max,
+        path=swarmplot_output_path.replace(".png", "_intensities.png"),
+    )
+    plot_probs(day1, n_max=n_max, path=prob_plot_path)
+    plot_probs_log(
+        day1, n_max=n_max, path=prob_plot_path.replace(".png", "_log.png")
+    )
+
+
+def plot_swarm_counts(table: pd.DataFrame, n_max: int = 5, path=None):
     fig, ax = plt.subplots(dpi=300)
-    sns.swarmplot(ax=ax, data=table.query(f'n_cells <= {n_max}'), x='[AB]', y='n_cells_final', hue='n_cells', dodge=True, size=1 )
+    sns.swarmplot(
+        ax=ax,
+        data=table.query(f"n_cells <= {n_max}"),
+        x="[AB]",
+        y="n_cells_final",
+        hue="n_cells",
+        dodge=True,
+        size=1,
+    )
     fig.savefig(path)
 
-def plot_swarm_intensity(table:pd.DataFrame, n_max:int=5, path=None):
+
+def plot_swarm_intensity(table: pd.DataFrame, n_max: int = 5, path=None):
     fig, ax = plt.subplots(dpi=300)
-    sns.swarmplot(ax=ax, data=table.query(f'n_cells <= {n_max}'), x='[AB]', y='intensity_final', hue='n_cells', dodge=True, size=1 )
+    sns.swarmplot(
+        ax=ax,
+        data=table.query(f"n_cells <= {n_max}"),
+        x="[AB]",
+        y="intensity_final",
+        hue="n_cells",
+        dodge=True,
+        size=1,
+    )
     fig.savefig(path)
 
-def plot_probs(table:pd.DataFrame, n_max:int=5, path=None):
-    
+
+def plot_probs(table: pd.DataFrame, n_max: int = 5, path=None):
     fig, ax = plt.subplots(dpi=300)
-    sns.lineplot(ax=ax, data=table.query(f'n_cells <= {n_max}'), x='[AB]', y='final_state', hue='n_cells')
+    sns.lineplot(
+        ax=ax,
+        data=table.query(f"n_cells <= {n_max}"),
+        x="[AB]",
+        y="final_state",
+        hue="n_cells",
+    )
     fig.savefig(path)
 
-def plot_probs_log(table:pd.DataFrame, n_max:int=5, path=None):
+
+def plot_probs_log(table: pd.DataFrame, n_max: int = 5, path=None):
     fig, ax = plt.subplots(dpi=300)
-    sns.lineplot(ax=ax, data=table.query(f'n_cells <= {n_max}'), x='[AB]', y='final_state', hue='n_cells')
-    ax.set_xscale('log')
+    sns.lineplot(
+        ax=ax,
+        data=table.query(f"n_cells <= {n_max}"),
+        x="[AB]",
+        y="final_state",
+        hue="n_cells",
+    )
+    ax.set_xscale("log")
     fig.savefig(path)
 
+
 if __name__ == "__main__":
-    fire.Fire(combine)
\ No newline at end of file
+    fire.Fire(combine)
diff --git a/poisson.py b/poisson.py
new file mode 100644
index 0000000000000000000000000000000000000000..943ddfecd0447524c310f241766a3bc6d6475934
--- /dev/null
+++ b/poisson.py
@@ -0,0 +1,41 @@
+from scipy.optimize import curve_fit
+from scipy.stats import poisson
+import matplotlib.pyplot as plt
+import numpy as np
+
+
+def fit(
+    numbers: np.ndarray,
+    max_value=None,
+    xlabel="Initial number of cells",
+    title="",
+    plot=True,
+    save_fig_path=None,
+):
+    if max_value is None:
+        max_value = numbers.max()
+    bins = np.arange(max_value + 1) - 0.5
+    vector = bins[:-1] + 0.5
+    hist, bins = np.histogram(numbers, bins=bins, density=True)
+    popt, pcov = curve_fit(poisson.pmf, vector, hist, p0=(1.0,))
+    lambda_fit_result = popt[0]
+    if plot:
+        plt.hist(numbers, bins=bins, fill=None)
+        plt.plot(
+            vector,
+            len(numbers) * poisson.pmf(vector, lambda_fit_result),
+            ".-",
+            label=f"Poisson fit λ={lambda_fit_result:.1f}",
+            color="tab:red",
+        )
+        plt.xlabel(xlabel)
+        plt.title(title)
+        plt.legend()
+        if save_fig_path is not None:
+            try:
+                plt.savefig(save_fig_path)
+                print(f"Save histogram {save_fig_path}")
+            except Exception as e:
+                print("saving histogram failed", e.args)
+        plt.show()
+    return lambda_fit_result
diff --git a/register.py b/register.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6b9fac1a6cf55fd156321ace35093323b4acd34
--- /dev/null
+++ b/register.py
@@ -0,0 +1,208 @@
+from tifffile import imread
+import matplotlib.pyplot as plt
+import imreg_dft as reg
+import numpy as np
+
+
+def align_stack(
+    data_or_path,
+    template16,
+    mask2,
+    plot=False,
+    binnings=(1, 16, 2),
+):
+    """
+    stack should contain two channels: bright field and fluorescence.
+    BF will be binned 8 times and registered with template8 (aligned BF).
+    When the transformation verctor will be applied to the original data and
+    stacked with the mask.
+    The output stack is of the same size as mask.
+
+    :return: aligned_array[bf, fluo, mask], tvec
+    :rtype: np.ndarray, dict
+
+    """
+    if isinstance(data_or_path, str):
+        path = data_or_path
+        stack = imread(path)
+        print(path, stack.shape)
+    else:
+        assert data_or_path.ndim == 3 and data_or_path.shape[0] == 2
+        stack = data_or_path
+
+    bf, tritc = stack[:2]
+    stack_temp_scale = binnings[1] // binnings[0]
+    mask_temp_scale = binnings[1] // binnings[2]
+    stack_mask_scale = binnings[2] // binnings[0]
+
+    f_bf = bf[::stack_temp_scale, ::stack_temp_scale]
+
+    tvec8 = get_transform(f_bf, template16, plot=plot)
+    plt.show()
+    tvec = scale_tvec(tvec8, mask_temp_scale)
+    print(tvec)
+    try:
+        aligned_tritc = unpad(
+            transform(tritc[::stack_mask_scale, ::stack_mask_scale], tvec),
+            mask2.shape,
+        )
+        aligned_bf = unpad(
+            transform(bf[::stack_mask_scale, ::stack_mask_scale], tvec),
+            mask2.shape,
+        )
+    except ValueError as e:
+        print("stack_mask_scale: ", stack_mask_scale)
+        print(e.args)
+        raise e
+
+    aligned_stack = np.stack((aligned_bf, aligned_tritc, mask2)).astype(
+        "uint16"
+    )
+
+    return aligned_stack, tvec
+
+
+def get_transform(
+    image, template, plot=True, pad_ratio=1.2, figsize=(10, 5), dpi=300
+):
+    """
+    Pads image and template, registers and returns tvec
+    """
+    padded_template = pad(template, (s := increase(image.shape, pad_ratio)))
+    padded_image = pad(image, s)
+    tvec = register(padded_image, padded_template)
+    if plot:
+        aligned_bf = unpad(tvec["timg"], template.shape)
+        plt.figure(figsize=figsize, dpi=dpi)
+        plt.imshow(aligned_bf, cmap="gray")
+    return tvec
+
+
+def register(image, template):
+    """
+    Register image towards template
+    Return:
+    tvec:dict
+    """
+    assert np.array_equal(
+        image.shape, template.shape
+    ), f"unequal shapes {(image.shape, template.shape)}"
+    return reg.similarity(
+        template,
+        image,
+        constraints={
+            "scale": [1, 0.2],
+            "tx": [0, 50],
+            "ty": [0, 50],
+            "angle": [0, 30],
+        },
+    )
+
+
+def filter_by_fft(
+    image,
+    sigma=40,
+    fix_horizontal_stripes=False,
+    fix_vertical_stripes=False,
+    highpass=True,
+):
+    size = get_fft_size(image.shape)
+    fft = np.fft.fft2(pad(image, size))
+    if fix_vertical_stripes:
+        fft[0, 1:] = 0  # vertical
+    if fix_horizontal_stripes:
+        fft[1:, 0] = 0  # horizontal
+    fft = fft * (fft_mask(size, sigma=size[0] / sigma, highpass=highpass))
+    filtered_image = unpad(abs(np.fft.ifft2(fft)), image.shape)
+    return filtered_image
+
+
+def fft_mask(shape, sigma=10, highpass=False):
+    y, x = np.indices(shape)
+
+    gaus = np.exp(-((x - shape[1] // 2) ** 2) / sigma**2 / 2) * np.exp(
+        -((y - shape[0] // 2) ** 2) / sigma**2 / 2
+    )
+    gaus = gaus / gaus.max()
+    if highpass:
+        return 1 - gaus
+    return gaus
+
+
+def pad(image: np.ndarray, to_shape: tuple = None, padding: tuple = None):
+    if padding is None:
+        padding = calculate_padding(image.shape, to_shape)
+    try:
+        padded = np.pad(image, padding, "edge")
+    except TypeError as e:
+        print(e.args, padding)
+        raise e
+    return padded
+
+
+def unpad(image: np.ndarray, to_shape: tuple = None, padding: tuple = None):
+    if any(np.array(image.shape) - np.array(to_shape) < 0):
+        print(
+            f"""unpad:warning: image.shape {image.shape} \
+            is within to_shape {to_shape}"""
+        )
+        image = pad(image, np.array((image.shape, to_shape)).max(axis=0))
+        print(f"new image shape after padding {image.shape}")
+    if padding is None:
+        padding = calculate_padding(to_shape, image.shape)
+
+    y = [padding[0][0], -padding[0][1]]
+    if y[1] == 0:
+        y[1] = None
+    x = [padding[1][0], -padding[1][1]]
+    if x[1] == 0:
+        x[1] = None
+    return image[y[0]: y[1], x[0]: x[1]]
+
+
+def get_fft_size(shape):
+    max_size = np.max(shape)
+    n = 5
+    while (2**n) < (max_size * 1.5):
+        n += 1
+    return (2**n, 2**n)
+
+
+def calculate_padding(shape1: tuple, shape2: tuple):
+    """
+    Calculates padding to get shape2 from shape1
+    Return:
+    2D tuple of indices
+    """
+    dif = np.array(shape2) - np.array(shape1)
+    assert all(
+        dif >= 0
+    ), f"Shape2 must be bigger than shape1, got {shape2}, {shape1}"
+    mid = dif // 2
+    rest = dif - mid
+    y = mid[0], rest[0]
+    x = mid[1], rest[1]
+    return y, x
+
+
+def scale_tvec(tvec, scale=8):
+    tvec_8x = tvec.copy()
+    tvec_8x["tvec"] = tvec["tvec"] * scale
+    try:
+        tvec_8x["timg"] = None
+    except KeyError:
+        pass
+    finally:
+        return tvec_8x
+
+
+def transform(image, tvec):
+    print(f"transform {image.shape}")
+    fluo = reg.transform_img_dict(image, tvec)
+    return fluo.astype("uint")
+
+
+def increase(shape, increase_ratio):
+    assert increase_ratio > 1
+    shape = np.array(shape)
+    return tuple((shape * increase_ratio).astype(int))
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..93e4409c4734c3ff85686a36238a15cab6c05516
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,85 @@
+appdirs==1.4.4
+asciitree==0.3.3
+attrs==23.1.0
+certifi==2023.7.22
+charset-normalizer==3.2.0
+click==8.1.7
+cloudpickle==2.2.1
+ConfigArgParse==1.7
+connection-pool==0.0.3
+contourpy==1.1.1
+cycler==0.11.0
+dask==2023.5.0
+datrie==0.8.2
+docutils==0.20.1
+dpath==2.1.6
+entrypoints==0.4
+fasteners==0.18
+fastjsonschema==2.18.0
+fire==0.5.0
+fonttools==4.42.1
+fsspec==2023.9.1
+gitdb==4.0.10
+GitPython==3.1.36
+humanfriendly==10.0
+idna==3.4
+imageio==2.31.3
+importlib-metadata==6.8.0
+importlib-resources==6.0.1
+imreg-dft @ git+https://github.com/matejak/imreg_dft.git@4023e9e2c4110b04017c85e7bda9883a71b2eea0
+Jinja2==3.1.2
+jsonschema==4.19.0
+jsonschema-specifications==2023.7.1
+jupyter_core==5.3.1
+kiwisolver==1.4.5
+lazy_loader==0.3
+locket==1.0.0
+MarkupSafe==2.1.3
+matplotlib==3.7.3
+nbformat==5.9.2
+nd2==0.7.2
+networkx==3.1
+numcodecs==0.11.0
+numpy==1.24.4
+packaging==23.1
+pandas==2.0.3
+partd==1.4.0
+Pillow==10.0.1
+pkgutil_resolve_name==1.3.10
+plac==1.3.5
+platformdirs==3.10.0
+psutil==5.9.5
+PuLP==2.7.0
+pyparsing==3.1.1
+python-dateutil==2.8.2
+pytz==2023.3.post1
+PyWavelets==1.4.1
+PyYAML==6.0.1
+referencing==0.30.2
+requests==2.31.0
+reretry==0.11.8
+resource-backed-dask-array==0.1.0
+rpds-py==0.10.3
+scikit-image==0.21.0
+scipy==1.10.1
+seaborn==0.12.2
+six==1.16.0
+smart-open==6.4.0
+smmap==5.0.1
+snakemake==7.32.4
+stopit==1.1.2
+tabulate==0.9.0
+termcolor==2.3.0
+throttler==1.2.2
+tifffile==2023.7.10
+toolz==0.12.0
+toposort==1.10
+traitlets==5.10.0
+typing_extensions==4.8.0
+tzdata==2023.3
+urllib3==2.0.4
+wrapt==1.15.0
+yte==1.5.1
+zarr==2.16.1
+zarr-tools==0.4.5
+zipp==3.17.0