Skip to content
Snippets Groups Projects
main.py 23.54 KiB
import typer
import sys
import json
import pandas as pd
import csv
import tempfile
import matplotlib.pyplot as plt
from operator import itemgetter
from pandas.errors import ParserError
from typing_extensions import Annotated
from typing import Optional, List
from pathlib import Path
from pydantic import BaseModel, ValidationError
import frontmatter
from enum import Enum
from rich.console import Console
import re
import requests
from Bio.PDB import PDBParser, MMCIFIO
import tarfile
import xml.etree.ElementTree as ET

console = Console()
app = typer.Typer()


class LayoutEnum(str, Enum):
    article = "article"
    db = "db"


class TableArticle(BaseModel):
    doi: str


class TableColumns(BaseModel):
    article: TableArticle
    Sensor: Optional[str] = None
    Activator: Optional[str] = None
    Effector: Optional[str] = None
    PFAM: Optional[str] = None


class RelevantAbstract(BaseModel):
    doi: str


class FrontMatter(BaseModel):
    title: str
    layout: LayoutEnum
    tableColumns: TableColumns
    relevantAbstracts: List[RelevantAbstract]
    contributors: List[str]


@app.command()
def lint(
    file: Annotated[
        Path,
        typer.Option(
            exists=False,
            file_okay=True,
            writable=True,
        ),
    ],
):
    console.rule(f"[bold blue]{file.name}", style="blue")

    with open(file) as f:
        metadata, _ = frontmatter.parse(f.read())
        # print(metadata)
        try:
            FrontMatter.model_validate(metadata)
        except ValidationError as exc:
            for err in exc.errors():
                console.print(
                    f"[red]{err['msg']} : {err['type']} {' -> '.join([str(l) for l in err['loc']])}"
                )
            # raise
            sys.exit(1)
        else:
            console.print("[green] Everything is alright")


@app.command(
    help="Download all structures from statistics file. Generate PAE from raw data and organize these files per systems"
)
def structure(
    user: Annotated[str, typer.Option(help="username credential")],
    password: Annotated[str, typer.Option(help="password credential")],
    stat: Annotated[
        Path,
        typer.Option(
            exists=True,
            file_okay=True,
            dir_okay=True,
            writable=False,
            readable=True,
            resolve_path=True,
        ),
    ],
    output: Annotated[
        Path,
        typer.Option(
            exists=False,
            file_okay=False,
            dir_okay=True,
            writable=True,
            readable=True,
            resolve_path=True,
        ),
    ],
):
    with open(stat, "r") as stat_f:
        reader = csv.DictReader(stat_f, delimiter="\t")
        count_row = 0
        for row in reader:
            count_row += 1
            system_dir_name = row["system"].lower()
            pdb_path_file = Path(row["pdb"])
            foldseek_html_file = Path(str(pdb_path_file).split(".pdb")[0] + ".html")
            png_structure = Path(str(pdb_path_file).split(".pdb")[0] + ".png")
            files = [
                {"f": pdb_path_file, "d": "PDB"},
                {"f": Path(row["pae_table"]), "d": "PAE"},
                {"f": Path(row["fasta_file"]), "d": "Fastas"},
                {"f": foldseek_html_file, "d": "foldseek_pdb_html"},
                {"f": png_structure, "d": "png"},
            ]

            target_dir = output / system_dir_name
            target_dir.mkdir(parents=True, exist_ok=True)
            base_url = "https://data.atkinson-lab.com/DefenceFinder"
            console.rule(f"[bold blue]{count_row} / {system_dir_name}", style="blue")
            for f in files:
                console.print(f"[bold blue]{f['d']}", style="blue")
                str_f = str(f["f"])
                if str_f and str_f != "." and str_f != "" and str_f != "na":
                    # get the file from atkinson lab
                    target_file = target_dir / f["f"]
                    file_to_fetch = f"{base_url}/{f['d']}/{f['f']}"
                    console.print(f"Fetch : {file_to_fetch}")
                    response = requests.get(
                        file_to_fetch,
                        auth=(user, password),
                        allow_redirects=True,
                        stream=True,
                    )
                    with open(target_file, "wb") as fh:
                        for chunk in response.iter_content(chunk_size=1024):
                            if chunk:
                                fh.write(chunk)

                    if f["d"] == "PDB":
                        # pdb2cif
                        pdb2cif(target_file)
                    if f["d"] == "PAE":
                        png_file = str(target_file).split(".tsv")[0] + ".pae.png"
                        try:
                            pae2png(target_file, png_file)
                        except ParserError as err:
                            console.print(
                                f"[red] file {str(target_file.name)} cannot be parsed"
                            )
                            print(err)


