boltz-vsynthes 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltz/__init__.py +7 -0
- boltz/data/__init__.py +0 -0
- boltz/data/const.py +1184 -0
- boltz/data/crop/__init__.py +0 -0
- boltz/data/crop/affinity.py +164 -0
- boltz/data/crop/boltz.py +296 -0
- boltz/data/crop/cropper.py +45 -0
- boltz/data/feature/__init__.py +0 -0
- boltz/data/feature/featurizer.py +1230 -0
- boltz/data/feature/featurizerv2.py +2208 -0
- boltz/data/feature/symmetry.py +602 -0
- boltz/data/filter/__init__.py +0 -0
- boltz/data/filter/dynamic/__init__.py +0 -0
- boltz/data/filter/dynamic/date.py +76 -0
- boltz/data/filter/dynamic/filter.py +24 -0
- boltz/data/filter/dynamic/max_residues.py +37 -0
- boltz/data/filter/dynamic/resolution.py +34 -0
- boltz/data/filter/dynamic/size.py +38 -0
- boltz/data/filter/dynamic/subset.py +42 -0
- boltz/data/filter/static/__init__.py +0 -0
- boltz/data/filter/static/filter.py +26 -0
- boltz/data/filter/static/ligand.py +37 -0
- boltz/data/filter/static/polymer.py +299 -0
- boltz/data/module/__init__.py +0 -0
- boltz/data/module/inference.py +307 -0
- boltz/data/module/inferencev2.py +429 -0
- boltz/data/module/training.py +684 -0
- boltz/data/module/trainingv2.py +660 -0
- boltz/data/mol.py +900 -0
- boltz/data/msa/__init__.py +0 -0
- boltz/data/msa/mmseqs2.py +235 -0
- boltz/data/pad.py +84 -0
- boltz/data/parse/__init__.py +0 -0
- boltz/data/parse/a3m.py +134 -0
- boltz/data/parse/csv.py +100 -0
- boltz/data/parse/fasta.py +138 -0
- boltz/data/parse/mmcif.py +1239 -0
- boltz/data/parse/mmcif_with_constraints.py +1607 -0
- boltz/data/parse/schema.py +1851 -0
- boltz/data/parse/yaml.py +68 -0
- boltz/data/sample/__init__.py +0 -0
- boltz/data/sample/cluster.py +283 -0
- boltz/data/sample/distillation.py +57 -0
- boltz/data/sample/random.py +39 -0
- boltz/data/sample/sampler.py +49 -0
- boltz/data/tokenize/__init__.py +0 -0
- boltz/data/tokenize/boltz.py +195 -0
- boltz/data/tokenize/boltz2.py +396 -0
- boltz/data/tokenize/tokenizer.py +24 -0
- boltz/data/types.py +777 -0
- boltz/data/write/__init__.py +0 -0
- boltz/data/write/mmcif.py +305 -0
- boltz/data/write/pdb.py +171 -0
- boltz/data/write/utils.py +23 -0
- boltz/data/write/writer.py +330 -0
- boltz/main.py +1292 -0
- boltz/model/__init__.py +0 -0
- boltz/model/layers/__init__.py +0 -0
- boltz/model/layers/attention.py +132 -0
- boltz/model/layers/attentionv2.py +111 -0
- boltz/model/layers/confidence_utils.py +231 -0
- boltz/model/layers/dropout.py +34 -0
- boltz/model/layers/initialize.py +100 -0
- boltz/model/layers/outer_product_mean.py +98 -0
- boltz/model/layers/pair_averaging.py +135 -0
- boltz/model/layers/pairformer.py +337 -0
- boltz/model/layers/relative.py +58 -0
- boltz/model/layers/transition.py +78 -0
- boltz/model/layers/triangular_attention/__init__.py +0 -0
- boltz/model/layers/triangular_attention/attention.py +189 -0
- boltz/model/layers/triangular_attention/primitives.py +409 -0
- boltz/model/layers/triangular_attention/utils.py +380 -0
- boltz/model/layers/triangular_mult.py +212 -0
- boltz/model/loss/__init__.py +0 -0
- boltz/model/loss/bfactor.py +49 -0
- boltz/model/loss/confidence.py +590 -0
- boltz/model/loss/confidencev2.py +621 -0
- boltz/model/loss/diffusion.py +171 -0
- boltz/model/loss/diffusionv2.py +134 -0
- boltz/model/loss/distogram.py +48 -0
- boltz/model/loss/distogramv2.py +105 -0
- boltz/model/loss/validation.py +1025 -0
- boltz/model/models/__init__.py +0 -0
- boltz/model/models/boltz1.py +1286 -0
- boltz/model/models/boltz2.py +1249 -0
- boltz/model/modules/__init__.py +0 -0
- boltz/model/modules/affinity.py +223 -0
- boltz/model/modules/confidence.py +481 -0
- boltz/model/modules/confidence_utils.py +181 -0
- boltz/model/modules/confidencev2.py +495 -0
- boltz/model/modules/diffusion.py +844 -0
- boltz/model/modules/diffusion_conditioning.py +116 -0
- boltz/model/modules/diffusionv2.py +677 -0
- boltz/model/modules/encoders.py +639 -0
- boltz/model/modules/encodersv2.py +565 -0
- boltz/model/modules/transformers.py +322 -0
- boltz/model/modules/transformersv2.py +261 -0
- boltz/model/modules/trunk.py +688 -0
- boltz/model/modules/trunkv2.py +828 -0
- boltz/model/modules/utils.py +303 -0
- boltz/model/optim/__init__.py +0 -0
- boltz/model/optim/ema.py +389 -0
- boltz/model/optim/scheduler.py +99 -0
- boltz/model/potentials/__init__.py +0 -0
- boltz/model/potentials/potentials.py +497 -0
- boltz/model/potentials/schedules.py +32 -0
- boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
- boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
- boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
- boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
- boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
- boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,307 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import pytorch_lightning as pl
|
6
|
+
import torch
|
7
|
+
from torch import Tensor
|
8
|
+
from torch.utils.data import DataLoader
|
9
|
+
|
10
|
+
from boltz.data import const
|
11
|
+
from boltz.data.feature.featurizer import BoltzFeaturizer
|
12
|
+
from boltz.data.pad import pad_to_max
|
13
|
+
from boltz.data.tokenize.boltz import BoltzTokenizer
|
14
|
+
from boltz.data.types import (
|
15
|
+
MSA,
|
16
|
+
Connection,
|
17
|
+
Input,
|
18
|
+
Manifest,
|
19
|
+
Record,
|
20
|
+
ResidueConstraints,
|
21
|
+
Structure,
|
22
|
+
)
|
23
|
+
|
24
|
+
|
25
|
+
def load_input(
|
26
|
+
record: Record,
|
27
|
+
target_dir: Path,
|
28
|
+
msa_dir: Path,
|
29
|
+
constraints_dir: Optional[Path] = None,
|
30
|
+
) -> Input:
|
31
|
+
"""Load the given input data.
|
32
|
+
|
33
|
+
Parameters
|
34
|
+
----------
|
35
|
+
record : Record
|
36
|
+
The record to load.
|
37
|
+
target_dir : Path
|
38
|
+
The path to the data directory.
|
39
|
+
msa_dir : Path
|
40
|
+
The path to msa directory.
|
41
|
+
|
42
|
+
Returns
|
43
|
+
-------
|
44
|
+
Input
|
45
|
+
The loaded input.
|
46
|
+
|
47
|
+
"""
|
48
|
+
# Load the structure
|
49
|
+
structure = np.load(target_dir / f"{record.id}.npz")
|
50
|
+
structure = Structure(
|
51
|
+
atoms=structure["atoms"],
|
52
|
+
bonds=structure["bonds"],
|
53
|
+
residues=structure["residues"],
|
54
|
+
chains=structure["chains"],
|
55
|
+
connections=structure["connections"].astype(Connection),
|
56
|
+
interfaces=structure["interfaces"],
|
57
|
+
mask=structure["mask"],
|
58
|
+
)
|
59
|
+
|
60
|
+
msas = {}
|
61
|
+
for chain in record.chains:
|
62
|
+
msa_id = chain.msa_id
|
63
|
+
# Load the MSA for this chain, if any
|
64
|
+
if msa_id != -1:
|
65
|
+
msa = np.load(msa_dir / f"{msa_id}.npz")
|
66
|
+
msas[chain.chain_id] = MSA(**msa)
|
67
|
+
|
68
|
+
residue_constraints = None
|
69
|
+
if constraints_dir is not None:
|
70
|
+
residue_constraints = ResidueConstraints.load(
|
71
|
+
constraints_dir / f"{record.id}.npz"
|
72
|
+
)
|
73
|
+
|
74
|
+
return Input(structure, msas, record, residue_constraints)
|
75
|
+
|
76
|
+
|
77
|
+
def collate(data: list[dict[str, Tensor]]) -> dict[str, Tensor]:
|
78
|
+
"""Collate the data.
|
79
|
+
|
80
|
+
Parameters
|
81
|
+
----------
|
82
|
+
data : List[Dict[str, Tensor]]
|
83
|
+
The data to collate.
|
84
|
+
|
85
|
+
Returns
|
86
|
+
-------
|
87
|
+
Dict[str, Tensor]
|
88
|
+
The collated data.
|
89
|
+
|
90
|
+
"""
|
91
|
+
# Get the keys
|
92
|
+
keys = data[0].keys()
|
93
|
+
|
94
|
+
# Collate the data
|
95
|
+
collated = {}
|
96
|
+
for key in keys:
|
97
|
+
values = [d[key] for d in data]
|
98
|
+
|
99
|
+
if key not in [
|
100
|
+
"all_coords",
|
101
|
+
"all_resolved_mask",
|
102
|
+
"crop_to_all_atom_map",
|
103
|
+
"chain_symmetries",
|
104
|
+
"amino_acids_symmetries",
|
105
|
+
"ligand_symmetries",
|
106
|
+
"record",
|
107
|
+
]:
|
108
|
+
# Check if all have the same shape
|
109
|
+
shape = values[0].shape
|
110
|
+
if not all(v.shape == shape for v in values):
|
111
|
+
values, _ = pad_to_max(values, 0)
|
112
|
+
else:
|
113
|
+
values = torch.stack(values, dim=0)
|
114
|
+
|
115
|
+
# Stack the values
|
116
|
+
collated[key] = values
|
117
|
+
|
118
|
+
return collated
|
119
|
+
|
120
|
+
|
121
|
+
class PredictionDataset(torch.utils.data.Dataset):
|
122
|
+
"""Base iterable dataset."""
|
123
|
+
|
124
|
+
def __init__(
|
125
|
+
self,
|
126
|
+
manifest: Manifest,
|
127
|
+
target_dir: Path,
|
128
|
+
msa_dir: Path,
|
129
|
+
constraints_dir: Optional[Path] = None,
|
130
|
+
) -> None:
|
131
|
+
"""Initialize the training dataset.
|
132
|
+
|
133
|
+
Parameters
|
134
|
+
----------
|
135
|
+
manifest : Manifest
|
136
|
+
The manifest to load data from.
|
137
|
+
target_dir : Path
|
138
|
+
The path to the target directory.
|
139
|
+
msa_dir : Path
|
140
|
+
The path to the msa directory.
|
141
|
+
|
142
|
+
"""
|
143
|
+
super().__init__()
|
144
|
+
self.manifest = manifest
|
145
|
+
self.target_dir = target_dir
|
146
|
+
self.msa_dir = msa_dir
|
147
|
+
self.constraints_dir = constraints_dir
|
148
|
+
self.tokenizer = BoltzTokenizer()
|
149
|
+
self.featurizer = BoltzFeaturizer()
|
150
|
+
|
151
|
+
def __getitem__(self, idx: int) -> dict:
|
152
|
+
"""Get an item from the dataset.
|
153
|
+
|
154
|
+
Returns
|
155
|
+
-------
|
156
|
+
Dict[str, Tensor]
|
157
|
+
The sampled data features.
|
158
|
+
|
159
|
+
"""
|
160
|
+
# Get a sample from the dataset
|
161
|
+
record = self.manifest.records[idx]
|
162
|
+
|
163
|
+
# Get the structure
|
164
|
+
try:
|
165
|
+
input_data = load_input(
|
166
|
+
record,
|
167
|
+
self.target_dir,
|
168
|
+
self.msa_dir,
|
169
|
+
self.constraints_dir,
|
170
|
+
)
|
171
|
+
except Exception as e: # noqa: BLE001
|
172
|
+
print(f"Failed to load input for {record.id} with error {e}. Skipping.") # noqa: T201
|
173
|
+
return self.__getitem__(0)
|
174
|
+
|
175
|
+
# Tokenize structure
|
176
|
+
try:
|
177
|
+
tokenized = self.tokenizer.tokenize(input_data)
|
178
|
+
except Exception as e: # noqa: BLE001
|
179
|
+
print(f"Tokenizer failed on {record.id} with error {e}. Skipping.") # noqa: T201
|
180
|
+
return self.__getitem__(0)
|
181
|
+
|
182
|
+
# Inference specific options
|
183
|
+
options = record.inference_options
|
184
|
+
if options is None or len(options.pocket_constraints) == 0:
|
185
|
+
binder, pocket = None, None
|
186
|
+
else:
|
187
|
+
binder, pocket = options.pocket_constraints[0][0], options.pocket_constraints[0][1]
|
188
|
+
|
189
|
+
# Compute features
|
190
|
+
try:
|
191
|
+
features = self.featurizer.process(
|
192
|
+
tokenized,
|
193
|
+
training=False,
|
194
|
+
max_atoms=None,
|
195
|
+
max_tokens=None,
|
196
|
+
max_seqs=const.max_msa_seqs,
|
197
|
+
pad_to_max_seqs=False,
|
198
|
+
symmetries={},
|
199
|
+
compute_symmetries=False,
|
200
|
+
inference_binder=binder,
|
201
|
+
inference_pocket=pocket,
|
202
|
+
compute_constraint_features=True,
|
203
|
+
)
|
204
|
+
except Exception as e: # noqa: BLE001
|
205
|
+
print(f"Featurizer failed on {record.id} with error {e}. Skipping.") # noqa: T201
|
206
|
+
return self.__getitem__(0)
|
207
|
+
|
208
|
+
features["record"] = record
|
209
|
+
return features
|
210
|
+
|
211
|
+
def __len__(self) -> int:
|
212
|
+
"""Get the length of the dataset.
|
213
|
+
|
214
|
+
Returns
|
215
|
+
-------
|
216
|
+
int
|
217
|
+
The length of the dataset.
|
218
|
+
|
219
|
+
"""
|
220
|
+
return len(self.manifest.records)
|
221
|
+
|
222
|
+
|
223
|
+
class BoltzInferenceDataModule(pl.LightningDataModule):
|
224
|
+
"""DataModule for Boltz inference."""
|
225
|
+
|
226
|
+
def __init__(
|
227
|
+
self,
|
228
|
+
manifest: Manifest,
|
229
|
+
target_dir: Path,
|
230
|
+
msa_dir: Path,
|
231
|
+
num_workers: int,
|
232
|
+
constraints_dir: Optional[Path] = None,
|
233
|
+
) -> None:
|
234
|
+
"""Initialize the DataModule.
|
235
|
+
|
236
|
+
Parameters
|
237
|
+
----------
|
238
|
+
config : DataConfig
|
239
|
+
The data configuration.
|
240
|
+
|
241
|
+
"""
|
242
|
+
super().__init__()
|
243
|
+
self.num_workers = num_workers
|
244
|
+
self.manifest = manifest
|
245
|
+
self.target_dir = target_dir
|
246
|
+
self.msa_dir = msa_dir
|
247
|
+
self.constraints_dir = constraints_dir
|
248
|
+
|
249
|
+
def predict_dataloader(self) -> DataLoader:
|
250
|
+
"""Get the training dataloader.
|
251
|
+
|
252
|
+
Returns
|
253
|
+
-------
|
254
|
+
DataLoader
|
255
|
+
The training dataloader.
|
256
|
+
|
257
|
+
"""
|
258
|
+
dataset = PredictionDataset(
|
259
|
+
manifest=self.manifest,
|
260
|
+
target_dir=self.target_dir,
|
261
|
+
msa_dir=self.msa_dir,
|
262
|
+
constraints_dir=self.constraints_dir,
|
263
|
+
)
|
264
|
+
return DataLoader(
|
265
|
+
dataset,
|
266
|
+
batch_size=1,
|
267
|
+
num_workers=self.num_workers,
|
268
|
+
pin_memory=True,
|
269
|
+
shuffle=False,
|
270
|
+
collate_fn=collate,
|
271
|
+
)
|
272
|
+
|
273
|
+
def transfer_batch_to_device(
|
274
|
+
self,
|
275
|
+
batch: dict,
|
276
|
+
device: torch.device,
|
277
|
+
dataloader_idx: int, # noqa: ARG002
|
278
|
+
) -> dict:
|
279
|
+
"""Transfer a batch to the given device.
|
280
|
+
|
281
|
+
Parameters
|
282
|
+
----------
|
283
|
+
batch : Dict
|
284
|
+
The batch to transfer.
|
285
|
+
device : torch.device
|
286
|
+
The device to transfer to.
|
287
|
+
dataloader_idx : int
|
288
|
+
The dataloader index.
|
289
|
+
|
290
|
+
Returns
|
291
|
+
-------
|
292
|
+
np.Any
|
293
|
+
The transferred batch.
|
294
|
+
|
295
|
+
"""
|
296
|
+
for key in batch:
|
297
|
+
if key not in [
|
298
|
+
"all_coords",
|
299
|
+
"all_resolved_mask",
|
300
|
+
"crop_to_all_atom_map",
|
301
|
+
"chain_symmetries",
|
302
|
+
"amino_acids_symmetries",
|
303
|
+
"ligand_symmetries",
|
304
|
+
"record",
|
305
|
+
]:
|
306
|
+
batch[key] = batch[key].to(device)
|
307
|
+
return batch
|