Skip to content
Snippets Groups Projects
Snakefile 34.54 KiB
#########################################################################
# ChIPflow: Standardize and reproducible ChIP-seq analysis from raw     #
#           data to differential analysis                               #
# Authors: Rachel Legendre, Maelle Daunesse                             #
# Copyright (c) 2019-2020  Institut Pasteur (Paris) and CNRS.           #
#                                                                       #
# This file is part of ChIPflow workflow.                               #
#                                                                       #
# ChIPflow is free software: you can redistribute it and/or modify      #
# it under the terms of the GNU General Public License as published by  #
# the Free Software Foundation, either version 3 of the License, or     #
# (at your option) any later version.                                   #
#                                                                       #
# ChIPflow is distributed in the hope that it will be useful,           #
# but WITHOUT ANY WARRANTY; without even the implied warranty of        #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the          #
# GNU General Public License for more details .                         #
#                                                                       #
# You should have received a copy of the GNU General Public License     #
# along with ChIPflow (LICENSE).                                        #
# If not, see <https://www.gnu.org/licenses/>.                          #
#########################################################################



import pandas as pd
from fnmatch import fnmatch
from re import sub, match
from itertools import repeat, chain
import os


#-------------------------------------------------------
# read config files

configfile: "config/config.yaml"
#put a relative path
RULES = os.path.join("workflow", "rules")

design = pd.read_csv(config["design"]["design_file"], header=0, sep='\t')

"""
REQUIREMENTS IN DESIGN:
 - all files in one directory
 - fullname of files must be :
        MARK_COND_REP_MATE.fastq.gz
 - name on design must be:
        MARK_COND


"""

#-------------------------------------------------------
# list of all files in the directory 'input_dir'

filenames = [f for f in os.listdir(config["input_dir"]) if match(r'.*'+config["input_mate"]+config["input_extension"]+'', f)] 

if not filenames :
    raise ValueError("Please provides input fastq files")

#-------------------------------------------------------
# paired-end data gestion

mate_pair = config["input_mate"]
rt1 = mate_pair.replace("[12]", "1")
rt2 = mate_pair.replace("[12]", "2")

R1 = [1 for this in filenames if rt1 in this]
R2 = [1 for this in filenames if rt2 in this]
if len(R2) == 0:
    paired = False
else:
    if R1 == R2:
        paired = True
    else:
        raise ValueError("Please provides single or paired files only")

#get sample names
filename_R1 = [ file for file in filenames if match(r'.*'+rt1+config["input_extension"]+'', file)]
samples = [sub(rt1+config["input_extension"], '', file) for file in filename_R1]
marks = [ x.strip() for x in (config["design"]["marks"]).split(",")]
conds = [ x.strip() for x in (config["design"]["condition"]).split(",")]
rep_flag = config["design"]["replicates"]

# -----------------------------------------------
# get list of INPUT files

INPUT = []
for row in design.itertuples(index=True, name='Pandas'):
    for file in samples:
        mark = getattr(row, "INPUT_NAME")
        if not pd.isna(mark):
            if fnmatch(file, mark+"*"+rep_flag+"1*"):
                i = 1
                name = sub(mate_pair, '', file)
                INPUT.append(name)
                if getattr(row, "NB_IP") > 1 :
                    while i < getattr(row, "NB_IP"):
                        INPUT.append(name)
                        i += 1

# -----------------------------------------------
# get list of IP files

IP_ALL = []
for mark in design['IP_NAME']:
    for file in samples:
        if fnmatch(file, mark+"*") and file not in IP_ALL:
            name = sub(mate_pair, '', file)
            IP_ALL.append(name)


# -----------------------------------------------
# check design
# Select mark for statistical analysis if there is more than one condition (with min 2 rep) for each mark


MARK_OK = []
CORR_INPUT_OK = [] # corresponding INPUT condition
nb = 0
for row in design.itertuples(index=True, name='Pandas'):
    for mark in marks:
        if mark == getattr(row, "IP_NAME").split("_")[0]:
            nb_rep = getattr(row, "NB_IP")
            if (getattr(row, "IP_NAME")).endswith(tuple(conds)) and getattr(row, "NB_IP") > 1:
                if nb >= 1:               
                    MARK_OK.append(mark)
                    CORR_INPUT_OK.append(getattr(row, "INPUT_NAME").split("_")[0])
                    break
                else:                 
                    nb += 1
nb = 0

if config["differential_analysis"]["input_counting"]:
    CORR_INPUT = [] # corresponding INPUT files for input counting
    for mark in CORR_INPUT_OK:
        for file in samples:
            if fnmatch(file, mark+"_*") and file not in CORR_INPUT:
                name = sub(mate_pair, '', file)
                CORR_INPUT.append(name)
    
# -----------------------------------------------
# get REP names

###### INCLUDE BETTER GESTION OF AUTO REPLICATES WITH 3 REP
###### TODO use https://snakemake.readthedocs.io/en/stable/project_info/faq.html#i-don-t-want-expand-to-use-every-wildcard-what-can-i-do 
#get list with replicates names for all IP

