From cc9e1a2c946ce89858bdaf1ed688483cad5c433e Mon Sep 17 00:00:00 2001
From: FreshAirTonight <mugao00@gmail.com>
Date: Wed, 2 Mar 2022 22:17:56 -0500
Subject: [PATCH] minor refactor

1. better way to deal with zero template senario
2. removed some deprecated files
---
 src/alphafold/data/complex.py           | 121 +++--
 src/alphafold/data/msa_pairing.py.no_pd | 641 ------------------------
 src/run_alphafold_stage1.py             | 220 --------
 src/run_alphafold_stage2a_comp.py       | 502 -------------------
 src/run_alphafold_stage2b.py            |  81 ---
 5 files changed, 57 insertions(+), 1508 deletions(-)
 delete mode 100644 src/alphafold/data/msa_pairing.py.no_pd
 delete mode 100644 src/run_alphafold_stage1.py
 delete mode 100644 src/run_alphafold_stage2a_comp.py
 delete mode 100644 src/run_alphafold_stage2b.py

diff --git a/src/alphafold/data/complex.py b/src/alphafold/data/complex.py
index 1b58683..973613d 100644
--- a/src/alphafold/data/complex.py
+++ b/src/alphafold/data/complex.py
@@ -38,8 +38,8 @@ def initialize_template_feats(num_templates_, num_res_, is_multimer=False):
       'template_aatype': np.zeros([num_templates_, num_res_, 22], np.float32),
       'template_all_atom_masks': np.zeros([num_templates_, num_res_, 37], np.float32),
       'template_all_atom_positions': np.zeros([num_templates_, num_res_, 37, 3], np.float32),
-      'template_domain_names': np.empty([num_templates_], dtype=object),
-      'template_sequence': np.empty([num_templates_], dtype=object),
+      'template_domain_names': np.empty([num_templates_], dtype=str),
+      'template_sequence': np.empty([num_templates_], dtype=str),
       'template_sum_probs': np.zeros([num_templates_,1], np.float32),
   }
 
@@ -100,6 +100,10 @@ def load_monomer_feature(target, flags):
           mono_feature_dict["template_sequence"] = mono_feature_dict["template_sequence"][:flags.max_template_hits]
           mono_feature_dict["template_sum_probs"] = mono_feature_dict["template_sum_probs"][:flags.max_template_hits,:]
 
