boltz-vsynthes 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,35 +1,1286 @@
1
- #!/usr/bin/env python3
2
- """
3
- boltz-sdf-to-pre-affinity: Convert a docked SDF file (protein-ligand complex) to Boltz pre_affinity_[ligand_id].npz format.
1
+ import multiprocessing
2
+ import os
3
+ import pickle
4
+ import platform
5
+ import tarfile
6
+ import urllib.request
7
+ import warnings
8
+ from dataclasses import asdict, dataclass
9
+ from functools import partial
10
+ from multiprocessing import Pool
11
+ from pathlib import Path
12
+ from typing import Literal, Optional
4
13
 
5
- Usage:
6
- boltz-sdf-to-pre-affinity --sdf docked_pose.sdf --ligand_id ligand1 --output pre_affinity_ligand1.npz
14
+ import click
15
+ import torch
16
+ from pytorch_lightning import Trainer, seed_everything
17
+ from pytorch_lightning.strategies import DDPStrategy
18
+ from pytorch_lightning.utilities import rank_zero_only
19
+ from rdkit import Chem
20
+ from tqdm import tqdm
7
21
 
8
- This command is available after installing boltz with pip.
9
- """
10
- import argparse
11
- from pathlib import Path
22
+ from boltz.data import const
23
+ from boltz.data.module.inference import BoltzInferenceDataModule
24
+ from boltz.data.module.inferencev2 import Boltz2InferenceDataModule
25
+ from boltz.data.mol import load_canonicals
26
+ from boltz.data.msa.mmseqs2 import run_mmseqs2
27
+ from boltz.data.parse.a3m import parse_a3m
28
+ from boltz.data.parse.csv import parse_csv
29
+ from boltz.data.parse.fasta import parse_fasta
30
+ from boltz.data.parse.yaml import parse_yaml
31
+ from boltz.data.types import MSA, Manifest, Record
32
+ from boltz.data.write.writer import BoltzAffinityWriter, BoltzWriter
33
+ from boltz.model.models.boltz1 import Boltz1
34
+ from boltz.model.models.boltz2 import Boltz2
35
+
36
+ CCD_URL = "https://huggingface.co/boltz-community/boltz-1/resolve/main/ccd.pkl"
37
+ MOL_URL = "https://huggingface.co/boltz-community/boltz-2/resolve/main/mols.tar"
38
+
39
+ BOLTZ1_URL_WITH_FALLBACK = [
40
+ "https://model-gateway.boltz.bio/boltz1_conf.ckpt",
41
+ "https://huggingface.co/boltz-community/boltz-1/resolve/main/boltz1_conf.ckpt",
42
+ ]
43
+
44
+ BOLTZ2_URL_WITH_FALLBACK = [
45
+ "https://model-gateway.boltz.bio/boltz2_conf.ckpt",
46
+ "https://huggingface.co/boltz-community/boltz-2/resolve/main/boltz2_conf.ckpt",
47
+ ]
48
+
49
+ BOLTZ2_AFFINITY_URL_WITH_FALLBACK = [
50
+ "https://model-gateway.boltz.bio/boltz2_aff.ckpt",
51
+ "https://huggingface.co/boltz-community/boltz-2/resolve/main/boltz2_aff.ckpt",
52
+ ]
53
+
54
+
55
+ @dataclass
56
+ class BoltzProcessedInput:
57
+ """Processed input data."""
58
+
59
+ manifest: Manifest
60
+ targets_dir: Path
61
+ msa_dir: Path
62
+ constraints_dir: Optional[Path] = None
63
+ template_dir: Optional[Path] = None
64
+ extra_mols_dir: Optional[Path] = None
65
+
66
+
67
+ @dataclass
68
+ class PairformerArgs:
69
+ """Pairformer arguments."""
70
+
71
+ num_blocks: int = 48
72
+ num_heads: int = 16
73
+ dropout: float = 0.0
74
+ activation_checkpointing: bool = False
75
+ offload_to_cpu: bool = False
76
+ v2: bool = False
77
+
78
+
79
+ @dataclass
80
+ class PairformerArgsV2:
81
+ """Pairformer arguments."""
82
+
83
+ num_blocks: int = 64
84
+ num_heads: int = 16
85
+ dropout: float = 0.0
86
+ activation_checkpointing: bool = False
87
+ offload_to_cpu: bool = False
88
+ v2: bool = True
89
+
90
+
91
+ @dataclass
92
+ class MSAModuleArgs:
93
+ """MSA module arguments."""
94
+
95
+ msa_s: int = 64
96
+ msa_blocks: int = 4
97
+ msa_dropout: float = 0.0
98
+ z_dropout: float = 0.0
99
+ use_paired_feature: bool = True
100
+ pairwise_head_width: int = 32
101
+ pairwise_num_heads: int = 4
102
+ activation_checkpointing: bool = False
103
+ offload_to_cpu: bool = False
104
+ subsample_msa: bool = False
105
+ num_subsampled_msa: int = 1024
106
+
107
+
108
+ @dataclass
109
+ class BoltzDiffusionParams:
110
+ """Diffusion process parameters."""
111
+
112
+ gamma_0: float = 0.605
113
+ gamma_min: float = 1.107
114
+ noise_scale: float = 0.901
115
+ rho: float = 8
116
+ step_scale: float = 1.638
117
+ sigma_min: float = 0.0004
118
+ sigma_max: float = 160.0
119
+ sigma_data: float = 16.0
120
+ P_mean: float = -1.2
121
+ P_std: float = 1.5
122
+ coordinate_augmentation: bool = True
123
+ alignment_reverse_diff: bool = True
124
+ synchronize_sigmas: bool = True
125
+ use_inference_model_cache: bool = True
126
+
127
+
128
+ @dataclass
129
+ class Boltz2DiffusionParams:
130
+ """Diffusion process parameters."""
131
+
132
+ gamma_0: float = 0.8
133
+ gamma_min: float = 1.0
134
+ noise_scale: float = 1.003
135
+ rho: float = 7
136
+ step_scale: float = 1.5
137
+ sigma_min: float = 0.0001
138
+ sigma_max: float = 160.0
139
+ sigma_data: float = 16.0
140
+ P_mean: float = -1.2
141
+ P_std: float = 1.5
142
+ coordinate_augmentation: bool = True
143
+ alignment_reverse_diff: bool = True
144
+ synchronize_sigmas: bool = True
145
+
146
+
147
+ @dataclass
148
+ class BoltzSteeringParams:
149
+ """Steering parameters."""
150
+
151
+ fk_steering: bool = True
152
+ num_particles: int = 3
153
+ fk_lambda: float = 4.0
154
+ fk_resampling_interval: int = 3
155
+ guidance_update: bool = True
156
+ num_gd_steps: int = 20
157
+
158
+
159
+ @rank_zero_only
160
+ def download_boltz1(cache: Path) -> None:
161
+ """Download all the required data.
162
+
163
+ Parameters
164
+ ----------
165
+ cache : Path
166
+ The cache directory.
167
+
168
+ """
169
+ # Download CCD
170
+ ccd = cache / "ccd.pkl"
171
+ if not ccd.exists():
172
+ click.echo(
173
+ f"Downloading the CCD dictionary to {ccd}. You may "
174
+ "change the cache directory with the --cache flag."
175
+ )
176
+ urllib.request.urlretrieve(CCD_URL, str(ccd)) # noqa: S310
177
+
178
+ # Download model
179
+ model = cache / "boltz1_conf.ckpt"
180
+ if not model.exists():
181
+ click.echo(
182
+ f"Downloading the model weights to {model}. You may "
183
+ "change the cache directory with the --cache flag."
184
+ )
185
+ for i, url in enumerate(BOLTZ1_URL_WITH_FALLBACK):
186
+ try:
187
+ urllib.request.urlretrieve(url, str(model)) # noqa: S310
188
+ break
189
+ except Exception as e: # noqa: BLE001
190
+ if i == len(BOLTZ1_URL_WITH_FALLBACK) - 1:
191
+ msg = f"Failed to download model from all URLs. Last error: {e}"
192
+ raise RuntimeError(msg) from e
193
+ continue
194
+
195
+
196
+ @rank_zero_only
197
+ def download_boltz2(cache: Path) -> None:
198
+ """Download all the required data.
199
+
200
+ Parameters
201
+ ----------
202
+ cache : Path
203
+ The cache directory.
204
+
205
+ """
206
+ # Download CCD
207
+ mols = cache / "mols"
208
+ tar_mols = cache / "mols.tar"
209
+ if not tar_mols.exists():
210
+ click.echo(
211
+ f"Downloading the CCD data to {tar_mols}. "
212
+ "This may take a bit of time. You may change the cache directory "
213
+ "with the --cache flag."
214
+ )
215
+ urllib.request.urlretrieve(MOL_URL, str(tar_mols)) # noqa: S310
216
+ if not mols.exists():
217
+ click.echo(
218
+ f"Extracting the CCD data to {mols}. "
219
+ "This may take a bit of time. You may change the cache directory "
220
+ "with the --cache flag."
221
+ )
222
+ with tarfile.open(str(tar_mols), "r") as tar:
223
+ tar.extractall(cache) # noqa: S202
224
+
225
+ # Download model
226
+ model = cache / "boltz2_conf.ckpt"
227
+ if not model.exists():
228
+ click.echo(
229
+ f"Downloading the Boltz-2 weights to {model}. You may "
230
+ "change the cache directory with the --cache flag."
231
+ )
232
+ for i, url in enumerate(BOLTZ2_URL_WITH_FALLBACK):
233
+ try:
234
+ urllib.request.urlretrieve(url, str(model)) # noqa: S310
235
+ break
236
+ except Exception as e: # noqa: BLE001
237
+ if i == len(BOLTZ2_URL_WITH_FALLBACK) - 1:
238
+ msg = f"Failed to download model from all URLs. Last error: {e}"
239
+ raise RuntimeError(msg) from e
240
+ continue
241
+
242
+ # Download affinity model
243
+ affinity_model = cache / "boltz2_aff.ckpt"
244
+ if not affinity_model.exists():
245
+ click.echo(
246
+ f"Downloading the Boltz-2 affinity weights to {affinity_model}. You may "
247
+ "change the cache directory with the --cache flag."
248
+ )
249
+ for i, url in enumerate(BOLTZ2_AFFINITY_URL_WITH_FALLBACK):
250
+ try:
251
+ urllib.request.urlretrieve(url, str(affinity_model)) # noqa: S310
252
+ break
253
+ except Exception as e: # noqa: BLE001
254
+ if i == len(BOLTZ2_AFFINITY_URL_WITH_FALLBACK) - 1:
255
+ msg = f"Failed to download model from all URLs. Last error: {e}"
256
+ raise RuntimeError(msg) from e
257
+ continue
258
+
259
+
260
+ def get_cache_path() -> str:
261
+ """Determine the cache path, prioritising the BOLTZ_CACHE environment variable.
262
+
263
+ Returns
264
+ -------
265
+ str: Path
266
+ Path to use for boltz cache location.
267
+
268
+ """
269
+ env_cache = os.environ.get("BOLTZ_CACHE")
270
+ if env_cache:
271
+ resolved_cache = Path(env_cache).expanduser().resolve()
272
+ if not resolved_cache.is_absolute():
273
+ msg = f"BOLTZ_CACHE must be an absolute path, got: {env_cache}"
274
+ raise ValueError(msg)
275
+ return str(resolved_cache)
276
+
277
+ return str(Path("~/.boltz").expanduser())
278
+
279
+
280
+ def check_inputs(data: Path) -> list[Path]:
281
+ """Check the input data and output directory.
282
+
283
+ Parameters
284
+ ----------
285
+ data : Path
286
+ The input data.
287
+
288
+ Returns
289
+ -------
290
+ list[Path]
291
+ The list of input data.
292
+
293
+ """
294
+ click.echo("Checking input data.")
295
+
296
+ # Check if data is a directory
297
+ if data.is_dir():
298
+ data: list[Path] = list(data.glob("*"))
299
+
300
+ # Filter out non .fasta or .yaml files, raise
301
+ # an error on directory and other file types
302
+ for d in data:
303
+ if d.is_dir():
304
+ msg = f"Found directory {d} instead of .fasta or .yaml."
305
+ raise RuntimeError(msg)
306
+ if d.suffix not in (".fa", ".fas", ".fasta", ".yml", ".yaml"):
307
+ msg = (
308
+ f"Unable to parse filetype {d.suffix}, "
309
+ "please provide a .fasta or .yaml file."
310
+ )
311
+ raise RuntimeError(msg)
312
+ else:
313
+ data = [data]
314
+
315
+ return data
316
+
317
+
318
+ def filter_inputs_structure(
319
+ manifest: Manifest,
320
+ outdir: Path,
321
+ override: bool = False,
322
+ ) -> Manifest:
323
+ """Filter the manifest to only include missing predictions.
324
+
325
+ Parameters
326
+ ----------
327
+ manifest : Manifest
328
+ The manifest of the input data.
329
+ outdir : Path
330
+ The output directory.
331
+ override: bool
332
+ Whether to override existing predictions.
333
+
334
+ Returns
335
+ -------
336
+ Manifest
337
+ The manifest of the filtered input data.
338
+
339
+ """
340
+ # Check if existing predictions are found
341
+ existing = (outdir / "predictions").rglob("*")
342
+ existing = {e.name for e in existing if e.is_dir()}
343
+
344
+ # Remove them from the input data
345
+ if existing and not override:
346
+ manifest = Manifest([r for r in manifest.records if r.id not in existing])
347
+ msg = (
348
+ f"Found some existing predictions ({len(existing)}), "
349
+ f"skipping and running only the missing ones, "
350
+ "if any. If you wish to override these existing "
351
+ "predictions, please set the --override flag."
352
+ )
353
+ click.echo(msg)
354
+ elif existing and override:
355
+ msg = f"Found {len(existing)} existing predictions, will override."
356
+ click.echo(msg)
357
+
358
+ return manifest
359
+
360
+
361
+ def filter_inputs_affinity(
362
+ manifest: Manifest,
363
+ outdir: Path,
364
+ override: bool = False,
365
+ ) -> Manifest:
366
+ """Check the input data and output directory for affinity.
367
+
368
+ Parameters
369
+ ----------
370
+ manifest : Manifest
371
+ The manifest.
372
+ outdir : Path
373
+ The output directory.
374
+ override: bool
375
+ Whether to override existing predictions.
376
+
377
+ Returns
378
+ -------
379
+ Manifest
380
+ The manifest of the filtered input data.
381
+
382
+ """
383
+ click.echo("Checking input data for affinity.")
384
+
385
+ # Get all affinity targets
386
+ existing = {
387
+ r.id
388
+ for r in manifest.records
389
+ if r.affinity
390
+ and (outdir / "predictions" / r.id / f"affinity_{r.id}.json").exists()
391
+ }
392
+
393
+ # Remove them from the input data
394
+ if existing and not override:
395
+ num_skipped = len(existing)
396
+ msg = (
397
+ f"Found some existing affinity predictions ({num_skipped}), "
398
+ f"skipping and running only the missing ones, "
399
+ "if any. If you wish to override these existing "
400
+ "affinity predictions, please set the --override flag."
401
+ )
402
+ click.echo(msg)
403
+ elif existing and override:
404
+ msg = "Found existing affinity predictions, will override."
405
+ click.echo(msg)
406
+
407
+ return Manifest([r for r in manifest.records if r.id not in existing])
408
+
409
+
410
+ def compute_msa(
411
+ data: dict[str, str],
412
+ target_id: str,
413
+ msa_dir: Path,
414
+ msa_server_url: str,
415
+ msa_pairing_strategy: str,
416
+ ) -> None:
417
+ """Compute the MSA for the input data.
418
+
419
+ Parameters
420
+ ----------
421
+ data : dict[str, str]
422
+ The input protein sequences.
423
+ target_id : str
424
+ The target id.
425
+ msa_dir : Path
426
+ The msa directory.
427
+ msa_server_url : str
428
+ The MSA server URL.
429
+ msa_pairing_strategy : str
430
+ The MSA pairing strategy.
431
+
432
+ """
433
+ if len(data) > 1:
434
+ paired_msas = run_mmseqs2(
435
+ list(data.values()),
436
+ msa_dir / f"{target_id}_paired_tmp",
437
+ use_env=True,
438
+ use_pairing=True,
439
+ host_url=msa_server_url,
440
+ pairing_strategy=msa_pairing_strategy,
441
+ )
442
+ else:
443
+ paired_msas = [""] * len(data)
444
+
445
+ unpaired_msa = run_mmseqs2(
446
+ list(data.values()),
447
+ msa_dir / f"{target_id}_unpaired_tmp",
448
+ use_env=True,
449
+ use_pairing=False,
450
+ host_url=msa_server_url,
451
+ pairing_strategy=msa_pairing_strategy,
452
+ )
453
+
454
+ for idx, name in enumerate(data):
455
+ # Get paired sequences
456
+ paired = paired_msas[idx].strip().splitlines()
457
+ paired = paired[1::2] # ignore headers
458
+ paired = paired[: const.max_paired_seqs]
459
+
460
+ # Set key per row and remove empty sequences
461
+ keys = [idx for idx, s in enumerate(paired) if s != "-" * len(s)]
462
+ paired = [s for s in paired if s != "-" * len(s)]
463
+
464
+ # Combine paired-unpaired sequences
465
+ unpaired = unpaired_msa[idx].strip().splitlines()
466
+ unpaired = unpaired[1::2]
467
+ unpaired = unpaired[: (const.max_msa_seqs - len(paired))]
468
+ if paired:
469
+ unpaired = unpaired[1:] # ignore query is already present
470
+
471
+ # Combine
472
+ seqs = paired + unpaired
473
+ keys = keys + [-1] * len(unpaired)
474
+
475
+ # Dump MSA
476
+ csv_str = ["key,sequence"] + [f"{key},{seq}" for key, seq in zip(keys, seqs)]
477
+
478
+ msa_path = msa_dir / f"{name}.csv"
479
+ with msa_path.open("w") as f:
480
+ f.write("\n".join(csv_str))
481
+
482
+
483
+ def process_input( # noqa: C901, PLR0912, PLR0915, D103
484
+ path: Path,
485
+ ccd: dict,
486
+ msa_dir: Path,
487
+ mol_dir: Path,
488
+ boltz2: bool,
489
+ use_msa_server: bool,
490
+ msa_server_url: str,
491
+ msa_pairing_strategy: str,
492
+ max_msa_seqs: int,
493
+ processed_msa_dir: Path,
494
+ processed_constraints_dir: Path,
495
+ processed_templates_dir: Path,
496
+ processed_mols_dir: Path,
497
+ structure_dir: Path,
498
+ records_dir: Path,
499
+ ) -> None:
500
+ try:
501
+ # Parse data
502
+ if path.suffix in (".fa", ".fas", ".fasta"):
503
+ target = parse_fasta(path, ccd, mol_dir, boltz2)
504
+ elif path.suffix in (".yml", ".yaml"):
505
+ target = parse_yaml(path, ccd, mol_dir, boltz2)
506
+ elif path.is_dir():
507
+ msg = f"Found directory {path} instead of .fasta or .yaml, skipping."
508
+ raise RuntimeError(msg) # noqa: TRY301
509
+ else:
510
+ msg = (
511
+ f"Unable to parse filetype {path.suffix}, "
512
+ "please provide a .fasta or .yaml file."
513
+ )
514
+ raise RuntimeError(msg) # noqa: TRY301
515
+
516
+ # Get target id
517
+ target_id = target.record.id
518
+
519
+ # Get all MSA ids and decide whether to generate MSA
520
+ to_generate = {}
521
+ prot_id = const.chain_type_ids["PROTEIN"]
522
+ for chain in target.record.chains:
523
+ # Add to generate list, assigning entity id
524
+ if (chain.mol_type == prot_id) and (chain.msa_id == 0):
525
+ entity_id = chain.entity_id
526
+ msa_id = f"{target_id}_{entity_id}"
527
+ to_generate[msa_id] = target.sequences[entity_id]
528
+ chain.msa_id = msa_dir / f"{msa_id}.csv"
529
+
530
+ # We do not support msa generation for non-protein chains
531
+ elif chain.msa_id == 0:
532
+ chain.msa_id = -1
533
+
534
+ # Generate MSA
535
+ if to_generate and not use_msa_server:
536
+ msg = "Missing MSA's in input and --use_msa_server flag not set."
537
+ raise RuntimeError(msg) # noqa: TRY301
538
+
539
+ if to_generate:
540
+ msg = f"Generating MSA for {path} with {len(to_generate)} protein entities."
541
+ click.echo(msg)
542
+ compute_msa(
543
+ data=to_generate,
544
+ target_id=target_id,
545
+ msa_dir=msa_dir,
546
+ msa_server_url=msa_server_url,
547
+ msa_pairing_strategy=msa_pairing_strategy,
548
+ )
549
+
550
+ # Parse MSA data
551
+ msas = sorted({c.msa_id for c in target.record.chains if c.msa_id != -1})
552
+ msa_id_map = {}
553
+ for msa_idx, msa_id in enumerate(msas):
554
+ # Check that raw MSA exists
555
+ msa_path = Path(msa_id)
556
+ if not msa_path.exists():
557
+ msg = f"MSA file {msa_path} not found."
558
+ raise FileNotFoundError(msg) # noqa: TRY301
559
+
560
+ # Dump processed MSA
561
+ processed = processed_msa_dir / f"{target_id}_{msa_idx}.npz"
562
+ msa_id_map[msa_id] = f"{target_id}_{msa_idx}"
563
+ if not processed.exists():
564
+ # Parse A3M
565
+ if msa_path.suffix == ".a3m":
566
+ msa: MSA = parse_a3m(
567
+ msa_path,
568
+ taxonomy=None,
569
+ max_seqs=max_msa_seqs,
570
+ )
571
+ elif msa_path.suffix == ".csv":
572
+ msa: MSA = parse_csv(msa_path, max_seqs=max_msa_seqs)
573
+ else:
574
+ msg = f"MSA file {msa_path} not supported, only a3m or csv."
575
+ raise RuntimeError(msg) # noqa: TRY301
576
+
577
+ msa.dump(processed)
578
+
579
+ # Modify records to point to processed MSA
580
+ for c in target.record.chains:
581
+ if (c.msa_id != -1) and (c.msa_id in msa_id_map):
582
+ c.msa_id = msa_id_map[c.msa_id]
583
+
584
+ # Dump templates
585
+ for template_id, template in target.templates.items():
586
+ name = f"{target.record.id}_{template_id}.npz"
587
+ template_path = processed_templates_dir / name
588
+ template.dump(template_path)
589
+
590
+ # Dump constraints
591
+ constraints_path = processed_constraints_dir / f"{target.record.id}.npz"
592
+ target.residue_constraints.dump(constraints_path)
593
+
594
+ # Dump extra molecules
595
+ Chem.SetDefaultPickleProperties(Chem.PropertyPickleOptions.AllProps)
596
+ with (processed_mols_dir / f"{target.record.id}.pkl").open("wb") as f:
597
+ pickle.dump(target.extra_mols, f)
598
+
599
+ # Dump structure
600
+ struct_path = structure_dir / f"{target.record.id}.npz"
601
+ target.structure.dump(struct_path)
602
+
603
+ # Dump record
604
+ record_path = records_dir / f"{target.record.id}.json"
605
+ target.record.dump(record_path)
606
+
607
+ except Exception as e: # noqa: BLE001
608
+ import traceback
609
+
610
+ traceback.print_exc()
611
+ print(f"Failed to process {path}. Skipping. Error: {e}.") # noqa: T201
612
+
613
+
614
+ @rank_zero_only
615
+ def process_inputs(
616
+ data: list[Path],
617
+ out_dir: Path,
618
+ ccd_path: Path,
619
+ mol_dir: Path,
620
+ msa_server_url: str,
621
+ msa_pairing_strategy: str,
622
+ max_msa_seqs: int = 8192,
623
+ use_msa_server: bool = False,
624
+ boltz2: bool = False,
625
+ preprocessing_threads: int = 1,
626
+ ) -> Manifest:
627
+ """Process the input data and output directory.
628
+
629
+ Parameters
630
+ ----------
631
+ data : list[Path]
632
+ The input data.
633
+ out_dir : Path
634
+ The output directory.
635
+ ccd_path : Path
636
+ The path to the CCD dictionary.
637
+ max_msa_seqs : int, optional
638
+ Max number of MSA sequences, by default 4096.
639
+ use_msa_server : bool, optional
640
+ Whether to use the MMSeqs2 server for MSA generation, by default False.
641
+ boltz2: bool, optional
642
+ Whether to use Boltz2, by default False.
643
+ preprocessing_threads: int, optional
644
+ The number of threads to use for preprocessing, by default 1.
645
+
646
+ Returns
647
+ -------
648
+ Manifest
649
+ The manifest of the processed input data.
650
+
651
+ """
652
+ # Check if records exist at output path
653
+ records_dir = out_dir / "processed" / "records"
654
+ if records_dir.exists():
655
+ # Load existing records
656
+ existing = [Record.load(p) for p in records_dir.glob("*.json")]
657
+ processed_ids = {record.id for record in existing}
658
+
659
+ # Filter to missing only
660
+ data = [d for d in data if d.stem not in processed_ids]
661
+
662
+ # Nothing to do, update the manifest and return
663
+ if data:
664
+ click.echo(
665
+ f"Found {len(existing)} existing processed inputs, skipping them."
666
+ )
667
+ else:
668
+ click.echo("All inputs are already processed.")
669
+ updated_manifest = Manifest(existing)
670
+ updated_manifest.dump(out_dir / "processed" / "manifest.json")
671
+
672
+ # Create output directories
673
+ msa_dir = out_dir / "msa"
674
+ records_dir = out_dir / "processed" / "records"
675
+ structure_dir = out_dir / "processed" / "structures"
676
+ processed_msa_dir = out_dir / "processed" / "msa"
677
+ processed_constraints_dir = out_dir / "processed" / "constraints"
678
+ processed_templates_dir = out_dir / "processed" / "templates"
679
+ processed_mols_dir = out_dir / "processed" / "mols"
680
+ predictions_dir = out_dir / "predictions"
681
+
682
+ out_dir.mkdir(parents=True, exist_ok=True)
683
+ msa_dir.mkdir(parents=True, exist_ok=True)
684
+ records_dir.mkdir(parents=True, exist_ok=True)
685
+ structure_dir.mkdir(parents=True, exist_ok=True)
686
+ processed_msa_dir.mkdir(parents=True, exist_ok=True)
687
+ processed_constraints_dir.mkdir(parents=True, exist_ok=True)
688
+ processed_templates_dir.mkdir(parents=True, exist_ok=True)
689
+ processed_mols_dir.mkdir(parents=True, exist_ok=True)
690
+ predictions_dir.mkdir(parents=True, exist_ok=True)
691
+
692
+ # Load CCD
693
+ if boltz2:
694
+ ccd = load_canonicals(mol_dir)
695
+ else:
696
+ with ccd_path.open("rb") as file:
697
+ ccd = pickle.load(file) # noqa: S301
698
+
699
+ # Create partial function
700
+ process_input_partial = partial(
701
+ process_input,
702
+ ccd=ccd,
703
+ msa_dir=msa_dir,
704
+ mol_dir=mol_dir,
705
+ boltz2=boltz2,
706
+ use_msa_server=use_msa_server,
707
+ msa_server_url=msa_server_url,
708
+ msa_pairing_strategy=msa_pairing_strategy,
709
+ max_msa_seqs=max_msa_seqs,
710
+ processed_msa_dir=processed_msa_dir,
711
+ processed_constraints_dir=processed_constraints_dir,
712
+ processed_templates_dir=processed_templates_dir,
713
+ processed_mols_dir=processed_mols_dir,
714
+ structure_dir=structure_dir,
715
+ records_dir=records_dir,
716
+ )
717
+
718
+ # Parse input data
719
+ preprocessing_threads = min(preprocessing_threads, len(data))
720
+ click.echo(f"Processing {len(data)} inputs with {preprocessing_threads} threads.")
721
+
722
+ if preprocessing_threads > 1 and len(data) > 1:
723
+ with Pool(preprocessing_threads) as pool:
724
+ list(tqdm(pool.imap(process_input_partial, data), total=len(data)))
725
+ else:
726
+ for path in tqdm(data):
727
+ process_input_partial(path)
728
+
729
+ # Load all records and write manifest
730
+ records = [Record.load(p) for p in records_dir.glob("*.json")]
731
+ manifest = Manifest(records)
732
+ manifest.dump(out_dir / "processed" / "manifest.json")
733
+
734
+
735
+ @click.group()
736
+ def cli() -> None:
737
+ """Boltz."""
738
+ return
739
+
740
+
741
+ @cli.command()
742
+ @click.argument("data", type=click.Path(exists=True))
743
+ @click.option(
744
+ "--out_dir",
745
+ type=click.Path(exists=False),
746
+ help="The path where to save the predictions.",
747
+ default="./",
748
+ )
749
+ @click.option(
750
+ "--cache",
751
+ type=click.Path(exists=False),
752
+ help=(
753
+ "The directory where to download the data and model. "
754
+ "Default is ~/.boltz, or $BOLTZ_CACHE if set."
755
+ ),
756
+ default=get_cache_path,
757
+ )
758
+ @click.option(
759
+ "--checkpoint",
760
+ type=click.Path(exists=True),
761
+ help="An optional checkpoint, will use the provided Boltz-1 model by default.",
762
+ default=None,
763
+ )
764
+ @click.option(
765
+ "--devices",
766
+ type=int,
767
+ help="The number of devices to use for prediction. Default is 1.",
768
+ default=1,
769
+ )
770
+ @click.option(
771
+ "--accelerator",
772
+ type=click.Choice(["gpu", "cpu", "tpu"]),
773
+ help="The accelerator to use for prediction. Default is gpu.",
774
+ default="gpu",
775
+ )
776
+ @click.option(
777
+ "--recycling_steps",
778
+ type=int,
779
+ help="The number of recycling steps to use for prediction. Default is 3.",
780
+ default=3,
781
+ )
782
+ @click.option(
783
+ "--sampling_steps",
784
+ type=int,
785
+ help="The number of sampling steps to use for prediction. Default is 200.",
786
+ default=200,
787
+ )
788
+ @click.option(
789
+ "--diffusion_samples",
790
+ type=int,
791
+ help="The number of diffusion samples to use for prediction. Default is 1.",
792
+ default=1,
793
+ )
794
+ @click.option(
795
+ "--max_parallel_samples",
796
+ type=int,
797
+ help="The maximum number of samples to predict in parallel. Default is None.",
798
+ default=5,
799
+ )
800
+ @click.option(
801
+ "--step_scale",
802
+ type=float,
803
+ help=(
804
+ "The step size is related to the temperature at "
805
+ "which the diffusion process samples the distribution. "
806
+ "The lower the higher the diversity among samples "
807
+ "(recommended between 1 and 2). "
808
+ "Default is 1.638 for Boltz-1 and 1.5 for Boltz-2. "
809
+ "If not provided, the default step size will be used."
810
+ ),
811
+ default=None,
812
+ )
813
+ @click.option(
814
+ "--write_full_pae",
815
+ type=bool,
816
+ is_flag=True,
817
+ help="Whether to dump the pae into a npz file. Default is True.",
818
+ )
819
+ @click.option(
820
+ "--write_full_pde",
821
+ type=bool,
822
+ is_flag=True,
823
+ help="Whether to dump the pde into a npz file. Default is False.",
824
+ )
825
+ @click.option(
826
+ "--output_format",
827
+ type=click.Choice(["pdb", "mmcif"]),
828
+ help="The output format to use for the predictions. Default is mmcif.",
829
+ default="mmcif",
830
+ )
831
+ @click.option(
832
+ "--num_workers",
833
+ type=int,
834
+ help="The number of dataloader workers to use for prediction. Default is 2.",
835
+ default=2,
836
+ )
837
+ @click.option(
838
+ "--override",
839
+ is_flag=True,
840
+ help="Whether to override existing found predictions. Default is False.",
841
+ )
842
+ @click.option(
843
+ "--seed",
844
+ type=int,
845
+ help="Seed to use for random number generator. Default is None (no seeding).",
846
+ default=None,
847
+ )
848
+ @click.option(
849
+ "--use_msa_server",
850
+ is_flag=True,
851
+ help="Whether to use the MMSeqs2 server for MSA generation. Default is False.",
852
+ )
853
+ @click.option(
854
+ "--msa_server_url",
855
+ type=str,
856
+ help="MSA server url. Used only if --use_msa_server is set. ",
857
+ default="https://api.colabfold.com",
858
+ )
859
+ @click.option(
860
+ "--msa_pairing_strategy",
861
+ type=str,
862
+ help=(
863
+ "Pairing strategy to use. Used only if --use_msa_server is set. "
864
+ "Options are 'greedy' and 'complete'"
865
+ ),
866
+ default="greedy",
867
+ )
868
+ @click.option(
869
+ "--use_potentials",
870
+ is_flag=True,
871
+ help="Whether to not use potentials for steering. Default is False.",
872
+ )
873
+ @click.option(
874
+ "--model",
875
+ default="boltz2",
876
+ type=click.Choice(["boltz1", "boltz2"]),
877
+ help="The model to use for prediction. Default is boltz2.",
878
+ )
879
+ @click.option(
880
+ "--method",
881
+ type=str,
882
+ help="The method to use for prediction. Default is None.",
883
+ default=None,
884
+ )
885
+ @click.option(
886
+ "--preprocessing-threads",
887
+ type=int,
888
+ help="The number of threads to use for preprocessing. Default is 1.",
889
+ default=multiprocessing.cpu_count(),
890
+ )
891
+ @click.option(
892
+ "--affinity_mw_correction",
893
+ is_flag=True,
894
+ type=bool,
895
+ help="Whether to add the Molecular Weight correction to the affinity value head.",
896
+ )
897
+ @click.option(
898
+ "--sampling_steps_affinity",
899
+ type=int,
900
+ help="The number of sampling steps to use for affinity prediction. Default is 200.",
901
+ default=200,
902
+ )
903
+ @click.option(
904
+ "--diffusion_samples_affinity",
905
+ type=int,
906
+ help="The number of diffusion samples to use for affinity prediction. Default is 5.",
907
+ default=5,
908
+ )
909
+ @click.option(
910
+ "--affinity_checkpoint",
911
+ type=click.Path(exists=True),
912
+ help="An optional checkpoint, will use the provided Boltz-1 model by default.",
913
+ default=None,
914
+ )
915
+ @click.option(
916
+ "--max_msa_seqs",
917
+ type=int,
918
+ help="The maximum number of MSA sequences to use for prediction. Default is 8192.",
919
+ default=8192,
920
+ )
921
+ @click.option(
922
+ "--subsample_msa",
923
+ is_flag=True,
924
+ help="Whether to subsample the MSA. Default is True.",
925
+ )
926
+ @click.option(
927
+ "--num_subsampled_msa",
928
+ type=int,
929
+ help="The number of MSA sequences to subsample. Default is 1024.",
930
+ default=1024,
931
+ )
932
+ @click.option(
933
+ "--no_kernels",
934
+ is_flag=True,
935
+ help="Whether to disable the kernels. Default False",
936
+ )
937
+ def predict( # noqa: C901, PLR0915, PLR0912
938
+ data: str,
939
+ out_dir: str,
940
+ cache: str = "~/.boltz",
941
+ checkpoint: Optional[str] = None,
942
+ affinity_checkpoint: Optional[str] = None,
943
+ devices: int = 1,
944
+ accelerator: str = "gpu",
945
+ recycling_steps: int = 3,
946
+ sampling_steps: int = 200,
947
+ diffusion_samples: int = 1,
948
+ sampling_steps_affinity: int = 200,
949
+ diffusion_samples_affinity: int = 3,
950
+ max_parallel_samples: Optional[int] = None,
951
+ step_scale: Optional[float] = None,
952
+ write_full_pae: bool = False,
953
+ write_full_pde: bool = False,
954
+ output_format: Literal["pdb", "mmcif"] = "mmcif",
955
+ num_workers: int = 2,
956
+ override: bool = False,
957
+ seed: Optional[int] = None,
958
+ use_msa_server: bool = False,
959
+ msa_server_url: str = "https://api.colabfold.com",
960
+ msa_pairing_strategy: str = "greedy",
961
+ use_potentials: bool = False,
962
+ model: Literal["boltz1", "boltz2"] = "boltz2",
963
+ method: Optional[str] = None,
964
+ affinity_mw_correction: Optional[bool] = False,
965
+ preprocessing_threads: int = 1,
966
+ max_msa_seqs: int = 8192,
967
+ subsample_msa: bool = True,
968
+ num_subsampled_msa: int = 1024,
969
+ no_kernels: bool = False,
970
+ ) -> None:
971
+ """Run predictions with Boltz."""
972
+ # If cpu, write a friendly warning
973
+ if accelerator == "cpu":
974
+ msg = "Running on CPU, this will be slow. Consider using a GPU."
975
+ click.echo(msg)
976
+
977
+ # Supress some lightning warnings
978
+ warnings.filterwarnings(
979
+ "ignore", ".*that has Tensor Cores. To properly utilize them.*"
980
+ )
981
+
982
+ # Set no grad
983
+ torch.set_grad_enabled(False)
984
+
985
+ # Ignore matmul precision warning
986
+ torch.set_float32_matmul_precision("highest")
987
+
988
+ # Set rdkit pickle logic
989
+ Chem.SetDefaultPickleProperties(Chem.PropertyPickleOptions.AllProps)
990
+
991
+ # Set seed if desired
992
+ if seed is not None:
993
+ seed_everything(seed)
994
+
995
+ for key in ["CUEQ_DEFAULT_CONFIG", "CUEQ_DISABLE_AOT_TUNING"]:
996
+ # Disable kernel tuning by default,
997
+ # but do not modify envvar if already set by caller
998
+ os.environ[key] = os.environ.get(key, "1")
999
+
1000
+ # Set cache path
1001
+ cache = Path(cache).expanduser()
1002
+ cache.mkdir(parents=True, exist_ok=True)
1003
+
1004
+ # Create output directories
1005
+ data = Path(data).expanduser()
1006
+ out_dir = Path(out_dir).expanduser()
1007
+ out_dir = out_dir / f"boltz_results_{data.stem}"
1008
+ out_dir.mkdir(parents=True, exist_ok=True)
1009
+
1010
+ # Download necessary data and model
1011
+ if model == "boltz1":
1012
+ download_boltz1(cache)
1013
+ elif model == "boltz2":
1014
+ download_boltz2(cache)
1015
+ else:
1016
+ msg = f"Model {model} not supported. Supported: boltz1, boltz2."
1017
+ raise ValueError(f"Model {model} not supported.")
1018
+
1019
+ # Validate inputs
1020
+ data = check_inputs(data)
1021
+
1022
+ # Check method
1023
+ if method is not None:
1024
+ if model == "boltz1":
1025
+ msg = "Method conditioning is not supported for Boltz-1."
1026
+ raise ValueError(msg)
1027
+ if method.lower() not in const.method_types_ids:
1028
+ method_names = list(const.method_types_ids.keys())
1029
+ msg = f"Method {method} not supported. Supported: {method_names}"
1030
+ raise ValueError(msg)
1031
+
1032
+ # Process inputs
1033
+ ccd_path = cache / "ccd.pkl"
1034
+ mol_dir = cache / "mols"
1035
+ process_inputs(
1036
+ data=data,
1037
+ out_dir=out_dir,
1038
+ ccd_path=ccd_path,
1039
+ mol_dir=mol_dir,
1040
+ use_msa_server=use_msa_server,
1041
+ msa_server_url=msa_server_url,
1042
+ msa_pairing_strategy=msa_pairing_strategy,
1043
+ boltz2=model == "boltz2",
1044
+ preprocessing_threads=preprocessing_threads,
1045
+ max_msa_seqs=max_msa_seqs,
1046
+ )
1047
+
1048
+ # Load manifest
1049
+ manifest = Manifest.load(out_dir / "processed" / "manifest.json")
1050
+
1051
+ # Filter out existing predictions
1052
+ filtered_manifest = filter_inputs_structure(
1053
+ manifest=manifest,
1054
+ outdir=out_dir,
1055
+ override=override,
1056
+ )
1057
+
1058
+ # Load processed data
1059
+ processed_dir = out_dir / "processed"
1060
+ processed = BoltzProcessedInput(
1061
+ manifest=filtered_manifest,
1062
+ targets_dir=processed_dir / "structures",
1063
+ msa_dir=processed_dir / "msa",
1064
+ constraints_dir=(
1065
+ (processed_dir / "constraints")
1066
+ if (processed_dir / "constraints").exists()
1067
+ else None
1068
+ ),
1069
+ template_dir=(
1070
+ (processed_dir / "templates")
1071
+ if (processed_dir / "templates").exists()
1072
+ else None
1073
+ ),
1074
+ extra_mols_dir=(
1075
+ (processed_dir / "mols") if (processed_dir / "mols").exists() else None
1076
+ ),
1077
+ )
1078
+
1079
+ # Set up trainer
1080
+ strategy = "auto"
1081
+ if (isinstance(devices, int) and devices > 1) or (
1082
+ isinstance(devices, list) and len(devices) > 1
1083
+ ):
1084
+ start_method = "fork" if platform.system() != "win32" else "spawn"
1085
+ strategy = DDPStrategy(start_method=start_method)
1086
+ if len(filtered_manifest.records) < devices:
1087
+ msg = (
1088
+ "Number of requested devices is greater "
1089
+ "than the number of predictions, taking the minimum."
1090
+ )
1091
+ click.echo(msg)
1092
+ if isinstance(devices, list):
1093
+ devices = devices[: max(1, len(filtered_manifest.records))]
1094
+ else:
1095
+ devices = max(1, min(len(filtered_manifest.records), devices))
1096
+
1097
+ # Set up model parameters
1098
+ if model == "boltz2":
1099
+ diffusion_params = Boltz2DiffusionParams()
1100
+ step_scale = 1.5 if step_scale is None else step_scale
1101
+ diffusion_params.step_scale = step_scale
1102
+ pairformer_args = PairformerArgsV2()
1103
+ else:
1104
+ diffusion_params = BoltzDiffusionParams()
1105
+ step_scale = 1.638 if step_scale is None else step_scale
1106
+ diffusion_params.step_scale = step_scale
1107
+ pairformer_args = PairformerArgs()
1108
+
1109
+ msa_args = MSAModuleArgs(
1110
+ subsample_msa=subsample_msa,
1111
+ num_subsampled_msa=num_subsampled_msa,
1112
+ use_paired_feature=model == "boltz2",
1113
+ )
1114
+
1115
+ # # Create prediction writer
1116
+ # pred_writer = BoltzWriter(
1117
+ # data_dir=processed.targets_dir,
1118
+ # output_dir=out_dir / "predictions",
1119
+ # output_format=output_format,
1120
+ # boltz2=model == "boltz2",
1121
+ # )
1122
+
1123
+ # # Set up trainer
1124
+ # trainer = Trainer(
1125
+ # default_root_dir=out_dir,
1126
+ # strategy=strategy,
1127
+ # callbacks=[pred_writer],
1128
+ # accelerator=accelerator,
1129
+ # devices=devices,
1130
+ # precision=32 if model == "boltz1" else "bf16-mixed",
1131
+ # )
1132
+
1133
+ # if filtered_manifest.records:
1134
+ # msg = f"Running structure prediction for {len(filtered_manifest.records)} input"
1135
+ # msg += "s." if len(filtered_manifest.records) > 1 else "."
1136
+ # click.echo(msg)
1137
+
1138
+ # # Create data module
1139
+ # if model == "boltz2":
1140
+ # data_module = Boltz2InferenceDataModule(
1141
+ # manifest=processed.manifest,
1142
+ # target_dir=processed.targets_dir,
1143
+ # msa_dir=processed.msa_dir,
1144
+ # mol_dir=mol_dir,
1145
+ # num_workers=num_workers,
1146
+ # constraints_dir=processed.constraints_dir,
1147
+ # template_dir=processed.template_dir,
1148
+ # extra_mols_dir=processed.extra_mols_dir,
1149
+ # override_method=method,
1150
+ # )
1151
+ # else:
1152
+ # data_module = BoltzInferenceDataModule(
1153
+ # manifest=processed.manifest,
1154
+ # target_dir=processed.targets_dir,
1155
+ # msa_dir=processed.msa_dir,
1156
+ # num_workers=num_workers,
1157
+ # constraints_dir=processed.constraints_dir,
1158
+ # )
1159
+
1160
+ # # Load model
1161
+ # if checkpoint is None:
1162
+ # if model == "boltz2":
1163
+ # checkpoint = cache / "boltz2_conf.ckpt"
1164
+ # else:
1165
+ # checkpoint = cache / "boltz1_conf.ckpt"
1166
+
1167
+ # predict_args = {
1168
+ # "recycling_steps": recycling_steps,
1169
+ # "sampling_steps": sampling_steps,
1170
+ # "diffusion_samples": diffusion_samples,
1171
+ # "max_parallel_samples": max_parallel_samples,
1172
+ # "write_confidence_summary": True,
1173
+ # "write_full_pae": write_full_pae,
1174
+ # "write_full_pde": write_full_pde,
1175
+ # }
1176
+
1177
+ # steering_args = BoltzSteeringParams()
1178
+ # steering_args.fk_steering = use_potentials
1179
+ # steering_args.guidance_update = use_potentials
1180
+
1181
+ # model_cls = Boltz2 if model == "boltz2" else Boltz1
1182
+ # model_module = model_cls.load_from_checkpoint(
1183
+ # checkpoint,
1184
+ # strict=True,
1185
+ # predict_args=predict_args,
1186
+ # map_location="cpu",
1187
+ # diffusion_process_args=asdict(diffusion_params),
1188
+ # ema=False,
1189
+ # use_kernels=not no_kernels,
1190
+ # pairformer_args=asdict(pairformer_args),
1191
+ # msa_args=asdict(msa_args),
1192
+ # steering_args=asdict(steering_args),
1193
+ # )
1194
+ # model_module.eval()
1195
+
1196
+ # # Compute structure predictions
1197
+ # trainer.predict(
1198
+ # model_module,
1199
+ # datamodule=data_module,
1200
+ # return_predictions=False,
1201
+ # )
1202
+
1203
+ # Check if affinity predictions are needed
1204
+ if any(r.affinity for r in manifest.records):
1205
+ # Print header
1206
+ click.echo("\nPredicting property: affinity\n")
1207
+
1208
+ # Validate inputs
1209
+ manifest_filtered = filter_inputs_affinity(
1210
+ manifest=manifest,
1211
+ outdir=out_dir,
1212
+ override=override,
1213
+ )
1214
+ if not manifest_filtered.records:
1215
+ click.echo("Found existing affinity predictions for all inputs, skipping.")
1216
+ return
1217
+
1218
+ msg = f"Running affinity prediction for {len(manifest_filtered.records)} input"
1219
+ msg += "s." if len(manifest_filtered.records) > 1 else "."
1220
+ click.echo(msg)
1221
+
1222
+ pred_writer = BoltzAffinityWriter(
1223
+ data_dir=processed.targets_dir,
1224
+ output_dir=out_dir / "predictions",
1225
+ )
1226
+
1227
+ trainer = Trainer(
1228
+ default_root_dir=out_dir,
1229
+ strategy=strategy,
1230
+ callbacks=[pred_writer],
1231
+ accelerator=accelerator,
1232
+ devices=devices,
1233
+ precision=32 if model == "boltz1" else "bf16-mixed",
1234
+ )
12
1235
 
