rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,356 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import getpass
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
import sys
|
|
7
|
+
import tempfile
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from unittest.mock import patch
|
|
10
|
+
|
|
11
|
+
import hydra
|
|
12
|
+
import numpy as np
|
|
13
|
+
import rootutils
|
|
14
|
+
import torch
|
|
15
|
+
from hydra import compose, initialize
|
|
16
|
+
from hydra.core.global_hydra import GlobalHydra
|
|
17
|
+
from lightning.fabric import seed_everything
|
|
18
|
+
from omegaconf import DictConfig
|
|
19
|
+
|
|
20
|
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../..")
|
|
21
|
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../../src")
|
|
22
|
+
|
|
23
|
+
import atomworks
|
|
24
|
+
from atomworks import parse
|
|
25
|
+
from atomworks.io.parser import STANDARD_PARSER_ARGS
|
|
26
|
+
from atomworks.io.utils.io_utils import to_cif_file
|
|
27
|
+
from hydra.core.hydra_config import HydraConfig
|
|
28
|
+
from omegaconf import OmegaConf
|
|
29
|
+
from rfd3.inference.input_parsing import (
|
|
30
|
+
DesignInputSpecification,
|
|
31
|
+
create_atom_array_from_design_specification,
|
|
32
|
+
)
|
|
33
|
+
from rfd3.transforms.pipelines import (
|
|
34
|
+
build_atom14_base_pipeline,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
from foundry.hydra.resolvers import register_resolvers
|
|
38
|
+
|
|
39
|
+
register_resolvers()
|
|
40
|
+
|
|
41
|
+
print("Atomworks version:", atomworks.__version__)
|
|
42
|
+
|
|
43
|
+
logging.basicConfig(level=logging.INFO)
|
|
44
|
+
logger = logging.getLogger(__name__)
|
|
45
|
+
|
|
46
|
+
# Same as train.py
|
|
47
|
+
rootutils.setup_root(__file__ + "/../..", indicator=".project-root", pythonpath=True)
|
|
48
|
+
_config_path = "pkg://configs"
|
|
49
|
+
|
|
50
|
+
print("_config_path:", _config_path)
|
|
51
|
+
print(os.getcwd())
|
|
52
|
+
|
|
53
|
+
# Set globals
|
|
54
|
+
DEBUG_MODE_USERS = ["jbutch"]
|
|
55
|
+
DEBUG_MODE = getpass.getuser() in DEBUG_MODE_USERS
|
|
56
|
+
seed = 42
|
|
57
|
+
np.random.seed(seed)
|
|
58
|
+
torch.manual_seed(seed)
|
|
59
|
+
|
|
60
|
+
BASE_HYDRA_OVERRIDES = ["hydra.job.num=0", "hydra.runtime.output_dir=/dev/null"]
|
|
61
|
+
|
|
62
|
+
PATH_TO_REPO_ROOT = os.path.dirname(os.path.abspath(__file__)) + "/../../.."
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def load_inference_cfg():
|
|
66
|
+
GlobalHydra.instance().clear() # Clear existing context
|
|
67
|
+
with patch.object(sys, "argv", ["compile_cfg"]):
|
|
68
|
+
initialize(config_path=_config_path, version_base="1.3", job_name="engine_cfg")
|
|
69
|
+
cfg = compose(
|
|
70
|
+
config_name="inference",
|
|
71
|
+
overrides=["ckpt_path=null", "out_dir=null", "inputs=null"]
|
|
72
|
+
+ BASE_HYDRA_OVERRIDES,
|
|
73
|
+
return_hydra_config=True,
|
|
74
|
+
)
|
|
75
|
+
HydraConfig.instance().set_config(cfg)
|
|
76
|
+
return cfg
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def load_train_or_val_cfg(name="pretrain", is_val_cfg=False):
|
|
80
|
+
GlobalHydra.instance().clear() # Clear existing context
|
|
81
|
+
with patch.object(sys, "argv", ["compile_cfg"]):
|
|
82
|
+
initialize(config_path=_config_path, version_base="1.3", job_name="trainer_cfg")
|
|
83
|
+
cfg = compose(
|
|
84
|
+
config_name="validate" if is_val_cfg else "train",
|
|
85
|
+
overrides=["experiment={}".format(name)] + BASE_HYDRA_OVERRIDES,
|
|
86
|
+
return_hydra_config=True,
|
|
87
|
+
)
|
|
88
|
+
HydraConfig.instance().set_config(cfg)
|
|
89
|
+
return cfg
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
TEST_CFG_INFERENCE = load_inference_cfg()
|
|
93
|
+
TEST_CFG_TRAIN = load_train_or_val_cfg()
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
##########################################################################################
|
|
97
|
+
# Additional utilities
|
|
98
|
+
##########################################################################################
|
|
99
|
+
|
|
100
|
+
DIRS = [
|
|
101
|
+
os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../tests"),
|
|
102
|
+
os.path.join(os.path.dirname(os.path.abspath(__file__))),
|
|
103
|
+
TEST_CFG_TRAIN.paths.data.design_benchmark_data_dir,
|
|
104
|
+
]
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def load_test_json():
|
|
108
|
+
test_files = ["demo.json", "demo_extended.json", "tests.json"]
|
|
109
|
+
test_files += [
|
|
110
|
+
"mcsa_41.json",
|
|
111
|
+
"rfd_unindexed.json",
|
|
112
|
+
"sym_tests.json",
|
|
113
|
+
"brk_regression.json",
|
|
114
|
+
]
|
|
115
|
+
test_json_data = {}
|
|
116
|
+
for dir in DIRS:
|
|
117
|
+
test_data_dir = Path(dir, "test_data")
|
|
118
|
+
|
|
119
|
+
for test_file in test_files:
|
|
120
|
+
test_json = test_data_dir / test_file
|
|
121
|
+
if not test_json.exists():
|
|
122
|
+
print(f"Test json file {test_json} does not exist.")
|
|
123
|
+
continue
|
|
124
|
+
|
|
125
|
+
print(f"Loading test json file {test_json}...")
|
|
126
|
+
with open(test_json, "r") as f:
|
|
127
|
+
test_json_data_ = json.load(f)
|
|
128
|
+
|
|
129
|
+
print("Found test json data:", test_json_data.keys())
|
|
130
|
+
assert test_json_data_, "Test json data is empty."
|
|
131
|
+
|
|
132
|
+
nonexistent_keys = []
|
|
133
|
+
for k, d in test_json_data_.items():
|
|
134
|
+
if "input" in d:
|
|
135
|
+
file = test_data_dir / d["input"]
|
|
136
|
+
test_json_data_[k]["input"] = file
|
|
137
|
+
if not file.exists():
|
|
138
|
+
nonexistent_keys.append(k)
|
|
139
|
+
|
|
140
|
+
if nonexistent_keys:
|
|
141
|
+
test_json_data_ = {
|
|
142
|
+
k: v
|
|
143
|
+
for k, v in test_json_data_.items()
|
|
144
|
+
if k not in nonexistent_keys
|
|
145
|
+
}
|
|
146
|
+
print(f"Test input files {nonexistent_keys} do not exist for test.")
|
|
147
|
+
print("Skipping the examples.")
|
|
148
|
+
|
|
149
|
+
test_json_data = test_json_data | test_json_data_
|
|
150
|
+
return test_json_data
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
TEST_JSON_DATA = load_test_json()
|
|
154
|
+
assert TEST_JSON_DATA, "No test json data loaded!"
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def filter_inference_args(args):
|
|
158
|
+
return {k: v for k, v in args.items() if k in valid_keys_}
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def instantiate_example(args, is_inference=True):
|
|
162
|
+
"""Saves and reloads atom array for pipeline forwarding"""
|
|
163
|
+
args = copy.deepcopy(args)
|
|
164
|
+
|
|
165
|
+
if is_inference:
|
|
166
|
+
input = DesignInputSpecification.safe_init(**args).to_pipeline_input(
|
|
167
|
+
example_id=args.get("example_id", "example")
|
|
168
|
+
)
|
|
169
|
+
else:
|
|
170
|
+
file = args.get("input")
|
|
171
|
+
if file is None:
|
|
172
|
+
# args = filter_inference_args(args)
|
|
173
|
+
atom_array, spec = create_atom_array_from_design_specification(**args)
|
|
174
|
+
outdir = tempfile.TemporaryDirectory()
|
|
175
|
+
outdir = Path(outdir.name)
|
|
176
|
+
tmpfile = outdir / "tmp.cif"
|
|
177
|
+
to_cif_file(
|
|
178
|
+
atom_array,
|
|
179
|
+
tmpfile,
|
|
180
|
+
id="example",
|
|
181
|
+
extra_fields="all",
|
|
182
|
+
)
|
|
183
|
+
file = tmpfile
|
|
184
|
+
result_dict = parse(
|
|
185
|
+
filename=file,
|
|
186
|
+
build_assembly=("1",),
|
|
187
|
+
**STANDARD_PARSER_ARGS,
|
|
188
|
+
)
|
|
189
|
+
input = {
|
|
190
|
+
"atom_array": result_dict["assemblies"]["1"][0], # First model
|
|
191
|
+
"chain_info": result_dict["chain_info"],
|
|
192
|
+
"ligand_info": result_dict["ligand_info"],
|
|
193
|
+
"metadata": result_dict["metadata"],
|
|
194
|
+
}
|
|
195
|
+
return input
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
# Pre-build pipeline | in debug mode, exclude random augmentation so that the noise scales are more clear
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def build_pipelines(
|
|
202
|
+
cfg_name="pretrain",
|
|
203
|
+
train_dataset_nested_keys=["pdb", "sub_datasets", "interface"],
|
|
204
|
+
val_dataset_nested_keys=["unconditional"],
|
|
205
|
+
composed_config: DictConfig | None = None,
|
|
206
|
+
standardize_crop_size: bool = True,
|
|
207
|
+
**transform_kwargs,
|
|
208
|
+
):
|
|
209
|
+
pipes = {}
|
|
210
|
+
for is_validation in [True, False]:
|
|
211
|
+
if composed_config is None:
|
|
212
|
+
config = load_train_or_val_cfg(name=cfg_name, is_val_cfg=is_validation)
|
|
213
|
+
else:
|
|
214
|
+
config = composed_config
|
|
215
|
+
dataset_config = config["datasets"]["val" if is_validation else "train"]
|
|
216
|
+
for key in (
|
|
217
|
+
val_dataset_nested_keys if is_validation else train_dataset_nested_keys
|
|
218
|
+
):
|
|
219
|
+
dataset_config = dataset_config[key]
|
|
220
|
+
transforms_config = dataset_config["dataset"]["transform"]
|
|
221
|
+
|
|
222
|
+
OmegaConf.set_struct(transforms_config, False)
|
|
223
|
+
transforms_config.return_atom_array = True
|
|
224
|
+
transforms_config.seed = seed
|
|
225
|
+
transforms_config.pop("_target_")
|
|
226
|
+
|
|
227
|
+
if standardize_crop_size:
|
|
228
|
+
transforms_config.crop_size = 9999
|
|
229
|
+
transforms_config.max_atoms_in_crop = 99999
|
|
230
|
+
|
|
231
|
+
for k, v in transform_kwargs.items():
|
|
232
|
+
transforms_config[k] = v
|
|
233
|
+
|
|
234
|
+
pipes |= {
|
|
235
|
+
is_validation: build_atom14_base_pipeline(
|
|
236
|
+
**(transforms_config),
|
|
237
|
+
)
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
if DEBUG_MODE:
|
|
241
|
+
from rfd3.testing.debug_utils import to_debug_pipe
|
|
242
|
+
|
|
243
|
+
for is_inference in [True, False]:
|
|
244
|
+
pipes[is_inference] = to_debug_pipe(pipes[is_inference])
|
|
245
|
+
return pipes
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def get_train_dataloader(cfg_name="pretrain", seed=42):
|
|
249
|
+
print("This function is deprecated!")
|
|
250
|
+
cfg = load_train_or_val_cfg(name=cfg_name, is_val_cfg=False)
|
|
251
|
+
from foundry.utils.datasets import (
|
|
252
|
+
assemble_distributed_loader,
|
|
253
|
+
recursively_instantiate_datasets_and_samplers,
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
seed_everything(seed, workers=True, verbose=True)
|
|
257
|
+
|
|
258
|
+
# Number of examples per epoch (accross all GPUs)
|
|
259
|
+
# (We must sample this many indices from our sampler)
|
|
260
|
+
n_examples_per_epoch = cfg.trainer.n_examples_per_epoch
|
|
261
|
+
|
|
262
|
+
# HACK: I'm getting OOMS using the full parquets, so overwrite with these
|
|
263
|
+
# NOTE: If using this test, check that your desired use-case is represented in these mini-datasets!
|
|
264
|
+
try:
|
|
265
|
+
PATH_TO_DATADIR_ROOT = "/projects/ml/datahub/dfs/af3_splits/2024_12_16/"
|
|
266
|
+
if "interface" in cfg.datasets.train.pdb.sub_datasets:
|
|
267
|
+
path = PATH_TO_DATADIR_ROOT + "/interfaces_df_train.parquet"
|
|
268
|
+
assert os.path.exists(path), f"Path {path} does not exist!"
|
|
269
|
+
cfg.datasets.train.pdb.sub_datasets.interface.dataset.dataset.data = path
|
|
270
|
+
if "pn_unit" in cfg.datasets.train.pdb.sub_datasets:
|
|
271
|
+
path = PATH_TO_DATADIR_ROOT + "/pn_units_df_train.parquet"
|
|
272
|
+
assert os.path.exists(path), f"Path {path} does not exist!"
|
|
273
|
+
cfg.datasets.train.pdb.sub_datasets.pn_unit.dataset.dataset.data = path
|
|
274
|
+
except Exception as e:
|
|
275
|
+
print("Could not overwrite dataset paths for train dataloader:", e)
|
|
276
|
+
|
|
277
|
+
# ... build the train dataset
|
|
278
|
+
# assert (
|
|
279
|
+
# "train" in cfg.datasets and cfg.datasets.train
|
|
280
|
+
# ), "No 'train' dataloader configuration provided! If only performing validation, use `validate.py` instead."
|
|
281
|
+
dataset_and_sampler = recursively_instantiate_datasets_and_samplers(
|
|
282
|
+
cfg.datasets.train
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
train_dataset = dataset_and_sampler["dataset"]
|
|
286
|
+
train_sampler = dataset_and_sampler["sampler"]
|
|
287
|
+
|
|
288
|
+
train_loader = assemble_distributed_loader(
|
|
289
|
+
dataset=train_dataset,
|
|
290
|
+
sampler=train_sampler,
|
|
291
|
+
rank=0,
|
|
292
|
+
world_size=1,
|
|
293
|
+
n_examples_per_epoch=n_examples_per_epoch,
|
|
294
|
+
loader_cfg=cfg.dataloader["train"],
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
return train_loader
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def assert_tensor_or_array_equal(actual, expected, error_msg: str):
|
|
301
|
+
"""Compare tensors/arrays with dtype-aware tolerances."""
|
|
302
|
+
if torch.is_tensor(actual):
|
|
303
|
+
if actual.dtype == torch.bool or actual.dtype in (torch.int32, torch.int64):
|
|
304
|
+
assert torch.equal(actual, expected), error_msg
|
|
305
|
+
else:
|
|
306
|
+
assert torch.allclose(
|
|
307
|
+
actual, expected, atol=1e-5, rtol=1e-5, equal_nan=True
|
|
308
|
+
), error_msg
|
|
309
|
+
elif isinstance(actual, np.ndarray):
|
|
310
|
+
if actual.dtype.kind in {"U", "S"}:
|
|
311
|
+
assert np.array_equal(actual, expected), error_msg
|
|
312
|
+
elif actual.dtype == bool or np.issubdtype(actual.dtype, np.integer):
|
|
313
|
+
assert np.array_equal(actual, expected), error_msg
|
|
314
|
+
else:
|
|
315
|
+
assert np.allclose(
|
|
316
|
+
actual, expected, atol=1e-5, rtol=1e-5, equal_nan=True
|
|
317
|
+
), error_msg
|
|
318
|
+
else:
|
|
319
|
+
assert actual == expected, error_msg
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def _assert_tensor_or_array_equal(actual, expected, error_msg: str):
|
|
323
|
+
"""Backwards-compatible alias for legacy imports."""
|
|
324
|
+
return assert_tensor_or_array_equal(actual, expected, error_msg)
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def build_pipes_constant():
|
|
328
|
+
cfgs = {
|
|
329
|
+
True: copy.deepcopy(
|
|
330
|
+
TEST_CFG_TRAIN.datasets.val.unconditional.dataset.transform
|
|
331
|
+
),
|
|
332
|
+
False: copy.deepcopy(
|
|
333
|
+
TEST_CFG_TRAIN.datasets.train.pdb.sub_datasets.pn_unit.dataset.transform
|
|
334
|
+
),
|
|
335
|
+
}
|
|
336
|
+
# Standardize crop size and atom array return
|
|
337
|
+
pipes = {}
|
|
338
|
+
for is_inference, cfg in cfgs.items():
|
|
339
|
+
OmegaConf.set_struct(cfg, False)
|
|
340
|
+
cfg.seed = seed
|
|
341
|
+
cfg.return_atom_array = True
|
|
342
|
+
cfg.crop_size = 9999
|
|
343
|
+
cfg.max_atoms_in_crop = 99999
|
|
344
|
+
pipes[is_inference] = hydra.utils.instantiate(cfg)
|
|
345
|
+
# pipes[is_inference] = build_atom14_base_pipeline(
|
|
346
|
+
# **cfg,
|
|
347
|
+
# )
|
|
348
|
+
if DEBUG_MODE:
|
|
349
|
+
from rfd3.testing.debug_utils import to_debug_pipe
|
|
350
|
+
|
|
351
|
+
for is_inference in [True, False]:
|
|
352
|
+
pipes[is_inference] = to_debug_pipe(pipes[is_inference])
|
|
353
|
+
return pipes
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
PIPES = build_pipes_constant()
|
rfd3/train.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
#!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../../../.ipd/shebang/rfd3_exec.sh" "$0" "$@"'
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
import hydra
|
|
7
|
+
import rootutils
|
|
8
|
+
from dotenv import load_dotenv
|
|
9
|
+
from omegaconf import DictConfig
|
|
10
|
+
|
|
11
|
+
from foundry.utils.logging import suppress_warnings
|
|
12
|
+
from foundry.utils.weights import CheckpointConfig
|
|
13
|
+
|
|
14
|
+
# Setup root dir and environment variables (more info: https://github.com/ashleve/rootutils)
|
|
15
|
+
# NOTE: Sets the `PROJECT_ROOT` environment variable to the root directory of the project (where `.project-root` is located)
|
|
16
|
+
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
17
|
+
|
|
18
|
+
load_dotenv(override=True)
|
|
19
|
+
|
|
20
|
+
_config_path = os.path.join(os.environ["PROJECT_ROOT"], "models/rfd3/configs")
|
|
21
|
+
|
|
22
|
+
_spawning_process_logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@hydra.main(config_path=_config_path, config_name="train", version_base="1.3")
|
|
26
|
+
def train(cfg: DictConfig) -> None:
|
|
27
|
+
# ==============================================================================
|
|
28
|
+
# Import dependencies and resolve Hydra configuration
|
|
29
|
+
# ==============================================================================
|
|
30
|
+
|
|
31
|
+
_spawning_process_logger.info("Importing dependencies...")
|
|
32
|
+
|
|
33
|
+
# Lazy imports to make config generation fast
|
|
34
|
+
import torch
|
|
35
|
+
from lightning.fabric import seed_everything
|
|
36
|
+
from lightning.fabric.loggers import Logger
|
|
37
|
+
|
|
38
|
+
# If training on DIGS L40, set precision of matrix multiplication to balance speed and accuracy
|
|
39
|
+
# Reference: https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
|
|
40
|
+
torch.set_float32_matmul_precision("medium")
|
|
41
|
+
|
|
42
|
+
from foundry.callbacks.callback import BaseCallback # noqa
|
|
43
|
+
from foundry.utils.instantiators import instantiate_loggers, instantiate_callbacks # noqa
|
|
44
|
+
from foundry.utils.logging import (
|
|
45
|
+
print_config_tree,
|
|
46
|
+
log_hyperparameters_with_all_loggers,
|
|
47
|
+
) # noqa
|
|
48
|
+
from foundry.utils.ddp import RankedLogger # noqa
|
|
49
|
+
from foundry.utils.ddp import is_rank_zero, set_accelerator_based_on_availability # noqa
|
|
50
|
+
from foundry.utils.datasets import (
|
|
51
|
+
recursively_instantiate_datasets_and_samplers,
|
|
52
|
+
assemble_distributed_loader,
|
|
53
|
+
subset_dataset_to_example_ids,
|
|
54
|
+
assemble_val_loader_dict,
|
|
55
|
+
) # noqa
|
|
56
|
+
|
|
57
|
+
set_accelerator_based_on_availability(cfg)
|
|
58
|
+
|
|
59
|
+
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
60
|
+
_spawning_process_logger.info("Completed dependency imports ...")
|
|
61
|
+
|
|
62
|
+
# ... print the configuration tree (NOTE: Only prints for rank 0)
|
|
63
|
+
print_config_tree(cfg, resolve=True)
|
|
64
|
+
|
|
65
|
+
# ==============================================================================
|
|
66
|
+
# Logging and Callback instantiation
|
|
67
|
+
# ==============================================================================
|
|
68
|
+
|
|
69
|
+
# Reduce the logging level for all dataset and sampler loggers (unless rank 0)
|
|
70
|
+
# We will still see messages from Rank 0; they are identical, since all ranks load and sample from the same datasets
|
|
71
|
+
if not is_rank_zero():
|
|
72
|
+
dataset_logger = logging.getLogger("datasets")
|
|
73
|
+
sampler_logger = logging.getLogger("atomworks.ml.samplers")
|
|
74
|
+
dataset_logger.setLevel(logging.WARNING)
|
|
75
|
+
sampler_logger.setLevel(logging.ERROR)
|
|
76
|
+
|
|
77
|
+
# ... seed everything (NOTE: By setting `workers=True`, we ensure that the dataloaders are seeded as well)
|
|
78
|
+
# (`PL_GLOBAL_SEED` environment varaible will be passed to the spawned subprocessed; e.g., through `ddp_spawn` backend)
|
|
79
|
+
if cfg.get("seed"):
|
|
80
|
+
ranked_logger.info(f"Seeding everything with seed={cfg.seed}...")
|
|
81
|
+
seed_everything(cfg.seed, workers=True, verbose=True)
|
|
82
|
+
else:
|
|
83
|
+
ranked_logger.warning("No seed provided - Not seeding anything!")
|
|
84
|
+
|
|
85
|
+
ranked_logger.info("Instantiating loggers...")
|
|
86
|
+
loggers: list[Logger] = instantiate_loggers(cfg.get("logger"))
|
|
87
|
+
|
|
88
|
+
ranked_logger.info("Instantiating callbacks...")
|
|
89
|
+
callbacks: list[BaseCallback] = instantiate_callbacks(cfg.get("callbacks"))
|
|
90
|
+
|
|
91
|
+
# ==============================================================================
|
|
92
|
+
# Trainer and model instantiation
|
|
93
|
+
# ==============================================================================
|
|
94
|
+
|
|
95
|
+
# ... instantiate the trainer
|
|
96
|
+
ranked_logger.info("Instantiating trainer...")
|
|
97
|
+
trainer = hydra.utils.instantiate(
|
|
98
|
+
cfg.trainer,
|
|
99
|
+
loggers=loggers or None,
|
|
100
|
+
callbacks=callbacks or None,
|
|
101
|
+
_convert_="partial",
|
|
102
|
+
_recursive_=False,
|
|
103
|
+
)
|
|
104
|
+
# (Store the Hydra configuration in the trainer state)
|
|
105
|
+
trainer.initialize_or_update_trainer_state({"train_cfg": cfg})
|
|
106
|
+
|
|
107
|
+
# ... spawn processes for distributed training
|
|
108
|
+
# (We spawn here, rather than within `fit`, so we can use Fabric's `init_module` to efficiently initialize the model on the appropriate device)
|
|
109
|
+
ranked_logger.info(
|
|
110
|
+
f"Spawning {trainer.fabric.world_size} processes from {trainer.fabric.global_rank}..."
|
|
111
|
+
)
|
|
112
|
+
trainer.fabric.launch()
|
|
113
|
+
|
|
114
|
+
# ... construct the model
|
|
115
|
+
trainer.construct_model()
|
|
116
|
+
|
|
117
|
+
# ... construct the optimizer and schedule (which requires the model to be constructed)
|
|
118
|
+
trainer.construct_optimizer()
|
|
119
|
+
trainer.construct_scheduler()
|
|
120
|
+
|
|
121
|
+
# ==============================================================================
|
|
122
|
+
# Dataset instantiation
|
|
123
|
+
# ==============================================================================
|
|
124
|
+
|
|
125
|
+
# Number of examples per epoch (accross all GPUs)
|
|
126
|
+
# (We must sample this many indices from our sampler)
|
|
127
|
+
n_examples_per_epoch = cfg.trainer.n_examples_per_epoch
|
|
128
|
+
|
|
129
|
+
# ... build the train dataset
|
|
130
|
+
assert (
|
|
131
|
+
"train" in cfg.datasets and cfg.datasets.train
|
|
132
|
+
), "No 'train' dataloader configuration provided! If only performing validation, use `validate.py` instead."
|
|
133
|
+
dataset_and_sampler = recursively_instantiate_datasets_and_samplers(
|
|
134
|
+
cfg.datasets.train
|
|
135
|
+
)
|
|
136
|
+
train_dataset, train_sampler = (
|
|
137
|
+
dataset_and_sampler["dataset"],
|
|
138
|
+
dataset_and_sampler["sampler"],
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# ... compose the train loader
|
|
142
|
+
if "subset_to_example_ids" in cfg.datasets:
|
|
143
|
+
# Backdoor for debugging and overfitting: subset the dataset to a specific set of example IDs
|
|
144
|
+
train_dataset = subset_dataset_to_example_ids(
|
|
145
|
+
train_dataset, cfg.datasets.subset_to_example_ids
|
|
146
|
+
)
|
|
147
|
+
train_sampler = None # Sampler is no longer valid, since we are using a subset of the dataset
|
|
148
|
+
|
|
149
|
+
train_loader = assemble_distributed_loader(
|
|
150
|
+
dataset=train_dataset,
|
|
151
|
+
sampler=train_sampler,
|
|
152
|
+
rank=trainer.fabric.global_rank,
|
|
153
|
+
world_size=trainer.fabric.world_size,
|
|
154
|
+
n_examples_per_epoch=n_examples_per_epoch,
|
|
155
|
+
loader_cfg=cfg.dataloader["train"],
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# ... compose the validation loader(s)
|
|
159
|
+
if "val" in cfg.datasets and cfg.datasets.val:
|
|
160
|
+
val_loaders = assemble_val_loader_dict(
|
|
161
|
+
cfg=cfg.datasets.val,
|
|
162
|
+
rank=trainer.fabric.global_rank,
|
|
163
|
+
world_size=trainer.fabric.world_size,
|
|
164
|
+
loader_cfg=cfg.dataloader["val"],
|
|
165
|
+
)
|
|
166
|
+
else:
|
|
167
|
+
ranked_logger.warning("No validation datasets provided! Skipping validation...")
|
|
168
|
+
val_loaders = None
|
|
169
|
+
|
|
170
|
+
ranked_logger.info("Logging hyperparameters...")
|
|
171
|
+
log_hyperparameters_with_all_loggers(
|
|
172
|
+
trainer=trainer, cfg=cfg, model=trainer.state["model"]
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# ... load the checkpoint configuration
|
|
176
|
+
ckpt_config = None
|
|
177
|
+
if "ckpt_config" in cfg and cfg.ckpt_config:
|
|
178
|
+
ckpt_config = hydra.utils.instantiate(cfg.ckpt_config)
|
|
179
|
+
elif "ckpt_path" in cfg and cfg.ckpt_path:
|
|
180
|
+
# Just a checkpoint path
|
|
181
|
+
if cfg.ckpt_path is not None:
|
|
182
|
+
ckpt_config = CheckpointConfig(path=cfg.ckpt_path)
|
|
183
|
+
|
|
184
|
+
# ... train the model
|
|
185
|
+
ranked_logger.info("Training model...")
|
|
186
|
+
|
|
187
|
+
with suppress_warnings():
|
|
188
|
+
trainer.fit(
|
|
189
|
+
train_loader=train_loader, val_loaders=val_loaders, ckpt_config=ckpt_config
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
if __name__ == "__main__":
|
|
194
|
+
train()
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
from os import PathLike
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
from atomworks.ml.example_id import parse_example_id
|
|
6
|
+
from beartype.typing import Any, Literal
|
|
7
|
+
from rfd3.utils.io import (
|
|
8
|
+
dump_metadata,
|
|
9
|
+
dump_structures,
|
|
10
|
+
dump_trajectories,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
from foundry.callbacks.callback import BaseCallback
|
|
14
|
+
from foundry.utils.ddp import RankedLogger
|
|
15
|
+
|
|
16
|
+
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class DumpValidationStructuresCallback(BaseCallback):
|
|
20
|
+
"""Dump predicted structures and/or diffusion trajectories during validation"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
save_dir: PathLike,
|
|
25
|
+
dump_predictions: bool = False,
|
|
26
|
+
dump_prediction_metadata_json: bool = True,
|
|
27
|
+
one_model_per_file: bool = False,
|
|
28
|
+
dump_trajectories: bool = False,
|
|
29
|
+
dump_denoised_trajectories_only: bool = False,
|
|
30
|
+
dump_every_n: int = 1,
|
|
31
|
+
align_trajectories: bool = True,
|
|
32
|
+
verbose: bool = True,
|
|
33
|
+
extra_fields: list[str] | Literal["all"] = [],
|
|
34
|
+
):
|
|
35
|
+
"""
|
|
36
|
+
Args:
|
|
37
|
+
dump_predictions: Whether to dump structures (CIF files) after validation batches.
|
|
38
|
+
one_model_per_file: If True, write each structure within a diffusion batch to its own CIF files. If False,
|
|
39
|
+
include each structure within a diffusion batch as a separate model within one CIF file.
|
|
40
|
+
dump_trajectories: Whether to dump denoising trajectories after validation batches.
|
|
41
|
+
dump_every_n: Dump validation structures every n epochs.
|
|
42
|
+
extra_fields: List of extra annotations to include in the saved CIF file.
|
|
43
|
+
"""
|
|
44
|
+
super().__init__()
|
|
45
|
+
self.save_dir = Path(save_dir)
|
|
46
|
+
self.dump_predictions = dump_predictions
|
|
47
|
+
self.dump_prediction_metadata_json = dump_prediction_metadata_json
|
|
48
|
+
self.dump_trajectories = dump_trajectories
|
|
49
|
+
self.one_model_per_file = one_model_per_file
|
|
50
|
+
self.dump_denoised_trajectories_only = dump_denoised_trajectories_only
|
|
51
|
+
self.dump_every_n = dump_every_n
|
|
52
|
+
self.align_trajectories = align_trajectories
|
|
53
|
+
self.verbose = verbose
|
|
54
|
+
self.extra_fields = extra_fields
|
|
55
|
+
|
|
56
|
+
def _build_path_from_example_id(
|
|
57
|
+
self,
|
|
58
|
+
example_id,
|
|
59
|
+
dir: str,
|
|
60
|
+
extra: str = "",
|
|
61
|
+
epoch: str = None,
|
|
62
|
+
dataset_name: str = "",
|
|
63
|
+
) -> Path:
|
|
64
|
+
"""Helper function to build a path from a training or validation example_id."""
|
|
65
|
+
try:
|
|
66
|
+
# ... try to extract the PDB ID and assembly ID from the example ID
|
|
67
|
+
parsed_id = parse_example_id(example_id)
|
|
68
|
+
identifier = f"{parsed_id['pdb_id']}_{parsed_id['assembly_id']}"
|
|
69
|
+
except (KeyError, ValueError):
|
|
70
|
+
# ... if parsing fails, fall back to the original example ID
|
|
71
|
+
identifier = example_id
|
|
72
|
+
|
|
73
|
+
# ... parse the example_id into a dictionary of components
|
|
74
|
+
epoch_str = "epoch_{}".format(epoch) if epoch else ""
|
|
75
|
+
return (
|
|
76
|
+
self.save_dir / dir / f"{epoch_str}" / dataset_name / f"{identifier}{extra}"
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
def on_validation_batch_end(
|
|
80
|
+
self,
|
|
81
|
+
*,
|
|
82
|
+
trainer,
|
|
83
|
+
outputs: dict,
|
|
84
|
+
batch: Any,
|
|
85
|
+
dataset_name: str,
|
|
86
|
+
**_,
|
|
87
|
+
):
|
|
88
|
+
if (not self.dump_predictions) and (not self.dump_trajectories):
|
|
89
|
+
return # Nothing to do
|
|
90
|
+
current_epoch = trainer.state["current_epoch"]
|
|
91
|
+
if current_epoch % self.dump_every_n != 0:
|
|
92
|
+
ranked_logger.debug(
|
|
93
|
+
f"Skipping validation batch dump at step {current_epoch} (not every {self.dump_every_n} epochs)"
|
|
94
|
+
)
|
|
95
|
+
return
|
|
96
|
+
|
|
97
|
+
assert (
|
|
98
|
+
"network_output" in outputs
|
|
99
|
+
), "Validation outputs must contain `network_output` to dump structures!"
|
|
100
|
+
|
|
101
|
+
network_output = outputs["network_output"]
|
|
102
|
+
atom_array_stack = outputs["predicted_atom_array_stack"]
|
|
103
|
+
prediction_metadata = outputs["prediction_metadata"]
|
|
104
|
+
example = batch[0] # Assume batch size = 1
|
|
105
|
+
|
|
106
|
+
_build_path_from_example_id = functools.partial(
|
|
107
|
+
self._build_path_from_example_id,
|
|
108
|
+
example_id=example["example_id"],
|
|
109
|
+
epoch=current_epoch,
|
|
110
|
+
dataset_name=dataset_name,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
if self.dump_predictions:
|
|
114
|
+
dump_structures(
|
|
115
|
+
atom_arrays=atom_array_stack,
|
|
116
|
+
base_path=(f := _build_path_from_example_id(dir="predictions")),
|
|
117
|
+
one_model_per_file=self.one_model_per_file,
|
|
118
|
+
extra_fields=self.extra_fields,
|
|
119
|
+
)
|
|
120
|
+
if self.verbose:
|
|
121
|
+
ranked_logger.info(f"Dumped validation predictions to {f}")
|
|
122
|
+
|
|
123
|
+
if self.dump_prediction_metadata_json and prediction_metadata:
|
|
124
|
+
dump_metadata(
|
|
125
|
+
prediction_metadata=prediction_metadata,
|
|
126
|
+
base_path=(f := _build_path_from_example_id(dir="predictions")),
|
|
127
|
+
one_model_per_file=self.one_model_per_file,
|
|
128
|
+
)
|
|
129
|
+
if self.verbose:
|
|
130
|
+
ranked_logger.info(f"Dumped validation metadata to {f}")
|
|
131
|
+
|
|
132
|
+
if self.dump_trajectories:
|
|
133
|
+
dump_trajectories(
|
|
134
|
+
trajectory_list=network_output["X_denoised_L_traj"],
|
|
135
|
+
atom_array=example["atom_array"],
|
|
136
|
+
base_path=(
|
|
137
|
+
f := _build_path_from_example_id(
|
|
138
|
+
dir="trajectories", extra="_denoised"
|
|
139
|
+
)
|
|
140
|
+
),
|
|
141
|
+
align_structures=self.align_trajectories,
|
|
142
|
+
)
|
|
143
|
+
if not self.dump_denoised_trajectories_only:
|
|
144
|
+
dump_trajectories(
|
|
145
|
+
trajectory_list=network_output["X_noisy_L_traj"],
|
|
146
|
+
atom_array=example["atom_array"],
|
|
147
|
+
base_path=_build_path_from_example_id(
|
|
148
|
+
dir="trajectories", extra="_noisy"
|
|
149
|
+
),
|
|
150
|
+
align_structures=self.align_trajectories,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
if self.verbose:
|
|
154
|
+
ranked_logger.info(f"Dumped validation trajectories to {f}")
|