rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
mpnn/metrics/nll.py ADDED
@@ -0,0 +1,369 @@
1
+ import torch
2
+ from atomworks.ml.transforms.base import ConvertToTorch
3
+ from mpnn.collate.feature_collator import FeatureCollator
4
+ from mpnn.transforms.feature_aggregation.polymer_ligand_interface import (
5
+ FeaturizePolymerLigandInterfaceMask,
6
+ )
7
+ from mpnn.transforms.polymer_ligand_interface import ComputePolymerLigandInterface
8
+
9
+ from foundry.metrics.metric import Metric
10
+
11
+
12
+ class NLL(Metric):
13
+ """
14
+ Computes negative log likelihood (NLL) and perplexity for Protein/Ligand
15
+ MPNN.
16
+
17
+ This metric computes the NLL loss by averaging the negative log
18
+ probabilities at the true token indices, masked by the loss mask. This
19
+ follows the same computation as LabelSmoothedNLLLoss but without label
20
+ smoothing and with averaging instead of a normalization constant.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ return_per_example_metrics=False,
26
+ return_per_residue_metrics=False,
27
+ **kwargs,
28
+ ):
29
+ """
30
+ Initialize the NLL metric.
31
+
32
+ Args:
33
+ return_per_example_metrics (bool): If True, returns per-example
34
+ metrics in addition to the aggregate metrics.
35
+ return_per_residue_metrics (bool): If True, returns per-residue
36
+ metrics in addition to the aggregate metrics.
37
+ **kwargs: Additional keyword arguments passed to the base Metric
38
+ class.
39
+ """
40
+ super().__init__(**kwargs)
41
+ self.return_per_example_metrics = return_per_example_metrics
42
+ self.return_per_residue_metrics = return_per_residue_metrics
43
+
44
+ @property
45
+ def kwargs_to_compute_args(self):
46
+ """
47
+ Map input keys to the compute method arguments.
48
+
49
+ Returns:
50
+ dict: Mapping from compute method argument names to nested
51
+ dictionary keys in the input kwargs.
52
+ """
53
+ return {
54
+ "log_probs": ("network_output", "decoder_features", "log_probs"),
55
+ "S": ("network_input", "input_features", "S"),
56
+ "mask_for_loss": ("network_output", "input_features", "mask_for_loss"),
57
+ }
58
+
59
+ def get_per_residue_mask(self, mask_for_loss, **kwargs):
60
+ """
61
+ Get the per-residue mask for computing NLL.
62
+
63
+ This method can be overridden by subclasses to apply additional masking
64
+ criteria.
65
+
66
+ Args:
67
+ mask_for_loss (torch.Tensor): [B, L] - mask for loss
68
+ **kwargs: Additional arguments that may be needed by subclasses
69
+ Returns:
70
+ per_residue_mask (torch.Tensor): [B, L] - per-residue mask for NLL
71
+ computation.
72
+ """
73
+ per_residue_mask = mask_for_loss
74
+ return per_residue_mask
75
+
76
+ def compute_nll_metrics(self, S, log_probs, per_residue_mask):
77
+ """
78
+ Compute NLL and perplexity metrics using the provided per-residue mask.
79
+ Args:
80
+ S (torch.Tensor): [B, L] - the ground truth sequence.
81
+ log_probs (torch.Tensor): [B, L, vocab_size] - the log
82
+ probabilities for the sequence.
83
+ per_residue_mask (torch.Tensor): [B, L] - per-residue mask for
84
+ computation of NLL.
85
+ Returns:
86
+ nll_dict (dict): Dictionary containing the NLL metrics.
87
+ - mean_nll [1]: mean NLL over (valid) examples (a valid example
88
+ is one with at least one valid residue according to the
89
+ per_residue_mask).
90
+ - nll_per_example [B]: NLL per example, undefined for examples
91
+ with no valid residues.
92
+ - nll_per_residue [B, L]: NLL per residue (masked, 0 for
93
+ masked out positions).
94
+ - mean_perplexity [1]: mean perplexity over (valid) examples.
95
+ - perplexity_per_example [B]: perplexity per example, undefined
96
+ for examples with no valid residues.
97
+ - total_valid_per_example [B]: number of valid residues per
98
+ example.
99
+ - valid_examples_mask [B]: boolean mask indicating examples
100
+ with valid residues.
101
+ - per_residue_mask [B, L]: per-residue mask for NLL computation.
102
+ """
103
+ _, _, vocab_size = log_probs.shape
104
+ per_residue_mask = per_residue_mask.float()
105
+
106
+ # total_valid_per_example [B] - number of valid residues per example.
107
+ total_valid_per_example = per_residue_mask.sum(dim=-1)
108
+
109
+ # valid_examples_mask [B] - boolean mask indicating examples with valid
110
+ # residues.
111
+ valid_examples_mask = total_valid_per_example > 0
112
+
113
+ # S_onehot [B, L, vocab_size] - the one-hot encoded sequence.
114
+ S_onehot = torch.nn.functional.one_hot(S, num_classes=vocab_size).float()
115
+
116
+ # nll_per_residue [B, L] - the per-residue negative log likelihood,
117
+ # masked by the per_residue_mask.
118
+ nll_per_residue = -torch.sum(S_onehot * log_probs, dim=-1) * per_residue_mask
119
+
120
+ # nll_per_example [B] - average NLL per example. Undefined if there are
121
+ # no valid residues.
122
+ nll_per_example = nll_per_residue.sum(dim=-1) / total_valid_per_example
123
+
124
+ # mean_nll [1] - mean of per-example NLL values (over valid examples).
125
+ mean_nll = nll_per_example[valid_examples_mask].mean()
126
+
127
+ # perplexity_per_example [B] - perplexity per example.
128
+ perplexity_per_example = torch.exp(nll_per_example)
129
+
130
+ # mean_perplexity [1] - mean of per-example perplexity values (over
131
+ # valid examples).
132
+ mean_perplexity = perplexity_per_example[valid_examples_mask].mean()
133
+
134
+ nll_dict = {
135
+ "mean_nll": mean_nll,
136
+ "nll_per_example": nll_per_example,
137
+ "nll_per_residue": nll_per_residue,
138
+ "mean_perplexity": mean_perplexity,
139
+ "perplexity_per_example": perplexity_per_example,
140
+ "total_valid_per_example": total_valid_per_example,
141
+ "valid_examples_mask": valid_examples_mask,
142
+ "per_residue_mask": per_residue_mask,
143
+ }
144
+ return nll_dict
145
+
146
+ def compute(self, log_probs, S, mask_for_loss, **kwargs):
147
+ """
148
+ Compute the negative log likelihood (NLL) and perplexity, meaned
149
+ across all residues that are included in the loss calculation.
150
+
151
+ Args:
152
+ S (torch.Tensor): [B, L] - the ground truth sequence.
153
+ log_probs (torch.Tensor): [B, L, vocab_size] - the
154
+ log probabilities for the sequence.
155
+ mask_for_loss (torch.Tensor): [B, L] - mask for loss,
156
+ where True is a residue that is included in the loss
157
+ calculation, and False is a residue that is not included
158
+ in the loss calculation.
159
+ **kwargs: Additional arguments that may be needed by subclasses.
160
+ Returns:
161
+ metric_dict (dict): Dictionary containing the computed metrics.
162
+ - mean_nll [1]: mean NLL over (valid) examples.
163
+ - mean_perplexity [1]: mean perplexity over (valid) examples.
164
+ if self.return_per_example_metrics is True:
165
+ - nll_per_example [B]: NLL per example, undefined for examples
166
+ with no valid residues.
167
+ - perplexity_per_example [B]: perplexity per example, undefined
168
+ for examples with no valid residues.
169
+ - total_valid_per_example [B]: number of valid residues per
170
+ example.
171
+ - valid_examples_mask [B]: boolean mask indicating examples
172
+ with valid residues.
173
+ if self.return_per_residue_metrics is True:
174
+ - nll_per_residue [B, L]: NLL per residue (masked, 0 for
175
+ masked out positions).
176
+ - per_residue_mask [B, L]: mask for sequence recovery.
177
+ """
178
+ # per_residue_mask [B, L] - mask for sequence recovery.
179
+ per_residue_mask = self.get_per_residue_mask(mask_for_loss, **kwargs)
180
+
181
+ # Compute NLL metrics.
182
+ nll_metrics = self.compute_nll_metrics(S, log_probs, per_residue_mask)
183
+
184
+ # Prepare the metric dictionary.
185
+ metric_dict = {
186
+ "mean_nll": nll_metrics["mean_nll"].detach().item(),
187
+ "mean_perplexity": nll_metrics["mean_perplexity"].detach().item(),
188
+ }
189
+ if self.return_per_example_metrics:
190
+ metric_dict.update(
191
+ {
192
+ "nll_per_example": nll_metrics["nll_per_example"],
193
+ "perplexity_per_example": nll_metrics["perplexity_per_example"],
194
+ "total_valid_per_example": nll_metrics["total_valid_per_example"],
195
+ "valid_examples_mask": nll_metrics["valid_examples_mask"],
196
+ }
197
+ )
198
+ if self.return_per_residue_metrics:
199
+ metric_dict.update(
200
+ {
201
+ "nll_per_residue": nll_metrics["nll_per_residue"],
202
+ "per_residue_mask": nll_metrics["per_residue_mask"],
203
+ }
204
+ )
205
+ return metric_dict
206
+
207
+
208
+ class InterfaceNLL(NLL):
209
+ """
210
+ Computes negative log likelihood (NLL) and perplexity for Protein/Ligand
211
+ MPNN specifically for residues at the polymer-ligand interface.
212
+
213
+ This metric inherits from NLL but only computes metrics for residues that
214
+ are within a specified distance threshold of ligand atoms. All returned
215
+ metric names are prefixed with "interface_".
216
+ """
217
+
218
+ def __init__(
219
+ self,
220
+ interface_distance_threshold: float = 5.0,
221
+ return_per_example_metrics: bool = False,
222
+ return_per_residue_metrics: bool = False,
223
+ **kwargs,
224
+ ):
225
+ """
226
+ Initialize the InterfaceNLL metric.
227
+
228
+ Args:
229
+ interface_distance_threshold (float): Distance threshold in
230
+ Angstroms for considering residues to be at the interface.
231
+ Defaults to 5.0.
232
+ return_per_example_metrics (bool): If True, returns per-example
233
+ metrics in addition to the aggregate metrics.
234
+ return_per_residue_metrics (bool): If True, returns per-residue
235
+ metrics in addition to the aggregate metrics.
236
+ **kwargs: Additional keyword arguments passed to the base Metric
237
+ class.
238
+ """
239
+ super().__init__(
240
+ return_per_example_metrics=return_per_example_metrics,
241
+ return_per_residue_metrics=return_per_residue_metrics,
242
+ **kwargs,
243
+ )
244
+ self.interface_distance_threshold = interface_distance_threshold
245
+
246
+ @property
247
+ def kwargs_to_compute_args(self):
248
+ """
249
+ Map input keys to the compute method arguments.
250
+
251
+ Returns:
252
+ dict: Mapping from compute method argument names to nested
253
+ dictionary keys in the input kwargs.
254
+ """
255
+ args_mapping = super().kwargs_to_compute_args
256
+ # Add atom_array to the mapping for interface computation
257
+ args_mapping["atom_array"] = ("network_input", "atom_array")
258
+ return args_mapping
259
+
260
+ def get_per_residue_mask(self, mask_for_loss, **kwargs):
261
+ """
262
+ Get the per-residue mask for computing interface NLL.
263
+
264
+ This method computes the interface mask by applying transforms to
265
+ detect polymer-ligand interfaces and combines it with the original
266
+ mask_for_loss using logical AND.
267
+
268
+ Args:
269
+ mask_for_loss (torch.Tensor): [B, L] - mask for loss
270
+ **kwargs: Additional arguments including atom_array
271
+
272
+ Returns:
273
+ per_residue_mask (torch.Tensor): [B, L] - combined mask for
274
+ interface NLL computation.
275
+ """
276
+ # Extract atom arrays from kwargs
277
+ atom_arrays = kwargs.get("atom_array")
278
+ if atom_arrays is None:
279
+ raise ValueError(
280
+ "atom_array is required for interface "
281
+ + "computation but was not found"
282
+ )
283
+
284
+ # Initialize transforms
285
+ interface_transform = ComputePolymerLigandInterface(
286
+ distance_threshold=self.interface_distance_threshold
287
+ )
288
+ mask_transform = FeaturizePolymerLigandInterfaceMask()
289
+ convert_to_torch_transform = ConvertToTorch(keys=["input_features"])
290
+
291
+ # Process each atom array in the batch
292
+ batch_interface_masks = []
293
+ for atom_array in atom_arrays:
294
+ # Apply interface detection transform
295
+ data = {"atom_array": atom_array}
296
+ data = interface_transform(data)
297
+
298
+ # Apply interface mask featurization
299
+ data = mask_transform(data)
300
+
301
+ # Convert to torch tensor
302
+ data = convert_to_torch_transform(data)
303
+
304
+ # Extract the interface mask
305
+ interface_mask = data["input_features"]["polymer_ligand_interface_mask"]
306
+ batch_interface_masks.append(interface_mask)
307
+
308
+ # Collate interface masks with proper padding
309
+ collator = FeatureCollator(
310
+ default_padding={"polymer_ligand_interface_mask": False}
311
+ )
312
+
313
+ # Create mock pipeline outputs for collation
314
+ mock_outputs = []
315
+ for interface_mask in batch_interface_masks:
316
+ mock_outputs.append(
317
+ {
318
+ "input_features": {"polymer_ligand_interface_mask": interface_mask},
319
+ "atom_array": None, # Not needed for collation
320
+ }
321
+ )
322
+
323
+ # Collate the masks
324
+ collated = collator(mock_outputs)
325
+ interface_mask = collated["input_features"]["polymer_ligand_interface_mask"]
326
+
327
+ # Convert to the same device and dtype as mask_for_loss
328
+ interface_mask = interface_mask.to(
329
+ device=mask_for_loss.device, dtype=mask_for_loss.dtype
330
+ )
331
+
332
+ # Combine with original mask using logical AND
333
+ combined_mask = mask_for_loss & interface_mask
334
+
335
+ return combined_mask
336
+
337
+ def compute(self, log_probs, S, mask_for_loss, atom_array, **kwargs):
338
+ """
339
+ Compute the interface negative log likelihood (NLL) and perplexity,
340
+ averaged across interface residues only.
341
+
342
+ This method computes NLL and perplexity specifically for residues at
343
+ the polymer-ligand interface and prefixes all output metrics with
344
+ "interface_".
345
+
346
+ Args:
347
+ log_probs (torch.Tensor): [B, L, vocab_size] - the
348
+ log probabilities for the sequence.
349
+ S (torch.Tensor): [B, L] - the ground truth sequence.
350
+ mask_for_loss (torch.Tensor): [B, L] - mask for loss.
351
+ **kwargs: Additional arguments including atom_array.
352
+
353
+ Returns:
354
+ metric_dict (dict): Dictionary containing the interface NLL and
355
+ perplexity metrics with "interface_" prefix.
356
+ """
357
+ # Get the base metrics using parent class compute method
358
+ # Pass atom_array through kwargs for get_per_residue_mask method
359
+ kwargs_with_atom_array = {**kwargs, "atom_array": atom_array}
360
+ base_metrics = super().compute(
361
+ log_probs, S, mask_for_loss, **kwargs_with_atom_array
362
+ )
363
+
364
+ # Add "interface_" prefix to all metric keys
365
+ interface_metrics = {}
366
+ for key, value in base_metrics.items():
367
+ interface_metrics[f"interface_{key}"] = value
368
+
369
+ return interface_metrics