rc-foundry 0.1.3__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -69,15 +69,25 @@ class BaseInferenceEngine:
69
69
  self.verbose = verbose
70
70
 
71
71
  # Resolve checkpoint path
72
- if '.' not in str(ckpt_path):
72
+ if "." not in str(ckpt_path):
73
73
  # Assume registered model
74
74
  name = str(ckpt_path)
75
- assert name in REGISTERED_CHECKPOINTS, 'Checkpoint provided not and not in registered checkpoints'
75
+ assert (
76
+ name in REGISTERED_CHECKPOINTS
77
+ ), "Checkpoint provided not and not in registered checkpoints"
76
78
  ckpt = REGISTERED_CHECKPOINTS[name]
77
-
79
+
78
80
  ckpt_path = ckpt.get_default_path()
79
- ranked_logger.info("Using checkpoint from default installation directory, got: {}".format(str(ckpt_path)))
80
- assert os.path.exists(ckpt_path), 'Invalid checkpoint: {}. And could not find checkpoint in default installation location: {}'.format(name, ckpt_path)
81
+ ranked_logger.info(
82
+ "Using checkpoint from default installation directory, got: {}".format(
83
+ str(ckpt_path)
84
+ )
85
+ )
86
+ assert os.path.exists(
87
+ ckpt_path
88
+ ), "Invalid checkpoint: {}. And could not find checkpoint in default installation location: {}".format(
89
+ name, ckpt_path
90
+ )
81
91
  self.ckpt_path = Path(ckpt_path).resolve()
82
92
 
83
93
  # Set random seed (only if seed is not None)
@@ -1,4 +1,5 @@
1
- '''Management of checkpoints'''
1
+ """Management of checkpoints"""
2
+
2
3
  import os
3
4
  from dataclasses import dataclass
4
5
  from pathlib import Path
@@ -11,10 +12,13 @@ def get_default_checkpoint_dir() -> Path:
11
12
  1. FOUNDRY_CHECKPOINTS_DIR environment variable
12
13
  2. ~/.foundry/checkpoints
