Skip to content
Snippets Groups Projects
main.py 21.69 KiB
import typer
import sys
import json
import pandas as pd
import csv
import tempfile
import matplotlib.pyplot as plt
from pandas.errors import ParserError
from typing_extensions import Annotated
from typing import Optional, List
from pathlib import Path
from pydantic import BaseModel, ValidationError
import frontmatter
from enum import Enum
from rich.console import Console
import re
import requests
from Bio.PDB import PDBParser, MMCIFIO
import tarfile
import xml.etree.ElementTree as ET

console = Console()
app = typer.Typer()


class LayoutEnum(str, Enum):
    article = "article"
    db = "db"


class TableArticle(BaseModel):
    doi: str


class TableColumns(BaseModel):
    article: TableArticle
    Sensor: Optional[str] = None
    Activator: Optional[str] = None
    Effector: Optional[str] = None
    PFAM: Optional[str] = None


class RelevantAbstract(BaseModel):
    doi: str


class FrontMatter(BaseModel):
    title: str
    layout: LayoutEnum
    tableColumns: TableColumns
    relevantAbstracts: List[RelevantAbstract]
    contributors: List[str]


@app.command()
def lint(
    file: Annotated[
        Path,
        typer.Option(
            exists=False,
            file_okay=True,
            writable=True,
        ),
    ],
):
    console.rule(f"[bold blue]{file.name}", style="blue")

    with open(file) as f:
        metadata, _ = frontmatter.parse(f.read())
        # print(metadata)
        try:
            FrontMatter.model_validate(metadata)
        except ValidationError as exc:
            for err in exc.errors():
                console.print(
                    f"[red]{err['msg']} : {err['type']} {' -> '.join([str(l) for l in err['loc']])}"
                )
            # raise
            sys.exit(1)
        else:
            console.print("[green] Everything is alright")


@app.command(
    help="Download all structures from statistics file. Generate PAE from raw data and organize these files per systems"
)
def structure(
    user: Annotated[str, typer.Option(help="username credential")],
    password: Annotated[str, typer.Option(help="password credential")],
    stat: Annotated[
        Path,
        typer.Option(
            exists=True,
            file_okay=True,
            dir_okay=True,
            writable=False,
            readable=True,
            resolve_path=True,
        ),
    ],
    output: Annotated[
        Path,
        typer.Option(
            exists=False,
            file_okay=False,
            dir_okay=True,
            writable=True,
            readable=True,
            resolve_path=True,
        ),
    ],
):
    with open(stat, "r") as stat_f:
        reader = csv.DictReader(stat_f, delimiter="\t")
        count_row = 0
        for row in reader:
            count_row += 1
            system_dir_name = row["system"].lower()
            pdb_path_file = Path(row["pdb"])
            foldseek_html_file = Path(str(pdb_path_file).split(".pdb")[0] + ".html")
            png_structure = Path(str(pdb_path_file).split(".pdb")[0] + ".png")
            files = [
                {"f": pdb_path_file, "d": "PDB"},
                {"f": Path(row["pae_table"]), "d": "PAE"},
                {"f": Path(row["fasta_file"]), "d": "Fastas"},
                {"f": foldseek_html_file, "d": "foldseek_pdb_html"},
                {"f": png_structure, "d": "png"},
            ]

            target_dir = output / system_dir_name
            target_dir.mkdir(parents=True, exist_ok=True)
            base_url = "https://data.atkinson-lab.com/DefenceFinder"
            console.rule(f"[bold blue]{count_row} / {system_dir_name}", style="blue")
            for f in files:
                console.print(f"[bold blue]{f['d']}", style="blue")
                str_f = str(f["f"])
                if str_f and str_f != "." and str_f != "" and str_f != "na":
                    # get the file from atkinson lab
                    target_file = target_dir / f["f"]
                    file_to_fetch = f"{base_url}/{f['d']}/{f['f']}"
                    console.print(f"Fetch : {file_to_fetch}")
                    response = requests.get(
                        file_to_fetch,
                        auth=(user, password),
                        allow_redirects=True,
                        stream=True,
                    )
                    with open(target_file, "wb") as fh:
                        for chunk in response.iter_content(chunk_size=1024):
                            if chunk:
                                fh.write(chunk)

                    if f["d"] == "PDB":
                        # pdb2cif
                        pdb2cif(target_file)
                    if f["d"] == "PAE":
                        png_file = str(target_file).split(".tsv")[0] + ".pae.png"
                        try:
                            pae2png(target_file, png_file)
                        except ParserError as err:
                            console.print(
                                f"[red] file {str(target_file.name)} cannot be parsed"
                            )
                            print(err)


