rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,580 @@
1
+ import logging
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from jaxtyping import Float, Int
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ def bucketize_scaled_distogram(R_L, min_dist=1, max_dist=30, sigma_data=16, n_bins=65):
12
+ """
13
+ Bucketizes pairwise distances into bins based on edm scaling
14
+
15
+ min dist and max dist given as angstroms
16
+ Will use bin ranges based on scaled angstrom distances
17
+
18
+ R_L: B, N, 3
19
+ D_LL: B, N, N
20
+ D_LL_binned: B, N, N, n_bins
21
+ """
22
+ D_LL = R_L.unsqueeze(-2) - R_L.unsqueeze(-3) # [B, N, N, 3]
23
+ D_LL = torch.linalg.norm(D_LL, dim=-1) # [B, N, N]
24
+
25
+ # normalize
26
+ min_dist, max_dist = min_dist / sigma_data, max_dist / sigma_data
27
+
28
+ bins = torch.linspace(min_dist, max_dist, n_bins - 1, device=D_LL.device)
29
+ bin_idxs = torch.bucketize(D_LL, bins)
30
+ return F.one_hot(bin_idxs, num_classes=len(bins) + 1).float()
31
+
32
+
33
+ def build_valid_mask(
34
+ tok_idx: torch.Tensor, n_atoms_per_tok_max: int | None = None
35
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
36
+ """
37
+ Args
38
+ ----
39
+ tok_idx : (n_atoms,) non negative integer array
40
+ n_atoms_per_tok_max : if given, pad/truncate up to this size
41
+
42
+ Returns
43
+ -------
44
+ valid_mask : (n_tokens, A) True where an atom exists
45
+ tokens : (n_tokens,) the unique token IDs in ascending order
46
+ """
47
+ tokens, counts = torch.unique(tok_idx, return_counts=True)
48
+ A = int(counts.max()) if n_atoms_per_tok_max is None else int(n_atoms_per_tok_max)
49
+
50
+ # build [n_tokens, A] mask; broadcasting keeps it vectorised
51
+ atom_idx_grid = torch.arange(A, device=tok_idx.device)[None, :] # (1, A)
52
+ valid_mask = atom_idx_grid < counts[:, None] # (n_tok, A)
53
+
54
+ return valid_mask
55
+
56
+
57
+ def ungroup_atoms(Q_L, valid_mask):
58
+ """
59
+ Args
60
+ ----
61
+ Q_L : (B, n_atoms, c)
62
+ valid_mask : (n_tokens, A) # same object returned by `ungroup_atoms`
63
+
64
+ Returns
65
+ -------
66
+ Q_IA : (B, n_tokens, A, c) # padded with zeros
67
+ """
68
+ B, n_atoms, c = Q_L.shape
69
+ n_tokens, A = valid_mask.shape
70
+ Q_IA = torch.zeros(B, n_tokens, A, c, dtype=Q_L.dtype, device=Q_L.device)
71
+ mask4d = valid_mask.unsqueeze(0).unsqueeze(-1) # (1, n_tok, A, 1)
72
+ mask4d = mask4d.expand(B, -1, -1, c) # (B, n_tok, A, c)
73
+ Q_IA.masked_scatter_(mask4d, Q_L)
74
+ return Q_IA
75
+
76
+
77
+ def group_atoms(Q_IA: torch.Tensor, valid_mask: torch.Tensor) -> torch.Tensor:
78
+ """
79
+ Args
80
+ ----
81
+ Q_IA : (B, n_tokens, A, c)
82
+ valid_mask : (n_tokens, A)
83
+
84
+ Returns
85
+ -------
86
+ Q_L : (B, n_atoms, c) flattened real atoms, order preserved
87
+ """
88
+ B, _, _, c = Q_IA.shape
89
+ mask4d = valid_mask.unsqueeze(0).unsqueeze(-1).expand(B, -1, -1, c) # (B,n_tok,A,c)
90
+ Q_L = Q_IA[mask4d].view(B, -1, c) # restore 2‑D shape
91
+ return Q_L
92
+
93
+
94
+ def group_pair(P_IAA, valid_mask):
95
+ # Valid mask: [L, A]
96
+ # P_IAA: (B, L, A, A, c) or (L, A, A, c)
97
+ if P_IAA.ndim == 5:
98
+ B, _, _, A, c = P_IAA.shape
99
+ mask5d = valid_mask[None, ..., None, None].expand(
100
+ B, -1, -1, A, c
101
+ ) # (B, L, L, A, c)
102
+ P_LA = P_IAA[mask5d].view(B, -1, A, c) # (B, n_valid, A, c)
103
+ elif P_IAA.ndim == 4:
104
+ _, _, A, c = P_IAA.shape
105
+ mask4d = valid_mask[..., None, None].expand(-1, -1, A, c) # (L, L, A, c)
106
+ P_LA = P_IAA[mask4d].view(-1, A, c) # (n_valid, A, c)
107
+ else:
108
+ raise ValueError(
109
+ f"Unexpected input shape {P_IAA.shape}: must be (B, L, A, A, c) or (L, A, A, c)"
110
+ )
111
+
112
+ return P_LA
113
+
114
+
115
+ def scatter_add_pair_features(P_LK_tgt, P_LK_indices, P_LA_src, P_LA_indices):
116
+ """
117
+ Adds features from P_LA_C into P_LK_C at positions where P_LA matches P_LK.
118
+
119
+ Parameters
120
+ ----------
121
+ P_LK_indices : (D, L, k) LongTensor
122
+ Key indices | P_LK_indices[d, i, k] = global atom index for which atom i attends to.
123
+ P_LK : (D, L, k, c) FloatTensor
124
+ Key features to scatter add into
125
+
126
+ P_LA_indices : (D, L, a) LongTensor
127
+ Additional feature indices to scatter into P_LK.
128
+ P_LA : (D, L, a, c) FloatTensor
129
+ Features corresponding to P_LA.
130
+
131
+ Both index tensors contain indices representing D batch dim,
132
+ L sequence positions and k keys / a additional features.
133
+ This function will scatter indices from P_LA into P_LK based on
134
+ matching indices.
135
+
136
+ """
137
+ # Handle case when indices and P_LA don't have batch dimensions
138
+ D, L, k = P_LK_indices.shape
139
+ if P_LA_indices.ndim == 2:
140
+ P_LA_indices = P_LA_indices.unsqueeze(0).expand(D, -1, -1)
141
+ if P_LA_src.ndim == 3:
142
+ P_LA_src = P_LA_src.unsqueeze(0).expand(D, -1, -1)
143
+ assert (
144
+ P_LA_src.shape[-1] == P_LK_tgt.shape[-1]
145
+ ), "Channel dims do not match, got: {} vs {}".format(
146
+ P_LA_src.shape[-1], P_LK_tgt.shape[-1]
147
+ )
148
+
149
+ matches = P_LA_indices.unsqueeze(-1) == P_LK_indices.unsqueeze(-2) # (D, L, a, k)
150
+ if not torch.all(matches.sum(dim=(-1, -2)) >= 1):
151
+ raise ValueError("Found multiple scatter indices for some atoms")
152
+ elif not torch.all(matches.sum(dim=-1) <= 1):
153
+ raise ValueError("Did not find a scatter index for every atom")
154
+ k_indices = matches.long().argmax(dim=-1) # (D, L, a)
155
+ scatter_indices = k_indices.unsqueeze(-1).expand(
156
+ -1, -1, -1, P_LK_tgt.shape[-1]
157
+ ) # (D, L, a, c)
158
+ P_LK_tgt = P_LK_tgt.scatter_add(dim=2, index=scatter_indices, src=P_LA_src)
159
+ return P_LK_tgt
160
+
161
+
162
+ def _batched_gather(values: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
163
+ """
164
+ values : (D, L, C)
165
+ idx : (D, L, k)
166
+ returns: (D, L, k, C)
167
+ """
168
+ D, L, C = values.shape
169
+ k = idx.shape[-1]
170
+
171
+ # (D, L, 1, C) → stride-0 along k → (D, L, k, C)
172
+ src = values.unsqueeze(2).expand(-1, -1, k, -1)
173
+ idx = idx.unsqueeze(-1).expand(-1, -1, -1, C) # (D, L, k, C)
174
+
175
+ return torch.gather(src, 1, idx) # dim=1 is the L-axis
176
+
177
+
178
+ @torch.no_grad()
179
+ def create_attention_indices(
180
+ f, n_attn_keys, n_attn_seq_neighbours, X_L=None, tok_idx=None
181
+ ):
182
+ """
183
+ Entry-point function for creating attention indices for sequence & structure-local attention
184
+
185
+ f: input features of the model
186
+ n_attn_keys: number of (atom) attention keys
187
+ n_attn_seq_neighbours: number of neighbouring sequence tokens (residues) to attend to
188
+ X_L: optional input tensor for atom positions | if None, choose random padding atoms
189
+ """
190
+
191
+ tok_idx = f["atom_to_token_map"] if tok_idx is None else tok_idx
192
+ device = X_L.device if X_L is not None else tok_idx.device
193
+ L = len(tok_idx)
194
+
195
+ if X_L is None:
196
+ X_L = torch.randn(
197
+ (1, L, 3), device=device, dtype=torch.float
198
+ ) # [L, 3] - random
199
+ D_LL = torch.cdist(X_L, X_L, p=2) # [D, L, L] - pairwise atom distances
200
+
201
+ # Create attention indices using neighbour distances
202
+ base_mask = ~f["unindexing_pair_mask"][
203
+ tok_idx[None, :], tok_idx[:, None]
204
+ ] # [n_atoms, n_atoms]
205
+ k_actual = min(n_attn_keys, L)
206
+
207
+ # For symmetric structures, ensure inter-chain interactions are included
208
+ chain_ids = f["asym_id"][tok_idx] if "asym_id" in f else None
209
+ if (
210
+ chain_ids is not None and len(torch.unique(chain_ids)) > 3
211
+ ): # Multi-chain structure
212
+ # Reserve 25% of attention keys for inter-chain interactions
213
+ k_inter_chain = max(32, k_actual // 4) # At least 32 inter-chain keys
214
+ k_intra_chain = k_actual - k_inter_chain
215
+
216
+ attn_indices = get_sparse_attention_indices_with_inter_chain(
217
+ tok_idx,
218
+ D_LL,
219
+ n_seq_neighbours=n_attn_seq_neighbours,
220
+ k_intra=k_intra_chain,
221
+ k_inter=k_inter_chain,
222
+ chain_id=chain_ids,
223
+ base_mask=base_mask,
224
+ )
225
+ else:
226
+ # Regular attention for single chain or small structures
227
+ attn_indices = get_sparse_attention_indices(
228
+ tok_idx,
229
+ D_LL,
230
+ n_seq_neighbours=n_attn_seq_neighbours,
231
+ k_max=k_actual,
232
+ chain_id=chain_ids,
233
+ base_mask=base_mask,
234
+ ) # [D, L, k] | indices[b, i, j] = atom index for atom i to j-th attn query
235
+
236
+ return attn_indices
237
+
238
+
239
+ @torch.no_grad()
240
+ def get_sparse_attention_indices_with_inter_chain(
241
+ tok_idx, D_LL, n_seq_neighbours, k_intra, k_inter, chain_id, base_mask
242
+ ):
243
+ """
244
+ Create attention indices that guarantee inter-chain interactions for clash avoidance.
245
+
246
+ Args:
247
+ tok_idx: atom to token mapping
248
+ D_LL: pairwise distances [D, L, L]
249
+ n_seq_neighbours: number of sequence neighbors
250
+ k_intra: number of intra-chain attention keys
251
+ k_inter: number of inter-chain attention keys
252
+ chain_id: chain IDs for each atom
253
+ base_mask: base mask for valid pairs
254
+
255
+ Returns:
256
+ attn_indices: [D, L, k_total] where k_total = k_intra + k_inter
257
+ """
258
+ D, L, _ = D_LL.shape
259
+
260
+ # Get regular intra-chain indices (limited to k_intra)
261
+ intra_indices = get_sparse_attention_indices(
262
+ tok_idx, D_LL, n_seq_neighbours, k_intra, chain_id, base_mask
263
+ ) # [D, L, k_intra]
264
+
265
+ # Get inter-chain indices for clash avoidance
266
+ inter_indices = torch.zeros(D, L, k_inter, dtype=torch.long, device=D_LL.device)
267
+
268
+ for d in range(D):
269
+ for l in range(L):
270
+ query_chain = chain_id[l]
271
+
272
+ # Find atoms from different chains
273
+ other_chain_mask = (chain_id != query_chain) & base_mask[l, :]
274
+ other_chain_atoms = torch.where(other_chain_mask)[0]
275
+
276
+ if len(other_chain_atoms) > 0:
277
+ # Get distances to other chains
278
+ distances_to_other = D_LL[d, l, other_chain_atoms]
279
+
280
+ # Select k_inter closest atoms from other chains
281
+ n_select = min(k_inter, len(other_chain_atoms))
282
+ _, closest_idx = torch.topk(distances_to_other, n_select, largest=False)
283
+ selected_atoms = other_chain_atoms[closest_idx]
284
+
285
+ # Fill inter-chain indices
286
+ inter_indices[d, l, :n_select] = selected_atoms
287
+ # Pad with random atoms if needed
288
+ if n_select < k_inter:
289
+ padding = torch.randint(
290
+ 0, L, (k_inter - n_select,), device=D_LL.device
291
+ )
292
+ inter_indices[d, l, n_select:] = padding
293
+ else:
294
+ # No other chains found, fill with random indices
295
+ inter_indices[d, l, :] = torch.randint(
296
+ 0, L, (k_inter,), device=D_LL.device
297
+ )
298
+
299
+ # Combine intra and inter chain indices
300
+ combined_indices = torch.cat(
301
+ [intra_indices, inter_indices], dim=-1
302
+ ) # [D, L, k_total]
303
+
304
+ return combined_indices
305
+
306
+
307
+ @torch.no_grad()
308
+ def build_index_mask(
309
+ tok_idx: torch.Tensor,
310
+ n_sequence_neighbours: int,
311
+ k_max: int,
312
+ chain_id: torch.Tensor | None = None,
313
+ base_mask: torch.Tensor | None = None,
314
+ ) -> torch.Tensor:
315
+ """
316
+ Builds a mask that includes entire tokens from neighboring positions within a
317
+ tokenized sequence, never partially including a token. Limits range to k_max,
318
+ which is interpreted at the token level.
319
+
320
+ Parameters:
321
+ tok_idx: (L,) tensor of token indices.
322
+ n_sequence_neighbours: number of tokens to include on either side.
323
+ k_max: max total number of tokens (across both directions).
324
+ chain_id: (L,) chain identifiers for each position (optional).
325
+ base_mask: (L, L) optional pre-mask to AND with.
326
+ """
327
+ device = tok_idx.device
328
+ L = tok_idx.shape[0]
329
+ k_max = min(k_max, L)
330
+ I = int(tok_idx.max()) + 1 # Number of unique tokens
331
+ n_atoms_per_token = torch.zeros(I, device=device).float()
332
+ n_atoms_per_token.scatter_add_(0, tok_idx.long(), torch.ones_like(tok_idx).float())
333
+
334
+ # Create index masks for tokens and atoms
335
+ token_indices = torch.arange(I, device=device)
336
+ token_diff = (token_indices[:, None] - token_indices[None, :]).abs()
337
+ atom_indices = torch.arange(L, device=device)
338
+ atom_diff = (atom_indices[:, None] - atom_indices[None, :]).abs()
339
+
340
+ # Build token-token mask: [I, I]
341
+ token_mask = token_diff <= n_sequence_neighbours
342
+
343
+ # Expand token_mask to full [L, L] mask using broadcast
344
+ # token_to_idx maps each position to a token index [L]
345
+ token_i = tok_idx[:, None] # (L, 1)
346
+ token_j = tok_idx[None, :] # (1, L)
347
+ mask = token_mask[token_i, token_j] # (L, L)
348
+ mask = mask & (atom_diff <= (k_max // 2))
349
+
350
+ # Exclude tokens which are partially filled (L, I)
351
+ n_query_per_token = torch.zeros((L, I), device=device).float()
352
+ n_query_per_token.scatter_add_(
353
+ 1, tok_idx.long()[None, :].expand(L, -1), mask.float()
354
+ )
355
+
356
+ # Find mask for the atoms for which the number of keys
357
+ # match the number of atoms in the token (L, I)
358
+ fully_included = n_query_per_token == n_atoms_per_token[None, :]
359
+
360
+ # Contract to (L, L) and count the number of atoms within tokens that
361
+ # fully include other tokens
362
+ n_atoms_fully_included = torch.zeros((I, I), device=device)
363
+ n_atoms_fully_included.index_add_(0, tok_idx.long(), fully_included.float())
364
+ full_token_mask = n_atoms_fully_included == n_atoms_per_token[:, None]
365
+
366
+ # Map this back to (L, L) — include token j in row i only if all its atoms are included
367
+ full_token_mask = full_token_mask[token_i, token_j] # (L, L)
368
+ mask &= full_token_mask
369
+
370
+ if chain_id is not None:
371
+ same_chain = chain_id.unsqueeze(-1) == chain_id.unsqueeze(-2)
372
+ mask = mask & same_chain
373
+
374
+ if base_mask is not None:
375
+ mask = mask & base_mask
376
+
377
+ return mask
378
+
379
+
380
+ def extend_index_mask_with_neighbours(
381
+ mask: torch.Tensor, D_LL: torch.Tensor, k: int
382
+ ) -> torch.LongTensor:
383
+ """
384
+ Parameters
385
+ ----------
386
+ mask : (L, L) bool # pre-selected neighbours (True = keep)
387
+ D_LL : (B, L, L) float32/float64 # pairwise distances (lower = closer)
388
+ k: int # desired neighbours per query token
389
+
390
+ Returns
391
+ -------
392
+ neigh_idx : (L, k_neigh) long # exactly k_neigh indices per row
393
+
394
+ NB: Indices of the mask are placed first along k dimension. e.g.
395
+ indices[i, :] = [1, 2, 3, nan, nan] (from pre-built mask)
396
+ -> indices[i, :] = [1, 2, 3, 0, 5] # where 0, 5 are additional k NN (here k=5)
397
+ NB: If k_neigh = 14 * (2*n_seq_neigh + 1) (from above), then for tokens in the middle there will
398
+ be exactly no D_LL-local neighbours, but for tokens at the edges there will be an increasingly
399
+ large number of neighbours.
400
+ """
401
+ if D_LL.ndim == 2:
402
+ D_LL = D_LL.unsqueeze(0)
403
+ B, L, _ = D_LL.shape
404
+ k = min(k, L)
405
+ assert mask.shape == (L, L) and D_LL.shape == (B, L, L)
406
+ device = D_LL.device
407
+ inf = torch.tensor(float("inf"), dtype=D_LL.dtype, device=device)
408
+
409
+ # 1. Selection of sequence neighbours
410
+ all_idx_row = torch.arange(L, device=device).expand(L, L)
411
+ indices = torch.where(mask, all_idx_row, inf) # sentinel inf if not-forced
412
+ indices = indices.sort(dim=1)[0][:, :k] # (L, k)
413
+
414
+ # 2. Find k-nn excluding forced indices
415
+ D_LL = torch.where(mask, inf, D_LL)
416
+ filler_idx = torch.topk(D_LL, k, dim=-1, largest=False).indices
417
+
418
+ # ... Reverse last axis s.t. best matched indices are last
419
+ filler_idx = filler_idx.flip(dims=[-1])
420
+
421
+ # 3. Fill indices
422
+ to_fill = indices == inf
423
+ to_fill = to_fill.expand_as(filler_idx)
424
+ indices = indices.expand_as(filler_idx)
425
+ indices = torch.where(to_fill, filler_idx, indices)
426
+
427
+ return indices.long() # (B, L, k)
428
+
429
+
430
+ def get_sparse_attention_indices(
431
+ res_idx, D_LL, n_seq_neighbours, k_max, chain_id=None, base_mask=None
432
+ ):
433
+ mask = build_index_mask(
434
+ res_idx, n_seq_neighbours, k_max, chain_id=chain_id, base_mask=base_mask
435
+ )
436
+ indices = extend_index_mask_with_neighbours(mask, D_LL, k_max)
437
+
438
+ # Sort and assert no duplicates (optional but good practise)
439
+ indices, _ = torch.sort(indices, dim=-1)
440
+ if (indices[..., 1:] == indices[..., :-1]).any():
441
+ raise AssertionError("Tensor has duplicate elements along the last dimension.")
442
+
443
+ assert (
444
+ indices.shape[-1] == k_max
445
+ ), f"Expected k_max={k_max} indices, got {indices.shape[-1]} instead."
446
+ # Detach to avoid gradients flowing through indices
447
+
448
+ return indices.detach()
449
+
450
+
451
+ @torch.no_grad()
452
+ def indices_to_mask(neigh_idx):
453
+ """
454
+ Helper function for converting indices to masks for visualization
455
+
456
+ Args:
457
+ neigh_idx: [L, k] or [B, L, k] tensor of indices for attention.
458
+ """
459
+ neigh_idx = neigh_idx.to(dtype=torch.long)
460
+
461
+ if neigh_idx.ndim == 2:
462
+ L = neigh_idx.shape[0]
463
+ mask_out = torch.zeros((L, L), dtype=torch.bool, device=neigh_idx.device)
464
+ mask_out.scatter_(1, neigh_idx, torch.ones_like(neigh_idx, dtype=torch.bool))
465
+
466
+ elif neigh_idx.ndim == 3:
467
+ B, L, k = neigh_idx.shape
468
+ mask_out = torch.zeros((B, L, L), dtype=torch.bool, device=neigh_idx.device)
469
+ mask_out.scatter_(2, neigh_idx, torch.ones_like(neigh_idx, dtype=torch.bool))
470
+
471
+ else:
472
+ raise ValueError(f"Expected ndim 2 or 3, got {neigh_idx.ndim}")
473
+
474
+ return mask_out
475
+
476
+
477
+ def create_valid_mask_LA(valid_mask):
478
+ """
479
+ Helper function for X_IAA (token-grouped atom-pair representations).
480
+ valid_mask: [I, A] represents which atoms in the token-grouping are real,
481
+ sum(valid_mask) = L, where L is total number of atoms.
482
+
483
+ Returns
484
+ -------
485
+ valid_mask_LA: [L, A] L atoms by A atoms in token grouping.
486
+ indices: [L, A] absolute atom indices of atoms in token grouping.
487
+
488
+ E.g. Allows you to have [14, 14] matrices for every token in your protein,
489
+ where atomized tokens (or similar) will have invalid indices outside of [0,0].
490
+ """
491
+ I, A = valid_mask.shape
492
+ L = valid_mask.sum()
493
+ pos = torch.arange(A, device=valid_mask.device)
494
+ rel_pos = pos.unsqueeze(-2) - pos.unsqueeze(-1) # [A, A]
495
+ rel_pos = rel_pos.unsqueeze(0).expand(I, -1, -1) # [I, A, A]
496
+ rel_pos_LA = rel_pos[valid_mask[..., None].expand_as(rel_pos)].view(
497
+ L, A
498
+ ) # [I, A, A] -> [L, A]
499
+
500
+ indices = torch.arange(L, device=valid_mask.device).unsqueeze(-1).expand(L, A)
501
+ indices = indices + rel_pos_LA
502
+
503
+ valid_mask_IAA = valid_mask.unsqueeze(-2).expand(-1, A, -1)
504
+ valid_mask_LA = valid_mask_IAA[
505
+ valid_mask.unsqueeze(-1).expand_as(valid_mask_IAA)
506
+ ].view(L, A)
507
+
508
+ indices[~valid_mask_LA] = -1
509
+
510
+ return valid_mask_LA, indices
511
+
512
+
513
+ def pairwise_mean_pool(
514
+ pairwise_atom_features: Float[torch.Tensor, "batch n_atoms n_atoms d_hidden"],
515
+ atom_to_token_map: Int[torch.Tensor, "n_atoms"],
516
+ I: int,
517
+ dtype: torch.dtype,
518
+ ) -> Float[torch.Tensor, "batch n_tokens n_tokens d_hidden"]:
519
+ """Mean pooling of pairwise atom features to pairwise token features.
520
+
521
+ Args:
522
+ pairwise_atom_features: Pairwise features between atoms
523
+ atom_to_token_map: Mapping from atoms to tokens
524
+ I: Number of tokens
525
+ dtype: Data type for computations
526
+
527
+ Returns:
528
+ Token pairwise features pooled by averaging over atom pairs within tokens
529
+ """
530
+ B, _, _, _ = pairwise_atom_features.shape
531
+
532
+ # Create one-hot encoding for atom-to-token mapping
533
+ atom_to_token_onehot = F.one_hot(atom_to_token_map.long(), num_classes=I).to(
534
+ dtype
535
+ ) # (L, I)
536
+
537
+ # Use einsum to aggregate features across atom pairs for each token pair
538
+ # For each token pair (i, j), sum over all atom pairs (l1, l2) where l1→i and l2→j
539
+ # Result[b,i,j,d] = sum_l1,l2 ( onehot[l1,i] * onehot[l2,j] * features[b,l1,l2,d] )
540
+ use_memory_efficient_einsum = True
541
+ if use_memory_efficient_einsum:
542
+ # Memory-optimized implementation using two-step einsum:
543
+ # First step: contract on axis 1 (left-side tokens)
544
+ # (L, I)^T = (I, L), (B, L, L, d) → (B, I, L, d)
545
+ temp = torch.einsum(
546
+ "ia,bacd->bicd", atom_to_token_onehot.T, pairwise_atom_features
547
+ )
548
+
549
+ # Free the original to save memory if not needed
550
+ del pairwise_atom_features
551
+
552
+ # Second step: contract on axis 2 (right-side tokens)
553
+ # (L, I) = (L, I), (B, I, L, d) → (B, I, I, d)
554
+ token_features_sum = torch.einsum("cj,bicd->bijd", atom_to_token_onehot, temp)
555
+
556
+ # Optionally free temp
557
+ del temp
558
+ else:
559
+ token_features_sum = torch.einsum(
560
+ "ai,cj,bacd->bijd",
561
+ atom_to_token_onehot, # (L, I)
562
+ atom_to_token_onehot, # (L, I)
563
+ pairwise_atom_features, # (B, L, L, d_hidden)
564
+ ) # (B, I, I, d_hidden)
565
+
566
+ # Count the number of atom pairs contributing to each token pair
567
+ # count[i, j] = number of atom pairs (l1, l2) where l1→i and l2→j (same for all batches)
568
+ atom_counts_per_token = atom_to_token_onehot.sum(dim=0) # (I,)
569
+ token_pair_counts = torch.outer(
570
+ atom_counts_per_token, atom_counts_per_token
571
+ ) # (I, I) (= outer product)
572
+
573
+ # Expand to match batch dimension: (I, I) -> (B, I, I)
574
+ token_pair_counts = token_pair_counts.unsqueeze(0).expand(B, -1, -1)
575
+
576
+ # Avoid division by zero and compute mean
577
+ token_pair_counts = torch.clamp(token_pair_counts, min=1)
578
+ token_pairwise_features = token_features_sum / token_pair_counts.unsqueeze(-1)
579
+
580
+ return token_pairwise_features