rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,673 @@
1
+ import itertools
2
+ from typing import List
3
+
4
+ import einops
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ import tree
9
+ from beartype.typing import Any
10
+ from biotite.structure import AtomArray, AtomArrayStack
11
+ from omegaconf import DictConfig
12
+ from rf3.chemical import NHEAVY
13
+ from rf3.metrics.metric_utils import (
14
+ compute_mean_over_subsampled_pairs,
15
+ compute_min_over_subsampled_pairs,
16
+ create_chainwise_masks_1d,
17
+ create_chainwise_masks_2d,
18
+ create_interface_masks_2d,
19
+ spread_batch_into_dictionary,
20
+ unbin_logits,
21
+ )
22
+
23
+
24
+ def get_mean_atomwise_plddt(
25
+ plddt_logits: torch.Tensor,
26
+ is_real_atom: torch.Tensor,
27
+ max_value: float,
28
+ ) -> torch.Tensor:
29
+ """Aggregate plddts.
30
+
31
+ Args:
32
+ plddt_logits: Tensor of shape [B, n_token, max_atoms_in_a_token * n_bin] with logits
33
+ is_real_atom: Boolean mask of shape [B, n_token, max_atoms_in_a_token] indicating which atoms are real (i.e., not padding)
34
+ max_value: Maximum value for pLDDT (assigned to the last bin)
35
+
36
+ Returns:
37
+ plddt: Tensor of shape [B,] with the mean atom-wise pLDDT for each batch
38
+ """
39
+ assert (
40
+ plddt_logits.ndim == 3
41
+ ), "plddt_logits must be a 3D tensor (B, n_token, max_atoms_in_a_token * n_bins)"
42
+
43
+ # TODO: Replace with the last dimension of is_real_atom; right now that number is too large (36) because it includes hydrogens
44
+ max_atoms_in_a_token = NHEAVY
45
+
46
+ # Since the pLDDT logits have the last dimension (max_atoms_in_a_token * n_bins), we can calculate n_bins directly
47
+ assert (
48
+ plddt_logits.shape[-1] % max_atoms_in_a_token == 0
49
+ ), "The last dimension of plddt_logits must be divisible by max_atoms_in_a_token!"
50
+ n_bins = plddt_logits.shape[-1] // max_atoms_in_a_token
51
+
52
+ # ... reshape to match what unbin_logits expects
53
+ reshaped_plddt_logits = einops.rearrange(
54
+ plddt_logits,
55
+ "... n_token (max_atoms_in_a_token n_bins) -> ... n_bins n_token max_atoms_in_a_token",
56
+ max_atoms_in_a_token=max_atoms_in_a_token,
57
+ n_bins=n_bins,
58
+ ).float() # [..., n_token, n_bins * max_atoms_in_a_token] -> [ ..., n_bins, n_token, max_atoms_in_a_token]
59
+
60
+ plddt = unbin_logits(
61
+ reshaped_plddt_logits,
62
+ max_value,
63
+ n_bins,
64
+ )
65
+
66
+ is_real_atom = is_real_atom.to(device=plddt.device)
67
+
68
+ # ... create mask indicating which atoms are "real" (i.e., not padding) and calculate the mean
69
+ mask = is_real_atom[:, :max_atoms_in_a_token].unsqueeze(0)
70
+ atomwise_plddt_mean = (plddt * mask).sum(dim=(1, 2)) / mask.sum(dim=(1, 2))
71
+
72
+ return atomwise_plddt_mean
73
+
74
+
75
+ def compile_af3_confidence_outputs(
76
+ plddt_logits: torch.Tensor,
77
+ pae_logits: torch.Tensor,
78
+ pde_logits: torch.Tensor,
79
+ chain_iid_token_lvl: torch.Tensor,
80
+ is_real_atom: torch.Tensor,
81
+ example_id: str,
82
+ confidence_loss_cfg: DictConfig | dict,
83
+ ) -> dict[str, Any]:
84
+ # TODO: Refactor to accept an AtomArray
85
+ # TODO: Taking the confidence_loss_cfg does not align with functional programming best-practices; we should instead take the max_value and n_bins as arguments
86
+
87
+ """Given the confidence logits, computes the confidence metrics for the model's predictions.
88
+
89
+ Returns:
90
+ dict[str, Any]: A dictionary containing the following:
91
+ - confidence_df: A DataFrame containing the aggregate confidence metrics at the chain- and interface-level
92
+ - plddt: The pLDDT logits
93
+ - pae: The pAE logits
94
+ - pde: The pDE logits
95
+ """
96
+
97
+ # Reorder the input tensors to be in (B, n_bins, ...) format for unbinning
98
+ plddt = unbin_logits(
99
+ plddt_logits.reshape(
100
+ -1,
101
+ plddt_logits.shape[1],
102
+ NHEAVY,
103
+ confidence_loss_cfg.plddt.n_bins,
104
+ )
105
+ .permute(0, 3, 1, 2)
106
+ .float(),
107
+ confidence_loss_cfg.plddt.max_value,
108
+ confidence_loss_cfg.plddt.n_bins,
109
+ )
110
+
111
+ # Unbin the pae and pde logits
112
+ pae = unbin_logits(
113
+ pae_logits.permute(0, 3, 1, 2).float(),
114
+ confidence_loss_cfg.pae.max_value,
115
+ confidence_loss_cfg.pae.n_bins,
116
+ )
117
+ pde = unbin_logits(
118
+ pde_logits.permute(0, 3, 1, 2).float(),
119
+ confidence_loss_cfg.pde.max_value,
120
+ confidence_loss_cfg.pde.n_bins,
121
+ )
122
+
123
+ # Calculate interface metrics
124
+ interface_masks = create_interface_masks_2d(chain_iid_token_lvl, device=pae.device)
125
+ pae_interface = {
126
+ k: spread_batch_into_dictionary(compute_mean_over_subsampled_pairs(pae, v))
127
+ for k, v in interface_masks.items()
128
+ }
129
+ pde_interface = {
130
+ k: spread_batch_into_dictionary(compute_mean_over_subsampled_pairs(pde, v))
131
+ for k, v in interface_masks.items()
132
+ }
133
+
134
+ pae_interface_min = {
135
+ k: spread_batch_into_dictionary(compute_min_over_subsampled_pairs(pae, v))
136
+ for k, v in interface_masks.items()
137
+ }
138
+
139
+ pde_interface_min = {
140
+ k: spread_batch_into_dictionary(compute_min_over_subsampled_pairs(pde, v))
141
+ for k, v in interface_masks.items()
142
+ }
143
+ # Calculate chainwise metrics
144
+ chain_masks_2d = create_chainwise_masks_2d(chain_iid_token_lvl, device=pae.device)
145
+ pae_chainwise = {
146
+ k: spread_batch_into_dictionary(compute_mean_over_subsampled_pairs(pae, v))
147
+ for k, v in chain_masks_2d.items()
148
+ }
149
+ pde_chainwise = {
150
+ k: spread_batch_into_dictionary(compute_mean_over_subsampled_pairs(pde, v))
151
+ for k, v in chain_masks_2d.items()
152
+ }
153
+
154
+ chain_masks_1d = create_chainwise_masks_1d(
155
+ chain_iid_token_lvl, device=is_real_atom.device
156
+ )
157
+ plddt_chainwise = {
158
+ k: spread_batch_into_dictionary(
159
+ compute_mean_over_subsampled_pairs(
160
+ plddt, is_real_atom[..., :NHEAVY] * v[:, None]
161
+ )
162
+ )
163
+ for k, v in chain_masks_1d.items()
164
+ }
165
+
166
+ # Aggregate confidence data
167
+ confidence_data = {
168
+ "example_id": example_id,
169
+ "mean_plddt": spread_batch_into_dictionary(
170
+ compute_mean_over_subsampled_pairs(plddt, is_real_atom[..., :NHEAVY])
171
+ ),
172
+ "mean_pae": spread_batch_into_dictionary(pae.mean(dim=(-1, -2))),
173
+ "mean_pde": spread_batch_into_dictionary(pde.mean(dim=(-1, -2))),
174
+ "chain_wise_mean_plddt": plddt_chainwise,
175
+ "chain_wise_mean_pae": pae_chainwise,
176
+ "chain_wise_mean_pde": pde_chainwise,
177
+ "interface_wise_mean_pae": pae_interface,
178
+ "interface_wise_mean_pde": pde_interface,
179
+ "interface_wise_min_pae": pae_interface_min,
180
+ "interface_wise_min_pde": pde_interface_min,
181
+ }
182
+
183
+ # Generate DataFrame rows
184
+ num_batches = plddt.shape[0]
185
+ chains = np.unique(chain_iid_token_lvl)
186
+ chain_pairs = list(itertools.combinations(chains, 2))
187
+
188
+ # For every batch, chain, and interface (chain pair), generate a dataframe row
189
+ chain_rows = [
190
+ {
191
+ "example_id": example_id,
192
+ "chain_chainwise": chain,
193
+ "chainwise_plddt": confidence_data["chain_wise_mean_plddt"][chain][
194
+ batch_idx
195
+ ],
196
+ "chainwise_pde": confidence_data["chain_wise_mean_pde"][chain][batch_idx],
197
+ "chainwise_pae": confidence_data["chain_wise_mean_pae"][chain][batch_idx],
198
+ "overall_plddt": confidence_data["mean_plddt"][batch_idx],
199
+ "overall_pde": confidence_data["mean_pde"][batch_idx],
200
+ "overall_pae": confidence_data["mean_pae"][batch_idx],
201
+ "batch_idx": batch_idx,
202
+ }
203
+ for batch_idx in range(num_batches)
204
+ for chain in chains
205
+ ]
206
+
207
+ interface_rows = [
208
+ {
209
+ "example_id": example_id,
210
+ "chain_i_interface": chain_i,
211
+ "chain_j_interface": chain_j,
212
+ "pae_interface": confidence_data["interface_wise_mean_pae"][
213
+ (chain_i, chain_j)
214
+ ][batch_idx],
215
+ "pde_interface": confidence_data["interface_wise_mean_pde"][
216
+ (chain_i, chain_j)
217
+ ][batch_idx],
218
+ "min_pae_interface": confidence_data["interface_wise_min_pae"][
219
+ (chain_i, chain_j)
220
+ ][batch_idx],
221
+ "min_pde_interface": confidence_data["interface_wise_min_pde"][
222
+ (chain_i, chain_j)
223
+ ][batch_idx],
224
+ "overall_plddt": confidence_data["mean_plddt"][batch_idx],
225
+ "overall_pde": confidence_data["mean_pde"][batch_idx],
226
+ "overall_pae": confidence_data["mean_pae"][batch_idx],
227
+ "batch_idx": batch_idx,
228
+ }
229
+ for batch_idx in range(num_batches)
230
+ for (chain_i, chain_j) in chain_pairs
231
+ ]
232
+
233
+ return {
234
+ "confidence_df": pd.DataFrame(itertools.chain([*chain_rows, *interface_rows])),
235
+ "plddt": plddt,
236
+ "pae": pae,
237
+ "pde": pde,
238
+ }
239
+
240
+
241
+ def compile_af3_style_confidence_outputs(
242
+ plddt_logits: torch.Tensor,
243
+ pae_logits: torch.Tensor,
244
+ pde_logits: torch.Tensor,
245
+ chain_iid_token_lvl: torch.Tensor | np.ndarray,
246
+ is_real_atom: torch.Tensor,
247
+ atom_array: AtomArray,
248
+ confidence_loss_cfg: DictConfig | dict,
249
+ batch_idx: int = 0,
250
+ ) -> dict[str, Any]:
251
+ """Compile confidence outputs in AlphaFold3-compatible format.
252
+
253
+ Returns a dict with:
254
+ - summary_confidences: Dict for {name}_summary_confidences.json
255
+ - confidences: Dict for {name}_confidences.json (per-atom data)
256
+ - plddt, pae, pde: Raw tensors for further processing
257
+ """
258
+ # Unbin logits
259
+ plddt = unbin_logits(
260
+ plddt_logits.reshape(
261
+ -1,
262
+ plddt_logits.shape[1],
263
+ NHEAVY,
264
+ confidence_loss_cfg.plddt.n_bins,
265
+ )
266
+ .permute(0, 3, 1, 2)
267
+ .float(),
268
+ confidence_loss_cfg.plddt.max_value,
269
+ confidence_loss_cfg.plddt.n_bins,
270
+ )
271
+
272
+ pae = unbin_logits(
273
+ pae_logits.permute(0, 3, 1, 2).float(),
274
+ confidence_loss_cfg.pae.max_value,
275
+ confidence_loss_cfg.pae.n_bins,
276
+ )
277
+ pde = unbin_logits(
278
+ pde_logits.permute(0, 3, 1, 2).float(),
279
+ confidence_loss_cfg.pde.max_value,
280
+ confidence_loss_cfg.pde.n_bins,
281
+ )
282
+
283
+ # Get chain information
284
+ if isinstance(chain_iid_token_lvl, torch.Tensor):
285
+ chain_iid_token_lvl = chain_iid_token_lvl.cpu().numpy()
286
+ chains = list(np.unique(chain_iid_token_lvl))
287
+ n_chains = len(chains)
288
+
289
+ # Calculate chainwise metrics
290
+ chain_masks_1d = create_chainwise_masks_1d(
291
+ chain_iid_token_lvl, device=is_real_atom.device
292
+ )
293
+ chain_masks_2d = create_chainwise_masks_2d(chain_iid_token_lvl, device=pae.device)
294
+
295
+ # Chain-level pLDDT
296
+ chain_plddt = {}
297
+ for chain, mask in chain_masks_1d.items():
298
+ chain_plddt[chain] = compute_mean_over_subsampled_pairs(
299
+ plddt, is_real_atom[..., :NHEAVY] * mask[:, None]
300
+ )[batch_idx].item()
301
+
302
+ # Chain-level PAE (intra-chain)
303
+ chain_pae = {}
304
+ for chain, mask in chain_masks_2d.items():
305
+ chain_pae[chain] = compute_mean_over_subsampled_pairs(pae, mask)[
306
+ batch_idx
307
+ ].item()
308
+
309
+ # Chain-pair PAE/PDE (inter-chain, for iptm-like metric)
310
+ interface_masks = create_interface_masks_2d(chain_iid_token_lvl, device=pae.device)
311
+ chain_pair_pae = {}
312
+ chain_pair_pae_min = {}
313
+ chain_pair_pde = {}
314
+ chain_pair_pde_min = {}
315
+ for (chain_i, chain_j), mask in interface_masks.items():
316
+ chain_pair_pae[(chain_i, chain_j)] = compute_mean_over_subsampled_pairs(
317
+ pae, mask
318
+ )[batch_idx].item()
319
+ chain_pair_pae_min[(chain_i, chain_j)] = compute_min_over_subsampled_pairs(
320
+ pae, mask
321
+ )[batch_idx].item()
322
+ chain_pair_pde[(chain_i, chain_j)] = compute_mean_over_subsampled_pairs(
323
+ pde, mask
324
+ )[batch_idx].item()
325
+ chain_pair_pde_min[(chain_i, chain_j)] = compute_min_over_subsampled_pairs(
326
+ pde, mask
327
+ )[batch_idx].item()
328
+
329
+ # Overall metrics for this batch
330
+ overall_plddt = compute_mean_over_subsampled_pairs(
331
+ plddt, is_real_atom[..., :NHEAVY]
332
+ )[batch_idx].item()
333
+ overall_pae = pae[batch_idx].mean().item()
334
+ overall_pde = pde[batch_idx].mean().item()
335
+
336
+ # Build chain_pair matrices (NxN)
337
+ chain_pair_pae_matrix = [[None] * n_chains for _ in range(n_chains)]
338
+ chain_pair_pae_min_matrix = [[None] * n_chains for _ in range(n_chains)]
339
+ chain_pair_pde_matrix = [[None] * n_chains for _ in range(n_chains)]
340
+ chain_pair_pde_min_matrix = [[None] * n_chains for _ in range(n_chains)]
341
+ for i, chain_i in enumerate(chains):
342
+ for j, chain_j in enumerate(chains):
343
+ if i != j and (chain_i, chain_j) in chain_pair_pae:
344
+ chain_pair_pae_matrix[i][j] = round(
345
+ chain_pair_pae[(chain_i, chain_j)], 2
346
+ )
347
+ chain_pair_pae_min_matrix[i][j] = round(
348
+ chain_pair_pae_min[(chain_i, chain_j)], 2
349
+ )
350
+ chain_pair_pde_matrix[i][j] = round(
351
+ chain_pair_pde[(chain_i, chain_j)], 2
352
+ )
353
+ chain_pair_pde_min_matrix[i][j] = round(
354
+ chain_pair_pde_min[(chain_i, chain_j)], 2
355
+ )
356
+
357
+ # Extract per-atom pLDDT values
358
+ atom_plddts = plddt[batch_idx][is_real_atom[..., :NHEAVY]].cpu().tolist()
359
+
360
+ # Extract atom/token chain and residue info from atom_array
361
+ atom_chain_ids = atom_array.chain_id.tolist()
362
+ token_chain_ids = list(chain_iid_token_lvl)
363
+ token_res_ids = list(
364
+ range(len(chain_iid_token_lvl))
365
+ ) # Simplified; could map to actual res_id
366
+
367
+ # PAE matrix for this batch
368
+ pae_matrix = pae[batch_idx].cpu().tolist()
369
+
370
+ # Build summary_confidences (AlphaFold3-style + RF3 extensions)
371
+ summary_confidences = {
372
+ "chain_ptm": [round(chain_plddt.get(c, 0.0), 2) for c in chains],
373
+ "chain_pair_pae_min": chain_pair_pae_min_matrix,
374
+ "chain_pair_pde_min": chain_pair_pde_min_matrix,
375
+ "chain_pair_pae": chain_pair_pae_matrix,
376
+ "chain_pair_pde": chain_pair_pde_matrix,
377
+ "overall_plddt": round(overall_plddt, 4),
378
+ "overall_pde": round(overall_pde, 4),
379
+ "overall_pae": round(overall_pae, 4),
380
+ # Note: ptm, iptm, has_clash should be populated from metrics_output
381
+ }
382
+
383
+ # Build full confidences (per-atom data)
384
+ confidences = {
385
+ "atom_chain_ids": atom_chain_ids,
386
+ "atom_plddts": [round(p, 2) for p in atom_plddts],
387
+ "pae": [[round(v, 2) for v in row] for row in pae_matrix],
388
+ "token_chain_ids": token_chain_ids,
389
+ "token_res_ids": token_res_ids,
390
+ }
391
+
392
+ return {
393
+ "summary_confidences": summary_confidences,
394
+ "confidences": confidences,
395
+ "plddt": plddt,
396
+ "pae": pae,
397
+ "pde": pde,
398
+ }
399
+
400
+
401
+ def compute_batch_indices_with_lowest_predicted_error(
402
+ plddt: torch.Tensor,
403
+ is_real_atom: torch.Tensor,
404
+ pae: torch.Tensor,
405
+ confidence_loss_cfg: dict | DictConfig,
406
+ chain_iid_token_lvl: torch.Tensor,
407
+ is_ligand: torch.Tensor,
408
+ interfaces_to_score: list[tuple],
409
+ pn_units_to_score: list[tuple],
410
+ ) -> dict[str, Any]:
411
+ """Given the confidence logits, computes the index within the diffusion batch of the best predicted structure.
412
+
413
+ Metrics include pAE, pLDDT, and pDE, among others.
414
+
415
+ Returns:
416
+ dict[str, Any]: A dictionary containing the following keys:
417
+ - pae_idx: The index within the diffusion batch of the structure with the best overall pAE (Predicted Aligned Error)
418
+ - pde_idx: The index within the diffusion batch of the structure with the best overall pDE (Predicted Distance Error)
419
+ - plddt_idx: The index within the diffusion batch of the structure with the best overall pLDDT (Predicted Local Distance
420
+ Difference Test)
421
+ - best_chain_to_all_idx: The index within the diffusion batch of the structure with the best pAE subsampled over any
422
+ pair (i,j) where i == chain or j == chain
423
+ - best_chain_to_self_idx: The index within the diffusion batch of the structure with the best pAE subsampled over any
424
+ pair (i,j) where i == chain and j == chain
425
+ - best_interface_idx: For each interface between two scored PN Units, the index within the diffusion batch of the
426
+ structure with the best mean pAE for all (i,j) where i == interface_chain or j == interface_chain and i != j
427
+ - best_lig_ipae_idx: The index within the diffusion batch for the best pAE subsambled over any pair (i,j)
428
+ where i == chain or j == chain and i != j and i or j is a ligand
429
+ """
430
+ # TODO: Have this function take an `AtomArray` as input so we quickly build masks with much less code
431
+ # TODO: Explore how we can write this function more concisely
432
+ return_dict = {}
433
+
434
+ # AF3's ranking metrics work like this, but using ptm instead of ipae:
435
+ scored_chains, interfaces, interface_chains = _select_scored_units(
436
+ interfaces_to_score, pn_units_to_score
437
+ )
438
+
439
+ chain_to_all_masks = _create_chain_to_all_masks(chain_iid_token_lvl, scored_chains)
440
+ chain_to_self_masks = _create_chain_to_self_masks(
441
+ chain_iid_token_lvl, scored_chains
442
+ )
443
+ interface_masks, lig_chains = _create_interface_masks(
444
+ chain_iid_token_lvl, interfaces, is_ligand
445
+ )
446
+
447
+ # map everything to gpu
448
+ gpu = plddt.device
449
+ chain_to_all_masks = tree.map_structure(
450
+ lambda x: x.to(gpu) if hasattr(x, "cpu") else x, chain_to_all_masks
451
+ )
452
+ chain_to_self_masks = tree.map_structure(
453
+ lambda x: x.to(gpu) if hasattr(x, "cpu") else x, chain_to_self_masks
454
+ )
455
+ interface_masks = tree.map_structure(
456
+ lambda x: x.to(gpu) if hasattr(x, "cpu") else x, interface_masks
457
+ )
458
+
459
+ # Reshape logits to B, K, L, NHEAVY
460
+ plddt = (
461
+ plddt.reshape(
462
+ -1,
463
+ plddt.shape[1],
464
+ NHEAVY,
465
+ confidence_loss_cfg.plddt.n_bins,
466
+ )
467
+ .permute(0, 3, 1, 2)
468
+ .float()
469
+ )
470
+ # Reshape the pae and pde logits to B, K, L, L
471
+ pae_logits = pae.permute(0, 3, 1, 2).float()
472
+ pde_logits = pae.permute(0, 3, 1, 2).float()
473
+
474
+ pae_logits_unbinned = unbin_logits(
475
+ pae_logits, confidence_loss_cfg.pae.max_value, confidence_loss_cfg.pae.n_bins
476
+ )
477
+ plddt_logits_unbinned = unbin_logits(
478
+ plddt, confidence_loss_cfg.plddt.max_value, confidence_loss_cfg.plddt.n_bins
479
+ )
480
+ pde_logits_unbinned = unbin_logits(
481
+ pde_logits, confidence_loss_cfg.pde.max_value, confidence_loss_cfg.pde.n_bins
482
+ )
483
+
484
+ complex_pae = pae_logits_unbinned.mean(dim=(1, 2))
485
+ complex_pde = pde_logits_unbinned.mean(dim=(1, 2))
486
+ complex_plddt = (plddt_logits_unbinned * is_real_atom[..., :NHEAVY]).sum(
487
+ dim=(1, 2)
488
+ ) / is_real_atom[..., :NHEAVY].sum()
489
+
490
+ return_dict["pae_idx"] = torch.argmin(complex_pae)
491
+ return_dict["pde_idx"] = torch.argmin(complex_pde)
492
+ return_dict["plddt_idx"] = torch.argmax(complex_plddt)
493
+
494
+ chain_to_self_paes = _get_masked_error_per_chain(
495
+ scored_chains, chain_to_self_masks, pae_logits_unbinned
496
+ )
497
+ chain_to_all_paes = _get_masked_error_per_chain(
498
+ scored_chains, chain_to_all_masks, pae_logits_unbinned
499
+ )
500
+ interface_chain_paes = _get_masked_error_per_chain(
501
+ interface_chains, interface_masks, pae_logits_unbinned
502
+ )
503
+ # average over both interfaces
504
+ average_interface_paes = _get_average_error_per_interface(
505
+ interfaces, lig_chains, interface_chain_paes
506
+ )
507
+
508
+ return_dict["best_chain_to_all_idx"] = _get_lowest_error_indices(chain_to_all_paes)
509
+ return_dict["best_chain_to_self_idx"] = _get_lowest_error_indices(
510
+ chain_to_self_paes
511
+ )
512
+ return_dict["best_interface_idx"] = _get_lowest_error_indices(
513
+ average_interface_paes
514
+ )
515
+ # for ligands, we don't average the error
516
+ return_dict["best_lig_ipae_idx"] = _get_lowest_error_ligand_indices(
517
+ interface_chain_paes, interfaces, lig_chains
518
+ )
519
+ return return_dict
520
+
521
+
522
+ def annotate_atom_array_b_factor_with_plddt(
523
+ atom_array: AtomArray | AtomArrayStack,
524
+ plddt: torch.Tensor,
525
+ is_real_atom: torch.Tensor,
526
+ ) -> List[AtomArray]:
527
+ """Annotates the b_factor of an AtomArray with the pLDDT values in the occupancy field.
528
+
529
+ Args:
530
+ atom_array: The AtomArray or AtomArrayStack to annotate
531
+ plddt: The pLDDT tensor of shape (B, I, NHEAVY)
532
+ is_real_atom: A mask indicating which atoms are in the structure of shape (I, NHEAVY)
533
+
534
+ Returns:
535
+ list[AtomArray]: The annotated list of AtomArrays. We must return a list of AtomArrays
536
+ because the AtomArray class does not support setting different values as annotations
537
+ other than the coordinate feature.
538
+ """
539
+ atom_wise_plddt = plddt[:, is_real_atom[..., :NHEAVY]]
540
+ assert atom_wise_plddt.shape[1] == atom_array.array_length()
541
+ atom_array_list = []
542
+ # bitotite's AtomArray does not support setting different values as annotations other than
543
+ # the coordinate feature, so we convert atom_array to a list of AtomArrays
544
+ if isinstance(atom_array, AtomArrayStack):
545
+ for i, aa in enumerate(atom_array):
546
+ aa.set_annotation("b_factor", atom_wise_plddt[i].cpu().numpy())
547
+ atom_array_list.append(aa)
548
+ else:
549
+ assert atom_wise_plddt.shape[0] == 1
550
+ atom_array.set_annotation("b_factor", atom_wise_plddt[0].cpu().numpy())
551
+ atom_array_list.append(atom_array)
552
+
553
+ for aa in atom_array_list:
554
+ assert np.isnan(aa.b_factor).sum() == 0
555
+
556
+ return atom_array_list
557
+
558
+
559
+ def _select_scored_units(
560
+ interfaces_to_score: list[tuple], pn_units_to_score: list[tuple]
561
+ ):
562
+ scored_chains = []
563
+ interfaces = []
564
+ interface_chains = []
565
+ for k in interfaces_to_score:
566
+ interfaces.append(f"{k[0]}-{k[1]}")
567
+ interface_chains.append(k[0])
568
+ interface_chains.append(k[1])
569
+ for k in pn_units_to_score:
570
+ scored_chains.append(k[0])
571
+
572
+ return scored_chains, interfaces, interface_chains
573
+
574
+
575
+ def _create_chain_to_all_masks(ch_label, chains_to_score):
576
+ unique_chains = np.unique(ch_label)
577
+ I = len(ch_label)
578
+ chain_to_all_masks = {}
579
+ for chain in unique_chains:
580
+ if chain in chains_to_score:
581
+ indices = torch.from_numpy((ch_label == chain))
582
+ mask = indices.unsqueeze(0) | indices.unsqueeze(1)
583
+ # set the diagonal to false
584
+ mask = mask & ~torch.eye(I, device=mask.device, dtype=torch.bool)
585
+ chain_to_all_masks[chain] = mask
586
+ return chain_to_all_masks
587
+
588
+
589
+ def _create_chain_to_self_masks(ch_label, chains_to_score):
590
+ unique_chains = np.unique(ch_label)
591
+ I = len(ch_label)
592
+ chain_to_self_masks = {}
593
+ for chain in unique_chains:
594
+ if chain in chains_to_score:
595
+ indices = torch.from_numpy((ch_label == chain))
596
+ mask = indices.unsqueeze(0) & indices.unsqueeze(1)
597
+ # set the diagonal to false
598
+ mask = mask & ~torch.eye(I, device=mask.device, dtype=torch.bool)
599
+ chain_to_self_masks[chain] = mask
600
+ return chain_to_self_masks
601
+
602
+
603
+ def _create_interface_masks(ch_label, interfaces, is_ligand):
604
+ interface_masks = {}
605
+ interface_chains = []
606
+ ligand_chains = []
607
+ for interface in interfaces:
608
+ interface_chains.append(interface.split("-")[0])
609
+ interface_chains.append(interface.split("-")[1])
610
+ interface_chains = set(interface_chains)
611
+ for chain in interface_chains:
612
+ chain_indices = torch.from_numpy((ch_label == chain))
613
+
614
+ to_self = chain_indices.unsqueeze(0) & chain_indices.unsqueeze(1)
615
+ to_all = chain_indices.unsqueeze(0) | chain_indices.unsqueeze(1)
616
+ interface_mask = to_all & ~to_self
617
+ interface_masks[chain] = interface_mask
618
+
619
+ if torch.all(is_ligand[chain_indices]):
620
+ ligand_chains.append(chain)
621
+
622
+ return interface_masks, ligand_chains
623
+
624
+
625
+ def _get_masked_error_per_chain(chains, masks, unbinned_logits):
626
+ error = {}
627
+ for chain in chains:
628
+ mask = masks[chain]
629
+ chain_error = compute_mean_over_subsampled_pairs(unbinned_logits, mask)
630
+ error[chain] = chain_error
631
+
632
+ return error
633
+
634
+
635
+ def _get_average_error_per_interface(interfaces, lig_chains, interface_errors):
636
+ average_error = {}
637
+ for interface in interfaces:
638
+ chain_a = interface.split("-")[0]
639
+ chain_b = interface.split("-")[1]
640
+ average_error[interface] = (
641
+ interface_errors[chain_a] + interface_errors[chain_b]
642
+ ) / 2
643
+
644
+ return average_error
645
+
646
+
647
+ def _get_lowest_error_indices(errors):
648
+ lowest_error_indices = {}
649
+ for k, v in errors.items():
650
+ lowest_error_indices[k] = torch.argmin(v)
651
+
652
+ return lowest_error_indices
653
+
654
+
655
+ def _get_lowest_error_ligand_indices(errors, interfaces, lig_chains):
656
+ # ligands are a special case in AF3, where they only consider the ligand chain's error and not the average for the interface
657
+ lowest_error_indices = {}
658
+ for interface in interfaces:
659
+ chain_a = interface.split("-")[0]
660
+ chain_b = interface.split("-")[1]
661
+ if chain_a in lig_chains or chain_b in lig_chains:
662
+ if chain_a in lig_chains:
663
+ lig_chain = chain_a
664
+ elif chain_b in lig_chains:
665
+ lig_chain = chain_b
666
+
667
+ lowest_error_indices[interface] = torch.argmin(errors[lig_chain])
668
+ else:
669
+ # assign a random value to avoid key errors downstream; sorting ligand interfaces
670
+ # from other types is handles in analysis
671
+ lowest_error_indices[interface] = 0
672
+
673
+ return lowest_error_indices