rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,577 @@
1
+ import math
2
+ from math import sqrt
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from opt_einsum import contract as einsum
9
+ from rfd3.model.layers.block_utils import (
10
+ create_attention_indices,
11
+ indices_to_mask,
12
+ )
13
+ from rfd3.model.layers.layer_utils import (
14
+ AdaLN,
15
+ LinearBiasInit,
16
+ RMSNorm,
17
+ linearNoBias,
18
+ )
19
+
20
+ from foundry.common import exists
21
+ from foundry.training.checkpoint import activation_checkpointing
22
+ from foundry.utils.ddp import RankedLogger
23
+
24
+ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
25
+
26
+ try:
27
+ from cuequivariance_torch import attention_pair_bias as cueq_attention_pair_bias
28
+
29
+ # ranked_logger.info("Fused PairBiasAttention enabled!")
30
+ _CUEQ_AVAILABLE = True
31
+ except Exception:
32
+ # ranked_logger.warning(
33
+ # "Using pytorch implementation instead of NVIDIA kernel"
34
+ # "Ensure you are using the latest apptainer."
35
+ # )
36
+ _CUEQ_AVAILABLE = False
37
+
38
+
39
+ @torch.compiler.disable
40
+ def kernel_pairbias_attention(
41
+ *,
42
+ s: torch.Tensor, # (B, U, D) sequence features used for gating/output inside the kernel
43
+ q: torch.Tensor, # (B, H, U, DH)
44
+ k: torch.Tensor, # (B, H, V, DH)
45
+ v: torch.Tensor, # (B, H, V, DH)
46
+ z: torch.Tensor, # (B, U, V, z_dim)
47
+ mask: torch.Tensor | None, # (B, V) or (B*M, V) with 1=keep, 0=mask
48
+ num_heads: int,
49
+ w_proj_z: torch.Tensor, # (H, z_dim)
50
+ w_proj_g: torch.Tensor, # (D, D)
51
+ w_proj_o: torch.Tensor, # (D, D)
52
+ w_ln_z: torch.Tensor, # (z_dim,)
53
+ b_ln_z: torch.Tensor, # (z_dim,)
54
+ b_proj_z: torch.Tensor | None = None, # (H,)
55
+ b_proj_g: torch.Tensor | None = None, # (D,)
56
+ b_proj_o: torch.Tensor | None = None, # (D,)
57
+ attn_scale: float | None = None,
58
+ compute_pair_bias: bool = True,
59
+ multiplicity: int = 1,
60
+ ) -> torch.Tensor:
61
+ """Thin wrapper around cuequivariance_torch.attention_pair_bias."""
62
+ raise NotImplementedError("CUDA Kernel for attention pair bias not implemented")
63
+ out, _proj_z = cueq_attention_pair_bias(
64
+ s=s,
65
+ q=q,
66
+ k=k,
67
+ v=v,
68
+ z=z,
69
+ mask=mask,
70
+ num_heads=num_heads,
71
+ w_proj_z=w_proj_z,
72
+ w_proj_g=w_proj_g,
73
+ w_proj_o=w_proj_o,
74
+ w_ln_z=w_ln_z,
75
+ b_ln_z=b_ln_z,
76
+ b_proj_z=b_proj_z,
77
+ b_proj_g=b_proj_g,
78
+ b_proj_o=b_proj_o,
79
+ attn_scale=attn_scale,
80
+ compute_pair_bias=compute_pair_bias,
81
+ multiplicity=multiplicity,
82
+ )
83
+ return out # (B, U, D)
84
+
85
+
86
+ ######################################################################################
87
+ ########################## Network Modules ##########################
88
+ ######################################################################################
89
+
90
+
91
+ class GatedCrossAttention(nn.Module):
92
+ def __init__(
93
+ self,
94
+ c_query,
95
+ c_kv,
96
+ c_pair=None,
97
+ c_model=128,
98
+ n_head=4,
99
+ kq_norm=True,
100
+ dropout=0.0,
101
+ **_,
102
+ ):
103
+ super().__init__()
104
+ self.n_head = n_head
105
+ self.scale = 1 / math.sqrt(c_model // n_head)
106
+ assert c_model % n_head == 0, "c_model must be divisible by n_heads"
107
+
108
+ self.ln_q = RMSNorm(c_query)
109
+ self.ln_kv = RMSNorm(c_kv)
110
+
111
+ self.to_q = linearNoBias(c_query, c_model)
112
+ self.to_k = linearNoBias(c_kv, c_model)
113
+ self.to_v = linearNoBias(c_kv, c_model)
114
+ self.to_g = nn.Sequential(
115
+ linearNoBias(c_query, c_model),
116
+ nn.Sigmoid(),
117
+ )
118
+ self.to_out = nn.Sequential(nn.Linear(c_model, c_query), nn.Dropout(dropout))
119
+ self.kq_norm = kq_norm
120
+ if self.kq_norm:
121
+ self.k_norm = RMSNorm(c_model)
122
+ self.q_norm = RMSNorm(c_model)
123
+
124
+ self.c_pair = c_pair
125
+ if c_pair is not None:
126
+ self.to_b = nn.Sequential(RMSNorm(c_pair), linearNoBias(c_pair, n_head))
127
+ self.reset_parameter()
128
+
129
+ def reset_parameter(self):
130
+ # query/key/value projection: Xavier uniform
131
+ nn.init.xavier_uniform_(self.to_q.weight)
132
+ nn.init.xavier_uniform_(self.to_k.weight)
133
+ nn.init.xavier_uniform_(self.to_v.weight)
134
+ nn.init.xavier_uniform_(self.to_g[0].weight)
135
+ nn.init.xavier_uniform_(self.to_out[0].weight)
136
+
137
+ def forward(self, q, kv, attn_mask=None, pair_bias=None):
138
+ """
139
+ Args:
140
+ q: [B, tok, n_q, c_query]
141
+ kv: [B, tok, n_kv, c_kv]
142
+ attn_mask: [n_q, n_kv]
143
+ Returns:
144
+ attn_out: [B, tok, n_q, c_query]
145
+ """
146
+
147
+ q = self.ln_q(q)
148
+ kv = self.ln_kv(kv)
149
+
150
+ q, k, v, g = self.to_q(q), self.to_k(kv), self.to_v(kv), self.to_g(q)
151
+
152
+ if self.kq_norm:
153
+ k = self.k_norm(k)
154
+ q = self.q_norm(q)
155
+
156
+ q, k, v, g = map(
157
+ lambda t: rearrange(t, "b t n (h c) -> b h t n c", h=self.n_head),
158
+ (q, k, v, g),
159
+ ) # [B, tok, n, heads, c] -> [B, heads, tok, n, c]
160
+
161
+ attn = einsum("bhtqc,bhtkc->bhtqk", q, k) * self.scale
162
+
163
+ if pair_bias is not None:
164
+ b = self.to_b(pair_bias)
165
+ b = rearrange(b, "b t q k (h) -> b (h) t q k", h=self.n_head)
166
+ attn = attn + b
167
+
168
+ # Invalid query handling:
169
+ if attn_mask is not None:
170
+ attn = attn.masked_fill(~attn_mask[None, None], float("-inf"))
171
+
172
+ # Bugfix: Empty queries need to have a constant value otherwise nans are in the forward graph. I don't
173
+ # know why this causes instabilities because the invalid queries are masked out later. Oh well!
174
+ invalid_queries = torch.logical_not(
175
+ torch.any(attn_mask, dim=-1, keepdim=False)
176
+ ) # [n_q,]
177
+ attn[:, :, invalid_queries, :] = 0.0
178
+
179
+ attn = F.softmax(attn, dim=-1)
180
+ attn_out = einsum("bhtqk,bhtkd->bhtqd", attn, v)
181
+ attn_out = attn_out * g
182
+
183
+ attn_out = rearrange(attn_out, "b h t n c -> b t n (h c)")
184
+ attn_out = self.to_out(attn_out) # [B, n_tok, n_k, c]
185
+ return attn_out
186
+
187
+
188
+ class LocalAttentionPairBias(nn.Module):
189
+ def __init__(
190
+ self,
191
+ c_a,
192
+ c_s,
193
+ c_pair,
194
+ n_head,
195
+ kq_norm=True,
196
+ n_attn_seq_neighbours=2,
197
+ n_attn_keys=128,
198
+ ):
199
+ super().__init__()
200
+ self.c = c_a # d_model dim same as input features
201
+ self.n_head = n_head
202
+
203
+ self.to_q = linearNoBias(c_a, self.c)
204
+ self.to_k = linearNoBias(c_a, self.c)
205
+ self.to_v = linearNoBias(c_a, self.c)
206
+ self.to_b = linearNoBias(c_pair, self.n_head)
207
+ self.to_g = nn.Sequential(
208
+ linearNoBias(c_a, self.c, bias=False),
209
+ nn.Sigmoid(),
210
+ )
211
+ self.kq_norm = kq_norm
212
+ if kq_norm:
213
+ self.ln_q = RMSNorm(self.c)
214
+ self.ln_k = RMSNorm(self.c)
215
+
216
+ # Output / Input projections
217
+ self.to_o = linearNoBias(self.c, c_a) # from attn to Q_L
218
+
219
+ # Conditioned
220
+ if exists(c_s):
221
+ self.ada_ln_1 = AdaLN(c_a=c_a, c_s=c_s)
222
+ self.linear_output_project = nn.Sequential(
223
+ LinearBiasInit(c_s, c_a, biasinit=-2.0),
224
+ nn.Sigmoid(),
225
+ )
226
+ else:
227
+ self.ln_1 = RMSNorm(c_a)
228
+
229
+ # Used if no indices are provided
230
+ self.n_attn_seq_neighbours = n_attn_seq_neighbours
231
+ self.n_attn_keys = n_attn_keys
232
+ self.use_checkpointing = True
233
+
234
+ def forward(
235
+ self,
236
+ Q_L,
237
+ C_L,
238
+ P_LL,
239
+ indices=None,
240
+ f=None,
241
+ X_L=None,
242
+ full=False,
243
+ chunked_pairwise_embedder=None,
244
+ initializer_outputs=None,
245
+ ):
246
+ """
247
+ Q_L: [D, L, c_a]
248
+ C_L: [D, L, c_s]
249
+ P_LL: [D, L, L, c_pair] or None (if using chunked mode)
250
+ indices: [D, L, k] long
251
+ chunked_pairwise_embedder: ChunkedPairwiseEmbedder for memory efficient computation
252
+ initializer_outputs: Dict containing features for chunked computation
253
+ """
254
+
255
+ # If no indices are provided, prepare indices from
256
+ if not exists(indices):
257
+ indices = create_attention_indices(
258
+ f,
259
+ n_attn_keys=self.n_attn_keys,
260
+ n_attn_seq_neighbours=self.n_attn_seq_neighbours,
261
+ X_L=X_L,
262
+ )
263
+
264
+ # Handle chunked P_LL computation
265
+ if chunked_pairwise_embedder is not None and P_LL is None:
266
+ # Compute sparse P_LL only for the attention indices
267
+ P_LL_sparse = chunked_pairwise_embedder.forward_chunked(
268
+ f=f,
269
+ indices=indices,
270
+ C_L=initializer_outputs["C_L"],
271
+ Z_init_II=initializer_outputs["Z_II"],
272
+ tok_idx=f["atom_to_token_map"],
273
+ )
274
+ # P_LL_sparse is already in sparse format [D, L, k, c_pair]
275
+ use_sparse_pll = True
276
+ else:
277
+ # Original full P_LL computation
278
+ P_LL_sparse = None
279
+ use_sparse_pll = False
280
+
281
+ use_kernel = False
282
+
283
+ def do_attention(Q_L, C_L, P_LL):
284
+ if exists(C_L):
285
+ Q_L = self.ada_ln_1(Q_L, C_L)
286
+ else:
287
+ Q_L = self.ln_1(Q_L)
288
+
289
+ if use_kernel and not use_sparse_pll:
290
+ # TODO: Update with latest kernel
291
+ q, k, v, g, b = (
292
+ self.to_q(Q_L),
293
+ self.to_k(Q_L),
294
+ self.to_v(Q_L),
295
+ self.to_g(Q_L),
296
+ self.to_b(P_LL),
297
+ )
298
+ q, k = (self.ln_q(q), self.ln_k(k)) if self.kq_norm else (q, k)
299
+ attn_out = _fused_full_pairbias_attention(
300
+ Q_L=q, # already projected queries (B, L, c)
301
+ K_L=k,
302
+ V_L=v,
303
+ P_LL=P_LL, # pair features (B, L, L, c_pair)
304
+ num_heads=self.n_head,
305
+ to_b=None, # pair-bias projector (H, c_pair)
306
+ to_g_linear=None, # gating linear (D, D)
307
+ to_o_linear=None, # output linear (D, D)
308
+ w_ln_z_identity=None,
309
+ b_ln_z_identity=None,
310
+ attn_scale=1.0 / math.sqrt(self.c // self.n_head),
311
+ )
312
+ else:
313
+ # Sparse attention path
314
+ q, k, v, g = (
315
+ self.to_q(Q_L),
316
+ self.to_k(Q_L),
317
+ self.to_v(Q_L),
318
+ self.to_g(Q_L),
319
+ )
320
+ q, k = (self.ln_q(q), self.ln_k(k)) if self.kq_norm else (q, k)
321
+
322
+ if use_sparse_pll:
323
+ # Use pre-computed sparse P_LL (already gathered)
324
+ b = self.to_b(P_LL_sparse) # [D, L, k, H]
325
+ attn_out = sparse_pairbias_attention(
326
+ Q=q,
327
+ K=k,
328
+ V=v,
329
+ B=b,
330
+ G=g,
331
+ gather_bias=False, # Already gathered!
332
+ indices=indices,
333
+ H=self.n_head,
334
+ full=full,
335
+ ) # [D, L, c]
336
+ else:
337
+ # Original full P_LL path
338
+ b = self.to_b(P_LL)
339
+ attn_out = sparse_pairbias_attention(
340
+ Q=q,
341
+ K=k,
342
+ V=v,
343
+ B=b,
344
+ G=g,
345
+ gather_bias=True,
346
+ indices=indices,
347
+ H=self.n_head,
348
+ full=full,
349
+ ) # [D, L, c]
350
+
351
+ # Output projection (from adaLN-Zero)
352
+ Q_L = self.to_o(attn_out)
353
+ if exists(C_L):
354
+ Q_L = self.linear_output_project(C_L) * Q_L
355
+
356
+ return Q_L
357
+
358
+ do_attention_ = (
359
+ activation_checkpointing(do_attention)
360
+ if self.use_checkpointing
361
+ else do_attention
362
+ )
363
+
364
+ # Call attention with appropriate P_LL
365
+ if use_sparse_pll:
366
+ return do_attention_(Q_L, C_L, P_LL_sparse)
367
+ else:
368
+ return do_attention_(Q_L, C_L, P_LL)
369
+
370
+
371
+ ######################################################################################
372
+ ########################## Kernel Functions ##########################
373
+ ######################################################################################
374
+
375
+
376
+ def sparse_pairbias_attention(
377
+ Q, K, V, B, indices, H, gather_bias=True, G=None, full=False
378
+ ):
379
+ """
380
+ Computes attention with sparse pairwise bias, where indices specify which
381
+ keys to attend to for each query token.
382
+ Q: (D, L, c) # query vectors
383
+ K: (D, L, c) # key vectors
384
+ V: (D, L, c) # value vectors
385
+ B: (L, L, H) # attention bias (unbatched or pre-gathered and [D, L, k, H])
386
+ G: (D, L, c) # Gate (optional)
387
+ B2: (D, L, 14, 14, H) # attention bias (batched and within token) (optional)
388
+ indices: (D, L, k_neigh) long # indices of neighbours to attend to
389
+ Returns
390
+ -------
391
+ attn_out: (D, L, c) # attention output
392
+ """
393
+ D, L, c = Q.shape
394
+ k = indices.shape[-1] # k_neigh
395
+
396
+ if full:
397
+ # During training, compute full attention matrix to create a more optimized torch.tensor graph.
398
+ return pairbias_attention_(
399
+ Q=Q,
400
+ K=K,
401
+ V=V,
402
+ B=B,
403
+ H=H,
404
+ valid_mask=indices_to_mask(indices),
405
+ G=G,
406
+ )
407
+
408
+ # Pull vectors from dimension 1 into index torch.tensor according to unique k_neigh axis
409
+ batch_idx = torch.arange(D, device=Q.device).view(-1, 1, 1) # (D,1,1)
410
+ K_gathered = K[batch_idx, indices].contiguous() # (D, L, k, c)
411
+ V_gathered = V[batch_idx, indices].contiguous() # (D, L, k, c)
412
+
413
+ # Gather bias or assume pre-gathered
414
+ if gather_bias:
415
+ query_idx = torch.arange(L, device=Q.device).view(1, L, 1) # (1,L,1)
416
+ query_idx = query_idx.expand(D, -1, k)
417
+ if B.ndim == 3:
418
+ B_gathered = B[query_idx, indices, :] # (D, L, k, H)
419
+ elif B.ndim == 4: # (D, L, L, H)
420
+ B_gathered = B[batch_idx, query_idx, indices, :] # (D, L, k, H)
421
+ else:
422
+ assert B.shape == (D, L, k, H), "B must be batched with shape (D, L, k, H)"
423
+ B_gathered = B
424
+ B_gathered = B_gathered.contiguous()
425
+
426
+ # Split into heads
427
+ Q = Q.reshape(D, L, H, c // H)
428
+ K_gathered = K_gathered.reshape(D, L, k, H, c // H)
429
+ V_gathered = V_gathered.reshape(D, L, k, H, c // H)
430
+ B_gathered = B_gathered.reshape(D, L, k, H)
431
+ Q = Q.permute(0, 2, 1, 3) # [D, H, L, c // H]
432
+ K_gathered = K_gathered.permute(0, 3, 1, 2, 4)
433
+ V_gathered = V_gathered.permute(0, 3, 1, 2, 4)
434
+ B_gathered = B_gathered.permute(0, 3, 1, 2)
435
+
436
+ # Do attention
437
+ attn = torch.einsum("...ld,...lkd->...lk", Q, K_gathered)
438
+ attn = attn / sqrt(c // H) # scale
439
+ attn = attn + B_gathered # add bias
440
+ attn = torch.softmax(attn, dim=-1) # softmax over keys [D, H, L, k]
441
+ attn_out = torch.einsum(
442
+ "...ij,...ijc->...ic", attn, V_gathered
443
+ ) # allocates a max of 4.95 GiB.
444
+
445
+ # Optional gating
446
+ if G is not None:
447
+ G = G.reshape(D, L, H, c // H).permute(0, 2, 1, 3)
448
+ attn_out = attn_out * G
449
+
450
+ # Merge heads
451
+ attn_out = attn_out.permute(0, 2, 1, 3)
452
+ attn_out = attn_out.reshape(D, L, c).contiguous()
453
+
454
+ return attn_out # [D, L, c]
455
+
456
+
457
+ def pairbias_attention_(Q, K, V, B, H, valid_mask=None, G=None):
458
+ """
459
+ Fully connected variant of pairbias attention with optional gating and valid mask.
460
+ Equivalent to sparse attention but with all keys
461
+
462
+ Attn_out: [batch_size, query_length, H * head_dim]
463
+ """
464
+ D, L, c = Q.shape
465
+ k = L
466
+
467
+ # Split into heads
468
+ Q = Q.reshape(D, L, H, c // H)
469
+ K = K.reshape(D, k, H, c // H)
470
+ V = V.reshape(D, k, H, c // H)
471
+ B = B.reshape(D, L, k, H)
472
+
473
+ # Flip heads upwards [..., H, d_model] -> [B, H, ..., d_model]
474
+ Q = Q.permute(0, 2, 1, 3) # [D, H, L, c // H]
475
+ K = K.permute(0, 2, 1, 3)
476
+ V = V.permute(0, 2, 1, 3)
477
+ B = B.permute(0, 3, 1, 2)
478
+
479
+ # Do attention
480
+ attn = torch.einsum("...ld,...kd->...lk", Q, K)
481
+ attn = attn / sqrt(c // H) # scale
482
+ attn = attn + B # add bias
483
+ if exists(valid_mask):
484
+ # expand valid mask over heads [D, H, L, L]
485
+ attn = attn.masked_fill(~valid_mask.unsqueeze(1), float("-inf"))
486
+ attn = torch.softmax(attn, dim=-1) # softmax over keys [D, H, L, k]
487
+ attn_out = torch.einsum("...ij,...jc->...ic", attn, V)
488
+
489
+ # Optional gating
490
+ if G is not None:
491
+ G = G.reshape(D, L, H, c // H).permute(0, 2, 1, 3)
492
+ attn_out = attn_out * G
493
+
494
+ # Merge heads
495
+ attn_out = attn_out.permute(0, 2, 1, 3)
496
+ attn_out = attn_out.reshape(D, L, c).contiguous()
497
+
498
+ return attn_out
499
+
500
+
501
+ def _fused_full_pairbias_attention(
502
+ *,
503
+ Q_L, # (B, L, c) -- sequence features used to make q,k,v and for gating
504
+ K_L, # (B, L, c)
505
+ V_L, # (B, L, c)
506
+ P_LL, # (B, L, L, c_pair)
507
+ num_heads: int,
508
+ to_b: nn.Linear, # projects pair features -> heads (H)
509
+ to_g_linear: nn.Linear, # weight (D, D), bias optional/None (pre-sigmoid, kernel handles gate)
510
+ to_o_linear: nn.Linear, # weight (D, D), bias optional/None (kernel handles output proj)
511
+ w_ln_z_identity: torch.torch.Tensor, # (c_pair,)
512
+ b_ln_z_identity: torch.torch.Tensor, # (c_pair,)
513
+ attn_scale: float | None = None,
514
+ ):
515
+ """
516
+ Uses cuequivariance_torch.attention_pair_bias for dense (full) attention.
517
+ Expects Q/K/V to be projected *before* calling this function.
518
+ """
519
+ B, L, c = Q_L.shape
520
+ H = num_heads
521
+ assert c % H == 0, "Model dim must be divisible by num_heads"
522
+ DH = c // H
523
+
524
+ # q, k, v as (B, H, L, DH)
525
+ q = Q_L.reshape(B, L, H, DH).permute(0, 2, 1, 3).contiguous()
526
+ k = K_L.reshape(B, L, H, DH).permute(0, 2, 1, 3).contiguous()
527
+ v = V_L.reshape(B, L, H, DH).permute(0, 2, 1, 3).contiguous()
528
+
529
+ # s is the sequence features for gating/output projections
530
+ s = Q_L.contiguous() # (B, L, c)
531
+
532
+ # mask: None (kernel supports key padding mask shape (B,V) or (B*M,V); we don't need it here)
533
+ mask = None
534
+
535
+ # weights/biases for kernel (shapes per doc):
536
+ # w_proj_z: (H, z_dim)
537
+ w_proj_z = to_b.weight # (H, c_pair)
538
+ b_proj_z = to_b.bias if hasattr(to_b, "bias") else None
539
+
540
+ # w_proj_g / o: (D, D)
541
+ w_proj_g = to_g_linear.weight # (D, D)
542
+ b_proj_g = to_g_linear.bias if hasattr(to_g_linear, "bias") else None
543
+
544
+ w_proj_o = to_o_linear.weight # (D, D)
545
+ b_proj_o = to_o_linear.bias if hasattr(to_o_linear, "bias") else None
546
+
547
+ # z-LN params
548
+ w_ln_z = w_ln_z_identity.to(dtype=P_LL.dtype, device=P_LL.device)
549
+ b_ln_z = b_ln_z_identity.to(dtype=P_LL.dtype, device=P_LL.device)
550
+
551
+ # optional scaling (match your manual path)
552
+ if attn_scale is None:
553
+ attn_scale = 1.0 / math.sqrt(DH)
554
+
555
+ # Call the fused kernel (B*M collapses to B here; multiplicity=1)
556
+ out, _proj_z = cueq_attention_pair_bias(
557
+ s=s,
558
+ q=q,
559
+ k=k,
560
+ v=v,
561
+ z=P_LL,
562
+ mask=mask,
563
+ num_heads=H,
564
+ w_proj_z=w_proj_z,
565
+ w_proj_g=w_proj_g,
566
+ w_proj_o=w_proj_o,
567
+ w_ln_z=w_ln_z,
568
+ b_ln_z=b_ln_z,
569
+ b_proj_z=b_proj_z,
570
+ b_proj_g=b_proj_g,
571
+ b_proj_o=b_proj_o,
572
+ attn_scale=attn_scale,
573
+ compute_pair_bias=True,
574
+ multiplicity=1,
575
+ )
576
+ # out: (B, L, c) already gated & projected
577
+ return out