rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,577 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from math import sqrt
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
from einops import rearrange
|
|
8
|
+
from opt_einsum import contract as einsum
|
|
9
|
+
from rfd3.model.layers.block_utils import (
|
|
10
|
+
create_attention_indices,
|
|
11
|
+
indices_to_mask,
|
|
12
|
+
)
|
|
13
|
+
from rfd3.model.layers.layer_utils import (
|
|
14
|
+
AdaLN,
|
|
15
|
+
LinearBiasInit,
|
|
16
|
+
RMSNorm,
|
|
17
|
+
linearNoBias,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
from foundry.common import exists
|
|
21
|
+
from foundry.training.checkpoint import activation_checkpointing
|
|
22
|
+
from foundry.utils.ddp import RankedLogger
|
|
23
|
+
|
|
24
|
+
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
25
|
+
|
|
26
|
+
try:
|
|
27
|
+
from cuequivariance_torch import attention_pair_bias as cueq_attention_pair_bias
|
|
28
|
+
|
|
29
|
+
# ranked_logger.info("Fused PairBiasAttention enabled!")
|
|
30
|
+
_CUEQ_AVAILABLE = True
|
|
31
|
+
except Exception:
|
|
32
|
+
# ranked_logger.warning(
|
|
33
|
+
# "Using pytorch implementation instead of NVIDIA kernel"
|
|
34
|
+
# "Ensure you are using the latest apptainer."
|
|
35
|
+
# )
|
|
36
|
+
_CUEQ_AVAILABLE = False
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@torch.compiler.disable
|
|
40
|
+
def kernel_pairbias_attention(
|
|
41
|
+
*,
|
|
42
|
+
s: torch.Tensor, # (B, U, D) sequence features used for gating/output inside the kernel
|
|
43
|
+
q: torch.Tensor, # (B, H, U, DH)
|
|
44
|
+
k: torch.Tensor, # (B, H, V, DH)
|
|
45
|
+
v: torch.Tensor, # (B, H, V, DH)
|
|
46
|
+
z: torch.Tensor, # (B, U, V, z_dim)
|
|
47
|
+
mask: torch.Tensor | None, # (B, V) or (B*M, V) with 1=keep, 0=mask
|
|
48
|
+
num_heads: int,
|
|
49
|
+
w_proj_z: torch.Tensor, # (H, z_dim)
|
|
50
|
+
w_proj_g: torch.Tensor, # (D, D)
|
|
51
|
+
w_proj_o: torch.Tensor, # (D, D)
|
|
52
|
+
w_ln_z: torch.Tensor, # (z_dim,)
|
|
53
|
+
b_ln_z: torch.Tensor, # (z_dim,)
|
|
54
|
+
b_proj_z: torch.Tensor | None = None, # (H,)
|
|
55
|
+
b_proj_g: torch.Tensor | None = None, # (D,)
|
|
56
|
+
b_proj_o: torch.Tensor | None = None, # (D,)
|
|
57
|
+
attn_scale: float | None = None,
|
|
58
|
+
compute_pair_bias: bool = True,
|
|
59
|
+
multiplicity: int = 1,
|
|
60
|
+
) -> torch.Tensor:
|
|
61
|
+
"""Thin wrapper around cuequivariance_torch.attention_pair_bias."""
|
|
62
|
+
raise NotImplementedError("CUDA Kernel for attention pair bias not implemented")
|
|
63
|
+
out, _proj_z = cueq_attention_pair_bias(
|
|
64
|
+
s=s,
|
|
65
|
+
q=q,
|
|
66
|
+
k=k,
|
|
67
|
+
v=v,
|
|
68
|
+
z=z,
|
|
69
|
+
mask=mask,
|
|
70
|
+
num_heads=num_heads,
|
|
71
|
+
w_proj_z=w_proj_z,
|
|
72
|
+
w_proj_g=w_proj_g,
|
|
73
|
+
w_proj_o=w_proj_o,
|
|
74
|
+
w_ln_z=w_ln_z,
|
|
75
|
+
b_ln_z=b_ln_z,
|
|
76
|
+
b_proj_z=b_proj_z,
|
|
77
|
+
b_proj_g=b_proj_g,
|
|
78
|
+
b_proj_o=b_proj_o,
|
|
79
|
+
attn_scale=attn_scale,
|
|
80
|
+
compute_pair_bias=compute_pair_bias,
|
|
81
|
+
multiplicity=multiplicity,
|
|
82
|
+
)
|
|
83
|
+
return out # (B, U, D)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
######################################################################################
|
|
87
|
+
########################## Network Modules ##########################
|
|
88
|
+
######################################################################################
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class GatedCrossAttention(nn.Module):
|
|
92
|
+
def __init__(
|
|
93
|
+
self,
|
|
94
|
+
c_query,
|
|
95
|
+
c_kv,
|
|
96
|
+
c_pair=None,
|
|
97
|
+
c_model=128,
|
|
98
|
+
n_head=4,
|
|
99
|
+
kq_norm=True,
|
|
100
|
+
dropout=0.0,
|
|
101
|
+
**_,
|
|
102
|
+
):
|
|
103
|
+
super().__init__()
|
|
104
|
+
self.n_head = n_head
|
|
105
|
+
self.scale = 1 / math.sqrt(c_model // n_head)
|
|
106
|
+
assert c_model % n_head == 0, "c_model must be divisible by n_heads"
|
|
107
|
+
|
|
108
|
+
self.ln_q = RMSNorm(c_query)
|
|
109
|
+
self.ln_kv = RMSNorm(c_kv)
|
|
110
|
+
|
|
111
|
+
self.to_q = linearNoBias(c_query, c_model)
|
|
112
|
+
self.to_k = linearNoBias(c_kv, c_model)
|
|
113
|
+
self.to_v = linearNoBias(c_kv, c_model)
|
|
114
|
+
self.to_g = nn.Sequential(
|
|
115
|
+
linearNoBias(c_query, c_model),
|
|
116
|
+
nn.Sigmoid(),
|
|
117
|
+
)
|
|
118
|
+
self.to_out = nn.Sequential(nn.Linear(c_model, c_query), nn.Dropout(dropout))
|
|
119
|
+
self.kq_norm = kq_norm
|
|
120
|
+
if self.kq_norm:
|
|
121
|
+
self.k_norm = RMSNorm(c_model)
|
|
122
|
+
self.q_norm = RMSNorm(c_model)
|
|
123
|
+
|
|
124
|
+
self.c_pair = c_pair
|
|
125
|
+
if c_pair is not None:
|
|
126
|
+
self.to_b = nn.Sequential(RMSNorm(c_pair), linearNoBias(c_pair, n_head))
|
|
127
|
+
self.reset_parameter()
|
|
128
|
+
|
|
129
|
+
def reset_parameter(self):
|
|
130
|
+
# query/key/value projection: Xavier uniform
|
|
131
|
+
nn.init.xavier_uniform_(self.to_q.weight)
|
|
132
|
+
nn.init.xavier_uniform_(self.to_k.weight)
|
|
133
|
+
nn.init.xavier_uniform_(self.to_v.weight)
|
|
134
|
+
nn.init.xavier_uniform_(self.to_g[0].weight)
|
|
135
|
+
nn.init.xavier_uniform_(self.to_out[0].weight)
|
|
136
|
+
|
|
137
|
+
def forward(self, q, kv, attn_mask=None, pair_bias=None):
|
|
138
|
+
"""
|
|
139
|
+
Args:
|
|
140
|
+
q: [B, tok, n_q, c_query]
|
|
141
|
+
kv: [B, tok, n_kv, c_kv]
|
|
142
|
+
attn_mask: [n_q, n_kv]
|
|
143
|
+
Returns:
|
|
144
|
+
attn_out: [B, tok, n_q, c_query]
|
|
145
|
+
"""
|
|
146
|
+
|
|
147
|
+
q = self.ln_q(q)
|
|
148
|
+
kv = self.ln_kv(kv)
|
|
149
|
+
|
|
150
|
+
q, k, v, g = self.to_q(q), self.to_k(kv), self.to_v(kv), self.to_g(q)
|
|
151
|
+
|
|
152
|
+
if self.kq_norm:
|
|
153
|
+
k = self.k_norm(k)
|
|
154
|
+
q = self.q_norm(q)
|
|
155
|
+
|
|
156
|
+
q, k, v, g = map(
|
|
157
|
+
lambda t: rearrange(t, "b t n (h c) -> b h t n c", h=self.n_head),
|
|
158
|
+
(q, k, v, g),
|
|
159
|
+
) # [B, tok, n, heads, c] -> [B, heads, tok, n, c]
|
|
160
|
+
|
|
161
|
+
attn = einsum("bhtqc,bhtkc->bhtqk", q, k) * self.scale
|
|
162
|
+
|
|
163
|
+
if pair_bias is not None:
|
|
164
|
+
b = self.to_b(pair_bias)
|
|
165
|
+
b = rearrange(b, "b t q k (h) -> b (h) t q k", h=self.n_head)
|
|
166
|
+
attn = attn + b
|
|
167
|
+
|
|
168
|
+
# Invalid query handling:
|
|
169
|
+
if attn_mask is not None:
|
|
170
|
+
attn = attn.masked_fill(~attn_mask[None, None], float("-inf"))
|
|
171
|
+
|
|
172
|
+
# Bugfix: Empty queries need to have a constant value otherwise nans are in the forward graph. I don't
|
|
173
|
+
# know why this causes instabilities because the invalid queries are masked out later. Oh well!
|
|
174
|
+
invalid_queries = torch.logical_not(
|
|
175
|
+
torch.any(attn_mask, dim=-1, keepdim=False)
|
|
176
|
+
) # [n_q,]
|
|
177
|
+
attn[:, :, invalid_queries, :] = 0.0
|
|
178
|
+
|
|
179
|
+
attn = F.softmax(attn, dim=-1)
|
|
180
|
+
attn_out = einsum("bhtqk,bhtkd->bhtqd", attn, v)
|
|
181
|
+
attn_out = attn_out * g
|
|
182
|
+
|
|
183
|
+
attn_out = rearrange(attn_out, "b h t n c -> b t n (h c)")
|
|
184
|
+
attn_out = self.to_out(attn_out) # [B, n_tok, n_k, c]
|
|
185
|
+
return attn_out
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class LocalAttentionPairBias(nn.Module):
|
|
189
|
+
def __init__(
|
|
190
|
+
self,
|
|
191
|
+
c_a,
|
|
192
|
+
c_s,
|
|
193
|
+
c_pair,
|
|
194
|
+
n_head,
|
|
195
|
+
kq_norm=True,
|
|
196
|
+
n_attn_seq_neighbours=2,
|
|
197
|
+
n_attn_keys=128,
|
|
198
|
+
):
|
|
199
|
+
super().__init__()
|
|
200
|
+
self.c = c_a # d_model dim same as input features
|
|
201
|
+
self.n_head = n_head
|
|
202
|
+
|
|
203
|
+
self.to_q = linearNoBias(c_a, self.c)
|
|
204
|
+
self.to_k = linearNoBias(c_a, self.c)
|
|
205
|
+
self.to_v = linearNoBias(c_a, self.c)
|
|
206
|
+
self.to_b = linearNoBias(c_pair, self.n_head)
|
|
207
|
+
self.to_g = nn.Sequential(
|
|
208
|
+
linearNoBias(c_a, self.c, bias=False),
|
|
209
|
+
nn.Sigmoid(),
|
|
210
|
+
)
|
|
211
|
+
self.kq_norm = kq_norm
|
|
212
|
+
if kq_norm:
|
|
213
|
+
self.ln_q = RMSNorm(self.c)
|
|
214
|
+
self.ln_k = RMSNorm(self.c)
|
|
215
|
+
|
|
216
|
+
# Output / Input projections
|
|
217
|
+
self.to_o = linearNoBias(self.c, c_a) # from attn to Q_L
|
|
218
|
+
|
|
219
|
+
# Conditioned
|
|
220
|
+
if exists(c_s):
|
|
221
|
+
self.ada_ln_1 = AdaLN(c_a=c_a, c_s=c_s)
|
|
222
|
+
self.linear_output_project = nn.Sequential(
|
|
223
|
+
LinearBiasInit(c_s, c_a, biasinit=-2.0),
|
|
224
|
+
nn.Sigmoid(),
|
|
225
|
+
)
|
|
226
|
+
else:
|
|
227
|
+
self.ln_1 = RMSNorm(c_a)
|
|
228
|
+
|
|
229
|
+
# Used if no indices are provided
|
|
230
|
+
self.n_attn_seq_neighbours = n_attn_seq_neighbours
|
|
231
|
+
self.n_attn_keys = n_attn_keys
|
|
232
|
+
self.use_checkpointing = True
|
|
233
|
+
|
|
234
|
+
def forward(
|
|
235
|
+
self,
|
|
236
|
+
Q_L,
|
|
237
|
+
C_L,
|
|
238
|
+
P_LL,
|
|
239
|
+
indices=None,
|
|
240
|
+
f=None,
|
|
241
|
+
X_L=None,
|
|
242
|
+
full=False,
|
|
243
|
+
chunked_pairwise_embedder=None,
|
|
244
|
+
initializer_outputs=None,
|
|
245
|
+
):
|
|
246
|
+
"""
|
|
247
|
+
Q_L: [D, L, c_a]
|
|
248
|
+
C_L: [D, L, c_s]
|
|
249
|
+
P_LL: [D, L, L, c_pair] or None (if using chunked mode)
|
|
250
|
+
indices: [D, L, k] long
|
|
251
|
+
chunked_pairwise_embedder: ChunkedPairwiseEmbedder for memory efficient computation
|
|
252
|
+
initializer_outputs: Dict containing features for chunked computation
|
|
253
|
+
"""
|
|
254
|
+
|
|
255
|
+
# If no indices are provided, prepare indices from
|
|
256
|
+
if not exists(indices):
|
|
257
|
+
indices = create_attention_indices(
|
|
258
|
+
f,
|
|
259
|
+
n_attn_keys=self.n_attn_keys,
|
|
260
|
+
n_attn_seq_neighbours=self.n_attn_seq_neighbours,
|
|
261
|
+
X_L=X_L,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
# Handle chunked P_LL computation
|
|
265
|
+
if chunked_pairwise_embedder is not None and P_LL is None:
|
|
266
|
+
# Compute sparse P_LL only for the attention indices
|
|
267
|
+
P_LL_sparse = chunked_pairwise_embedder.forward_chunked(
|
|
268
|
+
f=f,
|
|
269
|
+
indices=indices,
|
|
270
|
+
C_L=initializer_outputs["C_L"],
|
|
271
|
+
Z_init_II=initializer_outputs["Z_II"],
|
|
272
|
+
tok_idx=f["atom_to_token_map"],
|
|
273
|
+
)
|
|
274
|
+
# P_LL_sparse is already in sparse format [D, L, k, c_pair]
|
|
275
|
+
use_sparse_pll = True
|
|
276
|
+
else:
|
|
277
|
+
# Original full P_LL computation
|
|
278
|
+
P_LL_sparse = None
|
|
279
|
+
use_sparse_pll = False
|
|
280
|
+
|
|
281
|
+
use_kernel = False
|
|
282
|
+
|
|
283
|
+
def do_attention(Q_L, C_L, P_LL):
|
|
284
|
+
if exists(C_L):
|
|
285
|
+
Q_L = self.ada_ln_1(Q_L, C_L)
|
|
286
|
+
else:
|
|
287
|
+
Q_L = self.ln_1(Q_L)
|
|
288
|
+
|
|
289
|
+
if use_kernel and not use_sparse_pll:
|
|
290
|
+
# TODO: Update with latest kernel
|
|
291
|
+
q, k, v, g, b = (
|
|
292
|
+
self.to_q(Q_L),
|
|
293
|
+
self.to_k(Q_L),
|
|
294
|
+
self.to_v(Q_L),
|
|
295
|
+
self.to_g(Q_L),
|
|
296
|
+
self.to_b(P_LL),
|
|
297
|
+
)
|
|
298
|
+
q, k = (self.ln_q(q), self.ln_k(k)) if self.kq_norm else (q, k)
|
|
299
|
+
attn_out = _fused_full_pairbias_attention(
|
|
300
|
+
Q_L=q, # already projected queries (B, L, c)
|
|
301
|
+
K_L=k,
|
|
302
|
+
V_L=v,
|
|
303
|
+
P_LL=P_LL, # pair features (B, L, L, c_pair)
|
|
304
|
+
num_heads=self.n_head,
|
|
305
|
+
to_b=None, # pair-bias projector (H, c_pair)
|
|
306
|
+
to_g_linear=None, # gating linear (D, D)
|
|
307
|
+
to_o_linear=None, # output linear (D, D)
|
|
308
|
+
w_ln_z_identity=None,
|
|
309
|
+
b_ln_z_identity=None,
|
|
310
|
+
attn_scale=1.0 / math.sqrt(self.c // self.n_head),
|
|
311
|
+
)
|
|
312
|
+
else:
|
|
313
|
+
# Sparse attention path
|
|
314
|
+
q, k, v, g = (
|
|
315
|
+
self.to_q(Q_L),
|
|
316
|
+
self.to_k(Q_L),
|
|
317
|
+
self.to_v(Q_L),
|
|
318
|
+
self.to_g(Q_L),
|
|
319
|
+
)
|
|
320
|
+
q, k = (self.ln_q(q), self.ln_k(k)) if self.kq_norm else (q, k)
|
|
321
|
+
|
|
322
|
+
if use_sparse_pll:
|
|
323
|
+
# Use pre-computed sparse P_LL (already gathered)
|
|
324
|
+
b = self.to_b(P_LL_sparse) # [D, L, k, H]
|
|
325
|
+
attn_out = sparse_pairbias_attention(
|
|
326
|
+
Q=q,
|
|
327
|
+
K=k,
|
|
328
|
+
V=v,
|
|
329
|
+
B=b,
|
|
330
|
+
G=g,
|
|
331
|
+
gather_bias=False, # Already gathered!
|
|
332
|
+
indices=indices,
|
|
333
|
+
H=self.n_head,
|
|
334
|
+
full=full,
|
|
335
|
+
) # [D, L, c]
|
|
336
|
+
else:
|
|
337
|
+
# Original full P_LL path
|
|
338
|
+
b = self.to_b(P_LL)
|
|
339
|
+
attn_out = sparse_pairbias_attention(
|
|
340
|
+
Q=q,
|
|
341
|
+
K=k,
|
|
342
|
+
V=v,
|
|
343
|
+
B=b,
|
|
344
|
+
G=g,
|
|
345
|
+
gather_bias=True,
|
|
346
|
+
indices=indices,
|
|
347
|
+
H=self.n_head,
|
|
348
|
+
full=full,
|
|
349
|
+
) # [D, L, c]
|
|
350
|
+
|
|
351
|
+
# Output projection (from adaLN-Zero)
|
|
352
|
+
Q_L = self.to_o(attn_out)
|
|
353
|
+
if exists(C_L):
|
|
354
|
+
Q_L = self.linear_output_project(C_L) * Q_L
|
|
355
|
+
|
|
356
|
+
return Q_L
|
|
357
|
+
|
|
358
|
+
do_attention_ = (
|
|
359
|
+
activation_checkpointing(do_attention)
|
|
360
|
+
if self.use_checkpointing
|
|
361
|
+
else do_attention
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
# Call attention with appropriate P_LL
|
|
365
|
+
if use_sparse_pll:
|
|
366
|
+
return do_attention_(Q_L, C_L, P_LL_sparse)
|
|
367
|
+
else:
|
|
368
|
+
return do_attention_(Q_L, C_L, P_LL)
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
######################################################################################
|
|
372
|
+
########################## Kernel Functions ##########################
|
|
373
|
+
######################################################################################
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
def sparse_pairbias_attention(
|
|
377
|
+
Q, K, V, B, indices, H, gather_bias=True, G=None, full=False
|
|
378
|
+
):
|
|
379
|
+
"""
|
|
380
|
+
Computes attention with sparse pairwise bias, where indices specify which
|
|
381
|
+
keys to attend to for each query token.
|
|
382
|
+
Q: (D, L, c) # query vectors
|
|
383
|
+
K: (D, L, c) # key vectors
|
|
384
|
+
V: (D, L, c) # value vectors
|
|
385
|
+
B: (L, L, H) # attention bias (unbatched or pre-gathered and [D, L, k, H])
|
|
386
|
+
G: (D, L, c) # Gate (optional)
|
|
387
|
+
B2: (D, L, 14, 14, H) # attention bias (batched and within token) (optional)
|
|
388
|
+
indices: (D, L, k_neigh) long # indices of neighbours to attend to
|
|
389
|
+
Returns
|
|
390
|
+
-------
|
|
391
|
+
attn_out: (D, L, c) # attention output
|
|
392
|
+
"""
|
|
393
|
+
D, L, c = Q.shape
|
|
394
|
+
k = indices.shape[-1] # k_neigh
|
|
395
|
+
|
|
396
|
+
if full:
|
|
397
|
+
# During training, compute full attention matrix to create a more optimized torch.tensor graph.
|
|
398
|
+
return pairbias_attention_(
|
|
399
|
+
Q=Q,
|
|
400
|
+
K=K,
|
|
401
|
+
V=V,
|
|
402
|
+
B=B,
|
|
403
|
+
H=H,
|
|
404
|
+
valid_mask=indices_to_mask(indices),
|
|
405
|
+
G=G,
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
# Pull vectors from dimension 1 into index torch.tensor according to unique k_neigh axis
|
|
409
|
+
batch_idx = torch.arange(D, device=Q.device).view(-1, 1, 1) # (D,1,1)
|
|
410
|
+
K_gathered = K[batch_idx, indices].contiguous() # (D, L, k, c)
|
|
411
|
+
V_gathered = V[batch_idx, indices].contiguous() # (D, L, k, c)
|
|
412
|
+
|
|
413
|
+
# Gather bias or assume pre-gathered
|
|
414
|
+
if gather_bias:
|
|
415
|
+
query_idx = torch.arange(L, device=Q.device).view(1, L, 1) # (1,L,1)
|
|
416
|
+
query_idx = query_idx.expand(D, -1, k)
|
|
417
|
+
if B.ndim == 3:
|
|
418
|
+
B_gathered = B[query_idx, indices, :] # (D, L, k, H)
|
|
419
|
+
elif B.ndim == 4: # (D, L, L, H)
|
|
420
|
+
B_gathered = B[batch_idx, query_idx, indices, :] # (D, L, k, H)
|
|
421
|
+
else:
|
|
422
|
+
assert B.shape == (D, L, k, H), "B must be batched with shape (D, L, k, H)"
|
|
423
|
+
B_gathered = B
|
|
424
|
+
B_gathered = B_gathered.contiguous()
|
|
425
|
+
|
|
426
|
+
# Split into heads
|
|
427
|
+
Q = Q.reshape(D, L, H, c // H)
|
|
428
|
+
K_gathered = K_gathered.reshape(D, L, k, H, c // H)
|
|
429
|
+
V_gathered = V_gathered.reshape(D, L, k, H, c // H)
|
|
430
|
+
B_gathered = B_gathered.reshape(D, L, k, H)
|
|
431
|
+
Q = Q.permute(0, 2, 1, 3) # [D, H, L, c // H]
|
|
432
|
+
K_gathered = K_gathered.permute(0, 3, 1, 2, 4)
|
|
433
|
+
V_gathered = V_gathered.permute(0, 3, 1, 2, 4)
|
|
434
|
+
B_gathered = B_gathered.permute(0, 3, 1, 2)
|
|
435
|
+
|
|
436
|
+
# Do attention
|
|
437
|
+
attn = torch.einsum("...ld,...lkd->...lk", Q, K_gathered)
|
|
438
|
+
attn = attn / sqrt(c // H) # scale
|
|
439
|
+
attn = attn + B_gathered # add bias
|
|
440
|
+
attn = torch.softmax(attn, dim=-1) # softmax over keys [D, H, L, k]
|
|
441
|
+
attn_out = torch.einsum(
|
|
442
|
+
"...ij,...ijc->...ic", attn, V_gathered
|
|
443
|
+
) # allocates a max of 4.95 GiB.
|
|
444
|
+
|
|
445
|
+
# Optional gating
|
|
446
|
+
if G is not None:
|
|
447
|
+
G = G.reshape(D, L, H, c // H).permute(0, 2, 1, 3)
|
|
448
|
+
attn_out = attn_out * G
|
|
449
|
+
|
|
450
|
+
# Merge heads
|
|
451
|
+
attn_out = attn_out.permute(0, 2, 1, 3)
|
|
452
|
+
attn_out = attn_out.reshape(D, L, c).contiguous()
|
|
453
|
+
|
|
454
|
+
return attn_out # [D, L, c]
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
def pairbias_attention_(Q, K, V, B, H, valid_mask=None, G=None):
|
|
458
|
+
"""
|
|
459
|
+
Fully connected variant of pairbias attention with optional gating and valid mask.
|
|
460
|
+
Equivalent to sparse attention but with all keys
|
|
461
|
+
|
|
462
|
+
Attn_out: [batch_size, query_length, H * head_dim]
|
|
463
|
+
"""
|
|
464
|
+
D, L, c = Q.shape
|
|
465
|
+
k = L
|
|
466
|
+
|
|
467
|
+
# Split into heads
|
|
468
|
+
Q = Q.reshape(D, L, H, c // H)
|
|
469
|
+
K = K.reshape(D, k, H, c // H)
|
|
470
|
+
V = V.reshape(D, k, H, c // H)
|
|
471
|
+
B = B.reshape(D, L, k, H)
|
|
472
|
+
|
|
473
|
+
# Flip heads upwards [..., H, d_model] -> [B, H, ..., d_model]
|
|
474
|
+
Q = Q.permute(0, 2, 1, 3) # [D, H, L, c // H]
|
|
475
|
+
K = K.permute(0, 2, 1, 3)
|
|
476
|
+
V = V.permute(0, 2, 1, 3)
|
|
477
|
+
B = B.permute(0, 3, 1, 2)
|
|
478
|
+
|
|
479
|
+
# Do attention
|
|
480
|
+
attn = torch.einsum("...ld,...kd->...lk", Q, K)
|
|
481
|
+
attn = attn / sqrt(c // H) # scale
|
|
482
|
+
attn = attn + B # add bias
|
|
483
|
+
if exists(valid_mask):
|
|
484
|
+
# expand valid mask over heads [D, H, L, L]
|
|
485
|
+
attn = attn.masked_fill(~valid_mask.unsqueeze(1), float("-inf"))
|
|
486
|
+
attn = torch.softmax(attn, dim=-1) # softmax over keys [D, H, L, k]
|
|
487
|
+
attn_out = torch.einsum("...ij,...jc->...ic", attn, V)
|
|
488
|
+
|
|
489
|
+
# Optional gating
|
|
490
|
+
if G is not None:
|
|
491
|
+
G = G.reshape(D, L, H, c // H).permute(0, 2, 1, 3)
|
|
492
|
+
attn_out = attn_out * G
|
|
493
|
+
|
|
494
|
+
# Merge heads
|
|
495
|
+
attn_out = attn_out.permute(0, 2, 1, 3)
|
|
496
|
+
attn_out = attn_out.reshape(D, L, c).contiguous()
|
|
497
|
+
|
|
498
|
+
return attn_out
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
def _fused_full_pairbias_attention(
|
|
502
|
+
*,
|
|
503
|
+
Q_L, # (B, L, c) -- sequence features used to make q,k,v and for gating
|
|
504
|
+
K_L, # (B, L, c)
|
|
505
|
+
V_L, # (B, L, c)
|
|
506
|
+
P_LL, # (B, L, L, c_pair)
|
|
507
|
+
num_heads: int,
|
|
508
|
+
to_b: nn.Linear, # projects pair features -> heads (H)
|
|
509
|
+
to_g_linear: nn.Linear, # weight (D, D), bias optional/None (pre-sigmoid, kernel handles gate)
|
|
510
|
+
to_o_linear: nn.Linear, # weight (D, D), bias optional/None (kernel handles output proj)
|
|
511
|
+
w_ln_z_identity: torch.torch.Tensor, # (c_pair,)
|
|
512
|
+
b_ln_z_identity: torch.torch.Tensor, # (c_pair,)
|
|
513
|
+
attn_scale: float | None = None,
|
|
514
|
+
):
|
|
515
|
+
"""
|
|
516
|
+
Uses cuequivariance_torch.attention_pair_bias for dense (full) attention.
|
|
517
|
+
Expects Q/K/V to be projected *before* calling this function.
|
|
518
|
+
"""
|
|
519
|
+
B, L, c = Q_L.shape
|
|
520
|
+
H = num_heads
|
|
521
|
+
assert c % H == 0, "Model dim must be divisible by num_heads"
|
|
522
|
+
DH = c // H
|
|
523
|
+
|
|
524
|
+
# q, k, v as (B, H, L, DH)
|
|
525
|
+
q = Q_L.reshape(B, L, H, DH).permute(0, 2, 1, 3).contiguous()
|
|
526
|
+
k = K_L.reshape(B, L, H, DH).permute(0, 2, 1, 3).contiguous()
|
|
527
|
+
v = V_L.reshape(B, L, H, DH).permute(0, 2, 1, 3).contiguous()
|
|
528
|
+
|
|
529
|
+
# s is the sequence features for gating/output projections
|
|
530
|
+
s = Q_L.contiguous() # (B, L, c)
|
|
531
|
+
|
|
532
|
+
# mask: None (kernel supports key padding mask shape (B,V) or (B*M,V); we don't need it here)
|
|
533
|
+
mask = None
|
|
534
|
+
|
|
535
|
+
# weights/biases for kernel (shapes per doc):
|
|
536
|
+
# w_proj_z: (H, z_dim)
|
|
537
|
+
w_proj_z = to_b.weight # (H, c_pair)
|
|
538
|
+
b_proj_z = to_b.bias if hasattr(to_b, "bias") else None
|
|
539
|
+
|
|
540
|
+
# w_proj_g / o: (D, D)
|
|
541
|
+
w_proj_g = to_g_linear.weight # (D, D)
|
|
542
|
+
b_proj_g = to_g_linear.bias if hasattr(to_g_linear, "bias") else None
|
|
543
|
+
|
|
544
|
+
w_proj_o = to_o_linear.weight # (D, D)
|
|
545
|
+
b_proj_o = to_o_linear.bias if hasattr(to_o_linear, "bias") else None
|
|
546
|
+
|
|
547
|
+
# z-LN params
|
|
548
|
+
w_ln_z = w_ln_z_identity.to(dtype=P_LL.dtype, device=P_LL.device)
|
|
549
|
+
b_ln_z = b_ln_z_identity.to(dtype=P_LL.dtype, device=P_LL.device)
|
|
550
|
+
|
|
551
|
+
# optional scaling (match your manual path)
|
|
552
|
+
if attn_scale is None:
|
|
553
|
+
attn_scale = 1.0 / math.sqrt(DH)
|
|
554
|
+
|
|
555
|
+
# Call the fused kernel (B*M collapses to B here; multiplicity=1)
|
|
556
|
+
out, _proj_z = cueq_attention_pair_bias(
|
|
557
|
+
s=s,
|
|
558
|
+
q=q,
|
|
559
|
+
k=k,
|
|
560
|
+
v=v,
|
|
561
|
+
z=P_LL,
|
|
562
|
+
mask=mask,
|
|
563
|
+
num_heads=H,
|
|
564
|
+
w_proj_z=w_proj_z,
|
|
565
|
+
w_proj_g=w_proj_g,
|
|
566
|
+
w_proj_o=w_proj_o,
|
|
567
|
+
w_ln_z=w_ln_z,
|
|
568
|
+
b_ln_z=b_ln_z,
|
|
569
|
+
b_proj_z=b_proj_z,
|
|
570
|
+
b_proj_g=b_proj_g,
|
|
571
|
+
b_proj_o=b_proj_o,
|
|
572
|
+
attn_scale=attn_scale,
|
|
573
|
+
compute_pair_bias=True,
|
|
574
|
+
multiplicity=1,
|
|
575
|
+
)
|
|
576
|
+
# out: (B, L, c) already gated & projected
|
|
577
|
+
return out
|