rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,465 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
from atomworks.ml.utils.token import (
|
|
4
|
+
get_token_starts,
|
|
5
|
+
)
|
|
6
|
+
from beartype.typing import Any
|
|
7
|
+
from rfd3.metrics.metrics_utils import (
|
|
8
|
+
_flatten_dict,
|
|
9
|
+
get_hotspot_contacts,
|
|
10
|
+
get_ss_metrics_and_rg,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
from foundry.common import exists
|
|
14
|
+
from foundry.metrics.metric import Metric
|
|
15
|
+
|
|
16
|
+
STANDARD_CACA_DIST = 3.8
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_clash_metrics(
|
|
20
|
+
atom_array,
|
|
21
|
+
clash_threshold=1.5,
|
|
22
|
+
ligand_clash_threshold=1.5,
|
|
23
|
+
chainbreak_threshold=0.75,
|
|
24
|
+
):
|
|
25
|
+
# HACK: For now, ligands are treated as any atomized residues
|
|
26
|
+
is_ligand = np.logical_and(
|
|
27
|
+
atom_array.is_ligand, ~atom_array.is_motif_atom_unindexed
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
def get_chainbreaks():
|
|
31
|
+
ca_atoms = atom_array[atom_array.atom_name == "CA"]
|
|
32
|
+
xyz = ca_atoms.coord
|
|
33
|
+
xyz = torch.from_numpy(xyz)
|
|
34
|
+
ca_dists = torch.norm(xyz[1:] - xyz[:-1], dim=-1)
|
|
35
|
+
deviation = torch.abs(ca_dists - STANDARD_CACA_DIST)
|
|
36
|
+
|
|
37
|
+
# Allow leniency for expected chain breaks (e.g. PPI)
|
|
38
|
+
chain_breaks = ca_atoms.chain_iid[1:] != ca_atoms.chain_iid[:-1]
|
|
39
|
+
deviation[chain_breaks] = 0
|
|
40
|
+
|
|
41
|
+
is_chainbreak = deviation > chainbreak_threshold
|
|
42
|
+
return {
|
|
43
|
+
"max_ca_deviation": float(deviation.max(-1).values.mean()),
|
|
44
|
+
"n_chainbreaks": int(is_chainbreak.sum()),
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
def get_interresidue_clashes(backbone_only=False):
|
|
48
|
+
protein_array = atom_array[atom_array.is_protein]
|
|
49
|
+
resid = protein_array.res_id - protein_array.res_id.min()
|
|
50
|
+
xyz = protein_array.coord
|
|
51
|
+
dists = np.linalg.norm(xyz[:, None] - xyz[None], axis=-1) # N_atoms x N_atoms
|
|
52
|
+
|
|
53
|
+
# Block out intra-residue distances
|
|
54
|
+
mask = np.triu(np.ones_like(dists), k=1).astype(bool)
|
|
55
|
+
block_mask = np.abs(resid[:, None] - resid[None, :]) <= 1
|
|
56
|
+
mask[block_mask] = False
|
|
57
|
+
dists[~mask] = 999
|
|
58
|
+
|
|
59
|
+
if backbone_only:
|
|
60
|
+
# Block out non-backbone atoms
|
|
61
|
+
backbone_mask = np.isin(protein_array.atom_name, ["N", "CA", "C"])
|
|
62
|
+
mask = backbone_mask[:, None] & backbone_mask[None, :]
|
|
63
|
+
dists[~mask] = 999
|
|
64
|
+
|
|
65
|
+
num_clashes_L = dists.min(axis=-1) < clash_threshold
|
|
66
|
+
return int(num_clashes_L.sum())
|
|
67
|
+
|
|
68
|
+
def get_ligand_clash_metrics():
|
|
69
|
+
if not is_ligand.any():
|
|
70
|
+
return {}
|
|
71
|
+
|
|
72
|
+
# Clashes are any non-motif atom against any ligand atom
|
|
73
|
+
xyz_ligand = atom_array[is_ligand].coord
|
|
74
|
+
backbone_mask = np.isin(atom_array.atom_name, ["N", "CA", "C"]) & ~is_ligand
|
|
75
|
+
xyz_diffused = atom_array[
|
|
76
|
+
backbone_mask
|
|
77
|
+
& ~atom_array.is_motif_atom_unindexed
|
|
78
|
+
& ~atom_array.is_motif_atom_with_fixed_coord
|
|
79
|
+
].coord
|
|
80
|
+
|
|
81
|
+
# If we have no diffused backbone atoms, return empty
|
|
82
|
+
if xyz_diffused.shape[0] == 0:
|
|
83
|
+
return {}
|
|
84
|
+
|
|
85
|
+
diff = (
|
|
86
|
+
xyz_diffused[:, None, :] - xyz_ligand[None, :, :]
|
|
87
|
+
) # (n_diffused, n_ligand, 3)
|
|
88
|
+
dists_ligand = np.linalg.norm(diff, axis=-1) # (n_diffused, n_ligand)
|
|
89
|
+
dists = np.min(dists_ligand, axis=0)
|
|
90
|
+
return {
|
|
91
|
+
"n_clashing.ligand_clashes": int(np.sum(dists < ligand_clash_threshold)),
|
|
92
|
+
"n_clashing.ligand_min_distance": float(np.min(dists)),
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
# Accumulate metrics
|
|
96
|
+
o = {}
|
|
97
|
+
o = o | get_chainbreaks()
|
|
98
|
+
o["n_clashing.interresidue_clashes_w_sidechain"] = get_interresidue_clashes()
|
|
99
|
+
o["n_clashing.interresidue_clashes_w_backbone"] = get_interresidue_clashes(
|
|
100
|
+
backbone_only=True
|
|
101
|
+
)
|
|
102
|
+
o |= get_ligand_clash_metrics()
|
|
103
|
+
return {k: v for k, v in o.items() if exists(v)}
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def convert_to_float_or_str(o):
|
|
107
|
+
"""
|
|
108
|
+
Converts elements of a dictionary to ensure all components are saveable with JSONs
|
|
109
|
+
"""
|
|
110
|
+
for k, v in o.items():
|
|
111
|
+
if not isinstance(v, (int, float, str, list)):
|
|
112
|
+
try:
|
|
113
|
+
o[k] = float(v)
|
|
114
|
+
except Exception as e:
|
|
115
|
+
raise ValueError(f"Unsupported type for key {k}: {type(v)}. Error: {e}")
|
|
116
|
+
return o
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def get_all_backbone_metrics(
|
|
120
|
+
atom_array,
|
|
121
|
+
verbose=True,
|
|
122
|
+
compute_non_clash_metrics_for_diffused_region_only: bool = False,
|
|
123
|
+
):
|
|
124
|
+
"""
|
|
125
|
+
Calculate metrics for the AtomArray
|
|
126
|
+
|
|
127
|
+
The atom array coming in will be a cleaned atom array (no virtual atoms and corrected atom names)
|
|
128
|
+
without guideposts
|
|
129
|
+
"""
|
|
130
|
+
o = {}
|
|
131
|
+
|
|
132
|
+
# ... Clash metrics
|
|
133
|
+
o = o | get_clash_metrics(
|
|
134
|
+
atom_array,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
if verbose:
|
|
138
|
+
if compute_non_clash_metrics_for_diffused_region_only:
|
|
139
|
+
# Subset to diffused region only
|
|
140
|
+
atom_array = atom_array[~atom_array.is_motif_atom_with_fixed_coord]
|
|
141
|
+
|
|
142
|
+
# ... Add additional metrics
|
|
143
|
+
o |= get_ss_metrics_and_rg(
|
|
144
|
+
atom_array[~atom_array.is_motif_atom_with_fixed_coord]
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
# Basic compositional statistics
|
|
148
|
+
starts = get_token_starts(atom_array)
|
|
149
|
+
protein_starts = starts[atom_array.is_protein[starts]]
|
|
150
|
+
o["alanine_content"] = np.mean(atom_array[protein_starts].res_name == "ALA")
|
|
151
|
+
o["glycine_content"] = np.mean(atom_array[protein_starts].res_name == "GLY")
|
|
152
|
+
o["num_residues"] = len(protein_starts)
|
|
153
|
+
|
|
154
|
+
fixed = atom_array.is_motif_atom_with_fixed_coord
|
|
155
|
+
o["diffused_com"] = np.mean(atom_array.coord[~fixed, :], axis=0).tolist()
|
|
156
|
+
if np.any(fixed):
|
|
157
|
+
o["fixed_com"] = np.mean(atom_array.coord[fixed, :], axis=0).tolist()
|
|
158
|
+
|
|
159
|
+
# if "b_factor" in token_array.get_annotation_categories():
|
|
160
|
+
# m["sequence_entropy_mean"] = np.mean(token_array.b_factor)
|
|
161
|
+
# m["sequence_entropy_max"] = np.max(token_array.b_factor)
|
|
162
|
+
# m["sequence_entropy_min"] = np.min(token_array.b_factor)
|
|
163
|
+
# m["sequence_entropy_std"] = np.std(token_array.b_factor)
|
|
164
|
+
|
|
165
|
+
# ... Ensure JSON-saveable
|
|
166
|
+
o = convert_to_float_or_str(o)
|
|
167
|
+
return o
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class AtomArrayMetrics(Metric):
|
|
171
|
+
"""General metrics for the predicted atom array."""
|
|
172
|
+
|
|
173
|
+
def __init__(
|
|
174
|
+
self,
|
|
175
|
+
compute_for_diffused_region_only: bool = False,
|
|
176
|
+
compute_ss_adherence_if_possible: bool = False,
|
|
177
|
+
):
|
|
178
|
+
super().__init__()
|
|
179
|
+
self.clash_threshold = 1.2
|
|
180
|
+
self.float_threshold = (
|
|
181
|
+
3.0 # maximum closest-neighbour distance before considered a floating atom
|
|
182
|
+
)
|
|
183
|
+
self.standard_ca_dist = 3.8
|
|
184
|
+
self.compute_for_diffused_region_only = compute_for_diffused_region_only
|
|
185
|
+
self.compute_ss_adherence_if_possible = compute_ss_adherence_if_possible
|
|
186
|
+
|
|
187
|
+
@property
|
|
188
|
+
def kwargs_to_compute_args(self) -> dict[str, Any]:
|
|
189
|
+
return {
|
|
190
|
+
"atom_array_stack": ("predicted_atom_array_stack"),
|
|
191
|
+
"feats": ("network_input", "f"),
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
def compute(self, atom_array_stack, feats):
|
|
195
|
+
o = {}
|
|
196
|
+
|
|
197
|
+
for atom_array in atom_array_stack:
|
|
198
|
+
# Subset to indexed tokens only
|
|
199
|
+
atom_array = atom_array[~atom_array.is_motif_atom_unindexed]
|
|
200
|
+
|
|
201
|
+
if self.compute_for_diffused_region_only:
|
|
202
|
+
atom_array = atom_array[~atom_array.is_motif_atom_with_fixed_coord]
|
|
203
|
+
|
|
204
|
+
# SS content and ROG
|
|
205
|
+
if (
|
|
206
|
+
self.compute_ss_adherence_if_possible
|
|
207
|
+
and (
|
|
208
|
+
"is_helix_conditioning" in feats
|
|
209
|
+
and "is_sheet_conditioning" in feats
|
|
210
|
+
and "is_loop_conditioning" in feats
|
|
211
|
+
)
|
|
212
|
+
and (
|
|
213
|
+
feats["is_helix_conditioning"].sum() > 0
|
|
214
|
+
or feats["is_sheet_conditioning"].sum() > 0
|
|
215
|
+
or feats["is_loop_conditioning"].sum() > 0
|
|
216
|
+
)
|
|
217
|
+
):
|
|
218
|
+
ss_conditioning = {
|
|
219
|
+
"helix": feats["is_helix_conditioning"].cpu().numpy(),
|
|
220
|
+
"sheet": feats["is_sheet_conditioning"].cpu().numpy(),
|
|
221
|
+
"loop": feats["is_loop_conditioning"].cpu().numpy(),
|
|
222
|
+
}
|
|
223
|
+
else:
|
|
224
|
+
ss_conditioning = None
|
|
225
|
+
m = get_ss_metrics_and_rg(atom_array, ss_conditioning=ss_conditioning)
|
|
226
|
+
|
|
227
|
+
# Subset to token level array for consistency
|
|
228
|
+
token_array = atom_array[get_token_starts(atom_array)]
|
|
229
|
+
|
|
230
|
+
# Basic compositional statistics
|
|
231
|
+
m["alanine_content"] = np.mean(token_array.res_name == "ALA")
|
|
232
|
+
m["glycine_content"] = np.mean(token_array.res_name == "GLY")
|
|
233
|
+
|
|
234
|
+
# Sequence Confidence
|
|
235
|
+
if "b_factor" in token_array.get_annotation_categories():
|
|
236
|
+
m["sequence_entropy_mean"] = np.mean(token_array.b_factor)
|
|
237
|
+
m["sequence_entropy_max"] = np.max(token_array.b_factor)
|
|
238
|
+
m["sequence_entropy_min"] = np.min(token_array.b_factor)
|
|
239
|
+
m["sequence_entropy_std"] = np.std(token_array.b_factor)
|
|
240
|
+
|
|
241
|
+
# Write to o
|
|
242
|
+
for k, v in m.items():
|
|
243
|
+
if k not in o:
|
|
244
|
+
o[k] = []
|
|
245
|
+
o[k].append(v)
|
|
246
|
+
|
|
247
|
+
# Summarize stats
|
|
248
|
+
for k, v in o.items():
|
|
249
|
+
o[k] = float(np.mean(v))
|
|
250
|
+
return o
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
class MetadataMetrics(Metric):
|
|
254
|
+
"""
|
|
255
|
+
Fetches all floating point values from the prediction metadata
|
|
256
|
+
"""
|
|
257
|
+
|
|
258
|
+
@property
|
|
259
|
+
def kwargs_to_compute_args(self):
|
|
260
|
+
return {
|
|
261
|
+
"prediction_metadata": ("prediction_metadata",),
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
def compute(self, prediction_metadata):
|
|
265
|
+
""" """
|
|
266
|
+
if not prediction_metadata:
|
|
267
|
+
return {}
|
|
268
|
+
|
|
269
|
+
o = {}
|
|
270
|
+
for idx, metadata in prediction_metadata.items():
|
|
271
|
+
# Flatten dictionary
|
|
272
|
+
metadata = _flatten_dict(metadata)
|
|
273
|
+
|
|
274
|
+
# Update output dictionary
|
|
275
|
+
for key, value in metadata.items():
|
|
276
|
+
if isinstance(value, (int, float)):
|
|
277
|
+
if key not in o:
|
|
278
|
+
o[key] = []
|
|
279
|
+
o[key].append(value)
|
|
280
|
+
|
|
281
|
+
# Reduce via mean
|
|
282
|
+
o = {k: float(np.mean(v)) for k, v in o.items()}
|
|
283
|
+
return o
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
class BackboneMetrics(Metric):
|
|
287
|
+
def __init__(self, compute_for_diffused_region_only: bool = False):
|
|
288
|
+
super().__init__()
|
|
289
|
+
self.clash_threshold = 1.2
|
|
290
|
+
self.float_threshold = (
|
|
291
|
+
3.0 # maximum closest-neighbour distance before considered a floating atom
|
|
292
|
+
)
|
|
293
|
+
self.standard_ca_dist = 3.8
|
|
294
|
+
self.compute_for_diffused_region_only = compute_for_diffused_region_only
|
|
295
|
+
|
|
296
|
+
@property
|
|
297
|
+
def kwargs_to_compute_args(self) -> dict[str, Any]:
|
|
298
|
+
return {
|
|
299
|
+
"X_L": ("network_output", "X_L"), # [D, L, 3]
|
|
300
|
+
"tok_idx": ("network_input", "f", "atom_to_token_map"),
|
|
301
|
+
"f": ("network_input", "f"),
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
def compute(self, X_L, tok_idx, f):
|
|
305
|
+
o = {}
|
|
306
|
+
xyz = X_L.detach().cpu().numpy()
|
|
307
|
+
tok_idx = tok_idx.cpu().numpy()
|
|
308
|
+
dists = np.linalg.norm(
|
|
309
|
+
xyz[..., :, None, :] - xyz[..., None, :, :], axis=-1
|
|
310
|
+
) # N_atoms x N_atoms
|
|
311
|
+
|
|
312
|
+
is_protein = f["is_protein"][tok_idx].cpu().numpy() # n_atoms
|
|
313
|
+
|
|
314
|
+
mask = np.zeros_like(dists, dtype=bool)
|
|
315
|
+
mask = mask | (np.eye(dists.shape[-1], dtype=bool))[None]
|
|
316
|
+
mask = mask | (tok_idx[:, None] == tok_idx[None, :])[None]
|
|
317
|
+
mask = mask | ~(is_protein[:, None] & is_protein[None, :])[None]
|
|
318
|
+
dists[mask] = 999
|
|
319
|
+
|
|
320
|
+
num_clashes_L = (dists.min(axis=-1) < self.clash_threshold).astype(
|
|
321
|
+
float
|
|
322
|
+
) # B, L
|
|
323
|
+
o["frac_clashing"] = float(num_clashes_L.mean(-1).mean())
|
|
324
|
+
o["n_clashing"] = float(num_clashes_L.sum(-1).mean())
|
|
325
|
+
|
|
326
|
+
if "is_backbone" in f:
|
|
327
|
+
is_backbone = f["is_backbone"].cpu().numpy()
|
|
328
|
+
mask = np.zeros_like(dists, dtype=bool)
|
|
329
|
+
mask = mask | (tok_idx[:, None] == tok_idx[None, :])[None]
|
|
330
|
+
mask = mask | ~(is_backbone[:, None] & is_backbone[None, :])[None]
|
|
331
|
+
dists[mask] = 999
|
|
332
|
+
o["frac_backbone_clashing"] = float(
|
|
333
|
+
(dists.min(axis=-1) < self.clash_threshold)
|
|
334
|
+
.astype(float)
|
|
335
|
+
.mean(-1)
|
|
336
|
+
.mean()
|
|
337
|
+
)
|
|
338
|
+
o["n_backbone_clashing"] = float(
|
|
339
|
+
(dists.min(axis=-1) < self.clash_threshold).astype(float).sum(-1).mean()
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
# We do this after clash detection, since that should consider both chains
|
|
343
|
+
if self.compute_for_diffused_region_only:
|
|
344
|
+
diffused_region = ~(f["is_motif_atom_with_fixed_coord"].cpu().numpy())
|
|
345
|
+
xyz = xyz[:, diffused_region]
|
|
346
|
+
tok_idx = tok_idx[diffused_region]
|
|
347
|
+
|
|
348
|
+
# Num floating
|
|
349
|
+
dists = np.linalg.norm(
|
|
350
|
+
xyz[..., :, None, :] - xyz[..., None, :, :], axis=-1
|
|
351
|
+
) # N_atoms x N_atoms
|
|
352
|
+
mask = np.zeros_like(dists, dtype=bool)
|
|
353
|
+
mask = mask & np.eye(dists.shape[-1], dtype=bool)[None]
|
|
354
|
+
dists[mask] = 999
|
|
355
|
+
|
|
356
|
+
is_floating = dists.min(axis=-1) > self.float_threshold
|
|
357
|
+
o["frac_floating"] = float(is_floating.mean(-1).mean())
|
|
358
|
+
|
|
359
|
+
if "is_ca" in f:
|
|
360
|
+
# Calculate CA
|
|
361
|
+
is_ca = f["is_ca"].cpu().numpy()
|
|
362
|
+
if self.compute_for_diffused_region_only:
|
|
363
|
+
is_ca = is_ca[diffused_region]
|
|
364
|
+
is_protein = is_protein[diffused_region]
|
|
365
|
+
idx_mask = is_ca & is_protein
|
|
366
|
+
if self.compute_for_diffused_region_only:
|
|
367
|
+
xyz = X_L.cpu()[:, diffused_region][:, idx_mask]
|
|
368
|
+
else:
|
|
369
|
+
xyz = X_L.cpu()[:, idx_mask]
|
|
370
|
+
|
|
371
|
+
ca_dists = torch.norm(xyz[:, 1:] - xyz[:, :-1], dim=-1)
|
|
372
|
+
deviation = torch.abs(ca_dists - self.standard_ca_dist) # B, (I-1)
|
|
373
|
+
is_chainbreak = deviation > 0.75
|
|
374
|
+
|
|
375
|
+
o["max_ca_deviation"] = float(deviation.max(-1).values.mean())
|
|
376
|
+
o["fraction_chainbreaks"] = float(is_chainbreak.float().mean(-1).mean())
|
|
377
|
+
o["n_chainbreaks"] = float(is_chainbreak.float().sum(-1).mean())
|
|
378
|
+
|
|
379
|
+
return o
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
class PPIMetrics(Metric):
|
|
383
|
+
"""PPI-specific metrics"""
|
|
384
|
+
|
|
385
|
+
def __init__(self, distance_cutoff: float = 4.5):
|
|
386
|
+
super().__init__()
|
|
387
|
+
self.distance_cutoff = distance_cutoff # Distance cutoff for hotspot contacts
|
|
388
|
+
|
|
389
|
+
@property
|
|
390
|
+
def kwargs_to_compute_args(self) -> dict[str, Any]:
|
|
391
|
+
return {
|
|
392
|
+
"atom_array_stack": ("predicted_atom_array_stack"),
|
|
393
|
+
# "ppi_hotspots_mask": ("network_input", "f", "is_atom_level_hotspot"),
|
|
394
|
+
}
|
|
395
|
+
|
|
396
|
+
def compute(self, atom_array_stack):
|
|
397
|
+
# Get the number of hotspots for which a diffused atom is within the distance cutoff
|
|
398
|
+
metrics_dict = {"fraction_hotspots_contacted": []}
|
|
399
|
+
for atom_array in atom_array_stack:
|
|
400
|
+
ppi_hotspots_mask = atom_array.get_annotation(
|
|
401
|
+
"is_atom_level_hotspot"
|
|
402
|
+
).astype(bool)
|
|
403
|
+
if ppi_hotspots_mask.sum() == 0:
|
|
404
|
+
continue
|
|
405
|
+
|
|
406
|
+
fraction_contacted = get_hotspot_contacts(
|
|
407
|
+
atom_array,
|
|
408
|
+
hotspot_mask=ppi_hotspots_mask,
|
|
409
|
+
distance_cutoff=self.distance_cutoff,
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
metrics_dict["fraction_hotspots_contacted"].append(fraction_contacted)
|
|
413
|
+
|
|
414
|
+
fraction_contacted_array = np.array(metrics_dict["fraction_hotspots_contacted"])
|
|
415
|
+
|
|
416
|
+
if fraction_contacted_array.size == 0:
|
|
417
|
+
return {}
|
|
418
|
+
|
|
419
|
+
return {"fraction_hotspots_contacted": float(np.mean(fraction_contacted_array))}
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
class SequenceMetrics(Metric):
|
|
423
|
+
@property
|
|
424
|
+
def kwargs_to_compute_args(self) -> dict[str, Any]:
|
|
425
|
+
return {
|
|
426
|
+
"S_I": ("network_output", "sequence_logits_I"), # [D, I, K]
|
|
427
|
+
"S_gt_I": ("extra_info", "seq_token_lvl"), # [D, I]
|
|
428
|
+
}
|
|
429
|
+
|
|
430
|
+
def compute(self, S_I, S_gt_I):
|
|
431
|
+
o = {}
|
|
432
|
+
seq_head_pred = S_I.argmax(dim=-1) # B, I
|
|
433
|
+
seq_head_recovery = seq_head_pred == S_gt_I
|
|
434
|
+
|
|
435
|
+
# Filter out unresolved residues
|
|
436
|
+
seq_head_recovery = seq_head_recovery.float().mean()
|
|
437
|
+
o["seq_head_recovery"] = float(seq_head_recovery.mean())
|
|
438
|
+
|
|
439
|
+
# Calculate the confusion matrix
|
|
440
|
+
seq_head_gt = S_gt_I[None].expand(seq_head_pred.shape[0], -1) # B, I
|
|
441
|
+
|
|
442
|
+
# One-hot encode predictions and ground truth
|
|
443
|
+
seq_head_pred = S_I.clone()
|
|
444
|
+
seq_head_pred = torch.nn.functional.softmax(seq_head_pred, dim=-1) # (B, I, C)
|
|
445
|
+
|
|
446
|
+
# Set any unresolve residues to be 31
|
|
447
|
+
seq_head_gt = torch.nn.functional.one_hot(
|
|
448
|
+
seq_head_gt, num_classes=S_I.shape[-1]
|
|
449
|
+
).float() # (B, I, C)
|
|
450
|
+
|
|
451
|
+
# Permute predictions to shape (B, C, I) for matmul
|
|
452
|
+
seq_head_pred = seq_head_pred.permute(0, 2, 1) # (B, C, I)
|
|
453
|
+
|
|
454
|
+
# Compute confusion matrix per batch (B, C, C)
|
|
455
|
+
confusion_matrix = torch.matmul(seq_head_pred, seq_head_gt)
|
|
456
|
+
|
|
457
|
+
# Sum over batch to get (C, C)
|
|
458
|
+
confusion_matrix = confusion_matrix.sum(dim=0)
|
|
459
|
+
confusion_matrix = confusion_matrix.cpu().numpy().astype(np.float32)
|
|
460
|
+
|
|
461
|
+
for i in range(confusion_matrix.shape[0]):
|
|
462
|
+
for j in range(confusion_matrix.shape[1]):
|
|
463
|
+
o[f"confusion_matrix_{i}_{j}"] = confusion_matrix[i, j]
|
|
464
|
+
|
|
465
|
+
return o
|