Skip to content
Snippets Groups Projects
PRO-seq.snakefile 80.08 KiB
"""
Snakefile to analyse PRO-seq data.
"""
import sys
major, minor = sys.version_info[:2]
if major < 3 or (major == 3 and minor < 6):
    sys.exit("Need at least python 3.6\n")

# TODO look at 5' and 3' ends of genes

#TODO: add local metaprofiles (around TSS and around TTS), no min length
# TSS or TTS should not be within 200 of a TSS or TTS

# counts using featureCounts, differential expression for transcripts, CDS and UTR (see genes.gtf), and introns (transcript - exons) (so we need counting on exons)
# ratios CDS / introns

import os
OPJ = os.path.join
from glob import glob
from pickle import load
from fileinput import input as finput

import warnings


def formatwarning(message, category, filename, lineno, line):
    """Used to format warning messages."""
    return "%s:%s: %s: %s\n" % (filename, lineno, category.__name__, message)


warnings.formatwarning = formatwarning


# Useful for functional style
from itertools import product, starmap
from operator import eq

# Useful data structures
from collections import OrderedDict
from collections import defaultdict, Counter

# To parse SAM format
import pysam
import pyBigWig

# For data processing and displaying
from sklearn.decomposition import PCA
import matplotlib as mpl
# To be able to run the script without a defined $DISPLAY
mpl.use("PDF")
#mpl.rcParams["figure.figsize"] = 2, 4
mpl.rcParams["font.sans-serif"] = [
    "Arial", "Liberation Sans", "Bitstream Vera Sans"]
mpl.rcParams["font.family"] = "sans-serif"

from matplotlib import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import husl

from libdeseq import do_deseq2
from libhts import median_ratio_to_pseudo_ref_size_factors, status_setter, plot_lfc_distribution, plot_MA
from libworkflows import wc_applied, ensure_relative, cleanup_and_backup, texscape
from libworkflows import get_chrom_sizes, column_converter
from libworkflows import strip_split, file_len, last_lines, save_plot, test_na_file
from libworkflows import make_id_list_getter, filter_combinator, SHELL_FUNCTIONS, warn_context
from libworkflows import feature_orientation2stranded
from libworkflows import sum_by_family
from libworkflows import read_htseq_counts, sum_htseq_counts
from libworkflows import read_feature_counts, sum_feature_counts
from smincludes import rules as irules
from smwrappers import wrappers_dir

alignment_settings = {"bowtie2": ""}

#TRIMMERS = ["cutadapt", "fastx_clipper"]
TRIMMERS = ["cutadapt"]
COUNTERS = ["feature_count"]
ORIENTATIONS = ["fwd", "rev", "all"]

COMPL = {"A" : "T", "C" : "G", "G" : "C", "T" : "A", "N" : "N"}

LFC_RANGE = {
    "protein_coding" : (-10, 10),
    "DNA_transposons_rmsk" : (-10, 10),
    "RNA_transposons_rmsk" : (-10, 10),
    "satellites_rmsk" : (-10, 10),
    "simple_repeats_rmsk" : (-10, 10),
    "DNA_transposons_rmsk_families" : (-10, 10),
    "RNA_transposons_rmsk_families" : (-10, 10),
    "satellites_rmsk_families" : (-10, 10),
    "simple_repeats_rmsk_families" : (-10, 10),
    "pseudogene" : (-10, 10),
    "all_rmsk" : (-10, 10),
    "all_rmsk_families" : (-10, 10),
    "alltypes" : (-10, 10)}
# Cutoffs in log fold change
LFC_CUTOFFS = [0.5, 1, 2]
UP_STATUSES = [f"up{cutoff}" for cutoff in LFC_CUTOFFS]
DOWN_STATUSES = [f"down{cutoff}" for cutoff in LFC_CUTOFFS]
#status2colour = make_status2colour(DOWN_STATUSES, UP_STATUSES)
#STATUSES = list(reversed(DOWN_STATUSES)) + ["down", "NS", "up"] + UP_STATUSES
#STATUS2COLOUR = dict(zip(STATUSES, sns.color_palette("coolwarm", len(STATUSES))))

# Or use --configfile instead
#configfile:
#    "PRO-seq.config.json"

# http://sailfish.readthedocs.io/en/master/library_type.html
LIB_TYPE = config["lib_type"]
# key: library name
# value: 3' adapter sequence
lib2adapt = config["lib2adapt"]
# key: library name
# value: [length of 5' UMI, length of 3' UMI]
lib2UMI = config["lib2UMI"]
# key: library name
# value: path to raw data
lib2raw = config["lib2raw"]
LIBS = list(lib2raw.keys())
REPS = config["replicates"]
COND_PAIRS = config["cond_pairs"]
msg = "\n".join([
    "Some contrats do not use known library names.",
    "Contrasts:"
    ", ".join([f"({cond}, {ref})" for (cond, ref) in COND_PAIRS])])
assert all([cond in LIBS and ref in LIBS for (cond, ref) in COND_PAIRS]), msg
CONTRASTS = [f"{cond1}_vs_{cond2}" for [cond1, cond2] in COND_PAIRS]
CONTRAST2PAIR = dict(zip(CONTRASTS, COND_PAIRS))
CONDITIONS = [{
    "lib" : lib,
    "rep" : rep} for rep in REPS for lib in LIBS]
# We use this for various things in order to have always the same library order:
COND_NAMES = ["_".join((
    cond["lib"],
    cond["rep"])) for cond in CONDITIONS]
COND_COLUMNS = pd.DataFrame(CONDITIONS).assign(
    cond_name=pd.Series(COND_NAMES).values).set_index("cond_name")
COUNT_BIOTYPES = config["count_biotypes"]

RMSK_BIOTYPES = [
    "DNA_transposons_rmsk",
    "RNA_transposons_rmsk",
    "satellites_rmsk",
    "simple_repeats_rmsk"]
RMSK_FAMILIES_BIOTYPES = [
    "DNA_transposons_rmsk_families",
    "RNA_transposons_rmsk_families",
    "satellites_rmsk_families",
    "simple_repeats_rmsk_families"]

BIOTYPES_TO_JOIN = {
    "all_rmsk": [biotype for biotype in COUNT_BIOTYPES if biotype in RMSK_BIOTYPES],
    "all_rmsk_families": [biotype for biotype in COUNT_BIOTYPES if biotype in RMSK_FAMILIES_BIOTYPES],
    # We only count "protein_coding", not "protein_coding_{5UTR,CDS,3UTR}"
    "alltypes": [biotype for biotype in COUNT_BIOTYPES if not biotype.startswith("protein_coding_")]}
JOINED_BIOTYPES = list(BIOTYPES_TO_JOIN.keys())
DE_BIOTYPES = [biotype for biotype in LFC_RANGE.keys() if biotype in COUNT_BIOTYPES + JOINED_BIOTYPES]


#ANNOT_BIOTYPES = config["annot_biotypes"]
#METAGENE_BIOTYPES = ["protein_coding", "DNA_transposons_rmsk", "RNA_transposons_rmsk"]
#METAGENE_BIOTYPES = ["protein_coding"]
#METAGENE_BIOTYPES = ["protein_coding", "protein_coding_5UTR", "protein_coding_CDS", "protein_coding_3UTR"]
METAGENE_BIOTYPES = [biotype for biotype in ["protein_coding", "protein_coding_5UTR", "protein_coding_CDS", "protein_coding_3UTR"] if biotype in COUNT_BIOTYPES]
# default id lists for MA plots
ID_LISTS = [
    "lfc_statuses",
    "germline_specific",
    "histone",
    "spermatogenic_Ortiz_2014", "oogenic_Ortiz_2014",
    "piRNA_dependent_prot_si_22G_down4_top200", "piRNA_dependent_prot_si_22G_down4",
    "csr1_prot_si_supertargets_common"]
ID_LISTS = config.get("maplot_gene_lists", ID_LISTS)
aligner = config["aligner"]
########################
# Genome configuration #
########################
genome_dict = config.get("genome_dict", None)
if genome_dict is not None:
    genome = genome_dict["name"]
    chrom_sizes = get_chrom_sizes(genome_dict["size"])
    genomelen = sum(chrom_sizes.values())
    genome_db = genome_dict["db"][aligner]
    # bed file binning the genome in 10nt bins
    genome_binned = genome_dict["binned"]
    annot_dir = genome_dict["annot_dir"]
    # TODO: figure out the difference between OPJ(convert_dir, "wormid2name.pickle") and genome_dict["converter"]
    convert_dir = genome_dict["convert_dir"]
    gene_lists_dir = genome_dict["gene_lists_dir"]
else:
    genome = "C_elegans"
    chrom_sizes = get_chrom_sizes(config["genome_size"])
    genomelen = sum(chrom_sizes.values())
    genome_db = config["index"]
    genome_binned = f"/pasteur/entites/Mhe/Genomes/{genome}/Caenorhabditis_elegans/Ensembl/WBcel235/Sequence/genome_binned_10.bed"
    annot_dir = config["annot_dir"]
    convert_dir = config["convert_dir"]
    gene_lists_dir = "/pasteur/entites/Mhe/bli/Gene_lists"
avail_id_lists = set(glob(OPJ(gene_lists_dir, "*_ids.txt")))
#gene_lists_dir = config["gene_lists_dir"]
#local_annot_dir = config["local_annot_dir"]
#output_dir = config["output_dir"]
#workdir: config["output_dir"]
output_dir = os.path.abspath(".")
log_dir = OPJ("logs")
data_dir = OPJ("data")
local_annot_dir = OPJ("annotations")
# Used to skip some genotype x treatment x replicate number combinations
# when some of them were not sequenced
forbidden = {frozenset(wc_comb.items()) for wc_comb in config["missing"]}

SIZE_FACTORS = ["protein_coding", "miRNA", "median_ratio_to_pseudo_ref"]
assert set(SIZE_FACTORS).issubset(set(COUNT_BIOTYPES) | {"median_ratio_to_pseudo_ref"})
#NORM_TYPES = config["norm_types"]
NORM_TYPES = ["protein_coding", "median_ratio_to_pseudo_ref"]
assert set(NORM_TYPES).issubset(set(SIZE_FACTORS))

# For metagene analyses
#META_MARGIN = 300
META_MARGIN = 0
META_SCALE = 2000
#UNSCALED_INSIDE = 500
UNSCALED_INSIDE = 0
#META_MIN_LEN = 1000
META_MIN_LEN = 2 * UNSCALED_INSIDE
MIN_DIST = 2 * META_MARGIN


######
# Colors from https://personal.sron.nl/~pault/
# (via https://personal.sron.nl/~pault/python/distinct_colours.py)
hexcols = ['#332288', '#88CCEE', '#44AA99', '#117733', '#999933', '#DDCC77', 
           '#CC6677', '#882255', '#AA4499', '#661100', '#6699CC', '#AA4466',
           '#4477AA']

greysafecols = ['#809BC8', '#FF6666', '#FFCC66', '#64C204']