if nb_rep > 1 :
    rep = [rep_flag+'1', rep_flag+'2']
    # deduce from rep the list for SPR & PPR
    ppr = ['PPR1', 'PPR2', 'PPRPool']
    spr = ['SPR1.1', 'SPR1.2', 'SPR2.1', 'SPR2.2']
else :
    rep = [rep_flag+'1']
    

# -----------------------------------------------
# built list of MARK_COND_REP from config.yaml and check correspondance between design, config and fastq files

if len(conds) > 1 :
#built list of MARK_COND_REP from config.yaml
    conf_cond = ["{mark}_{cond}_{flag}".format(cond=cond, mark=mark, flag=rep_flag) for cond in conds for mark in marks]
    for row in design.itertuples(index=True, name='Pandas'):
        for elem in conf_cond:
            if elem in getattr(row, "IP_NAME"):
                raise ValueError("Please check correspondance between config and design file: %s is not %s "
                                 % (elem,getattr(row, "IP_NAME")))
            elif not any(elem in s for s in samples):
                raise ValueError("Please check correspondance between config file and fastq filenames: %s not found in %s"
                                 % (elem, samples))
            elif sum(getattr(row, "IP_NAME")+"_"+rep_flag in s for s in samples) is not int(getattr(row, "NB_IP")):
                raise ValueError("Please check correspondance between number of replicates and/or prefix names in design "
                                 "file and fastq filenames: %s not found %s times in %s" % (getattr(row, "IP_NAME"),
                                                                                             getattr(row, "NB_IP"),samples))
            elif sum(getattr(row, "INPUT_NAME")+"_"+rep_flag in s for s in samples) is not int(getattr(row, "NB_INPUT")):
                raise ValueError("Please check correspondance between number of replicates and/or prefix names in design "
                                 "file and fastq filenames: %s not found %s times in %s" % (getattr(row, "INPUT_NAME"), getattr(row, "NB_INPUT"),samples))


# -----------------------------------------------
# From the design file, we get only IP with more than one replicate and at least one INPUT


IP_REP = [] # all IP passing IDR
IP_REP_DUP = [] # corresponding INPUT
IP_NO_INPUT = [] # all IP with replicates but without INPUT
INPUT_NA = [] # corresponding INPUT (list of NA according to number of IP)
IP_NA = [] # all IP without replicates and without INPUT
for row in design.itertuples(index=True, name='Pandas'):
    if getattr(row, "NB_IP") > 1 and getattr(row, "NB_INPUT") > 1:
        IP_REP_DUP.append(getattr(row, "INPUT_NAME")+"_Pool")
        IP_REP.append(getattr(row, "IP_NAME"))
    elif getattr(row, "NB_IP") > 1 and getattr(row, "NB_INPUT") == 1:
        IP_REP_DUP.append(getattr(row, "INPUT_NAME")+"_"+rep_flag+"1")
        IP_REP.append(getattr(row, "IP_NAME"))
    elif getattr(row, "NB_IP") > 1 and getattr(row, "NB_INPUT") < 1:
        IP_NO_INPUT.append(getattr(row, "IP_NAME"))
        INPUT_NA.append('NA')
    elif getattr(row, "NB_IP") == 1 and getattr(row, "NB_INPUT") < 1:
        IP_NA.append(getattr(row, "IP_NAME"))


# We get only INPUT with at least one replicate that are linked to an IP with multiple replicates
INPUT_REP = []
for row in design.itertuples(index=True, name='Pandas'):
    if getattr(row, "NB_INPUT") >= 1 and getattr(row, "NB_IP") > 1 and (getattr(row, "INPUT_NAME")) not in INPUT_REP:
        INPUT_REP.append(getattr(row, "INPUT_NAME"))
# -----------------------------------------------
# get IP passed pre IDR step (for rep, ppr and spr)

IP_IDR = []
IP_SPR = []
IP_PPR = []
SPR_POOL = []
INPUT_SPR  = []
INPUT_PPR = []
#IDR_REP = []
for cond in IP_REP:
    tmp = []
    tmp2 = []
    for ip in IP_ALL:
        name = ip.split("_" + rep_flag)
        if ip.startswith(cond):
            spr_file = sub(rep_flag, 'SPR', ip)
            ppr_file =  sub(rep_flag, 'PPR', ip)
            pool_file = sub(rep_flag, 'PPRPool', ip)
            tmp.append(ip)
            #IDR_REP.append("Rep"+name[1])
            IP_SPR.append(spr_file+".1")
            IP_SPR.append(spr_file+".2")
            IP_PPR.append(ppr_file)
            SPR_POOL.append(pool_file)
            for row in design.itertuples(index=True, name='Pandas'):
                if  (getattr(row, "IP_NAME")) in name[0]:
                     INPUT_PPR.append(getattr(row, "INPUT_NAME")+"_"+rep[0])
                     INPUT_SPR.append(getattr(row, "INPUT_NAME")+"_"+rep[0])
                     INPUT_SPR.append(getattr(row, "INPUT_NAME")+"_"+rep[0])
    IP_IDR.append(tmp)


# -----------------------------------------------
# get pooled IP passed pre IDR step and corresponding INPUT

