boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,684 @@
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ import pytorch_lightning as pl
7
+ import torch
8
+ from torch import Tensor
9
+ from torch.utils.data import DataLoader
10
+
11
+ from boltz.data.crop.cropper import Cropper
12
+ from boltz.data.feature.featurizer import BoltzFeaturizer
13
+ from boltz.data.feature.symmetry import get_symmetries
14
+ from boltz.data.filter.dynamic.filter import DynamicFilter
15
+ from boltz.data.pad import pad_to_max
16
+ from boltz.data.sample.sampler import Sample, Sampler
17
+ from boltz.data.tokenize.tokenizer import Tokenizer
18
+ from boltz.data.types import MSA, Connection, Input, Manifest, Record, Structure
19
+
20
+
21
+ @dataclass
22
+ class DatasetConfig:
23
+ """Dataset configuration."""
24
+
25
+ target_dir: str
26
+ msa_dir: str
27
+ prob: float
28
+ sampler: Sampler
29
+ cropper: Cropper
30
+ filters: Optional[list] = None
31
+ split: Optional[str] = None
32
+ manifest_path: Optional[str] = None
33
+
34
+
35
+ @dataclass
36
+ class DataConfig:
37
+ """Data configuration."""
38
+
39
+ datasets: list[DatasetConfig]
40
+ filters: list[DynamicFilter]
41
+ featurizer: BoltzFeaturizer
42
+ tokenizer: Tokenizer
43
+ max_atoms: int
44
+ max_tokens: int
45
+ max_seqs: int
46
+ samples_per_epoch: int
47
+ batch_size: int
48
+ num_workers: int
49
+ random_seed: int
50
+ pin_memory: bool
51
+ symmetries: str
52
+ atoms_per_window_queries: int
53
+ min_dist: float
54
+ max_dist: float
55
+ num_bins: int
56
+ overfit: Optional[int] = None
57
+ pad_to_max_tokens: bool = False
58
+ pad_to_max_atoms: bool = False
59
+ pad_to_max_seqs: bool = False
60
+ crop_validation: bool = False
61
+ return_train_symmetries: bool = False
62
+ return_val_symmetries: bool = True
63
+ train_binder_pocket_conditioned_prop: float = 0.0
64
+ val_binder_pocket_conditioned_prop: float = 0.0
65
+ binder_pocket_cutoff: float = 6.0
66
+ binder_pocket_sampling_geometric_p: float = 0.0
67
+ val_batch_size: int = 1
68
+
69
+
70
+ @dataclass
71
+ class Dataset:
72
+ """Data holder."""
73
+
74
+ target_dir: Path
75
+ msa_dir: Path
76
+ manifest: Manifest
77
+ prob: float
78
+ sampler: Sampler
79
+ cropper: Cropper
80
+ tokenizer: Tokenizer
81
+ featurizer: BoltzFeaturizer
82
+
83
+
84
+ def load_input(record: Record, target_dir: Path, msa_dir: Path) -> Input:
85
+ """Load the given input data.
86
+
87
+ Parameters
88
+ ----------
89
+ record : Record
90
+ The record to load.
91
+ target_dir : Path
92
+ The path to the data directory.
93
+ msa_dir : Path
94
+ The path to msa directory.
95
+
96
+ Returns
97
+ -------
98
+ Input
99
+ The loaded input.
100
+
101
+ """
102
+ # Load the structure
103
+ structure = np.load(target_dir / "structures" / f"{record.id}.npz")
104
+
105
+ # In order to add cyclic_period to chains if it does not exist
106
+ # Extract the chains array
107
+ chains = structure["chains"]
108
+ # Check if the field exists
109
+ if "cyclic_period" not in chains.dtype.names:
110
+ # Create a new dtype with the additional field
111
+ new_dtype = chains.dtype.descr + [("cyclic_period", "i4")]
112
+ # Create a new array with the new dtype
113
+ new_chains = np.empty(chains.shape, dtype=new_dtype)
114
+ # Copy over existing fields
115
+ for name in chains.dtype.names:
116
+ new_chains[name] = chains[name]
117
+ # Set the new field to 0
118
+ new_chains["cyclic_period"] = 0
119
+ # Replace old chains array with new one
120
+ chains = new_chains
121
+
122
+ structure = Structure(
123
+ atoms=structure["atoms"],
124
+ bonds=structure["bonds"],
125
+ residues=structure["residues"],
126
+ chains=chains, # chains var accounting for missing cyclic_period
127
+ connections=structure["connections"].astype(Connection),
128
+ interfaces=structure["interfaces"],
129
+ mask=structure["mask"],
130
+ )
131
+
132
+ msas = {}
133
+ for chain in record.chains:
134
+ msa_id = chain.msa_id
135
+ # Load the MSA for this chain, if any
136
+ if msa_id != -1 and msa_id != "":
137
+ msa = np.load(msa_dir / f"{msa_id}.npz")
138
+ msas[chain.chain_id] = MSA(**msa)
139
+
140
+ return Input(structure, msas)
141
+
142
+
143
+ def collate(data: list[dict[str, Tensor]]) -> dict[str, Tensor]:
144
+ """Collate the data.
145
+
146
+ Parameters
147
+ ----------
148
+ data : list[dict[str, Tensor]]
149
+ The data to collate.
150
+
151
+ Returns
152
+ -------
153
+ dict[str, Tensor]
154
+ The collated data.
155
+
156
+ """
157
+ # Get the keys
158
+ keys = data[0].keys()
159
+
160
+ # Collate the data
161
+ collated = {}
162
+ for key in keys:
163
+ values = [d[key] for d in data]
164
+
165
+ if key not in [
166
+ "all_coords",
167
+ "all_resolved_mask",
168
+ "crop_to_all_atom_map",
169
+ "chain_symmetries",
170
+ "amino_acids_symmetries",
171
+ "ligand_symmetries",
172
+ ]:
173
+ # Check if all have the same shape
174
+ shape = values[0].shape
175
+ if not all(v.shape == shape for v in values):
176
+ values, _ = pad_to_max(values, 0)
177
+ else:
178
+ values = torch.stack(values, dim=0)
179
+
180
+ # Stack the values
181
+ collated[key] = values
182
+
183
+ return collated
184
+
185
+
186
+ class TrainingDataset(torch.utils.data.Dataset):
187
+ """Base iterable dataset."""
188
+
189
+ def __init__(
190
+ self,
191
+ datasets: list[Dataset],
192
+ samples_per_epoch: int,
193
+ symmetries: dict,
194
+ max_atoms: int,
195
+ max_tokens: int,
196
+ max_seqs: int,
197
+ pad_to_max_atoms: bool = False,
198
+ pad_to_max_tokens: bool = False,
199
+ pad_to_max_seqs: bool = False,
200
+ atoms_per_window_queries: int = 32,
201
+ min_dist: float = 2.0,
202
+ max_dist: float = 22.0,
203
+ num_bins: int = 64,
204
+ overfit: Optional[int] = None,
205
+ binder_pocket_conditioned_prop: Optional[float] = 0.0,
206
+ binder_pocket_cutoff: Optional[float] = 6.0,
207
+ binder_pocket_sampling_geometric_p: Optional[float] = 0.0,
208
+ return_symmetries: Optional[bool] = False,
209
+ compute_constraint_features: bool = False,
210
+ ) -> None:
211
+ """Initialize the training dataset."""
212
+ super().__init__()
213
+ self.datasets = datasets
214
+ self.probs = [d.prob for d in datasets]
215
+ self.samples_per_epoch = samples_per_epoch
216
+ self.symmetries = symmetries
217
+ self.max_tokens = max_tokens
218
+ self.max_seqs = max_seqs
219
+ self.max_atoms = max_atoms
220
+ self.pad_to_max_tokens = pad_to_max_tokens
221
+ self.pad_to_max_atoms = pad_to_max_atoms
222
+ self.pad_to_max_seqs = pad_to_max_seqs
223
+ self.atoms_per_window_queries = atoms_per_window_queries
224
+ self.min_dist = min_dist
225
+ self.max_dist = max_dist
226
+ self.num_bins = num_bins
227
+ self.binder_pocket_conditioned_prop = binder_pocket_conditioned_prop
228
+ self.binder_pocket_cutoff = binder_pocket_cutoff
229
+ self.binder_pocket_sampling_geometric_p = binder_pocket_sampling_geometric_p
230
+ self.return_symmetries = return_symmetries
231
+ self.compute_constraint_features = compute_constraint_features
232
+ self.samples = []
233
+ for dataset in datasets:
234
+ records = dataset.manifest.records
235
+ if overfit is not None:
236
+ records = records[:overfit]
237
+ iterator = dataset.sampler.sample(records, np.random)
238
+ self.samples.append(iterator)
239
+
240
+ def __getitem__(self, idx: int) -> dict[str, Tensor]:
241
+ """Get an item from the dataset.
242
+
243
+ Parameters
244
+ ----------
245
+ idx : int
246
+ The data index.
247
+
248
+ Returns
249
+ -------
250
+ dict[str, Tensor]
251
+ The sampled data features.
252
+
253
+ """
254
+ # Pick a random dataset
255
+ dataset_idx = np.random.choice(
256
+ len(self.datasets),
257
+ p=self.probs,
258
+ )
259
+ dataset = self.datasets[dataset_idx]
260
+
261
+ # Get a sample from the dataset
262
+ sample: Sample = next(self.samples[dataset_idx])
263
+
264
+ # Get the structure
265
+ try:
266
+ input_data = load_input(sample.record, dataset.target_dir, dataset.msa_dir)
267
+ except Exception as e:
268
+ print(
269
+ f"Failed to load input for {sample.record.id} with error {e}. Skipping."
270
+ )
271
+ return self.__getitem__(idx)
272
+
273
+ # Tokenize structure
274
+ try:
275
+ tokenized = dataset.tokenizer.tokenize(input_data)
276
+ except Exception as e:
277
+ print(f"Tokenizer failed on {sample.record.id} with error {e}. Skipping.")
278
+ return self.__getitem__(idx)
279
+
280
+ # Compute crop
281
+ try:
282
+ if self.max_tokens is not None:
283
+ tokenized = dataset.cropper.crop(
284
+ tokenized,
285
+ max_atoms=self.max_atoms,
286
+ max_tokens=self.max_tokens,
287
+ random=np.random,
288
+ chain_id=sample.chain_id,
289
+ interface_id=sample.interface_id,
290
+ )
291
+ except Exception as e:
292
+ print(f"Cropper failed on {sample.record.id} with error {e}. Skipping.")
293
+ return self.__getitem__(idx)
294
+
295
+ # Check if there are tokens
296
+ if len(tokenized.tokens) == 0:
297
+ msg = "No tokens in cropped structure."
298
+ raise ValueError(msg)
299
+
300
+ # Compute features
301
+ try:
302
+ features = dataset.featurizer.process(
303
+ tokenized,
304
+ training=True,
305
+ max_atoms=self.max_atoms if self.pad_to_max_atoms else None,
306
+ max_tokens=self.max_tokens if self.pad_to_max_tokens else None,
307
+ max_seqs=self.max_seqs,
308
+ pad_to_max_seqs=self.pad_to_max_seqs,
309
+ symmetries=self.symmetries,
310
+ atoms_per_window_queries=self.atoms_per_window_queries,
311
+ min_dist=self.min_dist,
312
+ max_dist=self.max_dist,
313
+ num_bins=self.num_bins,
314
+ compute_symmetries=self.return_symmetries,
315
+ binder_pocket_conditioned_prop=self.binder_pocket_conditioned_prop,
316
+ binder_pocket_cutoff=self.binder_pocket_cutoff,
317
+ binder_pocket_sampling_geometric_p=self.binder_pocket_sampling_geometric_p,
318
+ compute_constraint_features=self.compute_constraint_features,
319
+ )
320
+ except Exception as e:
321
+ print(f"Featurizer failed on {sample.record.id} with error {e}. Skipping.")
322
+ return self.__getitem__(idx)
323
+
324
+ return features
325
+
326
+ def __len__(self) -> int:
327
+ """Get the length of the dataset.
328
+
329
+ Returns
330
+ -------
331
+ int
332
+ The length of the dataset.
333
+
334
+ """
335
+ return self.samples_per_epoch
336
+
337
+
338
+ class ValidationDataset(torch.utils.data.Dataset):
339
+ """Base iterable dataset."""
340
+
341
+ def __init__(
342
+ self,
343
+ datasets: list[Dataset],
344
+ seed: int,
345
+ symmetries: dict,
346
+ max_atoms: Optional[int] = None,
347
+ max_tokens: Optional[int] = None,
348
+ max_seqs: Optional[int] = None,
349
+ pad_to_max_atoms: bool = False,
350
+ pad_to_max_tokens: bool = False,
351
+ pad_to_max_seqs: bool = False,
352
+ atoms_per_window_queries: int = 32,
353
+ min_dist: float = 2.0,
354
+ max_dist: float = 22.0,
355
+ num_bins: int = 64,
356
+ overfit: Optional[int] = None,
357
+ crop_validation: bool = False,
358
+ return_symmetries: Optional[bool] = False,
359
+ binder_pocket_conditioned_prop: Optional[float] = 0.0,
360
+ binder_pocket_cutoff: Optional[float] = 6.0,
361
+ compute_constraint_features: bool = False,
362
+ ) -> None:
363
+ """Initialize the validation dataset."""
364
+ super().__init__()
365
+ self.datasets = datasets
366
+ self.max_atoms = max_atoms
367
+ self.max_tokens = max_tokens
368
+ self.max_seqs = max_seqs
369
+ self.seed = seed
370
+ self.symmetries = symmetries
371
+ self.random = np.random if overfit else np.random.RandomState(self.seed)
372
+ self.pad_to_max_tokens = pad_to_max_tokens
373
+ self.pad_to_max_atoms = pad_to_max_atoms
374
+ self.pad_to_max_seqs = pad_to_max_seqs
375
+ self.overfit = overfit
376
+ self.crop_validation = crop_validation
377
+ self.atoms_per_window_queries = atoms_per_window_queries
378
+ self.min_dist = min_dist
379
+ self.max_dist = max_dist
380
+ self.num_bins = num_bins
381
+ self.return_symmetries = return_symmetries
382
+ self.binder_pocket_conditioned_prop = binder_pocket_conditioned_prop
383
+ self.binder_pocket_cutoff = binder_pocket_cutoff
384
+ self.compute_constraint_features = compute_constraint_features
385
+
386
+ def __getitem__(self, idx: int) -> dict[str, Tensor]:
387
+ """Get an item from the dataset.
388
+
389
+ Parameters
390
+ ----------
391
+ idx : int
392
+ The data index.
393
+
394
+ Returns
395
+ -------
396
+ dict[str, Tensor]
397
+ The sampled data features.
398
+
399
+ """
400
+ # Pick dataset based on idx
401
+ for dataset in self.datasets:
402
+ size = len(dataset.manifest.records)
403
+ if self.overfit is not None:
404
+ size = min(size, self.overfit)
405
+ if idx < size:
406
+ break
407
+ idx -= size
408
+
409
+ # Get a sample from the dataset
410
+ record = dataset.manifest.records[idx]
411
+
412
+ # Get the structure
413
+ try:
414
+ input_data = load_input(record, dataset.target_dir, dataset.msa_dir)
415
+ except Exception as e:
416
+ print(f"Failed to load input for {record.id} with error {e}. Skipping.")
417
+ return self.__getitem__(0)
418
+
419
+ # Tokenize structure
420
+ try:
421
+ tokenized = dataset.tokenizer.tokenize(input_data)
422
+ except Exception as e:
423
+ print(f"Tokenizer failed on {record.id} with error {e}. Skipping.")
424
+ return self.__getitem__(0)
425
+
426
+ # Compute crop
427
+ try:
428
+ if self.crop_validation and (self.max_tokens is not None):
429
+ tokenized = dataset.cropper.crop(
430
+ tokenized,
431
+ max_tokens=self.max_tokens,
432
+ random=self.random,
433
+ max_atoms=self.max_atoms,
434
+ )
435
+ except Exception as e:
436
+ print(f"Cropper failed on {record.id} with error {e}. Skipping.")
437
+ return self.__getitem__(0)
438
+
439
+ # Check if there are tokens
440
+ if len(tokenized.tokens) == 0:
441
+ msg = "No tokens in cropped structure."
442
+ raise ValueError(msg)
443
+
444
+ # Compute features
445
+ try:
446
+ pad_atoms = self.crop_validation and self.pad_to_max_atoms
447
+ pad_tokens = self.crop_validation and self.pad_to_max_tokens
448
+
449
+ features = dataset.featurizer.process(
450
+ tokenized,
451
+ training=False,
452
+ max_atoms=self.max_atoms if pad_atoms else None,
453
+ max_tokens=self.max_tokens if pad_tokens else None,
454
+ max_seqs=self.max_seqs,
455
+ pad_to_max_seqs=self.pad_to_max_seqs,
456
+ symmetries=self.symmetries,
457
+ atoms_per_window_queries=self.atoms_per_window_queries,
458
+ min_dist=self.min_dist,
459
+ max_dist=self.max_dist,
460
+ num_bins=self.num_bins,
461
+ compute_symmetries=self.return_symmetries,
462
+ binder_pocket_conditioned_prop=self.binder_pocket_conditioned_prop,
463
+ binder_pocket_cutoff=self.binder_pocket_cutoff,
464
+ binder_pocket_sampling_geometric_p=1.0, # this will only sample a single pocket token
465
+ only_ligand_binder_pocket=True,
466
+ compute_constraint_features=self.compute_constraint_features,
467
+ )
468
+ except Exception as e:
469
+ print(f"Featurizer failed on {record.id} with error {e}. Skipping.")
470
+ return self.__getitem__(0)
471
+
472
+ return features
473
+
474
+ def __len__(self) -> int:
475
+ """Get the length of the dataset.
476
+
477
+ Returns
478
+ -------
479
+ int
480
+ The length of the dataset.
481
+
482
+ """
483
+ if self.overfit is not None:
484
+ length = sum(len(d.manifest.records[: self.overfit]) for d in self.datasets)
485
+ else:
486
+ length = sum(len(d.manifest.records) for d in self.datasets)
487
+
488
+ return length
489
+
490
+
491
+ class BoltzTrainingDataModule(pl.LightningDataModule):
492
+ """DataModule for boltz."""
493
+
494
+ def __init__(self, cfg: DataConfig) -> None:
495
+ """Initialize the DataModule.
496
+
497
+ Parameters
498
+ ----------
499
+ config : DataConfig
500
+ The data configuration.
501
+
502
+ """
503
+ super().__init__()
504
+ self.cfg = cfg
505
+
506
+ assert self.cfg.val_batch_size == 1, "Validation only works with batch size=1."
507
+
508
+ # Load symmetries
509
+ symmetries = get_symmetries(cfg.symmetries)
510
+
511
+ # Load datasets
512
+ train: list[Dataset] = []
513
+ val: list[Dataset] = []
514
+
515
+ for data_config in cfg.datasets:
516
+ # Set target_dir
517
+ target_dir = Path(data_config.target_dir)
518
+ msa_dir = Path(data_config.msa_dir)
519
+
520
+ # Load manifest
521
+ if data_config.manifest_path is not None:
522
+ path = Path(data_config.manifest_path)
523
+ else:
524
+ path = target_dir / "manifest.json"
525
+ manifest: Manifest = Manifest.load(path)
526
+
527
+ # Split records if given
528
+ if data_config.split is not None:
529
+ with Path(data_config.split).open("r") as f:
530
+ split = {x.lower() for x in f.read().splitlines()}
531
+
532
+ train_records = []
533
+ val_records = []
534
+ for record in manifest.records:
535
+ if record.id.lower() in split:
536
+ val_records.append(record)
537
+ else:
538
+ train_records.append(record)
539
+ else:
540
+ train_records = manifest.records
541
+ val_records = []
542
+
543
+ # Filter training records
544
+ train_records = [
545
+ record
546
+ for record in train_records
547
+ if all(f.filter(record) for f in cfg.filters)
548
+ ]
549
+ # Filter training records
550
+ if data_config.filters is not None:
551
+ train_records = [
552
+ record
553
+ for record in train_records
554
+ if all(f.filter(record) for f in data_config.filters)
555
+ ]
556
+
557
+ # Create train dataset
558
+ train_manifest = Manifest(train_records)
559
+ train.append(
560
+ Dataset(
561
+ target_dir,
562
+ msa_dir,
563
+ train_manifest,
564
+ data_config.prob,
565
+ data_config.sampler,
566
+ data_config.cropper,
567
+ cfg.tokenizer,
568
+ cfg.featurizer,
569
+ )
570
+ )
571
+
572
+ # Create validation dataset
573
+ if val_records:
574
+ val_manifest = Manifest(val_records)
575
+ val.append(
576
+ Dataset(
577
+ target_dir,
578
+ msa_dir,
579
+ val_manifest,
580
+ data_config.prob,
581
+ data_config.sampler,
582
+ data_config.cropper,
583
+ cfg.tokenizer,
584
+ cfg.featurizer,
585
+ )
586
+ )
587
+
588
+ # Print dataset sizes
589
+ for dataset in train:
590
+ dataset: Dataset
591
+ print(f"Training dataset size: {len(dataset.manifest.records)}")
592
+
593
+ for dataset in val:
594
+ dataset: Dataset
595
+ print(f"Validation dataset size: {len(dataset.manifest.records)}")
596
+
597
+ # Create wrapper datasets
598
+ self._train_set = TrainingDataset(
599
+ datasets=train,
600
+ samples_per_epoch=cfg.samples_per_epoch,
601
+ max_atoms=cfg.max_atoms,
602
+ max_tokens=cfg.max_tokens,
603
+ max_seqs=cfg.max_seqs,
604
+ pad_to_max_atoms=cfg.pad_to_max_atoms,
605
+ pad_to_max_tokens=cfg.pad_to_max_tokens,
606
+ pad_to_max_seqs=cfg.pad_to_max_seqs,
607
+ symmetries=symmetries,
608
+ atoms_per_window_queries=cfg.atoms_per_window_queries,
609
+ min_dist=cfg.min_dist,
610
+ max_dist=cfg.max_dist,
611
+ num_bins=cfg.num_bins,
612
+ overfit=cfg.overfit,
613
+ binder_pocket_conditioned_prop=cfg.train_binder_pocket_conditioned_prop,
614
+ binder_pocket_cutoff=cfg.binder_pocket_cutoff,
615
+ binder_pocket_sampling_geometric_p=cfg.binder_pocket_sampling_geometric_p,
616
+ return_symmetries=cfg.return_train_symmetries,
617
+ )
618
+ self._val_set = ValidationDataset(
619
+ datasets=train if cfg.overfit is not None else val,
620
+ seed=cfg.random_seed,
621
+ max_atoms=cfg.max_atoms,
622
+ max_tokens=cfg.max_tokens,
623
+ max_seqs=cfg.max_seqs,
624
+ pad_to_max_atoms=cfg.pad_to_max_atoms,
625
+ pad_to_max_tokens=cfg.pad_to_max_tokens,
626
+ pad_to_max_seqs=cfg.pad_to_max_seqs,
627
+ symmetries=symmetries,
628
+ atoms_per_window_queries=cfg.atoms_per_window_queries,
629
+ min_dist=cfg.min_dist,
630
+ max_dist=cfg.max_dist,
631
+ num_bins=cfg.num_bins,
632
+ overfit=cfg.overfit,
633
+ crop_validation=cfg.crop_validation,
634
+ return_symmetries=cfg.return_val_symmetries,
635
+ binder_pocket_conditioned_prop=cfg.val_binder_pocket_conditioned_prop,
636
+ binder_pocket_cutoff=cfg.binder_pocket_cutoff,
637
+ )
638
+
639
+ def setup(self, stage: Optional[str] = None) -> None:
640
+ """Run the setup for the DataModule.
641
+
642
+ Parameters
643
+ ----------
644
+ stage : str, optional
645
+ The stage, one of 'fit', 'validate', 'test'.
646
+
647
+ """
648
+ return
649
+
650
+ def train_dataloader(self) -> DataLoader:
651
+ """Get the training dataloader.
652
+
653
+ Returns
654
+ -------
655
+ DataLoader
656
+ The training dataloader.
657
+
658
+ """
659
+ return DataLoader(
660
+ self._train_set,
661
+ batch_size=self.cfg.batch_size,
662
+ num_workers=self.cfg.num_workers,
663
+ pin_memory=self.cfg.pin_memory,
664
+ shuffle=False,
665
+ collate_fn=collate,
666
+ )
667
+
668
+ def val_dataloader(self) -> DataLoader:
669
+ """Get the validation dataloader.
670
+
671
+ Returns
672
+ -------
673
+ DataLoader
674
+ The validation dataloader.
675
+
676
+ """
677
+ return DataLoader(
678
+ self._val_set,
679
+ batch_size=self.cfg.val_batch_size,
680
+ num_workers=self.cfg.num_workers,
681
+ pin_memory=self.cfg.pin_memory,
682
+ shuffle=False,
683
+ collate_fn=collate,
684
+ )