boltz-vsynthes 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltz/__init__.py +7 -0
- boltz/data/__init__.py +0 -0
- boltz/data/const.py +1184 -0
- boltz/data/crop/__init__.py +0 -0
- boltz/data/crop/affinity.py +164 -0
- boltz/data/crop/boltz.py +296 -0
- boltz/data/crop/cropper.py +45 -0
- boltz/data/feature/__init__.py +0 -0
- boltz/data/feature/featurizer.py +1230 -0
- boltz/data/feature/featurizerv2.py +2208 -0
- boltz/data/feature/symmetry.py +602 -0
- boltz/data/filter/__init__.py +0 -0
- boltz/data/filter/dynamic/__init__.py +0 -0
- boltz/data/filter/dynamic/date.py +76 -0
- boltz/data/filter/dynamic/filter.py +24 -0
- boltz/data/filter/dynamic/max_residues.py +37 -0
- boltz/data/filter/dynamic/resolution.py +34 -0
- boltz/data/filter/dynamic/size.py +38 -0
- boltz/data/filter/dynamic/subset.py +42 -0
- boltz/data/filter/static/__init__.py +0 -0
- boltz/data/filter/static/filter.py +26 -0
- boltz/data/filter/static/ligand.py +37 -0
- boltz/data/filter/static/polymer.py +299 -0
- boltz/data/module/__init__.py +0 -0
- boltz/data/module/inference.py +307 -0
- boltz/data/module/inferencev2.py +429 -0
- boltz/data/module/training.py +684 -0
- boltz/data/module/trainingv2.py +660 -0
- boltz/data/mol.py +900 -0
- boltz/data/msa/__init__.py +0 -0
- boltz/data/msa/mmseqs2.py +235 -0
- boltz/data/pad.py +84 -0
- boltz/data/parse/__init__.py +0 -0
- boltz/data/parse/a3m.py +134 -0
- boltz/data/parse/csv.py +100 -0
- boltz/data/parse/fasta.py +138 -0
- boltz/data/parse/mmcif.py +1239 -0
- boltz/data/parse/mmcif_with_constraints.py +1607 -0
- boltz/data/parse/schema.py +1851 -0
- boltz/data/parse/yaml.py +68 -0
- boltz/data/sample/__init__.py +0 -0
- boltz/data/sample/cluster.py +283 -0
- boltz/data/sample/distillation.py +57 -0
- boltz/data/sample/random.py +39 -0
- boltz/data/sample/sampler.py +49 -0
- boltz/data/tokenize/__init__.py +0 -0
- boltz/data/tokenize/boltz.py +195 -0
- boltz/data/tokenize/boltz2.py +396 -0
- boltz/data/tokenize/tokenizer.py +24 -0
- boltz/data/types.py +777 -0
- boltz/data/write/__init__.py +0 -0
- boltz/data/write/mmcif.py +305 -0
- boltz/data/write/pdb.py +171 -0
- boltz/data/write/utils.py +23 -0
- boltz/data/write/writer.py +330 -0
- boltz/main.py +1292 -0
- boltz/model/__init__.py +0 -0
- boltz/model/layers/__init__.py +0 -0
- boltz/model/layers/attention.py +132 -0
- boltz/model/layers/attentionv2.py +111 -0
- boltz/model/layers/confidence_utils.py +231 -0
- boltz/model/layers/dropout.py +34 -0
- boltz/model/layers/initialize.py +100 -0
- boltz/model/layers/outer_product_mean.py +98 -0
- boltz/model/layers/pair_averaging.py +135 -0
- boltz/model/layers/pairformer.py +337 -0
- boltz/model/layers/relative.py +58 -0
- boltz/model/layers/transition.py +78 -0
- boltz/model/layers/triangular_attention/__init__.py +0 -0
- boltz/model/layers/triangular_attention/attention.py +189 -0
- boltz/model/layers/triangular_attention/primitives.py +409 -0
- boltz/model/layers/triangular_attention/utils.py +380 -0
- boltz/model/layers/triangular_mult.py +212 -0
- boltz/model/loss/__init__.py +0 -0
- boltz/model/loss/bfactor.py +49 -0
- boltz/model/loss/confidence.py +590 -0
- boltz/model/loss/confidencev2.py +621 -0
- boltz/model/loss/diffusion.py +171 -0
- boltz/model/loss/diffusionv2.py +134 -0
- boltz/model/loss/distogram.py +48 -0
- boltz/model/loss/distogramv2.py +105 -0
- boltz/model/loss/validation.py +1025 -0
- boltz/model/models/__init__.py +0 -0
- boltz/model/models/boltz1.py +1286 -0
- boltz/model/models/boltz2.py +1249 -0
- boltz/model/modules/__init__.py +0 -0
- boltz/model/modules/affinity.py +223 -0
- boltz/model/modules/confidence.py +481 -0
- boltz/model/modules/confidence_utils.py +181 -0
- boltz/model/modules/confidencev2.py +495 -0
- boltz/model/modules/diffusion.py +844 -0
- boltz/model/modules/diffusion_conditioning.py +116 -0
- boltz/model/modules/diffusionv2.py +677 -0
- boltz/model/modules/encoders.py +639 -0
- boltz/model/modules/encodersv2.py +565 -0
- boltz/model/modules/transformers.py +322 -0
- boltz/model/modules/transformersv2.py +261 -0
- boltz/model/modules/trunk.py +688 -0
- boltz/model/modules/trunkv2.py +828 -0
- boltz/model/modules/utils.py +303 -0
- boltz/model/optim/__init__.py +0 -0
- boltz/model/optim/ema.py +389 -0
- boltz/model/optim/scheduler.py +99 -0
- boltz/model/potentials/__init__.py +0 -0
- boltz/model/potentials/potentials.py +497 -0
- boltz/model/potentials/schedules.py +32 -0
- boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
- boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
- boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
- boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
- boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
- boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,590 @@
|
|
1
|
+
import torch
|
2
|
+
from torch import nn
|
3
|
+
|
4
|
+
from boltz.data import const
|
5
|
+
|
6
|
+
|
7
|
+
def confidence_loss(
|
8
|
+
model_out,
|
9
|
+
feats,
|
10
|
+
true_coords,
|
11
|
+
true_coords_resolved_mask,
|
12
|
+
multiplicity=1,
|
13
|
+
alpha_pae=0.0,
|
14
|
+
):
|
15
|
+
"""Compute confidence loss.
|
16
|
+
|
17
|
+
Parameters
|
18
|
+
----------
|
19
|
+
model_out: Dict[str, torch.Tensor]
|
20
|
+
Dictionary containing the model output
|
21
|
+
feats: Dict[str, torch.Tensor]
|
22
|
+
Dictionary containing the model input
|
23
|
+
true_coords: torch.Tensor
|
24
|
+
The atom coordinates after symmetry correction
|
25
|
+
true_coords_resolved_mask: torch.Tensor
|
26
|
+
The resolved mask after symmetry correction
|
27
|
+
multiplicity: int, optional
|
28
|
+
The diffusion batch size, by default 1
|
29
|
+
alpha_pae: float, optional
|
30
|
+
The weight of the pae loss, by default 0.0
|
31
|
+
|
32
|
+
Returns
|
33
|
+
-------
|
34
|
+
Dict[str, torch.Tensor]
|
35
|
+
Loss breakdown
|
36
|
+
|
37
|
+
"""
|
38
|
+
# Compute losses
|
39
|
+
plddt = plddt_loss(
|
40
|
+
model_out["plddt_logits"],
|
41
|
+
model_out["sample_atom_coords"],
|
42
|
+
true_coords,
|
43
|
+
true_coords_resolved_mask,
|
44
|
+
feats,
|
45
|
+
multiplicity=multiplicity,
|
46
|
+
)
|
47
|
+
pde = pde_loss(
|
48
|
+
model_out["pde_logits"],
|
49
|
+
model_out["sample_atom_coords"],
|
50
|
+
true_coords,
|
51
|
+
true_coords_resolved_mask,
|
52
|
+
feats,
|
53
|
+
multiplicity,
|
54
|
+
)
|
55
|
+
resolved = resolved_loss(
|
56
|
+
model_out["resolved_logits"],
|
57
|
+
feats,
|
58
|
+
true_coords_resolved_mask,
|
59
|
+
multiplicity=multiplicity,
|
60
|
+
)
|
61
|
+
|
62
|
+
pae = 0.0
|
63
|
+
if alpha_pae > 0.0:
|
64
|
+
pae = pae_loss(
|
65
|
+
model_out["pae_logits"],
|
66
|
+
model_out["sample_atom_coords"],
|
67
|
+
true_coords,
|
68
|
+
true_coords_resolved_mask,
|
69
|
+
feats,
|
70
|
+
multiplicity,
|
71
|
+
)
|
72
|
+
|
73
|
+
loss = plddt + pde + resolved + alpha_pae * pae
|
74
|
+
|
75
|
+
dict_out = {
|
76
|
+
"loss": loss,
|
77
|
+
"loss_breakdown": {
|
78
|
+
"plddt_loss": plddt,
|
79
|
+
"pde_loss": pde,
|
80
|
+
"resolved_loss": resolved,
|
81
|
+
"pae_loss": pae,
|
82
|
+
},
|
83
|
+
}
|
84
|
+
return dict_out
|
85
|
+
|
86
|
+
|
87
|
+
def resolved_loss(
|
88
|
+
pred_resolved,
|
89
|
+
feats,
|
90
|
+
true_coords_resolved_mask,
|
91
|
+
multiplicity=1,
|
92
|
+
):
|
93
|
+
"""Compute resolved loss.
|
94
|
+
|
95
|
+
Parameters
|
96
|
+
----------
|
97
|
+
pred_resolved: torch.Tensor
|
98
|
+
The resolved logits
|
99
|
+
feats: Dict[str, torch.Tensor]
|
100
|
+
Dictionary containing the model input
|
101
|
+
true_coords_resolved_mask: torch.Tensor
|
102
|
+
The resolved mask after symmetry correction
|
103
|
+
multiplicity: int, optional
|
104
|
+
The diffusion batch size, by default 1
|
105
|
+
|
106
|
+
Returns
|
107
|
+
-------
|
108
|
+
torch.Tensor
|
109
|
+
Resolved loss
|
110
|
+
|
111
|
+
"""
|
112
|
+
|
113
|
+
# extract necessary features
|
114
|
+
token_to_rep_atom = feats["token_to_rep_atom"]
|
115
|
+
token_to_rep_atom = token_to_rep_atom.repeat_interleave(multiplicity, 0).float()
|
116
|
+
ref_mask = torch.bmm(
|
117
|
+
token_to_rep_atom, true_coords_resolved_mask.unsqueeze(-1).float()
|
118
|
+
).squeeze(-1)
|
119
|
+
pad_mask = feats["token_pad_mask"]
|
120
|
+
pad_mask = pad_mask.repeat_interleave(multiplicity, 0).float()
|
121
|
+
|
122
|
+
# compute loss
|
123
|
+
log_softmax_resolved = torch.nn.functional.log_softmax(pred_resolved, dim=-1)
|
124
|
+
errors = (
|
125
|
+
-ref_mask * log_softmax_resolved[:, :, 0]
|
126
|
+
- (1 - ref_mask) * log_softmax_resolved[:, :, 1]
|
127
|
+
)
|
128
|
+
loss = torch.sum(errors * pad_mask, dim=-1) / (1e-7 + torch.sum(pad_mask, dim=-1))
|
129
|
+
|
130
|
+
# Average over the batch dimension
|
131
|
+
loss = torch.mean(loss)
|
132
|
+
|
133
|
+
return loss
|
134
|
+
|
135
|
+
|
136
|
+
def plddt_loss(
|
137
|
+
pred_lddt,
|
138
|
+
pred_atom_coords,
|
139
|
+
true_atom_coords,
|
140
|
+
true_coords_resolved_mask,
|
141
|
+
feats,
|
142
|
+
multiplicity=1,
|
143
|
+
):
|
144
|
+
"""Compute plddt loss.
|
145
|
+
|
146
|
+
Parameters
|
147
|
+
----------
|
148
|
+
pred_lddt: torch.Tensor
|
149
|
+
The plddt logits
|
150
|
+
pred_atom_coords: torch.Tensor
|
151
|
+
The predicted atom coordinates
|
152
|
+
true_atom_coords: torch.Tensor
|
153
|
+
The atom coordinates after symmetry correction
|
154
|
+
true_coords_resolved_mask: torch.Tensor
|
155
|
+
The resolved mask after symmetry correction
|
156
|
+
feats: Dict[str, torch.Tensor]
|
157
|
+
Dictionary containing the model input
|
158
|
+
multiplicity: int, optional
|
159
|
+
The diffusion batch size, by default 1
|
160
|
+
|
161
|
+
Returns
|
162
|
+
-------
|
163
|
+
torch.Tensor
|
164
|
+
Plddt loss
|
165
|
+
|
166
|
+
"""
|
167
|
+
|
168
|
+
# extract necessary features
|
169
|
+
atom_mask = true_coords_resolved_mask
|
170
|
+
|
171
|
+
R_set_to_rep_atom = feats["r_set_to_rep_atom"]
|
172
|
+
R_set_to_rep_atom = R_set_to_rep_atom.repeat_interleave(multiplicity, 0).float()
|
173
|
+
|
174
|
+
token_type = feats["mol_type"]
|
175
|
+
token_type = token_type.repeat_interleave(multiplicity, 0)
|
176
|
+
is_nucleotide_token = (token_type == const.chain_type_ids["DNA"]).float() + (
|
177
|
+
token_type == const.chain_type_ids["RNA"]
|
178
|
+
).float()
|
179
|
+
|
180
|
+
B = true_atom_coords.shape[0]
|
181
|
+
|
182
|
+
atom_to_token = feats["atom_to_token"].float()
|
183
|
+
atom_to_token = atom_to_token.repeat_interleave(multiplicity, 0)
|
184
|
+
|
185
|
+
token_to_rep_atom = feats["token_to_rep_atom"].float()
|
186
|
+
token_to_rep_atom = token_to_rep_atom.repeat_interleave(multiplicity, 0)
|
187
|
+
|
188
|
+
true_token_coords = torch.bmm(token_to_rep_atom, true_atom_coords)
|
189
|
+
pred_token_coords = torch.bmm(token_to_rep_atom, pred_atom_coords)
|
190
|
+
|
191
|
+
# compute true lddt
|
192
|
+
true_d = torch.cdist(
|
193
|
+
true_token_coords,
|
194
|
+
torch.bmm(R_set_to_rep_atom, true_atom_coords),
|
195
|
+
)
|
196
|
+
pred_d = torch.cdist(
|
197
|
+
pred_token_coords,
|
198
|
+
torch.bmm(R_set_to_rep_atom, pred_atom_coords),
|
199
|
+
)
|
200
|
+
|
201
|
+
# compute mask
|
202
|
+
pair_mask = atom_mask.unsqueeze(-1) * atom_mask.unsqueeze(-2)
|
203
|
+
pair_mask = (
|
204
|
+
pair_mask
|
205
|
+
* (1 - torch.eye(pair_mask.shape[1], device=pair_mask.device))[None, :, :]
|
206
|
+
)
|
207
|
+
pair_mask = torch.einsum("bnm,bkm->bnk", pair_mask, R_set_to_rep_atom)
|
208
|
+
pair_mask = torch.bmm(token_to_rep_atom, pair_mask)
|
209
|
+
atom_mask = torch.bmm(token_to_rep_atom, atom_mask.unsqueeze(-1).float())
|
210
|
+
is_nucleotide_R_element = torch.bmm(
|
211
|
+
R_set_to_rep_atom, torch.bmm(atom_to_token, is_nucleotide_token.unsqueeze(-1))
|
212
|
+
).squeeze(-1)
|
213
|
+
cutoff = 15 + 15 * is_nucleotide_R_element.reshape(B, 1, -1).repeat(
|
214
|
+
1, true_d.shape[1], 1
|
215
|
+
)
|
216
|
+
|
217
|
+
# compute lddt
|
218
|
+
target_lddt, mask_no_match = lddt_dist(
|
219
|
+
pred_d, true_d, pair_mask, cutoff, per_atom=True
|
220
|
+
)
|
221
|
+
|
222
|
+
# compute loss
|
223
|
+
num_bins = pred_lddt.shape[-1]
|
224
|
+
bin_index = torch.floor(target_lddt * num_bins).long()
|
225
|
+
bin_index = torch.clamp(bin_index, max=(num_bins - 1))
|
226
|
+
lddt_one_hot = nn.functional.one_hot(bin_index, num_classes=num_bins)
|
227
|
+
errors = -1 * torch.sum(
|
228
|
+
lddt_one_hot * torch.nn.functional.log_softmax(pred_lddt, dim=-1),
|
229
|
+
dim=-1,
|
230
|
+
)
|
231
|
+
atom_mask = atom_mask.squeeze(-1)
|
232
|
+
loss = torch.sum(errors * atom_mask * mask_no_match, dim=-1) / (
|
233
|
+
1e-7 + torch.sum(atom_mask * mask_no_match, dim=-1)
|
234
|
+
)
|
235
|
+
|
236
|
+
# Average over the batch dimension
|
237
|
+
loss = torch.mean(loss)
|
238
|
+
|
239
|
+
return loss
|
240
|
+
|
241
|
+
|
242
|
+
def pde_loss(
|
243
|
+
pred_pde,
|
244
|
+
pred_atom_coords,
|
245
|
+
true_atom_coords,
|
246
|
+
true_coords_resolved_mask,
|
247
|
+
feats,
|
248
|
+
multiplicity=1,
|
249
|
+
max_dist=32.0,
|
250
|
+
):
|
251
|
+
"""Compute pde loss.
|
252
|
+
|
253
|
+
Parameters
|
254
|
+
----------
|
255
|
+
pred_pde: torch.Tensor
|
256
|
+
The pde logits
|
257
|
+
pred_atom_coords: torch.Tensor
|
258
|
+
The predicted atom coordinates
|
259
|
+
true_atom_coords: torch.Tensor
|
260
|
+
The atom coordinates after symmetry correction
|
261
|
+
true_coords_resolved_mask: torch.Tensor
|
262
|
+
The resolved mask after symmetry correction
|
263
|
+
feats: Dict[str, torch.Tensor]
|
264
|
+
Dictionary containing the model input
|
265
|
+
multiplicity: int, optional
|
266
|
+
The diffusion batch size, by default 1
|
267
|
+
|
268
|
+
Returns
|
269
|
+
-------
|
270
|
+
torch.Tensor
|
271
|
+
Pde loss
|
272
|
+
|
273
|
+
"""
|
274
|
+
|
275
|
+
# extract necessary features
|
276
|
+
token_to_rep_atom = feats["token_to_rep_atom"]
|
277
|
+
token_to_rep_atom = token_to_rep_atom.repeat_interleave(multiplicity, 0).float()
|
278
|
+
token_mask = torch.bmm(
|
279
|
+
token_to_rep_atom, true_coords_resolved_mask.unsqueeze(-1).float()
|
280
|
+
).squeeze(-1)
|
281
|
+
mask = token_mask.unsqueeze(-1) * token_mask.unsqueeze(-2)
|
282
|
+
|
283
|
+
# compute true pde
|
284
|
+
true_token_coords = torch.bmm(token_to_rep_atom, true_atom_coords)
|
285
|
+
pred_token_coords = torch.bmm(token_to_rep_atom, pred_atom_coords)
|
286
|
+
|
287
|
+
true_d = torch.cdist(true_token_coords, true_token_coords)
|
288
|
+
pred_d = torch.cdist(pred_token_coords, pred_token_coords)
|
289
|
+
target_pde = torch.abs(true_d - pred_d)
|
290
|
+
|
291
|
+
# compute loss
|
292
|
+
num_bins = pred_pde.shape[-1]
|
293
|
+
bin_index = torch.floor(target_pde * num_bins / max_dist).long()
|
294
|
+
bin_index = torch.clamp(bin_index, max=(num_bins - 1))
|
295
|
+
pde_one_hot = nn.functional.one_hot(bin_index, num_classes=num_bins)
|
296
|
+
errors = -1 * torch.sum(
|
297
|
+
pde_one_hot * torch.nn.functional.log_softmax(pred_pde, dim=-1),
|
298
|
+
dim=-1,
|
299
|
+
)
|
300
|
+
loss = torch.sum(errors * mask, dim=(-2, -1)) / (
|
301
|
+
1e-7 + torch.sum(mask, dim=(-2, -1))
|
302
|
+
)
|
303
|
+
|
304
|
+
# Average over the batch dimension
|
305
|
+
loss = torch.mean(loss)
|
306
|
+
|
307
|
+
return loss
|
308
|
+
|
309
|
+
|
310
|
+
def pae_loss(
|
311
|
+
pred_pae,
|
312
|
+
pred_atom_coords,
|
313
|
+
true_atom_coords,
|
314
|
+
true_coords_resolved_mask,
|
315
|
+
feats,
|
316
|
+
multiplicity=1,
|
317
|
+
max_dist=32.0,
|
318
|
+
):
|
319
|
+
"""Compute pae loss.
|
320
|
+
|
321
|
+
Parameters
|
322
|
+
----------
|
323
|
+
pred_pae: torch.Tensor
|
324
|
+
The pae logits
|
325
|
+
pred_atom_coords: torch.Tensor
|
326
|
+
The predicted atom coordinates
|
327
|
+
true_atom_coords: torch.Tensor
|
328
|
+
The atom coordinates after symmetry correction
|
329
|
+
true_coords_resolved_mask: torch.Tensor
|
330
|
+
The resolved mask after symmetry correction
|
331
|
+
feats: Dict[str, torch.Tensor]
|
332
|
+
Dictionary containing the model input
|
333
|
+
multiplicity: int, optional
|
334
|
+
The diffusion batch size, by default 1
|
335
|
+
|
336
|
+
Returns
|
337
|
+
-------
|
338
|
+
torch.Tensor
|
339
|
+
Pae loss
|
340
|
+
|
341
|
+
"""
|
342
|
+
# Retrieve frames and resolved masks
|
343
|
+
frames_idx_original = feats["frames_idx"]
|
344
|
+
mask_frame_true = feats["frame_resolved_mask"]
|
345
|
+
|
346
|
+
# Adjust the frames for nonpolymers after symmetry correction!
|
347
|
+
# NOTE: frames of polymers do not change under symmetry!
|
348
|
+
frames_idx_true, mask_collinear_true = compute_frame_pred(
|
349
|
+
true_atom_coords,
|
350
|
+
frames_idx_original,
|
351
|
+
feats,
|
352
|
+
multiplicity,
|
353
|
+
resolved_mask=true_coords_resolved_mask,
|
354
|
+
)
|
355
|
+
|
356
|
+
frame_true_atom_a, frame_true_atom_b, frame_true_atom_c = (
|
357
|
+
frames_idx_true[:, :, :, 0],
|
358
|
+
frames_idx_true[:, :, :, 1],
|
359
|
+
frames_idx_true[:, :, :, 2],
|
360
|
+
)
|
361
|
+
# Compute token coords in true frames
|
362
|
+
B, N, _ = true_atom_coords.shape
|
363
|
+
true_atom_coords = true_atom_coords.reshape(B // multiplicity, multiplicity, -1, 3)
|
364
|
+
true_coords_transformed = express_coordinate_in_frame(
|
365
|
+
true_atom_coords, frame_true_atom_a, frame_true_atom_b, frame_true_atom_c
|
366
|
+
)
|
367
|
+
|
368
|
+
# Compute pred frames and mask
|
369
|
+
frames_idx_pred, mask_collinear_pred = compute_frame_pred(
|
370
|
+
pred_atom_coords, frames_idx_original, feats, multiplicity
|
371
|
+
)
|
372
|
+
frame_pred_atom_a, frame_pred_atom_b, frame_pred_atom_c = (
|
373
|
+
frames_idx_pred[:, :, :, 0],
|
374
|
+
frames_idx_pred[:, :, :, 1],
|
375
|
+
frames_idx_pred[:, :, :, 2],
|
376
|
+
)
|
377
|
+
# Compute token coords in pred frames
|
378
|
+
B, N, _ = pred_atom_coords.shape
|
379
|
+
pred_atom_coords = pred_atom_coords.reshape(B // multiplicity, multiplicity, -1, 3)
|
380
|
+
pred_coords_transformed = express_coordinate_in_frame(
|
381
|
+
pred_atom_coords, frame_pred_atom_a, frame_pred_atom_b, frame_pred_atom_c
|
382
|
+
)
|
383
|
+
|
384
|
+
target_pae = torch.sqrt(
|
385
|
+
((true_coords_transformed - pred_coords_transformed) ** 2).sum(-1) + 1e-8
|
386
|
+
)
|
387
|
+
|
388
|
+
# Compute mask for the pae loss
|
389
|
+
b_true_resolved_mask = true_coords_resolved_mask[
|
390
|
+
torch.arange(B // multiplicity)[:, None, None].to(
|
391
|
+
pred_coords_transformed.device
|
392
|
+
),
|
393
|
+
frame_true_atom_b,
|
394
|
+
]
|
395
|
+
|
396
|
+
pair_mask = (
|
397
|
+
mask_frame_true[:, None, :, None] # if true frame is invalid
|
398
|
+
* mask_collinear_true[:, :, :, None] # if true frame is invalid
|
399
|
+
* mask_collinear_pred[:, :, :, None] # if pred frame is invalid
|
400
|
+
* b_true_resolved_mask[:, :, None, :] # If atom j is not resolved
|
401
|
+
* feats["token_pad_mask"][:, None, :, None]
|
402
|
+
* feats["token_pad_mask"][:, None, None, :]
|
403
|
+
)
|
404
|
+
|
405
|
+
# compute loss
|
406
|
+
num_bins = pred_pae.shape[-1]
|
407
|
+
bin_index = torch.floor(target_pae * num_bins / max_dist).long()
|
408
|
+
bin_index = torch.clamp(bin_index, max=(num_bins - 1))
|
409
|
+
pae_one_hot = nn.functional.one_hot(bin_index, num_classes=num_bins)
|
410
|
+
errors = -1 * torch.sum(
|
411
|
+
pae_one_hot
|
412
|
+
* torch.nn.functional.log_softmax(pred_pae.reshape(pae_one_hot.shape), dim=-1),
|
413
|
+
dim=-1,
|
414
|
+
)
|
415
|
+
loss = torch.sum(errors * pair_mask, dim=(-2, -1)) / (
|
416
|
+
1e-7 + torch.sum(pair_mask, dim=(-2, -1))
|
417
|
+
)
|
418
|
+
# Average over the batch dimension
|
419
|
+
loss = torch.mean(loss)
|
420
|
+
|
421
|
+
return loss
|
422
|
+
|
423
|
+
|
424
|
+
def lddt_dist(dmat_predicted, dmat_true, mask, cutoff=15.0, per_atom=False):
|
425
|
+
# NOTE: the mask is a pairwise mask which should have the identity elements already masked out
|
426
|
+
# Compute mask over distances
|
427
|
+
dists_to_score = (dmat_true < cutoff).float() * mask
|
428
|
+
dist_l1 = torch.abs(dmat_true - dmat_predicted)
|
429
|
+
|
430
|
+
score = 0.25 * (
|
431
|
+
(dist_l1 < 0.5).float()
|
432
|
+
+ (dist_l1 < 1.0).float()
|
433
|
+
+ (dist_l1 < 2.0).float()
|
434
|
+
+ (dist_l1 < 4.0).float()
|
435
|
+
)
|
436
|
+
|
437
|
+
# Normalize over the appropriate axes.
|
438
|
+
if per_atom:
|
439
|
+
mask_no_match = torch.sum(dists_to_score, dim=-1) != 0
|
440
|
+
norm = 1.0 / (1e-10 + torch.sum(dists_to_score, dim=-1))
|
441
|
+
score = norm * (1e-10 + torch.sum(dists_to_score * score, dim=-1))
|
442
|
+
return score, mask_no_match.float()
|
443
|
+
else:
|
444
|
+
norm = 1.0 / (1e-10 + torch.sum(dists_to_score, dim=(-2, -1)))
|
445
|
+
score = norm * (1e-10 + torch.sum(dists_to_score * score, dim=(-2, -1)))
|
446
|
+
total = torch.sum(dists_to_score, dim=(-1, -2))
|
447
|
+
return score, total
|
448
|
+
|
449
|
+
|
450
|
+
def express_coordinate_in_frame(atom_coords, frame_atom_a, frame_atom_b, frame_atom_c):
|
451
|
+
batch, multiplicity = atom_coords.shape[0], atom_coords.shape[1]
|
452
|
+
batch_indices0 = torch.arange(batch)[:, None, None].to(atom_coords.device)
|
453
|
+
batch_indices1 = torch.arange(multiplicity)[None, :, None].to(atom_coords.device)
|
454
|
+
|
455
|
+
# extract frame atoms
|
456
|
+
a, b, c = (
|
457
|
+
atom_coords[batch_indices0, batch_indices1, frame_atom_a],
|
458
|
+
atom_coords[batch_indices0, batch_indices1, frame_atom_b],
|
459
|
+
atom_coords[batch_indices0, batch_indices1, frame_atom_c],
|
460
|
+
)
|
461
|
+
w1 = (a - b) / (torch.norm(a - b, dim=-1, keepdim=True) + 1e-5)
|
462
|
+
w2 = (c - b) / (torch.norm(c - b, dim=-1, keepdim=True) + 1e-5)
|
463
|
+
|
464
|
+
# build orthogonal frame
|
465
|
+
e1 = (w1 + w2) / (torch.norm(w1 + w2, dim=-1, keepdim=True) + 1e-5)
|
466
|
+
e2 = (w2 - w1) / (torch.norm(w2 - w1, dim=-1, keepdim=True) + 1e-5)
|
467
|
+
e3 = torch.linalg.cross(e1, e2)
|
468
|
+
|
469
|
+
# project onto frame basis
|
470
|
+
d = b[:, :, None, :, :] - b[:, :, :, None, :]
|
471
|
+
x_transformed = torch.cat(
|
472
|
+
[
|
473
|
+
torch.sum(d * e1[:, :, :, None, :], dim=-1, keepdim=True),
|
474
|
+
torch.sum(d * e2[:, :, :, None, :], dim=-1, keepdim=True),
|
475
|
+
torch.sum(d * e3[:, :, :, None, :], dim=-1, keepdim=True),
|
476
|
+
],
|
477
|
+
dim=-1,
|
478
|
+
)
|
479
|
+
return x_transformed
|
480
|
+
|
481
|
+
|
482
|
+
def compute_collinear_mask(v1, v2):
|
483
|
+
# Compute the mask for collinear or overlapping atoms
|
484
|
+
norm1 = torch.norm(v1, dim=1, keepdim=True)
|
485
|
+
norm2 = torch.norm(v2, dim=1, keepdim=True)
|
486
|
+
v1 = v1 / (norm1 + 1e-6)
|
487
|
+
v2 = v2 / (norm2 + 1e-6)
|
488
|
+
mask_angle = torch.abs(torch.sum(v1 * v2, dim=1)) < 0.9063
|
489
|
+
mask_overlap1 = norm1.reshape(-1) > 1e-2
|
490
|
+
mask_overlap2 = norm2.reshape(-1) > 1e-2
|
491
|
+
return mask_angle & mask_overlap1 & mask_overlap2
|
492
|
+
|
493
|
+
|
494
|
+
def compute_frame_pred(
|
495
|
+
pred_atom_coords,
|
496
|
+
frames_idx_true,
|
497
|
+
feats,
|
498
|
+
multiplicity,
|
499
|
+
resolved_mask=None,
|
500
|
+
inference=False,
|
501
|
+
):
|
502
|
+
# extract necessary features
|
503
|
+
asym_id_token = feats["asym_id"]
|
504
|
+
asym_id_atom = torch.bmm(
|
505
|
+
feats["atom_to_token"].float(), asym_id_token.unsqueeze(-1).float()
|
506
|
+
).squeeze(-1)
|
507
|
+
B, N, _ = pred_atom_coords.shape
|
508
|
+
pred_atom_coords = pred_atom_coords.reshape(B // multiplicity, multiplicity, -1, 3)
|
509
|
+
frames_idx_pred = (
|
510
|
+
frames_idx_true.clone()
|
511
|
+
.repeat_interleave(multiplicity, 0)
|
512
|
+
.reshape(B // multiplicity, multiplicity, -1, 3)
|
513
|
+
)
|
514
|
+
|
515
|
+
# Iterate through the batch and update the frames for nonpolymers
|
516
|
+
for i, pred_atom_coord in enumerate(pred_atom_coords):
|
517
|
+
token_idx = 0
|
518
|
+
atom_idx = 0
|
519
|
+
for id in torch.unique(asym_id_token[i]):
|
520
|
+
mask_chain_token = (asym_id_token[i] == id) * feats["token_pad_mask"][i]
|
521
|
+
mask_chain_atom = (asym_id_atom[i] == id) * feats["atom_pad_mask"][i]
|
522
|
+
num_tokens = int(mask_chain_token.sum().item())
|
523
|
+
num_atoms = int(mask_chain_atom.sum().item())
|
524
|
+
if (
|
525
|
+
feats["mol_type"][i, token_idx] != const.chain_type_ids["NONPOLYMER"]
|
526
|
+
or num_atoms < 3
|
527
|
+
):
|
528
|
+
token_idx += num_tokens
|
529
|
+
atom_idx += num_atoms
|
530
|
+
continue
|
531
|
+
dist_mat = (
|
532
|
+
(
|
533
|
+
pred_atom_coord[:, mask_chain_atom.bool()][:, None, :, :]
|
534
|
+
- pred_atom_coord[:, mask_chain_atom.bool()][:, :, None, :]
|
535
|
+
)
|
536
|
+
** 2
|
537
|
+
).sum(-1) ** 0.5
|
538
|
+
|
539
|
+
# Sort the atoms by distance
|
540
|
+
if inference:
|
541
|
+
resolved_pair = 1 - (
|
542
|
+
feats["atom_pad_mask"][i][mask_chain_atom.bool()][None, :]
|
543
|
+
* feats["atom_pad_mask"][i][mask_chain_atom.bool()][:, None]
|
544
|
+
).to(torch.float32)
|
545
|
+
resolved_pair[resolved_pair == 1] = torch.inf
|
546
|
+
indices = torch.sort(dist_mat + resolved_pair, axis=2).indices
|
547
|
+
else:
|
548
|
+
if resolved_mask is None:
|
549
|
+
resolved_mask = feats["atom_resolved_mask"]
|
550
|
+
resolved_pair = 1 - (
|
551
|
+
resolved_mask[i][mask_chain_atom.bool()][None, :]
|
552
|
+
* resolved_mask[i][mask_chain_atom.bool()][:, None]
|
553
|
+
).to(torch.float32)
|
554
|
+
resolved_pair[resolved_pair == 1] = torch.inf
|
555
|
+
indices = torch.sort(dist_mat + resolved_pair, axis=2).indices
|
556
|
+
|
557
|
+
# Compute the frames
|
558
|
+
frames = (
|
559
|
+
torch.cat(
|
560
|
+
[
|
561
|
+
indices[:, :, 1:2],
|
562
|
+
indices[:, :, 0:1],
|
563
|
+
indices[:, :, 2:3],
|
564
|
+
],
|
565
|
+
dim=2,
|
566
|
+
)
|
567
|
+
+ atom_idx
|
568
|
+
)
|
569
|
+
frames_idx_pred[i, :, token_idx : token_idx + num_atoms, :] = frames
|
570
|
+
token_idx += num_tokens
|
571
|
+
atom_idx += num_atoms
|
572
|
+
|
573
|
+
# Expand the frames with the multiplicity
|
574
|
+
frames_expanded = pred_atom_coords[
|
575
|
+
torch.arange(0, B // multiplicity, 1)[:, None, None, None].to(
|
576
|
+
frames_idx_pred.device
|
577
|
+
),
|
578
|
+
torch.arange(0, multiplicity, 1)[None, :, None, None].to(
|
579
|
+
frames_idx_pred.device
|
580
|
+
),
|
581
|
+
frames_idx_pred,
|
582
|
+
].reshape(-1, 3, 3)
|
583
|
+
|
584
|
+
# Compute masks for collinear or overlapping atoms in the frame
|
585
|
+
mask_collinear_pred = compute_collinear_mask(
|
586
|
+
frames_expanded[:, 1] - frames_expanded[:, 0],
|
587
|
+
frames_expanded[:, 1] - frames_expanded[:, 2],
|
588
|
+
).reshape(B // multiplicity, multiplicity, -1)
|
589
|
+
|
590
|
+
return frames_idx_pred, mask_collinear_pred * feats["token_pad_mask"][:, None, :]
|