rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
mpnn/metrics/nll.py
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from atomworks.ml.transforms.base import ConvertToTorch
|
|
3
|
+
from mpnn.collate.feature_collator import FeatureCollator
|
|
4
|
+
from mpnn.transforms.feature_aggregation.polymer_ligand_interface import (
|
|
5
|
+
FeaturizePolymerLigandInterfaceMask,
|
|
6
|
+
)
|
|
7
|
+
from mpnn.transforms.polymer_ligand_interface import ComputePolymerLigandInterface
|
|
8
|
+
|
|
9
|
+
from foundry.metrics.metric import Metric
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class NLL(Metric):
|
|
13
|
+
"""
|
|
14
|
+
Computes negative log likelihood (NLL) and perplexity for Protein/Ligand
|
|
15
|
+
MPNN.
|
|
16
|
+
|
|
17
|
+
This metric computes the NLL loss by averaging the negative log
|
|
18
|
+
probabilities at the true token indices, masked by the loss mask. This
|
|
19
|
+
follows the same computation as LabelSmoothedNLLLoss but without label
|
|
20
|
+
smoothing and with averaging instead of a normalization constant.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
return_per_example_metrics=False,
|
|
26
|
+
return_per_residue_metrics=False,
|
|
27
|
+
**kwargs,
|
|
28
|
+
):
|
|
29
|
+
"""
|
|
30
|
+
Initialize the NLL metric.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
return_per_example_metrics (bool): If True, returns per-example
|
|
34
|
+
metrics in addition to the aggregate metrics.
|
|
35
|
+
return_per_residue_metrics (bool): If True, returns per-residue
|
|
36
|
+
metrics in addition to the aggregate metrics.
|
|
37
|
+
**kwargs: Additional keyword arguments passed to the base Metric
|
|
38
|
+
class.
|
|
39
|
+
"""
|
|
40
|
+
super().__init__(**kwargs)
|
|
41
|
+
self.return_per_example_metrics = return_per_example_metrics
|
|
42
|
+
self.return_per_residue_metrics = return_per_residue_metrics
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def kwargs_to_compute_args(self):
|
|
46
|
+
"""
|
|
47
|
+
Map input keys to the compute method arguments.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
dict: Mapping from compute method argument names to nested
|
|
51
|
+
dictionary keys in the input kwargs.
|
|
52
|
+
"""
|
|
53
|
+
return {
|
|
54
|
+
"log_probs": ("network_output", "decoder_features", "log_probs"),
|
|
55
|
+
"S": ("network_input", "input_features", "S"),
|
|
56
|
+
"mask_for_loss": ("network_output", "input_features", "mask_for_loss"),
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
def get_per_residue_mask(self, mask_for_loss, **kwargs):
|
|
60
|
+
"""
|
|
61
|
+
Get the per-residue mask for computing NLL.
|
|
62
|
+
|
|
63
|
+
This method can be overridden by subclasses to apply additional masking
|
|
64
|
+
criteria.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
mask_for_loss (torch.Tensor): [B, L] - mask for loss
|
|
68
|
+
**kwargs: Additional arguments that may be needed by subclasses
|
|
69
|
+
Returns:
|
|
70
|
+
per_residue_mask (torch.Tensor): [B, L] - per-residue mask for NLL
|
|
71
|
+
computation.
|
|
72
|
+
"""
|
|
73
|
+
per_residue_mask = mask_for_loss
|
|
74
|
+
return per_residue_mask
|
|
75
|
+
|
|
76
|
+
def compute_nll_metrics(self, S, log_probs, per_residue_mask):
|
|
77
|
+
"""
|
|
78
|
+
Compute NLL and perplexity metrics using the provided per-residue mask.
|
|
79
|
+
Args:
|
|
80
|
+
S (torch.Tensor): [B, L] - the ground truth sequence.
|
|
81
|
+
log_probs (torch.Tensor): [B, L, vocab_size] - the log
|
|
82
|
+
probabilities for the sequence.
|
|
83
|
+
per_residue_mask (torch.Tensor): [B, L] - per-residue mask for
|
|
84
|
+
computation of NLL.
|
|
85
|
+
Returns:
|
|
86
|
+
nll_dict (dict): Dictionary containing the NLL metrics.
|
|
87
|
+
- mean_nll [1]: mean NLL over (valid) examples (a valid example
|
|
88
|
+
is one with at least one valid residue according to the
|
|
89
|
+
per_residue_mask).
|
|
90
|
+
- nll_per_example [B]: NLL per example, undefined for examples
|
|
91
|
+
with no valid residues.
|
|
92
|
+
- nll_per_residue [B, L]: NLL per residue (masked, 0 for
|
|
93
|
+
masked out positions).
|
|
94
|
+
- mean_perplexity [1]: mean perplexity over (valid) examples.
|
|
95
|
+
- perplexity_per_example [B]: perplexity per example, undefined
|
|
96
|
+
for examples with no valid residues.
|
|
97
|
+
- total_valid_per_example [B]: number of valid residues per
|
|
98
|
+
example.
|
|
99
|
+
- valid_examples_mask [B]: boolean mask indicating examples
|
|
100
|
+
with valid residues.
|
|
101
|
+
- per_residue_mask [B, L]: per-residue mask for NLL computation.
|
|
102
|
+
"""
|
|
103
|
+
_, _, vocab_size = log_probs.shape
|
|
104
|
+
per_residue_mask = per_residue_mask.float()
|
|
105
|
+
|
|
106
|
+
# total_valid_per_example [B] - number of valid residues per example.
|
|
107
|
+
total_valid_per_example = per_residue_mask.sum(dim=-1)
|
|
108
|
+
|
|
109
|
+
# valid_examples_mask [B] - boolean mask indicating examples with valid
|
|
110
|
+
# residues.
|
|
111
|
+
valid_examples_mask = total_valid_per_example > 0
|
|
112
|
+
|
|
113
|
+
# S_onehot [B, L, vocab_size] - the one-hot encoded sequence.
|
|
114
|
+
S_onehot = torch.nn.functional.one_hot(S, num_classes=vocab_size).float()
|
|
115
|
+
|
|
116
|
+
# nll_per_residue [B, L] - the per-residue negative log likelihood,
|
|
117
|
+
# masked by the per_residue_mask.
|
|
118
|
+
nll_per_residue = -torch.sum(S_onehot * log_probs, dim=-1) * per_residue_mask
|
|
119
|
+
|
|
120
|
+
# nll_per_example [B] - average NLL per example. Undefined if there are
|
|
121
|
+
# no valid residues.
|
|
122
|
+
nll_per_example = nll_per_residue.sum(dim=-1) / total_valid_per_example
|
|
123
|
+
|
|
124
|
+
# mean_nll [1] - mean of per-example NLL values (over valid examples).
|
|
125
|
+
mean_nll = nll_per_example[valid_examples_mask].mean()
|
|
126
|
+
|
|
127
|
+
# perplexity_per_example [B] - perplexity per example.
|
|
128
|
+
perplexity_per_example = torch.exp(nll_per_example)
|
|
129
|
+
|
|
130
|
+
# mean_perplexity [1] - mean of per-example perplexity values (over
|
|
131
|
+
# valid examples).
|
|
132
|
+
mean_perplexity = perplexity_per_example[valid_examples_mask].mean()
|
|
133
|
+
|
|
134
|
+
nll_dict = {
|
|
135
|
+
"mean_nll": mean_nll,
|
|
136
|
+
"nll_per_example": nll_per_example,
|
|
137
|
+
"nll_per_residue": nll_per_residue,
|
|
138
|
+
"mean_perplexity": mean_perplexity,
|
|
139
|
+
"perplexity_per_example": perplexity_per_example,
|
|
140
|
+
"total_valid_per_example": total_valid_per_example,
|
|
141
|
+
"valid_examples_mask": valid_examples_mask,
|
|
142
|
+
"per_residue_mask": per_residue_mask,
|
|
143
|
+
}
|
|
144
|
+
return nll_dict
|
|
145
|
+
|
|
146
|
+
def compute(self, log_probs, S, mask_for_loss, **kwargs):
|
|
147
|
+
"""
|
|
148
|
+
Compute the negative log likelihood (NLL) and perplexity, meaned
|
|
149
|
+
across all residues that are included in the loss calculation.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
S (torch.Tensor): [B, L] - the ground truth sequence.
|
|
153
|
+
log_probs (torch.Tensor): [B, L, vocab_size] - the
|
|
154
|
+
log probabilities for the sequence.
|
|
155
|
+
mask_for_loss (torch.Tensor): [B, L] - mask for loss,
|
|
156
|
+
where True is a residue that is included in the loss
|
|
157
|
+
calculation, and False is a residue that is not included
|
|
158
|
+
in the loss calculation.
|
|
159
|
+
**kwargs: Additional arguments that may be needed by subclasses.
|
|
160
|
+
Returns:
|
|
161
|
+
metric_dict (dict): Dictionary containing the computed metrics.
|
|
162
|
+
- mean_nll [1]: mean NLL over (valid) examples.
|
|
163
|
+
- mean_perplexity [1]: mean perplexity over (valid) examples.
|
|
164
|
+
if self.return_per_example_metrics is True:
|
|
165
|
+
- nll_per_example [B]: NLL per example, undefined for examples
|
|
166
|
+
with no valid residues.
|
|
167
|
+
- perplexity_per_example [B]: perplexity per example, undefined
|
|
168
|
+
for examples with no valid residues.
|
|
169
|
+
- total_valid_per_example [B]: number of valid residues per
|
|
170
|
+
example.
|
|
171
|
+
- valid_examples_mask [B]: boolean mask indicating examples
|
|
172
|
+
with valid residues.
|
|
173
|
+
if self.return_per_residue_metrics is True:
|
|
174
|
+
- nll_per_residue [B, L]: NLL per residue (masked, 0 for
|
|
175
|
+
masked out positions).
|
|
176
|
+
- per_residue_mask [B, L]: mask for sequence recovery.
|
|
177
|
+
"""
|
|
178
|
+
# per_residue_mask [B, L] - mask for sequence recovery.
|
|
179
|
+
per_residue_mask = self.get_per_residue_mask(mask_for_loss, **kwargs)
|
|
180
|
+
|
|
181
|
+
# Compute NLL metrics.
|
|
182
|
+
nll_metrics = self.compute_nll_metrics(S, log_probs, per_residue_mask)
|
|
183
|
+
|
|
184
|
+
# Prepare the metric dictionary.
|
|
185
|
+
metric_dict = {
|
|
186
|
+
"mean_nll": nll_metrics["mean_nll"].detach().item(),
|
|
187
|
+
"mean_perplexity": nll_metrics["mean_perplexity"].detach().item(),
|
|
188
|
+
}
|
|
189
|
+
if self.return_per_example_metrics:
|
|
190
|
+
metric_dict.update(
|
|
191
|
+
{
|
|
192
|
+
"nll_per_example": nll_metrics["nll_per_example"],
|
|
193
|
+
"perplexity_per_example": nll_metrics["perplexity_per_example"],
|
|
194
|
+
"total_valid_per_example": nll_metrics["total_valid_per_example"],
|
|
195
|
+
"valid_examples_mask": nll_metrics["valid_examples_mask"],
|
|
196
|
+
}
|
|
197
|
+
)
|
|
198
|
+
if self.return_per_residue_metrics:
|
|
199
|
+
metric_dict.update(
|
|
200
|
+
{
|
|
201
|
+
"nll_per_residue": nll_metrics["nll_per_residue"],
|
|
202
|
+
"per_residue_mask": nll_metrics["per_residue_mask"],
|
|
203
|
+
}
|
|
204
|
+
)
|
|
205
|
+
return metric_dict
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
class InterfaceNLL(NLL):
|
|
209
|
+
"""
|
|
210
|
+
Computes negative log likelihood (NLL) and perplexity for Protein/Ligand
|
|
211
|
+
MPNN specifically for residues at the polymer-ligand interface.
|
|
212
|
+
|
|
213
|
+
This metric inherits from NLL but only computes metrics for residues that
|
|
214
|
+
are within a specified distance threshold of ligand atoms. All returned
|
|
215
|
+
metric names are prefixed with "interface_".
|
|
216
|
+
"""
|
|
217
|
+
|
|
218
|
+
def __init__(
|
|
219
|
+
self,
|
|
220
|
+
interface_distance_threshold: float = 5.0,
|
|
221
|
+
return_per_example_metrics: bool = False,
|
|
222
|
+
return_per_residue_metrics: bool = False,
|
|
223
|
+
**kwargs,
|
|
224
|
+
):
|
|
225
|
+
"""
|
|
226
|
+
Initialize the InterfaceNLL metric.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
interface_distance_threshold (float): Distance threshold in
|
|
230
|
+
Angstroms for considering residues to be at the interface.
|
|
231
|
+
Defaults to 5.0.
|
|
232
|
+
return_per_example_metrics (bool): If True, returns per-example
|
|
233
|
+
metrics in addition to the aggregate metrics.
|
|
234
|
+
return_per_residue_metrics (bool): If True, returns per-residue
|
|
235
|
+
metrics in addition to the aggregate metrics.
|
|
236
|
+
**kwargs: Additional keyword arguments passed to the base Metric
|
|
237
|
+
class.
|
|
238
|
+
"""
|
|
239
|
+
super().__init__(
|
|
240
|
+
return_per_example_metrics=return_per_example_metrics,
|
|
241
|
+
return_per_residue_metrics=return_per_residue_metrics,
|
|
242
|
+
**kwargs,
|
|
243
|
+
)
|
|
244
|
+
self.interface_distance_threshold = interface_distance_threshold
|
|
245
|
+
|
|
246
|
+
@property
|
|
247
|
+
def kwargs_to_compute_args(self):
|
|
248
|
+
"""
|
|
249
|
+
Map input keys to the compute method arguments.
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
dict: Mapping from compute method argument names to nested
|
|
253
|
+
dictionary keys in the input kwargs.
|
|
254
|
+
"""
|
|
255
|
+
args_mapping = super().kwargs_to_compute_args
|
|
256
|
+
# Add atom_array to the mapping for interface computation
|
|
257
|
+
args_mapping["atom_array"] = ("network_input", "atom_array")
|
|
258
|
+
return args_mapping
|
|
259
|
+
|
|
260
|
+
def get_per_residue_mask(self, mask_for_loss, **kwargs):
|
|
261
|
+
"""
|
|
262
|
+
Get the per-residue mask for computing interface NLL.
|
|
263
|
+
|
|
264
|
+
This method computes the interface mask by applying transforms to
|
|
265
|
+
detect polymer-ligand interfaces and combines it with the original
|
|
266
|
+
mask_for_loss using logical AND.
|
|
267
|
+
|
|
268
|
+
Args:
|
|
269
|
+
mask_for_loss (torch.Tensor): [B, L] - mask for loss
|
|
270
|
+
**kwargs: Additional arguments including atom_array
|
|
271
|
+
|
|
272
|
+
Returns:
|
|
273
|
+
per_residue_mask (torch.Tensor): [B, L] - combined mask for
|
|
274
|
+
interface NLL computation.
|
|
275
|
+
"""
|
|
276
|
+
# Extract atom arrays from kwargs
|
|
277
|
+
atom_arrays = kwargs.get("atom_array")
|
|
278
|
+
if atom_arrays is None:
|
|
279
|
+
raise ValueError(
|
|
280
|
+
"atom_array is required for interface "
|
|
281
|
+
+ "computation but was not found"
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
# Initialize transforms
|
|
285
|
+
interface_transform = ComputePolymerLigandInterface(
|
|
286
|
+
distance_threshold=self.interface_distance_threshold
|
|
287
|
+
)
|
|
288
|
+
mask_transform = FeaturizePolymerLigandInterfaceMask()
|
|
289
|
+
convert_to_torch_transform = ConvertToTorch(keys=["input_features"])
|
|
290
|
+
|
|
291
|
+
# Process each atom array in the batch
|
|
292
|
+
batch_interface_masks = []
|
|
293
|
+
for atom_array in atom_arrays:
|
|
294
|
+
# Apply interface detection transform
|
|
295
|
+
data = {"atom_array": atom_array}
|
|
296
|
+
data = interface_transform(data)
|
|
297
|
+
|
|
298
|
+
# Apply interface mask featurization
|
|
299
|
+
data = mask_transform(data)
|
|
300
|
+
|
|
301
|
+
# Convert to torch tensor
|
|
302
|
+
data = convert_to_torch_transform(data)
|
|
303
|
+
|
|
304
|
+
# Extract the interface mask
|
|
305
|
+
interface_mask = data["input_features"]["polymer_ligand_interface_mask"]
|
|
306
|
+
batch_interface_masks.append(interface_mask)
|
|
307
|
+
|
|
308
|
+
# Collate interface masks with proper padding
|
|
309
|
+
collator = FeatureCollator(
|
|
310
|
+
default_padding={"polymer_ligand_interface_mask": False}
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
# Create mock pipeline outputs for collation
|
|
314
|
+
mock_outputs = []
|
|
315
|
+
for interface_mask in batch_interface_masks:
|
|
316
|
+
mock_outputs.append(
|
|
317
|
+
{
|
|
318
|
+
"input_features": {"polymer_ligand_interface_mask": interface_mask},
|
|
319
|
+
"atom_array": None, # Not needed for collation
|
|
320
|
+
}
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
# Collate the masks
|
|
324
|
+
collated = collator(mock_outputs)
|
|
325
|
+
interface_mask = collated["input_features"]["polymer_ligand_interface_mask"]
|
|
326
|
+
|
|
327
|
+
# Convert to the same device and dtype as mask_for_loss
|
|
328
|
+
interface_mask = interface_mask.to(
|
|
329
|
+
device=mask_for_loss.device, dtype=mask_for_loss.dtype
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
# Combine with original mask using logical AND
|
|
333
|
+
combined_mask = mask_for_loss & interface_mask
|
|
334
|
+
|
|
335
|
+
return combined_mask
|
|
336
|
+
|
|
337
|
+
def compute(self, log_probs, S, mask_for_loss, atom_array, **kwargs):
|
|
338
|
+
"""
|
|
339
|
+
Compute the interface negative log likelihood (NLL) and perplexity,
|
|
340
|
+
averaged across interface residues only.
|
|
341
|
+
|
|
342
|
+
This method computes NLL and perplexity specifically for residues at
|
|
343
|
+
the polymer-ligand interface and prefixes all output metrics with
|
|
344
|
+
"interface_".
|
|
345
|
+
|
|
346
|
+
Args:
|
|
347
|
+
log_probs (torch.Tensor): [B, L, vocab_size] - the
|
|
348
|
+
log probabilities for the sequence.
|
|
349
|
+
S (torch.Tensor): [B, L] - the ground truth sequence.
|
|
350
|
+
mask_for_loss (torch.Tensor): [B, L] - mask for loss.
|
|
351
|
+
**kwargs: Additional arguments including atom_array.
|
|
352
|
+
|
|
353
|
+
Returns:
|
|
354
|
+
metric_dict (dict): Dictionary containing the interface NLL and
|
|
355
|
+
perplexity metrics with "interface_" prefix.
|
|
356
|
+
"""
|
|
357
|
+
# Get the base metrics using parent class compute method
|
|
358
|
+
# Pass atom_array through kwargs for get_per_residue_mask method
|
|
359
|
+
kwargs_with_atom_array = {**kwargs, "atom_array": atom_array}
|
|
360
|
+
base_metrics = super().compute(
|
|
361
|
+
log_probs, S, mask_for_loss, **kwargs_with_atom_array
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
# Add "interface_" prefix to all metric keys
|
|
365
|
+
interface_metrics = {}
|
|
366
|
+
for key, value in base_metrics.items():
|
|
367
|
+
interface_metrics[f"interface_{key}"] = value
|
|
368
|
+
|
|
369
|
+
return interface_metrics
|