rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,440 @@
|
|
|
1
|
+
from atomworks.ml.transforms.base import ConvertToTorch
|
|
2
|
+
from mpnn.collate.feature_collator import FeatureCollator
|
|
3
|
+
from mpnn.transforms.feature_aggregation.polymer_ligand_interface import (
|
|
4
|
+
FeaturizePolymerLigandInterfaceMask,
|
|
5
|
+
)
|
|
6
|
+
from mpnn.transforms.polymer_ligand_interface import ComputePolymerLigandInterface
|
|
7
|
+
|
|
8
|
+
from foundry.metrics.metric import Metric
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SequenceRecovery(Metric):
|
|
12
|
+
"""
|
|
13
|
+
Computes sequence recovery accuracy for Protein/Ligand MPNN.
|
|
14
|
+
|
|
15
|
+
This metric compares both the sampled predicted sequence and the argmax
|
|
16
|
+
sequence to the ground truth sequence and computes the percentage of
|
|
17
|
+
correctly predicted residues for both versions.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
return_per_example_metrics=False,
|
|
23
|
+
return_per_residue_metrics=False,
|
|
24
|
+
**kwargs,
|
|
25
|
+
):
|
|
26
|
+
"""
|
|
27
|
+
Initialize the SequenceRecovery metric.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
return_per_example_metrics (bool): If True, returns per-example
|
|
31
|
+
metrics in addition to the aggregate metrics.
|
|
32
|
+
return_per_residue_metrics (bool): If True, returns per-residue
|
|
33
|
+
metrics in addition to the aggregate metrics.
|
|
34
|
+
**kwargs: Additional keyword arguments passed to the base Metric
|
|
35
|
+
class.
|
|
36
|
+
"""
|
|
37
|
+
super().__init__(**kwargs)
|
|
38
|
+
self.return_per_example_metrics = return_per_example_metrics
|
|
39
|
+
self.return_per_residue_metrics = return_per_residue_metrics
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def kwargs_to_compute_args(self):
|
|
43
|
+
"""Map input keys to the compute method arguments.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
dict: Mapping from compute method argument names to nested
|
|
47
|
+
dictionary keys in the input kwargs.
|
|
48
|
+
"""
|
|
49
|
+
return {
|
|
50
|
+
"S": ("network_input", "input_features", "S"),
|
|
51
|
+
"S_sampled": ("network_output", "decoder_features", "S_sampled"),
|
|
52
|
+
"S_argmax": ("network_output", "decoder_features", "S_argmax"),
|
|
53
|
+
"mask_for_loss": ("network_output", "input_features", "mask_for_loss"),
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
def get_per_residue_mask(self, mask_for_loss, **kwargs):
|
|
57
|
+
"""
|
|
58
|
+
Get the per-residue mask for computing sequence recovery.
|
|
59
|
+
|
|
60
|
+
This method can be overridden by subclasses to apply additional
|
|
61
|
+
masking criteria (e.g., interface residues only).
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
mask_for_loss (torch.Tensor): [B, L] - mask for loss
|
|
65
|
+
**kwargs: Additional arguments that may be needed by subclasses
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
per_residue_mask (torch.Tensor): [B, L] - per-residue mask for
|
|
69
|
+
sequence recovery computation.
|
|
70
|
+
"""
|
|
71
|
+
per_residue_mask = mask_for_loss
|
|
72
|
+
return per_residue_mask
|
|
73
|
+
|
|
74
|
+
def compute_sequence_recovery_metrics(self, S, S_pred, per_residue_mask):
|
|
75
|
+
"""
|
|
76
|
+
Compute sequence recovery metrics using the ground truth sequence,
|
|
77
|
+
the predicted sequence, and the per-residue mask.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
S (torch.Tensor): [B, L] - the ground truth sequence.
|
|
81
|
+
S_pred (torch.Tensor): [B, L] - the predicted sequence.
|
|
82
|
+
per_residue_mask (torch.Tensor): [B, L] - per-residue mask for
|
|
83
|
+
computation of sequence recovery.
|
|
84
|
+
Returns:
|
|
85
|
+
sequence_recovery_dict (dict): Dictionary containing the sequence
|
|
86
|
+
recovery metrics.
|
|
87
|
+
- mean_sequence_recovery (torch.Tensor): [1] - mean sequence
|
|
88
|
+
recovery across (valid) examples (a valid example is one
|
|
89
|
+
that has at least one valid residue according to the
|
|
90
|
+
per_residue_mask).
|
|
91
|
+
- sequence_recovery_per_example (torch.Tensor): [B] - sequence
|
|
92
|
+
recovery per example, undefined for examples
|
|
93
|
+
with no valid residues.
|
|
94
|
+
- correct_per_example (torch.Tensor): [B] - total number of
|
|
95
|
+
correct predictions per example.
|
|
96
|
+
- correct_predictions_per_residue (torch.Tensor): [B, L] -
|
|
97
|
+
boolean tensor indicating if the predicted sequence matches
|
|
98
|
+
the ground truth sequence (1 for correct, 0 for incorrect,
|
|
99
|
+
masked by per_residue_mask).
|
|
100
|
+
- total_valid_per_example [B]: number of valid residues per
|
|
101
|
+
example.
|
|
102
|
+
- valid_examples_mask [B]: boolean mask indicating examples
|
|
103
|
+
with valid residues.
|
|
104
|
+
- per_residue_mask [B, L]: per-residue mask for NLL computation.
|
|
105
|
+
"""
|
|
106
|
+
per_residue_mask = per_residue_mask.float()
|
|
107
|
+
|
|
108
|
+
# total_valid_per_example [B] - sum of valid residues per example.
|
|
109
|
+
total_valid_per_example = per_residue_mask.sum(dim=-1)
|
|
110
|
+
|
|
111
|
+
# valid_examples_mask [B] - boolean mask indicating examples with
|
|
112
|
+
# valid residues.
|
|
113
|
+
valid_examples_mask = total_valid_per_example > 0
|
|
114
|
+
|
|
115
|
+
# Compute sequence recovery accuracy for sampled residues.
|
|
116
|
+
# correct_predictions [B, L] - boolean tensor indicating if the
|
|
117
|
+
# subject sequence matches the ground truth sequence. Masked by the
|
|
118
|
+
# per_residue_mask.
|
|
119
|
+
correct_predictions_per_residue = (S_pred == S).float() * per_residue_mask
|
|
120
|
+
|
|
121
|
+
# correct_per_example [B] - sum of correct predictions per example.
|
|
122
|
+
correct_per_example = correct_predictions_per_residue.sum(dim=-1)
|
|
123
|
+
|
|
124
|
+
# sequence_recovery_per_example [B] - compute the sequence recovery
|
|
125
|
+
# (accuracy) per example. Undefined if there are no valid residues.
|
|
126
|
+
sequence_recovery_per_example = correct_per_example / total_valid_per_example
|
|
127
|
+
|
|
128
|
+
# mean_sequence_recovery [1] - mean sequence recovery across
|
|
129
|
+
# examples with valid residues.
|
|
130
|
+
mean_sequence_recovery = sequence_recovery_per_example[
|
|
131
|
+
valid_examples_mask
|
|
132
|
+
].mean()
|
|
133
|
+
|
|
134
|
+
# Create the sequence recovery dictionary.
|
|
135
|
+
sequence_recovery_dict = {
|
|
136
|
+
"mean_sequence_recovery": mean_sequence_recovery,
|
|
137
|
+
"sequence_recovery_per_example": sequence_recovery_per_example,
|
|
138
|
+
"correct_per_example": correct_per_example,
|
|
139
|
+
"correct_predictions_per_residue": correct_predictions_per_residue,
|
|
140
|
+
"total_valid_per_example": total_valid_per_example,
|
|
141
|
+
"valid_examples_mask": valid_examples_mask,
|
|
142
|
+
"per_residue_mask": per_residue_mask,
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
return sequence_recovery_dict
|
|
146
|
+
|
|
147
|
+
def compute(self, S, S_sampled, S_argmax, mask_for_loss, **kwargs):
|
|
148
|
+
"""
|
|
149
|
+
Compute sequence recovery accuracy for both sampled and argmax
|
|
150
|
+
sequences.
|
|
151
|
+
|
|
152
|
+
This method compares both the sampled predicted sequence and the argmax
|
|
153
|
+
sequence to the ground truth sequence and computes the fraction of
|
|
154
|
+
correctly predicted residues for both versions (i.e. the accuracy).
|
|
155
|
+
|
|
156
|
+
A NOTE on shapes:
|
|
157
|
+
B: batch size
|
|
158
|
+
L: sequence length
|
|
159
|
+
vocab_size: vocabulary size
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
S (torch.Tensor): [B, L] - the ground truth sequence.
|
|
163
|
+
S_sampled (torch.Tensor): [B, L] - the sampled sequence,
|
|
164
|
+
sampled from the probabilities (unknown residues are not
|
|
165
|
+
sampled).
|
|
166
|
+
S_argmax (torch.Tensor): [B, L] - the predicted sequence,
|
|
167
|
+
obtained by taking the argmax of the probabilities
|
|
168
|
+
(unknown residues are not selected).
|
|
169
|
+
mask_for_loss (torch.Tensor): [B, L] - mask for loss,
|
|
170
|
+
where True is a residue that is included in the loss
|
|
171
|
+
calculation, and False is a residue that is not included
|
|
172
|
+
in the loss calculation.
|
|
173
|
+
**kwargs: Additional arguments that may be needed by subclasses.
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
metric_dict (dict): Dictionary containing the sequence recovery
|
|
177
|
+
metrics.
|
|
178
|
+
- mean_sequence_recovery_sampled (torch.Tensor): [1] -
|
|
179
|
+
mean sequence recovery for the sampled sequence.
|
|
180
|
+
- mean_sequence_recovery_argmax (torch.Tensor): [1] -
|
|
181
|
+
mean sequence recovery for the argmax sequence.
|
|
182
|
+
if self.return_per_example_metrics is True:
|
|
183
|
+
- sequence_recovery_per_example_sampled (torch.Tensor): [B] -
|
|
184
|
+
sequence recovery per example for the sampled sequence,
|
|
185
|
+
undefined for examples with no valid residues.
|
|
186
|
+
- sequence_recovery_per_example_argmax (torch.Tensor): [B] -
|
|
187
|
+
sequence recovery per example for the argmax sequence,
|
|
188
|
+
undefined for examples with no valid residues.
|
|
189
|
+
- correct_per_example_sampled (torch.Tensor): [B] - total
|
|
190
|
+
number of correct predictions per example for the sampled
|
|
191
|
+
sequence.
|
|
192
|
+
- correct_per_example_argmax (torch.Tensor): [B] - total
|
|
193
|
+
number of correct predictions per example for the argmax
|
|
194
|
+
sequence.
|
|
195
|
+
- total_valid_per_example (torch.Tensor): [B] - number of valid
|
|
196
|
+
residues per example.
|
|
197
|
+
- valid_examples_mask (torch.Tensor): [B] - boolean mask for
|
|
198
|
+
valid examples.
|
|
199
|
+
if self.return_per_residue_metrics is True:
|
|
200
|
+
- correct_predictions_per_residue_sampled (torch.Tensor):
|
|
201
|
+
[B, L] - boolean tensor indicating if the sampled
|
|
202
|
+
sequence matches the ground truth sequence (1 for correct,
|
|
203
|
+
0 for incorrect, masked by per_residue_mask).
|
|
204
|
+
- correct_predictions_per_residue_argmax (torch.Tensor):
|
|
205
|
+
[B, L] - boolean tensor indicating if the argmax sequence
|
|
206
|
+
matches the ground truth sequence (1 for correct, 0 for
|
|
207
|
+
incorrect, masked by per_residue_mask).
|
|
208
|
+
- per_residue_mask (torch.Tensor): [B, L] - per-residue
|
|
209
|
+
mask for sequence recovery computation.
|
|
210
|
+
"""
|
|
211
|
+
# per_residue_mask [B, L] - mask for sequence recovery.
|
|
212
|
+
per_residue_mask = self.get_per_residue_mask(mask_for_loss, **kwargs)
|
|
213
|
+
|
|
214
|
+
# Compute sequence recovery metrics for sampled sequence.
|
|
215
|
+
sequence_recovery_metrics_sampled = self.compute_sequence_recovery_metrics(
|
|
216
|
+
S, S_sampled, per_residue_mask
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
# Compute sequence recovery metrics for argmax sequence.
|
|
220
|
+
sequence_recovery_metrics_argmax = self.compute_sequence_recovery_metrics(
|
|
221
|
+
S, S_argmax, per_residue_mask
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
# Prepare the metric dictionary.
|
|
225
|
+
metric_dict = {
|
|
226
|
+
"mean_sequence_recovery_sampled": sequence_recovery_metrics_sampled[
|
|
227
|
+
"mean_sequence_recovery"
|
|
228
|
+
]
|
|
229
|
+
.detach()
|
|
230
|
+
.item(),
|
|
231
|
+
"mean_sequence_recovery_argmax": sequence_recovery_metrics_argmax[
|
|
232
|
+
"mean_sequence_recovery"
|
|
233
|
+
]
|
|
234
|
+
.detach()
|
|
235
|
+
.item(),
|
|
236
|
+
}
|
|
237
|
+
if self.return_per_example_metrics:
|
|
238
|
+
metric_dict.update(
|
|
239
|
+
{
|
|
240
|
+
"sequence_recovery_per_example_sampled": sequence_recovery_metrics_sampled[
|
|
241
|
+
"sequence_recovery_per_example"
|
|
242
|
+
],
|
|
243
|
+
"sequence_recovery_per_example_argmax": sequence_recovery_metrics_argmax[
|
|
244
|
+
"sequence_recovery_per_example"
|
|
245
|
+
],
|
|
246
|
+
"correct_per_example_sampled": sequence_recovery_metrics_sampled[
|
|
247
|
+
"correct_per_example"
|
|
248
|
+
],
|
|
249
|
+
"correct_per_example_argmax": sequence_recovery_metrics_argmax[
|
|
250
|
+
"correct_per_example"
|
|
251
|
+
],
|
|
252
|
+
"total_valid_per_example": sequence_recovery_metrics_sampled[
|
|
253
|
+
"total_valid_per_example"
|
|
254
|
+
],
|
|
255
|
+
"valid_examples_mask": sequence_recovery_metrics_sampled[
|
|
256
|
+
"valid_examples_mask"
|
|
257
|
+
],
|
|
258
|
+
}
|
|
259
|
+
)
|
|
260
|
+
if self.return_per_residue_metrics:
|
|
261
|
+
metric_dict.update(
|
|
262
|
+
{
|
|
263
|
+
"correct_predictions_per_residue_sampled": sequence_recovery_metrics_sampled[
|
|
264
|
+
"correct_predictions_per_residue"
|
|
265
|
+
],
|
|
266
|
+
"correct_predictions_per_residue_argmax": sequence_recovery_metrics_argmax[
|
|
267
|
+
"correct_predictions_per_residue"
|
|
268
|
+
],
|
|
269
|
+
"per_residue_mask": sequence_recovery_metrics_sampled[
|
|
270
|
+
"per_residue_mask"
|
|
271
|
+
],
|
|
272
|
+
}
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
return metric_dict
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
class InterfaceSequenceRecovery(SequenceRecovery):
|
|
279
|
+
"""
|
|
280
|
+
Computes sequence recovery accuracy for Protein/Ligand MPNN specifically
|
|
281
|
+
for residues at the polymer-ligand interface.
|
|
282
|
+
|
|
283
|
+
This metric inherits from SequenceRecovery but only computes metrics for
|
|
284
|
+
residues that are within a specified distance threshold of ligand atoms.
|
|
285
|
+
All returned metric names are prefixed with "interface_".
|
|
286
|
+
"""
|
|
287
|
+
|
|
288
|
+
def __init__(
|
|
289
|
+
self,
|
|
290
|
+
interface_distance_threshold: float = 5.0,
|
|
291
|
+
return_per_example_metrics: bool = False,
|
|
292
|
+
return_per_residue_metrics: bool = False,
|
|
293
|
+
**kwargs,
|
|
294
|
+
):
|
|
295
|
+
"""
|
|
296
|
+
Initialize the InterfaceSequenceRecovery metric.
|
|
297
|
+
|
|
298
|
+
Args:
|
|
299
|
+
interface_distance_threshold (float): Distance threshold in
|
|
300
|
+
Angstroms for considering residues to be at the interface.
|
|
301
|
+
Defaults to 5.0.
|
|
302
|
+
return_per_example_metrics (bool): If True, returns per-example
|
|
303
|
+
metrics in addition to the aggregate metrics.
|
|
304
|
+
return_per_residue_metrics (bool): If True, returns per-residue
|
|
305
|
+
metrics in addition to the aggregate metrics.
|
|
306
|
+
**kwargs: Additional keyword arguments passed to the base Metric
|
|
307
|
+
class.
|
|
308
|
+
"""
|
|
309
|
+
super().__init__(
|
|
310
|
+
return_per_example_metrics=return_per_example_metrics,
|
|
311
|
+
return_per_residue_metrics=return_per_residue_metrics,
|
|
312
|
+
**kwargs,
|
|
313
|
+
)
|
|
314
|
+
self.interface_distance_threshold = interface_distance_threshold
|
|
315
|
+
|
|
316
|
+
@property
|
|
317
|
+
def kwargs_to_compute_args(self):
|
|
318
|
+
"""Map input keys to the compute method arguments.
|
|
319
|
+
|
|
320
|
+
Returns:
|
|
321
|
+
dict: Mapping from compute method argument names to nested
|
|
322
|
+
dictionary keys in the input kwargs.
|
|
323
|
+
"""
|
|
324
|
+
args_mapping = super().kwargs_to_compute_args
|
|
325
|
+
# Add atom_array to the mapping for interface computation
|
|
326
|
+
args_mapping["atom_array"] = ("network_input", "atom_array")
|
|
327
|
+
return args_mapping
|
|
328
|
+
|
|
329
|
+
def get_per_residue_mask(self, mask_for_loss, **kwargs):
|
|
330
|
+
"""
|
|
331
|
+
Get the per-residue mask for computing interface sequence recovery.
|
|
332
|
+
|
|
333
|
+
This method computes the interface mask by applying transforms to
|
|
334
|
+
detect polymer-ligand interfaces and combines it with the original
|
|
335
|
+
mask_for_loss using logical AND.
|
|
336
|
+
|
|
337
|
+
Args:
|
|
338
|
+
mask_for_loss (torch.Tensor): [B, L] - mask for loss
|
|
339
|
+
**kwargs: Additional arguments including atom_array
|
|
340
|
+
|
|
341
|
+
Returns:
|
|
342
|
+
per_residue_mask (torch.Tensor): [B, L] - combined mask for
|
|
343
|
+
interface sequence recovery computation.
|
|
344
|
+
"""
|
|
345
|
+
# Extract atom arrays from kwargs
|
|
346
|
+
atom_arrays = kwargs.get("atom_array")
|
|
347
|
+
if atom_arrays is None:
|
|
348
|
+
raise ValueError(
|
|
349
|
+
"atom_array is required for interface "
|
|
350
|
+
+ "computation but was not found"
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
# Initialize transforms
|
|
354
|
+
interface_transform = ComputePolymerLigandInterface(
|
|
355
|
+
distance_threshold=self.interface_distance_threshold
|
|
356
|
+
)
|
|
357
|
+
mask_transform = FeaturizePolymerLigandInterfaceMask()
|
|
358
|
+
convert_to_torch_transform = ConvertToTorch(keys=["input_features"])
|
|
359
|
+
|
|
360
|
+
# Process each atom array in the batch
|
|
361
|
+
batch_interface_masks = []
|
|
362
|
+
for atom_array in atom_arrays:
|
|
363
|
+
# Apply interface detection transform
|
|
364
|
+
data = {"atom_array": atom_array}
|
|
365
|
+
data = interface_transform(data)
|
|
366
|
+
|
|
367
|
+
# Apply interface mask featurization
|
|
368
|
+
data = mask_transform(data)
|
|
369
|
+
|
|
370
|
+
# Convert to torch tensor
|
|
371
|
+
data = convert_to_torch_transform(data)
|
|
372
|
+
|
|
373
|
+
# Extract the interface mask
|
|
374
|
+
interface_mask = data["input_features"]["polymer_ligand_interface_mask"]
|
|
375
|
+
|
|
376
|
+
# Convert to torch tensor
|
|
377
|
+
batch_interface_masks.append(interface_mask)
|
|
378
|
+
|
|
379
|
+
# Collate interface masks with proper padding
|
|
380
|
+
collator = FeatureCollator(
|
|
381
|
+
default_padding={"polymer_ligand_interface_mask": False}
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
# Create mock pipeline outputs for collation
|
|
385
|
+
mock_outputs = []
|
|
386
|
+
for interface_mask in batch_interface_masks:
|
|
387
|
+
mock_outputs.append(
|
|
388
|
+
{
|
|
389
|
+
"input_features": {"polymer_ligand_interface_mask": interface_mask},
|
|
390
|
+
"atom_array": None, # Not needed for collation
|
|
391
|
+
}
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
# Collate the masks
|
|
395
|
+
collated = collator(mock_outputs)
|
|
396
|
+
interface_mask = collated["input_features"]["polymer_ligand_interface_mask"]
|
|
397
|
+
|
|
398
|
+
# Convert to the same device and dtype as mask_for_loss
|
|
399
|
+
interface_mask = interface_mask.to(
|
|
400
|
+
device=mask_for_loss.device, dtype=mask_for_loss.dtype
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
# Combine with original mask using logical AND
|
|
404
|
+
combined_mask = mask_for_loss & interface_mask
|
|
405
|
+
|
|
406
|
+
return combined_mask
|
|
407
|
+
|
|
408
|
+
def compute(self, S, S_sampled, S_argmax, mask_for_loss, atom_array, **kwargs):
|
|
409
|
+
"""
|
|
410
|
+
Compute interface sequence recovery accuracy for both sampled and
|
|
411
|
+
argmax sequences.
|
|
412
|
+
|
|
413
|
+
This method computes sequence recovery specifically for residues at
|
|
414
|
+
the polymer-ligand interface and prefixes all output metrics with
|
|
415
|
+
"interface_".
|
|
416
|
+
|
|
417
|
+
Args:
|
|
418
|
+
S (torch.Tensor): [B, L] - the ground truth sequence.
|
|
419
|
+
S_sampled (torch.Tensor): [B, L] - the sampled sequence.
|
|
420
|
+
S_argmax (torch.Tensor): [B, L] - the predicted sequence.
|
|
421
|
+
mask_for_loss (torch.Tensor): [B, L] - mask for loss.
|
|
422
|
+
**kwargs: Additional arguments including atom_array.
|
|
423
|
+
|
|
424
|
+
Returns:
|
|
425
|
+
metric_dict (dict): Dictionary containing the interface sequence
|
|
426
|
+
recovery metrics with "interface_" prefix.
|
|
427
|
+
"""
|
|
428
|
+
# Get the base metrics using parent class compute method
|
|
429
|
+
# Pass atom_array through kwargs for get_per_residue_mask method
|
|
430
|
+
kwargs_with_atom_array = {**kwargs, "atom_array": atom_array}
|
|
431
|
+
base_metrics = super().compute(
|
|
432
|
+
S, S_sampled, S_argmax, mask_for_loss, **kwargs_with_atom_array
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
# Add "interface_" prefix to all metric keys
|
|
436
|
+
interface_metrics = {}
|
|
437
|
+
for key, value in base_metrics.items():
|
|
438
|
+
interface_metrics[f"interface_{key}"] = value
|
|
439
|
+
|
|
440
|
+
return interface_metrics
|