rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,271 @@
1
+ """Utils for loading weights from checkpoints."""
2
+
3
+ import re
4
+ from dataclasses import dataclass, field
5
+ from enum import StrEnum, auto
6
+ from os import PathLike
7
+
8
+ import torch
9
+ from beartype.typing import Pattern
10
+ from torch import nn
11
+
12
+ from foundry.utils.ddp import RankedLogger
13
+
14
+ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
15
+
16
+
17
+ class WeightLoadingError(Exception):
18
+ """Exception raised when there's an error loading weights."""
19
+
20
+ pass
21
+
22
+
23
+ class WeightLoadingPolicy(StrEnum):
24
+ """Policy for handling weights when loading checkpoints."""
25
+
26
+ # Always keep default initialization, regardless of whether the parameter is in the checkpoint or shapes match
27
+ REINIT = auto()
28
+
29
+ # Always zero-initialize, regardless of whether the parameter is in the checkpoint or shapes match
30
+ ZERO_INIT = auto()
31
+
32
+ # Copy from checkpoint only when shapes match exactly, otherwise error
33
+ COPY = auto()
34
+
35
+ # Copy from checkpoint if tensors are the same rank, padding with zeros if shapes don't match exectly
36
+ COPY_AND_ZERO_PAD = auto()
37
+
38
+
39
+ class _PatternPolicyMixin:
40
+ """Mixin for handling glob-to-regex pattern compilation and matching for parameter policies.
41
+
42
+ Patterns can use the following wildcards:
43
+ - * matches any sequence of characters
44
+ - ? matches any single character
45
+ - [abc] matches any character in the brackets
46
+ - . matches a literal dot
47
+
48
+ Examples:
49
+ - "model.*.weight" matches any weight parameter in the model
50
+ - "model.encoder?.weight" matches encoder1.weight, encoder2.weight, etc.
51
+ - "model.encoder[12].weight" matches encoder1.weight and encoder2.weight
52
+ - "model.encoder.*.bias" matches any bias parameter in encoder submodules
53
+ """
54
+
55
+ _compiled_patterns: dict[Pattern, any]
56
+
57
+ @staticmethod
58
+ def _glob_to_regex(pattern: str) -> str:
59
+ # Convert glob pattern to regex string
60
+ return (
61
+ pattern.replace(".", r"\.")
62
+ .replace("*", ".*")
63
+ .replace("?", ".")
64
+ .replace("[", "[")
65
+ .replace("]", "]")
66
+ )
67
+
68
+ def _compile_patterns(self, policy_dict: dict[str, any]) -> dict[Pattern, any]:
69
+ compiled = {}
70
+ for pattern, value in list(policy_dict.items()):
71
+ if any(c in pattern for c in ["*", "?", "[", "]"]):
72
+ regex = self._glob_to_regex(pattern)
73
+ compiled[re.compile(f"^{regex}$")] = value
74
+ return compiled
75
+
76
+ def _get_policy_by_pattern(
77
+ self, param_name: str, policy_dict: dict[str, any], default: any
78
+ ) -> any:
79
+ # Exact match first
80
+ if policy_dict and param_name in policy_dict:
81
+ return policy_dict[param_name]
82
+ # Pattern match
83
+ for pattern, value in self._compiled_patterns.items():
84
+ if pattern.match(param_name):
85
+ return value
86
+ return default
87
+
88
+
89
+ @dataclass
90
+ class WeightLoadingConfig(_PatternPolicyMixin):
91
+ """Configuration for handling weights when loading a checkpoint."""
92
+
93
+ default_policy: WeightLoadingPolicy | str = WeightLoadingPolicy.COPY
94
+ fallback_policy: WeightLoadingPolicy | str = WeightLoadingPolicy.REINIT
95
+ param_policies: dict[str, WeightLoadingPolicy | str] = field(default_factory=dict)
96
+ _compiled_patterns: dict[Pattern, WeightLoadingPolicy] = field(
97
+ default_factory=dict, repr=False
98
+ )
99
+
100
+ def __post_init__(self):
101
+ if isinstance(self.default_policy, str):
102
+ self.default_policy = WeightLoadingPolicy(self.default_policy)
103
+ if isinstance(self.fallback_policy, str):
104
+ self.fallback_policy = WeightLoadingPolicy(self.fallback_policy)
105
+ for key, value in self.param_policies.items():
106
+ if isinstance(value, str):
107
+ self.param_policies[key] = WeightLoadingPolicy(value)
108
+ self._compiled_patterns = self._compile_patterns(self.param_policies)
109
+
110
+ def get_policy(self, param_name: str) -> WeightLoadingPolicy:
111
+ policy = self._get_policy_by_pattern(
112
+ param_name, self.param_policies, self.default_policy
113
+ )
114
+ assert isinstance(policy, WeightLoadingPolicy)
115
+ return policy
116
+
117
+
118
+ @dataclass
119
+ class ParameterFreezingConfig(_PatternPolicyMixin):
120
+ """Configuration for freezing model parameters after loading weights.
121
+
122
+ Allows specifying which parameters to freeze (set requires_grad=False) by exact name or pattern.
123
+ Patterns use glob-style wildcards (*, ?).
124
+
125
+ Attributes:
126
+ param_policies: Dict mapping parameter names or patterns to True (freeze) or False (do not freeze).
127
+ freeze_by_default: Whether to freeze parameters not matched by any pattern (default: False).
128
+ """
129
+
130
+ param_policies: dict[str, bool] = field(default_factory=dict)
131
+ freeze_by_default: bool = False
132
+ _compiled_patterns: dict[Pattern, bool] = field(default_factory=dict, repr=False)
133
+
134
+ def __post_init__(self):
135
+ self._compiled_patterns = self._compile_patterns(self.param_policies)
136
+
137
+ def is_frozen(self, param_name: str) -> bool:
138
+ """Get whether a parameter is frozen according to the config."""
139
+ is_frozen = self._get_policy_by_pattern(
140
+ param_name, self.param_policies, self.freeze_by_default
141
+ )
142
+ assert isinstance(is_frozen, bool)
143
+ return is_frozen
144
+
145
+
146
+ def freeze_parameters_with_config(
147
+ model: nn.Module, config: ParameterFreezingConfig, verbose: bool = True
148
+ ) -> None:
149
+ """Freeze (set requires_grad=False) or unfreeze parameters according to config.
150
+
151
+ Args:
152
+ model: The model whose parameters to freeze/unfreeze.
153
+ config: ParameterFreezingConfig specifying which parameters to freeze.
154
+ verbose: Whether to log which parameters have non-default policies applied.
155
+ """
156
+ for name, param in model.named_parameters():
157
+ is_frozen = config.is_frozen(name)
158
+ param.requires_grad = not is_frozen
159
+
160
+ if is_frozen != config.freeze_by_default and verbose:
161
+ ranked_logger.info(f"Non-default freezing applied to {name}: {is_frozen}")
162
+
163
+
164
+ def load_weights_with_policies(
165
+ model: nn.Module,
166
+ ckpt: dict[str, torch.Tensor],
167
+ config: WeightLoadingConfig | None = None,
168
+ ) -> dict:
169
+ """Load checkpoint weights into model according to the specified configuration.
170
+
171
+ Allows for partial loading of weights and zero-initialization of mismatched and arbitrary parameters.
172
+
173
+ Args:
174
+ model: The model to load weights INTO. By default, all model weights are re-initialized; we overwrite
175
+ with the checkpoint weights where appropriate
176
+ ckpt: Dictionary mapping parameter names to tensors (loaded from checkpoint on disk)
177
+ config: Configuration for handling weight loading. If None, uses default config
178
+ Returns:
179
+ dict: The updated state_dict (not loaded into model yet)
180
+ """
181
+ if config is None:
182
+ # (Initialize default config if not provided)
183
+ config = WeightLoadingConfig()
184
+
185
+ current_state = model.state_dict()
186
+ updated_state = {} # We will update this with the new weights
187
+
188
+ def _apply_policy(
189
+ name: str,
190
+ current_param: torch.Tensor,
191
+ checkpoint_param: torch.Tensor | None,
192
+ policy: WeightLoadingPolicy,
193
+ ) -> torch.Tensor:
194
+ """Apply a weight loading policy and return the resulting tensor.
195
+
196
+ Raises WeightLoadingError for any policy application failures.
197
+ """
198
+ if policy == WeightLoadingPolicy.REINIT:
199
+ # Keep original initialization
200
+ return current_param
201
+
202
+ elif policy == WeightLoadingPolicy.ZERO_INIT:
203
+ # Zero-initialize
204
+ return torch.zeros_like(current_param)
205
+
206
+ elif policy == WeightLoadingPolicy.COPY:
207
+ # Must have checkpoint param and shapes must match
208
+ if checkpoint_param is None:
209
+ raise WeightLoadingError(f"Parameter '{name}' not found in checkpoint")
210
+ if current_param.shape != checkpoint_param.shape:
211
+ raise WeightLoadingError(
212
+ f"Shape mismatch for '{name}': model {current_param.shape} vs checkpoint {checkpoint_param.shape}"
213
+ )
214
+ return checkpoint_param
215
+
216
+ elif policy == WeightLoadingPolicy.COPY_AND_ZERO_PAD:
217
+ # Must have checkpoint param and same number of dimensions
218
+ if checkpoint_param is None:
219
+ raise WeightLoadingError(f"Parameter '{name}' not found in checkpoint")
220
+ if len(current_param.shape) != len(checkpoint_param.shape):
221
+ raise WeightLoadingError(
222
+ f"Different dimensions for '{name}': model {len(current_param.shape)}D vs checkpoint {len(checkpoint_param.shape)}D"
223
+ )
224
+
225
+ # Copy where shapes match, zero-init the rest
226
+ new_param = torch.zeros_like(current_param)
227
+ slices = tuple(
228
+ slice(0, min(d_ckpt, d_current))
229
+ for d_ckpt, d_current in zip(
230
+ checkpoint_param.shape, current_param.shape
231
+ )
232
+ )
233
+ new_param[slices] = checkpoint_param[slices]
234
+ return new_param
235
+
236
+ # ... loop through all named parameters in the model
237
+ for name, current_param in current_state.items():
238
+ # Get the policy for this parameter
239
+ policy = config.get_policy(name)
240
+
241
+ # Get the corresponding parameter from the checkpoint
242
+ checkpoint_param = ckpt.get(name, None)
243
+
244
+ try:
245
+ # Try to apply the primary policy
246
+ result = _apply_policy(name, current_param, checkpoint_param, policy)
247
+ updated_state[name] = result
248
+ except WeightLoadingError as e:
249
+ # Primary policy failed, try fallback
250
+ ranked_logger.warning(
251
+ f"Failed to apply policy: '{policy}' to '{name}': {str(e)}. Falling back to policy: '{config.fallback_policy}'."
252
+ )
253
+ result = _apply_policy(
254
+ name, current_param, checkpoint_param, config.fallback_policy
255
+ )
256
+ updated_state[name] = result
257
+
258
+ return updated_state
259
+
260
+
261
+ @dataclass
262
+ class CheckpointConfig:
263
+ """Configuration for loading checkpoints.
264
+
265
+ TODO: Implement reset_scheduler and reset_ema
266
+ """
267
+
268
+ path: PathLike
269
+ reset_optimizer: bool = False
270
+ weight_loading_config: WeightLoadingConfig | None = None
271
+ parameter_freezing_config: ParameterFreezingConfig | None = None
foundry/version.py ADDED
@@ -0,0 +1,34 @@
1
+ # file generated by setuptools-scm
2
+ # don't change, don't track in version control
3
+
4
+ __all__ = [
5
+ "__version__",
6
+ "__version_tuple__",
7
+ "version",
8
+ "version_tuple",
9
+ "__commit_id__",
10
+ "commit_id",
11
+ ]
12
+
13
+ TYPE_CHECKING = False
14
+ if TYPE_CHECKING:
15
+ from typing import Tuple
16
+ from typing import Union
17
+
18
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
19
+ COMMIT_ID = Union[str, None]
20
+ else:
21
+ VERSION_TUPLE = object
22
+ COMMIT_ID = object
23
+
24
+ version: str
25
+ __version__: str
26
+ __version_tuple__: VERSION_TUPLE
27
+ version_tuple: VERSION_TUPLE
28
+ commit_id: COMMIT_ID
29
+ __commit_id__: COMMIT_ID
30
+
31
+ __version__ = version = '0.1.1'
32
+ __version_tuple__ = version_tuple = (0, 1, 1)
33
+
34
+ __commit_id__ = commit_id = None
@@ -0,0 +1,3 @@
1
+ """Foundry CLI utilities."""
2
+
3
+ __version__ = "0.1.0"
@@ -0,0 +1,281 @@
1
+ """CLI for foundry model checkpoint installation and management."""
2
+
3
+ import hashlib
4
+ from pathlib import Path
5
+ from typing import Optional
6
+ from urllib.request import urlopen
7
+
8
+ import rootutils
9
+ import typer
10
+ from dotenv import find_dotenv, load_dotenv, set_key
11
+ from rich.console import Console
12
+ from rich.progress import (
13
+ BarColumn,
14
+ DownloadColumn,
15
+ Progress,
16
+ SpinnerColumn,
17
+ TextColumn,
18
+ TimeRemainingColumn,
19
+ TransferSpeedColumn,
20
+ )
21
+
22
+ from foundry.inference_engines.checkpoint_registry import (
23
+ REGISTERED_CHECKPOINTS,
24
+ get_default_checkpoint_dir,
25
+ )
26
+
27
+ rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
28
+ load_dotenv(override=True)
29
+
30
+ app = typer.Typer(help="Foundry model checkpoint installation utilities")
31
+ console = Console()
32
+
33
+ def download_file(url: str, dest: Path, verify_hash: Optional[str] = None) -> None:
34
+ """Download a file with progress bar and optional hash verification.
35
+
36
+ Args:
37
+ url: URL to download from
38
+ dest: Destination file path
39
+ verify_hash: Optional SHA256 hash to verify against
40
+
41
+ Raises:
42
+ ValueError: If hash verification fails
43
+ """
44
+ dest.parent.mkdir(parents=True, exist_ok=True)
45
+
46
+ with Progress(
47
+ SpinnerColumn(),
48
+ TextColumn("[progress.description]{task.description}"),
49
+ BarColumn(),
50
+ DownloadColumn(),
51
+ TransferSpeedColumn(),
52
+ TimeRemainingColumn(),
53
+ ) as progress:
54
+ # Get file size
55
+ with urlopen(url) as response:
56
+ file_size = int(response.headers.get("Content-Length", 0))
57
+
58
+ task = progress.add_task(
59
+ f"Downloading {dest.name}", total=file_size, start=True
60
+ )
61
+
62
+ # Download with progress
63
+ hasher = hashlib.sha256() if verify_hash else None
64
+ with open(dest, "wb") as f:
65
+ while True:
66
+ chunk = response.read(8192)
67
+ if not chunk:
68
+ break
69
+ f.write(chunk)
70
+ if hasher:
71
+ hasher.update(chunk)
72
+ progress.update(task, advance=len(chunk))
73
+
74
+ # Verify hash if provided
75
+ if verify_hash:
76
+ computed_hash = hasher.hexdigest()
77
+ if computed_hash != verify_hash:
78
+ dest.unlink() # Remove corrupted file
79
+ raise ValueError(
80
+ f"Hash mismatch! Expected {verify_hash}, got {computed_hash}"
81
+ )
82
+ console.print("[green]✓[/green] Hash verification passed")
83
+
84
+
85
+ def install_model(
86
+ model_name: str, checkpoint_dir: Path, force: bool = False
87
+ ) -> None:
88
+ """Install a single model checkpoint.
89
+
90
+ Args:
91
+ model_name: Name of the model (rfd3, rf3, mpnn)
92
+ checkpoint_dir: Directory to save checkpoints
93
+ force: Overwrite existing checkpoint if it exists
94
+ """
95
+ if model_name not in REGISTERED_CHECKPOINTS:
96
+ console.print(f"[red]Error:[/red] Unknown model '{model_name}'")
97
+ console.print(f"Available models: {', '.join(REGISTERED_CHECKPOINTS.keys())}")
98
+ raise typer.Exit(1)
99
+
100
+ checkpoint_info = REGISTERED_CHECKPOINTS[model_name]
101
+ dest_path = checkpoint_dir / checkpoint_info["filename"]
102
+
103
+ # Check if already exists
104
+ if dest_path.exists() and not force:
105
+ console.print(
106
+ f"[yellow]⚠[/yellow] {model_name} checkpoint already exists at {dest_path}"
107
+ )
108
+ console.print("Use --force to overwrite")
109
+ return
110
+
111
+ console.print(
112
+ f"[cyan]Installing {model_name}:[/cyan] {checkpoint_info['description']}"
113
+ )
114
+
115
+ try:
116
+ download_file(
117
+ checkpoint_info["url"], dest_path, checkpoint_info.get("sha256")
118
+ )
119
+ console.print(
120
+ f"[green]✓[/green] Successfully installed {model_name} to {dest_path}"
121
+ )
122
+ except Exception as e:
123
+ console.print(f"[red]✗[/red] Failed to install {model_name}: {e}")
124
+ raise typer.Exit(1)
125
+
126
+
127
+ @app.command()
128
+ def install(
129
+ models: list[str] = typer.Argument(
130
+ ...,
131
+ help="Models to install: 'all', 'rfd3', 'rf3', 'mpnn', or combination",
132
+ ),
133
+ checkpoint_dir: Optional[Path] = typer.Option(
134
+ None,
135
+ "--checkpoint-dir",
136
+ "-d",
137
+ help="Directory to save checkpoints (default: $FOUNDRY_CHECKPOINTS_DIR or ~/.foundry/checkpoints)",
138
+ ),
139
+ force: bool = typer.Option(
140
+ False, "--force", "-f", help="Overwrite existing checkpoints"
141
+ ),
142
+ ):
143
+ """Install model checkpoints for foundry.
144
+
145
+ Examples:
146
+
147
+ foundry install all
148
+
149
+ foundry install rfd3 rf3
150
+
151
+ foundry install proteinmpnn --checkpoint-dir ./checkpoints
152
+ """
153
+ # Determine checkpoint directory
154
+ if checkpoint_dir is None:
155
+ checkpoint_dir = get_default_checkpoint_dir()
156
+
157
+ console.print(f"[bold]Checkpoint directory:[/bold] {checkpoint_dir}")
158
+ console.print()
159
+
160
+ # Expand 'all' to all available models
161
+ if "all" in models:
162
+ models_to_install = ['rfd3', 'proteinmpnn', 'ligandmpnn', 'rf3']
163
+ else:
164
+ models_to_install = models
165
+
166
+ # Install each model
167
+ for model_name in models_to_install:
168
+ install_model(model_name, checkpoint_dir, force)
169
+ console.print()
170
+
171
+ set_key(
172
+ dotenv_path=find_dotenv(),
173
+ key_to_set='FOUNDRY_CHECKPOINTS_DIR',
174
+ value_to_set=str(checkpoint_dir),
175
+ export = False,
176
+ )
177
+ console.print(
178
+ f"Set checkpoint installation directory to: {checkpoint_dir}"
179
+ )
180
+
181
+ console.print("[bold green]Installation complete![/bold green]")
182
+
183
+
184
+ @app.command(name="list")
185
+ def list_models():
186
+ """List available model checkpoints."""
187
+ console.print("[bold]Available models:[/bold]\n")
188
+ for name, info in REGISTERED_CHECKPOINTS.items():
189
+ console.print(f" [cyan]{name:8}[/cyan] - {info['description']}")
190
+
191
+
192
+ @app.command()
193
+ def show(
194
+ checkpoint_dir: Optional[Path] = typer.Option(
195
+ None,
196
+ "--checkpoint-dir",
197
+ "-d",
198
+ help="Checkpoint directory to show",
199
+ ),
200
+ ):
201
+ """Show installed checkpoints."""
202
+ if checkpoint_dir is None:
203
+ checkpoint_dir = get_default_checkpoint_dir()
204
+
205
+ if not checkpoint_dir.exists():
206
+ console.print(
207
+ f"[yellow]No checkpoints directory found at {checkpoint_dir}[/yellow]"
208
+ )
209
+ raise typer.Exit(0)
210
+
211
+ checkpoint_files = list(checkpoint_dir.glob("*.ckpt"))
212
+ if not checkpoint_files:
213
+ console.print(
214
+ f"[yellow]No checkpoint files found in {checkpoint_dir}[/yellow]"
215
+ )
216
+ raise typer.Exit(0)
217
+
218
+ console.print(f"[bold]Installed checkpoints in {checkpoint_dir}:[/bold]\n")
219
+ total_size = 0
220
+ for ckpt in sorted(checkpoint_files):
221
+ size = ckpt.stat().st_size / (1024**3) # GB
222
+ total_size += size
223
+ console.print(f" {ckpt.name:30} {size:8.2f} GB")
224
+
225
+ console.print(f"\n[bold]Total:[/bold] {total_size:.2f} GB")
226
+
227
+
228
+ @app.command()
229
+ def clean(
230
+ checkpoint_dir: Optional[Path] = typer.Option(
231
+ None,
232
+ "--checkpoint-dir",
233
+ "-d",
234
+ help="Checkpoint directory to clean",
235
+ ),
236
+ confirm: bool = typer.Option(
237
+ True, "--confirm/--no-confirm", help="Ask for confirmation before deleting"
238
+ ),
239
+ ):
240
+ """Remove all downloaded checkpoints."""
241
+ if checkpoint_dir is None:
242
+ checkpoint_dir = get_default_checkpoint_dir()
243
+
244
+ if not checkpoint_dir.exists():
245
+ console.print(f"[yellow]No checkpoints found at {checkpoint_dir}[/yellow]")
246
+ raise typer.Exit(0)
247
+
248
+ # List files to delete
249
+ checkpoint_files = list(checkpoint_dir.glob("*.ckpt"))
250
+ if not checkpoint_files:
251
+ console.print(
252
+ f"[yellow]No checkpoint files found in {checkpoint_dir}[/yellow]"
253
+ )
254
+ raise typer.Exit(0)
255
+
256
+ console.print("[bold]Files to delete:[/bold]")
257
+ total_size = 0
258
+ for ckpt in checkpoint_files:
259
+ size = ckpt.stat().st_size / (1024**3) # GB
260
+ total_size += size
261
+ console.print(f" {ckpt.name} ({size:.2f} GB)")
262
+
263
+ console.print(f"\n[bold]Total:[/bold] {total_size:.2f} GB")
264
+
265
+ # Confirm deletion
266
+ if confirm:
267
+ should_delete = typer.confirm("\nDelete these files?")
268
+ if not should_delete:
269
+ console.print("[yellow]Cancelled[/yellow]")
270
+ raise typer.Exit(0)
271
+
272
+ # Delete files
273
+ for ckpt in checkpoint_files:
274
+ ckpt.unlink()
275
+ console.print(f"[red]✗[/red] Deleted {ckpt.name}")
276
+
277
+ console.print("[green]✓[/green] Cleanup complete")
278
+
279
+
280
+ if __name__ == "__main__":
281
+ app()
mpnn/__init__.py ADDED
@@ -0,0 +1 @@
1
+ """MPNN - ProteinMPNN and LigandMPNN implementations."""