rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,735 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import re
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from os import PathLike
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import TextIO
|
|
8
|
+
|
|
9
|
+
import pandas as pd
|
|
10
|
+
import torch
|
|
11
|
+
import torch.distributed as dist
|
|
12
|
+
from atomworks.io.utils.io_utils import to_cif_file
|
|
13
|
+
from atomworks.ml.preprocessing.msa.finding import (
|
|
14
|
+
get_msa_depth_and_ext_from_folder,
|
|
15
|
+
get_msa_dirs_from_env,
|
|
16
|
+
)
|
|
17
|
+
from atomworks.ml.samplers import LoadBalancedDistributedSampler
|
|
18
|
+
from biotite.structure import AtomArray, AtomArrayStack
|
|
19
|
+
from omegaconf import OmegaConf
|
|
20
|
+
from torch.utils.data import DataLoader
|
|
21
|
+
|
|
22
|
+
from foundry.inference_engines.base import BaseInferenceEngine
|
|
23
|
+
from foundry.metrics.metric import MetricManager
|
|
24
|
+
from foundry.utils.ddp import RankedLogger
|
|
25
|
+
from rf3.model.RF3 import ShouldEarlyStopFn
|
|
26
|
+
from rf3.utils.inference import (
|
|
27
|
+
InferenceInput,
|
|
28
|
+
InferenceInputDataset,
|
|
29
|
+
prepare_inference_inputs_from_paths,
|
|
30
|
+
)
|
|
31
|
+
from rf3.utils.io import (
|
|
32
|
+
build_stack_from_atom_array_and_batched_coords,
|
|
33
|
+
dump_structures,
|
|
34
|
+
get_sharded_output_path,
|
|
35
|
+
)
|
|
36
|
+
from rf3.utils.predicted_error import (
|
|
37
|
+
annotate_atom_array_b_factor_with_plddt,
|
|
38
|
+
compile_af3_style_confidence_outputs,
|
|
39
|
+
get_mean_atomwise_plddt,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
logging.basicConfig(
|
|
43
|
+
level=logging.INFO,
|
|
44
|
+
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
|
45
|
+
datefmt="%H:%M:%S",
|
|
46
|
+
)
|
|
47
|
+
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
48
|
+
|
|
49
|
+
# Default metrics configuration for RF3 inference (ptm, iptm, clashing chains)
|
|
50
|
+
DEFAULT_RF3_METRICS_CFG = {
|
|
51
|
+
"ptm": {"_target_": "rf3.metrics.predicted_error.ComputePTM"},
|
|
52
|
+
"iptm": {"_target_": "rf3.metrics.predicted_error.ComputeIPTM"},
|
|
53
|
+
"count_clashing_chains": {
|
|
54
|
+
"_target_": "rf3.metrics.clashing_chains.CountClashingChains"
|
|
55
|
+
},
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def dump_json_compact_arrays(obj: dict, f: TextIO) -> None:
|
|
60
|
+
"""Dump JSON with indented structure but compact arrays (AF3 style).
|
|
61
|
+
|
|
62
|
+
Arrays are written on single lines instead of one element per line.
|
|
63
|
+
"""
|
|
64
|
+
# First dump with indent to get structure
|
|
65
|
+
json_str = json.dumps(obj, indent=2)
|
|
66
|
+
# Collapse arrays onto single lines using regex
|
|
67
|
+
# Match arrays that span multiple lines and collapse them
|
|
68
|
+
pattern = re.compile(r"\[\s*\n\s*([^\[\]]*?)\s*\n\s*\]", re.DOTALL)
|
|
69
|
+
while pattern.search(json_str):
|
|
70
|
+
json_str = pattern.sub(
|
|
71
|
+
lambda m: "["
|
|
72
|
+
+ ",".join(item.strip() for item in m.group(1).split(","))
|
|
73
|
+
+ "]",
|
|
74
|
+
json_str,
|
|
75
|
+
)
|
|
76
|
+
f.write(json_str)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def compute_ranking_score(
|
|
80
|
+
iptm: float | None,
|
|
81
|
+
ptm: float | None,
|
|
82
|
+
has_clash: bool,
|
|
83
|
+
) -> float:
|
|
84
|
+
"""Compute ranking score.
|
|
85
|
+
|
|
86
|
+
Formula: 0.8 * ipTM + 0.2 * pTM - 100 * has_clash
|
|
87
|
+
|
|
88
|
+
For single-chain predictions where ipTM is None, uses pTM only.
|
|
89
|
+
"""
|
|
90
|
+
if iptm is None:
|
|
91
|
+
# Single chain - use pTM only
|
|
92
|
+
iptm = ptm if ptm is not None else 0.0
|
|
93
|
+
if ptm is None:
|
|
94
|
+
ptm = 0.0
|
|
95
|
+
return 0.8 * iptm + 0.2 * ptm - 100 * int(has_clash)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@dataclass
|
|
99
|
+
class RF3Output:
|
|
100
|
+
"""Output container for RF3 predictions, analogous to RFD3Output.
|
|
101
|
+
|
|
102
|
+
Stores predicted structures and confidence metrics in AlphaFold3-compatible format.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
example_id: str
|
|
106
|
+
atom_array: AtomArray
|
|
107
|
+
summary_confidences: dict = field(default_factory=dict)
|
|
108
|
+
confidences: dict | None = None
|
|
109
|
+
sample_idx: int = 0
|
|
110
|
+
seed: int = 0
|
|
111
|
+
|
|
112
|
+
def dump(
|
|
113
|
+
self,
|
|
114
|
+
out_dir: Path,
|
|
115
|
+
file_type: str = "cif",
|
|
116
|
+
dump_full_confidences: bool = True,
|
|
117
|
+
) -> None:
|
|
118
|
+
"""Save output to disk in AlphaFold3-compatible format.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
out_dir: Directory to save outputs to.
|
|
122
|
+
file_type: File type for structure output ("cif" or "cif.gz").
|
|
123
|
+
dump_full_confidences: Whether to save full per-atom confidences.
|
|
124
|
+
"""
|
|
125
|
+
out_dir = Path(out_dir)
|
|
126
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
127
|
+
|
|
128
|
+
sample_name = f"{self.example_id}_seed-{self.seed}_sample-{self.sample_idx}"
|
|
129
|
+
base_path = out_dir / sample_name
|
|
130
|
+
|
|
131
|
+
# Save structure
|
|
132
|
+
to_cif_file(
|
|
133
|
+
self.atom_array,
|
|
134
|
+
f"{base_path}_model",
|
|
135
|
+
file_type=file_type,
|
|
136
|
+
include_entity_poly=False,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# Save summary_confidences.json
|
|
140
|
+
with open(f"{base_path}_summary_confidences.json", "w") as f:
|
|
141
|
+
dump_json_compact_arrays(self.summary_confidences, f)
|
|
142
|
+
|
|
143
|
+
# Save confidences.json (optional, for full per-atom data)
|
|
144
|
+
if dump_full_confidences and self.confidences:
|
|
145
|
+
with open(f"{base_path}_confidences.json", "w") as f:
|
|
146
|
+
dump_json_compact_arrays(self.confidences, f)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def dump_ranking_scores(
|
|
150
|
+
outputs: list[RF3Output],
|
|
151
|
+
out_dir: Path,
|
|
152
|
+
example_id: str,
|
|
153
|
+
) -> None:
|
|
154
|
+
"""Write {example_id}_ranking_scores.csv with ranking scores for all samples."""
|
|
155
|
+
rows = [
|
|
156
|
+
{
|
|
157
|
+
"seed": o.seed,
|
|
158
|
+
"sample": o.sample_idx,
|
|
159
|
+
"ranking_score": o.summary_confidences.get("ranking_score"),
|
|
160
|
+
}
|
|
161
|
+
for o in outputs
|
|
162
|
+
]
|
|
163
|
+
df = pd.DataFrame(rows)
|
|
164
|
+
df.to_csv(out_dir / f"{example_id}_ranking_scores.csv", index=False)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def dump_top_ranked_outputs(
|
|
168
|
+
outputs: list[RF3Output],
|
|
169
|
+
out_dir: Path,
|
|
170
|
+
example_id: str,
|
|
171
|
+
file_type: str = "cif",
|
|
172
|
+
) -> RF3Output:
|
|
173
|
+
"""Copy the top-ranked model and summary to the top-level directory.
|
|
174
|
+
|
|
175
|
+
Returns the top-ranked RF3Output.
|
|
176
|
+
"""
|
|
177
|
+
# Find the output with the highest ranking score
|
|
178
|
+
best_output = max(
|
|
179
|
+
outputs,
|
|
180
|
+
key=lambda o: o.summary_confidences.get("ranking_score", float("-inf")),
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
# Save top-ranked model at top level
|
|
184
|
+
to_cif_file(
|
|
185
|
+
best_output.atom_array,
|
|
186
|
+
out_dir / f"{example_id}_model",
|
|
187
|
+
file_type=file_type,
|
|
188
|
+
include_entity_poly=False,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
# Save top-ranked summary_confidences at top level
|
|
192
|
+
with open(out_dir / f"{example_id}_summary_confidences.json", "w") as f:
|
|
193
|
+
dump_json_compact_arrays(best_output.summary_confidences, f)
|
|
194
|
+
|
|
195
|
+
# Save top-ranked full confidences at top level (if present)
|
|
196
|
+
if best_output.confidences:
|
|
197
|
+
with open(out_dir / f"{example_id}_confidences.json", "w") as f:
|
|
198
|
+
dump_json_compact_arrays(best_output.confidences, f)
|
|
199
|
+
|
|
200
|
+
return best_output
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def should_early_stop_by_mean_plddt(
|
|
204
|
+
threshold: float, is_real_atom: torch.Tensor, max_value_of_plddt: float
|
|
205
|
+
) -> ShouldEarlyStopFn:
|
|
206
|
+
"""Returns a closure that triggers early stopping when mean pLDDT falls below the specified threshold."""
|
|
207
|
+
|
|
208
|
+
def fn(confidence_outputs: dict, **kwargs):
|
|
209
|
+
mean_plddt = get_mean_atomwise_plddt(
|
|
210
|
+
plddt_logits=confidence_outputs["plddt_logits"].unsqueeze(0),
|
|
211
|
+
is_real_atom=is_real_atom,
|
|
212
|
+
max_value=max_value_of_plddt,
|
|
213
|
+
)
|
|
214
|
+
return (mean_plddt < threshold).item(), {
|
|
215
|
+
"mean_plddt": mean_plddt.item(),
|
|
216
|
+
"threshold": threshold,
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
return fn
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
class RF3InferenceEngine(BaseInferenceEngine):
|
|
223
|
+
"""RF3 inference engine.
|
|
224
|
+
|
|
225
|
+
Separates model setup (expensive, once) from inference (can run multiple times).
|
|
226
|
+
|
|
227
|
+
Usage:
|
|
228
|
+
# Setup once
|
|
229
|
+
engine = RF3InferenceEngine(
|
|
230
|
+
ckpt_path="rf3_latest.pt",
|
|
231
|
+
n_recycles=10,
|
|
232
|
+
diffusion_batch_size=5,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
# Run inference multiple times with different inputs
|
|
236
|
+
results1 = engine.run(inputs="path/to/cifs", out_dir="./predictions")
|
|
237
|
+
results2 = engine.run(inputs=InferenceInput.from_atom_array(array), out_dir=None)
|
|
238
|
+
results3 = engine.run(inputs=[input1, input2], out_dir="./more_predictions")
|
|
239
|
+
"""
|
|
240
|
+
|
|
241
|
+
def __init__(
|
|
242
|
+
self,
|
|
243
|
+
# Model parameters
|
|
244
|
+
n_recycles: int = 10,
|
|
245
|
+
diffusion_batch_size: int = 5,
|
|
246
|
+
num_steps: int = 50,
|
|
247
|
+
# Templating, MSAs, etc.
|
|
248
|
+
template_noise_scale: float = 1e-5,
|
|
249
|
+
raise_if_missing_msa_for_protein_of_length_n: int | None = None,
|
|
250
|
+
# Output control
|
|
251
|
+
compress_outputs: bool = False,
|
|
252
|
+
early_stopping_plddt_threshold: float | None = None,
|
|
253
|
+
# Metrics
|
|
254
|
+
metrics_cfg: dict | OmegaConf | MetricManager | str | None = "default",
|
|
255
|
+
**kwargs,
|
|
256
|
+
):
|
|
257
|
+
"""Initialize inference engine and load model.
|
|
258
|
+
|
|
259
|
+
Model config is loaded from checkpoint and overridden with parameters provided here.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
n_recycles: Number of recycles. Defaults to ``10``.
|
|
263
|
+
diffusion_batch_size: Number of structures to generate per input. Defaults to ``5``.
|
|
264
|
+
num_steps: Number of diffusion steps. Defaults to ``50``.
|
|
265
|
+
template_noise_scale: Noise scale for template coordinates. Defaults to ``1e-5``.
|
|
266
|
+
raise_if_missing_msa_for_protein_of_length_n: Debug flag for MSA checking. Defaults to ``None``.
|
|
267
|
+
compress_outputs: Whether to gzip output files. Defaults to ``False``.
|
|
268
|
+
early_stopping_plddt_threshold: Stop early if pLDDT below threshold. Defaults to ``None``.
|
|
269
|
+
metrics_cfg: Metrics configuration. Can be:
|
|
270
|
+
- "default" to use standard RF3 metrics (ptm, iptm, clashing chains)
|
|
271
|
+
- dict/OmegaConf with Hydra configs
|
|
272
|
+
- Pre-instantiated MetricManager
|
|
273
|
+
- None (no metrics).
|
|
274
|
+
Defaults to ``"default"``.
|
|
275
|
+
**kwargs: Additional arguments passed to BaseInferenceEngine:
|
|
276
|
+
- ckpt_path (PathLike, required): Path to model checkpoint.
|
|
277
|
+
- seed (int | None): Random seed. If None, uses external RNG state. Defaults to ``None``.
|
|
278
|
+
- num_nodes (int): Number of nodes for distributed inference. Defaults to ``1``.
|
|
279
|
+
- devices_per_node (int): Number of devices per node. Defaults to ``1``.
|
|
280
|
+
- verbose (bool): If True, show detailed logging and config trees. Defaults to ``False``.
|
|
281
|
+
"""
|
|
282
|
+
# set MSA directories from environment variable only
|
|
283
|
+
if env_var_msa_dirs := get_msa_dirs_from_env(raise_if_not_set=False):
|
|
284
|
+
override_msa_dirs = [str(msa_dir) for msa_dir in env_var_msa_dirs]
|
|
285
|
+
ranked_logger.debug(
|
|
286
|
+
f"Using MSA directories from environment variable: {override_msa_dirs}"
|
|
287
|
+
)
|
|
288
|
+
else:
|
|
289
|
+
override_msa_dirs = []
|
|
290
|
+
ranked_logger.debug(
|
|
291
|
+
"No MSA directories set (LOCAL_MSA_DIRS env var not found)"
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
super().__init__(
|
|
295
|
+
transform_overrides={
|
|
296
|
+
"diffusion_batch_size": diffusion_batch_size,
|
|
297
|
+
"n_recycles": n_recycles,
|
|
298
|
+
"raise_if_missing_msa_for_protein_of_length_n": raise_if_missing_msa_for_protein_of_length_n,
|
|
299
|
+
"undesired_res_names": [],
|
|
300
|
+
"template_noise_scales": {
|
|
301
|
+
"atomized": template_noise_scale,
|
|
302
|
+
"not_atomized": template_noise_scale,
|
|
303
|
+
},
|
|
304
|
+
"allowed_chain_types_for_conditioning": None,
|
|
305
|
+
"protein_msa_dirs": [
|
|
306
|
+
{
|
|
307
|
+
"dir": msa_dir,
|
|
308
|
+
"extension": extension.value,
|
|
309
|
+
"directory_depth": depth,
|
|
310
|
+
}
|
|
311
|
+
for msa_dir, depth, extension in [
|
|
312
|
+
(msa_dir, *get_msa_depth_and_ext_from_folder(Path(msa_dir)))
|
|
313
|
+
for msa_dir in override_msa_dirs
|
|
314
|
+
]
|
|
315
|
+
],
|
|
316
|
+
"rna_msa_dirs": [],
|
|
317
|
+
# (Paranoia - in validation, these should be set correctly anyhow)
|
|
318
|
+
"p_give_polymer_ref_conf": 0.0,
|
|
319
|
+
"p_give_non_polymer_ref_conf": 0.0,
|
|
320
|
+
"p_dropout_ref_conf": 0.0,
|
|
321
|
+
"use_element_for_atom_names_of_atomized_tokens": True,
|
|
322
|
+
},
|
|
323
|
+
inference_sampler_overrides={
|
|
324
|
+
"num_timesteps": num_steps,
|
|
325
|
+
},
|
|
326
|
+
**kwargs,
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
# remove loss override if present (i.e. keep from checkpoint)
|
|
330
|
+
self.overrides["trainer"].pop("loss", None)
|
|
331
|
+
|
|
332
|
+
# Store metrics config for later - will be set directly on trainer in initialize()
|
|
333
|
+
self._metrics_cfg = metrics_cfg
|
|
334
|
+
|
|
335
|
+
# Dataset overrides
|
|
336
|
+
self.early_stopping_plddt_threshold = early_stopping_plddt_threshold
|
|
337
|
+
self.compress_outputs = compress_outputs
|
|
338
|
+
|
|
339
|
+
def initialize(self):
|
|
340
|
+
# Log checkpoint path on first init (base class logger may be suppressed in quiet mode)
|
|
341
|
+
if not self.initialized_:
|
|
342
|
+
ranked_logger.info(
|
|
343
|
+
f"Loading checkpoint from {Path(self.ckpt_path).resolve()}..."
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
cfg = super().initialize()
|
|
347
|
+
|
|
348
|
+
if cfg is not None:
|
|
349
|
+
self.cfg = cfg # store for later use
|
|
350
|
+
|
|
351
|
+
# Set trainer metrics directly based on what was requested
|
|
352
|
+
# This bypasses the OmegaConf merge issue with empty dicts
|
|
353
|
+
if isinstance(self._metrics_cfg, MetricManager):
|
|
354
|
+
# Already instantiated - use directly
|
|
355
|
+
self.trainer.metrics = self._metrics_cfg
|
|
356
|
+
elif self._metrics_cfg == "default":
|
|
357
|
+
# Use default RF3 metrics (ptm, iptm, clashing chains)
|
|
358
|
+
self.trainer.metrics = MetricManager.instantiate_from_hydra(
|
|
359
|
+
metrics_cfg=DEFAULT_RF3_METRICS_CFG
|
|
360
|
+
)
|
|
361
|
+
elif self._metrics_cfg is not None:
|
|
362
|
+
# Hydra config dict - instantiate MetricManager
|
|
363
|
+
self.trainer.metrics = MetricManager.instantiate_from_hydra(
|
|
364
|
+
metrics_cfg=self._metrics_cfg
|
|
365
|
+
)
|
|
366
|
+
else:
|
|
367
|
+
# No metrics requested - disable them
|
|
368
|
+
self.trainer.metrics = None
|
|
369
|
+
|
|
370
|
+
return cfg
|
|
371
|
+
|
|
372
|
+
def run(
|
|
373
|
+
self,
|
|
374
|
+
inputs: (
|
|
375
|
+
InferenceInput
|
|
376
|
+
| list[InferenceInput]
|
|
377
|
+
| AtomArray
|
|
378
|
+
| list[AtomArray]
|
|
379
|
+
| PathLike
|
|
380
|
+
| list[PathLike]
|
|
381
|
+
),
|
|
382
|
+
# Output control
|
|
383
|
+
out_dir: PathLike | None = None,
|
|
384
|
+
dump_predictions: bool = True,
|
|
385
|
+
dump_trajectories: bool = False,
|
|
386
|
+
one_model_per_file: bool = False,
|
|
387
|
+
annotate_b_factor_with_plddt: bool = False,
|
|
388
|
+
sharding_pattern: str | None = None,
|
|
389
|
+
skip_existing: bool = False,
|
|
390
|
+
# Selection overrides (applied to all input types)
|
|
391
|
+
template_selection: list[str] | str | None = None,
|
|
392
|
+
ground_truth_conformer_selection: list[str] | str | None = None,
|
|
393
|
+
cyclic_chains: list[str] = [],
|
|
394
|
+
) -> dict[str, dict] | None:
|
|
395
|
+
"""Run inference on inputs.
|
|
396
|
+
|
|
397
|
+
Requires a pre-initialized inference engine.
|
|
398
|
+
|
|
399
|
+
Args:
|
|
400
|
+
inputs: Single/list of InferenceInput objects, AtomArray objects, file paths, or directory.
|
|
401
|
+
out_dir: Output directory. If None, returns results as an AtomArray and dictionaries of metrics. Defaults to ``None``.
|
|
402
|
+
dump_predictions: Whether to save predicted structures. Defaults to ``True``.
|
|
403
|
+
dump_trajectories: Whether to save diffusion trajectories. Defaults to ``False``.
|
|
404
|
+
one_model_per_file: Save each model in separate file. Defaults to ``False``.
|
|
405
|
+
annotate_b_factor_with_plddt: Write pLDDT to B-factor column. Defaults to ``False``.
|
|
406
|
+
sharding_pattern: Sharding pattern for output organization. Defaults to ``None``.
|
|
407
|
+
skip_existing: Skip inputs with existing outputs. Requires ``out_dir`` to be set. If ``True`` when ``out_dir=None``, a warning is logged and skipping is disabled. Defaults to ``False``.
|
|
408
|
+
template_selection: Template selection override. Defaults to ``None``.
|
|
409
|
+
ground_truth_conformer_selection: Conformer selection override. Defaults to ``None``.
|
|
410
|
+
cyclic_chains: List of chain IDs to cyclize. Defaults to ``[]``.
|
|
411
|
+
|
|
412
|
+
Returns:
|
|
413
|
+
If ``out_dir`` is None: Dict mapping example_id to list of RF3Output objects.
|
|
414
|
+
If ``out_dir`` is set: None (results saved to disk).
|
|
415
|
+
"""
|
|
416
|
+
self.initialize()
|
|
417
|
+
|
|
418
|
+
# Setup output directory if provided
|
|
419
|
+
out_dir = Path(out_dir) if out_dir else None
|
|
420
|
+
if out_dir:
|
|
421
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
422
|
+
ranked_logger.info(f"Outputs will be written to {out_dir.resolve()}.")
|
|
423
|
+
if not out_dir:
|
|
424
|
+
ranked_logger.warning(
|
|
425
|
+
"out_dir is None - results will be returned in memory! If you want to save to disk, please provide an out_dir."
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
# Validate skip_existing configuration
|
|
429
|
+
if skip_existing and out_dir is None:
|
|
430
|
+
ranked_logger.warning(
|
|
431
|
+
"skip_existing=True requires out_dir to be set. "
|
|
432
|
+
"Disabling skip_existing for in-memory inference mode."
|
|
433
|
+
)
|
|
434
|
+
skip_existing = False
|
|
435
|
+
|
|
436
|
+
# Determine file type based on compression setting
|
|
437
|
+
file_type = "cif.gz" if self.compress_outputs else "cif"
|
|
438
|
+
|
|
439
|
+
# Convert inputs to InferenceInput objects
|
|
440
|
+
if isinstance(inputs, InferenceInput):
|
|
441
|
+
inference_inputs = [inputs]
|
|
442
|
+
elif isinstance(inputs, list) and all(
|
|
443
|
+
isinstance(i, InferenceInput) for i in inputs
|
|
444
|
+
):
|
|
445
|
+
inference_inputs = inputs
|
|
446
|
+
elif isinstance(inputs, AtomArray):
|
|
447
|
+
# Single AtomArray - convert to InferenceInput
|
|
448
|
+
inference_inputs = [
|
|
449
|
+
InferenceInput.from_atom_array(
|
|
450
|
+
inputs,
|
|
451
|
+
template_selection=template_selection,
|
|
452
|
+
ground_truth_conformer_selection=ground_truth_conformer_selection,
|
|
453
|
+
)
|
|
454
|
+
]
|
|
455
|
+
elif isinstance(inputs, list) and all(isinstance(i, AtomArray) for i in inputs):
|
|
456
|
+
# List of AtomArrays - convert each to InferenceInput
|
|
457
|
+
inference_inputs = [
|
|
458
|
+
InferenceInput.from_atom_array(
|
|
459
|
+
arr,
|
|
460
|
+
example_id=f"inference_{i}",
|
|
461
|
+
template_selection=template_selection,
|
|
462
|
+
ground_truth_conformer_selection=ground_truth_conformer_selection,
|
|
463
|
+
)
|
|
464
|
+
for i, arr in enumerate(inputs)
|
|
465
|
+
]
|
|
466
|
+
elif isinstance(inputs, (str, Path)) or (
|
|
467
|
+
isinstance(inputs, list) and isinstance(inputs[0], (str, Path))
|
|
468
|
+
):
|
|
469
|
+
inference_inputs = prepare_inference_inputs_from_paths(
|
|
470
|
+
inputs=inputs,
|
|
471
|
+
existing_outputs_dir=out_dir if skip_existing else None,
|
|
472
|
+
sharding_pattern=sharding_pattern,
|
|
473
|
+
template_selection=template_selection,
|
|
474
|
+
ground_truth_conformer_selection=ground_truth_conformer_selection,
|
|
475
|
+
)
|
|
476
|
+
else:
|
|
477
|
+
raise ValueError(f"Unsupported inputs type: {type(inputs)}")
|
|
478
|
+
|
|
479
|
+
# Flag chains for cyclization if specified
|
|
480
|
+
if cyclic_chains:
|
|
481
|
+
for input_spec in inference_inputs:
|
|
482
|
+
input_spec.cyclic_chains = cyclic_chains
|
|
483
|
+
|
|
484
|
+
# make InferenceInputDataset
|
|
485
|
+
inference_dataset = InferenceInputDataset(inference_inputs)
|
|
486
|
+
ranked_logger.info(f"Found {len(inference_dataset)} structures to predict!")
|
|
487
|
+
|
|
488
|
+
# make LoadBalancedDistributedSampler
|
|
489
|
+
sampler = LoadBalancedDistributedSampler(
|
|
490
|
+
dataset=inference_dataset,
|
|
491
|
+
key_to_balance=inference_dataset.key_to_balance,
|
|
492
|
+
num_replicas=self.trainer.fabric.world_size,
|
|
493
|
+
rank=self.trainer.fabric.global_rank,
|
|
494
|
+
drop_last=False,
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
loader = DataLoader(
|
|
498
|
+
dataset=inference_dataset,
|
|
499
|
+
sampler=sampler,
|
|
500
|
+
batch_size=1,
|
|
501
|
+
num_workers=0, # multiprocessing is disabled since it shouldn't be hard to read InferenceInput objects
|
|
502
|
+
collate_fn=lambda x: x, # no collation since we're not batching
|
|
503
|
+
pin_memory=True,
|
|
504
|
+
drop_last=False,
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
# Prepare results dict (if returning in-memory)
|
|
508
|
+
results = {} if out_dir is None else None
|
|
509
|
+
|
|
510
|
+
# Main inference loop
|
|
511
|
+
for batch_idx, input_spec in enumerate(loader):
|
|
512
|
+
input_spec = input_spec[
|
|
513
|
+
0
|
|
514
|
+
] # since we're not batching, the loader returns a list of length 1
|
|
515
|
+
ranked_logger.info(
|
|
516
|
+
f"Predicting structure {batch_idx + 1}/{len(loader)}: {input_spec.example_id}"
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
# Create output directory for this example if saving to disk
|
|
520
|
+
if out_dir:
|
|
521
|
+
example_out_dir = get_sharded_output_path(
|
|
522
|
+
input_spec.example_id, out_dir, sharding_pattern
|
|
523
|
+
)
|
|
524
|
+
example_out_dir.mkdir(parents=True, exist_ok=True)
|
|
525
|
+
|
|
526
|
+
# Run through Transform pipeline
|
|
527
|
+
pipeline_output = self.pipeline(input_spec.to_pipeline_input())
|
|
528
|
+
|
|
529
|
+
# Setup early stopping function if configured
|
|
530
|
+
should_early_stop_fn = None
|
|
531
|
+
if (
|
|
532
|
+
"confidence_feats" in pipeline_output
|
|
533
|
+
and self.early_stopping_plddt_threshold
|
|
534
|
+
and self.early_stopping_plddt_threshold > 0
|
|
535
|
+
):
|
|
536
|
+
should_early_stop_fn = should_early_stop_by_mean_plddt(
|
|
537
|
+
self.early_stopping_plddt_threshold,
|
|
538
|
+
pipeline_output["confidence_feats"]["is_real_atom"],
|
|
539
|
+
self.cfg.trainer.loss.confidence_loss.plddt.max_value,
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
# Model inference
|
|
543
|
+
with torch.no_grad():
|
|
544
|
+
pipeline_output = self.trainer.fabric.to_device(pipeline_output)
|
|
545
|
+
if should_early_stop_fn:
|
|
546
|
+
valid_step_outs = self.trainer.validation_step(
|
|
547
|
+
batch=pipeline_output,
|
|
548
|
+
batch_idx=0,
|
|
549
|
+
compute_metrics=True,
|
|
550
|
+
should_early_stop_fn=should_early_stop_fn,
|
|
551
|
+
)
|
|
552
|
+
else:
|
|
553
|
+
valid_step_outs = self.trainer.validation_step(
|
|
554
|
+
batch=pipeline_output,
|
|
555
|
+
batch_idx=0,
|
|
556
|
+
compute_metrics=True,
|
|
557
|
+
)
|
|
558
|
+
network_output = valid_step_outs["network_output"]
|
|
559
|
+
metrics_output = valid_step_outs["metrics_output"]
|
|
560
|
+
|
|
561
|
+
# Handle early stopping
|
|
562
|
+
if network_output.get("early_stopped", False):
|
|
563
|
+
ranked_logger.warning(
|
|
564
|
+
f"Early stopping triggered for {input_spec.example_id} "
|
|
565
|
+
f"with mean pLDDT {network_output['mean_plddt']:.2f} < "
|
|
566
|
+
f"{self.early_stopping_plddt_threshold:.2f}!"
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
if out_dir:
|
|
570
|
+
# Save early stop info to disk
|
|
571
|
+
dict_to_save = {
|
|
572
|
+
k: v for k, v in network_output.items() if v is not None
|
|
573
|
+
}
|
|
574
|
+
df_to_save = pd.DataFrame([dict_to_save])
|
|
575
|
+
df_to_save.to_csv(example_out_dir / "score.csv", index=False)
|
|
576
|
+
|
|
577
|
+
df_to_save = pd.DataFrame([metrics_output])
|
|
578
|
+
df_to_save.to_csv(
|
|
579
|
+
example_out_dir / f"{input_spec.example_id}_metrics.csv",
|
|
580
|
+
index=False,
|
|
581
|
+
)
|
|
582
|
+
else:
|
|
583
|
+
# Store in results dict
|
|
584
|
+
results[input_spec.example_id] = {
|
|
585
|
+
"early_stopped": True,
|
|
586
|
+
"mean_plddt": network_output["mean_plddt"],
|
|
587
|
+
"metrics": metrics_output,
|
|
588
|
+
}
|
|
589
|
+
|
|
590
|
+
continue
|
|
591
|
+
|
|
592
|
+
# Build predicted structures
|
|
593
|
+
atom_array_stack = build_stack_from_atom_array_and_batched_coords(
|
|
594
|
+
network_output["X_L"], pipeline_output["atom_array"]
|
|
595
|
+
)
|
|
596
|
+
num_samples = (
|
|
597
|
+
len(atom_array_stack)
|
|
598
|
+
if isinstance(atom_array_stack, AtomArrayStack)
|
|
599
|
+
else 1
|
|
600
|
+
)
|
|
601
|
+
|
|
602
|
+
# Build RF3Output objects for each sample
|
|
603
|
+
rf3_outputs: list[RF3Output] = []
|
|
604
|
+
for sample_idx in range(num_samples):
|
|
605
|
+
# Get atom array for this sample
|
|
606
|
+
if isinstance(atom_array_stack, AtomArrayStack):
|
|
607
|
+
sample_atom_array = atom_array_stack[sample_idx]
|
|
608
|
+
else:
|
|
609
|
+
sample_atom_array = atom_array_stack
|
|
610
|
+
|
|
611
|
+
# Compile confidence outputs in AF3 format (if available)
|
|
612
|
+
summary_confidences = {}
|
|
613
|
+
confidences = None
|
|
614
|
+
if "plddt" in network_output:
|
|
615
|
+
conf_outs = compile_af3_style_confidence_outputs(
|
|
616
|
+
plddt_logits=network_output["plddt"],
|
|
617
|
+
pae_logits=network_output["pae"],
|
|
618
|
+
pde_logits=network_output["pde"],
|
|
619
|
+
chain_iid_token_lvl=pipeline_output["ground_truth"][
|
|
620
|
+
"chain_iid_token_lvl"
|
|
621
|
+
],
|
|
622
|
+
is_real_atom=pipeline_output["confidence_feats"][
|
|
623
|
+
"is_real_atom"
|
|
624
|
+
],
|
|
625
|
+
atom_array=pipeline_output["atom_array"],
|
|
626
|
+
confidence_loss_cfg=self.cfg.trainer.loss.confidence_loss,
|
|
627
|
+
batch_idx=sample_idx,
|
|
628
|
+
)
|
|
629
|
+
summary_confidences = conf_outs["summary_confidences"]
|
|
630
|
+
confidences = conf_outs["confidences"]
|
|
631
|
+
|
|
632
|
+
# Annotate b-factor with pLDDT if requested
|
|
633
|
+
if annotate_b_factor_with_plddt:
|
|
634
|
+
atom_array_list = annotate_atom_array_b_factor_with_plddt(
|
|
635
|
+
atom_array_stack,
|
|
636
|
+
conf_outs["plddt"],
|
|
637
|
+
pipeline_output["confidence_feats"]["is_real_atom"],
|
|
638
|
+
)
|
|
639
|
+
sample_atom_array = atom_array_list[sample_idx]
|
|
640
|
+
|
|
641
|
+
# Add metrics (ptm, iptm, has_clash) to summary_confidences
|
|
642
|
+
if metrics_output:
|
|
643
|
+
ptm_key = f"ptm.ptm_{sample_idx}"
|
|
644
|
+
iptm_key = f"iptm.iptm_{sample_idx}"
|
|
645
|
+
clash_key = f"count_clashing_chains.has_clash_{sample_idx}"
|
|
646
|
+
|
|
647
|
+
ptm_val = metrics_output.get(ptm_key)
|
|
648
|
+
iptm_val = metrics_output.get(iptm_key)
|
|
649
|
+
has_clash = bool(metrics_output.get(clash_key, 0))
|
|
650
|
+
|
|
651
|
+
# Convert to native Python floats for JSON serialization
|
|
652
|
+
ptm = float(ptm_val) if ptm_val is not None else None
|
|
653
|
+
iptm = float(iptm_val) if iptm_val is not None else None
|
|
654
|
+
|
|
655
|
+
summary_confidences["ptm"] = ptm
|
|
656
|
+
summary_confidences["iptm"] = iptm
|
|
657
|
+
summary_confidences["has_clash"] = has_clash
|
|
658
|
+
|
|
659
|
+
ranking_score = compute_ranking_score(
|
|
660
|
+
iptm=iptm,
|
|
661
|
+
ptm=ptm,
|
|
662
|
+
has_clash=has_clash,
|
|
663
|
+
)
|
|
664
|
+
summary_confidences["ranking_score"] = round(ranking_score, 4)
|
|
665
|
+
|
|
666
|
+
rf3_outputs.append(
|
|
667
|
+
RF3Output(
|
|
668
|
+
example_id=input_spec.example_id,
|
|
669
|
+
atom_array=sample_atom_array,
|
|
670
|
+
summary_confidences=summary_confidences,
|
|
671
|
+
confidences=confidences,
|
|
672
|
+
sample_idx=sample_idx,
|
|
673
|
+
seed=self.seed if self.seed is not None else 0,
|
|
674
|
+
)
|
|
675
|
+
)
|
|
676
|
+
|
|
677
|
+
# Save or return results
|
|
678
|
+
if out_dir:
|
|
679
|
+
# Save to disk in AlphaFold3-style directory structure
|
|
680
|
+
# Top-level: ranking_scores.csv, best model, best summary
|
|
681
|
+
dump_ranking_scores(rf3_outputs, example_out_dir, input_spec.example_id)
|
|
682
|
+
dump_top_ranked_outputs(
|
|
683
|
+
rf3_outputs,
|
|
684
|
+
example_out_dir,
|
|
685
|
+
input_spec.example_id,
|
|
686
|
+
file_type=file_type,
|
|
687
|
+
)
|
|
688
|
+
|
|
689
|
+
# Per-sample subdirectories
|
|
690
|
+
if dump_predictions:
|
|
691
|
+
for rf3_out in rf3_outputs:
|
|
692
|
+
sample_subdir = (
|
|
693
|
+
example_out_dir
|
|
694
|
+
/ f"seed-{rf3_out.seed}_sample-{rf3_out.sample_idx}"
|
|
695
|
+
)
|
|
696
|
+
rf3_out.dump(
|
|
697
|
+
out_dir=sample_subdir,
|
|
698
|
+
file_type=file_type,
|
|
699
|
+
dump_full_confidences=True,
|
|
700
|
+
)
|
|
701
|
+
|
|
702
|
+
if dump_trajectories:
|
|
703
|
+
dump_structures(
|
|
704
|
+
atom_arrays=network_output["X_denoised_L_traj"],
|
|
705
|
+
base_path=example_out_dir / "denoised",
|
|
706
|
+
one_model_per_file=True,
|
|
707
|
+
file_type=file_type,
|
|
708
|
+
)
|
|
709
|
+
dump_structures(
|
|
710
|
+
atom_arrays=network_output["X_noisy_L_traj"],
|
|
711
|
+
base_path=example_out_dir / "noisy",
|
|
712
|
+
one_model_per_file=True,
|
|
713
|
+
file_type=file_type,
|
|
714
|
+
)
|
|
715
|
+
|
|
716
|
+
ranked_logger.info(
|
|
717
|
+
f"Outputs for {input_spec.example_id} written to {example_out_dir}!"
|
|
718
|
+
)
|
|
719
|
+
else:
|
|
720
|
+
# Store in memory - return list of RF3Output objects
|
|
721
|
+
results[input_spec.example_id] = rf3_outputs
|
|
722
|
+
|
|
723
|
+
# merge results across ranks
|
|
724
|
+
self.trainer.fabric.barrier()
|
|
725
|
+
if results is not None and dist.is_initialized():
|
|
726
|
+
gathered_results = [None] * self.trainer.fabric.world_size
|
|
727
|
+
dist.all_gather_object(
|
|
728
|
+
gathered_results, results
|
|
729
|
+
) # returns a list of dicts, need to combine them
|
|
730
|
+
gathered_results = {
|
|
731
|
+
k: v for result in gathered_results for k, v in result.items()
|
|
732
|
+
} # combine the dicts into a single dict
|
|
733
|
+
results = gathered_results
|
|
734
|
+
|
|
735
|
+
return results
|