rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,777 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import math
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
from atomworks.ml.encoding_definitions import AF3SequenceEncoding
|
|
8
|
+
from einops import rearrange
|
|
9
|
+
from rfd3.model.layers.attention import (
|
|
10
|
+
GatedCrossAttention,
|
|
11
|
+
LocalAttentionPairBias,
|
|
12
|
+
)
|
|
13
|
+
from rfd3.model.layers.block_utils import (
|
|
14
|
+
build_valid_mask,
|
|
15
|
+
create_attention_indices,
|
|
16
|
+
group_atoms,
|
|
17
|
+
ungroup_atoms,
|
|
18
|
+
)
|
|
19
|
+
from rfd3.model.layers.layer_utils import (
|
|
20
|
+
AdaLN,
|
|
21
|
+
EmbeddingLayer,
|
|
22
|
+
LinearBiasInit,
|
|
23
|
+
RMSNorm,
|
|
24
|
+
Transition,
|
|
25
|
+
collapse,
|
|
26
|
+
linearNoBias,
|
|
27
|
+
)
|
|
28
|
+
from rfd3.model.layers.pairformer_layers import PairformerBlock
|
|
29
|
+
from torch.nn.functional import one_hot
|
|
30
|
+
|
|
31
|
+
from foundry import DISABLE_CHECKPOINTING
|
|
32
|
+
from foundry.common import exists
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# SwiGLU transition block with adaptive layernorm
|
|
38
|
+
class ConditionedTransitionBlock(nn.Module):
|
|
39
|
+
def __init__(self, c_token, c_s, n=2):
|
|
40
|
+
super().__init__()
|
|
41
|
+
self.ada_ln = AdaLN(c_a=c_token, c_s=c_s)
|
|
42
|
+
self.linear_1 = linearNoBias(c_token, c_token * n)
|
|
43
|
+
self.linear_2 = linearNoBias(c_token, c_token * n)
|
|
44
|
+
self.linear_output_project = nn.Sequential(
|
|
45
|
+
LinearBiasInit(c_s, c_token, biasinit=-2.0),
|
|
46
|
+
nn.Sigmoid(),
|
|
47
|
+
)
|
|
48
|
+
self.linear_3 = linearNoBias(c_token * n, c_token)
|
|
49
|
+
|
|
50
|
+
def forward(
|
|
51
|
+
self,
|
|
52
|
+
Ai, # [B, I, C_token]
|
|
53
|
+
Si, # [B, I, C_token]
|
|
54
|
+
):
|
|
55
|
+
Ai = self.ada_ln(Ai, Si)
|
|
56
|
+
# BUG: This is not the correct implementation of SwiGLU
|
|
57
|
+
# Bi = torch.sigmoid(self.linear_1(Ai)) * self.linear_2(Ai)
|
|
58
|
+
# FIX: This is the correct implementation of SwiGLU
|
|
59
|
+
Bi = torch.nn.functional.silu(self.linear_1(Ai)) * self.linear_2(Ai)
|
|
60
|
+
|
|
61
|
+
# Output projection (from adaLN-Zero)
|
|
62
|
+
return self.linear_output_project(Si) * self.linear_3(Bi)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class PositionPairDistEmbedder(nn.Module):
|
|
66
|
+
def __init__(self, c_atompair, embed_frame=True):
|
|
67
|
+
super().__init__()
|
|
68
|
+
self.embed_frame = embed_frame
|
|
69
|
+
if embed_frame:
|
|
70
|
+
self.process_d = linearNoBias(3, c_atompair)
|
|
71
|
+
|
|
72
|
+
self.process_inverse_dist = linearNoBias(1, c_atompair)
|
|
73
|
+
self.process_valid_mask = linearNoBias(1, c_atompair)
|
|
74
|
+
|
|
75
|
+
def forward_af3(self, D_LL, V_LL):
|
|
76
|
+
"""Forward the same way reference positions are embeded in AF3"""
|
|
77
|
+
|
|
78
|
+
P_LL = self.process_d(D_LL) * V_LL
|
|
79
|
+
|
|
80
|
+
# Embed pairwise inverse squared distances, and the valid mask
|
|
81
|
+
if self.training:
|
|
82
|
+
P_LL = (
|
|
83
|
+
P_LL
|
|
84
|
+
+ self.process_inverse_dist(
|
|
85
|
+
1 / (1 + torch.linalg.norm(D_LL, dim=-1, keepdim=True) ** 2)
|
|
86
|
+
)
|
|
87
|
+
* V_LL
|
|
88
|
+
)
|
|
89
|
+
P_LL = P_LL + self.process_valid_mask(V_LL.to(P_LL.dtype)) * V_LL
|
|
90
|
+
else:
|
|
91
|
+
P_LL[V_LL[..., 0]] += self.process_inverse_dist(
|
|
92
|
+
1
|
|
93
|
+
/ (1 + torch.linalg.norm(D_LL[V_LL[..., 0]], dim=-1, keepdim=True) ** 2)
|
|
94
|
+
)
|
|
95
|
+
P_LL[V_LL[..., 0]] += self.process_valid_mask(
|
|
96
|
+
V_LL[V_LL[..., 0]].to(P_LL.dtype)
|
|
97
|
+
)
|
|
98
|
+
return P_LL
|
|
99
|
+
|
|
100
|
+
def forward(self, ref_pos, valid_mask):
|
|
101
|
+
D_LL = ref_pos.unsqueeze(-2) - ref_pos.unsqueeze(-3)
|
|
102
|
+
V_LL = valid_mask
|
|
103
|
+
|
|
104
|
+
if self.embed_frame:
|
|
105
|
+
# Embed pairwise distances
|
|
106
|
+
return self.forward_af3(D_LL, V_LL)
|
|
107
|
+
norm = torch.linalg.norm(D_LL, dim=-1, keepdim=True) ** 2
|
|
108
|
+
norm = torch.clamp(norm, min=1e-6)
|
|
109
|
+
inv_dist = 1 / (1 + norm)
|
|
110
|
+
P_LL = self.process_inverse_dist(inv_dist) * V_LL
|
|
111
|
+
P_LL = P_LL + self.process_valid_mask(V_LL.to(P_LL.dtype)) * V_LL
|
|
112
|
+
return P_LL
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class OneDFeatureEmbedder(nn.Module):
|
|
116
|
+
"""
|
|
117
|
+
Embeds 1D features into a single vector.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
features (dict): Dictionary of feature names and their number of channels.
|
|
121
|
+
output_channels (int): Output dimension of the projected embedding.
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
def __init__(self, features, output_channels):
|
|
125
|
+
super().__init__()
|
|
126
|
+
self.features = {k: v for k, v in features.items() if exists(v)}
|
|
127
|
+
total_embedding_input_features = sum(self.features.values())
|
|
128
|
+
self.embedders = nn.ModuleDict(
|
|
129
|
+
{
|
|
130
|
+
feature: EmbeddingLayer(
|
|
131
|
+
n_channels, total_embedding_input_features, output_channels
|
|
132
|
+
)
|
|
133
|
+
for feature, n_channels in self.features.items()
|
|
134
|
+
}
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
def forward(self, f, collapse_length):
|
|
138
|
+
return sum(
|
|
139
|
+
tuple(
|
|
140
|
+
self.embedders[feature](collapse(f[feature].float(), collapse_length))
|
|
141
|
+
for feature, n_channels in self.features.items()
|
|
142
|
+
if exists(n_channels)
|
|
143
|
+
)
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class SinusoidalDistEmbed(nn.Module):
|
|
148
|
+
"""
|
|
149
|
+
Applies sinusoidal embedding to pairwise distances and projects to c_atompair.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
c_atompair (int): Output dimension of the projected embedding (must be even).
|
|
153
|
+
"""
|
|
154
|
+
|
|
155
|
+
def __init__(self, c_atompair, n_freqs=32):
|
|
156
|
+
super().__init__()
|
|
157
|
+
assert c_atompair % 2 == 0, "Output embedding dim must be even"
|
|
158
|
+
|
|
159
|
+
self.n_freqs = (
|
|
160
|
+
n_freqs # Number of sin/cos pairs → total sinusoidal dim = 2 * n_freqs
|
|
161
|
+
)
|
|
162
|
+
self.c_atompair = c_atompair
|
|
163
|
+
|
|
164
|
+
self.output_proj = linearNoBias(2 * n_freqs, c_atompair)
|
|
165
|
+
self.process_valid_mask = linearNoBias(1, c_atompair)
|
|
166
|
+
|
|
167
|
+
def forward(self, pos, valid_mask):
|
|
168
|
+
"""
|
|
169
|
+
Args:
|
|
170
|
+
pos: [L, 3] or [B, L, 3] ground truth atom positions
|
|
171
|
+
valid_mask: [L, L, 1] or [B, L, L, 1] boolean mask
|
|
172
|
+
Returns:
|
|
173
|
+
P_LL: [L, L, c_atompair] or [B, L, L, c_atompair]
|
|
174
|
+
"""
|
|
175
|
+
# Compute pairwise distances
|
|
176
|
+
D_LL = pos.unsqueeze(-2) - pos.unsqueeze(-3) # [L, L, 3] or [B, L, L, 3]
|
|
177
|
+
dist_matrix = torch.linalg.norm(D_LL, dim=-1) # [L, L] or [B, L, L]
|
|
178
|
+
|
|
179
|
+
# Sinusoidal embedding
|
|
180
|
+
half_dim = self.n_freqs
|
|
181
|
+
freq = torch.exp(
|
|
182
|
+
-math.log(10000.0)
|
|
183
|
+
* torch.arange(0, half_dim, dtype=torch.float32)
|
|
184
|
+
/ half_dim
|
|
185
|
+
).to(dist_matrix.device) # [n_freqs]
|
|
186
|
+
|
|
187
|
+
angles = dist_matrix.unsqueeze(-1) * freq # [..., D/2]
|
|
188
|
+
sin_embed = torch.sin(angles)
|
|
189
|
+
cos_embed = torch.cos(angles)
|
|
190
|
+
sincos_embed = torch.cat([sin_embed, cos_embed], dim=-1) # [..., D]
|
|
191
|
+
|
|
192
|
+
# Linear projection
|
|
193
|
+
P_LL = self.output_proj(sincos_embed) # [..., c_atompair]
|
|
194
|
+
P_LL = P_LL * valid_mask
|
|
195
|
+
|
|
196
|
+
# Add linear embedding of valid mask
|
|
197
|
+
P_LL = P_LL + self.process_valid_mask(valid_mask.to(P_LL.dtype)) * valid_mask
|
|
198
|
+
return P_LL
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
class LinearEmbedWithPool(nn.Module):
|
|
202
|
+
def __init__(self, c_token):
|
|
203
|
+
super().__init__()
|
|
204
|
+
self.c_token = c_token
|
|
205
|
+
self.linear = linearNoBias(3, c_token)
|
|
206
|
+
|
|
207
|
+
def forward(self, R_L, tok_idx):
|
|
208
|
+
B = R_L.shape[0]
|
|
209
|
+
I = int(tok_idx.max().item()) + 1
|
|
210
|
+
A_I_shape = (
|
|
211
|
+
B,
|
|
212
|
+
I,
|
|
213
|
+
self.c_token,
|
|
214
|
+
)
|
|
215
|
+
Q_L = self.linear(R_L)
|
|
216
|
+
A_I = (
|
|
217
|
+
torch.zeros(A_I_shape, device=R_L.device, dtype=Q_L.dtype)
|
|
218
|
+
.index_reduce(
|
|
219
|
+
-2,
|
|
220
|
+
tok_idx.long(),
|
|
221
|
+
Q_L,
|
|
222
|
+
"mean",
|
|
223
|
+
include_self=False,
|
|
224
|
+
)
|
|
225
|
+
.clone()
|
|
226
|
+
)
|
|
227
|
+
return A_I
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
class SimpleRecycler(nn.Module):
|
|
231
|
+
def __init__(
|
|
232
|
+
self,
|
|
233
|
+
c_s,
|
|
234
|
+
c_z,
|
|
235
|
+
template_embedder,
|
|
236
|
+
msa_module,
|
|
237
|
+
n_pairformer_blocks,
|
|
238
|
+
pairformer_block,
|
|
239
|
+
):
|
|
240
|
+
super().__init__()
|
|
241
|
+
self.c_z = c_z
|
|
242
|
+
self.process_zh = nn.Sequential(
|
|
243
|
+
RMSNorm(c_z),
|
|
244
|
+
linearNoBias(c_z, c_z),
|
|
245
|
+
)
|
|
246
|
+
self.process_sh = nn.Sequential(
|
|
247
|
+
RMSNorm(c_s),
|
|
248
|
+
linearNoBias(c_s, c_s),
|
|
249
|
+
)
|
|
250
|
+
self.pairformer_stack = nn.ModuleList(
|
|
251
|
+
[
|
|
252
|
+
PairformerBlock(c_s=c_s, c_z=c_z, **pairformer_block)
|
|
253
|
+
for _ in range(n_pairformer_blocks)
|
|
254
|
+
]
|
|
255
|
+
)
|
|
256
|
+
# Templates and msa's removed:
|
|
257
|
+
# self.template_embedder = TemplateEmbedder(c_z=c_z, **template_embedder)
|
|
258
|
+
# self.msa_module = MSAModule(**msa_module)
|
|
259
|
+
|
|
260
|
+
def forward(
|
|
261
|
+
self,
|
|
262
|
+
f,
|
|
263
|
+
S_inputs_I,
|
|
264
|
+
S_init_I,
|
|
265
|
+
Z_init_II,
|
|
266
|
+
S_I,
|
|
267
|
+
Z_II,
|
|
268
|
+
):
|
|
269
|
+
Z_II = Z_init_II + self.process_zh(Z_II)
|
|
270
|
+
|
|
271
|
+
# Templates and msa's removed:
|
|
272
|
+
# Z_II = Z_II + self.template_embedder(f, Z_II)
|
|
273
|
+
# Z_II = self.msa_module(f, Z_II, S_inputs_I)
|
|
274
|
+
|
|
275
|
+
S_I = S_init_I + self.process_sh(S_I)
|
|
276
|
+
for block in self.pairformer_stack:
|
|
277
|
+
S_I, Z_II = block(S_I, Z_II)
|
|
278
|
+
return S_I, Z_II
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
class RelativePositionEncodingWithIndexRemoval(nn.Module):
|
|
282
|
+
"""
|
|
283
|
+
Usual RPE but utilizes `is_motif_atom_3d_unindexed` to ensure within-chain position is spoofed.
|
|
284
|
+
"""
|
|
285
|
+
|
|
286
|
+
def __init__(self, r_max, s_max, c_z):
|
|
287
|
+
super().__init__()
|
|
288
|
+
self.r_max = r_max
|
|
289
|
+
self.s_max = s_max
|
|
290
|
+
self.c_z = c_z
|
|
291
|
+
|
|
292
|
+
self.num_tok_pos_bins = (
|
|
293
|
+
2 * self.r_max + 2
|
|
294
|
+
) + 1 # original af3 + 1 for unknown index
|
|
295
|
+
self.linear = linearNoBias(
|
|
296
|
+
2 * self.num_tok_pos_bins + (2 * self.s_max + 2) + 1, c_z
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
def forward(self, f):
|
|
300
|
+
b_samechain_II = f["asym_id"].unsqueeze(-1) == f["asym_id"].unsqueeze(-2)
|
|
301
|
+
b_same_entity_II = f["entity_id"].unsqueeze(-1) == f["entity_id"].unsqueeze(-2)
|
|
302
|
+
d_residue_II = torch.where(
|
|
303
|
+
b_samechain_II,
|
|
304
|
+
torch.clip(
|
|
305
|
+
f["residue_index"].unsqueeze(-1)
|
|
306
|
+
- f["residue_index"].unsqueeze(-2)
|
|
307
|
+
+ self.r_max,
|
|
308
|
+
0,
|
|
309
|
+
2 * self.r_max,
|
|
310
|
+
),
|
|
311
|
+
2 * self.r_max + 1,
|
|
312
|
+
)
|
|
313
|
+
b_sameresidue_II = f["residue_index"].unsqueeze(-1) == f[
|
|
314
|
+
"residue_index"
|
|
315
|
+
].unsqueeze(-2)
|
|
316
|
+
tok_distance = (
|
|
317
|
+
f["token_index"].unsqueeze(-1) - f["token_index"].unsqueeze(-2) + self.r_max
|
|
318
|
+
)
|
|
319
|
+
d_token_II = torch.where(
|
|
320
|
+
b_samechain_II * b_sameresidue_II,
|
|
321
|
+
torch.clip(
|
|
322
|
+
tok_distance,
|
|
323
|
+
0,
|
|
324
|
+
2 * self.r_max,
|
|
325
|
+
),
|
|
326
|
+
2 * self.r_max + 1,
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
# Chain distances are kept
|
|
330
|
+
d_chain_II = torch.where(
|
|
331
|
+
# NOTE: Implementing bugfix from the Protenix Technical report, where we use `same_entity` instead of `not same_chain` (as in the AF-3 pseudocode)
|
|
332
|
+
# Reference: https://github.com/bytedance/Protenix/blob/main/Protenix_Technical_Report.pdf
|
|
333
|
+
b_same_entity_II,
|
|
334
|
+
torch.clip(
|
|
335
|
+
f["sym_id"].unsqueeze(-1) - f["sym_id"].unsqueeze(-2) + self.s_max,
|
|
336
|
+
0,
|
|
337
|
+
2 * self.s_max,
|
|
338
|
+
),
|
|
339
|
+
2 * self.s_max + 1,
|
|
340
|
+
)
|
|
341
|
+
A_relchain_II = one_hot(d_chain_II.long(), 2 * self.s_max + 2)
|
|
342
|
+
|
|
343
|
+
#########################################################
|
|
344
|
+
# Cancel out distances from unidexed motifs
|
|
345
|
+
unindexing_pair_mask = f[
|
|
346
|
+
"unindexing_pair_mask"
|
|
347
|
+
] # [L, L] representing the parts which shouldnt' talk to one another
|
|
348
|
+
|
|
349
|
+
# Special position case
|
|
350
|
+
d_token_II[unindexing_pair_mask] = self.num_tok_pos_bins - 1
|
|
351
|
+
d_residue_II[unindexing_pair_mask] = self.num_tok_pos_bins - 1
|
|
352
|
+
|
|
353
|
+
A_relpos_II = one_hot(d_residue_II.long(), self.num_tok_pos_bins)
|
|
354
|
+
A_reltoken_II = one_hot(d_token_II, self.num_tok_pos_bins)
|
|
355
|
+
#########################################################
|
|
356
|
+
|
|
357
|
+
return self.linear(
|
|
358
|
+
torch.cat(
|
|
359
|
+
[
|
|
360
|
+
A_relpos_II,
|
|
361
|
+
A_reltoken_II,
|
|
362
|
+
b_same_entity_II.unsqueeze(-1),
|
|
363
|
+
A_relchain_II,
|
|
364
|
+
],
|
|
365
|
+
dim=-1,
|
|
366
|
+
).to(torch.float)
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
class VirtualPredictor(nn.Module):
|
|
371
|
+
def __init__(self, c_atom):
|
|
372
|
+
super(VirtualPredictor, self).__init__()
|
|
373
|
+
self.process_atom_embeddings = nn.Sequential(
|
|
374
|
+
RMSNorm((c_atom,)), linearNoBias(c_atom, 1)
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
def forward(self, Q_L):
|
|
378
|
+
return self.process_atom_embeddings(Q_L)
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
class SequenceHead(nn.Module):
|
|
382
|
+
def __init__(self, c_token):
|
|
383
|
+
super(SequenceHead, self).__init__()
|
|
384
|
+
|
|
385
|
+
# Distogram feature extraction
|
|
386
|
+
self.dist_fc1 = nn.Linear(196, 128)
|
|
387
|
+
self.dist_relu = nn.ReLU()
|
|
388
|
+
self.dist_fc2 = nn.Linear(128, 64)
|
|
389
|
+
|
|
390
|
+
# Embedding feature extraction
|
|
391
|
+
self.embed_fc1 = nn.Linear(c_token, 128)
|
|
392
|
+
self.embed_relu = nn.ReLU()
|
|
393
|
+
self.embed_fc2 = nn.Linear(128, 64)
|
|
394
|
+
|
|
395
|
+
# Fusion layer
|
|
396
|
+
self.fusion_fc = nn.Linear(128, 32)
|
|
397
|
+
|
|
398
|
+
# Sequence encoding
|
|
399
|
+
self.sequence_encoding_ = AF3SequenceEncoding()
|
|
400
|
+
|
|
401
|
+
def forward(self, A_I, Q_L, X_L, f):
|
|
402
|
+
B, L, _ = X_L.shape
|
|
403
|
+
max_res_id = f["atom_to_token_map"].max().item() + 1
|
|
404
|
+
|
|
405
|
+
# Detach tensors to avoid gradients through main module
|
|
406
|
+
# X_L = X_L.detach()
|
|
407
|
+
# A_I = A_I.detach()
|
|
408
|
+
# Q_L = Q_L.detach()
|
|
409
|
+
|
|
410
|
+
# Compute distograms
|
|
411
|
+
residue_distogram = torch.zeros(B, max_res_id, 14, 14, device=X_L.device)
|
|
412
|
+
for i in range(max_res_id):
|
|
413
|
+
residue_mask = f["atom_to_token_map"] == i
|
|
414
|
+
if residue_mask.sum() == 14:
|
|
415
|
+
coords = X_L[:, residue_mask] # (B, 14, 3)
|
|
416
|
+
residue_distogram[:, i] = torch.cdist(coords, coords)
|
|
417
|
+
|
|
418
|
+
# Flatten distogram
|
|
419
|
+
dist_features = residue_distogram.view(B, max_res_id, 196)
|
|
420
|
+
|
|
421
|
+
# Pass through separate MLPs
|
|
422
|
+
dist_out = self.dist_fc1(dist_features)
|
|
423
|
+
dist_out = self.dist_relu(dist_out)
|
|
424
|
+
dist_out = self.dist_fc2(dist_out)
|
|
425
|
+
|
|
426
|
+
embed_out = self.embed_fc1(A_I)
|
|
427
|
+
embed_out = self.embed_relu(embed_out)
|
|
428
|
+
embed_out = self.embed_fc2(embed_out)
|
|
429
|
+
|
|
430
|
+
# Fusion via concatenation
|
|
431
|
+
fused = torch.cat([dist_out, embed_out], dim=-1)
|
|
432
|
+
Seq_I = self.fusion_fc(fused)
|
|
433
|
+
|
|
434
|
+
indices = self.decode(Seq_I)
|
|
435
|
+
|
|
436
|
+
return Seq_I, indices
|
|
437
|
+
|
|
438
|
+
def decode(self, Seq_I):
|
|
439
|
+
indices = Seq_I.argmax(dim=-1) # [B, L]
|
|
440
|
+
return indices
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
class LinearSequenceHead(nn.Module):
|
|
444
|
+
def __init__(self, c_token):
|
|
445
|
+
super().__init__()
|
|
446
|
+
n_tok_all = 32
|
|
447
|
+
disallowed_idxs = AF3SequenceEncoding().encode(["UNK", "X", "DX", "<G>"])
|
|
448
|
+
mask = torch.ones(n_tok_all, dtype=torch.bool)
|
|
449
|
+
mask[disallowed_idxs] = False
|
|
450
|
+
self.register_buffer("valid_out_mask", mask)
|
|
451
|
+
self.linear = nn.Linear(c_token, n_tok_all)
|
|
452
|
+
|
|
453
|
+
def forward(self, A_I, **_):
|
|
454
|
+
logits = self.linear(A_I)
|
|
455
|
+
indices = self.decode(logits)
|
|
456
|
+
return logits, indices
|
|
457
|
+
|
|
458
|
+
def decode(self, logits):
|
|
459
|
+
# logits: [D, L, 28]
|
|
460
|
+
# indices: [D, L] in [0,32-1]
|
|
461
|
+
D, I, _ = logits.shape
|
|
462
|
+
probs = F.softmax(logits, dim=-1)
|
|
463
|
+
probs = probs * self.valid_out_mask[None, None, :].to(probs.device)
|
|
464
|
+
probs = probs / (probs.sum(dim=-1, keepdim=True) + 1e-8)
|
|
465
|
+
indices = probs.argmax(axis=-1)
|
|
466
|
+
return indices
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
class Upcast(nn.Module):
|
|
470
|
+
def __init__(
|
|
471
|
+
self, c_token, c_atom, method="broadcast", cross_attention_block=None, n_split=6
|
|
472
|
+
):
|
|
473
|
+
super().__init__()
|
|
474
|
+
self.method = method
|
|
475
|
+
self.n_split = n_split
|
|
476
|
+
if self.method == "broadcast":
|
|
477
|
+
self.project = nn.Sequential(
|
|
478
|
+
RMSNorm((c_token,)), linearNoBias(c_token, c_atom)
|
|
479
|
+
)
|
|
480
|
+
elif self.method == "cross_attention":
|
|
481
|
+
self.gca = GatedCrossAttention(
|
|
482
|
+
c_query=c_atom, c_kv=c_token // self.n_split, **cross_attention_block
|
|
483
|
+
)
|
|
484
|
+
else:
|
|
485
|
+
raise ValueError(f"Unknown upcast method: {self.method}")
|
|
486
|
+
|
|
487
|
+
def forward_(self, Q_IA, A_I, valid_mask=None):
|
|
488
|
+
if self.method == "broadcast":
|
|
489
|
+
Q_IA = Q_IA + self.project(A_I)[..., None, :]
|
|
490
|
+
elif self.method == "cross_attention":
|
|
491
|
+
assert exists(A_I) and exists(valid_mask)
|
|
492
|
+
# Split Tokens
|
|
493
|
+
A_I = rearrange(A_I, "b n (s c) -> b n s c", s=self.n_split)
|
|
494
|
+
n_tokens, n_atom_per_tok = Q_IA.shape[1], Q_IA.shape[2]
|
|
495
|
+
|
|
496
|
+
# Attention mask: ..., n_atom_per_tok, n_split
|
|
497
|
+
attn_mask = torch.full(
|
|
498
|
+
(n_tokens, 1, n_atom_per_tok), True, device=Q_IA.device
|
|
499
|
+
)
|
|
500
|
+
attn_mask[~valid_mask.view_as(attn_mask)] = False
|
|
501
|
+
|
|
502
|
+
attn_mask = torch.ones(
|
|
503
|
+
(n_tokens, n_atom_per_tok, self.n_split), device=A_I.device, dtype=bool
|
|
504
|
+
)
|
|
505
|
+
attn_mask[~valid_mask, :] = False
|
|
506
|
+
|
|
507
|
+
Q_IA = Q_IA + self.gca(q=Q_IA, kv=A_I, attn_mask=attn_mask)
|
|
508
|
+
return Q_IA
|
|
509
|
+
|
|
510
|
+
def forward(self, Q_L, A_I, tok_idx):
|
|
511
|
+
valid_mask = build_valid_mask(tok_idx)
|
|
512
|
+
Q_IA = ungroup_atoms(Q_L, valid_mask)
|
|
513
|
+
Q_IA = self.forward_(Q_IA, A_I, valid_mask)
|
|
514
|
+
Q_L = group_atoms(Q_IA, valid_mask)
|
|
515
|
+
return Q_L
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
class Downcast(nn.Module):
|
|
519
|
+
"""Downcast modules for when atoms are already reshaped from N_atoms -> (N_tokens, 14)"""
|
|
520
|
+
|
|
521
|
+
def __init__(
|
|
522
|
+
self, c_atom, c_token, c_s=None, method="mean", cross_attention_block=None
|
|
523
|
+
):
|
|
524
|
+
super().__init__()
|
|
525
|
+
self.method = method
|
|
526
|
+
self.c_token = c_token
|
|
527
|
+
self.c_atom = c_atom
|
|
528
|
+
if c_s is not None:
|
|
529
|
+
self.process_s = nn.Sequential(
|
|
530
|
+
RMSNorm((c_s,)),
|
|
531
|
+
linearNoBias(c_s, c_token),
|
|
532
|
+
)
|
|
533
|
+
else:
|
|
534
|
+
self.process_s = None
|
|
535
|
+
|
|
536
|
+
if self.method == "mean":
|
|
537
|
+
self.project = linearNoBias(c_atom, c_token)
|
|
538
|
+
elif self.method == "cross_attention":
|
|
539
|
+
self.gca = GatedCrossAttention(
|
|
540
|
+
c_query=c_token,
|
|
541
|
+
c_kv=c_atom,
|
|
542
|
+
**cross_attention_block,
|
|
543
|
+
)
|
|
544
|
+
else:
|
|
545
|
+
raise ValueError(f"Unknown downcast method: {self.method}")
|
|
546
|
+
|
|
547
|
+
def forward_(self, Q_IA, A_I, S_I=None, valid_mask=None):
|
|
548
|
+
if self.method == "mean":
|
|
549
|
+
A_I_update = self.project(Q_IA).sum(-2) / valid_mask.sum(-1, keepdim=True)
|
|
550
|
+
elif self.method == "cross_attention":
|
|
551
|
+
assert exists(A_I) and exists(valid_mask)
|
|
552
|
+
# Attention mask: ..., 1, n_atom_per_tok (1 querying token to atoms in token)
|
|
553
|
+
attn_mask = valid_mask[..., None, :]
|
|
554
|
+
A_I_update = self.gca(
|
|
555
|
+
q=A_I[..., None, :], kv=Q_IA, attn_mask=attn_mask
|
|
556
|
+
).squeeze(-2)
|
|
557
|
+
|
|
558
|
+
A_I = A_I + A_I_update if exists(A_I) else A_I_update
|
|
559
|
+
|
|
560
|
+
if self.process_s is not None:
|
|
561
|
+
A_I = A_I + self.process_s(S_I)
|
|
562
|
+
return A_I
|
|
563
|
+
|
|
564
|
+
def forward(self, Q_L, A_I, S_I=None, tok_idx=None):
|
|
565
|
+
valid_mask = build_valid_mask(tok_idx)
|
|
566
|
+
if Q_L.ndim == 2:
|
|
567
|
+
squeeze = True
|
|
568
|
+
Q_L = Q_L.unsqueeze(0)
|
|
569
|
+
else:
|
|
570
|
+
squeeze = False
|
|
571
|
+
|
|
572
|
+
A_I = A_I.unsqueeze(0) if exists(A_I) and A_I.ndim == 2 else A_I
|
|
573
|
+
S_I = S_I.unsqueeze(0) if exists(S_I) and S_I.ndim == 2 else S_I
|
|
574
|
+
|
|
575
|
+
Q_IA = ungroup_atoms(Q_L, valid_mask)
|
|
576
|
+
|
|
577
|
+
A_I = self.forward_(Q_IA, A_I, S_I, valid_mask=valid_mask)
|
|
578
|
+
|
|
579
|
+
if squeeze:
|
|
580
|
+
A_I = A_I.squeeze(0)
|
|
581
|
+
return A_I
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
######################################################################################
|
|
585
|
+
########################## Local Atom Transformer ##########################
|
|
586
|
+
######################################################################################
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
class LocalTokenTransformer(nn.Module):
|
|
590
|
+
def __init__(
|
|
591
|
+
self,
|
|
592
|
+
c_token,
|
|
593
|
+
c_tokenpair,
|
|
594
|
+
c_s,
|
|
595
|
+
n_block,
|
|
596
|
+
diffusion_transformer_block,
|
|
597
|
+
n_registers=None,
|
|
598
|
+
n_local_tokens=8,
|
|
599
|
+
n_keys=32,
|
|
600
|
+
):
|
|
601
|
+
super().__init__()
|
|
602
|
+
self.n_local_tokens = n_local_tokens
|
|
603
|
+
self.n_keys = n_keys
|
|
604
|
+
self.blocks = nn.ModuleList(
|
|
605
|
+
[
|
|
606
|
+
StructureLocalAtomTransformerBlock(
|
|
607
|
+
c_atom=c_token,
|
|
608
|
+
c_s=c_s,
|
|
609
|
+
c_atompair=c_tokenpair,
|
|
610
|
+
**diffusion_transformer_block,
|
|
611
|
+
)
|
|
612
|
+
for _ in range(n_block)
|
|
613
|
+
]
|
|
614
|
+
)
|
|
615
|
+
|
|
616
|
+
def forward(self, A_I, S_I, Z_II, f, X_L, full=False):
|
|
617
|
+
indices = create_attention_indices(
|
|
618
|
+
X_L=X_L,
|
|
619
|
+
f=f,
|
|
620
|
+
tok_idx=torch.arange(A_I.shape[1], device=A_I.device),
|
|
621
|
+
n_attn_keys=self.n_keys,
|
|
622
|
+
n_attn_seq_neighbours=self.n_local_tokens,
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
for i, block in enumerate(self.blocks):
|
|
626
|
+
# Set checkpointing
|
|
627
|
+
block.attention_pair_bias.use_checkpointing = not DISABLE_CHECKPOINTING
|
|
628
|
+
# A_I: [B, L, C_token]
|
|
629
|
+
# S_I: [B, L, C_s]
|
|
630
|
+
# Z_II: [B, L, L, C_tokenpair]
|
|
631
|
+
A_I = block(
|
|
632
|
+
A_I,
|
|
633
|
+
S_I,
|
|
634
|
+
Z_II,
|
|
635
|
+
indices=indices,
|
|
636
|
+
full=full, # (self.training and torch.is_grad_enabled()), # Does not accelerate inference, but memory *does* scale better
|
|
637
|
+
)
|
|
638
|
+
|
|
639
|
+
return A_I
|
|
640
|
+
|
|
641
|
+
|
|
642
|
+
class LocalAtomTransformer(nn.Module):
|
|
643
|
+
def __init__(self, c_atom, c_s, c_atompair, atom_transformer_block, n_blocks):
|
|
644
|
+
super().__init__()
|
|
645
|
+
self.blocks = nn.ModuleList(
|
|
646
|
+
[
|
|
647
|
+
StructureLocalAtomTransformerBlock(
|
|
648
|
+
c_atom=c_atom,
|
|
649
|
+
c_s=c_s,
|
|
650
|
+
c_atompair=c_atompair,
|
|
651
|
+
**atom_transformer_block,
|
|
652
|
+
)
|
|
653
|
+
for _ in range(n_blocks)
|
|
654
|
+
]
|
|
655
|
+
)
|
|
656
|
+
|
|
657
|
+
def forward(self, Q_L, C_L, P_LL, **kwargs):
|
|
658
|
+
for block in self.blocks:
|
|
659
|
+
Q_L = block(Q_L, C_L, P_LL, **kwargs)
|
|
660
|
+
return Q_L
|
|
661
|
+
|
|
662
|
+
|
|
663
|
+
class StructureLocalAtomTransformerBlock(nn.Module):
|
|
664
|
+
def __init__(
|
|
665
|
+
self,
|
|
666
|
+
*,
|
|
667
|
+
c_atom,
|
|
668
|
+
c_s,
|
|
669
|
+
c_atompair,
|
|
670
|
+
dropout,
|
|
671
|
+
no_residual_connection_between_attention_and_transition,
|
|
672
|
+
**transformer_block,
|
|
673
|
+
):
|
|
674
|
+
super().__init__()
|
|
675
|
+
assert not no_residual_connection_between_attention_and_transition
|
|
676
|
+
self.c_s = c_s
|
|
677
|
+
self.dropout = nn.Dropout(dropout)
|
|
678
|
+
self.attention_pair_bias = LocalAttentionPairBias(
|
|
679
|
+
c_a=c_atom, c_s=c_s, c_pair=c_atompair, **transformer_block
|
|
680
|
+
)
|
|
681
|
+
if exists(c_s):
|
|
682
|
+
self.transition_block = ConditionedTransitionBlock(c_token=c_atom, c_s=c_s)
|
|
683
|
+
else:
|
|
684
|
+
self.transition_block = Transition(c=c_atom, n=4)
|
|
685
|
+
|
|
686
|
+
def forward(
|
|
687
|
+
self,
|
|
688
|
+
Q_L, # [..., I, C_token]
|
|
689
|
+
C_L, # [..., I, C_s]
|
|
690
|
+
P_LL, # [..., I, I, C_tokenpair]
|
|
691
|
+
f=None,
|
|
692
|
+
chunked_pairwise_embedder=None,
|
|
693
|
+
initializer_outputs=None,
|
|
694
|
+
**kwargs,
|
|
695
|
+
):
|
|
696
|
+
Q_L = Q_L + self.dropout(
|
|
697
|
+
self.attention_pair_bias(
|
|
698
|
+
Q_L,
|
|
699
|
+
C_L,
|
|
700
|
+
P_LL,
|
|
701
|
+
f=f,
|
|
702
|
+
chunked_pairwise_embedder=chunked_pairwise_embedder,
|
|
703
|
+
initializer_outputs=initializer_outputs,
|
|
704
|
+
**kwargs,
|
|
705
|
+
)
|
|
706
|
+
)
|
|
707
|
+
if exists(C_L):
|
|
708
|
+
Q_L = Q_L + self.transition_block(Q_L, C_L)
|
|
709
|
+
else:
|
|
710
|
+
Q_L = Q_L + self.transition_block(Q_L)
|
|
711
|
+
return Q_L
|
|
712
|
+
|
|
713
|
+
|
|
714
|
+
class CompactStreamingDecoder(nn.Module):
|
|
715
|
+
def __init__(
|
|
716
|
+
self,
|
|
717
|
+
c_atom,
|
|
718
|
+
c_atompair,
|
|
719
|
+
c_token,
|
|
720
|
+
c_s,
|
|
721
|
+
c_tokenpair,
|
|
722
|
+
atom_transformer_block,
|
|
723
|
+
upcast,
|
|
724
|
+
downcast,
|
|
725
|
+
n_blocks,
|
|
726
|
+
diffusion_transformer_block=False,
|
|
727
|
+
):
|
|
728
|
+
super().__init__()
|
|
729
|
+
self.n_blocks = n_blocks
|
|
730
|
+
|
|
731
|
+
self.upcast = nn.ModuleList(
|
|
732
|
+
[Upcast(c_atom=c_atom, c_token=c_token, **upcast) for _ in range(n_blocks)]
|
|
733
|
+
)
|
|
734
|
+
self.atom_transformer = nn.ModuleList(
|
|
735
|
+
[
|
|
736
|
+
StructureLocalAtomTransformerBlock(
|
|
737
|
+
c_atom=c_atom,
|
|
738
|
+
c_s=c_atom,
|
|
739
|
+
c_atompair=c_atompair,
|
|
740
|
+
**atom_transformer_block,
|
|
741
|
+
)
|
|
742
|
+
for _ in range(n_blocks)
|
|
743
|
+
]
|
|
744
|
+
)
|
|
745
|
+
self.downcast = Downcast(c_atom=c_atom, c_token=c_token, c_s=c_s, **downcast)
|
|
746
|
+
|
|
747
|
+
def forward(
|
|
748
|
+
self,
|
|
749
|
+
A_I,
|
|
750
|
+
S_I,
|
|
751
|
+
Z_II,
|
|
752
|
+
Q_L,
|
|
753
|
+
C_L,
|
|
754
|
+
P_LL,
|
|
755
|
+
tok_idx,
|
|
756
|
+
indices,
|
|
757
|
+
f=None,
|
|
758
|
+
chunked_pairwise_embedder=None,
|
|
759
|
+
initializer_outputs=None,
|
|
760
|
+
):
|
|
761
|
+
for i in range(self.n_blocks):
|
|
762
|
+
Q_L = self.upcast[i](Q_L, A_I, tok_idx=tok_idx)
|
|
763
|
+
Q_L = self.atom_transformer[i](
|
|
764
|
+
Q_L,
|
|
765
|
+
C_L,
|
|
766
|
+
P_LL,
|
|
767
|
+
indices=indices,
|
|
768
|
+
f=f,
|
|
769
|
+
chunked_pairwise_embedder=chunked_pairwise_embedder,
|
|
770
|
+
initializer_outputs=initializer_outputs,
|
|
771
|
+
)
|
|
772
|
+
|
|
773
|
+
# Downcast to sequence
|
|
774
|
+
A_I = self.downcast(Q_L.detach(), A_I.detach(), S_I.detach(), tok_idx=tok_idx)
|
|
775
|
+
|
|
776
|
+
o = {}
|
|
777
|
+
return A_I, Q_L, o
|