boltz-vsynthes 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltz/__init__.py +7 -0
- boltz/data/__init__.py +0 -0
- boltz/data/const.py +1184 -0
- boltz/data/crop/__init__.py +0 -0
- boltz/data/crop/affinity.py +164 -0
- boltz/data/crop/boltz.py +296 -0
- boltz/data/crop/cropper.py +45 -0
- boltz/data/feature/__init__.py +0 -0
- boltz/data/feature/featurizer.py +1230 -0
- boltz/data/feature/featurizerv2.py +2208 -0
- boltz/data/feature/symmetry.py +602 -0
- boltz/data/filter/__init__.py +0 -0
- boltz/data/filter/dynamic/__init__.py +0 -0
- boltz/data/filter/dynamic/date.py +76 -0
- boltz/data/filter/dynamic/filter.py +24 -0
- boltz/data/filter/dynamic/max_residues.py +37 -0
- boltz/data/filter/dynamic/resolution.py +34 -0
- boltz/data/filter/dynamic/size.py +38 -0
- boltz/data/filter/dynamic/subset.py +42 -0
- boltz/data/filter/static/__init__.py +0 -0
- boltz/data/filter/static/filter.py +26 -0
- boltz/data/filter/static/ligand.py +37 -0
- boltz/data/filter/static/polymer.py +299 -0
- boltz/data/module/__init__.py +0 -0
- boltz/data/module/inference.py +307 -0
- boltz/data/module/inferencev2.py +429 -0
- boltz/data/module/training.py +684 -0
- boltz/data/module/trainingv2.py +660 -0
- boltz/data/mol.py +900 -0
- boltz/data/msa/__init__.py +0 -0
- boltz/data/msa/mmseqs2.py +235 -0
- boltz/data/pad.py +84 -0
- boltz/data/parse/__init__.py +0 -0
- boltz/data/parse/a3m.py +134 -0
- boltz/data/parse/csv.py +100 -0
- boltz/data/parse/fasta.py +138 -0
- boltz/data/parse/mmcif.py +1239 -0
- boltz/data/parse/mmcif_with_constraints.py +1607 -0
- boltz/data/parse/schema.py +1851 -0
- boltz/data/parse/yaml.py +68 -0
- boltz/data/sample/__init__.py +0 -0
- boltz/data/sample/cluster.py +283 -0
- boltz/data/sample/distillation.py +57 -0
- boltz/data/sample/random.py +39 -0
- boltz/data/sample/sampler.py +49 -0
- boltz/data/tokenize/__init__.py +0 -0
- boltz/data/tokenize/boltz.py +195 -0
- boltz/data/tokenize/boltz2.py +396 -0
- boltz/data/tokenize/tokenizer.py +24 -0
- boltz/data/types.py +777 -0
- boltz/data/write/__init__.py +0 -0
- boltz/data/write/mmcif.py +305 -0
- boltz/data/write/pdb.py +171 -0
- boltz/data/write/utils.py +23 -0
- boltz/data/write/writer.py +330 -0
- boltz/main.py +1292 -0
- boltz/model/__init__.py +0 -0
- boltz/model/layers/__init__.py +0 -0
- boltz/model/layers/attention.py +132 -0
- boltz/model/layers/attentionv2.py +111 -0
- boltz/model/layers/confidence_utils.py +231 -0
- boltz/model/layers/dropout.py +34 -0
- boltz/model/layers/initialize.py +100 -0
- boltz/model/layers/outer_product_mean.py +98 -0
- boltz/model/layers/pair_averaging.py +135 -0
- boltz/model/layers/pairformer.py +337 -0
- boltz/model/layers/relative.py +58 -0
- boltz/model/layers/transition.py +78 -0
- boltz/model/layers/triangular_attention/__init__.py +0 -0
- boltz/model/layers/triangular_attention/attention.py +189 -0
- boltz/model/layers/triangular_attention/primitives.py +409 -0
- boltz/model/layers/triangular_attention/utils.py +380 -0
- boltz/model/layers/triangular_mult.py +212 -0
- boltz/model/loss/__init__.py +0 -0
- boltz/model/loss/bfactor.py +49 -0
- boltz/model/loss/confidence.py +590 -0
- boltz/model/loss/confidencev2.py +621 -0
- boltz/model/loss/diffusion.py +171 -0
- boltz/model/loss/diffusionv2.py +134 -0
- boltz/model/loss/distogram.py +48 -0
- boltz/model/loss/distogramv2.py +105 -0
- boltz/model/loss/validation.py +1025 -0
- boltz/model/models/__init__.py +0 -0
- boltz/model/models/boltz1.py +1286 -0
- boltz/model/models/boltz2.py +1249 -0
- boltz/model/modules/__init__.py +0 -0
- boltz/model/modules/affinity.py +223 -0
- boltz/model/modules/confidence.py +481 -0
- boltz/model/modules/confidence_utils.py +181 -0
- boltz/model/modules/confidencev2.py +495 -0
- boltz/model/modules/diffusion.py +844 -0
- boltz/model/modules/diffusion_conditioning.py +116 -0
- boltz/model/modules/diffusionv2.py +677 -0
- boltz/model/modules/encoders.py +639 -0
- boltz/model/modules/encodersv2.py +565 -0
- boltz/model/modules/transformers.py +322 -0
- boltz/model/modules/transformersv2.py +261 -0
- boltz/model/modules/trunk.py +688 -0
- boltz/model/modules/trunkv2.py +828 -0
- boltz/model/modules/utils.py +303 -0
- boltz/model/optim/__init__.py +0 -0
- boltz/model/optim/ema.py +389 -0
- boltz/model/optim/scheduler.py +99 -0
- boltz/model/potentials/__init__.py +0 -0
- boltz/model/potentials/potentials.py +497 -0
- boltz/model/potentials/schedules.py +32 -0
- boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
- boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
- boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
- boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
- boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
- boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,181 @@
|
|
1
|
+
import torch
|
2
|
+
from torch import nn
|
3
|
+
|
4
|
+
from boltz.data import const
|
5
|
+
from boltz.model.loss.confidence import compute_frame_pred
|
6
|
+
|
7
|
+
|
8
|
+
def compute_aggregated_metric(logits, end=1.0):
|
9
|
+
"""Compute the metric from the logits.
|
10
|
+
|
11
|
+
Parameters
|
12
|
+
----------
|
13
|
+
logits : torch.Tensor
|
14
|
+
The logits of the metric
|
15
|
+
end : float
|
16
|
+
Max value of the metric, by default 1.0
|
17
|
+
|
18
|
+
Returns
|
19
|
+
-------
|
20
|
+
Tensor
|
21
|
+
The metric value
|
22
|
+
|
23
|
+
"""
|
24
|
+
num_bins = logits.shape[-1]
|
25
|
+
bin_width = end / num_bins
|
26
|
+
bounds = torch.arange(
|
27
|
+
start=0.5 * bin_width, end=end, step=bin_width, device=logits.device
|
28
|
+
)
|
29
|
+
probs = nn.functional.softmax(logits, dim=-1)
|
30
|
+
plddt = torch.sum(
|
31
|
+
probs * bounds.view(*((1,) * len(probs.shape[:-1])), *bounds.shape),
|
32
|
+
dim=-1,
|
33
|
+
)
|
34
|
+
return plddt
|
35
|
+
|
36
|
+
|
37
|
+
def tm_function(d, Nres):
|
38
|
+
"""Compute the rescaling function for pTM.
|
39
|
+
|
40
|
+
Parameters
|
41
|
+
----------
|
42
|
+
d : torch.Tensor
|
43
|
+
The input
|
44
|
+
Nres : torch.Tensor
|
45
|
+
The number of residues
|
46
|
+
|
47
|
+
Returns
|
48
|
+
-------
|
49
|
+
Tensor
|
50
|
+
Output of the function
|
51
|
+
|
52
|
+
"""
|
53
|
+
d0 = 1.24 * (torch.clip(Nres, min=19) - 15) ** (1 / 3) - 1.8
|
54
|
+
return 1 / (1 + (d / d0) ** 2)
|
55
|
+
|
56
|
+
|
57
|
+
def compute_ptms(logits, x_preds, feats, multiplicity):
|
58
|
+
"""Compute pTM and ipTM scores.
|
59
|
+
|
60
|
+
Parameters
|
61
|
+
----------
|
62
|
+
logits : torch.Tensor
|
63
|
+
pae logits
|
64
|
+
x_preds : torch.Tensor
|
65
|
+
The predicted coordinates
|
66
|
+
feats : Dict[str, torch.Tensor]
|
67
|
+
The input features
|
68
|
+
multiplicity : int
|
69
|
+
The batch size of the diffusion roll-out
|
70
|
+
|
71
|
+
Returns
|
72
|
+
-------
|
73
|
+
Tensor
|
74
|
+
pTM score
|
75
|
+
Tensor
|
76
|
+
ipTM score
|
77
|
+
Tensor
|
78
|
+
ligand ipTM score
|
79
|
+
Tensor
|
80
|
+
protein ipTM score
|
81
|
+
|
82
|
+
"""
|
83
|
+
# Compute mask for collinear and overlapping tokens
|
84
|
+
_, mask_collinear_pred = compute_frame_pred(
|
85
|
+
x_preds, feats["frames_idx"], feats, multiplicity, inference=True
|
86
|
+
)
|
87
|
+
mask_pad = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
|
88
|
+
maski = mask_collinear_pred.reshape(-1, mask_collinear_pred.shape[-1])
|
89
|
+
pair_mask_ptm = maski[:, :, None] * mask_pad[:, None, :] * mask_pad[:, :, None]
|
90
|
+
asym_id = feats["asym_id"].repeat_interleave(multiplicity, 0)
|
91
|
+
pair_mask_iptm = (
|
92
|
+
maski[:, :, None]
|
93
|
+
* (asym_id[:, None, :] != asym_id[:, :, None])
|
94
|
+
* mask_pad[:, None, :]
|
95
|
+
* mask_pad[:, :, None]
|
96
|
+
)
|
97
|
+
|
98
|
+
# Extract pae values
|
99
|
+
num_bins = logits.shape[-1]
|
100
|
+
bin_width = 32.0 / num_bins
|
101
|
+
end = 32.0
|
102
|
+
pae_value = torch.arange(
|
103
|
+
start=0.5 * bin_width, end=end, step=bin_width, device=logits.device
|
104
|
+
).unsqueeze(0)
|
105
|
+
N_res = mask_pad.sum(dim=-1, keepdim=True)
|
106
|
+
|
107
|
+
# compute pTM and ipTM
|
108
|
+
tm_value = tm_function(pae_value, N_res).unsqueeze(1).unsqueeze(2)
|
109
|
+
probs = nn.functional.softmax(logits, dim=-1)
|
110
|
+
tm_expected_value = torch.sum(
|
111
|
+
probs * tm_value,
|
112
|
+
dim=-1,
|
113
|
+
) # shape (B, N, N)
|
114
|
+
ptm = torch.max(
|
115
|
+
torch.sum(tm_expected_value * pair_mask_ptm, dim=-1)
|
116
|
+
/ (torch.sum(pair_mask_ptm, dim=-1) + 1e-5),
|
117
|
+
dim=1,
|
118
|
+
).values
|
119
|
+
iptm = torch.max(
|
120
|
+
torch.sum(tm_expected_value * pair_mask_iptm, dim=-1)
|
121
|
+
/ (torch.sum(pair_mask_iptm, dim=-1) + 1e-5),
|
122
|
+
dim=1,
|
123
|
+
).values
|
124
|
+
|
125
|
+
# compute ligand and protein ipTM
|
126
|
+
token_type = feats["mol_type"]
|
127
|
+
token_type = token_type.repeat_interleave(multiplicity, 0)
|
128
|
+
is_ligand_token = (token_type == const.chain_type_ids["NONPOLYMER"]).float()
|
129
|
+
is_protein_token = (token_type == const.chain_type_ids["PROTEIN"]).float()
|
130
|
+
|
131
|
+
ligand_iptm_mask = (
|
132
|
+
maski[:, :, None]
|
133
|
+
* (asym_id[:, None, :] != asym_id[:, :, None])
|
134
|
+
* mask_pad[:, None, :]
|
135
|
+
* mask_pad[:, :, None]
|
136
|
+
* (
|
137
|
+
(is_ligand_token[:, :, None] * is_protein_token[:, None, :])
|
138
|
+
+ (is_protein_token[:, :, None] * is_ligand_token[:, None, :])
|
139
|
+
)
|
140
|
+
)
|
141
|
+
protein_ipmt_mask = (
|
142
|
+
maski[:, :, None]
|
143
|
+
* (asym_id[:, None, :] != asym_id[:, :, None])
|
144
|
+
* mask_pad[:, None, :]
|
145
|
+
* mask_pad[:, :, None]
|
146
|
+
* (is_protein_token[:, :, None] * is_protein_token[:, None, :])
|
147
|
+
)
|
148
|
+
|
149
|
+
ligand_iptm = torch.max(
|
150
|
+
torch.sum(tm_expected_value * ligand_iptm_mask, dim=-1)
|
151
|
+
/ (torch.sum(ligand_iptm_mask, dim=-1) + 1e-5),
|
152
|
+
dim=1,
|
153
|
+
).values
|
154
|
+
protein_iptm = torch.max(
|
155
|
+
torch.sum(tm_expected_value * protein_ipmt_mask, dim=-1)
|
156
|
+
/ (torch.sum(protein_ipmt_mask, dim=-1) + 1e-5),
|
157
|
+
dim=1,
|
158
|
+
).values
|
159
|
+
|
160
|
+
# Compute pair chain ipTM
|
161
|
+
chain_pair_iptm = {}
|
162
|
+
asym_ids_list = torch.unique(asym_id).tolist()
|
163
|
+
for idx1 in asym_ids_list:
|
164
|
+
chain_iptm = {}
|
165
|
+
for idx2 in asym_ids_list:
|
166
|
+
mask_pair_chain = (
|
167
|
+
maski[:, :, None]
|
168
|
+
* (asym_id[:, None, :] == idx1)
|
169
|
+
* (asym_id[:, :, None] == idx2)
|
170
|
+
* mask_pad[:, None, :]
|
171
|
+
* mask_pad[:, :, None]
|
172
|
+
)
|
173
|
+
|
174
|
+
chain_iptm[idx2] = torch.max(
|
175
|
+
torch.sum(tm_expected_value * mask_pair_chain, dim=-1)
|
176
|
+
/ (torch.sum(mask_pair_chain, dim=-1) + 1e-5),
|
177
|
+
dim=1,
|
178
|
+
).values
|
179
|
+
chain_pair_iptm[idx1] = chain_iptm
|
180
|
+
|
181
|
+
return ptm, iptm, ligand_iptm, protein_iptm, chain_pair_iptm
|
@@ -0,0 +1,495 @@
|
|
1
|
+
import torch
|
2
|
+
from torch import nn
|
3
|
+
from torch.nn.functional import pad
|
4
|
+
|
5
|
+
import boltz.model.layers.initialize as init
|
6
|
+
from boltz.data import const
|
7
|
+
from boltz.model.layers.confidence_utils import (
|
8
|
+
compute_aggregated_metric,
|
9
|
+
compute_ptms,
|
10
|
+
)
|
11
|
+
from boltz.model.layers.pairformer import PairformerModule
|
12
|
+
from boltz.model.modules.encodersv2 import RelativePositionEncoder
|
13
|
+
from boltz.model.modules.trunkv2 import (
|
14
|
+
ContactConditioning,
|
15
|
+
)
|
16
|
+
from boltz.model.modules.utils import LinearNoBias
|
17
|
+
|
18
|
+
|
19
|
+
class ConfidenceModule(nn.Module):
|
20
|
+
"""Algorithm 31"""
|
21
|
+
|
22
|
+
def __init__(
|
23
|
+
self,
|
24
|
+
token_s,
|
25
|
+
token_z,
|
26
|
+
pairformer_args: dict,
|
27
|
+
num_dist_bins=64,
|
28
|
+
token_level_confidence=True,
|
29
|
+
max_dist=22,
|
30
|
+
add_s_to_z_prod=False,
|
31
|
+
add_s_input_to_s=False,
|
32
|
+
add_z_input_to_z=False,
|
33
|
+
maximum_bond_distance=0,
|
34
|
+
bond_type_feature=False,
|
35
|
+
confidence_args: dict = None,
|
36
|
+
compile_pairformer=False,
|
37
|
+
fix_sym_check=False,
|
38
|
+
cyclic_pos_enc=False,
|
39
|
+
return_latent_feats=False,
|
40
|
+
conditioning_cutoff_min=None,
|
41
|
+
conditioning_cutoff_max=None,
|
42
|
+
**kwargs,
|
43
|
+
):
|
44
|
+
super().__init__()
|
45
|
+
self.max_num_atoms_per_token = 23
|
46
|
+
self.no_update_s = pairformer_args.get("no_update_s", False)
|
47
|
+
boundaries = torch.linspace(2, max_dist, num_dist_bins - 1)
|
48
|
+
self.register_buffer("boundaries", boundaries)
|
49
|
+
self.dist_bin_pairwise_embed = nn.Embedding(num_dist_bins, token_z)
|
50
|
+
init.gating_init_(self.dist_bin_pairwise_embed.weight)
|
51
|
+
self.token_level_confidence = token_level_confidence
|
52
|
+
|
53
|
+
self.s_to_z = LinearNoBias(token_s, token_z)
|
54
|
+
self.s_to_z_transpose = LinearNoBias(token_s, token_z)
|
55
|
+
init.gating_init_(self.s_to_z.weight)
|
56
|
+
init.gating_init_(self.s_to_z_transpose.weight)
|
57
|
+
|
58
|
+
self.add_s_to_z_prod = add_s_to_z_prod
|
59
|
+
if add_s_to_z_prod:
|
60
|
+
self.s_to_z_prod_in1 = LinearNoBias(token_s, token_z)
|
61
|
+
self.s_to_z_prod_in2 = LinearNoBias(token_s, token_z)
|
62
|
+
self.s_to_z_prod_out = LinearNoBias(token_z, token_z)
|
63
|
+
init.gating_init_(self.s_to_z_prod_out.weight)
|
64
|
+
|
65
|
+
self.s_inputs_norm = nn.LayerNorm(token_s)
|
66
|
+
if not self.no_update_s:
|
67
|
+
self.s_norm = nn.LayerNorm(token_s)
|
68
|
+
self.z_norm = nn.LayerNorm(token_z)
|
69
|
+
|
70
|
+
self.add_s_input_to_s = add_s_input_to_s
|
71
|
+
if add_s_input_to_s:
|
72
|
+
self.s_input_to_s = LinearNoBias(token_s, token_s)
|
73
|
+
init.gating_init_(self.s_input_to_s.weight)
|
74
|
+
|
75
|
+
self.add_z_input_to_z = add_z_input_to_z
|
76
|
+
if add_z_input_to_z:
|
77
|
+
self.rel_pos = RelativePositionEncoder(
|
78
|
+
token_z, fix_sym_check=fix_sym_check, cyclic_pos_enc=cyclic_pos_enc
|
79
|
+
)
|
80
|
+
self.token_bonds = nn.Linear(
|
81
|
+
1 if maximum_bond_distance == 0 else maximum_bond_distance + 2,
|
82
|
+
token_z,
|
83
|
+
bias=False,
|
84
|
+
)
|
85
|
+
self.bond_type_feature = bond_type_feature
|
86
|
+
if bond_type_feature:
|
87
|
+
self.token_bonds_type = nn.Embedding(len(const.bond_types) + 1, token_z)
|
88
|
+
|
89
|
+
self.contact_conditioning = ContactConditioning(
|
90
|
+
token_z=token_z,
|
91
|
+
cutoff_min=conditioning_cutoff_min,
|
92
|
+
cutoff_max=conditioning_cutoff_max,
|
93
|
+
)
|
94
|
+
pairformer_args["v2"] = True
|
95
|
+
self.pairformer_stack = PairformerModule(
|
96
|
+
token_s,
|
97
|
+
token_z,
|
98
|
+
**pairformer_args,
|
99
|
+
)
|
100
|
+
self.return_latent_feats = return_latent_feats
|
101
|
+
|
102
|
+
self.confidence_heads = ConfidenceHeads(
|
103
|
+
token_s,
|
104
|
+
token_z,
|
105
|
+
token_level_confidence=token_level_confidence,
|
106
|
+
**confidence_args,
|
107
|
+
)
|
108
|
+
|
109
|
+
def forward(
|
110
|
+
self,
|
111
|
+
s_inputs, # Float['b n ts']
|
112
|
+
s, # Float['b n ts']
|
113
|
+
z, # Float['b n n tz']
|
114
|
+
x_pred, # Float['bm m 3']
|
115
|
+
feats,
|
116
|
+
pred_distogram_logits,
|
117
|
+
multiplicity=1,
|
118
|
+
run_sequentially=False,
|
119
|
+
use_kernels: bool = False,
|
120
|
+
):
|
121
|
+
if run_sequentially and multiplicity > 1:
|
122
|
+
assert z.shape[0] == 1, "Not supported with batch size > 1"
|
123
|
+
out_dicts = []
|
124
|
+
for sample_idx in range(multiplicity):
|
125
|
+
out_dicts.append( # noqa: PERF401
|
126
|
+
self.forward(
|
127
|
+
s_inputs,
|
128
|
+
s,
|
129
|
+
z,
|
130
|
+
x_pred[sample_idx : sample_idx + 1],
|
131
|
+
feats,
|
132
|
+
pred_distogram_logits,
|
133
|
+
multiplicity=1,
|
134
|
+
run_sequentially=False,
|
135
|
+
use_kernels=use_kernels,
|
136
|
+
)
|
137
|
+
)
|
138
|
+
|
139
|
+
out_dict = {}
|
140
|
+
for key in out_dicts[0]:
|
141
|
+
if key != "pair_chains_iptm":
|
142
|
+
out_dict[key] = torch.cat([out[key] for out in out_dicts], dim=0)
|
143
|
+
else:
|
144
|
+
pair_chains_iptm = {}
|
145
|
+
for chain_idx1 in out_dicts[0][key]:
|
146
|
+
chains_iptm = {}
|
147
|
+
for chain_idx2 in out_dicts[0][key][chain_idx1]:
|
148
|
+
chains_iptm[chain_idx2] = torch.cat(
|
149
|
+
[out[key][chain_idx1][chain_idx2] for out in out_dicts],
|
150
|
+
dim=0,
|
151
|
+
)
|
152
|
+
pair_chains_iptm[chain_idx1] = chains_iptm
|
153
|
+
out_dict[key] = pair_chains_iptm
|
154
|
+
return out_dict
|
155
|
+
|
156
|
+
s_inputs = self.s_inputs_norm(s_inputs)
|
157
|
+
if not self.no_update_s:
|
158
|
+
s = self.s_norm(s)
|
159
|
+
|
160
|
+
if self.add_s_input_to_s:
|
161
|
+
s = s + self.s_input_to_s(s_inputs)
|
162
|
+
|
163
|
+
z = self.z_norm(z)
|
164
|
+
|
165
|
+
if self.add_z_input_to_z:
|
166
|
+
relative_position_encoding = self.rel_pos(feats)
|
167
|
+
z = z + relative_position_encoding
|
168
|
+
z = z + self.token_bonds(feats["token_bonds"].float())
|
169
|
+
if self.bond_type_feature:
|
170
|
+
z = z + self.token_bonds_type(feats["type_bonds"].long())
|
171
|
+
z = z + self.contact_conditioning(feats)
|
172
|
+
|
173
|
+
s = s.repeat_interleave(multiplicity, 0)
|
174
|
+
|
175
|
+
z = (
|
176
|
+
z
|
177
|
+
+ self.s_to_z(s_inputs)[:, :, None, :]
|
178
|
+
+ self.s_to_z_transpose(s_inputs)[:, None, :, :]
|
179
|
+
)
|
180
|
+
if self.add_s_to_z_prod:
|
181
|
+
z = z + self.s_to_z_prod_out(
|
182
|
+
self.s_to_z_prod_in1(s_inputs)[:, :, None, :]
|
183
|
+
* self.s_to_z_prod_in2(s_inputs)[:, None, :, :]
|
184
|
+
)
|
185
|
+
|
186
|
+
z = z.repeat_interleave(multiplicity, 0)
|
187
|
+
s_inputs = s_inputs.repeat_interleave(multiplicity, 0)
|
188
|
+
|
189
|
+
token_to_rep_atom = feats["token_to_rep_atom"]
|
190
|
+
token_to_rep_atom = token_to_rep_atom.repeat_interleave(multiplicity, 0)
|
191
|
+
if len(x_pred.shape) == 4:
|
192
|
+
B, mult, N, _ = x_pred.shape
|
193
|
+
x_pred = x_pred.reshape(B * mult, N, -1)
|
194
|
+
else:
|
195
|
+
BM, N, _ = x_pred.shape
|
196
|
+
x_pred_repr = torch.bmm(token_to_rep_atom.float(), x_pred)
|
197
|
+
d = torch.cdist(x_pred_repr, x_pred_repr)
|
198
|
+
distogram = (d.unsqueeze(-1) > self.boundaries).sum(dim=-1).long()
|
199
|
+
distogram = self.dist_bin_pairwise_embed(distogram)
|
200
|
+
z = z + distogram
|
201
|
+
|
202
|
+
mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
|
203
|
+
pair_mask = mask[:, :, None] * mask[:, None, :]
|
204
|
+
|
205
|
+
s_t, z_t = self.pairformer_stack(
|
206
|
+
s, z, mask=mask, pair_mask=pair_mask, use_kernels=use_kernels
|
207
|
+
)
|
208
|
+
|
209
|
+
# AF3 has residual connections, we remove them
|
210
|
+
s = s_t
|
211
|
+
z = z_t
|
212
|
+
|
213
|
+
out_dict = {}
|
214
|
+
|
215
|
+
if self.return_latent_feats:
|
216
|
+
out_dict["s_conf"] = s
|
217
|
+
out_dict["z_conf"] = z
|
218
|
+
|
219
|
+
# confidence heads
|
220
|
+
out_dict.update(
|
221
|
+
self.confidence_heads(
|
222
|
+
s=s,
|
223
|
+
z=z,
|
224
|
+
x_pred=x_pred,
|
225
|
+
d=d,
|
226
|
+
feats=feats,
|
227
|
+
multiplicity=multiplicity,
|
228
|
+
pred_distogram_logits=pred_distogram_logits,
|
229
|
+
)
|
230
|
+
)
|
231
|
+
return out_dict
|
232
|
+
|
233
|
+
|
234
|
+
class ConfidenceHeads(nn.Module):
|
235
|
+
def __init__(
|
236
|
+
self,
|
237
|
+
token_s,
|
238
|
+
token_z,
|
239
|
+
num_plddt_bins=50,
|
240
|
+
num_pde_bins=64,
|
241
|
+
num_pae_bins=64,
|
242
|
+
token_level_confidence=True,
|
243
|
+
use_separate_heads: bool = False,
|
244
|
+
**kwargs,
|
245
|
+
):
|
246
|
+
super().__init__()
|
247
|
+
self.max_num_atoms_per_token = 23
|
248
|
+
self.token_level_confidence = token_level_confidence
|
249
|
+
self.use_separate_heads = use_separate_heads
|
250
|
+
|
251
|
+
if self.use_separate_heads:
|
252
|
+
self.to_pae_intra_logits = LinearNoBias(token_z, num_pae_bins)
|
253
|
+
self.to_pae_inter_logits = LinearNoBias(token_z, num_pae_bins)
|
254
|
+
else:
|
255
|
+
self.to_pae_logits = LinearNoBias(token_z, num_pae_bins)
|
256
|
+
|
257
|
+
if self.use_separate_heads:
|
258
|
+
self.to_pde_intra_logits = LinearNoBias(token_z, num_pde_bins)
|
259
|
+
self.to_pde_inter_logits = LinearNoBias(token_z, num_pde_bins)
|
260
|
+
else:
|
261
|
+
self.to_pde_logits = LinearNoBias(token_z, num_pde_bins)
|
262
|
+
|
263
|
+
if self.token_level_confidence:
|
264
|
+
self.to_plddt_logits = LinearNoBias(token_s, num_plddt_bins)
|
265
|
+
self.to_resolved_logits = LinearNoBias(token_s, 2)
|
266
|
+
else:
|
267
|
+
self.to_plddt_logits = LinearNoBias(
|
268
|
+
token_s, num_plddt_bins * self.max_num_atoms_per_token
|
269
|
+
)
|
270
|
+
self.to_resolved_logits = LinearNoBias(
|
271
|
+
token_s, 2 * self.max_num_atoms_per_token
|
272
|
+
)
|
273
|
+
|
274
|
+
def forward(
|
275
|
+
self,
|
276
|
+
s, # Float['b n ts']
|
277
|
+
z, # Float['b n n tz']
|
278
|
+
x_pred, # Float['bm m 3']
|
279
|
+
d,
|
280
|
+
feats,
|
281
|
+
pred_distogram_logits,
|
282
|
+
multiplicity=1,
|
283
|
+
):
|
284
|
+
if self.use_separate_heads:
|
285
|
+
asym_id_token = feats["asym_id"]
|
286
|
+
is_same_chain = asym_id_token.unsqueeze(-1) == asym_id_token.unsqueeze(-2)
|
287
|
+
is_different_chain = ~is_same_chain
|
288
|
+
|
289
|
+
if self.use_separate_heads:
|
290
|
+
pae_intra_logits = self.to_pae_intra_logits(z)
|
291
|
+
pae_intra_logits = pae_intra_logits * is_same_chain.float().unsqueeze(-1)
|
292
|
+
|
293
|
+
pae_inter_logits = self.to_pae_inter_logits(z)
|
294
|
+
pae_inter_logits = pae_inter_logits * is_different_chain.float().unsqueeze(
|
295
|
+
-1
|
296
|
+
)
|
297
|
+
|
298
|
+
pae_logits = pae_inter_logits + pae_intra_logits
|
299
|
+
else:
|
300
|
+
pae_logits = self.to_pae_logits(z)
|
301
|
+
|
302
|
+
if self.use_separate_heads:
|
303
|
+
pde_intra_logits = self.to_pde_intra_logits(z + z.transpose(1, 2))
|
304
|
+
pde_intra_logits = pde_intra_logits * is_same_chain.float().unsqueeze(-1)
|
305
|
+
|
306
|
+
pde_inter_logits = self.to_pde_inter_logits(z + z.transpose(1, 2))
|
307
|
+
pde_inter_logits = pde_inter_logits * is_different_chain.float().unsqueeze(
|
308
|
+
-1
|
309
|
+
)
|
310
|
+
|
311
|
+
pde_logits = pde_inter_logits + pde_intra_logits
|
312
|
+
else:
|
313
|
+
pde_logits = self.to_pde_logits(z + z.transpose(1, 2))
|
314
|
+
resolved_logits = self.to_resolved_logits(s)
|
315
|
+
plddt_logits = self.to_plddt_logits(s)
|
316
|
+
|
317
|
+
ligand_weight = 20
|
318
|
+
non_interface_weight = 1
|
319
|
+
interface_weight = 10
|
320
|
+
|
321
|
+
token_type = feats["mol_type"]
|
322
|
+
token_type = token_type.repeat_interleave(multiplicity, 0)
|
323
|
+
is_ligand_token = (token_type == const.chain_type_ids["NONPOLYMER"]).float()
|
324
|
+
|
325
|
+
if self.token_level_confidence:
|
326
|
+
plddt = compute_aggregated_metric(plddt_logits)
|
327
|
+
token_pad_mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
|
328
|
+
complex_plddt = (plddt * token_pad_mask).sum(dim=-1) / token_pad_mask.sum(
|
329
|
+
dim=-1
|
330
|
+
)
|
331
|
+
|
332
|
+
is_contact = (d < 8).float()
|
333
|
+
is_different_chain = (
|
334
|
+
feats["asym_id"].unsqueeze(-1) != feats["asym_id"].unsqueeze(-2)
|
335
|
+
).float()
|
336
|
+
is_different_chain = is_different_chain.repeat_interleave(multiplicity, 0)
|
337
|
+
token_interface_mask = torch.max(
|
338
|
+
is_contact * is_different_chain * (1 - is_ligand_token).unsqueeze(-1),
|
339
|
+
dim=-1,
|
340
|
+
).values
|
341
|
+
token_non_interface_mask = (1 - token_interface_mask) * (
|
342
|
+
1 - is_ligand_token
|
343
|
+
)
|
344
|
+
iplddt_weight = (
|
345
|
+
is_ligand_token * ligand_weight
|
346
|
+
+ token_interface_mask * interface_weight
|
347
|
+
+ token_non_interface_mask * non_interface_weight
|
348
|
+
)
|
349
|
+
complex_iplddt = (plddt * token_pad_mask * iplddt_weight).sum(
|
350
|
+
dim=-1
|
351
|
+
) / torch.sum(token_pad_mask * iplddt_weight, dim=-1)
|
352
|
+
|
353
|
+
else:
|
354
|
+
# token to atom conversion for resolved logits
|
355
|
+
B, N, _ = resolved_logits.shape
|
356
|
+
resolved_logits = resolved_logits.reshape(
|
357
|
+
B, N, self.max_num_atoms_per_token, 2
|
358
|
+
)
|
359
|
+
|
360
|
+
arange_max_num_atoms = (
|
361
|
+
torch.arange(self.max_num_atoms_per_token)
|
362
|
+
.reshape(1, 1, -1)
|
363
|
+
.to(resolved_logits.device)
|
364
|
+
)
|
365
|
+
max_num_atoms_mask = (
|
366
|
+
feats["atom_to_token"].sum(1).unsqueeze(-1) > arange_max_num_atoms
|
367
|
+
)
|
368
|
+
resolved_logits = resolved_logits[:, max_num_atoms_mask.squeeze(0)]
|
369
|
+
resolved_logits = pad(
|
370
|
+
resolved_logits,
|
371
|
+
(
|
372
|
+
0,
|
373
|
+
0,
|
374
|
+
0,
|
375
|
+
int(
|
376
|
+
feats["atom_pad_mask"].shape[1]
|
377
|
+
- feats["atom_pad_mask"].sum().item()
|
378
|
+
),
|
379
|
+
),
|
380
|
+
value=0,
|
381
|
+
)
|
382
|
+
plddt_logits = plddt_logits.reshape(B, N, self.max_num_atoms_per_token, -1)
|
383
|
+
plddt_logits = plddt_logits[:, max_num_atoms_mask.squeeze(0)]
|
384
|
+
plddt_logits = pad(
|
385
|
+
plddt_logits,
|
386
|
+
(
|
387
|
+
0,
|
388
|
+
0,
|
389
|
+
0,
|
390
|
+
int(
|
391
|
+
feats["atom_pad_mask"].shape[1]
|
392
|
+
- feats["atom_pad_mask"].sum().item()
|
393
|
+
),
|
394
|
+
),
|
395
|
+
value=0,
|
396
|
+
)
|
397
|
+
atom_pad_mask = feats["atom_pad_mask"].repeat_interleave(multiplicity, 0)
|
398
|
+
plddt = compute_aggregated_metric(plddt_logits)
|
399
|
+
|
400
|
+
complex_plddt = (plddt * atom_pad_mask).sum(dim=-1) / atom_pad_mask.sum(
|
401
|
+
dim=-1
|
402
|
+
)
|
403
|
+
token_type = feats["mol_type"].float()
|
404
|
+
atom_to_token = feats["atom_to_token"].float()
|
405
|
+
chain_id_token = feats["asym_id"].float()
|
406
|
+
atom_type = torch.bmm(atom_to_token, token_type.unsqueeze(-1)).squeeze(-1)
|
407
|
+
is_ligand_atom = (atom_type == const.chain_type_ids["NONPOLYMER"]).float()
|
408
|
+
d_atom = torch.cdist(x_pred, x_pred)
|
409
|
+
is_contact = (d_atom < 8).float()
|
410
|
+
chain_id_atom = torch.bmm(
|
411
|
+
atom_to_token, chain_id_token.unsqueeze(-1)
|
412
|
+
).squeeze(-1)
|
413
|
+
is_different_chain = (
|
414
|
+
chain_id_atom.unsqueeze(-1) != chain_id_atom.unsqueeze(-2)
|
415
|
+
).float()
|
416
|
+
|
417
|
+
atom_interface_mask = torch.max(
|
418
|
+
is_contact * is_different_chain * (1 - is_ligand_atom).unsqueeze(-1),
|
419
|
+
dim=-1,
|
420
|
+
).values
|
421
|
+
atom_non_interface_mask = (1 - atom_interface_mask) * (1 - is_ligand_atom)
|
422
|
+
iplddt_weight = (
|
423
|
+
is_ligand_atom * ligand_weight
|
424
|
+
+ atom_interface_mask * interface_weight
|
425
|
+
+ atom_non_interface_mask * non_interface_weight
|
426
|
+
)
|
427
|
+
|
428
|
+
complex_iplddt = (plddt * feats["atom_pad_mask"] * iplddt_weight).sum(
|
429
|
+
dim=-1
|
430
|
+
) / torch.sum(feats["atom_pad_mask"] * iplddt_weight, dim=-1)
|
431
|
+
|
432
|
+
# Compute the gPDE and giPDE
|
433
|
+
pde = compute_aggregated_metric(pde_logits, end=32)
|
434
|
+
pred_distogram_prob = nn.functional.softmax(
|
435
|
+
pred_distogram_logits, dim=-1
|
436
|
+
).repeat_interleave(multiplicity, 0)
|
437
|
+
contacts = torch.zeros((1, 1, 1, 64), dtype=pred_distogram_prob.dtype).to(
|
438
|
+
pred_distogram_prob.device
|
439
|
+
)
|
440
|
+
contacts[:, :, :, :20] = 1.0
|
441
|
+
prob_contact = (pred_distogram_prob * contacts).sum(-1)
|
442
|
+
token_pad_mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
|
443
|
+
token_pad_pair_mask = (
|
444
|
+
token_pad_mask.unsqueeze(-1)
|
445
|
+
* token_pad_mask.unsqueeze(-2)
|
446
|
+
* (
|
447
|
+
1
|
448
|
+
- torch.eye(
|
449
|
+
token_pad_mask.shape[1], device=token_pad_mask.device
|
450
|
+
).unsqueeze(0)
|
451
|
+
)
|
452
|
+
)
|
453
|
+
token_pair_mask = token_pad_pair_mask * prob_contact
|
454
|
+
complex_pde = (pde * token_pair_mask).sum(dim=(1, 2)) / token_pair_mask.sum(
|
455
|
+
dim=(1, 2)
|
456
|
+
)
|
457
|
+
asym_id = feats["asym_id"].repeat_interleave(multiplicity, 0)
|
458
|
+
token_interface_pair_mask = token_pair_mask * (
|
459
|
+
asym_id.unsqueeze(-1) != asym_id.unsqueeze(-2)
|
460
|
+
)
|
461
|
+
complex_ipde = (pde * token_interface_pair_mask).sum(dim=(1, 2)) / (
|
462
|
+
token_interface_pair_mask.sum(dim=(1, 2)) + 1e-5
|
463
|
+
)
|
464
|
+
out_dict = dict(
|
465
|
+
pde_logits=pde_logits,
|
466
|
+
plddt_logits=plddt_logits,
|
467
|
+
resolved_logits=resolved_logits,
|
468
|
+
pde=pde,
|
469
|
+
plddt=plddt,
|
470
|
+
complex_plddt=complex_plddt,
|
471
|
+
complex_iplddt=complex_iplddt,
|
472
|
+
complex_pde=complex_pde,
|
473
|
+
complex_ipde=complex_ipde,
|
474
|
+
)
|
475
|
+
out_dict["pae_logits"] = pae_logits
|
476
|
+
out_dict["pae"] = compute_aggregated_metric(pae_logits, end=32)
|
477
|
+
|
478
|
+
try:
|
479
|
+
ptm, iptm, ligand_iptm, protein_iptm, pair_chains_iptm = compute_ptms(
|
480
|
+
pae_logits, x_pred, feats, multiplicity
|
481
|
+
)
|
482
|
+
out_dict["ptm"] = ptm
|
483
|
+
out_dict["iptm"] = iptm
|
484
|
+
out_dict["ligand_iptm"] = ligand_iptm
|
485
|
+
out_dict["protein_iptm"] = protein_iptm
|
486
|
+
out_dict["pair_chains_iptm"] = pair_chains_iptm
|
487
|
+
except Exception as e:
|
488
|
+
print(f"Error in compute_ptms: {e}")
|
489
|
+
out_dict["ptm"] = torch.zeros_like(complex_plddt)
|
490
|
+
out_dict["iptm"] = torch.zeros_like(complex_plddt)
|
491
|
+
out_dict["ligand_iptm"] = torch.zeros_like(complex_plddt)
|
492
|
+
out_dict["protein_iptm"] = torch.zeros_like(complex_plddt)
|
493
|
+
out_dict["pair_chains_iptm"] = torch.zeros_like(complex_plddt)
|
494
|
+
|
495
|
+
return out_dict
|