rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,2397 @@
1
+ """
2
+ CLI, JSON config handling, and structure-aware inputs for MPNN inference.
3
+
4
+ This module implements:
5
+
6
+ - Argument parser and CLI -> JSON builder.
7
+ - MPNNInferenceInput construction utilities.
8
+ """
9
+
10
+ import argparse
11
+ import ast
12
+ import copy
13
+ import json
14
+ import logging
15
+ import re
16
+ from dataclasses import dataclass
17
+ from os import PathLike
18
+ from pathlib import Path
19
+ from typing import Any
20
+
21
+ import numpy as np
22
+ from atomworks.io import parse
23
+ from atomworks.io.parser import STANDARD_PARSER_ARGS, parse_atom_array
24
+ from atomworks.io.utils.atom_array_plus import (
25
+ AtomArrayPlus,
26
+ as_atom_array_plus,
27
+ )
28
+ from atomworks.io.utils.io_utils import to_cif_file
29
+ from atomworks.ml.utils.token import get_token_starts, spread_token_wise
30
+ from biotite.structure import AtomArray
31
+ from mpnn.transforms.feature_aggregation.token_encodings import MPNN_TOKEN_ENCODING
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+ MPNN_GLOBAL_INFERENCE_DEFAULTS: dict[str, Any] = {
36
+ # Top-level Config JSON
37
+ "config_json": None,
38
+ # Model Type and Weights
39
+ "checkpoint_path": None,
40
+ "model_type": None,
41
+ "is_legacy_weights": None,
42
+ # Output controls
43
+ "out_directory": None,
44
+ "write_fasta": True,
45
+ "write_structures": True,
46
+ }
47
+
48
+ MPNN_PER_INPUT_INFERENCE_DEFAULTS: dict[str, Any] = {
49
+ # Structure Path and Name
50
+ "structure_path": None,
51
+ "name": None,
52
+ # Sampling Parameters
53
+ "seed": None,
54
+ "batch_size": 1,
55
+ "number_of_batches": 1,
56
+ # Parser Overrides
57
+ "remove_ccds": [],
58
+ "remove_waters": None,
59
+ # Pipeline Setup Overrides
60
+ "occupancy_threshold_sidechain": 0.0,
61
+ "occupancy_threshold_backbone": 0.0,
62
+ "undesired_res_names": [],
63
+ # Scalar User Settings
64
+ "structure_noise": 0.0,
65
+ "decode_type": "auto_regressive",
66
+ "causality_pattern": "auto_regressive",
67
+ "initialize_sequence_embedding_with_ground_truth": False,
68
+ "features_to_return": None,
69
+ # Only applicable for LigandMPNN
70
+ "atomize_side_chains": False,
71
+ # Design scope - if all None, design all residues
72
+ "fixed_residues": None,
73
+ "designed_residues": None,
74
+ "fixed_chains": None,
75
+ "designed_chains": None,
76
+ # Bias, Omission, and Pair Bias
77
+ "bias": None,
78
+ "bias_per_residue": None,
79
+ "omit": ["UNK"],
80
+ "omit_per_residue": None,
81
+ "pair_bias": None,
82
+ "pair_bias_per_residue_pair": None,
83
+ # Temperature
84
+ "temperature": 0.1,
85
+ "temperature_per_residue": None,
86
+ # Symmetry
87
+ "symmetry_residues": None,
88
+ "symmetry_residues_weights": None,
89
+ "homo_oligomer_chains": None,
90
+ }
91
+
92
+
93
+ ################################################################################
94
+ # CLI / Arg parser
95
+ ################################################################################
96
+
97
+
98
+ def str2bool(v: str) -> bool:
99
+ """Helper function to parse boolean CLI args."""
100
+ if v in ("True", "1"):
101
+ return True
102
+ elif v in ("False", "0"):
103
+ return False
104
+ else:
105
+ raise argparse.ArgumentTypeError(f"Boolean value expected, got {v!r}")
106
+
107
+
108
+ def none_or_type(v: Any, specified_type) -> Any | None:
109
+ """
110
+ CLI type parser that turns 'None' into None. Otherwise, returns the value
111
+ cast to the given type. This function is useful for the parser/pipeline
112
+ override arguments where None has a special meaning (use default behavior).
113
+ """
114
+ if v == "None":
115
+ return None
116
+ return specified_type(v)
117
+
118
+
119
+ def build_arg_parser() -> argparse.ArgumentParser:
120
+ """Build the MPNN inference arg parser."""
121
+ parser = argparse.ArgumentParser(
122
+ description="MPNN JSON-driven inference CLI",
123
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
124
+ )
125
+
126
+ # ---------------- Top-level Config JSON ---------------- #
127
+ parser.add_argument(
128
+ "--config_json",
129
+ type=str,
130
+ help=(
131
+ "Path to existing JSON config file. When provided, all other CLI "
132
+ "flags are parsed but ignored."
133
+ ),
134
+ default=MPNN_GLOBAL_INFERENCE_DEFAULTS["config_json"],
135
+ )
136
+
137
+ # ---------------- Model Type and Weights ---------------- #
138
+ parser.add_argument(
139
+ "--model_type",
140
+ type=str,
141
+ choices=["protein_mpnn", "ligand_mpnn"],
142
+ help="Model type to use.",
143
+ default=MPNN_GLOBAL_INFERENCE_DEFAULTS["model_type"],
144
+ )
145
+ parser.add_argument(
146
+ "--checkpoint_path",
147
+ type=str,
148
+ help="Path to model checkpoint.",
149
+ default=MPNN_GLOBAL_INFERENCE_DEFAULTS["checkpoint_path"],
150
+ )
151
+ parser.add_argument(
152
+ "--is_legacy_weights",
153
+ type=str2bool,
154
+ choices=[True, False],
155
+ help="Whether to interpret checkpoint as legacy-weight ordering.",
156
+ default=MPNN_GLOBAL_INFERENCE_DEFAULTS["is_legacy_weights"],
157
+ )
158
+
159
+ # --------------- Output controls ---------------- #
160
+ parser.add_argument(
161
+ "--out_directory",
162
+ type=str,
163
+ help="Output directory for CIF/FASTA.",
164
+ default=MPNN_GLOBAL_INFERENCE_DEFAULTS["out_directory"],
165
+ )
166
+ parser.add_argument(
167
+ "--write_fasta",
168
+ type=str2bool,
169
+ choices=[True, False],
170
+ help="Whether to write FASTA outputs.",
171
+ default=MPNN_GLOBAL_INFERENCE_DEFAULTS["write_fasta"],
172
+ )
173
+ parser.add_argument(
174
+ "--write_structures",
175
+ type=str2bool,
176
+ choices=[True, False],
177
+ help="Whether to write designed structures (CIF).",
178
+ default=MPNN_GLOBAL_INFERENCE_DEFAULTS["write_structures"],
179
+ )
180
+
181
+ # ---------------- Structure Path and Name ---------------- #
182
+ parser.add_argument(
183
+ "--structure_path",
184
+ type=str,
185
+ help="Path to structure file (CIF or PDB).",
186
+ default=MPNN_PER_INPUT_INFERENCE_DEFAULTS["structure_path"],
187
+ )
188
+ parser.add_argument(
189
+ "--name",
190
+ type=str,
191
+ help="Optional name / label for the input.",
192
+ default=MPNN_PER_INPUT_INFERENCE_DEFAULTS["name"],
193
+ )
194
+
195
+ # ---------------- Sampling Parameters ---------------- #
196
+ parser.add_argument(
197
+ "--seed",
198
+ type=int,
199
+ help="Random seed for sampling.",
200
+ default=MPNN_PER_INPUT_INFERENCE_DEFAULTS["seed"],
201
+ )
202
+ parser.add_argument(
203
+ "--batch_size",
204
+ type=int,
205
+ help=(
206
+ "Batch size for sampling. At inference, this also controls "
207
+ "the effective repeat_sample_num passed to the pipeline."
208
+ ),
209
+ default=MPNN_PER_INPUT_INFERENCE_DEFAULTS["batch_size"],
210
+ )
211
+ parser.add_argument(
212
+ "--number_of_batches",
213
+ type=int,
214
+ help="Number of batches of size batch_size to draw.",
215
+ default=MPNN_PER_INPUT_INFERENCE_DEFAULTS["number_of_batches"],
216
+ )
217
+
218
+ # ---------------- Parser overrides ---------------- #
219
+ parser.add_argument(
220
+ "--remove_ccds",
221
+ type=lambda v: none_or_type(v, str),
222
+ help=(
223
+ "Comma-separated list of CCD residue names to remove as solvents/"
224
+ "crystallization components during parsing "
225
+ "(overrides STANDARD_PARSER_ARGS). 'None' has special behavior: "
226
+ "use the parser default behavior."
227
+ ),
228
+ default=MPNN_PER_INPUT_INFERENCE_DEFAULTS["remove_ccds"],
229
+ )
230
+ parser.add_argument(
231
+ "--remove_waters",
232
+ type=lambda v: none_or_type(v, str2bool),
233
+ choices=[True, False, None],
234
+ help=(
235
+ "If set, override the parser default for removing water-like "
236
+ "residues (overrides STANDARD_PARSER_ARGS). 'None' "
237
+ "has special behavior: use the parser default behavior."
238
+ ),
239
+ default=MPNN_PER_INPUT_INFERENCE_DEFAULTS["remove_waters"],
240
+ )
241
+
242
+ # ---------------- Pipeline Setup Overrides ---------------- #
243
+ parser.add_argument(
244
+ "--occupancy_threshold_sidechain",
245
+ type=lambda v: none_or_type(v, float),
246
+ help=(
247
+ "Sidechain occupancy threshold used in the MPNN pipeline. 'None' "
248
+ "has special behavior: use the pipeline default behavior."
249
+ ),
250
+ default=MPNN_PER_INPUT_INFERENCE_DEFAULTS["occupancy_threshold_sidechain"],
251
+ )
252
+ parser.add_argument(
253
+ "--occupancy_threshold_backbone",
254
+ type=lambda v: none_or_type(v, float),
255
+ help=(
256
+ "Backbone occupancy threshold used in the MPNN pipeline. 'None' "
257
+ "has special behavior: use the pipeline default behavior."
258
+ ),
259
+ default=MPNN_PER_INPUT_INFERENCE_DEFAULTS["occupancy_threshold_backbone"],
260
+ )
261
+ parser.add_argument(
262
+ "--undesired_res_names",
263
+ type=lambda v: none_or_type(v, str),
264
+ help=(
265
+ "JSON or comma-separated list of residue names to treat as "
266
+ "undesired in the pipeline. 'None' has special behavior: use the "
267
+ "pipeline default behavior."
268
+ ),
269
+ default=MPNN_PER_INPUT_INFERENCE_DEFAULTS["undesired_res_names"],
270
+ )
271
+
272
+ # ---------------- Scalar User Settings ---------------- #
273
+ parser.add_argument(
274
+ "--structure_noise",
275
+ type=float,
276
+ help=("Structure noise (Angstroms) used in user settings."),
277
+ default=MPNN_PER_INPUT_INFERENCE_DEFAULTS["structure_noise"],
278
+ )
279
+ parser.add_argument(
280
+ "--decode_type",
281
+ type=str,
282
+ choices=["auto_regressive", "teacher_forcing"],
283
+ help=(
284
+ "Decoding type for MPNN inference. "
285
+ "\t- auto_regressive: use previously predicted residues for all "
286
+ "previous positions when predicting each residue. This is the "
287
+ "default for inference."
288
+ "\t- teacher_forcing: use ground-truth residues from the structure "
289
+ "for all previous positions when predicting each residue."
290
+ ),
291
+ default=MPNN_PER_INPUT_INFERENCE_DEFAULTS["decode_type"],
292
+ )
293
+ parser.add_argument(
294
+ "--causality_pattern",
295
+ type=str,
296
+ choices=[
297
+ "auto_regressive",
298
+ "unconditional",
299
+ "conditional",
300
+ "conditional_minus_self",
301
+ ],
302
+ help=(
303
+ "Causality pattern for decoding. "
304
+ "\t- auto_regressive: each position attends to the sequence and "
305
+ "decoder representation of all previously decoded positions. This "
306
+ "is the default for inference."
307
+ "\t- unconditional: each position does not attend to the sequence "
308
+ "or decoder representation of any other positions (encoder "
309
+ "representations only)."
310
+ "\t- conditional: each position attends to the sequence and "
311
+ "decoder representation of all other positions."
312
+ "\t- conditional_minus_self: each position attends to the sequence "
313
+ "and decoder representation of all other positions, except for "
314
+ "itself (as a destination node)."
315
+ ),
316
+ default=MPNN_PER_INPUT_INFERENCE_DEFAULTS["causality_pattern"],
317
+ )
318
+ parser.add_argument(
319
+ "--initialize_sequence_embedding_with_ground_truth",
320
+ type=str2bool,
321
+ choices=[True, False],
322
+ help=(
323
+ "Whether to initialize the sequence embedding with ground truth "
324
+ "residues from the input structure. "
325
+ "\t- False: initialize the sequence embedding with zeros. If doing "
326
+ "auto-regressive decoding, initialize S_sampled with unknown "
327
+ "residues. This is the default for inference."
328
+ "\t- True: initialize the sequence embedding with the ground truth "
329
+ "sequence from the input structure. If doing auto-regressive "
330
+ "decoding, also initialize S_sampled with the ground truth. This "
331
+ "affects the pair bias application."
332
+ ),
333
+ default=MPNN_PER_INPUT_INFERENCE_DEFAULTS[
334
+ "initialize_sequence_embedding_with_ground_truth"
335
+ ],
336
+ )
337
+ parser.add_argument(
338
+ "--features_to_return",
339
+ type=str,
340
+ help=(
341
+ "JSON dict for features_to_return; "
342
+ 'e.g. \'{"input_features": '
343
+ '["mask_for_loss"], "decoder_features": ["log_probs"]}\''
344
+ ),
345
+ default=MPNN_PER_INPUT_INFERENCE_DEFAULTS["features_to_return"],
346
+ )
347
+ # Only applicable for LigandMPNN.
348
+ parser.add_argument(
349
+ "--atomize_side_chains",
350
+ type=str2bool,
351
+ choices=[True, False],
352
+ help=(
353
+ "Whether to atomize side chains of fixed residues. Only applicable "
354
+ "for LigandMPNN."
355
+ ),
356
+ default=MPNN_PER_INPUT_INFERENCE_DEFAULTS["atomize_side_chains"],
357
+ )
358
+
359
+ # ---------------- Design scope (mutually exclusive) ---------------- #
360
+ design_group = parser.add_mutually_exclusive_group(required=False)
361
+ design_group.add_argument(
362
+ "--fixed_residues",
363
+ type=str,
364
+ help=(
365
+ 'List of residue IDs to fix: e.g. \'["A35","B40","C52"]\' or "A35,B40,C52"'
366
+ ),
367
+ default=MPNN_PER_INPUT_INFERENCE_DEFAULTS["fixed_residues"],
368
+ )
369
+ design_group.add_argument(
370
+ "--designed_residues",
371
+ type=str,
372
+ help=(
373
+ "List of residue IDs to design: "
374
+ 'e.g. \'["A35","B40","C52"]\' or "A35,B40,C52"'
375
+ ),
376
+ default=MPNN_PER_INPUT_INFERENCE_DEFAULTS["designed_residues"],
377
+ )
378
+ design_group.add_argument(
379
+ "--fixed_chains",
380
+ type=str,
381
+ help=('List of chain IDs to fix: e.g. \'["A","B"]\' or "A,B"'),
382
+ default=MPNN_PER_INPUT_INFERENCE_DEFAULTS["fixed_chains"],
383
+ )
384
+ design_group.add_argument(
385
+ "--designed_chains",
386
+ type=str,
387
+ help=('List of chain IDs to design: e.g. \'["A","B"]\' or "A,B"'),
388
+ default=MPNN_PER_INPUT_INFERENCE_DEFAULTS["designed_chains"],
389
+ )
390
+
391
+ # ---------------- Bias, Omission, and Pair Bias ---------------- #
392
+ parser.add_argument(
393
+ "--bias",
394
+ type=str,
395
+ help='Bias dict: e.g. \'{"ALA": -1.0, "GLY": 0.5}\'',
396
+ default=MPNN_PER_INPUT_INFERENCE_DEFAULTS["bias"],
397
+ )
398
+ parser.add_argument(
399
+ "--bias_per_residue",
400
+ type=str,
401
+ help=(
402
+ 'Per-residue bias dict: e.g. \'{"A35": {"ALA": -2.0}}\'. Overwrites --bias.'
403
+ ),
404
+ default=MPNN_PER_INPUT_INFERENCE_DEFAULTS["bias_per_residue"],
405
+ )
406
+ parser.add_argument(
407
+ "--omit",
408
+ type=str,
409
+ help=('List of residue types to omit: e.g. \'["ALA","GLY","UNK"]\'.'),
410
+ default=MPNN_PER_INPUT_INFERENCE_DEFAULTS["omit"],
411
+ )
412
+ parser.add_argument(
413
+ "--omit_per_residue",
414
+ type=str,
415
+ help=(
416
+ "Per-residue list of residue types to omit: "
417
+ 'e.g. \'{"A35": ["ALA","GLY","UNK"]}\'. Overwrites --omit.'
418
+ ),
419
+ default=MPNN_PER_INPUT_INFERENCE_DEFAULTS["omit_per_residue"],
420
+ )
421
+ parser.add_argument(
422
+ "--pair_bias",
423
+ type=str,
424
+ help=(
425
+ "Controls the bias applied due to residue selections at "
426
+ "neighboring positions: "
427
+ '\'{"ALA": {"GLY": -0.5}, "GLY": {"ALA": -0.5}}\''
428
+ ),
429
+ default=MPNN_PER_INPUT_INFERENCE_DEFAULTS["pair_bias"],
430
+ )
431
+ parser.add_argument(
432
+ "--pair_bias_per_residue_pair",
433
+ type=str,
434
+ help=(
435
+ "Per-residue-pair dict for controlling bias due to residue "
436
+ "selections at neighboring positions: "
437
+ '\'{"A35": {"B40": {"ALA": {"GLY": -1.0}}}}\' . Overwrites '
438
+ "--pair_bias. Note that this is NOT applied symmetrically; if "
439
+ "the outer residue ID corresponds to the first token; the inner "
440
+ "residue ID corresponds to the second token. This should be read "
441
+ 'as follows: for residue pair (i,j) (e.g. ("A35","B40")), the '
442
+ "inner dictionaries dictate that if residue i is assigned as the "
443
+ 'first token (e.g. "ALA"), then the bias for assigning residue j '
444
+ 'is the innermost dict (e.g. {"GLY": -1.0} ).'
445
+ ),
446
+ default=MPNN_PER_INPUT_INFERENCE_DEFAULTS["pair_bias_per_residue_pair"],
447
+ )
448
+
449
+ # ---------------- Temperature ---------------- #
450
+ parser.add_argument(
451
+ "--temperature",
452
+ type=float,
453
+ help=("Temperature for sampling."),
454
+ default=MPNN_PER_INPUT_INFERENCE_DEFAULTS["temperature"],
455
+ )
456
+ parser.add_argument(
457
+ "--temperature_per_residue",
458
+ type=str,
459
+ help=(
460
+ "Per-residue temperature dict: e.g. '{\"A35\": 0.1}'. Overwrites "
461
+ "--temperature."
462
+ ),
463
+ default=MPNN_PER_INPUT_INFERENCE_DEFAULTS["temperature_per_residue"],
464
+ )
465
+
466
+ # ---------------- Symmetry ---------------- #
467
+ sym_group = parser.add_mutually_exclusive_group(required=False)
468
+ sym_group.add_argument(
469
+ "--symmetry_residues",
470
+ type=str,
471
+ help=(
472
+ "Residue-based symmetry groups, each a list of residue IDs. "
473
+ "Example: "
474
+ '\'[["A35","B35"],["A40","B40","C40"]]\''
475
+ ),
476
+ default=MPNN_PER_INPUT_INFERENCE_DEFAULTS["symmetry_residues"],
477
+ )
478
+ sym_group.add_argument(
479
+ "--homo_oligomer_chains",
480
+ type=str,
481
+ help=(
482
+ "Homo-oligomer chain groups, each a list of chain IDs. "
483
+ "Within each group, chains must have the same number of residues "
484
+ "in the same order; residues at matching positions across chains "
485
+ "are treated as symmetry-equivalent. Example: "
486
+ '\'[["A","B","C"]]\''
487
+ ),
488
+ default=MPNN_PER_INPUT_INFERENCE_DEFAULTS["homo_oligomer_chains"],
489
+ )
490
+
491
+ # Symmetry weights
492
+ parser.add_argument(
493
+ "--symmetry_residues_weights",
494
+ type=str,
495
+ help=(
496
+ "Optional list of symmetry weights matching the shape of "
497
+ "symmetry_residues. Example: "
498
+ "'[[1.0, 1.0], [1.0, 0.5, -0.5]]'. "
499
+ "Ignored if homo_oligomer_chains is used."
500
+ ),
501
+ default=MPNN_PER_INPUT_INFERENCE_DEFAULTS["symmetry_residues_weights"],
502
+ )
503
+
504
+ return parser
505
+
506
+
507
+ ###############################################################################
508
+ # JSON builder
509
+ ###############################################################################
510
+
511
+
512
+ def parse_json_like(value: Any) -> Any:
513
+ """Parse a JSON-like string into a Python object.
514
+
515
+ Tries JSON first, then falls back to ast.literal_eval for
516
+ simple Python literals, and finally comma-separated lists.
517
+
518
+ If value is not a string (e.g. already parsed from JSON), it is
519
+ returned unchanged.
520
+
521
+ Args:
522
+ value: The input to parse.
523
+ Returns:
524
+ Any: The parsed Python object, or the original value if it is
525
+ not a string.
526
+ - None -> None
527
+ - non-str -> returned unchanged
528
+ - JSON-like string -> dict / list / ...
529
+ - python literal string (list, dict, etc.) -> list | dict | ...
530
+ - comma-separated list string -> list[str]
531
+ """
532
+ # Pass through None or non-string values unchanged.
533
+ if value is None or not isinstance(value, str):
534
+ return value
535
+
536
+ # Try JSON loading first.
537
+ try:
538
+ return json.loads(value)
539
+ except json.JSONDecodeError:
540
+ pass
541
+
542
+ # Fallback: try Python literal eval
543
+ try:
544
+ return ast.literal_eval(value)
545
+ except (ValueError, SyntaxError):
546
+ pass
547
+
548
+ # Fallback: treat as comma-separated list of strings
549
+ if "," in value:
550
+ return [item.strip() for item in value.split(",") if item.strip()]
551
+
552
+ # Single string value.
553
+ return value
554
+
555
+
556
+ def parse_list_like(value: str | None) -> list[Any] | None:
557
+ """Parse list-like CLI strings into Python lists."""
558
+ # First, try JSON-like parsing.
559
+ parsed = parse_json_like(value)
560
+
561
+ # Handle None and regular list.
562
+ if parsed is None:
563
+ return None
564
+ if isinstance(parsed, list):
565
+ return parsed
566
+
567
+ # If single value, return singleton list.
568
+ return [parsed]
569
+
570
+
571
+ def _absolute_path_or_none(path_str: str | None) -> str | None:
572
+ """
573
+ Convert a path string to an absolute path if the string is not None or
574
+ empty.
575
+ """
576
+ if not path_str:
577
+ return None
578
+ return str(Path(path_str).expanduser().resolve())
579
+
580
+
581
+ def cli_to_json(args: argparse.Namespace) -> dict[str, Any]:
582
+ """Convert CLI args into the top-level JSON config dict."""
583
+ # If a config JSON is provided, load and return it directly. Ignore the
584
+ # other CLI args.
585
+ if args.config_json:
586
+ config_path = _absolute_path_or_none(args.config_json)
587
+ with open(config_path, "r") as f:
588
+ return json.load(f)
589
+
590
+ # Build a single-input JSON object from CLI
591
+ if (
592
+ args.model_type is None
593
+ or args.checkpoint_path is None
594
+ or args.is_legacy_weights is None
595
+ or args.structure_path is None
596
+ ):
597
+ raise ValueError(
598
+ "When --config_json is not provided, "
599
+ "--model_type, "
600
+ "--checkpoint_path, "
601
+ "--is_legacy_weights, "
602
+ "--structure_path "
603
+ "must all be specified."
604
+ )
605
+
606
+ config: dict[str, Any] = {
607
+ # Model Type and Weights
608
+ "model_type": args.model_type,
609
+ "checkpoint_path": args.checkpoint_path,
610
+ "is_legacy_weights": args.is_legacy_weights,
611
+ # Output controls
612
+ "out_directory": args.out_directory,
613
+ "write_fasta": args.write_fasta,
614
+ "write_structures": args.write_structures,
615
+ # Singleton inputs list (CLI only supports single input at a time).
616
+ "inputs": [
617
+ {
618
+ # Structure Path and Name
619
+ "structure_path": args.structure_path,
620
+ "name": args.name,
621
+ # Sampling Parameters
622
+ "seed": args.seed,
623
+ "batch_size": args.batch_size,
624
+ "number_of_batches": args.number_of_batches,
625
+ # Parser Overrides
626
+ "remove_ccds": parse_list_like(args.remove_ccds),
627
+ "remove_waters": args.remove_waters,
628
+ # Pipeline Setup Overrides
629
+ "occupancy_threshold_sidechain": args.occupancy_threshold_sidechain,
630
+ "occupancy_threshold_backbone": args.occupancy_threshold_backbone,
631
+ "undesired_res_names": parse_list_like(args.undesired_res_names),
632
+ # Scalar User Settings
633
+ "structure_noise": args.structure_noise,
634
+ "decode_type": args.decode_type,
635
+ "causality_pattern": args.causality_pattern,
636
+ "initialize_sequence_embedding_with_ground_truth": args.initialize_sequence_embedding_with_ground_truth,
637
+ "features_to_return": parse_json_like(args.features_to_return),
638
+ # Only applicable for LigandMPNN
639
+ "atomize_side_chains": args.atomize_side_chains,
640
+ # Design scope - if all None, design all residues
641
+ "fixed_residues": parse_list_like(args.fixed_residues),
642
+ "designed_residues": parse_list_like(args.designed_residues),
643
+ "fixed_chains": parse_list_like(args.fixed_chains),
644
+ "designed_chains": parse_list_like(args.designed_chains),
645
+ # Bias, Omission, and Pair Bias
646
+ "bias": parse_json_like(args.bias),
647
+ "bias_per_residue": parse_json_like(args.bias_per_residue),
648
+ "omit": parse_json_like(args.omit),
649
+ "omit_per_residue": parse_json_like(args.omit_per_residue),
650
+ "pair_bias": parse_json_like(args.pair_bias),
651
+ "pair_bias_per_residue_pair": parse_json_like(
652
+ args.pair_bias_per_residue_pair
653
+ ),
654
+ # Temperature
655
+ "temperature": args.temperature,
656
+ "temperature_per_residue": parse_json_like(
657
+ args.temperature_per_residue
658
+ ),
659
+ # Symmetry
660
+ "symmetry_residues": parse_json_like(args.symmetry_residues),
661
+ "symmetry_residues_weights": parse_json_like(
662
+ args.symmetry_residues_weights
663
+ ),
664
+ "homo_oligomer_chains": parse_json_like(args.homo_oligomer_chains),
665
+ }
666
+ ],
667
+ }
668
+
669
+ return config
670
+
671
+
672
+ ###############################################################################
673
+ # MPNNInferenceInput
674
+ ###############################################################################
675
+
676
+
677
+ @dataclass
678
+ class MPNNInferenceInput:
679
+ """Container for structure + input_dict passed into inference."""
680
+
681
+ atom_array: AtomArray
682
+ input_dict: dict[str, Any]
683
+
684
+ @staticmethod
685
+ def from_atom_array_and_dict(
686
+ *,
687
+ atom_array: AtomArray | None = None,
688
+ input_dict: dict[str, Any] | None = None,
689
+ ) -> "MPNNInferenceInput":
690
+ """Construct from an optional AtomArray and/or input dict.
691
+
692
+ This method is responsible for per-input sanitization and defaulting.
693
+
694
+ NOTE: if the user provides both an atom array and an input dictionary,
695
+ the atom array is treated as the authoritative source for annotations.
696
+ If the user passes an atom array with annotations such as:
697
+ - mpnn_designed_residue_mask
698
+ - mpnn_temperature
699
+ - mpnn_bias
700
+ - mpnn_symmetry_equivalence_group
701
+ - mpnn_symmetry_weight
702
+ - mpnn_pair_bias
703
+ then those annotations will be used directly, and any corresponding
704
+ fields from the input dictionary will be ignored. If you would like
705
+ to override those annotations, please either do so in the atom array
706
+ or delete the annotations from the atom array before passing it in.
707
+ """
708
+ # Copy input dictionary.
709
+ input_dict = copy.deepcopy(input_dict) if input_dict is not None else dict()
710
+
711
+ # Copy atom array.
712
+ atom_array = atom_array.copy() if atom_array is not None else None
713
+ parser_output = parse_atom_array(atom_array) if atom_array is not None else {}
714
+ atom_array = (
715
+ parser_output["assemblies"]["1"][0]
716
+ if len(parser_output.get("assemblies", {})) > 0
717
+ else None
718
+ )
719
+
720
+ # Validate the input dictionary.
721
+ MPNNInferenceInput._validate_all(
722
+ input_dict=input_dict,
723
+ require_structure_path=(atom_array is None),
724
+ )
725
+
726
+ # Apply centralized defaults (in place), without overwriting
727
+ # user-provided values.
728
+ MPNNInferenceInput.apply_defaults(input_dict)
729
+
730
+ # Process structure_path, name, and repeat_sample_num (in place).
731
+ MPNNInferenceInput.post_process_inputs(input_dict)
732
+
733
+ # Construct AtomArray if not provided.
734
+ if atom_array is None:
735
+ atom_array = MPNNInferenceInput.build_atom_array(input_dict)
736
+
737
+ # Annotate the atom array with per-residue information from the
738
+ # input dictionary.
739
+ annotated = MPNNInferenceInput.annotate_atom_array(atom_array, input_dict)
740
+ logger.info(f"Annotated AtomArray has {annotated.array_length()} atoms ")
741
+ return MPNNInferenceInput(atom_array=annotated, input_dict=input_dict)
742
+
743
+ @staticmethod
744
+ def _parse_id(
745
+ id_str: str,
746
+ res_num_required: bool = False,
747
+ res_num_allowed: bool = True,
748
+ ) -> tuple[str, int | None, str | None]:
749
+ """
750
+ Parse flexible id strings into (chain_id, res_num, insertion_code).
751
+
752
+ Supported formats
753
+ -----------------
754
+ - '<chain>' (e.g. 'A', 'AB')
755
+ - '<chain><integer_res_num>' (e.g. 'A35', 'AB12')
756
+ - '<chain><integer_res_num><icode>' (e.g. 'A35B', 'AB12C')
757
+
758
+ Args:
759
+ id_str (str): chain/res_num/insertion_code string.
760
+ res_num_required (bool): whether a residue number is required.
761
+ res_num_allowed (bool): whether a residue number is allowed.
762
+ Returns:
763
+ tuple[str, int | None, str | None]: the parsed
764
+ (chain_id, res_num, insertion_code), where res_num and/or
765
+ insertion_code can be None if not provided.
766
+
767
+ Examples
768
+ --------
769
+ 'A' -> ('A', None, None)
770
+ 'AB' -> ('AB', None, None)
771
+ 'A35' -> ('A', 35, None)
772
+ 'A35B' -> ('A', 35, 'B')
773
+ 'AB12C' -> ('AB', 12, 'C')
774
+ """
775
+ # Match:
776
+ # [A-Za-z]+ : 1+ letters for chain ID
777
+ # (\d+)? : optional integer residue number
778
+ # ([A-Za-z]*) : optional insertion code (0+ letters)
779
+ m = re.fullmatch(r"([A-Za-z]+)(\d+)?([A-Za-z]*)", id_str)
780
+
781
+ # Check for valid format.
782
+ if not m:
783
+ raise ValueError(
784
+ f"ID '{id_str}' must look like "
785
+ "'<letters>', '<letters><number>', or "
786
+ "'<letters><number><letters>'."
787
+ )
788
+
789
+ # Extract matched groups.
790
+ chain_id, res_num_str, insertion_code_str = m.groups()
791
+
792
+ # Handle residue number.
793
+ if res_num_str is None:
794
+ if res_num_required:
795
+ raise ValueError(f"ID '{id_str}' must contain a residue number.")
796
+ res_num = None
797
+ else:
798
+ try:
799
+ res_num = int(res_num_str)
800
+ except ValueError as exc:
801
+ raise ValueError(
802
+ f"ID '{id_str}' must contain a valid integer "
803
+ "residue index after the chain ID."
804
+ ) from exc
805
+
806
+ if not res_num_allowed:
807
+ raise ValueError(
808
+ f"ID '{id_str}' is not allowed to contain a residue number."
809
+ )
810
+
811
+ # Handle insertion code (None or "" mapped to None).
812
+ if not insertion_code_str:
813
+ insertion_code = None
814
+ else:
815
+ insertion_code = insertion_code_str
816
+
817
+ return chain_id, res_num, insertion_code
818
+
819
+ @staticmethod
820
+ def _mask_from_ids(
821
+ atom_array: AtomArray,
822
+ targets: list[str],
823
+ ) -> np.ndarray:
824
+ """
825
+ Return a boolean mask over entries in 'atom_array' matching any ID
826
+ specifier in 'targets'.
827
+
828
+ Each target string can be one of:
829
+ - '<chain>' (e.g. 'A')
830
+ - '<chain><integer_res_num>' (e.g. 'A35')
831
+ - '<chain><integer_res_num><icode>' (e.g. 'A35B')
832
+
833
+ Args:
834
+ atom_array (AtomArray): The AtomArray to mask.
835
+ targets (list[str]): List of ID strings to match.
836
+
837
+ Matching rules
838
+ --------------
839
+ - If only chain is provided: match all entries in that chain.
840
+ - If chain + res_num:
841
+ * requires that 'atom_array' has a 'res_id' annotation/field.
842
+ * additionally require residue-number equality.
843
+ - If chain + res_num + icode:
844
+ * requires both 'res_id' and 'ins_code' to be present.
845
+ * additionally require insertion-code equality.
846
+
847
+ Safety checks
848
+ -------------
849
+ - If a target specifies a residue number but 'res_id' is missing, a
850
+ ValueError is raised.
851
+ - If a target specifies an insertion code but 'ins_code' is missing, a
852
+ ValueError is raised.
853
+ - If ANY target matches zero entries in 'atom_array', a ValueError is
854
+ raised to avoid silently ending up with an empty specification.
855
+
856
+ Raises
857
+ ------
858
+ ValueError
859
+ For malformed IDs, missing required fields, or IDs that match
860
+ no entries in 'atom_array'.
861
+ """
862
+ mask = np.zeros(atom_array.array_length(), dtype=bool)
863
+
864
+ chain_ids = atom_array.chain_id
865
+ res_ids = getattr(atom_array, "res_id", None)
866
+ ins_codes = getattr(atom_array, "ins_code", None)
867
+
868
+ for id_str in targets:
869
+ chain_id, res_num, insertion_code = MPNNInferenceInput._parse_id(
870
+ id_str, res_num_required=False, res_num_allowed=True
871
+ )
872
+
873
+ # Always constrain by chain.
874
+ local_mask = chain_ids == chain_id
875
+
876
+ # Optionally constrain by residue number.
877
+ if res_num is not None:
878
+ if res_ids is None:
879
+ raise ValueError(
880
+ f"ID '{id_str}' specifies a residue number, but "
881
+ "the provided AtomArray does not have a 'res_id' "
882
+ "annotation loaded."
883
+ )
884
+ local_mask &= res_ids == res_num
885
+
886
+ # Optionally constrain by insertion code.
887
+ if insertion_code is not None:
888
+ if ins_codes is None:
889
+ raise ValueError(
890
+ f"ID '{id_str}' specifies an insertion code, but "
891
+ "the provided AtomArray does not have an 'ins_code' "
892
+ "annotation loaded."
893
+ )
894
+ local_mask &= ins_codes == insertion_code
895
+
896
+ # Disallow IDs that match nothing
897
+ if not np.any(local_mask):
898
+ raise ValueError(
899
+ f"ID '{id_str}' did not match any entries in the structure."
900
+ )
901
+
902
+ mask |= local_mask
903
+
904
+ return mask
905
+
906
+ @staticmethod
907
+ def _validate_structure_path_and_name(
908
+ input_dict: dict[str, Any],
909
+ require_structure_path: bool,
910
+ ) -> None:
911
+ """
912
+ Validate structure_path and name fields.
913
+
914
+ Args:
915
+ input_dict (dict[str, Any]): Input dictionary containing fields.
916
+ require_structure_path (bool): If True, structure_path must be
917
+ provided and must exist on disk. This should typically be True
918
+ when 'atom_array' is not provided.
919
+ """
920
+ structure_path = input_dict.get("structure_path")
921
+
922
+ # Check presence of structure_path if required.
923
+ if require_structure_path and structure_path is None:
924
+ raise ValueError(
925
+ "structure_path is required when atom_array is not provided."
926
+ )
927
+
928
+ # Check structure_path validity if provided.
929
+ if structure_path is not None:
930
+ if not isinstance(structure_path, str):
931
+ raise TypeError("structure_path must be a string path when provided.")
932
+ structure_path_abs = _absolute_path_or_none(structure_path)
933
+ if structure_path_abs is None or not Path(structure_path_abs).is_file():
934
+ raise FileNotFoundError(
935
+ f"structure_path does not exist: {structure_path}"
936
+ )
937
+
938
+ # Check name type if provided.
939
+ name = input_dict.get("name")
940
+ if name is not None and not isinstance(name, str):
941
+ raise TypeError("name must be a string when provided.")
942
+
943
+ @staticmethod
944
+ def _validate_sampling_parameters(input_dict: dict[str, Any]) -> None:
945
+ """
946
+ Validate seed / batch_size / number_of_batches and repeat_sample_num.
947
+ """
948
+ # Check seed, batch_size, number_of_batches types and values.
949
+ for key in ("seed", "batch_size", "number_of_batches"):
950
+ val = input_dict.get(key)
951
+ if val is None:
952
+ continue
953
+ if not isinstance(val, int):
954
+ raise TypeError(f"{key} must be an int when provided.")
955
+ if key in ("batch_size", "number_of_batches") and val <= 0:
956
+ raise ValueError(f"{key} must be positive if provided.")
957
+
958
+ # repeat_sample_num is derived internally from batch_size and must not
959
+ # appear in user JSON.
960
+ if "repeat_sample_num" in input_dict:
961
+ raise ValueError(
962
+ "repeat_sample_num is not allowed in the JSON config; "
963
+ "use batch_size instead."
964
+ )
965
+
966
+ @staticmethod
967
+ def _validate_parser_overrides(input_dict: dict[str, Any]) -> None:
968
+ """Validate parser override fields: remove_ccds, remove_waters."""
969
+
970
+ # Check that remove_ccds is a list of strings if provided.
971
+ remove_ccds = input_dict.get("remove_ccds")
972
+ if remove_ccds is not None:
973
+ if not isinstance(remove_ccds, list):
974
+ raise TypeError("remove_ccds must be a list of CCD residue names.")
975
+ for item in remove_ccds:
976
+ if not isinstance(item, str):
977
+ raise TypeError(
978
+ f"remove_ccds entries must be strings, got {type(item)}"
979
+ )
980
+
981
+ # Check that remove_waters is a boolean if provided.
982
+ remove_waters = input_dict.get("remove_waters")
983
+ if remove_waters is not None and not isinstance(remove_waters, bool):
984
+ raise TypeError("remove_waters must be a bool when provided.")
985
+
986
+ @staticmethod
987
+ def _validate_pipeline_override_fields(input_dict: dict[str, Any]) -> None:
988
+ """
989
+ Validate pipeline set-up override fields:
990
+ occupancy thresholds and undesired_res_names.
991
+ """
992
+ # Check occupancy threshold types.
993
+ for key in ("occupancy_threshold_sidechain", "occupancy_threshold_backbone"):
994
+ val = input_dict.get(key)
995
+ if val is None:
996
+ continue
997
+ if not isinstance(val, (int, float)):
998
+ raise TypeError(f"{key} must be numeric when provided.")
999
+
1000
+ # Check undesired_res_names is a list of strings if provided.
1001
+ undesired_res_names = input_dict.get("undesired_res_names")
1002
+ if undesired_res_names is not None:
1003
+ if not isinstance(undesired_res_names, list):
1004
+ raise TypeError("undesired_res_names must be a list when provided.")
1005
+ for item in undesired_res_names:
1006
+ if not isinstance(item, str):
1007
+ raise TypeError(
1008
+ f"undesired_res_names entries must be strings, got {type(item)}"
1009
+ )
1010
+
1011
+ @staticmethod
1012
+ def _validate_scalar_user_settings(input_dict: dict[str, Any]) -> None:
1013
+ """
1014
+ Validate scalar user settings and related fields:
1015
+ structure_noise, decode_type, causality_pattern,
1016
+ initialize_sequence_embedding_with_ground_truth,
1017
+ features_to_return, atomize_side_chains.
1018
+ """
1019
+ # Check type of structure_noise.
1020
+ if input_dict.get("structure_noise") is not None and not isinstance(
1021
+ input_dict["structure_noise"], (int, float)
1022
+ ):
1023
+ raise TypeError("structure_noise must be numeric when provided.")
1024
+
1025
+ # Check type and value of decode_type.
1026
+ decode_type = input_dict.get("decode_type")
1027
+ if decode_type is not None:
1028
+ if not isinstance(decode_type, str):
1029
+ raise TypeError("decode_type must be a string when provided.")
1030
+ allowed = {"auto_regressive", "teacher_forcing"}
1031
+ if decode_type not in allowed:
1032
+ raise ValueError(
1033
+ f"decode_type must be one of {sorted(allowed)}, got '{decode_type}'"
1034
+ )
1035
+
1036
+ # Check type and value of causality_pattern.
1037
+ causality_pattern = input_dict.get("causality_pattern")
1038
+ if causality_pattern is not None:
1039
+ if not isinstance(causality_pattern, str):
1040
+ raise TypeError("causality_pattern must be a string when provided.")
1041
+ allowed = {
1042
+ "auto_regressive",
1043
+ "unconditional",
1044
+ "conditional",
1045
+ "conditional_minus_self",
1046
+ }
1047
+ if causality_pattern not in allowed:
1048
+ raise ValueError(
1049
+ f"causality_pattern must be one of {sorted(allowed)}, "
1050
+ f"got '{causality_pattern}'"
1051
+ )
1052
+
1053
+ # Check type of initialize_sequence_embedding_with_ground_truth.
1054
+ initialize_sequence_embedding_with_ground_truth = input_dict.get(
1055
+ "initialize_sequence_embedding_with_ground_truth"
1056
+ )
1057
+ if (
1058
+ initialize_sequence_embedding_with_ground_truth is not None
1059
+ and not isinstance(initialize_sequence_embedding_with_ground_truth, bool)
1060
+ ):
1061
+ raise TypeError(
1062
+ "initialize_sequence_embedding_with_ground_truth must be a "
1063
+ "bool when provided."
1064
+ )
1065
+
1066
+ features_to_return = input_dict.get("features_to_return")
1067
+ if features_to_return is not None and not isinstance(features_to_return, dict):
1068
+ raise TypeError("features_to_return must be a dict when provided.")
1069
+
1070
+ # Check type of atomize_side_chains.
1071
+ atomize_side_chains = input_dict.get("atomize_side_chains")
1072
+ if atomize_side_chains is not None and not isinstance(
1073
+ atomize_side_chains, bool
1074
+ ):
1075
+ raise TypeError("atomize_side_chains must be a bool when provided.")
1076
+
1077
+ @staticmethod
1078
+ def _validate_design_scope(input_dict: dict[str, Any]) -> None:
1079
+ """
1080
+ Validate fixed/designed residue and chain fields.
1081
+
1082
+ - Lists must actually be lists.
1083
+ - Residue IDs must parse as <chain><integer> (e.g. 'A35').
1084
+ - Chain IDs must be strings.
1085
+ - Mutually exclusive combinations are disallowed.
1086
+ """
1087
+ fixed_res = input_dict.get("fixed_residues")
1088
+ designed_res = input_dict.get("designed_residues")
1089
+ fixed_chains = input_dict.get("fixed_chains")
1090
+ designed_chains = input_dict.get("designed_chains")
1091
+
1092
+ # Check types + residue-id parsing
1093
+ for key in ("fixed_residues", "designed_residues"):
1094
+ val = input_dict.get(key)
1095
+ if val is None:
1096
+ continue
1097
+ if not isinstance(val, list):
1098
+ raise TypeError(f"{key} must be a list if provided.")
1099
+ for res_id in val:
1100
+ if not isinstance(res_id, str):
1101
+ raise TypeError(
1102
+ f"{key} entries must be residue-id strings, got {type(res_id)}"
1103
+ )
1104
+ MPNNInferenceInput._parse_id(res_id, res_num_required=True)
1105
+
1106
+ # Check chain ID types
1107
+ for key in ("fixed_chains", "designed_chains"):
1108
+ val = input_dict.get(key)
1109
+ if val is None:
1110
+ continue
1111
+ if not isinstance(val, list):
1112
+ raise TypeError(f"{key} must be a list if provided.")
1113
+ for chain_id in val:
1114
+ if not isinstance(chain_id, str):
1115
+ raise TypeError(
1116
+ f"{key} entries must be chain-id strings, got {type(chain_id)}"
1117
+ )
1118
+ MPNNInferenceInput._parse_id(chain_id, res_num_allowed=False)
1119
+
1120
+ # Mutual exclusivity rules
1121
+ if fixed_res is not None and designed_res is not None:
1122
+ raise ValueError("Cannot set both fixed_residues and designed_residues.")
1123
+ if fixed_chains is not None and designed_chains is not None:
1124
+ raise ValueError("Cannot set both fixed_chains and designed_chains.")
1125
+ if (fixed_res or designed_res) and (fixed_chains or designed_chains):
1126
+ raise ValueError(
1127
+ "Cannot mix residue-based and chain-based design constraints "
1128
+ "in the same input."
1129
+ )
1130
+
1131
+ @staticmethod
1132
+ def _validate_bias_omit_and_pair_bias(input_dict: dict[str, Any]) -> None:
1133
+ """
1134
+ Validate global/per-residue bias & omit and pair-bias containers.
1135
+
1136
+ This centralizes checks for:
1137
+ - bias / bias_per_residue
1138
+ - omit / omit_per_residue
1139
+ - pair_bias
1140
+ - pair_bias_per_residue_pair
1141
+ """
1142
+ token_to_idx = MPNN_TOKEN_ENCODING.token_to_idx
1143
+
1144
+ # Check bias type, token membership, and value types.
1145
+ bias = input_dict.get("bias")
1146
+ if bias is not None and not isinstance(bias, dict):
1147
+ raise TypeError("bias must be a dict {token_name: bias_value}.")
1148
+ if isinstance(bias, dict):
1149
+ for token_name, value in bias.items():
1150
+ if token_name not in token_to_idx:
1151
+ raise ValueError(
1152
+ f"bias key '{token_name}' is not in the MPNN token vocabulary."
1153
+ )
1154
+ if not isinstance(value, (int, float)):
1155
+ raise TypeError(
1156
+ f"bias['{token_name}'] must be numeric, got {type(value)}"
1157
+ )
1158
+
1159
+ # Check bias_per_residue type, residue-id parsing, token membership,
1160
+ # and value types.
1161
+ bias_per_residue = input_dict.get("bias_per_residue")
1162
+ if bias_per_residue is not None and not isinstance(bias_per_residue, dict):
1163
+ raise TypeError("bias_per_residue must be a dict.")
1164
+ if isinstance(bias_per_residue, dict):
1165
+ for res_id, res_id_bias in bias_per_residue.items():
1166
+ # Check residue ID type and parsing
1167
+ if not isinstance(res_id, str):
1168
+ raise TypeError(
1169
+ "bias_per_residue keys must be residue-id strings, "
1170
+ f"got {type(res_id)}"
1171
+ )
1172
+ MPNNInferenceInput._parse_id(res_id, res_num_required=True)
1173
+
1174
+ # Check bias for this res_id.
1175
+ if not isinstance(res_id_bias, dict):
1176
+ raise TypeError(
1177
+ f"bias_per_residue[{res_id}] must be a dict, "
1178
+ f"got {type(res_id_bias)}"
1179
+ )
1180
+ for token_name, value in res_id_bias.items():
1181
+ if token_name not in token_to_idx:
1182
+ raise ValueError(
1183
+ f"bias_per_residue[{res_id}] key '{token_name}' is "
1184
+ "not in the MPNN token vocabulary."
1185
+ )
1186
+ if not isinstance(value, (int, float)):
1187
+ raise TypeError(
1188
+ "bias_per_residue"
1189
+ f"[{res_id}]['{token_name}'] must be numeric, "
1190
+ f"got {type(value)}"
1191
+ )
1192
+
1193
+ # Check omit type and token membership.
1194
+ omit = input_dict.get("omit")
1195
+ if omit is not None and not isinstance(omit, list):
1196
+ raise TypeError("omit must be a list of residue codes.")
1197
+ if isinstance(omit, list):
1198
+ for token_name in omit:
1199
+ if token_name not in token_to_idx:
1200
+ raise ValueError(
1201
+ f"omit entry '{token_name}' is not in the MPNN token "
1202
+ "vocabulary."
1203
+ )
1204
+
1205
+ # Check omit_per_residue type, residue-id parsing, and token membership.
1206
+ omit_per_residue = input_dict.get("omit_per_residue")
1207
+ if omit_per_residue is not None and not isinstance(omit_per_residue, dict):
1208
+ raise TypeError("omit_per_residue must be a dict.")
1209
+ if isinstance(omit_per_residue, dict):
1210
+ for res_id, res_id_omit in omit_per_residue.items():
1211
+ # Check residue ID type and parsing.
1212
+ if not isinstance(res_id, str):
1213
+ raise TypeError(
1214
+ "omit_per_residue keys must be residue-id strings, "
1215
+ f"got {type(res_id)}"
1216
+ )
1217
+ MPNNInferenceInput._parse_id(res_id, res_num_required=True)
1218
+
1219
+ # Check omit list for this res_id.
1220
+ if not isinstance(res_id_omit, list):
1221
+ raise TypeError(
1222
+ f"omit_per_residue[{res_id}] must be a list, "
1223
+ f"got {type(res_id_omit)}"
1224
+ )
1225
+ for token_name in res_id_omit:
1226
+ if token_name not in token_to_idx:
1227
+ raise ValueError(
1228
+ f"omit_per_residue[{res_id}] entry '{token_name}' "
1229
+ "is not in the MPNN token vocabulary."
1230
+ )
1231
+
1232
+ # Check pair_bias type, token membership, and value types.
1233
+ pair_bias = input_dict.get("pair_bias")
1234
+ if pair_bias is not None and not isinstance(pair_bias, dict):
1235
+ raise TypeError("pair_bias must be a nested dict when provided.")
1236
+ if isinstance(pair_bias, dict):
1237
+ for token_i, token_j_to_bias in pair_bias.items():
1238
+ # Check outer token membership.
1239
+ if token_i not in token_to_idx:
1240
+ raise ValueError(
1241
+ f"pair_bias key '{token_i}' is not in the MPNN token "
1242
+ "vocabulary."
1243
+ )
1244
+
1245
+ # Check token_j_to_bias type.
1246
+ if not isinstance(token_j_to_bias, dict):
1247
+ raise TypeError(
1248
+ f"pair_bias['{token_i}'] must be a dict mapping "
1249
+ "token_name_j -> bias."
1250
+ )
1251
+
1252
+ # Check inner token membership and value types.
1253
+ for token_j, value in token_j_to_bias.items():
1254
+ if token_j not in token_to_idx:
1255
+ raise ValueError(
1256
+ f"pair_bias['{token_i}'] key '{token_j}' is not in "
1257
+ "the MPNN token vocabulary."
1258
+ )
1259
+ if not isinstance(value, (int, float)):
1260
+ raise TypeError(
1261
+ f"pair_bias['{token_i}']['{token_j}'] must be "
1262
+ f"numeric, got {type(value)}"
1263
+ )
1264
+
1265
+ # ---------------- pair_bias_per_residue_pair ---------------- #
1266
+ pair_bias_per = input_dict.get("pair_bias_per_residue_pair")
1267
+ if pair_bias_per is not None and not isinstance(pair_bias_per, dict):
1268
+ raise TypeError("pair_bias_per_residue_pair must be a dict when provided.")
1269
+ if isinstance(pair_bias_per, dict):
1270
+ for res_id_i, res_id_j_to_pair_bias in pair_bias_per.items():
1271
+ # Check residue ID type and parsing.
1272
+ if not isinstance(res_id_i, str):
1273
+ raise TypeError(
1274
+ "pair_bias_per_residue_pair keys must be residue-id "
1275
+ f"strings, got {type(res_id_i)}"
1276
+ )
1277
+ MPNNInferenceInput._parse_id(res_id_i, res_num_required=True)
1278
+
1279
+ # Check res_id_j_to_pair_bias type.
1280
+ if not isinstance(res_id_j_to_pair_bias, dict):
1281
+ raise TypeError(
1282
+ f"pair_bias_per_residue_pair['{res_id_i}'] must be a "
1283
+ "dict mapping res_id_j -> dict."
1284
+ )
1285
+
1286
+ for res_id_j, i_j_pair_bias in res_id_j_to_pair_bias.items():
1287
+ # Check residue ID type and parsing.
1288
+ if not isinstance(res_id_j, str):
1289
+ raise TypeError(
1290
+ "pair_bias_per_residue_pair inner keys must be "
1291
+ f"residue-id strings, got {type(res_id_j)}"
1292
+ )
1293
+ MPNNInferenceInput._parse_id(res_id_j, res_num_required=True)
1294
+
1295
+ # Check the res i, res j pair bias dict.
1296
+ if not isinstance(i_j_pair_bias, dict):
1297
+ raise TypeError(
1298
+ f"pair_bias_per_residue_pair['{res_id_i}']"
1299
+ f"['{res_id_j}'] must be "
1300
+ "a dict mapping token_name_i -> dict."
1301
+ )
1302
+ for token_i, token_j_to_bias in i_j_pair_bias.items():
1303
+ # Check outer token membership.
1304
+ if token_i not in token_to_idx:
1305
+ raise ValueError(
1306
+ "pair_bias_per_residue_pair"
1307
+ f"['{res_id_i}']['{res_id_j}'] key '{token_i}' "
1308
+ "is not in the MPNN token vocabulary."
1309
+ )
1310
+
1311
+ # Check token_j_to_bias type.
1312
+ if not isinstance(token_j_to_bias, dict):
1313
+ raise TypeError(
1314
+ "pair_bias_per_residue_pair"
1315
+ f"['{res_id_i}']['{res_id_j}']['{token_i}'] "
1316
+ "must be a dict mapping token_name_j -> bias."
1317
+ )
1318
+
1319
+ # Check inner token membership and value types.
1320
+ for token_j, value in token_j_to_bias.items():
1321
+ if token_j not in token_to_idx:
1322
+ raise ValueError(
1323
+ "pair_bias_per_residue_pair"
1324
+ f"['{res_id_i}']['{res_id_j}']"
1325
+ f"['{token_i}'] key '{token_j}' is not in "
1326
+ "the MPNN token vocabulary."
1327
+ )
1328
+ if not isinstance(value, (int, float)):
1329
+ raise TypeError(
1330
+ "pair_bias_per_residue_pair"
1331
+ f"['{res_id_i}']['{res_id_j}']"
1332
+ f"['{token_i}']['{token_j}'] "
1333
+ f"must be numeric, got {type(value)}"
1334
+ )
1335
+
1336
+ @staticmethod
1337
+ def _validate_temperature(input_dict: dict[str, Any]) -> None:
1338
+ """Validate temperature scalars and per-residue mappings."""
1339
+ # Check temperature type if provided.
1340
+ temperature = input_dict.get("temperature")
1341
+ if temperature is not None and not isinstance(temperature, (int, float)):
1342
+ raise TypeError("temperature must be numeric when provided.")
1343
+
1344
+ # Check temperature_per_residue type, residue-id parsing, and value
1345
+ # types.
1346
+ temperature_per_residue = input_dict.get("temperature_per_residue")
1347
+ if temperature_per_residue is not None and not isinstance(
1348
+ temperature_per_residue, dict
1349
+ ):
1350
+ raise TypeError("temperature_per_residue must be a dict.")
1351
+ if isinstance(temperature_per_residue, dict):
1352
+ # Check each residue ID and temperature value.
1353
+ for res_id, res_id_temperature in temperature_per_residue.items():
1354
+ # Check residue ID type and parsing.
1355
+ if not isinstance(res_id, str):
1356
+ raise TypeError(
1357
+ "temperature_per_residue keys must be residue-id "
1358
+ f"strings, got {type(res_id)}"
1359
+ )
1360
+ MPNNInferenceInput._parse_id(res_id, res_num_required=True)
1361
+
1362
+ # Check temperature value type.
1363
+ if not isinstance(res_id_temperature, (int, float)):
1364
+ raise TypeError(
1365
+ f"temperature_per_residue[{res_id}] must be numeric; "
1366
+ f"got {type(res_id_temperature)}"
1367
+ )
1368
+
1369
+ @staticmethod
1370
+ def _validate_symmetry(input_dict: dict[str, Any]) -> None:
1371
+ """
1372
+ Validate symmetry-related fields, including residue-id parsing
1373
+ and mutual exclusivity between residue-based symmetry and
1374
+ homo-oligomer chain symmetry.
1375
+ """
1376
+ symmetry_residues = input_dict.get("symmetry_residues")
1377
+ symmetry_residues_weights = input_dict.get("symmetry_residues_weights")
1378
+ homo_oligomer_chains = input_dict.get("homo_oligomer_chains")
1379
+
1380
+ # Check symmetry_residues type and residue-id parsing.
1381
+ if symmetry_residues is not None:
1382
+ if not isinstance(symmetry_residues, list):
1383
+ raise TypeError(
1384
+ "symmetry_residues must be a list of lists when provided."
1385
+ )
1386
+
1387
+ seen_res_ids = set()
1388
+ for symmetry_residue_group in symmetry_residues:
1389
+ if not isinstance(symmetry_residue_group, list):
1390
+ raise TypeError("Each element of symmetry_residues must be a list.")
1391
+ for res_id in symmetry_residue_group:
1392
+ # Check the residue id.
1393
+ if not isinstance(res_id, str):
1394
+ raise TypeError(
1395
+ "symmetry_residues entries must be residue-id "
1396
+ f"strings, got {type(res_id)}"
1397
+ )
1398
+ MPNNInferenceInput._parse_id(res_id, res_num_required=True)
1399
+
1400
+ # Check for duplicates across all groups.
1401
+ if res_id in seen_res_ids:
1402
+ raise ValueError(
1403
+ f"symmetry_residues contains duplicate residue "
1404
+ f"ID '{res_id}' across groups."
1405
+ )
1406
+ seen_res_ids.add(res_id)
1407
+
1408
+ # Check symmetry_residues_weights type, shape, and value types.
1409
+ if symmetry_residues_weights is not None:
1410
+ if not isinstance(symmetry_residues_weights, list):
1411
+ raise TypeError(
1412
+ "symmetry_residues_weights must be a list of lists when provided."
1413
+ )
1414
+
1415
+ # Check that symmetry_residues is also provided, and that the
1416
+ # outer lengths match.
1417
+ if symmetry_residues is None:
1418
+ raise ValueError(
1419
+ "symmetry_residues_weights provided without symmetry_residues."
1420
+ )
1421
+ if len(symmetry_residues_weights) != len(symmetry_residues):
1422
+ raise ValueError(
1423
+ "symmetry_residues_weights must have the same outer length "
1424
+ "as symmetry_residues."
1425
+ )
1426
+
1427
+ # Check that each symmetry_residues_weights group is a list and
1428
+ # that the inner lengths match symmetry_residues, also check
1429
+ # weight values.
1430
+ for symmetry_residue_group, symmetry_residue_group_weights in zip(
1431
+ symmetry_residues, symmetry_residues_weights
1432
+ ):
1433
+ # Check group type.
1434
+ if not isinstance(symmetry_residue_group_weights, list):
1435
+ raise TypeError(
1436
+ "Each element of symmetry_residues_weights must be a list."
1437
+ )
1438
+
1439
+ # Length check.
1440
+ if len(symmetry_residue_group) != len(symmetry_residue_group_weights):
1441
+ raise ValueError(
1442
+ f"symmetry_residues group {symmetry_residue_group} "
1443
+ "has different length than corresponding weights "
1444
+ f"group {symmetry_residue_group_weights}."
1445
+ )
1446
+
1447
+ # Weight type check.
1448
+ for weight in symmetry_residue_group_weights:
1449
+ if not isinstance(weight, (int, float)):
1450
+ raise TypeError(
1451
+ "symmetry_residues_weights entries must be "
1452
+ f"numeric; got {type(weight)}"
1453
+ )
1454
+
1455
+ # Check homo_oligomer_chains type and chain-id parsing.
1456
+ if homo_oligomer_chains is not None:
1457
+ if not isinstance(homo_oligomer_chains, list):
1458
+ raise TypeError(
1459
+ "homo_oligomer_chains must be a list of lists when provided."
1460
+ )
1461
+
1462
+ # Check each chain group.
1463
+ for chain_group in homo_oligomer_chains:
1464
+ # Check group type.
1465
+ if not isinstance(chain_group, list):
1466
+ raise TypeError(
1467
+ "Each element of homo_oligomer_chains must be a list."
1468
+ )
1469
+
1470
+ # Check that each group has at least 2 chains.
1471
+ if len(chain_group) < 2:
1472
+ raise ValueError(
1473
+ "Each homo_oligomer_chains group must contain at "
1474
+ "least 2 chains."
1475
+ )
1476
+
1477
+ # Check each chain ID.
1478
+ for chain_id in chain_group:
1479
+ if not isinstance(chain_id, str):
1480
+ raise TypeError(
1481
+ "homo_oligomer_chains entries must be chain-id "
1482
+ f"strings, got {type(chain_id)}"
1483
+ )
1484
+ MPNNInferenceInput._parse_id(chain_id, res_num_allowed=False)
1485
+
1486
+ # Check mutual exclusivity of symmetry_residues and
1487
+ # homo_oligomer_chains.
1488
+ if symmetry_residues is not None and homo_oligomer_chains is not None:
1489
+ raise ValueError(
1490
+ "Residue-based symmetry (symmetry_residues / "
1491
+ "symmetry_residues_weights) and homo-oligomer symmetry "
1492
+ "(homo_oligomer_chains) are mutually exclusive; "
1493
+ "please specify only one."
1494
+ )
1495
+
1496
+ @staticmethod
1497
+ def _validate_all(
1498
+ input_dict: dict[str, Any],
1499
+ require_structure_path: bool,
1500
+ ) -> None:
1501
+ """
1502
+ Run all JSON-level validation routines on a single input dict.
1503
+
1504
+ Args:
1505
+ input_dict (dict[str, Any]): JSON config for single input.
1506
+ require_structure_path (bool): If True, a valid on-disk
1507
+ structure_path must be present.
1508
+ """
1509
+ MPNNInferenceInput._validate_structure_path_and_name(
1510
+ input_dict=input_dict,
1511
+ require_structure_path=require_structure_path,
1512
+ )
1513
+ MPNNInferenceInput._validate_sampling_parameters(input_dict)
1514
+ MPNNInferenceInput._validate_parser_overrides(input_dict)
1515
+ MPNNInferenceInput._validate_pipeline_override_fields(input_dict)
1516
+ MPNNInferenceInput._validate_scalar_user_settings(input_dict)
1517
+ MPNNInferenceInput._validate_design_scope(input_dict)
1518
+ MPNNInferenceInput._validate_bias_omit_and_pair_bias(input_dict)
1519
+ MPNNInferenceInput._validate_temperature(input_dict)
1520
+ MPNNInferenceInput._validate_symmetry(input_dict)
1521
+
1522
+ @staticmethod
1523
+ def apply_defaults(input_dict: dict[str, Any]) -> None:
1524
+ """Apply JSON-level defaults. Modifies in place."""
1525
+ for key, default_value in MPNN_PER_INPUT_INFERENCE_DEFAULTS.items():
1526
+ if key not in input_dict:
1527
+ input_dict[key] = default_value
1528
+
1529
+ @staticmethod
1530
+ def post_process_inputs(input_dict: dict[str, Any]) -> None:
1531
+ """Apply post-processing to input dict. Modifies in place."""
1532
+ # Ensure structure_path is absolute.
1533
+ input_dict["structure_path"] = _absolute_path_or_none(
1534
+ input_dict["structure_path"]
1535
+ )
1536
+
1537
+ # Set name if missing.
1538
+ if input_dict["name"] is None:
1539
+ if input_dict["structure_path"] is not None:
1540
+ input_dict["name"] = Path(input_dict["structure_path"]).stem
1541
+ else:
1542
+ input_dict["name"] = "unnamed"
1543
+
1544
+ # Set repeat_sample_num from batch_size.
1545
+ input_dict["repeat_sample_num"] = input_dict["batch_size"]
1546
+
1547
+ @staticmethod
1548
+ def build_atom_array(input_dict: dict[str, Any]) -> AtomArray:
1549
+ """Build AtomArray from structure_path and parser overrides."""
1550
+ # Override parser args if specified.
1551
+ parser_args = dict(STANDARD_PARSER_ARGS)
1552
+ if input_dict["remove_ccds"] is not None:
1553
+ parser_args["remove_ccds"] = input_dict["remove_ccds"]
1554
+ if input_dict["remove_waters"] is not None:
1555
+ parser_args["remove_waters"] = input_dict["remove_waters"]
1556
+
1557
+ # Parse structure file.
1558
+ data = parse(
1559
+ filename=input_dict["structure_path"],
1560
+ keep_cif_block=True,
1561
+ **parser_args,
1562
+ )
1563
+
1564
+ # Use assembly 1 if present, otherwise use asymmetric unit.
1565
+ if "assemblies" in data:
1566
+ atom_array = data["assemblies"]["1"][0]
1567
+ else:
1568
+ atom_array = data["asym_unit"][0]
1569
+
1570
+ return atom_array
1571
+
1572
+ @staticmethod
1573
+ def _annotate_design_scope(
1574
+ atom_array: AtomArray,
1575
+ input_dict: dict[str, Any],
1576
+ ) -> None:
1577
+ """
1578
+ Attach 'mpnn_designed_residue_mask' from design-scope fields.
1579
+
1580
+ This function assumes that no existing 'mpnn_designed_residue_mask'
1581
+ annotation is present; callers should skip invocation if the annotation
1582
+ already exists.
1583
+
1584
+ Semantics
1585
+ ---------
1586
+ - If all design-scope fields are None/empty, this function is a no-op
1587
+ and leaves the implicit "design all residues" behavior.
1588
+ - If designed_* fields are present, they define the design mask
1589
+ (starting from all False).
1590
+ - Otherwise, we start from all True and then clear fixed_* fields.
1591
+ """
1592
+ fixed_residues = input_dict["fixed_residues"] or []
1593
+ designed_residues = input_dict["designed_residues"] or []
1594
+ fixed_chains = input_dict["fixed_chains"] or []
1595
+ designed_chains = input_dict["designed_chains"] or []
1596
+
1597
+ # If absolutely nothing is specified -> design all, and rely on the
1598
+ # implicit "design all residues" behavior in the model code.
1599
+ if not (fixed_residues or designed_residues or fixed_chains or designed_chains):
1600
+ return
1601
+
1602
+ # Gather the token-level array.
1603
+ token_starts = get_token_starts(atom_array)
1604
+ token_level = atom_array[token_starts]
1605
+ n_tokens = token_level.array_length()
1606
+
1607
+ # Initialize mask depending on which fields are present.
1608
+ if designed_residues or designed_chains:
1609
+ designed_residue_mask_token_level = np.zeros(n_tokens, dtype=bool)
1610
+ elif fixed_residues or fixed_chains:
1611
+ designed_residue_mask_token_level = np.ones(n_tokens, dtype=bool)
1612
+ else:
1613
+ raise RuntimeError("Unreachable state in _annotate_design_scope.")
1614
+
1615
+ # Residue-based constraints.
1616
+ if fixed_residues:
1617
+ mask = MPNNInferenceInput._mask_from_ids(token_level, fixed_residues)
1618
+ designed_residue_mask_token_level[mask] = False
1619
+
1620
+ if designed_residues:
1621
+ mask = MPNNInferenceInput._mask_from_ids(token_level, designed_residues)
1622
+ designed_residue_mask_token_level[mask] = True
1623
+
1624
+ # Chain-based constraints.
1625
+ if fixed_chains:
1626
+ mask = MPNNInferenceInput._mask_from_ids(token_level, fixed_chains)
1627
+ designed_residue_mask_token_level[mask] = False
1628
+
1629
+ if designed_chains:
1630
+ mask = MPNNInferenceInput._mask_from_ids(token_level, designed_chains)
1631
+ designed_residue_mask_token_level[mask] = True
1632
+
1633
+ # Spread to atom level.
1634
+ mpnn_designed_residue_mask = spread_token_wise(
1635
+ atom_array, designed_residue_mask_token_level
1636
+ )
1637
+
1638
+ # Annotate.
1639
+ atom_array.set_annotation(
1640
+ "mpnn_designed_residue_mask",
1641
+ mpnn_designed_residue_mask.astype(bool),
1642
+ )
1643
+
1644
+ @staticmethod
1645
+ def _annotate_temperature(
1646
+ atom_array: AtomArray,
1647
+ input_dict: dict[str, Any],
1648
+ ) -> None:
1649
+ """
1650
+ Attach 'mpnn_temperature' annotation from scalar + per-residue
1651
+ temperature settings.
1652
+
1653
+ This function assumes that no existing 'mpnn_temperature' annotation
1654
+ is present; callers should skip invocation if the annotation already
1655
+ exists.
1656
+
1657
+ Semantics
1658
+ ---------
1659
+ - Per-residue values override the global scalar for the specified
1660
+ residues (not additive).
1661
+ - If neither a global temperature nor any per-residue temperatures
1662
+ are specified, this function is a no-op.
1663
+ """
1664
+ temperature = input_dict["temperature"]
1665
+ temperature_per_residue = input_dict["temperature_per_residue"] or {}
1666
+
1667
+ # If there is no global or per-residue temperature, this is a no-op.
1668
+ if temperature is None and not temperature_per_residue:
1669
+ return
1670
+ elif temperature is None and temperature_per_residue:
1671
+ raise RuntimeError(
1672
+ "temperature_per_residue provided without global temperature."
1673
+ )
1674
+
1675
+ # Gather token-level array.
1676
+ token_starts = get_token_starts(atom_array)
1677
+ token_level = atom_array[token_starts]
1678
+ n_tokens = token_level.array_length()
1679
+
1680
+ # Create the global temperature token array.
1681
+ temperature_token_level = np.full(n_tokens, temperature, dtype=np.float32)
1682
+
1683
+ # Per-residue overrides for temperature.
1684
+ for res_id_str, res_id_temperature in temperature_per_residue.items():
1685
+ token_mask = MPNNInferenceInput._mask_from_ids(token_level, [res_id_str])
1686
+ temperature_token_level[token_mask] = res_id_temperature
1687
+
1688
+ # Spread to atom level.
1689
+ mpnn_temperature = spread_token_wise(
1690
+ atom_array, temperature_token_level.astype(np.float32)
1691
+ )
1692
+
1693
+ # Annotate.
1694
+ atom_array.set_annotation(
1695
+ "mpnn_temperature",
1696
+ mpnn_temperature.astype(np.float32),
1697
+ )
1698
+
1699
+ @staticmethod
1700
+ def _build_bias_vector_from_dict(
1701
+ bias_dict: dict[str, float] | None,
1702
+ ) -> np.ndarray:
1703
+ """Convert {token_name: bias} dict to vocab-length vector."""
1704
+ # Create a zero bias vector.
1705
+ vocab_size = MPNN_TOKEN_ENCODING.n_tokens
1706
+ bias_vector = np.zeros((vocab_size,), dtype=np.float32)
1707
+
1708
+ # If no bias dict, return zero vector.
1709
+ if not bias_dict:
1710
+ return bias_vector
1711
+
1712
+ # Populate the bias vector.
1713
+ token_to_idx = MPNN_TOKEN_ENCODING.token_to_idx
1714
+ for token_name, token_bias in bias_dict.items():
1715
+ bias_vector[token_to_idx[token_name]] = token_bias
1716
+ return bias_vector
1717
+
1718
+ @staticmethod
1719
+ def _annotate_bias_and_omit(
1720
+ atom_array: AtomArray,
1721
+ input_dict: dict[str, Any],
1722
+ omit_bias_value: float = -1e8,
1723
+ ) -> None:
1724
+ """
1725
+ Attach 'mpnn_bias' annotation from:
1726
+ - bias
1727
+ - bias_per_residue
1728
+ - omit
1729
+ - omit_per_residue
1730
+
1731
+ This function assumes that no existing 'mpnn_bias' annotation is
1732
+ present; callers should skip invocation if the annotation already
1733
+ exists.
1734
+
1735
+ Behavior
1736
+ --------
1737
+ - bias_per_residue overrides global bias for overlapping tokens.
1738
+ - omit forms its own bias matrix (global + per-residue), then is
1739
+ added to the bias matrix.
1740
+ - If the resulting bias matrix is all zeros, this function is a no-op
1741
+ and does not create an annotation.
1742
+ """
1743
+ bias = input_dict["bias"] or {}
1744
+ bias_per_residue = input_dict["bias_per_residue"] or {}
1745
+ omit = input_dict["omit"] or []
1746
+ omit_per_residue = input_dict["omit_per_residue"] or {}
1747
+
1748
+ # Gather token-level array.
1749
+ token_starts = get_token_starts(atom_array)
1750
+ token_level = atom_array[token_starts]
1751
+ n_tokens = token_level.array_length()
1752
+
1753
+ vocab_size = MPNN_TOKEN_ENCODING.n_tokens
1754
+
1755
+ # ---------------- Bias ---------------- #
1756
+ # Initialize a zero bias matrix.
1757
+ bias_token_level = np.zeros((n_tokens, vocab_size), dtype=np.float32)
1758
+
1759
+ # Compute the vector for the global bias setting.
1760
+ global_bias_vector = MPNNInferenceInput._build_bias_vector_from_dict(bias)
1761
+
1762
+ # If there is a global bias, set all residues to it.
1763
+ if np.any(global_bias_vector != 0.0):
1764
+ bias_token_level[:] = global_bias_vector
1765
+
1766
+ # Per-residue bias overrides.
1767
+ for res_id_str, res_id_bias in bias_per_residue.items():
1768
+ # Construct the per-residue bias vector and mask.
1769
+ per_residue_bias_vector = MPNNInferenceInput._build_bias_vector_from_dict(
1770
+ res_id_bias
1771
+ )
1772
+ token_mask = MPNNInferenceInput._mask_from_ids(token_level, [res_id_str])
1773
+
1774
+ # Apply the per-residue bias vector.
1775
+ bias_token_level[token_mask] = per_residue_bias_vector
1776
+
1777
+ # ---------------- Omit ---------------- #
1778
+ # Initialize a zero omit bias matrix.
1779
+ omit_bias_token_level = np.zeros((n_tokens, vocab_size), dtype=np.float32)
1780
+
1781
+ # Compute the vector for the global omit setting.
1782
+ global_omit_bias_vector = MPNNInferenceInput._build_bias_vector_from_dict(
1783
+ {token_name: omit_bias_value for token_name in omit}
1784
+ )
1785
+
1786
+ # If there is a global omit, set all residues to it.
1787
+ if np.any(global_omit_bias_vector != 0.0):
1788
+ omit_bias_token_level[:] = global_omit_bias_vector
1789
+
1790
+ # Per-residue omit overrides.
1791
+ for res_id_str, res_id_omit in omit_per_residue.items():
1792
+ # Construct the per-residue omit bias vector and mask.
1793
+ per_residue_omit_bias_vector = (
1794
+ MPNNInferenceInput._build_bias_vector_from_dict(
1795
+ {token_name: omit_bias_value for token_name in res_id_omit}
1796
+ )
1797
+ )
1798
+ token_mask = MPNNInferenceInput._mask_from_ids(token_level, [res_id_str])
1799
+
1800
+ # Apply the per-residue omit vector.
1801
+ omit_bias_token_level[token_mask] = per_residue_omit_bias_vector
1802
+
1803
+ # ---------------- Combine bias and omit ---------------- #
1804
+ # Add omit into the bias matrix.
1805
+ bias_token_level = bias_token_level + omit_bias_token_level
1806
+
1807
+ # No-op if there is no non-zero bias information.
1808
+ if not np.any(bias_token_level != 0.0):
1809
+ return
1810
+
1811
+ # Spread to atom level.
1812
+ mpnn_bias = spread_token_wise(atom_array, bias_token_level)
1813
+
1814
+ # Annotate.
1815
+ atom_array.set_annotation("mpnn_bias", mpnn_bias.astype(np.float32))
1816
+
1817
+ @staticmethod
1818
+ def _annotate_symmetry(
1819
+ atom_array: AtomArray,
1820
+ input_dict: dict[str, Any],
1821
+ ) -> None:
1822
+ """
1823
+ Attach symmetry-related annotations:
1824
+
1825
+ - 'mpnn_symmetry_equivalence_group' (int group IDs)
1826
+ - 'mpnn_symmetry_weight' (optional weights)
1827
+
1828
+ This function assumes that no existing symmetry annotations are
1829
+ present; callers should skip invocation if either
1830
+ 'mpnn_symmetry_equivalence_group' or 'mpnn_symmetry_weight' already
1831
+ exist on the atom array.
1832
+
1833
+ Semantics
1834
+ ---------
1835
+ - Supports either residue-based symmetry or homo-oligomer chain
1836
+ symmetry (not both, as enforced by validation).
1837
+ - If no symmetry information is provided, this function is a no-op.
1838
+ - If no weights are provided, 'mpnn_symmetry_weight' is not created.
1839
+ - Any residues or chains that are not explicitly included in a symmetry
1840
+ group are treated as individual symmetry groups, each with its own
1841
+ unique group ID. If weights are present, these singleton groups have
1842
+ weight 1.0.
1843
+ """
1844
+ symmetry_residues = input_dict["symmetry_residues"]
1845
+ symmetry_residues_weights = input_dict["symmetry_residues_weights"]
1846
+ homo_oligomer_chains = input_dict["homo_oligomer_chains"]
1847
+
1848
+ # If no symmetry information, this is a no-op.
1849
+ if symmetry_residues is None and homo_oligomer_chains is None:
1850
+ return
1851
+
1852
+ # Gather token-level array.
1853
+ token_starts = get_token_starts(atom_array)
1854
+ token_level = atom_array[token_starts]
1855
+ n_tokens = token_level.array_length()
1856
+
1857
+ # By default, every token is its own symmetry group. We will overwrite
1858
+ # these IDs for any tokens that participate in an explicit symmetry
1859
+ # group. The absolute group ID values do not matter, only equality.
1860
+ symmetry_equivalent_group_token_level = np.arange(n_tokens, dtype=np.int32)
1861
+
1862
+ # Optional weights: only created if weights are provided. If present,
1863
+ # default weight is 1.0 for all tokens; explicit symmetry groups
1864
+ # overwrite the weights for the tokens they cover.
1865
+ symmetry_weight_token_level = None
1866
+
1867
+ # Residue-based symmetry
1868
+ if symmetry_residues is not None:
1869
+ # If weights are provided, initialize the weights array.
1870
+ if symmetry_residues_weights is not None:
1871
+ symmetry_weight_token_level = np.ones(n_tokens, dtype=np.float32)
1872
+
1873
+ # Start assigning new group IDs above the current maximum. Each
1874
+ # explicit symmetry group gets a fresh ID, so all residues in the
1875
+ # group share that ID.
1876
+ next_group_id = int(symmetry_equivalent_group_token_level.max()) + 1
1877
+
1878
+ for group_index, symmetry_residue_group in enumerate(symmetry_residues):
1879
+ # Assign a unique group ID for this explicit residue group.
1880
+ group_id = next_group_id
1881
+ next_group_id += 1
1882
+
1883
+ # Get the corresponding weights for this group if present.
1884
+ if symmetry_residues_weights is not None:
1885
+ group_weights = symmetry_residues_weights[group_index]
1886
+ else:
1887
+ group_weights = None
1888
+
1889
+ # Assign group ID and weights to each residue in the group.
1890
+ for position, res_id_str in enumerate(symmetry_residue_group):
1891
+ # Get the token mask for this residue ID.
1892
+ token_mask = MPNNInferenceInput._mask_from_ids(
1893
+ token_level, [res_id_str]
1894
+ )
1895
+
1896
+ # Write the group ID.
1897
+ symmetry_equivalent_group_token_level[token_mask] = group_id
1898
+
1899
+ # Write the weight if applicable.
1900
+ if (
1901
+ symmetry_weight_token_level is not None
1902
+ and group_weights is not None
1903
+ ):
1904
+ weight_value = group_weights[position]
1905
+ symmetry_weight_token_level[token_mask] = weight_value
1906
+
1907
+ # Homo-oligomer chain symmetry. We rely on the implicit behavior of
1908
+ # mpnn_symmetry_weights = None -> weight of 1.0 for all tokens.
1909
+ elif homo_oligomer_chains is not None:
1910
+ # Start assigning new group IDs above the current maximum. Each
1911
+ # explicit symmetry group gets a fresh ID, so all residues in the
1912
+ # group share that ID.
1913
+ next_group_id = int(symmetry_equivalent_group_token_level.max()) + 1
1914
+
1915
+ for chain_group in homo_oligomer_chains:
1916
+ # For each chain in the group, collect token indices.
1917
+ per_chain_indices = []
1918
+ for chain_id_str in chain_group:
1919
+ chain_mask = MPNNInferenceInput._mask_from_ids(
1920
+ token_level, [chain_id_str]
1921
+ )
1922
+ per_chain_indices.append(np.nonzero(chain_mask)[0])
1923
+
1924
+ # All chains in the group must have the same number of tokens.
1925
+ lengths = {indices.size for indices in per_chain_indices}
1926
+ if len(lengths) != 1:
1927
+ raise ValueError(
1928
+ "All chains in a homo_oligomer_chains group must have "
1929
+ "the same number of residues/tokens. "
1930
+ f"Group {chain_group!r} has token counts {lengths}."
1931
+ )
1932
+
1933
+ n_positions = next(iter(lengths))
1934
+
1935
+ # Interleave by position: tokens at the same position along
1936
+ # each chain are symmetry-equivalent.
1937
+ for position_index in range(n_positions):
1938
+ group_id = next_group_id
1939
+ next_group_id += 1
1940
+
1941
+ # Grab the token indices for the current group.
1942
+ token_indices = [
1943
+ int(indices[position_index]) for indices in per_chain_indices
1944
+ ]
1945
+
1946
+ # Assign the group ID to the tokens.
1947
+ symmetry_equivalent_group_token_level[token_indices] = group_id
1948
+
1949
+ # Spread to atom level for equivalence groups.
1950
+ mpnn_symmetry_equivalence_group = spread_token_wise(
1951
+ atom_array, symmetry_equivalent_group_token_level
1952
+ )
1953
+
1954
+ # Annotate equivalence groups.
1955
+ atom_array.set_annotation(
1956
+ "mpnn_symmetry_equivalence_group",
1957
+ mpnn_symmetry_equivalence_group.astype(np.int32),
1958
+ )
1959
+
1960
+ # If symmetry weights are present:
1961
+ if symmetry_weight_token_level is not None:
1962
+ # Spread to atom level for weights.
1963
+ mpnn_symmetry_weight = spread_token_wise(
1964
+ atom_array, symmetry_weight_token_level
1965
+ )
1966
+
1967
+ # Annotate weights.
1968
+ atom_array.set_annotation(
1969
+ "mpnn_symmetry_weight",
1970
+ mpnn_symmetry_weight.astype(np.float32),
1971
+ )
1972
+
1973
+ @staticmethod
1974
+ def _build_pair_bias_matrix_from_dict(
1975
+ pair_bias_dict: dict[str, dict[str, float]] | None,
1976
+ ) -> np.ndarray:
1977
+ """Convert {token_i: {token_j: bias}} into a [vocab, vocab] matrix."""
1978
+ # Create a zero pair-bias matrix.
1979
+ vocab_size = MPNN_TOKEN_ENCODING.n_tokens
1980
+ pair_bias_matrix = np.zeros((vocab_size, vocab_size), dtype=np.float32)
1981
+
1982
+ # If no pair-bias dict, return zero matrix.
1983
+ if not pair_bias_dict:
1984
+ return pair_bias_matrix
1985
+
1986
+ # Populate the pair-bias matrix.
1987
+ token_to_idx = MPNN_TOKEN_ENCODING.token_to_idx
1988
+ for token_i, token_j_to_bias in pair_bias_dict.items():
1989
+ for token_j, value in token_j_to_bias.items():
1990
+ pair_bias_matrix[token_to_idx[token_i], token_to_idx[token_j]] = value
1991
+
1992
+ return pair_bias_matrix
1993
+
1994
+ @staticmethod
1995
+ def _annotate_pair_bias(
1996
+ atom_array: AtomArrayPlus,
1997
+ input_dict: dict[str, Any],
1998
+ ) -> None:
1999
+ """
2000
+ Attach 2D 'mpnn_pair_bias' annotation, with pairs stored as:
2001
+ - pairs_arr: [num_pairs, 2] int32 indices (atom indices)
2002
+ - values_arr: [num_pairs, vocab, vocab] float32 bias matrices
2003
+
2004
+ This function assumes that no existing 'mpnn_pair_bias' 2D annotation
2005
+ is present; callers should skip invocation if the annotation already
2006
+ exists.
2007
+
2008
+ Semantics
2009
+ ---------
2010
+ - Global pair_bias: applies to all residue pairs (CA representatives).
2011
+ - pair_bias_per_residue_pair:
2012
+ * For residue pairs present here, the matrix overrides the global
2013
+ pair_bias matrix.
2014
+ - If there is no pair-bias information at all, or if all matrices are
2015
+ zero, this function is a no-op.
2016
+ """
2017
+ pair_bias = input_dict["pair_bias"]
2018
+ pair_bias_per_residue_pair = input_dict["pair_bias_per_residue_pair"]
2019
+
2020
+ # If there is no pair-bias information, this is a no-op.
2021
+ if not pair_bias and not pair_bias_per_residue_pair:
2022
+ return
2023
+
2024
+ # Identify CA atoms as residue-level representatives.
2025
+ ca_mask = atom_array.atom_name == "CA"
2026
+ ca_indices = np.nonzero(ca_mask)[0]
2027
+ ca_array = atom_array[ca_mask]
2028
+ n_tokens_ca = ca_array.array_length()
2029
+
2030
+ # Check for presence of CA atoms.
2031
+ if n_tokens_ca == 0:
2032
+ raise ValueError(
2033
+ "No CA atoms found in the structure; cannot build pair bias."
2034
+ )
2035
+
2036
+ # Compute the global pair-bias matrix.
2037
+ global_pair_bias_matrix = MPNNInferenceInput._build_pair_bias_matrix_from_dict(
2038
+ pair_bias
2039
+ )
2040
+
2041
+ # Define a dictionary to keep track of (token_index_i, token_index_j)
2042
+ # -> pair_bias_matrix mappings.
2043
+ pair_bias_matrices = {}
2044
+
2045
+ # If there is a global pair-bias matrix, apply it to all residue pairs.
2046
+ if np.any(global_pair_bias_matrix != 0.0):
2047
+ for token_index_i in range(n_tokens_ca):
2048
+ for token_index_j in range(n_tokens_ca):
2049
+ pair_bias_matrices[(int(token_index_i), int(token_index_j))] = (
2050
+ global_pair_bias_matrix
2051
+ )
2052
+
2053
+ # Per-residue-pair overrides.
2054
+ if pair_bias_per_residue_pair:
2055
+ for res_id_i, res_id_j_to_pair_bias in pair_bias_per_residue_pair.items():
2056
+ # Map res_id_i to token index of corresponding CA atom.
2057
+ mask_i = MPNNInferenceInput._mask_from_ids(ca_array, [res_id_i])
2058
+ i_indices = np.nonzero(mask_i)[0]
2059
+ if i_indices.size != 1:
2060
+ raise ValueError(
2061
+ f"Residue ID '{res_id_i}' maps to "
2062
+ f"{i_indices.size} CA atoms; expected exactly 1."
2063
+ )
2064
+ token_index_i = int(i_indices[0])
2065
+
2066
+ # Iterate over all res_id_j entries for this res_id_i.
2067
+ for res_id_j, i_j_pair_bias in res_id_j_to_pair_bias.items():
2068
+ # Map res_id_j to token index of corresponding CA atom.
2069
+ mask_j = MPNNInferenceInput._mask_from_ids(ca_array, [res_id_j])
2070
+ j_indices = np.nonzero(mask_j)[0]
2071
+ if j_indices.size != 1:
2072
+ raise ValueError(
2073
+ f"Residue ID '{res_id_j}' maps to "
2074
+ f"{j_indices.size} CA atoms; expected exactly 1."
2075
+ )
2076
+ token_index_j = int(j_indices[0])
2077
+
2078
+ # Build the pair-bias matrix for this specific pair.
2079
+ pair_bias_matrix_ij = (
2080
+ MPNNInferenceInput._build_pair_bias_matrix_from_dict(
2081
+ i_j_pair_bias
2082
+ )
2083
+ )
2084
+
2085
+ # Skip if the matrix is all zeros.
2086
+ if not np.any(pair_bias_matrix_ij != 0.0):
2087
+ continue
2088
+
2089
+ # Override global for this specific pair.
2090
+ pair_bias_matrices[(token_index_i, token_index_j)] = (
2091
+ pair_bias_matrix_ij
2092
+ )
2093
+
2094
+ # If there are no non-zero matrices, this is a no-op.
2095
+ if not pair_bias_matrices:
2096
+ return
2097
+
2098
+ # Build the pairs and values arrays.
2099
+ items = list(pair_bias_matrices.items())
2100
+ pairs_arr = np.asarray(
2101
+ [
2102
+ [int(ca_indices[token_index_i]), int(ca_indices[token_index_j])]
2103
+ for (token_index_i, token_index_j), _ in items
2104
+ ],
2105
+ dtype=np.int32,
2106
+ )
2107
+ values_arr = np.stack(
2108
+ [values for _, values in items],
2109
+ axis=0,
2110
+ ).astype(np.float32)
2111
+
2112
+ # Annotate.
2113
+ atom_array.set_annotation_2d("mpnn_pair_bias", pairs_arr, values_arr)
2114
+
2115
+ @staticmethod
2116
+ def annotate_atom_array(
2117
+ atom_array: AtomArray | AtomArrayPlus,
2118
+ input_dict: dict[str, Any],
2119
+ ) -> AtomArray | AtomArrayPlus:
2120
+ """
2121
+ Attach all MPNN-specific annotations to an AtomArray based on the
2122
+ (already-validated, default-applied) JSON input dict.
2123
+
2124
+ This function possibly creates the following annotations:
2125
+ - 'mpnn_designed_residue_mask' (bool array)
2126
+ - 'mpnn_temperature' (float32 array)
2127
+ - 'mpnn_bias' (float32 array)
2128
+ - 'mpnn_symmetry_equivalence_group' (int32 array)
2129
+ - 'mpnn_symmetry_weight' (float32 array)
2130
+ - 'mpnn_pair_bias' (2D annotation)
2131
+
2132
+ NOTE:
2133
+ If an annotation already exists on the atom array, the corresponding
2134
+ JSON settings are ignored:
2135
+ - 'mpnn_designed_residue_mask' -> design scope fields are ignored
2136
+ - 'mpnn_temperature' -> temperature fields are ignored
2137
+ - 'mpnn_bias' -> bias/omit fields are ignored
2138
+ - 'mpnn_symmetry_equivalence_group' -> symmetry fields are ignored
2139
+ - 'mpnn_pair_bias' -> pair-bias fields are ignored
2140
+
2141
+ Raises:
2142
+ RuntimeError: If pre-existing 'mpnn_symmetry_weight' atom array
2143
+ annotation is found without a corresponding
2144
+ 'mpnn_symmetry_equivalence_group' annotation, raises an error.
2145
+ """
2146
+ # Discover existing annotations.
2147
+ annotation_categories = set(atom_array.get_annotation_categories())
2148
+ # 2D annotations, dependent on AtomArrayPlus.
2149
+ if isinstance(atom_array, AtomArrayPlus):
2150
+ annotation_2d_categories = set(atom_array.get_annotation_2d_categories())
2151
+ else:
2152
+ annotation_2d_categories = set()
2153
+
2154
+ # Design scope
2155
+ if "mpnn_designed_residue_mask" not in annotation_categories:
2156
+ MPNNInferenceInput._annotate_design_scope(atom_array, input_dict)
2157
+
2158
+ # Temperature
2159
+ if "mpnn_temperature" not in annotation_categories:
2160
+ MPNNInferenceInput._annotate_temperature(atom_array, input_dict)
2161
+
2162
+ # Bias / omit
2163
+ if "mpnn_bias" not in annotation_categories:
2164
+ MPNNInferenceInput._annotate_bias_and_omit(atom_array, input_dict)
2165
+
2166
+ # Symmetry
2167
+ if "mpnn_symmetry_equivalence_group" not in annotation_categories:
2168
+ # Disallow having symmetry weight annotation without equivalence
2169
+ # group annotation.
2170
+ if "mpnn_symmetry_weight" in annotation_categories:
2171
+ raise RuntimeError(
2172
+ "Inconsistent existing symmetry annotations in atom array: "
2173
+ "'mpnn_symmetry_weight' annotation exists but "
2174
+ "'mpnn_symmetry_equivalence_group' annotation does not."
2175
+ )
2176
+ MPNNInferenceInput._annotate_symmetry(atom_array, input_dict)
2177
+
2178
+ # Pair bias (2D annotation)
2179
+ if "mpnn_pair_bias" not in annotation_2d_categories:
2180
+ # Create an AtomArrayPlus.
2181
+ atom_array_plus = as_atom_array_plus(atom_array)
2182
+
2183
+ MPNNInferenceInput._annotate_pair_bias(atom_array_plus, input_dict)
2184
+
2185
+ # If pair bias annotation was added, upgrade to AtomArrayPlus.
2186
+ new_has_pair_bias = (
2187
+ "mpnn_pair_bias" in atom_array_plus.get_annotation_2d_categories()
2188
+ )
2189
+ if new_has_pair_bias:
2190
+ atom_array = atom_array_plus
2191
+
2192
+ return atom_array
2193
+
2194
+
2195
+ ###############################################################################
2196
+ # MPNNInferenceOutput
2197
+ ###############################################################################
2198
+
2199
+
2200
+ @dataclass
2201
+ class MPNNInferenceOutput:
2202
+ """Container for inference output.
2203
+
2204
+ Attributes
2205
+ ----------
2206
+ atom_array:
2207
+ The final, per-design AtomArray to be written/saved.
2208
+ output_dict:
2209
+ Per-design metadata, not stored in the AtomArray:
2210
+ - 'batch_idx'
2211
+ - 'design_idx'
2212
+ - 'designed_sequence'
2213
+ - 'sequence_recovery'
2214
+ - 'ligand_interface_sequence_recovery'
2215
+ - 'model_type'
2216
+ - 'checkpoint_path'
2217
+ - 'is_legacy_weights'
2218
+ input_dict:
2219
+ The JSON-like config dict used for this design.
2220
+ """
2221
+
2222
+ atom_array: AtomArray
2223
+ output_dict: dict[str, Any]
2224
+ input_dict: dict[str, Any]
2225
+
2226
+ def _build_extra_categories(
2227
+ self,
2228
+ ) -> dict[str, dict[str, Any]]:
2229
+ """Convert 'input_dict' and 'output_dict' into CIF 'extra_categories'.
2230
+
2231
+ The result is:
2232
+ {
2233
+ "mpnn_input": {"col1": [val1], "col2": [val2], ...}
2234
+ "mpnn_output": {"col1": [val1], "col2": [val2], ...}
2235
+ }
2236
+
2237
+ Nested structures and non-scalar values are converted to strings.
2238
+ """
2239
+ categories = dict()
2240
+
2241
+ # For both inputs and outputs:
2242
+ for category_name, category_dict in [
2243
+ ("mpnn_input", self.input_dict),
2244
+ ("mpnn_output", self.output_dict),
2245
+ ]:
2246
+ # Initialize category dict.
2247
+ categories[category_name] = dict()
2248
+
2249
+ for key, value in category_dict.items():
2250
+ # For scalar values, store directly.
2251
+ if isinstance(value, (str, int, float, bool)):
2252
+ categories[category_name][key] = [value]
2253
+ # JSON-serializable types: convert to JSON string.
2254
+ elif isinstance(value, (list, dict, type(None))):
2255
+ categories[category_name][key] = [json.dumps(value)]
2256
+ else:
2257
+ raise ValueError(
2258
+ f"Cannot serialize key {key!r} with value {value!r} "
2259
+ f"of type {type(value)} in category {category_name!r}."
2260
+ )
2261
+
2262
+ return categories
2263
+
2264
+ def write_structure(
2265
+ self,
2266
+ *,
2267
+ base_path: PathLike | None = None,
2268
+ file_type: str = "cif",
2269
+ ):
2270
+ """
2271
+ Write this design as a CIF file.
2272
+
2273
+ Parameters
2274
+ ----------
2275
+ base_path:
2276
+ Base path *without* an enforced suffix; the 'file_type' argument
2277
+ controls how the suffix is added (e.g. 'cif.gz' -> '.cif.gz').
2278
+ file_type:
2279
+ One of {'cif', 'bcif', 'cif.gz'}. This is forwarded directly to
2280
+ 'atomworks.io.utils.io_utils.to_cif_file'.
2281
+ """
2282
+ if base_path is None:
2283
+ raise ValueError("base_path must be provided to write structure.")
2284
+
2285
+ extra_categories = self._build_extra_categories()
2286
+
2287
+ # NOTE: It is not currently possible to save mpnn_bias and
2288
+ # mpnn_pair_bias annotations to the CIF file due to shape limitations,
2289
+ # so we exclude them here.
2290
+ possible_extra_fields = [
2291
+ "mpnn_designed_residue_mask",
2292
+ "mpnn_temperature",
2293
+ "mpnn_symmetry_equivalence_group",
2294
+ "mpnn_symmetry_weight",
2295
+ ]
2296
+
2297
+ # Limit to fields actually present in the atom array.
2298
+ extra_fields = [
2299
+ field
2300
+ for field in possible_extra_fields
2301
+ if field in self.atom_array.get_annotation_categories()
2302
+ ]
2303
+
2304
+ # Save to CIF file.
2305
+ to_cif_file(
2306
+ self.atom_array,
2307
+ base_path,
2308
+ file_type=file_type,
2309
+ extra_fields=extra_fields,
2310
+ extra_categories=extra_categories,
2311
+ )
2312
+
2313
+ def write_fasta(
2314
+ self,
2315
+ *,
2316
+ base_path: PathLike | None = None,
2317
+ handle=None,
2318
+ ) -> None:
2319
+ """
2320
+ Write a single FASTA record for this design.
2321
+
2322
+ Parameters
2323
+ ----------
2324
+ base_path:
2325
+ Base path *without* an enforced suffix; if provided, the final
2326
+ path will be '{base_path}.fa'. If None, 'handle' must be provided.
2327
+ handle:
2328
+ An open writable file-like handle. If None, 'base_path' must be
2329
+ provided.
2330
+ """
2331
+ # At least one of handle or base_path must be provided, and they
2332
+ # are mutually exclusive.
2333
+ if handle is None and base_path is None:
2334
+ raise ValueError("At least one of handle or base_path must be provided.")
2335
+ if handle is not None and base_path is not None:
2336
+ raise ValueError("handle and base_path are mutually exclusive arguments.")
2337
+
2338
+ # Extract sequence.
2339
+ seq = self.output_dict.get("designed_sequence")
2340
+ if not seq:
2341
+ raise ValueError("No designed_sequence found for FASTA output.")
2342
+
2343
+ # Extract name, batch_idx, and design_idx.
2344
+ name = self.input_dict["name"]
2345
+ batch_idx = self.output_dict["batch_idx"]
2346
+ design_idx = self.output_dict["design_idx"]
2347
+
2348
+ # Extract recovery metrics.
2349
+ sequence_recovery = self.output_dict["sequence_recovery"]
2350
+ ligand_interface_sequence_recovery = self.output_dict[
2351
+ "ligand_interface_sequence_recovery"
2352
+ ]
2353
+
2354
+ # Initialize the header fields list.
2355
+ header_fields = []
2356
+
2357
+ # Construct the decorated name for the header.
2358
+ name_fields = []
2359
+ if name is not None:
2360
+ name_fields.append(name)
2361
+ if batch_idx is not None:
2362
+ name_fields.append(f"b{batch_idx}")
2363
+ if design_idx is not None:
2364
+ name_fields.append(f"d{design_idx}")
2365
+
2366
+ if name_fields:
2367
+ decorated_name = "_".join(name_fields)
2368
+ header_fields.append(decorated_name)
2369
+
2370
+ # Construct the recovery fields for the header.
2371
+ if sequence_recovery is not None:
2372
+ header_fields.append(f"sequence_recovery={float(sequence_recovery):.4f}")
2373
+ if ligand_interface_sequence_recovery is not None:
2374
+ header_fields.append(
2375
+ f"ligand_interface_sequence_recovery="
2376
+ f"{float(ligand_interface_sequence_recovery):.4f}"
2377
+ )
2378
+
2379
+ # Construct the header string.
2380
+ header = ">" + ", ".join(header_fields)
2381
+
2382
+ # If the handle is provided, write to it directly.
2383
+ if handle is not None:
2384
+ # Write the header.
2385
+ handle.write(f"{header}\n")
2386
+
2387
+ # Write the sequence.
2388
+ handle.write(f"{seq}\n")
2389
+ # Otherwise, open the file at base_path and write to it.
2390
+ else:
2391
+ fasta_path = Path(base_path).with_suffix(".fa")
2392
+ with open(fasta_path, "w") as handle:
2393
+ # Write the header.
2394
+ handle.write(f"{header}\n")
2395
+
2396
+ # Write the sequence.
2397
+ handle.write(f"{seq}\n")