PPR_POOL = []
INPUT_POOL = []
for row in design.itertuples(index=True, name='Pandas'):
    if getattr(row, "NB_IP") > 1 and getattr(row, "NB_INPUT") > 1 :
        PPR_POOL.append(getattr(row, "IP_NAME") + "_PPRPool")
        INPUT_POOL.append(getattr(row, "INPUT_NAME") + "_Pool")
        pass_pool = 0
    #elif getattr(row, "NB_IP") > 1 and getattr(row, "NB_INPUT") < 2 :
    elif getattr(row, "NB_IP") > 1 and getattr(row, "NB_INPUT") == 1 :
        PPR_POOL.append(getattr(row, "IP_NAME") + "_PPRPool")
        INPUT_POOL.append(getattr(row, "INPUT_NAME")+ "_" + rep_flag + "1")
        pass_pool = 1




# all files passing PhantomPeakQualTool if no-model is chosen
ALL = IP_ALL + IP_SPR + IP_PPR + PPR_POOL


# all files passing peak calling step
ALL_IP_PC = IP_ALL + IP_SPR + IP_PPR + PPR_POOL
ALL_INPUT_PC = INPUT + INPUT_SPR + INPUT_PPR + INPUT_POOL



# get files for IDR
CASE = [rep_flag, "PPR", "SPR1.", "SPR2."]*len(IP_REP)
REP_IDR = list(chain(*zip(*repeat(IP_REP,4))))
IN_IDR = list(chain(*zip(*repeat(IP_REP_DUP,4))))



# -----------------------------------------------
# Add wildcard constraints

wildcard_constraints:
    sample = "[A-Za-z-_0-9]+_{0}[0-9]+".format(rep_flag),
    IP_REP = "[A-Za-z-_0-9]+_{0}[0-9]+".format(rep_flag),
    REP = "{0}[0-9]+".format(rep_flag),
    SPR = "[A-Za-z-_0-9]+_SPR[0-9]\.[1-4]*",
    PPR = "[A-Za-z-_0-9]+_PPR[0-9]*",
    POOL = "[A-Za-z-_0-9]+_PPRPool",
    INPUT_POOL = "[A-Za-z-_0-9]+_(Pool|{0}1)".format(rep_flag),
    MARK = "[A-Za-z-_0-9]+"


# -----------------------------------------------
#initialize global variables


sample_dir = config["input_dir"]
analysis_dir = config["analysis_dir"]

input_data = [sample_dir + "/{{SAMPLE}}{}{}".format(rt1,config["input_extension"])]
if paired:
    input_data += [sample_dir + "/{{SAMPLE}}{}{}".format(rt2,config["input_extension"])]    
# global variable for all output files
final_output = []



# -----------------------------------------------
#begin of the rules

#----------------------------------
# quality control
#----------------------------------

fastqc_input_fastq = input_data
fastqc_output_done = "00-Fastqc/{{SAMPLE}}{}_fastqc.done".format(rt1)
fastqc_wkdir = "00-Fastqc"
fastqc_log = "00-Fastqc/logs/{{SAMPLE}}{}_fastqc_raw.log".format(rt1)
final_output.extend(expand(fastqc_output_done, SAMPLE=samples))
include: os.path.join(RULES, "fastqc.rules")



#----------------------------------
# Remove adapters
#----------------------------------

if config["adapters"]["remove"] :

    ## TODO add AlienTrimmer
    adapter_tool = "adapters"
    cutadapt_input_fastq = input_data
    cutadapt_wkdir = "01-Trimming"
    cutadapt_output = ["01-Trimming/{{SAMPLE}}{}_trim{}".format(rt1,config["input_extension"])]
    if paired:
        cutadapt_output += ["01-Trimming/{{SAMPLE}}{}_trim{}".format(rt2,config["input_extension"])]

    # Set parameters
    #todo write the good line according to tool
    cutadapt_adapt_list = config["adapters"]["adapter_list"]
    
    cutadapt_options = config["adapters"]["options"]
    cutadapt_mode = config["adapters"]["mode"]
    cutadapt_min  = config["adapters"]["m"]
    cutadapt_qual = config["adapters"]["quality"]
    cutadapt_log = "01-Trimming/logs/{SAMPLE}_trim.txt"
    final_output.extend(expand(cutadapt_output, SAMPLE=samples))
    include: os.path.join(RULES, "cutadapt.rules")
else:
    cutadapt_output = input_data


#----------------------------------
# genome gestion
#----------------------------------

ref = [ config["genome"]["name"] ]

if config["design"]["spike"]:
    ref += [ "spikes" ]
    # TODO add possibility of index name and also, check if 


def mapping_index(wildcards):
    if (wildcards.REF == config["genome"]["name"]):
        input = config["genome"]["fasta_file"]
    elif (wildcards.REF == "spikes"):
        input = config["design"]["spike_genome_file"]
    return(input)

#----------------------------------
# bowtie2 indexing
#----------------------------------

if config["genome"]["index"]: 
    # indexing for bowtie2
    bowtie2_index_fasta = mapping_index
    bowtie2_index_output_done = os.path.join(config["genome"]["genome_directory"],"{REF}.1.bt2")
    bowtie2_index_output_prefix = os.path.join(config["genome"]["genome_directory"],"{REF}")
    bowtie2_index_log = "02-Mapping/logs/bowtie2_{REF}_indexing.log"
    final_output.extend(expand(bowtie2_index_output_done, REF=ref))
    include: os.path.join(RULES, "bowtie2_index.rules")