xarr = [[12],
        [12, 6],
        [12, 6, 5],
        [12, 6, 5, 3],
        [0, 1, 3, 5, 6],
        [0, 1, 3, 5, 6, 8],
        [0, 1, 2, 3, 5, 6, 8],
        [0, 1, 2, 3, 4, 5, 6, 8],
        [0, 1, 2, 3, 4, 5, 6, 7, 8],
        [0, 1, 2, 3, 4, 5, 9, 6, 7, 8],
        [0, 10, 1, 2, 3, 4, 5, 9, 6, 7, 8],
        [0, 10, 1, 2, 3, 4, 5, 9, 6, 11, 7, 8]]

# get specified nr of distinct colours in HTML hex format.
# in: nr - number of colours [1..12]
# returns: list of distinct colours in HTML hex
def get_distinct(nr):

    #
    # check if nr is in correct range
    #
    
    assert not (nr < 1 or nr > 12), "wrong nr of distinct colours!"

    #
    # get list of indices
    #
    
    lst = xarr[nr-1]
    
    #
    # generate colour list by stepping through indices and looking them up
    # in the colour table
    #

    i_col = 0
    col = [0] * nr
    for idx in lst:
        col[i_col] = hexcols[idx]
        i_col += 1
    return col

######

# def filter_combinator(combinator, blacklist):
#     """This function builds a wildcards combination generator
#     based on the generator *combinator* and a set of combinations
#     to exclude *blacklist*."""
#     def filtered_combinator(*args, **kwargs):
#         """This function generates wildcards combinations.
#         It is to be used as second argument of *expand*."""
#         #print(*args)
#         for wc_comb in combinator(*args, **kwargs):
#             # Use frozenset instead of tuple
#             # in order to accomodate
#             # unpredictable wildcard order
#             if frozenset(wc_comb) not in blacklist:
#                 yield wc_comb
#     return filtered_combinator


filtered_product = filter_combinator(product, forbidden)

wildcard_constraints:
    lib="|".join(LIBS),
    rep="\d+",
    orientation="|".join(ORIENTATIONS),
    biotype="|".join(COUNT_BIOTYPES + JOINED_BIOTYPES)


# Define functions to be used in shell portions
shell.prefix(SHELL_FUNCTIONS)

if len(CONDITIONS) < 2:
    pca_figs = []
else:
    pca_figs = expand(OPJ(
        "{trimmer}", "figures", aligner, "{counter}",
        "{biotype}_{orientation}_PCA.pdf"),
        trimmer=TRIMMERS, counter=COUNTERS, biotype=COUNT_BIOTYPES,
        orientation=ORIENTATIONS),

rule all:
    """This top rule is used to drive the whole workflow by taking as input its final products."""
    input:
        expand(OPJ(
            "{trimmer}", "figures", aligner, "{counter}",
            "{contrast}", "{orientation}_{biotype}", "MA_with_{id_list}.pdf"),
            trimmer=TRIMMERS, counter=COUNTERS, contrast=CONTRASTS,
            orientation=ORIENTATIONS, biotype=DE_BIOTYPES, id_list=ID_LISTS),
        expand(OPJ(
            "{trimmer}", "figures", aligner, "{counter}",
            "{contrast}", "{orientation}_{biotype}", "{fold_type}_distribution.pdf"),
            trimmer=TRIMMERS, counter=COUNTERS, contrast=CONTRASTS,
            orientation=ORIENTATIONS, biotype=DE_BIOTYPES, fold_type=["log2FoldChange"]),
        expand(OPJ(
            "{trimmer}", aligner, "mapped_C_elegans", "{counter}",
            "deseq2", "{contrast}", "{orientation}_{biotype}", "counts_and_res.txt"),
            trimmer=TRIMMERS, counter=COUNTERS, contrast=CONTRASTS,
            orientation=ORIENTATIONS, biotype=DE_BIOTYPES),
        pca_figs,
        #expand(OPJ(
        #    "{trimmer}", "figures", aligner, "{counter}",
        #    "{biotype}_{orientation}_PCA.pdf"),
        #    trimmer=TRIMMERS, counter=COUNTERS, biotype=COUNT_BIOTYPES,
        #    orientation=ORIENTATIONS),
        #expand(OPJ("{trimmer}", aligner, "mapped_C_elegans", "{counter}", "all_on_C_elegans", "alltypes_{orientation}_counts.txt"), trimmer=TRIMMERS, counter=COUNTERS, orientation=["all"]),
        #expand(expand(OPJ("{{trimmer}}", aligner, "mapped_C_elegans", "{{counter}}", "{lib}_{rep}_on_C_elegans", "{{biotype}}_{{orientation}}_counts.txt"), filtered_product, lib=LIBS, rep=REPS), trimmer=TRIMMERS, counter=COUNTERS, biotype=COUNT_BIOTYPES, orientation=["all"]),
        expand(expand(OPJ(
            "{{trimmer}}", "figures", aligner, "{lib}_{rep}",
            "adapt_on_C_elegans_last_bases.pdf"), filtered_product, lib=LIBS, rep=REPS),
            trimmer=TRIMMERS),
        expand(expand(OPJ(
            "{{trimmer}}", aligner, "mapped_C_elegans",
            "{lib}_{rep}_on_C_elegans_by_{{norm_type}}_{{orientation}}.bw"), filtered_product, lib=LIBS, rep=REPS),
            trimmer=TRIMMERS, norm_type=NORM_TYPES, orientation=["all"]),
        expand(OPJ(
            "{trimmer}", aligner, "mapped_C_elegans", "{counter}",
            "all_on_C_elegans", "{biotype}_{orientation}_TPM.txt"),
            trimmer=TRIMMERS, counter=COUNTERS, biotype=["alltypes"], orientation=ORIENTATIONS),
        #expand(OPJ("{trimmer}", "figures", aligner, "{lib}_mean", "{orientation}_on_merged_isolated_%d_{biotype}_min_%d_meta_profile.pdf" % (MIN_DIST, META_MIN_LEN)), trimmer=TRIMMERS, lib=LIBS, orientation=["all"], biotype=["protein_coding"]),
        #expand(OPJ("{trimmer}", "figures", aligner, "{lib}_mean", "{orientation}_on_merged_isolated_%d_{biotype}_min_%d_meta_profile.pdf" % (MIN_DIST, META_MIN_LEN)), trimmer=TRIMMERS, lib=LIBS, orientation=["all"], biotype=METAGENE_BIOTYPES),
        # TODO: Add metagene profiles similar to small RNA-seq
        expand(OPJ(
            "{trimmer}", "figures", aligner, "{lib}_by_{norm_type}_mean",
            "{orientation}_on_merged_isolated_%d_{biotype}_min_%d_meta_profile.pdf" % (MIN_DIST, META_MIN_LEN)),
            trimmer=TRIMMERS, lib=LIBS, norm_type=NORM_TYPES, orientation=["all", "fwd", "rev"],
            biotype=METAGENE_BIOTYPES),


include: ensure_relative(irules["link_raw_data"], workflow.basedir)
#include: "../snakemake_wrappers/includes/link_raw_data.rules"


rule trim_and_dedup:
    """The adaptor is trimmed, then reads are treated in two groups depending
    on whether the adapter was found or not. For each group the reads are
    sorted, deduplicated, and the random k-mers (k=4) that helped identify
    PCR duplicates are removed at both ends"""
    input:
        rules.link_raw_data.output,
    output:
        noadapt = OPJ(data_dir, "trimmed_{trimmer}", "{lib}_{rep}_noadapt_deduped.fastq.gz"),
        adapt = OPJ(data_dir, "trimmed_{trimmer}", "{lib}_{rep}_adapt_deduped.fastq.gz"),
        nb_raw =  OPJ(data_dir, "trimmed_{trimmer}", "{lib}_{rep}_nb_raw.txt"),
        nb_adapt =  OPJ(data_dir, "trimmed_{trimmer}", "{lib}_{rep}_nb_adapt.txt"),
        nb_adapt_deduped =  OPJ(data_dir, "trimmed_{trimmer}", "{lib}_{rep}_nb_adapt_deduped.txt"),
        nb_noadapt =  OPJ(data_dir, "trimmed_{trimmer}", "{lib}_{rep}_nb_noadapt.txt"),
        nb_noadapt_deduped =  OPJ(data_dir, "trimmed_{trimmer}", "{lib}_{rep}_nb_noadapt_deduped.txt"),
    params:
        adapter = lambda wildcards : lib2adapt[wildcards.lib],
        process_type = "PRO-seq",
        trim5 = lambda wildcards : lib2UMI[wildcards.lib][0],
        trim3 = lambda wildcards : lib2UMI[wildcards.lib][1],
    threads: 8 # Actually, to avoid too much IO
    message:
        "Trimming adaptor from raw data using {wildcards.trimmer}, deduplicating reads, and removing 5' and 3' random n-mers for {wildcards.lib}_{wildcards.rep}."
    benchmark:
        OPJ(data_dir, "trimmed_{trimmer}", "{lib}_{rep}_trim_benchmark.txt")
    log:
        trim = OPJ(data_dir, "trimmed_{trimmer}", "{lib}_{rep}_trim.log"),
        log = OPJ(log_dir, "{trimmer}", "trim_and_dedup", "{lib}_{rep}.log"),
        err = OPJ(log_dir, "{trimmer}", "trim_and_dedup", "{lib}_{rep}.err"),
    run:
        shell_commands = """
THREADS="{threads}" {params.process_type}_trim_and_dedup.sh {wildcards.trimmer} {input} \\
    {params.adapter} {params.trim5} {params.trim3} \\
    {output.adapt} {output.noadapt} {log.trim} \\
    {output.nb_raw} {output.nb_adapt} {output.nb_adapt_deduped} \\
    {output.nb_noadapt} {output.nb_noadapt_deduped} 1> {log.log} 2> {log.err}
"""
        shell(shell_commands)


rule trim_only:
    """The adaptor is trimmed, then reads are treated in two groups depending
    on whether the adapter was found or not. For each group, the extra k-mers
    are removed at both ends."""
    input:
        rules.link_raw_data.output,
    output:
        noadapt = OPJ(data_dir, "trimmed_{trimmer}", "{lib}_{rep}_noadapt.fastq.gz"),
        adapt = OPJ(data_dir, "trimmed_{trimmer}", "{lib}_{rep}_adapt.fastq.gz"),
        nb_raw =  OPJ(data_dir, "trimmed_{trimmer}", "{lib}_{rep}_nb_raw.txt"),
        nb_adapt =  OPJ(data_dir, "trimmed_{trimmer}", "{lib}_{rep}_nb_adapt.txt"),
        nb_noadapt =  OPJ(data_dir, "trimmed_{trimmer}", "{lib}_{rep}_nb_noadapt.txt"),
    params:
        adapter = lambda wildcards : lib2adapt[wildcards.lib],
        process_type = "PRO-seq",
        trim5 = lambda wildcards : lib2UMI[wildcards.lib][0],
        trim3 = lambda wildcards : lib2UMI[wildcards.lib][1],
    threads: 8 # Actually, to avoid too much IO
    message:
        "Trimming adaptor from raw data using {wildcards.trimmer} and removing 5' and 3' random n-mers for {wildcards.lib}_{wildcards.rep}."
    benchmark:
        OPJ(data_dir, "trimmed_{trimmer}", "{lib}_{rep}_trim_benchmark.txt")
    log:
        trim = OPJ(data_dir, "trimmed_{trimmer}", "{lib}_{rep}_trim.log"),
        log = OPJ(log_dir, "{trimmer}", "trim_and_dedup", "{lib}_{rep}.log"),
        err = OPJ(log_dir, "{trimmer}", "trim_and_dedup", "{lib}_{rep}.err"),
    run:
        shell_commands = """
THREADS="{threads}" {params.process_type}_trim_only.sh {wildcards.trimmer} {input} \\
    {params.adapter} {params.trim5} {params.trim3} \\
    {output.adapt} {output.noadapt} {log.trim} \\
    {output.nb_raw} {output.nb_adapt} {output.nb_noadapt} \\
    1> {log.log} 2> {log.err}
"""
        shell(shell_commands)


