rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
rf3/utils/inference.py
ADDED
|
@@ -0,0 +1,665 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
from concurrent.futures import ProcessPoolExecutor
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from os import PathLike
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Iterable
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pandas as pd
|
|
12
|
+
from atomworks.common import as_list
|
|
13
|
+
from atomworks.enums import GroundTruthConformerPolicy
|
|
14
|
+
from atomworks.io import parse
|
|
15
|
+
from atomworks.io.parser import parse_atom_array
|
|
16
|
+
from atomworks.io.tools.inference import (
|
|
17
|
+
build_msa_paths_by_chain_id_from_component_list,
|
|
18
|
+
components_to_atom_array,
|
|
19
|
+
)
|
|
20
|
+
from atomworks.io.transforms.categories import category_to_dict
|
|
21
|
+
from atomworks.io.utils.selection import AtomSelectionStack
|
|
22
|
+
from atomworks.ml.transforms.atom_array import add_global_token_id_annotation
|
|
23
|
+
from biotite.structure import AtomArray
|
|
24
|
+
from rf3.utils.io import (
|
|
25
|
+
CIF_LIKE_EXTENSIONS,
|
|
26
|
+
DICTIONARY_LIKE_EXTENSIONS,
|
|
27
|
+
get_sharded_output_path,
|
|
28
|
+
)
|
|
29
|
+
from torch.utils.data import Dataset
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _resolve_override(override_value, source_value, param_name: str, example_id: str):
|
|
35
|
+
"""Resolve CLI override vs source value with warning."""
|
|
36
|
+
if override_value is not None and source_value:
|
|
37
|
+
logger.warning(f"CLI {param_name} overriding source value for {example_id}")
|
|
38
|
+
return override_value
|
|
39
|
+
return override_value if override_value is not None else source_value
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def extract_example_id_from_path(path: Path) -> str:
|
|
43
|
+
"""Extract example ID from file path."""
|
|
44
|
+
path_str = str(path.name)
|
|
45
|
+
# Check for known extensions (longer matches first to handle .cif.gz before .gz)
|
|
46
|
+
for ext in sorted(CIF_LIKE_EXTENSIONS | {".json"}, key=len, reverse=True):
|
|
47
|
+
if path_str.endswith(ext):
|
|
48
|
+
return path_str[: -len(ext)]
|
|
49
|
+
# Fallback to simple stem
|
|
50
|
+
return path.stem
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def extract_example_ids_from_json(path: Path) -> list[str]:
|
|
54
|
+
"""Extract example IDs from a JSON file containing one or more examples."""
|
|
55
|
+
with open(path, "r") as f:
|
|
56
|
+
data = json.load(f)
|
|
57
|
+
return [ex["name"] for ex in data]
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@dataclass
|
|
61
|
+
class InferenceInput:
|
|
62
|
+
"""Input specification for RF3 inference."""
|
|
63
|
+
|
|
64
|
+
atom_array: AtomArray
|
|
65
|
+
chain_info: dict
|
|
66
|
+
example_id: str
|
|
67
|
+
template_selection: list[str] | None = None
|
|
68
|
+
ground_truth_conformer_selection: list[str] | None = None
|
|
69
|
+
cyclic_chains: list[str] | None = None
|
|
70
|
+
|
|
71
|
+
@classmethod
|
|
72
|
+
def from_cif_path(
|
|
73
|
+
cls,
|
|
74
|
+
path: PathLike,
|
|
75
|
+
example_id: str | None = None,
|
|
76
|
+
template_selection: list[str] | str | None = None,
|
|
77
|
+
ground_truth_conformer_selection: list[str] | str | None = None,
|
|
78
|
+
) -> "InferenceInput":
|
|
79
|
+
"""Load from CIF/PDB file.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
path: Path to CIF/PDB file.
|
|
83
|
+
example_id: Example ID. Defaults to filename stem.
|
|
84
|
+
template_selection: Template selection override.
|
|
85
|
+
ground_truth_conformer_selection: Conformer selection override.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
InferenceInput object.
|
|
89
|
+
"""
|
|
90
|
+
parsed = parse(path, hydrogen_policy="remove", keep_cif_block=True)
|
|
91
|
+
|
|
92
|
+
atom_array = (
|
|
93
|
+
parsed["assemblies"]["1"][0]
|
|
94
|
+
if "assemblies" in parsed
|
|
95
|
+
else parsed["asym_unit"][0]
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
example_id = example_id or extract_example_id_from_path(Path(path))
|
|
99
|
+
|
|
100
|
+
# Extract from CIF
|
|
101
|
+
cif_template_sel = None
|
|
102
|
+
cif_conformer_sel = None
|
|
103
|
+
if "cif_block" in parsed:
|
|
104
|
+
template_dict = category_to_dict(parsed["cif_block"], "template_selection")
|
|
105
|
+
if template_dict:
|
|
106
|
+
cif_template_sel = list(template_dict.get("template_selection", []))
|
|
107
|
+
|
|
108
|
+
conformer_dict = category_to_dict(
|
|
109
|
+
parsed["cif_block"], "ground_truth_conformer_selection"
|
|
110
|
+
)
|
|
111
|
+
if conformer_dict:
|
|
112
|
+
cif_conformer_sel = list(
|
|
113
|
+
conformer_dict.get("ground_truth_conformer_selection", [])
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Resolve overrides (CLI priority)
|
|
117
|
+
final_template_sel = _resolve_override(
|
|
118
|
+
template_selection, cif_template_sel, "template_selection", example_id
|
|
119
|
+
)
|
|
120
|
+
final_conformer_sel = _resolve_override(
|
|
121
|
+
ground_truth_conformer_selection,
|
|
122
|
+
cif_conformer_sel,
|
|
123
|
+
"ground_truth_conformer_selection",
|
|
124
|
+
example_id,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
return cls(
|
|
128
|
+
atom_array=atom_array,
|
|
129
|
+
chain_info=parsed["chain_info"],
|
|
130
|
+
example_id=example_id,
|
|
131
|
+
template_selection=final_template_sel,
|
|
132
|
+
ground_truth_conformer_selection=final_conformer_sel,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
@classmethod
|
|
136
|
+
def from_json_dict(
|
|
137
|
+
cls,
|
|
138
|
+
data: dict,
|
|
139
|
+
template_selection: list[str] | str | None = None,
|
|
140
|
+
ground_truth_conformer_selection: list[str] | str | None = None,
|
|
141
|
+
) -> "InferenceInput":
|
|
142
|
+
"""Create from JSON dict with components.
|
|
143
|
+
|
|
144
|
+
CLI args override JSON metadata.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
data: JSON dictionary with components.
|
|
148
|
+
template_selection: Template selection override.
|
|
149
|
+
ground_truth_conformer_selection: Conformer selection override.
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
InferenceInput object.
|
|
153
|
+
"""
|
|
154
|
+
# Build atom_array from components
|
|
155
|
+
atom_array, component_list = components_to_atom_array(
|
|
156
|
+
data["components"],
|
|
157
|
+
bonds=data.get("bonds"),
|
|
158
|
+
return_components=True,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
parsed = parse_atom_array(
|
|
162
|
+
atom_array,
|
|
163
|
+
build_assembly="_spoof",
|
|
164
|
+
hydrogen_policy="keep",
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
chain_info = parsed.get("chain_info", {})
|
|
168
|
+
atom_array = (
|
|
169
|
+
parsed["assemblies"]["1"][0]
|
|
170
|
+
if "assemblies" in parsed
|
|
171
|
+
else parsed["asym_unit"][0]
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
# Merge MSA paths into chain_info
|
|
175
|
+
msa_paths_by_chain_id = build_msa_paths_by_chain_id_from_component_list(
|
|
176
|
+
component_list
|
|
177
|
+
)
|
|
178
|
+
if data.get("msa_paths") and isinstance(data.get("msa_paths"), dict):
|
|
179
|
+
msa_paths_by_chain_id.update(data.get("msa_paths"))
|
|
180
|
+
|
|
181
|
+
for chain_id, msa_path in msa_paths_by_chain_id.items():
|
|
182
|
+
if chain_id in chain_info:
|
|
183
|
+
chain_info[chain_id]["msa_path"] = msa_path
|
|
184
|
+
|
|
185
|
+
# Resolve overrides (CLI priority)
|
|
186
|
+
final_template_sel = _resolve_override(
|
|
187
|
+
template_selection,
|
|
188
|
+
data.get("template_selection"),
|
|
189
|
+
"template_selection",
|
|
190
|
+
data["name"],
|
|
191
|
+
)
|
|
192
|
+
final_conformer_sel = _resolve_override(
|
|
193
|
+
ground_truth_conformer_selection,
|
|
194
|
+
data.get("ground_truth_conformer_selection"),
|
|
195
|
+
"ground_truth_conformer_selection",
|
|
196
|
+
data["name"],
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
return cls(
|
|
200
|
+
atom_array=atom_array,
|
|
201
|
+
chain_info=chain_info,
|
|
202
|
+
example_id=data["name"],
|
|
203
|
+
template_selection=final_template_sel,
|
|
204
|
+
ground_truth_conformer_selection=final_conformer_sel,
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
@classmethod
|
|
208
|
+
def from_atom_array(
|
|
209
|
+
cls,
|
|
210
|
+
atom_array: AtomArray,
|
|
211
|
+
chain_info: dict | None = None,
|
|
212
|
+
example_id: str | None = None,
|
|
213
|
+
template_selection: list[str] | str | None = None,
|
|
214
|
+
ground_truth_conformer_selection: list[str] | str | None = None,
|
|
215
|
+
) -> "InferenceInput":
|
|
216
|
+
"""Create from AtomArray.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
atom_array: Input AtomArray.
|
|
220
|
+
chain_info: Chain info dict. Defaults to extracted from atom_array.
|
|
221
|
+
example_id: Example ID. Defaults to generated ID.
|
|
222
|
+
template_selection: Template selection.
|
|
223
|
+
ground_truth_conformer_selection: Conformer selection.
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
InferenceInput object.
|
|
227
|
+
"""
|
|
228
|
+
# Use parse_atom_array
|
|
229
|
+
parsed = parse_atom_array(
|
|
230
|
+
atom_array,
|
|
231
|
+
build_assembly="_spoof",
|
|
232
|
+
hydrogen_policy="keep",
|
|
233
|
+
extra_fields="all",
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
extracted_chain_info = parsed.get("chain_info", {})
|
|
237
|
+
|
|
238
|
+
# Merge with provided chain_info (provided takes priority)
|
|
239
|
+
if chain_info is not None:
|
|
240
|
+
for chain_id, chain_data in chain_info.items():
|
|
241
|
+
if chain_id in extracted_chain_info:
|
|
242
|
+
extracted_chain_info[chain_id].update(chain_data)
|
|
243
|
+
else:
|
|
244
|
+
extracted_chain_info[chain_id] = chain_data
|
|
245
|
+
|
|
246
|
+
final_atom_array = (
|
|
247
|
+
parsed["assemblies"]["1"][0]
|
|
248
|
+
if "assemblies" in parsed
|
|
249
|
+
else parsed["asym_unit"][0]
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
return cls(
|
|
253
|
+
atom_array=final_atom_array,
|
|
254
|
+
chain_info=extracted_chain_info,
|
|
255
|
+
example_id=example_id or f"inference_{id(atom_array)}",
|
|
256
|
+
template_selection=template_selection,
|
|
257
|
+
ground_truth_conformer_selection=ground_truth_conformer_selection,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
def to_pipeline_input(self) -> dict:
|
|
261
|
+
"""Apply transformations and return input for Transform pipeline.
|
|
262
|
+
|
|
263
|
+
Returns:
|
|
264
|
+
Pipeline input dict with example_id, atom_array, and chain_info.
|
|
265
|
+
"""
|
|
266
|
+
atom_array = self.atom_array.copy()
|
|
267
|
+
|
|
268
|
+
# Apply template and conformer selections
|
|
269
|
+
atom_array = apply_conformer_and_template_selections(
|
|
270
|
+
atom_array,
|
|
271
|
+
template_selection=self.template_selection,
|
|
272
|
+
ground_truth_conformer_selection=self.ground_truth_conformer_selection,
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
if self.cyclic_chains:
|
|
276
|
+
atom_array = cyclize_atom_array(atom_array, self.cyclic_chains)
|
|
277
|
+
|
|
278
|
+
return {
|
|
279
|
+
"example_id": self.example_id,
|
|
280
|
+
"atom_array": atom_array,
|
|
281
|
+
"chain_info": self.chain_info,
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def _process_single_path(
|
|
286
|
+
path: Path,
|
|
287
|
+
existing_outputs_dir: Path | None,
|
|
288
|
+
sharding_pattern: str | None,
|
|
289
|
+
template_selection: list[str] | str | None,
|
|
290
|
+
ground_truth_conformer_selection: list[str] | str | None,
|
|
291
|
+
) -> list[InferenceInput]:
|
|
292
|
+
"""Worker function to process a single input file path.
|
|
293
|
+
|
|
294
|
+
This function is defined at module level to be picklable for multiprocessing.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
path: Path to a single input file.
|
|
298
|
+
existing_outputs_dir: If set, skip examples with existing outputs.
|
|
299
|
+
sharding_pattern: Sharding pattern for output paths.
|
|
300
|
+
template_selection: Override for template selection.
|
|
301
|
+
ground_truth_conformer_selection: Override for conformer selection.
|
|
302
|
+
|
|
303
|
+
Returns:
|
|
304
|
+
List of InferenceInput objects (may be empty if file is skipped).
|
|
305
|
+
"""
|
|
306
|
+
|
|
307
|
+
def example_exists(example_id: str) -> bool:
|
|
308
|
+
"""Check if example already has predictions (sharding-aware)."""
|
|
309
|
+
if not existing_outputs_dir:
|
|
310
|
+
return False
|
|
311
|
+
example_dir = get_sharded_output_path(
|
|
312
|
+
example_id, existing_outputs_dir, sharding_pattern
|
|
313
|
+
)
|
|
314
|
+
return (example_dir / f"{example_id}_metrics.csv").exists()
|
|
315
|
+
|
|
316
|
+
inference_inputs = []
|
|
317
|
+
|
|
318
|
+
if path.suffix == ".json":
|
|
319
|
+
# Load JSON and convert each entry
|
|
320
|
+
with open(path, "r") as f:
|
|
321
|
+
data = json.load(f)
|
|
322
|
+
|
|
323
|
+
# Normalize to list
|
|
324
|
+
if isinstance(data, dict):
|
|
325
|
+
data = [data]
|
|
326
|
+
|
|
327
|
+
for item in data:
|
|
328
|
+
example_id = item["name"]
|
|
329
|
+
if not example_exists(example_id):
|
|
330
|
+
inference_inputs.append(
|
|
331
|
+
InferenceInput.from_json_dict(
|
|
332
|
+
item,
|
|
333
|
+
template_selection=template_selection,
|
|
334
|
+
ground_truth_conformer_selection=ground_truth_conformer_selection,
|
|
335
|
+
)
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
elif any(path.name.endswith(ext) for ext in CIF_LIKE_EXTENSIONS):
|
|
339
|
+
# CIF/PDB file
|
|
340
|
+
example_id = extract_example_id_from_path(path)
|
|
341
|
+
if not example_exists(example_id):
|
|
342
|
+
inference_inputs.append(
|
|
343
|
+
InferenceInput.from_cif_path(
|
|
344
|
+
path,
|
|
345
|
+
example_id=example_id,
|
|
346
|
+
template_selection=template_selection,
|
|
347
|
+
ground_truth_conformer_selection=ground_truth_conformer_selection,
|
|
348
|
+
)
|
|
349
|
+
)
|
|
350
|
+
else:
|
|
351
|
+
raise ValueError(
|
|
352
|
+
f"Unsupported file type: {path.suffix} (path: {path}). "
|
|
353
|
+
f"Supported: {CIF_LIKE_EXTENSIONS | DICTIONARY_LIKE_EXTENSIONS}"
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
return inference_inputs
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def prepare_inference_inputs_from_paths(
|
|
360
|
+
inputs: PathLike | list[PathLike],
|
|
361
|
+
existing_outputs_dir: PathLike | None = None,
|
|
362
|
+
sharding_pattern: str | None = None,
|
|
363
|
+
template_selection: list[str] | str | None = None,
|
|
364
|
+
ground_truth_conformer_selection: list[str] | str | None = None,
|
|
365
|
+
) -> list[InferenceInput]:
|
|
366
|
+
"""Load InferenceInput objects from file paths.
|
|
367
|
+
|
|
368
|
+
Handles CIF, PDB, and JSON files. Filters out existing outputs if requested.
|
|
369
|
+
Uses multiprocessing to parallelize file loading across all available CPUs.
|
|
370
|
+
|
|
371
|
+
Args:
|
|
372
|
+
inputs: File path(s) or directory path(s).
|
|
373
|
+
existing_outputs_dir: If set, skip examples with existing outputs.
|
|
374
|
+
sharding_pattern: Sharding pattern for output paths.
|
|
375
|
+
template_selection: Override for template selection (applied to all inputs).
|
|
376
|
+
ground_truth_conformer_selection: Override for conformer selection (applied to all inputs).
|
|
377
|
+
|
|
378
|
+
Returns:
|
|
379
|
+
List of InferenceInput objects.
|
|
380
|
+
"""
|
|
381
|
+
input_paths = as_list(inputs)
|
|
382
|
+
|
|
383
|
+
# Collect all raw input files (reusing logic from build_file_paths_for_prediction)
|
|
384
|
+
paths_to_raw_input_files = []
|
|
385
|
+
for _path in input_paths:
|
|
386
|
+
if Path(_path).is_dir():
|
|
387
|
+
# Scan directory for supported file types (JSON + CIF-like)
|
|
388
|
+
for file_type in CIF_LIKE_EXTENSIONS | DICTIONARY_LIKE_EXTENSIONS:
|
|
389
|
+
paths_to_raw_input_files.extend(Path(_path).glob(f"*{file_type}"))
|
|
390
|
+
else:
|
|
391
|
+
paths_to_raw_input_files.append(Path(_path))
|
|
392
|
+
|
|
393
|
+
# Determine number of CPUs to use
|
|
394
|
+
num_cpus = min(os.cpu_count() or 1, len(paths_to_raw_input_files))
|
|
395
|
+
logger.info(
|
|
396
|
+
f"Processing {len(paths_to_raw_input_files)} files using {num_cpus} CPUs"
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
# Convert existing_outputs_dir to Path if needed
|
|
400
|
+
existing_outputs_dir_path = (
|
|
401
|
+
Path(existing_outputs_dir) if existing_outputs_dir else None
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
# Process files in parallel using all available CPUs
|
|
405
|
+
inference_inputs = []
|
|
406
|
+
with ProcessPoolExecutor(max_workers=num_cpus) as executor:
|
|
407
|
+
# Submit all tasks
|
|
408
|
+
futures = [
|
|
409
|
+
executor.submit(
|
|
410
|
+
_process_single_path,
|
|
411
|
+
path,
|
|
412
|
+
existing_outputs_dir_path,
|
|
413
|
+
sharding_pattern,
|
|
414
|
+
template_selection,
|
|
415
|
+
ground_truth_conformer_selection,
|
|
416
|
+
)
|
|
417
|
+
for path in paths_to_raw_input_files
|
|
418
|
+
]
|
|
419
|
+
|
|
420
|
+
# Collect results as they complete
|
|
421
|
+
for future in futures:
|
|
422
|
+
result = future.result()
|
|
423
|
+
inference_inputs.extend(result)
|
|
424
|
+
|
|
425
|
+
logger.info(f"Loaded {len(inference_inputs)} inference inputs")
|
|
426
|
+
return inference_inputs
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
def apply_atom_selection_mask(
|
|
430
|
+
atom_array: AtomArray, selection_list: Iterable[str]
|
|
431
|
+
) -> np.ndarray:
|
|
432
|
+
"""Return a combined boolean mask for a list of AtomSelectionStack queries.
|
|
433
|
+
|
|
434
|
+
Args:
|
|
435
|
+
atom_array: AtomArray to select from.
|
|
436
|
+
selection_list: Iterable of AtomSelectionStack queries (e.g., "*/LIG", "A1-10").
|
|
437
|
+
|
|
438
|
+
Returns:
|
|
439
|
+
A boolean numpy array of shape (num_atoms,) where True indicates a selected atom.
|
|
440
|
+
"""
|
|
441
|
+
selection_mask = np.zeros(len(atom_array), dtype=bool)
|
|
442
|
+
for selection in selection_list:
|
|
443
|
+
if not selection:
|
|
444
|
+
continue
|
|
445
|
+
try:
|
|
446
|
+
selector = AtomSelectionStack.from_query(selection)
|
|
447
|
+
mask = selector.get_mask(atom_array)
|
|
448
|
+
selection_mask = selection_mask | mask
|
|
449
|
+
except Exception as exc: # Defensive: keep going if one selection fails
|
|
450
|
+
logging.warning(
|
|
451
|
+
"Failed to parse selection '%s': %s. Skipping.", selection, exc
|
|
452
|
+
)
|
|
453
|
+
return selection_mask
|
|
454
|
+
|
|
455
|
+
|
|
456
|
+
def apply_template_selection(
|
|
457
|
+
atom_array: AtomArray, template_selection: list[str] | str | None
|
|
458
|
+
) -> AtomArray:
|
|
459
|
+
"""Apply token-level template selection to `atom_array` with OR semantics.
|
|
460
|
+
|
|
461
|
+
If the `is_input_file_templated` annotation already exists, this function ORs
|
|
462
|
+
the new selection with the existing annotation. Otherwise, it creates it.
|
|
463
|
+
|
|
464
|
+
Args:
|
|
465
|
+
atom_array: AtomArray to annotate.
|
|
466
|
+
template_selection: Selection string(s). Single strings are converted to lists. If None/empty, no-op.
|
|
467
|
+
|
|
468
|
+
Returns:
|
|
469
|
+
The same AtomArray with `is_input_file_templated` updated.
|
|
470
|
+
"""
|
|
471
|
+
# Convert to list if needed
|
|
472
|
+
template_selection_list = as_list(template_selection) if template_selection else []
|
|
473
|
+
|
|
474
|
+
if not template_selection_list:
|
|
475
|
+
# Ensure the annotation exists even if no selection provided
|
|
476
|
+
if "is_input_file_templated" not in atom_array.get_annotation_categories():
|
|
477
|
+
atom_array.set_annotation(
|
|
478
|
+
"is_input_file_templated", np.zeros(len(atom_array), dtype=bool)
|
|
479
|
+
)
|
|
480
|
+
return atom_array
|
|
481
|
+
|
|
482
|
+
# Build new mask
|
|
483
|
+
selection_mask = apply_atom_selection_mask(atom_array, template_selection_list)
|
|
484
|
+
logging.info(
|
|
485
|
+
"Selected %d atoms for token-level templating with %d syntaxes",
|
|
486
|
+
int(np.sum(selection_mask)),
|
|
487
|
+
len([s for s in template_selection_list if s]),
|
|
488
|
+
)
|
|
489
|
+
|
|
490
|
+
# OR with existing annotation if present
|
|
491
|
+
if "is_input_file_templated" in atom_array.get_annotation_categories():
|
|
492
|
+
existing = atom_array.get_annotation("is_input_file_templated").astype(bool)
|
|
493
|
+
selection_mask = existing | selection_mask
|
|
494
|
+
atom_array.set_annotation("is_input_file_templated", selection_mask)
|
|
495
|
+
return atom_array
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
def apply_ground_truth_conformer_selection(
|
|
499
|
+
atom_array: AtomArray, ground_truth_conformer_selection: list[str] | str | None
|
|
500
|
+
) -> AtomArray:
|
|
501
|
+
"""Apply ground-truth conformer policy selection with union semantics.
|
|
502
|
+
|
|
503
|
+
Behavior:
|
|
504
|
+
- Creates `ground_truth_conformer_policy` if missing and initializes to IGNORE.
|
|
505
|
+
- For selected atoms, sets policy to at least ADD without downgrading any
|
|
506
|
+
existing policy (e.g., preserves REPLACE if present).
|
|
507
|
+
|
|
508
|
+
Args:
|
|
509
|
+
atom_array: AtomArray to annotate.
|
|
510
|
+
ground_truth_conformer_selection: Selection string(s). Single strings are converted to lists. If None/empty, no-op.
|
|
511
|
+
|
|
512
|
+
Returns:
|
|
513
|
+
The same AtomArray with `ground_truth_conformer_policy` updated.
|
|
514
|
+
"""
|
|
515
|
+
# Convert to list if needed
|
|
516
|
+
ground_truth_conformer_selection_list = (
|
|
517
|
+
as_list(ground_truth_conformer_selection)
|
|
518
|
+
if ground_truth_conformer_selection
|
|
519
|
+
else []
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
if not ground_truth_conformer_selection_list:
|
|
523
|
+
if (
|
|
524
|
+
"ground_truth_conformer_policy"
|
|
525
|
+
not in atom_array.get_annotation_categories()
|
|
526
|
+
):
|
|
527
|
+
atom_array.set_annotation(
|
|
528
|
+
"ground_truth_conformer_policy",
|
|
529
|
+
np.full(
|
|
530
|
+
len(atom_array), GroundTruthConformerPolicy.IGNORE, dtype=np.int8
|
|
531
|
+
),
|
|
532
|
+
)
|
|
533
|
+
return atom_array
|
|
534
|
+
|
|
535
|
+
# Ensure annotation exists
|
|
536
|
+
if "ground_truth_conformer_policy" not in atom_array.get_annotation_categories():
|
|
537
|
+
atom_array.set_annotation(
|
|
538
|
+
"ground_truth_conformer_policy",
|
|
539
|
+
np.full(len(atom_array), GroundTruthConformerPolicy.IGNORE, dtype=np.int8),
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
selection_mask = apply_atom_selection_mask(
|
|
543
|
+
atom_array, ground_truth_conformer_selection_list
|
|
544
|
+
)
|
|
545
|
+
logging.info(
|
|
546
|
+
"Selected %d atoms for ground-truth conformer policy with %d syntaxes",
|
|
547
|
+
int(np.sum(selection_mask)),
|
|
548
|
+
len([s for s in ground_truth_conformer_selection_list if s]),
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
existing = atom_array.get_annotation("ground_truth_conformer_policy")
|
|
552
|
+
existing[selection_mask] = GroundTruthConformerPolicy.ADD
|
|
553
|
+
atom_array.set_annotation("ground_truth_conformer_policy", existing)
|
|
554
|
+
|
|
555
|
+
return atom_array
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
def apply_conformer_and_template_selections(
|
|
559
|
+
atom_array: AtomArray,
|
|
560
|
+
template_selection: list[str] | str | None = None,
|
|
561
|
+
ground_truth_conformer_selection: list[str] | str | None = None,
|
|
562
|
+
) -> AtomArray:
|
|
563
|
+
"""Apply template and conformer selections and basic preprocessing.
|
|
564
|
+
|
|
565
|
+
This function replaces the former class method `prepare_atom_array`.
|
|
566
|
+
|
|
567
|
+
- Applies `apply_template_selection` then `apply_ground_truth_conformer_selection`.
|
|
568
|
+
- Replaces NaN coordinates with -1 for safety.
|
|
569
|
+
|
|
570
|
+
Args:
|
|
571
|
+
atom_array: AtomArray to prepare.
|
|
572
|
+
template_selection: Template selection string(s). Single strings are converted to lists.
|
|
573
|
+
ground_truth_conformer_selection: Ground-truth conformer selection string(s). Single strings are converted to lists.
|
|
574
|
+
|
|
575
|
+
Returns:
|
|
576
|
+
The same AtomArray with `is_input_file_templated` and `ground_truth_conformer_policy` updated.
|
|
577
|
+
"""
|
|
578
|
+
atom_array = apply_template_selection(atom_array, template_selection)
|
|
579
|
+
atom_array = apply_ground_truth_conformer_selection(
|
|
580
|
+
atom_array, ground_truth_conformer_selection
|
|
581
|
+
)
|
|
582
|
+
# Safety: avoid unexpected behavior downstream
|
|
583
|
+
atom_array.coord[np.isnan(atom_array.coord)] = -1
|
|
584
|
+
return atom_array
|
|
585
|
+
|
|
586
|
+
|
|
587
|
+
def cyclize_atom_array(atom_array: AtomArray, cyclic_chains: list[str]) -> AtomArray:
|
|
588
|
+
"""Cyclize the atom array by positioining the termini properly if not already done.
|
|
589
|
+
|
|
590
|
+
Behavior:
|
|
591
|
+
- Positions the last carbon atom in the chain to be 1.3 Angstroms away from the first nitrogen atom if they are not already close.
|
|
592
|
+
- Adds a bond between the termini for proper cif output.
|
|
593
|
+
|
|
594
|
+
Args:
|
|
595
|
+
atom_array: AtomArray to cyclize.
|
|
596
|
+
cyclic_chains: List of chain IDs to cyclize.
|
|
597
|
+
|
|
598
|
+
Returns:
|
|
599
|
+
The same AtomArray with the specified chains cyclized.
|
|
600
|
+
"""
|
|
601
|
+
for chain in cyclic_chains:
|
|
602
|
+
# Find the first nitrogen atom in the chain
|
|
603
|
+
nitrogen_mask = (atom_array.chain_id == chain) & (atom_array.atom_name == "N")
|
|
604
|
+
nitrogen_mask_indices = np.where(nitrogen_mask)[0]
|
|
605
|
+
first_nitrogen_index = nitrogen_mask_indices[0]
|
|
606
|
+
nitrogen_coord = atom_array.coord[first_nitrogen_index]
|
|
607
|
+
|
|
608
|
+
# move the last carbon atom in the chain to be 1.3 Angstroms away from the nitrogen
|
|
609
|
+
carbon_mask = (atom_array.chain_id == chain) & (atom_array.atom_name == "C")
|
|
610
|
+
carbon_mask_indices = np.where(carbon_mask)[0]
|
|
611
|
+
last_carbon_index = carbon_mask_indices[-1]
|
|
612
|
+
# check if the last carbon is already close to the nitrogen
|
|
613
|
+
termini_distance = np.linalg.norm(
|
|
614
|
+
atom_array.coord[last_carbon_index] - nitrogen_coord
|
|
615
|
+
)
|
|
616
|
+
if not (termini_distance < 1.5 and termini_distance > 0.5):
|
|
617
|
+
atom_array.coord[last_carbon_index] = nitrogen_coord + np.array(
|
|
618
|
+
[1.3, 0.0, 0.0]
|
|
619
|
+
)
|
|
620
|
+
|
|
621
|
+
# add a bond between the nitrogen and carbon so output cif has a connection
|
|
622
|
+
atom_array.bonds.add_bond(first_nitrogen_index, last_carbon_index)
|
|
623
|
+
atom_array.bonds.add_bond(last_carbon_index, first_nitrogen_index)
|
|
624
|
+
|
|
625
|
+
return atom_array
|
|
626
|
+
|
|
627
|
+
|
|
628
|
+
class InferenceInputDataset(Dataset):
|
|
629
|
+
"""
|
|
630
|
+
Dataset for inference inputs. Also has a length key telling you the number of tokens in each example for LoadBalancedDistributedSampler.
|
|
631
|
+
|
|
632
|
+
To calculate the length of each example, we need to add the token_id annotation to the atom_array. If it doesn't exist yet, we add it,
|
|
633
|
+
calculate the length, and then remove it since the downstream pipeline may not be expecting it. That means the num_tokens key may not ultimately
|
|
634
|
+
be the same as what's actually used in the model, but this is a close enough approximation for load balancing.
|
|
635
|
+
|
|
636
|
+
Args:
|
|
637
|
+
inference_inputs: List of InferenceInput objects to wrap in a Dataset.
|
|
638
|
+
"""
|
|
639
|
+
|
|
640
|
+
def __init__(self, inference_inputs: list[InferenceInput]):
|
|
641
|
+
self.inference_inputs = inference_inputs
|
|
642
|
+
self.key_to_balance = "num_tokens_approximate"
|
|
643
|
+
|
|
644
|
+
# LoadBalancedDistributedSampler checks in dataset.data[key_to_balance] to determine balancing.
|
|
645
|
+
# That means we need to make a dataframe in self.data that has a column with the key_to_balance.
|
|
646
|
+
atom_array_token_lens = []
|
|
647
|
+
for inf_input in self.inference_inputs:
|
|
648
|
+
if "token_id" not in inf_input.atom_array.get_annotation_categories():
|
|
649
|
+
inf_input.atom_array = add_global_token_id_annotation(
|
|
650
|
+
inf_input.atom_array
|
|
651
|
+
)
|
|
652
|
+
num_tokens = len(np.unique(inf_input.atom_array.token_id))
|
|
653
|
+
|
|
654
|
+
# remove the token_id annotation since the pipeline may not be expecting it
|
|
655
|
+
inf_input.atom_array.del_annotation("token_id")
|
|
656
|
+
else:
|
|
657
|
+
num_tokens = len(np.unique(inf_input.atom_array.token_id))
|
|
658
|
+
atom_array_token_lens.append(num_tokens)
|
|
659
|
+
self.data = pd.DataFrame({self.key_to_balance: atom_array_token_lens})
|
|
660
|
+
|
|
661
|
+
def __len__(self):
|
|
662
|
+
return len(self.inference_inputs)
|
|
663
|
+
|
|
664
|
+
def __getitem__(self, idx):
|
|
665
|
+
return self.inference_inputs[idx]
|