rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
rfd3/engine.py
ADDED
|
@@ -0,0 +1,543 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
import time
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from os import PathLike
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Dict, List, Optional
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
import yaml
|
|
12
|
+
from atomworks.io.utils.io_utils import to_cif_file
|
|
13
|
+
from biotite.structure import AtomArray, AtomArrayStack
|
|
14
|
+
from toolz import merge_with
|
|
15
|
+
|
|
16
|
+
from foundry.common import exists
|
|
17
|
+
from foundry.inference_engines.base import BaseInferenceEngine
|
|
18
|
+
from foundry.inference_engines.checkpoint_registry import REGISTERED_CHECKPOINTS
|
|
19
|
+
from foundry.utils.alignment import weighted_rigid_align
|
|
20
|
+
from foundry.utils.ddp import RankedLogger
|
|
21
|
+
from rfd3.constants import SAVED_CONDITIONING_ANNOTATIONS
|
|
22
|
+
from rfd3.inference.datasets import (
|
|
23
|
+
assemble_distributed_inference_loader_from_json,
|
|
24
|
+
)
|
|
25
|
+
from rfd3.inference.input_parsing import DesignInputSpecification
|
|
26
|
+
from rfd3.model.inference_sampler import SampleDiffusionConfig
|
|
27
|
+
from rfd3.utils.inference import ensure_input_is_abspath
|
|
28
|
+
from rfd3.utils.io import (
|
|
29
|
+
CIF_LIKE_EXTENSIONS,
|
|
30
|
+
build_stack_from_atom_array_and_batched_coords,
|
|
31
|
+
extract_example_id_from_path,
|
|
32
|
+
find_files_with_extension,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
logging.basicConfig(level=logging.INFO)
|
|
36
|
+
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass(kw_only=True)
|
|
40
|
+
class RFD3InferenceConfig:
|
|
41
|
+
ckpt_path: str | Path = 'rfd3' # Defaults to foundry installation upon instantiation
|
|
42
|
+
diffusion_batch_size: int = 16
|
|
43
|
+
|
|
44
|
+
# RFD3 specific
|
|
45
|
+
skip_existing: bool = False
|
|
46
|
+
json_keys_subset: Optional[List[str]] = None
|
|
47
|
+
skip_existing: bool = True
|
|
48
|
+
specification: Optional[dict] = field(default_factory=dict)
|
|
49
|
+
inference_sampler: SampleDiffusionConfig | dict = field(default_factory=dict)
|
|
50
|
+
|
|
51
|
+
# Saving args
|
|
52
|
+
cleanup_guideposts: bool = True
|
|
53
|
+
cleanup_virtual_atoms: bool = True
|
|
54
|
+
read_sequence_from_sequence_head: bool = True
|
|
55
|
+
output_full_json: bool = True
|
|
56
|
+
|
|
57
|
+
# Prefix to add to all output samples
|
|
58
|
+
# Default: None -> f'{jsonfilebasename}_{jsonkey}_{batch}_{model}'
|
|
59
|
+
# Otherwise: string -> f'{string}{jsonkey}_{batch}_{model}'
|
|
60
|
+
# e.g. Empty string -> f'{jsonkey}_{batch}_{model}'
|
|
61
|
+
# e.g. Chunk string -> f'{chunkprefix_}{jsonkey}_{batch}_{model}' (pipelines usage)
|
|
62
|
+
global_prefix: Optional[str] = None
|
|
63
|
+
dump_prediction_metadata_json: bool = True
|
|
64
|
+
dump_trajectories: bool = False
|
|
65
|
+
align_trajectory_structures: bool = False
|
|
66
|
+
prevalidate_inputs: bool = True
|
|
67
|
+
low_memory_mode: bool = (
|
|
68
|
+
False # False for standard mode, True for memory efficient tokenization mode
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# Other:
|
|
72
|
+
num_nodes: int = 1
|
|
73
|
+
devices_per_node: int = 1
|
|
74
|
+
verbose: bool = False
|
|
75
|
+
seed: Optional[int] = None
|
|
76
|
+
|
|
77
|
+
# For use as mapping:
|
|
78
|
+
def keys(self):
|
|
79
|
+
return self.__dataclass_fields__.keys()
|
|
80
|
+
|
|
81
|
+
def __getitem__(self, key):
|
|
82
|
+
return getattr(self, key)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@dataclass
|
|
86
|
+
class RFD3Output:
|
|
87
|
+
atom_array: AtomArray
|
|
88
|
+
metadata: dict
|
|
89
|
+
example_id: str
|
|
90
|
+
denoised_trajectory_stack: Optional[AtomArrayStack] = None
|
|
91
|
+
noisy_trajectory_stack: Optional[AtomArrayStack] = None
|
|
92
|
+
|
|
93
|
+
def dump(
|
|
94
|
+
self,
|
|
95
|
+
out_dir,
|
|
96
|
+
verbose=True,
|
|
97
|
+
):
|
|
98
|
+
base_path = os.path.join(out_dir, self.example_id)
|
|
99
|
+
base_path = Path(base_path).absolute()
|
|
100
|
+
to_cif_file(
|
|
101
|
+
self.atom_array,
|
|
102
|
+
base_path,
|
|
103
|
+
file_type="cif.gz",
|
|
104
|
+
include_entity_poly=False,
|
|
105
|
+
extra_fields=SAVED_CONDITIONING_ANNOTATIONS,
|
|
106
|
+
)
|
|
107
|
+
if self.metadata:
|
|
108
|
+
with open(f"{base_path}.json", "w") as f:
|
|
109
|
+
json.dump(self.metadata, f, indent=4)
|
|
110
|
+
|
|
111
|
+
# Trajectory saving
|
|
112
|
+
prefix = str(base_path)[:-1].rstrip("_model_")
|
|
113
|
+
suffix = str(base_path)[-1]
|
|
114
|
+
if self.denoised_trajectory_stack is not None:
|
|
115
|
+
to_cif_file(
|
|
116
|
+
self.denoised_trajectory_stack,
|
|
117
|
+
"_denoised_model_".join([prefix, suffix]),
|
|
118
|
+
file_type="cif.gz",
|
|
119
|
+
include_entity_poly=False,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
if self.noisy_trajectory_stack is not None:
|
|
123
|
+
to_cif_file(
|
|
124
|
+
self.noisy_trajectory_stack,
|
|
125
|
+
"_noisy_model_".join([prefix, suffix]),
|
|
126
|
+
file_type="cif.gz",
|
|
127
|
+
include_entity_poly=False,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
if verbose:
|
|
131
|
+
ranked_logger.info(f"Outputs for {self.example_id} written to {base_path}.")
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class RFD3InferenceEngine(BaseInferenceEngine):
|
|
135
|
+
"""Inference engine for RFdiffusion3"""
|
|
136
|
+
|
|
137
|
+
def __init__(
|
|
138
|
+
self,
|
|
139
|
+
*,
|
|
140
|
+
# Default input handling args
|
|
141
|
+
skip_existing: bool,
|
|
142
|
+
json_keys_subset: None | List[str],
|
|
143
|
+
prevalidate_inputs: bool,
|
|
144
|
+
# Base inference engine args
|
|
145
|
+
diffusion_batch_size: int,
|
|
146
|
+
inference_sampler: dict,
|
|
147
|
+
specification: dict | None,
|
|
148
|
+
# Structure dumping arguments
|
|
149
|
+
global_prefix: str | None,
|
|
150
|
+
cleanup_guideposts: bool,
|
|
151
|
+
cleanup_virtual_atoms: bool,
|
|
152
|
+
read_sequence_from_sequence_head: bool,
|
|
153
|
+
output_full_json: bool,
|
|
154
|
+
dump_prediction_metadata_json: bool,
|
|
155
|
+
dump_trajectories: bool,
|
|
156
|
+
align_trajectory_structures: bool,
|
|
157
|
+
low_memory_mode: bool,
|
|
158
|
+
**kwargs,
|
|
159
|
+
):
|
|
160
|
+
super().__init__(
|
|
161
|
+
transform_overrides={"diffusion_batch_size": diffusion_batch_size},
|
|
162
|
+
inference_sampler_overrides={**inference_sampler},
|
|
163
|
+
trainer_overrides={
|
|
164
|
+
"cleanup_guideposts": cleanup_guideposts,
|
|
165
|
+
"cleanup_virtual_atoms": cleanup_virtual_atoms,
|
|
166
|
+
"read_sequence_from_sequence_head": read_sequence_from_sequence_head,
|
|
167
|
+
"output_full_json": output_full_json,
|
|
168
|
+
},
|
|
169
|
+
**kwargs,
|
|
170
|
+
)
|
|
171
|
+
# save
|
|
172
|
+
self.specification_overrides = dict(specification or {})
|
|
173
|
+
|
|
174
|
+
# Setup output directories and args
|
|
175
|
+
self.global_prefix = global_prefix
|
|
176
|
+
self.json_keys_subset = json_keys_subset
|
|
177
|
+
self.prevalidate_inputs = prevalidate_inputs
|
|
178
|
+
self.skip_existing = skip_existing
|
|
179
|
+
|
|
180
|
+
# Saving / other args
|
|
181
|
+
self.dump_prediction_metadata_json = dump_prediction_metadata_json
|
|
182
|
+
self.dump_trajectories = dump_trajectories
|
|
183
|
+
self.align_trajectory_structures = align_trajectory_structures
|
|
184
|
+
if not cleanup_guideposts:
|
|
185
|
+
ranked_logger.warning(
|
|
186
|
+
"Guideposts will not be cleaned up. This is intended for debugging purposes."
|
|
187
|
+
)
|
|
188
|
+
if not cleanup_virtual_atoms:
|
|
189
|
+
ranked_logger.warning(
|
|
190
|
+
"Virtual atoms will not be cleaned up. Some tools like MPNN may run, but outputs will not be like native structures."
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
# Check which example ids already exist in the output directory
|
|
194
|
+
if low_memory_mode:
|
|
195
|
+
ranked_logger.info("Low memory mode enabled.")
|
|
196
|
+
# HACK: Set attribute to the diffusion module
|
|
197
|
+
os.environ["RFD3_LOW_MEMORY_MODE"] = "1"
|
|
198
|
+
|
|
199
|
+
def run(
|
|
200
|
+
self,
|
|
201
|
+
*,
|
|
202
|
+
inputs: str | PathLike | AtomArray | DesignInputSpecification,
|
|
203
|
+
n_batches: int | None = None,
|
|
204
|
+
out_dir: str | PathLike | None = None,
|
|
205
|
+
):
|
|
206
|
+
self._set_out_dir(out_dir)
|
|
207
|
+
inputs = self._canonicalize_inputs(inputs)
|
|
208
|
+
design_specifications = self._multiply_specifications(
|
|
209
|
+
inputs=inputs,
|
|
210
|
+
n_batches=n_batches,
|
|
211
|
+
)
|
|
212
|
+
# init before
|
|
213
|
+
self.initialize()
|
|
214
|
+
outputs = self._run_multi(design_specifications)
|
|
215
|
+
return outputs
|
|
216
|
+
|
|
217
|
+
def _set_out_dir(self, out_dir: str | PathLike | None):
|
|
218
|
+
out_dir = Path(out_dir) if out_dir else None
|
|
219
|
+
if out_dir:
|
|
220
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
221
|
+
ranked_logger.info(f"Outputs will be written to {out_dir.resolve()}.")
|
|
222
|
+
self.out_dir = out_dir
|
|
223
|
+
|
|
224
|
+
def _run_multi(self, specs) -> None | Dict[str, List[RFD3Output]]:
|
|
225
|
+
# ==============================================================================
|
|
226
|
+
# Prepare pipeline and inference loader
|
|
227
|
+
# ==============================================================================
|
|
228
|
+
loader = assemble_distributed_inference_loader_from_json(
|
|
229
|
+
# Passed directly to ContigJSONDataset
|
|
230
|
+
data=specs,
|
|
231
|
+
transform=self.pipeline,
|
|
232
|
+
name="inference-dataset",
|
|
233
|
+
cif_parser_args=None,
|
|
234
|
+
subset_to_keys=None,
|
|
235
|
+
eval_every_n=1,
|
|
236
|
+
# Sampler args
|
|
237
|
+
world_size=self.trainer.fabric.world_size,
|
|
238
|
+
rank=self.trainer.fabric.global_rank,
|
|
239
|
+
)
|
|
240
|
+
loader = self.trainer.fabric.setup_dataloaders(
|
|
241
|
+
loader,
|
|
242
|
+
use_distributed_sampler=False,
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
# ==============================================================================
|
|
246
|
+
# Evaluate, using `validation_step`
|
|
247
|
+
# ==============================================================================
|
|
248
|
+
outputs = {}
|
|
249
|
+
for batch_idx, batch in enumerate(loader):
|
|
250
|
+
pipeline_output = batch[0]
|
|
251
|
+
example_id = pipeline_output["example_id"]
|
|
252
|
+
|
|
253
|
+
# Run model
|
|
254
|
+
output_list = self._model_forward(pipeline_output)
|
|
255
|
+
if self.out_dir:
|
|
256
|
+
for output in output_list:
|
|
257
|
+
output.dump(out_dir=self.out_dir)
|
|
258
|
+
else:
|
|
259
|
+
outputs[example_id] = output_list
|
|
260
|
+
return outputs
|
|
261
|
+
|
|
262
|
+
def _model_forward(self, pipeline_output) -> List[RFD3Output]:
|
|
263
|
+
# Wraps around the trainer validation step to create atom arrays for saving.
|
|
264
|
+
t0 = time.time()
|
|
265
|
+
with torch.no_grad():
|
|
266
|
+
pipeline_output = self.trainer.fabric.to_device(pipeline_output)
|
|
267
|
+
output_val = self.trainer.validation_step(
|
|
268
|
+
batch=pipeline_output,
|
|
269
|
+
batch_idx=0,
|
|
270
|
+
compute_metrics=False,
|
|
271
|
+
)
|
|
272
|
+
t_end = time.time()
|
|
273
|
+
|
|
274
|
+
# Add additional information to prediction metadata
|
|
275
|
+
if self.dump_trajectories:
|
|
276
|
+
X_noisy_L_traj = torch.stack(
|
|
277
|
+
output_val["network_output"]["X_noisy_L_traj"]
|
|
278
|
+
).transpose(0, 1) # [D, N_steps, L, 3]
|
|
279
|
+
X_denoised_L_traj = torch.stack(
|
|
280
|
+
output_val["network_output"]["X_denoised_L_traj"]
|
|
281
|
+
).transpose(0, 1) # [D, N_steps, L, 3]
|
|
282
|
+
|
|
283
|
+
outputs = []
|
|
284
|
+
for idx in range(len(output_val["predicted_atom_array_stack"])):
|
|
285
|
+
if self.dump_prediction_metadata_json:
|
|
286
|
+
ckpt = Path(self.ckpt_path)
|
|
287
|
+
if ckpt.is_symlink():
|
|
288
|
+
ckpt = ckpt.resolve(strict=True) # follow symlink to target
|
|
289
|
+
output_val["prediction_metadata"][idx]["ckpt_path"] = str(ckpt)
|
|
290
|
+
output_val["prediction_metadata"][idx]["seed"] = self.seed
|
|
291
|
+
|
|
292
|
+
# Append to outputs
|
|
293
|
+
if self.dump_trajectories:
|
|
294
|
+
X_denoised_L_traj_i = _reshape_trajectory(
|
|
295
|
+
X_noisy_L_traj[idx], self.align_trajectory_structures
|
|
296
|
+
)
|
|
297
|
+
X_noisy_L_traj_i = _reshape_trajectory(X_denoised_L_traj[idx], False)
|
|
298
|
+
denoised_trajectory_stack = (
|
|
299
|
+
build_stack_from_atom_array_and_batched_coords(
|
|
300
|
+
X_denoised_L_traj_i, pipeline_output["atom_array"]
|
|
301
|
+
)
|
|
302
|
+
)
|
|
303
|
+
noisy_trajectory_stack = build_stack_from_atom_array_and_batched_coords(
|
|
304
|
+
X_noisy_L_traj_i, pipeline_output["atom_array"]
|
|
305
|
+
)
|
|
306
|
+
else:
|
|
307
|
+
denoised_trajectory_stack = None
|
|
308
|
+
noisy_trajectory_stack = None
|
|
309
|
+
|
|
310
|
+
outputs.append(
|
|
311
|
+
RFD3Output(
|
|
312
|
+
example_id=f"{pipeline_output['example_id']}_model_{idx}",
|
|
313
|
+
atom_array=output_val["predicted_atom_array_stack"][idx],
|
|
314
|
+
metadata=output_val["prediction_metadata"][idx]
|
|
315
|
+
if self.dump_prediction_metadata_json
|
|
316
|
+
else {},
|
|
317
|
+
denoised_trajectory_stack=denoised_trajectory_stack,
|
|
318
|
+
noisy_trajectory_stack=noisy_trajectory_stack,
|
|
319
|
+
)
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
ranked_logger.info(f"Finished inference batch in {t_end - t0:.2f} seconds.")
|
|
323
|
+
return outputs
|
|
324
|
+
|
|
325
|
+
###############################################
|
|
326
|
+
# Input merging
|
|
327
|
+
###############################################
|
|
328
|
+
|
|
329
|
+
def _canonicalize_inputs(
|
|
330
|
+
self, inputs
|
|
331
|
+
) -> Dict[str, dict | DesignInputSpecification]:
|
|
332
|
+
is_json_like = (isinstance(inputs, (str, PathLike, Path))) or (
|
|
333
|
+
isinstance(inputs, list)
|
|
334
|
+
and all([isinstance(i, (str, PathLike, Path)) for i in inputs])
|
|
335
|
+
)
|
|
336
|
+
is_specification_like = isinstance(inputs, DesignInputSpecification) or (
|
|
337
|
+
isinstance(inputs, list)
|
|
338
|
+
and all([isinstance(i, DesignInputSpecification) for i in inputs])
|
|
339
|
+
)
|
|
340
|
+
is_atom_array_like = isinstance(inputs, (AtomArray, list)) or (
|
|
341
|
+
isinstance(inputs, list) and all([isinstance(i, AtomArray) for i in inputs])
|
|
342
|
+
)
|
|
343
|
+
if inputs is None:
|
|
344
|
+
# Create empty specification dictionary
|
|
345
|
+
return {"": {**self.specification_overrides}}
|
|
346
|
+
elif is_json_like:
|
|
347
|
+
# List of file paths
|
|
348
|
+
inputs = process_input(
|
|
349
|
+
inputs,
|
|
350
|
+
json_keys_subset=self.json_keys_subset,
|
|
351
|
+
global_prefix=self.global_prefix,
|
|
352
|
+
specification_overrides=self.specification_overrides,
|
|
353
|
+
validate=self.prevalidate_inputs,
|
|
354
|
+
) # any -> Dict[Name: DesignInputSpecification]
|
|
355
|
+
elif is_specification_like:
|
|
356
|
+
# List of DesignInputSpecifications
|
|
357
|
+
if isinstance(inputs, DesignInputSpecification):
|
|
358
|
+
inputs = [inputs]
|
|
359
|
+
inputs = {f"backbone_{i}": spec for i, spec in enumerate(inputs)}
|
|
360
|
+
elif is_atom_array_like:
|
|
361
|
+
raise NotImplementedError("AtomArray inputs not yet supported.")
|
|
362
|
+
else:
|
|
363
|
+
raise ValueError(
|
|
364
|
+
f"Invalid input type: {type(inputs)}. Expected JSON/YAML file paths, AtomArray, or DesignInputSpecification.\nInput: {inputs}"
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
return inputs
|
|
368
|
+
|
|
369
|
+
def _multiply_specifications(
|
|
370
|
+
self, inputs: Dict[str, dict | DesignInputSpecification], n_batches=None
|
|
371
|
+
) -> Dict[str, Dict[str, Any]]:
|
|
372
|
+
# Find existing example IDS in output directory
|
|
373
|
+
if exists(self.out_dir):
|
|
374
|
+
existing_example_ids = set(
|
|
375
|
+
extract_example_id_from_path(path, CIF_LIKE_EXTENSIONS)
|
|
376
|
+
for path in find_files_with_extension(self.out_dir, CIF_LIKE_EXTENSIONS)
|
|
377
|
+
)
|
|
378
|
+
ranked_logger.info(
|
|
379
|
+
f"Found {len(existing_example_ids)} existing example IDs in the output directory."
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
# Based on inputs, construct the specifications to loop through
|
|
383
|
+
design_specifications = {}
|
|
384
|
+
for prefix, example_spec in inputs.items():
|
|
385
|
+
# ... Create n_batches for example
|
|
386
|
+
for batch_id in range((n_batches) if exists(n_batches) else 1):
|
|
387
|
+
# ... Example ID
|
|
388
|
+
example_id = f"{prefix}_{batch_id}" if exists(n_batches) else prefix
|
|
389
|
+
|
|
390
|
+
if (
|
|
391
|
+
self.skip_existing
|
|
392
|
+
and exists(self.out_dir)
|
|
393
|
+
and example_id in existing_example_ids
|
|
394
|
+
):
|
|
395
|
+
ranked_logger.info(
|
|
396
|
+
f"Skipping design specification for example {example_id} | Already exists."
|
|
397
|
+
)
|
|
398
|
+
continue
|
|
399
|
+
design_specifications[example_id] = example_spec
|
|
400
|
+
return design_specifications
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
def normalize_inputs(inputs: str | list | None) -> list[str | None]:
|
|
404
|
+
"""
|
|
405
|
+
inputs: str | list[str] | None
|
|
406
|
+
- Can be:
|
|
407
|
+
- A single path to a JSON, YAML, or regular input file (cif or pdb)
|
|
408
|
+
- A comma-separated string of paths (e.g. "a.json,b.json")
|
|
409
|
+
- A list of file paths
|
|
410
|
+
- None or an empty list, in which case a dummy input is added (used for e.g. motif-only design)
|
|
411
|
+
- Returns list of paths or [None] if no inputs are provided
|
|
412
|
+
"""
|
|
413
|
+
if inputs is None or (isinstance(inputs, list) and len(inputs) == 0):
|
|
414
|
+
inputs = [None]
|
|
415
|
+
elif isinstance(inputs, str):
|
|
416
|
+
inputs = inputs.split(",")
|
|
417
|
+
elif not isinstance(inputs, list):
|
|
418
|
+
raise ValueError(
|
|
419
|
+
f"Invalid input type: {type(inputs)}. Expected str, list, or None.\nInput: {inputs}"
|
|
420
|
+
)
|
|
421
|
+
return inputs
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
def process_input(
|
|
425
|
+
inputs: str | list | None,
|
|
426
|
+
json_keys_subset: str | list | None = None,
|
|
427
|
+
global_prefix: str | None = None,
|
|
428
|
+
specification_overrides: dict | None = None,
|
|
429
|
+
validate: bool = True,
|
|
430
|
+
) -> Dict[str, dict]:
|
|
431
|
+
"""
|
|
432
|
+
inputs: Any -> list[str | None] (see normalize_inputs)
|
|
433
|
+
json_keys_subset: extract only subset of JSON keys. None will keep all keys
|
|
434
|
+
prefix: If provided, prefix all example ids with said prefix
|
|
435
|
+
|
|
436
|
+
returns: Dictionaries of specifcation args pre-batching:
|
|
437
|
+
{
|
|
438
|
+
'jsonfile_jsonkey1': {
|
|
439
|
+
**args_from_key1
|
|
440
|
+
},
|
|
441
|
+
'jsonfile_jsonkey2': {
|
|
442
|
+
**args_from_key2
|
|
443
|
+
}
|
|
444
|
+
}
|
|
445
|
+
"""
|
|
446
|
+
specification_overrides = dict(specification_overrides or {})
|
|
447
|
+
|
|
448
|
+
def merge_args(example_args: dict) -> dict:
|
|
449
|
+
return merge_with(lambda x: x[-1], example_args, specification_overrides)
|
|
450
|
+
|
|
451
|
+
inputs = normalize_inputs(inputs)
|
|
452
|
+
|
|
453
|
+
# If global_prefix is not provided, then default to using the basename of the JSON or YAML file (when provided)
|
|
454
|
+
if global_prefix is None:
|
|
455
|
+
use_json_basename_prefix = True
|
|
456
|
+
else:
|
|
457
|
+
use_json_basename_prefix = False
|
|
458
|
+
|
|
459
|
+
# ... Convert all inputs to list of inputs (e.g. if comma-separated)
|
|
460
|
+
if exists(inputs) and "," in inputs:
|
|
461
|
+
inputs = inputs.split(",")
|
|
462
|
+
elif not exists(inputs):
|
|
463
|
+
# If inputs is None or empty, we will create a dummy input
|
|
464
|
+
inputs = []
|
|
465
|
+
inputs = inputs if isinstance(inputs, list) else [inputs]
|
|
466
|
+
if len(inputs) == 0:
|
|
467
|
+
inputs = [None]
|
|
468
|
+
|
|
469
|
+
# ... Determine prefix of sample to create
|
|
470
|
+
all_specs = {}
|
|
471
|
+
for input in inputs:
|
|
472
|
+
if exists(input) and (input.endswith(".json") or input.endswith(".yaml")):
|
|
473
|
+
# ... Load JSON or YAML file
|
|
474
|
+
with open(input, "r") as f:
|
|
475
|
+
data = json.load(f) if input.endswith(".json") else yaml.safe_load(f)
|
|
476
|
+
|
|
477
|
+
# ... Apply any global args for this input file
|
|
478
|
+
if "global_args" in data:
|
|
479
|
+
global_args = data.pop("global_args")
|
|
480
|
+
for example in data:
|
|
481
|
+
data[example].update(global_args)
|
|
482
|
+
|
|
483
|
+
# ... Subset to keys
|
|
484
|
+
if json_keys_subset is not None:
|
|
485
|
+
json_keys_subset = (
|
|
486
|
+
json_keys_subset.split(",")
|
|
487
|
+
if isinstance(json_keys_subset, str)
|
|
488
|
+
else json_keys_subset
|
|
489
|
+
)
|
|
490
|
+
data = {
|
|
491
|
+
example: data[example]
|
|
492
|
+
for example in json_keys_subset
|
|
493
|
+
if example in data
|
|
494
|
+
}
|
|
495
|
+
|
|
496
|
+
# ... Extract each accumulated example in data.
|
|
497
|
+
for example, args in data.items():
|
|
498
|
+
args = ensure_input_is_abspath(args, input)
|
|
499
|
+
if use_json_basename_prefix:
|
|
500
|
+
name = os.path.splitext(os.path.basename(input))[0]
|
|
501
|
+
prefix = f"{name}_{example}"
|
|
502
|
+
else:
|
|
503
|
+
prefix = f"{global_prefix}{example}"
|
|
504
|
+
args["extra"] = args.get("extra", {}) | {"example": example}
|
|
505
|
+
all_specs[prefix] = dict(merge_args(args))
|
|
506
|
+
|
|
507
|
+
elif exists(input):
|
|
508
|
+
prefix = os.path.basename(os.path.splitext(input)[0])
|
|
509
|
+
if global_prefix is not None:
|
|
510
|
+
prefix = f"{global_prefix}{prefix}"
|
|
511
|
+
all_specs[prefix] = dict(merge_args({"input": input}))
|
|
512
|
+
else:
|
|
513
|
+
all_specs["backbone"] = dict(specification_overrides)
|
|
514
|
+
|
|
515
|
+
if validate:
|
|
516
|
+
for prefix, example_spec in all_specs.items():
|
|
517
|
+
ranked_logger.info(
|
|
518
|
+
f"Prevalidating design specification for example: {prefix}"
|
|
519
|
+
)
|
|
520
|
+
DesignInputSpecification.safe_init(**example_spec)
|
|
521
|
+
|
|
522
|
+
return all_specs
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
def _reshape_trajectory(traj, align_structures: bool):
|
|
526
|
+
traj = [traj[i] for i in range(len(traj))]
|
|
527
|
+
n_steps = len(traj)
|
|
528
|
+
max_frames = 100
|
|
529
|
+
|
|
530
|
+
if align_structures:
|
|
531
|
+
# ... align the trajectories on the last prediction
|
|
532
|
+
for step in range(n_steps - 1):
|
|
533
|
+
traj[step] = weighted_rigid_align(
|
|
534
|
+
X_L=traj[-1],
|
|
535
|
+
X_gt_L=traj[step],
|
|
536
|
+
)
|
|
537
|
+
traj = traj[::-1] # reverse to go from noised -> denoised
|
|
538
|
+
if n_steps > max_frames:
|
|
539
|
+
selected_indices = torch.linspace(0, n_steps - 1, max_frames).long().tolist()
|
|
540
|
+
traj = [traj[i] for i in selected_indices]
|
|
541
|
+
|
|
542
|
+
traj = torch.stack(traj).cpu().numpy()
|
|
543
|
+
return traj
|