boltz-vsynthes 0.0.12__py3-none-any.whl → 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltz/utils/sdf_to_pre_affinity_npz.py +1267 -25
- {boltz_vsynthes-0.0.12.dist-info → boltz_vsynthes-0.0.13.dist-info}/METADATA +1 -1
- {boltz_vsynthes-0.0.12.dist-info → boltz_vsynthes-0.0.13.dist-info}/RECORD +7 -7
- {boltz_vsynthes-0.0.12.dist-info → boltz_vsynthes-0.0.13.dist-info}/entry_points.txt +1 -1
- {boltz_vsynthes-0.0.12.dist-info → boltz_vsynthes-0.0.13.dist-info}/WHEEL +0 -0
- {boltz_vsynthes-0.0.12.dist-info → boltz_vsynthes-0.0.13.dist-info}/licenses/LICENSE +0 -0
- {boltz_vsynthes-0.0.12.dist-info → boltz_vsynthes-0.0.13.dist-info}/top_level.txt +0 -0
@@ -1,35 +1,1277 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
1
|
+
import multiprocessing
|
2
|
+
import os
|
3
|
+
import pickle
|
4
|
+
import platform
|
5
|
+
import tarfile
|
6
|
+
import urllib.request
|
7
|
+
import warnings
|
8
|
+
from dataclasses import asdict, dataclass
|
9
|
+
from functools import partial
|
10
|
+
from multiprocessing import Pool
|
11
|
+
from pathlib import Path
|
12
|
+
from typing import Literal, Optional
|
4
13
|
|
5
|
-
|
6
|
-
|
14
|
+
import click
|
15
|
+
import torch
|
16
|
+
from pytorch_lightning import Trainer, seed_everything
|
17
|
+
from pytorch_lightning.strategies import DDPStrategy
|
18
|
+
from pytorch_lightning.utilities import rank_zero_only
|
19
|
+
from rdkit import Chem
|
20
|
+
from tqdm import tqdm
|
7
21
|
|
8
|
-
|
9
|
-
|
10
|
-
import
|
11
|
-
from
|
22
|
+
from boltz.data import const
|
23
|
+
from boltz.data.module.inference import BoltzInferenceDataModule
|
24
|
+
from boltz.data.module.inferencev2 import Boltz2InferenceDataModule
|
25
|
+
from boltz.data.mol import load_canonicals
|
26
|
+
from boltz.data.msa.mmseqs2 import run_mmseqs2
|
27
|
+
from boltz.data.parse.a3m import parse_a3m
|
28
|
+
from boltz.data.parse.csv import parse_csv
|
29
|
+
from boltz.data.parse.fasta import parse_fasta
|
30
|
+
from boltz.data.parse.yaml import parse_yaml
|
31
|
+
from boltz.data.types import MSA, Manifest, Record
|
32
|
+
from boltz.data.write.writer import BoltzAffinityWriter, BoltzWriter
|
33
|
+
from boltz.model.models.boltz1 import Boltz1
|
34
|
+
from boltz.model.models.boltz2 import Boltz2
|
35
|
+
|
36
|
+
CCD_URL = "https://huggingface.co/boltz-community/boltz-1/resolve/main/ccd.pkl"
|
37
|
+
MOL_URL = "https://huggingface.co/boltz-community/boltz-2/resolve/main/mols.tar"
|
38
|
+
|
39
|
+
BOLTZ1_URL_WITH_FALLBACK = [
|
40
|
+
"https://model-gateway.boltz.bio/boltz1_conf.ckpt",
|
41
|
+
"https://huggingface.co/boltz-community/boltz-1/resolve/main/boltz1_conf.ckpt",
|
42
|
+
]
|
43
|
+
|
44
|
+
BOLTZ2_URL_WITH_FALLBACK = [
|
45
|
+
"https://model-gateway.boltz.bio/boltz2_conf.ckpt",
|
46
|
+
"https://huggingface.co/boltz-community/boltz-2/resolve/main/boltz2_conf.ckpt",
|
47
|
+
]
|
48
|
+
|
49
|
+
BOLTZ2_AFFINITY_URL_WITH_FALLBACK = [
|
50
|
+
"https://model-gateway.boltz.bio/boltz2_aff.ckpt",
|
51
|
+
"https://huggingface.co/boltz-community/boltz-2/resolve/main/boltz2_aff.ckpt",
|
52
|
+
]
|
53
|
+
|
54
|
+
|
55
|
+
@dataclass
|
56
|
+
class BoltzProcessedInput:
|
57
|
+
"""Processed input data."""
|
58
|
+
|
59
|
+
manifest: Manifest
|
60
|
+
targets_dir: Path
|
61
|
+
msa_dir: Path
|
62
|
+
constraints_dir: Optional[Path] = None
|
63
|
+
template_dir: Optional[Path] = None
|
64
|
+
extra_mols_dir: Optional[Path] = None
|
65
|
+
|
66
|
+
|
67
|
+
@dataclass
|
68
|
+
class PairformerArgs:
|
69
|
+
"""Pairformer arguments."""
|
70
|
+
|
71
|
+
num_blocks: int = 48
|
72
|
+
num_heads: int = 16
|
73
|
+
dropout: float = 0.0
|
74
|
+
activation_checkpointing: bool = False
|
75
|
+
offload_to_cpu: bool = False
|
76
|
+
v2: bool = False
|
77
|
+
|
78
|
+
|
79
|
+
@dataclass
|
80
|
+
class PairformerArgsV2:
|
81
|
+
"""Pairformer arguments."""
|
82
|
+
|
83
|
+
num_blocks: int = 64
|
84
|
+
num_heads: int = 16
|
85
|
+
dropout: float = 0.0
|
86
|
+
activation_checkpointing: bool = False
|
87
|
+
offload_to_cpu: bool = False
|
88
|
+
v2: bool = True
|
89
|
+
|
90
|
+
|
91
|
+
@dataclass
|
92
|
+
class MSAModuleArgs:
|
93
|
+
"""MSA module arguments."""
|
94
|
+
|
95
|
+
msa_s: int = 64
|
96
|
+
msa_blocks: int = 4
|
97
|
+
msa_dropout: float = 0.0
|
98
|
+
z_dropout: float = 0.0
|
99
|
+
use_paired_feature: bool = True
|
100
|
+
pairwise_head_width: int = 32
|
101
|
+
pairwise_num_heads: int = 4
|
102
|
+
activation_checkpointing: bool = False
|
103
|
+
offload_to_cpu: bool = False
|
104
|
+
subsample_msa: bool = False
|
105
|
+
num_subsampled_msa: int = 1024
|
106
|
+
|
107
|
+
|
108
|
+
@dataclass
|
109
|
+
class BoltzDiffusionParams:
|
110
|
+
"""Diffusion process parameters."""
|
111
|
+
|
112
|
+
gamma_0: float = 0.605
|
113
|
+
gamma_min: float = 1.107
|
114
|
+
noise_scale: float = 0.901
|
115
|
+
rho: float = 8
|
116
|
+
step_scale: float = 1.638
|
117
|
+
sigma_min: float = 0.0004
|
118
|
+
sigma_max: float = 160.0
|
119
|
+
sigma_data: float = 16.0
|
120
|
+
P_mean: float = -1.2
|
121
|
+
P_std: float = 1.5
|
122
|
+
coordinate_augmentation: bool = True
|
123
|
+
alignment_reverse_diff: bool = True
|
124
|
+
synchronize_sigmas: bool = True
|
125
|
+
use_inference_model_cache: bool = True
|
126
|
+
|
127
|
+
|
128
|
+
@dataclass
|
129
|
+
class Boltz2DiffusionParams:
|
130
|
+
"""Diffusion process parameters."""
|
131
|
+
|
132
|
+
gamma_0: float = 0.8
|
133
|
+
gamma_min: float = 1.0
|
134
|
+
noise_scale: float = 1.003
|
135
|
+
rho: float = 7
|
136
|
+
step_scale: float = 1.5
|
137
|
+
sigma_min: float = 0.0001
|
138
|
+
sigma_max: float = 160.0
|
139
|
+
sigma_data: float = 16.0
|
140
|
+
P_mean: float = -1.2
|
141
|
+
P_std: float = 1.5
|
142
|
+
coordinate_augmentation: bool = True
|
143
|
+
alignment_reverse_diff: bool = True
|
144
|
+
synchronize_sigmas: bool = True
|
145
|
+
|
146
|
+
|
147
|
+
@dataclass
|
148
|
+
class BoltzSteeringParams:
|
149
|
+
"""Steering parameters."""
|
150
|
+
|
151
|
+
fk_steering: bool = True
|
152
|
+
num_particles: int = 3
|
153
|
+
fk_lambda: float = 4.0
|
154
|
+
fk_resampling_interval: int = 3
|
155
|
+
guidance_update: bool = True
|
156
|
+
num_gd_steps: int = 20
|
157
|
+
|
158
|
+
|
159
|
+
@rank_zero_only
|
160
|
+
def download_boltz1(cache: Path) -> None:
|
161
|
+
"""Download all the required data.
|
162
|
+
|
163
|
+
Parameters
|
164
|
+
----------
|
165
|
+
cache : Path
|
166
|
+
The cache directory.
|
167
|
+
|
168
|
+
"""
|
169
|
+
# Download CCD
|
170
|
+
ccd = cache / "ccd.pkl"
|
171
|
+
if not ccd.exists():
|
172
|
+
click.echo(
|
173
|
+
f"Downloading the CCD dictionary to {ccd}. You may "
|
174
|
+
"change the cache directory with the --cache flag."
|
175
|
+
)
|
176
|
+
urllib.request.urlretrieve(CCD_URL, str(ccd)) # noqa: S310
|
177
|
+
|
178
|
+
# Download model
|
179
|
+
model = cache / "boltz1_conf.ckpt"
|
180
|
+
if not model.exists():
|
181
|
+
click.echo(
|
182
|
+
f"Downloading the model weights to {model}. You may "
|
183
|
+
"change the cache directory with the --cache flag."
|
184
|
+
)
|
185
|
+
for i, url in enumerate(BOLTZ1_URL_WITH_FALLBACK):
|
186
|
+
try:
|
187
|
+
urllib.request.urlretrieve(url, str(model)) # noqa: S310
|
188
|
+
break
|
189
|
+
except Exception as e: # noqa: BLE001
|
190
|
+
if i == len(BOLTZ1_URL_WITH_FALLBACK) - 1:
|
191
|
+
msg = f"Failed to download model from all URLs. Last error: {e}"
|
192
|
+
raise RuntimeError(msg) from e
|
193
|
+
continue
|
194
|
+
|
195
|
+
|
196
|
+
@rank_zero_only
|
197
|
+
def download_boltz2(cache: Path) -> None:
|
198
|
+
"""Download all the required data.
|
199
|
+
|
200
|
+
Parameters
|
201
|
+
----------
|
202
|
+
cache : Path
|
203
|
+
The cache directory.
|
204
|
+
|
205
|
+
"""
|
206
|
+
# Download CCD
|
207
|
+
mols = cache / "mols"
|
208
|
+
tar_mols = cache / "mols.tar"
|
209
|
+
if not tar_mols.exists():
|
210
|
+
click.echo(
|
211
|
+
f"Downloading the CCD data to {tar_mols}. "
|
212
|
+
"This may take a bit of time. You may change the cache directory "
|
213
|
+
"with the --cache flag."
|
214
|
+
)
|
215
|
+
urllib.request.urlretrieve(MOL_URL, str(tar_mols)) # noqa: S310
|
216
|
+
if not mols.exists():
|
217
|
+
click.echo(
|
218
|
+
f"Extracting the CCD data to {mols}. "
|
219
|
+
"This may take a bit of time. You may change the cache directory "
|
220
|
+
"with the --cache flag."
|
221
|
+
)
|
222
|
+
with tarfile.open(str(tar_mols), "r") as tar:
|
223
|
+
tar.extractall(cache) # noqa: S202
|
224
|
+
|
225
|
+
# Download model
|
226
|
+
model = cache / "boltz2_conf.ckpt"
|
227
|
+
if not model.exists():
|
228
|
+
click.echo(
|
229
|
+
f"Downloading the Boltz-2 weights to {model}. You may "
|
230
|
+
"change the cache directory with the --cache flag."
|
231
|
+
)
|
232
|
+
for i, url in enumerate(BOLTZ2_URL_WITH_FALLBACK):
|
233
|
+
try:
|
234
|
+
urllib.request.urlretrieve(url, str(model)) # noqa: S310
|
235
|
+
break
|
236
|
+
except Exception as e: # noqa: BLE001
|
237
|
+
if i == len(BOLTZ2_URL_WITH_FALLBACK) - 1:
|
238
|
+
msg = f"Failed to download model from all URLs. Last error: {e}"
|
239
|
+
raise RuntimeError(msg) from e
|
240
|
+
continue
|
241
|
+
|
242
|
+
# Download affinity model
|
243
|
+
affinity_model = cache / "boltz2_aff.ckpt"
|
244
|
+
if not affinity_model.exists():
|
245
|
+
click.echo(
|
246
|
+
f"Downloading the Boltz-2 affinity weights to {affinity_model}. You may "
|
247
|
+
"change the cache directory with the --cache flag."
|
248
|
+
)
|
249
|
+
for i, url in enumerate(BOLTZ2_AFFINITY_URL_WITH_FALLBACK):
|
250
|
+
try:
|
251
|
+
urllib.request.urlretrieve(url, str(affinity_model)) # noqa: S310
|
252
|
+
break
|
253
|
+
except Exception as e: # noqa: BLE001
|
254
|
+
if i == len(BOLTZ2_AFFINITY_URL_WITH_FALLBACK) - 1:
|
255
|
+
msg = f"Failed to download model from all URLs. Last error: {e}"
|
256
|
+
raise RuntimeError(msg) from e
|
257
|
+
continue
|
258
|
+
|
259
|
+
|
260
|
+
def get_cache_path() -> str:
|
261
|
+
"""Determine the cache path, prioritising the BOLTZ_CACHE environment variable.
|
262
|
+
|
263
|
+
Returns
|
264
|
+
-------
|
265
|
+
str: Path
|
266
|
+
Path to use for boltz cache location.
|
267
|
+
|
268
|
+
"""
|
269
|
+
env_cache = os.environ.get("BOLTZ_CACHE")
|
270
|
+
if env_cache:
|
271
|
+
resolved_cache = Path(env_cache).expanduser().resolve()
|
272
|
+
if not resolved_cache.is_absolute():
|
273
|
+
msg = f"BOLTZ_CACHE must be an absolute path, got: {env_cache}"
|
274
|
+
raise ValueError(msg)
|
275
|
+
return str(resolved_cache)
|
276
|
+
|
277
|
+
return str(Path("~/.boltz").expanduser())
|
278
|
+
|
279
|
+
|
280
|
+
def check_inputs(data: Path) -> list[Path]:
|
281
|
+
"""Check the input data and output directory.
|
282
|
+
|
283
|
+
Parameters
|
284
|
+
----------
|
285
|
+
data : Path
|
286
|
+
The input data.
|
287
|
+
|
288
|
+
Returns
|
289
|
+
-------
|
290
|
+
list[Path]
|
291
|
+
The list of input data.
|
292
|
+
|
293
|
+
"""
|
294
|
+
click.echo("Checking input data.")
|
295
|
+
|
296
|
+
# Check if data is a directory
|
297
|
+
if data.is_dir():
|
298
|
+
data: list[Path] = list(data.glob("*"))
|
299
|
+
|
300
|
+
# Filter out non .fasta or .yaml files, raise
|
301
|
+
# an error on directory and other file types
|
302
|
+
for d in data:
|
303
|
+
if d.is_dir():
|
304
|
+
msg = f"Found directory {d} instead of .fasta or .yaml."
|
305
|
+
raise RuntimeError(msg)
|
306
|
+
if d.suffix not in (".fa", ".fas", ".fasta", ".yml", ".yaml"):
|
307
|
+
msg = (
|
308
|
+
f"Unable to parse filetype {d.suffix}, "
|
309
|
+
"please provide a .fasta or .yaml file."
|
310
|
+
)
|
311
|
+
raise RuntimeError(msg)
|
312
|
+
else:
|
313
|
+
data = [data]
|
314
|
+
|
315
|
+
return data
|
316
|
+
|
317
|
+
|
318
|
+
def filter_inputs_structure(
|
319
|
+
manifest: Manifest,
|
320
|
+
outdir: Path,
|
321
|
+
override: bool = False,
|
322
|
+
) -> Manifest:
|
323
|
+
"""Filter the manifest to only include missing predictions.
|
324
|
+
|
325
|
+
Parameters
|
326
|
+
----------
|
327
|
+
manifest : Manifest
|
328
|
+
The manifest of the input data.
|
329
|
+
outdir : Path
|
330
|
+
The output directory.
|
331
|
+
override: bool
|
332
|
+
Whether to override existing predictions.
|
333
|
+
|
334
|
+
Returns
|
335
|
+
-------
|
336
|
+
Manifest
|
337
|
+
The manifest of the filtered input data.
|
338
|
+
|
339
|
+
"""
|
340
|
+
# Check if existing predictions are found
|
341
|
+
existing = (outdir / "predictions").rglob("*")
|
342
|
+
existing = {e.name for e in existing if e.is_dir()}
|
343
|
+
|
344
|
+
# Remove them from the input data
|
345
|
+
if existing and not override:
|
346
|
+
manifest = Manifest([r for r in manifest.records if r.id not in existing])
|
347
|
+
msg = (
|
348
|
+
f"Found some existing predictions ({len(existing)}), "
|
349
|
+
f"skipping and running only the missing ones, "
|
350
|
+
"if any. If you wish to override these existing "
|
351
|
+
"predictions, please set the --override flag."
|
352
|
+
)
|
353
|
+
click.echo(msg)
|
354
|
+
elif existing and override:
|
355
|
+
msg = f"Found {len(existing)} existing predictions, will override."
|
356
|
+
click.echo(msg)
|
357
|
+
|
358
|
+
return manifest
|
359
|
+
|
360
|
+
|
361
|
+
def filter_inputs_affinity(
|
362
|
+
manifest: Manifest,
|
363
|
+
outdir: Path,
|
364
|
+
override: bool = False,
|
365
|
+
) -> Manifest:
|
366
|
+
"""Check the input data and output directory for affinity.
|
367
|
+
|
368
|
+
Parameters
|
369
|
+
----------
|
370
|
+
manifest : Manifest
|
371
|
+
The manifest.
|
372
|
+
outdir : Path
|
373
|
+
The output directory.
|
374
|
+
override: bool
|
375
|
+
Whether to override existing predictions.
|
376
|
+
|
377
|
+
Returns
|
378
|
+
-------
|
379
|
+
Manifest
|
380
|
+
The manifest of the filtered input data.
|
381
|
+
|
382
|
+
"""
|
383
|
+
click.echo("Checking input data for affinity.")
|
384
|
+
|
385
|
+
# Get all affinity targets
|
386
|
+
existing = {
|
387
|
+
r.id
|
388
|
+
for r in manifest.records
|
389
|
+
if r.affinity
|
390
|
+
and (outdir / "predictions" / r.id / f"affinity_{r.id}.json").exists()
|
391
|
+
}
|
392
|
+
|
393
|
+
# Remove them from the input data
|
394
|
+
if existing and not override:
|
395
|
+
num_skipped = len(existing)
|
396
|
+
msg = (
|
397
|
+
f"Found some existing affinity predictions ({num_skipped}), "
|
398
|
+
f"skipping and running only the missing ones, "
|
399
|
+
"if any. If you wish to override these existing "
|
400
|
+
"affinity predictions, please set the --override flag."
|
401
|
+
)
|
402
|
+
click.echo(msg)
|
403
|
+
elif existing and override:
|
404
|
+
msg = "Found existing affinity predictions, will override."
|
405
|
+
click.echo(msg)
|
406
|
+
|
407
|
+
return Manifest([r for r in manifest.records if r.id not in existing])
|
408
|
+
|
409
|
+
|
410
|
+
def compute_msa(
|
411
|
+
data: dict[str, str],
|
412
|
+
target_id: str,
|
413
|
+
msa_dir: Path,
|
414
|
+
msa_server_url: str,
|
415
|
+
msa_pairing_strategy: str,
|
416
|
+
) -> None:
|
417
|
+
"""Compute the MSA for the input data.
|
418
|
+
|
419
|
+
Parameters
|
420
|
+
----------
|
421
|
+
data : dict[str, str]
|
422
|
+
The input protein sequences.
|
423
|
+
target_id : str
|
424
|
+
The target id.
|
425
|
+
msa_dir : Path
|
426
|
+
The msa directory.
|
427
|
+
msa_server_url : str
|
428
|
+
The MSA server URL.
|
429
|
+
msa_pairing_strategy : str
|
430
|
+
The MSA pairing strategy.
|
431
|
+
|
432
|
+
"""
|
433
|
+
if len(data) > 1:
|
434
|
+
paired_msas = run_mmseqs2(
|
435
|
+
list(data.values()),
|
436
|
+
msa_dir / f"{target_id}_paired_tmp",
|
437
|
+
use_env=True,
|
438
|
+
use_pairing=True,
|
439
|
+
host_url=msa_server_url,
|
440
|
+
pairing_strategy=msa_pairing_strategy,
|
441
|
+
)
|
442
|
+
else:
|
443
|
+
paired_msas = [""] * len(data)
|
444
|
+
|
445
|
+
unpaired_msa = run_mmseqs2(
|
446
|
+
list(data.values()),
|
447
|
+
msa_dir / f"{target_id}_unpaired_tmp",
|
448
|
+
use_env=True,
|
449
|
+
use_pairing=False,
|
450
|
+
host_url=msa_server_url,
|
451
|
+
pairing_strategy=msa_pairing_strategy,
|
452
|
+
)
|
453
|
+
|
454
|
+
for idx, name in enumerate(data):
|
455
|
+
# Get paired sequences
|
456
|
+
paired = paired_msas[idx].strip().splitlines()
|
457
|
+
paired = paired[1::2] # ignore headers
|
458
|
+
paired = paired[: const.max_paired_seqs]
|
459
|
+
|
460
|
+
# Set key per row and remove empty sequences
|
461
|
+
keys = [idx for idx, s in enumerate(paired) if s != "-" * len(s)]
|
462
|
+
paired = [s for s in paired if s != "-" * len(s)]
|
463
|
+
|
464
|
+
# Combine paired-unpaired sequences
|
465
|
+
unpaired = unpaired_msa[idx].strip().splitlines()
|
466
|
+
unpaired = unpaired[1::2]
|
467
|
+
unpaired = unpaired[: (const.max_msa_seqs - len(paired))]
|
468
|
+
if paired:
|
469
|
+
unpaired = unpaired[1:] # ignore query is already present
|
470
|
+
|
471
|
+
# Combine
|
472
|
+
seqs = paired + unpaired
|
473
|
+
keys = keys + [-1] * len(unpaired)
|
474
|
+
|
475
|
+
# Dump MSA
|
476
|
+
csv_str = ["key,sequence"] + [f"{key},{seq}" for key, seq in zip(keys, seqs)]
|
477
|
+
|
478
|
+
msa_path = msa_dir / f"{name}.csv"
|
479
|
+
with msa_path.open("w") as f:
|
480
|
+
f.write("\n".join(csv_str))
|
481
|
+
|
482
|
+
|
483
|
+
def process_input( # noqa: C901, PLR0912, PLR0915, D103
|
484
|
+
path: Path,
|
485
|
+
ccd: dict,
|
486
|
+
msa_dir: Path,
|
487
|
+
mol_dir: Path,
|
488
|
+
boltz2: bool,
|
489
|
+
use_msa_server: bool,
|
490
|
+
msa_server_url: str,
|
491
|
+
msa_pairing_strategy: str,
|
492
|
+
max_msa_seqs: int,
|
493
|
+
processed_msa_dir: Path,
|
494
|
+
processed_constraints_dir: Path,
|
495
|
+
processed_templates_dir: Path,
|
496
|
+
processed_mols_dir: Path,
|
497
|
+
structure_dir: Path,
|
498
|
+
records_dir: Path,
|
499
|
+
) -> None:
|
500
|
+
try:
|
501
|
+
# Parse data
|
502
|
+
if path.suffix in (".fa", ".fas", ".fasta"):
|
503
|
+
target = parse_fasta(path, ccd, mol_dir, boltz2)
|
504
|
+
elif path.suffix in (".yml", ".yaml"):
|
505
|
+
target = parse_yaml(path, ccd, mol_dir, boltz2)
|
506
|
+
elif path.is_dir():
|
507
|
+
msg = f"Found directory {path} instead of .fasta or .yaml, skipping."
|
508
|
+
raise RuntimeError(msg) # noqa: TRY301
|
509
|
+
else:
|
510
|
+
msg = (
|
511
|
+
f"Unable to parse filetype {path.suffix}, "
|
512
|
+
"please provide a .fasta or .yaml file."
|
513
|
+
)
|
514
|
+
raise RuntimeError(msg) # noqa: TRY301
|
515
|
+
|
516
|
+
# Get target id
|
517
|
+
target_id = target.record.id
|
518
|
+
|
519
|
+
# Get all MSA ids and decide whether to generate MSA
|
520
|
+
to_generate = {}
|
521
|
+
prot_id = const.chain_type_ids["PROTEIN"]
|
522
|
+
for chain in target.record.chains:
|
523
|
+
# Add to generate list, assigning entity id
|
524
|
+
if (chain.mol_type == prot_id) and (chain.msa_id == 0):
|
525
|
+
entity_id = chain.entity_id
|
526
|
+
msa_id = f"{target_id}_{entity_id}"
|
527
|
+
to_generate[msa_id] = target.sequences[entity_id]
|
528
|
+
chain.msa_id = msa_dir / f"{msa_id}.csv"
|
529
|
+
|
530
|
+
# We do not support msa generation for non-protein chains
|
531
|
+
elif chain.msa_id == 0:
|
532
|
+
chain.msa_id = -1
|
533
|
+
|
534
|
+
# Generate MSA
|
535
|
+
if to_generate and not use_msa_server:
|
536
|
+
msg = "Missing MSA's in input and --use_msa_server flag not set."
|
537
|
+
raise RuntimeError(msg) # noqa: TRY301
|
538
|
+
|
539
|
+
if to_generate:
|
540
|
+
msg = f"Generating MSA for {path} with {len(to_generate)} protein entities."
|
541
|
+
click.echo(msg)
|
542
|
+
compute_msa(
|
543
|
+
data=to_generate,
|
544
|
+
target_id=target_id,
|
545
|
+
msa_dir=msa_dir,
|
546
|
+
msa_server_url=msa_server_url,
|
547
|
+
msa_pairing_strategy=msa_pairing_strategy,
|
548
|
+
)
|
549
|
+
|
550
|
+
# Parse MSA data
|
551
|
+
msas = sorted({c.msa_id for c in target.record.chains if c.msa_id != -1})
|
552
|
+
msa_id_map = {}
|
553
|
+
for msa_idx, msa_id in enumerate(msas):
|
554
|
+
# Check that raw MSA exists
|
555
|
+
msa_path = Path(msa_id)
|
556
|
+
if not msa_path.exists():
|
557
|
+
msg = f"MSA file {msa_path} not found."
|
558
|
+
raise FileNotFoundError(msg) # noqa: TRY301
|
559
|
+
|
560
|
+
# Dump processed MSA
|
561
|
+
processed = processed_msa_dir / f"{target_id}_{msa_idx}.npz"
|
562
|
+
msa_id_map[msa_id] = f"{target_id}_{msa_idx}"
|
563
|
+
if not processed.exists():
|
564
|
+
# Parse A3M
|
565
|
+
if msa_path.suffix == ".a3m":
|
566
|
+
msa: MSA = parse_a3m(
|
567
|
+
msa_path,
|
568
|
+
taxonomy=None,
|
569
|
+
max_seqs=max_msa_seqs,
|
570
|
+
)
|
571
|
+
elif msa_path.suffix == ".csv":
|
572
|
+
msa: MSA = parse_csv(msa_path, max_seqs=max_msa_seqs)
|
573
|
+
else:
|
574
|
+
msg = f"MSA file {msa_path} not supported, only a3m or csv."
|
575
|
+
raise RuntimeError(msg) # noqa: TRY301
|
576
|
+
|
577
|
+
msa.dump(processed)
|
578
|
+
|
579
|
+
# Modify records to point to processed MSA
|
580
|
+
for c in target.record.chains:
|
581
|
+
if (c.msa_id != -1) and (c.msa_id in msa_id_map):
|
582
|
+
c.msa_id = msa_id_map[c.msa_id]
|
583
|
+
|
584
|
+
# Dump templates
|
585
|
+
for template_id, template in target.templates.items():
|
586
|
+
name = f"{target.record.id}_{template_id}.npz"
|
587
|
+
template_path = processed_templates_dir / name
|
588
|
+
template.dump(template_path)
|
589
|
+
|
590
|
+
# Dump constraints
|
591
|
+
constraints_path = processed_constraints_dir / f"{target.record.id}.npz"
|
592
|
+
target.residue_constraints.dump(constraints_path)
|
593
|
+
|
594
|
+
# Dump extra molecules
|
595
|
+
Chem.SetDefaultPickleProperties(Chem.PropertyPickleOptions.AllProps)
|
596
|
+
with (processed_mols_dir / f"{target.record.id}.pkl").open("wb") as f:
|
597
|
+
pickle.dump(target.extra_mols, f)
|
598
|
+
|
599
|
+
# Dump structure
|
600
|
+
struct_path = structure_dir / f"{target.record.id}.npz"
|
601
|
+
target.structure.dump(struct_path)
|
602
|
+
|
603
|
+
# Dump record
|
604
|
+
record_path = records_dir / f"{target.record.id}.json"
|
605
|
+
target.record.dump(record_path)
|
606
|
+
|
607
|
+
except Exception as e: # noqa: BLE001
|
608
|
+
import traceback
|
609
|
+
|
610
|
+
traceback.print_exc()
|
611
|
+
print(f"Failed to process {path}. Skipping. Error: {e}.") # noqa: T201
|
612
|
+
|
613
|
+
|
614
|
+
@rank_zero_only
|
615
|
+
def process_inputs(
|
616
|
+
data: list[Path],
|
617
|
+
out_dir: Path,
|
618
|
+
ccd_path: Path,
|
619
|
+
mol_dir: Path,
|
620
|
+
msa_server_url: str,
|
621
|
+
msa_pairing_strategy: str,
|
622
|
+
max_msa_seqs: int = 8192,
|
623
|
+
use_msa_server: bool = False,
|
624
|
+
boltz2: bool = False,
|
625
|
+
preprocessing_threads: int = 1,
|
626
|
+
) -> Manifest:
|
627
|
+
"""Process the input data and output directory.
|
628
|
+
|
629
|
+
Parameters
|
630
|
+
----------
|
631
|
+
data : list[Path]
|
632
|
+
The input data.
|
633
|
+
out_dir : Path
|
634
|
+
The output directory.
|
635
|
+
ccd_path : Path
|
636
|
+
The path to the CCD dictionary.
|
637
|
+
max_msa_seqs : int, optional
|
638
|
+
Max number of MSA sequences, by default 4096.
|
639
|
+
use_msa_server : bool, optional
|
640
|
+
Whether to use the MMSeqs2 server for MSA generation, by default False.
|
641
|
+
boltz2: bool, optional
|
642
|
+
Whether to use Boltz2, by default False.
|
643
|
+
preprocessing_threads: int, optional
|
644
|
+
The number of threads to use for preprocessing, by default 1.
|
645
|
+
|
646
|
+
Returns
|
647
|
+
-------
|
648
|
+
Manifest
|
649
|
+
The manifest of the processed input data.
|
650
|
+
|
651
|
+
"""
|
652
|
+
# Check if records exist at output path
|
653
|
+
records_dir = out_dir / "processed" / "records"
|
654
|
+
if records_dir.exists():
|
655
|
+
# Load existing records
|
656
|
+
existing = [Record.load(p) for p in records_dir.glob("*.json")]
|
657
|
+
processed_ids = {record.id for record in existing}
|
658
|
+
|
659
|
+
# Filter to missing only
|
660
|
+
data = [d for d in data if d.stem not in processed_ids]
|
661
|
+
|
662
|
+
# Nothing to do, update the manifest and return
|
663
|
+
if data:
|
664
|
+
click.echo(
|
665
|
+
f"Found {len(existing)} existing processed inputs, skipping them."
|
666
|
+
)
|
667
|
+
else:
|
668
|
+
click.echo("All inputs are already processed.")
|
669
|
+
updated_manifest = Manifest(existing)
|
670
|
+
updated_manifest.dump(out_dir / "processed" / "manifest.json")
|
671
|
+
|
672
|
+
# Create output directories
|
673
|
+
msa_dir = out_dir / "msa"
|
674
|
+
records_dir = out_dir / "processed" / "records"
|
675
|
+
structure_dir = out_dir / "processed" / "structures"
|
676
|
+
processed_msa_dir = out_dir / "processed" / "msa"
|
677
|
+
processed_constraints_dir = out_dir / "processed" / "constraints"
|
678
|
+
processed_templates_dir = out_dir / "processed" / "templates"
|
679
|
+
processed_mols_dir = out_dir / "processed" / "mols"
|
680
|
+
predictions_dir = out_dir / "predictions"
|
681
|
+
|
682
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
683
|
+
msa_dir.mkdir(parents=True, exist_ok=True)
|
684
|
+
records_dir.mkdir(parents=True, exist_ok=True)
|
685
|
+
structure_dir.mkdir(parents=True, exist_ok=True)
|
686
|
+
processed_msa_dir.mkdir(parents=True, exist_ok=True)
|
687
|
+
processed_constraints_dir.mkdir(parents=True, exist_ok=True)
|
688
|
+
processed_templates_dir.mkdir(parents=True, exist_ok=True)
|
689
|
+
processed_mols_dir.mkdir(parents=True, exist_ok=True)
|
690
|
+
predictions_dir.mkdir(parents=True, exist_ok=True)
|
691
|
+
|
692
|
+
# Load CCD
|
693
|
+
if boltz2:
|
694
|
+
ccd = load_canonicals(mol_dir)
|
695
|
+
else:
|
696
|
+
with ccd_path.open("rb") as file:
|
697
|
+
ccd = pickle.load(file) # noqa: S301
|
698
|
+
|
699
|
+
# Create partial function
|
700
|
+
process_input_partial = partial(
|
701
|
+
process_input,
|
702
|
+
ccd=ccd,
|
703
|
+
msa_dir=msa_dir,
|
704
|
+
mol_dir=mol_dir,
|
705
|
+
boltz2=boltz2,
|
706
|
+
use_msa_server=use_msa_server,
|
707
|
+
msa_server_url=msa_server_url,
|
708
|
+
msa_pairing_strategy=msa_pairing_strategy,
|
709
|
+
max_msa_seqs=max_msa_seqs,
|
710
|
+
processed_msa_dir=processed_msa_dir,
|
711
|
+
processed_constraints_dir=processed_constraints_dir,
|
712
|
+
processed_templates_dir=processed_templates_dir,
|
713
|
+
processed_mols_dir=processed_mols_dir,
|
714
|
+
structure_dir=structure_dir,
|
715
|
+
records_dir=records_dir,
|
716
|
+
)
|
717
|
+
|
718
|
+
# Parse input data
|
719
|
+
preprocessing_threads = min(preprocessing_threads, len(data))
|
720
|
+
click.echo(f"Processing {len(data)} inputs with {preprocessing_threads} threads.")
|
721
|
+
|
722
|
+
if preprocessing_threads > 1 and len(data) > 1:
|
723
|
+
with Pool(preprocessing_threads) as pool:
|
724
|
+
list(tqdm(pool.imap(process_input_partial, data), total=len(data)))
|
725
|
+
else:
|
726
|
+
for path in tqdm(data):
|
727
|
+
process_input_partial(path)
|
728
|
+
|
729
|
+
# Load all records and write manifest
|
730
|
+
records = [Record.load(p) for p in records_dir.glob("*.json")]
|
731
|
+
manifest = Manifest(records)
|
732
|
+
manifest.dump(out_dir / "processed" / "manifest.json")
|
733
|
+
|
734
|
+
|
735
|
+
@click.group()
|
736
|
+
def cli() -> None:
|
737
|
+
"""Boltz."""
|
738
|
+
return
|
739
|
+
|
740
|
+
|
741
|
+
@cli.command()
|
742
|
+
@click.argument("data", type=click.Path(exists=True))
|
743
|
+
@click.option(
|
744
|
+
"--out_dir",
|
745
|
+
type=click.Path(exists=False),
|
746
|
+
help="The path where to save the predictions.",
|
747
|
+
default="./",
|
748
|
+
)
|
749
|
+
@click.option(
|
750
|
+
"--cache",
|
751
|
+
type=click.Path(exists=False),
|
752
|
+
help=(
|
753
|
+
"The directory where to download the data and model. "
|
754
|
+
"Default is ~/.boltz, or $BOLTZ_CACHE if set."
|
755
|
+
),
|
756
|
+
default=get_cache_path,
|
757
|
+
)
|
758
|
+
@click.option(
|
759
|
+
"--checkpoint",
|
760
|
+
type=click.Path(exists=True),
|
761
|
+
help="An optional checkpoint, will use the provided Boltz-1 model by default.",
|
762
|
+
default=None,
|
763
|
+
)
|
764
|
+
@click.option(
|
765
|
+
"--devices",
|
766
|
+
type=int,
|
767
|
+
help="The number of devices to use for prediction. Default is 1.",
|
768
|
+
default=1,
|
769
|
+
)
|
770
|
+
@click.option(
|
771
|
+
"--accelerator",
|
772
|
+
type=click.Choice(["gpu", "cpu", "tpu"]),
|
773
|
+
help="The accelerator to use for prediction. Default is gpu.",
|
774
|
+
default="gpu",
|
775
|
+
)
|
776
|
+
@click.option(
|
777
|
+
"--recycling_steps",
|
778
|
+
type=int,
|
779
|
+
help="The number of recycling steps to use for prediction. Default is 3.",
|
780
|
+
default=3,
|
781
|
+
)
|
782
|
+
@click.option(
|
783
|
+
"--sampling_steps",
|
784
|
+
type=int,
|
785
|
+
help="The number of sampling steps to use for prediction. Default is 200.",
|
786
|
+
default=200,
|
787
|
+
)
|
788
|
+
@click.option(
|
789
|
+
"--diffusion_samples",
|
790
|
+
type=int,
|
791
|
+
help="The number of diffusion samples to use for prediction. Default is 1.",
|
792
|
+
default=1,
|
793
|
+
)
|
794
|
+
@click.option(
|
795
|
+
"--max_parallel_samples",
|
796
|
+
type=int,
|
797
|
+
help="The maximum number of samples to predict in parallel. Default is None.",
|
798
|
+
default=5,
|
799
|
+
)
|
800
|
+
@click.option(
|
801
|
+
"--step_scale",
|
802
|
+
type=float,
|
803
|
+
help=(
|
804
|
+
"The step size is related to the temperature at "
|
805
|
+
"which the diffusion process samples the distribution. "
|
806
|
+
"The lower the higher the diversity among samples "
|
807
|
+
"(recommended between 1 and 2). "
|
808
|
+
"Default is 1.638 for Boltz-1 and 1.5 for Boltz-2. "
|
809
|
+
"If not provided, the default step size will be used."
|
810
|
+
),
|
811
|
+
default=None,
|
812
|
+
)
|
813
|
+
@click.option(
|
814
|
+
"--write_full_pae",
|
815
|
+
type=bool,
|
816
|
+
is_flag=True,
|
817
|
+
help="Whether to dump the pae into a npz file. Default is True.",
|
818
|
+
)
|
819
|
+
@click.option(
|
820
|
+
"--write_full_pde",
|
821
|
+
type=bool,
|
822
|
+
is_flag=True,
|
823
|
+
help="Whether to dump the pde into a npz file. Default is False.",
|
824
|
+
)
|
825
|
+
@click.option(
|
826
|
+
"--output_format",
|
827
|
+
type=click.Choice(["pdb", "mmcif"]),
|
828
|
+
help="The output format to use for the predictions. Default is mmcif.",
|
829
|
+
default="mmcif",
|
830
|
+
)
|
831
|
+
@click.option(
|
832
|
+
"--num_workers",
|
833
|
+
type=int,
|
834
|
+
help="The number of dataloader workers to use for prediction. Default is 2.",
|
835
|
+
default=2,
|
836
|
+
)
|
837
|
+
@click.option(
|
838
|
+
"--override",
|
839
|
+
is_flag=True,
|
840
|
+
help="Whether to override existing found predictions. Default is False.",
|
841
|
+
)
|
842
|
+
@click.option(
|
843
|
+
"--seed",
|
844
|
+
type=int,
|
845
|
+
help="Seed to use for random number generator. Default is None (no seeding).",
|
846
|
+
default=None,
|
847
|
+
)
|
848
|
+
@click.option(
|
849
|
+
"--use_msa_server",
|
850
|
+
is_flag=True,
|
851
|
+
help="Whether to use the MMSeqs2 server for MSA generation. Default is False.",
|
852
|
+
)
|
853
|
+
@click.option(
|
854
|
+
"--msa_server_url",
|
855
|
+
type=str,
|
856
|
+
help="MSA server url. Used only if --use_msa_server is set. ",
|
857
|
+
default="https://api.colabfold.com",
|
858
|
+
)
|
859
|
+
@click.option(
|
860
|
+
"--msa_pairing_strategy",
|
861
|
+
type=str,
|
862
|
+
help=(
|
863
|
+
"Pairing strategy to use. Used only if --use_msa_server is set. "
|
864
|
+
"Options are 'greedy' and 'complete'"
|
865
|
+
),
|
866
|
+
default="greedy",
|
867
|
+
)
|
868
|
+
@click.option(
|
869
|
+
"--use_potentials",
|
870
|
+
is_flag=True,
|
871
|
+
help="Whether to not use potentials for steering. Default is False.",
|
872
|
+
)
|
873
|
+
@click.option(
|
874
|
+
"--model",
|
875
|
+
default="boltz2",
|
876
|
+
type=click.Choice(["boltz1", "boltz2"]),
|
877
|
+
help="The model to use for prediction. Default is boltz2.",
|
878
|
+
)
|
879
|
+
@click.option(
|
880
|
+
"--method",
|
881
|
+
type=str,
|
882
|
+
help="The method to use for prediction. Default is None.",
|
883
|
+
default=None,
|
884
|
+
)
|
885
|
+
@click.option(
|
886
|
+
"--preprocessing-threads",
|
887
|
+
type=int,
|
888
|
+
help="The number of threads to use for preprocessing. Default is 1.",
|
889
|
+
default=multiprocessing.cpu_count(),
|
890
|
+
)
|
891
|
+
@click.option(
|
892
|
+
"--affinity_mw_correction",
|
893
|
+
is_flag=True,
|
894
|
+
type=bool,
|
895
|
+
help="Whether to add the Molecular Weight correction to the affinity value head.",
|
896
|
+
)
|
897
|
+
@click.option(
|
898
|
+
"--sampling_steps_affinity",
|
899
|
+
type=int,
|
900
|
+
help="The number of sampling steps to use for affinity prediction. Default is 200.",
|
901
|
+
default=200,
|
902
|
+
)
|
903
|
+
@click.option(
|
904
|
+
"--diffusion_samples_affinity",
|
905
|
+
type=int,
|
906
|
+
help="The number of diffusion samples to use for affinity prediction. Default is 5.",
|
907
|
+
default=5,
|
908
|
+
)
|
909
|
+
@click.option(
|
910
|
+
"--affinity_checkpoint",
|
911
|
+
type=click.Path(exists=True),
|
912
|
+
help="An optional checkpoint, will use the provided Boltz-1 model by default.",
|
913
|
+
default=None,
|
914
|
+
)
|
915
|
+
@click.option(
|
916
|
+
"--max_msa_seqs",
|
917
|
+
type=int,
|
918
|
+
help="The maximum number of MSA sequences to use for prediction. Default is 8192.",
|
919
|
+
default=8192,
|
920
|
+
)
|
921
|
+
@click.option(
|
922
|
+
"--subsample_msa",
|
923
|
+
is_flag=True,
|
924
|
+
help="Whether to subsample the MSA. Default is True.",
|
925
|
+
)
|
926
|
+
@click.option(
|
927
|
+
"--num_subsampled_msa",
|
928
|
+
type=int,
|
929
|
+
help="The number of MSA sequences to subsample. Default is 1024.",
|
930
|
+
default=1024,
|
931
|
+
)
|
932
|
+
@click.option(
|
933
|
+
"--no_kernels",
|
934
|
+
is_flag=True,
|
935
|
+
help="Whether to disable the kernels. Default False",
|
936
|
+
)
|
937
|
+
def predict( # noqa: C901, PLR0915, PLR0912
|
938
|
+
data: str,
|
939
|
+
out_dir: str,
|
940
|
+
cache: str = "~/.boltz",
|
941
|
+
checkpoint: Optional[str] = None,
|
942
|
+
affinity_checkpoint: Optional[str] = None,
|
943
|
+
devices: int = 1,
|
944
|
+
accelerator: str = "gpu",
|
945
|
+
recycling_steps: int = 3,
|
946
|
+
sampling_steps: int = 200,
|
947
|
+
diffusion_samples: int = 1,
|
948
|
+
sampling_steps_affinity: int = 200,
|
949
|
+
diffusion_samples_affinity: int = 3,
|
950
|
+
max_parallel_samples: Optional[int] = None,
|
951
|
+
step_scale: Optional[float] = None,
|
952
|
+
write_full_pae: bool = False,
|
953
|
+
write_full_pde: bool = False,
|
954
|
+
output_format: Literal["pdb", "mmcif"] = "mmcif",
|
955
|
+
num_workers: int = 2,
|
956
|
+
override: bool = False,
|
957
|
+
seed: Optional[int] = None,
|
958
|
+
use_msa_server: bool = False,
|
959
|
+
msa_server_url: str = "https://api.colabfold.com",
|
960
|
+
msa_pairing_strategy: str = "greedy",
|
961
|
+
use_potentials: bool = False,
|
962
|
+
model: Literal["boltz1", "boltz2"] = "boltz2",
|
963
|
+
method: Optional[str] = None,
|
964
|
+
affinity_mw_correction: Optional[bool] = False,
|
965
|
+
preprocessing_threads: int = 1,
|
966
|
+
max_msa_seqs: int = 8192,
|
967
|
+
subsample_msa: bool = True,
|
968
|
+
num_subsampled_msa: int = 1024,
|
969
|
+
no_kernels: bool = False,
|
970
|
+
) -> None:
|
971
|
+
"""Run predictions with Boltz."""
|
972
|
+
# If cpu, write a friendly warning
|
973
|
+
if accelerator == "cpu":
|
974
|
+
msg = "Running on CPU, this will be slow. Consider using a GPU."
|
975
|
+
click.echo(msg)
|
976
|
+
|
977
|
+
# Supress some lightning warnings
|
978
|
+
warnings.filterwarnings(
|
979
|
+
"ignore", ".*that has Tensor Cores. To properly utilize them.*"
|
980
|
+
)
|
981
|
+
|
982
|
+
# Set no grad
|
983
|
+
torch.set_grad_enabled(False)
|
984
|
+
|
985
|
+
# Ignore matmul precision warning
|
986
|
+
torch.set_float32_matmul_precision("highest")
|
987
|
+
|
988
|
+
# Set rdkit pickle logic
|
989
|
+
Chem.SetDefaultPickleProperties(Chem.PropertyPickleOptions.AllProps)
|
990
|
+
|
991
|
+
# Set seed if desired
|
992
|
+
if seed is not None:
|
993
|
+
seed_everything(seed)
|
994
|
+
|
995
|
+
for key in ["CUEQ_DEFAULT_CONFIG", "CUEQ_DISABLE_AOT_TUNING"]:
|
996
|
+
# Disable kernel tuning by default,
|
997
|
+
# but do not modify envvar if already set by caller
|
998
|
+
os.environ[key] = os.environ.get(key, "1")
|
999
|
+
|
1000
|
+
# Set cache path
|
1001
|
+
cache = Path(cache).expanduser()
|
1002
|
+
cache.mkdir(parents=True, exist_ok=True)
|
1003
|
+
|
1004
|
+
# Create output directories
|
1005
|
+
data = Path(data).expanduser()
|
1006
|
+
out_dir = Path(out_dir).expanduser()
|
1007
|
+
out_dir = out_dir / f"boltz_results_{data.stem}"
|
1008
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
1009
|
+
|
1010
|
+
# Download necessary data and model
|
1011
|
+
if model == "boltz1":
|
1012
|
+
download_boltz1(cache)
|
1013
|
+
elif model == "boltz2":
|
1014
|
+
download_boltz2(cache)
|
1015
|
+
else:
|
1016
|
+
msg = f"Model {model} not supported. Supported: boltz1, boltz2."
|
1017
|
+
raise ValueError(f"Model {model} not supported.")
|
1018
|
+
|
1019
|
+
# Validate inputs
|
1020
|
+
data = check_inputs(data)
|
1021
|
+
|
1022
|
+
# Check method
|
1023
|
+
if method is not None:
|
1024
|
+
if model == "boltz1":
|
1025
|
+
msg = "Method conditioning is not supported for Boltz-1."
|
1026
|
+
raise ValueError(msg)
|
1027
|
+
if method.lower() not in const.method_types_ids:
|
1028
|
+
method_names = list(const.method_types_ids.keys())
|
1029
|
+
msg = f"Method {method} not supported. Supported: {method_names}"
|
1030
|
+
raise ValueError(msg)
|
1031
|
+
|
1032
|
+
# Process inputs
|
1033
|
+
ccd_path = cache / "ccd.pkl"
|
1034
|
+
mol_dir = cache / "mols"
|
1035
|
+
process_inputs(
|
1036
|
+
data=data,
|
1037
|
+
out_dir=out_dir,
|
1038
|
+
ccd_path=ccd_path,
|
1039
|
+
mol_dir=mol_dir,
|
1040
|
+
use_msa_server=use_msa_server,
|
1041
|
+
msa_server_url=msa_server_url,
|
1042
|
+
msa_pairing_strategy=msa_pairing_strategy,
|
1043
|
+
boltz2=model == "boltz2",
|
1044
|
+
preprocessing_threads=preprocessing_threads,
|
1045
|
+
max_msa_seqs=max_msa_seqs,
|
1046
|
+
)
|
1047
|
+
|
1048
|
+
# Load manifest
|
1049
|
+
manifest = Manifest.load(out_dir / "processed" / "manifest.json")
|
1050
|
+
|
1051
|
+
# Filter out existing predictions
|
1052
|
+
filtered_manifest = filter_inputs_structure(
|
1053
|
+
manifest=manifest,
|
1054
|
+
outdir=out_dir,
|
1055
|
+
override=override,
|
1056
|
+
)
|
1057
|
+
|
1058
|
+
# Load processed data
|
1059
|
+
processed_dir = out_dir / "processed"
|
1060
|
+
processed = BoltzProcessedInput(
|
1061
|
+
manifest=filtered_manifest,
|
1062
|
+
targets_dir=processed_dir / "structures",
|
1063
|
+
msa_dir=processed_dir / "msa",
|
1064
|
+
constraints_dir=(
|
1065
|
+
(processed_dir / "constraints")
|
1066
|
+
if (processed_dir / "constraints").exists()
|
1067
|
+
else None
|
1068
|
+
),
|
1069
|
+
template_dir=(
|
1070
|
+
(processed_dir / "templates")
|
1071
|
+
if (processed_dir / "templates").exists()
|
1072
|
+
else None
|
1073
|
+
),
|
1074
|
+
extra_mols_dir=(
|
1075
|
+
(processed_dir / "mols") if (processed_dir / "mols").exists() else None
|
1076
|
+
),
|
1077
|
+
)
|
1078
|
+
|
1079
|
+
# # Set up trainer
|
1080
|
+
# strategy = "auto"
|
1081
|
+
# if (isinstance(devices, int) and devices > 1) or (
|
1082
|
+
# isinstance(devices, list) and len(devices) > 1
|
1083
|
+
# ):
|
1084
|
+
# start_method = "fork" if platform.system() != "win32" else "spawn"
|
1085
|
+
# strategy = DDPStrategy(start_method=start_method)
|
1086
|
+
# if len(filtered_manifest.records) < devices:
|
1087
|
+
# msg = (
|
1088
|
+
# "Number of requested devices is greater "
|
1089
|
+
# "than the number of predictions, taking the minimum."
|
1090
|
+
# )
|
1091
|
+
# click.echo(msg)
|
1092
|
+
# if isinstance(devices, list):
|
1093
|
+
# devices = devices[: max(1, len(filtered_manifest.records))]
|
1094
|
+
# else:
|
1095
|
+
# devices = max(1, min(len(filtered_manifest.records), devices))
|
1096
|
+
|
1097
|
+
# # Set up model parameters
|
1098
|
+
# if model == "boltz2":
|
1099
|
+
# diffusion_params = Boltz2DiffusionParams()
|
1100
|
+
# step_scale = 1.5 if step_scale is None else step_scale
|
1101
|
+
# diffusion_params.step_scale = step_scale
|
1102
|
+
# pairformer_args = PairformerArgsV2()
|
1103
|
+
# else:
|
1104
|
+
# diffusion_params = BoltzDiffusionParams()
|
1105
|
+
# step_scale = 1.638 if step_scale is None else step_scale
|
1106
|
+
# diffusion_params.step_scale = step_scale
|
1107
|
+
# pairformer_args = PairformerArgs()
|
1108
|
+
|
1109
|
+
# msa_args = MSAModuleArgs(
|
1110
|
+
# subsample_msa=subsample_msa,
|
1111
|
+
# num_subsampled_msa=num_subsampled_msa,
|
1112
|
+
# use_paired_feature=model == "boltz2",
|
1113
|
+
# )
|
1114
|
+
|
1115
|
+
# # Create prediction writer
|
1116
|
+
# pred_writer = BoltzWriter(
|
1117
|
+
# data_dir=processed.targets_dir,
|
1118
|
+
# output_dir=out_dir / "predictions",
|
1119
|
+
# output_format=output_format,
|
1120
|
+
# boltz2=model == "boltz2",
|
1121
|
+
# )
|
1122
|
+
|
1123
|
+
# # Set up trainer
|
1124
|
+
# trainer = Trainer(
|
1125
|
+
# default_root_dir=out_dir,
|
1126
|
+
# strategy=strategy,
|
1127
|
+
# callbacks=[pred_writer],
|
1128
|
+
# accelerator=accelerator,
|
1129
|
+
# devices=devices,
|
1130
|
+
# precision=32 if model == "boltz1" else "bf16-mixed",
|
1131
|
+
# )
|
1132
|
+
|
1133
|
+
# if filtered_manifest.records:
|
1134
|
+
# msg = f"Running structure prediction for {len(filtered_manifest.records)} input"
|
1135
|
+
# msg += "s." if len(filtered_manifest.records) > 1 else "."
|
1136
|
+
# click.echo(msg)
|
1137
|
+
|
1138
|
+
# # Create data module
|
1139
|
+
# if model == "boltz2":
|
1140
|
+
# data_module = Boltz2InferenceDataModule(
|
1141
|
+
# manifest=processed.manifest,
|
1142
|
+
# target_dir=processed.targets_dir,
|
1143
|
+
# msa_dir=processed.msa_dir,
|
1144
|
+
# mol_dir=mol_dir,
|
1145
|
+
# num_workers=num_workers,
|
1146
|
+
# constraints_dir=processed.constraints_dir,
|
1147
|
+
# template_dir=processed.template_dir,
|
1148
|
+
# extra_mols_dir=processed.extra_mols_dir,
|
1149
|
+
# override_method=method,
|
1150
|
+
# )
|
1151
|
+
# else:
|
1152
|
+
# data_module = BoltzInferenceDataModule(
|
1153
|
+
# manifest=processed.manifest,
|
1154
|
+
# target_dir=processed.targets_dir,
|
1155
|
+
# msa_dir=processed.msa_dir,
|
1156
|
+
# num_workers=num_workers,
|
1157
|
+
# constraints_dir=processed.constraints_dir,
|
1158
|
+
# )
|
1159
|
+
|
1160
|
+
# # Load model
|
1161
|
+
# if checkpoint is None:
|
1162
|
+
# if model == "boltz2":
|
1163
|
+
# checkpoint = cache / "boltz2_conf.ckpt"
|
1164
|
+
# else:
|
1165
|
+
# checkpoint = cache / "boltz1_conf.ckpt"
|
1166
|
+
|
1167
|
+
# predict_args = {
|
1168
|
+
# "recycling_steps": recycling_steps,
|
1169
|
+
# "sampling_steps": sampling_steps,
|
1170
|
+
# "diffusion_samples": diffusion_samples,
|
1171
|
+
# "max_parallel_samples": max_parallel_samples,
|
1172
|
+
# "write_confidence_summary": True,
|
1173
|
+
# "write_full_pae": write_full_pae,
|
1174
|
+
# "write_full_pde": write_full_pde,
|
1175
|
+
# }
|
1176
|
+
|
1177
|
+
# steering_args = BoltzSteeringParams()
|
1178
|
+
# steering_args.fk_steering = use_potentials
|
1179
|
+
# steering_args.guidance_update = use_potentials
|
1180
|
+
|
1181
|
+
# model_cls = Boltz2 if model == "boltz2" else Boltz1
|
1182
|
+
# model_module = model_cls.load_from_checkpoint(
|
1183
|
+
# checkpoint,
|
1184
|
+
# strict=True,
|
1185
|
+
# predict_args=predict_args,
|
1186
|
+
# map_location="cpu",
|
1187
|
+
# diffusion_process_args=asdict(diffusion_params),
|
1188
|
+
# ema=False,
|
1189
|
+
# use_kernels=not no_kernels,
|
1190
|
+
# pairformer_args=asdict(pairformer_args),
|
1191
|
+
# msa_args=asdict(msa_args),
|
1192
|
+
# steering_args=asdict(steering_args),
|
1193
|
+
# )
|
1194
|
+
# model_module.eval()
|
1195
|
+
|
1196
|
+
# # Compute structure predictions
|
1197
|
+
# trainer.predict(
|
1198
|
+
# model_module,
|
1199
|
+
# datamodule=data_module,
|
1200
|
+
# return_predictions=False,
|
1201
|
+
# )
|
1202
|
+
|
1203
|
+
# # Check if affinity predictions are needed
|
1204
|
+
# if any(r.affinity for r in manifest.records):
|
1205
|
+
# # Print header
|
1206
|
+
# click.echo("\nPredicting property: affinity\n")
|
1207
|
+
|
1208
|
+
# # Validate inputs
|
1209
|
+
# manifest_filtered = filter_inputs_affinity(
|
1210
|
+
# manifest=manifest,
|
1211
|
+
# outdir=out_dir,
|
1212
|
+
# override=override,
|
1213
|
+
# )
|
1214
|
+
# if not manifest_filtered.records:
|
1215
|
+
# click.echo("Found existing affinity predictions for all inputs, skipping.")
|
1216
|
+
# return
|
1217
|
+
|
1218
|
+
# msg = f"Running affinity prediction for {len(manifest_filtered.records)} input"
|
1219
|
+
# msg += "s." if len(manifest_filtered.records) > 1 else "."
|
1220
|
+
# click.echo(msg)
|
1221
|
+
|
1222
|
+
# pred_writer = BoltzAffinityWriter(
|
1223
|
+
# data_dir=processed.targets_dir,
|
1224
|
+
# output_dir=out_dir / "predictions",
|
1225
|
+
# )
|
12
1226
|
|
13
|
-
|
14
|
-
|
1227
|
+
# data_module = Boltz2InferenceDataModule(
|
1228
|
+
# manifest=manifest_filtered,
|
1229
|
+
# target_dir=out_dir / "predictions",
|
1230
|
+
# msa_dir=processed.msa_dir,
|
1231
|
+
# mol_dir=mol_dir,
|
1232
|
+
# num_workers=num_workers,
|
1233
|
+
# constraints_dir=processed.constraints_dir,
|
1234
|
+
# template_dir=processed.template_dir,
|
1235
|
+
# extra_mols_dir=processed.extra_mols_dir,
|
1236
|
+
# override_method="other",
|
1237
|
+
# affinity=True,
|
1238
|
+
# )
|
15
1239
|
|
1240
|
+
# predict_affinity_args = {
|
1241
|
+
# "recycling_steps": 5,
|
1242
|
+
# "sampling_steps": sampling_steps_affinity,
|
1243
|
+
# "diffusion_samples": diffusion_samples_affinity,
|
1244
|
+
# "max_parallel_samples": 1,
|
1245
|
+
# "write_confidence_summary": False,
|
1246
|
+
# "write_full_pae": False,
|
1247
|
+
# "write_full_pde": False,
|
1248
|
+
# }
|
16
1249
|
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
parser.add_argument('--ligand_id', type=str, required=True, help='Ligand ID (used in output filename)')
|
21
|
-
parser.add_argument('--output', type=str, required=True, help='Output .npz file path')
|
22
|
-
args = parser.parse_args()
|
1250
|
+
# # Load affinity model
|
1251
|
+
# if affinity_checkpoint is None:
|
1252
|
+
# affinity_checkpoint = cache / "boltz2_aff.ckpt"
|
23
1253
|
|
24
|
-
|
25
|
-
|
1254
|
+
# model_module = Boltz2.load_from_checkpoint(
|
1255
|
+
# affinity_checkpoint,
|
1256
|
+
# strict=True,
|
1257
|
+
# predict_args=predict_affinity_args,
|
1258
|
+
# map_location="cpu",
|
1259
|
+
# diffusion_process_args=asdict(diffusion_params),
|
1260
|
+
# ema=False,
|
1261
|
+
# pairformer_args=asdict(pairformer_args),
|
1262
|
+
# msa_args=asdict(msa_args),
|
1263
|
+
# steering_args={"fk_steering": False, "guidance_update": False},
|
1264
|
+
# affinity_mw_correction=affinity_mw_correction,
|
1265
|
+
# )
|
1266
|
+
# model_module.eval()
|
26
1267
|
|
27
|
-
#
|
28
|
-
|
1268
|
+
# trainer.callbacks[0] = pred_writer
|
1269
|
+
# trainer.predict(
|
1270
|
+
# model_module,
|
1271
|
+
# datamodule=data_module,
|
1272
|
+
# return_predictions=False,
|
1273
|
+
# )
|
29
1274
|
|
30
|
-
# Save as pre_affinity_[ligand_id].npz
|
31
|
-
structure.dump(output_path)
|
32
|
-
print(f"Saved: {output_path}")
|
33
1275
|
|
34
1276
|
if __name__ == "__main__":
|
35
|
-
|
1277
|
+
cli()
|
@@ -108,11 +108,11 @@ boltz/model/potentials/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3
|
|
108
108
|
boltz/model/potentials/potentials.py,sha256=vev8Vjfs-ML1hyrdv_R8DynG4wSFahJ6nzPWp7CYQqw,17507
|
109
109
|
boltz/model/potentials/schedules.py,sha256=m7XJjfuF9uTX3bR9VisXv1rvzJjxiD8PobXRpcBBu1c,968
|
110
110
|
boltz/utils/sdf_splitter.py,sha256=ZHn_syOcmm-fDnJ3YEGyGv_vYz2IRzUW7vbbMSU2JBY,2108
|
111
|
-
boltz/utils/sdf_to_pre_affinity_npz.py,sha256=
|
111
|
+
boltz/utils/sdf_to_pre_affinity_npz.py,sha256=ENljAVhA7ZtDkUCp1xuJufTyVbuaQHZAe_vAl6ck-WE,40301
|
112
112
|
boltz/utils/yaml_generator.py,sha256=ermWIG-BE6nNWHFvpEwpk92N9J-YATpGXZGLvD1I2oQ,4012
|
113
|
-
boltz_vsynthes-0.0.
|
114
|
-
boltz_vsynthes-0.0.
|
115
|
-
boltz_vsynthes-0.0.
|
116
|
-
boltz_vsynthes-0.0.
|
117
|
-
boltz_vsynthes-0.0.
|
118
|
-
boltz_vsynthes-0.0.
|
113
|
+
boltz_vsynthes-0.0.13.dist-info/licenses/LICENSE,sha256=8GZ_1eZsUeG6jdqgJJxtciWzADfgLEV4LY8sKUOsJhc,1102
|
114
|
+
boltz_vsynthes-0.0.13.dist-info/METADATA,sha256=-ZCZOOLVwXKOYf2E0XY8Q1ZAlHQgZFp0Qj4_mPfeIqU,7235
|
115
|
+
boltz_vsynthes-0.0.13.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
116
|
+
boltz_vsynthes-0.0.13.dist-info/entry_points.txt,sha256=nZNYPKKrmAr-MVA0K-ClNRT2p90FV1_14d7HpsESZFQ,211
|
117
|
+
boltz_vsynthes-0.0.13.dist-info/top_level.txt,sha256=MgU3Jfb-ctWm07YGMts68PMjSh9v26D0gfG3dFRmVFA,6
|
118
|
+
boltz_vsynthes-0.0.13.dist-info/RECORD,,
|
@@ -1,5 +1,5 @@
|
|
1
1
|
[console_scripts]
|
2
2
|
boltz = boltz.main:cli
|
3
3
|
boltz-generate-yaml = boltz.utils.yaml_generator:main
|
4
|
-
boltz-sdf-to-pre-affinity = boltz.utils.sdf_to_pre_affinity_npz:
|
4
|
+
boltz-sdf-to-pre-affinity = boltz.utils.sdf_to_pre_affinity_npz:cli
|
5
5
|
boltz-split-sdf = boltz.utils.sdf_splitter:main
|
File without changes
|
File without changes
|
File without changes
|