boltz-vsynthes 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltz/__init__.py +7 -0
- boltz/data/__init__.py +0 -0
- boltz/data/const.py +1184 -0
- boltz/data/crop/__init__.py +0 -0
- boltz/data/crop/affinity.py +164 -0
- boltz/data/crop/boltz.py +296 -0
- boltz/data/crop/cropper.py +45 -0
- boltz/data/feature/__init__.py +0 -0
- boltz/data/feature/featurizer.py +1230 -0
- boltz/data/feature/featurizerv2.py +2208 -0
- boltz/data/feature/symmetry.py +602 -0
- boltz/data/filter/__init__.py +0 -0
- boltz/data/filter/dynamic/__init__.py +0 -0
- boltz/data/filter/dynamic/date.py +76 -0
- boltz/data/filter/dynamic/filter.py +24 -0
- boltz/data/filter/dynamic/max_residues.py +37 -0
- boltz/data/filter/dynamic/resolution.py +34 -0
- boltz/data/filter/dynamic/size.py +38 -0
- boltz/data/filter/dynamic/subset.py +42 -0
- boltz/data/filter/static/__init__.py +0 -0
- boltz/data/filter/static/filter.py +26 -0
- boltz/data/filter/static/ligand.py +37 -0
- boltz/data/filter/static/polymer.py +299 -0
- boltz/data/module/__init__.py +0 -0
- boltz/data/module/inference.py +307 -0
- boltz/data/module/inferencev2.py +429 -0
- boltz/data/module/training.py +684 -0
- boltz/data/module/trainingv2.py +660 -0
- boltz/data/mol.py +900 -0
- boltz/data/msa/__init__.py +0 -0
- boltz/data/msa/mmseqs2.py +235 -0
- boltz/data/pad.py +84 -0
- boltz/data/parse/__init__.py +0 -0
- boltz/data/parse/a3m.py +134 -0
- boltz/data/parse/csv.py +100 -0
- boltz/data/parse/fasta.py +138 -0
- boltz/data/parse/mmcif.py +1239 -0
- boltz/data/parse/mmcif_with_constraints.py +1607 -0
- boltz/data/parse/schema.py +1851 -0
- boltz/data/parse/yaml.py +68 -0
- boltz/data/sample/__init__.py +0 -0
- boltz/data/sample/cluster.py +283 -0
- boltz/data/sample/distillation.py +57 -0
- boltz/data/sample/random.py +39 -0
- boltz/data/sample/sampler.py +49 -0
- boltz/data/tokenize/__init__.py +0 -0
- boltz/data/tokenize/boltz.py +195 -0
- boltz/data/tokenize/boltz2.py +396 -0
- boltz/data/tokenize/tokenizer.py +24 -0
- boltz/data/types.py +777 -0
- boltz/data/write/__init__.py +0 -0
- boltz/data/write/mmcif.py +305 -0
- boltz/data/write/pdb.py +171 -0
- boltz/data/write/utils.py +23 -0
- boltz/data/write/writer.py +330 -0
- boltz/main.py +1292 -0
- boltz/model/__init__.py +0 -0
- boltz/model/layers/__init__.py +0 -0
- boltz/model/layers/attention.py +132 -0
- boltz/model/layers/attentionv2.py +111 -0
- boltz/model/layers/confidence_utils.py +231 -0
- boltz/model/layers/dropout.py +34 -0
- boltz/model/layers/initialize.py +100 -0
- boltz/model/layers/outer_product_mean.py +98 -0
- boltz/model/layers/pair_averaging.py +135 -0
- boltz/model/layers/pairformer.py +337 -0
- boltz/model/layers/relative.py +58 -0
- boltz/model/layers/transition.py +78 -0
- boltz/model/layers/triangular_attention/__init__.py +0 -0
- boltz/model/layers/triangular_attention/attention.py +189 -0
- boltz/model/layers/triangular_attention/primitives.py +409 -0
- boltz/model/layers/triangular_attention/utils.py +380 -0
- boltz/model/layers/triangular_mult.py +212 -0
- boltz/model/loss/__init__.py +0 -0
- boltz/model/loss/bfactor.py +49 -0
- boltz/model/loss/confidence.py +590 -0
- boltz/model/loss/confidencev2.py +621 -0
- boltz/model/loss/diffusion.py +171 -0
- boltz/model/loss/diffusionv2.py +134 -0
- boltz/model/loss/distogram.py +48 -0
- boltz/model/loss/distogramv2.py +105 -0
- boltz/model/loss/validation.py +1025 -0
- boltz/model/models/__init__.py +0 -0
- boltz/model/models/boltz1.py +1286 -0
- boltz/model/models/boltz2.py +1249 -0
- boltz/model/modules/__init__.py +0 -0
- boltz/model/modules/affinity.py +223 -0
- boltz/model/modules/confidence.py +481 -0
- boltz/model/modules/confidence_utils.py +181 -0
- boltz/model/modules/confidencev2.py +495 -0
- boltz/model/modules/diffusion.py +844 -0
- boltz/model/modules/diffusion_conditioning.py +116 -0
- boltz/model/modules/diffusionv2.py +677 -0
- boltz/model/modules/encoders.py +639 -0
- boltz/model/modules/encodersv2.py +565 -0
- boltz/model/modules/transformers.py +322 -0
- boltz/model/modules/transformersv2.py +261 -0
- boltz/model/modules/trunk.py +688 -0
- boltz/model/modules/trunkv2.py +828 -0
- boltz/model/modules/utils.py +303 -0
- boltz/model/optim/__init__.py +0 -0
- boltz/model/optim/ema.py +389 -0
- boltz/model/optim/scheduler.py +99 -0
- boltz/model/potentials/__init__.py +0 -0
- boltz/model/potentials/potentials.py +497 -0
- boltz/model/potentials/schedules.py +32 -0
- boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
- boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
- boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
- boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
- boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
- boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,688 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
5
|
+
from torch import Tensor, nn
|
6
|
+
|
7
|
+
from boltz.data import const
|
8
|
+
from boltz.model.layers.attention import AttentionPairBias
|
9
|
+
from boltz.model.layers.dropout import get_dropout_mask
|
10
|
+
from boltz.model.layers.outer_product_mean import OuterProductMean
|
11
|
+
from boltz.model.layers.pair_averaging import PairWeightedAveraging
|
12
|
+
from boltz.model.layers.transition import Transition
|
13
|
+
from boltz.model.layers.triangular_attention.attention import (
|
14
|
+
TriangleAttentionEndingNode,
|
15
|
+
TriangleAttentionStartingNode,
|
16
|
+
)
|
17
|
+
from boltz.model.layers.triangular_mult import (
|
18
|
+
TriangleMultiplicationIncoming,
|
19
|
+
TriangleMultiplicationOutgoing,
|
20
|
+
)
|
21
|
+
from boltz.model.modules.encoders import AtomAttentionEncoder
|
22
|
+
|
23
|
+
|
24
|
+
class InputEmbedder(nn.Module):
|
25
|
+
"""Input embedder."""
|
26
|
+
|
27
|
+
def __init__(
|
28
|
+
self,
|
29
|
+
atom_s: int,
|
30
|
+
atom_z: int,
|
31
|
+
token_s: int,
|
32
|
+
token_z: int,
|
33
|
+
atoms_per_window_queries: int,
|
34
|
+
atoms_per_window_keys: int,
|
35
|
+
atom_feature_dim: int,
|
36
|
+
atom_encoder_depth: int,
|
37
|
+
atom_encoder_heads: int,
|
38
|
+
no_atom_encoder: bool = False,
|
39
|
+
) -> None:
|
40
|
+
"""Initialize the input embedder.
|
41
|
+
|
42
|
+
Parameters
|
43
|
+
----------
|
44
|
+
atom_s : int
|
45
|
+
The atom single representation dimension.
|
46
|
+
atom_z : int
|
47
|
+
The atom pair representation dimension.
|
48
|
+
token_s : int
|
49
|
+
The single token representation dimension.
|
50
|
+
token_z : int
|
51
|
+
The pair token representation dimension.
|
52
|
+
atoms_per_window_queries : int
|
53
|
+
The number of atoms per window for queries.
|
54
|
+
atoms_per_window_keys : int
|
55
|
+
The number of atoms per window for keys.
|
56
|
+
atom_feature_dim : int
|
57
|
+
The atom feature dimension.
|
58
|
+
atom_encoder_depth : int
|
59
|
+
The atom encoder depth.
|
60
|
+
atom_encoder_heads : int
|
61
|
+
The atom encoder heads.
|
62
|
+
no_atom_encoder : bool, optional
|
63
|
+
Whether to use the atom encoder, by default False
|
64
|
+
|
65
|
+
"""
|
66
|
+
super().__init__()
|
67
|
+
self.token_s = token_s
|
68
|
+
self.no_atom_encoder = no_atom_encoder
|
69
|
+
|
70
|
+
if not no_atom_encoder:
|
71
|
+
self.atom_attention_encoder = AtomAttentionEncoder(
|
72
|
+
atom_s=atom_s,
|
73
|
+
atom_z=atom_z,
|
74
|
+
token_s=token_s,
|
75
|
+
token_z=token_z,
|
76
|
+
atoms_per_window_queries=atoms_per_window_queries,
|
77
|
+
atoms_per_window_keys=atoms_per_window_keys,
|
78
|
+
atom_feature_dim=atom_feature_dim,
|
79
|
+
atom_encoder_depth=atom_encoder_depth,
|
80
|
+
atom_encoder_heads=atom_encoder_heads,
|
81
|
+
structure_prediction=False,
|
82
|
+
)
|
83
|
+
|
84
|
+
def forward(self, feats: dict[str, Tensor]) -> Tensor:
|
85
|
+
"""Perform the forward pass.
|
86
|
+
|
87
|
+
Parameters
|
88
|
+
----------
|
89
|
+
feats : Dict[str, Tensor]
|
90
|
+
Input features
|
91
|
+
|
92
|
+
Returns
|
93
|
+
-------
|
94
|
+
Tensor
|
95
|
+
The embedded tokens.
|
96
|
+
|
97
|
+
"""
|
98
|
+
# Load relevant features
|
99
|
+
res_type = feats["res_type"]
|
100
|
+
profile = feats["profile"]
|
101
|
+
deletion_mean = feats["deletion_mean"].unsqueeze(-1)
|
102
|
+
pocket_feature = feats["pocket_feature"]
|
103
|
+
|
104
|
+
# Compute input embedding
|
105
|
+
if self.no_atom_encoder:
|
106
|
+
a = torch.zeros(
|
107
|
+
(res_type.shape[0], res_type.shape[1], self.token_s),
|
108
|
+
device=res_type.device,
|
109
|
+
)
|
110
|
+
else:
|
111
|
+
a, _, _, _, _ = self.atom_attention_encoder(feats)
|
112
|
+
s = torch.cat([a, res_type, profile, deletion_mean, pocket_feature], dim=-1)
|
113
|
+
return s
|
114
|
+
|
115
|
+
|
116
|
+
class MSAModule(nn.Module):
|
117
|
+
"""MSA module."""
|
118
|
+
|
119
|
+
def __init__(
|
120
|
+
self,
|
121
|
+
msa_s: int,
|
122
|
+
token_z: int,
|
123
|
+
s_input_dim: int,
|
124
|
+
msa_blocks: int,
|
125
|
+
msa_dropout: float,
|
126
|
+
z_dropout: float,
|
127
|
+
pairwise_head_width: int = 32,
|
128
|
+
pairwise_num_heads: int = 4,
|
129
|
+
activation_checkpointing: bool = False,
|
130
|
+
use_paired_feature: bool = False,
|
131
|
+
offload_to_cpu: bool = False,
|
132
|
+
subsample_msa: bool = False,
|
133
|
+
num_subsampled_msa: int = 1024,
|
134
|
+
**kwargs,
|
135
|
+
) -> None:
|
136
|
+
"""Initialize the MSA module.
|
137
|
+
|
138
|
+
Parameters
|
139
|
+
----------
|
140
|
+
msa_s : int
|
141
|
+
The MSA embedding size.
|
142
|
+
token_z : int
|
143
|
+
The token pairwise embedding size.
|
144
|
+
s_input_dim : int
|
145
|
+
The input sequence dimension.
|
146
|
+
msa_blocks : int
|
147
|
+
The number of MSA blocks.
|
148
|
+
msa_dropout : float
|
149
|
+
The MSA dropout.
|
150
|
+
z_dropout : float
|
151
|
+
The pairwise dropout.
|
152
|
+
pairwise_head_width : int, optional
|
153
|
+
The pairwise head width, by default 32
|
154
|
+
pairwise_num_heads : int, optional
|
155
|
+
The number of pairwise heads, by default 4
|
156
|
+
activation_checkpointing : bool, optional
|
157
|
+
Whether to use activation checkpointing, by default False
|
158
|
+
use_paired_feature : bool, optional
|
159
|
+
Whether to use the paired feature, by default False
|
160
|
+
offload_to_cpu : bool, optional
|
161
|
+
Whether to offload to CPU, by default False
|
162
|
+
|
163
|
+
"""
|
164
|
+
super().__init__()
|
165
|
+
self.msa_blocks = msa_blocks
|
166
|
+
self.msa_dropout = msa_dropout
|
167
|
+
self.z_dropout = z_dropout
|
168
|
+
self.use_paired_feature = use_paired_feature
|
169
|
+
self.subsample_msa = subsample_msa
|
170
|
+
self.num_subsampled_msa = num_subsampled_msa
|
171
|
+
|
172
|
+
self.s_proj = nn.Linear(s_input_dim, msa_s, bias=False)
|
173
|
+
self.msa_proj = nn.Linear(
|
174
|
+
const.num_tokens + 2 + int(use_paired_feature),
|
175
|
+
msa_s,
|
176
|
+
bias=False,
|
177
|
+
)
|
178
|
+
self.layers = nn.ModuleList()
|
179
|
+
for i in range(msa_blocks):
|
180
|
+
if activation_checkpointing:
|
181
|
+
self.layers.append(
|
182
|
+
checkpoint_wrapper(
|
183
|
+
MSALayer(
|
184
|
+
msa_s,
|
185
|
+
token_z,
|
186
|
+
msa_dropout,
|
187
|
+
z_dropout,
|
188
|
+
pairwise_head_width,
|
189
|
+
pairwise_num_heads,
|
190
|
+
),
|
191
|
+
offload_to_cpu=offload_to_cpu,
|
192
|
+
)
|
193
|
+
)
|
194
|
+
else:
|
195
|
+
self.layers.append(
|
196
|
+
MSALayer(
|
197
|
+
msa_s,
|
198
|
+
token_z,
|
199
|
+
msa_dropout,
|
200
|
+
z_dropout,
|
201
|
+
pairwise_head_width,
|
202
|
+
pairwise_num_heads,
|
203
|
+
)
|
204
|
+
)
|
205
|
+
|
206
|
+
def forward(
|
207
|
+
self,
|
208
|
+
z: Tensor,
|
209
|
+
emb: Tensor,
|
210
|
+
feats: dict[str, Tensor],
|
211
|
+
use_kernels: bool = False,
|
212
|
+
) -> Tensor:
|
213
|
+
"""Perform the forward pass.
|
214
|
+
|
215
|
+
Parameters
|
216
|
+
----------
|
217
|
+
z : Tensor
|
218
|
+
The pairwise embeddings
|
219
|
+
emb : Tensor
|
220
|
+
The input embeddings
|
221
|
+
feats : dict[str, Tensor]
|
222
|
+
Input features
|
223
|
+
|
224
|
+
Returns
|
225
|
+
-------
|
226
|
+
Tensor
|
227
|
+
The output pairwise embeddings.
|
228
|
+
|
229
|
+
"""
|
230
|
+
# Set chunk sizes
|
231
|
+
if not self.training:
|
232
|
+
if z.shape[1] > const.chunk_size_threshold:
|
233
|
+
chunk_heads_pwa = True
|
234
|
+
chunk_size_transition_z = 64
|
235
|
+
chunk_size_transition_msa = 32
|
236
|
+
chunk_size_outer_product = 4
|
237
|
+
chunk_size_tri_attn = 128
|
238
|
+
else:
|
239
|
+
chunk_heads_pwa = False
|
240
|
+
chunk_size_transition_z = None
|
241
|
+
chunk_size_transition_msa = None
|
242
|
+
chunk_size_outer_product = None
|
243
|
+
chunk_size_tri_attn = 512
|
244
|
+
else:
|
245
|
+
chunk_heads_pwa = False
|
246
|
+
chunk_size_transition_z = None
|
247
|
+
chunk_size_transition_msa = None
|
248
|
+
chunk_size_outer_product = None
|
249
|
+
chunk_size_tri_attn = None
|
250
|
+
|
251
|
+
# Load relevant features
|
252
|
+
msa = feats["msa"]
|
253
|
+
has_deletion = feats["has_deletion"].unsqueeze(-1)
|
254
|
+
deletion_value = feats["deletion_value"].unsqueeze(-1)
|
255
|
+
is_paired = feats["msa_paired"].unsqueeze(-1)
|
256
|
+
msa_mask = feats["msa_mask"]
|
257
|
+
token_mask = feats["token_pad_mask"].float()
|
258
|
+
token_mask = token_mask[:, :, None] * token_mask[:, None, :]
|
259
|
+
|
260
|
+
# Compute MSA embeddings
|
261
|
+
if self.use_paired_feature:
|
262
|
+
m = torch.cat([msa, has_deletion, deletion_value, is_paired], dim=-1)
|
263
|
+
else:
|
264
|
+
m = torch.cat([msa, has_deletion, deletion_value], dim=-1)
|
265
|
+
|
266
|
+
if self.subsample_msa:
|
267
|
+
msa_indices = torch.randperm(m.shape[1])[: self.num_subsampled_msa]
|
268
|
+
m = m[:, msa_indices]
|
269
|
+
msa_mask = msa_mask[:, msa_indices]
|
270
|
+
|
271
|
+
# Compute input projections
|
272
|
+
m = self.msa_proj(m)
|
273
|
+
m = m + self.s_proj(emb).unsqueeze(1)
|
274
|
+
|
275
|
+
# Perform MSA blocks
|
276
|
+
for i in range(self.msa_blocks):
|
277
|
+
z, m = self.layers[i](
|
278
|
+
z,
|
279
|
+
m,
|
280
|
+
token_mask,
|
281
|
+
msa_mask,
|
282
|
+
chunk_heads_pwa,
|
283
|
+
chunk_size_transition_z,
|
284
|
+
chunk_size_transition_msa,
|
285
|
+
chunk_size_outer_product,
|
286
|
+
chunk_size_tri_attn,
|
287
|
+
use_kernels=use_kernels,
|
288
|
+
)
|
289
|
+
return z
|
290
|
+
|
291
|
+
|
292
|
+
class MSALayer(nn.Module):
|
293
|
+
"""MSA module."""
|
294
|
+
|
295
|
+
def __init__(
|
296
|
+
self,
|
297
|
+
msa_s: int,
|
298
|
+
token_z: int,
|
299
|
+
msa_dropout: float,
|
300
|
+
z_dropout: float,
|
301
|
+
pairwise_head_width: int = 32,
|
302
|
+
pairwise_num_heads: int = 4,
|
303
|
+
) -> None:
|
304
|
+
"""Initialize the MSA module.
|
305
|
+
|
306
|
+
Parameters
|
307
|
+
----------
|
308
|
+
|
309
|
+
msa_s : int
|
310
|
+
The MSA embedding size.
|
311
|
+
token_z : int
|
312
|
+
The pair representation dimention.
|
313
|
+
msa_dropout : float
|
314
|
+
The MSA dropout.
|
315
|
+
z_dropout : float
|
316
|
+
The pair dropout.
|
317
|
+
pairwise_head_width : int, optional
|
318
|
+
The pairwise head width, by default 32
|
319
|
+
pairwise_num_heads : int, optional
|
320
|
+
The number of pairwise heads, by default 4
|
321
|
+
|
322
|
+
"""
|
323
|
+
super().__init__()
|
324
|
+
self.msa_dropout = msa_dropout
|
325
|
+
self.z_dropout = z_dropout
|
326
|
+
self.msa_transition = Transition(dim=msa_s, hidden=msa_s * 4)
|
327
|
+
self.pair_weighted_averaging = PairWeightedAveraging(
|
328
|
+
c_m=msa_s,
|
329
|
+
c_z=token_z,
|
330
|
+
c_h=32,
|
331
|
+
num_heads=8,
|
332
|
+
)
|
333
|
+
|
334
|
+
self.tri_mul_out = TriangleMultiplicationOutgoing(token_z)
|
335
|
+
self.tri_mul_in = TriangleMultiplicationIncoming(token_z)
|
336
|
+
self.tri_att_start = TriangleAttentionStartingNode(
|
337
|
+
token_z, pairwise_head_width, pairwise_num_heads, inf=1e9
|
338
|
+
)
|
339
|
+
self.tri_att_end = TriangleAttentionEndingNode(
|
340
|
+
token_z, pairwise_head_width, pairwise_num_heads, inf=1e9
|
341
|
+
)
|
342
|
+
self.z_transition = Transition(
|
343
|
+
dim=token_z,
|
344
|
+
hidden=token_z * 4,
|
345
|
+
)
|
346
|
+
self.outer_product_mean = OuterProductMean(
|
347
|
+
c_in=msa_s,
|
348
|
+
c_hidden=32,
|
349
|
+
c_out=token_z,
|
350
|
+
)
|
351
|
+
|
352
|
+
def forward(
|
353
|
+
self,
|
354
|
+
z: Tensor,
|
355
|
+
m: Tensor,
|
356
|
+
token_mask: Tensor,
|
357
|
+
msa_mask: Tensor,
|
358
|
+
chunk_heads_pwa: bool = False,
|
359
|
+
chunk_size_transition_z: int = None,
|
360
|
+
chunk_size_transition_msa: int = None,
|
361
|
+
chunk_size_outer_product: int = None,
|
362
|
+
chunk_size_tri_attn: int = None,
|
363
|
+
use_kernels: bool = False,
|
364
|
+
) -> tuple[Tensor, Tensor]:
|
365
|
+
"""Perform the forward pass.
|
366
|
+
|
367
|
+
Parameters
|
368
|
+
----------
|
369
|
+
z : Tensor
|
370
|
+
The pair representation
|
371
|
+
m : Tensor
|
372
|
+
The msa representation
|
373
|
+
token_mask : Tensor
|
374
|
+
The token mask
|
375
|
+
msa_mask : Dict[str, Tensor]
|
376
|
+
The MSA mask
|
377
|
+
|
378
|
+
Returns
|
379
|
+
-------
|
380
|
+
Tensor
|
381
|
+
The output pairwise embeddings.
|
382
|
+
Tensor
|
383
|
+
The output MSA embeddings.
|
384
|
+
|
385
|
+
"""
|
386
|
+
# Communication to MSA stack
|
387
|
+
msa_dropout = get_dropout_mask(self.msa_dropout, m, self.training)
|
388
|
+
m = m + msa_dropout * self.pair_weighted_averaging(
|
389
|
+
m, z, token_mask, chunk_heads_pwa
|
390
|
+
)
|
391
|
+
m = m + self.msa_transition(m, chunk_size_transition_msa)
|
392
|
+
|
393
|
+
# Communication to pairwise stack
|
394
|
+
z = z + self.outer_product_mean(m, msa_mask, chunk_size_outer_product)
|
395
|
+
|
396
|
+
# Compute pairwise stack
|
397
|
+
dropout = get_dropout_mask(self.z_dropout, z, self.training)
|
398
|
+
z = z + dropout * self.tri_mul_out(z, mask=token_mask)
|
399
|
+
|
400
|
+
dropout = get_dropout_mask(self.z_dropout, z, self.training)
|
401
|
+
z = z + dropout * self.tri_mul_in(z, mask=token_mask)
|
402
|
+
|
403
|
+
dropout = get_dropout_mask(self.z_dropout, z, self.training)
|
404
|
+
z = z + dropout * self.tri_att_start(
|
405
|
+
z,
|
406
|
+
mask=token_mask,
|
407
|
+
chunk_size=chunk_size_tri_attn,
|
408
|
+
use_kernels=use_kernels,
|
409
|
+
)
|
410
|
+
|
411
|
+
dropout = get_dropout_mask(self.z_dropout, z, self.training, columnwise=True)
|
412
|
+
z = z + dropout * self.tri_att_end(
|
413
|
+
z,
|
414
|
+
mask=token_mask,
|
415
|
+
chunk_size=chunk_size_tri_attn,
|
416
|
+
use_kernels=use_kernels,
|
417
|
+
)
|
418
|
+
|
419
|
+
z = z + self.z_transition(z, chunk_size_transition_z)
|
420
|
+
|
421
|
+
return z, m
|
422
|
+
|
423
|
+
|
424
|
+
class PairformerModule(nn.Module):
|
425
|
+
"""Pairformer module."""
|
426
|
+
|
427
|
+
def __init__(
|
428
|
+
self,
|
429
|
+
token_s: int,
|
430
|
+
token_z: int,
|
431
|
+
num_blocks: int,
|
432
|
+
num_heads: int = 16,
|
433
|
+
dropout: float = 0.25,
|
434
|
+
pairwise_head_width: int = 32,
|
435
|
+
pairwise_num_heads: int = 4,
|
436
|
+
activation_checkpointing: bool = False,
|
437
|
+
no_update_s: bool = False,
|
438
|
+
no_update_z: bool = False,
|
439
|
+
offload_to_cpu: bool = False,
|
440
|
+
**kwargs,
|
441
|
+
) -> None:
|
442
|
+
"""Initialize the Pairformer module.
|
443
|
+
|
444
|
+
Parameters
|
445
|
+
----------
|
446
|
+
token_s : int
|
447
|
+
The token single embedding size.
|
448
|
+
token_z : int
|
449
|
+
The token pairwise embedding size.
|
450
|
+
num_blocks : int
|
451
|
+
The number of blocks.
|
452
|
+
num_heads : int, optional
|
453
|
+
The number of heads, by default 16
|
454
|
+
dropout : float, optional
|
455
|
+
The dropout rate, by default 0.25
|
456
|
+
pairwise_head_width : int, optional
|
457
|
+
The pairwise head width, by default 32
|
458
|
+
pairwise_num_heads : int, optional
|
459
|
+
The number of pairwise heads, by default 4
|
460
|
+
activation_checkpointing : bool, optional
|
461
|
+
Whether to use activation checkpointing, by default False
|
462
|
+
no_update_s : bool, optional
|
463
|
+
Whether to update the single embeddings, by default False
|
464
|
+
no_update_z : bool, optional
|
465
|
+
Whether to update the pairwise embeddings, by default False
|
466
|
+
offload_to_cpu : bool, optional
|
467
|
+
Whether to offload to CPU, by default False
|
468
|
+
|
469
|
+
"""
|
470
|
+
super().__init__()
|
471
|
+
self.token_z = token_z
|
472
|
+
self.num_blocks = num_blocks
|
473
|
+
self.dropout = dropout
|
474
|
+
self.num_heads = num_heads
|
475
|
+
|
476
|
+
self.layers = nn.ModuleList()
|
477
|
+
for i in range(num_blocks):
|
478
|
+
if activation_checkpointing:
|
479
|
+
self.layers.append(
|
480
|
+
checkpoint_wrapper(
|
481
|
+
PairformerLayer(
|
482
|
+
token_s,
|
483
|
+
token_z,
|
484
|
+
num_heads,
|
485
|
+
dropout,
|
486
|
+
pairwise_head_width,
|
487
|
+
pairwise_num_heads,
|
488
|
+
no_update_s,
|
489
|
+
False if i < num_blocks - 1 else no_update_z,
|
490
|
+
),
|
491
|
+
offload_to_cpu=offload_to_cpu,
|
492
|
+
)
|
493
|
+
)
|
494
|
+
else:
|
495
|
+
self.layers.append(
|
496
|
+
PairformerLayer(
|
497
|
+
token_s,
|
498
|
+
token_z,
|
499
|
+
num_heads,
|
500
|
+
dropout,
|
501
|
+
pairwise_head_width,
|
502
|
+
pairwise_num_heads,
|
503
|
+
no_update_s,
|
504
|
+
False if i < num_blocks - 1 else no_update_z,
|
505
|
+
)
|
506
|
+
)
|
507
|
+
|
508
|
+
def forward(
|
509
|
+
self,
|
510
|
+
s: Tensor,
|
511
|
+
z: Tensor,
|
512
|
+
mask: Tensor,
|
513
|
+
pair_mask: Tensor,
|
514
|
+
chunk_size_tri_attn: Optional[int] = None,
|
515
|
+
use_kernels: bool = False,
|
516
|
+
) -> tuple[Tensor, Tensor]:
|
517
|
+
"""Perform the forward pass.
|
518
|
+
|
519
|
+
Parameters
|
520
|
+
----------
|
521
|
+
s : Tensor
|
522
|
+
The sequence embeddings
|
523
|
+
z : Tensor
|
524
|
+
The pairwise embeddings
|
525
|
+
mask : Tensor
|
526
|
+
The token mask
|
527
|
+
pair_mask : Tensor
|
528
|
+
The pairwise mask
|
529
|
+
Returns
|
530
|
+
-------
|
531
|
+
Tensor
|
532
|
+
The updated sequence embeddings.
|
533
|
+
Tensor
|
534
|
+
The updated pairwise embeddings.
|
535
|
+
|
536
|
+
"""
|
537
|
+
if not self.training:
|
538
|
+
if z.shape[1] > const.chunk_size_threshold:
|
539
|
+
chunk_size_tri_attn = 128
|
540
|
+
else:
|
541
|
+
chunk_size_tri_attn = 512
|
542
|
+
else:
|
543
|
+
chunk_size_tri_attn = None
|
544
|
+
|
545
|
+
for layer in self.layers:
|
546
|
+
s, z = layer(
|
547
|
+
s,
|
548
|
+
z,
|
549
|
+
mask,
|
550
|
+
pair_mask,
|
551
|
+
chunk_size_tri_attn,
|
552
|
+
use_kernels=use_kernels,
|
553
|
+
)
|
554
|
+
return s, z
|
555
|
+
|
556
|
+
|
557
|
+
class PairformerLayer(nn.Module):
|
558
|
+
"""Pairformer module."""
|
559
|
+
|
560
|
+
def __init__(
|
561
|
+
self,
|
562
|
+
token_s: int,
|
563
|
+
token_z: int,
|
564
|
+
num_heads: int = 16,
|
565
|
+
dropout: float = 0.25,
|
566
|
+
pairwise_head_width: int = 32,
|
567
|
+
pairwise_num_heads: int = 4,
|
568
|
+
no_update_s: bool = False,
|
569
|
+
no_update_z: bool = False,
|
570
|
+
) -> None:
|
571
|
+
"""Initialize the Pairformer module.
|
572
|
+
|
573
|
+
Parameters
|
574
|
+
----------
|
575
|
+
token_s : int
|
576
|
+
The token single embedding size.
|
577
|
+
token_z : int
|
578
|
+
The token pairwise embedding size.
|
579
|
+
num_heads : int, optional
|
580
|
+
The number of heads, by default 16
|
581
|
+
dropout : float, optiona
|
582
|
+
The dropout rate, by default 0.25
|
583
|
+
pairwise_head_width : int, optional
|
584
|
+
The pairwise head width, by default 32
|
585
|
+
pairwise_num_heads : int, optional
|
586
|
+
The number of pairwise heads, by default 4
|
587
|
+
no_update_s : bool, optional
|
588
|
+
Whether to update the single embeddings, by default False
|
589
|
+
no_update_z : bool, optional
|
590
|
+
Whether to update the pairwise embeddings, by default False
|
591
|
+
|
592
|
+
"""
|
593
|
+
super().__init__()
|
594
|
+
self.token_z = token_z
|
595
|
+
self.dropout = dropout
|
596
|
+
self.num_heads = num_heads
|
597
|
+
self.no_update_s = no_update_s
|
598
|
+
self.no_update_z = no_update_z
|
599
|
+
if not self.no_update_s:
|
600
|
+
self.attention = AttentionPairBias(token_s, token_z, num_heads)
|
601
|
+
self.tri_mul_out = TriangleMultiplicationOutgoing(token_z)
|
602
|
+
self.tri_mul_in = TriangleMultiplicationIncoming(token_z)
|
603
|
+
self.tri_att_start = TriangleAttentionStartingNode(
|
604
|
+
token_z, pairwise_head_width, pairwise_num_heads, inf=1e9
|
605
|
+
)
|
606
|
+
self.tri_att_end = TriangleAttentionEndingNode(
|
607
|
+
token_z, pairwise_head_width, pairwise_num_heads, inf=1e9
|
608
|
+
)
|
609
|
+
if not self.no_update_s:
|
610
|
+
self.transition_s = Transition(token_s, token_s * 4)
|
611
|
+
self.transition_z = Transition(token_z, token_z * 4)
|
612
|
+
|
613
|
+
def forward(
|
614
|
+
self,
|
615
|
+
s: Tensor,
|
616
|
+
z: Tensor,
|
617
|
+
mask: Tensor,
|
618
|
+
pair_mask: Tensor,
|
619
|
+
chunk_size_tri_attn: Optional[int] = None,
|
620
|
+
use_kernels: bool = False,
|
621
|
+
) -> tuple[Tensor, Tensor]:
|
622
|
+
"""Perform the forward pass."""
|
623
|
+
# Compute pairwise stack
|
624
|
+
dropout = get_dropout_mask(self.dropout, z, self.training)
|
625
|
+
z = z + dropout * self.tri_mul_out(z, mask=pair_mask)
|
626
|
+
|
627
|
+
dropout = get_dropout_mask(self.dropout, z, self.training)
|
628
|
+
z = z + dropout * self.tri_mul_in(z, mask=pair_mask)
|
629
|
+
|
630
|
+
dropout = get_dropout_mask(self.dropout, z, self.training)
|
631
|
+
z = z + dropout * self.tri_att_start(
|
632
|
+
z,
|
633
|
+
mask=pair_mask,
|
634
|
+
chunk_size=chunk_size_tri_attn,
|
635
|
+
use_kernels=use_kernels,
|
636
|
+
)
|
637
|
+
|
638
|
+
dropout = get_dropout_mask(self.dropout, z, self.training, columnwise=True)
|
639
|
+
z = z + dropout * self.tri_att_end(
|
640
|
+
z,
|
641
|
+
mask=pair_mask,
|
642
|
+
chunk_size=chunk_size_tri_attn,
|
643
|
+
use_kernels=use_kernels,
|
644
|
+
)
|
645
|
+
|
646
|
+
z = z + self.transition_z(z)
|
647
|
+
|
648
|
+
# Compute sequence stack
|
649
|
+
if not self.no_update_s:
|
650
|
+
s = s + self.attention(s, z, mask)
|
651
|
+
s = s + self.transition_s(s)
|
652
|
+
|
653
|
+
return s, z
|
654
|
+
|
655
|
+
|
656
|
+
class DistogramModule(nn.Module):
|
657
|
+
"""Distogram Module."""
|
658
|
+
|
659
|
+
def __init__(self, token_z: int, num_bins: int) -> None:
|
660
|
+
"""Initialize the distogram module.
|
661
|
+
|
662
|
+
Parameters
|
663
|
+
----------
|
664
|
+
token_z : int
|
665
|
+
The token pairwise embedding size.
|
666
|
+
num_bins : int
|
667
|
+
The number of bins.
|
668
|
+
|
669
|
+
"""
|
670
|
+
super().__init__()
|
671
|
+
self.distogram = nn.Linear(token_z, num_bins)
|
672
|
+
|
673
|
+
def forward(self, z: Tensor) -> Tensor:
|
674
|
+
"""Perform the forward pass.
|
675
|
+
|
676
|
+
Parameters
|
677
|
+
----------
|
678
|
+
z : Tensor
|
679
|
+
The pairwise embeddings
|
680
|
+
|
681
|
+
Returns
|
682
|
+
-------
|
683
|
+
Tensor
|
684
|
+
The predicted distogram.
|
685
|
+
|
686
|
+
"""
|
687
|
+
z = z + z.transpose(1, 2)
|
688
|
+
return self.distogram(z)
|