boltz-vsynthes 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltz/__init__.py +7 -0
- boltz/data/__init__.py +0 -0
- boltz/data/const.py +1184 -0
- boltz/data/crop/__init__.py +0 -0
- boltz/data/crop/affinity.py +164 -0
- boltz/data/crop/boltz.py +296 -0
- boltz/data/crop/cropper.py +45 -0
- boltz/data/feature/__init__.py +0 -0
- boltz/data/feature/featurizer.py +1230 -0
- boltz/data/feature/featurizerv2.py +2208 -0
- boltz/data/feature/symmetry.py +602 -0
- boltz/data/filter/__init__.py +0 -0
- boltz/data/filter/dynamic/__init__.py +0 -0
- boltz/data/filter/dynamic/date.py +76 -0
- boltz/data/filter/dynamic/filter.py +24 -0
- boltz/data/filter/dynamic/max_residues.py +37 -0
- boltz/data/filter/dynamic/resolution.py +34 -0
- boltz/data/filter/dynamic/size.py +38 -0
- boltz/data/filter/dynamic/subset.py +42 -0
- boltz/data/filter/static/__init__.py +0 -0
- boltz/data/filter/static/filter.py +26 -0
- boltz/data/filter/static/ligand.py +37 -0
- boltz/data/filter/static/polymer.py +299 -0
- boltz/data/module/__init__.py +0 -0
- boltz/data/module/inference.py +307 -0
- boltz/data/module/inferencev2.py +429 -0
- boltz/data/module/training.py +684 -0
- boltz/data/module/trainingv2.py +660 -0
- boltz/data/mol.py +900 -0
- boltz/data/msa/__init__.py +0 -0
- boltz/data/msa/mmseqs2.py +235 -0
- boltz/data/pad.py +84 -0
- boltz/data/parse/__init__.py +0 -0
- boltz/data/parse/a3m.py +134 -0
- boltz/data/parse/csv.py +100 -0
- boltz/data/parse/fasta.py +138 -0
- boltz/data/parse/mmcif.py +1239 -0
- boltz/data/parse/mmcif_with_constraints.py +1607 -0
- boltz/data/parse/schema.py +1851 -0
- boltz/data/parse/yaml.py +68 -0
- boltz/data/sample/__init__.py +0 -0
- boltz/data/sample/cluster.py +283 -0
- boltz/data/sample/distillation.py +57 -0
- boltz/data/sample/random.py +39 -0
- boltz/data/sample/sampler.py +49 -0
- boltz/data/tokenize/__init__.py +0 -0
- boltz/data/tokenize/boltz.py +195 -0
- boltz/data/tokenize/boltz2.py +396 -0
- boltz/data/tokenize/tokenizer.py +24 -0
- boltz/data/types.py +777 -0
- boltz/data/write/__init__.py +0 -0
- boltz/data/write/mmcif.py +305 -0
- boltz/data/write/pdb.py +171 -0
- boltz/data/write/utils.py +23 -0
- boltz/data/write/writer.py +330 -0
- boltz/main.py +1292 -0
- boltz/model/__init__.py +0 -0
- boltz/model/layers/__init__.py +0 -0
- boltz/model/layers/attention.py +132 -0
- boltz/model/layers/attentionv2.py +111 -0
- boltz/model/layers/confidence_utils.py +231 -0
- boltz/model/layers/dropout.py +34 -0
- boltz/model/layers/initialize.py +100 -0
- boltz/model/layers/outer_product_mean.py +98 -0
- boltz/model/layers/pair_averaging.py +135 -0
- boltz/model/layers/pairformer.py +337 -0
- boltz/model/layers/relative.py +58 -0
- boltz/model/layers/transition.py +78 -0
- boltz/model/layers/triangular_attention/__init__.py +0 -0
- boltz/model/layers/triangular_attention/attention.py +189 -0
- boltz/model/layers/triangular_attention/primitives.py +409 -0
- boltz/model/layers/triangular_attention/utils.py +380 -0
- boltz/model/layers/triangular_mult.py +212 -0
- boltz/model/loss/__init__.py +0 -0
- boltz/model/loss/bfactor.py +49 -0
- boltz/model/loss/confidence.py +590 -0
- boltz/model/loss/confidencev2.py +621 -0
- boltz/model/loss/diffusion.py +171 -0
- boltz/model/loss/diffusionv2.py +134 -0
- boltz/model/loss/distogram.py +48 -0
- boltz/model/loss/distogramv2.py +105 -0
- boltz/model/loss/validation.py +1025 -0
- boltz/model/models/__init__.py +0 -0
- boltz/model/models/boltz1.py +1286 -0
- boltz/model/models/boltz2.py +1249 -0
- boltz/model/modules/__init__.py +0 -0
- boltz/model/modules/affinity.py +223 -0
- boltz/model/modules/confidence.py +481 -0
- boltz/model/modules/confidence_utils.py +181 -0
- boltz/model/modules/confidencev2.py +495 -0
- boltz/model/modules/diffusion.py +844 -0
- boltz/model/modules/diffusion_conditioning.py +116 -0
- boltz/model/modules/diffusionv2.py +677 -0
- boltz/model/modules/encoders.py +639 -0
- boltz/model/modules/encodersv2.py +565 -0
- boltz/model/modules/transformers.py +322 -0
- boltz/model/modules/transformersv2.py +261 -0
- boltz/model/modules/trunk.py +688 -0
- boltz/model/modules/trunkv2.py +828 -0
- boltz/model/modules/utils.py +303 -0
- boltz/model/optim/__init__.py +0 -0
- boltz/model/optim/ema.py +389 -0
- boltz/model/optim/scheduler.py +99 -0
- boltz/model/potentials/__init__.py +0 -0
- boltz/model/potentials/potentials.py +497 -0
- boltz/model/potentials/schedules.py +32 -0
- boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
- boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
- boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
- boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
- boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
- boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,684 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from pathlib import Path
|
3
|
+
from typing import Optional
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
import pytorch_lightning as pl
|
7
|
+
import torch
|
8
|
+
from torch import Tensor
|
9
|
+
from torch.utils.data import DataLoader
|
10
|
+
|
11
|
+
from boltz.data.crop.cropper import Cropper
|
12
|
+
from boltz.data.feature.featurizer import BoltzFeaturizer
|
13
|
+
from boltz.data.feature.symmetry import get_symmetries
|
14
|
+
from boltz.data.filter.dynamic.filter import DynamicFilter
|
15
|
+
from boltz.data.pad import pad_to_max
|
16
|
+
from boltz.data.sample.sampler import Sample, Sampler
|
17
|
+
from boltz.data.tokenize.tokenizer import Tokenizer
|
18
|
+
from boltz.data.types import MSA, Connection, Input, Manifest, Record, Structure
|
19
|
+
|
20
|
+
|
21
|
+
@dataclass
|
22
|
+
class DatasetConfig:
|
23
|
+
"""Dataset configuration."""
|
24
|
+
|
25
|
+
target_dir: str
|
26
|
+
msa_dir: str
|
27
|
+
prob: float
|
28
|
+
sampler: Sampler
|
29
|
+
cropper: Cropper
|
30
|
+
filters: Optional[list] = None
|
31
|
+
split: Optional[str] = None
|
32
|
+
manifest_path: Optional[str] = None
|
33
|
+
|
34
|
+
|
35
|
+
@dataclass
|
36
|
+
class DataConfig:
|
37
|
+
"""Data configuration."""
|
38
|
+
|
39
|
+
datasets: list[DatasetConfig]
|
40
|
+
filters: list[DynamicFilter]
|
41
|
+
featurizer: BoltzFeaturizer
|
42
|
+
tokenizer: Tokenizer
|
43
|
+
max_atoms: int
|
44
|
+
max_tokens: int
|
45
|
+
max_seqs: int
|
46
|
+
samples_per_epoch: int
|
47
|
+
batch_size: int
|
48
|
+
num_workers: int
|
49
|
+
random_seed: int
|
50
|
+
pin_memory: bool
|
51
|
+
symmetries: str
|
52
|
+
atoms_per_window_queries: int
|
53
|
+
min_dist: float
|
54
|
+
max_dist: float
|
55
|
+
num_bins: int
|
56
|
+
overfit: Optional[int] = None
|
57
|
+
pad_to_max_tokens: bool = False
|
58
|
+
pad_to_max_atoms: bool = False
|
59
|
+
pad_to_max_seqs: bool = False
|
60
|
+
crop_validation: bool = False
|
61
|
+
return_train_symmetries: bool = False
|
62
|
+
return_val_symmetries: bool = True
|
63
|
+
train_binder_pocket_conditioned_prop: float = 0.0
|
64
|
+
val_binder_pocket_conditioned_prop: float = 0.0
|
65
|
+
binder_pocket_cutoff: float = 6.0
|
66
|
+
binder_pocket_sampling_geometric_p: float = 0.0
|
67
|
+
val_batch_size: int = 1
|
68
|
+
|
69
|
+
|
70
|
+
@dataclass
|
71
|
+
class Dataset:
|
72
|
+
"""Data holder."""
|
73
|
+
|
74
|
+
target_dir: Path
|
75
|
+
msa_dir: Path
|
76
|
+
manifest: Manifest
|
77
|
+
prob: float
|
78
|
+
sampler: Sampler
|
79
|
+
cropper: Cropper
|
80
|
+
tokenizer: Tokenizer
|
81
|
+
featurizer: BoltzFeaturizer
|
82
|
+
|
83
|
+
|
84
|
+
def load_input(record: Record, target_dir: Path, msa_dir: Path) -> Input:
|
85
|
+
"""Load the given input data.
|
86
|
+
|
87
|
+
Parameters
|
88
|
+
----------
|
89
|
+
record : Record
|
90
|
+
The record to load.
|
91
|
+
target_dir : Path
|
92
|
+
The path to the data directory.
|
93
|
+
msa_dir : Path
|
94
|
+
The path to msa directory.
|
95
|
+
|
96
|
+
Returns
|
97
|
+
-------
|
98
|
+
Input
|
99
|
+
The loaded input.
|
100
|
+
|
101
|
+
"""
|
102
|
+
# Load the structure
|
103
|
+
structure = np.load(target_dir / "structures" / f"{record.id}.npz")
|
104
|
+
|
105
|
+
# In order to add cyclic_period to chains if it does not exist
|
106
|
+
# Extract the chains array
|
107
|
+
chains = structure["chains"]
|
108
|
+
# Check if the field exists
|
109
|
+
if "cyclic_period" not in chains.dtype.names:
|
110
|
+
# Create a new dtype with the additional field
|
111
|
+
new_dtype = chains.dtype.descr + [("cyclic_period", "i4")]
|
112
|
+
# Create a new array with the new dtype
|
113
|
+
new_chains = np.empty(chains.shape, dtype=new_dtype)
|
114
|
+
# Copy over existing fields
|
115
|
+
for name in chains.dtype.names:
|
116
|
+
new_chains[name] = chains[name]
|
117
|
+
# Set the new field to 0
|
118
|
+
new_chains["cyclic_period"] = 0
|
119
|
+
# Replace old chains array with new one
|
120
|
+
chains = new_chains
|
121
|
+
|
122
|
+
structure = Structure(
|
123
|
+
atoms=structure["atoms"],
|
124
|
+
bonds=structure["bonds"],
|
125
|
+
residues=structure["residues"],
|
126
|
+
chains=chains, # chains var accounting for missing cyclic_period
|
127
|
+
connections=structure["connections"].astype(Connection),
|
128
|
+
interfaces=structure["interfaces"],
|
129
|
+
mask=structure["mask"],
|
130
|
+
)
|
131
|
+
|
132
|
+
msas = {}
|
133
|
+
for chain in record.chains:
|
134
|
+
msa_id = chain.msa_id
|
135
|
+
# Load the MSA for this chain, if any
|
136
|
+
if msa_id != -1 and msa_id != "":
|
137
|
+
msa = np.load(msa_dir / f"{msa_id}.npz")
|
138
|
+
msas[chain.chain_id] = MSA(**msa)
|
139
|
+
|
140
|
+
return Input(structure, msas)
|
141
|
+
|
142
|
+
|
143
|
+
def collate(data: list[dict[str, Tensor]]) -> dict[str, Tensor]:
|
144
|
+
"""Collate the data.
|
145
|
+
|
146
|
+
Parameters
|
147
|
+
----------
|
148
|
+
data : list[dict[str, Tensor]]
|
149
|
+
The data to collate.
|
150
|
+
|
151
|
+
Returns
|
152
|
+
-------
|
153
|
+
dict[str, Tensor]
|
154
|
+
The collated data.
|
155
|
+
|
156
|
+
"""
|
157
|
+
# Get the keys
|
158
|
+
keys = data[0].keys()
|
159
|
+
|
160
|
+
# Collate the data
|
161
|
+
collated = {}
|
162
|
+
for key in keys:
|
163
|
+
values = [d[key] for d in data]
|
164
|
+
|
165
|
+
if key not in [
|
166
|
+
"all_coords",
|
167
|
+
"all_resolved_mask",
|
168
|
+
"crop_to_all_atom_map",
|
169
|
+
"chain_symmetries",
|
170
|
+
"amino_acids_symmetries",
|
171
|
+
"ligand_symmetries",
|
172
|
+
]:
|
173
|
+
# Check if all have the same shape
|
174
|
+
shape = values[0].shape
|
175
|
+
if not all(v.shape == shape for v in values):
|
176
|
+
values, _ = pad_to_max(values, 0)
|
177
|
+
else:
|
178
|
+
values = torch.stack(values, dim=0)
|
179
|
+
|
180
|
+
# Stack the values
|
181
|
+
collated[key] = values
|
182
|
+
|
183
|
+
return collated
|
184
|
+
|
185
|
+
|
186
|
+
class TrainingDataset(torch.utils.data.Dataset):
|
187
|
+
"""Base iterable dataset."""
|
188
|
+
|
189
|
+
def __init__(
|
190
|
+
self,
|
191
|
+
datasets: list[Dataset],
|
192
|
+
samples_per_epoch: int,
|
193
|
+
symmetries: dict,
|
194
|
+
max_atoms: int,
|
195
|
+
max_tokens: int,
|
196
|
+
max_seqs: int,
|
197
|
+
pad_to_max_atoms: bool = False,
|
198
|
+
pad_to_max_tokens: bool = False,
|
199
|
+
pad_to_max_seqs: bool = False,
|
200
|
+
atoms_per_window_queries: int = 32,
|
201
|
+
min_dist: float = 2.0,
|
202
|
+
max_dist: float = 22.0,
|
203
|
+
num_bins: int = 64,
|
204
|
+
overfit: Optional[int] = None,
|
205
|
+
binder_pocket_conditioned_prop: Optional[float] = 0.0,
|
206
|
+
binder_pocket_cutoff: Optional[float] = 6.0,
|
207
|
+
binder_pocket_sampling_geometric_p: Optional[float] = 0.0,
|
208
|
+
return_symmetries: Optional[bool] = False,
|
209
|
+
compute_constraint_features: bool = False,
|
210
|
+
) -> None:
|
211
|
+
"""Initialize the training dataset."""
|
212
|
+
super().__init__()
|
213
|
+
self.datasets = datasets
|
214
|
+
self.probs = [d.prob for d in datasets]
|
215
|
+
self.samples_per_epoch = samples_per_epoch
|
216
|
+
self.symmetries = symmetries
|
217
|
+
self.max_tokens = max_tokens
|
218
|
+
self.max_seqs = max_seqs
|
219
|
+
self.max_atoms = max_atoms
|
220
|
+
self.pad_to_max_tokens = pad_to_max_tokens
|
221
|
+
self.pad_to_max_atoms = pad_to_max_atoms
|
222
|
+
self.pad_to_max_seqs = pad_to_max_seqs
|
223
|
+
self.atoms_per_window_queries = atoms_per_window_queries
|
224
|
+
self.min_dist = min_dist
|
225
|
+
self.max_dist = max_dist
|
226
|
+
self.num_bins = num_bins
|
227
|
+
self.binder_pocket_conditioned_prop = binder_pocket_conditioned_prop
|
228
|
+
self.binder_pocket_cutoff = binder_pocket_cutoff
|
229
|
+
self.binder_pocket_sampling_geometric_p = binder_pocket_sampling_geometric_p
|
230
|
+
self.return_symmetries = return_symmetries
|
231
|
+
self.compute_constraint_features = compute_constraint_features
|
232
|
+
self.samples = []
|
233
|
+
for dataset in datasets:
|
234
|
+
records = dataset.manifest.records
|
235
|
+
if overfit is not None:
|
236
|
+
records = records[:overfit]
|
237
|
+
iterator = dataset.sampler.sample(records, np.random)
|
238
|
+
self.samples.append(iterator)
|
239
|
+
|
240
|
+
def __getitem__(self, idx: int) -> dict[str, Tensor]:
|
241
|
+
"""Get an item from the dataset.
|
242
|
+
|
243
|
+
Parameters
|
244
|
+
----------
|
245
|
+
idx : int
|
246
|
+
The data index.
|
247
|
+
|
248
|
+
Returns
|
249
|
+
-------
|
250
|
+
dict[str, Tensor]
|
251
|
+
The sampled data features.
|
252
|
+
|
253
|
+
"""
|
254
|
+
# Pick a random dataset
|
255
|
+
dataset_idx = np.random.choice(
|
256
|
+
len(self.datasets),
|
257
|
+
p=self.probs,
|
258
|
+
)
|
259
|
+
dataset = self.datasets[dataset_idx]
|
260
|
+
|
261
|
+
# Get a sample from the dataset
|
262
|
+
sample: Sample = next(self.samples[dataset_idx])
|
263
|
+
|
264
|
+
# Get the structure
|
265
|
+
try:
|
266
|
+
input_data = load_input(sample.record, dataset.target_dir, dataset.msa_dir)
|
267
|
+
except Exception as e:
|
268
|
+
print(
|
269
|
+
f"Failed to load input for {sample.record.id} with error {e}. Skipping."
|
270
|
+
)
|
271
|
+
return self.__getitem__(idx)
|
272
|
+
|
273
|
+
# Tokenize structure
|
274
|
+
try:
|
275
|
+
tokenized = dataset.tokenizer.tokenize(input_data)
|
276
|
+
except Exception as e:
|
277
|
+
print(f"Tokenizer failed on {sample.record.id} with error {e}. Skipping.")
|
278
|
+
return self.__getitem__(idx)
|
279
|
+
|
280
|
+
# Compute crop
|
281
|
+
try:
|
282
|
+
if self.max_tokens is not None:
|
283
|
+
tokenized = dataset.cropper.crop(
|
284
|
+
tokenized,
|
285
|
+
max_atoms=self.max_atoms,
|
286
|
+
max_tokens=self.max_tokens,
|
287
|
+
random=np.random,
|
288
|
+
chain_id=sample.chain_id,
|
289
|
+
interface_id=sample.interface_id,
|
290
|
+
)
|
291
|
+
except Exception as e:
|
292
|
+
print(f"Cropper failed on {sample.record.id} with error {e}. Skipping.")
|
293
|
+
return self.__getitem__(idx)
|
294
|
+
|
295
|
+
# Check if there are tokens
|
296
|
+
if len(tokenized.tokens) == 0:
|
297
|
+
msg = "No tokens in cropped structure."
|
298
|
+
raise ValueError(msg)
|
299
|
+
|
300
|
+
# Compute features
|
301
|
+
try:
|
302
|
+
features = dataset.featurizer.process(
|
303
|
+
tokenized,
|
304
|
+
training=True,
|
305
|
+
max_atoms=self.max_atoms if self.pad_to_max_atoms else None,
|
306
|
+
max_tokens=self.max_tokens if self.pad_to_max_tokens else None,
|
307
|
+
max_seqs=self.max_seqs,
|
308
|
+
pad_to_max_seqs=self.pad_to_max_seqs,
|
309
|
+
symmetries=self.symmetries,
|
310
|
+
atoms_per_window_queries=self.atoms_per_window_queries,
|
311
|
+
min_dist=self.min_dist,
|
312
|
+
max_dist=self.max_dist,
|
313
|
+
num_bins=self.num_bins,
|
314
|
+
compute_symmetries=self.return_symmetries,
|
315
|
+
binder_pocket_conditioned_prop=self.binder_pocket_conditioned_prop,
|
316
|
+
binder_pocket_cutoff=self.binder_pocket_cutoff,
|
317
|
+
binder_pocket_sampling_geometric_p=self.binder_pocket_sampling_geometric_p,
|
318
|
+
compute_constraint_features=self.compute_constraint_features,
|
319
|
+
)
|
320
|
+
except Exception as e:
|
321
|
+
print(f"Featurizer failed on {sample.record.id} with error {e}. Skipping.")
|
322
|
+
return self.__getitem__(idx)
|
323
|
+
|
324
|
+
return features
|
325
|
+
|
326
|
+
def __len__(self) -> int:
|
327
|
+
"""Get the length of the dataset.
|
328
|
+
|
329
|
+
Returns
|
330
|
+
-------
|
331
|
+
int
|
332
|
+
The length of the dataset.
|
333
|
+
|
334
|
+
"""
|
335
|
+
return self.samples_per_epoch
|
336
|
+
|
337
|
+
|
338
|
+
class ValidationDataset(torch.utils.data.Dataset):
|
339
|
+
"""Base iterable dataset."""
|
340
|
+
|
341
|
+
def __init__(
|
342
|
+
self,
|
343
|
+
datasets: list[Dataset],
|
344
|
+
seed: int,
|
345
|
+
symmetries: dict,
|
346
|
+
max_atoms: Optional[int] = None,
|
347
|
+
max_tokens: Optional[int] = None,
|
348
|
+
max_seqs: Optional[int] = None,
|
349
|
+
pad_to_max_atoms: bool = False,
|
350
|
+
pad_to_max_tokens: bool = False,
|
351
|
+
pad_to_max_seqs: bool = False,
|
352
|
+
atoms_per_window_queries: int = 32,
|
353
|
+
min_dist: float = 2.0,
|
354
|
+
max_dist: float = 22.0,
|
355
|
+
num_bins: int = 64,
|
356
|
+
overfit: Optional[int] = None,
|
357
|
+
crop_validation: bool = False,
|
358
|
+
return_symmetries: Optional[bool] = False,
|
359
|
+
binder_pocket_conditioned_prop: Optional[float] = 0.0,
|
360
|
+
binder_pocket_cutoff: Optional[float] = 6.0,
|
361
|
+
compute_constraint_features: bool = False,
|
362
|
+
) -> None:
|
363
|
+
"""Initialize the validation dataset."""
|
364
|
+
super().__init__()
|
365
|
+
self.datasets = datasets
|
366
|
+
self.max_atoms = max_atoms
|
367
|
+
self.max_tokens = max_tokens
|
368
|
+
self.max_seqs = max_seqs
|
369
|
+
self.seed = seed
|
370
|
+
self.symmetries = symmetries
|
371
|
+
self.random = np.random if overfit else np.random.RandomState(self.seed)
|
372
|
+
self.pad_to_max_tokens = pad_to_max_tokens
|
373
|
+
self.pad_to_max_atoms = pad_to_max_atoms
|
374
|
+
self.pad_to_max_seqs = pad_to_max_seqs
|
375
|
+
self.overfit = overfit
|
376
|
+
self.crop_validation = crop_validation
|
377
|
+
self.atoms_per_window_queries = atoms_per_window_queries
|
378
|
+
self.min_dist = min_dist
|
379
|
+
self.max_dist = max_dist
|
380
|
+
self.num_bins = num_bins
|
381
|
+
self.return_symmetries = return_symmetries
|
382
|
+
self.binder_pocket_conditioned_prop = binder_pocket_conditioned_prop
|
383
|
+
self.binder_pocket_cutoff = binder_pocket_cutoff
|
384
|
+
self.compute_constraint_features = compute_constraint_features
|
385
|
+
|
386
|
+
def __getitem__(self, idx: int) -> dict[str, Tensor]:
|
387
|
+
"""Get an item from the dataset.
|
388
|
+
|
389
|
+
Parameters
|
390
|
+
----------
|
391
|
+
idx : int
|
392
|
+
The data index.
|
393
|
+
|
394
|
+
Returns
|
395
|
+
-------
|
396
|
+
dict[str, Tensor]
|
397
|
+
The sampled data features.
|
398
|
+
|
399
|
+
"""
|
400
|
+
# Pick dataset based on idx
|
401
|
+
for dataset in self.datasets:
|
402
|
+
size = len(dataset.manifest.records)
|
403
|
+
if self.overfit is not None:
|
404
|
+
size = min(size, self.overfit)
|
405
|
+
if idx < size:
|
406
|
+
break
|
407
|
+
idx -= size
|
408
|
+
|
409
|
+
# Get a sample from the dataset
|
410
|
+
record = dataset.manifest.records[idx]
|
411
|
+
|
412
|
+
# Get the structure
|
413
|
+
try:
|
414
|
+
input_data = load_input(record, dataset.target_dir, dataset.msa_dir)
|
415
|
+
except Exception as e:
|
416
|
+
print(f"Failed to load input for {record.id} with error {e}. Skipping.")
|
417
|
+
return self.__getitem__(0)
|
418
|
+
|
419
|
+
# Tokenize structure
|
420
|
+
try:
|
421
|
+
tokenized = dataset.tokenizer.tokenize(input_data)
|
422
|
+
except Exception as e:
|
423
|
+
print(f"Tokenizer failed on {record.id} with error {e}. Skipping.")
|
424
|
+
return self.__getitem__(0)
|
425
|
+
|
426
|
+
# Compute crop
|
427
|
+
try:
|
428
|
+
if self.crop_validation and (self.max_tokens is not None):
|
429
|
+
tokenized = dataset.cropper.crop(
|
430
|
+
tokenized,
|
431
|
+
max_tokens=self.max_tokens,
|
432
|
+
random=self.random,
|
433
|
+
max_atoms=self.max_atoms,
|
434
|
+
)
|
435
|
+
except Exception as e:
|
436
|
+
print(f"Cropper failed on {record.id} with error {e}. Skipping.")
|
437
|
+
return self.__getitem__(0)
|
438
|
+
|
439
|
+
# Check if there are tokens
|
440
|
+
if len(tokenized.tokens) == 0:
|
441
|
+
msg = "No tokens in cropped structure."
|
442
|
+
raise ValueError(msg)
|
443
|
+
|
444
|
+
# Compute features
|
445
|
+
try:
|
446
|
+
pad_atoms = self.crop_validation and self.pad_to_max_atoms
|
447
|
+
pad_tokens = self.crop_validation and self.pad_to_max_tokens
|
448
|
+
|
449
|
+
features = dataset.featurizer.process(
|
450
|
+
tokenized,
|
451
|
+
training=False,
|
452
|
+
max_atoms=self.max_atoms if pad_atoms else None,
|
453
|
+
max_tokens=self.max_tokens if pad_tokens else None,
|
454
|
+
max_seqs=self.max_seqs,
|
455
|
+
pad_to_max_seqs=self.pad_to_max_seqs,
|
456
|
+
symmetries=self.symmetries,
|
457
|
+
atoms_per_window_queries=self.atoms_per_window_queries,
|
458
|
+
min_dist=self.min_dist,
|
459
|
+
max_dist=self.max_dist,
|
460
|
+
num_bins=self.num_bins,
|
461
|
+
compute_symmetries=self.return_symmetries,
|
462
|
+
binder_pocket_conditioned_prop=self.binder_pocket_conditioned_prop,
|
463
|
+
binder_pocket_cutoff=self.binder_pocket_cutoff,
|
464
|
+
binder_pocket_sampling_geometric_p=1.0, # this will only sample a single pocket token
|
465
|
+
only_ligand_binder_pocket=True,
|
466
|
+
compute_constraint_features=self.compute_constraint_features,
|
467
|
+
)
|
468
|
+
except Exception as e:
|
469
|
+
print(f"Featurizer failed on {record.id} with error {e}. Skipping.")
|
470
|
+
return self.__getitem__(0)
|
471
|
+
|
472
|
+
return features
|
473
|
+
|
474
|
+
def __len__(self) -> int:
|
475
|
+
"""Get the length of the dataset.
|
476
|
+
|
477
|
+
Returns
|
478
|
+
-------
|
479
|
+
int
|
480
|
+
The length of the dataset.
|
481
|
+
|
482
|
+
"""
|
483
|
+
if self.overfit is not None:
|
484
|
+
length = sum(len(d.manifest.records[: self.overfit]) for d in self.datasets)
|
485
|
+
else:
|
486
|
+
length = sum(len(d.manifest.records) for d in self.datasets)
|
487
|
+
|
488
|
+
return length
|
489
|
+
|
490
|
+
|
491
|
+
class BoltzTrainingDataModule(pl.LightningDataModule):
|
492
|
+
"""DataModule for boltz."""
|
493
|
+
|
494
|
+
def __init__(self, cfg: DataConfig) -> None:
|
495
|
+
"""Initialize the DataModule.
|
496
|
+
|
497
|
+
Parameters
|
498
|
+
----------
|
499
|
+
config : DataConfig
|
500
|
+
The data configuration.
|
501
|
+
|
502
|
+
"""
|
503
|
+
super().__init__()
|
504
|
+
self.cfg = cfg
|
505
|
+
|
506
|
+
assert self.cfg.val_batch_size == 1, "Validation only works with batch size=1."
|
507
|
+
|
508
|
+
# Load symmetries
|
509
|
+
symmetries = get_symmetries(cfg.symmetries)
|
510
|
+
|
511
|
+
# Load datasets
|
512
|
+
train: list[Dataset] = []
|
513
|
+
val: list[Dataset] = []
|
514
|
+
|
515
|
+
for data_config in cfg.datasets:
|
516
|
+
# Set target_dir
|
517
|
+
target_dir = Path(data_config.target_dir)
|
518
|
+
msa_dir = Path(data_config.msa_dir)
|
519
|
+
|
520
|
+
# Load manifest
|
521
|
+
if data_config.manifest_path is not None:
|
522
|
+
path = Path(data_config.manifest_path)
|
523
|
+
else:
|
524
|
+
path = target_dir / "manifest.json"
|
525
|
+
manifest: Manifest = Manifest.load(path)
|
526
|
+
|
527
|
+
# Split records if given
|
528
|
+
if data_config.split is not None:
|
529
|
+
with Path(data_config.split).open("r") as f:
|
530
|
+
split = {x.lower() for x in f.read().splitlines()}
|
531
|
+
|
532
|
+
train_records = []
|
533
|
+
val_records = []
|
534
|
+
for record in manifest.records:
|
535
|
+
if record.id.lower() in split:
|
536
|
+
val_records.append(record)
|
537
|
+
else:
|
538
|
+
train_records.append(record)
|
539
|
+
else:
|
540
|
+
train_records = manifest.records
|
541
|
+
val_records = []
|
542
|
+
|
543
|
+
# Filter training records
|
544
|
+
train_records = [
|
545
|
+
record
|
546
|
+
for record in train_records
|
547
|
+
if all(f.filter(record) for f in cfg.filters)
|
548
|
+
]
|
549
|
+
# Filter training records
|
550
|
+
if data_config.filters is not None:
|
551
|
+
train_records = [
|
552
|
+
record
|
553
|
+
for record in train_records
|
554
|
+
if all(f.filter(record) for f in data_config.filters)
|
555
|
+
]
|
556
|
+
|
557
|
+
# Create train dataset
|
558
|
+
train_manifest = Manifest(train_records)
|
559
|
+
train.append(
|
560
|
+
Dataset(
|
561
|
+
target_dir,
|
562
|
+
msa_dir,
|
563
|
+
train_manifest,
|
564
|
+
data_config.prob,
|
565
|
+
data_config.sampler,
|
566
|
+
data_config.cropper,
|
567
|
+
cfg.tokenizer,
|
568
|
+
cfg.featurizer,
|
569
|
+
)
|
570
|
+
)
|
571
|
+
|
572
|
+
# Create validation dataset
|
573
|
+
if val_records:
|
574
|
+
val_manifest = Manifest(val_records)
|
575
|
+
val.append(
|
576
|
+
Dataset(
|
577
|
+
target_dir,
|
578
|
+
msa_dir,
|
579
|
+
val_manifest,
|
580
|
+
data_config.prob,
|
581
|
+
data_config.sampler,
|
582
|
+
data_config.cropper,
|
583
|
+
cfg.tokenizer,
|
584
|
+
cfg.featurizer,
|
585
|
+
)
|
586
|
+
)
|
587
|
+
|
588
|
+
# Print dataset sizes
|
589
|
+
for dataset in train:
|
590
|
+
dataset: Dataset
|
591
|
+
print(f"Training dataset size: {len(dataset.manifest.records)}")
|
592
|
+
|
593
|
+
for dataset in val:
|
594
|
+
dataset: Dataset
|
595
|
+
print(f"Validation dataset size: {len(dataset.manifest.records)}")
|
596
|
+
|
597
|
+
# Create wrapper datasets
|
598
|
+
self._train_set = TrainingDataset(
|
599
|
+
datasets=train,
|
600
|
+
samples_per_epoch=cfg.samples_per_epoch,
|
601
|
+
max_atoms=cfg.max_atoms,
|
602
|
+
max_tokens=cfg.max_tokens,
|
603
|
+
max_seqs=cfg.max_seqs,
|
604
|
+
pad_to_max_atoms=cfg.pad_to_max_atoms,
|
605
|
+
pad_to_max_tokens=cfg.pad_to_max_tokens,
|
606
|
+
pad_to_max_seqs=cfg.pad_to_max_seqs,
|
607
|
+
symmetries=symmetries,
|
608
|
+
atoms_per_window_queries=cfg.atoms_per_window_queries,
|
609
|
+
min_dist=cfg.min_dist,
|
610
|
+
max_dist=cfg.max_dist,
|
611
|
+
num_bins=cfg.num_bins,
|
612
|
+
overfit=cfg.overfit,
|
613
|
+
binder_pocket_conditioned_prop=cfg.train_binder_pocket_conditioned_prop,
|
614
|
+
binder_pocket_cutoff=cfg.binder_pocket_cutoff,
|
615
|
+
binder_pocket_sampling_geometric_p=cfg.binder_pocket_sampling_geometric_p,
|
616
|
+
return_symmetries=cfg.return_train_symmetries,
|
617
|
+
)
|
618
|
+
self._val_set = ValidationDataset(
|
619
|
+
datasets=train if cfg.overfit is not None else val,
|
620
|
+
seed=cfg.random_seed,
|
621
|
+
max_atoms=cfg.max_atoms,
|
622
|
+
max_tokens=cfg.max_tokens,
|
623
|
+
max_seqs=cfg.max_seqs,
|
624
|
+
pad_to_max_atoms=cfg.pad_to_max_atoms,
|
625
|
+
pad_to_max_tokens=cfg.pad_to_max_tokens,
|
626
|
+
pad_to_max_seqs=cfg.pad_to_max_seqs,
|
627
|
+
symmetries=symmetries,
|
628
|
+
atoms_per_window_queries=cfg.atoms_per_window_queries,
|
629
|
+
min_dist=cfg.min_dist,
|
630
|
+
max_dist=cfg.max_dist,
|
631
|
+
num_bins=cfg.num_bins,
|
632
|
+
overfit=cfg.overfit,
|
633
|
+
crop_validation=cfg.crop_validation,
|
634
|
+
return_symmetries=cfg.return_val_symmetries,
|
635
|
+
binder_pocket_conditioned_prop=cfg.val_binder_pocket_conditioned_prop,
|
636
|
+
binder_pocket_cutoff=cfg.binder_pocket_cutoff,
|
637
|
+
)
|
638
|
+
|
639
|
+
def setup(self, stage: Optional[str] = None) -> None:
|
640
|
+
"""Run the setup for the DataModule.
|
641
|
+
|
642
|
+
Parameters
|
643
|
+
----------
|
644
|
+
stage : str, optional
|
645
|
+
The stage, one of 'fit', 'validate', 'test'.
|
646
|
+
|
647
|
+
"""
|
648
|
+
return
|
649
|
+
|
650
|
+
def train_dataloader(self) -> DataLoader:
|
651
|
+
"""Get the training dataloader.
|
652
|
+
|
653
|
+
Returns
|
654
|
+
-------
|
655
|
+
DataLoader
|
656
|
+
The training dataloader.
|
657
|
+
|
658
|
+
"""
|
659
|
+
return DataLoader(
|
660
|
+
self._train_set,
|
661
|
+
batch_size=self.cfg.batch_size,
|
662
|
+
num_workers=self.cfg.num_workers,
|
663
|
+
pin_memory=self.cfg.pin_memory,
|
664
|
+
shuffle=False,
|
665
|
+
collate_fn=collate,
|
666
|
+
)
|
667
|
+
|
668
|
+
def val_dataloader(self) -> DataLoader:
|
669
|
+
"""Get the validation dataloader.
|
670
|
+
|
671
|
+
Returns
|
672
|
+
-------
|
673
|
+
DataLoader
|
674
|
+
The validation dataloader.
|
675
|
+
|
676
|
+
"""
|
677
|
+
return DataLoader(
|
678
|
+
self._val_set,
|
679
|
+
batch_size=self.cfg.val_batch_size,
|
680
|
+
num_workers=self.cfg.num_workers,
|
681
|
+
pin_memory=self.cfg.pin_memory,
|
682
|
+
shuffle=False,
|
683
|
+
collate_fn=collate,
|
684
|
+
)
|