rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,735 @@
1
+ import json
2
+ import logging
3
+ import re
4
+ from dataclasses import dataclass, field
5
+ from os import PathLike
6
+ from pathlib import Path
7
+ from typing import TextIO
8
+
9
+ import pandas as pd
10
+ import torch
11
+ import torch.distributed as dist
12
+ from atomworks.io.utils.io_utils import to_cif_file
13
+ from atomworks.ml.preprocessing.msa.finding import (
14
+ get_msa_depth_and_ext_from_folder,
15
+ get_msa_dirs_from_env,
16
+ )
17
+ from atomworks.ml.samplers import LoadBalancedDistributedSampler
18
+ from biotite.structure import AtomArray, AtomArrayStack
19
+ from omegaconf import OmegaConf
20
+ from torch.utils.data import DataLoader
21
+
22
+ from foundry.inference_engines.base import BaseInferenceEngine
23
+ from foundry.metrics.metric import MetricManager
24
+ from foundry.utils.ddp import RankedLogger
25
+ from rf3.model.RF3 import ShouldEarlyStopFn
26
+ from rf3.utils.inference import (
27
+ InferenceInput,
28
+ InferenceInputDataset,
29
+ prepare_inference_inputs_from_paths,
30
+ )
31
+ from rf3.utils.io import (
32
+ build_stack_from_atom_array_and_batched_coords,
33
+ dump_structures,
34
+ get_sharded_output_path,
35
+ )
36
+ from rf3.utils.predicted_error import (
37
+ annotate_atom_array_b_factor_with_plddt,
38
+ compile_af3_style_confidence_outputs,
39
+ get_mean_atomwise_plddt,
40
+ )
41
+
42
+ logging.basicConfig(
43
+ level=logging.INFO,
44
+ format="%(asctime)s %(levelname)s %(name)s: %(message)s",
45
+ datefmt="%H:%M:%S",
46
+ )
47
+ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
48
+
49
+ # Default metrics configuration for RF3 inference (ptm, iptm, clashing chains)
50
+ DEFAULT_RF3_METRICS_CFG = {
51
+ "ptm": {"_target_": "rf3.metrics.predicted_error.ComputePTM"},
52
+ "iptm": {"_target_": "rf3.metrics.predicted_error.ComputeIPTM"},
53
+ "count_clashing_chains": {
54
+ "_target_": "rf3.metrics.clashing_chains.CountClashingChains"
55
+ },
56
+ }
57
+
58
+
59
+ def dump_json_compact_arrays(obj: dict, f: TextIO) -> None:
60
+ """Dump JSON with indented structure but compact arrays (AF3 style).
61
+
62
+ Arrays are written on single lines instead of one element per line.
63
+ """
64
+ # First dump with indent to get structure
65
+ json_str = json.dumps(obj, indent=2)
66
+ # Collapse arrays onto single lines using regex
67
+ # Match arrays that span multiple lines and collapse them
68
+ pattern = re.compile(r"\[\s*\n\s*([^\[\]]*?)\s*\n\s*\]", re.DOTALL)
69
+ while pattern.search(json_str):
70
+ json_str = pattern.sub(
71
+ lambda m: "["
72
+ + ",".join(item.strip() for item in m.group(1).split(","))
73
+ + "]",
74
+ json_str,
75
+ )
76
+ f.write(json_str)
77
+
78
+
79
+ def compute_ranking_score(
80
+ iptm: float | None,
81
+ ptm: float | None,
82
+ has_clash: bool,
83
+ ) -> float:
84
+ """Compute ranking score.
85
+
86
+ Formula: 0.8 * ipTM + 0.2 * pTM - 100 * has_clash
87
+
88
+ For single-chain predictions where ipTM is None, uses pTM only.
89
+ """
90
+ if iptm is None:
91
+ # Single chain - use pTM only
92
+ iptm = ptm if ptm is not None else 0.0
93
+ if ptm is None:
94
+ ptm = 0.0
95
+ return 0.8 * iptm + 0.2 * ptm - 100 * int(has_clash)
96
+
97
+
98
+ @dataclass
99
+ class RF3Output:
100
+ """Output container for RF3 predictions, analogous to RFD3Output.
101
+
102
+ Stores predicted structures and confidence metrics in AlphaFold3-compatible format.
103
+ """
104
+
105
+ example_id: str
106
+ atom_array: AtomArray
107
+ summary_confidences: dict = field(default_factory=dict)
108
+ confidences: dict | None = None
109
+ sample_idx: int = 0
110
+ seed: int = 0
111
+
112
+ def dump(
113
+ self,
114
+ out_dir: Path,
115
+ file_type: str = "cif",
116
+ dump_full_confidences: bool = True,
117
+ ) -> None:
118
+ """Save output to disk in AlphaFold3-compatible format.
119
+
120
+ Args:
121
+ out_dir: Directory to save outputs to.
122
+ file_type: File type for structure output ("cif" or "cif.gz").
123
+ dump_full_confidences: Whether to save full per-atom confidences.
124
+ """
125
+ out_dir = Path(out_dir)
126
+ out_dir.mkdir(parents=True, exist_ok=True)
127
+
128
+ sample_name = f"{self.example_id}_seed-{self.seed}_sample-{self.sample_idx}"
129
+ base_path = out_dir / sample_name
130
+
131
+ # Save structure
132
+ to_cif_file(
133
+ self.atom_array,
134
+ f"{base_path}_model",
135
+ file_type=file_type,
136
+ include_entity_poly=False,
137
+ )
138
+
139
+ # Save summary_confidences.json
140
+ with open(f"{base_path}_summary_confidences.json", "w") as f:
141
+ dump_json_compact_arrays(self.summary_confidences, f)
142
+
143
+ # Save confidences.json (optional, for full per-atom data)
144
+ if dump_full_confidences and self.confidences:
145
+ with open(f"{base_path}_confidences.json", "w") as f:
146
+ dump_json_compact_arrays(self.confidences, f)
147
+
148
+
149
+ def dump_ranking_scores(
150
+ outputs: list[RF3Output],
151
+ out_dir: Path,
152
+ example_id: str,
153
+ ) -> None:
154
+ """Write {example_id}_ranking_scores.csv with ranking scores for all samples."""
155
+ rows = [
156
+ {
157
+ "seed": o.seed,
158
+ "sample": o.sample_idx,
159
+ "ranking_score": o.summary_confidences.get("ranking_score"),
160
+ }
161
+ for o in outputs
162
+ ]
163
+ df = pd.DataFrame(rows)
164
+ df.to_csv(out_dir / f"{example_id}_ranking_scores.csv", index=False)
165
+
166
+
167
+ def dump_top_ranked_outputs(
168
+ outputs: list[RF3Output],
169
+ out_dir: Path,
170
+ example_id: str,
171
+ file_type: str = "cif",
172
+ ) -> RF3Output:
173
+ """Copy the top-ranked model and summary to the top-level directory.
174
+
175
+ Returns the top-ranked RF3Output.
176
+ """
177
+ # Find the output with the highest ranking score
178
+ best_output = max(
179
+ outputs,
180
+ key=lambda o: o.summary_confidences.get("ranking_score", float("-inf")),
181
+ )
182
+
183
+ # Save top-ranked model at top level
184
+ to_cif_file(
185
+ best_output.atom_array,
186
+ out_dir / f"{example_id}_model",
187
+ file_type=file_type,
188
+ include_entity_poly=False,
189
+ )
190
+
191
+ # Save top-ranked summary_confidences at top level
192
+ with open(out_dir / f"{example_id}_summary_confidences.json", "w") as f:
193
+ dump_json_compact_arrays(best_output.summary_confidences, f)
194
+
195
+ # Save top-ranked full confidences at top level (if present)
196
+ if best_output.confidences:
197
+ with open(out_dir / f"{example_id}_confidences.json", "w") as f:
198
+ dump_json_compact_arrays(best_output.confidences, f)
199
+
200
+ return best_output
201
+
202
+
203
+ def should_early_stop_by_mean_plddt(
204
+ threshold: float, is_real_atom: torch.Tensor, max_value_of_plddt: float
205
+ ) -> ShouldEarlyStopFn:
206
+ """Returns a closure that triggers early stopping when mean pLDDT falls below the specified threshold."""
207
+
208
+ def fn(confidence_outputs: dict, **kwargs):
209
+ mean_plddt = get_mean_atomwise_plddt(
210
+ plddt_logits=confidence_outputs["plddt_logits"].unsqueeze(0),
211
+ is_real_atom=is_real_atom,
212
+ max_value=max_value_of_plddt,
213
+ )
214
+ return (mean_plddt < threshold).item(), {
215
+ "mean_plddt": mean_plddt.item(),
216
+ "threshold": threshold,
217
+ }
218
+
219
+ return fn
220
+
221
+
222
+ class RF3InferenceEngine(BaseInferenceEngine):
223
+ """RF3 inference engine.
224
+
225
+ Separates model setup (expensive, once) from inference (can run multiple times).
226
+
227
+ Usage:
228
+ # Setup once
229
+ engine = RF3InferenceEngine(
230
+ ckpt_path="rf3_latest.pt",
231
+ n_recycles=10,
232
+ diffusion_batch_size=5,
233
+ )
234
+
235
+ # Run inference multiple times with different inputs
236
+ results1 = engine.run(inputs="path/to/cifs", out_dir="./predictions")
237
+ results2 = engine.run(inputs=InferenceInput.from_atom_array(array), out_dir=None)
238
+ results3 = engine.run(inputs=[input1, input2], out_dir="./more_predictions")
239
+ """
240
+
241
+ def __init__(
242
+ self,
243
+ # Model parameters
244
+ n_recycles: int = 10,
245
+ diffusion_batch_size: int = 5,
246
+ num_steps: int = 50,
247
+ # Templating, MSAs, etc.
248
+ template_noise_scale: float = 1e-5,
249
+ raise_if_missing_msa_for_protein_of_length_n: int | None = None,
250
+ # Output control
251
+ compress_outputs: bool = False,
252
+ early_stopping_plddt_threshold: float | None = None,
253
+ # Metrics
254
+ metrics_cfg: dict | OmegaConf | MetricManager | str | None = "default",
255
+ **kwargs,
256
+ ):
257
+ """Initialize inference engine and load model.
258
+
259
+ Model config is loaded from checkpoint and overridden with parameters provided here.
260
+
261
+ Args:
262
+ n_recycles: Number of recycles. Defaults to ``10``.
263
+ diffusion_batch_size: Number of structures to generate per input. Defaults to ``5``.
264
+ num_steps: Number of diffusion steps. Defaults to ``50``.
265
+ template_noise_scale: Noise scale for template coordinates. Defaults to ``1e-5``.
266
+ raise_if_missing_msa_for_protein_of_length_n: Debug flag for MSA checking. Defaults to ``None``.
267
+ compress_outputs: Whether to gzip output files. Defaults to ``False``.
268
+ early_stopping_plddt_threshold: Stop early if pLDDT below threshold. Defaults to ``None``.
269
+ metrics_cfg: Metrics configuration. Can be:
270
+ - "default" to use standard RF3 metrics (ptm, iptm, clashing chains)
271
+ - dict/OmegaConf with Hydra configs
272
+ - Pre-instantiated MetricManager
273
+ - None (no metrics).
274
+ Defaults to ``"default"``.
275
+ **kwargs: Additional arguments passed to BaseInferenceEngine:
276
+ - ckpt_path (PathLike, required): Path to model checkpoint.
277
+ - seed (int | None): Random seed. If None, uses external RNG state. Defaults to ``None``.
278
+ - num_nodes (int): Number of nodes for distributed inference. Defaults to ``1``.
279
+ - devices_per_node (int): Number of devices per node. Defaults to ``1``.
280
+ - verbose (bool): If True, show detailed logging and config trees. Defaults to ``False``.
281
+ """
282
+ # set MSA directories from environment variable only
283
+ if env_var_msa_dirs := get_msa_dirs_from_env(raise_if_not_set=False):
284
+ override_msa_dirs = [str(msa_dir) for msa_dir in env_var_msa_dirs]
285
+ ranked_logger.debug(
286
+ f"Using MSA directories from environment variable: {override_msa_dirs}"
287
+ )
288
+ else:
289
+ override_msa_dirs = []
290
+ ranked_logger.debug(
291
+ "No MSA directories set (LOCAL_MSA_DIRS env var not found)"
292
+ )
293
+
294
+ super().__init__(
295
+ transform_overrides={
296
+ "diffusion_batch_size": diffusion_batch_size,
297
+ "n_recycles": n_recycles,
298
+ "raise_if_missing_msa_for_protein_of_length_n": raise_if_missing_msa_for_protein_of_length_n,
299
+ "undesired_res_names": [],
300
+ "template_noise_scales": {
301
+ "atomized": template_noise_scale,
302
+ "not_atomized": template_noise_scale,
303
+ },
304
+ "allowed_chain_types_for_conditioning": None,
305
+ "protein_msa_dirs": [
306
+ {
307
+ "dir": msa_dir,
308
+ "extension": extension.value,
309
+ "directory_depth": depth,
310
+ }
311
+ for msa_dir, depth, extension in [
312
+ (msa_dir, *get_msa_depth_and_ext_from_folder(Path(msa_dir)))
313
+ for msa_dir in override_msa_dirs
314
+ ]
315
+ ],
316
+ "rna_msa_dirs": [],
317
+ # (Paranoia - in validation, these should be set correctly anyhow)
318
+ "p_give_polymer_ref_conf": 0.0,
319
+ "p_give_non_polymer_ref_conf": 0.0,
320
+ "p_dropout_ref_conf": 0.0,
321
+ "use_element_for_atom_names_of_atomized_tokens": True,
322
+ },
323
+ inference_sampler_overrides={
324
+ "num_timesteps": num_steps,
325
+ },
326
+ **kwargs,
327
+ )
328
+
329
+ # remove loss override if present (i.e. keep from checkpoint)
330
+ self.overrides["trainer"].pop("loss", None)
331
+
332
+ # Store metrics config for later - will be set directly on trainer in initialize()
333
+ self._metrics_cfg = metrics_cfg
334
+
335
+ # Dataset overrides
336
+ self.early_stopping_plddt_threshold = early_stopping_plddt_threshold
337
+ self.compress_outputs = compress_outputs
338
+
339
+ def initialize(self):
340
+ # Log checkpoint path on first init (base class logger may be suppressed in quiet mode)
341
+ if not self.initialized_:
342
+ ranked_logger.info(
343
+ f"Loading checkpoint from {Path(self.ckpt_path).resolve()}..."
344
+ )
345
+
346
+ cfg = super().initialize()
347
+
348
+ if cfg is not None:
349
+ self.cfg = cfg # store for later use
350
+
351
+ # Set trainer metrics directly based on what was requested
352
+ # This bypasses the OmegaConf merge issue with empty dicts
353
+ if isinstance(self._metrics_cfg, MetricManager):
354
+ # Already instantiated - use directly
355
+ self.trainer.metrics = self._metrics_cfg
356
+ elif self._metrics_cfg == "default":
357
+ # Use default RF3 metrics (ptm, iptm, clashing chains)
358
+ self.trainer.metrics = MetricManager.instantiate_from_hydra(
359
+ metrics_cfg=DEFAULT_RF3_METRICS_CFG
360
+ )
361
+ elif self._metrics_cfg is not None:
362
+ # Hydra config dict - instantiate MetricManager
363
+ self.trainer.metrics = MetricManager.instantiate_from_hydra(
364
+ metrics_cfg=self._metrics_cfg
365
+ )
366
+ else:
367
+ # No metrics requested - disable them
368
+ self.trainer.metrics = None
369
+
370
+ return cfg
371
+
372
+ def run(
373
+ self,
374
+ inputs: (
375
+ InferenceInput
376
+ | list[InferenceInput]
377
+ | AtomArray
378
+ | list[AtomArray]
379
+ | PathLike
380
+ | list[PathLike]
381
+ ),
382
+ # Output control
383
+ out_dir: PathLike | None = None,
384
+ dump_predictions: bool = True,
385
+ dump_trajectories: bool = False,
386
+ one_model_per_file: bool = False,
387
+ annotate_b_factor_with_plddt: bool = False,
388
+ sharding_pattern: str | None = None,
389
+ skip_existing: bool = False,
390
+ # Selection overrides (applied to all input types)
391
+ template_selection: list[str] | str | None = None,
392
+ ground_truth_conformer_selection: list[str] | str | None = None,
393
+ cyclic_chains: list[str] = [],
394
+ ) -> dict[str, dict] | None:
395
+ """Run inference on inputs.
396
+
397
+ Requires a pre-initialized inference engine.
398
+
399
+ Args:
400
+ inputs: Single/list of InferenceInput objects, AtomArray objects, file paths, or directory.
401
+ out_dir: Output directory. If None, returns results as an AtomArray and dictionaries of metrics. Defaults to ``None``.
402
+ dump_predictions: Whether to save predicted structures. Defaults to ``True``.
403
+ dump_trajectories: Whether to save diffusion trajectories. Defaults to ``False``.
404
+ one_model_per_file: Save each model in separate file. Defaults to ``False``.
405
+ annotate_b_factor_with_plddt: Write pLDDT to B-factor column. Defaults to ``False``.
406
+ sharding_pattern: Sharding pattern for output organization. Defaults to ``None``.
407
+ skip_existing: Skip inputs with existing outputs. Requires ``out_dir`` to be set. If ``True`` when ``out_dir=None``, a warning is logged and skipping is disabled. Defaults to ``False``.
408
+ template_selection: Template selection override. Defaults to ``None``.
409
+ ground_truth_conformer_selection: Conformer selection override. Defaults to ``None``.
410
+ cyclic_chains: List of chain IDs to cyclize. Defaults to ``[]``.
411
+
412
+ Returns:
413
+ If ``out_dir`` is None: Dict mapping example_id to list of RF3Output objects.
414
+ If ``out_dir`` is set: None (results saved to disk).
415
+ """
416
+ self.initialize()
417
+
418
+ # Setup output directory if provided
419
+ out_dir = Path(out_dir) if out_dir else None
420
+ if out_dir:
421
+ out_dir.mkdir(parents=True, exist_ok=True)
422
+ ranked_logger.info(f"Outputs will be written to {out_dir.resolve()}.")
423
+ if not out_dir:
424
+ ranked_logger.warning(
425
+ "out_dir is None - results will be returned in memory! If you want to save to disk, please provide an out_dir."
426
+ )
427
+
428
+ # Validate skip_existing configuration
429
+ if skip_existing and out_dir is None:
430
+ ranked_logger.warning(
431
+ "skip_existing=True requires out_dir to be set. "
432
+ "Disabling skip_existing for in-memory inference mode."
433
+ )
434
+ skip_existing = False
435
+
436
+ # Determine file type based on compression setting
437
+ file_type = "cif.gz" if self.compress_outputs else "cif"
438
+
439
+ # Convert inputs to InferenceInput objects
440
+ if isinstance(inputs, InferenceInput):
441
+ inference_inputs = [inputs]
442
+ elif isinstance(inputs, list) and all(
443
+ isinstance(i, InferenceInput) for i in inputs
444
+ ):
445
+ inference_inputs = inputs
446
+ elif isinstance(inputs, AtomArray):
447
+ # Single AtomArray - convert to InferenceInput
448
+ inference_inputs = [
449
+ InferenceInput.from_atom_array(
450
+ inputs,
451
+ template_selection=template_selection,
452
+ ground_truth_conformer_selection=ground_truth_conformer_selection,
453
+ )
454
+ ]
455
+ elif isinstance(inputs, list) and all(isinstance(i, AtomArray) for i in inputs):
456
+ # List of AtomArrays - convert each to InferenceInput
457
+ inference_inputs = [
458
+ InferenceInput.from_atom_array(
459
+ arr,
460
+ example_id=f"inference_{i}",
461
+ template_selection=template_selection,
462
+ ground_truth_conformer_selection=ground_truth_conformer_selection,
463
+ )
464
+ for i, arr in enumerate(inputs)
465
+ ]
466
+ elif isinstance(inputs, (str, Path)) or (
467
+ isinstance(inputs, list) and isinstance(inputs[0], (str, Path))
468
+ ):
469
+ inference_inputs = prepare_inference_inputs_from_paths(
470
+ inputs=inputs,
471
+ existing_outputs_dir=out_dir if skip_existing else None,
472
+ sharding_pattern=sharding_pattern,
473
+ template_selection=template_selection,
474
+ ground_truth_conformer_selection=ground_truth_conformer_selection,
475
+ )
476
+ else:
477
+ raise ValueError(f"Unsupported inputs type: {type(inputs)}")
478
+
479
+ # Flag chains for cyclization if specified
480
+ if cyclic_chains:
481
+ for input_spec in inference_inputs:
482
+ input_spec.cyclic_chains = cyclic_chains
483
+
484
+ # make InferenceInputDataset
485
+ inference_dataset = InferenceInputDataset(inference_inputs)
486
+ ranked_logger.info(f"Found {len(inference_dataset)} structures to predict!")
487
+
488
+ # make LoadBalancedDistributedSampler
489
+ sampler = LoadBalancedDistributedSampler(
490
+ dataset=inference_dataset,
491
+ key_to_balance=inference_dataset.key_to_balance,
492
+ num_replicas=self.trainer.fabric.world_size,
493
+ rank=self.trainer.fabric.global_rank,
494
+ drop_last=False,
495
+ )
496
+
497
+ loader = DataLoader(
498
+ dataset=inference_dataset,
499
+ sampler=sampler,
500
+ batch_size=1,
501
+ num_workers=0, # multiprocessing is disabled since it shouldn't be hard to read InferenceInput objects
502
+ collate_fn=lambda x: x, # no collation since we're not batching
503
+ pin_memory=True,
504
+ drop_last=False,
505
+ )
506
+
507
+ # Prepare results dict (if returning in-memory)
508
+ results = {} if out_dir is None else None
509
+
510
+ # Main inference loop
511
+ for batch_idx, input_spec in enumerate(loader):
512
+ input_spec = input_spec[
513
+ 0
514
+ ] # since we're not batching, the loader returns a list of length 1
515
+ ranked_logger.info(
516
+ f"Predicting structure {batch_idx + 1}/{len(loader)}: {input_spec.example_id}"
517
+ )
518
+
519
+ # Create output directory for this example if saving to disk
520
+ if out_dir:
521
+ example_out_dir = get_sharded_output_path(
522
+ input_spec.example_id, out_dir, sharding_pattern
523
+ )
524
+ example_out_dir.mkdir(parents=True, exist_ok=True)
525
+
526
+ # Run through Transform pipeline
527
+ pipeline_output = self.pipeline(input_spec.to_pipeline_input())
528
+
529
+ # Setup early stopping function if configured
530
+ should_early_stop_fn = None
531
+ if (
532
+ "confidence_feats" in pipeline_output
533
+ and self.early_stopping_plddt_threshold
534
+ and self.early_stopping_plddt_threshold > 0
535
+ ):
536
+ should_early_stop_fn = should_early_stop_by_mean_plddt(
537
+ self.early_stopping_plddt_threshold,
538
+ pipeline_output["confidence_feats"]["is_real_atom"],
539
+ self.cfg.trainer.loss.confidence_loss.plddt.max_value,
540
+ )
541
+
542
+ # Model inference
543
+ with torch.no_grad():
544
+ pipeline_output = self.trainer.fabric.to_device(pipeline_output)
545
+ if should_early_stop_fn:
546
+ valid_step_outs = self.trainer.validation_step(
547
+ batch=pipeline_output,
548
+ batch_idx=0,
549
+ compute_metrics=True,
550
+ should_early_stop_fn=should_early_stop_fn,
551
+ )
552
+ else:
553
+ valid_step_outs = self.trainer.validation_step(
554
+ batch=pipeline_output,
555
+ batch_idx=0,
556
+ compute_metrics=True,
557
+ )
558
+ network_output = valid_step_outs["network_output"]
559
+ metrics_output = valid_step_outs["metrics_output"]
560
+
561
+ # Handle early stopping
562
+ if network_output.get("early_stopped", False):
563
+ ranked_logger.warning(
564
+ f"Early stopping triggered for {input_spec.example_id} "
565
+ f"with mean pLDDT {network_output['mean_plddt']:.2f} < "
566
+ f"{self.early_stopping_plddt_threshold:.2f}!"
567
+ )
568
+
569
+ if out_dir:
570
+ # Save early stop info to disk
571
+ dict_to_save = {
572
+ k: v for k, v in network_output.items() if v is not None
573
+ }
574
+ df_to_save = pd.DataFrame([dict_to_save])
575
+ df_to_save.to_csv(example_out_dir / "score.csv", index=False)
576
+
577
+ df_to_save = pd.DataFrame([metrics_output])
578
+ df_to_save.to_csv(
579
+ example_out_dir / f"{input_spec.example_id}_metrics.csv",
580
+ index=False,
581
+ )
582
+ else:
583
+ # Store in results dict
584
+ results[input_spec.example_id] = {
585
+ "early_stopped": True,
586
+ "mean_plddt": network_output["mean_plddt"],
587
+ "metrics": metrics_output,
588
+ }
589
+
590
+ continue
591
+
592
+ # Build predicted structures
593
+ atom_array_stack = build_stack_from_atom_array_and_batched_coords(
594
+ network_output["X_L"], pipeline_output["atom_array"]
595
+ )
596
+ num_samples = (
597
+ len(atom_array_stack)
598
+ if isinstance(atom_array_stack, AtomArrayStack)
599
+ else 1
600
+ )
601
+
602
+ # Build RF3Output objects for each sample
603
+ rf3_outputs: list[RF3Output] = []
604
+ for sample_idx in range(num_samples):
605
+ # Get atom array for this sample
606
+ if isinstance(atom_array_stack, AtomArrayStack):
607
+ sample_atom_array = atom_array_stack[sample_idx]
608
+ else:
609
+ sample_atom_array = atom_array_stack
610
+
611
+ # Compile confidence outputs in AF3 format (if available)
612
+ summary_confidences = {}
613
+ confidences = None
614
+ if "plddt" in network_output:
615
+ conf_outs = compile_af3_style_confidence_outputs(
616
+ plddt_logits=network_output["plddt"],
617
+ pae_logits=network_output["pae"],
618
+ pde_logits=network_output["pde"],
619
+ chain_iid_token_lvl=pipeline_output["ground_truth"][
620
+ "chain_iid_token_lvl"
621
+ ],
622
+ is_real_atom=pipeline_output["confidence_feats"][
623
+ "is_real_atom"
624
+ ],
625
+ atom_array=pipeline_output["atom_array"],
626
+ confidence_loss_cfg=self.cfg.trainer.loss.confidence_loss,
627
+ batch_idx=sample_idx,
628
+ )
629
+ summary_confidences = conf_outs["summary_confidences"]
630
+ confidences = conf_outs["confidences"]
631
+
632
+ # Annotate b-factor with pLDDT if requested
633
+ if annotate_b_factor_with_plddt:
634
+ atom_array_list = annotate_atom_array_b_factor_with_plddt(
635
+ atom_array_stack,
636
+ conf_outs["plddt"],
637
+ pipeline_output["confidence_feats"]["is_real_atom"],
638
+ )
639
+ sample_atom_array = atom_array_list[sample_idx]
640
+
641
+ # Add metrics (ptm, iptm, has_clash) to summary_confidences
642
+ if metrics_output:
643
+ ptm_key = f"ptm.ptm_{sample_idx}"
644
+ iptm_key = f"iptm.iptm_{sample_idx}"
645
+ clash_key = f"count_clashing_chains.has_clash_{sample_idx}"
646
+
647
+ ptm_val = metrics_output.get(ptm_key)
648
+ iptm_val = metrics_output.get(iptm_key)
649
+ has_clash = bool(metrics_output.get(clash_key, 0))
650
+
651
+ # Convert to native Python floats for JSON serialization
652
+ ptm = float(ptm_val) if ptm_val is not None else None
653
+ iptm = float(iptm_val) if iptm_val is not None else None
654
+
655
+ summary_confidences["ptm"] = ptm
656
+ summary_confidences["iptm"] = iptm
657
+ summary_confidences["has_clash"] = has_clash
658
+
659
+ ranking_score = compute_ranking_score(
660
+ iptm=iptm,
661
+ ptm=ptm,
662
+ has_clash=has_clash,
663
+ )
664
+ summary_confidences["ranking_score"] = round(ranking_score, 4)
665
+
666
+ rf3_outputs.append(
667
+ RF3Output(
668
+ example_id=input_spec.example_id,
669
+ atom_array=sample_atom_array,
670
+ summary_confidences=summary_confidences,
671
+ confidences=confidences,
672
+ sample_idx=sample_idx,
673
+ seed=self.seed if self.seed is not None else 0,
674
+ )
675
+ )
676
+
677
+ # Save or return results
678
+ if out_dir:
679
+ # Save to disk in AlphaFold3-style directory structure
680
+ # Top-level: ranking_scores.csv, best model, best summary
681
+ dump_ranking_scores(rf3_outputs, example_out_dir, input_spec.example_id)
682
+ dump_top_ranked_outputs(
683
+ rf3_outputs,
684
+ example_out_dir,
685
+ input_spec.example_id,
686
+ file_type=file_type,
687
+ )
688
+
689
+ # Per-sample subdirectories
690
+ if dump_predictions:
691
+ for rf3_out in rf3_outputs:
692
+ sample_subdir = (
693
+ example_out_dir
694
+ / f"seed-{rf3_out.seed}_sample-{rf3_out.sample_idx}"
695
+ )
696
+ rf3_out.dump(
697
+ out_dir=sample_subdir,
698
+ file_type=file_type,
699
+ dump_full_confidences=True,
700
+ )
701
+
702
+ if dump_trajectories:
703
+ dump_structures(
704
+ atom_arrays=network_output["X_denoised_L_traj"],
705
+ base_path=example_out_dir / "denoised",
706
+ one_model_per_file=True,
707
+ file_type=file_type,
708
+ )
709
+ dump_structures(
710
+ atom_arrays=network_output["X_noisy_L_traj"],
711
+ base_path=example_out_dir / "noisy",
712
+ one_model_per_file=True,
713
+ file_type=file_type,
714
+ )
715
+
716
+ ranked_logger.info(
717
+ f"Outputs for {input_spec.example_id} written to {example_out_dir}!"
718
+ )
719
+ else:
720
+ # Store in memory - return list of RF3Output objects
721
+ results[input_spec.example_id] = rf3_outputs
722
+
723
+ # merge results across ranks
724
+ self.trainer.fabric.barrier()
725
+ if results is not None and dist.is_initialized():
726
+ gathered_results = [None] * self.trainer.fabric.world_size
727
+ dist.all_gather_object(
728
+ gathered_results, results
729
+ ) # returns a list of dicts, need to combine them
730
+ gathered_results = {
731
+ k: v for result in gathered_results for k, v in result.items()
732
+ } # combine the dicts into a single dict
733
+ results = gathered_results
734
+
735
+ return results