diff --git a/README.md b/README.md index d85e9d69a176b788650ddcaf975efe91107553d9..1af8c720e0d575fd9424feb2757dd446cadbf81d 100644 --- a/README.md +++ b/README.md @@ -61,8 +61,9 @@ Please follow these steps: 1. Download genetic databases and model parameters: - * Install `aria2c` (on most Linux distributions it is available via the - package manager). + * Install `aria2c`. On most Linux distributions it is available via the + package manager as the `aria2` package (on Debian-based distributions this + can be installed by running `sudo apt install aria2`). * Please use the script `scripts/download_all_data.sh` to download and set up full databases. This may take substantial time (download size is @@ -362,9 +363,11 @@ section. --output_dir=/home/user/absolute_path_to_the_output_dir ``` -1. After generating the predicted model, by default AlphaFold runs a relaxation - step to improve geometrical quality. You can control this via `--run_relax=true` - (default) or `--run_relax=false`. +1. After generating the predicted model, AlphaFold runs a relaxation + step to improve local geometry. By default, only the best model (by + pLDDT) is relaxed (`--models_to_relax=best`), but also all of the models + (`--models_to_relax=all`) or none of the models (`--models_to_relax=none`) + can be relaxed. 1. The relaxation step can be run on GPU (faster, but could be less stable) or CPU (slow, but stable). This can be controlled with `--enable_gpu_relax=true` diff --git a/docker/run_docker.py b/docker/run_docker.py index bef5fdf063ef8a621993a03f30ffcfc8b10970fd..155d8fe2cd124497d6028d2288a02a8ae68b23fa 100644 --- a/docker/run_docker.py +++ b/docker/run_docker.py @@ -28,12 +28,15 @@ from docker import types flags.DEFINE_bool( 'use_gpu', True, 'Enable NVIDIA runtime to run with GPUs.') -flags.DEFINE_boolean( - 'run_relax', True, - 'Whether to run the final relaxation step on the predicted models. Turning ' - 'relax off might result in predictions with distracting stereochemical ' - 'violations but might help in case you are having issues with the ' - 'relaxation stage.') +flags.DEFINE_enum('models_to_relax', 'best', ['best', 'all', 'none'], + 'The models to run the final relaxation step on. ' + 'If `all`, all models are relaxed, which may be time ' + 'consuming. If `best`, only the most confident model is ' + 'relaxed. If `none`, relaxation is not run. Turning off ' + 'relaxation might result in predictions with ' + 'distracting stereochemical violations but might help ' + 'in case you are having issues with the relaxation ' + 'stage.') flags.DEFINE_bool( 'enable_gpu_relax', True, 'Run relax on GPU if GPU is enabled.') flags.DEFINE_string( @@ -221,7 +224,7 @@ def main(argv): f'--benchmark={FLAGS.benchmark}', f'--use_precomputed_msas={FLAGS.use_precomputed_msas}', f'--num_multimer_predictions_per_model={FLAGS.num_multimer_predictions_per_model}', - f'--run_relax={FLAGS.run_relax}', + f'--models_to_relax={FLAGS.models_to_relax}', f'--use_gpu_relax={use_gpu_relax}', '--logtostderr', ]) diff --git a/notebooks/AlphaFold.ipynb b/notebooks/AlphaFold.ipynb index 44b6eaa8230561e2c83fefc904424fcb48a79490..8aaf5bcaba3b3522e3ad748f2f6d870c07d0c5b8 100644 --- a/notebooks/AlphaFold.ipynb +++ b/notebooks/AlphaFold.ipynb @@ -342,6 +342,7 @@ "from concurrent import futures\n", "import json\n", "import random\n", + "import shutil\n", "\n", "from urllib import request\n", "from google.colab import files\n", @@ -773,7 +774,7 @@ " f.write(pae_data)\n", "\n", "# --- Download the predictions ---\n", - "!zip -q -r {output_dir}.zip {output_dir}\n", + "shutil.make_archive(base_name='prediction', format='zip', root_dir=output_dir)\n", "files.download(f'{output_dir}.zip')" ] }, diff --git a/requirements.txt b/requirements.txt index f8098f4afb40c428f148212fa8c3cfc0dc016bfb..a5819a1d3caea1be2ad1731af11d6fa55ce33442 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,5 @@ jax==0.3.25 ml-collections==0.1.0 numpy==1.21.6 pandas==1.3.4 -protobuf==3.20.1 scipy==1.7.0 -tensorflow-cpu==2.9.0 +tensorflow-cpu==2.11.0 diff --git a/run_alphafold.py b/run_alphafold.py index 4ddb31396f03f8b30deaa8e62982c2d9459ed500..0d89bfb47c8fc394e85fdf403e8492b88523cc8e 100644 --- a/run_alphafold.py +++ b/run_alphafold.py @@ -13,6 +13,7 @@ # limitations under the License. """Full AlphaFold protein structure prediction script.""" +import enum import json import os import pathlib @@ -21,7 +22,7 @@ import random import shutil import sys import time -from typing import Dict, Union +from typing import Any, Dict, Mapping, Union from absl import app from absl import flags @@ -37,12 +38,20 @@ from alphafold.model import config from alphafold.model import data from alphafold.model import model from alphafold.relax import relax +import jax.numpy as jnp import numpy as np # Internal import (7716). logging.set_verbosity(logging.INFO) + +@enum.unique +class ModelsToRelax(enum.Enum): + ALL = 0 + BEST = 1 + NONE = 2 + flags.DEFINE_list( 'fasta_paths', None, 'Paths to FASTA files, each containing a prediction ' 'target that will be folded one after another. If a FASTA file contains ' @@ -119,11 +128,15 @@ flags.DEFINE_boolean('use_precomputed_msas', False, 'Whether to read MSAs that ' 'runs that are to reuse the MSAs. WARNING: This will not ' 'check if the sequence, database or configuration have ' 'changed.') -flags.DEFINE_boolean('run_relax', True, 'Whether to run the final relaxation ' - 'step on the predicted models. Turning relax off might ' - 'result in predictions with distracting stereochemical ' - 'violations but might help in case you are having issues ' - 'with the relaxation stage.') +flags.DEFINE_enum_class('models_to_relax', ModelsToRelax.BEST, ModelsToRelax, + 'The models to run the final relaxation step on. ' + 'If `all`, all models are relaxed, which may be time ' + 'consuming. If `best`, only the most confident model ' + 'is relaxed. If `none`, relaxation is not run. Turning ' + 'off relaxation might result in predictions with ' + 'distracting stereochemical violations but might help ' + 'in case you are having issues with the relaxation ' + 'stage.') flags.DEFINE_boolean('use_gpu_relax', None, 'Whether to relax on GPU. ' 'Relax on GPU can be much faster than CPU, so it is ' 'recommended to enable if possible. GPUs must be available' @@ -148,6 +161,16 @@ def _check_flag(flag_name: str, f'"--{other_flag_name}={FLAGS[other_flag_name].value}".') +def _jnp_to_np(output: Dict[str, Any]) -> Dict[str, Any]: + """Recursively changes jax arrays to numpy arrays.""" + for k, v in output.items(): + if isinstance(v, dict): + output[k] = _jnp_to_np(v) + elif isinstance(v, jnp.ndarray): + output[k] = np.array(v) + return output + + def predict_structure( fasta_path: str, fasta_name: str, @@ -156,7 +179,8 @@ def predict_structure( model_runners: Dict[str, model.RunModel], amber_relaxer: relax.AmberRelaxation, benchmark: bool, - random_seed: int): + random_seed: int, + models_to_relax: ModelsToRelax): """Predicts structure using AlphaFold for the given sequence.""" logging.info('Predicting %s', fasta_name) timings = {} @@ -180,6 +204,7 @@ def predict_structure( pickle.dump(feature_dict, f, protocol=4) unrelaxed_pdbs = {} + unrelaxed_proteins = {} relaxed_pdbs = {} relax_metrics = {} ranking_confidences = {} @@ -217,10 +242,13 @@ def predict_structure( plddt = prediction_result['plddt'] ranking_confidences[model_name] = prediction_result['ranking_confidence'] + # Remove jax dependency from results. + np_prediction_result = _jnp_to_np(dict(prediction_result)) + # Save the model outputs. result_output_path = os.path.join(output_dir, f'result_{model_name}.pkl') with open(result_output_path, 'wb') as f: - pickle.dump(prediction_result, f, protocol=4) + pickle.dump(np_prediction_result, f, protocol=4) # Add the predicted LDDT in the b-factor column. # Note that higher predicted LDDT value means higher model confidence. @@ -232,38 +260,48 @@ def predict_structure( b_factors=plddt_b_factors, remove_leading_feature_dimension=not model_runner.multimer_mode) + unrelaxed_proteins[model_name] = unrelaxed_protein unrelaxed_pdbs[model_name] = protein.to_pdb(unrelaxed_protein) unrelaxed_pdb_path = os.path.join(output_dir, f'unrelaxed_{model_name}.pdb') with open(unrelaxed_pdb_path, 'w') as f: f.write(unrelaxed_pdbs[model_name]) - if amber_relaxer: - # Relax the prediction. - t_0 = time.time() - relaxed_pdb_str, _, violations = amber_relaxer.process( - prot=unrelaxed_protein) - relax_metrics[model_name] = { - 'remaining_violations': violations, - 'remaining_violations_count': sum(violations) - } - timings[f'relax_{model_name}'] = time.time() - t_0 - - relaxed_pdbs[model_name] = relaxed_pdb_str - - # Save the relaxed PDB. - relaxed_output_path = os.path.join( - output_dir, f'relaxed_{model_name}.pdb') - with open(relaxed_output_path, 'w') as f: - f.write(relaxed_pdb_str) - - # Rank by model confidence and write out relaxed PDBs in rank order. - ranked_order = [] - for idx, (model_name, _) in enumerate( - sorted(ranking_confidences.items(), key=lambda x: x[1], reverse=True)): - ranked_order.append(model_name) + # Rank by model confidence. + ranked_order = [ + model_name for model_name, confidence in + sorted(ranking_confidences.items(), key=lambda x: x[1], reverse=True)] + + # Relax predictions. + if models_to_relax == ModelsToRelax.BEST: + to_relax = [ranked_order[0]] + elif models_to_relax == ModelsToRelax.ALL: + to_relax = ranked_order + elif models_to_relax == ModelsToRelax.NONE: + to_relax = [] + + for model_name in to_relax: + t_0 = time.time() + relaxed_pdb_str, _, violations = amber_relaxer.process( + prot=unrelaxed_proteins[model_name]) + relax_metrics[model_name] = { + 'remaining_violations': violations, + 'remaining_violations_count': sum(violations) + } + timings[f'relax_{model_name}'] = time.time() - t_0 + + relaxed_pdbs[model_name] = relaxed_pdb_str + + # Save the relaxed PDB. + relaxed_output_path = os.path.join( + output_dir, f'relaxed_{model_name}.pdb') + with open(relaxed_output_path, 'w') as f: + f.write(relaxed_pdb_str) + + # Write out relaxed PDBs in rank order. + for idx, model_name in enumerate(ranked_order): ranked_output_path = os.path.join(output_dir, f'ranked_{idx}.pdb') with open(ranked_output_path, 'w') as f: - if amber_relaxer: + if model_name in relaxed_pdbs: f.write(relaxed_pdbs[model_name]) else: f.write(unrelaxed_pdbs[model_name]) @@ -279,7 +317,7 @@ def predict_structure( timings_output_path = os.path.join(output_dir, 'timings.json') with open(timings_output_path, 'w') as f: f.write(json.dumps(timings, indent=4)) - if amber_relaxer: + if models_to_relax != ModelsToRelax.NONE: relax_metrics_path = os.path.join(output_dir, 'relax_metrics.json') with open(relax_metrics_path, 'w') as f: f.write(json.dumps(relax_metrics, indent=4)) @@ -386,16 +424,13 @@ def main(argv): logging.info('Have %d models: %s', len(model_runners), list(model_runners.keys())) - if FLAGS.run_relax: - amber_relaxer = relax.AmberRelaxation( - max_iterations=RELAX_MAX_ITERATIONS, - tolerance=RELAX_ENERGY_TOLERANCE, - stiffness=RELAX_STIFFNESS, - exclude_residues=RELAX_EXCLUDE_RESIDUES, - max_outer_iterations=RELAX_MAX_OUTER_ITERATIONS, - use_gpu=FLAGS.use_gpu_relax) - else: - amber_relaxer = None + amber_relaxer = relax.AmberRelaxation( + max_iterations=RELAX_MAX_ITERATIONS, + tolerance=RELAX_ENERGY_TOLERANCE, + stiffness=RELAX_STIFFNESS, + exclude_residues=RELAX_EXCLUDE_RESIDUES, + max_outer_iterations=RELAX_MAX_OUTER_ITERATIONS, + use_gpu=FLAGS.use_gpu_relax) random_seed = FLAGS.random_seed if random_seed is None: @@ -413,7 +448,8 @@ def main(argv): model_runners=model_runners, amber_relaxer=amber_relaxer, benchmark=FLAGS.benchmark, - random_seed=random_seed) + random_seed=random_seed, + models_to_relax=FLAGS.models_to_relax) if __name__ == '__main__': diff --git a/run_alphafold_test.py b/run_alphafold_test.py index b91189c9a05eabc5e6647a5d7ab3293c27c28c85..5e0d7699924c2c240ec8ca9af53e540413b5faaf 100644 --- a/run_alphafold_test.py +++ b/run_alphafold_test.py @@ -28,10 +28,10 @@ import numpy as np class RunAlphafoldTest(parameterized.TestCase): @parameterized.named_parameters( - ('relax', True), - ('no_relax', False), + ('relax', run_alphafold.ModelsToRelax.ALL), + ('no_relax', run_alphafold.ModelsToRelax.NONE), ) - def test_end_to_end(self, do_relax): + def test_end_to_end(self, models_to_relax): data_pipeline_mock = mock.Mock() model_runner_mock = mock.Mock() @@ -72,9 +72,11 @@ class RunAlphafoldTest(parameterized.TestCase): output_dir_base=out_dir, data_pipeline=data_pipeline_mock, model_runners={'model1': model_runner_mock}, - amber_relaxer=amber_relaxer_mock if do_relax else None, + amber_relaxer=amber_relaxer_mock, benchmark=False, - random_seed=0) + random_seed=0, + models_to_relax=models_to_relax, + ) base_output_files = os.listdir(out_dir) self.assertIn('target.fasta', base_output_files) @@ -85,7 +87,7 @@ class RunAlphafoldTest(parameterized.TestCase): 'features.pkl', 'msas', 'ranked_0.pdb', 'ranking_debug.json', 'result_model1.pkl', 'timings.json', 'unrelaxed_model1.pdb', ] - if do_relax: + if models_to_relax == run_alphafold.ModelsToRelax.ALL: expected_files.extend(['relaxed_model1.pdb', 'relax_metrics.json']) with open(os.path.join(out_dir, 'test', 'relax_metrics.json')) as f: relax_metrics = json.loads(f.read())