def source_fastq(wildcards):
    """
    Determine the correct pre-processed fastq file depending on the pipeline
    configuration and the current wildcards.
    """
    if config.get("deduplicate", False):
        return OPJ(
            data_dir, f"trimmed_{wildcards.trimmer}",
            f"{wildcards.lib}_{wildcards.rep}_{wildcards.type}_deduped.fastq.gz"),
    else:
        return OPJ(
            data_dir, f"trimmed_{wildcards.trimmer}",
            f"{wildcards.lib}_{wildcards.rep}_{wildcards.type}.fastq.gz"),


# TODO: Do not deduplicate, or at least do not use the noadapt_deduped: The 3' UMI is not present.
rule map_on_genome:
    input:
        fastq = source_fastq,
    output:
        # sam files take a lot of space
        sam = temp(OPJ("{trimmer}", aligner, "mapped_C_elegans", "{lib}_{rep}_{type}_on_C_elegans.sam")),
        nomap_fastq = OPJ("{trimmer}", aligner, "not_mapped_C_elegans", "{lib}_{rep}_{type}_unmapped_on_C_elegans.fastq.gz"),
    params:
        aligner = aligner,
        index = genome_db,
        settings = "",
    message:
        "Mapping {wildcards.lib}_{wildcards.rep}_{wildcards.type} on %s genome." % genome
    log:
        log = OPJ(log_dir, "{trimmer}", "map_{type}_on_genome", "{lib}_{rep}.log"),
        err = OPJ(log_dir, "{trimmer}", "map_{type}_on_genome", "{lib}_{rep}.err"),
    threads:
        8
#    shell:
#        """
#        genome_dir="${{HOME}}/Genomes"
#        genome="C_elegans"
#        bowtie2_genome_db="${{genome_dir}}/${{genome}}/Caenorhabditis_elegans/Ensembl/WBcel235/Sequence/Bowtie2Index/genome"
#        cmd="bowtie2 --seed 123 -t --mm -x ${{bowtie2_genome_db}} -U {input.fastq} --no-unal --un-gz {output.nomap_fastq} -S {output.sam}"
#        echo ${{cmd}} 1> {log.log} 2> {log.err}
#        eval ${{cmd}} 1>> {log.log} 2>> {log.err}
#        """
    wrapper:
        f"file://{wrappers_dir}/map_on_genome"


rule sam2indexedbam:
    input:
        sam = OPJ("{trimmer}", aligner, "mapped_C_elegans", "{lib}_{rep}_{type}_on_C_elegans.sam"),
    output:
        sorted_bam = OPJ("{trimmer}", aligner, "mapped_C_elegans", "{lib}_{rep}_{type}_on_C_elegans_sorted.bam"),
        index = OPJ("{trimmer}", aligner, "mapped_C_elegans", "{lib}_{rep}_{type}_on_C_elegans_sorted.bam.bai"),
    message:
        "Sorting and indexing sam file for {wildcards.lib}_{wildcards.rep}_{wildcards.type}."
    log:
        log = OPJ(log_dir, "{trimmer}", "sam2indexedbam", "{lib}_{rep}_{type}.log"),
        err = OPJ(log_dir, "{trimmer}", "sam2indexedbam", "{lib}_{rep}_{type}.err"),
    threads:
        8
    resources:
        mem_mb=4100
    wrapper:
        f"file://{wrappers_dir}/sam2indexedbam"


rule fuse_bams:
    """This rule fuses the two sorted bam files corresponding to the mapping
    of the reads containing the adaptor or not."""
    input:
        noadapt_sorted_bam = OPJ("{trimmer}", aligner, "mapped_C_elegans", "{lib}_{rep}_noadapt_on_C_elegans_sorted.bam"),
        adapt_sorted_bam = OPJ("{trimmer}", aligner, "mapped_C_elegans", "{lib}_{rep}_adapt_on_C_elegans_sorted.bam"),
    output:
        sorted_bam = OPJ("{trimmer}", aligner, "mapped_C_elegans", "{lib}_{rep}_on_C_elegans_sorted.bam"),
        bai = OPJ("{trimmer}", aligner, "mapped_C_elegans", "{lib}_{rep}_on_C_elegans_sorted.bam.bai"),
    message:
        "Fusing sorted bam files for {wildcards.lib}_{wildcards.rep}"
    log:
        log = OPJ(log_dir, "{trimmer}", "fuse_bams", "{lib}_{rep}.log"),
        err = OPJ(log_dir, "{trimmer}", "fuse_bams", "{lib}_{rep}.err"),
    shell:
        """
        samtools merge -c {output.sorted_bam} {input.noadapt_sorted_bam} {input.adapt_sorted_bam} 1> {log.log} 2> {log.err}
        indexed=""
        while [ ! ${{indexed}} ]
        do
            samtools index {output.sorted_bam} && indexed="OK"
            if [ ! ${{indexed}} ]
            then
                rm -f {output.bai}
                echo "Indexing failed. Retrying" 1>&2
            fi
        done 1>> {log.log} 2>> {log.err}
        """


rule compute_coverage:
    input:
        sorted_bam = rules.fuse_bams.output.sorted_bam,
    output:
        coverage = OPJ("{trimmer}", aligner, "mapped_C_elegans", "{lib}_{rep}_on_C_elegans_coverage.txt"),
    params:
        genomelen = genomelen,
    shell:
        """
        bases=$(samtools depth {input.sorted_bam} | awk '{{sum += $3}} END {{print sum}}') || error_exit "samtools depth failed"
        python3 -c "print(${{bases}} / {params.genomelen})" > {output.coverage}
        """

rule check_last_base:
    input:
        adapt_sorted_bam = OPJ("{trimmer}", aligner, "mapped_C_elegans", "{lib}_{rep}_adapt_on_C_elegans_sorted.bam"),
        adapt_index = OPJ("{trimmer}", aligner, "mapped_C_elegans", "{lib}_{rep}_adapt_on_C_elegans_sorted.bam.bai"),
    output:
        OPJ("{trimmer}", aligner, "mapped_C_elegans", "{lib}_{rep}_adapt_on_C_elegans_last_bases.txt")
    message:
        "Computing last base proportions for {wildcards.lib}_{wildcards.rep} (mapped reads from which the adaptor had been removed)."
    log:
        log = OPJ(log_dir, "{trimmer}", "check_last_base", "{lib}_{rep}.log"),
        err = OPJ(log_dir, "{trimmer}", "check_last_base", "{lib}_{rep}.err"),
    run:
        base_counts = defaultdict(Counter)
        with pysam.AlignmentFile(input.adapt_sorted_bam) as samfile:
            for ali in samfile.fetch():
                seq = ali.seq
                # To avoid errors when last base was erroneous:
                #seq = ali.get_reference_sequence()
                if ali.is_reverse:
                    base_counts[len(seq)][COMPL[seq[0].upper()]] += 1
                else:
                    base_counts[len(seq)][seq[-1].upper()] += 1
        with open(output[0], "w") as output_file:
            print("#length\tnb_reads\tA\tC\tG\tT\tN", file=output_file)
            for length, counter in sorted(base_counts.items()):
                nb_reads_this_length = sum(counter.values())
                print(length, nb_reads_this_length, *[str(counter[letter] / nb_reads_this_length) for letter in "ACGTN"], sep="\t", file=output_file)


# TODO: use Python to make the plot
# This may remove dependency on R docopt
rule plot_last_base:
    input:
        OPJ("{trimmer}", aligner, "mapped_C_elegans", "{lib}_{rep}_adapt_on_C_elegans_last_bases.txt")
    output:
        OPJ("{trimmer}", "figures", aligner, "{lib}_{rep}", "adapt_on_C_elegans_last_bases.pdf")
    params:
        title = lambda wildcards : "\"last base frequencies for %s_%s_%s\"" % (wildcards.trimmer, wildcards.lib, wildcards.rep)
    message:
        "Plotting last base proportions for {wildcards.lib}_{wildcards.rep} (mapped reads from which the adaptor had been removed)."
    log:
        OPJ(log_dir, "{trimmer}", "plot_last_base", "{lib}_{rep}.log")
    shell:
        """
        plot_last_base.R -i {input} -o {output} -t {params.title}
        """


def htseq_orientation2stranded(wildcards):
    orientation = wildcards.orientation
    if orientation == "fwd":
        if LIB_TYPE[-2:] == "SF":
            return "yes"
        elif LIB_TYPE[-2:] == "SR":
            return "reverse"
        else:
            raise ValueError(f"{LIB_TYPE} library type not compatible with strand-aware read counting.")
    elif orientation == "rev":
        if LIB_TYPE[-2:] == "SF":
            return "reverse"
        elif LIB_TYPE[-2:] == "SR":
            return "yes"
        else:
            raise ValueError(f"{LIB_TYPE} library type not compatible with strand-aware read counting.")
    elif orientation == "all":
        return "no"
    else:
        exit("Orientation is to be among \"fwd\", \"rev\" and \"all\".")

def biotype2annot(wildcards):
    #return "/pasteur/entites/Mhe/Genomes/C_elegans/Caenorhabditis_elegans/Ensembl/WBcel235/Annotation/Genes/%s.gtf" % wildcards.biotype
    if wildcards.biotype.endswith("_rmsk_families"):
        biotype = wildcards.biotype[:-9]
    else:
        biotype = wildcards.biotype
    return OPJ(annot_dir, f"{biotype}.gtf")


rule htseq_count_reads:
    input:
        sorted_bam = OPJ("{trimmer}", aligner, "mapped_C_elegans", "{lib}_{rep}_on_C_elegans_sorted.bam"),
        bai = OPJ("{trimmer}", aligner, "mapped_C_elegans", "{lib}_{rep}_on_C_elegans_sorted.bam.bai"),
    output:
        counts = OPJ("{trimmer}", aligner, "mapped_C_elegans", "htseq_count", "{lib}_{rep}_on_C_elegans", "{biotype}_{orientation}_counts.txt"),
        counts_converted = OPJ("{trimmer}", aligner, "mapped_C_elegans", "htseq_count", "{lib}_{rep}_on_C_elegans", "{biotype}_{orientation}_counts_gene_names.txt"),
    params:
        stranded = htseq_orientation2stranded,
        mode = "union",
        annot = biotype2annot,
    message:
        "Counting {wildcards.orientation} {wildcards.biotype} reads for {wildcards.lib}_{wildcards.rep} with htseq-count."
    benchmark:
        OPJ(log_dir, "{trimmer}", "htseq_count_reads", "{lib}_{rep}_{biotype}_{orientation}_benchmark.txt")
    log:
        log = OPJ(log_dir, "{trimmer}", "htseq_count_reads", "{lib}_{rep}_{biotype}_{orientation}.log"),
        err = OPJ(log_dir, "{trimmer}", "htseq_count_reads", "{lib}_{rep}_{biotype}_{orientation}.err")
    wrapper:
        f"file://{wrappers_dir}/htseq_count_reads"


