rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
rf3/cli.py ADDED
@@ -0,0 +1,77 @@
1
+ from pathlib import Path
2
+
3
+ import typer
4
+ from hydra import compose, initialize_config_dir
5
+
6
+ app = typer.Typer(pretty_exceptions_enable=False)
7
+
8
+
9
+ @app.command(
10
+ context_settings={"allow_extra_args": True, "ignore_unknown_options": True}
11
+ )
12
+ def fold(
13
+ ctx: typer.Context,
14
+ verbose: bool = typer.Option(
15
+ False, "--verbose", "-v", help="Show detailed logging output"
16
+ ),
17
+ ):
18
+ """Run structure prediction using hydra config overrides or simple input file."""
19
+ # Configure logging BEFORE any heavy imports
20
+ if not verbose:
21
+ from foundry.utils.logging import configure_minimal_inference_logging
22
+
23
+ configure_minimal_inference_logging()
24
+
25
+ # Find the RF3 configs directory relative to this file
26
+ # This file is at: models/rf3/src/rf3/cli.py
27
+ # Configs are at: models/rf3/configs/
28
+ rf3_package_dir = Path(__file__).parent.parent.parent # Go up to models/rf3/
29
+ config_path = str(rf3_package_dir / "configs")
30
+
31
+ # Get all arguments
32
+ args = ctx.params.get("args", []) + ctx.args
33
+
34
+ # Parse arguments
35
+ hydra_overrides = []
36
+
37
+ if len(args) == 1 and "=" not in args[0]:
38
+ # Old style: single positional argument assumed to be inputs
39
+ hydra_overrides.append(f"inputs={args[0]}")
40
+ else:
41
+ # New style: all arguments are hydra overrides
42
+ hydra_overrides.extend(args)
43
+
44
+ # Ensure we have at least a default inference_engine if not specified
45
+ has_inference_engine = any(
46
+ arg.startswith("inference_engine=") for arg in hydra_overrides
47
+ )
48
+ if not has_inference_engine:
49
+ hydra_overrides.append("inference_engine=rf3")
50
+
51
+ # Handle verbose flag
52
+ if verbose:
53
+ hydra_overrides.append("verbose=true")
54
+
55
+ with initialize_config_dir(config_dir=config_path, version_base="1.3"):
56
+ cfg = compose(config_name="inference", overrides=hydra_overrides)
57
+ # Lazy import to avoid loading heavy dependencies at CLI startup
58
+ from rf3.inference import run_inference
59
+
60
+ run_inference(cfg)
61
+
62
+
63
+ @app.command(
64
+ context_settings={"allow_extra_args": True, "ignore_unknown_options": True}
65
+ )
66
+ def predict(
67
+ ctx: typer.Context,
68
+ verbose: bool = typer.Option(
69
+ False, "--verbose", "-v", help="Show detailed logging output"
70
+ ),
71
+ ):
72
+ """Alias for fold command."""
73
+ fold(ctx, verbose=verbose)
74
+
75
+
76
+ if __name__ == "__main__":
77
+ app()
@@ -0,0 +1,78 @@
1
+ import logging
2
+
3
+ import numpy as np
4
+ from atomworks.ml.transforms.base import Transform
5
+
6
+ logger = logging.getLogger("atomworks.ml")
7
+
8
+
9
+ class AddCyclicBonds(Transform):
10
+ """
11
+ Transform that detects and adds cyclic (head-to-tail) peptide bonds in protein chains.
12
+ This transform analyzes the atom-level structure of each chain in the input data to identify
13
+ cyclic bonds between the N-terminal nitrogen of the first residue and the C-terminal carbon
14
+ of the last residue, based on spatial proximity. If such a bond is detected (0.5 Å < distance < 1.5 Å),
15
+ it updates the token-level bond features to reflect the presence of the cyclic bond and flags that
16
+ chain as being cyclic.
17
+
18
+ Requirements:
19
+ - Must be applied after "AddAF3TokenBondFeatures" and "EncodeAF3TokenLevelFeatures" transforms.
20
+ """
21
+
22
+ # need to do it after this transform, because we want to include poly-poly bonds, which AF3TokenBondFeatures does not.
23
+ requires_previous_transforms = [
24
+ "AddAF3TokenBondFeatures",
25
+ "EncodeAF3TokenLevelFeatures",
26
+ "AddGlobalTokenIdAnnotation",
27
+ ]
28
+
29
+ def forward(self, data: dict) -> dict:
30
+ atom_array = data["atom_array"]
31
+ token_bonds = data["feats"]["token_bonds"]
32
+ asym_ids = data["feats"]["asym_id"]
33
+
34
+ cyclic_token_bonds = np.zeros_like(token_bonds, dtype=bool)
35
+ cyclic_asym_ids = set()
36
+
37
+ # check for any cyclic bonds
38
+ for chain in np.unique(atom_array.chain_id):
39
+ chain_mask = atom_array.chain_id == chain
40
+ if not np.any(chain_mask):
41
+ continue
42
+ chain_array = atom_array[chain_mask]
43
+ residue_ids = np.unique(chain_array.res_id)
44
+ if len(residue_ids) < 2:
45
+ continue
46
+ first_residue = residue_ids[0]
47
+ last_residue = residue_ids[-1]
48
+ first_nitrogen_mask = (chain_array.res_id == first_residue) & (
49
+ chain_array.atom_name == "N"
50
+ )
51
+ last_carbon_mask = (chain_array.res_id == last_residue) & (
52
+ chain_array.atom_name == "C"
53
+ )
54
+ if first_nitrogen_mask.sum() == 1 and last_carbon_mask.sum() == 1:
55
+ first_nitrogen = chain_array[first_nitrogen_mask]
56
+ last_carbon = chain_array[last_carbon_mask]
57
+ distance = np.linalg.norm(
58
+ first_nitrogen.coord[0] - last_carbon.coord[0]
59
+ )
60
+ if distance < 1.5 and distance > 0.5: # peptide-bond length-ish
61
+ cyclic_token_bonds[
62
+ first_nitrogen.token_id[0], last_carbon.token_id[0]
63
+ ] = True
64
+ cyclic_token_bonds[
65
+ last_carbon.token_id[0], first_nitrogen.token_id[0]
66
+ ] = True
67
+ logger.warning(
68
+ f"Detected cyclic bond in chain {chain} of {data['example_id']} between residues {first_residue} and {last_residue} with distance {distance:.2f} Å"
69
+ )
70
+
71
+ cyclic_asym_ids.update(
72
+ asym_ids[np.where(np.any(cyclic_token_bonds, axis=0))[0]].tolist()
73
+ )
74
+ token_bonds |= cyclic_token_bonds
75
+ data["feats"]["token_bonds"] = token_bonds
76
+ data["feats"]["cyclic_asym_ids"] = list(cyclic_asym_ids)
77
+
78
+ return data
@@ -0,0 +1,36 @@
1
+ import torch
2
+ from atomworks.ml.transforms._checks import (
3
+ check_contains_keys,
4
+ )
5
+ from atomworks.ml.transforms.base import Transform
6
+
7
+
8
+ class CheckForNaNsInInputs(Transform):
9
+ """
10
+ This component marks atoms as occ=0 based on bfactor values
11
+
12
+ It takes as input 'brange', a list specifying the Mminimum and maximum B factors to
13
+ keep.
14
+
15
+ Example:
16
+ brange = [-1.0,70.0] will mark with occ=0 any atom with b>70 or b<-1
17
+ """
18
+
19
+ def check_input(self, data: dict):
20
+ check_contains_keys(data, ["coord_atom_lvl_to_be_noised"])
21
+ check_contains_keys(data, ["noise"])
22
+
23
+ def forward(self, data: dict) -> dict:
24
+ # During inference, replace coordinates with true noise
25
+ # TODO: Move elsewhere in pipeline; placing it here is a short-term hack
26
+ if data.get("is_inference", False):
27
+ data["coord_atom_lvl_to_be_noised"] = torch.randn_like(
28
+ data["coord_atom_lvl_to_be_noised"]
29
+ )
30
+
31
+ assert not torch.isnan(
32
+ data["coord_atom_lvl_to_be_noised"]
33
+ ).any(), "NaN found in network input"
34
+ assert not torch.isnan(data["noise"]).any(), "NaN found in network noise"
35
+
36
+ return data
@@ -0,0 +1,463 @@
1
+ import logging
2
+ from dataclasses import dataclass
3
+
4
+ import numpy as np
5
+ import torch
6
+ from atomworks.enums import ChainType
7
+ from atomworks.ml.transforms._checks import (
8
+ check_atom_array_annotation,
9
+ check_contains_keys,
10
+ check_is_instance,
11
+ )
12
+ from atomworks.ml.transforms.atomize import AtomizeByCCDName
13
+ from atomworks.ml.transforms.base import Transform
14
+ from atomworks.ml.utils.token import (
15
+ get_af3_token_center_masks,
16
+ get_token_starts,
17
+ )
18
+ from beartype.typing import Any, Callable, Final, Sequence
19
+ from biotite.structure import AtomArray
20
+ from jaxtyping import Bool, Float, Shaped
21
+ from torch import Tensor
22
+
23
+ from foundry.utils.torch import assert_no_nans
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ MaskingFunction = Callable[[AtomArray], Bool[Shaped, "n"]]
28
+ """A function that takes in an AtomArray and returns a boolean mask."""
29
+
30
+ NoiseScaleSampler = Callable[[Sequence[int]], Float[Tensor, "..."] | float]
31
+ """
32
+ A noise scale sampler that, when given a shape-tuple, returns a sample of
33
+ noise scales of the appropriate shape.
34
+
35
+ Examples:
36
+ - partial(torch.normal, mean=0.0, std=1.0)
37
+ - af3_noise_scale_distribution
38
+ - af3_noise_scale_distribution_wrapped
39
+ """
40
+
41
+
42
+ @dataclass
43
+ class TokenGroupNoiseScaleSampler:
44
+ mask_and_sampling_fns: tuple[tuple[MaskingFunction, NoiseScaleSampler], ...]
45
+
46
+ def __call__(self, atom_array: AtomArray) -> Tensor:
47
+ # ... determine token centers
48
+ token_center_mask = get_af3_token_center_masks(atom_array) # [n_token] (bool)
49
+ token_array = atom_array[token_center_mask] # [n_token] (AtomArray)
50
+
51
+ # ... sample a noise scale for each token group
52
+ noise_scales = torch.full(
53
+ size=(len(token_array),),
54
+ fill_value=float("nan"),
55
+ dtype=torch.float32,
56
+ )
57
+ for mask_fn, sampling_fn in self.mask_and_sampling_fns:
58
+ mask = mask_fn(token_array)
59
+ n_tokens_to_sample = mask.sum()
60
+ if n_tokens_to_sample > 0:
61
+ # ... all tokens in that group receive the same noise scale
62
+ noise_scales[mask] = sampling_fn((1,))
63
+
64
+ return noise_scales
65
+
66
+
67
+ DEFAULT_DISTOGRAM_BINS: Final[Float[Tensor, "63"]] = torch.concat(
68
+ (
69
+ torch.arange(1.0, 4.0, 0.1, device="cpu"),
70
+ torch.arange(4.0, 20.5, 0.5, device="cpu"),
71
+ )
72
+ )
73
+ """
74
+ Default bins for discretizing distances in the distogram (in Angstrom).
75
+ - 0.1A resolution from 1.0 - 4.0 A
76
+ - 0.5A resolution from 4.0 - 20.0 A
77
+ Total number of bins: 64 (i.e. 63 bin boundaries above)
78
+ """
79
+
80
+
81
+ def wrap_probability_distribution(
82
+ samples: Float[Tensor, "..."],
83
+ lower: float = float("-inf"),
84
+ upper: float = float("inf"),
85
+ ) -> Float[Tensor, "..."]:
86
+ """
87
+ Wrap a probability distribution around lower and upper bounds to create
88
+ samples from the corresponding wrapped probability distribution.
89
+
90
+ Args:
91
+ - samples: Input tensor of samples to wrap
92
+ - lower: Lower bound for wrapping (inclusive, unless infinite)
93
+ - upper: Upper bound for wrapping (inclusive, unless infinite)
94
+
95
+ Returns:
96
+ - samples: Samples wrapped around the lower and upper bounds within
97
+ the interval ]lower, upper[.
98
+ Reference:
99
+ - https://en.wikipedia.org/wiki/Wrapped_distribution
100
+ """
101
+ if lower > float("-inf") and upper < float("inf"):
102
+ return ((samples - lower) % (upper - lower)) + lower
103
+ elif lower > float("-inf"):
104
+ return lower + (samples - lower).abs()
105
+ elif upper < float("inf"):
106
+ return upper - (samples - upper).abs()
107
+ return samples
108
+
109
+
110
+ def wrapped_normal(
111
+ mean: float,
112
+ std: float,
113
+ size: Sequence[int],
114
+ *,
115
+ lower: float = float("-inf"),
116
+ upper: float = float("inf"),
117
+ **normal_kwargs,
118
+ ) -> Float[Tensor, "..."]:
119
+ """Sample from a wrapped normal distribution."""
120
+ samples = torch.normal(mean=mean, std=std, size=size, **normal_kwargs)
121
+ return wrap_probability_distribution(samples, lower, upper)
122
+
123
+
124
+ def af3_noise_scale_to_noise_level(
125
+ noise_scale: Tensor | float, eps: int = 1e-8
126
+ ) -> Tensor:
127
+ """Converts AlphaFold3 noise scale (t^) in Angstroms to noise level (t).
128
+
129
+ This function converts from a noise scale in Angstroms (t^) to the
130
+ corresponding standard normal noise level (t) using the formula:
131
+ t = (log(t^/16.0) + 1.2) / 1.5
132
+
133
+ Args:
134
+ - noise_scale (Tensor): The noise scale (t^) in Angstroms,
135
+ representing the standard deviation of positional noise.
136
+
137
+ Returns:
138
+ - noise_level (Tensor): The corresponding noise level (t) as a
139
+ standard normal random variable.
140
+
141
+ Notes:
142
+ - We use the term 'noise-level' to refer to the standard normal random
143
+ variable `t` in the AF3 paper and 'noise-scale' to refer to the variable
144
+ `t^` which denotes the noise scale in Angstrom. This is the inverse
145
+ operation of af3_noise_level_to_noise_scale().
146
+ - To avoid taking the log of zero, we add a small constant to the
147
+ denominator (16.0) in the formula.
148
+ """
149
+ noise_scale_tensor = torch.as_tensor(noise_scale)
150
+ return (torch.log(torch.clamp(noise_scale_tensor, min=eps) / 16.0) + 1.2) / 1.5
151
+
152
+
153
+ def af3_noise_level_to_noise_scale(noise_level: Tensor | float) -> Tensor:
154
+ """Convert AlphaFold3 noise level (t) to noise scale (t^) in Angstroms.
155
+
156
+ This function converts from a standard normal noise level (t) to the
157
+ corresponding noise scale in Angstroms (t^) using the formula:
158
+ t^ = 16.0 * exp(1.5t - 1.2) (log-N(log(0.04), 1.5^2))
159
+
160
+ Args:
161
+ - noise_level (Tensor): The noise level (t) as a standard normal random
162
+ variable, sampled from N(0,1).
163
+
164
+ Returns:
165
+ - noise_scale (Tensor): The corresponding noise scale (t^) in Angstroms,
166
+ representing the standard deviation of positional noise to apply.
167
+
168
+ Note:
169
+ This is the inverse operation of af3_noise_scale_to_noise_level(). The
170
+ transformation is designed to convert between a normal distribution and
171
+ a log-normal distribution with specific parameters chosen by AlphaFold3.
172
+ """
173
+ return 16.0 * torch.exp((torch.as_tensor(noise_level) * 1.5) - 1.2)
174
+
175
+
176
+ def af3_noise_scale_distribution(size: Sequence[int], **kwargs) -> Tensor:
177
+ """
178
+ The log-normal noise-scale distribution used in AF3 (in Angstrom).
179
+
180
+ t^ = 16.0 * exp(1.5t - 1.2),
181
+ where:
182
+ - t = noise-level ~ N(0,1)
183
+ - t^ = noise-scale ~ log-N(log(0.04), 1.5^2)
184
+ """
185
+ noise_level = torch.normal(mean=0.0, std=1.0, size=size, **kwargs)
186
+ return af3_noise_level_to_noise_scale(noise_level)
187
+
188
+
189
+ def af3_noise_scale_distribution_wrapped(
190
+ size: Sequence[int],
191
+ *,
192
+ lower_noise_level: float = float("-inf"),
193
+ upper_noise_level: float = float("inf"),
194
+ **kwargs,
195
+ ) -> Tensor:
196
+ """
197
+ The noise-scale distribution used in AF3 (in Angstrom), wrapped around the lower
198
+ and upper bounds.
199
+
200
+ WARNING: The lower/upper here correspond to the noise-level (t) (not noise-scale (t^)),
201
+ wrapping happens in the noise-level space before converting to the corresponding
202
+ log-normal noise-scale distribution (t^) in Angstroms.
203
+ """
204
+ noise_level = wrapped_normal(
205
+ mean=0.0,
206
+ std=1.0,
207
+ size=size,
208
+ lower=lower_noise_level,
209
+ upper=upper_noise_level,
210
+ **kwargs,
211
+ )
212
+ return af3_noise_level_to_noise_scale(noise_level)
213
+
214
+
215
+ def featurize_noised_ground_truth_as_template_distogram(
216
+ atom_array: AtomArray,
217
+ *,
218
+ noise_scale: Float[Tensor, "n_token"] | float,
219
+ distogram_bins: Float[Tensor, "n_bin_edges"],
220
+ allowed_chain_types: list[ChainType],
221
+ is_unconditional: bool = True,
222
+ p_condition_per_token: float = 0.0,
223
+ p_provide_inter_molecule_distances: float = 0.0,
224
+ existing_annotation_to_check: str = "is_input_file_templated",
225
+ ) -> dict[str, Tensor]:
226
+ """Featurize noised ground truth as a template distogram for conditioning.
227
+
228
+ Used to leak ground-truth information into the model.
229
+
230
+ Args:
231
+ atom_array (AtomArray): The input atom array. Must have 'chain_type', 'occupancy', and 'molecule_id' annotations.
232
+ noise_scale (Tensor | float): Standard deviation of the noise to add to the ground truth.
233
+ Different tokens may have different noise scales (e.g. one noise scale for
234
+ side-chains, one for ligand atoms and one for backbone atoms).
235
+ Units are in Angstrom. If given as tensor, must be of shape [n_token] (float).
236
+ allowed_chain_types (list): List of allowed chain types. Only token pairs where BOTH
237
+ tokens have a chain type in this list will have a distogram condition.
238
+ distogram_bins (Tensor): Bins for discretizing distances in the distogram (in Angstrom).
239
+ Shape: [n_bin].
240
+ is_unconditional (bool): Whether we are sampling unconditionally.
241
+ See Classifier-Free Diffusion Guidance (Ho et al., 2022) for details.
242
+ Default: True (no conditioning).
243
+ p_condition_per_token (float, optional):
244
+ Probability of conditioning each eligible token. Default: 0.0 (no conditioning)
245
+ p_provide_inter_molecule_distances (float, optional):
246
+ Probability of providing inter-molecule (inter-chain) distances. Default: 0.0 (mask all inter-molecule pairs).
247
+ existing_annotation_to_check (str):
248
+ If this annotation exists in the AtomArray, we ALWAYS template where it is True.
249
+ Useful for inference.
250
+
251
+ Returns:
252
+ dict[str, Tensor]:
253
+ Dictionary with the following keys:
254
+ - 'distogram_condition_noise_scale': Float[Tensor, "n_token"]. Noise scale for each token (0 for unconditioned tokens).
255
+ - 'has_distogram_condition': Bool[Tensor, "n_token n_token"]. Mask indicating which token pairs are conditioned.
256
+ - 'distogram_condition': Float[Tensor, "n_token n_token n_bins"]. One-hot encoded distogram for each token pair.
257
+
258
+ NOTE:
259
+ - We use the center atom for each token (CA for proteins, C1' for nucleic acids) in the token-level conditioning.
260
+ - If a token is not conditioned, its noise scale is set to 0 and its pairwise distances are masked.
261
+ """
262
+ MASK_VALUE = float("nan")
263
+
264
+ # Get full atom array token starts (useful for going from atom-level -> token-level annotations)
265
+ _a_token_starts = get_token_starts(atom_array) # [n_token] (int)
266
+ _n_token = len(_a_token_starts)
267
+
268
+ # Create one blank template (ground truth), initialized to mask tokens (we will only use the distogram, and ignore the other features)
269
+ template_distogram = torch.full((_n_token, _n_token), fill_value=MASK_VALUE)
270
+
271
+ # Sample Gaussian noise according to the noise scale for each token
272
+ # NOTE: If a scalar noise scale is provided, it will be broadcasted to all tokens
273
+ # NOTE: We sample noise independently for each token; no two tokens will have the exact same noise
274
+ noise = torch.normal(mean=0.0, std=1.0, size=(_n_token, 3)) * noise_scale.unsqueeze(
275
+ -1
276
+ )
277
+
278
+ # Get the center coordinates of the tokens (CA for proteins, C1' for nucleic acids), and add noise
279
+ center_token_mask = get_af3_token_center_masks(atom_array) # [n_atom] (bool)
280
+ noisy_center_coords = (
281
+ torch.from_numpy(atom_array.coord[center_token_mask]) + noise
282
+ ) # [n_token, 3] (float)
283
+
284
+ # Create a mask of supported chain types...
285
+ tokens_with_supported_chain_types_mask = np.isin(
286
+ atom_array.chain_type[center_token_mask], allowed_chain_types
287
+ ) # [n_token] (bool)
288
+
289
+ # ... and mask of tokens with resolved center atoms
290
+ resolved_tokens_mask = (
291
+ atom_array.occupancy[center_token_mask] > 0
292
+ ) # [n_token] (bool)
293
+
294
+ # The tokens to fill are those with supported chain types, resolved center atoms, and non-NaN noise
295
+ token_to_fill_mask = (
296
+ tokens_with_supported_chain_types_mask
297
+ & resolved_tokens_mask
298
+ & torch.isfinite(noise).all(dim=-1).numpy()
299
+ ) # [n_token] (bool)
300
+
301
+ # Check if existing annotation exists and force templating where it's True
302
+ if (
303
+ existing_annotation_to_check
304
+ and existing_annotation_to_check in atom_array.get_annotation_categories()
305
+ ):
306
+ existing_annotation_values = atom_array.get_annotation(
307
+ existing_annotation_to_check
308
+ )[center_token_mask]
309
+ forced_template_mask = np.asarray(existing_annotation_values, dtype=bool)
310
+ else:
311
+ forced_template_mask = np.full(_n_token, False, dtype=bool)
312
+
313
+ # If unconditional, discard all conditioning...
314
+ if is_unconditional:
315
+ token_to_fill_mask = np.full_like(token_to_fill_mask, False)
316
+ else:
317
+ # Probability of masking each token
318
+ _should_apply_condition = np.random.rand(_n_token) < p_condition_per_token
319
+ token_to_fill_mask = (
320
+ token_to_fill_mask & _should_apply_condition
321
+ ) | forced_template_mask
322
+
323
+ token_idxs_to_fill = np.where(token_to_fill_mask)[0] # [n_token_to_fill] (int)
324
+
325
+ # ... fill the template_distogram
326
+ ix1, ix2 = np.ix_(token_idxs_to_fill, token_idxs_to_fill)
327
+ template_distogram[ix1.astype(int), ix2.astype(int)] = torch.cdist(
328
+ noisy_center_coords[token_to_fill_mask],
329
+ noisy_center_coords[token_to_fill_mask],
330
+ compute_mode="donot_use_mm_for_euclid_dist", # Important for numerical stability
331
+ )
332
+
333
+ # (Create n_token x n_token mask, where True indicates a condition); e.g., True for all non-mask tokens
334
+ token_to_fill_mask_II = token_to_fill_mask[:, None] & token_to_fill_mask[None, :]
335
+
336
+ # ... mask inter-molecule distances, if required
337
+ if np.random.rand() > p_provide_inter_molecule_distances:
338
+ # Create a mask of tokens that belong to different molecules
339
+ is_inter_molecule = (
340
+ atom_array.molecule_id[center_token_mask][:, None]
341
+ != atom_array.molecule_id[center_token_mask]
342
+ )
343
+
344
+ # ... mask inter-molecule distances
345
+ token_to_fill_mask_II[is_inter_molecule] = False
346
+ template_distogram[is_inter_molecule] = MASK_VALUE
347
+
348
+ # Discretize distances into bins (NaNs go to last bin)
349
+ template_distogram_binned: Tensor = torch.bucketize(
350
+ template_distogram, boundaries=distogram_bins
351
+ ) # (n_token, n_token)
352
+ n_bins: int = len(distogram_bins) + 1
353
+ template_distogram_onehot: Float[Tensor, "n_token n_token n_bins"] = (
354
+ torch.nn.functional.one_hot(
355
+ template_distogram_binned, num_classes=n_bins
356
+ ).to(torch.float32)
357
+ )
358
+
359
+ # Expand noise_scale to (n_token,) if needed
360
+ expanded_noise_scale: Float[Tensor, "n_token"] = (
361
+ noise_scale.expand(_n_token)
362
+ if isinstance(noise_scale, Tensor)
363
+ else torch.full_like(noise, fill_value=noise_scale)
364
+ )
365
+ # Set noise scale to 0 for unconditioned tokens
366
+ expanded_noise_scale[~token_to_fill_mask] = 0.0
367
+
368
+ out: dict[str, Tensor] = {
369
+ "distogram_condition_noise_scale": expanded_noise_scale, # (n_token,)
370
+ "has_distogram_condition": torch.as_tensor(
371
+ token_to_fill_mask_II, dtype=torch.bool
372
+ ), # (n_token, n_token)
373
+ "distogram_condition": template_distogram_onehot, # (n_token, n_token, n_bins)
374
+ }
375
+
376
+ assert_no_nans(out, msg="Conditioning features contain NaNs!")
377
+
378
+ return out
379
+
380
+
381
+ class FeaturizeNoisedGroundTruthAsTemplateDistogram(Transform):
382
+ """Add noised ground truth as a template distogram.
383
+
384
+ Creates template features by adding Gaussian noise to the ground truth
385
+ coordinates and converting the resulting distances into a discretized
386
+ distogram.
387
+
388
+ Args:
389
+ noise_scale_distribution (Callable): Function that returns the standard
390
+ deviation of noise to add to the ground truth coordinates. Should take
391
+ a sequence of dimensions and return a tensor or float. Default is
392
+ af3_noise_scale_distribution.
393
+ distogram_bins (Tensor): Bin boundaries for discretizing distances in
394
+ the distogram. Shape [n_bins-1].
395
+ allowed_chain_types (list): List of allowed chain types. Default is all chain types.
396
+ p_condition_per_token (float): Per-token probability of conditioning, for those tokens that satisfy all other conditions.
397
+ Default is 0.0 (no conditioning).
398
+ p_provide_inter_molecule_distances (float): Probability of providing inter-molecule (inter-chain) distances.
399
+ Default is 0.0 (no inter-molecule distances provided).
400
+ existing_annotation_to_check (str): Name of an annotation in the AtomArray that,
401
+ if present and True for a token, will force that token to be templated regardless of other conditions.
402
+ Default is "is_input_file_templated".
403
+
404
+ Adds the following features to the `feats` dict:
405
+ - "distogram_condition_noise_scale": Noise scale for each
406
+ token [n_token] (float)
407
+ - "has_distogram_condition": Mask indicating which token pairs have a distogram
408
+ condition [n_token, n_token] (bool)
409
+ - "distogram_condition": One-hot encoded distogram
410
+ [n_token, n_token, n_bins] (float)
411
+ """
412
+
413
+ requires_previous_transforms = [AtomizeByCCDName]
414
+
415
+ def __init__(
416
+ self,
417
+ noise_scale_distribution: NoiseScaleSampler
418
+ | TokenGroupNoiseScaleSampler = af3_noise_scale_distribution,
419
+ distogram_bins: torch.Tensor = DEFAULT_DISTOGRAM_BINS,
420
+ allowed_chain_types: list[ChainType] = ChainType.get_all_types(),
421
+ p_condition_per_token: float = 0.0,
422
+ p_provide_inter_molecule_distances: float = 0.0,
423
+ existing_annotation_to_check: str = "is_input_file_templated",
424
+ ):
425
+ self.distogram_bins = distogram_bins
426
+ self.noise_scale_distribution = noise_scale_distribution
427
+ self.p_provide_inter_molecule_distances = p_provide_inter_molecule_distances
428
+ self.allowed_chain_types = allowed_chain_types
429
+ self.p_condition_per_token = p_condition_per_token
430
+ self.existing_annotation_to_check = existing_annotation_to_check
431
+
432
+ def check_input(self, data: dict[str, Any]) -> None:
433
+ check_contains_keys(data, ["atom_array"])
434
+ check_is_instance(data, "atom_array", AtomArray)
435
+ check_atom_array_annotation(data, ["chain_type", "occupancy"])
436
+
437
+ def forward(self, data: dict[str, Any]) -> dict[str, Any]:
438
+ atom_array = data["atom_array"]
439
+
440
+ if isinstance(self.noise_scale_distribution, TokenGroupNoiseScaleSampler):
441
+ # ... different noise scale for each token group
442
+ noise_scale = self.noise_scale_distribution(atom_array) # [n_token] (float)
443
+ else:
444
+ # ... same noise scale for all tokens
445
+ noise_scale = self.noise_scale_distribution(size=(1,)) # [1] (float)
446
+
447
+ template_features = featurize_noised_ground_truth_as_template_distogram(
448
+ atom_array=atom_array,
449
+ noise_scale=noise_scale,
450
+ allowed_chain_types=self.allowed_chain_types,
451
+ distogram_bins=self.distogram_bins,
452
+ p_provide_inter_molecule_distances=self.p_provide_inter_molecule_distances,
453
+ is_unconditional=data.get("is_unconditional", False),
454
+ p_condition_per_token=self.p_condition_per_token,
455
+ existing_annotation_to_check=self.existing_annotation_to_check,
456
+ )
457
+
458
+ # Add the template features to the `feats` dict
459
+ if "feats" not in data:
460
+ data["feats"] = {}
461
+ data["feats"].update(template_features)
462
+
463
+ return data