rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
rf3/metrics/lddt.py ADDED
@@ -0,0 +1,523 @@
1
+ import numpy as np
2
+ import torch
3
+ from atomworks.io.transforms.atom_array import ensure_atom_array_stack
4
+ from atomworks.ml.transforms.atom_array import AddGlobalTokenIdAnnotation
5
+ from atomworks.ml.transforms.atomize import AtomizeByCCDName
6
+ from atomworks.ml.transforms.base import Compose
7
+ from atomworks.ml.utils.token import get_token_starts
8
+ from beartype.typing import Any
9
+ from biotite.structure import AtomArray, AtomArrayStack, stack
10
+ from jaxtyping import Bool, Float, Int
11
+
12
+ from foundry.metrics.metric import Metric
13
+ from foundry.utils.ddp import RankedLogger
14
+
15
+ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
16
+
17
+
18
+ def calc_lddt(
19
+ X_L: Float[torch.Tensor, "D L 3"],
20
+ X_gt_L: Float[torch.Tensor, "D L 3"],
21
+ crd_mask_L: Bool[torch.Tensor, "D L"],
22
+ tok_idx: Int[torch.Tensor, "L"],
23
+ pairs_to_score: Bool[torch.Tensor, "L L"] | None = None,
24
+ distance_cutoff: float = 15.0,
25
+ eps: float = 1e-6,
26
+ ) -> Float[torch.Tensor, "D"]:
27
+ """Calculates LDDT scores for each model in the batch.
28
+
29
+ Args:
30
+ X_L: Predicted coordinates (D, L, 3).
31
+ X_gt_L: Ground truth coordinates (D, L, 3).
32
+ crd_mask_L: Coordinate mask indicating valid atoms (D, L).
33
+ tok_idx: Token index of each atom (L,). Used to exclude same-token pairs.
34
+ pairs_to_score: Boolean mask for pairs to score (L, L). If None, scores all valid pairs.
35
+ distance_cutoff: Distance cutoff for scoring pairs.
36
+ eps: Small epsilon to prevent division by zero.
37
+
38
+ Returns:
39
+ LDDT scores for each model (D,).
40
+ """
41
+ D, L = X_L.shape[:2]
42
+
43
+ # Create pairs to score mask - if not provided, use upper triangular (includes diagonal)
44
+ if pairs_to_score is None:
45
+ pairs_to_score = torch.ones((L, L), dtype=torch.bool).triu(0).to(X_L.device)
46
+ else:
47
+ assert pairs_to_score.shape == (L, L)
48
+ pairs_to_score = pairs_to_score.triu(0).to(X_L.device)
49
+
50
+ # Get indices of atom pairs to evaluate
51
+ first_index: Int[torch.Tensor, "n_pairs"]
52
+ second_index: Int[torch.Tensor, "n_pairs"]
53
+ first_index, second_index = torch.nonzero(pairs_to_score, as_tuple=True)
54
+
55
+ # Compute LDDT score for each model in the batch
56
+ lddt_scores = []
57
+ for d in range(D):
58
+ # Calculate pairwise distances in ground truth structure
59
+ ground_truth_distances = torch.linalg.norm(
60
+ X_gt_L[d, first_index] - X_gt_L[d, second_index], dim=-1
61
+ )
62
+
63
+ # Create mask for valid pairs to score:
64
+ # 1. Ground truth distance > 0 (atoms not at same position)
65
+ # 2. Ground truth distance < cutoff (within interaction range)
66
+ pair_mask = torch.logical_and(
67
+ ground_truth_distances > 0, ground_truth_distances < distance_cutoff
68
+ )
69
+
70
+ # Only score pairs that are resolved in the ground truth
71
+ pair_mask *= crd_mask_L[d, first_index] * crd_mask_L[d, second_index]
72
+
73
+ # Don't score pairs that are in the same token (e.g., same residue)
74
+ pair_mask *= tok_idx[first_index] != tok_idx[second_index]
75
+
76
+ # Filter to only "valid" pairs
77
+ valid_pairs = pair_mask.nonzero(as_tuple=True)
78
+
79
+ pair_mask_valid = pair_mask[valid_pairs].to(X_L.dtype)
80
+ ground_truth_distances_valid = ground_truth_distances[valid_pairs]
81
+
82
+ first_index_valid: Int[torch.Tensor, "n_valid_pairs"] = first_index[valid_pairs]
83
+ second_index_valid: Int[torch.Tensor, "n_valid_pairs"] = second_index[
84
+ valid_pairs
85
+ ]
86
+
87
+ # Calculate pairwise distances in predicted structure
88
+ predicted_distances = torch.linalg.norm(
89
+ X_L[d, first_index_valid] - X_L[d, second_index_valid], dim=-1
90
+ )
91
+
92
+ # Compute absolute distance differences (with small eps to avoid numerical issues)
93
+ delta_distances = torch.abs(
94
+ predicted_distances - ground_truth_distances_valid + eps
95
+ )
96
+ del predicted_distances, ground_truth_distances_valid
97
+
98
+ # Calculate LDDT score using standard thresholds (0.5Å, 1.0Å, 2.0Å, 4.0Å)
99
+ # LDDT is the average fraction of distances preserved within each threshold
100
+ lddt_score = (
101
+ 0.25
102
+ * (
103
+ torch.sum((delta_distances < 0.5) * pair_mask_valid) # 0.5Å threshold
104
+ + torch.sum((delta_distances < 1.0) * pair_mask_valid) # 1.0Å threshold
105
+ + torch.sum((delta_distances < 2.0) * pair_mask_valid) # 2.0Å threshold
106
+ + torch.sum((delta_distances < 4.0) * pair_mask_valid) # 4.0Å threshold
107
+ )
108
+ / (torch.sum(pair_mask_valid) + eps) # Normalize by number of valid pairs
109
+ )
110
+
111
+ lddt_scores.append(lddt_score)
112
+
113
+ return torch.tensor(lddt_scores, device=X_L.device)
114
+
115
+
116
+ def extract_lddt_features_from_atom_arrays(
117
+ predicted_atom_array_stack: AtomArrayStack | AtomArray,
118
+ ground_truth_atom_array_stack: AtomArrayStack | AtomArray,
119
+ ) -> dict[str, Any]:
120
+ """Extract all features needed for LDDT computation from AtomArrays.
121
+
122
+ Args:
123
+ predicted_atom_array_stack: Predicted coordinates as AtomArray(Stack)
124
+ ground_truth_atom_array_stack: Ground truth coordinates as AtomArray(Stack)
125
+
126
+ Returns:
127
+ Dictionary containing:
128
+ - X_L: Predicted coordinates tensor (D, L, 3)
129
+ - X_gt_L: Ground truth coordinates tensor (D, L, 3)
130
+ - crd_mask_L: Coordinate validity mask (D, L)
131
+ - tok_idx: Token indices for each atom (L,)
132
+ - chain_iid_token_lvl: Chain identification at token level
133
+ """
134
+ predicted_atom_array_stack = ensure_atom_array_stack(predicted_atom_array_stack)
135
+ ground_truth_atom_array_stack = ensure_atom_array_stack(
136
+ ground_truth_atom_array_stack
137
+ )
138
+
139
+ if (
140
+ ground_truth_atom_array_stack.stack_depth() == 1
141
+ and predicted_atom_array_stack.stack_depth() > 1
142
+ ):
143
+ # If the ground truth is a single model, and the predicted is a stack, we need to expand the ground truth to the same length as the predicted
144
+ ground_truth_atom_array_stack = stack(
145
+ [ground_truth_atom_array_stack[0]]
146
+ * predicted_atom_array_stack.stack_depth()
147
+ )
148
+
149
+ # Compute coordinates - convert AtomArrays to tensors
150
+ X_L: Float[torch.Tensor, "D L 3"] = torch.from_numpy(
151
+ predicted_atom_array_stack.coord
152
+ ).float()
153
+ X_gt_L: Float[torch.Tensor, "D L 3"] = torch.from_numpy(
154
+ ground_truth_atom_array_stack.coord
155
+ ).float()
156
+
157
+ # For the remaining feature generation, we can directly use the first model in the stack (only coordinates are different)
158
+ ground_truth_atom_array = ground_truth_atom_array_stack[0]
159
+
160
+ # Create coordinate mask using occupancy if available, fallback to coordinate validity
161
+ if "occupancy" in ground_truth_atom_array.get_annotation_categories():
162
+ # Use occupancy annotation (broadcast to all models in stack)if present (occupancy > 0 means atom is present)
163
+ occupancy_mask = ground_truth_atom_array.occupancy > 0
164
+ crd_mask_L: Bool[torch.Tensor, "D L"] = (
165
+ torch.from_numpy(occupancy_mask)
166
+ .bool()
167
+ .unsqueeze(0)
168
+ .expand(X_gt_L.shape[0], -1)
169
+ )
170
+ else:
171
+ # Fallback to coordinate validity (not NaN)
172
+ crd_mask_L: Bool[torch.Tensor, "D L"] = ~torch.isnan(X_gt_L).any(dim=-1)
173
+
174
+ # Get token indices using the same logic as ComputeAtomToTokenMap
175
+ if "token_id" in ground_truth_atom_array.get_annotation_categories():
176
+ # Use the existing token_id annotation (matches ComputeAtomToTokenMap exactly)
177
+ tok_idx = ground_truth_atom_array.token_id.astype(np.int32)
178
+ else:
179
+ # Generate annotations with Transform pipeline
180
+ pipe = Compose(
181
+ [AtomizeByCCDName(atomize_by_default=True), AddGlobalTokenIdAnnotation()]
182
+ )
183
+ data = pipe({"atom_array": ground_truth_atom_array})
184
+ tok_idx = data["atom_array"].token_id.astype(np.int32)
185
+
186
+ # Compute chain identification at the token-level
187
+ token_starts = get_token_starts(ground_truth_atom_array)
188
+
189
+ if "chain_iid" in ground_truth_atom_array.get_annotation_categories():
190
+ chain_iid_token_lvl = ground_truth_atom_array.chain_iid[token_starts]
191
+ else:
192
+ # Use the chain_id annotation instead (e.g., for AF-3 outputs, where the chain_id is ostensibly the chain_iid)
193
+ chain_iid_token_lvl = ground_truth_atom_array.chain_id[token_starts]
194
+
195
+ return {
196
+ "X_L": X_L,
197
+ "X_gt_L": X_gt_L,
198
+ "crd_mask_L": crd_mask_L,
199
+ "tok_idx": tok_idx,
200
+ "chain_iid_token_lvl": chain_iid_token_lvl,
201
+ }
202
+
203
+
204
+ class AllAtomLDDT(Metric):
205
+ """Computes all-atom LDDT scores from AtomArrays."""
206
+
207
+ def __init__(self, log_lddt_for_every_batch: bool = False, **kwargs):
208
+ super().__init__(**kwargs)
209
+ self.log_lddt_for_every_batch = log_lddt_for_every_batch
210
+
211
+ @property
212
+ def kwargs_to_compute_args(self) -> dict[str, Any]:
213
+ return {
214
+ "predicted_atom_array_stack": "predicted_atom_array_stack",
215
+ "ground_truth_atom_array_stack": "ground_truth_atom_array_stack",
216
+ }
217
+
218
+ def compute(
219
+ self,
220
+ predicted_atom_array_stack: AtomArrayStack | AtomArray,
221
+ ground_truth_atom_array_stack: AtomArrayStack | AtomArray,
222
+ ) -> dict[str, Any]:
223
+ """Calculates all-atom LDDT between all pairs of atoms.
224
+
225
+ Args:
226
+ predicted_atom_array_stack: Predicted coordinates as AtomArray(Stack)
227
+ ground_truth_atom_array_stack: Ground truth coordinates as AtomArray(Stack)
228
+
229
+ Returns:
230
+ A dictionary with all-atom LDDT scores:
231
+ - lddt_scores: Raw LDDT scores for each model (torch.Tensor)
232
+ - best_of_1_lddt: LDDT score for the first model
233
+ - best_of_{N}_lddt: Best LDDT score across all N models
234
+ """
235
+ lddt_features = extract_lddt_features_from_atom_arrays(
236
+ predicted_atom_array_stack, ground_truth_atom_array_stack
237
+ )
238
+ tok_idx = torch.tensor(lddt_features["tok_idx"]).to(lddt_features["X_L"].device)
239
+
240
+ all_atom_lddt = calc_lddt(
241
+ X_L=lddt_features["X_L"],
242
+ X_gt_L=lddt_features["X_gt_L"],
243
+ crd_mask_L=lddt_features["crd_mask_L"],
244
+ tok_idx=tok_idx,
245
+ pairs_to_score=None, # By default, score all pairs, except those within the same token
246
+ distance_cutoff=15.0,
247
+ )
248
+
249
+ result = {
250
+ "best_of_1_lddt": all_atom_lddt[0].item(),
251
+ f"best_of_{len(all_atom_lddt)}_lddt": all_atom_lddt.max().item(),
252
+ }
253
+
254
+ if self.log_lddt_for_every_batch:
255
+ lddt_by_batch = {
256
+ f"all_atom_lddt_{i}": all_atom_lddt[i].item()
257
+ for i in range(len(all_atom_lddt))
258
+ }
259
+ result.update(lddt_by_batch)
260
+
261
+ return result
262
+
263
+
264
+ class InterfaceLDDTByType(Metric):
265
+ """Computes interface LDDT, grouped by interface type"""
266
+
267
+ def __init__(self, log_lddt_for_every_batch: bool = False, **kwargs):
268
+ super().__init__(**kwargs)
269
+ self.log_lddt_for_every_batch = log_lddt_for_every_batch
270
+
271
+ @property
272
+ def kwargs_to_compute_args(self) -> dict[str, Any]:
273
+ return {
274
+ "predicted_atom_array_stack": "predicted_atom_array_stack",
275
+ "ground_truth_atom_array_stack": "ground_truth_atom_array_stack",
276
+ "interfaces_to_score": ("extra_info", "interfaces_to_score"),
277
+ }
278
+
279
+ def compute(
280
+ self,
281
+ predicted_atom_array_stack: AtomArrayStack | AtomArray,
282
+ ground_truth_atom_array_stack: AtomArrayStack | AtomArray,
283
+ interfaces_to_score: list = None,
284
+ **kwargs,
285
+ ) -> list[dict[str, Any]]:
286
+ """Calculates interface LDDT between specific pairs of chains/units, grouped by interface type.
287
+
288
+ Args:
289
+ predicted_atom_array_stack: Predicted coordinates as AtomArray(Stack)
290
+ ground_truth_atom_array_stack: Ground truth coordinates as AtomArray(Stack)
291
+ interfaces_to_score: List of interface specifications, each as
292
+ (pn_unit_i, pn_unit_j, interface_type)
293
+
294
+ Returns:
295
+ List of dictionaries containing interface LDDT results for each interface.
296
+ """
297
+ lddt_features = extract_lddt_features_from_atom_arrays(
298
+ predicted_atom_array_stack, ground_truth_atom_array_stack
299
+ )
300
+
301
+ # Short-circuit if no interfaces to score
302
+ if not interfaces_to_score:
303
+ return []
304
+
305
+ interface_results = []
306
+
307
+ # Parse string inputs (for backwards compatibility)
308
+ if isinstance(interfaces_to_score, str):
309
+ interfaces_to_score = (
310
+ eval(interfaces_to_score) if interfaces_to_score else []
311
+ )
312
+
313
+ # Loop over the interfaces to score
314
+ for pn_unit_i, pn_unit_j, interface_type in interfaces_to_score:
315
+ # Get tokens in pn_unit_i and pn_unit_j
316
+ pn_unit_i_tokens = lddt_features["chain_iid_token_lvl"] == pn_unit_i
317
+ pn_unit_j_tokens = lddt_features["chain_iid_token_lvl"] == pn_unit_j
318
+
319
+ if pn_unit_i_tokens.sum() == 0 or pn_unit_j_tokens.sum() == 0:
320
+ ranked_logger.warning(
321
+ f"No atoms found for {pn_unit_i} or {pn_unit_j}! Available chains: {np.unique(lddt_features['chain_iid_token_lvl']).tolist()}"
322
+ )
323
+ continue
324
+
325
+ # Convert the token level to the atom level
326
+ pn_unit_i_atoms = pn_unit_i_tokens[lddt_features["tok_idx"]]
327
+ pn_unit_j_atoms = pn_unit_j_tokens[lddt_features["tok_idx"]]
328
+
329
+ # Compute the outer product of chain_i and chain_j, which represents the interface
330
+ chain_ij_atoms = torch.einsum(
331
+ "L, K -> LK",
332
+ torch.tensor(pn_unit_i_atoms),
333
+ torch.tensor(pn_unit_j_atoms),
334
+ ).to(lddt_features["X_L"].device)
335
+
336
+ # Symmetrize the interface so we can later multiply with an upper triangular without losing information
337
+ chain_ij_atoms = chain_ij_atoms | chain_ij_atoms.T
338
+
339
+ # compute lddt using the pairs_to_score from the intersection
340
+ lddt = calc_lddt(
341
+ lddt_features["X_L"],
342
+ lddt_features["X_gt_L"],
343
+ lddt_features["crd_mask_L"],
344
+ torch.tensor(lddt_features["tok_idx"]).to(lddt_features["X_L"].device),
345
+ pairs_to_score=chain_ij_atoms,
346
+ distance_cutoff=30.0,
347
+ )
348
+
349
+ # add the results to the interface_results list
350
+ n = len(lddt)
351
+ result = {
352
+ "pn_units": [pn_unit_i, pn_unit_j],
353
+ "type": interface_type,
354
+ "best_of_1_lddt": lddt[0].item(),
355
+ f"best_of_{n}_lddt": lddt.max().item(),
356
+ }
357
+
358
+ if self.log_lddt_for_every_batch:
359
+ lddt_by_batch = {f"lddt_{i}": lddt[i].item() for i in range(len(lddt))}
360
+ result.update(lddt_by_batch)
361
+
362
+ interface_results.append(result)
363
+
364
+ return interface_results
365
+
366
+
367
+ class ChainLDDTByType(Metric):
368
+ """Computes chain-wise LDDT scores from AtomArrays, grouped by chain type."""
369
+
370
+ def __init__(self, log_lddt_for_every_batch: bool = False, **kwargs):
371
+ super().__init__(**kwargs)
372
+ self.log_lddt_for_every_batch = log_lddt_for_every_batch
373
+
374
+ @property
375
+ def kwargs_to_compute_args(self) -> dict[str, Any]:
376
+ return {
377
+ "predicted_atom_array_stack": "predicted_atom_array_stack",
378
+ "ground_truth_atom_array_stack": "ground_truth_atom_array_stack",
379
+ "pn_units_to_score": ("extra_info", "pn_units_to_score"),
380
+ }
381
+
382
+ def compute(
383
+ self,
384
+ predicted_atom_array_stack: AtomArrayStack | AtomArray,
385
+ ground_truth_atom_array_stack: AtomArrayStack | AtomArray,
386
+ pn_units_to_score: list = None,
387
+ **kwargs,
388
+ ) -> list[dict[str, Any]]:
389
+ """Calculates intra-chain LDDT for specific chains/units.
390
+
391
+ Args:
392
+ predicted_atom_array_stack: Predicted coordinates as AtomArray(Stack)
393
+ ground_truth_atom_array_stack: Ground truth coordinates as AtomArray(Stack)
394
+ pn_units_to_score: List of chain specifications, each as (pn_unit_iid, chain_type)
395
+
396
+ Returns:
397
+ List of dictionaries containing chain LDDT results for each chain.
398
+ """
399
+ lddt_features = extract_lddt_features_from_atom_arrays(
400
+ predicted_atom_array_stack, ground_truth_atom_array_stack
401
+ )
402
+
403
+ # Short-circuit if no chains to score
404
+ if not pn_units_to_score:
405
+ return []
406
+
407
+ chain_results = []
408
+
409
+ # Parse string inputs (for backwards compatibility)
410
+ if isinstance(pn_units_to_score, str):
411
+ pn_units_to_score = eval(pn_units_to_score) if pn_units_to_score else []
412
+
413
+ # For all chains (pn_units) to score...
414
+ for chain, chain_type in pn_units_to_score:
415
+ # ... get tokens in chain instance
416
+ chain_tokens = lddt_features["chain_iid_token_lvl"] == chain
417
+ if chain_tokens.sum() == 0:
418
+ ranked_logger.warning(
419
+ f"No atoms found for {chain}! Available chains: {np.unique(lddt_features['chain_iid_token_lvl']).tolist()}"
420
+ )
421
+ continue
422
+
423
+ # ... convert the token level to the atom level
424
+ chain_atoms = chain_tokens[lddt_features["tok_idx"]]
425
+
426
+ # ... compute the outer product of the chain with itself (the definition of intra-lddt)
427
+ chain_ij_atoms = torch.einsum(
428
+ "L, K -> LK", torch.tensor(chain_atoms), torch.tensor(chain_atoms)
429
+ ).to(lddt_features["X_L"].device)
430
+
431
+ # ... compute lddt using the pairs_to_score from the interface
432
+ lddt = calc_lddt(
433
+ lddt_features["X_L"],
434
+ lddt_features["X_gt_L"],
435
+ lddt_features["crd_mask_L"],
436
+ torch.tensor(lddt_features["tok_idx"]).to(lddt_features["X_L"].device),
437
+ pairs_to_score=chain_ij_atoms,
438
+ )
439
+
440
+ # ... and finally add the results to the chain_results list
441
+ n = len(lddt)
442
+ result = {
443
+ "pn_units": [chain],
444
+ "type": chain_type,
445
+ "best_of_1_lddt": lddt[0].item(),
446
+ f"best_of_{n}_lddt": lddt.max().item(),
447
+ }
448
+
449
+ if self.log_lddt_for_every_batch:
450
+ lddt_by_batch = {f"lddt_{i}": lddt[i].item() for i in range(len(lddt))}
451
+ result.update(lddt_by_batch)
452
+
453
+ chain_results.append(result)
454
+
455
+ return chain_results
456
+
457
+
458
+ class ByTypeLDDT(Metric):
459
+ """Calculates LDDT scores by type for both chains and interfaces."""
460
+
461
+ def __init__(self, log_lddt_for_every_batch: bool = True, **kwargs):
462
+ super().__init__(**kwargs)
463
+ self.interface_lddt = InterfaceLDDTByType(
464
+ log_lddt_for_every_batch=log_lddt_for_every_batch, **kwargs
465
+ )
466
+ self.chain_lddt = ChainLDDTByType(
467
+ log_lddt_for_every_batch=log_lddt_for_every_batch, **kwargs
468
+ )
469
+
470
+ @property
471
+ def kwargs_to_compute_args(self) -> dict[str, Any]:
472
+ return {
473
+ "predicted_atom_array_stack": "predicted_atom_array_stack",
474
+ "ground_truth_atom_array_stack": "ground_truth_atom_array_stack",
475
+ "interfaces_to_score": ("extra_info", "interfaces_to_score"),
476
+ "pn_units_to_score": ("extra_info", "pn_units_to_score"),
477
+ }
478
+
479
+ @property
480
+ def optional_kwargs(self) -> set[str]:
481
+ """Mark interfaces_to_score and pn_units_to_score as optional."""
482
+ return {"interfaces_to_score", "pn_units_to_score"}
483
+
484
+ def compute(
485
+ self,
486
+ predicted_atom_array_stack: AtomArrayStack | AtomArray,
487
+ ground_truth_atom_array_stack: AtomArrayStack | AtomArray,
488
+ interfaces_to_score: list[tuple[str, str, str]] | None = None,
489
+ pn_units_to_score: list[tuple[str, str]] | None = None,
490
+ ) -> list[dict[str, Any]]:
491
+ """Calculates LDDT scores by type for both chains and interfaces.
492
+
493
+ Args:
494
+ predicted_atom_array_stack: Predicted coordinates as AtomArray(Stack)
495
+ ground_truth_atom_array_stack: Ground truth coordinates as AtomArray(Stack)
496
+ interfaces_to_score: Tuples of (pn_unit_i, pn_unit_j, interface_type)
497
+ representing the interfaces to score
498
+ pn_units_to_score: Tuples of (pn_unit_iid, chain_type)
499
+ representing the chains to score
500
+ log_lddt_for_every_batch: Whether to compute LDDT for each model separately (vs. only BO1 and BO{N})
501
+
502
+ Returns:
503
+ Combined list of interface and chain LDDT results.
504
+ """
505
+
506
+ # Compute interface LDDT scores
507
+ interface_results = self.interface_lddt.compute(
508
+ predicted_atom_array_stack=predicted_atom_array_stack,
509
+ ground_truth_atom_array_stack=ground_truth_atom_array_stack,
510
+ interfaces_to_score=interfaces_to_score,
511
+ )
512
+
513
+ # Compute chain LDDT scores
514
+ chain_results = self.chain_lddt.compute(
515
+ predicted_atom_array_stack=predicted_atom_array_stack,
516
+ ground_truth_atom_array_stack=ground_truth_atom_array_stack,
517
+ pn_units_to_score=pn_units_to_score,
518
+ )
519
+
520
+ # Merge the results
521
+ combined_results = interface_results + chain_results
522
+
523
+ return combined_results
@@ -0,0 +1,43 @@
1
+ import json
2
+
3
+ from beartype.typing import Any, Literal
4
+
5
+ from foundry.metrics.metric import Metric
6
+
7
+
8
+ class ExtraInfo(Metric):
9
+ """Stores the extra_info from the dataloader output in the metrics dictionary.
10
+ Only basic Python types that are hashable and can be JSON serialized are stored."""
11
+
12
+ def __init__(self, keys_to_store: list[str] | Literal["all"] = "all", **kwargs):
13
+ super().__init__(**kwargs)
14
+ self.keys_to_store = keys_to_store
15
+
16
+ @property
17
+ def kwargs_to_compute_args(self) -> dict[str, Any]:
18
+ return {"extra_info": "extra_info"}
19
+
20
+ def _is_basic_hashable_type(self, value: Any) -> bool:
21
+ """Check if value is a basic Python type that is both JSON serializable and hashable."""
22
+ try:
23
+ # First check if it's hashable
24
+ hash(value)
25
+
26
+ # Then check if it's JSON serializable
27
+ json.dumps(value)
28
+ return True
29
+ except (TypeError, OverflowError):
30
+ return False
31
+
32
+ def compute(
33
+ self,
34
+ extra_info: dict,
35
+ ) -> dict[str, Any]:
36
+ result = {}
37
+ for key, value in extra_info.items():
38
+ # Check if we should include this key
39
+ if self.keys_to_store == "all" or key in self.keys_to_store:
40
+ # Check if the value is a basic hashable type
41
+ if self._is_basic_hashable_type(value):
42
+ result[key] = value
43
+ return result