rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
"""Resolvers for Hydra configuration files.
|
|
2
|
+
|
|
3
|
+
Documentation on custom resolvers:
|
|
4
|
+
- https://omegaconf.readthedocs.io/en/latest/custom_resolvers.html
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import importlib
|
|
8
|
+
|
|
9
|
+
from atomworks.enums import ChainType, ChainTypeInfo
|
|
10
|
+
from beartype.typing import Any
|
|
11
|
+
from omegaconf import OmegaConf
|
|
12
|
+
|
|
13
|
+
from ..common import run_once
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# (Custom resolvers)
|
|
17
|
+
@run_once
|
|
18
|
+
def register_resolvers():
|
|
19
|
+
resolvers = {
|
|
20
|
+
"resolve_import": resolve_import,
|
|
21
|
+
"chain_type_info_to_regex": chain_type_info_to_regex,
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
for name, resolver in resolvers.items():
|
|
25
|
+
OmegaConf.register_new_resolver(name, resolver)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def resolve_import(module_path: str, attribute_path: str = None) -> Any:
|
|
29
|
+
"""
|
|
30
|
+
Import a module and access a specific attribute from it.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
module_path (str): The path to the module.
|
|
34
|
+
attribute_path (str): The path to the attribute within the module.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
The imported attribute.
|
|
38
|
+
"""
|
|
39
|
+
module = importlib.import_module(module_path)
|
|
40
|
+
if attribute_path is not None:
|
|
41
|
+
# Split the attribute path to navigate through nested attributes
|
|
42
|
+
attributes = attribute_path.split(".")
|
|
43
|
+
attr = module
|
|
44
|
+
for attr_name in attributes:
|
|
45
|
+
attr = getattr(attr, attr_name)
|
|
46
|
+
return attr
|
|
47
|
+
else:
|
|
48
|
+
return module
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def chain_type_info_to_regex(*args) -> Any:
|
|
52
|
+
"""Convert a combination of ChainType or ChainTypeInfo attributes to a regex string.
|
|
53
|
+
|
|
54
|
+
Primarily used for filtering a dataset by chain type prior to training/validation.
|
|
55
|
+
|
|
56
|
+
Example filter:
|
|
57
|
+
- "pn_unit_1_type.astype('str').str.match('${chain_type_info_to_regex:PROTEINS}')"
|
|
58
|
+
|
|
59
|
+
"""
|
|
60
|
+
regex_str = ""
|
|
61
|
+
|
|
62
|
+
for arg in args:
|
|
63
|
+
if hasattr(ChainType, arg):
|
|
64
|
+
regex_str += f"{getattr(ChainType, arg).value}|"
|
|
65
|
+
elif hasattr(ChainTypeInfo, arg):
|
|
66
|
+
chain_types_list = getattr(ChainTypeInfo, arg)
|
|
67
|
+
for ct in chain_types_list:
|
|
68
|
+
regex_str += f"{ct.value}|"
|
|
69
|
+
else:
|
|
70
|
+
raise ValueError(
|
|
71
|
+
f"Attribute not found for ChainType or ChainTypeInfo: {arg}."
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# Remove the trailing '|'
|
|
75
|
+
regex_str = regex_str[:-1]
|
|
76
|
+
|
|
77
|
+
return regex_str
|
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from os import PathLike
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Dict
|
|
6
|
+
|
|
7
|
+
import hydra
|
|
8
|
+
import torch
|
|
9
|
+
from biotite.structure import AtomArray
|
|
10
|
+
from lightning.fabric import seed_everything
|
|
11
|
+
from omegaconf import OmegaConf
|
|
12
|
+
|
|
13
|
+
from foundry.inference_engines.checkpoint_registry import REGISTERED_CHECKPOINTS
|
|
14
|
+
from foundry.utils.ddp import RankedLogger, set_accelerator_based_on_availability
|
|
15
|
+
from foundry.utils.logging import (
|
|
16
|
+
configure_minimal_inference_logging,
|
|
17
|
+
print_config_tree,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
logging.basicConfig(
|
|
21
|
+
level=logging.INFO,
|
|
22
|
+
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
|
23
|
+
datefmt="%H:%M:%S",
|
|
24
|
+
)
|
|
25
|
+
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def merge(cfg, overrides: dict):
|
|
29
|
+
return OmegaConf.merge(cfg, OmegaConf.create(overrides))
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class BaseInferenceEngine:
|
|
33
|
+
"""
|
|
34
|
+
Base inference engine.
|
|
35
|
+
Separates model setup (expensive, once) from inference (can run multiple times).
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
ckpt_path: PathLike,
|
|
41
|
+
num_nodes: int = 1,
|
|
42
|
+
devices_per_node: int = 1,
|
|
43
|
+
# Config overrides
|
|
44
|
+
transform_overrides={},
|
|
45
|
+
inference_sampler_overrides={},
|
|
46
|
+
trainer_overrides={},
|
|
47
|
+
# Debug
|
|
48
|
+
verbose: bool = False,
|
|
49
|
+
seed: int | None = None,
|
|
50
|
+
):
|
|
51
|
+
"""Initialize inference engine and load model.
|
|
52
|
+
|
|
53
|
+
Model config is loaded from checkpoint and overridden with parameters provided here.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
ckpt_path: Path to model checkpoint.
|
|
57
|
+
seed: Random seed. If None, uses external RNG state. Defaults to ``None``.
|
|
58
|
+
num_nodes: Number of nodes for distributed inference. Defaults to ``1``.
|
|
59
|
+
devices_per_node: Number of devices per node. Defaults to ``1``.
|
|
60
|
+
verbose: If True, show detailed logging and config trees. Defaults to ``False``.
|
|
61
|
+
"""
|
|
62
|
+
if not verbose:
|
|
63
|
+
configure_minimal_inference_logging()
|
|
64
|
+
|
|
65
|
+
# Set attrs
|
|
66
|
+
self.initialized_ = False
|
|
67
|
+
self.trainer = None
|
|
68
|
+
self.pipeline = None
|
|
69
|
+
self.verbose = verbose
|
|
70
|
+
|
|
71
|
+
# Resolve checkpoint path
|
|
72
|
+
if '.' not in str(ckpt_path):
|
|
73
|
+
# Assume registered model
|
|
74
|
+
name = str(ckpt_path)
|
|
75
|
+
assert name in REGISTERED_CHECKPOINTS, 'Checkpoint provided not and not in registered checkpoints'
|
|
76
|
+
ckpt = REGISTERED_CHECKPOINTS[name]
|
|
77
|
+
|
|
78
|
+
ckpt_path = ckpt.get_default_path()
|
|
79
|
+
ranked_logger.info("Using checkpoint from default installation directory, got: {}".format(str(ckpt_path)))
|
|
80
|
+
assert os.path.exists(ckpt_path), 'Invalid checkpoint: {}. And could not find checkpoint in default installation location: {}'.format(name, ckpt_path)
|
|
81
|
+
self.ckpt_path = Path(ckpt_path).resolve()
|
|
82
|
+
|
|
83
|
+
# Set random seed (only if seed is not None)
|
|
84
|
+
if seed is not None:
|
|
85
|
+
ranked_logger.info(f"Seeding everything with seed={seed}...")
|
|
86
|
+
seed_everything(seed, workers=True, verbose=True)
|
|
87
|
+
else:
|
|
88
|
+
ranked_logger.info("Seed is None - using external RNG state")
|
|
89
|
+
self.seed = seed
|
|
90
|
+
|
|
91
|
+
# Stored for later;
|
|
92
|
+
self.transform_overrides = transform_overrides
|
|
93
|
+
self.overrides: dict[str, Any] = {}
|
|
94
|
+
|
|
95
|
+
base_overrides = {
|
|
96
|
+
"trainer.seed": seed,
|
|
97
|
+
"trainer.metrics": {},
|
|
98
|
+
"trainer.loss": None,
|
|
99
|
+
"trainer.num_nodes": num_nodes,
|
|
100
|
+
"trainer.devices_per_node": devices_per_node,
|
|
101
|
+
}
|
|
102
|
+
for key, value in base_overrides.items():
|
|
103
|
+
self._assign_override(key, value)
|
|
104
|
+
|
|
105
|
+
for key, value in trainer_overrides.items():
|
|
106
|
+
self._assign_override(f"trainer.{key}", value)
|
|
107
|
+
|
|
108
|
+
for key, value in inference_sampler_overrides.items():
|
|
109
|
+
self._assign_override(f"model.net.inference_sampler.{key}", value)
|
|
110
|
+
|
|
111
|
+
###################################################################################
|
|
112
|
+
# Required subclasss methods
|
|
113
|
+
###################################################################################
|
|
114
|
+
|
|
115
|
+
def initialize(self):
|
|
116
|
+
if self.initialized_:
|
|
117
|
+
return getattr(self, "cfg", None)
|
|
118
|
+
|
|
119
|
+
# Load checkpoint and config
|
|
120
|
+
ranked_logger.info(
|
|
121
|
+
f"Loading checkpoint from {Path(self.ckpt_path).resolve()}..."
|
|
122
|
+
)
|
|
123
|
+
checkpoint = torch.load(self.ckpt_path, "cpu", weights_only=False)
|
|
124
|
+
cfg = self._override_checkpoint_config(checkpoint["train_cfg"])
|
|
125
|
+
|
|
126
|
+
# Load pipeline first before trainer/model
|
|
127
|
+
self._construct_pipeline(cfg)
|
|
128
|
+
self._construct_trainer(cfg, checkpoint=checkpoint)
|
|
129
|
+
|
|
130
|
+
ranked_logger.info("Model loaded and ready for inference.")
|
|
131
|
+
self.initialized_ = True
|
|
132
|
+
return cfg
|
|
133
|
+
|
|
134
|
+
def run(
|
|
135
|
+
self,
|
|
136
|
+
inputs: (
|
|
137
|
+
Dict[str, dict] | AtomArray | list[AtomArray] | PathLike | list[PathLike]
|
|
138
|
+
),
|
|
139
|
+
*_,
|
|
140
|
+
) -> dict[str, dict] | None:
|
|
141
|
+
self.initialize()
|
|
142
|
+
raise NotImplementedError(
|
|
143
|
+
"Subclasses must implement inference logic in `run` method."
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
###################################################################################
|
|
147
|
+
# Util methods
|
|
148
|
+
###################################################################################
|
|
149
|
+
|
|
150
|
+
def _override_checkpoint_config(self, cfg):
|
|
151
|
+
cfg = merge(cfg, self.overrides)
|
|
152
|
+
cfg = set_accelerator_based_on_availability(cfg)
|
|
153
|
+
return cfg
|
|
154
|
+
|
|
155
|
+
def _construct_trainer(self, cfg, checkpoint=None):
|
|
156
|
+
"""
|
|
157
|
+
Sets attr self.trainer
|
|
158
|
+
"""
|
|
159
|
+
# Instantiate trainer
|
|
160
|
+
ranked_logger.info("Instantiating trainer...")
|
|
161
|
+
if self.verbose:
|
|
162
|
+
print_config_tree(
|
|
163
|
+
cfg.trainer, resolve=True, title="INFERENCE TRAINER CONFIGURATION"
|
|
164
|
+
)
|
|
165
|
+
trainer = hydra.utils.instantiate(
|
|
166
|
+
cfg.trainer,
|
|
167
|
+
_convert_="partial",
|
|
168
|
+
_recursive_=False,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
# Setup model
|
|
172
|
+
ranked_logger.info("Setting up model...")
|
|
173
|
+
trainer.fabric.launch()
|
|
174
|
+
trainer.initialize_or_update_trainer_state(
|
|
175
|
+
{"train_cfg": cfg}
|
|
176
|
+
) # config from training stores net params
|
|
177
|
+
trainer.construct_model()
|
|
178
|
+
|
|
179
|
+
ranked_logger.info("Loading model weights from checkpoint...")
|
|
180
|
+
trainer.load_checkpoint(checkpoint=checkpoint or self.ckpt_path)
|
|
181
|
+
|
|
182
|
+
# Ensure optimizer isn't loaded
|
|
183
|
+
trainer.state["optimizer"] = None
|
|
184
|
+
trainer.state["train_cfg"].model.optimizer = None
|
|
185
|
+
trainer.setup_model_optimizers_and_schedulers()
|
|
186
|
+
trainer.state["model"].eval()
|
|
187
|
+
self.trainer = trainer
|
|
188
|
+
|
|
189
|
+
def _assign_override(self, dotted_key: str, value: Any) -> None:
|
|
190
|
+
"""Assign ``value`` into ``self.overrides`` using a dotted path."""
|
|
191
|
+
target = self.overrides
|
|
192
|
+
keys = dotted_key.split(".")
|
|
193
|
+
for key in keys[:-1]:
|
|
194
|
+
if key not in target or not isinstance(target[key], dict):
|
|
195
|
+
target[key] = {}
|
|
196
|
+
target = target[key]
|
|
197
|
+
target[keys[-1]] = value
|
|
198
|
+
|
|
199
|
+
def _construct_pipeline(self, cfg):
|
|
200
|
+
"""
|
|
201
|
+
Sets attr self.pipeline
|
|
202
|
+
"""
|
|
203
|
+
# Construct pipeline
|
|
204
|
+
ranked_logger.info("Building Transform pipeline...")
|
|
205
|
+
first_val_dataset_key, first_val_dataset = next(iter(cfg.datasets.val.items()))
|
|
206
|
+
ranked_logger.info(
|
|
207
|
+
f"Using settings from validation dataset: {first_val_dataset_key}."
|
|
208
|
+
)
|
|
209
|
+
transform = first_val_dataset.dataset.transform
|
|
210
|
+
transform = merge(transform, self.transform_overrides)
|
|
211
|
+
|
|
212
|
+
if self.verbose:
|
|
213
|
+
print_config_tree(
|
|
214
|
+
transform,
|
|
215
|
+
resolve=True,
|
|
216
|
+
title="INFERENCE TRANSFORM PIPELINE",
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
self.pipeline = hydra.utils.instantiate(transform)
|
|
220
|
+
|
|
221
|
+
# aliases for run
|
|
222
|
+
def forward(self, *args, **kwargs):
|
|
223
|
+
return self.run(*args, **kwargs)
|
|
224
|
+
|
|
225
|
+
def __call__(self, *args, **kwargs):
|
|
226
|
+
return self.run(*args, **kwargs)
|
|
227
|
+
|
|
228
|
+
# for use as a context manager: e.g. `with BaseInferenceEngine(...) as engine:` to automatically cleanup
|
|
229
|
+
def __enter__(self):
|
|
230
|
+
return self
|
|
231
|
+
|
|
232
|
+
def __exit__(self, exc_type, exc, tb):
|
|
233
|
+
self.trainer = None
|
|
234
|
+
self.pipeline = None
|
|
235
|
+
self.initialized_ = False
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
'''Management of checkpoints'''
|
|
2
|
+
import os
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def get_default_checkpoint_dir() -> Path:
|
|
8
|
+
"""Get the default checkpoint directory.
|
|
9
|
+
|
|
10
|
+
Priority:
|
|
11
|
+
1. FOUNDRY_CHECKPOINTS_DIR environment variable
|
|
12
|
+
2. ~/.foundry/checkpoints
|
|
13
|
+
"""
|
|
14
|
+
if "FOUNDRY_CHECKPOINTS_DIR" in os.environ and os.environ.get("FOUNDRY_CHECKPOINTS_DIR"):
|
|
15
|
+
return Path(os.environ["FOUNDRY_CHECKPOINTS_DIR"]).absolute()
|
|
16
|
+
return Path.home() / ".foundry" / "checkpoints"
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class RegisteredCheckpoint:
|
|
20
|
+
url: str
|
|
21
|
+
filename: str
|
|
22
|
+
description: str
|
|
23
|
+
sha256: None = None # Optional: add checksum for verification
|
|
24
|
+
|
|
25
|
+
def get_default_path(self):
|
|
26
|
+
return get_default_checkpoint_dir() / self.filename
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
REGISTERED_CHECKPOINTS = {
|
|
30
|
+
"rfd3": RegisteredCheckpoint(
|
|
31
|
+
url = "https://files.ipd.uw.edu/pub/rfd3/rfd3_foundry_2025_12_01_remapped.ckpt",
|
|
32
|
+
filename = "rfd3_latest.ckpt",
|
|
33
|
+
description = "RFdiffusion3 checkpoint",
|
|
34
|
+
),
|
|
35
|
+
"rf3": RegisteredCheckpoint(
|
|
36
|
+
url = "https://files.ipd.uw.edu/pub/rf3/rf3_foundry_01_24_latest_remapped.ckpt",
|
|
37
|
+
filename= "rf3_foundry_01_24_latest_remapped.ckpt",
|
|
38
|
+
description= "latest RF3 checkpoint trained with data until 1/2024 (expect best performance)",
|
|
39
|
+
),
|
|
40
|
+
"proteinmpnn": RegisteredCheckpoint(
|
|
41
|
+
url = "https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_020.pt",
|
|
42
|
+
filename= "proteinmpnn_v_48_020.pt",
|
|
43
|
+
description= "ProteinMPNN checkpoint",
|
|
44
|
+
),
|
|
45
|
+
"ligandmpnn": RegisteredCheckpoint(
|
|
46
|
+
url = "https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_010_25.pt",
|
|
47
|
+
filename= "ligandmpnn_v_32_010_25.pt",
|
|
48
|
+
description= "LigandMPNN checkpoint",
|
|
49
|
+
),
|
|
50
|
+
# Other models
|
|
51
|
+
"rf3_preprint_921": RegisteredCheckpoint(
|
|
52
|
+
url = "https://files.ipd.uw.edu/pub/rf3/rf3_foundry_09_21_preprint_remapped.ckpt",
|
|
53
|
+
filename = "rf3_foundry_09_21_preprint_remapped.ckpt",
|
|
54
|
+
description = "RF3 preprint checkpoint trained with data until 9/2021",
|
|
55
|
+
),
|
|
56
|
+
"rf3_preprint_124": RegisteredCheckpoint(
|
|
57
|
+
url = "https://files.ipd.uw.edu/pub/rf3/rf3_foundry_01_24_preprint_remapped.ckpt",
|
|
58
|
+
filename = "rf3_foundry_01_24_preprint_remapped.ckpt",
|
|
59
|
+
description= "RF3 preprint checkpoint trained with data until 1/2024",
|
|
60
|
+
),
|
|
61
|
+
"solublempnn": RegisteredCheckpoint(
|
|
62
|
+
url = "https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_020.pt",
|
|
63
|
+
filename= "solublempnn_v_48_020.pt",
|
|
64
|
+
description= "SolubleMPNN checkpoint"
|
|
65
|
+
)
|
|
66
|
+
}
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
import hydra
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from omegaconf import DictConfig
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Loss(nn.Module):
|
|
7
|
+
def __init__(self, **losses):
|
|
8
|
+
super().__init__()
|
|
9
|
+
self.to_compute = []
|
|
10
|
+
for loss_name, loss in losses.items():
|
|
11
|
+
loss_fn = hydra.utils.instantiate(loss)
|
|
12
|
+
self.to_compute.append(loss_fn)
|
|
13
|
+
assert not isinstance(
|
|
14
|
+
loss_fn, DictConfig
|
|
15
|
+
), f"Loss {loss_name} was instantiated as a DictConfig. Is _target_ present?."
|
|
16
|
+
|
|
17
|
+
def forward(
|
|
18
|
+
self,
|
|
19
|
+
network_input,
|
|
20
|
+
network_output,
|
|
21
|
+
loss_input,
|
|
22
|
+
):
|
|
23
|
+
loss_dict = {}
|
|
24
|
+
loss = 0
|
|
25
|
+
for loss_fn in self.to_compute:
|
|
26
|
+
loss_, loss_dict_ = loss_fn(network_input, network_output, loss_input)
|
|
27
|
+
loss += loss_
|
|
28
|
+
loss_dict.update(loss_dict_)
|
|
29
|
+
loss_dict["total_loss"] = loss.detach()
|
|
30
|
+
return loss, loss_dict
|