rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,783 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from rf3.data.ground_truth_template import (
|
|
3
|
+
af3_noise_scale_to_noise_level,
|
|
4
|
+
)
|
|
5
|
+
from rf3.model.layers.af3_diffusion_transformer import AtomTransformer
|
|
6
|
+
from rf3.model.layers.attention import TriangleAttention, TriangleMultiplication
|
|
7
|
+
from rf3.model.layers.layer_utils import (
|
|
8
|
+
MultiDimLinear,
|
|
9
|
+
Transition,
|
|
10
|
+
collapse,
|
|
11
|
+
create_batch_dimension_if_not_present,
|
|
12
|
+
linearNoBias,
|
|
13
|
+
)
|
|
14
|
+
from rf3.model.layers.mlff import ConformerEmbeddingWeightedAverage
|
|
15
|
+
from rf3.model.layers.outer_product import (
|
|
16
|
+
OuterProductMean_AF3,
|
|
17
|
+
)
|
|
18
|
+
from rf3.model.RF3_blocks import MSAPairWeightedAverage, MSASubsampleEmbedder
|
|
19
|
+
from torch import nn
|
|
20
|
+
from torch.nn.functional import one_hot, relu
|
|
21
|
+
|
|
22
|
+
from foundry.model.layers.blocks import Dropout
|
|
23
|
+
from foundry.training.checkpoint import activation_checkpointing
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class AtomAttentionEncoderPairformer(nn.Module):
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
c_atom,
|
|
30
|
+
c_atompair,
|
|
31
|
+
c_token,
|
|
32
|
+
c_tokenpair,
|
|
33
|
+
c_s,
|
|
34
|
+
atom_1d_features,
|
|
35
|
+
c_atom_1d_features,
|
|
36
|
+
atom_transformer,
|
|
37
|
+
use_inv_dist_squared: bool = False, # HACK: For 9/21 checkpoint, default to False (as this argument was not present in the checkpoint config)
|
|
38
|
+
use_atom_level_embedding: bool = False,
|
|
39
|
+
atom_level_embedding_dim: int = 384,
|
|
40
|
+
):
|
|
41
|
+
super().__init__()
|
|
42
|
+
self.c_atom = c_atom
|
|
43
|
+
self.c_atompair = c_atompair
|
|
44
|
+
self.c_token = c_token
|
|
45
|
+
self.c_tokenpair = c_tokenpair
|
|
46
|
+
self.c_s = c_s
|
|
47
|
+
self.atom_1d_features = atom_1d_features
|
|
48
|
+
|
|
49
|
+
self.process_input_features = linearNoBias(c_atom_1d_features, c_atom)
|
|
50
|
+
|
|
51
|
+
self.process_d = linearNoBias(3, c_atompair)
|
|
52
|
+
self.process_inverse_dist = linearNoBias(1, c_atompair)
|
|
53
|
+
self.process_valid_mask = linearNoBias(1, c_atompair)
|
|
54
|
+
|
|
55
|
+
self.use_atom_level_embedding = use_atom_level_embedding
|
|
56
|
+
|
|
57
|
+
self.process_single_l = nn.Sequential(
|
|
58
|
+
nn.ReLU(), linearNoBias(c_atom, c_atompair)
|
|
59
|
+
)
|
|
60
|
+
self.process_single_m = nn.Sequential(
|
|
61
|
+
nn.ReLU(), linearNoBias(c_atom, c_atompair)
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
self.pair_mlp = nn.Sequential(
|
|
65
|
+
nn.ReLU(),
|
|
66
|
+
linearNoBias(self.c_atompair, c_atompair),
|
|
67
|
+
nn.ReLU(),
|
|
68
|
+
linearNoBias(self.c_atompair, c_atompair),
|
|
69
|
+
nn.ReLU(),
|
|
70
|
+
linearNoBias(self.c_atompair, c_atompair),
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
self.process_q = nn.Sequential(
|
|
74
|
+
linearNoBias(c_atom, c_token),
|
|
75
|
+
nn.ReLU(),
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
self.atom_transformer = AtomTransformer(
|
|
79
|
+
c_atom=c_atom, c_atompair=c_atompair, **atom_transformer
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
self.use_inv_dist_squared = use_inv_dist_squared
|
|
83
|
+
|
|
84
|
+
if self.use_atom_level_embedding:
|
|
85
|
+
self.process_atom_level_embedding = ConformerEmbeddingWeightedAverage(
|
|
86
|
+
atom_level_embedding_dim=atom_level_embedding_dim,
|
|
87
|
+
c_atompair=c_atompair,
|
|
88
|
+
c_atom=c_atom,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
def forward(
|
|
92
|
+
self,
|
|
93
|
+
f, # Dict (Input feature dictionary)
|
|
94
|
+
R_L, # [D, L, 3]
|
|
95
|
+
S_trunk_I, # [B, I, C_S_trunk] [...,I,C_S_trunk]
|
|
96
|
+
Z_II, # [B, I, I, C_Z] [...,I,I,C_Z]
|
|
97
|
+
):
|
|
98
|
+
assert R_L is None
|
|
99
|
+
assert S_trunk_I is None
|
|
100
|
+
assert Z_II is None
|
|
101
|
+
|
|
102
|
+
# ... get the number of atoms and tokens
|
|
103
|
+
tok_idx = f["atom_to_token_map"]
|
|
104
|
+
L = len(tok_idx) # N_atom
|
|
105
|
+
I = tok_idx.max() + 1 # N_token
|
|
106
|
+
|
|
107
|
+
# ... flatten the last two dimensions of ref_atom_name_chars
|
|
108
|
+
# (the letter dimension and the one-hot encoding of the unicode character dimension)
|
|
109
|
+
f["ref_atom_name_chars"] = f["ref_atom_name_chars"].reshape(
|
|
110
|
+
L, -1
|
|
111
|
+
) # [L, 4, 64] -> [L, 256], where L = N_atom
|
|
112
|
+
|
|
113
|
+
# Atom single conditioning (C_L): Linearly embed concatenated per-atom features
|
|
114
|
+
# (e.g., ref_pos, ref_charge, ref_mask, ref_element, ref_atom_name_chars)
|
|
115
|
+
C_L = self.process_input_features(
|
|
116
|
+
torch.cat(
|
|
117
|
+
tuple(
|
|
118
|
+
collapse(f[feature_name], L)
|
|
119
|
+
for feature_name in self.atom_1d_features
|
|
120
|
+
),
|
|
121
|
+
dim=-1,
|
|
122
|
+
)
|
|
123
|
+
) # [L, C_atom]
|
|
124
|
+
|
|
125
|
+
if self.use_atom_level_embedding:
|
|
126
|
+
assert "atom_level_embedding" in f
|
|
127
|
+
C_L = C_L + self.process_atom_level_embedding(f["atom_level_embedding"])
|
|
128
|
+
|
|
129
|
+
# Now, we have the single conditioning (C_L) for each atom. We will:
|
|
130
|
+
# 1. Use C_L to initialize the pair atom representation
|
|
131
|
+
# 2. Pass C_L as a skip connection to the diffusion module
|
|
132
|
+
|
|
133
|
+
# Embed offsets between atom reference positions
|
|
134
|
+
# ref_pos is of shape [L, 3], so ref_pos.unsqueeze(-2) is of shape [L, 1, 3] and ref_pos.unsqueeze(-3) is of shape [1, L, 3]
|
|
135
|
+
# We then take the outer difference between these two tensors to get a tensor of shape [L, L, 3] (via broadcasting both to shape [L, L, 3], and then taking the difference)
|
|
136
|
+
D_LL = f["ref_pos"].unsqueeze(-2) - f["ref_pos"].unsqueeze(
|
|
137
|
+
-3
|
|
138
|
+
) # [L, 1, 3] - [1, L, 3] -> [L, L, 3]
|
|
139
|
+
|
|
140
|
+
# Create a mask indicating if two atoms are on the same chain AND the same residue (e.g., the same ref_space_uid)
|
|
141
|
+
# (We add a singleton dimension to the mask to make it broadcastable with D_LL, which will be useful later)
|
|
142
|
+
V_LL = (
|
|
143
|
+
f["ref_space_uid"].unsqueeze(-1) == f["ref_space_uid"].unsqueeze(-2)
|
|
144
|
+
).unsqueeze(-1) # [L, 1] == [1, L] -> [L, L, 1]
|
|
145
|
+
|
|
146
|
+
@activation_checkpointing
|
|
147
|
+
def embed_features(C_L, D_LL, V_LL):
|
|
148
|
+
P_LL = self.process_d(D_LL) * V_LL # [L, L, 3] -> [L, L, C_atompair]
|
|
149
|
+
|
|
150
|
+
# Embed pairwise inverse squared distances, and the valid mask
|
|
151
|
+
if self.use_inv_dist_squared:
|
|
152
|
+
P_LL += (
|
|
153
|
+
self.process_inverse_dist(
|
|
154
|
+
1 / (1 + torch.sum(D_LL * D_LL, dim=-1, keepdim=True))
|
|
155
|
+
)
|
|
156
|
+
* V_LL
|
|
157
|
+
) # [L, L, 1] -> [L, L, C_atompair]
|
|
158
|
+
else:
|
|
159
|
+
P_LL = (
|
|
160
|
+
P_LL
|
|
161
|
+
+ self.process_inverse_dist(
|
|
162
|
+
1 / (1 + torch.linalg.norm(D_LL, dim=-1, keepdim=True))
|
|
163
|
+
)
|
|
164
|
+
* V_LL
|
|
165
|
+
) # [L, L, 1] -> [L, L, C_atompair]
|
|
166
|
+
|
|
167
|
+
P_LL = P_LL + self.process_valid_mask(V_LL.to(P_LL.dtype)) * V_LL
|
|
168
|
+
|
|
169
|
+
# Initialise the atom single representation as the single conditioning.
|
|
170
|
+
# NOTE: We create a new view on the tensor, so that the original tensor is not modified (unless we perform an in-place operation)
|
|
171
|
+
Q_L = C_L
|
|
172
|
+
|
|
173
|
+
# Add the combined single conditioning to the pair representation.
|
|
174
|
+
# (With a residual connection)
|
|
175
|
+
P_LL = P_LL + (
|
|
176
|
+
self.process_single_l(C_L).unsqueeze(-2)
|
|
177
|
+
+ self.process_single_m(C_L).unsqueeze(-3)
|
|
178
|
+
) # [L, 1, C_atompair] + [1, L, C_atompair] -> [L, L, C_atompair]
|
|
179
|
+
|
|
180
|
+
# Run a small MLP on the pair activations
|
|
181
|
+
# (With a residual connection)
|
|
182
|
+
P_LL = P_LL + self.pair_mlp(
|
|
183
|
+
P_LL
|
|
184
|
+
) # [L, L, C_atompair] -> [L, L, C_atompair]
|
|
185
|
+
|
|
186
|
+
# Cross attention transformer
|
|
187
|
+
Q_L = self.atom_transformer(Q_L, C_L, P_LL) # [L, C_atom]
|
|
188
|
+
|
|
189
|
+
# ...get the desired shape of the per-token representation, which is [I, C_token]
|
|
190
|
+
A_I_shape = Q_L.shape[:-2] + (
|
|
191
|
+
I,
|
|
192
|
+
self.c_token,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# Aggregate per-atom representation to per-token representation
|
|
196
|
+
# (Set the per-token representation to be the mean activation of all atoms in the token)
|
|
197
|
+
processed_Q_L = self.process_q(Q_L) # [L, C_atom] -> [L, C_token]
|
|
198
|
+
# Ensure dtype consistency for index_reduce
|
|
199
|
+
processed_Q_L = processed_Q_L.to(Q_L.dtype)
|
|
200
|
+
|
|
201
|
+
A_I = torch.zeros(
|
|
202
|
+
A_I_shape, device=Q_L.device, dtype=Q_L.dtype
|
|
203
|
+
).index_reduce(
|
|
204
|
+
-2, # Operate on the second-to-last dimension (the atom dimension)
|
|
205
|
+
f[
|
|
206
|
+
"atom_to_token_map"
|
|
207
|
+
].long(), # [L], mapping from atom index to token index. Must be a torch.int64 or torch.int32 tensor.
|
|
208
|
+
processed_Q_L, # [L, C_atom] -> [L, C_token]
|
|
209
|
+
"mean",
|
|
210
|
+
include_self=False, # Do not use the original values in A_I (all zeros) when aggregating
|
|
211
|
+
) # [L, C_atom] -> [I, C_token]
|
|
212
|
+
|
|
213
|
+
return A_I, Q_L, C_L, P_LL
|
|
214
|
+
|
|
215
|
+
return embed_features(C_L, D_LL, V_LL)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class AttentionPairBiasPairformerDeepspeed(nn.Module):
|
|
219
|
+
def __init__(self, c_a, c_s, c_pair, n_head):
|
|
220
|
+
super().__init__()
|
|
221
|
+
self.n_head = n_head
|
|
222
|
+
self.c_a = c_a
|
|
223
|
+
self.c_pair = c_pair
|
|
224
|
+
self.c = c_a // n_head
|
|
225
|
+
|
|
226
|
+
self.to_q = MultiDimLinear(c_a, (n_head, self.c), bias=False)
|
|
227
|
+
self.to_k = MultiDimLinear(c_a, (n_head, self.c), bias=False)
|
|
228
|
+
self.to_v = MultiDimLinear(c_a, (n_head, self.c), bias=False)
|
|
229
|
+
self.to_b = linearNoBias(c_pair, n_head)
|
|
230
|
+
self.to_g = nn.Sequential(
|
|
231
|
+
MultiDimLinear(c_a, (n_head, self.c), bias=False),
|
|
232
|
+
nn.Sigmoid(),
|
|
233
|
+
)
|
|
234
|
+
self.to_a = linearNoBias(c_a, c_a)
|
|
235
|
+
# self.linear_output_project = nn.Sequential(
|
|
236
|
+
# LinearBiasInit(c_s, c_a, biasinit=-2.),
|
|
237
|
+
# nn.Sigmoid(),
|
|
238
|
+
# )
|
|
239
|
+
self.ln_0 = nn.LayerNorm((c_pair,))
|
|
240
|
+
# self.ada_ln_1 = AdaLN(c_a=c_a, c_s=c_s)
|
|
241
|
+
self.ln_1 = nn.LayerNorm((c_a,))
|
|
242
|
+
self.use_deepspeed_evo = False
|
|
243
|
+
self.force_bfloat16 = True
|
|
244
|
+
|
|
245
|
+
def forward(
|
|
246
|
+
self,
|
|
247
|
+
A_I, # [I, C_a]
|
|
248
|
+
S_I, # [I, C_a] | None
|
|
249
|
+
Z_II, # [I, I, C_z]
|
|
250
|
+
Beta_II=None, # [I, I]
|
|
251
|
+
):
|
|
252
|
+
# Input projections
|
|
253
|
+
assert S_I is None
|
|
254
|
+
A_I = self.ln_1(A_I)
|
|
255
|
+
|
|
256
|
+
if self.use_deepspeed_evo or self.force_bfloat16:
|
|
257
|
+
A_I = A_I.to(torch.bfloat16)
|
|
258
|
+
|
|
259
|
+
Q_IH = self.to_q(A_I) # / np.sqrt(self.c)
|
|
260
|
+
K_IH = self.to_k(A_I)
|
|
261
|
+
V_IH = self.to_v(A_I)
|
|
262
|
+
B_IIH = self.to_b(self.ln_0(Z_II)) + Beta_II[..., None]
|
|
263
|
+
G_IH = self.to_g(A_I)
|
|
264
|
+
|
|
265
|
+
B, L = B_IIH.shape[:2]
|
|
266
|
+
|
|
267
|
+
if not self.use_deepspeed_evo or L <= 24:
|
|
268
|
+
Q_IH = Q_IH / torch.sqrt(
|
|
269
|
+
torch.tensor(self.c).to(Q_IH.device, torch.bfloat16)
|
|
270
|
+
)
|
|
271
|
+
# Attention
|
|
272
|
+
A_IIH = torch.softmax(
|
|
273
|
+
torch.einsum("...ihd,...jhd->...ijh", Q_IH, K_IH) + B_IIH, dim=-2
|
|
274
|
+
) # softmax over j
|
|
275
|
+
## G_IH: [I, H, C]
|
|
276
|
+
## A_IIH: [I, I, H]
|
|
277
|
+
## V_IH: [I, H, C]
|
|
278
|
+
A_I = torch.einsum("...ijh,...jhc->...ihc", A_IIH, V_IH)
|
|
279
|
+
A_I = G_IH * A_I # [B, I, H, C]
|
|
280
|
+
A_I = A_I.flatten(start_dim=-2) # [B, I, Ca]
|
|
281
|
+
else:
|
|
282
|
+
# DS4Sci_EvoformerAttention
|
|
283
|
+
# Q, K, V: [Batch, N_seq, N_res, Head, Dim]
|
|
284
|
+
# res_mask: [Batch, N_seq, 1, 1, N_res]
|
|
285
|
+
# pair_bias: [Batch, 1, Head, N_res, N_res]
|
|
286
|
+
from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention
|
|
287
|
+
|
|
288
|
+
assert Q_IH.shape[0] != 1, "this code assumes your structure is not batched"
|
|
289
|
+
batch = 1
|
|
290
|
+
n_res = Q_IH.shape[0]
|
|
291
|
+
n_head = self.n_head
|
|
292
|
+
c = self.c
|
|
293
|
+
|
|
294
|
+
Q_IH = Q_IH[None, None]
|
|
295
|
+
K_IH = K_IH[None, None]
|
|
296
|
+
V_IH = V_IH[None, None]
|
|
297
|
+
B_IIH = B_IIH.repeat(Q_IH.shape[0], 1, 1, 1)
|
|
298
|
+
B_IIH = B_IIH[:, None]
|
|
299
|
+
B_IIH = B_IIH.permute(0, 1, 4, 2, 3).to(torch.bfloat16)
|
|
300
|
+
mask = torch.zeros(
|
|
301
|
+
[Q_IH.shape[0], 1, 1, 1, B_IIH.shape[-1]],
|
|
302
|
+
dtype=torch.bfloat16,
|
|
303
|
+
device=B_IIH.device,
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
assert Q_IH.shape == (batch, 1, n_res, n_head, c)
|
|
307
|
+
assert K_IH.shape == (batch, 1, n_res, n_head, c)
|
|
308
|
+
assert V_IH.shape == (batch, 1, n_res, n_head, c)
|
|
309
|
+
assert mask.shape == (batch, 1, 1, 1, n_res)
|
|
310
|
+
assert B_IIH.shape == (batch, 1, n_head, n_res, n_res)
|
|
311
|
+
|
|
312
|
+
A_I = DS4Sci_EvoformerAttention(Q_IH, K_IH, V_IH, [mask, B_IIH])
|
|
313
|
+
|
|
314
|
+
assert A_I.shape == (batch, 1, n_res, n_head, c)
|
|
315
|
+
A_I = A_I * G_IH[None, None]
|
|
316
|
+
A_I = A_I.view(n_res, -1)
|
|
317
|
+
|
|
318
|
+
A_I = self.to_a(A_I)
|
|
319
|
+
|
|
320
|
+
return A_I
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
class PairformerBlock(nn.Module):
|
|
324
|
+
"""
|
|
325
|
+
Attempt to replicate AF3 architecture from scratch.
|
|
326
|
+
"""
|
|
327
|
+
|
|
328
|
+
def __init__(
|
|
329
|
+
self,
|
|
330
|
+
c_s,
|
|
331
|
+
c_z,
|
|
332
|
+
p_drop,
|
|
333
|
+
triangle_multiplication,
|
|
334
|
+
triangle_attention,
|
|
335
|
+
attention_pair_bias,
|
|
336
|
+
n_transition=4,
|
|
337
|
+
**kwargs, # Catch-all for backwards compatibility
|
|
338
|
+
):
|
|
339
|
+
super().__init__()
|
|
340
|
+
|
|
341
|
+
self.drop_row = Dropout(broadcast_dim=-2, p_drop=p_drop)
|
|
342
|
+
self.drop_col = Dropout(broadcast_dim=-3, p_drop=p_drop)
|
|
343
|
+
|
|
344
|
+
self.tri_mul_outgoing = TriangleMultiplication(
|
|
345
|
+
d_pair=c_z,
|
|
346
|
+
d_hidden=triangle_multiplication["d_hidden"],
|
|
347
|
+
direction="outgoing",
|
|
348
|
+
bias=True,
|
|
349
|
+
use_cuequivariance=True,
|
|
350
|
+
)
|
|
351
|
+
self.tri_mul_incoming = TriangleMultiplication(
|
|
352
|
+
d_pair=c_z,
|
|
353
|
+
d_hidden=triangle_multiplication["d_hidden"],
|
|
354
|
+
direction="incoming",
|
|
355
|
+
bias=True,
|
|
356
|
+
use_cuequivariance=True,
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
self.tri_attn_start = TriangleAttention(
|
|
360
|
+
c_z,
|
|
361
|
+
**triangle_attention,
|
|
362
|
+
start_node=True,
|
|
363
|
+
use_cuequivariance=True,
|
|
364
|
+
)
|
|
365
|
+
self.tri_attn_end = TriangleAttention(
|
|
366
|
+
c_z,
|
|
367
|
+
**triangle_attention,
|
|
368
|
+
start_node=False,
|
|
369
|
+
use_cuequivariance=True,
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
self.z_transition = Transition(c=c_z, n=n_transition)
|
|
373
|
+
|
|
374
|
+
if c_s > 0:
|
|
375
|
+
self.s_transition = Transition(c=c_s, n=n_transition)
|
|
376
|
+
|
|
377
|
+
self.attention_pair_bias = AttentionPairBiasPairformerDeepspeed(
|
|
378
|
+
c_a=c_s, c_s=0, c_pair=c_z, **attention_pair_bias
|
|
379
|
+
)
|
|
380
|
+
triangle_operations_expected_dim = 4 # B, L, L, C
|
|
381
|
+
self.maybe_make_batched = create_batch_dimension_if_not_present(
|
|
382
|
+
triangle_operations_expected_dim
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
@activation_checkpointing
|
|
386
|
+
def forward(self, S_I, Z_II):
|
|
387
|
+
Z_II = Z_II + self.drop_row(
|
|
388
|
+
self.maybe_make_batched(self.tri_mul_outgoing)(Z_II)
|
|
389
|
+
)
|
|
390
|
+
Z_II = Z_II + self.drop_row(
|
|
391
|
+
self.maybe_make_batched(self.tri_mul_incoming)(Z_II)
|
|
392
|
+
)
|
|
393
|
+
Z_II = Z_II + self.drop_row(self.maybe_make_batched(self.tri_attn_start)(Z_II))
|
|
394
|
+
Z_II = Z_II + self.drop_col(self.maybe_make_batched(self.tri_attn_end)(Z_II))
|
|
395
|
+
Z_II = Z_II + self.z_transition(Z_II)
|
|
396
|
+
if S_I is not None:
|
|
397
|
+
S_I = S_I + self.attention_pair_bias(
|
|
398
|
+
S_I, None, Z_II, Beta_II=torch.tensor([0.0], device=Z_II.device)
|
|
399
|
+
)
|
|
400
|
+
S_I = S_I + self.s_transition(S_I)
|
|
401
|
+
|
|
402
|
+
return S_I, Z_II
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
class FeatureInitializer(nn.Module):
|
|
406
|
+
def __init__(
|
|
407
|
+
self,
|
|
408
|
+
c_s,
|
|
409
|
+
c_z,
|
|
410
|
+
c_atom,
|
|
411
|
+
c_atompair,
|
|
412
|
+
c_s_inputs,
|
|
413
|
+
input_feature_embedder,
|
|
414
|
+
relative_position_encoding,
|
|
415
|
+
):
|
|
416
|
+
super().__init__()
|
|
417
|
+
self.input_feature_embedder = InputFeatureEmbedder(
|
|
418
|
+
c_atom=c_atom, c_atompair=c_atompair, **input_feature_embedder
|
|
419
|
+
)
|
|
420
|
+
self.to_s_init = linearNoBias(c_s_inputs, c_s)
|
|
421
|
+
self.to_z_init_i = linearNoBias(c_s_inputs, c_z)
|
|
422
|
+
self.to_z_init_j = linearNoBias(c_s_inputs, c_z)
|
|
423
|
+
self.relative_position_encoding = RelativePositionEncoding(
|
|
424
|
+
c_z=c_z, **relative_position_encoding
|
|
425
|
+
)
|
|
426
|
+
self.process_token_bonds = linearNoBias(1, c_z)
|
|
427
|
+
|
|
428
|
+
def forward(
|
|
429
|
+
self,
|
|
430
|
+
f,
|
|
431
|
+
):
|
|
432
|
+
S_inputs_I = self.input_feature_embedder(f)
|
|
433
|
+
S_init_I = self.to_s_init(S_inputs_I)
|
|
434
|
+
Z_init_II = self.to_z_init_i(S_inputs_I).unsqueeze(-3) + self.to_z_init_j(
|
|
435
|
+
S_inputs_I
|
|
436
|
+
).unsqueeze(-2)
|
|
437
|
+
Z_init_II = Z_init_II + self.relative_position_encoding(f)
|
|
438
|
+
Z_init_II = Z_init_II + self.process_token_bonds(
|
|
439
|
+
f["token_bonds"].unsqueeze(-1).to(torch.float)
|
|
440
|
+
)
|
|
441
|
+
return S_inputs_I, S_init_I, Z_init_II
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
class InputFeatureEmbedder(nn.Module):
|
|
445
|
+
def __init__(self, features, c_atom, c_atompair, atom_attention_encoder):
|
|
446
|
+
super().__init__()
|
|
447
|
+
self.atom_attention_encoder = AtomAttentionEncoderPairformer(
|
|
448
|
+
c_atom=c_atom, c_atompair=c_atompair, c_s=0, **atom_attention_encoder
|
|
449
|
+
)
|
|
450
|
+
self.features = features
|
|
451
|
+
self.features_to_unsqueeze = ["deletion_mean"]
|
|
452
|
+
|
|
453
|
+
def forward(
|
|
454
|
+
self,
|
|
455
|
+
f,
|
|
456
|
+
):
|
|
457
|
+
A_I, _, _, _ = self.atom_attention_encoder(f, None, None, None)
|
|
458
|
+
S_I = torch.cat(
|
|
459
|
+
[A_I.squeeze(0)]
|
|
460
|
+
+ [
|
|
461
|
+
f[feature].unsqueeze(-1)
|
|
462
|
+
if feature in self.features_to_unsqueeze
|
|
463
|
+
else f[feature]
|
|
464
|
+
for feature in self.features
|
|
465
|
+
],
|
|
466
|
+
dim=-1,
|
|
467
|
+
)
|
|
468
|
+
return S_I
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
class RelativePositionEncoding(nn.Module):
|
|
472
|
+
def __init__(self, r_max, s_max, c_z):
|
|
473
|
+
super().__init__()
|
|
474
|
+
self.r_max = r_max
|
|
475
|
+
self.s_max = s_max
|
|
476
|
+
self.c_z = c_z
|
|
477
|
+
self.linear = linearNoBias(
|
|
478
|
+
2 * (2 * self.r_max + 2) + (2 * self.s_max + 2) + 1, c_z
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
def forward(self, f):
|
|
482
|
+
b_samechain_II = f["asym_id"].unsqueeze(-1) == f["asym_id"].unsqueeze(-2)
|
|
483
|
+
b_sameresidue_II = f["residue_index"].unsqueeze(-1) == f[
|
|
484
|
+
"residue_index"
|
|
485
|
+
].unsqueeze(-2)
|
|
486
|
+
b_same_entity_II = f["entity_id"].unsqueeze(-1) == f["entity_id"].unsqueeze(-2)
|
|
487
|
+
|
|
488
|
+
# Handle cyclic chains
|
|
489
|
+
cyclic_asym_ids = f.get("cyclic_asym_ids", [])
|
|
490
|
+
if len(cyclic_asym_ids) > 0:
|
|
491
|
+
offset = f["residue_index"].unsqueeze(-1) - f["residue_index"].unsqueeze(-2)
|
|
492
|
+
|
|
493
|
+
for cyclic_asym_id in cyclic_asym_ids:
|
|
494
|
+
len_cyclic_chain = (
|
|
495
|
+
f["residue_index"][f["asym_id"] == cyclic_asym_id].unique().shape[0]
|
|
496
|
+
)
|
|
497
|
+
cyclic_chain_mask = (f["asym_id"].unsqueeze(-1) == cyclic_asym_id) & (
|
|
498
|
+
f["asym_id"].unsqueeze(-2) == cyclic_asym_id
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
# cyclic offset
|
|
502
|
+
if len_cyclic_chain > 0:
|
|
503
|
+
offset_plus = offset + len_cyclic_chain
|
|
504
|
+
offset_minus = offset - len_cyclic_chain
|
|
505
|
+
abs_offset = offset.abs()
|
|
506
|
+
abs_offset_plus = offset_plus.abs()
|
|
507
|
+
abs_offset_minus = offset_minus.abs()
|
|
508
|
+
|
|
509
|
+
choice_plus_or_minus = torch.where(
|
|
510
|
+
abs_offset_plus <= abs_offset_minus, offset_plus, offset_minus
|
|
511
|
+
)
|
|
512
|
+
c_offset = torch.where(
|
|
513
|
+
(abs_offset <= abs_offset_plus)
|
|
514
|
+
& (abs_offset <= abs_offset_minus),
|
|
515
|
+
offset,
|
|
516
|
+
choice_plus_or_minus,
|
|
517
|
+
)
|
|
518
|
+
offset = torch.where(cyclic_chain_mask, c_offset, offset)
|
|
519
|
+
|
|
520
|
+
offset = (offset + self.r_max).clamp(0, 2 * self.r_max)
|
|
521
|
+
d_residue_II = torch.where(
|
|
522
|
+
b_samechain_II, offset, (2 * self.r_max + 1) * torch.ones_like(offset)
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
else:
|
|
526
|
+
d_residue_II = torch.where(
|
|
527
|
+
b_samechain_II,
|
|
528
|
+
torch.clip(
|
|
529
|
+
f["residue_index"].unsqueeze(-1)
|
|
530
|
+
- f["residue_index"].unsqueeze(-2)
|
|
531
|
+
+ self.r_max,
|
|
532
|
+
0,
|
|
533
|
+
2 * self.r_max,
|
|
534
|
+
),
|
|
535
|
+
2 * self.r_max + 1,
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
A_relpos_II = one_hot(d_residue_II.long(), 2 * self.r_max + 2)
|
|
539
|
+
d_token_II = torch.where(
|
|
540
|
+
b_samechain_II * b_sameresidue_II,
|
|
541
|
+
torch.clip(
|
|
542
|
+
f["token_index"].unsqueeze(-1)
|
|
543
|
+
- f["token_index"].unsqueeze(-2)
|
|
544
|
+
+ self.r_max,
|
|
545
|
+
0,
|
|
546
|
+
2 * self.r_max,
|
|
547
|
+
),
|
|
548
|
+
2 * self.r_max + 1,
|
|
549
|
+
)
|
|
550
|
+
A_reltoken_II = one_hot(d_token_II, 2 * self.r_max + 2)
|
|
551
|
+
d_chain_II = torch.where(
|
|
552
|
+
# NOTE: Implementing bugfix from the Protenix Technical report, where we use `same_entity` instead of `not same_chain` (as in the AF-3 pseudocode)
|
|
553
|
+
# Reference: https://github.com/bytedance/Protenix/blob/main/Protenix_Technical_Report.pdf
|
|
554
|
+
b_same_entity_II,
|
|
555
|
+
torch.clip(
|
|
556
|
+
f["sym_id"].unsqueeze(-1) - f["sym_id"].unsqueeze(-2) + self.s_max,
|
|
557
|
+
0,
|
|
558
|
+
2 * self.s_max,
|
|
559
|
+
),
|
|
560
|
+
2 * self.s_max + 1,
|
|
561
|
+
)
|
|
562
|
+
A_relchain_II = one_hot(d_chain_II.long(), 2 * self.s_max + 2)
|
|
563
|
+
return self.linear(
|
|
564
|
+
torch.cat(
|
|
565
|
+
[
|
|
566
|
+
A_relpos_II,
|
|
567
|
+
A_reltoken_II,
|
|
568
|
+
b_same_entity_II.unsqueeze(-1),
|
|
569
|
+
A_relchain_II,
|
|
570
|
+
],
|
|
571
|
+
dim=-1,
|
|
572
|
+
).to(torch.float)
|
|
573
|
+
)
|
|
574
|
+
|
|
575
|
+
|
|
576
|
+
class MSAModule(nn.Module):
|
|
577
|
+
def __init__(
|
|
578
|
+
self,
|
|
579
|
+
n_block,
|
|
580
|
+
c_m,
|
|
581
|
+
p_drop_msa,
|
|
582
|
+
p_drop_pair,
|
|
583
|
+
msa_subsample_embedder,
|
|
584
|
+
outer_product,
|
|
585
|
+
msa_pair_weighted_averaging,
|
|
586
|
+
msa_transition,
|
|
587
|
+
triangle_multiplication_outgoing,
|
|
588
|
+
triangle_multiplication_incoming,
|
|
589
|
+
triangle_attention_starting,
|
|
590
|
+
triangle_attention_ending,
|
|
591
|
+
pair_transition,
|
|
592
|
+
):
|
|
593
|
+
super().__init__()
|
|
594
|
+
self.n_block = n_block
|
|
595
|
+
self.msa_subsampler = MSASubsampleEmbedder(**msa_subsample_embedder)
|
|
596
|
+
self.outer_product = OuterProductMean_AF3(**outer_product)
|
|
597
|
+
self.msa_pair_weighted_averaging = MSAPairWeightedAverage(
|
|
598
|
+
**msa_pair_weighted_averaging
|
|
599
|
+
)
|
|
600
|
+
self.msa_transition = Transition(**msa_transition)
|
|
601
|
+
|
|
602
|
+
self.drop_row_msa = Dropout(broadcast_dim=-2, p_drop=p_drop_msa)
|
|
603
|
+
self.drop_row_pair = Dropout(broadcast_dim=-2, p_drop=p_drop_pair)
|
|
604
|
+
self.drop_col_pair = Dropout(broadcast_dim=-3, p_drop=p_drop_pair)
|
|
605
|
+
|
|
606
|
+
self.tri_mult_outgoing = TriangleMultiplication(
|
|
607
|
+
d_pair=triangle_multiplication_outgoing["d_pair"],
|
|
608
|
+
d_hidden=triangle_multiplication_outgoing["d_hidden"],
|
|
609
|
+
direction="outgoing",
|
|
610
|
+
bias=True,
|
|
611
|
+
use_cuequivariance=True,
|
|
612
|
+
)
|
|
613
|
+
self.tri_mult_incoming = TriangleMultiplication(
|
|
614
|
+
d_pair=triangle_multiplication_incoming["d_pair"],
|
|
615
|
+
d_hidden=triangle_multiplication_incoming["d_hidden"],
|
|
616
|
+
direction="incoming",
|
|
617
|
+
bias=True,
|
|
618
|
+
use_cuequivariance=True,
|
|
619
|
+
)
|
|
620
|
+
self.tri_attn_start = TriangleAttention(
|
|
621
|
+
**triangle_attention_starting, start_node=True, use_cuequivariance=True
|
|
622
|
+
)
|
|
623
|
+
self.tri_attn_end = TriangleAttention(
|
|
624
|
+
**triangle_attention_ending, start_node=False, use_cuequivariance=True
|
|
625
|
+
)
|
|
626
|
+
self.pair_transition = Transition(**pair_transition)
|
|
627
|
+
|
|
628
|
+
outer_product_expected_dim = 4 # B, S, I, C
|
|
629
|
+
self.maybe_make_batched_outer_product = create_batch_dimension_if_not_present(
|
|
630
|
+
outer_product_expected_dim
|
|
631
|
+
)
|
|
632
|
+
|
|
633
|
+
triangle_ops_expected_dim = 4 # B, I, I, C
|
|
634
|
+
self.maybe_make_batched_triangle_ops = create_batch_dimension_if_not_present(
|
|
635
|
+
triangle_ops_expected_dim
|
|
636
|
+
)
|
|
637
|
+
|
|
638
|
+
@activation_checkpointing
|
|
639
|
+
def forward(
|
|
640
|
+
self,
|
|
641
|
+
f,
|
|
642
|
+
Z_II,
|
|
643
|
+
S_inputs_I,
|
|
644
|
+
):
|
|
645
|
+
msa = f["msa"]
|
|
646
|
+
msa_SI = self.msa_subsampler(msa, S_inputs_I)
|
|
647
|
+
|
|
648
|
+
for i in range(self.n_block):
|
|
649
|
+
# update MSA features
|
|
650
|
+
Z_II = Z_II + self.maybe_make_batched_outer_product(self.outer_product)(
|
|
651
|
+
msa_SI
|
|
652
|
+
)
|
|
653
|
+
msa_SI = msa_SI + self.drop_row_msa(
|
|
654
|
+
self.msa_pair_weighted_averaging(msa_SI, Z_II)
|
|
655
|
+
)
|
|
656
|
+
msa_SI = msa_SI + self.msa_transition(msa_SI)
|
|
657
|
+
|
|
658
|
+
# update pair features
|
|
659
|
+
Z_II = Z_II + self.drop_row_pair(
|
|
660
|
+
self.maybe_make_batched_triangle_ops(self.tri_mult_outgoing)(Z_II)
|
|
661
|
+
)
|
|
662
|
+
Z_II = Z_II + self.drop_row_pair(
|
|
663
|
+
self.maybe_make_batched_triangle_ops(self.tri_mult_incoming)(Z_II)
|
|
664
|
+
)
|
|
665
|
+
|
|
666
|
+
Z_II = Z_II + self.drop_row_pair(
|
|
667
|
+
self.maybe_make_batched_triangle_ops(self.tri_attn_start)(Z_II)
|
|
668
|
+
)
|
|
669
|
+
Z_II = Z_II + self.drop_col_pair(
|
|
670
|
+
self.maybe_make_batched_triangle_ops(self.tri_attn_end)(Z_II)
|
|
671
|
+
)
|
|
672
|
+
Z_II = Z_II + self.pair_transition(Z_II)
|
|
673
|
+
|
|
674
|
+
return Z_II
|
|
675
|
+
|
|
676
|
+
|
|
677
|
+
class RF3TemplateEmbedder(nn.Module):
|
|
678
|
+
"""
|
|
679
|
+
Template track that enables conditioning on noisy ground-truth templates at the token level.
|
|
680
|
+
Supports all chain types.
|
|
681
|
+
"""
|
|
682
|
+
|
|
683
|
+
def __init__(
|
|
684
|
+
self,
|
|
685
|
+
n_block,
|
|
686
|
+
raw_template_dim,
|
|
687
|
+
c_z,
|
|
688
|
+
c,
|
|
689
|
+
p_drop,
|
|
690
|
+
use_fourier_encoding: bool = False, # HACK: Unused, kept for backwards compatibility with 9/21 checkpoint
|
|
691
|
+
):
|
|
692
|
+
super().__init__()
|
|
693
|
+
self.c = c
|
|
694
|
+
self.emb_pair = nn.Linear(c_z, c, bias=False)
|
|
695
|
+
self.norm_pair_before_pairformer = nn.LayerNorm(c_z)
|
|
696
|
+
self.norm_after_pairformer = nn.LayerNorm(c)
|
|
697
|
+
self.emb_templ = nn.Linear(raw_template_dim, c, bias=False)
|
|
698
|
+
|
|
699
|
+
# template pairformer does not operate on sequence representation
|
|
700
|
+
self.pairformer = nn.ModuleList(
|
|
701
|
+
[
|
|
702
|
+
PairformerBlock(
|
|
703
|
+
c_s=0,
|
|
704
|
+
c_z=c,
|
|
705
|
+
p_drop=p_drop,
|
|
706
|
+
triangle_multiplication=dict(d_hidden=c),
|
|
707
|
+
triangle_attention=dict(d_hidden=c),
|
|
708
|
+
attention_pair_bias={},
|
|
709
|
+
n_transition=4,
|
|
710
|
+
)
|
|
711
|
+
for _ in range(n_block)
|
|
712
|
+
]
|
|
713
|
+
)
|
|
714
|
+
|
|
715
|
+
# NOTE: this is not consistent with AF3 paper which outputs this tensor in the template_channel dimension
|
|
716
|
+
# In Algorithm 1, line 9, the outputs of this function are added to the Z_II tensor which has dimensions [B, I, I, C_z]
|
|
717
|
+
# so we make the outputs of this module also has those dimensions
|
|
718
|
+
self.agg_emb = nn.Linear(c, c_z, bias=False)
|
|
719
|
+
|
|
720
|
+
def forward(
|
|
721
|
+
self,
|
|
722
|
+
f,
|
|
723
|
+
Z_II,
|
|
724
|
+
):
|
|
725
|
+
@activation_checkpointing
|
|
726
|
+
def embed_templates_like_rf3(
|
|
727
|
+
has_distogram_condition, # [I, I]
|
|
728
|
+
distogram_condition_noise_scale, # [I]
|
|
729
|
+
distogram_condition, # [I, I, 64], where 64 is the number of distogram bins
|
|
730
|
+
):
|
|
731
|
+
I = Z_II.shape[0] # n_tokens
|
|
732
|
+
|
|
733
|
+
# Transform noise scale to reasonable range
|
|
734
|
+
joint_noise_scale = (
|
|
735
|
+
distogram_condition_noise_scale[None, :] ** 2
|
|
736
|
+
+ distogram_condition_noise_scale[:, None] ** 2
|
|
737
|
+
).sqrt()
|
|
738
|
+
joint_noise_level = af3_noise_scale_to_noise_level(joint_noise_scale)
|
|
739
|
+
|
|
740
|
+
# ---------------------------- #
|
|
741
|
+
|
|
742
|
+
# ... concatenate along the channel dimension
|
|
743
|
+
template_feats = torch.cat(
|
|
744
|
+
[
|
|
745
|
+
distogram_condition, # [I, I, 64]
|
|
746
|
+
has_distogram_condition.unsqueeze(-1), # [I, I, 1]
|
|
747
|
+
joint_noise_level.unsqueeze(-1), # [I, I, 1]
|
|
748
|
+
],
|
|
749
|
+
dim=-1,
|
|
750
|
+
) # [I, I, 66]
|
|
751
|
+
|
|
752
|
+
# ... remove any invalid interactions
|
|
753
|
+
template_feats = template_feats * has_distogram_condition.unsqueeze(
|
|
754
|
+
-1
|
|
755
|
+
) # [I, I, 66], where 66 = 64 + 1 + 1
|
|
756
|
+
|
|
757
|
+
# ... embed template features
|
|
758
|
+
template_channels = self.emb_templ(template_feats) # [I, I, c]
|
|
759
|
+
|
|
760
|
+
# ---------------------------- #
|
|
761
|
+
|
|
762
|
+
# ... pass through pairformer
|
|
763
|
+
u_II = torch.zeros(I, I, self.c, device=Z_II.device)
|
|
764
|
+
v_II = (
|
|
765
|
+
self.emb_pair(self.norm_pair_before_pairformer(Z_II))
|
|
766
|
+
+ template_channels
|
|
767
|
+
) # [I, I, c]
|
|
768
|
+
for block in self.pairformer:
|
|
769
|
+
_, v_II = block(None, v_II)
|
|
770
|
+
u_II = u_II + self.norm_after_pairformer(v_II)
|
|
771
|
+
|
|
772
|
+
return self.agg_emb(relu(u_II))
|
|
773
|
+
|
|
774
|
+
# Ground-truth template embedding (noisy ground-truth template as input)
|
|
775
|
+
embedded_templates = embed_templates_like_rf3(
|
|
776
|
+
has_distogram_condition=f["has_distogram_condition"], # [I, I]
|
|
777
|
+
distogram_condition_noise_scale=f["distogram_condition_noise_scale"], # [I]
|
|
778
|
+
distogram_condition=f[
|
|
779
|
+
"distogram_condition"
|
|
780
|
+
], # [I, I, 64], where 64 is the number of distogram bins
|
|
781
|
+
)
|
|
782
|
+
|
|
783
|
+
return embedded_templates
|