rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,544 @@
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ from rf3.loss.loss import calc_chiral_grads_flat_impl
5
+ from rf3.model.layers.layer_utils import (
6
+ AdaLN,
7
+ LinearBiasInit,
8
+ MultiDimLinear,
9
+ collapse,
10
+ linearNoBias,
11
+ )
12
+ from rf3.model.layers.mlff import ConformerEmbeddingWeightedAverage
13
+
14
+ from foundry.training.checkpoint import activation_checkpointing
15
+ from foundry.utils.torch import device_of
16
+
17
+
18
+ class AtomAttentionEncoderDiffusion(nn.Module):
19
+ def __init__(
20
+ self,
21
+ c_atom,
22
+ c_atompair,
23
+ c_token,
24
+ c_tokenpair,
25
+ c_s,
26
+ atom_1d_features,
27
+ c_atom_1d_features,
28
+ atom_transformer,
29
+ broadcast_trunk_feats_on_1dim_old,
30
+ use_chiral_features,
31
+ no_grad_on_chiral_center,
32
+ use_inv_dist_squared: bool = False,
33
+ use_atom_level_embedding: bool = False,
34
+ atom_level_embedding_dim: int = 384,
35
+ ):
36
+ super().__init__()
37
+ self.c_atom = c_atom
38
+ self.c_atompair = c_atompair
39
+ self.c_token = c_token
40
+ self.c_tokenpair = c_tokenpair
41
+ self.c_s = c_s
42
+ self.atom_1d_features = atom_1d_features
43
+ self.broadcast_trunk_feats_on_1dim_old = broadcast_trunk_feats_on_1dim_old
44
+ self.use_chiral_features = use_chiral_features
45
+ self.no_grad_on_chiral_center = no_grad_on_chiral_center
46
+ self.use_atom_level_embedding = use_atom_level_embedding
47
+ self.atom_level_embedding_dim = atom_level_embedding_dim
48
+
49
+ self.process_input_features = linearNoBias(c_atom_1d_features, c_atom)
50
+
51
+ self.process_d = linearNoBias(3, c_atompair) # x,y,z
52
+
53
+ self.process_inverse_dist = linearNoBias(1, c_atompair)
54
+ self.process_valid_mask = linearNoBias(1, c_atompair)
55
+
56
+ self.process_s_trunk = nn.Sequential(
57
+ nn.LayerNorm(c_s), linearNoBias(c_s, c_atom)
58
+ )
59
+ self.process_z = nn.Sequential(
60
+ nn.LayerNorm(c_tokenpair), linearNoBias(c_tokenpair, c_atompair)
61
+ )
62
+ self.process_r = linearNoBias(3, c_atom)
63
+ if self.use_chiral_features:
64
+ self.process_ch = linearNoBias(3, c_atom)
65
+
66
+ self.process_single_l = nn.Sequential(
67
+ nn.ReLU(), linearNoBias(c_atom, c_atompair)
68
+ )
69
+ self.process_single_m = nn.Sequential(
70
+ nn.ReLU(), linearNoBias(c_atom, c_atompair)
71
+ )
72
+
73
+ self.pair_mlp = nn.Sequential(
74
+ nn.ReLU(),
75
+ linearNoBias(self.c_atompair, c_atompair),
76
+ nn.ReLU(),
77
+ linearNoBias(self.c_atompair, c_atompair),
78
+ nn.ReLU(),
79
+ linearNoBias(self.c_atompair, c_atompair),
80
+ )
81
+
82
+ self.process_q = nn.Sequential(
83
+ linearNoBias(c_atom, c_token),
84
+ nn.ReLU(),
85
+ )
86
+
87
+ self.atom_transformer = AtomTransformer(
88
+ c_atom=c_atom, c_atompair=c_atompair, **atom_transformer
89
+ )
90
+
91
+ self.use_inv_dist_squared = use_inv_dist_squared
92
+
93
+ if self.use_atom_level_embedding:
94
+ self.process_atom_level_embedding = ConformerEmbeddingWeightedAverage(
95
+ atom_level_embedding_dim=self.atom_level_embedding_dim,
96
+ c_atompair=c_atompair,
97
+ c_atom=c_atom,
98
+ )
99
+
100
+ def reset_parameters(self):
101
+ super().reset_parameters()
102
+ if self.use_chiral_features:
103
+ nn.init.zeros_(self.process_ch.weight)
104
+
105
+ def forward(
106
+ self,
107
+ f, # Dict (Input feature dictionary)
108
+ R_L, # [D, L, 3]
109
+ S_trunk_I, # [B, I, C_S_trunk] [...,I,C_S_trunk]
110
+ Z_II, # [B, I, I, C_Z] [...,I,I,C_Z]
111
+ ):
112
+ assert R_L is not None
113
+
114
+ tok_idx = f["atom_to_token_map"]
115
+ L = len(tok_idx)
116
+ I = tok_idx.max() + 1
117
+
118
+ f["ref_atom_name_chars"] = f["ref_atom_name_chars"].reshape(L, -1)
119
+ # Create the atom single conditioning: Embed per-atom meta data
120
+ C_L = self.process_input_features(
121
+ torch.cat(
122
+ tuple(
123
+ collapse(f[feature_name], L)
124
+ for feature_name in self.atom_1d_features
125
+ ),
126
+ dim=-1,
127
+ )
128
+ )
129
+
130
+ if self.use_atom_level_embedding:
131
+ assert "atom_level_embedding" in f
132
+ C_L = C_L + self.process_atom_level_embedding(f["atom_level_embedding"])
133
+
134
+ # Embed offsets between atom reference positions
135
+ D_LL = f["ref_pos"].unsqueeze(-2) - f["ref_pos"].unsqueeze(-3)
136
+ V_LL = (
137
+ f["ref_space_uid"].unsqueeze(-1) == f["ref_space_uid"].unsqueeze(-2)
138
+ ).unsqueeze(-1)
139
+ P_LL = self.process_d(D_LL) * V_LL
140
+
141
+ @activation_checkpointing
142
+ def embed_atom_feats(R_L, C_L, D_LL, V_LL, P_LL, tok_idx):
143
+ # Embed pairwise inverse squared distances, and the valid mask
144
+ if self.training:
145
+ if self.use_inv_dist_squared:
146
+ P_LL = (
147
+ P_LL
148
+ + self.process_inverse_dist(
149
+ 1 / (1 + torch.sum(D_LL * D_LL, dim=-1, keepdim=True))
150
+ )
151
+ * V_LL
152
+ )
153
+ else:
154
+ P_LL = (
155
+ P_LL
156
+ + self.process_inverse_dist(
157
+ 1 / (1 + torch.linalg.norm(D_LL, dim=-1, keepdim=True))
158
+ )
159
+ * V_LL
160
+ )
161
+ P_LL = P_LL + self.process_valid_mask(V_LL.to(P_LL.dtype)) * V_LL
162
+ else:
163
+ if self.use_inv_dist_squared:
164
+ P_LL[V_LL[..., 0]] += self.process_inverse_dist(
165
+ 1
166
+ / (
167
+ 1
168
+ + torch.sum(
169
+ D_LL[V_LL[..., 0]] * D_LL[V_LL[..., 0]],
170
+ dim=-1,
171
+ keepdim=True,
172
+ )
173
+ )
174
+ )
175
+ else:
176
+ P_LL[V_LL[..., 0]] += self.process_inverse_dist(
177
+ 1
178
+ / (
179
+ 1
180
+ + torch.linalg.norm(
181
+ D_LL[V_LL[..., 0]], dim=-1, keepdim=True
182
+ )
183
+ )
184
+ )
185
+ P_LL[V_LL[..., 0]] += self.process_valid_mask(
186
+ V_LL[V_LL[..., 0]].to(P_LL.dtype)
187
+ )
188
+
189
+ # Initialise the atom single representation as the single conditioning.
190
+ Q_L = C_L
191
+
192
+ # If provided, add trunk embeddings and noisy positions.
193
+ if R_L is not None:
194
+ # Broadcast the single and pair embedding from the trunk.
195
+ S_trunk_embed_L = self.process_s_trunk(S_trunk_I)[..., tok_idx, :]
196
+
197
+ C_L = C_L + S_trunk_embed_L
198
+ assert not (C_L == Q_L).all()
199
+ if self.broadcast_trunk_feats_on_1dim_old:
200
+ P_LL = P_LL + self.process_z(Z_II)[..., tok_idx, tok_idx, :]
201
+ else:
202
+ P_LL = (
203
+ P_LL + self.process_z(Z_II)[..., tok_idx, :, :][..., tok_idx, :]
204
+ )
205
+
206
+ # Add the noisy positions.
207
+ Q_L = self.process_r(R_L) + Q_L
208
+
209
+ # Add chirality gradients
210
+ if self.use_chiral_features:
211
+ with torch.autocast(
212
+ device_type=device_of(self).type, enabled=False
213
+ ):
214
+ # Do not pass grads through grad calc
215
+ R_L = calc_chiral_grads_flat_impl(
216
+ R_L.detach(),
217
+ f["chiral_centers"],
218
+ f["chiral_center_dihedral_angles"],
219
+ self.no_grad_on_chiral_center,
220
+ ).nan_to_num()
221
+ Q_L = self.process_ch(R_L) + Q_L
222
+
223
+ # Add the combined single conditioning to the pair representation.
224
+ P_LL = P_LL + (
225
+ self.process_single_l(C_L).unsqueeze(-2)
226
+ + self.process_single_m(C_L).unsqueeze(-3)
227
+ )
228
+
229
+ # Run a small MLP on the pair activations
230
+ P_LL = P_LL + self.pair_mlp(P_LL)
231
+
232
+ # Cross attention transformer.
233
+ Q_L = self.atom_transformer(Q_L, C_L, P_LL)
234
+
235
+ A_I_shape = Q_L.shape[:-2] + (
236
+ I,
237
+ self.c_token,
238
+ )
239
+ # Aggregate per-atom representation to per-token representation
240
+ processed_Q_L = self.process_q(Q_L) # [L, C_atom] -> [L, C_token]
241
+ # Ensure dtype consistency for index_reduce
242
+ processed_Q_L = processed_Q_L.to(Q_L.dtype)
243
+
244
+ A_I = (
245
+ torch.zeros(A_I_shape, device=Q_L.device, dtype=Q_L.dtype)
246
+ .index_reduce(
247
+ -2,
248
+ f["atom_to_token_map"].long(),
249
+ processed_Q_L,
250
+ "mean",
251
+ include_self=False,
252
+ )
253
+ .clone()
254
+ )
255
+
256
+ return A_I, Q_L, C_L, P_LL
257
+
258
+ return embed_atom_feats(R_L, C_L, D_LL, V_LL, P_LL, tok_idx)
259
+
260
+
261
+ class AtomTransformer(nn.Module):
262
+ def __init__(
263
+ self,
264
+ c_atom,
265
+ c_atompair,
266
+ diffusion_transformer,
267
+ n_queries,
268
+ n_keys,
269
+ l_max: int = None, # HACK: Unused, kept for backwards compatibility with 9/21 checkpoint
270
+ ):
271
+ super().__init__()
272
+
273
+ self.diffusion_transformer = DiffusionTransformer(
274
+ c_token=c_atom, c_s=c_atom, c_tokenpair=c_atompair, **diffusion_transformer
275
+ )
276
+
277
+ def forward(
278
+ self,
279
+ Ql, # [B, L, C_atom]
280
+ Cl, # [B, L, C_atom]
281
+ Plm, # [B, L, L, C_atompair]
282
+ ):
283
+ Beta_lm = True
284
+ return self.diffusion_transformer(Ql, Cl, Plm, Beta_lm)
285
+
286
+
287
+ class DiffusionTransformer(nn.Module):
288
+ def __init__(self, c_token, c_s, c_tokenpair, n_block, diffusion_transformer_block):
289
+ super().__init__()
290
+ self.blocks = torch.nn.ModuleList(
291
+ [
292
+ DiffusionTransformerBlock(
293
+ c_token=c_token,
294
+ c_s=c_s,
295
+ c_tokenpair=c_tokenpair,
296
+ **diffusion_transformer_block,
297
+ )
298
+ for _ in range(n_block)
299
+ ]
300
+ )
301
+
302
+ def forward(
303
+ self,
304
+ A_I, # [..., I, C_token]
305
+ S_I, # [..., I, C_token]
306
+ Z_II, # [..., I, I, C_tokenpair]
307
+ Beta_II, # [I, I]
308
+ ):
309
+ for block in self.blocks:
310
+ A_I = block(A_I, S_I, Z_II, Beta_II)
311
+ return A_I
312
+
313
+
314
+ class DiffusionTransformerBlock(nn.Module):
315
+ def __init__(
316
+ self,
317
+ c_token,
318
+ c_s,
319
+ c_tokenpair,
320
+ n_head,
321
+ no_residual_connection_between_attention_and_transition,
322
+ kq_norm,
323
+ ):
324
+ super().__init__()
325
+ self.attention_pair_bias = AttentionPairBiasDiffusion(
326
+ c_a=c_token, c_s=c_s, c_pair=c_tokenpair, n_head=n_head, kq_norm=kq_norm
327
+ )
328
+ self.conditioned_transition_block = ConditionedTransitionBlock(
329
+ c_token=c_token, c_s=c_s
330
+ )
331
+ self.no_residual_connection_between_attention_and_transition = (
332
+ no_residual_connection_between_attention_and_transition
333
+ )
334
+
335
+ @activation_checkpointing
336
+ def forward(
337
+ self,
338
+ A_I, # [..., I, C_token]
339
+ S_I, # [..., I, C_s]
340
+ Z_II, # [..., I, I, C_tokenpair]
341
+ Beta_II, # [I, I]
342
+ ):
343
+ if self.no_residual_connection_between_attention_and_transition:
344
+ B_I = self.attention_pair_bias(A_I, S_I, Z_II, Beta_II)
345
+ A_I = A_I + B_I + self.conditioned_transition_block(A_I, S_I)
346
+ else:
347
+ A_I = A_I + self.attention_pair_bias(A_I, S_I, Z_II, Beta_II)
348
+ A_I = A_I + self.conditioned_transition_block(A_I, S_I)
349
+
350
+ return A_I
351
+
352
+
353
+ class ConditionedTransitionBlock(nn.Module):
354
+ """SwiGLU transition block with adaptive layernorm"""
355
+
356
+ def __init__(self, c_token, c_s, n=2):
357
+ super().__init__()
358
+ self.ada_ln = AdaLN(c_a=c_token, c_s=c_s)
359
+ self.linear_1 = linearNoBias(c_token, c_token * n)
360
+ self.linear_2 = linearNoBias(c_token, c_token * n)
361
+ self.linear_output_project = nn.Sequential(
362
+ LinearBiasInit(c_s, c_token, biasinit=-2.0),
363
+ nn.Sigmoid(),
364
+ )
365
+ self.linear_3 = linearNoBias(c_token * n, c_token)
366
+
367
+ def forward(
368
+ self,
369
+ Ai, # [B, I, C_token]
370
+ Si, # [B, I, C_token]
371
+ ):
372
+ Ai = self.ada_ln(Ai, Si)
373
+ # BUG: This is not the correct implementation of SwiGLU
374
+ # Bi = torch.sigmoid(self.linear_1(Ai)) * self.linear_2(Ai)
375
+ # FIX: This is the correct implementation of SwiGLU
376
+ Bi = torch.nn.functional.silu(self.linear_1(Ai)) * self.linear_2(Ai)
377
+
378
+ # Output projection (from adaLN-Zero)
379
+ return self.linear_output_project(Si) * self.linear_3(Bi)
380
+
381
+
382
+ class AttentionPairBiasDiffusion(nn.Module):
383
+ def __init__(self, c_a, c_s, c_pair, n_head, kq_norm):
384
+ super().__init__()
385
+ self.n_head = n_head
386
+ self.c_a = c_a
387
+ self.c_pair = c_pair
388
+ self.c = c_a // n_head
389
+
390
+ self.to_q = MultiDimLinear(c_a, (n_head, self.c), bias=False)
391
+ self.to_k = MultiDimLinear(c_a, (n_head, self.c), bias=False)
392
+ self.to_v = MultiDimLinear(c_a, (n_head, self.c), bias=False)
393
+ self.to_b = linearNoBias(c_pair, n_head)
394
+ self.to_g = nn.Sequential(
395
+ MultiDimLinear(c_a, (n_head, self.c), bias=False),
396
+ nn.Sigmoid(),
397
+ )
398
+ self.to_a = linearNoBias(c_a, c_a)
399
+ self.linear_output_project = nn.Sequential(
400
+ LinearBiasInit(c_s, c_a, biasinit=-2.0),
401
+ nn.Sigmoid(),
402
+ )
403
+ self.ln_0 = nn.LayerNorm((c_pair,))
404
+ self.ada_ln_1 = AdaLN(c_a=c_a, c_s=c_s)
405
+ self.use_deepspeed_evo = False
406
+ self.force_bfloat16 = True
407
+
408
+ self.kq_norm = kq_norm
409
+ if self.kq_norm:
410
+ self.key_layer_norm = nn.LayerNorm((self.n_head * self.c,))
411
+ self.query_layer_norm = nn.LayerNorm((self.n_head * self.c,))
412
+
413
+ @activation_checkpointing
414
+ def forward(
415
+ self,
416
+ A_I, # [I, C_a]
417
+ S_I, # [I, C_a] | None
418
+ Z_II, # [I, I, C_z]
419
+ Beta_II, # [I, I]
420
+ ):
421
+ # Input projections
422
+ assert S_I is not None
423
+ if S_I is not None:
424
+ A_I = self.ada_ln_1(A_I, S_I)
425
+
426
+ if Beta_II is not None:
427
+ # zero out layer norms for the key and query
428
+ return self.atom_attention(A_I, S_I, Z_II)
429
+
430
+ if self.use_deepspeed_evo or self.force_bfloat16:
431
+ A_I = A_I.to(torch.bfloat16)
432
+ assert len(A_I.shape) == 3, f"(Diffusion batch, I, C_a) but got {A_I.shape}"
433
+
434
+ Q_IH = self.to_q(A_I) # / np.sqrt(self.c)
435
+ K_IH = self.to_k(A_I)
436
+ V_IH = self.to_v(A_I)
437
+ B_IIH = self.to_b(self.ln_0(Z_II))
438
+ G_IH = self.to_g(A_I)
439
+
440
+ if self.kq_norm:
441
+ Q_IH = self.query_layer_norm(
442
+ Q_IH.reshape(-1, self.n_head * self.c)
443
+ ).reshape(Q_IH.shape)
444
+ K_IH = self.key_layer_norm(K_IH.reshape(-1, self.n_head * self.c)).reshape(
445
+ K_IH.shape
446
+ )
447
+
448
+ _, L = B_IIH.shape[:2]
449
+
450
+ if not self.use_deepspeed_evo or L <= 24:
451
+ # Attention
452
+ Q_IH = Q_IH / np.sqrt(self.c)
453
+ A_IIH = torch.softmax(
454
+ torch.einsum("...ihd,...jhd->...ijh", Q_IH, K_IH) + B_IIH, dim=-2
455
+ ) # softmax over j
456
+ ## G_IH: [B, I, H, C]
457
+ ## A_IIH: [B, I, I, H]
458
+ ## V_IH: [B, I, H, C]
459
+ A_I = torch.einsum("...ijh,...jhc->...ihc", A_IIH, V_IH)
460
+ A_I = G_IH * A_I # [B, I, H, C]
461
+ A_I = A_I.flatten(start_dim=-2) # [B, I, Ca]
462
+ else:
463
+ # DS4Sci_EvoformerAttention
464
+ # Q, K, V: [Batch, N_seq, N_res, Head, Dim]
465
+ # res_mask: [Batch, N_seq, 1, 1, N_res]
466
+ # pair_bias: [Batch, 1, Head, N_res, N_res]
467
+ from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention
468
+
469
+ Q_IH = Q_IH[:, None]
470
+ K_IH = K_IH[:, None]
471
+ V_IH = V_IH[:, None]
472
+ B_IIH = B_IIH.repeat(Q_IH.shape[0], 1, 1, 1)
473
+ B_IIH = B_IIH[:, None]
474
+ B_IIH = B_IIH.permute(0, 1, 4, 2, 3).to(torch.bfloat16)
475
+ mask = torch.zeros(
476
+ [Q_IH.shape[0], 1, 1, 1, B_IIH.shape[-1]],
477
+ dtype=torch.bfloat16,
478
+ device=B_IIH.device,
479
+ )
480
+ A_I = DS4Sci_EvoformerAttention(Q_IH, K_IH, V_IH, [mask, B_IIH])
481
+ A_I = A_I * G_IH[:, None]
482
+ A_I = A_I.view(A_I.shape[0], A_I.shape[2], -1)
483
+
484
+ A_I = self.to_a(A_I)
485
+ # Output projection (from adaLN-Zero)
486
+ if S_I is not None:
487
+ A_I = self.linear_output_project(S_I) * A_I
488
+
489
+ return A_I
490
+
491
+ def atom_attention(self, A_I, S_I, Z_II, qbatch=32, kbatch=128):
492
+ assert qbatch % 2 == 0
493
+ assert kbatch % 2 == 0
494
+
495
+ if len(A_I.shape) == 2:
496
+ A_I = A_I[None]
497
+ Z_II = Z_II[None]
498
+ D, L = A_I.shape[:2]
499
+ Q_IH = self.to_q(A_I)
500
+ K_IH = self.to_k(A_I)
501
+ V_IH = self.to_v(A_I)
502
+ B_IIH = self.to_b(self.ln_0(Z_II))
503
+ G_IH = self.to_g(A_I)
504
+
505
+ if self.kq_norm:
506
+ Q_IH = self.query_layer_norm(
507
+ Q_IH.reshape(-1, self.n_head * self.c)
508
+ ).reshape(Q_IH.shape)
509
+ K_IH = self.key_layer_norm(K_IH.reshape(-1, self.n_head * self.c)).reshape(
510
+ K_IH.shape
511
+ )
512
+
513
+ nqbatch = (L + qbatch - 1) // qbatch
514
+ Cs = torch.arange(nqbatch, device=A_I.device) * qbatch + qbatch // 2
515
+ patchq = torch.arange(qbatch, device=A_I.device) - qbatch // 2
516
+ patchk = torch.arange(kbatch, device=A_I.device) - kbatch // 2
517
+
518
+ indicesQ = Cs[:, None] + patchq[None, :]
519
+ maskQ = (indicesQ < 0) | (indicesQ > L - 1)
520
+ indicesQ = torch.clamp(indicesQ, 0, L - 1)
521
+
522
+ indicesK = Cs[:, None] + patchk[None, :]
523
+ maskK = (indicesK < 0) | (indicesK > L - 1)
524
+ indicesK = torch.clamp(indicesK, 0, L - 1)
525
+
526
+ query_subset = Q_IH[:, indicesQ]
527
+ key_subset = K_IH[:, indicesK]
528
+ attn = torch.einsum("...ihd,...jhd->...ijh", query_subset, key_subset)
529
+ attn = attn / (self.c**0.5)
530
+
531
+ attn += B_IIH[:, indicesQ[:, :, None], indicesK[:, None, :]] - 1e9 * (
532
+ maskQ[None, :, :, None, None] + maskK[None, :, None, :, None]
533
+ )
534
+ attn = torch.softmax(attn, dim=-2)
535
+
536
+ value_subset = V_IH[:, indicesK]
537
+ atom_features = torch.einsum("...ijh,...jhc->...ihc", attn, value_subset)
538
+ atom_features = atom_features[:, ~maskQ]
539
+ atom_features = (G_IH * atom_features).view(D, L, -1)
540
+ atom_features = self.to_a(atom_features.view(D, L, -1))
541
+
542
+ A_I = self.linear_output_project(S_I) * atom_features
543
+
544
+ return A_I