boltz-vsynthes 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltz/__init__.py +7 -0
- boltz/data/__init__.py +0 -0
- boltz/data/const.py +1184 -0
- boltz/data/crop/__init__.py +0 -0
- boltz/data/crop/affinity.py +164 -0
- boltz/data/crop/boltz.py +296 -0
- boltz/data/crop/cropper.py +45 -0
- boltz/data/feature/__init__.py +0 -0
- boltz/data/feature/featurizer.py +1230 -0
- boltz/data/feature/featurizerv2.py +2208 -0
- boltz/data/feature/symmetry.py +602 -0
- boltz/data/filter/__init__.py +0 -0
- boltz/data/filter/dynamic/__init__.py +0 -0
- boltz/data/filter/dynamic/date.py +76 -0
- boltz/data/filter/dynamic/filter.py +24 -0
- boltz/data/filter/dynamic/max_residues.py +37 -0
- boltz/data/filter/dynamic/resolution.py +34 -0
- boltz/data/filter/dynamic/size.py +38 -0
- boltz/data/filter/dynamic/subset.py +42 -0
- boltz/data/filter/static/__init__.py +0 -0
- boltz/data/filter/static/filter.py +26 -0
- boltz/data/filter/static/ligand.py +37 -0
- boltz/data/filter/static/polymer.py +299 -0
- boltz/data/module/__init__.py +0 -0
- boltz/data/module/inference.py +307 -0
- boltz/data/module/inferencev2.py +429 -0
- boltz/data/module/training.py +684 -0
- boltz/data/module/trainingv2.py +660 -0
- boltz/data/mol.py +900 -0
- boltz/data/msa/__init__.py +0 -0
- boltz/data/msa/mmseqs2.py +235 -0
- boltz/data/pad.py +84 -0
- boltz/data/parse/__init__.py +0 -0
- boltz/data/parse/a3m.py +134 -0
- boltz/data/parse/csv.py +100 -0
- boltz/data/parse/fasta.py +138 -0
- boltz/data/parse/mmcif.py +1239 -0
- boltz/data/parse/mmcif_with_constraints.py +1607 -0
- boltz/data/parse/schema.py +1851 -0
- boltz/data/parse/yaml.py +68 -0
- boltz/data/sample/__init__.py +0 -0
- boltz/data/sample/cluster.py +283 -0
- boltz/data/sample/distillation.py +57 -0
- boltz/data/sample/random.py +39 -0
- boltz/data/sample/sampler.py +49 -0
- boltz/data/tokenize/__init__.py +0 -0
- boltz/data/tokenize/boltz.py +195 -0
- boltz/data/tokenize/boltz2.py +396 -0
- boltz/data/tokenize/tokenizer.py +24 -0
- boltz/data/types.py +777 -0
- boltz/data/write/__init__.py +0 -0
- boltz/data/write/mmcif.py +305 -0
- boltz/data/write/pdb.py +171 -0
- boltz/data/write/utils.py +23 -0
- boltz/data/write/writer.py +330 -0
- boltz/main.py +1292 -0
- boltz/model/__init__.py +0 -0
- boltz/model/layers/__init__.py +0 -0
- boltz/model/layers/attention.py +132 -0
- boltz/model/layers/attentionv2.py +111 -0
- boltz/model/layers/confidence_utils.py +231 -0
- boltz/model/layers/dropout.py +34 -0
- boltz/model/layers/initialize.py +100 -0
- boltz/model/layers/outer_product_mean.py +98 -0
- boltz/model/layers/pair_averaging.py +135 -0
- boltz/model/layers/pairformer.py +337 -0
- boltz/model/layers/relative.py +58 -0
- boltz/model/layers/transition.py +78 -0
- boltz/model/layers/triangular_attention/__init__.py +0 -0
- boltz/model/layers/triangular_attention/attention.py +189 -0
- boltz/model/layers/triangular_attention/primitives.py +409 -0
- boltz/model/layers/triangular_attention/utils.py +380 -0
- boltz/model/layers/triangular_mult.py +212 -0
- boltz/model/loss/__init__.py +0 -0
- boltz/model/loss/bfactor.py +49 -0
- boltz/model/loss/confidence.py +590 -0
- boltz/model/loss/confidencev2.py +621 -0
- boltz/model/loss/diffusion.py +171 -0
- boltz/model/loss/diffusionv2.py +134 -0
- boltz/model/loss/distogram.py +48 -0
- boltz/model/loss/distogramv2.py +105 -0
- boltz/model/loss/validation.py +1025 -0
- boltz/model/models/__init__.py +0 -0
- boltz/model/models/boltz1.py +1286 -0
- boltz/model/models/boltz2.py +1249 -0
- boltz/model/modules/__init__.py +0 -0
- boltz/model/modules/affinity.py +223 -0
- boltz/model/modules/confidence.py +481 -0
- boltz/model/modules/confidence_utils.py +181 -0
- boltz/model/modules/confidencev2.py +495 -0
- boltz/model/modules/diffusion.py +844 -0
- boltz/model/modules/diffusion_conditioning.py +116 -0
- boltz/model/modules/diffusionv2.py +677 -0
- boltz/model/modules/encoders.py +639 -0
- boltz/model/modules/encodersv2.py +565 -0
- boltz/model/modules/transformers.py +322 -0
- boltz/model/modules/transformersv2.py +261 -0
- boltz/model/modules/trunk.py +688 -0
- boltz/model/modules/trunkv2.py +828 -0
- boltz/model/modules/utils.py +303 -0
- boltz/model/optim/__init__.py +0 -0
- boltz/model/optim/ema.py +389 -0
- boltz/model/optim/scheduler.py +99 -0
- boltz/model/potentials/__init__.py +0 -0
- boltz/model/potentials/potentials.py +497 -0
- boltz/model/potentials/schedules.py +32 -0
- boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
- boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
- boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
- boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
- boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
- boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,565 @@
|
|
1
|
+
# started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang
|
2
|
+
from functools import partial
|
3
|
+
from math import pi
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from einops import rearrange
|
7
|
+
from torch import nn
|
8
|
+
from torch.nn import Linear, Module, ModuleList
|
9
|
+
from torch.nn.functional import one_hot
|
10
|
+
|
11
|
+
import boltz.model.layers.initialize as init
|
12
|
+
from boltz.model.layers.transition import Transition
|
13
|
+
from boltz.model.modules.transformersv2 import AtomTransformer
|
14
|
+
from boltz.model.modules.utils import LinearNoBias
|
15
|
+
|
16
|
+
|
17
|
+
class FourierEmbedding(Module):
|
18
|
+
"""Algorithm 22."""
|
19
|
+
|
20
|
+
def __init__(self, dim):
|
21
|
+
super().__init__()
|
22
|
+
self.proj = nn.Linear(1, dim)
|
23
|
+
torch.nn.init.normal_(self.proj.weight, mean=0, std=1)
|
24
|
+
torch.nn.init.normal_(self.proj.bias, mean=0, std=1)
|
25
|
+
self.proj.requires_grad_(False)
|
26
|
+
|
27
|
+
def forward(
|
28
|
+
self,
|
29
|
+
times, # Float[' b'],
|
30
|
+
): # -> Float['b d']:
|
31
|
+
times = rearrange(times, "b -> b 1")
|
32
|
+
rand_proj = self.proj(times)
|
33
|
+
return torch.cos(2 * pi * rand_proj)
|
34
|
+
|
35
|
+
|
36
|
+
class RelativePositionEncoder(Module):
|
37
|
+
"""Algorithm 3."""
|
38
|
+
|
39
|
+
def __init__(
|
40
|
+
self, token_z, r_max=32, s_max=2, fix_sym_check=False, cyclic_pos_enc=False
|
41
|
+
):
|
42
|
+
super().__init__()
|
43
|
+
self.r_max = r_max
|
44
|
+
self.s_max = s_max
|
45
|
+
self.linear_layer = LinearNoBias(4 * (r_max + 1) + 2 * (s_max + 1) + 1, token_z)
|
46
|
+
self.fix_sym_check = fix_sym_check
|
47
|
+
self.cyclic_pos_enc = cyclic_pos_enc
|
48
|
+
|
49
|
+
def forward(self, feats):
|
50
|
+
b_same_chain = torch.eq(
|
51
|
+
feats["asym_id"][:, :, None], feats["asym_id"][:, None, :]
|
52
|
+
)
|
53
|
+
b_same_residue = torch.eq(
|
54
|
+
feats["residue_index"][:, :, None], feats["residue_index"][:, None, :]
|
55
|
+
)
|
56
|
+
b_same_entity = torch.eq(
|
57
|
+
feats["entity_id"][:, :, None], feats["entity_id"][:, None, :]
|
58
|
+
)
|
59
|
+
|
60
|
+
d_residue = (
|
61
|
+
feats["residue_index"][:, :, None] - feats["residue_index"][:, None, :]
|
62
|
+
)
|
63
|
+
|
64
|
+
if self.cyclic_pos_enc and torch.any(feats["cyclic_period"] > 0):
|
65
|
+
period = torch.where(
|
66
|
+
feats["cyclic_period"] > 0,
|
67
|
+
feats["cyclic_period"],
|
68
|
+
torch.zeros_like(feats["cyclic_period"]) + 10000,
|
69
|
+
)
|
70
|
+
d_residue = (d_residue - period * torch.round(d_residue / period)).long()
|
71
|
+
|
72
|
+
d_residue = torch.clip(
|
73
|
+
d_residue + self.r_max,
|
74
|
+
0,
|
75
|
+
2 * self.r_max,
|
76
|
+
)
|
77
|
+
d_residue = torch.where(
|
78
|
+
b_same_chain, d_residue, torch.zeros_like(d_residue) + 2 * self.r_max + 1
|
79
|
+
)
|
80
|
+
a_rel_pos = one_hot(d_residue, 2 * self.r_max + 2)
|
81
|
+
|
82
|
+
d_token = torch.clip(
|
83
|
+
feats["token_index"][:, :, None]
|
84
|
+
- feats["token_index"][:, None, :]
|
85
|
+
+ self.r_max,
|
86
|
+
0,
|
87
|
+
2 * self.r_max,
|
88
|
+
)
|
89
|
+
d_token = torch.where(
|
90
|
+
b_same_chain & b_same_residue,
|
91
|
+
d_token,
|
92
|
+
torch.zeros_like(d_token) + 2 * self.r_max + 1,
|
93
|
+
)
|
94
|
+
a_rel_token = one_hot(d_token, 2 * self.r_max + 2)
|
95
|
+
|
96
|
+
d_chain = torch.clip(
|
97
|
+
feats["sym_id"][:, :, None] - feats["sym_id"][:, None, :] + self.s_max,
|
98
|
+
0,
|
99
|
+
2 * self.s_max,
|
100
|
+
)
|
101
|
+
d_chain = torch.where(
|
102
|
+
(~b_same_entity) if self.fix_sym_check else b_same_chain,
|
103
|
+
torch.zeros_like(d_chain) + 2 * self.s_max + 1,
|
104
|
+
d_chain,
|
105
|
+
)
|
106
|
+
# Note: added | (~b_same_entity) based on observation of ProteinX manuscript
|
107
|
+
a_rel_chain = one_hot(d_chain, 2 * self.s_max + 2)
|
108
|
+
|
109
|
+
p = self.linear_layer(
|
110
|
+
torch.cat(
|
111
|
+
[
|
112
|
+
a_rel_pos.float(),
|
113
|
+
a_rel_token.float(),
|
114
|
+
b_same_entity.unsqueeze(-1).float(),
|
115
|
+
a_rel_chain.float(),
|
116
|
+
],
|
117
|
+
dim=-1,
|
118
|
+
)
|
119
|
+
)
|
120
|
+
return p
|
121
|
+
|
122
|
+
|
123
|
+
class SingleConditioning(Module):
|
124
|
+
"""Algorithm 21."""
|
125
|
+
|
126
|
+
def __init__(
|
127
|
+
self,
|
128
|
+
sigma_data: float,
|
129
|
+
token_s: int = 384,
|
130
|
+
dim_fourier: int = 256,
|
131
|
+
num_transitions: int = 2,
|
132
|
+
transition_expansion_factor: int = 2,
|
133
|
+
eps: float = 1e-20,
|
134
|
+
disable_times: bool = False,
|
135
|
+
) -> None:
|
136
|
+
super().__init__()
|
137
|
+
self.eps = eps
|
138
|
+
self.sigma_data = sigma_data
|
139
|
+
self.disable_times = disable_times
|
140
|
+
|
141
|
+
self.norm_single = nn.LayerNorm(2 * token_s)
|
142
|
+
self.single_embed = nn.Linear(2 * token_s, 2 * token_s)
|
143
|
+
if not self.disable_times:
|
144
|
+
self.fourier_embed = FourierEmbedding(dim_fourier)
|
145
|
+
self.norm_fourier = nn.LayerNorm(dim_fourier)
|
146
|
+
self.fourier_to_single = LinearNoBias(dim_fourier, 2 * token_s)
|
147
|
+
|
148
|
+
transitions = ModuleList([])
|
149
|
+
for _ in range(num_transitions):
|
150
|
+
transition = Transition(
|
151
|
+
dim=2 * token_s, hidden=transition_expansion_factor * 2 * token_s
|
152
|
+
)
|
153
|
+
transitions.append(transition)
|
154
|
+
|
155
|
+
self.transitions = transitions
|
156
|
+
|
157
|
+
def forward(
|
158
|
+
self,
|
159
|
+
times, # Float[' b'],
|
160
|
+
s_trunk, # Float['b n ts'],
|
161
|
+
s_inputs, # Float['b n ts'],
|
162
|
+
): # -> Float['b n 2ts']:
|
163
|
+
s = torch.cat((s_trunk, s_inputs), dim=-1)
|
164
|
+
s = self.single_embed(self.norm_single(s))
|
165
|
+
if not self.disable_times:
|
166
|
+
fourier_embed = self.fourier_embed(
|
167
|
+
times
|
168
|
+
) # note: sigma rescaling done in diffusion module
|
169
|
+
normed_fourier = self.norm_fourier(fourier_embed)
|
170
|
+
fourier_to_single = self.fourier_to_single(normed_fourier)
|
171
|
+
|
172
|
+
s = rearrange(fourier_to_single, "b d -> b 1 d") + s
|
173
|
+
|
174
|
+
for transition in self.transitions:
|
175
|
+
s = transition(s) + s
|
176
|
+
|
177
|
+
return s, normed_fourier if not self.disable_times else None
|
178
|
+
|
179
|
+
|
180
|
+
class PairwiseConditioning(Module):
|
181
|
+
"""Algorithm 21."""
|
182
|
+
|
183
|
+
def __init__(
|
184
|
+
self,
|
185
|
+
token_z,
|
186
|
+
dim_token_rel_pos_feats,
|
187
|
+
num_transitions=2,
|
188
|
+
transition_expansion_factor=2,
|
189
|
+
):
|
190
|
+
super().__init__()
|
191
|
+
|
192
|
+
self.dim_pairwise_init_proj = nn.Sequential(
|
193
|
+
nn.LayerNorm(token_z + dim_token_rel_pos_feats),
|
194
|
+
LinearNoBias(token_z + dim_token_rel_pos_feats, token_z),
|
195
|
+
)
|
196
|
+
|
197
|
+
transitions = ModuleList([])
|
198
|
+
for _ in range(num_transitions):
|
199
|
+
transition = Transition(
|
200
|
+
dim=token_z, hidden=transition_expansion_factor * token_z
|
201
|
+
)
|
202
|
+
transitions.append(transition)
|
203
|
+
|
204
|
+
self.transitions = transitions
|
205
|
+
|
206
|
+
def forward(
|
207
|
+
self,
|
208
|
+
z_trunk, # Float['b n n tz'],
|
209
|
+
token_rel_pos_feats, # Float['b n n 3'],
|
210
|
+
): # -> Float['b n n tz']:
|
211
|
+
z = torch.cat((z_trunk, token_rel_pos_feats), dim=-1)
|
212
|
+
z = self.dim_pairwise_init_proj(z)
|
213
|
+
|
214
|
+
for transition in self.transitions:
|
215
|
+
z = transition(z) + z
|
216
|
+
|
217
|
+
return z
|
218
|
+
|
219
|
+
|
220
|
+
def get_indexing_matrix(K, W, H, device):
|
221
|
+
assert W % 2 == 0
|
222
|
+
assert H % (W // 2) == 0
|
223
|
+
|
224
|
+
h = H // (W // 2)
|
225
|
+
assert h % 2 == 0
|
226
|
+
|
227
|
+
arange = torch.arange(2 * K, device=device)
|
228
|
+
index = ((arange.unsqueeze(0) - arange.unsqueeze(1)) + h // 2).clamp(
|
229
|
+
min=0, max=h + 1
|
230
|
+
)
|
231
|
+
index = index.view(K, 2, 2 * K)[:, 0, :]
|
232
|
+
onehot = one_hot(index, num_classes=h + 2)[..., 1:-1].transpose(1, 0)
|
233
|
+
return onehot.reshape(2 * K, h * K).float()
|
234
|
+
|
235
|
+
|
236
|
+
def single_to_keys(single, indexing_matrix, W, H):
|
237
|
+
B, N, D = single.shape
|
238
|
+
K = N // W
|
239
|
+
single = single.view(B, 2 * K, W // 2, D)
|
240
|
+
return torch.einsum("b j i d, j k -> b k i d", single, indexing_matrix).reshape(
|
241
|
+
B, K, H, D
|
242
|
+
) # j = 2K, i = W//2, k = h * K
|
243
|
+
|
244
|
+
|
245
|
+
class AtomEncoder(Module):
|
246
|
+
def __init__(
|
247
|
+
self,
|
248
|
+
atom_s,
|
249
|
+
atom_z,
|
250
|
+
token_s,
|
251
|
+
token_z,
|
252
|
+
atoms_per_window_queries,
|
253
|
+
atoms_per_window_keys,
|
254
|
+
atom_feature_dim,
|
255
|
+
structure_prediction=True,
|
256
|
+
use_no_atom_char=False,
|
257
|
+
use_atom_backbone_feat=False,
|
258
|
+
use_residue_feats_atoms=False,
|
259
|
+
):
|
260
|
+
super().__init__()
|
261
|
+
|
262
|
+
self.embed_atom_features = Linear(atom_feature_dim, atom_s)
|
263
|
+
self.embed_atompair_ref_pos = LinearNoBias(3, atom_z)
|
264
|
+
self.embed_atompair_ref_dist = LinearNoBias(1, atom_z)
|
265
|
+
self.embed_atompair_mask = LinearNoBias(1, atom_z)
|
266
|
+
self.atoms_per_window_queries = atoms_per_window_queries
|
267
|
+
self.atoms_per_window_keys = atoms_per_window_keys
|
268
|
+
self.use_no_atom_char = use_no_atom_char
|
269
|
+
self.use_atom_backbone_feat = use_atom_backbone_feat
|
270
|
+
self.use_residue_feats_atoms = use_residue_feats_atoms
|
271
|
+
|
272
|
+
self.structure_prediction = structure_prediction
|
273
|
+
if structure_prediction:
|
274
|
+
self.s_to_c_trans = nn.Sequential(
|
275
|
+
nn.LayerNorm(token_s), LinearNoBias(token_s, atom_s)
|
276
|
+
)
|
277
|
+
init.final_init_(self.s_to_c_trans[1].weight)
|
278
|
+
|
279
|
+
self.z_to_p_trans = nn.Sequential(
|
280
|
+
nn.LayerNorm(token_z), LinearNoBias(token_z, atom_z)
|
281
|
+
)
|
282
|
+
init.final_init_(self.z_to_p_trans[1].weight)
|
283
|
+
|
284
|
+
self.c_to_p_trans_k = nn.Sequential(
|
285
|
+
nn.ReLU(),
|
286
|
+
LinearNoBias(atom_s, atom_z),
|
287
|
+
)
|
288
|
+
init.final_init_(self.c_to_p_trans_k[1].weight)
|
289
|
+
|
290
|
+
self.c_to_p_trans_q = nn.Sequential(
|
291
|
+
nn.ReLU(),
|
292
|
+
LinearNoBias(atom_s, atom_z),
|
293
|
+
)
|
294
|
+
init.final_init_(self.c_to_p_trans_q[1].weight)
|
295
|
+
|
296
|
+
self.p_mlp = nn.Sequential(
|
297
|
+
nn.ReLU(),
|
298
|
+
LinearNoBias(atom_z, atom_z),
|
299
|
+
nn.ReLU(),
|
300
|
+
LinearNoBias(atom_z, atom_z),
|
301
|
+
nn.ReLU(),
|
302
|
+
LinearNoBias(atom_z, atom_z),
|
303
|
+
)
|
304
|
+
init.final_init_(self.p_mlp[5].weight)
|
305
|
+
|
306
|
+
def forward(
|
307
|
+
self,
|
308
|
+
feats,
|
309
|
+
s_trunk=None, # Float['bm n ts'],
|
310
|
+
z=None, # Float['bm n n tz'],
|
311
|
+
):
|
312
|
+
with torch.autocast("cuda", enabled=False):
|
313
|
+
B, N, _ = feats["ref_pos"].shape
|
314
|
+
atom_mask = feats["atom_pad_mask"].bool() # Bool['b m'],
|
315
|
+
|
316
|
+
atom_ref_pos = feats["ref_pos"] # Float['b m 3'],
|
317
|
+
atom_uid = feats["ref_space_uid"] # Long['b m'],
|
318
|
+
|
319
|
+
atom_feats = [
|
320
|
+
atom_ref_pos,
|
321
|
+
feats["ref_charge"].unsqueeze(-1),
|
322
|
+
feats["ref_element"],
|
323
|
+
]
|
324
|
+
if not self.use_no_atom_char:
|
325
|
+
atom_feats.append(feats["ref_atom_name_chars"].reshape(B, N, 4 * 64))
|
326
|
+
if self.use_atom_backbone_feat:
|
327
|
+
atom_feats.append(feats["atom_backbone_feat"])
|
328
|
+
if self.use_residue_feats_atoms:
|
329
|
+
res_feats = torch.cat(
|
330
|
+
[
|
331
|
+
feats["res_type"],
|
332
|
+
feats["modified"].unsqueeze(-1),
|
333
|
+
one_hot(feats["mol_type"], num_classes=4).float(),
|
334
|
+
],
|
335
|
+
dim=-1,
|
336
|
+
)
|
337
|
+
atom_to_token = feats["atom_to_token"].float()
|
338
|
+
atom_res_feats = torch.bmm(atom_to_token, res_feats)
|
339
|
+
atom_feats.append(atom_res_feats)
|
340
|
+
|
341
|
+
atom_feats = torch.cat(atom_feats, dim=-1)
|
342
|
+
|
343
|
+
c = self.embed_atom_features(atom_feats)
|
344
|
+
|
345
|
+
# note we are already creating the windows to make it more efficient
|
346
|
+
W, H = self.atoms_per_window_queries, self.atoms_per_window_keys
|
347
|
+
B, N = c.shape[:2]
|
348
|
+
K = N // W
|
349
|
+
keys_indexing_matrix = get_indexing_matrix(K, W, H, c.device)
|
350
|
+
to_keys = partial(
|
351
|
+
single_to_keys, indexing_matrix=keys_indexing_matrix, W=W, H=H
|
352
|
+
)
|
353
|
+
|
354
|
+
atom_ref_pos_queries = atom_ref_pos.view(B, K, W, 1, 3)
|
355
|
+
atom_ref_pos_keys = to_keys(atom_ref_pos).view(B, K, 1, H, 3)
|
356
|
+
|
357
|
+
d = atom_ref_pos_keys - atom_ref_pos_queries # Float['b k w h 3']
|
358
|
+
d_norm = torch.sum(d * d, dim=-1, keepdim=True) # Float['b k w h 1']
|
359
|
+
d_norm = 1 / (
|
360
|
+
1 + d_norm
|
361
|
+
) # AF3 feeds in the reciprocal of the distance norm
|
362
|
+
|
363
|
+
atom_mask_queries = atom_mask.view(B, K, W, 1)
|
364
|
+
atom_mask_keys = (
|
365
|
+
to_keys(atom_mask.unsqueeze(-1).float()).view(B, K, 1, H).bool()
|
366
|
+
)
|
367
|
+
atom_uid_queries = atom_uid.view(B, K, W, 1)
|
368
|
+
atom_uid_keys = (
|
369
|
+
to_keys(atom_uid.unsqueeze(-1).float()).view(B, K, 1, H).long()
|
370
|
+
)
|
371
|
+
v = (
|
372
|
+
(
|
373
|
+
atom_mask_queries
|
374
|
+
& atom_mask_keys
|
375
|
+
& (atom_uid_queries == atom_uid_keys)
|
376
|
+
)
|
377
|
+
.float()
|
378
|
+
.unsqueeze(-1)
|
379
|
+
) # Bool['b k w h 1']
|
380
|
+
|
381
|
+
p = self.embed_atompair_ref_pos(d) * v
|
382
|
+
p = p + self.embed_atompair_ref_dist(d_norm) * v
|
383
|
+
p = p + self.embed_atompair_mask(v) * v
|
384
|
+
|
385
|
+
q = c
|
386
|
+
|
387
|
+
if self.structure_prediction:
|
388
|
+
# run only in structure model not in initial encoding
|
389
|
+
atom_to_token = feats["atom_to_token"].float() # Long['b m n'],
|
390
|
+
|
391
|
+
s_to_c = self.s_to_c_trans(s_trunk.float())
|
392
|
+
s_to_c = torch.bmm(atom_to_token, s_to_c)
|
393
|
+
c = c + s_to_c.to(c)
|
394
|
+
|
395
|
+
atom_to_token_queries = atom_to_token.view(
|
396
|
+
B, K, W, atom_to_token.shape[-1]
|
397
|
+
)
|
398
|
+
atom_to_token_keys = to_keys(atom_to_token)
|
399
|
+
z_to_p = self.z_to_p_trans(z.float())
|
400
|
+
z_to_p = torch.einsum(
|
401
|
+
"bijd,bwki,bwlj->bwkld",
|
402
|
+
z_to_p,
|
403
|
+
atom_to_token_queries,
|
404
|
+
atom_to_token_keys,
|
405
|
+
)
|
406
|
+
p = p + z_to_p.to(p)
|
407
|
+
|
408
|
+
p = p + self.c_to_p_trans_q(c.view(B, K, W, 1, c.shape[-1]))
|
409
|
+
p = p + self.c_to_p_trans_k(to_keys(c).view(B, K, 1, H, c.shape[-1]))
|
410
|
+
p = p + self.p_mlp(p)
|
411
|
+
return q, c, p, to_keys
|
412
|
+
|
413
|
+
|
414
|
+
class AtomAttentionEncoder(Module):
|
415
|
+
def __init__(
|
416
|
+
self,
|
417
|
+
atom_s,
|
418
|
+
token_s,
|
419
|
+
atoms_per_window_queries,
|
420
|
+
atoms_per_window_keys,
|
421
|
+
atom_encoder_depth=3,
|
422
|
+
atom_encoder_heads=4,
|
423
|
+
structure_prediction=True,
|
424
|
+
activation_checkpointing=False,
|
425
|
+
transformer_post_layer_norm=False,
|
426
|
+
):
|
427
|
+
super().__init__()
|
428
|
+
|
429
|
+
self.structure_prediction = structure_prediction
|
430
|
+
if structure_prediction:
|
431
|
+
self.r_to_q_trans = LinearNoBias(3, atom_s)
|
432
|
+
init.final_init_(self.r_to_q_trans.weight)
|
433
|
+
|
434
|
+
self.atom_encoder = AtomTransformer(
|
435
|
+
dim=atom_s,
|
436
|
+
dim_single_cond=atom_s,
|
437
|
+
attn_window_queries=atoms_per_window_queries,
|
438
|
+
attn_window_keys=atoms_per_window_keys,
|
439
|
+
depth=atom_encoder_depth,
|
440
|
+
heads=atom_encoder_heads,
|
441
|
+
activation_checkpointing=activation_checkpointing,
|
442
|
+
post_layer_norm=transformer_post_layer_norm,
|
443
|
+
)
|
444
|
+
|
445
|
+
self.atom_to_token_trans = nn.Sequential(
|
446
|
+
LinearNoBias(atom_s, 2 * token_s if structure_prediction else token_s),
|
447
|
+
nn.ReLU(),
|
448
|
+
)
|
449
|
+
|
450
|
+
def forward(
|
451
|
+
self,
|
452
|
+
feats,
|
453
|
+
q,
|
454
|
+
c,
|
455
|
+
atom_enc_bias,
|
456
|
+
to_keys,
|
457
|
+
r=None, # Float['bm m 3'],
|
458
|
+
multiplicity=1,
|
459
|
+
):
|
460
|
+
B, N, _ = feats["ref_pos"].shape
|
461
|
+
atom_mask = feats["atom_pad_mask"].bool() # Bool['b m'],
|
462
|
+
|
463
|
+
if self.structure_prediction:
|
464
|
+
# only here the multiplicity kicks in because we use the different positions r
|
465
|
+
q = q.repeat_interleave(multiplicity, 0)
|
466
|
+
r_to_q = self.r_to_q_trans(r)
|
467
|
+
q = q + r_to_q
|
468
|
+
|
469
|
+
c = c.repeat_interleave(multiplicity, 0)
|
470
|
+
atom_mask = atom_mask.repeat_interleave(multiplicity, 0)
|
471
|
+
|
472
|
+
q = self.atom_encoder(
|
473
|
+
q=q,
|
474
|
+
mask=atom_mask,
|
475
|
+
c=c,
|
476
|
+
bias=atom_enc_bias,
|
477
|
+
multiplicity=multiplicity,
|
478
|
+
to_keys=to_keys,
|
479
|
+
)
|
480
|
+
|
481
|
+
with torch.autocast("cuda", enabled=False):
|
482
|
+
q_to_a = self.atom_to_token_trans(q).float()
|
483
|
+
atom_to_token = feats["atom_to_token"].float()
|
484
|
+
atom_to_token = atom_to_token.repeat_interleave(multiplicity, 0)
|
485
|
+
atom_to_token_mean = atom_to_token / (
|
486
|
+
atom_to_token.sum(dim=1, keepdim=True) + 1e-6
|
487
|
+
)
|
488
|
+
a = torch.bmm(atom_to_token_mean.transpose(1, 2), q_to_a)
|
489
|
+
|
490
|
+
a = a.to(q)
|
491
|
+
|
492
|
+
return a, q, c, to_keys
|
493
|
+
|
494
|
+
|
495
|
+
class AtomAttentionDecoder(Module):
|
496
|
+
"""Algorithm 6."""
|
497
|
+
|
498
|
+
def __init__(
|
499
|
+
self,
|
500
|
+
atom_s,
|
501
|
+
token_s,
|
502
|
+
attn_window_queries,
|
503
|
+
attn_window_keys,
|
504
|
+
atom_decoder_depth=3,
|
505
|
+
atom_decoder_heads=4,
|
506
|
+
activation_checkpointing=False,
|
507
|
+
transformer_post_layer_norm=False,
|
508
|
+
):
|
509
|
+
super().__init__()
|
510
|
+
|
511
|
+
self.a_to_q_trans = LinearNoBias(2 * token_s, atom_s)
|
512
|
+
init.final_init_(self.a_to_q_trans.weight)
|
513
|
+
|
514
|
+
self.atom_decoder = AtomTransformer(
|
515
|
+
dim=atom_s,
|
516
|
+
dim_single_cond=atom_s,
|
517
|
+
attn_window_queries=attn_window_queries,
|
518
|
+
attn_window_keys=attn_window_keys,
|
519
|
+
depth=atom_decoder_depth,
|
520
|
+
heads=atom_decoder_heads,
|
521
|
+
activation_checkpointing=activation_checkpointing,
|
522
|
+
post_layer_norm=transformer_post_layer_norm,
|
523
|
+
)
|
524
|
+
|
525
|
+
if transformer_post_layer_norm:
|
526
|
+
self.atom_feat_to_atom_pos_update = LinearNoBias(atom_s, 3)
|
527
|
+
init.final_init_(self.atom_feat_to_atom_pos_update.weight)
|
528
|
+
else:
|
529
|
+
self.atom_feat_to_atom_pos_update = nn.Sequential(
|
530
|
+
nn.LayerNorm(atom_s), LinearNoBias(atom_s, 3)
|
531
|
+
)
|
532
|
+
init.final_init_(self.atom_feat_to_atom_pos_update[1].weight)
|
533
|
+
|
534
|
+
def forward(
|
535
|
+
self,
|
536
|
+
a, # Float['bm n 2ts'],
|
537
|
+
q, # Float['bm m as'],
|
538
|
+
c, # Float['bm m as'],
|
539
|
+
atom_dec_bias, # Float['bm m m az'],
|
540
|
+
feats,
|
541
|
+
to_keys,
|
542
|
+
multiplicity=1,
|
543
|
+
):
|
544
|
+
with torch.autocast("cuda", enabled=False):
|
545
|
+
atom_to_token = feats["atom_to_token"].float()
|
546
|
+
atom_to_token = atom_to_token.repeat_interleave(multiplicity, 0)
|
547
|
+
|
548
|
+
a_to_q = self.a_to_q_trans(a.float())
|
549
|
+
a_to_q = torch.bmm(atom_to_token, a_to_q)
|
550
|
+
|
551
|
+
q = q + a_to_q.to(q)
|
552
|
+
atom_mask = feats["atom_pad_mask"] # Bool['b m'],
|
553
|
+
atom_mask = atom_mask.repeat_interleave(multiplicity, 0)
|
554
|
+
|
555
|
+
q = self.atom_decoder(
|
556
|
+
q=q,
|
557
|
+
mask=atom_mask,
|
558
|
+
c=c,
|
559
|
+
bias=atom_dec_bias,
|
560
|
+
multiplicity=multiplicity,
|
561
|
+
to_keys=to_keys,
|
562
|
+
)
|
563
|
+
|
564
|
+
r_update = self.atom_feat_to_atom_pos_update(q)
|
565
|
+
return r_update
|