@app.command()
def systems(
    dir: Annotated[
        Path,
        typer.Option(exists=False, file_okay=False, readable=True, dir_okay=True),
    ],
    pfam: Annotated[
        Path,
        typer.Option(
            exists=False,
            file_okay=True,
            writable=True,
        ),
    ],
    output: Annotated[
        Path,
        typer.Option(
            file_okay=True,
            dir_okay=False,
            writable=True,
            resolve_path=True,
        ),
    ],
):
    with open(pfam, "r") as pf:
        pfam_df = pd.read_csv(pf, index_col="AC", keep_default_na=False)
        systems = []
        if output.exists():
            output.unlink()
        with open(output, "a") as ty:
            for file in dir.iterdir():
                if file.suffix == ".md":
                    console.rule(f"[bold blue]{file.name}", style="blue")
                    with open(file) as f:
                        metadata, _ = frontmatter.parse(f.read())
                        del metadata["layout"]
                        sanitizedMetadata = {**metadata}
                        if "tableColumns" in sanitizedMetadata:
                            table_data = sanitizedMetadata["tableColumns"]
                            if "PFAM" in table_data:
                                pfams_list = [
                                    pfam.strip()
                                    for pfam in table_data["PFAM"].split(",")
                                ]
                                pfam_metadata = list()
                                for pfam in pfams_list:
                                    try:
                                        pfam_obj = pfam_df.loc[[pfam]]
                                        # print(pfam_obj)
                                        pfam_to_dict = pfam_obj.to_dict(orient="index")
                                        pfam_to_dict[pfam]["AC"] = pfam
                                        flatten_value = pfam_to_dict[pfam]
                                        pfam_metadata.append(flatten_value)
                                    except KeyError as err:
                                        console.print(f"[bold red]{err}", style="red")
                                        console.print(
                                            f"[bold red]No pfam entry or {pfam}",
                                            style="red",
                                        )
                                        continue
                                sanitizedMetadata["PFAM"] = pfam_metadata

                            if "article" in table_data:
                                sanitizedMetadata["doi"] = table_data["article"]["doi"]
                                if "abstract" in table_data["article"]:
                                    sanitizedMetadata["abstract"] = table_data[
                                        "article"
                                    ]["abstract"]
                                del table_data["article"]
                            if "PFAM" in table_data:
                                del table_data["PFAM"]
                            del sanitizedMetadata["tableColumns"]
                            sanitizedMetadata = {**sanitizedMetadata, **table_data}
                            systems.append(sanitizedMetadata)

            json_object = json.dumps(systems, indent=2)
            ty.write(json_object)


