rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,377 @@
1
+ """
2
+ Chunked pairwise embedding implementation for memory-efficient large structure processing.
3
+
4
+ This module provides memory-optimized versions of pairwise embedders that compute
5
+ only the pairs needed for sparse attention, reducing memory usage from O(L²) to O(L×k).
6
+ """
7
+
8
+ import math
9
+ from typing import Optional
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from rfd3.model.layers.layer_utils import RMSNorm, linearNoBias
14
+
15
+
16
+ class ChunkedPositionPairDistEmbedder(nn.Module):
17
+ """
18
+ Memory-efficient version of PositionPairDistEmbedder that computes pairs on-demand.
19
+ """
20
+
21
+ def __init__(self, c_atompair, embed_frame=True):
22
+ super().__init__()
23
+ self.c_atompair = c_atompair
24
+ self.embed_frame = embed_frame
25
+ if embed_frame:
26
+ self.process_d = linearNoBias(3, c_atompair)
27
+
28
+ self.process_inverse_dist = linearNoBias(1, c_atompair)
29
+ self.process_valid_mask = linearNoBias(1, c_atompair)
30
+
31
+ def compute_pairs_chunked(
32
+ self,
33
+ query_pos: torch.Tensor, # [D, 3]
34
+ key_pos: torch.Tensor, # [D, k, 3]
35
+ valid_mask: torch.Tensor, # [D, k, 1]
36
+ ) -> torch.Tensor:
37
+ """
38
+ Compute pairwise embeddings for specific query-key pairs.
39
+
40
+ Args:
41
+ query_pos: Query positions [D, 3]
42
+ key_pos: Key positions [D, k, 3]
43
+ valid_mask: Valid pair mask [D, k, 1]
44
+
45
+ Returns:
46
+ P_sparse: Pairwise embeddings [D, k, c_atompair]
47
+ """
48
+ D, k = key_pos.shape[:2]
49
+
50
+ # Compute pairwise distances: [D, k, 3]
51
+ D_pairs = query_pos.unsqueeze(1) - key_pos # [D, 1, 3] - [D, k, 3] = [D, k, 3]
52
+
53
+ if self.embed_frame:
54
+ # Embed pairwise distances
55
+ P_pairs = self.process_d(D_pairs) * valid_mask # [D, k, c_atompair]
56
+
57
+ # Add inverse distance embedding
58
+ norm_sq = torch.linalg.norm(D_pairs, dim=-1, keepdim=True) ** 2 # [D, k, 1]
59
+ inv_dist = 1 / (1 + norm_sq)
60
+ P_pairs = P_pairs + self.process_inverse_dist(inv_dist) * valid_mask
61
+
62
+ # Add valid mask embedding
63
+ P_pairs = (
64
+ P_pairs
65
+ + self.process_valid_mask(valid_mask.to(P_pairs.dtype)) * valid_mask
66
+ )
67
+ else:
68
+ # Simplified version without frame embedding
69
+ norm_sq = torch.linalg.norm(D_pairs, dim=-1, keepdim=True) ** 2
70
+ norm_sq = torch.clamp(norm_sq, min=1e-6)
71
+ inv_dist = 1 / (1 + norm_sq)
72
+ P_pairs = self.process_inverse_dist(inv_dist) * valid_mask
73
+ P_pairs = (
74
+ P_pairs
75
+ + self.process_valid_mask(valid_mask.to(P_pairs.dtype)) * valid_mask
76
+ )
77
+
78
+ return P_pairs
79
+
80
+
81
+ class ChunkedSinusoidalDistEmbed(nn.Module):
82
+ """
83
+ Memory-efficient version of SinusoidalDistEmbed.
84
+ """
85
+
86
+ def __init__(self, c_atompair, n_freqs=32):
87
+ super().__init__()
88
+ assert c_atompair % 2 == 0, "Output embedding dim must be even"
89
+
90
+ self.n_freqs = n_freqs
91
+ self.c_atompair = c_atompair
92
+
93
+ self.output_proj = linearNoBias(2 * n_freqs, c_atompair)
94
+ self.process_valid_mask = linearNoBias(1, c_atompair)
95
+
96
+ def compute_pairs_chunked(
97
+ self,
98
+ query_pos: torch.Tensor, # [D, 3]
99
+ key_pos: torch.Tensor, # [D, k, 3]
100
+ valid_mask: torch.Tensor, # [D, k, 1]
101
+ ) -> torch.Tensor:
102
+ """
103
+ Compute sinusoidal distance embeddings for specific query-key pairs.
104
+ """
105
+ D, k = key_pos.shape[:2]
106
+ device = query_pos.device
107
+
108
+ # Compute pairwise distances
109
+ D_pairs = query_pos.unsqueeze(1) - key_pos # [D, k, 3]
110
+ dist_matrix = torch.linalg.norm(D_pairs, dim=-1) # [D, k]
111
+
112
+ # Sinusoidal embedding
113
+ half_dim = self.n_freqs
114
+ freq = torch.exp(
115
+ -math.log(10000.0)
116
+ * torch.arange(0, half_dim, dtype=torch.float32, device=device)
117
+ / half_dim
118
+ ) # [n_freqs]
119
+
120
+ angles = dist_matrix.unsqueeze(-1) * freq # [D, k, n_freqs]
121
+ sin_embed = torch.sin(angles)
122
+ cos_embed = torch.cos(angles)
123
+ sincos_embed = torch.cat([sin_embed, cos_embed], dim=-1) # [D, k, 2*n_freqs]
124
+
125
+ # Linear projection
126
+ P_pairs = self.output_proj(sincos_embed) # [D, k, c_atompair]
127
+ P_pairs = P_pairs * valid_mask
128
+
129
+ # Add linear embedding of valid mask
130
+ P_pairs = (
131
+ P_pairs + self.process_valid_mask(valid_mask.to(P_pairs.dtype)) * valid_mask
132
+ )
133
+
134
+ return P_pairs
135
+
136
+
137
+ class ChunkedPairwiseEmbedder(nn.Module):
138
+ """
139
+ Main chunked pairwise embedder that combines all embedding types.
140
+ This replaces the full P_LL computation with sparse computation.
141
+ """
142
+
143
+ def __init__(
144
+ self,
145
+ c_atompair: int,
146
+ motif_pos_embedder: Optional[ChunkedPositionPairDistEmbedder] = None,
147
+ ref_pos_embedder: Optional[ChunkedPositionPairDistEmbedder] = None,
148
+ process_single_l: Optional[nn.Module] = None,
149
+ process_single_m: Optional[nn.Module] = None,
150
+ process_z: Optional[nn.Module] = None,
151
+ pair_mlp: Optional[nn.Module] = None,
152
+ **kwargs,
153
+ ):
154
+ super().__init__()
155
+ self.c_atompair = c_atompair
156
+ self.motif_pos_embedder = motif_pos_embedder
157
+ self.ref_pos_embedder = ref_pos_embedder
158
+
159
+ # Use shared trained MLPs if provided, otherwise create new ones
160
+ if process_single_l is not None:
161
+ self.process_single_l = process_single_l
162
+ else:
163
+ self.process_single_l = nn.Sequential(
164
+ nn.ReLU(), linearNoBias(128, c_atompair)
165
+ )
166
+
167
+ if process_single_m is not None:
168
+ self.process_single_m = process_single_m
169
+ else:
170
+ self.process_single_m = nn.Sequential(
171
+ nn.ReLU(), linearNoBias(128, c_atompair)
172
+ )
173
+
174
+ if process_z is not None:
175
+ self.process_z = process_z
176
+ else:
177
+ self.process_z = nn.Sequential(RMSNorm(128), linearNoBias(128, c_atompair))
178
+
179
+ if pair_mlp is not None:
180
+ self.pair_mlp = pair_mlp
181
+ else:
182
+ self.pair_mlp = nn.Sequential(
183
+ nn.ReLU(),
184
+ linearNoBias(c_atompair, c_atompair),
185
+ nn.ReLU(),
186
+ linearNoBias(c_atompair, c_atompair),
187
+ nn.ReLU(),
188
+ linearNoBias(c_atompair, c_atompair),
189
+ )
190
+
191
+ def forward_chunked(
192
+ self,
193
+ f: dict,
194
+ indices: torch.Tensor, # [D, L, k] - sparse attention indices
195
+ C_L: torch.Tensor, # [D, L, c_token] - atom features
196
+ Z_init_II: torch.Tensor, # [I, I, c_z] - token pair features
197
+ tok_idx: torch.Tensor, # [L] - atom to token mapping
198
+ ) -> torch.Tensor:
199
+ # Add logging for chunked P_LL computation
200
+ import logging
201
+
202
+ logger = logging.getLogger(__name__)
203
+ logger.info(
204
+ f"ChunkedPairwiseEmbedder: Computing sparse P_LL for {indices.shape[1]} atoms with {indices.shape[2]} neighbors each"
205
+ )
206
+ """
207
+ Compute P_LL only for the pairs specified by attention indices.
208
+
209
+ Args:
210
+ f: Feature dictionary
211
+ indices: Sparse attention indices [D, L, k]
212
+ C_L: Atom-level features [D, L, c_token]
213
+ Z_init_II: Token-level pair features [I, I, c_z]
214
+ tok_idx: Atom to token mapping [L]
215
+
216
+ Returns:
217
+ P_LL_sparse: Sparse pairwise features [D, L, k, c_atompair]
218
+ """
219
+ D, L, k = indices.shape
220
+ device = indices.device
221
+
222
+ # Initialize sparse P_LL
223
+ P_LL_sparse = torch.zeros(
224
+ D, L, k, self.c_atompair, device=device, dtype=C_L.dtype
225
+ )
226
+
227
+ # Handle both batched and non-batched C_L
228
+ if C_L.dim() == 2: # [L, c_token] - add batch dimension
229
+ C_L = C_L.unsqueeze(0) # [1, L, c_token]
230
+ # Add bounds checking to prevent index errors
231
+ L_max = C_L.shape[1]
232
+ valid_indices = torch.clamp(
233
+ indices, 0, L_max - 1
234
+ ) # Clamp indices to valid range
235
+
236
+ # Ensure indices have the right shape for gathering
237
+ if valid_indices.dim() == 2: # [L, k] - add batch dimension
238
+ valid_indices = valid_indices.unsqueeze(0).expand(
239
+ C_L.shape[0], -1, -1
240
+ ) # [D, L, k]
241
+
242
+ # 1. Motif position embedding (if exists)
243
+ if self.motif_pos_embedder is not None and "motif_pos" in f:
244
+ motif_pos = f["motif_pos"] # [L, 3]
245
+ is_motif = f["is_motif_atom_with_fixed_coord"] # [L]
246
+
247
+ # For each query position
248
+ for l in range(L):
249
+ if is_motif[l]: # Only compute if query is motif
250
+ key_indices = valid_indices[:, l, :] # [D, k] - use clamped indices
251
+ key_pos = motif_pos[key_indices] # [D, k, 3]
252
+ query_pos = motif_pos[l].unsqueeze(0).expand(D, -1) # [D, 3]
253
+
254
+ # Valid mask: both query and keys must be motif
255
+ key_is_motif = is_motif[key_indices] # [D, k]
256
+ valid_mask = key_is_motif.unsqueeze(-1).float() # [D, k, 1]
257
+
258
+ if valid_mask.sum() > 0:
259
+ motif_pairs = self.motif_pos_embedder.compute_pairs_chunked(
260
+ query_pos, key_pos, valid_mask
261
+ )
262
+ P_LL_sparse[:, l, :, :] += motif_pairs
263
+
264
+ # 2. Reference position embedding (if exists)
265
+ if self.ref_pos_embedder is not None and "ref_pos" in f:
266
+ ref_pos = f["ref_pos"] # [L, 3]
267
+ ref_space_uid = f["ref_space_uid"] # [L]
268
+ is_motif_seq = f["is_motif_atom_with_fixed_seq"] # [L]
269
+
270
+ for l in range(L):
271
+ if is_motif_seq[l]: # Only compute if query has sequence
272
+ key_indices = valid_indices[:, l, :] # [D, k] - use clamped indices
273
+ key_pos = ref_pos[key_indices] # [D, k, 3]
274
+ query_pos = ref_pos[l].unsqueeze(0).expand(D, -1) # [D, 3]
275
+
276
+ # Valid mask: same token and both have sequence
277
+ key_space_uid = ref_space_uid[key_indices] # [D, k]
278
+ key_is_motif_seq = is_motif_seq[key_indices] # [D, k]
279
+
280
+ same_token = key_space_uid == ref_space_uid[l] # [D, k]
281
+ valid_mask = (
282
+ (same_token & key_is_motif_seq).unsqueeze(-1).float()
283
+ ) # [D, k, 1]
284
+
285
+ if valid_mask.sum() > 0:
286
+ ref_pairs = self.ref_pos_embedder.compute_pairs_chunked(
287
+ query_pos, key_pos, valid_mask
288
+ )
289
+ P_LL_sparse[:, l, :, :] += ref_pairs
290
+
291
+ # 3. Single embedding terms (broadcasted)
292
+ # Gather key features for each query
293
+ C_L_keys = torch.gather(
294
+ C_L.unsqueeze(2).expand(-1, -1, k, -1),
295
+ 1,
296
+ valid_indices.unsqueeze(-1).expand(-1, -1, -1, C_L.shape[-1]),
297
+ ) # [D, L, k, c_token]
298
+ C_L_queries = C_L.unsqueeze(2).expand(-1, -1, k, -1) # [D, L, k, c_token]
299
+
300
+ # Add single embeddings - match standard implementation structure
301
+ # Standard does: self.process_single_l(C_L).unsqueeze(-2) + self.process_single_m(C_L).unsqueeze(-3)
302
+ # We need to broadcast from [D, L, k, c_atompair] to match this
303
+ single_l = self.process_single_l(C_L_queries) # [D, L, k, c_atompair]
304
+ single_m = self.process_single_m(C_L_keys) # [D, L, k, c_atompair]
305
+ P_LL_sparse += single_l + single_m
306
+
307
+ # 4. Token pair features Z_init_II
308
+ # Map atoms to tokens and gather token pair features
309
+ # Handle tok_idx dimensions properly
310
+ if tok_idx.dim() == 1: # [L] - add batch dimension for consistency
311
+ tok_idx_expanded = tok_idx.unsqueeze(0) # [1, L]
312
+ else:
313
+ tok_idx_expanded = tok_idx
314
+
315
+ tok_queries = tok_idx_expanded.unsqueeze(2).expand(-1, -1, k) # [D, L, k]
316
+ # Use valid_indices for token mapping as well
317
+ tok_keys = torch.gather(
318
+ tok_idx_expanded.unsqueeze(2).expand(-1, -1, k), 1, valid_indices
319
+ ) # [D, L, k]
320
+
321
+ # Gather Z_init_II[tok_queries, tok_keys] with safe indexing
322
+ # Z_init_II shape is [I, I, c_z] (3D), not 4D
323
+ # tok_queries shape: [D, L, k] - each value is a token index
324
+ # We want: Z_init_II[tok_queries[d,l,k], tok_keys[d,l,k], :] for all d,l,k
325
+
326
+ I_z, I_z2, c_z = Z_init_II.shape
327
+
328
+ # CRITICAL: Match standard implementation exactly!
329
+ # Standard does: self.process_z(Z_init_II)[..., tok_idx, :, :][..., tok_idx, :]
330
+ # This means: 1) Process Z_init_II first, 2) Then do double token indexing
331
+
332
+ # Step 1: Process Z_init_II to get processed token pair features
333
+ Z_processed = self.process_z(Z_init_II) # [I, I, c_atompair]
334
+
335
+ # Step 2: Do the double indexing like the standard implementation
336
+ # Standard: Z_processed[..., tok_idx, :, :][..., tok_idx, :]
337
+ # This creates Z_processed[tok_idx, :][:, tok_idx] which is [L, L, c_atompair]
338
+ # Then we need to gather the sparse version
339
+
340
+ Z_pairs_processed = torch.zeros(
341
+ D, L, k, self.c_atompair, device=device, dtype=Z_processed.dtype
342
+ )
343
+
344
+ for d in range(D):
345
+ # For this batch, get the token queries and keys
346
+ tq = tok_queries[d] # [L, k]
347
+ tk = tok_keys[d] # [L, k]
348
+
349
+ # Ensure indices are within bounds
350
+ tq = torch.clamp(tq, 0, I_z - 1)
351
+ tk = torch.clamp(tk, 0, I_z2 - 1)
352
+
353
+ # Apply the double token indexing like standard implementation
354
+ Z_pairs_processed[d] = Z_processed[tq, tk] # [L, k, c_atompair]
355
+
356
+ P_LL_sparse += Z_pairs_processed
357
+
358
+ # 5. Final MLP - ADD the result, don't replace (to match standard implementation)
359
+ P_LL_sparse = P_LL_sparse + self.pair_mlp(P_LL_sparse)
360
+
361
+ return P_LL_sparse.contiguous()
362
+
363
+
364
+ def create_chunked_embedders(
365
+ c_atompair: int, embed_frame: bool = True
366
+ ) -> ChunkedPairwiseEmbedder:
367
+ """
368
+ Factory function to create chunked pairwise embedder with standard components.
369
+ """
370
+ motif_pos_embedder = ChunkedPositionPairDistEmbedder(c_atompair, embed_frame)
371
+ ref_pos_embedder = ChunkedPositionPairDistEmbedder(c_atompair, embed_frame)
372
+
373
+ return ChunkedPairwiseEmbedder(
374
+ c_atompair=c_atompair,
375
+ motif_pos_embedder=motif_pos_embedder,
376
+ ref_pos_embedder=ref_pos_embedder,
377
+ )