rc-foundry 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/inference_engines/base.py +15 -5
- foundry/inference_engines/checkpoint_registry.py +31 -27
- foundry/version.py +2 -2
- foundry_cli/download_checkpoints.py +16 -23
- mpnn/inference_engines/mpnn.py +10 -3
- {rc_foundry-0.1.3.dist-info → rc_foundry-0.1.5.dist-info}/METADATA +2 -2
- {rc_foundry-0.1.3.dist-info → rc_foundry-0.1.5.dist-info}/RECORD +12 -12
- rfd3/engine.py +3 -2
- rfd3/run_inference.py +3 -7
- {rc_foundry-0.1.3.dist-info → rc_foundry-0.1.5.dist-info}/WHEEL +0 -0
- {rc_foundry-0.1.3.dist-info → rc_foundry-0.1.5.dist-info}/entry_points.txt +0 -0
- {rc_foundry-0.1.3.dist-info → rc_foundry-0.1.5.dist-info}/licenses/LICENSE.md +0 -0
|
@@ -69,15 +69,25 @@ class BaseInferenceEngine:
|
|
|
69
69
|
self.verbose = verbose
|
|
70
70
|
|
|
71
71
|
# Resolve checkpoint path
|
|
72
|
-
if
|
|
72
|
+
if "." not in str(ckpt_path):
|
|
73
73
|
# Assume registered model
|
|
74
74
|
name = str(ckpt_path)
|
|
75
|
-
assert
|
|
75
|
+
assert (
|
|
76
|
+
name in REGISTERED_CHECKPOINTS
|
|
77
|
+
), "Checkpoint provided not and not in registered checkpoints"
|
|
76
78
|
ckpt = REGISTERED_CHECKPOINTS[name]
|
|
77
|
-
|
|
79
|
+
|
|
78
80
|
ckpt_path = ckpt.get_default_path()
|
|
79
|
-
ranked_logger.info(
|
|
80
|
-
|
|
81
|
+
ranked_logger.info(
|
|
82
|
+
"Using checkpoint from default installation directory, got: {}".format(
|
|
83
|
+
str(ckpt_path)
|
|
84
|
+
)
|
|
85
|
+
)
|
|
86
|
+
assert os.path.exists(
|
|
87
|
+
ckpt_path
|
|
88
|
+
), "Invalid checkpoint: {}. And could not find checkpoint in default installation location: {}".format(
|
|
89
|
+
name, ckpt_path
|
|
90
|
+
)
|
|
81
91
|
self.ckpt_path = Path(ckpt_path).resolve()
|
|
82
92
|
|
|
83
93
|
# Set random seed (only if seed is not None)
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
|
|
1
|
+
"""Management of checkpoints"""
|
|
2
|
+
|
|
2
3
|
import os
|
|
3
4
|
from dataclasses import dataclass
|
|
4
5
|
from pathlib import Path
|
|
@@ -11,10 +12,13 @@ def get_default_checkpoint_dir() -> Path:
|
|
|
11
12
|
1. FOUNDRY_CHECKPOINTS_DIR environment variable
|
|
12
13
|
2. ~/.foundry/checkpoints
|
|
13
14
|
"""
|
|
14
|
-
if "FOUNDRY_CHECKPOINTS_DIR" in os.environ and os.environ.get(
|
|
15
|
+
if "FOUNDRY_CHECKPOINTS_DIR" in os.environ and os.environ.get(
|
|
16
|
+
"FOUNDRY_CHECKPOINTS_DIR"
|
|
17
|
+
):
|
|
15
18
|
return Path(os.environ["FOUNDRY_CHECKPOINTS_DIR"]).absolute()
|
|
16
19
|
return Path.home() / ".foundry" / "checkpoints"
|
|
17
20
|
|
|
21
|
+
|
|
18
22
|
@dataclass
|
|
19
23
|
class RegisteredCheckpoint:
|
|
20
24
|
url: str
|
|
@@ -28,39 +32,39 @@ class RegisteredCheckpoint:
|
|
|
28
32
|
|
|
29
33
|
REGISTERED_CHECKPOINTS = {
|
|
30
34
|
"rfd3": RegisteredCheckpoint(
|
|
31
|
-
url
|
|
32
|
-
filename
|
|
33
|
-
description
|
|
35
|
+
url="https://files.ipd.uw.edu/pub/rfd3/rfd3_foundry_2025_12_01_remapped.ckpt",
|
|
36
|
+
filename="rfd3_latest.ckpt",
|
|
37
|
+
description="RFdiffusion3 checkpoint",
|
|
34
38
|
),
|
|
35
|
-
"rf3":
|
|
36
|
-
url
|
|
37
|
-
filename=
|
|
38
|
-
description=
|
|
39
|
+
"rf3": RegisteredCheckpoint(
|
|
40
|
+
url="https://files.ipd.uw.edu/pub/rf3/rf3_foundry_01_24_latest_remapped.ckpt",
|
|
41
|
+
filename="rf3_foundry_01_24_latest_remapped.ckpt",
|
|
42
|
+
description="latest RF3 checkpoint trained with data until 1/2024 (expect best performance)",
|
|
39
43
|
),
|
|
40
|
-
"proteinmpnn":
|
|
41
|
-
url
|
|
42
|
-
filename=
|
|
43
|
-
description=
|
|
44
|
+
"proteinmpnn": RegisteredCheckpoint(
|
|
45
|
+
url="https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_020.pt",
|
|
46
|
+
filename="proteinmpnn_v_48_020.pt",
|
|
47
|
+
description="ProteinMPNN checkpoint",
|
|
44
48
|
),
|
|
45
49
|
"ligandmpnn": RegisteredCheckpoint(
|
|
46
|
-
url
|
|
47
|
-
filename=
|
|
48
|
-
description=
|
|
50
|
+
url="https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_010_25.pt",
|
|
51
|
+
filename="ligandmpnn_v_32_010_25.pt",
|
|
52
|
+
description="LigandMPNN checkpoint",
|
|
49
53
|
),
|
|
50
54
|
# Other models
|
|
51
55
|
"rf3_preprint_921": RegisteredCheckpoint(
|
|
52
|
-
url
|
|
53
|
-
filename
|
|
54
|
-
description
|
|
56
|
+
url="https://files.ipd.uw.edu/pub/rf3/rf3_foundry_09_21_preprint_remapped.ckpt",
|
|
57
|
+
filename="rf3_foundry_09_21_preprint_remapped.ckpt",
|
|
58
|
+
description="RF3 preprint checkpoint trained with data until 9/2021",
|
|
55
59
|
),
|
|
56
60
|
"rf3_preprint_124": RegisteredCheckpoint(
|
|
57
|
-
url
|
|
58
|
-
filename
|
|
59
|
-
description=
|
|
61
|
+
url="https://files.ipd.uw.edu/pub/rf3/rf3_foundry_01_24_preprint_remapped.ckpt",
|
|
62
|
+
filename="rf3_foundry_01_24_preprint_remapped.ckpt",
|
|
63
|
+
description="RF3 preprint checkpoint trained with data until 1/2024",
|
|
64
|
+
),
|
|
65
|
+
"solublempnn": RegisteredCheckpoint(
|
|
66
|
+
url="https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_020.pt",
|
|
67
|
+
filename="solublempnn_v_48_020.pt",
|
|
68
|
+
description="SolubleMPNN checkpoint",
|
|
60
69
|
),
|
|
61
|
-
"solublempnn": RegisteredCheckpoint(
|
|
62
|
-
url = "https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_020.pt",
|
|
63
|
-
filename= "solublempnn_v_48_020.pt",
|
|
64
|
-
description= "SolubleMPNN checkpoint"
|
|
65
|
-
)
|
|
66
70
|
}
|
foundry/version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.1.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 1,
|
|
31
|
+
__version__ = version = '0.1.5'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 1, 5)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -5,7 +5,6 @@ from pathlib import Path
|
|
|
5
5
|
from typing import Optional
|
|
6
6
|
from urllib.request import urlopen
|
|
7
7
|
|
|
8
|
-
import rootutils
|
|
9
8
|
import typer
|
|
10
9
|
from dotenv import find_dotenv, load_dotenv, set_key
|
|
11
10
|
from rich.console import Console
|
|
@@ -29,6 +28,7 @@ load_dotenv(override=True)
|
|
|
29
28
|
app = typer.Typer(help="Foundry model checkpoint installation utilities")
|
|
30
29
|
console = Console()
|
|
31
30
|
|
|
31
|
+
|
|
32
32
|
def download_file(url: str, dest: Path, verify_hash: Optional[str] = None) -> None:
|
|
33
33
|
"""Download a file with progress bar and optional hash verification.
|
|
34
34
|
|
|
@@ -81,9 +81,7 @@ def download_file(url: str, dest: Path, verify_hash: Optional[str] = None) -> No
|
|
|
81
81
|
console.print("[green]✓[/green] Hash verification passed")
|
|
82
82
|
|
|
83
83
|
|
|
84
|
-
def install_model(
|
|
85
|
-
model_name: str, checkpoint_dir: Path, force: bool = False
|
|
86
|
-
) -> None:
|
|
84
|
+
def install_model(model_name: str, checkpoint_dir: Path, force: bool = False) -> None:
|
|
87
85
|
"""Install a single model checkpoint.
|
|
88
86
|
|
|
89
87
|
Args:
|
|
@@ -112,9 +110,7 @@ def install_model(
|
|
|
112
110
|
)
|
|
113
111
|
|
|
114
112
|
try:
|
|
115
|
-
download_file(
|
|
116
|
-
checkpoint_info.url, dest_path, checkpoint_info.sha256
|
|
117
|
-
)
|
|
113
|
+
download_file(checkpoint_info.url, dest_path, checkpoint_info.sha256)
|
|
118
114
|
console.print(
|
|
119
115
|
f"[green]✓[/green] Successfully installed {model_name} to {dest_path}"
|
|
120
116
|
)
|
|
@@ -158,7 +154,7 @@ def install(
|
|
|
158
154
|
|
|
159
155
|
# Expand 'all' to all available models
|
|
160
156
|
if "all" in models:
|
|
161
|
-
models_to_install = [
|
|
157
|
+
models_to_install = ["rfd3", "proteinmpnn", "ligandmpnn", "rf3"]
|
|
162
158
|
else:
|
|
163
159
|
models_to_install = models
|
|
164
160
|
|
|
@@ -167,15 +163,16 @@ def install(
|
|
|
167
163
|
install_model(model_name, checkpoint_dir, force)
|
|
168
164
|
console.print()
|
|
169
165
|
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
166
|
+
# Try to persist checkpoint dir to .env (optional, may not exist in Colab etc.)
|
|
167
|
+
dotenv_path = find_dotenv()
|
|
168
|
+
if dotenv_path:
|
|
169
|
+
set_key(
|
|
170
|
+
dotenv_path=dotenv_path,
|
|
171
|
+
key_to_set="FOUNDRY_CHECKPOINTS_DIR",
|
|
172
|
+
value_to_set=str(checkpoint_dir),
|
|
173
|
+
export=False,
|
|
174
|
+
)
|
|
175
|
+
console.print(f"Saved FOUNDRY_CHECKPOINTS_DIR to {dotenv_path}")
|
|
179
176
|
|
|
180
177
|
console.print("[bold green]Installation complete![/bold green]")
|
|
181
178
|
|
|
@@ -209,9 +206,7 @@ def show(
|
|
|
209
206
|
|
|
210
207
|
checkpoint_files = list(checkpoint_dir.glob("*.ckpt"))
|
|
211
208
|
if not checkpoint_files:
|
|
212
|
-
console.print(
|
|
213
|
-
f"[yellow]No checkpoint files found in {checkpoint_dir}[/yellow]"
|
|
214
|
-
)
|
|
209
|
+
console.print(f"[yellow]No checkpoint files found in {checkpoint_dir}[/yellow]")
|
|
215
210
|
raise typer.Exit(0)
|
|
216
211
|
|
|
217
212
|
console.print(f"[bold]Installed checkpoints in {checkpoint_dir}:[/bold]\n")
|
|
@@ -247,9 +242,7 @@ def clean(
|
|
|
247
242
|
# List files to delete
|
|
248
243
|
checkpoint_files = list(checkpoint_dir.glob("*.ckpt"))
|
|
249
244
|
if not checkpoint_files:
|
|
250
|
-
console.print(
|
|
251
|
-
f"[yellow]No checkpoint files found in {checkpoint_dir}[/yellow]"
|
|
252
|
-
)
|
|
245
|
+
console.print(f"[yellow]No checkpoint files found in {checkpoint_dir}[/yellow]")
|
|
253
246
|
raise typer.Exit(0)
|
|
254
247
|
|
|
255
248
|
console.print("[bold]Files to delete:[/bold]")
|
mpnn/inference_engines/mpnn.py
CHANGED
|
@@ -54,11 +54,18 @@ class MPNNInferenceEngine:
|
|
|
54
54
|
self.out_directory = out_directory
|
|
55
55
|
self.write_fasta = write_fasta
|
|
56
56
|
self.write_structures = write_structures
|
|
57
|
-
|
|
57
|
+
|
|
58
58
|
# allow null for checkpoint path when foundry-installed
|
|
59
59
|
# TODO: Currently this assumes the model type is the key in the registered path. Rework needed
|
|
60
|
-
self.checkpoint_path =
|
|
61
|
-
|
|
60
|
+
self.checkpoint_path = (
|
|
61
|
+
str(
|
|
62
|
+
REGISTERED_CHECKPOINTS[
|
|
63
|
+
self.model_type.replace("_", "")
|
|
64
|
+
].get_default_path()
|
|
65
|
+
)
|
|
66
|
+
if not checkpoint_path
|
|
67
|
+
else checkpoint_path
|
|
68
|
+
)
|
|
62
69
|
|
|
63
70
|
# Determine the device.
|
|
64
71
|
if device is not None:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: rc-foundry
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.5
|
|
4
4
|
Summary: Shared utilities and training infrastructure for biomolecular structure prediction models.
|
|
5
5
|
Author-email: Institute for Protein Design <contact@ipd.uw.edu>
|
|
6
6
|
License: BSD 3-Clause License
|
|
@@ -123,7 +123,7 @@ foundry install rfd3 ligandmpnn rf3 --checkpoint_dir <path/to/ckpt/dir>
|
|
|
123
123
|
> *See [models/rfd3/README.md](models/rfd3/README.md) for complete documentation.*
|
|
124
124
|
|
|
125
125
|
<div align="center">
|
|
126
|
-
<img src="docs/
|
|
126
|
+
<img src="models/rfd3/docs/.assets/trajectory.png" alt="RFdiffusion3 generation trajectory." width="700">
|
|
127
127
|
</div>
|
|
128
128
|
|
|
129
129
|
### ProteinMPNN
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
foundry/__init__.py,sha256=H8S1nl5v6YeW8ggn1jKy4GdtH7c-FGS-j7CqUCAEnAU,1926
|
|
2
2
|
foundry/common.py,sha256=Aur8mH-CNmcUqSsw7VgaCQSW5sH1Bqf8Da91jzxPV1Y,3035
|
|
3
3
|
foundry/constants.py,sha256=0n1wBKCvNuw3QaQehSbmsHYkIdaGn3tLeRFItBrdeHY,913
|
|
4
|
-
foundry/version.py,sha256=
|
|
4
|
+
foundry/version.py,sha256=rdxBMYpwzYxiWk08QbPLHSAxHoDfeKWwyaJIAM0lSic,704
|
|
5
5
|
foundry/callbacks/__init__.py,sha256=VsRT1e4sqlJHPcTCsfupMEx82Iz-LoOAGPpwvf_OJeE,126
|
|
6
6
|
foundry/callbacks/callback.py,sha256=xZBo_suP4bLrP6gl5uJPbaXm00DXigePa6dMeDxucgg,3890
|
|
7
7
|
foundry/callbacks/health_logging.py,sha256=tEtkByOlaAA7nnelxb7PbM9_dcIgOsdbxCdQY3K5pMc,16664
|
|
@@ -9,8 +9,8 @@ foundry/callbacks/metrics_logging.py,sha256=Vekzs831d-HE7TfLJZnQ45iPeG9ziQWLQaMB
|
|
|
9
9
|
foundry/callbacks/timing_logging.py,sha256=u-r0hKp7fWOY3mLk7CcuIwHgZbhte13m5M09xNgatZA,2343
|
|
10
10
|
foundry/callbacks/train_logging.py,sha256=Xs3tmZA88qLxmdSOwt-x8YKN4NKb1kVm59uptNXl4Qo,10399
|
|
11
11
|
foundry/hydra/resolvers.py,sha256=xyJzo6OeWAc_LOu8RiHhX7_CRNoLZ22626AvYHXYl4U,2186
|
|
12
|
-
foundry/inference_engines/base.py,sha256=
|
|
13
|
-
foundry/inference_engines/checkpoint_registry.py,sha256=
|
|
12
|
+
foundry/inference_engines/base.py,sha256=ZHdlmGUqH4-p3v4RdrLH-Ps8_zalr7j5mQ4x-S53N4M,8375
|
|
13
|
+
foundry/inference_engines/checkpoint_registry.py,sha256=kt2Z1JhrAjoOiEpkIIQ0sLttie1ceL8OgXUBmmyA6iw,2544
|
|
14
14
|
foundry/metrics/__init__.py,sha256=qL4wwaiQ7EtR30pmZ9MCknqx909BJcNvHVmNJUaz_WM,236
|
|
15
15
|
foundry/metrics/losses.py,sha256=2CLUmf7oCdFUCvgJukdNkff0FVG3BlATI-NI60TtpVY,903
|
|
16
16
|
foundry/metrics/metric.py,sha256=23pKh_Ra0EcHGo5cSzYQQrUGr5zWRxeufKSJ58tfXXo,12687
|
|
@@ -34,12 +34,12 @@ foundry/utils/squashfs.py,sha256=QlcwuJyVe-QVfIOS7o1QfLhaCQPNzzox7ln4n8dcYEg,523
|
|
|
34
34
|
foundry/utils/torch.py,sha256=OLsqoxw4CTXbGzWUHernLUT7uQjLu0tVPtD8h8747DI,11211
|
|
35
35
|
foundry/utils/weights.py,sha256=btz4S02xff2vgiq4xMfiXuhK1ERafqQPtmimo1DmoWY,10381
|
|
36
36
|
foundry_cli/__init__.py,sha256=0BxY2RUKJLaMXUGgypPCwlTskTEFdVnkhTR4C4ft2Kw,52
|
|
37
|
-
foundry_cli/download_checkpoints.py,sha256=
|
|
37
|
+
foundry_cli/download_checkpoints.py,sha256=UCNdy4VZyJe1PH_lnVLqy-VSMuTu875mGGd99ma7fTQ,8426
|
|
38
38
|
mpnn/__init__.py,sha256=hgQcXFaCbAxFrhydVAy0xj8yC7UJF-GCCFhqD0sZ7I4,57
|
|
39
39
|
mpnn/inference.py,sha256=wPtGR325eVRVeesXoWtBK6b_-VcU8BZae5IfQN3-mvA,1669
|
|
40
40
|
mpnn/train.py,sha256=9eQGBd3rdNF5Zr2w8oUgETbqxBavNBajtA6Vbc5zESE,10239
|
|
41
41
|
mpnn/collate/feature_collator.py,sha256=LpzAFWo1VMa06dJLmfUWZsKe4xvLZjHbx4RICg2lgbQ,10510
|
|
42
|
-
mpnn/inference_engines/mpnn.py,sha256=
|
|
42
|
+
mpnn/inference_engines/mpnn.py,sha256=PmDEsIFipdk2fY57FA-vCp4evoU83DVVuUVmlViUtWk,21725
|
|
43
43
|
mpnn/loss/nll_loss.py,sha256=KmdNe-BCzGYtijjappzBArQcT1gHVlJnKdY1PYQ4mhU,5947
|
|
44
44
|
mpnn/metrics/nll.py,sha256=T6oMeUOEeHZzOMTH8NHFtsH9vUwLAsHQDPszzj4YKXI,15299
|
|
45
45
|
mpnn/metrics/sequence_recovery.py,sha256=YDw_LmH-a3ajBYWK0mucJEQvw0_VEyxvrBN7da4vX8Q,19034
|
|
@@ -119,8 +119,8 @@ rfd3/__init__.py,sha256=2Wto2IsUIj2lGag9m_gqgdCwBNl5p21-Xnr7W_RpU3c,348
|
|
|
119
119
|
rfd3/callbacks.py,sha256=Zjt8RiaYWquoKOwRmC_wCUbRbov-V4zd2_73zjhgDHE,2783
|
|
120
120
|
rfd3/cli.py,sha256=TZpZouXGmwAMFaH8hp4r3q9tbUi1xlcN8n_r8hO2q8c,1424
|
|
121
121
|
rfd3/constants.py,sha256=wLvDzrThpOrK8T3wGFNQeGrhAXOJQze8l3v_7pjIdMM,13141
|
|
122
|
-
rfd3/engine.py,sha256=
|
|
123
|
-
rfd3/run_inference.py,sha256=
|
|
122
|
+
rfd3/engine.py,sha256=La_dB48Ewz0IdY1ocxvSWg-PXVAsySm0OGvwyz42lI8,20824
|
|
123
|
+
rfd3/run_inference.py,sha256=ljzsCKEtrlfAvP0SDFPeQwTM3rV_X3ewHOhcRFVI37c,1258
|
|
124
124
|
rfd3/train.py,sha256=rHswffIUhOae3_iYyvAiQ3jALoFuzrcRUgMlbJLinlI,7947
|
|
125
125
|
rfd3/inference/datasets.py,sha256=u-2U7deHXu-iOs7doiKKynewP-NEyJfdORSTDzUSaQI,6538
|
|
126
126
|
rfd3/inference/input_parsing.py,sha256=mk3HBvo7MPTFEET7NagCo5TSjb47w-hxUDoeQxUW_h4,45449
|
|
@@ -173,8 +173,8 @@ rfd3/transforms/virtual_atoms.py,sha256=UpmxzPPd5FaJigcRoxgLSHHrLLOqsCvZ5PPZfQSG
|
|
|
173
173
|
rfd3/utils/inference.py,sha256=RQp5CCy6Z6uHVZ2Mx0zmmGluYEOrASke4bABtfRjpy0,26448
|
|
174
174
|
rfd3/utils/io.py,sha256=wbdjUTQkDc3RCSM7gdogA-XOKR68HeQ-cfvyN4pP90w,9849
|
|
175
175
|
rfd3/utils/vizualize.py,sha256=HPlczrA3zkOuxV5X05eOvy_Oga9e3cPnFUXOEP4RR_g,11046
|
|
176
|
-
rc_foundry-0.1.
|
|
177
|
-
rc_foundry-0.1.
|
|
178
|
-
rc_foundry-0.1.
|
|
179
|
-
rc_foundry-0.1.
|
|
180
|
-
rc_foundry-0.1.
|
|
176
|
+
rc_foundry-0.1.5.dist-info/METADATA,sha256=aUG8GCa8x-SjkduDG2FaLlLx5hT7uZW5qTZuoWv39l0,10585
|
|
177
|
+
rc_foundry-0.1.5.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
178
|
+
rc_foundry-0.1.5.dist-info/entry_points.txt,sha256=BmiWCbWGtrd_lSOFMuCLBXyo84B7Nco-alj7hB0Yw9A,130
|
|
179
|
+
rc_foundry-0.1.5.dist-info/licenses/LICENSE.md,sha256=NKtPCJ7QMysFmzeDg56ZfUStvgzbq5sOvRQv7_ddZOs,1533
|
|
180
|
+
rc_foundry-0.1.5.dist-info/RECORD,,
|
rfd3/engine.py
CHANGED
|
@@ -15,7 +15,6 @@ from toolz import merge_with
|
|
|
15
15
|
|
|
16
16
|
from foundry.common import exists
|
|
17
17
|
from foundry.inference_engines.base import BaseInferenceEngine
|
|
18
|
-
from foundry.inference_engines.checkpoint_registry import REGISTERED_CHECKPOINTS
|
|
19
18
|
from foundry.utils.alignment import weighted_rigid_align
|
|
20
19
|
from foundry.utils.ddp import RankedLogger
|
|
21
20
|
from rfd3.constants import SAVED_CONDITIONING_ANNOTATIONS
|
|
@@ -38,7 +37,9 @@ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
|
38
37
|
|
|
39
38
|
@dataclass(kw_only=True)
|
|
40
39
|
class RFD3InferenceConfig:
|
|
41
|
-
ckpt_path: str | Path =
|
|
40
|
+
ckpt_path: str | Path = (
|
|
41
|
+
"rfd3" # Defaults to foundry installation upon instantiation
|
|
42
|
+
)
|
|
42
43
|
diffusion_batch_size: int = 16
|
|
43
44
|
|
|
44
45
|
# RFD3 specific
|
rfd3/run_inference.py
CHANGED
|
@@ -3,20 +3,16 @@
|
|
|
3
3
|
import os
|
|
4
4
|
|
|
5
5
|
import hydra
|
|
6
|
-
import rootutils
|
|
7
6
|
from dotenv import load_dotenv
|
|
8
7
|
from omegaconf import DictConfig, OmegaConf
|
|
9
8
|
|
|
10
9
|
from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine
|
|
11
10
|
|
|
12
|
-
# Setup root dir and environment variables (more info: https://github.com/ashleve/rootutils)
|
|
13
|
-
# NOTE: Sets the `PROJECT_ROOT` environment variable to the root directory of the project (where `.project-root` is located)
|
|
14
|
-
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
15
|
-
|
|
16
11
|
load_dotenv(override=True)
|
|
17
12
|
|
|
18
|
-
#
|
|
19
|
-
|
|
13
|
+
# For pip-installed package, configs should be relative to this file
|
|
14
|
+
# Adjust this path based on where configs are bundled in the package
|
|
15
|
+
_config_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "configs")
|
|
20
16
|
|
|
21
17
|
|
|
22
18
|
@hydra.main(
|
|
File without changes
|
|
File without changes
|
|
File without changes
|