@app.command()
def system_operon_structure(
    output: Annotated[
        Path,
        typer.Option(
            file_okay=True,
            dir_okay=False,
            writable=True,
            resolve_path=True,
        ),
    ] = "./system-structures.csv",
    version: Annotated[str, typer.Option(help="Defense finder model")] = "v1.2.4",
    tag: Annotated[str, typer.Option(help="Defense finder model")] = "1.2.4",
):

    # get defense finder model from github

    df_model_url = f"https://github.com/mdmparis/defense-finder-models/releases/download/{tag}/defense-finder-models-{version}.tar.gz"
    _, tmp_path = tempfile.mkstemp()
    tmp_root_dir = tempfile.gettempdir()
    df_model_dir = Path(f"{tmp_root_dir}/defense-finder-models-{version}")
    df_model_definitions_dir = df_model_dir / "defense-finder-models" / "definitions"
    console.print(f"Download models: {df_model_url}")
    response = requests.get(
        df_model_url,
        allow_redirects=True,
    )
    with open(tmp_path, mode="wb") as file:
        file.write(response.content)

    console.print("untar file")
    with tarfile.open(tmp_path) as archive:
        archive.extractall(df_model_dir)
    # # extract foreach system and subsystem list genes
    # set the order
    system_genes = []
    for child in df_model_definitions_dir.iterdir():
        for system_path in child.iterdir():
            system = system_path.name
            # console.rule(system)
            subsystem_list = (
                s for s in system_path.iterdir() if str(s).endswith(".xml")
            )
            for subsystem in subsystem_list:
                susbsystem_name = subsystem.stem
                console.print(susbsystem_name)
                in_exchangeables = False
                current_gene = {}
                exchangeables = []

                with open(subsystem) as file:
                    for event, elem in ET.iterparse(file, events=("start", "end")):
                        if event == "start":
                            if elem.tag == "gene" and not in_exchangeables:
                                current_gene = {
                                    "system": system,
                                    "subsystem": susbsystem_name,
                                    "gene": elem.attrib["name"],
                                    "exchangeables": None,
                                }
                                system_genes.append(current_gene)
                            if elem.tag == "gene" and in_exchangeables:
                                exchangeables.append(elem.attrib["name"])
                            if elem.tag == "exchangeables":
                                in_exchangeables = True
                                exchangeables = []
                        elif event == "end":
                            if elem.tag == "exchangeables":
                                in_exchangeables = False
                                current_gene["exchangeables"] = ",".join(exchangeables)
                                exchangeables = []

    with open(output, "w") as f:
        fieldnames = ["id", "system", "subsystem", "gene", "exchangeables"]
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        for id, gene in enumerate(system_genes):
            # gene["alternatives"] = ",".join(gene["alternatives"])
            gene["id"] = id
            writer.writerow(gene)


TMP_CIF = """
#
loop_
_ma_qa_metric.id
_ma_qa_metric.mode
_ma_qa_metric.name
_ma_qa_metric.software_group_id
_ma_qa_metric.type
1 global pLDDT 1 pLDDT 
2 local  pLDDT 1 pLDDT 
#
_ma_qa_metric_global.metric_id    1
_ma_qa_metric_global.metric_value {:.06}
_ma_qa_metric_global.model_id     1
_ma_qa_metric_global.ordinal_id   1
#
loop_
_ma_qa_metric_local.label_asym_id
_ma_qa_metric_local.label_comp_id
_ma_qa_metric_local.label_seq_id
_ma_qa_metric_local.metric_id
_ma_qa_metric_local.metric_value
_ma_qa_metric_local.model_id
_ma_qa_metric_local.ordinal_id
"""


def pdb2cif(pdb):
    cif = str(pdb).split(".pdb")[0] + ".cif"
    console.print(f"convert {pdb} -> {cif}")
    p = PDBParser()
    struc = p.get_structure("", pdb)

    list_atoms = []
    for a in struc.get_atoms():
        list_atoms.append(
            [
                a.parent.parent.id,
                a.parent.resname,
                a.parent.id[1],
                "2",
                a.bfactor,
                1,
                a.parent.id[1],
            ]
        )
    df = pd.DataFrame(list_atoms).drop_duplicates()

    with open(cif, "w") as of:
        of.write(TMP_CIF.format(df[4].mean()))
    df.to_csv(cif, index=False, header=False, mode="a", sep=" ")

    io = MMCIFIO()
    io.set_structure(struc)

    with open(cif, "a") as of:
        io.save(of)


