rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
rf3/utils/io.py ADDED
@@ -0,0 +1,198 @@
1
+ from os import PathLike
2
+ from pathlib import Path
3
+
4
+ import numpy as np
5
+ import torch
6
+ from atomworks.io.utils.io_utils import to_cif_file
7
+ from atomworks.ml.utils.io import apply_sharding_pattern
8
+ from atomworks.ml.utils.misc import hash_sequence
9
+ from beartype.typing import Literal
10
+ from biotite.structure import AtomArray, AtomArrayStack, stack
11
+
12
+ from foundry.utils.alignment import weighted_rigid_align
13
+ from foundry.utils.ddp import RankedLogger
14
+
15
+ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
16
+
17
+ DICTIONARY_LIKE_EXTENSIONS = {".json", ".yaml", ".yml", ".pkl"}
18
+ CIF_LIKE_EXTENSIONS = {".cif", ".pdb", ".bcif", ".cif.gz", ".pdb.gz", ".bcif.gz"}
19
+
20
+
21
+ def get_sharded_output_path(
22
+ example_id: str,
23
+ base_dir: Path,
24
+ sharding_pattern: str | None = None,
25
+ ) -> Path:
26
+ """Get output directory path for an example with optional sharding.
27
+
28
+ Args:
29
+ example_id: Example identifier (used as final directory name).
30
+ base_dir: Base output directory.
31
+ sharding_pattern: Sharding pattern like ``/0:2/2:4/`` or ``None`` for no sharding.
32
+ Pattern defines how to split the hash of ``example_id`` into nested directories.
33
+
34
+ Returns:
35
+ Output directory path. If sharding is enabled, returns ``base_dir/shard1/shard2/.../example_id``.
36
+ Otherwise returns ``base_dir/example_id``.
37
+
38
+ Examples:
39
+ Without sharding::
40
+
41
+ get_sharded_output_path("entry_1", Path("/out"))
42
+ # Returns: /out/entry_1
43
+
44
+ With sharding pattern ``/0:2/2:4/``::
45
+
46
+ get_sharded_output_path("entry_1", Path("/out"), "/0:2/2:4/")
47
+ # Computes hash of "entry_1" (e.g., "a1b2c3d4e5f")
48
+ # Returns: /out/a1/b2/entry_1
49
+ """
50
+ if not sharding_pattern:
51
+ return base_dir / example_id
52
+
53
+ # Hash the example ID and apply sharding pattern
54
+ example_hash = hash_sequence(example_id)
55
+ sharded_path = apply_sharding_pattern(example_hash, sharding_pattern)
56
+
57
+ # Return base_dir / sharded_directories / example_id
58
+ return base_dir / sharded_path.parent / example_id
59
+
60
+
61
+ def build_stack_from_atom_array_and_batched_coords(
62
+ coords: np.ndarray | torch.Tensor,
63
+ atom_array: AtomArray,
64
+ ) -> AtomArrayStack:
65
+ """Builds an AtomArrayStack from an AtomArray and a set of coordinates with a batch dimension.
66
+
67
+ Additionally, handles the case where the AtomArray contains multiple transformations and we must adjust the chain_id.
68
+
69
+ Args:
70
+ coords (np.array): The coordinates to be assigned to the AtomArrayStack. Must have shape (nbatch, n_atoms, 3).
71
+ atom_array (AtomArray): The AtomArray to be stacked. Must have shape (n_atoms,)
72
+ """
73
+ if isinstance(coords, torch.Tensor):
74
+ coords = coords.cpu().numpy()
75
+
76
+ # (Diffusion batch size will become the number of models)
77
+ n_batch = coords.shape[0]
78
+
79
+ # Build the stack and assign the coordinates
80
+ atom_array_stack = stack([atom_array for _ in range(n_batch)])
81
+ atom_array_stack.coord = coords
82
+
83
+ # Adjust chain_id if there are multiple transformations
84
+ # (Otherwise, we will have ambiguous bond annotations, since only `chain_id` is used for the bond annotations)
85
+ if (
86
+ "transformation_id" in atom_array.get_annotation_categories()
87
+ and len(np.unique(atom_array_stack.transformation_id)) > 1
88
+ ):
89
+ new_chain_ids = np.char.add(
90
+ atom_array_stack.chain_id, atom_array_stack.transformation_id
91
+ )
92
+ atom_array_stack.set_annotation("chain_id", new_chain_ids)
93
+
94
+ return atom_array_stack
95
+
96
+
97
+ def dump_structures(
98
+ atom_arrays: AtomArrayStack | list[AtomArray] | AtomArray,
99
+ base_path: PathLike,
100
+ one_model_per_file: bool,
101
+ extra_fields: list[str] | Literal["all"] = [],
102
+ file_type: str = "cif.gz",
103
+ ) -> None:
104
+ """Dump structures to CIF files, given the coordinates and input AtomArray.
105
+
106
+ Args:
107
+ atom_arrays (AtomArrayStack | list[AtomArray] | AtomArray): Either an AtomArrayStack, a list of AtomArray objects,
108
+ or a single AtomArray object to be dumped to CIF file(s)
109
+ base_path (PathLike): Base path where the output files will be saved.
110
+ one_model_per_file (bool): Flag to determine if each model should be dumped into a separate file. Has no effect if
111
+ `atom_arrays` is a list of AtomArrays.
112
+ extra_fields (list[str] | Literal["all"]): List of extra fields to include in the CIF file.
113
+ """
114
+ base_path = Path(base_path)
115
+
116
+ if one_model_per_file:
117
+ assert (
118
+ isinstance(atom_arrays, AtomArrayStack) or isinstance(atom_arrays, list)
119
+ ), "AtomArrayStack or list of AtomArray required when one_model_per_file is True"
120
+ # One model per file —> loop over the diffusion batch
121
+ for i in range(len(atom_arrays)):
122
+ path = f"{base_path}_model_{i}"
123
+ to_cif_file(
124
+ atom_arrays[i],
125
+ path,
126
+ file_type=file_type,
127
+ include_entity_poly=False,
128
+ extra_fields=extra_fields,
129
+ )
130
+ else:
131
+ # Include all models in a single CIF file
132
+ to_cif_file(
133
+ atom_arrays,
134
+ base_path,
135
+ file_type=file_type,
136
+ include_entity_poly=False,
137
+ extra_fields=extra_fields,
138
+ )
139
+
140
+
141
+ def dump_trajectories(
142
+ trajectory_list: list[torch.Tensor | np.ndarray],
143
+ atom_array: AtomArray,
144
+ base_path: Path,
145
+ align_structures: bool = True,
146
+ file_type: str = "cif.gz",
147
+ ) -> None:
148
+ """Write denoising trajectories to CIF files.
149
+
150
+ Args:
151
+ trajectory_list (List[torch.Tensor]): List of tensors of length n_steps representing the diffusion trajectory at each step.
152
+ Each tensor has shape [D, L, 3], where D is the diffusion batch size and L is the number of atoms.
153
+ atom_array (np.ndarray): Atom array corresponding to the coordinates.
154
+ base_path (Path): Base path where the output files will be saved.
155
+ align_structures (bool): Flag to determine if the structures should be aligned on the final prediction.
156
+ If False, each step may have a different alignment.
157
+ file_type (str): File type for output (e.g., "cif", "cif.gz", "pdb"). Defaults to ``"cif.gz"``.
158
+ """
159
+ n_steps = len(trajectory_list)
160
+
161
+ if align_structures:
162
+ # ... align the trajectories on the last prediction
163
+ w_L = torch.ones(*trajectory_list[0].shape[:2]).to(trajectory_list[0].device)
164
+ X_exists_L = torch.ones(trajectory_list[0].shape[1], dtype=torch.bool).to(
165
+ trajectory_list[0].device
166
+ )
167
+ for step in range(n_steps - 1):
168
+ trajectory_list[step] = weighted_rigid_align(
169
+ X_L=trajectory_list[-1],
170
+ X_gt_L=trajectory_list[step],
171
+ X_exists_L=X_exists_L,
172
+ w_L=w_L,
173
+ )
174
+
175
+ # ... invert the list, to make the trajectory compatible with PyMol (which builds the bond graph from the first frame)
176
+ trajectory_list = trajectory_list[::-1]
177
+
178
+ # ... iterate over the range of D (diffusion batch size; e.g., 5 during validation)
179
+ # (We want to convert `aligned_trajectory_list` to a list of length D where each item is a tensor of shape [n_steps, L, 3])
180
+ trajectories_split_by_model = []
181
+ for d in range(trajectory_list[0].shape[0]):
182
+ trajectory_for_single_model = torch.stack(
183
+ [trajectory_list[step][d] for step in range(n_steps)], dim=0
184
+ )
185
+ trajectories_split_by_model.append(trajectory_for_single_model)
186
+
187
+ # ... write the trajectories to CIF files, named by epoch, dataset, example_id, and model index (within the diffusion batch)
188
+ for i, trajectory in enumerate(trajectories_split_by_model):
189
+ if isinstance(trajectory, torch.Tensor):
190
+ trajectory = trajectory.cpu().numpy()
191
+ atom_array_stack = build_stack_from_atom_array_and_batched_coords(
192
+ trajectory, atom_array
193
+ )
194
+
195
+ path = f"{base_path}_model_{i}"
196
+ to_cif_file(
197
+ atom_array_stack, path, file_type=file_type, include_entity_poly=False
198
+ )
rf3/utils/loss.py ADDED
@@ -0,0 +1,72 @@
1
+ import torch
2
+
3
+
4
+ def convert_batched_losses_to_list_of_dicts(loss_dict: dict[str, torch.Tensor]):
5
+ """Converts a dictionary of batched and non-batched loss tensors into a list of dictionaries.
6
+
7
+ Args:
8
+ loss_dict (dict): A dictionary where keys are loss names and values are PyTorch tensors.
9
+ Some values may be batched (1D tensors), while others are not (0D tensors).
10
+
11
+ Returns:
12
+ list: A list of dictionaries, each representing a batch or non-batched losses.
13
+
14
+ Example:
15
+ >>> outputs = {
16
+ ... "loss_dict": {
17
+ ... "diffusion_loss": torch.tensor([0.0509, 0.0062]),
18
+ ... "smoothed_lddt_loss": torch.tensor([0.2507, 0.2797]),
19
+ ... "t": torch.tensor([1.7329, 9.3498]),
20
+ ... "distogram_loss": torch.tensor(1.7663),
21
+ ... "total_loss": torch.tensor(1.2281),
22
+ ... }
23
+ ... }
24
+ >>> convert_batched_losses_to_list_of_dicts(outputs["loss_dict"])
25
+ [{'batch_idx': 0, 'diffusion_loss': 0.0509, 'smoothed_lddt_loss': 0.2507, 't': 1.7329},
26
+ {'batch_idx': 1, 'diffusion_loss': 0.0062, 'smoothed_lddt_loss': 0.2797, 't': 9.3498},
27
+ {'distogram_loss': 1.7663, 'total_loss': 1.2281}]
28
+ """
29
+ result = []
30
+ batch_size = next((v.size(0) for v in loss_dict.values() if v.dim() == 1), 1)
31
+
32
+ # Create a dictionary for each batch index
33
+ for batch_idx in range(batch_size):
34
+ batch_dict = {"batch_idx": batch_idx}
35
+
36
+ for key, value in loss_dict.items():
37
+ if value.dim() == 1: # Check if the tensor is batched
38
+ batch_dict[key] = value[batch_idx].item()
39
+
40
+ result.append(batch_dict)
41
+
42
+ # Create a dictionary for non-batched losses
43
+ non_batched_dict = {}
44
+ for key, value in loss_dict.items():
45
+ if value.dim() == 0: # Check if the tensor is not batched
46
+ non_batched_dict[key] = value.item()
47
+
48
+ result.append(non_batched_dict)
49
+
50
+ return result
51
+
52
+
53
+ def mean_losses(loss_dict_batched: dict[str, torch.Tensor]) -> dict:
54
+ """Compute the mean of each tensor in a dictionary of batched losses.
55
+
56
+ Args:
57
+ loss_dict_batched (Dict[str, torch.Tensor]): A dictionary where each key maps to a tensor of losses.
58
+
59
+ Returns:
60
+ dict: A dictionary with the mean loss for each key (as a tensor).
61
+
62
+ Example:
63
+ >>> loss_dict_batched = {"loss1": torch.tensor([0.5, 0.7]), "loss2": torch.tensor([1.0])}
64
+ >>> mean_losses(loss_dict_batched)
65
+ {'loss1': 0.6, 'loss2': 1.0}
66
+ """
67
+ loss_dict = {}
68
+ for key, batched_loss in loss_dict_batched.items():
69
+ # Compute the mean of the tensor and store it in the dictionary
70
+ loss_dict[key] = batched_loss.mean().item()
71
+
72
+ return loss_dict
@@ -0,0 +1,165 @@
1
+ """Utility to run RF3 predictions and then a set of metrics on those predictions."""
2
+
3
+ from pathlib import Path
4
+
5
+ import biotite.structure as struc
6
+ from atomworks.ml.transforms.filters import remove_protein_terminal_oxygen
7
+ from atomworks.ml.utils.rng import create_rng_state_from_seeds, rng_state
8
+ from beartype.typing import Any
9
+ from rf3.inference_engines.rf3 import RF3InferenceEngine
10
+ from rf3.utils.inference import InferenceInput
11
+
12
+ from foundry.metrics.metric import MetricManager
13
+ from foundry.utils.ddp import RankedLogger
14
+
15
+ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
16
+
17
+
18
+ def _clean_atom_array_for_rf3(atom_array: struc.AtomArray) -> struc.AtomArray:
19
+ """Preprocess atom array by removing terminal oxygen and hydrogens."""
20
+ original_count = len(atom_array)
21
+
22
+ # Remove terminal oxygen atoms
23
+ atom_array = remove_protein_terminal_oxygen(atom_array)
24
+ if len(atom_array) < original_count:
25
+ ranked_logger.warning(
26
+ f"Removed {original_count - len(atom_array)} terminal oxygen atoms. "
27
+ f"Atom count changed from {original_count} to {len(atom_array)}."
28
+ )
29
+ original_count = len(atom_array)
30
+
31
+ # Filter to heavy atoms only (no hydrogen)
32
+ atom_array = atom_array[atom_array.element != "H"]
33
+ if len(atom_array) < original_count:
34
+ ranked_logger.warning(
35
+ f"Removed {original_count - len(atom_array)} hydrogen atoms. "
36
+ f"Atom count changed from {original_count} to {len(atom_array)}."
37
+ )
38
+
39
+ return atom_array
40
+
41
+
42
+ def predict_and_score_with_rf3(
43
+ atom_arrays: list[struc.AtomArray],
44
+ ckpt_path: str | Path,
45
+ metrics=None,
46
+ n_recycles: int = 10,
47
+ diffusion_batch_size: int = 5,
48
+ num_steps: int = 50,
49
+ example_ids: list[str] | None = None,
50
+ annotate_b_factor_with_plddt: bool = True,
51
+ rng_seed: int = 1,
52
+ ) -> dict[str, dict[str, Any]]:
53
+ """Predict structures with RF3 and evaluate against inputs.
54
+
55
+ Metrics are computed using the RF3 inference engine's internal trainer,
56
+ which automatically handles symmetry resolution during validation.
57
+
58
+ Args:
59
+ atom_arrays: List of input structures (ground truth).
60
+ ckpt_path: Path to RF3 checkpoint file.
61
+ metrics: Metrics to compute. Can be:
62
+ - Dict mapping names to Metric objects
63
+ - List of (name, Metric) tuples
64
+ - None (no metrics)
65
+ n_recycles: Number of recycles. Defaults to ``10``.
66
+ diffusion_batch_size: Number of structures per input. Defaults to ``5``.
67
+ num_steps: Number of diffusion steps. Defaults to ``50``.
68
+ example_ids: Optional IDs for each structure. Defaults to "example_0", "example_1", etc.
69
+ annotate_b_factor_with_plddt: Whether to write pLDDT to B-factor. Defaults to ``True``.
70
+ rng_seed: RNG seed for reproducibility. Defaults to ``1``.
71
+
72
+ Returns:
73
+ Dict mapping example_id to::
74
+
75
+ {
76
+ "predicted_structures": list[AtomArray] | AtomArrayStack,
77
+ "metrics": dict[str, float],
78
+ }
79
+
80
+ Example:
81
+ ```python
82
+ metrics = [
83
+ ("all_atom_lddt", AllAtomLDDT()),
84
+ ("by_type_lddt", ByTypeLDDT()),
85
+ ]
86
+
87
+ results = predict_and_score_with_rf3(
88
+ atom_arrays=structures,
89
+ ckpt_path="rf3_latest.pt",
90
+ metrics=metrics,
91
+ )
92
+ ```
93
+ """
94
+ # Generate example IDs if not provided
95
+ if example_ids is None:
96
+ example_ids = [f"example_{i}" for i in range(len(atom_arrays))]
97
+
98
+ # Preprocess atom arrays (remove terminal oxygen and hydrogens so that atom counts match)
99
+ ranked_logger.info("Preprocessing atom arrays...")
100
+ preprocessed_arrays = [_clean_atom_array_for_rf3(arr.copy()) for arr in atom_arrays]
101
+
102
+ # Convert metrics to MetricManager if provided
103
+ if metrics is not None:
104
+ metric_manager = MetricManager.from_metrics(metrics)
105
+ else:
106
+ # (Prediction only, no metrics)
107
+ metric_manager = None
108
+
109
+ # Initialize RF3 engine (one-time) with custom metrics
110
+ ranked_logger.info("Initializing RF3 inference engine...")
111
+ inference_engine = RF3InferenceEngine(
112
+ ckpt_path=ckpt_path,
113
+ n_recycles=n_recycles,
114
+ diffusion_batch_size=diffusion_batch_size,
115
+ num_steps=num_steps,
116
+ seed=None, # We'll use external RNG state (set below)
117
+ metrics_cfg=metric_manager, # Pass MetricManager to engine
118
+ )
119
+
120
+ with rng_state(create_rng_state_from_seeds(rng_seed, rng_seed, rng_seed)):
121
+ results = {}
122
+
123
+ # Loop over each example
124
+ for example_id, ground_truth_array in zip(example_ids, preprocessed_arrays):
125
+ ranked_logger.info(f"Predicting structure for {example_id}...")
126
+
127
+ # Create InferenceInput from AtomArray
128
+ inference_input = InferenceInput.from_atom_array(
129
+ ground_truth_array, example_id=example_id
130
+ )
131
+
132
+ # Run inference in-memory
133
+ # The engine's trainer.validation_step() automatically:
134
+ # 1. Runs inference
135
+ # 2. Applies symmetry resolution
136
+ # 3. Computes configured metrics
137
+ inference_results = inference_engine.run(
138
+ inputs=inference_input,
139
+ out_dir=None, # Return in-memory
140
+ annotate_b_factor_with_plddt=annotate_b_factor_with_plddt,
141
+ )
142
+
143
+ # Extract results for this example
144
+ result = inference_results[example_id]
145
+
146
+ # Check for early stopping
147
+ if result.get("early_stopped", False):
148
+ ranked_logger.warning(
149
+ f"Early stopping triggered for {example_id} "
150
+ f"(mean pLDDT = {result.get('mean_plddt', 'N/A'):.2f})"
151
+ )
152
+ results[example_id] = {
153
+ "predicted_structures": None,
154
+ "metrics": result.get("metrics", {}),
155
+ "early_stopped": True,
156
+ }
157
+ continue
158
+
159
+ # Store results
160
+ results[example_id] = {
161
+ "predicted_structures": result["predicted_structures"],
162
+ "metrics": result.get("metrics", {}),
163
+ }
164
+
165
+ return results