+      if T == 0 or flags.no_template:  # deal with senario no template found, or set it to a null template if requested
+          mono_template_features = initialize_template_feats(1, L, is_multimer=False)
+          mono_feature_dict.update(mono_template_features)
+
       if is_multimer_np:
           for i in range(monomer["copy_number"]):
             f_dict = pipeline_multimer.convert_monomer_features_af2complex(
@@ -418,9 +422,9 @@ def extract_domain_mono(mono_entry):
             mono_msa = np.concatenate((mono_msa, msa[:,sta:end]), axis=1)
             mono_mtx = np.concatenate((mono_mtx, mtx[:,sta:end]), axis=1)
             mono_res = np.concatenate((mono_res, resid[sta:end]), axis=0)
-    all_gap_rows = ~np.all(mono_msa == 21, axis=1)   ## remove ones with gaps only
-    features['msa']= mono_msa[all_gap_rows]
-    features['deletion_matrix_int'] = mono_mtx[all_gap_rows]
+    not_all_gap_rows = ~np.all(mono_msa == 21, axis=1)   ## remove ones with gaps only
+    features['msa']= mono_msa[not_all_gap_rows]
+    features['deletion_matrix_int'] = mono_mtx[not_all_gap_rows]
     mono_entry['feature_dict'] = features
     mono_sequence = mono_seq
     features['residue_index'] = mono_res
@@ -473,6 +477,7 @@ def extract_template_domain_mult(mono_entry):
   copy_num  = mono_entry['copy_number']
   dom_range = mono_entry['domain_range']
   mono_sequence = features['sequence'][0].decode()
+
   if dom_range is not None:
     mono_seq = ''
     for idx, boundary in enumerate(dom_range):
@@ -649,34 +654,28 @@ def template_cropping_and_joining_mono(curr_input):
   monomers = curr_input['monomers']
   new_feature_dict = curr_input['new_feature_dict']
 
+  new_tem = initialize_template_feats(full_num_tem, full_num_res, False)
   col = 0; row = 0
-  if not flags.no_template:
-    new_tem = initialize_template_feats(full_num_tem, full_num_res, False)
-    for mono_entry in monomers:
-      features = extract_template_domain_mono(mono_entry)
-
-      copy_num = mono_entry['copy_number']
-      #num_res = features['template_aatype'].shape[1]
-      num_res = features['msa'].shape[1]
-      num_tem = len(features['template_domain_names'])
-
-      if num_tem == 0:
-          dom_fea = initialize_template_feats(1, num_res, is_multimer=False)
-      else:
-          dom_fea = features
-
-      for i in range(copy_num):
-        col_ = col + num_res
-        row_ = row + num_tem
-
-        new_tem['template_all_atom_positions'][row:row_,col:col_,...] = dom_fea['template_all_atom_positions']
-        new_tem['template_domain_names'][row:row_] = dom_fea['template_domain_names']
-        new_tem['template_sequence'][row:row_]  = dom_fea['template_sequence']
-        new_tem['template_sum_probs'][row:row_] = dom_fea['template_sum_probs']
-        new_tem['template_aatype'][row:row_,col:col_,:] = dom_fea['template_aatype']
-        new_tem['template_all_atom_masks'][row:row_,col:col_,:] = dom_fea['template_all_atom_masks']
-        col = col_; row = row_
-    new_feature_dict.update(new_tem)
+  for mono_entry in monomers:
+    features = extract_template_domain_mono(mono_entry)
+
+    copy_num = mono_entry['copy_number']
+    #num_res = features['template_aatype'].shape[1]
+    num_res = features['msa'].shape[1]
+    num_tem = len(features['template_domain_names'])
+
+    for i in range(copy_num):
+      col_ = col + num_res
+      row_ = row + num_tem
+
+      new_tem['template_all_atom_positions'][row:row_,col:col_,...] = features['template_all_atom_positions']
+      new_tem['template_domain_names'][row:row_] = features['template_domain_names']
+      new_tem['template_sequence'][row:row_]  = features['template_sequence']
+      new_tem['template_sum_probs'][row:row_] = features['template_sum_probs']
+      new_tem['template_aatype'][row:row_,col:col_,:] = features['template_aatype']
+      new_tem['template_all_atom_masks'][row:row_,col:col_,:] = features['template_all_atom_masks']
+      col = col_; row = row_
+  new_feature_dict.update(new_tem)
 
   curr_input.update({'new_feature_dict': new_feature_dict})
   return curr_input
@@ -702,39 +701,33 @@ def template_cropping_and_joining_mult(curr_input):
   new_feature_dict = curr_input['new_feature_dict']
 
   col = 0; row = 0
-  if not flags.no_template:
-    new_tem = initialize_template_feats(full_num_tem, full_num_res, is_multimer=True)
-    for mono_entry in monomers:
-      features = extract_template_domain_mult(mono_entry)
-      copy_num = mono_entry['copy_number']
-
-      #num_res = features['template_aatype'].shape[1]
-      num_res = features['msa'].shape[1]
-      num_tem = len(features['template_domain_names'])
-      if num_tem == 0:
-        dom_fea = initialize_template_feats(1, num_res, is_multimer=True)
-        dom_fea['asym_id'] = features['asym_id']
-        dom_fea['sym_id']  = features['sym_id']
-        dom_fea['entity_id'] = features['entity_id']
-      else:
-        dom_fea = features
-
-      for i in range(copy_num):
-        col_ = col + num_res
-        row_ = row + num_tem
-
-        new_tem['template_all_atom_positions'][row:row_,col:col_,...] = dom_fea['template_all_atom_positions']
-        new_tem['template_domain_names'][row:row_] = dom_fea['template_domain_names']
-        new_tem['template_sequence'][row:row_]  = dom_fea['template_sequence']
-        new_tem['template_sum_probs'][row:row_] = dom_fea['template_sum_probs']
-        new_tem['template_all_atom_mask'][row:row_,col:col_,:] = dom_fea['template_all_atom_mask']
-        new_tem['template_aatype'][row:row_,col:col_] = dom_fea['template_aatype']
-        new_tem['asym_id'][col:col_] = dom_fea['asym_id']
-        new_tem['sym_id'][col:col_] = dom_fea['sym_id']
-        new_tem['entity_id'][col:col_] = dom_fea['entity_id']
 
-        col = col_; row = row_
-    new_feature_dict.update(new_tem)
+  new_tem = initialize_template_feats(full_num_tem, full_num_res, is_multimer=True)
+  for mono_entry in monomers:
+    features = extract_template_domain_mult(mono_entry)
+    copy_num = mono_entry['copy_number']
+
+    #num_res = features['template_aatype'].shape[1]
+    num_res = features['msa'].shape[1]
+    num_tem = len(features['template_domain_names'])
+
+    for i in range(copy_num):
+      col_ = col + num_res
+      row_ = row + num_tem
+
+      new_tem['template_all_atom_positions'][row:row_,col:col_,...] = features['template_all_atom_positions']
+      new_tem['template_domain_names'][row:row_] = features['template_domain_names']
+      new_tem['template_sequence'][row:row_]  = features['template_sequence']
+      new_tem['template_sum_probs'][row:row_] = features['template_sum_probs']
+      new_tem['template_all_atom_mask'][row:row_,col:col_,:] = features['template_all_atom_mask']
+      new_tem['template_aatype'][row:row_,col:col_] = features['template_aatype']
+      new_tem['asym_id'][col:col_] = features['asym_id']
+      new_tem['sym_id'][col:col_] = features['sym_id']
+      new_tem['entity_id'][col:col_] = features['entity_id']
+
+      col = col_; row = row_
+
+  new_feature_dict.update(new_tem)
 
   curr_input.update({'new_feature_dict': new_feature_dict})
   return curr_input
diff --git a/src/alphafold/data/msa_pairing.py.no_pd b/src/alphafold/data/msa_pairing.py.no_pd
deleted file mode 100644
index 4583cd5..0000000
--- a/src/alphafold/data/msa_pairing.py.no_pd
+++ /dev/null
@@ -1,641 +0,0 @@
-# Copyright 2021 DeepMind Technologies Limited
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#      http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Pairing logic for multimer data pipeline."""
-
-import collections
-import functools
-import re
-import string
-from typing import Any, Dict, Iterable, List, Sequence
-
-from alphafold.common import residue_constants
-from alphafold.data import pipeline
-import numpy as np
-#import pandas as pd
-import scipy.linalg
-
-ALPHA_ACCESSION_ID_MAP = {x: y for y, x in enumerate(string.ascii_uppercase)}
-ALPHANUM_ACCESSION_ID_MAP = {
-    chr: num for num, chr in enumerate(string.ascii_uppercase + string.digits)
-}  # A-Z,0-9
-NUM_ACCESSION_ID_MAP = {str(x): x for x in range(10)}                # 0-9
-
-MSA_GAP_IDX = residue_constants.restypes_with_x_and_gap.index('-')
-SEQUENCE_GAP_CUTOFF = 0.5
-SEQUENCE_SIMILARITY_CUTOFF = 0.9
-
-MSA_PAD_VALUES = {'msa_all_seq': MSA_GAP_IDX,
-                  'msa_mask_all_seq': 1,
-                  'deletion_matrix_all_seq': 0,
-                  'deletion_matrix_int_all_seq': 0,
-                  'msa': MSA_GAP_IDX,
-                  'msa_mask': 1,
-                  'deletion_matrix': 0,
-                  'deletion_matrix_int': 0}
-
-MSA_FEATURES = ('msa', 'msa_mask', 'deletion_matrix', 'deletion_matrix_int')
-SEQ_FEATURES = ('residue_index', 'aatype', 'all_atom_positions',
-                'all_atom_mask', 'seq_mask', 'between_segment_residues',
-                'has_alt_locations', 'has_hetatoms', 'asym_id', 'entity_id',
-                'sym_id', 'entity_mask', 'deletion_mean',
-                'prediction_atom_mask',
-                'literature_positions', 'atom_indices_to_group_indices',
-                'rigid_group_default_frame')
-TEMPLATE_FEATURES = ('template_aatype', 'template_all_atom_positions',
-                     'template_all_atom_mask',) 
-# Added for AF2Complex
-ADDITIONAL_TEMPLATE_FEATURES = ('template_domain_names', 'template_sequence',
-                 'template_sum_probs')
-CHAIN_FEATURES = ('num_alignments', 'seq_length')
-
-
-domain_name_pattern = re.compile(
-    r'''^(?P<pdb>[a-z\d]{4})
-    \{(?P<bioassembly>[\d+(\+\d+)?])\}
-    (?P<chain>[a-zA-Z\d]+)
-    \{(?P<transform_index>\d+)\}$
-    ''', re.VERBOSE)
-
-
-def create_paired_features(
-    chains: Iterable[pipeline.FeatureDict],
-    prokaryotic: bool,
-    ) ->  List[pipeline.FeatureDict]:
-  """Returns the original chains with paired NUM_SEQ features.
-
-  Args:
-    chains:  A list of feature dictionaries for each chain.
-    prokaryotic: Whether the target complex is from a prokaryotic organism.
-      Used to determine the distance metric for pairing.
-
-  Returns:
-    A list of feature dictionaries with sequence features including only
-    rows to be paired.
-  """
-  chains = list(chains)
-  chain_keys = chains[0].keys()
-
-  if len(chains) < 2:
-    return chains
-  else:
-    updated_chains = []
-    paired_chains_to_paired_row_indices = pair_sequences(
-        chains, prokaryotic)
-    paired_rows = reorder_paired_rows(
-        paired_chains_to_paired_row_indices)
-
-    for chain_num, chain in enumerate(chains):
-      new_chain = {k: v for k, v in chain.items() if '_all_seq' not in k}
-      for feature_name in chain_keys:
-        if feature_name.endswith('_all_seq'):
-          feats_padded = pad_features(chain[feature_name], feature_name)
-          new_chain[feature_name] = feats_padded[paired_rows[:, chain_num]]
-      new_chain['num_alignments_all_seq'] = np.asarray(
-          len(paired_rows[:, chain_num]))
-      updated_chains.append(new_chain)
-    return updated_chains
-
-
-def pad_features(feature: np.ndarray, feature_name: str) -> np.ndarray:
-  """Add a 'padding' row at the end of the features list.
-
-  The padding row will be selected as a 'paired' row in the case of partial
-  alignment - for the chain that doesn't have paired alignment.
-
-  Args:
-    feature: The feature to be padded.
-    feature_name: The name of the feature to be padded.
-
-  Returns:
-    The feature with an additional padding row.
-  """
-  assert feature.dtype != np.dtype(np.string_)
-  if feature_name in ('msa_all_seq', 'msa_mask_all_seq',
-                      'deletion_matrix_all_seq', 'deletion_matrix_int_all_seq'):
-    num_res = feature.shape[1]
-    padding = MSA_PAD_VALUES[feature_name] * np.ones([1, num_res],
-                                                     feature.dtype)
-  elif feature_name in ('msa_uniprot_accession_identifiers_all_seq',
-                        'msa_species_identifiers_all_seq'):
-    padding = [b'']
-  else:
-    return feature
-  feats_padded = np.concatenate([feature, padding], axis=0)
-  return feats_padded
-
-'''
-def _make_msa_df(chain_features: pipeline.FeatureDict) -> pd.DataFrame:
-  """Makes dataframe with msa features needed for msa pairing."""
-  chain_msa = chain_features['msa_all_seq']
-  query_seq = chain_msa[0]
-  per_seq_similarity = np.sum(
-      query_seq[None] == chain_msa, axis=-1) / float(len(query_seq))
-  per_seq_gap = np.sum(chain_msa == 21, axis=-1) / float(len(query_seq))
-  msa_df = pd.DataFrame({
-      'msa_species_identifiers':
-          chain_features['msa_species_identifiers_all_seq'],
-      'msa_uniprot_accession_identifiers':
-          chain_features['msa_uniprot_accession_identifiers_all_seq'],
-      'msa_row':
-          np.arange(len(
-              chain_features['msa_uniprot_accession_identifiers_all_seq'])),
-      'msa_similarity': per_seq_similarity,
-      'gap': per_seq_gap
-  })
-  return msa_df
-
-
-def _create_species_dict(msa_df: pd.DataFrame) -> Dict[bytes, pd.DataFrame]:
-  """Creates mapping from species to msa dataframe of that species."""
-  species_lookup = {}
-  for species, species_df in msa_df.groupby('msa_species_identifiers'):
-    species_lookup[species] = species_df
-  return species_lookup
-'''
-
-@functools.lru_cache(maxsize=65536)
-def encode_accession(accession_id: str) -> int:
-  """Map accession codes to the serial order in which they were assigned."""
-  alpha = ALPHA_ACCESSION_ID_MAP        # A-Z
-  alphanum = ALPHANUM_ACCESSION_ID_MAP  # A-Z,0-9
-  num = NUM_ACCESSION_ID_MAP            # 0-9
-
-  coding = 0
-
-  # This is based on the uniprot accession id format
-  # https://www.uniprot.org/help/accession_numbers
-  if accession_id[0] in {'O', 'P', 'Q'}:
-    bases = (alpha, num, alphanum, alphanum, alphanum, num)
-  elif len(accession_id) == 6:
-    bases = (alpha, num, alpha, alphanum, alphanum, num)
-  elif len(accession_id) == 10:
-    bases = (alpha, num, alpha, alphanum, alphanum, num, alpha, alphanum,
-             alphanum, num)
-
-  product = 1
-  for place, base in zip(reversed(accession_id), reversed(bases)):
-    coding += base[place] * product
-    product *= len(base)
-
-  return coding
-
-
-def _calc_id_diff(id_a: bytes, id_b: bytes) -> int:
-  return abs(encode_accession(id_a.decode()) - encode_accession(id_b.decode()))
-
-
-def _find_all_accession_matches(accession_id_lists: List[List[bytes]],
-                                diff_cutoff: int = 20
-                                ) -> List[List[Any]]:
-  """Finds accession id matches across the chains based on their difference."""
-  all_accession_tuples = []
-  current_tuple = []
-  tokens_used_in_answer = set()
-
-  def _matches_all_in_current_tuple(inp: bytes, diff_cutoff: int) -> bool:
-    return all((_calc_id_diff(s, inp) < diff_cutoff for s in current_tuple))
-
-  def _all_tokens_not_used_before() -> bool:
-    return all((s not in tokens_used_in_answer for s in current_tuple))
-
-  def dfs(level, accession_id, diff_cutoff=diff_cutoff) -> None:
-    if level == len(accession_id_lists) - 1:
-      if _all_tokens_not_used_before():
-        all_accession_tuples.append(list(current_tuple))
-        for s in current_tuple:
-          tokens_used_in_answer.add(s)
-      return
-
-    if level == -1:
-      new_list = accession_id_lists[level+1]
-    else:
-      new_list = [(_calc_id_diff(accession_id, s), s) for
-                  s in accession_id_lists[level+1]]
-      new_list = sorted(new_list)
-      new_list = [s for d, s in new_list]
-
-    for s in new_list:
-      if (_matches_all_in_current_tuple(s, diff_cutoff) and
-          s not in tokens_used_in_answer):
-        current_tuple.append(s)
-        dfs(level + 1, s)
-        current_tuple.pop()
-  dfs(-1, '')
-  return all_accession_tuples
-
-'''
-def _accession_row(msa_df: pd.DataFrame, accession_id: bytes) -> pd.Series:
-  matched_df = msa_df[msa_df.msa_uniprot_accession_identifiers == accession_id]
-  return matched_df.iloc[0]
-
-
-def _match_rows_by_genetic_distance(
-    this_species_msa_dfs: List[pd.DataFrame],
-    cutoff: int = 20) -> List[List[int]]:
-  """Finds MSA sequence pairings across chains within a genetic distance cutoff.
-
-  The genetic distance between two sequences is approximated by taking the
-  difference in their UniProt accession ids.
-
-  Args:
-    this_species_msa_dfs: a list of dataframes containing MSA features for
-      sequences for a specific species. If species is missing for a chain, the
-      dataframe is set to None.
-    cutoff: the genetic distance cutoff.
-
-  Returns:
-    A list of lists, each containing M indices corresponding to paired MSA rows,
-    where M is the number of chains.
-  """
-  num_examples = len(this_species_msa_dfs)  # N
-
-  accession_id_lists = []  # M
-  match_index_to_chain_index = {}
-  for chain_index, species_df in enumerate(this_species_msa_dfs):
-    if species_df is not None:
-      accession_id_lists.append(
-          list(species_df.msa_uniprot_accession_identifiers.values))
-      # Keep track of which of the this_species_msa_dfs are not None.
-      match_index_to_chain_index[len(accession_id_lists) - 1] = chain_index
-
-  all_accession_id_matches = _find_all_accession_matches(
-      accession_id_lists, cutoff)  # [k, M]
-
-  all_paired_msa_rows = []  # [k, N]
-  for accession_id_match in all_accession_id_matches:
-    paired_msa_rows = []
-    for match_index, accession_id in enumerate(accession_id_match):
-      # Map back to chain index.
-      chain_index = match_index_to_chain_index[match_index]
-      seq_series = _accession_row(
-          this_species_msa_dfs[chain_index], accession_id)
-
-      if (seq_series.msa_similarity > SEQUENCE_SIMILARITY_CUTOFF or
-          seq_series.gap > SEQUENCE_GAP_CUTOFF):
-        continue
-      else:
-        paired_msa_rows.append(seq_series.msa_row)
-    # If a sequence is skipped based on sequence similarity to the respective
-    # target sequence or a gap cuttoff, the lengths of accession_id_match and
-    # paired_msa_rows will be different. Skip this match.
-    if len(paired_msa_rows) == len(accession_id_match):
-      paired_and_non_paired_msa_rows = np.array([-1] * num_examples)
-      matched_chain_indices = list(match_index_to_chain_index.values())
-      paired_and_non_paired_msa_rows[matched_chain_indices] = paired_msa_rows
-      all_paired_msa_rows.append(list(paired_and_non_paired_msa_rows))
-  return all_paired_msa_rows
-
-
-def _match_rows_by_sequence_similarity(this_species_msa_dfs: List[pd.DataFrame]
-                                       ) -> List[List[int]]:
-  """Finds MSA sequence pairings across chains based on sequence similarity.
-
-  Each chain's MSA sequences are first sorted by their sequence similarity to
-  their respective target sequence. The sequences are then paired, starting
-  from the sequences most similar to their target sequence.
-
-  Args:
-    this_species_msa_dfs: a list of dataframes containing MSA features for
-      sequences for a specific species.
-
-  Returns:
-   A list of lists, each containing M indices corresponding to paired MSA rows,
-   where M is the number of chains.
-  """
-  all_paired_msa_rows = []
-
-  num_seqs = [len(species_df) for species_df in this_species_msa_dfs
-              if species_df is not None]
-  take_num_seqs = np.min(num_seqs)
-
-  sort_by_similarity = (
-      lambda x: x.sort_values('msa_similarity', axis=0, ascending=False))
-
-  for species_df in this_species_msa_dfs:
-    if species_df is not None:
-      species_df_sorted = sort_by_similarity(species_df)
-      msa_rows = species_df_sorted.msa_row.iloc[:take_num_seqs].values
-    else:
-      msa_rows = [-1] * take_num_seqs  # take the last 'padding' row
-    all_paired_msa_rows.append(msa_rows)
-  all_paired_msa_rows = list(np.array(all_paired_msa_rows).transpose())
-  return all_paired_msa_rows
-
-
-def pair_sequences(examples: List[pipeline.FeatureDict],
-                   prokaryotic: bool) -> Dict[int, np.ndarray]:
-  """Returns indices for paired MSA sequences across chains."""
-
-  num_examples = len(examples)
-
-  all_chain_species_dict = []
-  common_species = set()
-  for chain_features in examples:
-    msa_df = _make_msa_df(chain_features)
-    species_dict = _create_species_dict(msa_df)
-    all_chain_species_dict.append(species_dict)
-    common_species.update(set(species_dict))
-
-  common_species = sorted(common_species)
-  common_species.remove(b'')  # Remove target sequence species.
-
-  all_paired_msa_rows = [np.zeros(len(examples), int)]
-  all_paired_msa_rows_dict = {k: [] for k in range(num_examples)}
-  all_paired_msa_rows_dict[num_examples] = [np.zeros(len(examples), int)]
-
-  for species in common_species:
-    if not species:
-      continue
-    this_species_msa_dfs = []
-    species_dfs_present = 0
-    for species_dict in all_chain_species_dict:
-      if species in species_dict:
-        this_species_msa_dfs.append(species_dict[species])
-        species_dfs_present += 1
-      else:
-        this_species_msa_dfs.append(None)
-
-    # Skip species that are present in only one chain.
-    if species_dfs_present <= 1:
-      continue
-
-    if np.any(
-        np.array([len(species_df) for species_df in
-                  this_species_msa_dfs if
-                  isinstance(species_df, pd.DataFrame)]) > 600):
-      continue
-
-    # In prokaryotes (and some eukaryotes), interacting genes are often
-    # co-located on the chromosome into operons. Because of that we can assume
-    # that if two proteins' intergenic distance is less than a threshold, they
-    # two proteins will form an an interacting pair.
-    # In most eukaryotes, a single protein's MSA can contain many paralogs.
-    # Two genes may interact even if they are not close by genomic distance.
-    # In case of eukaryotes, some methods pair MSA sequences using sequence
-    # similarity method.
-    # See Jinbo Xu's work:
-    # https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6030867/#B28.
-    if prokaryotic:
-      paired_msa_rows = _match_rows_by_genetic_distance(this_species_msa_dfs)
-
-      if not paired_msa_rows:
-        continue
-    else:
-      paired_msa_rows = _match_rows_by_sequence_similarity(this_species_msa_dfs)
-    all_paired_msa_rows.extend(paired_msa_rows)
-    all_paired_msa_rows_dict[species_dfs_present].extend(paired_msa_rows)
-  all_paired_msa_rows_dict = {
-      num_examples: np.array(paired_msa_rows) for
-      num_examples, paired_msa_rows in all_paired_msa_rows_dict.items()
-  }
-  return all_paired_msa_rows_dict
-'''
-
-def reorder_paired_rows(all_paired_msa_rows_dict: Dict[int, np.ndarray]
-                        ) -> np.ndarray:
-  """Creates a list of indices of paired MSA rows across chains.
-
-  Args:
-    all_paired_msa_rows_dict: a mapping from the number of paired chains to the
-      paired indices.
-
-  Returns:
-    a list of lists, each containing indices of paired MSA rows across chains.
-    The paired-index lists are ordered by:
-      1) the number of chains in the paired alignment, i.e, all-chain pairings
-         will come first.
-      2) e-values
-  """
-  all_paired_msa_rows = []
-
-  for num_pairings in sorted(all_paired_msa_rows_dict, reverse=True):
-    paired_rows = all_paired_msa_rows_dict[num_pairings]
-    paired_rows_product = abs(np.array([np.prod(rows) for rows in paired_rows]))
-    paired_rows_sort_index = np.argsort(paired_rows_product)
-    all_paired_msa_rows.extend(paired_rows[paired_rows_sort_index])
-
-  return np.array(all_paired_msa_rows)
-
-
-def block_diag(*arrs: np.ndarray, pad_value: float = 0.0) -> np.ndarray:
-  """Like scipy.linalg.block_diag but with an optional padding value."""
-  ones_arrs = [np.ones_like(x) for x in arrs]
-  off_diag_mask = 1.0 - scipy.linalg.block_diag(*ones_arrs)
-  diag = scipy.linalg.block_diag(*arrs)
-  diag += (off_diag_mask * pad_value).astype(diag.dtype)
-  return diag
-
-
-def _correct_post_merged_feats(
-    np_example: pipeline.FeatureDict,
-    np_chains_list: Sequence[pipeline.FeatureDict],
-    pair_msa_sequences: bool) -> pipeline.FeatureDict:
-  """Adds features that need to be computed/recomputed post merging."""
-
-  np_example['seq_length'] = np.asarray(np_example['aatype'].shape[0],
-                                        dtype=np.int32)
-  np_example['num_alignments'] = np.asarray(np_example['msa'].shape[0],
-                                            dtype=np.int32)
-
-  if not pair_msa_sequences:
-    # Generate a bias that is 1 for the first row of every block in the
-    # block diagonal MSA - i.e. make sure the cluster stack always includes
-    # the query sequences for each chain (since the first row is the query
-    # sequence).
-    cluster_bias_masks = []
-    for chain in np_chains_list:
-      mask = np.zeros(chain['msa'].shape[0])
-      mask[0] = 1
-      cluster_bias_masks.append(mask)
-    np_example['cluster_bias_mask'] = np.concatenate(cluster_bias_masks)
-
-    # Initialize Bert mask with masked out off diagonals.
-    msa_masks = [np.ones(x['msa'].shape, dtype=np.float32)
-                 for x in np_chains_list]
-
-    np_example['bert_mask'] = block_diag(
-        *msa_masks, pad_value=0)
-  else:
-    np_example['cluster_bias_mask'] = np.zeros(np_example['msa'].shape[0])
-    np_example['cluster_bias_mask'][0] = 1
-
-    # Initialize Bert mask with masked out off diagonals.
-    msa_masks = [np.ones(x['msa'].shape, dtype=np.float32) for
-                 x in np_chains_list]
-    msa_masks_all_seq = [np.ones(x['msa_all_seq'].shape, dtype=np.float32) for
-                         x in np_chains_list]
-
-    msa_mask_block_diag = block_diag(
-        *msa_masks, pad_value=0)
-    msa_mask_all_seq = np.concatenate(msa_masks_all_seq, axis=1)
-    np_example['bert_mask'] = np.concatenate(
-        [msa_mask_all_seq, msa_mask_block_diag], axis=0)
-  return np_example
-
-
-def _pad_templates(chains: Sequence[pipeline.FeatureDict],
-                   max_templates: int) -> Sequence[pipeline.FeatureDict]:
-  """For each chain pad the number of templates to a fixed size.
-
-  Args:
-    chains: A list of protein chains.
-    max_templates: Each chain will be padded to have this many templates.
-
-  Returns:
-    The list of chains, updated to have template features padded to
-    max_templates.
-  """
-  for chain in chains:
-    for k, v in chain.items():
-      if k in TEMPLATE_FEATURES:
-        padding = np.zeros_like(v.shape)
-        padding[0] = max_templates - v.shape[0]
-        padding = [(0, p) for p in padding]
-        chain[k] = np.pad(v, padding, mode='constant')
-  return chains
-
-
-def _merge_features_from_multiple_chains(
-    chains: Sequence[pipeline.FeatureDict],
-    pair_msa_sequences: bool) -> pipeline.FeatureDict:
-  """Merge features from multiple chains.
-
-  Args:
-    chains: A list of feature dictionaries that we want to merge.
-    pair_msa_sequences: Whether to concatenate MSA features along the
-      num_res dimension (if True), or to block diagonalize them (if False).
-
-  Returns:
-    A feature dictionary for the merged example.
-  """
-  merged_example = {}
-  for feature_name in chains[0]:
-    feats = [x[feature_name] for x in chains]
-    feature_name_split = feature_name.split('_all_seq')[0]
-    if feature_name_split in MSA_FEATURES:
-      if pair_msa_sequences or '_all_seq' in feature_name:
-        merged_example[feature_name] = np.concatenate(feats, axis=1)
-      else:
-        merged_example[feature_name] = block_diag(
-            *feats, pad_value=MSA_PAD_VALUES[feature_name])
-    elif feature_name_split in SEQ_FEATURES:
-      merged_example[feature_name] = np.concatenate(feats, axis=0)
-    elif feature_name_split in TEMPLATE_FEATURES:
-      merged_example[feature_name] = np.concatenate(feats, axis=1)
-    elif feature_name_split in CHAIN_FEATURES:
-      merged_example[feature_name] = np.sum(x for x in feats).astype(np.int32)
-    else:
-      merged_example[feature_name] = feats[0]
-  return merged_example
-
-
-def _merge_homomers_dense_msa(
-    chains: Iterable[pipeline.FeatureDict]) -> Sequence[pipeline.FeatureDict]:
-  """Merge all identical chains, making the resulting MSA dense.
-
-  Args:
-    chains: An iterable of features for each chain.
-
-  Returns:
-    A list of feature dictionaries.  All features with the same entity_id
-    will be merged - MSA features will be concatenated along the num_res
-    dimension - making them dense.
-  """
-  entity_chains = collections.defaultdict(list)
-  for chain in chains:
-    entity_id = chain['entity_id'][0]
-    entity_chains[entity_id].append(chain)
-
-  grouped_chains = []
-  for entity_id in sorted(entity_chains):
-    chains = entity_chains[entity_id]
-    grouped_chains.append(chains)
-  chains = [
-      _merge_features_from_multiple_chains(chains, pair_msa_sequences=True)
-      for chains in grouped_chains]
-  return chains
-
-
-def _concatenate_paired_and_unpaired_features(
-    example: pipeline.FeatureDict) -> pipeline.FeatureDict:
-  """Merges paired and block-diagonalised features."""
-  features = MSA_FEATURES
-  for feature_name in features:
-    if feature_name in example:
-      feat = example[feature_name]
-      feat_all_seq = example[feature_name + '_all_seq']
-      merged_feat = np.concatenate([feat_all_seq, feat], axis=0)
-      example[feature_name] = merged_feat
-  example['num_alignments'] = np.array(example['msa'].shape[0],
-                                       dtype=np.int32)
-  return example
-
-
-def merge_chain_features(np_chains_list: List[pipeline.FeatureDict],
-                         pair_msa_sequences: bool,
-                         max_templates: int) -> pipeline.FeatureDict:
-  """Merges features for multiple chains to single FeatureDict.
-
-  Args:
-    np_chains_list: List of FeatureDicts for each chain.
-    pair_msa_sequences: Whether to merge paired MSAs.
-    max_templates: The maximum number of templates to include.
-
-  Returns:
-    Single FeatureDict for entire complex.
-  """
-  np_chains_list = _pad_templates(
-      np_chains_list, max_templates=max_templates)
-  np_chains_list = _merge_homomers_dense_msa(np_chains_list)
-  # Unpaired MSA features will be always block-diagonalised; paired MSA
-  # features will be concatenated.
-  np_example = _merge_features_from_multiple_chains(
-      np_chains_list, pair_msa_sequences=False)
-  if pair_msa_sequences:
-    np_example = _concatenate_paired_and_unpaired_features(np_example)
-  np_example = _correct_post_merged_feats(
-      np_example=np_example,
-      np_chains_list=np_chains_list,
-      pair_msa_sequences=pair_msa_sequences)
-
-  return np_example
-
-
-def deduplicate_unpaired_sequences(
-    np_chains: List[pipeline.FeatureDict]) -> List[pipeline.FeatureDict]:
-  """Removes unpaired sequences which duplicate a paired sequence."""
-
-  feature_names = np_chains[0].keys()
-  msa_features = MSA_FEATURES
-
-  for chain in np_chains:
-    sequence_set = set(tuple(s) for s in chain['msa_all_seq'])
-    keep_rows = []
-    # Go through unpaired MSA seqs and remove any rows that correspond to the
-    # sequences that are already present in the paired MSA.
-    for row_num, seq in enumerate(chain['msa']):
-      if tuple(seq) not in sequence_set:
-        keep_rows.append(row_num)
-    for feature_name in feature_names:
-      if feature_name in msa_features:
-        if keep_rows:
-          chain[feature_name] = chain[feature_name][keep_rows]
-        else:
-          new_shape = list(chain[feature_name].shape)
-          new_shape[0] = 0
-          chain[feature_name] = np.zeros(new_shape,
-                                         dtype=chain[feature_name].dtype)
-    chain['num_alignments'] = np.array(chain['msa'].shape[0], dtype=np.int32)
-  return np_chains
diff --git a/src/run_alphafold_stage1.py b/src/run_alphafold_stage1.py
deleted file mode 100644
index 7ce2e00..0000000
--- a/src/run_alphafold_stage1.py
+++ /dev/null
@@ -1,220 +0,0 @@
-# Copyright 2021 DeepMind Technologies Limited
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#      http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Run data pipeline to generate input features for alphafold, save the features
-
-"""AlphaFold Stage 1: data pipeline for the generation of input features."""
-import json
-import os
-import pathlib
-import pickle
-import random
-import sys
-import time
-from typing import Dict
-from socket import gethostname
-
-from absl import app
-from absl import flags
-from absl import logging
-import numpy as np
-
-from alphafold.common import protein
-from alphafold.data import pipeline
-from alphafold.data import templates
-from alphafold.model import data
-from alphafold.model import config
-#from alphafold.model import model
-#from alphafold.relax import relax
-# Internal import (7716).
-
-flags.DEFINE_list('fasta_paths', None, 'Paths to FASTA files, each containing '
-                  'one sequence. Paths should be separated by commas. '
-                  'All FASTA paths must have a unique basename as the '
-                  'basename is used to name the output directories for '
-                  'each prediction.')
-flags.DEFINE_string('output_dir', None, 'Path to a directory that will '
-                    'store the results.')
-flags.DEFINE_list('model_names', None, 'Names of models to use.')
-flags.DEFINE_string('data_dir', None, 'Path to directory of supporting data.')
-flags.DEFINE_string('jackhmmer_binary_path', '/usr/bin/jackhmmer',
-                    'Path to the JackHMMER executable.')
-flags.DEFINE_string('hhblits_binary_path', '/usr/bin/hhblits',
-                    'Path to the HHblits executable.')
-flags.DEFINE_string('hhsearch_binary_path', '/usr/bin/hhsearch',
-                    'Path to the HHsearch executable.')
-flags.DEFINE_string('kalign_binary_path', '/usr/bin/kalign',
-                    'Path to the Kalign executable.')
-flags.DEFINE_string('uniref90_database_path', None, 'Path to the Uniref90 '
-                    'database for use by JackHMMER.')
-flags.DEFINE_string('mgnify_database_path', None, 'Path to the MGnify '
-                    'database for use by JackHMMER.')
-flags.DEFINE_string('bfd_database_path', None, 'Path to the BFD '
-                    'database for use by HHblits.')
-flags.DEFINE_string('small_bfd_database_path', None, 'Path to the small '
-                    'version of BFD used with the "reduced_dbs" preset.')
-flags.DEFINE_string('uniclust30_database_path', None, 'Path to the Uniclust30 '
-                    'database for use by HHblits.')
-flags.DEFINE_string('pdb70_database_path', None, 'Path to the PDB70 '
-                    'database for use by HHsearch.')
-flags.DEFINE_string('template_mmcif_dir', None, 'Path to a directory with '
-                    'template mmCIF structures, each named <pdb_id>.cif')
-flags.DEFINE_string('max_template_date', None, 'Maximum template release date '
-                    'to consider. Important if folding historical test sets.')
-flags.DEFINE_string('obsolete_pdbs_path', None, 'Path to file containing a '
-                    'mapping from obsolete PDB IDs to the PDB IDs of their '
-                    'replacements.')
-flags.DEFINE_enum('preset', None,
-                  ['reduced_dbs', 'full_dbs', 'casp14'],
-                  'Choose preset model configuration - no ensembling and '
-                  'smaller genetic database config (reduced_dbs), no '
-                  'ensembling and full genetic database config  (full_dbs) or '
-                  'full genetic database config and 8 model ensemblings '
-                  '(casp14).')
-flags.DEFINE_boolean('benchmark', False, 'Run multiple JAX model evaluations '
-                     'to obtain a timing that excludes the compilation time, '
-                     'which should be more indicative of the time required for '
-                     'inferencing many proteins.')
-flags.DEFINE_integer('random_seed', None, 'The random seed for the data '
-                     'pipeline. By default, this is randomly generated. Note '
-                     'that even if this is set, Alphafold may still not be '
-                     'deterministic, because processes like GPU inference are '
-                     'nondeterministic.')
-FLAGS = flags.FLAGS
-
-MAX_TEMPLATE_HITS = 20
-RELAX_MAX_ITERATIONS = 0
-RELAX_ENERGY_TOLERANCE = 2.39
-RELAX_STIFFNESS = 10.0
-RELAX_EXCLUDE_RESIDUES = []
-RELAX_MAX_OUTER_ITERATIONS = 20
-
-
-def _check_flag(flag_name: str, preset: str, should_be_set: bool):
-  if should_be_set != bool(FLAGS[flag_name].value):
-    verb = 'be' if should_be_set else 'not be'
-    raise ValueError(f'{flag_name} must {verb} set for preset "{preset}"')
-
-
-def predict_structure(
-    fasta_path: str,
-    fasta_name: str,
-    output_dir_base: str,
-    data_pipeline: pipeline.DataPipeline,
-    benchmark: bool,
-    random_seed: int):
-  """Predicts structure using AlphaFold for the given sequence."""
-  timings = {}
-  output_dir = os.path.join(output_dir_base, fasta_name)
-  if not os.path.exists(output_dir):
-    os.makedirs(output_dir)
-  msa_output_dir = os.path.join(output_dir, 'msas')
-  if not os.path.exists(msa_output_dir):
-    os.makedirs(msa_output_dir)
-
-  # Get features.
-  t_0 = time.time()
-  feature_dict = data_pipeline.process(
-      input_fasta_path=fasta_path,
-      msa_output_dir=msa_output_dir)
-  timings['features'] = time.time() - t_0
-
-  # Write out features as a pickled dictionary.
-  features_output_path = os.path.join(output_dir, 'features.pkl')
-  with open(features_output_path, 'wb') as f:
-    pickle.dump(feature_dict, f, protocol=4)
-
-  logging.info('Final timings for %s: %s', fasta_name, timings)
-
-  timings_output_path = os.path.join(output_dir, 'timings_fea.json')
-  with open(timings_output_path, 'w') as f:
-    f.write(json.dumps(timings, indent=4))
-
-
-def main(argv):
-  if len(argv) > 1:
-    raise app.UsageError('Too many command-line arguments.')
-
-  use_small_bfd = FLAGS.preset == 'reduced_dbs'
-  _check_flag('small_bfd_database_path', FLAGS.preset,
-              should_be_set=use_small_bfd)
-  _check_flag('bfd_database_path', FLAGS.preset,
-              should_be_set=not use_small_bfd)
-  _check_flag('uniclust30_database_path', FLAGS.preset,
-              should_be_set=not use_small_bfd)
-
-  if FLAGS.preset in ('reduced_dbs', 'full_dbs'):
-    num_ensemble = 1
-  elif FLAGS.preset == 'casp14':
-    num_ensemble = 8
-
-  # Check for duplicate FASTA file names.
-  fasta_names = [pathlib.Path(p).stem for p in FLAGS.fasta_paths]
-  if len(fasta_names) != len(set(fasta_names)):
-    raise ValueError('All FASTA paths must have a unique basename.')
-
-  template_featurizer = templates.TemplateHitFeaturizer(
-      mmcif_dir=FLAGS.template_mmcif_dir,
-      max_template_date=FLAGS.max_template_date,
-      max_hits=MAX_TEMPLATE_HITS,
-      kalign_binary_path=FLAGS.kalign_binary_path,
-      release_dates_path=None,
-      obsolete_pdbs_path=FLAGS.obsolete_pdbs_path)
-
-  data_pipeline = pipeline.DataPipeline(
-      jackhmmer_binary_path=FLAGS.jackhmmer_binary_path,
-      hhblits_binary_path=FLAGS.hhblits_binary_path,
-      hhsearch_binary_path=FLAGS.hhsearch_binary_path,
-      uniref90_database_path=FLAGS.uniref90_database_path,
-      mgnify_database_path=FLAGS.mgnify_database_path,
-      bfd_database_path=FLAGS.bfd_database_path,
-      uniclust30_database_path=FLAGS.uniclust30_database_path,
-      small_bfd_database_path=FLAGS.small_bfd_database_path,
-      pdb70_database_path=FLAGS.pdb70_database_path,
-      template_featurizer=template_featurizer,
-      use_small_bfd=use_small_bfd)
-
-  random_seed = FLAGS.random_seed
-  if random_seed is None:
-    random_seed = random.randrange(sys.maxsize)
-  logging.info('Using random seed %d for the data pipeline', random_seed)
-
-  # Predict structure for each of the sequences.
-  for fasta_path, fasta_name in zip(FLAGS.fasta_paths, fasta_names):
-    host_name = gethostname()
-    print(f"Info: working on target {fasta_name} at {host_name}")
-    predict_structure(
-        fasta_path=fasta_path,
-        fasta_name=fasta_name,
-        output_dir_base=FLAGS.output_dir,
-        data_pipeline=data_pipeline,
-        benchmark=FLAGS.benchmark,
-        random_seed=random_seed)
-
-
-if __name__ == '__main__':
-  flags.mark_flags_as_required([
-      'fasta_paths',
-      'output_dir',
-      'data_dir',
-      'preset',
-      'uniref90_database_path',
-      'mgnify_database_path',
-      'pdb70_database_path',
-      'template_mmcif_dir',
-      'max_template_date',
-      'obsolete_pdbs_path',
-  ])
-
-  app.run(main)
diff --git a/src/run_alphafold_stage2a_comp.py b/src/run_alphafold_stage2a_comp.py
deleted file mode 100644
index 35b562d..0000000
--- a/src/run_alphafold_stage2a_comp.py
+++ /dev/null
@@ -1,502 +0,0 @@
-# Copyright 2021 DeepMind Technologies Limited
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#      http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-# Run AlphaFold DL modelds using pre-generated features w/o model relaxation
-#
-# Input: pre-generated features from AlphaFold DataPipeline on single sequences
-# Output: un-relaxed protein models
-#
-# Note: AF2Complex is a modified, enhanced version of AlphaFold 2.
-# Additional unofficial features added:
-#
-# Predicting models of a protein complex including both homooligomer and heterooligomer
-# No MSA pairing required
-# New metrics designed for evaluating protein-protein interface
-# Saving structure models of all recycles
-# Split feature generation (stage 1), DL inference (stage 2a), and model relaxation (stage 2b)
-#
-# Some other features such as option for dynamically controled number of recycles and
-# residue index breaks were taken from ColabFold
-#
-# Mu Gao and Davi Nakajima An
-# Georgia Institute of Technology
-
-"""Enhanced AlphaFold Stage2a: protein complex structure prediction with deep learning"""
-import json
-import os
-import pickle
-import random
-import sys
-import time
-import re
-from typing import Dict
-
-from absl import app
-from absl import flags
-from absl import logging
-
-from alphafold.common import protein
-from alphafold.model import data
-from alphafold.model import config
-from alphafold.model import model
-from alphafold.data import pipeline
-
-from alphafold.data.complex import *
-from datetime import date
-
-import numpy as np
-# Internal import (7716).
-
-
-flags.DEFINE_string('target_lst_path', None, 'Path to a file containing a list of targets '
-                  'in any monomer, homo- or hetero-oligomers '
-                  'configurations. For example, TarA is a monomer, TarA:2 is a dimer '
-                  'of two TarAs. TarA:2/TarB is a trimer of two TarA and one TarB, etc.'
-                  )
-flags.DEFINE_string('output_dir', None, 'Path to a directory that will '
-                    'store the results.')
-flags.DEFINE_string('feature_dir', None, 'Path to a directory that will '
-                    'contains pre-genearted feature in pickle format.')
-flags.DEFINE_list('model_names', None, 'Names of deep learning models to use.')
-flags.DEFINE_string('data_dir', None, 'Path to directory of supporting data.')
-flags.DEFINE_enum('preset', None,
-                  ['reduced_dbs', 'casp14', 'economy', 'super', 'super2', 'genome', 'genome2'],
-                  'Choose preset model configuration: <reduced_dbs> no ensembling, '
-                  '<economy> no ensemble, up to 256 MSA clusters, recycling up to 3 rounds; '
-                  '<super, super2> 1 or 2 ensembles, up to 512 MSA clusters, recycling up to 20 rounds; '
-                  '<genome, genome2> 1 or 2 ensembles, up to 512 MSA clusters, max number '
-                  'of recycles and ensembles adjusted according to input sequence length; '
-                  'or <casp14> 8 model ensemblings of the factory settings.')
-flags.DEFINE_integer('random_seed', None, 'The random seed for the data '
-                     'pipeline. By default, this is randomly generated. Note '
-                     'that even if this is set, Alphafold may still not be '
-                     'deterministic, because processes like GPU inference are '
-                     'nondeterministic.')
-flags.DEFINE_integer('max_recycles', None, 'The maximum number of recycles.', lower_bound=1)
-flags.DEFINE_float('recycle_tol', None, 'The tolerance for recycling, caculated as the RMSD change '
-                    'in the distogram of backbone Ca atoms. Recycling stops '
-                    'if the change is smaller than this value', lower_bound=0.0, upper_bound=2.0)
-flags.DEFINE_integer('num_ensemble', None, 'The number of ensembles of each model, 1 means no ensembling.', lower_bound=1)
-flags.DEFINE_integer('max_msa_clusters', None, 'The maximum number of MSA clusters.', lower_bound=1)
-flags.DEFINE_integer('max_extra_msa', None, 'The maximum number of extra MSA clusters.', lower_bound=1)
-flags.DEFINE_boolean('write_complex_features', False, 'Save the feature dict for '
-                    'complex prediction as a pickle file under the output direcotry')
-flags.DEFINE_enum('template_mode', 'oligomer', ['none', 'monomer', 'oligomer'],
-                  'none - No template is allowed, '
-                  'monomer - Use template only for monomer but not oligomer modeling, '
-                  'oligomer - Use monomer template for all modeling if exists.')
-flags.DEFINE_integer('save_recycled', 0, '0 - no recycle info saving, 1 - print '
-                   'metrics of intermediate recycles, 2 - additionally saving pdb structures '
-                   'of all recycles, 3 - additionally save all results in pickle '
-                    'dictionaries of each recycling iteration.', lower_bound=0, upper_bound=3)
-
-FLAGS = flags.FLAGS
-
-MAX_TEMPLATE_HITS = 4
-MAX_MSA_DEPTH_MONO = 10000   ### maximum number of input sequences in the msa of a monomer
-
-##################################################################################################
-# read either a single target string or a input list of targets in a file
-# each line has a format like: <target> <length> (output_name), the output_name is optional.
-# In <target>, use monomer:num to indicate num of copies in a homooligomer
-# and name1/name2 to indicate heterooligomer. For example, TarA:2/TarB is 2 copies of TarA and 1 TarB
-def _read_target_file( data_lst_file ):
-    target_lst = []
-    if not os.path.exists( data_lst_file ):  ### input is a single target in strings
-        fields = data_lst_file.split(',')
-        fullname = name = fields[0]
-        if len(fields) == 2:
-            name = fields[1]
-        target_lst.append( {'full':fullname, 'name':name} )
-    else: ### input are a list of targets in a file
-      with open( data_lst_file ) as file:
-        for line in file:
-            if line.startswith("#"):
-                continue
-            line = line.strip() # strip "\n"
-            fields = line.split()
-            fullname = name = fields[0]
-            if len(fields) > 2 and not fields[2].startswith("#"):
-                name = fields[2]
-            target_lst.append( {'full':fullname, 'name':name} )
-
-    # process the components of a complex if detected
-    for target in target_lst:
-        complex = target['full']
-        monomers = []
-        subfields = complex.split('/')
-        for item in subfields:
-            cols = item.split(':')
-            if len(cols) == 1:
-                monomers.append( {cols[0]:1} )  ### monomer
-            elif len(cols) > 1:
-                monomers.append( {cols[0]:cols[1]} )
-        if len(monomers) >= 1:
-            target['split'] = monomers
-
-    return target_lst
-##################################################################################################
-
-
-##################################################################################################
-def predict_structure(
-    target: Dict[str, str],
-    output_dir_base: str,
-    feature_dir_base: str,
-    model_runners: Dict[str, model.RunModel],
-    random_seed: int,
-    max_msa_clusters: int,
-    max_extra_msa: int,
-    max_recycles: int,
-    num_ensemble: int,
-    preset: str,
-    write_complex_features: bool,
-    template_mode: str):
-  """Predicts structure using AlphaFold for the given sequence."""
-  timings = {}
-
-  target_name  = target['name']
-  target_name = re.sub(":", "_x", target_name)
-  target_name = re.sub("/", "+", target_name)
-
-  homo_copy = []
-  seq_names = []
-  for homo in target['split']:
-    for seq_name, seq_copy in homo.items():
-      homo_copy.append(int(seq_copy))
-      seq_names.append(seq_name)
-
-  time.sleep(random.randint(0,30)) # mitigating creating the same output directory from multiple runs
-
-  # Retrieve pre-generated features of monomers (single protien sequences)
-  t_0 = time.time()
-  feature_dicts = []
-  for seq_name in seq_names:
-    feature_dir = os.path.join(feature_dir_base, seq_name)
-    if not os.path.exists(feature_dir):
-      raise SystemExit("Error: ", feature_dir, "does not exists")
-
-    # load pre-generated features as a pickled dictionary.
-    features_input_path = os.path.join(feature_dir, 'features.pkl')
-    with open(features_input_path, "rb") as f:
-      mono_feature_dict = pickle.load(f)
-      N = len(mono_feature_dict["msa"])
-      L = len(mono_feature_dict["residue_index"])
-      T = len(mono_feature_dict["template_domain_names"])
-      print(f"Info: {target_name} found monomer {seq_name} msa_depth = {N}, seq_len = {L}, num_templ = {T}")
-      if N > MAX_MSA_DEPTH_MONO:
-          print(f"Info: {seq_name} MSA size is too large, reducing to {MAX_MSA_DEPTH_MONO}")
-          mono_feature_dict["msa"] = mono_feature_dict["msa"][:MAX_MSA_DEPTH_MONO,:]
-          mono_feature_dict["deletion_matrix_int"] = mono_feature_dict["deletion_matrix_int"][:MAX_MSA_DEPTH_MONO,:]
-          mono_feature_dict['num_alignments'][:] = MAX_MSA_DEPTH_MONO
-      if T > MAX_TEMPLATE_HITS:
-          print(f"Info: {seq_name} reducing the number of structural templates to {MAX_TEMPLATE_HITS}")
-          mono_feature_dict["template_aatype"] = mono_feature_dict["template_aatype"][:MAX_TEMPLATE_HITS,...]
-          mono_feature_dict["template_all_atom_masks"] = mono_feature_dict["template_all_atom_masks"][:MAX_TEMPLATE_HITS,...]
-          mono_feature_dict["template_all_atom_positions"] = mono_feature_dict["template_all_atom_positions"][:MAX_TEMPLATE_HITS,...]
-          mono_feature_dict["template_domain_names"] = mono_feature_dict["template_domain_names"][:MAX_TEMPLATE_HITS]
-          mono_feature_dict["template_sequence"] = mono_feature_dict["template_sequence"][:MAX_TEMPLATE_HITS]
-          mono_feature_dict["template_sum_probs"] = mono_feature_dict["template_sum_probs"][:MAX_TEMPLATE_HITS,:]
-      feature_dicts.append( mono_feature_dict )
-
-  # Make features for complex structure prediction using monomer structures if necessary
-  if len(seq_names) == 1 and homo_copy[0] == 1:   # monomer structure prediction
-    feature_dict = feature_dicts[0]
-    seq_len = len(feature_dict["residue_index"])
-    Ls = [seq_len]
-    if template_mode == 'none':
-        new_tem = initialize_template_feats(0, seq_len)
-        feature_dict.update(new_tem)
-  else:  # complex structure prediction
-    feature_dict, Ls = make_complex_features(feature_dicts, target_name, homo_copy, template_mode)
-
-  mono_chains = []
-  mono_chains = get_mono_chain(seq_names, homo_copy, Ls)
-  print(f"Info: individual chain(s) to model {mono_chains}")
-
-  N = len(feature_dict["msa"])
-  L = len(feature_dict["residue_index"])
-  T = len(feature_dict["template_domain_names"])
-  print(f"Info: modeling {target_name} with msa_depth = {N}, seq_len = {L}, num_templ = {T}")
-  timings['features'] = round(time.time() - t_0, 2)
-
-
-  output_dir = os.path.join(output_dir_base, target_name)
-  if not os.path.exists(output_dir):
-    try:
-        os.makedirs(output_dir)
-    except FileExistsError:
-        print(f"Warning: tried to create an existing {output_dir}, ignored")
-
-  if write_complex_features:
-      feature_output_path = os.path.join(output_dir, 'features_comp.pkl')
-      with open(feature_output_path, 'wb') as f:
-          pickle.dump(feature_dict, f, protocol=4)
-
-  today = date.today().strftime('%Y%m%d')
-  out_suffix = '_' + today + '_' + str(random_seed)[-6:]
-
-  plddts = {}  # predicted LDDT score
-  iterations = {}  # recycle information
-  ptms = {}; pitms = {} # predicted TM-score
-  ires = {}; icnt = {} # interfacial residues and contacts
-  tols = {} # change in backbone pairwise distance to check with the recyle tolerance criterion
-  ints = {} # interface-score
-  # Run models for structure prediction
-  for model_name, model_runner in model_runners.items():
-    model_out_name = model_name + out_suffix
-    logging.info('Running model %s', model_out_name)
-    t_0 = time.time()
-
-    # set size of msa (to reduce memory requirements)
-    if max_msa_clusters is not None and max_extra_msa is not None:
-        msa_clusters = max(min(N, max_msa_clusters),5)
-        model_runner.config.data.eval.max_msa_clusters = msa_clusters
-        model_runner.config.data.common.max_extra_msa = max(min(N-msa_clusters,max_extra_msa),1)
-    if preset in ['genome', 'genome2', 'super']:
-        max_iter = max_recycles - max(0, (L - 500) // 50)
-        max_iter = max( 6, max_iter )
-        if L > 1180:    ### memory limit of a single 16GB GPU, applied to cases with mutliple ensembles
-            num_en = 1
-        else:
-            num_en = num_ensemble
-        print(f"Info: {target_name} reset max_recycles = {max_iter}, num_ensemble = {num_en}")
-        model_runner.config.data.common.num_recycle = max_iter
-        model_runner.config.model.num_recycle = max_iter
-        model_runner.config.data.eval.num_ensemble = num_en
-
-    processed_feature_dict = model_runner.process_features(
-        feature_dict, random_seed=random_seed)
-    timings[f'process_features_{model_out_name}'] = round(time.time() - t_0, 2)
-
-    t_0 = time.time()
-    prediction_result, (tot_recycle, tol_value, recycled) = model_runner.predict(processed_feature_dict)
-    prediction_result['num_recycle'] = tot_recycle
-    prediction_result['mono_chains'] = mono_chains
-
-    tols[model_out_name] = round(tol_value.tolist(), 3)
-    iterations[model_out_name] = tot_recycle.tolist()  ### convert from jax numpy to regular list for json saving
-
-    t_diff = time.time() - t_0
-    timings[f'predict_and_compile_{model_out_name}'] = round(t_diff,1)
-    logging.info(
-      'Total JAX model %s predict time (includes compilation time): %.1f seconds', model_out_name, t_diff)
-
-    def _save_results(result, log_model_name, out_dir, recycle_index):
-      # Get mean pLDDT confidence metric.
-      plddt = np.mean(result['plddt'])
-      plddts[log_model_name] = round(plddt, 2)
-      ptm = 0
-      if 'ptm' in result:
-          ptm = result['ptm'].tolist()
-          ptms[log_model_name] = round(ptm, 4)
-      pitm = 0; inter_residues = 0; inter_contacts = 0; inter_sc = 0
-      if 'pitm' in result:
-          pitm = result['pitm']['score'].tolist()
-          inter_residues = result['pitm']['num_residues'].tolist()
-          inter_contacts = result['pitm']['num_contacts'].tolist()
-          pitms[log_model_name] = round(pitm, 4)
-          ires[log_model_name] = inter_residues
-          icnt[log_model_name] = int(inter_contacts)
-      if 'interface' in result:
-          inter_sc = result['interface']['score'].tolist()
-          ints[log_model_name] = round(inter_sc, 4)
-
-      if recycle_index < tot_recycle:
-          tol = result['tol_val'].tolist()
-          tols[log_model_name] = round(tol, 2)
-          print(f"Info: {target_name} {log_model_name}, ",
-            f"tol = {tol:5.2f}, pLDDT = {plddt:.2f}, pTM-score = {ptm:.4f}", end='')
-          if len(seq_names) > 1 or sum(homo_copy) > 1: # complex target
-            print(f", piTM-score = {pitm:.4f}, interface-score = {inter_sc:.4f}", end='')
-            print(f", iRes = {inter_residues:<4d} iCnt = {inter_contacts:<4.0f}")
-          else:
-            print('')
-      else:
-          print(f"Info: {target_name} {log_model_name} performed {tot_recycle} recycles,",
-            f"final tol = {tol_value:.2f}, pLDDT = {plddt:.2f}, pTM-score = {ptm:.4f}", end='')
-          if len(seq_names) > 1 or sum(homo_copy) > 1:
-            print(f", piTM-score = {pitm:.4f}, interface-score = {inter_sc:.4f}", end='')
-            print(f", iRes = {inter_residues:<4d} iCnt = {inter_contacts:<4.0f}")
-          else:
-            print('')
-
-      # Save the model outputs, not saving pkl for intermeidate recycles to save storage space
-      if recycle_index == tot_recycle or FLAGS.save_recycled == 3:
-          result_output_path = os.path.join(out_dir, f'{log_model_name}.pkl')
-          with open(result_output_path, 'wb') as f:
-              pickle.dump(result, f, protocol=4)
-
-      if recycle_index == tot_recycle or FLAGS.save_recycled >= 2:
-          # Set the b-factors to the per-residue plddt
-          final_atom_mask = result['structure_module']['final_atom_mask']
-          b_factors = result['plddt'][:, None] * final_atom_mask
-
-          unrelaxed_protein = protein.from_prediction(processed_feature_dict,
-                                                result, b_factors=b_factors)
-
-          unrelaxed_pdb_path = os.path.join(out_dir, f'{log_model_name}.pdb')
-          with open(unrelaxed_pdb_path, 'w') as f:
-            f.write(protein.to_pdb(unrelaxed_protein))
-
-    # output info of intermeidate recycles and save the coordinates
-    if FLAGS.save_recycled:
-      recycle_out_dir = os.path.join(output_dir, "recycled")
-      if FLAGS.save_recycled > 1 and not os.path.exists(recycle_out_dir):
-        os.mkdir(recycle_out_dir)
-      for i, rec_dict in enumerate(recycled):
-        if i < tot_recycle:
-            _save_results(rec_dict, f"{model_out_name}_recycled_{i:02d}",
-                          recycle_out_dir, i)
-
-    # the final results from this model run
-    _save_results(prediction_result, model_out_name, output_dir, tot_recycle)
-  # End of model runs
-
-  # Rank by pTMscore if exists, otherwise pLDDTs
-  ranked_order = []
-  if 'ptm' in prediction_result:
-      ranking_metric = 'pTM'
-      for idx, (mod_name, _) in enumerate(
-        sorted(ptms.items(), key=lambda x: x[1], reverse=True)):
-        ranked_order.append(mod_name)
-  else:
-      ranking_metric = 'pLDDT'
-      for idx, (mod_name, _) in enumerate(
-        sorted(plddts.items(), key=lambda x: x[1], reverse=True)):
-        ranked_order.append(mod_name)
-
-  stats = {'plddts': plddts, 'ptms': ptms, 'order': ranked_order,
-        'ranking_metric': ranking_metric, 'iterations': iterations,
-        'tol_values':tols, 'chains': mono_chains, 'chain_lengths': Ls,
-        'timings':timings}
-  if len(pitms):
-      stats = { **stats, 'pitms': pitms, 'interfacial residue number': ires,
-            'interficial contact number': icnt, 'interface score': ints }
-
-  if len(model_runners) > 1:  #more than 1 model
-      ranking_output_path = os.path.join(output_dir, 'ranking_all'+out_suffix+'.json')
-  else: #only one model, use different model names to avoid overwriting same file
-      ranking_output_path = os.path.join(output_dir, 'ranking_'+model_out_name+'.json')
-
-  with open(ranking_output_path, 'w') as f:
-    f.write(json.dumps(stats, sort_keys=True, indent=4))
-
-  logging.info('Final timings for %s: %s', target_name, timings)
-
-##################################################################################################
-
-def main(argv):
-  if len(argv) > 1:
-    raise app.UsageError('Too many command-line arguments.')
-
-  # read a list of target files
-  target_lst = _read_target_file( FLAGS.target_lst_path )
-
-  max_recycles = 3; recycle_tol = 0
-  max_extra_msa = None; max_msa_clusters = None
-  print("Info: using preset", FLAGS.preset)
-  if FLAGS.preset == 'reduced_dbs':
-    num_ensemble = 1
-  elif FLAGS.preset == 'casp14':
-    num_ensemble = 8
-  elif FLAGS.preset == 'economy':
-    num_ensemble = 1
-    recycle_tol  = 0.1
-    max_extra_msa = 512
-    max_msa_clusters = 256
-  elif FLAGS.preset in ['super', 'super2']:
-    num_ensemble = 1
-    max_recycles = 20
-    recycle_tol  = 0.1
-    max_extra_msa = 1024
-    max_msa_clusters = 512
-    if FLAGS.preset == 'super2': num_ensemble = 2
-  elif FLAGS.preset in ['genome', 'genome2']:
-    num_ensemble = 1
-    max_recycles = 20
-    recycle_tol  = 0.5
-    max_extra_msa = 1024
-    max_msa_clusters = 512
-    if FLAGS.preset == 'genome2': num_ensemble = 2
-
-  # allow customized parameters over preset
-  if FLAGS.num_ensemble is not None:
-    num_ensemble = FLAGS.num_ensemble
-  print(f"Info: set num_ensemble = {num_ensemble}")
-  if FLAGS.max_recycles is not None:
-    max_recycles = FLAGS.max_recycles
-  print(f"Info: set max_recyles = {max_recycles}")
-  if FLAGS.recycle_tol is not None:
-    recycle_tol = FLAGS.recycle_tol
-  print(f"Info: set recycle_tol = {recycle_tol}")
-  if FLAGS.max_msa_clusters is not None and FLAGS.max_extra_msa is not None:
-    max_msa_clusters = FLAGS.max_msa_clusters
-    max_extra_msa = FLAGS.max_extra_msa
-    print(f"Info: max_msa_clusters = {max_msa_clusters}, max_extra_msa = {max_extra_msa}")
-
-  model_runners = {}
-  for model_name in FLAGS.model_names:
-    model_config = config.model_config(model_name)
-    model_config.data.eval.num_ensemble = num_ensemble
-    model_config.data.common.num_recycle = max_recycles
-    model_config.model.num_recycle = max_recycles
-    model_config.model.recycle_tol = recycle_tol
-    model_config.model.save_recycled = FLAGS.save_recycled
-
-    model_params = data.get_model_haiku_params(
-        model_name=model_name, data_dir=FLAGS.data_dir)
-    model_runner = model.RunModel(model_config, model_params)
-    model_runners[model_name] = model_runner
-
-  logging.info('Have %d models: %s', len(model_runners),
-               list(model_runners.keys()))
-
-  random_seed = FLAGS.random_seed
-  if random_seed is None:
-    random_seed = random.randrange(sys.maxsize)
-  logging.info('Using random seed %d for the data pipeline', random_seed)
-
-  # Predict structure for each target.
-  for target in target_lst:
-    target_name  = target['name']
-    target_split = target['split']
-    print(f"Info: working on target {target_name}")
-    predict_structure(
-        target=target,
-        output_dir_base=FLAGS.output_dir,
-        feature_dir_base=FLAGS.feature_dir,
-        model_runners=model_runners,
-        random_seed=random_seed,
-        max_msa_clusters=max_msa_clusters,
-        max_extra_msa=max_extra_msa,
-        max_recycles=max_recycles,
-        num_ensemble=num_ensemble,
-        preset=FLAGS.preset,
-        write_complex_features=FLAGS.write_complex_features,
-        template_mode=FLAGS.template_mode
-        )
-
-
-if __name__ == '__main__':
-  flags.mark_flags_as_required([
-      'target_lst_path',
-      'output_dir',
-      'feature_dir',
-      'model_names',
-      'data_dir',
-      'preset',
-  ])
-
-  app.run(main)
diff --git a/src/run_alphafold_stage2b.py b/src/run_alphafold_stage2b.py
deleted file mode 100644
index 806ceed..0000000
--- a/src/run_alphafold_stage2b.py
+++ /dev/null
@@ -1,81 +0,0 @@
-"""Run MD minization to relax a protein structure model from AF2"""
-import os
-import pickle
-import re
-import time
-
-from absl import app
-from absl import flags
-from tqdm import tqdm
-
-from alphafold.relax  import relax
-from alphafold.common import protein
-
-from run_alphafold_stage2a_comp import _read_target_file, FLAGS
-
-import numpy as np
-
-
-
-MAX_TEMPLATE_HITS = 20
-RELAX_MAX_ITERATIONS = 0
-RELAX_ENERGY_TOLERANCE = 2.39
-RELAX_STIFFNESS = 10.0
-RELAX_EXCLUDE_RESIDUES = []
-RELAX_MAX_OUTER_ITERATIONS = 20
-
-
-def main(argv):
-  if len(argv) > 1:
-    raise app.UsageError('Too many command-line arguments.')
-
-  amber_relaxer = relax.AmberRelaxation(
-      max_iterations=RELAX_MAX_ITERATIONS,
-      tolerance=RELAX_ENERGY_TOLERANCE,
-      stiffness=RELAX_STIFFNESS,
-      exclude_residues=RELAX_EXCLUDE_RESIDUES,
-      max_outer_iterations=RELAX_MAX_OUTER_ITERATIONS)
-
-  # read list of targets
-  target_lst = _read_target_file( FLAGS.target_lst_path )
-
-  for target in target_lst:
-    # get target name
-    target_name  = target['name']
-    target_name = re.sub(":", "_x", target_name)
-    target_name = re.sub("/", "+", target_name)
-    target_dir = os.path.join(FLAGS.output_dir, target_name)
-
-    for afile in os.listdir(target_dir):
-      # find all unrelaxed pdb files, relaxed ones with 'relaxed' as prefix
-      if ".pdb" in afile and afile.startswith("model_"):
-        print(f"Info: {target_name} processing {afile}")
-        unrelaxed_pdb_file = afile
-        unrelaxed_pdb_path = os.path.join(target_dir, unrelaxed_pdb_file)
-        with open(unrelaxed_pdb_path, "r") as f:
-            unrelaxed_pdb_str = f.read()
-
-        unrelaxed_protein = protein.from_pdb_string( unrelaxed_pdb_str )
-
-        # Relax the prediction.
-        t_0 = time.time()
-        relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
-        relaxation_time = time.time() - t_0
-
-        # Save the relaxed PDB.
-        relaxed_output_path = os.path.join(target_dir, f'relaxed_{unrelaxed_pdb_file}')
-        with open(relaxed_output_path, 'w') as f:
-          f.write(relaxed_pdb_str)
-
-        print(f"Info: {target_name} relaxation done, time spent {relaxation_time:.1f} seconds")
-
-
-if __name__ == '__main__':
-  flags.mark_flags_as_required([
-      'target_lst_path',
-      'output_dir',
-      'feature_dir',
-      'template_mode',
-  ])
-
-  app.run(main)
-- 
GitLab