@app.command()
def systems(
    dir: Annotated[
        Path,
        typer.Option(exists=False, file_okay=False, readable=True, dir_okay=True),
    ],
    pfam: Annotated[
        Path,
        typer.Option(
            exists=False,
            file_okay=True,
            writable=True,
        ),
    ],
    output: Annotated[
        Path,
        typer.Option(
            file_okay=True,
            dir_okay=False,
            writable=True,
            resolve_path=True,
        ),
    ],
):
    with open(pfam, "r") as pf:
        pfam_df = pd.read_csv(pf, index_col="AC", keep_default_na=False)
        systems = []
        if output.exists():
            output.unlink()
        with open(output, "a") as ty:
            for file in dir.iterdir():
                if file.suffix == ".md":
                    console.rule(f"[bold blue]{file.name}", style="blue")
                    with open(file) as f:
                        metadata, _ = frontmatter.parse(f.read())
                        del metadata["layout"]
                        sanitizedMetadata = {**metadata}
                        if "tableColumns" in sanitizedMetadata:
                            table_data = sanitizedMetadata["tableColumns"]
                            if "PFAM" in table_data:
                                pfams_list = [
                                    pfam.strip()
                                    for pfam in table_data["PFAM"].split(",")
                                ]
                                pfam_metadata = list()
                                for pfam in pfams_list:
                                    try:
                                        pfam_obj = pfam_df.loc[[pfam]]
                                        # print(pfam_obj)
                                        pfam_to_dict = pfam_obj.to_dict(orient="index")
                                        pfam_to_dict[pfam]["AC"] = pfam
                                        flatten_value = pfam_to_dict[pfam]
                                        pfam_metadata.append(flatten_value)
                                    except KeyError as err:
                                        console.print(f"[bold red]{err}", style="red")
                                        console.print(
                                            f"[bold red]No pfam entry or {pfam}",
                                            style="red",
                                        )
                                        continue
                                sanitizedMetadata["PFAM"] = pfam_metadata

                            if "article" in table_data:
                                sanitizedMetadata["doi"] = table_data["article"]["doi"]
                                if "abstract" in table_data["article"]:
                                    sanitizedMetadata["abstract"] = table_data[
                                        "article"
                                    ]["abstract"]
                                del table_data["article"]
                            if "PFAM" in table_data:
                                del table_data["PFAM"]
                            del sanitizedMetadata["tableColumns"]
                            sanitizedMetadata = {**sanitizedMetadata, **table_data}
                            systems.append(sanitizedMetadata)

            json_object = json.dumps(systems, indent=2)
            ty.write(json_object)


@app.command()
def system_operon_structure(
    output: Annotated[
        Path,
        typer.Option(
            file_okay=True,
            dir_okay=False,
            writable=True,
            resolve_path=True,
        ),
    ] = "./system-structures.csv",
    structure: Annotated[
        Path,
        typer.Option(
            file_okay=True,
            dir_okay=False,
            writable=False,
            resolve_path=True,
        ),
    ] = "./all_predictions_statistics.tsv",
    versions: Annotated[List[str], typer.Option(help="Defense finder model")] = [
        "1.2.2",
        "1.2.3",
        "v1.2.4",
    ],
    tags: Annotated[List[str], typer.Option(help="Defense finder model")] = [
        "1.2.2",
        "1.2.3",
        "1.2.4",
    ],
):

    # get defense finder model from github
    releases = zip(versions, tags)
    model_dirs = list(download_model_release(releases))

    systems = []
    with open(structure) as tsvfile:
        tsvreader = csv.DictReader(tsvfile, delimiter="\t")
        for row in tsvreader:
            systems.append({"system": row["system"], "subsystem": row["subsystem"]})
    system_genes = []
    system_genes_got = set()
    for system_def in systems:
        system, subsystem = itemgetter("system", "subsystem")(system_def)
        system_id = f"{system}-{subsystem}"
        list_paths = list(gen_model_path(model_dirs))
        if (
            system != "#N/A"
            and subsystem != "#N/A"
            and system_id not in system_genes_got
        ):
            system_genes_got.add(system_id)
            def_path = find_model_definition(system, subsystem, list_paths)
            in_exchangeables = False
            current_gene = {}
            exchangeables = []
            with open(def_path["path"]) as file:
                for event, elem in ET.iterparse(file, events=("start", "end")):
                    if event == "start":
                        if (
                            elem.tag == "gene"
                            and not in_exchangeables
                            and elem.attrib["presence"] != "forbidden"
                        ):
                            current_gene = {
                                "system": system,
                                "subsystem": subsystem,
                                "gene": elem.attrib["name"],
                                "version": def_path["version"],
                                "exchangeables": None,
                            }
                            system_genes.append(current_gene)
                        if elem.tag == "gene" and in_exchangeables:
                            exchangeables.append(elem.attrib["name"])
                        if elem.tag == "exchangeables":
                            in_exchangeables = True
                            exchangeables = []
                    elif event == "end":
                        if elem.tag == "exchangeables":
                            in_exchangeables = False
                            current_gene["exchangeables"] = ",".join(exchangeables)
                            exchangeables = []

    with open(output, "w") as f:
        fieldnames = ["id", "system", "subsystem", "version", "gene", "exchangeables"]
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        for id, gene in enumerate(system_genes):
            # gene["alternatives"] = ",".join(gene["alternatives"])
            gene["id"] = id
            writer.writerow(gene)