13
14
  """
14
- if "FOUNDRY_CHECKPOINTS_DIR" in os.environ and os.environ.get("FOUNDRY_CHECKPOINTS_DIR"):
15
+ if "FOUNDRY_CHECKPOINTS_DIR" in os.environ and os.environ.get(
16
+ "FOUNDRY_CHECKPOINTS_DIR"
17
+ ):
15
18
  return Path(os.environ["FOUNDRY_CHECKPOINTS_DIR"]).absolute()
16
19
  return Path.home() / ".foundry" / "checkpoints"
17
20
 
21
+
18
22
  @dataclass
19
23
  class RegisteredCheckpoint:
20
24
  url: str
@@ -28,39 +32,39 @@ class RegisteredCheckpoint:
28
32
 
29
33
  REGISTERED_CHECKPOINTS = {
30
34
  "rfd3": RegisteredCheckpoint(
31
- url = "https://files.ipd.uw.edu/pub/rfd3/rfd3_foundry_2025_12_01_remapped.ckpt",
32
- filename = "rfd3_latest.ckpt",
33
- description = "RFdiffusion3 checkpoint",
35
+ url="https://files.ipd.uw.edu/pub/rfd3/rfd3_foundry_2025_12_01_remapped.ckpt",
36
+ filename="rfd3_latest.ckpt",
37
+ description="RFdiffusion3 checkpoint",
34
38
  ),
35
- "rf3": RegisteredCheckpoint(
36
- url = "https://files.ipd.uw.edu/pub/rf3/rf3_foundry_01_24_latest_remapped.ckpt",
37
- filename= "rf3_foundry_01_24_latest_remapped.ckpt",
38
- description= "latest RF3 checkpoint trained with data until 1/2024 (expect best performance)",
39
+ "rf3": RegisteredCheckpoint(
40
+ url="https://files.ipd.uw.edu/pub/rf3/rf3_foundry_01_24_latest_remapped.ckpt",
41
+ filename="rf3_foundry_01_24_latest_remapped.ckpt",
42
+ description="latest RF3 checkpoint trained with data until 1/2024 (expect best performance)",
39
43
  ),
40
- "proteinmpnn": RegisteredCheckpoint(
41
- url = "https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_020.pt",
42
- filename= "proteinmpnn_v_48_020.pt",
43
- description= "ProteinMPNN checkpoint",
44
+ "proteinmpnn": RegisteredCheckpoint(
45
+ url="https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_020.pt",
46
+ filename="proteinmpnn_v_48_020.pt",
47
+ description="ProteinMPNN checkpoint",
44
48
  ),
45
49
  "ligandmpnn": RegisteredCheckpoint(
46
- url = "https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_010_25.pt",
47
- filename= "ligandmpnn_v_32_010_25.pt",
48
- description= "LigandMPNN checkpoint",
50
+ url="https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_010_25.pt",
51
+ filename="ligandmpnn_v_32_010_25.pt",
52
+ description="LigandMPNN checkpoint",
49
53
  ),
50
54
  # Other models
51
55
  "rf3_preprint_921": RegisteredCheckpoint(
52
- url = "https://files.ipd.uw.edu/pub/rf3/rf3_foundry_09_21_preprint_remapped.ckpt",
53
- filename = "rf3_foundry_09_21_preprint_remapped.ckpt",
54
- description = "RF3 preprint checkpoint trained with data until 9/2021",
56
+ url="https://files.ipd.uw.edu/pub/rf3/rf3_foundry_09_21_preprint_remapped.ckpt",
57
+ filename="rf3_foundry_09_21_preprint_remapped.ckpt",
58
+ description="RF3 preprint checkpoint trained with data until 9/2021",
55
59
  ),
56
60
  "rf3_preprint_124": RegisteredCheckpoint(
57
- url = "https://files.ipd.uw.edu/pub/rf3/rf3_foundry_01_24_preprint_remapped.ckpt",
58
- filename = "rf3_foundry_01_24_preprint_remapped.ckpt",
59
- description= "RF3 preprint checkpoint trained with data until 1/2024",
61
+ url="https://files.ipd.uw.edu/pub/rf3/rf3_foundry_01_24_preprint_remapped.ckpt",
62
+ filename="rf3_foundry_01_24_preprint_remapped.ckpt",
63
+ description="RF3 preprint checkpoint trained with data until 1/2024",
64
+ ),
65
+ "solublempnn": RegisteredCheckpoint(
66
+ url="https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_020.pt",
67
+ filename="solublempnn_v_48_020.pt",
68
+ description="SolubleMPNN checkpoint",
60
69
  ),
61
- "solublempnn": RegisteredCheckpoint(
62
- url = "https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_020.pt",
63
- filename= "solublempnn_v_48_020.pt",
64
- description= "SolubleMPNN checkpoint"
65
- )
66
70
  }
foundry/version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.1.3'
32
- __version_tuple__ = version_tuple = (0, 1, 3)
31
+ __version__ = version = '0.1.4'
32
+ __version_tuple__ = version_tuple = (0, 1, 4)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -5,7 +5,6 @@ from pathlib import Path
5
5
  from typing import Optional
6
6
  from urllib.request import urlopen
7
7
 
8
- import rootutils
9
8
  import typer
10
9
  from dotenv import find_dotenv, load_dotenv, set_key
11
10
  from rich.console import Console
@@ -29,6 +28,7 @@ load_dotenv(override=True)
29
28
  app = typer.Typer(help="Foundry model checkpoint installation utilities")
30
29
  console = Console()
31
30
 
31
+
32
32
  def download_file(url: str, dest: Path, verify_hash: Optional[str] = None) -> None:
33
33
  """Download a file with progress bar and optional hash verification.
34
34
 
@@ -81,9 +81,7 @@ def download_file(url: str, dest: Path, verify_hash: Optional[str] = None) -> No
81
81
  console.print("[green]✓[/green] Hash verification passed")
