rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,387 @@
1
+ import functools
2
+ import logging
3
+ import os
4
+ from contextlib import ExitStack
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from rfd3.model.layers.block_utils import (
9
+ bucketize_scaled_distogram,
10
+ create_attention_indices,
11
+ )
12
+ from rfd3.model.layers.blocks import (
13
+ CompactStreamingDecoder,
14
+ Downcast,
15
+ LinearEmbedWithPool,
16
+ LinearSequenceHead,
17
+ LocalAtomTransformer,
18
+ LocalTokenTransformer,
19
+ )
20
+ from rfd3.model.layers.encoders import (
21
+ DiffusionTokenEncoder,
22
+ )
23
+ from rfd3.model.layers.layer_utils import RMSNorm, linearNoBias
24
+
25
+ from foundry.model.layers.blocks import (
26
+ FourierEmbedding,
27
+ )
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class RFD3DiffusionModule(nn.Module):
33
+ def __init__(
34
+ self,
35
+ *,
36
+ c_atom,
37
+ c_atompair,
38
+ c_token,
39
+ c_s,
40
+ c_z,
41
+ c_t_embed,
42
+ sigma_data,
43
+ f_pred,
44
+ n_attn_seq_neighbours,
45
+ n_attn_keys,
46
+ n_recycle,
47
+ atom_attention_encoder,
48
+ diffusion_token_encoder,
49
+ diffusion_transformer,
50
+ atom_attention_decoder,
51
+ # upcast,
52
+ downcast,
53
+ use_local_token_attention=True,
54
+ **_,
55
+ ):
56
+ super().__init__()
57
+ self.sigma_data = sigma_data
58
+ self.c_atom = c_atom
59
+ self.c_atompair = c_atompair
60
+ self.c_token = c_token
61
+ self.c_s = c_s
62
+ self.c_z = c_z
63
+ self.f_pred = f_pred
64
+ self.n_attn_seq_neighbours = n_attn_seq_neighbours
65
+ self.n_attn_keys = n_attn_keys
66
+ self.use_local_token_attention = use_local_token_attention
67
+
68
+ # Auxiliary
69
+ self.process_r = linearNoBias(3, c_atom)
70
+ self.to_r_update = nn.Sequential(RMSNorm((c_atom,)), linearNoBias(c_atom, 3))
71
+ self.sequence_head = LinearSequenceHead(c_token=c_token)
72
+
73
+ self.n_recycle = n_recycle
74
+ self.n_bins = 65
75
+ self.bucketize_fn = functools.partial(
76
+ bucketize_scaled_distogram,
77
+ min_dist=1,
78
+ max_dist=30,
79
+ sigma_data=1,
80
+ n_bins=self.n_bins,
81
+ )
82
+
83
+ # Time processing
84
+ self.fourier_embedding = nn.ModuleList(
85
+ [FourierEmbedding(c_t_embed), FourierEmbedding(c_t_embed)]
86
+ )
87
+ self.process_n = nn.ModuleList(
88
+ [
89
+ nn.Sequential(RMSNorm(c_t_embed), linearNoBias(c_t_embed, c_atom)),
90
+ nn.Sequential(RMSNorm(c_t_embed), linearNoBias(c_t_embed, c_s)),
91
+ ]
92
+ )
93
+ self.downcast_c = Downcast(c_atom=c_atom, c_token=c_s, c_s=None, **downcast)
94
+ self.downcast_q = Downcast(c_atom=c_atom, c_token=c_token, c_s=c_s, **downcast)
95
+ self.process_a = LinearEmbedWithPool(c_token)
96
+ self.process_c = nn.Sequential(RMSNorm(c_atom), linearNoBias(c_atom, c_atom))
97
+
98
+ # UNet-like architecture for processing across tokens and atoms
99
+ self.encoder = LocalAtomTransformer(
100
+ c_atom=c_atom, c_s=c_atom, c_atompair=c_atompair, **atom_attention_encoder
101
+ )
102
+
103
+ self.diffusion_token_encoder = DiffusionTokenEncoder(
104
+ c_s=c_s,
105
+ c_token=c_token,
106
+ c_z=c_z,
107
+ c_atompair=c_atompair,
108
+ **diffusion_token_encoder,
109
+ )
110
+
111
+ self.diffusion_transformer = LocalTokenTransformer(
112
+ c_token=c_token,
113
+ c_tokenpair=c_z,
114
+ c_s=c_s,
115
+ **diffusion_transformer,
116
+ )
117
+
118
+ self.decoder = CompactStreamingDecoder(
119
+ c_atom=c_atom,
120
+ c_atompair=c_atompair,
121
+ c_token=c_token,
122
+ c_s=c_s,
123
+ c_tokenpair=c_z,
124
+ **atom_attention_decoder,
125
+ )
126
+
127
+ def scale_positions_in(self, X_noisy_L, t):
128
+ if t.ndim == 1:
129
+ t = t[..., None, None] # [B, (n_atoms), (3)]
130
+ elif t.ndim == 2:
131
+ t = t[..., None] # [B, n_atoms, (3)]
132
+
133
+ if self.f_pred == "edm":
134
+ R_noisy_L = X_noisy_L / torch.sqrt(t**2 + self.sigma_data**2)
135
+ elif self.f_pred == "unconditioned":
136
+ R_noisy_L = torch.zeros_like(X_noisy_L)
137
+ elif self.f_pred == "noise_pred":
138
+ R_noisy_L = X_noisy_L
139
+ else:
140
+ raise Exception(f"{self.f_pred=} unrecognized")
141
+ return R_noisy_L
142
+
143
+ def scale_positions_out(self, R_update_L, X_noisy_L, t):
144
+ if t.ndim == 1:
145
+ t = t[..., None, None]
146
+ elif t.ndim == 2:
147
+ t = t[..., None] # [B, n_atoms, (3)]
148
+
149
+ if self.f_pred == "edm":
150
+ X_out_L = (self.sigma_data**2 / (self.sigma_data**2 + t**2)) * X_noisy_L + (
151
+ self.sigma_data * t / (self.sigma_data**2 + t**2) ** 0.5
152
+ ) * R_update_L
153
+ elif self.f_pred == "unconditioned":
154
+ X_out_L = R_update_L
155
+ elif self.f_pred == "noise_pred":
156
+ X_out_L = X_noisy_L + R_update_L
157
+ else:
158
+ raise Exception(f"{self.f_pred=} unrecognized")
159
+ return X_out_L
160
+
161
+ def process_time_(self, t_L, i):
162
+ C_L = self.process_n[i](
163
+ self.fourier_embedding[i](
164
+ 1 / 4 * torch.log(torch.clamp(t_L, min=1e-20) / self.sigma_data)
165
+ )
166
+ )
167
+ # Mask out zero-time features;
168
+ C_L = C_L * (t_L > 0).float()[..., None] # [B, L, C_atom]
169
+ return C_L
170
+
171
+ def forward(
172
+ self,
173
+ X_noisy_L,
174
+ t,
175
+ f,
176
+ # Features from initialization
177
+ Q_L_init,
178
+ C_L,
179
+ P_LL,
180
+ S_I,
181
+ Z_II,
182
+ n_recycle=None,
183
+ # Chunked memory optimization parameters
184
+ chunked_pairwise_embedder=None,
185
+ initializer_outputs=None,
186
+ **kwargs,
187
+ ):
188
+ """
189
+ Diffusion forward pass with recycling.
190
+ Computes denoised positions given encoded features and noisy coordinates.
191
+ """
192
+ # ... Collect inputs
193
+ tok_idx = f["atom_to_token_map"]
194
+ L = len(tok_idx)
195
+ I = tok_idx.max() + 1 # Number of tokens
196
+ f["attn_indices"] = create_attention_indices(
197
+ X_L=X_noisy_L,
198
+ f=f,
199
+ n_attn_keys=self.n_attn_keys,
200
+ n_attn_seq_neighbours=self.n_attn_seq_neighbours,
201
+ )
202
+
203
+ # ... Expand t tensors
204
+ t_L = t.unsqueeze(-1).expand(-1, L) * (
205
+ ~f["is_motif_atom_with_fixed_coord"]
206
+ ).float().unsqueeze(0)
207
+ t_I = t.unsqueeze(-1).expand(-1, I) * (
208
+ ~f["is_motif_token_with_fully_fixed_coord"]
209
+ ).float().unsqueeze(0)
210
+
211
+ # ... Create scaled positions
212
+ R_L_uniform = self.scale_positions_in(X_noisy_L, t)
213
+ R_noisy_L = self.scale_positions_in(X_noisy_L, t_L)
214
+
215
+ # ... Pool initial representation to sequence level
216
+ A_I = self.process_a(R_noisy_L, tok_idx=tok_idx)
217
+ S_I = self.downcast_c(C_L, S_I, tok_idx=tok_idx)
218
+
219
+ # ... Add batch-wise features to inputs
220
+ Q_L = Q_L_init.unsqueeze(0) + self.process_r(R_noisy_L)
221
+ C_L = C_L.unsqueeze(0) + self.process_time_(t_L, i=0)
222
+ S_I = S_I.unsqueeze(0) + self.process_time_(t_I, i=1)
223
+ C_L = C_L + self.process_c(C_L)
224
+
225
+ # ... Run Local-Atom Self Attention and Pool
226
+ if chunked_pairwise_embedder is not None:
227
+ # Chunked mode: pass chunked embedder and feature dict
228
+ Q_L = self.encoder(
229
+ Q_L,
230
+ C_L,
231
+ P_LL=None,
232
+ indices=f["attn_indices"],
233
+ f=f, # Pass feature dict for chunked computation
234
+ chunked_pairwise_embedder=chunked_pairwise_embedder,
235
+ initializer_outputs=initializer_outputs,
236
+ )
237
+ else:
238
+ # Standard mode: use full P_LL
239
+ Q_L = self.encoder(Q_L, C_L, P_LL, indices=f["attn_indices"])
240
+ A_I = self.downcast_q(Q_L, A_I=A_I, S_I=S_I, tok_idx=tok_idx)
241
+
242
+ # Debug chunked parameters
243
+
244
+ # ... Run forward with recycling
245
+ recycled_features = self.forward_with_recycle(
246
+ n_recycle,
247
+ X_noisy_L=X_noisy_L,
248
+ R_L_uniform=R_L_uniform,
249
+ t_L=t_L,
250
+ f=f,
251
+ Q_L=Q_L,
252
+ C_L=C_L,
253
+ P_LL=P_LL,
254
+ A_I=A_I,
255
+ S_I=S_I,
256
+ Z_II=Z_II,
257
+ chunked_pairwise_embedder=chunked_pairwise_embedder,
258
+ initializer_outputs=initializer_outputs,
259
+ )
260
+
261
+ # ... Collect outputs
262
+ outputs = {
263
+ "X_L": recycled_features["X_L"], # [B, L, 3] denoised positions
264
+ "sequence_indices_I": recycled_features["sequence_indices_I"],
265
+ "sequence_logits_I": recycled_features["sequence_logits_I"],
266
+ }
267
+ return outputs
268
+
269
+ def forward_with_recycle(
270
+ self,
271
+ n_recycle,
272
+ **kwargs,
273
+ ):
274
+ if not self.training:
275
+ n_recycle = self.n_recycle
276
+ else:
277
+ assert n_recycle is not None
278
+
279
+ recycled_features = {}
280
+ for i in range(n_recycle):
281
+ with ExitStack() as stack:
282
+ last = not (i < n_recycle - 1)
283
+ if not last:
284
+ stack.enter_context(torch.no_grad())
285
+
286
+ # Clear the autocast cache if gradients are enabled (workaround for autocast bug)
287
+ # See: https://github.com/pytorch/pytorch/issues/65766
288
+ if torch.is_grad_enabled():
289
+ torch.clear_autocast_cache()
290
+
291
+ # Run forward
292
+ recycled_features = self.process_(
293
+ D_II_self=recycled_features.get("D_II_self"),
294
+ X_L_self=recycled_features.get("X_L"),
295
+ **kwargs,
296
+ )
297
+
298
+ return recycled_features
299
+
300
+ def process_(
301
+ self,
302
+ D_II_self,
303
+ X_L_self,
304
+ *,
305
+ R_L_uniform,
306
+ X_noisy_L,
307
+ t_L,
308
+ f,
309
+ Q_L,
310
+ C_L,
311
+ P_LL,
312
+ A_I,
313
+ S_I,
314
+ Z_II,
315
+ chunked_pairwise_embedder=None,
316
+ initializer_outputs=None,
317
+ **_,
318
+ ):
319
+ # ... Embed token level features with atom level encodings
320
+ S_I, Z_II = self.diffusion_token_encoder(
321
+ f=f,
322
+ R_L=R_L_uniform,
323
+ D_II_self=D_II_self,
324
+ S_init_I=S_I,
325
+ Z_init_II=Z_II,
326
+ C_L=C_L,
327
+ P_LL=P_LL,
328
+ )
329
+
330
+ # ... Diffusion transformer
331
+ A_I = self.diffusion_transformer(
332
+ A_I,
333
+ S_I,
334
+ Z_II,
335
+ f=f,
336
+ X_L=(
337
+ X_noisy_L[..., f["is_ca"], :]
338
+ if X_L_self is None
339
+ else X_L_self[..., f["is_ca"], :]
340
+ ),
341
+ full=not (os.environ.get("RFD3_LOW_MEMORY_MODE", None) == "1"),
342
+ )
343
+
344
+ # ... Decoder readout
345
+ # Check if using chunked P_LL mode
346
+
347
+ if chunked_pairwise_embedder is not None:
348
+ # Chunked mode: pass embedder and no P_LL
349
+ A_I, Q_L, o = self.decoder(
350
+ A_I,
351
+ S_I,
352
+ Z_II,
353
+ Q_L,
354
+ C_L,
355
+ P_LL=None, # Not used in chunked mode
356
+ tok_idx=f["atom_to_token_map"],
357
+ indices=f["attn_indices"],
358
+ f=f, # Pass f for chunked computation
359
+ chunked_pairwise_embedder=chunked_pairwise_embedder,
360
+ initializer_outputs=initializer_outputs,
361
+ )
362
+ else:
363
+ # Original mode: use full P_LL
364
+ A_I, Q_L, o = self.decoder(
365
+ A_I,
366
+ S_I,
367
+ Z_II,
368
+ Q_L,
369
+ C_L,
370
+ P_LL=P_LL,
371
+ tok_idx=f["atom_to_token_map"],
372
+ indices=f["attn_indices"],
373
+ )
374
+
375
+ # ... Process outputs to positions update
376
+ R_update_L = self.to_r_update(Q_L)
377
+ X_out_L = self.scale_positions_out(R_update_L, X_noisy_L, t_L)
378
+
379
+ sequence_logits_I, sequence_indices_I = self.sequence_head(A_I=A_I)
380
+ D_II_self = self.bucketize_fn(X_out_L[..., f["is_ca"], :].detach())
381
+
382
+ return {
383
+ "X_L": X_out_L,
384
+ "D_II_self": D_II_self,
385
+ "sequence_logits_I": sequence_logits_I,
386
+ "sequence_indices_I": sequence_indices_I,
387
+ } | o
@@ -0,0 +1,81 @@
1
+ import torch
2
+
3
+
4
+ def strip_f(
5
+ f,
6
+ cfg_features,
7
+ ):
8
+ """
9
+ Strips conditioning features from 'f' for classifier-free guidance.
10
+
11
+ Args:
12
+ f (dict): Conditioning features
13
+ cfg_features (list): List of features to be set to 0
14
+
15
+ Returns:
16
+ dict: Stripped conditioning features
17
+ """
18
+ # variable used to identify token and atom features independent of their variable names (this way we only need to hardcode these two)
19
+ token_dim = f["is_motif_token_unindexed"].shape[0]
20
+ atom_dim = f["is_motif_atom_unindexed"].shape[0]
21
+
22
+ # identify the first atom and token to be cropped
23
+ crop = torch.any(f["is_motif_atom_unindexed"]).item()
24
+ atom_crop_index = (
25
+ torch.where(f["is_motif_atom_unindexed"])[0][0]
26
+ if crop
27
+ else f["is_motif_atom_unindexed"].shape[0]
28
+ )
29
+ token_crop_index = (
30
+ torch.where(f["is_motif_token_unindexed"])[0][0]
31
+ if crop
32
+ else f["is_motif_token_unindexed"].shape[0]
33
+ )
34
+
35
+ # ... Mask out conditioning features
36
+ f_stripped = f.copy()
37
+
38
+ # Crop features based on them being atom or token features and based on them being 1d or 2d features
39
+ for k, v in f.items():
40
+ # handle cases not captured below
41
+ v_cropped = v
42
+
43
+ # handle token features
44
+ if token_dim in v.shape:
45
+ # Check if it's a 2D feature (square matrix)
46
+ if len(v.shape) == 2 and v.shape[0] == v.shape[1]:
47
+ v_cropped = v[:token_crop_index, :token_crop_index]
48
+ else:
49
+ v_cropped = v[:token_crop_index]
50
+ # handle atom features
51
+ if atom_dim in v.shape:
52
+ # Check if it's a 2D feature (square matrix)
53
+ if len(v.shape) == 2 and v.shape[0] == v.shape[1]:
54
+ v_cropped = v[:atom_crop_index, :atom_crop_index]
55
+ else:
56
+ v_cropped = v[:atom_crop_index]
57
+
58
+ # set the feature to default value if it is in the cfg_features
59
+ if k in cfg_features:
60
+ v_cropped = torch.zeros_like(v_cropped).to(
61
+ v_cropped.device, dtype=v_cropped.dtype
62
+ )
63
+
64
+ # update the feature in the dictionary
65
+ f_stripped[k] = v_cropped
66
+
67
+ return f_stripped
68
+
69
+
70
+ def strip_X(X_L, f_stripped):
71
+ """
72
+ Strips X_L unindexed atoms from X_L
73
+
74
+ Args:
75
+ X_L (torch.Tensor): Atom coordinates
76
+ f_stripped (dict): Stripped conditioning features
77
+
78
+ Returns:
79
+ torch.Tensor: Atom coordinates with unindexed atoms removed
80
+ """
81
+ return X_L[..., : f_stripped["is_motif_atom_unindexed"].shape[0], :]