else:
    bowtie2_index_output_done = os.path.join(config["genome"]["genome_directory"],"{REF}.1.bt2")
    bowtie2_index_output_prefix = os.path.join(config["genome"]["genome_directory"],"{REF}")

#----------------------------------
# bowtie2 MAPPING
#----------------------------------

bowtie2_mapping_input = cutadapt_output
bowtie2_mapping_index_done = bowtie2_index_output_done
bowtie2_mapping_sort = "02-Mapping/{SAMPLE}_{REF}_sort.bam"
bowtie2_mapping_bam = "02-Mapping/{SAMPLE}_{REF}.bam"
bowtie2_mapping_sortprefix = "02-Mapping/{SAMPLE}_{REF}_sort"
bowtie2_mapping_logs_err = "02-Mapping/logs/{SAMPLE}_{REF}_mapping.err"
bowtie2_mapping_logs_out = "02-Mapping/logs/{SAMPLE}_{REF}_mapping.out"
bowtie2_mapping_prefix_index = bowtie2_index_output_prefix
bowtie2_mapping_options =  config["bowtie2_mapping"]["options"]
final_output.extend(expand(bowtie2_mapping_sort, SAMPLE=samples, REF=ref))
include: os.path.join(RULES, "bowtie2_mapping.rules")


#----------------------------------
# Mark duplicated reads
#----------------------------------

mark_duplicates_input = bowtie2_mapping_sort
mark_duplicates_output = "03-Deduplication/{SAMPLE}_{REF}_sort_dedup.bam"
mark_duplicates_metrics = "03-Deduplication/{SAMPLE}_{REF}_sort_dedup.txt"
mark_duplicates_log_std = "03-Deduplication/logs/{SAMPLE}_{REF}_sort_dedup.out"
mark_duplicates_log_err = "03-Deduplication/logs/{SAMPLE}_{REF}_sort_dedup.err"
mark_duplicates_tmpdir = config['tmpdir']
final_output.extend(expand(mark_duplicates_output, SAMPLE=samples, REF=ref))
include: os.path.join(RULES, "mark_duplicates.rules")

#----------------------------------
# Coverage step
#----------------------------------
if config["bamCoverage"]["do"]:
    bamCoverage_input = "03-Deduplication/{SAMPLE}_{REF}_sort_dedup.bam"
    bamCoverage_logs = "12-Coverage/logs/{SAMPLE}_{REF}.out"
    bamCoverage_output = "12-Coverage/{SAMPLE}_{REF}_coverage.bw"
    bamCoverage_options = config['bamCoverage']['options']
    if config["remove_biasedRegions"]["do"]:
        bamCoverage_options += " --blackListFileName {} ".format(config["remove_biasedRegions"]["bed_file"])
    final_output.extend(expand(bamCoverage_output, SAMPLE=samples, REF=ref))    
    include: os.path.join(RULES, "bamCoverage.rules")  



#----------------------------------
# Spikes counting
#----------------------------------

if config["design"]["spike"]:
    # counting on spikes
    spikes_counting_input = expand("03-Deduplication/{MARK}_{COND}_{REP}_{REF}_sort_dedup.bam", MARK=marks, COND=conds, REP=rep, REF="spikes")
    spikes_counting_output_json = "09-CountMatrix/Spikes_count.json"
    spikes_counting_output = "Spikes_metrics_mqc.out"
    spikes_counting_log = "03-Deduplication/logs/Spikes_metrics.o"
    final_output.extend([spikes_counting_output_json])
    include: os.path.join(RULES, "spikes_counting.rules")


#----------------------------------
# Remove biased regions
#----------------------------------

if config["remove_biasedRegions"]["do"]:
    biasedRegions = "_biasedRegions"
    biasedRegions_dir = "04-NobiasedRegions"
    remove_biasedRegions_input = "03-Deduplication/{MARK}_{COND}_{REP}_{REF}_sort_dedup.bam"
    remove_biasedRegions_output = "04-NobiasedRegions/{{MARK}}_{{COND}}_{{REP}}_{{REF}}_sort_dedup{}.bam".format(biasedRegions)
    remove_biasedRegions_log_std = "04-NobiasedRegions/logs/{{MARK}}_{{COND}}_{{REP}}_{{REF}}_sort_dedup{}.out".format(biasedRegions)
    remove_biasedRegions_log_err = "04-NobiasedRegions/logs/{{MARK}}_{{COND}}_{{REP}}_{{REF}}_sort_dedup{}.err".format(biasedRegions)
    final_output.extend(expand(remove_biasedRegions_output, MARK=marks, COND=conds, REP=rep, REF=ref))
    include: os.path.join(RULES, "remove_biasedRegions.rules")

else:
    biasedRegions = ""
    biasedRegions_dir = "03-Deduplication"


# then we performed peak calling only in files against reference genome and not spike genome !
ref = config["genome"]["name"] 

#----------------------------------
# preIDR on INPUT
#----------------------------------

