boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,639 @@
1
+ # started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang
2
+ from functools import partial
3
+ from math import pi
4
+
5
+ import torch
6
+ from einops import rearrange
7
+ from torch import nn
8
+ from torch.nn import Module, ModuleList
9
+ from torch.nn.functional import one_hot
10
+
11
+ import boltz.model.layers.initialize as init
12
+ from boltz.data import const
13
+ from boltz.model.layers.transition import Transition
14
+ from boltz.model.modules.transformers import AtomTransformer
15
+ from boltz.model.modules.utils import LinearNoBias
16
+
17
+
18
+ class FourierEmbedding(Module):
19
+ """Fourier embedding layer."""
20
+
21
+ def __init__(self, dim):
22
+ """Initialize the Fourier Embeddings.
23
+
24
+ Parameters
25
+ ----------
26
+ dim : int
27
+ The dimension of the embeddings.
28
+
29
+ """
30
+ super().__init__()
31
+ self.proj = nn.Linear(1, dim)
32
+ torch.nn.init.normal_(self.proj.weight, mean=0, std=1)
33
+ torch.nn.init.normal_(self.proj.bias, mean=0, std=1)
34
+ self.proj.requires_grad_(False)
35
+
36
+ def forward(
37
+ self,
38
+ times,
39
+ ):
40
+ times = rearrange(times, "b -> b 1")
41
+ rand_proj = self.proj(times)
42
+ return torch.cos(2 * pi * rand_proj)
43
+
44
+
45
+ class RelativePositionEncoder(Module):
46
+ """Relative position encoder."""
47
+
48
+ def __init__(self, token_z, r_max=32, s_max=2):
49
+ """Initialize the relative position encoder.
50
+
51
+ Parameters
52
+ ----------
53
+ token_z : int
54
+ The pair representation dimension.
55
+ r_max : int, optional
56
+ The maximum index distance, by default 32.
57
+ s_max : int, optional
58
+ The maximum chain distance, by default 2.
59
+
60
+ """
61
+ super().__init__()
62
+ self.r_max = r_max
63
+ self.s_max = s_max
64
+ self.linear_layer = LinearNoBias(4 * (r_max + 1) + 2 * (s_max + 1) + 1, token_z)
65
+
66
+ def forward(self, feats):
67
+ b_same_chain = torch.eq(
68
+ feats["asym_id"][:, :, None], feats["asym_id"][:, None, :]
69
+ )
70
+ b_same_residue = torch.eq(
71
+ feats["residue_index"][:, :, None], feats["residue_index"][:, None, :]
72
+ )
73
+ b_same_entity = torch.eq(
74
+ feats["entity_id"][:, :, None], feats["entity_id"][:, None, :]
75
+ )
76
+ rel_pos = (
77
+ feats["residue_index"][:, :, None] - feats["residue_index"][:, None, :]
78
+ )
79
+ if torch.any(feats["cyclic_period"] != 0):
80
+ period = torch.where(
81
+ feats["cyclic_period"] > 0,
82
+ feats["cyclic_period"],
83
+ torch.zeros_like(feats["cyclic_period"]) + 10000,
84
+ ).unsqueeze(1)
85
+ rel_pos = (rel_pos - period * torch.round(rel_pos / period)).long()
86
+
87
+ d_residue = torch.clip(
88
+ rel_pos + self.r_max,
89
+ 0,
90
+ 2 * self.r_max,
91
+ )
92
+
93
+ d_residue = torch.where(
94
+ b_same_chain, d_residue, torch.zeros_like(d_residue) + 2 * self.r_max + 1
95
+ )
96
+ a_rel_pos = one_hot(d_residue, 2 * self.r_max + 2)
97
+
98
+ d_token = torch.clip(
99
+ feats["token_index"][:, :, None]
100
+ - feats["token_index"][:, None, :]
101
+ + self.r_max,
102
+ 0,
103
+ 2 * self.r_max,
104
+ )
105
+ d_token = torch.where(
106
+ b_same_chain & b_same_residue,
107
+ d_token,
108
+ torch.zeros_like(d_token) + 2 * self.r_max + 1,
109
+ )
110
+ a_rel_token = one_hot(d_token, 2 * self.r_max + 2)
111
+
112
+ d_chain = torch.clip(
113
+ feats["sym_id"][:, :, None] - feats["sym_id"][:, None, :] + self.s_max,
114
+ 0,
115
+ 2 * self.s_max,
116
+ )
117
+ d_chain = torch.where(
118
+ b_same_chain, torch.zeros_like(d_chain) + 2 * self.s_max + 1, d_chain
119
+ )
120
+ a_rel_chain = one_hot(d_chain, 2 * self.s_max + 2)
121
+
122
+ p = self.linear_layer(
123
+ torch.cat(
124
+ [
125
+ a_rel_pos.float(),
126
+ a_rel_token.float(),
127
+ b_same_entity.unsqueeze(-1).float(),
128
+ a_rel_chain.float(),
129
+ ],
130
+ dim=-1,
131
+ )
132
+ )
133
+ return p
134
+
135
+
136
+ class SingleConditioning(Module):
137
+ """Single conditioning layer."""
138
+
139
+ def __init__(
140
+ self,
141
+ sigma_data: float,
142
+ token_s=384,
143
+ dim_fourier=256,
144
+ num_transitions=2,
145
+ transition_expansion_factor=2,
146
+ eps=1e-20,
147
+ ):
148
+ """Initialize the single conditioning layer.
149
+
150
+ Parameters
151
+ ----------
152
+ sigma_data : float
153
+ The data sigma.
154
+ token_s : int, optional
155
+ The single representation dimension, by default 384.
156
+ dim_fourier : int, optional
157
+ The fourier embeddings dimension, by default 256.
158
+ num_transitions : int, optional
159
+ The number of transitions layers, by default 2.
160
+ transition_expansion_factor : int, optional
161
+ The transition expansion factor, by default 2.
162
+ eps : float, optional
163
+ The epsilon value, by default 1e-20.
164
+
165
+ """
166
+ super().__init__()
167
+ self.eps = eps
168
+ self.sigma_data = sigma_data
169
+
170
+ input_dim = (
171
+ 2 * token_s + 2 * const.num_tokens + 1 + len(const.pocket_contact_info)
172
+ )
173
+ self.norm_single = nn.LayerNorm(input_dim)
174
+ self.single_embed = nn.Linear(input_dim, 2 * token_s)
175
+ self.fourier_embed = FourierEmbedding(dim_fourier)
176
+ self.norm_fourier = nn.LayerNorm(dim_fourier)
177
+ self.fourier_to_single = LinearNoBias(dim_fourier, 2 * token_s)
178
+
179
+ transitions = ModuleList([])
180
+ for _ in range(num_transitions):
181
+ transition = Transition(
182
+ dim=2 * token_s, hidden=transition_expansion_factor * 2 * token_s
183
+ )
184
+ transitions.append(transition)
185
+
186
+ self.transitions = transitions
187
+
188
+ def forward(
189
+ self,
190
+ *,
191
+ times,
192
+ s_trunk,
193
+ s_inputs,
194
+ ):
195
+ s = torch.cat((s_trunk, s_inputs), dim=-1)
196
+ s = self.single_embed(self.norm_single(s))
197
+ fourier_embed = self.fourier_embed(times)
198
+ normed_fourier = self.norm_fourier(fourier_embed)
199
+ fourier_to_single = self.fourier_to_single(normed_fourier)
200
+
201
+ s = rearrange(fourier_to_single, "b d -> b 1 d") + s
202
+
203
+ for transition in self.transitions:
204
+ s = transition(s) + s
205
+
206
+ return s, normed_fourier
207
+
208
+
209
+ class PairwiseConditioning(Module):
210
+ """Pairwise conditioning layer."""
211
+
212
+ def __init__(
213
+ self,
214
+ token_z,
215
+ dim_token_rel_pos_feats,
216
+ num_transitions=2,
217
+ transition_expansion_factor=2,
218
+ ):
219
+ """Initialize the pairwise conditioning layer.
220
+
221
+ Parameters
222
+ ----------
223
+ token_z : int
224
+ The pair representation dimension.
225
+ dim_token_rel_pos_feats : int
226
+ The token relative position features dimension.
227
+ num_transitions : int, optional
228
+ The number of transitions layers, by default 2.
229
+ transition_expansion_factor : int, optional
230
+ The transition expansion factor, by default 2.
231
+
232
+ """
233
+ super().__init__()
234
+
235
+ self.dim_pairwise_init_proj = nn.Sequential(
236
+ nn.LayerNorm(token_z + dim_token_rel_pos_feats),
237
+ LinearNoBias(token_z + dim_token_rel_pos_feats, token_z),
238
+ )
239
+
240
+ transitions = ModuleList([])
241
+ for _ in range(num_transitions):
242
+ transition = Transition(
243
+ dim=token_z, hidden=transition_expansion_factor * token_z
244
+ )
245
+ transitions.append(transition)
246
+
247
+ self.transitions = transitions
248
+
249
+ def forward(
250
+ self,
251
+ z_trunk,
252
+ token_rel_pos_feats,
253
+ ):
254
+ z = torch.cat((z_trunk, token_rel_pos_feats), dim=-1)
255
+ z = self.dim_pairwise_init_proj(z)
256
+
257
+ for transition in self.transitions:
258
+ z = transition(z) + z
259
+
260
+ return z
261
+
262
+
263
+ def get_indexing_matrix(K, W, H, device):
264
+ assert W % 2 == 0
265
+ assert H % (W // 2) == 0
266
+
267
+ h = H // (W // 2)
268
+ assert h % 2 == 0
269
+
270
+ arange = torch.arange(2 * K, device=device)
271
+ index = ((arange.unsqueeze(0) - arange.unsqueeze(1)) + h // 2).clamp(
272
+ min=0, max=h + 1
273
+ )
274
+ index = index.view(K, 2, 2 * K)[:, 0, :]
275
+ onehot = one_hot(index, num_classes=h + 2)[..., 1:-1].transpose(1, 0)
276
+ return onehot.reshape(2 * K, h * K).float()
277
+
278
+
279
+ def single_to_keys(single, indexing_matrix, W, H):
280
+ B, N, D = single.shape
281
+ K = N // W
282
+ single = single.view(B, 2 * K, W // 2, D)
283
+ return torch.einsum("b j i d, j k -> b k i d", single, indexing_matrix).reshape(
284
+ B, K, H, D
285
+ )
286
+
287
+
288
+ class AtomAttentionEncoder(Module):
289
+ """Atom attention encoder."""
290
+
291
+ def __init__(
292
+ self,
293
+ atom_s,
294
+ atom_z,
295
+ token_s,
296
+ token_z,
297
+ atoms_per_window_queries,
298
+ atoms_per_window_keys,
299
+ atom_feature_dim,
300
+ atom_encoder_depth=3,
301
+ atom_encoder_heads=4,
302
+ structure_prediction=True,
303
+ activation_checkpointing=False,
304
+ ):
305
+ """Initialize the atom attention encoder.
306
+
307
+ Parameters
308
+ ----------
309
+ atom_s : int
310
+ The atom single representation dimension.
311
+ atom_z : int
312
+ The atom pair representation dimension.
313
+ token_s : int
314
+ The single representation dimension.
315
+ token_z : int
316
+ The pair representation dimension.
317
+ atoms_per_window_queries : int
318
+ The number of atoms per window for queries.
319
+ atoms_per_window_keys : int
320
+ The number of atoms per window for keys.
321
+ atom_feature_dim : int
322
+ The atom feature dimension.
323
+ atom_encoder_depth : int, optional
324
+ The number of transformer layers, by default 3.
325
+ atom_encoder_heads : int, optional
326
+ The number of transformer heads, by default 4.
327
+ structure_prediction : bool, optional
328
+ Whether it is used in the diffusion module, by default True.
329
+ activation_checkpointing : bool, optional
330
+ Whether to use activation checkpointing, by default False.
331
+
332
+ """
333
+ super().__init__()
334
+
335
+ self.embed_atom_features = LinearNoBias(atom_feature_dim, atom_s)
336
+ self.embed_atompair_ref_pos = LinearNoBias(3, atom_z)
337
+ self.embed_atompair_ref_dist = LinearNoBias(1, atom_z)
338
+ self.embed_atompair_mask = LinearNoBias(1, atom_z)
339
+ self.atoms_per_window_queries = atoms_per_window_queries
340
+ self.atoms_per_window_keys = atoms_per_window_keys
341
+
342
+ self.structure_prediction = structure_prediction
343
+ if structure_prediction:
344
+ self.s_to_c_trans = nn.Sequential(
345
+ nn.LayerNorm(token_s), LinearNoBias(token_s, atom_s)
346
+ )
347
+ init.final_init_(self.s_to_c_trans[1].weight)
348
+
349
+ self.z_to_p_trans = nn.Sequential(
350
+ nn.LayerNorm(token_z), LinearNoBias(token_z, atom_z)
351
+ )
352
+ init.final_init_(self.z_to_p_trans[1].weight)
353
+
354
+ self.r_to_q_trans = LinearNoBias(10, atom_s)
355
+ init.final_init_(self.r_to_q_trans.weight)
356
+
357
+ self.c_to_p_trans_k = nn.Sequential(
358
+ nn.ReLU(),
359
+ LinearNoBias(atom_s, atom_z),
360
+ )
361
+ init.final_init_(self.c_to_p_trans_k[1].weight)
362
+
363
+ self.c_to_p_trans_q = nn.Sequential(
364
+ nn.ReLU(),
365
+ LinearNoBias(atom_s, atom_z),
366
+ )
367
+ init.final_init_(self.c_to_p_trans_q[1].weight)
368
+
369
+ self.p_mlp = nn.Sequential(
370
+ nn.ReLU(),
371
+ LinearNoBias(atom_z, atom_z),
372
+ nn.ReLU(),
373
+ LinearNoBias(atom_z, atom_z),
374
+ nn.ReLU(),
375
+ LinearNoBias(atom_z, atom_z),
376
+ )
377
+ init.final_init_(self.p_mlp[5].weight)
378
+
379
+ self.atom_encoder = AtomTransformer(
380
+ dim=atom_s,
381
+ dim_single_cond=atom_s,
382
+ dim_pairwise=atom_z,
383
+ attn_window_queries=atoms_per_window_queries,
384
+ attn_window_keys=atoms_per_window_keys,
385
+ depth=atom_encoder_depth,
386
+ heads=atom_encoder_heads,
387
+ activation_checkpointing=activation_checkpointing,
388
+ )
389
+
390
+ self.atom_to_token_trans = nn.Sequential(
391
+ LinearNoBias(atom_s, 2 * token_s if structure_prediction else token_s),
392
+ nn.ReLU(),
393
+ )
394
+
395
+ def forward(
396
+ self,
397
+ feats,
398
+ s_trunk=None,
399
+ z=None,
400
+ r=None,
401
+ multiplicity=1,
402
+ model_cache=None,
403
+ ):
404
+ B, N, _ = feats["ref_pos"].shape
405
+ atom_mask = feats["atom_pad_mask"].bool()
406
+
407
+ layer_cache = None
408
+ if model_cache is not None:
409
+ cache_prefix = "atomencoder"
410
+ if cache_prefix not in model_cache:
411
+ model_cache[cache_prefix] = {}
412
+ layer_cache = model_cache[cache_prefix]
413
+
414
+ if model_cache is None or len(layer_cache) == 0:
415
+ # either model is not using the cache or it is the first time running it
416
+
417
+ atom_ref_pos = feats["ref_pos"]
418
+ atom_uid = feats["ref_space_uid"]
419
+ atom_feats = torch.cat(
420
+ [
421
+ atom_ref_pos,
422
+ feats["ref_charge"].unsqueeze(-1),
423
+ feats["atom_pad_mask"].unsqueeze(-1),
424
+ feats["ref_element"],
425
+ feats["ref_atom_name_chars"].reshape(B, N, 4 * 64),
426
+ ],
427
+ dim=-1,
428
+ )
429
+
430
+ c = self.embed_atom_features(atom_feats)
431
+
432
+ # NOTE: we are already creating the windows to make it more efficient
433
+ W, H = self.atoms_per_window_queries, self.atoms_per_window_keys
434
+ B, N = c.shape[:2]
435
+ K = N // W
436
+ keys_indexing_matrix = get_indexing_matrix(K, W, H, c.device)
437
+ to_keys = partial(
438
+ single_to_keys, indexing_matrix=keys_indexing_matrix, W=W, H=H
439
+ )
440
+
441
+ atom_ref_pos_queries = atom_ref_pos.view(B, K, W, 1, 3)
442
+ atom_ref_pos_keys = to_keys(atom_ref_pos).view(B, K, 1, H, 3)
443
+
444
+ d = atom_ref_pos_keys - atom_ref_pos_queries
445
+ d_norm = torch.sum(d * d, dim=-1, keepdim=True)
446
+ d_norm = 1 / (1 + d_norm)
447
+
448
+ atom_mask_queries = atom_mask.view(B, K, W, 1)
449
+ atom_mask_keys = (
450
+ to_keys(atom_mask.unsqueeze(-1).float()).view(B, K, 1, H).bool()
451
+ )
452
+ atom_uid_queries = atom_uid.view(B, K, W, 1)
453
+ atom_uid_keys = (
454
+ to_keys(atom_uid.unsqueeze(-1).float()).view(B, K, 1, H).long()
455
+ )
456
+ v = (
457
+ (
458
+ atom_mask_queries
459
+ & atom_mask_keys
460
+ & (atom_uid_queries == atom_uid_keys)
461
+ )
462
+ .float()
463
+ .unsqueeze(-1)
464
+ )
465
+
466
+ p = self.embed_atompair_ref_pos(d) * v
467
+ p = p + self.embed_atompair_ref_dist(d_norm) * v
468
+ p = p + self.embed_atompair_mask(v) * v
469
+
470
+ q = c
471
+
472
+ if self.structure_prediction:
473
+ # run only in structure model not in initial encoding
474
+ atom_to_token = feats["atom_to_token"].float()
475
+
476
+ s_to_c = self.s_to_c_trans(s_trunk)
477
+ s_to_c = torch.bmm(atom_to_token, s_to_c)
478
+ c = c + s_to_c
479
+
480
+ atom_to_token_queries = atom_to_token.view(
481
+ B, K, W, atom_to_token.shape[-1]
482
+ )
483
+ atom_to_token_keys = to_keys(atom_to_token)
484
+ z_to_p = self.z_to_p_trans(z)
485
+ z_to_p = torch.einsum(
486
+ "bijd,bwki,bwlj->bwkld",
487
+ z_to_p,
488
+ atom_to_token_queries,
489
+ atom_to_token_keys,
490
+ )
491
+ p = p + z_to_p
492
+
493
+ p = p + self.c_to_p_trans_q(c.view(B, K, W, 1, c.shape[-1]))
494
+ p = p + self.c_to_p_trans_k(to_keys(c).view(B, K, 1, H, c.shape[-1]))
495
+ p = p + self.p_mlp(p)
496
+
497
+ if model_cache is not None:
498
+ layer_cache["q"] = q
499
+ layer_cache["c"] = c
500
+ layer_cache["p"] = p
501
+ layer_cache["to_keys"] = to_keys
502
+
503
+ else:
504
+ q = layer_cache["q"]
505
+ c = layer_cache["c"]
506
+ p = layer_cache["p"]
507
+ to_keys = layer_cache["to_keys"]
508
+
509
+ if self.structure_prediction:
510
+ # only here the multiplicity kicks in because we use the different positions r
511
+ q = q.repeat_interleave(multiplicity, 0)
512
+ r_input = torch.cat(
513
+ [r, torch.zeros((B * multiplicity, N, 7)).to(r)],
514
+ dim=-1,
515
+ )
516
+ r_to_q = self.r_to_q_trans(r_input)
517
+ q = q + r_to_q
518
+
519
+ c = c.repeat_interleave(multiplicity, 0)
520
+ atom_mask = atom_mask.repeat_interleave(multiplicity, 0)
521
+
522
+ q = self.atom_encoder(
523
+ q=q,
524
+ mask=atom_mask,
525
+ c=c,
526
+ p=p,
527
+ multiplicity=multiplicity,
528
+ to_keys=to_keys,
529
+ model_cache=layer_cache,
530
+ )
531
+
532
+ q_to_a = self.atom_to_token_trans(q)
533
+ atom_to_token = feats["atom_to_token"].float()
534
+ atom_to_token = atom_to_token.repeat_interleave(multiplicity, 0)
535
+ atom_to_token_mean = atom_to_token / (
536
+ atom_to_token.sum(dim=1, keepdim=True) + 1e-6
537
+ )
538
+ a = torch.bmm(atom_to_token_mean.transpose(1, 2), q_to_a)
539
+
540
+ return a, q, c, p, to_keys
541
+
542
+
543
+ class AtomAttentionDecoder(Module):
544
+ """Atom attention decoder."""
545
+
546
+ def __init__(
547
+ self,
548
+ atom_s,
549
+ atom_z,
550
+ token_s,
551
+ attn_window_queries,
552
+ attn_window_keys,
553
+ atom_decoder_depth=3,
554
+ atom_decoder_heads=4,
555
+ activation_checkpointing=False,
556
+ ):
557
+ """Initialize the atom attention decoder.
558
+
559
+ Parameters
560
+ ----------
561
+ atom_s : int
562
+ The atom single representation dimension.
563
+ atom_z : int
564
+ The atom pair representation dimension.
565
+ token_s : int
566
+ The single representation dimension.
567
+ attn_window_queries : int
568
+ The number of atoms per window for queries.
569
+ attn_window_keys : int
570
+ The number of atoms per window for keys.
571
+ atom_decoder_depth : int, optional
572
+ The number of transformer layers, by default 3.
573
+ atom_decoder_heads : int, optional
574
+ The number of transformer heads, by default 4.
575
+ activation_checkpointing : bool, optional
576
+ Whether to use activation checkpointing, by default False.
577
+
578
+ """
579
+ super().__init__()
580
+
581
+ self.a_to_q_trans = LinearNoBias(2 * token_s, atom_s)
582
+ init.final_init_(self.a_to_q_trans.weight)
583
+
584
+ self.atom_decoder = AtomTransformer(
585
+ dim=atom_s,
586
+ dim_single_cond=atom_s,
587
+ dim_pairwise=atom_z,
588
+ attn_window_queries=attn_window_queries,
589
+ attn_window_keys=attn_window_keys,
590
+ depth=atom_decoder_depth,
591
+ heads=atom_decoder_heads,
592
+ activation_checkpointing=activation_checkpointing,
593
+ )
594
+
595
+ self.atom_feat_to_atom_pos_update = nn.Sequential(
596
+ nn.LayerNorm(atom_s), LinearNoBias(atom_s, 3)
597
+ )
598
+ init.final_init_(self.atom_feat_to_atom_pos_update[1].weight)
599
+
600
+ def forward(
601
+ self,
602
+ a,
603
+ q,
604
+ c,
605
+ p,
606
+ feats,
607
+ to_keys,
608
+ multiplicity=1,
609
+ model_cache=None,
610
+ ):
611
+ atom_mask = feats["atom_pad_mask"]
612
+ atom_mask = atom_mask.repeat_interleave(multiplicity, 0)
613
+
614
+ atom_to_token = feats["atom_to_token"].float()
615
+ atom_to_token = atom_to_token.repeat_interleave(multiplicity, 0)
616
+
617
+ a_to_q = self.a_to_q_trans(a)
618
+ a_to_q = torch.bmm(atom_to_token, a_to_q)
619
+ q = q + a_to_q
620
+
621
+ layer_cache = None
622
+ if model_cache is not None:
623
+ cache_prefix = "atomdecoder"
624
+ if cache_prefix not in model_cache:
625
+ model_cache[cache_prefix] = {}
626
+ layer_cache = model_cache[cache_prefix]
627
+
628
+ q = self.atom_decoder(
629
+ q=q,
630
+ mask=atom_mask,
631
+ c=c,
632
+ p=p,
633
+ multiplicity=multiplicity,
634
+ to_keys=to_keys,
635
+ model_cache=layer_cache,
636
+ )
637
+
638
+ r_update = self.atom_feat_to_atom_pos_update(q)
639
+ return r_update