13
- from boltz.data.parse.sdf import parse_sdf
14
- from boltz.data.types import StructureV2
1236
+ data_module = Boltz2InferenceDataModule(
1237
+ manifest=manifest_filtered,
1238
+ target_dir=out_dir / "predictions",
1239
+ msa_dir=processed.msa_dir,
1240
+ mol_dir=mol_dir,
1241
+ num_workers=num_workers,
1242
+ constraints_dir=processed.constraints_dir,
1243
+ template_dir=processed.template_dir,
1244
+ extra_mols_dir=processed.extra_mols_dir,
1245
+ override_method="other",
1246
+ affinity=True,
1247
+ )
15
1248
 
1249
+ predict_affinity_args = {
1250
+ "recycling_steps": 5,
1251
+ "sampling_steps": sampling_steps_affinity,
1252
+ "diffusion_samples": diffusion_samples_affinity,
1253
+ "max_parallel_samples": 1,
1254
+ "write_confidence_summary": False,
1255
+ "write_full_pae": False,
1256
+ "write_full_pde": False,
1257
+ }
16
1258
 
17
- def main():
18
- parser = argparse.ArgumentParser(description="Convert SDF to Boltz pre_affinity_*.npz format.")
19
- parser.add_argument('--sdf', type=str, required=True, help='Input SDF file (protein-ligand complex)')
20
- parser.add_argument('--ligand_id', type=str, required=True, help='Ligand ID (used in output filename)')
21
- parser.add_argument('--output', type=str, required=True, help='Output .npz file path')
22
- args = parser.parse_args()
1259
+ # Load affinity model
1260
+ if affinity_checkpoint is None:
1261
+ affinity_checkpoint = cache / "boltz2_aff.ckpt"
23
1262
 
