boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,565 @@
1
+ # started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang
2
+ from functools import partial
3
+ from math import pi
4
+
5
+ import torch
6
+ from einops import rearrange
7
+ from torch import nn
8
+ from torch.nn import Linear, Module, ModuleList
9
+ from torch.nn.functional import one_hot
10
+
11
+ import boltz.model.layers.initialize as init
12
+ from boltz.model.layers.transition import Transition
13
+ from boltz.model.modules.transformersv2 import AtomTransformer
14
+ from boltz.model.modules.utils import LinearNoBias
15
+
16
+
17
+ class FourierEmbedding(Module):
18
+ """Algorithm 22."""
19
+
20
+ def __init__(self, dim):
21
+ super().__init__()
22
+ self.proj = nn.Linear(1, dim)
23
+ torch.nn.init.normal_(self.proj.weight, mean=0, std=1)
24
+ torch.nn.init.normal_(self.proj.bias, mean=0, std=1)
25
+ self.proj.requires_grad_(False)
26
+
27
+ def forward(
28
+ self,
29
+ times, # Float[' b'],
30
+ ): # -> Float['b d']:
31
+ times = rearrange(times, "b -> b 1")
32
+ rand_proj = self.proj(times)
33
+ return torch.cos(2 * pi * rand_proj)
34
+
35
+
36
+ class RelativePositionEncoder(Module):
37
+ """Algorithm 3."""
38
+
39
+ def __init__(
40
+ self, token_z, r_max=32, s_max=2, fix_sym_check=False, cyclic_pos_enc=False
41
+ ):
42
+ super().__init__()
43
+ self.r_max = r_max
44
+ self.s_max = s_max
45
+ self.linear_layer = LinearNoBias(4 * (r_max + 1) + 2 * (s_max + 1) + 1, token_z)
46
+ self.fix_sym_check = fix_sym_check
47
+ self.cyclic_pos_enc = cyclic_pos_enc
48
+
49
+ def forward(self, feats):
50
+ b_same_chain = torch.eq(
51
+ feats["asym_id"][:, :, None], feats["asym_id"][:, None, :]
52
+ )
53
+ b_same_residue = torch.eq(
54
+ feats["residue_index"][:, :, None], feats["residue_index"][:, None, :]
55
+ )
56
+ b_same_entity = torch.eq(
57
+ feats["entity_id"][:, :, None], feats["entity_id"][:, None, :]
58
+ )
59
+
60
+ d_residue = (
61
+ feats["residue_index"][:, :, None] - feats["residue_index"][:, None, :]
62
+ )
63
+
64
+ if self.cyclic_pos_enc and torch.any(feats["cyclic_period"] > 0):
65
+ period = torch.where(
66
+ feats["cyclic_period"] > 0,
67
+ feats["cyclic_period"],
68
+ torch.zeros_like(feats["cyclic_period"]) + 10000,
69
+ )
70
+ d_residue = (d_residue - period * torch.round(d_residue / period)).long()
71
+
72
+ d_residue = torch.clip(
73
+ d_residue + self.r_max,
74
+ 0,
75
+ 2 * self.r_max,
76
+ )
77
+ d_residue = torch.where(
78
+ b_same_chain, d_residue, torch.zeros_like(d_residue) + 2 * self.r_max + 1
79
+ )
80
+ a_rel_pos = one_hot(d_residue, 2 * self.r_max + 2)
81
+
82
+ d_token = torch.clip(
83
+ feats["token_index"][:, :, None]
84
+ - feats["token_index"][:, None, :]
85
+ + self.r_max,
86
+ 0,
87
+ 2 * self.r_max,
88
+ )
89
+ d_token = torch.where(
90
+ b_same_chain & b_same_residue,
91
+ d_token,
92
+ torch.zeros_like(d_token) + 2 * self.r_max + 1,
93
+ )
94
+ a_rel_token = one_hot(d_token, 2 * self.r_max + 2)
95
+
96
+ d_chain = torch.clip(
97
+ feats["sym_id"][:, :, None] - feats["sym_id"][:, None, :] + self.s_max,
98
+ 0,
99
+ 2 * self.s_max,
100
+ )
101
+ d_chain = torch.where(
102
+ (~b_same_entity) if self.fix_sym_check else b_same_chain,
103
+ torch.zeros_like(d_chain) + 2 * self.s_max + 1,
104
+ d_chain,
105
+ )
106
+ # Note: added | (~b_same_entity) based on observation of ProteinX manuscript
107
+ a_rel_chain = one_hot(d_chain, 2 * self.s_max + 2)
108
+
109
+ p = self.linear_layer(
110
+ torch.cat(
111
+ [
112
+ a_rel_pos.float(),
113
+ a_rel_token.float(),
114
+ b_same_entity.unsqueeze(-1).float(),
115
+ a_rel_chain.float(),
116
+ ],
117
+ dim=-1,
118
+ )
119
+ )
120
+ return p
121
+
122
+
123
+ class SingleConditioning(Module):
124
+ """Algorithm 21."""
125
+
126
+ def __init__(
127
+ self,
128
+ sigma_data: float,
129
+ token_s: int = 384,
130
+ dim_fourier: int = 256,
131
+ num_transitions: int = 2,
132
+ transition_expansion_factor: int = 2,
133
+ eps: float = 1e-20,
134
+ disable_times: bool = False,
135
+ ) -> None:
136
+ super().__init__()
137
+ self.eps = eps
138
+ self.sigma_data = sigma_data
139
+ self.disable_times = disable_times
140
+
141
+ self.norm_single = nn.LayerNorm(2 * token_s)
142
+ self.single_embed = nn.Linear(2 * token_s, 2 * token_s)
143
+ if not self.disable_times:
144
+ self.fourier_embed = FourierEmbedding(dim_fourier)
145
+ self.norm_fourier = nn.LayerNorm(dim_fourier)
146
+ self.fourier_to_single = LinearNoBias(dim_fourier, 2 * token_s)
147
+
148
+ transitions = ModuleList([])
149
+ for _ in range(num_transitions):
150
+ transition = Transition(
151
+ dim=2 * token_s, hidden=transition_expansion_factor * 2 * token_s
152
+ )
153
+ transitions.append(transition)
154
+
155
+ self.transitions = transitions
156
+
157
+ def forward(
158
+ self,
159
+ times, # Float[' b'],
160
+ s_trunk, # Float['b n ts'],
161
+ s_inputs, # Float['b n ts'],
162
+ ): # -> Float['b n 2ts']:
163
+ s = torch.cat((s_trunk, s_inputs), dim=-1)
164
+ s = self.single_embed(self.norm_single(s))
165
+ if not self.disable_times:
166
+ fourier_embed = self.fourier_embed(
167
+ times
168
+ ) # note: sigma rescaling done in diffusion module
169
+ normed_fourier = self.norm_fourier(fourier_embed)
170
+ fourier_to_single = self.fourier_to_single(normed_fourier)
171
+
172
+ s = rearrange(fourier_to_single, "b d -> b 1 d") + s
173
+
174
+ for transition in self.transitions:
175
+ s = transition(s) + s
176
+
177
+ return s, normed_fourier if not self.disable_times else None
178
+
179
+
180
+ class PairwiseConditioning(Module):
181
+ """Algorithm 21."""
182
+
183
+ def __init__(
184
+ self,
185
+ token_z,
186
+ dim_token_rel_pos_feats,
187
+ num_transitions=2,
188
+ transition_expansion_factor=2,
189
+ ):
190
+ super().__init__()
191
+
192
+ self.dim_pairwise_init_proj = nn.Sequential(
193
+ nn.LayerNorm(token_z + dim_token_rel_pos_feats),
194
+ LinearNoBias(token_z + dim_token_rel_pos_feats, token_z),
195
+ )
196
+
197
+ transitions = ModuleList([])
198
+ for _ in range(num_transitions):
199
+ transition = Transition(
200
+ dim=token_z, hidden=transition_expansion_factor * token_z
201
+ )
202
+ transitions.append(transition)
203
+
204
+ self.transitions = transitions
205
+
206
+ def forward(
207
+ self,
208
+ z_trunk, # Float['b n n tz'],
209
+ token_rel_pos_feats, # Float['b n n 3'],
210
+ ): # -> Float['b n n tz']:
211
+ z = torch.cat((z_trunk, token_rel_pos_feats), dim=-1)
212
+ z = self.dim_pairwise_init_proj(z)
213
+
214
+ for transition in self.transitions:
215
+ z = transition(z) + z
216
+
217
+ return z
218
+
219
+
220
+ def get_indexing_matrix(K, W, H, device):
221
+ assert W % 2 == 0
222
+ assert H % (W // 2) == 0
223
+
224
+ h = H // (W // 2)
225
+ assert h % 2 == 0
226
+
227
+ arange = torch.arange(2 * K, device=device)
228
+ index = ((arange.unsqueeze(0) - arange.unsqueeze(1)) + h // 2).clamp(
229
+ min=0, max=h + 1
230
+ )
231
+ index = index.view(K, 2, 2 * K)[:, 0, :]
232
+ onehot = one_hot(index, num_classes=h + 2)[..., 1:-1].transpose(1, 0)
233
+ return onehot.reshape(2 * K, h * K).float()
234
+
235
+
236
+ def single_to_keys(single, indexing_matrix, W, H):
237
+ B, N, D = single.shape
238
+ K = N // W
239
+ single = single.view(B, 2 * K, W // 2, D)
240
+ return torch.einsum("b j i d, j k -> b k i d", single, indexing_matrix).reshape(
241
+ B, K, H, D
242
+ ) # j = 2K, i = W//2, k = h * K
243
+
244
+
245
+ class AtomEncoder(Module):
246
+ def __init__(
247
+ self,
248
+ atom_s,
249
+ atom_z,
250
+ token_s,
251
+ token_z,
252
+ atoms_per_window_queries,
253
+ atoms_per_window_keys,
254
+ atom_feature_dim,
255
+ structure_prediction=True,
256
+ use_no_atom_char=False,
257
+ use_atom_backbone_feat=False,
258
+ use_residue_feats_atoms=False,
259
+ ):
260
+ super().__init__()
261
+
262
+ self.embed_atom_features = Linear(atom_feature_dim, atom_s)
263
+ self.embed_atompair_ref_pos = LinearNoBias(3, atom_z)
264
+ self.embed_atompair_ref_dist = LinearNoBias(1, atom_z)
265
+ self.embed_atompair_mask = LinearNoBias(1, atom_z)
266
+ self.atoms_per_window_queries = atoms_per_window_queries
267
+ self.atoms_per_window_keys = atoms_per_window_keys
268
+ self.use_no_atom_char = use_no_atom_char
269
+ self.use_atom_backbone_feat = use_atom_backbone_feat
270
+ self.use_residue_feats_atoms = use_residue_feats_atoms
271
+
272
+ self.structure_prediction = structure_prediction
273
+ if structure_prediction:
274
+ self.s_to_c_trans = nn.Sequential(
275
+ nn.LayerNorm(token_s), LinearNoBias(token_s, atom_s)
276
+ )
277
+ init.final_init_(self.s_to_c_trans[1].weight)
278
+
279
+ self.z_to_p_trans = nn.Sequential(
280
+ nn.LayerNorm(token_z), LinearNoBias(token_z, atom_z)
281
+ )
282
+ init.final_init_(self.z_to_p_trans[1].weight)
283
+
284
+ self.c_to_p_trans_k = nn.Sequential(
285
+ nn.ReLU(),
286
+ LinearNoBias(atom_s, atom_z),
287
+ )
288
+ init.final_init_(self.c_to_p_trans_k[1].weight)
289
+
290
+ self.c_to_p_trans_q = nn.Sequential(
291
+ nn.ReLU(),
292
+ LinearNoBias(atom_s, atom_z),
293
+ )
294
+ init.final_init_(self.c_to_p_trans_q[1].weight)
295
+
296
+ self.p_mlp = nn.Sequential(
297
+ nn.ReLU(),
298
+ LinearNoBias(atom_z, atom_z),
299
+ nn.ReLU(),
300
+ LinearNoBias(atom_z, atom_z),
301
+ nn.ReLU(),
302
+ LinearNoBias(atom_z, atom_z),
303
+ )
304
+ init.final_init_(self.p_mlp[5].weight)
305
+
306
+ def forward(
307
+ self,
308
+ feats,
309
+ s_trunk=None, # Float['bm n ts'],
310
+ z=None, # Float['bm n n tz'],
311
+ ):
312
+ with torch.autocast("cuda", enabled=False):
313
+ B, N, _ = feats["ref_pos"].shape
314
+ atom_mask = feats["atom_pad_mask"].bool() # Bool['b m'],
315
+
316
+ atom_ref_pos = feats["ref_pos"] # Float['b m 3'],
317
+ atom_uid = feats["ref_space_uid"] # Long['b m'],
318
+
319
+ atom_feats = [
320
+ atom_ref_pos,
321
+ feats["ref_charge"].unsqueeze(-1),
322
+ feats["ref_element"],
323
+ ]
324
+ if not self.use_no_atom_char:
325
+ atom_feats.append(feats["ref_atom_name_chars"].reshape(B, N, 4 * 64))
326
+ if self.use_atom_backbone_feat:
327
+ atom_feats.append(feats["atom_backbone_feat"])
328
+ if self.use_residue_feats_atoms:
329
+ res_feats = torch.cat(
330
+ [
331
+ feats["res_type"],
332
+ feats["modified"].unsqueeze(-1),
333
+ one_hot(feats["mol_type"], num_classes=4).float(),
334
+ ],
335
+ dim=-1,
336
+ )
337
+ atom_to_token = feats["atom_to_token"].float()
338
+ atom_res_feats = torch.bmm(atom_to_token, res_feats)
339
+ atom_feats.append(atom_res_feats)
340
+
341
+ atom_feats = torch.cat(atom_feats, dim=-1)
342
+
343
+ c = self.embed_atom_features(atom_feats)
344
+
345
+ # note we are already creating the windows to make it more efficient
346
+ W, H = self.atoms_per_window_queries, self.atoms_per_window_keys
347
+ B, N = c.shape[:2]
348
+ K = N // W
349
+ keys_indexing_matrix = get_indexing_matrix(K, W, H, c.device)
350
+ to_keys = partial(
351
+ single_to_keys, indexing_matrix=keys_indexing_matrix, W=W, H=H
352
+ )
353
+
354
+ atom_ref_pos_queries = atom_ref_pos.view(B, K, W, 1, 3)
355
+ atom_ref_pos_keys = to_keys(atom_ref_pos).view(B, K, 1, H, 3)
356
+
357
+ d = atom_ref_pos_keys - atom_ref_pos_queries # Float['b k w h 3']
358
+ d_norm = torch.sum(d * d, dim=-1, keepdim=True) # Float['b k w h 1']
359
+ d_norm = 1 / (
360
+ 1 + d_norm
361
+ ) # AF3 feeds in the reciprocal of the distance norm
362
+
363
+ atom_mask_queries = atom_mask.view(B, K, W, 1)
364
+ atom_mask_keys = (
365
+ to_keys(atom_mask.unsqueeze(-1).float()).view(B, K, 1, H).bool()
366
+ )
367
+ atom_uid_queries = atom_uid.view(B, K, W, 1)
368
+ atom_uid_keys = (
369
+ to_keys(atom_uid.unsqueeze(-1).float()).view(B, K, 1, H).long()
370
+ )
371
+ v = (
372
+ (
373
+ atom_mask_queries
374
+ & atom_mask_keys
375
+ & (atom_uid_queries == atom_uid_keys)
376
+ )
377
+ .float()
378
+ .unsqueeze(-1)
379
+ ) # Bool['b k w h 1']
380
+
381
+ p = self.embed_atompair_ref_pos(d) * v
382
+ p = p + self.embed_atompair_ref_dist(d_norm) * v
383
+ p = p + self.embed_atompair_mask(v) * v
384
+
385
+ q = c
386
+
387
+ if self.structure_prediction:
388
+ # run only in structure model not in initial encoding
389
+ atom_to_token = feats["atom_to_token"].float() # Long['b m n'],
390
+
391
+ s_to_c = self.s_to_c_trans(s_trunk.float())
392
+ s_to_c = torch.bmm(atom_to_token, s_to_c)
393
+ c = c + s_to_c.to(c)
394
+
395
+ atom_to_token_queries = atom_to_token.view(
396
+ B, K, W, atom_to_token.shape[-1]
397
+ )
398
+ atom_to_token_keys = to_keys(atom_to_token)
399
+ z_to_p = self.z_to_p_trans(z.float())
400
+ z_to_p = torch.einsum(
401
+ "bijd,bwki,bwlj->bwkld",
402
+ z_to_p,
403
+ atom_to_token_queries,
404
+ atom_to_token_keys,
405
+ )
406
+ p = p + z_to_p.to(p)
407
+
408
+ p = p + self.c_to_p_trans_q(c.view(B, K, W, 1, c.shape[-1]))
409
+ p = p + self.c_to_p_trans_k(to_keys(c).view(B, K, 1, H, c.shape[-1]))
410
+ p = p + self.p_mlp(p)
411
+ return q, c, p, to_keys
412
+
413
+
414
+ class AtomAttentionEncoder(Module):
415
+ def __init__(
416
+ self,
417
+ atom_s,
418
+ token_s,
419
+ atoms_per_window_queries,
420
+ atoms_per_window_keys,
421
+ atom_encoder_depth=3,
422
+ atom_encoder_heads=4,
423
+ structure_prediction=True,
424
+ activation_checkpointing=False,
425
+ transformer_post_layer_norm=False,
426
+ ):
427
+ super().__init__()
428
+
429
+ self.structure_prediction = structure_prediction
430
+ if structure_prediction:
431
+ self.r_to_q_trans = LinearNoBias(3, atom_s)
432
+ init.final_init_(self.r_to_q_trans.weight)
433
+
434
+ self.atom_encoder = AtomTransformer(
435
+ dim=atom_s,
436
+ dim_single_cond=atom_s,
437
+ attn_window_queries=atoms_per_window_queries,
438
+ attn_window_keys=atoms_per_window_keys,
439
+ depth=atom_encoder_depth,
440
+ heads=atom_encoder_heads,
441
+ activation_checkpointing=activation_checkpointing,
442
+ post_layer_norm=transformer_post_layer_norm,
443
+ )
444
+
445
+ self.atom_to_token_trans = nn.Sequential(
446
+ LinearNoBias(atom_s, 2 * token_s if structure_prediction else token_s),
447
+ nn.ReLU(),
448
+ )
449
+
450
+ def forward(
451
+ self,
452
+ feats,
453
+ q,
454
+ c,
455
+ atom_enc_bias,
456
+ to_keys,
457
+ r=None, # Float['bm m 3'],
458
+ multiplicity=1,
459
+ ):
460
+ B, N, _ = feats["ref_pos"].shape
461
+ atom_mask = feats["atom_pad_mask"].bool() # Bool['b m'],
462
+
463
+ if self.structure_prediction:
464
+ # only here the multiplicity kicks in because we use the different positions r
465
+ q = q.repeat_interleave(multiplicity, 0)
466
+ r_to_q = self.r_to_q_trans(r)
467
+ q = q + r_to_q
468
+
469
+ c = c.repeat_interleave(multiplicity, 0)
470
+ atom_mask = atom_mask.repeat_interleave(multiplicity, 0)
471
+
472
+ q = self.atom_encoder(
473
+ q=q,
474
+ mask=atom_mask,
475
+ c=c,
476
+ bias=atom_enc_bias,
477
+ multiplicity=multiplicity,
478
+ to_keys=to_keys,
479
+ )
480
+
481
+ with torch.autocast("cuda", enabled=False):
482
+ q_to_a = self.atom_to_token_trans(q).float()
483
+ atom_to_token = feats["atom_to_token"].float()
484
+ atom_to_token = atom_to_token.repeat_interleave(multiplicity, 0)
485
+ atom_to_token_mean = atom_to_token / (
486
+ atom_to_token.sum(dim=1, keepdim=True) + 1e-6
487
+ )
488
+ a = torch.bmm(atom_to_token_mean.transpose(1, 2), q_to_a)
489
+
490
+ a = a.to(q)
491
+
492
+ return a, q, c, to_keys
493
+
494
+
495
+ class AtomAttentionDecoder(Module):
496
+ """Algorithm 6."""
497
+
498
+ def __init__(
499
+ self,
500
+ atom_s,
501
+ token_s,
502
+ attn_window_queries,
503
+ attn_window_keys,
504
+ atom_decoder_depth=3,
505
+ atom_decoder_heads=4,
506
+ activation_checkpointing=False,
507
+ transformer_post_layer_norm=False,
508
+ ):
509
+ super().__init__()
510
+
511
+ self.a_to_q_trans = LinearNoBias(2 * token_s, atom_s)
512
+ init.final_init_(self.a_to_q_trans.weight)
513
+
514
+ self.atom_decoder = AtomTransformer(
515
+ dim=atom_s,
516
+ dim_single_cond=atom_s,
517
+ attn_window_queries=attn_window_queries,
518
+ attn_window_keys=attn_window_keys,
519
+ depth=atom_decoder_depth,
520
+ heads=atom_decoder_heads,
521
+ activation_checkpointing=activation_checkpointing,
522
+ post_layer_norm=transformer_post_layer_norm,
523
+ )
524
+
525
+ if transformer_post_layer_norm:
526
+ self.atom_feat_to_atom_pos_update = LinearNoBias(atom_s, 3)
527
+ init.final_init_(self.atom_feat_to_atom_pos_update.weight)
528
+ else:
529
+ self.atom_feat_to_atom_pos_update = nn.Sequential(
530
+ nn.LayerNorm(atom_s), LinearNoBias(atom_s, 3)
531
+ )
532
+ init.final_init_(self.atom_feat_to_atom_pos_update[1].weight)
533
+
534
+ def forward(
535
+ self,
536
+ a, # Float['bm n 2ts'],
537
+ q, # Float['bm m as'],
538
+ c, # Float['bm m as'],
539
+ atom_dec_bias, # Float['bm m m az'],
540
+ feats,
541
+ to_keys,
542
+ multiplicity=1,
543
+ ):
544
+ with torch.autocast("cuda", enabled=False):
545
+ atom_to_token = feats["atom_to_token"].float()
546
+ atom_to_token = atom_to_token.repeat_interleave(multiplicity, 0)
547
+
548
+ a_to_q = self.a_to_q_trans(a.float())
549
+ a_to_q = torch.bmm(atom_to_token, a_to_q)
550
+
551
+ q = q + a_to_q.to(q)
552
+ atom_mask = feats["atom_pad_mask"] # Bool['b m'],
553
+ atom_mask = atom_mask.repeat_interleave(multiplicity, 0)
554
+
555
+ q = self.atom_decoder(
556
+ q=q,
557
+ mask=atom_mask,
558
+ c=c,
559
+ bias=atom_dec_bias,
560
+ multiplicity=multiplicity,
561
+ to_keys=to_keys,
562
+ )
563
+
564
+ r_update = self.atom_feat_to_atom_pos_update(q)
565
+ return r_update