def download_model_release(releases):

    for version, tag in releases:
        df_model_url = f"https://github.com/mdmparis/defense-finder-models/releases/download/{tag}/defense-finder-models-{version}.tar.gz"
        _, tmp_path = tempfile.mkstemp()
        tmp_root_dir = tempfile.gettempdir()
        df_model_dir = Path(f"{tmp_root_dir}/defense-finder-models-{version}")
        df_model_definitions_dir = (
            df_model_dir / "defense-finder-models" / "definitions"
        )
        console.print(f"Download models: {df_model_url}")
        response = requests.get(
            df_model_url,
            allow_redirects=True,
        )
        with open(tmp_path, mode="wb") as file:
            file.write(response.content)

        console.print("untar file")
        with tarfile.open(tmp_path) as archive:
            archive.extractall(df_model_dir)
        yield {"version": tag, "dir": df_model_definitions_dir}


TMP_CIF = """
#
loop_
_ma_qa_metric.id
_ma_qa_metric.mode
_ma_qa_metric.name
_ma_qa_metric.software_group_id
_ma_qa_metric.type
1 global pLDDT 1 pLDDT 
2 local  pLDDT 1 pLDDT 
#
_ma_qa_metric_global.metric_id    1
_ma_qa_metric_global.metric_value {:.06}
_ma_qa_metric_global.model_id     1
_ma_qa_metric_global.ordinal_id   1
#
loop_
_ma_qa_metric_local.label_asym_id
_ma_qa_metric_local.label_comp_id
_ma_qa_metric_local.label_seq_id
_ma_qa_metric_local.metric_id
_ma_qa_metric_local.metric_value
_ma_qa_metric_local.model_id
_ma_qa_metric_local.ordinal_id
"""


def pdb2cif(pdb):
    cif = str(pdb).split(".pdb")[0] + ".cif"
    console.print(f"convert {pdb} -> {cif}")
    p = PDBParser()
    struc = p.get_structure("", pdb)

    list_atoms = []
    for a in struc.get_atoms():
        list_atoms.append(
            [
                a.parent.parent.id,
                a.parent.resname,
                a.parent.id[1],
                "2",
                a.bfactor,
                1,
                a.parent.id[1],
            ]
        )
    df = pd.DataFrame(list_atoms).drop_duplicates()

    with open(cif, "w") as of:
        of.write(TMP_CIF.format(df[4].mean()))
    df.to_csv(cif, index=False, header=False, mode="a", sep=" ")

    io = MMCIFIO()
    io.set_structure(struc)

    with open(cif, "a") as of:
        io.save(of)


def pae2png(tsv_file, png_file):

    console.print(f"Convert : {tsv_file} row data to {png_file}")
    v = pd.read_table(tsv_file, index_col=0, low_memory=False)
    fig, ax = plt.subplots(1, 1, figsize=(4, 3), facecolor=None)
    m = ax.matshow(v, cmap="Greens_r", vmin=0, vmax=35, aspect="auto", origin="lower")
    cbar = plt.colorbar(m, ax=ax, fraction=0.046, pad=0.04)
    ax.set_xlabel("Scored Residues")
    ax.set_ylabel("Aligned Residues")
    cbar.set_label("Expected Position Error (Å)")
    plt.tight_layout()
    plt.savefig(png_file, dpi=150, facecolor=None, transparent=True)
    plt.close()