24
- sdf_path = Path(args.sdf)
25
- output_path = Path(args.output)
1263
+ model_module = Boltz2.load_from_checkpoint(
1264
+ affinity_checkpoint,
1265
+ strict=True,
1266
+ predict_args=predict_affinity_args,
1267
+ map_location="cpu",
1268
+ diffusion_process_args=asdict(diffusion_params),
1269
+ ema=False,
1270
+ pairformer_args=asdict(pairformer_args),
1271
+ msa_args=asdict(msa_args),
1272
+ steering_args={"fk_steering": False, "guidance_update": False},
1273
+ affinity_mw_correction=affinity_mw_correction,
1274
+ )
1275
+ model_module.eval()
26
1276
 
27
- # Parse the SDF file to StructureV2
28
- structure: StructureV2 = parse_sdf(sdf_path)
1277
+ trainer.callbacks[0] = pred_writer
1278
+ trainer.predict(
1279
+ model_module,
1280
+ datamodule=data_module,
1281
+ return_predictions=False,
1282
+ )
29
1283
 
30
- # Save as pre_affinity_[ligand_id].npz
31
- structure.dump(output_path)
32
- print(f"Saved: {output_path}")
33
1284
 
34
1285
  if __name__ == "__main__":
35
- main()
1286
+ cli()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: boltz-vsynthes
3
- Version: 0.0.12
3
+ Version: 0.0.14
4
4
  Summary: Boltz for VSYNTHES
