boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,429 @@
1
+ import pickle
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ import pytorch_lightning as pl
7
+ import torch
8
+ from torch import Tensor
9
+ from torch.utils.data import DataLoader
10
+
11
+ from boltz.data import const
12
+ from boltz.data.crop.affinity import AffinityCropper
13
+ from boltz.data.feature.featurizerv2 import Boltz2Featurizer
14
+ from boltz.data.mol import load_canonicals, load_molecules
15
+ from boltz.data.pad import pad_to_max
16
+ from boltz.data.tokenize.boltz2 import Boltz2Tokenizer
17
+ from boltz.data.types import (
18
+ MSA,
19
+ Input,
20
+ Manifest,
21
+ Record,
22
+ ResidueConstraints,
23
+ StructureV2,
24
+ )
25
+
26
+
27
+ def load_input(
28
+ record: Record,
29
+ target_dir: Path,
30
+ msa_dir: Path,
31
+ constraints_dir: Optional[Path] = None,
32
+ template_dir: Optional[Path] = None,
33
+ extra_mols_dir: Optional[Path] = None,
34
+ affinity: bool = False,
35
+ ) -> Input:
36
+ """Load the given input data.
37
+
38
+ Parameters
39
+ ----------
40
+ record : Record
41
+ The record to load.
42
+ target_dir : Path
43
+ The path to the data directory.
44
+ msa_dir : Path
45
+ The path to msa directory.
46
+ constraints_dir : Optional[Path]
47
+ The path to the constraints directory.
48
+ template_dir : Optional[Path]
49
+ The path to the template directory.
50
+ extra_mols_dir : Optional[Path]
51
+ The path to the extra molecules directory.
52
+ affinity : bool
53
+ Whether to load the affinity data.
54
+
55
+ Returns
56
+ -------
57
+ Input
58
+ The loaded input.
59
+
60
+ """
61
+ # Load the structure
62
+ if affinity:
63
+ structure = StructureV2.load(
64
+ target_dir / record.id / f"pre_affinity_{record.id}.npz"
65
+ )
66
+ else:
67
+ structure = StructureV2.load(target_dir / f"{record.id}.npz")
68
+
69
+ msas = {}
70
+ for chain in record.chains:
71
+ msa_id = chain.msa_id
72
+ # Load the MSA for this chain, if any
73
+ if msa_id != -1:
74
+ msa = MSA.load(msa_dir / f"{msa_id}.npz")
75
+ msas[chain.chain_id] = msa
76
+
77
+ # Load templates
78
+ templates = None
79
+ if record.templates and template_dir is not None:
80
+ templates = {}
81
+ for template_info in record.templates:
82
+ template_id = template_info.name
83
+ template_path = template_dir / f"{record.id}_{template_id}.npz"
84
+ template = StructureV2.load(template_path)
85
+ templates[template_id] = template
86
+
87
+ # Load residue constraints
88
+ residue_constraints = None
89
+ if constraints_dir is not None:
90
+ residue_constraints = ResidueConstraints.load(
91
+ constraints_dir / f"{record.id}.npz"
92
+ )
93
+
94
+ # Load extra molecules
95
+ extra_mols = {}
96
+ if extra_mols_dir is not None:
97
+ extra_mol_path = extra_mols_dir / f"{record.id}.pkl"
98
+ if extra_mol_path.exists():
99
+ with extra_mol_path.open("rb") as f:
100
+ extra_mols = pickle.load(f) # noqa: S301
101
+
102
+ return Input(
103
+ structure,
104
+ msas,
105
+ record=record,
106
+ residue_constraints=residue_constraints,
107
+ templates=templates,
108
+ extra_mols=extra_mols,
109
+ )
110
+
111
+
112
+ def collate(data: list[dict[str, Tensor]]) -> dict[str, Tensor]:
113
+ """Collate the data.
114
+
115
+ Parameters
116
+ ----------
117
+ data : List[Dict[str, Tensor]]
118
+ The data to collate.
119
+
120
+ Returns
121
+ -------
122
+ Dict[str, Tensor]
123
+ The collated data.
124
+
125
+ """
126
+ # Get the keys
127
+ keys = data[0].keys()
128
+
129
+ # Collate the data
130
+ collated = {}
131
+ for key in keys:
132
+ values = [d[key] for d in data]
133
+
134
+ if key not in [
135
+ "all_coords",
136
+ "all_resolved_mask",
137
+ "crop_to_all_atom_map",
138
+ "chain_symmetries",
139
+ "amino_acids_symmetries",
140
+ "ligand_symmetries",
141
+ "record",
142
+ "affinity_mw",
143
+ ]:
144
+ # Check if all have the same shape
145
+ shape = values[0].shape
146
+ if not all(v.shape == shape for v in values):
147
+ values, _ = pad_to_max(values, 0)
148
+ else:
149
+ values = torch.stack(values, dim=0)
150
+
151
+ # Stack the values
152
+ collated[key] = values
153
+
154
+ return collated
155
+
156
+
157
+ class PredictionDataset(torch.utils.data.Dataset):
158
+ """Base iterable dataset."""
159
+
160
+ def __init__(
161
+ self,
162
+ manifest: Manifest,
163
+ target_dir: Path,
164
+ msa_dir: Path,
165
+ mol_dir: Path,
166
+ constraints_dir: Optional[Path] = None,
167
+ template_dir: Optional[Path] = None,
168
+ extra_mols_dir: Optional[Path] = None,
169
+ override_method: Optional[str] = None,
170
+ affinity: bool = False,
171
+ ) -> None:
172
+ """Initialize the training dataset.
173
+
174
+ Parameters
175
+ ----------
176
+ manifest : Manifest
177
+ The manifest to load data from.
178
+ target_dir : Path
179
+ The path to the target directory.
180
+ msa_dir : Path
181
+ The path to the msa directory.
182
+ mol_dir : Path
183
+ The path to the moldir.
184
+ constraints_dir : Optional[Path]
185
+ The path to the constraints directory.
186
+ template_dir : Optional[Path]
187
+ The path to the template directory.
188
+
189
+ """
190
+ super().__init__()
191
+ self.manifest = manifest
192
+ self.target_dir = target_dir
193
+ self.msa_dir = msa_dir
194
+ self.mol_dir = mol_dir
195
+ self.constraints_dir = constraints_dir
196
+ self.template_dir = template_dir
197
+ self.tokenizer = Boltz2Tokenizer()
198
+ self.featurizer = Boltz2Featurizer()
199
+ self.canonicals = load_canonicals(self.mol_dir)
200
+ self.extra_mols_dir = extra_mols_dir
201
+ self.override_method = override_method
202
+ self.affinity = affinity
203
+ if self.affinity:
204
+ self.cropper = AffinityCropper()
205
+
206
+ def __getitem__(self, idx: int) -> dict:
207
+ """Get an item from the dataset.
208
+
209
+ Returns
210
+ -------
211
+ Dict[str, Tensor]
212
+ The sampled data features.
213
+
214
+ """
215
+ # Get record
216
+ record = self.manifest.records[idx]
217
+
218
+ # Finalize input data
219
+ input_data = load_input(
220
+ record=record,
221
+ target_dir=self.target_dir,
222
+ msa_dir=self.msa_dir,
223
+ constraints_dir=self.constraints_dir,
224
+ template_dir=self.template_dir,
225
+ extra_mols_dir=self.extra_mols_dir,
226
+ affinity=self.affinity,
227
+ )
228
+
229
+ # Tokenize structure
230
+ try:
231
+ tokenized = self.tokenizer.tokenize(input_data)
232
+ except Exception as e: # noqa: BLE001
233
+ print( # noqa: T201
234
+ f"Tokenizer failed on {record.id} with error {e}. Skipping."
235
+ )
236
+ return self.__getitem__(0)
237
+
238
+ if self.affinity:
239
+ try:
240
+ tokenized = self.cropper.crop(
241
+ tokenized,
242
+ max_tokens=256,
243
+ max_atoms=2048,
244
+ )
245
+ except Exception as e: # noqa: BLE001
246
+ print(f"Cropper failed on {record.id} with error {e}. Skipping.") # noqa: T201
247
+ return self.__getitem__(0)
248
+
249
+ # Load conformers
250
+ try:
251
+ molecules = {}
252
+ molecules.update(self.canonicals)
253
+ molecules.update(input_data.extra_mols)
254
+ mol_names = set(tokenized.tokens["res_name"].tolist())
255
+ mol_names = mol_names - set(molecules.keys())
256
+ molecules.update(load_molecules(self.mol_dir, mol_names))
257
+ except Exception as e: # noqa: BLE001
258
+ print(f"Molecule loading failed for {record.id} with error {e}. Skipping.")
259
+ return self.__getitem__(0)
260
+
261
+ # Inference specific options
262
+ options = record.inference_options
263
+ if options is None:
264
+ pocket_constraints = None, None
265
+ else:
266
+ pocket_constraints = options.pocket_constraints
267
+
268
+ # Get random seed
269
+ seed = 42
270
+ random = np.random.default_rng(seed)
271
+
272
+ # Compute features
273
+ try:
274
+ features = self.featurizer.process(
275
+ tokenized,
276
+ molecules=molecules,
277
+ random=random,
278
+ training=False,
279
+ max_atoms=None,
280
+ max_tokens=None,
281
+ max_seqs=const.max_msa_seqs,
282
+ pad_to_max_seqs=False,
283
+ single_sequence_prop=0.0,
284
+ compute_frames=True,
285
+ inference_pocket_constraints=pocket_constraints,
286
+ compute_constraint_features=True,
287
+ override_method=self.override_method,
288
+ compute_affinity=self.affinity,
289
+ )
290
+ except Exception as e: # noqa: BLE001
291
+ import traceback
292
+
293
+ traceback.print_exc()
294
+ print(f"Featurizer failed on {record.id} with error {e}. Skipping.") # noqa: T201
295
+ return self.__getitem__(0)
296
+
297
+ # Add record
298
+ features["record"] = record
299
+ return features
300
+
301
+ def __len__(self) -> int:
302
+ """Get the length of the dataset.
303
+
304
+ Returns
305
+ -------
306
+ int
307
+ The length of the dataset.
308
+
309
+ """
310
+ return len(self.manifest.records)
311
+
312
+
313
+ class Boltz2InferenceDataModule(pl.LightningDataModule):
314
+ """DataModule for Boltz2 inference."""
315
+
316
+ def __init__(
317
+ self,
318
+ manifest: Manifest,
319
+ target_dir: Path,
320
+ msa_dir: Path,
321
+ mol_dir: Path,
322
+ num_workers: int,
323
+ constraints_dir: Optional[Path] = None,
324
+ template_dir: Optional[Path] = None,
325
+ extra_mols_dir: Optional[Path] = None,
326
+ override_method: Optional[str] = None,
327
+ affinity: bool = False,
328
+ ) -> None:
329
+ """Initialize the DataModule.
330
+
331
+ Parameters
332
+ ----------
333
+ manifest : Manifest
334
+ The manifest to load data from.
335
+ target_dir : Path
336
+ The path to the target directory.
337
+ msa_dir : Path
338
+ The path to the msa directory.
339
+ mol_dir : Path
340
+ The path to the moldir.
341
+ num_workers : int
342
+ The number of workers to use.
343
+ constraints_dir : Optional[Path]
344
+ The path to the constraints directory.
345
+ template_dir : Optional[Path]
346
+ The path to the template directory.
347
+ extra_mols_dir : Optional[Path]
348
+ The path to the extra molecules directory.
349
+ override_method : Optional[str]
350
+ The method to override.
351
+
352
+ """
353
+ super().__init__()
354
+ self.num_workers = num_workers
355
+ self.manifest = manifest
356
+ self.target_dir = target_dir
357
+ self.msa_dir = msa_dir
358
+ self.mol_dir = mol_dir
359
+ self.constraints_dir = constraints_dir
360
+ self.template_dir = template_dir
361
+ self.extra_mols_dir = extra_mols_dir
362
+ self.override_method = override_method
363
+ self.affinity = affinity
364
+
365
+ def predict_dataloader(self) -> DataLoader:
366
+ """Get the training dataloader.
367
+
368
+ Returns
369
+ -------
370
+ DataLoader
371
+ The training dataloader.
372
+
373
+ """
374
+ dataset = PredictionDataset(
375
+ manifest=self.manifest,
376
+ target_dir=self.target_dir,
377
+ msa_dir=self.msa_dir,
378
+ mol_dir=self.mol_dir,
379
+ constraints_dir=self.constraints_dir,
380
+ template_dir=self.template_dir,
381
+ extra_mols_dir=self.extra_mols_dir,
382
+ override_method=self.override_method,
383
+ affinity=self.affinity,
384
+ )
385
+ return DataLoader(
386
+ dataset,
387
+ batch_size=1,
388
+ num_workers=self.num_workers,
389
+ pin_memory=True,
390
+ shuffle=False,
391
+ collate_fn=collate,
392
+ )
393
+
394
+ def transfer_batch_to_device(
395
+ self,
396
+ batch: dict,
397
+ device: torch.device,
398
+ dataloader_idx: int, # noqa: ARG002
399
+ ) -> dict:
400
+ """Transfer a batch to the given device.
401
+
402
+ Parameters
403
+ ----------
404
+ batch : Dict
405
+ The batch to transfer.
406
+ device : torch.device
407
+ The device to transfer to.
408
+ dataloader_idx : int
409
+ The dataloader index.
410
+
411
+ Returns
412
+ -------
413
+ np.Any
414
+ The transferred batch.
415
+
416
+ """
417
+ for key in batch:
418
+ if key not in [
419
+ "all_coords",
420
+ "all_resolved_mask",
421
+ "crop_to_all_atom_map",
422
+ "chain_symmetries",
423
+ "amino_acids_symmetries",
424
+ "ligand_symmetries",
425
+ "record",
426
+ "affinity_mw",
427
+ ]:
428
+ batch[key] = batch[key].to(device)
429
+ return batch