def parse_htseq_counts(counts_filename):
    with open(counts_filename) as counts_file:
        for line in counts_file:
            (gene, count) = line.strip().split()
            if gene.startswith("__"):
                return
            yield (gene, int(count))


rule feature_count_reads:
    input:
        sorted_bam = OPJ(
            "{trimmer}", aligner, "mapped_C_elegans",
            "{lib}_{rep}_on_C_elegans_sorted.bam"),
        bai = OPJ(
            "{trimmer}", aligner, "mapped_C_elegans",
            "{lib}_{rep}_on_C_elegans_sorted.bam.bai"),
        # TODO: Why does the following fail?
        #sorted_bam = rules.fuse_bams.output.sorted_bam,
        #index = rules.fuse_bams.output.index,
    output:
        counts = OPJ(
            "{trimmer}", aligner, "mapped_C_elegans", "feature_count",
            "{lib}_{rep}_on_C_elegans", "{biotype}_{orientation}_counts.txt"),
        counts_converted = OPJ(
            "{trimmer}", aligner, "mapped_C_elegans", "feature_count",
            "{lib}_{rep}_on_C_elegans", "{biotype}_{orientation}_counts_gene_names.txt"),
    params:
        stranded = feature_orientation2stranded(LIB_TYPE),
        annot = biotype2annot,
        # pickled dictionary that associates gene ids to gene names
        converter = genome_dict["converter"]
    message:
        "Counting {wildcards.orientation} {wildcards.biotype} reads for {wildcards.lib}_{wildcards.rep} with featureCounts."
    log:
        log = OPJ(log_dir, "{trimmer}", "feature_count_reads", "{orientation}_{biotype}", "{lib}_{rep}.log"),
        err = OPJ(log_dir, "{trimmer}", "feature_count_reads", "{orientation}_{biotype}", "{lib}_{rep}.err")
    shell:
        """
        tmpdir=$(mktemp -dt "feature_{wildcards.lib}_{wildcards.rep}_{wildcards.biotype}_{wildcards.orientation}.XXXXXXXXXX")
        cmd="featureCounts -a {params.annot} -o {output.counts} -t transcript -g "gene_id" -O -M --primary -s {params.stranded} --fracOverlap 0 --tmpDir ${{tmpdir}} {input.sorted_bam}"
        featureCounts -v 2> {log.log}
        echo ${{cmd}} 1>> {log.log}
        eval ${{cmd}} 1>> {log.log} 2> {log.err} || error_exit "featureCounts failed"
        rm -rf ${{tmpdir}}
        cat {output.counts} | wormid2name > {output.counts_converted}
        # cat {output.counts} | id2name.py {params.converter} > {output.counts_converted}
        """


def parse_feature_counts(counts_filename):
    with open(counts_filename) as counts_file:
        for line in counts_file:
            # skip comments
            if line[0] == "#":
                continue
            fields = line.strip().split()
            # skip header
            if fields[:6] == ["Geneid", "Chr", "Start", "End", "Strand", "Length"]:
                continue
            gene, count = fields[0], int(fields[6])
            yield (gene, count)


def same_gene_order(od1, od2):
    """Returns True if the keys of ordered dictionaries *od1* and *od2* are the same, in the same order."""
    if len(od1) != len(od2):
        return False
    return all(starmap(eq, zip(od1.keys(), od2.keys())))


def plot_scatterplot(outfile, data, data_groups, group2colour):
    #axis = plt.gca()
    for (i, (group, colour)) in enumerate(group2colour.items()):
        plt.scatter(
            data[data_groups==i, 0],
            data[data_groups==i, 1],
            color=colour, lw=2, label=texscape(group))
    plt.xlabel("PC1")
    plt.ylabel("PC2")
    plt.legend(loc="best", shadow=False, scatterpoints=1)
    plt.tight_layout()
    plt.savefig(outfile, format=outfile.name.split(".")[-1])
    #plt.savefig(outfile, format="pdf")
    plt.cla()


