boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,590 @@
1
+ import torch
2
+ from torch import nn
3
+
4
+ from boltz.data import const
5
+
6
+
7
+ def confidence_loss(
8
+ model_out,
9
+ feats,
10
+ true_coords,
11
+ true_coords_resolved_mask,
12
+ multiplicity=1,
13
+ alpha_pae=0.0,
14
+ ):
15
+ """Compute confidence loss.
16
+
17
+ Parameters
18
+ ----------
19
+ model_out: Dict[str, torch.Tensor]
20
+ Dictionary containing the model output
21
+ feats: Dict[str, torch.Tensor]
22
+ Dictionary containing the model input
23
+ true_coords: torch.Tensor
24
+ The atom coordinates after symmetry correction
25
+ true_coords_resolved_mask: torch.Tensor
26
+ The resolved mask after symmetry correction
27
+ multiplicity: int, optional
28
+ The diffusion batch size, by default 1
29
+ alpha_pae: float, optional
30
+ The weight of the pae loss, by default 0.0
31
+
32
+ Returns
33
+ -------
34
+ Dict[str, torch.Tensor]
35
+ Loss breakdown
36
+
37
+ """
38
+ # Compute losses
39
+ plddt = plddt_loss(
40
+ model_out["plddt_logits"],
41
+ model_out["sample_atom_coords"],
42
+ true_coords,
43
+ true_coords_resolved_mask,
44
+ feats,
45
+ multiplicity=multiplicity,
46
+ )
47
+ pde = pde_loss(
48
+ model_out["pde_logits"],
49
+ model_out["sample_atom_coords"],
50
+ true_coords,
51
+ true_coords_resolved_mask,
52
+ feats,
53
+ multiplicity,
54
+ )
55
+ resolved = resolved_loss(
56
+ model_out["resolved_logits"],
57
+ feats,
58
+ true_coords_resolved_mask,
59
+ multiplicity=multiplicity,
60
+ )
61
+
62
+ pae = 0.0
63
+ if alpha_pae > 0.0:
64
+ pae = pae_loss(
65
+ model_out["pae_logits"],
66
+ model_out["sample_atom_coords"],
67
+ true_coords,
68
+ true_coords_resolved_mask,
69
+ feats,
70
+ multiplicity,
71
+ )
72
+
73
+ loss = plddt + pde + resolved + alpha_pae * pae
74
+
75
+ dict_out = {
76
+ "loss": loss,
77
+ "loss_breakdown": {
78
+ "plddt_loss": plddt,
79
+ "pde_loss": pde,
80
+ "resolved_loss": resolved,
81
+ "pae_loss": pae,
82
+ },
83
+ }
84
+ return dict_out
85
+
86
+
87
+ def resolved_loss(
88
+ pred_resolved,
89
+ feats,
90
+ true_coords_resolved_mask,
91
+ multiplicity=1,
92
+ ):
93
+ """Compute resolved loss.
94
+
95
+ Parameters
96
+ ----------
97
+ pred_resolved: torch.Tensor
98
+ The resolved logits
99
+ feats: Dict[str, torch.Tensor]
100
+ Dictionary containing the model input
101
+ true_coords_resolved_mask: torch.Tensor
102
+ The resolved mask after symmetry correction
103
+ multiplicity: int, optional
104
+ The diffusion batch size, by default 1
105
+
106
+ Returns
107
+ -------
108
+ torch.Tensor
109
+ Resolved loss
110
+
111
+ """
112
+
113
+ # extract necessary features
114
+ token_to_rep_atom = feats["token_to_rep_atom"]
115
+ token_to_rep_atom = token_to_rep_atom.repeat_interleave(multiplicity, 0).float()
116
+ ref_mask = torch.bmm(
117
+ token_to_rep_atom, true_coords_resolved_mask.unsqueeze(-1).float()
118
+ ).squeeze(-1)
119
+ pad_mask = feats["token_pad_mask"]
120
+ pad_mask = pad_mask.repeat_interleave(multiplicity, 0).float()
121
+
122
+ # compute loss
123
+ log_softmax_resolved = torch.nn.functional.log_softmax(pred_resolved, dim=-1)
124
+ errors = (
125
+ -ref_mask * log_softmax_resolved[:, :, 0]
126
+ - (1 - ref_mask) * log_softmax_resolved[:, :, 1]
127
+ )
128
+ loss = torch.sum(errors * pad_mask, dim=-1) / (1e-7 + torch.sum(pad_mask, dim=-1))
129
+
130
+ # Average over the batch dimension
131
+ loss = torch.mean(loss)
132
+
133
+ return loss
134
+
135
+
136
+ def plddt_loss(
137
+ pred_lddt,
138
+ pred_atom_coords,
139
+ true_atom_coords,
140
+ true_coords_resolved_mask,
141
+ feats,
142
+ multiplicity=1,
143
+ ):
144
+ """Compute plddt loss.
145
+
146
+ Parameters
147
+ ----------
148
+ pred_lddt: torch.Tensor
149
+ The plddt logits
150
+ pred_atom_coords: torch.Tensor
151
+ The predicted atom coordinates
152
+ true_atom_coords: torch.Tensor
153
+ The atom coordinates after symmetry correction
154
+ true_coords_resolved_mask: torch.Tensor
155
+ The resolved mask after symmetry correction
156
+ feats: Dict[str, torch.Tensor]
157
+ Dictionary containing the model input
158
+ multiplicity: int, optional
159
+ The diffusion batch size, by default 1
160
+
161
+ Returns
162
+ -------
163
+ torch.Tensor
164
+ Plddt loss
165
+
166
+ """
167
+
168
+ # extract necessary features
169
+ atom_mask = true_coords_resolved_mask
170
+
171
+ R_set_to_rep_atom = feats["r_set_to_rep_atom"]
172
+ R_set_to_rep_atom = R_set_to_rep_atom.repeat_interleave(multiplicity, 0).float()
173
+
174
+ token_type = feats["mol_type"]
175
+ token_type = token_type.repeat_interleave(multiplicity, 0)
176
+ is_nucleotide_token = (token_type == const.chain_type_ids["DNA"]).float() + (
177
+ token_type == const.chain_type_ids["RNA"]
178
+ ).float()
179
+
180
+ B = true_atom_coords.shape[0]
181
+
182
+ atom_to_token = feats["atom_to_token"].float()
183
+ atom_to_token = atom_to_token.repeat_interleave(multiplicity, 0)
184
+
185
+ token_to_rep_atom = feats["token_to_rep_atom"].float()
186
+ token_to_rep_atom = token_to_rep_atom.repeat_interleave(multiplicity, 0)
187
+
188
+ true_token_coords = torch.bmm(token_to_rep_atom, true_atom_coords)
189
+ pred_token_coords = torch.bmm(token_to_rep_atom, pred_atom_coords)
190
+
191
+ # compute true lddt
192
+ true_d = torch.cdist(
193
+ true_token_coords,
194
+ torch.bmm(R_set_to_rep_atom, true_atom_coords),
195
+ )
196
+ pred_d = torch.cdist(
197
+ pred_token_coords,
198
+ torch.bmm(R_set_to_rep_atom, pred_atom_coords),
199
+ )
200
+
201
+ # compute mask
202
+ pair_mask = atom_mask.unsqueeze(-1) * atom_mask.unsqueeze(-2)
203
+ pair_mask = (
204
+ pair_mask
205
+ * (1 - torch.eye(pair_mask.shape[1], device=pair_mask.device))[None, :, :]
206
+ )
207
+ pair_mask = torch.einsum("bnm,bkm->bnk", pair_mask, R_set_to_rep_atom)
208
+ pair_mask = torch.bmm(token_to_rep_atom, pair_mask)
209
+ atom_mask = torch.bmm(token_to_rep_atom, atom_mask.unsqueeze(-1).float())
210
+ is_nucleotide_R_element = torch.bmm(
211
+ R_set_to_rep_atom, torch.bmm(atom_to_token, is_nucleotide_token.unsqueeze(-1))
212
+ ).squeeze(-1)
213
+ cutoff = 15 + 15 * is_nucleotide_R_element.reshape(B, 1, -1).repeat(
214
+ 1, true_d.shape[1], 1
215
+ )
216
+
217
+ # compute lddt
218
+ target_lddt, mask_no_match = lddt_dist(
219
+ pred_d, true_d, pair_mask, cutoff, per_atom=True
220
+ )
221
+
222
+ # compute loss
223
+ num_bins = pred_lddt.shape[-1]
224
+ bin_index = torch.floor(target_lddt * num_bins).long()
225
+ bin_index = torch.clamp(bin_index, max=(num_bins - 1))
226
+ lddt_one_hot = nn.functional.one_hot(bin_index, num_classes=num_bins)
227
+ errors = -1 * torch.sum(
228
+ lddt_one_hot * torch.nn.functional.log_softmax(pred_lddt, dim=-1),
229
+ dim=-1,
230
+ )
231
+ atom_mask = atom_mask.squeeze(-1)
232
+ loss = torch.sum(errors * atom_mask * mask_no_match, dim=-1) / (
233
+ 1e-7 + torch.sum(atom_mask * mask_no_match, dim=-1)
234
+ )
235
+
236
+ # Average over the batch dimension
237
+ loss = torch.mean(loss)
238
+
239
+ return loss
240
+
241
+
242
+ def pde_loss(
243
+ pred_pde,
244
+ pred_atom_coords,
245
+ true_atom_coords,
246
+ true_coords_resolved_mask,
247
+ feats,
248
+ multiplicity=1,
249
+ max_dist=32.0,
250
+ ):
251
+ """Compute pde loss.
252
+
253
+ Parameters
254
+ ----------
255
+ pred_pde: torch.Tensor
256
+ The pde logits
257
+ pred_atom_coords: torch.Tensor
258
+ The predicted atom coordinates
259
+ true_atom_coords: torch.Tensor
260
+ The atom coordinates after symmetry correction
261
+ true_coords_resolved_mask: torch.Tensor
262
+ The resolved mask after symmetry correction
263
+ feats: Dict[str, torch.Tensor]
264
+ Dictionary containing the model input
265
+ multiplicity: int, optional
266
+ The diffusion batch size, by default 1
267
+
268
+ Returns
269
+ -------
270
+ torch.Tensor
271
+ Pde loss
272
+
273
+ """
274
+
275
+ # extract necessary features
276
+ token_to_rep_atom = feats["token_to_rep_atom"]
277
+ token_to_rep_atom = token_to_rep_atom.repeat_interleave(multiplicity, 0).float()
278
+ token_mask = torch.bmm(
279
+ token_to_rep_atom, true_coords_resolved_mask.unsqueeze(-1).float()
280
+ ).squeeze(-1)
281
+ mask = token_mask.unsqueeze(-1) * token_mask.unsqueeze(-2)
282
+
283
+ # compute true pde
284
+ true_token_coords = torch.bmm(token_to_rep_atom, true_atom_coords)
285
+ pred_token_coords = torch.bmm(token_to_rep_atom, pred_atom_coords)
286
+
287
+ true_d = torch.cdist(true_token_coords, true_token_coords)
288
+ pred_d = torch.cdist(pred_token_coords, pred_token_coords)
289
+ target_pde = torch.abs(true_d - pred_d)
290
+
291
+ # compute loss
292
+ num_bins = pred_pde.shape[-1]
293
+ bin_index = torch.floor(target_pde * num_bins / max_dist).long()
294
+ bin_index = torch.clamp(bin_index, max=(num_bins - 1))
295
+ pde_one_hot = nn.functional.one_hot(bin_index, num_classes=num_bins)
296
+ errors = -1 * torch.sum(
297
+ pde_one_hot * torch.nn.functional.log_softmax(pred_pde, dim=-1),
298
+ dim=-1,
299
+ )
300
+ loss = torch.sum(errors * mask, dim=(-2, -1)) / (
301
+ 1e-7 + torch.sum(mask, dim=(-2, -1))
302
+ )
303
+
304
+ # Average over the batch dimension
305
+ loss = torch.mean(loss)
306
+
307
+ return loss
308
+
309
+
310
+ def pae_loss(
311
+ pred_pae,
312
+ pred_atom_coords,
313
+ true_atom_coords,
314
+ true_coords_resolved_mask,
315
+ feats,
316
+ multiplicity=1,
317
+ max_dist=32.0,
318
+ ):
319
+ """Compute pae loss.
320
+
321
+ Parameters
322
+ ----------
323
+ pred_pae: torch.Tensor
324
+ The pae logits
325
+ pred_atom_coords: torch.Tensor
326
+ The predicted atom coordinates
327
+ true_atom_coords: torch.Tensor
328
+ The atom coordinates after symmetry correction
329
+ true_coords_resolved_mask: torch.Tensor
330
+ The resolved mask after symmetry correction
331
+ feats: Dict[str, torch.Tensor]
332
+ Dictionary containing the model input
333
+ multiplicity: int, optional
334
+ The diffusion batch size, by default 1
335
+
336
+ Returns
337
+ -------
338
+ torch.Tensor
339
+ Pae loss
340
+
341
+ """
342
+ # Retrieve frames and resolved masks
343
+ frames_idx_original = feats["frames_idx"]
344
+ mask_frame_true = feats["frame_resolved_mask"]
345
+
346
+ # Adjust the frames for nonpolymers after symmetry correction!
347
+ # NOTE: frames of polymers do not change under symmetry!
348
+ frames_idx_true, mask_collinear_true = compute_frame_pred(
349
+ true_atom_coords,
350
+ frames_idx_original,
351
+ feats,
352
+ multiplicity,
353
+ resolved_mask=true_coords_resolved_mask,
354
+ )
355
+
356
+ frame_true_atom_a, frame_true_atom_b, frame_true_atom_c = (
357
+ frames_idx_true[:, :, :, 0],
358
+ frames_idx_true[:, :, :, 1],
359
+ frames_idx_true[:, :, :, 2],
360
+ )
361
+ # Compute token coords in true frames
362
+ B, N, _ = true_atom_coords.shape
363
+ true_atom_coords = true_atom_coords.reshape(B // multiplicity, multiplicity, -1, 3)
364
+ true_coords_transformed = express_coordinate_in_frame(
365
+ true_atom_coords, frame_true_atom_a, frame_true_atom_b, frame_true_atom_c
366
+ )
367
+
368
+ # Compute pred frames and mask
369
+ frames_idx_pred, mask_collinear_pred = compute_frame_pred(
370
+ pred_atom_coords, frames_idx_original, feats, multiplicity
371
+ )
372
+ frame_pred_atom_a, frame_pred_atom_b, frame_pred_atom_c = (
373
+ frames_idx_pred[:, :, :, 0],
374
+ frames_idx_pred[:, :, :, 1],
375
+ frames_idx_pred[:, :, :, 2],
376
+ )
377
+ # Compute token coords in pred frames
378
+ B, N, _ = pred_atom_coords.shape
379
+ pred_atom_coords = pred_atom_coords.reshape(B // multiplicity, multiplicity, -1, 3)
380
+ pred_coords_transformed = express_coordinate_in_frame(
381
+ pred_atom_coords, frame_pred_atom_a, frame_pred_atom_b, frame_pred_atom_c
382
+ )
383
+
384
+ target_pae = torch.sqrt(
385
+ ((true_coords_transformed - pred_coords_transformed) ** 2).sum(-1) + 1e-8
386
+ )
387
+
388
+ # Compute mask for the pae loss
389
+ b_true_resolved_mask = true_coords_resolved_mask[
390
+ torch.arange(B // multiplicity)[:, None, None].to(
391
+ pred_coords_transformed.device
392
+ ),
393
+ frame_true_atom_b,
394
+ ]
395
+
396
+ pair_mask = (
397
+ mask_frame_true[:, None, :, None] # if true frame is invalid
398
+ * mask_collinear_true[:, :, :, None] # if true frame is invalid
399
+ * mask_collinear_pred[:, :, :, None] # if pred frame is invalid
400
+ * b_true_resolved_mask[:, :, None, :] # If atom j is not resolved
401
+ * feats["token_pad_mask"][:, None, :, None]
402
+ * feats["token_pad_mask"][:, None, None, :]
403
+ )
404
+
405
+ # compute loss
406
+ num_bins = pred_pae.shape[-1]
407
+ bin_index = torch.floor(target_pae * num_bins / max_dist).long()
408
+ bin_index = torch.clamp(bin_index, max=(num_bins - 1))
409
+ pae_one_hot = nn.functional.one_hot(bin_index, num_classes=num_bins)
410
+ errors = -1 * torch.sum(
411
+ pae_one_hot
412
+ * torch.nn.functional.log_softmax(pred_pae.reshape(pae_one_hot.shape), dim=-1),
413
+ dim=-1,
414
+ )
415
+ loss = torch.sum(errors * pair_mask, dim=(-2, -1)) / (
416
+ 1e-7 + torch.sum(pair_mask, dim=(-2, -1))
417
+ )
418
+ # Average over the batch dimension
419
+ loss = torch.mean(loss)
420
+
421
+ return loss
422
+
423
+
424
+ def lddt_dist(dmat_predicted, dmat_true, mask, cutoff=15.0, per_atom=False):
425
+ # NOTE: the mask is a pairwise mask which should have the identity elements already masked out
426
+ # Compute mask over distances
427
+ dists_to_score = (dmat_true < cutoff).float() * mask
428
+ dist_l1 = torch.abs(dmat_true - dmat_predicted)
429
+
430
+ score = 0.25 * (
431
+ (dist_l1 < 0.5).float()
432
+ + (dist_l1 < 1.0).float()
433
+ + (dist_l1 < 2.0).float()
434
+ + (dist_l1 < 4.0).float()
435
+ )
436
+
437
+ # Normalize over the appropriate axes.
438
+ if per_atom:
439
+ mask_no_match = torch.sum(dists_to_score, dim=-1) != 0
440
+ norm = 1.0 / (1e-10 + torch.sum(dists_to_score, dim=-1))
441
+ score = norm * (1e-10 + torch.sum(dists_to_score * score, dim=-1))
442
+ return score, mask_no_match.float()
443
+ else:
444
+ norm = 1.0 / (1e-10 + torch.sum(dists_to_score, dim=(-2, -1)))
445
+ score = norm * (1e-10 + torch.sum(dists_to_score * score, dim=(-2, -1)))
446
+ total = torch.sum(dists_to_score, dim=(-1, -2))
447
+ return score, total
448
+
449
+
450
+ def express_coordinate_in_frame(atom_coords, frame_atom_a, frame_atom_b, frame_atom_c):
451
+ batch, multiplicity = atom_coords.shape[0], atom_coords.shape[1]
452
+ batch_indices0 = torch.arange(batch)[:, None, None].to(atom_coords.device)
453
+ batch_indices1 = torch.arange(multiplicity)[None, :, None].to(atom_coords.device)
454
+
455
+ # extract frame atoms
456
+ a, b, c = (
457
+ atom_coords[batch_indices0, batch_indices1, frame_atom_a],
458
+ atom_coords[batch_indices0, batch_indices1, frame_atom_b],
459
+ atom_coords[batch_indices0, batch_indices1, frame_atom_c],
460
+ )
461
+ w1 = (a - b) / (torch.norm(a - b, dim=-1, keepdim=True) + 1e-5)
462
+ w2 = (c - b) / (torch.norm(c - b, dim=-1, keepdim=True) + 1e-5)
463
+
464
+ # build orthogonal frame
465
+ e1 = (w1 + w2) / (torch.norm(w1 + w2, dim=-1, keepdim=True) + 1e-5)
466
+ e2 = (w2 - w1) / (torch.norm(w2 - w1, dim=-1, keepdim=True) + 1e-5)
467
+ e3 = torch.linalg.cross(e1, e2)
468
+
469
+ # project onto frame basis
470
+ d = b[:, :, None, :, :] - b[:, :, :, None, :]
471
+ x_transformed = torch.cat(
472
+ [
473
+ torch.sum(d * e1[:, :, :, None, :], dim=-1, keepdim=True),
474
+ torch.sum(d * e2[:, :, :, None, :], dim=-1, keepdim=True),
475
+ torch.sum(d * e3[:, :, :, None, :], dim=-1, keepdim=True),
476
+ ],
477
+ dim=-1,
478
+ )
479
+ return x_transformed
480
+
481
+
482
+ def compute_collinear_mask(v1, v2):
483
+ # Compute the mask for collinear or overlapping atoms
484
+ norm1 = torch.norm(v1, dim=1, keepdim=True)
485
+ norm2 = torch.norm(v2, dim=1, keepdim=True)
486
+ v1 = v1 / (norm1 + 1e-6)
487
+ v2 = v2 / (norm2 + 1e-6)
488
+ mask_angle = torch.abs(torch.sum(v1 * v2, dim=1)) < 0.9063
489
+ mask_overlap1 = norm1.reshape(-1) > 1e-2
490
+ mask_overlap2 = norm2.reshape(-1) > 1e-2
491
+ return mask_angle & mask_overlap1 & mask_overlap2
492
+
493
+
494
+ def compute_frame_pred(
495
+ pred_atom_coords,
496
+ frames_idx_true,
497
+ feats,
498
+ multiplicity,
499
+ resolved_mask=None,
500
+ inference=False,
501
+ ):
502
+ # extract necessary features
503
+ asym_id_token = feats["asym_id"]
504
+ asym_id_atom = torch.bmm(
505
+ feats["atom_to_token"].float(), asym_id_token.unsqueeze(-1).float()
506
+ ).squeeze(-1)
507
+ B, N, _ = pred_atom_coords.shape
508
+ pred_atom_coords = pred_atom_coords.reshape(B // multiplicity, multiplicity, -1, 3)
509
+ frames_idx_pred = (
510
+ frames_idx_true.clone()
511
+ .repeat_interleave(multiplicity, 0)
512
+ .reshape(B // multiplicity, multiplicity, -1, 3)
513
+ )
514
+
515
+ # Iterate through the batch and update the frames for nonpolymers
516
+ for i, pred_atom_coord in enumerate(pred_atom_coords):
517
+ token_idx = 0
518
+ atom_idx = 0
519
+ for id in torch.unique(asym_id_token[i]):
520
+ mask_chain_token = (asym_id_token[i] == id) * feats["token_pad_mask"][i]
521
+ mask_chain_atom = (asym_id_atom[i] == id) * feats["atom_pad_mask"][i]
522
+ num_tokens = int(mask_chain_token.sum().item())
523
+ num_atoms = int(mask_chain_atom.sum().item())
524
+ if (
525
+ feats["mol_type"][i, token_idx] != const.chain_type_ids["NONPOLYMER"]
526
+ or num_atoms < 3
527
+ ):
528
+ token_idx += num_tokens
529
+ atom_idx += num_atoms
530
+ continue
531
+ dist_mat = (
532
+ (
533
+ pred_atom_coord[:, mask_chain_atom.bool()][:, None, :, :]
534
+ - pred_atom_coord[:, mask_chain_atom.bool()][:, :, None, :]
535
+ )
536
+ ** 2
537
+ ).sum(-1) ** 0.5
538
+
539
+ # Sort the atoms by distance
540
+ if inference:
541
+ resolved_pair = 1 - (
542
+ feats["atom_pad_mask"][i][mask_chain_atom.bool()][None, :]
543
+ * feats["atom_pad_mask"][i][mask_chain_atom.bool()][:, None]
544
+ ).to(torch.float32)
545
+ resolved_pair[resolved_pair == 1] = torch.inf
546
+ indices = torch.sort(dist_mat + resolved_pair, axis=2).indices
547
+ else:
548
+ if resolved_mask is None:
549
+ resolved_mask = feats["atom_resolved_mask"]
550
+ resolved_pair = 1 - (
551
+ resolved_mask[i][mask_chain_atom.bool()][None, :]
552
+ * resolved_mask[i][mask_chain_atom.bool()][:, None]
553
+ ).to(torch.float32)
554
+ resolved_pair[resolved_pair == 1] = torch.inf
555
+ indices = torch.sort(dist_mat + resolved_pair, axis=2).indices
556
+
557
+ # Compute the frames
558
+ frames = (
559
+ torch.cat(
560
+ [
561
+ indices[:, :, 1:2],
562
+ indices[:, :, 0:1],
563
+ indices[:, :, 2:3],
564
+ ],
565
+ dim=2,
566
+ )
567
+ + atom_idx
568
+ )
569
+ frames_idx_pred[i, :, token_idx : token_idx + num_atoms, :] = frames
570
+ token_idx += num_tokens
571
+ atom_idx += num_atoms
572
+
573
+ # Expand the frames with the multiplicity
574
+ frames_expanded = pred_atom_coords[
575
+ torch.arange(0, B // multiplicity, 1)[:, None, None, None].to(
576
+ frames_idx_pred.device
577
+ ),
578
+ torch.arange(0, multiplicity, 1)[None, :, None, None].to(
579
+ frames_idx_pred.device
580
+ ),
581
+ frames_idx_pred,
582
+ ].reshape(-1, 3, 3)
583
+
584
+ # Compute masks for collinear or overlapping atoms in the frame
585
+ mask_collinear_pred = compute_collinear_mask(
586
+ frames_expanded[:, 1] - frames_expanded[:, 0],
587
+ frames_expanded[:, 1] - frames_expanded[:, 2],
588
+ ).reshape(B // multiplicity, multiplicity, -1)
589
+
590
+ return frames_idx_pred, mask_collinear_pred * feats["token_pad_mask"][:, None, :]