if len(INPUT_REP) > 0 and pass_pool == 0 :
    preIDR_pool_input_bam = expand("%s/{{INPUT}}_{REP}_%s_sort_dedup%s.bam" %
                                       (biasedRegions_dir,ref, biasedRegions), REP = rep)
    preIDR_pool_log = "%s/logs/{INPUT}_preIDR_input.o" % (biasedRegions_dir)
    if len(rep) > 2:
        preIDR_pool_output = ["{}/{{INPUT}}_Pool_{}_sort_dedup{}.bam".format(biasedRegions_dir, ref, biasedRegions),
                                 "{}/{{INPUT}}_MaxiPool_{}_sort_dedup{}.bam".format(biasedRegions_dir, ref, biasedRegions)]
    else:
        preIDR_pool_output = "{}/{{INPUT}}_Pool_{}_sort_dedup{}.bam".format(biasedRegions_dir, ref, biasedRegions)

    final_output.extend(expand(preIDR_pool_output, INPUT=INPUT_REP))
    include: os.path.join(RULES, "preIDR_Pool.rules")

#----------------------------------
# preIDR on IP
#----------------------------------

if len(IP_REP) > 0:
    # run SPR
    preIDR_SPR_input_bam = expand("%s/{{IP}}_{REP}_%s_sort_dedup%s.bam" %
                                     (biasedRegions_dir,ref, biasedRegions) , REP = rep)
    preIDR_SPR_log = "%s/logs/{IP}_preIDR_SPR.o"% (biasedRegions_dir)
    preIDR_SPR_output = expand("%s/{{IP}}_{SPR}_%s_sort_dedup%s.bam" % (biasedRegions_dir,ref, biasedRegions),  SPR = spr)
    final_output.extend(expand(preIDR_SPR_output, IP=IP_REP))
    include: os.path.join(RULES, "preIDR_SPR.rules")

    # run PPR
    preIDR_PPR_input_bam = expand("%s/{{IP}}_{REP}_%s_sort_dedup%s.bam" %
                                     (biasedRegions_dir,ref, biasedRegions), REP = rep)
    preIDR_PPR_log = "%s/logs/{IP}_preIDR_PPR.o" % (biasedRegions_dir)
    preIDR_PPR_output =  expand("%s/{{IP}}_{PPR}_%s_sort_dedup%s.bam" %
                                   (biasedRegions_dir,ref, biasedRegions), PPR = ppr)
    final_output.extend(expand(preIDR_PPR_output, IP=IP_REP))
    include: os.path.join(RULES, "preIDR_PPR.rules")

#----------------------------------
# Cross correlation
#----------------------------------

if config["peak_calling"]["no_model"]:
    # PhantomPeakQualTools rule
    spp_input = "{}/{{ALL}}_{}_sort_dedup{}.bam".format(biasedRegions_dir,ref, biasedRegions)
    spp_output_pdf = "05-PhantomPeakQualTools/{{ALL}}_{}_sort_dedup{}_phantom.pdf".format(ref, biasedRegions)
    spp_metrics = "05-PhantomPeakQualTools/{{ALL}}_{}_sort_dedup{}_spp.out".format(ref, biasedRegions)
    spp_log_std = "05-PhantomPeakQualTools/logs/{ALL}_phantom.out"
    spp_log_err = "05-PhantomPeakQualTools/logs/{ALL}_phantom.err"
    spp_tmpdir = config['tmpdir']
    final_output.extend(expand(spp_metrics, ALL=ALL))
    include: os.path.join(RULES, "spp.rules")
else :
    # PhantomPeakQualTools rule
    spp_input = "{}/{{ALL}}_{}_sort_dedup{}.bam".format(biasedRegions_dir, ref, biasedRegions)
    spp_output_pdf = "05-PhantomPeakQualTools/{{ALL}}_{}_sort_dedup{}_phantom.pdf".format(ref, biasedRegions)
    spp_metrics = "05-PhantomPeakQualTools/{{ALL}}_{}_sort_dedup{}_spp.out".format(ref, biasedRegions)
    spp_log_std = "05-PhantomPeakQualTools/logs/{ALL}_phantom.out"
    spp_log_err = "05-PhantomPeakQualTools/logs/{ALL}_phantom.err"
    spp_tmpdir = config['tmpdir']
    final_output.extend(expand(spp_metrics, ALL=IP_ALL))
    include: os.path.join(RULES, "spp.rules")



#----------------------------------
# Peak Calling
#----------------------------------


model = config["peak_calling"]["mode_choice"]
model_dir = model
if config["peak_calling"]["no_model"]:
    model_dir += "-nomodel"
