rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,635 @@
1
+ import inspect
2
+ from dataclasses import dataclass
3
+ from typing import Any, Literal
4
+
5
+ import torch
6
+ from jaxtyping import Float
7
+
8
+ from foundry.common import exists
9
+ from foundry.utils.ddp import RankedLogger
10
+ from foundry.utils.rotation_augmentation import (
11
+ rot_vec_mul,
12
+ uniform_random_rotation,
13
+ )
14
+
15
+ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
16
+
17
+
18
+ @dataclass(kw_only=True)
19
+ class SampleDiffusionConfig:
20
+ kind: Literal["default", "symmetry"] = "default"
21
+
22
+ # Standard EDM args
23
+ num_timesteps: int = 200
24
+ min_t: int = 0
25
+ max_t: int = 1
26
+ sigma_data: int = 16
27
+ s_min: float = 4e-4
28
+ s_max: int = 160
29
+ p: int = 7
30
+ gamma_0: float = 0.6
31
+ gamma_min: float = 1.0
32
+ noise_scale: float = 1.003
33
+ step_scale: float = 1.5
34
+ solver: Literal["af3"] = "af3"
35
+
36
+ # RFD3 / design args
37
+ center_option: str = "all"
38
+ s_trans: float = 1.0
39
+ s_jitter_origin: float = 0.0
40
+ fraction_of_steps_to_fix_motif: float = 0.0
41
+ skip_few_diffusion_steps: bool = False
42
+ allow_realignment: bool = False
43
+ insert_motif_at_end: bool = True
44
+ use_classifier_free_guidance: bool = False
45
+ cfg_scale: float = 2.0
46
+ cfg_t_max: float | None = None
47
+
48
+
49
+ class SampleDiffusionWithMotif(SampleDiffusionConfig):
50
+ """Diffusion sampler that supports optional motif alignment."""
51
+
52
+ def _construct_inference_noise_schedule(
53
+ self, device: torch.device, partial_t: float = None
54
+ ) -> torch.Tensor:
55
+ """Constructs a noise schedule for use during inference.
56
+
57
+ The inference noise schedule is defined in the AF-3 supplement as:
58
+
59
+ t_hat = sigma_data * (s_max**(1/p) + t * (s_min**(1/p) - s_max**(1/p)))**p
60
+
61
+ Returns:
62
+ torch.Tensor: A tensor representing the noise schedule `t_hat`.
63
+
64
+ Reference:
65
+ AlphaFold 3 Supplement, Section 3.7.1.
66
+ """
67
+ # Create a linearly spaced tensor of timesteps between min_t and max_t
68
+ t = torch.linspace(self.min_t, self.max_t, self.num_timesteps, device=device)
69
+
70
+ # Construct the noise schedule, using the formula provided in the reference
71
+ t_hat = (
72
+ self.sigma_data
73
+ * (
74
+ (self.s_max) ** (1 / self.p)
75
+ + t * (self.s_min ** (1 / self.p) - self.s_max ** (1 / self.p))
76
+ )
77
+ ** self.p
78
+ )
79
+
80
+ if partial_t is not None:
81
+ # For now, partial t is a global parameter
82
+ partial_t = float(partial_t.mean())
83
+ noise_schedule = t_hat
84
+ ranked_logger.info("Using partial diffusion with t={}".format(partial_t))
85
+
86
+ # Debug the noise schedule filtering
87
+ original_schedule_len = len(noise_schedule)
88
+ original_max = noise_schedule.max().item()
89
+ original_min = noise_schedule.min().item()
90
+
91
+ noise_schedule = noise_schedule[noise_schedule <= partial_t]
92
+
93
+ new_schedule_len = len(noise_schedule)
94
+ if new_schedule_len > 0:
95
+ new_max = noise_schedule.max().item()
96
+ new_min = noise_schedule.min().item()
97
+ ranked_logger.info(
98
+ f"Noise schedule: {original_schedule_len} → {new_schedule_len} steps"
99
+ )
100
+ ranked_logger.info(
101
+ f"Original range: [{original_min:.3f}, {original_max:.3f}]"
102
+ )
103
+ ranked_logger.info(f"Filtered range: [{new_min:.3f}, {new_max:.3f}]")
104
+ else:
105
+ ranked_logger.warning(
106
+ f"No noise schedule steps found with t <= {partial_t}!"
107
+ )
108
+ ranked_logger.info(
109
+ f"Original schedule range: [{original_min:.3f}, {original_max:.3f}]"
110
+ )
111
+ # Fallback to smallest available step
112
+ noise_schedule_original = self._construct_inference_noise_schedule(
113
+ device=coord_atom_lvl_to_be_noised.device
114
+ )
115
+ noise_schedule = noise_schedule_original[-1:] # Just use the final step
116
+ ranked_logger.info(
117
+ f"Using fallback: final step with t={noise_schedule[0].item():.6f}"
118
+ )
119
+
120
+ return t_hat
121
+
122
+ def _get_initial_structure(
123
+ self,
124
+ c0: torch.Tensor,
125
+ D: int,
126
+ L: int,
127
+ coord_atom_lvl_to_be_noised: torch.Tensor,
128
+ is_motif_atom_with_fixed_coord,
129
+ ) -> torch.Tensor:
130
+ noise = c0 * torch.normal(mean=0.0, std=1.0, size=(D, L, 3), device=c0.device)
131
+ noise[..., is_motif_atom_with_fixed_coord, :] = 0 # Zero out noise going in
132
+ X_L = noise + coord_atom_lvl_to_be_noised
133
+ return X_L
134
+
135
+ def sample_diffusion_like_af3(
136
+ self,
137
+ *,
138
+ f: dict[str, Any],
139
+ diffusion_module: torch.nn.Module,
140
+ diffusion_batch_size: int,
141
+ coord_atom_lvl_to_be_noised: Float[torch.Tensor, "D L 3"],
142
+ initializer_outputs,
143
+ ref_initializer_outputs: dict[str, Any] | None,
144
+ f_ref: dict[str, Any] | None,
145
+ ) -> dict[str, Any]:
146
+ # Motif setup to recenter the motif at every step
147
+ is_motif_atom_with_fixed_coord = f["is_motif_atom_with_fixed_coord"]
148
+
149
+ # Book-keeping
150
+ noise_schedule = self._construct_inference_noise_schedule(
151
+ device=coord_atom_lvl_to_be_noised.device,
152
+ partial_t=f.get("partial_t", None),
153
+ )
154
+
155
+ L = f["ref_element"].shape[0]
156
+ D = diffusion_batch_size
157
+
158
+ X_L = self._get_initial_structure(
159
+ c0=noise_schedule[0],
160
+ D=D,
161
+ L=L,
162
+ coord_atom_lvl_to_be_noised=coord_atom_lvl_to_be_noised.clone(),
163
+ is_motif_atom_with_fixed_coord=is_motif_atom_with_fixed_coord,
164
+ ) # (D, L, 3)
165
+
166
+ if self.s_jitter_origin > 0.0:
167
+ X_L[:, is_motif_atom_with_fixed_coord, :] += torch.normal(
168
+ mean=0.0,
169
+ std=self.s_jitter_origin,
170
+ size=(D, 1, 3),
171
+ device=X_L.device,
172
+ )
173
+
174
+ X_noisy_L_traj = []
175
+ X_denoised_L_traj = []
176
+ sequence_entropy_traj = []
177
+ t_hats = []
178
+
179
+ threshold_step = (len(noise_schedule) - 1) * self.fraction_of_steps_to_fix_motif
180
+
181
+ for step_num, (c_t_minus_1, c_t) in enumerate(
182
+ zip(noise_schedule, noise_schedule[1:])
183
+ ):
184
+ # Assert no grads on X_L
185
+ assert not torch.is_grad_enabled(), "Computation graph should not be active"
186
+ assert not X_L.requires_grad, "X_L should not require gradients"
187
+
188
+ # Apply a random rotation and translation to the structure
189
+ if self.allow_realignment:
190
+ X_L, _ = centre_random_augment_around_motif(
191
+ X_L,
192
+ coord_atom_lvl_to_be_noised,
193
+ is_motif_atom_with_fixed_coord,
194
+ center_option=self.center_option,
195
+ # If centering_affects_motif is True, the model's predictions from (step_num-1) might affect the motif
196
+ centering_affects_motif=(max(step_num - 1, 0)) >= threshold_step,
197
+ # If keeping the motif position wrt the origin fixed, we can't do translational augmentation
198
+ # We want to keep this position fixed in the interval where the model is not allowed to change it
199
+ s_trans=self.s_trans if step_num >= threshold_step else 0.0,
200
+ )
201
+
202
+ # Update gamma & step scale
203
+ gamma = self.gamma_0 if c_t > self.gamma_min else 0
204
+ step_scale = self.step_scale
205
+
206
+ # Compute the value of t_hat
207
+ t_hat = c_t_minus_1 * (gamma + 1)
208
+
209
+ # Noise the coordinates with scaled Gaussian noise
210
+ epsilon_L = (
211
+ self.noise_scale
212
+ * torch.sqrt(torch.square(t_hat) - torch.square(c_t_minus_1))
213
+ * torch.normal(mean=0.0, std=1.0, size=X_L.shape, device=X_L.device)
214
+ )
215
+ epsilon_L[..., is_motif_atom_with_fixed_coord, :] = (
216
+ 0 # No noise injection for fixed atoms
217
+ )
218
+ X_noisy_L = X_L + epsilon_L
219
+
220
+ # Denoise the coordinates
221
+ # Handle chunked mode vs standard mode
222
+ if "chunked_pairwise_embedder" in initializer_outputs:
223
+ # Chunked mode: explicitly provide P_LL=None
224
+ chunked_embedder = initializer_outputs[
225
+ "chunked_pairwise_embedder"
226
+ ] # Don't pop, just get
227
+ other_outputs = {
228
+ k: v
229
+ for k, v in initializer_outputs.items()
230
+ if k != "chunked_pairwise_embedder"
231
+ }
232
+ outs = diffusion_module(
233
+ X_noisy_L=X_noisy_L,
234
+ t=t_hat.tile(D),
235
+ f=f,
236
+ P_LL=None, # Not used in chunked mode
237
+ chunked_pairwise_embedder=chunked_embedder,
238
+ initializer_outputs=other_outputs,
239
+ **other_outputs,
240
+ )
241
+ else:
242
+ # Standard mode: P_LL is included in initializer_outputs
243
+ outs = diffusion_module(
244
+ X_noisy_L=X_noisy_L,
245
+ t=t_hat.tile(D),
246
+ f=f,
247
+ **initializer_outputs,
248
+ )
249
+
250
+ X_denoised_L = outs["X_L"] if "X_L" in outs else outs
251
+
252
+ # Compute the delta between the noisy and denoised coordinates, scaled by t_hat
253
+ delta_L = (
254
+ X_noisy_L - X_denoised_L
255
+ ) / t_hat # gradient of x wrt. t at x_t_hat
256
+ d_t = c_t - t_hat
257
+
258
+ if self.use_classifier_free_guidance and (
259
+ self.cfg_t_max is None or c_t > self.cfg_t_max
260
+ ):
261
+ X_noisy_L_stripped = strip_X(X_noisy_L, f_ref)
262
+
263
+ # unconditional forward pass
264
+ outs_ref = diffusion_module(
265
+ X_noisy_L=X_noisy_L_stripped, # modify X
266
+ t=t_hat.tile(D),
267
+ f=f_ref, # modified f
268
+ **ref_initializer_outputs,
269
+ )
270
+
271
+ X_denoised_L_stripped = outs_ref["X_L"]
272
+
273
+ delta_L_ref = (
274
+ X_noisy_L_stripped - X_denoised_L_stripped
275
+ ) / t_hat # gradient of x wrt. t at x_t_hat
276
+
277
+ # pad delta_L_ref with zeros to match delta_L (for the unindexed atoms)
278
+ if delta_L_ref.shape[1] < delta_L.shape[1]:
279
+ delta_L_ref = torch.cat(
280
+ [
281
+ delta_L_ref,
282
+ torch.zeros_like(delta_L[:, delta_L_ref.shape[1] :, :]),
283
+ ],
284
+ dim=1,
285
+ )
286
+
287
+ # apply CFG
288
+ delta_L = delta_L + (self.cfg_scale - 1) * (delta_L - delta_L_ref)
289
+
290
+ if exists(outs.get("sequence_logits_I")):
291
+ # Compute confidence
292
+ p = torch.softmax(
293
+ outs["sequence_logits_I"], dim=-1
294
+ ).cpu() # shape (D, L, 32)
295
+ seq_entropy = -torch.sum(
296
+ p * torch.log(p + 1e-10), dim=-1
297
+ ) # shape (D, L,)
298
+ sequence_entropy_traj.append(seq_entropy)
299
+
300
+ # Update the coordinates, scaled by the step size
301
+ X_L = X_noisy_L + step_scale * d_t * delta_L
302
+
303
+ # Append the results to the trajectory (for visualization of the diffusion process)
304
+ X_noisy_L_scaled = (
305
+ self.sigma_data * X_noisy_L / torch.sqrt(t_hat**2 + self.sigma_data**2)
306
+ ) # Save noisy traj as scaled inputs
307
+ X_noisy_L_traj.append(X_noisy_L_scaled)
308
+ X_denoised_L_traj.append(X_denoised_L)
309
+ t_hats.append(t_hat)
310
+
311
+ if torch.any(is_motif_atom_with_fixed_coord) and self.allow_realignment:
312
+ # Insert the gt motif at the end
313
+ X_L, _ = centre_random_augment_around_motif(
314
+ X_L,
315
+ coord_atom_lvl_to_be_noised,
316
+ is_motif_atom_with_fixed_coord,
317
+ reinsert_motif=self.insert_motif_at_end,
318
+ )
319
+
320
+ # Align prediction to original motif
321
+ X_L = weighted_rigid_align(
322
+ coord_atom_lvl_to_be_noised,
323
+ X_L,
324
+ X_exists_L=is_motif_atom_with_fixed_coord,
325
+ )
326
+
327
+ return dict(
328
+ X_L=X_L, # (D, L, 3)
329
+ X_noisy_L_traj=X_noisy_L_traj, # list[Tensor[D, L, 3]]
330
+ X_denoised_L_traj=X_denoised_L_traj, # list[Tensor[D, L, 3]]
331
+ t_hats=t_hats, # list[Tensor[D]], where D is shared across all diffusion batches
332
+ sequence_logits_I=outs.get("sequence_logits_I"), # (D, I, 32)
333
+ sequence_indices_I=outs.get("sequence_indices_I"), # (D, I, 32)
334
+ sequence_entropy_traj=sequence_entropy_traj, # list[Tensor[D, I]]
335
+ )
336
+
337
+
338
+ class SampleDiffusionWithSymmetry(SampleDiffusionWithMotif):
339
+ """
340
+ This class is a wrapper around the SampleDiffusionWithMotif class.
341
+ It is used to sample diffusion with symmetry.
342
+ """
343
+
344
+ def __init__(self, sym_step_frac: float = 0.9, **kwargs):
345
+ assert (
346
+ kwargs.get("gamma_0") > 0.5
347
+ ), "gamma_0 must be greater than 0.5 for symmetry sampling"
348
+ self.sym_step_frac = sym_step_frac
349
+ super().__init__(**kwargs)
350
+
351
+ def apply_symmetry_to_X_L(self, X_L, f):
352
+ # check that we are doing symmetric inference
353
+
354
+ assert "sym_transform" in f.keys(), "Symmetry transform not found in f"
355
+
356
+ # update symmetric frames to correct for change in global frame
357
+ symmetry_feats = {k: v for k, v in f.items() if "sym" in k}
358
+
359
+ # apply symmetry frame shift to X_L
360
+ X_L = apply_symmetry_to_xyz_atomwise(
361
+ X_L, symmetry_feats, partial_diffusion=("partial_t" in f)
362
+ )
363
+
364
+ return X_L
365
+
366
+ def sample_diffusion_like_af3(
367
+ self,
368
+ *,
369
+ f: dict[str, Any],
370
+ diffusion_module: torch.nn.Module,
371
+ diffusion_batch_size: int,
372
+ coord_atom_lvl_to_be_noised: Float[torch.Tensor, "D L 3"],
373
+ initializer_outputs,
374
+ ref_initializer_outputs: dict[str, Any] | None,
375
+ f_ref: dict[str, Any] | None,
376
+ **_,
377
+ ) -> dict[str, Any]:
378
+ # Motif setup to recenter the motif at every step
379
+ is_motif_atom_with_fixed_coord = f["is_motif_atom_with_fixed_coord"]
380
+ # Book-keeping
381
+ noise_schedule = self._construct_inference_noise_schedule(
382
+ device=coord_atom_lvl_to_be_noised.device,
383
+ partial_t=f.get("partial_t", None),
384
+ )
385
+
386
+ L = f["ref_element"].shape[0]
387
+ D = diffusion_batch_size
388
+ X_L = self._get_initial_structure(
389
+ c0=noise_schedule[0],
390
+ D=D,
391
+ L=L,
392
+ coord_atom_lvl_to_be_noised=coord_atom_lvl_to_be_noised.clone(),
393
+ is_motif_atom_with_fixed_coord=is_motif_atom_with_fixed_coord,
394
+ ) # (D, L, 3)
395
+
396
+ X_noisy_L_traj = []
397
+ X_denoised_L_traj = []
398
+ sequence_entropy_traj = []
399
+ t_hats = []
400
+
401
+ # symmetrize X_L until the step gamma = gamma_min_sym
402
+ gamma_min_sym_idx = min(
403
+ int(len(noise_schedule) * self.sym_step_frac), len(noise_schedule) - 1
404
+ )
405
+ gamma_min_sym = noise_schedule[gamma_min_sym_idx]
406
+
407
+ ranked_logger.info(f"gamma_min_sym: {gamma_min_sym}")
408
+ ranked_logger.info(f"gamma_min: {self.gamma_min}")
409
+ for step_num, (c_t_minus_1, c_t) in enumerate(
410
+ zip(noise_schedule, noise_schedule[1:])
411
+ ):
412
+ # Assert no grads on X_L
413
+ assert not torch.is_grad_enabled(), "Computation graph should not be active"
414
+ assert not X_L.requires_grad, "X_L should not require gradients"
415
+
416
+ # Apply a random rotation and translation to the structure
417
+ if self.allow_realignment:
418
+ X_L, R = centre_random_augment_around_motif(
419
+ X_L,
420
+ coord_atom_lvl_to_be_noised,
421
+ is_motif_atom_with_fixed_coord,
422
+ )
423
+
424
+ # Update gamma & step scale
425
+ gamma = self.gamma_0 if c_t > self.gamma_min else 0
426
+ step_scale = self.step_scale
427
+
428
+ # Compute the value of t_hat
429
+ t_hat = c_t_minus_1 * (gamma + 1)
430
+
431
+ # Noise the coordinates with scaled Gaussian noise
432
+ epsilon_L = (
433
+ self.noise_scale
434
+ * torch.sqrt(torch.square(t_hat) - torch.square(c_t_minus_1))
435
+ * torch.normal(mean=0.0, std=1.0, size=X_L.shape, device=X_L.device)
436
+ )
437
+ epsilon_L[..., is_motif_atom_with_fixed_coord, :] = (
438
+ 0 # No noise injection for fixed atoms
439
+ )
440
+
441
+ # NOTE: no symmetry applied to the noisy structure
442
+ X_noisy_L = X_L + epsilon_L
443
+
444
+ # Denoise the coordinates
445
+ # Handle chunked mode vs standard mode (same as default sampler)
446
+ if "chunked_pairwise_embedder" in initializer_outputs:
447
+ # Chunked mode: explicitly provide P_LL=None
448
+ chunked_embedder = initializer_outputs[
449
+ "chunked_pairwise_embedder"
450
+ ] # Don't pop, just get
451
+ other_outputs = {
452
+ k: v
453
+ for k, v in initializer_outputs.items()
454
+ if k != "chunked_pairwise_embedder"
455
+ }
456
+ outs = diffusion_module(
457
+ X_noisy_L=X_noisy_L,
458
+ t=t_hat.tile(D),
459
+ f=f,
460
+ P_LL=None, # Not used in chunked mode
461
+ chunked_pairwise_embedder=chunked_embedder,
462
+ initializer_outputs=other_outputs,
463
+ **other_outputs,
464
+ )
465
+ else:
466
+ # Standard mode: P_LL is included in initializer_outputs
467
+ outs = diffusion_module(
468
+ X_noisy_L=X_noisy_L,
469
+ t=t_hat.tile(D),
470
+ f=f,
471
+ **initializer_outputs,
472
+ )
473
+ # apply symmetry to X_denoised_L
474
+ if "X_L" in outs and c_t > gamma_min_sym:
475
+ # outs["original_X_L"] = outs["X_L"].clone()
476
+ outs["X_L"] = self.apply_symmetry_to_X_L(outs["X_L"], f)
477
+
478
+ X_denoised_L = outs["X_L"] if "X_L" in outs else outs
479
+
480
+ # Compute the delta between the noisy and denoised coordinates, scaled by t_hat
481
+ delta_L = (
482
+ X_noisy_L - X_denoised_L
483
+ ) / t_hat # gradient of x wrt. t at x_t_hat
484
+ d_t = c_t - t_hat
485
+
486
+ # NOTE: no classifier-free guidance for symmetry
487
+
488
+ if exists(outs.get("sequence_logits_I")):
489
+ # Compute confidence
490
+ p = torch.softmax(
491
+ outs["sequence_logits_I"], dim=-1
492
+ ).cpu() # shape (D, L, 32)
493
+ seq_entropy = -torch.sum(
494
+ p * torch.log(p + 1e-10), dim=-1
495
+ ) # shape (D, L,)
496
+ sequence_entropy_traj.append(seq_entropy)
497
+
498
+ # Update the coordinates, scaled by the step size
499
+ # delta_L should be symmetric
500
+ X_L = X_noisy_L + step_scale * d_t * delta_L
501
+
502
+ # Append the results to the trajectory (for visualization of the diffusion process)
503
+ X_noisy_L_scaled = (
504
+ self.sigma_data * X_noisy_L / torch.sqrt(t_hat**2 + self.sigma_data**2)
505
+ ) # Save noisy traj as scaled inputs
506
+ X_noisy_L_traj.append(X_noisy_L_scaled)
507
+ X_denoised_L_traj.append(X_denoised_L)
508
+ t_hats.append(t_hat)
509
+
510
+ if torch.any(is_motif_atom_with_fixed_coord) and self.allow_realignment:
511
+ # Insert the gt motif at the end
512
+ X_L, R = centre_random_augment_around_motif(
513
+ X_L,
514
+ coord_atom_lvl_to_be_noised,
515
+ is_motif_atom_with_fixed_coord,
516
+ reinsert_motif=self.insert_motif_at_end,
517
+ )
518
+
519
+ # apply symmetry frame shift to X_L
520
+ X_L = self.apply_symmetry_to_X_L(X_L, f)
521
+
522
+ # Align prediction to original motif
523
+ X_L = weighted_rigid_align(
524
+ coord_atom_lvl_to_be_noised,
525
+ X_L,
526
+ X_exists_L=is_motif_atom_with_fixed_coord,
527
+ )
528
+
529
+ return dict(
530
+ X_L=X_L, # (D, L, 3)
531
+ X_noisy_L_traj=X_noisy_L_traj, # list[Tensor[D, L, 3]]
532
+ X_denoised_L_traj=X_denoised_L_traj, # list[Tensor[D, L, 3]]
533
+ t_hats=t_hats, # list[Tensor[D]], where D is shared across all diffusion batches
534
+ sequence_logits_I=outs.get("sequence_logits_I"), # (D, I, 32)
535
+ sequence_indices_I=outs.get("sequence_indices_I"), # (D, I, 32)
536
+ sequence_entropy_traj=sequence_entropy_traj, # list[Tensor[D, I]]
537
+ )
538
+
539
+
540
+ class ConditionalDiffusionSampler:
541
+ """
542
+ Conditional diffusion sampler, chooses at construction time which sampler to use,
543
+ then forwards `sample_diffusion_like_af3` to the chosen sampler.
544
+ If you write a new sampler, you best add it to the registry below
545
+ and inference_sampler.kind in inference_engine config.
546
+ """
547
+
548
+ _registry = {
549
+ "default": SampleDiffusionWithMotif,
550
+ "symmetry": SampleDiffusionWithSymmetry,
551
+ }
552
+
553
+ def __init__(self, kind="default", **kwargs):
554
+ ranked_logger.info(
555
+ f"Initializing ConditionalDiffusionSampler with kind: {kind}"
556
+ )
557
+ try:
558
+ SamplerCls = self._registry[kind]
559
+ # remove kwargs that the sampler cannot take
560
+ init_args = self.get_class_init_args(SamplerCls)
561
+ kwargs = {k: v for k, v in kwargs.items() if k in init_args}
562
+ except KeyError:
563
+ raise ValueError(
564
+ f"Invalid sampler kind: {kind}, must be one of {list(self._registry.keys())}"
565
+ )
566
+ self.sampler = SamplerCls(**kwargs)
567
+
568
+ def sample_diffusion_like_af3(self, **kwargs):
569
+ return self.sampler.sample_diffusion_like_af3(**kwargs)
570
+
571
+ def get_class_init_args(self, cls):
572
+ arg_names = []
573
+ if hasattr(cls, "__init__") and callable(cls.__init__):
574
+ for p_cls in cls.__mro__:
575
+ if "__init__" in p_cls.__dict__ and p_cls is not object:
576
+ signature = inspect.signature(p_cls.__init__)
577
+ arg_names.extend(
578
+ [param.name for param in signature.parameters.values()]
579
+ )
580
+ return arg_names
581
+
582
+
583
+ def centre_random_augment_around_motif(
584
+ X_L: torch.Tensor, # (D, L, 3) noisy diffused coordinates
585
+ coord_atom_lvl_to_be_noised: torch.Tensor, # (D, L, 3) original coordinates
586
+ is_motif_atom_with_fixed_coord: torch.Tensor, # (D, L) indices in original coordinates to be kept constant
587
+ s_trans: float = 1.0,
588
+ center_option: str = "all",
589
+ centering_affects_motif: bool = True,
590
+ reinsert_motif=True,
591
+ ):
592
+ D, L, _ = X_L.shape
593
+
594
+ if reinsert_motif and torch.any(is_motif_atom_with_fixed_coord):
595
+ # ... Align original coordinates to the prediction
596
+ coords_with_gt_aligned = weighted_rigid_align(
597
+ X_L[..., is_motif_atom_with_fixed_coord, :],
598
+ coord_atom_lvl_to_be_noised[..., is_motif_atom_with_fixed_coord, :],
599
+ )
600
+
601
+ # ... Insert original coordinates into X_L
602
+ X_L[..., is_motif_atom_with_fixed_coord, :] = coords_with_gt_aligned
603
+
604
+ # ... Centering
605
+ if torch.any(is_motif_atom_with_fixed_coord):
606
+ if center_option == "motif":
607
+ center = torch.mean(
608
+ X_L[..., is_motif_atom_with_fixed_coord, :], dim=-2, keepdim=True
609
+ ) # (D, 1, 3) - COM of motif atoms
610
+ elif center_option == "diffuse":
611
+ center = torch.mean(
612
+ X_L[..., ~is_motif_atom_with_fixed_coord, :], dim=-2, keepdim=True
613
+ ) # (D, 1, 3) - COM of diffused atoms
614
+
615
+ else:
616
+ center = torch.mean(X_L, dim=-2, keepdim=True)
617
+ else:
618
+ center = torch.mean(X_L, dim=-2, keepdim=True)
619
+
620
+ # ... Center
621
+ if centering_affects_motif:
622
+ X_L = X_L - center
623
+ else:
624
+ X_L[..., ~is_motif_atom_with_fixed_coord, :] = (
625
+ X_L[..., ~is_motif_atom_with_fixed_coord, :] - center
626
+ )
627
+
628
+ # ... Random augmentation
629
+ R = uniform_random_rotation((D,)).to(X_L.device)
630
+ noise = (
631
+ torch.normal(mean=0, std=1, size=(D, 1, 3), device=X_L.device) * s_trans
632
+ ) # (D, 1, 3)
633
+ X_L = rot_vec_mul(R[:, None], X_L) + noise
634
+
635
+ return X_L, R