rc-foundry 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/inference_engines/checkpoint_registry.py +58 -11
- foundry/utils/alignment.py +10 -2
- foundry/version.py +2 -2
- foundry_cli/download_checkpoints.py +66 -66
- {rc_foundry-0.1.6.dist-info → rc_foundry-0.1.7.dist-info}/METADATA +25 -20
- {rc_foundry-0.1.6.dist-info → rc_foundry-0.1.7.dist-info}/RECORD +22 -22
- rfd3/configs/datasets/train/pdb/af3_train_interface.yaml +1 -1
- rfd3/configs/inference_engine/rfdiffusion3.yaml +2 -2
- rfd3/configs/model/samplers/symmetry.yaml +1 -1
- rfd3/engine.py +19 -11
- rfd3/inference/input_parsing.py +1 -1
- rfd3/inference/legacy_input_parsing.py +17 -1
- rfd3/inference/parsing.py +1 -0
- rfd3/inference/symmetry/atom_array.py +1 -5
- rfd3/inference/symmetry/checks.py +53 -28
- rfd3/inference/symmetry/frames.py +8 -5
- rfd3/inference/symmetry/symmetry_utils.py +38 -60
- rfd3/run_inference.py +3 -1
- rfd3/utils/inference.py +23 -0
- {rc_foundry-0.1.6.dist-info → rc_foundry-0.1.7.dist-info}/WHEEL +0 -0
- {rc_foundry-0.1.6.dist-info → rc_foundry-0.1.7.dist-info}/entry_points.txt +0 -0
- {rc_foundry-0.1.6.dist-info → rc_foundry-0.1.7.dist-info}/licenses/LICENSE.md +0 -0
|
@@ -3,20 +3,62 @@
|
|
|
3
3
|
import os
|
|
4
4
|
from dataclasses import dataclass
|
|
5
5
|
from pathlib import Path
|
|
6
|
+
from typing import Iterable, List
|
|
6
7
|
|
|
8
|
+
import dotenv
|
|
7
9
|
|
|
8
|
-
|
|
9
|
-
|
|
10
|
+
DEFAULT_CHECKPOINT_DIR = Path.home() / ".foundry" / "checkpoints"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _normalize_paths(paths: Iterable[Path]) -> list[Path]:
|
|
14
|
+
"""Return absolute, deduplicated paths in order."""
|
|
15
|
+
seen = set()
|
|
16
|
+
normalized: List[Path] = []
|
|
17
|
+
for path in paths:
|
|
18
|
+
resolved = path.expanduser().absolute()
|
|
19
|
+
if resolved not in seen:
|
|
20
|
+
normalized.append(resolved)
|
|
21
|
+
seen.add(resolved)
|
|
22
|
+
return normalized
|
|
10
23
|
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
24
|
+
|
|
25
|
+
def get_default_checkpoint_dirs() -> list[Path]:
|
|
26
|
+
"""Return checkpoint search paths.
|
|
27
|
+
|
|
28
|
+
Always starts with the default ~/.foundry/checkpoints directory and then
|
|
29
|
+
appends any additional directories from the colon-separated
|
|
30
|
+
FOUNDRY_CHECKPOINT_DIRS environment variable.
|
|
14
31
|
"""
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
32
|
+
env_dirs = os.environ.get("FOUNDRY_CHECKPOINT_DIRS", "")
|
|
33
|
+
|
|
34
|
+
# For backward compatibility, also check FOUNDRY_CHECKPOINTS_DIR
|
|
35
|
+
if not env_dirs:
|
|
36
|
+
env_dirs = os.environ.get("FOUNDRY_CHECKPOINTS_DIR", "")
|
|
37
|
+
|
|
38
|
+
extra_dirs: list[Path] = []
|
|
39
|
+
if env_dirs:
|
|
40
|
+
extra_dirs = [Path(p.strip()) for p in env_dirs.split(":") if p.strip()]
|
|
41
|
+
return _normalize_paths([*extra_dirs, DEFAULT_CHECKPOINT_DIR])
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def get_default_checkpoint_dir() -> Path:
|
|
45
|
+
"""Backward-compatible helper returning the primary checkpoint directory."""
|
|
46
|
+
return get_default_checkpoint_dirs()[0]
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def append_checkpoint_to_env(checkpoint_dirs: list[Path]) -> bool:
|
|
50
|
+
dotenv_path = dotenv.find_dotenv()
|
|
51
|
+
if dotenv_path:
|
|
52
|
+
checkpoint_dirs = _normalize_paths(checkpoint_dirs)
|
|
53
|
+
dotenv.set_key(
|
|
54
|
+
dotenv_path=dotenv_path,
|
|
55
|
+
key_to_set="FOUNDRY_CHECKPOINT_DIRS",
|
|
56
|
+
value_to_set=":".join(str(path) for path in checkpoint_dirs),
|
|
57
|
+
export=False,
|
|
58
|
+
)
|
|
59
|
+
return True
|
|
60
|
+
else:
|
|
61
|
+
return False
|
|
20
62
|
|
|
21
63
|
|
|
22
64
|
@dataclass
|
|
@@ -27,7 +69,12 @@ class RegisteredCheckpoint:
|
|
|
27
69
|
sha256: None = None # Optional: add checksum for verification
|
|
28
70
|
|
|
29
71
|
def get_default_path(self):
|
|
30
|
-
|
|
72
|
+
checkpoint_dirs = get_default_checkpoint_dirs()
|
|
73
|
+
for checkpoint_dir in checkpoint_dirs:
|
|
74
|
+
candidate = checkpoint_dir / self.filename
|
|
75
|
+
if candidate.exists():
|
|
76
|
+
return candidate
|
|
77
|
+
return checkpoint_dirs[0] / self.filename
|
|
31
78
|
|
|
32
79
|
|
|
33
80
|
REGISTERED_CHECKPOINTS = {
|
foundry/utils/alignment.py
CHANGED
|
@@ -18,14 +18,19 @@ def weighted_rigid_align(
|
|
|
18
18
|
Returns:
|
|
19
19
|
X_align_L: [B, L, 3]
|
|
20
20
|
"""
|
|
21
|
-
assert X_L.shape == X_gt_L.shape
|
|
22
|
-
assert X_L.shape[:-1] == w_L.shape
|
|
23
21
|
|
|
22
|
+
# Canonicalize dimensions
|
|
23
|
+
if X_L.ndim == 2:
|
|
24
|
+
X_L = X_L[None]
|
|
25
|
+
if X_gt_L.ndim == 2:
|
|
26
|
+
X_gt_L = X_gt_L[None]
|
|
24
27
|
if X_exists_L is None:
|
|
25
28
|
X_exists_L = torch.ones((X_L.shape[-2]), dtype=torch.bool)
|
|
26
29
|
if w_L is None:
|
|
27
30
|
w_L = torch.ones_like(X_L[..., 0])
|
|
28
31
|
else:
|
|
32
|
+
if w_L.ndim == 1:
|
|
33
|
+
w_L = w_L[None]
|
|
29
34
|
w_L = w_L.to(torch.float32)
|
|
30
35
|
|
|
31
36
|
# Assert `X_exists_L` is a boolean mask
|
|
@@ -33,6 +38,9 @@ def weighted_rigid_align(
|
|
|
33
38
|
X_exists_L.dtype == torch.bool
|
|
34
39
|
), "X_exists_L should be a boolean mask! Otherwise, the alignment will be incorrect (silent failure)!"
|
|
35
40
|
|
|
41
|
+
assert X_L.shape == X_gt_L.shape
|
|
42
|
+
assert X_L.shape[:-1] == w_L.shape
|
|
43
|
+
|
|
36
44
|
X_resolved = X_L[:, X_exists_L]
|
|
37
45
|
X_gt_resolved = X_gt_L[:, X_exists_L]
|
|
38
46
|
w_resolved = w_L[:, X_exists_L]
|
foundry/version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.1.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 1,
|
|
31
|
+
__version__ = version = '0.1.7'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 1, 7)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -6,7 +6,7 @@ from typing import Optional
|
|
|
6
6
|
from urllib.request import urlopen
|
|
7
7
|
|
|
8
8
|
import typer
|
|
9
|
-
from dotenv import
|
|
9
|
+
from dotenv import load_dotenv
|
|
10
10
|
from rich.console import Console
|
|
11
11
|
from rich.progress import (
|
|
12
12
|
BarColumn,
|
|
@@ -20,7 +20,8 @@ from rich.progress import (
|
|
|
20
20
|
|
|
21
21
|
from foundry.inference_engines.checkpoint_registry import (
|
|
22
22
|
REGISTERED_CHECKPOINTS,
|
|
23
|
-
|
|
23
|
+
append_checkpoint_to_env,
|
|
24
|
+
get_default_checkpoint_dirs,
|
|
24
25
|
)
|
|
25
26
|
|
|
26
27
|
load_dotenv(override=True)
|
|
@@ -29,6 +30,27 @@ app = typer.Typer(help="Foundry model checkpoint installation utilities")
|
|
|
29
30
|
console = Console()
|
|
30
31
|
|
|
31
32
|
|
|
33
|
+
def _resolve_checkpoint_dirs(checkpoint_dir: Optional[Path]) -> list[Path]:
|
|
34
|
+
"""Return checkpoint search path with defaults first."""
|
|
35
|
+
checkpoint_dirs = get_default_checkpoint_dirs()
|
|
36
|
+
if checkpoint_dir is not None:
|
|
37
|
+
resolved = checkpoint_dir.expanduser().absolute()
|
|
38
|
+
if resolved not in checkpoint_dirs:
|
|
39
|
+
checkpoint_dirs.insert(0, resolved)
|
|
40
|
+
else:
|
|
41
|
+
# Move to front
|
|
42
|
+
checkpoint_dirs.remove(resolved)
|
|
43
|
+
checkpoint_dirs.insert(0, resolved)
|
|
44
|
+
|
|
45
|
+
# Try to persist checkpoint dir to .env (optional, may not exist in Colab etc.)
|
|
46
|
+
if append_checkpoint_to_env(checkpoint_dirs):
|
|
47
|
+
console.print(
|
|
48
|
+
f"Tracked checkpoint directories: {':'.join(str(path) for path in checkpoint_dirs)}"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
return checkpoint_dirs
|
|
52
|
+
|
|
53
|
+
|
|
32
54
|
def download_file(url: str, dest: Path, verify_hash: Optional[str] = None) -> None:
|
|
33
55
|
"""Download a file with progress bar and optional hash verification.
|
|
34
56
|
|
|
@@ -123,134 +145,112 @@ def install_model(model_name: str, checkpoint_dir: Path, force: bool = False) ->
|
|
|
123
145
|
def install(
|
|
124
146
|
models: list[str] = typer.Argument(
|
|
125
147
|
...,
|
|
126
|
-
help="Models to install: 'all', 'rfd3', 'rf3', 'mpnn', or combination",
|
|
148
|
+
help="Models to install: 'all', 'rfd3', 'rf3', 'mpnn', or a combination thereof",
|
|
127
149
|
),
|
|
128
150
|
checkpoint_dir: Optional[Path] = typer.Option(
|
|
129
151
|
None,
|
|
130
152
|
"--checkpoint-dir",
|
|
131
153
|
"-d",
|
|
132
|
-
help="Directory to save checkpoints (default
|
|
154
|
+
help="Directory to save checkpoints (default search path: ~/.foundry/checkpoints plus any $FOUNDRY_CHECKPOINT_DIRS entries)",
|
|
133
155
|
),
|
|
134
156
|
force: bool = typer.Option(
|
|
135
157
|
False, "--force", "-f", help="Overwrite existing checkpoints"
|
|
136
158
|
),
|
|
137
159
|
):
|
|
138
160
|
"""Install model checkpoints for foundry.
|
|
139
|
-
|
|
140
161
|
Examples:
|
|
141
|
-
|
|
142
162
|
foundry install all
|
|
143
|
-
|
|
144
163
|
foundry install rfd3 rf3
|
|
145
|
-
|
|
146
164
|
foundry install proteinmpnn --checkpoint-dir ./checkpoints
|
|
147
165
|
"""
|
|
148
166
|
# Determine checkpoint directory
|
|
149
|
-
|
|
150
|
-
|
|
167
|
+
checkpoint_dirs = _resolve_checkpoint_dirs(checkpoint_dir)
|
|
168
|
+
primary_checkpoint_dir = checkpoint_dirs[0]
|
|
151
169
|
|
|
152
|
-
console.print(f"[bold]
|
|
153
|
-
console.print()
|
|
170
|
+
console.print(f"[bold]Install target:[/bold] {primary_checkpoint_dir}\n")
|
|
154
171
|
|
|
155
172
|
# Expand 'all' to all available models
|
|
156
173
|
if "all" in models:
|
|
174
|
+
models_to_install = list(REGISTERED_CHECKPOINTS.keys())
|
|
175
|
+
elif "base-models" in models:
|
|
157
176
|
models_to_install = ["rfd3", "proteinmpnn", "ligandmpnn", "rf3"]
|
|
158
177
|
else:
|
|
159
178
|
models_to_install = models
|
|
160
179
|
|
|
161
180
|
# Install each model
|
|
162
181
|
for model_name in models_to_install:
|
|
163
|
-
install_model(model_name,
|
|
182
|
+
install_model(model_name, primary_checkpoint_dir, force)
|
|
164
183
|
console.print()
|
|
165
184
|
|
|
166
|
-
# Try to persist checkpoint dir to .env (optional, may not exist in Colab etc.)
|
|
167
|
-
dotenv_path = find_dotenv()
|
|
168
|
-
if dotenv_path:
|
|
169
|
-
set_key(
|
|
170
|
-
dotenv_path=dotenv_path,
|
|
171
|
-
key_to_set="FOUNDRY_CHECKPOINTS_DIR",
|
|
172
|
-
value_to_set=str(checkpoint_dir),
|
|
173
|
-
export=False,
|
|
174
|
-
)
|
|
175
|
-
console.print(f"Saved FOUNDRY_CHECKPOINTS_DIR to {dotenv_path}")
|
|
176
|
-
|
|
177
185
|
console.print("[bold green]Installation complete![/bold green]")
|
|
178
186
|
|
|
179
187
|
|
|
180
|
-
@app.command(name="list")
|
|
181
|
-
def
|
|
188
|
+
@app.command(name="list-available")
|
|
189
|
+
def list_available():
|
|
182
190
|
"""List available model checkpoints."""
|
|
183
191
|
console.print("[bold]Available models:[/bold]\n")
|
|
184
192
|
for name, info in REGISTERED_CHECKPOINTS.items():
|
|
185
193
|
console.print(f" [cyan]{name:8}[/cyan] - {info.description}")
|
|
186
194
|
|
|
187
195
|
|
|
188
|
-
@app.command()
|
|
189
|
-
def
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
"--checkpoint-dir",
|
|
193
|
-
"-d",
|
|
194
|
-
help="Checkpoint directory to show",
|
|
195
|
-
),
|
|
196
|
-
):
|
|
197
|
-
"""Show installed checkpoints."""
|
|
198
|
-
if checkpoint_dir is None:
|
|
199
|
-
checkpoint_dir = get_default_checkpoint_dir()
|
|
196
|
+
@app.command(name="list-installed")
|
|
197
|
+
def list_installed():
|
|
198
|
+
"""List installed checkpoints and their sizes."""
|
|
199
|
+
checkpoint_dirs = _resolve_checkpoint_dirs(None)
|
|
200
200
|
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
201
|
+
checkpoint_files: list[tuple[Path, float]] = []
|
|
202
|
+
for checkpoint_dir in checkpoint_dirs:
|
|
203
|
+
if not checkpoint_dir.exists():
|
|
204
|
+
continue
|
|
205
|
+
ckpts = list(checkpoint_dir.glob("*.ckpt")) + list(checkpoint_dir.glob("*.pt"))
|
|
206
|
+
for ckpt in ckpts:
|
|
207
|
+
size = ckpt.stat().st_size / (1024**3) # GB
|
|
208
|
+
checkpoint_files.append((ckpt, size))
|
|
206
209
|
|
|
207
|
-
checkpoint_files = list(checkpoint_dir.glob("*.ckpt"))
|
|
208
210
|
if not checkpoint_files:
|
|
209
|
-
console.print(
|
|
211
|
+
console.print(
|
|
212
|
+
"[yellow]No checkpoint files found in any checkpoint directory[/yellow]"
|
|
213
|
+
)
|
|
210
214
|
raise typer.Exit(0)
|
|
211
215
|
|
|
212
|
-
console.print(
|
|
216
|
+
console.print("[bold]Installed checkpoints:[/bold]\n")
|
|
213
217
|
total_size = 0
|
|
214
|
-
for ckpt in sorted(checkpoint_files):
|
|
215
|
-
size = ckpt.stat().st_size / (1024**3) # GB
|
|
218
|
+
for ckpt, size in sorted(checkpoint_files, key=lambda item: str(item[0])):
|
|
216
219
|
total_size += size
|
|
217
|
-
console.print(f" {ckpt
|
|
220
|
+
console.print(f" {ckpt} {size:8.2f} GB")
|
|
218
221
|
|
|
219
222
|
console.print(f"\n[bold]Total:[/bold] {total_size:.2f} GB")
|
|
220
223
|
|
|
221
224
|
|
|
222
|
-
@app.command()
|
|
225
|
+
@app.command(name="clean")
|
|
223
226
|
def clean(
|
|
224
|
-
checkpoint_dir: Optional[Path] = typer.Option(
|
|
225
|
-
None,
|
|
226
|
-
"--checkpoint-dir",
|
|
227
|
-
"-d",
|
|
228
|
-
help="Checkpoint directory to clean",
|
|
229
|
-
),
|
|
230
227
|
confirm: bool = typer.Option(
|
|
231
228
|
True, "--confirm/--no-confirm", help="Ask for confirmation before deleting"
|
|
232
229
|
),
|
|
233
230
|
):
|
|
234
231
|
"""Remove all downloaded checkpoints."""
|
|
235
|
-
|
|
236
|
-
checkpoint_dir = get_default_checkpoint_dir()
|
|
237
|
-
|
|
238
|
-
if not checkpoint_dir.exists():
|
|
239
|
-
console.print(f"[yellow]No checkpoints found at {checkpoint_dir}[/yellow]")
|
|
240
|
-
raise typer.Exit(0)
|
|
232
|
+
checkpoint_dirs = _resolve_checkpoint_dirs(None)
|
|
241
233
|
|
|
242
234
|
# List files to delete
|
|
243
|
-
checkpoint_files =
|
|
235
|
+
checkpoint_files: list[Path] = []
|
|
236
|
+
for checkpoint_dir in checkpoint_dirs:
|
|
237
|
+
if not checkpoint_dir.exists():
|
|
238
|
+
continue
|
|
239
|
+
checkpoint_files.extend(checkpoint_dir.glob("*.ckpt"))
|
|
240
|
+
checkpoint_files.extend(checkpoint_dir.glob("*.pt"))
|
|
241
|
+
|
|
244
242
|
if not checkpoint_files:
|
|
245
|
-
console.print(
|
|
243
|
+
console.print(
|
|
244
|
+
"[yellow]No checkpoint files found in any checkpoint directory[/yellow]"
|
|
245
|
+
)
|
|
246
246
|
raise typer.Exit(0)
|
|
247
247
|
|
|
248
248
|
console.print("[bold]Files to delete:[/bold]")
|
|
249
249
|
total_size = 0
|
|
250
|
-
for ckpt in checkpoint_files:
|
|
250
|
+
for ckpt in sorted(checkpoint_files, key=str):
|
|
251
251
|
size = ckpt.stat().st_size / (1024**3) # GB
|
|
252
252
|
total_size += size
|
|
253
|
-
console.print(f" {ckpt
|
|
253
|
+
console.print(f" {ckpt} ({size:.2f} GB)")
|
|
254
254
|
|
|
255
255
|
console.print(f"\n[bold]Total:[/bold] {total_size:.2f} GB")
|
|
256
256
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: rc-foundry
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.7
|
|
4
4
|
Summary: Shared utilities and training infrastructure for biomolecular structure prediction models.
|
|
5
5
|
Author-email: Institute for Protein Design <contact@ipd.uw.edu>
|
|
6
6
|
License: BSD 3-Clause License
|
|
@@ -104,33 +104,34 @@ All models within Foundry rely on [AtomWorks](https://github.com/RosettaCommons/
|
|
|
104
104
|
pip install rc-foundry[all]
|
|
105
105
|
```
|
|
106
106
|
|
|
107
|
-
**Downloading weights**
|
|
108
|
-
|
|
107
|
+
**Downloading weights** Models can be downloaded to a target folder with:
|
|
108
|
+
```
|
|
109
|
+
foundry install base-models --checkpoint-dir <path/to/ckpt/dir>
|
|
110
|
+
```
|
|
111
|
+
where `checkpoint-dir` will be `~/.foundry/checkpoints` by default. Foundry always searches `~/.foundry/checkpoints` plus any colon-separated entries in `$FOUNDRY_CHECKPOINT_DIRS` during inference or subsequent commands to find checkpoints. `base-models` installs the latest RFD3, RF3 and MPNN variants - you can also download all of the models supported (including multiple checkpoints of RF3) with `all`, or by listing the models sequentially (e.g. `foundry install rfd3 rf3 ...`).
|
|
112
|
+
To list the registry of available checkpoints:
|
|
109
113
|
```
|
|
110
|
-
foundry
|
|
114
|
+
foundry list-available
|
|
111
115
|
```
|
|
112
|
-
|
|
116
|
+
To check what you already have downloaded (searches `~/.foundry/checkpoints` plus `$FOUNDRY_CHECKPOINT_DIRS` if set):
|
|
113
117
|
```
|
|
114
|
-
foundry
|
|
118
|
+
foundry list-installed
|
|
115
119
|
```
|
|
116
120
|
|
|
117
|
-
>*See `examples/all.ipynb` for how to run each model in a notebook.*
|
|
121
|
+
>*See `examples/all.ipynb` for how to run each model and design proteins end-to-end in a notebook.*
|
|
122
|
+
|
|
123
|
+
### Google Colab
|
|
124
|
+
For an interactive Google Colab notebook walking through a basic design pipeline with RFD3, MPNN, and RF3, please see the [IPD Design Pipeline Tutorial](https://colab.research.google.com/drive/1ZwIMV3n9h0ZOnIXX0GyKUuoiahgifBxh?usp=sharing).
|
|
118
125
|
|
|
119
126
|
### RFdiffusion3 (RFD3)
|
|
120
127
|
|
|
121
128
|
[RFdiffusion3](https://www.biorxiv.org/content/10.1101/2025.09.18.676967v2) is an all-atom generative model capable of designing protein structures under complex constraints.
|
|
122
129
|
|
|
123
|
-
> *See [models/rfd3/README.md](models/rfd3/README.md) for complete documentation.*
|
|
124
|
-
|
|
125
130
|
<div align="center">
|
|
126
|
-
<img src="
|
|
131
|
+
<img src="docs/_static/cover.png" alt="RFdiffusion3 generation trajectory." width="700">
|
|
127
132
|
</div>
|
|
128
133
|
|
|
129
|
-
|
|
130
|
-
[ProteinMPNN](https://www.science.org/doi/10.1126/science.add2187) and [LigandMPNN](https://www.nature.com/articles/s41592-025-02626-1) are lightweight inverse-folding models which can be use to design diverse sequences for backbones under constrained conditions.
|
|
131
|
-
|
|
132
|
-
> *See [models/mpnn/README.md](models/mpnn/README.md) for complete documentation.*
|
|
133
|
-
|
|
134
|
+
> *See [models/rfd3/README.md](models/rfd3/README.md) for complete documentation.*
|
|
134
135
|
|
|
135
136
|
### RosettaFold3 (RF3)
|
|
136
137
|
|
|
@@ -142,6 +143,11 @@ foundry install rfd3 ligandmpnn rf3 --checkpoint_dir <path/to/ckpt/dir>
|
|
|
142
143
|
|
|
143
144
|
> *See [models/rf3/README.md](models/rf3/README.md) for complete documentation.*
|
|
144
145
|
|
|
146
|
+
### ProteinMPNN
|
|
147
|
+
[ProteinMPNN](https://www.science.org/doi/10.1126/science.add2187) and [LigandMPNN](https://www.nature.com/articles/s41592-025-02626-1) are lightweight inverse-folding models which can be use to design diverse sequences for backbones under constrained conditions.
|
|
148
|
+
|
|
149
|
+
> *See [models/mpnn/README.md](models/mpnn/README.md) for complete documentation.*
|
|
150
|
+
|
|
145
151
|
---
|
|
146
152
|
|
|
147
153
|
## Development
|
|
@@ -159,11 +165,7 @@ foundry install rfd3 ligandmpnn rf3 --checkpoint_dir <path/to/ckpt/dir>
|
|
|
159
165
|
Install both `foundry` and models in editable mode for development:
|
|
160
166
|
|
|
161
167
|
```bash
|
|
162
|
-
|
|
163
|
-
uv pip install -e . -e ./models/rf3 -e ./models/rfd3 -e ./models/mpnn
|
|
164
|
-
|
|
165
|
-
# Or install only foundry (no models)
|
|
166
|
-
uv pip install -e .
|
|
168
|
+
uv pip install -e '.[all,dev]'
|
|
167
169
|
```
|
|
168
170
|
|
|
169
171
|
This approach allows you to:
|
|
@@ -171,6 +173,9 @@ This approach allows you to:
|
|
|
171
173
|
- Work on specific models without installing all models
|
|
172
174
|
- Add new models as independent packages in `models/`
|
|
173
175
|
|
|
176
|
+
> [!NOTE]
|
|
177
|
+
> Running tests is not currently supported, test files may be missing.
|
|
178
|
+
|
|
174
179
|
### Adding New Models
|
|
175
180
|
|
|
176
181
|
To add a new model:
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
foundry/__init__.py,sha256=H8S1nl5v6YeW8ggn1jKy4GdtH7c-FGS-j7CqUCAEnAU,1926
|
|
2
2
|
foundry/common.py,sha256=Aur8mH-CNmcUqSsw7VgaCQSW5sH1Bqf8Da91jzxPV1Y,3035
|
|
3
3
|
foundry/constants.py,sha256=0n1wBKCvNuw3QaQehSbmsHYkIdaGn3tLeRFItBrdeHY,913
|
|
4
|
-
foundry/version.py,sha256=
|
|
4
|
+
foundry/version.py,sha256=szvPIs2C82UunpzuvVg3MbF4QhzbBYTsVJ8DmPfq6_E,704
|
|
5
5
|
foundry/callbacks/__init__.py,sha256=VsRT1e4sqlJHPcTCsfupMEx82Iz-LoOAGPpwvf_OJeE,126
|
|
6
6
|
foundry/callbacks/callback.py,sha256=xZBo_suP4bLrP6gl5uJPbaXm00DXigePa6dMeDxucgg,3890
|
|
7
7
|
foundry/callbacks/health_logging.py,sha256=tEtkByOlaAA7nnelxb7PbM9_dcIgOsdbxCdQY3K5pMc,16664
|
|
@@ -10,7 +10,7 @@ foundry/callbacks/timing_logging.py,sha256=u-r0hKp7fWOY3mLk7CcuIwHgZbhte13m5M09x
|
|
|
10
10
|
foundry/callbacks/train_logging.py,sha256=Xs3tmZA88qLxmdSOwt-x8YKN4NKb1kVm59uptNXl4Qo,10399
|
|
11
11
|
foundry/hydra/resolvers.py,sha256=xyJzo6OeWAc_LOu8RiHhX7_CRNoLZ22626AvYHXYl4U,2186
|
|
12
12
|
foundry/inference_engines/base.py,sha256=ZHdlmGUqH4-p3v4RdrLH-Ps8_zalr7j5mQ4x-S53N4M,8375
|
|
13
|
-
foundry/inference_engines/checkpoint_registry.py,sha256=
|
|
13
|
+
foundry/inference_engines/checkpoint_registry.py,sha256=c_me8Uz2NWXAaELhQ4bT1HMPfY8XrH67kvCKdDPrD8g,4149
|
|
14
14
|
foundry/metrics/__init__.py,sha256=qL4wwaiQ7EtR30pmZ9MCknqx909BJcNvHVmNJUaz_WM,236
|
|
15
15
|
foundry/metrics/losses.py,sha256=2CLUmf7oCdFUCvgJukdNkff0FVG3BlATI-NI60TtpVY,903
|
|
16
16
|
foundry/metrics/metric.py,sha256=23pKh_Ra0EcHGo5cSzYQQrUGr5zWRxeufKSJ58tfXXo,12687
|
|
@@ -22,7 +22,7 @@ foundry/trainers/fabric.py,sha256=cjaTHbGuJEQwaGBvIAXD_il4bHtY-crsTY14Xn77uXA,40
|
|
|
22
22
|
foundry/training/EMA.py,sha256=3OWA9Pz7XuDr-SRxbz24tZf55DmhSa2fKy9r5v2IXqA,2651
|
|
23
23
|
foundry/training/checkpoint.py,sha256=mUiObg-qcF3tvMfVu77sD9m3yVRp71czv07ccliU7qQ,1791
|
|
24
24
|
foundry/training/schedulers.py,sha256=StmXegPfIdLAv31FreCTrDh9dsOvNUfzG4YGa61Y4oE,3647
|
|
25
|
-
foundry/utils/alignment.py,sha256=
|
|
25
|
+
foundry/utils/alignment.py,sha256=2anqy0mn9zeFEiVWS_EG7zHiyPk1C_gbUu-SRvQ5mAM,2502
|
|
26
26
|
foundry/utils/components.py,sha256=Piw2TfQF26uuxC3hXG3iv_4rgud1lVO-cv6N-p05EDY,15200
|
|
27
27
|
foundry/utils/datasets.py,sha256=pLBxVezm-TSrYuC5gFnJZdGnNWV7aPH2QiWIVE2hkdQ,16629
|
|
28
28
|
foundry/utils/ddp.py,sha256=ydHrO6peGbRnWAwgH5rmpHuQd55g2gFzzoZJYypn7GU,3970
|
|
@@ -34,7 +34,7 @@ foundry/utils/squashfs.py,sha256=QlcwuJyVe-QVfIOS7o1QfLhaCQPNzzox7ln4n8dcYEg,523
|
|
|
34
34
|
foundry/utils/torch.py,sha256=OLsqoxw4CTXbGzWUHernLUT7uQjLu0tVPtD8h8747DI,11211
|
|
35
35
|
foundry/utils/weights.py,sha256=btz4S02xff2vgiq4xMfiXuhK1ERafqQPtmimo1DmoWY,10381
|
|
36
36
|
foundry_cli/__init__.py,sha256=0BxY2RUKJLaMXUGgypPCwlTskTEFdVnkhTR4C4ft2Kw,52
|
|
37
|
-
foundry_cli/download_checkpoints.py,sha256=
|
|
37
|
+
foundry_cli/download_checkpoints.py,sha256=CxU9dKBa1vAkVd450tfH5aZAlQIUTrHsDGTbmxzd_JQ,8922
|
|
38
38
|
mpnn/__init__.py,sha256=hgQcXFaCbAxFrhydVAy0xj8yC7UJF-GCCFhqD0sZ7I4,57
|
|
39
39
|
mpnn/inference.py,sha256=wPtGR325eVRVeesXoWtBK6b_-VcU8BZae5IfQN3-mvA,1669
|
|
40
40
|
mpnn/train.py,sha256=9eQGBd3rdNF5Zr2w8oUgETbqxBavNBajtA6Vbc5zESE,10239
|
|
@@ -119,18 +119,18 @@ rfd3/__init__.py,sha256=2Wto2IsUIj2lGag9m_gqgdCwBNl5p21-Xnr7W_RpU3c,348
|
|
|
119
119
|
rfd3/callbacks.py,sha256=Zjt8RiaYWquoKOwRmC_wCUbRbov-V4zd2_73zjhgDHE,2783
|
|
120
120
|
rfd3/cli.py,sha256=ka3K5H117fzDYIDXFpOpJV21w_XBrHYJZdFE0thsGBI,1644
|
|
121
121
|
rfd3/constants.py,sha256=wLvDzrThpOrK8T3wGFNQeGrhAXOJQze8l3v_7pjIdMM,13141
|
|
122
|
-
rfd3/engine.py,sha256=
|
|
123
|
-
rfd3/run_inference.py,sha256=
|
|
122
|
+
rfd3/engine.py,sha256=NwATrhYFyqT7C9Bie8mWtUiqqzXgs9x6nOCkmZYPiT4,21224
|
|
123
|
+
rfd3/run_inference.py,sha256=HfRMQ30_SAHfc-VFzBV52F-aLaNdG6PW8VkdMyB__wE,1264
|
|
124
124
|
rfd3/train.py,sha256=rHswffIUhOae3_iYyvAiQ3jALoFuzrcRUgMlbJLinlI,7947
|
|
125
125
|
rfd3/inference/datasets.py,sha256=u-2U7deHXu-iOs7doiKKynewP-NEyJfdORSTDzUSaQI,6538
|
|
126
|
-
rfd3/inference/input_parsing.py,sha256=
|
|
127
|
-
rfd3/inference/legacy_input_parsing.py,sha256=
|
|
128
|
-
rfd3/inference/parsing.py,sha256=
|
|
129
|
-
rfd3/inference/symmetry/atom_array.py,sha256=
|
|
130
|
-
rfd3/inference/symmetry/checks.py,sha256=
|
|
126
|
+
rfd3/inference/input_parsing.py,sha256=TyEzCzeCaNhuNi0RjMcq9fF2j3Sp36KbuZ1FUjlBTZ8,45442
|
|
127
|
+
rfd3/inference/legacy_input_parsing.py,sha256=G2XxkrjdIpL6i1YY7xEmkFitVv__Pc45ow6IKKPHw64,28855
|
|
128
|
+
rfd3/inference/parsing.py,sha256=ktAMUuZE3Pe4bKAjjV3zjqcEDmGlMZ-cotIUhJsEQQA,5402
|
|
129
|
+
rfd3/inference/symmetry/atom_array.py,sha256=HfFagFUB5yB-Y4IfUM5nuVGWHC5AEkyHqt0JcIqTQ_E,10922
|
|
130
|
+
rfd3/inference/symmetry/checks.py,sha256=y-Kq0l5OhEmmxsPBBsMMB0qaAt18FeEicD3-jSMQFa0,9900
|
|
131
131
|
rfd3/inference/symmetry/contigs.py,sha256=6OvbZ2dJg-a0mvvKAC0VkzUH5HpUDxOJvkByIst_roU,2127
|
|
132
|
-
rfd3/inference/symmetry/frames.py,sha256=
|
|
133
|
-
rfd3/inference/symmetry/symmetry_utils.py,sha256=
|
|
132
|
+
rfd3/inference/symmetry/frames.py,sha256=aEwkmlUsYexERX9hu09JMhisC8QTpHPVhfITbL80-EE,10819
|
|
133
|
+
rfd3/inference/symmetry/symmetry_utils.py,sha256=p_PkxU3sw6gYGO2EmZTrbNQdLjz1mdTWEIl5MjQdIuY,14664
|
|
134
134
|
rfd3/metrics/design_metrics.py,sha256=O1RqZdjQPNlAWYRg6UJTERYg_gUI1_hVleKsm9xbWBY,16836
|
|
135
135
|
rfd3/metrics/hbonds_hbplus_metrics.py,sha256=Sewy9KzmrA1OnfkasN-fmWrQ9IRx9G7Yyhe2ua0mk28,11518
|
|
136
136
|
rfd3/metrics/hbonds_metrics.py,sha256=SIR4BnDhYdpVSqwXXRYpQ_tB-M0_fVyugGl08WivCmE,15257
|
|
@@ -170,7 +170,7 @@ rfd3/transforms/symmetry.py,sha256=GSnMF7oAnUxPozfafsRuHEv0yKXW0BpLTI6wsKGZrbc,2
|
|
|
170
170
|
rfd3/transforms/training_conditions.py,sha256=UXiUPjDwrNKM95tRe0eXrMeRN8XlTPc_MXUvo6UpePo,19510
|
|
171
171
|
rfd3/transforms/util_transforms.py,sha256=2AcLkzx-73ZFgcWD1cIHv7NyniRPI4_zThHK8azyQaY,18119
|
|
172
172
|
rfd3/transforms/virtual_atoms.py,sha256=UpmxzPPd5FaJigcRoxgLSHHrLLOqsCvZ5PPZfQSGqII,12547
|
|
173
|
-
rfd3/utils/inference.py,sha256
|
|
173
|
+
rfd3/utils/inference.py,sha256=-8IKzkB9ulhLEJgapvnZSdIaIPQDPMpyPpHTQlFS7r0,27317
|
|
174
174
|
rfd3/utils/io.py,sha256=wbdjUTQkDc3RCSM7gdogA-XOKR68HeQ-cfvyN4pP90w,9849
|
|
175
175
|
rfd3/utils/vizualize.py,sha256=HPlczrA3zkOuxV5X05eOvy_Oga9e3cPnFUXOEP4RR_g,11046
|
|
176
176
|
rf3/configs/inference.yaml,sha256=JmEZdkAnbnOrX79lGS5xrYYho9aBFfVxfUp-8KjJV5I,309
|
|
@@ -248,7 +248,7 @@ rfd3/configs/datasets/conditions/sequence_design.yaml,sha256=D1K6WOysmSAQ4LogltU
|
|
|
248
248
|
rfd3/configs/datasets/conditions/tipatom.yaml,sha256=0010o7UUL-l75qI8HCjC_tdBXFWysm2dgVXzE7bQyZ0,650
|
|
249
249
|
rfd3/configs/datasets/conditions/unconditional.yaml,sha256=z1eVHylswLyludXWFs1AMt3mTMu3EbAUHrP8J3XBsRU,446
|
|
250
250
|
rfd3/configs/datasets/train/rfd3_monomer_distillation.yaml,sha256=1f61uFeRB8OD6sifFuIKFov8D7PcHpqRT4Z-M5EzO4w,1207
|
|
251
|
-
rfd3/configs/datasets/train/pdb/af3_train_interface.yaml,sha256=
|
|
251
|
+
rfd3/configs/datasets/train/pdb/af3_train_interface.yaml,sha256=DSIpXW2SQ3drDp12490y0tFbjbugecyA7TI_x3WrKng,1546
|
|
252
252
|
rfd3/configs/datasets/train/pdb/af3_train_pn_unit.yaml,sha256=DPoEhLlyBu0RdBkkJeWB8pkOV4z0DBc6XmclLgww9II,1324
|
|
253
253
|
rfd3/configs/datasets/train/pdb/base.yaml,sha256=2VUEAKADyvjJmWP4FeOJwRat9r6F3_GXuyGYjvMvArw,291
|
|
254
254
|
rfd3/configs/datasets/train/pdb/base_no_weights.yaml,sha256=8HchN7DqYESBK520vShdg7xidWBSogGRAxfaxa5pKdE,554
|
|
@@ -285,7 +285,7 @@ rfd3/configs/hydra/default.yaml,sha256=SYDTSU8bAw20QssrtTi7lptiBD5H3XNyzApsyy0br
|
|
|
285
285
|
rfd3/configs/hydra/no_logging.yaml,sha256=MUXDFcw-QwaRPz9HcE-1tdZwbNha1mexTe31G-Zt9_w,120
|
|
286
286
|
rfd3/configs/inference_engine/base.yaml,sha256=ekP5U7bAALpeJGpwyj1v0N5LiEtptl5loRCtM8FRzRM,246
|
|
287
287
|
rfd3/configs/inference_engine/dev.yaml,sha256=-2snClOTwj5TQt7jnwSrI4pzAiI4nFulXKJflmgIyUw,304
|
|
288
|
-
rfd3/configs/inference_engine/rfdiffusion3.yaml,sha256=
|
|
288
|
+
rfd3/configs/inference_engine/rfdiffusion3.yaml,sha256=3bHIAhzFhFDIag0xQWYxHBUMSc71fjClHXKbZ-tpHzA,2112
|
|
289
289
|
rfd3/configs/logger/csv.yaml,sha256=DtcywAIS4OxLXP2QxSEvqdrjhMpT6xHiGspoYw5qkus,245
|
|
290
290
|
rfd3/configs/logger/default.yaml,sha256=pSyHyxT-J_T-g4_6TtD2yzN3rzxgY6rOG_Vh4RjZeFY,17
|
|
291
291
|
rfd3/configs/logger/wandb.yaml,sha256=RhCnFtO0hNc3R75ts417l5ICZeGm74lOj9Bfe7ZvRNA,652
|
|
@@ -294,7 +294,7 @@ rfd3/configs/model/components/ema.yaml,sha256=AIzf4RZLKP8AcfaxdvZBS1rFw3AlSo431r
|
|
|
294
294
|
rfd3/configs/model/components/rfd3_net.yaml,sha256=95FF4U7aWmLCoHvyxsRoE74n-bxTPD6KlAhPKNemVH4,3275
|
|
295
295
|
rfd3/configs/model/optimizers/adam.yaml,sha256=cTRNo4_4lNgLv0b329v-KiC_MCQtTVVTxeer5Au_FIM,145
|
|
296
296
|
rfd3/configs/model/samplers/edm.yaml,sha256=QycHAIrfhRgx0mJygTOs56FT93tGCWTGxrQSKBOA7Mc,483
|
|
297
|
-
rfd3/configs/model/samplers/symmetry.yaml,sha256=
|
|
297
|
+
rfd3/configs/model/samplers/symmetry.yaml,sha256=pI0Ens6jmbpAIl8E4eYsJR1SqIppe5OsWh91KfpjNjs,214
|
|
298
298
|
rfd3/configs/model/schedulers/af3.yaml,sha256=xEtRb--KPjg_5pW_IJvN9AHWVqCtOM4QOnXlMH2KrEg,149
|
|
299
299
|
rfd3/configs/paths/default.yaml,sha256=bjB04SNu_5E6W_v4mRBjwce0xmdKwO5wsVf4gfaRl0Y,1045
|
|
300
300
|
rfd3/configs/paths/data/default.yaml,sha256=jfs1dbbcOqHja4_6lXheyRg4t0YExqVn2w0rZEWL6XE,788
|
|
@@ -304,8 +304,8 @@ rfd3/configs/trainer/rfd3_base.yaml,sha256=R3lZxdyjUirjlLU31qWlnZgHaz4GcWTGGIz4f
|
|
|
304
304
|
rfd3/configs/trainer/loss/losses/diffusion_loss.yaml,sha256=FE4FCEfurE0ekwZ4YfS6wCvPSNqxClwg_kc73cPql5Y,323
|
|
305
305
|
rfd3/configs/trainer/loss/losses/sequence_loss.yaml,sha256=kezbQcqwAZ0VKQPUBr2MsNr9DcDL3ENIP1i-j7h-6Co,64
|
|
306
306
|
rfd3/configs/trainer/metrics/design_metrics.yaml,sha256=xVDpClhHqSHvsf-8StL26z51Vn-iuWMDG9KMB-kqOI0,719
|
|
307
|
-
rc_foundry-0.1.
|
|
308
|
-
rc_foundry-0.1.
|
|
309
|
-
rc_foundry-0.1.
|
|
310
|
-
rc_foundry-0.1.
|
|
311
|
-
rc_foundry-0.1.
|
|
307
|
+
rc_foundry-0.1.7.dist-info/METADATA,sha256=zlvCxfZ5-Ow7WuGKskfW6P1DGhZB9OfLIIBUBGncFeQ,11309
|
|
308
|
+
rc_foundry-0.1.7.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
309
|
+
rc_foundry-0.1.7.dist-info/entry_points.txt,sha256=BmiWCbWGtrd_lSOFMuCLBXyo84B7Nco-alj7hB0Yw9A,130
|
|
310
|
+
rc_foundry-0.1.7.dist-info/licenses/LICENSE.md,sha256=NKtPCJ7QMysFmzeDg56ZfUStvgzbq5sOvRQv7_ddZOs,1533
|
|
311
|
+
rc_foundry-0.1.7.dist-info/RECORD,,
|
|
@@ -7,7 +7,7 @@ dataset:
|
|
|
7
7
|
base_dir: ${paths.data.pdb_data_dir}
|
|
8
8
|
dataset:
|
|
9
9
|
name: interface
|
|
10
|
-
data: ${paths.data.pdb_parquet_dir}/
|
|
10
|
+
data: ${paths.data.pdb_parquet_dir}/interfaces_df.parquet
|
|
11
11
|
filters:
|
|
12
12
|
# filters common across all PDB datasets
|
|
13
13
|
- "deposition_date < '2021-09-30'"
|
|
@@ -7,7 +7,7 @@ _target_: rfd3.engine.RFD3InferenceEngine
|
|
|
7
7
|
|
|
8
8
|
out_dir: ???
|
|
9
9
|
inputs: ??? # null, json, pdb or
|
|
10
|
-
ckpt_path:
|
|
10
|
+
ckpt_path: rfd3
|
|
11
11
|
json_keys_subset: null
|
|
12
12
|
skip_existing: True
|
|
13
13
|
|
|
@@ -61,5 +61,5 @@ global_prefix: null
|
|
|
61
61
|
dump_prediction_metadata_json: True
|
|
62
62
|
dump_trajectories: False
|
|
63
63
|
align_trajectory_structures: False
|
|
64
|
-
prevalidate_inputs:
|
|
64
|
+
prevalidate_inputs: False
|
|
65
65
|
low_memory_mode: False # False for standard mode, True for memory efficient tokenization mode
|
rfd3/engine.py
CHANGED
|
@@ -23,7 +23,10 @@ from rfd3.inference.datasets import (
|
|
|
23
23
|
)
|
|
24
24
|
from rfd3.inference.input_parsing import DesignInputSpecification
|
|
25
25
|
from rfd3.model.inference_sampler import SampleDiffusionConfig
|
|
26
|
-
from rfd3.utils.inference import
|
|
26
|
+
from rfd3.utils.inference import (
|
|
27
|
+
ensure_inference_sampler_matches_design_spec,
|
|
28
|
+
ensure_input_is_abspath,
|
|
29
|
+
)
|
|
27
30
|
from rfd3.utils.io import (
|
|
28
31
|
CIF_LIKE_EXTENSIONS,
|
|
29
32
|
build_stack_from_atom_array_and_batched_coords,
|
|
@@ -171,6 +174,7 @@ class RFD3InferenceEngine(BaseInferenceEngine):
|
|
|
171
174
|
)
|
|
172
175
|
# save
|
|
173
176
|
self.specification_overrides = dict(specification or {})
|
|
177
|
+
self.inference_sampler_overrides = dict(inference_sampler or {})
|
|
174
178
|
|
|
175
179
|
# Setup output directories and args
|
|
176
180
|
self.global_prefix = global_prefix
|
|
@@ -210,6 +214,9 @@ class RFD3InferenceEngine(BaseInferenceEngine):
|
|
|
210
214
|
inputs=inputs,
|
|
211
215
|
n_batches=n_batches,
|
|
212
216
|
)
|
|
217
|
+
ensure_inference_sampler_matches_design_spec(
|
|
218
|
+
design_specifications, self.inference_sampler_overrides
|
|
219
|
+
)
|
|
213
220
|
# init before
|
|
214
221
|
self.initialize()
|
|
215
222
|
outputs = self._run_multi(design_specifications)
|
|
@@ -383,6 +390,9 @@ class RFD3InferenceEngine(BaseInferenceEngine):
|
|
|
383
390
|
# Based on inputs, construct the specifications to loop through
|
|
384
391
|
design_specifications = {}
|
|
385
392
|
for prefix, example_spec in inputs.items():
|
|
393
|
+
# Record task name in the specification
|
|
394
|
+
example_spec["extra"]["task_name"] = prefix
|
|
395
|
+
|
|
386
396
|
# ... Create n_batches for example
|
|
387
397
|
for batch_id in range((n_batches) if exists(n_batches) else 1):
|
|
388
398
|
# ... Example ID
|
|
@@ -524,21 +534,19 @@ def process_input(
|
|
|
524
534
|
|
|
525
535
|
|
|
526
536
|
def _reshape_trajectory(traj, align_structures: bool):
|
|
527
|
-
traj = [traj[i] for i in range(len(traj))]
|
|
528
|
-
n_steps = len(traj)
|
|
537
|
+
traj = [traj[i] for i in range(len(traj))] # make list of arrays
|
|
529
538
|
max_frames = 100
|
|
530
|
-
|
|
539
|
+
if len(traj) > max_frames:
|
|
540
|
+
selected_indices = torch.linspace(0, len(traj) - 1, max_frames).long().tolist()
|
|
541
|
+
traj = [traj[i] for i in selected_indices]
|
|
531
542
|
if align_structures:
|
|
532
543
|
# ... align the trajectories on the last prediction
|
|
533
|
-
for step in range(
|
|
544
|
+
for step in range(len(traj) - 1):
|
|
534
545
|
traj[step] = weighted_rigid_align(
|
|
535
|
-
X_L=traj[-1],
|
|
536
|
-
X_gt_L=traj[step],
|
|
537
|
-
)
|
|
546
|
+
X_L=traj[-1][None],
|
|
547
|
+
X_gt_L=traj[step][None],
|
|
548
|
+
).squeeze(0)
|
|
538
549
|
traj = traj[::-1] # reverse to go from noised -> denoised
|
|
539
|
-
if n_steps > max_frames:
|
|
540
|
-
selected_indices = torch.linspace(0, n_steps - 1, max_frames).long().tolist()
|
|
541
|
-
traj = [traj[i] for i in selected_indices]
|
|
542
550
|
|
|
543
551
|
traj = torch.stack(traj).cpu().numpy()
|
|
544
552
|
return traj
|
rfd3/inference/input_parsing.py
CHANGED
|
@@ -696,7 +696,7 @@ class DesignInputSpecification(BaseModel):
|
|
|
696
696
|
# Partial diffusion: use COM, keep all coordinates
|
|
697
697
|
if exists(self.symmetry) and self.symmetry.id:
|
|
698
698
|
# For symmetric structures, avoid COM centering that would collapse chains
|
|
699
|
-
|
|
699
|
+
logger.info(
|
|
700
700
|
"Partial diffusion with symmetry: skipping COM centering to preserve chain spacing"
|
|
701
701
|
)
|
|
702
702
|
else:
|
|
@@ -139,13 +139,18 @@ def fetch_motif_residue_(
|
|
|
139
139
|
subarray, motif=True, unindexed=False, dtype=int
|
|
140
140
|
) # all values init to True (fix all)
|
|
141
141
|
|
|
142
|
+
to_unindex = f"{src_chain}{src_resid}" in unindexed_components
|
|
143
|
+
to_index = f"{src_chain}{src_resid}" in components
|
|
144
|
+
|
|
142
145
|
# Assign is motif atom and sequence
|
|
143
146
|
if exists(atoms := fixed_atoms.get(f"{src_chain}{src_resid}")):
|
|
147
|
+
# If specified, we set fixed atoms in the residue to be motif atoms
|
|
144
148
|
atom_mask = get_name_mask(subarray.atom_name, atoms, res_name)
|
|
145
149
|
subarray.set_annotation("is_motif_atom", atom_mask)
|
|
146
150
|
# subarray.set_annotation("is_motif_atom_with_fixed_coord", atom_mask) # BUGFIX: uncomment
|
|
147
151
|
|
|
148
152
|
elif redesign_motif_sidechains and res_name in STANDARD_AA:
|
|
153
|
+
# If redesign_motif_sidechains is True, we only make the backbone atoms to be motif atoms
|
|
149
154
|
n_atoms = subarray.shape[0]
|
|
150
155
|
diffuse_oxygen = False
|
|
151
156
|
if n_atoms < 3:
|
|
@@ -178,6 +183,18 @@ def fetch_motif_residue_(
|
|
|
178
183
|
subarray.set_annotation(
|
|
179
184
|
"is_motif_atom_with_fixed_seq", np.zeros(subarray.shape[0], dtype=int)
|
|
180
185
|
)
|
|
186
|
+
elif to_index or to_unindex:
|
|
187
|
+
# If the residue is in the contig or unindexed components,
|
|
188
|
+
# we set all atoms in the residue to be motif atoms
|
|
189
|
+
subarray.set_annotation("is_motif_atom", np.ones(subarray.shape[0], dtype=int))
|
|
190
|
+
else:
|
|
191
|
+
if to_unindex and not (
|
|
192
|
+
unfix_all or f"{src_chain}{src_resid}" in unfix_residues
|
|
193
|
+
):
|
|
194
|
+
raise ValueError(
|
|
195
|
+
f"{src_chain}{src_resid} is not found in fixed_atoms, contig or unindex contig."
|
|
196
|
+
"Please check your input and contig specification."
|
|
197
|
+
)
|
|
181
198
|
if unfix_all or f"{src_chain}{src_resid}" in unfix_residues:
|
|
182
199
|
subarray.set_annotation(
|
|
183
200
|
"is_motif_atom_with_fixed_coord", np.zeros(subarray.shape[0], dtype=int)
|
|
@@ -197,7 +214,6 @@ def fetch_motif_residue_(
|
|
|
197
214
|
subarray.set_annotation(
|
|
198
215
|
"is_flexible_motif_atom", np.zeros(subarray.shape[0], dtype=bool)
|
|
199
216
|
)
|
|
200
|
-
to_unindex = f"{src_chain}{src_resid}" in unindexed_components
|
|
201
217
|
if to_unindex:
|
|
202
218
|
subarray.set_annotation(
|
|
203
219
|
"is_motif_atom_unindexed", subarray.is_motif_atom.copy()
|
rfd3/inference/parsing.py
CHANGED
|
@@ -117,6 +117,7 @@ def from_any_(v: Any, atom_array: AtomArray):
|
|
|
117
117
|
|
|
118
118
|
# Split to atom names
|
|
119
119
|
data_split[idx] = token.atom_name[comp_mask_subset].tolist()
|
|
120
|
+
# TODO: there is a bug where when you select specifc atoms within a ligand, output ligand is fragmented
|
|
120
121
|
|
|
121
122
|
# Update mask & token dictionary
|
|
122
123
|
mask[comp_mask] = comp_mask_subset
|
|
@@ -4,12 +4,8 @@ from rfd3.inference.symmetry.frames import (
|
|
|
4
4
|
get_symmetry_frames_from_symmetry_id,
|
|
5
5
|
)
|
|
6
6
|
|
|
7
|
-
from foundry.utils.ddp import RankedLogger
|
|
8
|
-
|
|
9
7
|
FIXED_TRANSFORM_ID = -1
|
|
10
8
|
FIXED_ENTITY_ID = -1
|
|
11
|
-
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
12
|
-
|
|
13
9
|
|
|
14
10
|
########################################################
|
|
15
11
|
# Symmetry annotations
|
|
@@ -28,7 +24,7 @@ def add_sym_annotations(atom_array, sym_conf):
|
|
|
28
24
|
is_asu = np.full(n, True, dtype=np.bool_)
|
|
29
25
|
atom_array.set_annotation("is_sym_asu", is_asu)
|
|
30
26
|
# symmetry_id
|
|
31
|
-
symmetry_ids = np.full(n, sym_conf.
|
|
27
|
+
symmetry_ids = np.full(n, sym_conf.id, dtype="U6")
|
|
32
28
|
atom_array.set_annotation("symmetry_id", symmetry_ids)
|
|
33
29
|
return atom_array
|
|
34
30
|
|
|
@@ -1,10 +1,13 @@
|
|
|
1
1
|
import numpy as np
|
|
2
|
-
from rfd3.inference.symmetry.contigs import
|
|
2
|
+
from rfd3.inference.symmetry.contigs import (
|
|
3
|
+
expand_contig_unsym_motif,
|
|
4
|
+
get_unsym_motif_mask,
|
|
5
|
+
)
|
|
3
6
|
from rfd3.transforms.conditioning_base import get_motif_features
|
|
4
7
|
|
|
5
8
|
from foundry.utils.ddp import RankedLogger
|
|
6
9
|
|
|
7
|
-
MIN_ATOMS_ALIGN =
|
|
10
|
+
MIN_ATOMS_ALIGN = 30
|
|
8
11
|
MAX_TRANSFORMS = 10
|
|
9
12
|
RMSD_CUT = 1.0 # Angstroms
|
|
10
13
|
|
|
@@ -18,29 +21,33 @@ def check_symmetry_config(
|
|
|
18
21
|
Check if the symmetry configuration is valid. Add all basic checks here.
|
|
19
22
|
"""
|
|
20
23
|
|
|
21
|
-
assert sym_conf.
|
|
24
|
+
assert sym_conf.id, "symmetry_id is required. e.g. {'id': 'C2'}"
|
|
22
25
|
# if unsym motif is provided, check that each motif name is in the atom array
|
|
23
|
-
|
|
26
|
+
|
|
27
|
+
is_unsym_motif = np.zeros(atom_array.shape[0], dtype=bool)
|
|
28
|
+
if sym_conf.is_unsym_motif:
|
|
24
29
|
assert (
|
|
25
30
|
src_atom_array is not None
|
|
26
31
|
), "Source atom array must be provided for symmetric motifs"
|
|
27
|
-
unsym_motif_names = sym_conf
|
|
32
|
+
unsym_motif_names = sym_conf.is_unsym_motif.split(",")
|
|
28
33
|
unsym_motif_names = expand_contig_unsym_motif(unsym_motif_names)
|
|
34
|
+
is_unsym_motif = get_unsym_motif_mask(atom_array, unsym_motif_names)
|
|
29
35
|
for n in unsym_motif_names:
|
|
30
36
|
if (sm and n not in sm.split(",")) and (n not in atom_array.src_component):
|
|
31
37
|
raise ValueError(f"Unsym motif {n} not found in atom_array")
|
|
38
|
+
|
|
39
|
+
is_motif_token = get_motif_features(atom_array)["is_motif_token"]
|
|
32
40
|
if (
|
|
33
|
-
|
|
34
|
-
and not sym_conf.
|
|
41
|
+
is_motif_token[~is_unsym_motif].any()
|
|
42
|
+
and not sym_conf.is_symmetric_motif
|
|
35
43
|
and not has_dist_cond
|
|
36
44
|
):
|
|
37
45
|
raise ValueError(
|
|
38
|
-
"Asymmetric motif inputs should be distance constrained.
|
|
46
|
+
"Asymmetric motif inputs should be distance constrained."
|
|
39
47
|
"Use atomwise_fixed_dist to constrain the distance between the motif atoms."
|
|
40
48
|
)
|
|
41
|
-
# else: if unconditional symmetry, no need to have symmetric input motif
|
|
42
49
|
|
|
43
|
-
if partial and not sym_conf.
|
|
50
|
+
if partial and not sym_conf.is_symmetric_motif:
|
|
44
51
|
raise ValueError(
|
|
45
52
|
"Partial diffusion with symmetry is only supported for symmetric inputs."
|
|
46
53
|
)
|
|
@@ -54,9 +61,6 @@ def check_atom_array_is_symmetric(atom_array):
|
|
|
54
61
|
Returns:
|
|
55
62
|
bool: True if the atom array is symmetric, False otherwise
|
|
56
63
|
"""
|
|
57
|
-
# TODO: Implement something like this https://github.com/baker-laboratory/ipd/blob/main/ipd/sym/sym_detect.py#L303
|
|
58
|
-
# and maybe this https://github.com/baker-laboratory/ipd/blob/main/ipd/sym/sym_detect.py#L231
|
|
59
|
-
|
|
60
64
|
import biotite.structure as struc
|
|
61
65
|
from rfd3.inference.symmetry.atom_array import (
|
|
62
66
|
apply_symmetry_to_atomarray_coord,
|
|
@@ -68,8 +72,10 @@ def check_atom_array_is_symmetric(atom_array):
|
|
|
68
72
|
# remove hetero atoms
|
|
69
73
|
atom_array = atom_array[~atom_array.hetero]
|
|
70
74
|
if len(atom_array) == 0:
|
|
71
|
-
ranked_logger.
|
|
72
|
-
|
|
75
|
+
ranked_logger.warning(
|
|
76
|
+
"Atom array has no protein chains. Please check your input."
|
|
77
|
+
)
|
|
78
|
+
return True
|
|
73
79
|
|
|
74
80
|
chains = np.unique(atom_array.chain_id)
|
|
75
81
|
asu_mask = atom_array.chain_id == chains[0]
|
|
@@ -162,16 +168,22 @@ def find_optimal_rotation(coords1, coords2, max_points=1000):
|
|
|
162
168
|
return None
|
|
163
169
|
|
|
164
170
|
|
|
165
|
-
def check_input_frames_match_symmetry_frames(
|
|
171
|
+
def check_input_frames_match_symmetry_frames(
|
|
172
|
+
computed_frames, original_frames, nids_by_entity
|
|
173
|
+
) -> None:
|
|
166
174
|
"""
|
|
167
175
|
Check if the atom array matches the symmetry_id.
|
|
168
176
|
Arguments:
|
|
169
177
|
computed_frames: list of computed frames
|
|
170
178
|
original_frames: list of original frames
|
|
171
179
|
"""
|
|
172
|
-
assert len(computed_frames) == len(
|
|
173
|
-
|
|
174
|
-
|
|
180
|
+
assert len(computed_frames) == len(original_frames), (
|
|
181
|
+
"Number of computed frames does not match number of original frames.\n"
|
|
182
|
+
f"Computed Frames: {len(computed_frames)}. Original Frames: {len(original_frames)}.\n"
|
|
183
|
+
"If the computed frames are not as expected, please check if you have one-to-one mapping "
|
|
184
|
+
"(size, sequence, folding) of an entity across all chains.\n"
|
|
185
|
+
f"Computed Entity Mapping: {nids_by_entity}."
|
|
186
|
+
)
|
|
175
187
|
|
|
176
188
|
|
|
177
189
|
def check_valid_multiplicity(nids_by_entity) -> None:
|
|
@@ -184,25 +196,35 @@ def check_valid_multiplicity(nids_by_entity) -> None:
|
|
|
184
196
|
multiplicity = min([len(i) for i in nids_by_entity.values()])
|
|
185
197
|
if multiplicity == 1: # no possible symmetry
|
|
186
198
|
raise ValueError(
|
|
187
|
-
"Input has no possible symmetry. If asymmetric motif, please use 2D conditioning inference instead
|
|
199
|
+
"Input has no possible symmetry. If asymmetric motif, please use 2D conditioning inference instead.\n"
|
|
200
|
+
"Multiplicity: 1"
|
|
188
201
|
)
|
|
189
202
|
|
|
190
203
|
# Check that the input is not asymmetric
|
|
191
204
|
multiplicity_good = [len(i) % multiplicity == 0 for i in nids_by_entity.values()]
|
|
192
205
|
if not all(multiplicity_good):
|
|
193
|
-
raise ValueError(
|
|
206
|
+
raise ValueError(
|
|
207
|
+
"Expected multiplicity does not match for some entities.\n"
|
|
208
|
+
"Please modify your input to have one-to-one mapping (size, sequence, folding) of an entity across all chains.\n"
|
|
209
|
+
f"Expected Multiplicity: {multiplicity}.\n"
|
|
210
|
+
f"Computed Entity Mapping: {nids_by_entity}."
|
|
211
|
+
)
|
|
194
212
|
|
|
195
213
|
|
|
196
214
|
def check_valid_subunit_size(nids_by_entity, pn_unit_id) -> None:
|
|
197
215
|
"""
|
|
198
216
|
Check that the subunits in the input are of the same size.
|
|
199
217
|
Arguments:
|
|
200
|
-
nids_by_entity: dict mapping entity to ids
|
|
218
|
+
nids_by_entity: dict mapping entity to ids. e.g. {0: (['A_1', 'B_1', 'C_1']), 1: (['A_2', 'B_2', 'C_2'])}
|
|
219
|
+
pn_unit_id: array of ids. e.g. ['A_1', 'B_1', 'C_1', 'A_2', 'B_2', 'C_2']
|
|
201
220
|
"""
|
|
202
|
-
for
|
|
203
|
-
for
|
|
204
|
-
if (pn_unit_id == js[0]).sum() != (pn_unit_id ==
|
|
205
|
-
raise ValueError(
|
|
221
|
+
for js in nids_by_entity.values():
|
|
222
|
+
for js_i in js[1:]:
|
|
223
|
+
if (pn_unit_id == js[0]).sum() != (pn_unit_id == js_i).sum():
|
|
224
|
+
raise ValueError(
|
|
225
|
+
f"Size mismatch between chain {js[0]} ({(pn_unit_id == js[0]).sum()} atoms) "
|
|
226
|
+
f"and chain {js_i} ({(pn_unit_id == js_i).sum()} atoms). Please check your input file."
|
|
227
|
+
)
|
|
206
228
|
|
|
207
229
|
|
|
208
230
|
def check_min_atoms_to_align(natm_per_unique, reference_entity) -> None:
|
|
@@ -212,7 +234,10 @@ def check_min_atoms_to_align(natm_per_unique, reference_entity) -> None:
|
|
|
212
234
|
nids_by_entity: dict mapping entity to ids
|
|
213
235
|
"""
|
|
214
236
|
if natm_per_unique[reference_entity] < MIN_ATOMS_ALIGN:
|
|
215
|
-
raise ValueError(
|
|
237
|
+
raise ValueError(
|
|
238
|
+
f"Not enough atoms to align < {MIN_ATOMS_ALIGN} atoms."
|
|
239
|
+
f"Please provide a input with at least {MIN_ATOMS_ALIGN} atoms."
|
|
240
|
+
)
|
|
216
241
|
|
|
217
242
|
|
|
218
243
|
def check_max_transforms(chains_to_consider) -> None:
|
|
@@ -224,7 +249,7 @@ def check_max_transforms(chains_to_consider) -> None:
|
|
|
224
249
|
"""
|
|
225
250
|
if len(chains_to_consider) > MAX_TRANSFORMS:
|
|
226
251
|
raise ValueError(
|
|
227
|
-
"Number of transforms exceeds the max number of transforms (
|
|
252
|
+
f"Number of transforms exceeds the max number of transforms ({MAX_TRANSFORMS})."
|
|
228
253
|
)
|
|
229
254
|
|
|
230
255
|
|
|
@@ -10,12 +10,13 @@ def get_symmetry_frames_from_symmetry_id(symmetry_id):
|
|
|
10
10
|
Returns:
|
|
11
11
|
frames: list of rotation matrices
|
|
12
12
|
"""
|
|
13
|
+
from rfd3.inference.symmetry.symmetry_utils import SymmetryConfig
|
|
13
14
|
|
|
14
15
|
# Get frames from symmetry id
|
|
15
16
|
sym_conf = {}
|
|
16
|
-
if isinstance(symmetry_id,
|
|
17
|
+
if isinstance(symmetry_id, SymmetryConfig):
|
|
17
18
|
sym_conf = symmetry_id
|
|
18
|
-
symmetry_id = symmetry_id.
|
|
19
|
+
symmetry_id = symmetry_id.id
|
|
19
20
|
|
|
20
21
|
if symmetry_id.lower().startswith("c"):
|
|
21
22
|
order = int(symmetry_id[1:])
|
|
@@ -25,9 +26,9 @@ def get_symmetry_frames_from_symmetry_id(symmetry_id):
|
|
|
25
26
|
frames = get_dihedral_frames(order)
|
|
26
27
|
elif symmetry_id.lower() == "input_defined":
|
|
27
28
|
assert (
|
|
28
|
-
|
|
29
|
+
sym_conf.symmetry_file is not None
|
|
29
30
|
), "symmetry_file is required for input_defined symmetry"
|
|
30
|
-
frames = get_frames_from_file(sym_conf.
|
|
31
|
+
frames = get_frames_from_file(sym_conf.symmetry_file)
|
|
31
32
|
else:
|
|
32
33
|
raise ValueError(f"Symmetry id {symmetry_id} not supported")
|
|
33
34
|
|
|
@@ -120,7 +121,9 @@ def get_symmetry_frames_from_atom_array(src_atom_array, input_frames):
|
|
|
120
121
|
computed_frames = [(R, np.array([0, 0, 0])) for R in Rs]
|
|
121
122
|
|
|
122
123
|
# check that the computed frames match the input frames
|
|
123
|
-
check_input_frames_match_symmetry_frames(
|
|
124
|
+
check_input_frames_match_symmetry_frames(
|
|
125
|
+
computed_frames, input_frames, nids_by_entity
|
|
126
|
+
)
|
|
124
127
|
|
|
125
128
|
return computed_frames
|
|
126
129
|
|
|
@@ -39,18 +39,36 @@ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
|
39
39
|
|
|
40
40
|
|
|
41
41
|
class SymmetryConfig(BaseModel):
|
|
42
|
-
# AM / HE TODO: feel free to flesh this out and add validation as needed
|
|
43
42
|
model_config = ConfigDict(
|
|
44
43
|
arbitrary_types_allowed=True,
|
|
45
44
|
extra="allow",
|
|
46
45
|
)
|
|
47
|
-
id: Optional[str] = Field(
|
|
48
|
-
|
|
49
|
-
|
|
46
|
+
id: Optional[str] = Field(
|
|
47
|
+
None,
|
|
48
|
+
description="Symmetry group ID. e.g. 'C3', 'D2'. Only C and D symmetry types are supported currently.",
|
|
49
|
+
)
|
|
50
|
+
is_unsym_motif: Optional[str] = Field(
|
|
51
|
+
None,
|
|
52
|
+
description="Comma separated list of contig/ligand names that should not be symmetrized such as DNA strands. \
|
|
53
|
+
e.g. 'HEM' or 'Y1-11,Z16-25'",
|
|
54
|
+
)
|
|
55
|
+
is_symmetric_motif: bool = Field(
|
|
56
|
+
True,
|
|
57
|
+
description="If True, the input motifs are expected to be already symmetric and won't be symmetrized. \
|
|
58
|
+
If False, the all input motifs are expected to be ASU and will be symmetrized.",
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def convery_sym_conf_to_symmetry_config(sym_conf: dict):
|
|
63
|
+
return SymmetryConfig(**sym_conf)
|
|
50
64
|
|
|
51
65
|
|
|
52
66
|
def make_symmetric_atom_array(
|
|
53
|
-
asu_atom_array,
|
|
67
|
+
asu_atom_array,
|
|
68
|
+
sym_conf: SymmetryConfig | dict,
|
|
69
|
+
sm=None,
|
|
70
|
+
has_dist_cond=False,
|
|
71
|
+
src_atom_array=None,
|
|
54
72
|
):
|
|
55
73
|
"""
|
|
56
74
|
apply symmetry to an atom array.
|
|
@@ -58,39 +76,33 @@ def make_symmetric_atom_array(
|
|
|
58
76
|
asu_atom_array: atom array of the asymmetric unit
|
|
59
77
|
sym_conf: symmetry configuration (dict, "id" key is required)
|
|
60
78
|
sm: optional small molecule names (str, comma separated)
|
|
61
|
-
|
|
79
|
+
has_dist_cond: whether to add 2d entity annotations
|
|
62
80
|
Returns:
|
|
63
81
|
new_asu_atom_array: atom array with symmetry applied
|
|
64
82
|
"""
|
|
65
|
-
|
|
66
|
-
sym_conf
|
|
67
|
-
) # TODO: JB: remove this line to keep as symmetry config for cleaner syntax(?)
|
|
68
|
-
ranked_logger.info(f"Symmetry Configs: {sym_conf}")
|
|
83
|
+
if not isinstance(sym_conf, SymmetryConfig):
|
|
84
|
+
sym_conf = convery_sym_conf_to_symmetry_config(sym_conf)
|
|
69
85
|
|
|
70
|
-
# Making sure that the symmetry config is valid
|
|
71
86
|
check_symmetry_config(
|
|
72
|
-
asu_atom_array,
|
|
73
|
-
sym_conf,
|
|
74
|
-
sm,
|
|
75
|
-
has_dist_cond=has_2d,
|
|
76
|
-
src_atom_array=src_atom_array,
|
|
87
|
+
asu_atom_array, sym_conf, sm, has_dist_cond, src_atom_array=src_atom_array
|
|
77
88
|
)
|
|
78
89
|
# Adding utility annotations to the asu atom array
|
|
79
90
|
asu_atom_array = _add_util_annotations(asu_atom_array, sym_conf, sm)
|
|
80
91
|
|
|
81
|
-
if
|
|
92
|
+
if has_dist_cond: # NB: this will only work for asymmetric motifs at the moment - need to add functionality for symmetric motifs
|
|
82
93
|
asu_atom_array = add_2d_entity_annotations(asu_atom_array)
|
|
83
94
|
|
|
84
95
|
frames = get_symmetry_frames_from_symmetry_id(sym_conf)
|
|
85
96
|
|
|
86
97
|
# If the motif is symmetric, we get the frames instead from the source atom array.
|
|
87
|
-
if sym_conf.
|
|
98
|
+
if sym_conf.is_symmetric_motif:
|
|
88
99
|
assert (
|
|
89
100
|
src_atom_array is not None
|
|
90
101
|
), "Source atom array must be provided for symmetric motifs"
|
|
91
|
-
# if symmetric motif is provided, get the frames from the src atom array
|
|
102
|
+
# if symmetric motif is provided, get the frames from the src atom array.
|
|
92
103
|
frames = get_symmetry_frames_from_atom_array(src_atom_array, frames)
|
|
93
|
-
|
|
104
|
+
elif (asu_atom_array._is_motif[~asu_atom_array._is_unsym_motif]).any():
|
|
105
|
+
# if the motifs that's not unsym motifs are present.
|
|
94
106
|
raise NotImplementedError(
|
|
95
107
|
"Asymmetric motif inputs are not implemented yet. please symmetrize the motif."
|
|
96
108
|
)
|
|
@@ -101,7 +113,7 @@ def make_symmetric_atom_array(
|
|
|
101
113
|
# Extracting all things at this moment that we will not want to symmetrize.
|
|
102
114
|
# This includes: 1) unsym motifs, 2) ligands
|
|
103
115
|
unsym_atom_arrays = []
|
|
104
|
-
if sym_conf.
|
|
116
|
+
if sym_conf.is_unsym_motif:
|
|
105
117
|
# unsym_motif_atom_array = get_unsym_motif(asu_atom_array, asu_atom_array._is_unsym_motif)
|
|
106
118
|
# Now remove the unsym motifs from the asu atom array
|
|
107
119
|
unsym_atom_arrays.append(asu_atom_array[asu_atom_array._is_unsym_motif])
|
|
@@ -128,7 +140,7 @@ def make_symmetric_atom_array(
|
|
|
128
140
|
symmetrized_atom_array = struc.concatenate(symmetry_unit_list)
|
|
129
141
|
|
|
130
142
|
# add 2D conditioning annotations
|
|
131
|
-
if
|
|
143
|
+
if has_dist_cond:
|
|
132
144
|
symmetrized_atom_array = reannotate_2d_conditions(symmetrized_atom_array)
|
|
133
145
|
|
|
134
146
|
# set all motifs to not have any symmetrization applied to them
|
|
@@ -183,7 +195,7 @@ def make_symmetric_atom_array_for_partial_diffusion(atom_array, sym_conf):
|
|
|
183
195
|
frames = get_symmetry_frames_from_symmetry_id(sym_conf)
|
|
184
196
|
|
|
185
197
|
# Add symmetry ID
|
|
186
|
-
symmetry_ids = np.full(n, sym_conf.
|
|
198
|
+
symmetry_ids = np.full(n, sym_conf.id, dtype="U6")
|
|
187
199
|
atom_array.set_annotation("symmetry_id", symmetry_ids)
|
|
188
200
|
|
|
189
201
|
# Initialize transform annotations (use same format as original system)
|
|
@@ -244,7 +256,7 @@ def _add_util_annotations(asu_atom_array, sym_conf, sm):
|
|
|
244
256
|
"""
|
|
245
257
|
n = asu_atom_array.shape[0]
|
|
246
258
|
is_motif = get_motif_features(asu_atom_array)["is_motif_atom"].astype(np.bool_)
|
|
247
|
-
is_sm = np.zeros(
|
|
259
|
+
is_sm = np.zeros(n, dtype=bool)
|
|
248
260
|
is_asu = np.ones(n, dtype=bool)
|
|
249
261
|
is_unsym_motif = np.zeros(n, dtype=bool)
|
|
250
262
|
|
|
@@ -257,8 +269,8 @@ def _add_util_annotations(asu_atom_array, sym_conf, sm):
|
|
|
257
269
|
)
|
|
258
270
|
|
|
259
271
|
# assign unsym motifs
|
|
260
|
-
if sym_conf.
|
|
261
|
-
unsym_motif_names = sym_conf
|
|
272
|
+
if sym_conf.is_unsym_motif:
|
|
273
|
+
unsym_motif_names = sym_conf.is_unsym_motif.split(",")
|
|
262
274
|
unsym_motif_names = expand_contig_unsym_motif(unsym_motif_names)
|
|
263
275
|
is_unsym_motif = get_unsym_motif_mask(asu_atom_array, unsym_motif_names)
|
|
264
276
|
|
|
@@ -361,38 +373,4 @@ def apply_symmetry_to_xyz_atomwise(X_L, sym_feats, partial_diffusion=False):
|
|
|
361
373
|
"blc,cd->bld", asu_xyz, sym_transforms[target_id][0].to(asu_xyz.dtype)
|
|
362
374
|
) + sym_transforms[target_id][1].to(asu_xyz.dtype)
|
|
363
375
|
|
|
364
|
-
# Log inter-chain distances for debugging - use actual chain annotations
|
|
365
|
-
if sym_X_L.shape[1] > 100: # Only for large structures
|
|
366
|
-
# Use symmetry entity annotations to find different chains
|
|
367
|
-
sym_entity_id = sym_feats["sym_entity_id"]
|
|
368
|
-
unique_entities = torch.unique(sym_entity_id)
|
|
369
|
-
|
|
370
|
-
if len(unique_entities) >= 2:
|
|
371
|
-
# Get atoms from first two different entities
|
|
372
|
-
entity_0_mask = sym_entity_id == unique_entities[0]
|
|
373
|
-
entity_1_mask = sym_entity_id == unique_entities[1]
|
|
374
|
-
|
|
375
|
-
if entity_0_mask.sum() > 0 and entity_1_mask.sum() > 0:
|
|
376
|
-
entity_0_atoms = sym_X_L[0, entity_0_mask, :]
|
|
377
|
-
entity_1_atoms = sym_X_L[0, entity_1_mask, :]
|
|
378
|
-
|
|
379
|
-
# Sample subset to avoid memory issues
|
|
380
|
-
entity_0_sample = entity_0_atoms[: min(50, entity_0_atoms.shape[0]), :]
|
|
381
|
-
entity_1_sample = entity_1_atoms[: min(50, entity_1_atoms.shape[0]), :]
|
|
382
|
-
|
|
383
|
-
min_distance = (
|
|
384
|
-
torch.cdist(entity_0_sample, entity_1_sample).min().item()
|
|
385
|
-
)
|
|
386
|
-
ranked_logger.info(
|
|
387
|
-
f"Min inter-chain distance after symmetry: {min_distance:.2f} Å"
|
|
388
|
-
)
|
|
389
|
-
|
|
390
|
-
# Also log the centers of each entity
|
|
391
|
-
entity_0_center = entity_0_atoms.mean(dim=0)
|
|
392
|
-
entity_1_center = entity_1_atoms.mean(dim=0)
|
|
393
|
-
center_distance = torch.norm(entity_0_center - entity_1_center).item()
|
|
394
|
-
ranked_logger.info(
|
|
395
|
-
f"Distance between chain centers: {center_distance:.2f} Å"
|
|
396
|
-
)
|
|
397
|
-
|
|
398
376
|
return sym_X_L
|
rfd3/run_inference.py
CHANGED
|
@@ -12,7 +12,9 @@ load_dotenv(override=True)
|
|
|
12
12
|
|
|
13
13
|
# For pip-installed package, configs should be relative to this file
|
|
14
14
|
# Adjust this path based on where configs are bundled in the package
|
|
15
|
-
_config_path = os.path.join(
|
|
15
|
+
_config_path = os.path.join(
|
|
16
|
+
os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "configs"
|
|
17
|
+
)
|
|
16
18
|
|
|
17
19
|
|
|
18
20
|
@hydra.main(
|
rfd3/utils/inference.py
CHANGED
|
@@ -391,6 +391,29 @@ def ensure_input_is_abspath(args: dict, path: PathLike | None):
|
|
|
391
391
|
return args
|
|
392
392
|
|
|
393
393
|
|
|
394
|
+
def ensure_inference_sampler_matches_design_spec(
|
|
395
|
+
design_spec: dict, inference_sampler: dict | None = None
|
|
396
|
+
):
|
|
397
|
+
"""
|
|
398
|
+
Ensure the inference sampler is set to the correct sampler for the design specification.
|
|
399
|
+
Args:
|
|
400
|
+
design_spec: Design specification dictionary
|
|
401
|
+
inference_sampler: Inference sampler dictionary
|
|
402
|
+
"""
|
|
403
|
+
has_symmetry_specification = [
|
|
404
|
+
True if "symmetry" in item.keys() else False for item in design_spec.values()
|
|
405
|
+
]
|
|
406
|
+
if any(has_symmetry_specification):
|
|
407
|
+
if (
|
|
408
|
+
inference_sampler is None
|
|
409
|
+
or inference_sampler.get("kind", "default") != "symmetry"
|
|
410
|
+
):
|
|
411
|
+
raise ValueError(
|
|
412
|
+
"You requested for symmetric designs, but inference sampler is not set to symmetry. "
|
|
413
|
+
"Please add inference_sampler.kind='symmetry' to your command."
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
|
|
394
417
|
#################################################################################
|
|
395
418
|
# Custom infer_ori functions
|
|
396
419
|
#################################################################################
|
|
File without changes
|
|
File without changes
|
|
File without changes
|