rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,440 @@
1
+ from atomworks.ml.transforms.base import ConvertToTorch
2
+ from mpnn.collate.feature_collator import FeatureCollator
3
+ from mpnn.transforms.feature_aggregation.polymer_ligand_interface import (
4
+ FeaturizePolymerLigandInterfaceMask,
5
+ )
6
+ from mpnn.transforms.polymer_ligand_interface import ComputePolymerLigandInterface
7
+
8
+ from foundry.metrics.metric import Metric
9
+
10
+
11
+ class SequenceRecovery(Metric):
12
+ """
13
+ Computes sequence recovery accuracy for Protein/Ligand MPNN.
14
+
15
+ This metric compares both the sampled predicted sequence and the argmax
16
+ sequence to the ground truth sequence and computes the percentage of
17
+ correctly predicted residues for both versions.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ return_per_example_metrics=False,
23
+ return_per_residue_metrics=False,
24
+ **kwargs,
25
+ ):
26
+ """
27
+ Initialize the SequenceRecovery metric.
28
+
29
+ Args:
30
+ return_per_example_metrics (bool): If True, returns per-example
31
+ metrics in addition to the aggregate metrics.
32
+ return_per_residue_metrics (bool): If True, returns per-residue
33
+ metrics in addition to the aggregate metrics.
34
+ **kwargs: Additional keyword arguments passed to the base Metric
35
+ class.
36
+ """
37
+ super().__init__(**kwargs)
38
+ self.return_per_example_metrics = return_per_example_metrics
39
+ self.return_per_residue_metrics = return_per_residue_metrics
40
+
41
+ @property
42
+ def kwargs_to_compute_args(self):
43
+ """Map input keys to the compute method arguments.
44
+
45
+ Returns:
46
+ dict: Mapping from compute method argument names to nested
47
+ dictionary keys in the input kwargs.
48
+ """
49
+ return {
50
+ "S": ("network_input", "input_features", "S"),
51
+ "S_sampled": ("network_output", "decoder_features", "S_sampled"),
52
+ "S_argmax": ("network_output", "decoder_features", "S_argmax"),
53
+ "mask_for_loss": ("network_output", "input_features", "mask_for_loss"),
54
+ }
55
+
56
+ def get_per_residue_mask(self, mask_for_loss, **kwargs):
57
+ """
58
+ Get the per-residue mask for computing sequence recovery.
59
+
60
+ This method can be overridden by subclasses to apply additional
61
+ masking criteria (e.g., interface residues only).
62
+
63
+ Args:
64
+ mask_for_loss (torch.Tensor): [B, L] - mask for loss
65
+ **kwargs: Additional arguments that may be needed by subclasses
66
+
67
+ Returns:
68
+ per_residue_mask (torch.Tensor): [B, L] - per-residue mask for
69
+ sequence recovery computation.
70
+ """
71
+ per_residue_mask = mask_for_loss
72
+ return per_residue_mask
73
+
74
+ def compute_sequence_recovery_metrics(self, S, S_pred, per_residue_mask):
75
+ """
76
+ Compute sequence recovery metrics using the ground truth sequence,
77
+ the predicted sequence, and the per-residue mask.
78
+
79
+ Args:
80
+ S (torch.Tensor): [B, L] - the ground truth sequence.
81
+ S_pred (torch.Tensor): [B, L] - the predicted sequence.
82
+ per_residue_mask (torch.Tensor): [B, L] - per-residue mask for
83
+ computation of sequence recovery.
84
+ Returns:
85
+ sequence_recovery_dict (dict): Dictionary containing the sequence
86
+ recovery metrics.
87
+ - mean_sequence_recovery (torch.Tensor): [1] - mean sequence
88
+ recovery across (valid) examples (a valid example is one
89
+ that has at least one valid residue according to the
90
+ per_residue_mask).
91
+ - sequence_recovery_per_example (torch.Tensor): [B] - sequence
92
+ recovery per example, undefined for examples
93
+ with no valid residues.
94
+ - correct_per_example (torch.Tensor): [B] - total number of
95
+ correct predictions per example.
96
+ - correct_predictions_per_residue (torch.Tensor): [B, L] -
97
+ boolean tensor indicating if the predicted sequence matches
98
+ the ground truth sequence (1 for correct, 0 for incorrect,
99
+ masked by per_residue_mask).
100
+ - total_valid_per_example [B]: number of valid residues per
101
+ example.
102
+ - valid_examples_mask [B]: boolean mask indicating examples
103
+ with valid residues.
104
+ - per_residue_mask [B, L]: per-residue mask for NLL computation.
105
+ """
106
+ per_residue_mask = per_residue_mask.float()
107
+
108
+ # total_valid_per_example [B] - sum of valid residues per example.
109
+ total_valid_per_example = per_residue_mask.sum(dim=-1)
110
+
111
+ # valid_examples_mask [B] - boolean mask indicating examples with
112
+ # valid residues.
113
+ valid_examples_mask = total_valid_per_example > 0
114
+
115
+ # Compute sequence recovery accuracy for sampled residues.
116
+ # correct_predictions [B, L] - boolean tensor indicating if the
117
+ # subject sequence matches the ground truth sequence. Masked by the
118
+ # per_residue_mask.
119
+ correct_predictions_per_residue = (S_pred == S).float() * per_residue_mask
120
+
121
+ # correct_per_example [B] - sum of correct predictions per example.
122
+ correct_per_example = correct_predictions_per_residue.sum(dim=-1)
123
+
124
+ # sequence_recovery_per_example [B] - compute the sequence recovery
125
+ # (accuracy) per example. Undefined if there are no valid residues.
126
+ sequence_recovery_per_example = correct_per_example / total_valid_per_example
127
+
128
+ # mean_sequence_recovery [1] - mean sequence recovery across
129
+ # examples with valid residues.
130
+ mean_sequence_recovery = sequence_recovery_per_example[
131
+ valid_examples_mask
132
+ ].mean()
133
+
134
+ # Create the sequence recovery dictionary.
135
+ sequence_recovery_dict = {
136
+ "mean_sequence_recovery": mean_sequence_recovery,
137
+ "sequence_recovery_per_example": sequence_recovery_per_example,
138
+ "correct_per_example": correct_per_example,
139
+ "correct_predictions_per_residue": correct_predictions_per_residue,
140
+ "total_valid_per_example": total_valid_per_example,
141
+ "valid_examples_mask": valid_examples_mask,
142
+ "per_residue_mask": per_residue_mask,
143
+ }
144
+
145
+ return sequence_recovery_dict
146
+
147
+ def compute(self, S, S_sampled, S_argmax, mask_for_loss, **kwargs):
148
+ """
149
+ Compute sequence recovery accuracy for both sampled and argmax
150
+ sequences.
151
+
152
+ This method compares both the sampled predicted sequence and the argmax
153
+ sequence to the ground truth sequence and computes the fraction of
154
+ correctly predicted residues for both versions (i.e. the accuracy).
155
+
156
+ A NOTE on shapes:
157
+ B: batch size
158
+ L: sequence length
159
+ vocab_size: vocabulary size
160
+
161
+ Args:
162
+ S (torch.Tensor): [B, L] - the ground truth sequence.
163
+ S_sampled (torch.Tensor): [B, L] - the sampled sequence,
164
+ sampled from the probabilities (unknown residues are not
165
+ sampled).
166
+ S_argmax (torch.Tensor): [B, L] - the predicted sequence,
167
+ obtained by taking the argmax of the probabilities
168
+ (unknown residues are not selected).
169
+ mask_for_loss (torch.Tensor): [B, L] - mask for loss,
170
+ where True is a residue that is included in the loss
171
+ calculation, and False is a residue that is not included
172
+ in the loss calculation.
173
+ **kwargs: Additional arguments that may be needed by subclasses.
174
+
175
+ Returns:
176
+ metric_dict (dict): Dictionary containing the sequence recovery
177
+ metrics.
178
+ - mean_sequence_recovery_sampled (torch.Tensor): [1] -
179
+ mean sequence recovery for the sampled sequence.
180
+ - mean_sequence_recovery_argmax (torch.Tensor): [1] -
181
+ mean sequence recovery for the argmax sequence.
182
+ if self.return_per_example_metrics is True:
183
+ - sequence_recovery_per_example_sampled (torch.Tensor): [B] -
184
+ sequence recovery per example for the sampled sequence,
185
+ undefined for examples with no valid residues.
186
+ - sequence_recovery_per_example_argmax (torch.Tensor): [B] -
187
+ sequence recovery per example for the argmax sequence,
188
+ undefined for examples with no valid residues.
189
+ - correct_per_example_sampled (torch.Tensor): [B] - total
190
+ number of correct predictions per example for the sampled
191
+ sequence.
192
+ - correct_per_example_argmax (torch.Tensor): [B] - total
193
+ number of correct predictions per example for the argmax
194
+ sequence.
195
+ - total_valid_per_example (torch.Tensor): [B] - number of valid
196
+ residues per example.
197
+ - valid_examples_mask (torch.Tensor): [B] - boolean mask for
198
+ valid examples.
199
+ if self.return_per_residue_metrics is True:
200
+ - correct_predictions_per_residue_sampled (torch.Tensor):
201
+ [B, L] - boolean tensor indicating if the sampled
202
+ sequence matches the ground truth sequence (1 for correct,
203
+ 0 for incorrect, masked by per_residue_mask).
204
+ - correct_predictions_per_residue_argmax (torch.Tensor):
205
+ [B, L] - boolean tensor indicating if the argmax sequence
206
+ matches the ground truth sequence (1 for correct, 0 for
207
+ incorrect, masked by per_residue_mask).
208
+ - per_residue_mask (torch.Tensor): [B, L] - per-residue
209
+ mask for sequence recovery computation.
210
+ """
211
+ # per_residue_mask [B, L] - mask for sequence recovery.
212
+ per_residue_mask = self.get_per_residue_mask(mask_for_loss, **kwargs)
213
+
214
+ # Compute sequence recovery metrics for sampled sequence.
215
+ sequence_recovery_metrics_sampled = self.compute_sequence_recovery_metrics(
216
+ S, S_sampled, per_residue_mask
217
+ )
218
+
219
+ # Compute sequence recovery metrics for argmax sequence.
220
+ sequence_recovery_metrics_argmax = self.compute_sequence_recovery_metrics(
221
+ S, S_argmax, per_residue_mask
222
+ )
223
+
224
+ # Prepare the metric dictionary.
225
+ metric_dict = {
226
+ "mean_sequence_recovery_sampled": sequence_recovery_metrics_sampled[
227
+ "mean_sequence_recovery"
228
+ ]
229
+ .detach()
230
+ .item(),
231
+ "mean_sequence_recovery_argmax": sequence_recovery_metrics_argmax[
232
+ "mean_sequence_recovery"
233
+ ]
234
+ .detach()
235
+ .item(),
236
+ }
237
+ if self.return_per_example_metrics:
238
+ metric_dict.update(
239
+ {
240
+ "sequence_recovery_per_example_sampled": sequence_recovery_metrics_sampled[
241
+ "sequence_recovery_per_example"
242
+ ],
243
+ "sequence_recovery_per_example_argmax": sequence_recovery_metrics_argmax[
244
+ "sequence_recovery_per_example"
245
+ ],
246
+ "correct_per_example_sampled": sequence_recovery_metrics_sampled[
247
+ "correct_per_example"
248
+ ],
249
+ "correct_per_example_argmax": sequence_recovery_metrics_argmax[
250
+ "correct_per_example"
251
+ ],
252
+ "total_valid_per_example": sequence_recovery_metrics_sampled[
253
+ "total_valid_per_example"
254
+ ],
255
+ "valid_examples_mask": sequence_recovery_metrics_sampled[
256
+ "valid_examples_mask"
257
+ ],
258
+ }
259
+ )
260
+ if self.return_per_residue_metrics:
261
+ metric_dict.update(
262
+ {
263
+ "correct_predictions_per_residue_sampled": sequence_recovery_metrics_sampled[
264
+ "correct_predictions_per_residue"
265
+ ],
266
+ "correct_predictions_per_residue_argmax": sequence_recovery_metrics_argmax[
267
+ "correct_predictions_per_residue"
268
+ ],
269
+ "per_residue_mask": sequence_recovery_metrics_sampled[
270
+ "per_residue_mask"
271
+ ],
272
+ }
273
+ )
274
+
275
+ return metric_dict
276
+
277
+
278
+ class InterfaceSequenceRecovery(SequenceRecovery):
279
+ """
280
+ Computes sequence recovery accuracy for Protein/Ligand MPNN specifically
281
+ for residues at the polymer-ligand interface.
282
+
283
+ This metric inherits from SequenceRecovery but only computes metrics for
284
+ residues that are within a specified distance threshold of ligand atoms.
285
+ All returned metric names are prefixed with "interface_".
286
+ """
287
+
288
+ def __init__(
289
+ self,
290
+ interface_distance_threshold: float = 5.0,
291
+ return_per_example_metrics: bool = False,
292
+ return_per_residue_metrics: bool = False,
293
+ **kwargs,
294
+ ):
295
+ """
296
+ Initialize the InterfaceSequenceRecovery metric.
297
+
298
+ Args:
299
+ interface_distance_threshold (float): Distance threshold in
300
+ Angstroms for considering residues to be at the interface.
301
+ Defaults to 5.0.
302
+ return_per_example_metrics (bool): If True, returns per-example
303
+ metrics in addition to the aggregate metrics.
304
+ return_per_residue_metrics (bool): If True, returns per-residue
305
+ metrics in addition to the aggregate metrics.
306
+ **kwargs: Additional keyword arguments passed to the base Metric
307
+ class.
308
+ """
309
+ super().__init__(
310
+ return_per_example_metrics=return_per_example_metrics,
311
+ return_per_residue_metrics=return_per_residue_metrics,
312
+ **kwargs,
313
+ )
314
+ self.interface_distance_threshold = interface_distance_threshold
315
+
316
+ @property
317
+ def kwargs_to_compute_args(self):
318
+ """Map input keys to the compute method arguments.
319
+
320
+ Returns:
321
+ dict: Mapping from compute method argument names to nested
322
+ dictionary keys in the input kwargs.
323
+ """
324
+ args_mapping = super().kwargs_to_compute_args
325
+ # Add atom_array to the mapping for interface computation
326
+ args_mapping["atom_array"] = ("network_input", "atom_array")
327
+ return args_mapping
328
+
329
+ def get_per_residue_mask(self, mask_for_loss, **kwargs):
330
+ """
331
+ Get the per-residue mask for computing interface sequence recovery.
332
+
333
+ This method computes the interface mask by applying transforms to
334
+ detect polymer-ligand interfaces and combines it with the original
335
+ mask_for_loss using logical AND.
336
+
337
+ Args:
338
+ mask_for_loss (torch.Tensor): [B, L] - mask for loss
339
+ **kwargs: Additional arguments including atom_array
340
+
341
+ Returns:
342
+ per_residue_mask (torch.Tensor): [B, L] - combined mask for
343
+ interface sequence recovery computation.
344
+ """
345
+ # Extract atom arrays from kwargs
346
+ atom_arrays = kwargs.get("atom_array")
347
+ if atom_arrays is None:
348
+ raise ValueError(
349
+ "atom_array is required for interface "
350
+ + "computation but was not found"
351
+ )
352
+
353
+ # Initialize transforms
354
+ interface_transform = ComputePolymerLigandInterface(
355
+ distance_threshold=self.interface_distance_threshold
356
+ )
357
+ mask_transform = FeaturizePolymerLigandInterfaceMask()
358
+ convert_to_torch_transform = ConvertToTorch(keys=["input_features"])
359
+
360
+ # Process each atom array in the batch
361
+ batch_interface_masks = []
362
+ for atom_array in atom_arrays:
363
+ # Apply interface detection transform
364
+ data = {"atom_array": atom_array}
365
+ data = interface_transform(data)
366
+
367
+ # Apply interface mask featurization
368
+ data = mask_transform(data)
369
+
370
+ # Convert to torch tensor
371
+ data = convert_to_torch_transform(data)
372
+
373
+ # Extract the interface mask
374
+ interface_mask = data["input_features"]["polymer_ligand_interface_mask"]
375
+
376
+ # Convert to torch tensor
377
+ batch_interface_masks.append(interface_mask)
378
+
379
+ # Collate interface masks with proper padding
380
+ collator = FeatureCollator(
381
+ default_padding={"polymer_ligand_interface_mask": False}
382
+ )
383
+
384
+ # Create mock pipeline outputs for collation
385
+ mock_outputs = []
386
+ for interface_mask in batch_interface_masks:
387
+ mock_outputs.append(
388
+ {
389
+ "input_features": {"polymer_ligand_interface_mask": interface_mask},
390
+ "atom_array": None, # Not needed for collation
391
+ }
392
+ )
393
+
394
+ # Collate the masks
395
+ collated = collator(mock_outputs)
396
+ interface_mask = collated["input_features"]["polymer_ligand_interface_mask"]
397
+
398
+ # Convert to the same device and dtype as mask_for_loss
399
+ interface_mask = interface_mask.to(
400
+ device=mask_for_loss.device, dtype=mask_for_loss.dtype
401
+ )
402
+
403
+ # Combine with original mask using logical AND
404
+ combined_mask = mask_for_loss & interface_mask
405
+
406
+ return combined_mask
407
+
408
+ def compute(self, S, S_sampled, S_argmax, mask_for_loss, atom_array, **kwargs):
409
+ """
410
+ Compute interface sequence recovery accuracy for both sampled and
411
+ argmax sequences.
412
+
413
+ This method computes sequence recovery specifically for residues at
414
+ the polymer-ligand interface and prefixes all output metrics with
415
+ "interface_".
416
+
417
+ Args:
418
+ S (torch.Tensor): [B, L] - the ground truth sequence.
419
+ S_sampled (torch.Tensor): [B, L] - the sampled sequence.
420
+ S_argmax (torch.Tensor): [B, L] - the predicted sequence.
421
+ mask_for_loss (torch.Tensor): [B, L] - mask for loss.
422
+ **kwargs: Additional arguments including atom_array.
423
+
424
+ Returns:
425
+ metric_dict (dict): Dictionary containing the interface sequence
426
+ recovery metrics with "interface_" prefix.
427
+ """
428
+ # Get the base metrics using parent class compute method
429
+ # Pass atom_array through kwargs for get_per_residue_mask method
430
+ kwargs_with_atom_array = {**kwargs, "atom_array": atom_array}
431
+ base_metrics = super().compute(
432
+ S, S_sampled, S_argmax, mask_for_loss, **kwargs_with_atom_array
433
+ )
434
+
435
+ # Add "interface_" prefix to all metric keys
436
+ interface_metrics = {}
437
+ for key, value in base_metrics.items():
438
+ interface_metrics[f"interface_{key}"] = value
439
+
440
+ return interface_metrics