From cc9e1a2c946ce89858bdaf1ed688483cad5c433e Mon Sep 17 00:00:00 2001 From: FreshAirTonight <mugao00@gmail.com> Date: Wed, 2 Mar 2022 22:17:56 -0500 Subject: [PATCH] minor refactor 1. better way to deal with zero template senario 2. removed some deprecated files --- src/alphafold/data/complex.py | 121 +++-- src/alphafold/data/msa_pairing.py.no_pd | 641 ------------------------ src/run_alphafold_stage1.py | 220 -------- src/run_alphafold_stage2a_comp.py | 502 ------------------- src/run_alphafold_stage2b.py | 81 --- 5 files changed, 57 insertions(+), 1508 deletions(-) delete mode 100644 src/alphafold/data/msa_pairing.py.no_pd delete mode 100644 src/run_alphafold_stage1.py delete mode 100644 src/run_alphafold_stage2a_comp.py delete mode 100644 src/run_alphafold_stage2b.py diff --git a/src/alphafold/data/complex.py b/src/alphafold/data/complex.py index 1b58683..973613d 100644 --- a/src/alphafold/data/complex.py +++ b/src/alphafold/data/complex.py @@ -38,8 +38,8 @@ def initialize_template_feats(num_templates_, num_res_, is_multimer=False): 'template_aatype': np.zeros([num_templates_, num_res_, 22], np.float32), 'template_all_atom_masks': np.zeros([num_templates_, num_res_, 37], np.float32), 'template_all_atom_positions': np.zeros([num_templates_, num_res_, 37, 3], np.float32), - 'template_domain_names': np.empty([num_templates_], dtype=object), - 'template_sequence': np.empty([num_templates_], dtype=object), + 'template_domain_names': np.empty([num_templates_], dtype=str), + 'template_sequence': np.empty([num_templates_], dtype=str), 'template_sum_probs': np.zeros([num_templates_,1], np.float32), } @@ -100,6 +100,10 @@ def load_monomer_feature(target, flags): mono_feature_dict["template_sequence"] = mono_feature_dict["template_sequence"][:flags.max_template_hits] mono_feature_dict["template_sum_probs"] = mono_feature_dict["template_sum_probs"][:flags.max_template_hits,:] + if T == 0 or flags.no_template: # deal with senario no template found, or set it to a null template if requested + mono_template_features = initialize_template_feats(1, L, is_multimer=False) + mono_feature_dict.update(mono_template_features) + if is_multimer_np: for i in range(monomer["copy_number"]): f_dict = pipeline_multimer.convert_monomer_features_af2complex( @@ -418,9 +422,9 @@ def extract_domain_mono(mono_entry): mono_msa = np.concatenate((mono_msa, msa[:,sta:end]), axis=1) mono_mtx = np.concatenate((mono_mtx, mtx[:,sta:end]), axis=1) mono_res = np.concatenate((mono_res, resid[sta:end]), axis=0) - all_gap_rows = ~np.all(mono_msa == 21, axis=1) ## remove ones with gaps only - features['msa']= mono_msa[all_gap_rows] - features['deletion_matrix_int'] = mono_mtx[all_gap_rows] + not_all_gap_rows = ~np.all(mono_msa == 21, axis=1) ## remove ones with gaps only + features['msa']= mono_msa[not_all_gap_rows] + features['deletion_matrix_int'] = mono_mtx[not_all_gap_rows] mono_entry['feature_dict'] = features mono_sequence = mono_seq features['residue_index'] = mono_res @@ -473,6 +477,7 @@ def extract_template_domain_mult(mono_entry): copy_num = mono_entry['copy_number'] dom_range = mono_entry['domain_range'] mono_sequence = features['sequence'][0].decode() + if dom_range is not None: mono_seq = '' for idx, boundary in enumerate(dom_range): @@ -649,34 +654,28 @@ def template_cropping_and_joining_mono(curr_input): monomers = curr_input['monomers'] new_feature_dict = curr_input['new_feature_dict'] + new_tem = initialize_template_feats(full_num_tem, full_num_res, False) col = 0; row = 0 - if not flags.no_template: - new_tem = initialize_template_feats(full_num_tem, full_num_res, False) - for mono_entry in monomers: - features = extract_template_domain_mono(mono_entry) - - copy_num = mono_entry['copy_number'] - #num_res = features['template_aatype'].shape[1] - num_res = features['msa'].shape[1] - num_tem = len(features['template_domain_names']) - - if num_tem == 0: - dom_fea = initialize_template_feats(1, num_res, is_multimer=False) - else: - dom_fea = features - - for i in range(copy_num): - col_ = col + num_res - row_ = row + num_tem - - new_tem['template_all_atom_positions'][row:row_,col:col_,...] = dom_fea['template_all_atom_positions'] - new_tem['template_domain_names'][row:row_] = dom_fea['template_domain_names'] - new_tem['template_sequence'][row:row_] = dom_fea['template_sequence'] - new_tem['template_sum_probs'][row:row_] = dom_fea['template_sum_probs'] - new_tem['template_aatype'][row:row_,col:col_,:] = dom_fea['template_aatype'] - new_tem['template_all_atom_masks'][row:row_,col:col_,:] = dom_fea['template_all_atom_masks'] - col = col_; row = row_ - new_feature_dict.update(new_tem) + for mono_entry in monomers: + features = extract_template_domain_mono(mono_entry) + + copy_num = mono_entry['copy_number'] + #num_res = features['template_aatype'].shape[1] + num_res = features['msa'].shape[1] + num_tem = len(features['template_domain_names']) + + for i in range(copy_num): + col_ = col + num_res + row_ = row + num_tem + + new_tem['template_all_atom_positions'][row:row_,col:col_,...] = features['template_all_atom_positions'] + new_tem['template_domain_names'][row:row_] = features['template_domain_names'] + new_tem['template_sequence'][row:row_] = features['template_sequence'] + new_tem['template_sum_probs'][row:row_] = features['template_sum_probs'] + new_tem['template_aatype'][row:row_,col:col_,:] = features['template_aatype'] + new_tem['template_all_atom_masks'][row:row_,col:col_,:] = features['template_all_atom_masks'] + col = col_; row = row_ + new_feature_dict.update(new_tem) curr_input.update({'new_feature_dict': new_feature_dict}) return curr_input @@ -702,39 +701,33 @@ def template_cropping_and_joining_mult(curr_input): new_feature_dict = curr_input['new_feature_dict'] col = 0; row = 0 - if not flags.no_template: - new_tem = initialize_template_feats(full_num_tem, full_num_res, is_multimer=True) - for mono_entry in monomers: - features = extract_template_domain_mult(mono_entry) - copy_num = mono_entry['copy_number'] - - #num_res = features['template_aatype'].shape[1] - num_res = features['msa'].shape[1] - num_tem = len(features['template_domain_names']) - if num_tem == 0: - dom_fea = initialize_template_feats(1, num_res, is_multimer=True) - dom_fea['asym_id'] = features['asym_id'] - dom_fea['sym_id'] = features['sym_id'] - dom_fea['entity_id'] = features['entity_id'] - else: - dom_fea = features - - for i in range(copy_num): - col_ = col + num_res - row_ = row + num_tem - - new_tem['template_all_atom_positions'][row:row_,col:col_,...] = dom_fea['template_all_atom_positions'] - new_tem['template_domain_names'][row:row_] = dom_fea['template_domain_names'] - new_tem['template_sequence'][row:row_] = dom_fea['template_sequence'] - new_tem['template_sum_probs'][row:row_] = dom_fea['template_sum_probs'] - new_tem['template_all_atom_mask'][row:row_,col:col_,:] = dom_fea['template_all_atom_mask'] - new_tem['template_aatype'][row:row_,col:col_] = dom_fea['template_aatype'] - new_tem['asym_id'][col:col_] = dom_fea['asym_id'] - new_tem['sym_id'][col:col_] = dom_fea['sym_id'] - new_tem['entity_id'][col:col_] = dom_fea['entity_id'] - col = col_; row = row_ - new_feature_dict.update(new_tem) + new_tem = initialize_template_feats(full_num_tem, full_num_res, is_multimer=True) + for mono_entry in monomers: + features = extract_template_domain_mult(mono_entry) + copy_num = mono_entry['copy_number'] + + #num_res = features['template_aatype'].shape[1] + num_res = features['msa'].shape[1] + num_tem = len(features['template_domain_names']) + + for i in range(copy_num): + col_ = col + num_res + row_ = row + num_tem + + new_tem['template_all_atom_positions'][row:row_,col:col_,...] = features['template_all_atom_positions'] + new_tem['template_domain_names'][row:row_] = features['template_domain_names'] + new_tem['template_sequence'][row:row_] = features['template_sequence'] + new_tem['template_sum_probs'][row:row_] = features['template_sum_probs'] + new_tem['template_all_atom_mask'][row:row_,col:col_,:] = features['template_all_atom_mask'] + new_tem['template_aatype'][row:row_,col:col_] = features['template_aatype'] + new_tem['asym_id'][col:col_] = features['asym_id'] + new_tem['sym_id'][col:col_] = features['sym_id'] + new_tem['entity_id'][col:col_] = features['entity_id'] + + col = col_; row = row_ + + new_feature_dict.update(new_tem) curr_input.update({'new_feature_dict': new_feature_dict}) return curr_input diff --git a/src/alphafold/data/msa_pairing.py.no_pd b/src/alphafold/data/msa_pairing.py.no_pd deleted file mode 100644 index 4583cd5..0000000 --- a/src/alphafold/data/msa_pairing.py.no_pd +++ /dev/null @@ -1,641 +0,0 @@ -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Pairing logic for multimer data pipeline.""" - -import collections -import functools -import re -import string -from typing import Any, Dict, Iterable, List, Sequence - -from alphafold.common import residue_constants -from alphafold.data import pipeline -import numpy as np -#import pandas as pd -import scipy.linalg - -ALPHA_ACCESSION_ID_MAP = {x: y for y, x in enumerate(string.ascii_uppercase)} -ALPHANUM_ACCESSION_ID_MAP = { - chr: num for num, chr in enumerate(string.ascii_uppercase + string.digits) -} # A-Z,0-9 -NUM_ACCESSION_ID_MAP = {str(x): x for x in range(10)} # 0-9 - -MSA_GAP_IDX = residue_constants.restypes_with_x_and_gap.index('-') -SEQUENCE_GAP_CUTOFF = 0.5 -SEQUENCE_SIMILARITY_CUTOFF = 0.9 - -MSA_PAD_VALUES = {'msa_all_seq': MSA_GAP_IDX, - 'msa_mask_all_seq': 1, - 'deletion_matrix_all_seq': 0, - 'deletion_matrix_int_all_seq': 0, - 'msa': MSA_GAP_IDX, - 'msa_mask': 1, - 'deletion_matrix': 0, - 'deletion_matrix_int': 0} - -MSA_FEATURES = ('msa', 'msa_mask', 'deletion_matrix', 'deletion_matrix_int') -SEQ_FEATURES = ('residue_index', 'aatype', 'all_atom_positions', - 'all_atom_mask', 'seq_mask', 'between_segment_residues', - 'has_alt_locations', 'has_hetatoms', 'asym_id', 'entity_id', - 'sym_id', 'entity_mask', 'deletion_mean', - 'prediction_atom_mask', - 'literature_positions', 'atom_indices_to_group_indices', - 'rigid_group_default_frame') -TEMPLATE_FEATURES = ('template_aatype', 'template_all_atom_positions', - 'template_all_atom_mask',) -# Added for AF2Complex -ADDITIONAL_TEMPLATE_FEATURES = ('template_domain_names', 'template_sequence', - 'template_sum_probs') -CHAIN_FEATURES = ('num_alignments', 'seq_length') - - -domain_name_pattern = re.compile( - r'''^(?P<pdb>[a-z\d]{4}) - \{(?P<bioassembly>[\d+(\+\d+)?])\} - (?P<chain>[a-zA-Z\d]+) - \{(?P<transform_index>\d+)\}$ - ''', re.VERBOSE) - - -def create_paired_features( - chains: Iterable[pipeline.FeatureDict], - prokaryotic: bool, - ) -> List[pipeline.FeatureDict]: - """Returns the original chains with paired NUM_SEQ features. - - Args: - chains: A list of feature dictionaries for each chain. - prokaryotic: Whether the target complex is from a prokaryotic organism. - Used to determine the distance metric for pairing. - - Returns: - A list of feature dictionaries with sequence features including only - rows to be paired. - """ - chains = list(chains) - chain_keys = chains[0].keys() - - if len(chains) < 2: - return chains - else: - updated_chains = [] - paired_chains_to_paired_row_indices = pair_sequences( - chains, prokaryotic) - paired_rows = reorder_paired_rows( - paired_chains_to_paired_row_indices) - - for chain_num, chain in enumerate(chains): - new_chain = {k: v for k, v in chain.items() if '_all_seq' not in k} - for feature_name in chain_keys: - if feature_name.endswith('_all_seq'): - feats_padded = pad_features(chain[feature_name], feature_name) - new_chain[feature_name] = feats_padded[paired_rows[:, chain_num]] - new_chain['num_alignments_all_seq'] = np.asarray( - len(paired_rows[:, chain_num])) - updated_chains.append(new_chain) - return updated_chains - - -def pad_features(feature: np.ndarray, feature_name: str) -> np.ndarray: - """Add a 'padding' row at the end of the features list. - - The padding row will be selected as a 'paired' row in the case of partial - alignment - for the chain that doesn't have paired alignment. - - Args: - feature: The feature to be padded. - feature_name: The name of the feature to be padded. - - Returns: - The feature with an additional padding row. - """ - assert feature.dtype != np.dtype(np.string_) - if feature_name in ('msa_all_seq', 'msa_mask_all_seq', - 'deletion_matrix_all_seq', 'deletion_matrix_int_all_seq'): - num_res = feature.shape[1] - padding = MSA_PAD_VALUES[feature_name] * np.ones([1, num_res], - feature.dtype) - elif feature_name in ('msa_uniprot_accession_identifiers_all_seq', - 'msa_species_identifiers_all_seq'): - padding = [b''] - else: - return feature - feats_padded = np.concatenate([feature, padding], axis=0) - return feats_padded - -''' -def _make_msa_df(chain_features: pipeline.FeatureDict) -> pd.DataFrame: - """Makes dataframe with msa features needed for msa pairing.""" - chain_msa = chain_features['msa_all_seq'] - query_seq = chain_msa[0] - per_seq_similarity = np.sum( - query_seq[None] == chain_msa, axis=-1) / float(len(query_seq)) - per_seq_gap = np.sum(chain_msa == 21, axis=-1) / float(len(query_seq)) - msa_df = pd.DataFrame({ - 'msa_species_identifiers': - chain_features['msa_species_identifiers_all_seq'], - 'msa_uniprot_accession_identifiers': - chain_features['msa_uniprot_accession_identifiers_all_seq'], - 'msa_row': - np.arange(len( - chain_features['msa_uniprot_accession_identifiers_all_seq'])), - 'msa_similarity': per_seq_similarity, - 'gap': per_seq_gap - }) - return msa_df - - -def _create_species_dict(msa_df: pd.DataFrame) -> Dict[bytes, pd.DataFrame]: - """Creates mapping from species to msa dataframe of that species.""" - species_lookup = {} - for species, species_df in msa_df.groupby('msa_species_identifiers'): - species_lookup[species] = species_df - return species_lookup -''' - -@functools.lru_cache(maxsize=65536) -def encode_accession(accession_id: str) -> int: - """Map accession codes to the serial order in which they were assigned.""" - alpha = ALPHA_ACCESSION_ID_MAP # A-Z - alphanum = ALPHANUM_ACCESSION_ID_MAP # A-Z,0-9 - num = NUM_ACCESSION_ID_MAP # 0-9 - - coding = 0 - - # This is based on the uniprot accession id format - # https://www.uniprot.org/help/accession_numbers - if accession_id[0] in {'O', 'P', 'Q'}: - bases = (alpha, num, alphanum, alphanum, alphanum, num) - elif len(accession_id) == 6: - bases = (alpha, num, alpha, alphanum, alphanum, num) - elif len(accession_id) == 10: - bases = (alpha, num, alpha, alphanum, alphanum, num, alpha, alphanum, - alphanum, num) - - product = 1 - for place, base in zip(reversed(accession_id), reversed(bases)): - coding += base[place] * product - product *= len(base) - - return coding - - -def _calc_id_diff(id_a: bytes, id_b: bytes) -> int: - return abs(encode_accession(id_a.decode()) - encode_accession(id_b.decode())) - - -def _find_all_accession_matches(accession_id_lists: List[List[bytes]], - diff_cutoff: int = 20 - ) -> List[List[Any]]: - """Finds accession id matches across the chains based on their difference.""" - all_accession_tuples = [] - current_tuple = [] - tokens_used_in_answer = set() - - def _matches_all_in_current_tuple(inp: bytes, diff_cutoff: int) -> bool: - return all((_calc_id_diff(s, inp) < diff_cutoff for s in current_tuple)) - - def _all_tokens_not_used_before() -> bool: - return all((s not in tokens_used_in_answer for s in current_tuple)) - - def dfs(level, accession_id, diff_cutoff=diff_cutoff) -> None: - if level == len(accession_id_lists) - 1: - if _all_tokens_not_used_before(): - all_accession_tuples.append(list(current_tuple)) - for s in current_tuple: - tokens_used_in_answer.add(s) - return - - if level == -1: - new_list = accession_id_lists[level+1] - else: - new_list = [(_calc_id_diff(accession_id, s), s) for - s in accession_id_lists[level+1]] - new_list = sorted(new_list) - new_list = [s for d, s in new_list] - - for s in new_list: - if (_matches_all_in_current_tuple(s, diff_cutoff) and - s not in tokens_used_in_answer): - current_tuple.append(s) - dfs(level + 1, s) - current_tuple.pop() - dfs(-1, '') - return all_accession_tuples - -''' -def _accession_row(msa_df: pd.DataFrame, accession_id: bytes) -> pd.Series: - matched_df = msa_df[msa_df.msa_uniprot_accession_identifiers == accession_id] - return matched_df.iloc[0] - - -def _match_rows_by_genetic_distance( - this_species_msa_dfs: List[pd.DataFrame], - cutoff: int = 20) -> List[List[int]]: - """Finds MSA sequence pairings across chains within a genetic distance cutoff. - - The genetic distance between two sequences is approximated by taking the - difference in their UniProt accession ids. - - Args: - this_species_msa_dfs: a list of dataframes containing MSA features for - sequences for a specific species. If species is missing for a chain, the - dataframe is set to None. - cutoff: the genetic distance cutoff. - - Returns: - A list of lists, each containing M indices corresponding to paired MSA rows, - where M is the number of chains. - """ - num_examples = len(this_species_msa_dfs) # N - - accession_id_lists = [] # M - match_index_to_chain_index = {} - for chain_index, species_df in enumerate(this_species_msa_dfs): - if species_df is not None: - accession_id_lists.append( - list(species_df.msa_uniprot_accession_identifiers.values)) - # Keep track of which of the this_species_msa_dfs are not None. - match_index_to_chain_index[len(accession_id_lists) - 1] = chain_index - - all_accession_id_matches = _find_all_accession_matches( - accession_id_lists, cutoff) # [k, M] - - all_paired_msa_rows = [] # [k, N] - for accession_id_match in all_accession_id_matches: - paired_msa_rows = [] - for match_index, accession_id in enumerate(accession_id_match): - # Map back to chain index. - chain_index = match_index_to_chain_index[match_index] - seq_series = _accession_row( - this_species_msa_dfs[chain_index], accession_id) - - if (seq_series.msa_similarity > SEQUENCE_SIMILARITY_CUTOFF or - seq_series.gap > SEQUENCE_GAP_CUTOFF): - continue - else: - paired_msa_rows.append(seq_series.msa_row) - # If a sequence is skipped based on sequence similarity to the respective - # target sequence or a gap cuttoff, the lengths of accession_id_match and - # paired_msa_rows will be different. Skip this match. - if len(paired_msa_rows) == len(accession_id_match): - paired_and_non_paired_msa_rows = np.array([-1] * num_examples) - matched_chain_indices = list(match_index_to_chain_index.values()) - paired_and_non_paired_msa_rows[matched_chain_indices] = paired_msa_rows - all_paired_msa_rows.append(list(paired_and_non_paired_msa_rows)) - return all_paired_msa_rows - - -def _match_rows_by_sequence_similarity(this_species_msa_dfs: List[pd.DataFrame] - ) -> List[List[int]]: - """Finds MSA sequence pairings across chains based on sequence similarity. - - Each chain's MSA sequences are first sorted by their sequence similarity to - their respective target sequence. The sequences are then paired, starting - from the sequences most similar to their target sequence. - - Args: - this_species_msa_dfs: a list of dataframes containing MSA features for - sequences for a specific species. - - Returns: - A list of lists, each containing M indices corresponding to paired MSA rows, - where M is the number of chains. - """ - all_paired_msa_rows = [] - - num_seqs = [len(species_df) for species_df in this_species_msa_dfs - if species_df is not None] - take_num_seqs = np.min(num_seqs) - - sort_by_similarity = ( - lambda x: x.sort_values('msa_similarity', axis=0, ascending=False)) - - for species_df in this_species_msa_dfs: - if species_df is not None: - species_df_sorted = sort_by_similarity(species_df) - msa_rows = species_df_sorted.msa_row.iloc[:take_num_seqs].values - else: - msa_rows = [-1] * take_num_seqs # take the last 'padding' row - all_paired_msa_rows.append(msa_rows) - all_paired_msa_rows = list(np.array(all_paired_msa_rows).transpose()) - return all_paired_msa_rows - - -def pair_sequences(examples: List[pipeline.FeatureDict], - prokaryotic: bool) -> Dict[int, np.ndarray]: - """Returns indices for paired MSA sequences across chains.""" - - num_examples = len(examples) - - all_chain_species_dict = [] - common_species = set() - for chain_features in examples: - msa_df = _make_msa_df(chain_features) - species_dict = _create_species_dict(msa_df) - all_chain_species_dict.append(species_dict) - common_species.update(set(species_dict)) - - common_species = sorted(common_species) - common_species.remove(b'') # Remove target sequence species. - - all_paired_msa_rows = [np.zeros(len(examples), int)] - all_paired_msa_rows_dict = {k: [] for k in range(num_examples)} - all_paired_msa_rows_dict[num_examples] = [np.zeros(len(examples), int)] - - for species in common_species: - if not species: - continue - this_species_msa_dfs = [] - species_dfs_present = 0 - for species_dict in all_chain_species_dict: - if species in species_dict: - this_species_msa_dfs.append(species_dict[species]) - species_dfs_present += 1 - else: - this_species_msa_dfs.append(None) - - # Skip species that are present in only one chain. - if species_dfs_present <= 1: - continue - - if np.any( - np.array([len(species_df) for species_df in - this_species_msa_dfs if - isinstance(species_df, pd.DataFrame)]) > 600): - continue - - # In prokaryotes (and some eukaryotes), interacting genes are often - # co-located on the chromosome into operons. Because of that we can assume - # that if two proteins' intergenic distance is less than a threshold, they - # two proteins will form an an interacting pair. - # In most eukaryotes, a single protein's MSA can contain many paralogs. - # Two genes may interact even if they are not close by genomic distance. - # In case of eukaryotes, some methods pair MSA sequences using sequence - # similarity method. - # See Jinbo Xu's work: - # https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6030867/#B28. - if prokaryotic: - paired_msa_rows = _match_rows_by_genetic_distance(this_species_msa_dfs) - - if not paired_msa_rows: - continue - else: - paired_msa_rows = _match_rows_by_sequence_similarity(this_species_msa_dfs) - all_paired_msa_rows.extend(paired_msa_rows) - all_paired_msa_rows_dict[species_dfs_present].extend(paired_msa_rows) - all_paired_msa_rows_dict = { - num_examples: np.array(paired_msa_rows) for - num_examples, paired_msa_rows in all_paired_msa_rows_dict.items() - } - return all_paired_msa_rows_dict -''' - -def reorder_paired_rows(all_paired_msa_rows_dict: Dict[int, np.ndarray] - ) -> np.ndarray: - """Creates a list of indices of paired MSA rows across chains. - - Args: - all_paired_msa_rows_dict: a mapping from the number of paired chains to the - paired indices. - - Returns: - a list of lists, each containing indices of paired MSA rows across chains. - The paired-index lists are ordered by: - 1) the number of chains in the paired alignment, i.e, all-chain pairings - will come first. - 2) e-values - """ - all_paired_msa_rows = [] - - for num_pairings in sorted(all_paired_msa_rows_dict, reverse=True): - paired_rows = all_paired_msa_rows_dict[num_pairings] - paired_rows_product = abs(np.array([np.prod(rows) for rows in paired_rows])) - paired_rows_sort_index = np.argsort(paired_rows_product) - all_paired_msa_rows.extend(paired_rows[paired_rows_sort_index]) - - return np.array(all_paired_msa_rows) - - -def block_diag(*arrs: np.ndarray, pad_value: float = 0.0) -> np.ndarray: - """Like scipy.linalg.block_diag but with an optional padding value.""" - ones_arrs = [np.ones_like(x) for x in arrs] - off_diag_mask = 1.0 - scipy.linalg.block_diag(*ones_arrs) - diag = scipy.linalg.block_diag(*arrs) - diag += (off_diag_mask * pad_value).astype(diag.dtype) - return diag - - -def _correct_post_merged_feats( - np_example: pipeline.FeatureDict, - np_chains_list: Sequence[pipeline.FeatureDict], - pair_msa_sequences: bool) -> pipeline.FeatureDict: - """Adds features that need to be computed/recomputed post merging.""" - - np_example['seq_length'] = np.asarray(np_example['aatype'].shape[0], - dtype=np.int32) - np_example['num_alignments'] = np.asarray(np_example['msa'].shape[0], - dtype=np.int32) - - if not pair_msa_sequences: - # Generate a bias that is 1 for the first row of every block in the - # block diagonal MSA - i.e. make sure the cluster stack always includes - # the query sequences for each chain (since the first row is the query - # sequence). - cluster_bias_masks = [] - for chain in np_chains_list: - mask = np.zeros(chain['msa'].shape[0]) - mask[0] = 1 - cluster_bias_masks.append(mask) - np_example['cluster_bias_mask'] = np.concatenate(cluster_bias_masks) - - # Initialize Bert mask with masked out off diagonals. - msa_masks = [np.ones(x['msa'].shape, dtype=np.float32) - for x in np_chains_list] - - np_example['bert_mask'] = block_diag( - *msa_masks, pad_value=0) - else: - np_example['cluster_bias_mask'] = np.zeros(np_example['msa'].shape[0]) - np_example['cluster_bias_mask'][0] = 1 - - # Initialize Bert mask with masked out off diagonals. - msa_masks = [np.ones(x['msa'].shape, dtype=np.float32) for - x in np_chains_list] - msa_masks_all_seq = [np.ones(x['msa_all_seq'].shape, dtype=np.float32) for - x in np_chains_list] - - msa_mask_block_diag = block_diag( - *msa_masks, pad_value=0) - msa_mask_all_seq = np.concatenate(msa_masks_all_seq, axis=1) - np_example['bert_mask'] = np.concatenate( - [msa_mask_all_seq, msa_mask_block_diag], axis=0) - return np_example - - -def _pad_templates(chains: Sequence[pipeline.FeatureDict], - max_templates: int) -> Sequence[pipeline.FeatureDict]: - """For each chain pad the number of templates to a fixed size. - - Args: - chains: A list of protein chains. - max_templates: Each chain will be padded to have this many templates. - - Returns: - The list of chains, updated to have template features padded to - max_templates. - """ - for chain in chains: - for k, v in chain.items(): - if k in TEMPLATE_FEATURES: - padding = np.zeros_like(v.shape) - padding[0] = max_templates - v.shape[0] - padding = [(0, p) for p in padding] - chain[k] = np.pad(v, padding, mode='constant') - return chains - - -def _merge_features_from_multiple_chains( - chains: Sequence[pipeline.FeatureDict], - pair_msa_sequences: bool) -> pipeline.FeatureDict: - """Merge features from multiple chains. - - Args: - chains: A list of feature dictionaries that we want to merge. - pair_msa_sequences: Whether to concatenate MSA features along the - num_res dimension (if True), or to block diagonalize them (if False). - - Returns: - A feature dictionary for the merged example. - """ - merged_example = {} - for feature_name in chains[0]: - feats = [x[feature_name] for x in chains] - feature_name_split = feature_name.split('_all_seq')[0] - if feature_name_split in MSA_FEATURES: - if pair_msa_sequences or '_all_seq' in feature_name: - merged_example[feature_name] = np.concatenate(feats, axis=1) - else: - merged_example[feature_name] = block_diag( - *feats, pad_value=MSA_PAD_VALUES[feature_name]) - elif feature_name_split in SEQ_FEATURES: - merged_example[feature_name] = np.concatenate(feats, axis=0) - elif feature_name_split in TEMPLATE_FEATURES: - merged_example[feature_name] = np.concatenate(feats, axis=1) - elif feature_name_split in CHAIN_FEATURES: - merged_example[feature_name] = np.sum(x for x in feats).astype(np.int32) - else: - merged_example[feature_name] = feats[0] - return merged_example - - -def _merge_homomers_dense_msa( - chains: Iterable[pipeline.FeatureDict]) -> Sequence[pipeline.FeatureDict]: - """Merge all identical chains, making the resulting MSA dense. - - Args: - chains: An iterable of features for each chain. - - Returns: - A list of feature dictionaries. All features with the same entity_id - will be merged - MSA features will be concatenated along the num_res - dimension - making them dense. - """ - entity_chains = collections.defaultdict(list) - for chain in chains: - entity_id = chain['entity_id'][0] - entity_chains[entity_id].append(chain) - - grouped_chains = [] - for entity_id in sorted(entity_chains): - chains = entity_chains[entity_id] - grouped_chains.append(chains) - chains = [ - _merge_features_from_multiple_chains(chains, pair_msa_sequences=True) - for chains in grouped_chains] - return chains - - -def _concatenate_paired_and_unpaired_features( - example: pipeline.FeatureDict) -> pipeline.FeatureDict: - """Merges paired and block-diagonalised features.""" - features = MSA_FEATURES - for feature_name in features: - if feature_name in example: - feat = example[feature_name] - feat_all_seq = example[feature_name + '_all_seq'] - merged_feat = np.concatenate([feat_all_seq, feat], axis=0) - example[feature_name] = merged_feat - example['num_alignments'] = np.array(example['msa'].shape[0], - dtype=np.int32) - return example - - -def merge_chain_features(np_chains_list: List[pipeline.FeatureDict], - pair_msa_sequences: bool, - max_templates: int) -> pipeline.FeatureDict: - """Merges features for multiple chains to single FeatureDict. - - Args: - np_chains_list: List of FeatureDicts for each chain. - pair_msa_sequences: Whether to merge paired MSAs. - max_templates: The maximum number of templates to include. - - Returns: - Single FeatureDict for entire complex. - """ - np_chains_list = _pad_templates( - np_chains_list, max_templates=max_templates) - np_chains_list = _merge_homomers_dense_msa(np_chains_list) - # Unpaired MSA features will be always block-diagonalised; paired MSA - # features will be concatenated. - np_example = _merge_features_from_multiple_chains( - np_chains_list, pair_msa_sequences=False) - if pair_msa_sequences: - np_example = _concatenate_paired_and_unpaired_features(np_example) - np_example = _correct_post_merged_feats( - np_example=np_example, - np_chains_list=np_chains_list, - pair_msa_sequences=pair_msa_sequences) - - return np_example - - -def deduplicate_unpaired_sequences( - np_chains: List[pipeline.FeatureDict]) -> List[pipeline.FeatureDict]: - """Removes unpaired sequences which duplicate a paired sequence.""" - - feature_names = np_chains[0].keys() - msa_features = MSA_FEATURES - - for chain in np_chains: - sequence_set = set(tuple(s) for s in chain['msa_all_seq']) - keep_rows = [] - # Go through unpaired MSA seqs and remove any rows that correspond to the - # sequences that are already present in the paired MSA. - for row_num, seq in enumerate(chain['msa']): - if tuple(seq) not in sequence_set: - keep_rows.append(row_num) - for feature_name in feature_names: - if feature_name in msa_features: - if keep_rows: - chain[feature_name] = chain[feature_name][keep_rows] - else: - new_shape = list(chain[feature_name].shape) - new_shape[0] = 0 - chain[feature_name] = np.zeros(new_shape, - dtype=chain[feature_name].dtype) - chain['num_alignments'] = np.array(chain['msa'].shape[0], dtype=np.int32) - return np_chains diff --git a/src/run_alphafold_stage1.py b/src/run_alphafold_stage1.py deleted file mode 100644 index 7ce2e00..0000000 --- a/src/run_alphafold_stage1.py +++ /dev/null @@ -1,220 +0,0 @@ -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Run data pipeline to generate input features for alphafold, save the features - -"""AlphaFold Stage 1: data pipeline for the generation of input features.""" -import json -import os -import pathlib -import pickle -import random -import sys -import time -from typing import Dict -from socket import gethostname - -from absl import app -from absl import flags -from absl import logging -import numpy as np - -from alphafold.common import protein -from alphafold.data import pipeline -from alphafold.data import templates -from alphafold.model import data -from alphafold.model import config -#from alphafold.model import model -#from alphafold.relax import relax -# Internal import (7716). - -flags.DEFINE_list('fasta_paths', None, 'Paths to FASTA files, each containing ' - 'one sequence. Paths should be separated by commas. ' - 'All FASTA paths must have a unique basename as the ' - 'basename is used to name the output directories for ' - 'each prediction.') -flags.DEFINE_string('output_dir', None, 'Path to a directory that will ' - 'store the results.') -flags.DEFINE_list('model_names', None, 'Names of models to use.') -flags.DEFINE_string('data_dir', None, 'Path to directory of supporting data.') -flags.DEFINE_string('jackhmmer_binary_path', '/usr/bin/jackhmmer', - 'Path to the JackHMMER executable.') -flags.DEFINE_string('hhblits_binary_path', '/usr/bin/hhblits', - 'Path to the HHblits executable.') -flags.DEFINE_string('hhsearch_binary_path', '/usr/bin/hhsearch', - 'Path to the HHsearch executable.') -flags.DEFINE_string('kalign_binary_path', '/usr/bin/kalign', - 'Path to the Kalign executable.') -flags.DEFINE_string('uniref90_database_path', None, 'Path to the Uniref90 ' - 'database for use by JackHMMER.') -flags.DEFINE_string('mgnify_database_path', None, 'Path to the MGnify ' - 'database for use by JackHMMER.') -flags.DEFINE_string('bfd_database_path', None, 'Path to the BFD ' - 'database for use by HHblits.') -flags.DEFINE_string('small_bfd_database_path', None, 'Path to the small ' - 'version of BFD used with the "reduced_dbs" preset.') -flags.DEFINE_string('uniclust30_database_path', None, 'Path to the Uniclust30 ' - 'database for use by HHblits.') -flags.DEFINE_string('pdb70_database_path', None, 'Path to the PDB70 ' - 'database for use by HHsearch.') -flags.DEFINE_string('template_mmcif_dir', None, 'Path to a directory with ' - 'template mmCIF structures, each named <pdb_id>.cif') -flags.DEFINE_string('max_template_date', None, 'Maximum template release date ' - 'to consider. Important if folding historical test sets.') -flags.DEFINE_string('obsolete_pdbs_path', None, 'Path to file containing a ' - 'mapping from obsolete PDB IDs to the PDB IDs of their ' - 'replacements.') -flags.DEFINE_enum('preset', None, - ['reduced_dbs', 'full_dbs', 'casp14'], - 'Choose preset model configuration - no ensembling and ' - 'smaller genetic database config (reduced_dbs), no ' - 'ensembling and full genetic database config (full_dbs) or ' - 'full genetic database config and 8 model ensemblings ' - '(casp14).') -flags.DEFINE_boolean('benchmark', False, 'Run multiple JAX model evaluations ' - 'to obtain a timing that excludes the compilation time, ' - 'which should be more indicative of the time required for ' - 'inferencing many proteins.') -flags.DEFINE_integer('random_seed', None, 'The random seed for the data ' - 'pipeline. By default, this is randomly generated. Note ' - 'that even if this is set, Alphafold may still not be ' - 'deterministic, because processes like GPU inference are ' - 'nondeterministic.') -FLAGS = flags.FLAGS - -MAX_TEMPLATE_HITS = 20 -RELAX_MAX_ITERATIONS = 0 -RELAX_ENERGY_TOLERANCE = 2.39 -RELAX_STIFFNESS = 10.0 -RELAX_EXCLUDE_RESIDUES = [] -RELAX_MAX_OUTER_ITERATIONS = 20 - - -def _check_flag(flag_name: str, preset: str, should_be_set: bool): - if should_be_set != bool(FLAGS[flag_name].value): - verb = 'be' if should_be_set else 'not be' - raise ValueError(f'{flag_name} must {verb} set for preset "{preset}"') - - -def predict_structure( - fasta_path: str, - fasta_name: str, - output_dir_base: str, - data_pipeline: pipeline.DataPipeline, - benchmark: bool, - random_seed: int): - """Predicts structure using AlphaFold for the given sequence.""" - timings = {} - output_dir = os.path.join(output_dir_base, fasta_name) - if not os.path.exists(output_dir): - os.makedirs(output_dir) - msa_output_dir = os.path.join(output_dir, 'msas') - if not os.path.exists(msa_output_dir): - os.makedirs(msa_output_dir) - - # Get features. - t_0 = time.time() - feature_dict = data_pipeline.process( - input_fasta_path=fasta_path, - msa_output_dir=msa_output_dir) - timings['features'] = time.time() - t_0 - - # Write out features as a pickled dictionary. - features_output_path = os.path.join(output_dir, 'features.pkl') - with open(features_output_path, 'wb') as f: - pickle.dump(feature_dict, f, protocol=4) - - logging.info('Final timings for %s: %s', fasta_name, timings) - - timings_output_path = os.path.join(output_dir, 'timings_fea.json') - with open(timings_output_path, 'w') as f: - f.write(json.dumps(timings, indent=4)) - - -def main(argv): - if len(argv) > 1: - raise app.UsageError('Too many command-line arguments.') - - use_small_bfd = FLAGS.preset == 'reduced_dbs' - _check_flag('small_bfd_database_path', FLAGS.preset, - should_be_set=use_small_bfd) - _check_flag('bfd_database_path', FLAGS.preset, - should_be_set=not use_small_bfd) - _check_flag('uniclust30_database_path', FLAGS.preset, - should_be_set=not use_small_bfd) - - if FLAGS.preset in ('reduced_dbs', 'full_dbs'): - num_ensemble = 1 - elif FLAGS.preset == 'casp14': - num_ensemble = 8 - - # Check for duplicate FASTA file names. - fasta_names = [pathlib.Path(p).stem for p in FLAGS.fasta_paths] - if len(fasta_names) != len(set(fasta_names)): - raise ValueError('All FASTA paths must have a unique basename.') - - template_featurizer = templates.TemplateHitFeaturizer( - mmcif_dir=FLAGS.template_mmcif_dir, - max_template_date=FLAGS.max_template_date, - max_hits=MAX_TEMPLATE_HITS, - kalign_binary_path=FLAGS.kalign_binary_path, - release_dates_path=None, - obsolete_pdbs_path=FLAGS.obsolete_pdbs_path) - - data_pipeline = pipeline.DataPipeline( - jackhmmer_binary_path=FLAGS.jackhmmer_binary_path, - hhblits_binary_path=FLAGS.hhblits_binary_path, - hhsearch_binary_path=FLAGS.hhsearch_binary_path, - uniref90_database_path=FLAGS.uniref90_database_path, - mgnify_database_path=FLAGS.mgnify_database_path, - bfd_database_path=FLAGS.bfd_database_path, - uniclust30_database_path=FLAGS.uniclust30_database_path, - small_bfd_database_path=FLAGS.small_bfd_database_path, - pdb70_database_path=FLAGS.pdb70_database_path, - template_featurizer=template_featurizer, - use_small_bfd=use_small_bfd) - - random_seed = FLAGS.random_seed - if random_seed is None: - random_seed = random.randrange(sys.maxsize) - logging.info('Using random seed %d for the data pipeline', random_seed) - - # Predict structure for each of the sequences. - for fasta_path, fasta_name in zip(FLAGS.fasta_paths, fasta_names): - host_name = gethostname() - print(f"Info: working on target {fasta_name} at {host_name}") - predict_structure( - fasta_path=fasta_path, - fasta_name=fasta_name, - output_dir_base=FLAGS.output_dir, - data_pipeline=data_pipeline, - benchmark=FLAGS.benchmark, - random_seed=random_seed) - - -if __name__ == '__main__': - flags.mark_flags_as_required([ - 'fasta_paths', - 'output_dir', - 'data_dir', - 'preset', - 'uniref90_database_path', - 'mgnify_database_path', - 'pdb70_database_path', - 'template_mmcif_dir', - 'max_template_date', - 'obsolete_pdbs_path', - ]) - - app.run(main) diff --git a/src/run_alphafold_stage2a_comp.py b/src/run_alphafold_stage2a_comp.py deleted file mode 100644 index 35b562d..0000000 --- a/src/run_alphafold_stage2a_comp.py +++ /dev/null @@ -1,502 +0,0 @@ -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Run AlphaFold DL modelds using pre-generated features w/o model relaxation -# -# Input: pre-generated features from AlphaFold DataPipeline on single sequences -# Output: un-relaxed protein models -# -# Note: AF2Complex is a modified, enhanced version of AlphaFold 2. -# Additional unofficial features added: -# -# Predicting models of a protein complex including both homooligomer and heterooligomer -# No MSA pairing required -# New metrics designed for evaluating protein-protein interface -# Saving structure models of all recycles -# Split feature generation (stage 1), DL inference (stage 2a), and model relaxation (stage 2b) -# -# Some other features such as option for dynamically controled number of recycles and -# residue index breaks were taken from ColabFold -# -# Mu Gao and Davi Nakajima An -# Georgia Institute of Technology - -"""Enhanced AlphaFold Stage2a: protein complex structure prediction with deep learning""" -import json -import os -import pickle -import random -import sys -import time -import re -from typing import Dict - -from absl import app -from absl import flags -from absl import logging - -from alphafold.common import protein -from alphafold.model import data -from alphafold.model import config -from alphafold.model import model -from alphafold.data import pipeline - -from alphafold.data.complex import * -from datetime import date - -import numpy as np -# Internal import (7716). - - -flags.DEFINE_string('target_lst_path', None, 'Path to a file containing a list of targets ' - 'in any monomer, homo- or hetero-oligomers ' - 'configurations. For example, TarA is a monomer, TarA:2 is a dimer ' - 'of two TarAs. TarA:2/TarB is a trimer of two TarA and one TarB, etc.' - ) -flags.DEFINE_string('output_dir', None, 'Path to a directory that will ' - 'store the results.') -flags.DEFINE_string('feature_dir', None, 'Path to a directory that will ' - 'contains pre-genearted feature in pickle format.') -flags.DEFINE_list('model_names', None, 'Names of deep learning models to use.') -flags.DEFINE_string('data_dir', None, 'Path to directory of supporting data.') -flags.DEFINE_enum('preset', None, - ['reduced_dbs', 'casp14', 'economy', 'super', 'super2', 'genome', 'genome2'], - 'Choose preset model configuration: <reduced_dbs> no ensembling, ' - '<economy> no ensemble, up to 256 MSA clusters, recycling up to 3 rounds; ' - '<super, super2> 1 or 2 ensembles, up to 512 MSA clusters, recycling up to 20 rounds; ' - '<genome, genome2> 1 or 2 ensembles, up to 512 MSA clusters, max number ' - 'of recycles and ensembles adjusted according to input sequence length; ' - 'or <casp14> 8 model ensemblings of the factory settings.') -flags.DEFINE_integer('random_seed', None, 'The random seed for the data ' - 'pipeline. By default, this is randomly generated. Note ' - 'that even if this is set, Alphafold may still not be ' - 'deterministic, because processes like GPU inference are ' - 'nondeterministic.') -flags.DEFINE_integer('max_recycles', None, 'The maximum number of recycles.', lower_bound=1) -flags.DEFINE_float('recycle_tol', None, 'The tolerance for recycling, caculated as the RMSD change ' - 'in the distogram of backbone Ca atoms. Recycling stops ' - 'if the change is smaller than this value', lower_bound=0.0, upper_bound=2.0) -flags.DEFINE_integer('num_ensemble', None, 'The number of ensembles of each model, 1 means no ensembling.', lower_bound=1) -flags.DEFINE_integer('max_msa_clusters', None, 'The maximum number of MSA clusters.', lower_bound=1) -flags.DEFINE_integer('max_extra_msa', None, 'The maximum number of extra MSA clusters.', lower_bound=1) -flags.DEFINE_boolean('write_complex_features', False, 'Save the feature dict for ' - 'complex prediction as a pickle file under the output direcotry') -flags.DEFINE_enum('template_mode', 'oligomer', ['none', 'monomer', 'oligomer'], - 'none - No template is allowed, ' - 'monomer - Use template only for monomer but not oligomer modeling, ' - 'oligomer - Use monomer template for all modeling if exists.') -flags.DEFINE_integer('save_recycled', 0, '0 - no recycle info saving, 1 - print ' - 'metrics of intermediate recycles, 2 - additionally saving pdb structures ' - 'of all recycles, 3 - additionally save all results in pickle ' - 'dictionaries of each recycling iteration.', lower_bound=0, upper_bound=3) - -FLAGS = flags.FLAGS - -MAX_TEMPLATE_HITS = 4 -MAX_MSA_DEPTH_MONO = 10000 ### maximum number of input sequences in the msa of a monomer - -################################################################################################## -# read either a single target string or a input list of targets in a file -# each line has a format like: <target> <length> (output_name), the output_name is optional. -# In <target>, use monomer:num to indicate num of copies in a homooligomer -# and name1/name2 to indicate heterooligomer. For example, TarA:2/TarB is 2 copies of TarA and 1 TarB -def _read_target_file( data_lst_file ): - target_lst = [] - if not os.path.exists( data_lst_file ): ### input is a single target in strings - fields = data_lst_file.split(',') - fullname = name = fields[0] - if len(fields) == 2: - name = fields[1] - target_lst.append( {'full':fullname, 'name':name} ) - else: ### input are a list of targets in a file - with open( data_lst_file ) as file: - for line in file: - if line.startswith("#"): - continue - line = line.strip() # strip "\n" - fields = line.split() - fullname = name = fields[0] - if len(fields) > 2 and not fields[2].startswith("#"): - name = fields[2] - target_lst.append( {'full':fullname, 'name':name} ) - - # process the components of a complex if detected - for target in target_lst: - complex = target['full'] - monomers = [] - subfields = complex.split('/') - for item in subfields: - cols = item.split(':') - if len(cols) == 1: - monomers.append( {cols[0]:1} ) ### monomer - elif len(cols) > 1: - monomers.append( {cols[0]:cols[1]} ) - if len(monomers) >= 1: - target['split'] = monomers - - return target_lst -################################################################################################## - - -################################################################################################## -def predict_structure( - target: Dict[str, str], - output_dir_base: str, - feature_dir_base: str, - model_runners: Dict[str, model.RunModel], - random_seed: int, - max_msa_clusters: int, - max_extra_msa: int, - max_recycles: int, - num_ensemble: int, - preset: str, - write_complex_features: bool, - template_mode: str): - """Predicts structure using AlphaFold for the given sequence.""" - timings = {} - - target_name = target['name'] - target_name = re.sub(":", "_x", target_name) - target_name = re.sub("/", "+", target_name) - - homo_copy = [] - seq_names = [] - for homo in target['split']: - for seq_name, seq_copy in homo.items(): - homo_copy.append(int(seq_copy)) - seq_names.append(seq_name) - - time.sleep(random.randint(0,30)) # mitigating creating the same output directory from multiple runs - - # Retrieve pre-generated features of monomers (single protien sequences) - t_0 = time.time() - feature_dicts = [] - for seq_name in seq_names: - feature_dir = os.path.join(feature_dir_base, seq_name) - if not os.path.exists(feature_dir): - raise SystemExit("Error: ", feature_dir, "does not exists") - - # load pre-generated features as a pickled dictionary. - features_input_path = os.path.join(feature_dir, 'features.pkl') - with open(features_input_path, "rb") as f: - mono_feature_dict = pickle.load(f) - N = len(mono_feature_dict["msa"]) - L = len(mono_feature_dict["residue_index"]) - T = len(mono_feature_dict["template_domain_names"]) - print(f"Info: {target_name} found monomer {seq_name} msa_depth = {N}, seq_len = {L}, num_templ = {T}") - if N > MAX_MSA_DEPTH_MONO: - print(f"Info: {seq_name} MSA size is too large, reducing to {MAX_MSA_DEPTH_MONO}") - mono_feature_dict["msa"] = mono_feature_dict["msa"][:MAX_MSA_DEPTH_MONO,:] - mono_feature_dict["deletion_matrix_int"] = mono_feature_dict["deletion_matrix_int"][:MAX_MSA_DEPTH_MONO,:] - mono_feature_dict['num_alignments'][:] = MAX_MSA_DEPTH_MONO - if T > MAX_TEMPLATE_HITS: - print(f"Info: {seq_name} reducing the number of structural templates to {MAX_TEMPLATE_HITS}") - mono_feature_dict["template_aatype"] = mono_feature_dict["template_aatype"][:MAX_TEMPLATE_HITS,...] - mono_feature_dict["template_all_atom_masks"] = mono_feature_dict["template_all_atom_masks"][:MAX_TEMPLATE_HITS,...] - mono_feature_dict["template_all_atom_positions"] = mono_feature_dict["template_all_atom_positions"][:MAX_TEMPLATE_HITS,...] - mono_feature_dict["template_domain_names"] = mono_feature_dict["template_domain_names"][:MAX_TEMPLATE_HITS] - mono_feature_dict["template_sequence"] = mono_feature_dict["template_sequence"][:MAX_TEMPLATE_HITS] - mono_feature_dict["template_sum_probs"] = mono_feature_dict["template_sum_probs"][:MAX_TEMPLATE_HITS,:] - feature_dicts.append( mono_feature_dict ) - - # Make features for complex structure prediction using monomer structures if necessary - if len(seq_names) == 1 and homo_copy[0] == 1: # monomer structure prediction - feature_dict = feature_dicts[0] - seq_len = len(feature_dict["residue_index"]) - Ls = [seq_len] - if template_mode == 'none': - new_tem = initialize_template_feats(0, seq_len) - feature_dict.update(new_tem) - else: # complex structure prediction - feature_dict, Ls = make_complex_features(feature_dicts, target_name, homo_copy, template_mode) - - mono_chains = [] - mono_chains = get_mono_chain(seq_names, homo_copy, Ls) - print(f"Info: individual chain(s) to model {mono_chains}") - - N = len(feature_dict["msa"]) - L = len(feature_dict["residue_index"]) - T = len(feature_dict["template_domain_names"]) - print(f"Info: modeling {target_name} with msa_depth = {N}, seq_len = {L}, num_templ = {T}") - timings['features'] = round(time.time() - t_0, 2) - - - output_dir = os.path.join(output_dir_base, target_name) - if not os.path.exists(output_dir): - try: - os.makedirs(output_dir) - except FileExistsError: - print(f"Warning: tried to create an existing {output_dir}, ignored") - - if write_complex_features: - feature_output_path = os.path.join(output_dir, 'features_comp.pkl') - with open(feature_output_path, 'wb') as f: - pickle.dump(feature_dict, f, protocol=4) - - today = date.today().strftime('%Y%m%d') - out_suffix = '_' + today + '_' + str(random_seed)[-6:] - - plddts = {} # predicted LDDT score - iterations = {} # recycle information - ptms = {}; pitms = {} # predicted TM-score - ires = {}; icnt = {} # interfacial residues and contacts - tols = {} # change in backbone pairwise distance to check with the recyle tolerance criterion - ints = {} # interface-score - # Run models for structure prediction - for model_name, model_runner in model_runners.items(): - model_out_name = model_name + out_suffix - logging.info('Running model %s', model_out_name) - t_0 = time.time() - - # set size of msa (to reduce memory requirements) - if max_msa_clusters is not None and max_extra_msa is not None: - msa_clusters = max(min(N, max_msa_clusters),5) - model_runner.config.data.eval.max_msa_clusters = msa_clusters - model_runner.config.data.common.max_extra_msa = max(min(N-msa_clusters,max_extra_msa),1) - if preset in ['genome', 'genome2', 'super']: - max_iter = max_recycles - max(0, (L - 500) // 50) - max_iter = max( 6, max_iter ) - if L > 1180: ### memory limit of a single 16GB GPU, applied to cases with mutliple ensembles - num_en = 1 - else: - num_en = num_ensemble - print(f"Info: {target_name} reset max_recycles = {max_iter}, num_ensemble = {num_en}") - model_runner.config.data.common.num_recycle = max_iter - model_runner.config.model.num_recycle = max_iter - model_runner.config.data.eval.num_ensemble = num_en - - processed_feature_dict = model_runner.process_features( - feature_dict, random_seed=random_seed) - timings[f'process_features_{model_out_name}'] = round(time.time() - t_0, 2) - - t_0 = time.time() - prediction_result, (tot_recycle, tol_value, recycled) = model_runner.predict(processed_feature_dict) - prediction_result['num_recycle'] = tot_recycle - prediction_result['mono_chains'] = mono_chains - - tols[model_out_name] = round(tol_value.tolist(), 3) - iterations[model_out_name] = tot_recycle.tolist() ### convert from jax numpy to regular list for json saving - - t_diff = time.time() - t_0 - timings[f'predict_and_compile_{model_out_name}'] = round(t_diff,1) - logging.info( - 'Total JAX model %s predict time (includes compilation time): %.1f seconds', model_out_name, t_diff) - - def _save_results(result, log_model_name, out_dir, recycle_index): - # Get mean pLDDT confidence metric. - plddt = np.mean(result['plddt']) - plddts[log_model_name] = round(plddt, 2) - ptm = 0 - if 'ptm' in result: - ptm = result['ptm'].tolist() - ptms[log_model_name] = round(ptm, 4) - pitm = 0; inter_residues = 0; inter_contacts = 0; inter_sc = 0 - if 'pitm' in result: - pitm = result['pitm']['score'].tolist() - inter_residues = result['pitm']['num_residues'].tolist() - inter_contacts = result['pitm']['num_contacts'].tolist() - pitms[log_model_name] = round(pitm, 4) - ires[log_model_name] = inter_residues - icnt[log_model_name] = int(inter_contacts) - if 'interface' in result: - inter_sc = result['interface']['score'].tolist() - ints[log_model_name] = round(inter_sc, 4) - - if recycle_index < tot_recycle: - tol = result['tol_val'].tolist() - tols[log_model_name] = round(tol, 2) - print(f"Info: {target_name} {log_model_name}, ", - f"tol = {tol:5.2f}, pLDDT = {plddt:.2f}, pTM-score = {ptm:.4f}", end='') - if len(seq_names) > 1 or sum(homo_copy) > 1: # complex target - print(f", piTM-score = {pitm:.4f}, interface-score = {inter_sc:.4f}", end='') - print(f", iRes = {inter_residues:<4d} iCnt = {inter_contacts:<4.0f}") - else: - print('') - else: - print(f"Info: {target_name} {log_model_name} performed {tot_recycle} recycles,", - f"final tol = {tol_value:.2f}, pLDDT = {plddt:.2f}, pTM-score = {ptm:.4f}", end='') - if len(seq_names) > 1 or sum(homo_copy) > 1: - print(f", piTM-score = {pitm:.4f}, interface-score = {inter_sc:.4f}", end='') - print(f", iRes = {inter_residues:<4d} iCnt = {inter_contacts:<4.0f}") - else: - print('') - - # Save the model outputs, not saving pkl for intermeidate recycles to save storage space - if recycle_index == tot_recycle or FLAGS.save_recycled == 3: - result_output_path = os.path.join(out_dir, f'{log_model_name}.pkl') - with open(result_output_path, 'wb') as f: - pickle.dump(result, f, protocol=4) - - if recycle_index == tot_recycle or FLAGS.save_recycled >= 2: - # Set the b-factors to the per-residue plddt - final_atom_mask = result['structure_module']['final_atom_mask'] - b_factors = result['plddt'][:, None] * final_atom_mask - - unrelaxed_protein = protein.from_prediction(processed_feature_dict, - result, b_factors=b_factors) - - unrelaxed_pdb_path = os.path.join(out_dir, f'{log_model_name}.pdb') - with open(unrelaxed_pdb_path, 'w') as f: - f.write(protein.to_pdb(unrelaxed_protein)) - - # output info of intermeidate recycles and save the coordinates - if FLAGS.save_recycled: - recycle_out_dir = os.path.join(output_dir, "recycled") - if FLAGS.save_recycled > 1 and not os.path.exists(recycle_out_dir): - os.mkdir(recycle_out_dir) - for i, rec_dict in enumerate(recycled): - if i < tot_recycle: - _save_results(rec_dict, f"{model_out_name}_recycled_{i:02d}", - recycle_out_dir, i) - - # the final results from this model run - _save_results(prediction_result, model_out_name, output_dir, tot_recycle) - # End of model runs - - # Rank by pTMscore if exists, otherwise pLDDTs - ranked_order = [] - if 'ptm' in prediction_result: - ranking_metric = 'pTM' - for idx, (mod_name, _) in enumerate( - sorted(ptms.items(), key=lambda x: x[1], reverse=True)): - ranked_order.append(mod_name) - else: - ranking_metric = 'pLDDT' - for idx, (mod_name, _) in enumerate( - sorted(plddts.items(), key=lambda x: x[1], reverse=True)): - ranked_order.append(mod_name) - - stats = {'plddts': plddts, 'ptms': ptms, 'order': ranked_order, - 'ranking_metric': ranking_metric, 'iterations': iterations, - 'tol_values':tols, 'chains': mono_chains, 'chain_lengths': Ls, - 'timings':timings} - if len(pitms): - stats = { **stats, 'pitms': pitms, 'interfacial residue number': ires, - 'interficial contact number': icnt, 'interface score': ints } - - if len(model_runners) > 1: #more than 1 model - ranking_output_path = os.path.join(output_dir, 'ranking_all'+out_suffix+'.json') - else: #only one model, use different model names to avoid overwriting same file - ranking_output_path = os.path.join(output_dir, 'ranking_'+model_out_name+'.json') - - with open(ranking_output_path, 'w') as f: - f.write(json.dumps(stats, sort_keys=True, indent=4)) - - logging.info('Final timings for %s: %s', target_name, timings) - -################################################################################################## - -def main(argv): - if len(argv) > 1: - raise app.UsageError('Too many command-line arguments.') - - # read a list of target files - target_lst = _read_target_file( FLAGS.target_lst_path ) - - max_recycles = 3; recycle_tol = 0 - max_extra_msa = None; max_msa_clusters = None - print("Info: using preset", FLAGS.preset) - if FLAGS.preset == 'reduced_dbs': - num_ensemble = 1 - elif FLAGS.preset == 'casp14': - num_ensemble = 8 - elif FLAGS.preset == 'economy': - num_ensemble = 1 - recycle_tol = 0.1 - max_extra_msa = 512 - max_msa_clusters = 256 - elif FLAGS.preset in ['super', 'super2']: - num_ensemble = 1 - max_recycles = 20 - recycle_tol = 0.1 - max_extra_msa = 1024 - max_msa_clusters = 512 - if FLAGS.preset == 'super2': num_ensemble = 2 - elif FLAGS.preset in ['genome', 'genome2']: - num_ensemble = 1 - max_recycles = 20 - recycle_tol = 0.5 - max_extra_msa = 1024 - max_msa_clusters = 512 - if FLAGS.preset == 'genome2': num_ensemble = 2 - - # allow customized parameters over preset - if FLAGS.num_ensemble is not None: - num_ensemble = FLAGS.num_ensemble - print(f"Info: set num_ensemble = {num_ensemble}") - if FLAGS.max_recycles is not None: - max_recycles = FLAGS.max_recycles - print(f"Info: set max_recyles = {max_recycles}") - if FLAGS.recycle_tol is not None: - recycle_tol = FLAGS.recycle_tol - print(f"Info: set recycle_tol = {recycle_tol}") - if FLAGS.max_msa_clusters is not None and FLAGS.max_extra_msa is not None: - max_msa_clusters = FLAGS.max_msa_clusters - max_extra_msa = FLAGS.max_extra_msa - print(f"Info: max_msa_clusters = {max_msa_clusters}, max_extra_msa = {max_extra_msa}") - - model_runners = {} - for model_name in FLAGS.model_names: - model_config = config.model_config(model_name) - model_config.data.eval.num_ensemble = num_ensemble - model_config.data.common.num_recycle = max_recycles - model_config.model.num_recycle = max_recycles - model_config.model.recycle_tol = recycle_tol - model_config.model.save_recycled = FLAGS.save_recycled - - model_params = data.get_model_haiku_params( - model_name=model_name, data_dir=FLAGS.data_dir) - model_runner = model.RunModel(model_config, model_params) - model_runners[model_name] = model_runner - - logging.info('Have %d models: %s', len(model_runners), - list(model_runners.keys())) - - random_seed = FLAGS.random_seed - if random_seed is None: - random_seed = random.randrange(sys.maxsize) - logging.info('Using random seed %d for the data pipeline', random_seed) - - # Predict structure for each target. - for target in target_lst: - target_name = target['name'] - target_split = target['split'] - print(f"Info: working on target {target_name}") - predict_structure( - target=target, - output_dir_base=FLAGS.output_dir, - feature_dir_base=FLAGS.feature_dir, - model_runners=model_runners, - random_seed=random_seed, - max_msa_clusters=max_msa_clusters, - max_extra_msa=max_extra_msa, - max_recycles=max_recycles, - num_ensemble=num_ensemble, - preset=FLAGS.preset, - write_complex_features=FLAGS.write_complex_features, - template_mode=FLAGS.template_mode - ) - - -if __name__ == '__main__': - flags.mark_flags_as_required([ - 'target_lst_path', - 'output_dir', - 'feature_dir', - 'model_names', - 'data_dir', - 'preset', - ]) - - app.run(main) diff --git a/src/run_alphafold_stage2b.py b/src/run_alphafold_stage2b.py deleted file mode 100644 index 806ceed..0000000 --- a/src/run_alphafold_stage2b.py +++ /dev/null @@ -1,81 +0,0 @@ -"""Run MD minization to relax a protein structure model from AF2""" -import os -import pickle -import re -import time - -from absl import app -from absl import flags -from tqdm import tqdm - -from alphafold.relax import relax -from alphafold.common import protein - -from run_alphafold_stage2a_comp import _read_target_file, FLAGS - -import numpy as np - - - -MAX_TEMPLATE_HITS = 20 -RELAX_MAX_ITERATIONS = 0 -RELAX_ENERGY_TOLERANCE = 2.39 -RELAX_STIFFNESS = 10.0 -RELAX_EXCLUDE_RESIDUES = [] -RELAX_MAX_OUTER_ITERATIONS = 20 - - -def main(argv): - if len(argv) > 1: - raise app.UsageError('Too many command-line arguments.') - - amber_relaxer = relax.AmberRelaxation( - max_iterations=RELAX_MAX_ITERATIONS, - tolerance=RELAX_ENERGY_TOLERANCE, - stiffness=RELAX_STIFFNESS, - exclude_residues=RELAX_EXCLUDE_RESIDUES, - max_outer_iterations=RELAX_MAX_OUTER_ITERATIONS) - - # read list of targets - target_lst = _read_target_file( FLAGS.target_lst_path ) - - for target in target_lst: - # get target name - target_name = target['name'] - target_name = re.sub(":", "_x", target_name) - target_name = re.sub("/", "+", target_name) - target_dir = os.path.join(FLAGS.output_dir, target_name) - - for afile in os.listdir(target_dir): - # find all unrelaxed pdb files, relaxed ones with 'relaxed' as prefix - if ".pdb" in afile and afile.startswith("model_"): - print(f"Info: {target_name} processing {afile}") - unrelaxed_pdb_file = afile - unrelaxed_pdb_path = os.path.join(target_dir, unrelaxed_pdb_file) - with open(unrelaxed_pdb_path, "r") as f: - unrelaxed_pdb_str = f.read() - - unrelaxed_protein = protein.from_pdb_string( unrelaxed_pdb_str ) - - # Relax the prediction. - t_0 = time.time() - relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein) - relaxation_time = time.time() - t_0 - - # Save the relaxed PDB. - relaxed_output_path = os.path.join(target_dir, f'relaxed_{unrelaxed_pdb_file}') - with open(relaxed_output_path, 'w') as f: - f.write(relaxed_pdb_str) - - print(f"Info: {target_name} relaxation done, time spent {relaxation_time:.1f} seconds") - - -if __name__ == '__main__': - flags.mark_flags_as_required([ - 'target_lst_path', - 'output_dir', - 'feature_dir', - 'template_mode', - ]) - - app.run(main) -- GitLab