rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
rfd3/engine.py ADDED
@@ -0,0 +1,543 @@
1
+ import json
2
+ import logging
3
+ import os
4
+ import time
5
+ from dataclasses import dataclass, field
6
+ from os import PathLike
7
+ from pathlib import Path
8
+ from typing import Any, Dict, List, Optional
9
+
10
+ import torch
11
+ import yaml
12
+ from atomworks.io.utils.io_utils import to_cif_file
13
+ from biotite.structure import AtomArray, AtomArrayStack
14
+ from toolz import merge_with
15
+
16
+ from foundry.common import exists
17
+ from foundry.inference_engines.base import BaseInferenceEngine
18
+ from foundry.inference_engines.checkpoint_registry import REGISTERED_CHECKPOINTS
19
+ from foundry.utils.alignment import weighted_rigid_align
20
+ from foundry.utils.ddp import RankedLogger
21
+ from rfd3.constants import SAVED_CONDITIONING_ANNOTATIONS
22
+ from rfd3.inference.datasets import (
23
+ assemble_distributed_inference_loader_from_json,
24
+ )
25
+ from rfd3.inference.input_parsing import DesignInputSpecification
26
+ from rfd3.model.inference_sampler import SampleDiffusionConfig
27
+ from rfd3.utils.inference import ensure_input_is_abspath
28
+ from rfd3.utils.io import (
29
+ CIF_LIKE_EXTENSIONS,
30
+ build_stack_from_atom_array_and_batched_coords,
31
+ extract_example_id_from_path,
32
+ find_files_with_extension,
33
+ )
34
+
35
+ logging.basicConfig(level=logging.INFO)
36
+ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
37
+
38
+
39
+ @dataclass(kw_only=True)
40
+ class RFD3InferenceConfig:
41
+ ckpt_path: str | Path = 'rfd3' # Defaults to foundry installation upon instantiation
42
+ diffusion_batch_size: int = 16
43
+
44
+ # RFD3 specific
45
+ skip_existing: bool = False
46
+ json_keys_subset: Optional[List[str]] = None
47
+ skip_existing: bool = True
48
+ specification: Optional[dict] = field(default_factory=dict)
49
+ inference_sampler: SampleDiffusionConfig | dict = field(default_factory=dict)
50
+
51
+ # Saving args
52
+ cleanup_guideposts: bool = True
53
+ cleanup_virtual_atoms: bool = True
54
+ read_sequence_from_sequence_head: bool = True
55
+ output_full_json: bool = True
56
+
57
+ # Prefix to add to all output samples
58
+ # Default: None -> f'{jsonfilebasename}_{jsonkey}_{batch}_{model}'
59
+ # Otherwise: string -> f'{string}{jsonkey}_{batch}_{model}'
60
+ # e.g. Empty string -> f'{jsonkey}_{batch}_{model}'
61
+ # e.g. Chunk string -> f'{chunkprefix_}{jsonkey}_{batch}_{model}' (pipelines usage)
62
+ global_prefix: Optional[str] = None
63
+ dump_prediction_metadata_json: bool = True
64
+ dump_trajectories: bool = False
65
+ align_trajectory_structures: bool = False
66
+ prevalidate_inputs: bool = True
67
+ low_memory_mode: bool = (
68
+ False # False for standard mode, True for memory efficient tokenization mode
69
+ )
70
+
71
+ # Other:
72
+ num_nodes: int = 1
73
+ devices_per_node: int = 1
74
+ verbose: bool = False
75
+ seed: Optional[int] = None
76
+
77
+ # For use as mapping:
78
+ def keys(self):
79
+ return self.__dataclass_fields__.keys()
80
+
81
+ def __getitem__(self, key):
82
+ return getattr(self, key)
83
+
84
+
85
+ @dataclass
86
+ class RFD3Output:
87
+ atom_array: AtomArray
88
+ metadata: dict
89
+ example_id: str
90
+ denoised_trajectory_stack: Optional[AtomArrayStack] = None
91
+ noisy_trajectory_stack: Optional[AtomArrayStack] = None
92
+
93
+ def dump(
94
+ self,
95
+ out_dir,
96
+ verbose=True,
97
+ ):
98
+ base_path = os.path.join(out_dir, self.example_id)
99
+ base_path = Path(base_path).absolute()
100
+ to_cif_file(
101
+ self.atom_array,
102
+ base_path,
103
+ file_type="cif.gz",
104
+ include_entity_poly=False,
105
+ extra_fields=SAVED_CONDITIONING_ANNOTATIONS,
106
+ )
107
+ if self.metadata:
108
+ with open(f"{base_path}.json", "w") as f:
109
+ json.dump(self.metadata, f, indent=4)
110
+
111
+ # Trajectory saving
112
+ prefix = str(base_path)[:-1].rstrip("_model_")
113
+ suffix = str(base_path)[-1]
114
+ if self.denoised_trajectory_stack is not None:
115
+ to_cif_file(
116
+ self.denoised_trajectory_stack,
117
+ "_denoised_model_".join([prefix, suffix]),
118
+ file_type="cif.gz",
119
+ include_entity_poly=False,
120
+ )
121
+
122
+ if self.noisy_trajectory_stack is not None:
123
+ to_cif_file(
124
+ self.noisy_trajectory_stack,
125
+ "_noisy_model_".join([prefix, suffix]),
126
+ file_type="cif.gz",
127
+ include_entity_poly=False,
128
+ )
129
+
130
+ if verbose:
131
+ ranked_logger.info(f"Outputs for {self.example_id} written to {base_path}.")
132
+
133
+
134
+ class RFD3InferenceEngine(BaseInferenceEngine):
135
+ """Inference engine for RFdiffusion3"""
136
+
137
+ def __init__(
138
+ self,
139
+ *,
140
+ # Default input handling args
141
+ skip_existing: bool,
142
+ json_keys_subset: None | List[str],
143
+ prevalidate_inputs: bool,
144
+ # Base inference engine args
145
+ diffusion_batch_size: int,
146
+ inference_sampler: dict,
147
+ specification: dict | None,
148
+ # Structure dumping arguments
149
+ global_prefix: str | None,
150
+ cleanup_guideposts: bool,
151
+ cleanup_virtual_atoms: bool,
152
+ read_sequence_from_sequence_head: bool,
153
+ output_full_json: bool,
154
+ dump_prediction_metadata_json: bool,
155
+ dump_trajectories: bool,
156
+ align_trajectory_structures: bool,
157
+ low_memory_mode: bool,
158
+ **kwargs,
159
+ ):
160
+ super().__init__(
161
+ transform_overrides={"diffusion_batch_size": diffusion_batch_size},
162
+ inference_sampler_overrides={**inference_sampler},
163
+ trainer_overrides={
164
+ "cleanup_guideposts": cleanup_guideposts,
165
+ "cleanup_virtual_atoms": cleanup_virtual_atoms,
166
+ "read_sequence_from_sequence_head": read_sequence_from_sequence_head,
167
+ "output_full_json": output_full_json,
168
+ },
169
+ **kwargs,
170
+ )
171
+ # save
172
+ self.specification_overrides = dict(specification or {})
173
+
174
+ # Setup output directories and args
175
+ self.global_prefix = global_prefix
176
+ self.json_keys_subset = json_keys_subset
177
+ self.prevalidate_inputs = prevalidate_inputs
178
+ self.skip_existing = skip_existing
179
+
180
+ # Saving / other args
181
+ self.dump_prediction_metadata_json = dump_prediction_metadata_json
182
+ self.dump_trajectories = dump_trajectories
183
+ self.align_trajectory_structures = align_trajectory_structures
184
+ if not cleanup_guideposts:
185
+ ranked_logger.warning(
186
+ "Guideposts will not be cleaned up. This is intended for debugging purposes."
187
+ )
188
+ if not cleanup_virtual_atoms:
189
+ ranked_logger.warning(
190
+ "Virtual atoms will not be cleaned up. Some tools like MPNN may run, but outputs will not be like native structures."
191
+ )
192
+
193
+ # Check which example ids already exist in the output directory
194
+ if low_memory_mode:
195
+ ranked_logger.info("Low memory mode enabled.")
196
+ # HACK: Set attribute to the diffusion module
197
+ os.environ["RFD3_LOW_MEMORY_MODE"] = "1"
198
+
199
+ def run(
200
+ self,
201
+ *,
202
+ inputs: str | PathLike | AtomArray | DesignInputSpecification,
203
+ n_batches: int | None = None,
204
+ out_dir: str | PathLike | None = None,
205
+ ):
206
+ self._set_out_dir(out_dir)
207
+ inputs = self._canonicalize_inputs(inputs)
208
+ design_specifications = self._multiply_specifications(
209
+ inputs=inputs,
210
+ n_batches=n_batches,
211
+ )
212
+ # init before
213
+ self.initialize()
214
+ outputs = self._run_multi(design_specifications)
215
+ return outputs
216
+
217
+ def _set_out_dir(self, out_dir: str | PathLike | None):
218
+ out_dir = Path(out_dir) if out_dir else None
219
+ if out_dir:
220
+ out_dir.mkdir(parents=True, exist_ok=True)
221
+ ranked_logger.info(f"Outputs will be written to {out_dir.resolve()}.")
222
+ self.out_dir = out_dir
223
+
224
+ def _run_multi(self, specs) -> None | Dict[str, List[RFD3Output]]:
225
+ # ==============================================================================
226
+ # Prepare pipeline and inference loader
227
+ # ==============================================================================
228
+ loader = assemble_distributed_inference_loader_from_json(
229
+ # Passed directly to ContigJSONDataset
230
+ data=specs,
231
+ transform=self.pipeline,
232
+ name="inference-dataset",
233
+ cif_parser_args=None,
234
+ subset_to_keys=None,
235
+ eval_every_n=1,
236
+ # Sampler args
237
+ world_size=self.trainer.fabric.world_size,
238
+ rank=self.trainer.fabric.global_rank,
239
+ )
240
+ loader = self.trainer.fabric.setup_dataloaders(
241
+ loader,
242
+ use_distributed_sampler=False,
243
+ )
244
+
245
+ # ==============================================================================
246
+ # Evaluate, using `validation_step`
247
+ # ==============================================================================
248
+ outputs = {}
249
+ for batch_idx, batch in enumerate(loader):
250
+ pipeline_output = batch[0]
251
+ example_id = pipeline_output["example_id"]
252
+
253
+ # Run model
254
+ output_list = self._model_forward(pipeline_output)
255
+ if self.out_dir:
256
+ for output in output_list:
257
+ output.dump(out_dir=self.out_dir)
258
+ else:
259
+ outputs[example_id] = output_list
260
+ return outputs
261
+
262
+ def _model_forward(self, pipeline_output) -> List[RFD3Output]:
263
+ # Wraps around the trainer validation step to create atom arrays for saving.
264
+ t0 = time.time()
265
+ with torch.no_grad():
266
+ pipeline_output = self.trainer.fabric.to_device(pipeline_output)
267
+ output_val = self.trainer.validation_step(
268
+ batch=pipeline_output,
269
+ batch_idx=0,
270
+ compute_metrics=False,
271
+ )
272
+ t_end = time.time()
273
+
274
+ # Add additional information to prediction metadata
275
+ if self.dump_trajectories:
276
+ X_noisy_L_traj = torch.stack(
277
+ output_val["network_output"]["X_noisy_L_traj"]
278
+ ).transpose(0, 1) # [D, N_steps, L, 3]
279
+ X_denoised_L_traj = torch.stack(
280
+ output_val["network_output"]["X_denoised_L_traj"]
281
+ ).transpose(0, 1) # [D, N_steps, L, 3]
282
+
283
+ outputs = []
284
+ for idx in range(len(output_val["predicted_atom_array_stack"])):
285
+ if self.dump_prediction_metadata_json:
286
+ ckpt = Path(self.ckpt_path)
287
+ if ckpt.is_symlink():
288
+ ckpt = ckpt.resolve(strict=True) # follow symlink to target
289
+ output_val["prediction_metadata"][idx]["ckpt_path"] = str(ckpt)
290
+ output_val["prediction_metadata"][idx]["seed"] = self.seed
291
+
292
+ # Append to outputs
293
+ if self.dump_trajectories:
294
+ X_denoised_L_traj_i = _reshape_trajectory(
295
+ X_noisy_L_traj[idx], self.align_trajectory_structures
296
+ )
297
+ X_noisy_L_traj_i = _reshape_trajectory(X_denoised_L_traj[idx], False)
298
+ denoised_trajectory_stack = (
299
+ build_stack_from_atom_array_and_batched_coords(
300
+ X_denoised_L_traj_i, pipeline_output["atom_array"]
301
+ )
302
+ )
303
+ noisy_trajectory_stack = build_stack_from_atom_array_and_batched_coords(
304
+ X_noisy_L_traj_i, pipeline_output["atom_array"]
305
+ )
306
+ else:
307
+ denoised_trajectory_stack = None
308
+ noisy_trajectory_stack = None
309
+
310
+ outputs.append(
311
+ RFD3Output(
312
+ example_id=f"{pipeline_output['example_id']}_model_{idx}",
313
+ atom_array=output_val["predicted_atom_array_stack"][idx],
314
+ metadata=output_val["prediction_metadata"][idx]
315
+ if self.dump_prediction_metadata_json
316
+ else {},
317
+ denoised_trajectory_stack=denoised_trajectory_stack,
318
+ noisy_trajectory_stack=noisy_trajectory_stack,
319
+ )
320
+ )
321
+
322
+ ranked_logger.info(f"Finished inference batch in {t_end - t0:.2f} seconds.")
323
+ return outputs
324
+
325
+ ###############################################
326
+ # Input merging
327
+ ###############################################
328
+
329
+ def _canonicalize_inputs(
330
+ self, inputs
331
+ ) -> Dict[str, dict | DesignInputSpecification]:
332
+ is_json_like = (isinstance(inputs, (str, PathLike, Path))) or (
333
+ isinstance(inputs, list)
334
+ and all([isinstance(i, (str, PathLike, Path)) for i in inputs])
335
+ )
336
+ is_specification_like = isinstance(inputs, DesignInputSpecification) or (
337
+ isinstance(inputs, list)
338
+ and all([isinstance(i, DesignInputSpecification) for i in inputs])
339
+ )
340
+ is_atom_array_like = isinstance(inputs, (AtomArray, list)) or (
341
+ isinstance(inputs, list) and all([isinstance(i, AtomArray) for i in inputs])
342
+ )
343
+ if inputs is None:
344
+ # Create empty specification dictionary
345
+ return {"": {**self.specification_overrides}}
346
+ elif is_json_like:
347
+ # List of file paths
348
+ inputs = process_input(
349
+ inputs,
350
+ json_keys_subset=self.json_keys_subset,
351
+ global_prefix=self.global_prefix,
352
+ specification_overrides=self.specification_overrides,
353
+ validate=self.prevalidate_inputs,
354
+ ) # any -> Dict[Name: DesignInputSpecification]
355
+ elif is_specification_like:
356
+ # List of DesignInputSpecifications
357
+ if isinstance(inputs, DesignInputSpecification):
358
+ inputs = [inputs]
359
+ inputs = {f"backbone_{i}": spec for i, spec in enumerate(inputs)}
360
+ elif is_atom_array_like:
361
+ raise NotImplementedError("AtomArray inputs not yet supported.")
362
+ else:
363
+ raise ValueError(
364
+ f"Invalid input type: {type(inputs)}. Expected JSON/YAML file paths, AtomArray, or DesignInputSpecification.\nInput: {inputs}"
365
+ )
366
+
367
+ return inputs
368
+
369
+ def _multiply_specifications(
370
+ self, inputs: Dict[str, dict | DesignInputSpecification], n_batches=None
371
+ ) -> Dict[str, Dict[str, Any]]:
372
+ # Find existing example IDS in output directory
373
+ if exists(self.out_dir):
374
+ existing_example_ids = set(
375
+ extract_example_id_from_path(path, CIF_LIKE_EXTENSIONS)
376
+ for path in find_files_with_extension(self.out_dir, CIF_LIKE_EXTENSIONS)
377
+ )
378
+ ranked_logger.info(
379
+ f"Found {len(existing_example_ids)} existing example IDs in the output directory."
380
+ )
381
+
382
+ # Based on inputs, construct the specifications to loop through
383
+ design_specifications = {}
384
+ for prefix, example_spec in inputs.items():
385
+ # ... Create n_batches for example
386
+ for batch_id in range((n_batches) if exists(n_batches) else 1):
387
+ # ... Example ID
388
+ example_id = f"{prefix}_{batch_id}" if exists(n_batches) else prefix
389
+
390
+ if (
391
+ self.skip_existing
392
+ and exists(self.out_dir)
393
+ and example_id in existing_example_ids
394
+ ):
395
+ ranked_logger.info(
396
+ f"Skipping design specification for example {example_id} | Already exists."
397
+ )
398
+ continue
399
+ design_specifications[example_id] = example_spec
400
+ return design_specifications
401
+
402
+
403
+ def normalize_inputs(inputs: str | list | None) -> list[str | None]:
404
+ """
405
+ inputs: str | list[str] | None
406
+ - Can be:
407
+ - A single path to a JSON, YAML, or regular input file (cif or pdb)
408
+ - A comma-separated string of paths (e.g. "a.json,b.json")
409
+ - A list of file paths
410
+ - None or an empty list, in which case a dummy input is added (used for e.g. motif-only design)
411
+ - Returns list of paths or [None] if no inputs are provided
412
+ """
413
+ if inputs is None or (isinstance(inputs, list) and len(inputs) == 0):
414
+ inputs = [None]
415
+ elif isinstance(inputs, str):
416
+ inputs = inputs.split(",")
417
+ elif not isinstance(inputs, list):
418
+ raise ValueError(
419
+ f"Invalid input type: {type(inputs)}. Expected str, list, or None.\nInput: {inputs}"
420
+ )
421
+ return inputs
422
+
423
+
424
+ def process_input(
425
+ inputs: str | list | None,
426
+ json_keys_subset: str | list | None = None,
427
+ global_prefix: str | None = None,
428
+ specification_overrides: dict | None = None,
429
+ validate: bool = True,
430
+ ) -> Dict[str, dict]:
431
+ """
432
+ inputs: Any -> list[str | None] (see normalize_inputs)
433
+ json_keys_subset: extract only subset of JSON keys. None will keep all keys
434
+ prefix: If provided, prefix all example ids with said prefix
435
+
436
+ returns: Dictionaries of specifcation args pre-batching:
437
+ {
438
+ 'jsonfile_jsonkey1': {
439
+ **args_from_key1
440
+ },
441
+ 'jsonfile_jsonkey2': {
442
+ **args_from_key2
443
+ }
444
+ }
445
+ """
446
+ specification_overrides = dict(specification_overrides or {})
447
+
448
+ def merge_args(example_args: dict) -> dict:
449
+ return merge_with(lambda x: x[-1], example_args, specification_overrides)
450
+
451
+ inputs = normalize_inputs(inputs)
452
+
453
+ # If global_prefix is not provided, then default to using the basename of the JSON or YAML file (when provided)
454
+ if global_prefix is None:
455
+ use_json_basename_prefix = True
456
+ else:
457
+ use_json_basename_prefix = False
458
+
459
+ # ... Convert all inputs to list of inputs (e.g. if comma-separated)
460
+ if exists(inputs) and "," in inputs:
461
+ inputs = inputs.split(",")
462
+ elif not exists(inputs):
463
+ # If inputs is None or empty, we will create a dummy input
464
+ inputs = []
465
+ inputs = inputs if isinstance(inputs, list) else [inputs]
466
+ if len(inputs) == 0:
467
+ inputs = [None]
468
+
469
+ # ... Determine prefix of sample to create
470
+ all_specs = {}
471
+ for input in inputs:
472
+ if exists(input) and (input.endswith(".json") or input.endswith(".yaml")):
473
+ # ... Load JSON or YAML file
474
+ with open(input, "r") as f:
475
+ data = json.load(f) if input.endswith(".json") else yaml.safe_load(f)
476
+
477
+ # ... Apply any global args for this input file
478
+ if "global_args" in data:
479
+ global_args = data.pop("global_args")
480
+ for example in data:
481
+ data[example].update(global_args)
482
+
483
+ # ... Subset to keys
484
+ if json_keys_subset is not None:
485
+ json_keys_subset = (
486
+ json_keys_subset.split(",")
487
+ if isinstance(json_keys_subset, str)
488
+ else json_keys_subset
489
+ )
490
+ data = {
491
+ example: data[example]
492
+ for example in json_keys_subset
493
+ if example in data
494
+ }
495
+
496
+ # ... Extract each accumulated example in data.
497
+ for example, args in data.items():
498
+ args = ensure_input_is_abspath(args, input)
499
+ if use_json_basename_prefix:
500
+ name = os.path.splitext(os.path.basename(input))[0]
501
+ prefix = f"{name}_{example}"
502
+ else:
503
+ prefix = f"{global_prefix}{example}"
504
+ args["extra"] = args.get("extra", {}) | {"example": example}
505
+ all_specs[prefix] = dict(merge_args(args))
506
+
507
+ elif exists(input):
508
+ prefix = os.path.basename(os.path.splitext(input)[0])
509
+ if global_prefix is not None:
510
+ prefix = f"{global_prefix}{prefix}"
511
+ all_specs[prefix] = dict(merge_args({"input": input}))
512
+ else:
513
+ all_specs["backbone"] = dict(specification_overrides)
514
+
515
+ if validate:
516
+ for prefix, example_spec in all_specs.items():
517
+ ranked_logger.info(
518
+ f"Prevalidating design specification for example: {prefix}"
519
+ )
520
+ DesignInputSpecification.safe_init(**example_spec)
521
+
522
+ return all_specs
523
+
524
+
525
+ def _reshape_trajectory(traj, align_structures: bool):
526
+ traj = [traj[i] for i in range(len(traj))]
527
+ n_steps = len(traj)
528
+ max_frames = 100
529
+
530
+ if align_structures:
531
+ # ... align the trajectories on the last prediction
532
+ for step in range(n_steps - 1):
533
+ traj[step] = weighted_rigid_align(
534
+ X_L=traj[-1],
535
+ X_gt_L=traj[step],
536
+ )
537
+ traj = traj[::-1] # reverse to go from noised -> denoised
538
+ if n_steps > max_frames:
539
+ selected_indices = torch.linspace(0, n_steps - 1, max_frames).long().tolist()
540
+ traj = [traj[i] for i in selected_indices]
541
+
542
+ traj = torch.stack(traj).cpu().numpy()
543
+ return traj