diff --git a/README.md b/README.md index a97ac3dc70feaaf8bb4967848818779de82c168f..bf984a205d7ce38ec4ee213460350fea71494b82 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,11 @@ You may test examples of AF2Complex or explore protein-protein interactions with ## Updates and Features +#### Version 1.4.0 (2023-01-29) + +- Support AF-Multimer v3 multimer models (based on AF v2.3.1) +- More options for input feature generation + #### Version 1.3.0 (2022-08-30) - Google Colab notebook access diff --git a/example/example1.sh b/example/example1.sh index c5d561803134aa978ceb9d5dfccb5e70b5913148..b0edc922a0fc12303d736772fe8a2b2d60c41f6d 100755 --- a/example/example1.sh +++ b/example/example1.sh @@ -1,7 +1,7 @@ #!/bin/bash # # An example script of an AF2Complex run for predicting structural models -# of a multimeric target using AF-Multimer v2 model weights in fully paired MSA mode +# of a multimeric target using AF-Multimer model weights in fully paired MSA mode # You need to take care of these two items, which are dependent on your installation. # 1) activate your conda environment for AlphaFold if you use conda @@ -18,9 +18,9 @@ out_dir=af2c_mod # model output files will be under $out_dir/$target ### This preset defined the number of recycles, ensembles, MSA cluster sizes (for monomer_ptm models) preset=economy # up to 6 recycles, 1 ensemble. -### Choose neural network model(s) from ['model_1/2/3/4/5_multimer', 'model_1/2/3/4/5_multimer_v2', or 'model_1/2/3/4/5_ptm'] +### Choose neural network model(s) from ['model_1/2/3/4/5_multimer_v3', 'model_1/2/3/4/5_multimer_v2', or 'model_1/2/3/4/5_ptm'] # Using two AF2 multimer model released in alphafold2 version 2.2.0 -model=model_1_multimer_v2,model_3_multimer_v2 +model=model_1_multimer_v3,model_3_multimer_v3 ### Choose model_preset from: ['monomer_ptm', 'multimer', 'multimer_np'] # Notes: @@ -29,7 +29,7 @@ model=model_1_multimer_v2,model_3_multimer_v2 # - multimer: apply multimer DL model to paired MSA pre-generated by AlphaFold-Multimer's official data pipeline # # You must specify approriate model names compatible with the model preset you choose. -# E.g., mnomer_ptm for model_x_ptm, and multimer_np for model_x_multimer_v2 +# E.g., mnomer_ptm for model_x_ptm, and multimer_np for model_x_multimer_v3 # model_preset=multimer_np msa_pairing=all # will assemble msa pairing using monoermic features generated with af2complex feature procedure diff --git a/example/example2.sh b/example/example2.sh index f725defa5829b6542b4f7de5e031707935d46b03..20befe6341e351a95e99bf7ab10cc91c955698b6 100755 --- a/example/example2.sh +++ b/example/example2.sh @@ -18,7 +18,7 @@ out_dir=af2c_mod # model output files will be under $out_dir/$target ### This preset defined the number of recycles, ensembles, MSA cluster sizes (for monomer_ptm models) preset=economy # up to 3 recycles, 1 ensemble. -### Choose neural network model(s) from ['model_1/2/3/4/5_multimer', 'model_1/2/3/4/5_multimer_v2', or 'model_1/2/3/4/5_ptm'] +### Choose neural network model(s) from ['model_1/2/3/4/5_multimer_v3', 'model_1/2/3/4/5_multimer_v2', or 'model_1/2/3/4/5_ptm'] # Using AF2 monomer_ptm model released in alphafold2 version 2.0.1 model=model_1_ptm,model_3_ptm diff --git a/example/example3.sh b/example/example3.sh index d8e0b85191b524d4440406389bd25d2e079aea9a..74ebb91f236083625589a96cb5a7db6c95dcfec2 100755 --- a/example/example3.sh +++ b/example/example3.sh @@ -25,9 +25,9 @@ out_dir=af2c_mod # model output files will be under $out_dir/$target ### This preset defined the number of recycles, ensembles, MSA cluster sizes (for monomer_ptm models) preset=economy # up to 3 recycles, 1 ensemble. -### Choose neural network model(s) from ['model_1/2/3/4/5_multimer', 'model_1/2/3/4/5_multimer_v2', or 'model_1/2/3/4/5_ptm'] -# Using AF2 multimer model released in alphafold2 version 2.1.1 -model=model_1_multimer_v2 +### Choose neural network model(s) from ['model_1/2/3/4/5_multimer_v3', 'model_1/2/3/4/5_multimer_v2', or 'model_1/2/3/4/5_ptm'] +# Using AF2 multimer model released in alphafold2 version 2.3.1 +model=model_1_multimer_v3 ### Choose model_preset from: ['monomer_ptm', 'multimer', 'multimer_np'] # Notes: @@ -36,7 +36,7 @@ model=model_1_multimer_v2 # - multimer: apply multimer DL model to paired MSA pre-generated by AlphaFold-Multimer's official data pipeline # # You must specify approriate model names compatible with the model preset you choose. -# E.g., mnomer_ptm for model_x_ptm, and multimer_np for model_x_multimer_v2 +# E.g., mnomer_ptm for model_x_ptm, and multimer_np for model_x_multimer_v3 model_preset=multimer_np recycling_setting=1 # output metrics but not saving pdb files during intermediate recycles diff --git a/example/example4.sh b/example/example4.sh index c27cd2a0dd05637b51f72568138d0a747d358b80..a19b3f14a0d93f7cdb129a9e9c1edb393d1b40e7 100755 --- a/example/example4.sh +++ b/example/example4.sh @@ -19,7 +19,7 @@ out_dir=af2c_mod # model output directory, s.t. output files will be on $out_dir ### This preset defined the number of recycles, ensembles, MSA cluster sizes (for monomer_ptm models) preset=expert # up to 20 recycles, 1 ensemble. -### Choose neural network model(s) from ['model_1/2/3/4/5_multimer', 'model_1/2/3/4/5_multimer_v2', or 'model_1/2/3/4/5_ptm'] +### Choose neural network model(s) from ['model_1/2/3/4/5_multimer_v3', 'model_1/2/3/4/5_multimer_v2', or 'model_1/2/3/4/5_ptm'] # Using AF2 monomer_ptm model released in alphafold2 version 2.0.1 model=model_5_ptm diff --git a/example/examples.sh b/example/examples.sh index 2a91bddc260bc049a4440077a42f94a8df519470..bf800a67e2ae09c10e98f3998f47dfffa55f0a11 100755 --- a/example/examples.sh +++ b/example/examples.sh @@ -19,7 +19,7 @@ out_dir=af2c_mod # model output directory, s.t. output files will be on $out_dir preset=economy # up to 3 recycles, 1 ensemble. # these two options can be overwritten in examples.lst -model=model_1_multimer_v2 +model=model_1_multimer_v3 model_preset=multimer_np diff --git a/example/run_af2comp.sh b/example/run_af2comp.sh index 79fc8ea421c53c6a8dd56f50ccf13fb61b284d5e..9af080b3b464e363140297b810eb25b9da639e4e 100755 --- a/example/run_af2comp.sh +++ b/example/run_af2comp.sh @@ -16,12 +16,12 @@ out_dir=af2c_mod # model output directory, $out_dir/$target ### This preset defined the number of recycles, ensembles, MSA cluster sizes (for monomer_ptm models) preset=super # up to 20 recycles, 1 ensemble. -### Choose neural network model(s) from ['model_1/2/3/4/5_multimer', 'model_1/2/3/4/5_multimer_v2', or 'model_1/2/3/4/5_ptm'] +### Choose neural network model(s) from ['model_1/2/3/4/5_multimer_v3', 'model_1/2/3/4/5_multimer_v2', or 'model_1/2/3/4/5_ptm'] # Using AF2 multimer model released in version 2.1.1 #model=model_1_multimer,model_2_multime,rmodel_3_multimer,model_4_multimer,model_5_multimer # Using AF2 multimer model released in version 2.2.0 -#model=model_1_multimer_v2,model_2_multime_v2,rmodel_3_multimer_v2,model_4_multimer_v2,model_5_multimer_v2 -model=model_1_multimer_v2,model_2_multimer_v2 +#model=model_1_multimer_v3,model_2_multime_v3,rmodel_3_multimer_v3,model_4_multimer_v3,model_5_multimer_v3 +model=model_1_multimer_v3,model_2_multimer_v3 # Using AF2 monomer_ptm model released in version 2.0.1 #model=model_1_ptm,model_2_ptm,model_3_ptm,model_4_ptm,model_5_ptm diff --git a/example/run_fea_gen.sh b/example/run_fea_gen.sh index 38212bc72aebd7c33a4c078a4dca53ddcd0dcf51..760692f67c5671b9503de7bfe7c50677ac6f3152 100755 --- a/example/run_fea_gen.sh +++ b/example/run_fea_gen.sh @@ -8,7 +8,7 @@ DATA_DIR=$HOME/scratch/afold/data export HHLIB=$HOME/data/tools/hh-suite/build #export HMMER=$HOME/data/tools/hmmer-3.2.1/build -export HMMER=/usr/local/pace-apps/spack/packages/0.13/linux-rhel7-cascadelake/intel-19.0.5/hmmer-3.2.1-sngcwm2qjzzxseh42cryf432role4on5 +export HMMER=/usr/local/pace-apps/spack/packages/linux-rhel7-x86_64/gcc-10.3.0/hmmer-3.3.2-y25u7humnxtqnaf7yc2il3misyze6pac export KALIGN=$HOME/data/tools/kalign_v2/kalign af_dir=../src @@ -18,17 +18,26 @@ if [ $# -eq 0 ] exit 1 fi fasta_path=$1 -out_dir=af2c_fea +out_dir=af2c_fea_test + +# choices are "reduced_dbs", "full_dbs", "uniprot" db_preset='reduced_dbs' -# choices are "monomer, monomer+species, multimer" +# choices are "monomer, multimer, monomer+species, monomer+fullpdb" # Option "monomer" and "multimer" follows alphafold official datapipeline for monomeric and # multimeric structure predictions, respectively. +# # Option "monomer+species" is a modified monomeric pipeline such as the species information # is recorded for MSA pairing using only monomeric input features. This option is recommended. -feature_mode='monomer+species' +#feature_mode='monomer+species' +# +# Option "monomer+fullpdb": in addition to add species, it uses template pipeline for multimer +# rather the template pipeline for the original monomer modeling. The mulitmer template pipeline +# search full PDB for templates, which is more comprehensive than the monomer template pipeline. +feature_mode='monomer+fullpdb' -max_template_date=2020-05-15 # CASP14 starting date +#max_template_date=2020-05-15 # CASP14 starting date +max_template_date=$(date +"%Y-%m-%d") # current date echo "Info: sequence file is $fasta_path" @@ -41,7 +50,7 @@ echo "Info: max_template_date is $max_template_date" ########################################################################################## -if [ "$feature_mode" = "multimer" ]; then +if [ "$model_preset" = "multimer" ] || [ "$feature_mode" = "monomer+fullpdb" ]; then python $af_dir/run_af2c_fea.py --fasta_paths=$fasta_path --db_preset=$db_preset \ --data_dir=$DATA_DIR --output_dir=$out_dir \ --uniprot_database_path=$DATA_DIR/uniprot/uniprot.fasta \ @@ -58,7 +67,8 @@ if [ "$feature_mode" = "multimer" ]; then --hmmsearch_binary_path=$HMMER/bin/hmmsearch \ --hmmbuild_binary_path=$HMMER/bin/hmmbuild \ --kalign_binary_path=$KALIGN \ - --feature_mode=$feature_mode + --feature_mode=$feature_mode \ + --use_precomputed_msas=True elif [ "$feature_mode" = "monomer+species" ]; then python $af_dir/run_af2c_fea.py --fasta_paths=$fasta_path --db_preset=$db_preset \ --data_dir=$DATA_DIR --output_dir=$out_dir \ @@ -94,5 +104,6 @@ else --hmmsearch_binary_path=$HMMER/bin/hmmsearch \ --hmmbuild_binary_path=$HMMER/bin/hmmbuild \ --kalign_binary_path=$KALIGN \ - --feature_mode=$feature_mode + --feature_mode=$feature_mode \ + --use_precomputed_msas=True fi diff --git a/example/run_relaxation.sh b/example/run_relaxation.sh index a814707d9bc67601319bf9da1f87223eecd7fd2c..966bb674ea0168f69a440c87192cba264ed86866 100755 --- a/example/run_relaxation.sh +++ b/example/run_relaxation.sh @@ -19,3 +19,4 @@ python -u $af_dir/run_af2c_min.py \ --target_lst_path=$target_lst_file \ --output_dir=$out_dir \ --input_dir=$inp_dir \ + --use_gpu_relax diff --git a/example/targets/test.lst b/example/targets/test.lst index 828b3176dbadbf513ff7f416afde4ddece724c8f..34973a029b211861fe5fdfb90fdbb22b23d7a506 100644 --- a/example/targets/test.lst +++ b/example/targets/test.lst @@ -1,4 +1,4 @@ ##Target(components) Size(AAs) Name(for output) #T1065s1/T1065s2 225 H1065 -T1072s1:2+T1072s2:2 340 H1072 ./H1072_adj_list.txt -#T1060s3:12 1680 H1060v4 \ No newline at end of file +#T1072s1:2+T1072s2:2 340 H1072 ./H1072_adj_list.txt +T1060s3:12 1680 H1060v4 diff --git a/notebook/AF2Complex_notebook.ipynb b/notebook/AF2Complex_notebook.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..06ef040ca5ec3d90842d25db376fa93bdb36dbb6 --- /dev/null +++ b/notebook/AF2Complex_notebook.ipynb @@ -0,0 +1,1023 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "sl1IbeCnW4zg" + }, + "source": [ + "# Workflow\n", + "\n", + "Step 1: Setup AF2Complex by running through the setup module\n", + "\n", + "Step 2: Pick one of three Target Run modules and follow the steps in each module\n", + "\n", + "Step 3: Download your predictions\n", + "- After running AF2Complex on a target the prediction location will be printed out Such as the screenshot below. \n", + "\n", + "\n", + "\n", + "- You can then download these predictions by using Google Colab's file explorer like the GIF below demonstrates:\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iyrTK1-HJ6Gs" + }, + "source": [ + "#Setup AF2Complex\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "DuR67XCbLKod" + }, + "outputs": [], + "source": [ + "#@title 1. Download deep learning model parameters of AlphaFold 2\n", + "\n", + "#@markdown Please execute this cell by pressing the *Play* button on \n", + "#@markdown the left.\n", + "\n", + "# from google.colab import output\n", + "# output.enable_custom_widget_manager()\n", + "from IPython.utils import io\n", + "import os\n", + "import subprocess\n", + "import tqdm.notebook\n", + "from google.colab import output\n", + "import os\n", + "\n", + "output.enable_custom_widget_manager()\n", + "os.environ['TF_FORCE_UNIFIED_MEMORY'] = '1'\n", + "os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '2.0'\n", + "\n", + "#SOURCE_URL = \"https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar\"\n", + "SOURCE_URL = \"https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar\"\n", + "PARAMS_DIR = '/content/afold/data/params'\n", + "\n", + "PARAMS_PATH = os.path.join(PARAMS_DIR, os.path.basename(SOURCE_URL))\n", + "\n", + "TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'\n", + "try:\n", + " with tqdm.notebook.tqdm(total=100, bar_format=TQDM_BAR_FORMAT) as pbar:\n", + " with io.capture_output() as captured:\n", + "\n", + " if not os.path.exists(PARAMS_DIR):\n", + " %shell mkdir --parents \"{PARAMS_DIR}\"\n", + " %shell wget -O \"{PARAMS_PATH}\" \"{SOURCE_URL}\"\n", + " pbar.update(40)\n", + " \n", + " %shell tar --extract --verbose --file=\"{PARAMS_PATH}\" \\\n", + " --directory=\"{PARAMS_DIR}\" --preserve-permissions\n", + " %shell rm \"{PARAMS_PATH}\"\n", + " pbar.update(60)\n", + " else:\n", + " pbar.update(100)\n", + "except subprocess.CalledProcessError:\n", + " print(captured)\n", + " raise\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "zo7qqlklKQmd" + }, + "outputs": [], + "source": [ + "#@title 2. Install AF2Complex\n", + "\n", + "#@markdown Please execute this cell by pressing the _Play_ button \n", + "#@markdown \n", + "#@markdown This installs AF2Complex and the python packages it uses\n", + "\n", + "import os\n", + "import subprocess\n", + "\n", + "AF2C_examples = '/content/af2complex/example'\n", + "AF2C_src = '/content/af2complex/src'\n", + "AF_LIB_DIR = os.path.join(AF2C_src, 'alphafold')\n", + "UPLOAD_DIR = '/content/uploaded_feats/'\n", + "os.chdir('/content/')\n", + "\n", + "try:\n", + " with tqdm.notebook.tqdm(total=100, bar_format=TQDM_BAR_FORMAT) as pbar:\n", + " with io.capture_output() as captured:\n", + " if not os.path.exists('/content/af2complex'):\n", + " %shell git clone https://github.com/FreshAirTonight/af2complex.git\n", + " pbar.update(15)\n", + "\n", + " #Install third-party software\n", + " %shell pip uninstall -y tensorflow keras\n", + " pbar.update(5)\n", + " # Install py3dmol.\n", + " %shell pip install py3dmol\n", + " pbar.update(5)\n", + " %shell cd af2complex && pip install -r requirements.txt\n", + " pbar.update(50)\n", + " #%shell pip install --upgrade jax==0.2.14 jaxlib==0.1.69+cuda111 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n", + " %shell pip3 install --upgrade --no-cache-dir jax==0.3.25 jaxlib==0.3.25+cuda11.cudnn805 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n", + "\n", + " if not os.path.exists('/content/uploaded_feats/'):\n", + " %shell mkdir /content/uploaded_feats/\n", + " pbar.update(25)\n", + "except subprocess.CalledProcessError:\n", + " print(captured)\n", + " raise\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "-dkd1Qcz4254" + }, + "outputs": [], + "source": [ + "#@title 3. Define the configuration of your structure prediction run\n", + "#@markdown **Note**: Please re-run this cell if any variable below is changed\n", + "\n", + "#@markdown Choose preset model configuration: <deepmind> standard settings according to DeepMind, \n", + "#@markdown i.e., 3 recycles and 1 ensemble; \n", + "#@markdown - **economy**: no ensemble, up to 256 MSA clusters, recycling up to 3 rounds; \n", + "#@markdown - **super/super2**: 1 or 2 ensembles, up to 512 MSA clusters, recycling up to 20 rounds; \n", + "#@markdown - **genome/genome2**: 1 or 2 ensembles, up to 512 MSA clusters, max number \n", + "#@markdown of recycles and ensembles adjusted according to input sequence length; \n", + "#@markdown - **expert**: similar to super but maintain the same recycle number regardless target size; \n", + "#@markdown - **casp14**: 8 model ensemblings used by DeepMind in CASP14.')\n", + "import numpy as np\n", + "DATA_DIR = '/content/afold/data/'\n", + "preset = 'economy' #@param ['deepmind', 'casp14', 'economy', 'super', 'expert', 'super2', 'genome', 'genome2']\n", + "\n", + "#@markdown Choose between multimer_v2 or ptm AF parameter sets:\n", + "model_type = 'multimer_v3' #@param ['multimer_v3', 'monomer_ptm']\n", + "model_preset = {\n", + " 'multimer_v3': 'multimer_np',\n", + " 'monomer_ptm': 'monomer_ptm',\n", + " }[model_type]\n", + "if model_type == 'monomer_ptm':\n", + " model_type = 'ptm'\n", + "\n", + "#@markdown There are five different models you can choose from, check the ones you want to run (please check at least one) \n", + "param_set_1 = True #@param {type:\"boolean\"}\n", + "param_set_2 = False #@param {type:\"boolean\"}\n", + "param_set_3 = False #@param {type:\"boolean\"}\n", + "param_set_4 = False #@param {type:\"boolean\"}\n", + "param_set_5 = False #@param {type:\"boolean\"}\n", + "\n", + "param_set_nums = [param_set_1,param_set_2,param_set_3,param_set_4,param_set_5]\n", + "assert np.any(param_set_nums), 'Please check one of the param_sets '\n", + "models = []\n", + "for i, param_set in enumerate(param_set_nums):\n", + " if param_set:\n", + " models.append(f\"model_{i+1}_{model_type}\")\n", + "\n", + "#@markdown Choose your recycling setting:\n", + "#@markdown 0. no recycle info saving \n", + "#@markdown 1. print metrics of intermediate recycles\n", + "#@markdown 2. additionally saving pdb structures of all recycles, \n", + "#@markdown 3. additionally save all results in pickle\n", + "recycling_setting=\"1\" #@param [0, 1, 2, 3]\n", + "\n", + "#@markdown Input below how many predictions (each with a different random seed) will be \n", + "#@markdown generated per model. \n", + "\n", + "#@markdown E.g. if this is 2 and there are 5\n", + "#@markdown models then there will be 10 predictions per input. \n", + "num_predictions_per_model=1 #@param {type:\"integer\"}\n", + "\n", + "#@markdown Input below the maximum number of recycles. Leave as -1 if you don't want to limit the number of recycles.\n", + "max_recycles = 4 #@param {type: \"integer\"}\n", + "\n", + "\n", + "class dotdict(dict):\n", + " \"\"\"dot.notation access to dictionary attributes\"\"\"\n", + " __getattr__ = dict.get\n", + " __setattr__ = dict.__setitem__\n", + " __delattr__ = dict.__delitem__\n", + "def make_default_flags():\n", + " return dotdict({\n", + " 'target_lst_path':None,\n", + " 'output_dir':'/content/af2complex/example/af2c_mod',\n", + " 'feature_dir':'/content/af2complex/example/af2c_fea',\n", + " 'model_names':None,\n", + " 'data_dir':DATA_DIR,\n", + " 'preset':'economy',\n", + " 'random_seed':None,\n", + " 'max_recycles':None,\n", + " 'num_ensemble':None,\n", + " 'max_msa_clusters':None,\n", + " 'max_extra_msa':None,\n", + " 'write_complex_features':False,\n", + " 'no_template':False,\n", + " 'output_pickle':True,\n", + " 'save_recycled':0,\n", + " 'checkpoint_tag':None,\n", + " 'max_mono_msa_depth':10000,\n", + " 'mono_msa_crop_size':5000,\n", + " 'max_template_hits':4,\n", + " 'model_preset':'monomer_ptm',\n", + " 'num_predictions_per_model':1,\n", + " 'msa_pairing':None,\n", + " 'do_cluster_analysis':False,\n", + " 'cluster_edge_thres':10,\n", + " })\n", + "FLAGS = make_default_flags()\n", + "FLAGS['preset'] = preset\n", + "FLAGS['model_preset'] = model_preset\n", + "FLAGS['model_names'] = models\n", + "FLAGS['save_recycled'] = recycling_setting\n", + "FLAGS['num_predictions_per_model'] = num_predictions_per_model\n", + "FLAGS['max_recycles'] = max_recycles\n", + "\n", + "def make_mod_params():\n", + " preset = FLAGS['preset'] \n", + " model_preset = FLAGS['model_preset'] \n", + " models = FLAGS['model_names'] \n", + " recycling_setting = FLAGS['save_recycled'] \n", + " target_lst_file = FLAGS['target_lst_file'] \n", + " msa_pairing = FLAGS['msa_pairing'] \n", + " out_dir = FLAGS['output_dir']\n", + " fea_dir = FLAGS['feature_dir']\n", + " num_predictions_per_model = FLAGS['num_predictions_per_model']\n", + " max_recycles = FLAGS['max_recycles']\n", + "\n", + " parameters = [\n", + " f'--target_lst_path={target_lst_file}',\n", + " f'--data_dir={DATA_DIR}',\n", + " f'--output_dir={out_dir}',\n", + " f'--feature_dir={fea_dir}',\n", + " f'--model_names={\",\".join(models)}',\n", + " f'--preset={preset}',\n", + " f'--model_preset={model_preset}',\n", + " f'--num_predictions_per_model={num_predictions_per_model}',\n", + " f'--save_recycled={recycling_setting}']\n", + " \n", + " if msa_pairing != 'none':\n", + " parameters.append(f'--msa_pairing={msa_pairing}')\n", + " if max_recycles > 0:\n", + " parameters.append(f'--max_recycles={max_recycles}')\n", + "\n", + " return ' '.join(parameters)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "7ZKu8W-rTzJ3" + }, + "outputs": [], + "source": [ + "#@title 4. Define relevant methods for visualization\n", + "\n", + "os.chdir(AF2C_src)\n", + "import py3Dmol\n", + "from alphafold.data.complex import make_complex_features\n", + "from alphafold.model import config\n", + "from alphafold.common import confidence\n", + "from alphafold.data.complex import initialize_template_feats\n", + "\n", + "import alphafold.data.complex as af2c\n", + "from run_af2c_mod import get_asymid2chain_name\n", + "import pickle\n", + "\n", + "import numpy as np\n", + "import re\n", + "import pandas as pd\n", + "from ipywidgets import interact, Dropdown\n", + "from google.colab import widgets\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import ipywidgets\n", + "from IPython.display import display\n", + "import pandas as pd\n", + "\n", + "\n", + "def show_pdb(pred_output_path, show_sidechains=False, show_mainchains=False):\n", + " view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)\n", + " view.addModel(open(pred_output_path,'r').read(),'pdb')\n", + " view.setStyle({'cartoon': {'colorscheme': 'chain'}})\n", + " if show_sidechains:\n", + " BB = ['C','O','N']\n", + " view.addStyle({'and':[{'resn':[\"GLY\",\"PRO\"],'invert':True},{'atom':BB,'invert':True}]},\n", + " {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n", + " view.addStyle({'and':[{'resn':\"GLY\"},{'atom':'CA'}]},\n", + " {'sphere':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n", + " view.addStyle({'and':[{'resn':\"PRO\"},{'atom':['C','O'],'invert':True}]},\n", + " {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}}) \n", + " if show_mainchains:\n", + " BB = ['C','O','N','CA']\n", + " view.addStyle({'atom':BB},{'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n", + "\n", + " view.zoomTo()\n", + " return view\n", + "\n", + "def get_asym_id(target, flags):\n", + " \"\"\"Defines the sequence of preprocessing steps to get the asym_id feature\n", + " Args:\n", + " target: dictionary with the items:\n", + " name: name of the multimer,\n", + " split: information about each monomer composing the multimer,\n", + " full: a string denoting all stoichiometry and domains of all monomers\n", + " composing the multimer to be modeled,\n", + " flags: variable containing inference configuration\n", + " Returns:\n", + " asym_id\n", + " \"\"\"\n", + " monomers = af2c.load_monomer_feature(target, flags)\n", + "\n", + " if flags.msa_pairing is not None:\n", + " for i in range(len(monomers)):\n", + " if 'deletion_matrix' in monomers[i]['feature_dict']:\n", + " monomers[i]['feature_dict']['deletion_matrix_int'] = monomers[i]['feature_dict']['deletion_matrix']\n", + " curr_input = {'monomers': monomers, 'target': target, 'flags': flags}\n", + "\n", + " curr_input = af2c.targeted_domain_cropping_mono(curr_input)\n", + " curr_input = af2c.add_asym_id_monomer_ptm(curr_input)\n", + " asym_id = curr_input['asym_id_mono_ptm']\n", + "\n", + " return asym_id\n", + "\n", + "def get_interface_score(\n", + " model_name, target_name, full_name, asym_id, idx2chain_name, out_dir, asym_id_list):\n", + " metric = []\n", + " value = []\n", + " pdb_path = os.path.join(out_dir, target_name, f'{model_name}.pdb')\n", + " pkl_path = os.path.join(out_dir, target_name, f'{model_name}.pkl')\n", + "\n", + " model_config = config.model_config(model_name[:7])\n", + " breaks = np.linspace(\n", + " 0., model_config.model.heads.predicted_aligned_error.max_error_bin,\n", + " model_config.model.heads.predicted_aligned_error.num_bins - 1)\n", + " try:\n", + " result = pickle.load(open(pkl_path, \"rb\"))\n", + " except (EOFError,IOError) as error:\n", + " print(f\"Warning: {target_name} {error} encountered, check the pickle file\")\n", + " raise\n", + "\n", + " super_asym_id, superid2chainids = confidence.join_superchains_asym_id(asym_id, asym_id_list)\n", + "\n", + " res = confidence.interface_score(\n", + " result['aligned_confidence_probs'],\n", + " breaks,\n", + " result['structure_module']['final_atom_positions'],\n", + " result['structure_module']['final_atom_mask'],\n", + " super_asym_id,\n", + " is_probs=True)\n", + "\n", + " ptm = result['ptm'].tolist()\n", + " pitm = result['pitm']['score'].tolist()\n", + "\n", + " inter_sc = res['score'].tolist()\n", + " inter_residues = res['num_residues'].tolist()\n", + " inter_contacts = res['num_contacts'].tolist()\n", + " metric.append('MODEL NAME')\n", + " value.append(model_name)\n", + " metric.append('TARGET CHAINS')\n", + " value.append(full_name)\n", + " metric.append('===========')\n", + " value.append('===========')\n", + " metric.append('pTM-score')\n", + " value.append(ptm)\n", + " metric.append('piTM-score')\n", + " value.append(pitm)\n", + " metric.append('iRes')\n", + " value.append(inter_residues)\n", + " metric.append('iCnt')\n", + " value.append(inter_contacts)\n", + " metric.append('interface-score')\n", + " value.append(inter_sc)\n", + "\n", + " if FLAGS.do_cluster_analysis:\n", + " clus_res = confidence.cluster_analysis(\n", + " super_asym_id,\n", + " result['structure_module']['final_atom_positions'],\n", + " result['structure_module']['final_atom_mask'],\n", + " edge_contacts_thres=FLAGS.cluster_edge_thres,\n", + " superid2chainids=superid2chainids,\n", + " )\n", + " cluster_identities = []\n", + " for cluster in clus_res['clusters']:\n", + " cluster_identities.append([idx2chain_name[c] for c in cluster])\n", + "\n", + " metric.append('num_clusters')\n", + " value.append(clus_res['num_clusters'])\n", + " metric.append('cluster_sizes')\n", + " value.append(clus_res['cluster_size'])\n", + " metric.append('clusters')\n", + " value.append(cluster_identities)\n", + " return pd.DataFrame({'Metric Name':metric, 'Value':value})\n", + "\n", + "DATA_DIR = '/content/afold/data' \n", + "def display_metrics(target_lst_path, model_path, support, show_sidechains_=True, show_mainchains_=True, ):\n", + " with io.capture_output() as captured:\n", + " target_lst = af2c.read_af2c_target_file(target_lst_path)\n", + " full_name = support[\"full_name\"] \n", + " target_name= support[\"target_name\"] \n", + " idx2chain_name= support[\"idx2chain_name\"]\n", + " asym_id_list= support[\"asym_id_list\"]\n", + " asym_id= support[\"asym_id\"] \n", + " model_name = os.path.basename(model_path)\n", + " pdb_path = os.path.join(FLAGS.output_dir, target_name, f'{model_name}.pdb')\n", + "\n", + " metrics = get_interface_score(\n", + " model_name, target_name, full_name, asym_id, idx2chain_name, FLAGS.output_dir, asym_id_list\n", + " )\n", + "\n", + " print(metrics.to_markdown())\n", + " view = show_pdb(pdb_path, \n", + " show_sidechains=show_sidechains_,\n", + " show_mainchains=show_mainchains_)\n", + " view.show()\n", + "\n", + "def visualize(show_sidechains, show_mainchains):\n", + " is_pdb = lambda x: '.pdb' in x \n", + " if not os.path.exists(FLAGS.target_lst_file):\n", + " raise f'{FLAGS.target_lst_file} does not exist!'\n", + " target_lst = af2c.read_af2c_target_file(FLAGS.target_lst_file)\n", + " files = []\n", + " model2support = {}\n", + " for target in target_lst:\n", + " target_name = target['name']\n", + " target_name = re.sub(\":\", \"_x\", target_name)\n", + " target_name = re.sub(\"/\", \"+\", target_name)\n", + " target_dir = os.path.join(FLAGS.output_dir, target_name)\n", + " if not os.path.exists(target_dir):\n", + " raise Exception(\n", + " f'No predictions for {target_name}. Predictions available are for {os.listdir(FLAGS.output_dir)}. Please make sure the inference cell was run correctly.'\n", + " )\n", + " target_files = os.listdir(target_dir)\n", + " if len(target_files) == 0:\n", + " raise Exception(\n", + " f'No predictions for {target_name}. Predictions available are for {os.listdir(FLAGS.output_dir)}. Please make sure the inference cell was run correctly.'\n", + " )\n", + " for f in target_files:\n", + " full_name = target['full']\n", + " idx2chain_name = get_asymid2chain_name(target)\n", + " asym_id_list = target['asym_id_list']\n", + " if not FLAGS.write_complex_features:\n", + " with io.capture_output() as captured:\n", + " asym_id = get_asym_id(target, FLAGS)\n", + " else:\n", + " feat_path = os.path.join(target_name, 'features_comp.pkl')\n", + " try:\n", + " feature_dict = np.load(open(feat_path, 'rb'))\n", + " except FileNotFoundError:\n", + " print('Did not find feature_comp.pkl file. ',\n", + " 'To rebuild complex features, run without ',\n", + " '--write_complex_features flag.')\n", + " asym_id = feature_dict['asym_id']\n", + " if is_pdb(f):\n", + " model_name = os.path.join(target_dir, target_name, f)[:-4]\n", + " files.append(model_name)\n", + " model2support[model_name] = {\n", + " 'full_name': full_name, \n", + " 'target_name': target_name, \n", + " 'idx2chain_name': idx2chain_name, \n", + " 'asym_id_list': asym_id_list, \n", + " 'asym_id': asym_id, \n", + " }\n", + "\n", + " tabs = widgets.TabBar(files)\n", + "\n", + " for i, model in enumerate(files): \n", + " with tabs.output_to(i):\n", + " display_metrics(\n", + " FLAGS.target_lst_file,\n", + " model,\n", + " model2support[model],\n", + " show_sidechains,\n", + " show_mainchains, \n", + " )\n", + "\n", + "def get_dataset_desc(file_path):\n", + " from google.colab import data_table\n", + " data_table.enable_dataframe_formatter()\n", + " with open(file_path, 'r') as f:\n", + " ecoli_txt = f.readlines()\n", + "\n", + " ids = []\n", + " genes = []\n", + " acs = []\n", + " fulllen = []\n", + " ranges = []\n", + " length = []\n", + " desc = []\n", + " for line in ecoli_txt:\n", + " if line.startswith('#'):\n", + " continue\n", + " line = line.split('\\t')\n", + " ids.append(line[0])\n", + " genes.append(line[1])\n", + " acs.append(line[2])\n", + " fulllen.append(line[3])\n", + " ranges.append(line[4])\n", + " length.append(line[5])\n", + " desc.append(line[6])\n", + "\n", + " df = pd.DataFrame({\n", + " 'ID': ids[1:],\n", + " 'Gene': genes[1:],\n", + " 'AC': acs[1:],\n", + " 'Full Length': fulllen[1:],\n", + " 'Range': ranges[1:],\n", + " 'Length': length[1:],\n", + " 'Description': desc[1:],\n", + " })\n", + " return df" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EUsieqM-OsSJ" + }, + "source": [ + "# Target Run (AF2Complex Examples)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "1QFJ32_zOqJU" + }, + "outputs": [], + "source": [ + "import subprocess\n", + "import numpy as np\n", + "os.chdir(AF_LIB_DIR)\n", + "#@markdown #1. Choose one of the AF2Complex examples to run below! \n", + "#@markdown Note: After choosing your parameters below, press the play button to run the example chosen:\n", + "FLAGS['feature_dir'] = '/content/af2complex/example/af2c_fea' \n", + "FLAGS['output_dir'] = '/content/af2complex/example/af2c_mod'\n", + "\n", + "example = 'H1065' #@param ['H1065', 'H1072', 'H1072_H1065', 'H1060v4']\n", + "\n", + "target_lst_file = {\n", + " 'H1065': '/content/af2complex/example/targets/example1.lst',\n", + " 'H1072': '/content/af2complex/example/targets/example2.lst',\n", + " 'H1072_H1065': '/content/af2complex/example/targets/example3.lst',\n", + " 'H1060v4': '/content/af2complex/example/targets/example4.lst',\n", + "}[example]\n", + "\n", + "#@markdown Choose the type of msa pairing you want to use (Note: 'none' will do no msa_pairing, 'all' will do every possible species pairing as was done in AF-Multimer):\n", + "msa_pairing = 'none' #@param ['none', 'all', 'custom', 'cyclic', 'linear']\n", + "\n", + "FLAGS['target_lst_file'] = target_lst_file\n", + "FLAGS['msa_pairing'] = msa_pairing\n", + "\n", + "pred_params = make_mod_params()\n", + "\n", + "# with io.capture_output() as captured:\n", + "%shell python -u ../run_af2c_mod.py {pred_params}\n", + "print(f'DONE! (predictions available on {FLAGS.output_dir}' )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "sd1bBQgTJLEA" + }, + "outputs": [], + "source": [ + "# %matplotlib inline\n", + "\n", + "#@markdown #2. Visualize your results below by pressing the *Play* button on the left\n", + "#@markdown Choose one of the AF2Complex examples to visualize below! \n", + "FLAGS['feature_dir'] = '/content/af2complex/example/af2c_fea' \n", + "FLAGS['output_dir'] = '/content/af2complex/example/af2c_mod'\n", + "\n", + "example = 'H1065' #@param ['H1065', 'H1072', 'H1072_H1065', 'H1060v4']\n", + "\n", + "target_lst_file = {\n", + " 'H1065': '/content/af2complex/example/targets/example1.lst',\n", + " 'H1072': '/content/af2complex/example/targets/example2.lst',\n", + " 'H1072_H1065': '/content/af2complex/example/targets/example3.lst',\n", + " 'H1060v4': '/content/af2complex/example/targets/example4.lst',\n", + "}[example]\n", + "\n", + "FLAGS['target_lst_file'] = target_lst_file\n", + "FLAGS['msa_pairing'] = msa_pairing\n", + "\n", + "pred_params = make_mod_params()\n", + "\n", + "show_sidechains = False #@param {type: 'boolean'}\n", + "show_mainchains = False #@param {type: 'boolean'}\n", + "\n", + "\n", + "AF2C_examples = '/content/af2complex/example'\n", + "AF2C_egtargets = os.path.join(AF2C_examples, 'targets')\n", + "\n", + "visualize(show_sidechains, show_mainchains)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1s0RID2AYdpk" + }, + "source": [ + "# Target Run (within the *E. coli* by using pre-generated features for the proteome!)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "ClAAecbPYrr0" + }, + "outputs": [], + "source": [ + "import os\n", + "#@title 1. Download the dataset from [Zenodo](https://zenodo.org/record/7008599#.YwFWR3bMJaQ)\n", + "#@markdown Note: Usually takes less than 20 minutes.\n", + "\n", + "AF2C_examples = '/content/af2complex/example'\n", + "os.chdir(AF2C_examples)\n", + "AF2C_ecoli = os.path.join(AF2C_examples, 'ecoli')\n", + "zenodo_link = 'https://zenodo.org/record/7008599/files/af2c_fea_ecoli_220331_msa10ktem10.tar?download=1'\n", + "\n", + "AF2C_ecoli_path = os.path.join(AF2C_ecoli, os.path.basename(zenodo_link))\n", + "\n", + "if not os.path.exists(AF2C_ecoli):\n", + " os.mkdir(AF2C_ecoli)\n", + "\n", + "%shell wget -O {AF2C_ecoli_path} {zenodo_link}\n", + " \n", + "with io.capture_output() as captured:\n", + " %shell tar --extract --verbose --file={AF2C_ecoli_path} \\\n", + " --directory={AF2C_ecoli} --preserve-permissions\n", + " \n", + "AF2C_ecoli_feas = AF2C_ecoli_path.split('.tar')[0]\n", + "txt_zenodo_link = \"https://zenodo.org/record/7008599/files/ecoli_af2c_fea.txt?download=1\"\n", + "AF2C_ecoli_txt_path = os.path.join(AF2C_ecoli, os.path.basename(txt_zenodo_link))\n", + "\n", + "%shell wget -O {AF2C_ecoli_txt_path} {txt_zenodo_link}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "IQuqyTcMMaS7" + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "\n", + "# @markdown # Find the genes in *E. coli* with pre-generated input features:\n", + "\n", + "ecoli_df = get_dataset_desc(AF2C_ecoli_txt_path)\n", + "ecoli_df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "yNYa7QFWQPZF" + }, + "outputs": [], + "source": [ + "import subprocess\n", + "import numpy as np\n", + "# from run_af2c_mod import FLAGS\n", + "os.chdir(AF_LIB_DIR)\n", + "# print(FLAGS.fea_dir, AF2C_ecoli_feas)\n", + "FLAGS.feature_dir = AF2C_ecoli_feas\n", + "# print(FLAGS.fea_dir, AF2C_ecoli_feas)\n", + "FLAGS.output_dir = '/content/af2complex/example/ecoli/af2c_mod'\n", + "if not os.path.exists(FLAGS.output_dir):\n", + " os.mkdir(FLAGS.output_dir)\n", + "\n", + "#@markdown #2. Define your protein complex target using the UniProt IDs you found in the table above, then run AF2Complex using the *Play* button on the left\n", + "#@markdown Define how the chains compose the target,\n", + "#@markdown e.g.: \n", + "# #@markdown - T1065s1/T1065s2 *(Explanation on [example1](https://github.gatech.edu/gmu3/af2complex/tree/master/example#example-1) for more)*\n", + "# #@markdown - T1072s1:2/T1072s2:2 *(Explanation on [example2](https://github.gatech.edu/gmu3/af2complex/tree/master/example#example-2) for more)\n", + "# #@markdown - T1065s1/T1065s2+T1072s1:2/T1072s2:2 *(Explanation on [example3](https://github.gatech.edu/gmu3/af2complex/tree/master/example#example-3) for more)*\n", + "# #@markdown - T1060s3:12 *(Explanation on [example4](https://github.gatech.edu/gmu3/af2complex/tree/master/example#example-4) for more)*\n", + "\n", + "#@markdown - SECE/SECG/SECY *(SecYEG translocon, a hetero-trimer composed of SecE, SecG, and SecY, 680 AAs)*\n", + "#@markdown - PPID|265-359/DSBA|20-208 *(PpiD parvulin domain and DsbA, each has a residue ID range, 285 AAs)*\n", + "#@markdown - SURA|21-428/BAMA|21-420 *(surA and BamA, both have signal peptide removed, 808 AAs)*\n", + "#@markdown - PPID/YFGM *(chaperon proteins PpiD and YfgM, 829 AAs)*\n", + "#@markdown - CCMA:2/CCMB:2/CCMC/CCMD/CCME *(CcmI system, 1327 AAs)*\n", + "#@markdown - YAJC:3 *(YajC, a membrane protein chaperon? 330 AAs)*\n", + "\n", + "#@markdown Note that a large target may require resources beyond the free-tier.\n", + "\n", + "# chains = 'e.g. T1065s1/T1065s2' #@param {type:'string'}\n", + "chains = 'SECE/SECG/SECY' #@param {type:'string'}\n", + "\n", + "#@markdown Name your target\n", + "# target = 'e.g. H1065' #@param {type: 'string'}\n", + "target = 'SecYEG' #@param {type: 'string'}\n", + "\n", + "#@markdown Put down the total number of AA of the target (does not need to be exact as this number will be parsed but not used in the code)\n", + "num_AA = 680 #@param{type: 'integer'}\n", + "\n", + "#@markdown Choose the type of msa pairing you want to use (Note: 'none' will do no msa_pairing, 'all' will do species pairing as in AF-Multimer):\n", + "msa_pairing = 'all' #@param ['none', 'all', 'custom', 'cyclic', 'linear']\n", + "FLAGS.msa_pairing = msa_pairing\n", + "\n", + "target_lst = f'{chains} {num_AA} {target}'\n", + "\n", + "target_lst_file = os.path.join(AF2C_ecoli, f'{target}.lst')\n", + "with open(target_lst_file, 'w') as f:\n", + " f.write(target_lst)\n", + " f.close()\n", + "FLAGS.target_lst_file = target_lst_file\n", + "\n", + "pred_params = make_mod_params()\n", + "\n", + "# with io.capture_output() as captured:\n", + "%shell python -u ../run_af2c_mod.py {pred_params}\n", + "print(f'DONE! (predictions available on {FLAGS.output_dir}' )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "MewHtxVmBxNS" + }, + "outputs": [], + "source": [ + "#@markdown # Press **Play** button to the left to see which targets you have predictions for so far\n", + "from ipywidgets import interact\n", + "pd.DataFrame({\n", + " 'Target Name': os.listdir(FLAGS.output_dir),\n", + " 'Number of Predictions': [\n", + " len(list(filter(lambda x: '.pdb' in x, os.listdir(os.path.join(FLAGS.output_dir,f )))))\n", + " for f in os.listdir(FLAGS.output_dir)],\n", + "})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "9G2i6I84ZJql" + }, + "outputs": [], + "source": [ + "#@markdown #3. Visualize your results below by pressing the *Play* button on the left\n", + "\n", + "#@markdown Check out the cell above to see which targets have predictions. Imput the target name below to visualize the proteins. \n", + "target_name = 'SecYEG' #@param {type: 'string'}\n", + "FLAGS.target_lst_file = os.path.join(AF2C_ecoli, f'{target_name}.lst')\n", + "if not os.path.exists(FLAGS.target_lst_file):\n", + " raise Exception(f' Target: predictions for {target_name} do not exist, run the cell above to see which targets have predictions')\n", + "\n", + "show_sidechains = False #@param {type: 'boolean'}\n", + "show_mainchains = False #@param {type: 'boolean'}\n", + "\n", + "visualize(show_sidechains, show_mainchains)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tct8Uw1TW8RQ" + }, + "source": [ + "#Target Run (Upload your own features)\n", + "\n", + "First create a folder with your features with the following file structure: \n", + "```\n", + "dataset_name\n", + "│ \n", + "└───chain_1\n", + "│ │ \n", + "│ └───features.pkl\n", + "└───chain_2\n", + "│ │ \n", + "│ └───features.pkl\n", + "...\n", + "```\n", + "Then, upload a .tar (or .tgz) file of this folder below (Section 1a). You can create a .tar file with the following unix terminal command (Please keep the folder name and the .tar file name the same): \n", + "\n", + "\n", + "```\n", + "tar -czf af2c_fea.tgz af2c_fea\n", + "```\n", + "\n", + "\n", + "Optionally, you can also upload a txt file describing the dataset to easily search through the dataset. It should look like the following (all tab separated):\n", + "```\n", + "### ID -- UniProt ID\n", + "### Gene -- Recommended gene name\n", + "### AC -- Accession ID\n", + "### Fulllen -- Full sequence length\n", + "### Range -- Residue range of the longest mature chain\n", + "### Len -- Seuence length\n", + "### Description -- description of the gene\n", + "ID\tGene\tAC\tFullLen\tRange\tLen\tDescription\n", + "3MG1\ttag\tP05100\t187\t1-187\t187\tDNA-3-methyladenine glycosylase 1\n", + "...\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "nnj1Vwg9KW3A" + }, + "outputs": [], + "source": [ + "#@markdown #1a. Upload your dataset here (press *Play* button to the left)\n", + "#@markdown Please upload only one dataset at a time\n", + "\n", + "from google.colab import files\n", + "os.chdir(UPLOAD_DIR)\n", + "print('Upload the .tar (or .tgz) file')\n", + "\n", + "uploaded = files.upload()\n", + "dset_file = list(uploaded.keys())[0]\n", + "dset_name = dset_file.split('.')[0]\n", + "dset_dir = os.path.join(UPLOAD_DIR, dset_name)\n", + "print(f\"INFO: Uploaded dataset with name: {dset_name}\")\n", + " \n", + "with io.capture_output() as captured:\n", + " %shell tar --extract --verbose --file=/content/uploaded_feats/{dset_file} \\\n", + " --directory={UPLOAD_DIR} --preserve-permissions\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "XgDHn7W2vEWX" + }, + "outputs": [], + "source": [ + "#@markdown #1b. Check out the dataset (Optional)\n", + "#@markdown Upload the dataset description file\n", + "\n", + "os.chdir(UPLOAD_DIR)\n", + "print('Upload the .txt description file')\n", + "uploaded = files.upload()\n", + "desc_file = list(uploaded.keys())[0]\n", + "dset_df = get_dataset_desc(os.path.join(UPLOAD_DIR, desc_file))\n", + "dset_df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "76o6KEjsy6R7" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "# from run_af2c_mod import FLAGS\n", + "os.chdir(AF_LIB_DIR)\n", + "# print(FLAGS.fea_dir, AF2C_ecoli_feas)\n", + "FLAGS.feature_dir = os.path.join(UPLOAD_DIR, dset_name)\n", + "# print(FLAGS.fea_dir, AF2C_ecoli_feas)\n", + "FLAGS.output_dir = os.path.join(UPLOAD_DIR, dset_name, 'af2c_mod')\n", + "if not os.path.exists(FLAGS.output_dir):\n", + " os.mkdir(FLAGS.output_dir)\n", + "\n", + "#@markdown #2. Define your protein complex target and Run AF2Complex on it using the *Play* button on the left\n", + "#@markdown Define how the chains compose the target, look at sections above for more information (Section 2 of *E. coli* target run)\n", + "\n", + "#@markdown Note that a large target may require resources beyond the free-tier.\n", + "\n", + "# chains = 'e.g. T1065s1/T1065s2' #@param {type:'string'}\n", + "chains = 'HgcA/HgcB' #@param {type:'string'}\n", + "\n", + "#@markdown Name your target\n", + "# target = 'e.g. H1065' #@param {type: 'string'}\n", + "target = 'HgcAB' #@param {type: 'string'}\n", + "\n", + "#@markdown Put down the total number of AA of the target (does not need to be exact as this number will be parsed but not used in the code)\n", + "num_AA = 433 #@param{type: 'integer'}\n", + "\n", + "#@markdown Choose the type of msa pairing you want to use (Note: 'none' will do no msa_pairing, 'all' will do species pairing as in AF-Multimer):\n", + "msa_pairing = 'all' #@param ['none', 'all', 'custom', 'cyclic', 'linear']\n", + "FLAGS.msa_pairing = msa_pairing\n", + "\n", + "target_lst = f'{chains} {num_AA} {target}'\n", + "\n", + "target_lst_file = os.path.join(dset_dir, f'{target}.lst')\n", + "with open(target_lst_file, 'w') as f:\n", + " f.write(target_lst)\n", + " f.close()\n", + "FLAGS.target_lst_file = target_lst_file\n", + "\n", + "pred_params = make_mod_params()\n", + "\n", + "# with io.capture_output() as captured:\n", + "%shell python -u ../run_af2c_mod.py {pred_params}\n", + "print(f'DONE! (predictions available on {FLAGS.output_dir}' )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "01DWp6Sjk6pP" + }, + "outputs": [], + "source": [ + "#@markdown # Press **Play** button to the left to see which targets you have predictions for so far\n", + "from ipywidgets import interact\n", + "pd.DataFrame({\n", + " 'Target Name': os.listdir(FLAGS.output_dir),\n", + " 'Number of Predictions': [\n", + " len(list(filter(lambda x: '.pdb' in x, os.listdir(os.path.join(FLAGS.output_dir,f )))))\n", + " for f in os.listdir(FLAGS.output_dir)],\n", + "})\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "1rjkGMLYyVcU" + }, + "outputs": [], + "source": [ + "# %matplotlib inline\n", + "\n", + "#@markdown #3. Visualize your results below by pressing the *Play* button on the left\n", + "\n", + "#@markdown Place the target name you want to visualize below\n", + "# target = 'e.g. H1065' #@param {type: 'string'}\n", + "target_name = 'HgcAB' #@param {type: 'string'}\n", + "FLAGS.target_lst_file = os.path.join(dset_dir, f'{target_name}.lst')\n", + "if not os.path.exists(FLAGS.target_lst_file):\n", + " raise Exception(f' Target: predictions for {target_name} do not exist, run the cell above to see which targets have predictions')\n", + "\n", + "show_sidechains = False #@param {type: 'boolean'}\n", + "show_mainchains = False #@param {type: 'boolean'}\n", + "\n", + "visualize(show_sidechains, show_mainchains)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7PN-BhHoTtCr" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "private_outputs": true, + "provenance": [], + "toc_visible": true + }, + "gpuClass": "standard", + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.6 (main, Nov 14 2022, 16:10:14) [GCC 11.3.0]" + }, + "vscode": { + "interpreter": { + "hash": "e7370f93d1d0cde622a1f8e1c04877d8463912d04d973331ad4851f04de6915a" + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/requirements.txt b/requirements.txt index d0e90f3ca5a0b4389cfbdeea7d1fab989ce6447e..02ca4dce135d4f585d32d831ced550f3ec78afc2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,15 @@ -absl-py==0.13.0 +absl-py==1.0.0 biopython==1.79 chex==0.0.7 -dm-haiku==0.0.4 +dm-haiku==0.0.9 dm-tree==0.1.6 +docker==5.0.0 immutabledict==2.0.0 -jax==0.2.14 +jax==0.3.25 ml-collections==0.1.0 -numpy==1.19.5 +numpy==1.21.6 pandas==1.3.4 +protobuf==3.20.1 scipy==1.7.0 -tensorflow==2.5.0 +tensorflow-cpu==2.9.0 networkx==2.5 diff --git a/src/alphafold/common/confidence.py b/src/alphafold/common/confidence.py index ffad0384cfe7fc77b21860a3b02bf3c65f5a2068..68dca877f994d0976464fc8fb1144cc71bff36ba 100644 --- a/src/alphafold/common/confidence.py +++ b/src/alphafold/common/confidence.py @@ -396,7 +396,7 @@ def interface_score( atom_mask: np.ndarray, asym_id: np.ndarray, residue_weights: Optional[np.ndarray] = None, - distance_threshold: Optional[int] = 4.5, + distance_threshold: Optional[float] = 4.5, is_probs: Optional[bool] = False,) -> Dict[str, Union[np.ndarray, int]]: """ Returns the interface-score, number of residues in the interface, and number of contacts of a complex model. This is a further tweak from piTM by diff --git a/src/alphafold/common/residue_constants.py b/src/alphafold/common/residue_constants.py index 4318875a9b29636ebbe26a5dd743547b856e21a3..049c9a6df3757d8feded6e52402298008879f687 100644 --- a/src/alphafold/common/residue_constants.py +++ b/src/alphafold/common/residue_constants.py @@ -120,7 +120,7 @@ chi_pi_periodic = [ # 4,5,6,7: 'chi1,2,3,4-group' # The atom positions are relative to the axis-end-atom of the corresponding # rotation axis. The x-axis is in direction of the rotation axis, and the y-axis -# is defined such that the dihedral-angle-definiting atom (the last entry in +# is defined such that the dihedral-angle-defining atom (the last entry in # chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate). # format: [atomname, group_idx, rel_position] rigid_group_atom_positions = { @@ -772,10 +772,10 @@ def _make_rigid_transformation_4x4(ex, ey, translation): # and an array with (restype, atomtype, coord) for the atom positions # and compute affine transformation matrices (4,4) from one rigid group to the # previous group -restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=np.int) +restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=int) restype_atom37_mask = np.zeros([21, 37], dtype=np.float32) restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32) -restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=np.int) +restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=int) restype_atom14_mask = np.zeros([21, 14], dtype=np.float32) restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32) restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32) diff --git a/src/alphafold/data/pipeline.py b/src/alphafold/data/pipeline.py index cb954bdfdfad4378e310bea22b694f12a0233729..bca4306b7ce1d865438c5f49f49beb130610a4ba 100644 --- a/src/alphafold/data/pipeline.py +++ b/src/alphafold/data/pipeline.py @@ -233,14 +233,14 @@ class DataPipeline: use_precomputed_msas=self.use_precomputed_msas) bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto']) else: - bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m') - hhblits_bfd_uniclust_result = run_msa_tool( - msa_runner=self.hhblits_bfd_uniclust_runner, + bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniref_hits.a3m') + hhblits_bfd_uniref_result = run_msa_tool( + msa_runner=self.hhblits_bfd_uniref_runner, input_fasta_path=input_fasta_path, msa_out_path=bfd_out_path, msa_format='a3m', use_precomputed_msas=self.use_precomputed_msas) - bfd_msa = parsers.parse_a3m(hhblits_bfd_uniclust_result['a3m']) + bfd_msa = parsers.parse_a3m(hhblits_bfd_uniref_result['a3m']) templates_result = self.template_featurizer.get_templates( query_sequence=input_sequence, diff --git a/src/alphafold/data/pipeline_uniprot.py b/src/alphafold/data/pipeline_uniprot.py new file mode 100644 index 0000000000000000000000000000000000000000..aa2719adf7e32bafd8da435dcc9df3818cad1d17 --- /dev/null +++ b/src/alphafold/data/pipeline_uniprot.py @@ -0,0 +1,199 @@ +# Modified by Mu Gao to consider only UniProt as the sequence database +# +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functions for building the input features for the AlphaFold model.""" + +import os +from typing import Any, Mapping, MutableMapping, Optional, Sequence, Union +from absl import logging +from alphafold.common import residue_constants +from alphafold.data import msa_identifiers +from alphafold.data import parsers +from alphafold.data import templates +from alphafold.data.tools import hhblits +from alphafold.data.tools import hhsearch +from alphafold.data.tools import hmmsearch +from alphafold.data.tools import jackhmmer +import numpy as np + +# Internal import (7716). + +FeatureDict = MutableMapping[str, np.ndarray] +TemplateSearcher = Union[hhsearch.HHSearch, hmmsearch.Hmmsearch] + + +def make_sequence_features( + sequence: str, description: str, num_res: int) -> FeatureDict: + """Constructs a feature dict of sequence features.""" + features = {} + features['aatype'] = residue_constants.sequence_to_onehot( + sequence=sequence, + mapping=residue_constants.restype_order_with_x, + map_unknown_to_x=True) + features['between_segment_residues'] = np.zeros((num_res,), dtype=np.int32) + features['domain_name'] = np.array([description.encode('utf-8')], + dtype=np.object_) + features['residue_index'] = np.array(range(num_res), dtype=np.int32) + features['seq_length'] = np.array([num_res] * num_res, dtype=np.int32) + features['sequence'] = np.array([sequence.encode('utf-8')], dtype=np.object_) + return features + + +def make_msa_features(msa: Sequence[parsers.Msa]) -> FeatureDict: + """Constructs a feature dict of MSA features.""" + if not msa: + raise ValueError('At least one MSA must be provided.') + + int_msa = [] + deletion_matrix = [] + species_ids = [] + seen_sequences = set() + + if not msa: + raise ValueError(f'MSA {msa_index} must contain at least one sequence.') + for sequence_index, sequence in enumerate(msa.sequences): + if sequence in seen_sequences: + continue + seen_sequences.add(sequence) + int_msa.append( + [residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence]) + deletion_matrix.append(msa.deletion_matrix[sequence_index]) + identifiers = msa_identifiers.get_identifiers( + msa.descriptions[sequence_index]) + species_ids.append(identifiers.species_id.encode('utf-8')) + + num_res = len(msa.sequences[0]) + num_alignments = len(int_msa) + features = {} + features['deletion_matrix_int'] = np.array(deletion_matrix, dtype=np.int32) + features['msa'] = np.array(int_msa, dtype=np.int32) + features['num_alignments'] = np.array( + [num_alignments] * num_res, dtype=np.int32) + features['msa_species_identifiers'] = np.array(species_ids, dtype=np.object_) + return features + + +def run_msa_tool(msa_runner, input_fasta_path: str, msa_out_path: str, + msa_format: str, use_precomputed_msas: bool, + max_sto_sequences: Optional[int] = None + ) -> Mapping[str, Any]: + """Runs an MSA tool, checking if output already exists first.""" + if not use_precomputed_msas or not os.path.exists(msa_out_path): + if msa_format == 'sto' and max_sto_sequences is not None: + result = msa_runner.query(input_fasta_path, max_sto_sequences)[0] # pytype: disable=wrong-arg-count + else: + result = msa_runner.query(input_fasta_path)[0] + with open(msa_out_path, 'w') as f: + f.write(result[msa_format]) + else: + logging.warning('Reading MSA from file %s', msa_out_path) + if msa_format == 'sto' and max_sto_sequences is not None: + precomputed_msa = parsers.truncate_stockholm_msa( + msa_out_path, max_sto_sequences) + result = {'sto': precomputed_msa} + else: + with open(msa_out_path, 'r') as f: + result = {msa_format: f.read()} + return result + + +class DataPipeline: + """Runs the alignment tools and assembles the input features.""" + + def __init__(self, + jackhmmer_binary_path: str, + hhblits_binary_path: str, + uniprot_database_path: str, + template_searcher: TemplateSearcher, + template_featurizer: templates.TemplateHitFeaturizer, + uniprot_max_hits: int = 30000, + use_precomputed_msas: bool = False, + add_species: bool = False): + """Initializes the data pipeline.""" + + self.template_searcher = template_searcher + self.template_featurizer = template_featurizer + self.use_precomputed_msas = use_precomputed_msas + self.add_species = add_species + if add_species: + self.uniprot_msa_runner = jackhmmer.Jackhmmer( + binary_path=jackhmmer_binary_path, + database_path=uniprot_database_path) + self.uniprot_max_hits = uniprot_max_hits + + def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict: + """Runs alignment tools on the input sequence and creates features.""" + with open(input_fasta_path) as f: + input_fasta_str = f.read() + input_seqs, input_descs = parsers.parse_fasta(input_fasta_str) + if len(input_seqs) != 1: + raise ValueError( + f'More than one input sequence found in {input_fasta_path}.') + input_sequence = input_seqs[0] + input_description = input_descs[0] + num_res = len(input_sequence) + + uniprot_out_path = os.path.join(msa_output_dir, 'uniprot_hits.sto') + jackhmmer_uniprot_result = run_msa_tool( + msa_runner=self.uniprot_msa_runner, + input_fasta_path=input_fasta_path, + msa_out_path=uniprot_out_path, + msa_format='sto', + use_precomputed_msas=self.use_precomputed_msas, + max_sto_sequences=self.uniprot_max_hits) + + msa_for_templates = jackhmmer_uniprot_result['sto'] + msa_for_templates = parsers.deduplicate_stockholm_msa(msa_for_templates) + msa_for_templates = parsers.remove_empty_columns_from_stockholm_msa( + msa_for_templates) + + uniprot_msa = parsers.parse_stockholm(jackhmmer_uniprot_result['sto']) + + if self.template_searcher.input_format == 'sto': + pdb_templates_result = self.template_searcher.query(msa_for_templates) + elif self.template_searcher.input_format == 'a3m': + uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(msa_for_templates) + pdb_templates_result = self.template_searcher.query(uniref90_msa_as_a3m) + else: + raise ValueError('Unrecognized template input format: ' + f'{self.template_searcher.input_format}') + + pdb_hits_out_path = os.path.join( + msa_output_dir, f'pdb_hits.{self.template_searcher.output_format}') + with open(pdb_hits_out_path, 'w') as f: + f.write(pdb_templates_result) + + pdb_template_hits = self.template_searcher.get_template_hits( + output_string=pdb_templates_result, input_sequence=input_sequence) + + templates_result = self.template_featurizer.get_templates( + query_sequence=input_sequence, + hits=pdb_template_hits) + + sequence_features = make_sequence_features( + sequence=input_sequence, + description=input_description, + num_res=num_res) + + msa_features = make_msa_features(uniprot_msa) + #logging.info('UniProt MSA size: %d sequences.', len(uniprot_msa)) + logging.info('Final (deduplicated) MSA size: %d sequences.', + msa_features['num_alignments'][0]) + logging.info('Total number of templates (NB: this can include bad ' + 'templates and is later filtered to top 4): %d.', + templates_result.features['template_domain_names'].shape[0]) + + return {**sequence_features, **msa_features, **templates_result.features} diff --git a/src/alphafold/data/templates.py b/src/alphafold/data/templates.py index d3759871190e781f3146109361a139611346323c..91681ac3fa073d41a9bdae40dc4f61cf6a9848a8 100644 --- a/src/alphafold/data/templates.py +++ b/src/alphafold/data/templates.py @@ -89,8 +89,8 @@ TEMPLATE_FEATURES = { 'template_aatype': np.float32, 'template_all_atom_masks': np.float32, 'template_all_atom_positions': np.float32, - 'template_domain_names': np.object, - 'template_sequence': np.object, + 'template_domain_names': object, + 'template_sequence': object, 'template_sum_probs': np.float32, } @@ -1002,8 +1002,8 @@ class HmmsearchHitFeaturizer(TemplateHitFeaturizer): (1, num_res, residue_constants.atom_type_num), np.float32), 'template_all_atom_positions': np.zeros( (1, num_res, residue_constants.atom_type_num, 3), np.float32), - 'template_domain_names': np.array([''.encode()], dtype=np.object), - 'template_sequence': np.array([''.encode()], dtype=np.object), + 'template_domain_names': np.array([''.encode()], dtype=object), + 'template_sequence': np.array([''.encode()], dtype=object), 'template_sum_probs': np.array([0], dtype=np.float32) } return TemplateSearchResult( diff --git a/src/alphafold/data/tools/jackhmmer.py b/src/alphafold/data/tools/jackhmmer.py index 60e0e222c91457c8c1b1dbe0a5f4cac358cd1e69..68997f857f2c4fd3a69a59205310828a4cf08fd2 100644 --- a/src/alphafold/data/tools/jackhmmer.py +++ b/src/alphafold/data/tools/jackhmmer.py @@ -167,10 +167,20 @@ class Jackhmmer: input_fasta_path: str, max_sequences: Optional[int] = None) -> Sequence[Mapping[str, Any]]: """Queries the database using Jackhmmer.""" + return self.query_multiple([input_fasta_path], max_sequences)[0] + + def query_multiple( + self, + input_fasta_paths: Sequence[str], + max_sequences: Optional[int] = None, + ) -> Sequence[Sequence[Mapping[str, Any]]]: + """Queries the database for multiple queries using Jackhmmer.""" if self.num_streamed_chunks is None: - single_chunk_result = self._query_chunk( - input_fasta_path, self.database_path, max_sequences) - return [single_chunk_result] + single_chunk_results = [] + for input_fasta_path in input_fasta_paths: + single_chunk_results.append([self._query_chunk( + input_fasta_path, self.database_path, max_sequences)]) + return single_chunk_results db_basename = os.path.basename(self.database_path) db_remote_chunk = lambda db_idx: f'{self.database_path}.{db_idx}' @@ -185,7 +195,7 @@ class Jackhmmer: # Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk with futures.ThreadPoolExecutor(max_workers=2) as executor: - chunked_output = [] + chunked_outputs = [[] for _ in range(len(input_fasta_paths))] for i in range(1, self.num_streamed_chunks + 1): # Copy the chunk locally if i == 1: @@ -197,9 +207,9 @@ class Jackhmmer: # Run Jackhmmer with the chunk future.result() - chunked_output.append(self._query_chunk( - input_fasta_path, db_local_chunk(i), max_sequences)) - + for fasta_index, input_fasta_path in enumerate(input_fasta_paths): + chunked_outputs[fasta_index].append(self._query_chunk( + input_fasta_path, db_local_chunk(i), max_sequences)) # Remove the local copy of the chunk os.remove(db_local_chunk(i)) # Do not set next_future for the last chunk so that this works even for @@ -208,4 +218,4 @@ class Jackhmmer: future = next_future if self.streaming_callback: self.streaming_callback(i) - return chunked_output + return chunked_outputs diff --git a/src/alphafold/model/all_atom_multimer.py b/src/alphafold/model/all_atom_multimer.py index 361652050effad98660c134ef96a247df13a5d69..2cc49c4d34012b31ab7814e5731c50fa9674e936 100644 --- a/src/alphafold/model/all_atom_multimer.py +++ b/src/alphafold/model/all_atom_multimer.py @@ -426,7 +426,7 @@ def torsion_angles_to_frames( chi3_frame_to_backb = chi2_frame_to_backb @ all_frames[:, 6] chi4_frame_to_backb = chi3_frame_to_backb @ all_frames[:, 7] - all_frames_to_backb = jax.tree_multimap( + all_frames_to_backb = jax.tree_map( lambda *x: jnp.concatenate(x, axis=-1), all_frames[:, 0:5], chi2_frame_to_backb[:, None], chi3_frame_to_backb[:, None], chi4_frame_to_backb[:, None]) diff --git a/src/alphafold/model/common_modules.py b/src/alphafold/model/common_modules.py index 08776a7f00af3dab4c25289954166bc32bec3539..0b5cd07d5b6d643fceb865591c80001f67387ef1 100644 --- a/src/alphafold/model/common_modules.py +++ b/src/alphafold/model/common_modules.py @@ -128,3 +128,64 @@ class Linear(hk.Module): return output + +class LayerNorm(hk.LayerNorm): + """LayerNorm module. + + Equivalent to hk.LayerNorm but with different parameter shapes: they are + always vectors rather than possibly higher-rank tensors. This makes it easier + to change the layout whilst keep the model weight-compatible. + """ + + def __init__(self, + axis, + create_scale: bool, + create_offset: bool, + eps: float = 1e-5, + scale_init=None, + offset_init=None, + use_fast_variance: bool = False, + name=None, + param_axis=None): + super().__init__( + axis=axis, + create_scale=False, + create_offset=False, + eps=eps, + scale_init=None, + offset_init=None, + use_fast_variance=use_fast_variance, + name=name, + param_axis=param_axis) + self._temp_create_scale = create_scale + self._temp_create_offset = create_offset + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + is_bf16 = (x.dtype == jnp.bfloat16) + if is_bf16: + x = x.astype(jnp.float32) + + param_axis = self.param_axis[0] if self.param_axis else -1 + param_shape = (x.shape[param_axis],) + + param_broadcast_shape = [1] * x.ndim + param_broadcast_shape[param_axis] = x.shape[param_axis] + scale = None + offset = None + if self._temp_create_scale: + scale = hk.get_parameter( + 'scale', param_shape, x.dtype, init=self.scale_init) + scale = scale.reshape(param_broadcast_shape) + + if self._temp_create_offset: + offset = hk.get_parameter( + 'offset', param_shape, x.dtype, init=self.offset_init) + offset = offset.reshape(param_broadcast_shape) + + out = super().__call__(x, scale=scale, offset=offset) + + if is_bf16: + out = out.astype(jnp.bfloat16) + + return out + \ No newline at end of file diff --git a/src/alphafold/model/config.py b/src/alphafold/model/config.py index f8e56c8f4c450dd1a3400b9cfbe9fedf8939c4c8..bad78901e44a08a41fdf3c03bae7a296758ff56b 100644 --- a/src/alphafold/model/config.py +++ b/src/alphafold/model/config.py @@ -11,6 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# +# Modified to keep the definitions of previous multimer_v2 and multimer models +# Mu Gao + """Model config.""" import copy @@ -26,12 +30,12 @@ NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES def model_config(name: str) -> ml_collections.ConfigDict: """Get the ConfigDict of a CASP14 model.""" - if 'multimer' in name: - return CONFIG_MULTIMER - if name not in CONFIG_DIFFS: raise ValueError(f'Invalid model name {name}.') - cfg = copy.deepcopy(CONFIG) + if 'multimer' in name: + cfg = copy.deepcopy(CONFIG_MULTIMER) + else: + cfg = copy.deepcopy(CONFIG) cfg.update_from_flattened_dict(CONFIG_DIFFS[name]) return cfg @@ -65,6 +69,13 @@ MODEL_PRESETS = { 'model_4_multimer_v2', 'model_5_multimer_v2', ), + 'multimer_v3': ( + 'model_1_multimer_v3', + 'model_2_multimer_v3', + 'model_3_multimer_v3', + 'model_4_multimer_v3', + 'model_5_multimer_v3', + ), } MODEL_PRESETS['monomer_casp14'] = MODEL_PRESETS['monomer'] @@ -125,8 +136,31 @@ CONFIG_DIFFS = { }, 'model_5_ptm': { 'model.heads.predicted_aligned_error.weight': 0.1 - } + }, + 'model_1_multimer_v3': {}, + 'model_2_multimer_v3': {}, + 'model_3_multimer_v3': {}, + 'model_4_multimer_v3': { + 'model.embeddings_and_evoformer.num_extra_msa': 1152 + }, + 'model_5_multimer_v3': { + 'model.embeddings_and_evoformer.num_extra_msa': 1152 + }, } +# Key differences between multimer v1/v2 and v3, mostly due to numerical +# optimisations in the TriangleMultiplication module. +common_updates = { + 'model.embeddings_and_evoformer.num_msa': 252, + 'model.embeddings_and_evoformer.num_extra_msa': 1152, + 'model.embeddings_and_evoformer.evoformer.triangle_multiplication_incoming.fuse_projection_weights': False, + 'model.embeddings_and_evoformer.evoformer.triangle_multiplication_outgoing.fuse_projection_weights': False, + 'model.embeddings_and_evoformer.template.template_pair_stack.triangle_multiplication_incoming.fuse_projection_weights': False, + 'model.embeddings_and_evoformer.template.template_pair_stack.triangle_multiplication_outgoing.fuse_projection_weights': False, +} +CONFIG_DIFFS.update( + {f'model_{i}_multimer': common_updates for i in range(1, 6)}) +CONFIG_DIFFS.update( + {f'model_{i}_multimer_v2': common_updates for i in range(1, 6)}) CONFIG = ml_collections.ConfigDict({ 'data': { @@ -267,14 +301,16 @@ CONFIG = ml_collections.ConfigDict({ 'equation': 'ikc,jkc->ijc', 'num_intermediate_channel': 128, 'orientation': 'per_row', - 'shared_dropout': True + 'shared_dropout': True, + 'fuse_projection_weights': False, }, 'triangle_multiplication_incoming': { 'dropout_rate': 0.25, 'equation': 'kjc,kic->ijc', 'num_intermediate_channel': 128, 'orientation': 'per_row', - 'shared_dropout': True + 'shared_dropout': True, + 'fuse_projection_weights': False, }, 'pair_transition': { 'dropout_rate': 0.0, @@ -335,14 +371,16 @@ CONFIG = ml_collections.ConfigDict({ 'equation': 'ikc,jkc->ijc', 'num_intermediate_channel': 64, 'orientation': 'per_row', - 'shared_dropout': True + 'shared_dropout': True, + 'fuse_projection_weights': False, }, 'triangle_multiplication_incoming': { 'dropout_rate': 0.25, 'equation': 'kjc,kic->ijc', 'num_intermediate_channel': 64, 'orientation': 'per_row', - 'shared_dropout': True + 'shared_dropout': True, + 'fuse_projection_weights': False, }, 'pair_transition': { 'dropout_rate': 0.0, @@ -361,7 +399,8 @@ CONFIG = ml_collections.ConfigDict({ 'multimer_mode': False, 'subbatch_size': 4, 'use_remat': False, - 'zero_init': True + 'zero_init': True, + 'eval_dropout': False, }, 'heads': { 'distogram': { @@ -490,27 +529,29 @@ CONFIG_MULTIMER = ml_collections.ConfigDict({ 'gating': True, 'num_head': 4, 'orientation': 'per_row', - 'shared_dropout': True + 'shared_dropout': True, }, 'triangle_multiplication_incoming': { 'dropout_rate': 0.25, 'equation': 'kjc,kic->ijc', 'num_intermediate_channel': 128, 'orientation': 'per_row', - 'shared_dropout': True + 'shared_dropout': True, + 'fuse_projection_weights': True, }, 'triangle_multiplication_outgoing': { 'dropout_rate': 0.25, 'equation': 'ikc,jkc->ijc', 'num_intermediate_channel': 128, 'orientation': 'per_row', - 'shared_dropout': True + 'shared_dropout': True, + 'fuse_projection_weights': True, } }, 'extra_msa_channel': 64, 'extra_msa_stack_num_block': 4, - 'num_msa': 252, - 'num_extra_msa': 1152, + 'num_msa': 508, + 'num_extra_msa': 2048, 'masked_msa': { 'profile_prob': 0.1, 'replace_fraction': 0.15, @@ -571,24 +612,29 @@ CONFIG_MULTIMER = ml_collections.ConfigDict({ 'equation': 'kjc,kic->ijc', 'num_intermediate_channel': 64, 'orientation': 'per_row', - 'shared_dropout': True + 'shared_dropout': True, + 'fuse_projection_weights': True, }, 'triangle_multiplication_outgoing': { 'dropout_rate': 0.25, 'equation': 'ikc,jkc->ijc', 'num_intermediate_channel': 64, 'orientation': 'per_row', - 'shared_dropout': True + 'shared_dropout': True, + 'fuse_projection_weights': True, } } }, }, 'global_config': { + 'bfloat16': True, + 'bfloat16_output': False, 'deterministic': False, 'multimer_mode': True, 'subbatch_size': 4, 'use_remat': False, - 'zero_init': True + 'zero_init': True, + 'eval_dropout': False, }, 'heads': { 'distogram': { @@ -658,7 +704,13 @@ CONFIG_MULTIMER = ml_collections.ConfigDict({ } }, 'num_ensemble_eval': 1, - 'num_recycle': 3, + 'num_recycle': 20, + # A negative value indicates that no early stopping will occur, i.e. + # the model will always run `num_recycle` number of recycling + # iterations. A positive value will enable early stopping if the + # difference in pairwise distances is less than the tolerance between + # recycling steps. + 'recycle_early_stop_tolerance': 0.5, 'resample_msa_in_recycling': True } }) diff --git a/src/alphafold/model/folding.py b/src/alphafold/model/folding.py index 1faf5bd58377880107da119b4b65c96a2f1e728d..e73266489190f7df63632be126443cf1c8a62422 100644 --- a/src/alphafold/model/folding.py +++ b/src/alphafold/model/folding.py @@ -331,7 +331,7 @@ class FoldIteration(hk.Module): safe_key, *sub_keys = safe_key.split(3) sub_keys = iter(sub_keys) act = safe_dropout_fn(act, next(sub_keys)) - act = hk.LayerNorm( + act = common_modules.LayerNorm( axis=[-1], create_scale=True, create_offset=True, @@ -353,7 +353,7 @@ class FoldIteration(hk.Module): act = jax.nn.relu(act) act += input_act act = safe_dropout_fn(act, next(sub_keys)) - act = hk.LayerNorm( + act = common_modules.LayerNorm( axis=[-1], create_scale=True, create_offset=True, @@ -410,7 +410,7 @@ def generate_affines(representations, batch, config, global_config, c = config sequence_mask = batch['seq_mask'][:, None] - act = hk.LayerNorm( + act = common_modules.LayerNorm( axis=[-1], create_scale=True, create_offset=True, @@ -433,7 +433,7 @@ def generate_affines(representations, batch, config, global_config, 'affine': affine.to_tensor(), } - act_2d = hk.LayerNorm( + act_2d = common_modules.LayerNorm( axis=[-1], create_scale=True, create_offset=True, diff --git a/src/alphafold/model/folding_multimer.py b/src/alphafold/model/folding_multimer.py index 90ce31b902e47564e0bc6e40ed71115ed2e64ad8..b565a0d41b4befa93c88769de1813d3e88a76246 100644 --- a/src/alphafold/model/folding_multimer.py +++ b/src/alphafold/model/folding_multimer.py @@ -427,7 +427,7 @@ class FoldIteration(hk.Module): safe_key, *sub_keys = safe_key.split(3) sub_keys = iter(sub_keys) act = safe_dropout_fn(act, next(sub_keys)) - act = hk.LayerNorm( + act = common_modules.LayerNorm( axis=-1, create_scale=True, create_offset=True, @@ -448,7 +448,7 @@ class FoldIteration(hk.Module): act = jax.nn.relu(act) act += input_act act = safe_dropout_fn(act, next(sub_keys)) - act = hk.LayerNorm( + act = common_modules.LayerNorm( axis=-1, create_scale=True, create_offset=True, @@ -500,7 +500,7 @@ def generate_monomer_rigids(representations: Mapping[str, jnp.ndarray], """ c = config sequence_mask = batch['seq_mask'][:, None] - act = hk.LayerNorm( + act = common_modules.LayerNorm( axis=-1, create_scale=True, create_offset=True, name='single_layer_norm')( representations['single']) @@ -523,7 +523,7 @@ def generate_monomer_rigids(representations: Mapping[str, jnp.ndarray], rigid } - act_2d = hk.LayerNorm( + act_2d = common_modules.LayerNorm( axis=-1, create_scale=True, create_offset=True, @@ -546,7 +546,7 @@ def generate_monomer_rigids(representations: Mapping[str, jnp.ndarray], ) outputs.append(output) - output = jax.tree_multimap(lambda *x: jnp.stack(x), *outputs) + output = jax.tree_map(lambda *x: jnp.stack(x), *outputs) # Pass along for LDDT-Head. output['act'] = activations['act'] @@ -789,7 +789,7 @@ def backbone_loss(gt_rigid: geometry.Rigid3Array, loss_fn = functools.partial( all_atom_multimer.frame_aligned_point_error, l1_clamp_distance=config.atom_clamp_distance, - loss_unit_distance=config.loss_unit_distance) + length_scale=config.loss_unit_distance) loss_fn = jax.vmap(loss_fn, (0, None, None, 0, None, None, None)) fape = loss_fn(target_rigid, gt_rigid, gt_frames_mask, @@ -823,7 +823,7 @@ def compute_frames( alt_gt_frames = frames_batch['rigidgroups_alt_gt_frames'] use_alt = use_alt[:, None] - renamed_gt_frames = jax.tree_multimap( + renamed_gt_frames = jax.tree_map( lambda x, y: (1. - use_alt) * x + use_alt * y, gt_frames, alt_gt_frames) return renamed_gt_frames, frames_batch['rigidgroups_gt_exists'] @@ -1160,4 +1160,3 @@ class MultiRigidSidechain(hk.Module): 'frames': all_frames_to_global, # geometry.Rigid3Array (N, 8) }) return outputs - diff --git a/src/alphafold/model/geometry/rigid_matrix_vector.py b/src/alphafold/model/geometry/rigid_matrix_vector.py index 299f6401706e78c2b8872ec7db99235ab904578a..4f2c0f006b0a64f399abb28d4297395cd1b9cf1f 100644 --- a/src/alphafold/model/geometry/rigid_matrix_vector.py +++ b/src/alphafold/model/geometry/rigid_matrix_vector.py @@ -65,7 +65,7 @@ class Rigid3Array: """Return identity Rigid3Array of given shape.""" return cls( rotation_matrix.Rot3Array.identity(shape, dtype=dtype), - vector.Vec3Array.zeros(shape, dtype=dtype)) + vector.Vec3Array.zeros(shape, dtype=dtype)) # pytype: disable=wrong-arg-count # trace-all-classes def scale_translation(self, factor: Float) -> Rigid3Array: """Scale translation in Rigid3Array by 'factor'.""" @@ -80,7 +80,7 @@ class Rigid3Array: def from_array(cls, array): rot = rotation_matrix.Rot3Array.from_array(array[..., :3]) vec = vector.Vec3Array.from_array(array[..., -1]) - return cls(rot, vec) + return cls(rot, vec) # pytype: disable=wrong-arg-count # trace-all-classes @classmethod def from_array4x4(cls, array: jnp.ndarray) -> Rigid3Array: @@ -94,7 +94,7 @@ class Rigid3Array: ) translation = vector.Vec3Array( array[..., 0, 3], array[..., 1, 3], array[..., 2, 3]) - return cls(rotation, translation) + return cls(rotation, translation) # pytype: disable=wrong-arg-count # trace-all-classes def __getstate__(self): return (VERSION, (self.rotation, self.translation)) diff --git a/src/alphafold/model/geometry/rotation_matrix.py b/src/alphafold/model/geometry/rotation_matrix.py index 3222329940078095645bd1e6c191a0314c143774..ccb211110024df5ce21dd39b7beb2ece7f5bfc83 100644 --- a/src/alphafold/model/geometry/rotation_matrix.py +++ b/src/alphafold/model/geometry/rotation_matrix.py @@ -73,7 +73,7 @@ class Rot3Array: """Returns identity of given shape.""" ones = jnp.ones(shape, dtype=dtype) zeros = jnp.zeros(shape, dtype=dtype) - return cls(ones, zeros, zeros, zeros, ones, zeros, zeros, zeros, ones) + return cls(ones, zeros, zeros, zeros, ones, zeros, zeros, zeros, ones) # pytype: disable=wrong-arg-count # trace-all-classes @classmethod def from_two_vectors(cls, e0: vector.Vec3Array, @@ -96,7 +96,7 @@ class Rot3Array: e1 = (e1 - c * e0).normalized() # Compute e2 as cross product of e0 and e1. e2 = e0.cross(e1) - return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z) + return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z) # pytype: disable=wrong-arg-count # trace-all-classes @classmethod def from_array(cls, array: jnp.ndarray) -> Rot3Array: @@ -137,7 +137,7 @@ class Rot3Array: zx = 2 * (x * z - w * y) zy = 2 * (y * z + w * x) zz = 1 - 2 * (jnp.square(x) + jnp.square(y)) - return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz) + return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz) # pytype: disable=wrong-arg-count # trace-all-classes @classmethod def random_uniform(cls, key, shape, dtype=jnp.float32) -> Rot3Array: diff --git a/src/alphafold/model/geometry/struct_of_array.py b/src/alphafold/model/geometry/struct_of_array.py index 97a89fd4a6d9ff1430540532922cdeb21d3aec42..562743b327dcd6872d60df202e8862967674b915 100644 --- a/src/alphafold/model/geometry/struct_of_array.py +++ b/src/alphafold/model/geometry/struct_of_array.py @@ -133,7 +133,7 @@ def flatten(instance): inner_treedefs = [] num_arrays = [] for array_like in array_likes: - flat_array_like, inner_treedef = jax.tree_flatten(array_like) + flat_array_like, inner_treedef = jax.tree_util.tree_flatten(array_like) inner_treedefs.append(inner_treedef) flat_array_likes += flat_array_like num_arrays.append(len(flat_array_like)) @@ -206,7 +206,7 @@ class StructOfArray: for num_array, inner_treedef, array_field in zip(num_arrays, inner_treedefs, array_fields): - value_dict[array_field] = jax.tree_unflatten( + value_dict[array_field] = jax.tree_util.tree_unflatten( inner_treedef, data[array_start:array_start + num_array]) array_start += num_array metadata_fields = get_metadata_fields(new_cls) diff --git a/src/alphafold/model/geometry/vector.py b/src/alphafold/model/geometry/vector.py index 99dcb50f77f7165b8c7f1ed1f0c67a48ca3141cb..8f22cc54ba4c1596101a14522e77c7ddc06bfd02 100644 --- a/src/alphafold/model/geometry/vector.py +++ b/src/alphafold/model/geometry/vector.py @@ -53,10 +53,10 @@ class Vec3Array: assert all([x == z for x, z in zip(self.x.shape, self.z.shape)]) def __add__(self, other: Vec3Array) -> Vec3Array: - return jax.tree_multimap(lambda x, y: x + y, self, other) + return jax.tree_map(lambda x, y: x + y, self, other) def __sub__(self, other: Vec3Array) -> Vec3Array: - return jax.tree_multimap(lambda x, y: x - y, self, other) + return jax.tree_map(lambda x, y: x - y, self, other) def __mul__(self, other: Float) -> Vec3Array: return jax.tree_map(lambda x: x * other, self) @@ -104,7 +104,7 @@ class Vec3Array: """Return Vec3Array corresponding to zeros of given shape.""" return cls( jnp.zeros(shape, dtype), jnp.zeros(shape, dtype), - jnp.zeros(shape, dtype)) + jnp.zeros(shape, dtype)) # pytype: disable=wrong-arg-count # trace-all-classes def to_array(self) -> jnp.ndarray: return jnp.stack([self.x, self.y, self.z], axis=-1) diff --git a/src/alphafold/model/layer_stack_test.py b/src/alphafold/model/layer_stack_test.py index 062221f6b753f188475de40f3d50f53324735ffc..d2682f895e5f68d882377ab7b8bd5e3c55c07232 100644 --- a/src/alphafold/model/layer_stack_test.py +++ b/src/alphafold/model/layer_stack_test.py @@ -198,8 +198,8 @@ class LayerStackTest(parameterized.TestCase): assert_fn = functools.partial( np.testing.assert_allclose, atol=1e-4, rtol=1e-4) - jax.tree_multimap(assert_fn, unrolled_grad, - _slice_layers_params(layer_stack_grad)) + jax.tree_map(assert_fn, unrolled_grad, + _slice_layers_params(layer_stack_grad)) def test_random(self): """Random numbers should be handled correctly.""" diff --git a/src/alphafold/model/mapping.py b/src/alphafold/model/mapping.py index 2524572b5648d2cdeca87bf99c8f82eecc328431..0e736d521b5bd0ab1ef7d98a33be6c79008e45b0 100644 --- a/src/alphafold/model/mapping.py +++ b/src/alphafold/model/mapping.py @@ -47,11 +47,11 @@ def _maybe_get_size(array, axis): def _expand_axes(axes, values, name='sharded_apply'): - values_tree_def = jax.tree_flatten(values)[1] + values_tree_def = jax.tree_util.tree_flatten(values)[1] flat_axes = jax.api_util.flatten_axes(name, values_tree_def, axes) # Replace None's with PROXY flat_axes = [PROXY if x is None else x for x in flat_axes] - return jax.tree_unflatten(values_tree_def, flat_axes) + return jax.tree_util.tree_unflatten(values_tree_def, flat_axes) def sharded_map( @@ -125,8 +125,8 @@ def sharded_apply( # Expand in axes and Determine Loop range in_axes_ = _expand_axes(in_axes, args) - in_sizes = jax.tree_multimap(_maybe_get_size, args, in_axes_) - flat_sizes = jax.tree_flatten(in_sizes)[0] + in_sizes = jax.tree_map(_maybe_get_size, args, in_axes_) + flat_sizes = jax.tree_util.tree_flatten(in_sizes)[0] in_size = max(flat_sizes) assert all(i in {in_size, -1} for i in flat_sizes) @@ -137,7 +137,7 @@ def sharded_apply( last_shard_size = shard_size if last_shard_size == 0 else last_shard_size def apply_fun_to_slice(slice_start, slice_size): - input_slice = jax.tree_multimap( + input_slice = jax.tree_map( lambda array, axis: _maybe_slice(array, slice_start, slice_size, axis ), args, in_axes_) return fun(*input_slice) @@ -158,11 +158,11 @@ def sharded_apply( shard_shape[axis] * num_extra_shards + remainder_shape[axis],) + shard_shape[axis + 1:] - out_shapes = jax.tree_multimap(make_output_shape, out_axes_, shard_shapes, - out_shapes) + out_shapes = jax.tree_map(make_output_shape, out_axes_, shard_shapes, + out_shapes) # Calls dynamic Update slice with different argument order - # This is here since tree_multimap only works with positional arguments + # This is here since tree_map only works with positional arguments def dynamic_update_slice_in_dim(full_array, update, axis, i): return jax.lax.dynamic_update_slice_in_dim(full_array, update, i, axis) @@ -170,7 +170,7 @@ def sharded_apply( slice_out = apply_fun_to_slice(slice_start, slice_size) update_slice = partial( dynamic_update_slice_in_dim, i=slice_start) - return jax.tree_multimap(update_slice, outputs, slice_out, out_axes_) + return jax.tree_map(update_slice, outputs, slice_out, out_axes_) def scan_iteration(outputs, i): new_outputs = compute_shard(outputs, i, shard_size) @@ -181,7 +181,7 @@ def sharded_apply( def allocate_buffer(dtype, shape): return jnp.zeros(shape, dtype=dtype) - outputs = jax.tree_multimap(allocate_buffer, out_dtypes, out_shapes) + outputs = jax.tree_map(allocate_buffer, out_dtypes, out_shapes) if slice_starts.shape[0] > 0: outputs, _ = hk.scan(scan_iteration, outputs, slice_starts) diff --git a/src/alphafold/model/model.py b/src/alphafold/model/model.py index 680a17acbc0adeb018277a2e83afb946aad21466..b7982738f3dd61657cfbd4d7165b2edb13a1e935 100644 --- a/src/alphafold/model/model.py +++ b/src/alphafold/model/model.py @@ -67,7 +67,6 @@ def get_confidence_metrics( confidence_metrics['pitm'] = confidence.predicted_interface_tm_score( np.asarray(prediction_result['predicted_aligned_error']['logits']), np.asarray(prediction_result['predicted_aligned_error']['breaks']), - # np.asarray(residue_index), np.asarray(prediction_result['structure_module']['final_atom_positions']), np.asarray(prediction_result['structure_module']['final_atom_mask']), np.asarray(prediction_result['predicted_aligned_error']['asym_id']), @@ -187,7 +186,7 @@ class RunModel: logging.info('Output shape was %s', shape) return shape - def predict(self, feat: features.FeatureDict, random_seed=0, + def predict(self, feat: features.FeatureDict, random_seed: int, prev=None, prev_ckpt_iter=0, asym_id_list=None, asym_id=None, edge_contacts_thres=10) -> Mapping[str, Any]: """Makes a prediction by inferencing the model on the provided features. @@ -196,7 +195,7 @@ class RunModel: feat: A dictionary of NumPy feature arrays as output by RunModel.process_features. random_seed: The random seed to use when running the model. In the - multimer model this controls the MSA sampling + multimer model this controls the MSA sampling. Returns: A dictionary of model outputs. @@ -258,12 +257,14 @@ class RunModel: if self.config.model.save_recycled: *_, recycled_info = recycles - structs = recycled_info['atom_positions'] - structs_masks = recycled_info['atom_mask'] - plddt = recycled_info['plddt'] - palign_logits = recycled_info['pred_aligned_error_logits'] - palign_break = recycled_info['pred_aligned_error_breaks'] - tol_values = recycled_info['tol_values'] + # must convert jax array to np array, otherwise some interplay between + # jax array and the loops in the AF2Complex metric functions dramatically slow downs the calculations + structs = np.asarray(recycled_info['atom_positions']) + structs_masks = np.asarray(recycled_info['atom_mask']) + plddt = np.asarray(recycled_info['plddt']) + palign_logits = np.asarray(recycled_info['pred_aligned_error_logits']) + palign_break = np.asarray(recycled_info['pred_aligned_error_breaks']) + tol_values = np.asarray(recycled_info['tol_values']) recycled_info_ = [] for i, (s, m, p, a_logits, a_break, tol_val) in enumerate(zip( @@ -283,7 +284,6 @@ class RunModel: "pitm": confidence.predicted_interface_tm_score( a_logits, a_break, - # res_index, s, m, asym_id, diff --git a/src/alphafold/model/modules.py b/src/alphafold/model/modules.py index b43cf1cb56ad227c89cd44fbe580807821dfd8be..af9f8ea962c1cc792fd011bd4c5835e3c27abe7f 100644 --- a/src/alphafold/model/modules.py +++ b/src/alphafold/model/modules.py @@ -81,6 +81,9 @@ def dropout_wrapper(module, residual = module(input_act, mask, is_training=is_training, **kwargs) dropout_rate = 0.0 if gc.deterministic else module.config.dropout_rate + # Will override `is_training` to True if want to use dropout. + should_apply_dropout = True if gc.eval_dropout else is_training + if module.config.shared_dropout: if module.config.orientation == 'per_row': broadcast_dim = 0 @@ -92,7 +95,7 @@ def dropout_wrapper(module, residual = apply_dropout(tensor=residual, safe_key=safe_key, rate=dropout_rate, - is_training=is_training, + is_training=should_apply_dropout, broadcast_dim=broadcast_dim) new_act = output_act + residual @@ -560,7 +563,7 @@ class Transition(hk.Module): num_intermediate = int(nc * self.config.num_intermediate_factor) mask = jnp.expand_dims(mask, axis=-1) - act = hk.LayerNorm( + act = common_modules.LayerNorm( axis=[-1], create_scale=True, create_offset=True, @@ -628,12 +631,15 @@ class Attention(hk.Module): q_weights = hk.get_parameter( 'query_w', shape=(q_data.shape[-1], num_head, key_dim), + dtype=q_data.dtype, init=glorot_uniform()) k_weights = hk.get_parameter( 'key_w', shape=(m_data.shape[-1], num_head, key_dim), + dtype=q_data.dtype, init=glorot_uniform()) v_weights = hk.get_parameter( 'value_w', shape=(m_data.shape[-1], num_head, value_dim), + dtype=q_data.dtype, init=glorot_uniform()) q = jnp.einsum('bqa,ahc->bqhc', q_data, q_weights) * key_dim**(-0.5) @@ -654,10 +660,12 @@ class Attention(hk.Module): gating_weights = hk.get_parameter( 'gating_w', shape=(q_data.shape[-1], num_head, value_dim), + dtype=q_data.dtype, init=hk.initializers.Constant(0.0)) gating_bias = hk.get_parameter( 'gating_b', shape=(num_head, value_dim), + dtype=q_data.dtype, init=hk.initializers.Constant(1.0)) gate_values = jnp.einsum('bqc, chv->bqhv', q_data, @@ -669,9 +677,12 @@ class Attention(hk.Module): o_weights = hk.get_parameter( 'output_w', shape=(num_head, value_dim, self.output_dim), + dtype=q_data.dtype, init=init) - o_bias = hk.get_parameter('output_b', shape=(self.output_dim,), - init=hk.initializers.Constant(0.0)) + o_bias = hk.get_parameter( + 'output_b', shape=(self.output_dim,), + dtype=q_data.dtype, + init=hk.initializers.Constant(0.0)) output = jnp.einsum('bqhc,hco->bqo', weighted_avg, o_weights) + o_bias @@ -718,12 +729,15 @@ class GlobalAttention(hk.Module): q_weights = hk.get_parameter( 'query_w', shape=(q_data.shape[-1], num_head, key_dim), + dtype=q_data.dtype, init=glorot_uniform()) k_weights = hk.get_parameter( 'key_w', shape=(m_data.shape[-1], key_dim), + dtype=q_data.dtype, init=glorot_uniform()) v_weights = hk.get_parameter( 'value_w', shape=(m_data.shape[-1], value_dim), + dtype=q_data.dtype, init=glorot_uniform()) v = jnp.einsum('bka,ac->bkc', m_data, v_weights) @@ -744,18 +758,23 @@ class GlobalAttention(hk.Module): o_weights = hk.get_parameter( 'output_w', shape=(num_head, value_dim, self.output_dim), + dtype=q_data.dtype, init=init) - o_bias = hk.get_parameter('output_b', shape=(self.output_dim,), - init=hk.initializers.Constant(0.0)) + o_bias = hk.get_parameter( + 'output_b', shape=(self.output_dim,), + dtype=q_data.dtype, + init=hk.initializers.Constant(0.0)) if self.config.gating: gating_weights = hk.get_parameter( 'gating_w', shape=(q_data.shape[-1], num_head, value_dim), + dtype=q_data.dtype, init=hk.initializers.Constant(0.0)) gating_bias = hk.get_parameter( 'gating_b', shape=(num_head, value_dim), + dtype=q_data.dtype, init=hk.initializers.Constant(1.0)) gate_values = jnp.einsum('bqc, chv->bqhv', q_data, gating_weights) @@ -805,11 +824,11 @@ class MSARowAttentionWithPairBias(hk.Module): bias = (1e9 * (msa_mask - 1.))[:, None, None, :] assert len(bias.shape) == 4 - msa_act = hk.LayerNorm( + msa_act = common_modules.LayerNorm( axis=[-1], create_scale=True, create_offset=True, name='query_norm')( msa_act) - pair_act = hk.LayerNorm( + pair_act = common_modules.LayerNorm( axis=[-1], create_scale=True, create_offset=True, @@ -820,6 +839,7 @@ class MSARowAttentionWithPairBias(hk.Module): weights = hk.get_parameter( 'feat_2d_weights', shape=(pair_act.shape[-1], c.num_head), + dtype=msa_act.dtype, init=hk.initializers.RandomNormal(stddev=init_factor)) nonbatched_bias = jnp.einsum('qkc,ch->hqk', pair_act, weights) @@ -872,7 +892,7 @@ class MSAColumnAttention(hk.Module): bias = (1e9 * (msa_mask - 1.))[:, None, None, :] assert len(bias.shape) == 4 - msa_act = hk.LayerNorm( + msa_act = common_modules.LayerNorm( axis=[-1], create_scale=True, create_offset=True, name='query_norm')( msa_act) @@ -927,7 +947,7 @@ class MSAColumnGlobalAttention(hk.Module): bias = (1e9 * (msa_mask - 1.))[:, None, None, :] assert len(bias.shape) == 4 - msa_act = hk.LayerNorm( + msa_act = common_modules.LayerNorm( axis=[-1], create_scale=True, create_offset=True, name='query_norm')( msa_act) @@ -984,7 +1004,7 @@ class TriangleAttention(hk.Module): bias = (1e9 * (pair_mask - 1.))[:, None, None, :] assert len(bias.shape) == 4 - pair_act = hk.LayerNorm( + pair_act = common_modules.LayerNorm( axis=[-1], create_scale=True, create_offset=True, name='query_norm')( pair_act) @@ -992,6 +1012,7 @@ class TriangleAttention(hk.Module): weights = hk.get_parameter( 'feat_2d_weights', shape=(pair_act.shape[-1], c.num_head), + dtype=pair_act.dtype, init=hk.initializers.RandomNormal(stddev=init_factor)) nonbatched_bias = jnp.einsum('qkc,ch->hqk', pair_act, weights) @@ -1089,7 +1110,7 @@ class PredictedLDDTHead(hk.Module): """ act = representations['structure_module'] - act = hk.LayerNorm( + act = common_modules.LayerNorm( axis=[-1], create_scale=True, create_offset=True, @@ -1311,6 +1332,19 @@ class ExperimentallyResolvedHead(hk.Module): return output +def _layer_norm(axis=-1, name='layer_norm'): + return common_modules.LayerNorm( + axis=axis, + create_scale=True, + create_offset=True, + eps=1e-5, + use_fast_variance=True, + scale_init=hk.initializers.Constant(1.), + offset_init=hk.initializers.Constant(0.), + param_axis=axis, + name=name) + + class TriangleMultiplication(hk.Module): """Triangle multiplication layer ("outgoing" or "incoming"). @@ -1323,25 +1357,34 @@ class TriangleMultiplication(hk.Module): self.config = config self.global_config = global_config - def __call__(self, act, mask, is_training=True): + def __call__(self, left_act, left_mask, is_training=True): """Builds TriangleMultiplication module. Arguments: - act: Pair activations, shape [N_res, N_res, c_z] - mask: Pair mask, shape [N_res, N_res]. + left_act: Pair activations, shape [N_res, N_res, c_z] + left_mask: Pair mask, shape [N_res, N_res]. is_training: Whether the module is in training mode. Returns: - Outputs, same shape/type as act. + Outputs, same shape/type as left_act. """ del is_training + + if self.config.fuse_projection_weights: + return self._fused_triangle_multiplication(left_act, left_mask) + else: + return self._triangle_multiplication(left_act, left_mask) + + @hk.transparent + def _triangle_multiplication(self, left_act, left_mask): + """Implementation of TriangleMultiplication used in AF2 and AF-M<2.3.""" c = self.config gc = self.global_config - mask = mask[..., None] + mask = left_mask[..., None] - act = hk.LayerNorm(axis=[-1], create_scale=True, create_offset=True, - name='layer_norm_input')(act) + act = common_modules.LayerNorm(axis=[-1], create_scale=True, create_offset=True, + name='layer_norm_input')(left_act) input_act = act left_projection = common_modules.Linear( @@ -1377,7 +1420,7 @@ class TriangleMultiplication(hk.Module): # b = left_proj_act and a = right_proj_act act = jnp.einsum(c.equation, left_proj_act, right_proj_act) - act = hk.LayerNorm( + act = common_modules.LayerNorm( axis=[-1], create_scale=True, create_offset=True, @@ -1400,6 +1443,50 @@ class TriangleMultiplication(hk.Module): return act + @hk.transparent + def _fused_triangle_multiplication(self, left_act, left_mask): + """TriangleMultiplication with fused projection weights.""" + mask = left_mask[..., None] + c = self.config + gc = self.global_config + + left_act = _layer_norm(axis=-1, name='left_norm_input')(left_act) + + # Both left and right projections are fused into projection. + projection = common_modules.Linear( + 2*c.num_intermediate_channel, name='projection') + proj_act = mask * projection(left_act) + + # Both left + right gate are fused into gate_values. + gate_values = common_modules.Linear( + 2 * c.num_intermediate_channel, + name='gate', + bias_init=1., + initializer=utils.final_init(gc))(left_act) + proj_act *= jax.nn.sigmoid(gate_values) + + left_proj_act = proj_act[:, :, :c.num_intermediate_channel] + right_proj_act = proj_act[:, :, c.num_intermediate_channel:] + act = jnp.einsum(c.equation, left_proj_act, right_proj_act) + + act = _layer_norm(axis=-1, name='center_norm')(act) + + output_channel = int(left_act.shape[-1]) + + act = common_modules.Linear( + output_channel, + initializer=utils.final_init(gc), + name='output_projection')(act) + + gate_values = common_modules.Linear( + output_channel, + bias_init=1., + initializer=utils.final_init(gc), + name='gating_linear')(left_act) + act *= jax.nn.sigmoid(gate_values) + + return act + class DistogramHead(hk.Module): """Head to predict a distogram. @@ -1506,7 +1593,7 @@ class OuterProductMean(hk.Module): c = self.config mask = mask[..., None] - act = hk.LayerNorm([-1], True, True, name='layer_norm_input')(act) + act = common_modules.LayerNorm([-1], True, True, name='layer_norm_input')(act) left_act = mask * common_modules.Linear( c.num_outer_channel, @@ -1529,9 +1616,11 @@ class OuterProductMean(hk.Module): 'output_w', shape=(c.num_outer_channel, c.num_outer_channel, self.num_output_channel), + dtype=act.dtype, init=init_w) output_b = hk.get_parameter( 'output_b', shape=(self.num_output_channel,), + dtype=act.dtype, init=hk.initializers.Constant(0.0)) def compute_chunk(left_act): @@ -1798,20 +1887,20 @@ class EmbeddingsAndEvoformer(hk.Module): dgram) if c.recycle_features: - if 'prev_msa_first_row' in batch: - prev_msa_first_row = hk.LayerNorm([-1], - True, - True, - name='prev_msa_first_row_norm')( - batch['prev_msa_first_row']) - msa_activations = msa_activations.at[0].add(prev_msa_first_row) - - if 'prev_pair' in batch: - pair_activations += hk.LayerNorm([-1], - True, - True, - name='prev_pair_norm')( - batch['prev_pair']) + prev_msa_first_row = common_modules.LayerNorm( + axis=[-1], + create_scale=True, + create_offset=True, + name='prev_msa_first_row_norm')( + batch['prev_msa_first_row']) + msa_activations = msa_activations.at[0].add(prev_msa_first_row) + + pair_activations += common_modules.LayerNorm( + axis=[-1], + create_scale=True, + create_offset=True, + name='prev_pair_norm')( + batch['prev_pair']) # Relative position encoding. # Jumper et al. (2021) Suppl. Alg. 4 "relpos" @@ -2080,7 +2169,7 @@ class SingleTemplateEmbedding(hk.Module): self.config.template_pair_stack, self.global_config)( act, mask_2d, is_training) - act = hk.LayerNorm([-1], True, True, name='output_layer_norm')(act) + act = common_modules.LayerNorm([-1], True, True, name='output_layer_norm')(act) return act diff --git a/src/alphafold/model/modules_multimer.py b/src/alphafold/model/modules_multimer.py index c3a9b46dc161cb9c8a3d79381f0daf76e2d17a38..bf0125b30d60349a2a8c94de74225d4128e03c4b 100644 --- a/src/alphafold/model/modules_multimer.py +++ b/src/alphafold/model/modules_multimer.py @@ -594,11 +594,13 @@ class EmbeddingsAndEvoformer(hk.Module): Feature embedding using the features as described before. """ c = self.config + gc = self.global_config rel_feats = [] pos = batch['residue_index'] asym_id = batch['asym_id'] asym_id_same = jnp.equal(asym_id[:, None], asym_id[None, :]) offset = pos[:, None] - pos[None, :] + dtype = jnp.bfloat16 if gc.bfloat16 else jnp.float32 clipped_offset = jnp.clip( offset + c.max_relative_idx, a_min=0, a_max=2 * c.max_relative_idx) @@ -638,6 +640,7 @@ class EmbeddingsAndEvoformer(hk.Module): rel_feat = jnp.concatenate(rel_feats, axis=-1) + rel_feat = rel_feat.astype(dtype) return common_modules.Linear( c.pair_channel, name='position_activations')( @@ -649,6 +652,7 @@ class EmbeddingsAndEvoformer(hk.Module): gc = self.global_config batch = dict(batch) + dtype = jnp.bfloat16 if gc.bfloat16 else jnp.float32 if safe_key is None: safe_key = prng.SafeKey(hk.next_rng_key()) @@ -657,180 +661,178 @@ class EmbeddingsAndEvoformer(hk.Module): batch['msa_profile'] = make_msa_profile(batch) - # print(f"aatype = {batch['aatype']}") - target_feat = jax.nn.one_hot(batch['aatype'], 21) - - preprocess_1d = common_modules.Linear( - c.msa_channel, name='preprocess_1d')( - target_feat) - - safe_key, sample_key, mask_key = safe_key.split(3) - batch = sample_msa(sample_key, batch, c.num_msa) - batch = make_masked_msa(batch, mask_key, c.masked_msa) - - (batch['cluster_profile'], - batch['cluster_deletion_mean']) = nearest_neighbor_clusters(batch) - - msa_feat = create_msa_feat(batch) - - preprocess_msa = common_modules.Linear( - c.msa_channel, name='preprocess_msa')( - msa_feat) - - msa_activations = jnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa - - left_single = common_modules.Linear( - c.pair_channel, name='left_single')( - target_feat) - right_single = common_modules.Linear( - c.pair_channel, name='right_single')( - target_feat) - pair_activations = left_single[:, None] + right_single[None] - mask_2d = batch['seq_mask'][:, None] * batch['seq_mask'][None, :] - mask_2d = mask_2d.astype(jnp.float32) - - if c.recycle_pos and 'prev_pos' in batch: - prev_pseudo_beta = modules.pseudo_beta_fn( - batch['aatype'], batch['prev_pos'], None) - - dgram = modules.dgram_from_positions( - prev_pseudo_beta, **self.config.prev_pos) - pair_activations += common_modules.Linear( - c.pair_channel, name='prev_pos_linear')( - dgram) - - if c.recycle_features: - if 'prev_msa_first_row' in batch: - prev_msa_first_row = hk.LayerNorm( + with utils.bfloat16_context(): + target_feat = jax.nn.one_hot(batch['aatype'], 21).astype(dtype) + + preprocess_1d = common_modules.Linear( + c.msa_channel, name='preprocess_1d')( + target_feat) + + safe_key, sample_key, mask_key = safe_key.split(3) + batch = sample_msa(sample_key, batch, c.num_msa) + batch = make_masked_msa(batch, mask_key, c.masked_msa) + + (batch['cluster_profile'], + batch['cluster_deletion_mean']) = nearest_neighbor_clusters(batch) + + msa_feat = create_msa_feat(batch).astype(dtype) + + preprocess_msa = common_modules.Linear( + c.msa_channel, name='preprocess_msa')( + msa_feat) + msa_activations = jnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa + + left_single = common_modules.Linear( + c.pair_channel, name='left_single')( + target_feat) + right_single = common_modules.Linear( + c.pair_channel, name='right_single')( + target_feat) + pair_activations = left_single[:, None] + right_single[None] + mask_2d = batch['seq_mask'][:, None] * batch['seq_mask'][None, :] + mask_2d = mask_2d.astype(dtype) + + if c.recycle_pos: + prev_pseudo_beta = modules.pseudo_beta_fn( + batch['aatype'], batch['prev_pos'], None) + + dgram = modules.dgram_from_positions( + prev_pseudo_beta, **self.config.prev_pos) + dgram = dgram.astype(dtype) + pair_activations += common_modules.Linear( + c.pair_channel, name='prev_pos_linear')( + dgram) + if c.recycle_features: + prev_msa_first_row = common_modules.LayerNorm( axis=[-1], create_scale=True, create_offset=True, name='prev_msa_first_row_norm')( - batch['prev_msa_first_row']) + batch['prev_msa_first_row']).astype(dtype) msa_activations = msa_activations.at[0].add(prev_msa_first_row) - if 'prev_pair' in batch: - pair_activations += hk.LayerNorm( + pair_activations += common_modules.LayerNorm( axis=[-1], create_scale=True, create_offset=True, name='prev_pair_norm')( - batch['prev_pair']) + batch['prev_pair']).astype(dtype) - if c.max_relative_idx: - pair_activations += self._relative_encoding(batch) + if c.max_relative_idx: + pair_activations += self._relative_encoding(batch) - if c.template.enabled: - template_module = TemplateEmbedding(c.template, gc) - template_batch = { - 'template_aatype': batch['template_aatype'], - 'template_all_atom_positions': batch['template_all_atom_positions'], - 'template_all_atom_mask': batch['template_all_atom_mask'] + if c.template.enabled: + template_module = TemplateEmbedding(c.template, gc) + template_batch = { + 'template_aatype': batch['template_aatype'], + 'template_all_atom_positions': batch['template_all_atom_positions'], + 'template_all_atom_mask': batch['template_all_atom_mask'] + } + # Construct a mask such that only intra-chain template features are + # computed, since all templates are for each chain individually. + multichain_mask = batch['asym_id'][:, None] == batch['asym_id'][None, :] + safe_key, safe_subkey = safe_key.split() + template_act = template_module( + query_embedding=pair_activations, + template_batch=template_batch, + padding_mask_2d=mask_2d, + multichain_mask_2d=multichain_mask, + is_training=is_training, + safe_key=safe_subkey) + pair_activations += template_act + + # Extra MSA stack. + (extra_msa_feat, + extra_msa_mask) = create_extra_msa_feature(batch, c.num_extra_msa) + extra_msa_activations = common_modules.Linear( + c.extra_msa_channel, + name='extra_msa_activations')( + extra_msa_feat).astype(dtype) + extra_msa_mask = extra_msa_mask.astype(dtype) + + extra_evoformer_input = { + 'msa': extra_msa_activations, + 'pair': pair_activations, } - # Construct a mask such that only intra-chain template features are - # computed, since all templates are for each chain individually. - multichain_mask = batch['asym_id'][:, None] == batch['asym_id'][None, :] - safe_key, safe_subkey = safe_key.split() - template_act = template_module( - query_embedding=pair_activations, - template_batch=template_batch, - padding_mask_2d=mask_2d, - multichain_mask_2d=multichain_mask, - is_training=is_training, - safe_key=safe_subkey) - pair_activations += template_act - - # Extra MSA stack. - (extra_msa_feat, - extra_msa_mask) = create_extra_msa_feature(batch, c.num_extra_msa) - extra_msa_activations = common_modules.Linear( - c.extra_msa_channel, - name='extra_msa_activations')( - extra_msa_feat) - extra_msa_mask = extra_msa_mask.astype(jnp.float32) - - extra_evoformer_input = { - 'msa': extra_msa_activations, - 'pair': pair_activations, - } - extra_masks = {'msa': extra_msa_mask, 'pair': mask_2d} + extra_masks = {'msa': extra_msa_mask, 'pair': mask_2d} - extra_evoformer_iteration = modules.EvoformerIteration( - c.evoformer, gc, is_extra_msa=True, name='extra_msa_stack') + extra_evoformer_iteration = modules.EvoformerIteration( + c.evoformer, gc, is_extra_msa=True, name='extra_msa_stack') - def extra_evoformer_fn(x): - act, safe_key = x - safe_key, safe_subkey = safe_key.split() - extra_evoformer_output = extra_evoformer_iteration( - activations=act, - masks=extra_masks, - is_training=is_training, - safe_key=safe_subkey) - return (extra_evoformer_output, safe_key) + def extra_evoformer_fn(x): + act, safe_key = x + safe_key, safe_subkey = safe_key.split() + extra_evoformer_output = extra_evoformer_iteration( + activations=act, + masks=extra_masks, + is_training=is_training, + safe_key=safe_subkey) + return (extra_evoformer_output, safe_key) - if gc.use_remat: - extra_evoformer_fn = hk.remat(extra_evoformer_fn) + if gc.use_remat: + extra_evoformer_fn = hk.remat(extra_evoformer_fn) - safe_key, safe_subkey = safe_key.split() - extra_evoformer_stack = layer_stack.layer_stack( - c.extra_msa_stack_num_block)( - extra_evoformer_fn) - extra_evoformer_output, safe_key = extra_evoformer_stack( - (extra_evoformer_input, safe_subkey)) - - pair_activations = extra_evoformer_output['pair'] - - # Get the size of the MSA before potentially adding templates, so we - # can crop out the templates later. - num_msa_sequences = msa_activations.shape[0] - evoformer_input = { - 'msa': msa_activations, - 'pair': pair_activations, - } - evoformer_masks = {'msa': batch['msa_mask'].astype(jnp.float32), - 'pair': mask_2d} - - if c.template.enabled: - template_features, template_masks = ( - template_embedding_1d(batch=batch, num_channel=c.msa_channel)) - - evoformer_input['msa'] = jnp.concatenate( - [evoformer_input['msa'], template_features], axis=0) - evoformer_masks['msa'] = jnp.concatenate( - [evoformer_masks['msa'], template_masks], axis=0) - - evoformer_iteration = modules.EvoformerIteration( - c.evoformer, gc, is_extra_msa=False, name='evoformer_iteration') - - def evoformer_fn(x): - act, safe_key = x safe_key, safe_subkey = safe_key.split() - evoformer_output = evoformer_iteration( - activations=act, - masks=evoformer_masks, - is_training=is_training, - safe_key=safe_subkey) - return (evoformer_output, safe_key) - - if gc.use_remat: - evoformer_fn = hk.remat(evoformer_fn) + extra_evoformer_stack = layer_stack.layer_stack( + c.extra_msa_stack_num_block)( + extra_evoformer_fn) + extra_evoformer_output, safe_key = extra_evoformer_stack( + (extra_evoformer_input, safe_subkey)) + + pair_activations = extra_evoformer_output['pair'] + + # Get the size of the MSA before potentially adding templates, so we + # can crop out the templates later. + num_msa_sequences = msa_activations.shape[0] + evoformer_input = { + 'msa': msa_activations, + 'pair': pair_activations, + } + evoformer_masks = { + 'msa': batch['msa_mask'].astype(dtype), + 'pair': mask_2d + } + if c.template.enabled: + template_features, template_masks = ( + template_embedding_1d( + batch=batch, num_channel=c.msa_channel, global_config=gc)) + + evoformer_input['msa'] = jnp.concatenate( + [evoformer_input['msa'], template_features], axis=0) + evoformer_masks['msa'] = jnp.concatenate( + [evoformer_masks['msa'], template_masks], axis=0) + evoformer_iteration = modules.EvoformerIteration( + c.evoformer, gc, is_extra_msa=False, name='evoformer_iteration') + + def evoformer_fn(x): + act, safe_key = x + safe_key, safe_subkey = safe_key.split() + evoformer_output = evoformer_iteration( + activations=act, + masks=evoformer_masks, + is_training=is_training, + safe_key=safe_subkey) + return (evoformer_output, safe_key) + + if gc.use_remat: + evoformer_fn = hk.remat(evoformer_fn) - safe_key, safe_subkey = safe_key.split() - evoformer_stack = layer_stack.layer_stack(c.evoformer_num_block)( - evoformer_fn) + safe_key, safe_subkey = safe_key.split() + evoformer_stack = layer_stack.layer_stack(c.evoformer_num_block)( + evoformer_fn) - def run_evoformer(evoformer_input): - evoformer_output, _ = evoformer_stack((evoformer_input, safe_subkey)) - return evoformer_output + def run_evoformer(evoformer_input): + evoformer_output, _ = evoformer_stack((evoformer_input, safe_subkey)) + return evoformer_output - evoformer_output = run_evoformer(evoformer_input) + evoformer_output = run_evoformer(evoformer_input) - msa_activations = evoformer_output['msa'] - pair_activations = evoformer_output['pair'] + msa_activations = evoformer_output['msa'] + pair_activations = evoformer_output['pair'] - single_activations = common_modules.Linear( - c.seq_channel, name='single_activations')( - msa_activations[0]) + single_activations = common_modules.Linear( + c.seq_channel, name='single_activations')( + msa_activations[0]) output.update({ 'single': @@ -844,6 +846,12 @@ class EmbeddingsAndEvoformer(hk.Module): msa_activations[0], }) + # Convert back to float32 if we're not saving memory. + if not gc.bfloat16_output: + for k, v in output.items(): + if v.dtype == jnp.bfloat16: + output[k] = v.astype(jnp.float32) + return output @@ -990,6 +998,9 @@ class SingleTemplateEmbedding(hk.Module): # backbone affine - i.e. in each residues local frame, what direction are # each of the other residues. raw_atom_pos = template_all_atom_positions + if gc.bfloat16: + # Vec3Arrays are required to be float32 + raw_atom_pos = raw_atom_pos.astype(jnp.float32) atom_pos = geometry.Vec3Array.from_array(raw_atom_pos) rigid, backbone_mask = folding_multimer.make_backbone_affine( @@ -1001,6 +1012,10 @@ class SingleTemplateEmbedding(hk.Module): unit_vector = rigid_vec.normalized() unit_vector = [unit_vector.x, unit_vector.y, unit_vector.z] + if gc.bfloat16: + unit_vector = [x.astype(jnp.bfloat16) for x in unit_vector] + backbone_mask = backbone_mask.astype(jnp.bfloat16) + backbone_mask_2d = backbone_mask[:, None] * backbone_mask[None, :] backbone_mask_2d *= multichain_mask_2d unit_vector = [x*backbone_mask_2d for x in unit_vector] @@ -1010,7 +1025,7 @@ class SingleTemplateEmbedding(hk.Module): to_concat.extend([(x, 0) for x in unit_vector]) to_concat.append((backbone_mask_2d, 0)) - query_embedding = hk.LayerNorm( + query_embedding = common_modules.LayerNorm( axis=[-1], create_scale=True, create_offset=True, @@ -1059,12 +1074,13 @@ class SingleTemplateEmbedding(hk.Module): template_iteration_fn) act, safe_key = template_stack((act, safe_subkey)) - act = hk.LayerNorm( + act = common_modules.LayerNorm( axis=[-1], create_scale=True, create_offset=True, name='output_layer_norm')( act) + return act @@ -1117,21 +1133,18 @@ class TemplateEmbeddingIteration(hk.Module): act, pair_mask, safe_key=next(sub_keys)) - act = dropout_wrapper_fn( modules.TriangleAttention(c.triangle_attention_starting_node, gc, name='triangle_attention_starting_node'), act, pair_mask, safe_key=next(sub_keys)) - act = dropout_wrapper_fn( modules.TriangleAttention(c.triangle_attention_ending_node, gc, name='triangle_attention_ending_node'), act, pair_mask, safe_key=next(sub_keys)) - act = dropout_wrapper_fn( modules.Transition(c.pair_transition, gc, name='pair_transition'), @@ -1142,7 +1155,7 @@ class TemplateEmbeddingIteration(hk.Module): return act -def template_embedding_1d(batch, num_channel): +def template_embedding_1d(batch, num_channel, global_config): """Embed templates into an (num_res, num_templates, num_channels) embedding. Args: @@ -1153,6 +1166,7 @@ def template_embedding_1d(batch, num_channel): template_all_atom_mask, (num_templates, num_residues, 37) atom mask for each template. num_channel: The number of channels in the output. + global_config: The global_config. Returns: An embedding of shape (num_templates, num_res, num_channels) and a mask of @@ -1186,6 +1200,10 @@ def template_embedding_1d(batch, num_channel): template_mask = chi_mask[:, :, 0] + if global_config.bfloat16: + template_features = template_features.astype(jnp.bfloat16) + template_mask = template_mask.astype(jnp.bfloat16) + template_activations = common_modules.Linear( num_channel, initializer='relu', diff --git a/src/alphafold/model/tf/protein_features_test.py b/src/alphafold/model/tf/protein_features_test.py index ee88711281f5cfb366ab26b10c75f37b211120ea..f5a351ba863584222272318f10304e9707fe3edf 100644 --- a/src/alphafold/model/tf/protein_features_test.py +++ b/src/alphafold/model/tf/protein_features_test.py @@ -27,6 +27,10 @@ def _random_bytes(): class FeaturesTest(parameterized.TestCase, tf.test.TestCase): + def setUp(self): + super().setUp() + tf.disable_v2_behavior() + def testFeatureNames(self): self.assertEqual(len(protein_features.FEATURE_SIZES), len(protein_features.FEATURE_TYPES)) @@ -47,5 +51,4 @@ class FeaturesTest(parameterized.TestCase, tf.test.TestCase): if __name__ == '__main__': - tf.disable_v2_behavior() absltest.main() diff --git a/src/alphafold/model/tf/shape_helpers_test.py b/src/alphafold/model/tf/shape_helpers_test.py index d7797b340514d9577dd77b9e9660babd0aa52b5e..16c032bae89db6320f27920f1715159f2319231d 100644 --- a/src/alphafold/model/tf/shape_helpers_test.py +++ b/src/alphafold/model/tf/shape_helpers_test.py @@ -21,6 +21,10 @@ import tensorflow.compat.v1 as tf class ShapeTest(tf.test.TestCase): + def setUp(self): + super().setUp() + tf.disable_v2_behavior() + def test_shape_list(self): """Test that shape_list can allow for reshaping to dynamic shapes.""" a = tf.zeros([10, 4, 4, 2]) @@ -35,5 +39,4 @@ class ShapeTest(tf.test.TestCase): if __name__ == '__main__': - tf.disable_v2_behavior() tf.test.main() diff --git a/src/alphafold/model/utils.py b/src/alphafold/model/utils.py index 40ca1683eab8c71522aae3ff9fa1dced89439772..8d70376e38bade8374592c38e5bd52d4e1f3ee4a 100644 --- a/src/alphafold/model/utils.py +++ b/src/alphafold/model/utils.py @@ -15,6 +15,7 @@ """A collection of JAX utility functions for use in protein folding.""" import collections +import contextlib import functools import numbers from typing import Mapping @@ -25,6 +26,27 @@ import jax.numpy as jnp import numpy as np +def bfloat16_creator(next_creator, shape, dtype, init, context): + """Creates float32 variables when bfloat16 is requested.""" + if context.original_dtype == jnp.bfloat16: + dtype = jnp.float32 + return next_creator(shape, dtype, init) + + +def bfloat16_getter(next_getter, value, context): + """Casts float32 to bfloat16 when bfloat16 was originally requested.""" + if context.original_dtype == jnp.bfloat16: + assert value.dtype == jnp.float32 + value = value.astype(jnp.bfloat16) + return next_getter(value) + + +@contextlib.contextmanager +def bfloat16_context(): + with hk.custom_creator(bfloat16_creator), hk.custom_getter(bfloat16_getter): + yield + + def final_init(config): if config.zero_init: return 'zeros' @@ -34,7 +56,7 @@ def final_init(config): def batched_gather(params, indices, axis=0, batch_dims=0): """Implements a JAX equivalent of `tf.gather` with `axis` and `batch_dims`.""" - take_fn = lambda p, i: jnp.take(p, i, axis=axis) + take_fn = lambda p, i: jnp.take(p, i, axis=axis, mode='clip') for _ in range(batch_dims): take_fn = jax.vmap(take_fn) return take_fn(params, indices) @@ -54,7 +76,7 @@ def mask_mean(mask, value, axis=None, drop_mask_channel=False, eps=1e-10): axis = [axis] elif axis is None: axis = list(range(len(mask_shape))) - assert isinstance(axis, collections.Iterable), ( + assert isinstance(axis, collections.abc.Iterable), ( 'axis needs to be either an iterable, integer or "None"') broadcast_factor = 1. diff --git a/src/alphafold/relax/amber_minimize.py b/src/alphafold/relax/amber_minimize.py index ef1496942e5c422f5556c56b447d67354e9c1496..e21a0dc302109f01623a75ffb533b28958e85944 100644 --- a/src/alphafold/relax/amber_minimize.py +++ b/src/alphafold/relax/amber_minimize.py @@ -26,6 +26,7 @@ from alphafold.relax import cleanup from alphafold.relax import utils import ml_collections import numpy as np +import jax from simtk import openmm from simtk import unit from simtk.openmm import app as openmm_app @@ -486,7 +487,9 @@ def run_pipeline( pdb_string = clean_protein(prot, checks=True) else: pdb_string = ret["min_pdb"] - ret.update(get_violation_metrics(prot)) + # Calculation of violations can cause CUDA errors for some JAX versions. + with jax.default_device(jax.devices("cpu")[0]): + ret.update(get_violation_metrics(prot)) ret.update({ "num_exclusions": len(exclude_residues), "iteration": iteration, @@ -500,51 +503,3 @@ def run_pipeline( ret["num_residue_violations"], ret["num_exclusions"]) iteration += 1 return ret - - -def get_initial_energies(pdb_strs: Sequence[str], - stiffness: float = 0.0, - restraint_set: str = "non_hydrogen", - exclude_residues: Optional[Sequence[int]] = None): - """Returns initial potential energies for a sequence of PDBs. - - Assumes the input PDBs are ready for minimization, and all have the same - topology. - Allows time to be saved by not pdbfixing / rebuilding the system. - - Args: - pdb_strs: List of PDB strings. - stiffness: kcal/mol A**2, spring constant of heavy atom restraining - potential. - restraint_set: Which atom types to restrain. - exclude_residues: An optional list of zero-indexed residues to exclude from - restraints. - - Returns: - A list of initial energies in the same order as pdb_strs. - """ - exclude_residues = exclude_residues or [] - - openmm_pdbs = [openmm_app.PDBFile(PdbStructure(io.StringIO(p))) - for p in pdb_strs] - force_field = openmm_app.ForceField("amber99sb.xml") - system = force_field.createSystem(openmm_pdbs[0].topology, - constraints=openmm_app.HBonds) - stiffness = stiffness * ENERGY / (LENGTH**2) - if stiffness > 0 * ENERGY / (LENGTH**2): - _add_restraints(system, openmm_pdbs[0], stiffness, restraint_set, - exclude_residues) - simulation = openmm_app.Simulation(openmm_pdbs[0].topology, - system, - openmm.LangevinIntegrator(0, 0.01, 0.0), - openmm.Platform.getPlatformByName("CPU")) - energies = [] - for pdb in openmm_pdbs: - try: - simulation.context.setPositions(pdb.positions) - state = simulation.context.getState(getEnergy=True) - energies.append(state.getPotentialEnergy().value_in_unit(ENERGY)) - except Exception as e: # pylint: disable=broad-except - logging.error("Error getting initial energy, returning large value %s", e) - energies.append(unit.Quantity(1e20, ENERGY)) - return energies diff --git a/src/alphafold/relax/relax.py b/src/alphafold/relax/relax.py index bd6c9fd04b277679ece2b62a7c373f1127e6b1a6..ebbd72d0247b624c6830dee9c0a01ec5a5d6ae61 100644 --- a/src/alphafold/relax/relax.py +++ b/src/alphafold/relax/relax.py @@ -56,7 +56,8 @@ class AmberRelaxation(object): self._use_gpu = use_gpu def process(self, *, - prot: protein.Protein) -> Tuple[str, Dict[str, Any], np.ndarray]: + prot: protein.Protein + ) -> Tuple[str, Dict[str, Any], Sequence[float]]: """Runs Amber relax on a prediction, adds hydrogens, returns PDB string.""" out = amber_minimize.run_pipeline( prot=prot, max_iterations=self._max_iterations, @@ -73,12 +74,11 @@ class AmberRelaxation(object): 'attempts': out['min_attempts'], 'rmsd': rmsd } - pdb_str = amber_minimize.clean_protein(prot) - min_pdb = utils.overwrite_pdb_coordinates(pdb_str, min_pos) + min_pdb = out['min_pdb'] min_pdb = utils.overwrite_b_factors(min_pdb, prot.b_factors) utils.assert_equal_nonterminal_atom_types( protein.from_pdb_string(min_pdb).atom_mask, prot.atom_mask) violations = out['structural_violations'][ - 'total_per_residue_violations_mask'] + 'total_per_residue_violations_mask'].tolist() return min_pdb, debug_data, violations diff --git a/src/alphafold/relax/relax_test.py b/src/alphafold/relax/relax_test.py index 57e594e8a4f684e8bbab0bf645bad3776cec3d00..8ab5142e31f4955863f4b75c8c5eb8042bdcdb1c 100644 --- a/src/alphafold/relax/relax_test.py +++ b/src/alphafold/relax/relax_test.py @@ -82,7 +82,7 @@ class RunAmberRelaxTest(absltest.TestCase): 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0]) # Check no violations were added. Can't check exactly due to stochasticity. - self.assertTrue(np.all(num_violations <= exp_num_violations)) + self.assertTrue(np.all(np.array(num_violations) <= exp_num_violations)) if __name__ == '__main__': diff --git a/src/alphafold/relax/utils.py b/src/alphafold/relax/utils.py index 4bd4acad4e5e7f071fb0ad98e7711b4856843464..0207df5ba24d4876208a26ee06b7e0c65b509ee8 100644 --- a/src/alphafold/relax/utils.py +++ b/src/alphafold/relax/utils.py @@ -17,17 +17,6 @@ import io from alphafold.common import residue_constants from Bio import PDB import numpy as np -from simtk.openmm import app as openmm_app -from simtk.openmm.app.internal.pdbstructure import PdbStructure - - -def overwrite_pdb_coordinates(pdb_str: str, pos) -> str: - pdb_file = io.StringIO(pdb_str) - structure = PdbStructure(pdb_file) - topology = openmm_app.PDBFile(structure).getTopology() - with io.StringIO() as f: - openmm_app.PDBFile.writeFile(topology, pos, f) - return f.getvalue() def overwrite_b_factors(pdb_str: str, bfactors: np.ndarray) -> str: @@ -74,7 +63,7 @@ def assert_equal_nonterminal_atom_types( """Checks that pre- and post-minimized proteins have same atom set.""" # Ignore any terminal OXT atoms which may have been added by minimization. oxt = residue_constants.atom_order['OXT'] - no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=np.bool) + no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=bool) no_oxt_mask[..., oxt] = False np.testing.assert_almost_equal(ref_atom_mask[no_oxt_mask], atom_mask[no_oxt_mask]) diff --git a/src/run_af2c_fea.py b/src/run_af2c_fea.py index 23e5407ccf12152418b63ec20b6d10ffe0eb3637..982327143bd8016fe6cc88a0bce6a6bf8486e0fa 100644 --- a/src/run_af2c_fea.py +++ b/src/run_af2c_fea.py @@ -40,6 +40,7 @@ from alphafold.common import protein from alphafold.common import residue_constants from alphafold.data import pipeline from alphafold.data import pipeline_multimer +from alphafold.data import pipeline_uniprot from alphafold.data import templates from alphafold.data.tools import hhsearch from alphafold.data.tools import hmmsearch @@ -98,14 +99,16 @@ flags.DEFINE_string('obsolete_pdbs_path', None, 'Path to file containing a ' 'mapping from obsolete PDB IDs to the PDB IDs of their ' 'replacements.') flags.DEFINE_enum('db_preset', 'reduced_dbs', - ['full_dbs', 'reduced_dbs'], + ['full_dbs', 'reduced_dbs', 'uniprot'], 'Choose preset MSA database configuration - ' 'smaller genetic database config (reduced_dbs) or ' - 'full genetic database config (full_dbs)') + 'full genetic database config (full_dbs) or ' + 'only uniprot database plus pdb (uniprot)') flags.DEFINE_enum('feature_mode', 'monomer', - ['monomer', 'monomer+species', 'multimer'], + ['monomer', 'monomer+species', 'monomer+fullpdb', 'multimer'], 'Choose the mode of output feature sets - for monomer prediction, ' - 'monomer plus species ids (for customized pairing later), and ' + 'monomer plus species ids (for customized pairing later), ' + 'monomer using full pdb (instead of pdb70) plus species id, and ' 'for multimer prediction using the default MSA pairing') flags.DEFINE_integer('random_seed', None, 'The random seed for the data ' 'pipeline. By default, this is randomly generated. Note ' @@ -182,21 +185,24 @@ def main(argv): raise ValueError(f'Could not find path to the "{tool_name}" binary. Make ' 'sure it is installed on your system.') - use_small_bfd = FLAGS.db_preset == 'reduced_dbs' - _check_flag('small_bfd_database_path', 'db_preset', - should_be_set=use_small_bfd) - _check_flag('bfd_database_path', 'db_preset', - should_be_set=not use_small_bfd) - _check_flag('uniclust30_database_path', 'db_preset', - should_be_set=not use_small_bfd) + if FLAGS.db_preset != 'uniprot': + use_small_bfd = FLAGS.db_preset == 'reduced_dbs' + _check_flag('small_bfd_database_path', 'db_preset', + should_be_set=use_small_bfd) + _check_flag('bfd_database_path', 'db_preset', + should_be_set=(FLAGS.db_preset == 'full_dbs')) + _check_flag('uniclust30_database_path', 'db_preset', + should_be_set=(FLAGS.db_preset == 'full_dbs')) run_multimer_system = 'multimer' in FLAGS.feature_mode - add_species = 'species' in FLAGS.feature_mode + add_species = any(x in FLAGS.feature_mode for x in ['species', 'fullpdb']) print(f"add_species is {add_species}") _check_flag('pdb70_database_path', 'feature_mode', - should_be_set=not run_multimer_system) + should_be_set=not (run_multimer_system + or 'fullpdb' in FLAGS.feature_mode)) _check_flag('pdb_seqres_database_path', 'feature_mode', - should_be_set=run_multimer_system) + should_be_set=(run_multimer_system + or 'fullpdb' in FLAGS.feature_mode)) _check_flag('uniprot_database_path', 'feature_mode', should_be_set=(run_multimer_system or add_species)) @@ -206,7 +212,7 @@ def main(argv): raise ValueError('All FASTA paths must have a unique basename.') - if run_multimer_system: + if run_multimer_system or 'fullpdb' in FLAGS.feature_mode: template_searcher = hmmsearch.Hmmsearch( binary_path=FLAGS.hmmsearch_binary_path, hmmbuild_binary_path=FLAGS.hmmbuild_binary_path, @@ -235,7 +241,8 @@ def main(argv): # else: # mono_uniprot_database_path = FLAGS.uniref90_database_path - monomer_data_pipeline = pipeline.DataPipeline( + if FLAGS.db_preset != 'uniprot': + monomer_data_pipeline = pipeline.DataPipeline( jackhmmer_binary_path=FLAGS.jackhmmer_binary_path, hhblits_binary_path=FLAGS.hhblits_binary_path, uniref90_database_path=FLAGS.uniref90_database_path, @@ -249,6 +256,15 @@ def main(argv): use_small_bfd=use_small_bfd, use_precomputed_msas=FLAGS.use_precomputed_msas, add_species=add_species) + else: + monomer_data_pipeline = pipeline_uniprot.DataPipeline( + jackhmmer_binary_path=FLAGS.jackhmmer_binary_path, + hhblits_binary_path=FLAGS.hhblits_binary_path, + uniprot_database_path=FLAGS.uniprot_database_path, + template_searcher=template_searcher, + template_featurizer=template_featurizer, + use_precomputed_msas=FLAGS.use_precomputed_msas, + add_species=add_species) if run_multimer_system: data_pipeline = pipeline_multimer.DataPipeline( diff --git a/src/run_af2c_min.py b/src/run_af2c_min.py index bf06d2c839699f5e4c77cb13d24438aaed90f73d..0ff4519677a2cc27b55f17ccb733bf56be3ee0ca 100644 --- a/src/run_af2c_min.py +++ b/src/run_af2c_min.py @@ -11,6 +11,7 @@ import os import pickle import re import time +import json from absl import app from absl import flags @@ -78,7 +79,8 @@ def main(argv): output_dir = os.path.join(FLAGS.output_dir, target_name) if not os.path.exists(output_dir): os.makedirs(output_dir) - + + relax_metrics = {} for afile in os.listdir(input_dir): # find all unrelaxed pdb files, ignore ones with 'relaxed' as prefix if not afile.endswith(".pdb") or afile.startswith("relaxed_"): @@ -96,6 +98,11 @@ def main(argv): # Relax the prediction. t_0 = time.time() relaxed_pdb_str, log, _ = amber_relaxer.process(prot=unrelaxed_protein) + relaxed_pdb_str, _, violations = amber_relaxer.process(prot=unrelaxed_protein) + relax_metrics[f'relaxed_{unrelaxed_pdb_file}'] = { + 'remaining_violations': violations, + 'remaining_violations_count': sum(violations) + } relaxation_time = time.time() - t_0 # Fix residue index, and ignore hydrogen atoms @@ -107,7 +114,12 @@ def main(argv): with open(relaxed_output_path, 'w') as f: f.write(relaxed_pdb_str) - logging.info(f"{target_name} relaxation done, time spent {relaxation_time:.1f} seconds, Efinal {log['final_energy']:.2f}, rmsd {log['rmsd']:.2f}") + logging.info(f"{target_name} relaxation done, time spent {relaxation_time:.1f} seconds, " + f"Efinal {log['final_energy']:.2f}, rmsd {log['rmsd']:.2f}") + + relax_metrics_path = os.path.join(output_dir, 'relax_metrics.json') + with open(relax_metrics_path, 'w') as f: + f.write(json.dumps(relax_metrics, indent=4)) if __name__ == '__main__': flags.mark_flags_as_required([ diff --git a/src/run_af2c_mod.py b/src/run_af2c_mod.py index 2f6f8fc74db91954482e873c633ab623d02fa73b..0ac472b0356c335054260d1c330cd25c0e7d751f 100644 --- a/src/run_af2c_mod.py +++ b/src/run_af2c_mod.py @@ -20,7 +20,7 @@ # # Note: AF2Complex is a modified, enhanced version of AlphaFold 2. # Mu Gao and Davi Nakajima An -# Georgia Institute of Technology, 2021-2022 +# Georgia Institute of Technology, 2021-2023 # """AF2Complex: protein complex structure prediction with deep learning""" import json @@ -116,6 +116,10 @@ flags.DEFINE_enum('msa_pairing', None, flags.DEFINE_boolean('do_cluster_analysis', False, 'Whether to print out clusters of protein chains in the prediction') flags.DEFINE_integer('cluster_edge_thres', 10, 'The number of contacts between chains that constitute an edge in the ' 'cluster analysis', lower_bound=0) +flags.DEFINE_float('pdb_iscore_cf', -1.0, 'If interface icore is present, only write the pdb of the structural model ' + 'if the iScore is larger than this cutoff value. Useful for large-scale screening. ') +flags.DEFINE_boolean('allow_dropout', False, 'Allow dropout during model inference. This is an experimental feature. ' + 'Default is disabled.') FLAGS = flags.FLAGS @@ -274,7 +278,7 @@ def predict_structure( # Get mean pLDDT confidence metric. plddt = np.mean(result['plddt']) plddts[log_model_name] = round(plddt, 2) - ptm = 0 + ptm = 0; iptm = 0 if 'ptm' in result: ptm = result['ptm'].tolist() ptms[log_model_name] = round(ptm, 4) @@ -299,16 +303,20 @@ def predict_structure( print(f"Info: {target_name} {log_model_name}, ", f"tol = {tol:5.2f}, pLDDT = {plddt:.2f}, pTM-score = {ptm:.4f}", end='') if len(monomers) > 1 or monomers[0]['copy_number'] > 1: # hetero- or homo-oligomer target - print(f", piTM-score = {pitm:.4f}, interface-score = {inter_sc:.4f}", end='') - print(f", iRes = {inter_residues:<4d} iCnt = {inter_contacts:<4.0f}") + print(f", piTM-score = {pitm:.4f}, interface-score = {inter_sc:.4f}", + f", iRes = {inter_residues:<4d} iCnt = {inter_contacts:<4.0f}") else: print('') else: print(f"Info: {target_name} {log_model_name} performed {tot_recycle} recycles,", - f"final tol = {tol_value:.2f}, pLDDT = {plddt:.2f}, pTM-score = {ptm:.4f}", end='') - if len(monomers) > 1 or monomers[0]['copy_number'] > 1: - print(f", piTM-score = {pitm:.4f}, interface-score = {inter_sc:.4f}", end='') - print(f", iRes = {inter_residues:<4d} iCnt = {inter_contacts:<4.0f}") + f"final tol = {tol_value:.2f}, pLDDT = {plddt:.2f}", end='') + if 'iptm+ptm' in result: + print(f", iptm+ptm = {iptm:.4f}", end='') + else: + print(f", pTM-score = {ptm:.4f}", end='') + if len(monomers) > 1 or monomers[0]['copy_number'] > 1: + print(f", piTM-score = {pitm:.4f}, interface-score = {inter_sc:.4f}", + f", iRes = {inter_residues:<4d} iCnt = {inter_contacts:<4.0f}") if 'cluster_analysis' in result: clus_res = result['cluster_analysis'] idx2chain_name = get_asymid2chain_name(target) @@ -326,12 +334,13 @@ def predict_structure( print('') # Save the model outputs, not saving pkl for intermeidate recycles to save storage space - if (recycle_index == tot_recycle and flags.output_pickle) or FLAGS.save_recycled == 3: + # skip saving pkl if iScore less than the specified cutoff + if ((recycle_index == tot_recycle and flags.output_pickle) or FLAGS.save_recycled == 3) and inter_sc > FLAGS.pdb_iscore_cf: result_output_path = os.path.join(out_dir, f'{log_model_name}.pkl') with open(result_output_path, 'wb') as f: pickle.dump(result, f, protocol=4) - if recycle_index == tot_recycle or FLAGS.save_recycled >= 2: + if (recycle_index == tot_recycle or FLAGS.save_recycled >= 2) and inter_sc > FLAGS.pdb_iscore_cf: # Set the b-factors to the per-residue plddt final_atom_mask = result['structure_module']['final_atom_mask'] b_factors = result['plddt'][:, None] * final_atom_mask @@ -469,6 +478,11 @@ def main(argv): model_config.model.recycle_tol = recycle_tol model_config.model.save_recycled = FLAGS.save_recycled + # allow drop out for model inference, this is an advanced feature only for expert users + if FLAGS.allow_dropout: + print("Info: allow dropout for model inference") + model_config.model.global_config.eval_dropout = True + model_params = data.get_model_haiku_params( model_name=model_name, data_dir=FLAGS.data_dir) model_runner = model.RunModel(model_config, model_params) diff --git a/src/setup.py b/src/setup.py index 762bd7be8f03bdb6142493dcd8a419dcd066939d..753f70097db6c57489ba07aa30e8830ff6e17634 100644 --- a/src/setup.py +++ b/src/setup.py @@ -18,7 +18,7 @@ from setuptools import setup setup( name='alphafold', - version='2.0.0', + version='2.3.1', description='An implementation of the inference pipeline of AlphaFold v2.0.' 'This is a completely new model that was entered as AlphaFold2 in CASP14 ' 'and published in Nature.', @@ -38,10 +38,14 @@ setup( 'jax', 'ml-collections', 'numpy', + 'pandas', 'scipy', - 'tensorflow', + 'tensorflow-cpu', + ], + tests_require=[ + 'matplotlib', # For notebook_utils_test. + 'mock', ], - tests_require=['mock'], classifiers=[ 'Development Status :: 5 - Production/Stable', 'Intended Audience :: Science/Research', @@ -49,6 +53,9 @@ setup( 'Operating System :: POSIX :: Linux', 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', 'Topic :: Scientific/Engineering :: Artificial Intelligence', ], ) diff --git a/tools/run_interface_score.py b/tools/run_interface_score.py index 3675be76e1253ca5964c08090c8851e1ded8c113..0b4611a3c10073a40261327cc9f567f5117fd9f0 100644 --- a/tools/run_interface_score.py +++ b/tools/run_interface_score.py @@ -1,15 +1,24 @@ """Calculate interface score using pickle output file generated by AF2Complex""" +# +# Mu Gao and Davi Nakajima An +# Georgia Institute of Technology +# import os import sys - -sys.path.append('../src/') import pickle import re import json +import time +import logging + from absl import app from absl import flags from tqdm import tqdm +parent_dir = os.path.dirname( os.path.dirname(os.path.realpath(__file__)) ) +af_dir = os.path.join(parent_dir, 'src') +sys.path.append(af_dir) + from run_af2c_mod import FLAGS, get_asymid2chain_name from alphafold.data.complex import make_complex_features @@ -22,6 +31,13 @@ import alphafold.data.complex as af2c import numpy as np # Internal import (7716). +flags.DEFINE_string('model_str', 'model_', 'Only relax a model with a specified ' + 'string, e.g., ranked_top1. The default will process all model_* pdb files') +flags.DEFINE_float('interface_dist_thres', 4.5, 'The distance threshold in Angstrom for interface residues. ' + 'If a heavy atom of residue i in chain A is less than this distance from ' + 'another heavy atom of residue j in chain B, residues i j are interface residues.' + 'This is only used to calculate the confidence metrics such as the interface score.', + lower_bound=3.5) def get_asym_id(target, flags): """Defines the sequence of preprocessing steps to get the asym_id feature @@ -82,7 +98,12 @@ def main(argv): asym_id = feature_dict['asym_id'] for pkl_file in os.listdir(target_dir): - if ".pkl" in pkl_file and pkl_file.startswith("model_"): + # find all pickle files + if not pkl_file.endswith(".pkl"): + continue + + if FLAGS.model_str in pkl_file: + t_0 = time.time() model_name = os.path.basename( pkl_file ).split(".")[0] model_config = config.model_config(pkl_file[:7]) breaks = np.linspace( @@ -103,16 +124,18 @@ def main(argv): result['structure_module']['final_atom_positions'], result['structure_module']['final_atom_mask'], super_asym_id, + distance_threshold=FLAGS.interface_dist_thres, is_probs=True) ptm = result['ptm'].tolist() pitm = result['pitm']['score'].tolist() + iptm = result['iptm+ptm'].tolist() inter_sc = res['score'].tolist() inter_residues = res['num_residues'].tolist() inter_contacts = res['num_contacts'].tolist() - print(f"Info: {target_name} (chains: {full_name}) {model_name} pTM-score = {ptm:.4f}, ", + print(f"Info: {target_name} (chains: {full_name}) {model_name} iptm+ptm = {iptm:.4f}, ", f"piTM-score = {pitm:.4f}, iRes = {inter_residues:<4d}, iCnt = {inter_contacts:<4.0f}, interface-score = {inter_sc:.4f}",) if FLAGS.do_cluster_analysis: @@ -129,6 +152,8 @@ def main(argv): print(f"Info: num_clusters = {clus_res['num_clusters']}, cluster_sizes = {clus_res['cluster_size']}, ", f"clusters = {cluster_identities}\n") + + logging.info('Interface score calculation time spent: %.1f seconds', time.time() - t_0) ''' fields = model_name.split('_')