@app.command(help="Remove version from sys_id")
def refseq(
    input: Annotated[
        Path,
        typer.Option(
            exists=False,
            file_okay=True,
            writable=True,
        ),
    ],
    output: Annotated[
        Path,
        typer.Option(
            file_okay=True,
            dir_okay=False,
            writable=True,
            resolve_path=True,
        ),
    ],
):
    with open(output, "w") as out, open(input, "r") as refseq_f:
        reader = csv.DictReader(refseq_f)
        fieldnames = reader.fieldnames
        writer = csv.DictWriter(out, fieldnames=fieldnames)
        writer.writeheader()
        for row in reader:
            if row["sys_id"] == "":
                row["sys_id"] = f'{row["Assembly"]}_{row["replicon"]}'
            result = re.sub(r"^(\w+)\.\d+(_.*)$", r"\1\2", row["sys_id"])
            console.print(f"[green]{row['sys_id']} ->  {result}")
            row["sys_id"] = result
            writer.writerow(row)


@app.command(
    help='Remove "No system found" hits if the are not the only hit for an assembly'
)
def refseq_sanitized_hits(
    input: Annotated[
        Path,
        typer.Option(
            exists=False,
            file_okay=True,
            writable=True,
        ),
    ],
    output: Annotated[
        Path,
        typer.Option(
            file_okay=True,
            dir_okay=False,
            writable=True,
            resolve_path=True,
        ),
    ],
):
    df = pd.read_csv(input)
    df_final = _sanitized_refseq_hits(df)
    df_final.reset_index().to_csv(output, index=False)
    return df_final


@app.command(help="Group hits per assembly and types (from 'sanitized-hits')")
def refseq_group_per_assembly_and_type(
    input: Annotated[
        Path,
        typer.Option(
            exists=False,
            file_okay=True,
            writable=True,
        ),
    ],
    output: Annotated[
        Path,
        typer.Option(
            file_okay=True,
            dir_okay=False,
            writable=True,
            resolve_path=True,
        ),
    ],
):
    df = pd.read_csv(input)
    df_final = _sanitized_refseq_hits(df)
    df_final_grouped = df_final.groupby(
        [
            "Assembly",
            "type",
            "Superkingdom",
            "phylum",
            "class",
            "order",
            "family",
            "genus",
            "species",
        ],
        as_index=False,
        dropna=False,
    ).size()
    df_final_grouped.reset_index().to_csv(output, index=False)


@app.command()
def refseq_group_per_assembly(
    input: Annotated[
        Path,
        typer.Option(
            exists=False,
            file_okay=True,
            writable=True,
        ),
    ],
    output: Annotated[
        Path,
        typer.Option(
            file_okay=True,
            dir_okay=False,
            writable=True,
            resolve_path=True,
        ),
    ],
):
    df = pd.read_csv(input)

    df["Assembly"] = df["Assembly"].apply(remove_version)
    df_grouped = df.groupby(
        [
            "Assembly",
            "Superkingdom",
            "phylum",
            "class",
            "order",
            "family",
            "genus",
            "species",
        ],
        as_index=False,
        dropna=False,
    ).size()
    df_grouped.reset_index().to_csv(output, index=False)


@app.command()
def refseq_type_count(
    input: Annotated[
        Path,
        typer.Option(
            exists=False,
            file_okay=True,
            writable=True,
            help="csv file with type and taxo (No system found removed when other system are founded in the same assembly)",
        ),
    ],
    output: Annotated[
        Path,
        typer.Option(
            file_okay=True,
            dir_okay=False,
            writable=True,
            resolve_path=True,
        ),
    ],
):
    df = pd.read_csv(input)
    grouped_per_type = df.groupby(["type"], as_index=False, dropna=False).size()
    grouped_per_type.reset_index().to_csv(output, index=False)


