rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,356 @@
1
+ import copy
2
+ import getpass
3
+ import json
4
+ import logging
5
+ import os
6
+ import sys
7
+ import tempfile
8
+ from pathlib import Path
9
+ from unittest.mock import patch
10
+
11
+ import hydra
12
+ import numpy as np
13
+ import rootutils
14
+ import torch
15
+ from hydra import compose, initialize
16
+ from hydra.core.global_hydra import GlobalHydra
17
+ from lightning.fabric import seed_everything
18
+ from omegaconf import DictConfig
19
+
20
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../..")
21
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../../src")
22
+
23
+ import atomworks
24
+ from atomworks import parse
25
+ from atomworks.io.parser import STANDARD_PARSER_ARGS
26
+ from atomworks.io.utils.io_utils import to_cif_file
27
+ from hydra.core.hydra_config import HydraConfig
28
+ from omegaconf import OmegaConf
29
+ from rfd3.inference.input_parsing import (
30
+ DesignInputSpecification,
31
+ create_atom_array_from_design_specification,
32
+ )
33
+ from rfd3.transforms.pipelines import (
34
+ build_atom14_base_pipeline,
35
+ )
36
+
37
+ from foundry.hydra.resolvers import register_resolvers
38
+
39
+ register_resolvers()
40
+
41
+ print("Atomworks version:", atomworks.__version__)
42
+
43
+ logging.basicConfig(level=logging.INFO)
44
+ logger = logging.getLogger(__name__)
45
+
46
+ # Same as train.py
47
+ rootutils.setup_root(__file__ + "/../..", indicator=".project-root", pythonpath=True)
48
+ _config_path = "pkg://configs"
49
+
50
+ print("_config_path:", _config_path)
51
+ print(os.getcwd())
52
+
53
+ # Set globals
54
+ DEBUG_MODE_USERS = ["jbutch"]
55
+ DEBUG_MODE = getpass.getuser() in DEBUG_MODE_USERS
56
+ seed = 42
57
+ np.random.seed(seed)
58
+ torch.manual_seed(seed)
59
+
60
+ BASE_HYDRA_OVERRIDES = ["hydra.job.num=0", "hydra.runtime.output_dir=/dev/null"]
61
+
62
+ PATH_TO_REPO_ROOT = os.path.dirname(os.path.abspath(__file__)) + "/../../.."
63
+
64
+
65
+ def load_inference_cfg():
66
+ GlobalHydra.instance().clear() # Clear existing context
67
+ with patch.object(sys, "argv", ["compile_cfg"]):
68
+ initialize(config_path=_config_path, version_base="1.3", job_name="engine_cfg")
69
+ cfg = compose(
70
+ config_name="inference",
71
+ overrides=["ckpt_path=null", "out_dir=null", "inputs=null"]
72
+ + BASE_HYDRA_OVERRIDES,
73
+ return_hydra_config=True,
74
+ )
75
+ HydraConfig.instance().set_config(cfg)
76
+ return cfg
77
+
78
+
79
+ def load_train_or_val_cfg(name="pretrain", is_val_cfg=False):
80
+ GlobalHydra.instance().clear() # Clear existing context
81
+ with patch.object(sys, "argv", ["compile_cfg"]):
82
+ initialize(config_path=_config_path, version_base="1.3", job_name="trainer_cfg")
83
+ cfg = compose(
84
+ config_name="validate" if is_val_cfg else "train",
85
+ overrides=["experiment={}".format(name)] + BASE_HYDRA_OVERRIDES,
86
+ return_hydra_config=True,
87
+ )
88
+ HydraConfig.instance().set_config(cfg)
89
+ return cfg
90
+
91
+
92
+ TEST_CFG_INFERENCE = load_inference_cfg()
93
+ TEST_CFG_TRAIN = load_train_or_val_cfg()
94
+
95
+
96
+ ##########################################################################################
97
+ # Additional utilities
98
+ ##########################################################################################
99
+
100
+ DIRS = [
101
+ os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../tests"),
102
+ os.path.join(os.path.dirname(os.path.abspath(__file__))),
103
+ TEST_CFG_TRAIN.paths.data.design_benchmark_data_dir,
104
+ ]
105
+
106
+
107
+ def load_test_json():
108
+ test_files = ["demo.json", "demo_extended.json", "tests.json"]
109
+ test_files += [
110
+ "mcsa_41.json",
111
+ "rfd_unindexed.json",
112
+ "sym_tests.json",
113
+ "brk_regression.json",
114
+ ]
115
+ test_json_data = {}
116
+ for dir in DIRS:
117
+ test_data_dir = Path(dir, "test_data")
118
+
119
+ for test_file in test_files:
120
+ test_json = test_data_dir / test_file
121
+ if not test_json.exists():
122
+ print(f"Test json file {test_json} does not exist.")
123
+ continue
124
+
125
+ print(f"Loading test json file {test_json}...")
126
+ with open(test_json, "r") as f:
127
+ test_json_data_ = json.load(f)
128
+
129
+ print("Found test json data:", test_json_data.keys())
130
+ assert test_json_data_, "Test json data is empty."
131
+
132
+ nonexistent_keys = []
133
+ for k, d in test_json_data_.items():
134
+ if "input" in d:
135
+ file = test_data_dir / d["input"]
136
+ test_json_data_[k]["input"] = file
137
+ if not file.exists():
138
+ nonexistent_keys.append(k)
139
+
140
+ if nonexistent_keys:
141
+ test_json_data_ = {
142
+ k: v
143
+ for k, v in test_json_data_.items()
144
+ if k not in nonexistent_keys
145
+ }
146
+ print(f"Test input files {nonexistent_keys} do not exist for test.")
147
+ print("Skipping the examples.")
148
+
149
+ test_json_data = test_json_data | test_json_data_
150
+ return test_json_data
151
+
152
+
153
+ TEST_JSON_DATA = load_test_json()
154
+ assert TEST_JSON_DATA, "No test json data loaded!"
155
+
156
+
157
+ def filter_inference_args(args):
158
+ return {k: v for k, v in args.items() if k in valid_keys_}
159
+
160
+
161
+ def instantiate_example(args, is_inference=True):
162
+ """Saves and reloads atom array for pipeline forwarding"""
163
+ args = copy.deepcopy(args)
164
+
165
+ if is_inference:
166
+ input = DesignInputSpecification.safe_init(**args).to_pipeline_input(
167
+ example_id=args.get("example_id", "example")
168
+ )
169
+ else:
170
+ file = args.get("input")
171
+ if file is None:
172
+ # args = filter_inference_args(args)
173
+ atom_array, spec = create_atom_array_from_design_specification(**args)
174
+ outdir = tempfile.TemporaryDirectory()
175
+ outdir = Path(outdir.name)
176
+ tmpfile = outdir / "tmp.cif"
177
+ to_cif_file(
178
+ atom_array,
179
+ tmpfile,
180
+ id="example",
181
+ extra_fields="all",
182
+ )
183
+ file = tmpfile
184
+ result_dict = parse(
185
+ filename=file,
186
+ build_assembly=("1",),
187
+ **STANDARD_PARSER_ARGS,
188
+ )
189
+ input = {
190
+ "atom_array": result_dict["assemblies"]["1"][0], # First model
191
+ "chain_info": result_dict["chain_info"],
192
+ "ligand_info": result_dict["ligand_info"],
193
+ "metadata": result_dict["metadata"],
194
+ }
195
+ return input
196
+
197
+
198
+ # Pre-build pipeline | in debug mode, exclude random augmentation so that the noise scales are more clear
199
+
200
+
201
+ def build_pipelines(
202
+ cfg_name="pretrain",
203
+ train_dataset_nested_keys=["pdb", "sub_datasets", "interface"],
204
+ val_dataset_nested_keys=["unconditional"],
205
+ composed_config: DictConfig | None = None,
206
+ standardize_crop_size: bool = True,
207
+ **transform_kwargs,
208
+ ):
209
+ pipes = {}
210
+ for is_validation in [True, False]:
211
+ if composed_config is None:
212
+ config = load_train_or_val_cfg(name=cfg_name, is_val_cfg=is_validation)
213
+ else:
214
+ config = composed_config
215
+ dataset_config = config["datasets"]["val" if is_validation else "train"]
216
+ for key in (
217
+ val_dataset_nested_keys if is_validation else train_dataset_nested_keys
218
+ ):
219
+ dataset_config = dataset_config[key]
220
+ transforms_config = dataset_config["dataset"]["transform"]
221
+
222
+ OmegaConf.set_struct(transforms_config, False)
223
+ transforms_config.return_atom_array = True
224
+ transforms_config.seed = seed
225
+ transforms_config.pop("_target_")
226
+
227
+ if standardize_crop_size:
228
+ transforms_config.crop_size = 9999
229
+ transforms_config.max_atoms_in_crop = 99999
230
+
231
+ for k, v in transform_kwargs.items():
232
+ transforms_config[k] = v
233
+
234
+ pipes |= {
235
+ is_validation: build_atom14_base_pipeline(
236
+ **(transforms_config),
237
+ )
238
+ }
239
+
240
+ if DEBUG_MODE:
241
+ from rfd3.testing.debug_utils import to_debug_pipe
242
+
243
+ for is_inference in [True, False]:
244
+ pipes[is_inference] = to_debug_pipe(pipes[is_inference])
245
+ return pipes
246
+
247
+
248
+ def get_train_dataloader(cfg_name="pretrain", seed=42):
249
+ print("This function is deprecated!")
250
+ cfg = load_train_or_val_cfg(name=cfg_name, is_val_cfg=False)
251
+ from foundry.utils.datasets import (
252
+ assemble_distributed_loader,
253
+ recursively_instantiate_datasets_and_samplers,
254
+ )
255
+
256
+ seed_everything(seed, workers=True, verbose=True)
257
+
258
+ # Number of examples per epoch (accross all GPUs)
259
+ # (We must sample this many indices from our sampler)
260
+ n_examples_per_epoch = cfg.trainer.n_examples_per_epoch
261
+
262
+ # HACK: I'm getting OOMS using the full parquets, so overwrite with these
263
+ # NOTE: If using this test, check that your desired use-case is represented in these mini-datasets!
264
+ try:
265
+ PATH_TO_DATADIR_ROOT = "/projects/ml/datahub/dfs/af3_splits/2024_12_16/"
266
+ if "interface" in cfg.datasets.train.pdb.sub_datasets:
267
+ path = PATH_TO_DATADIR_ROOT + "/interfaces_df_train.parquet"
268
+ assert os.path.exists(path), f"Path {path} does not exist!"
269
+ cfg.datasets.train.pdb.sub_datasets.interface.dataset.dataset.data = path
270
+ if "pn_unit" in cfg.datasets.train.pdb.sub_datasets:
271
+ path = PATH_TO_DATADIR_ROOT + "/pn_units_df_train.parquet"
272
+ assert os.path.exists(path), f"Path {path} does not exist!"
273
+ cfg.datasets.train.pdb.sub_datasets.pn_unit.dataset.dataset.data = path
274
+ except Exception as e:
275
+ print("Could not overwrite dataset paths for train dataloader:", e)
276
+
277
+ # ... build the train dataset
278
+ # assert (
279
+ # "train" in cfg.datasets and cfg.datasets.train
280
+ # ), "No 'train' dataloader configuration provided! If only performing validation, use `validate.py` instead."
281
+ dataset_and_sampler = recursively_instantiate_datasets_and_samplers(
282
+ cfg.datasets.train
283
+ )
284
+
285
+ train_dataset = dataset_and_sampler["dataset"]
286
+ train_sampler = dataset_and_sampler["sampler"]
287
+
288
+ train_loader = assemble_distributed_loader(
289
+ dataset=train_dataset,
290
+ sampler=train_sampler,
291
+ rank=0,
292
+ world_size=1,
293
+ n_examples_per_epoch=n_examples_per_epoch,
294
+ loader_cfg=cfg.dataloader["train"],
295
+ )
296
+
297
+ return train_loader
298
+
299
+
300
+ def assert_tensor_or_array_equal(actual, expected, error_msg: str):
301
+ """Compare tensors/arrays with dtype-aware tolerances."""
302
+ if torch.is_tensor(actual):
303
+ if actual.dtype == torch.bool or actual.dtype in (torch.int32, torch.int64):
304
+ assert torch.equal(actual, expected), error_msg
305
+ else:
306
+ assert torch.allclose(
307
+ actual, expected, atol=1e-5, rtol=1e-5, equal_nan=True
308
+ ), error_msg
309
+ elif isinstance(actual, np.ndarray):
310
+ if actual.dtype.kind in {"U", "S"}:
311
+ assert np.array_equal(actual, expected), error_msg
312
+ elif actual.dtype == bool or np.issubdtype(actual.dtype, np.integer):
313
+ assert np.array_equal(actual, expected), error_msg
314
+ else:
315
+ assert np.allclose(
316
+ actual, expected, atol=1e-5, rtol=1e-5, equal_nan=True
317
+ ), error_msg
318
+ else:
319
+ assert actual == expected, error_msg
320
+
321
+
322
+ def _assert_tensor_or_array_equal(actual, expected, error_msg: str):
323
+ """Backwards-compatible alias for legacy imports."""
324
+ return assert_tensor_or_array_equal(actual, expected, error_msg)
325
+
326
+
327
+ def build_pipes_constant():
328
+ cfgs = {
329
+ True: copy.deepcopy(
330
+ TEST_CFG_TRAIN.datasets.val.unconditional.dataset.transform
331
+ ),
332
+ False: copy.deepcopy(
333
+ TEST_CFG_TRAIN.datasets.train.pdb.sub_datasets.pn_unit.dataset.transform
334
+ ),
335
+ }
336
+ # Standardize crop size and atom array return
337
+ pipes = {}
338
+ for is_inference, cfg in cfgs.items():
339
+ OmegaConf.set_struct(cfg, False)
340
+ cfg.seed = seed
341
+ cfg.return_atom_array = True
342
+ cfg.crop_size = 9999
343
+ cfg.max_atoms_in_crop = 99999
344
+ pipes[is_inference] = hydra.utils.instantiate(cfg)
345
+ # pipes[is_inference] = build_atom14_base_pipeline(
346
+ # **cfg,
347
+ # )
348
+ if DEBUG_MODE:
349
+ from rfd3.testing.debug_utils import to_debug_pipe
350
+
351
+ for is_inference in [True, False]:
352
+ pipes[is_inference] = to_debug_pipe(pipes[is_inference])
353
+ return pipes
354
+
355
+
356
+ PIPES = build_pipes_constant()
rfd3/train.py ADDED
@@ -0,0 +1,194 @@
1
+ #!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../../../.ipd/shebang/rfd3_exec.sh" "$0" "$@"'
2
+
3
+ import logging
4
+ import os
5
+
6
+ import hydra
7
+ import rootutils
8
+ from dotenv import load_dotenv
9
+ from omegaconf import DictConfig
10
+
11
+ from foundry.utils.logging import suppress_warnings
12
+ from foundry.utils.weights import CheckpointConfig
13
+
14
+ # Setup root dir and environment variables (more info: https://github.com/ashleve/rootutils)
15
+ # NOTE: Sets the `PROJECT_ROOT` environment variable to the root directory of the project (where `.project-root` is located)
16
+ rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
17
+
18
+ load_dotenv(override=True)
19
+
20
+ _config_path = os.path.join(os.environ["PROJECT_ROOT"], "models/rfd3/configs")
21
+
22
+ _spawning_process_logger = logging.getLogger(__name__)
23
+
24
+
25
+ @hydra.main(config_path=_config_path, config_name="train", version_base="1.3")
26
+ def train(cfg: DictConfig) -> None:
27
+ # ==============================================================================
28
+ # Import dependencies and resolve Hydra configuration
29
+ # ==============================================================================
30
+
31
+ _spawning_process_logger.info("Importing dependencies...")
32
+
33
+ # Lazy imports to make config generation fast
34
+ import torch
35
+ from lightning.fabric import seed_everything
36
+ from lightning.fabric.loggers import Logger
37
+
38
+ # If training on DIGS L40, set precision of matrix multiplication to balance speed and accuracy
39
+ # Reference: https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
40
+ torch.set_float32_matmul_precision("medium")
41
+
42
+ from foundry.callbacks.callback import BaseCallback # noqa
43
+ from foundry.utils.instantiators import instantiate_loggers, instantiate_callbacks # noqa
44
+ from foundry.utils.logging import (
45
+ print_config_tree,
46
+ log_hyperparameters_with_all_loggers,
47
+ ) # noqa
48
+ from foundry.utils.ddp import RankedLogger # noqa
49
+ from foundry.utils.ddp import is_rank_zero, set_accelerator_based_on_availability # noqa
50
+ from foundry.utils.datasets import (
51
+ recursively_instantiate_datasets_and_samplers,
52
+ assemble_distributed_loader,
53
+ subset_dataset_to_example_ids,
54
+ assemble_val_loader_dict,
55
+ ) # noqa
56
+
57
+ set_accelerator_based_on_availability(cfg)
58
+
59
+ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
60
+ _spawning_process_logger.info("Completed dependency imports ...")
61
+
62
+ # ... print the configuration tree (NOTE: Only prints for rank 0)
63
+ print_config_tree(cfg, resolve=True)
64
+
65
+ # ==============================================================================
66
+ # Logging and Callback instantiation
67
+ # ==============================================================================
68
+
69
+ # Reduce the logging level for all dataset and sampler loggers (unless rank 0)
70
+ # We will still see messages from Rank 0; they are identical, since all ranks load and sample from the same datasets
71
+ if not is_rank_zero():
72
+ dataset_logger = logging.getLogger("datasets")
73
+ sampler_logger = logging.getLogger("atomworks.ml.samplers")
74
+ dataset_logger.setLevel(logging.WARNING)
75
+ sampler_logger.setLevel(logging.ERROR)
76
+
77
+ # ... seed everything (NOTE: By setting `workers=True`, we ensure that the dataloaders are seeded as well)
78
+ # (`PL_GLOBAL_SEED` environment varaible will be passed to the spawned subprocessed; e.g., through `ddp_spawn` backend)
79
+ if cfg.get("seed"):
80
+ ranked_logger.info(f"Seeding everything with seed={cfg.seed}...")
81
+ seed_everything(cfg.seed, workers=True, verbose=True)
82
+ else:
83
+ ranked_logger.warning("No seed provided - Not seeding anything!")
84
+
85
+ ranked_logger.info("Instantiating loggers...")
86
+ loggers: list[Logger] = instantiate_loggers(cfg.get("logger"))
87
+
88
+ ranked_logger.info("Instantiating callbacks...")
89
+ callbacks: list[BaseCallback] = instantiate_callbacks(cfg.get("callbacks"))
90
+
91
+ # ==============================================================================
92
+ # Trainer and model instantiation
93
+ # ==============================================================================
94
+
95
+ # ... instantiate the trainer
96
+ ranked_logger.info("Instantiating trainer...")
97
+ trainer = hydra.utils.instantiate(
98
+ cfg.trainer,
99
+ loggers=loggers or None,
100
+ callbacks=callbacks or None,
101
+ _convert_="partial",
102
+ _recursive_=False,
103
+ )
104
+ # (Store the Hydra configuration in the trainer state)
105
+ trainer.initialize_or_update_trainer_state({"train_cfg": cfg})
106
+
107
+ # ... spawn processes for distributed training
108
+ # (We spawn here, rather than within `fit`, so we can use Fabric's `init_module` to efficiently initialize the model on the appropriate device)
109
+ ranked_logger.info(
110
+ f"Spawning {trainer.fabric.world_size} processes from {trainer.fabric.global_rank}..."
111
+ )
112
+ trainer.fabric.launch()
113
+
114
+ # ... construct the model
115
+ trainer.construct_model()
116
+
117
+ # ... construct the optimizer and schedule (which requires the model to be constructed)
118
+ trainer.construct_optimizer()
119
+ trainer.construct_scheduler()
120
+
121
+ # ==============================================================================
122
+ # Dataset instantiation
123
+ # ==============================================================================
124
+
125
+ # Number of examples per epoch (accross all GPUs)
126
+ # (We must sample this many indices from our sampler)
127
+ n_examples_per_epoch = cfg.trainer.n_examples_per_epoch
128
+
129
+ # ... build the train dataset
130
+ assert (
131
+ "train" in cfg.datasets and cfg.datasets.train
132
+ ), "No 'train' dataloader configuration provided! If only performing validation, use `validate.py` instead."
133
+ dataset_and_sampler = recursively_instantiate_datasets_and_samplers(
134
+ cfg.datasets.train
135
+ )
136
+ train_dataset, train_sampler = (
137
+ dataset_and_sampler["dataset"],
138
+ dataset_and_sampler["sampler"],
139
+ )
140
+
141
+ # ... compose the train loader
142
+ if "subset_to_example_ids" in cfg.datasets:
143
+ # Backdoor for debugging and overfitting: subset the dataset to a specific set of example IDs
144
+ train_dataset = subset_dataset_to_example_ids(
145
+ train_dataset, cfg.datasets.subset_to_example_ids
146
+ )
147
+ train_sampler = None # Sampler is no longer valid, since we are using a subset of the dataset
148
+
149
+ train_loader = assemble_distributed_loader(
150
+ dataset=train_dataset,
151
+ sampler=train_sampler,
152
+ rank=trainer.fabric.global_rank,
153
+ world_size=trainer.fabric.world_size,
154
+ n_examples_per_epoch=n_examples_per_epoch,
155
+ loader_cfg=cfg.dataloader["train"],
156
+ )
157
+
158
+ # ... compose the validation loader(s)
159
+ if "val" in cfg.datasets and cfg.datasets.val:
160
+ val_loaders = assemble_val_loader_dict(
161
+ cfg=cfg.datasets.val,
162
+ rank=trainer.fabric.global_rank,
163
+ world_size=trainer.fabric.world_size,
164
+ loader_cfg=cfg.dataloader["val"],
165
+ )
166
+ else:
167
+ ranked_logger.warning("No validation datasets provided! Skipping validation...")
168
+ val_loaders = None
169
+
170
+ ranked_logger.info("Logging hyperparameters...")
171
+ log_hyperparameters_with_all_loggers(
172
+ trainer=trainer, cfg=cfg, model=trainer.state["model"]
173
+ )
174
+
175
+ # ... load the checkpoint configuration
176
+ ckpt_config = None
177
+ if "ckpt_config" in cfg and cfg.ckpt_config:
178
+ ckpt_config = hydra.utils.instantiate(cfg.ckpt_config)
179
+ elif "ckpt_path" in cfg and cfg.ckpt_path:
180
+ # Just a checkpoint path
181
+ if cfg.ckpt_path is not None:
182
+ ckpt_config = CheckpointConfig(path=cfg.ckpt_path)
183
+
184
+ # ... train the model
185
+ ranked_logger.info("Training model...")
186
+
187
+ with suppress_warnings():
188
+ trainer.fit(
189
+ train_loader=train_loader, val_loaders=val_loaders, ckpt_config=ckpt_config
190
+ )
191
+
192
+
193
+ if __name__ == "__main__":
194
+ train()
@@ -0,0 +1,154 @@
1
+ import functools
2
+ from os import PathLike
3
+ from pathlib import Path
4
+
5
+ from atomworks.ml.example_id import parse_example_id
6
+ from beartype.typing import Any, Literal
7
+ from rfd3.utils.io import (
8
+ dump_metadata,
9
+ dump_structures,
10
+ dump_trajectories,
11
+ )
12
+
13
+ from foundry.callbacks.callback import BaseCallback
14
+ from foundry.utils.ddp import RankedLogger
15
+
16
+ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
17
+
18
+
19
+ class DumpValidationStructuresCallback(BaseCallback):
20
+ """Dump predicted structures and/or diffusion trajectories during validation"""
21
+
22
+ def __init__(
23
+ self,
24
+ save_dir: PathLike,
25
+ dump_predictions: bool = False,
26
+ dump_prediction_metadata_json: bool = True,
27
+ one_model_per_file: bool = False,
28
+ dump_trajectories: bool = False,
29
+ dump_denoised_trajectories_only: bool = False,
30
+ dump_every_n: int = 1,
31
+ align_trajectories: bool = True,
32
+ verbose: bool = True,
33
+ extra_fields: list[str] | Literal["all"] = [],
34
+ ):
35
+ """
36
+ Args:
37
+ dump_predictions: Whether to dump structures (CIF files) after validation batches.
38
+ one_model_per_file: If True, write each structure within a diffusion batch to its own CIF files. If False,
39
+ include each structure within a diffusion batch as a separate model within one CIF file.
40
+ dump_trajectories: Whether to dump denoising trajectories after validation batches.
41
+ dump_every_n: Dump validation structures every n epochs.
42
+ extra_fields: List of extra annotations to include in the saved CIF file.
43
+ """
44
+ super().__init__()
45
+ self.save_dir = Path(save_dir)
46
+ self.dump_predictions = dump_predictions
47
+ self.dump_prediction_metadata_json = dump_prediction_metadata_json
48
+ self.dump_trajectories = dump_trajectories
49
+ self.one_model_per_file = one_model_per_file
50
+ self.dump_denoised_trajectories_only = dump_denoised_trajectories_only
51
+ self.dump_every_n = dump_every_n
52
+ self.align_trajectories = align_trajectories
53
+ self.verbose = verbose
54
+ self.extra_fields = extra_fields
55
+
56
+ def _build_path_from_example_id(
57
+ self,
58
+ example_id,
59
+ dir: str,
60
+ extra: str = "",
61
+ epoch: str = None,
62
+ dataset_name: str = "",
63
+ ) -> Path:
64
+ """Helper function to build a path from a training or validation example_id."""
65
+ try:
66
+ # ... try to extract the PDB ID and assembly ID from the example ID
67
+ parsed_id = parse_example_id(example_id)
68
+ identifier = f"{parsed_id['pdb_id']}_{parsed_id['assembly_id']}"
69
+ except (KeyError, ValueError):
70
+ # ... if parsing fails, fall back to the original example ID
71
+ identifier = example_id
72
+
73
+ # ... parse the example_id into a dictionary of components
74
+ epoch_str = "epoch_{}".format(epoch) if epoch else ""
75
+ return (
76
+ self.save_dir / dir / f"{epoch_str}" / dataset_name / f"{identifier}{extra}"
77
+ )
78
+
79
+ def on_validation_batch_end(
80
+ self,
81
+ *,
82
+ trainer,
83
+ outputs: dict,
84
+ batch: Any,
85
+ dataset_name: str,
86
+ **_,
87
+ ):
88
+ if (not self.dump_predictions) and (not self.dump_trajectories):
89
+ return # Nothing to do
90
+ current_epoch = trainer.state["current_epoch"]
91
+ if current_epoch % self.dump_every_n != 0:
92
+ ranked_logger.debug(
93
+ f"Skipping validation batch dump at step {current_epoch} (not every {self.dump_every_n} epochs)"
94
+ )
95
+ return
96
+
97
+ assert (
98
+ "network_output" in outputs
99
+ ), "Validation outputs must contain `network_output` to dump structures!"
100
+
101
+ network_output = outputs["network_output"]
102
+ atom_array_stack = outputs["predicted_atom_array_stack"]
103
+ prediction_metadata = outputs["prediction_metadata"]
104
+ example = batch[0] # Assume batch size = 1
105
+
106
+ _build_path_from_example_id = functools.partial(
107
+ self._build_path_from_example_id,
108
+ example_id=example["example_id"],
109
+ epoch=current_epoch,
110
+ dataset_name=dataset_name,
111
+ )
112
+
113
+ if self.dump_predictions:
114
+ dump_structures(
115
+ atom_arrays=atom_array_stack,
116
+ base_path=(f := _build_path_from_example_id(dir="predictions")),
117
+ one_model_per_file=self.one_model_per_file,
118
+ extra_fields=self.extra_fields,
119
+ )
120
+ if self.verbose:
121
+ ranked_logger.info(f"Dumped validation predictions to {f}")
122
+
123
+ if self.dump_prediction_metadata_json and prediction_metadata:
124
+ dump_metadata(
125
+ prediction_metadata=prediction_metadata,
126
+ base_path=(f := _build_path_from_example_id(dir="predictions")),
127
+ one_model_per_file=self.one_model_per_file,
128
+ )
129
+ if self.verbose:
130
+ ranked_logger.info(f"Dumped validation metadata to {f}")
131
+
132
+ if self.dump_trajectories:
133
+ dump_trajectories(
134
+ trajectory_list=network_output["X_denoised_L_traj"],
135
+ atom_array=example["atom_array"],
136
+ base_path=(
137
+ f := _build_path_from_example_id(
138
+ dir="trajectories", extra="_denoised"
139
+ )
140
+ ),
141
+ align_structures=self.align_trajectories,
142
+ )
143
+ if not self.dump_denoised_trajectories_only:
144
+ dump_trajectories(
145
+ trajectory_list=network_output["X_noisy_L_traj"],
146
+ atom_array=example["atom_array"],
147
+ base_path=_build_path_from_example_id(
148
+ dir="trajectories", extra="_noisy"
149
+ ),
150
+ align_structures=self.align_trajectories,
151
+ )
152
+
153
+ if self.verbose:
154
+ ranked_logger.info(f"Dumped validation trajectories to {f}")