boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
boltz/main.py ADDED
@@ -0,0 +1,1292 @@
1
+ import multiprocessing
2
+ import os
3
+ import pickle
4
+ import platform
5
+ import tarfile
6
+ import urllib.request
7
+ import warnings
8
+ from dataclasses import asdict, dataclass
9
+ from functools import partial
10
+ from multiprocessing import Pool
11
+ from pathlib import Path
12
+ from typing import Literal, Optional
13
+
14
+ import click
15
+ import torch
16
+ from pytorch_lightning import Trainer, seed_everything
17
+ from pytorch_lightning.strategies import DDPStrategy
18
+ from pytorch_lightning.utilities import rank_zero_only
19
+ from rdkit import Chem
20
+ from tqdm import tqdm
21
+
22
+ from boltz.data import const
23
+ from boltz.data.module.inference import BoltzInferenceDataModule
24
+ from boltz.data.module.inferencev2 import Boltz2InferenceDataModule
25
+ from boltz.data.mol import load_canonicals
26
+ from boltz.data.msa.mmseqs2 import run_mmseqs2
27
+ from boltz.data.parse.a3m import parse_a3m
28
+ from boltz.data.parse.csv import parse_csv
29
+ from boltz.data.parse.fasta import parse_fasta
30
+ from boltz.data.parse.yaml import parse_yaml
31
+ from boltz.data.types import MSA, Manifest, Record
32
+ from boltz.data.write.writer import BoltzAffinityWriter, BoltzWriter
33
+ from boltz.model.models.boltz1 import Boltz1
34
+ from boltz.model.models.boltz2 import Boltz2
35
+
36
+ CCD_URL = "https://huggingface.co/boltz-community/boltz-1/resolve/main/ccd.pkl"
37
+ MOL_URL = "https://huggingface.co/boltz-community/boltz-2/resolve/main/mols.tar"
38
+
39
+ BOLTZ1_URL_WITH_FALLBACK = [
40
+ "https://model-gateway.boltz.bio/boltz1_conf.ckpt",
41
+ "https://huggingface.co/boltz-community/boltz-1/resolve/main/boltz1_conf.ckpt",
42
+ ]
43
+
44
+ BOLTZ2_URL_WITH_FALLBACK = [
45
+ "https://model-gateway.boltz.bio/boltz2_conf.ckpt",
46
+ "https://huggingface.co/boltz-community/boltz-2/resolve/main/boltz2_conf.ckpt",
47
+ ]
48
+
49
+ BOLTZ2_AFFINITY_URL_WITH_FALLBACK = [
50
+ "https://model-gateway.boltz.bio/boltz2_aff.ckpt",
51
+ "https://huggingface.co/boltz-community/boltz-2/resolve/main/boltz2_aff.ckpt",
52
+ ]
53
+
54
+
55
+ @dataclass
56
+ class BoltzProcessedInput:
57
+ """Processed input data."""
58
+
59
+ manifest: Manifest
60
+ targets_dir: Path
61
+ msa_dir: Path
62
+ constraints_dir: Optional[Path] = None
63
+ template_dir: Optional[Path] = None
64
+ extra_mols_dir: Optional[Path] = None
65
+
66
+
67
+ @dataclass
68
+ class PairformerArgs:
69
+ """Pairformer arguments."""
70
+
71
+ num_blocks: int = 48
72
+ num_heads: int = 16
73
+ dropout: float = 0.0
74
+ activation_checkpointing: bool = False
75
+ offload_to_cpu: bool = False
76
+ v2: bool = False
77
+
78
+
79
+ @dataclass
80
+ class PairformerArgsV2:
81
+ """Pairformer arguments."""
82
+
83
+ num_blocks: int = 64
84
+ num_heads: int = 16
85
+ dropout: float = 0.0
86
+ activation_checkpointing: bool = False
87
+ offload_to_cpu: bool = False
88
+ v2: bool = True
89
+
90
+
91
+ @dataclass
92
+ class MSAModuleArgs:
93
+ """MSA module arguments."""
94
+
95
+ msa_s: int = 64
96
+ msa_blocks: int = 4
97
+ msa_dropout: float = 0.0
98
+ z_dropout: float = 0.0
99
+ use_paired_feature: bool = True
100
+ pairwise_head_width: int = 32
101
+ pairwise_num_heads: int = 4
102
+ activation_checkpointing: bool = False
103
+ offload_to_cpu: bool = False
104
+ subsample_msa: bool = False
105
+ num_subsampled_msa: int = 1024
106
+
107
+
108
+ @dataclass
109
+ class BoltzDiffusionParams:
110
+ """Diffusion process parameters."""
111
+
112
+ gamma_0: float = 0.605
113
+ gamma_min: float = 1.107
114
+ noise_scale: float = 0.901
115
+ rho: float = 8
116
+ step_scale: float = 1.638
117
+ sigma_min: float = 0.0004
118
+ sigma_max: float = 160.0
119
+ sigma_data: float = 16.0
120
+ P_mean: float = -1.2
121
+ P_std: float = 1.5
122
+ coordinate_augmentation: bool = True
123
+ alignment_reverse_diff: bool = True
124
+ synchronize_sigmas: bool = True
125
+ use_inference_model_cache: bool = True
126
+
127
+
128
+ @dataclass
129
+ class Boltz2DiffusionParams:
130
+ """Diffusion process parameters."""
131
+
132
+ gamma_0: float = 0.8
133
+ gamma_min: float = 1.0
134
+ noise_scale: float = 1.003
135
+ rho: float = 7
136
+ step_scale: float = 1.5
137
+ sigma_min: float = 0.0001
138
+ sigma_max: float = 160.0
139
+ sigma_data: float = 16.0
140
+ P_mean: float = -1.2
141
+ P_std: float = 1.5
142
+ coordinate_augmentation: bool = True
143
+ alignment_reverse_diff: bool = True
144
+ synchronize_sigmas: bool = True
145
+
146
+
147
+ @dataclass
148
+ class BoltzSteeringParams:
149
+ """Steering parameters."""
150
+
151
+ fk_steering: bool = True
152
+ num_particles: int = 3
153
+ fk_lambda: float = 4.0
154
+ fk_resampling_interval: int = 3
155
+ guidance_update: bool = True
156
+ num_gd_steps: int = 20
157
+
158
+
159
+ @rank_zero_only
160
+ def download_boltz1(cache: Path) -> None:
161
+ """Download all the required data.
162
+
163
+ Parameters
164
+ ----------
165
+ cache : Path
166
+ The cache directory.
167
+
168
+ """
169
+ # Download CCD
170
+ ccd = cache / "ccd.pkl"
171
+ if not ccd.exists():
172
+ click.echo(
173
+ f"Downloading the CCD dictionary to {ccd}. You may "
174
+ "change the cache directory with the --cache flag."
175
+ )
176
+ urllib.request.urlretrieve(CCD_URL, str(ccd)) # noqa: S310
177
+
178
+ # Download model
179
+ model = cache / "boltz1_conf.ckpt"
180
+ if not model.exists():
181
+ click.echo(
182
+ f"Downloading the model weights to {model}. You may "
183
+ "change the cache directory with the --cache flag."
184
+ )
185
+ for i, url in enumerate(BOLTZ1_URL_WITH_FALLBACK):
186
+ try:
187
+ urllib.request.urlretrieve(url, str(model)) # noqa: S310
188
+ break
189
+ except Exception as e: # noqa: BLE001
190
+ if i == len(BOLTZ1_URL_WITH_FALLBACK) - 1:
191
+ msg = f"Failed to download model from all URLs. Last error: {e}"
192
+ raise RuntimeError(msg) from e
193
+ continue
194
+
195
+
196
+ @rank_zero_only
197
+ def download_boltz2(cache: Path) -> None:
198
+ """Download all the required data.
199
+
200
+ Parameters
201
+ ----------
202
+ cache : Path
203
+ The cache directory.
204
+
205
+ """
206
+ # Download CCD
207
+ mols = cache / "mols"
208
+ tar_mols = cache / "mols.tar"
209
+ if not mols.exists():
210
+ click.echo(
211
+ f"Downloading and extracting the CCD data to {mols}. "
212
+ "This may take a bit of time. You may change the cache directory "
213
+ "with the --cache flag."
214
+ )
215
+ urllib.request.urlretrieve(MOL_URL, str(tar_mols)) # noqa: S310
216
+ with tarfile.open(str(tar_mols), "r") as tar:
217
+ tar.extractall(cache) # noqa: S202
218
+
219
+ # Download model
220
+ model = cache / "boltz2_conf.ckpt"
221
+ if not model.exists():
222
+ click.echo(
223
+ f"Downloading the Boltz-2 weights to {model}. You may "
224
+ "change the cache directory with the --cache flag."
225
+ )
226
+ for i, url in enumerate(BOLTZ2_URL_WITH_FALLBACK):
227
+ try:
228
+ urllib.request.urlretrieve(url, str(model)) # noqa: S310
229
+ break
230
+ except Exception as e: # noqa: BLE001
231
+ if i == len(BOLTZ2_URL_WITH_FALLBACK) - 1:
232
+ msg = f"Failed to download model from all URLs. Last error: {e}"
233
+ raise RuntimeError(msg) from e
234
+ continue
235
+
236
+ # Download affinity model
237
+ affinity_model = cache / "boltz2_aff.ckpt"
238
+ if not affinity_model.exists():
239
+ click.echo(
240
+ f"Downloading the Boltz-2 affinity weights to {affinity_model}. You may "
241
+ "change the cache directory with the --cache flag."
242
+ )
243
+ for i, url in enumerate(BOLTZ2_AFFINITY_URL_WITH_FALLBACK):
244
+ try:
245
+ urllib.request.urlretrieve(url, str(affinity_model)) # noqa: S310
246
+ break
247
+ except Exception as e: # noqa: BLE001
248
+ if i == len(BOLTZ2_AFFINITY_URL_WITH_FALLBACK) - 1:
249
+ msg = f"Failed to download model from all URLs. Last error: {e}"
250
+ raise RuntimeError(msg) from e
251
+ continue
252
+
253
+
254
+ def get_cache_path() -> str:
255
+ """Determine the cache path, prioritising the BOLTZ_CACHE environment variable.
256
+
257
+ Returns
258
+ -------
259
+ str: Path
260
+ Path to use for boltz cache location.
261
+
262
+ """
263
+ env_cache = os.environ.get("BOLTZ_CACHE")
264
+ if env_cache:
265
+ resolved_cache = Path(env_cache).expanduser().resolve()
266
+ if not resolved_cache.is_absolute():
267
+ msg = f"BOLTZ_CACHE must be an absolute path, got: {env_cache}"
268
+ raise ValueError(msg)
269
+ return str(resolved_cache)
270
+
271
+ return str(Path("~/.boltz").expanduser())
272
+
273
+
274
+ def check_inputs(data: Path) -> list[Path]:
275
+ """Check the input data and output directory.
276
+
277
+ Parameters
278
+ ----------
279
+ data : Path
280
+ The input data.
281
+
282
+ Returns
283
+ -------
284
+ list[Path]
285
+ The list of input data.
286
+
287
+ """
288
+ click.echo("Checking input data.")
289
+
290
+ # Check if data is a directory
291
+ if data.is_dir():
292
+ # Get all files recursively
293
+ data: list[Path] = []
294
+ for ext in (".fa", ".fas", ".fasta", ".yml", ".yaml"):
295
+ data.extend(data.glob(f"**/*{ext}"))
296
+
297
+ if not data:
298
+ msg = f"No .fasta or .yaml files found in {data}"
299
+ raise RuntimeError(msg)
300
+
301
+ # Filter out directories and invalid file types
302
+ valid_data = []
303
+ for d in data:
304
+ if d.is_dir():
305
+ msg = f"Found directory {d} instead of .fasta or .yaml."
306
+ click.echo(f"Warning: {msg}")
307
+ continue
308
+ if d.suffix not in (".fa", ".fas", ".fasta", ".yml", ".yaml"):
309
+ msg = f"Warning: Skipping file with unsupported extension {d.suffix}"
310
+ click.echo(msg)
311
+ continue
312
+ valid_data.append(d)
313
+
314
+ if not valid_data:
315
+ msg = "No valid input files found after filtering."
316
+ raise RuntimeError(msg)
317
+
318
+ data = valid_data
319
+ else:
320
+ # Single file case
321
+ if data.suffix not in (".fa", ".fas", ".fasta", ".yml", ".yaml"):
322
+ msg = (
323
+ f"Unable to parse filetype {data.suffix}, "
324
+ "please provide a .fasta or .yaml file."
325
+ )
326
+ raise RuntimeError(msg)
327
+ data = [data]
328
+
329
+ click.echo(f"Found {len(data)} valid input files.")
330
+ return data
331
+
332
+
333
+ def filter_inputs_structure(
334
+ manifest: Manifest,
335
+ outdir: Path,
336
+ override: bool = False,
337
+ ) -> Manifest:
338
+ """Filter the manifest to only include missing predictions.
339
+
340
+ Parameters
341
+ ----------
342
+ manifest : Manifest
343
+ The manifest of the input data.
344
+ outdir : Path
345
+ The output directory.
346
+ override: bool
347
+ Whether to override existing predictions.
348
+
349
+ Returns
350
+ -------
351
+ Manifest
352
+ The manifest of the filtered input data.
353
+
354
+ """
355
+ # Check if existing predictions are found
356
+ existing = (outdir / "predictions").rglob("*")
357
+ existing = {e.name for e in existing if e.is_dir()}
358
+
359
+ # Remove them from the input data
360
+ if existing and not override:
361
+ manifest = Manifest([r for r in manifest.records if r.id not in existing])
362
+ msg = (
363
+ f"Found some existing predictions ({len(existing)}), "
364
+ f"skipping and running only the missing ones, "
365
+ "if any. If you wish to override these existing "
366
+ "predictions, please set the --override flag."
367
+ )
368
+ click.echo(msg)
369
+ elif existing and override:
370
+ msg = f"Found {len(existing)} existing predictions, will override."
371
+ click.echo(msg)
372
+
373
+ return manifest
374
+
375
+
376
+ def filter_inputs_affinity(
377
+ manifest: Manifest,
378
+ outdir: Path,
379
+ override: bool = False,
380
+ ) -> Manifest:
381
+ """Check the input data and output directory for affinity.
382
+
383
+ Parameters
384
+ ----------
385
+ manifest : Manifest
386
+ The manifest.
387
+ outdir : Path
388
+ The output directory.
389
+ override: bool
390
+ Whether to override existing predictions.
391
+
392
+ Returns
393
+ -------
394
+ Manifest
395
+ The manifest of the filtered input data.
396
+
397
+ """
398
+ click.echo("Checking input data for affinity.")
399
+
400
+ # Get all affinity targets
401
+ existing = {
402
+ r.id
403
+ for r in manifest.records
404
+ if r.affinity
405
+ and (outdir / "predictions" / r.id / f"affinity_{r.id}.json").exists()
406
+ }
407
+
408
+ # Remove them from the input data
409
+ if existing and not override:
410
+ num_skipped = len(existing)
411
+ msg = (
412
+ f"Found some existing affinity predictions ({num_skipped}), "
413
+ f"skipping and running only the missing ones, "
414
+ "if any. If you wish to override these existing "
415
+ "affinity predictions, please set the --override flag."
416
+ )
417
+ click.echo(msg)
418
+ elif existing and override:
419
+ msg = "Found existing affinity predictions, will override."
420
+ click.echo(msg)
421
+
422
+ return Manifest([r for r in manifest.records if r.id not in existing])
423
+
424
+
425
+ def compute_msa(
426
+ data: dict[str, str],
427
+ target_id: str,
428
+ msa_dir: Path,
429
+ msa_server_url: str,
430
+ msa_pairing_strategy: str,
431
+ ) -> None:
432
+ """Compute the MSA for the input data.
433
+
434
+ Parameters
435
+ ----------
436
+ data : dict[str, str]
437
+ The input protein sequences.
438
+ target_id : str
439
+ The target id.
440
+ msa_dir : Path
441
+ The msa directory.
442
+ msa_server_url : str
443
+ The MSA server URL.
444
+ msa_pairing_strategy : str
445
+ The MSA pairing strategy.
446
+
447
+ """
448
+ if len(data) > 1:
449
+ paired_msas = run_mmseqs2(
450
+ list(data.values()),
451
+ msa_dir / f"{target_id}_paired_tmp",
452
+ use_env=True,
453
+ use_pairing=True,
454
+ host_url=msa_server_url,
455
+ pairing_strategy=msa_pairing_strategy,
456
+ )
457
+ else:
458
+ paired_msas = [""] * len(data)
459
+
460
+ unpaired_msa = run_mmseqs2(
461
+ list(data.values()),
462
+ msa_dir / f"{target_id}_unpaired_tmp",
463
+ use_env=True,
464
+ use_pairing=False,
465
+ host_url=msa_server_url,
466
+ pairing_strategy=msa_pairing_strategy,
467
+ )
468
+
469
+ for idx, name in enumerate(data):
470
+ # Get paired sequences
471
+ paired = paired_msas[idx].strip().splitlines()
472
+ paired = paired[1::2] # ignore headers
473
+ paired = paired[: const.max_paired_seqs]
474
+
475
+ # Set key per row and remove empty sequences
476
+ keys = [idx for idx, s in enumerate(paired) if s != "-" * len(s)]
477
+ paired = [s for s in paired if s != "-" * len(s)]
478
+
479
+ # Combine paired-unpaired sequences
480
+ unpaired = unpaired_msa[idx].strip().splitlines()
481
+ unpaired = unpaired[1::2]
482
+ unpaired = unpaired[: (const.max_msa_seqs - len(paired))]
483
+ if paired:
484
+ unpaired = unpaired[1:] # ignore query is already present
485
+
486
+ # Combine
487
+ seqs = paired + unpaired
488
+ keys = keys + [-1] * len(unpaired)
489
+
490
+ # Dump MSA
491
+ csv_str = ["key,sequence"] + [f"{key},{seq}" for key, seq in zip(keys, seqs)]
492
+
493
+ msa_path = msa_dir / f"{name}.csv"
494
+ with msa_path.open("w") as f:
495
+ f.write("\n".join(csv_str))
496
+
497
+
498
+ def process_input( # noqa: C901, PLR0912, PLR0915, D103
499
+ path: Path,
500
+ ccd: dict,
501
+ msa_dir: Path,
502
+ mol_dir: Path,
503
+ boltz2: bool,
504
+ use_msa_server: bool,
505
+ msa_server_url: str,
506
+ msa_pairing_strategy: str,
507
+ max_msa_seqs: int,
508
+ processed_msa_dir: Path,
509
+ processed_constraints_dir: Path,
510
+ processed_templates_dir: Path,
511
+ processed_mols_dir: Path,
512
+ structure_dir: Path,
513
+ records_dir: Path,
514
+ ) -> None:
515
+ try:
516
+ # Parse data
517
+ if path.suffix in (".fa", ".fas", ".fasta"):
518
+ target = parse_fasta(path, ccd, mol_dir, boltz2)
519
+ elif path.suffix in (".yml", ".yaml"):
520
+ target = parse_yaml(path, ccd, mol_dir, boltz2)
521
+ elif path.is_dir():
522
+ msg = f"Found directory {path} instead of .fasta or .yaml, skipping."
523
+ raise RuntimeError(msg) # noqa: TRY301
524
+ else:
525
+ msg = (
526
+ f"Unable to parse filetype {path.suffix}, "
527
+ "please provide a .fasta or .yaml file."
528
+ )
529
+ raise RuntimeError(msg) # noqa: TRY301
530
+
531
+ # Get target id
532
+ target_id = target.record.id
533
+
534
+ # Get all MSA ids and decide whether to generate MSA
535
+ to_generate = {}
536
+ prot_id = const.chain_type_ids["PROTEIN"]
537
+ for chain in target.record.chains:
538
+ # Add to generate list, assigning entity id
539
+ if (chain.mol_type == prot_id) and (chain.msa_id == 0):
540
+ entity_id = chain.entity_id
541
+ msa_id = f"{target_id}_{entity_id}"
542
+ to_generate[msa_id] = target.sequences[entity_id]
543
+ chain.msa_id = msa_dir / f"{msa_id}.csv"
544
+
545
+ # We do not support msa generation for non-protein chains
546
+ elif chain.msa_id == 0:
547
+ chain.msa_id = -1
548
+
549
+ # Generate MSA
550
+ if to_generate and not use_msa_server:
551
+ msg = "Missing MSA's in input and --use_msa_server flag not set."
552
+ raise RuntimeError(msg) # noqa: TRY301
553
+
554
+ if to_generate:
555
+ msg = f"Generating MSA for {path} with {len(to_generate)} protein entities."
556
+ click.echo(msg)
557
+ compute_msa(
558
+ data=to_generate,
559
+ target_id=target_id,
560
+ msa_dir=msa_dir,
561
+ msa_server_url=msa_server_url,
562
+ msa_pairing_strategy=msa_pairing_strategy,
563
+ )
564
+
565
+ # Parse MSA data
566
+ msas = sorted({c.msa_id for c in target.record.chains if c.msa_id != -1})
567
+ msa_id_map = {}
568
+ for msa_idx, msa_id in enumerate(msas):
569
+ # Check that raw MSA exists
570
+ msa_path = Path(msa_id)
571
+ if not msa_path.exists():
572
+ msg = f"MSA file {msa_path} not found."
573
+ raise FileNotFoundError(msg) # noqa: TRY301
574
+
575
+ # Dump processed MSA
576
+ processed = processed_msa_dir / f"{target_id}_{msa_idx}.npz"
577
+ msa_id_map[msa_id] = f"{target_id}_{msa_idx}"
578
+ if not processed.exists():
579
+ # Parse A3M
580
+ if msa_path.suffix == ".a3m":
581
+ msa: MSA = parse_a3m(
582
+ msa_path,
583
+ taxonomy=None,
584
+ max_seqs=max_msa_seqs,
585
+ )
586
+ elif msa_path.suffix == ".csv":
587
+ msa: MSA = parse_csv(msa_path, max_seqs=max_msa_seqs)
588
+ else:
589
+ msg = f"MSA file {msa_path} not supported, only a3m or csv."
590
+ raise RuntimeError(msg) # noqa: TRY301
591
+
592
+ msa.dump(processed)
593
+
594
+ # Modify records to point to processed MSA
595
+ for c in target.record.chains:
596
+ if (c.msa_id != -1) and (c.msa_id in msa_id_map):
597
+ c.msa_id = msa_id_map[c.msa_id]
598
+
599
+ # Dump templates
600
+ for template_id, template in target.templates.items():
601
+ name = f"{target.record.id}_{template_id}.npz"
602
+ template_path = processed_templates_dir / name
603
+ template.dump(template_path)
604
+
605
+ # Dump constraints
606
+ constraints_path = processed_constraints_dir / f"{target.record.id}.npz"
607
+ target.residue_constraints.dump(constraints_path)
608
+
609
+ # Dump extra molecules
610
+ Chem.SetDefaultPickleProperties(Chem.PropertyPickleOptions.AllProps)
611
+ with (processed_mols_dir / f"{target.record.id}.pkl").open("wb") as f:
612
+ pickle.dump(target.extra_mols, f)
613
+
614
+ # Dump structure
615
+ struct_path = structure_dir / f"{target.record.id}.npz"
616
+ target.structure.dump(struct_path)
617
+
618
+ # Dump record
619
+ record_path = records_dir / f"{target.record.id}.json"
620
+ target.record.dump(record_path)
621
+
622
+ except Exception as e: # noqa: BLE001
623
+ import traceback
624
+
625
+ traceback.print_exc()
626
+ print(f"Failed to process {path}. Skipping. Error: {e}.") # noqa: T201
627
+
628
+
629
+ def process_inputs(
630
+ data: list[Path],
631
+ out_dir: Path,
632
+ ccd_path: Path,
633
+ mol_dir: Path,
634
+ msa_server_url: str,
635
+ msa_pairing_strategy: str,
636
+ max_msa_seqs: int = 8192,
637
+ use_msa_server: bool = False,
638
+ boltz2: bool = False,
639
+ preprocessing_threads: int = 1,
640
+ ) -> Manifest:
641
+ """Process the input data and output directory.
642
+
643
+ Parameters
644
+ ----------
645
+ data : list[Path]
646
+ The input data.
647
+ out_dir : Path
648
+ The output directory.
649
+ ccd_path : Path
650
+ The path to the CCD dictionary.
651
+ max_msa_seqs : int, optional
652
+ Max number of MSA sequences, by default 4096.
653
+ use_msa_server : bool, optional
654
+ Whether to use the MMSeqs2 server for MSA generation, by default False.
655
+ boltz2: bool, optional
656
+ Whether to use Boltz2, by default False.
657
+ preprocessing_threads: int, optional
658
+ The number of threads to use for preprocessing, by default 1.
659
+
660
+ Returns
661
+ -------
662
+ Manifest
663
+ The manifest of the processed input data.
664
+
665
+ """
666
+ all_records = []
667
+
668
+ # Process each input file in its own directory
669
+ for input_file in data:
670
+ # Create a subdirectory for this input file
671
+ file_out_dir = out_dir / input_file.stem
672
+ file_out_dir.mkdir(parents=True, exist_ok=True)
673
+
674
+ # Check if records exist at output path
675
+ records_dir = file_out_dir / "processed" / "records"
676
+ if records_dir.exists():
677
+ # Load existing records
678
+ existing = [Record.load(p) for p in records_dir.glob("*.json")]
679
+ processed_ids = {record.id for record in existing}
680
+
681
+ # Skip if already processed
682
+ if input_file.stem in processed_ids:
683
+ click.echo(f"Found existing processed input for {input_file.stem}, skipping.")
684
+ all_records.extend(existing)
685
+ continue
686
+
687
+ # Create output directories for this file
688
+ msa_dir = file_out_dir / "msa"
689
+ records_dir = file_out_dir / "processed" / "records"
690
+ structure_dir = file_out_dir / "processed" / "structures"
691
+ processed_msa_dir = file_out_dir / "processed" / "msa"
692
+ processed_constraints_dir = file_out_dir / "processed" / "constraints"
693
+ processed_templates_dir = file_out_dir / "processed" / "templates"
694
+ processed_mols_dir = file_out_dir / "processed" / "mols"
695
+ predictions_dir = file_out_dir / "predictions"
696
+
697
+ file_out_dir.mkdir(parents=True, exist_ok=True)
698
+ msa_dir.mkdir(parents=True, exist_ok=True)
699
+ records_dir.mkdir(parents=True, exist_ok=True)
700
+ structure_dir.mkdir(parents=True, exist_ok=True)
701
+ processed_msa_dir.mkdir(parents=True, exist_ok=True)
702
+ processed_constraints_dir.mkdir(parents=True, exist_ok=True)
703
+ processed_templates_dir.mkdir(parents=True, exist_ok=True)
704
+ processed_mols_dir.mkdir(parents=True, exist_ok=True)
705
+ predictions_dir.mkdir(parents=True, exist_ok=True)
706
+
707
+ # Load CCD
708
+ if boltz2:
709
+ ccd = load_canonicals(mol_dir)
710
+ else:
711
+ with ccd_path.open("rb") as file:
712
+ ccd = pickle.load(file) # noqa: S301
713
+
714
+ # Create partial function
715
+ process_input_partial = partial(
716
+ process_input,
717
+ ccd=ccd,
718
+ msa_dir=msa_dir,
719
+ mol_dir=mol_dir,
720
+ boltz2=boltz2,
721
+ use_msa_server=use_msa_server,
722
+ msa_server_url=msa_server_url,
723
+ msa_pairing_strategy=msa_pairing_strategy,
724
+ max_msa_seqs=max_msa_seqs,
725
+ processed_msa_dir=processed_msa_dir,
726
+ processed_constraints_dir=processed_constraints_dir,
727
+ processed_templates_dir=processed_templates_dir,
728
+ processed_mols_dir=processed_mols_dir,
729
+ structure_dir=structure_dir,
730
+ records_dir=records_dir,
731
+ )
732
+
733
+ # Process this input file
734
+ click.echo(f"Processing {input_file.name}")
735
+ process_input_partial(input_file)
736
+
737
+ # Load records for this file
738
+ records = [Record.load(p) for p in records_dir.glob("*.json")]
739
+ all_records.extend(records)
740
+
741
+ # Create combined manifest
742
+ manifest = Manifest(all_records)
743
+ manifest.dump(out_dir / "processed" / "manifest.json")
744
+ return manifest
745
+
746
+
747
+ @click.group()
748
+ def cli() -> None:
749
+ """Boltz."""
750
+ return
751
+
752
+
753
+ @cli.command()
754
+ @click.argument("data", type=click.Path(exists=True))
755
+ @click.option(
756
+ "--out_dir",
757
+ type=click.Path(exists=False),
758
+ help="The path where to save the predictions.",
759
+ default="./",
760
+ )
761
+ @click.option(
762
+ "--cache",
763
+ type=click.Path(exists=False),
764
+ help=(
765
+ "The directory where to download the data and model. "
766
+ "Default is ~/.boltz, or $BOLTZ_CACHE if set."
767
+ ),
768
+ default=get_cache_path,
769
+ )
770
+ @click.option(
771
+ "--checkpoint",
772
+ type=click.Path(exists=True),
773
+ help="An optional checkpoint, will use the provided Boltz-1 model by default.",
774
+ default=None,
775
+ )
776
+ @click.option(
777
+ "--devices",
778
+ type=int,
779
+ help="The number of devices to use for prediction. Default is 1.",
780
+ default=1,
781
+ )
782
+ @click.option(
783
+ "--accelerator",
784
+ type=click.Choice(["gpu", "cpu", "tpu"]),
785
+ help="The accelerator to use for prediction. Default is gpu.",
786
+ default="gpu",
787
+ )
788
+ @click.option(
789
+ "--recycling_steps",
790
+ type=int,
791
+ help="The number of recycling steps to use for prediction. Default is 3.",
792
+ default=3,
793
+ )
794
+ @click.option(
795
+ "--sampling_steps",
796
+ type=int,
797
+ help="The number of sampling steps to use for prediction. Default is 200.",
798
+ default=200,
799
+ )
800
+ @click.option(
801
+ "--diffusion_samples",
802
+ type=int,
803
+ help="The number of diffusion samples to use for prediction. Default is 1.",
804
+ default=1,
805
+ )
806
+ @click.option(
807
+ "--max_parallel_samples",
808
+ type=int,
809
+ help="The maximum number of samples to predict in parallel. Default is None.",
810
+ default=5,
811
+ )
812
+ @click.option(
813
+ "--step_scale",
814
+ type=float,
815
+ help=(
816
+ "The step size is related to the temperature at "
817
+ "which the diffusion process samples the distribution. "
818
+ "The lower the higher the diversity among samples "
819
+ "(recommended between 1 and 2). "
820
+ "Default is 1.638 for Boltz-1 and 1.5 for Boltz-2. "
821
+ "If not provided, the default step size will be used."
822
+ ),
823
+ default=None,
824
+ )
825
+ @click.option(
826
+ "--write_full_pae",
827
+ type=bool,
828
+ is_flag=True,
829
+ help="Whether to dump the pae into a npz file. Default is True.",
830
+ )
831
+ @click.option(
832
+ "--write_full_pde",
833
+ type=bool,
834
+ is_flag=True,
835
+ help="Whether to dump the pde into a npz file. Default is False.",
836
+ )
837
+ @click.option(
838
+ "--output_format",
839
+ type=click.Choice(["pdb", "mmcif"]),
840
+ help="The output format to use for the predictions. Default is mmcif.",
841
+ default="mmcif",
842
+ )
843
+ @click.option(
844
+ "--num_workers",
845
+ type=int,
846
+ help="The number of dataloader workers to use for prediction. Default is 2.",
847
+ default=2,
848
+ )
849
+ @click.option(
850
+ "--override",
851
+ is_flag=True,
852
+ help="Whether to override existing found predictions. Default is False.",
853
+ )
854
+ @click.option(
855
+ "--seed",
856
+ type=int,
857
+ help="Seed to use for random number generator. Default is None (no seeding).",
858
+ default=None,
859
+ )
860
+ @click.option(
861
+ "--use_msa_server",
862
+ is_flag=True,
863
+ help="Whether to use the MMSeqs2 server for MSA generation. Default is False.",
864
+ )
865
+ @click.option(
866
+ "--msa_server_url",
867
+ type=str,
868
+ help="MSA server url. Used only if --use_msa_server is set. ",
869
+ default="https://api.colabfold.com",
870
+ )
871
+ @click.option(
872
+ "--msa_pairing_strategy",
873
+ type=str,
874
+ help=(
875
+ "Pairing strategy to use. Used only if --use_msa_server is set. "
876
+ "Options are 'greedy' and 'complete'"
877
+ ),
878
+ default="greedy",
879
+ )
880
+ @click.option(
881
+ "--use_potentials",
882
+ is_flag=True,
883
+ help="Whether to not use potentials for steering. Default is False.",
884
+ )
885
+ @click.option(
886
+ "--model",
887
+ default="boltz2",
888
+ type=click.Choice(["boltz1", "boltz2"]),
889
+ help="The model to use for prediction. Default is boltz2.",
890
+ )
891
+ @click.option(
892
+ "--method",
893
+ type=str,
894
+ help="The method to use for prediction. Default is None.",
895
+ default=None,
896
+ )
897
+ @click.option(
898
+ "--preprocessing-threads",
899
+ type=int,
900
+ help="The number of threads to use for preprocessing. Default is 1.",
901
+ default=multiprocessing.cpu_count(),
902
+ )
903
+ @click.option(
904
+ "--affinity_mw_correction",
905
+ is_flag=True,
906
+ type=bool,
907
+ help="Whether to add the Molecular Weight correction to the affinity value head.",
908
+ )
909
+ @click.option(
910
+ "--sampling_steps_affinity",
911
+ type=int,
912
+ help="The number of sampling steps to use for affinity prediction. Default is 200.",
913
+ default=200,
914
+ )
915
+ @click.option(
916
+ "--diffusion_samples_affinity",
917
+ type=int,
918
+ help="The number of diffusion samples to use for affinity prediction. Default is 5.",
919
+ default=5,
920
+ )
921
+ @click.option(
922
+ "--affinity_checkpoint",
923
+ type=click.Path(exists=True),
924
+ help="An optional checkpoint, will use the provided Boltz-1 model by default.",
925
+ default=None,
926
+ )
927
+ @click.option(
928
+ "--max_msa_seqs",
929
+ type=int,
930
+ help="The maximum number of MSA sequences to use for prediction. Default is 8192.",
931
+ default=8192,
932
+ )
933
+ @click.option(
934
+ "--subsample_msa",
935
+ is_flag=True,
936
+ help="Whether to subsample the MSA. Default is True.",
937
+ )
938
+ @click.option(
939
+ "--num_subsampled_msa",
940
+ type=int,
941
+ help="The number of MSA sequences to subsample. Default is 1024.",
942
+ default=1024,
943
+ )
944
+ @click.option(
945
+ "--no_kernels",
946
+ is_flag=True,
947
+ help="Whether to disable the kernels. Default False",
948
+ )
949
+ @click.option(
950
+ "--skip_structure",
951
+ is_flag=True,
952
+ help="Skip structure prediction and go straight to affinity prediction. Default is False.",
953
+ )
954
+ def predict( # noqa: C901, PLR0915, PLR0912
955
+ data: str,
956
+ out_dir: str,
957
+ cache: str = "~/.boltz",
958
+ checkpoint: Optional[str] = None,
959
+ affinity_checkpoint: Optional[str] = None,
960
+ devices: int = 1,
961
+ accelerator: str = "gpu",
962
+ recycling_steps: int = 3,
963
+ sampling_steps: int = 200,
964
+ diffusion_samples: int = 1,
965
+ sampling_steps_affinity: int = 200,
966
+ diffusion_samples_affinity: int = 3,
967
+ max_parallel_samples: Optional[int] = None,
968
+ step_scale: Optional[float] = None,
969
+ write_full_pae: bool = False,
970
+ write_full_pde: bool = False,
971
+ output_format: Literal["pdb", "mmcif"] = "mmcif",
972
+ num_workers: int = 2,
973
+ override: bool = False,
974
+ seed: Optional[int] = None,
975
+ use_msa_server: bool = False,
976
+ msa_server_url: str = "https://api.colabfold.com",
977
+ msa_pairing_strategy: str = "greedy",
978
+ use_potentials: bool = False,
979
+ model: Literal["boltz1", "boltz2"] = "boltz2",
980
+ method: Optional[str] = None,
981
+ affinity_mw_correction: Optional[bool] = False,
982
+ preprocessing_threads: int = 1,
983
+ max_msa_seqs: int = 8192,
984
+ subsample_msa: bool = True,
985
+ num_subsampled_msa: int = 1024,
986
+ no_kernels: bool = False,
987
+ skip_structure: bool = False,
988
+ ) -> None:
989
+ """Run predictions with Boltz."""
990
+ # If cpu, write a friendly warning
991
+ if accelerator == "cpu":
992
+ msg = "Running on CPU, this will be slow. Consider using a GPU."
993
+ click.echo(msg)
994
+
995
+ # Supress some lightning warnings
996
+ warnings.filterwarnings(
997
+ "ignore", ".*that has Tensor Cores. To properly utilize them.*"
998
+ )
999
+
1000
+ # Set no grad
1001
+ torch.set_grad_enabled(False)
1002
+
1003
+ # Ignore matmul precision warning
1004
+ torch.set_float32_matmul_precision("highest")
1005
+
1006
+ # Set rdkit pickle logic
1007
+ Chem.SetDefaultPickleProperties(Chem.PropertyPickleOptions.AllProps)
1008
+
1009
+ # Set seed if desired
1010
+ if seed is not None:
1011
+ seed_everything(seed)
1012
+
1013
+ for key in ["CUEQ_DEFAULT_CONFIG", "CUEQ_DISABLE_AOT_TUNING"]:
1014
+ # Disable kernel tuning by default,
1015
+ # but do not modify envvar if already set by caller
1016
+ os.environ[key] = os.environ.get(key, "1")
1017
+
1018
+ # Set cache path
1019
+ cache = Path(cache).expanduser()
1020
+ cache.mkdir(parents=True, exist_ok=True)
1021
+
1022
+ # Create output directories
1023
+ data = Path(data).expanduser()
1024
+ out_dir = Path(out_dir).expanduser()
1025
+ out_dir = out_dir / f"boltz_results_{data.stem}"
1026
+ out_dir.mkdir(parents=True, exist_ok=True)
1027
+
1028
+ # Download necessary data and model
1029
+ if model == "boltz1":
1030
+ download_boltz1(cache)
1031
+ elif model == "boltz2":
1032
+ download_boltz2(cache)
1033
+ else:
1034
+ msg = f"Model {model} not supported. Supported: boltz1, boltz2."
1035
+ raise ValueError(f"Model {model} not supported.")
1036
+
1037
+ # Validate inputs
1038
+ data = check_inputs(data)
1039
+
1040
+ # Check method
1041
+ if method is not None:
1042
+ if model == "boltz1":
1043
+ msg = "Method conditioning is not supported for Boltz-1."
1044
+ raise ValueError(msg)
1045
+ if method.lower() not in const.method_types_ids:
1046
+ method_names = list(const.method_types_ids.keys())
1047
+ msg = f"Method {method} not supported. Supported: {method_names}"
1048
+ raise ValueError(msg)
1049
+
1050
+ # Process inputs
1051
+ ccd_path = cache / "ccd.pkl"
1052
+ mol_dir = cache / "mols"
1053
+ manifest = process_inputs(
1054
+ data=data,
1055
+ out_dir=out_dir,
1056
+ ccd_path=ccd_path,
1057
+ mol_dir=mol_dir,
1058
+ use_msa_server=use_msa_server,
1059
+ msa_server_url=msa_server_url,
1060
+ msa_pairing_strategy=msa_pairing_strategy,
1061
+ boltz2=model == "boltz2",
1062
+ preprocessing_threads=preprocessing_threads,
1063
+ max_msa_seqs=max_msa_seqs,
1064
+ )
1065
+
1066
+ # Filter out existing predictions
1067
+ filtered_manifest = filter_inputs_structure(
1068
+ manifest=manifest,
1069
+ outdir=out_dir,
1070
+ override=override,
1071
+ )
1072
+
1073
+ # Load processed data
1074
+ processed_dir = out_dir / "processed"
1075
+ processed = BoltzProcessedInput(
1076
+ manifest=filtered_manifest,
1077
+ targets_dir=processed_dir / "structures",
1078
+ msa_dir=processed_dir / "msa",
1079
+ constraints_dir=(
1080
+ (processed_dir / "constraints")
1081
+ if (processed_dir / "constraints").exists()
1082
+ else None
1083
+ ),
1084
+ template_dir=(
1085
+ (processed_dir / "templates")
1086
+ if (processed_dir / "templates").exists()
1087
+ else None
1088
+ ),
1089
+ extra_mols_dir=(
1090
+ (processed_dir / "mols") if (processed_dir / "mols").exists() else None
1091
+ ),
1092
+ )
1093
+
1094
+ # Set up trainer
1095
+ strategy = "auto"
1096
+ if (isinstance(devices, int) and devices > 1) or (
1097
+ isinstance(devices, list) and len(devices) > 1
1098
+ ):
1099
+ start_method = "fork" if platform.system() != "win32" else "spawn"
1100
+ strategy = DDPStrategy(start_method=start_method)
1101
+ if len(filtered_manifest.records) < devices:
1102
+ msg = (
1103
+ "Number of requested devices is greater "
1104
+ "than the number of predictions, taking the minimum."
1105
+ )
1106
+ click.echo(msg)
1107
+ if isinstance(devices, list):
1108
+ devices = devices[: max(1, len(filtered_manifest.records))]
1109
+ else:
1110
+ devices = max(1, min(len(filtered_manifest.records), devices))
1111
+
1112
+ # Set up model parameters
1113
+ if model == "boltz2":
1114
+ diffusion_params = Boltz2DiffusionParams()
1115
+ step_scale = 1.5 if step_scale is None else step_scale
1116
+ diffusion_params.step_scale = step_scale
1117
+ pairformer_args = PairformerArgsV2()
1118
+ else:
1119
+ diffusion_params = BoltzDiffusionParams()
1120
+ step_scale = 1.638 if step_scale is None else step_scale
1121
+ diffusion_params.step_scale = step_scale
1122
+ pairformer_args = PairformerArgs()
1123
+
1124
+ msa_args = MSAModuleArgs(
1125
+ subsample_msa=subsample_msa,
1126
+ num_subsampled_msa=num_subsampled_msa,
1127
+ use_paired_feature=model == "boltz2",
1128
+ )
1129
+
1130
+ # Create prediction writer
1131
+ pred_writer = BoltzWriter(
1132
+ data_dir=processed.targets_dir,
1133
+ output_dir=out_dir / "predictions",
1134
+ output_format=output_format,
1135
+ boltz2=model == "boltz2",
1136
+ )
1137
+
1138
+ # Set up trainer
1139
+ trainer = Trainer(
1140
+ default_root_dir=out_dir,
1141
+ strategy=strategy,
1142
+ callbacks=[pred_writer],
1143
+ accelerator=accelerator,
1144
+ devices=devices,
1145
+ precision=32 if model == "boltz1" else "bf16-mixed",
1146
+ )
1147
+
1148
+ if filtered_manifest.records and not skip_structure:
1149
+ msg = f"Running structure prediction for {len(filtered_manifest.records)} input"
1150
+ msg += "s." if len(filtered_manifest.records) > 1 else "."
1151
+ click.echo(msg)
1152
+
1153
+ # Create data module
1154
+ if model == "boltz2":
1155
+ data_module = Boltz2InferenceDataModule(
1156
+ manifest=processed.manifest,
1157
+ target_dir=processed.targets_dir,
1158
+ msa_dir=processed.msa_dir,
1159
+ mol_dir=mol_dir,
1160
+ num_workers=num_workers,
1161
+ constraints_dir=processed.constraints_dir,
1162
+ template_dir=processed.template_dir,
1163
+ extra_mols_dir=processed.extra_mols_dir,
1164
+ override_method=method,
1165
+ )
1166
+ else:
1167
+ data_module = BoltzInferenceDataModule(
1168
+ manifest=processed.manifest,
1169
+ target_dir=processed.targets_dir,
1170
+ msa_dir=processed.msa_dir,
1171
+ num_workers=num_workers,
1172
+ constraints_dir=processed.constraints_dir,
1173
+ )
1174
+
1175
+ # Load model
1176
+ if checkpoint is None:
1177
+ if model == "boltz2":
1178
+ checkpoint = cache / "boltz2_conf.ckpt"
1179
+ else:
1180
+ checkpoint = cache / "boltz1_conf.ckpt"
1181
+
1182
+ predict_args = {
1183
+ "recycling_steps": recycling_steps,
1184
+ "sampling_steps": sampling_steps,
1185
+ "diffusion_samples": diffusion_samples,
1186
+ "max_parallel_samples": max_parallel_samples,
1187
+ "write_confidence_summary": True,
1188
+ "write_full_pae": write_full_pae,
1189
+ "write_full_pde": write_full_pde,
1190
+ }
1191
+
1192
+ steering_args = BoltzSteeringParams()
1193
+ steering_args.fk_steering = use_potentials
1194
+ steering_args.guidance_update = use_potentials
1195
+
1196
+ model_cls = Boltz2 if model == "boltz2" else Boltz1
1197
+ model_module = model_cls.load_from_checkpoint(
1198
+ checkpoint,
1199
+ strict=True,
1200
+ predict_args=predict_args,
1201
+ map_location="cpu",
1202
+ diffusion_process_args=asdict(diffusion_params),
1203
+ ema=False,
1204
+ use_kernels=not no_kernels,
1205
+ pairformer_args=asdict(pairformer_args),
1206
+ msa_args=asdict(msa_args),
1207
+ steering_args=asdict(steering_args),
1208
+ )
1209
+ model_module.eval()
1210
+
1211
+ # Compute structure predictions
1212
+ trainer.predict(
1213
+ model_module,
1214
+ datamodule=data_module,
1215
+ return_predictions=False,
1216
+ )
1217
+
1218
+ # Check if affinity predictions are needed
1219
+ if any(r.affinity for r in manifest.records):
1220
+ # Print header
1221
+ click.echo("\nPredicting property: affinity\n")
1222
+
1223
+ # Validate inputs
1224
+ manifest_filtered = filter_inputs_affinity(
1225
+ manifest=manifest,
1226
+ outdir=out_dir,
1227
+ override=override,
1228
+ )
1229
+ if not manifest_filtered.records:
1230
+ click.echo("Found existing affinity predictions for all inputs, skipping.")
1231
+ return
1232
+
1233
+ msg = f"Running affinity prediction for {len(manifest_filtered.records)} input"
1234
+ msg += "s." if len(manifest_filtered.records) > 1 else "."
1235
+ click.echo(msg)
1236
+
1237
+ pred_writer = BoltzAffinityWriter(
1238
+ data_dir=processed.targets_dir,
1239
+ output_dir=out_dir / "predictions",
1240
+ )
1241
+
1242
+ data_module = Boltz2InferenceDataModule(
1243
+ manifest=manifest_filtered,
1244
+ target_dir=out_dir / "predictions",
1245
+ msa_dir=processed.msa_dir,
1246
+ mol_dir=mol_dir,
1247
+ num_workers=num_workers,
1248
+ constraints_dir=processed.constraints_dir,
1249
+ template_dir=processed.template_dir,
1250
+ extra_mols_dir=processed.extra_mols_dir,
1251
+ override_method="other",
1252
+ affinity=True,
1253
+ )
1254
+
1255
+ predict_affinity_args = {
1256
+ "recycling_steps": 5,
1257
+ "sampling_steps": sampling_steps_affinity,
1258
+ "diffusion_samples": diffusion_samples_affinity,
1259
+ "max_parallel_samples": 1,
1260
+ "write_confidence_summary": False,
1261
+ "write_full_pae": False,
1262
+ "write_full_pde": False,
1263
+ }
1264
+
1265
+ # Load affinity model
1266
+ if affinity_checkpoint is None:
1267
+ affinity_checkpoint = cache / "boltz2_aff.ckpt"
1268
+
1269
+ model_module = Boltz2.load_from_checkpoint(
1270
+ affinity_checkpoint,
1271
+ strict=True,
1272
+ predict_args=predict_affinity_args,
1273
+ map_location="cpu",
1274
+ diffusion_process_args=asdict(diffusion_params),
1275
+ ema=False,
1276
+ pairformer_args=asdict(pairformer_args),
1277
+ msa_args=asdict(msa_args),
1278
+ steering_args={"fk_steering": False, "guidance_update": False},
1279
+ affinity_mw_correction=affinity_mw_correction,
1280
+ )
1281
+ model_module.eval()
1282
+
1283
+ trainer.callbacks[0] = pred_writer
1284
+ trainer.predict(
1285
+ model_module,
1286
+ datamodule=data_module,
1287
+ return_predictions=False,
1288
+ )
1289
+
1290
+
1291
+ if __name__ == "__main__":
1292
+ cli()