rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,783 @@
1
+ import torch
2
+ from rf3.data.ground_truth_template import (
3
+ af3_noise_scale_to_noise_level,
4
+ )
5
+ from rf3.model.layers.af3_diffusion_transformer import AtomTransformer
6
+ from rf3.model.layers.attention import TriangleAttention, TriangleMultiplication
7
+ from rf3.model.layers.layer_utils import (
8
+ MultiDimLinear,
9
+ Transition,
10
+ collapse,
11
+ create_batch_dimension_if_not_present,
12
+ linearNoBias,
13
+ )
14
+ from rf3.model.layers.mlff import ConformerEmbeddingWeightedAverage
15
+ from rf3.model.layers.outer_product import (
16
+ OuterProductMean_AF3,
17
+ )
18
+ from rf3.model.RF3_blocks import MSAPairWeightedAverage, MSASubsampleEmbedder
19
+ from torch import nn
20
+ from torch.nn.functional import one_hot, relu
21
+
22
+ from foundry.model.layers.blocks import Dropout
23
+ from foundry.training.checkpoint import activation_checkpointing
24
+
25
+
26
+ class AtomAttentionEncoderPairformer(nn.Module):
27
+ def __init__(
28
+ self,
29
+ c_atom,
30
+ c_atompair,
31
+ c_token,
32
+ c_tokenpair,
33
+ c_s,
34
+ atom_1d_features,
35
+ c_atom_1d_features,
36
+ atom_transformer,
37
+ use_inv_dist_squared: bool = False, # HACK: For 9/21 checkpoint, default to False (as this argument was not present in the checkpoint config)
38
+ use_atom_level_embedding: bool = False,
39
+ atom_level_embedding_dim: int = 384,
40
+ ):
41
+ super().__init__()
42
+ self.c_atom = c_atom
43
+ self.c_atompair = c_atompair
44
+ self.c_token = c_token
45
+ self.c_tokenpair = c_tokenpair
46
+ self.c_s = c_s
47
+ self.atom_1d_features = atom_1d_features
48
+
49
+ self.process_input_features = linearNoBias(c_atom_1d_features, c_atom)
50
+
51
+ self.process_d = linearNoBias(3, c_atompair)
52
+ self.process_inverse_dist = linearNoBias(1, c_atompair)
53
+ self.process_valid_mask = linearNoBias(1, c_atompair)
54
+
55
+ self.use_atom_level_embedding = use_atom_level_embedding
56
+
57
+ self.process_single_l = nn.Sequential(
58
+ nn.ReLU(), linearNoBias(c_atom, c_atompair)
59
+ )
60
+ self.process_single_m = nn.Sequential(
61
+ nn.ReLU(), linearNoBias(c_atom, c_atompair)
62
+ )
63
+
64
+ self.pair_mlp = nn.Sequential(
65
+ nn.ReLU(),
66
+ linearNoBias(self.c_atompair, c_atompair),
67
+ nn.ReLU(),
68
+ linearNoBias(self.c_atompair, c_atompair),
69
+ nn.ReLU(),
70
+ linearNoBias(self.c_atompair, c_atompair),
71
+ )
72
+
73
+ self.process_q = nn.Sequential(
74
+ linearNoBias(c_atom, c_token),
75
+ nn.ReLU(),
76
+ )
77
+
78
+ self.atom_transformer = AtomTransformer(
79
+ c_atom=c_atom, c_atompair=c_atompair, **atom_transformer
80
+ )
81
+
82
+ self.use_inv_dist_squared = use_inv_dist_squared
83
+
84
+ if self.use_atom_level_embedding:
85
+ self.process_atom_level_embedding = ConformerEmbeddingWeightedAverage(
86
+ atom_level_embedding_dim=atom_level_embedding_dim,
87
+ c_atompair=c_atompair,
88
+ c_atom=c_atom,
89
+ )
90
+
91
+ def forward(
92
+ self,
93
+ f, # Dict (Input feature dictionary)
94
+ R_L, # [D, L, 3]
95
+ S_trunk_I, # [B, I, C_S_trunk] [...,I,C_S_trunk]
96
+ Z_II, # [B, I, I, C_Z] [...,I,I,C_Z]
97
+ ):
98
+ assert R_L is None
99
+ assert S_trunk_I is None
100
+ assert Z_II is None
101
+
102
+ # ... get the number of atoms and tokens
103
+ tok_idx = f["atom_to_token_map"]
104
+ L = len(tok_idx) # N_atom
105
+ I = tok_idx.max() + 1 # N_token
106
+
107
+ # ... flatten the last two dimensions of ref_atom_name_chars
108
+ # (the letter dimension and the one-hot encoding of the unicode character dimension)
109
+ f["ref_atom_name_chars"] = f["ref_atom_name_chars"].reshape(
110
+ L, -1
111
+ ) # [L, 4, 64] -> [L, 256], where L = N_atom
112
+
113
+ # Atom single conditioning (C_L): Linearly embed concatenated per-atom features
114
+ # (e.g., ref_pos, ref_charge, ref_mask, ref_element, ref_atom_name_chars)
115
+ C_L = self.process_input_features(
116
+ torch.cat(
117
+ tuple(
118
+ collapse(f[feature_name], L)
119
+ for feature_name in self.atom_1d_features
120
+ ),
121
+ dim=-1,
122
+ )
123
+ ) # [L, C_atom]
124
+
125
+ if self.use_atom_level_embedding:
126
+ assert "atom_level_embedding" in f
127
+ C_L = C_L + self.process_atom_level_embedding(f["atom_level_embedding"])
128
+
129
+ # Now, we have the single conditioning (C_L) for each atom. We will:
130
+ # 1. Use C_L to initialize the pair atom representation
131
+ # 2. Pass C_L as a skip connection to the diffusion module
132
+
133
+ # Embed offsets between atom reference positions
134
+ # ref_pos is of shape [L, 3], so ref_pos.unsqueeze(-2) is of shape [L, 1, 3] and ref_pos.unsqueeze(-3) is of shape [1, L, 3]
135
+ # We then take the outer difference between these two tensors to get a tensor of shape [L, L, 3] (via broadcasting both to shape [L, L, 3], and then taking the difference)
136
+ D_LL = f["ref_pos"].unsqueeze(-2) - f["ref_pos"].unsqueeze(
137
+ -3
138
+ ) # [L, 1, 3] - [1, L, 3] -> [L, L, 3]
139
+
140
+ # Create a mask indicating if two atoms are on the same chain AND the same residue (e.g., the same ref_space_uid)
141
+ # (We add a singleton dimension to the mask to make it broadcastable with D_LL, which will be useful later)
142
+ V_LL = (
143
+ f["ref_space_uid"].unsqueeze(-1) == f["ref_space_uid"].unsqueeze(-2)
144
+ ).unsqueeze(-1) # [L, 1] == [1, L] -> [L, L, 1]
145
+
146
+ @activation_checkpointing
147
+ def embed_features(C_L, D_LL, V_LL):
148
+ P_LL = self.process_d(D_LL) * V_LL # [L, L, 3] -> [L, L, C_atompair]
149
+
150
+ # Embed pairwise inverse squared distances, and the valid mask
151
+ if self.use_inv_dist_squared:
152
+ P_LL += (
153
+ self.process_inverse_dist(
154
+ 1 / (1 + torch.sum(D_LL * D_LL, dim=-1, keepdim=True))
155
+ )
156
+ * V_LL
157
+ ) # [L, L, 1] -> [L, L, C_atompair]
158
+ else:
159
+ P_LL = (
160
+ P_LL
161
+ + self.process_inverse_dist(
162
+ 1 / (1 + torch.linalg.norm(D_LL, dim=-1, keepdim=True))
163
+ )
164
+ * V_LL
165
+ ) # [L, L, 1] -> [L, L, C_atompair]
166
+
167
+ P_LL = P_LL + self.process_valid_mask(V_LL.to(P_LL.dtype)) * V_LL
168
+
169
+ # Initialise the atom single representation as the single conditioning.
170
+ # NOTE: We create a new view on the tensor, so that the original tensor is not modified (unless we perform an in-place operation)
171
+ Q_L = C_L
172
+
173
+ # Add the combined single conditioning to the pair representation.
174
+ # (With a residual connection)
175
+ P_LL = P_LL + (
176
+ self.process_single_l(C_L).unsqueeze(-2)
177
+ + self.process_single_m(C_L).unsqueeze(-3)
178
+ ) # [L, 1, C_atompair] + [1, L, C_atompair] -> [L, L, C_atompair]
179
+
180
+ # Run a small MLP on the pair activations
181
+ # (With a residual connection)
182
+ P_LL = P_LL + self.pair_mlp(
183
+ P_LL
184
+ ) # [L, L, C_atompair] -> [L, L, C_atompair]
185
+
186
+ # Cross attention transformer
187
+ Q_L = self.atom_transformer(Q_L, C_L, P_LL) # [L, C_atom]
188
+
189
+ # ...get the desired shape of the per-token representation, which is [I, C_token]
190
+ A_I_shape = Q_L.shape[:-2] + (
191
+ I,
192
+ self.c_token,
193
+ )
194
+
195
+ # Aggregate per-atom representation to per-token representation
196
+ # (Set the per-token representation to be the mean activation of all atoms in the token)
197
+ processed_Q_L = self.process_q(Q_L) # [L, C_atom] -> [L, C_token]
198
+ # Ensure dtype consistency for index_reduce
199
+ processed_Q_L = processed_Q_L.to(Q_L.dtype)
200
+
201
+ A_I = torch.zeros(
202
+ A_I_shape, device=Q_L.device, dtype=Q_L.dtype
203
+ ).index_reduce(
204
+ -2, # Operate on the second-to-last dimension (the atom dimension)
205
+ f[
206
+ "atom_to_token_map"
207
+ ].long(), # [L], mapping from atom index to token index. Must be a torch.int64 or torch.int32 tensor.
208
+ processed_Q_L, # [L, C_atom] -> [L, C_token]
209
+ "mean",
210
+ include_self=False, # Do not use the original values in A_I (all zeros) when aggregating
211
+ ) # [L, C_atom] -> [I, C_token]
212
+
213
+ return A_I, Q_L, C_L, P_LL
214
+
215
+ return embed_features(C_L, D_LL, V_LL)
216
+
217
+
218
+ class AttentionPairBiasPairformerDeepspeed(nn.Module):
219
+ def __init__(self, c_a, c_s, c_pair, n_head):
220
+ super().__init__()
221
+ self.n_head = n_head
222
+ self.c_a = c_a
223
+ self.c_pair = c_pair
224
+ self.c = c_a // n_head
225
+
226
+ self.to_q = MultiDimLinear(c_a, (n_head, self.c), bias=False)
227
+ self.to_k = MultiDimLinear(c_a, (n_head, self.c), bias=False)
228
+ self.to_v = MultiDimLinear(c_a, (n_head, self.c), bias=False)
229
+ self.to_b = linearNoBias(c_pair, n_head)
230
+ self.to_g = nn.Sequential(
231
+ MultiDimLinear(c_a, (n_head, self.c), bias=False),
232
+ nn.Sigmoid(),
233
+ )
234
+ self.to_a = linearNoBias(c_a, c_a)
235
+ # self.linear_output_project = nn.Sequential(
236
+ # LinearBiasInit(c_s, c_a, biasinit=-2.),
237
+ # nn.Sigmoid(),
238
+ # )
239
+ self.ln_0 = nn.LayerNorm((c_pair,))
240
+ # self.ada_ln_1 = AdaLN(c_a=c_a, c_s=c_s)
241
+ self.ln_1 = nn.LayerNorm((c_a,))
242
+ self.use_deepspeed_evo = False
243
+ self.force_bfloat16 = True
244
+
245
+ def forward(
246
+ self,
247
+ A_I, # [I, C_a]
248
+ S_I, # [I, C_a] | None
249
+ Z_II, # [I, I, C_z]
250
+ Beta_II=None, # [I, I]
251
+ ):
252
+ # Input projections
253
+ assert S_I is None
254
+ A_I = self.ln_1(A_I)
255
+
256
+ if self.use_deepspeed_evo or self.force_bfloat16:
257
+ A_I = A_I.to(torch.bfloat16)
258
+
259
+ Q_IH = self.to_q(A_I) # / np.sqrt(self.c)
260
+ K_IH = self.to_k(A_I)
261
+ V_IH = self.to_v(A_I)
262
+ B_IIH = self.to_b(self.ln_0(Z_II)) + Beta_II[..., None]
263
+ G_IH = self.to_g(A_I)
264
+
265
+ B, L = B_IIH.shape[:2]
266
+
267
+ if not self.use_deepspeed_evo or L <= 24:
268
+ Q_IH = Q_IH / torch.sqrt(
269
+ torch.tensor(self.c).to(Q_IH.device, torch.bfloat16)
270
+ )
271
+ # Attention
272
+ A_IIH = torch.softmax(
273
+ torch.einsum("...ihd,...jhd->...ijh", Q_IH, K_IH) + B_IIH, dim=-2
274
+ ) # softmax over j
275
+ ## G_IH: [I, H, C]
276
+ ## A_IIH: [I, I, H]
277
+ ## V_IH: [I, H, C]
278
+ A_I = torch.einsum("...ijh,...jhc->...ihc", A_IIH, V_IH)
279
+ A_I = G_IH * A_I # [B, I, H, C]
280
+ A_I = A_I.flatten(start_dim=-2) # [B, I, Ca]
281
+ else:
282
+ # DS4Sci_EvoformerAttention
283
+ # Q, K, V: [Batch, N_seq, N_res, Head, Dim]
284
+ # res_mask: [Batch, N_seq, 1, 1, N_res]
285
+ # pair_bias: [Batch, 1, Head, N_res, N_res]
286
+ from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention
287
+
288
+ assert Q_IH.shape[0] != 1, "this code assumes your structure is not batched"
289
+ batch = 1
290
+ n_res = Q_IH.shape[0]
291
+ n_head = self.n_head
292
+ c = self.c
293
+
294
+ Q_IH = Q_IH[None, None]
295
+ K_IH = K_IH[None, None]
296
+ V_IH = V_IH[None, None]
297
+ B_IIH = B_IIH.repeat(Q_IH.shape[0], 1, 1, 1)
298
+ B_IIH = B_IIH[:, None]
299
+ B_IIH = B_IIH.permute(0, 1, 4, 2, 3).to(torch.bfloat16)
300
+ mask = torch.zeros(
301
+ [Q_IH.shape[0], 1, 1, 1, B_IIH.shape[-1]],
302
+ dtype=torch.bfloat16,
303
+ device=B_IIH.device,
304
+ )
305
+
306
+ assert Q_IH.shape == (batch, 1, n_res, n_head, c)
307
+ assert K_IH.shape == (batch, 1, n_res, n_head, c)
308
+ assert V_IH.shape == (batch, 1, n_res, n_head, c)
309
+ assert mask.shape == (batch, 1, 1, 1, n_res)
310
+ assert B_IIH.shape == (batch, 1, n_head, n_res, n_res)
311
+
312
+ A_I = DS4Sci_EvoformerAttention(Q_IH, K_IH, V_IH, [mask, B_IIH])
313
+
314
+ assert A_I.shape == (batch, 1, n_res, n_head, c)
315
+ A_I = A_I * G_IH[None, None]
316
+ A_I = A_I.view(n_res, -1)
317
+
318
+ A_I = self.to_a(A_I)
319
+
320
+ return A_I
321
+
322
+
323
+ class PairformerBlock(nn.Module):
324
+ """
325
+ Attempt to replicate AF3 architecture from scratch.
326
+ """
327
+
328
+ def __init__(
329
+ self,
330
+ c_s,
331
+ c_z,
332
+ p_drop,
333
+ triangle_multiplication,
334
+ triangle_attention,
335
+ attention_pair_bias,
336
+ n_transition=4,
337
+ **kwargs, # Catch-all for backwards compatibility
338
+ ):
339
+ super().__init__()
340
+
341
+ self.drop_row = Dropout(broadcast_dim=-2, p_drop=p_drop)
342
+ self.drop_col = Dropout(broadcast_dim=-3, p_drop=p_drop)
343
+
344
+ self.tri_mul_outgoing = TriangleMultiplication(
345
+ d_pair=c_z,
346
+ d_hidden=triangle_multiplication["d_hidden"],
347
+ direction="outgoing",
348
+ bias=True,
349
+ use_cuequivariance=True,
350
+ )
351
+ self.tri_mul_incoming = TriangleMultiplication(
352
+ d_pair=c_z,
353
+ d_hidden=triangle_multiplication["d_hidden"],
354
+ direction="incoming",
355
+ bias=True,
356
+ use_cuequivariance=True,
357
+ )
358
+
359
+ self.tri_attn_start = TriangleAttention(
360
+ c_z,
361
+ **triangle_attention,
362
+ start_node=True,
363
+ use_cuequivariance=True,
364
+ )
365
+ self.tri_attn_end = TriangleAttention(
366
+ c_z,
367
+ **triangle_attention,
368
+ start_node=False,
369
+ use_cuequivariance=True,
370
+ )
371
+
372
+ self.z_transition = Transition(c=c_z, n=n_transition)
373
+
374
+ if c_s > 0:
375
+ self.s_transition = Transition(c=c_s, n=n_transition)
376
+
377
+ self.attention_pair_bias = AttentionPairBiasPairformerDeepspeed(
378
+ c_a=c_s, c_s=0, c_pair=c_z, **attention_pair_bias
379
+ )
380
+ triangle_operations_expected_dim = 4 # B, L, L, C
381
+ self.maybe_make_batched = create_batch_dimension_if_not_present(
382
+ triangle_operations_expected_dim
383
+ )
384
+
385
+ @activation_checkpointing
386
+ def forward(self, S_I, Z_II):
387
+ Z_II = Z_II + self.drop_row(
388
+ self.maybe_make_batched(self.tri_mul_outgoing)(Z_II)
389
+ )
390
+ Z_II = Z_II + self.drop_row(
391
+ self.maybe_make_batched(self.tri_mul_incoming)(Z_II)
392
+ )
393
+ Z_II = Z_II + self.drop_row(self.maybe_make_batched(self.tri_attn_start)(Z_II))
394
+ Z_II = Z_II + self.drop_col(self.maybe_make_batched(self.tri_attn_end)(Z_II))
395
+ Z_II = Z_II + self.z_transition(Z_II)
396
+ if S_I is not None:
397
+ S_I = S_I + self.attention_pair_bias(
398
+ S_I, None, Z_II, Beta_II=torch.tensor([0.0], device=Z_II.device)
399
+ )
400
+ S_I = S_I + self.s_transition(S_I)
401
+
402
+ return S_I, Z_II
403
+
404
+
405
+ class FeatureInitializer(nn.Module):
406
+ def __init__(
407
+ self,
408
+ c_s,
409
+ c_z,
410
+ c_atom,
411
+ c_atompair,
412
+ c_s_inputs,
413
+ input_feature_embedder,
414
+ relative_position_encoding,
415
+ ):
416
+ super().__init__()
417
+ self.input_feature_embedder = InputFeatureEmbedder(
418
+ c_atom=c_atom, c_atompair=c_atompair, **input_feature_embedder
419
+ )
420
+ self.to_s_init = linearNoBias(c_s_inputs, c_s)
421
+ self.to_z_init_i = linearNoBias(c_s_inputs, c_z)
422
+ self.to_z_init_j = linearNoBias(c_s_inputs, c_z)
423
+ self.relative_position_encoding = RelativePositionEncoding(
424
+ c_z=c_z, **relative_position_encoding
425
+ )
426
+ self.process_token_bonds = linearNoBias(1, c_z)
427
+
428
+ def forward(
429
+ self,
430
+ f,
431
+ ):
432
+ S_inputs_I = self.input_feature_embedder(f)
433
+ S_init_I = self.to_s_init(S_inputs_I)
434
+ Z_init_II = self.to_z_init_i(S_inputs_I).unsqueeze(-3) + self.to_z_init_j(
435
+ S_inputs_I
436
+ ).unsqueeze(-2)
437
+ Z_init_II = Z_init_II + self.relative_position_encoding(f)
438
+ Z_init_II = Z_init_II + self.process_token_bonds(
439
+ f["token_bonds"].unsqueeze(-1).to(torch.float)
440
+ )
441
+ return S_inputs_I, S_init_I, Z_init_II
442
+
443
+
444
+ class InputFeatureEmbedder(nn.Module):
445
+ def __init__(self, features, c_atom, c_atompair, atom_attention_encoder):
446
+ super().__init__()
447
+ self.atom_attention_encoder = AtomAttentionEncoderPairformer(
448
+ c_atom=c_atom, c_atompair=c_atompair, c_s=0, **atom_attention_encoder
449
+ )
450
+ self.features = features
451
+ self.features_to_unsqueeze = ["deletion_mean"]
452
+
453
+ def forward(
454
+ self,
455
+ f,
456
+ ):
457
+ A_I, _, _, _ = self.atom_attention_encoder(f, None, None, None)
458
+ S_I = torch.cat(
459
+ [A_I.squeeze(0)]
460
+ + [
461
+ f[feature].unsqueeze(-1)
462
+ if feature in self.features_to_unsqueeze
463
+ else f[feature]
464
+ for feature in self.features
465
+ ],
466
+ dim=-1,
467
+ )
468
+ return S_I
469
+
470
+
471
+ class RelativePositionEncoding(nn.Module):
472
+ def __init__(self, r_max, s_max, c_z):
473
+ super().__init__()
474
+ self.r_max = r_max
475
+ self.s_max = s_max
476
+ self.c_z = c_z
477
+ self.linear = linearNoBias(
478
+ 2 * (2 * self.r_max + 2) + (2 * self.s_max + 2) + 1, c_z
479
+ )
480
+
481
+ def forward(self, f):
482
+ b_samechain_II = f["asym_id"].unsqueeze(-1) == f["asym_id"].unsqueeze(-2)
483
+ b_sameresidue_II = f["residue_index"].unsqueeze(-1) == f[
484
+ "residue_index"
485
+ ].unsqueeze(-2)
486
+ b_same_entity_II = f["entity_id"].unsqueeze(-1) == f["entity_id"].unsqueeze(-2)
487
+
488
+ # Handle cyclic chains
489
+ cyclic_asym_ids = f.get("cyclic_asym_ids", [])
490
+ if len(cyclic_asym_ids) > 0:
491
+ offset = f["residue_index"].unsqueeze(-1) - f["residue_index"].unsqueeze(-2)
492
+
493
+ for cyclic_asym_id in cyclic_asym_ids:
494
+ len_cyclic_chain = (
495
+ f["residue_index"][f["asym_id"] == cyclic_asym_id].unique().shape[0]
496
+ )
497
+ cyclic_chain_mask = (f["asym_id"].unsqueeze(-1) == cyclic_asym_id) & (
498
+ f["asym_id"].unsqueeze(-2) == cyclic_asym_id
499
+ )
500
+
501
+ # cyclic offset
502
+ if len_cyclic_chain > 0:
503
+ offset_plus = offset + len_cyclic_chain
504
+ offset_minus = offset - len_cyclic_chain
505
+ abs_offset = offset.abs()
506
+ abs_offset_plus = offset_plus.abs()
507
+ abs_offset_minus = offset_minus.abs()
508
+
509
+ choice_plus_or_minus = torch.where(
510
+ abs_offset_plus <= abs_offset_minus, offset_plus, offset_minus
511
+ )
512
+ c_offset = torch.where(
513
+ (abs_offset <= abs_offset_plus)
514
+ & (abs_offset <= abs_offset_minus),
515
+ offset,
516
+ choice_plus_or_minus,
517
+ )
518
+ offset = torch.where(cyclic_chain_mask, c_offset, offset)
519
+
520
+ offset = (offset + self.r_max).clamp(0, 2 * self.r_max)
521
+ d_residue_II = torch.where(
522
+ b_samechain_II, offset, (2 * self.r_max + 1) * torch.ones_like(offset)
523
+ )
524
+
525
+ else:
526
+ d_residue_II = torch.where(
527
+ b_samechain_II,
528
+ torch.clip(
529
+ f["residue_index"].unsqueeze(-1)
530
+ - f["residue_index"].unsqueeze(-2)
531
+ + self.r_max,
532
+ 0,
533
+ 2 * self.r_max,
534
+ ),
535
+ 2 * self.r_max + 1,
536
+ )
537
+
538
+ A_relpos_II = one_hot(d_residue_II.long(), 2 * self.r_max + 2)
539
+ d_token_II = torch.where(
540
+ b_samechain_II * b_sameresidue_II,
541
+ torch.clip(
542
+ f["token_index"].unsqueeze(-1)
543
+ - f["token_index"].unsqueeze(-2)
544
+ + self.r_max,
545
+ 0,
546
+ 2 * self.r_max,
547
+ ),
548
+ 2 * self.r_max + 1,
549
+ )
550
+ A_reltoken_II = one_hot(d_token_II, 2 * self.r_max + 2)
551
+ d_chain_II = torch.where(
552
+ # NOTE: Implementing bugfix from the Protenix Technical report, where we use `same_entity` instead of `not same_chain` (as in the AF-3 pseudocode)
553
+ # Reference: https://github.com/bytedance/Protenix/blob/main/Protenix_Technical_Report.pdf
554
+ b_same_entity_II,
555
+ torch.clip(
556
+ f["sym_id"].unsqueeze(-1) - f["sym_id"].unsqueeze(-2) + self.s_max,
557
+ 0,
558
+ 2 * self.s_max,
559
+ ),
560
+ 2 * self.s_max + 1,
561
+ )
562
+ A_relchain_II = one_hot(d_chain_II.long(), 2 * self.s_max + 2)
563
+ return self.linear(
564
+ torch.cat(
565
+ [
566
+ A_relpos_II,
567
+ A_reltoken_II,
568
+ b_same_entity_II.unsqueeze(-1),
569
+ A_relchain_II,
570
+ ],
571
+ dim=-1,
572
+ ).to(torch.float)
573
+ )
574
+
575
+
576
+ class MSAModule(nn.Module):
577
+ def __init__(
578
+ self,
579
+ n_block,
580
+ c_m,
581
+ p_drop_msa,
582
+ p_drop_pair,
583
+ msa_subsample_embedder,
584
+ outer_product,
585
+ msa_pair_weighted_averaging,
586
+ msa_transition,
587
+ triangle_multiplication_outgoing,
588
+ triangle_multiplication_incoming,
589
+ triangle_attention_starting,
590
+ triangle_attention_ending,
591
+ pair_transition,
592
+ ):
593
+ super().__init__()
594
+ self.n_block = n_block
595
+ self.msa_subsampler = MSASubsampleEmbedder(**msa_subsample_embedder)
596
+ self.outer_product = OuterProductMean_AF3(**outer_product)
597
+ self.msa_pair_weighted_averaging = MSAPairWeightedAverage(
598
+ **msa_pair_weighted_averaging
599
+ )
600
+ self.msa_transition = Transition(**msa_transition)
601
+
602
+ self.drop_row_msa = Dropout(broadcast_dim=-2, p_drop=p_drop_msa)
603
+ self.drop_row_pair = Dropout(broadcast_dim=-2, p_drop=p_drop_pair)
604
+ self.drop_col_pair = Dropout(broadcast_dim=-3, p_drop=p_drop_pair)
605
+
606
+ self.tri_mult_outgoing = TriangleMultiplication(
607
+ d_pair=triangle_multiplication_outgoing["d_pair"],
608
+ d_hidden=triangle_multiplication_outgoing["d_hidden"],
609
+ direction="outgoing",
610
+ bias=True,
611
+ use_cuequivariance=True,
612
+ )
613
+ self.tri_mult_incoming = TriangleMultiplication(
614
+ d_pair=triangle_multiplication_incoming["d_pair"],
615
+ d_hidden=triangle_multiplication_incoming["d_hidden"],
616
+ direction="incoming",
617
+ bias=True,
618
+ use_cuequivariance=True,
619
+ )
620
+ self.tri_attn_start = TriangleAttention(
621
+ **triangle_attention_starting, start_node=True, use_cuequivariance=True
622
+ )
623
+ self.tri_attn_end = TriangleAttention(
624
+ **triangle_attention_ending, start_node=False, use_cuequivariance=True
625
+ )
626
+ self.pair_transition = Transition(**pair_transition)
627
+
628
+ outer_product_expected_dim = 4 # B, S, I, C
629
+ self.maybe_make_batched_outer_product = create_batch_dimension_if_not_present(
630
+ outer_product_expected_dim
631
+ )
632
+
633
+ triangle_ops_expected_dim = 4 # B, I, I, C
634
+ self.maybe_make_batched_triangle_ops = create_batch_dimension_if_not_present(
635
+ triangle_ops_expected_dim
636
+ )
637
+
638
+ @activation_checkpointing
639
+ def forward(
640
+ self,
641
+ f,
642
+ Z_II,
643
+ S_inputs_I,
644
+ ):
645
+ msa = f["msa"]
646
+ msa_SI = self.msa_subsampler(msa, S_inputs_I)
647
+
648
+ for i in range(self.n_block):
649
+ # update MSA features
650
+ Z_II = Z_II + self.maybe_make_batched_outer_product(self.outer_product)(
651
+ msa_SI
652
+ )
653
+ msa_SI = msa_SI + self.drop_row_msa(
654
+ self.msa_pair_weighted_averaging(msa_SI, Z_II)
655
+ )
656
+ msa_SI = msa_SI + self.msa_transition(msa_SI)
657
+
658
+ # update pair features
659
+ Z_II = Z_II + self.drop_row_pair(
660
+ self.maybe_make_batched_triangle_ops(self.tri_mult_outgoing)(Z_II)
661
+ )
662
+ Z_II = Z_II + self.drop_row_pair(
663
+ self.maybe_make_batched_triangle_ops(self.tri_mult_incoming)(Z_II)
664
+ )
665
+
666
+ Z_II = Z_II + self.drop_row_pair(
667
+ self.maybe_make_batched_triangle_ops(self.tri_attn_start)(Z_II)
668
+ )
669
+ Z_II = Z_II + self.drop_col_pair(
670
+ self.maybe_make_batched_triangle_ops(self.tri_attn_end)(Z_II)
671
+ )
672
+ Z_II = Z_II + self.pair_transition(Z_II)
673
+
674
+ return Z_II
675
+
676
+
677
+ class RF3TemplateEmbedder(nn.Module):
678
+ """
679
+ Template track that enables conditioning on noisy ground-truth templates at the token level.
680
+ Supports all chain types.
681
+ """
682
+
683
+ def __init__(
684
+ self,
685
+ n_block,
686
+ raw_template_dim,
687
+ c_z,
688
+ c,
689
+ p_drop,
690
+ use_fourier_encoding: bool = False, # HACK: Unused, kept for backwards compatibility with 9/21 checkpoint
691
+ ):
692
+ super().__init__()
693
+ self.c = c
694
+ self.emb_pair = nn.Linear(c_z, c, bias=False)
695
+ self.norm_pair_before_pairformer = nn.LayerNorm(c_z)
696
+ self.norm_after_pairformer = nn.LayerNorm(c)
697
+ self.emb_templ = nn.Linear(raw_template_dim, c, bias=False)
698
+
699
+ # template pairformer does not operate on sequence representation
700
+ self.pairformer = nn.ModuleList(
701
+ [
702
+ PairformerBlock(
703
+ c_s=0,
704
+ c_z=c,
705
+ p_drop=p_drop,
706
+ triangle_multiplication=dict(d_hidden=c),
707
+ triangle_attention=dict(d_hidden=c),
708
+ attention_pair_bias={},
709
+ n_transition=4,
710
+ )
711
+ for _ in range(n_block)
712
+ ]
713
+ )
714
+
715
+ # NOTE: this is not consistent with AF3 paper which outputs this tensor in the template_channel dimension
716
+ # In Algorithm 1, line 9, the outputs of this function are added to the Z_II tensor which has dimensions [B, I, I, C_z]
717
+ # so we make the outputs of this module also has those dimensions
718
+ self.agg_emb = nn.Linear(c, c_z, bias=False)
719
+
720
+ def forward(
721
+ self,
722
+ f,
723
+ Z_II,
724
+ ):
725
+ @activation_checkpointing
726
+ def embed_templates_like_rf3(
727
+ has_distogram_condition, # [I, I]
728
+ distogram_condition_noise_scale, # [I]
729
+ distogram_condition, # [I, I, 64], where 64 is the number of distogram bins
730
+ ):
731
+ I = Z_II.shape[0] # n_tokens
732
+
733
+ # Transform noise scale to reasonable range
734
+ joint_noise_scale = (
735
+ distogram_condition_noise_scale[None, :] ** 2
736
+ + distogram_condition_noise_scale[:, None] ** 2
737
+ ).sqrt()
738
+ joint_noise_level = af3_noise_scale_to_noise_level(joint_noise_scale)
739
+
740
+ # ---------------------------- #
741
+
742
+ # ... concatenate along the channel dimension
743
+ template_feats = torch.cat(
744
+ [
745
+ distogram_condition, # [I, I, 64]
746
+ has_distogram_condition.unsqueeze(-1), # [I, I, 1]
747
+ joint_noise_level.unsqueeze(-1), # [I, I, 1]
748
+ ],
749
+ dim=-1,
750
+ ) # [I, I, 66]
751
+
752
+ # ... remove any invalid interactions
753
+ template_feats = template_feats * has_distogram_condition.unsqueeze(
754
+ -1
755
+ ) # [I, I, 66], where 66 = 64 + 1 + 1
756
+
757
+ # ... embed template features
758
+ template_channels = self.emb_templ(template_feats) # [I, I, c]
759
+
760
+ # ---------------------------- #
761
+
762
+ # ... pass through pairformer
763
+ u_II = torch.zeros(I, I, self.c, device=Z_II.device)
764
+ v_II = (
765
+ self.emb_pair(self.norm_pair_before_pairformer(Z_II))
766
+ + template_channels
767
+ ) # [I, I, c]
768
+ for block in self.pairformer:
769
+ _, v_II = block(None, v_II)
770
+ u_II = u_II + self.norm_after_pairformer(v_II)
771
+
772
+ return self.agg_emb(relu(u_II))
773
+
774
+ # Ground-truth template embedding (noisy ground-truth template as input)
775
+ embedded_templates = embed_templates_like_rf3(
776
+ has_distogram_condition=f["has_distogram_condition"], # [I, I]
777
+ distogram_condition_noise_scale=f["distogram_condition_noise_scale"], # [I]
778
+ distogram_condition=f[
779
+ "distogram_condition"
780
+ ], # [I, I, 64], where 64 is the number of distogram bins
781
+ )
782
+
783
+ return embedded_templates