rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,549 @@
1
+ import copy
2
+ from pathlib import Path
3
+ from typing import Any
4
+
5
+ import numpy as np
6
+ import torch
7
+ from atomworks.constants import (
8
+ DICT_THREE_TO_ONE,
9
+ PROTEIN_BACKBONE_ATOM_NAMES,
10
+ UNKNOWN_AA,
11
+ )
12
+ from atomworks.ml.utils.token import get_token_starts, spread_token_wise
13
+ from biotite.structure import AtomArray
14
+ from mpnn.collate.feature_collator import FeatureCollator
15
+ from mpnn.metrics.sequence_recovery import (
16
+ InterfaceSequenceRecovery,
17
+ SequenceRecovery,
18
+ )
19
+ from mpnn.model.mpnn import LigandMPNN, ProteinMPNN
20
+ from mpnn.pipelines.mpnn import build_mpnn_transform_pipeline
21
+ from mpnn.transforms.feature_aggregation.token_encodings import MPNN_TOKEN_ENCODING
22
+ from mpnn.utils.inference import (
23
+ MPNN_GLOBAL_INFERENCE_DEFAULTS,
24
+ MPNNInferenceInput,
25
+ MPNNInferenceOutput,
26
+ _absolute_path_or_none,
27
+ )
28
+ from mpnn.utils.weights import load_legacy_weights
29
+
30
+ from foundry.inference_engines.checkpoint_registry import REGISTERED_CHECKPOINTS
31
+ from foundry.metrics.metric import MetricManager
32
+ from foundry.utils.ddp import RankedLogger
33
+
34
+ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
35
+
36
+
37
+ class MPNNInferenceEngine:
38
+ """Inference engine for ProteinMPNN/LigandMPNN."""
39
+
40
+ def __init__(
41
+ self,
42
+ *,
43
+ model_type: str = MPNN_GLOBAL_INFERENCE_DEFAULTS["model_type"],
44
+ checkpoint_path: str = MPNN_GLOBAL_INFERENCE_DEFAULTS["checkpoint_path"],
45
+ is_legacy_weights: bool = MPNN_GLOBAL_INFERENCE_DEFAULTS["is_legacy_weights"],
46
+ out_directory: str | None = MPNN_GLOBAL_INFERENCE_DEFAULTS["out_directory"],
47
+ write_fasta: bool = MPNN_GLOBAL_INFERENCE_DEFAULTS["write_fasta"],
48
+ write_structures: bool = MPNN_GLOBAL_INFERENCE_DEFAULTS["write_structures"],
49
+ device: str | torch.device | None = None,
50
+ ):
51
+ # Store raw configuration
52
+ self.model_type = model_type
53
+ self.is_legacy_weights = is_legacy_weights
54
+ self.out_directory = out_directory
55
+ self.write_fasta = write_fasta
56
+ self.write_structures = write_structures
57
+
58
+ # allow null for checkpoint path when foundry-installed
59
+ # TODO: Currently this assumes the model type is the key in the registered path. Rework needed
60
+ self.checkpoint_path = str(REGISTERED_CHECKPOINTS[self.model_type.replace('_', '')].get_default_path()) \
61
+ if not checkpoint_path else checkpoint_path
62
+
63
+ # Determine the device.
64
+ if device is not None:
65
+ self.device = torch.device(device)
66
+ else:
67
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
68
+
69
+ # Set up allowed model types.
70
+ self.allowed_model_types = {"protein_mpnn", "ligand_mpnn"}
71
+
72
+ # Validate the user configuration.
73
+ self._validate_all()
74
+
75
+ # Post-process the configuration (making absolute paths, etc).
76
+ self._post_process_engine_config()
77
+
78
+ # Build and load the model.
79
+ self.model = self._build_and_load_model().to(self.device)
80
+
81
+ # Construct metrics manager.
82
+ self.metrics = self._build_metrics_manager()
83
+
84
+ def _validate_model_config(self) -> None:
85
+ """Validate model-type and checkpoint-related configuration."""
86
+ # Model type.
87
+ if self.model_type not in self.allowed_model_types:
88
+ raise ValueError(
89
+ f"model_type must be one of {self.allowed_model_types}; "
90
+ f"got {self.model_type!r}"
91
+ )
92
+
93
+ # Checkpoint path.
94
+ if not isinstance(self.checkpoint_path, str):
95
+ raise TypeError("checkpoint_path must be a string path.")
96
+
97
+ # Check that the checkpoint path exists.
98
+ ckpt_path = Path(_absolute_path_or_none(self.checkpoint_path))
99
+ if not ckpt_path.is_file():
100
+ raise FileNotFoundError(
101
+ f"checkpoint_path does not exist: {self.checkpoint_path}"
102
+ )
103
+
104
+ # Legacy-weight flag.
105
+ if not isinstance(self.is_legacy_weights, bool):
106
+ raise TypeError("is_legacy_weights must be a bool.")
107
+
108
+ def _validate_output_config(self) -> None:
109
+ """Validate output-directory and writing-related configuration."""
110
+ # Output directory.
111
+ if self.out_directory is not None:
112
+ # Must be a string.
113
+ if not isinstance(self.out_directory, str):
114
+ raise TypeError("out_directory must be a string when provided.")
115
+
116
+ # Boolean writing flags.
117
+ for name in ("write_fasta", "write_structures"):
118
+ value = getattr(self, name)
119
+ if not isinstance(value, bool):
120
+ raise TypeError(f"{name} must be a bool.")
121
+
122
+ # If asked to write outputs, out_directory must be set.
123
+ if value and self.out_directory is None:
124
+ raise ValueError(f"{name} is True, but out_directory is not set.")
125
+
126
+ def _validate_all(self) -> None:
127
+ """Run validation on the user-specified engine config variables."""
128
+ # Validate the model configuration.
129
+ self._validate_model_config()
130
+
131
+ # Validate the output configuration.
132
+ self._validate_output_config()
133
+
134
+ def _post_process_engine_config(self) -> None:
135
+ """Normalize paths into absolute paths."""
136
+ # Make checkpoint path absolute.
137
+ self.checkpoint_path = _absolute_path_or_none(self.checkpoint_path)
138
+
139
+ # Make output directory absolute.
140
+ if self.out_directory is not None:
141
+ self.out_directory = _absolute_path_or_none(self.out_directory)
142
+
143
+ def _build_and_load_model(self) -> torch.nn.Module:
144
+ # Load model architecture.
145
+ if self.model_type == "protein_mpnn":
146
+ model = ProteinMPNN()
147
+ elif self.model_type == "ligand_mpnn":
148
+ model = LigandMPNN()
149
+ else:
150
+ raise ValueError(f"Unsupported model_type: {self.model_type}")
151
+
152
+ # Load weights.
153
+ if self.is_legacy_weights:
154
+ ranked_logger.info("Loading legacy MPNN weights.")
155
+ load_legacy_weights(model, self.checkpoint_path)
156
+ else:
157
+ ranked_logger.info("Loading MPNN weights.")
158
+
159
+ # Load the checkpoint.
160
+ checkpoint = torch.load(
161
+ self.checkpoint_path, map_location="cpu", weights_only=False
162
+ )
163
+
164
+ # Check that checkpoint is a dict.
165
+ if not isinstance(checkpoint, dict) or "model" not in checkpoint:
166
+ raise TypeError("Expected checkpoint to be a dict with a 'model' key.")
167
+
168
+ state_dict = checkpoint["model"]
169
+
170
+ model.load_state_dict(state_dict, strict=True)
171
+
172
+ # Set model to eval mode.
173
+ model.eval()
174
+
175
+ return model
176
+
177
+ def _build_metrics_manager(self) -> MetricManager:
178
+ """Build the metrics manager for inference."""
179
+
180
+ # Construct metrics dict.
181
+ metrics: dict[str, Any] = {
182
+ "sequence_recovery": SequenceRecovery(return_per_example_metrics=True),
183
+ }
184
+ if self.model_type == "ligand_mpnn":
185
+ metrics["interface_sequence_recovery"] = InterfaceSequenceRecovery(
186
+ return_per_example_metrics=True
187
+ )
188
+
189
+ # Construct the MetricManager.
190
+ metric_manager = MetricManager.from_metrics(metrics, raise_errors=True)
191
+
192
+ return metric_manager
193
+
194
+ # ------------------------------------------------------------------ #
195
+ # Public API
196
+ # ------------------------------------------------------------------ #
197
+ def run(
198
+ self,
199
+ *,
200
+ input_dicts: list[dict[str, Any]] | None = None,
201
+ atom_arrays: list[AtomArray] | None = None,
202
+ ) -> list[MPNNInferenceOutput]:
203
+ """Run inference and return a list of MPNNInferenceOutput objects.
204
+
205
+ Parameters
206
+ ----------
207
+ input_dicts:
208
+ Optional list of per-input JSON-like dictionaries (one per
209
+ input). If None, 'atom_arrays' must be provided.
210
+ atom_arrays:
211
+ Optional list of externally provided AtomArray objects. If given,
212
+ must align one-to-one with 'input_dicts'. If None, 'input_dicts'
213
+ must be sufficient to resolve structures internally.
214
+
215
+ Returns
216
+ -------
217
+ list[MPNNInferenceOutput]
218
+ A flat list of per-design MPNNInferenceOutput objects. Writing
219
+ of CIF/FASTA outputs is handled internally based on engine-level
220
+ configuration.
221
+ """
222
+ if input_dicts is None and atom_arrays is None:
223
+ raise ValueError(
224
+ "At least one of 'input_dicts' or 'atom_arrays' must be provided."
225
+ )
226
+ if atom_arrays is not None and input_dicts is not None:
227
+ if len(atom_arrays) != len(input_dicts):
228
+ raise ValueError(
229
+ "'atom_arrays' and 'input_dicts' must have the same length."
230
+ )
231
+
232
+ # Determine the number of inputs.
233
+ num_inputs = len(input_dicts) if input_dicts is not None else len(atom_arrays)
234
+ results: list[MPNNInferenceOutput] = []
235
+ for input_idx in range(num_inputs):
236
+ # Construct the per-input MPNNInferenceInput.
237
+ inference_input = MPNNInferenceInput.from_atom_array_and_dict(
238
+ atom_array=atom_arrays[input_idx] if atom_arrays is not None else None,
239
+ input_dict=input_dicts[input_idx] if input_dicts is not None else None,
240
+ )
241
+
242
+ # Optional per-input RNG seeding for deterministic sampling across
243
+ # batches. Initialize the seed at the beginning of the batches.
244
+ seed = inference_input.input_dict["seed"]
245
+ if seed is not None:
246
+ torch.manual_seed(seed)
247
+ np.random.seed(seed)
248
+ if torch.cuda.is_available():
249
+ torch.cuda.manual_seed_all(seed)
250
+
251
+ # Run the batches for this input.
252
+ for batch_idx in range(inference_input.input_dict["number_of_batches"]):
253
+ ranked_logger.info(
254
+ f"Running MPNN inference for input {input_idx}, "
255
+ f"batch {batch_idx}..."
256
+ )
257
+
258
+ # Run a single batch.
259
+ result = self._run_batch(
260
+ atom_array=inference_input.atom_array,
261
+ input_dict=inference_input.input_dict,
262
+ batch_idx=batch_idx,
263
+ )
264
+ results.extend(result)
265
+
266
+ # Write outputs if requested.
267
+ self._write_outputs(results)
268
+
269
+ return results
270
+
271
+ def _run_batch(
272
+ self,
273
+ atom_array: AtomArray,
274
+ input_dict: dict[str, Any],
275
+ batch_idx: int | None = None,
276
+ ) -> list[MPNNInferenceOutput]:
277
+ """
278
+ Run a single batch (possibly multiple designs) through the pipeline.
279
+
280
+ This function:
281
+ - builds the transform pipeline based on 'input_dict',
282
+ - runs the pipeline and collator,
283
+ - executes the model forward pass,
284
+ - decodes sequences and applies them to the pipeline output
285
+ AtomArray,
286
+ - constructs 'MPNNInferenceOutput' objects
287
+ """
288
+ # Overriding of default pipeline args from input_dict.
289
+ pipeline_args = dict()
290
+ if input_dict["occupancy_threshold_sidechain"] is not None:
291
+ pipeline_args["occupancy_threshold_sidechain"] = input_dict[
292
+ "occupancy_threshold_sidechain"
293
+ ]
294
+ if input_dict["occupancy_threshold_backbone"] is not None:
295
+ pipeline_args["occupancy_threshold_backbone"] = input_dict[
296
+ "occupancy_threshold_backbone"
297
+ ]
298
+ if input_dict["undesired_res_names"] is not None:
299
+ pipeline_args["undesired_res_names"] = input_dict["undesired_res_names"]
300
+
301
+ # Construct the pipeline.
302
+ pipeline = build_mpnn_transform_pipeline(
303
+ model_type=self.model_type,
304
+ is_inference=True,
305
+ minimal_return=True,
306
+ device=self.device,
307
+ **pipeline_args,
308
+ )
309
+
310
+ # Construct the collator.
311
+ collator = FeatureCollator()
312
+
313
+ # Data dict for pipeline: atom_array plus scalar user-settings.
314
+ data: dict[str, Any] = {
315
+ "atom_array": atom_array.copy(),
316
+ # Scalar user settings.
317
+ "structure_noise": input_dict["structure_noise"],
318
+ "decode_type": input_dict["decode_type"],
319
+ "causality_pattern": input_dict["causality_pattern"],
320
+ "initialize_sequence_embedding_with_ground_truth": input_dict[
321
+ "initialize_sequence_embedding_with_ground_truth"
322
+ ],
323
+ "atomize_side_chains": input_dict["atomize_side_chains"],
324
+ "repeat_sample_num": input_dict["repeat_sample_num"],
325
+ "features_to_return": input_dict["features_to_return"],
326
+ }
327
+
328
+ # Run the pipeline.
329
+ pipeline_output = pipeline(data)
330
+
331
+ # Construct the collated network input.
332
+ network_input = collator([pipeline_output])
333
+
334
+ # Run the model forward pass.
335
+ with torch.no_grad():
336
+ network_output = self.model(network_input)
337
+
338
+ # Compute metrics once per batch.
339
+ metrics_output = self.metrics(
340
+ network_input=network_input,
341
+ network_output=network_output,
342
+ extra_info={},
343
+ )
344
+
345
+ # Extract the sampled sequences.
346
+ # S_sampled: [B = batch_size, L = sequence length]
347
+ S_sampled = (
348
+ network_output["decoder_features"]["S_sampled"].detach().cpu().numpy()
349
+ )
350
+ B, L = S_sampled.shape
351
+ if B != input_dict["batch_size"]:
352
+ raise ValueError(
353
+ "Mismatch between network output batch size and input_dict batch_size."
354
+ )
355
+
356
+ # Extract the metrics.
357
+ sequence_recovery_per_design = (
358
+ metrics_output["sequence_recovery.sequence_recovery_per_example_sampled"]
359
+ .detach()
360
+ .cpu()
361
+ .numpy()
362
+ )
363
+ if self.model_type == "ligand_mpnn":
364
+ interface_sequence_recovery_per_design = (
365
+ metrics_output[
366
+ "interface_sequence_recovery.interface_sequence_recovery_per_example_sampled"
367
+ ]
368
+ .detach()
369
+ .cpu()
370
+ .numpy()
371
+ )
372
+ else:
373
+ interface_sequence_recovery_per_design = None
374
+
375
+ # Grab the index to token mapping from the model.
376
+ idx_to_token = MPNN_TOKEN_ENCODING.idx_to_token
377
+
378
+ # Construct the output objects.
379
+ outputs: list[MPNNInferenceOutput] = []
380
+ for design_idx in range(input_dict["batch_size"]):
381
+ # Per design, copy the atom array.
382
+ design_atom_array = pipeline_output["atom_array"].copy()
383
+
384
+ # Grab the non-atomized atom and token level arrays. This mimics
385
+ # the logic in the pipeline for token level extraction, so it
386
+ # should lead to a one-to-one mapping between decoded tokens and
387
+ # non-atomized residues.
388
+ design_non_atomized_array = design_atom_array[~design_atom_array.atomize]
389
+ design_non_atomized_token_starts = get_token_starts(
390
+ design_non_atomized_array
391
+ )
392
+ design_non_atomized_token_level = design_non_atomized_array[
393
+ design_non_atomized_token_starts
394
+ ]
395
+
396
+ # Create the res_name array for the design.
397
+ designed_resnames = np.array(
398
+ [idx_to_token[int(token_idx)] for token_idx in S_sampled[design_idx]],
399
+ dtype=design_atom_array.res_name.dtype,
400
+ )
401
+
402
+ # Sanity check: decoded sequence length must match number of
403
+ # non-atomized tokens.
404
+ if len(design_non_atomized_token_level) != len(designed_resnames):
405
+ raise ValueError(
406
+ "Mismatch between number of non-atomized tokens and "
407
+ "decoded sequence length."
408
+ )
409
+
410
+ # Spread token-level residue names back to atom level, but only
411
+ # over the non-atomized subset.
412
+ designed_resnames_atom = spread_token_wise(
413
+ design_non_atomized_array,
414
+ designed_resnames,
415
+ )
416
+
417
+ # Create a full res_name array.
418
+ full_resnames = design_atom_array.res_name.copy()
419
+ full_resnames[~design_atom_array.atomize] = designed_resnames_atom
420
+
421
+ # Overwrite with designed residue names.
422
+ design_atom_array.set_annotation("res_name", full_resnames)
423
+
424
+ # We need to remove any non-atomized residue atoms that no
425
+ # longer belong (i.e. old side chain atoms). We want to keep any
426
+ # atom that is atomized, any atom that is a backbone atom, and
427
+ # any atom that was fixed.
428
+ design_is_backbone_atom = np.isin(
429
+ design_atom_array.atom_name,
430
+ PROTEIN_BACKBONE_ATOM_NAMES,
431
+ )
432
+ if (
433
+ "mpnn_designed_residue_mask"
434
+ in design_atom_array.get_annotation_categories()
435
+ ):
436
+ design_is_fixed_atom = ~design_atom_array.mpnn_designed_residue_mask
437
+ else:
438
+ design_is_fixed_atom = np.zeros(len(design_atom_array), dtype=bool)
439
+ design_atom_array = design_atom_array[
440
+ design_atom_array.atomize
441
+ | design_is_backbone_atom
442
+ | design_is_fixed_atom
443
+ ]
444
+
445
+ # Construct one letter sequence and recovery metrics for
446
+ # output dict.
447
+ one_letter_seq = "".join(
448
+ [
449
+ DICT_THREE_TO_ONE.get(res_name, DICT_THREE_TO_ONE[UNKNOWN_AA])
450
+ for res_name in designed_resnames
451
+ ]
452
+ )
453
+ sequence_recovery = float(sequence_recovery_per_design[design_idx])
454
+ if interface_sequence_recovery_per_design is not None:
455
+ ligand_interface_sequence_recovery = float(
456
+ interface_sequence_recovery_per_design[design_idx]
457
+ )
458
+ else:
459
+ ligand_interface_sequence_recovery = None
460
+
461
+ # Build the output dict.
462
+ output_dict = {
463
+ "batch_idx": batch_idx,
464
+ "design_idx": design_idx,
465
+ "designed_sequence": one_letter_seq,
466
+ "sequence_recovery": sequence_recovery,
467
+ "ligand_interface_sequence_recovery": (
468
+ ligand_interface_sequence_recovery
469
+ ),
470
+ "model_type": self.model_type,
471
+ "checkpoint_path": self.checkpoint_path,
472
+ "is_legacy_weights": self.is_legacy_weights,
473
+ }
474
+
475
+ outputs.append(
476
+ MPNNInferenceOutput(
477
+ atom_array=design_atom_array,
478
+ output_dict=output_dict,
479
+ input_dict=copy.deepcopy(input_dict),
480
+ )
481
+ )
482
+
483
+ return outputs
484
+
485
+ def _write_outputs(self, results: list[MPNNInferenceOutput]) -> None:
486
+ """Write CIF and/or FASTA outputs based on engine-level settings."""
487
+ out_directory = self.out_directory
488
+
489
+ # If no output directory and writing requested, raise error.
490
+ if not out_directory and (self.write_fasta or self.write_structures):
491
+ raise ValueError(
492
+ "Output directory is not set, but writing of outputs was requested."
493
+ )
494
+ elif not out_directory:
495
+ # Nothing to do.
496
+ return
497
+
498
+ # Make the output directory if it does not exist.
499
+ out_dir_path = Path(out_directory)
500
+ out_dir_path.mkdir(parents=True, exist_ok=True)
501
+
502
+ if self.write_structures:
503
+ # One CIF per design.
504
+ for idx, result in enumerate(results):
505
+ name = result.input_dict["name"]
506
+ batch_idx = result.output_dict["batch_idx"]
507
+ design_idx = result.output_dict["design_idx"]
508
+
509
+ # Can't write without a name.
510
+ if name is None:
511
+ raise ValueError(
512
+ f"Cannot write structure for result {idx}: 'name' is "
513
+ "not set in input_dict."
514
+ )
515
+
516
+ # Construct the output file path.
517
+ file_stem = f"{name}_b{batch_idx}_d{design_idx}"
518
+ base_path = out_dir_path / file_stem
519
+
520
+ # Use the MPNNInferenceOutput helper for writing.
521
+ result.write_structure(
522
+ base_path=base_path,
523
+ )
524
+
525
+ # Write FASTA outputs if requested, one per input name.
526
+ if self.write_fasta:
527
+ # Group results by input name.
528
+ grouped: dict[str, list[MPNNInferenceOutput]] = {}
529
+ for result in results:
530
+ name = result.input_dict["name"]
531
+
532
+ # Can't write without a name.
533
+ if name is None:
534
+ raise ValueError(
535
+ "Cannot write FASTA output: 'name' is not set in input_dict."
536
+ )
537
+
538
+ if name not in grouped:
539
+ grouped[name] = []
540
+
541
+ grouped[name].append(result)
542
+
543
+ # Write one FASTA file per input name.
544
+ for name, group in grouped.items():
545
+ fasta_path = out_dir_path / f"{name}.fa"
546
+ # Append mode so that multiple runs can accumulate designs.
547
+ with fasta_path.open("a") as handle:
548
+ for result in group:
549
+ result.write_fasta(handle=handle)
mpnn/loss/nll_loss.py ADDED
@@ -0,0 +1,122 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class LabelSmoothedNLLLoss(nn.Module):
6
+ def __init__(self, label_smoothing_eps=0.1, normalization_constant=6000.0):
7
+ """
8
+ Label smoothed negative log likelihood loss for Protein/Ligand MPNN.
9
+
10
+ Args:
11
+ label_smoothing_eps (float): The label smoothing factor. Default is
12
+ 0.1.
13
+ normalization_constant (float): The normalization constant for the
14
+ loss. As opposed to averaging per sample in the batch, or
15
+ averaging across all tokens, this constant is used to normalize
16
+ the loss. Default is 6000.0.
17
+ """
18
+ super(LabelSmoothedNLLLoss, self).__init__()
19
+
20
+ self.label_smoothing_eps = label_smoothing_eps
21
+ self.normalization_constant = normalization_constant
22
+
23
+ def forward(self, network_input, network_output, loss_input):
24
+ """
25
+ Given the network_input (same as input_features to the model), network
26
+ output, and loss input, compute the loss.
27
+
28
+ Args:
29
+ network_input (dict): The input to the network.
30
+ - input_features (dict): Contains the input features.
31
+ - S (torch.Tensor): [B, L] - the sequence of residues.
32
+ network_output (dict): The output of the network, a dictionary
33
+ containing several sub-dictionaries; the necessary sub-
34
+ dictionaries and their needed keys are listed below:
35
+ - input_features (dict): Contains the modified input features.
36
+ - mask_for_loss (torch.Tensor): [B, L] - the mask for the
37
+ loss computation.
38
+ - decoder_features (dict): Contains the decoder features.
39
+ - log_probs (torch.Tensor): [B, L, vocab_size] - the log
40
+ probabilities for the sequence.
41
+ loss_input (dict): Dictionary containing additional inputs needed
42
+ for the loss computation. Unused here.
43
+ Returns:
44
+ The loss and a dictionary containing the loss values.
45
+ - label_smoothed_nll_loss_agg (torch.Tensor): [1] - the
46
+ aggregated label smoothed negative log likelihood loss,
47
+ masked by the mask for the loss, summed across the batch and
48
+ length dimensions, and normalized by the normalization
49
+ constant. This is the final loss value returned by the loss
50
+ function.
51
+ - loss_dict (dict): A dictionary containing the loss outputs.
52
+ - label_smoothed_nll_loss_per_residue (torch.Tensor): [B, L]
53
+ - the per-residue label smoothed negative log likelihood
54
+ loss, masked by the mask for loss.
55
+ - label_smoothed_nll_loss_agg (torch.Tensor): [1] - the
56
+ aggregated label smoothed negative log likelihood loss,
57
+ masked by the mask for loss, summed across the batch and
58
+ length dimensions, and normalized by the normalization
59
+ constant. This is the final loss value returned by the
60
+ loss function.
61
+
62
+ """
63
+ input_features = network_input["input_features"]
64
+
65
+ # Check that the input features contains the necessary keys.
66
+ if "S" not in input_features:
67
+ raise ValueError("Input features must contain 'S' key.")
68
+
69
+ # Check that the network output contains the necessary keys.
70
+ if "input_features" not in network_output:
71
+ raise ValueError("Network output must contain 'input_features' key.")
72
+ if "mask_for_loss" not in network_output["input_features"]:
73
+ raise ValueError(
74
+ "Network output must contain'"
75
+ + "mask_for_loss' key in 'input_features'."
76
+ )
77
+ if "decoder_features" not in network_output:
78
+ raise ValueError("Network output must contain 'decoder_features' key.")
79
+ if "log_probs" not in network_output["decoder_features"]:
80
+ raise ValueError(
81
+ "Network output must contain" + "'log_probs' key in 'decoder_features'."
82
+ )
83
+
84
+ B, L, vocab_size = network_output["decoder_features"]["log_probs"].shape
85
+
86
+ # S_onehot [B, L, vocab_size] - the one-hot encoded sequence.
87
+ S_onehot = torch.nn.functional.one_hot(
88
+ input_features["S"], num_classes=vocab_size
89
+ ).float()
90
+
91
+ # label_smoothed_S_onehot [B, L, vocab_size] - the label smoothed
92
+ # encoded sequence.
93
+ label_smoothed_S_onehot = (
94
+ 1 - self.label_smoothing_eps
95
+ ) * S_onehot + self.label_smoothing_eps / vocab_size
96
+
97
+ # label_smoothed_nll_loss_per_residue [B, L] - the per-residue label
98
+ # smoothed negative log likelihood loss, masked by the mask for loss.
99
+ label_smoothed_nll_loss_per_residue = (
100
+ -torch.sum(
101
+ label_smoothed_S_onehot
102
+ * network_output["decoder_features"]["log_probs"],
103
+ dim=-1,
104
+ )
105
+ * network_output["input_features"]["mask_for_loss"]
106
+ )
107
+
108
+ # label_smoothed_nll_loss_agg - the aggregated label smoothed
109
+ # negative log likelihood loss, aggregated across the batch and
110
+ # length dimensions, and normalized by the normalization constant.
111
+ # This is the final loss value returned by the loss function.
112
+ label_smoothed_nll_loss_agg = (
113
+ torch.sum(label_smoothed_nll_loss_per_residue) / self.normalization_constant
114
+ )
115
+
116
+ # Construct the output loss dictionary.
117
+ loss_dict = {
118
+ "label_smoothed_nll_loss_per_residue": label_smoothed_nll_loss_per_residue.detach(),
119
+ "label_smoothed_nll_loss_agg": label_smoothed_nll_loss_agg.detach(),
120
+ }
121
+
122
+ return label_smoothed_nll_loss_agg, loss_dict