Skip to content
Snippets Groups Projects
Select Git revision
  • 453a88812060f6e0a416785093a005bc369a771f
  • main default
  • handle-single-chip
  • master
  • v0.2.0
  • v0.1.4
  • v0.1.3
  • v0.1.2
  • 0.1.1
  • v0.0.1
10 results

align.py

Blame
  • align.py 5.53 KiB
    from typing import List
    import numpy as np
    from droplet_growth import mic, register, poisson
    import tifffile as tf
    import os
    import dask.array as da
    from functools import partial
    from zarr_tools import convert
    import yaml
    import fire
    from multiprocessing import Pool
    import nd2
    import pandas as pd
    from skimage.measure import regionprops_table
    from scipy.ndimage import laplace, rotate
    import json
    
    
    def align_multichip(BF_TRITC_2D_path, out_path, concentrations_path, template_path, labels_path, table_path, fit_poisson, nmax=10):
        with open(concentrations_path, 'r') as f:
            concentrations_dct = (yaml.safe_load(f))
        concentrations = concentrations_dct['concentrations']
        if "rotation_template_deg" in concentrations_dct:
            rotate_template_deg = int(concentrations_dct["rotation_template_deg"])
            print(f'Rotation: {rotate_template_deg}')
        else:
            rotate_template_deg = 0
        print('concentrations: ', concentrations)
        unit = concentrations_dct['units']
        tif_paths = [f'{os.path.dirname(BF_TRITC_2D_path)}/{c:02d}{unit}_aligned.tif' for c in concentrations]
        print('tif_paths: ', tif_paths)
        
        data = read_dask(BF_TRITC_2D_path)
        template16 = tf.imread(template_path)
        if rotate_template_deg != 0:
            data = da.from_array(rotate(data, rotate_template_deg))
            print(f'Rotated data {rotate_template_deg} deg')
        big_labels = tf.imread(labels_path)
    
        fun = partial(
            align_parallel, 
            template16=template16,
            big_labels=big_labels,
            unit=unit,
            fit_poisson=fit_poisson,
            nmax=nmax )
        
        try:
            p=Pool(data.shape[0])
            out = p.map(fun, zip(data, tif_paths, concentrations)) 
    
        except TypeError as e:
            print(f'Pool failed due to {e.args}')
            out = list(map(fun, zip(data, tif_paths, concentrations)))
        finally:
            p.close()
        
        aligned = [o['stack'] for o in out]
        counts = [o['counts'] for o in out]
        tvecs = [o['tvec'] for o in out]
    
        def prep_tvec(transform_dict):
            transform_dict["tvec"] = List(transform_dict["tvec"])
            return transform_dict
    
        try:
            transform_data = [map(prep_tvec, tvecs)]
            with open(out_path.replace('.zarr', '.transform.json'), 'w') as f:
                json.dump(transform_data, fp=f)
        except Exception as e:
            print('saving transform json failed: ', e.args)
        
    
        df = pd.concat(counts, ignore_index=True).sort_values(['[AB]','label'])
        df.to_csv(table_path)
    
        daligned = da.from_array(np.array(aligned))
        convert.to_zarr(
            daligned, 
            path=out_path, 
            steps=4, 
            name=['BF','TRITC','mask'], 
            colormap=['gray','green','blue'],
            lut=((1000,30000),(440, 600),(0,501)),
        )
        return out_path
    
     
    def align_parallel(args, **kwargs):
        return align2D(*args, **kwargs)
        
    def align2D(stack_dask, path_tif, ab, template16=None, big_labels=None, unit='μg_mL', fit_poisson=True, nmax=20):
        print(ab, unit)
        try:
            aligned = tf.imread(path_tif)
            print(f'already aligned: {path_tif}:{aligned.shape}')
            counts = count(aligned, ab=ab)
            intensity_table = get_intensity_table(aligned[2],aligned[1])
            counts.loc[:, 'intensity'] = intensity_table.intensity
            return {"stack": aligned, "counts": counts}
        except FileNotFoundError:
            print('Processing...')
            
        data = stack_dask.compute()
        aligned, tvec = register.align_stack(
            data, 
            path_to_save=None, 
            template16=template16, 
            mask2=big_labels, 
            binnings=(2,16,2)
        )
        counts = count(aligned, ab=ab)
        intensity_table = get_intensity_table(aligned[2],aligned[1])
        counts.loc[:, 'intensity'] = intensity_table.intensity
        
        if fit_poisson:
            l = poisson.fit(
                counts.query(f'n_cells < {nmax}').n_cells, 
                title=f"automatic {ab}{unit.replace('_','/')}",
                save_fig_path=path_tif.replace('.tif', '-counts-hist.png')
            )
            counts.loc[:, 'poisson fit'] = l
    
        # counts.to_csv((cp := path_tif.replace('.tif', '-counts.csv')), index=None)
        return {"stack": aligned, "counts": counts, 'tvec': tvec}
    
    def count(aligned:np.ndarray, ab=None):
        counts = mic.get_cell_numbers(aligned[1], aligned[2], threshold_abs=2, plot=False, bf=aligned[0])
        counts.loc[:,'[AB]'] = ab
        return counts
    
    def get_intensity_table(
        labelled_mask: np.ndarray,
        intensity_image: np.ndarray,
        values = ['mean_intensity', ],
        estimator = np.mean,
        plot: bool = True
    ):
        assert (iis := intensity_image).ndim == 2, (
            f'expected 2D array for intensity, got shape {iis.shape}'
        )
        data = intensity_image.astype('f')
        dict_li = regionprops_table(
            labelled_mask,
            intensity_image=data,
            properties=['label', *values]
        )
        dict_bg = regionprops_table(
            get_outlines(labelled_mask),
            intensity_image=data,
            properties=['mean_intensity']
        )
        dict_litb = {**dict_li, 'bg_mean': dict_bg['mean_intensity']}
        df = pd.DataFrame.from_dict(dict_litb)
    
        df.loc[:, "intensity"] = df.mean_intensity - df.bg_mean
        return df
    
    def get_outlines(labels):
        '''creates 1 px outline around labels'''
        return labels * (np.abs(laplace(labels)) > 0)
    
    
    def read_dask(path:str):
        if path.endswith('.zarr'):
            data = da.from_zarr(path+'/0/')
        elif path.endswith('.nd2'):
            data = nd2.ND2File(path).to_dask()
        else:
            raise ValueError(f'Unexpected file format, expected zarr or nd2')
        print('data:', data)
        return data
       
    if __name__ == "__main__":
        fire.Fire(align_multichip)