rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,377 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Chunked pairwise embedding implementation for memory-efficient large structure processing.
|
|
3
|
+
|
|
4
|
+
This module provides memory-optimized versions of pairwise embedders that compute
|
|
5
|
+
only the pairs needed for sparse attention, reducing memory usage from O(L²) to O(L×k).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import math
|
|
9
|
+
from typing import Optional
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn as nn
|
|
13
|
+
from rfd3.model.layers.layer_utils import RMSNorm, linearNoBias
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ChunkedPositionPairDistEmbedder(nn.Module):
|
|
17
|
+
"""
|
|
18
|
+
Memory-efficient version of PositionPairDistEmbedder that computes pairs on-demand.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, c_atompair, embed_frame=True):
|
|
22
|
+
super().__init__()
|
|
23
|
+
self.c_atompair = c_atompair
|
|
24
|
+
self.embed_frame = embed_frame
|
|
25
|
+
if embed_frame:
|
|
26
|
+
self.process_d = linearNoBias(3, c_atompair)
|
|
27
|
+
|
|
28
|
+
self.process_inverse_dist = linearNoBias(1, c_atompair)
|
|
29
|
+
self.process_valid_mask = linearNoBias(1, c_atompair)
|
|
30
|
+
|
|
31
|
+
def compute_pairs_chunked(
|
|
32
|
+
self,
|
|
33
|
+
query_pos: torch.Tensor, # [D, 3]
|
|
34
|
+
key_pos: torch.Tensor, # [D, k, 3]
|
|
35
|
+
valid_mask: torch.Tensor, # [D, k, 1]
|
|
36
|
+
) -> torch.Tensor:
|
|
37
|
+
"""
|
|
38
|
+
Compute pairwise embeddings for specific query-key pairs.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
query_pos: Query positions [D, 3]
|
|
42
|
+
key_pos: Key positions [D, k, 3]
|
|
43
|
+
valid_mask: Valid pair mask [D, k, 1]
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
P_sparse: Pairwise embeddings [D, k, c_atompair]
|
|
47
|
+
"""
|
|
48
|
+
D, k = key_pos.shape[:2]
|
|
49
|
+
|
|
50
|
+
# Compute pairwise distances: [D, k, 3]
|
|
51
|
+
D_pairs = query_pos.unsqueeze(1) - key_pos # [D, 1, 3] - [D, k, 3] = [D, k, 3]
|
|
52
|
+
|
|
53
|
+
if self.embed_frame:
|
|
54
|
+
# Embed pairwise distances
|
|
55
|
+
P_pairs = self.process_d(D_pairs) * valid_mask # [D, k, c_atompair]
|
|
56
|
+
|
|
57
|
+
# Add inverse distance embedding
|
|
58
|
+
norm_sq = torch.linalg.norm(D_pairs, dim=-1, keepdim=True) ** 2 # [D, k, 1]
|
|
59
|
+
inv_dist = 1 / (1 + norm_sq)
|
|
60
|
+
P_pairs = P_pairs + self.process_inverse_dist(inv_dist) * valid_mask
|
|
61
|
+
|
|
62
|
+
# Add valid mask embedding
|
|
63
|
+
P_pairs = (
|
|
64
|
+
P_pairs
|
|
65
|
+
+ self.process_valid_mask(valid_mask.to(P_pairs.dtype)) * valid_mask
|
|
66
|
+
)
|
|
67
|
+
else:
|
|
68
|
+
# Simplified version without frame embedding
|
|
69
|
+
norm_sq = torch.linalg.norm(D_pairs, dim=-1, keepdim=True) ** 2
|
|
70
|
+
norm_sq = torch.clamp(norm_sq, min=1e-6)
|
|
71
|
+
inv_dist = 1 / (1 + norm_sq)
|
|
72
|
+
P_pairs = self.process_inverse_dist(inv_dist) * valid_mask
|
|
73
|
+
P_pairs = (
|
|
74
|
+
P_pairs
|
|
75
|
+
+ self.process_valid_mask(valid_mask.to(P_pairs.dtype)) * valid_mask
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
return P_pairs
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class ChunkedSinusoidalDistEmbed(nn.Module):
|
|
82
|
+
"""
|
|
83
|
+
Memory-efficient version of SinusoidalDistEmbed.
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
def __init__(self, c_atompair, n_freqs=32):
|
|
87
|
+
super().__init__()
|
|
88
|
+
assert c_atompair % 2 == 0, "Output embedding dim must be even"
|
|
89
|
+
|
|
90
|
+
self.n_freqs = n_freqs
|
|
91
|
+
self.c_atompair = c_atompair
|
|
92
|
+
|
|
93
|
+
self.output_proj = linearNoBias(2 * n_freqs, c_atompair)
|
|
94
|
+
self.process_valid_mask = linearNoBias(1, c_atompair)
|
|
95
|
+
|
|
96
|
+
def compute_pairs_chunked(
|
|
97
|
+
self,
|
|
98
|
+
query_pos: torch.Tensor, # [D, 3]
|
|
99
|
+
key_pos: torch.Tensor, # [D, k, 3]
|
|
100
|
+
valid_mask: torch.Tensor, # [D, k, 1]
|
|
101
|
+
) -> torch.Tensor:
|
|
102
|
+
"""
|
|
103
|
+
Compute sinusoidal distance embeddings for specific query-key pairs.
|
|
104
|
+
"""
|
|
105
|
+
D, k = key_pos.shape[:2]
|
|
106
|
+
device = query_pos.device
|
|
107
|
+
|
|
108
|
+
# Compute pairwise distances
|
|
109
|
+
D_pairs = query_pos.unsqueeze(1) - key_pos # [D, k, 3]
|
|
110
|
+
dist_matrix = torch.linalg.norm(D_pairs, dim=-1) # [D, k]
|
|
111
|
+
|
|
112
|
+
# Sinusoidal embedding
|
|
113
|
+
half_dim = self.n_freqs
|
|
114
|
+
freq = torch.exp(
|
|
115
|
+
-math.log(10000.0)
|
|
116
|
+
* torch.arange(0, half_dim, dtype=torch.float32, device=device)
|
|
117
|
+
/ half_dim
|
|
118
|
+
) # [n_freqs]
|
|
119
|
+
|
|
120
|
+
angles = dist_matrix.unsqueeze(-1) * freq # [D, k, n_freqs]
|
|
121
|
+
sin_embed = torch.sin(angles)
|
|
122
|
+
cos_embed = torch.cos(angles)
|
|
123
|
+
sincos_embed = torch.cat([sin_embed, cos_embed], dim=-1) # [D, k, 2*n_freqs]
|
|
124
|
+
|
|
125
|
+
# Linear projection
|
|
126
|
+
P_pairs = self.output_proj(sincos_embed) # [D, k, c_atompair]
|
|
127
|
+
P_pairs = P_pairs * valid_mask
|
|
128
|
+
|
|
129
|
+
# Add linear embedding of valid mask
|
|
130
|
+
P_pairs = (
|
|
131
|
+
P_pairs + self.process_valid_mask(valid_mask.to(P_pairs.dtype)) * valid_mask
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
return P_pairs
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class ChunkedPairwiseEmbedder(nn.Module):
|
|
138
|
+
"""
|
|
139
|
+
Main chunked pairwise embedder that combines all embedding types.
|
|
140
|
+
This replaces the full P_LL computation with sparse computation.
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
def __init__(
|
|
144
|
+
self,
|
|
145
|
+
c_atompair: int,
|
|
146
|
+
motif_pos_embedder: Optional[ChunkedPositionPairDistEmbedder] = None,
|
|
147
|
+
ref_pos_embedder: Optional[ChunkedPositionPairDistEmbedder] = None,
|
|
148
|
+
process_single_l: Optional[nn.Module] = None,
|
|
149
|
+
process_single_m: Optional[nn.Module] = None,
|
|
150
|
+
process_z: Optional[nn.Module] = None,
|
|
151
|
+
pair_mlp: Optional[nn.Module] = None,
|
|
152
|
+
**kwargs,
|
|
153
|
+
):
|
|
154
|
+
super().__init__()
|
|
155
|
+
self.c_atompair = c_atompair
|
|
156
|
+
self.motif_pos_embedder = motif_pos_embedder
|
|
157
|
+
self.ref_pos_embedder = ref_pos_embedder
|
|
158
|
+
|
|
159
|
+
# Use shared trained MLPs if provided, otherwise create new ones
|
|
160
|
+
if process_single_l is not None:
|
|
161
|
+
self.process_single_l = process_single_l
|
|
162
|
+
else:
|
|
163
|
+
self.process_single_l = nn.Sequential(
|
|
164
|
+
nn.ReLU(), linearNoBias(128, c_atompair)
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
if process_single_m is not None:
|
|
168
|
+
self.process_single_m = process_single_m
|
|
169
|
+
else:
|
|
170
|
+
self.process_single_m = nn.Sequential(
|
|
171
|
+
nn.ReLU(), linearNoBias(128, c_atompair)
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
if process_z is not None:
|
|
175
|
+
self.process_z = process_z
|
|
176
|
+
else:
|
|
177
|
+
self.process_z = nn.Sequential(RMSNorm(128), linearNoBias(128, c_atompair))
|
|
178
|
+
|
|
179
|
+
if pair_mlp is not None:
|
|
180
|
+
self.pair_mlp = pair_mlp
|
|
181
|
+
else:
|
|
182
|
+
self.pair_mlp = nn.Sequential(
|
|
183
|
+
nn.ReLU(),
|
|
184
|
+
linearNoBias(c_atompair, c_atompair),
|
|
185
|
+
nn.ReLU(),
|
|
186
|
+
linearNoBias(c_atompair, c_atompair),
|
|
187
|
+
nn.ReLU(),
|
|
188
|
+
linearNoBias(c_atompair, c_atompair),
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
def forward_chunked(
|
|
192
|
+
self,
|
|
193
|
+
f: dict,
|
|
194
|
+
indices: torch.Tensor, # [D, L, k] - sparse attention indices
|
|
195
|
+
C_L: torch.Tensor, # [D, L, c_token] - atom features
|
|
196
|
+
Z_init_II: torch.Tensor, # [I, I, c_z] - token pair features
|
|
197
|
+
tok_idx: torch.Tensor, # [L] - atom to token mapping
|
|
198
|
+
) -> torch.Tensor:
|
|
199
|
+
# Add logging for chunked P_LL computation
|
|
200
|
+
import logging
|
|
201
|
+
|
|
202
|
+
logger = logging.getLogger(__name__)
|
|
203
|
+
logger.info(
|
|
204
|
+
f"ChunkedPairwiseEmbedder: Computing sparse P_LL for {indices.shape[1]} atoms with {indices.shape[2]} neighbors each"
|
|
205
|
+
)
|
|
206
|
+
"""
|
|
207
|
+
Compute P_LL only for the pairs specified by attention indices.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
f: Feature dictionary
|
|
211
|
+
indices: Sparse attention indices [D, L, k]
|
|
212
|
+
C_L: Atom-level features [D, L, c_token]
|
|
213
|
+
Z_init_II: Token-level pair features [I, I, c_z]
|
|
214
|
+
tok_idx: Atom to token mapping [L]
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
P_LL_sparse: Sparse pairwise features [D, L, k, c_atompair]
|
|
218
|
+
"""
|
|
219
|
+
D, L, k = indices.shape
|
|
220
|
+
device = indices.device
|
|
221
|
+
|
|
222
|
+
# Initialize sparse P_LL
|
|
223
|
+
P_LL_sparse = torch.zeros(
|
|
224
|
+
D, L, k, self.c_atompair, device=device, dtype=C_L.dtype
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# Handle both batched and non-batched C_L
|
|
228
|
+
if C_L.dim() == 2: # [L, c_token] - add batch dimension
|
|
229
|
+
C_L = C_L.unsqueeze(0) # [1, L, c_token]
|
|
230
|
+
# Add bounds checking to prevent index errors
|
|
231
|
+
L_max = C_L.shape[1]
|
|
232
|
+
valid_indices = torch.clamp(
|
|
233
|
+
indices, 0, L_max - 1
|
|
234
|
+
) # Clamp indices to valid range
|
|
235
|
+
|
|
236
|
+
# Ensure indices have the right shape for gathering
|
|
237
|
+
if valid_indices.dim() == 2: # [L, k] - add batch dimension
|
|
238
|
+
valid_indices = valid_indices.unsqueeze(0).expand(
|
|
239
|
+
C_L.shape[0], -1, -1
|
|
240
|
+
) # [D, L, k]
|
|
241
|
+
|
|
242
|
+
# 1. Motif position embedding (if exists)
|
|
243
|
+
if self.motif_pos_embedder is not None and "motif_pos" in f:
|
|
244
|
+
motif_pos = f["motif_pos"] # [L, 3]
|
|
245
|
+
is_motif = f["is_motif_atom_with_fixed_coord"] # [L]
|
|
246
|
+
|
|
247
|
+
# For each query position
|
|
248
|
+
for l in range(L):
|
|
249
|
+
if is_motif[l]: # Only compute if query is motif
|
|
250
|
+
key_indices = valid_indices[:, l, :] # [D, k] - use clamped indices
|
|
251
|
+
key_pos = motif_pos[key_indices] # [D, k, 3]
|
|
252
|
+
query_pos = motif_pos[l].unsqueeze(0).expand(D, -1) # [D, 3]
|
|
253
|
+
|
|
254
|
+
# Valid mask: both query and keys must be motif
|
|
255
|
+
key_is_motif = is_motif[key_indices] # [D, k]
|
|
256
|
+
valid_mask = key_is_motif.unsqueeze(-1).float() # [D, k, 1]
|
|
257
|
+
|
|
258
|
+
if valid_mask.sum() > 0:
|
|
259
|
+
motif_pairs = self.motif_pos_embedder.compute_pairs_chunked(
|
|
260
|
+
query_pos, key_pos, valid_mask
|
|
261
|
+
)
|
|
262
|
+
P_LL_sparse[:, l, :, :] += motif_pairs
|
|
263
|
+
|
|
264
|
+
# 2. Reference position embedding (if exists)
|
|
265
|
+
if self.ref_pos_embedder is not None and "ref_pos" in f:
|
|
266
|
+
ref_pos = f["ref_pos"] # [L, 3]
|
|
267
|
+
ref_space_uid = f["ref_space_uid"] # [L]
|
|
268
|
+
is_motif_seq = f["is_motif_atom_with_fixed_seq"] # [L]
|
|
269
|
+
|
|
270
|
+
for l in range(L):
|
|
271
|
+
if is_motif_seq[l]: # Only compute if query has sequence
|
|
272
|
+
key_indices = valid_indices[:, l, :] # [D, k] - use clamped indices
|
|
273
|
+
key_pos = ref_pos[key_indices] # [D, k, 3]
|
|
274
|
+
query_pos = ref_pos[l].unsqueeze(0).expand(D, -1) # [D, 3]
|
|
275
|
+
|
|
276
|
+
# Valid mask: same token and both have sequence
|
|
277
|
+
key_space_uid = ref_space_uid[key_indices] # [D, k]
|
|
278
|
+
key_is_motif_seq = is_motif_seq[key_indices] # [D, k]
|
|
279
|
+
|
|
280
|
+
same_token = key_space_uid == ref_space_uid[l] # [D, k]
|
|
281
|
+
valid_mask = (
|
|
282
|
+
(same_token & key_is_motif_seq).unsqueeze(-1).float()
|
|
283
|
+
) # [D, k, 1]
|
|
284
|
+
|
|
285
|
+
if valid_mask.sum() > 0:
|
|
286
|
+
ref_pairs = self.ref_pos_embedder.compute_pairs_chunked(
|
|
287
|
+
query_pos, key_pos, valid_mask
|
|
288
|
+
)
|
|
289
|
+
P_LL_sparse[:, l, :, :] += ref_pairs
|
|
290
|
+
|
|
291
|
+
# 3. Single embedding terms (broadcasted)
|
|
292
|
+
# Gather key features for each query
|
|
293
|
+
C_L_keys = torch.gather(
|
|
294
|
+
C_L.unsqueeze(2).expand(-1, -1, k, -1),
|
|
295
|
+
1,
|
|
296
|
+
valid_indices.unsqueeze(-1).expand(-1, -1, -1, C_L.shape[-1]),
|
|
297
|
+
) # [D, L, k, c_token]
|
|
298
|
+
C_L_queries = C_L.unsqueeze(2).expand(-1, -1, k, -1) # [D, L, k, c_token]
|
|
299
|
+
|
|
300
|
+
# Add single embeddings - match standard implementation structure
|
|
301
|
+
# Standard does: self.process_single_l(C_L).unsqueeze(-2) + self.process_single_m(C_L).unsqueeze(-3)
|
|
302
|
+
# We need to broadcast from [D, L, k, c_atompair] to match this
|
|
303
|
+
single_l = self.process_single_l(C_L_queries) # [D, L, k, c_atompair]
|
|
304
|
+
single_m = self.process_single_m(C_L_keys) # [D, L, k, c_atompair]
|
|
305
|
+
P_LL_sparse += single_l + single_m
|
|
306
|
+
|
|
307
|
+
# 4. Token pair features Z_init_II
|
|
308
|
+
# Map atoms to tokens and gather token pair features
|
|
309
|
+
# Handle tok_idx dimensions properly
|
|
310
|
+
if tok_idx.dim() == 1: # [L] - add batch dimension for consistency
|
|
311
|
+
tok_idx_expanded = tok_idx.unsqueeze(0) # [1, L]
|
|
312
|
+
else:
|
|
313
|
+
tok_idx_expanded = tok_idx
|
|
314
|
+
|
|
315
|
+
tok_queries = tok_idx_expanded.unsqueeze(2).expand(-1, -1, k) # [D, L, k]
|
|
316
|
+
# Use valid_indices for token mapping as well
|
|
317
|
+
tok_keys = torch.gather(
|
|
318
|
+
tok_idx_expanded.unsqueeze(2).expand(-1, -1, k), 1, valid_indices
|
|
319
|
+
) # [D, L, k]
|
|
320
|
+
|
|
321
|
+
# Gather Z_init_II[tok_queries, tok_keys] with safe indexing
|
|
322
|
+
# Z_init_II shape is [I, I, c_z] (3D), not 4D
|
|
323
|
+
# tok_queries shape: [D, L, k] - each value is a token index
|
|
324
|
+
# We want: Z_init_II[tok_queries[d,l,k], tok_keys[d,l,k], :] for all d,l,k
|
|
325
|
+
|
|
326
|
+
I_z, I_z2, c_z = Z_init_II.shape
|
|
327
|
+
|
|
328
|
+
# CRITICAL: Match standard implementation exactly!
|
|
329
|
+
# Standard does: self.process_z(Z_init_II)[..., tok_idx, :, :][..., tok_idx, :]
|
|
330
|
+
# This means: 1) Process Z_init_II first, 2) Then do double token indexing
|
|
331
|
+
|
|
332
|
+
# Step 1: Process Z_init_II to get processed token pair features
|
|
333
|
+
Z_processed = self.process_z(Z_init_II) # [I, I, c_atompair]
|
|
334
|
+
|
|
335
|
+
# Step 2: Do the double indexing like the standard implementation
|
|
336
|
+
# Standard: Z_processed[..., tok_idx, :, :][..., tok_idx, :]
|
|
337
|
+
# This creates Z_processed[tok_idx, :][:, tok_idx] which is [L, L, c_atompair]
|
|
338
|
+
# Then we need to gather the sparse version
|
|
339
|
+
|
|
340
|
+
Z_pairs_processed = torch.zeros(
|
|
341
|
+
D, L, k, self.c_atompair, device=device, dtype=Z_processed.dtype
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
for d in range(D):
|
|
345
|
+
# For this batch, get the token queries and keys
|
|
346
|
+
tq = tok_queries[d] # [L, k]
|
|
347
|
+
tk = tok_keys[d] # [L, k]
|
|
348
|
+
|
|
349
|
+
# Ensure indices are within bounds
|
|
350
|
+
tq = torch.clamp(tq, 0, I_z - 1)
|
|
351
|
+
tk = torch.clamp(tk, 0, I_z2 - 1)
|
|
352
|
+
|
|
353
|
+
# Apply the double token indexing like standard implementation
|
|
354
|
+
Z_pairs_processed[d] = Z_processed[tq, tk] # [L, k, c_atompair]
|
|
355
|
+
|
|
356
|
+
P_LL_sparse += Z_pairs_processed
|
|
357
|
+
|
|
358
|
+
# 5. Final MLP - ADD the result, don't replace (to match standard implementation)
|
|
359
|
+
P_LL_sparse = P_LL_sparse + self.pair_mlp(P_LL_sparse)
|
|
360
|
+
|
|
361
|
+
return P_LL_sparse.contiguous()
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
def create_chunked_embedders(
|
|
365
|
+
c_atompair: int, embed_frame: bool = True
|
|
366
|
+
) -> ChunkedPairwiseEmbedder:
|
|
367
|
+
"""
|
|
368
|
+
Factory function to create chunked pairwise embedder with standard components.
|
|
369
|
+
"""
|
|
370
|
+
motif_pos_embedder = ChunkedPositionPairDistEmbedder(c_atompair, embed_frame)
|
|
371
|
+
ref_pos_embedder = ChunkedPositionPairDistEmbedder(c_atompair, embed_frame)
|
|
372
|
+
|
|
373
|
+
return ChunkedPairwiseEmbedder(
|
|
374
|
+
c_atompair=c_atompair,
|
|
375
|
+
motif_pos_embedder=motif_pos_embedder,
|
|
376
|
+
ref_pos_embedder=ref_pos_embedder,
|
|
377
|
+
)
|