boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,621 @@
1
+ import torch
2
+ from torch import nn
3
+
4
+ from boltz.data import const
5
+ from boltz.model.layers.confidence_utils import compute_frame_pred, tm_function
6
+
7
+
8
+ def confidence_loss(
9
+ model_out,
10
+ feats,
11
+ true_coords,
12
+ true_coords_resolved_mask,
13
+ token_level_confidence=False,
14
+ multiplicity=1,
15
+ alpha_pae=0.0,
16
+ mask_loss=None,
17
+ relative_supervision_weight=0.0,
18
+ ):
19
+ # TODO no support for MD yet!
20
+ # TODO only apply to the PDB structures not the distillation ones
21
+ plddt, rel_plddt = plddt_loss(
22
+ model_out["plddt_logits"],
23
+ model_out["sample_atom_coords"],
24
+ feats,
25
+ true_coords,
26
+ true_coords_resolved_mask,
27
+ token_level_confidence=token_level_confidence,
28
+ multiplicity=multiplicity,
29
+ mask_loss=mask_loss,
30
+ relative_confidence_supervision=relative_supervision_weight > 0.0,
31
+ relative_pred_lddt=model_out.get("relative_plddt_logits", None),
32
+ )
33
+ pde, rel_pde = pde_loss(
34
+ model_out["pde_logits"],
35
+ model_out["sample_atom_coords"],
36
+ feats,
37
+ true_coords,
38
+ true_coords_resolved_mask,
39
+ multiplicity,
40
+ mask_loss=mask_loss,
41
+ relative_confidence_supervision=relative_supervision_weight > 0.0,
42
+ relative_pred_pde=model_out.get("relative_pde_logits", None),
43
+ )
44
+ resolved = resolved_loss(
45
+ model_out["resolved_logits"],
46
+ feats,
47
+ true_coords_resolved_mask,
48
+ token_level_confidence=token_level_confidence,
49
+ multiplicity=multiplicity,
50
+ mask_loss=mask_loss,
51
+ )
52
+
53
+ pae, rel_pae = 0.0, 0.0
54
+ if alpha_pae > 0.0:
55
+ pae, rel_pae = pae_loss(
56
+ model_out["pae_logits"],
57
+ model_out["sample_atom_coords"],
58
+ feats,
59
+ true_coords,
60
+ true_coords_resolved_mask,
61
+ multiplicity,
62
+ mask_loss=mask_loss,
63
+ relative_confidence_supervision=relative_supervision_weight > 0.0,
64
+ relative_pred_pae=model_out.get("relative_pae_logits", None),
65
+ )
66
+
67
+ loss = (
68
+ plddt
69
+ + pde
70
+ + resolved
71
+ + alpha_pae * pae
72
+ + relative_supervision_weight * (rel_plddt + rel_pde + alpha_pae * rel_pae)
73
+ )
74
+
75
+ dict_out = {
76
+ "loss": loss,
77
+ "loss_breakdown": {
78
+ "plddt_loss": plddt,
79
+ "pde_loss": pde,
80
+ "resolved_loss": resolved,
81
+ "pae_loss": pae,
82
+ "rel_plddt_loss": rel_plddt,
83
+ "rel_pde_loss": rel_pde,
84
+ "rel_pae_loss": rel_pae,
85
+ },
86
+ }
87
+ return dict_out
88
+
89
+
90
+ def resolved_loss(
91
+ pred_resolved,
92
+ feats,
93
+ true_coords_resolved_mask,
94
+ token_level_confidence=False,
95
+ multiplicity=1,
96
+ mask_loss=None,
97
+ ):
98
+ with torch.autocast("cuda", enabled=False):
99
+ if token_level_confidence:
100
+ token_to_rep_atom = feats["token_to_rep_atom"]
101
+ token_to_rep_atom = token_to_rep_atom.repeat_interleave(
102
+ multiplicity, 0
103
+ ).float()
104
+ ref_mask = torch.bmm(
105
+ token_to_rep_atom, true_coords_resolved_mask.unsqueeze(-1).float()
106
+ ).squeeze(-1)
107
+
108
+ pad_mask = feats["token_pad_mask"]
109
+ pad_mask = pad_mask.repeat_interleave(multiplicity, 0).float()
110
+ else:
111
+ ref_mask = true_coords_resolved_mask.float()
112
+ pad_mask = feats["atom_pad_mask"]
113
+ pad_mask = pad_mask.repeat_interleave(multiplicity, 0).float()
114
+ # compute loss
115
+ log_softmax_resolved = torch.nn.functional.log_softmax(
116
+ pred_resolved.float(), dim=-1
117
+ )
118
+ errors = (
119
+ -ref_mask * log_softmax_resolved[:, :, 0]
120
+ - (1 - ref_mask) * log_softmax_resolved[:, :, 1]
121
+ )
122
+ loss = torch.sum(errors * pad_mask, dim=-1) / (
123
+ 1e-7 + torch.sum(pad_mask, dim=-1)
124
+ )
125
+
126
+ # Average over the batch dimension
127
+ if mask_loss is not None:
128
+ mask_loss = (
129
+ mask_loss.repeat_interleave(multiplicity, 0)
130
+ .reshape(-1, multiplicity)
131
+ .float()
132
+ )
133
+ loss = torch.sum(loss.reshape(-1, multiplicity) * mask_loss) / (
134
+ torch.sum(mask_loss) + 1e-7
135
+ )
136
+ else:
137
+ loss = torch.mean(loss)
138
+ return loss
139
+
140
+
141
+ def get_target_lddt(
142
+ pred_atom_coords,
143
+ feats,
144
+ true_atom_coords,
145
+ true_coords_resolved_mask,
146
+ token_level_confidence=True,
147
+ multiplicity=1,
148
+ ):
149
+ with torch.cuda.amp.autocast(enabled=False):
150
+ # extract necessary features
151
+ atom_mask = true_coords_resolved_mask
152
+
153
+ R_set_to_rep_atom = feats["r_set_to_rep_atom"]
154
+ R_set_to_rep_atom = R_set_to_rep_atom.repeat_interleave(multiplicity, 0).float()
155
+
156
+ token_type = feats["mol_type"]
157
+ token_type = token_type.repeat_interleave(multiplicity, 0)
158
+ is_nucleotide_token = (token_type == const.chain_type_ids["DNA"]).float() + (
159
+ token_type == const.chain_type_ids["RNA"]
160
+ ).float()
161
+
162
+ B = true_atom_coords.shape[0]
163
+
164
+ atom_to_token = feats["atom_to_token"].float()
165
+ atom_to_token = atom_to_token.repeat_interleave(multiplicity, 0)
166
+
167
+ token_to_rep_atom = feats["token_to_rep_atom"].float()
168
+ token_to_rep_atom = token_to_rep_atom.repeat_interleave(multiplicity, 0)
169
+
170
+ if token_level_confidence:
171
+ true_token_coords = torch.bmm(token_to_rep_atom, true_atom_coords)
172
+ pred_token_coords = torch.bmm(token_to_rep_atom, pred_atom_coords)
173
+
174
+ # compute true lddt
175
+ true_d = torch.cdist(
176
+ true_token_coords if token_level_confidence else true_atom_coords,
177
+ torch.bmm(R_set_to_rep_atom, true_atom_coords),
178
+ )
179
+ pred_d = torch.cdist(
180
+ pred_token_coords if token_level_confidence else pred_atom_coords,
181
+ torch.bmm(R_set_to_rep_atom, pred_atom_coords),
182
+ )
183
+
184
+ pair_mask = atom_mask.unsqueeze(-1) * atom_mask.unsqueeze(-2)
185
+ pair_mask = (
186
+ pair_mask
187
+ * (1 - torch.eye(pair_mask.shape[1], device=pair_mask.device))[None, :, :]
188
+ )
189
+ pair_mask = torch.einsum("bnm,bkm->bnk", pair_mask, R_set_to_rep_atom)
190
+
191
+ if token_level_confidence:
192
+ pair_mask = torch.bmm(token_to_rep_atom, pair_mask)
193
+ atom_mask = torch.bmm(token_to_rep_atom, atom_mask.unsqueeze(-1).float())
194
+ is_nucleotide_R_element = torch.bmm(
195
+ R_set_to_rep_atom,
196
+ torch.bmm(atom_to_token, is_nucleotide_token.unsqueeze(-1).float()),
197
+ ).squeeze(-1)
198
+ cutoff = 15 + 15 * is_nucleotide_R_element.reshape(B, 1, -1).repeat(
199
+ 1, true_d.shape[1], 1
200
+ )
201
+ target_lddt, mask_no_match = lddt_dist(
202
+ pred_d, true_d, pair_mask, cutoff, per_atom=True
203
+ )
204
+ return target_lddt, mask_no_match, atom_mask
205
+
206
+
207
+ def plddt_loss(
208
+ pred_lddt,
209
+ pred_atom_coords,
210
+ feats,
211
+ true_atom_coords,
212
+ true_coords_resolved_mask,
213
+ token_level_confidence=False,
214
+ multiplicity=1,
215
+ mask_loss=None,
216
+ relative_confidence_supervision=False,
217
+ relative_pred_lddt=None,
218
+ ):
219
+ target_lddt, mask_no_match, atom_mask = get_target_lddt(
220
+ pred_atom_coords=pred_atom_coords,
221
+ feats=feats,
222
+ true_atom_coords=true_atom_coords,
223
+ true_coords_resolved_mask=true_coords_resolved_mask,
224
+ token_level_confidence=token_level_confidence,
225
+ multiplicity=multiplicity,
226
+ )
227
+
228
+ num_bins = pred_lddt.shape[-1]
229
+ bin_index = torch.floor(target_lddt * num_bins).long()
230
+ bin_index = torch.clamp(bin_index, max=(num_bins - 1))
231
+ lddt_one_hot = nn.functional.one_hot(bin_index, num_classes=num_bins)
232
+ errors = -1 * torch.sum(
233
+ lddt_one_hot * torch.nn.functional.log_softmax(pred_lddt, dim=-1),
234
+ dim=-1,
235
+ )
236
+ atom_mask = atom_mask.squeeze(-1)
237
+ loss = torch.sum(errors * atom_mask * mask_no_match, dim=-1) / (
238
+ 1e-7 + torch.sum(atom_mask * mask_no_match, dim=-1)
239
+ )
240
+ # Average over the batch dimension
241
+ if mask_loss is not None:
242
+ mask_loss = mask_loss.repeat_interleave(multiplicity, 0).reshape(
243
+ -1, multiplicity
244
+ )
245
+ loss = torch.sum(loss.reshape(-1, multiplicity) * mask_loss) / (
246
+ torch.sum(mask_loss) + 1e-7
247
+ )
248
+ else:
249
+ loss = torch.mean(loss)
250
+
251
+ rel_loss = 0.0
252
+ if relative_confidence_supervision:
253
+ # relative LDDT loss
254
+ B = true_atom_coords.shape[0]
255
+ relative_target_lddt = target_lddt.view(
256
+ B // multiplicity, multiplicity, 1, -1
257
+ ) - target_lddt.view(B // multiplicity, 1, multiplicity, -1)
258
+ rel_bin_index = torch.floor(
259
+ torch.abs(relative_target_lddt) * num_bins
260
+ ).long() * torch.sign(relative_target_lddt)
261
+ rel_bin_index = torch.clamp(
262
+ rel_bin_index, max=(num_bins - 1), min=-(num_bins - 1)
263
+ ).long() + (num_bins - 1)
264
+ rel_lddt_one_hot = nn.functional.one_hot(
265
+ rel_bin_index, num_classes=2 * num_bins - 1
266
+ )
267
+ rel_errors = -1 * torch.sum(
268
+ rel_lddt_one_hot
269
+ * torch.nn.functional.log_softmax(relative_pred_lddt, dim=-1),
270
+ dim=-1,
271
+ )
272
+ rel_atom_mask = atom_mask.view(B // multiplicity, multiplicity, 1, -1).repeat(
273
+ 1, 1, multiplicity, 1
274
+ )
275
+ rel_mask_no_match = mask_no_match.view(
276
+ B // multiplicity, multiplicity, 1, -1
277
+ ).repeat(1, 1, multiplicity, 1)
278
+ rel_loss = torch.sum(rel_errors * rel_atom_mask * rel_mask_no_match, dim=-1) / (
279
+ 1e-7 + torch.sum(rel_atom_mask * rel_mask_no_match, dim=-1)
280
+ )
281
+
282
+ if mask_loss is not None:
283
+ rel_mask_loss = mask_loss.view(B // multiplicity, multiplicity, 1).repeat(
284
+ 1, 1, multiplicity
285
+ )
286
+ rel_loss = torch.sum(rel_loss * rel_mask_loss) / (
287
+ torch.sum(rel_mask_loss) + 1e-7
288
+ )
289
+ else:
290
+ rel_loss = torch.mean(rel_loss)
291
+
292
+ return loss, rel_loss
293
+
294
+
295
+ def lddt_dist(dmat_predicted, dmat_true, mask, cutoff=15.0, per_atom=False):
296
+ # NOTE: the mask is a pairwise mask which should have the identity elements already masked out
297
+ # Compute mask over distances
298
+ dists_to_score = (dmat_true < cutoff).float() * mask
299
+ dist_l1 = torch.abs(dmat_true - dmat_predicted)
300
+
301
+ score = 0.25 * (
302
+ (dist_l1 < 0.5).float()
303
+ + (dist_l1 < 1.0).float()
304
+ + (dist_l1 < 2.0).float()
305
+ + (dist_l1 < 4.0).float()
306
+ )
307
+
308
+ # Normalize over the appropriate axes.
309
+ if per_atom:
310
+ mask_no_match = torch.sum(dists_to_score, dim=-1) != 0
311
+ norm = 1.0 / (1e-10 + torch.sum(dists_to_score, dim=-1))
312
+ score = norm * (1e-10 + torch.sum(dists_to_score * score, dim=-1))
313
+ return score, mask_no_match.float()
314
+ else:
315
+ norm = 1.0 / (1e-10 + torch.sum(dists_to_score, dim=(-2, -1)))
316
+ score = norm * (1e-10 + torch.sum(dists_to_score * score, dim=(-2, -1)))
317
+ total = torch.sum(dists_to_score, dim=(-1, -2))
318
+ return score, total
319
+
320
+
321
+ def express_coordinate_in_frame(atom_coords, frame_atom_a, frame_atom_b, frame_atom_c):
322
+ batch, multiplicity = atom_coords.shape[0], atom_coords.shape[1]
323
+ batch_indices0 = torch.arange(batch)[:, None, None].to(atom_coords.device)
324
+ batch_indices1 = torch.arange(multiplicity)[None, :, None].to(atom_coords.device)
325
+
326
+ # extract frame atoms
327
+ a, b, c = (
328
+ atom_coords[batch_indices0, batch_indices1, frame_atom_a],
329
+ atom_coords[batch_indices0, batch_indices1, frame_atom_b],
330
+ atom_coords[batch_indices0, batch_indices1, frame_atom_c],
331
+ )
332
+ w1 = (a - b) / (torch.norm(a - b, dim=-1, keepdim=True) + 1e-5)
333
+ w2 = (c - b) / (torch.norm(c - b, dim=-1, keepdim=True) + 1e-5)
334
+
335
+ # build orthogonal frame
336
+ e1 = (w1 + w2) / (torch.norm(w1 + w2, dim=-1, keepdim=True) + 1e-5)
337
+ e2 = (w2 - w1) / (torch.norm(w2 - w1, dim=-1, keepdim=True) + 1e-5)
338
+ e3 = torch.linalg.cross(e1, e2)
339
+
340
+ # NOTE: it is unclear based on what atom of the token the error is computed, here I will use the atom indicated by b (center of frame)
341
+
342
+ # project onto frame basis
343
+ d = b[:, :, None, :, :] - b[:, :, :, None, :]
344
+ x_transformed = torch.cat(
345
+ [
346
+ torch.sum(d * e1[:, :, :, None, :], dim=-1, keepdim=True),
347
+ torch.sum(d * e2[:, :, :, None, :], dim=-1, keepdim=True),
348
+ torch.sum(d * e3[:, :, :, None, :], dim=-1, keepdim=True),
349
+ ],
350
+ dim=-1,
351
+ )
352
+ return x_transformed
353
+
354
+
355
+ def get_target_pae(
356
+ pred_atom_coords,
357
+ feats,
358
+ true_atom_coords,
359
+ true_coords_resolved_mask,
360
+ multiplicity=1,
361
+ ):
362
+ with torch.cuda.amp.autocast(enabled=False):
363
+ # Retrieve frames and resolved masks
364
+ frames_idx_original = feats["frames_idx"]
365
+ mask_frame_true = feats["frame_resolved_mask"]
366
+
367
+ # Adjust the frames for nonpolymers after symmetry correction!
368
+ # NOTE: frames of polymers do not change under symmetry!
369
+ frames_idx_true, mask_collinear_true = compute_frame_pred(
370
+ true_atom_coords,
371
+ frames_idx_original,
372
+ feats,
373
+ multiplicity,
374
+ resolved_mask=true_coords_resolved_mask,
375
+ )
376
+
377
+ frame_true_atom_a, frame_true_atom_b, frame_true_atom_c = (
378
+ frames_idx_true[:, :, :, 0],
379
+ frames_idx_true[:, :, :, 1],
380
+ frames_idx_true[:, :, :, 2],
381
+ )
382
+ # Compute token coords in true frames
383
+ B, N, _ = true_atom_coords.shape
384
+ true_atom_coords = true_atom_coords.reshape(
385
+ B // multiplicity, multiplicity, -1, 3
386
+ )
387
+ true_coords_transformed = express_coordinate_in_frame(
388
+ true_atom_coords, frame_true_atom_a, frame_true_atom_b, frame_true_atom_c
389
+ )
390
+
391
+ # Compute pred frames and mask
392
+ frames_idx_pred, mask_collinear_pred = compute_frame_pred(
393
+ pred_atom_coords, frames_idx_original, feats, multiplicity
394
+ )
395
+ frame_pred_atom_a, frame_pred_atom_b, frame_pred_atom_c = (
396
+ frames_idx_pred[:, :, :, 0],
397
+ frames_idx_pred[:, :, :, 1],
398
+ frames_idx_pred[:, :, :, 2],
399
+ )
400
+ # Compute token coords in pred frames
401
+ B, N, _ = pred_atom_coords.shape
402
+ pred_atom_coords = pred_atom_coords.reshape(
403
+ B // multiplicity, multiplicity, -1, 3
404
+ )
405
+ pred_coords_transformed = express_coordinate_in_frame(
406
+ pred_atom_coords, frame_pred_atom_a, frame_pred_atom_b, frame_pred_atom_c
407
+ )
408
+
409
+ target_pae = torch.sqrt(
410
+ ((true_coords_transformed - pred_coords_transformed) ** 2).sum(-1) + 1e-8
411
+ )
412
+
413
+ # Compute mask for the pae loss
414
+ b_true_resolved_mask = true_coords_resolved_mask[
415
+ torch.arange(B // multiplicity)[:, None, None].to(
416
+ pred_coords_transformed.device
417
+ ),
418
+ frame_true_atom_b,
419
+ ]
420
+
421
+ pair_mask = (
422
+ mask_frame_true[:, None, :, None] # if true frame is invalid
423
+ * mask_collinear_true[:, :, :, None] # if true frame is invalid
424
+ * mask_collinear_pred[:, :, :, None] # if pred frame is invalid
425
+ * b_true_resolved_mask[:, :, None, :] # If atom j is not resolved
426
+ * feats["token_pad_mask"][:, None, :, None]
427
+ * feats["token_pad_mask"][:, None, None, :]
428
+ )
429
+ return target_pae, pair_mask
430
+
431
+
432
+ def pae_loss(
433
+ pred_pae,
434
+ pred_atom_coords,
435
+ feats,
436
+ true_atom_coords,
437
+ true_coords_resolved_mask,
438
+ multiplicity=1,
439
+ max_dist=32.0,
440
+ mask_loss=None,
441
+ relative_confidence_supervision=False,
442
+ relative_pred_pae=None,
443
+ ):
444
+ target_pae, pair_mask = get_target_pae(
445
+ pred_atom_coords=pred_atom_coords,
446
+ feats=feats,
447
+ true_atom_coords=true_atom_coords,
448
+ true_coords_resolved_mask=true_coords_resolved_mask,
449
+ multiplicity=multiplicity,
450
+ )
451
+
452
+ # compute loss
453
+ num_bins = pred_pae.shape[-1]
454
+ bin_index = torch.floor(target_pae * num_bins / max_dist).long()
455
+ bin_index = torch.clamp(bin_index, max=(num_bins - 1))
456
+ pae_one_hot = nn.functional.one_hot(bin_index, num_classes=num_bins)
457
+ errors = -1 * torch.sum(
458
+ pae_one_hot
459
+ * torch.nn.functional.log_softmax(pred_pae.reshape(pae_one_hot.shape), dim=-1),
460
+ dim=-1,
461
+ )
462
+ loss = torch.sum(errors * pair_mask, dim=(-2, -1)) / (
463
+ 1e-7 + torch.sum(pair_mask, dim=(-2, -1))
464
+ )
465
+ # Average over the batch dimension
466
+ if mask_loss is not None:
467
+ mask_loss = mask_loss.repeat_interleave(multiplicity, 0).reshape(
468
+ -1, multiplicity
469
+ )
470
+ loss = torch.sum(loss.reshape(-1, multiplicity) * mask_loss) / (
471
+ torch.sum(mask_loss) + 1e-7
472
+ )
473
+ else:
474
+ loss = torch.mean(loss)
475
+
476
+ rel_loss = 0.0
477
+ if relative_confidence_supervision:
478
+ B, N, _, _ = pred_pae.shape
479
+ rel_target_pae = target_pae.view(
480
+ B // multiplicity, multiplicity, 1, N, N
481
+ ) - target_pae.view(B // multiplicity, 1, multiplicity, N, N)
482
+ rel_bin_index = torch.floor(
483
+ torch.abs(rel_target_pae) * num_bins / max_dist
484
+ ).long() * torch.sign(rel_target_pae)
485
+ rel_bin_index = torch.clamp(
486
+ rel_bin_index, max=(num_bins - 1), min=-(num_bins - 1)
487
+ ).long() + (num_bins - 1)
488
+ rel_pae_one_hot = nn.functional.one_hot(
489
+ rel_bin_index, num_classes=2 * num_bins - 1
490
+ )
491
+ rel_errors = -1 * torch.sum(
492
+ rel_pae_one_hot
493
+ * torch.nn.functional.log_softmax(relative_pred_pae, dim=-1),
494
+ dim=-1,
495
+ )
496
+ rel_mask = pair_mask.view(B // multiplicity, multiplicity, 1, N, N).repeat(
497
+ 1, 1, multiplicity, 1, 1
498
+ )
499
+ rel_loss = torch.sum(rel_errors * rel_mask, dim=(-2, -1)) / (
500
+ 1e-7 + torch.sum(rel_mask, dim=(-2, -1))
501
+ )
502
+
503
+ if mask_loss is not None:
504
+ rel_mask_loss = mask_loss.view(B // multiplicity, multiplicity, 1).repeat(
505
+ 1, 1, multiplicity
506
+ )
507
+ rel_loss = torch.sum(rel_loss * rel_mask_loss) / (
508
+ torch.sum(rel_mask_loss) + 1e-7
509
+ )
510
+ else:
511
+ rel_loss = torch.mean(rel_loss)
512
+
513
+ return loss, rel_loss
514
+
515
+
516
+ def get_target_pde(
517
+ pred_atom_coords,
518
+ feats,
519
+ true_atom_coords,
520
+ true_coords_resolved_mask,
521
+ multiplicity=1,
522
+ ):
523
+ with torch.cuda.amp.autocast(enabled=False):
524
+ # extract necessary features
525
+ token_to_rep_atom = feats["token_to_rep_atom"]
526
+ token_to_rep_atom = token_to_rep_atom.repeat_interleave(multiplicity, 0).float()
527
+ token_mask = torch.bmm(
528
+ token_to_rep_atom, true_coords_resolved_mask.unsqueeze(-1).float()
529
+ ).squeeze(-1)
530
+ mask = token_mask.unsqueeze(-1) * token_mask.unsqueeze(-2)
531
+
532
+ # compute true pde
533
+ true_token_coords = torch.bmm(token_to_rep_atom, true_atom_coords)
534
+ pred_token_coords = torch.bmm(token_to_rep_atom, pred_atom_coords)
535
+
536
+ true_d = torch.cdist(true_token_coords, true_token_coords)
537
+ pred_d = torch.cdist(pred_token_coords, pred_token_coords)
538
+ target_pde = torch.abs(true_d - pred_d)
539
+ return target_pde, mask
540
+
541
+
542
+ def pde_loss(
543
+ pred_pde,
544
+ pred_atom_coords,
545
+ feats,
546
+ true_atom_coords,
547
+ true_coords_resolved_mask,
548
+ multiplicity=1,
549
+ max_dist=32.0,
550
+ mask_loss=None,
551
+ relative_confidence_supervision=False,
552
+ relative_pred_pde=None,
553
+ ):
554
+ target_pde, mask = get_target_pde(
555
+ pred_atom_coords=pred_atom_coords,
556
+ feats=feats,
557
+ true_atom_coords=true_atom_coords,
558
+ true_coords_resolved_mask=true_coords_resolved_mask,
559
+ multiplicity=multiplicity,
560
+ )
561
+ # compute loss
562
+ num_bins = pred_pde.shape[-1]
563
+ bin_index = torch.floor(target_pde * num_bins / max_dist).long()
564
+ bin_index = torch.clamp(bin_index, max=(num_bins - 1))
565
+ pde_one_hot = nn.functional.one_hot(bin_index, num_classes=num_bins)
566
+ errors = -1 * torch.sum(
567
+ pde_one_hot * torch.nn.functional.log_softmax(pred_pde, dim=-1),
568
+ dim=-1,
569
+ )
570
+ loss = torch.sum(errors * mask, dim=(-2, -1)) / (
571
+ 1e-7 + torch.sum(mask, dim=(-2, -1))
572
+ )
573
+ # Average over the batch dimension
574
+ if mask_loss is not None:
575
+ mask_loss = mask_loss.repeat_interleave(multiplicity, 0).reshape(
576
+ -1, multiplicity
577
+ )
578
+ loss = torch.sum(loss.reshape(-1, multiplicity) * mask_loss) / (
579
+ torch.sum(mask_loss) + 1e-7
580
+ )
581
+ else:
582
+ loss = torch.mean(loss)
583
+
584
+ rel_loss = 0.0
585
+ if relative_confidence_supervision:
586
+ B, N = target_pde.shape[:2]
587
+ rel_target_pde = target_pde.view(
588
+ B // multiplicity, multiplicity, 1, N, N
589
+ ) - target_pde.view(B // multiplicity, 1, multiplicity, N, N)
590
+ rel_bin_index = torch.floor(
591
+ torch.abs(rel_target_pde) * num_bins / max_dist
592
+ ).long() * torch.sign(rel_target_pde)
593
+ rel_bin_index = torch.clamp(
594
+ rel_bin_index, max=(num_bins - 1), min=-(num_bins - 1)
595
+ ).long() + (num_bins - 1)
596
+ rel_pde_one_hot = nn.functional.one_hot(
597
+ rel_bin_index, num_classes=2 * num_bins - 1
598
+ )
599
+ rel_errors = -1 * torch.sum(
600
+ rel_pde_one_hot
601
+ * torch.nn.functional.log_softmax(relative_pred_pde, dim=-1),
602
+ dim=-1,
603
+ )
604
+ rel_mask = mask.view(B // multiplicity, multiplicity, 1, N, N).repeat(
605
+ 1, 1, multiplicity, 1, 1
606
+ )
607
+ rel_loss = torch.sum(rel_errors * rel_mask, dim=(-2, -1)) / (
608
+ 1e-7 + torch.sum(rel_mask, dim=(-2, -1))
609
+ )
610
+
611
+ if mask_loss is not None:
612
+ rel_mask_loss = mask_loss.view(B // multiplicity, multiplicity, 1).repeat(
613
+ 1, 1, multiplicity
614
+ )
615
+ rel_loss = torch.sum(rel_loss * rel_mask_loss) / (
616
+ torch.sum(rel_mask_loss) + 1e-7
617
+ )
618
+ else:
619
+ rel_loss = torch.mean(rel_loss)
620
+
621
+ return loss, rel_loss