rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
rf3/model/RF3.py ADDED
@@ -0,0 +1,527 @@
1
+ from collections import deque
2
+ from contextlib import ExitStack
3
+
4
+ import torch
5
+ import torch.utils.checkpoint as checkpoint
6
+ from beartype.typing import Any, Generator, Protocol
7
+ from omegaconf import DictConfig
8
+ from rf3.diffusion_samplers.inference_sampler import (
9
+ SampleDiffusion,
10
+ SamplePartialDiffusion,
11
+ )
12
+ from rf3.model.layers.pairformer_layers import (
13
+ FeatureInitializer,
14
+ )
15
+ from rf3.model.RF3_structure import DiffusionModule, DistogramHead, Recycler
16
+ from torch import nn
17
+
18
+ from foundry.training.checkpoint import create_custom_forward
19
+
20
+ """
21
+ Shape Annotation Glossary:
22
+ I: # tokens (coarse representation)
23
+ L: # atoms (fine representation)
24
+ M: # msa
25
+ T: # templates
26
+ D: # diffusion structure batch dim
27
+
28
+ C_s: # Token-level single reprentation channel dimension
29
+ C_z: # Token-level pair reprentation channel dimension
30
+ C_atom: # Atom-level single reprentation channel dimension
31
+ C_atompair: # Atom-level pair reprentation channel dimension
32
+
33
+ Tensor Name Glossary:
34
+ S: Token-level single representation (I, C_s)
35
+ Z: Token-level pair representation (I, I, C_z)
36
+ Q: Atom-level single representation (L, C_atom)
37
+ P: Atom-level pair representation (L, L, C_atompair)
38
+ """
39
+
40
+
41
+ class ShouldEarlyStopFn(Protocol):
42
+ def __call__(
43
+ self, confidence_outputs: dict[str, Any], first_recycle_outputs: dict[str, Any]
44
+ ) -> tuple[bool, dict[str, Any]]:
45
+ """Duck-typed function Protocol for early stopping based on confidence outputs.
46
+
47
+ Returns:
48
+ tuple: A pair containing:
49
+ - should_stop (bool): Whether to stop early.
50
+ - additional_data (dict): Metadata for the user, if any
51
+ """
52
+ ...
53
+
54
+
55
+ class RF3(nn.Module):
56
+ """RF3 Network module.
57
+
58
+ We adhere to the PyTorch Lightning Style Guide; see (1).
59
+
60
+ References:
61
+ (1) PyTorch Lightning Style Guide: https://lightning.ai/docs/pytorch/latest/starter/style_guide.html
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ *,
67
+ # Arguments for modules that will be instantiated
68
+ feature_initializer: DictConfig | dict,
69
+ recycler: DictConfig | dict,
70
+ diffusion_module: DictConfig | dict,
71
+ distogram_head: DictConfig | dict,
72
+ inference_sampler: DictConfig | dict,
73
+ # Channel dimensions
74
+ c_s: int, # AF-3: 384,
75
+ c_z: int, # AF-3: 128,
76
+ c_atom: int, # AF-3: 128,
77
+ c_atompair: int, # AF-3: 16,
78
+ c_s_inputs: int, # AF-3: 449,
79
+ ):
80
+ """Initializes the AF3 model.
81
+
82
+ Args:
83
+ feature_initializer: Arguments for FeatureInitializer
84
+ recycler: Arguments for Recycler
85
+ diffusion_module: Arguments for DiffusionModule
86
+ distogram_head: Arguments for DistogramHead
87
+ inference_sampler: Arguments for the SampleDiffusion class, used for inference (contains no trainable parameters)
88
+ c_s: Token-level single reprentation channel dimension
89
+ c_z: Token-level pair reprentation channel dimension
90
+ c_atom: Atom-level single reprentation channel dimension
91
+ c_atompair: Atom-level pair reprentation channel dimension
92
+ c_s_inputs: Output dimension of the InputFeatureEmbedder
93
+ """
94
+ super().__init__()
95
+
96
+ # ... initialize the FeatureInitializer, which creates the initial token-level representations and conditioning
97
+ self.feature_initializer = FeatureInitializer(
98
+ c_s=c_s,
99
+ c_z=c_z,
100
+ c_atom=c_atom,
101
+ c_atompair=c_atompair,
102
+ c_s_inputs=c_s_inputs,
103
+ **feature_initializer,
104
+ )
105
+
106
+ # ... initialize the Recycler, which runs the trunk repeatedly with shared weights
107
+ self.recycler = Recycler(c_s=c_s, c_z=c_z, **recycler)
108
+ self.diffusion_module = DiffusionModule(
109
+ c_atom=c_atom,
110
+ c_atompair=c_atompair,
111
+ c_s=c_s,
112
+ c_z=c_z,
113
+ **diffusion_module,
114
+ )
115
+ self.distogram_head = DistogramHead(c_z=c_z, **distogram_head)
116
+
117
+ # ... initialize the inference sampler, which performs a full diffusion rollout during inference
118
+ self.inference_sampler = (
119
+ SampleDiffusion(**inference_sampler)
120
+ if not inference_sampler.get("partial_t", False)
121
+ else SamplePartialDiffusion(**inference_sampler)
122
+ )
123
+
124
+ def forward(
125
+ self,
126
+ input: dict,
127
+ n_cycle: int,
128
+ coord_atom_lvl_to_be_noised: torch.Tensor = None,
129
+ ) -> dict:
130
+ """Complete forward pass of the model.
131
+
132
+ Runs recycling with gradients only on final recycle.
133
+
134
+ Args:
135
+ input (dict): Dictionary of model inputs
136
+ n_cycle (int): Number of recycling cycles for the trunk
137
+ coord_atom_lvl_to_be_noised (torch.Tensor): Atom-level coordinates to be noised further. Optional;
138
+ only used during inference for partial denoising.
139
+
140
+ Returns:
141
+ dict: Dictionary of model outputs, including:
142
+ - X_L: Predicted atomic coordinates [D, L, 3]
143
+ - distogram: Predicted distogram [I, I, C], where C is the number of bins in the distogram
144
+ - If not training, additional lists are returned, each of length T:
145
+ * X_noisy_L_traj: List of noisy atomic coordinates at each timestep [D, L, 3]
146
+ * X_denoised_L_traj: List of denoised atomic coordinates at each timestep [D, L, 3]
147
+ * t_hats: List of tensor scalars representing the noise schedule at each timestep
148
+ """
149
+ # Cast features to lower precision if autocast is enabled
150
+ if torch.is_autocast_enabled():
151
+ autocast_dtype = torch.get_autocast_dtype("cuda")
152
+ for x in [
153
+ "msa_stack",
154
+ "profile",
155
+ "template_distogram",
156
+ "template_restype",
157
+ "template_unit_vector",
158
+ ]:
159
+ if x in input["f"]:
160
+ input["f"][x] = input["f"][x].to(autocast_dtype)
161
+
162
+ # ... recycling
163
+ # Gives dictionary of outputs S_inputs_I, S_init_I, Z_init_II, S_I, Z_II
164
+ # (We use `deque` with maxlen=1 to ensure that we only keep the last output in memory)
165
+ try:
166
+ recycling_outputs = deque(
167
+ self.trunk_forward_with_recycling(f=input["f"], n_recycles=n_cycle),
168
+ maxlen=1,
169
+ ).pop()
170
+ except IndexError:
171
+ # Handle the case where the generator is empty
172
+ raise RuntimeError("Recycling generator produced no outputs")
173
+
174
+ # Predict the distogram from the pair representation
175
+ distogram_pred = self.distogram_head(recycling_outputs["Z_II"])
176
+
177
+ # ... post-recycling (diffusion module)
178
+ if self.training:
179
+ # Single denoising step
180
+ X_pred = self.diffusion_module(
181
+ X_noisy_L=input["X_noisy_L"],
182
+ t=input["t"],
183
+ f=input["f"],
184
+ S_inputs_I=recycling_outputs["S_inputs_I"],
185
+ S_trunk_I=recycling_outputs["S_I"],
186
+ Z_trunk_II=recycling_outputs["Z_II"],
187
+ ) # [D, L, 3]
188
+ return dict(
189
+ X_L=X_pred,
190
+ distogram=distogram_pred,
191
+ )
192
+ else:
193
+ # Full diffusion rollout (no gradients, or will OOM)
194
+ sample_diffusion_outs = self.inference_sampler.sample_diffusion_like_af3(
195
+ f=input["f"],
196
+ S_inputs_I=recycling_outputs["S_inputs_I"],
197
+ S_trunk_I=recycling_outputs["S_I"],
198
+ Z_trunk_II=recycling_outputs["Z_II"],
199
+ diffusion_module=self.diffusion_module,
200
+ diffusion_batch_size=input["t"].shape[0],
201
+ coord_atom_lvl_to_be_noised=coord_atom_lvl_to_be_noised,
202
+ )
203
+ return dict(
204
+ X_L=sample_diffusion_outs["X_L"],
205
+ distogram=distogram_pred,
206
+ # For reporting, inference (validation or testing) only
207
+ X_noisy_L_traj=sample_diffusion_outs["X_noisy_L_traj"],
208
+ X_denoised_L_traj=sample_diffusion_outs["X_denoised_L_traj"],
209
+ t_hats=sample_diffusion_outs["t_hats"],
210
+ )
211
+
212
+ def trunk_forward_with_recycling(
213
+ self, f: dict, n_recycles: int
214
+ ) -> Generator[dict[str, torch.Tensor]]:
215
+ """Forward pass of the AF-3 trunk.
216
+
217
+ (e.g., the recycling process, including the MSAModule, PairfomerStack, etc.).
218
+
219
+ Notes:
220
+ - We run with gradients ONLY on the final recycle
221
+ - All recycles use shared weights (ResNet-style)
222
+ - We yield results after reach recycle to support use cases such as e.g., early stopping during inference
223
+
224
+ Args:
225
+ f: Feature dictionary
226
+ n_recycles: Number of recycles to run
227
+
228
+ Returns:
229
+ dict: Recycling outputs, with keys:
230
+ - S_inputs_I: Token-level single representation input, prior to AtomAttention [I, c_s_inputs]
231
+ - S_init_I: Token-level single representation initialization [I, c_s], after AtomAttention but before recycling stack
232
+ - Z_init_II: Token-level pair representation initialization [I, I, c_z], after AtomAttention but before recycling stack
233
+ - S_I: Token-level single representation [I, c_s], after recycling stack
234
+ - Z_II: Token-level pair representation [I, I, c_z], after recycling stack
235
+ """
236
+ # ... initialize the recycling process (feature initialization)
237
+ # Gives S_inputs_I, S_init_I, Z_init_II, S_I, Z_II
238
+ initialized_features = self.pre_recycle(f)
239
+
240
+ # ... collect the recycling inputs, which will be updated in place
241
+ recycling_inputs = {**initialized_features, "f": f}
242
+
243
+ for i_cycle in range(n_recycles):
244
+ with ExitStack() as stack:
245
+ # For the first n_recycles - 1 cycles (all but the last recycle), we run without gradients
246
+ if i_cycle < n_recycles - 1:
247
+ stack.enter_context(torch.no_grad())
248
+
249
+ # Clear the autocast cache if gradients are enabled (workaround for autocast bug)
250
+ # See: https://github.com/pytorch/pytorch/issues/65766
251
+ if torch.is_grad_enabled():
252
+ torch.clear_autocast_cache()
253
+
254
+ # Select the MSA for the current recycle (we sample an i.i.d. MSA for each recycle)
255
+ recycling_inputs["f"]["msa"] = f["msa_stack"][i_cycle]
256
+
257
+ # Run the model trunk (MSAModule, PairformerStack, etc.)
258
+ # We alter the S_I and Z_II in place such that the next iteration uses the updated values
259
+ recycling_inputs = self.recycle(**recycling_inputs)
260
+
261
+ # Yield after each recycle
262
+ yield {
263
+ "S_inputs_I": recycling_inputs["S_inputs_I"],
264
+ "S_init_I": recycling_inputs["S_init_I"],
265
+ "Z_init_II": recycling_inputs["Z_init_II"],
266
+ "S_I": recycling_inputs["S_I"],
267
+ "Z_II": recycling_inputs["Z_II"],
268
+ }
269
+
270
+ def pre_recycle(self, f: dict) -> dict:
271
+ """Prepare feature inputs for recycling.
272
+
273
+ Includes:
274
+ - Feature initialization (S_inputs_I, S_init_I, Z_init_II)
275
+ - Initializing S_I and Z_II to zeros
276
+
277
+ Returns:
278
+ dict: Dictionary of recycling inputs, including:
279
+ - S_inputs_I: Token-level single representation input (prior to AtomAttention) [I, c_s_inputs]
280
+ - S_init_I: Token-level single representation initialization [I, c_s] (after round of AtomAttention)
281
+ - Z_init_II: Token-level pair representation initialization [I, I, c_z] (after round of AtomAttention)
282
+ - S_I: Token-level single representation [I, c_s], initialized to zeros
283
+ - Z_II: Token-level pair representation [I, I, c_z], initialized to zeros
284
+ """
285
+ S_inputs_I, S_init_I, Z_init_II = self.feature_initializer(f)
286
+ S_I = torch.zeros_like(S_init_I)
287
+ Z_II = torch.zeros_like(Z_init_II)
288
+
289
+ return dict(
290
+ S_inputs_I=S_inputs_I,
291
+ S_init_I=S_init_I,
292
+ Z_init_II=Z_init_II,
293
+ S_I=S_I,
294
+ Z_II=Z_II,
295
+ )
296
+
297
+ def recycle(
298
+ self,
299
+ # TODO: Jax typing
300
+ S_inputs_I,
301
+ S_init_I,
302
+ Z_init_II,
303
+ S_I,
304
+ Z_II,
305
+ f,
306
+ ):
307
+ S_I, Z_II = self.recycler(
308
+ f=f,
309
+ S_inputs_I=S_inputs_I,
310
+ S_init_I=S_init_I,
311
+ Z_init_II=Z_init_II,
312
+ S_I=S_I,
313
+ Z_II=Z_II,
314
+ )
315
+ return dict(
316
+ S_inputs_I=S_inputs_I,
317
+ S_init_I=S_init_I,
318
+ Z_init_II=Z_init_II,
319
+ S_I=S_I,
320
+ Z_II=Z_II,
321
+ f=f,
322
+ )
323
+
324
+
325
+ class RF3WithConfidence(RF3):
326
+ """Model for training and inference with confidence metric computation"""
327
+
328
+ def __init__(
329
+ self,
330
+ confidence_head: DictConfig | dict,
331
+ mini_rollout_sampler: DictConfig | dict,
332
+ **kwargs,
333
+ ):
334
+ """
335
+ Args:
336
+ (... all arguments from the AF3 class)
337
+ confidence_head: Hydra configuration for the confidence head architecture
338
+ mini_rollout_sampler: Hydra configuration for the mini-rollout sampler (e.g., SampleDiffusion with 20 rather than
339
+ 200 timesteps. Note that the `inference_sampler` argument in the AF3 class will still be used for full
340
+ rollouts during inference)
341
+ """
342
+ # (Lazy import)
343
+ from rf3.model.layers.af3_auxiliary_heads import ConfidenceHead # noqa
344
+
345
+ super().__init__(**kwargs)
346
+
347
+ self.confidence_head = ConfidenceHead(**confidence_head)
348
+ self.mini_rollout_sampler = SampleDiffusion(**mini_rollout_sampler)
349
+
350
+ def forward(
351
+ self,
352
+ input: dict,
353
+ n_cycle: int,
354
+ coord_atom_lvl_to_be_noised: torch.Tensor | None = None,
355
+ should_early_stop_fn: ShouldEarlyStopFn | None = None,
356
+ ) -> dict:
357
+ """Complete forward pass of the model with confidence head.
358
+
359
+ Notes:
360
+ - Performs a mini-rollout without gradients during training (e.g., 20 timesteps) and a full rollout (e.g., 200 timesteps) during inference
361
+ - Runs the trunk forward without gradients to conserve memory (which departs from the AF-3 implementation)
362
+ - Runs the forward pass (with gradients) for the confidence model
363
+
364
+ Args:
365
+ input (dict): Dictionary of model inputs. In addition to the standard AF-3 model inputs, we expect:
366
+ - rep_atom_idxs: TBD
367
+ - frame_atom_idxs: TBD
368
+ n_cycle (int): Number of recycling cycles for the trunk
369
+ coord_atom_lvl_to_be_noised (torch.Tensor): Atom-level coordinates to be noised further. Optional;
370
+ only used during inference for partial denoising.
371
+ should_early_stop_fn(Callable): Function that takes the confidence and trunk outputs after the first recycle and returns a boolean
372
+ indicating whether to stop early and a dictionary with additional information (e.g., value and threshold).
373
+ If None, no early stopping is performed. Optional; only used during inference.
374
+
375
+ Returns:
376
+ dict: Dictionary of model outputs, including:
377
+ - X_L: Predicted atomic coordinates [D, L, 3] (from the mini rollout during training or full rollout during inference)
378
+ - plddt: TBD
379
+ - pae: TBD
380
+ - pde: TBD
381
+ - exp_resolved: TBD
382
+ """
383
+ # Cast features to lower precision if autocast is enabled
384
+ if torch.is_autocast_enabled():
385
+ autocast_dtype = torch.get_autocast_dtype("cuda")
386
+ for x in [
387
+ "msa_stack",
388
+ "profile",
389
+ "template_distogram",
390
+ "template_restype",
391
+ "template_unit_vector",
392
+ ]:
393
+ if x in input["f"]:
394
+ input["f"][x] = input["f"][x].to(autocast_dtype)
395
+
396
+ diffusion_batch_size = input["t"].shape[0]
397
+ with torch.no_grad():
398
+ # ... recycling
399
+ recycling_output_generator = self.trunk_forward_with_recycling(
400
+ f=input["f"], n_recycles=n_cycle
401
+ )
402
+ if should_early_stop_fn:
403
+ assert (
404
+ not self.training
405
+ ), "Early stopping is not supported during training!"
406
+ # ... get the recycling outputs after the first recycle
407
+ first_recycle_outputs = next(recycling_output_generator)
408
+
409
+ # ... compute confidence metrics (without structure)
410
+ confidence_outputs = checkpoint.checkpoint(
411
+ create_custom_forward(
412
+ self.confidence_head, frame_atom_idxs=input["frame_atom_idxs"]
413
+ ),
414
+ first_recycle_outputs["S_inputs_I"],
415
+ first_recycle_outputs["S_I"],
416
+ first_recycle_outputs["Z_II"],
417
+ None, # Omit structure
418
+ input["seq"],
419
+ input["rep_atom_idxs"],
420
+ use_reentrant=False,
421
+ )
422
+
423
+ should_early_stop, early_stop_data = should_early_stop_fn(
424
+ confidence_outputs=confidence_outputs,
425
+ first_recycle_outputs=first_recycle_outputs,
426
+ )
427
+ if should_early_stop:
428
+ result = {"early_stopped": True}
429
+ return result | early_stop_data
430
+
431
+ # (We use `deque` with maxlen=1 to ensure that we only keep the last output in memory)
432
+ try:
433
+ recycling_outputs = deque(recycling_output_generator, maxlen=1).pop()
434
+ except IndexError:
435
+ # Handle the case where the generator is empty
436
+ raise RuntimeError("Recycling generator produced no outputs")
437
+
438
+ # Predict the distogram from the pair representation
439
+ # (NOTE: Not necessary for confidence head training, but helpful for reporting)
440
+ distogram_pred = self.distogram_head(recycling_outputs["Z_II"])
441
+
442
+ # ... post-recycling (diffusion module)
443
+ if self.training:
444
+ # Mini-rollout (no gradients still)
445
+ sample_diffusion_outs = (
446
+ self.mini_rollout_sampler.sample_diffusion_like_af3(
447
+ f=input["f"],
448
+ S_inputs_I=recycling_outputs["S_inputs_I"],
449
+ S_trunk_I=recycling_outputs["S_I"],
450
+ Z_trunk_II=recycling_outputs["Z_II"],
451
+ diffusion_module=self.diffusion_module,
452
+ diffusion_batch_size=diffusion_batch_size,
453
+ coord_atom_lvl_to_be_noised=coord_atom_lvl_to_be_noised,
454
+ )
455
+ )
456
+ else:
457
+ # Full diffusion rollout (no gradients still)
458
+ sample_diffusion_outs = (
459
+ self.inference_sampler.sample_diffusion_like_af3(
460
+ f=input["f"],
461
+ S_inputs_I=recycling_outputs["S_inputs_I"],
462
+ S_trunk_I=recycling_outputs["S_I"],
463
+ Z_trunk_II=recycling_outputs["Z_II"],
464
+ diffusion_module=self.diffusion_module,
465
+ diffusion_batch_size=diffusion_batch_size,
466
+ coord_atom_lvl_to_be_noised=coord_atom_lvl_to_be_noised,
467
+ )
468
+ )
469
+
470
+ # ... run non-batched confidence head
471
+ # TODO: Write a version of the confidence head that splits into batches based on memory available
472
+ # (Currently, we OOM with the full batch size, so we loop, which is slow)
473
+ D = sample_diffusion_outs["X_L"].shape[0]
474
+ confidence_stack = {}
475
+ for i in range(D):
476
+ confidence = checkpoint.checkpoint(
477
+ create_custom_forward(
478
+ self.confidence_head, frame_atom_idxs=input["frame_atom_idxs"]
479
+ ),
480
+ recycling_outputs["S_inputs_I"],
481
+ recycling_outputs["S_I"],
482
+ recycling_outputs["Z_II"],
483
+ sample_diffusion_outs["X_L"][i].unsqueeze(0),
484
+ input["seq"],
485
+ input["rep_atom_idxs"],
486
+ use_reentrant=False,
487
+ )
488
+
489
+ for k, v in confidence.items():
490
+ if k in confidence_stack:
491
+ confidence_stack[k] = torch.cat((confidence_stack[k], v), dim=0)
492
+ else:
493
+ confidence_stack[k] = v
494
+ confidence = confidence_stack
495
+
496
+ # ... run batched confidence head
497
+ # fd too much memory use at training time...
498
+ # confidence = checkpoint.checkpoint(
499
+ # create_custom_forward(
500
+ # self.confidence_head, frame_atom_idxs=input["frame_atom_idxs"]
501
+ # ),
502
+ # recycling_outputs["S_inputs_I"],
503
+ # recycling_outputs["S_I"],
504
+ # recycling_outputs["Z_II"],
505
+ # sample_diffusion_outs["X_L"],
506
+ # input["seq"],
507
+ # input["rep_atom_idxs"],
508
+ # use_reentrant=False,
509
+ # )
510
+
511
+ # TODO: Return outputs in a more structured way (e.g., a dataclass)
512
+ return dict(
513
+ early_stopped=False,
514
+ # We return X_L from diffusion sampling as X_pred_rollout_L to support future joint training with the confidence head (where we would have both X_L and X_pred_rollout_L)
515
+ X_L=None,
516
+ distogram=distogram_pred,
517
+ # For reporting, inference (validation or testing) only
518
+ X_noisy_L_traj=sample_diffusion_outs["X_noisy_L_traj"],
519
+ X_denoised_L_traj=sample_diffusion_outs["X_denoised_L_traj"],
520
+ t_hats=sample_diffusion_outs["t_hats"],
521
+ # Confidence outputs
522
+ X_pred_rollout_L=sample_diffusion_outs["X_L"],
523
+ plddt=confidence["plddt_logits"],
524
+ pae=confidence["pae_logits"],
525
+ pde=confidence["pde_logits"],
526
+ exp_resolved=confidence["exp_resolved_logits"],
527
+ )
@@ -0,0 +1,92 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from foundry.training.checkpoint import activation_checkpointing
6
+
7
+
8
+ class MSASubsampleEmbedder(nn.Module):
9
+ def __init__(self, num_sequences, dim_raw_msa, c_msa_embed, c_s_inputs):
10
+ super(MSASubsampleEmbedder, self).__init__()
11
+ self.num_sequences = num_sequences
12
+ self.emb_msa = nn.Linear(dim_raw_msa, c_msa_embed, bias=False)
13
+ self.emb_S_inputs = nn.Linear(c_s_inputs, c_msa_embed, bias=False)
14
+
15
+ @activation_checkpointing
16
+ def forward(
17
+ self,
18
+ msa_SI, # (S, I, 34) (32 tokens + has_deletion + deletion value)
19
+ S_inputs, # (L, S_dim)
20
+ ):
21
+ # Embed the subsampled MSA
22
+ # (NOTE: We subsample in the data loader to avoid memory issues)
23
+ msa_SI = self.emb_msa(msa_SI)
24
+ msa_SI = msa_SI + self.emb_S_inputs(S_inputs)
25
+ return msa_SI
26
+
27
+
28
+ class MSAPairWeightedAverage(nn.Module):
29
+ """implements Algorithm 10 from AF3 paper"""
30
+
31
+ def __init__(
32
+ self,
33
+ c_weighted_average,
34
+ n_heads,
35
+ c_msa_embed,
36
+ c_z,
37
+ separate_gate_for_every_channel,
38
+ ):
39
+ super(MSAPairWeightedAverage, self).__init__()
40
+ self.weighted_average_channels = c_weighted_average
41
+ self.n_heads = n_heads
42
+ self.msa_channels = c_msa_embed
43
+ self.pair_channels = c_z
44
+ self.norm_msa = nn.LayerNorm(self.msa_channels)
45
+ self.to_v = nn.Linear(
46
+ self.msa_channels, self.n_heads * self.weighted_average_channels, bias=False
47
+ )
48
+ self.norm_pair = nn.LayerNorm(self.pair_channels)
49
+ self.to_bias = nn.Linear(self.pair_channels, self.n_heads, bias=False)
50
+
51
+ self.separate_gate_for_every_channel = separate_gate_for_every_channel
52
+ if self.separate_gate_for_every_channel:
53
+ self.to_gate = nn.Linear(
54
+ self.msa_channels,
55
+ self.weighted_average_channels * self.n_heads,
56
+ bias=False,
57
+ )
58
+ else:
59
+ self.to_gate = nn.Linear(self.msa_channels, self.n_heads, bias=False)
60
+
61
+ self.to_out = nn.Linear(
62
+ self.weighted_average_channels * self.n_heads, self.msa_channels, bias=False
63
+ )
64
+
65
+ @activation_checkpointing
66
+ def forward(self, msa_SI, pair_II):
67
+ S, I = msa_SI.shape[:2]
68
+
69
+ # normalize inputs
70
+ msa_SI = self.norm_msa(msa_SI)
71
+
72
+ # construct values, bias and weights
73
+ v_SIH = self.to_v(msa_SI).reshape(
74
+ S, I, self.n_heads, self.weighted_average_channels
75
+ )
76
+ bias_IIH = self.to_bias(self.norm_pair(pair_II))
77
+ w_IIH = F.softmax(bias_IIH, dim=-2)
78
+
79
+ # construct gate
80
+ gate_SIH = torch.sigmoid(self.to_gate(msa_SI))
81
+
82
+ # compute weighted average & apply gate
83
+ if self.separate_gate_for_every_channel:
84
+ weights = torch.einsum("ijh,sjhc->sihc", w_IIH, v_SIH).reshape(S, I, -1)
85
+ o_SIH = gate_SIH * weights
86
+ else:
87
+ weights = torch.einsum("ijh,sjhc->sihc", w_IIH, v_SIH)
88
+ o_SIH = gate_SIH[..., None] * weights
89
+
90
+ # concatenate heads and project
91
+ msa_update_SI = self.to_out(o_SIH.reshape(S, I, -1))
92
+ return msa_update_SI