def pae2png(tsv_file, png_file):

    console.print(f"Convert : {tsv_file} row data to {png_file}")
    v = pd.read_table(tsv_file, index_col=0, low_memory=False)
    fig, ax = plt.subplots(1, 1, figsize=(4, 3), facecolor=None)
    m = ax.matshow(v, cmap="Greens_r", vmin=0, vmax=35, aspect="auto", origin="lower")
    cbar = plt.colorbar(m, ax=ax, fraction=0.046, pad=0.04)
    ax.set_xlabel("Scored Residues")
    ax.set_ylabel("Aligned Residues")
    cbar.set_label("Expected Position Error (Å)")
    plt.tight_layout()
    plt.savefig(png_file, dpi=150, facecolor=None, transparent=True)
    plt.close()


@app.command(help="Remove version from sys_id")
def refseq(
    input: Annotated[
        Path,
        typer.Option(
            exists=False,
            file_okay=True,
            writable=True,
        ),
    ],
    output: Annotated[
        Path,
        typer.Option(
            file_okay=True,
            dir_okay=False,
            writable=True,
            resolve_path=True,
        ),
    ],
):
    with open(output, "w") as out, open(input, "r") as refseq_f:
        reader = csv.DictReader(refseq_f)
        fieldnames = reader.fieldnames
        writer = csv.DictWriter(out, fieldnames=fieldnames)
        writer.writeheader()
        for row in reader:
            if row["sys_id"] == "":
                row["sys_id"] = f'{row["Assembly"]}_{row["replicon"]}'
            result = re.sub(r"^(\w+)\.\d+(_.*)$", r"\1\2", row["sys_id"])
            console.print(f"[green]{row['sys_id']} ->  {result}")
            row["sys_id"] = result
            writer.writerow(row)


@app.command(
    help='Remove "No system found" hits if the are not the only hit for an assembly'
)
def refseq_sanitized_hits(
    input: Annotated[
        Path,
        typer.Option(
            exists=False,
            file_okay=True,
            writable=True,
        ),
    ],
    output: Annotated[
        Path,
        typer.Option(
            file_okay=True,
            dir_okay=False,
            writable=True,
            resolve_path=True,
        ),
    ],
):
    df = pd.read_csv(input)
    df_final = _sanitized_refseq_hits(df)
    df_final.reset_index().to_csv(output, index=False)
    return df_final


@app.command(help="Group hits per assembly and types (from 'sanitized-hits')")
def refseq_group_per_assembly_and_type(
    input: Annotated[
        Path,
        typer.Option(
            exists=False,
            file_okay=True,
            writable=True,
        ),
    ],
    output: Annotated[
        Path,
        typer.Option(
            file_okay=True,
            dir_okay=False,
            writable=True,
            resolve_path=True,
        ),
    ],
):
    df = pd.read_csv(input)
    df_final = _sanitized_refseq_hits(df)
    df_final_grouped = df_final.groupby(
        [
            "Assembly",
            "type",
            "Superkingdom",
            "phylum",
            "class",
            "order",
            "family",
            "genus",
            "species",
        ],
        as_index=False,
        dropna=False,
    ).size()
    df_final_grouped.reset_index().to_csv(output, index=False)


@app.command()
def refseq_group_per_assembly(
    input: Annotated[
        Path,
        typer.Option(
            exists=False,
            file_okay=True,
            writable=True,
        ),
    ],
    output: Annotated[
        Path,
        typer.Option(
            file_okay=True,
            dir_okay=False,
            writable=True,
            resolve_path=True,
        ),
    ],
):
    df = pd.read_csv(input)

    df["Assembly"] = df["Assembly"].apply(remove_version)
    df_grouped = df.groupby(
        [
            "Assembly",
            "Superkingdom",
            "phylum",
            "class",
            "order",
            "family",
            "genus",
            "species",
        ],
        as_index=False,
        dropna=False,
    ).size()
    df_grouped.reset_index().to_csv(output, index=False)


@app.command()
def refseq_type_count(
    input: Annotated[
        Path,
        typer.Option(
            exists=False,
            file_okay=True,
            writable=True,
            help="csv file with type and taxo (No system found removed when other system are founded in the same assembly)",
        ),
    ],
    output: Annotated[
        Path,
        typer.Option(
            file_okay=True,
            dir_okay=False,
            writable=True,
            resolve_path=True,
        ),
    ],
):
    df = pd.read_csv(input)
    grouped_per_type = df.groupby(["type"], as_index=False, dropna=False).size()
    grouped_per_type.reset_index().to_csv(output, index=False)


