rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,515 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from rf3.chemical import NFRAMES, NHEAVY, frame_indices
4
+
5
+ # TODO: REFACTOR; COPIED FROM RF2AA. WE NEED TO ADD DOCSTRINGS, EXAMPLES, HOPEFULLY TESTS, AND CLEAN UP
6
+ from rf3.metrics.metric_utils import (
7
+ compute_mean_over_subsampled_pairs,
8
+ unbin_logits,
9
+ )
10
+ from rf3.utils.frames import (
11
+ get_frames,
12
+ mask_unresolved_frames_batched,
13
+ rigid_from_3_points,
14
+ )
15
+ from scipy.stats import spearmanr
16
+
17
+
18
+ class ConfidenceLoss(nn.Module):
19
+ def __init__(
20
+ self,
21
+ plddt,
22
+ pae,
23
+ pde,
24
+ exp_resolved,
25
+ weight=1,
26
+ rank_loss=None,
27
+ log_statistics=False,
28
+ ):
29
+ super(ConfidenceLoss, self).__init__()
30
+ self.weight = weight
31
+ self.plddt = plddt
32
+ self.pae = pae
33
+ self.pde = pde
34
+ self.exp_resolved = exp_resolved
35
+ self.cce = nn.CrossEntropyLoss(reduction="none")
36
+ self.eps = 1e-6
37
+ self.rank_loss = rank_loss
38
+ self.log_statistics = log_statistics
39
+
40
+ def forward(
41
+ self,
42
+ network_input,
43
+ network_output,
44
+ loss_input,
45
+ ):
46
+ X_gt_L = loss_input["X_gt_L"]
47
+ X_exists_L = loss_input["crd_mask_L"]
48
+ X_pred_L = network_output["X_pred_rollout_L"]
49
+ B = X_pred_L.shape[0]
50
+ I = loss_input["is_real_atom"].shape[0]
51
+
52
+ true_lddt_binned, is_resolved_I = self.calc_lddt(
53
+ X_pred_L, X_gt_L, X_exists_L, loss_input["seq"], loss_input["is_real_atom"]
54
+ )
55
+
56
+ plddt_logits = (
57
+ network_output["plddt"]
58
+ .reshape(-1, I, NHEAVY, self.plddt.n_bins)
59
+ .permute(0, 3, 1, 2)
60
+ )
61
+ plddt_loss = (
62
+ self.cce(
63
+ plddt_logits,
64
+ true_lddt_binned[..., :NHEAVY].long(),
65
+ )
66
+ * is_resolved_I[..., :NHEAVY]
67
+ )
68
+ plddt_loss = plddt_loss.sum() / (is_resolved_I.sum() + self.eps)
69
+
70
+ pae_logits = network_output["pae"]
71
+ true_pae_binned, pae_logits, valid_pae_pairs = self.calc_pae(
72
+ loss_input,
73
+ X_pred_L,
74
+ X_gt_L,
75
+ X_exists_L,
76
+ pae_logits,
77
+ loss_input["frame_atom_idxs"],
78
+ )
79
+ pae_loss = self.cce(pae_logits, true_pae_binned) * valid_pae_pairs
80
+ pae_loss = pae_loss.sum() / (valid_pae_pairs.sum() + self.eps)
81
+
82
+ true_pde_binned, is_valid_pair = self.calc_pde(
83
+ X_pred_L, X_gt_L, X_exists_L, loss_input["rep_atom_idxs"]
84
+ )
85
+ pde_logits = network_output["pde"].permute(0, 3, 1, 2)
86
+ pde_loss = self.cce(pde_logits, true_pde_binned) * is_valid_pair
87
+ pde_loss = pde_loss.sum() / (is_valid_pair.sum() + self.eps)
88
+
89
+ exp_resolved_logits = network_output["exp_resolved"]
90
+ exp_resolved_loss = (
91
+ self.cce(
92
+ exp_resolved_logits.reshape(
93
+ B, I, NHEAVY, self.exp_resolved.n_bins
94
+ ).permute(0, 3, 1, 2),
95
+ is_resolved_I[:, :, :NHEAVY].long(),
96
+ )
97
+ * loss_input["is_real_atom"][:, :NHEAVY]
98
+ )
99
+ exp_resolved_loss = exp_resolved_loss.sum() / (
100
+ loss_input["is_real_atom"][:, :NHEAVY].sum() + self.eps
101
+ )
102
+ exp_resolved_loss = exp_resolved_loss / B
103
+
104
+ loss_dict = dict(
105
+ plddt_loss=plddt_loss.detach(),
106
+ pae_loss=pae_loss.detach(),
107
+ pde_loss=pde_loss.detach(),
108
+ exp_resolved_loss=exp_resolved_loss.detach(),
109
+ )
110
+
111
+ confidence_loss = (
112
+ self.plddt.weight * plddt_loss
113
+ + self.pae.weight * pae_loss
114
+ + self.pde.weight * pde_loss
115
+ + self.exp_resolved.weight * exp_resolved_loss
116
+ )
117
+
118
+ if self.log_statistics or self.rank_loss.use_listnet_loss:
119
+ # Get correlations across and within batches
120
+ # Get the true values per metric
121
+ true_lddt, true_lddt_per_structure = self.get_true_metrics(
122
+ true_lddt_binned, self.plddt, is_resolved_I
123
+ )
124
+ true_pae, true_pae_per_structure = self.get_true_metrics(
125
+ true_pae_binned, self.pae, valid_pae_pairs
126
+ )
127
+ true_pde, true_pde_per_structure = self.get_true_metrics(
128
+ true_pde_binned, self.pde, is_valid_pair
129
+ )
130
+
131
+ # reorder the input tensors to be in (B, n_bins, ...) format for unbinning
132
+ # pae and pde were already reordered above
133
+ plddt_logit_stack = network_output["plddt"]
134
+ plddt_per_structure = unbin_logits(
135
+ plddt_logit_stack.reshape(
136
+ -1,
137
+ I,
138
+ NHEAVY,
139
+ self.plddt.n_bins,
140
+ )
141
+ .permute(0, 3, 1, 2)
142
+ .float(),
143
+ self.plddt.max_value,
144
+ self.plddt.n_bins,
145
+ )
146
+ pae_per_structure = unbin_logits(
147
+ pae_logits, self.pae.max_value, self.pae.n_bins
148
+ )
149
+ pde_per_structure = unbin_logits(
150
+ pde_logits, self.pde.max_value, self.pde.n_bins
151
+ )
152
+
153
+ plddt_per_structure = torch.cat(
154
+ [
155
+ compute_mean_over_subsampled_pairs(
156
+ plddt_per_structure[i][None],
157
+ is_resolved_I[i, ..., :NHEAVY],
158
+ )
159
+ for i in range(plddt_logit_stack.shape[0])
160
+ ],
161
+ dim=0,
162
+ )
163
+ pae_per_structure = torch.cat(
164
+ [
165
+ compute_mean_over_subsampled_pairs(
166
+ pae_per_structure[i][None], is_valid_pair[i]
167
+ )
168
+ for i in range(pae_per_structure.shape[0])
169
+ ],
170
+ dim=0,
171
+ )
172
+ pde_per_structure = torch.cat(
173
+ [
174
+ compute_mean_over_subsampled_pairs(
175
+ pde_per_structure[i][None], is_valid_pair[i]
176
+ )
177
+ for i in range(pde_per_structure.shape[0])
178
+ ],
179
+ dim=0,
180
+ )
181
+
182
+ plddt = plddt_per_structure.mean()
183
+ pae = pae_per_structure.mean()
184
+ pde = pde_per_structure.mean()
185
+
186
+ if self.log_statistics:
187
+ self.log_correlation_statistics(
188
+ plddt,
189
+ pae,
190
+ pde,
191
+ true_lddt,
192
+ true_pae,
193
+ true_pde,
194
+ true_lddt_per_structure,
195
+ true_pae_per_structure,
196
+ true_pde_per_structure,
197
+ plddt_per_structure,
198
+ pae_per_structure,
199
+ pde_per_structure,
200
+ loss_dict,
201
+ )
202
+
203
+ if self.rank_loss.use_listnet_loss:
204
+ # an easy way of incentivizing ranking accuracy is the following (Listnet):
205
+ plddt_rank_loss = self.listnet_loss(
206
+ true_lddt_per_structure, plddt_per_structure
207
+ )
208
+ pae_rank_loss = self.listnet_loss(
209
+ true_pae_per_structure, pae_per_structure
210
+ )
211
+ pde_rank_loss = self.listnet_loss(
212
+ true_pde_per_structure, pde_per_structure
213
+ )
214
+
215
+ rank_loss_dict = dict(
216
+ plddt_rank_loss=plddt_rank_loss.detach(),
217
+ pae_rank_loss=pae_rank_loss.detach(),
218
+ pde_rank_loss=pde_rank_loss.detach(),
219
+ )
220
+ loss_dict.update(rank_loss_dict)
221
+ confidence_loss += (
222
+ plddt_rank_loss + pae_rank_loss + pde_rank_loss
223
+ ) * self.rank_loss.weight
224
+
225
+ return self.weight * confidence_loss, loss_dict
226
+
227
+ def calc_lddt(self, X_pred_L, X_gt_L, X_exists_L, seq, is_real_atom):
228
+ tok_idx = is_real_atom.nonzero()[:, 0]
229
+
230
+ I = is_real_atom.shape[0]
231
+ B = X_pred_L.shape[0]
232
+
233
+ # If structure is too big, split the batches to deal with a memory issue
234
+ if I > 384:
235
+ ground_truth_distances = torch.cdist(
236
+ X_gt_L[: B // 2],
237
+ X_gt_L[: B // 2],
238
+ compute_mode="donot_use_mm_for_euclid_dist",
239
+ )
240
+ predicted_distances = torch.cdist(
241
+ X_pred_L[: B // 2],
242
+ X_pred_L[: B // 2],
243
+ compute_mode="donot_use_mm_for_euclid_dist",
244
+ )
245
+
246
+ ground_truth_distances2 = torch.cdist(
247
+ X_gt_L[B // 2 :],
248
+ X_gt_L[B // 2 :],
249
+ compute_mode="donot_use_mm_for_euclid_dist",
250
+ )
251
+ predicted_distances2 = torch.cdist(
252
+ X_pred_L[B // 2 :],
253
+ X_pred_L[B // 2 :],
254
+ compute_mode="donot_use_mm_for_euclid_dist",
255
+ )
256
+
257
+ ground_truth_distances = torch.cat(
258
+ (ground_truth_distances, ground_truth_distances2), dim=0
259
+ )
260
+ predicted_distances = torch.cat(
261
+ (predicted_distances, predicted_distances2), dim=0
262
+ )
263
+ else:
264
+ ground_truth_distances = torch.cdist(
265
+ X_gt_L, X_gt_L, compute_mode="donot_use_mm_for_euclid_dist"
266
+ )
267
+ predicted_distances = torch.cdist(
268
+ X_pred_L, X_pred_L, compute_mode="donot_use_mm_for_euclid_dist"
269
+ )
270
+
271
+ X_exists_LL = X_exists_L.unsqueeze(-1) * X_exists_L.unsqueeze(-2)
272
+
273
+ difference_distances = torch.abs(ground_truth_distances - predicted_distances)
274
+ lddt_matrix = torch.zeros_like(difference_distances)
275
+ lddt_matrix = (
276
+ 0.25 * (difference_distances < 4.0)
277
+ + 0.25 * (difference_distances < 2.0)
278
+ + 0.25 * (difference_distances < 1.0)
279
+ + 0.25 * (difference_distances < 0.5)
280
+ )
281
+ in_same_residue_LL = tok_idx.unsqueeze(-1) == tok_idx.unsqueeze(-2)
282
+ close_distances_LL = ground_truth_distances < 15.0
283
+
284
+ # include distances where both atoms are resolved and not in the same residue, and are within an inclusion radius (15A)
285
+ mask_LL = X_exists_LL * ~in_same_residue_LL * close_distances_LL
286
+ lddt_per_atom_L = (lddt_matrix * mask_LL).sum(-1) / (mask_LL.sum(-1) + self.eps)
287
+
288
+ # only aggregate over the resolved atoms in each residue
289
+ lddt_per_atom_I = torch.zeros_like(is_real_atom, dtype=torch.float32)
290
+ lddt_per_atom_I = lddt_per_atom_I.unsqueeze(0).repeat(B, 1, 1)
291
+
292
+ lddt_per_atom_I[:, is_real_atom] = lddt_per_atom_L
293
+ X_exists_I = torch.zeros_like(is_real_atom, dtype=torch.bool)
294
+ X_exists_I = X_exists_I.unsqueeze(0).repeat(B, 1, 1)
295
+ X_exists_I[:, is_real_atom] = X_exists_L
296
+ lddt_per_atom_binned = self.bin_values(
297
+ lddt_per_atom_I, max_value=self.plddt.max_value, n_bins=self.plddt.n_bins
298
+ )
299
+
300
+ return lddt_per_atom_binned, X_exists_I
301
+
302
+ def calc_pae(
303
+ self,
304
+ loss_input,
305
+ X_pred_L,
306
+ X_gt_L,
307
+ X_exists_L,
308
+ pae_logits,
309
+ frame_atom_idxs,
310
+ eps=1e-4,
311
+ ):
312
+ seq = loss_input["seq"]
313
+ atom_frames = loss_input["atom_frames"]
314
+ B = X_pred_L.shape[0]
315
+
316
+ # Construct the backbone atoms in the faux atom-36 representation so we can use existing machinery to get frames
317
+ frame_atom_idxs = frame_atom_idxs.unsqueeze(0).expand(B, -1, -1)
318
+ X_pred_I = torch.zeros(B, seq.shape[-1], 36, 3, device=X_pred_L.device)
319
+ X_pred_I[..., 0, :] = torch.gather(
320
+ X_pred_L, 1, frame_atom_idxs[..., 0].unsqueeze(-1).expand(-1, -1, 3)
321
+ )
322
+ X_pred_I[..., 1, :] = torch.gather(
323
+ X_pred_L, 1, frame_atom_idxs[..., 1].unsqueeze(-1).expand(-1, -1, 3)
324
+ )
325
+ X_pred_I[..., 2, :] = torch.gather(
326
+ X_pred_L, 1, frame_atom_idxs[..., 2].unsqueeze(-1).expand(-1, -1, 3)
327
+ )
328
+
329
+ X_gt_I = torch.zeros(B, seq.shape[-1], 36, 3, device=X_gt_L.device)
330
+ X_gt_I[..., 0, :] = torch.gather(
331
+ X_gt_L, 1, frame_atom_idxs[..., 0].unsqueeze(-1).expand(-1, -1, 3)
332
+ )
333
+ X_gt_I[..., 1, :] = torch.gather(
334
+ X_gt_L, 1, frame_atom_idxs[..., 1].unsqueeze(-1).expand(-1, -1, 3)
335
+ )
336
+ X_gt_I[..., 2, :] = torch.gather(
337
+ X_gt_L, 1, frame_atom_idxs[..., 2].unsqueeze(-1).expand(-1, -1, 3)
338
+ )
339
+
340
+ atom_mask = torch.zeros(
341
+ B, seq.shape[-1], 36, device=X_exists_L.device, dtype=torch.bool
342
+ )
343
+ atom_mask[..., 0] = torch.gather(X_exists_L, 1, frame_atom_idxs[..., 0])
344
+ atom_mask[..., 1] = torch.gather(X_exists_L, 1, frame_atom_idxs[..., 1])
345
+ atom_mask[..., 2] = torch.gather(X_exists_L, 1, frame_atom_idxs[..., 2])
346
+
347
+ frames, frame_mask = get_frames(
348
+ 0,
349
+ 0,
350
+ seq.unsqueeze(0).repeat(B, 1),
351
+ frame_indices.to(seq.device),
352
+ atom_frames,
353
+ )
354
+
355
+ N, L, natoms, _ = X_pred_I.shape
356
+
357
+ # flatten middle dims so can gather across residues
358
+ X_prime = X_pred_I.reshape(N, L * natoms, -1, 3).repeat(1, 1, NFRAMES, 1)
359
+ Y_prime = X_gt_I.reshape(N, L * natoms, -1, 3).repeat(1, 1, NFRAMES, 1)
360
+ frames_reindex_batched, frame_mask_batched = mask_unresolved_frames_batched(
361
+ frames, frame_mask, atom_mask
362
+ )
363
+
364
+ X_x = torch.gather(
365
+ X_prime, 1, frames_reindex_batched[..., 0:1].repeat(1, 1, 1, 3)
366
+ )
367
+ X_y = torch.gather(
368
+ X_prime, 1, frames_reindex_batched[..., 1:2].repeat(1, 1, 1, 3)
369
+ )
370
+ X_z = torch.gather(
371
+ X_prime, 1, frames_reindex_batched[..., 2:3].repeat(1, 1, 1, 3)
372
+ )
373
+ uX, tX = rigid_from_3_points(X_x, X_y, X_z)
374
+
375
+ Y_x = torch.gather(
376
+ Y_prime, 1, frames_reindex_batched[..., 0:1].repeat(1, 1, 1, 3)
377
+ )
378
+ Y_y = torch.gather(
379
+ Y_prime, 1, frames_reindex_batched[..., 1:2].repeat(1, 1, 1, 3)
380
+ )
381
+ Y_z = torch.gather(
382
+ Y_prime, 1, frames_reindex_batched[..., 2:3].repeat(1, 1, 1, 3)
383
+ )
384
+ uY, tY = rigid_from_3_points(Y_x, Y_y, Y_z)
385
+
386
+ uX = uX[:, :, 0]
387
+ uY = uY[:, :, 0]
388
+
389
+ # Compute xij_ca across the batch
390
+ # uX: (B, L, 3), X_pred_I: (B, A, 3), X_y: (B, L, 3)
391
+ xij_ca = torch.einsum(
392
+ "bfji,bfaj->bfai",
393
+ uX, # select valid frames for backbone, shape (B, N_valid_frames, 3)
394
+ X_pred_I[:, None, :, 1] - X_y[:, :, None, 0],
395
+ ) # Result: (B, N_valid_frames, N_valid_ca, 3)
396
+
397
+ # Compute xij_ca_t across the batch
398
+ # uY: (B, L, 3), X_gt_I: (B, A, 3), Y_y: (B, L, 3)
399
+ xij_ca_t = torch.einsum(
400
+ "bfji,bfaj->bfai",
401
+ uY, # select valid frames for backbone, shape (B, N_valid_frames, 3)
402
+ X_gt_I[:, None, :, 1] - Y_y[:, :, None, 0],
403
+ ) # Result: (B, N_valid_frames, N_valid_ca, 3)
404
+
405
+ valid_frames = frame_mask_batched[:, :, 0] # valid backbone frames (B,I)
406
+ valid_ca = atom_mask[:, :, 1] # valid CA atoms (B,I)
407
+ valid_pairs = (
408
+ valid_frames[:, :, None] & valid_ca[:, None, :]
409
+ ) # valid pairs (B,I,I)
410
+
411
+ eij_label = (
412
+ torch.sqrt(torch.square(xij_ca - xij_ca_t).sum(dim=-1) + eps)
413
+ .clone()
414
+ .detach()
415
+ )
416
+ true_pae_label = self.bin_values(
417
+ eij_label, max_value=self.pae.max_value, n_bins=self.pae.n_bins
418
+ )
419
+ pae_logits = pae_logits.permute(0, 3, 1, 2) # (1, nbins, N_frames, N_ca)
420
+
421
+ return true_pae_label.detach(), pae_logits, valid_pairs
422
+
423
+ def calc_pde(self, X_pred_L, X_gt_L, X_exists_L, rep_atoms):
424
+ X_pred_I = X_pred_L.index_select(1, rep_atoms)
425
+ X_gt_I = X_gt_L.index_select(1, rep_atoms)
426
+ X_exists_I = X_exists_L.index_select(1, rep_atoms)
427
+ predicted_distances = torch.cdist(
428
+ X_pred_I, X_pred_I, compute_mode="donot_use_mm_for_euclid_dist"
429
+ )
430
+ ground_truth_distances = torch.cdist(
431
+ X_gt_I, X_gt_I, compute_mode="donot_use_mm_for_euclid_dist"
432
+ )
433
+ difference_distances = torch.abs(ground_truth_distances - predicted_distances)
434
+ true_pde_binned = self.bin_values(
435
+ difference_distances, max_value=self.pde.max_value, n_bins=self.pde.n_bins
436
+ )
437
+ X_exists_II = X_exists_I.unsqueeze(-1) * X_exists_I.unsqueeze(-2)
438
+ return true_pde_binned.detach(), X_exists_II.detach()
439
+
440
+ def bin_values(self, values, max_value, n_bins):
441
+ # assumes that the bins go from 0 to max_value
442
+ bin_size = max_value / n_bins
443
+ bins = torch.linspace(
444
+ bin_size, max_value - bin_size, n_bins - 1, device=values.device
445
+ )
446
+ return torch.bucketize(values, bins, right=True)
447
+
448
+ def log_correlation_statistics(
449
+ self,
450
+ plddt,
451
+ pae,
452
+ pde,
453
+ true_lddt,
454
+ true_pae,
455
+ true_pde,
456
+ true_lddt_per_structure,
457
+ true_pae_per_structure,
458
+ true_pde_per_structure,
459
+ plddt_per_structure,
460
+ pae_per_structure,
461
+ pde_per_structure,
462
+ loss_dict,
463
+ ):
464
+ # Calculate Spearman rank correlation
465
+ plddt_rank_corr, lddt_spearman_p = spearmanr(
466
+ true_lddt_per_structure.cpu().numpy(), plddt_per_structure.cpu().numpy()
467
+ )
468
+ pae_rank_corr, pae_spearman_p = spearmanr(
469
+ true_pae_per_structure.cpu().numpy(), pae_per_structure.cpu().numpy()
470
+ )
471
+ pde_rank_corr, pde_spearman_p = spearmanr(
472
+ true_pde_per_structure.cpu().numpy(), pde_per_structure.cpu().numpy()
473
+ )
474
+
475
+ loss_dict.update(
476
+ {
477
+ "pred_err_plddt": plddt,
478
+ "pred_err_pae": pae,
479
+ "pred_err_pde": pde,
480
+ "true_err_plddt": true_lddt,
481
+ "true_err_pae": true_pae,
482
+ "true_err_pde": true_pde,
483
+ "plddt_rank_corr": torch.tensor(plddt_rank_corr),
484
+ "pae_rank_corr": torch.tensor(pae_rank_corr),
485
+ "pde_rank_corr": torch.tensor(pde_rank_corr),
486
+ "plddt_spread": plddt_per_structure.max() - plddt_per_structure.min(),
487
+ "pae_spread": pae_per_structure.max() - pae_per_structure.min(),
488
+ "pde_spread": pde_per_structure.max() - pde_per_structure.min(),
489
+ "true_plddt_spread": true_lddt_per_structure.max()
490
+ - true_lddt_per_structure.min(),
491
+ "true_pae_spread": true_pae_per_structure.max()
492
+ - true_pae_per_structure.min(),
493
+ "true_pde_spread": true_pde_per_structure.max()
494
+ - true_pde_per_structure.min(),
495
+ }
496
+ )
497
+
498
+ def get_true_metrics(self, true_metric_binned, metric_config, mask):
499
+ # Calculate the true metric values from the binned values along with the per structure metrics
500
+ bin_size = metric_config.max_value / metric_config.n_bins
501
+ true_metric_unbinned = (
502
+ (true_metric_binned.detach() + 1) * bin_size - (bin_size / 2)
503
+ ) * mask
504
+ true_metric_per_structure = true_metric_unbinned.sum(dim=(1, 2)) / (
505
+ mask.sum(dim=(1, 2)) + self.eps
506
+ )
507
+ true_metric = true_metric_unbinned.sum() / (mask.sum() + self.eps)
508
+
509
+ return true_metric, true_metric_per_structure
510
+
511
+ def listnet_loss(self, true_metric_per_structure, pred_metric_per_structure):
512
+ # Calculate the ListNet loss
513
+ rank_true = torch.nn.Softmax(dim=0)(true_metric_per_structure)
514
+ rank_pred = torch.nn.Softmax(dim=0)(pred_metric_per_structure)
515
+ return -torch.mean(rank_true * torch.log(rank_pred))