rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
rf3/utils/io.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
from os import PathLike
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
from atomworks.io.utils.io_utils import to_cif_file
|
|
7
|
+
from atomworks.ml.utils.io import apply_sharding_pattern
|
|
8
|
+
from atomworks.ml.utils.misc import hash_sequence
|
|
9
|
+
from beartype.typing import Literal
|
|
10
|
+
from biotite.structure import AtomArray, AtomArrayStack, stack
|
|
11
|
+
|
|
12
|
+
from foundry.utils.alignment import weighted_rigid_align
|
|
13
|
+
from foundry.utils.ddp import RankedLogger
|
|
14
|
+
|
|
15
|
+
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
16
|
+
|
|
17
|
+
DICTIONARY_LIKE_EXTENSIONS = {".json", ".yaml", ".yml", ".pkl"}
|
|
18
|
+
CIF_LIKE_EXTENSIONS = {".cif", ".pdb", ".bcif", ".cif.gz", ".pdb.gz", ".bcif.gz"}
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def get_sharded_output_path(
|
|
22
|
+
example_id: str,
|
|
23
|
+
base_dir: Path,
|
|
24
|
+
sharding_pattern: str | None = None,
|
|
25
|
+
) -> Path:
|
|
26
|
+
"""Get output directory path for an example with optional sharding.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
example_id: Example identifier (used as final directory name).
|
|
30
|
+
base_dir: Base output directory.
|
|
31
|
+
sharding_pattern: Sharding pattern like ``/0:2/2:4/`` or ``None`` for no sharding.
|
|
32
|
+
Pattern defines how to split the hash of ``example_id`` into nested directories.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
Output directory path. If sharding is enabled, returns ``base_dir/shard1/shard2/.../example_id``.
|
|
36
|
+
Otherwise returns ``base_dir/example_id``.
|
|
37
|
+
|
|
38
|
+
Examples:
|
|
39
|
+
Without sharding::
|
|
40
|
+
|
|
41
|
+
get_sharded_output_path("entry_1", Path("/out"))
|
|
42
|
+
# Returns: /out/entry_1
|
|
43
|
+
|
|
44
|
+
With sharding pattern ``/0:2/2:4/``::
|
|
45
|
+
|
|
46
|
+
get_sharded_output_path("entry_1", Path("/out"), "/0:2/2:4/")
|
|
47
|
+
# Computes hash of "entry_1" (e.g., "a1b2c3d4e5f")
|
|
48
|
+
# Returns: /out/a1/b2/entry_1
|
|
49
|
+
"""
|
|
50
|
+
if not sharding_pattern:
|
|
51
|
+
return base_dir / example_id
|
|
52
|
+
|
|
53
|
+
# Hash the example ID and apply sharding pattern
|
|
54
|
+
example_hash = hash_sequence(example_id)
|
|
55
|
+
sharded_path = apply_sharding_pattern(example_hash, sharding_pattern)
|
|
56
|
+
|
|
57
|
+
# Return base_dir / sharded_directories / example_id
|
|
58
|
+
return base_dir / sharded_path.parent / example_id
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def build_stack_from_atom_array_and_batched_coords(
|
|
62
|
+
coords: np.ndarray | torch.Tensor,
|
|
63
|
+
atom_array: AtomArray,
|
|
64
|
+
) -> AtomArrayStack:
|
|
65
|
+
"""Builds an AtomArrayStack from an AtomArray and a set of coordinates with a batch dimension.
|
|
66
|
+
|
|
67
|
+
Additionally, handles the case where the AtomArray contains multiple transformations and we must adjust the chain_id.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
coords (np.array): The coordinates to be assigned to the AtomArrayStack. Must have shape (nbatch, n_atoms, 3).
|
|
71
|
+
atom_array (AtomArray): The AtomArray to be stacked. Must have shape (n_atoms,)
|
|
72
|
+
"""
|
|
73
|
+
if isinstance(coords, torch.Tensor):
|
|
74
|
+
coords = coords.cpu().numpy()
|
|
75
|
+
|
|
76
|
+
# (Diffusion batch size will become the number of models)
|
|
77
|
+
n_batch = coords.shape[0]
|
|
78
|
+
|
|
79
|
+
# Build the stack and assign the coordinates
|
|
80
|
+
atom_array_stack = stack([atom_array for _ in range(n_batch)])
|
|
81
|
+
atom_array_stack.coord = coords
|
|
82
|
+
|
|
83
|
+
# Adjust chain_id if there are multiple transformations
|
|
84
|
+
# (Otherwise, we will have ambiguous bond annotations, since only `chain_id` is used for the bond annotations)
|
|
85
|
+
if (
|
|
86
|
+
"transformation_id" in atom_array.get_annotation_categories()
|
|
87
|
+
and len(np.unique(atom_array_stack.transformation_id)) > 1
|
|
88
|
+
):
|
|
89
|
+
new_chain_ids = np.char.add(
|
|
90
|
+
atom_array_stack.chain_id, atom_array_stack.transformation_id
|
|
91
|
+
)
|
|
92
|
+
atom_array_stack.set_annotation("chain_id", new_chain_ids)
|
|
93
|
+
|
|
94
|
+
return atom_array_stack
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def dump_structures(
|
|
98
|
+
atom_arrays: AtomArrayStack | list[AtomArray] | AtomArray,
|
|
99
|
+
base_path: PathLike,
|
|
100
|
+
one_model_per_file: bool,
|
|
101
|
+
extra_fields: list[str] | Literal["all"] = [],
|
|
102
|
+
file_type: str = "cif.gz",
|
|
103
|
+
) -> None:
|
|
104
|
+
"""Dump structures to CIF files, given the coordinates and input AtomArray.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
atom_arrays (AtomArrayStack | list[AtomArray] | AtomArray): Either an AtomArrayStack, a list of AtomArray objects,
|
|
108
|
+
or a single AtomArray object to be dumped to CIF file(s)
|
|
109
|
+
base_path (PathLike): Base path where the output files will be saved.
|
|
110
|
+
one_model_per_file (bool): Flag to determine if each model should be dumped into a separate file. Has no effect if
|
|
111
|
+
`atom_arrays` is a list of AtomArrays.
|
|
112
|
+
extra_fields (list[str] | Literal["all"]): List of extra fields to include in the CIF file.
|
|
113
|
+
"""
|
|
114
|
+
base_path = Path(base_path)
|
|
115
|
+
|
|
116
|
+
if one_model_per_file:
|
|
117
|
+
assert (
|
|
118
|
+
isinstance(atom_arrays, AtomArrayStack) or isinstance(atom_arrays, list)
|
|
119
|
+
), "AtomArrayStack or list of AtomArray required when one_model_per_file is True"
|
|
120
|
+
# One model per file —> loop over the diffusion batch
|
|
121
|
+
for i in range(len(atom_arrays)):
|
|
122
|
+
path = f"{base_path}_model_{i}"
|
|
123
|
+
to_cif_file(
|
|
124
|
+
atom_arrays[i],
|
|
125
|
+
path,
|
|
126
|
+
file_type=file_type,
|
|
127
|
+
include_entity_poly=False,
|
|
128
|
+
extra_fields=extra_fields,
|
|
129
|
+
)
|
|
130
|
+
else:
|
|
131
|
+
# Include all models in a single CIF file
|
|
132
|
+
to_cif_file(
|
|
133
|
+
atom_arrays,
|
|
134
|
+
base_path,
|
|
135
|
+
file_type=file_type,
|
|
136
|
+
include_entity_poly=False,
|
|
137
|
+
extra_fields=extra_fields,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def dump_trajectories(
|
|
142
|
+
trajectory_list: list[torch.Tensor | np.ndarray],
|
|
143
|
+
atom_array: AtomArray,
|
|
144
|
+
base_path: Path,
|
|
145
|
+
align_structures: bool = True,
|
|
146
|
+
file_type: str = "cif.gz",
|
|
147
|
+
) -> None:
|
|
148
|
+
"""Write denoising trajectories to CIF files.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
trajectory_list (List[torch.Tensor]): List of tensors of length n_steps representing the diffusion trajectory at each step.
|
|
152
|
+
Each tensor has shape [D, L, 3], where D is the diffusion batch size and L is the number of atoms.
|
|
153
|
+
atom_array (np.ndarray): Atom array corresponding to the coordinates.
|
|
154
|
+
base_path (Path): Base path where the output files will be saved.
|
|
155
|
+
align_structures (bool): Flag to determine if the structures should be aligned on the final prediction.
|
|
156
|
+
If False, each step may have a different alignment.
|
|
157
|
+
file_type (str): File type for output (e.g., "cif", "cif.gz", "pdb"). Defaults to ``"cif.gz"``.
|
|
158
|
+
"""
|
|
159
|
+
n_steps = len(trajectory_list)
|
|
160
|
+
|
|
161
|
+
if align_structures:
|
|
162
|
+
# ... align the trajectories on the last prediction
|
|
163
|
+
w_L = torch.ones(*trajectory_list[0].shape[:2]).to(trajectory_list[0].device)
|
|
164
|
+
X_exists_L = torch.ones(trajectory_list[0].shape[1], dtype=torch.bool).to(
|
|
165
|
+
trajectory_list[0].device
|
|
166
|
+
)
|
|
167
|
+
for step in range(n_steps - 1):
|
|
168
|
+
trajectory_list[step] = weighted_rigid_align(
|
|
169
|
+
X_L=trajectory_list[-1],
|
|
170
|
+
X_gt_L=trajectory_list[step],
|
|
171
|
+
X_exists_L=X_exists_L,
|
|
172
|
+
w_L=w_L,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# ... invert the list, to make the trajectory compatible with PyMol (which builds the bond graph from the first frame)
|
|
176
|
+
trajectory_list = trajectory_list[::-1]
|
|
177
|
+
|
|
178
|
+
# ... iterate over the range of D (diffusion batch size; e.g., 5 during validation)
|
|
179
|
+
# (We want to convert `aligned_trajectory_list` to a list of length D where each item is a tensor of shape [n_steps, L, 3])
|
|
180
|
+
trajectories_split_by_model = []
|
|
181
|
+
for d in range(trajectory_list[0].shape[0]):
|
|
182
|
+
trajectory_for_single_model = torch.stack(
|
|
183
|
+
[trajectory_list[step][d] for step in range(n_steps)], dim=0
|
|
184
|
+
)
|
|
185
|
+
trajectories_split_by_model.append(trajectory_for_single_model)
|
|
186
|
+
|
|
187
|
+
# ... write the trajectories to CIF files, named by epoch, dataset, example_id, and model index (within the diffusion batch)
|
|
188
|
+
for i, trajectory in enumerate(trajectories_split_by_model):
|
|
189
|
+
if isinstance(trajectory, torch.Tensor):
|
|
190
|
+
trajectory = trajectory.cpu().numpy()
|
|
191
|
+
atom_array_stack = build_stack_from_atom_array_and_batched_coords(
|
|
192
|
+
trajectory, atom_array
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
path = f"{base_path}_model_{i}"
|
|
196
|
+
to_cif_file(
|
|
197
|
+
atom_array_stack, path, file_type=file_type, include_entity_poly=False
|
|
198
|
+
)
|
rf3/utils/loss.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def convert_batched_losses_to_list_of_dicts(loss_dict: dict[str, torch.Tensor]):
|
|
5
|
+
"""Converts a dictionary of batched and non-batched loss tensors into a list of dictionaries.
|
|
6
|
+
|
|
7
|
+
Args:
|
|
8
|
+
loss_dict (dict): A dictionary where keys are loss names and values are PyTorch tensors.
|
|
9
|
+
Some values may be batched (1D tensors), while others are not (0D tensors).
|
|
10
|
+
|
|
11
|
+
Returns:
|
|
12
|
+
list: A list of dictionaries, each representing a batch or non-batched losses.
|
|
13
|
+
|
|
14
|
+
Example:
|
|
15
|
+
>>> outputs = {
|
|
16
|
+
... "loss_dict": {
|
|
17
|
+
... "diffusion_loss": torch.tensor([0.0509, 0.0062]),
|
|
18
|
+
... "smoothed_lddt_loss": torch.tensor([0.2507, 0.2797]),
|
|
19
|
+
... "t": torch.tensor([1.7329, 9.3498]),
|
|
20
|
+
... "distogram_loss": torch.tensor(1.7663),
|
|
21
|
+
... "total_loss": torch.tensor(1.2281),
|
|
22
|
+
... }
|
|
23
|
+
... }
|
|
24
|
+
>>> convert_batched_losses_to_list_of_dicts(outputs["loss_dict"])
|
|
25
|
+
[{'batch_idx': 0, 'diffusion_loss': 0.0509, 'smoothed_lddt_loss': 0.2507, 't': 1.7329},
|
|
26
|
+
{'batch_idx': 1, 'diffusion_loss': 0.0062, 'smoothed_lddt_loss': 0.2797, 't': 9.3498},
|
|
27
|
+
{'distogram_loss': 1.7663, 'total_loss': 1.2281}]
|
|
28
|
+
"""
|
|
29
|
+
result = []
|
|
30
|
+
batch_size = next((v.size(0) for v in loss_dict.values() if v.dim() == 1), 1)
|
|
31
|
+
|
|
32
|
+
# Create a dictionary for each batch index
|
|
33
|
+
for batch_idx in range(batch_size):
|
|
34
|
+
batch_dict = {"batch_idx": batch_idx}
|
|
35
|
+
|
|
36
|
+
for key, value in loss_dict.items():
|
|
37
|
+
if value.dim() == 1: # Check if the tensor is batched
|
|
38
|
+
batch_dict[key] = value[batch_idx].item()
|
|
39
|
+
|
|
40
|
+
result.append(batch_dict)
|
|
41
|
+
|
|
42
|
+
# Create a dictionary for non-batched losses
|
|
43
|
+
non_batched_dict = {}
|
|
44
|
+
for key, value in loss_dict.items():
|
|
45
|
+
if value.dim() == 0: # Check if the tensor is not batched
|
|
46
|
+
non_batched_dict[key] = value.item()
|
|
47
|
+
|
|
48
|
+
result.append(non_batched_dict)
|
|
49
|
+
|
|
50
|
+
return result
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def mean_losses(loss_dict_batched: dict[str, torch.Tensor]) -> dict:
|
|
54
|
+
"""Compute the mean of each tensor in a dictionary of batched losses.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
loss_dict_batched (Dict[str, torch.Tensor]): A dictionary where each key maps to a tensor of losses.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
dict: A dictionary with the mean loss for each key (as a tensor).
|
|
61
|
+
|
|
62
|
+
Example:
|
|
63
|
+
>>> loss_dict_batched = {"loss1": torch.tensor([0.5, 0.7]), "loss2": torch.tensor([1.0])}
|
|
64
|
+
>>> mean_losses(loss_dict_batched)
|
|
65
|
+
{'loss1': 0.6, 'loss2': 1.0}
|
|
66
|
+
"""
|
|
67
|
+
loss_dict = {}
|
|
68
|
+
for key, batched_loss in loss_dict_batched.items():
|
|
69
|
+
# Compute the mean of the tensor and store it in the dictionary
|
|
70
|
+
loss_dict[key] = batched_loss.mean().item()
|
|
71
|
+
|
|
72
|
+
return loss_dict
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
"""Utility to run RF3 predictions and then a set of metrics on those predictions."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import biotite.structure as struc
|
|
6
|
+
from atomworks.ml.transforms.filters import remove_protein_terminal_oxygen
|
|
7
|
+
from atomworks.ml.utils.rng import create_rng_state_from_seeds, rng_state
|
|
8
|
+
from beartype.typing import Any
|
|
9
|
+
from rf3.inference_engines.rf3 import RF3InferenceEngine
|
|
10
|
+
from rf3.utils.inference import InferenceInput
|
|
11
|
+
|
|
12
|
+
from foundry.metrics.metric import MetricManager
|
|
13
|
+
from foundry.utils.ddp import RankedLogger
|
|
14
|
+
|
|
15
|
+
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _clean_atom_array_for_rf3(atom_array: struc.AtomArray) -> struc.AtomArray:
|
|
19
|
+
"""Preprocess atom array by removing terminal oxygen and hydrogens."""
|
|
20
|
+
original_count = len(atom_array)
|
|
21
|
+
|
|
22
|
+
# Remove terminal oxygen atoms
|
|
23
|
+
atom_array = remove_protein_terminal_oxygen(atom_array)
|
|
24
|
+
if len(atom_array) < original_count:
|
|
25
|
+
ranked_logger.warning(
|
|
26
|
+
f"Removed {original_count - len(atom_array)} terminal oxygen atoms. "
|
|
27
|
+
f"Atom count changed from {original_count} to {len(atom_array)}."
|
|
28
|
+
)
|
|
29
|
+
original_count = len(atom_array)
|
|
30
|
+
|
|
31
|
+
# Filter to heavy atoms only (no hydrogen)
|
|
32
|
+
atom_array = atom_array[atom_array.element != "H"]
|
|
33
|
+
if len(atom_array) < original_count:
|
|
34
|
+
ranked_logger.warning(
|
|
35
|
+
f"Removed {original_count - len(atom_array)} hydrogen atoms. "
|
|
36
|
+
f"Atom count changed from {original_count} to {len(atom_array)}."
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
return atom_array
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def predict_and_score_with_rf3(
|
|
43
|
+
atom_arrays: list[struc.AtomArray],
|
|
44
|
+
ckpt_path: str | Path,
|
|
45
|
+
metrics=None,
|
|
46
|
+
n_recycles: int = 10,
|
|
47
|
+
diffusion_batch_size: int = 5,
|
|
48
|
+
num_steps: int = 50,
|
|
49
|
+
example_ids: list[str] | None = None,
|
|
50
|
+
annotate_b_factor_with_plddt: bool = True,
|
|
51
|
+
rng_seed: int = 1,
|
|
52
|
+
) -> dict[str, dict[str, Any]]:
|
|
53
|
+
"""Predict structures with RF3 and evaluate against inputs.
|
|
54
|
+
|
|
55
|
+
Metrics are computed using the RF3 inference engine's internal trainer,
|
|
56
|
+
which automatically handles symmetry resolution during validation.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
atom_arrays: List of input structures (ground truth).
|
|
60
|
+
ckpt_path: Path to RF3 checkpoint file.
|
|
61
|
+
metrics: Metrics to compute. Can be:
|
|
62
|
+
- Dict mapping names to Metric objects
|
|
63
|
+
- List of (name, Metric) tuples
|
|
64
|
+
- None (no metrics)
|
|
65
|
+
n_recycles: Number of recycles. Defaults to ``10``.
|
|
66
|
+
diffusion_batch_size: Number of structures per input. Defaults to ``5``.
|
|
67
|
+
num_steps: Number of diffusion steps. Defaults to ``50``.
|
|
68
|
+
example_ids: Optional IDs for each structure. Defaults to "example_0", "example_1", etc.
|
|
69
|
+
annotate_b_factor_with_plddt: Whether to write pLDDT to B-factor. Defaults to ``True``.
|
|
70
|
+
rng_seed: RNG seed for reproducibility. Defaults to ``1``.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Dict mapping example_id to::
|
|
74
|
+
|
|
75
|
+
{
|
|
76
|
+
"predicted_structures": list[AtomArray] | AtomArrayStack,
|
|
77
|
+
"metrics": dict[str, float],
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
Example:
|
|
81
|
+
```python
|
|
82
|
+
metrics = [
|
|
83
|
+
("all_atom_lddt", AllAtomLDDT()),
|
|
84
|
+
("by_type_lddt", ByTypeLDDT()),
|
|
85
|
+
]
|
|
86
|
+
|
|
87
|
+
results = predict_and_score_with_rf3(
|
|
88
|
+
atom_arrays=structures,
|
|
89
|
+
ckpt_path="rf3_latest.pt",
|
|
90
|
+
metrics=metrics,
|
|
91
|
+
)
|
|
92
|
+
```
|
|
93
|
+
"""
|
|
94
|
+
# Generate example IDs if not provided
|
|
95
|
+
if example_ids is None:
|
|
96
|
+
example_ids = [f"example_{i}" for i in range(len(atom_arrays))]
|
|
97
|
+
|
|
98
|
+
# Preprocess atom arrays (remove terminal oxygen and hydrogens so that atom counts match)
|
|
99
|
+
ranked_logger.info("Preprocessing atom arrays...")
|
|
100
|
+
preprocessed_arrays = [_clean_atom_array_for_rf3(arr.copy()) for arr in atom_arrays]
|
|
101
|
+
|
|
102
|
+
# Convert metrics to MetricManager if provided
|
|
103
|
+
if metrics is not None:
|
|
104
|
+
metric_manager = MetricManager.from_metrics(metrics)
|
|
105
|
+
else:
|
|
106
|
+
# (Prediction only, no metrics)
|
|
107
|
+
metric_manager = None
|
|
108
|
+
|
|
109
|
+
# Initialize RF3 engine (one-time) with custom metrics
|
|
110
|
+
ranked_logger.info("Initializing RF3 inference engine...")
|
|
111
|
+
inference_engine = RF3InferenceEngine(
|
|
112
|
+
ckpt_path=ckpt_path,
|
|
113
|
+
n_recycles=n_recycles,
|
|
114
|
+
diffusion_batch_size=diffusion_batch_size,
|
|
115
|
+
num_steps=num_steps,
|
|
116
|
+
seed=None, # We'll use external RNG state (set below)
|
|
117
|
+
metrics_cfg=metric_manager, # Pass MetricManager to engine
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
with rng_state(create_rng_state_from_seeds(rng_seed, rng_seed, rng_seed)):
|
|
121
|
+
results = {}
|
|
122
|
+
|
|
123
|
+
# Loop over each example
|
|
124
|
+
for example_id, ground_truth_array in zip(example_ids, preprocessed_arrays):
|
|
125
|
+
ranked_logger.info(f"Predicting structure for {example_id}...")
|
|
126
|
+
|
|
127
|
+
# Create InferenceInput from AtomArray
|
|
128
|
+
inference_input = InferenceInput.from_atom_array(
|
|
129
|
+
ground_truth_array, example_id=example_id
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Run inference in-memory
|
|
133
|
+
# The engine's trainer.validation_step() automatically:
|
|
134
|
+
# 1. Runs inference
|
|
135
|
+
# 2. Applies symmetry resolution
|
|
136
|
+
# 3. Computes configured metrics
|
|
137
|
+
inference_results = inference_engine.run(
|
|
138
|
+
inputs=inference_input,
|
|
139
|
+
out_dir=None, # Return in-memory
|
|
140
|
+
annotate_b_factor_with_plddt=annotate_b_factor_with_plddt,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Extract results for this example
|
|
144
|
+
result = inference_results[example_id]
|
|
145
|
+
|
|
146
|
+
# Check for early stopping
|
|
147
|
+
if result.get("early_stopped", False):
|
|
148
|
+
ranked_logger.warning(
|
|
149
|
+
f"Early stopping triggered for {example_id} "
|
|
150
|
+
f"(mean pLDDT = {result.get('mean_plddt', 'N/A'):.2f})"
|
|
151
|
+
)
|
|
152
|
+
results[example_id] = {
|
|
153
|
+
"predicted_structures": None,
|
|
154
|
+
"metrics": result.get("metrics", {}),
|
|
155
|
+
"early_stopped": True,
|
|
156
|
+
}
|
|
157
|
+
continue
|
|
158
|
+
|
|
159
|
+
# Store results
|
|
160
|
+
results[example_id] = {
|
|
161
|
+
"predicted_structures": result["predicted_structures"],
|
|
162
|
+
"metrics": result.get("metrics", {}),
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
return results
|