if model in ["narrow", "broad"]:

    # Peak Calling on replicates
    if model in ["narrow"]:
        # add corresponding options
        macs2_options = "-p {} ".format(config["peak_calling"]['cutoff']) + config["peak_calling"]['options']
    else:
        macs2_options = "-p {} --broad --broad-cutoff {} ".format(config["peak_calling"]['cutoff'],
                            config["peak_calling"]['cutoff']) + config["peak_calling"]['options']

    macs2_input_bam = "{}/{{IP}}_{}_sort_dedup{}.bam".format(biasedRegions_dir, ref, biasedRegions)
    macs2_input_done = ["{}/{{INPUT}}_{}_sort_dedup{}.bam".format(biasedRegions_dir, ref, biasedRegions)]
    if config["peak_calling"]["no_model"]:
        macs2_options += " --nomodel "
        macs2_input_done += ["05-PhantomPeakQualTools/{{IP}}_{}_sort_dedup{}_spp.out".format(ref, biasedRegions)]
        macs2_shift_file = "05-PhantomPeakQualTools/{{IP}}_{}_sort_dedup{}_spp.out".format(ref, biasedRegions)
    else:
        macs2_shift_file = "Empty"
    if paired :
        macs2_pe_mode = "no"
    else:
        macs2_pe_mode = "no"
    macs2_control = "-c {}/{{INPUT}}_{}_sort_dedup{}.bam".format(biasedRegions_dir, ref, biasedRegions)
    macs2_log = "06-PeakCalling/{}/logs/{{IP}}_vs_{{INPUT}}.o".format(model_dir)
    macs2_output = "06-PeakCalling/{}/{{IP}}_vs_{{INPUT}}_peaks.{}Peak".format(model_dir, model)
    macs2_output_prefix = "06-PeakCalling/{}/{{IP}}_vs_{{INPUT}}".format(model_dir)

    final_output.extend(expand(macs2_output, zip, IP=ALL_IP_PC, INPUT=ALL_INPUT_PC))
    include: os.path.join(RULES, "macs2.rules")


#----------------------------------
# Peak Calling metrics
#----------------------------------

macs2_rep = []
macs2_rep.extend(expand("06-PeakCalling/{}/{{IP_REP}}_vs_{{INPUT}}_peaks.{}Peak".format(model_dir, model), zip, IP_REP=IP_ALL, INPUT=INPUT))
stats_peakCalling_input = macs2_rep
stats_peakCalling_csv = "Peaks_metrics_mqc.out"
stats_peakCalling_marks = marks
stats_peakCalling_conds = conds
stats_peakCalling_rep = rep_flag
stats_peakCalling_log = "06-PeakCalling/{}/Peaks_metrics.out".format(model_dir)
include: os.path.join(RULES, "stats_peakCalling.rules")
final_output.extend([stats_peakCalling_csv])



#----------------------------------
# IDR computing
#----------------------------------

if model in ["narrow"]:
    compute_idr_mode = "narrowPeak"
else :
    compute_idr_mode = "broadPeak"
compute_idr_input1 = "06-PeakCalling/{}/{{IP_IDR}}_{{CASE}}1_vs_{{INPUT}}_peaks.{}Peak".format(model_dir, model)
compute_idr_input2 = "06-PeakCalling/{}/{{IP_IDR}}_{{CASE}}2_vs_{{INPUT}}_peaks.{}Peak".format(model_dir, model)
compute_idr_output = "07-IDR/{}/{{IP_IDR}}_{{CASE}}1vs{{CASE}}2_{{INPUT}}_{}_{}_idr.txt".format(model_dir, ref, model)
compute_idr_output_peak = "07-IDR/{}/{{IP_IDR}}_{{CASE}}1vs{{CASE}}2_{{INPUT}}_{}_{}_idr{}.{}Peak".format(model_dir, 
    ref, model, config["compute_idr"]["thresh"], model)
compute_idr_log = "07-IDR/{}/logs/{{IP_IDR}}_{{CASE}}1vs{{CASE}}2_{{INPUT}}_{}_idr.o".format(model_dir,model)
include: os.path.join(RULES, "compute_idr.rules")
final_output.extend(expand(compute_idr_output, zip, IP_IDR=REP_IDR, CASE=CASE, INPUT=IN_IDR))


#----------------------------------
# Select reproducible peaks
#----------------------------------


if model in ["narrow"] and not config["compute_idr"]["intersectionApproach"]:
    # Select IDR peaks
    select_peaks_input_rep = "07-IDR/{}/{{IP_IDR}}_{}1vs{}2_{{INPUT}}_{}_{}_idr{}.{}Peak".format(model_dir, rep_flag, rep_flag, ref, model, config["compute_idr"]["thresh"], model)
    select_peaks_input_ppr = "07-IDR/{}/{{IP_IDR}}_{}1vs{}2_{{INPUT}}_{}_{}_idr{}.{}Peak".format(model_dir, "PPR", "PPR", ref, model, config["compute_idr"]["thresh"], model)
    select_peaks_input_pool = "06-PeakCalling/{}/{{IP_IDR}}_PPRPool_vs_{{INPUT}}_peaks.{}Peak".format(model_dir, model)
    select_peaks_logs = "08-ReproduciblePeaks/%s/logs/{IP_IDR}_vs_{INPUT}_selected_peaks.o" % model_dir
    select_peaks_output = "08-ReproduciblePeaks/{}/{{IP_IDR}}_vs_{{INPUT}}_select.{}Peak".format(model_dir, model)
    include: os.path.join(RULES, "select_peaks.rules")
    final_output.extend(expand(select_peaks_output, zip, IP_IDR=IP_REP, INPUT=IP_REP_DUP))
    
#----------------------------------
# Intersection approach
#----------------------------------