82
82
 
83
83
 
84
- def install_model(
85
- model_name: str, checkpoint_dir: Path, force: bool = False
86
- ) -> None:
84
+ def install_model(model_name: str, checkpoint_dir: Path, force: bool = False) -> None:
87
85
  """Install a single model checkpoint.
88
86
 
89
87
  Args:
@@ -112,9 +110,7 @@ def install_model(
112
110
  )
113
111
 
114
112
  try:
115
- download_file(
116
- checkpoint_info.url, dest_path, checkpoint_info.sha256
117
- )
113
+ download_file(checkpoint_info.url, dest_path, checkpoint_info.sha256)
118
114
  console.print(
119
115
  f"[green]✓[/green] Successfully installed {model_name} to {dest_path}"
120
116
  )
@@ -158,7 +154,7 @@ def install(
158
154
 
159
155
  # Expand 'all' to all available models
160
156
  if "all" in models:
161
- models_to_install = ['rfd3', 'proteinmpnn', 'ligandmpnn', 'rf3']
157
+ models_to_install = ["rfd3", "proteinmpnn", "ligandmpnn", "rf3"]
162
158
  else:
163
159
  models_to_install = models
164
160
 
@@ -167,15 +163,16 @@ def install(
167
163
  install_model(model_name, checkpoint_dir, force)
168
164
  console.print()
169
165
 
170
- set_key(
171
- dotenv_path=find_dotenv(),
172
- key_to_set='FOUNDRY_CHECKPOINTS_DIR',
173
- value_to_set=str(checkpoint_dir),
174
- export = False,
175
- )
176
- console.print(
177
- f"Set checkpoint installation directory to: {checkpoint_dir}"
178
- )
166
+ # Try to persist checkpoint dir to .env (optional, may not exist in Colab etc.)
167
+ dotenv_path = find_dotenv()
168
+ if dotenv_path:
169
+ set_key(
170
+ dotenv_path=dotenv_path,
171
+ key_to_set="FOUNDRY_CHECKPOINTS_DIR",
172
+ value_to_set=str(checkpoint_dir),
173
+ export=False,
174
+ )
175
+ console.print(f"Saved FOUNDRY_CHECKPOINTS_DIR to {dotenv_path}")
179
176
 
180
177
  console.print("[bold green]Installation complete![/bold green]")
181
178
 
@@ -209,9 +206,7 @@ def show(
209
206
 
210
207
  checkpoint_files = list(checkpoint_dir.glob("*.ckpt"))
211
208
  if not checkpoint_files:
212
- console.print(
213
- f"[yellow]No checkpoint files found in {checkpoint_dir}[/yellow]"
214
- )
209
+ console.print(f"[yellow]No checkpoint files found in {checkpoint_dir}[/yellow]")
215
210
  raise typer.Exit(0)
216
211
 
217
212
  console.print(f"[bold]Installed checkpoints in {checkpoint_dir}:[/bold]\n")
@@ -247,9 +242,7 @@ def clean(
247
242
  # List files to delete
248
243
  checkpoint_files = list(checkpoint_dir.glob("*.ckpt"))
249
244
  if not checkpoint_files:
250
- console.print(
251
- f"[yellow]No checkpoint files found in {checkpoint_dir}[/yellow]"
252
- )
245
+ console.print(f"[yellow]No checkpoint files found in {checkpoint_dir}[/yellow]")
253
246
  raise typer.Exit(0)
254
247
 
255
248
  console.print("[bold]Files to delete:[/bold]")
@@ -54,11 +54,18 @@ class MPNNInferenceEngine:
54
54
  self.out_directory = out_directory
55
55
  self.write_fasta = write_fasta
56
56
  self.write_structures = write_structures
57
-
57
+
58
58
  # allow null for checkpoint path when foundry-installed
59
59
  # TODO: Currently this assumes the model type is the key in the registered path. Rework needed
60
- self.checkpoint_path = str(REGISTERED_CHECKPOINTS[self.model_type.replace('_', '')].get_default_path()) \
61
- if not checkpoint_path else checkpoint_path
60
+ self.checkpoint_path = (
61
+ str(
62
+ REGISTERED_CHECKPOINTS[
63
+ self.model_type.replace("_", "")
64
+ ].get_default_path()
65
+ )
66
+ if not checkpoint_path
67
+ else checkpoint_path
68
+ )
62
69
 
63
70
  # Determine the device.
64
71
  if device is not None:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rc-foundry
3
- Version: 0.1.3
3
+ Version: 0.1.4
4
4
  Summary: Shared utilities and training infrastructure for biomolecular structure prediction models.
5
5
  Author-email: Institute for Protein Design <contact@ipd.uw.edu>
6
6
  License: BSD 3-Clause License
@@ -123,7 +123,7 @@ foundry install rfd3 ligandmpnn rf3 --checkpoint_dir <path/to/ckpt/dir>
123
123
  > *See [models/rfd3/README.md](models/rfd3/README.md) for complete documentation.*
124
124
 
125
125
  <div align="center">
126
- <img src="docs/_static/rfd3_trajectory.png" alt="RFdiffusion3 generation trajectory." width="400">
126
+ <img src="models/rfd3/docs/.assets/trajectory.png" alt="RFdiffusion3 generation trajectory." width="700">
127
127
  </div>
128
128
 
129
129
  ### ProteinMPNN
@@ -1,7 +1,7 @@
1
1
  foundry/__init__.py,sha256=H8S1nl5v6YeW8ggn1jKy4GdtH7c-FGS-j7CqUCAEnAU,1926
2
2
  foundry/common.py,sha256=Aur8mH-CNmcUqSsw7VgaCQSW5sH1Bqf8Da91jzxPV1Y,3035
3
3
  foundry/constants.py,sha256=0n1wBKCvNuw3QaQehSbmsHYkIdaGn3tLeRFItBrdeHY,913
4
- foundry/version.py,sha256=q5nF98G8SoVeJqaknL0xdyxtv0egsqb0fK06_84Izu8,704
4
+ foundry/version.py,sha256=rLCrf4heo25FJtBY-2Ap7ZuWW-5FS7sqTjsolIUuI5c,704
5
5
  foundry/callbacks/__init__.py,sha256=VsRT1e4sqlJHPcTCsfupMEx82Iz-LoOAGPpwvf_OJeE,126
6
6
  foundry/callbacks/callback.py,sha256=xZBo_suP4bLrP6gl5uJPbaXm00DXigePa6dMeDxucgg,3890
7
7
  foundry/callbacks/health_logging.py,sha256=tEtkByOlaAA7nnelxb7PbM9_dcIgOsdbxCdQY3K5pMc,16664
@@ -9,8 +9,8 @@ foundry/callbacks/metrics_logging.py,sha256=Vekzs831d-HE7TfLJZnQ45iPeG9ziQWLQaMB
9
9
  foundry/callbacks/timing_logging.py,sha256=u-r0hKp7fWOY3mLk7CcuIwHgZbhte13m5M09xNgatZA,2343
10
10
  foundry/callbacks/train_logging.py,sha256=Xs3tmZA88qLxmdSOwt-x8YKN4NKb1kVm59uptNXl4Qo,10399
11
11
  foundry/hydra/resolvers.py,sha256=xyJzo6OeWAc_LOu8RiHhX7_CRNoLZ22626AvYHXYl4U,2186
12
- foundry/inference_engines/base.py,sha256=qv5Gnk6NIxMpIxZ3oeOJurqMMUzBCZgfzHckb7SSzmU,8227
13
- foundry/inference_engines/checkpoint_registry.py,sha256=xRN3PjtmcnN7aEEhDR0MKhp1yaMDXHRXdLAGTDxi_Yk,2563
12
+ foundry/inference_engines/base.py,sha256=ZHdlmGUqH4-p3v4RdrLH-Ps8_zalr7j5mQ4x-S53N4M,8375
13
+ foundry/inference_engines/checkpoint_registry.py,sha256=kt2Z1JhrAjoOiEpkIIQ0sLttie1ceL8OgXUBmmyA6iw,2544
14
14
  foundry/metrics/__init__.py,sha256=qL4wwaiQ7EtR30pmZ9MCknqx909BJcNvHVmNJUaz_WM,236
15
15
  foundry/metrics/losses.py,sha256=2CLUmf7oCdFUCvgJukdNkff0FVG3BlATI-NI60TtpVY,903
16
16
  foundry/metrics/metric.py,sha256=23pKh_Ra0EcHGo5cSzYQQrUGr5zWRxeufKSJ58tfXXo,12687
@@ -34,12 +34,12 @@ foundry/utils/squashfs.py,sha256=QlcwuJyVe-QVfIOS7o1QfLhaCQPNzzox7ln4n8dcYEg,523
34
34
  foundry/utils/torch.py,sha256=OLsqoxw4CTXbGzWUHernLUT7uQjLu0tVPtD8h8747DI,11211
35
35
  foundry/utils/weights.py,sha256=btz4S02xff2vgiq4xMfiXuhK1ERafqQPtmimo1DmoWY,10381
36
36
  foundry_cli/__init__.py,sha256=0BxY2RUKJLaMXUGgypPCwlTskTEFdVnkhTR4C4ft2Kw,52
37
- foundry_cli/download_checkpoints.py,sha256=CDMxm3otzpb_c46AE2OwQY7aG7K_vnCJLKjP4U3b0Oc,8380
37
+ foundry_cli/download_checkpoints.py,sha256=UCNdy4VZyJe1PH_lnVLqy-VSMuTu875mGGd99ma7fTQ,8426
38
38
  mpnn/__init__.py,sha256=hgQcXFaCbAxFrhydVAy0xj8yC7UJF-GCCFhqD0sZ7I4,57
39
39
  mpnn/inference.py,sha256=wPtGR325eVRVeesXoWtBK6b_-VcU8BZae5IfQN3-mvA,1669
40
40
  mpnn/train.py,sha256=9eQGBd3rdNF5Zr2w8oUgETbqxBavNBajtA6Vbc5zESE,10239
41
41
  mpnn/collate/feature_collator.py,sha256=LpzAFWo1VMa06dJLmfUWZsKe4xvLZjHbx4RICg2lgbQ,10510
42
- mpnn/inference_engines/mpnn.py,sha256=985Ce84dWKNZk_5_dk1eJScYIgBGUKUhy-xDVCu15PA,21631
42
+ mpnn/inference_engines/mpnn.py,sha256=PmDEsIFipdk2fY57FA-vCp4evoU83DVVuUVmlViUtWk,21725
43
43
  mpnn/loss/nll_loss.py,sha256=KmdNe-BCzGYtijjappzBArQcT1gHVlJnKdY1PYQ4mhU,5947
44
44
  mpnn/metrics/nll.py,sha256=T6oMeUOEeHZzOMTH8NHFtsH9vUwLAsHQDPszzj4YKXI,15299
45
45
  mpnn/metrics/sequence_recovery.py,sha256=YDw_LmH-a3ajBYWK0mucJEQvw0_VEyxvrBN7da4vX8Q,19034
@@ -119,7 +119,7 @@ rfd3/__init__.py,sha256=2Wto2IsUIj2lGag9m_gqgdCwBNl5p21-Xnr7W_RpU3c,348
119
119
  rfd3/callbacks.py,sha256=Zjt8RiaYWquoKOwRmC_wCUbRbov-V4zd2_73zjhgDHE,2783
120
120
  rfd3/cli.py,sha256=TZpZouXGmwAMFaH8hp4r3q9tbUi1xlcN8n_r8hO2q8c,1424
121
121
  rfd3/constants.py,sha256=wLvDzrThpOrK8T3wGFNQeGrhAXOJQze8l3v_7pjIdMM,13141
122
- rfd3/engine.py,sha256=HCIHgB_4xFBYz7d7DGnqibnFcdC7RGXQ6qHgUJEvm2M,20889
122
+ rfd3/engine.py,sha256=La_dB48Ewz0IdY1ocxvSWg-PXVAsySm0OGvwyz42lI8,20824
123
123
  rfd3/run_inference.py,sha256=dubMwEFkNPOq_yYf0ny37qvvEkRjNNPRFksZgmEFkUc,1520
124
124
  rfd3/train.py,sha256=rHswffIUhOae3_iYyvAiQ3jALoFuzrcRUgMlbJLinlI,7947
125
125
  rfd3/inference/datasets.py,sha256=u-2U7deHXu-iOs7doiKKynewP-NEyJfdORSTDzUSaQI,6538
@@ -173,8 +173,8 @@ rfd3/transforms/virtual_atoms.py,sha256=UpmxzPPd5FaJigcRoxgLSHHrLLOqsCvZ5PPZfQSG
173
173
  rfd3/utils/inference.py,sha256=RQp5CCy6Z6uHVZ2Mx0zmmGluYEOrASke4bABtfRjpy0,26448
174
174
  rfd3/utils/io.py,sha256=wbdjUTQkDc3RCSM7gdogA-XOKR68HeQ-cfvyN4pP90w,9849
175
175
  rfd3/utils/vizualize.py,sha256=HPlczrA3zkOuxV5X05eOvy_Oga9e3cPnFUXOEP4RR_g,11046
176
- rc_foundry-0.1.3.dist-info/METADATA,sha256=gY8d46JTRTfot6O9MqW34E0D1O64xNhMeVdpJrQ-G18,10578
177
- rc_foundry-0.1.3.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
178
- rc_foundry-0.1.3.dist-info/entry_points.txt,sha256=BmiWCbWGtrd_lSOFMuCLBXyo84B7Nco-alj7hB0Yw9A,130
179
- rc_foundry-0.1.3.dist-info/licenses/LICENSE.md,sha256=NKtPCJ7QMysFmzeDg56ZfUStvgzbq5sOvRQv7_ddZOs,1533
180
- rc_foundry-0.1.3.dist-info/RECORD,,
176
+ rc_foundry-0.1.4.dist-info/METADATA,sha256=hzcS1buvLzRRAv7rPRgKwYjeNDL_iTGyR6u8CRpL-Ic,10585
177
+ rc_foundry-0.1.4.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
178
+ rc_foundry-0.1.4.dist-info/entry_points.txt,sha256=BmiWCbWGtrd_lSOFMuCLBXyo84B7Nco-alj7hB0Yw9A,130
179
+ rc_foundry-0.1.4.dist-info/licenses/LICENSE.md,sha256=NKtPCJ7QMysFmzeDg56ZfUStvgzbq5sOvRQv7_ddZOs,1533
180
+ rc_foundry-0.1.4.dist-info/RECORD,,
rfd3/engine.py CHANGED
@@ -15,7 +15,6 @@ from toolz import merge_with
15
15
 
16
16
  from foundry.common import exists
17
17
  from foundry.inference_engines.base import BaseInferenceEngine
18
- from foundry.inference_engines.checkpoint_registry import REGISTERED_CHECKPOINTS
19
18
  from foundry.utils.alignment import weighted_rigid_align
20
19
  from foundry.utils.ddp import RankedLogger
21
20
  from rfd3.constants import SAVED_CONDITIONING_ANNOTATIONS
@@ -38,7 +37,9 @@ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
38
37
 
39
38
  @dataclass(kw_only=True)
40
39
  class RFD3InferenceConfig:
41
- ckpt_path: str | Path = 'rfd3' # Defaults to foundry installation upon instantiation
40
+ ckpt_path: str | Path = (
41
+ "rfd3" # Defaults to foundry installation upon instantiation
42
+ )
42
43
  diffusion_batch_size: int = 16
43
44
 
44
45
  # RFD3 specific