rule do_PCA:
    input:
        expand(OPJ("{{trimmer}}", aligner, "mapped_C_elegans", "{{counter}}", "{lib}_{rep}_on_C_elegans", "{{biotype}}_{{orientation}}_counts.txt"), filtered_product, lib=LIBS, rep=REPS),
    output:
        #OPJ(aligner, "mapped_C_elegans", "htseq_count", "summaries", "{biotype}_{orientation}_PCA.pdf"),
        #OPJ(aligner, "mapped_C_elegans", "htseq_count", "summaries", "{biotype}_{orientation}_PCA.png"),
        OPJ("{trimmer}", "figures", aligner, "{counter}", "{biotype}_{orientation}_PCA.pdf"),
    message:
        "Summarizing counts for {wildcards.biotype}_{wildcards.orientation}"
    threads: 12  # trying to avoid TimeoutError and "LOCKERROR: matplotlib is trying to acquire the lock [...]"
    log:
        OPJ(log_dir, "{trimmer}", "{counter}", "do_PCA_{biotype}_{orientation}.log")
    run:
        if wildcards.counter == "htseq_count":
            counts_parser = parse_htseq_counts
        elif wildcards.counter == "feature_count":
            counts_parser = parse_feature_counts
        else:
            raise NotImplementedError("%s is not yet handled." % aligner)
        # We need the order to be fixed for the zip to be meaningful
        counts = OrderedDict([])
        nb_libs = len(LIBS)
        nb_reps = len(REPS)
        counts_array = np.empty((nb_libs * nb_reps,), dtype=object)
        #lib_categories = np.empty([len(LIBS) * len(REPS)], dtype=np.uint32)
        lib_categories = np.fromiter(chain(*[[i] * nb_reps for i in range(nb_libs)]), dtype=np.uint32)
        for i, (lib, rep) in enumerate(product(LIBS, REPS)):
            counts_filename = OPJ(
                wildcards.trimmer, aligner, "mapped_C_elegans", wildcards.counter,
                "%s_%s_on_C_elegans" % (lib, rep), "%s_%s_counts.txt" % (wildcards.biotype, wildcards.orientation))
            #print("Reading", counts_filename)
            counts[(lib, rep)] = OrderedDict(counts_parser(counts_filename))
            counts_array[i] = np.fromiter(counts[(lib, rep)].values(), np.uint32)
            #lib_categories[i] = i // len(REPS)
        # zipping the values with a shifted version and pairwise comparing the gene orders
        assert all(starmap(
            same_gene_order,
            zip(list(counts.values()),
                list(counts.values())[1:]))), "All counts file should have the same genes in the same order."
        #counts_array = np.array([library_array for library_array in counts_array])
        # Faster (http://stackoverflow.com/a/40402682/1878788)
        counts_array = np.concatenate(counts_array).reshape(len(counts_array), -1)
        libs_pca = PCA(n_components=2)
        libs_fitting = libs_pca.fit(counts_array)
        libs_transformed = libs_fitting.transform(counts_array)
        genes_pca = PCA(n_components=2)
        genes_fitting = genes_pca.fit(counts_array.T)
        genes_transformed = genes_fitting.transform(counts_array.T)
        # http://stackoverflow.com/questions/40425036/how-to-extract-the-extreme-two-colours-from-a-matplotlib-colormap
        #TODO: make this configurable
        colormap = "BrBG"
        cmap = plt.cm.get_cmap(colormap)
        # Extract rgb coordinates
        left_rgb = cmap(0)[:-1]
        right_rgb = cmap(cmap.N)[:-1]
        # Convert to husl and take the hue (first element)
        left_hue = husl.rgb_to_husl(*left_rgb)[0]
        right_hue = husl.rgb_to_husl(*right_rgb)[0]
        # Create a dark divergent palette
        palette = sns.diverging_palette(left_hue, right_hue, n=(2 * (nb_libs // 2)) + 1, center="dark")
        lib2colour = OrderedDict(zip(LIBS, [*palette[nb_libs // 2:(nb_libs+1) // 2], *palette[:nb_libs // 2], *palette[1 + nb_libs // 2:]]))
        with open(output[0], "wb") as outfile:
            plot_scatterplot(outfile, libs_transformed, lib_categories, lib2colour)

#def plot_cluster(counts_dir, biotype, orientation, counts_files):
#    libname_finder = re.compile("%s/(.+)_on_C_elegans_%s_%s_counts.txt" % (counts_dir, biotype, orientation))
#    libnames = [libname_finder.match(fname).groups()[0] for fname in counts_files]
#    d = pd.read_csv(filename, sep="\t", header=None, index_col=0)


rule summarize_counts:
    """For a given library, write a summary of the read counts for the various biotypes."""
    input:
        biotype_counts_files = expand(OPJ("{{trimmer}}", aligner, "mapped_C_elegans", "{{counter}}", "{{lib}}_{{rep}}_on_C_elegans", "{biotype}_{{orientation}}_counts.txt"), biotype=COUNT_BIOTYPES),
    output:
        summary = OPJ("{trimmer}", aligner, "mapped_C_elegans", "{counter}", "summaries", "{lib}_{rep}_on_C_elegans_{orientation}_counts.txt")
    run:
        if wildcards.counter == "htseq_count":
            sum_counter = sum_htseq_counts
        elif wildcards.counter == "feature_count":
            sum_counter = sum_feature_counts
        else:
            raise NotImplementedError(f"{wildcards.counter} not handled (yet?)")
        with open(output.summary, "w") as summary_file:
            header = "\t".join(COUNT_BIOTYPES)
            #summary_file.write("#biotypes\t%s\n" % header)
            summary_file.write("%s\n" % header)
            sums = "\t".join((str(sum_counter(counts_file)) for counts_file in input.biotype_counts_files))
            #summary_file.write("%s_%s_%s\t%s\n" % (wildcards.lib, wildcards.rep, wildcards.orientation, sums))
            summary_file.write("%s\n" % sums)

rule gather_read_counts_summaries:
    input:
        summary_tables = expand(OPJ("{{trimmer}}", aligner, "mapped_C_elegans", "{{counter}}", "summaries", "{lib}_{rep}_on_C_elegans_{{orientation}}_counts.txt"), filtered_product, lib=LIBS, rep=REPS),
    output:
        summary_table = OPJ("{trimmer}", aligner, "mapped_C_elegans", "{counter}", "summaries", "all_on_C_elegans_{orientation}_counts.txt"),
    run:
        summary_files = (OPJ(
            wildcards.trimmer,
            aligner,
            "mapped_C_elegans",
            wildcards.counter,
            "summaries",
            "{name}_on_C_elegans_{orientation}_counts.txt".format(
                name=cond_name,
                orientation=wildcards.orientation)) for cond_name in COND_NAMES)
        summaries = pd.concat((pd.read_table(summary_file).T.astype(int) for summary_file in summary_files), axis=1)
        summaries.columns = COND_NAMES
        summaries.to_csv(output.summary_table, sep="\t")


rule gather_counts:
    """For a given biotype, gather counts from all libraries in one table."""
    input:
        counts_tables = expand(OPJ("{{trimmer}}", aligner, "mapped_C_elegans", "{{counter}}", "{lib}_{rep}_on_C_elegans", "{{biotype}}_{{orientation}}_counts.txt"), filtered_product, lib=LIBS, rep=REPS),
    output:
        counts_table = OPJ("{trimmer}", aligner, "mapped_C_elegans", "{counter}", "all_on_C_elegans", "{biotype}_{orientation}_counts.txt"),
    wildcard_constraints:
        # Avoid ambiguity with join_all_counts
        biotype = "|".join(COUNT_BIOTYPES)
    run:
        # Gathering the counts data
        ############################
        counts_files = (OPJ(
            wildcards.trimmer,
            aligner,
            "mapped_C_elegans",
            wildcards.counter,
            "{name}_on_C_elegans".format(name=cond_name),
            "{biotype}_{orientation}_counts.txt".format(biotype=wildcards.biotype, orientation=wildcards.orientation)) for cond_name in COND_NAMES)
        if wildcards.counter == "htseq_count":
            counts_data = pd.concat(
                map(read_htseq_counts, counts_files),
                axis=1).fillna(0).astype(int)
        elif wildcards.counter == "intersect_count":
            counts_data = pd.concat(
                map(read_intersect_counts, counts_files),
                axis=1).fillna(0).astype(int)
        elif wildcards.counter == "feature_count":
            counts_data = pd.concat(
                map(read_feature_counts, counts_files),
                axis=1).fillna(0).astype(int)
        else:
            raise NotImplementedError(f"{wilcards.counter} not handled (yet?)")
        counts_data.columns = COND_NAMES
        # Simple_repeat|Simple_repeat|(TTTTTTG)n:1
        # Simple_repeat|Simple_repeat|(TTTTTTG)n:2
        # Simple_repeat|Simple_repeat|(TTTTTTG)n:3
        # Simple_repeat|Simple_repeat|(TTTTTTG)n:4
        # -> Simple_repeat|Simple_repeat|(TTTTTTG)n
        if wildcards.biotype.endswith("_rmsk_families"):
            counts_data = sum_by_family(counts_data)
        counts_data.index.names = ["gene"]
        counts_data.to_csv(output.counts_table, sep="\t")


@wc_applied
def source_counts_to_join(wildcards):
    """
    Determines which elementary biotype counts files should be joined to make the desired "joined" biotype.
    """
    return expand(
        OPJ("{{trimmer}}", aligner, "mapped_C_elegans",
            "{{counter}}", "all_on_C_elegans",
            "{biotype}_{{orientation}}_counts.txt"),
        biotype=BIOTYPES_TO_JOIN[wildcards.biotype])


rule join_all_counts:
    """concat counts for all biotypes into all"""
    input:
        counts_tables = source_counts_to_join,
        #counts_tables = expand(OPJ("{{trimmer}}", aligner, "mapped_C_elegans", "{{counter}}", "all_on_C_elegans", "{biotype}_{{orientation}}_counts.txt"), biotype=[bty for bty in COUNT_BIOTYPES if bty[-9:] != "_families"]),
        # counts_tables = expand(
        #     OPJ("{{trimmer}}", aligner, "mapped_C_elegans",
        #         "{{counter}}", "all_on_C_elegans",
        #         "{biotype}_{{orientation}}_counts.txt"),
        #     # We only count "protein_coding", not "protein_codin_{5UTR,CDS,3UTR}"
        #     biotype=[b for b in COUNT_BIOTYPES if not b.startswith("protein_coding_")]),
    output:
        counts_table = OPJ("{trimmer}", aligner, "mapped_C_elegans", "{counter}", "all_on_C_elegans", "{biotype}_{orientation}_counts.txt"),
    wildcard_constraints:
        biotype = "|".join(JOINED_BIOTYPES)
    run:
        counts_data = pd.concat((pd.read_table(table, index_col="gene") for table in input.counts_tables))
        assert len(counts_data.index.unique()) == len(counts_data.index), "Some genes appear several times in the counts table."
        counts_data.index.names = ["gene"]
        counts_data.to_csv(output.counts_table, sep="\t")


@wc_applied
def source_counts(wildcards):
    """Determines from which rule the gathered small counts should be sourced."""
    if wildcards.biotype in JOINED_BIOTYPES:
        return rules.join_all_counts.output.counts_table
    else:
        # "Directly" from the counts gathered across libraries
        return rules.gather_counts.output.counts_table


rule compute_RPK:
    """For a given biotype, compute the corresponding RPK value (reads per kilobase)."""
    input:
        counts_data = source_counts,
        #counts_data = rules.gather_counts.output.counts_table,
        #counts_table = OPJ("{trimmer}", aligner, "mapped_C_elegans", "{counter}",
        #    "all_on_C_elegans", "{biotype}_{orientation}_counts.txt"),
    output:
        rpk_file = OPJ("{trimmer}", aligner, "mapped_C_elegans", "{counter}",
            "all_on_C_elegans", "{biotype}_{orientation}_RPK.txt"),
    params:
        feature_lengths_file = OPJ(annot_dir, "union_exon_lengths.txt"),
    # run:
    #     counts_data = pd.read_table(input.counts_data, index_col="gene")
    #     feature_lengths = pd.read_table(params.feature_lengths_file, index_col="gene")
    #     common = counts_data.index.intersection(feature_lengths.index)
    #     rpk = 1000 * counts_data.loc[common].div(feature_lengths.loc[common]["union_exon_len"], axis="index")
    #     rpk.to_csv(output.rpk_file, sep="\t")
    wrapper:
        f"file://{wrappers_dir}/compute_RPK"


rule compute_sum_million_RPK:
    input:
        rpk_file = rules.compute_RPK.output.rpk_file,
    output:
        sum_rpk_file = OPJ("{trimmer}", aligner, "mapped_C_elegans", "{counter}",
            "all_on_C_elegans", "{biotype}_{orientation}_sum_million_RPK.txt"),
    run:
        sum_rpk = pd.read_table(
            input.rpk_file,
            index_col=0).sum()
        (sum_rpk / 1000000).to_csv(output.sum_rpk_file, sep="\t")


rule compute_TPM:
    """For a given biotype, compute the corresponding TPM value (reads per kilobase per million mappers)."""
    input:
        rpk_file = rules.compute_RPK.output.rpk_file
    output:
        tpm_file = OPJ("{trimmer}", aligner, "mapped_C_elegans", "{counter}",
            "all_on_C_elegans", "{biotype}_{orientation}_TPM.txt"),
    # The sum must be done over all counted features
    wildcard_constraints:
        biotype = "|".join(["alltypes"])
    # run:
    #     rpk = pd.read_table(input.rpk_file, index_col="gene")
    #     tpm = 1000000 * rpk / rpk.sum()
    #     tpm.to_csv(output.tpm_file, sep="\t")
    wrapper:
        f"file://{wrappers_dir}/compute_TPM"


@wc_applied
def source_quantif(wildcards):
    """Determines from which rule the gathered counts should be sourced."""
    if wildcards.quantif_type == "counts":
        return source_counts(wildcards)
        #return rules.gather_counts.output.counts_table
    elif wildcards.quantif_type == "RPK":
        return rules.compute_RPK.output.rpk_file
    elif wildcards.quantif_type == "TPM":
        return rules.compute_TPM.output.tpm_file
    else:
        raise NotImplementedError("%s is not yet handeled." % wildcards.quantif_type)


rule compute_median_ratio_to_pseudo_ref_size_factors:
    input:
        counts_table = source_counts,
        #counts_table = rules.gather_counts.output.counts_table,
    output:
        median_ratios_file = OPJ("{trimmer}", aligner, "mapped_C_elegans", "{counter}", "all_on_C_elegans", "{biotype}_{orientation}_median_ratios_to_pseudo_ref.txt"),
    run:
        counts_data = pd.read_table(
            input.counts_table,
            index_col=0,
            na_filter=False)
        # http://stackoverflow.com/a/21320592/1878788
        #median_ratios = pd.DataFrame(median_ratio_to_pseudo_ref_size_factors(counts_data)).T
        #median_ratios.index.names = ["median_ratio_to_pseudo_ref"]
        # Easier to grep when not transposed, actually:
        median_ratios = median_ratio_to_pseudo_ref_size_factors(counts_data)
        median_ratios.to_csv(output.median_ratios_file, sep="\t")


# Note that bamCoverage inverses strands:
# https://github.com/fidelram/deepTools/issues/494
def bamcoverage_filter(wildcards):
    if wildcards.orientation == "fwd":
        return "--filterRNAstrand reverse"
    elif wildcards.orientation == "fwd":
        return "--filterRNAstrand forward"
    else:
        return ""


def source_normalizer(wildcards):
    if wildcards.norm_type == "median_ratio_to_pseudo_ref":
        return OPJ(
            f"{wildcards.trimmer}", aligner, "mapped_C_elegans", COUNTERS[0], "all_on_C_elegans",
            "protein_coding_fwd_median_ratios_to_pseudo_ref.txt"),
    elif wildcards.norm_type in COUNT_BIOTYPES:
        return OPJ(
            f"{wildcards.trimmer}", aligner, "mapped_C_elegans", COUNTERS[0], "all_on_C_elegans",
            f"{wildcards.norm_type}_fwd_sum_million_RPK.txt"),
    else:
        raise NotImplementedError(f"{wildcards.norm_type} normalization not implemented yet.")


# Warning: The normalization is done based on a particular count using the first counter
rule make_normalized_bigwig:
    input:
        bam = rules.fuse_bams.output.sorted_bam,
        # TODO: use sourcing function based on norm_type
        size_factor_file = source_normalizer,
        #size_factor_file = rules.compute_coverage.output.coverage
        #median_ratios_file = OPJ("{trimmer}", aligner, "mapped_C_elegans", COUNTERS[0], "all_on_C_elegans", "protein_coding_fwd_median_ratios_to_pseudo_ref.txt"),
        # TODO: compute this
        #scale_factor_file = OPJ(aligner, "mapped_C_elegans", "annotation", "all_%s_on_C_elegans" % size_selected, "pisimi_median_ratios_to_pseudo_ref.txt"),
    output:
        bigwig_norm = OPJ("{trimmer}", aligner, "mapped_C_elegans", "{lib}_{rep}_on_C_elegans_by_{norm_type}_{orientation}.bw"),
        #bigwig = OPJ(aligner, "mapped_C_elegans", "{lib}_{rep}_on_C_elegans_{orientation}.bw"),
    #params:
    #    orient_filter = bamcoverage_filter,
    threads: 4  # to limit memory usage, actually
    benchmark:
        OPJ(log_dir, "{trimmer}", "make_normalized_bigwig", "{lib}_{rep}_by_{norm_type}_{orientation}_benchmark.txt")
    params:
        genome_binned = genome_binned,
    log:
        log = OPJ(log_dir, "{trimmer}", "make_normalized_bigwig", "{lib}_{rep}_by_{norm_type}_{orientation}.log"),
        err = OPJ(log_dir, "{trimmer}", "make_normalized_bigwig", "{lib}_{rep}_by_{norm_type}_{orientation}.err"),
    shell:
        """
        bam2bigwig.sh {input.bam} {params.genome_binned} \\
            {wildcards.lib}_{wildcards.rep} {wildcards.orientation} %s \\
            {input.size_factor_file} {output.bigwig_norm} \\
            > {log.log} 2> {log.err} \\
            || error_exit "bam2bigwig.sh failed"
        """ % LIB_TYPE[-1]
        #"""
        #scale=$(python -c "print 1.0 / ${{size}}")
        #bamCoverage -b {input.bam} {params.orient_filter} \\
        #    -of=bigwig -bs 10 -p=1 \\
        #    --scaleFactor ${{scale}} -o {output.bigwig_norm} \\
        #    1>> {log.make_bigwig_log} 2>> {log.make_bigwig_err} \\
        #    || error_exit "bamCoverage failed"
        #bamCoverage -b {input.bam} --skipNAs {params.orient_filter} \\
        #    -of=bigwig -bs 10 -p=1 \\
        #    -o {output.bigwig} \\
        #    1>> {log.make_bigwig_log} 2>> {log.make_bigwig_err} \\
        #    || error_exit "bamCoverage failed"
        #"""


rule make_bigwig:
    input:
        bam = rules.fuse_bams.output.sorted_bam,
        # TODO: use sourcing function based on norm_type
        #size_factor_file = rules.compute_coverage.output.coverage
        median_ratios_file = OPJ("{trimmer}", aligner, "mapped_C_elegans", COUNTERS[0], "all_on_C_elegans", "protein_coding_fwd_median_ratios_to_pseudo_ref.txt"),
        # TODO: compute this
        #scale_factor_file = OPJ(aligner, "mapped_C_elegans", "annotation", "all_%s_on_C_elegans" % size_selected, "pisimi_median_ratios_to_pseudo_ref.txt"),
    output:
        bigwig_norm = OPJ("{trimmer}", aligner, "mapped_C_elegans", "{lib}_{rep}_on_C_elegans_by_{norm_type}_{orientation}_bamCoverage.bw"),
        #bigwig = OPJ(aligner, "mapped_C_elegans", "{lib}_{rep}_on_C_elegans_{orientation}.bw"),
    params:
        orient_filter = bamcoverage_filter,
    threads: 12  # to limit memory usage, actually
    benchmark:
        OPJ(log_dir, "{trimmer}", "make_bigwig", "{lib}_{rep}_by_{norm_type}_{orientation}_benchmark.txt")
    log:
        log = OPJ(log_dir, "{trimmer}", "make_bigwig", "{lib}_{rep}_by_{norm_type}_{orientation}.log"),
        err = OPJ(log_dir, "{trimmer}", "make_bigwig", "{lib}_{rep}_by_{norm_type}_{orientation}.err"),
    run:
        scale = 1 / float(pd.read_table(
            input.median_ratios_file, index_col=0, header=None).loc[
                f"{wildcards.lib}_{wildcards.rep}"])
        assert scale > 0
        # TODO: make this a function of deeptools version
        no_reads = """Error: The generated bedGraphFile was empty. Please adjust
your deepTools settings and check your input files.
"""
#        no_reads = """[bwClose] There was an error while finishing writing a bigWig file! The output is likely truncated.
#"""
        try:
            shell("""
                cmd="bamCoverage -b {input.bam} {params.orient_filter} \\
                    -of=bigwig -bs 10 -p={threads} \\
                    --scaleFactor %f -o {output.bigwig_norm} \\
                    1>> {log.log} 2>> {log.err}"
                > {log.err}
                echo ${{cmd}} > {log.log}
                eval ${{cmd}} || error_exit "bamCoverage failed"
            """ % scale)
        except CalledProcessError as e:
            if last_lines(log.err, 2) == no_reads:
                with open(output.bigwig_norm, "w") as bwfile:
                    bwfile.write("")
            else:
                raise


def source_bigwigs_for_merge(wildcards):
    return [OPJ("{trimmer}".format(trimmer=wildcards.trimmer), aligner, "mapped_C_elegans", "{lib}_{{rep}}_on_C_elegans_by_{norm_type}_{orientation}.bw".format(lib=wildcards.lib, norm_type=wildcards.norm_type, orientation=wildcards.orientation).format(rep=rep)) for rep in REPS if frozenset((wildcards.lib, rep)) not in forbidden]
    #return expand(OPJ(aligner, "mapped_C_elegans", "{lib}_{rep}_on_C_elegans_norm_{orientation}.bw"), lib=[wildcards.lib], rep=[rep for rep in REPS if frozenset((wildcards.lib, rep)) not in forbidden], orientation=[wildcards.orientation])


rule merge_bigwig_reps:
    """This rule merges bigwig files by computing a mean across replicates."""
    input:
        source_bigwigs_for_merge,
        #expand(OPJ(aligner, "mapped_C_elegans", "{{lib}}_{rep}_on_C_elegans_norm_{{orientation}}.bw"), rep=[rep for rep in REPS if frozenset((wildcards.lib, rep)) not in forbidden]),
    output:
        bw = OPJ("{trimmer}", aligner, "mapped_C_elegans", "{lib}_mean_on_C_elegans_by_{norm_type}_{orientation}.bw"),
    log:
        warnings = OPJ(log_dir, "{trimmer}", "merge_bigwig_reps", "{lib}_mean_on_C_elegans_by_{norm_type}_{orientation}.warnings"),
    threads: 2  # to limit memory usage, actually
    run:
        with warn_context(log.warnings) as warn:
            try:
                bws = [pyBigWig.open(bw_filename) for bw_filename in input]
                #for bw_filename in input:
                #    bws.append(pyBigWig.open(bw_filename))
            except RuntimeError as e:
                warn(str(e))
                warn("Generating empty file.\n")
                # Make the file empty
                open(output.bw, "w").close()
            else:
                bw_out = pyBigWig.open(output.bw, "w")
                bw_out.addHeader(list(chrom_sizes.items()))
                for (chrom, chrom_len) in chrom_sizes.items():
                    try:
                        assert all([bw.chroms()[chrom] == chrom_len for bw in bws])
                    except KeyError as e:
                        warn(str(e))
                        warn(f"Chromosome {chrom} might be missing from one of the input files.\n")
                        for filename, bw in zip(input, bws):
                            msg = " ".join([f"{filename}:", *list(bw.chroms().keys())])
                            warn(f"{msg}:\n")
                        #raise
                        warn(f"The bigwig files without {chrom} will be skipped.\n")
                    to_use = [bw for bw in bws if chrom in bw.chroms()]
                    if to_use:
                        means = np.nanmean(np.vstack([bw.values(chrom, 0, chrom_len) for bw in to_use]), axis=0)
                    else:
                        means = np.zeros(chrom_len)
                    # bin size is 10
                    bw_out.addEntries(chrom, 0, values=np.nan_to_num(means[0::10]), span=10, step=10)
                bw_out.close()
                for bw in bws:
                    bw.close()


from rpy2.robjects import Formula, StrVector
#from rpy2.rinterface import RRuntimeError
rule differential_expression:
    input:
        counts_table = source_counts,
        summary_table = rules.gather_read_counts_summaries.output.summary_table,
    output:
        deseq_results = OPJ("{trimmer}", aligner, "mapped_C_elegans", "{counter}", "deseq2", "{contrast}", "{orientation}_{biotype}", "deseq2.txt"),
        up_genes = OPJ("{trimmer}", aligner, "mapped_C_elegans", "{counter}", "deseq2", "{contrast}", "{orientation}_{biotype}", "up_genes.txt"),
        down_genes = OPJ("{trimmer}", aligner, "mapped_C_elegans", "{counter}", "deseq2", "{contrast}", "{orientation}_{biotype}", "down_genes.txt"),
        counts_and_res = OPJ("{trimmer}", aligner, "mapped_C_elegans", "{counter}", "deseq2", "{contrast}", "{orientation}_{biotype}", "counts_and_res.txt"),
    threads: 4  # to limit memory usage, actually
    run:
        counts_data = pd.read_table(input.counts_table, index_col="gene")
        summaries = pd.read_table(input.summary_table, index_col=0)
        # Running DESeq2
        #################
        (cond, ref) = CONTRAST2PAIR[wildcards.contrast]
        if not any(counts_data[f"{ref}_{rep}"].any() for rep in REPS):
            warnings.warn(
                "Reference data is all zero.\nSkipping %s_%s_%s" % (
                    wildcards.contrast, wildcards.orientation, wildcards.biotype))
            for outfile in output:
                shell(f"echo 'NA' > {outfile}")
        else:
            try:
                try:
                    contrast = StrVector(["lib", cond, ref])
                    formula = Formula("~ lib")
                    res, size_factors = do_deseq2(COND_NAMES, CONDITIONS, counts_data, formula=formula, contrast=contrast)
                #except RRuntimeError as e:
                except RuntimeError as e:
                    warnings.warn(
                        "Probably not enough usable data points to perform DESeq2 analyses:\n%s\nSkipping %s_%s_%s" % (
                            str(e), wildcards.contrast, wildcards.orientation, wildcards.biotype))
                    for outfile in output:
                        shell(f"echo 'NA' > {outfile}")
                else:
                    # Determining fold-change category
                    ###################################
                    set_de_status = status_setter(LFC_CUTOFFS, "log2FoldChange")
                    #counts_and_res = add_tags_column(pd.concat((counts_and_res, res), axis=1).assign(status=res.apply(set_de_status, axis=1)), input.tags_table, "small_type")
                    res = res.assign(status=res.apply(set_de_status, axis=1))
                    # Converting gene IDs
                    ######################
                    with open(OPJ(convert_dir, "wormid2cosmid.pickle"), "rb") as dict_file:
                        res = res.assign(cosmid=res.apply(column_converter(load(dict_file)), axis=1))
                    with open(OPJ(convert_dir, "wormid2name.pickle"), "rb") as dict_file:
                        res = res.assign(name=res.apply(column_converter(load(dict_file)), axis=1))
                    # Just to see if column_converter works also with named column, and not just index:
                    # with open(OPJ(convert_dir, "cosmid2name.pickle"), "rb") as dict_file:
                    #     res = res.assign(name=res.apply(column_converter(load(dict_file), "cosmid"), axis=1))
                    ##########################################
                    # res.to_csv(output.deseq_results, sep="\t", na_rep="NA", decimal=",")
                    res.to_csv(output.deseq_results, sep="\t", na_rep="NA")
                    # Joining counts and DESeq2 results in a same table and determining up- or down- regulation status
                    counts_and_res = counts_data
                    for normalizer in SIZE_FACTORS:
                        if normalizer == "median_ratio_to_pseudo_ref":
                            ## Adapted from DESeq paper (doi:10.1186/gb-2010-11-10-r106) but
                            ## add pseudo-count to compute the geometric mean, then remove it
                            #pseudo_ref = (counts_data + 1).apply(gmean, axis=1) - 1
                            #def median_ratio_to_pseudo_ref(col):
                            #    return (col / pseudo_ref).median()
                            #size_factors = counts_data.apply(median_ratio_to_pseudo_ref, axis=0)
                            size_factors = median_ratio_to_pseudo_ref_size_factors(counts_data)
                        else:
                            #raise NotImplementedError(f"{normalizer} normalization not implemented")
                            size_factors = summaries.loc[normalizer]
                        by_norm = counts_data / size_factors
                        by_norm.columns = by_norm.columns.map(lambda s: "%s_by_%s" % (s, normalizer))
                        counts_and_res = pd.concat((counts_and_res, by_norm), axis=1)
                    #counts_and_res = add_tags_column(pd.concat((counts_and_res, res), axis=1).assign(status=res.apply(set_de_status, axis=1)), input.tags_table, "small_type")
                    counts_and_res = pd.concat((counts_and_res, res), axis=1)
                    counts_and_res.to_csv(output.counts_and_res, sep="\t", na_rep="NA")
                    # Saving lists of genes gaining or loosing siRNAs
                    up_genes = list(counts_and_res.query(f"status in {UP_STATUSES}").index)
                    down_genes = list(counts_and_res.query(f"status in {DOWN_STATUSES}").index)
                    with open(output.up_genes, "w") as up_file:
                        if up_genes:
                            up_file.write("%s\n" % "\n".join(up_genes))
                        else:
                            up_file.truncate(0)
                    with open(output.down_genes, "w") as down_file:
                        if down_genes:
                            down_file.write("%s\n" % "\n".join(down_genes))
                        else:
                            down_file.truncate(0)
            # Does not seem to be caught...
            except KeyError as err:
                err_msg = str(err)
                warnings.warn("XXXXXXXXXXXXXXXXX Got KeyError XXXXXXXXXXXXXXXXX")
                if err_msg[:17] == "Trying to release":
                    warnings.warn(err_msg)
                    warnings.warn(f"Skipping {wildcards.contrast}_{wildcards.orientation}_{wildcards.biotype}\n")
                    for outfile in output:
                        shell(f"echo 'NA' > {outfile}")
                else:
                    raise
            except:
                warnings.warn("XXXXXXXXXXXXXXXXX Got another exception XXXXXXXXXXXXXXXXX")
                raise


rule make_lfc_distrib_plot:
    input:
        deseq_results = rules.differential_expression.output.deseq_results,
    output:
        lfc_plot = OPJ("{trimmer}", "figures", aligner, "{counter}", "{contrast}", "{orientation}_{biotype}", "{fold_type}_distribution.pdf"),
    run:
        if test_na_file(input.deseq_results):
            warnings.warn(
                "No DESeq2 results for %s_%s_%s. Making dummy output." % (
                    wildcards.contrast, wildcards.orientation, wildcards.biotype))
            for outfile in output:
                shell(f"echo 'NA' > {outfile}")
        else:
            res = pd.read_table(input.deseq_results, index_col=0)
            save_plot(
                output.lfc_plot, plot_lfc_distribution,
                res, wildcards.contrast, wildcards.fold_type,
                title="log fold-change distribution for %s, %s_%s" % (
                    wildcards.contrast, wildcards.orientation, wildcards.biotype))

def set_lfc_range(wildcards):
    return LFC_RANGE[wildcards.biotype]


# takes wildcards, gene list name or path
# returns list of wormbase ids
get_id_list = make_id_list_getter(gene_lists_dir, avail_id_lists)
rule make_MA_plot:
    input:
        deseq_results = rules.differential_expression.output.deseq_results,
    output:
        MA_plot = OPJ("{trimmer}", "figures", aligner, "{counter}", "{contrast}", "{orientation}_{biotype}", "MA_with_{id_list}.pdf"),
    params:
        lfc_range = set_lfc_range,
        id_list = get_id_list,
    run:
        if test_na_file(input.deseq_results):
            warnings.warn(
                "No DESeq2 results for %s_%s_%s. Making dummy output." % (
                    wildcards.contrast, wildcards.orientation, wildcards.biotype))
            for outfile in output:
                shell(f"echo 'NA' > {outfile}")
        else:
            res = pd.read_table(input.deseq_results, index_col=0)
            if params.id_list is None:
                grouping = "status"
                group2colour = None
            else:
                grouping = params.id_list
                group2colour = (wildcards.id_list, sns.xkcd_rgb["orange"])
            title = f"MA-plot for {wildcards.contrast}, {wildcards.orientation}_{wildcards.biotype}"
            if mpl.rcParams.get("text.usetex", False):
                title = texscape(title)
            save_plot(
                output.MA_plot, plot_MA, res,
                grouping=grouping,
                group2colour=group2colour,
                lfc_range=params.lfc_range,
                fold_type="log2FoldChange",
                title=title)


##################
# Metagene plots #
##################

# rule gather_annotations:
#     input:
#         expand(OPJ(annot_dir, "{biotype}.gtf"), biotype=ANNOT_BIOTYPES),
#     output:
#         merged_gtf = OPJ(local_annot_dir, "all_annotations.gtf"),
#         merged_gtf_gz = OPJ(local_annot_dir, "all_annotations.gtf.gz"),
#         index = OPJ(local_annot_dir, "all_annotations.gtf.gz.tbi"),
#         #merged_bed = OPJ(local_annot_dir, "all_annotations.bed"),
#     message:
#         "Gathering annotations for {}".format(", ".join(ANNOT_BIOTYPES))
#     shell:
#         """
#         sort -k1,1 -k4,4n -m {input} | tee {output.merged_gtf} | bgzip > {output.merged_gtf_gz}
#         tabix -p gff {output.merged_gtf_gz}
#         #ensembl_gtf2bed.py {output.merged_gtf} > {output.merged_bed}
#         """

# For metagene analyses:
#- Extract transcripts
#- Take only genes with TSS identified, among isoforms of a same gene, take the 3'end closest to the TSS
#- Avoid overlap between gene1 UTR and a gene2 (all biotypes except piRNA and antisense) UTR or TSS based on the 3-primes furthest of the TSS of their gene

rule extract_transcripts:
    input:
        # We want to get back to the original "protein_coding" annotations,
        # not separated by UTR or CDS
        #gtf = rules.gather_annotations.output.merged_gtf,
        #gtf = OPJ(annot_dir, "genes.gtf"),
        #DNA_transposon_gtf = OPJ(annot_dir, "DNA_transposons_rmsk.gtf"),
        #RNA_transposon_gtf = OPJ(annot_dir, "RNA_transposons_rmsk.gtf"),
        OPJ(annot_dir, "genes.gtf"),
        OPJ(annot_dir, "DNA_transposons_rmsk.gtf"),
        OPJ(annot_dir, "RNA_transposons_rmsk.gtf"),
    output:
        bed = OPJ(local_annot_dir, "transcripts_all.bed"),
    run:
        with finput(files=input) as gtf_source, open(output.bed, "w") as bed_dest:
            for (chrom, _, bed_type, gtf_start, end,
                 score, strand, _, annot_field) in map(strip_split, gtf_source):
                if bed_type != "transcript":
                    continue
                annots = dict([(k, v.rstrip(";").strip('"')) for (k, v) in [
                    f.split(" ") for f in annot_field[:-1].split("; ")]])
                gene_biotype = annots["gene_biotype"]
                # piRNA should not be highly expressed
                #if gene_biotype in {"miRNA", "piRNA", "antisense"}:
                if gene_biotype not in {"protein_coding", "pseudogene"}:
                    continue
                transcript_id = annots["transcript_id"]
                gene_id = annots["gene_id"]
                print(chrom, int(gtf_start) - 1, end,
                      gene_id, gene_biotype, strand,
                      sep="\t", file=bed_dest)
    # shell:
    #     """
    #     extract_transcripts_from_gtf.py \\
    #         -g Caenorhabditis_elegans/Ensembl/WBcel235/Annotation/Genes/genes.gtf \\
    #         -o {params.annot_dir} \\
    #         -i "piRNA" "antisense" \\
    #         || error_exit "extract_transcripts failed"
    #     """

#- Extract transcripts
#- Take only genes with TSS identified, among isoforms of a same gene, take the 3'end closest to the TSS
#- Avoid overlap between gene1 UTR and a gene2 (all biotypes except piRNA and antisense) UTR or TSS based on the 3-primes furthest of the TSS of their gene


class Gene(object):
    """This is used to fuse coordinates of transcripts deriving from a same gene.
    *wide_end* corresponds to the furthest-extending 3' end. It will be used to
    determine possible "interferences" (risks of small RNA confusion) between genes."""
    __slots__ = ("gene_id", "chrom", "start", "tight_end", "wide_end", "biotype", "strand")
    def __init__(self, gene_id, chrom, start, end, strand, biotype):
        self.gene_id = gene_id
        self.chrom = chrom
        self.biotype = biotype
        self.strand = strand
        if strand == "+":
            self.start = int(start)
            self.tight_end = int(end)
            self.wide_end = int(end)
        else:
            self.start = int(end)
            self.tight_end = int(start)
            self.wide_end = int(start)

    def add_transcript(self, gene_id, chrom, start, end, strand):
        # assert gene_id == self.gene_id
        assert chrom == self.chrom
        assert strand == self.strand
        if strand == "+":
            self.start = min(int(start), self.start)
            self.tight_end = min(self.tight_end, int(end))
            self.wide_end = max(self.wide_end, int(end))
        else:
            self.start = max(int(end), self.start)
            self.tight_end = max(self.tight_end, int(start))
            self.wide_end = min(self.wide_end, int(start))

    @property
    def left(self):
        if self.strand == "+":
            return self.start
        else:
            return self.wide_end

    @property
    def right(self):
        if self.strand == "+":
            return self.wide_end
        else:
            return self.start

    def too_close_before(self, other, min_dist):
        """Returns True if *other* is strictly less than *min_dist*
        away after *self*."""
        return (other.left - self.right) - 1 < min_dist

    def tight_bed(self):
        if self.strand == "+":
            return "\t".join([
                self.chrom, str(self.start), str(self.tight_end),
                self.gene_id, self.biotype, self.strand])
        else:
            return "\t".join([
                self.chrom, str(self.tight_end), str(self.start),
                self.gene_id, self.biotype, self.strand])


rule adjust_TSS:
    """Extends transcript coordinates if a new TSS can be found
    in the data provided by Kruesi et al 2013 (for L3 stage)."""
    input:
        in_bed = rules.extract_transcripts.output.bed,
    output:
        out_bed = OPJ(local_annot_dir, "transcripts_with_TSS.bed"),
    run:
        TSS_dict = defaultdict(set)
        with open("/pasteur/entites/Mhe/Genomes/C_elegans/TSS_annotations/Kruesi_TSS_coding_WT_L3_ce11_sorted.bed", "r") as TSS_bedfile:
            for chrom, bed_start, _, gene_info in map(strip_split, TSS_bedfile):
                # chrI	35384	35385	WBGene00022279|sesn-1@chrI:27595-32482|-1
                # chrI	47149	47150	WBGene00044345|Y48G1C.12@chrI:47472-49819|1
                # chrI	70172	70173	WBGene00000812|csk-1@chrI:71858-81071|1
                # chrI	110690	110691	WBGene00004274|rab-11.1@chrI:108686-110077|-1
                gene_id, _, strand = gene_info.split("|")
                if strand == "1":
                    strand = "+"
                elif strand == "-1":
                    strand = "-"
                else:
                    raise NotImplementedError("Unexpected strand information: %s" % strand)
                TSS_dict[gene_id].add((chrom, bed_start, strand))
        with open(input.in_bed, "r") as in_bedfile, open(output.out_bed, "w") as out_bedfile:
            for (chrom, start, end, gene_id, gene_biotype, strand) in map(strip_split, in_bedfile):
                if gene_id in TSS_dict:
                    # The set should have only one element
                    ((tss_chrom, tss_start, tss_strand),) = TSS_dict[gene_id]
                    assert tss_chrom == "chr%s" % chrom
                    assert tss_strand == strand
                    # TODO: check this is correct
                    if strand == "+":
                        #assert int(tss_start) <= int(start), "%s has a problem:\ntss_start:%s\tstart:%s" % (gene_id, tss_start, start)
                        #if int(tss_start) > int(start):
                        #    print("%s (+) has a problem:\ntss_start:%s\tstart:%s" % (gene_id, tss_start, start))
                        #    continue
                        start = tss_start
                    else:
                        #assert int(tss_start) >= int(end)
                        #if int(tss_start) < int(end):
                        #    print("%s (-) has a problem:\ntss_start:%s\tend:%s" % (gene_id, tss_start, end))
                        #    continue
                        end = tss_start
                print(chrom, start, end, gene_id, gene_biotype, strand, sep="\t", file=out_bedfile)


# TODO: merge also with transposon
rule merge_transcripts:
    """Determine the span of each gene, based on its transcripts spans."""
    input:
        in_bed = rules.adjust_TSS.output.out_bed,
    output:
        out_bed = OPJ(local_annot_dir, "transcripts_merged_by_gene.bed"),
    run:
        # Use OrderedDict in the hope that this will speed up the sorting
        genes = OrderedDict()
        with open(input.in_bed, "r") as in_bedfile:
            for (chrom, start, end, gene_id, gene_biotype, strand) in map(strip_split, in_bedfile):
                if gene_id in genes:
                    genes[gene_id].add_transcript(gene_id, chrom, start, end, strand)
                else:
                    genes[gene_id] = Gene(gene_id, chrom, start, end, strand, gene_biotype)
        with open(output.out_bed, "w") as out_bedfile:
            for gene in genes.values():
                print(
                    gene.chrom, str(gene.left), str(gene.right),
                    gene.gene_id, gene.biotype, gene.strand,
                    sep="\t", file=out_bedfile)


rule resort_transcript_bed:
    input:
        in_bed = rules.merge_transcripts.output.out_bed,
    output:
        out_bed = OPJ(local_annot_dir, "transcripts_merged_resorted.bed"),
    shell:
        """
        sort -k1,1 -k4,4n -k5,5n {input.in_bed} > {output.out_bed}
        """

# TODO
rule filter_out_interfering_transcripts:
    input:
        in_bed = rules.resort_transcript_bed.output.out_bed,
    output:
        out_bed = OPJ(local_annot_dir, "transcripts_merged_isolated_%d.bed" % MIN_DIST),
    run:
        last_chrom = ""
        # Zone occupied by genes previously encountered
        span = Gene("span", "", -MIN_DIST, -MIN_DIST, "+", "span")
        previous_gene = span
        with open(input.in_bed, "r") as in_bedfile, open(output.out_bed, "w") as out_bedfile:
            for (chrom, start, end, gene_id, gene_biotype, strand) in map(strip_split, in_bedfile):
                # This happens at the first entrance in the loop because chrom != ""
                if chrom != last_chrom:
                    # see if we have a previous_gene that was not "discarded"
                    if previous_gene.gene_id != "span":
                        print(previous_gene.tight_bed(), sep="\t", file=out_bedfile)
                    # Reset stuff
                    last_chrom = chrom
                    # Zone occupied by genes previously encountered
                    span = Gene("span", chrom, -MIN_DIST, -MIN_DIST, "+", "span")
                    previous_gene = span
                this_gene = Gene(gene_id, chrom, start, end, strand, gene_biotype)
                # This code assumes that bed entries are sorted
                # by increasing order of start coordinate, then end coordinate
                if not span.too_close_before(this_gene, MIN_DIST):
                    # If nothing is too close before this_gene,
                    # then the previous gene will not be too close
                    # before other genes that come later
                    # we can write it in the output if no gene was too close before it
                    if previous_gene.gene_id != "span":
                        print(previous_gene.tight_bed(), sep="\t", file=out_bedfile)
                    # this_gene will have the possibility to be written next iteration
                    previous_gene = this_gene
                else:
                    # this_gene is too close to the previous ones.
                    # We use this to "discard" this_gene
                    previous_gene = span
                span.add_transcript("span", chrom, start, end, "+")


rule select_genes_for_meta_profile:
    """Creates a bed file for metagene analysis based (on TSS being known: not now) and length being enough."""
    input:
        in_bed = rules.filter_out_interfering_transcripts.output.out_bed
    output:
        out_bed = OPJ(local_annot_dir, "transcripts_merged_isolated_%d_{biotype}_min_%d.bed" % (MIN_DIST, META_MIN_LEN)),
    run:
        TSS_dict = defaultdict(set)
        with open("/pasteur/entites/Mhe/Genomes/C_elegans/TSS_annotations/Kruesi_TSS_coding_WT_L3_ce11_sorted.bed", "r") as TSS_bedfile:
            for chrom, bed_start, _, gene_info in map(strip_split, TSS_bedfile):
                # chrI	35384	35385	WBGene00022279|sesn-1@chrI:27595-32482|-1
                # chrI	47149	47150	WBGene00044345|Y48G1C.12@chrI:47472-49819|1
                # chrI	70172	70173	WBGene00000812|csk-1@chrI:71858-81071|1
                # chrI	110690	110691	WBGene00004274|rab-11.1@chrI:108686-110077|-1
                gene_id, _, strand = gene_info.split("|")
                if strand == "1":
                    strand = "+"
                elif strand == "-1":
                    strand = "-"
                else:
                    raise NotImplementedError("Unexpected strand information: %s" % strand)
                TSS_dict[gene_id].add((chrom, bed_start, strand))
        with open(input.in_bed, "r") as in_bedfile, open(output.out_bed, "w") as out_bedfile:
            for (chrom, start, end, gene_id, gene_biotype, strand) in map(strip_split, in_bedfile):
                if wildcards.biotype != gene_biotype:
                    continue
                #assert wildcards.biotype == gene_biotype
                #if (gene_biotype == "protein_coding") and (gene_id not in TSS_dict):
                #    continue
                if int(end) - int(start) < META_MIN_LEN:
                    continue
                print(
                    chrom, start, end,
                    gene_id, gene_biotype, strand,
                    sep="\t", file=out_bedfile)


# In http://dx.doi.org/10.7554/eLife.00808, p. 27:
# -----
# To compare GRO-seq signal across genes, we scaled genes to be the same length,
# allowing us to average the GRO-seq signal across them. To avoid small genes
# that could affect the sensitivity of our analyses, we required that genes be
# ≥1.5 kb in length. These genes were scaled to the same length as follows: the
# 5′ end (1000 bp upstream to 500 bp downstream of the TSS) and the 3′ end (500
# bp upstream to 1000 bp downstream of the WB stop site) were not scaled, and the
# remainder of the gene was scaled to a length of 2 kb. We predicted that leaving
# the ends of the gene unscaled might allow us to better identify any effects
# that occurred at the ends of genes.
# -----


def meta_params(wildcards):
    biotype = wildcards.biotype
    if biotype == "protein_coding":
        return " ".join([
            "-b %d" % META_MARGIN,
            "--unscaled5prime %d" % UNSCALED_INSIDE,
            "-m %d" % META_SCALE,
            "--unscaled3prime %d" % UNSCALED_INSIDE,
            "-a %d" % META_MARGIN])
    if biotype in {"DNA_transposons_rmsk", "RNA_transposons_rmsk"}:
        return " ".join([
            "-b %d" % META_MARGIN,
            "--unscaled5prime %d" % UNSCALED_INSIDE,
            "-m %d" % META_SCALE,
            "--unscaled3prime %d" % UNSCALED_INSIDE,
            "-a %d" % META_MARGIN])
    if biotype.startswith("protein_coding_"):
        return " ".join([
            "-b %d" % 0,
            "--unscaled5prime %d" % UNSCALED_INSIDE,
            "-m %d" % META_SCALE,
            "--unscaled3prime %d" % UNSCALED_INSIDE,
            "-a %d" % 0])
    else:
        raise NotImplementedError("Metagene analyses for %s not implemented." % biotype)


# TODO: make scripts to generate bed given gene names and one to plot the metaprofile
rule plot_meta_profile_mean:
    input:
        bigwig = rules.merge_bigwig_reps.output.bw,
        bed = rules.select_genes_for_meta_profile.output.out_bed,
    output:
        OPJ("{trimmer}", "figures", aligner, "{lib}_by_{norm_type}_mean", "{orientation}_on_merged_isolated_%d_{biotype}_min_%d_meta_profile.pdf" % (MIN_DIST, META_MIN_LEN)),
    params:
        meta_params = meta_params,
        # before = META_MARGIN,
        # after = META_MARGIN,
        # body_length = META_SCALE,
        # unscaled_5 = UNSCALED_INSIDE,
        # unscaled_3 = UNSCALED_INSIDE,
    log:
        plot_TSS_log = OPJ(log_dir, "{trimmer}", "plot_meta_profile", "{lib}", "by_{norm_type}", "{orientation}_on_merged_isolated_%d_{biotype}_min_%d.log" % (MIN_DIST, META_MIN_LEN)),
        plot_TSS_err = OPJ(log_dir, "{trimmer}", "plot_meta_profile", "{lib}", "by_{norm_type}", "{orientation}_on_merged_isolated_%d_{biotype}_min_%d.err" % (MIN_DIST, META_MIN_LEN)),
    threads: 12  # to limit memory usage, actually
    run:
        if file_len(input.bed):
            shell("""tmpdir=$(mktemp -dt "plot_meta_profile_{wildcards.trimmer}_{wildcards.lib}_{wildcards.orientation}_{wildcards.biotype}_%d.XXXXXXXXXX")
computeMatrix scale-regions -S {input.bigwig} \\
    -R {input.bed} \\
    {params.meta_params} \\
    -p 1 \\
    -out ${{tmpdir}}/meta_profile.gz \\
    --skipZeros \\
    1> {log.plot_TSS_log} \\
    2> {log.plot_TSS_err} \\
    || error_exit "computeMatrix failed"
plotProfile -m ${{tmpdir}}/meta_profile.gz -out {output} \\
    1>> {log.plot_TSS_log} \\
    2>> {log.plot_TSS_err} \\
    || error_exit "plotProfile failed"
rm -rf ${{tmpdir}}
""" % META_MIN_LEN)
        else:
            warnings.warn("No regions selected in {input.bed}. Generating empty figure.\n")
            shell("""> {output}""")

onsuccess:
    print("PRO-seq analysis finished.")
    cleanup_and_backup(output_dir, config, delete=True)
onerror:
    shell(f"rm -rf {output_dir}_err")
    shell(f"cp -rp {output_dir} {output_dir}_err")
    cleanup_and_backup(output_dir + "_err", config)
    print("PRO-seq analysis failed.")