rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
rf3/metrics/lddt.py
ADDED
|
@@ -0,0 +1,523 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
from atomworks.io.transforms.atom_array import ensure_atom_array_stack
|
|
4
|
+
from atomworks.ml.transforms.atom_array import AddGlobalTokenIdAnnotation
|
|
5
|
+
from atomworks.ml.transforms.atomize import AtomizeByCCDName
|
|
6
|
+
from atomworks.ml.transforms.base import Compose
|
|
7
|
+
from atomworks.ml.utils.token import get_token_starts
|
|
8
|
+
from beartype.typing import Any
|
|
9
|
+
from biotite.structure import AtomArray, AtomArrayStack, stack
|
|
10
|
+
from jaxtyping import Bool, Float, Int
|
|
11
|
+
|
|
12
|
+
from foundry.metrics.metric import Metric
|
|
13
|
+
from foundry.utils.ddp import RankedLogger
|
|
14
|
+
|
|
15
|
+
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def calc_lddt(
|
|
19
|
+
X_L: Float[torch.Tensor, "D L 3"],
|
|
20
|
+
X_gt_L: Float[torch.Tensor, "D L 3"],
|
|
21
|
+
crd_mask_L: Bool[torch.Tensor, "D L"],
|
|
22
|
+
tok_idx: Int[torch.Tensor, "L"],
|
|
23
|
+
pairs_to_score: Bool[torch.Tensor, "L L"] | None = None,
|
|
24
|
+
distance_cutoff: float = 15.0,
|
|
25
|
+
eps: float = 1e-6,
|
|
26
|
+
) -> Float[torch.Tensor, "D"]:
|
|
27
|
+
"""Calculates LDDT scores for each model in the batch.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
X_L: Predicted coordinates (D, L, 3).
|
|
31
|
+
X_gt_L: Ground truth coordinates (D, L, 3).
|
|
32
|
+
crd_mask_L: Coordinate mask indicating valid atoms (D, L).
|
|
33
|
+
tok_idx: Token index of each atom (L,). Used to exclude same-token pairs.
|
|
34
|
+
pairs_to_score: Boolean mask for pairs to score (L, L). If None, scores all valid pairs.
|
|
35
|
+
distance_cutoff: Distance cutoff for scoring pairs.
|
|
36
|
+
eps: Small epsilon to prevent division by zero.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
LDDT scores for each model (D,).
|
|
40
|
+
"""
|
|
41
|
+
D, L = X_L.shape[:2]
|
|
42
|
+
|
|
43
|
+
# Create pairs to score mask - if not provided, use upper triangular (includes diagonal)
|
|
44
|
+
if pairs_to_score is None:
|
|
45
|
+
pairs_to_score = torch.ones((L, L), dtype=torch.bool).triu(0).to(X_L.device)
|
|
46
|
+
else:
|
|
47
|
+
assert pairs_to_score.shape == (L, L)
|
|
48
|
+
pairs_to_score = pairs_to_score.triu(0).to(X_L.device)
|
|
49
|
+
|
|
50
|
+
# Get indices of atom pairs to evaluate
|
|
51
|
+
first_index: Int[torch.Tensor, "n_pairs"]
|
|
52
|
+
second_index: Int[torch.Tensor, "n_pairs"]
|
|
53
|
+
first_index, second_index = torch.nonzero(pairs_to_score, as_tuple=True)
|
|
54
|
+
|
|
55
|
+
# Compute LDDT score for each model in the batch
|
|
56
|
+
lddt_scores = []
|
|
57
|
+
for d in range(D):
|
|
58
|
+
# Calculate pairwise distances in ground truth structure
|
|
59
|
+
ground_truth_distances = torch.linalg.norm(
|
|
60
|
+
X_gt_L[d, first_index] - X_gt_L[d, second_index], dim=-1
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
# Create mask for valid pairs to score:
|
|
64
|
+
# 1. Ground truth distance > 0 (atoms not at same position)
|
|
65
|
+
# 2. Ground truth distance < cutoff (within interaction range)
|
|
66
|
+
pair_mask = torch.logical_and(
|
|
67
|
+
ground_truth_distances > 0, ground_truth_distances < distance_cutoff
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Only score pairs that are resolved in the ground truth
|
|
71
|
+
pair_mask *= crd_mask_L[d, first_index] * crd_mask_L[d, second_index]
|
|
72
|
+
|
|
73
|
+
# Don't score pairs that are in the same token (e.g., same residue)
|
|
74
|
+
pair_mask *= tok_idx[first_index] != tok_idx[second_index]
|
|
75
|
+
|
|
76
|
+
# Filter to only "valid" pairs
|
|
77
|
+
valid_pairs = pair_mask.nonzero(as_tuple=True)
|
|
78
|
+
|
|
79
|
+
pair_mask_valid = pair_mask[valid_pairs].to(X_L.dtype)
|
|
80
|
+
ground_truth_distances_valid = ground_truth_distances[valid_pairs]
|
|
81
|
+
|
|
82
|
+
first_index_valid: Int[torch.Tensor, "n_valid_pairs"] = first_index[valid_pairs]
|
|
83
|
+
second_index_valid: Int[torch.Tensor, "n_valid_pairs"] = second_index[
|
|
84
|
+
valid_pairs
|
|
85
|
+
]
|
|
86
|
+
|
|
87
|
+
# Calculate pairwise distances in predicted structure
|
|
88
|
+
predicted_distances = torch.linalg.norm(
|
|
89
|
+
X_L[d, first_index_valid] - X_L[d, second_index_valid], dim=-1
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
# Compute absolute distance differences (with small eps to avoid numerical issues)
|
|
93
|
+
delta_distances = torch.abs(
|
|
94
|
+
predicted_distances - ground_truth_distances_valid + eps
|
|
95
|
+
)
|
|
96
|
+
del predicted_distances, ground_truth_distances_valid
|
|
97
|
+
|
|
98
|
+
# Calculate LDDT score using standard thresholds (0.5Å, 1.0Å, 2.0Å, 4.0Å)
|
|
99
|
+
# LDDT is the average fraction of distances preserved within each threshold
|
|
100
|
+
lddt_score = (
|
|
101
|
+
0.25
|
|
102
|
+
* (
|
|
103
|
+
torch.sum((delta_distances < 0.5) * pair_mask_valid) # 0.5Å threshold
|
|
104
|
+
+ torch.sum((delta_distances < 1.0) * pair_mask_valid) # 1.0Å threshold
|
|
105
|
+
+ torch.sum((delta_distances < 2.0) * pair_mask_valid) # 2.0Å threshold
|
|
106
|
+
+ torch.sum((delta_distances < 4.0) * pair_mask_valid) # 4.0Å threshold
|
|
107
|
+
)
|
|
108
|
+
/ (torch.sum(pair_mask_valid) + eps) # Normalize by number of valid pairs
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
lddt_scores.append(lddt_score)
|
|
112
|
+
|
|
113
|
+
return torch.tensor(lddt_scores, device=X_L.device)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def extract_lddt_features_from_atom_arrays(
|
|
117
|
+
predicted_atom_array_stack: AtomArrayStack | AtomArray,
|
|
118
|
+
ground_truth_atom_array_stack: AtomArrayStack | AtomArray,
|
|
119
|
+
) -> dict[str, Any]:
|
|
120
|
+
"""Extract all features needed for LDDT computation from AtomArrays.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
predicted_atom_array_stack: Predicted coordinates as AtomArray(Stack)
|
|
124
|
+
ground_truth_atom_array_stack: Ground truth coordinates as AtomArray(Stack)
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
Dictionary containing:
|
|
128
|
+
- X_L: Predicted coordinates tensor (D, L, 3)
|
|
129
|
+
- X_gt_L: Ground truth coordinates tensor (D, L, 3)
|
|
130
|
+
- crd_mask_L: Coordinate validity mask (D, L)
|
|
131
|
+
- tok_idx: Token indices for each atom (L,)
|
|
132
|
+
- chain_iid_token_lvl: Chain identification at token level
|
|
133
|
+
"""
|
|
134
|
+
predicted_atom_array_stack = ensure_atom_array_stack(predicted_atom_array_stack)
|
|
135
|
+
ground_truth_atom_array_stack = ensure_atom_array_stack(
|
|
136
|
+
ground_truth_atom_array_stack
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
if (
|
|
140
|
+
ground_truth_atom_array_stack.stack_depth() == 1
|
|
141
|
+
and predicted_atom_array_stack.stack_depth() > 1
|
|
142
|
+
):
|
|
143
|
+
# If the ground truth is a single model, and the predicted is a stack, we need to expand the ground truth to the same length as the predicted
|
|
144
|
+
ground_truth_atom_array_stack = stack(
|
|
145
|
+
[ground_truth_atom_array_stack[0]]
|
|
146
|
+
* predicted_atom_array_stack.stack_depth()
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
# Compute coordinates - convert AtomArrays to tensors
|
|
150
|
+
X_L: Float[torch.Tensor, "D L 3"] = torch.from_numpy(
|
|
151
|
+
predicted_atom_array_stack.coord
|
|
152
|
+
).float()
|
|
153
|
+
X_gt_L: Float[torch.Tensor, "D L 3"] = torch.from_numpy(
|
|
154
|
+
ground_truth_atom_array_stack.coord
|
|
155
|
+
).float()
|
|
156
|
+
|
|
157
|
+
# For the remaining feature generation, we can directly use the first model in the stack (only coordinates are different)
|
|
158
|
+
ground_truth_atom_array = ground_truth_atom_array_stack[0]
|
|
159
|
+
|
|
160
|
+
# Create coordinate mask using occupancy if available, fallback to coordinate validity
|
|
161
|
+
if "occupancy" in ground_truth_atom_array.get_annotation_categories():
|
|
162
|
+
# Use occupancy annotation (broadcast to all models in stack)if present (occupancy > 0 means atom is present)
|
|
163
|
+
occupancy_mask = ground_truth_atom_array.occupancy > 0
|
|
164
|
+
crd_mask_L: Bool[torch.Tensor, "D L"] = (
|
|
165
|
+
torch.from_numpy(occupancy_mask)
|
|
166
|
+
.bool()
|
|
167
|
+
.unsqueeze(0)
|
|
168
|
+
.expand(X_gt_L.shape[0], -1)
|
|
169
|
+
)
|
|
170
|
+
else:
|
|
171
|
+
# Fallback to coordinate validity (not NaN)
|
|
172
|
+
crd_mask_L: Bool[torch.Tensor, "D L"] = ~torch.isnan(X_gt_L).any(dim=-1)
|
|
173
|
+
|
|
174
|
+
# Get token indices using the same logic as ComputeAtomToTokenMap
|
|
175
|
+
if "token_id" in ground_truth_atom_array.get_annotation_categories():
|
|
176
|
+
# Use the existing token_id annotation (matches ComputeAtomToTokenMap exactly)
|
|
177
|
+
tok_idx = ground_truth_atom_array.token_id.astype(np.int32)
|
|
178
|
+
else:
|
|
179
|
+
# Generate annotations with Transform pipeline
|
|
180
|
+
pipe = Compose(
|
|
181
|
+
[AtomizeByCCDName(atomize_by_default=True), AddGlobalTokenIdAnnotation()]
|
|
182
|
+
)
|
|
183
|
+
data = pipe({"atom_array": ground_truth_atom_array})
|
|
184
|
+
tok_idx = data["atom_array"].token_id.astype(np.int32)
|
|
185
|
+
|
|
186
|
+
# Compute chain identification at the token-level
|
|
187
|
+
token_starts = get_token_starts(ground_truth_atom_array)
|
|
188
|
+
|
|
189
|
+
if "chain_iid" in ground_truth_atom_array.get_annotation_categories():
|
|
190
|
+
chain_iid_token_lvl = ground_truth_atom_array.chain_iid[token_starts]
|
|
191
|
+
else:
|
|
192
|
+
# Use the chain_id annotation instead (e.g., for AF-3 outputs, where the chain_id is ostensibly the chain_iid)
|
|
193
|
+
chain_iid_token_lvl = ground_truth_atom_array.chain_id[token_starts]
|
|
194
|
+
|
|
195
|
+
return {
|
|
196
|
+
"X_L": X_L,
|
|
197
|
+
"X_gt_L": X_gt_L,
|
|
198
|
+
"crd_mask_L": crd_mask_L,
|
|
199
|
+
"tok_idx": tok_idx,
|
|
200
|
+
"chain_iid_token_lvl": chain_iid_token_lvl,
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
class AllAtomLDDT(Metric):
|
|
205
|
+
"""Computes all-atom LDDT scores from AtomArrays."""
|
|
206
|
+
|
|
207
|
+
def __init__(self, log_lddt_for_every_batch: bool = False, **kwargs):
|
|
208
|
+
super().__init__(**kwargs)
|
|
209
|
+
self.log_lddt_for_every_batch = log_lddt_for_every_batch
|
|
210
|
+
|
|
211
|
+
@property
|
|
212
|
+
def kwargs_to_compute_args(self) -> dict[str, Any]:
|
|
213
|
+
return {
|
|
214
|
+
"predicted_atom_array_stack": "predicted_atom_array_stack",
|
|
215
|
+
"ground_truth_atom_array_stack": "ground_truth_atom_array_stack",
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
def compute(
|
|
219
|
+
self,
|
|
220
|
+
predicted_atom_array_stack: AtomArrayStack | AtomArray,
|
|
221
|
+
ground_truth_atom_array_stack: AtomArrayStack | AtomArray,
|
|
222
|
+
) -> dict[str, Any]:
|
|
223
|
+
"""Calculates all-atom LDDT between all pairs of atoms.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
predicted_atom_array_stack: Predicted coordinates as AtomArray(Stack)
|
|
227
|
+
ground_truth_atom_array_stack: Ground truth coordinates as AtomArray(Stack)
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
A dictionary with all-atom LDDT scores:
|
|
231
|
+
- lddt_scores: Raw LDDT scores for each model (torch.Tensor)
|
|
232
|
+
- best_of_1_lddt: LDDT score for the first model
|
|
233
|
+
- best_of_{N}_lddt: Best LDDT score across all N models
|
|
234
|
+
"""
|
|
235
|
+
lddt_features = extract_lddt_features_from_atom_arrays(
|
|
236
|
+
predicted_atom_array_stack, ground_truth_atom_array_stack
|
|
237
|
+
)
|
|
238
|
+
tok_idx = torch.tensor(lddt_features["tok_idx"]).to(lddt_features["X_L"].device)
|
|
239
|
+
|
|
240
|
+
all_atom_lddt = calc_lddt(
|
|
241
|
+
X_L=lddt_features["X_L"],
|
|
242
|
+
X_gt_L=lddt_features["X_gt_L"],
|
|
243
|
+
crd_mask_L=lddt_features["crd_mask_L"],
|
|
244
|
+
tok_idx=tok_idx,
|
|
245
|
+
pairs_to_score=None, # By default, score all pairs, except those within the same token
|
|
246
|
+
distance_cutoff=15.0,
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
result = {
|
|
250
|
+
"best_of_1_lddt": all_atom_lddt[0].item(),
|
|
251
|
+
f"best_of_{len(all_atom_lddt)}_lddt": all_atom_lddt.max().item(),
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
if self.log_lddt_for_every_batch:
|
|
255
|
+
lddt_by_batch = {
|
|
256
|
+
f"all_atom_lddt_{i}": all_atom_lddt[i].item()
|
|
257
|
+
for i in range(len(all_atom_lddt))
|
|
258
|
+
}
|
|
259
|
+
result.update(lddt_by_batch)
|
|
260
|
+
|
|
261
|
+
return result
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
class InterfaceLDDTByType(Metric):
|
|
265
|
+
"""Computes interface LDDT, grouped by interface type"""
|
|
266
|
+
|
|
267
|
+
def __init__(self, log_lddt_for_every_batch: bool = False, **kwargs):
|
|
268
|
+
super().__init__(**kwargs)
|
|
269
|
+
self.log_lddt_for_every_batch = log_lddt_for_every_batch
|
|
270
|
+
|
|
271
|
+
@property
|
|
272
|
+
def kwargs_to_compute_args(self) -> dict[str, Any]:
|
|
273
|
+
return {
|
|
274
|
+
"predicted_atom_array_stack": "predicted_atom_array_stack",
|
|
275
|
+
"ground_truth_atom_array_stack": "ground_truth_atom_array_stack",
|
|
276
|
+
"interfaces_to_score": ("extra_info", "interfaces_to_score"),
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
def compute(
|
|
280
|
+
self,
|
|
281
|
+
predicted_atom_array_stack: AtomArrayStack | AtomArray,
|
|
282
|
+
ground_truth_atom_array_stack: AtomArrayStack | AtomArray,
|
|
283
|
+
interfaces_to_score: list = None,
|
|
284
|
+
**kwargs,
|
|
285
|
+
) -> list[dict[str, Any]]:
|
|
286
|
+
"""Calculates interface LDDT between specific pairs of chains/units, grouped by interface type.
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
predicted_atom_array_stack: Predicted coordinates as AtomArray(Stack)
|
|
290
|
+
ground_truth_atom_array_stack: Ground truth coordinates as AtomArray(Stack)
|
|
291
|
+
interfaces_to_score: List of interface specifications, each as
|
|
292
|
+
(pn_unit_i, pn_unit_j, interface_type)
|
|
293
|
+
|
|
294
|
+
Returns:
|
|
295
|
+
List of dictionaries containing interface LDDT results for each interface.
|
|
296
|
+
"""
|
|
297
|
+
lddt_features = extract_lddt_features_from_atom_arrays(
|
|
298
|
+
predicted_atom_array_stack, ground_truth_atom_array_stack
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
# Short-circuit if no interfaces to score
|
|
302
|
+
if not interfaces_to_score:
|
|
303
|
+
return []
|
|
304
|
+
|
|
305
|
+
interface_results = []
|
|
306
|
+
|
|
307
|
+
# Parse string inputs (for backwards compatibility)
|
|
308
|
+
if isinstance(interfaces_to_score, str):
|
|
309
|
+
interfaces_to_score = (
|
|
310
|
+
eval(interfaces_to_score) if interfaces_to_score else []
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
# Loop over the interfaces to score
|
|
314
|
+
for pn_unit_i, pn_unit_j, interface_type in interfaces_to_score:
|
|
315
|
+
# Get tokens in pn_unit_i and pn_unit_j
|
|
316
|
+
pn_unit_i_tokens = lddt_features["chain_iid_token_lvl"] == pn_unit_i
|
|
317
|
+
pn_unit_j_tokens = lddt_features["chain_iid_token_lvl"] == pn_unit_j
|
|
318
|
+
|
|
319
|
+
if pn_unit_i_tokens.sum() == 0 or pn_unit_j_tokens.sum() == 0:
|
|
320
|
+
ranked_logger.warning(
|
|
321
|
+
f"No atoms found for {pn_unit_i} or {pn_unit_j}! Available chains: {np.unique(lddt_features['chain_iid_token_lvl']).tolist()}"
|
|
322
|
+
)
|
|
323
|
+
continue
|
|
324
|
+
|
|
325
|
+
# Convert the token level to the atom level
|
|
326
|
+
pn_unit_i_atoms = pn_unit_i_tokens[lddt_features["tok_idx"]]
|
|
327
|
+
pn_unit_j_atoms = pn_unit_j_tokens[lddt_features["tok_idx"]]
|
|
328
|
+
|
|
329
|
+
# Compute the outer product of chain_i and chain_j, which represents the interface
|
|
330
|
+
chain_ij_atoms = torch.einsum(
|
|
331
|
+
"L, K -> LK",
|
|
332
|
+
torch.tensor(pn_unit_i_atoms),
|
|
333
|
+
torch.tensor(pn_unit_j_atoms),
|
|
334
|
+
).to(lddt_features["X_L"].device)
|
|
335
|
+
|
|
336
|
+
# Symmetrize the interface so we can later multiply with an upper triangular without losing information
|
|
337
|
+
chain_ij_atoms = chain_ij_atoms | chain_ij_atoms.T
|
|
338
|
+
|
|
339
|
+
# compute lddt using the pairs_to_score from the intersection
|
|
340
|
+
lddt = calc_lddt(
|
|
341
|
+
lddt_features["X_L"],
|
|
342
|
+
lddt_features["X_gt_L"],
|
|
343
|
+
lddt_features["crd_mask_L"],
|
|
344
|
+
torch.tensor(lddt_features["tok_idx"]).to(lddt_features["X_L"].device),
|
|
345
|
+
pairs_to_score=chain_ij_atoms,
|
|
346
|
+
distance_cutoff=30.0,
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
# add the results to the interface_results list
|
|
350
|
+
n = len(lddt)
|
|
351
|
+
result = {
|
|
352
|
+
"pn_units": [pn_unit_i, pn_unit_j],
|
|
353
|
+
"type": interface_type,
|
|
354
|
+
"best_of_1_lddt": lddt[0].item(),
|
|
355
|
+
f"best_of_{n}_lddt": lddt.max().item(),
|
|
356
|
+
}
|
|
357
|
+
|
|
358
|
+
if self.log_lddt_for_every_batch:
|
|
359
|
+
lddt_by_batch = {f"lddt_{i}": lddt[i].item() for i in range(len(lddt))}
|
|
360
|
+
result.update(lddt_by_batch)
|
|
361
|
+
|
|
362
|
+
interface_results.append(result)
|
|
363
|
+
|
|
364
|
+
return interface_results
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
class ChainLDDTByType(Metric):
|
|
368
|
+
"""Computes chain-wise LDDT scores from AtomArrays, grouped by chain type."""
|
|
369
|
+
|
|
370
|
+
def __init__(self, log_lddt_for_every_batch: bool = False, **kwargs):
|
|
371
|
+
super().__init__(**kwargs)
|
|
372
|
+
self.log_lddt_for_every_batch = log_lddt_for_every_batch
|
|
373
|
+
|
|
374
|
+
@property
|
|
375
|
+
def kwargs_to_compute_args(self) -> dict[str, Any]:
|
|
376
|
+
return {
|
|
377
|
+
"predicted_atom_array_stack": "predicted_atom_array_stack",
|
|
378
|
+
"ground_truth_atom_array_stack": "ground_truth_atom_array_stack",
|
|
379
|
+
"pn_units_to_score": ("extra_info", "pn_units_to_score"),
|
|
380
|
+
}
|
|
381
|
+
|
|
382
|
+
def compute(
|
|
383
|
+
self,
|
|
384
|
+
predicted_atom_array_stack: AtomArrayStack | AtomArray,
|
|
385
|
+
ground_truth_atom_array_stack: AtomArrayStack | AtomArray,
|
|
386
|
+
pn_units_to_score: list = None,
|
|
387
|
+
**kwargs,
|
|
388
|
+
) -> list[dict[str, Any]]:
|
|
389
|
+
"""Calculates intra-chain LDDT for specific chains/units.
|
|
390
|
+
|
|
391
|
+
Args:
|
|
392
|
+
predicted_atom_array_stack: Predicted coordinates as AtomArray(Stack)
|
|
393
|
+
ground_truth_atom_array_stack: Ground truth coordinates as AtomArray(Stack)
|
|
394
|
+
pn_units_to_score: List of chain specifications, each as (pn_unit_iid, chain_type)
|
|
395
|
+
|
|
396
|
+
Returns:
|
|
397
|
+
List of dictionaries containing chain LDDT results for each chain.
|
|
398
|
+
"""
|
|
399
|
+
lddt_features = extract_lddt_features_from_atom_arrays(
|
|
400
|
+
predicted_atom_array_stack, ground_truth_atom_array_stack
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
# Short-circuit if no chains to score
|
|
404
|
+
if not pn_units_to_score:
|
|
405
|
+
return []
|
|
406
|
+
|
|
407
|
+
chain_results = []
|
|
408
|
+
|
|
409
|
+
# Parse string inputs (for backwards compatibility)
|
|
410
|
+
if isinstance(pn_units_to_score, str):
|
|
411
|
+
pn_units_to_score = eval(pn_units_to_score) if pn_units_to_score else []
|
|
412
|
+
|
|
413
|
+
# For all chains (pn_units) to score...
|
|
414
|
+
for chain, chain_type in pn_units_to_score:
|
|
415
|
+
# ... get tokens in chain instance
|
|
416
|
+
chain_tokens = lddt_features["chain_iid_token_lvl"] == chain
|
|
417
|
+
if chain_tokens.sum() == 0:
|
|
418
|
+
ranked_logger.warning(
|
|
419
|
+
f"No atoms found for {chain}! Available chains: {np.unique(lddt_features['chain_iid_token_lvl']).tolist()}"
|
|
420
|
+
)
|
|
421
|
+
continue
|
|
422
|
+
|
|
423
|
+
# ... convert the token level to the atom level
|
|
424
|
+
chain_atoms = chain_tokens[lddt_features["tok_idx"]]
|
|
425
|
+
|
|
426
|
+
# ... compute the outer product of the chain with itself (the definition of intra-lddt)
|
|
427
|
+
chain_ij_atoms = torch.einsum(
|
|
428
|
+
"L, K -> LK", torch.tensor(chain_atoms), torch.tensor(chain_atoms)
|
|
429
|
+
).to(lddt_features["X_L"].device)
|
|
430
|
+
|
|
431
|
+
# ... compute lddt using the pairs_to_score from the interface
|
|
432
|
+
lddt = calc_lddt(
|
|
433
|
+
lddt_features["X_L"],
|
|
434
|
+
lddt_features["X_gt_L"],
|
|
435
|
+
lddt_features["crd_mask_L"],
|
|
436
|
+
torch.tensor(lddt_features["tok_idx"]).to(lddt_features["X_L"].device),
|
|
437
|
+
pairs_to_score=chain_ij_atoms,
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
# ... and finally add the results to the chain_results list
|
|
441
|
+
n = len(lddt)
|
|
442
|
+
result = {
|
|
443
|
+
"pn_units": [chain],
|
|
444
|
+
"type": chain_type,
|
|
445
|
+
"best_of_1_lddt": lddt[0].item(),
|
|
446
|
+
f"best_of_{n}_lddt": lddt.max().item(),
|
|
447
|
+
}
|
|
448
|
+
|
|
449
|
+
if self.log_lddt_for_every_batch:
|
|
450
|
+
lddt_by_batch = {f"lddt_{i}": lddt[i].item() for i in range(len(lddt))}
|
|
451
|
+
result.update(lddt_by_batch)
|
|
452
|
+
|
|
453
|
+
chain_results.append(result)
|
|
454
|
+
|
|
455
|
+
return chain_results
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
class ByTypeLDDT(Metric):
|
|
459
|
+
"""Calculates LDDT scores by type for both chains and interfaces."""
|
|
460
|
+
|
|
461
|
+
def __init__(self, log_lddt_for_every_batch: bool = True, **kwargs):
|
|
462
|
+
super().__init__(**kwargs)
|
|
463
|
+
self.interface_lddt = InterfaceLDDTByType(
|
|
464
|
+
log_lddt_for_every_batch=log_lddt_for_every_batch, **kwargs
|
|
465
|
+
)
|
|
466
|
+
self.chain_lddt = ChainLDDTByType(
|
|
467
|
+
log_lddt_for_every_batch=log_lddt_for_every_batch, **kwargs
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
@property
|
|
471
|
+
def kwargs_to_compute_args(self) -> dict[str, Any]:
|
|
472
|
+
return {
|
|
473
|
+
"predicted_atom_array_stack": "predicted_atom_array_stack",
|
|
474
|
+
"ground_truth_atom_array_stack": "ground_truth_atom_array_stack",
|
|
475
|
+
"interfaces_to_score": ("extra_info", "interfaces_to_score"),
|
|
476
|
+
"pn_units_to_score": ("extra_info", "pn_units_to_score"),
|
|
477
|
+
}
|
|
478
|
+
|
|
479
|
+
@property
|
|
480
|
+
def optional_kwargs(self) -> set[str]:
|
|
481
|
+
"""Mark interfaces_to_score and pn_units_to_score as optional."""
|
|
482
|
+
return {"interfaces_to_score", "pn_units_to_score"}
|
|
483
|
+
|
|
484
|
+
def compute(
|
|
485
|
+
self,
|
|
486
|
+
predicted_atom_array_stack: AtomArrayStack | AtomArray,
|
|
487
|
+
ground_truth_atom_array_stack: AtomArrayStack | AtomArray,
|
|
488
|
+
interfaces_to_score: list[tuple[str, str, str]] | None = None,
|
|
489
|
+
pn_units_to_score: list[tuple[str, str]] | None = None,
|
|
490
|
+
) -> list[dict[str, Any]]:
|
|
491
|
+
"""Calculates LDDT scores by type for both chains and interfaces.
|
|
492
|
+
|
|
493
|
+
Args:
|
|
494
|
+
predicted_atom_array_stack: Predicted coordinates as AtomArray(Stack)
|
|
495
|
+
ground_truth_atom_array_stack: Ground truth coordinates as AtomArray(Stack)
|
|
496
|
+
interfaces_to_score: Tuples of (pn_unit_i, pn_unit_j, interface_type)
|
|
497
|
+
representing the interfaces to score
|
|
498
|
+
pn_units_to_score: Tuples of (pn_unit_iid, chain_type)
|
|
499
|
+
representing the chains to score
|
|
500
|
+
log_lddt_for_every_batch: Whether to compute LDDT for each model separately (vs. only BO1 and BO{N})
|
|
501
|
+
|
|
502
|
+
Returns:
|
|
503
|
+
Combined list of interface and chain LDDT results.
|
|
504
|
+
"""
|
|
505
|
+
|
|
506
|
+
# Compute interface LDDT scores
|
|
507
|
+
interface_results = self.interface_lddt.compute(
|
|
508
|
+
predicted_atom_array_stack=predicted_atom_array_stack,
|
|
509
|
+
ground_truth_atom_array_stack=ground_truth_atom_array_stack,
|
|
510
|
+
interfaces_to_score=interfaces_to_score,
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
# Compute chain LDDT scores
|
|
514
|
+
chain_results = self.chain_lddt.compute(
|
|
515
|
+
predicted_atom_array_stack=predicted_atom_array_stack,
|
|
516
|
+
ground_truth_atom_array_stack=ground_truth_atom_array_stack,
|
|
517
|
+
pn_units_to_score=pn_units_to_score,
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
# Merge the results
|
|
521
|
+
combined_results = interface_results + chain_results
|
|
522
|
+
|
|
523
|
+
return combined_results
|
rf3/metrics/metadata.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
import json
|
|
2
|
+
|
|
3
|
+
from beartype.typing import Any, Literal
|
|
4
|
+
|
|
5
|
+
from foundry.metrics.metric import Metric
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ExtraInfo(Metric):
|
|
9
|
+
"""Stores the extra_info from the dataloader output in the metrics dictionary.
|
|
10
|
+
Only basic Python types that are hashable and can be JSON serialized are stored."""
|
|
11
|
+
|
|
12
|
+
def __init__(self, keys_to_store: list[str] | Literal["all"] = "all", **kwargs):
|
|
13
|
+
super().__init__(**kwargs)
|
|
14
|
+
self.keys_to_store = keys_to_store
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def kwargs_to_compute_args(self) -> dict[str, Any]:
|
|
18
|
+
return {"extra_info": "extra_info"}
|
|
19
|
+
|
|
20
|
+
def _is_basic_hashable_type(self, value: Any) -> bool:
|
|
21
|
+
"""Check if value is a basic Python type that is both JSON serializable and hashable."""
|
|
22
|
+
try:
|
|
23
|
+
# First check if it's hashable
|
|
24
|
+
hash(value)
|
|
25
|
+
|
|
26
|
+
# Then check if it's JSON serializable
|
|
27
|
+
json.dumps(value)
|
|
28
|
+
return True
|
|
29
|
+
except (TypeError, OverflowError):
|
|
30
|
+
return False
|
|
31
|
+
|
|
32
|
+
def compute(
|
|
33
|
+
self,
|
|
34
|
+
extra_info: dict,
|
|
35
|
+
) -> dict[str, Any]:
|
|
36
|
+
result = {}
|
|
37
|
+
for key, value in extra_info.items():
|
|
38
|
+
# Check if we should include this key
|
|
39
|
+
if self.keys_to_store == "all" or key in self.keys_to_store:
|
|
40
|
+
# Check if the value is a basic hashable type
|
|
41
|
+
if self._is_basic_hashable_type(value):
|
|
42
|
+
result[key] = value
|
|
43
|
+
return result
|