boltz-vsynthes 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltz/__init__.py +7 -0
- boltz/data/__init__.py +0 -0
- boltz/data/const.py +1184 -0
- boltz/data/crop/__init__.py +0 -0
- boltz/data/crop/affinity.py +164 -0
- boltz/data/crop/boltz.py +296 -0
- boltz/data/crop/cropper.py +45 -0
- boltz/data/feature/__init__.py +0 -0
- boltz/data/feature/featurizer.py +1230 -0
- boltz/data/feature/featurizerv2.py +2208 -0
- boltz/data/feature/symmetry.py +602 -0
- boltz/data/filter/__init__.py +0 -0
- boltz/data/filter/dynamic/__init__.py +0 -0
- boltz/data/filter/dynamic/date.py +76 -0
- boltz/data/filter/dynamic/filter.py +24 -0
- boltz/data/filter/dynamic/max_residues.py +37 -0
- boltz/data/filter/dynamic/resolution.py +34 -0
- boltz/data/filter/dynamic/size.py +38 -0
- boltz/data/filter/dynamic/subset.py +42 -0
- boltz/data/filter/static/__init__.py +0 -0
- boltz/data/filter/static/filter.py +26 -0
- boltz/data/filter/static/ligand.py +37 -0
- boltz/data/filter/static/polymer.py +299 -0
- boltz/data/module/__init__.py +0 -0
- boltz/data/module/inference.py +307 -0
- boltz/data/module/inferencev2.py +429 -0
- boltz/data/module/training.py +684 -0
- boltz/data/module/trainingv2.py +660 -0
- boltz/data/mol.py +900 -0
- boltz/data/msa/__init__.py +0 -0
- boltz/data/msa/mmseqs2.py +235 -0
- boltz/data/pad.py +84 -0
- boltz/data/parse/__init__.py +0 -0
- boltz/data/parse/a3m.py +134 -0
- boltz/data/parse/csv.py +100 -0
- boltz/data/parse/fasta.py +138 -0
- boltz/data/parse/mmcif.py +1239 -0
- boltz/data/parse/mmcif_with_constraints.py +1607 -0
- boltz/data/parse/schema.py +1851 -0
- boltz/data/parse/yaml.py +68 -0
- boltz/data/sample/__init__.py +0 -0
- boltz/data/sample/cluster.py +283 -0
- boltz/data/sample/distillation.py +57 -0
- boltz/data/sample/random.py +39 -0
- boltz/data/sample/sampler.py +49 -0
- boltz/data/tokenize/__init__.py +0 -0
- boltz/data/tokenize/boltz.py +195 -0
- boltz/data/tokenize/boltz2.py +396 -0
- boltz/data/tokenize/tokenizer.py +24 -0
- boltz/data/types.py +777 -0
- boltz/data/write/__init__.py +0 -0
- boltz/data/write/mmcif.py +305 -0
- boltz/data/write/pdb.py +171 -0
- boltz/data/write/utils.py +23 -0
- boltz/data/write/writer.py +330 -0
- boltz/main.py +1292 -0
- boltz/model/__init__.py +0 -0
- boltz/model/layers/__init__.py +0 -0
- boltz/model/layers/attention.py +132 -0
- boltz/model/layers/attentionv2.py +111 -0
- boltz/model/layers/confidence_utils.py +231 -0
- boltz/model/layers/dropout.py +34 -0
- boltz/model/layers/initialize.py +100 -0
- boltz/model/layers/outer_product_mean.py +98 -0
- boltz/model/layers/pair_averaging.py +135 -0
- boltz/model/layers/pairformer.py +337 -0
- boltz/model/layers/relative.py +58 -0
- boltz/model/layers/transition.py +78 -0
- boltz/model/layers/triangular_attention/__init__.py +0 -0
- boltz/model/layers/triangular_attention/attention.py +189 -0
- boltz/model/layers/triangular_attention/primitives.py +409 -0
- boltz/model/layers/triangular_attention/utils.py +380 -0
- boltz/model/layers/triangular_mult.py +212 -0
- boltz/model/loss/__init__.py +0 -0
- boltz/model/loss/bfactor.py +49 -0
- boltz/model/loss/confidence.py +590 -0
- boltz/model/loss/confidencev2.py +621 -0
- boltz/model/loss/diffusion.py +171 -0
- boltz/model/loss/diffusionv2.py +134 -0
- boltz/model/loss/distogram.py +48 -0
- boltz/model/loss/distogramv2.py +105 -0
- boltz/model/loss/validation.py +1025 -0
- boltz/model/models/__init__.py +0 -0
- boltz/model/models/boltz1.py +1286 -0
- boltz/model/models/boltz2.py +1249 -0
- boltz/model/modules/__init__.py +0 -0
- boltz/model/modules/affinity.py +223 -0
- boltz/model/modules/confidence.py +481 -0
- boltz/model/modules/confidence_utils.py +181 -0
- boltz/model/modules/confidencev2.py +495 -0
- boltz/model/modules/diffusion.py +844 -0
- boltz/model/modules/diffusion_conditioning.py +116 -0
- boltz/model/modules/diffusionv2.py +677 -0
- boltz/model/modules/encoders.py +639 -0
- boltz/model/modules/encodersv2.py +565 -0
- boltz/model/modules/transformers.py +322 -0
- boltz/model/modules/transformersv2.py +261 -0
- boltz/model/modules/trunk.py +688 -0
- boltz/model/modules/trunkv2.py +828 -0
- boltz/model/modules/utils.py +303 -0
- boltz/model/optim/__init__.py +0 -0
- boltz/model/optim/ema.py +389 -0
- boltz/model/optim/scheduler.py +99 -0
- boltz/model/potentials/__init__.py +0 -0
- boltz/model/potentials/potentials.py +497 -0
- boltz/model/potentials/schedules.py +32 -0
- boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
- boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
- boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
- boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
- boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
- boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,639 @@
|
|
1
|
+
# started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang
|
2
|
+
from functools import partial
|
3
|
+
from math import pi
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from einops import rearrange
|
7
|
+
from torch import nn
|
8
|
+
from torch.nn import Module, ModuleList
|
9
|
+
from torch.nn.functional import one_hot
|
10
|
+
|
11
|
+
import boltz.model.layers.initialize as init
|
12
|
+
from boltz.data import const
|
13
|
+
from boltz.model.layers.transition import Transition
|
14
|
+
from boltz.model.modules.transformers import AtomTransformer
|
15
|
+
from boltz.model.modules.utils import LinearNoBias
|
16
|
+
|
17
|
+
|
18
|
+
class FourierEmbedding(Module):
|
19
|
+
"""Fourier embedding layer."""
|
20
|
+
|
21
|
+
def __init__(self, dim):
|
22
|
+
"""Initialize the Fourier Embeddings.
|
23
|
+
|
24
|
+
Parameters
|
25
|
+
----------
|
26
|
+
dim : int
|
27
|
+
The dimension of the embeddings.
|
28
|
+
|
29
|
+
"""
|
30
|
+
super().__init__()
|
31
|
+
self.proj = nn.Linear(1, dim)
|
32
|
+
torch.nn.init.normal_(self.proj.weight, mean=0, std=1)
|
33
|
+
torch.nn.init.normal_(self.proj.bias, mean=0, std=1)
|
34
|
+
self.proj.requires_grad_(False)
|
35
|
+
|
36
|
+
def forward(
|
37
|
+
self,
|
38
|
+
times,
|
39
|
+
):
|
40
|
+
times = rearrange(times, "b -> b 1")
|
41
|
+
rand_proj = self.proj(times)
|
42
|
+
return torch.cos(2 * pi * rand_proj)
|
43
|
+
|
44
|
+
|
45
|
+
class RelativePositionEncoder(Module):
|
46
|
+
"""Relative position encoder."""
|
47
|
+
|
48
|
+
def __init__(self, token_z, r_max=32, s_max=2):
|
49
|
+
"""Initialize the relative position encoder.
|
50
|
+
|
51
|
+
Parameters
|
52
|
+
----------
|
53
|
+
token_z : int
|
54
|
+
The pair representation dimension.
|
55
|
+
r_max : int, optional
|
56
|
+
The maximum index distance, by default 32.
|
57
|
+
s_max : int, optional
|
58
|
+
The maximum chain distance, by default 2.
|
59
|
+
|
60
|
+
"""
|
61
|
+
super().__init__()
|
62
|
+
self.r_max = r_max
|
63
|
+
self.s_max = s_max
|
64
|
+
self.linear_layer = LinearNoBias(4 * (r_max + 1) + 2 * (s_max + 1) + 1, token_z)
|
65
|
+
|
66
|
+
def forward(self, feats):
|
67
|
+
b_same_chain = torch.eq(
|
68
|
+
feats["asym_id"][:, :, None], feats["asym_id"][:, None, :]
|
69
|
+
)
|
70
|
+
b_same_residue = torch.eq(
|
71
|
+
feats["residue_index"][:, :, None], feats["residue_index"][:, None, :]
|
72
|
+
)
|
73
|
+
b_same_entity = torch.eq(
|
74
|
+
feats["entity_id"][:, :, None], feats["entity_id"][:, None, :]
|
75
|
+
)
|
76
|
+
rel_pos = (
|
77
|
+
feats["residue_index"][:, :, None] - feats["residue_index"][:, None, :]
|
78
|
+
)
|
79
|
+
if torch.any(feats["cyclic_period"] != 0):
|
80
|
+
period = torch.where(
|
81
|
+
feats["cyclic_period"] > 0,
|
82
|
+
feats["cyclic_period"],
|
83
|
+
torch.zeros_like(feats["cyclic_period"]) + 10000,
|
84
|
+
).unsqueeze(1)
|
85
|
+
rel_pos = (rel_pos - period * torch.round(rel_pos / period)).long()
|
86
|
+
|
87
|
+
d_residue = torch.clip(
|
88
|
+
rel_pos + self.r_max,
|
89
|
+
0,
|
90
|
+
2 * self.r_max,
|
91
|
+
)
|
92
|
+
|
93
|
+
d_residue = torch.where(
|
94
|
+
b_same_chain, d_residue, torch.zeros_like(d_residue) + 2 * self.r_max + 1
|
95
|
+
)
|
96
|
+
a_rel_pos = one_hot(d_residue, 2 * self.r_max + 2)
|
97
|
+
|
98
|
+
d_token = torch.clip(
|
99
|
+
feats["token_index"][:, :, None]
|
100
|
+
- feats["token_index"][:, None, :]
|
101
|
+
+ self.r_max,
|
102
|
+
0,
|
103
|
+
2 * self.r_max,
|
104
|
+
)
|
105
|
+
d_token = torch.where(
|
106
|
+
b_same_chain & b_same_residue,
|
107
|
+
d_token,
|
108
|
+
torch.zeros_like(d_token) + 2 * self.r_max + 1,
|
109
|
+
)
|
110
|
+
a_rel_token = one_hot(d_token, 2 * self.r_max + 2)
|
111
|
+
|
112
|
+
d_chain = torch.clip(
|
113
|
+
feats["sym_id"][:, :, None] - feats["sym_id"][:, None, :] + self.s_max,
|
114
|
+
0,
|
115
|
+
2 * self.s_max,
|
116
|
+
)
|
117
|
+
d_chain = torch.where(
|
118
|
+
b_same_chain, torch.zeros_like(d_chain) + 2 * self.s_max + 1, d_chain
|
119
|
+
)
|
120
|
+
a_rel_chain = one_hot(d_chain, 2 * self.s_max + 2)
|
121
|
+
|
122
|
+
p = self.linear_layer(
|
123
|
+
torch.cat(
|
124
|
+
[
|
125
|
+
a_rel_pos.float(),
|
126
|
+
a_rel_token.float(),
|
127
|
+
b_same_entity.unsqueeze(-1).float(),
|
128
|
+
a_rel_chain.float(),
|
129
|
+
],
|
130
|
+
dim=-1,
|
131
|
+
)
|
132
|
+
)
|
133
|
+
return p
|
134
|
+
|
135
|
+
|
136
|
+
class SingleConditioning(Module):
|
137
|
+
"""Single conditioning layer."""
|
138
|
+
|
139
|
+
def __init__(
|
140
|
+
self,
|
141
|
+
sigma_data: float,
|
142
|
+
token_s=384,
|
143
|
+
dim_fourier=256,
|
144
|
+
num_transitions=2,
|
145
|
+
transition_expansion_factor=2,
|
146
|
+
eps=1e-20,
|
147
|
+
):
|
148
|
+
"""Initialize the single conditioning layer.
|
149
|
+
|
150
|
+
Parameters
|
151
|
+
----------
|
152
|
+
sigma_data : float
|
153
|
+
The data sigma.
|
154
|
+
token_s : int, optional
|
155
|
+
The single representation dimension, by default 384.
|
156
|
+
dim_fourier : int, optional
|
157
|
+
The fourier embeddings dimension, by default 256.
|
158
|
+
num_transitions : int, optional
|
159
|
+
The number of transitions layers, by default 2.
|
160
|
+
transition_expansion_factor : int, optional
|
161
|
+
The transition expansion factor, by default 2.
|
162
|
+
eps : float, optional
|
163
|
+
The epsilon value, by default 1e-20.
|
164
|
+
|
165
|
+
"""
|
166
|
+
super().__init__()
|
167
|
+
self.eps = eps
|
168
|
+
self.sigma_data = sigma_data
|
169
|
+
|
170
|
+
input_dim = (
|
171
|
+
2 * token_s + 2 * const.num_tokens + 1 + len(const.pocket_contact_info)
|
172
|
+
)
|
173
|
+
self.norm_single = nn.LayerNorm(input_dim)
|
174
|
+
self.single_embed = nn.Linear(input_dim, 2 * token_s)
|
175
|
+
self.fourier_embed = FourierEmbedding(dim_fourier)
|
176
|
+
self.norm_fourier = nn.LayerNorm(dim_fourier)
|
177
|
+
self.fourier_to_single = LinearNoBias(dim_fourier, 2 * token_s)
|
178
|
+
|
179
|
+
transitions = ModuleList([])
|
180
|
+
for _ in range(num_transitions):
|
181
|
+
transition = Transition(
|
182
|
+
dim=2 * token_s, hidden=transition_expansion_factor * 2 * token_s
|
183
|
+
)
|
184
|
+
transitions.append(transition)
|
185
|
+
|
186
|
+
self.transitions = transitions
|
187
|
+
|
188
|
+
def forward(
|
189
|
+
self,
|
190
|
+
*,
|
191
|
+
times,
|
192
|
+
s_trunk,
|
193
|
+
s_inputs,
|
194
|
+
):
|
195
|
+
s = torch.cat((s_trunk, s_inputs), dim=-1)
|
196
|
+
s = self.single_embed(self.norm_single(s))
|
197
|
+
fourier_embed = self.fourier_embed(times)
|
198
|
+
normed_fourier = self.norm_fourier(fourier_embed)
|
199
|
+
fourier_to_single = self.fourier_to_single(normed_fourier)
|
200
|
+
|
201
|
+
s = rearrange(fourier_to_single, "b d -> b 1 d") + s
|
202
|
+
|
203
|
+
for transition in self.transitions:
|
204
|
+
s = transition(s) + s
|
205
|
+
|
206
|
+
return s, normed_fourier
|
207
|
+
|
208
|
+
|
209
|
+
class PairwiseConditioning(Module):
|
210
|
+
"""Pairwise conditioning layer."""
|
211
|
+
|
212
|
+
def __init__(
|
213
|
+
self,
|
214
|
+
token_z,
|
215
|
+
dim_token_rel_pos_feats,
|
216
|
+
num_transitions=2,
|
217
|
+
transition_expansion_factor=2,
|
218
|
+
):
|
219
|
+
"""Initialize the pairwise conditioning layer.
|
220
|
+
|
221
|
+
Parameters
|
222
|
+
----------
|
223
|
+
token_z : int
|
224
|
+
The pair representation dimension.
|
225
|
+
dim_token_rel_pos_feats : int
|
226
|
+
The token relative position features dimension.
|
227
|
+
num_transitions : int, optional
|
228
|
+
The number of transitions layers, by default 2.
|
229
|
+
transition_expansion_factor : int, optional
|
230
|
+
The transition expansion factor, by default 2.
|
231
|
+
|
232
|
+
"""
|
233
|
+
super().__init__()
|
234
|
+
|
235
|
+
self.dim_pairwise_init_proj = nn.Sequential(
|
236
|
+
nn.LayerNorm(token_z + dim_token_rel_pos_feats),
|
237
|
+
LinearNoBias(token_z + dim_token_rel_pos_feats, token_z),
|
238
|
+
)
|
239
|
+
|
240
|
+
transitions = ModuleList([])
|
241
|
+
for _ in range(num_transitions):
|
242
|
+
transition = Transition(
|
243
|
+
dim=token_z, hidden=transition_expansion_factor * token_z
|
244
|
+
)
|
245
|
+
transitions.append(transition)
|
246
|
+
|
247
|
+
self.transitions = transitions
|
248
|
+
|
249
|
+
def forward(
|
250
|
+
self,
|
251
|
+
z_trunk,
|
252
|
+
token_rel_pos_feats,
|
253
|
+
):
|
254
|
+
z = torch.cat((z_trunk, token_rel_pos_feats), dim=-1)
|
255
|
+
z = self.dim_pairwise_init_proj(z)
|
256
|
+
|
257
|
+
for transition in self.transitions:
|
258
|
+
z = transition(z) + z
|
259
|
+
|
260
|
+
return z
|
261
|
+
|
262
|
+
|
263
|
+
def get_indexing_matrix(K, W, H, device):
|
264
|
+
assert W % 2 == 0
|
265
|
+
assert H % (W // 2) == 0
|
266
|
+
|
267
|
+
h = H // (W // 2)
|
268
|
+
assert h % 2 == 0
|
269
|
+
|
270
|
+
arange = torch.arange(2 * K, device=device)
|
271
|
+
index = ((arange.unsqueeze(0) - arange.unsqueeze(1)) + h // 2).clamp(
|
272
|
+
min=0, max=h + 1
|
273
|
+
)
|
274
|
+
index = index.view(K, 2, 2 * K)[:, 0, :]
|
275
|
+
onehot = one_hot(index, num_classes=h + 2)[..., 1:-1].transpose(1, 0)
|
276
|
+
return onehot.reshape(2 * K, h * K).float()
|
277
|
+
|
278
|
+
|
279
|
+
def single_to_keys(single, indexing_matrix, W, H):
|
280
|
+
B, N, D = single.shape
|
281
|
+
K = N // W
|
282
|
+
single = single.view(B, 2 * K, W // 2, D)
|
283
|
+
return torch.einsum("b j i d, j k -> b k i d", single, indexing_matrix).reshape(
|
284
|
+
B, K, H, D
|
285
|
+
)
|
286
|
+
|
287
|
+
|
288
|
+
class AtomAttentionEncoder(Module):
|
289
|
+
"""Atom attention encoder."""
|
290
|
+
|
291
|
+
def __init__(
|
292
|
+
self,
|
293
|
+
atom_s,
|
294
|
+
atom_z,
|
295
|
+
token_s,
|
296
|
+
token_z,
|
297
|
+
atoms_per_window_queries,
|
298
|
+
atoms_per_window_keys,
|
299
|
+
atom_feature_dim,
|
300
|
+
atom_encoder_depth=3,
|
301
|
+
atom_encoder_heads=4,
|
302
|
+
structure_prediction=True,
|
303
|
+
activation_checkpointing=False,
|
304
|
+
):
|
305
|
+
"""Initialize the atom attention encoder.
|
306
|
+
|
307
|
+
Parameters
|
308
|
+
----------
|
309
|
+
atom_s : int
|
310
|
+
The atom single representation dimension.
|
311
|
+
atom_z : int
|
312
|
+
The atom pair representation dimension.
|
313
|
+
token_s : int
|
314
|
+
The single representation dimension.
|
315
|
+
token_z : int
|
316
|
+
The pair representation dimension.
|
317
|
+
atoms_per_window_queries : int
|
318
|
+
The number of atoms per window for queries.
|
319
|
+
atoms_per_window_keys : int
|
320
|
+
The number of atoms per window for keys.
|
321
|
+
atom_feature_dim : int
|
322
|
+
The atom feature dimension.
|
323
|
+
atom_encoder_depth : int, optional
|
324
|
+
The number of transformer layers, by default 3.
|
325
|
+
atom_encoder_heads : int, optional
|
326
|
+
The number of transformer heads, by default 4.
|
327
|
+
structure_prediction : bool, optional
|
328
|
+
Whether it is used in the diffusion module, by default True.
|
329
|
+
activation_checkpointing : bool, optional
|
330
|
+
Whether to use activation checkpointing, by default False.
|
331
|
+
|
332
|
+
"""
|
333
|
+
super().__init__()
|
334
|
+
|
335
|
+
self.embed_atom_features = LinearNoBias(atom_feature_dim, atom_s)
|
336
|
+
self.embed_atompair_ref_pos = LinearNoBias(3, atom_z)
|
337
|
+
self.embed_atompair_ref_dist = LinearNoBias(1, atom_z)
|
338
|
+
self.embed_atompair_mask = LinearNoBias(1, atom_z)
|
339
|
+
self.atoms_per_window_queries = atoms_per_window_queries
|
340
|
+
self.atoms_per_window_keys = atoms_per_window_keys
|
341
|
+
|
342
|
+
self.structure_prediction = structure_prediction
|
343
|
+
if structure_prediction:
|
344
|
+
self.s_to_c_trans = nn.Sequential(
|
345
|
+
nn.LayerNorm(token_s), LinearNoBias(token_s, atom_s)
|
346
|
+
)
|
347
|
+
init.final_init_(self.s_to_c_trans[1].weight)
|
348
|
+
|
349
|
+
self.z_to_p_trans = nn.Sequential(
|
350
|
+
nn.LayerNorm(token_z), LinearNoBias(token_z, atom_z)
|
351
|
+
)
|
352
|
+
init.final_init_(self.z_to_p_trans[1].weight)
|
353
|
+
|
354
|
+
self.r_to_q_trans = LinearNoBias(10, atom_s)
|
355
|
+
init.final_init_(self.r_to_q_trans.weight)
|
356
|
+
|
357
|
+
self.c_to_p_trans_k = nn.Sequential(
|
358
|
+
nn.ReLU(),
|
359
|
+
LinearNoBias(atom_s, atom_z),
|
360
|
+
)
|
361
|
+
init.final_init_(self.c_to_p_trans_k[1].weight)
|
362
|
+
|
363
|
+
self.c_to_p_trans_q = nn.Sequential(
|
364
|
+
nn.ReLU(),
|
365
|
+
LinearNoBias(atom_s, atom_z),
|
366
|
+
)
|
367
|
+
init.final_init_(self.c_to_p_trans_q[1].weight)
|
368
|
+
|
369
|
+
self.p_mlp = nn.Sequential(
|
370
|
+
nn.ReLU(),
|
371
|
+
LinearNoBias(atom_z, atom_z),
|
372
|
+
nn.ReLU(),
|
373
|
+
LinearNoBias(atom_z, atom_z),
|
374
|
+
nn.ReLU(),
|
375
|
+
LinearNoBias(atom_z, atom_z),
|
376
|
+
)
|
377
|
+
init.final_init_(self.p_mlp[5].weight)
|
378
|
+
|
379
|
+
self.atom_encoder = AtomTransformer(
|
380
|
+
dim=atom_s,
|
381
|
+
dim_single_cond=atom_s,
|
382
|
+
dim_pairwise=atom_z,
|
383
|
+
attn_window_queries=atoms_per_window_queries,
|
384
|
+
attn_window_keys=atoms_per_window_keys,
|
385
|
+
depth=atom_encoder_depth,
|
386
|
+
heads=atom_encoder_heads,
|
387
|
+
activation_checkpointing=activation_checkpointing,
|
388
|
+
)
|
389
|
+
|
390
|
+
self.atom_to_token_trans = nn.Sequential(
|
391
|
+
LinearNoBias(atom_s, 2 * token_s if structure_prediction else token_s),
|
392
|
+
nn.ReLU(),
|
393
|
+
)
|
394
|
+
|
395
|
+
def forward(
|
396
|
+
self,
|
397
|
+
feats,
|
398
|
+
s_trunk=None,
|
399
|
+
z=None,
|
400
|
+
r=None,
|
401
|
+
multiplicity=1,
|
402
|
+
model_cache=None,
|
403
|
+
):
|
404
|
+
B, N, _ = feats["ref_pos"].shape
|
405
|
+
atom_mask = feats["atom_pad_mask"].bool()
|
406
|
+
|
407
|
+
layer_cache = None
|
408
|
+
if model_cache is not None:
|
409
|
+
cache_prefix = "atomencoder"
|
410
|
+
if cache_prefix not in model_cache:
|
411
|
+
model_cache[cache_prefix] = {}
|
412
|
+
layer_cache = model_cache[cache_prefix]
|
413
|
+
|
414
|
+
if model_cache is None or len(layer_cache) == 0:
|
415
|
+
# either model is not using the cache or it is the first time running it
|
416
|
+
|
417
|
+
atom_ref_pos = feats["ref_pos"]
|
418
|
+
atom_uid = feats["ref_space_uid"]
|
419
|
+
atom_feats = torch.cat(
|
420
|
+
[
|
421
|
+
atom_ref_pos,
|
422
|
+
feats["ref_charge"].unsqueeze(-1),
|
423
|
+
feats["atom_pad_mask"].unsqueeze(-1),
|
424
|
+
feats["ref_element"],
|
425
|
+
feats["ref_atom_name_chars"].reshape(B, N, 4 * 64),
|
426
|
+
],
|
427
|
+
dim=-1,
|
428
|
+
)
|
429
|
+
|
430
|
+
c = self.embed_atom_features(atom_feats)
|
431
|
+
|
432
|
+
# NOTE: we are already creating the windows to make it more efficient
|
433
|
+
W, H = self.atoms_per_window_queries, self.atoms_per_window_keys
|
434
|
+
B, N = c.shape[:2]
|
435
|
+
K = N // W
|
436
|
+
keys_indexing_matrix = get_indexing_matrix(K, W, H, c.device)
|
437
|
+
to_keys = partial(
|
438
|
+
single_to_keys, indexing_matrix=keys_indexing_matrix, W=W, H=H
|
439
|
+
)
|
440
|
+
|
441
|
+
atom_ref_pos_queries = atom_ref_pos.view(B, K, W, 1, 3)
|
442
|
+
atom_ref_pos_keys = to_keys(atom_ref_pos).view(B, K, 1, H, 3)
|
443
|
+
|
444
|
+
d = atom_ref_pos_keys - atom_ref_pos_queries
|
445
|
+
d_norm = torch.sum(d * d, dim=-1, keepdim=True)
|
446
|
+
d_norm = 1 / (1 + d_norm)
|
447
|
+
|
448
|
+
atom_mask_queries = atom_mask.view(B, K, W, 1)
|
449
|
+
atom_mask_keys = (
|
450
|
+
to_keys(atom_mask.unsqueeze(-1).float()).view(B, K, 1, H).bool()
|
451
|
+
)
|
452
|
+
atom_uid_queries = atom_uid.view(B, K, W, 1)
|
453
|
+
atom_uid_keys = (
|
454
|
+
to_keys(atom_uid.unsqueeze(-1).float()).view(B, K, 1, H).long()
|
455
|
+
)
|
456
|
+
v = (
|
457
|
+
(
|
458
|
+
atom_mask_queries
|
459
|
+
& atom_mask_keys
|
460
|
+
& (atom_uid_queries == atom_uid_keys)
|
461
|
+
)
|
462
|
+
.float()
|
463
|
+
.unsqueeze(-1)
|
464
|
+
)
|
465
|
+
|
466
|
+
p = self.embed_atompair_ref_pos(d) * v
|
467
|
+
p = p + self.embed_atompair_ref_dist(d_norm) * v
|
468
|
+
p = p + self.embed_atompair_mask(v) * v
|
469
|
+
|
470
|
+
q = c
|
471
|
+
|
472
|
+
if self.structure_prediction:
|
473
|
+
# run only in structure model not in initial encoding
|
474
|
+
atom_to_token = feats["atom_to_token"].float()
|
475
|
+
|
476
|
+
s_to_c = self.s_to_c_trans(s_trunk)
|
477
|
+
s_to_c = torch.bmm(atom_to_token, s_to_c)
|
478
|
+
c = c + s_to_c
|
479
|
+
|
480
|
+
atom_to_token_queries = atom_to_token.view(
|
481
|
+
B, K, W, atom_to_token.shape[-1]
|
482
|
+
)
|
483
|
+
atom_to_token_keys = to_keys(atom_to_token)
|
484
|
+
z_to_p = self.z_to_p_trans(z)
|
485
|
+
z_to_p = torch.einsum(
|
486
|
+
"bijd,bwki,bwlj->bwkld",
|
487
|
+
z_to_p,
|
488
|
+
atom_to_token_queries,
|
489
|
+
atom_to_token_keys,
|
490
|
+
)
|
491
|
+
p = p + z_to_p
|
492
|
+
|
493
|
+
p = p + self.c_to_p_trans_q(c.view(B, K, W, 1, c.shape[-1]))
|
494
|
+
p = p + self.c_to_p_trans_k(to_keys(c).view(B, K, 1, H, c.shape[-1]))
|
495
|
+
p = p + self.p_mlp(p)
|
496
|
+
|
497
|
+
if model_cache is not None:
|
498
|
+
layer_cache["q"] = q
|
499
|
+
layer_cache["c"] = c
|
500
|
+
layer_cache["p"] = p
|
501
|
+
layer_cache["to_keys"] = to_keys
|
502
|
+
|
503
|
+
else:
|
504
|
+
q = layer_cache["q"]
|
505
|
+
c = layer_cache["c"]
|
506
|
+
p = layer_cache["p"]
|
507
|
+
to_keys = layer_cache["to_keys"]
|
508
|
+
|
509
|
+
if self.structure_prediction:
|
510
|
+
# only here the multiplicity kicks in because we use the different positions r
|
511
|
+
q = q.repeat_interleave(multiplicity, 0)
|
512
|
+
r_input = torch.cat(
|
513
|
+
[r, torch.zeros((B * multiplicity, N, 7)).to(r)],
|
514
|
+
dim=-1,
|
515
|
+
)
|
516
|
+
r_to_q = self.r_to_q_trans(r_input)
|
517
|
+
q = q + r_to_q
|
518
|
+
|
519
|
+
c = c.repeat_interleave(multiplicity, 0)
|
520
|
+
atom_mask = atom_mask.repeat_interleave(multiplicity, 0)
|
521
|
+
|
522
|
+
q = self.atom_encoder(
|
523
|
+
q=q,
|
524
|
+
mask=atom_mask,
|
525
|
+
c=c,
|
526
|
+
p=p,
|
527
|
+
multiplicity=multiplicity,
|
528
|
+
to_keys=to_keys,
|
529
|
+
model_cache=layer_cache,
|
530
|
+
)
|
531
|
+
|
532
|
+
q_to_a = self.atom_to_token_trans(q)
|
533
|
+
atom_to_token = feats["atom_to_token"].float()
|
534
|
+
atom_to_token = atom_to_token.repeat_interleave(multiplicity, 0)
|
535
|
+
atom_to_token_mean = atom_to_token / (
|
536
|
+
atom_to_token.sum(dim=1, keepdim=True) + 1e-6
|
537
|
+
)
|
538
|
+
a = torch.bmm(atom_to_token_mean.transpose(1, 2), q_to_a)
|
539
|
+
|
540
|
+
return a, q, c, p, to_keys
|
541
|
+
|
542
|
+
|
543
|
+
class AtomAttentionDecoder(Module):
|
544
|
+
"""Atom attention decoder."""
|
545
|
+
|
546
|
+
def __init__(
|
547
|
+
self,
|
548
|
+
atom_s,
|
549
|
+
atom_z,
|
550
|
+
token_s,
|
551
|
+
attn_window_queries,
|
552
|
+
attn_window_keys,
|
553
|
+
atom_decoder_depth=3,
|
554
|
+
atom_decoder_heads=4,
|
555
|
+
activation_checkpointing=False,
|
556
|
+
):
|
557
|
+
"""Initialize the atom attention decoder.
|
558
|
+
|
559
|
+
Parameters
|
560
|
+
----------
|
561
|
+
atom_s : int
|
562
|
+
The atom single representation dimension.
|
563
|
+
atom_z : int
|
564
|
+
The atom pair representation dimension.
|
565
|
+
token_s : int
|
566
|
+
The single representation dimension.
|
567
|
+
attn_window_queries : int
|
568
|
+
The number of atoms per window for queries.
|
569
|
+
attn_window_keys : int
|
570
|
+
The number of atoms per window for keys.
|
571
|
+
atom_decoder_depth : int, optional
|
572
|
+
The number of transformer layers, by default 3.
|
573
|
+
atom_decoder_heads : int, optional
|
574
|
+
The number of transformer heads, by default 4.
|
575
|
+
activation_checkpointing : bool, optional
|
576
|
+
Whether to use activation checkpointing, by default False.
|
577
|
+
|
578
|
+
"""
|
579
|
+
super().__init__()
|
580
|
+
|
581
|
+
self.a_to_q_trans = LinearNoBias(2 * token_s, atom_s)
|
582
|
+
init.final_init_(self.a_to_q_trans.weight)
|
583
|
+
|
584
|
+
self.atom_decoder = AtomTransformer(
|
585
|
+
dim=atom_s,
|
586
|
+
dim_single_cond=atom_s,
|
587
|
+
dim_pairwise=atom_z,
|
588
|
+
attn_window_queries=attn_window_queries,
|
589
|
+
attn_window_keys=attn_window_keys,
|
590
|
+
depth=atom_decoder_depth,
|
591
|
+
heads=atom_decoder_heads,
|
592
|
+
activation_checkpointing=activation_checkpointing,
|
593
|
+
)
|
594
|
+
|
595
|
+
self.atom_feat_to_atom_pos_update = nn.Sequential(
|
596
|
+
nn.LayerNorm(atom_s), LinearNoBias(atom_s, 3)
|
597
|
+
)
|
598
|
+
init.final_init_(self.atom_feat_to_atom_pos_update[1].weight)
|
599
|
+
|
600
|
+
def forward(
|
601
|
+
self,
|
602
|
+
a,
|
603
|
+
q,
|
604
|
+
c,
|
605
|
+
p,
|
606
|
+
feats,
|
607
|
+
to_keys,
|
608
|
+
multiplicity=1,
|
609
|
+
model_cache=None,
|
610
|
+
):
|
611
|
+
atom_mask = feats["atom_pad_mask"]
|
612
|
+
atom_mask = atom_mask.repeat_interleave(multiplicity, 0)
|
613
|
+
|
614
|
+
atom_to_token = feats["atom_to_token"].float()
|
615
|
+
atom_to_token = atom_to_token.repeat_interleave(multiplicity, 0)
|
616
|
+
|
617
|
+
a_to_q = self.a_to_q_trans(a)
|
618
|
+
a_to_q = torch.bmm(atom_to_token, a_to_q)
|
619
|
+
q = q + a_to_q
|
620
|
+
|
621
|
+
layer_cache = None
|
622
|
+
if model_cache is not None:
|
623
|
+
cache_prefix = "atomdecoder"
|
624
|
+
if cache_prefix not in model_cache:
|
625
|
+
model_cache[cache_prefix] = {}
|
626
|
+
layer_cache = model_cache[cache_prefix]
|
627
|
+
|
628
|
+
q = self.atom_decoder(
|
629
|
+
q=q,
|
630
|
+
mask=atom_mask,
|
631
|
+
c=c,
|
632
|
+
p=p,
|
633
|
+
multiplicity=multiplicity,
|
634
|
+
to_keys=to_keys,
|
635
|
+
model_cache=layer_cache,
|
636
|
+
)
|
637
|
+
|
638
|
+
r_update = self.atom_feat_to_atom_pos_update(q)
|
639
|
+
return r_update
|