boltz-vsynthes 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltz/__init__.py +7 -0
- boltz/data/__init__.py +0 -0
- boltz/data/const.py +1184 -0
- boltz/data/crop/__init__.py +0 -0
- boltz/data/crop/affinity.py +164 -0
- boltz/data/crop/boltz.py +296 -0
- boltz/data/crop/cropper.py +45 -0
- boltz/data/feature/__init__.py +0 -0
- boltz/data/feature/featurizer.py +1230 -0
- boltz/data/feature/featurizerv2.py +2208 -0
- boltz/data/feature/symmetry.py +602 -0
- boltz/data/filter/__init__.py +0 -0
- boltz/data/filter/dynamic/__init__.py +0 -0
- boltz/data/filter/dynamic/date.py +76 -0
- boltz/data/filter/dynamic/filter.py +24 -0
- boltz/data/filter/dynamic/max_residues.py +37 -0
- boltz/data/filter/dynamic/resolution.py +34 -0
- boltz/data/filter/dynamic/size.py +38 -0
- boltz/data/filter/dynamic/subset.py +42 -0
- boltz/data/filter/static/__init__.py +0 -0
- boltz/data/filter/static/filter.py +26 -0
- boltz/data/filter/static/ligand.py +37 -0
- boltz/data/filter/static/polymer.py +299 -0
- boltz/data/module/__init__.py +0 -0
- boltz/data/module/inference.py +307 -0
- boltz/data/module/inferencev2.py +429 -0
- boltz/data/module/training.py +684 -0
- boltz/data/module/trainingv2.py +660 -0
- boltz/data/mol.py +900 -0
- boltz/data/msa/__init__.py +0 -0
- boltz/data/msa/mmseqs2.py +235 -0
- boltz/data/pad.py +84 -0
- boltz/data/parse/__init__.py +0 -0
- boltz/data/parse/a3m.py +134 -0
- boltz/data/parse/csv.py +100 -0
- boltz/data/parse/fasta.py +138 -0
- boltz/data/parse/mmcif.py +1239 -0
- boltz/data/parse/mmcif_with_constraints.py +1607 -0
- boltz/data/parse/schema.py +1851 -0
- boltz/data/parse/yaml.py +68 -0
- boltz/data/sample/__init__.py +0 -0
- boltz/data/sample/cluster.py +283 -0
- boltz/data/sample/distillation.py +57 -0
- boltz/data/sample/random.py +39 -0
- boltz/data/sample/sampler.py +49 -0
- boltz/data/tokenize/__init__.py +0 -0
- boltz/data/tokenize/boltz.py +195 -0
- boltz/data/tokenize/boltz2.py +396 -0
- boltz/data/tokenize/tokenizer.py +24 -0
- boltz/data/types.py +777 -0
- boltz/data/write/__init__.py +0 -0
- boltz/data/write/mmcif.py +305 -0
- boltz/data/write/pdb.py +171 -0
- boltz/data/write/utils.py +23 -0
- boltz/data/write/writer.py +330 -0
- boltz/main.py +1292 -0
- boltz/model/__init__.py +0 -0
- boltz/model/layers/__init__.py +0 -0
- boltz/model/layers/attention.py +132 -0
- boltz/model/layers/attentionv2.py +111 -0
- boltz/model/layers/confidence_utils.py +231 -0
- boltz/model/layers/dropout.py +34 -0
- boltz/model/layers/initialize.py +100 -0
- boltz/model/layers/outer_product_mean.py +98 -0
- boltz/model/layers/pair_averaging.py +135 -0
- boltz/model/layers/pairformer.py +337 -0
- boltz/model/layers/relative.py +58 -0
- boltz/model/layers/transition.py +78 -0
- boltz/model/layers/triangular_attention/__init__.py +0 -0
- boltz/model/layers/triangular_attention/attention.py +189 -0
- boltz/model/layers/triangular_attention/primitives.py +409 -0
- boltz/model/layers/triangular_attention/utils.py +380 -0
- boltz/model/layers/triangular_mult.py +212 -0
- boltz/model/loss/__init__.py +0 -0
- boltz/model/loss/bfactor.py +49 -0
- boltz/model/loss/confidence.py +590 -0
- boltz/model/loss/confidencev2.py +621 -0
- boltz/model/loss/diffusion.py +171 -0
- boltz/model/loss/diffusionv2.py +134 -0
- boltz/model/loss/distogram.py +48 -0
- boltz/model/loss/distogramv2.py +105 -0
- boltz/model/loss/validation.py +1025 -0
- boltz/model/models/__init__.py +0 -0
- boltz/model/models/boltz1.py +1286 -0
- boltz/model/models/boltz2.py +1249 -0
- boltz/model/modules/__init__.py +0 -0
- boltz/model/modules/affinity.py +223 -0
- boltz/model/modules/confidence.py +481 -0
- boltz/model/modules/confidence_utils.py +181 -0
- boltz/model/modules/confidencev2.py +495 -0
- boltz/model/modules/diffusion.py +844 -0
- boltz/model/modules/diffusion_conditioning.py +116 -0
- boltz/model/modules/diffusionv2.py +677 -0
- boltz/model/modules/encoders.py +639 -0
- boltz/model/modules/encodersv2.py +565 -0
- boltz/model/modules/transformers.py +322 -0
- boltz/model/modules/transformersv2.py +261 -0
- boltz/model/modules/trunk.py +688 -0
- boltz/model/modules/trunkv2.py +828 -0
- boltz/model/modules/utils.py +303 -0
- boltz/model/optim/__init__.py +0 -0
- boltz/model/optim/ema.py +389 -0
- boltz/model/optim/scheduler.py +99 -0
- boltz/model/potentials/__init__.py +0 -0
- boltz/model/potentials/potentials.py +497 -0
- boltz/model/potentials/schedules.py +32 -0
- boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
- boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
- boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
- boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
- boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
- boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
File without changes
|
@@ -0,0 +1,223 @@
|
|
1
|
+
import torch
|
2
|
+
from torch import nn
|
3
|
+
|
4
|
+
import boltz.model.layers.initialize as init
|
5
|
+
from boltz.model.layers.pairformer import PairformerNoSeqModule
|
6
|
+
from boltz.model.modules.encodersv2 import PairwiseConditioning
|
7
|
+
from boltz.model.modules.transformersv2 import DiffusionTransformer
|
8
|
+
from boltz.model.modules.utils import LinearNoBias
|
9
|
+
|
10
|
+
|
11
|
+
class GaussianSmearing(torch.nn.Module):
|
12
|
+
"""Gaussian smearing."""
|
13
|
+
|
14
|
+
def __init__(
|
15
|
+
self,
|
16
|
+
start: float = 0.0,
|
17
|
+
stop: float = 5.0,
|
18
|
+
num_gaussians: int = 50,
|
19
|
+
) -> None:
|
20
|
+
super().__init__()
|
21
|
+
offset = torch.linspace(start, stop, num_gaussians)
|
22
|
+
self.num_gaussians = num_gaussians
|
23
|
+
self.coeff = -0.5 / (offset[1] - offset[0]).item() ** 2
|
24
|
+
self.register_buffer("offset", offset)
|
25
|
+
|
26
|
+
def forward(self, dist):
|
27
|
+
shape = dist.shape
|
28
|
+
dist = dist.view(-1, 1) - self.offset.view(1, -1)
|
29
|
+
return torch.exp(self.coeff * torch.pow(dist, 2)).reshape(
|
30
|
+
*shape, self.num_gaussians
|
31
|
+
)
|
32
|
+
|
33
|
+
|
34
|
+
class AffinityModule(nn.Module):
|
35
|
+
"""Algorithm 31"""
|
36
|
+
|
37
|
+
def __init__(
|
38
|
+
self,
|
39
|
+
token_s,
|
40
|
+
token_z,
|
41
|
+
pairformer_args: dict,
|
42
|
+
transformer_args: dict,
|
43
|
+
num_dist_bins=64,
|
44
|
+
max_dist=22,
|
45
|
+
use_cross_transformer: bool = False,
|
46
|
+
groups: dict = {},
|
47
|
+
):
|
48
|
+
super().__init__()
|
49
|
+
boundaries = torch.linspace(2, max_dist, num_dist_bins - 1)
|
50
|
+
self.register_buffer("boundaries", boundaries)
|
51
|
+
self.dist_bin_pairwise_embed = nn.Embedding(num_dist_bins, token_z)
|
52
|
+
init.gating_init_(self.dist_bin_pairwise_embed.weight)
|
53
|
+
|
54
|
+
self.s_to_z_prod_in1 = LinearNoBias(token_s, token_z)
|
55
|
+
self.s_to_z_prod_in2 = LinearNoBias(token_s, token_z)
|
56
|
+
|
57
|
+
self.z_norm = nn.LayerNorm(token_z)
|
58
|
+
self.z_linear = LinearNoBias(token_z, token_z)
|
59
|
+
|
60
|
+
self.pairwise_conditioner = PairwiseConditioning(
|
61
|
+
token_z=token_z,
|
62
|
+
dim_token_rel_pos_feats=token_z,
|
63
|
+
num_transitions=2,
|
64
|
+
)
|
65
|
+
|
66
|
+
self.pairformer_stack = PairformerNoSeqModule(token_z, **pairformer_args)
|
67
|
+
self.affinity_heads = AffinityHeadsTransformer(
|
68
|
+
token_z,
|
69
|
+
transformer_args["token_s"],
|
70
|
+
transformer_args["num_blocks"],
|
71
|
+
transformer_args["num_heads"],
|
72
|
+
transformer_args["activation_checkpointing"],
|
73
|
+
False,
|
74
|
+
groups=groups,
|
75
|
+
)
|
76
|
+
|
77
|
+
def forward(
|
78
|
+
self,
|
79
|
+
s_inputs,
|
80
|
+
z,
|
81
|
+
x_pred,
|
82
|
+
feats,
|
83
|
+
multiplicity=1,
|
84
|
+
use_kernels=False,
|
85
|
+
):
|
86
|
+
z = self.z_linear(self.z_norm(z))
|
87
|
+
z = z.repeat_interleave(multiplicity, 0)
|
88
|
+
|
89
|
+
z = (
|
90
|
+
z
|
91
|
+
+ self.s_to_z_prod_in1(s_inputs)[:, :, None, :]
|
92
|
+
+ self.s_to_z_prod_in2(s_inputs)[:, None, :, :]
|
93
|
+
)
|
94
|
+
|
95
|
+
token_to_rep_atom = feats["token_to_rep_atom"]
|
96
|
+
token_to_rep_atom = token_to_rep_atom.repeat_interleave(multiplicity, 0)
|
97
|
+
if len(x_pred.shape) == 4:
|
98
|
+
B, mult, N, _ = x_pred.shape
|
99
|
+
x_pred = x_pred.reshape(B * mult, N, -1)
|
100
|
+
else:
|
101
|
+
BM, N, _ = x_pred.shape
|
102
|
+
B = BM // multiplicity
|
103
|
+
mult = multiplicity
|
104
|
+
x_pred_repr = torch.bmm(token_to_rep_atom.float(), x_pred)
|
105
|
+
d = torch.cdist(x_pred_repr, x_pred_repr)
|
106
|
+
|
107
|
+
distogram = (d.unsqueeze(-1) > self.boundaries).sum(dim=-1).long()
|
108
|
+
distogram = self.dist_bin_pairwise_embed(distogram)
|
109
|
+
|
110
|
+
z = z + self.pairwise_conditioner(z_trunk=z, token_rel_pos_feats=distogram)
|
111
|
+
|
112
|
+
pad_token_mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
|
113
|
+
rec_mask = (feats["mol_type"] == 0).repeat_interleave(multiplicity, 0)
|
114
|
+
rec_mask = rec_mask * pad_token_mask
|
115
|
+
lig_mask = (
|
116
|
+
feats["affinity_token_mask"]
|
117
|
+
.repeat_interleave(multiplicity, 0)
|
118
|
+
.to(torch.bool)
|
119
|
+
)
|
120
|
+
lig_mask = lig_mask * pad_token_mask
|
121
|
+
cross_pair_mask = (
|
122
|
+
lig_mask[:, :, None] * rec_mask[:, None, :]
|
123
|
+
+ rec_mask[:, :, None] * lig_mask[:, None, :]
|
124
|
+
+ lig_mask[:, :, None] * lig_mask[:, None, :]
|
125
|
+
)
|
126
|
+
z = self.pairformer_stack(
|
127
|
+
z,
|
128
|
+
pair_mask=cross_pair_mask,
|
129
|
+
use_kernels=use_kernels,
|
130
|
+
)
|
131
|
+
|
132
|
+
out_dict = {}
|
133
|
+
|
134
|
+
# affinity heads
|
135
|
+
out_dict.update(
|
136
|
+
self.affinity_heads(z=z, feats=feats, multiplicity=multiplicity)
|
137
|
+
)
|
138
|
+
|
139
|
+
return out_dict
|
140
|
+
|
141
|
+
|
142
|
+
class AffinityHeadsTransformer(nn.Module):
|
143
|
+
def __init__(
|
144
|
+
self,
|
145
|
+
token_z,
|
146
|
+
input_token_s,
|
147
|
+
num_blocks,
|
148
|
+
num_heads,
|
149
|
+
activation_checkpointing,
|
150
|
+
use_cross_transformer,
|
151
|
+
groups={},
|
152
|
+
):
|
153
|
+
super().__init__()
|
154
|
+
self.affinity_out_mlp = nn.Sequential(
|
155
|
+
nn.Linear(token_z, token_z),
|
156
|
+
nn.ReLU(),
|
157
|
+
nn.Linear(token_z, input_token_s),
|
158
|
+
nn.ReLU(),
|
159
|
+
)
|
160
|
+
|
161
|
+
self.to_affinity_pred_value = nn.Sequential(
|
162
|
+
nn.Linear(input_token_s, input_token_s),
|
163
|
+
nn.ReLU(),
|
164
|
+
nn.Linear(input_token_s, input_token_s),
|
165
|
+
nn.ReLU(),
|
166
|
+
nn.Linear(input_token_s, 1),
|
167
|
+
)
|
168
|
+
|
169
|
+
self.to_affinity_pred_score = nn.Sequential(
|
170
|
+
nn.Linear(input_token_s, input_token_s),
|
171
|
+
nn.ReLU(),
|
172
|
+
nn.Linear(input_token_s, input_token_s),
|
173
|
+
nn.ReLU(),
|
174
|
+
nn.Linear(input_token_s, 1),
|
175
|
+
)
|
176
|
+
self.to_affinity_logits_binary = nn.Linear(1, 1)
|
177
|
+
|
178
|
+
def forward(
|
179
|
+
self,
|
180
|
+
z,
|
181
|
+
feats,
|
182
|
+
multiplicity=1,
|
183
|
+
):
|
184
|
+
pad_token_mask = (
|
185
|
+
feats["token_pad_mask"].repeat_interleave(multiplicity, 0).unsqueeze(-1)
|
186
|
+
)
|
187
|
+
rec_mask = (
|
188
|
+
(feats["mol_type"] == 0).repeat_interleave(multiplicity, 0).unsqueeze(-1)
|
189
|
+
)
|
190
|
+
rec_mask = rec_mask * pad_token_mask
|
191
|
+
lig_mask = (
|
192
|
+
feats["affinity_token_mask"]
|
193
|
+
.repeat_interleave(multiplicity, 0)
|
194
|
+
.to(torch.bool)
|
195
|
+
.unsqueeze(-1)
|
196
|
+
) * pad_token_mask
|
197
|
+
cross_pair_mask = (
|
198
|
+
lig_mask[:, :, None] * rec_mask[:, None, :]
|
199
|
+
+ rec_mask[:, :, None] * lig_mask[:, None, :]
|
200
|
+
+ (lig_mask[:, :, None] * lig_mask[:, None, :])
|
201
|
+
) * (
|
202
|
+
1
|
203
|
+
- torch.eye(lig_mask.shape[1], device=lig_mask.device)
|
204
|
+
.unsqueeze(-1)
|
205
|
+
.unsqueeze(0)
|
206
|
+
)
|
207
|
+
|
208
|
+
g = torch.sum(z * cross_pair_mask, dim=(1, 2)) / (
|
209
|
+
torch.sum(cross_pair_mask, dim=(1, 2)) + 1e-7
|
210
|
+
)
|
211
|
+
|
212
|
+
g = self.affinity_out_mlp(g)
|
213
|
+
|
214
|
+
affinity_pred_value = self.to_affinity_pred_value(g).reshape(-1, 1)
|
215
|
+
affinity_pred_score = self.to_affinity_pred_score(g).reshape(-1, 1)
|
216
|
+
affinity_logits_binary = self.to_affinity_logits_binary(
|
217
|
+
affinity_pred_score
|
218
|
+
).reshape(-1, 1)
|
219
|
+
out_dict = {
|
220
|
+
"affinity_pred_value": affinity_pred_value,
|
221
|
+
"affinity_logits_binary": affinity_logits_binary,
|
222
|
+
}
|
223
|
+
return out_dict
|
@@ -0,0 +1,481 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn.functional as F
|
3
|
+
from torch import nn
|
4
|
+
|
5
|
+
import boltz.model.layers.initialize as init
|
6
|
+
from boltz.data import const
|
7
|
+
from boltz.model.modules.confidence_utils import (
|
8
|
+
compute_aggregated_metric,
|
9
|
+
compute_ptms,
|
10
|
+
)
|
11
|
+
from boltz.model.modules.encoders import RelativePositionEncoder
|
12
|
+
from boltz.model.modules.trunk import (
|
13
|
+
InputEmbedder,
|
14
|
+
MSAModule,
|
15
|
+
PairformerModule,
|
16
|
+
)
|
17
|
+
from boltz.model.modules.utils import LinearNoBias
|
18
|
+
|
19
|
+
|
20
|
+
class ConfidenceModule(nn.Module):
|
21
|
+
"""Confidence module."""
|
22
|
+
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
token_s,
|
26
|
+
token_z,
|
27
|
+
pairformer_args: dict,
|
28
|
+
num_dist_bins=64,
|
29
|
+
max_dist=22,
|
30
|
+
add_s_to_z_prod=False,
|
31
|
+
add_s_input_to_s=False,
|
32
|
+
use_s_diffusion=False,
|
33
|
+
add_z_input_to_z=False,
|
34
|
+
confidence_args: dict = None,
|
35
|
+
compute_pae: bool = False,
|
36
|
+
imitate_trunk=False,
|
37
|
+
full_embedder_args: dict = None,
|
38
|
+
msa_args: dict = None,
|
39
|
+
compile_pairformer=False,
|
40
|
+
):
|
41
|
+
"""Initialize the confidence module.
|
42
|
+
|
43
|
+
Parameters
|
44
|
+
----------
|
45
|
+
token_s : int
|
46
|
+
The single representation dimension.
|
47
|
+
token_z : int
|
48
|
+
The pair representation dimension.
|
49
|
+
pairformer_args : int
|
50
|
+
The pairformer arguments.
|
51
|
+
num_dist_bins : int, optional
|
52
|
+
The number of distance bins, by default 64.
|
53
|
+
max_dist : int, optional
|
54
|
+
The maximum distance, by default 22.
|
55
|
+
add_s_to_z_prod : bool, optional
|
56
|
+
Whether to add s to z product, by default False.
|
57
|
+
add_s_input_to_s : bool, optional
|
58
|
+
Whether to add s input to s, by default False.
|
59
|
+
use_s_diffusion : bool, optional
|
60
|
+
Whether to use s diffusion, by default False.
|
61
|
+
add_z_input_to_z : bool, optional
|
62
|
+
Whether to add z input to z, by default False.
|
63
|
+
confidence_args : dict, optional
|
64
|
+
The confidence arguments, by default None.
|
65
|
+
compute_pae : bool, optional
|
66
|
+
Whether to compute pae, by default False.
|
67
|
+
imitate_trunk : bool, optional
|
68
|
+
Whether to imitate trunk, by default False.
|
69
|
+
full_embedder_args : dict, optional
|
70
|
+
The full embedder arguments, by default None.
|
71
|
+
msa_args : dict, optional
|
72
|
+
The msa arguments, by default None.
|
73
|
+
compile_pairformer : bool, optional
|
74
|
+
Whether to compile pairformer, by default False.
|
75
|
+
|
76
|
+
"""
|
77
|
+
super().__init__()
|
78
|
+
self.max_num_atoms_per_token = 23
|
79
|
+
self.no_update_s = pairformer_args.get("no_update_s", False)
|
80
|
+
boundaries = torch.linspace(2, max_dist, num_dist_bins - 1)
|
81
|
+
self.register_buffer("boundaries", boundaries)
|
82
|
+
self.dist_bin_pairwise_embed = nn.Embedding(num_dist_bins, token_z)
|
83
|
+
init.gating_init_(self.dist_bin_pairwise_embed.weight)
|
84
|
+
s_input_dim = (
|
85
|
+
token_s + 2 * const.num_tokens + 1 + len(const.pocket_contact_info)
|
86
|
+
)
|
87
|
+
|
88
|
+
self.use_s_diffusion = use_s_diffusion
|
89
|
+
if use_s_diffusion:
|
90
|
+
self.s_diffusion_norm = nn.LayerNorm(2 * token_s)
|
91
|
+
self.s_diffusion_to_s = LinearNoBias(2 * token_s, token_s)
|
92
|
+
init.gating_init_(self.s_diffusion_to_s.weight)
|
93
|
+
|
94
|
+
self.s_to_z = LinearNoBias(s_input_dim, token_z)
|
95
|
+
self.s_to_z_transpose = LinearNoBias(s_input_dim, token_z)
|
96
|
+
init.gating_init_(self.s_to_z.weight)
|
97
|
+
init.gating_init_(self.s_to_z_transpose.weight)
|
98
|
+
|
99
|
+
self.add_s_to_z_prod = add_s_to_z_prod
|
100
|
+
if add_s_to_z_prod:
|
101
|
+
self.s_to_z_prod_in1 = LinearNoBias(s_input_dim, token_z)
|
102
|
+
self.s_to_z_prod_in2 = LinearNoBias(s_input_dim, token_z)
|
103
|
+
self.s_to_z_prod_out = LinearNoBias(token_z, token_z)
|
104
|
+
init.gating_init_(self.s_to_z_prod_out.weight)
|
105
|
+
|
106
|
+
self.imitate_trunk = imitate_trunk
|
107
|
+
if self.imitate_trunk:
|
108
|
+
s_input_dim = (
|
109
|
+
token_s + 2 * const.num_tokens + 1 + len(const.pocket_contact_info)
|
110
|
+
)
|
111
|
+
self.s_init = nn.Linear(s_input_dim, token_s, bias=False)
|
112
|
+
self.z_init_1 = nn.Linear(s_input_dim, token_z, bias=False)
|
113
|
+
self.z_init_2 = nn.Linear(s_input_dim, token_z, bias=False)
|
114
|
+
|
115
|
+
# Input embeddings
|
116
|
+
self.input_embedder = InputEmbedder(**full_embedder_args)
|
117
|
+
self.rel_pos = RelativePositionEncoder(token_z)
|
118
|
+
self.token_bonds = nn.Linear(1, token_z, bias=False)
|
119
|
+
|
120
|
+
# Normalization layers
|
121
|
+
self.s_norm = nn.LayerNorm(token_s)
|
122
|
+
self.z_norm = nn.LayerNorm(token_z)
|
123
|
+
|
124
|
+
# Recycling projections
|
125
|
+
self.s_recycle = nn.Linear(token_s, token_s, bias=False)
|
126
|
+
self.z_recycle = nn.Linear(token_z, token_z, bias=False)
|
127
|
+
init.gating_init_(self.s_recycle.weight)
|
128
|
+
init.gating_init_(self.z_recycle.weight)
|
129
|
+
|
130
|
+
# Pairwise stack
|
131
|
+
self.msa_module = MSAModule(
|
132
|
+
token_z=token_z,
|
133
|
+
s_input_dim=s_input_dim,
|
134
|
+
**msa_args,
|
135
|
+
)
|
136
|
+
self.pairformer_module = PairformerModule(
|
137
|
+
token_s,
|
138
|
+
token_z,
|
139
|
+
**pairformer_args,
|
140
|
+
)
|
141
|
+
if compile_pairformer:
|
142
|
+
# Big models hit the default cache limit (8)
|
143
|
+
self.is_pairformer_compiled = True
|
144
|
+
torch._dynamo.config.cache_size_limit = 512
|
145
|
+
torch._dynamo.config.accumulated_cache_size_limit = 512
|
146
|
+
self.pairformer_module = torch.compile(
|
147
|
+
self.pairformer_module,
|
148
|
+
dynamic=False,
|
149
|
+
fullgraph=False,
|
150
|
+
)
|
151
|
+
|
152
|
+
self.final_s_norm = nn.LayerNorm(token_s)
|
153
|
+
self.final_z_norm = nn.LayerNorm(token_z)
|
154
|
+
else:
|
155
|
+
self.s_inputs_norm = nn.LayerNorm(s_input_dim)
|
156
|
+
if not self.no_update_s:
|
157
|
+
self.s_norm = nn.LayerNorm(token_s)
|
158
|
+
self.z_norm = nn.LayerNorm(token_z)
|
159
|
+
|
160
|
+
self.add_s_input_to_s = add_s_input_to_s
|
161
|
+
if add_s_input_to_s:
|
162
|
+
self.s_input_to_s = LinearNoBias(s_input_dim, token_s)
|
163
|
+
init.gating_init_(self.s_input_to_s.weight)
|
164
|
+
|
165
|
+
self.add_z_input_to_z = add_z_input_to_z
|
166
|
+
if add_z_input_to_z:
|
167
|
+
self.rel_pos = RelativePositionEncoder(token_z)
|
168
|
+
self.token_bonds = nn.Linear(1, token_z, bias=False)
|
169
|
+
|
170
|
+
self.pairformer_stack = PairformerModule(
|
171
|
+
token_s,
|
172
|
+
token_z,
|
173
|
+
**pairformer_args,
|
174
|
+
)
|
175
|
+
|
176
|
+
self.confidence_heads = ConfidenceHeads(
|
177
|
+
token_s,
|
178
|
+
token_z,
|
179
|
+
compute_pae=compute_pae,
|
180
|
+
**confidence_args,
|
181
|
+
)
|
182
|
+
|
183
|
+
def forward(
|
184
|
+
self,
|
185
|
+
s_inputs,
|
186
|
+
s,
|
187
|
+
z,
|
188
|
+
x_pred,
|
189
|
+
feats,
|
190
|
+
pred_distogram_logits,
|
191
|
+
multiplicity=1,
|
192
|
+
s_diffusion=None,
|
193
|
+
run_sequentially=False,
|
194
|
+
use_kernels: bool = False,
|
195
|
+
):
|
196
|
+
if run_sequentially and multiplicity > 1:
|
197
|
+
assert z.shape[0] == 1, "Not supported with batch size > 1"
|
198
|
+
out_dicts = []
|
199
|
+
for sample_idx in range(multiplicity):
|
200
|
+
out_dicts.append( # noqa: PERF401
|
201
|
+
self.forward(
|
202
|
+
s_inputs,
|
203
|
+
s,
|
204
|
+
z,
|
205
|
+
x_pred[sample_idx : sample_idx + 1],
|
206
|
+
feats,
|
207
|
+
pred_distogram_logits,
|
208
|
+
multiplicity=1,
|
209
|
+
s_diffusion=s_diffusion[sample_idx : sample_idx + 1]
|
210
|
+
if s_diffusion is not None
|
211
|
+
else None,
|
212
|
+
run_sequentially=False,
|
213
|
+
use_kernels=use_kernels,
|
214
|
+
)
|
215
|
+
)
|
216
|
+
|
217
|
+
out_dict = {}
|
218
|
+
for key in out_dicts[0]:
|
219
|
+
if key != "pair_chains_iptm":
|
220
|
+
out_dict[key] = torch.cat([out[key] for out in out_dicts], dim=0)
|
221
|
+
else:
|
222
|
+
pair_chains_iptm = {}
|
223
|
+
for chain_idx1 in out_dicts[0][key].keys():
|
224
|
+
chains_iptm = {}
|
225
|
+
for chain_idx2 in out_dicts[0][key][chain_idx1].keys():
|
226
|
+
chains_iptm[chain_idx2] = torch.cat(
|
227
|
+
[out[key][chain_idx1][chain_idx2] for out in out_dicts],
|
228
|
+
dim=0,
|
229
|
+
)
|
230
|
+
pair_chains_iptm[chain_idx1] = chains_iptm
|
231
|
+
out_dict[key] = pair_chains_iptm
|
232
|
+
return out_dict
|
233
|
+
if self.imitate_trunk:
|
234
|
+
s_inputs = self.input_embedder(feats)
|
235
|
+
|
236
|
+
# Initialize the sequence and pairwise embeddings
|
237
|
+
s_init = self.s_init(s_inputs)
|
238
|
+
z_init = (
|
239
|
+
self.z_init_1(s_inputs)[:, :, None]
|
240
|
+
+ self.z_init_2(s_inputs)[:, None, :]
|
241
|
+
)
|
242
|
+
relative_position_encoding = self.rel_pos(feats)
|
243
|
+
z_init = z_init + relative_position_encoding
|
244
|
+
z_init = z_init + self.token_bonds(feats["token_bonds"].float())
|
245
|
+
|
246
|
+
# Apply recycling
|
247
|
+
s = s_init + self.s_recycle(self.s_norm(s))
|
248
|
+
z = z_init + self.z_recycle(self.z_norm(z))
|
249
|
+
|
250
|
+
else:
|
251
|
+
s_inputs = self.s_inputs_norm(s_inputs).repeat_interleave(multiplicity, 0)
|
252
|
+
if not self.no_update_s:
|
253
|
+
s = self.s_norm(s)
|
254
|
+
|
255
|
+
if self.add_s_input_to_s:
|
256
|
+
s = s + self.s_input_to_s(s_inputs)
|
257
|
+
|
258
|
+
z = self.z_norm(z)
|
259
|
+
|
260
|
+
if self.add_z_input_to_z:
|
261
|
+
relative_position_encoding = self.rel_pos(feats)
|
262
|
+
z = z + relative_position_encoding
|
263
|
+
z = z + self.token_bonds(feats["token_bonds"].float())
|
264
|
+
|
265
|
+
s = s.repeat_interleave(multiplicity, 0)
|
266
|
+
|
267
|
+
if self.use_s_diffusion:
|
268
|
+
assert s_diffusion is not None
|
269
|
+
s_diffusion = self.s_diffusion_norm(s_diffusion)
|
270
|
+
s = s + self.s_diffusion_to_s(s_diffusion)
|
271
|
+
|
272
|
+
z = z.repeat_interleave(multiplicity, 0)
|
273
|
+
z = (
|
274
|
+
z
|
275
|
+
+ self.s_to_z(s_inputs)[:, :, None, :]
|
276
|
+
+ self.s_to_z_transpose(s_inputs)[:, None, :, :]
|
277
|
+
)
|
278
|
+
|
279
|
+
if self.add_s_to_z_prod:
|
280
|
+
z = z + self.s_to_z_prod_out(
|
281
|
+
self.s_to_z_prod_in1(s_inputs)[:, :, None, :]
|
282
|
+
* self.s_to_z_prod_in2(s_inputs)[:, None, :, :]
|
283
|
+
)
|
284
|
+
|
285
|
+
token_to_rep_atom = feats["token_to_rep_atom"]
|
286
|
+
token_to_rep_atom = token_to_rep_atom.repeat_interleave(multiplicity, 0)
|
287
|
+
if len(x_pred.shape) == 4:
|
288
|
+
B, mult, N, _ = x_pred.shape
|
289
|
+
x_pred = x_pred.reshape(B * mult, N, -1)
|
290
|
+
x_pred_repr = torch.bmm(token_to_rep_atom.float(), x_pred)
|
291
|
+
d = torch.cdist(x_pred_repr, x_pred_repr)
|
292
|
+
|
293
|
+
distogram = (d.unsqueeze(-1) > self.boundaries).sum(dim=-1).long()
|
294
|
+
distogram = self.dist_bin_pairwise_embed(distogram)
|
295
|
+
|
296
|
+
z = z + distogram
|
297
|
+
|
298
|
+
mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
|
299
|
+
pair_mask = mask[:, :, None] * mask[:, None, :]
|
300
|
+
|
301
|
+
if self.imitate_trunk:
|
302
|
+
z = z + self.msa_module(z, s_inputs, feats, use_kernels=use_kernels)
|
303
|
+
|
304
|
+
s, z = self.pairformer_module(
|
305
|
+
s, z, mask=mask, pair_mask=pair_mask, use_kernels=use_kernels
|
306
|
+
)
|
307
|
+
|
308
|
+
s, z = self.final_s_norm(s), self.final_z_norm(z)
|
309
|
+
|
310
|
+
else:
|
311
|
+
s_t, z_t = self.pairformer_stack(
|
312
|
+
s, z, mask=mask, pair_mask=pair_mask, use_kernels=use_kernels
|
313
|
+
)
|
314
|
+
|
315
|
+
# AF3 has residual connections, we remove them
|
316
|
+
s = s_t
|
317
|
+
z = z_t
|
318
|
+
|
319
|
+
out_dict = {}
|
320
|
+
|
321
|
+
# confidence heads
|
322
|
+
out_dict.update(
|
323
|
+
self.confidence_heads(
|
324
|
+
s=s,
|
325
|
+
z=z,
|
326
|
+
x_pred=x_pred,
|
327
|
+
d=d,
|
328
|
+
feats=feats,
|
329
|
+
multiplicity=multiplicity,
|
330
|
+
pred_distogram_logits=pred_distogram_logits,
|
331
|
+
)
|
332
|
+
)
|
333
|
+
|
334
|
+
return out_dict
|
335
|
+
|
336
|
+
|
337
|
+
class ConfidenceHeads(nn.Module):
|
338
|
+
"""Confidence heads."""
|
339
|
+
|
340
|
+
def __init__(
|
341
|
+
self,
|
342
|
+
token_s,
|
343
|
+
token_z,
|
344
|
+
num_plddt_bins=50,
|
345
|
+
num_pde_bins=64,
|
346
|
+
num_pae_bins=64,
|
347
|
+
compute_pae: bool = True,
|
348
|
+
):
|
349
|
+
"""Initialize the confidence head.
|
350
|
+
|
351
|
+
Parameters
|
352
|
+
----------
|
353
|
+
token_s : int
|
354
|
+
The single representation dimension.
|
355
|
+
token_z : int
|
356
|
+
The pair representation dimension.
|
357
|
+
num_plddt_bins : int
|
358
|
+
The number of plddt bins, by default 50.
|
359
|
+
num_pde_bins : int
|
360
|
+
The number of pde bins, by default 64.
|
361
|
+
num_pae_bins : int
|
362
|
+
The number of pae bins, by default 64.
|
363
|
+
compute_pae : bool
|
364
|
+
Whether to compute pae, by default False
|
365
|
+
|
366
|
+
"""
|
367
|
+
super().__init__()
|
368
|
+
self.max_num_atoms_per_token = 23
|
369
|
+
self.to_pde_logits = LinearNoBias(token_z, num_pde_bins)
|
370
|
+
self.to_plddt_logits = LinearNoBias(token_s, num_plddt_bins)
|
371
|
+
self.to_resolved_logits = LinearNoBias(token_s, 2)
|
372
|
+
self.compute_pae = compute_pae
|
373
|
+
if self.compute_pae:
|
374
|
+
self.to_pae_logits = LinearNoBias(token_z, num_pae_bins)
|
375
|
+
|
376
|
+
def forward(
|
377
|
+
self,
|
378
|
+
s,
|
379
|
+
z,
|
380
|
+
x_pred,
|
381
|
+
d,
|
382
|
+
feats,
|
383
|
+
pred_distogram_logits,
|
384
|
+
multiplicity=1,
|
385
|
+
):
|
386
|
+
# Compute the pLDDT, PDE, PAE, and resolved logits
|
387
|
+
plddt_logits = self.to_plddt_logits(s)
|
388
|
+
pde_logits = self.to_pde_logits(z + z.transpose(1, 2))
|
389
|
+
resolved_logits = self.to_resolved_logits(s)
|
390
|
+
if self.compute_pae:
|
391
|
+
pae_logits = self.to_pae_logits(z)
|
392
|
+
|
393
|
+
# Weights used to compute the interface pLDDT
|
394
|
+
ligand_weight = 2
|
395
|
+
interface_weight = 1
|
396
|
+
|
397
|
+
# Retrieve relevant features
|
398
|
+
token_type = feats["mol_type"]
|
399
|
+
token_type = token_type.repeat_interleave(multiplicity, 0)
|
400
|
+
is_ligand_token = (token_type == const.chain_type_ids["NONPOLYMER"]).float()
|
401
|
+
|
402
|
+
# Compute the aggregated pLDDT and iPLDDT
|
403
|
+
plddt = compute_aggregated_metric(plddt_logits)
|
404
|
+
token_pad_mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
|
405
|
+
complex_plddt = (plddt * token_pad_mask).sum(dim=-1) / token_pad_mask.sum(
|
406
|
+
dim=-1
|
407
|
+
)
|
408
|
+
|
409
|
+
is_contact = (d < 8).float()
|
410
|
+
is_different_chain = (
|
411
|
+
feats["asym_id"].unsqueeze(-1) != feats["asym_id"].unsqueeze(-2)
|
412
|
+
).float()
|
413
|
+
is_different_chain = is_different_chain.repeat_interleave(multiplicity, 0)
|
414
|
+
token_interface_mask = torch.max(
|
415
|
+
is_contact * is_different_chain * (1 - is_ligand_token).unsqueeze(-1),
|
416
|
+
dim=-1,
|
417
|
+
).values
|
418
|
+
iplddt_weight = (
|
419
|
+
is_ligand_token * ligand_weight + token_interface_mask * interface_weight
|
420
|
+
)
|
421
|
+
complex_iplddt = (plddt * token_pad_mask * iplddt_weight).sum(dim=-1) / (
|
422
|
+
torch.sum(token_pad_mask * iplddt_weight, dim=-1) + 1e-5
|
423
|
+
)
|
424
|
+
|
425
|
+
# Compute the aggregated PDE and iPDE
|
426
|
+
pde = compute_aggregated_metric(pde_logits, end=32)
|
427
|
+
pred_distogram_prob = nn.functional.softmax(
|
428
|
+
pred_distogram_logits, dim=-1
|
429
|
+
).repeat_interleave(multiplicity, 0)
|
430
|
+
contacts = torch.zeros((1, 1, 1, 64), dtype=pred_distogram_prob.dtype).to(
|
431
|
+
pred_distogram_prob.device
|
432
|
+
)
|
433
|
+
contacts[:, :, :, :20] = 1.0
|
434
|
+
prob_contact = (pred_distogram_prob * contacts).sum(-1)
|
435
|
+
token_pad_mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
|
436
|
+
token_pad_pair_mask = (
|
437
|
+
token_pad_mask.unsqueeze(-1)
|
438
|
+
* token_pad_mask.unsqueeze(-2)
|
439
|
+
* (
|
440
|
+
1
|
441
|
+
- torch.eye(
|
442
|
+
token_pad_mask.shape[1], device=token_pad_mask.device
|
443
|
+
).unsqueeze(0)
|
444
|
+
)
|
445
|
+
)
|
446
|
+
token_pair_mask = token_pad_pair_mask * prob_contact
|
447
|
+
complex_pde = (pde * token_pair_mask).sum(dim=(1, 2)) / token_pair_mask.sum(
|
448
|
+
dim=(1, 2)
|
449
|
+
)
|
450
|
+
asym_id = feats["asym_id"].repeat_interleave(multiplicity, 0)
|
451
|
+
token_interface_pair_mask = token_pair_mask * (
|
452
|
+
asym_id.unsqueeze(-1) != asym_id.unsqueeze(-2)
|
453
|
+
)
|
454
|
+
complex_ipde = (pde * token_interface_pair_mask).sum(dim=(1, 2)) / (
|
455
|
+
token_interface_pair_mask.sum(dim=(1, 2)) + 1e-5
|
456
|
+
)
|
457
|
+
|
458
|
+
out_dict = dict(
|
459
|
+
pde_logits=pde_logits,
|
460
|
+
plddt_logits=plddt_logits,
|
461
|
+
resolved_logits=resolved_logits,
|
462
|
+
pde=pde,
|
463
|
+
plddt=plddt,
|
464
|
+
complex_plddt=complex_plddt,
|
465
|
+
complex_iplddt=complex_iplddt,
|
466
|
+
complex_pde=complex_pde,
|
467
|
+
complex_ipde=complex_ipde,
|
468
|
+
)
|
469
|
+
if self.compute_pae:
|
470
|
+
out_dict["pae_logits"] = pae_logits
|
471
|
+
out_dict["pae"] = compute_aggregated_metric(pae_logits, end=32)
|
472
|
+
ptm, iptm, ligand_iptm, protein_iptm, pair_chains_iptm = compute_ptms(
|
473
|
+
pae_logits, x_pred, feats, multiplicity
|
474
|
+
)
|
475
|
+
out_dict["ptm"] = ptm
|
476
|
+
out_dict["iptm"] = iptm
|
477
|
+
out_dict["ligand_iptm"] = ligand_iptm
|
478
|
+
out_dict["protein_iptm"] = protein_iptm
|
479
|
+
out_dict["pair_chains_iptm"] = pair_chains_iptm
|
480
|
+
|
481
|
+
return out_dict
|