boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,828 @@
1
+ import torch
2
+ from torch import Tensor, nn
3
+ from torch.nn.functional import one_hot
4
+
5
+ from boltz.data import const
6
+ from boltz.model.layers.outer_product_mean import OuterProductMean
7
+ from boltz.model.layers.pair_averaging import PairWeightedAveraging
8
+ from boltz.model.layers.pairformer import (
9
+ PairformerNoSeqLayer,
10
+ PairformerNoSeqModule,
11
+ get_dropout_mask,
12
+ )
13
+ from boltz.model.layers.transition import Transition
14
+ from boltz.model.modules.encodersv2 import (
15
+ AtomAttentionEncoder,
16
+ AtomEncoder,
17
+ FourierEmbedding,
18
+ )
19
+
20
+
21
+ class ContactConditioning(nn.Module):
22
+ def __init__(self, token_z: int, cutoff_min: float, cutoff_max: float):
23
+ super().__init__()
24
+
25
+ self.fourier_embedding = FourierEmbedding(token_z)
26
+ self.encoder = nn.Linear(
27
+ token_z + len(const.contact_conditioning_info) - 1, token_z
28
+ )
29
+ self.encoding_unspecified = nn.Parameter(torch.zeros(token_z))
30
+ self.encoding_unselected = nn.Parameter(torch.zeros(token_z))
31
+ self.cutoff_min = cutoff_min
32
+ self.cutoff_max = cutoff_max
33
+
34
+ def forward(self, feats):
35
+ assert const.contact_conditioning_info["UNSPECIFIED"] == 0
36
+ assert const.contact_conditioning_info["UNSELECTED"] == 1
37
+ contact_conditioning = feats["contact_conditioning"][:, :, :, 2:]
38
+ contact_threshold = feats["contact_threshold"]
39
+ contact_threshold_normalized = (contact_threshold - self.cutoff_min) / (
40
+ self.cutoff_max - self.cutoff_min
41
+ )
42
+ contact_threshold_fourier = self.fourier_embedding(
43
+ contact_threshold_normalized.flatten()
44
+ ).reshape(contact_threshold_normalized.shape + (-1,))
45
+
46
+ contact_conditioning = torch.cat(
47
+ [
48
+ contact_conditioning,
49
+ contact_threshold_normalized.unsqueeze(-1),
50
+ contact_threshold_fourier,
51
+ ],
52
+ dim=-1,
53
+ )
54
+ contact_conditioning = self.encoder(contact_conditioning)
55
+
56
+ contact_conditioning = (
57
+ contact_conditioning
58
+ * (
59
+ 1
60
+ - feats["contact_conditioning"][:, :, :, 0:2].sum(dim=-1, keepdim=True)
61
+ )
62
+ + self.encoding_unspecified * feats["contact_conditioning"][:, :, :, 0:1]
63
+ + self.encoding_unselected * feats["contact_conditioning"][:, :, :, 1:2]
64
+ )
65
+ return contact_conditioning
66
+
67
+
68
+ class InputEmbedder(nn.Module):
69
+ def __init__(
70
+ self,
71
+ atom_s: int,
72
+ atom_z: int,
73
+ token_s: int,
74
+ token_z: int,
75
+ atoms_per_window_queries: int,
76
+ atoms_per_window_keys: int,
77
+ atom_feature_dim: int,
78
+ atom_encoder_depth: int,
79
+ atom_encoder_heads: int,
80
+ activation_checkpointing: bool = False,
81
+ add_method_conditioning: bool = False,
82
+ add_modified_flag: bool = False,
83
+ add_cyclic_flag: bool = False,
84
+ add_mol_type_feat: bool = False,
85
+ use_no_atom_char: bool = False,
86
+ use_atom_backbone_feat: bool = False,
87
+ use_residue_feats_atoms: bool = False,
88
+ ) -> None:
89
+ """Initialize the input embedder.
90
+
91
+ Parameters
92
+ ----------
93
+ atom_s : int
94
+ The atom embedding size.
95
+ atom_z : int
96
+ The atom pairwise embedding size.
97
+ token_s : int
98
+ The token embedding size.
99
+
100
+ """
101
+ super().__init__()
102
+ self.token_s = token_s
103
+ self.add_method_conditioning = add_method_conditioning
104
+ self.add_modified_flag = add_modified_flag
105
+ self.add_cyclic_flag = add_cyclic_flag
106
+ self.add_mol_type_feat = add_mol_type_feat
107
+
108
+ self.atom_encoder = AtomEncoder(
109
+ atom_s=atom_s,
110
+ atom_z=atom_z,
111
+ token_s=token_s,
112
+ token_z=token_z,
113
+ atoms_per_window_queries=atoms_per_window_queries,
114
+ atoms_per_window_keys=atoms_per_window_keys,
115
+ atom_feature_dim=atom_feature_dim,
116
+ structure_prediction=False,
117
+ use_no_atom_char=use_no_atom_char,
118
+ use_atom_backbone_feat=use_atom_backbone_feat,
119
+ use_residue_feats_atoms=use_residue_feats_atoms,
120
+ )
121
+
122
+ self.atom_enc_proj_z = nn.Sequential(
123
+ nn.LayerNorm(atom_z),
124
+ nn.Linear(atom_z, atom_encoder_depth * atom_encoder_heads, bias=False),
125
+ )
126
+
127
+ self.atom_attention_encoder = AtomAttentionEncoder(
128
+ atom_s=atom_s,
129
+ token_s=token_s,
130
+ atoms_per_window_queries=atoms_per_window_queries,
131
+ atoms_per_window_keys=atoms_per_window_keys,
132
+ atom_encoder_depth=atom_encoder_depth,
133
+ atom_encoder_heads=atom_encoder_heads,
134
+ structure_prediction=False,
135
+ activation_checkpointing=activation_checkpointing,
136
+ )
137
+
138
+ self.res_type_encoding = nn.Linear(const.num_tokens, token_s, bias=False)
139
+ self.msa_profile_encoding = nn.Linear(const.num_tokens + 1, token_s, bias=False)
140
+
141
+ if add_method_conditioning:
142
+ self.method_conditioning_init = nn.Embedding(
143
+ const.num_method_types, token_s
144
+ )
145
+ self.method_conditioning_init.weight.data.fill_(0)
146
+ if add_modified_flag:
147
+ self.modified_conditioning_init = nn.Embedding(2, token_s)
148
+ self.modified_conditioning_init.weight.data.fill_(0)
149
+ if add_cyclic_flag:
150
+ self.cyclic_conditioning_init = nn.Linear(1, token_s, bias=False)
151
+ self.cyclic_conditioning_init.weight.data.fill_(0)
152
+ if add_mol_type_feat:
153
+ self.mol_type_conditioning_init = nn.Embedding(
154
+ len(const.chain_type_ids), token_s
155
+ )
156
+ self.mol_type_conditioning_init.weight.data.fill_(0)
157
+
158
+ def forward(self, feats: dict[str, Tensor], affinity: bool = False) -> Tensor:
159
+ """Perform the forward pass.
160
+
161
+ Parameters
162
+ ----------
163
+ feats : dict[str, Tensor]
164
+ Input features
165
+
166
+ Returns
167
+ -------
168
+ Tensor
169
+ The embedded tokens.
170
+
171
+ """
172
+ # Load relevant features
173
+ res_type = feats["res_type"].float()
174
+ if affinity:
175
+ profile = feats["profile_affinity"]
176
+ deletion_mean = feats["deletion_mean_affinity"].unsqueeze(-1)
177
+ else:
178
+ profile = feats["profile"]
179
+ deletion_mean = feats["deletion_mean"].unsqueeze(-1)
180
+
181
+ # Compute input embedding
182
+ q, c, p, to_keys = self.atom_encoder(feats)
183
+ atom_enc_bias = self.atom_enc_proj_z(p)
184
+ a, _, _, _ = self.atom_attention_encoder(
185
+ feats=feats,
186
+ q=q,
187
+ c=c,
188
+ atom_enc_bias=atom_enc_bias,
189
+ to_keys=to_keys,
190
+ )
191
+
192
+ s = (
193
+ a
194
+ + self.res_type_encoding(res_type)
195
+ + self.msa_profile_encoding(torch.cat([profile, deletion_mean], dim=-1))
196
+ )
197
+
198
+ if self.add_method_conditioning:
199
+ s = s + self.method_conditioning_init(feats["method_feature"])
200
+ if self.add_modified_flag:
201
+ s = s + self.modified_conditioning_init(feats["modified"])
202
+ if self.add_cyclic_flag:
203
+ cyclic = feats["cyclic_period"].clamp(max=1.0).unsqueeze(-1)
204
+ s = s + self.cyclic_conditioning_init(cyclic)
205
+ if self.add_mol_type_feat:
206
+ s = s + self.mol_type_conditioning_init(feats["mol_type"])
207
+
208
+ return s
209
+
210
+
211
+ class TemplateModule(nn.Module):
212
+ """Template module."""
213
+
214
+ def __init__(
215
+ self,
216
+ token_z: int,
217
+ template_dim: int,
218
+ template_blocks: int,
219
+ dropout: float = 0.25,
220
+ pairwise_head_width: int = 32,
221
+ pairwise_num_heads: int = 4,
222
+ post_layer_norm: bool = False,
223
+ activation_checkpointing: bool = False,
224
+ min_dist: float = 3.25,
225
+ max_dist: float = 50.75,
226
+ num_bins: int = 38,
227
+ **kwargs,
228
+ ) -> None:
229
+ """Initialize the template module.
230
+
231
+ Parameters
232
+ ----------
233
+ token_z : int
234
+ The token pairwise embedding size.
235
+
236
+ """
237
+ super().__init__()
238
+ self.min_dist = min_dist
239
+ self.max_dist = max_dist
240
+ self.num_bins = num_bins
241
+ self.relu = nn.ReLU()
242
+ self.z_norm = nn.LayerNorm(token_z)
243
+ self.v_norm = nn.LayerNorm(template_dim)
244
+ self.z_proj = nn.Linear(token_z, template_dim, bias=False)
245
+ self.a_proj = nn.Linear(
246
+ const.num_tokens * 2 + num_bins + 5,
247
+ template_dim,
248
+ bias=False,
249
+ )
250
+ self.u_proj = nn.Linear(template_dim, token_z, bias=False)
251
+ self.pairformer = PairformerNoSeqModule(
252
+ template_dim,
253
+ num_blocks=template_blocks,
254
+ dropout=dropout,
255
+ pairwise_head_width=pairwise_head_width,
256
+ pairwise_num_heads=pairwise_num_heads,
257
+ post_layer_norm=post_layer_norm,
258
+ activation_checkpointing=activation_checkpointing,
259
+ )
260
+
261
+ def forward(
262
+ self,
263
+ z: Tensor,
264
+ feats: dict[str, Tensor],
265
+ pair_mask: Tensor,
266
+ use_kernels: bool = False,
267
+ ) -> Tensor:
268
+ """Perform the forward pass.
269
+
270
+ Parameters
271
+ ----------
272
+ z : Tensor
273
+ The pairwise embeddings
274
+ feats : dict[str, Tensor]
275
+ Input features
276
+ pair_mask : Tensor
277
+ The pair mask
278
+
279
+ Returns
280
+ -------
281
+ Tensor
282
+ The updated pairwise embeddings.
283
+
284
+ """
285
+ # Load relevant features
286
+ asym_id = feats["asym_id"]
287
+ res_type = feats["template_restype"]
288
+ frame_rot = feats["template_frame_rot"]
289
+ frame_t = feats["template_frame_t"]
290
+ frame_mask = feats["template_mask_frame"]
291
+ cb_coords = feats["template_cb"]
292
+ ca_coords = feats["template_ca"]
293
+ cb_mask = feats["template_mask_cb"]
294
+ template_mask = feats["template_mask"].any(dim=2).float()
295
+ num_templates = template_mask.sum(dim=1)
296
+ num_templates = num_templates.clamp(min=1)
297
+
298
+ # Compute pairwise masks
299
+ b_cb_mask = cb_mask[:, :, :, None] * cb_mask[:, :, None, :]
300
+ b_frame_mask = frame_mask[:, :, :, None] * frame_mask[:, :, None, :]
301
+
302
+ b_cb_mask = b_cb_mask[..., None]
303
+ b_frame_mask = b_frame_mask[..., None]
304
+
305
+ # Compute asym mask, template features only attend within the same chain
306
+ B, T = res_type.shape[:2] # noqa: N806
307
+ asym_mask = (asym_id[:, :, None] == asym_id[:, None, :]).float()
308
+ asym_mask = asym_mask[:, None].expand(-1, T, -1, -1)
309
+
310
+ # Compute template features
311
+ with torch.autocast(device_type="cuda", enabled=False):
312
+ # Compute distogram
313
+ cb_dists = torch.cdist(cb_coords, cb_coords)
314
+ boundaries = torch.linspace(self.min_dist, self.max_dist, self.num_bins - 1)
315
+ boundaries = boundaries.to(cb_dists.device)
316
+ distogram = (cb_dists[..., None] > boundaries).sum(dim=-1).long()
317
+ distogram = one_hot(distogram, num_classes=self.num_bins)
318
+
319
+ # Compute unit vector in each frame
320
+ frame_rot = frame_rot.unsqueeze(2).transpose(-1, -2)
321
+ frame_t = frame_t.unsqueeze(2).unsqueeze(-1)
322
+ ca_coords = ca_coords.unsqueeze(3).unsqueeze(-1)
323
+ vector = torch.matmul(frame_rot, (ca_coords - frame_t))
324
+ norm = torch.norm(vector, dim=-1, keepdim=True)
325
+ unit_vector = torch.where(norm > 0, vector / norm, torch.zeros_like(vector))
326
+ unit_vector = unit_vector.squeeze(-1)
327
+
328
+ # Concatenate input features
329
+ a_tij = [distogram, b_cb_mask, unit_vector, b_frame_mask]
330
+ a_tij = torch.cat(a_tij, dim=-1)
331
+ a_tij = a_tij * asym_mask.unsqueeze(-1)
332
+
333
+ res_type_i = res_type[:, :, :, None]
334
+ res_type_j = res_type[:, :, None, :]
335
+ res_type_i = res_type_i.expand(-1, -1, -1, res_type.size(2), -1)
336
+ res_type_j = res_type_j.expand(-1, -1, res_type.size(2), -1, -1)
337
+ a_tij = torch.cat([a_tij, res_type_i, res_type_j], dim=-1)
338
+ a_tij = self.a_proj(a_tij)
339
+
340
+ # Expand mask
341
+ pair_mask = pair_mask[:, None].expand(-1, T, -1, -1)
342
+ pair_mask = pair_mask.reshape(B * T, *pair_mask.shape[2:])
343
+
344
+ # Compute input projections
345
+ v = self.z_proj(self.z_norm(z[:, None])) + a_tij
346
+ v = v.view(B * T, *v.shape[2:])
347
+ v = v + self.pairformer(v, pair_mask, use_kernels=use_kernels)
348
+ v = self.v_norm(v)
349
+ v = v.view(B, T, *v.shape[1:])
350
+
351
+ # Aggregate templates
352
+ template_mask = template_mask[:, :, None, None, None]
353
+ num_templates = num_templates[:, None, None, None]
354
+ u = (v * template_mask).sum(dim=1) / num_templates.to(v)
355
+
356
+ # Compute output projection
357
+ u = self.u_proj(self.relu(u))
358
+ return u
359
+
360
+
361
+ class TemplateV2Module(nn.Module):
362
+ """Template module."""
363
+
364
+ def __init__(
365
+ self,
366
+ token_z: int,
367
+ template_dim: int,
368
+ template_blocks: int,
369
+ dropout: float = 0.25,
370
+ pairwise_head_width: int = 32,
371
+ pairwise_num_heads: int = 4,
372
+ post_layer_norm: bool = False,
373
+ activation_checkpointing: bool = False,
374
+ min_dist: float = 3.25,
375
+ max_dist: float = 50.75,
376
+ num_bins: int = 38,
377
+ **kwargs,
378
+ ) -> None:
379
+ """Initialize the template module.
380
+
381
+ Parameters
382
+ ----------
383
+ token_z : int
384
+ The token pairwise embedding size.
385
+
386
+ """
387
+ super().__init__()
388
+ self.min_dist = min_dist
389
+ self.max_dist = max_dist
390
+ self.num_bins = num_bins
391
+ self.relu = nn.ReLU()
392
+ self.z_norm = nn.LayerNorm(token_z)
393
+ self.v_norm = nn.LayerNorm(template_dim)
394
+ self.z_proj = nn.Linear(token_z, template_dim, bias=False)
395
+ self.a_proj = nn.Linear(
396
+ const.num_tokens * 2 + num_bins + 5,
397
+ template_dim,
398
+ bias=False,
399
+ )
400
+ self.u_proj = nn.Linear(template_dim, token_z, bias=False)
401
+ self.pairformer = PairformerNoSeqModule(
402
+ template_dim,
403
+ num_blocks=template_blocks,
404
+ dropout=dropout,
405
+ pairwise_head_width=pairwise_head_width,
406
+ pairwise_num_heads=pairwise_num_heads,
407
+ post_layer_norm=post_layer_norm,
408
+ activation_checkpointing=activation_checkpointing,
409
+ )
410
+
411
+ def forward(
412
+ self,
413
+ z: Tensor,
414
+ feats: dict[str, Tensor],
415
+ pair_mask: Tensor,
416
+ use_kernels: bool = False,
417
+ ) -> Tensor:
418
+ """Perform the forward pass.
419
+
420
+ Parameters
421
+ ----------
422
+ z : Tensor
423
+ The pairwise embeddings
424
+ feats : dict[str, Tensor]
425
+ Input features
426
+ pair_mask : Tensor
427
+ The pair mask
428
+
429
+ Returns
430
+ -------
431
+ Tensor
432
+ The updated pairwise embeddings.
433
+
434
+ """
435
+ # Load relevant features
436
+ res_type = feats["template_restype"]
437
+ frame_rot = feats["template_frame_rot"]
438
+ frame_t = feats["template_frame_t"]
439
+ frame_mask = feats["template_mask_frame"]
440
+ cb_coords = feats["template_cb"]
441
+ ca_coords = feats["template_ca"]
442
+ cb_mask = feats["template_mask_cb"]
443
+ visibility_ids = feats["visibility_ids"]
444
+ template_mask = feats["template_mask"].any(dim=2).float()
445
+ num_templates = template_mask.sum(dim=1)
446
+ num_templates = num_templates.clamp(min=1)
447
+
448
+ # Compute pairwise masks
449
+ b_cb_mask = cb_mask[:, :, :, None] * cb_mask[:, :, None, :]
450
+ b_frame_mask = frame_mask[:, :, :, None] * frame_mask[:, :, None, :]
451
+
452
+ b_cb_mask = b_cb_mask[..., None]
453
+ b_frame_mask = b_frame_mask[..., None]
454
+
455
+ # Compute asym mask, template features only attend within the same chain
456
+ B, T = res_type.shape[:2] # noqa: N806
457
+ tmlp_pair_mask = (
458
+ visibility_ids[:, :, :, None] == visibility_ids[:, :, None, :]
459
+ ).float()
460
+
461
+ # Compute template features
462
+ with torch.autocast(device_type="cuda", enabled=False):
463
+ # Compute distogram
464
+ cb_dists = torch.cdist(cb_coords, cb_coords)
465
+ boundaries = torch.linspace(self.min_dist, self.max_dist, self.num_bins - 1)
466
+ boundaries = boundaries.to(cb_dists.device)
467
+ distogram = (cb_dists[..., None] > boundaries).sum(dim=-1).long()
468
+ distogram = one_hot(distogram, num_classes=self.num_bins)
469
+
470
+ # Compute unit vector in each frame
471
+ frame_rot = frame_rot.unsqueeze(2).transpose(-1, -2)
472
+ frame_t = frame_t.unsqueeze(2).unsqueeze(-1)
473
+ ca_coords = ca_coords.unsqueeze(3).unsqueeze(-1)
474
+ vector = torch.matmul(frame_rot, (ca_coords - frame_t))
475
+ norm = torch.norm(vector, dim=-1, keepdim=True)
476
+ unit_vector = torch.where(norm > 0, vector / norm, torch.zeros_like(vector))
477
+ unit_vector = unit_vector.squeeze(-1)
478
+
479
+ # Concatenate input features
480
+ a_tij = [distogram, b_cb_mask, unit_vector, b_frame_mask]
481
+ a_tij = torch.cat(a_tij, dim=-1)
482
+ a_tij = a_tij * tmlp_pair_mask.unsqueeze(-1)
483
+
484
+ res_type_i = res_type[:, :, :, None]
485
+ res_type_j = res_type[:, :, None, :]
486
+ res_type_i = res_type_i.expand(-1, -1, -1, res_type.size(2), -1)
487
+ res_type_j = res_type_j.expand(-1, -1, res_type.size(2), -1, -1)
488
+ a_tij = torch.cat([a_tij, res_type_i, res_type_j], dim=-1)
489
+ a_tij = self.a_proj(a_tij)
490
+
491
+ # Expand mask
492
+ pair_mask = pair_mask[:, None].expand(-1, T, -1, -1)
493
+ pair_mask = pair_mask.reshape(B * T, *pair_mask.shape[2:])
494
+
495
+ # Compute input projections
496
+ v = self.z_proj(self.z_norm(z[:, None])) + a_tij
497
+ v = v.view(B * T, *v.shape[2:])
498
+ v = v + self.pairformer(v, pair_mask, use_kernels=use_kernels)
499
+ v = self.v_norm(v)
500
+ v = v.view(B, T, *v.shape[1:])
501
+
502
+ # Aggregate templates
503
+ template_mask = template_mask[:, :, None, None, None]
504
+ num_templates = num_templates[:, None, None, None]
505
+ u = (v * template_mask).sum(dim=1) / num_templates.to(v)
506
+
507
+ # Compute output projection
508
+ u = self.u_proj(self.relu(u))
509
+ return u
510
+
511
+
512
+ class MSAModule(nn.Module):
513
+ """MSA module."""
514
+
515
+ def __init__(
516
+ self,
517
+ msa_s: int,
518
+ token_z: int,
519
+ token_s: int,
520
+ msa_blocks: int,
521
+ msa_dropout: float,
522
+ z_dropout: float,
523
+ pairwise_head_width: int = 32,
524
+ pairwise_num_heads: int = 4,
525
+ activation_checkpointing: bool = False,
526
+ use_paired_feature: bool = True,
527
+ subsample_msa: bool = False,
528
+ num_subsampled_msa: int = 1024,
529
+ **kwargs,
530
+ ) -> None:
531
+ """Initialize the MSA module.
532
+
533
+ Parameters
534
+ ----------
535
+ token_z : int
536
+ The token pairwise embedding size.
537
+
538
+ """
539
+ super().__init__()
540
+ self.msa_blocks = msa_blocks
541
+ self.msa_dropout = msa_dropout
542
+ self.z_dropout = z_dropout
543
+ self.use_paired_feature = use_paired_feature
544
+ self.activation_checkpointing = activation_checkpointing
545
+ self.subsample_msa = subsample_msa
546
+ self.num_subsampled_msa = num_subsampled_msa
547
+
548
+ self.s_proj = nn.Linear(token_s, msa_s, bias=False)
549
+ self.msa_proj = nn.Linear(
550
+ const.num_tokens + 2 + int(use_paired_feature),
551
+ msa_s,
552
+ bias=False,
553
+ )
554
+ self.layers = nn.ModuleList()
555
+ for i in range(msa_blocks):
556
+ self.layers.append(
557
+ MSALayer(
558
+ msa_s,
559
+ token_z,
560
+ msa_dropout,
561
+ z_dropout,
562
+ pairwise_head_width,
563
+ pairwise_num_heads,
564
+ )
565
+ )
566
+
567
+ def forward(
568
+ self,
569
+ z: Tensor,
570
+ emb: Tensor,
571
+ feats: dict[str, Tensor],
572
+ use_kernels: bool = False,
573
+ ) -> Tensor:
574
+ """Perform the forward pass.
575
+
576
+ Parameters
577
+ ----------
578
+ z : Tensor
579
+ The pairwise embeddings
580
+ emb : Tensor
581
+ The input embeddings
582
+ feats : dict[str, Tensor]
583
+ Input features
584
+ use_kernels: bool
585
+ Whether to use kernels for triangular updates
586
+
587
+ Returns
588
+ -------
589
+ Tensor
590
+ The output pairwise embeddings.
591
+
592
+ """
593
+ # Set chunk sizes
594
+ if not self.training:
595
+ if z.shape[1] > const.chunk_size_threshold:
596
+ chunk_heads_pwa = True
597
+ chunk_size_transition_z = 64
598
+ chunk_size_transition_msa = 32
599
+ chunk_size_outer_product = 4
600
+ chunk_size_tri_attn = 128
601
+ else:
602
+ chunk_heads_pwa = False
603
+ chunk_size_transition_z = None
604
+ chunk_size_transition_msa = None
605
+ chunk_size_outer_product = None
606
+ chunk_size_tri_attn = 512
607
+ else:
608
+ chunk_heads_pwa = False
609
+ chunk_size_transition_z = None
610
+ chunk_size_transition_msa = None
611
+ chunk_size_outer_product = None
612
+ chunk_size_tri_attn = None
613
+
614
+ # Load relevant features
615
+ msa = feats["msa"]
616
+ msa = torch.nn.functional.one_hot(msa, num_classes=const.num_tokens)
617
+ has_deletion = feats["has_deletion"].unsqueeze(-1)
618
+ deletion_value = feats["deletion_value"].unsqueeze(-1)
619
+ is_paired = feats["msa_paired"].unsqueeze(-1)
620
+ msa_mask = feats["msa_mask"]
621
+ token_mask = feats["token_pad_mask"].float()
622
+ token_mask = token_mask[:, :, None] * token_mask[:, None, :]
623
+
624
+ # Compute MSA embeddings
625
+ if self.use_paired_feature:
626
+ m = torch.cat([msa, has_deletion, deletion_value, is_paired], dim=-1)
627
+ else:
628
+ m = torch.cat([msa, has_deletion, deletion_value], dim=-1)
629
+
630
+ # Subsample the MSA
631
+ if self.subsample_msa:
632
+ msa_indices = torch.randperm(msa.shape[1])[: self.num_subsampled_msa]
633
+ m = m[:, msa_indices]
634
+ msa_mask = msa_mask[:, msa_indices]
635
+
636
+ # Compute input projections
637
+ m = self.msa_proj(m)
638
+ m = m + self.s_proj(emb).unsqueeze(1)
639
+
640
+ # Perform MSA blocks
641
+ for i in range(self.msa_blocks):
642
+ if self.activation_checkpointing and self.training:
643
+ z, m = torch.utils.checkpoint.checkpoint(
644
+ self.layers[i],
645
+ z,
646
+ m,
647
+ token_mask,
648
+ msa_mask,
649
+ chunk_heads_pwa,
650
+ chunk_size_transition_z,
651
+ chunk_size_transition_msa,
652
+ chunk_size_outer_product,
653
+ chunk_size_tri_attn,
654
+ use_kernels,
655
+ )
656
+ else:
657
+ z, m = self.layers[i](
658
+ z,
659
+ m,
660
+ token_mask,
661
+ msa_mask,
662
+ chunk_heads_pwa,
663
+ chunk_size_transition_z,
664
+ chunk_size_transition_msa,
665
+ chunk_size_outer_product,
666
+ chunk_size_tri_attn,
667
+ use_kernels,
668
+ )
669
+ return z
670
+
671
+
672
+ class MSALayer(nn.Module):
673
+ """MSA module."""
674
+
675
+ def __init__(
676
+ self,
677
+ msa_s: int,
678
+ token_z: int,
679
+ msa_dropout: float,
680
+ z_dropout: float,
681
+ pairwise_head_width: int = 32,
682
+ pairwise_num_heads: int = 4,
683
+ ) -> None:
684
+ """Initialize the MSA module.
685
+
686
+ Parameters
687
+ ----------
688
+ token_z : int
689
+ The token pairwise embedding size.
690
+
691
+ """
692
+ super().__init__()
693
+ self.msa_dropout = msa_dropout
694
+ self.msa_transition = Transition(dim=msa_s, hidden=msa_s * 4)
695
+ self.pair_weighted_averaging = PairWeightedAveraging(
696
+ c_m=msa_s,
697
+ c_z=token_z,
698
+ c_h=32,
699
+ num_heads=8,
700
+ )
701
+
702
+ self.pairformer_layer = PairformerNoSeqLayer(
703
+ token_z=token_z,
704
+ dropout=z_dropout,
705
+ pairwise_head_width=pairwise_head_width,
706
+ pairwise_num_heads=pairwise_num_heads,
707
+ )
708
+ self.outer_product_mean = OuterProductMean(
709
+ c_in=msa_s,
710
+ c_hidden=32,
711
+ c_out=token_z,
712
+ )
713
+
714
+ def forward(
715
+ self,
716
+ z: Tensor,
717
+ m: Tensor,
718
+ token_mask: Tensor,
719
+ msa_mask: Tensor,
720
+ chunk_heads_pwa: bool = False,
721
+ chunk_size_transition_z: int = None,
722
+ chunk_size_transition_msa: int = None,
723
+ chunk_size_outer_product: int = None,
724
+ chunk_size_tri_attn: int = None,
725
+ use_kernels: bool = False,
726
+ ) -> tuple[Tensor, Tensor]:
727
+ """Perform the forward pass.
728
+
729
+ Parameters
730
+ ----------
731
+ z : Tensor
732
+ The pairwise embeddings
733
+ emb : Tensor
734
+ The input embeddings
735
+ feats : dict[str, Tensor]
736
+ Input features
737
+
738
+ Returns
739
+ -------
740
+ Tensor
741
+ The output pairwise embeddings.
742
+
743
+ """
744
+ # Communication to MSA stack
745
+ msa_dropout = get_dropout_mask(self.msa_dropout, m, self.training)
746
+ m = m + msa_dropout * self.pair_weighted_averaging(
747
+ m, z, token_mask, chunk_heads_pwa
748
+ )
749
+ m = m + self.msa_transition(m, chunk_size_transition_msa)
750
+
751
+ z = z + self.outer_product_mean(m, msa_mask, chunk_size_outer_product)
752
+
753
+ # Compute pairwise stack
754
+ z = self.pairformer_layer(
755
+ z, token_mask, chunk_size_tri_attn, use_kernels=use_kernels
756
+ )
757
+
758
+ return z, m
759
+
760
+
761
+ class BFactorModule(nn.Module):
762
+ """BFactor Module."""
763
+
764
+ def __init__(self, token_s: int, num_bins: int) -> None:
765
+ """Initialize the bfactor module.
766
+
767
+ Parameters
768
+ ----------
769
+ token_s : int
770
+ The token embedding size.
771
+
772
+ """
773
+ super().__init__()
774
+ self.bfactor = nn.Linear(token_s, num_bins)
775
+ self.num_bins = num_bins
776
+
777
+ def forward(self, s: Tensor) -> Tensor:
778
+ """Perform the forward pass.
779
+
780
+ Parameters
781
+ ----------
782
+ s : Tensor
783
+ The sequence embeddings
784
+
785
+ Returns
786
+ -------
787
+ Tensor
788
+ The predicted bfactor histogram.
789
+
790
+ """
791
+ return self.bfactor(s)
792
+
793
+
794
+ class DistogramModule(nn.Module):
795
+ """Distogram Module."""
796
+
797
+ def __init__(self, token_z: int, num_bins: int, num_distograms: int = 1) -> None:
798
+ """Initialize the distogram module.
799
+
800
+ Parameters
801
+ ----------
802
+ token_z : int
803
+ The token pairwise embedding size.
804
+
805
+ """
806
+ super().__init__()
807
+ self.distogram = nn.Linear(token_z, num_distograms * num_bins)
808
+ self.num_distograms = num_distograms
809
+ self.num_bins = num_bins
810
+
811
+ def forward(self, z: Tensor) -> Tensor:
812
+ """Perform the forward pass.
813
+
814
+ Parameters
815
+ ----------
816
+ z : Tensor
817
+ The pairwise embeddings
818
+
819
+ Returns
820
+ -------
821
+ Tensor
822
+ The predicted distogram.
823
+
824
+ """
825
+ z = z + z.transpose(1, 2)
826
+ return self.distogram(z).reshape(
827
+ z.shape[0], z.shape[1], z.shape[2], self.num_distograms, self.num_bins
828
+ )