boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,688 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
5
+ from torch import Tensor, nn
6
+
7
+ from boltz.data import const
8
+ from boltz.model.layers.attention import AttentionPairBias
9
+ from boltz.model.layers.dropout import get_dropout_mask
10
+ from boltz.model.layers.outer_product_mean import OuterProductMean
11
+ from boltz.model.layers.pair_averaging import PairWeightedAveraging
12
+ from boltz.model.layers.transition import Transition
13
+ from boltz.model.layers.triangular_attention.attention import (
14
+ TriangleAttentionEndingNode,
15
+ TriangleAttentionStartingNode,
16
+ )
17
+ from boltz.model.layers.triangular_mult import (
18
+ TriangleMultiplicationIncoming,
19
+ TriangleMultiplicationOutgoing,
20
+ )
21
+ from boltz.model.modules.encoders import AtomAttentionEncoder
22
+
23
+
24
+ class InputEmbedder(nn.Module):
25
+ """Input embedder."""
26
+
27
+ def __init__(
28
+ self,
29
+ atom_s: int,
30
+ atom_z: int,
31
+ token_s: int,
32
+ token_z: int,
33
+ atoms_per_window_queries: int,
34
+ atoms_per_window_keys: int,
35
+ atom_feature_dim: int,
36
+ atom_encoder_depth: int,
37
+ atom_encoder_heads: int,
38
+ no_atom_encoder: bool = False,
39
+ ) -> None:
40
+ """Initialize the input embedder.
41
+
42
+ Parameters
43
+ ----------
44
+ atom_s : int
45
+ The atom single representation dimension.
46
+ atom_z : int
47
+ The atom pair representation dimension.
48
+ token_s : int
49
+ The single token representation dimension.
50
+ token_z : int
51
+ The pair token representation dimension.
52
+ atoms_per_window_queries : int
53
+ The number of atoms per window for queries.
54
+ atoms_per_window_keys : int
55
+ The number of atoms per window for keys.
56
+ atom_feature_dim : int
57
+ The atom feature dimension.
58
+ atom_encoder_depth : int
59
+ The atom encoder depth.
60
+ atom_encoder_heads : int
61
+ The atom encoder heads.
62
+ no_atom_encoder : bool, optional
63
+ Whether to use the atom encoder, by default False
64
+
65
+ """
66
+ super().__init__()
67
+ self.token_s = token_s
68
+ self.no_atom_encoder = no_atom_encoder
69
+
70
+ if not no_atom_encoder:
71
+ self.atom_attention_encoder = AtomAttentionEncoder(
72
+ atom_s=atom_s,
73
+ atom_z=atom_z,
74
+ token_s=token_s,
75
+ token_z=token_z,
76
+ atoms_per_window_queries=atoms_per_window_queries,
77
+ atoms_per_window_keys=atoms_per_window_keys,
78
+ atom_feature_dim=atom_feature_dim,
79
+ atom_encoder_depth=atom_encoder_depth,
80
+ atom_encoder_heads=atom_encoder_heads,
81
+ structure_prediction=False,
82
+ )
83
+
84
+ def forward(self, feats: dict[str, Tensor]) -> Tensor:
85
+ """Perform the forward pass.
86
+
87
+ Parameters
88
+ ----------
89
+ feats : Dict[str, Tensor]
90
+ Input features
91
+
92
+ Returns
93
+ -------
94
+ Tensor
95
+ The embedded tokens.
96
+
97
+ """
98
+ # Load relevant features
99
+ res_type = feats["res_type"]
100
+ profile = feats["profile"]
101
+ deletion_mean = feats["deletion_mean"].unsqueeze(-1)
102
+ pocket_feature = feats["pocket_feature"]
103
+
104
+ # Compute input embedding
105
+ if self.no_atom_encoder:
106
+ a = torch.zeros(
107
+ (res_type.shape[0], res_type.shape[1], self.token_s),
108
+ device=res_type.device,
109
+ )
110
+ else:
111
+ a, _, _, _, _ = self.atom_attention_encoder(feats)
112
+ s = torch.cat([a, res_type, profile, deletion_mean, pocket_feature], dim=-1)
113
+ return s
114
+
115
+
116
+ class MSAModule(nn.Module):
117
+ """MSA module."""
118
+
119
+ def __init__(
120
+ self,
121
+ msa_s: int,
122
+ token_z: int,
123
+ s_input_dim: int,
124
+ msa_blocks: int,
125
+ msa_dropout: float,
126
+ z_dropout: float,
127
+ pairwise_head_width: int = 32,
128
+ pairwise_num_heads: int = 4,
129
+ activation_checkpointing: bool = False,
130
+ use_paired_feature: bool = False,
131
+ offload_to_cpu: bool = False,
132
+ subsample_msa: bool = False,
133
+ num_subsampled_msa: int = 1024,
134
+ **kwargs,
135
+ ) -> None:
136
+ """Initialize the MSA module.
137
+
138
+ Parameters
139
+ ----------
140
+ msa_s : int
141
+ The MSA embedding size.
142
+ token_z : int
143
+ The token pairwise embedding size.
144
+ s_input_dim : int
145
+ The input sequence dimension.
146
+ msa_blocks : int
147
+ The number of MSA blocks.
148
+ msa_dropout : float
149
+ The MSA dropout.
150
+ z_dropout : float
151
+ The pairwise dropout.
152
+ pairwise_head_width : int, optional
153
+ The pairwise head width, by default 32
154
+ pairwise_num_heads : int, optional
155
+ The number of pairwise heads, by default 4
156
+ activation_checkpointing : bool, optional
157
+ Whether to use activation checkpointing, by default False
158
+ use_paired_feature : bool, optional
159
+ Whether to use the paired feature, by default False
160
+ offload_to_cpu : bool, optional
161
+ Whether to offload to CPU, by default False
162
+
163
+ """
164
+ super().__init__()
165
+ self.msa_blocks = msa_blocks
166
+ self.msa_dropout = msa_dropout
167
+ self.z_dropout = z_dropout
168
+ self.use_paired_feature = use_paired_feature
169
+ self.subsample_msa = subsample_msa
170
+ self.num_subsampled_msa = num_subsampled_msa
171
+
172
+ self.s_proj = nn.Linear(s_input_dim, msa_s, bias=False)
173
+ self.msa_proj = nn.Linear(
174
+ const.num_tokens + 2 + int(use_paired_feature),
175
+ msa_s,
176
+ bias=False,
177
+ )
178
+ self.layers = nn.ModuleList()
179
+ for i in range(msa_blocks):
180
+ if activation_checkpointing:
181
+ self.layers.append(
182
+ checkpoint_wrapper(
183
+ MSALayer(
184
+ msa_s,
185
+ token_z,
186
+ msa_dropout,
187
+ z_dropout,
188
+ pairwise_head_width,
189
+ pairwise_num_heads,
190
+ ),
191
+ offload_to_cpu=offload_to_cpu,
192
+ )
193
+ )
194
+ else:
195
+ self.layers.append(
196
+ MSALayer(
197
+ msa_s,
198
+ token_z,
199
+ msa_dropout,
200
+ z_dropout,
201
+ pairwise_head_width,
202
+ pairwise_num_heads,
203
+ )
204
+ )
205
+
206
+ def forward(
207
+ self,
208
+ z: Tensor,
209
+ emb: Tensor,
210
+ feats: dict[str, Tensor],
211
+ use_kernels: bool = False,
212
+ ) -> Tensor:
213
+ """Perform the forward pass.
214
+
215
+ Parameters
216
+ ----------
217
+ z : Tensor
218
+ The pairwise embeddings
219
+ emb : Tensor
220
+ The input embeddings
221
+ feats : dict[str, Tensor]
222
+ Input features
223
+
224
+ Returns
225
+ -------
226
+ Tensor
227
+ The output pairwise embeddings.
228
+
229
+ """
230
+ # Set chunk sizes
231
+ if not self.training:
232
+ if z.shape[1] > const.chunk_size_threshold:
233
+ chunk_heads_pwa = True
234
+ chunk_size_transition_z = 64
235
+ chunk_size_transition_msa = 32
236
+ chunk_size_outer_product = 4
237
+ chunk_size_tri_attn = 128
238
+ else:
239
+ chunk_heads_pwa = False
240
+ chunk_size_transition_z = None
241
+ chunk_size_transition_msa = None
242
+ chunk_size_outer_product = None
243
+ chunk_size_tri_attn = 512
244
+ else:
245
+ chunk_heads_pwa = False
246
+ chunk_size_transition_z = None
247
+ chunk_size_transition_msa = None
248
+ chunk_size_outer_product = None
249
+ chunk_size_tri_attn = None
250
+
251
+ # Load relevant features
252
+ msa = feats["msa"]
253
+ has_deletion = feats["has_deletion"].unsqueeze(-1)
254
+ deletion_value = feats["deletion_value"].unsqueeze(-1)
255
+ is_paired = feats["msa_paired"].unsqueeze(-1)
256
+ msa_mask = feats["msa_mask"]
257
+ token_mask = feats["token_pad_mask"].float()
258
+ token_mask = token_mask[:, :, None] * token_mask[:, None, :]
259
+
260
+ # Compute MSA embeddings
261
+ if self.use_paired_feature:
262
+ m = torch.cat([msa, has_deletion, deletion_value, is_paired], dim=-1)
263
+ else:
264
+ m = torch.cat([msa, has_deletion, deletion_value], dim=-1)
265
+
266
+ if self.subsample_msa:
267
+ msa_indices = torch.randperm(m.shape[1])[: self.num_subsampled_msa]
268
+ m = m[:, msa_indices]
269
+ msa_mask = msa_mask[:, msa_indices]
270
+
271
+ # Compute input projections
272
+ m = self.msa_proj(m)
273
+ m = m + self.s_proj(emb).unsqueeze(1)
274
+
275
+ # Perform MSA blocks
276
+ for i in range(self.msa_blocks):
277
+ z, m = self.layers[i](
278
+ z,
279
+ m,
280
+ token_mask,
281
+ msa_mask,
282
+ chunk_heads_pwa,
283
+ chunk_size_transition_z,
284
+ chunk_size_transition_msa,
285
+ chunk_size_outer_product,
286
+ chunk_size_tri_attn,
287
+ use_kernels=use_kernels,
288
+ )
289
+ return z
290
+
291
+
292
+ class MSALayer(nn.Module):
293
+ """MSA module."""
294
+
295
+ def __init__(
296
+ self,
297
+ msa_s: int,
298
+ token_z: int,
299
+ msa_dropout: float,
300
+ z_dropout: float,
301
+ pairwise_head_width: int = 32,
302
+ pairwise_num_heads: int = 4,
303
+ ) -> None:
304
+ """Initialize the MSA module.
305
+
306
+ Parameters
307
+ ----------
308
+
309
+ msa_s : int
310
+ The MSA embedding size.
311
+ token_z : int
312
+ The pair representation dimention.
313
+ msa_dropout : float
314
+ The MSA dropout.
315
+ z_dropout : float
316
+ The pair dropout.
317
+ pairwise_head_width : int, optional
318
+ The pairwise head width, by default 32
319
+ pairwise_num_heads : int, optional
320
+ The number of pairwise heads, by default 4
321
+
322
+ """
323
+ super().__init__()
324
+ self.msa_dropout = msa_dropout
325
+ self.z_dropout = z_dropout
326
+ self.msa_transition = Transition(dim=msa_s, hidden=msa_s * 4)
327
+ self.pair_weighted_averaging = PairWeightedAveraging(
328
+ c_m=msa_s,
329
+ c_z=token_z,
330
+ c_h=32,
331
+ num_heads=8,
332
+ )
333
+
334
+ self.tri_mul_out = TriangleMultiplicationOutgoing(token_z)
335
+ self.tri_mul_in = TriangleMultiplicationIncoming(token_z)
336
+ self.tri_att_start = TriangleAttentionStartingNode(
337
+ token_z, pairwise_head_width, pairwise_num_heads, inf=1e9
338
+ )
339
+ self.tri_att_end = TriangleAttentionEndingNode(
340
+ token_z, pairwise_head_width, pairwise_num_heads, inf=1e9
341
+ )
342
+ self.z_transition = Transition(
343
+ dim=token_z,
344
+ hidden=token_z * 4,
345
+ )
346
+ self.outer_product_mean = OuterProductMean(
347
+ c_in=msa_s,
348
+ c_hidden=32,
349
+ c_out=token_z,
350
+ )
351
+
352
+ def forward(
353
+ self,
354
+ z: Tensor,
355
+ m: Tensor,
356
+ token_mask: Tensor,
357
+ msa_mask: Tensor,
358
+ chunk_heads_pwa: bool = False,
359
+ chunk_size_transition_z: int = None,
360
+ chunk_size_transition_msa: int = None,
361
+ chunk_size_outer_product: int = None,
362
+ chunk_size_tri_attn: int = None,
363
+ use_kernels: bool = False,
364
+ ) -> tuple[Tensor, Tensor]:
365
+ """Perform the forward pass.
366
+
367
+ Parameters
368
+ ----------
369
+ z : Tensor
370
+ The pair representation
371
+ m : Tensor
372
+ The msa representation
373
+ token_mask : Tensor
374
+ The token mask
375
+ msa_mask : Dict[str, Tensor]
376
+ The MSA mask
377
+
378
+ Returns
379
+ -------
380
+ Tensor
381
+ The output pairwise embeddings.
382
+ Tensor
383
+ The output MSA embeddings.
384
+
385
+ """
386
+ # Communication to MSA stack
387
+ msa_dropout = get_dropout_mask(self.msa_dropout, m, self.training)
388
+ m = m + msa_dropout * self.pair_weighted_averaging(
389
+ m, z, token_mask, chunk_heads_pwa
390
+ )
391
+ m = m + self.msa_transition(m, chunk_size_transition_msa)
392
+
393
+ # Communication to pairwise stack
394
+ z = z + self.outer_product_mean(m, msa_mask, chunk_size_outer_product)
395
+
396
+ # Compute pairwise stack
397
+ dropout = get_dropout_mask(self.z_dropout, z, self.training)
398
+ z = z + dropout * self.tri_mul_out(z, mask=token_mask)
399
+
400
+ dropout = get_dropout_mask(self.z_dropout, z, self.training)
401
+ z = z + dropout * self.tri_mul_in(z, mask=token_mask)
402
+
403
+ dropout = get_dropout_mask(self.z_dropout, z, self.training)
404
+ z = z + dropout * self.tri_att_start(
405
+ z,
406
+ mask=token_mask,
407
+ chunk_size=chunk_size_tri_attn,
408
+ use_kernels=use_kernels,
409
+ )
410
+
411
+ dropout = get_dropout_mask(self.z_dropout, z, self.training, columnwise=True)
412
+ z = z + dropout * self.tri_att_end(
413
+ z,
414
+ mask=token_mask,
415
+ chunk_size=chunk_size_tri_attn,
416
+ use_kernels=use_kernels,
417
+ )
418
+
419
+ z = z + self.z_transition(z, chunk_size_transition_z)
420
+
421
+ return z, m
422
+
423
+
424
+ class PairformerModule(nn.Module):
425
+ """Pairformer module."""
426
+
427
+ def __init__(
428
+ self,
429
+ token_s: int,
430
+ token_z: int,
431
+ num_blocks: int,
432
+ num_heads: int = 16,
433
+ dropout: float = 0.25,
434
+ pairwise_head_width: int = 32,
435
+ pairwise_num_heads: int = 4,
436
+ activation_checkpointing: bool = False,
437
+ no_update_s: bool = False,
438
+ no_update_z: bool = False,
439
+ offload_to_cpu: bool = False,
440
+ **kwargs,
441
+ ) -> None:
442
+ """Initialize the Pairformer module.
443
+
444
+ Parameters
445
+ ----------
446
+ token_s : int
447
+ The token single embedding size.
448
+ token_z : int
449
+ The token pairwise embedding size.
450
+ num_blocks : int
451
+ The number of blocks.
452
+ num_heads : int, optional
453
+ The number of heads, by default 16
454
+ dropout : float, optional
455
+ The dropout rate, by default 0.25
456
+ pairwise_head_width : int, optional
457
+ The pairwise head width, by default 32
458
+ pairwise_num_heads : int, optional
459
+ The number of pairwise heads, by default 4
460
+ activation_checkpointing : bool, optional
461
+ Whether to use activation checkpointing, by default False
462
+ no_update_s : bool, optional
463
+ Whether to update the single embeddings, by default False
464
+ no_update_z : bool, optional
465
+ Whether to update the pairwise embeddings, by default False
466
+ offload_to_cpu : bool, optional
467
+ Whether to offload to CPU, by default False
468
+
469
+ """
470
+ super().__init__()
471
+ self.token_z = token_z
472
+ self.num_blocks = num_blocks
473
+ self.dropout = dropout
474
+ self.num_heads = num_heads
475
+
476
+ self.layers = nn.ModuleList()
477
+ for i in range(num_blocks):
478
+ if activation_checkpointing:
479
+ self.layers.append(
480
+ checkpoint_wrapper(
481
+ PairformerLayer(
482
+ token_s,
483
+ token_z,
484
+ num_heads,
485
+ dropout,
486
+ pairwise_head_width,
487
+ pairwise_num_heads,
488
+ no_update_s,
489
+ False if i < num_blocks - 1 else no_update_z,
490
+ ),
491
+ offload_to_cpu=offload_to_cpu,
492
+ )
493
+ )
494
+ else:
495
+ self.layers.append(
496
+ PairformerLayer(
497
+ token_s,
498
+ token_z,
499
+ num_heads,
500
+ dropout,
501
+ pairwise_head_width,
502
+ pairwise_num_heads,
503
+ no_update_s,
504
+ False if i < num_blocks - 1 else no_update_z,
505
+ )
506
+ )
507
+
508
+ def forward(
509
+ self,
510
+ s: Tensor,
511
+ z: Tensor,
512
+ mask: Tensor,
513
+ pair_mask: Tensor,
514
+ chunk_size_tri_attn: Optional[int] = None,
515
+ use_kernels: bool = False,
516
+ ) -> tuple[Tensor, Tensor]:
517
+ """Perform the forward pass.
518
+
519
+ Parameters
520
+ ----------
521
+ s : Tensor
522
+ The sequence embeddings
523
+ z : Tensor
524
+ The pairwise embeddings
525
+ mask : Tensor
526
+ The token mask
527
+ pair_mask : Tensor
528
+ The pairwise mask
529
+ Returns
530
+ -------
531
+ Tensor
532
+ The updated sequence embeddings.
533
+ Tensor
534
+ The updated pairwise embeddings.
535
+
536
+ """
537
+ if not self.training:
538
+ if z.shape[1] > const.chunk_size_threshold:
539
+ chunk_size_tri_attn = 128
540
+ else:
541
+ chunk_size_tri_attn = 512
542
+ else:
543
+ chunk_size_tri_attn = None
544
+
545
+ for layer in self.layers:
546
+ s, z = layer(
547
+ s,
548
+ z,
549
+ mask,
550
+ pair_mask,
551
+ chunk_size_tri_attn,
552
+ use_kernels=use_kernels,
553
+ )
554
+ return s, z
555
+
556
+
557
+ class PairformerLayer(nn.Module):
558
+ """Pairformer module."""
559
+
560
+ def __init__(
561
+ self,
562
+ token_s: int,
563
+ token_z: int,
564
+ num_heads: int = 16,
565
+ dropout: float = 0.25,
566
+ pairwise_head_width: int = 32,
567
+ pairwise_num_heads: int = 4,
568
+ no_update_s: bool = False,
569
+ no_update_z: bool = False,
570
+ ) -> None:
571
+ """Initialize the Pairformer module.
572
+
573
+ Parameters
574
+ ----------
575
+ token_s : int
576
+ The token single embedding size.
577
+ token_z : int
578
+ The token pairwise embedding size.
579
+ num_heads : int, optional
580
+ The number of heads, by default 16
581
+ dropout : float, optiona
582
+ The dropout rate, by default 0.25
583
+ pairwise_head_width : int, optional
584
+ The pairwise head width, by default 32
585
+ pairwise_num_heads : int, optional
586
+ The number of pairwise heads, by default 4
587
+ no_update_s : bool, optional
588
+ Whether to update the single embeddings, by default False
589
+ no_update_z : bool, optional
590
+ Whether to update the pairwise embeddings, by default False
591
+
592
+ """
593
+ super().__init__()
594
+ self.token_z = token_z
595
+ self.dropout = dropout
596
+ self.num_heads = num_heads
597
+ self.no_update_s = no_update_s
598
+ self.no_update_z = no_update_z
599
+ if not self.no_update_s:
600
+ self.attention = AttentionPairBias(token_s, token_z, num_heads)
601
+ self.tri_mul_out = TriangleMultiplicationOutgoing(token_z)
602
+ self.tri_mul_in = TriangleMultiplicationIncoming(token_z)
603
+ self.tri_att_start = TriangleAttentionStartingNode(
604
+ token_z, pairwise_head_width, pairwise_num_heads, inf=1e9
605
+ )
606
+ self.tri_att_end = TriangleAttentionEndingNode(
607
+ token_z, pairwise_head_width, pairwise_num_heads, inf=1e9
608
+ )
609
+ if not self.no_update_s:
610
+ self.transition_s = Transition(token_s, token_s * 4)
611
+ self.transition_z = Transition(token_z, token_z * 4)
612
+
613
+ def forward(
614
+ self,
615
+ s: Tensor,
616
+ z: Tensor,
617
+ mask: Tensor,
618
+ pair_mask: Tensor,
619
+ chunk_size_tri_attn: Optional[int] = None,
620
+ use_kernels: bool = False,
621
+ ) -> tuple[Tensor, Tensor]:
622
+ """Perform the forward pass."""
623
+ # Compute pairwise stack
624
+ dropout = get_dropout_mask(self.dropout, z, self.training)
625
+ z = z + dropout * self.tri_mul_out(z, mask=pair_mask)
626
+
627
+ dropout = get_dropout_mask(self.dropout, z, self.training)
628
+ z = z + dropout * self.tri_mul_in(z, mask=pair_mask)
629
+
630
+ dropout = get_dropout_mask(self.dropout, z, self.training)
631
+ z = z + dropout * self.tri_att_start(
632
+ z,
633
+ mask=pair_mask,
634
+ chunk_size=chunk_size_tri_attn,
635
+ use_kernels=use_kernels,
636
+ )
637
+
638
+ dropout = get_dropout_mask(self.dropout, z, self.training, columnwise=True)
639
+ z = z + dropout * self.tri_att_end(
640
+ z,
641
+ mask=pair_mask,
642
+ chunk_size=chunk_size_tri_attn,
643
+ use_kernels=use_kernels,
644
+ )
645
+
646
+ z = z + self.transition_z(z)
647
+
648
+ # Compute sequence stack
649
+ if not self.no_update_s:
650
+ s = s + self.attention(s, z, mask)
651
+ s = s + self.transition_s(s)
652
+
653
+ return s, z
654
+
655
+
656
+ class DistogramModule(nn.Module):
657
+ """Distogram Module."""
658
+
659
+ def __init__(self, token_z: int, num_bins: int) -> None:
660
+ """Initialize the distogram module.
661
+
662
+ Parameters
663
+ ----------
664
+ token_z : int
665
+ The token pairwise embedding size.
666
+ num_bins : int
667
+ The number of bins.
668
+
669
+ """
670
+ super().__init__()
671
+ self.distogram = nn.Linear(token_z, num_bins)
672
+
673
+ def forward(self, z: Tensor) -> Tensor:
674
+ """Perform the forward pass.
675
+
676
+ Parameters
677
+ ----------
678
+ z : Tensor
679
+ The pairwise embeddings
680
+
681
+ Returns
682
+ -------
683
+ Tensor
684
+ The predicted distogram.
685
+
686
+ """
687
+ z = z + z.transpose(1, 2)
688
+ return self.distogram(z)