rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,777 @@
1
+ import logging
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from atomworks.ml.encoding_definitions import AF3SequenceEncoding
8
+ from einops import rearrange
9
+ from rfd3.model.layers.attention import (
10
+ GatedCrossAttention,
11
+ LocalAttentionPairBias,
12
+ )
13
+ from rfd3.model.layers.block_utils import (
14
+ build_valid_mask,
15
+ create_attention_indices,
16
+ group_atoms,
17
+ ungroup_atoms,
18
+ )
19
+ from rfd3.model.layers.layer_utils import (
20
+ AdaLN,
21
+ EmbeddingLayer,
22
+ LinearBiasInit,
23
+ RMSNorm,
24
+ Transition,
25
+ collapse,
26
+ linearNoBias,
27
+ )
28
+ from rfd3.model.layers.pairformer_layers import PairformerBlock
29
+ from torch.nn.functional import one_hot
30
+
31
+ from foundry import DISABLE_CHECKPOINTING
32
+ from foundry.common import exists
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ # SwiGLU transition block with adaptive layernorm
38
+ class ConditionedTransitionBlock(nn.Module):
39
+ def __init__(self, c_token, c_s, n=2):
40
+ super().__init__()
41
+ self.ada_ln = AdaLN(c_a=c_token, c_s=c_s)
42
+ self.linear_1 = linearNoBias(c_token, c_token * n)
43
+ self.linear_2 = linearNoBias(c_token, c_token * n)
44
+ self.linear_output_project = nn.Sequential(
45
+ LinearBiasInit(c_s, c_token, biasinit=-2.0),
46
+ nn.Sigmoid(),
47
+ )
48
+ self.linear_3 = linearNoBias(c_token * n, c_token)
49
+
50
+ def forward(
51
+ self,
52
+ Ai, # [B, I, C_token]
53
+ Si, # [B, I, C_token]
54
+ ):
55
+ Ai = self.ada_ln(Ai, Si)
56
+ # BUG: This is not the correct implementation of SwiGLU
57
+ # Bi = torch.sigmoid(self.linear_1(Ai)) * self.linear_2(Ai)
58
+ # FIX: This is the correct implementation of SwiGLU
59
+ Bi = torch.nn.functional.silu(self.linear_1(Ai)) * self.linear_2(Ai)
60
+
61
+ # Output projection (from adaLN-Zero)
62
+ return self.linear_output_project(Si) * self.linear_3(Bi)
63
+
64
+
65
+ class PositionPairDistEmbedder(nn.Module):
66
+ def __init__(self, c_atompair, embed_frame=True):
67
+ super().__init__()
68
+ self.embed_frame = embed_frame
69
+ if embed_frame:
70
+ self.process_d = linearNoBias(3, c_atompair)
71
+
72
+ self.process_inverse_dist = linearNoBias(1, c_atompair)
73
+ self.process_valid_mask = linearNoBias(1, c_atompair)
74
+
75
+ def forward_af3(self, D_LL, V_LL):
76
+ """Forward the same way reference positions are embeded in AF3"""
77
+
78
+ P_LL = self.process_d(D_LL) * V_LL
79
+
80
+ # Embed pairwise inverse squared distances, and the valid mask
81
+ if self.training:
82
+ P_LL = (
83
+ P_LL
84
+ + self.process_inverse_dist(
85
+ 1 / (1 + torch.linalg.norm(D_LL, dim=-1, keepdim=True) ** 2)
86
+ )
87
+ * V_LL
88
+ )
89
+ P_LL = P_LL + self.process_valid_mask(V_LL.to(P_LL.dtype)) * V_LL
90
+ else:
91
+ P_LL[V_LL[..., 0]] += self.process_inverse_dist(
92
+ 1
93
+ / (1 + torch.linalg.norm(D_LL[V_LL[..., 0]], dim=-1, keepdim=True) ** 2)
94
+ )
95
+ P_LL[V_LL[..., 0]] += self.process_valid_mask(
96
+ V_LL[V_LL[..., 0]].to(P_LL.dtype)
97
+ )
98
+ return P_LL
99
+
100
+ def forward(self, ref_pos, valid_mask):
101
+ D_LL = ref_pos.unsqueeze(-2) - ref_pos.unsqueeze(-3)
102
+ V_LL = valid_mask
103
+
104
+ if self.embed_frame:
105
+ # Embed pairwise distances
106
+ return self.forward_af3(D_LL, V_LL)
107
+ norm = torch.linalg.norm(D_LL, dim=-1, keepdim=True) ** 2
108
+ norm = torch.clamp(norm, min=1e-6)
109
+ inv_dist = 1 / (1 + norm)
110
+ P_LL = self.process_inverse_dist(inv_dist) * V_LL
111
+ P_LL = P_LL + self.process_valid_mask(V_LL.to(P_LL.dtype)) * V_LL
112
+ return P_LL
113
+
114
+
115
+ class OneDFeatureEmbedder(nn.Module):
116
+ """
117
+ Embeds 1D features into a single vector.
118
+
119
+ Args:
120
+ features (dict): Dictionary of feature names and their number of channels.
121
+ output_channels (int): Output dimension of the projected embedding.
122
+ """
123
+
124
+ def __init__(self, features, output_channels):
125
+ super().__init__()
126
+ self.features = {k: v for k, v in features.items() if exists(v)}
127
+ total_embedding_input_features = sum(self.features.values())
128
+ self.embedders = nn.ModuleDict(
129
+ {
130
+ feature: EmbeddingLayer(
131
+ n_channels, total_embedding_input_features, output_channels
132
+ )
133
+ for feature, n_channels in self.features.items()
134
+ }
135
+ )
136
+
137
+ def forward(self, f, collapse_length):
138
+ return sum(
139
+ tuple(
140
+ self.embedders[feature](collapse(f[feature].float(), collapse_length))
141
+ for feature, n_channels in self.features.items()
142
+ if exists(n_channels)
143
+ )
144
+ )
145
+
146
+
147
+ class SinusoidalDistEmbed(nn.Module):
148
+ """
149
+ Applies sinusoidal embedding to pairwise distances and projects to c_atompair.
150
+
151
+ Args:
152
+ c_atompair (int): Output dimension of the projected embedding (must be even).
153
+ """
154
+
155
+ def __init__(self, c_atompair, n_freqs=32):
156
+ super().__init__()
157
+ assert c_atompair % 2 == 0, "Output embedding dim must be even"
158
+
159
+ self.n_freqs = (
160
+ n_freqs # Number of sin/cos pairs → total sinusoidal dim = 2 * n_freqs
161
+ )
162
+ self.c_atompair = c_atompair
163
+
164
+ self.output_proj = linearNoBias(2 * n_freqs, c_atompair)
165
+ self.process_valid_mask = linearNoBias(1, c_atompair)
166
+
167
+ def forward(self, pos, valid_mask):
168
+ """
169
+ Args:
170
+ pos: [L, 3] or [B, L, 3] ground truth atom positions
171
+ valid_mask: [L, L, 1] or [B, L, L, 1] boolean mask
172
+ Returns:
173
+ P_LL: [L, L, c_atompair] or [B, L, L, c_atompair]
174
+ """
175
+ # Compute pairwise distances
176
+ D_LL = pos.unsqueeze(-2) - pos.unsqueeze(-3) # [L, L, 3] or [B, L, L, 3]
177
+ dist_matrix = torch.linalg.norm(D_LL, dim=-1) # [L, L] or [B, L, L]
178
+
179
+ # Sinusoidal embedding
180
+ half_dim = self.n_freqs
181
+ freq = torch.exp(
182
+ -math.log(10000.0)
183
+ * torch.arange(0, half_dim, dtype=torch.float32)
184
+ / half_dim
185
+ ).to(dist_matrix.device) # [n_freqs]
186
+
187
+ angles = dist_matrix.unsqueeze(-1) * freq # [..., D/2]
188
+ sin_embed = torch.sin(angles)
189
+ cos_embed = torch.cos(angles)
190
+ sincos_embed = torch.cat([sin_embed, cos_embed], dim=-1) # [..., D]
191
+
192
+ # Linear projection
193
+ P_LL = self.output_proj(sincos_embed) # [..., c_atompair]
194
+ P_LL = P_LL * valid_mask
195
+
196
+ # Add linear embedding of valid mask
197
+ P_LL = P_LL + self.process_valid_mask(valid_mask.to(P_LL.dtype)) * valid_mask
198
+ return P_LL
199
+
200
+
201
+ class LinearEmbedWithPool(nn.Module):
202
+ def __init__(self, c_token):
203
+ super().__init__()
204
+ self.c_token = c_token
205
+ self.linear = linearNoBias(3, c_token)
206
+
207
+ def forward(self, R_L, tok_idx):
208
+ B = R_L.shape[0]
209
+ I = int(tok_idx.max().item()) + 1
210
+ A_I_shape = (
211
+ B,
212
+ I,
213
+ self.c_token,
214
+ )
215
+ Q_L = self.linear(R_L)
216
+ A_I = (
217
+ torch.zeros(A_I_shape, device=R_L.device, dtype=Q_L.dtype)
218
+ .index_reduce(
219
+ -2,
220
+ tok_idx.long(),
221
+ Q_L,
222
+ "mean",
223
+ include_self=False,
224
+ )
225
+ .clone()
226
+ )
227
+ return A_I
228
+
229
+
230
+ class SimpleRecycler(nn.Module):
231
+ def __init__(
232
+ self,
233
+ c_s,
234
+ c_z,
235
+ template_embedder,
236
+ msa_module,
237
+ n_pairformer_blocks,
238
+ pairformer_block,
239
+ ):
240
+ super().__init__()
241
+ self.c_z = c_z
242
+ self.process_zh = nn.Sequential(
243
+ RMSNorm(c_z),
244
+ linearNoBias(c_z, c_z),
245
+ )
246
+ self.process_sh = nn.Sequential(
247
+ RMSNorm(c_s),
248
+ linearNoBias(c_s, c_s),
249
+ )
250
+ self.pairformer_stack = nn.ModuleList(
251
+ [
252
+ PairformerBlock(c_s=c_s, c_z=c_z, **pairformer_block)
253
+ for _ in range(n_pairformer_blocks)
254
+ ]
255
+ )
256
+ # Templates and msa's removed:
257
+ # self.template_embedder = TemplateEmbedder(c_z=c_z, **template_embedder)
258
+ # self.msa_module = MSAModule(**msa_module)
259
+
260
+ def forward(
261
+ self,
262
+ f,
263
+ S_inputs_I,
264
+ S_init_I,
265
+ Z_init_II,
266
+ S_I,
267
+ Z_II,
268
+ ):
269
+ Z_II = Z_init_II + self.process_zh(Z_II)
270
+
271
+ # Templates and msa's removed:
272
+ # Z_II = Z_II + self.template_embedder(f, Z_II)
273
+ # Z_II = self.msa_module(f, Z_II, S_inputs_I)
274
+
275
+ S_I = S_init_I + self.process_sh(S_I)
276
+ for block in self.pairformer_stack:
277
+ S_I, Z_II = block(S_I, Z_II)
278
+ return S_I, Z_II
279
+
280
+
281
+ class RelativePositionEncodingWithIndexRemoval(nn.Module):
282
+ """
283
+ Usual RPE but utilizes `is_motif_atom_3d_unindexed` to ensure within-chain position is spoofed.
284
+ """
285
+
286
+ def __init__(self, r_max, s_max, c_z):
287
+ super().__init__()
288
+ self.r_max = r_max
289
+ self.s_max = s_max
290
+ self.c_z = c_z
291
+
292
+ self.num_tok_pos_bins = (
293
+ 2 * self.r_max + 2
294
+ ) + 1 # original af3 + 1 for unknown index
295
+ self.linear = linearNoBias(
296
+ 2 * self.num_tok_pos_bins + (2 * self.s_max + 2) + 1, c_z
297
+ )
298
+
299
+ def forward(self, f):
300
+ b_samechain_II = f["asym_id"].unsqueeze(-1) == f["asym_id"].unsqueeze(-2)
301
+ b_same_entity_II = f["entity_id"].unsqueeze(-1) == f["entity_id"].unsqueeze(-2)
302
+ d_residue_II = torch.where(
303
+ b_samechain_II,
304
+ torch.clip(
305
+ f["residue_index"].unsqueeze(-1)
306
+ - f["residue_index"].unsqueeze(-2)
307
+ + self.r_max,
308
+ 0,
309
+ 2 * self.r_max,
310
+ ),
311
+ 2 * self.r_max + 1,
312
+ )
313
+ b_sameresidue_II = f["residue_index"].unsqueeze(-1) == f[
314
+ "residue_index"
315
+ ].unsqueeze(-2)
316
+ tok_distance = (
317
+ f["token_index"].unsqueeze(-1) - f["token_index"].unsqueeze(-2) + self.r_max
318
+ )
319
+ d_token_II = torch.where(
320
+ b_samechain_II * b_sameresidue_II,
321
+ torch.clip(
322
+ tok_distance,
323
+ 0,
324
+ 2 * self.r_max,
325
+ ),
326
+ 2 * self.r_max + 1,
327
+ )
328
+
329
+ # Chain distances are kept
330
+ d_chain_II = torch.where(
331
+ # NOTE: Implementing bugfix from the Protenix Technical report, where we use `same_entity` instead of `not same_chain` (as in the AF-3 pseudocode)
332
+ # Reference: https://github.com/bytedance/Protenix/blob/main/Protenix_Technical_Report.pdf
333
+ b_same_entity_II,
334
+ torch.clip(
335
+ f["sym_id"].unsqueeze(-1) - f["sym_id"].unsqueeze(-2) + self.s_max,
336
+ 0,
337
+ 2 * self.s_max,
338
+ ),
339
+ 2 * self.s_max + 1,
340
+ )
341
+ A_relchain_II = one_hot(d_chain_II.long(), 2 * self.s_max + 2)
342
+
343
+ #########################################################
344
+ # Cancel out distances from unidexed motifs
345
+ unindexing_pair_mask = f[
346
+ "unindexing_pair_mask"
347
+ ] # [L, L] representing the parts which shouldnt' talk to one another
348
+
349
+ # Special position case
350
+ d_token_II[unindexing_pair_mask] = self.num_tok_pos_bins - 1
351
+ d_residue_II[unindexing_pair_mask] = self.num_tok_pos_bins - 1
352
+
353
+ A_relpos_II = one_hot(d_residue_II.long(), self.num_tok_pos_bins)
354
+ A_reltoken_II = one_hot(d_token_II, self.num_tok_pos_bins)
355
+ #########################################################
356
+
357
+ return self.linear(
358
+ torch.cat(
359
+ [
360
+ A_relpos_II,
361
+ A_reltoken_II,
362
+ b_same_entity_II.unsqueeze(-1),
363
+ A_relchain_II,
364
+ ],
365
+ dim=-1,
366
+ ).to(torch.float)
367
+ )
368
+
369
+
370
+ class VirtualPredictor(nn.Module):
371
+ def __init__(self, c_atom):
372
+ super(VirtualPredictor, self).__init__()
373
+ self.process_atom_embeddings = nn.Sequential(
374
+ RMSNorm((c_atom,)), linearNoBias(c_atom, 1)
375
+ )
376
+
377
+ def forward(self, Q_L):
378
+ return self.process_atom_embeddings(Q_L)
379
+
380
+
381
+ class SequenceHead(nn.Module):
382
+ def __init__(self, c_token):
383
+ super(SequenceHead, self).__init__()
384
+
385
+ # Distogram feature extraction
386
+ self.dist_fc1 = nn.Linear(196, 128)
387
+ self.dist_relu = nn.ReLU()
388
+ self.dist_fc2 = nn.Linear(128, 64)
389
+
390
+ # Embedding feature extraction
391
+ self.embed_fc1 = nn.Linear(c_token, 128)
392
+ self.embed_relu = nn.ReLU()
393
+ self.embed_fc2 = nn.Linear(128, 64)
394
+
395
+ # Fusion layer
396
+ self.fusion_fc = nn.Linear(128, 32)
397
+
398
+ # Sequence encoding
399
+ self.sequence_encoding_ = AF3SequenceEncoding()
400
+
401
+ def forward(self, A_I, Q_L, X_L, f):
402
+ B, L, _ = X_L.shape
403
+ max_res_id = f["atom_to_token_map"].max().item() + 1
404
+
405
+ # Detach tensors to avoid gradients through main module
406
+ # X_L = X_L.detach()
407
+ # A_I = A_I.detach()
408
+ # Q_L = Q_L.detach()
409
+
410
+ # Compute distograms
411
+ residue_distogram = torch.zeros(B, max_res_id, 14, 14, device=X_L.device)
412
+ for i in range(max_res_id):
413
+ residue_mask = f["atom_to_token_map"] == i
414
+ if residue_mask.sum() == 14:
415
+ coords = X_L[:, residue_mask] # (B, 14, 3)
416
+ residue_distogram[:, i] = torch.cdist(coords, coords)
417
+
418
+ # Flatten distogram
419
+ dist_features = residue_distogram.view(B, max_res_id, 196)
420
+
421
+ # Pass through separate MLPs
422
+ dist_out = self.dist_fc1(dist_features)
423
+ dist_out = self.dist_relu(dist_out)
424
+ dist_out = self.dist_fc2(dist_out)
425
+
426
+ embed_out = self.embed_fc1(A_I)
427
+ embed_out = self.embed_relu(embed_out)
428
+ embed_out = self.embed_fc2(embed_out)
429
+
430
+ # Fusion via concatenation
431
+ fused = torch.cat([dist_out, embed_out], dim=-1)
432
+ Seq_I = self.fusion_fc(fused)
433
+
434
+ indices = self.decode(Seq_I)
435
+
436
+ return Seq_I, indices
437
+
438
+ def decode(self, Seq_I):
439
+ indices = Seq_I.argmax(dim=-1) # [B, L]
440
+ return indices
441
+
442
+
443
+ class LinearSequenceHead(nn.Module):
444
+ def __init__(self, c_token):
445
+ super().__init__()
446
+ n_tok_all = 32
447
+ disallowed_idxs = AF3SequenceEncoding().encode(["UNK", "X", "DX", "<G>"])
448
+ mask = torch.ones(n_tok_all, dtype=torch.bool)
449
+ mask[disallowed_idxs] = False
450
+ self.register_buffer("valid_out_mask", mask)
451
+ self.linear = nn.Linear(c_token, n_tok_all)
452
+
453
+ def forward(self, A_I, **_):
454
+ logits = self.linear(A_I)
455
+ indices = self.decode(logits)
456
+ return logits, indices
457
+
458
+ def decode(self, logits):
459
+ # logits: [D, L, 28]
460
+ # indices: [D, L] in [0,32-1]
461
+ D, I, _ = logits.shape
462
+ probs = F.softmax(logits, dim=-1)
463
+ probs = probs * self.valid_out_mask[None, None, :].to(probs.device)
464
+ probs = probs / (probs.sum(dim=-1, keepdim=True) + 1e-8)
465
+ indices = probs.argmax(axis=-1)
466
+ return indices
467
+
468
+
469
+ class Upcast(nn.Module):
470
+ def __init__(
471
+ self, c_token, c_atom, method="broadcast", cross_attention_block=None, n_split=6
472
+ ):
473
+ super().__init__()
474
+ self.method = method
475
+ self.n_split = n_split
476
+ if self.method == "broadcast":
477
+ self.project = nn.Sequential(
478
+ RMSNorm((c_token,)), linearNoBias(c_token, c_atom)
479
+ )
480
+ elif self.method == "cross_attention":
481
+ self.gca = GatedCrossAttention(
482
+ c_query=c_atom, c_kv=c_token // self.n_split, **cross_attention_block
483
+ )
484
+ else:
485
+ raise ValueError(f"Unknown upcast method: {self.method}")
486
+
487
+ def forward_(self, Q_IA, A_I, valid_mask=None):
488
+ if self.method == "broadcast":
489
+ Q_IA = Q_IA + self.project(A_I)[..., None, :]
490
+ elif self.method == "cross_attention":
491
+ assert exists(A_I) and exists(valid_mask)
492
+ # Split Tokens
493
+ A_I = rearrange(A_I, "b n (s c) -> b n s c", s=self.n_split)
494
+ n_tokens, n_atom_per_tok = Q_IA.shape[1], Q_IA.shape[2]
495
+
496
+ # Attention mask: ..., n_atom_per_tok, n_split
497
+ attn_mask = torch.full(
498
+ (n_tokens, 1, n_atom_per_tok), True, device=Q_IA.device
499
+ )
500
+ attn_mask[~valid_mask.view_as(attn_mask)] = False
501
+
502
+ attn_mask = torch.ones(
503
+ (n_tokens, n_atom_per_tok, self.n_split), device=A_I.device, dtype=bool
504
+ )
505
+ attn_mask[~valid_mask, :] = False
506
+
507
+ Q_IA = Q_IA + self.gca(q=Q_IA, kv=A_I, attn_mask=attn_mask)
508
+ return Q_IA
509
+
510
+ def forward(self, Q_L, A_I, tok_idx):
511
+ valid_mask = build_valid_mask(tok_idx)
512
+ Q_IA = ungroup_atoms(Q_L, valid_mask)
513
+ Q_IA = self.forward_(Q_IA, A_I, valid_mask)
514
+ Q_L = group_atoms(Q_IA, valid_mask)
515
+ return Q_L
516
+
517
+
518
+ class Downcast(nn.Module):
519
+ """Downcast modules for when atoms are already reshaped from N_atoms -> (N_tokens, 14)"""
520
+
521
+ def __init__(
522
+ self, c_atom, c_token, c_s=None, method="mean", cross_attention_block=None
523
+ ):
524
+ super().__init__()
525
+ self.method = method
526
+ self.c_token = c_token
527
+ self.c_atom = c_atom
528
+ if c_s is not None:
529
+ self.process_s = nn.Sequential(
530
+ RMSNorm((c_s,)),
531
+ linearNoBias(c_s, c_token),
532
+ )
533
+ else:
534
+ self.process_s = None
535
+
536
+ if self.method == "mean":
537
+ self.project = linearNoBias(c_atom, c_token)
538
+ elif self.method == "cross_attention":
539
+ self.gca = GatedCrossAttention(
540
+ c_query=c_token,
541
+ c_kv=c_atom,
542
+ **cross_attention_block,
543
+ )
544
+ else:
545
+ raise ValueError(f"Unknown downcast method: {self.method}")
546
+
547
+ def forward_(self, Q_IA, A_I, S_I=None, valid_mask=None):
548
+ if self.method == "mean":
549
+ A_I_update = self.project(Q_IA).sum(-2) / valid_mask.sum(-1, keepdim=True)
550
+ elif self.method == "cross_attention":
551
+ assert exists(A_I) and exists(valid_mask)
552
+ # Attention mask: ..., 1, n_atom_per_tok (1 querying token to atoms in token)
553
+ attn_mask = valid_mask[..., None, :]
554
+ A_I_update = self.gca(
555
+ q=A_I[..., None, :], kv=Q_IA, attn_mask=attn_mask
556
+ ).squeeze(-2)
557
+
558
+ A_I = A_I + A_I_update if exists(A_I) else A_I_update
559
+
560
+ if self.process_s is not None:
561
+ A_I = A_I + self.process_s(S_I)
562
+ return A_I
563
+
564
+ def forward(self, Q_L, A_I, S_I=None, tok_idx=None):
565
+ valid_mask = build_valid_mask(tok_idx)
566
+ if Q_L.ndim == 2:
567
+ squeeze = True
568
+ Q_L = Q_L.unsqueeze(0)
569
+ else:
570
+ squeeze = False
571
+
572
+ A_I = A_I.unsqueeze(0) if exists(A_I) and A_I.ndim == 2 else A_I
573
+ S_I = S_I.unsqueeze(0) if exists(S_I) and S_I.ndim == 2 else S_I
574
+
575
+ Q_IA = ungroup_atoms(Q_L, valid_mask)
576
+
577
+ A_I = self.forward_(Q_IA, A_I, S_I, valid_mask=valid_mask)
578
+
579
+ if squeeze:
580
+ A_I = A_I.squeeze(0)
581
+ return A_I
582
+
583
+
584
+ ######################################################################################
585
+ ########################## Local Atom Transformer ##########################
586
+ ######################################################################################
587
+
588
+
589
+ class LocalTokenTransformer(nn.Module):
590
+ def __init__(
591
+ self,
592
+ c_token,
593
+ c_tokenpair,
594
+ c_s,
595
+ n_block,
596
+ diffusion_transformer_block,
597
+ n_registers=None,
598
+ n_local_tokens=8,
599
+ n_keys=32,
600
+ ):
601
+ super().__init__()
602
+ self.n_local_tokens = n_local_tokens
603
+ self.n_keys = n_keys
604
+ self.blocks = nn.ModuleList(
605
+ [
606
+ StructureLocalAtomTransformerBlock(
607
+ c_atom=c_token,
608
+ c_s=c_s,
609
+ c_atompair=c_tokenpair,
610
+ **diffusion_transformer_block,
611
+ )
612
+ for _ in range(n_block)
613
+ ]
614
+ )
615
+
616
+ def forward(self, A_I, S_I, Z_II, f, X_L, full=False):
617
+ indices = create_attention_indices(
618
+ X_L=X_L,
619
+ f=f,
620
+ tok_idx=torch.arange(A_I.shape[1], device=A_I.device),
621
+ n_attn_keys=self.n_keys,
622
+ n_attn_seq_neighbours=self.n_local_tokens,
623
+ )
624
+
625
+ for i, block in enumerate(self.blocks):
626
+ # Set checkpointing
627
+ block.attention_pair_bias.use_checkpointing = not DISABLE_CHECKPOINTING
628
+ # A_I: [B, L, C_token]
629
+ # S_I: [B, L, C_s]
630
+ # Z_II: [B, L, L, C_tokenpair]
631
+ A_I = block(
632
+ A_I,
633
+ S_I,
634
+ Z_II,
635
+ indices=indices,
636
+ full=full, # (self.training and torch.is_grad_enabled()), # Does not accelerate inference, but memory *does* scale better
637
+ )
638
+
639
+ return A_I
640
+
641
+
642
+ class LocalAtomTransformer(nn.Module):
643
+ def __init__(self, c_atom, c_s, c_atompair, atom_transformer_block, n_blocks):
644
+ super().__init__()
645
+ self.blocks = nn.ModuleList(
646
+ [
647
+ StructureLocalAtomTransformerBlock(
648
+ c_atom=c_atom,
649
+ c_s=c_s,
650
+ c_atompair=c_atompair,
651
+ **atom_transformer_block,
652
+ )
653
+ for _ in range(n_blocks)
654
+ ]
655
+ )
656
+
657
+ def forward(self, Q_L, C_L, P_LL, **kwargs):
658
+ for block in self.blocks:
659
+ Q_L = block(Q_L, C_L, P_LL, **kwargs)
660
+ return Q_L
661
+
662
+
663
+ class StructureLocalAtomTransformerBlock(nn.Module):
664
+ def __init__(
665
+ self,
666
+ *,
667
+ c_atom,
668
+ c_s,
669
+ c_atompair,
670
+ dropout,
671
+ no_residual_connection_between_attention_and_transition,
672
+ **transformer_block,
673
+ ):
674
+ super().__init__()
675
+ assert not no_residual_connection_between_attention_and_transition
676
+ self.c_s = c_s
677
+ self.dropout = nn.Dropout(dropout)
678
+ self.attention_pair_bias = LocalAttentionPairBias(
679
+ c_a=c_atom, c_s=c_s, c_pair=c_atompair, **transformer_block
680
+ )
681
+ if exists(c_s):
682
+ self.transition_block = ConditionedTransitionBlock(c_token=c_atom, c_s=c_s)
683
+ else:
684
+ self.transition_block = Transition(c=c_atom, n=4)
685
+
686
+ def forward(
687
+ self,
688
+ Q_L, # [..., I, C_token]
689
+ C_L, # [..., I, C_s]
690
+ P_LL, # [..., I, I, C_tokenpair]
691
+ f=None,
692
+ chunked_pairwise_embedder=None,
693
+ initializer_outputs=None,
694
+ **kwargs,
695
+ ):
696
+ Q_L = Q_L + self.dropout(
697
+ self.attention_pair_bias(
698
+ Q_L,
699
+ C_L,
700
+ P_LL,
701
+ f=f,
702
+ chunked_pairwise_embedder=chunked_pairwise_embedder,
703
+ initializer_outputs=initializer_outputs,
704
+ **kwargs,
705
+ )
706
+ )
707
+ if exists(C_L):
708
+ Q_L = Q_L + self.transition_block(Q_L, C_L)
709
+ else:
710
+ Q_L = Q_L + self.transition_block(Q_L)
711
+ return Q_L
712
+
713
+
714
+ class CompactStreamingDecoder(nn.Module):
715
+ def __init__(
716
+ self,
717
+ c_atom,
718
+ c_atompair,
719
+ c_token,
720
+ c_s,
721
+ c_tokenpair,
722
+ atom_transformer_block,
723
+ upcast,
724
+ downcast,
725
+ n_blocks,
726
+ diffusion_transformer_block=False,
727
+ ):
728
+ super().__init__()
729
+ self.n_blocks = n_blocks
730
+
731
+ self.upcast = nn.ModuleList(
732
+ [Upcast(c_atom=c_atom, c_token=c_token, **upcast) for _ in range(n_blocks)]
733
+ )
734
+ self.atom_transformer = nn.ModuleList(
735
+ [
736
+ StructureLocalAtomTransformerBlock(
737
+ c_atom=c_atom,
738
+ c_s=c_atom,
739
+ c_atompair=c_atompair,
740
+ **atom_transformer_block,
741
+ )
742
+ for _ in range(n_blocks)
743
+ ]
744
+ )
745
+ self.downcast = Downcast(c_atom=c_atom, c_token=c_token, c_s=c_s, **downcast)
746
+
747
+ def forward(
748
+ self,
749
+ A_I,
750
+ S_I,
751
+ Z_II,
752
+ Q_L,
753
+ C_L,
754
+ P_LL,
755
+ tok_idx,
756
+ indices,
757
+ f=None,
758
+ chunked_pairwise_embedder=None,
759
+ initializer_outputs=None,
760
+ ):
761
+ for i in range(self.n_blocks):
762
+ Q_L = self.upcast[i](Q_L, A_I, tok_idx=tok_idx)
763
+ Q_L = self.atom_transformer[i](
764
+ Q_L,
765
+ C_L,
766
+ P_LL,
767
+ indices=indices,
768
+ f=f,
769
+ chunked_pairwise_embedder=chunked_pairwise_embedder,
770
+ initializer_outputs=initializer_outputs,
771
+ )
772
+
773
+ # Downcast to sequence
774
+ A_I = self.downcast(Q_L.detach(), A_I.detach(), S_I.detach(), tok_idx=tok_idx)
775
+
776
+ o = {}
777
+ return A_I, Q_L, o