boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1025 @@
1
+ import torch
2
+
3
+ from boltz.data import const
4
+ from boltz.model.loss.confidence import (
5
+ compute_frame_pred,
6
+ express_coordinate_in_frame,
7
+ lddt_dist,
8
+ )
9
+ from boltz.model.loss.diffusion import weighted_rigid_align
10
+
11
+
12
+ def factored_lddt_loss(
13
+ true_atom_coords,
14
+ pred_atom_coords,
15
+ feats,
16
+ atom_mask,
17
+ multiplicity=1,
18
+ cardinality_weighted=False,
19
+ ):
20
+ """Compute the lddt factorized into the different modalities.
21
+
22
+ Parameters
23
+ ----------
24
+ true_atom_coords : torch.Tensor
25
+ Ground truth atom coordinates after symmetry correction
26
+ pred_atom_coords : torch.Tensor
27
+ Predicted atom coordinates
28
+ feats : Dict[str, torch.Tensor]
29
+ Input features
30
+ atom_mask : torch.Tensor
31
+ Atom mask
32
+ multiplicity : int
33
+ Diffusion batch size, by default 1
34
+
35
+ Returns
36
+ -------
37
+ Dict[str, torch.Tensor]
38
+ The lddt for each modality
39
+ Dict[str, torch.Tensor]
40
+ The total number of pairs for each modality
41
+
42
+ """
43
+ # extract necessary features
44
+ atom_type = (
45
+ torch.bmm(
46
+ feats["atom_to_token"].float(), feats["mol_type"].unsqueeze(-1).float()
47
+ )
48
+ .squeeze(-1)
49
+ .long()
50
+ )
51
+ atom_type = atom_type.repeat_interleave(multiplicity, 0)
52
+
53
+ ligand_mask = (atom_type == const.chain_type_ids["NONPOLYMER"]).float()
54
+ dna_mask = (atom_type == const.chain_type_ids["DNA"]).float()
55
+ rna_mask = (atom_type == const.chain_type_ids["RNA"]).float()
56
+ protein_mask = (atom_type == const.chain_type_ids["PROTEIN"]).float()
57
+
58
+ nucleotide_mask = dna_mask + rna_mask
59
+
60
+ true_d = torch.cdist(true_atom_coords, true_atom_coords)
61
+ pred_d = torch.cdist(pred_atom_coords, pred_atom_coords)
62
+
63
+ pair_mask = atom_mask[:, :, None] * atom_mask[:, None, :]
64
+ pair_mask = (
65
+ pair_mask
66
+ * (1 - torch.eye(pair_mask.shape[1], device=pair_mask.device))[None, :, :]
67
+ )
68
+
69
+ cutoff = 15 + 15 * (
70
+ 1 - (1 - nucleotide_mask[:, :, None]) * (1 - nucleotide_mask[:, None, :])
71
+ )
72
+
73
+ # compute different lddts
74
+ dna_protein_mask = pair_mask * (
75
+ dna_mask[:, :, None] * protein_mask[:, None, :]
76
+ + protein_mask[:, :, None] * dna_mask[:, None, :]
77
+ )
78
+ dna_protein_lddt, dna_protein_total = lddt_dist(
79
+ pred_d, true_d, dna_protein_mask, cutoff
80
+ )
81
+ del dna_protein_mask
82
+
83
+ rna_protein_mask = pair_mask * (
84
+ rna_mask[:, :, None] * protein_mask[:, None, :]
85
+ + protein_mask[:, :, None] * rna_mask[:, None, :]
86
+ )
87
+ rna_protein_lddt, rna_protein_total = lddt_dist(
88
+ pred_d, true_d, rna_protein_mask, cutoff
89
+ )
90
+ del rna_protein_mask
91
+
92
+ ligand_protein_mask = pair_mask * (
93
+ ligand_mask[:, :, None] * protein_mask[:, None, :]
94
+ + protein_mask[:, :, None] * ligand_mask[:, None, :]
95
+ )
96
+ ligand_protein_lddt, ligand_protein_total = lddt_dist(
97
+ pred_d, true_d, ligand_protein_mask, cutoff
98
+ )
99
+ del ligand_protein_mask
100
+
101
+ dna_ligand_mask = pair_mask * (
102
+ dna_mask[:, :, None] * ligand_mask[:, None, :]
103
+ + ligand_mask[:, :, None] * dna_mask[:, None, :]
104
+ )
105
+ dna_ligand_lddt, dna_ligand_total = lddt_dist(
106
+ pred_d, true_d, dna_ligand_mask, cutoff
107
+ )
108
+ del dna_ligand_mask
109
+
110
+ rna_ligand_mask = pair_mask * (
111
+ rna_mask[:, :, None] * ligand_mask[:, None, :]
112
+ + ligand_mask[:, :, None] * rna_mask[:, None, :]
113
+ )
114
+ rna_ligand_lddt, rna_ligand_total = lddt_dist(
115
+ pred_d, true_d, rna_ligand_mask, cutoff
116
+ )
117
+ del rna_ligand_mask
118
+
119
+ intra_dna_mask = pair_mask * (dna_mask[:, :, None] * dna_mask[:, None, :])
120
+ intra_dna_lddt, intra_dna_total = lddt_dist(pred_d, true_d, intra_dna_mask, cutoff)
121
+ del intra_dna_mask
122
+
123
+ intra_rna_mask = pair_mask * (rna_mask[:, :, None] * rna_mask[:, None, :])
124
+ intra_rna_lddt, intra_rna_total = lddt_dist(pred_d, true_d, intra_rna_mask, cutoff)
125
+ del intra_rna_mask
126
+
127
+ chain_id = feats["asym_id"]
128
+ atom_chain_id = (
129
+ torch.bmm(feats["atom_to_token"].float(), chain_id.unsqueeze(-1).float())
130
+ .squeeze(-1)
131
+ .long()
132
+ )
133
+ atom_chain_id = atom_chain_id.repeat_interleave(multiplicity, 0)
134
+ same_chain_mask = (atom_chain_id[:, :, None] == atom_chain_id[:, None, :]).float()
135
+
136
+ intra_ligand_mask = (
137
+ pair_mask
138
+ * same_chain_mask
139
+ * (ligand_mask[:, :, None] * ligand_mask[:, None, :])
140
+ )
141
+ intra_ligand_lddt, intra_ligand_total = lddt_dist(
142
+ pred_d, true_d, intra_ligand_mask, cutoff
143
+ )
144
+ del intra_ligand_mask
145
+
146
+ intra_protein_mask = (
147
+ pair_mask
148
+ * same_chain_mask
149
+ * (protein_mask[:, :, None] * protein_mask[:, None, :])
150
+ )
151
+ intra_protein_lddt, intra_protein_total = lddt_dist(
152
+ pred_d, true_d, intra_protein_mask, cutoff
153
+ )
154
+ del intra_protein_mask
155
+
156
+ protein_protein_mask = (
157
+ pair_mask
158
+ * (1 - same_chain_mask)
159
+ * (protein_mask[:, :, None] * protein_mask[:, None, :])
160
+ )
161
+ protein_protein_lddt, protein_protein_total = lddt_dist(
162
+ pred_d, true_d, protein_protein_mask, cutoff
163
+ )
164
+ del protein_protein_mask
165
+
166
+ lddt_dict = {
167
+ "dna_protein": dna_protein_lddt,
168
+ "rna_protein": rna_protein_lddt,
169
+ "ligand_protein": ligand_protein_lddt,
170
+ "dna_ligand": dna_ligand_lddt,
171
+ "rna_ligand": rna_ligand_lddt,
172
+ "intra_ligand": intra_ligand_lddt,
173
+ "intra_dna": intra_dna_lddt,
174
+ "intra_rna": intra_rna_lddt,
175
+ "intra_protein": intra_protein_lddt,
176
+ "protein_protein": protein_protein_lddt,
177
+ }
178
+
179
+ total_dict = {
180
+ "dna_protein": dna_protein_total,
181
+ "rna_protein": rna_protein_total,
182
+ "ligand_protein": ligand_protein_total,
183
+ "dna_ligand": dna_ligand_total,
184
+ "rna_ligand": rna_ligand_total,
185
+ "intra_ligand": intra_ligand_total,
186
+ "intra_dna": intra_dna_total,
187
+ "intra_rna": intra_rna_total,
188
+ "intra_protein": intra_protein_total,
189
+ "protein_protein": protein_protein_total,
190
+ }
191
+ if not cardinality_weighted:
192
+ for key in total_dict:
193
+ total_dict[key] = (total_dict[key] > 0.0).float()
194
+
195
+ return lddt_dict, total_dict
196
+
197
+
198
+ def factored_token_lddt_dist_loss(true_d, pred_d, feats, cardinality_weighted=False):
199
+ """Compute the distogram lddt factorized into the different modalities.
200
+
201
+ Parameters
202
+ ----------
203
+ true_d : torch.Tensor
204
+ Ground truth atom distogram
205
+ pred_d : torch.Tensor
206
+ Predicted atom distogram
207
+ feats : Dict[str, torch.Tensor]
208
+ Input features
209
+
210
+ Returns
211
+ -------
212
+ Tensor
213
+ The lddt for each modality
214
+ Tensor
215
+ The total number of pairs for each modality
216
+
217
+ """
218
+ # extract necessary features
219
+ token_type = feats["mol_type"]
220
+
221
+ ligand_mask = (token_type == const.chain_type_ids["NONPOLYMER"]).float()
222
+ dna_mask = (token_type == const.chain_type_ids["DNA"]).float()
223
+ rna_mask = (token_type == const.chain_type_ids["RNA"]).float()
224
+ protein_mask = (token_type == const.chain_type_ids["PROTEIN"]).float()
225
+ nucleotide_mask = dna_mask + rna_mask
226
+
227
+ token_mask = feats["token_disto_mask"]
228
+ token_mask = token_mask[:, :, None] * token_mask[:, None, :]
229
+ token_mask = token_mask * (1 - torch.eye(token_mask.shape[1])[None]).to(token_mask)
230
+
231
+ cutoff = 15 + 15 * (
232
+ 1 - (1 - nucleotide_mask[:, :, None]) * (1 - nucleotide_mask[:, None, :])
233
+ )
234
+
235
+ # compute different lddts
236
+ dna_protein_mask = token_mask * (
237
+ dna_mask[:, :, None] * protein_mask[:, None, :]
238
+ + protein_mask[:, :, None] * dna_mask[:, None, :]
239
+ )
240
+ dna_protein_lddt, dna_protein_total = lddt_dist(
241
+ pred_d, true_d, dna_protein_mask, cutoff
242
+ )
243
+
244
+ rna_protein_mask = token_mask * (
245
+ rna_mask[:, :, None] * protein_mask[:, None, :]
246
+ + protein_mask[:, :, None] * rna_mask[:, None, :]
247
+ )
248
+ rna_protein_lddt, rna_protein_total = lddt_dist(
249
+ pred_d, true_d, rna_protein_mask, cutoff
250
+ )
251
+
252
+ ligand_protein_mask = token_mask * (
253
+ ligand_mask[:, :, None] * protein_mask[:, None, :]
254
+ + protein_mask[:, :, None] * ligand_mask[:, None, :]
255
+ )
256
+ ligand_protein_lddt, ligand_protein_total = lddt_dist(
257
+ pred_d, true_d, ligand_protein_mask, cutoff
258
+ )
259
+
260
+ dna_ligand_mask = token_mask * (
261
+ dna_mask[:, :, None] * ligand_mask[:, None, :]
262
+ + ligand_mask[:, :, None] * dna_mask[:, None, :]
263
+ )
264
+ dna_ligand_lddt, dna_ligand_total = lddt_dist(
265
+ pred_d, true_d, dna_ligand_mask, cutoff
266
+ )
267
+
268
+ rna_ligand_mask = token_mask * (
269
+ rna_mask[:, :, None] * ligand_mask[:, None, :]
270
+ + ligand_mask[:, :, None] * rna_mask[:, None, :]
271
+ )
272
+ rna_ligand_lddt, rna_ligand_total = lddt_dist(
273
+ pred_d, true_d, rna_ligand_mask, cutoff
274
+ )
275
+
276
+ chain_id = feats["asym_id"]
277
+ same_chain_mask = (chain_id[:, :, None] == chain_id[:, None, :]).float()
278
+ intra_ligand_mask = (
279
+ token_mask
280
+ * same_chain_mask
281
+ * (ligand_mask[:, :, None] * ligand_mask[:, None, :])
282
+ )
283
+ intra_ligand_lddt, intra_ligand_total = lddt_dist(
284
+ pred_d, true_d, intra_ligand_mask, cutoff
285
+ )
286
+
287
+ intra_dna_mask = token_mask * (dna_mask[:, :, None] * dna_mask[:, None, :])
288
+ intra_dna_lddt, intra_dna_total = lddt_dist(pred_d, true_d, intra_dna_mask, cutoff)
289
+
290
+ intra_rna_mask = token_mask * (rna_mask[:, :, None] * rna_mask[:, None, :])
291
+ intra_rna_lddt, intra_rna_total = lddt_dist(pred_d, true_d, intra_rna_mask, cutoff)
292
+
293
+ chain_id = feats["asym_id"]
294
+ same_chain_mask = (chain_id[:, :, None] == chain_id[:, None, :]).float()
295
+
296
+ intra_protein_mask = (
297
+ token_mask
298
+ * same_chain_mask
299
+ * (protein_mask[:, :, None] * protein_mask[:, None, :])
300
+ )
301
+ intra_protein_lddt, intra_protein_total = lddt_dist(
302
+ pred_d, true_d, intra_protein_mask, cutoff
303
+ )
304
+
305
+ protein_protein_mask = (
306
+ token_mask
307
+ * (1 - same_chain_mask)
308
+ * (protein_mask[:, :, None] * protein_mask[:, None, :])
309
+ )
310
+ protein_protein_lddt, protein_protein_total = lddt_dist(
311
+ pred_d, true_d, protein_protein_mask, cutoff
312
+ )
313
+
314
+ lddt_dict = {
315
+ "dna_protein": dna_protein_lddt,
316
+ "rna_protein": rna_protein_lddt,
317
+ "ligand_protein": ligand_protein_lddt,
318
+ "dna_ligand": dna_ligand_lddt,
319
+ "rna_ligand": rna_ligand_lddt,
320
+ "intra_ligand": intra_ligand_lddt,
321
+ "intra_dna": intra_dna_lddt,
322
+ "intra_rna": intra_rna_lddt,
323
+ "intra_protein": intra_protein_lddt,
324
+ "protein_protein": protein_protein_lddt,
325
+ }
326
+
327
+ total_dict = {
328
+ "dna_protein": dna_protein_total,
329
+ "rna_protein": rna_protein_total,
330
+ "ligand_protein": ligand_protein_total,
331
+ "dna_ligand": dna_ligand_total,
332
+ "rna_ligand": rna_ligand_total,
333
+ "intra_ligand": intra_ligand_total,
334
+ "intra_dna": intra_dna_total,
335
+ "intra_rna": intra_rna_total,
336
+ "intra_protein": intra_protein_total,
337
+ "protein_protein": protein_protein_total,
338
+ }
339
+
340
+ if not cardinality_weighted:
341
+ for key in total_dict:
342
+ total_dict[key] = (total_dict[key] > 0.0).float()
343
+
344
+ return lddt_dict, total_dict
345
+
346
+
347
+ def compute_plddt_mae(
348
+ pred_atom_coords,
349
+ feats,
350
+ true_atom_coords,
351
+ pred_lddt,
352
+ true_coords_resolved_mask,
353
+ multiplicity=1,
354
+ ):
355
+ """Compute the plddt mean absolute error.
356
+
357
+ Parameters
358
+ ----------
359
+ pred_atom_coords : torch.Tensor
360
+ Predicted atom coordinates
361
+ feats : torch.Tensor
362
+ Input features
363
+ true_atom_coords : torch.Tensor
364
+ Ground truth atom coordinates
365
+ pred_lddt : torch.Tensor
366
+ Predicted lddt
367
+ true_coords_resolved_mask : torch.Tensor
368
+ Resolved atom mask
369
+ multiplicity : int
370
+ Diffusion batch size, by default 1
371
+
372
+ Returns
373
+ -------
374
+ Tensor
375
+ The mae for each modality
376
+ Tensor
377
+ The total number of pairs for each modality
378
+
379
+ """
380
+ # extract necessary features
381
+ atom_mask = true_coords_resolved_mask
382
+ R_set_to_rep_atom = feats["r_set_to_rep_atom"]
383
+ R_set_to_rep_atom = R_set_to_rep_atom.repeat_interleave(multiplicity, 0).float()
384
+
385
+ token_type = feats["mol_type"]
386
+ token_type = token_type.repeat_interleave(multiplicity, 0)
387
+ is_nucleotide_token = (token_type == const.chain_type_ids["DNA"]).float() + (
388
+ token_type == const.chain_type_ids["RNA"]
389
+ ).float()
390
+
391
+ B = true_atom_coords.shape[0]
392
+
393
+ atom_to_token = feats["atom_to_token"].float()
394
+ atom_to_token = atom_to_token.repeat_interleave(multiplicity, 0)
395
+
396
+ token_to_rep_atom = feats["token_to_rep_atom"].float()
397
+ token_to_rep_atom = token_to_rep_atom.repeat_interleave(multiplicity, 0)
398
+
399
+ true_token_coords = torch.bmm(token_to_rep_atom, true_atom_coords)
400
+ pred_token_coords = torch.bmm(token_to_rep_atom, pred_atom_coords)
401
+
402
+ # compute true lddt
403
+ true_d = torch.cdist(
404
+ true_token_coords,
405
+ torch.bmm(R_set_to_rep_atom, true_atom_coords),
406
+ )
407
+ pred_d = torch.cdist(
408
+ pred_token_coords,
409
+ torch.bmm(R_set_to_rep_atom, pred_atom_coords),
410
+ )
411
+
412
+ pair_mask = atom_mask.unsqueeze(-1) * atom_mask.unsqueeze(-2)
413
+ pair_mask = (
414
+ pair_mask
415
+ * (1 - torch.eye(pair_mask.shape[1], device=pair_mask.device))[None, :, :]
416
+ )
417
+ pair_mask = torch.einsum("bnm,bkm->bnk", pair_mask, R_set_to_rep_atom)
418
+
419
+ pair_mask = torch.bmm(token_to_rep_atom, pair_mask)
420
+ atom_mask = torch.bmm(token_to_rep_atom, atom_mask.unsqueeze(-1).float()).squeeze(
421
+ -1
422
+ )
423
+ is_nucleotide_R_element = torch.bmm(
424
+ R_set_to_rep_atom, torch.bmm(atom_to_token, is_nucleotide_token.unsqueeze(-1))
425
+ ).squeeze(-1)
426
+ cutoff = 15 + 15 * is_nucleotide_R_element.reshape(B, 1, -1).repeat(
427
+ 1, true_d.shape[1], 1
428
+ )
429
+
430
+ target_lddt, mask_no_match = lddt_dist(
431
+ pred_d, true_d, pair_mask, cutoff, per_atom=True
432
+ )
433
+
434
+ protein_mask = (
435
+ (token_type == const.chain_type_ids["PROTEIN"]).float()
436
+ * atom_mask
437
+ * mask_no_match
438
+ )
439
+ ligand_mask = (
440
+ (token_type == const.chain_type_ids["NONPOLYMER"]).float()
441
+ * atom_mask
442
+ * mask_no_match
443
+ )
444
+ dna_mask = (
445
+ (token_type == const.chain_type_ids["DNA"]).float() * atom_mask * mask_no_match
446
+ )
447
+ rna_mask = (
448
+ (token_type == const.chain_type_ids["RNA"]).float() * atom_mask * mask_no_match
449
+ )
450
+
451
+ protein_mae = torch.sum(torch.abs(target_lddt - pred_lddt) * protein_mask) / (
452
+ torch.sum(protein_mask) + 1e-5
453
+ )
454
+ protein_total = torch.sum(protein_mask)
455
+ ligand_mae = torch.sum(torch.abs(target_lddt - pred_lddt) * ligand_mask) / (
456
+ torch.sum(ligand_mask) + 1e-5
457
+ )
458
+ ligand_total = torch.sum(ligand_mask)
459
+ dna_mae = torch.sum(torch.abs(target_lddt - pred_lddt) * dna_mask) / (
460
+ torch.sum(dna_mask) + 1e-5
461
+ )
462
+ dna_total = torch.sum(dna_mask)
463
+ rna_mae = torch.sum(torch.abs(target_lddt - pred_lddt) * rna_mask) / (
464
+ torch.sum(rna_mask) + 1e-5
465
+ )
466
+ rna_total = torch.sum(rna_mask)
467
+
468
+ mae_plddt_dict = {
469
+ "protein": protein_mae,
470
+ "ligand": ligand_mae,
471
+ "dna": dna_mae,
472
+ "rna": rna_mae,
473
+ }
474
+ total_dict = {
475
+ "protein": protein_total,
476
+ "ligand": ligand_total,
477
+ "dna": dna_total,
478
+ "rna": rna_total,
479
+ }
480
+
481
+ return mae_plddt_dict, total_dict
482
+
483
+
484
+ def compute_pde_mae(
485
+ pred_atom_coords,
486
+ feats,
487
+ true_atom_coords,
488
+ pred_pde,
489
+ true_coords_resolved_mask,
490
+ multiplicity=1,
491
+ ):
492
+ """Compute the plddt mean absolute error.
493
+
494
+ Parameters
495
+ ----------
496
+ pred_atom_coords : torch.Tensor
497
+ Predicted atom coordinates
498
+ feats : torch.Tensor
499
+ Input features
500
+ true_atom_coords : torch.Tensor
501
+ Ground truth atom coordinates
502
+ pred_pde : torch.Tensor
503
+ Predicted pde
504
+ true_coords_resolved_mask : torch.Tensor
505
+ Resolved atom mask
506
+ multiplicity : int
507
+ Diffusion batch size, by default 1
508
+
509
+ Returns
510
+ -------
511
+ Tensor
512
+ The mae for each modality
513
+ Tensor
514
+ The total number of pairs for each modality
515
+
516
+ """
517
+ # extract necessary features
518
+ token_to_rep_atom = feats["token_to_rep_atom"].float()
519
+ token_to_rep_atom = token_to_rep_atom.repeat_interleave(multiplicity, 0)
520
+
521
+ token_mask = torch.bmm(
522
+ token_to_rep_atom, true_coords_resolved_mask.unsqueeze(-1).float()
523
+ ).squeeze(-1)
524
+
525
+ token_type = feats["mol_type"]
526
+ token_type = token_type.repeat_interleave(multiplicity, 0)
527
+
528
+ true_token_coords = torch.bmm(token_to_rep_atom, true_atom_coords)
529
+ pred_token_coords = torch.bmm(token_to_rep_atom, pred_atom_coords)
530
+
531
+ # compute true pde
532
+ true_d = torch.cdist(true_token_coords, true_token_coords)
533
+ pred_d = torch.cdist(pred_token_coords, pred_token_coords)
534
+ target_pde = (
535
+ torch.clamp(
536
+ torch.floor(torch.abs(true_d - pred_d) * 64 / 32).long(), max=63
537
+ ).float()
538
+ * 0.5
539
+ + 0.25
540
+ )
541
+
542
+ pair_mask = token_mask.unsqueeze(-1) * token_mask.unsqueeze(-2)
543
+ pair_mask = (
544
+ pair_mask
545
+ * (1 - torch.eye(pair_mask.shape[1], device=pair_mask.device))[None, :, :]
546
+ )
547
+
548
+ protein_mask = (token_type == const.chain_type_ids["PROTEIN"]).float()
549
+ ligand_mask = (token_type == const.chain_type_ids["NONPOLYMER"]).float()
550
+ dna_mask = (token_type == const.chain_type_ids["DNA"]).float()
551
+ rna_mask = (token_type == const.chain_type_ids["RNA"]).float()
552
+
553
+ # compute different pdes
554
+ dna_protein_mask = pair_mask * (
555
+ dna_mask[:, :, None] * protein_mask[:, None, :]
556
+ + protein_mask[:, :, None] * dna_mask[:, None, :]
557
+ )
558
+ dna_protein_mae = torch.sum(torch.abs(target_pde - pred_pde) * dna_protein_mask) / (
559
+ torch.sum(dna_protein_mask) + 1e-5
560
+ )
561
+ dna_protein_total = torch.sum(dna_protein_mask)
562
+
563
+ rna_protein_mask = pair_mask * (
564
+ rna_mask[:, :, None] * protein_mask[:, None, :]
565
+ + protein_mask[:, :, None] * rna_mask[:, None, :]
566
+ )
567
+ rna_protein_mae = torch.sum(torch.abs(target_pde - pred_pde) * rna_protein_mask) / (
568
+ torch.sum(rna_protein_mask) + 1e-5
569
+ )
570
+ rna_protein_total = torch.sum(rna_protein_mask)
571
+
572
+ ligand_protein_mask = pair_mask * (
573
+ ligand_mask[:, :, None] * protein_mask[:, None, :]
574
+ + protein_mask[:, :, None] * ligand_mask[:, None, :]
575
+ )
576
+ ligand_protein_mae = torch.sum(
577
+ torch.abs(target_pde - pred_pde) * ligand_protein_mask
578
+ ) / (torch.sum(ligand_protein_mask) + 1e-5)
579
+ ligand_protein_total = torch.sum(ligand_protein_mask)
580
+
581
+ dna_ligand_mask = pair_mask * (
582
+ dna_mask[:, :, None] * ligand_mask[:, None, :]
583
+ + ligand_mask[:, :, None] * dna_mask[:, None, :]
584
+ )
585
+ dna_ligand_mae = torch.sum(torch.abs(target_pde - pred_pde) * dna_ligand_mask) / (
586
+ torch.sum(dna_ligand_mask) + 1e-5
587
+ )
588
+ dna_ligand_total = torch.sum(dna_ligand_mask)
589
+
590
+ rna_ligand_mask = pair_mask * (
591
+ rna_mask[:, :, None] * ligand_mask[:, None, :]
592
+ + ligand_mask[:, :, None] * rna_mask[:, None, :]
593
+ )
594
+ rna_ligand_mae = torch.sum(torch.abs(target_pde - pred_pde) * rna_ligand_mask) / (
595
+ torch.sum(rna_ligand_mask) + 1e-5
596
+ )
597
+ rna_ligand_total = torch.sum(rna_ligand_mask)
598
+
599
+ intra_ligand_mask = pair_mask * (ligand_mask[:, :, None] * ligand_mask[:, None, :])
600
+ intra_ligand_mae = torch.sum(
601
+ torch.abs(target_pde - pred_pde) * intra_ligand_mask
602
+ ) / (torch.sum(intra_ligand_mask) + 1e-5)
603
+ intra_ligand_total = torch.sum(intra_ligand_mask)
604
+
605
+ intra_dna_mask = pair_mask * (dna_mask[:, :, None] * dna_mask[:, None, :])
606
+ intra_dna_mae = torch.sum(torch.abs(target_pde - pred_pde) * intra_dna_mask) / (
607
+ torch.sum(intra_dna_mask) + 1e-5
608
+ )
609
+ intra_dna_total = torch.sum(intra_dna_mask)
610
+
611
+ intra_rna_mask = pair_mask * (rna_mask[:, :, None] * rna_mask[:, None, :])
612
+ intra_rna_mae = torch.sum(torch.abs(target_pde - pred_pde) * intra_rna_mask) / (
613
+ torch.sum(intra_rna_mask) + 1e-5
614
+ )
615
+ intra_rna_total = torch.sum(intra_rna_mask)
616
+
617
+ chain_id = feats["asym_id"].repeat_interleave(multiplicity, 0)
618
+ same_chain_mask = (chain_id[:, :, None] == chain_id[:, None, :]).float()
619
+
620
+ intra_protein_mask = (
621
+ pair_mask
622
+ * same_chain_mask
623
+ * (protein_mask[:, :, None] * protein_mask[:, None, :])
624
+ )
625
+ intra_protein_mae = torch.sum(
626
+ torch.abs(target_pde - pred_pde) * intra_protein_mask
627
+ ) / (torch.sum(intra_protein_mask) + 1e-5)
628
+ intra_protein_total = torch.sum(intra_protein_mask)
629
+
630
+ protein_protein_mask = (
631
+ pair_mask
632
+ * (1 - same_chain_mask)
633
+ * (protein_mask[:, :, None] * protein_mask[:, None, :])
634
+ )
635
+ protein_protein_mae = torch.sum(
636
+ torch.abs(target_pde - pred_pde) * protein_protein_mask
637
+ ) / (torch.sum(protein_protein_mask) + 1e-5)
638
+ protein_protein_total = torch.sum(protein_protein_mask)
639
+
640
+ mae_pde_dict = {
641
+ "dna_protein": dna_protein_mae,
642
+ "rna_protein": rna_protein_mae,
643
+ "ligand_protein": ligand_protein_mae,
644
+ "dna_ligand": dna_ligand_mae,
645
+ "rna_ligand": rna_ligand_mae,
646
+ "intra_ligand": intra_ligand_mae,
647
+ "intra_dna": intra_dna_mae,
648
+ "intra_rna": intra_rna_mae,
649
+ "intra_protein": intra_protein_mae,
650
+ "protein_protein": protein_protein_mae,
651
+ }
652
+ total_pde_dict = {
653
+ "dna_protein": dna_protein_total,
654
+ "rna_protein": rna_protein_total,
655
+ "ligand_protein": ligand_protein_total,
656
+ "dna_ligand": dna_ligand_total,
657
+ "rna_ligand": rna_ligand_total,
658
+ "intra_ligand": intra_ligand_total,
659
+ "intra_dna": intra_dna_total,
660
+ "intra_rna": intra_rna_total,
661
+ "intra_protein": intra_protein_total,
662
+ "protein_protein": protein_protein_total,
663
+ }
664
+
665
+ return mae_pde_dict, total_pde_dict
666
+
667
+
668
+ def compute_pae_mae(
669
+ pred_atom_coords,
670
+ feats,
671
+ true_atom_coords,
672
+ pred_pae,
673
+ true_coords_resolved_mask,
674
+ multiplicity=1,
675
+ ):
676
+ """Compute the pae mean absolute error.
677
+
678
+ Parameters
679
+ ----------
680
+ pred_atom_coords : torch.Tensor
681
+ Predicted atom coordinates
682
+ feats : torch.Tensor
683
+ Input features
684
+ true_atom_coords : torch.Tensor
685
+ Ground truth atom coordinates
686
+ pred_pae : torch.Tensor
687
+ Predicted pae
688
+ true_coords_resolved_mask : torch.Tensor
689
+ Resolved atom mask
690
+ multiplicity : int
691
+ Diffusion batch size, by default 1
692
+
693
+ Returns
694
+ -------
695
+ Tensor
696
+ The mae for each modality
697
+ Tensor
698
+ The total number of pairs for each modality
699
+
700
+ """
701
+ # Retrieve frames and resolved masks
702
+ frames_idx_original = feats["frames_idx"]
703
+ mask_frame_true = feats["frame_resolved_mask"]
704
+
705
+ # Adjust the frames for nonpolymers after symmetry correction!
706
+ # NOTE: frames of polymers do not change under symmetry!
707
+ frames_idx_true, mask_collinear_true = compute_frame_pred(
708
+ true_atom_coords,
709
+ frames_idx_original,
710
+ feats,
711
+ multiplicity,
712
+ resolved_mask=true_coords_resolved_mask,
713
+ )
714
+
715
+ frame_true_atom_a, frame_true_atom_b, frame_true_atom_c = (
716
+ frames_idx_true[:, :, :, 0],
717
+ frames_idx_true[:, :, :, 1],
718
+ frames_idx_true[:, :, :, 2],
719
+ )
720
+ # Compute token coords in true frames
721
+ B, N, _ = true_atom_coords.shape
722
+ true_atom_coords = true_atom_coords.reshape(B // multiplicity, multiplicity, -1, 3)
723
+ true_coords_transformed = express_coordinate_in_frame(
724
+ true_atom_coords, frame_true_atom_a, frame_true_atom_b, frame_true_atom_c
725
+ )
726
+
727
+ # Compute pred frames and mask
728
+ frames_idx_pred, mask_collinear_pred = compute_frame_pred(
729
+ pred_atom_coords, frames_idx_original, feats, multiplicity
730
+ )
731
+ frame_pred_atom_a, frame_pred_atom_b, frame_pred_atom_c = (
732
+ frames_idx_pred[:, :, :, 0],
733
+ frames_idx_pred[:, :, :, 1],
734
+ frames_idx_pred[:, :, :, 2],
735
+ )
736
+ # Compute token coords in pred frames
737
+ B, N, _ = pred_atom_coords.shape
738
+ pred_atom_coords = pred_atom_coords.reshape(B // multiplicity, multiplicity, -1, 3)
739
+ pred_coords_transformed = express_coordinate_in_frame(
740
+ pred_atom_coords, frame_pred_atom_a, frame_pred_atom_b, frame_pred_atom_c
741
+ )
742
+
743
+ target_pae_continuous = torch.sqrt(
744
+ ((true_coords_transformed - pred_coords_transformed) ** 2).sum(-1) + 1e-8
745
+ )
746
+ target_pae = (
747
+ torch.clamp(torch.floor(target_pae_continuous * 64 / 32).long(), max=63).float()
748
+ * 0.5
749
+ + 0.25
750
+ )
751
+
752
+ # Compute mask for the pae loss
753
+ b_true_resolved_mask = true_coords_resolved_mask[
754
+ torch.arange(B // multiplicity)[:, None, None].to(
755
+ pred_coords_transformed.device
756
+ ),
757
+ frame_true_atom_b,
758
+ ]
759
+
760
+ pair_mask = (
761
+ mask_frame_true[:, None, :, None] # if true frame is invalid
762
+ * mask_collinear_true[:, :, :, None] # if true frame is invalid
763
+ * mask_collinear_pred[:, :, :, None] # if pred frame is invalid
764
+ * b_true_resolved_mask[:, :, None, :] # If atom j is not resolved
765
+ * feats["token_pad_mask"][:, None, :, None]
766
+ * feats["token_pad_mask"][:, None, None, :]
767
+ )
768
+
769
+ token_type = feats["mol_type"]
770
+ token_type = token_type.repeat_interleave(multiplicity, 0)
771
+
772
+ protein_mask = (token_type == const.chain_type_ids["PROTEIN"]).float()
773
+ ligand_mask = (token_type == const.chain_type_ids["NONPOLYMER"]).float()
774
+ dna_mask = (token_type == const.chain_type_ids["DNA"]).float()
775
+ rna_mask = (token_type == const.chain_type_ids["RNA"]).float()
776
+
777
+ # compute different paes
778
+ dna_protein_mask = pair_mask * (
779
+ dna_mask[:, :, None] * protein_mask[:, None, :]
780
+ + protein_mask[:, :, None] * dna_mask[:, None, :]
781
+ )
782
+ dna_protein_mae = torch.sum(torch.abs(target_pae - pred_pae) * dna_protein_mask) / (
783
+ torch.sum(dna_protein_mask) + 1e-5
784
+ )
785
+ dna_protein_total = torch.sum(dna_protein_mask)
786
+
787
+ rna_protein_mask = pair_mask * (
788
+ rna_mask[:, :, None] * protein_mask[:, None, :]
789
+ + protein_mask[:, :, None] * rna_mask[:, None, :]
790
+ )
791
+ rna_protein_mae = torch.sum(torch.abs(target_pae - pred_pae) * rna_protein_mask) / (
792
+ torch.sum(rna_protein_mask) + 1e-5
793
+ )
794
+ rna_protein_total = torch.sum(rna_protein_mask)
795
+
796
+ ligand_protein_mask = pair_mask * (
797
+ ligand_mask[:, :, None] * protein_mask[:, None, :]
798
+ + protein_mask[:, :, None] * ligand_mask[:, None, :]
799
+ )
800
+ ligand_protein_mae = torch.sum(
801
+ torch.abs(target_pae - pred_pae) * ligand_protein_mask
802
+ ) / (torch.sum(ligand_protein_mask) + 1e-5)
803
+ ligand_protein_total = torch.sum(ligand_protein_mask)
804
+
805
+ dna_ligand_mask = pair_mask * (
806
+ dna_mask[:, :, None] * ligand_mask[:, None, :]
807
+ + ligand_mask[:, :, None] * dna_mask[:, None, :]
808
+ )
809
+ dna_ligand_mae = torch.sum(torch.abs(target_pae - pred_pae) * dna_ligand_mask) / (
810
+ torch.sum(dna_ligand_mask) + 1e-5
811
+ )
812
+ dna_ligand_total = torch.sum(dna_ligand_mask)
813
+
814
+ rna_ligand_mask = pair_mask * (
815
+ rna_mask[:, :, None] * ligand_mask[:, None, :]
816
+ + ligand_mask[:, :, None] * rna_mask[:, None, :]
817
+ )
818
+ rna_ligand_mae = torch.sum(torch.abs(target_pae - pred_pae) * rna_ligand_mask) / (
819
+ torch.sum(rna_ligand_mask) + 1e-5
820
+ )
821
+ rna_ligand_total = torch.sum(rna_ligand_mask)
822
+
823
+ intra_ligand_mask = pair_mask * (ligand_mask[:, :, None] * ligand_mask[:, None, :])
824
+ intra_ligand_mae = torch.sum(
825
+ torch.abs(target_pae - pred_pae) * intra_ligand_mask
826
+ ) / (torch.sum(intra_ligand_mask) + 1e-5)
827
+ intra_ligand_total = torch.sum(intra_ligand_mask)
828
+
829
+ intra_dna_mask = pair_mask * (dna_mask[:, :, None] * dna_mask[:, None, :])
830
+ intra_dna_mae = torch.sum(torch.abs(target_pae - pred_pae) * intra_dna_mask) / (
831
+ torch.sum(intra_dna_mask) + 1e-5
832
+ )
833
+ intra_dna_total = torch.sum(intra_dna_mask)
834
+
835
+ intra_rna_mask = pair_mask * (rna_mask[:, :, None] * rna_mask[:, None, :])
836
+ intra_rna_mae = torch.sum(torch.abs(target_pae - pred_pae) * intra_rna_mask) / (
837
+ torch.sum(intra_rna_mask) + 1e-5
838
+ )
839
+ intra_rna_total = torch.sum(intra_rna_mask)
840
+
841
+ chain_id = feats["asym_id"].repeat_interleave(multiplicity, 0)
842
+ same_chain_mask = (chain_id[:, :, None] == chain_id[:, None, :]).float()
843
+
844
+ intra_protein_mask = (
845
+ pair_mask
846
+ * same_chain_mask
847
+ * (protein_mask[:, :, None] * protein_mask[:, None, :])
848
+ )
849
+ intra_protein_mae = torch.sum(
850
+ torch.abs(target_pae - pred_pae) * intra_protein_mask
851
+ ) / (torch.sum(intra_protein_mask) + 1e-5)
852
+ intra_protein_total = torch.sum(intra_protein_mask)
853
+
854
+ protein_protein_mask = (
855
+ pair_mask
856
+ * (1 - same_chain_mask)
857
+ * (protein_mask[:, :, None] * protein_mask[:, None, :])
858
+ )
859
+ protein_protein_mae = torch.sum(
860
+ torch.abs(target_pae - pred_pae) * protein_protein_mask
861
+ ) / (torch.sum(protein_protein_mask) + 1e-5)
862
+ protein_protein_total = torch.sum(protein_protein_mask)
863
+
864
+ mae_pae_dict = {
865
+ "dna_protein": dna_protein_mae,
866
+ "rna_protein": rna_protein_mae,
867
+ "ligand_protein": ligand_protein_mae,
868
+ "dna_ligand": dna_ligand_mae,
869
+ "rna_ligand": rna_ligand_mae,
870
+ "intra_ligand": intra_ligand_mae,
871
+ "intra_dna": intra_dna_mae,
872
+ "intra_rna": intra_rna_mae,
873
+ "intra_protein": intra_protein_mae,
874
+ "protein_protein": protein_protein_mae,
875
+ }
876
+ total_pae_dict = {
877
+ "dna_protein": dna_protein_total,
878
+ "rna_protein": rna_protein_total,
879
+ "ligand_protein": ligand_protein_total,
880
+ "dna_ligand": dna_ligand_total,
881
+ "rna_ligand": rna_ligand_total,
882
+ "intra_ligand": intra_ligand_total,
883
+ "intra_dna": intra_dna_total,
884
+ "intra_rna": intra_rna_total,
885
+ "intra_protein": intra_protein_total,
886
+ "protein_protein": protein_protein_total,
887
+ }
888
+
889
+ return mae_pae_dict, total_pae_dict
890
+
891
+
892
+ def weighted_minimum_rmsd(
893
+ pred_atom_coords,
894
+ feats,
895
+ multiplicity=1,
896
+ nucleotide_weight=5.0,
897
+ ligand_weight=10.0,
898
+ ):
899
+ """Compute rmsd of the aligned atom coordinates.
900
+
901
+ Parameters
902
+ ----------
903
+ pred_atom_coords : torch.Tensor
904
+ Predicted atom coordinates
905
+ feats : torch.Tensor
906
+ Input features
907
+ multiplicity : int
908
+ Diffusion batch size, by default 1
909
+
910
+ Returns
911
+ -------
912
+ Tensor
913
+ The rmsds
914
+ Tensor
915
+ The best rmsd
916
+
917
+ """
918
+ atom_coords = feats["coords"]
919
+ atom_coords = atom_coords.repeat_interleave(multiplicity, 0)
920
+ atom_coords = atom_coords[:, 0]
921
+
922
+ atom_mask = feats["atom_resolved_mask"]
923
+ atom_mask = atom_mask.repeat_interleave(multiplicity, 0)
924
+
925
+ align_weights = atom_coords.new_ones(atom_coords.shape[:2])
926
+ atom_type = (
927
+ torch.bmm(
928
+ feats["atom_to_token"].float(), feats["mol_type"].unsqueeze(-1).float()
929
+ )
930
+ .squeeze(-1)
931
+ .long()
932
+ )
933
+ atom_type = atom_type.repeat_interleave(multiplicity, 0)
934
+
935
+ align_weights = align_weights * (
936
+ 1
937
+ + nucleotide_weight
938
+ * (
939
+ torch.eq(atom_type, const.chain_type_ids["DNA"]).float()
940
+ + torch.eq(atom_type, const.chain_type_ids["RNA"]).float()
941
+ )
942
+ + ligand_weight
943
+ * torch.eq(atom_type, const.chain_type_ids["NONPOLYMER"]).float()
944
+ )
945
+
946
+ with torch.no_grad():
947
+ atom_coords_aligned_ground_truth = weighted_rigid_align(
948
+ atom_coords, pred_atom_coords, align_weights, mask=atom_mask
949
+ )
950
+
951
+ # weighted MSE loss of denoised atom positions
952
+ mse_loss = ((pred_atom_coords - atom_coords_aligned_ground_truth) ** 2).sum(dim=-1)
953
+ rmsd = torch.sqrt(
954
+ torch.sum(mse_loss * align_weights * atom_mask, dim=-1)
955
+ / torch.sum(align_weights * atom_mask, dim=-1)
956
+ )
957
+ best_rmsd = torch.min(rmsd.reshape(-1, multiplicity), dim=1).values
958
+
959
+ return rmsd, best_rmsd
960
+
961
+
962
+ def weighted_minimum_rmsd_single(
963
+ pred_atom_coords,
964
+ atom_coords,
965
+ atom_mask,
966
+ atom_to_token,
967
+ mol_type,
968
+ nucleotide_weight=5.0,
969
+ ligand_weight=10.0,
970
+ ):
971
+ """Compute rmsd of the aligned atom coordinates.
972
+
973
+ Parameters
974
+ ----------
975
+ pred_atom_coords : torch.Tensor
976
+ Predicted atom coordinates
977
+ atom_coords: torch.Tensor
978
+ Ground truth atom coordinates
979
+ atom_mask : torch.Tensor
980
+ Resolved atom mask
981
+ atom_to_token : torch.Tensor
982
+ Atom to token mapping
983
+ mol_type : torch.Tensor
984
+ Atom type
985
+
986
+ Returns
987
+ -------
988
+ Tensor
989
+ The rmsd
990
+ Tensor
991
+ The aligned coordinates
992
+ Tensor
993
+ The aligned weights
994
+
995
+ """
996
+ align_weights = atom_coords.new_ones(atom_coords.shape[:2])
997
+ atom_type = (
998
+ torch.bmm(atom_to_token.float(), mol_type.unsqueeze(-1).float())
999
+ .squeeze(-1)
1000
+ .long()
1001
+ )
1002
+
1003
+ align_weights = align_weights * (
1004
+ 1
1005
+ + nucleotide_weight
1006
+ * (
1007
+ torch.eq(atom_type, const.chain_type_ids["DNA"]).float()
1008
+ + torch.eq(atom_type, const.chain_type_ids["RNA"]).float()
1009
+ )
1010
+ + ligand_weight
1011
+ * torch.eq(atom_type, const.chain_type_ids["NONPOLYMER"]).float()
1012
+ )
1013
+
1014
+ with torch.no_grad():
1015
+ atom_coords_aligned_ground_truth = weighted_rigid_align(
1016
+ atom_coords, pred_atom_coords, align_weights, mask=atom_mask
1017
+ )
1018
+
1019
+ # weighted MSE loss of denoised atom positions
1020
+ mse_loss = ((pred_atom_coords - atom_coords_aligned_ground_truth) ** 2).sum(dim=-1)
1021
+ rmsd = torch.sqrt(
1022
+ torch.sum(mse_loss * align_weights * atom_mask, dim=-1)
1023
+ / torch.sum(align_weights * atom_mask, dim=-1)
1024
+ )
1025
+ return rmsd, atom_coords_aligned_ground_truth, align_weights