rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,465 @@
1
+ import numpy as np
2
+ import torch
3
+ from atomworks.ml.utils.token import (
4
+ get_token_starts,
5
+ )
6
+ from beartype.typing import Any
7
+ from rfd3.metrics.metrics_utils import (
8
+ _flatten_dict,
9
+ get_hotspot_contacts,
10
+ get_ss_metrics_and_rg,
11
+ )
12
+
13
+ from foundry.common import exists
14
+ from foundry.metrics.metric import Metric
15
+
16
+ STANDARD_CACA_DIST = 3.8
17
+
18
+
19
+ def get_clash_metrics(
20
+ atom_array,
21
+ clash_threshold=1.5,
22
+ ligand_clash_threshold=1.5,
23
+ chainbreak_threshold=0.75,
24
+ ):
25
+ # HACK: For now, ligands are treated as any atomized residues
26
+ is_ligand = np.logical_and(
27
+ atom_array.is_ligand, ~atom_array.is_motif_atom_unindexed
28
+ )
29
+
30
+ def get_chainbreaks():
31
+ ca_atoms = atom_array[atom_array.atom_name == "CA"]
32
+ xyz = ca_atoms.coord
33
+ xyz = torch.from_numpy(xyz)
34
+ ca_dists = torch.norm(xyz[1:] - xyz[:-1], dim=-1)
35
+ deviation = torch.abs(ca_dists - STANDARD_CACA_DIST)
36
+
37
+ # Allow leniency for expected chain breaks (e.g. PPI)
38
+ chain_breaks = ca_atoms.chain_iid[1:] != ca_atoms.chain_iid[:-1]
39
+ deviation[chain_breaks] = 0
40
+
41
+ is_chainbreak = deviation > chainbreak_threshold
42
+ return {
43
+ "max_ca_deviation": float(deviation.max(-1).values.mean()),
44
+ "n_chainbreaks": int(is_chainbreak.sum()),
45
+ }
46
+
47
+ def get_interresidue_clashes(backbone_only=False):
48
+ protein_array = atom_array[atom_array.is_protein]
49
+ resid = protein_array.res_id - protein_array.res_id.min()
50
+ xyz = protein_array.coord
51
+ dists = np.linalg.norm(xyz[:, None] - xyz[None], axis=-1) # N_atoms x N_atoms
52
+
53
+ # Block out intra-residue distances
54
+ mask = np.triu(np.ones_like(dists), k=1).astype(bool)
55
+ block_mask = np.abs(resid[:, None] - resid[None, :]) <= 1
56
+ mask[block_mask] = False
57
+ dists[~mask] = 999
58
+
59
+ if backbone_only:
60
+ # Block out non-backbone atoms
61
+ backbone_mask = np.isin(protein_array.atom_name, ["N", "CA", "C"])
62
+ mask = backbone_mask[:, None] & backbone_mask[None, :]
63
+ dists[~mask] = 999
64
+
65
+ num_clashes_L = dists.min(axis=-1) < clash_threshold
66
+ return int(num_clashes_L.sum())
67
+
68
+ def get_ligand_clash_metrics():
69
+ if not is_ligand.any():
70
+ return {}
71
+
72
+ # Clashes are any non-motif atom against any ligand atom
73
+ xyz_ligand = atom_array[is_ligand].coord
74
+ backbone_mask = np.isin(atom_array.atom_name, ["N", "CA", "C"]) & ~is_ligand
75
+ xyz_diffused = atom_array[
76
+ backbone_mask
77
+ & ~atom_array.is_motif_atom_unindexed
78
+ & ~atom_array.is_motif_atom_with_fixed_coord
79
+ ].coord
80
+
81
+ # If we have no diffused backbone atoms, return empty
82
+ if xyz_diffused.shape[0] == 0:
83
+ return {}
84
+
85
+ diff = (
86
+ xyz_diffused[:, None, :] - xyz_ligand[None, :, :]
87
+ ) # (n_diffused, n_ligand, 3)
88
+ dists_ligand = np.linalg.norm(diff, axis=-1) # (n_diffused, n_ligand)
89
+ dists = np.min(dists_ligand, axis=0)
90
+ return {
91
+ "n_clashing.ligand_clashes": int(np.sum(dists < ligand_clash_threshold)),
92
+ "n_clashing.ligand_min_distance": float(np.min(dists)),
93
+ }
94
+
95
+ # Accumulate metrics
96
+ o = {}
97
+ o = o | get_chainbreaks()
98
+ o["n_clashing.interresidue_clashes_w_sidechain"] = get_interresidue_clashes()
99
+ o["n_clashing.interresidue_clashes_w_backbone"] = get_interresidue_clashes(
100
+ backbone_only=True
101
+ )
102
+ o |= get_ligand_clash_metrics()
103
+ return {k: v for k, v in o.items() if exists(v)}
104
+
105
+
106
+ def convert_to_float_or_str(o):
107
+ """
108
+ Converts elements of a dictionary to ensure all components are saveable with JSONs
109
+ """
110
+ for k, v in o.items():
111
+ if not isinstance(v, (int, float, str, list)):
112
+ try:
113
+ o[k] = float(v)
114
+ except Exception as e:
115
+ raise ValueError(f"Unsupported type for key {k}: {type(v)}. Error: {e}")
116
+ return o
117
+
118
+
119
+ def get_all_backbone_metrics(
120
+ atom_array,
121
+ verbose=True,
122
+ compute_non_clash_metrics_for_diffused_region_only: bool = False,
123
+ ):
124
+ """
125
+ Calculate metrics for the AtomArray
126
+
127
+ The atom array coming in will be a cleaned atom array (no virtual atoms and corrected atom names)
128
+ without guideposts
129
+ """
130
+ o = {}
131
+
132
+ # ... Clash metrics
133
+ o = o | get_clash_metrics(
134
+ atom_array,
135
+ )
136
+
137
+ if verbose:
138
+ if compute_non_clash_metrics_for_diffused_region_only:
139
+ # Subset to diffused region only
140
+ atom_array = atom_array[~atom_array.is_motif_atom_with_fixed_coord]
141
+
142
+ # ... Add additional metrics
143
+ o |= get_ss_metrics_and_rg(
144
+ atom_array[~atom_array.is_motif_atom_with_fixed_coord]
145
+ )
146
+
147
+ # Basic compositional statistics
148
+ starts = get_token_starts(atom_array)
149
+ protein_starts = starts[atom_array.is_protein[starts]]
150
+ o["alanine_content"] = np.mean(atom_array[protein_starts].res_name == "ALA")
151
+ o["glycine_content"] = np.mean(atom_array[protein_starts].res_name == "GLY")
152
+ o["num_residues"] = len(protein_starts)
153
+
154
+ fixed = atom_array.is_motif_atom_with_fixed_coord
155
+ o["diffused_com"] = np.mean(atom_array.coord[~fixed, :], axis=0).tolist()
156
+ if np.any(fixed):
157
+ o["fixed_com"] = np.mean(atom_array.coord[fixed, :], axis=0).tolist()
158
+
159
+ # if "b_factor" in token_array.get_annotation_categories():
160
+ # m["sequence_entropy_mean"] = np.mean(token_array.b_factor)
161
+ # m["sequence_entropy_max"] = np.max(token_array.b_factor)
162
+ # m["sequence_entropy_min"] = np.min(token_array.b_factor)
163
+ # m["sequence_entropy_std"] = np.std(token_array.b_factor)
164
+
165
+ # ... Ensure JSON-saveable
166
+ o = convert_to_float_or_str(o)
167
+ return o
168
+
169
+
170
+ class AtomArrayMetrics(Metric):
171
+ """General metrics for the predicted atom array."""
172
+
173
+ def __init__(
174
+ self,
175
+ compute_for_diffused_region_only: bool = False,
176
+ compute_ss_adherence_if_possible: bool = False,
177
+ ):
178
+ super().__init__()
179
+ self.clash_threshold = 1.2
180
+ self.float_threshold = (
181
+ 3.0 # maximum closest-neighbour distance before considered a floating atom
182
+ )
183
+ self.standard_ca_dist = 3.8
184
+ self.compute_for_diffused_region_only = compute_for_diffused_region_only
185
+ self.compute_ss_adherence_if_possible = compute_ss_adherence_if_possible
186
+
187
+ @property
188
+ def kwargs_to_compute_args(self) -> dict[str, Any]:
189
+ return {
190
+ "atom_array_stack": ("predicted_atom_array_stack"),
191
+ "feats": ("network_input", "f"),
192
+ }
193
+
194
+ def compute(self, atom_array_stack, feats):
195
+ o = {}
196
+
197
+ for atom_array in atom_array_stack:
198
+ # Subset to indexed tokens only
199
+ atom_array = atom_array[~atom_array.is_motif_atom_unindexed]
200
+
201
+ if self.compute_for_diffused_region_only:
202
+ atom_array = atom_array[~atom_array.is_motif_atom_with_fixed_coord]
203
+
204
+ # SS content and ROG
205
+ if (
206
+ self.compute_ss_adherence_if_possible
207
+ and (
208
+ "is_helix_conditioning" in feats
209
+ and "is_sheet_conditioning" in feats
210
+ and "is_loop_conditioning" in feats
211
+ )
212
+ and (
213
+ feats["is_helix_conditioning"].sum() > 0
214
+ or feats["is_sheet_conditioning"].sum() > 0
215
+ or feats["is_loop_conditioning"].sum() > 0
216
+ )
217
+ ):
218
+ ss_conditioning = {
219
+ "helix": feats["is_helix_conditioning"].cpu().numpy(),
220
+ "sheet": feats["is_sheet_conditioning"].cpu().numpy(),
221
+ "loop": feats["is_loop_conditioning"].cpu().numpy(),
222
+ }
223
+ else:
224
+ ss_conditioning = None
225
+ m = get_ss_metrics_and_rg(atom_array, ss_conditioning=ss_conditioning)
226
+
227
+ # Subset to token level array for consistency
228
+ token_array = atom_array[get_token_starts(atom_array)]
229
+
230
+ # Basic compositional statistics
231
+ m["alanine_content"] = np.mean(token_array.res_name == "ALA")
232
+ m["glycine_content"] = np.mean(token_array.res_name == "GLY")
233
+
234
+ # Sequence Confidence
235
+ if "b_factor" in token_array.get_annotation_categories():
236
+ m["sequence_entropy_mean"] = np.mean(token_array.b_factor)
237
+ m["sequence_entropy_max"] = np.max(token_array.b_factor)
238
+ m["sequence_entropy_min"] = np.min(token_array.b_factor)
239
+ m["sequence_entropy_std"] = np.std(token_array.b_factor)
240
+
241
+ # Write to o
242
+ for k, v in m.items():
243
+ if k not in o:
244
+ o[k] = []
245
+ o[k].append(v)
246
+
247
+ # Summarize stats
248
+ for k, v in o.items():
249
+ o[k] = float(np.mean(v))
250
+ return o
251
+
252
+
253
+ class MetadataMetrics(Metric):
254
+ """
255
+ Fetches all floating point values from the prediction metadata
256
+ """
257
+
258
+ @property
259
+ def kwargs_to_compute_args(self):
260
+ return {
261
+ "prediction_metadata": ("prediction_metadata",),
262
+ }
263
+
264
+ def compute(self, prediction_metadata):
265
+ """ """
266
+ if not prediction_metadata:
267
+ return {}
268
+
269
+ o = {}
270
+ for idx, metadata in prediction_metadata.items():
271
+ # Flatten dictionary
272
+ metadata = _flatten_dict(metadata)
273
+
274
+ # Update output dictionary
275
+ for key, value in metadata.items():
276
+ if isinstance(value, (int, float)):
277
+ if key not in o:
278
+ o[key] = []
279
+ o[key].append(value)
280
+
281
+ # Reduce via mean
282
+ o = {k: float(np.mean(v)) for k, v in o.items()}
283
+ return o
284
+
285
+
286
+ class BackboneMetrics(Metric):
287
+ def __init__(self, compute_for_diffused_region_only: bool = False):
288
+ super().__init__()
289
+ self.clash_threshold = 1.2
290
+ self.float_threshold = (
291
+ 3.0 # maximum closest-neighbour distance before considered a floating atom
292
+ )
293
+ self.standard_ca_dist = 3.8
294
+ self.compute_for_diffused_region_only = compute_for_diffused_region_only
295
+
296
+ @property
297
+ def kwargs_to_compute_args(self) -> dict[str, Any]:
298
+ return {
299
+ "X_L": ("network_output", "X_L"), # [D, L, 3]
300
+ "tok_idx": ("network_input", "f", "atom_to_token_map"),
301
+ "f": ("network_input", "f"),
302
+ }
303
+
304
+ def compute(self, X_L, tok_idx, f):
305
+ o = {}
306
+ xyz = X_L.detach().cpu().numpy()
307
+ tok_idx = tok_idx.cpu().numpy()
308
+ dists = np.linalg.norm(
309
+ xyz[..., :, None, :] - xyz[..., None, :, :], axis=-1
310
+ ) # N_atoms x N_atoms
311
+
312
+ is_protein = f["is_protein"][tok_idx].cpu().numpy() # n_atoms
313
+
314
+ mask = np.zeros_like(dists, dtype=bool)
315
+ mask = mask | (np.eye(dists.shape[-1], dtype=bool))[None]
316
+ mask = mask | (tok_idx[:, None] == tok_idx[None, :])[None]
317
+ mask = mask | ~(is_protein[:, None] & is_protein[None, :])[None]
318
+ dists[mask] = 999
319
+
320
+ num_clashes_L = (dists.min(axis=-1) < self.clash_threshold).astype(
321
+ float
322
+ ) # B, L
323
+ o["frac_clashing"] = float(num_clashes_L.mean(-1).mean())
324
+ o["n_clashing"] = float(num_clashes_L.sum(-1).mean())
325
+
326
+ if "is_backbone" in f:
327
+ is_backbone = f["is_backbone"].cpu().numpy()
328
+ mask = np.zeros_like(dists, dtype=bool)
329
+ mask = mask | (tok_idx[:, None] == tok_idx[None, :])[None]
330
+ mask = mask | ~(is_backbone[:, None] & is_backbone[None, :])[None]
331
+ dists[mask] = 999
332
+ o["frac_backbone_clashing"] = float(
333
+ (dists.min(axis=-1) < self.clash_threshold)
334
+ .astype(float)
335
+ .mean(-1)
336
+ .mean()
337
+ )
338
+ o["n_backbone_clashing"] = float(
339
+ (dists.min(axis=-1) < self.clash_threshold).astype(float).sum(-1).mean()
340
+ )
341
+
342
+ # We do this after clash detection, since that should consider both chains
343
+ if self.compute_for_diffused_region_only:
344
+ diffused_region = ~(f["is_motif_atom_with_fixed_coord"].cpu().numpy())
345
+ xyz = xyz[:, diffused_region]
346
+ tok_idx = tok_idx[diffused_region]
347
+
348
+ # Num floating
349
+ dists = np.linalg.norm(
350
+ xyz[..., :, None, :] - xyz[..., None, :, :], axis=-1
351
+ ) # N_atoms x N_atoms
352
+ mask = np.zeros_like(dists, dtype=bool)
353
+ mask = mask & np.eye(dists.shape[-1], dtype=bool)[None]
354
+ dists[mask] = 999
355
+
356
+ is_floating = dists.min(axis=-1) > self.float_threshold
357
+ o["frac_floating"] = float(is_floating.mean(-1).mean())
358
+
359
+ if "is_ca" in f:
360
+ # Calculate CA
361
+ is_ca = f["is_ca"].cpu().numpy()
362
+ if self.compute_for_diffused_region_only:
363
+ is_ca = is_ca[diffused_region]
364
+ is_protein = is_protein[diffused_region]
365
+ idx_mask = is_ca & is_protein
366
+ if self.compute_for_diffused_region_only:
367
+ xyz = X_L.cpu()[:, diffused_region][:, idx_mask]
368
+ else:
369
+ xyz = X_L.cpu()[:, idx_mask]
370
+
371
+ ca_dists = torch.norm(xyz[:, 1:] - xyz[:, :-1], dim=-1)
372
+ deviation = torch.abs(ca_dists - self.standard_ca_dist) # B, (I-1)
373
+ is_chainbreak = deviation > 0.75
374
+
375
+ o["max_ca_deviation"] = float(deviation.max(-1).values.mean())
376
+ o["fraction_chainbreaks"] = float(is_chainbreak.float().mean(-1).mean())
377
+ o["n_chainbreaks"] = float(is_chainbreak.float().sum(-1).mean())
378
+
379
+ return o
380
+
381
+
382
+ class PPIMetrics(Metric):
383
+ """PPI-specific metrics"""
384
+
385
+ def __init__(self, distance_cutoff: float = 4.5):
386
+ super().__init__()
387
+ self.distance_cutoff = distance_cutoff # Distance cutoff for hotspot contacts
388
+
389
+ @property
390
+ def kwargs_to_compute_args(self) -> dict[str, Any]:
391
+ return {
392
+ "atom_array_stack": ("predicted_atom_array_stack"),
393
+ # "ppi_hotspots_mask": ("network_input", "f", "is_atom_level_hotspot"),
394
+ }
395
+
396
+ def compute(self, atom_array_stack):
397
+ # Get the number of hotspots for which a diffused atom is within the distance cutoff
398
+ metrics_dict = {"fraction_hotspots_contacted": []}
399
+ for atom_array in atom_array_stack:
400
+ ppi_hotspots_mask = atom_array.get_annotation(
401
+ "is_atom_level_hotspot"
402
+ ).astype(bool)
403
+ if ppi_hotspots_mask.sum() == 0:
404
+ continue
405
+
406
+ fraction_contacted = get_hotspot_contacts(
407
+ atom_array,
408
+ hotspot_mask=ppi_hotspots_mask,
409
+ distance_cutoff=self.distance_cutoff,
410
+ )
411
+
412
+ metrics_dict["fraction_hotspots_contacted"].append(fraction_contacted)
413
+
414
+ fraction_contacted_array = np.array(metrics_dict["fraction_hotspots_contacted"])
415
+
416
+ if fraction_contacted_array.size == 0:
417
+ return {}
418
+
419
+ return {"fraction_hotspots_contacted": float(np.mean(fraction_contacted_array))}
420
+
421
+
422
+ class SequenceMetrics(Metric):
423
+ @property
424
+ def kwargs_to_compute_args(self) -> dict[str, Any]:
425
+ return {
426
+ "S_I": ("network_output", "sequence_logits_I"), # [D, I, K]
427
+ "S_gt_I": ("extra_info", "seq_token_lvl"), # [D, I]
428
+ }
429
+
430
+ def compute(self, S_I, S_gt_I):
431
+ o = {}
432
+ seq_head_pred = S_I.argmax(dim=-1) # B, I
433
+ seq_head_recovery = seq_head_pred == S_gt_I
434
+
435
+ # Filter out unresolved residues
436
+ seq_head_recovery = seq_head_recovery.float().mean()
437
+ o["seq_head_recovery"] = float(seq_head_recovery.mean())
438
+
439
+ # Calculate the confusion matrix
440
+ seq_head_gt = S_gt_I[None].expand(seq_head_pred.shape[0], -1) # B, I
441
+
442
+ # One-hot encode predictions and ground truth
443
+ seq_head_pred = S_I.clone()
444
+ seq_head_pred = torch.nn.functional.softmax(seq_head_pred, dim=-1) # (B, I, C)
445
+
446
+ # Set any unresolve residues to be 31
447
+ seq_head_gt = torch.nn.functional.one_hot(
448
+ seq_head_gt, num_classes=S_I.shape[-1]
449
+ ).float() # (B, I, C)
450
+
451
+ # Permute predictions to shape (B, C, I) for matmul
452
+ seq_head_pred = seq_head_pred.permute(0, 2, 1) # (B, C, I)
453
+
454
+ # Compute confusion matrix per batch (B, C, C)
455
+ confusion_matrix = torch.matmul(seq_head_pred, seq_head_gt)
456
+
457
+ # Sum over batch to get (C, C)
458
+ confusion_matrix = confusion_matrix.sum(dim=0)
459
+ confusion_matrix = confusion_matrix.cpu().numpy().astype(np.float32)
460
+
461
+ for i in range(confusion_matrix.shape[0]):
462
+ for j in range(confusion_matrix.shape[1]):
463
+ o[f"confusion_matrix_{i}_{j}"] = confusion_matrix[i, j]
464
+
465
+ return o