diff --git a/README.md b/README.md
index d85e9d69a176b788650ddcaf975efe91107553d9..1af8c720e0d575fd9424feb2757dd446cadbf81d 100644
--- a/README.md
+++ b/README.md
@@ -61,8 +61,9 @@ Please follow these steps:
 
 1.  Download genetic databases and model parameters:
 
-    *   Install `aria2c` (on most Linux distributions it is available via the
-    package manager).
+    *   Install `aria2c`. On most Linux distributions it is available via the
+    package manager as the `aria2` package (on Debian-based distributions this
+    can be installed by running `sudo apt install aria2`).
 
     *   Please use the script `scripts/download_all_data.sh` to download
     and set up full databases. This may take substantial time (download size is
@@ -362,9 +363,11 @@ section.
       --output_dir=/home/user/absolute_path_to_the_output_dir
     ```
 
-1.  After generating the predicted model, by default AlphaFold runs a relaxation
-    step to improve geometrical quality. You can control this via `--run_relax=true`
-    (default) or `--run_relax=false`.
+1.  After generating the predicted model, AlphaFold runs a relaxation
+    step to improve local geometry. By default, only the best model (by
+    pLDDT) is relaxed (`--models_to_relax=best`), but also all of the models
+    (`--models_to_relax=all`) or none of the models (`--models_to_relax=none`)
+    can be relaxed.
 
 1.  The relaxation step can be run on GPU (faster, but could be less stable) or
     CPU (slow, but stable). This can be controlled with `--enable_gpu_relax=true`
diff --git a/docker/run_docker.py b/docker/run_docker.py
index bef5fdf063ef8a621993a03f30ffcfc8b10970fd..155d8fe2cd124497d6028d2288a02a8ae68b23fa 100644
--- a/docker/run_docker.py
+++ b/docker/run_docker.py
@@ -28,12 +28,15 @@ from docker import types
 
 flags.DEFINE_bool(
     'use_gpu', True, 'Enable NVIDIA runtime to run with GPUs.')
-flags.DEFINE_boolean(
-    'run_relax', True,
-    'Whether to run the final relaxation step on the predicted models. Turning '
-    'relax off might result in predictions with distracting stereochemical '
-    'violations but might help in case you are having issues with the '
-    'relaxation stage.')
+flags.DEFINE_enum('models_to_relax', 'best', ['best', 'all', 'none'],
+                  'The models to run the final relaxation step on. '
+                  'If `all`, all models are relaxed, which may be time '
+                  'consuming. If `best`, only the most confident model is '
+                  'relaxed. If `none`, relaxation is not run. Turning off '
+                  'relaxation might result in predictions with '
+                  'distracting stereochemical violations but might help '
+                  'in case you are having issues with the relaxation '
+                  'stage.')
 flags.DEFINE_bool(
     'enable_gpu_relax', True, 'Run relax on GPU if GPU is enabled.')
 flags.DEFINE_string(
@@ -221,7 +224,7 @@ def main(argv):
       f'--benchmark={FLAGS.benchmark}',
       f'--use_precomputed_msas={FLAGS.use_precomputed_msas}',
       f'--num_multimer_predictions_per_model={FLAGS.num_multimer_predictions_per_model}',
-      f'--run_relax={FLAGS.run_relax}',
+      f'--models_to_relax={FLAGS.models_to_relax}',
       f'--use_gpu_relax={use_gpu_relax}',
       '--logtostderr',
   ])
diff --git a/notebooks/AlphaFold.ipynb b/notebooks/AlphaFold.ipynb
index 44b6eaa8230561e2c83fefc904424fcb48a79490..8aaf5bcaba3b3522e3ad748f2f6d870c07d0c5b8 100644
--- a/notebooks/AlphaFold.ipynb
+++ b/notebooks/AlphaFold.ipynb
@@ -342,6 +342,7 @@
         "from concurrent import futures\n",
         "import json\n",
         "import random\n",
+        "import shutil\n",
         "\n",
         "from urllib import request\n",
         "from google.colab import files\n",
@@ -773,7 +774,7 @@
         "    f.write(pae_data)\n",
         "\n",
         "# --- Download the predictions ---\n",
-        "!zip -q -r {output_dir}.zip {output_dir}\n",
+        "shutil.make_archive(base_name='prediction', format='zip', root_dir=output_dir)\n",
         "files.download(f'{output_dir}.zip')"
       ]
     },
diff --git a/requirements.txt b/requirements.txt
index f8098f4afb40c428f148212fa8c3cfc0dc016bfb..a5819a1d3caea1be2ad1731af11d6fa55ce33442 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -9,6 +9,5 @@ jax==0.3.25
 ml-collections==0.1.0
 numpy==1.21.6
 pandas==1.3.4
-protobuf==3.20.1
 scipy==1.7.0
-tensorflow-cpu==2.9.0
+tensorflow-cpu==2.11.0
diff --git a/run_alphafold.py b/run_alphafold.py
index 4ddb31396f03f8b30deaa8e62982c2d9459ed500..0d89bfb47c8fc394e85fdf403e8492b88523cc8e 100644
--- a/run_alphafold.py
+++ b/run_alphafold.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 
 """Full AlphaFold protein structure prediction script."""
+import enum
 import json
 import os
 import pathlib
@@ -21,7 +22,7 @@ import random
 import shutil
 import sys
 import time
-from typing import Dict, Union
+from typing import Any, Dict, Mapping, Union
 
 from absl import app
 from absl import flags
@@ -37,12 +38,20 @@ from alphafold.model import config
 from alphafold.model import data
 from alphafold.model import model
 from alphafold.relax import relax
+import jax.numpy as jnp
 import numpy as np
 
 # Internal import (7716).
 
 logging.set_verbosity(logging.INFO)
 
+
+@enum.unique
+class ModelsToRelax(enum.Enum):
+  ALL = 0
+  BEST = 1
+  NONE = 2
+
 flags.DEFINE_list(
     'fasta_paths', None, 'Paths to FASTA files, each containing a prediction '
     'target that will be folded one after another. If a FASTA file contains '
@@ -119,11 +128,15 @@ flags.DEFINE_boolean('use_precomputed_msas', False, 'Whether to read MSAs that '
                      'runs that are to reuse the MSAs. WARNING: This will not '
                      'check if the sequence, database or configuration have '
                      'changed.')
-flags.DEFINE_boolean('run_relax', True, 'Whether to run the final relaxation '
-                     'step on the predicted models. Turning relax off might '
-                     'result in predictions with distracting stereochemical '
-                     'violations but might help in case you are having issues '
-                     'with the relaxation stage.')
+flags.DEFINE_enum_class('models_to_relax', ModelsToRelax.BEST, ModelsToRelax,
+                        'The models to run the final relaxation step on. '
+                        'If `all`, all models are relaxed, which may be time '
+                        'consuming. If `best`, only the most confident model '
+                        'is relaxed. If `none`, relaxation is not run. Turning '
+                        'off relaxation might result in predictions with '
+                        'distracting stereochemical violations but might help '
+                        'in case you are having issues with the relaxation '
+                        'stage.')
 flags.DEFINE_boolean('use_gpu_relax', None, 'Whether to relax on GPU. '
                      'Relax on GPU can be much faster than CPU, so it is '
                      'recommended to enable if possible. GPUs must be available'
@@ -148,6 +161,16 @@ def _check_flag(flag_name: str,
                      f'"--{other_flag_name}={FLAGS[other_flag_name].value}".')
 
 
+def _jnp_to_np(output: Dict[str, Any]) -> Dict[str, Any]:
+  """Recursively changes jax arrays to numpy arrays."""
+  for k, v in output.items():
+    if isinstance(v, dict):
+      output[k] = _jnp_to_np(v)
+    elif isinstance(v, jnp.ndarray):
+      output[k] = np.array(v)
+  return output
+
+
 def predict_structure(
     fasta_path: str,
     fasta_name: str,
@@ -156,7 +179,8 @@ def predict_structure(
     model_runners: Dict[str, model.RunModel],
     amber_relaxer: relax.AmberRelaxation,
     benchmark: bool,
-    random_seed: int):
+    random_seed: int,
+    models_to_relax: ModelsToRelax):
   """Predicts structure using AlphaFold for the given sequence."""
   logging.info('Predicting %s', fasta_name)
   timings = {}
@@ -180,6 +204,7 @@ def predict_structure(
     pickle.dump(feature_dict, f, protocol=4)
 
   unrelaxed_pdbs = {}
+  unrelaxed_proteins = {}
   relaxed_pdbs = {}
   relax_metrics = {}
   ranking_confidences = {}
@@ -217,10 +242,13 @@ def predict_structure(
     plddt = prediction_result['plddt']
     ranking_confidences[model_name] = prediction_result['ranking_confidence']
 
+    # Remove jax dependency from results.
+    np_prediction_result = _jnp_to_np(dict(prediction_result))
+
     # Save the model outputs.
     result_output_path = os.path.join(output_dir, f'result_{model_name}.pkl')
     with open(result_output_path, 'wb') as f:
-      pickle.dump(prediction_result, f, protocol=4)
+      pickle.dump(np_prediction_result, f, protocol=4)
 
     # Add the predicted LDDT in the b-factor column.
     # Note that higher predicted LDDT value means higher model confidence.
@@ -232,38 +260,48 @@ def predict_structure(
         b_factors=plddt_b_factors,
         remove_leading_feature_dimension=not model_runner.multimer_mode)
 
+    unrelaxed_proteins[model_name] = unrelaxed_protein
     unrelaxed_pdbs[model_name] = protein.to_pdb(unrelaxed_protein)
     unrelaxed_pdb_path = os.path.join(output_dir, f'unrelaxed_{model_name}.pdb')
     with open(unrelaxed_pdb_path, 'w') as f:
       f.write(unrelaxed_pdbs[model_name])
 
-    if amber_relaxer:
-      # Relax the prediction.
-      t_0 = time.time()
-      relaxed_pdb_str, _, violations = amber_relaxer.process(
-          prot=unrelaxed_protein)
-      relax_metrics[model_name] = {
-          'remaining_violations': violations,
-          'remaining_violations_count': sum(violations)
-      }
-      timings[f'relax_{model_name}'] = time.time() - t_0
-
-      relaxed_pdbs[model_name] = relaxed_pdb_str
-
-      # Save the relaxed PDB.
-      relaxed_output_path = os.path.join(
-          output_dir, f'relaxed_{model_name}.pdb')
-      with open(relaxed_output_path, 'w') as f:
-        f.write(relaxed_pdb_str)
-
-  # Rank by model confidence and write out relaxed PDBs in rank order.
-  ranked_order = []
-  for idx, (model_name, _) in enumerate(
-      sorted(ranking_confidences.items(), key=lambda x: x[1], reverse=True)):
-    ranked_order.append(model_name)
+  # Rank by model confidence.
+  ranked_order = [
+      model_name for model_name, confidence in
+      sorted(ranking_confidences.items(), key=lambda x: x[1], reverse=True)]
+
+  # Relax predictions.
+  if models_to_relax == ModelsToRelax.BEST:
+    to_relax = [ranked_order[0]]
+  elif models_to_relax == ModelsToRelax.ALL:
+    to_relax = ranked_order
+  elif models_to_relax == ModelsToRelax.NONE:
+    to_relax = []
+
+  for model_name in to_relax:
+    t_0 = time.time()
+    relaxed_pdb_str, _, violations = amber_relaxer.process(
+        prot=unrelaxed_proteins[model_name])
+    relax_metrics[model_name] = {
+        'remaining_violations': violations,
+        'remaining_violations_count': sum(violations)
+    }
+    timings[f'relax_{model_name}'] = time.time() - t_0
+
+    relaxed_pdbs[model_name] = relaxed_pdb_str
+
+    # Save the relaxed PDB.
+    relaxed_output_path = os.path.join(
+        output_dir, f'relaxed_{model_name}.pdb')
+    with open(relaxed_output_path, 'w') as f:
+      f.write(relaxed_pdb_str)
+
+  # Write out relaxed PDBs in rank order.
+  for idx, model_name in enumerate(ranked_order):
     ranked_output_path = os.path.join(output_dir, f'ranked_{idx}.pdb')
     with open(ranked_output_path, 'w') as f:
-      if amber_relaxer:
+      if model_name in relaxed_pdbs:
         f.write(relaxed_pdbs[model_name])
       else:
         f.write(unrelaxed_pdbs[model_name])
@@ -279,7 +317,7 @@ def predict_structure(
   timings_output_path = os.path.join(output_dir, 'timings.json')
   with open(timings_output_path, 'w') as f:
     f.write(json.dumps(timings, indent=4))
-  if amber_relaxer:
+  if models_to_relax != ModelsToRelax.NONE:
     relax_metrics_path = os.path.join(output_dir, 'relax_metrics.json')
     with open(relax_metrics_path, 'w') as f:
       f.write(json.dumps(relax_metrics, indent=4))
@@ -386,16 +424,13 @@ def main(argv):
   logging.info('Have %d models: %s', len(model_runners),
                list(model_runners.keys()))
 
-  if FLAGS.run_relax:
-    amber_relaxer = relax.AmberRelaxation(
-        max_iterations=RELAX_MAX_ITERATIONS,
-        tolerance=RELAX_ENERGY_TOLERANCE,
-        stiffness=RELAX_STIFFNESS,
-        exclude_residues=RELAX_EXCLUDE_RESIDUES,
-        max_outer_iterations=RELAX_MAX_OUTER_ITERATIONS,
-        use_gpu=FLAGS.use_gpu_relax)
-  else:
-    amber_relaxer = None
+  amber_relaxer = relax.AmberRelaxation(
+      max_iterations=RELAX_MAX_ITERATIONS,
+      tolerance=RELAX_ENERGY_TOLERANCE,
+      stiffness=RELAX_STIFFNESS,
+      exclude_residues=RELAX_EXCLUDE_RESIDUES,
+      max_outer_iterations=RELAX_MAX_OUTER_ITERATIONS,
+      use_gpu=FLAGS.use_gpu_relax)
 
   random_seed = FLAGS.random_seed
   if random_seed is None:
@@ -413,7 +448,8 @@ def main(argv):
         model_runners=model_runners,
         amber_relaxer=amber_relaxer,
         benchmark=FLAGS.benchmark,
-        random_seed=random_seed)
+        random_seed=random_seed,
+        models_to_relax=FLAGS.models_to_relax)
 
 
 if __name__ == '__main__':
diff --git a/run_alphafold_test.py b/run_alphafold_test.py
index b91189c9a05eabc5e6647a5d7ab3293c27c28c85..5e0d7699924c2c240ec8ca9af53e540413b5faaf 100644
--- a/run_alphafold_test.py
+++ b/run_alphafold_test.py
@@ -28,10 +28,10 @@ import numpy as np
 class RunAlphafoldTest(parameterized.TestCase):
 
   @parameterized.named_parameters(
-      ('relax', True),
-      ('no_relax', False),
+      ('relax', run_alphafold.ModelsToRelax.ALL),
+      ('no_relax', run_alphafold.ModelsToRelax.NONE),
   )
-  def test_end_to_end(self, do_relax):
+  def test_end_to_end(self, models_to_relax):
 
     data_pipeline_mock = mock.Mock()
     model_runner_mock = mock.Mock()
@@ -72,9 +72,11 @@ class RunAlphafoldTest(parameterized.TestCase):
         output_dir_base=out_dir,
         data_pipeline=data_pipeline_mock,
         model_runners={'model1': model_runner_mock},
-        amber_relaxer=amber_relaxer_mock if do_relax else None,
+        amber_relaxer=amber_relaxer_mock,
         benchmark=False,
-        random_seed=0)
+        random_seed=0,
+        models_to_relax=models_to_relax,
+        )
 
     base_output_files = os.listdir(out_dir)
     self.assertIn('target.fasta', base_output_files)
@@ -85,7 +87,7 @@ class RunAlphafoldTest(parameterized.TestCase):
         'features.pkl', 'msas', 'ranked_0.pdb', 'ranking_debug.json',
         'result_model1.pkl', 'timings.json', 'unrelaxed_model1.pdb',
     ]
-    if do_relax:
+    if models_to_relax == run_alphafold.ModelsToRelax.ALL:
       expected_files.extend(['relaxed_model1.pdb', 'relax_metrics.json'])
       with open(os.path.join(out_dir, 'test', 'relax_metrics.json')) as f:
         relax_metrics = json.loads(f.read())