diff --git a/README.md b/README.md
index a97ac3dc70feaaf8bb4967848818779de82c168f..bf984a205d7ce38ec4ee213460350fea71494b82 100644
--- a/README.md
+++ b/README.md
@@ -13,6 +13,11 @@ You may test examples of AF2Complex or explore protein-protein interactions with
 
 ## Updates and Features
 
+#### Version 1.4.0 (2023-01-29)
+
+- Support AF-Multimer v3 multimer models (based on AF v2.3.1)
+- More options for input feature generation
+
 #### Version 1.3.0 (2022-08-30)
 
 - Google Colab notebook access
diff --git a/example/example1.sh b/example/example1.sh
index c5d561803134aa978ceb9d5dfccb5e70b5913148..b0edc922a0fc12303d736772fe8a2b2d60c41f6d 100755
--- a/example/example1.sh
+++ b/example/example1.sh
@@ -1,7 +1,7 @@
 #!/bin/bash
 #
 # An example script of an AF2Complex run for predicting structural models
-# of a multimeric target using AF-Multimer v2 model weights in fully paired MSA mode
+# of a multimeric target using AF-Multimer model weights in fully paired MSA mode
 
 # You need to take care of these two items, which are dependent on your installation.
 # 1) activate your conda environment for AlphaFold if you use conda
@@ -18,9 +18,9 @@ out_dir=af2c_mod # model output files will be under $out_dir/$target
 ### This preset defined the number of recycles, ensembles, MSA cluster sizes (for monomer_ptm models)
 preset=economy # up to 6 recycles, 1 ensemble.
 
-### Choose neural network model(s) from ['model_1/2/3/4/5_multimer', 'model_1/2/3/4/5_multimer_v2', or 'model_1/2/3/4/5_ptm']
+### Choose neural network model(s) from ['model_1/2/3/4/5_multimer_v3', 'model_1/2/3/4/5_multimer_v2', or 'model_1/2/3/4/5_ptm']
 # Using two AF2 multimer model released in alphafold2 version 2.2.0
-model=model_1_multimer_v2,model_3_multimer_v2
+model=model_1_multimer_v3,model_3_multimer_v3
 
 ### Choose model_preset from: ['monomer_ptm', 'multimer', 'multimer_np']
 # Notes:
@@ -29,7 +29,7 @@ model=model_1_multimer_v2,model_3_multimer_v2
 #   - multimer: apply multimer DL model to paired MSA pre-generated by AlphaFold-Multimer's official data pipeline
 #
 #   You must specify approriate model names compatible with the model preset you choose.
-#   E.g., mnomer_ptm for model_x_ptm, and multimer_np for model_x_multimer_v2
+#   E.g., mnomer_ptm for model_x_ptm, and multimer_np for model_x_multimer_v3
 #
 model_preset=multimer_np
 msa_pairing=all # will assemble msa pairing using monoermic features generated with af2complex feature procedure
diff --git a/example/example2.sh b/example/example2.sh
index f725defa5829b6542b4f7de5e031707935d46b03..20befe6341e351a95e99bf7ab10cc91c955698b6 100755
--- a/example/example2.sh
+++ b/example/example2.sh
@@ -18,7 +18,7 @@ out_dir=af2c_mod # model output files will be under $out_dir/$target
 ### This preset defined the number of recycles, ensembles, MSA cluster sizes (for monomer_ptm models)
 preset=economy  # up to 3 recycles, 1 ensemble.
 
-### Choose neural network model(s) from ['model_1/2/3/4/5_multimer', 'model_1/2/3/4/5_multimer_v2', or 'model_1/2/3/4/5_ptm']
+### Choose neural network model(s) from ['model_1/2/3/4/5_multimer_v3', 'model_1/2/3/4/5_multimer_v2', or 'model_1/2/3/4/5_ptm']
 # Using AF2 monomer_ptm model released in alphafold2 version 2.0.1
 model=model_1_ptm,model_3_ptm
 
diff --git a/example/example3.sh b/example/example3.sh
index d8e0b85191b524d4440406389bd25d2e079aea9a..74ebb91f236083625589a96cb5a7db6c95dcfec2 100755
--- a/example/example3.sh
+++ b/example/example3.sh
@@ -25,9 +25,9 @@ out_dir=af2c_mod # model output files will be under $out_dir/$target
 ### This preset defined the number of recycles, ensembles, MSA cluster sizes (for monomer_ptm models)
 preset=economy # up to 3 recycles, 1 ensemble.
 
-### Choose neural network model(s) from ['model_1/2/3/4/5_multimer', 'model_1/2/3/4/5_multimer_v2', or 'model_1/2/3/4/5_ptm']
-# Using AF2 multimer model released in alphafold2 version 2.1.1
-model=model_1_multimer_v2
+### Choose neural network model(s) from ['model_1/2/3/4/5_multimer_v3', 'model_1/2/3/4/5_multimer_v2', or 'model_1/2/3/4/5_ptm']
+# Using AF2 multimer model released in alphafold2 version 2.3.1
+model=model_1_multimer_v3
 
 ### Choose model_preset from: ['monomer_ptm', 'multimer', 'multimer_np']
 # Notes:
@@ -36,7 +36,7 @@ model=model_1_multimer_v2
 #   - multimer: apply multimer DL model to paired MSA pre-generated by AlphaFold-Multimer's official data pipeline
 #
 #   You must specify approriate model names compatible with the model preset you choose.
-#   E.g., mnomer_ptm for model_x_ptm, and multimer_np for model_x_multimer_v2
+#   E.g., mnomer_ptm for model_x_ptm, and multimer_np for model_x_multimer_v3
 model_preset=multimer_np
 
 recycling_setting=1   # output metrics but not saving pdb files during intermediate recycles
diff --git a/example/example4.sh b/example/example4.sh
index c27cd2a0dd05637b51f72568138d0a747d358b80..a19b3f14a0d93f7cdb129a9e9c1edb393d1b40e7 100755
--- a/example/example4.sh
+++ b/example/example4.sh
@@ -19,7 +19,7 @@ out_dir=af2c_mod # model output directory, s.t. output files will be on $out_dir
 ### This preset defined the number of recycles, ensembles, MSA cluster sizes (for monomer_ptm models)
 preset=expert # up to 20 recycles, 1 ensemble.
 
-### Choose neural network model(s) from ['model_1/2/3/4/5_multimer', 'model_1/2/3/4/5_multimer_v2', or 'model_1/2/3/4/5_ptm']
+### Choose neural network model(s) from ['model_1/2/3/4/5_multimer_v3', 'model_1/2/3/4/5_multimer_v2', or 'model_1/2/3/4/5_ptm']
 # Using AF2 monomer_ptm model released in alphafold2 version 2.0.1
 model=model_5_ptm
 
diff --git a/example/examples.sh b/example/examples.sh
index 2a91bddc260bc049a4440077a42f94a8df519470..bf800a67e2ae09c10e98f3998f47dfffa55f0a11 100755
--- a/example/examples.sh
+++ b/example/examples.sh
@@ -19,7 +19,7 @@ out_dir=af2c_mod # model output directory, s.t. output files will be on $out_dir
 preset=economy # up to 3 recycles, 1 ensemble.
 
 # these two options can be overwritten in examples.lst
-model=model_1_multimer_v2
+model=model_1_multimer_v3
 model_preset=multimer_np
 
 
diff --git a/example/run_af2comp.sh b/example/run_af2comp.sh
index 79fc8ea421c53c6a8dd56f50ccf13fb61b284d5e..9af080b3b464e363140297b810eb25b9da639e4e 100755
--- a/example/run_af2comp.sh
+++ b/example/run_af2comp.sh
@@ -16,12 +16,12 @@ out_dir=af2c_mod # model output directory, $out_dir/$target
 ### This preset defined the number of recycles, ensembles, MSA cluster sizes (for monomer_ptm models)
 preset=super   # up to 20 recycles, 1 ensemble.
 
-### Choose neural network model(s) from ['model_1/2/3/4/5_multimer', 'model_1/2/3/4/5_multimer_v2', or 'model_1/2/3/4/5_ptm']
+### Choose neural network model(s) from ['model_1/2/3/4/5_multimer_v3', 'model_1/2/3/4/5_multimer_v2', or 'model_1/2/3/4/5_ptm']
 # Using AF2 multimer model released in version 2.1.1
 #model=model_1_multimer,model_2_multime,rmodel_3_multimer,model_4_multimer,model_5_multimer
 # Using AF2 multimer model released in version 2.2.0
-#model=model_1_multimer_v2,model_2_multime_v2,rmodel_3_multimer_v2,model_4_multimer_v2,model_5_multimer_v2
-model=model_1_multimer_v2,model_2_multimer_v2
+#model=model_1_multimer_v3,model_2_multime_v3,rmodel_3_multimer_v3,model_4_multimer_v3,model_5_multimer_v3
+model=model_1_multimer_v3,model_2_multimer_v3
 # Using AF2 monomer_ptm model released in version 2.0.1
 #model=model_1_ptm,model_2_ptm,model_3_ptm,model_4_ptm,model_5_ptm
 
diff --git a/example/run_fea_gen.sh b/example/run_fea_gen.sh
index 38212bc72aebd7c33a4c078a4dca53ddcd0dcf51..760692f67c5671b9503de7bfe7c50677ac6f3152 100755
--- a/example/run_fea_gen.sh
+++ b/example/run_fea_gen.sh
@@ -8,7 +8,7 @@
 DATA_DIR=$HOME/scratch/afold/data
 export HHLIB=$HOME/data/tools/hh-suite/build
 #export HMMER=$HOME/data/tools/hmmer-3.2.1/build
-export HMMER=/usr/local/pace-apps/spack/packages/0.13/linux-rhel7-cascadelake/intel-19.0.5/hmmer-3.2.1-sngcwm2qjzzxseh42cryf432role4on5
+export HMMER=/usr/local/pace-apps/spack/packages/linux-rhel7-x86_64/gcc-10.3.0/hmmer-3.3.2-y25u7humnxtqnaf7yc2il3misyze6pac
 export KALIGN=$HOME/data/tools/kalign_v2/kalign
 af_dir=../src
 