5
5
  Requires-Python: <3.13,>=3.10
6
6
  Description-Content-Type: text/markdown
@@ -108,11 +108,11 @@ boltz/model/potentials/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3
108
108
  boltz/model/potentials/potentials.py,sha256=vev8Vjfs-ML1hyrdv_R8DynG4wSFahJ6nzPWp7CYQqw,17507
109
109
  boltz/model/potentials/schedules.py,sha256=m7XJjfuF9uTX3bR9VisXv1rvzJjxiD8PobXRpcBBu1c,968
110
110
  boltz/utils/sdf_splitter.py,sha256=ZHn_syOcmm-fDnJ3YEGyGv_vYz2IRzUW7vbbMSU2JBY,2108
111
- boltz/utils/sdf_to_pre_affinity_npz.py,sha256=ev0s2pS8NoB-_3BkSjRKk3GMnqjOxTRCH-r9avKYGOg,1212
111
+ boltz/utils/sdf_to_pre_affinity_npz.py,sha256=ro0KGe24JexbJm47J8S8w8Lmr_KaQbzOAb_dKZO2G9I,40384
112
112
  boltz/utils/yaml_generator.py,sha256=ermWIG-BE6nNWHFvpEwpk92N9J-YATpGXZGLvD1I2oQ,4012
