rc-foundry 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,20 +3,62 @@
3
3
  import os
4
4
  from dataclasses import dataclass
5
5
  from pathlib import Path
6
+ from typing import Iterable, List
6
7
 
8
+ import dotenv
7
9
 
8
- def get_default_checkpoint_dir() -> Path:
9
- """Get the default checkpoint directory.
10
+ DEFAULT_CHECKPOINT_DIR = Path.home() / ".foundry" / "checkpoints"
11
+
12
+
13
+ def _normalize_paths(paths: Iterable[Path]) -> list[Path]:
14
+ """Return absolute, deduplicated paths in order."""
15
+ seen = set()
16
+ normalized: List[Path] = []
17
+ for path in paths:
18
+ resolved = path.expanduser().absolute()
19
+ if resolved not in seen:
20
+ normalized.append(resolved)
21
+ seen.add(resolved)
22
+ return normalized
10
23
 
11
- Priority:
12
- 1. FOUNDRY_CHECKPOINTS_DIR environment variable
13
- 2. ~/.foundry/checkpoints
24
+
25
+ def get_default_checkpoint_dirs() -> list[Path]:
26
+ """Return checkpoint search paths.
27
+
28
+ Always starts with the default ~/.foundry/checkpoints directory and then
29
+ appends any additional directories from the colon-separated
30
+ FOUNDRY_CHECKPOINT_DIRS environment variable.
14
31
  """
15
- if "FOUNDRY_CHECKPOINTS_DIR" in os.environ and os.environ.get(
16
- "FOUNDRY_CHECKPOINTS_DIR"
17
- ):
18
- return Path(os.environ["FOUNDRY_CHECKPOINTS_DIR"]).absolute()
19
- return Path.home() / ".foundry" / "checkpoints"
32
+ env_dirs = os.environ.get("FOUNDRY_CHECKPOINT_DIRS", "")
33
+
34
+ # For backward compatibility, also check FOUNDRY_CHECKPOINTS_DIR
35
+ if not env_dirs:
36
+ env_dirs = os.environ.get("FOUNDRY_CHECKPOINTS_DIR", "")
37
+
38
+ extra_dirs: list[Path] = []
39
+ if env_dirs:
40
+ extra_dirs = [Path(p.strip()) for p in env_dirs.split(":") if p.strip()]
41
+ return _normalize_paths([*extra_dirs, DEFAULT_CHECKPOINT_DIR])
42
+
43
+
44
+ def get_default_checkpoint_dir() -> Path:
45
+ """Backward-compatible helper returning the primary checkpoint directory."""
46
+ return get_default_checkpoint_dirs()[0]
47
+
48
+
49
+ def append_checkpoint_to_env(checkpoint_dirs: list[Path]) -> bool:
50
+ dotenv_path = dotenv.find_dotenv()
51
+ if dotenv_path:
52
+ checkpoint_dirs = _normalize_paths(checkpoint_dirs)
53
+ dotenv.set_key(
54
+ dotenv_path=dotenv_path,
55
+ key_to_set="FOUNDRY_CHECKPOINT_DIRS",
56
+ value_to_set=":".join(str(path) for path in checkpoint_dirs),
57
+ export=False,
58
+ )
59
+ return True
60
+ else:
61
+ return False
20
62
 
21
63
 
22
64
  @dataclass
@@ -27,7 +69,12 @@ class RegisteredCheckpoint:
27
69
  sha256: None = None # Optional: add checksum for verification
28
70
 
29
71
  def get_default_path(self):
30
- return get_default_checkpoint_dir() / self.filename
72
+ checkpoint_dirs = get_default_checkpoint_dirs()
73
+ for checkpoint_dir in checkpoint_dirs:
74
+ candidate = checkpoint_dir / self.filename
75
+ if candidate.exists():
76
+ return candidate
77
+ return checkpoint_dirs[0] / self.filename
31
78
 
32
79
 
33
80
  REGISTERED_CHECKPOINTS = {
@@ -18,14 +18,19 @@ def weighted_rigid_align(
18
18
  Returns:
19
19
  X_align_L: [B, L, 3]
20
20
  """
21
- assert X_L.shape == X_gt_L.shape
22
- assert X_L.shape[:-1] == w_L.shape
23
21
 
22
+ # Canonicalize dimensions
23
+ if X_L.ndim == 2:
24
+ X_L = X_L[None]
25
+ if X_gt_L.ndim == 2:
26
+ X_gt_L = X_gt_L[None]
24
27
  if X_exists_L is None:
25
28
  X_exists_L = torch.ones((X_L.shape[-2]), dtype=torch.bool)
26
29
  if w_L is None:
27
30
  w_L = torch.ones_like(X_L[..., 0])
28
31
  else:
32
+ if w_L.ndim == 1:
33
+ w_L = w_L[None]
29
34
  w_L = w_L.to(torch.float32)
30
35
 
31
36
  # Assert `X_exists_L` is a boolean mask
@@ -33,6 +38,9 @@ def weighted_rigid_align(
33
38
  X_exists_L.dtype == torch.bool
34
39
  ), "X_exists_L should be a boolean mask! Otherwise, the alignment will be incorrect (silent failure)!"
35
40
 
41
+ assert X_L.shape == X_gt_L.shape
42
+ assert X_L.shape[:-1] == w_L.shape
43
+
36
44
  X_resolved = X_L[:, X_exists_L]
37
45
  X_gt_resolved = X_gt_L[:, X_exists_L]
38
46
  w_resolved = w_L[:, X_exists_L]
foundry/version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.1.6'
32
- __version_tuple__ = version_tuple = (0, 1, 6)
31
+ __version__ = version = '0.1.7'
32
+ __version_tuple__ = version_tuple = (0, 1, 7)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -6,7 +6,7 @@ from typing import Optional
6
6
  from urllib.request import urlopen
7
7
 
8
8
  import typer
9
- from dotenv import find_dotenv, load_dotenv, set_key
9
+ from dotenv import load_dotenv
10
10
  from rich.console import Console
