diff --git a/README.md b/README.md
index 46c8d43eaa720a11289216dbbff8c51f66adaac9..04a58aaf6dd1dcf2fe5f68f20c123134c87445c2 100644
--- a/README.md
+++ b/README.md
@@ -170,7 +170,7 @@ parameters are made available under the terms of the CC BY 4.0 license. Please
 see the [Disclaimer](#license-and-disclaimer) below for more detail.
 
 The AlphaFold parameters are available from
-https://storage.googleapis.com/alphafold/alphafold_params_2021-10-27.tar, and
+https://storage.googleapis.com/alphafold/alphafold_params_2022-01-19.tar, and
 are downloaded as part of the `scripts/download_all_data.sh` script. This script
 will download parameters for:
 
diff --git a/alphafold/data/msa_pairing.py b/alphafold/data/msa_pairing.py
index ddd36ee1e3309ba2f520440de0334f3b09b24b31..3200bf59a79576363276b0f5a397529a8a9f78f7 100644
--- a/alphafold/data/msa_pairing.py
+++ b/alphafold/data/msa_pairing.py
@@ -16,7 +16,6 @@
 
 import collections
 import functools
-import re
 import string
 from typing import Any, Dict, Iterable, List, Sequence
 
@@ -58,14 +57,6 @@ TEMPLATE_FEATURES = ('template_aatype', 'template_all_atom_positions',
 CHAIN_FEATURES = ('num_alignments', 'seq_length')
 
 
-domain_name_pattern = re.compile(
-    r'''^(?P<pdb>[a-z\d]{4})
-    \{(?P<bioassembly>[\d+(\+\d+)?])\}
-    (?P<chain>[a-zA-Z\d]+)
-    \{(?P<transform_index>\d+)\}$
-    ''', re.VERBOSE)
-
-
 def create_paired_features(
     chains: Iterable[pipeline.FeatureDict],
     prokaryotic: bool,
@@ -618,6 +609,7 @@ def deduplicate_unpaired_sequences(
   msa_features = MSA_FEATURES
 
   for chain in np_chains:
+    # Convert the msa_all_seq numpy array to a tuple for hashing.
     sequence_set = set(tuple(s) for s in chain['msa_all_seq'])
     keep_rows = []
     # Go through unpaired MSA seqs and remove any rows that correspond to the
@@ -627,12 +619,6 @@ def deduplicate_unpaired_sequences(
         keep_rows.append(row_num)
     for feature_name in feature_names:
       if feature_name in msa_features:
-        if keep_rows:
-          chain[feature_name] = chain[feature_name][keep_rows]
-        else:
-          new_shape = list(chain[feature_name].shape)
-          new_shape[0] = 0
-          chain[feature_name] = np.zeros(new_shape,
-                                         dtype=chain[feature_name].dtype)
+        chain[feature_name] = chain[feature_name][keep_rows]
     chain['num_alignments'] = np.array(chain['msa'].shape[0], dtype=np.int32)
   return np_chains
diff --git a/alphafold/data/parsers.py b/alphafold/data/parsers.py
index cbb58f23af71052c1f01f576880969c0a1d911d5..0d865fab8c91f575e2c45c986170a0b65197534c 100644
--- a/alphafold/data/parsers.py
+++ b/alphafold/data/parsers.py
@@ -20,6 +20,9 @@ import re
 import string
 from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Set
 
+# Internal import (7716).
+
+
 DeletionMatrix = Sequence[Sequence[int]]
 
 
@@ -271,24 +274,27 @@ def _keep_line(line: str, seqnames: Set[str]) -> bool:
     return seqname in seqnames
 
 
-def truncate_stockholm_msa(stockholm_msa: str, max_sequences: int) -> str:
-  """Truncates a stockholm file to a maximum number of sequences."""
+def truncate_stockholm_msa(stockholm_msa_path: str, max_sequences: int) -> str:
+  """Reads + truncates a Stockholm file while preventing excessive RAM usage."""
   seqnames = set()
   filtered_lines = []
-  for line in stockholm_msa.splitlines():
-    if line.strip() and not line.startswith(('#', '//')):
-      # Ignore blank lines, markup and end symbols - remainder are alignment
-      # sequence parts.
-      seqname = line.partition(' ')[0]
-      seqnames.add(seqname)
-      if len(seqnames) >= max_sequences:
-        break
-
-  for line in stockholm_msa.splitlines():
-    if _keep_line(line, seqnames):
-      filtered_lines.append(line)
 
-  return '\n'.join(filtered_lines) + '\n'
+  with open(stockholm_msa_path) as f:
+    for line in f:
+      if line.strip() and not line.startswith(('#', '//')):
+        # Ignore blank lines, markup and end symbols - remainder are alignment
+        # sequence parts.
+        seqname = line.partition(' ')[0]
+        seqnames.add(seqname)
+        if len(seqnames) >= max_sequences:
+          break
+
+    f.seek(0)
+    for line in f:
+      if _keep_line(line, seqnames):
+        filtered_lines.append(line)
+
+  return ''.join(filtered_lines)
 
 
 def remove_empty_columns_from_stockholm_msa(stockholm_msa: str) -> str:
diff --git a/alphafold/data/pipeline.py b/alphafold/data/pipeline.py
index 1f643dad87e14480e31719624d67fcf76626af4b..42c8c2ded755bde325d7eccbb33c4357d849a271 100644
--- a/alphafold/data/pipeline.py
+++ b/alphafold/data/pipeline.py
@@ -91,16 +91,25 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
 
 def run_msa_tool(msa_runner, input_fasta_path: str, msa_out_path: str,
                  msa_format: str, use_precomputed_msas: bool,
+                 max_sto_sequences: Optional[int] = None
                  ) -> Mapping[str, Any]:
   """Runs an MSA tool, checking if output already exists first."""
   if not use_precomputed_msas or not os.path.exists(msa_out_path):
-    result = msa_runner.query(input_fasta_path)[0]
+    if msa_format == 'sto' and max_sto_sequences is not None:
+      result = msa_runner.query(input_fasta_path, max_sto_sequences)[0]  # pytype: disable=wrong-arg-count
+    else:
+      result = msa_runner.query(input_fasta_path)[0]
     with open(msa_out_path, 'w') as f:
       f.write(result[msa_format])
   else:
     logging.warning('Reading MSA from file %s', msa_out_path)
-    with open(msa_out_path, 'r') as f:
-      result = {msa_format: f.read()}
+    if msa_format == 'sto' and max_sto_sequences is not None:
+      precomputed_msa = parsers.truncate_stockholm_msa(
+          msa_out_path, max_sto_sequences)
+      result = {'sto': precomputed_msa}
+    else:
+      with open(msa_out_path, 'r') as f:
+        result = {msa_format: f.read()}
   return result
 
 
@@ -157,18 +166,23 @@ class DataPipeline:
 
     uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto')
     jackhmmer_uniref90_result = run_msa_tool(
-        self.jackhmmer_uniref90_runner, input_fasta_path, uniref90_out_path,
-        'sto', self.use_precomputed_msas)
+        msa_runner=self.jackhmmer_uniref90_runner,
+        input_fasta_path=input_fasta_path,
+        msa_out_path=uniref90_out_path,
+        msa_format='sto',
+        use_precomputed_msas=self.use_precomputed_msas,
+        max_sto_sequences=self.uniref_max_hits)
     mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto')
     jackhmmer_mgnify_result = run_msa_tool(
-        self.jackhmmer_mgnify_runner, input_fasta_path, mgnify_out_path, 'sto',
-        self.use_precomputed_msas)
+        msa_runner=self.jackhmmer_mgnify_runner,
+        input_fasta_path=input_fasta_path,
+        msa_out_path=mgnify_out_path,
+        msa_format='sto',
+        use_precomputed_msas=self.use_precomputed_msas,
+        max_sto_sequences=self.mgnify_max_hits)
 
     msa_for_templates = jackhmmer_uniref90_result['sto']
-    msa_for_templates = parsers.truncate_stockholm_msa(
-        msa_for_templates, max_sequences=self.uniref_max_hits)
-    msa_for_templates = parsers.deduplicate_stockholm_msa(
-        msa_for_templates)
+    msa_for_templates = parsers.deduplicate_stockholm_msa(msa_for_templates)
     msa_for_templates = parsers.remove_empty_columns_from_stockholm_msa(
         msa_for_templates)
 
@@ -187,9 +201,7 @@ class DataPipeline:
       f.write(pdb_templates_result)
 
     uniref90_msa = parsers.parse_stockholm(jackhmmer_uniref90_result['sto'])
-    uniref90_msa = uniref90_msa.truncate(max_seqs=self.uniref_max_hits)
     mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto'])
-    mgnify_msa = mgnify_msa.truncate(max_seqs=self.mgnify_max_hits)
 
     pdb_template_hits = self.template_searcher.get_template_hits(
         output_string=pdb_templates_result, input_sequence=input_sequence)
@@ -197,14 +209,20 @@ class DataPipeline:
     if self._use_small_bfd:
       bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.sto')
       jackhmmer_small_bfd_result = run_msa_tool(
-          self.jackhmmer_small_bfd_runner, input_fasta_path, bfd_out_path,
-          'sto', self.use_precomputed_msas)
+          msa_runner=self.jackhmmer_small_bfd_runner,
+          input_fasta_path=input_fasta_path,
+          msa_out_path=bfd_out_path,
+          msa_format='sto',
+          use_precomputed_msas=self.use_precomputed_msas)
       bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto'])
     else:
       bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m')
       hhblits_bfd_uniclust_result = run_msa_tool(
-          self.hhblits_bfd_uniclust_runner, input_fasta_path, bfd_out_path,
-          'a3m', self.use_precomputed_msas)
+          msa_runner=self.hhblits_bfd_uniclust_runner,
+          input_fasta_path=input_fasta_path,
+          msa_out_path=bfd_out_path,
+          msa_format='a3m',
+          use_precomputed_msas=self.use_precomputed_msas)
       bfd_msa = parsers.parse_a3m(hhblits_bfd_uniclust_result['a3m'])
 
     templates_result = self.template_featurizer.get_templates(
diff --git a/alphafold/data/tools/jackhmmer.py b/alphafold/data/tools/jackhmmer.py
index cb03324f9ecd9882f10788bd92ebc44a707cea87..bea71844c41ad7d511c7c5bb9aff42c7b75838d4 100644
--- a/alphafold/data/tools/jackhmmer.py
+++ b/alphafold/data/tools/jackhmmer.py
@@ -23,6 +23,7 @@ from urllib import request
 
 from absl import logging
 
+from alphafold.data import parsers
 from alphafold.data.tools import utils
 # Internal import (7716).
 
@@ -86,8 +87,10 @@ class Jackhmmer:
     self.get_tblout = get_tblout
     self.streaming_callback = streaming_callback
 
-  def _query_chunk(self, input_fasta_path: str, database_path: str
-                   ) -> Mapping[str, Any]:
+  def _query_chunk(self,
+                   input_fasta_path: str,
+                   database_path: str,
+                   max_sequences: Optional[int] = None) -> Mapping[str, Any]:
     """Queries the database chunk using Jackhmmer."""
     with utils.tmpdir_manager() as query_tmp_dir:
       sto_path = os.path.join(query_tmp_dir, 'output.sto')
@@ -145,8 +148,11 @@ class Jackhmmer:
         with open(tblout_path) as f:
           tbl = f.read()
 
-      with open(sto_path) as f:
-        sto = f.read()
+      if max_sequences is None:
+        with open(sto_path) as f:
+          sto = f.read()
+      else:
+        sto = parsers.truncate_stockholm_msa(sto_path, max_sequences)
 
     raw_output = dict(
         sto=sto,
@@ -157,10 +163,14 @@ class Jackhmmer:
 
     return raw_output
 
-  def query(self, input_fasta_path: str) -> Sequence[Mapping[str, Any]]:
+  def query(self,
+            input_fasta_path: str,
+            max_sequences: Optional[int] = None) -> Sequence[Mapping[str, Any]]:
     """Queries the database using Jackhmmer."""
     if self.num_streamed_chunks is None:
-      return [self._query_chunk(input_fasta_path, self.database_path)]
+      single_chunk_result = self._query_chunk(
+          input_fasta_path, self.database_path, max_sequences)
+      return [single_chunk_result]
 
     db_basename = os.path.basename(self.database_path)
     db_remote_chunk = lambda db_idx: f'{self.database_path}.{db_idx}'
@@ -187,8 +197,8 @@ class Jackhmmer:
 
         # Run Jackhmmer with the chunk
         future.result()
-        chunked_output.append(
-            self._query_chunk(input_fasta_path, db_local_chunk(i)))
+        chunked_output.append(self._query_chunk(
+            input_fasta_path, db_local_chunk(i), max_sequences))
 
         # Remove the local copy of the chunk
         os.remove(db_local_chunk(i))
diff --git a/alphafold/model/folding_multimer.py b/alphafold/model/folding_multimer.py
index 6bdc6f16303e8277bf15a9cd486a370dd91d7c70..a30d7f466c76eeb82b346f7e131217ea2d1fabf3 100644
--- a/alphafold/model/folding_multimer.py
+++ b/alphafold/model/folding_multimer.py
@@ -186,7 +186,7 @@ class PointProjection(hk.Module):
 
 
 class InvariantPointAttention(hk.Module):
-  """Covariant attention module.
+  """Invariant point attention module.
 
   The high-level idea is that this attention module works over a set of points
   and associated orientations in 3D space (e.g. protein residues).
diff --git a/alphafold/relax/amber_minimize.py b/alphafold/relax/amber_minimize.py
index d3ff9f74218bdcabe0b57d8e0e749814b583edcd..ef1496942e5c422f5556c56b447d67354e9c1496 100644
--- a/alphafold/relax/amber_minimize.py
+++ b/alphafold/relax/amber_minimize.py
@@ -76,7 +76,8 @@ def _openmm_minimize(
     tolerance: unit.Unit,
     stiffness: unit.Unit,
     restraint_set: str,
-    exclude_residues: Sequence[int]):
+    exclude_residues: Sequence[int],
+    use_gpu: bool):
   """Minimize energy via openmm."""
 
   pdb_file = io.StringIO(pdb_str)
@@ -90,7 +91,7 @@ def _openmm_minimize(
     _add_restraints(system, pdb, stiffness, restraint_set, exclude_residues)
 
   integrator = openmm.LangevinIntegrator(0, 0.01, 0.0)
-  platform = openmm.Platform.getPlatformByName("CPU")
+  platform = openmm.Platform.getPlatformByName("CUDA" if use_gpu else "CPU")
   simulation = openmm_app.Simulation(
       pdb.topology, system, integrator, platform)
   simulation.context.setPositions(pdb.positions)
@@ -371,6 +372,7 @@ def _run_one_iteration(
     stiffness: float,
     restraint_set: str,
     max_attempts: int,
+    use_gpu: bool,
     exclude_residues: Optional[Collection[int]] = None):
   """Runs the minimization pipeline.
 
@@ -383,6 +385,7 @@ def _run_one_iteration(
       potential.
     restraint_set: The set of atoms to restrain.
     max_attempts: The maximum number of minimization attempts.
+    use_gpu: Whether to run on GPU.
     exclude_residues: An optional list of zero-indexed residues to exclude from
         restraints.
 
@@ -407,7 +410,8 @@ def _run_one_iteration(
           pdb_string, max_iterations=max_iterations,
           tolerance=tolerance, stiffness=stiffness,
           restraint_set=restraint_set,
-          exclude_residues=exclude_residues)
+          exclude_residues=exclude_residues,
+          use_gpu=use_gpu)
       minimized = True
     except Exception as e:  # pylint: disable=broad-except
       logging.info(e)
@@ -421,6 +425,7 @@ def _run_one_iteration(
 def run_pipeline(
     prot: protein.Protein,
     stiffness: float,
+    use_gpu: bool,
     max_outer_iterations: int = 1,
     place_hydrogens_every_iteration: bool = True,
     max_iterations: int = 0,
@@ -438,6 +443,7 @@ def run_pipeline(
   Args:
     prot: A protein to be relaxed.
     stiffness: kcal/mol A**2, the restraint stiffness.
+    use_gpu: Whether to run on GPU.
     max_outer_iterations: The maximum number of iterative minimization.
     place_hydrogens_every_iteration: Whether hydrogens are re-initialized
         prior to every minimization.
@@ -473,7 +479,8 @@ def run_pipeline(
         tolerance=tolerance,
         stiffness=stiffness,
         restraint_set=restraint_set,
-        max_attempts=max_attempts)
+        max_attempts=max_attempts,
+        use_gpu=use_gpu)
     prot = protein.from_pdb_string(ret["min_pdb"])
     if place_hydrogens_every_iteration:
       pdb_string = clean_protein(prot, checks=True)
diff --git a/alphafold/relax/amber_minimize_test.py b/alphafold/relax/amber_minimize_test.py
index b67cb911cbb07b505c7313eb4e7c13d518f162d9..dc7e6ea5a6d275c9a32a701574655eb05cd976dd 100644
--- a/alphafold/relax/amber_minimize_test.py
+++ b/alphafold/relax/amber_minimize_test.py
@@ -21,6 +21,8 @@ from alphafold.relax import amber_minimize
 import numpy as np
 # Internal import (7716).
 
+_USE_GPU = False
+
 
 def _load_test_protein(data_path):
   pdb_path = os.path.join(absltest.get_default_test_srcdir(), data_path)
@@ -35,7 +37,7 @@ class AmberMinimizeTest(absltest.TestCase):
         'alphafold/relax/testdata/multiple_disulfides_target.pdb'
         )
     ret = amber_minimize.run_pipeline(prot, max_iterations=10, max_attempts=1,
-                                      stiffness=10.)
+                                      stiffness=10., use_gpu=_USE_GPU)
     self.assertIn('opt_time', ret)
     self.assertIn('min_attempts', ret)
 
@@ -50,7 +52,8 @@ class AmberMinimizeTest(absltest.TestCase):
         ' residues. This protein contains at least one residue with no atoms.'):
       amber_minimize.run_pipeline(prot, max_iterations=10,
                                   stiffness=1.,
-                                  max_attempts=1)
+                                  max_attempts=1,
+                                  use_gpu=_USE_GPU)
 
   def test_iterative_relax(self):
     prot = _load_test_protein(
@@ -59,7 +62,7 @@ class AmberMinimizeTest(absltest.TestCase):
     violations = amber_minimize.get_violation_metrics(prot)
     self.assertGreater(violations['num_residue_violations'], 0)
     out = amber_minimize.run_pipeline(
-        prot=prot, max_outer_iterations=10, stiffness=10.)
+        prot=prot, max_outer_iterations=10, stiffness=10., use_gpu=_USE_GPU)
     self.assertLess(out['efinal'], out['einit'])
     self.assertEqual(0, out['num_residue_violations'])
 
diff --git a/alphafold/relax/relax.py b/alphafold/relax/relax.py
index f7af5856dc03b714d9a3566d0d393729c57c0522..bd6c9fd04b277679ece2b62a7c373f1127e6b1a6 100644
--- a/alphafold/relax/relax.py
+++ b/alphafold/relax/relax.py
@@ -29,7 +29,8 @@ class AmberRelaxation(object):
                tolerance: float,
                stiffness: float,
                exclude_residues: Sequence[int],
-               max_outer_iterations: int):
+               max_outer_iterations: int,
+               use_gpu: bool):
     """Initialize Amber Relaxer.
 
     Args:
@@ -44,6 +45,7 @@ class AmberRelaxation(object):
        CASP14. Use 20 so that >95% of the bad cases are relaxed. Relax finishes
        as soon as there are no violations, hence in most cases this causes no
        slowdown. In the worst case we do 20 outer iterations.
+      use_gpu: Whether to run on GPU.
     """
 
     self._max_iterations = max_iterations
@@ -51,6 +53,7 @@ class AmberRelaxation(object):
     self._stiffness = stiffness
     self._exclude_residues = exclude_residues
     self._max_outer_iterations = max_outer_iterations
+    self._use_gpu = use_gpu
 
   def process(self, *,
               prot: protein.Protein) -> Tuple[str, Dict[str, Any], np.ndarray]:
@@ -59,7 +62,8 @@ class AmberRelaxation(object):
         prot=prot, max_iterations=self._max_iterations,
         tolerance=self._tolerance, stiffness=self._stiffness,
         exclude_residues=self._exclude_residues,
-        max_outer_iterations=self._max_outer_iterations)
+        max_outer_iterations=self._max_outer_iterations,
+        use_gpu=self._use_gpu)
     min_pos = out['pos']
     start_pos = out['posinit']
     rmsd = np.sqrt(np.sum((start_pos - min_pos)**2) / start_pos.shape[0])
diff --git a/alphafold/relax/relax_test.py b/alphafold/relax/relax_test.py
index eba67ef2359e0a35696fcf4fd47404ec47cd3736..57e594e8a4f684e8bbab0bf645bad3776cec3d00 100644
--- a/alphafold/relax/relax_test.py
+++ b/alphafold/relax/relax_test.py
@@ -34,7 +34,8 @@ class RunAmberRelaxTest(absltest.TestCase):
         'tolerance': 2.39,
         'stiffness': 10.0,
         'exclude_residues': [],
-        'max_outer_iterations': 1}
+        'max_outer_iterations': 1,
+        'use_gpu': False}
 
   def test_process(self):
     amber_relax = relax.AmberRelaxation(**self.test_config)
diff --git a/docker/Dockerfile b/docker/Dockerfile
index c895832fd477541b8149ab66df862ff35e5a88ae..e6a78f6d3d3af451bb8f548e79e814b5ebddd8a6 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -72,7 +72,7 @@ RUN wget -q -P /app/alphafold/alphafold/common/ \
 # Install pip packages.
 RUN pip3 install --upgrade --no-cache-dir pip \
     && pip3 install --no-cache-dir -r /app/alphafold/requirements.txt \
-    && pip3 install --upgrade --no-cache-dir jax jaxlib==0.1.69+cuda${CUDA_JAXLIB/./} -f \
+    && pip3 install --upgrade --no-cache-dir jax==0.2.14 jaxlib==0.1.69+cuda${CUDA_JAXLIB/./} -f \
       https://storage.googleapis.com/jax-releases/jax_releases.html
 
 # Apply OpenMM patch.
@@ -110,6 +110,8 @@ COPY --from=build /opt/hhsuite /opt/hhsuite
 
 # hhsuite executable path hardcoded as /usr/bin in run_alphafold.py
 RUN ln -s /opt/hhsuite/bin/* /usr/bin
+# Add SETUID bit to the ldconfig binary so that non-root users can run it.
+RUN chmod u+s /sbin/ldconfig.real
 
 # We need to run `ldconfig` first to ensure GPUs are visible, due to some quirk
 # with Debian. See https://github.com/NVIDIA/nvidia-docker/issues/1399 for
diff --git a/docker/run_docker.py b/docker/run_docker.py
index 5d0f9beb0990a161736f2a66b821b4b44d80acd7..860303d8336bb1b5e43e222131f1e7c03e56c9f6 100644
--- a/docker/run_docker.py
+++ b/docker/run_docker.py
@@ -28,6 +28,14 @@ from docker import types
 
 flags.DEFINE_bool(
     'use_gpu', True, 'Enable NVIDIA runtime to run with GPUs.')
+flags.DEFINE_boolean(
+    'run_relax', True,
+    'Whether to run the final relaxation step on the predicted models. Turning '
+    'relax off might result in predictions with distracting stereochemical '
+    'violations but might help in case you are having issues with the '
+    'relaxation stage.')
+flags.DEFINE_bool(
+    'enable_gpu_relax', True, 'Run relax on GPU if GPU is enabled.')
 flags.DEFINE_string(
     'gpu_devices', 'all',
     'Comma separated list of devices to pass to NVIDIA_VISIBLE_DEVICES.')
@@ -72,8 +80,17 @@ flags.DEFINE_boolean(
     'for inferencing many proteins.')
 flags.DEFINE_boolean(
     'use_precomputed_msas', False,
-    'Whether to read MSAs that have been written to disk. WARNING: This will '
-    'not check if the sequence, database or configuration have changed.')
+    'Whether to read MSAs that have been written to disk instead of running '
+    'the MSA tools. The MSA files are looked up in the output directory, so it '
+    'must stay the same between multiple runs that are to reuse the MSAs. '
+    'WARNING: This will not check if the sequence, database or configuration '
+    'have changed.')
+flags.DEFINE_string(
+    'docker_user', f'{os.geteuid()}:{os.getegid()}',
+    'UID:GID with which to run the Docker container. The output directories '
+    'will be owned by this user:group. By default, this is the current user. '
+    'Valid options are: uid or uid:gid, non-numeric values are not recognised '
+    'by Docker unless that user has been created within the container.')
 
 FLAGS = flags.FLAGS
 
@@ -84,6 +101,9 @@ def _create_mount(mount_name: str, path: str) -> Tuple[types.Mount, str]:
   path = os.path.abspath(path)
   source_path = os.path.dirname(path)
   target_path = os.path.join(_ROOT_MOUNT_DIRECTORY, mount_name)
+  if not os.path.exists(source_path):
+    raise ValueError(f'Failed to find source directory "{source_path}" to '
+                     'mount in Docker container.')
   logging.info('Mounting %s -> %s', source_path, target_path)
   mount = types.Mount(target_path, source_path, type='bind', read_only=True)
   return mount, os.path.join(target_path, os.path.basename(path))
@@ -184,6 +204,8 @@ def main(argv):
   output_target_path = os.path.join(_ROOT_MOUNT_DIRECTORY, 'output')
   mounts.append(types.Mount(output_target_path, FLAGS.output_dir, type='bind'))
 
+  use_gpu_relax = FLAGS.enable_gpu_relax and FLAGS.use_gpu
+
   command_args.extend([
       f'--output_dir={output_target_path}',
       f'--max_template_date={FLAGS.max_template_date}',
@@ -191,6 +213,8 @@ def main(argv):
       f'--model_preset={FLAGS.model_preset}',
       f'--benchmark={FLAGS.benchmark}',
       f'--use_precomputed_msas={FLAGS.use_precomputed_msas}',
+      f'--run_relax={FLAGS.run_relax}',
+      f'--use_gpu_relax={use_gpu_relax}',
       '--logtostderr',
   ])
 
@@ -206,6 +230,7 @@ def main(argv):
       remove=True,
       detach=True,
       mounts=mounts,
+      user=FLAGS.docker_user,
       environment={
           'NVIDIA_VISIBLE_DEVICES': FLAGS.gpu_devices,
           # The following flags allow us to make predictions on proteins that
diff --git a/notebooks/AlphaFold.ipynb b/notebooks/AlphaFold.ipynb
index e9d936868b55932bafb18f6364ceb35862158587..44a4dd832f23d04565417303007e620381003fff 100644
--- a/notebooks/AlphaFold.ipynb
+++ b/notebooks/AlphaFold.ipynb
@@ -14,7 +14,7 @@
         "\n",
         "In comparison to AlphaFold v2.1.0, this Colab notebook uses **no templates (homologous structures)** and a selected portion of the [BFD database](https://bfd.mmseqs.com/). We have validated these changes on several thousand recent PDB structures. While accuracy will be near-identical to the full AlphaFold system on many targets, a small fraction have a large drop in accuracy due to the smaller MSA and lack of templates. For best reliability, we recommend instead using the [full open source AlphaFold](https://github.com/deepmind/alphafold/), or the [AlphaFold Protein Structure Database](https://alphafold.ebi.ac.uk/).\n",
         "\n",
-        "**This Colab has an small drop in average accuracy for multimers compared to local AlphaFold installation, for full multimer accuracy it is highly recommended to run [AlphaFold locally](https://github.com/deepmind/alphafold#running-alphafold).** Moreover, the AlphaFold-Multimer requires searching for MSA for every unique sequence in the complex, hence it is substantially slower. If your notebook times-out due to slow multimer MSA search, we recommend either using Colab Pro or running AlphaFold locally.\n",
+        "**This Colab has a small drop in average accuracy for multimers compared to local AlphaFold installation, for full multimer accuracy it is highly recommended to run [AlphaFold locally](https://github.com/deepmind/alphafold#running-alphafold).** Moreover, the AlphaFold-Multimer requires searching for MSA for every unique sequence in the complex, hence it is substantially slower. If your notebook times-out due to slow multimer MSA search, we recommend either using Colab Pro or running AlphaFold locally.\n",
         "\n",
         "Please note that this Colab notebook is provided as an early-access prototype and is not a finished product. It is provided for theoretical modelling only and caution should be exercised in its use. \n",
         "\n",
@@ -37,6 +37,17 @@
         "FAQ on how to interpret AlphaFold predictions are [here](https://alphafold.ebi.ac.uk/faq)."
       ]
     },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "uC1dKAwk2eyl"
+      },
+      "source": [
+        "## Setup\n",
+        "\n",
+        "Start by running the 2 cells below to set up AlphaFold and all required software."
+      ]
+    },
     {
       "cell_type": "code",
       "execution_count": null,
@@ -46,7 +57,7 @@
       },
       "outputs": [],
       "source": [
-        "#@title Install third-party software\n",
+        "#@title 1. Install third-party software\n",
         "\n",
         "#@markdown Please execute this cell by pressing the _Play_ button \n",
         "#@markdown on the left to download and import third-party software \n",
@@ -114,7 +125,7 @@
       },
       "outputs": [],
       "source": [
-        "#@title Download AlphaFold\n",
+        "#@title 2. Download AlphaFold\n",
         "\n",
         "#@markdown Please execute this cell by pressing the *Play* button on \n",
         "#@markdown the left.\n",
@@ -201,7 +212,7 @@
       },
       "outputs": [],
       "source": [
-        "#@title Enter the amino acid sequence(s) to fold ⬇️\n",
+        "#@title 3. Enter the amino acid sequence(s) to fold ⬇️\n",
         "#@markdown Enter the amino acid sequence(s) to fold:\n",
         "#@markdown * If you enter only a single sequence, the monomer model will be used.\n",
         "#@markdown * If you enter multiple sequences, the multimer model will be used.\n",
@@ -247,7 +258,7 @@
       },
       "outputs": [],
       "source": [
-        "#@title Search against genetic databases\n",
+        "#@title 4. Search against genetic databases\n",
         "\n",
         "#@markdown Once this cell has been executed, you will see\n",
         "#@markdown statistics about the multiple sequence alignment \n",
@@ -275,7 +286,6 @@
         "\n",
         "from alphafold.data import feature_processing\n",
         "from alphafold.data import msa_pairing\n",
-        "from alphafold.data import parsers\n",
         "from alphafold.data import pipeline\n",
         "from alphafold.data import pipeline_multimer\n",
         "from alphafold.data.tools import jackhmmer\n",
@@ -455,7 +465,7 @@
       },
       "outputs": [],
       "source": [
-        "#@title Run AlphaFold and download prediction\n",
+        "#@title 5. Run AlphaFold and download prediction\n",
         "\n",
         "#@markdown Once this cell has been executed, a zip-archive with\n",
         "#@markdown the obtained prediction will be automatically downloaded\n",
@@ -542,7 +552,8 @@
         "        tolerance=2.39,\n",
         "        stiffness=10.0,\n",
         "        exclude_residues=[],\n",
-        "        max_outer_iterations=3)\n",
+        "        max_outer_iterations=3,\n",
+        "        use_gpu=True)\n",
         "    relaxed_pdb, _, _ = amber_relaxer.process(prot=unrelaxed_proteins[best_model_name])\n",
         "  else:\n",
         "    print('Warning: Running without the relaxation stage.')\n",
@@ -694,7 +705,7 @@
         "*   How do I get a predicted protein structure for my protein?\n",
         "    *   Click on the _Connect_ button on the top right to get started.\n",
         "    *   Paste the amino acid sequence of your protein (without any headers) into the “Enter the amino acid sequence to fold”.\n",
-        "    *   Run all cells in the Colab, either by running them individually (with the play button on the left side) or via _Runtime_ \u003e _Run all._\n",
+        "    *   Run all cells in the Colab, either by running them individually (with the play button on the left side) or via _Runtime_ \u003e _Run all._ Make sure you run all 5 cells in order.\n",
         "    *   The predicted protein structure will be downloaded once all cells have been executed. Note: This can take minutes to hours - see below.\n",
         "*   How long will this take?\n",
         "    *   Downloading the AlphaFold source code can take up to a few minutes.\n",
diff --git a/run_alphafold.py b/run_alphafold.py
index 33fae99c8caa732b505afb8f57cad288a18c0dca..83034b9f342a7d5ae553f32696c1f7c318d3c220 100644
--- a/run_alphafold.py
+++ b/run_alphafold.py
@@ -34,11 +34,11 @@ from alphafold.data import templates
 from alphafold.data.tools import hhsearch
 from alphafold.data.tools import hmmsearch
 from alphafold.model import config
+from alphafold.model import data
 from alphafold.model import model
 from alphafold.relax import relax
 import numpy as np
 
-from alphafold.model import data
 # Internal import (7716).
 
 logging.set_verbosity(logging.INFO)
@@ -114,8 +114,21 @@ flags.DEFINE_integer('random_seed', None, 'The random seed for the data '
                      'deterministic, because processes like GPU inference are '
                      'nondeterministic.')
 flags.DEFINE_boolean('use_precomputed_msas', False, 'Whether to read MSAs that '
-                     'have been written to disk. WARNING: This will not check '
-                     'if the sequence, database or configuration have changed.')
+                     'have been written to disk instead of running the MSA '
+                     'tools. The MSA files are looked up in the output '
+                     'directory, so it must stay the same between multiple '
+                     'runs that are to reuse the MSAs. WARNING: This will not '
+                     'check if the sequence, database or configuration have '
+                     'changed.')
+flags.DEFINE_boolean('run_relax', True, 'Whether to run the final relaxation '
+                     'step on the predicted models. Turning relax off might '
+                     'result in predictions with distracting stereochemical '
+                     'violations but might help in case you are having issues '
+                     'with the relaxation stage.')
+flags.DEFINE_boolean('use_gpu_relax', None, 'Whether to relax on GPU. '
+                     'Relax on GPU can be much faster than CPU, so it is '
+                     'recommended to enable if possible. GPUs must be available'
+                     ' if this setting is enabled.')
 
 FLAGS = flags.FLAGS
 
@@ -384,12 +397,16 @@ def main(argv):
   logging.info('Have %d models: %s', len(model_runners),
                list(model_runners.keys()))
 
-  amber_relaxer = relax.AmberRelaxation(
-      max_iterations=RELAX_MAX_ITERATIONS,
-      tolerance=RELAX_ENERGY_TOLERANCE,
-      stiffness=RELAX_STIFFNESS,
-      exclude_residues=RELAX_EXCLUDE_RESIDUES,
-      max_outer_iterations=RELAX_MAX_OUTER_ITERATIONS)
+  if FLAGS.run_relax:
+    amber_relaxer = relax.AmberRelaxation(
+        max_iterations=RELAX_MAX_ITERATIONS,
+        tolerance=RELAX_ENERGY_TOLERANCE,
+        stiffness=RELAX_STIFFNESS,
+        exclude_residues=RELAX_EXCLUDE_RESIDUES,
+        max_outer_iterations=RELAX_MAX_OUTER_ITERATIONS,
+        use_gpu=FLAGS.use_gpu_relax)
+  else:
+    amber_relaxer = None
 
   random_seed = FLAGS.random_seed
   if random_seed is None:
@@ -422,6 +439,7 @@ if __name__ == '__main__':
       'template_mmcif_dir',
       'max_template_date',
       'obsolete_pdbs_path',
+      'use_gpu_relax',
   ])
 
   app.run(main)