@app.command()
def markdown(
    dir: Annotated[
        Path,
        typer.Option(
            exists=True,
            file_okay=False,
            writable=True,
            readable=True,
            resolve_path=True,
            help="Dir where all systems article are",
        ),
    ],
):
    for file in dir.iterdir():
        if file.suffix == ".md":

            console.rule(f"[bold blue]{file.name}", style="blue")
            # make a copy of file
            _, tmp_path = tempfile.mkstemp()
            # with open(dst, "w") as tmp_f:
            dst = Path(tmp_path)
            dst.write_bytes(file.read_bytes())
            # check if article has ## Structure, ## Experimental Validation, ## Distribution of the system among prokaryotes

            with open(dst, "r+") as f:

                all_file = f.read()

                if (
                    re.search(
                        r"#{2}\s+Structure",
                        all_file,
                        flags=re.IGNORECASE | re.MULTILINE,
                    )
                    and re.search(
                        r"#{2}\s+Experimental\s+validation",
                        all_file,
                        flags=re.IGNORECASE | re.MULTILINE,
                    )
                    and re.search(
                        r"#{2}\s+Distribution\s+of\s+the\s+system\s+among\s+prokaryotes",
                        all_file,
                        flags=re.IGNORECASE | re.MULTILINE,
                    )
                ):

                    new_f_str = re.sub(
                        r"#{2}\s+Structure.*?#{2}\s+Experimental\s+validation",
                        "## Structure\n\n::article-structure\n::\n\n## Experimental validation",
                        all_file,
                        flags=re.DOTALL | re.IGNORECASE,
                    )
                    new_f = re.sub(
                        r"#{2}\s+Distribution\s+of\s+the\s+system\s+among\s+prokaryotes.*?#{2}\s+Structure",
                        "## Distribution of the system among prokaryotes\n\n::article-system-distribution-plot\n::\n\n## Structure",
                        new_f_str,
                        flags=re.DOTALL | re.IGNORECASE,
                    )
                    with open(file, "w") as f_out:
                        f_out.write(new_f)
                else:
                    console.log(f"[bold red]check it manually")


def remove_version(assembly):
    return assembly.split(".")[0]


def _sanitized_refseq_hits(df):
    df["Assembly"] = df["Assembly"].apply(remove_version)
    # Lower type namesmc
    # df["type"] = df["type"].apply(lambda x: x.lower())

    # Get all row with no system type
    df_no_system = df.loc[df["type"] == "No system found"]
    # unique assembly with no sys
    serie_assembly_with_no_sys = df_no_system["Assembly"].unique()
    # filter assembly to have those with no sys
    df_with_no_sys = df[df["Assembly"].isin(serie_assembly_with_no_sys)]
    # Group them by assembly, type, taxo
    no_sys_assembly_by_size = df_with_no_sys.groupby(
        [
            "Assembly",
            "type",
            "Superkingdom",
            "phylum",
            "class",
            "order",
            "family",
            "genus",
            "species",
        ],
        as_index=False,
        dropna=False,
    ).size()

    # count each occurrence
    df_again_per_assembly = no_sys_assembly_by_size.groupby(
        "Assembly", as_index=False, dropna=False
    ).size()
    # filter to keep only size > 1 (when == 1 it means that there is only "no system found for an assembly")
    # so we should keep it
    df_size_sup_1 = df_again_per_assembly[df_again_per_assembly["size"] > 1]
    assembly_where_should_remove_no_sys_found = df_size_sup_1["Assembly"].unique()

    # Construct new dataset to remove entries with no system found
    # while found system on other replicon that belongs to the
    # same assembly
    df_filtered_assembly_only_with_sys = df[
        (df["type"] != "No system found")
        | ~df.Assembly.isin(assembly_where_should_remove_no_sys_found)
    ]
    return df_filtered_assembly_only_with_sys


def find_model_definition(system, subsystem, list_paths):
    found_path = None
    for p in list_paths:
        path = p["path"]
        parts = path.parts
        if path.stem == subsystem and parts[-2] == system:
            console.rule(f"{system} - {subsystem}")
            console.print(p)
            found_path = {"path": path, "version": p["version"]}
            break

    if found_path is None:
        raise FileNotFoundError
    else:
        return found_path


def gen_model_path(model_dirs):
    for model_dir in model_dirs:
        for subdir in model_dir["dir"].iterdir():
            for system_path in subdir.iterdir():
                for subsystem_path in system_path.iterdir():
                    if str(subsystem_path).endswith(".xml"):
                        yield {"path": subsystem_path, "version": model_dir["version"]}