113
- boltz_vsynthes-0.0.12.dist-info/licenses/LICENSE,sha256=8GZ_1eZsUeG6jdqgJJxtciWzADfgLEV4LY8sKUOsJhc,1102
114
- boltz_vsynthes-0.0.12.dist-info/METADATA,sha256=XPcHZExavOK9oz8BaZnp31f-acFOXbCKwmDORbDxaP8,7235
115
- boltz_vsynthes-0.0.12.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
116
- boltz_vsynthes-0.0.12.dist-info/entry_points.txt,sha256=WlXNdHTJF1ahGN-H1QpgMbZm_5SobXY05eiuT8buycw,212
117
- boltz_vsynthes-0.0.12.dist-info/top_level.txt,sha256=MgU3Jfb-ctWm07YGMts68PMjSh9v26D0gfG3dFRmVFA,6
118
- boltz_vsynthes-0.0.12.dist-info/RECORD,,
113
+ boltz_vsynthes-0.0.14.dist-info/licenses/LICENSE,sha256=8GZ_1eZsUeG6jdqgJJxtciWzADfgLEV4LY8sKUOsJhc,1102
114
+ boltz_vsynthes-0.0.14.dist-info/METADATA,sha256=hopmZgbr8M8tfFZ3kjGbeYSNvIldCDN1qkNEN4v8ePY,7235
115
+ boltz_vsynthes-0.0.14.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
116
+ boltz_vsynthes-0.0.14.dist-info/entry_points.txt,sha256=nZNYPKKrmAr-MVA0K-ClNRT2p90FV1_14d7HpsESZFQ,211
117
+ boltz_vsynthes-0.0.14.dist-info/top_level.txt,sha256=MgU3Jfb-ctWm07YGMts68PMjSh9v26D0gfG3dFRmVFA,6
118
+ boltz_vsynthes-0.0.14.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  [console_scripts]
2
2
  boltz = boltz.main:cli
3
3
  boltz-generate-yaml = boltz.utils.yaml_generator:main
4
- boltz-sdf-to-pre-affinity = boltz.utils.sdf_to_pre_affinity_npz:main
4
+ boltz-sdf-to-pre-affinity = boltz.utils.sdf_to_pre_affinity_npz:cli
5
5
  boltz-split-sdf = boltz.utils.sdf_splitter:main