boltz-vsynthes 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltz/__init__.py +7 -0
- boltz/data/__init__.py +0 -0
- boltz/data/const.py +1184 -0
- boltz/data/crop/__init__.py +0 -0
- boltz/data/crop/affinity.py +164 -0
- boltz/data/crop/boltz.py +296 -0
- boltz/data/crop/cropper.py +45 -0
- boltz/data/feature/__init__.py +0 -0
- boltz/data/feature/featurizer.py +1230 -0
- boltz/data/feature/featurizerv2.py +2208 -0
- boltz/data/feature/symmetry.py +602 -0
- boltz/data/filter/__init__.py +0 -0
- boltz/data/filter/dynamic/__init__.py +0 -0
- boltz/data/filter/dynamic/date.py +76 -0
- boltz/data/filter/dynamic/filter.py +24 -0
- boltz/data/filter/dynamic/max_residues.py +37 -0
- boltz/data/filter/dynamic/resolution.py +34 -0
- boltz/data/filter/dynamic/size.py +38 -0
- boltz/data/filter/dynamic/subset.py +42 -0
- boltz/data/filter/static/__init__.py +0 -0
- boltz/data/filter/static/filter.py +26 -0
- boltz/data/filter/static/ligand.py +37 -0
- boltz/data/filter/static/polymer.py +299 -0
- boltz/data/module/__init__.py +0 -0
- boltz/data/module/inference.py +307 -0
- boltz/data/module/inferencev2.py +429 -0
- boltz/data/module/training.py +684 -0
- boltz/data/module/trainingv2.py +660 -0
- boltz/data/mol.py +900 -0
- boltz/data/msa/__init__.py +0 -0
- boltz/data/msa/mmseqs2.py +235 -0
- boltz/data/pad.py +84 -0
- boltz/data/parse/__init__.py +0 -0
- boltz/data/parse/a3m.py +134 -0
- boltz/data/parse/csv.py +100 -0
- boltz/data/parse/fasta.py +138 -0
- boltz/data/parse/mmcif.py +1239 -0
- boltz/data/parse/mmcif_with_constraints.py +1607 -0
- boltz/data/parse/schema.py +1851 -0
- boltz/data/parse/yaml.py +68 -0
- boltz/data/sample/__init__.py +0 -0
- boltz/data/sample/cluster.py +283 -0
- boltz/data/sample/distillation.py +57 -0
- boltz/data/sample/random.py +39 -0
- boltz/data/sample/sampler.py +49 -0
- boltz/data/tokenize/__init__.py +0 -0
- boltz/data/tokenize/boltz.py +195 -0
- boltz/data/tokenize/boltz2.py +396 -0
- boltz/data/tokenize/tokenizer.py +24 -0
- boltz/data/types.py +777 -0
- boltz/data/write/__init__.py +0 -0
- boltz/data/write/mmcif.py +305 -0
- boltz/data/write/pdb.py +171 -0
- boltz/data/write/utils.py +23 -0
- boltz/data/write/writer.py +330 -0
- boltz/main.py +1292 -0
- boltz/model/__init__.py +0 -0
- boltz/model/layers/__init__.py +0 -0
- boltz/model/layers/attention.py +132 -0
- boltz/model/layers/attentionv2.py +111 -0
- boltz/model/layers/confidence_utils.py +231 -0
- boltz/model/layers/dropout.py +34 -0
- boltz/model/layers/initialize.py +100 -0
- boltz/model/layers/outer_product_mean.py +98 -0
- boltz/model/layers/pair_averaging.py +135 -0
- boltz/model/layers/pairformer.py +337 -0
- boltz/model/layers/relative.py +58 -0
- boltz/model/layers/transition.py +78 -0
- boltz/model/layers/triangular_attention/__init__.py +0 -0
- boltz/model/layers/triangular_attention/attention.py +189 -0
- boltz/model/layers/triangular_attention/primitives.py +409 -0
- boltz/model/layers/triangular_attention/utils.py +380 -0
- boltz/model/layers/triangular_mult.py +212 -0
- boltz/model/loss/__init__.py +0 -0
- boltz/model/loss/bfactor.py +49 -0
- boltz/model/loss/confidence.py +590 -0
- boltz/model/loss/confidencev2.py +621 -0
- boltz/model/loss/diffusion.py +171 -0
- boltz/model/loss/diffusionv2.py +134 -0
- boltz/model/loss/distogram.py +48 -0
- boltz/model/loss/distogramv2.py +105 -0
- boltz/model/loss/validation.py +1025 -0
- boltz/model/models/__init__.py +0 -0
- boltz/model/models/boltz1.py +1286 -0
- boltz/model/models/boltz2.py +1249 -0
- boltz/model/modules/__init__.py +0 -0
- boltz/model/modules/affinity.py +223 -0
- boltz/model/modules/confidence.py +481 -0
- boltz/model/modules/confidence_utils.py +181 -0
- boltz/model/modules/confidencev2.py +495 -0
- boltz/model/modules/diffusion.py +844 -0
- boltz/model/modules/diffusion_conditioning.py +116 -0
- boltz/model/modules/diffusionv2.py +677 -0
- boltz/model/modules/encoders.py +639 -0
- boltz/model/modules/encodersv2.py +565 -0
- boltz/model/modules/transformers.py +322 -0
- boltz/model/modules/transformersv2.py +261 -0
- boltz/model/modules/trunk.py +688 -0
- boltz/model/modules/trunkv2.py +828 -0
- boltz/model/modules/utils.py +303 -0
- boltz/model/optim/__init__.py +0 -0
- boltz/model/optim/ema.py +389 -0
- boltz/model/optim/scheduler.py +99 -0
- boltz/model/potentials/__init__.py +0 -0
- boltz/model/potentials/potentials.py +497 -0
- boltz/model/potentials/schedules.py +32 -0
- boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
- boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
- boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
- boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
- boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
- boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,828 @@
|
|
1
|
+
import torch
|
2
|
+
from torch import Tensor, nn
|
3
|
+
from torch.nn.functional import one_hot
|
4
|
+
|
5
|
+
from boltz.data import const
|
6
|
+
from boltz.model.layers.outer_product_mean import OuterProductMean
|
7
|
+
from boltz.model.layers.pair_averaging import PairWeightedAveraging
|
8
|
+
from boltz.model.layers.pairformer import (
|
9
|
+
PairformerNoSeqLayer,
|
10
|
+
PairformerNoSeqModule,
|
11
|
+
get_dropout_mask,
|
12
|
+
)
|
13
|
+
from boltz.model.layers.transition import Transition
|
14
|
+
from boltz.model.modules.encodersv2 import (
|
15
|
+
AtomAttentionEncoder,
|
16
|
+
AtomEncoder,
|
17
|
+
FourierEmbedding,
|
18
|
+
)
|
19
|
+
|
20
|
+
|
21
|
+
class ContactConditioning(nn.Module):
|
22
|
+
def __init__(self, token_z: int, cutoff_min: float, cutoff_max: float):
|
23
|
+
super().__init__()
|
24
|
+
|
25
|
+
self.fourier_embedding = FourierEmbedding(token_z)
|
26
|
+
self.encoder = nn.Linear(
|
27
|
+
token_z + len(const.contact_conditioning_info) - 1, token_z
|
28
|
+
)
|
29
|
+
self.encoding_unspecified = nn.Parameter(torch.zeros(token_z))
|
30
|
+
self.encoding_unselected = nn.Parameter(torch.zeros(token_z))
|
31
|
+
self.cutoff_min = cutoff_min
|
32
|
+
self.cutoff_max = cutoff_max
|
33
|
+
|
34
|
+
def forward(self, feats):
|
35
|
+
assert const.contact_conditioning_info["UNSPECIFIED"] == 0
|
36
|
+
assert const.contact_conditioning_info["UNSELECTED"] == 1
|
37
|
+
contact_conditioning = feats["contact_conditioning"][:, :, :, 2:]
|
38
|
+
contact_threshold = feats["contact_threshold"]
|
39
|
+
contact_threshold_normalized = (contact_threshold - self.cutoff_min) / (
|
40
|
+
self.cutoff_max - self.cutoff_min
|
41
|
+
)
|
42
|
+
contact_threshold_fourier = self.fourier_embedding(
|
43
|
+
contact_threshold_normalized.flatten()
|
44
|
+
).reshape(contact_threshold_normalized.shape + (-1,))
|
45
|
+
|
46
|
+
contact_conditioning = torch.cat(
|
47
|
+
[
|
48
|
+
contact_conditioning,
|
49
|
+
contact_threshold_normalized.unsqueeze(-1),
|
50
|
+
contact_threshold_fourier,
|
51
|
+
],
|
52
|
+
dim=-1,
|
53
|
+
)
|
54
|
+
contact_conditioning = self.encoder(contact_conditioning)
|
55
|
+
|
56
|
+
contact_conditioning = (
|
57
|
+
contact_conditioning
|
58
|
+
* (
|
59
|
+
1
|
60
|
+
- feats["contact_conditioning"][:, :, :, 0:2].sum(dim=-1, keepdim=True)
|
61
|
+
)
|
62
|
+
+ self.encoding_unspecified * feats["contact_conditioning"][:, :, :, 0:1]
|
63
|
+
+ self.encoding_unselected * feats["contact_conditioning"][:, :, :, 1:2]
|
64
|
+
)
|
65
|
+
return contact_conditioning
|
66
|
+
|
67
|
+
|
68
|
+
class InputEmbedder(nn.Module):
|
69
|
+
def __init__(
|
70
|
+
self,
|
71
|
+
atom_s: int,
|
72
|
+
atom_z: int,
|
73
|
+
token_s: int,
|
74
|
+
token_z: int,
|
75
|
+
atoms_per_window_queries: int,
|
76
|
+
atoms_per_window_keys: int,
|
77
|
+
atom_feature_dim: int,
|
78
|
+
atom_encoder_depth: int,
|
79
|
+
atom_encoder_heads: int,
|
80
|
+
activation_checkpointing: bool = False,
|
81
|
+
add_method_conditioning: bool = False,
|
82
|
+
add_modified_flag: bool = False,
|
83
|
+
add_cyclic_flag: bool = False,
|
84
|
+
add_mol_type_feat: bool = False,
|
85
|
+
use_no_atom_char: bool = False,
|
86
|
+
use_atom_backbone_feat: bool = False,
|
87
|
+
use_residue_feats_atoms: bool = False,
|
88
|
+
) -> None:
|
89
|
+
"""Initialize the input embedder.
|
90
|
+
|
91
|
+
Parameters
|
92
|
+
----------
|
93
|
+
atom_s : int
|
94
|
+
The atom embedding size.
|
95
|
+
atom_z : int
|
96
|
+
The atom pairwise embedding size.
|
97
|
+
token_s : int
|
98
|
+
The token embedding size.
|
99
|
+
|
100
|
+
"""
|
101
|
+
super().__init__()
|
102
|
+
self.token_s = token_s
|
103
|
+
self.add_method_conditioning = add_method_conditioning
|
104
|
+
self.add_modified_flag = add_modified_flag
|
105
|
+
self.add_cyclic_flag = add_cyclic_flag
|
106
|
+
self.add_mol_type_feat = add_mol_type_feat
|
107
|
+
|
108
|
+
self.atom_encoder = AtomEncoder(
|
109
|
+
atom_s=atom_s,
|
110
|
+
atom_z=atom_z,
|
111
|
+
token_s=token_s,
|
112
|
+
token_z=token_z,
|
113
|
+
atoms_per_window_queries=atoms_per_window_queries,
|
114
|
+
atoms_per_window_keys=atoms_per_window_keys,
|
115
|
+
atom_feature_dim=atom_feature_dim,
|
116
|
+
structure_prediction=False,
|
117
|
+
use_no_atom_char=use_no_atom_char,
|
118
|
+
use_atom_backbone_feat=use_atom_backbone_feat,
|
119
|
+
use_residue_feats_atoms=use_residue_feats_atoms,
|
120
|
+
)
|
121
|
+
|
122
|
+
self.atom_enc_proj_z = nn.Sequential(
|
123
|
+
nn.LayerNorm(atom_z),
|
124
|
+
nn.Linear(atom_z, atom_encoder_depth * atom_encoder_heads, bias=False),
|
125
|
+
)
|
126
|
+
|
127
|
+
self.atom_attention_encoder = AtomAttentionEncoder(
|
128
|
+
atom_s=atom_s,
|
129
|
+
token_s=token_s,
|
130
|
+
atoms_per_window_queries=atoms_per_window_queries,
|
131
|
+
atoms_per_window_keys=atoms_per_window_keys,
|
132
|
+
atom_encoder_depth=atom_encoder_depth,
|
133
|
+
atom_encoder_heads=atom_encoder_heads,
|
134
|
+
structure_prediction=False,
|
135
|
+
activation_checkpointing=activation_checkpointing,
|
136
|
+
)
|
137
|
+
|
138
|
+
self.res_type_encoding = nn.Linear(const.num_tokens, token_s, bias=False)
|
139
|
+
self.msa_profile_encoding = nn.Linear(const.num_tokens + 1, token_s, bias=False)
|
140
|
+
|
141
|
+
if add_method_conditioning:
|
142
|
+
self.method_conditioning_init = nn.Embedding(
|
143
|
+
const.num_method_types, token_s
|
144
|
+
)
|
145
|
+
self.method_conditioning_init.weight.data.fill_(0)
|
146
|
+
if add_modified_flag:
|
147
|
+
self.modified_conditioning_init = nn.Embedding(2, token_s)
|
148
|
+
self.modified_conditioning_init.weight.data.fill_(0)
|
149
|
+
if add_cyclic_flag:
|
150
|
+
self.cyclic_conditioning_init = nn.Linear(1, token_s, bias=False)
|
151
|
+
self.cyclic_conditioning_init.weight.data.fill_(0)
|
152
|
+
if add_mol_type_feat:
|
153
|
+
self.mol_type_conditioning_init = nn.Embedding(
|
154
|
+
len(const.chain_type_ids), token_s
|
155
|
+
)
|
156
|
+
self.mol_type_conditioning_init.weight.data.fill_(0)
|
157
|
+
|
158
|
+
def forward(self, feats: dict[str, Tensor], affinity: bool = False) -> Tensor:
|
159
|
+
"""Perform the forward pass.
|
160
|
+
|
161
|
+
Parameters
|
162
|
+
----------
|
163
|
+
feats : dict[str, Tensor]
|
164
|
+
Input features
|
165
|
+
|
166
|
+
Returns
|
167
|
+
-------
|
168
|
+
Tensor
|
169
|
+
The embedded tokens.
|
170
|
+
|
171
|
+
"""
|
172
|
+
# Load relevant features
|
173
|
+
res_type = feats["res_type"].float()
|
174
|
+
if affinity:
|
175
|
+
profile = feats["profile_affinity"]
|
176
|
+
deletion_mean = feats["deletion_mean_affinity"].unsqueeze(-1)
|
177
|
+
else:
|
178
|
+
profile = feats["profile"]
|
179
|
+
deletion_mean = feats["deletion_mean"].unsqueeze(-1)
|
180
|
+
|
181
|
+
# Compute input embedding
|
182
|
+
q, c, p, to_keys = self.atom_encoder(feats)
|
183
|
+
atom_enc_bias = self.atom_enc_proj_z(p)
|
184
|
+
a, _, _, _ = self.atom_attention_encoder(
|
185
|
+
feats=feats,
|
186
|
+
q=q,
|
187
|
+
c=c,
|
188
|
+
atom_enc_bias=atom_enc_bias,
|
189
|
+
to_keys=to_keys,
|
190
|
+
)
|
191
|
+
|
192
|
+
s = (
|
193
|
+
a
|
194
|
+
+ self.res_type_encoding(res_type)
|
195
|
+
+ self.msa_profile_encoding(torch.cat([profile, deletion_mean], dim=-1))
|
196
|
+
)
|
197
|
+
|
198
|
+
if self.add_method_conditioning:
|
199
|
+
s = s + self.method_conditioning_init(feats["method_feature"])
|
200
|
+
if self.add_modified_flag:
|
201
|
+
s = s + self.modified_conditioning_init(feats["modified"])
|
202
|
+
if self.add_cyclic_flag:
|
203
|
+
cyclic = feats["cyclic_period"].clamp(max=1.0).unsqueeze(-1)
|
204
|
+
s = s + self.cyclic_conditioning_init(cyclic)
|
205
|
+
if self.add_mol_type_feat:
|
206
|
+
s = s + self.mol_type_conditioning_init(feats["mol_type"])
|
207
|
+
|
208
|
+
return s
|
209
|
+
|
210
|
+
|
211
|
+
class TemplateModule(nn.Module):
|
212
|
+
"""Template module."""
|
213
|
+
|
214
|
+
def __init__(
|
215
|
+
self,
|
216
|
+
token_z: int,
|
217
|
+
template_dim: int,
|
218
|
+
template_blocks: int,
|
219
|
+
dropout: float = 0.25,
|
220
|
+
pairwise_head_width: int = 32,
|
221
|
+
pairwise_num_heads: int = 4,
|
222
|
+
post_layer_norm: bool = False,
|
223
|
+
activation_checkpointing: bool = False,
|
224
|
+
min_dist: float = 3.25,
|
225
|
+
max_dist: float = 50.75,
|
226
|
+
num_bins: int = 38,
|
227
|
+
**kwargs,
|
228
|
+
) -> None:
|
229
|
+
"""Initialize the template module.
|
230
|
+
|
231
|
+
Parameters
|
232
|
+
----------
|
233
|
+
token_z : int
|
234
|
+
The token pairwise embedding size.
|
235
|
+
|
236
|
+
"""
|
237
|
+
super().__init__()
|
238
|
+
self.min_dist = min_dist
|
239
|
+
self.max_dist = max_dist
|
240
|
+
self.num_bins = num_bins
|
241
|
+
self.relu = nn.ReLU()
|
242
|
+
self.z_norm = nn.LayerNorm(token_z)
|
243
|
+
self.v_norm = nn.LayerNorm(template_dim)
|
244
|
+
self.z_proj = nn.Linear(token_z, template_dim, bias=False)
|
245
|
+
self.a_proj = nn.Linear(
|
246
|
+
const.num_tokens * 2 + num_bins + 5,
|
247
|
+
template_dim,
|
248
|
+
bias=False,
|
249
|
+
)
|
250
|
+
self.u_proj = nn.Linear(template_dim, token_z, bias=False)
|
251
|
+
self.pairformer = PairformerNoSeqModule(
|
252
|
+
template_dim,
|
253
|
+
num_blocks=template_blocks,
|
254
|
+
dropout=dropout,
|
255
|
+
pairwise_head_width=pairwise_head_width,
|
256
|
+
pairwise_num_heads=pairwise_num_heads,
|
257
|
+
post_layer_norm=post_layer_norm,
|
258
|
+
activation_checkpointing=activation_checkpointing,
|
259
|
+
)
|
260
|
+
|
261
|
+
def forward(
|
262
|
+
self,
|
263
|
+
z: Tensor,
|
264
|
+
feats: dict[str, Tensor],
|
265
|
+
pair_mask: Tensor,
|
266
|
+
use_kernels: bool = False,
|
267
|
+
) -> Tensor:
|
268
|
+
"""Perform the forward pass.
|
269
|
+
|
270
|
+
Parameters
|
271
|
+
----------
|
272
|
+
z : Tensor
|
273
|
+
The pairwise embeddings
|
274
|
+
feats : dict[str, Tensor]
|
275
|
+
Input features
|
276
|
+
pair_mask : Tensor
|
277
|
+
The pair mask
|
278
|
+
|
279
|
+
Returns
|
280
|
+
-------
|
281
|
+
Tensor
|
282
|
+
The updated pairwise embeddings.
|
283
|
+
|
284
|
+
"""
|
285
|
+
# Load relevant features
|
286
|
+
asym_id = feats["asym_id"]
|
287
|
+
res_type = feats["template_restype"]
|
288
|
+
frame_rot = feats["template_frame_rot"]
|
289
|
+
frame_t = feats["template_frame_t"]
|
290
|
+
frame_mask = feats["template_mask_frame"]
|
291
|
+
cb_coords = feats["template_cb"]
|
292
|
+
ca_coords = feats["template_ca"]
|
293
|
+
cb_mask = feats["template_mask_cb"]
|
294
|
+
template_mask = feats["template_mask"].any(dim=2).float()
|
295
|
+
num_templates = template_mask.sum(dim=1)
|
296
|
+
num_templates = num_templates.clamp(min=1)
|
297
|
+
|
298
|
+
# Compute pairwise masks
|
299
|
+
b_cb_mask = cb_mask[:, :, :, None] * cb_mask[:, :, None, :]
|
300
|
+
b_frame_mask = frame_mask[:, :, :, None] * frame_mask[:, :, None, :]
|
301
|
+
|
302
|
+
b_cb_mask = b_cb_mask[..., None]
|
303
|
+
b_frame_mask = b_frame_mask[..., None]
|
304
|
+
|
305
|
+
# Compute asym mask, template features only attend within the same chain
|
306
|
+
B, T = res_type.shape[:2] # noqa: N806
|
307
|
+
asym_mask = (asym_id[:, :, None] == asym_id[:, None, :]).float()
|
308
|
+
asym_mask = asym_mask[:, None].expand(-1, T, -1, -1)
|
309
|
+
|
310
|
+
# Compute template features
|
311
|
+
with torch.autocast(device_type="cuda", enabled=False):
|
312
|
+
# Compute distogram
|
313
|
+
cb_dists = torch.cdist(cb_coords, cb_coords)
|
314
|
+
boundaries = torch.linspace(self.min_dist, self.max_dist, self.num_bins - 1)
|
315
|
+
boundaries = boundaries.to(cb_dists.device)
|
316
|
+
distogram = (cb_dists[..., None] > boundaries).sum(dim=-1).long()
|
317
|
+
distogram = one_hot(distogram, num_classes=self.num_bins)
|
318
|
+
|
319
|
+
# Compute unit vector in each frame
|
320
|
+
frame_rot = frame_rot.unsqueeze(2).transpose(-1, -2)
|
321
|
+
frame_t = frame_t.unsqueeze(2).unsqueeze(-1)
|
322
|
+
ca_coords = ca_coords.unsqueeze(3).unsqueeze(-1)
|
323
|
+
vector = torch.matmul(frame_rot, (ca_coords - frame_t))
|
324
|
+
norm = torch.norm(vector, dim=-1, keepdim=True)
|
325
|
+
unit_vector = torch.where(norm > 0, vector / norm, torch.zeros_like(vector))
|
326
|
+
unit_vector = unit_vector.squeeze(-1)
|
327
|
+
|
328
|
+
# Concatenate input features
|
329
|
+
a_tij = [distogram, b_cb_mask, unit_vector, b_frame_mask]
|
330
|
+
a_tij = torch.cat(a_tij, dim=-1)
|
331
|
+
a_tij = a_tij * asym_mask.unsqueeze(-1)
|
332
|
+
|
333
|
+
res_type_i = res_type[:, :, :, None]
|
334
|
+
res_type_j = res_type[:, :, None, :]
|
335
|
+
res_type_i = res_type_i.expand(-1, -1, -1, res_type.size(2), -1)
|
336
|
+
res_type_j = res_type_j.expand(-1, -1, res_type.size(2), -1, -1)
|
337
|
+
a_tij = torch.cat([a_tij, res_type_i, res_type_j], dim=-1)
|
338
|
+
a_tij = self.a_proj(a_tij)
|
339
|
+
|
340
|
+
# Expand mask
|
341
|
+
pair_mask = pair_mask[:, None].expand(-1, T, -1, -1)
|
342
|
+
pair_mask = pair_mask.reshape(B * T, *pair_mask.shape[2:])
|
343
|
+
|
344
|
+
# Compute input projections
|
345
|
+
v = self.z_proj(self.z_norm(z[:, None])) + a_tij
|
346
|
+
v = v.view(B * T, *v.shape[2:])
|
347
|
+
v = v + self.pairformer(v, pair_mask, use_kernels=use_kernels)
|
348
|
+
v = self.v_norm(v)
|
349
|
+
v = v.view(B, T, *v.shape[1:])
|
350
|
+
|
351
|
+
# Aggregate templates
|
352
|
+
template_mask = template_mask[:, :, None, None, None]
|
353
|
+
num_templates = num_templates[:, None, None, None]
|
354
|
+
u = (v * template_mask).sum(dim=1) / num_templates.to(v)
|
355
|
+
|
356
|
+
# Compute output projection
|
357
|
+
u = self.u_proj(self.relu(u))
|
358
|
+
return u
|
359
|
+
|
360
|
+
|
361
|
+
class TemplateV2Module(nn.Module):
|
362
|
+
"""Template module."""
|
363
|
+
|
364
|
+
def __init__(
|
365
|
+
self,
|
366
|
+
token_z: int,
|
367
|
+
template_dim: int,
|
368
|
+
template_blocks: int,
|
369
|
+
dropout: float = 0.25,
|
370
|
+
pairwise_head_width: int = 32,
|
371
|
+
pairwise_num_heads: int = 4,
|
372
|
+
post_layer_norm: bool = False,
|
373
|
+
activation_checkpointing: bool = False,
|
374
|
+
min_dist: float = 3.25,
|
375
|
+
max_dist: float = 50.75,
|
376
|
+
num_bins: int = 38,
|
377
|
+
**kwargs,
|
378
|
+
) -> None:
|
379
|
+
"""Initialize the template module.
|
380
|
+
|
381
|
+
Parameters
|
382
|
+
----------
|
383
|
+
token_z : int
|
384
|
+
The token pairwise embedding size.
|
385
|
+
|
386
|
+
"""
|
387
|
+
super().__init__()
|
388
|
+
self.min_dist = min_dist
|
389
|
+
self.max_dist = max_dist
|
390
|
+
self.num_bins = num_bins
|
391
|
+
self.relu = nn.ReLU()
|
392
|
+
self.z_norm = nn.LayerNorm(token_z)
|
393
|
+
self.v_norm = nn.LayerNorm(template_dim)
|
394
|
+
self.z_proj = nn.Linear(token_z, template_dim, bias=False)
|
395
|
+
self.a_proj = nn.Linear(
|
396
|
+
const.num_tokens * 2 + num_bins + 5,
|
397
|
+
template_dim,
|
398
|
+
bias=False,
|
399
|
+
)
|
400
|
+
self.u_proj = nn.Linear(template_dim, token_z, bias=False)
|
401
|
+
self.pairformer = PairformerNoSeqModule(
|
402
|
+
template_dim,
|
403
|
+
num_blocks=template_blocks,
|
404
|
+
dropout=dropout,
|
405
|
+
pairwise_head_width=pairwise_head_width,
|
406
|
+
pairwise_num_heads=pairwise_num_heads,
|
407
|
+
post_layer_norm=post_layer_norm,
|
408
|
+
activation_checkpointing=activation_checkpointing,
|
409
|
+
)
|
410
|
+
|
411
|
+
def forward(
|
412
|
+
self,
|
413
|
+
z: Tensor,
|
414
|
+
feats: dict[str, Tensor],
|
415
|
+
pair_mask: Tensor,
|
416
|
+
use_kernels: bool = False,
|
417
|
+
) -> Tensor:
|
418
|
+
"""Perform the forward pass.
|
419
|
+
|
420
|
+
Parameters
|
421
|
+
----------
|
422
|
+
z : Tensor
|
423
|
+
The pairwise embeddings
|
424
|
+
feats : dict[str, Tensor]
|
425
|
+
Input features
|
426
|
+
pair_mask : Tensor
|
427
|
+
The pair mask
|
428
|
+
|
429
|
+
Returns
|
430
|
+
-------
|
431
|
+
Tensor
|
432
|
+
The updated pairwise embeddings.
|
433
|
+
|
434
|
+
"""
|
435
|
+
# Load relevant features
|
436
|
+
res_type = feats["template_restype"]
|
437
|
+
frame_rot = feats["template_frame_rot"]
|
438
|
+
frame_t = feats["template_frame_t"]
|
439
|
+
frame_mask = feats["template_mask_frame"]
|
440
|
+
cb_coords = feats["template_cb"]
|
441
|
+
ca_coords = feats["template_ca"]
|
442
|
+
cb_mask = feats["template_mask_cb"]
|
443
|
+
visibility_ids = feats["visibility_ids"]
|
444
|
+
template_mask = feats["template_mask"].any(dim=2).float()
|
445
|
+
num_templates = template_mask.sum(dim=1)
|
446
|
+
num_templates = num_templates.clamp(min=1)
|
447
|
+
|
448
|
+
# Compute pairwise masks
|
449
|
+
b_cb_mask = cb_mask[:, :, :, None] * cb_mask[:, :, None, :]
|
450
|
+
b_frame_mask = frame_mask[:, :, :, None] * frame_mask[:, :, None, :]
|
451
|
+
|
452
|
+
b_cb_mask = b_cb_mask[..., None]
|
453
|
+
b_frame_mask = b_frame_mask[..., None]
|
454
|
+
|
455
|
+
# Compute asym mask, template features only attend within the same chain
|
456
|
+
B, T = res_type.shape[:2] # noqa: N806
|
457
|
+
tmlp_pair_mask = (
|
458
|
+
visibility_ids[:, :, :, None] == visibility_ids[:, :, None, :]
|
459
|
+
).float()
|
460
|
+
|
461
|
+
# Compute template features
|
462
|
+
with torch.autocast(device_type="cuda", enabled=False):
|
463
|
+
# Compute distogram
|
464
|
+
cb_dists = torch.cdist(cb_coords, cb_coords)
|
465
|
+
boundaries = torch.linspace(self.min_dist, self.max_dist, self.num_bins - 1)
|
466
|
+
boundaries = boundaries.to(cb_dists.device)
|
467
|
+
distogram = (cb_dists[..., None] > boundaries).sum(dim=-1).long()
|
468
|
+
distogram = one_hot(distogram, num_classes=self.num_bins)
|
469
|
+
|
470
|
+
# Compute unit vector in each frame
|
471
|
+
frame_rot = frame_rot.unsqueeze(2).transpose(-1, -2)
|
472
|
+
frame_t = frame_t.unsqueeze(2).unsqueeze(-1)
|
473
|
+
ca_coords = ca_coords.unsqueeze(3).unsqueeze(-1)
|
474
|
+
vector = torch.matmul(frame_rot, (ca_coords - frame_t))
|
475
|
+
norm = torch.norm(vector, dim=-1, keepdim=True)
|
476
|
+
unit_vector = torch.where(norm > 0, vector / norm, torch.zeros_like(vector))
|
477
|
+
unit_vector = unit_vector.squeeze(-1)
|
478
|
+
|
479
|
+
# Concatenate input features
|
480
|
+
a_tij = [distogram, b_cb_mask, unit_vector, b_frame_mask]
|
481
|
+
a_tij = torch.cat(a_tij, dim=-1)
|
482
|
+
a_tij = a_tij * tmlp_pair_mask.unsqueeze(-1)
|
483
|
+
|
484
|
+
res_type_i = res_type[:, :, :, None]
|
485
|
+
res_type_j = res_type[:, :, None, :]
|
486
|
+
res_type_i = res_type_i.expand(-1, -1, -1, res_type.size(2), -1)
|
487
|
+
res_type_j = res_type_j.expand(-1, -1, res_type.size(2), -1, -1)
|
488
|
+
a_tij = torch.cat([a_tij, res_type_i, res_type_j], dim=-1)
|
489
|
+
a_tij = self.a_proj(a_tij)
|
490
|
+
|
491
|
+
# Expand mask
|
492
|
+
pair_mask = pair_mask[:, None].expand(-1, T, -1, -1)
|
493
|
+
pair_mask = pair_mask.reshape(B * T, *pair_mask.shape[2:])
|
494
|
+
|
495
|
+
# Compute input projections
|
496
|
+
v = self.z_proj(self.z_norm(z[:, None])) + a_tij
|
497
|
+
v = v.view(B * T, *v.shape[2:])
|
498
|
+
v = v + self.pairformer(v, pair_mask, use_kernels=use_kernels)
|
499
|
+
v = self.v_norm(v)
|
500
|
+
v = v.view(B, T, *v.shape[1:])
|
501
|
+
|
502
|
+
# Aggregate templates
|
503
|
+
template_mask = template_mask[:, :, None, None, None]
|
504
|
+
num_templates = num_templates[:, None, None, None]
|
505
|
+
u = (v * template_mask).sum(dim=1) / num_templates.to(v)
|
506
|
+
|
507
|
+
# Compute output projection
|
508
|
+
u = self.u_proj(self.relu(u))
|
509
|
+
return u
|
510
|
+
|
511
|
+
|
512
|
+
class MSAModule(nn.Module):
|
513
|
+
"""MSA module."""
|
514
|
+
|
515
|
+
def __init__(
|
516
|
+
self,
|
517
|
+
msa_s: int,
|
518
|
+
token_z: int,
|
519
|
+
token_s: int,
|
520
|
+
msa_blocks: int,
|
521
|
+
msa_dropout: float,
|
522
|
+
z_dropout: float,
|
523
|
+
pairwise_head_width: int = 32,
|
524
|
+
pairwise_num_heads: int = 4,
|
525
|
+
activation_checkpointing: bool = False,
|
526
|
+
use_paired_feature: bool = True,
|
527
|
+
subsample_msa: bool = False,
|
528
|
+
num_subsampled_msa: int = 1024,
|
529
|
+
**kwargs,
|
530
|
+
) -> None:
|
531
|
+
"""Initialize the MSA module.
|
532
|
+
|
533
|
+
Parameters
|
534
|
+
----------
|
535
|
+
token_z : int
|
536
|
+
The token pairwise embedding size.
|
537
|
+
|
538
|
+
"""
|
539
|
+
super().__init__()
|
540
|
+
self.msa_blocks = msa_blocks
|
541
|
+
self.msa_dropout = msa_dropout
|
542
|
+
self.z_dropout = z_dropout
|
543
|
+
self.use_paired_feature = use_paired_feature
|
544
|
+
self.activation_checkpointing = activation_checkpointing
|
545
|
+
self.subsample_msa = subsample_msa
|
546
|
+
self.num_subsampled_msa = num_subsampled_msa
|
547
|
+
|
548
|
+
self.s_proj = nn.Linear(token_s, msa_s, bias=False)
|
549
|
+
self.msa_proj = nn.Linear(
|
550
|
+
const.num_tokens + 2 + int(use_paired_feature),
|
551
|
+
msa_s,
|
552
|
+
bias=False,
|
553
|
+
)
|
554
|
+
self.layers = nn.ModuleList()
|
555
|
+
for i in range(msa_blocks):
|
556
|
+
self.layers.append(
|
557
|
+
MSALayer(
|
558
|
+
msa_s,
|
559
|
+
token_z,
|
560
|
+
msa_dropout,
|
561
|
+
z_dropout,
|
562
|
+
pairwise_head_width,
|
563
|
+
pairwise_num_heads,
|
564
|
+
)
|
565
|
+
)
|
566
|
+
|
567
|
+
def forward(
|
568
|
+
self,
|
569
|
+
z: Tensor,
|
570
|
+
emb: Tensor,
|
571
|
+
feats: dict[str, Tensor],
|
572
|
+
use_kernels: bool = False,
|
573
|
+
) -> Tensor:
|
574
|
+
"""Perform the forward pass.
|
575
|
+
|
576
|
+
Parameters
|
577
|
+
----------
|
578
|
+
z : Tensor
|
579
|
+
The pairwise embeddings
|
580
|
+
emb : Tensor
|
581
|
+
The input embeddings
|
582
|
+
feats : dict[str, Tensor]
|
583
|
+
Input features
|
584
|
+
use_kernels: bool
|
585
|
+
Whether to use kernels for triangular updates
|
586
|
+
|
587
|
+
Returns
|
588
|
+
-------
|
589
|
+
Tensor
|
590
|
+
The output pairwise embeddings.
|
591
|
+
|
592
|
+
"""
|
593
|
+
# Set chunk sizes
|
594
|
+
if not self.training:
|
595
|
+
if z.shape[1] > const.chunk_size_threshold:
|
596
|
+
chunk_heads_pwa = True
|
597
|
+
chunk_size_transition_z = 64
|
598
|
+
chunk_size_transition_msa = 32
|
599
|
+
chunk_size_outer_product = 4
|
600
|
+
chunk_size_tri_attn = 128
|
601
|
+
else:
|
602
|
+
chunk_heads_pwa = False
|
603
|
+
chunk_size_transition_z = None
|
604
|
+
chunk_size_transition_msa = None
|
605
|
+
chunk_size_outer_product = None
|
606
|
+
chunk_size_tri_attn = 512
|
607
|
+
else:
|
608
|
+
chunk_heads_pwa = False
|
609
|
+
chunk_size_transition_z = None
|
610
|
+
chunk_size_transition_msa = None
|
611
|
+
chunk_size_outer_product = None
|
612
|
+
chunk_size_tri_attn = None
|
613
|
+
|
614
|
+
# Load relevant features
|
615
|
+
msa = feats["msa"]
|
616
|
+
msa = torch.nn.functional.one_hot(msa, num_classes=const.num_tokens)
|
617
|
+
has_deletion = feats["has_deletion"].unsqueeze(-1)
|
618
|
+
deletion_value = feats["deletion_value"].unsqueeze(-1)
|
619
|
+
is_paired = feats["msa_paired"].unsqueeze(-1)
|
620
|
+
msa_mask = feats["msa_mask"]
|
621
|
+
token_mask = feats["token_pad_mask"].float()
|
622
|
+
token_mask = token_mask[:, :, None] * token_mask[:, None, :]
|
623
|
+
|
624
|
+
# Compute MSA embeddings
|
625
|
+
if self.use_paired_feature:
|
626
|
+
m = torch.cat([msa, has_deletion, deletion_value, is_paired], dim=-1)
|
627
|
+
else:
|
628
|
+
m = torch.cat([msa, has_deletion, deletion_value], dim=-1)
|
629
|
+
|
630
|
+
# Subsample the MSA
|
631
|
+
if self.subsample_msa:
|
632
|
+
msa_indices = torch.randperm(msa.shape[1])[: self.num_subsampled_msa]
|
633
|
+
m = m[:, msa_indices]
|
634
|
+
msa_mask = msa_mask[:, msa_indices]
|
635
|
+
|
636
|
+
# Compute input projections
|
637
|
+
m = self.msa_proj(m)
|
638
|
+
m = m + self.s_proj(emb).unsqueeze(1)
|
639
|
+
|
640
|
+
# Perform MSA blocks
|
641
|
+
for i in range(self.msa_blocks):
|
642
|
+
if self.activation_checkpointing and self.training:
|
643
|
+
z, m = torch.utils.checkpoint.checkpoint(
|
644
|
+
self.layers[i],
|
645
|
+
z,
|
646
|
+
m,
|
647
|
+
token_mask,
|
648
|
+
msa_mask,
|
649
|
+
chunk_heads_pwa,
|
650
|
+
chunk_size_transition_z,
|
651
|
+
chunk_size_transition_msa,
|
652
|
+
chunk_size_outer_product,
|
653
|
+
chunk_size_tri_attn,
|
654
|
+
use_kernels,
|
655
|
+
)
|
656
|
+
else:
|
657
|
+
z, m = self.layers[i](
|
658
|
+
z,
|
659
|
+
m,
|
660
|
+
token_mask,
|
661
|
+
msa_mask,
|
662
|
+
chunk_heads_pwa,
|
663
|
+
chunk_size_transition_z,
|
664
|
+
chunk_size_transition_msa,
|
665
|
+
chunk_size_outer_product,
|
666
|
+
chunk_size_tri_attn,
|
667
|
+
use_kernels,
|
668
|
+
)
|
669
|
+
return z
|
670
|
+
|
671
|
+
|
672
|
+
class MSALayer(nn.Module):
|
673
|
+
"""MSA module."""
|
674
|
+
|
675
|
+
def __init__(
|
676
|
+
self,
|
677
|
+
msa_s: int,
|
678
|
+
token_z: int,
|
679
|
+
msa_dropout: float,
|
680
|
+
z_dropout: float,
|
681
|
+
pairwise_head_width: int = 32,
|
682
|
+
pairwise_num_heads: int = 4,
|
683
|
+
) -> None:
|
684
|
+
"""Initialize the MSA module.
|
685
|
+
|
686
|
+
Parameters
|
687
|
+
----------
|
688
|
+
token_z : int
|
689
|
+
The token pairwise embedding size.
|
690
|
+
|
691
|
+
"""
|
692
|
+
super().__init__()
|
693
|
+
self.msa_dropout = msa_dropout
|
694
|
+
self.msa_transition = Transition(dim=msa_s, hidden=msa_s * 4)
|
695
|
+
self.pair_weighted_averaging = PairWeightedAveraging(
|
696
|
+
c_m=msa_s,
|
697
|
+
c_z=token_z,
|
698
|
+
c_h=32,
|
699
|
+
num_heads=8,
|
700
|
+
)
|
701
|
+
|
702
|
+
self.pairformer_layer = PairformerNoSeqLayer(
|
703
|
+
token_z=token_z,
|
704
|
+
dropout=z_dropout,
|
705
|
+
pairwise_head_width=pairwise_head_width,
|
706
|
+
pairwise_num_heads=pairwise_num_heads,
|
707
|
+
)
|
708
|
+
self.outer_product_mean = OuterProductMean(
|
709
|
+
c_in=msa_s,
|
710
|
+
c_hidden=32,
|
711
|
+
c_out=token_z,
|
712
|
+
)
|
713
|
+
|
714
|
+
def forward(
|
715
|
+
self,
|
716
|
+
z: Tensor,
|
717
|
+
m: Tensor,
|
718
|
+
token_mask: Tensor,
|
719
|
+
msa_mask: Tensor,
|
720
|
+
chunk_heads_pwa: bool = False,
|
721
|
+
chunk_size_transition_z: int = None,
|
722
|
+
chunk_size_transition_msa: int = None,
|
723
|
+
chunk_size_outer_product: int = None,
|
724
|
+
chunk_size_tri_attn: int = None,
|
725
|
+
use_kernels: bool = False,
|
726
|
+
) -> tuple[Tensor, Tensor]:
|
727
|
+
"""Perform the forward pass.
|
728
|
+
|
729
|
+
Parameters
|
730
|
+
----------
|
731
|
+
z : Tensor
|
732
|
+
The pairwise embeddings
|
733
|
+
emb : Tensor
|
734
|
+
The input embeddings
|
735
|
+
feats : dict[str, Tensor]
|
736
|
+
Input features
|
737
|
+
|
738
|
+
Returns
|
739
|
+
-------
|
740
|
+
Tensor
|
741
|
+
The output pairwise embeddings.
|
742
|
+
|
743
|
+
"""
|
744
|
+
# Communication to MSA stack
|
745
|
+
msa_dropout = get_dropout_mask(self.msa_dropout, m, self.training)
|
746
|
+
m = m + msa_dropout * self.pair_weighted_averaging(
|
747
|
+
m, z, token_mask, chunk_heads_pwa
|
748
|
+
)
|
749
|
+
m = m + self.msa_transition(m, chunk_size_transition_msa)
|
750
|
+
|
751
|
+
z = z + self.outer_product_mean(m, msa_mask, chunk_size_outer_product)
|
752
|
+
|
753
|
+
# Compute pairwise stack
|
754
|
+
z = self.pairformer_layer(
|
755
|
+
z, token_mask, chunk_size_tri_attn, use_kernels=use_kernels
|
756
|
+
)
|
757
|
+
|
758
|
+
return z, m
|
759
|
+
|
760
|
+
|
761
|
+
class BFactorModule(nn.Module):
|
762
|
+
"""BFactor Module."""
|
763
|
+
|
764
|
+
def __init__(self, token_s: int, num_bins: int) -> None:
|
765
|
+
"""Initialize the bfactor module.
|
766
|
+
|
767
|
+
Parameters
|
768
|
+
----------
|
769
|
+
token_s : int
|
770
|
+
The token embedding size.
|
771
|
+
|
772
|
+
"""
|
773
|
+
super().__init__()
|
774
|
+
self.bfactor = nn.Linear(token_s, num_bins)
|
775
|
+
self.num_bins = num_bins
|
776
|
+
|
777
|
+
def forward(self, s: Tensor) -> Tensor:
|
778
|
+
"""Perform the forward pass.
|
779
|
+
|
780
|
+
Parameters
|
781
|
+
----------
|
782
|
+
s : Tensor
|
783
|
+
The sequence embeddings
|
784
|
+
|
785
|
+
Returns
|
786
|
+
-------
|
787
|
+
Tensor
|
788
|
+
The predicted bfactor histogram.
|
789
|
+
|
790
|
+
"""
|
791
|
+
return self.bfactor(s)
|
792
|
+
|
793
|
+
|
794
|
+
class DistogramModule(nn.Module):
|
795
|
+
"""Distogram Module."""
|
796
|
+
|
797
|
+
def __init__(self, token_z: int, num_bins: int, num_distograms: int = 1) -> None:
|
798
|
+
"""Initialize the distogram module.
|
799
|
+
|
800
|
+
Parameters
|
801
|
+
----------
|
802
|
+
token_z : int
|
803
|
+
The token pairwise embedding size.
|
804
|
+
|
805
|
+
"""
|
806
|
+
super().__init__()
|
807
|
+
self.distogram = nn.Linear(token_z, num_distograms * num_bins)
|
808
|
+
self.num_distograms = num_distograms
|
809
|
+
self.num_bins = num_bins
|
810
|
+
|
811
|
+
def forward(self, z: Tensor) -> Tensor:
|
812
|
+
"""Perform the forward pass.
|
813
|
+
|
814
|
+
Parameters
|
815
|
+
----------
|
816
|
+
z : Tensor
|
817
|
+
The pairwise embeddings
|
818
|
+
|
819
|
+
Returns
|
820
|
+
-------
|
821
|
+
Tensor
|
822
|
+
The predicted distogram.
|
823
|
+
|
824
|
+
"""
|
825
|
+
z = z + z.transpose(1, 2)
|
826
|
+
return self.distogram(z).reshape(
|
827
|
+
z.shape[0], z.shape[1], z.shape[2], self.num_distograms, self.num_bins
|
828
|
+
)
|