Skip to content
Snippets Groups Projects
Select Git revision
  • 512c3824789c4530a64775538806163a84d9c489
  • main default
  • handle-single-chip
  • master
  • v0.2.0
  • v0.1.4
  • v0.1.3
  • v0.1.2
  • 0.1.1
  • v0.0.1
10 results

merge_tables.py

Blame
  • merge_tables.py 1.87 KiB
    import pandas as pd
    import fire
    import seaborn as sns
    import matplotlib.pyplot as plt
    
    
    def combine(table_day1_path, table_day2_path, table_output_path, swarmplot_output_path, prob_plot_path, threshold:int=20):
        day1, day2 = [pd.read_csv(t) for t in [table_day1_path, table_day2_path]]
        day1.loc[:,'n_cells_final'] = day2.n_cells
        day1.loc[:,'intensity_final'] = day2.intensity
    
        day1.loc[:,'final_state'] = day1.n_cells_final > threshold
    
        day1.to_csv(table_output_path, index=None)
        n_max = int(day1.n_cells.mean() * 3)
        plot_swarm_counts(day1, n_max=n_max, path=swarmplot_output_path)
        plot_swarm_intensity(day1, n_max=n_max, path=swarmplot_output_path.replace('.png', '_intensities.png'))
        plot_probs(day1, n_max=n_max,path=prob_plot_path)
        plot_probs_log(day1, n_max=n_max,path=prob_plot_path.replace('.png', '_log.png'))
        
    def plot_swarm_counts(table:pd.DataFrame, n_max:int=5, path=None):
        fig, ax = plt.subplots(dpi=300)
        sns.swarmplot(ax=ax, data=table.query(f'n_cells < {n_max}'), x='[AB]', y='n_cells_final', hue='n_cells', dodge=True, size=1 )
        fig.savefig(path)
    
    def plot_swarm_intensity(table:pd.DataFrame, n_max:int=5, path=None):
        fig, ax = plt.subplots(dpi=300)
        sns.swarmplot(ax=ax, data=table.query(f'n_cells < {n_max}'), x='[AB]', y='intensity_final', hue='n_cells', dodge=True, size=1 )
        fig.savefig(path)
    
    def plot_probs(table:pd.DataFrame, n_max:int=5, path=None):
        
        fig, ax = plt.subplots(dpi=300)
        sns.lineplot(ax=ax, data=table.query(f'n_cells < {n_max}'), x='[AB]', y='final_state', hue='n_cells')
        fig.savefig(path)
    
    def plot_probs_log(table:pd.DataFrame, n_max:int=5, path=None):
        fig, ax = plt.subplots(dpi=300)
        sns.lineplot(ax=ax, data=table.query(f'n_cells < {n_max}'), x='[AB]', y='final_state', hue='n_cells')
        ax.set_xscale('log')
        fig.savefig(path)
    
    if __name__ == "__main__":
        fire.Fire(combine)