rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
foundry/utils/weights.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
1
|
+
"""Utils for loading weights from checkpoints."""
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from enum import StrEnum, auto
|
|
6
|
+
from os import PathLike
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from beartype.typing import Pattern
|
|
10
|
+
from torch import nn
|
|
11
|
+
|
|
12
|
+
from foundry.utils.ddp import RankedLogger
|
|
13
|
+
|
|
14
|
+
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class WeightLoadingError(Exception):
|
|
18
|
+
"""Exception raised when there's an error loading weights."""
|
|
19
|
+
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class WeightLoadingPolicy(StrEnum):
|
|
24
|
+
"""Policy for handling weights when loading checkpoints."""
|
|
25
|
+
|
|
26
|
+
# Always keep default initialization, regardless of whether the parameter is in the checkpoint or shapes match
|
|
27
|
+
REINIT = auto()
|
|
28
|
+
|
|
29
|
+
# Always zero-initialize, regardless of whether the parameter is in the checkpoint or shapes match
|
|
30
|
+
ZERO_INIT = auto()
|
|
31
|
+
|
|
32
|
+
# Copy from checkpoint only when shapes match exactly, otherwise error
|
|
33
|
+
COPY = auto()
|
|
34
|
+
|
|
35
|
+
# Copy from checkpoint if tensors are the same rank, padding with zeros if shapes don't match exectly
|
|
36
|
+
COPY_AND_ZERO_PAD = auto()
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class _PatternPolicyMixin:
|
|
40
|
+
"""Mixin for handling glob-to-regex pattern compilation and matching for parameter policies.
|
|
41
|
+
|
|
42
|
+
Patterns can use the following wildcards:
|
|
43
|
+
- * matches any sequence of characters
|
|
44
|
+
- ? matches any single character
|
|
45
|
+
- [abc] matches any character in the brackets
|
|
46
|
+
- . matches a literal dot
|
|
47
|
+
|
|
48
|
+
Examples:
|
|
49
|
+
- "model.*.weight" matches any weight parameter in the model
|
|
50
|
+
- "model.encoder?.weight" matches encoder1.weight, encoder2.weight, etc.
|
|
51
|
+
- "model.encoder[12].weight" matches encoder1.weight and encoder2.weight
|
|
52
|
+
- "model.encoder.*.bias" matches any bias parameter in encoder submodules
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
_compiled_patterns: dict[Pattern, any]
|
|
56
|
+
|
|
57
|
+
@staticmethod
|
|
58
|
+
def _glob_to_regex(pattern: str) -> str:
|
|
59
|
+
# Convert glob pattern to regex string
|
|
60
|
+
return (
|
|
61
|
+
pattern.replace(".", r"\.")
|
|
62
|
+
.replace("*", ".*")
|
|
63
|
+
.replace("?", ".")
|
|
64
|
+
.replace("[", "[")
|
|
65
|
+
.replace("]", "]")
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
def _compile_patterns(self, policy_dict: dict[str, any]) -> dict[Pattern, any]:
|
|
69
|
+
compiled = {}
|
|
70
|
+
for pattern, value in list(policy_dict.items()):
|
|
71
|
+
if any(c in pattern for c in ["*", "?", "[", "]"]):
|
|
72
|
+
regex = self._glob_to_regex(pattern)
|
|
73
|
+
compiled[re.compile(f"^{regex}$")] = value
|
|
74
|
+
return compiled
|
|
75
|
+
|
|
76
|
+
def _get_policy_by_pattern(
|
|
77
|
+
self, param_name: str, policy_dict: dict[str, any], default: any
|
|
78
|
+
) -> any:
|
|
79
|
+
# Exact match first
|
|
80
|
+
if policy_dict and param_name in policy_dict:
|
|
81
|
+
return policy_dict[param_name]
|
|
82
|
+
# Pattern match
|
|
83
|
+
for pattern, value in self._compiled_patterns.items():
|
|
84
|
+
if pattern.match(param_name):
|
|
85
|
+
return value
|
|
86
|
+
return default
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@dataclass
|
|
90
|
+
class WeightLoadingConfig(_PatternPolicyMixin):
|
|
91
|
+
"""Configuration for handling weights when loading a checkpoint."""
|
|
92
|
+
|
|
93
|
+
default_policy: WeightLoadingPolicy | str = WeightLoadingPolicy.COPY
|
|
94
|
+
fallback_policy: WeightLoadingPolicy | str = WeightLoadingPolicy.REINIT
|
|
95
|
+
param_policies: dict[str, WeightLoadingPolicy | str] = field(default_factory=dict)
|
|
96
|
+
_compiled_patterns: dict[Pattern, WeightLoadingPolicy] = field(
|
|
97
|
+
default_factory=dict, repr=False
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
def __post_init__(self):
|
|
101
|
+
if isinstance(self.default_policy, str):
|
|
102
|
+
self.default_policy = WeightLoadingPolicy(self.default_policy)
|
|
103
|
+
if isinstance(self.fallback_policy, str):
|
|
104
|
+
self.fallback_policy = WeightLoadingPolicy(self.fallback_policy)
|
|
105
|
+
for key, value in self.param_policies.items():
|
|
106
|
+
if isinstance(value, str):
|
|
107
|
+
self.param_policies[key] = WeightLoadingPolicy(value)
|
|
108
|
+
self._compiled_patterns = self._compile_patterns(self.param_policies)
|
|
109
|
+
|
|
110
|
+
def get_policy(self, param_name: str) -> WeightLoadingPolicy:
|
|
111
|
+
policy = self._get_policy_by_pattern(
|
|
112
|
+
param_name, self.param_policies, self.default_policy
|
|
113
|
+
)
|
|
114
|
+
assert isinstance(policy, WeightLoadingPolicy)
|
|
115
|
+
return policy
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@dataclass
|
|
119
|
+
class ParameterFreezingConfig(_PatternPolicyMixin):
|
|
120
|
+
"""Configuration for freezing model parameters after loading weights.
|
|
121
|
+
|
|
122
|
+
Allows specifying which parameters to freeze (set requires_grad=False) by exact name or pattern.
|
|
123
|
+
Patterns use glob-style wildcards (*, ?).
|
|
124
|
+
|
|
125
|
+
Attributes:
|
|
126
|
+
param_policies: Dict mapping parameter names or patterns to True (freeze) or False (do not freeze).
|
|
127
|
+
freeze_by_default: Whether to freeze parameters not matched by any pattern (default: False).
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
param_policies: dict[str, bool] = field(default_factory=dict)
|
|
131
|
+
freeze_by_default: bool = False
|
|
132
|
+
_compiled_patterns: dict[Pattern, bool] = field(default_factory=dict, repr=False)
|
|
133
|
+
|
|
134
|
+
def __post_init__(self):
|
|
135
|
+
self._compiled_patterns = self._compile_patterns(self.param_policies)
|
|
136
|
+
|
|
137
|
+
def is_frozen(self, param_name: str) -> bool:
|
|
138
|
+
"""Get whether a parameter is frozen according to the config."""
|
|
139
|
+
is_frozen = self._get_policy_by_pattern(
|
|
140
|
+
param_name, self.param_policies, self.freeze_by_default
|
|
141
|
+
)
|
|
142
|
+
assert isinstance(is_frozen, bool)
|
|
143
|
+
return is_frozen
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def freeze_parameters_with_config(
|
|
147
|
+
model: nn.Module, config: ParameterFreezingConfig, verbose: bool = True
|
|
148
|
+
) -> None:
|
|
149
|
+
"""Freeze (set requires_grad=False) or unfreeze parameters according to config.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
model: The model whose parameters to freeze/unfreeze.
|
|
153
|
+
config: ParameterFreezingConfig specifying which parameters to freeze.
|
|
154
|
+
verbose: Whether to log which parameters have non-default policies applied.
|
|
155
|
+
"""
|
|
156
|
+
for name, param in model.named_parameters():
|
|
157
|
+
is_frozen = config.is_frozen(name)
|
|
158
|
+
param.requires_grad = not is_frozen
|
|
159
|
+
|
|
160
|
+
if is_frozen != config.freeze_by_default and verbose:
|
|
161
|
+
ranked_logger.info(f"Non-default freezing applied to {name}: {is_frozen}")
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def load_weights_with_policies(
|
|
165
|
+
model: nn.Module,
|
|
166
|
+
ckpt: dict[str, torch.Tensor],
|
|
167
|
+
config: WeightLoadingConfig | None = None,
|
|
168
|
+
) -> dict:
|
|
169
|
+
"""Load checkpoint weights into model according to the specified configuration.
|
|
170
|
+
|
|
171
|
+
Allows for partial loading of weights and zero-initialization of mismatched and arbitrary parameters.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
model: The model to load weights INTO. By default, all model weights are re-initialized; we overwrite
|
|
175
|
+
with the checkpoint weights where appropriate
|
|
176
|
+
ckpt: Dictionary mapping parameter names to tensors (loaded from checkpoint on disk)
|
|
177
|
+
config: Configuration for handling weight loading. If None, uses default config
|
|
178
|
+
Returns:
|
|
179
|
+
dict: The updated state_dict (not loaded into model yet)
|
|
180
|
+
"""
|
|
181
|
+
if config is None:
|
|
182
|
+
# (Initialize default config if not provided)
|
|
183
|
+
config = WeightLoadingConfig()
|
|
184
|
+
|
|
185
|
+
current_state = model.state_dict()
|
|
186
|
+
updated_state = {} # We will update this with the new weights
|
|
187
|
+
|
|
188
|
+
def _apply_policy(
|
|
189
|
+
name: str,
|
|
190
|
+
current_param: torch.Tensor,
|
|
191
|
+
checkpoint_param: torch.Tensor | None,
|
|
192
|
+
policy: WeightLoadingPolicy,
|
|
193
|
+
) -> torch.Tensor:
|
|
194
|
+
"""Apply a weight loading policy and return the resulting tensor.
|
|
195
|
+
|
|
196
|
+
Raises WeightLoadingError for any policy application failures.
|
|
197
|
+
"""
|
|
198
|
+
if policy == WeightLoadingPolicy.REINIT:
|
|
199
|
+
# Keep original initialization
|
|
200
|
+
return current_param
|
|
201
|
+
|
|
202
|
+
elif policy == WeightLoadingPolicy.ZERO_INIT:
|
|
203
|
+
# Zero-initialize
|
|
204
|
+
return torch.zeros_like(current_param)
|
|
205
|
+
|
|
206
|
+
elif policy == WeightLoadingPolicy.COPY:
|
|
207
|
+
# Must have checkpoint param and shapes must match
|
|
208
|
+
if checkpoint_param is None:
|
|
209
|
+
raise WeightLoadingError(f"Parameter '{name}' not found in checkpoint")
|
|
210
|
+
if current_param.shape != checkpoint_param.shape:
|
|
211
|
+
raise WeightLoadingError(
|
|
212
|
+
f"Shape mismatch for '{name}': model {current_param.shape} vs checkpoint {checkpoint_param.shape}"
|
|
213
|
+
)
|
|
214
|
+
return checkpoint_param
|
|
215
|
+
|
|
216
|
+
elif policy == WeightLoadingPolicy.COPY_AND_ZERO_PAD:
|
|
217
|
+
# Must have checkpoint param and same number of dimensions
|
|
218
|
+
if checkpoint_param is None:
|
|
219
|
+
raise WeightLoadingError(f"Parameter '{name}' not found in checkpoint")
|
|
220
|
+
if len(current_param.shape) != len(checkpoint_param.shape):
|
|
221
|
+
raise WeightLoadingError(
|
|
222
|
+
f"Different dimensions for '{name}': model {len(current_param.shape)}D vs checkpoint {len(checkpoint_param.shape)}D"
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# Copy where shapes match, zero-init the rest
|
|
226
|
+
new_param = torch.zeros_like(current_param)
|
|
227
|
+
slices = tuple(
|
|
228
|
+
slice(0, min(d_ckpt, d_current))
|
|
229
|
+
for d_ckpt, d_current in zip(
|
|
230
|
+
checkpoint_param.shape, current_param.shape
|
|
231
|
+
)
|
|
232
|
+
)
|
|
233
|
+
new_param[slices] = checkpoint_param[slices]
|
|
234
|
+
return new_param
|
|
235
|
+
|
|
236
|
+
# ... loop through all named parameters in the model
|
|
237
|
+
for name, current_param in current_state.items():
|
|
238
|
+
# Get the policy for this parameter
|
|
239
|
+
policy = config.get_policy(name)
|
|
240
|
+
|
|
241
|
+
# Get the corresponding parameter from the checkpoint
|
|
242
|
+
checkpoint_param = ckpt.get(name, None)
|
|
243
|
+
|
|
244
|
+
try:
|
|
245
|
+
# Try to apply the primary policy
|
|
246
|
+
result = _apply_policy(name, current_param, checkpoint_param, policy)
|
|
247
|
+
updated_state[name] = result
|
|
248
|
+
except WeightLoadingError as e:
|
|
249
|
+
# Primary policy failed, try fallback
|
|
250
|
+
ranked_logger.warning(
|
|
251
|
+
f"Failed to apply policy: '{policy}' to '{name}': {str(e)}. Falling back to policy: '{config.fallback_policy}'."
|
|
252
|
+
)
|
|
253
|
+
result = _apply_policy(
|
|
254
|
+
name, current_param, checkpoint_param, config.fallback_policy
|
|
255
|
+
)
|
|
256
|
+
updated_state[name] = result
|
|
257
|
+
|
|
258
|
+
return updated_state
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
@dataclass
|
|
262
|
+
class CheckpointConfig:
|
|
263
|
+
"""Configuration for loading checkpoints.
|
|
264
|
+
|
|
265
|
+
TODO: Implement reset_scheduler and reset_ema
|
|
266
|
+
"""
|
|
267
|
+
|
|
268
|
+
path: PathLike
|
|
269
|
+
reset_optimizer: bool = False
|
|
270
|
+
weight_loading_config: WeightLoadingConfig | None = None
|
|
271
|
+
parameter_freezing_config: ParameterFreezingConfig | None = None
|
foundry/version.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# file generated by setuptools-scm
|
|
2
|
+
# don't change, don't track in version control
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"__version__",
|
|
6
|
+
"__version_tuple__",
|
|
7
|
+
"version",
|
|
8
|
+
"version_tuple",
|
|
9
|
+
"__commit_id__",
|
|
10
|
+
"commit_id",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
TYPE_CHECKING = False
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from typing import Tuple
|
|
16
|
+
from typing import Union
|
|
17
|
+
|
|
18
|
+
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
|
19
|
+
COMMIT_ID = Union[str, None]
|
|
20
|
+
else:
|
|
21
|
+
VERSION_TUPLE = object
|
|
22
|
+
COMMIT_ID = object
|
|
23
|
+
|
|
24
|
+
version: str
|
|
25
|
+
__version__: str
|
|
26
|
+
__version_tuple__: VERSION_TUPLE
|
|
27
|
+
version_tuple: VERSION_TUPLE
|
|
28
|
+
commit_id: COMMIT_ID
|
|
29
|
+
__commit_id__: COMMIT_ID
|
|
30
|
+
|
|
31
|
+
__version__ = version = '0.1.1'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 1, 1)
|
|
33
|
+
|
|
34
|
+
__commit_id__ = commit_id = None
|
foundry_cli/__init__.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
"""CLI for foundry model checkpoint installation and management."""
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Optional
|
|
6
|
+
from urllib.request import urlopen
|
|
7
|
+
|
|
8
|
+
import rootutils
|
|
9
|
+
import typer
|
|
10
|
+
from dotenv import find_dotenv, load_dotenv, set_key
|
|
11
|
+
from rich.console import Console
|
|
12
|
+
from rich.progress import (
|
|
13
|
+
BarColumn,
|
|
14
|
+
DownloadColumn,
|
|
15
|
+
Progress,
|
|
16
|
+
SpinnerColumn,
|
|
17
|
+
TextColumn,
|
|
18
|
+
TimeRemainingColumn,
|
|
19
|
+
TransferSpeedColumn,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
from foundry.inference_engines.checkpoint_registry import (
|
|
23
|
+
REGISTERED_CHECKPOINTS,
|
|
24
|
+
get_default_checkpoint_dir,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
28
|
+
load_dotenv(override=True)
|
|
29
|
+
|
|
30
|
+
app = typer.Typer(help="Foundry model checkpoint installation utilities")
|
|
31
|
+
console = Console()
|
|
32
|
+
|
|
33
|
+
def download_file(url: str, dest: Path, verify_hash: Optional[str] = None) -> None:
|
|
34
|
+
"""Download a file with progress bar and optional hash verification.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
url: URL to download from
|
|
38
|
+
dest: Destination file path
|
|
39
|
+
verify_hash: Optional SHA256 hash to verify against
|
|
40
|
+
|
|
41
|
+
Raises:
|
|
42
|
+
ValueError: If hash verification fails
|
|
43
|
+
"""
|
|
44
|
+
dest.parent.mkdir(parents=True, exist_ok=True)
|
|
45
|
+
|
|
46
|
+
with Progress(
|
|
47
|
+
SpinnerColumn(),
|
|
48
|
+
TextColumn("[progress.description]{task.description}"),
|
|
49
|
+
BarColumn(),
|
|
50
|
+
DownloadColumn(),
|
|
51
|
+
TransferSpeedColumn(),
|
|
52
|
+
TimeRemainingColumn(),
|
|
53
|
+
) as progress:
|
|
54
|
+
# Get file size
|
|
55
|
+
with urlopen(url) as response:
|
|
56
|
+
file_size = int(response.headers.get("Content-Length", 0))
|
|
57
|
+
|
|
58
|
+
task = progress.add_task(
|
|
59
|
+
f"Downloading {dest.name}", total=file_size, start=True
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# Download with progress
|
|
63
|
+
hasher = hashlib.sha256() if verify_hash else None
|
|
64
|
+
with open(dest, "wb") as f:
|
|
65
|
+
while True:
|
|
66
|
+
chunk = response.read(8192)
|
|
67
|
+
if not chunk:
|
|
68
|
+
break
|
|
69
|
+
f.write(chunk)
|
|
70
|
+
if hasher:
|
|
71
|
+
hasher.update(chunk)
|
|
72
|
+
progress.update(task, advance=len(chunk))
|
|
73
|
+
|
|
74
|
+
# Verify hash if provided
|
|
75
|
+
if verify_hash:
|
|
76
|
+
computed_hash = hasher.hexdigest()
|
|
77
|
+
if computed_hash != verify_hash:
|
|
78
|
+
dest.unlink() # Remove corrupted file
|
|
79
|
+
raise ValueError(
|
|
80
|
+
f"Hash mismatch! Expected {verify_hash}, got {computed_hash}"
|
|
81
|
+
)
|
|
82
|
+
console.print("[green]✓[/green] Hash verification passed")
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def install_model(
|
|
86
|
+
model_name: str, checkpoint_dir: Path, force: bool = False
|
|
87
|
+
) -> None:
|
|
88
|
+
"""Install a single model checkpoint.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
model_name: Name of the model (rfd3, rf3, mpnn)
|
|
92
|
+
checkpoint_dir: Directory to save checkpoints
|
|
93
|
+
force: Overwrite existing checkpoint if it exists
|
|
94
|
+
"""
|
|
95
|
+
if model_name not in REGISTERED_CHECKPOINTS:
|
|
96
|
+
console.print(f"[red]Error:[/red] Unknown model '{model_name}'")
|
|
97
|
+
console.print(f"Available models: {', '.join(REGISTERED_CHECKPOINTS.keys())}")
|
|
98
|
+
raise typer.Exit(1)
|
|
99
|
+
|
|
100
|
+
checkpoint_info = REGISTERED_CHECKPOINTS[model_name]
|
|
101
|
+
dest_path = checkpoint_dir / checkpoint_info["filename"]
|
|
102
|
+
|
|
103
|
+
# Check if already exists
|
|
104
|
+
if dest_path.exists() and not force:
|
|
105
|
+
console.print(
|
|
106
|
+
f"[yellow]⚠[/yellow] {model_name} checkpoint already exists at {dest_path}"
|
|
107
|
+
)
|
|
108
|
+
console.print("Use --force to overwrite")
|
|
109
|
+
return
|
|
110
|
+
|
|
111
|
+
console.print(
|
|
112
|
+
f"[cyan]Installing {model_name}:[/cyan] {checkpoint_info['description']}"
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
try:
|
|
116
|
+
download_file(
|
|
117
|
+
checkpoint_info["url"], dest_path, checkpoint_info.get("sha256")
|
|
118
|
+
)
|
|
119
|
+
console.print(
|
|
120
|
+
f"[green]✓[/green] Successfully installed {model_name} to {dest_path}"
|
|
121
|
+
)
|
|
122
|
+
except Exception as e:
|
|
123
|
+
console.print(f"[red]✗[/red] Failed to install {model_name}: {e}")
|
|
124
|
+
raise typer.Exit(1)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
@app.command()
|
|
128
|
+
def install(
|
|
129
|
+
models: list[str] = typer.Argument(
|
|
130
|
+
...,
|
|
131
|
+
help="Models to install: 'all', 'rfd3', 'rf3', 'mpnn', or combination",
|
|
132
|
+
),
|
|
133
|
+
checkpoint_dir: Optional[Path] = typer.Option(
|
|
134
|
+
None,
|
|
135
|
+
"--checkpoint-dir",
|
|
136
|
+
"-d",
|
|
137
|
+
help="Directory to save checkpoints (default: $FOUNDRY_CHECKPOINTS_DIR or ~/.foundry/checkpoints)",
|
|
138
|
+
),
|
|
139
|
+
force: bool = typer.Option(
|
|
140
|
+
False, "--force", "-f", help="Overwrite existing checkpoints"
|
|
141
|
+
),
|
|
142
|
+
):
|
|
143
|
+
"""Install model checkpoints for foundry.
|
|
144
|
+
|
|
145
|
+
Examples:
|
|
146
|
+
|
|
147
|
+
foundry install all
|
|
148
|
+
|
|
149
|
+
foundry install rfd3 rf3
|
|
150
|
+
|
|
151
|
+
foundry install proteinmpnn --checkpoint-dir ./checkpoints
|
|
152
|
+
"""
|
|
153
|
+
# Determine checkpoint directory
|
|
154
|
+
if checkpoint_dir is None:
|
|
155
|
+
checkpoint_dir = get_default_checkpoint_dir()
|
|
156
|
+
|
|
157
|
+
console.print(f"[bold]Checkpoint directory:[/bold] {checkpoint_dir}")
|
|
158
|
+
console.print()
|
|
159
|
+
|
|
160
|
+
# Expand 'all' to all available models
|
|
161
|
+
if "all" in models:
|
|
162
|
+
models_to_install = ['rfd3', 'proteinmpnn', 'ligandmpnn', 'rf3']
|
|
163
|
+
else:
|
|
164
|
+
models_to_install = models
|
|
165
|
+
|
|
166
|
+
# Install each model
|
|
167
|
+
for model_name in models_to_install:
|
|
168
|
+
install_model(model_name, checkpoint_dir, force)
|
|
169
|
+
console.print()
|
|
170
|
+
|
|
171
|
+
set_key(
|
|
172
|
+
dotenv_path=find_dotenv(),
|
|
173
|
+
key_to_set='FOUNDRY_CHECKPOINTS_DIR',
|
|
174
|
+
value_to_set=str(checkpoint_dir),
|
|
175
|
+
export = False,
|
|
176
|
+
)
|
|
177
|
+
console.print(
|
|
178
|
+
f"Set checkpoint installation directory to: {checkpoint_dir}"
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
console.print("[bold green]Installation complete![/bold green]")
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
@app.command(name="list")
|
|
185
|
+
def list_models():
|
|
186
|
+
"""List available model checkpoints."""
|
|
187
|
+
console.print("[bold]Available models:[/bold]\n")
|
|
188
|
+
for name, info in REGISTERED_CHECKPOINTS.items():
|
|
189
|
+
console.print(f" [cyan]{name:8}[/cyan] - {info['description']}")
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
@app.command()
|
|
193
|
+
def show(
|
|
194
|
+
checkpoint_dir: Optional[Path] = typer.Option(
|
|
195
|
+
None,
|
|
196
|
+
"--checkpoint-dir",
|
|
197
|
+
"-d",
|
|
198
|
+
help="Checkpoint directory to show",
|
|
199
|
+
),
|
|
200
|
+
):
|
|
201
|
+
"""Show installed checkpoints."""
|
|
202
|
+
if checkpoint_dir is None:
|
|
203
|
+
checkpoint_dir = get_default_checkpoint_dir()
|
|
204
|
+
|
|
205
|
+
if not checkpoint_dir.exists():
|
|
206
|
+
console.print(
|
|
207
|
+
f"[yellow]No checkpoints directory found at {checkpoint_dir}[/yellow]"
|
|
208
|
+
)
|
|
209
|
+
raise typer.Exit(0)
|
|
210
|
+
|
|
211
|
+
checkpoint_files = list(checkpoint_dir.glob("*.ckpt"))
|
|
212
|
+
if not checkpoint_files:
|
|
213
|
+
console.print(
|
|
214
|
+
f"[yellow]No checkpoint files found in {checkpoint_dir}[/yellow]"
|
|
215
|
+
)
|
|
216
|
+
raise typer.Exit(0)
|
|
217
|
+
|
|
218
|
+
console.print(f"[bold]Installed checkpoints in {checkpoint_dir}:[/bold]\n")
|
|
219
|
+
total_size = 0
|
|
220
|
+
for ckpt in sorted(checkpoint_files):
|
|
221
|
+
size = ckpt.stat().st_size / (1024**3) # GB
|
|
222
|
+
total_size += size
|
|
223
|
+
console.print(f" {ckpt.name:30} {size:8.2f} GB")
|
|
224
|
+
|
|
225
|
+
console.print(f"\n[bold]Total:[/bold] {total_size:.2f} GB")
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
@app.command()
|
|
229
|
+
def clean(
|
|
230
|
+
checkpoint_dir: Optional[Path] = typer.Option(
|
|
231
|
+
None,
|
|
232
|
+
"--checkpoint-dir",
|
|
233
|
+
"-d",
|
|
234
|
+
help="Checkpoint directory to clean",
|
|
235
|
+
),
|
|
236
|
+
confirm: bool = typer.Option(
|
|
237
|
+
True, "--confirm/--no-confirm", help="Ask for confirmation before deleting"
|
|
238
|
+
),
|
|
239
|
+
):
|
|
240
|
+
"""Remove all downloaded checkpoints."""
|
|
241
|
+
if checkpoint_dir is None:
|
|
242
|
+
checkpoint_dir = get_default_checkpoint_dir()
|
|
243
|
+
|
|
244
|
+
if not checkpoint_dir.exists():
|
|
245
|
+
console.print(f"[yellow]No checkpoints found at {checkpoint_dir}[/yellow]")
|
|
246
|
+
raise typer.Exit(0)
|
|
247
|
+
|
|
248
|
+
# List files to delete
|
|
249
|
+
checkpoint_files = list(checkpoint_dir.glob("*.ckpt"))
|
|
250
|
+
if not checkpoint_files:
|
|
251
|
+
console.print(
|
|
252
|
+
f"[yellow]No checkpoint files found in {checkpoint_dir}[/yellow]"
|
|
253
|
+
)
|
|
254
|
+
raise typer.Exit(0)
|
|
255
|
+
|
|
256
|
+
console.print("[bold]Files to delete:[/bold]")
|
|
257
|
+
total_size = 0
|
|
258
|
+
for ckpt in checkpoint_files:
|
|
259
|
+
size = ckpt.stat().st_size / (1024**3) # GB
|
|
260
|
+
total_size += size
|
|
261
|
+
console.print(f" {ckpt.name} ({size:.2f} GB)")
|
|
262
|
+
|
|
263
|
+
console.print(f"\n[bold]Total:[/bold] {total_size:.2f} GB")
|
|
264
|
+
|
|
265
|
+
# Confirm deletion
|
|
266
|
+
if confirm:
|
|
267
|
+
should_delete = typer.confirm("\nDelete these files?")
|
|
268
|
+
if not should_delete:
|
|
269
|
+
console.print("[yellow]Cancelled[/yellow]")
|
|
270
|
+
raise typer.Exit(0)
|
|
271
|
+
|
|
272
|
+
# Delete files
|
|
273
|
+
for ckpt in checkpoint_files:
|
|
274
|
+
ckpt.unlink()
|
|
275
|
+
console.print(f"[red]✗[/red] Deleted {ckpt.name}")
|
|
276
|
+
|
|
277
|
+
console.print("[green]✓[/green] Cleanup complete")
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
if __name__ == "__main__":
|
|
281
|
+
app()
|
mpnn/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""MPNN - ProteinMPNN and LigandMPNN implementations."""
|