if model in ["broad"] or config["compute_idr"]["intersectionApproach"]:
    intersectionApproach_input_rep = expand("06-PeakCalling/%s/{{IP_IDR}}_{CASE}_vs_{{INPUT}}_peaks.%sPeak" % (model_dir, model), CASE=rep)
    intersectionApproach_input_pool = "06-PeakCalling/{}/{{IP_IDR}}_PPRPool_vs_{{INPUT}}_peaks.{}Peak".format(model_dir, model)
    intersectionApproach_logs = "08-ReproduciblePeaks/%s/logs/{IP_IDR}_vs_{INPUT}_IA_peaks.o" % model_dir
    intersectionApproach_output = "08-ReproduciblePeaks/{}/{{IP_IDR}}_vs_{{INPUT}}_IA.{}Peak".format(model_dir, model)
    if model in ["broad"] :
        intersectionApproach_overlap = 0.8
    else:
        intersectionApproach_overlap = config["compute_idr"]["ia_overlap"]
    include: os.path.join(RULES, "intersectionApproach.rules")
    final_output.extend(expand(intersectionApproach_output, zip, IP_IDR=IP_REP, INPUT=IP_REP_DUP))
    select_peaks_output = intersectionApproach_output

    

#----------------------------------
# Compute IDR metrics
#----------------------------------


idr_peaks = []
idr_peaks.extend(expand("07-IDR/{}/{{IP_IDR}}_{{CASE}}1vs{{CASE}}2_{{INPUT}}_{}_{}_idr{}.{}Peak".format(model_dir,
                        ref,model, config["compute_idr"]["thresh"], model), zip, IP_IDR=REP_IDR, CASE=CASE, INPUT=IN_IDR))
metrics_peaks_input = idr_peaks
metrics_peaks_marks = marks
metrics_peaks_conds = conds
metrics_peaks_rep = rep_flag
metrics_peaks_logs = "07-IDR/{}/logs/IDR_metrics.out".format(model_dir)
metrics_peaks_output = "IDR_metrics_mqc.out"
include: os.path.join(RULES, "metrics_peaks.rules")
final_output.extend([metrics_peaks_output])


#----------------------------------
# Run differential analysis
#----------------------------------

