boltz-vsynthes 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltz/__init__.py +7 -0
- boltz/data/__init__.py +0 -0
- boltz/data/const.py +1184 -0
- boltz/data/crop/__init__.py +0 -0
- boltz/data/crop/affinity.py +164 -0
- boltz/data/crop/boltz.py +296 -0
- boltz/data/crop/cropper.py +45 -0
- boltz/data/feature/__init__.py +0 -0
- boltz/data/feature/featurizer.py +1230 -0
- boltz/data/feature/featurizerv2.py +2208 -0
- boltz/data/feature/symmetry.py +602 -0
- boltz/data/filter/__init__.py +0 -0
- boltz/data/filter/dynamic/__init__.py +0 -0
- boltz/data/filter/dynamic/date.py +76 -0
- boltz/data/filter/dynamic/filter.py +24 -0
- boltz/data/filter/dynamic/max_residues.py +37 -0
- boltz/data/filter/dynamic/resolution.py +34 -0
- boltz/data/filter/dynamic/size.py +38 -0
- boltz/data/filter/dynamic/subset.py +42 -0
- boltz/data/filter/static/__init__.py +0 -0
- boltz/data/filter/static/filter.py +26 -0
- boltz/data/filter/static/ligand.py +37 -0
- boltz/data/filter/static/polymer.py +299 -0
- boltz/data/module/__init__.py +0 -0
- boltz/data/module/inference.py +307 -0
- boltz/data/module/inferencev2.py +429 -0
- boltz/data/module/training.py +684 -0
- boltz/data/module/trainingv2.py +660 -0
- boltz/data/mol.py +900 -0
- boltz/data/msa/__init__.py +0 -0
- boltz/data/msa/mmseqs2.py +235 -0
- boltz/data/pad.py +84 -0
- boltz/data/parse/__init__.py +0 -0
- boltz/data/parse/a3m.py +134 -0
- boltz/data/parse/csv.py +100 -0
- boltz/data/parse/fasta.py +138 -0
- boltz/data/parse/mmcif.py +1239 -0
- boltz/data/parse/mmcif_with_constraints.py +1607 -0
- boltz/data/parse/schema.py +1851 -0
- boltz/data/parse/yaml.py +68 -0
- boltz/data/sample/__init__.py +0 -0
- boltz/data/sample/cluster.py +283 -0
- boltz/data/sample/distillation.py +57 -0
- boltz/data/sample/random.py +39 -0
- boltz/data/sample/sampler.py +49 -0
- boltz/data/tokenize/__init__.py +0 -0
- boltz/data/tokenize/boltz.py +195 -0
- boltz/data/tokenize/boltz2.py +396 -0
- boltz/data/tokenize/tokenizer.py +24 -0
- boltz/data/types.py +777 -0
- boltz/data/write/__init__.py +0 -0
- boltz/data/write/mmcif.py +305 -0
- boltz/data/write/pdb.py +171 -0
- boltz/data/write/utils.py +23 -0
- boltz/data/write/writer.py +330 -0
- boltz/main.py +1292 -0
- boltz/model/__init__.py +0 -0
- boltz/model/layers/__init__.py +0 -0
- boltz/model/layers/attention.py +132 -0
- boltz/model/layers/attentionv2.py +111 -0
- boltz/model/layers/confidence_utils.py +231 -0
- boltz/model/layers/dropout.py +34 -0
- boltz/model/layers/initialize.py +100 -0
- boltz/model/layers/outer_product_mean.py +98 -0
- boltz/model/layers/pair_averaging.py +135 -0
- boltz/model/layers/pairformer.py +337 -0
- boltz/model/layers/relative.py +58 -0
- boltz/model/layers/transition.py +78 -0
- boltz/model/layers/triangular_attention/__init__.py +0 -0
- boltz/model/layers/triangular_attention/attention.py +189 -0
- boltz/model/layers/triangular_attention/primitives.py +409 -0
- boltz/model/layers/triangular_attention/utils.py +380 -0
- boltz/model/layers/triangular_mult.py +212 -0
- boltz/model/loss/__init__.py +0 -0
- boltz/model/loss/bfactor.py +49 -0
- boltz/model/loss/confidence.py +590 -0
- boltz/model/loss/confidencev2.py +621 -0
- boltz/model/loss/diffusion.py +171 -0
- boltz/model/loss/diffusionv2.py +134 -0
- boltz/model/loss/distogram.py +48 -0
- boltz/model/loss/distogramv2.py +105 -0
- boltz/model/loss/validation.py +1025 -0
- boltz/model/models/__init__.py +0 -0
- boltz/model/models/boltz1.py +1286 -0
- boltz/model/models/boltz2.py +1249 -0
- boltz/model/modules/__init__.py +0 -0
- boltz/model/modules/affinity.py +223 -0
- boltz/model/modules/confidence.py +481 -0
- boltz/model/modules/confidence_utils.py +181 -0
- boltz/model/modules/confidencev2.py +495 -0
- boltz/model/modules/diffusion.py +844 -0
- boltz/model/modules/diffusion_conditioning.py +116 -0
- boltz/model/modules/diffusionv2.py +677 -0
- boltz/model/modules/encoders.py +639 -0
- boltz/model/modules/encodersv2.py +565 -0
- boltz/model/modules/transformers.py +322 -0
- boltz/model/modules/transformersv2.py +261 -0
- boltz/model/modules/trunk.py +688 -0
- boltz/model/modules/trunkv2.py +828 -0
- boltz/model/modules/utils.py +303 -0
- boltz/model/optim/__init__.py +0 -0
- boltz/model/optim/ema.py +389 -0
- boltz/model/optim/scheduler.py +99 -0
- boltz/model/potentials/__init__.py +0 -0
- boltz/model/potentials/potentials.py +497 -0
- boltz/model/potentials/schedules.py +32 -0
- boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
- boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
- boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
- boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
- boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
- boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,429 @@
|
|
1
|
+
import pickle
|
2
|
+
from pathlib import Path
|
3
|
+
from typing import Optional
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
import pytorch_lightning as pl
|
7
|
+
import torch
|
8
|
+
from torch import Tensor
|
9
|
+
from torch.utils.data import DataLoader
|
10
|
+
|
11
|
+
from boltz.data import const
|
12
|
+
from boltz.data.crop.affinity import AffinityCropper
|
13
|
+
from boltz.data.feature.featurizerv2 import Boltz2Featurizer
|
14
|
+
from boltz.data.mol import load_canonicals, load_molecules
|
15
|
+
from boltz.data.pad import pad_to_max
|
16
|
+
from boltz.data.tokenize.boltz2 import Boltz2Tokenizer
|
17
|
+
from boltz.data.types import (
|
18
|
+
MSA,
|
19
|
+
Input,
|
20
|
+
Manifest,
|
21
|
+
Record,
|
22
|
+
ResidueConstraints,
|
23
|
+
StructureV2,
|
24
|
+
)
|
25
|
+
|
26
|
+
|
27
|
+
def load_input(
|
28
|
+
record: Record,
|
29
|
+
target_dir: Path,
|
30
|
+
msa_dir: Path,
|
31
|
+
constraints_dir: Optional[Path] = None,
|
32
|
+
template_dir: Optional[Path] = None,
|
33
|
+
extra_mols_dir: Optional[Path] = None,
|
34
|
+
affinity: bool = False,
|
35
|
+
) -> Input:
|
36
|
+
"""Load the given input data.
|
37
|
+
|
38
|
+
Parameters
|
39
|
+
----------
|
40
|
+
record : Record
|
41
|
+
The record to load.
|
42
|
+
target_dir : Path
|
43
|
+
The path to the data directory.
|
44
|
+
msa_dir : Path
|
45
|
+
The path to msa directory.
|
46
|
+
constraints_dir : Optional[Path]
|
47
|
+
The path to the constraints directory.
|
48
|
+
template_dir : Optional[Path]
|
49
|
+
The path to the template directory.
|
50
|
+
extra_mols_dir : Optional[Path]
|
51
|
+
The path to the extra molecules directory.
|
52
|
+
affinity : bool
|
53
|
+
Whether to load the affinity data.
|
54
|
+
|
55
|
+
Returns
|
56
|
+
-------
|
57
|
+
Input
|
58
|
+
The loaded input.
|
59
|
+
|
60
|
+
"""
|
61
|
+
# Load the structure
|
62
|
+
if affinity:
|
63
|
+
structure = StructureV2.load(
|
64
|
+
target_dir / record.id / f"pre_affinity_{record.id}.npz"
|
65
|
+
)
|
66
|
+
else:
|
67
|
+
structure = StructureV2.load(target_dir / f"{record.id}.npz")
|
68
|
+
|
69
|
+
msas = {}
|
70
|
+
for chain in record.chains:
|
71
|
+
msa_id = chain.msa_id
|
72
|
+
# Load the MSA for this chain, if any
|
73
|
+
if msa_id != -1:
|
74
|
+
msa = MSA.load(msa_dir / f"{msa_id}.npz")
|
75
|
+
msas[chain.chain_id] = msa
|
76
|
+
|
77
|
+
# Load templates
|
78
|
+
templates = None
|
79
|
+
if record.templates and template_dir is not None:
|
80
|
+
templates = {}
|
81
|
+
for template_info in record.templates:
|
82
|
+
template_id = template_info.name
|
83
|
+
template_path = template_dir / f"{record.id}_{template_id}.npz"
|
84
|
+
template = StructureV2.load(template_path)
|
85
|
+
templates[template_id] = template
|
86
|
+
|
87
|
+
# Load residue constraints
|
88
|
+
residue_constraints = None
|
89
|
+
if constraints_dir is not None:
|
90
|
+
residue_constraints = ResidueConstraints.load(
|
91
|
+
constraints_dir / f"{record.id}.npz"
|
92
|
+
)
|
93
|
+
|
94
|
+
# Load extra molecules
|
95
|
+
extra_mols = {}
|
96
|
+
if extra_mols_dir is not None:
|
97
|
+
extra_mol_path = extra_mols_dir / f"{record.id}.pkl"
|
98
|
+
if extra_mol_path.exists():
|
99
|
+
with extra_mol_path.open("rb") as f:
|
100
|
+
extra_mols = pickle.load(f) # noqa: S301
|
101
|
+
|
102
|
+
return Input(
|
103
|
+
structure,
|
104
|
+
msas,
|
105
|
+
record=record,
|
106
|
+
residue_constraints=residue_constraints,
|
107
|
+
templates=templates,
|
108
|
+
extra_mols=extra_mols,
|
109
|
+
)
|
110
|
+
|
111
|
+
|
112
|
+
def collate(data: list[dict[str, Tensor]]) -> dict[str, Tensor]:
|
113
|
+
"""Collate the data.
|
114
|
+
|
115
|
+
Parameters
|
116
|
+
----------
|
117
|
+
data : List[Dict[str, Tensor]]
|
118
|
+
The data to collate.
|
119
|
+
|
120
|
+
Returns
|
121
|
+
-------
|
122
|
+
Dict[str, Tensor]
|
123
|
+
The collated data.
|
124
|
+
|
125
|
+
"""
|
126
|
+
# Get the keys
|
127
|
+
keys = data[0].keys()
|
128
|
+
|
129
|
+
# Collate the data
|
130
|
+
collated = {}
|
131
|
+
for key in keys:
|
132
|
+
values = [d[key] for d in data]
|
133
|
+
|
134
|
+
if key not in [
|
135
|
+
"all_coords",
|
136
|
+
"all_resolved_mask",
|
137
|
+
"crop_to_all_atom_map",
|
138
|
+
"chain_symmetries",
|
139
|
+
"amino_acids_symmetries",
|
140
|
+
"ligand_symmetries",
|
141
|
+
"record",
|
142
|
+
"affinity_mw",
|
143
|
+
]:
|
144
|
+
# Check if all have the same shape
|
145
|
+
shape = values[0].shape
|
146
|
+
if not all(v.shape == shape for v in values):
|
147
|
+
values, _ = pad_to_max(values, 0)
|
148
|
+
else:
|
149
|
+
values = torch.stack(values, dim=0)
|
150
|
+
|
151
|
+
# Stack the values
|
152
|
+
collated[key] = values
|
153
|
+
|
154
|
+
return collated
|
155
|
+
|
156
|
+
|
157
|
+
class PredictionDataset(torch.utils.data.Dataset):
|
158
|
+
"""Base iterable dataset."""
|
159
|
+
|
160
|
+
def __init__(
|
161
|
+
self,
|
162
|
+
manifest: Manifest,
|
163
|
+
target_dir: Path,
|
164
|
+
msa_dir: Path,
|
165
|
+
mol_dir: Path,
|
166
|
+
constraints_dir: Optional[Path] = None,
|
167
|
+
template_dir: Optional[Path] = None,
|
168
|
+
extra_mols_dir: Optional[Path] = None,
|
169
|
+
override_method: Optional[str] = None,
|
170
|
+
affinity: bool = False,
|
171
|
+
) -> None:
|
172
|
+
"""Initialize the training dataset.
|
173
|
+
|
174
|
+
Parameters
|
175
|
+
----------
|
176
|
+
manifest : Manifest
|
177
|
+
The manifest to load data from.
|
178
|
+
target_dir : Path
|
179
|
+
The path to the target directory.
|
180
|
+
msa_dir : Path
|
181
|
+
The path to the msa directory.
|
182
|
+
mol_dir : Path
|
183
|
+
The path to the moldir.
|
184
|
+
constraints_dir : Optional[Path]
|
185
|
+
The path to the constraints directory.
|
186
|
+
template_dir : Optional[Path]
|
187
|
+
The path to the template directory.
|
188
|
+
|
189
|
+
"""
|
190
|
+
super().__init__()
|
191
|
+
self.manifest = manifest
|
192
|
+
self.target_dir = target_dir
|
193
|
+
self.msa_dir = msa_dir
|
194
|
+
self.mol_dir = mol_dir
|
195
|
+
self.constraints_dir = constraints_dir
|
196
|
+
self.template_dir = template_dir
|
197
|
+
self.tokenizer = Boltz2Tokenizer()
|
198
|
+
self.featurizer = Boltz2Featurizer()
|
199
|
+
self.canonicals = load_canonicals(self.mol_dir)
|
200
|
+
self.extra_mols_dir = extra_mols_dir
|
201
|
+
self.override_method = override_method
|
202
|
+
self.affinity = affinity
|
203
|
+
if self.affinity:
|
204
|
+
self.cropper = AffinityCropper()
|
205
|
+
|
206
|
+
def __getitem__(self, idx: int) -> dict:
|
207
|
+
"""Get an item from the dataset.
|
208
|
+
|
209
|
+
Returns
|
210
|
+
-------
|
211
|
+
Dict[str, Tensor]
|
212
|
+
The sampled data features.
|
213
|
+
|
214
|
+
"""
|
215
|
+
# Get record
|
216
|
+
record = self.manifest.records[idx]
|
217
|
+
|
218
|
+
# Finalize input data
|
219
|
+
input_data = load_input(
|
220
|
+
record=record,
|
221
|
+
target_dir=self.target_dir,
|
222
|
+
msa_dir=self.msa_dir,
|
223
|
+
constraints_dir=self.constraints_dir,
|
224
|
+
template_dir=self.template_dir,
|
225
|
+
extra_mols_dir=self.extra_mols_dir,
|
226
|
+
affinity=self.affinity,
|
227
|
+
)
|
228
|
+
|
229
|
+
# Tokenize structure
|
230
|
+
try:
|
231
|
+
tokenized = self.tokenizer.tokenize(input_data)
|
232
|
+
except Exception as e: # noqa: BLE001
|
233
|
+
print( # noqa: T201
|
234
|
+
f"Tokenizer failed on {record.id} with error {e}. Skipping."
|
235
|
+
)
|
236
|
+
return self.__getitem__(0)
|
237
|
+
|
238
|
+
if self.affinity:
|
239
|
+
try:
|
240
|
+
tokenized = self.cropper.crop(
|
241
|
+
tokenized,
|
242
|
+
max_tokens=256,
|
243
|
+
max_atoms=2048,
|
244
|
+
)
|
245
|
+
except Exception as e: # noqa: BLE001
|
246
|
+
print(f"Cropper failed on {record.id} with error {e}. Skipping.") # noqa: T201
|
247
|
+
return self.__getitem__(0)
|
248
|
+
|
249
|
+
# Load conformers
|
250
|
+
try:
|
251
|
+
molecules = {}
|
252
|
+
molecules.update(self.canonicals)
|
253
|
+
molecules.update(input_data.extra_mols)
|
254
|
+
mol_names = set(tokenized.tokens["res_name"].tolist())
|
255
|
+
mol_names = mol_names - set(molecules.keys())
|
256
|
+
molecules.update(load_molecules(self.mol_dir, mol_names))
|
257
|
+
except Exception as e: # noqa: BLE001
|
258
|
+
print(f"Molecule loading failed for {record.id} with error {e}. Skipping.")
|
259
|
+
return self.__getitem__(0)
|
260
|
+
|
261
|
+
# Inference specific options
|
262
|
+
options = record.inference_options
|
263
|
+
if options is None:
|
264
|
+
pocket_constraints = None, None
|
265
|
+
else:
|
266
|
+
pocket_constraints = options.pocket_constraints
|
267
|
+
|
268
|
+
# Get random seed
|
269
|
+
seed = 42
|
270
|
+
random = np.random.default_rng(seed)
|
271
|
+
|
272
|
+
# Compute features
|
273
|
+
try:
|
274
|
+
features = self.featurizer.process(
|
275
|
+
tokenized,
|
276
|
+
molecules=molecules,
|
277
|
+
random=random,
|
278
|
+
training=False,
|
279
|
+
max_atoms=None,
|
280
|
+
max_tokens=None,
|
281
|
+
max_seqs=const.max_msa_seqs,
|
282
|
+
pad_to_max_seqs=False,
|
283
|
+
single_sequence_prop=0.0,
|
284
|
+
compute_frames=True,
|
285
|
+
inference_pocket_constraints=pocket_constraints,
|
286
|
+
compute_constraint_features=True,
|
287
|
+
override_method=self.override_method,
|
288
|
+
compute_affinity=self.affinity,
|
289
|
+
)
|
290
|
+
except Exception as e: # noqa: BLE001
|
291
|
+
import traceback
|
292
|
+
|
293
|
+
traceback.print_exc()
|
294
|
+
print(f"Featurizer failed on {record.id} with error {e}. Skipping.") # noqa: T201
|
295
|
+
return self.__getitem__(0)
|
296
|
+
|
297
|
+
# Add record
|
298
|
+
features["record"] = record
|
299
|
+
return features
|
300
|
+
|
301
|
+
def __len__(self) -> int:
|
302
|
+
"""Get the length of the dataset.
|
303
|
+
|
304
|
+
Returns
|
305
|
+
-------
|
306
|
+
int
|
307
|
+
The length of the dataset.
|
308
|
+
|
309
|
+
"""
|
310
|
+
return len(self.manifest.records)
|
311
|
+
|
312
|
+
|
313
|
+
class Boltz2InferenceDataModule(pl.LightningDataModule):
|
314
|
+
"""DataModule for Boltz2 inference."""
|
315
|
+
|
316
|
+
def __init__(
|
317
|
+
self,
|
318
|
+
manifest: Manifest,
|
319
|
+
target_dir: Path,
|
320
|
+
msa_dir: Path,
|
321
|
+
mol_dir: Path,
|
322
|
+
num_workers: int,
|
323
|
+
constraints_dir: Optional[Path] = None,
|
324
|
+
template_dir: Optional[Path] = None,
|
325
|
+
extra_mols_dir: Optional[Path] = None,
|
326
|
+
override_method: Optional[str] = None,
|
327
|
+
affinity: bool = False,
|
328
|
+
) -> None:
|
329
|
+
"""Initialize the DataModule.
|
330
|
+
|
331
|
+
Parameters
|
332
|
+
----------
|
333
|
+
manifest : Manifest
|
334
|
+
The manifest to load data from.
|
335
|
+
target_dir : Path
|
336
|
+
The path to the target directory.
|
337
|
+
msa_dir : Path
|
338
|
+
The path to the msa directory.
|
339
|
+
mol_dir : Path
|
340
|
+
The path to the moldir.
|
341
|
+
num_workers : int
|
342
|
+
The number of workers to use.
|
343
|
+
constraints_dir : Optional[Path]
|
344
|
+
The path to the constraints directory.
|
345
|
+
template_dir : Optional[Path]
|
346
|
+
The path to the template directory.
|
347
|
+
extra_mols_dir : Optional[Path]
|
348
|
+
The path to the extra molecules directory.
|
349
|
+
override_method : Optional[str]
|
350
|
+
The method to override.
|
351
|
+
|
352
|
+
"""
|
353
|
+
super().__init__()
|
354
|
+
self.num_workers = num_workers
|
355
|
+
self.manifest = manifest
|
356
|
+
self.target_dir = target_dir
|
357
|
+
self.msa_dir = msa_dir
|
358
|
+
self.mol_dir = mol_dir
|
359
|
+
self.constraints_dir = constraints_dir
|
360
|
+
self.template_dir = template_dir
|
361
|
+
self.extra_mols_dir = extra_mols_dir
|
362
|
+
self.override_method = override_method
|
363
|
+
self.affinity = affinity
|
364
|
+
|
365
|
+
def predict_dataloader(self) -> DataLoader:
|
366
|
+
"""Get the training dataloader.
|
367
|
+
|
368
|
+
Returns
|
369
|
+
-------
|
370
|
+
DataLoader
|
371
|
+
The training dataloader.
|
372
|
+
|
373
|
+
"""
|
374
|
+
dataset = PredictionDataset(
|
375
|
+
manifest=self.manifest,
|
376
|
+
target_dir=self.target_dir,
|
377
|
+
msa_dir=self.msa_dir,
|
378
|
+
mol_dir=self.mol_dir,
|
379
|
+
constraints_dir=self.constraints_dir,
|
380
|
+
template_dir=self.template_dir,
|
381
|
+
extra_mols_dir=self.extra_mols_dir,
|
382
|
+
override_method=self.override_method,
|
383
|
+
affinity=self.affinity,
|
384
|
+
)
|
385
|
+
return DataLoader(
|
386
|
+
dataset,
|
387
|
+
batch_size=1,
|
388
|
+
num_workers=self.num_workers,
|
389
|
+
pin_memory=True,
|
390
|
+
shuffle=False,
|
391
|
+
collate_fn=collate,
|
392
|
+
)
|
393
|
+
|
394
|
+
def transfer_batch_to_device(
|
395
|
+
self,
|
396
|
+
batch: dict,
|
397
|
+
device: torch.device,
|
398
|
+
dataloader_idx: int, # noqa: ARG002
|
399
|
+
) -> dict:
|
400
|
+
"""Transfer a batch to the given device.
|
401
|
+
|
402
|
+
Parameters
|
403
|
+
----------
|
404
|
+
batch : Dict
|
405
|
+
The batch to transfer.
|
406
|
+
device : torch.device
|
407
|
+
The device to transfer to.
|
408
|
+
dataloader_idx : int
|
409
|
+
The dataloader index.
|
410
|
+
|
411
|
+
Returns
|
412
|
+
-------
|
413
|
+
np.Any
|
414
|
+
The transferred batch.
|
415
|
+
|
416
|
+
"""
|
417
|
+
for key in batch:
|
418
|
+
if key not in [
|
419
|
+
"all_coords",
|
420
|
+
"all_resolved_mask",
|
421
|
+
"crop_to_all_atom_map",
|
422
|
+
"chain_symmetries",
|
423
|
+
"amino_acids_symmetries",
|
424
|
+
"ligand_symmetries",
|
425
|
+
"record",
|
426
|
+
"affinity_mw",
|
427
|
+
]:
|
428
|
+
batch[key] = batch[key].to(device)
|
429
|
+
return batch
|