11
11
  from rich.progress import (
12
12
  BarColumn,
@@ -20,7 +20,8 @@ from rich.progress import (
20
20
 
21
21
  from foundry.inference_engines.checkpoint_registry import (
22
22
  REGISTERED_CHECKPOINTS,
23
- get_default_checkpoint_dir,
23
+ append_checkpoint_to_env,
24
+ get_default_checkpoint_dirs,
24
25
  )
25
26
 
26
27
  load_dotenv(override=True)
@@ -29,6 +30,27 @@ app = typer.Typer(help="Foundry model checkpoint installation utilities")
29
30
  console = Console()
30
31
 
31
32
 
33
+ def _resolve_checkpoint_dirs(checkpoint_dir: Optional[Path]) -> list[Path]:
34
+ """Return checkpoint search path with defaults first."""
35
+ checkpoint_dirs = get_default_checkpoint_dirs()
36
+ if checkpoint_dir is not None:
37
+ resolved = checkpoint_dir.expanduser().absolute()
38
+ if resolved not in checkpoint_dirs:
39
+ checkpoint_dirs.insert(0, resolved)
40
+ else:
41
+ # Move to front
42
+ checkpoint_dirs.remove(resolved)
43
+ checkpoint_dirs.insert(0, resolved)
44
+
45
+ # Try to persist checkpoint dir to .env (optional, may not exist in Colab etc.)
46
+ if append_checkpoint_to_env(checkpoint_dirs):
47
+ console.print(
48
+ f"Tracked checkpoint directories: {':'.join(str(path) for path in checkpoint_dirs)}"
49
+ )
50
+
51
+ return checkpoint_dirs
52
+
53
+
32
54
  def download_file(url: str, dest: Path, verify_hash: Optional[str] = None) -> None:
33
55
  """Download a file with progress bar and optional hash verification.
34
56
 
@@ -123,134 +145,112 @@ def install_model(model_name: str, checkpoint_dir: Path, force: bool = False) ->
123
145
  def install(
124
146
  models: list[str] = typer.Argument(
125
147
  ...,
126
- help="Models to install: 'all', 'rfd3', 'rf3', 'mpnn', or combination",
148
+ help="Models to install: 'all', 'rfd3', 'rf3', 'mpnn', or a combination thereof",
127
149
  ),
128
150
  checkpoint_dir: Optional[Path] = typer.Option(
129
151
  None,
130
152
  "--checkpoint-dir",
131
153
  "-d",
132
- help="Directory to save checkpoints (default: $FOUNDRY_CHECKPOINTS_DIR or ~/.foundry/checkpoints)",
154
+ help="Directory to save checkpoints (default search path: ~/.foundry/checkpoints plus any $FOUNDRY_CHECKPOINT_DIRS entries)",
133
155
  ),
134
156
  force: bool = typer.Option(
135
157
  False, "--force", "-f", help="Overwrite existing checkpoints"
136
158
  ),
137
159
  ):
138
160
  """Install model checkpoints for foundry.
139
-
140
161
  Examples:
141
-
142
162
  foundry install all
143
-
144
163
  foundry install rfd3 rf3
145
-
146
164
  foundry install proteinmpnn --checkpoint-dir ./checkpoints
147
165
  """
148
166
  # Determine checkpoint directory
149
- if checkpoint_dir is None:
150
- checkpoint_dir = get_default_checkpoint_dir()
167
+ checkpoint_dirs = _resolve_checkpoint_dirs(checkpoint_dir)
168
+ primary_checkpoint_dir = checkpoint_dirs[0]
151
169
 
152
- console.print(f"[bold]Checkpoint directory:[/bold] {checkpoint_dir}")
153
- console.print()
170
+ console.print(f"[bold]Install target:[/bold] {primary_checkpoint_dir}\n")
154
171
 
155
172
  # Expand 'all' to all available models
156
173
  if "all" in models:
174
+ models_to_install = list(REGISTERED_CHECKPOINTS.keys())
175
+ elif "base-models" in models:
157
176
  models_to_install = ["rfd3", "proteinmpnn", "ligandmpnn", "rf3"]
158
177
  else:
159
178
  models_to_install = models
160
179
 
161
180
  # Install each model
162
181
  for model_name in models_to_install:
163
- install_model(model_name, checkpoint_dir, force)
182
+ install_model(model_name, primary_checkpoint_dir, force)
164
183
  console.print()
165
184
 
166
- # Try to persist checkpoint dir to .env (optional, may not exist in Colab etc.)
167
- dotenv_path = find_dotenv()
168
- if dotenv_path:
169
- set_key(
170
- dotenv_path=dotenv_path,
171
- key_to_set="FOUNDRY_CHECKPOINTS_DIR",
172
- value_to_set=str(checkpoint_dir),
173
- export=False,
174
- )
175
- console.print(f"Saved FOUNDRY_CHECKPOINTS_DIR to {dotenv_path}")
176
-
177
185
  console.print("[bold green]Installation complete![/bold green]")
178
186
 
179
187
 
180
- @app.command(name="list")
181
- def list_models():
188
+ @app.command(name="list-available")
189
+ def list_available():
182
190
  """List available model checkpoints."""
183
191
  console.print("[bold]Available models:[/bold]\n")
184
192
  for name, info in REGISTERED_CHECKPOINTS.items():
185
193
  console.print(f" [cyan]{name:8}[/cyan] - {info.description}")
186
194
 
187
195
 
188
- @app.command()
189
- def show(
190
- checkpoint_dir: Optional[Path] = typer.Option(
191
- None,
192
- "--checkpoint-dir",
193
- "-d",
194
- help="Checkpoint directory to show",
195
- ),
196
- ):
197
- """Show installed checkpoints."""
198
- if checkpoint_dir is None:
199
- checkpoint_dir = get_default_checkpoint_dir()
196
+ @app.command(name="list-installed")
197
+ def list_installed():
198
+ """List installed checkpoints and their sizes."""
199
+ checkpoint_dirs = _resolve_checkpoint_dirs(None)
200
200
 
201
- if not checkpoint_dir.exists():
202
- console.print(
203
- f"[yellow]No checkpoints directory found at {checkpoint_dir}[/yellow]"
204
- )
205
- raise typer.Exit(0)
201
+ checkpoint_files: list[tuple[Path, float]] = []
202
+ for checkpoint_dir in checkpoint_dirs:
203
+ if not checkpoint_dir.exists():
204
+ continue
205
+ ckpts = list(checkpoint_dir.glob("*.ckpt")) + list(checkpoint_dir.glob("*.pt"))
206
+ for ckpt in ckpts:
207
+ size = ckpt.stat().st_size / (1024**3) # GB
208
+ checkpoint_files.append((ckpt, size))
206
209
 
207
- checkpoint_files = list(checkpoint_dir.glob("*.ckpt"))
208
210
  if not checkpoint_files:
209
- console.print(f"[yellow]No checkpoint files found in {checkpoint_dir}[/yellow]")
211
+ console.print(
212
+ "[yellow]No checkpoint files found in any checkpoint directory[/yellow]"
213
+ )
210
214
  raise typer.Exit(0)
211
215
 
212
- console.print(f"[bold]Installed checkpoints in {checkpoint_dir}:[/bold]\n")
216
+ console.print("[bold]Installed checkpoints:[/bold]\n")
213
217
  total_size = 0
214
- for ckpt in sorted(checkpoint_files):
215
- size = ckpt.stat().st_size / (1024**3) # GB
218
+ for ckpt, size in sorted(checkpoint_files, key=lambda item: str(item[0])):
216
219
  total_size += size
217
- console.print(f" {ckpt.name:30} {size:8.2f} GB")
220
+ console.print(f" {ckpt} {size:8.2f} GB")
218
221
 
219
222
  console.print(f"\n[bold]Total:[/bold] {total_size:.2f} GB")
220
223
 
221
224
 
222
- @app.command()
225
+ @app.command(name="clean")
223
226
  def clean(
224
- checkpoint_dir: Optional[Path] = typer.Option(
225
- None,
226
- "--checkpoint-dir",
227
- "-d",
228
- help="Checkpoint directory to clean",
229
- ),
230
227
  confirm: bool = typer.Option(
231
228
  True, "--confirm/--no-confirm", help="Ask for confirmation before deleting"
232
229
  ),
233
230
  ):
234
231
  """Remove all downloaded checkpoints."""
235
- if checkpoint_dir is None:
236
- checkpoint_dir = get_default_checkpoint_dir()
237
-
238
- if not checkpoint_dir.exists():
239
- console.print(f"[yellow]No checkpoints found at {checkpoint_dir}[/yellow]")
240
- raise typer.Exit(0)
232
+ checkpoint_dirs = _resolve_checkpoint_dirs(None)
241
233
 
242
234
  # List files to delete
243
- checkpoint_files = list(checkpoint_dir.glob("*.ckpt"))
235
+ checkpoint_files: list[Path] = []
236
+ for checkpoint_dir in checkpoint_dirs:
237
+ if not checkpoint_dir.exists():
238
+ continue
239
+ checkpoint_files.extend(checkpoint_dir.glob("*.ckpt"))
240
+ checkpoint_files.extend(checkpoint_dir.glob("*.pt"))
241
+
244
242
  if not checkpoint_files:
245
- console.print(f"[yellow]No checkpoint files found in {checkpoint_dir}[/yellow]")
243
+ console.print(
244
+ "[yellow]No checkpoint files found in any checkpoint directory[/yellow]"
245
+ )
246
246
  raise typer.Exit(0)
247
247
 
248
248
  console.print("[bold]Files to delete:[/bold]")
249
249
  total_size = 0
250
- for ckpt in checkpoint_files:
250
+ for ckpt in sorted(checkpoint_files, key=str):
251
251
  size = ckpt.stat().st_size / (1024**3) # GB
252
252
  total_size += size
253
- console.print(f" {ckpt.name} ({size:.2f} GB)")
253
+ console.print(f" {ckpt} ({size:.2f} GB)")
254
254
 
255
255
  console.print(f"\n[bold]Total:[/bold] {total_size:.2f} GB")
256
256
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rc-foundry
3
- Version: 0.1.6
3
+ Version: 0.1.7
4
4
  Summary: Shared utilities and training infrastructure for biomolecular structure prediction models.
5
5
  Author-email: Institute for Protein Design <contact@ipd.uw.edu>
6
6
  License: BSD 3-Clause License
@@ -104,33 +104,34 @@ All models within Foundry rely on [AtomWorks](https://github.com/RosettaCommons/
104
104
  pip install rc-foundry[all]
105
105
  ```
106
106
 
107
- **Downloading weights** All models can be downloaded to a target folder with:
108
-
107
+ **Downloading weights** Models can be downloaded to a target folder with:
108
+ ```
109
+ foundry install base-models --checkpoint-dir <path/to/ckpt/dir>
110
+ ```
111
+ where `checkpoint-dir` will be `~/.foundry/checkpoints` by default. Foundry always searches `~/.foundry/checkpoints` plus any colon-separated entries in `$FOUNDRY_CHECKPOINT_DIRS` during inference or subsequent commands to find checkpoints. `base-models` installs the latest RFD3, RF3 and MPNN variants - you can also download all of the models supported (including multiple checkpoints of RF3) with `all`, or by listing the models sequentially (e.g. `foundry install rfd3 rf3 ...`).
112
+ To list the registry of available checkpoints:
109
113
  ```
110
- foundry install all --checkpoint_dir <path/to/ckpt/dir>
114
+ foundry list-available
111
115
  ```
112
- This will download all the models supported (including multiple checkpoints of RF3) but as a beginner you can start with:
116
+ To check what you already have downloaded (searches `~/.foundry/checkpoints` plus `$FOUNDRY_CHECKPOINT_DIRS` if set):
113
117
  ```
114
- foundry install rfd3 ligandmpnn rf3 --checkpoint_dir <path/to/ckpt/dir>
118
+ foundry list-installed
115
119
  ```
116
120
 
117
- >*See `examples/all.ipynb` for how to run each model in a notebook.*
121
+ >*See `examples/all.ipynb` for how to run each model and design proteins end-to-end in a notebook.*
122
+
123
+ ### Google Colab
124
+ For an interactive Google Colab notebook walking through a basic design pipeline with RFD3, MPNN, and RF3, please see the [IPD Design Pipeline Tutorial](https://colab.research.google.com/drive/1ZwIMV3n9h0ZOnIXX0GyKUuoiahgifBxh?usp=sharing).
118
125
 
119
126
  ### RFdiffusion3 (RFD3)
120
127
 
121
128
  [RFdiffusion3](https://www.biorxiv.org/content/10.1101/2025.09.18.676967v2) is an all-atom generative model capable of designing protein structures under complex constraints.
122
129
 
123
- > *See [models/rfd3/README.md](models/rfd3/README.md) for complete documentation.*
124
-
125
130
  <div align="center">
126
- <img src="models/rfd3/docs/.assets/trajectory.png" alt="RFdiffusion3 generation trajectory." width="700">
131
+ <img src="docs/_static/cover.png" alt="RFdiffusion3 generation trajectory." width="700">
127
132
  </div>
128
133
 
129
- ### ProteinMPNN
130
- [ProteinMPNN](https://www.science.org/doi/10.1126/science.add2187) and [LigandMPNN](https://www.nature.com/articles/s41592-025-02626-1) are lightweight inverse-folding models which can be use to design diverse sequences for backbones under constrained conditions.
131
-
132
- > *See [models/mpnn/README.md](models/mpnn/README.md) for complete documentation.*
133
-
134
+ > *See [models/rfd3/README.md](models/rfd3/README.md) for complete documentation.*
134
135
 
135
136
  ### RosettaFold3 (RF3)
136
137
 
@@ -142,6 +143,11 @@ foundry install rfd3 ligandmpnn rf3 --checkpoint_dir <path/to/ckpt/dir>
142
143
 
143
144
  > *See [models/rf3/README.md](models/rf3/README.md) for complete documentation.*
144
145
 
146
+ ### ProteinMPNN
147
+ [ProteinMPNN](https://www.science.org/doi/10.1126/science.add2187) and [LigandMPNN](https://www.nature.com/articles/s41592-025-02626-1) are lightweight inverse-folding models which can be use to design diverse sequences for backbones under constrained conditions.
148
+
149
+ > *See [models/mpnn/README.md](models/mpnn/README.md) for complete documentation.*
150
+
145
151
  ---
146
152
 
147
153
  ## Development
@@ -159,11 +165,7 @@ foundry install rfd3 ligandmpnn rf3 --checkpoint_dir <path/to/ckpt/dir>
159
165
  Install both `foundry` and models in editable mode for development:
160
166
 
161
167
  ```bash
162
- # Install foundry and RF3 in editable mode
163
- uv pip install -e . -e ./models/rf3 -e ./models/rfd3 -e ./models/mpnn
164
-
165
- # Or install only foundry (no models)
166
- uv pip install -e .
168
+ uv pip install -e '.[all,dev]'
167
169
  ```
168
170
 
169
171
  This approach allows you to:
@@ -171,6 +173,9 @@ This approach allows you to:
171
173
  - Work on specific models without installing all models
172
174
  - Add new models as independent packages in `models/`
173
175
 
176
+ > [!NOTE]
177
+ > Running tests is not currently supported, test files may be missing.
178
+
174
179
  ### Adding New Models
175
180
 
176
181
  To add a new model:
@@ -1,7 +1,7 @@
1
1
  foundry/__init__.py,sha256=H8S1nl5v6YeW8ggn1jKy4GdtH7c-FGS-j7CqUCAEnAU,1926
2
2
  foundry/common.py,sha256=Aur8mH-CNmcUqSsw7VgaCQSW5sH1Bqf8Da91jzxPV1Y,3035
3
3
  foundry/constants.py,sha256=0n1wBKCvNuw3QaQehSbmsHYkIdaGn3tLeRFItBrdeHY,913
4
- foundry/version.py,sha256=riGXiVTWXmtdoju9hVCWvTxpszEMAAIK0sZZWoLKlnU,704
4
+ foundry/version.py,sha256=szvPIs2C82UunpzuvVg3MbF4QhzbBYTsVJ8DmPfq6_E,704
5
5
  foundry/callbacks/__init__.py,sha256=VsRT1e4sqlJHPcTCsfupMEx82Iz-LoOAGPpwvf_OJeE,126
6
6
  foundry/callbacks/callback.py,sha256=xZBo_suP4bLrP6gl5uJPbaXm00DXigePa6dMeDxucgg,3890
7
7
  foundry/callbacks/health_logging.py,sha256=tEtkByOlaAA7nnelxb7PbM9_dcIgOsdbxCdQY3K5pMc,16664
@@ -10,7 +10,7 @@ foundry/callbacks/timing_logging.py,sha256=u-r0hKp7fWOY3mLk7CcuIwHgZbhte13m5M09x
10
10
  foundry/callbacks/train_logging.py,sha256=Xs3tmZA88qLxmdSOwt-x8YKN4NKb1kVm59uptNXl4Qo,10399
11
11
  foundry/hydra/resolvers.py,sha256=xyJzo6OeWAc_LOu8RiHhX7_CRNoLZ22626AvYHXYl4U,2186
12
12
  foundry/inference_engines/base.py,sha256=ZHdlmGUqH4-p3v4RdrLH-Ps8_zalr7j5mQ4x-S53N4M,8375
13
- foundry/inference_engines/checkpoint_registry.py,sha256=kt2Z1JhrAjoOiEpkIIQ0sLttie1ceL8OgXUBmmyA6iw,2544
13
+ foundry/inference_engines/checkpoint_registry.py,sha256=c_me8Uz2NWXAaELhQ4bT1HMPfY8XrH67kvCKdDPrD8g,4149
14
14
  foundry/metrics/__init__.py,sha256=qL4wwaiQ7EtR30pmZ9MCknqx909BJcNvHVmNJUaz_WM,236
15
15
  foundry/metrics/losses.py,sha256=2CLUmf7oCdFUCvgJukdNkff0FVG3BlATI-NI60TtpVY,903
16
16
  foundry/metrics/metric.py,sha256=23pKh_Ra0EcHGo5cSzYQQrUGr5zWRxeufKSJ58tfXXo,12687
@@ -22,7 +22,7 @@ foundry/trainers/fabric.py,sha256=cjaTHbGuJEQwaGBvIAXD_il4bHtY-crsTY14Xn77uXA,40
22
22
  foundry/training/EMA.py,sha256=3OWA9Pz7XuDr-SRxbz24tZf55DmhSa2fKy9r5v2IXqA,2651
23
23
  foundry/training/checkpoint.py,sha256=mUiObg-qcF3tvMfVu77sD9m3yVRp71czv07ccliU7qQ,1791
24
24
  foundry/training/schedulers.py,sha256=StmXegPfIdLAv31FreCTrDh9dsOvNUfzG4YGa61Y4oE,3647
25
- foundry/utils/alignment.py,sha256=OAN7H2TqraGxP1uMXUpwLO7g0qS0cxUVjuV33pY16z0,2316
25
+ foundry/utils/alignment.py,sha256=2anqy0mn9zeFEiVWS_EG7zHiyPk1C_gbUu-SRvQ5mAM,2502
26
26
  foundry/utils/components.py,sha256=Piw2TfQF26uuxC3hXG3iv_4rgud1lVO-cv6N-p05EDY,15200
27
27
  foundry/utils/datasets.py,sha256=pLBxVezm-TSrYuC5gFnJZdGnNWV7aPH2QiWIVE2hkdQ,16629
28
28
  foundry/utils/ddp.py,sha256=ydHrO6peGbRnWAwgH5rmpHuQd55g2gFzzoZJYypn7GU,3970
@@ -34,7 +34,7 @@ foundry/utils/squashfs.py,sha256=QlcwuJyVe-QVfIOS7o1QfLhaCQPNzzox7ln4n8dcYEg,523
34
34
  foundry/utils/torch.py,sha256=OLsqoxw4CTXbGzWUHernLUT7uQjLu0tVPtD8h8747DI,11211
35
35
  foundry/utils/weights.py,sha256=btz4S02xff2vgiq4xMfiXuhK1ERafqQPtmimo1DmoWY,10381
36
36
  foundry_cli/__init__.py,sha256=0BxY2RUKJLaMXUGgypPCwlTskTEFdVnkhTR4C4ft2Kw,52
37
- foundry_cli/download_checkpoints.py,sha256=UCNdy4VZyJe1PH_lnVLqy-VSMuTu875mGGd99ma7fTQ,8426
37
+ foundry_cli/download_checkpoints.py,sha256=CxU9dKBa1vAkVd450tfH5aZAlQIUTrHsDGTbmxzd_JQ,8922
38
38
  mpnn/__init__.py,sha256=hgQcXFaCbAxFrhydVAy0xj8yC7UJF-GCCFhqD0sZ7I4,57
39
39
  mpnn/inference.py,sha256=wPtGR325eVRVeesXoWtBK6b_-VcU8BZae5IfQN3-mvA,1669
40
40
  mpnn/train.py,sha256=9eQGBd3rdNF5Zr2w8oUgETbqxBavNBajtA6Vbc5zESE,10239
@@ -119,18 +119,18 @@ rfd3/__init__.py,sha256=2Wto2IsUIj2lGag9m_gqgdCwBNl5p21-Xnr7W_RpU3c,348
119
119
  rfd3/callbacks.py,sha256=Zjt8RiaYWquoKOwRmC_wCUbRbov-V4zd2_73zjhgDHE,2783
120
120
  rfd3/cli.py,sha256=ka3K5H117fzDYIDXFpOpJV21w_XBrHYJZdFE0thsGBI,1644
121
121
  rfd3/constants.py,sha256=wLvDzrThpOrK8T3wGFNQeGrhAXOJQze8l3v_7pjIdMM,13141
122
- rfd3/engine.py,sha256=La_dB48Ewz0IdY1ocxvSWg-PXVAsySm0OGvwyz42lI8,20824
123
- rfd3/run_inference.py,sha256=ljzsCKEtrlfAvP0SDFPeQwTM3rV_X3ewHOhcRFVI37c,1258
122
+ rfd3/engine.py,sha256=NwATrhYFyqT7C9Bie8mWtUiqqzXgs9x6nOCkmZYPiT4,21224
123
+ rfd3/run_inference.py,sha256=HfRMQ30_SAHfc-VFzBV52F-aLaNdG6PW8VkdMyB__wE,1264
124
124
  rfd3/train.py,sha256=rHswffIUhOae3_iYyvAiQ3jALoFuzrcRUgMlbJLinlI,7947
125
125
  rfd3/inference/datasets.py,sha256=u-2U7deHXu-iOs7doiKKynewP-NEyJfdORSTDzUSaQI,6538
126
- rfd3/inference/input_parsing.py,sha256=mk3HBvo7MPTFEET7NagCo5TSjb47w-hxUDoeQxUW_h4,45449
127
- rfd3/inference/legacy_input_parsing.py,sha256=1wf_KF7qWnGLaVM8IXDl8fIsWCmxtOi2YlAiHEVELqw,28046
128
- rfd3/inference/parsing.py,sha256=Nq8CYmimnql4RM-5ZfPAvOFvCae4_CC2pYDzE6iCpWU,5290
129
- rfd3/inference/symmetry/atom_array.py,sha256=HH50Z07bTUnNUgCwAGslADbvMYHgsXn9s-fqwx6BvKw,11034
130
- rfd3/inference/symmetry/checks.py,sha256=wb7K327GnMwGG9bgOvvDAbaPsFj4nGZpEAolICUapNc,8908
126
+ rfd3/inference/input_parsing.py,sha256=TyEzCzeCaNhuNi0RjMcq9fF2j3Sp36KbuZ1FUjlBTZ8,45442
127
+ rfd3/inference/legacy_input_parsing.py,sha256=G2XxkrjdIpL6i1YY7xEmkFitVv__Pc45ow6IKKPHw64,28855
128
+ rfd3/inference/parsing.py,sha256=ktAMUuZE3Pe4bKAjjV3zjqcEDmGlMZ-cotIUhJsEQQA,5402
129
+ rfd3/inference/symmetry/atom_array.py,sha256=HfFagFUB5yB-Y4IfUM5nuVGWHC5AEkyHqt0JcIqTQ_E,10922
130
+ rfd3/inference/symmetry/checks.py,sha256=y-Kq0l5OhEmmxsPBBsMMB0qaAt18FeEicD3-jSMQFa0,9900
131
131
  rfd3/inference/symmetry/contigs.py,sha256=6OvbZ2dJg-a0mvvKAC0VkzUH5HpUDxOJvkByIst_roU,2127
132
- rfd3/inference/symmetry/frames.py,sha256=G55p-aOXqEYG4kCyKxrgWAsS-gW9-gOTlBME6nhbKyU,10716
133
- rfd3/inference/symmetry/symmetry_utils.py,sha256=KwgxrdfO766RCEwF3VElAE85oEKiopPGRQDhJbKZaUA,15810
132
+ rfd3/inference/symmetry/frames.py,sha256=aEwkmlUsYexERX9hu09JMhisC8QTpHPVhfITbL80-EE,10819
133
+ rfd3/inference/symmetry/symmetry_utils.py,sha256=p_PkxU3sw6gYGO2EmZTrbNQdLjz1mdTWEIl5MjQdIuY,14664
134
134
  rfd3/metrics/design_metrics.py,sha256=O1RqZdjQPNlAWYRg6UJTERYg_gUI1_hVleKsm9xbWBY,16836
135
135
  rfd3/metrics/hbonds_hbplus_metrics.py,sha256=Sewy9KzmrA1OnfkasN-fmWrQ9IRx9G7Yyhe2ua0mk28,11518
136
136
  rfd3/metrics/hbonds_metrics.py,sha256=SIR4BnDhYdpVSqwXXRYpQ_tB-M0_fVyugGl08WivCmE,15257
@@ -170,7 +170,7 @@ rfd3/transforms/symmetry.py,sha256=GSnMF7oAnUxPozfafsRuHEv0yKXW0BpLTI6wsKGZrbc,2
170
170
  rfd3/transforms/training_conditions.py,sha256=UXiUPjDwrNKM95tRe0eXrMeRN8XlTPc_MXUvo6UpePo,19510
171
171
  rfd3/transforms/util_transforms.py,sha256=2AcLkzx-73ZFgcWD1cIHv7NyniRPI4_zThHK8azyQaY,18119
172
172
  rfd3/transforms/virtual_atoms.py,sha256=UpmxzPPd5FaJigcRoxgLSHHrLLOqsCvZ5PPZfQSGqII,12547
173
- rfd3/utils/inference.py,sha256=RQp5CCy6Z6uHVZ2Mx0zmmGluYEOrASke4bABtfRjpy0,26448
173
+ rfd3/utils/inference.py,sha256=-8IKzkB9ulhLEJgapvnZSdIaIPQDPMpyPpHTQlFS7r0,27317
174
174
  rfd3/utils/io.py,sha256=wbdjUTQkDc3RCSM7gdogA-XOKR68HeQ-cfvyN4pP90w,9849
175
175
  rfd3/utils/vizualize.py,sha256=HPlczrA3zkOuxV5X05eOvy_Oga9e3cPnFUXOEP4RR_g,11046
176
176
  rf3/configs/inference.yaml,sha256=JmEZdkAnbnOrX79lGS5xrYYho9aBFfVxfUp-8KjJV5I,309
@@ -248,7 +248,7 @@ rfd3/configs/datasets/conditions/sequence_design.yaml,sha256=D1K6WOysmSAQ4LogltU
248
248
  rfd3/configs/datasets/conditions/tipatom.yaml,sha256=0010o7UUL-l75qI8HCjC_tdBXFWysm2dgVXzE7bQyZ0,650
249
249
  rfd3/configs/datasets/conditions/unconditional.yaml,sha256=z1eVHylswLyludXWFs1AMt3mTMu3EbAUHrP8J3XBsRU,446
250
250
  rfd3/configs/datasets/train/rfd3_monomer_distillation.yaml,sha256=1f61uFeRB8OD6sifFuIKFov8D7PcHpqRT4Z-M5EzO4w,1207
251
- rfd3/configs/datasets/train/pdb/af3_train_interface.yaml,sha256=mwbdGJQ9SXc8WvO3qqSWzS--K4rvbFsM0MR371FUrr0,1552
251
+ rfd3/configs/datasets/train/pdb/af3_train_interface.yaml,sha256=DSIpXW2SQ3drDp12490y0tFbjbugecyA7TI_x3WrKng,1546
252
252
  rfd3/configs/datasets/train/pdb/af3_train_pn_unit.yaml,sha256=DPoEhLlyBu0RdBkkJeWB8pkOV4z0DBc6XmclLgww9II,1324
253
253
  rfd3/configs/datasets/train/pdb/base.yaml,sha256=2VUEAKADyvjJmWP4FeOJwRat9r6F3_GXuyGYjvMvArw,291
254
254
  rfd3/configs/datasets/train/pdb/base_no_weights.yaml,sha256=8HchN7DqYESBK520vShdg7xidWBSogGRAxfaxa5pKdE,554
@@ -285,7 +285,7 @@ rfd3/configs/hydra/default.yaml,sha256=SYDTSU8bAw20QssrtTi7lptiBD5H3XNyzApsyy0br
285
285
  rfd3/configs/hydra/no_logging.yaml,sha256=MUXDFcw-QwaRPz9HcE-1tdZwbNha1mexTe31G-Zt9_w,120
286
286
  rfd3/configs/inference_engine/base.yaml,sha256=ekP5U7bAALpeJGpwyj1v0N5LiEtptl5loRCtM8FRzRM,246
287
287
  rfd3/configs/inference_engine/dev.yaml,sha256=-2snClOTwj5TQt7jnwSrI4pzAiI4nFulXKJflmgIyUw,304
288
- rfd3/configs/inference_engine/rfdiffusion3.yaml,sha256=h2e9U9RFCcvXjKAJ6U8puj-8O-U57ZxeZLA0HLB2txA,2161
288
+ rfd3/configs/inference_engine/rfdiffusion3.yaml,sha256=3bHIAhzFhFDIag0xQWYxHBUMSc71fjClHXKbZ-tpHzA,2112
289
289
  rfd3/configs/logger/csv.yaml,sha256=DtcywAIS4OxLXP2QxSEvqdrjhMpT6xHiGspoYw5qkus,245
290
290
  rfd3/configs/logger/default.yaml,sha256=pSyHyxT-J_T-g4_6TtD2yzN3rzxgY6rOG_Vh4RjZeFY,17
291
291
  rfd3/configs/logger/wandb.yaml,sha256=RhCnFtO0hNc3R75ts417l5ICZeGm74lOj9Bfe7ZvRNA,652
@@ -294,7 +294,7 @@ rfd3/configs/model/components/ema.yaml,sha256=AIzf4RZLKP8AcfaxdvZBS1rFw3AlSo431r
294
294
  rfd3/configs/model/components/rfd3_net.yaml,sha256=95FF4U7aWmLCoHvyxsRoE74n-bxTPD6KlAhPKNemVH4,3275
295
295
  rfd3/configs/model/optimizers/adam.yaml,sha256=cTRNo4_4lNgLv0b329v-KiC_MCQtTVVTxeer5Au_FIM,145
296
296
  rfd3/configs/model/samplers/edm.yaml,sha256=QycHAIrfhRgx0mJygTOs56FT93tGCWTGxrQSKBOA7Mc,483
297
- rfd3/configs/model/samplers/symmetry.yaml,sha256=BZZOIhk2ndAvIntf-16nnqCuOW43iWTB7iDU-RsxOcc,214
297
+ rfd3/configs/model/samplers/symmetry.yaml,sha256=pI0Ens6jmbpAIl8E4eYsJR1SqIppe5OsWh91KfpjNjs,214
298
298
  rfd3/configs/model/schedulers/af3.yaml,sha256=xEtRb--KPjg_5pW_IJvN9AHWVqCtOM4QOnXlMH2KrEg,149
299
299
  rfd3/configs/paths/default.yaml,sha256=bjB04SNu_5E6W_v4mRBjwce0xmdKwO5wsVf4gfaRl0Y,1045
300
300
  rfd3/configs/paths/data/default.yaml,sha256=jfs1dbbcOqHja4_6lXheyRg4t0YExqVn2w0rZEWL6XE,788
@@ -304,8 +304,8 @@ rfd3/configs/trainer/rfd3_base.yaml,sha256=R3lZxdyjUirjlLU31qWlnZgHaz4GcWTGGIz4f
304
304
  rfd3/configs/trainer/loss/losses/diffusion_loss.yaml,sha256=FE4FCEfurE0ekwZ4YfS6wCvPSNqxClwg_kc73cPql5Y,323
305
305
  rfd3/configs/trainer/loss/losses/sequence_loss.yaml,sha256=kezbQcqwAZ0VKQPUBr2MsNr9DcDL3ENIP1i-j7h-6Co,64
306
306
  rfd3/configs/trainer/metrics/design_metrics.yaml,sha256=xVDpClhHqSHvsf-8StL26z51Vn-iuWMDG9KMB-kqOI0,719
307
- rc_foundry-0.1.6.dist-info/METADATA,sha256=EEOkAi2nABzo70kEP-n9t5aXZ8a4Gqr5wYZ2mjIBqp4,10585
308
- rc_foundry-0.1.6.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
309
- rc_foundry-0.1.6.dist-info/entry_points.txt,sha256=BmiWCbWGtrd_lSOFMuCLBXyo84B7Nco-alj7hB0Yw9A,130
310
- rc_foundry-0.1.6.dist-info/licenses/LICENSE.md,sha256=NKtPCJ7QMysFmzeDg56ZfUStvgzbq5sOvRQv7_ddZOs,1533
311
- rc_foundry-0.1.6.dist-info/RECORD,,
307
+ rc_foundry-0.1.7.dist-info/METADATA,sha256=zlvCxfZ5-Ow7WuGKskfW6P1DGhZB9OfLIIBUBGncFeQ,11309
308
+ rc_foundry-0.1.7.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
309
+ rc_foundry-0.1.7.dist-info/entry_points.txt,sha256=BmiWCbWGtrd_lSOFMuCLBXyo84B7Nco-alj7hB0Yw9A,130
310
+ rc_foundry-0.1.7.dist-info/licenses/LICENSE.md,sha256=NKtPCJ7QMysFmzeDg56ZfUStvgzbq5sOvRQv7_ddZOs,1533
311
+ rc_foundry-0.1.7.dist-info/RECORD,,
@@ -7,7 +7,7 @@ dataset:
7
7
  base_dir: ${paths.data.pdb_data_dir}
8
8
  dataset:
9
9
  name: interface
10
- data: ${paths.data.pdb_parquet_dir}/interfaces_df_train.parquet
10
+ data: ${paths.data.pdb_parquet_dir}/interfaces_df.parquet
11
11
  filters:
12
12
  # filters common across all PDB datasets
13
13
  - "deposition_date < '2021-09-30'"
@@ -7,7 +7,7 @@ _target_: rfd3.engine.RFD3InferenceEngine
7
7
 
8
8
  out_dir: ???
9
9
  inputs: ??? # null, json, pdb or
10
- ckpt_path: /projects/ml/aa_design/models/rfd3_latest_cleaned.ckpt
10
+ ckpt_path: rfd3
11
11
  json_keys_subset: null
12
12
  skip_existing: True
13
13
 
@@ -61,5 +61,5 @@ global_prefix: null
61
61
  dump_prediction_metadata_json: True
62
62
  dump_trajectories: False
63
63
  align_trajectory_structures: False
64
- prevalidate_inputs: True
64
+ prevalidate_inputs: False
65
65
  low_memory_mode: False # False for standard mode, True for memory efficient tokenization mode
@@ -4,7 +4,7 @@ defaults:
4
4
 
5
5
  kind: symmetry
6
6
  num_timesteps: 200
7
- gamma_0: 1.0 # 1.0 for SDE sampling
7
+ gamma_0: 0.6 # 1.0 for SDE sampling
8
8
  gamma_min: 1.0
9
9
  gamma_min2: 0.0
10
10
  sym_step_frac: 0.9 # when 0.9, 90% of the trajectory from the start is symmetrized
rfd3/engine.py CHANGED
@@ -23,7 +23,10 @@ from rfd3.inference.datasets import (
23
23
  )
24
24
  from rfd3.inference.input_parsing import DesignInputSpecification
25
25
  from rfd3.model.inference_sampler import SampleDiffusionConfig
26
- from rfd3.utils.inference import ensure_input_is_abspath
26
+ from rfd3.utils.inference import (
27
+ ensure_inference_sampler_matches_design_spec,
28
+ ensure_input_is_abspath,
29
+ )
27
30
  from rfd3.utils.io import (
28
31
  CIF_LIKE_EXTENSIONS,
29
32
  build_stack_from_atom_array_and_batched_coords,
@@ -171,6 +174,7 @@ class RFD3InferenceEngine(BaseInferenceEngine):
171
174
  )
172
175
  # save
173
176
  self.specification_overrides = dict(specification or {})
177
+ self.inference_sampler_overrides = dict(inference_sampler or {})
174
178
 
175
179
  # Setup output directories and args
176
180
  self.global_prefix = global_prefix
@@ -210,6 +214,9 @@ class RFD3InferenceEngine(BaseInferenceEngine):
210
214
  inputs=inputs,
211
215
  n_batches=n_batches,
212
216
  )
217
+ ensure_inference_sampler_matches_design_spec(
218
+ design_specifications, self.inference_sampler_overrides
219
+ )
213
220
  # init before
214
221
  self.initialize()
215
222
  outputs = self._run_multi(design_specifications)
@@ -383,6 +390,9 @@ class RFD3InferenceEngine(BaseInferenceEngine):
383
390
  # Based on inputs, construct the specifications to loop through
384
391
  design_specifications = {}
385
392
  for prefix, example_spec in inputs.items():
393
+ # Record task name in the specification
394
+ example_spec["extra"]["task_name"] = prefix
395
+
386
396
  # ... Create n_batches for example
387
397
  for batch_id in range((n_batches) if exists(n_batches) else 1):
388
398
  # ... Example ID
@@ -524,21 +534,19 @@ def process_input(
524
534
 
525
535
 
526
536
  def _reshape_trajectory(traj, align_structures: bool):
527
- traj = [traj[i] for i in range(len(traj))]
528
- n_steps = len(traj)
537
+ traj = [traj[i] for i in range(len(traj))] # make list of arrays
529
538
  max_frames = 100
530
-
539
+ if len(traj) > max_frames:
540
+ selected_indices = torch.linspace(0, len(traj) - 1, max_frames).long().tolist()
541
+ traj = [traj[i] for i in selected_indices]
531
542
  if align_structures:
532
543
  # ... align the trajectories on the last prediction
533
- for step in range(n_steps - 1):
544
+ for step in range(len(traj) - 1):
534
545
  traj[step] = weighted_rigid_align(
535
- X_L=traj[-1],
536
- X_gt_L=traj[step],
537
- )
546
+ X_L=traj[-1][None],
547
+ X_gt_L=traj[step][None],
548
+ ).squeeze(0)
538
549
  traj = traj[::-1] # reverse to go from noised -> denoised
539
- if n_steps > max_frames:
540
- selected_indices = torch.linspace(0, n_steps - 1, max_frames).long().tolist()
541
- traj = [traj[i] for i in selected_indices]
542
550
 
543
551
  traj = torch.stack(traj).cpu().numpy()
544
552
  return traj
@@ -696,7 +696,7 @@ class DesignInputSpecification(BaseModel):
696
696
  # Partial diffusion: use COM, keep all coordinates
697
697
  if exists(self.symmetry) and self.symmetry.id:
698
698
  # For symmetric structures, avoid COM centering that would collapse chains
699
- ranked_logger.info(
699
+ logger.info(
700
700
  "Partial diffusion with symmetry: skipping COM centering to preserve chain spacing"
701
701
  )
702
702
  else:
@@ -139,13 +139,18 @@ def fetch_motif_residue_(
139
139
  subarray, motif=True, unindexed=False, dtype=int
140
140
  ) # all values init to True (fix all)
141
141
 
142
+ to_unindex = f"{src_chain}{src_resid}" in unindexed_components
143
+ to_index = f"{src_chain}{src_resid}" in components
144
+
142
145
  # Assign is motif atom and sequence
143
146
  if exists(atoms := fixed_atoms.get(f"{src_chain}{src_resid}")):
147
+ # If specified, we set fixed atoms in the residue to be motif atoms
144
148
  atom_mask = get_name_mask(subarray.atom_name, atoms, res_name)
145
149
  subarray.set_annotation("is_motif_atom", atom_mask)
146
150
  # subarray.set_annotation("is_motif_atom_with_fixed_coord", atom_mask) # BUGFIX: uncomment
147
151
 
148
152
  elif redesign_motif_sidechains and res_name in STANDARD_AA:
153
+ # If redesign_motif_sidechains is True, we only make the backbone atoms to be motif atoms
149
154
  n_atoms = subarray.shape[0]
150
155
  diffuse_oxygen = False
151
156
  if n_atoms < 3:
@@ -178,6 +183,18 @@ def fetch_motif_residue_(
178
183
  subarray.set_annotation(
179
184
  "is_motif_atom_with_fixed_seq", np.zeros(subarray.shape[0], dtype=int)
180
185
  )
186
+ elif to_index or to_unindex:
187
+ # If the residue is in the contig or unindexed components,
188
+ # we set all atoms in the residue to be motif atoms
189
+ subarray.set_annotation("is_motif_atom", np.ones(subarray.shape[0], dtype=int))
190
+ else:
191
+ if to_unindex and not (
192
+ unfix_all or f"{src_chain}{src_resid}" in unfix_residues
193
+ ):
194
+ raise ValueError(
195
+ f"{src_chain}{src_resid} is not found in fixed_atoms, contig or unindex contig."
196
+ "Please check your input and contig specification."
197
+ )
181
198
  if unfix_all or f"{src_chain}{src_resid}" in unfix_residues:
182
199
  subarray.set_annotation(
183
200
  "is_motif_atom_with_fixed_coord", np.zeros(subarray.shape[0], dtype=int)
@@ -197,7 +214,6 @@ def fetch_motif_residue_(
197
214
  subarray.set_annotation(
198
215
  "is_flexible_motif_atom", np.zeros(subarray.shape[0], dtype=bool)
199
216
  )
200
- to_unindex = f"{src_chain}{src_resid}" in unindexed_components
201
217
  if to_unindex:
202
218
  subarray.set_annotation(
203
219
  "is_motif_atom_unindexed", subarray.is_motif_atom.copy()
rfd3/inference/parsing.py CHANGED
@@ -117,6 +117,7 @@ def from_any_(v: Any, atom_array: AtomArray):
117
117
 
118
118
  # Split to atom names
119
119
  data_split[idx] = token.atom_name[comp_mask_subset].tolist()
120
+ # TODO: there is a bug where when you select specifc atoms within a ligand, output ligand is fragmented
120
121
 
121
122
  # Update mask & token dictionary
122
123
  mask[comp_mask] = comp_mask_subset
@@ -4,12 +4,8 @@ from rfd3.inference.symmetry.frames import (
4
4
  get_symmetry_frames_from_symmetry_id,
5
5
  )
6
6
 
7
- from foundry.utils.ddp import RankedLogger
8
-
9
7
  FIXED_TRANSFORM_ID = -1
10
8
  FIXED_ENTITY_ID = -1
11
- ranked_logger = RankedLogger(__name__, rank_zero_only=True)
12
-
13
9
 
14
10
  ########################################################
15
11
  # Symmetry annotations
@@ -28,7 +24,7 @@ def add_sym_annotations(atom_array, sym_conf):
28
24
  is_asu = np.full(n, True, dtype=np.bool_)
29
25
  atom_array.set_annotation("is_sym_asu", is_asu)
30
26
  # symmetry_id
31
- symmetry_ids = np.full(n, sym_conf.get("id"), dtype="U6")
27
+ symmetry_ids = np.full(n, sym_conf.id, dtype="U6")
32
28
  atom_array.set_annotation("symmetry_id", symmetry_ids)
33
29
  return atom_array
34
30
 
@@ -1,10 +1,13 @@
1
1
  import numpy as np
2
- from rfd3.inference.symmetry.contigs import expand_contig_unsym_motif
2
+ from rfd3.inference.symmetry.contigs import (
3
+ expand_contig_unsym_motif,
4
+ get_unsym_motif_mask,
5
+ )
3
6
  from rfd3.transforms.conditioning_base import get_motif_features
4
7
 
5
8
  from foundry.utils.ddp import RankedLogger
6
9
 
7
- MIN_ATOMS_ALIGN = 100
10
+ MIN_ATOMS_ALIGN = 30
8
11
  MAX_TRANSFORMS = 10
9
12
  RMSD_CUT = 1.0 # Angstroms
10
13
 
@@ -18,29 +21,33 @@ def check_symmetry_config(
18
21
  Check if the symmetry configuration is valid. Add all basic checks here.
19
22
  """
20
23
 
21
- assert sym_conf.get("id"), "symmetry_id is required. e.g. {'id': 'C2'}"
24
+ assert sym_conf.id, "symmetry_id is required. e.g. {'id': 'C2'}"
22
25
  # if unsym motif is provided, check that each motif name is in the atom array
23
- if sym_conf.get("is_unsym_motif"):
26
+
27
+ is_unsym_motif = np.zeros(atom_array.shape[0], dtype=bool)
28
+ if sym_conf.is_unsym_motif:
24
29
  assert (
25
30
  src_atom_array is not None
26
31
  ), "Source atom array must be provided for symmetric motifs"
27
- unsym_motif_names = sym_conf["is_unsym_motif"].split(",")
32
+ unsym_motif_names = sym_conf.is_unsym_motif.split(",")
28
33
  unsym_motif_names = expand_contig_unsym_motif(unsym_motif_names)
34
+ is_unsym_motif = get_unsym_motif_mask(atom_array, unsym_motif_names)
29
35
  for n in unsym_motif_names:
30
36
  if (sm and n not in sm.split(",")) and (n not in atom_array.src_component):
31
37
  raise ValueError(f"Unsym motif {n} not found in atom_array")
38
+
39
+ is_motif_token = get_motif_features(atom_array)["is_motif_token"]
32
40
  if (
33
- get_motif_features(atom_array)["is_motif_token"].any()
34
- and not sym_conf.get("is_symmetric_motif")
41
+ is_motif_token[~is_unsym_motif].any()
42
+ and not sym_conf.is_symmetric_motif
35
43
  and not has_dist_cond
36
44
  ):
37
45
  raise ValueError(
38
- "Asymmetric motif inputs should be distance constrained. "
46
+ "Asymmetric motif inputs should be distance constrained."
39
47
  "Use atomwise_fixed_dist to constrain the distance between the motif atoms."
40
48
  )
41
- # else: if unconditional symmetry, no need to have symmetric input motif
42
49
 
43
- if partial and not sym_conf.get("is_symmetric_motif"):
50
+ if partial and not sym_conf.is_symmetric_motif:
44
51
  raise ValueError(
45
52
  "Partial diffusion with symmetry is only supported for symmetric inputs."
46
53
  )
@@ -54,9 +61,6 @@ def check_atom_array_is_symmetric(atom_array):
54
61
  Returns:
55
62
  bool: True if the atom array is symmetric, False otherwise
56
63
  """
57
- # TODO: Implement something like this https://github.com/baker-laboratory/ipd/blob/main/ipd/sym/sym_detect.py#L303
58
- # and maybe this https://github.com/baker-laboratory/ipd/blob/main/ipd/sym/sym_detect.py#L231
59
-
60
64
  import biotite.structure as struc
61
65
  from rfd3.inference.symmetry.atom_array import (
62
66
  apply_symmetry_to_atomarray_coord,
@@ -68,8 +72,10 @@ def check_atom_array_is_symmetric(atom_array):
68
72
  # remove hetero atoms
69
73
  atom_array = atom_array[~atom_array.hetero]
70
74
  if len(atom_array) == 0:
71
- ranked_logger.info("Atom array has no protein chains. Please check your input.")
72
- return False
75
+ ranked_logger.warning(
76
+ "Atom array has no protein chains. Please check your input."
77
+ )
78
+ return True
73
79
 
74
80
  chains = np.unique(atom_array.chain_id)
75
81
  asu_mask = atom_array.chain_id == chains[0]
@@ -162,16 +168,22 @@ def find_optimal_rotation(coords1, coords2, max_points=1000):
162
168
  return None
163
169
 
164
170
 
165
- def check_input_frames_match_symmetry_frames(computed_frames, original_frames) -> None:
171
+ def check_input_frames_match_symmetry_frames(
172
+ computed_frames, original_frames, nids_by_entity
173
+ ) -> None:
166
174
  """
167
175
  Check if the atom array matches the symmetry_id.
168
176
  Arguments:
169
177
  computed_frames: list of computed frames
170
178
  original_frames: list of original frames
171
179
  """
172
- assert len(computed_frames) == len(
173
- original_frames
174
- ), "Number of computed frames does not match number of original frames"
180
+ assert len(computed_frames) == len(original_frames), (
181
+ "Number of computed frames does not match number of original frames.\n"
182
+ f"Computed Frames: {len(computed_frames)}. Original Frames: {len(original_frames)}.\n"
183
+ "If the computed frames are not as expected, please check if you have one-to-one mapping "
184
+ "(size, sequence, folding) of an entity across all chains.\n"
185
+ f"Computed Entity Mapping: {nids_by_entity}."
186
+ )
175
187
 
176
188
 
177
189
  def check_valid_multiplicity(nids_by_entity) -> None:
@@ -184,25 +196,35 @@ def check_valid_multiplicity(nids_by_entity) -> None:
184
196
  multiplicity = min([len(i) for i in nids_by_entity.values()])
185
197
  if multiplicity == 1: # no possible symmetry
186
198
  raise ValueError(
187
- "Input has no possible symmetry. If asymmetric motif, please use 2D conditioning inference instead."
199
+ "Input has no possible symmetry. If asymmetric motif, please use 2D conditioning inference instead.\n"
200
+ "Multiplicity: 1"
188
201
  )
189
202
 
190
203
  # Check that the input is not asymmetric
191
204
  multiplicity_good = [len(i) % multiplicity == 0 for i in nids_by_entity.values()]
192
205
  if not all(multiplicity_good):
193
- raise ValueError("Invalid multiplicities of subunits. Please check your input.")
206
+ raise ValueError(
207
+ "Expected multiplicity does not match for some entities.\n"
208
+ "Please modify your input to have one-to-one mapping (size, sequence, folding) of an entity across all chains.\n"
209
+ f"Expected Multiplicity: {multiplicity}.\n"
210
+ f"Computed Entity Mapping: {nids_by_entity}."
211
+ )
194
212
 
195
213
 
196
214
  def check_valid_subunit_size(nids_by_entity, pn_unit_id) -> None:
197
215
  """
198
216
  Check that the subunits in the input are of the same size.
199
217
  Arguments:
200
- nids_by_entity: dict mapping entity to ids
218
+ nids_by_entity: dict mapping entity to ids. e.g. {0: (['A_1', 'B_1', 'C_1']), 1: (['A_2', 'B_2', 'C_2'])}
219
+ pn_unit_id: array of ids. e.g. ['A_1', 'B_1', 'C_1', 'A_2', 'B_2', 'C_2']
201
220
  """
202
- for i, js in nids_by_entity.items():
203
- for j in js[1:]:
204
- if (pn_unit_id == js[0]).sum() != (pn_unit_id == j).sum():
205
- raise ValueError("Size mismatch in the input. Please check your file.")
221
+ for js in nids_by_entity.values():
222
+ for js_i in js[1:]:
223
+ if (pn_unit_id == js[0]).sum() != (pn_unit_id == js_i).sum():
224
+ raise ValueError(
225
+ f"Size mismatch between chain {js[0]} ({(pn_unit_id == js[0]).sum()} atoms) "
226
+ f"and chain {js_i} ({(pn_unit_id == js_i).sum()} atoms). Please check your input file."
227
+ )
206
228
 
207
229
 
208
230
  def check_min_atoms_to_align(natm_per_unique, reference_entity) -> None:
@@ -212,7 +234,10 @@ def check_min_atoms_to_align(natm_per_unique, reference_entity) -> None:
212
234
  nids_by_entity: dict mapping entity to ids
213
235
  """
214
236
  if natm_per_unique[reference_entity] < MIN_ATOMS_ALIGN:
215
- raise ValueError("Not enough atoms to align. Please check your input.")
237
+ raise ValueError(
238
+ f"Not enough atoms to align < {MIN_ATOMS_ALIGN} atoms."
239
+ f"Please provide a input with at least {MIN_ATOMS_ALIGN} atoms."
240
+ )
216
241
 
217
242
 
218
243
  def check_max_transforms(chains_to_consider) -> None:
@@ -224,7 +249,7 @@ def check_max_transforms(chains_to_consider) -> None:
224
249
  """
225
250
  if len(chains_to_consider) > MAX_TRANSFORMS:
226
251
  raise ValueError(
227
- "Number of transforms exceeds the max number of transforms (10)"
252
+ f"Number of transforms exceeds the max number of transforms ({MAX_TRANSFORMS})."
228
253
  )
229
254
 
230
255
 
@@ -10,12 +10,13 @@ def get_symmetry_frames_from_symmetry_id(symmetry_id):
10
10
  Returns:
11
11
  frames: list of rotation matrices
12
12
  """
13
+ from rfd3.inference.symmetry.symmetry_utils import SymmetryConfig
13
14
 
14
15
  # Get frames from symmetry id
15
16
  sym_conf = {}
16
- if isinstance(symmetry_id, dict):
17
+ if isinstance(symmetry_id, SymmetryConfig):
17
18
  sym_conf = symmetry_id
18
- symmetry_id = symmetry_id.get("id")
19
+ symmetry_id = symmetry_id.id
19
20
 
20
21
  if symmetry_id.lower().startswith("c"):
21
22
  order = int(symmetry_id[1:])
@@ -25,9 +26,9 @@ def get_symmetry_frames_from_symmetry_id(symmetry_id):
25
26
  frames = get_dihedral_frames(order)
26
27
  elif symmetry_id.lower() == "input_defined":
27
28
  assert (
28
- "symmetry_file" in sym_conf
29
+ sym_conf.symmetry_file is not None
29
30
  ), "symmetry_file is required for input_defined symmetry"
30
- frames = get_frames_from_file(sym_conf.get("symmetry_file"))
31
+ frames = get_frames_from_file(sym_conf.symmetry_file)
31
32
  else:
32
33
  raise ValueError(f"Symmetry id {symmetry_id} not supported")
33
34
 
@@ -120,7 +121,9 @@ def get_symmetry_frames_from_atom_array(src_atom_array, input_frames):
120
121
  computed_frames = [(R, np.array([0, 0, 0])) for R in Rs]
121
122
 
122
123
  # check that the computed frames match the input frames
123
- check_input_frames_match_symmetry_frames(computed_frames, input_frames)
124
+ check_input_frames_match_symmetry_frames(
125
+ computed_frames, input_frames, nids_by_entity
126
+ )
124
127
 
125
128
  return computed_frames
126
129
 
@@ -39,18 +39,36 @@ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
39
39
 
40
40
 
41
41
  class SymmetryConfig(BaseModel):
42
- # AM / HE TODO: feel free to flesh this out and add validation as needed
43
42
  model_config = ConfigDict(
44
43
  arbitrary_types_allowed=True,
45
44
  extra="allow",
46
45
  )
47
- id: Optional[str] = Field(None)
48
- # is_unsym_motif: Optional[np.ndarray[bool]] = Field(...)
49
- # is_symmetric_motif: bool = Field(...)
46
+ id: Optional[str] = Field(
47
+ None,
48
+ description="Symmetry group ID. e.g. 'C3', 'D2'. Only C and D symmetry types are supported currently.",
49
+ )
50
+ is_unsym_motif: Optional[str] = Field(
51
+ None,
52
+ description="Comma separated list of contig/ligand names that should not be symmetrized such as DNA strands. \
53
+ e.g. 'HEM' or 'Y1-11,Z16-25'",
54
+ )
55
+ is_symmetric_motif: bool = Field(
56
+ True,
57
+ description="If True, the input motifs are expected to be already symmetric and won't be symmetrized. \
58
+ If False, the all input motifs are expected to be ASU and will be symmetrized.",
59
+ )
60
+
61
+
62
+ def convery_sym_conf_to_symmetry_config(sym_conf: dict):
63
+ return SymmetryConfig(**sym_conf)
50
64
 
51
65
 
52
66
  def make_symmetric_atom_array(
53
- asu_atom_array, sym_conf: SymmetryConfig, sm=None, has_2d=False, src_atom_array=None
67
+ asu_atom_array,
68
+ sym_conf: SymmetryConfig | dict,
69
+ sm=None,
70
+ has_dist_cond=False,
71
+ src_atom_array=None,
54
72
  ):
55
73
  """
56
74
  apply symmetry to an atom array.
@@ -58,39 +76,33 @@ def make_symmetric_atom_array(
58
76
  asu_atom_array: atom array of the asymmetric unit
59
77
  sym_conf: symmetry configuration (dict, "id" key is required)
60
78
  sm: optional small molecule names (str, comma separated)
61
- has_2d: whether to add 2d entity annotations
79
+ has_dist_cond: whether to add 2d entity annotations
62
80
  Returns:
63
81
  new_asu_atom_array: atom array with symmetry applied
64
82
  """
65
- sym_conf = (
66
- sym_conf.model_dump()
67
- ) # TODO: JB: remove this line to keep as symmetry config for cleaner syntax(?)
68
- ranked_logger.info(f"Symmetry Configs: {sym_conf}")
83
+ if not isinstance(sym_conf, SymmetryConfig):
84
+ sym_conf = convery_sym_conf_to_symmetry_config(sym_conf)
69
85
 
70
- # Making sure that the symmetry config is valid
71
86
  check_symmetry_config(
72
- asu_atom_array,
73
- sym_conf,
74
- sm,
75
- has_dist_cond=has_2d,
76
- src_atom_array=src_atom_array,
87
+ asu_atom_array, sym_conf, sm, has_dist_cond, src_atom_array=src_atom_array
77
88
  )
78
89
  # Adding utility annotations to the asu atom array
79
90
  asu_atom_array = _add_util_annotations(asu_atom_array, sym_conf, sm)
80
91
 
81
- if has_2d: # NB: this will only work for asymmetric motifs at the moment - need to add functionality for symmetric motifs
92
+ if has_dist_cond: # NB: this will only work for asymmetric motifs at the moment - need to add functionality for symmetric motifs
82
93
  asu_atom_array = add_2d_entity_annotations(asu_atom_array)
83
94
 
84
95
  frames = get_symmetry_frames_from_symmetry_id(sym_conf)
85
96
 
86
97
  # If the motif is symmetric, we get the frames instead from the source atom array.
87
- if sym_conf.get("is_symmetric_motif"):
98
+ if sym_conf.is_symmetric_motif:
88
99
  assert (
89
100
  src_atom_array is not None
90
101
  ), "Source atom array must be provided for symmetric motifs"
91
- # if symmetric motif is provided, get the frames from the src atom array
102
+ # if symmetric motif is provided, get the frames from the src atom array.
92
103
  frames = get_symmetry_frames_from_atom_array(src_atom_array, frames)
93
- else:
104
+ elif (asu_atom_array._is_motif[~asu_atom_array._is_unsym_motif]).any():
105
+ # if the motifs that's not unsym motifs are present.
94
106
  raise NotImplementedError(
95
107
  "Asymmetric motif inputs are not implemented yet. please symmetrize the motif."
96
108
  )
@@ -101,7 +113,7 @@ def make_symmetric_atom_array(
101
113
  # Extracting all things at this moment that we will not want to symmetrize.
102
114
  # This includes: 1) unsym motifs, 2) ligands
103
115
  unsym_atom_arrays = []
104
- if sym_conf.get("is_unsym_motif"):
116
+ if sym_conf.is_unsym_motif:
105
117
  # unsym_motif_atom_array = get_unsym_motif(asu_atom_array, asu_atom_array._is_unsym_motif)
106
118
  # Now remove the unsym motifs from the asu atom array
107
119
  unsym_atom_arrays.append(asu_atom_array[asu_atom_array._is_unsym_motif])
@@ -128,7 +140,7 @@ def make_symmetric_atom_array(
128
140
  symmetrized_atom_array = struc.concatenate(symmetry_unit_list)
129
141
 
130
142
  # add 2D conditioning annotations
131
- if has_2d:
143
+ if has_dist_cond:
132
144
  symmetrized_atom_array = reannotate_2d_conditions(symmetrized_atom_array)
133
145
 
134
146
  # set all motifs to not have any symmetrization applied to them
@@ -183,7 +195,7 @@ def make_symmetric_atom_array_for_partial_diffusion(atom_array, sym_conf):
183
195
  frames = get_symmetry_frames_from_symmetry_id(sym_conf)
184
196
 
185
197
  # Add symmetry ID
186
- symmetry_ids = np.full(n, sym_conf.get("id"), dtype="U6")
198
+ symmetry_ids = np.full(n, sym_conf.id, dtype="U6")
187
199
  atom_array.set_annotation("symmetry_id", symmetry_ids)
188
200
 
189
201
  # Initialize transform annotations (use same format as original system)
@@ -244,7 +256,7 @@ def _add_util_annotations(asu_atom_array, sym_conf, sm):
244
256
  """
245
257
  n = asu_atom_array.shape[0]
246
258
  is_motif = get_motif_features(asu_atom_array)["is_motif_atom"].astype(np.bool_)
247
- is_sm = np.zeros(asu_atom_array.shape[0], dtype=bool)
259
+ is_sm = np.zeros(n, dtype=bool)
248
260
  is_asu = np.ones(n, dtype=bool)
249
261
  is_unsym_motif = np.zeros(n, dtype=bool)
250
262
 
@@ -257,8 +269,8 @@ def _add_util_annotations(asu_atom_array, sym_conf, sm):
257
269
  )
258
270
 
259
271
  # assign unsym motifs
260
- if sym_conf.get("is_unsym_motif"):
261
- unsym_motif_names = sym_conf["is_unsym_motif"].split(",")
272
+ if sym_conf.is_unsym_motif:
273
+ unsym_motif_names = sym_conf.is_unsym_motif.split(",")
262
274
  unsym_motif_names = expand_contig_unsym_motif(unsym_motif_names)
263
275
  is_unsym_motif = get_unsym_motif_mask(asu_atom_array, unsym_motif_names)
264
276
 
@@ -361,38 +373,4 @@ def apply_symmetry_to_xyz_atomwise(X_L, sym_feats, partial_diffusion=False):
361
373
  "blc,cd->bld", asu_xyz, sym_transforms[target_id][0].to(asu_xyz.dtype)
362
374
  ) + sym_transforms[target_id][1].to(asu_xyz.dtype)
363
375
 
364
- # Log inter-chain distances for debugging - use actual chain annotations
365
- if sym_X_L.shape[1] > 100: # Only for large structures
366
- # Use symmetry entity annotations to find different chains
367
- sym_entity_id = sym_feats["sym_entity_id"]
368
- unique_entities = torch.unique(sym_entity_id)
369
-
370
- if len(unique_entities) >= 2:
371
- # Get atoms from first two different entities
372
- entity_0_mask = sym_entity_id == unique_entities[0]
373
- entity_1_mask = sym_entity_id == unique_entities[1]
374
-
375
- if entity_0_mask.sum() > 0 and entity_1_mask.sum() > 0:
376
- entity_0_atoms = sym_X_L[0, entity_0_mask, :]
377
- entity_1_atoms = sym_X_L[0, entity_1_mask, :]
378
-
379
- # Sample subset to avoid memory issues
380
- entity_0_sample = entity_0_atoms[: min(50, entity_0_atoms.shape[0]), :]
381
- entity_1_sample = entity_1_atoms[: min(50, entity_1_atoms.shape[0]), :]
382
-
383
- min_distance = (
384
- torch.cdist(entity_0_sample, entity_1_sample).min().item()
385
- )
386
- ranked_logger.info(
387
- f"Min inter-chain distance after symmetry: {min_distance:.2f} Å"
388
- )
389
-
390
- # Also log the centers of each entity
391
- entity_0_center = entity_0_atoms.mean(dim=0)
392
- entity_1_center = entity_1_atoms.mean(dim=0)
393
- center_distance = torch.norm(entity_0_center - entity_1_center).item()
394
- ranked_logger.info(
395
- f"Distance between chain centers: {center_distance:.2f} Å"
396
- )
397
-
398
376
  return sym_X_L
rfd3/run_inference.py CHANGED
@@ -12,7 +12,9 @@ load_dotenv(override=True)
12
12
 
13
13
  # For pip-installed package, configs should be relative to this file
14
14
  # Adjust this path based on where configs are bundled in the package
15
- _config_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "configs")
15
+ _config_path = os.path.join(
16
+ os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "configs"
17
+ )
16
18
 
17
19
 
18
20
  @hydra.main(
rfd3/utils/inference.py CHANGED
@@ -391,6 +391,29 @@ def ensure_input_is_abspath(args: dict, path: PathLike | None):
391
391
  return args
392
392
 
393
393
 
394
+ def ensure_inference_sampler_matches_design_spec(
395
+ design_spec: dict, inference_sampler: dict | None = None
396
+ ):
397
+ """
398
+ Ensure the inference sampler is set to the correct sampler for the design specification.
399
+ Args:
400
+ design_spec: Design specification dictionary
401
+ inference_sampler: Inference sampler dictionary
402
+ """
403
+ has_symmetry_specification = [
404
+ True if "symmetry" in item.keys() else False for item in design_spec.values()
405
+ ]
406
+ if any(has_symmetry_specification):
407
+ if (
408
+ inference_sampler is None
409
+ or inference_sampler.get("kind", "default") != "symmetry"
410
+ ):
411
+ raise ValueError(
412
+ "You requested for symmetric designs, but inference sampler is not set to symmetry. "
413
+ "Please add inference_sampler.kind='symmetry' to your command."
414
+ )
415
+
416
+
394
417
  #################################################################################
395
418
  # Custom infer_ori functions
396
419
  #################################################################################