if len(conds) > 1 and config["differential_analysis"]["do"]:

    def getPeakFilesByMark(wildcards):
        ALL_IP = expand(select_peaks_output, zip, IP_IDR=IP_REP, INPUT=IP_REP_DUP)
        IP_dict = {}
        for mark in marks:
            IP_dict[mark] = []
            for file in ALL_IP:
                if mark in file:
                    IP_dict[mark].append(file)

        return IP_dict[wildcards["MARK"]]

    #----------------------------------
    # get union peaks
    #----------------------------------

    union_peaks_input = getPeakFilesByMark
    union_peaks_logs = "09-CountMatrix/logs/{MARK}_Optimal_selected_peaks.o"
    if len(conds) == 2 :
        union_peaks_output = "09-CountMatrix/{{MARK}}_{}u{}_{}.optimal.{}Peak_overlap.bed".format(conds[0], conds[1], ref, model)
    elif len(conds) == 3 :
        union_peaks_output = "09-CountMatrix/{{MARK}}_{}u{}u{}_{}.optimal.{}Peak_overlap.bed".format(conds[0], conds[1], conds[2], ref, model)
    else :
       union_peaks_output = "09-CountMatrix/{{MARK}}_all_conds_{}.optimal.{}Peak_overlap.bed".format(ref, model)
    include: os.path.join(RULES, "union_peaks.rules")
    final_output.extend(expand(union_peaks_output, MARK=MARK_OK))

    #----------------------------------
    # get GFF from peak files
    #----------------------------------
    if len(conds) == 2 :
        bed_to_gff_input = "09-CountMatrix/{{MARK}}_{}u{}_{}.optimal.{}Peak_overlap.bed".format(conds[0], conds[1], ref, model)
        bed_to_gff_output = "09-CountMatrix/{{MARK}}_{}u{}_{}.optimal.{}Peak_overlap.gff".format(conds[0], conds[1], ref, model)
    elif len(conds) == 3 :
        bed_to_gff_input = "09-CountMatrix/{{MARK}}_{}u{}u{}_{}.optimal.{}Peak_overlap.bed".format(conds[0], conds[1], conds[2],  ref, model)
        bed_to_gff_output = "09-CountMatrix/{{MARK}}_{}u{}u{}_{}.optimal.{}Peak_overlap.gff".format(conds[0], conds[1], conds[2], ref, model)
    else :
        bed_to_gff_input = "09-CountMatrix/{{MARK}}_all_conds_{}.optimal.{}Peak_overlap.bed".format(ref, model)
        bed_to_gff_output = "09-CountMatrix/{{MARK}}_all_conds_{}.optimal.{}Peak_overlap.gff".format(ref, model)

    bed_to_gff_logs = "09-CountMatrix/logs/{MARK}_Optimal_bed2gff.o"
    final_output.extend(expand(bed_to_gff_output,  MARK=MARK_OK))
    include: os.path.join(RULES, "bed_to_gff.rules")


    def getBAMFilesByMark(wildcards):
        ALL_BAM = expand(macs2_input_bam, IP=IP_ALL )
        BAM_dict = {}
        for mark in marks:
            BAM_dict[mark] = []
            for file in ALL_BAM:
                if mark in file:
                    BAM_dict[mark].append(file)
        return BAM_dict[wildcards["MARK"]]

    #----------------------------------
    # feature Count on peaks
    #----------------------------------


    feature_counts_input = getBAMFilesByMark
    feature_counts_optional_input = []
    if config["differential_analysis"]["input_counting"]:
        feature_counts_optional_input += expand("{}/{{INPUT}}_{}_sort_dedup{}.bam".format(biasedRegions_dir, ref, biasedRegions), INPUT=CORR_INPUT)
    feature_counts_output_count = "09-CountMatrix/{{MARK}}_Matrix_Optimal_{}Peak.mtx".format(model)
    if len(conds) == 2 :
        feature_counts_gff = "09-CountMatrix/{{MARK}}_{}u{}_{}.optimal.{}Peak_overlap.gff".format(conds[0], conds[1], ref, model)
    elif len(conds) == 3 :
        feature_counts_gff = "09-CountMatrix/{{MARK}}_{}u{}u{}_{}.optimal.{}Peak_overlap.gff".format(conds[0], conds[1], conds[2], ref, model)
    else :
        feature_counts_gff = "09-CountMatrix/{{MARK}}_all_conds_{}.optimal.{}Peak_overlap.gff".format(ref, model)

    feature_counts_log = "09-CountMatrix/logs/{MARK}_counts.o"
    feature_counts_options = "-t peak -g gene_id"
    feature_counts_threads = 4
    final_output.extend(expand(feature_counts_output_count,  MARK=MARK_OK))
    include: os.path.join(RULES, "feature_counts.rules")

    #----------------------------------
    # differential analysis on peaks
    #----------------------------------
    

    method = config["differential_analysis"]["method"]
    norm = config["differential_analysis"]["normalisation"]

    chipuanar_init_input = feature_counts_output_count
    chipuanar_init_conds = conds
    chipuanar_init_rep = rep
    chipuanar_init_method = method
    chipuanar_init_norm = norm
    if config["differential_analysis"]["spikes"] and config["design"]["spike"]:
        chipuanar_init_spikes = spikes_counting_output_json
        chipuanar_init_input_done = spikes_counting_output_json
    else:
        chipuanar_init_spikes = ""
        chipuanar_init_input_done = feature_counts_output_count
    chipuanar_init_padj = config["differential_analysis"]["pAdjustMethod"]
    chipuanar_init_alpha = config["differential_analysis"]["alpha"]
    chipuanar_init_batch = config["differential_analysis"]["batch"]
    chipuanar_init_output_dir = "10-DifferentialAnalysis/{{MARK}}_{}_{}_{}".format(model_dir, method, norm)
    chipuanar_init_config_r = "10-DifferentialAnalysis/{{MARK}}_{}_{}_{}/config.R".format(model_dir, method, norm)
    chipuanar_init_genome = ref
    include: os.path.join(RULES, "chipuanar_init.rules")


    chipuanar_config_r = chipuanar_init_config_r
    chipuanar_output_dir = "10-DifferentialAnalysis/{{MARK}}_{}_{}_{}".format(model_dir, method, norm)
    chipuanar_report = "10-DifferentialAnalysis/{{MARK}}_{}_{}_{}/{{MARK}}_Stat_report_{}_{}.html".format(model_dir, method, norm,method, norm)
    chipuanar_logs = "10-DifferentialAnalysis/{{MARK}}_{}_{}_{}/{{MARK}}_{}_{}_Log.txt".format(model_dir, method, norm, method, norm)
    final_output.extend(expand(chipuanar_report,  MARK=MARK_OK))
    include: os.path.join(RULES, "chipuanar.rules")



#----------------------------------  
# MultiQC report
#----------------------------------

multiqc_input = final_output
multiqc_input_dir = "."
multiqc_logs = "11-Multiqc/multiqc.log"
multiqc_output = config['multiqc']['output-directory'] + "/%s/multiqc_report.html" % (model_dir)
multiqc_options = config['multiqc']['options'] + " -c config/multiqc_config.yaml "
multiqc_output_dir = config['multiqc']['output-directory'] + "/%s" % (model_dir)
final_output = [multiqc_output]
include: os.path.join(RULES, "multiqc.rules")


rule chipflow:
    input: final_output
        

#----------------------------------  
# Move needed files
#----------------------------------

onsuccess:

    # copy metrics json in the corresponding multiQC output when you are in exploratory mode
    import shutil
    shutil.copyfile(metrics_peaks_output, config['multiqc']['output-directory'] + "/%s/%s_IDR_metrics.txt" % (model_dir, model_dir))
    shutil.copyfile(stats_peakCalling_csv, config['multiqc']['output-directory'] + "/%s/%s_peaks_metrics.txt" % (model_dir, model_dir))

    # move cluster log files
    import os
    pattern = re.compile("slurm.*")
    dest = "cluster_logs"
    for filepath in os.listdir("."):
        if pattern.match(filepath):
            if not os.path.exists(dest):
                os.makedirs(dest)
            shutil.move(filepath, dest)