@app.command()
def markdown(
    dir: Annotated[
        Path,
        typer.Option(
            exists=True,
            file_okay=False,
            writable=True,
            readable=True,
            resolve_path=True,
            help="Dir where all systems article are",
        ),
    ],
):
    for file in dir.iterdir():
        if file.suffix == ".md":

            console.rule(f"[bold blue]{file.name}", style="blue")
            # make a copy of file
            _, tmp_path = tempfile.mkstemp()
            # with open(dst, "w") as tmp_f:
            dst = Path(tmp_path)
            dst.write_bytes(file.read_bytes())
            # check if article has ## Structure, ## Experimental Validation, ## Distribution of the system among prokaryotes

            with open(dst, "r+") as f:

                all_file = f.read()

                if (
                    re.search(
                        r"#{2}\s+Structure",
                        all_file,
                        flags=re.IGNORECASE | re.MULTILINE,
                    )
                    and re.search(
                        r"#{2}\s+Experimental\s+validation",
                        all_file,
                        flags=re.IGNORECASE | re.MULTILINE,
                    )
                    and re.search(
                        r"#{2}\s+Distribution\s+of\s+the\s+system\s+among\s+prokaryotes",
                        all_file,
                        flags=re.IGNORECASE | re.MULTILINE,
                    )
                ):

                    new_f_str = re.sub(
                        r"#{2}\s+Structure.*?#{2}\s+Experimental\s+validation",
                        "## Structure\n\n::article-structure\n::\n\n## Experimental validation",
                        all_file,
                        flags=re.DOTALL | re.IGNORECASE,
                    )
                    new_f = re.sub(
                        r"#{2}\s+Distribution\s+of\s+the\s+system\s+among\s+prokaryotes.*?#{2}\s+Structure",
                        "## Distribution of the system among prokaryotes\n\n::article-system-distribution-plot\n::\n\n## Structure",
                        new_f_str,
                        flags=re.DOTALL | re.IGNORECASE,
                    )
                    with open(file, "w") as f_out:
                        f_out.write(new_f)
                else:
                    console.log(f"[bold red]check it manually")


def remove_version(assembly):
    return assembly.split(".")[0]


def _sanitized_refseq_hits(df):
    df["Assembly"] = df["Assembly"].apply(remove_version)
    # Lower type namesmc
    # df["type"] = df["type"].apply(lambda x: x.lower())

    # Get all row with no system type
    df_no_system = df.loc[df["type"] == "No system found"]
    # unique assembly with no sys
    serie_assembly_with_no_sys = df_no_system["Assembly"].unique()
    # filter assembly to have those with no sys
    df_with_no_sys = df[df["Assembly"].isin(serie_assembly_with_no_sys)]
    # Group them by assembly, type, taxo
    no_sys_assembly_by_size = df_with_no_sys.groupby(
        [
            "Assembly",
            "type",
            "Superkingdom",
            "phylum",
            "class",
            "order",
            "family",
            "genus",
            "species",
        ],
        as_index=False,
        dropna=False,
    ).size()

    # count each occurrence
    df_again_per_assembly = no_sys_assembly_by_size.groupby(
        "Assembly", as_index=False, dropna=False
    ).size()
    # filter to keep only size > 1 (when == 1 it means that there is only "no system found for an assembly")
    # so we should keep it
    df_size_sup_1 = df_again_per_assembly[df_again_per_assembly["size"] > 1]
    assembly_where_should_remove_no_sys_found = df_size_sup_1["Assembly"].unique()

    # Construct new dataset to remove entries with no system found
    # while found system on other replicon that belongs to the
    # same assembly
    df_filtered_assembly_only_with_sys = df[
        (df["type"] != "No system found")
        | ~df.Assembly.isin(assembly_where_should_remove_no_sys_found)
    ]
    return df_filtered_assembly_only_with_sys