@@ -18,17 +18,26 @@ if [ $# -eq 0 ]
     exit 1
 fi
 fasta_path=$1
-out_dir=af2c_fea
+out_dir=af2c_fea_test
+
+# choices are "reduced_dbs", "full_dbs", "uniprot"
 db_preset='reduced_dbs'
 
-# choices are "monomer, monomer+species, multimer"
+# choices are "monomer, multimer, monomer+species, monomer+fullpdb"
 # Option "monomer" and "multimer" follows alphafold official datapipeline for monomeric and
 # multimeric structure predictions, respectively.
+#
 # Option "monomer+species" is a modified monomeric pipeline such as the species information
 # is recorded for MSA pairing using only monomeric input features. This option is recommended.
-feature_mode='monomer+species'
+#feature_mode='monomer+species'
+#
+# Option "monomer+fullpdb": in addition to add species, it uses template pipeline for multimer
+# rather the template pipeline for the original monomer modeling. The mulitmer template pipeline
+# search full PDB for templates, which is more comprehensive than the monomer template pipeline.
+feature_mode='monomer+fullpdb'
 
-max_template_date=2020-05-15  # CASP14 starting date
+#max_template_date=2020-05-15  # CASP14 starting date
+max_template_date=$(date +"%Y-%m-%d")  # current date
 
 
 echo "Info: sequence file is $fasta_path"
@@ -41,7 +50,7 @@ echo "Info: max_template_date is $max_template_date"
 ##########################################################################################
 
 
-if [ "$feature_mode" = "multimer" ]; then
+if [ "$model_preset" = "multimer" ] || [ "$feature_mode" = "monomer+fullpdb" ]; then
   python $af_dir/run_af2c_fea.py --fasta_paths=$fasta_path --db_preset=$db_preset \
     --data_dir=$DATA_DIR --output_dir=$out_dir      \
     --uniprot_database_path=$DATA_DIR/uniprot/uniprot.fasta \
@@ -58,7 +67,8 @@ if [ "$feature_mode" = "multimer" ]; then
     --hmmsearch_binary_path=$HMMER/bin/hmmsearch \
     --hmmbuild_binary_path=$HMMER/bin/hmmbuild \
     --kalign_binary_path=$KALIGN \
-    --feature_mode=$feature_mode
+    --feature_mode=$feature_mode \
+    --use_precomputed_msas=True    
 elif [ "$feature_mode" = "monomer+species" ]; then
   python $af_dir/run_af2c_fea.py --fasta_paths=$fasta_path --db_preset=$db_preset \
     --data_dir=$DATA_DIR --output_dir=$out_dir      \
@@ -94,5 +104,6 @@ else
     --hmmsearch_binary_path=$HMMER/bin/hmmsearch \
     --hmmbuild_binary_path=$HMMER/bin/hmmbuild \
     --kalign_binary_path=$KALIGN \
-    --feature_mode=$feature_mode
+    --feature_mode=$feature_mode \
+    --use_precomputed_msas=True
 fi
diff --git a/example/run_relaxation.sh b/example/run_relaxation.sh
index a814707d9bc67601319bf9da1f87223eecd7fd2c..966bb674ea0168f69a440c87192cba264ed86866 100755
--- a/example/run_relaxation.sh
+++ b/example/run_relaxation.sh
@@ -19,3 +19,4 @@ python -u $af_dir/run_af2c_min.py \
   --target_lst_path=$target_lst_file \
   --output_dir=$out_dir \
   --input_dir=$inp_dir \
+  --use_gpu_relax
diff --git a/example/targets/test.lst b/example/targets/test.lst
index 828b3176dbadbf513ff7f416afde4ddece724c8f..34973a029b211861fe5fdfb90fdbb22b23d7a506 100644
--- a/example/targets/test.lst
+++ b/example/targets/test.lst
@@ -1,4 +1,4 @@
 ##Target(components) Size(AAs) Name(for output)
 #T1065s1/T1065s2 225 H1065
-T1072s1:2+T1072s2:2 340 H1072  ./H1072_adj_list.txt
-#T1060s3:12 1680 H1060v4
\ No newline at end of file
+#T1072s1:2+T1072s2:2 340 H1072  ./H1072_adj_list.txt
+T1060s3:12 1680 H1060v4
diff --git a/notebook/AF2Complex_notebook.ipynb b/notebook/AF2Complex_notebook.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..06ef040ca5ec3d90842d25db376fa93bdb36dbb6
--- /dev/null
+++ b/notebook/AF2Complex_notebook.ipynb
@@ -0,0 +1,1023 @@
+{
+  "cells": [
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "sl1IbeCnW4zg"
+      },
+      "source": [
+        "# Workflow\n",
+        "\n",
+        "Step 1: Setup AF2Complex by running through the setup module\n",
+        "\n",
+        "Step 2: Pick one of three Target Run modules and follow the steps in each module\n",
+        "\n",
+        "Step 3: Download your predictions\n",
+        "- After running AF2Complex on a target the prediction location will be printed out Such as the screenshot below. \n",
+        "\n",
+        "![Screen Shot 2022-08-30 at 3.28.40 PM.png]()\n",
+        "\n",
+        "- You can then download these predictions by using Google Colab's file explorer like the GIF below demonstrates:\n",
+        "\n",
+        "![download_final.gif]()\n"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "iyrTK1-HJ6Gs"
+      },
+      "source": [
+        "#Setup AF2Complex\n"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {
+        "id": "DuR67XCbLKod"
+      },
+      "outputs": [],
+      "source": [
+        "#@title 1. Download deep learning model parameters of AlphaFold 2\n",
+        "\n",
+        "#@markdown Please execute this cell by pressing the *Play* button on \n",
+        "#@markdown the left.\n",
+        "\n",
+        "# from google.colab import output\n",
+        "# output.enable_custom_widget_manager()\n",
+        "from IPython.utils import io\n",
+        "import os\n",
+        "import subprocess\n",
+        "import tqdm.notebook\n",
+        "from google.colab import output\n",
+        "import os\n",
+        "\n",
+        "output.enable_custom_widget_manager()\n",
+        "os.environ['TF_FORCE_UNIFIED_MEMORY'] = '1'\n",
+        "os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '2.0'\n",
+        "\n",
+        "#SOURCE_URL = \"https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar\"\n",
+        "SOURCE_URL = \"https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar\"\n",
+        "PARAMS_DIR = '/content/afold/data/params'\n",
+        "\n",
+        "PARAMS_PATH = os.path.join(PARAMS_DIR, os.path.basename(SOURCE_URL))\n",
+        "\n",
+        "TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'\n",
+        "try:\n",
+        "  with tqdm.notebook.tqdm(total=100, bar_format=TQDM_BAR_FORMAT) as pbar:\n",
+        "    with io.capture_output() as captured:\n",
+        "\n",
+        "      if not os.path.exists(PARAMS_DIR):\n",
+        "          %shell mkdir --parents \"{PARAMS_DIR}\"\n",
+        "          %shell wget -O \"{PARAMS_PATH}\" \"{SOURCE_URL}\"\n",
+        "          pbar.update(40)\n",
+        "  \n",
+        "          %shell tar --extract --verbose --file=\"{PARAMS_PATH}\" \\\n",
+        "              --directory=\"{PARAMS_DIR}\" --preserve-permissions\n",
+        "          %shell rm \"{PARAMS_PATH}\"\n",
+        "          pbar.update(60)\n",
+        "      else:\n",
+        "          pbar.update(100)\n",
+        "except subprocess.CalledProcessError:\n",
+        "  print(captured)\n",
+        "  raise\n"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {
+        "id": "zo7qqlklKQmd"
+      },
+      "outputs": [],
+      "source": [
+        "#@title 2. Install AF2Complex\n",
+        "\n",
+        "#@markdown Please execute this cell by pressing the _Play_ button \n",
+        "#@markdown \n",
+        "#@markdown This installs AF2Complex and the python packages it uses\n",
+        "\n",
+        "import os\n",
+        "import subprocess\n",
+        "\n",
+        "AF2C_examples = '/content/af2complex/example'\n",
+        "AF2C_src = '/content/af2complex/src'\n",
+        "AF_LIB_DIR = os.path.join(AF2C_src, 'alphafold')\n",
+        "UPLOAD_DIR = '/content/uploaded_feats/'\n",
+        "os.chdir('/content/')\n",
+        "\n",
+        "try:\n",
+        "  with tqdm.notebook.tqdm(total=100, bar_format=TQDM_BAR_FORMAT) as pbar:\n",
+        "    with io.capture_output() as captured:\n",
+        "        if not os.path.exists('/content/af2complex'):\n",
+        "            %shell git clone https://github.com/FreshAirTonight/af2complex.git\n",
+        "        pbar.update(15)\n",
+        "\n",
+        "        #Install third-party software\n",
+        "        %shell pip uninstall -y tensorflow keras\n",
+        "        pbar.update(5)\n",
+        "        # Install py3dmol.\n",
+        "        %shell pip install py3dmol\n",
+        "        pbar.update(5)\n",
+        "        %shell cd af2complex && pip install -r requirements.txt\n",
+        "        pbar.update(50)\n",
+        "        #%shell pip install --upgrade jax==0.2.14 jaxlib==0.1.69+cuda111 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
+        "        %shell pip3 install --upgrade --no-cache-dir jax==0.3.25 jaxlib==0.3.25+cuda11.cudnn805 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
+        "\n",
+        "        if not os.path.exists('/content/uploaded_feats/'):\n",
+        "            %shell mkdir /content/uploaded_feats/\n",
+        "        pbar.update(25)\n",
+        "except subprocess.CalledProcessError:\n",
+        "  print(captured)\n",
+        "  raise\n",
+        "\n"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {
+        "cellView": "form",
+        "id": "-dkd1Qcz4254"
+      },
+      "outputs": [],
+      "source": [
+        "#@title 3. Define the configuration of your structure prediction run\n",
+        "#@markdown **Note**: Please re-run this cell if any variable below is changed\n",
+        "\n",
+        "#@markdown Choose preset model configuration: <deepmind> standard settings according to DeepMind, \n",
+        "#@markdown  i.e., 3 recycles and 1 ensemble; \n",
+        "#@markdown - **economy**: no ensemble, up to 256 MSA clusters, recycling up to 3 rounds; \n",
+        "#@markdown - **super/super2**: 1 or 2 ensembles, up to 512 MSA clusters, recycling up to 20 rounds; \n",
+        "#@markdown - **genome/genome2**: 1 or 2 ensembles, up to 512 MSA clusters, max number \n",
+        "#@markdown of recycles and ensembles adjusted according to input sequence length; \n",
+        "#@markdown - **expert**: similar to super but maintain the same recycle number regardless target size; \n",
+        "#@markdown - **casp14**: 8 model ensemblings used by DeepMind in CASP14.')\n",
+        "import numpy as np\n",
+        "DATA_DIR =  '/content/afold/data/'\n",
+        "preset = 'economy' #@param ['deepmind', 'casp14', 'economy', 'super', 'expert', 'super2', 'genome', 'genome2']\n",
+        "\n",
+        "#@markdown Choose between multimer_v2 or ptm AF parameter sets:\n",
+        "model_type = 'multimer_v3' #@param ['multimer_v3', 'monomer_ptm']\n",
+        "model_preset = {\n",
+        "    'multimer_v3': 'multimer_np',\n",
+        "    'monomer_ptm': 'monomer_ptm',\n",
+        "    }[model_type]\n",
+        "if model_type == 'monomer_ptm':\n",
+        "    model_type = 'ptm'\n",
+        "\n",
+        "#@markdown There are five different models you can choose from, check the ones you want to run (please check at least one) \n",
+        "param_set_1 = True #@param {type:\"boolean\"}\n",
+        "param_set_2 = False #@param {type:\"boolean\"}\n",
+        "param_set_3 = False #@param {type:\"boolean\"}\n",
+        "param_set_4 = False #@param {type:\"boolean\"}\n",
+        "param_set_5 = False #@param {type:\"boolean\"}\n",
+        "\n",
+        "param_set_nums = [param_set_1,param_set_2,param_set_3,param_set_4,param_set_5]\n",
+        "assert np.any(param_set_nums), 'Please check one of the param_sets '\n",
+        "models = []\n",
+        "for i, param_set in enumerate(param_set_nums):\n",
+        "    if param_set:\n",
+        "        models.append(f\"model_{i+1}_{model_type}\")\n",
+        "\n",
+        "#@markdown Choose your recycling setting:\n",
+        "#@markdown 0. no recycle info saving \n",
+        "#@markdown 1. print metrics of intermediate recycles\n",
+        "#@markdown 2. additionally saving pdb structures of all recycles, \n",
+        "#@markdown 3. additionally save all results in pickle\n",
+        "recycling_setting=\"1\" #@param [0, 1, 2, 3]\n",
+        "\n",
+        "#@markdown Input below how many predictions (each with a different random seed) will be \n",
+        "#@markdown generated per model. \n",
+        "\n",
+        "#@markdown E.g. if this is 2 and there are 5\n",
+        "#@markdown models then there will be 10 predictions per input. \n",
+        "num_predictions_per_model=1 #@param {type:\"integer\"}\n",
+        "\n",
+        "#@markdown Input below the maximum number of recycles. Leave as -1 if you don't want to limit the number of recycles.\n",
+        "max_recycles = 4 #@param {type: \"integer\"}\n",
+        "\n",
+        "\n",
+        "class dotdict(dict):\n",
+        "    \"\"\"dot.notation access to dictionary attributes\"\"\"\n",
+        "    __getattr__ = dict.get\n",
+        "    __setattr__ = dict.__setitem__\n",
+        "    __delattr__ = dict.__delitem__\n",
+        "def make_default_flags():\n",
+        "    return dotdict({\n",
+        "        'target_lst_path':None,\n",
+        "        'output_dir':'/content/af2complex/example/af2c_mod',\n",
+        "        'feature_dir':'/content/af2complex/example/af2c_fea',\n",
+        "        'model_names':None,\n",
+        "        'data_dir':DATA_DIR,\n",
+        "        'preset':'economy',\n",
+        "        'random_seed':None,\n",
+        "        'max_recycles':None,\n",
+        "        'num_ensemble':None,\n",
+        "        'max_msa_clusters':None,\n",
+        "        'max_extra_msa':None,\n",
+        "        'write_complex_features':False,\n",
+        "        'no_template':False,\n",
+        "        'output_pickle':True,\n",
+        "        'save_recycled':0,\n",
+        "        'checkpoint_tag':None,\n",
+        "        'max_mono_msa_depth':10000,\n",
+        "        'mono_msa_crop_size':5000,\n",
+        "        'max_template_hits':4,\n",
+        "        'model_preset':'monomer_ptm',\n",
+        "        'num_predictions_per_model':1,\n",
+        "        'msa_pairing':None,\n",
+        "        'do_cluster_analysis':False,\n",
+        "        'cluster_edge_thres':10,\n",
+        "    })\n",
+        "FLAGS = make_default_flags()\n",
+        "FLAGS['preset'] = preset\n",
+        "FLAGS['model_preset'] = model_preset\n",
+        "FLAGS['model_names'] = models\n",
+        "FLAGS['save_recycled'] = recycling_setting\n",
+        "FLAGS['num_predictions_per_model'] = num_predictions_per_model\n",
+        "FLAGS['max_recycles'] = max_recycles\n",
+        "\n",
+        "def make_mod_params():\n",
+        "    preset = FLAGS['preset'] \n",
+        "    model_preset = FLAGS['model_preset'] \n",
+        "    models = FLAGS['model_names'] \n",
+        "    recycling_setting = FLAGS['save_recycled'] \n",
+        "    target_lst_file = FLAGS['target_lst_file'] \n",
+        "    msa_pairing = FLAGS['msa_pairing'] \n",
+        "    out_dir = FLAGS['output_dir']\n",
+        "    fea_dir = FLAGS['feature_dir']\n",
+        "    num_predictions_per_model = FLAGS['num_predictions_per_model']\n",
+        "    max_recycles = FLAGS['max_recycles']\n",
+        "\n",
+        "    parameters = [\n",
+        "            f'--target_lst_path={target_lst_file}',\n",
+        "            f'--data_dir={DATA_DIR}',\n",
+        "            f'--output_dir={out_dir}',\n",
+        "            f'--feature_dir={fea_dir}',\n",
+        "            f'--model_names={\",\".join(models)}',\n",
+        "            f'--preset={preset}',\n",
+        "            f'--model_preset={model_preset}',\n",
+        "            f'--num_predictions_per_model={num_predictions_per_model}',\n",
+        "            f'--save_recycled={recycling_setting}']\n",
+        "        \n",
+        "    if msa_pairing != 'none':\n",
+        "        parameters.append(f'--msa_pairing={msa_pairing}')\n",
+        "    if max_recycles > 0:\n",
+        "        parameters.append(f'--max_recycles={max_recycles}')\n",
+        "\n",
+        "    return ' '.join(parameters)"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {
+        "cellView": "form",
+        "id": "7ZKu8W-rTzJ3"
+      },
+      "outputs": [],
+      "source": [
+        "#@title 4. Define relevant methods for visualization\n",
+        "\n",
+        "os.chdir(AF2C_src)\n",
+        "import py3Dmol\n",
+        "from alphafold.data.complex import make_complex_features\n",
+        "from alphafold.model import config\n",
+        "from alphafold.common import confidence\n",
+        "from alphafold.data.complex import initialize_template_feats\n",
+        "\n",
+        "import alphafold.data.complex as af2c\n",
+        "from run_af2c_mod import get_asymid2chain_name\n",
+        "import pickle\n",
+        "\n",
+        "import numpy as np\n",
+        "import re\n",
+        "import pandas as pd\n",
+        "from ipywidgets import interact, Dropdown\n",
+        "from google.colab import widgets\n",
+        "import matplotlib.pyplot as plt\n",
+        "import numpy as np\n",
+        "import ipywidgets\n",
+        "from IPython.display import display\n",
+        "import pandas as pd\n",
+        "\n",
+        "\n",
+        "def show_pdb(pred_output_path, show_sidechains=False, show_mainchains=False):\n",
+        "  view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)\n",
+        "  view.addModel(open(pred_output_path,'r').read(),'pdb')\n",
+        "  view.setStyle({'cartoon': {'colorscheme': 'chain'}})\n",
+        "  if show_sidechains:\n",
+        "    BB = ['C','O','N']\n",
+        "    view.addStyle({'and':[{'resn':[\"GLY\",\"PRO\"],'invert':True},{'atom':BB,'invert':True}]},\n",
+        "                        {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
+        "    view.addStyle({'and':[{'resn':\"GLY\"},{'atom':'CA'}]},\n",
+        "                        {'sphere':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
+        "    view.addStyle({'and':[{'resn':\"PRO\"},{'atom':['C','O'],'invert':True}]},\n",
+        "                        {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})  \n",
+        "  if show_mainchains:\n",
+        "    BB = ['C','O','N','CA']\n",
+        "    view.addStyle({'atom':BB},{'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
+        "\n",
+        "  view.zoomTo()\n",
+        "  return view\n",
+        "\n",
+        "def get_asym_id(target, flags):\n",
+        "  \"\"\"Defines the sequence of preprocessing steps to get the asym_id feature\n",
+        "  Args:\n",
+        "    target: dictionary with the items:\n",
+        "        name: name of the multimer,\n",
+        "        split: information about each monomer composing the multimer,\n",
+        "        full: a string denoting all stoichiometry and domains of all monomers\n",
+        "          composing the multimer to be modeled,\n",
+        "    flags: variable containing inference configuration\n",
+        "  Returns:\n",
+        "    asym_id\n",
+        "  \"\"\"\n",
+        "  monomers = af2c.load_monomer_feature(target, flags)\n",
+        "\n",
+        "  if flags.msa_pairing is not None:\n",
+        "    for i in range(len(monomers)):\n",
+        "        if 'deletion_matrix' in monomers[i]['feature_dict']:\n",
+        "            monomers[i]['feature_dict']['deletion_matrix_int'] = monomers[i]['feature_dict']['deletion_matrix']\n",
+        "  curr_input = {'monomers': monomers, 'target': target, 'flags': flags}\n",
+        "\n",
+        "  curr_input = af2c.targeted_domain_cropping_mono(curr_input)\n",
+        "  curr_input = af2c.add_asym_id_monomer_ptm(curr_input)\n",
+        "  asym_id = curr_input['asym_id_mono_ptm']\n",
+        "\n",
+        "  return asym_id\n",
+        "\n",
+        "def get_interface_score(\n",
+        "    model_name, target_name, full_name, asym_id, idx2chain_name, out_dir, asym_id_list):\n",
+        "        metric = []\n",
+        "        value = []\n",
+        "        pdb_path = os.path.join(out_dir, target_name, f'{model_name}.pdb')\n",
+        "        pkl_path = os.path.join(out_dir, target_name, f'{model_name}.pkl')\n",
+        "\n",
+        "        model_config = config.model_config(model_name[:7])\n",
+        "        breaks = np.linspace(\n",
+        "            0., model_config.model.heads.predicted_aligned_error.max_error_bin,\n",
+        "            model_config.model.heads.predicted_aligned_error.num_bins - 1)\n",
+        "        try:\n",
+        "            result = pickle.load(open(pkl_path, \"rb\"))\n",
+        "        except (EOFError,IOError) as error:\n",
+        "            print(f\"Warning: {target_name} {error} encountered, check the pickle file\")\n",
+        "            raise\n",
+        "\n",
+        "        super_asym_id, superid2chainids = confidence.join_superchains_asym_id(asym_id, asym_id_list)\n",
+        "\n",
+        "        res = confidence.interface_score(\n",
+        "            result['aligned_confidence_probs'],\n",
+        "            breaks,\n",
+        "            result['structure_module']['final_atom_positions'],\n",
+        "            result['structure_module']['final_atom_mask'],\n",
+        "            super_asym_id,\n",
+        "            is_probs=True)\n",
+        "\n",
+        "        ptm = result['ptm'].tolist()\n",
+        "        pitm = result['pitm']['score'].tolist()\n",
+        "\n",
+        "        inter_sc = res['score'].tolist()\n",
+        "        inter_residues = res['num_residues'].tolist()\n",
+        "        inter_contacts = res['num_contacts'].tolist()\n",
+        "        metric.append('MODEL NAME')\n",
+        "        value.append(model_name)\n",
+        "        metric.append('TARGET CHAINS')\n",
+        "        value.append(full_name)\n",
+        "        metric.append('===========')\n",
+        "        value.append('===========')\n",
+        "        metric.append('pTM-score')\n",
+        "        value.append(ptm)\n",
+        "        metric.append('piTM-score')\n",
+        "        value.append(pitm)\n",
+        "        metric.append('iRes')\n",
+        "        value.append(inter_residues)\n",
+        "        metric.append('iCnt')\n",
+        "        value.append(inter_contacts)\n",
+        "        metric.append('interface-score')\n",
+        "        value.append(inter_sc)\n",
+        "\n",
+        "        if FLAGS.do_cluster_analysis:\n",
+        "            clus_res = confidence.cluster_analysis(\n",
+        "                super_asym_id,\n",
+        "                result['structure_module']['final_atom_positions'],\n",
+        "                result['structure_module']['final_atom_mask'],\n",
+        "                edge_contacts_thres=FLAGS.cluster_edge_thres,\n",
+        "                superid2chainids=superid2chainids,\n",
+        "            )\n",
+        "            cluster_identities = []\n",
+        "            for cluster in clus_res['clusters']:\n",
+        "                cluster_identities.append([idx2chain_name[c] for c in cluster])\n",
+        "\n",
+        "            metric.append('num_clusters')\n",
+        "            value.append(clus_res['num_clusters'])\n",
+        "            metric.append('cluster_sizes')\n",
+        "            value.append(clus_res['cluster_size'])\n",
+        "            metric.append('clusters')\n",
+        "            value.append(cluster_identities)\n",
+        "        return pd.DataFrame({'Metric Name':metric, 'Value':value})\n",
+        "\n",
+        "DATA_DIR = '/content/afold/data'  \n",
+        "def display_metrics(target_lst_path, model_path, support, show_sidechains_=True, show_mainchains_=True, ):\n",
+        "    with io.capture_output() as captured:\n",
+        "        target_lst = af2c.read_af2c_target_file(target_lst_path)\n",
+        "        full_name = support[\"full_name\"] \n",
+        "        target_name= support[\"target_name\"] \n",
+        "        idx2chain_name= support[\"idx2chain_name\"]\n",
+        "        asym_id_list= support[\"asym_id_list\"]\n",
+        "        asym_id= support[\"asym_id\"] \n",
+        "        model_name = os.path.basename(model_path)\n",
+        "        pdb_path = os.path.join(FLAGS.output_dir, target_name, f'{model_name}.pdb')\n",
+        "\n",
+        "        metrics = get_interface_score(\n",
+        "            model_name, target_name, full_name, asym_id, idx2chain_name, FLAGS.output_dir, asym_id_list\n",
+        "        )\n",
+        "\n",
+        "    print(metrics.to_markdown())\n",
+        "    view = show_pdb(pdb_path, \n",
+        "                show_sidechains=show_sidechains_,\n",
+        "                show_mainchains=show_mainchains_)\n",
+        "    view.show()\n",
+        "\n",
+        "def visualize(show_sidechains, show_mainchains):\n",
+        "    is_pdb = lambda x: '.pdb' in x \n",
+        "    if not os.path.exists(FLAGS.target_lst_file):\n",
+        "        raise f'{FLAGS.target_lst_file} does not exist!'\n",
+        "    target_lst = af2c.read_af2c_target_file(FLAGS.target_lst_file)\n",
+        "    files = []\n",
+        "    model2support = {}\n",
+        "    for target in target_lst:\n",
+        "        target_name = target['name']\n",
+        "        target_name = re.sub(\":\", \"_x\", target_name)\n",
+        "        target_name = re.sub(\"/\", \"+\", target_name)\n",
+        "        target_dir = os.path.join(FLAGS.output_dir, target_name)\n",
+        "        if not os.path.exists(target_dir):\n",
+        "            raise Exception(\n",
+        "                f'No predictions for {target_name}. Predictions available are for {os.listdir(FLAGS.output_dir)}. Please make sure the inference cell was run correctly.'\n",
+        "            )\n",
+        "        target_files = os.listdir(target_dir)\n",
+        "        if len(target_files) == 0:\n",
+        "            raise Exception(\n",
+        "                f'No predictions for {target_name}. Predictions available are for {os.listdir(FLAGS.output_dir)}. Please make sure the inference cell was run correctly.'\n",
+        "            )\n",
+        "        for f in target_files:\n",
+        "            full_name = target['full']\n",
+        "            idx2chain_name = get_asymid2chain_name(target)\n",
+        "            asym_id_list = target['asym_id_list']\n",
+        "            if not FLAGS.write_complex_features:\n",
+        "                with io.capture_output() as captured:\n",
+        "                    asym_id = get_asym_id(target, FLAGS)\n",
+        "            else:\n",
+        "                feat_path = os.path.join(target_name, 'features_comp.pkl')\n",
+        "                try:\n",
+        "                    feature_dict = np.load(open(feat_path, 'rb'))\n",
+        "                except FileNotFoundError:\n",
+        "                    print('Did not find feature_comp.pkl file. ',\n",
+        "                        'To rebuild complex features, run without ',\n",
+        "                        '--write_complex_features flag.')\n",
+        "                asym_id = feature_dict['asym_id']\n",
+        "            if is_pdb(f):\n",
+        "                model_name = os.path.join(target_dir, target_name, f)[:-4]\n",
+        "                files.append(model_name)\n",
+        "                model2support[model_name] = {\n",
+        "                    'full_name': full_name, \n",
+        "                    'target_name': target_name, \n",
+        "                    'idx2chain_name': idx2chain_name, \n",
+        "                    'asym_id_list': asym_id_list, \n",
+        "                    'asym_id': asym_id, \n",
+        "                    }\n",
+        "\n",
+        "    tabs = widgets.TabBar(files)\n",
+        "\n",
+        "    for i, model in enumerate(files): \n",
+        "        with tabs.output_to(i):\n",
+        "            display_metrics(\n",
+        "                FLAGS.target_lst_file,\n",
+        "                model,\n",
+        "                model2support[model],\n",
+        "                show_sidechains,\n",
+        "                show_mainchains, \n",
+        "                )\n",
+        "\n",
+        "def get_dataset_desc(file_path):\n",
+        "    from google.colab import data_table\n",
+        "    data_table.enable_dataframe_formatter()\n",
+        "    with open(file_path, 'r') as f:\n",
+        "        ecoli_txt = f.readlines()\n",
+        "\n",
+        "    ids = []\n",
+        "    genes = []\n",
+        "    acs = []\n",
+        "    fulllen = []\n",
+        "    ranges = []\n",
+        "    length = []\n",
+        "    desc = []\n",
+        "    for line in ecoli_txt:\n",
+        "        if line.startswith('#'):\n",
+        "            continue\n",
+        "        line = line.split('\\t')\n",
+        "        ids.append(line[0])\n",
+        "        genes.append(line[1])\n",
+        "        acs.append(line[2])\n",
+        "        fulllen.append(line[3])\n",
+        "        ranges.append(line[4])\n",
+        "        length.append(line[5])\n",
+        "        desc.append(line[6])\n",
+        "\n",
+        "    df = pd.DataFrame({\n",
+        "        'ID': ids[1:],\n",
+        "        'Gene': genes[1:],\n",
+        "        'AC': acs[1:],\n",
+        "        'Full Length': fulllen[1:],\n",
+        "        'Range': ranges[1:],\n",
+        "        'Length': length[1:],\n",
+        "        'Description': desc[1:],\n",
+        "    })\n",
+        "    return df"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "EUsieqM-OsSJ"
+      },
+      "source": [
+        "# Target Run (AF2Complex Examples)"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {
+        "cellView": "form",
+        "id": "1QFJ32_zOqJU"
+      },
+      "outputs": [],
+      "source": [
+        "import subprocess\n",
+        "import numpy as np\n",
+        "os.chdir(AF_LIB_DIR)\n",
+        "#@markdown #1. Choose one of the AF2Complex examples to run below! \n",
+        "#@markdown Note: After choosing your parameters below, press the play button to run the example chosen:\n",
+        "FLAGS['feature_dir'] = '/content/af2complex/example/af2c_fea' \n",
+        "FLAGS['output_dir'] = '/content/af2complex/example/af2c_mod'\n",
+        "\n",
+        "example = 'H1065' #@param ['H1065', 'H1072', 'H1072_H1065', 'H1060v4']\n",
+        "\n",
+        "target_lst_file = {\n",
+        "    'H1065': '/content/af2complex/example/targets/example1.lst',\n",
+        "    'H1072': '/content/af2complex/example/targets/example2.lst',\n",
+        "    'H1072_H1065': '/content/af2complex/example/targets/example3.lst',\n",
+        "    'H1060v4': '/content/af2complex/example/targets/example4.lst',\n",
+        "}[example]\n",
+        "\n",
+        "#@markdown Choose the type of msa pairing you want to use (Note: 'none' will do no msa_pairing, 'all' will do every possible species pairing as was done in AF-Multimer):\n",
+        "msa_pairing = 'none' #@param ['none', 'all', 'custom', 'cyclic', 'linear']\n",
+        "\n",
+        "FLAGS['target_lst_file'] = target_lst_file\n",
+        "FLAGS['msa_pairing'] = msa_pairing\n",
+        "\n",
+        "pred_params = make_mod_params()\n",
+        "\n",
+        "# with io.capture_output() as captured:\n",
+        "%shell python -u ../run_af2c_mod.py {pred_params}\n",
+        "print(f'DONE! (predictions available on {FLAGS.output_dir}' )"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {
+        "cellView": "form",
+        "id": "sd1bBQgTJLEA"
+      },
+      "outputs": [],
+      "source": [
+        "# %matplotlib inline\n",
+        "\n",
+        "#@markdown #2. Visualize your results below by pressing the *Play* button on the left\n",
+        "#@markdown Choose one of the AF2Complex examples to visualize below! \n",
+        "FLAGS['feature_dir'] = '/content/af2complex/example/af2c_fea' \n",
+        "FLAGS['output_dir'] = '/content/af2complex/example/af2c_mod'\n",
+        "\n",
+        "example = 'H1065' #@param ['H1065', 'H1072', 'H1072_H1065', 'H1060v4']\n",
+        "\n",
+        "target_lst_file = {\n",
+        "    'H1065': '/content/af2complex/example/targets/example1.lst',\n",
+        "    'H1072': '/content/af2complex/example/targets/example2.lst',\n",
+        "    'H1072_H1065': '/content/af2complex/example/targets/example3.lst',\n",
+        "    'H1060v4': '/content/af2complex/example/targets/example4.lst',\n",
+        "}[example]\n",
+        "\n",
+        "FLAGS['target_lst_file'] = target_lst_file\n",
+        "FLAGS['msa_pairing'] = msa_pairing\n",
+        "\n",
+        "pred_params = make_mod_params()\n",
+        "\n",
+        "show_sidechains = False #@param {type: 'boolean'}\n",
+        "show_mainchains = False #@param {type: 'boolean'}\n",
+        "\n",
+        "\n",
+        "AF2C_examples = '/content/af2complex/example'\n",
+        "AF2C_egtargets =  os.path.join(AF2C_examples, 'targets')\n",
+        "\n",
+        "visualize(show_sidechains, show_mainchains)"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "1s0RID2AYdpk"
+      },
+      "source": [
+        "# Target Run (within the *E. coli* by using pre-generated features for the proteome!)"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {
+        "cellView": "form",
+        "id": "ClAAecbPYrr0"
+      },
+      "outputs": [],
+      "source": [
+        "import os\n",
+        "#@title 1. Download the dataset from [Zenodo](https://zenodo.org/record/7008599#.YwFWR3bMJaQ)\n",
+        "#@markdown Note: Usually takes less than 20 minutes.\n",
+        "\n",
+        "AF2C_examples = '/content/af2complex/example'\n",
+        "os.chdir(AF2C_examples)\n",
+        "AF2C_ecoli = os.path.join(AF2C_examples, 'ecoli')\n",
+        "zenodo_link = 'https://zenodo.org/record/7008599/files/af2c_fea_ecoli_220331_msa10ktem10.tar?download=1'\n",
+        "\n",
+        "AF2C_ecoli_path = os.path.join(AF2C_ecoli, os.path.basename(zenodo_link))\n",
+        "\n",
+        "if not os.path.exists(AF2C_ecoli):\n",
+        "    os.mkdir(AF2C_ecoli)\n",
+        "\n",
+        "%shell wget -O {AF2C_ecoli_path} {zenodo_link}\n",
+        "    \n",
+        "with io.capture_output() as captured:\n",
+        "    %shell tar --extract --verbose --file={AF2C_ecoli_path} \\\n",
+        "        --directory={AF2C_ecoli} --preserve-permissions\n",
+        " \n",
+        "AF2C_ecoli_feas = AF2C_ecoli_path.split('.tar')[0]\n",
+        "txt_zenodo_link = \"https://zenodo.org/record/7008599/files/ecoli_af2c_fea.txt?download=1\"\n",
+        "AF2C_ecoli_txt_path = os.path.join(AF2C_ecoli, os.path.basename(txt_zenodo_link))\n",
+        "\n",
+        "%shell wget -O {AF2C_ecoli_txt_path} {txt_zenodo_link}"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {
+        "cellView": "form",
+        "id": "IQuqyTcMMaS7"
+      },
+      "outputs": [],
+      "source": [
+        "import pandas as pd\n",
+        "\n",
+        "# @markdown # Find the genes in *E. coli* with pre-generated input features:\n",
+        "\n",
+        "ecoli_df = get_dataset_desc(AF2C_ecoli_txt_path)\n",
+        "ecoli_df"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {
+        "cellView": "form",
+        "id": "yNYa7QFWQPZF"
+      },
+      "outputs": [],
+      "source": [
+        "import subprocess\n",
+        "import numpy as np\n",
+        "# from run_af2c_mod import FLAGS\n",
+        "os.chdir(AF_LIB_DIR)\n",
+        "# print(FLAGS.fea_dir, AF2C_ecoli_feas)\n",
+        "FLAGS.feature_dir =  AF2C_ecoli_feas\n",
+        "# print(FLAGS.fea_dir, AF2C_ecoli_feas)\n",
+        "FLAGS.output_dir = '/content/af2complex/example/ecoli/af2c_mod'\n",
+        "if not os.path.exists(FLAGS.output_dir):\n",
+        "    os.mkdir(FLAGS.output_dir)\n",
+        "\n",
+        "#@markdown #2. Define your protein complex target using the UniProt IDs you found in the table above, then run AF2Complex using the *Play* button on the left\n",
+        "#@markdown Define how the chains compose the target,\n",
+        "#@markdown e.g.: \n",
+        "# #@markdown - T1065s1/T1065s2  *(Explanation on [example1](https://github.gatech.edu/gmu3/af2complex/tree/master/example#example-1) for more)*\n",
+        "# #@markdown - T1072s1:2/T1072s2:2 *(Explanation on [example2](https://github.gatech.edu/gmu3/af2complex/tree/master/example#example-2) for more)\n",
+        "# #@markdown - T1065s1/T1065s2+T1072s1:2/T1072s2:2 *(Explanation on [example3](https://github.gatech.edu/gmu3/af2complex/tree/master/example#example-3) for more)*\n",
+        "# #@markdown - T1060s3:12 *(Explanation on [example4](https://github.gatech.edu/gmu3/af2complex/tree/master/example#example-4) for more)*\n",
+        "\n",
+        "#@markdown - SECE/SECG/SECY  *(SecYEG translocon, a hetero-trimer composed of SecE, SecG, and SecY, 680 AAs)*\n",
+        "#@markdown - PPID|265-359/DSBA|20-208  *(PpiD parvulin domain and DsbA, each has a residue ID range, 285 AAs)*\n",
+        "#@markdown - SURA|21-428/BAMA|21-420 *(surA and BamA, both have signal peptide removed, 808 AAs)*\n",
+        "#@markdown - PPID/YFGM  *(chaperon proteins PpiD and YfgM, 829 AAs)*\n",
+        "#@markdown - CCMA:2/CCMB:2/CCMC/CCMD/CCME *(CcmI system, 1327 AAs)*\n",
+        "#@markdown - YAJC:3  *(YajC, a membrane protein chaperon? 330 AAs)*\n",
+        "\n",
+        "#@markdown Note that a large target may require resources beyond the free-tier.\n",
+        "\n",
+        "# chains = 'e.g. T1065s1/T1065s2' #@param {type:'string'}\n",
+        "chains = 'SECE/SECG/SECY' #@param {type:'string'}\n",
+        "\n",
+        "#@markdown Name your target\n",
+        "# target = 'e.g. H1065' #@param {type: 'string'}\n",
+        "target = 'SecYEG' #@param {type: 'string'}\n",
+        "\n",
+        "#@markdown Put down the total number of AA of the target (does not need to be exact as this number will be parsed but not used in the code)\n",
+        "num_AA = 680 #@param{type: 'integer'}\n",
+        "\n",
+        "#@markdown Choose the type of msa pairing you want to use (Note: 'none' will do no msa_pairing, 'all' will do species pairing as in AF-Multimer):\n",
+        "msa_pairing = 'all' #@param ['none', 'all', 'custom', 'cyclic', 'linear']\n",
+        "FLAGS.msa_pairing = msa_pairing\n",
+        "\n",
+        "target_lst = f'{chains} {num_AA} {target}'\n",
+        "\n",
+        "target_lst_file = os.path.join(AF2C_ecoli, f'{target}.lst')\n",
+        "with open(target_lst_file, 'w') as f:\n",
+        "    f.write(target_lst)\n",
+        "    f.close()\n",
+        "FLAGS.target_lst_file = target_lst_file\n",
+        "\n",
+        "pred_params = make_mod_params()\n",
+        "\n",
+        "# with io.capture_output() as captured:\n",
+        "%shell python -u ../run_af2c_mod.py {pred_params}\n",
+        "print(f'DONE! (predictions available on {FLAGS.output_dir}' )"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {
+        "cellView": "form",
+        "id": "MewHtxVmBxNS"
+      },
+      "outputs": [],
+      "source": [
+        "#@markdown # Press **Play** button to the left to see which targets you have predictions for so far\n",
+        "from ipywidgets import interact\n",
+        "pd.DataFrame({\n",
+        "    'Target Name': os.listdir(FLAGS.output_dir),\n",
+        "    'Number of Predictions': [\n",
+        "        len(list(filter(lambda x: '.pdb' in x, os.listdir(os.path.join(FLAGS.output_dir,f )))))\n",
+        "        for f in os.listdir(FLAGS.output_dir)],\n",
+        "})"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {
+        "cellView": "form",
+        "id": "9G2i6I84ZJql"
+      },
+      "outputs": [],
+      "source": [
+        "#@markdown #3. Visualize your results below by pressing the *Play* button on the left\n",
+        "\n",
+        "#@markdown Check out the cell above to see which targets have predictions. Imput the target name below to visualize the proteins. \n",
+        "target_name = 'SecYEG' #@param {type: 'string'}\n",
+        "FLAGS.target_lst_file = os.path.join(AF2C_ecoli, f'{target_name}.lst')\n",
+        "if not os.path.exists(FLAGS.target_lst_file):\n",
+        "    raise Exception(f' Target: predictions for {target_name} do not exist, run the cell above to see which targets have predictions')\n",
+        "\n",
+        "show_sidechains = False #@param {type: 'boolean'}\n",
+        "show_mainchains = False #@param {type: 'boolean'}\n",
+        "\n",
+        "visualize(show_sidechains, show_mainchains)"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "tct8Uw1TW8RQ"
+      },
+      "source": [
+        "#Target Run (Upload your own features)\n",
+        "\n",
+        "First create a folder with your features with the following file structure: \n",
+        "```\n",
+        "dataset_name\n",
+        "│   \n",
+        "└───chain_1\n",
+        "│   │   \n",
+        "│   └───features.pkl\n",
+        "└───chain_2\n",
+        "│   │   \n",
+        "│   └───features.pkl\n",
+        "...\n",
+        "```\n",
+        "Then, upload a .tar (or .tgz) file of this folder below (Section 1a). You can create a .tar file with the following unix terminal command (Please keep the folder name and the .tar file name the same): \n",
+        "\n",
+        "\n",
+        "```\n",
+        "tar -czf af2c_fea.tgz af2c_fea\n",
+        "```\n",
+        "\n",
+        "\n",
+        "Optionally, you can also upload a txt file describing the dataset to easily search through the dataset. It should look like the following (all tab separated):\n",
+        "```\n",
+        "### ID -- UniProt ID\n",
+        "### Gene -- Recommended gene name\n",
+        "### AC -- Accession ID\n",
+        "### Fulllen -- Full sequence length\n",
+        "### Range -- Residue range of the longest mature chain\n",
+        "### Len -- Seuence length\n",
+        "### Description -- description of the gene\n",
+        "ID\tGene\tAC\tFullLen\tRange\tLen\tDescription\n",
+        "3MG1\ttag\tP05100\t187\t1-187\t187\tDNA-3-methyladenine glycosylase 1\n",
+        "...\n",
+        "```"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {
+        "cellView": "form",
+        "id": "nnj1Vwg9KW3A"
+      },
+      "outputs": [],
+      "source": [
+        "#@markdown #1a. Upload your dataset here (press *Play* button to the left)\n",
+        "#@markdown Please upload only one dataset at a time\n",
+        "\n",
+        "from google.colab import files\n",
+        "os.chdir(UPLOAD_DIR)\n",
+        "print('Upload the .tar (or .tgz) file')\n",
+        "\n",
+        "uploaded = files.upload()\n",
+        "dset_file = list(uploaded.keys())[0]\n",
+        "dset_name = dset_file.split('.')[0]\n",
+        "dset_dir = os.path.join(UPLOAD_DIR, dset_name)\n",
+        "print(f\"INFO: Uploaded dataset with name: {dset_name}\")\n",
+        "    \n",
+        "with io.capture_output() as captured:\n",
+        "    %shell tar --extract --verbose --file=/content/uploaded_feats/{dset_file} \\\n",
+        "        --directory={UPLOAD_DIR} --preserve-permissions\n",
+        " "
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {
+        "cellView": "form",
+        "id": "XgDHn7W2vEWX"
+      },
+      "outputs": [],
+      "source": [
+        "#@markdown #1b. Check out the dataset (Optional)\n",
+        "#@markdown Upload the dataset description file\n",
+        "\n",
+        "os.chdir(UPLOAD_DIR)\n",
+        "print('Upload the .txt description file')\n",
+        "uploaded = files.upload()\n",
+        "desc_file = list(uploaded.keys())[0]\n",
+        "dset_df = get_dataset_desc(os.path.join(UPLOAD_DIR, desc_file))\n",
+        "dset_df"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {
+        "cellView": "form",
+        "id": "76o6KEjsy6R7"
+      },
+      "outputs": [],
+      "source": [
+        "import numpy as np\n",
+        "# from run_af2c_mod import FLAGS\n",
+        "os.chdir(AF_LIB_DIR)\n",
+        "# print(FLAGS.fea_dir, AF2C_ecoli_feas)\n",
+        "FLAGS.feature_dir = os.path.join(UPLOAD_DIR, dset_name)\n",
+        "# print(FLAGS.fea_dir, AF2C_ecoli_feas)\n",
+        "FLAGS.output_dir = os.path.join(UPLOAD_DIR, dset_name, 'af2c_mod')\n",
+        "if not os.path.exists(FLAGS.output_dir):\n",
+        "    os.mkdir(FLAGS.output_dir)\n",
+        "\n",
+        "#@markdown #2. Define your protein complex target and Run AF2Complex on it using the *Play* button on the left\n",
+        "#@markdown Define how the chains compose the target, look at sections above for more information (Section 2 of *E. coli* target run)\n",
+        "\n",
+        "#@markdown Note that a large target may require resources beyond the free-tier.\n",
+        "\n",
+        "# chains = 'e.g. T1065s1/T1065s2' #@param {type:'string'}\n",
+        "chains = 'HgcA/HgcB' #@param {type:'string'}\n",
+        "\n",
+        "#@markdown Name your target\n",
+        "# target = 'e.g. H1065' #@param {type: 'string'}\n",
+        "target = 'HgcAB' #@param {type: 'string'}\n",
+        "\n",
+        "#@markdown Put down the total number of AA of the target (does not need to be exact as this number will be parsed but not used in the code)\n",
+        "num_AA = 433 #@param{type: 'integer'}\n",
+        "\n",
+        "#@markdown Choose the type of msa pairing you want to use (Note: 'none' will do no msa_pairing, 'all' will do species pairing as in AF-Multimer):\n",
+        "msa_pairing = 'all' #@param ['none', 'all', 'custom', 'cyclic', 'linear']\n",
+        "FLAGS.msa_pairing = msa_pairing\n",
+        "\n",
+        "target_lst = f'{chains} {num_AA} {target}'\n",
+        "\n",
+        "target_lst_file = os.path.join(dset_dir, f'{target}.lst')\n",
+        "with open(target_lst_file, 'w') as f:\n",
+        "    f.write(target_lst)\n",
+        "    f.close()\n",
+        "FLAGS.target_lst_file = target_lst_file\n",
+        "\n",
+        "pred_params = make_mod_params()\n",
+        "\n",
+        "# with io.capture_output() as captured:\n",
+        "%shell python -u ../run_af2c_mod.py {pred_params}\n",
+        "print(f'DONE! (predictions available on {FLAGS.output_dir}' )"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {
+        "cellView": "form",
+        "id": "01DWp6Sjk6pP"
+      },
+      "outputs": [],
+      "source": [
+        "#@markdown # Press **Play** button to the left to see which targets you have predictions for so far\n",
+        "from ipywidgets import interact\n",
+        "pd.DataFrame({\n",
+        "    'Target Name': os.listdir(FLAGS.output_dir),\n",
+        "    'Number of Predictions': [\n",
+        "        len(list(filter(lambda x: '.pdb' in x, os.listdir(os.path.join(FLAGS.output_dir,f )))))\n",
+        "        for f in os.listdir(FLAGS.output_dir)],\n",
+        "})\n"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {
+        "cellView": "form",
+        "id": "1rjkGMLYyVcU"
+      },
+      "outputs": [],
+      "source": [
+        "# %matplotlib inline\n",
+        "\n",
+        "#@markdown #3. Visualize your results below by pressing the *Play* button on the left\n",
+        "\n",
+        "#@markdown Place the target name you want to visualize below\n",
+        "# target = 'e.g. H1065' #@param {type: 'string'}\n",
+        "target_name = 'HgcAB' #@param {type: 'string'}\n",
+        "FLAGS.target_lst_file = os.path.join(dset_dir, f'{target_name}.lst')\n",
+        "if not os.path.exists(FLAGS.target_lst_file):\n",
+        "    raise Exception(f' Target: predictions for {target_name} do not exist, run the cell above to see which targets have predictions')\n",
+        "\n",
+        "show_sidechains = False #@param {type: 'boolean'}\n",
+        "show_mainchains = False #@param {type: 'boolean'}\n",
+        "\n",
+        "visualize(show_sidechains, show_mainchains)"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {
+        "id": "7PN-BhHoTtCr"
+      },
+      "outputs": [],
+      "source": []
+    }
+  ],
+  "metadata": {
+    "accelerator": "GPU",
+    "colab": {
+      "private_outputs": true,
+      "provenance": [],
+      "toc_visible": true
+    },
+    "gpuClass": "standard",
+    "kernelspec": {
+      "display_name": "Python 3",
+      "language": "python",
+      "name": "python3"
+    },
+    "language_info": {
+      "name": "python",
+      "version": "3.10.6 (main, Nov 14 2022, 16:10:14) [GCC 11.3.0]"
+    },
+    "vscode": {
+      "interpreter": {
+        "hash": "e7370f93d1d0cde622a1f8e1c04877d8463912d04d973331ad4851f04de6915a"
+      }
+    }
+  },
+  "nbformat": 4,
+  "nbformat_minor": 0
+}
diff --git a/requirements.txt b/requirements.txt
index d0e90f3ca5a0b4389cfbdeea7d1fab989ce6447e..02ca4dce135d4f585d32d831ced550f3ec78afc2 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,13 +1,15 @@
-absl-py==0.13.0
+absl-py==1.0.0
 biopython==1.79
 chex==0.0.7
-dm-haiku==0.0.4
+dm-haiku==0.0.9
 dm-tree==0.1.6
+docker==5.0.0
 immutabledict==2.0.0
-jax==0.2.14
+jax==0.3.25
 ml-collections==0.1.0
-numpy==1.19.5
+numpy==1.21.6
 pandas==1.3.4
+protobuf==3.20.1
 scipy==1.7.0
-tensorflow==2.5.0
+tensorflow-cpu==2.9.0
 networkx==2.5
diff --git a/src/alphafold/common/confidence.py b/src/alphafold/common/confidence.py
index ffad0384cfe7fc77b21860a3b02bf3c65f5a2068..68dca877f994d0976464fc8fb1144cc71bff36ba 100644
--- a/src/alphafold/common/confidence.py
+++ b/src/alphafold/common/confidence.py
@@ -396,7 +396,7 @@ def interface_score(
     atom_mask: np.ndarray,
     asym_id: np.ndarray,
     residue_weights: Optional[np.ndarray] = None,
-    distance_threshold: Optional[int] = 4.5,
+    distance_threshold: Optional[float] = 4.5,
     is_probs: Optional[bool] = False,) -> Dict[str, Union[np.ndarray, int]]:
   """ Returns the interface-score, number of residues in the interface, and
     number of contacts of a complex model. This is a further tweak from piTM by
diff --git a/src/alphafold/common/residue_constants.py b/src/alphafold/common/residue_constants.py
index 4318875a9b29636ebbe26a5dd743547b856e21a3..049c9a6df3757d8feded6e52402298008879f687 100644
--- a/src/alphafold/common/residue_constants.py
+++ b/src/alphafold/common/residue_constants.py
@@ -120,7 +120,7 @@ chi_pi_periodic = [
 # 4,5,6,7: 'chi1,2,3,4-group'
 # The atom positions are relative to the axis-end-atom of the corresponding
 # rotation axis. The x-axis is in direction of the rotation axis, and the y-axis
-# is defined such that the dihedral-angle-definiting atom (the last entry in
+# is defined such that the dihedral-angle-defining atom (the last entry in
 # chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate).
 # format: [atomname, group_idx, rel_position]
 rigid_group_atom_positions = {
@@ -772,10 +772,10 @@ def _make_rigid_transformation_4x4(ex, ey, translation):
 # and an array with (restype, atomtype, coord) for the atom positions
 # and compute affine transformation matrices (4,4) from one rigid group to the
 # previous group
-restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=np.int)
+restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=int)
 restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
 restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32)
-restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=np.int)
+restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=int)
 restype_atom14_mask = np.zeros([21, 14], dtype=np.float32)
 restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32)
 restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)
diff --git a/src/alphafold/data/pipeline.py b/src/alphafold/data/pipeline.py
index cb954bdfdfad4378e310bea22b694f12a0233729..bca4306b7ce1d865438c5f49f49beb130610a4ba 100644
--- a/src/alphafold/data/pipeline.py
+++ b/src/alphafold/data/pipeline.py
@@ -233,14 +233,14 @@ class DataPipeline:
           use_precomputed_msas=self.use_precomputed_msas)
       bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto'])
     else:
-      bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m')
-      hhblits_bfd_uniclust_result = run_msa_tool(
-          msa_runner=self.hhblits_bfd_uniclust_runner,
+      bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniref_hits.a3m')
+      hhblits_bfd_uniref_result = run_msa_tool(
+          msa_runner=self.hhblits_bfd_uniref_runner,
           input_fasta_path=input_fasta_path,
           msa_out_path=bfd_out_path,
           msa_format='a3m',
           use_precomputed_msas=self.use_precomputed_msas)
-      bfd_msa = parsers.parse_a3m(hhblits_bfd_uniclust_result['a3m'])
+      bfd_msa = parsers.parse_a3m(hhblits_bfd_uniref_result['a3m'])      
 
     templates_result = self.template_featurizer.get_templates(
         query_sequence=input_sequence,
diff --git a/src/alphafold/data/pipeline_uniprot.py b/src/alphafold/data/pipeline_uniprot.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa2719adf7e32bafd8da435dcc9df3818cad1d17
--- /dev/null
+++ b/src/alphafold/data/pipeline_uniprot.py
@@ -0,0 +1,199 @@
+# Modified by Mu Gao to consider only UniProt as the sequence database
+#
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Functions for building the input features for the AlphaFold model."""
+
+import os
+from typing import Any, Mapping, MutableMapping, Optional, Sequence, Union
+from absl import logging
+from alphafold.common import residue_constants
+from alphafold.data import msa_identifiers
+from alphafold.data import parsers
+from alphafold.data import templates
+from alphafold.data.tools import hhblits
+from alphafold.data.tools import hhsearch
+from alphafold.data.tools import hmmsearch
+from alphafold.data.tools import jackhmmer
+import numpy as np
+
+# Internal import (7716).
+
+FeatureDict = MutableMapping[str, np.ndarray]
+TemplateSearcher = Union[hhsearch.HHSearch, hmmsearch.Hmmsearch]
+
+
+def make_sequence_features(
+    sequence: str, description: str, num_res: int) -> FeatureDict:
+  """Constructs a feature dict of sequence features."""
+  features = {}
+  features['aatype'] = residue_constants.sequence_to_onehot(
+      sequence=sequence,
+      mapping=residue_constants.restype_order_with_x,
+      map_unknown_to_x=True)
+  features['between_segment_residues'] = np.zeros((num_res,), dtype=np.int32)
+  features['domain_name'] = np.array([description.encode('utf-8')],
+                                     dtype=np.object_)
+  features['residue_index'] = np.array(range(num_res), dtype=np.int32)
+  features['seq_length'] = np.array([num_res] * num_res, dtype=np.int32)
+  features['sequence'] = np.array([sequence.encode('utf-8')], dtype=np.object_)
+  return features
+
+
+def make_msa_features(msa: Sequence[parsers.Msa]) -> FeatureDict:
+  """Constructs a feature dict of MSA features."""
+  if not msa:
+    raise ValueError('At least one MSA must be provided.')
+
+  int_msa = []
+  deletion_matrix = []
+  species_ids = []
+  seen_sequences = set()
+
+  if not msa:
+    raise ValueError(f'MSA {msa_index} must contain at least one sequence.')
+  for sequence_index, sequence in enumerate(msa.sequences):
+    if sequence in seen_sequences:
+      continue
+    seen_sequences.add(sequence)
+    int_msa.append(
+      [residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence])
+    deletion_matrix.append(msa.deletion_matrix[sequence_index])
+    identifiers = msa_identifiers.get_identifiers(
+      msa.descriptions[sequence_index])
+    species_ids.append(identifiers.species_id.encode('utf-8'))
+
+  num_res = len(msa.sequences[0])
+  num_alignments = len(int_msa)
+  features = {}
+  features['deletion_matrix_int'] = np.array(deletion_matrix, dtype=np.int32)
+  features['msa'] = np.array(int_msa, dtype=np.int32)
+  features['num_alignments'] = np.array(
+      [num_alignments] * num_res, dtype=np.int32)
+  features['msa_species_identifiers'] = np.array(species_ids, dtype=np.object_)
+  return features
+
+
+def run_msa_tool(msa_runner, input_fasta_path: str, msa_out_path: str,
+                 msa_format: str, use_precomputed_msas: bool,
+                 max_sto_sequences: Optional[int] = None
+                 ) -> Mapping[str, Any]:
+  """Runs an MSA tool, checking if output already exists first."""
+  if not use_precomputed_msas or not os.path.exists(msa_out_path):
+    if msa_format == 'sto' and max_sto_sequences is not None:
+      result = msa_runner.query(input_fasta_path, max_sto_sequences)[0]  # pytype: disable=wrong-arg-count
+    else:
+      result = msa_runner.query(input_fasta_path)[0]
+    with open(msa_out_path, 'w') as f:
+      f.write(result[msa_format])
+  else:
+    logging.warning('Reading MSA from file %s', msa_out_path)
+    if msa_format == 'sto' and max_sto_sequences is not None:
+      precomputed_msa = parsers.truncate_stockholm_msa(
+          msa_out_path, max_sto_sequences)
+      result = {'sto': precomputed_msa}
+    else:
+      with open(msa_out_path, 'r') as f:
+        result = {msa_format: f.read()}
+  return result
+
+
+class DataPipeline:
+  """Runs the alignment tools and assembles the input features."""
+
+  def __init__(self,
+               jackhmmer_binary_path: str,
+               hhblits_binary_path: str,
+               uniprot_database_path: str,
+               template_searcher: TemplateSearcher,
+               template_featurizer: templates.TemplateHitFeaturizer,
+               uniprot_max_hits: int = 30000,
+               use_precomputed_msas: bool = False,
+               add_species: bool = False):
+    """Initializes the data pipeline."""
+
+    self.template_searcher = template_searcher
+    self.template_featurizer = template_featurizer
+    self.use_precomputed_msas = use_precomputed_msas
+    self.add_species = add_species
+    if add_species:
+        self.uniprot_msa_runner = jackhmmer.Jackhmmer(
+            binary_path=jackhmmer_binary_path,
+            database_path=uniprot_database_path)
+        self.uniprot_max_hits = uniprot_max_hits
+
+  def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict:
+    """Runs alignment tools on the input sequence and creates features."""
+    with open(input_fasta_path) as f:
+      input_fasta_str = f.read()
+    input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
+    if len(input_seqs) != 1:
+      raise ValueError(
+          f'More than one input sequence found in {input_fasta_path}.')
+    input_sequence = input_seqs[0]
+    input_description = input_descs[0]
+    num_res = len(input_sequence)
+
+    uniprot_out_path = os.path.join(msa_output_dir, 'uniprot_hits.sto')
+    jackhmmer_uniprot_result = run_msa_tool(
+        msa_runner=self.uniprot_msa_runner,
+        input_fasta_path=input_fasta_path,
+        msa_out_path=uniprot_out_path,
+        msa_format='sto',
+        use_precomputed_msas=self.use_precomputed_msas,
+        max_sto_sequences=self.uniprot_max_hits)
+
+    msa_for_templates = jackhmmer_uniprot_result['sto']
+    msa_for_templates = parsers.deduplicate_stockholm_msa(msa_for_templates)
+    msa_for_templates = parsers.remove_empty_columns_from_stockholm_msa(
+        msa_for_templates)
+
+    uniprot_msa = parsers.parse_stockholm(jackhmmer_uniprot_result['sto'])
+
+    if self.template_searcher.input_format == 'sto':
+      pdb_templates_result = self.template_searcher.query(msa_for_templates)
+    elif self.template_searcher.input_format == 'a3m':
+      uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(msa_for_templates)
+      pdb_templates_result = self.template_searcher.query(uniref90_msa_as_a3m)
+    else:
+      raise ValueError('Unrecognized template input format: '
+                       f'{self.template_searcher.input_format}')
+
+    pdb_hits_out_path = os.path.join(
+        msa_output_dir, f'pdb_hits.{self.template_searcher.output_format}')
+    with open(pdb_hits_out_path, 'w') as f:
+      f.write(pdb_templates_result)
+
+    pdb_template_hits = self.template_searcher.get_template_hits(
+        output_string=pdb_templates_result, input_sequence=input_sequence)
+
+    templates_result = self.template_featurizer.get_templates(
+        query_sequence=input_sequence,
+        hits=pdb_template_hits)
+
+    sequence_features = make_sequence_features(
+        sequence=input_sequence,
+        description=input_description,
+        num_res=num_res)
+
+    msa_features = make_msa_features(uniprot_msa)
+    #logging.info('UniProt MSA size: %d sequences.', len(uniprot_msa))
+    logging.info('Final (deduplicated) MSA size: %d sequences.',
+                 msa_features['num_alignments'][0])
+    logging.info('Total number of templates (NB: this can include bad '
+                 'templates and is later filtered to top 4): %d.',
+                 templates_result.features['template_domain_names'].shape[0])
+
+    return {**sequence_features, **msa_features, **templates_result.features}
diff --git a/src/alphafold/data/templates.py b/src/alphafold/data/templates.py
index d3759871190e781f3146109361a139611346323c..91681ac3fa073d41a9bdae40dc4f61cf6a9848a8 100644
--- a/src/alphafold/data/templates.py
+++ b/src/alphafold/data/templates.py
@@ -89,8 +89,8 @@ TEMPLATE_FEATURES = {
     'template_aatype': np.float32,
     'template_all_atom_masks': np.float32,
     'template_all_atom_positions': np.float32,
-    'template_domain_names': np.object,
-    'template_sequence': np.object,
+    'template_domain_names': object,
+    'template_sequence': object,
     'template_sum_probs': np.float32,
 }
 
@@ -1002,8 +1002,8 @@ class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
               (1, num_res, residue_constants.atom_type_num), np.float32),
           'template_all_atom_positions': np.zeros(
               (1, num_res, residue_constants.atom_type_num, 3), np.float32),
-          'template_domain_names': np.array([''.encode()], dtype=np.object),
-          'template_sequence': np.array([''.encode()], dtype=np.object),
+          'template_domain_names': np.array([''.encode()], dtype=object),
+          'template_sequence': np.array([''.encode()], dtype=object),
           'template_sum_probs': np.array([0], dtype=np.float32)
       }
     return TemplateSearchResult(
diff --git a/src/alphafold/data/tools/jackhmmer.py b/src/alphafold/data/tools/jackhmmer.py
index 60e0e222c91457c8c1b1dbe0a5f4cac358cd1e69..68997f857f2c4fd3a69a59205310828a4cf08fd2 100644
--- a/src/alphafold/data/tools/jackhmmer.py
+++ b/src/alphafold/data/tools/jackhmmer.py
@@ -167,10 +167,20 @@ class Jackhmmer:
             input_fasta_path: str,
             max_sequences: Optional[int] = None) -> Sequence[Mapping[str, Any]]:
     """Queries the database using Jackhmmer."""
+    return self.query_multiple([input_fasta_path], max_sequences)[0]
+
+  def query_multiple(
+      self,
+      input_fasta_paths: Sequence[str],
+      max_sequences: Optional[int] = None,
+    ) -> Sequence[Sequence[Mapping[str, Any]]]:
+    """Queries the database for multiple queries using Jackhmmer."""
     if self.num_streamed_chunks is None:
-      single_chunk_result = self._query_chunk(
-          input_fasta_path, self.database_path, max_sequences)
-      return [single_chunk_result]
+      single_chunk_results = []
+      for input_fasta_path in input_fasta_paths:
+        single_chunk_results.append([self._query_chunk(
+            input_fasta_path, self.database_path, max_sequences)])
+      return single_chunk_results
 
     db_basename = os.path.basename(self.database_path)
     db_remote_chunk = lambda db_idx: f'{self.database_path}.{db_idx}'
@@ -185,7 +195,7 @@ class Jackhmmer:
 
     # Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk
     with futures.ThreadPoolExecutor(max_workers=2) as executor:
-      chunked_output = []
+      chunked_outputs = [[] for _ in range(len(input_fasta_paths))]
       for i in range(1, self.num_streamed_chunks + 1):
         # Copy the chunk locally
         if i == 1:
@@ -197,9 +207,9 @@ class Jackhmmer:
 
         # Run Jackhmmer with the chunk
         future.result()
-        chunked_output.append(self._query_chunk(
-            input_fasta_path, db_local_chunk(i), max_sequences))
-
+        for fasta_index, input_fasta_path in enumerate(input_fasta_paths):
+          chunked_outputs[fasta_index].append(self._query_chunk(
+              input_fasta_path, db_local_chunk(i), max_sequences))
         # Remove the local copy of the chunk
         os.remove(db_local_chunk(i))
         # Do not set next_future for the last chunk so that this works even for
@@ -208,4 +218,4 @@ class Jackhmmer:
           future = next_future
         if self.streaming_callback:
           self.streaming_callback(i)
-    return chunked_output
+    return chunked_outputs
diff --git a/src/alphafold/model/all_atom_multimer.py b/src/alphafold/model/all_atom_multimer.py
index 361652050effad98660c134ef96a247df13a5d69..2cc49c4d34012b31ab7814e5731c50fa9674e936 100644
--- a/src/alphafold/model/all_atom_multimer.py
+++ b/src/alphafold/model/all_atom_multimer.py
@@ -426,7 +426,7 @@ def torsion_angles_to_frames(
   chi3_frame_to_backb = chi2_frame_to_backb @ all_frames[:, 6]
   chi4_frame_to_backb = chi3_frame_to_backb @ all_frames[:, 7]
 
-  all_frames_to_backb = jax.tree_multimap(
+  all_frames_to_backb = jax.tree_map(
       lambda *x: jnp.concatenate(x, axis=-1), all_frames[:, 0:5],
       chi2_frame_to_backb[:, None], chi3_frame_to_backb[:, None],
       chi4_frame_to_backb[:, None])
diff --git a/src/alphafold/model/common_modules.py b/src/alphafold/model/common_modules.py
index 08776a7f00af3dab4c25289954166bc32bec3539..0b5cd07d5b6d643fceb865591c80001f67387ef1 100644
--- a/src/alphafold/model/common_modules.py
+++ b/src/alphafold/model/common_modules.py
@@ -128,3 +128,64 @@ class Linear(hk.Module):
 
     return output
 
+
+class LayerNorm(hk.LayerNorm):
+  """LayerNorm module.
+
+  Equivalent to hk.LayerNorm but with different parameter shapes: they are
+  always vectors rather than possibly higher-rank tensors. This makes it easier
+  to change the layout whilst keep the model weight-compatible.
+  """
+
+  def __init__(self,
+               axis,
+               create_scale: bool,
+               create_offset: bool,
+               eps: float = 1e-5,
+               scale_init=None,
+               offset_init=None,
+               use_fast_variance: bool = False,
+               name=None,
+               param_axis=None):
+    super().__init__(
+        axis=axis,
+        create_scale=False,
+        create_offset=False,
+        eps=eps,
+        scale_init=None,
+        offset_init=None,
+        use_fast_variance=use_fast_variance,
+        name=name,
+        param_axis=param_axis)
+    self._temp_create_scale = create_scale
+    self._temp_create_offset = create_offset
+
+  def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
+    is_bf16 = (x.dtype == jnp.bfloat16)
+    if is_bf16:
+      x = x.astype(jnp.float32)
+
+    param_axis = self.param_axis[0] if self.param_axis else -1
+    param_shape = (x.shape[param_axis],)
+
+    param_broadcast_shape = [1] * x.ndim
+    param_broadcast_shape[param_axis] = x.shape[param_axis]
+    scale = None
+    offset = None
+    if self._temp_create_scale:
+      scale = hk.get_parameter(
+          'scale', param_shape, x.dtype, init=self.scale_init)
+      scale = scale.reshape(param_broadcast_shape)
+
+    if self._temp_create_offset:
+      offset = hk.get_parameter(
+          'offset', param_shape, x.dtype, init=self.offset_init)
+      offset = offset.reshape(param_broadcast_shape)
+
+    out = super().__call__(x, scale=scale, offset=offset)
+
+    if is_bf16:
+      out = out.astype(jnp.bfloat16)
+
+    return out
+  
\ No newline at end of file
diff --git a/src/alphafold/model/config.py b/src/alphafold/model/config.py
index f8e56c8f4c450dd1a3400b9cfbe9fedf8939c4c8..bad78901e44a08a41fdf3c03bae7a296758ff56b 100644
--- a/src/alphafold/model/config.py
+++ b/src/alphafold/model/config.py
@@ -11,6 +11,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+#
+# Modified to keep the definitions of previous multimer_v2 and multimer models
+# Mu Gao
+
 """Model config."""
 
 import copy
@@ -26,12 +30,12 @@ NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES
 def model_config(name: str) -> ml_collections.ConfigDict:
   """Get the ConfigDict of a CASP14 model."""
 
-  if 'multimer' in name:
-    return CONFIG_MULTIMER
-
   if name not in CONFIG_DIFFS:
     raise ValueError(f'Invalid model name {name}.')
-  cfg = copy.deepcopy(CONFIG)
+  if 'multimer' in name:
+    cfg = copy.deepcopy(CONFIG_MULTIMER)
+  else:
+    cfg = copy.deepcopy(CONFIG)
   cfg.update_from_flattened_dict(CONFIG_DIFFS[name])
   return cfg
 
@@ -65,6 +69,13 @@ MODEL_PRESETS = {
         'model_4_multimer_v2',
         'model_5_multimer_v2',
     ),
+    'multimer_v3': (
+        'model_1_multimer_v3',
+        'model_2_multimer_v3',
+        'model_3_multimer_v3',
+        'model_4_multimer_v3',
+        'model_5_multimer_v3',
+    ),    
 }
 MODEL_PRESETS['monomer_casp14'] = MODEL_PRESETS['monomer']
 
@@ -125,8 +136,31 @@ CONFIG_DIFFS = {
     },
     'model_5_ptm': {
         'model.heads.predicted_aligned_error.weight': 0.1
-    }
+    },
+    'model_1_multimer_v3': {},
+    'model_2_multimer_v3': {},
+    'model_3_multimer_v3': {},
+    'model_4_multimer_v3': {
+        'model.embeddings_and_evoformer.num_extra_msa': 1152
+    },
+    'model_5_multimer_v3': {
+        'model.embeddings_and_evoformer.num_extra_msa': 1152
+    },
 }
+# Key differences between multimer v1/v2 and v3, mostly due to numerical
+# optimisations in the TriangleMultiplication module.
+common_updates = {
+    'model.embeddings_and_evoformer.num_msa': 252,
+    'model.embeddings_and_evoformer.num_extra_msa': 1152,
+    'model.embeddings_and_evoformer.evoformer.triangle_multiplication_incoming.fuse_projection_weights': False,
+    'model.embeddings_and_evoformer.evoformer.triangle_multiplication_outgoing.fuse_projection_weights': False,
+    'model.embeddings_and_evoformer.template.template_pair_stack.triangle_multiplication_incoming.fuse_projection_weights': False,
+    'model.embeddings_and_evoformer.template.template_pair_stack.triangle_multiplication_outgoing.fuse_projection_weights': False,
+}
+CONFIG_DIFFS.update(
+    {f'model_{i}_multimer': common_updates for i in range(1, 6)})
+CONFIG_DIFFS.update(
+    {f'model_{i}_multimer_v2': common_updates for i in range(1, 6)})
 
 CONFIG = ml_collections.ConfigDict({
     'data': {
@@ -267,14 +301,16 @@ CONFIG = ml_collections.ConfigDict({
                     'equation': 'ikc,jkc->ijc',
                     'num_intermediate_channel': 128,
                     'orientation': 'per_row',
-                    'shared_dropout': True
+                    'shared_dropout': True,
+                    'fuse_projection_weights': False,
                 },
                 'triangle_multiplication_incoming': {
                     'dropout_rate': 0.25,
                     'equation': 'kjc,kic->ijc',
                     'num_intermediate_channel': 128,
                     'orientation': 'per_row',
-                    'shared_dropout': True
+                    'shared_dropout': True,
+                    'fuse_projection_weights': False,
                 },
                 'pair_transition': {
                     'dropout_rate': 0.0,
@@ -335,14 +371,16 @@ CONFIG = ml_collections.ConfigDict({
                         'equation': 'ikc,jkc->ijc',
                         'num_intermediate_channel': 64,
                         'orientation': 'per_row',
-                        'shared_dropout': True
+                        'shared_dropout': True,
+                        'fuse_projection_weights': False,
                     },
                     'triangle_multiplication_incoming': {
                         'dropout_rate': 0.25,
                         'equation': 'kjc,kic->ijc',
                         'num_intermediate_channel': 64,
                         'orientation': 'per_row',
-                        'shared_dropout': True
+                        'shared_dropout': True,
+                        'fuse_projection_weights': False,
                     },
                     'pair_transition': {
                         'dropout_rate': 0.0,
@@ -361,7 +399,8 @@ CONFIG = ml_collections.ConfigDict({
             'multimer_mode': False,
             'subbatch_size': 4,
             'use_remat': False,
-            'zero_init': True
+            'zero_init': True,
+            'eval_dropout': False,
         },
         'heads': {
             'distogram': {
@@ -490,27 +529,29 @@ CONFIG_MULTIMER = ml_collections.ConfigDict({
                     'gating': True,
                     'num_head': 4,
                     'orientation': 'per_row',
-                    'shared_dropout': True
+                    'shared_dropout': True,
                 },
                 'triangle_multiplication_incoming': {
                     'dropout_rate': 0.25,
                     'equation': 'kjc,kic->ijc',
                     'num_intermediate_channel': 128,
                     'orientation': 'per_row',
-                    'shared_dropout': True
+                    'shared_dropout': True,
+                    'fuse_projection_weights': True,
                 },
                 'triangle_multiplication_outgoing': {
                     'dropout_rate': 0.25,
                     'equation': 'ikc,jkc->ijc',
                     'num_intermediate_channel': 128,
                     'orientation': 'per_row',
-                    'shared_dropout': True
+                    'shared_dropout': True,
+                    'fuse_projection_weights': True,
                 }
             },
             'extra_msa_channel': 64,
             'extra_msa_stack_num_block': 4,
-            'num_msa': 252,
-            'num_extra_msa': 1152,
+            'num_msa': 508,
+            'num_extra_msa': 2048,
             'masked_msa': {
                 'profile_prob': 0.1,
                 'replace_fraction': 0.15,
@@ -571,24 +612,29 @@ CONFIG_MULTIMER = ml_collections.ConfigDict({
                         'equation': 'kjc,kic->ijc',
                         'num_intermediate_channel': 64,
                         'orientation': 'per_row',
-                        'shared_dropout': True
+                        'shared_dropout': True,
+                        'fuse_projection_weights': True,
                     },
                     'triangle_multiplication_outgoing': {
                         'dropout_rate': 0.25,
                         'equation': 'ikc,jkc->ijc',
                         'num_intermediate_channel': 64,
                         'orientation': 'per_row',
-                        'shared_dropout': True
+                        'shared_dropout': True,
+                        'fuse_projection_weights': True,
                     }
                 }
             },
         },
         'global_config': {
+            'bfloat16': True,
+            'bfloat16_output': False,
             'deterministic': False,
             'multimer_mode': True,
             'subbatch_size': 4,
             'use_remat': False,
-            'zero_init': True
+            'zero_init': True,
+            'eval_dropout': False,
         },
         'heads': {
             'distogram': {
@@ -658,7 +704,13 @@ CONFIG_MULTIMER = ml_collections.ConfigDict({
             }
         },
         'num_ensemble_eval': 1,
-        'num_recycle': 3,
+        'num_recycle': 20,
+        # A negative value indicates that no early stopping will occur, i.e.
+        # the model will always run `num_recycle` number of recycling
+        # iterations.  A positive value will enable early stopping if the
+        # difference in pairwise distances is less than the tolerance between
+        # recycling steps.
+        'recycle_early_stop_tolerance': 0.5,
         'resample_msa_in_recycling': True
     }
 })
diff --git a/src/alphafold/model/folding.py b/src/alphafold/model/folding.py
index 1faf5bd58377880107da119b4b65c96a2f1e728d..e73266489190f7df63632be126443cf1c8a62422 100644
--- a/src/alphafold/model/folding.py
+++ b/src/alphafold/model/folding.py
@@ -331,7 +331,7 @@ class FoldIteration(hk.Module):
     safe_key, *sub_keys = safe_key.split(3)
     sub_keys = iter(sub_keys)
     act = safe_dropout_fn(act, next(sub_keys))
-    act = hk.LayerNorm(
+    act = common_modules.LayerNorm(
         axis=[-1],
         create_scale=True,
         create_offset=True,
@@ -353,7 +353,7 @@ class FoldIteration(hk.Module):
         act = jax.nn.relu(act)
     act += input_act
     act = safe_dropout_fn(act, next(sub_keys))
-    act = hk.LayerNorm(
+    act = common_modules.LayerNorm(
         axis=[-1],
         create_scale=True,
         create_offset=True,
@@ -410,7 +410,7 @@ def generate_affines(representations, batch, config, global_config,
   c = config
   sequence_mask = batch['seq_mask'][:, None]
 
-  act = hk.LayerNorm(
+  act = common_modules.LayerNorm(
       axis=[-1],
       create_scale=True,
       create_offset=True,
@@ -433,7 +433,7 @@ def generate_affines(representations, batch, config, global_config,
                  'affine': affine.to_tensor(),
                  }
 
-  act_2d = hk.LayerNorm(
+  act_2d = common_modules.LayerNorm(
       axis=[-1],
       create_scale=True,
       create_offset=True,
diff --git a/src/alphafold/model/folding_multimer.py b/src/alphafold/model/folding_multimer.py
index 90ce31b902e47564e0bc6e40ed71115ed2e64ad8..b565a0d41b4befa93c88769de1813d3e88a76246 100644
--- a/src/alphafold/model/folding_multimer.py
+++ b/src/alphafold/model/folding_multimer.py
@@ -427,7 +427,7 @@ class FoldIteration(hk.Module):
     safe_key, *sub_keys = safe_key.split(3)
     sub_keys = iter(sub_keys)
     act = safe_dropout_fn(act, next(sub_keys))
-    act = hk.LayerNorm(
+    act = common_modules.LayerNorm(
         axis=-1,
         create_scale=True,
         create_offset=True,
@@ -448,7 +448,7 @@ class FoldIteration(hk.Module):
         act = jax.nn.relu(act)
     act += input_act
     act = safe_dropout_fn(act, next(sub_keys))
-    act = hk.LayerNorm(
+    act = common_modules.LayerNorm(
         axis=-1,
         create_scale=True,
         create_offset=True,
@@ -500,7 +500,7 @@ def generate_monomer_rigids(representations: Mapping[str, jnp.ndarray],
   """
   c = config
   sequence_mask = batch['seq_mask'][:, None]
-  act = hk.LayerNorm(
+  act = common_modules.LayerNorm(
       axis=-1, create_scale=True, create_offset=True, name='single_layer_norm')(
           representations['single'])
 
@@ -523,7 +523,7 @@ def generate_monomer_rigids(representations: Mapping[str, jnp.ndarray],
           rigid
   }
 
-  act_2d = hk.LayerNorm(
+  act_2d = common_modules.LayerNorm(
       axis=-1,
       create_scale=True,
       create_offset=True,
@@ -546,7 +546,7 @@ def generate_monomer_rigids(representations: Mapping[str, jnp.ndarray],
         )
     outputs.append(output)
 
-  output = jax.tree_multimap(lambda *x: jnp.stack(x), *outputs)
+  output = jax.tree_map(lambda *x: jnp.stack(x), *outputs)
   # Pass along for LDDT-Head.
   output['act'] = activations['act']
 
@@ -789,7 +789,7 @@ def backbone_loss(gt_rigid: geometry.Rigid3Array,
   loss_fn = functools.partial(
       all_atom_multimer.frame_aligned_point_error,
       l1_clamp_distance=config.atom_clamp_distance,
-      loss_unit_distance=config.loss_unit_distance)
+      length_scale=config.loss_unit_distance)
 
   loss_fn = jax.vmap(loss_fn, (0, None, None, 0, None, None, None))
   fape = loss_fn(target_rigid, gt_rigid, gt_frames_mask,
@@ -823,7 +823,7 @@ def compute_frames(
   alt_gt_frames = frames_batch['rigidgroups_alt_gt_frames']
   use_alt = use_alt[:, None]
 
-  renamed_gt_frames = jax.tree_multimap(
+  renamed_gt_frames = jax.tree_map(
       lambda x, y: (1. - use_alt) * x + use_alt * y, gt_frames, alt_gt_frames)
 
   return renamed_gt_frames, frames_batch['rigidgroups_gt_exists']
@@ -1160,4 +1160,3 @@ class MultiRigidSidechain(hk.Module):
         'frames': all_frames_to_global,  # geometry.Rigid3Array (N, 8)
     })
     return outputs
-
diff --git a/src/alphafold/model/geometry/rigid_matrix_vector.py b/src/alphafold/model/geometry/rigid_matrix_vector.py
index 299f6401706e78c2b8872ec7db99235ab904578a..4f2c0f006b0a64f399abb28d4297395cd1b9cf1f 100644
--- a/src/alphafold/model/geometry/rigid_matrix_vector.py
+++ b/src/alphafold/model/geometry/rigid_matrix_vector.py
@@ -65,7 +65,7 @@ class Rigid3Array:
     """Return identity Rigid3Array of given shape."""
     return cls(
         rotation_matrix.Rot3Array.identity(shape, dtype=dtype),
-        vector.Vec3Array.zeros(shape, dtype=dtype))
+        vector.Vec3Array.zeros(shape, dtype=dtype))  # pytype: disable=wrong-arg-count  # trace-all-classes
 
   def scale_translation(self, factor: Float) -> Rigid3Array:
     """Scale translation in Rigid3Array by 'factor'."""
@@ -80,7 +80,7 @@ class Rigid3Array:
   def from_array(cls, array):
     rot = rotation_matrix.Rot3Array.from_array(array[..., :3])
     vec = vector.Vec3Array.from_array(array[..., -1])
-    return cls(rot, vec)
+    return cls(rot, vec)  # pytype: disable=wrong-arg-count  # trace-all-classes
 
   @classmethod
   def from_array4x4(cls, array: jnp.ndarray) -> Rigid3Array:
@@ -94,7 +94,7 @@ class Rigid3Array:
         )
     translation = vector.Vec3Array(
         array[..., 0, 3], array[..., 1, 3], array[..., 2, 3])
-    return cls(rotation, translation)
+    return cls(rotation, translation)  # pytype: disable=wrong-arg-count  # trace-all-classes
 
   def __getstate__(self):
     return (VERSION, (self.rotation, self.translation))
diff --git a/src/alphafold/model/geometry/rotation_matrix.py b/src/alphafold/model/geometry/rotation_matrix.py
index 3222329940078095645bd1e6c191a0314c143774..ccb211110024df5ce21dd39b7beb2ece7f5bfc83 100644
--- a/src/alphafold/model/geometry/rotation_matrix.py
+++ b/src/alphafold/model/geometry/rotation_matrix.py
@@ -73,7 +73,7 @@ class Rot3Array:
     """Returns identity of given shape."""
     ones = jnp.ones(shape, dtype=dtype)
     zeros = jnp.zeros(shape, dtype=dtype)
-    return cls(ones, zeros, zeros, zeros, ones, zeros, zeros, zeros, ones)
+    return cls(ones, zeros, zeros, zeros, ones, zeros, zeros, zeros, ones)  # pytype: disable=wrong-arg-count  # trace-all-classes
 
   @classmethod
   def from_two_vectors(cls, e0: vector.Vec3Array,
@@ -96,7 +96,7 @@ class Rot3Array:
     e1 = (e1 - c * e0).normalized()
     # Compute e2 as cross product of e0 and e1.
     e2 = e0.cross(e1)
-    return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z)
+    return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z)  # pytype: disable=wrong-arg-count  # trace-all-classes
 
   @classmethod
   def from_array(cls, array: jnp.ndarray) -> Rot3Array:
@@ -137,7 +137,7 @@ class Rot3Array:
     zx = 2 * (x * z - w * y)
     zy = 2 * (y * z + w * x)
     zz = 1 - 2 * (jnp.square(x) + jnp.square(y))
-    return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz)
+    return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz)  # pytype: disable=wrong-arg-count  # trace-all-classes
 
   @classmethod
   def random_uniform(cls, key, shape, dtype=jnp.float32) -> Rot3Array:
diff --git a/src/alphafold/model/geometry/struct_of_array.py b/src/alphafold/model/geometry/struct_of_array.py
index 97a89fd4a6d9ff1430540532922cdeb21d3aec42..562743b327dcd6872d60df202e8862967674b915 100644
--- a/src/alphafold/model/geometry/struct_of_array.py
+++ b/src/alphafold/model/geometry/struct_of_array.py
@@ -133,7 +133,7 @@ def flatten(instance):
   inner_treedefs = []
   num_arrays = []
   for array_like in array_likes:
-    flat_array_like, inner_treedef = jax.tree_flatten(array_like)
+    flat_array_like, inner_treedef = jax.tree_util.tree_flatten(array_like)
     inner_treedefs.append(inner_treedef)
     flat_array_likes += flat_array_like
     num_arrays.append(len(flat_array_like))
@@ -206,7 +206,7 @@ class StructOfArray:
       for num_array, inner_treedef, array_field in zip(num_arrays,
                                                        inner_treedefs,
                                                        array_fields):
-        value_dict[array_field] = jax.tree_unflatten(
+        value_dict[array_field] = jax.tree_util.tree_unflatten(
             inner_treedef, data[array_start:array_start + num_array])
         array_start += num_array
       metadata_fields = get_metadata_fields(new_cls)
diff --git a/src/alphafold/model/geometry/vector.py b/src/alphafold/model/geometry/vector.py
index 99dcb50f77f7165b8c7f1ed1f0c67a48ca3141cb..8f22cc54ba4c1596101a14522e77c7ddc06bfd02 100644
--- a/src/alphafold/model/geometry/vector.py
+++ b/src/alphafold/model/geometry/vector.py
@@ -53,10 +53,10 @@ class Vec3Array:
       assert all([x == z for x, z in zip(self.x.shape, self.z.shape)])
 
   def __add__(self, other: Vec3Array) -> Vec3Array:
-    return jax.tree_multimap(lambda x, y: x + y, self, other)
+    return jax.tree_map(lambda x, y: x + y, self, other)
 
   def __sub__(self, other: Vec3Array) -> Vec3Array:
-    return jax.tree_multimap(lambda x, y: x - y, self, other)
+    return jax.tree_map(lambda x, y: x - y, self, other)
 
   def __mul__(self, other: Float) -> Vec3Array:
     return jax.tree_map(lambda x: x * other, self)
@@ -104,7 +104,7 @@ class Vec3Array:
     """Return Vec3Array corresponding to zeros of given shape."""
     return cls(
         jnp.zeros(shape, dtype), jnp.zeros(shape, dtype),
-        jnp.zeros(shape, dtype))
+        jnp.zeros(shape, dtype))  # pytype: disable=wrong-arg-count  # trace-all-classes
 
   def to_array(self) -> jnp.ndarray:
     return jnp.stack([self.x, self.y, self.z], axis=-1)
diff --git a/src/alphafold/model/layer_stack_test.py b/src/alphafold/model/layer_stack_test.py
index 062221f6b753f188475de40f3d50f53324735ffc..d2682f895e5f68d882377ab7b8bd5e3c55c07232 100644
--- a/src/alphafold/model/layer_stack_test.py
+++ b/src/alphafold/model/layer_stack_test.py
@@ -198,8 +198,8 @@ class LayerStackTest(parameterized.TestCase):
     assert_fn = functools.partial(
         np.testing.assert_allclose, atol=1e-4, rtol=1e-4)
 
-    jax.tree_multimap(assert_fn, unrolled_grad,
-                      _slice_layers_params(layer_stack_grad))
+    jax.tree_map(assert_fn, unrolled_grad,
+                 _slice_layers_params(layer_stack_grad))
 
   def test_random(self):
     """Random numbers should be handled correctly."""
diff --git a/src/alphafold/model/mapping.py b/src/alphafold/model/mapping.py
index 2524572b5648d2cdeca87bf99c8f82eecc328431..0e736d521b5bd0ab1ef7d98a33be6c79008e45b0 100644
--- a/src/alphafold/model/mapping.py
+++ b/src/alphafold/model/mapping.py
@@ -47,11 +47,11 @@ def _maybe_get_size(array, axis):
 
 
 def _expand_axes(axes, values, name='sharded_apply'):
-  values_tree_def = jax.tree_flatten(values)[1]
+  values_tree_def = jax.tree_util.tree_flatten(values)[1]
   flat_axes = jax.api_util.flatten_axes(name, values_tree_def, axes)
   # Replace None's with PROXY
   flat_axes = [PROXY if x is None else x for x in flat_axes]
-  return jax.tree_unflatten(values_tree_def, flat_axes)
+  return jax.tree_util.tree_unflatten(values_tree_def, flat_axes)
 
 
 def sharded_map(
@@ -125,8 +125,8 @@ def sharded_apply(
     # Expand in axes and Determine Loop range
     in_axes_ = _expand_axes(in_axes, args)
 
-    in_sizes = jax.tree_multimap(_maybe_get_size, args, in_axes_)
-    flat_sizes = jax.tree_flatten(in_sizes)[0]
+    in_sizes = jax.tree_map(_maybe_get_size, args, in_axes_)
+    flat_sizes = jax.tree_util.tree_flatten(in_sizes)[0]
     in_size = max(flat_sizes)
     assert all(i in {in_size, -1} for i in flat_sizes)
 
@@ -137,7 +137,7 @@ def sharded_apply(
     last_shard_size = shard_size if last_shard_size == 0 else last_shard_size
 
     def apply_fun_to_slice(slice_start, slice_size):
-      input_slice = jax.tree_multimap(
+      input_slice = jax.tree_map(
           lambda array, axis: _maybe_slice(array, slice_start, slice_size, axis
                                           ), args, in_axes_)
       return fun(*input_slice)
@@ -158,11 +158,11 @@ def sharded_apply(
             shard_shape[axis] * num_extra_shards +
             remainder_shape[axis],) + shard_shape[axis + 1:]
 
-      out_shapes = jax.tree_multimap(make_output_shape, out_axes_, shard_shapes,
-                                     out_shapes)
+      out_shapes = jax.tree_map(make_output_shape, out_axes_, shard_shapes,
+                                out_shapes)
 
     # Calls dynamic Update slice with different argument order
-    # This is here since tree_multimap only works with positional arguments
+    # This is here since tree_map only works with positional arguments
     def dynamic_update_slice_in_dim(full_array, update, axis, i):
       return jax.lax.dynamic_update_slice_in_dim(full_array, update, i, axis)
 
@@ -170,7 +170,7 @@ def sharded_apply(
       slice_out = apply_fun_to_slice(slice_start, slice_size)
       update_slice = partial(
           dynamic_update_slice_in_dim, i=slice_start)
-      return jax.tree_multimap(update_slice, outputs, slice_out, out_axes_)
+      return jax.tree_map(update_slice, outputs, slice_out, out_axes_)
 
     def scan_iteration(outputs, i):
       new_outputs = compute_shard(outputs, i, shard_size)
@@ -181,7 +181,7 @@ def sharded_apply(
     def allocate_buffer(dtype, shape):
       return jnp.zeros(shape, dtype=dtype)
 
-    outputs = jax.tree_multimap(allocate_buffer, out_dtypes, out_shapes)
+    outputs = jax.tree_map(allocate_buffer, out_dtypes, out_shapes)
 
     if slice_starts.shape[0] > 0:
       outputs, _ = hk.scan(scan_iteration, outputs, slice_starts)
diff --git a/src/alphafold/model/model.py b/src/alphafold/model/model.py
index 680a17acbc0adeb018277a2e83afb946aad21466..b7982738f3dd61657cfbd4d7165b2edb13a1e935 100644
--- a/src/alphafold/model/model.py
+++ b/src/alphafold/model/model.py
@@ -67,7 +67,6 @@ def get_confidence_metrics(
     confidence_metrics['pitm'] = confidence.predicted_interface_tm_score(
         np.asarray(prediction_result['predicted_aligned_error']['logits']),
         np.asarray(prediction_result['predicted_aligned_error']['breaks']),
-        # np.asarray(residue_index),
         np.asarray(prediction_result['structure_module']['final_atom_positions']),
         np.asarray(prediction_result['structure_module']['final_atom_mask']),
         np.asarray(prediction_result['predicted_aligned_error']['asym_id']),
@@ -187,7 +186,7 @@ class RunModel:
     logging.info('Output shape was %s', shape)
     return shape
 
-  def predict(self, feat: features.FeatureDict, random_seed=0,
+  def predict(self, feat: features.FeatureDict, random_seed: int,
             prev=None, prev_ckpt_iter=0, asym_id_list=None, asym_id=None, 
             edge_contacts_thres=10) -> Mapping[str, Any]:
     """Makes a prediction by inferencing the model on the provided features.
@@ -196,7 +195,7 @@ class RunModel:
       feat: A dictionary of NumPy feature arrays as output by
         RunModel.process_features.
       random_seed: The random seed to use when running the model. In the
-        multimer model this controls the MSA sampling
+        multimer model this controls the MSA sampling.
 
     Returns:
       A dictionary of model outputs.
@@ -258,12 +257,14 @@ class RunModel:
 
     if self.config.model.save_recycled:
       *_, recycled_info = recycles
-      structs = recycled_info['atom_positions']
-      structs_masks = recycled_info['atom_mask']
-      plddt = recycled_info['plddt']
-      palign_logits = recycled_info['pred_aligned_error_logits']
-      palign_break = recycled_info['pred_aligned_error_breaks']
-      tol_values = recycled_info['tol_values']
+     # must convert jax array to np array, otherwise some interplay between
+     # jax array and the loops in the AF2Complex metric functions dramatically slow downs the calculations     
+      structs = np.asarray(recycled_info['atom_positions'])
+      structs_masks = np.asarray(recycled_info['atom_mask'])
+      plddt = np.asarray(recycled_info['plddt'])
+      palign_logits = np.asarray(recycled_info['pred_aligned_error_logits'])
+      palign_break = np.asarray(recycled_info['pred_aligned_error_breaks'])
+      tol_values = np.asarray(recycled_info['tol_values'])
       recycled_info_ = []
 
       for i, (s, m, p, a_logits, a_break, tol_val) in enumerate(zip(
@@ -283,7 +284,6 @@ class RunModel:
           "pitm": confidence.predicted_interface_tm_score(
             a_logits,
             a_break,
-            # res_index,
             s,
             m,
             asym_id,
diff --git a/src/alphafold/model/modules.py b/src/alphafold/model/modules.py
index b43cf1cb56ad227c89cd44fbe580807821dfd8be..af9f8ea962c1cc792fd011bd4c5835e3c27abe7f 100644
--- a/src/alphafold/model/modules.py
+++ b/src/alphafold/model/modules.py
@@ -81,6 +81,9 @@ def dropout_wrapper(module,
   residual = module(input_act, mask, is_training=is_training, **kwargs)
   dropout_rate = 0.0 if gc.deterministic else module.config.dropout_rate
 
+  # Will override `is_training` to True if want to use dropout.
+  should_apply_dropout = True if gc.eval_dropout else is_training
+
   if module.config.shared_dropout:
     if module.config.orientation == 'per_row':
       broadcast_dim = 0
@@ -92,7 +95,7 @@ def dropout_wrapper(module,
   residual = apply_dropout(tensor=residual,
                            safe_key=safe_key,
                            rate=dropout_rate,
-                           is_training=is_training,
+                           is_training=should_apply_dropout,
                            broadcast_dim=broadcast_dim)
 
   new_act = output_act + residual
@@ -560,7 +563,7 @@ class Transition(hk.Module):
     num_intermediate = int(nc * self.config.num_intermediate_factor)
     mask = jnp.expand_dims(mask, axis=-1)
 
-    act = hk.LayerNorm(
+    act = common_modules.LayerNorm(
         axis=[-1],
         create_scale=True,
         create_offset=True,
@@ -628,12 +631,15 @@ class Attention(hk.Module):
 
     q_weights = hk.get_parameter(
         'query_w', shape=(q_data.shape[-1], num_head, key_dim),
+        dtype=q_data.dtype,
         init=glorot_uniform())
     k_weights = hk.get_parameter(
         'key_w', shape=(m_data.shape[-1], num_head, key_dim),
+        dtype=q_data.dtype,
         init=glorot_uniform())
     v_weights = hk.get_parameter(
         'value_w', shape=(m_data.shape[-1], num_head, value_dim),
+        dtype=q_data.dtype,
         init=glorot_uniform())
 
     q = jnp.einsum('bqa,ahc->bqhc', q_data, q_weights) * key_dim**(-0.5)
@@ -654,10 +660,12 @@ class Attention(hk.Module):
       gating_weights = hk.get_parameter(
           'gating_w',
           shape=(q_data.shape[-1], num_head, value_dim),
+          dtype=q_data.dtype,
           init=hk.initializers.Constant(0.0))
       gating_bias = hk.get_parameter(
           'gating_b',
           shape=(num_head, value_dim),
+          dtype=q_data.dtype,
           init=hk.initializers.Constant(1.0))
 
       gate_values = jnp.einsum('bqc, chv->bqhv', q_data,
@@ -669,9 +677,12 @@ class Attention(hk.Module):
 
     o_weights = hk.get_parameter(
         'output_w', shape=(num_head, value_dim, self.output_dim),
+        dtype=q_data.dtype,
         init=init)
-    o_bias = hk.get_parameter('output_b', shape=(self.output_dim,),
-                              init=hk.initializers.Constant(0.0))
+    o_bias = hk.get_parameter(
+        'output_b', shape=(self.output_dim,),
+        dtype=q_data.dtype,
+        init=hk.initializers.Constant(0.0))
 
     output = jnp.einsum('bqhc,hco->bqo', weighted_avg, o_weights) + o_bias
 
@@ -718,12 +729,15 @@ class GlobalAttention(hk.Module):
 
     q_weights = hk.get_parameter(
         'query_w', shape=(q_data.shape[-1], num_head, key_dim),
+        dtype=q_data.dtype,
         init=glorot_uniform())
     k_weights = hk.get_parameter(
         'key_w', shape=(m_data.shape[-1], key_dim),
+        dtype=q_data.dtype,
         init=glorot_uniform())
     v_weights = hk.get_parameter(
         'value_w', shape=(m_data.shape[-1], value_dim),
+        dtype=q_data.dtype,
         init=glorot_uniform())
 
     v = jnp.einsum('bka,ac->bkc', m_data, v_weights)
@@ -744,18 +758,23 @@ class GlobalAttention(hk.Module):
 
     o_weights = hk.get_parameter(
         'output_w', shape=(num_head, value_dim, self.output_dim),
+        dtype=q_data.dtype,
         init=init)
-    o_bias = hk.get_parameter('output_b', shape=(self.output_dim,),
-                              init=hk.initializers.Constant(0.0))
+    o_bias = hk.get_parameter(
+        'output_b', shape=(self.output_dim,),
+        dtype=q_data.dtype,
+        init=hk.initializers.Constant(0.0))
 
     if self.config.gating:
       gating_weights = hk.get_parameter(
           'gating_w',
           shape=(q_data.shape[-1], num_head, value_dim),
+          dtype=q_data.dtype,
           init=hk.initializers.Constant(0.0))
       gating_bias = hk.get_parameter(
           'gating_b',
           shape=(num_head, value_dim),
+          dtype=q_data.dtype,
           init=hk.initializers.Constant(1.0))
 
       gate_values = jnp.einsum('bqc, chv->bqhv', q_data, gating_weights)
@@ -805,11 +824,11 @@ class MSARowAttentionWithPairBias(hk.Module):
     bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
     assert len(bias.shape) == 4
 
-    msa_act = hk.LayerNorm(
+    msa_act = common_modules.LayerNorm(
         axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
             msa_act)
 
-    pair_act = hk.LayerNorm(
+    pair_act = common_modules.LayerNorm(
         axis=[-1],
         create_scale=True,
         create_offset=True,
@@ -820,6 +839,7 @@ class MSARowAttentionWithPairBias(hk.Module):
     weights = hk.get_parameter(
         'feat_2d_weights',
         shape=(pair_act.shape[-1], c.num_head),
+        dtype=msa_act.dtype,
         init=hk.initializers.RandomNormal(stddev=init_factor))
     nonbatched_bias = jnp.einsum('qkc,ch->hqk', pair_act, weights)
 
@@ -872,7 +892,7 @@ class MSAColumnAttention(hk.Module):
     bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
     assert len(bias.shape) == 4
 
-    msa_act = hk.LayerNorm(
+    msa_act = common_modules.LayerNorm(
         axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
             msa_act)
 
@@ -927,7 +947,7 @@ class MSAColumnGlobalAttention(hk.Module):
     bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
     assert len(bias.shape) == 4
 
-    msa_act = hk.LayerNorm(
+    msa_act = common_modules.LayerNorm(
         axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
             msa_act)
 
@@ -984,7 +1004,7 @@ class TriangleAttention(hk.Module):
     bias = (1e9 * (pair_mask - 1.))[:, None, None, :]
     assert len(bias.shape) == 4
 
-    pair_act = hk.LayerNorm(
+    pair_act = common_modules.LayerNorm(
         axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
             pair_act)
 
@@ -992,6 +1012,7 @@ class TriangleAttention(hk.Module):
     weights = hk.get_parameter(
         'feat_2d_weights',
         shape=(pair_act.shape[-1], c.num_head),
+        dtype=pair_act.dtype,
         init=hk.initializers.RandomNormal(stddev=init_factor))
     nonbatched_bias = jnp.einsum('qkc,ch->hqk', pair_act, weights)
 
@@ -1089,7 +1110,7 @@ class PredictedLDDTHead(hk.Module):
     """
     act = representations['structure_module']
 
-    act = hk.LayerNorm(
+    act = common_modules.LayerNorm(
         axis=[-1],
         create_scale=True,
         create_offset=True,
@@ -1311,6 +1332,19 @@ class ExperimentallyResolvedHead(hk.Module):
     return output
 
 
+def _layer_norm(axis=-1, name='layer_norm'):
+  return common_modules.LayerNorm(
+      axis=axis,
+      create_scale=True,
+      create_offset=True,
+      eps=1e-5,
+      use_fast_variance=True,
+      scale_init=hk.initializers.Constant(1.),
+      offset_init=hk.initializers.Constant(0.),
+      param_axis=axis,
+      name=name)
+
+
 class TriangleMultiplication(hk.Module):
   """Triangle multiplication layer ("outgoing" or "incoming").
 
@@ -1323,25 +1357,34 @@ class TriangleMultiplication(hk.Module):
     self.config = config
     self.global_config = global_config
 
-  def __call__(self, act, mask, is_training=True):
+  def __call__(self, left_act, left_mask, is_training=True):
     """Builds TriangleMultiplication module.
 
     Arguments:
-      act: Pair activations, shape [N_res, N_res, c_z]
-      mask: Pair mask, shape [N_res, N_res].
+      left_act: Pair activations, shape [N_res, N_res, c_z]
+      left_mask: Pair mask, shape [N_res, N_res].
       is_training: Whether the module is in training mode.
 
     Returns:
-      Outputs, same shape/type as act.
+      Outputs, same shape/type as left_act.
     """
     del is_training
+
+    if self.config.fuse_projection_weights:
+      return self._fused_triangle_multiplication(left_act, left_mask)
+    else:
+      return self._triangle_multiplication(left_act, left_mask)
+
+  @hk.transparent
+  def _triangle_multiplication(self, left_act, left_mask):
+    """Implementation of TriangleMultiplication used in AF2 and AF-M<2.3."""
     c = self.config
     gc = self.global_config
 
-    mask = mask[..., None]
+    mask = left_mask[..., None]
 
-    act = hk.LayerNorm(axis=[-1], create_scale=True, create_offset=True,
-                       name='layer_norm_input')(act)
+    act = common_modules.LayerNorm(axis=[-1], create_scale=True, create_offset=True,
+                       name='layer_norm_input')(left_act)
     input_act = act
 
     left_projection = common_modules.Linear(
@@ -1377,7 +1420,7 @@ class TriangleMultiplication(hk.Module):
     #   b = left_proj_act and a = right_proj_act
     act = jnp.einsum(c.equation, left_proj_act, right_proj_act)
 
-    act = hk.LayerNorm(
+    act = common_modules.LayerNorm(
         axis=[-1],
         create_scale=True,
         create_offset=True,
@@ -1400,6 +1443,50 @@ class TriangleMultiplication(hk.Module):
 
     return act
 
+  @hk.transparent
+  def _fused_triangle_multiplication(self, left_act, left_mask):
+    """TriangleMultiplication with fused projection weights."""
+    mask = left_mask[..., None]
+    c = self.config
+    gc = self.global_config
+
+    left_act = _layer_norm(axis=-1, name='left_norm_input')(left_act)
+
+    # Both left and right projections are fused into projection.
+    projection = common_modules.Linear(
+        2*c.num_intermediate_channel, name='projection')
+    proj_act = mask * projection(left_act)
+
+    # Both left + right gate are fused into gate_values.
+    gate_values = common_modules.Linear(
+        2 * c.num_intermediate_channel,
+        name='gate',
+        bias_init=1.,
+        initializer=utils.final_init(gc))(left_act)
+    proj_act *= jax.nn.sigmoid(gate_values)
+
+    left_proj_act = proj_act[:, :, :c.num_intermediate_channel]
+    right_proj_act = proj_act[:, :, c.num_intermediate_channel:]
+    act = jnp.einsum(c.equation, left_proj_act, right_proj_act)
+
+    act = _layer_norm(axis=-1, name='center_norm')(act)
+
+    output_channel = int(left_act.shape[-1])
+
+    act = common_modules.Linear(
+        output_channel,
+        initializer=utils.final_init(gc),
+        name='output_projection')(act)
+
+    gate_values = common_modules.Linear(
+        output_channel,
+        bias_init=1.,
+        initializer=utils.final_init(gc),
+        name='gating_linear')(left_act)
+    act *= jax.nn.sigmoid(gate_values)
+
+    return act
+
 
 class DistogramHead(hk.Module):
   """Head to predict a distogram.
@@ -1506,7 +1593,7 @@ class OuterProductMean(hk.Module):
     c = self.config
 
     mask = mask[..., None]
-    act = hk.LayerNorm([-1], True, True, name='layer_norm_input')(act)
+    act = common_modules.LayerNorm([-1], True, True, name='layer_norm_input')(act)
 
     left_act = mask * common_modules.Linear(
         c.num_outer_channel,
@@ -1529,9 +1616,11 @@ class OuterProductMean(hk.Module):
         'output_w',
         shape=(c.num_outer_channel, c.num_outer_channel,
                self.num_output_channel),
+        dtype=act.dtype,
         init=init_w)
     output_b = hk.get_parameter(
         'output_b', shape=(self.num_output_channel,),
+        dtype=act.dtype,
         init=hk.initializers.Constant(0.0))
 
     def compute_chunk(left_act):
@@ -1798,20 +1887,20 @@ class EmbeddingsAndEvoformer(hk.Module):
               dgram)
 
     if c.recycle_features:
-      if 'prev_msa_first_row' in batch:
-        prev_msa_first_row = hk.LayerNorm([-1],
-                                          True,
-                                          True,
-                                          name='prev_msa_first_row_norm')(
-                                              batch['prev_msa_first_row'])
-        msa_activations = msa_activations.at[0].add(prev_msa_first_row)
-
-      if 'prev_pair' in batch:
-        pair_activations += hk.LayerNorm([-1],
-                                         True,
-                                         True,
-                                         name='prev_pair_norm')(
-                                             batch['prev_pair'])
+      prev_msa_first_row = common_modules.LayerNorm(
+          axis=[-1],
+          create_scale=True,
+          create_offset=True,
+          name='prev_msa_first_row_norm')(
+              batch['prev_msa_first_row'])
+      msa_activations = msa_activations.at[0].add(prev_msa_first_row)
+
+      pair_activations += common_modules.LayerNorm(
+          axis=[-1],
+          create_scale=True,
+          create_offset=True,
+          name='prev_pair_norm')(
+              batch['prev_pair'])
 
     # Relative position encoding.
     # Jumper et al. (2021) Suppl. Alg. 4 "relpos"
@@ -2080,7 +2169,7 @@ class SingleTemplateEmbedding(hk.Module):
         self.config.template_pair_stack, self.global_config)(
             act, mask_2d, is_training)
 
-    act = hk.LayerNorm([-1], True, True, name='output_layer_norm')(act)
+    act = common_modules.LayerNorm([-1], True, True, name='output_layer_norm')(act)
     return act
 
 
diff --git a/src/alphafold/model/modules_multimer.py b/src/alphafold/model/modules_multimer.py
index c3a9b46dc161cb9c8a3d79381f0daf76e2d17a38..bf0125b30d60349a2a8c94de74225d4128e03c4b 100644
--- a/src/alphafold/model/modules_multimer.py
+++ b/src/alphafold/model/modules_multimer.py
@@ -594,11 +594,13 @@ class EmbeddingsAndEvoformer(hk.Module):
       Feature embedding using the features as described before.
     """
     c = self.config
+    gc = self.global_config
     rel_feats = []
     pos = batch['residue_index']
     asym_id = batch['asym_id']
     asym_id_same = jnp.equal(asym_id[:, None], asym_id[None, :])
     offset = pos[:, None] - pos[None, :]
+    dtype = jnp.bfloat16 if gc.bfloat16 else jnp.float32
 
     clipped_offset = jnp.clip(
         offset + c.max_relative_idx, a_min=0, a_max=2 * c.max_relative_idx)
@@ -638,6 +640,7 @@ class EmbeddingsAndEvoformer(hk.Module):
 
     rel_feat = jnp.concatenate(rel_feats, axis=-1)
 
+    rel_feat = rel_feat.astype(dtype)
     return common_modules.Linear(
         c.pair_channel,
         name='position_activations')(
@@ -649,6 +652,7 @@ class EmbeddingsAndEvoformer(hk.Module):
     gc = self.global_config
 
     batch = dict(batch)
+    dtype = jnp.bfloat16 if gc.bfloat16 else jnp.float32
 
     if safe_key is None:
       safe_key = prng.SafeKey(hk.next_rng_key())
@@ -657,180 +661,178 @@ class EmbeddingsAndEvoformer(hk.Module):
 
     batch['msa_profile'] = make_msa_profile(batch)
 
-    # print(f"aatype = {batch['aatype']}")
-    target_feat = jax.nn.one_hot(batch['aatype'], 21)
-
-    preprocess_1d = common_modules.Linear(
-        c.msa_channel, name='preprocess_1d')(
-            target_feat)
-
-    safe_key, sample_key, mask_key = safe_key.split(3)
-    batch = sample_msa(sample_key, batch, c.num_msa)
-    batch = make_masked_msa(batch, mask_key, c.masked_msa)
-
-    (batch['cluster_profile'],
-     batch['cluster_deletion_mean']) = nearest_neighbor_clusters(batch)
-
-    msa_feat = create_msa_feat(batch)
-
-    preprocess_msa = common_modules.Linear(
-        c.msa_channel, name='preprocess_msa')(
-            msa_feat)
-
-    msa_activations = jnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa
-
-    left_single = common_modules.Linear(
-        c.pair_channel, name='left_single')(
-            target_feat)
-    right_single = common_modules.Linear(
-        c.pair_channel, name='right_single')(
-            target_feat)
-    pair_activations = left_single[:, None] + right_single[None]
-    mask_2d = batch['seq_mask'][:, None] * batch['seq_mask'][None, :]
-    mask_2d = mask_2d.astype(jnp.float32)
-
-    if c.recycle_pos and 'prev_pos' in batch:
-      prev_pseudo_beta = modules.pseudo_beta_fn(
-          batch['aatype'], batch['prev_pos'], None)
-
-      dgram = modules.dgram_from_positions(
-          prev_pseudo_beta, **self.config.prev_pos)
-      pair_activations += common_modules.Linear(
-          c.pair_channel, name='prev_pos_linear')(
-              dgram)
-
-    if c.recycle_features:
-      if 'prev_msa_first_row' in batch:
-        prev_msa_first_row = hk.LayerNorm(
+    with utils.bfloat16_context():
+      target_feat = jax.nn.one_hot(batch['aatype'], 21).astype(dtype)
+
+      preprocess_1d = common_modules.Linear(
+          c.msa_channel, name='preprocess_1d')(
+              target_feat)
+
+      safe_key, sample_key, mask_key = safe_key.split(3)
+      batch = sample_msa(sample_key, batch, c.num_msa)
+      batch = make_masked_msa(batch, mask_key, c.masked_msa)
+
+      (batch['cluster_profile'],
+       batch['cluster_deletion_mean']) = nearest_neighbor_clusters(batch)
+
+      msa_feat = create_msa_feat(batch).astype(dtype)
+
+      preprocess_msa = common_modules.Linear(
+          c.msa_channel, name='preprocess_msa')(
+              msa_feat)
+      msa_activations = jnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa
+
+      left_single = common_modules.Linear(
+          c.pair_channel, name='left_single')(
+              target_feat)
+      right_single = common_modules.Linear(
+          c.pair_channel, name='right_single')(
+              target_feat)
+      pair_activations = left_single[:, None] + right_single[None]
+      mask_2d = batch['seq_mask'][:, None] * batch['seq_mask'][None, :]
+      mask_2d = mask_2d.astype(dtype)
+
+      if c.recycle_pos:
+        prev_pseudo_beta = modules.pseudo_beta_fn(
+            batch['aatype'], batch['prev_pos'], None)
+
+        dgram = modules.dgram_from_positions(
+            prev_pseudo_beta, **self.config.prev_pos)
+        dgram = dgram.astype(dtype)
+        pair_activations += common_modules.Linear(
+            c.pair_channel, name='prev_pos_linear')(
+                dgram)
+      if c.recycle_features:
+        prev_msa_first_row = common_modules.LayerNorm(
             axis=[-1],
             create_scale=True,
             create_offset=True,
             name='prev_msa_first_row_norm')(
-                batch['prev_msa_first_row'])
+                batch['prev_msa_first_row']).astype(dtype)
         msa_activations = msa_activations.at[0].add(prev_msa_first_row)
 
-      if 'prev_pair' in batch:
-        pair_activations += hk.LayerNorm(
+        pair_activations += common_modules.LayerNorm(
             axis=[-1],
             create_scale=True,
             create_offset=True,
             name='prev_pair_norm')(
-                batch['prev_pair'])
+                batch['prev_pair']).astype(dtype)
 
-    if c.max_relative_idx:
-      pair_activations += self._relative_encoding(batch)
+      if c.max_relative_idx:
+        pair_activations += self._relative_encoding(batch)
 
-    if c.template.enabled:
-      template_module = TemplateEmbedding(c.template, gc)
-      template_batch = {
-          'template_aatype': batch['template_aatype'],
-          'template_all_atom_positions': batch['template_all_atom_positions'],
-          'template_all_atom_mask': batch['template_all_atom_mask']
+      if c.template.enabled:
+        template_module = TemplateEmbedding(c.template, gc)
+        template_batch = {
+            'template_aatype': batch['template_aatype'],
+            'template_all_atom_positions': batch['template_all_atom_positions'],
+            'template_all_atom_mask': batch['template_all_atom_mask']
+        }
+        # Construct a mask such that only intra-chain template features are
+        # computed, since all templates are for each chain individually.
+        multichain_mask = batch['asym_id'][:, None] == batch['asym_id'][None, :]
+        safe_key, safe_subkey = safe_key.split()
+        template_act = template_module(
+            query_embedding=pair_activations,
+            template_batch=template_batch,
+            padding_mask_2d=mask_2d,
+            multichain_mask_2d=multichain_mask,
+            is_training=is_training,
+            safe_key=safe_subkey)
+        pair_activations += template_act
+
+      # Extra MSA stack.
+      (extra_msa_feat,
+       extra_msa_mask) = create_extra_msa_feature(batch, c.num_extra_msa)
+      extra_msa_activations = common_modules.Linear(
+          c.extra_msa_channel,
+          name='extra_msa_activations')(
+              extra_msa_feat).astype(dtype)
+      extra_msa_mask = extra_msa_mask.astype(dtype)
+
+      extra_evoformer_input = {
+          'msa': extra_msa_activations,
+          'pair': pair_activations,
       }
-      # Construct a mask such that only intra-chain template features are
-      # computed, since all templates are for each chain individually.
-      multichain_mask = batch['asym_id'][:, None] == batch['asym_id'][None, :]
-      safe_key, safe_subkey = safe_key.split()
-      template_act = template_module(
-          query_embedding=pair_activations,
-          template_batch=template_batch,
-          padding_mask_2d=mask_2d,
-          multichain_mask_2d=multichain_mask,
-          is_training=is_training,
-          safe_key=safe_subkey)
-      pair_activations += template_act
-
-    # Extra MSA stack.
-    (extra_msa_feat,
-     extra_msa_mask) = create_extra_msa_feature(batch, c.num_extra_msa)
-    extra_msa_activations = common_modules.Linear(
-        c.extra_msa_channel,
-        name='extra_msa_activations')(
-            extra_msa_feat)
-    extra_msa_mask = extra_msa_mask.astype(jnp.float32)
-
-    extra_evoformer_input = {
-        'msa': extra_msa_activations,
-        'pair': pair_activations,
-    }
-    extra_masks = {'msa': extra_msa_mask, 'pair': mask_2d}
+      extra_masks = {'msa': extra_msa_mask, 'pair': mask_2d}
 
-    extra_evoformer_iteration = modules.EvoformerIteration(
-        c.evoformer, gc, is_extra_msa=True, name='extra_msa_stack')
+      extra_evoformer_iteration = modules.EvoformerIteration(
+          c.evoformer, gc, is_extra_msa=True, name='extra_msa_stack')
 
-    def extra_evoformer_fn(x):
-      act, safe_key = x
-      safe_key, safe_subkey = safe_key.split()
-      extra_evoformer_output = extra_evoformer_iteration(
-          activations=act,
-          masks=extra_masks,
-          is_training=is_training,
-          safe_key=safe_subkey)
-      return (extra_evoformer_output, safe_key)
+      def extra_evoformer_fn(x):
+        act, safe_key = x
+        safe_key, safe_subkey = safe_key.split()
+        extra_evoformer_output = extra_evoformer_iteration(
+            activations=act,
+            masks=extra_masks,
+            is_training=is_training,
+            safe_key=safe_subkey)
+        return (extra_evoformer_output, safe_key)
 
-    if gc.use_remat:
-      extra_evoformer_fn = hk.remat(extra_evoformer_fn)
+      if gc.use_remat:
+        extra_evoformer_fn = hk.remat(extra_evoformer_fn)
 
-    safe_key, safe_subkey = safe_key.split()
-    extra_evoformer_stack = layer_stack.layer_stack(
-        c.extra_msa_stack_num_block)(
-            extra_evoformer_fn)
-    extra_evoformer_output, safe_key = extra_evoformer_stack(
-        (extra_evoformer_input, safe_subkey))
-
-    pair_activations = extra_evoformer_output['pair']
-
-    # Get the size of the MSA before potentially adding templates, so we
-    # can crop out the templates later.
-    num_msa_sequences = msa_activations.shape[0]
-    evoformer_input = {
-        'msa': msa_activations,
-        'pair': pair_activations,
-    }
-    evoformer_masks = {'msa': batch['msa_mask'].astype(jnp.float32),
-                       'pair': mask_2d}
-
-    if c.template.enabled:
-      template_features, template_masks = (
-          template_embedding_1d(batch=batch, num_channel=c.msa_channel))
-
-      evoformer_input['msa'] = jnp.concatenate(
-          [evoformer_input['msa'], template_features], axis=0)
-      evoformer_masks['msa'] = jnp.concatenate(
-          [evoformer_masks['msa'], template_masks], axis=0)
-
-    evoformer_iteration = modules.EvoformerIteration(
-        c.evoformer, gc, is_extra_msa=False, name='evoformer_iteration')
-
-    def evoformer_fn(x):
-      act, safe_key = x
       safe_key, safe_subkey = safe_key.split()
-      evoformer_output = evoformer_iteration(
-          activations=act,
-          masks=evoformer_masks,
-          is_training=is_training,
-          safe_key=safe_subkey)
-      return (evoformer_output, safe_key)
-
-    if gc.use_remat:
-      evoformer_fn = hk.remat(evoformer_fn)
+      extra_evoformer_stack = layer_stack.layer_stack(
+          c.extra_msa_stack_num_block)(
+              extra_evoformer_fn)
+      extra_evoformer_output, safe_key = extra_evoformer_stack(
+          (extra_evoformer_input, safe_subkey))
+
+      pair_activations = extra_evoformer_output['pair']
+
+      # Get the size of the MSA before potentially adding templates, so we
+      # can crop out the templates later.
+      num_msa_sequences = msa_activations.shape[0]
+      evoformer_input = {
+          'msa': msa_activations,
+          'pair': pair_activations,
+      }
+      evoformer_masks = {
+          'msa': batch['msa_mask'].astype(dtype),
+          'pair': mask_2d
+      }
+      if c.template.enabled:
+        template_features, template_masks = (
+            template_embedding_1d(
+                batch=batch, num_channel=c.msa_channel, global_config=gc))
+
+        evoformer_input['msa'] = jnp.concatenate(
+            [evoformer_input['msa'], template_features], axis=0)
+        evoformer_masks['msa'] = jnp.concatenate(
+            [evoformer_masks['msa'], template_masks], axis=0)
+      evoformer_iteration = modules.EvoformerIteration(
+          c.evoformer, gc, is_extra_msa=False, name='evoformer_iteration')
+
+      def evoformer_fn(x):
+        act, safe_key = x
+        safe_key, safe_subkey = safe_key.split()
+        evoformer_output = evoformer_iteration(
+            activations=act,
+            masks=evoformer_masks,
+            is_training=is_training,
+            safe_key=safe_subkey)
+        return (evoformer_output, safe_key)
+
+      if gc.use_remat:
+        evoformer_fn = hk.remat(evoformer_fn)
 
-    safe_key, safe_subkey = safe_key.split()
-    evoformer_stack = layer_stack.layer_stack(c.evoformer_num_block)(
-        evoformer_fn)
+      safe_key, safe_subkey = safe_key.split()
+      evoformer_stack = layer_stack.layer_stack(c.evoformer_num_block)(
+          evoformer_fn)
 
-    def run_evoformer(evoformer_input):
-      evoformer_output, _ = evoformer_stack((evoformer_input, safe_subkey))
-      return evoformer_output
+      def run_evoformer(evoformer_input):
+        evoformer_output, _ = evoformer_stack((evoformer_input, safe_subkey))
+        return evoformer_output
 
-    evoformer_output = run_evoformer(evoformer_input)
+      evoformer_output = run_evoformer(evoformer_input)
 
-    msa_activations = evoformer_output['msa']
-    pair_activations = evoformer_output['pair']
+      msa_activations = evoformer_output['msa']
+      pair_activations = evoformer_output['pair']
 
-    single_activations = common_modules.Linear(
-        c.seq_channel, name='single_activations')(
-            msa_activations[0])
+      single_activations = common_modules.Linear(
+          c.seq_channel, name='single_activations')(
+              msa_activations[0])
 
     output.update({
         'single':
@@ -844,6 +846,12 @@ class EmbeddingsAndEvoformer(hk.Module):
             msa_activations[0],
     })
 
+    # Convert back to float32 if we're not saving memory.
+    if not gc.bfloat16_output:
+      for k, v in output.items():
+        if v.dtype == jnp.bfloat16:
+          output[k] = v.astype(jnp.float32)
+
     return output
 
 
@@ -990,6 +998,9 @@ class SingleTemplateEmbedding(hk.Module):
       # backbone affine - i.e. in each residues local frame, what direction are
       # each of the other residues.
       raw_atom_pos = template_all_atom_positions
+      if gc.bfloat16:
+        # Vec3Arrays are required to be float32
+        raw_atom_pos = raw_atom_pos.astype(jnp.float32)
 
       atom_pos = geometry.Vec3Array.from_array(raw_atom_pos)
       rigid, backbone_mask = folding_multimer.make_backbone_affine(
@@ -1001,6 +1012,10 @@ class SingleTemplateEmbedding(hk.Module):
       unit_vector = rigid_vec.normalized()
       unit_vector = [unit_vector.x, unit_vector.y, unit_vector.z]
 
+      if gc.bfloat16:
+        unit_vector = [x.astype(jnp.bfloat16) for x in unit_vector]
+        backbone_mask = backbone_mask.astype(jnp.bfloat16)
+
       backbone_mask_2d = backbone_mask[:, None] * backbone_mask[None, :]
       backbone_mask_2d *= multichain_mask_2d
       unit_vector = [x*backbone_mask_2d for x in unit_vector]
@@ -1010,7 +1025,7 @@ class SingleTemplateEmbedding(hk.Module):
       to_concat.extend([(x, 0) for x in unit_vector])
       to_concat.append((backbone_mask_2d, 0))
 
-      query_embedding = hk.LayerNorm(
+      query_embedding = common_modules.LayerNorm(
           axis=[-1],
           create_scale=True,
           create_offset=True,
@@ -1059,12 +1074,13 @@ class SingleTemplateEmbedding(hk.Module):
             template_iteration_fn)
     act, safe_key = template_stack((act, safe_subkey))
 
-    act = hk.LayerNorm(
+    act = common_modules.LayerNorm(
         axis=[-1],
         create_scale=True,
         create_offset=True,
         name='output_layer_norm')(
             act)
+
     return act
 
 
@@ -1117,21 +1133,18 @@ class TemplateEmbeddingIteration(hk.Module):
         act,
         pair_mask,
         safe_key=next(sub_keys))
-
     act = dropout_wrapper_fn(
         modules.TriangleAttention(c.triangle_attention_starting_node, gc,
                                   name='triangle_attention_starting_node'),
         act,
         pair_mask,
         safe_key=next(sub_keys))
-
     act = dropout_wrapper_fn(
         modules.TriangleAttention(c.triangle_attention_ending_node, gc,
                                   name='triangle_attention_ending_node'),
         act,
         pair_mask,
         safe_key=next(sub_keys))
-
     act = dropout_wrapper_fn(
         modules.Transition(c.pair_transition, gc,
                            name='pair_transition'),
@@ -1142,7 +1155,7 @@ class TemplateEmbeddingIteration(hk.Module):
     return act
 
 
-def template_embedding_1d(batch, num_channel):
+def template_embedding_1d(batch, num_channel, global_config):
   """Embed templates into an (num_res, num_templates, num_channels) embedding.
 
   Args:
@@ -1153,6 +1166,7 @@ def template_embedding_1d(batch, num_channel):
       template_all_atom_mask, (num_templates, num_residues, 37) atom mask for
         each template.
     num_channel: The number of channels in the output.
+    global_config: The global_config.
 
   Returns:
     An embedding of shape (num_templates, num_res, num_channels) and a mask of
@@ -1186,6 +1200,10 @@ def template_embedding_1d(batch, num_channel):
 
   template_mask = chi_mask[:, :, 0]
 
+  if global_config.bfloat16:
+    template_features = template_features.astype(jnp.bfloat16)
+    template_mask = template_mask.astype(jnp.bfloat16)
+
   template_activations = common_modules.Linear(
       num_channel,
       initializer='relu',
diff --git a/src/alphafold/model/tf/protein_features_test.py b/src/alphafold/model/tf/protein_features_test.py
index ee88711281f5cfb366ab26b10c75f37b211120ea..f5a351ba863584222272318f10304e9707fe3edf 100644
--- a/src/alphafold/model/tf/protein_features_test.py
+++ b/src/alphafold/model/tf/protein_features_test.py
@@ -27,6 +27,10 @@ def _random_bytes():
 
 class FeaturesTest(parameterized.TestCase, tf.test.TestCase):
 
+  def setUp(self):
+    super().setUp()
+    tf.disable_v2_behavior()
+
   def testFeatureNames(self):
     self.assertEqual(len(protein_features.FEATURE_SIZES),
                      len(protein_features.FEATURE_TYPES))
@@ -47,5 +51,4 @@ class FeaturesTest(parameterized.TestCase, tf.test.TestCase):
 
 
 if __name__ == '__main__':
-  tf.disable_v2_behavior()
   absltest.main()
diff --git a/src/alphafold/model/tf/shape_helpers_test.py b/src/alphafold/model/tf/shape_helpers_test.py
index d7797b340514d9577dd77b9e9660babd0aa52b5e..16c032bae89db6320f27920f1715159f2319231d 100644
--- a/src/alphafold/model/tf/shape_helpers_test.py
+++ b/src/alphafold/model/tf/shape_helpers_test.py
@@ -21,6 +21,10 @@ import tensorflow.compat.v1 as tf
 
 class ShapeTest(tf.test.TestCase):
 
+  def setUp(self):
+    super().setUp()
+    tf.disable_v2_behavior()
+
   def test_shape_list(self):
     """Test that shape_list can allow for reshaping to dynamic shapes."""
     a = tf.zeros([10, 4, 4, 2])
@@ -35,5 +39,4 @@ class ShapeTest(tf.test.TestCase):
 
 
 if __name__ == '__main__':
-  tf.disable_v2_behavior()
   tf.test.main()
diff --git a/src/alphafold/model/utils.py b/src/alphafold/model/utils.py
index 40ca1683eab8c71522aae3ff9fa1dced89439772..8d70376e38bade8374592c38e5bd52d4e1f3ee4a 100644
--- a/src/alphafold/model/utils.py
+++ b/src/alphafold/model/utils.py
@@ -15,6 +15,7 @@
 """A collection of JAX utility functions for use in protein folding."""
 
 import collections
+import contextlib
 import functools
 import numbers
 from typing import Mapping
@@ -25,6 +26,27 @@ import jax.numpy as jnp
 import numpy as np
 
 
+def bfloat16_creator(next_creator, shape, dtype, init, context):
+  """Creates float32 variables when bfloat16 is requested."""
+  if context.original_dtype == jnp.bfloat16:
+    dtype = jnp.float32
+  return next_creator(shape, dtype, init)
+
+
+def bfloat16_getter(next_getter, value, context):
+  """Casts float32 to bfloat16 when bfloat16 was originally requested."""
+  if context.original_dtype == jnp.bfloat16:
+    assert value.dtype == jnp.float32
+    value = value.astype(jnp.bfloat16)
+  return next_getter(value)
+
+
+@contextlib.contextmanager
+def bfloat16_context():
+  with hk.custom_creator(bfloat16_creator), hk.custom_getter(bfloat16_getter):
+    yield
+
+
 def final_init(config):
   if config.zero_init:
     return 'zeros'
@@ -34,7 +56,7 @@ def final_init(config):
 
 def batched_gather(params, indices, axis=0, batch_dims=0):
   """Implements a JAX equivalent of `tf.gather` with `axis` and `batch_dims`."""
-  take_fn = lambda p, i: jnp.take(p, i, axis=axis)
+  take_fn = lambda p, i: jnp.take(p, i, axis=axis, mode='clip')
   for _ in range(batch_dims):
     take_fn = jax.vmap(take_fn)
   return take_fn(params, indices)
@@ -54,7 +76,7 @@ def mask_mean(mask, value, axis=None, drop_mask_channel=False, eps=1e-10):
     axis = [axis]
   elif axis is None:
     axis = list(range(len(mask_shape)))
-  assert isinstance(axis, collections.Iterable), (
+  assert isinstance(axis, collections.abc.Iterable), (
       'axis needs to be either an iterable, integer or "None"')
 
   broadcast_factor = 1.
diff --git a/src/alphafold/relax/amber_minimize.py b/src/alphafold/relax/amber_minimize.py
index ef1496942e5c422f5556c56b447d67354e9c1496..e21a0dc302109f01623a75ffb533b28958e85944 100644
--- a/src/alphafold/relax/amber_minimize.py
+++ b/src/alphafold/relax/amber_minimize.py
@@ -26,6 +26,7 @@ from alphafold.relax import cleanup
 from alphafold.relax import utils
 import ml_collections
 import numpy as np
+import jax
 from simtk import openmm
 from simtk import unit
 from simtk.openmm import app as openmm_app
@@ -486,7 +487,9 @@ def run_pipeline(
       pdb_string = clean_protein(prot, checks=True)
     else:
       pdb_string = ret["min_pdb"]
-    ret.update(get_violation_metrics(prot))
+    # Calculation of violations can cause CUDA errors for some JAX versions.
+    with jax.default_device(jax.devices("cpu")[0]):
+      ret.update(get_violation_metrics(prot))
     ret.update({
         "num_exclusions": len(exclude_residues),
         "iteration": iteration,
@@ -500,51 +503,3 @@ def run_pipeline(
                  ret["num_residue_violations"], ret["num_exclusions"])
     iteration += 1
   return ret
-
-
-def get_initial_energies(pdb_strs: Sequence[str],
-                         stiffness: float = 0.0,
-                         restraint_set: str = "non_hydrogen",
-                         exclude_residues: Optional[Sequence[int]] = None):
-  """Returns initial potential energies for a sequence of PDBs.
-
-  Assumes the input PDBs are ready for minimization, and all have the same
-  topology.
-  Allows time to be saved by not pdbfixing / rebuilding the system.
-
-  Args:
-    pdb_strs: List of PDB strings.
-    stiffness: kcal/mol A**2, spring constant of heavy atom restraining
-        potential.
-    restraint_set: Which atom types to restrain.
-    exclude_residues: An optional list of zero-indexed residues to exclude from
-        restraints.
-
-  Returns:
-    A list of initial energies in the same order as pdb_strs.
-  """
-  exclude_residues = exclude_residues or []
-
-  openmm_pdbs = [openmm_app.PDBFile(PdbStructure(io.StringIO(p)))
-                 for p in pdb_strs]
-  force_field = openmm_app.ForceField("amber99sb.xml")
-  system = force_field.createSystem(openmm_pdbs[0].topology,
-                                    constraints=openmm_app.HBonds)
-  stiffness = stiffness * ENERGY / (LENGTH**2)
-  if stiffness > 0 * ENERGY / (LENGTH**2):
-    _add_restraints(system, openmm_pdbs[0], stiffness, restraint_set,
-                    exclude_residues)
-  simulation = openmm_app.Simulation(openmm_pdbs[0].topology,
-                                     system,
-                                     openmm.LangevinIntegrator(0, 0.01, 0.0),
-                                     openmm.Platform.getPlatformByName("CPU"))
-  energies = []
-  for pdb in openmm_pdbs:
-    try:
-      simulation.context.setPositions(pdb.positions)
-      state = simulation.context.getState(getEnergy=True)
-      energies.append(state.getPotentialEnergy().value_in_unit(ENERGY))
-    except Exception as e:  # pylint: disable=broad-except
-      logging.error("Error getting initial energy, returning large value %s", e)
-      energies.append(unit.Quantity(1e20, ENERGY))
-  return energies
diff --git a/src/alphafold/relax/relax.py b/src/alphafold/relax/relax.py
index bd6c9fd04b277679ece2b62a7c373f1127e6b1a6..ebbd72d0247b624c6830dee9c0a01ec5a5d6ae61 100644
--- a/src/alphafold/relax/relax.py
+++ b/src/alphafold/relax/relax.py
@@ -56,7 +56,8 @@ class AmberRelaxation(object):
     self._use_gpu = use_gpu
 
   def process(self, *,
-              prot: protein.Protein) -> Tuple[str, Dict[str, Any], np.ndarray]:
+              prot: protein.Protein
+              ) -> Tuple[str, Dict[str, Any], Sequence[float]]:
     """Runs Amber relax on a prediction, adds hydrogens, returns PDB string."""
     out = amber_minimize.run_pipeline(
         prot=prot, max_iterations=self._max_iterations,
@@ -73,12 +74,11 @@ class AmberRelaxation(object):
         'attempts': out['min_attempts'],
         'rmsd': rmsd
     }
-    pdb_str = amber_minimize.clean_protein(prot)
-    min_pdb = utils.overwrite_pdb_coordinates(pdb_str, min_pos)
+    min_pdb = out['min_pdb']
     min_pdb = utils.overwrite_b_factors(min_pdb, prot.b_factors)
     utils.assert_equal_nonterminal_atom_types(
         protein.from_pdb_string(min_pdb).atom_mask,
         prot.atom_mask)
     violations = out['structural_violations'][
-        'total_per_residue_violations_mask']
+        'total_per_residue_violations_mask'].tolist()
     return min_pdb, debug_data, violations
diff --git a/src/alphafold/relax/relax_test.py b/src/alphafold/relax/relax_test.py
index 57e594e8a4f684e8bbab0bf645bad3776cec3d00..8ab5142e31f4955863f4b75c8c5eb8042bdcdb1c 100644
--- a/src/alphafold/relax/relax_test.py
+++ b/src/alphafold/relax/relax_test.py
@@ -82,7 +82,7 @@ class RunAmberRelaxTest(absltest.TestCase):
          0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
          0, 0, 0, 0])
     # Check no violations were added. Can't check exactly due to stochasticity.
-    self.assertTrue(np.all(num_violations <= exp_num_violations))
+    self.assertTrue(np.all(np.array(num_violations) <= exp_num_violations))
 
 
 if __name__ == '__main__':
diff --git a/src/alphafold/relax/utils.py b/src/alphafold/relax/utils.py
index 4bd4acad4e5e7f071fb0ad98e7711b4856843464..0207df5ba24d4876208a26ee06b7e0c65b509ee8 100644
--- a/src/alphafold/relax/utils.py
+++ b/src/alphafold/relax/utils.py
@@ -17,17 +17,6 @@ import io
 from alphafold.common import residue_constants
 from Bio import PDB
 import numpy as np
-from simtk.openmm import app as openmm_app
-from simtk.openmm.app.internal.pdbstructure import PdbStructure
-
-
-def overwrite_pdb_coordinates(pdb_str: str, pos) -> str:
-  pdb_file = io.StringIO(pdb_str)
-  structure = PdbStructure(pdb_file)
-  topology = openmm_app.PDBFile(structure).getTopology()
-  with io.StringIO() as f:
-    openmm_app.PDBFile.writeFile(topology, pos, f)
-    return f.getvalue()
 
 
 def overwrite_b_factors(pdb_str: str, bfactors: np.ndarray) -> str:
@@ -74,7 +63,7 @@ def assert_equal_nonterminal_atom_types(
   """Checks that pre- and post-minimized proteins have same atom set."""
   # Ignore any terminal OXT atoms which may have been added by minimization.
   oxt = residue_constants.atom_order['OXT']
-  no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=np.bool)
+  no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=bool)
   no_oxt_mask[..., oxt] = False
   np.testing.assert_almost_equal(ref_atom_mask[no_oxt_mask],
                                  atom_mask[no_oxt_mask])
diff --git a/src/run_af2c_fea.py b/src/run_af2c_fea.py
index 23e5407ccf12152418b63ec20b6d10ffe0eb3637..982327143bd8016fe6cc88a0bce6a6bf8486e0fa 100644
--- a/src/run_af2c_fea.py
+++ b/src/run_af2c_fea.py
@@ -40,6 +40,7 @@ from alphafold.common import protein
 from alphafold.common import residue_constants
 from alphafold.data import pipeline
 from alphafold.data import pipeline_multimer
+from alphafold.data import pipeline_uniprot
 from alphafold.data import templates
 from alphafold.data.tools import hhsearch
 from alphafold.data.tools import hmmsearch
@@ -98,14 +99,16 @@ flags.DEFINE_string('obsolete_pdbs_path', None, 'Path to file containing a '
                     'mapping from obsolete PDB IDs to the PDB IDs of their '
                     'replacements.')
 flags.DEFINE_enum('db_preset', 'reduced_dbs',
-                  ['full_dbs', 'reduced_dbs'],
+                  ['full_dbs', 'reduced_dbs', 'uniprot'],
                   'Choose preset MSA database configuration - '
                   'smaller genetic database config (reduced_dbs) or '
-                  'full genetic database config  (full_dbs)')
+                  'full genetic database config  (full_dbs) or '
+                  'only uniprot database plus pdb (uniprot)')
 flags.DEFINE_enum('feature_mode', 'monomer',
-                  ['monomer', 'monomer+species', 'multimer'],
+                  ['monomer', 'monomer+species', 'monomer+fullpdb', 'multimer'],
                   'Choose the mode of output feature sets - for monomer prediction, '
-                  'monomer plus species ids (for customized pairing later), and '
+                  'monomer plus species ids (for customized pairing later), '
+                  'monomer using full pdb (instead of pdb70) plus species id, and '
                   'for multimer prediction using the default MSA pairing')
 flags.DEFINE_integer('random_seed', None, 'The random seed for the data '
                      'pipeline. By default, this is randomly generated. Note '
@@ -182,21 +185,24 @@ def main(argv):
       raise ValueError(f'Could not find path to the "{tool_name}" binary. Make '
                        'sure it is installed on your system.')
 
-  use_small_bfd = FLAGS.db_preset == 'reduced_dbs'
-  _check_flag('small_bfd_database_path', 'db_preset',
-              should_be_set=use_small_bfd)
-  _check_flag('bfd_database_path', 'db_preset',
-              should_be_set=not use_small_bfd)
-  _check_flag('uniclust30_database_path', 'db_preset',
-              should_be_set=not use_small_bfd)
+  if FLAGS.db_preset != 'uniprot':
+      use_small_bfd = FLAGS.db_preset == 'reduced_dbs'
+      _check_flag('small_bfd_database_path', 'db_preset',
+                  should_be_set=use_small_bfd)
+      _check_flag('bfd_database_path', 'db_preset',
+                  should_be_set=(FLAGS.db_preset == 'full_dbs'))
+      _check_flag('uniclust30_database_path', 'db_preset',
+                  should_be_set=(FLAGS.db_preset == 'full_dbs'))
 
   run_multimer_system = 'multimer' in FLAGS.feature_mode
-  add_species = 'species' in FLAGS.feature_mode
+  add_species = any(x in FLAGS.feature_mode for x in ['species', 'fullpdb'])
   print(f"add_species is {add_species}")
   _check_flag('pdb70_database_path', 'feature_mode',
-              should_be_set=not run_multimer_system)
+              should_be_set=not (run_multimer_system
+                or 'fullpdb' in FLAGS.feature_mode))
   _check_flag('pdb_seqres_database_path', 'feature_mode',
-              should_be_set=run_multimer_system)
+              should_be_set=(run_multimer_system
+                        or 'fullpdb' in FLAGS.feature_mode))
   _check_flag('uniprot_database_path', 'feature_mode',
          should_be_set=(run_multimer_system or add_species))
 
@@ -206,7 +212,7 @@ def main(argv):
     raise ValueError('All FASTA paths must have a unique basename.')
 
 
-  if run_multimer_system:
+  if run_multimer_system or 'fullpdb' in FLAGS.feature_mode:
     template_searcher = hmmsearch.Hmmsearch(
         binary_path=FLAGS.hmmsearch_binary_path,
         hmmbuild_binary_path=FLAGS.hmmbuild_binary_path,
@@ -235,7 +241,8 @@ def main(argv):
   # else:
   #   mono_uniprot_database_path = FLAGS.uniref90_database_path
 
-  monomer_data_pipeline = pipeline.DataPipeline(
+  if FLAGS.db_preset != 'uniprot':
+    monomer_data_pipeline = pipeline.DataPipeline(
       jackhmmer_binary_path=FLAGS.jackhmmer_binary_path,
       hhblits_binary_path=FLAGS.hhblits_binary_path,
       uniref90_database_path=FLAGS.uniref90_database_path,
@@ -249,6 +256,15 @@ def main(argv):
       use_small_bfd=use_small_bfd,
       use_precomputed_msas=FLAGS.use_precomputed_msas,
       add_species=add_species)
+  else:
+    monomer_data_pipeline = pipeline_uniprot.DataPipeline(
+      jackhmmer_binary_path=FLAGS.jackhmmer_binary_path,
+      hhblits_binary_path=FLAGS.hhblits_binary_path,
+      uniprot_database_path=FLAGS.uniprot_database_path,
+      template_searcher=template_searcher,
+      template_featurizer=template_featurizer,
+      use_precomputed_msas=FLAGS.use_precomputed_msas,
+      add_species=add_species)
 
   if run_multimer_system:
     data_pipeline = pipeline_multimer.DataPipeline(
diff --git a/src/run_af2c_min.py b/src/run_af2c_min.py
index bf06d2c839699f5e4c77cb13d24438aaed90f73d..0ff4519677a2cc27b55f17ccb733bf56be3ee0ca 100644
--- a/src/run_af2c_min.py
+++ b/src/run_af2c_min.py
@@ -11,6 +11,7 @@ import os
 import pickle
 import re
 import time
+import json
 
 from absl import app
 from absl import flags
@@ -78,7 +79,8 @@ def main(argv):
     output_dir = os.path.join(FLAGS.output_dir, target_name)
     if not os.path.exists(output_dir):
       os.makedirs(output_dir)
-
+  
+    relax_metrics = {}
     for afile in os.listdir(input_dir):
       # find all unrelaxed pdb files, ignore ones with 'relaxed' as prefix
       if not afile.endswith(".pdb") or afile.startswith("relaxed_"):
@@ -96,6 +98,11 @@ def main(argv):
         # Relax the prediction.
         t_0 = time.time()
         relaxed_pdb_str, log, _ = amber_relaxer.process(prot=unrelaxed_protein)
+        relaxed_pdb_str, _, violations = amber_relaxer.process(prot=unrelaxed_protein)
+        relax_metrics[f'relaxed_{unrelaxed_pdb_file}'] = {
+            'remaining_violations': violations,
+            'remaining_violations_count': sum(violations)
+        }       
         relaxation_time = time.time() - t_0
 
         # Fix residue index, and ignore hydrogen atoms
@@ -107,7 +114,12 @@ def main(argv):
         with open(relaxed_output_path, 'w') as f:
           f.write(relaxed_pdb_str)
 
-        logging.info(f"{target_name} relaxation done, time spent {relaxation_time:.1f} seconds, Efinal {log['final_energy']:.2f}, rmsd {log['rmsd']:.2f}")
+        logging.info(f"{target_name} relaxation done, time spent {relaxation_time:.1f} seconds, "
+            f"Efinal {log['final_energy']:.2f}, rmsd {log['rmsd']:.2f}")
+
+    relax_metrics_path = os.path.join(output_dir, 'relax_metrics.json')
+    with open(relax_metrics_path, 'w') as f:
+      f.write(json.dumps(relax_metrics, indent=4))
 
 if __name__ == '__main__':
   flags.mark_flags_as_required([
diff --git a/src/run_af2c_mod.py b/src/run_af2c_mod.py
index 2f6f8fc74db91954482e873c633ab623d02fa73b..0ac472b0356c335054260d1c330cd25c0e7d751f 100644
--- a/src/run_af2c_mod.py
+++ b/src/run_af2c_mod.py
@@ -20,7 +20,7 @@
 #
 # Note: AF2Complex is a modified, enhanced version of AlphaFold 2.
 # Mu Gao and Davi Nakajima An
-# Georgia Institute of Technology, 2021-2022
+# Georgia Institute of Technology, 2021-2023
 #
 """AF2Complex: protein complex structure prediction with deep learning"""
 import json
@@ -116,6 +116,10 @@ flags.DEFINE_enum('msa_pairing', None,
 flags.DEFINE_boolean('do_cluster_analysis', False, 'Whether to print out clusters of protein chains in the prediction')
 flags.DEFINE_integer('cluster_edge_thres', 10, 'The number of contacts between chains that constitute an edge in the '
                   'cluster analysis', lower_bound=0)
+flags.DEFINE_float('pdb_iscore_cf', -1.0, 'If interface icore is present, only write the pdb of the structural model '
+                    'if the iScore is larger than this cutoff value. Useful for large-scale screening. ')
+flags.DEFINE_boolean('allow_dropout', False, 'Allow dropout during model inference. This is an experimental feature. '
+                     'Default is disabled.')
 
 FLAGS = flags.FLAGS
 
@@ -274,7 +278,7 @@ def predict_structure(
       # Get mean pLDDT confidence metric.
       plddt = np.mean(result['plddt'])
       plddts[log_model_name] = round(plddt, 2)
-      ptm = 0
+      ptm = 0; iptm = 0
       if 'ptm' in result:
           ptm = result['ptm'].tolist()
           ptms[log_model_name] = round(ptm, 4)
@@ -299,16 +303,20 @@ def predict_structure(
           print(f"Info: {target_name} {log_model_name}, ",
             f"tol = {tol:5.2f}, pLDDT = {plddt:.2f}, pTM-score = {ptm:.4f}", end='')
           if len(monomers) > 1 or monomers[0]['copy_number'] > 1: # hetero- or homo-oligomer target
-            print(f", piTM-score = {pitm:.4f}, interface-score = {inter_sc:.4f}", end='')
-            print(f", iRes = {inter_residues:<4d} iCnt = {inter_contacts:<4.0f}")
+            print(f", piTM-score = {pitm:.4f}, interface-score = {inter_sc:.4f}",
+              f", iRes = {inter_residues:<4d} iCnt = {inter_contacts:<4.0f}")
           else:
             print('')
       else:
           print(f"Info: {target_name} {log_model_name} performed {tot_recycle} recycles,",
-            f"final tol = {tol_value:.2f}, pLDDT = {plddt:.2f}, pTM-score = {ptm:.4f}", end='')
-          if len(monomers) > 1 or monomers[0]['copy_number'] > 1:
-            print(f", piTM-score = {pitm:.4f}, interface-score = {inter_sc:.4f}", end='')
-            print(f", iRes = {inter_residues:<4d} iCnt = {inter_contacts:<4.0f}")
+            f"final tol = {tol_value:.2f}, pLDDT = {plddt:.2f}", end='')
+          if 'iptm+ptm' in result:
+              print(f", iptm+ptm = {iptm:.4f}", end='')
+          else:
+              print(f", pTM-score = {ptm:.4f}", end='')
+          if len(monomers) > 1 or monomers[0]['copy_number'] > 1:            
+            print(f", piTM-score = {pitm:.4f}, interface-score = {inter_sc:.4f}",
+                f", iRes = {inter_residues:<4d} iCnt = {inter_contacts:<4.0f}")
             if 'cluster_analysis' in result:
               clus_res = result['cluster_analysis']
               idx2chain_name = get_asymid2chain_name(target)
@@ -326,12 +334,13 @@ def predict_structure(
             print('')
 
       # Save the model outputs, not saving pkl for intermeidate recycles to save storage space
-      if (recycle_index == tot_recycle and flags.output_pickle) or FLAGS.save_recycled == 3:
+      # skip saving pkl if iScore less than the specified cutoff
+      if ((recycle_index == tot_recycle and flags.output_pickle) or FLAGS.save_recycled == 3) and inter_sc > FLAGS.pdb_iscore_cf:
           result_output_path = os.path.join(out_dir, f'{log_model_name}.pkl')
           with open(result_output_path, 'wb') as f:
               pickle.dump(result, f, protocol=4)
 
-      if recycle_index == tot_recycle or FLAGS.save_recycled >= 2:
+      if (recycle_index == tot_recycle or FLAGS.save_recycled >= 2) and inter_sc > FLAGS.pdb_iscore_cf:
           # Set the b-factors to the per-residue plddt
           final_atom_mask = result['structure_module']['final_atom_mask']
           b_factors = result['plddt'][:, None] * final_atom_mask
@@ -469,6 +478,11 @@ def main(argv):
       model_config.model.recycle_tol = recycle_tol
       model_config.model.save_recycled = FLAGS.save_recycled
 
+      # allow drop out for model inference, this is an advanced feature only for expert users
+      if FLAGS.allow_dropout:
+          print("Info: allow dropout for model inference")
+          model_config.model.global_config.eval_dropout = True
+
       model_params = data.get_model_haiku_params(
           model_name=model_name, data_dir=FLAGS.data_dir)
       model_runner = model.RunModel(model_config, model_params)
diff --git a/src/setup.py b/src/setup.py
index 762bd7be8f03bdb6142493dcd8a419dcd066939d..753f70097db6c57489ba07aa30e8830ff6e17634 100644
--- a/src/setup.py
+++ b/src/setup.py
@@ -18,7 +18,7 @@ from setuptools import setup
 
 setup(
     name='alphafold',
-    version='2.0.0',
+    version='2.3.1',
     description='An implementation of the inference pipeline of AlphaFold v2.0.'
     'This is a completely new model that was entered as AlphaFold2 in CASP14 '
     'and published in Nature.',
@@ -38,10 +38,14 @@ setup(
         'jax',
         'ml-collections',
         'numpy',
+        'pandas',
         'scipy',
-        'tensorflow',
+        'tensorflow-cpu',
+    ],
+    tests_require=[
+        'matplotlib',  # For notebook_utils_test.
+        'mock',
     ],
-    tests_require=['mock'],
     classifiers=[
         'Development Status :: 5 - Production/Stable',
         'Intended Audience :: Science/Research',
@@ -49,6 +53,9 @@ setup(
         'Operating System :: POSIX :: Linux',
         'Programming Language :: Python :: 3.6',
         'Programming Language :: Python :: 3.7',
+        'Programming Language :: Python :: 3.8',
+        'Programming Language :: Python :: 3.9',
+        'Programming Language :: Python :: 3.10',
         'Topic :: Scientific/Engineering :: Artificial Intelligence',
     ],
 )
diff --git a/tools/run_interface_score.py b/tools/run_interface_score.py
index 3675be76e1253ca5964c08090c8851e1ded8c113..0b4611a3c10073a40261327cc9f567f5117fd9f0 100644
--- a/tools/run_interface_score.py
+++ b/tools/run_interface_score.py
@@ -1,15 +1,24 @@
 """Calculate interface score using pickle output file generated by AF2Complex"""
+#
+# Mu Gao and Davi Nakajima An
+# Georgia Institute of Technology
+#
 import os
 import sys
-
-sys.path.append('../src/')
 import pickle
 import re
 import json
+import time
+import logging
+
 from absl import app
 from absl import flags
 from tqdm import tqdm
 
+parent_dir = os.path.dirname( os.path.dirname(os.path.realpath(__file__)) )
+af_dir = os.path.join(parent_dir, 'src')
+sys.path.append(af_dir)
+
 from run_af2c_mod import FLAGS, get_asymid2chain_name
 
 from alphafold.data.complex import make_complex_features
@@ -22,6 +31,13 @@ import alphafold.data.complex as af2c
 import numpy as np
 # Internal import (7716).
 
+flags.DEFINE_string('model_str', 'model_', 'Only relax a model with a specified '
+                    'string, e.g., ranked_top1. The default will process all model_* pdb files')
+flags.DEFINE_float('interface_dist_thres', 4.5, 'The distance threshold in Angstrom for interface residues. '
+                    'If a heavy atom of residue i in chain A is less than this distance from '
+                    'another heavy atom of residue j in chain B, residues i j are interface residues.'
+                    'This is only used to calculate the confidence metrics such as the interface score.', 
+                    lower_bound=3.5)
 
 def get_asym_id(target, flags):
   """Defines the sequence of preprocessing steps to get the asym_id feature
@@ -82,7 +98,12 @@ def main(argv):
       asym_id = feature_dict['asym_id']
 
     for pkl_file in os.listdir(target_dir):
-      if ".pkl" in pkl_file and pkl_file.startswith("model_"):
+      # find all pickle files
+      if not pkl_file.endswith(".pkl"):
+          continue
+
+      if FLAGS.model_str in pkl_file:
+        t_0 = time.time()
         model_name = os.path.basename( pkl_file ).split(".")[0]
         model_config = config.model_config(pkl_file[:7])
         breaks = np.linspace(
@@ -103,16 +124,18 @@ def main(argv):
           result['structure_module']['final_atom_positions'],
           result['structure_module']['final_atom_mask'],
           super_asym_id,
+          distance_threshold=FLAGS.interface_dist_thres,
           is_probs=True)
 
         ptm = result['ptm'].tolist()
         pitm = result['pitm']['score'].tolist()
+        iptm = result['iptm+ptm'].tolist()
 
         inter_sc = res['score'].tolist()
         inter_residues = res['num_residues'].tolist()
         inter_contacts = res['num_contacts'].tolist()
 
-        print(f"Info: {target_name} (chains: {full_name}) {model_name}  pTM-score = {ptm:.4f}, ",
+        print(f"Info: {target_name} (chains: {full_name}) {model_name}  iptm+ptm = {iptm:.4f}, ",
             f"piTM-score = {pitm:.4f}, iRes = {inter_residues:<4d}, iCnt = {inter_contacts:<4.0f}, interface-score = {inter_sc:.4f}",)
 
         if FLAGS.do_cluster_analysis:
@@ -129,6 +152,8 @@ def main(argv):
 
           print(f"Info: num_clusters = {clus_res['num_clusters']}, cluster_sizes = {clus_res['cluster_size']}, ",
               f"clusters = {cluster_identities}\n")
+          
+        logging.info('Interface score calculation time spent: %.1f seconds', time.time() - t_0)
 
         '''
         fields   = model_name.split('_')