rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
rf3/utils/inference.py ADDED
@@ -0,0 +1,665 @@
1
+ import json
2
+ import logging
3
+ import os
4
+ from concurrent.futures import ProcessPoolExecutor
5
+ from dataclasses import dataclass
6
+ from os import PathLike
7
+ from pathlib import Path
8
+ from typing import Iterable
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ from atomworks.common import as_list
13
+ from atomworks.enums import GroundTruthConformerPolicy
14
+ from atomworks.io import parse
15
+ from atomworks.io.parser import parse_atom_array
16
+ from atomworks.io.tools.inference import (
17
+ build_msa_paths_by_chain_id_from_component_list,
18
+ components_to_atom_array,
19
+ )
20
+ from atomworks.io.transforms.categories import category_to_dict
21
+ from atomworks.io.utils.selection import AtomSelectionStack
22
+ from atomworks.ml.transforms.atom_array import add_global_token_id_annotation
23
+ from biotite.structure import AtomArray
24
+ from rf3.utils.io import (
25
+ CIF_LIKE_EXTENSIONS,
26
+ DICTIONARY_LIKE_EXTENSIONS,
27
+ get_sharded_output_path,
28
+ )
29
+ from torch.utils.data import Dataset
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ def _resolve_override(override_value, source_value, param_name: str, example_id: str):
35
+ """Resolve CLI override vs source value with warning."""
36
+ if override_value is not None and source_value:
37
+ logger.warning(f"CLI {param_name} overriding source value for {example_id}")
38
+ return override_value
39
+ return override_value if override_value is not None else source_value
40
+
41
+
42
+ def extract_example_id_from_path(path: Path) -> str:
43
+ """Extract example ID from file path."""
44
+ path_str = str(path.name)
45
+ # Check for known extensions (longer matches first to handle .cif.gz before .gz)
46
+ for ext in sorted(CIF_LIKE_EXTENSIONS | {".json"}, key=len, reverse=True):
47
+ if path_str.endswith(ext):
48
+ return path_str[: -len(ext)]
49
+ # Fallback to simple stem
50
+ return path.stem
51
+
52
+
53
+ def extract_example_ids_from_json(path: Path) -> list[str]:
54
+ """Extract example IDs from a JSON file containing one or more examples."""
55
+ with open(path, "r") as f:
56
+ data = json.load(f)
57
+ return [ex["name"] for ex in data]
58
+
59
+
60
+ @dataclass
61
+ class InferenceInput:
62
+ """Input specification for RF3 inference."""
63
+
64
+ atom_array: AtomArray
65
+ chain_info: dict
66
+ example_id: str
67
+ template_selection: list[str] | None = None
68
+ ground_truth_conformer_selection: list[str] | None = None
69
+ cyclic_chains: list[str] | None = None
70
+
71
+ @classmethod
72
+ def from_cif_path(
73
+ cls,
74
+ path: PathLike,
75
+ example_id: str | None = None,
76
+ template_selection: list[str] | str | None = None,
77
+ ground_truth_conformer_selection: list[str] | str | None = None,
78
+ ) -> "InferenceInput":
79
+ """Load from CIF/PDB file.
80
+
81
+ Args:
82
+ path: Path to CIF/PDB file.
83
+ example_id: Example ID. Defaults to filename stem.
84
+ template_selection: Template selection override.
85
+ ground_truth_conformer_selection: Conformer selection override.
86
+
87
+ Returns:
88
+ InferenceInput object.
89
+ """
90
+ parsed = parse(path, hydrogen_policy="remove", keep_cif_block=True)
91
+
92
+ atom_array = (
93
+ parsed["assemblies"]["1"][0]
94
+ if "assemblies" in parsed
95
+ else parsed["asym_unit"][0]
96
+ )
97
+
98
+ example_id = example_id or extract_example_id_from_path(Path(path))
99
+
100
+ # Extract from CIF
101
+ cif_template_sel = None
102
+ cif_conformer_sel = None
103
+ if "cif_block" in parsed:
104
+ template_dict = category_to_dict(parsed["cif_block"], "template_selection")
105
+ if template_dict:
106
+ cif_template_sel = list(template_dict.get("template_selection", []))
107
+
108
+ conformer_dict = category_to_dict(
109
+ parsed["cif_block"], "ground_truth_conformer_selection"
110
+ )
111
+ if conformer_dict:
112
+ cif_conformer_sel = list(
113
+ conformer_dict.get("ground_truth_conformer_selection", [])
114
+ )
115
+
116
+ # Resolve overrides (CLI priority)
117
+ final_template_sel = _resolve_override(
118
+ template_selection, cif_template_sel, "template_selection", example_id
119
+ )
120
+ final_conformer_sel = _resolve_override(
121
+ ground_truth_conformer_selection,
122
+ cif_conformer_sel,
123
+ "ground_truth_conformer_selection",
124
+ example_id,
125
+ )
126
+
127
+ return cls(
128
+ atom_array=atom_array,
129
+ chain_info=parsed["chain_info"],
130
+ example_id=example_id,
131
+ template_selection=final_template_sel,
132
+ ground_truth_conformer_selection=final_conformer_sel,
133
+ )
134
+
135
+ @classmethod
136
+ def from_json_dict(
137
+ cls,
138
+ data: dict,
139
+ template_selection: list[str] | str | None = None,
140
+ ground_truth_conformer_selection: list[str] | str | None = None,
141
+ ) -> "InferenceInput":
142
+ """Create from JSON dict with components.
143
+
144
+ CLI args override JSON metadata.
145
+
146
+ Args:
147
+ data: JSON dictionary with components.
148
+ template_selection: Template selection override.
149
+ ground_truth_conformer_selection: Conformer selection override.
150
+
151
+ Returns:
152
+ InferenceInput object.
153
+ """
154
+ # Build atom_array from components
155
+ atom_array, component_list = components_to_atom_array(
156
+ data["components"],
157
+ bonds=data.get("bonds"),
158
+ return_components=True,
159
+ )
160
+
161
+ parsed = parse_atom_array(
162
+ atom_array,
163
+ build_assembly="_spoof",
164
+ hydrogen_policy="keep",
165
+ )
166
+
167
+ chain_info = parsed.get("chain_info", {})
168
+ atom_array = (
169
+ parsed["assemblies"]["1"][0]
170
+ if "assemblies" in parsed
171
+ else parsed["asym_unit"][0]
172
+ )
173
+
174
+ # Merge MSA paths into chain_info
175
+ msa_paths_by_chain_id = build_msa_paths_by_chain_id_from_component_list(
176
+ component_list
177
+ )
178
+ if data.get("msa_paths") and isinstance(data.get("msa_paths"), dict):
179
+ msa_paths_by_chain_id.update(data.get("msa_paths"))
180
+
181
+ for chain_id, msa_path in msa_paths_by_chain_id.items():
182
+ if chain_id in chain_info:
183
+ chain_info[chain_id]["msa_path"] = msa_path
184
+
185
+ # Resolve overrides (CLI priority)
186
+ final_template_sel = _resolve_override(
187
+ template_selection,
188
+ data.get("template_selection"),
189
+ "template_selection",
190
+ data["name"],
191
+ )
192
+ final_conformer_sel = _resolve_override(
193
+ ground_truth_conformer_selection,
194
+ data.get("ground_truth_conformer_selection"),
195
+ "ground_truth_conformer_selection",
196
+ data["name"],
197
+ )
198
+
199
+ return cls(
200
+ atom_array=atom_array,
201
+ chain_info=chain_info,
202
+ example_id=data["name"],
203
+ template_selection=final_template_sel,
204
+ ground_truth_conformer_selection=final_conformer_sel,
205
+ )
206
+
207
+ @classmethod
208
+ def from_atom_array(
209
+ cls,
210
+ atom_array: AtomArray,
211
+ chain_info: dict | None = None,
212
+ example_id: str | None = None,
213
+ template_selection: list[str] | str | None = None,
214
+ ground_truth_conformer_selection: list[str] | str | None = None,
215
+ ) -> "InferenceInput":
216
+ """Create from AtomArray.
217
+
218
+ Args:
219
+ atom_array: Input AtomArray.
220
+ chain_info: Chain info dict. Defaults to extracted from atom_array.
221
+ example_id: Example ID. Defaults to generated ID.
222
+ template_selection: Template selection.
223
+ ground_truth_conformer_selection: Conformer selection.
224
+
225
+ Returns:
226
+ InferenceInput object.
227
+ """
228
+ # Use parse_atom_array
229
+ parsed = parse_atom_array(
230
+ atom_array,
231
+ build_assembly="_spoof",
232
+ hydrogen_policy="keep",
233
+ extra_fields="all",
234
+ )
235
+
236
+ extracted_chain_info = parsed.get("chain_info", {})
237
+
238
+ # Merge with provided chain_info (provided takes priority)
239
+ if chain_info is not None:
240
+ for chain_id, chain_data in chain_info.items():
241
+ if chain_id in extracted_chain_info:
242
+ extracted_chain_info[chain_id].update(chain_data)
243
+ else:
244
+ extracted_chain_info[chain_id] = chain_data
245
+
246
+ final_atom_array = (
247
+ parsed["assemblies"]["1"][0]
248
+ if "assemblies" in parsed
249
+ else parsed["asym_unit"][0]
250
+ )
251
+
252
+ return cls(
253
+ atom_array=final_atom_array,
254
+ chain_info=extracted_chain_info,
255
+ example_id=example_id or f"inference_{id(atom_array)}",
256
+ template_selection=template_selection,
257
+ ground_truth_conformer_selection=ground_truth_conformer_selection,
258
+ )
259
+
260
+ def to_pipeline_input(self) -> dict:
261
+ """Apply transformations and return input for Transform pipeline.
262
+
263
+ Returns:
264
+ Pipeline input dict with example_id, atom_array, and chain_info.
265
+ """
266
+ atom_array = self.atom_array.copy()
267
+
268
+ # Apply template and conformer selections
269
+ atom_array = apply_conformer_and_template_selections(
270
+ atom_array,
271
+ template_selection=self.template_selection,
272
+ ground_truth_conformer_selection=self.ground_truth_conformer_selection,
273
+ )
274
+
275
+ if self.cyclic_chains:
276
+ atom_array = cyclize_atom_array(atom_array, self.cyclic_chains)
277
+
278
+ return {
279
+ "example_id": self.example_id,
280
+ "atom_array": atom_array,
281
+ "chain_info": self.chain_info,
282
+ }
283
+
284
+
285
+ def _process_single_path(
286
+ path: Path,
287
+ existing_outputs_dir: Path | None,
288
+ sharding_pattern: str | None,
289
+ template_selection: list[str] | str | None,
290
+ ground_truth_conformer_selection: list[str] | str | None,
291
+ ) -> list[InferenceInput]:
292
+ """Worker function to process a single input file path.
293
+
294
+ This function is defined at module level to be picklable for multiprocessing.
295
+
296
+ Args:
297
+ path: Path to a single input file.
298
+ existing_outputs_dir: If set, skip examples with existing outputs.
299
+ sharding_pattern: Sharding pattern for output paths.
300
+ template_selection: Override for template selection.
301
+ ground_truth_conformer_selection: Override for conformer selection.
302
+
303
+ Returns:
304
+ List of InferenceInput objects (may be empty if file is skipped).
305
+ """
306
+
307
+ def example_exists(example_id: str) -> bool:
308
+ """Check if example already has predictions (sharding-aware)."""
309
+ if not existing_outputs_dir:
310
+ return False
311
+ example_dir = get_sharded_output_path(
312
+ example_id, existing_outputs_dir, sharding_pattern
313
+ )
314
+ return (example_dir / f"{example_id}_metrics.csv").exists()
315
+
316
+ inference_inputs = []
317
+
318
+ if path.suffix == ".json":
319
+ # Load JSON and convert each entry
320
+ with open(path, "r") as f:
321
+ data = json.load(f)
322
+
323
+ # Normalize to list
324
+ if isinstance(data, dict):
325
+ data = [data]
326
+
327
+ for item in data:
328
+ example_id = item["name"]
329
+ if not example_exists(example_id):
330
+ inference_inputs.append(
331
+ InferenceInput.from_json_dict(
332
+ item,
333
+ template_selection=template_selection,
334
+ ground_truth_conformer_selection=ground_truth_conformer_selection,
335
+ )
336
+ )
337
+
338
+ elif any(path.name.endswith(ext) for ext in CIF_LIKE_EXTENSIONS):
339
+ # CIF/PDB file
340
+ example_id = extract_example_id_from_path(path)
341
+ if not example_exists(example_id):
342
+ inference_inputs.append(
343
+ InferenceInput.from_cif_path(
344
+ path,
345
+ example_id=example_id,
346
+ template_selection=template_selection,
347
+ ground_truth_conformer_selection=ground_truth_conformer_selection,
348
+ )
349
+ )
350
+ else:
351
+ raise ValueError(
352
+ f"Unsupported file type: {path.suffix} (path: {path}). "
353
+ f"Supported: {CIF_LIKE_EXTENSIONS | DICTIONARY_LIKE_EXTENSIONS}"
354
+ )
355
+
356
+ return inference_inputs
357
+
358
+
359
+ def prepare_inference_inputs_from_paths(
360
+ inputs: PathLike | list[PathLike],
361
+ existing_outputs_dir: PathLike | None = None,
362
+ sharding_pattern: str | None = None,
363
+ template_selection: list[str] | str | None = None,
364
+ ground_truth_conformer_selection: list[str] | str | None = None,
365
+ ) -> list[InferenceInput]:
366
+ """Load InferenceInput objects from file paths.
367
+
368
+ Handles CIF, PDB, and JSON files. Filters out existing outputs if requested.
369
+ Uses multiprocessing to parallelize file loading across all available CPUs.
370
+
371
+ Args:
372
+ inputs: File path(s) or directory path(s).
373
+ existing_outputs_dir: If set, skip examples with existing outputs.
374
+ sharding_pattern: Sharding pattern for output paths.
375
+ template_selection: Override for template selection (applied to all inputs).
376
+ ground_truth_conformer_selection: Override for conformer selection (applied to all inputs).
377
+
378
+ Returns:
379
+ List of InferenceInput objects.
380
+ """
381
+ input_paths = as_list(inputs)
382
+
383
+ # Collect all raw input files (reusing logic from build_file_paths_for_prediction)
384
+ paths_to_raw_input_files = []
385
+ for _path in input_paths:
386
+ if Path(_path).is_dir():
387
+ # Scan directory for supported file types (JSON + CIF-like)
388
+ for file_type in CIF_LIKE_EXTENSIONS | DICTIONARY_LIKE_EXTENSIONS:
389
+ paths_to_raw_input_files.extend(Path(_path).glob(f"*{file_type}"))
390
+ else:
391
+ paths_to_raw_input_files.append(Path(_path))
392
+
393
+ # Determine number of CPUs to use
394
+ num_cpus = min(os.cpu_count() or 1, len(paths_to_raw_input_files))
395
+ logger.info(
396
+ f"Processing {len(paths_to_raw_input_files)} files using {num_cpus} CPUs"
397
+ )
398
+
399
+ # Convert existing_outputs_dir to Path if needed
400
+ existing_outputs_dir_path = (
401
+ Path(existing_outputs_dir) if existing_outputs_dir else None
402
+ )
403
+
404
+ # Process files in parallel using all available CPUs
405
+ inference_inputs = []
406
+ with ProcessPoolExecutor(max_workers=num_cpus) as executor:
407
+ # Submit all tasks
408
+ futures = [
409
+ executor.submit(
410
+ _process_single_path,
411
+ path,
412
+ existing_outputs_dir_path,
413
+ sharding_pattern,
414
+ template_selection,
415
+ ground_truth_conformer_selection,
416
+ )
417
+ for path in paths_to_raw_input_files
418
+ ]
419
+
420
+ # Collect results as they complete
421
+ for future in futures:
422
+ result = future.result()
423
+ inference_inputs.extend(result)
424
+
425
+ logger.info(f"Loaded {len(inference_inputs)} inference inputs")
426
+ return inference_inputs
427
+
428
+
429
+ def apply_atom_selection_mask(
430
+ atom_array: AtomArray, selection_list: Iterable[str]
431
+ ) -> np.ndarray:
432
+ """Return a combined boolean mask for a list of AtomSelectionStack queries.
433
+
434
+ Args:
435
+ atom_array: AtomArray to select from.
436
+ selection_list: Iterable of AtomSelectionStack queries (e.g., "*/LIG", "A1-10").
437
+
438
+ Returns:
439
+ A boolean numpy array of shape (num_atoms,) where True indicates a selected atom.
440
+ """
441
+ selection_mask = np.zeros(len(atom_array), dtype=bool)
442
+ for selection in selection_list:
443
+ if not selection:
444
+ continue
445
+ try:
446
+ selector = AtomSelectionStack.from_query(selection)
447
+ mask = selector.get_mask(atom_array)
448
+ selection_mask = selection_mask | mask
449
+ except Exception as exc: # Defensive: keep going if one selection fails
450
+ logging.warning(
451
+ "Failed to parse selection '%s': %s. Skipping.", selection, exc
452
+ )
453
+ return selection_mask
454
+
455
+
456
+ def apply_template_selection(
457
+ atom_array: AtomArray, template_selection: list[str] | str | None
458
+ ) -> AtomArray:
459
+ """Apply token-level template selection to `atom_array` with OR semantics.
460
+
461
+ If the `is_input_file_templated` annotation already exists, this function ORs
462
+ the new selection with the existing annotation. Otherwise, it creates it.
463
+
464
+ Args:
465
+ atom_array: AtomArray to annotate.
466
+ template_selection: Selection string(s). Single strings are converted to lists. If None/empty, no-op.
467
+
468
+ Returns:
469
+ The same AtomArray with `is_input_file_templated` updated.
470
+ """
471
+ # Convert to list if needed
472
+ template_selection_list = as_list(template_selection) if template_selection else []
473
+
474
+ if not template_selection_list:
475
+ # Ensure the annotation exists even if no selection provided
476
+ if "is_input_file_templated" not in atom_array.get_annotation_categories():
477
+ atom_array.set_annotation(
478
+ "is_input_file_templated", np.zeros(len(atom_array), dtype=bool)
479
+ )
480
+ return atom_array
481
+
482
+ # Build new mask
483
+ selection_mask = apply_atom_selection_mask(atom_array, template_selection_list)
484
+ logging.info(
485
+ "Selected %d atoms for token-level templating with %d syntaxes",
486
+ int(np.sum(selection_mask)),
487
+ len([s for s in template_selection_list if s]),
488
+ )
489
+
490
+ # OR with existing annotation if present
491
+ if "is_input_file_templated" in atom_array.get_annotation_categories():
492
+ existing = atom_array.get_annotation("is_input_file_templated").astype(bool)
493
+ selection_mask = existing | selection_mask
494
+ atom_array.set_annotation("is_input_file_templated", selection_mask)
495
+ return atom_array
496
+
497
+
498
+ def apply_ground_truth_conformer_selection(
499
+ atom_array: AtomArray, ground_truth_conformer_selection: list[str] | str | None
500
+ ) -> AtomArray:
501
+ """Apply ground-truth conformer policy selection with union semantics.
502
+
503
+ Behavior:
504
+ - Creates `ground_truth_conformer_policy` if missing and initializes to IGNORE.
505
+ - For selected atoms, sets policy to at least ADD without downgrading any
506
+ existing policy (e.g., preserves REPLACE if present).
507
+
508
+ Args:
509
+ atom_array: AtomArray to annotate.
510
+ ground_truth_conformer_selection: Selection string(s). Single strings are converted to lists. If None/empty, no-op.
511
+
512
+ Returns:
513
+ The same AtomArray with `ground_truth_conformer_policy` updated.
514
+ """
515
+ # Convert to list if needed
516
+ ground_truth_conformer_selection_list = (
517
+ as_list(ground_truth_conformer_selection)
518
+ if ground_truth_conformer_selection
519
+ else []
520
+ )
521
+
522
+ if not ground_truth_conformer_selection_list:
523
+ if (
524
+ "ground_truth_conformer_policy"
525
+ not in atom_array.get_annotation_categories()
526
+ ):
527
+ atom_array.set_annotation(
528
+ "ground_truth_conformer_policy",
529
+ np.full(
530
+ len(atom_array), GroundTruthConformerPolicy.IGNORE, dtype=np.int8
531
+ ),
532
+ )
533
+ return atom_array
534
+
535
+ # Ensure annotation exists
536
+ if "ground_truth_conformer_policy" not in atom_array.get_annotation_categories():
537
+ atom_array.set_annotation(
538
+ "ground_truth_conformer_policy",
539
+ np.full(len(atom_array), GroundTruthConformerPolicy.IGNORE, dtype=np.int8),
540
+ )
541
+
542
+ selection_mask = apply_atom_selection_mask(
543
+ atom_array, ground_truth_conformer_selection_list
544
+ )
545
+ logging.info(
546
+ "Selected %d atoms for ground-truth conformer policy with %d syntaxes",
547
+ int(np.sum(selection_mask)),
548
+ len([s for s in ground_truth_conformer_selection_list if s]),
549
+ )
550
+
551
+ existing = atom_array.get_annotation("ground_truth_conformer_policy")
552
+ existing[selection_mask] = GroundTruthConformerPolicy.ADD
553
+ atom_array.set_annotation("ground_truth_conformer_policy", existing)
554
+
555
+ return atom_array
556
+
557
+
558
+ def apply_conformer_and_template_selections(
559
+ atom_array: AtomArray,
560
+ template_selection: list[str] | str | None = None,
561
+ ground_truth_conformer_selection: list[str] | str | None = None,
562
+ ) -> AtomArray:
563
+ """Apply template and conformer selections and basic preprocessing.
564
+
565
+ This function replaces the former class method `prepare_atom_array`.
566
+
567
+ - Applies `apply_template_selection` then `apply_ground_truth_conformer_selection`.
568
+ - Replaces NaN coordinates with -1 for safety.
569
+
570
+ Args:
571
+ atom_array: AtomArray to prepare.
572
+ template_selection: Template selection string(s). Single strings are converted to lists.
573
+ ground_truth_conformer_selection: Ground-truth conformer selection string(s). Single strings are converted to lists.
574
+
575
+ Returns:
576
+ The same AtomArray with `is_input_file_templated` and `ground_truth_conformer_policy` updated.
577
+ """
578
+ atom_array = apply_template_selection(atom_array, template_selection)
579
+ atom_array = apply_ground_truth_conformer_selection(
580
+ atom_array, ground_truth_conformer_selection
581
+ )
582
+ # Safety: avoid unexpected behavior downstream
583
+ atom_array.coord[np.isnan(atom_array.coord)] = -1
584
+ return atom_array
585
+
586
+
587
+ def cyclize_atom_array(atom_array: AtomArray, cyclic_chains: list[str]) -> AtomArray:
588
+ """Cyclize the atom array by positioining the termini properly if not already done.
589
+
590
+ Behavior:
591
+ - Positions the last carbon atom in the chain to be 1.3 Angstroms away from the first nitrogen atom if they are not already close.
592
+ - Adds a bond between the termini for proper cif output.
593
+
594
+ Args:
595
+ atom_array: AtomArray to cyclize.
596
+ cyclic_chains: List of chain IDs to cyclize.
597
+
598
+ Returns:
599
+ The same AtomArray with the specified chains cyclized.
600
+ """
601
+ for chain in cyclic_chains:
602
+ # Find the first nitrogen atom in the chain
603
+ nitrogen_mask = (atom_array.chain_id == chain) & (atom_array.atom_name == "N")
604
+ nitrogen_mask_indices = np.where(nitrogen_mask)[0]
605
+ first_nitrogen_index = nitrogen_mask_indices[0]
606
+ nitrogen_coord = atom_array.coord[first_nitrogen_index]
607
+
608
+ # move the last carbon atom in the chain to be 1.3 Angstroms away from the nitrogen
609
+ carbon_mask = (atom_array.chain_id == chain) & (atom_array.atom_name == "C")
610
+ carbon_mask_indices = np.where(carbon_mask)[0]
611
+ last_carbon_index = carbon_mask_indices[-1]
612
+ # check if the last carbon is already close to the nitrogen
613
+ termini_distance = np.linalg.norm(
614
+ atom_array.coord[last_carbon_index] - nitrogen_coord
615
+ )
616
+ if not (termini_distance < 1.5 and termini_distance > 0.5):
617
+ atom_array.coord[last_carbon_index] = nitrogen_coord + np.array(
618
+ [1.3, 0.0, 0.0]
619
+ )
620
+
621
+ # add a bond between the nitrogen and carbon so output cif has a connection
622
+ atom_array.bonds.add_bond(first_nitrogen_index, last_carbon_index)
623
+ atom_array.bonds.add_bond(last_carbon_index, first_nitrogen_index)
624
+
625
+ return atom_array
626
+
627
+
628
+ class InferenceInputDataset(Dataset):
629
+ """
630
+ Dataset for inference inputs. Also has a length key telling you the number of tokens in each example for LoadBalancedDistributedSampler.
631
+
632
+ To calculate the length of each example, we need to add the token_id annotation to the atom_array. If it doesn't exist yet, we add it,
633
+ calculate the length, and then remove it since the downstream pipeline may not be expecting it. That means the num_tokens key may not ultimately
634
+ be the same as what's actually used in the model, but this is a close enough approximation for load balancing.
635
+
636
+ Args:
637
+ inference_inputs: List of InferenceInput objects to wrap in a Dataset.
638
+ """
639
+
640
+ def __init__(self, inference_inputs: list[InferenceInput]):
641
+ self.inference_inputs = inference_inputs
642
+ self.key_to_balance = "num_tokens_approximate"
643
+
644
+ # LoadBalancedDistributedSampler checks in dataset.data[key_to_balance] to determine balancing.
645
+ # That means we need to make a dataframe in self.data that has a column with the key_to_balance.
646
+ atom_array_token_lens = []
647
+ for inf_input in self.inference_inputs:
648
+ if "token_id" not in inf_input.atom_array.get_annotation_categories():
649
+ inf_input.atom_array = add_global_token_id_annotation(
650
+ inf_input.atom_array
651
+ )
652
+ num_tokens = len(np.unique(inf_input.atom_array.token_id))
653
+
654
+ # remove the token_id annotation since the pipeline may not be expecting it
655
+ inf_input.atom_array.del_annotation("token_id")
656
+ else:
657
+ num_tokens = len(np.unique(inf_input.atom_array.token_id))
658
+ atom_array_token_lens.append(num_tokens)
659
+ self.data = pd.DataFrame({self.key_to_balance: atom_array_token_lens})
660
+
661
+ def __len__(self):
662
+ return len(self.inference_inputs)
663
+
664
+ def __getitem__(self, idx):
665
+ return self.inference_inputs[idx]