rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,580 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Tuple
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
from jaxtyping import Float, Int
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def bucketize_scaled_distogram(R_L, min_dist=1, max_dist=30, sigma_data=16, n_bins=65):
|
|
12
|
+
"""
|
|
13
|
+
Bucketizes pairwise distances into bins based on edm scaling
|
|
14
|
+
|
|
15
|
+
min dist and max dist given as angstroms
|
|
16
|
+
Will use bin ranges based on scaled angstrom distances
|
|
17
|
+
|
|
18
|
+
R_L: B, N, 3
|
|
19
|
+
D_LL: B, N, N
|
|
20
|
+
D_LL_binned: B, N, N, n_bins
|
|
21
|
+
"""
|
|
22
|
+
D_LL = R_L.unsqueeze(-2) - R_L.unsqueeze(-3) # [B, N, N, 3]
|
|
23
|
+
D_LL = torch.linalg.norm(D_LL, dim=-1) # [B, N, N]
|
|
24
|
+
|
|
25
|
+
# normalize
|
|
26
|
+
min_dist, max_dist = min_dist / sigma_data, max_dist / sigma_data
|
|
27
|
+
|
|
28
|
+
bins = torch.linspace(min_dist, max_dist, n_bins - 1, device=D_LL.device)
|
|
29
|
+
bin_idxs = torch.bucketize(D_LL, bins)
|
|
30
|
+
return F.one_hot(bin_idxs, num_classes=len(bins) + 1).float()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def build_valid_mask(
|
|
34
|
+
tok_idx: torch.Tensor, n_atoms_per_tok_max: int | None = None
|
|
35
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
36
|
+
"""
|
|
37
|
+
Args
|
|
38
|
+
----
|
|
39
|
+
tok_idx : (n_atoms,) non negative integer array
|
|
40
|
+
n_atoms_per_tok_max : if given, pad/truncate up to this size
|
|
41
|
+
|
|
42
|
+
Returns
|
|
43
|
+
-------
|
|
44
|
+
valid_mask : (n_tokens, A) True where an atom exists
|
|
45
|
+
tokens : (n_tokens,) the unique token IDs in ascending order
|
|
46
|
+
"""
|
|
47
|
+
tokens, counts = torch.unique(tok_idx, return_counts=True)
|
|
48
|
+
A = int(counts.max()) if n_atoms_per_tok_max is None else int(n_atoms_per_tok_max)
|
|
49
|
+
|
|
50
|
+
# build [n_tokens, A] mask; broadcasting keeps it vectorised
|
|
51
|
+
atom_idx_grid = torch.arange(A, device=tok_idx.device)[None, :] # (1, A)
|
|
52
|
+
valid_mask = atom_idx_grid < counts[:, None] # (n_tok, A)
|
|
53
|
+
|
|
54
|
+
return valid_mask
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def ungroup_atoms(Q_L, valid_mask):
|
|
58
|
+
"""
|
|
59
|
+
Args
|
|
60
|
+
----
|
|
61
|
+
Q_L : (B, n_atoms, c)
|
|
62
|
+
valid_mask : (n_tokens, A) # same object returned by `ungroup_atoms`
|
|
63
|
+
|
|
64
|
+
Returns
|
|
65
|
+
-------
|
|
66
|
+
Q_IA : (B, n_tokens, A, c) # padded with zeros
|
|
67
|
+
"""
|
|
68
|
+
B, n_atoms, c = Q_L.shape
|
|
69
|
+
n_tokens, A = valid_mask.shape
|
|
70
|
+
Q_IA = torch.zeros(B, n_tokens, A, c, dtype=Q_L.dtype, device=Q_L.device)
|
|
71
|
+
mask4d = valid_mask.unsqueeze(0).unsqueeze(-1) # (1, n_tok, A, 1)
|
|
72
|
+
mask4d = mask4d.expand(B, -1, -1, c) # (B, n_tok, A, c)
|
|
73
|
+
Q_IA.masked_scatter_(mask4d, Q_L)
|
|
74
|
+
return Q_IA
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def group_atoms(Q_IA: torch.Tensor, valid_mask: torch.Tensor) -> torch.Tensor:
|
|
78
|
+
"""
|
|
79
|
+
Args
|
|
80
|
+
----
|
|
81
|
+
Q_IA : (B, n_tokens, A, c)
|
|
82
|
+
valid_mask : (n_tokens, A)
|
|
83
|
+
|
|
84
|
+
Returns
|
|
85
|
+
-------
|
|
86
|
+
Q_L : (B, n_atoms, c) flattened real atoms, order preserved
|
|
87
|
+
"""
|
|
88
|
+
B, _, _, c = Q_IA.shape
|
|
89
|
+
mask4d = valid_mask.unsqueeze(0).unsqueeze(-1).expand(B, -1, -1, c) # (B,n_tok,A,c)
|
|
90
|
+
Q_L = Q_IA[mask4d].view(B, -1, c) # restore 2‑D shape
|
|
91
|
+
return Q_L
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def group_pair(P_IAA, valid_mask):
|
|
95
|
+
# Valid mask: [L, A]
|
|
96
|
+
# P_IAA: (B, L, A, A, c) or (L, A, A, c)
|
|
97
|
+
if P_IAA.ndim == 5:
|
|
98
|
+
B, _, _, A, c = P_IAA.shape
|
|
99
|
+
mask5d = valid_mask[None, ..., None, None].expand(
|
|
100
|
+
B, -1, -1, A, c
|
|
101
|
+
) # (B, L, L, A, c)
|
|
102
|
+
P_LA = P_IAA[mask5d].view(B, -1, A, c) # (B, n_valid, A, c)
|
|
103
|
+
elif P_IAA.ndim == 4:
|
|
104
|
+
_, _, A, c = P_IAA.shape
|
|
105
|
+
mask4d = valid_mask[..., None, None].expand(-1, -1, A, c) # (L, L, A, c)
|
|
106
|
+
P_LA = P_IAA[mask4d].view(-1, A, c) # (n_valid, A, c)
|
|
107
|
+
else:
|
|
108
|
+
raise ValueError(
|
|
109
|
+
f"Unexpected input shape {P_IAA.shape}: must be (B, L, A, A, c) or (L, A, A, c)"
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
return P_LA
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def scatter_add_pair_features(P_LK_tgt, P_LK_indices, P_LA_src, P_LA_indices):
|
|
116
|
+
"""
|
|
117
|
+
Adds features from P_LA_C into P_LK_C at positions where P_LA matches P_LK.
|
|
118
|
+
|
|
119
|
+
Parameters
|
|
120
|
+
----------
|
|
121
|
+
P_LK_indices : (D, L, k) LongTensor
|
|
122
|
+
Key indices | P_LK_indices[d, i, k] = global atom index for which atom i attends to.
|
|
123
|
+
P_LK : (D, L, k, c) FloatTensor
|
|
124
|
+
Key features to scatter add into
|
|
125
|
+
|
|
126
|
+
P_LA_indices : (D, L, a) LongTensor
|
|
127
|
+
Additional feature indices to scatter into P_LK.
|
|
128
|
+
P_LA : (D, L, a, c) FloatTensor
|
|
129
|
+
Features corresponding to P_LA.
|
|
130
|
+
|
|
131
|
+
Both index tensors contain indices representing D batch dim,
|
|
132
|
+
L sequence positions and k keys / a additional features.
|
|
133
|
+
This function will scatter indices from P_LA into P_LK based on
|
|
134
|
+
matching indices.
|
|
135
|
+
|
|
136
|
+
"""
|
|
137
|
+
# Handle case when indices and P_LA don't have batch dimensions
|
|
138
|
+
D, L, k = P_LK_indices.shape
|
|
139
|
+
if P_LA_indices.ndim == 2:
|
|
140
|
+
P_LA_indices = P_LA_indices.unsqueeze(0).expand(D, -1, -1)
|
|
141
|
+
if P_LA_src.ndim == 3:
|
|
142
|
+
P_LA_src = P_LA_src.unsqueeze(0).expand(D, -1, -1)
|
|
143
|
+
assert (
|
|
144
|
+
P_LA_src.shape[-1] == P_LK_tgt.shape[-1]
|
|
145
|
+
), "Channel dims do not match, got: {} vs {}".format(
|
|
146
|
+
P_LA_src.shape[-1], P_LK_tgt.shape[-1]
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
matches = P_LA_indices.unsqueeze(-1) == P_LK_indices.unsqueeze(-2) # (D, L, a, k)
|
|
150
|
+
if not torch.all(matches.sum(dim=(-1, -2)) >= 1):
|
|
151
|
+
raise ValueError("Found multiple scatter indices for some atoms")
|
|
152
|
+
elif not torch.all(matches.sum(dim=-1) <= 1):
|
|
153
|
+
raise ValueError("Did not find a scatter index for every atom")
|
|
154
|
+
k_indices = matches.long().argmax(dim=-1) # (D, L, a)
|
|
155
|
+
scatter_indices = k_indices.unsqueeze(-1).expand(
|
|
156
|
+
-1, -1, -1, P_LK_tgt.shape[-1]
|
|
157
|
+
) # (D, L, a, c)
|
|
158
|
+
P_LK_tgt = P_LK_tgt.scatter_add(dim=2, index=scatter_indices, src=P_LA_src)
|
|
159
|
+
return P_LK_tgt
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _batched_gather(values: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
|
163
|
+
"""
|
|
164
|
+
values : (D, L, C)
|
|
165
|
+
idx : (D, L, k)
|
|
166
|
+
returns: (D, L, k, C)
|
|
167
|
+
"""
|
|
168
|
+
D, L, C = values.shape
|
|
169
|
+
k = idx.shape[-1]
|
|
170
|
+
|
|
171
|
+
# (D, L, 1, C) → stride-0 along k → (D, L, k, C)
|
|
172
|
+
src = values.unsqueeze(2).expand(-1, -1, k, -1)
|
|
173
|
+
idx = idx.unsqueeze(-1).expand(-1, -1, -1, C) # (D, L, k, C)
|
|
174
|
+
|
|
175
|
+
return torch.gather(src, 1, idx) # dim=1 is the L-axis
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
@torch.no_grad()
|
|
179
|
+
def create_attention_indices(
|
|
180
|
+
f, n_attn_keys, n_attn_seq_neighbours, X_L=None, tok_idx=None
|
|
181
|
+
):
|
|
182
|
+
"""
|
|
183
|
+
Entry-point function for creating attention indices for sequence & structure-local attention
|
|
184
|
+
|
|
185
|
+
f: input features of the model
|
|
186
|
+
n_attn_keys: number of (atom) attention keys
|
|
187
|
+
n_attn_seq_neighbours: number of neighbouring sequence tokens (residues) to attend to
|
|
188
|
+
X_L: optional input tensor for atom positions | if None, choose random padding atoms
|
|
189
|
+
"""
|
|
190
|
+
|
|
191
|
+
tok_idx = f["atom_to_token_map"] if tok_idx is None else tok_idx
|
|
192
|
+
device = X_L.device if X_L is not None else tok_idx.device
|
|
193
|
+
L = len(tok_idx)
|
|
194
|
+
|
|
195
|
+
if X_L is None:
|
|
196
|
+
X_L = torch.randn(
|
|
197
|
+
(1, L, 3), device=device, dtype=torch.float
|
|
198
|
+
) # [L, 3] - random
|
|
199
|
+
D_LL = torch.cdist(X_L, X_L, p=2) # [D, L, L] - pairwise atom distances
|
|
200
|
+
|
|
201
|
+
# Create attention indices using neighbour distances
|
|
202
|
+
base_mask = ~f["unindexing_pair_mask"][
|
|
203
|
+
tok_idx[None, :], tok_idx[:, None]
|
|
204
|
+
] # [n_atoms, n_atoms]
|
|
205
|
+
k_actual = min(n_attn_keys, L)
|
|
206
|
+
|
|
207
|
+
# For symmetric structures, ensure inter-chain interactions are included
|
|
208
|
+
chain_ids = f["asym_id"][tok_idx] if "asym_id" in f else None
|
|
209
|
+
if (
|
|
210
|
+
chain_ids is not None and len(torch.unique(chain_ids)) > 3
|
|
211
|
+
): # Multi-chain structure
|
|
212
|
+
# Reserve 25% of attention keys for inter-chain interactions
|
|
213
|
+
k_inter_chain = max(32, k_actual // 4) # At least 32 inter-chain keys
|
|
214
|
+
k_intra_chain = k_actual - k_inter_chain
|
|
215
|
+
|
|
216
|
+
attn_indices = get_sparse_attention_indices_with_inter_chain(
|
|
217
|
+
tok_idx,
|
|
218
|
+
D_LL,
|
|
219
|
+
n_seq_neighbours=n_attn_seq_neighbours,
|
|
220
|
+
k_intra=k_intra_chain,
|
|
221
|
+
k_inter=k_inter_chain,
|
|
222
|
+
chain_id=chain_ids,
|
|
223
|
+
base_mask=base_mask,
|
|
224
|
+
)
|
|
225
|
+
else:
|
|
226
|
+
# Regular attention for single chain or small structures
|
|
227
|
+
attn_indices = get_sparse_attention_indices(
|
|
228
|
+
tok_idx,
|
|
229
|
+
D_LL,
|
|
230
|
+
n_seq_neighbours=n_attn_seq_neighbours,
|
|
231
|
+
k_max=k_actual,
|
|
232
|
+
chain_id=chain_ids,
|
|
233
|
+
base_mask=base_mask,
|
|
234
|
+
) # [D, L, k] | indices[b, i, j] = atom index for atom i to j-th attn query
|
|
235
|
+
|
|
236
|
+
return attn_indices
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
@torch.no_grad()
|
|
240
|
+
def get_sparse_attention_indices_with_inter_chain(
|
|
241
|
+
tok_idx, D_LL, n_seq_neighbours, k_intra, k_inter, chain_id, base_mask
|
|
242
|
+
):
|
|
243
|
+
"""
|
|
244
|
+
Create attention indices that guarantee inter-chain interactions for clash avoidance.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
tok_idx: atom to token mapping
|
|
248
|
+
D_LL: pairwise distances [D, L, L]
|
|
249
|
+
n_seq_neighbours: number of sequence neighbors
|
|
250
|
+
k_intra: number of intra-chain attention keys
|
|
251
|
+
k_inter: number of inter-chain attention keys
|
|
252
|
+
chain_id: chain IDs for each atom
|
|
253
|
+
base_mask: base mask for valid pairs
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
attn_indices: [D, L, k_total] where k_total = k_intra + k_inter
|
|
257
|
+
"""
|
|
258
|
+
D, L, _ = D_LL.shape
|
|
259
|
+
|
|
260
|
+
# Get regular intra-chain indices (limited to k_intra)
|
|
261
|
+
intra_indices = get_sparse_attention_indices(
|
|
262
|
+
tok_idx, D_LL, n_seq_neighbours, k_intra, chain_id, base_mask
|
|
263
|
+
) # [D, L, k_intra]
|
|
264
|
+
|
|
265
|
+
# Get inter-chain indices for clash avoidance
|
|
266
|
+
inter_indices = torch.zeros(D, L, k_inter, dtype=torch.long, device=D_LL.device)
|
|
267
|
+
|
|
268
|
+
for d in range(D):
|
|
269
|
+
for l in range(L):
|
|
270
|
+
query_chain = chain_id[l]
|
|
271
|
+
|
|
272
|
+
# Find atoms from different chains
|
|
273
|
+
other_chain_mask = (chain_id != query_chain) & base_mask[l, :]
|
|
274
|
+
other_chain_atoms = torch.where(other_chain_mask)[0]
|
|
275
|
+
|
|
276
|
+
if len(other_chain_atoms) > 0:
|
|
277
|
+
# Get distances to other chains
|
|
278
|
+
distances_to_other = D_LL[d, l, other_chain_atoms]
|
|
279
|
+
|
|
280
|
+
# Select k_inter closest atoms from other chains
|
|
281
|
+
n_select = min(k_inter, len(other_chain_atoms))
|
|
282
|
+
_, closest_idx = torch.topk(distances_to_other, n_select, largest=False)
|
|
283
|
+
selected_atoms = other_chain_atoms[closest_idx]
|
|
284
|
+
|
|
285
|
+
# Fill inter-chain indices
|
|
286
|
+
inter_indices[d, l, :n_select] = selected_atoms
|
|
287
|
+
# Pad with random atoms if needed
|
|
288
|
+
if n_select < k_inter:
|
|
289
|
+
padding = torch.randint(
|
|
290
|
+
0, L, (k_inter - n_select,), device=D_LL.device
|
|
291
|
+
)
|
|
292
|
+
inter_indices[d, l, n_select:] = padding
|
|
293
|
+
else:
|
|
294
|
+
# No other chains found, fill with random indices
|
|
295
|
+
inter_indices[d, l, :] = torch.randint(
|
|
296
|
+
0, L, (k_inter,), device=D_LL.device
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
# Combine intra and inter chain indices
|
|
300
|
+
combined_indices = torch.cat(
|
|
301
|
+
[intra_indices, inter_indices], dim=-1
|
|
302
|
+
) # [D, L, k_total]
|
|
303
|
+
|
|
304
|
+
return combined_indices
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
@torch.no_grad()
|
|
308
|
+
def build_index_mask(
|
|
309
|
+
tok_idx: torch.Tensor,
|
|
310
|
+
n_sequence_neighbours: int,
|
|
311
|
+
k_max: int,
|
|
312
|
+
chain_id: torch.Tensor | None = None,
|
|
313
|
+
base_mask: torch.Tensor | None = None,
|
|
314
|
+
) -> torch.Tensor:
|
|
315
|
+
"""
|
|
316
|
+
Builds a mask that includes entire tokens from neighboring positions within a
|
|
317
|
+
tokenized sequence, never partially including a token. Limits range to k_max,
|
|
318
|
+
which is interpreted at the token level.
|
|
319
|
+
|
|
320
|
+
Parameters:
|
|
321
|
+
tok_idx: (L,) tensor of token indices.
|
|
322
|
+
n_sequence_neighbours: number of tokens to include on either side.
|
|
323
|
+
k_max: max total number of tokens (across both directions).
|
|
324
|
+
chain_id: (L,) chain identifiers for each position (optional).
|
|
325
|
+
base_mask: (L, L) optional pre-mask to AND with.
|
|
326
|
+
"""
|
|
327
|
+
device = tok_idx.device
|
|
328
|
+
L = tok_idx.shape[0]
|
|
329
|
+
k_max = min(k_max, L)
|
|
330
|
+
I = int(tok_idx.max()) + 1 # Number of unique tokens
|
|
331
|
+
n_atoms_per_token = torch.zeros(I, device=device).float()
|
|
332
|
+
n_atoms_per_token.scatter_add_(0, tok_idx.long(), torch.ones_like(tok_idx).float())
|
|
333
|
+
|
|
334
|
+
# Create index masks for tokens and atoms
|
|
335
|
+
token_indices = torch.arange(I, device=device)
|
|
336
|
+
token_diff = (token_indices[:, None] - token_indices[None, :]).abs()
|
|
337
|
+
atom_indices = torch.arange(L, device=device)
|
|
338
|
+
atom_diff = (atom_indices[:, None] - atom_indices[None, :]).abs()
|
|
339
|
+
|
|
340
|
+
# Build token-token mask: [I, I]
|
|
341
|
+
token_mask = token_diff <= n_sequence_neighbours
|
|
342
|
+
|
|
343
|
+
# Expand token_mask to full [L, L] mask using broadcast
|
|
344
|
+
# token_to_idx maps each position to a token index [L]
|
|
345
|
+
token_i = tok_idx[:, None] # (L, 1)
|
|
346
|
+
token_j = tok_idx[None, :] # (1, L)
|
|
347
|
+
mask = token_mask[token_i, token_j] # (L, L)
|
|
348
|
+
mask = mask & (atom_diff <= (k_max // 2))
|
|
349
|
+
|
|
350
|
+
# Exclude tokens which are partially filled (L, I)
|
|
351
|
+
n_query_per_token = torch.zeros((L, I), device=device).float()
|
|
352
|
+
n_query_per_token.scatter_add_(
|
|
353
|
+
1, tok_idx.long()[None, :].expand(L, -1), mask.float()
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
# Find mask for the atoms for which the number of keys
|
|
357
|
+
# match the number of atoms in the token (L, I)
|
|
358
|
+
fully_included = n_query_per_token == n_atoms_per_token[None, :]
|
|
359
|
+
|
|
360
|
+
# Contract to (L, L) and count the number of atoms within tokens that
|
|
361
|
+
# fully include other tokens
|
|
362
|
+
n_atoms_fully_included = torch.zeros((I, I), device=device)
|
|
363
|
+
n_atoms_fully_included.index_add_(0, tok_idx.long(), fully_included.float())
|
|
364
|
+
full_token_mask = n_atoms_fully_included == n_atoms_per_token[:, None]
|
|
365
|
+
|
|
366
|
+
# Map this back to (L, L) — include token j in row i only if all its atoms are included
|
|
367
|
+
full_token_mask = full_token_mask[token_i, token_j] # (L, L)
|
|
368
|
+
mask &= full_token_mask
|
|
369
|
+
|
|
370
|
+
if chain_id is not None:
|
|
371
|
+
same_chain = chain_id.unsqueeze(-1) == chain_id.unsqueeze(-2)
|
|
372
|
+
mask = mask & same_chain
|
|
373
|
+
|
|
374
|
+
if base_mask is not None:
|
|
375
|
+
mask = mask & base_mask
|
|
376
|
+
|
|
377
|
+
return mask
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def extend_index_mask_with_neighbours(
|
|
381
|
+
mask: torch.Tensor, D_LL: torch.Tensor, k: int
|
|
382
|
+
) -> torch.LongTensor:
|
|
383
|
+
"""
|
|
384
|
+
Parameters
|
|
385
|
+
----------
|
|
386
|
+
mask : (L, L) bool # pre-selected neighbours (True = keep)
|
|
387
|
+
D_LL : (B, L, L) float32/float64 # pairwise distances (lower = closer)
|
|
388
|
+
k: int # desired neighbours per query token
|
|
389
|
+
|
|
390
|
+
Returns
|
|
391
|
+
-------
|
|
392
|
+
neigh_idx : (L, k_neigh) long # exactly k_neigh indices per row
|
|
393
|
+
|
|
394
|
+
NB: Indices of the mask are placed first along k dimension. e.g.
|
|
395
|
+
indices[i, :] = [1, 2, 3, nan, nan] (from pre-built mask)
|
|
396
|
+
-> indices[i, :] = [1, 2, 3, 0, 5] # where 0, 5 are additional k NN (here k=5)
|
|
397
|
+
NB: If k_neigh = 14 * (2*n_seq_neigh + 1) (from above), then for tokens in the middle there will
|
|
398
|
+
be exactly no D_LL-local neighbours, but for tokens at the edges there will be an increasingly
|
|
399
|
+
large number of neighbours.
|
|
400
|
+
"""
|
|
401
|
+
if D_LL.ndim == 2:
|
|
402
|
+
D_LL = D_LL.unsqueeze(0)
|
|
403
|
+
B, L, _ = D_LL.shape
|
|
404
|
+
k = min(k, L)
|
|
405
|
+
assert mask.shape == (L, L) and D_LL.shape == (B, L, L)
|
|
406
|
+
device = D_LL.device
|
|
407
|
+
inf = torch.tensor(float("inf"), dtype=D_LL.dtype, device=device)
|
|
408
|
+
|
|
409
|
+
# 1. Selection of sequence neighbours
|
|
410
|
+
all_idx_row = torch.arange(L, device=device).expand(L, L)
|
|
411
|
+
indices = torch.where(mask, all_idx_row, inf) # sentinel inf if not-forced
|
|
412
|
+
indices = indices.sort(dim=1)[0][:, :k] # (L, k)
|
|
413
|
+
|
|
414
|
+
# 2. Find k-nn excluding forced indices
|
|
415
|
+
D_LL = torch.where(mask, inf, D_LL)
|
|
416
|
+
filler_idx = torch.topk(D_LL, k, dim=-1, largest=False).indices
|
|
417
|
+
|
|
418
|
+
# ... Reverse last axis s.t. best matched indices are last
|
|
419
|
+
filler_idx = filler_idx.flip(dims=[-1])
|
|
420
|
+
|
|
421
|
+
# 3. Fill indices
|
|
422
|
+
to_fill = indices == inf
|
|
423
|
+
to_fill = to_fill.expand_as(filler_idx)
|
|
424
|
+
indices = indices.expand_as(filler_idx)
|
|
425
|
+
indices = torch.where(to_fill, filler_idx, indices)
|
|
426
|
+
|
|
427
|
+
return indices.long() # (B, L, k)
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
def get_sparse_attention_indices(
|
|
431
|
+
res_idx, D_LL, n_seq_neighbours, k_max, chain_id=None, base_mask=None
|
|
432
|
+
):
|
|
433
|
+
mask = build_index_mask(
|
|
434
|
+
res_idx, n_seq_neighbours, k_max, chain_id=chain_id, base_mask=base_mask
|
|
435
|
+
)
|
|
436
|
+
indices = extend_index_mask_with_neighbours(mask, D_LL, k_max)
|
|
437
|
+
|
|
438
|
+
# Sort and assert no duplicates (optional but good practise)
|
|
439
|
+
indices, _ = torch.sort(indices, dim=-1)
|
|
440
|
+
if (indices[..., 1:] == indices[..., :-1]).any():
|
|
441
|
+
raise AssertionError("Tensor has duplicate elements along the last dimension.")
|
|
442
|
+
|
|
443
|
+
assert (
|
|
444
|
+
indices.shape[-1] == k_max
|
|
445
|
+
), f"Expected k_max={k_max} indices, got {indices.shape[-1]} instead."
|
|
446
|
+
# Detach to avoid gradients flowing through indices
|
|
447
|
+
|
|
448
|
+
return indices.detach()
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
@torch.no_grad()
|
|
452
|
+
def indices_to_mask(neigh_idx):
|
|
453
|
+
"""
|
|
454
|
+
Helper function for converting indices to masks for visualization
|
|
455
|
+
|
|
456
|
+
Args:
|
|
457
|
+
neigh_idx: [L, k] or [B, L, k] tensor of indices for attention.
|
|
458
|
+
"""
|
|
459
|
+
neigh_idx = neigh_idx.to(dtype=torch.long)
|
|
460
|
+
|
|
461
|
+
if neigh_idx.ndim == 2:
|
|
462
|
+
L = neigh_idx.shape[0]
|
|
463
|
+
mask_out = torch.zeros((L, L), dtype=torch.bool, device=neigh_idx.device)
|
|
464
|
+
mask_out.scatter_(1, neigh_idx, torch.ones_like(neigh_idx, dtype=torch.bool))
|
|
465
|
+
|
|
466
|
+
elif neigh_idx.ndim == 3:
|
|
467
|
+
B, L, k = neigh_idx.shape
|
|
468
|
+
mask_out = torch.zeros((B, L, L), dtype=torch.bool, device=neigh_idx.device)
|
|
469
|
+
mask_out.scatter_(2, neigh_idx, torch.ones_like(neigh_idx, dtype=torch.bool))
|
|
470
|
+
|
|
471
|
+
else:
|
|
472
|
+
raise ValueError(f"Expected ndim 2 or 3, got {neigh_idx.ndim}")
|
|
473
|
+
|
|
474
|
+
return mask_out
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
def create_valid_mask_LA(valid_mask):
|
|
478
|
+
"""
|
|
479
|
+
Helper function for X_IAA (token-grouped atom-pair representations).
|
|
480
|
+
valid_mask: [I, A] represents which atoms in the token-grouping are real,
|
|
481
|
+
sum(valid_mask) = L, where L is total number of atoms.
|
|
482
|
+
|
|
483
|
+
Returns
|
|
484
|
+
-------
|
|
485
|
+
valid_mask_LA: [L, A] L atoms by A atoms in token grouping.
|
|
486
|
+
indices: [L, A] absolute atom indices of atoms in token grouping.
|
|
487
|
+
|
|
488
|
+
E.g. Allows you to have [14, 14] matrices for every token in your protein,
|
|
489
|
+
where atomized tokens (or similar) will have invalid indices outside of [0,0].
|
|
490
|
+
"""
|
|
491
|
+
I, A = valid_mask.shape
|
|
492
|
+
L = valid_mask.sum()
|
|
493
|
+
pos = torch.arange(A, device=valid_mask.device)
|
|
494
|
+
rel_pos = pos.unsqueeze(-2) - pos.unsqueeze(-1) # [A, A]
|
|
495
|
+
rel_pos = rel_pos.unsqueeze(0).expand(I, -1, -1) # [I, A, A]
|
|
496
|
+
rel_pos_LA = rel_pos[valid_mask[..., None].expand_as(rel_pos)].view(
|
|
497
|
+
L, A
|
|
498
|
+
) # [I, A, A] -> [L, A]
|
|
499
|
+
|
|
500
|
+
indices = torch.arange(L, device=valid_mask.device).unsqueeze(-1).expand(L, A)
|
|
501
|
+
indices = indices + rel_pos_LA
|
|
502
|
+
|
|
503
|
+
valid_mask_IAA = valid_mask.unsqueeze(-2).expand(-1, A, -1)
|
|
504
|
+
valid_mask_LA = valid_mask_IAA[
|
|
505
|
+
valid_mask.unsqueeze(-1).expand_as(valid_mask_IAA)
|
|
506
|
+
].view(L, A)
|
|
507
|
+
|
|
508
|
+
indices[~valid_mask_LA] = -1
|
|
509
|
+
|
|
510
|
+
return valid_mask_LA, indices
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
def pairwise_mean_pool(
|
|
514
|
+
pairwise_atom_features: Float[torch.Tensor, "batch n_atoms n_atoms d_hidden"],
|
|
515
|
+
atom_to_token_map: Int[torch.Tensor, "n_atoms"],
|
|
516
|
+
I: int,
|
|
517
|
+
dtype: torch.dtype,
|
|
518
|
+
) -> Float[torch.Tensor, "batch n_tokens n_tokens d_hidden"]:
|
|
519
|
+
"""Mean pooling of pairwise atom features to pairwise token features.
|
|
520
|
+
|
|
521
|
+
Args:
|
|
522
|
+
pairwise_atom_features: Pairwise features between atoms
|
|
523
|
+
atom_to_token_map: Mapping from atoms to tokens
|
|
524
|
+
I: Number of tokens
|
|
525
|
+
dtype: Data type for computations
|
|
526
|
+
|
|
527
|
+
Returns:
|
|
528
|
+
Token pairwise features pooled by averaging over atom pairs within tokens
|
|
529
|
+
"""
|
|
530
|
+
B, _, _, _ = pairwise_atom_features.shape
|
|
531
|
+
|
|
532
|
+
# Create one-hot encoding for atom-to-token mapping
|
|
533
|
+
atom_to_token_onehot = F.one_hot(atom_to_token_map.long(), num_classes=I).to(
|
|
534
|
+
dtype
|
|
535
|
+
) # (L, I)
|
|
536
|
+
|
|
537
|
+
# Use einsum to aggregate features across atom pairs for each token pair
|
|
538
|
+
# For each token pair (i, j), sum over all atom pairs (l1, l2) where l1→i and l2→j
|
|
539
|
+
# Result[b,i,j,d] = sum_l1,l2 ( onehot[l1,i] * onehot[l2,j] * features[b,l1,l2,d] )
|
|
540
|
+
use_memory_efficient_einsum = True
|
|
541
|
+
if use_memory_efficient_einsum:
|
|
542
|
+
# Memory-optimized implementation using two-step einsum:
|
|
543
|
+
# First step: contract on axis 1 (left-side tokens)
|
|
544
|
+
# (L, I)^T = (I, L), (B, L, L, d) → (B, I, L, d)
|
|
545
|
+
temp = torch.einsum(
|
|
546
|
+
"ia,bacd->bicd", atom_to_token_onehot.T, pairwise_atom_features
|
|
547
|
+
)
|
|
548
|
+
|
|
549
|
+
# Free the original to save memory if not needed
|
|
550
|
+
del pairwise_atom_features
|
|
551
|
+
|
|
552
|
+
# Second step: contract on axis 2 (right-side tokens)
|
|
553
|
+
# (L, I) = (L, I), (B, I, L, d) → (B, I, I, d)
|
|
554
|
+
token_features_sum = torch.einsum("cj,bicd->bijd", atom_to_token_onehot, temp)
|
|
555
|
+
|
|
556
|
+
# Optionally free temp
|
|
557
|
+
del temp
|
|
558
|
+
else:
|
|
559
|
+
token_features_sum = torch.einsum(
|
|
560
|
+
"ai,cj,bacd->bijd",
|
|
561
|
+
atom_to_token_onehot, # (L, I)
|
|
562
|
+
atom_to_token_onehot, # (L, I)
|
|
563
|
+
pairwise_atom_features, # (B, L, L, d_hidden)
|
|
564
|
+
) # (B, I, I, d_hidden)
|
|
565
|
+
|
|
566
|
+
# Count the number of atom pairs contributing to each token pair
|
|
567
|
+
# count[i, j] = number of atom pairs (l1, l2) where l1→i and l2→j (same for all batches)
|
|
568
|
+
atom_counts_per_token = atom_to_token_onehot.sum(dim=0) # (I,)
|
|
569
|
+
token_pair_counts = torch.outer(
|
|
570
|
+
atom_counts_per_token, atom_counts_per_token
|
|
571
|
+
) # (I, I) (= outer product)
|
|
572
|
+
|
|
573
|
+
# Expand to match batch dimension: (I, I) -> (B, I, I)
|
|
574
|
+
token_pair_counts = token_pair_counts.unsqueeze(0).expand(B, -1, -1)
|
|
575
|
+
|
|
576
|
+
# Avoid division by zero and compute mean
|
|
577
|
+
token_pair_counts = torch.clamp(token_pair_counts, min=1)
|
|
578
|
+
token_pairwise_features = token_features_sum / token_pair_counts.unsqueeze(-1)
|
|
579
|
+
|
|
580
|
+
return token_pairwise_features
|