boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,660 @@
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ import pytorch_lightning as pl
7
+ import torch
8
+ from torch import Tensor
9
+ from torch.utils.data import DataLoader
10
+
11
+ from boltz.data.crop.cropper import Cropper
12
+ from boltz.data.feature.featurizer import BoltzFeaturizer
13
+ from boltz.data.feature.symmetry import get_symmetries
14
+ from boltz.data.filter.dynamic.filter import DynamicFilter
15
+ from boltz.data.pad import pad_to_max
16
+ from boltz.data.sample.sampler import Sample, Sampler
17
+ from boltz.data.tokenize.tokenizer import Tokenizer
18
+ from boltz.data.types import MSA, Connection, Input, Manifest, Record, Structure
19
+
20
+
21
+ @dataclass
22
+ class DatasetConfig:
23
+ """Dataset configuration."""
24
+
25
+ target_dir: str
26
+ msa_dir: str
27
+ prob: float
28
+ sampler: Sampler
29
+ cropper: Cropper
30
+ filters: Optional[list] = None
31
+ split: Optional[str] = None
32
+ manifest_path: Optional[str] = None
33
+
34
+
35
+ @dataclass
36
+ class DataConfig:
37
+ """Data configuration."""
38
+
39
+ datasets: list[DatasetConfig]
40
+ filters: list[DynamicFilter]
41
+ featurizer: BoltzFeaturizer
42
+ tokenizer: Tokenizer
43
+ max_atoms: int
44
+ max_tokens: int
45
+ max_seqs: int
46
+ samples_per_epoch: int
47
+ batch_size: int
48
+ num_workers: int
49
+ random_seed: int
50
+ pin_memory: bool
51
+ symmetries: str
52
+ atoms_per_window_queries: int
53
+ min_dist: float
54
+ max_dist: float
55
+ num_bins: int
56
+ overfit: Optional[int] = None
57
+ pad_to_max_tokens: bool = False
58
+ pad_to_max_atoms: bool = False
59
+ pad_to_max_seqs: bool = False
60
+ crop_validation: bool = False
61
+ return_train_symmetries: bool = False
62
+ return_val_symmetries: bool = True
63
+ train_binder_pocket_conditioned_prop: float = 0.0
64
+ val_binder_pocket_conditioned_prop: float = 0.0
65
+ binder_pocket_cutoff: float = 6.0
66
+ binder_pocket_sampling_geometric_p: float = 0.0
67
+ val_batch_size: int = 1
68
+
69
+
70
+ @dataclass
71
+ class Dataset:
72
+ """Data holder."""
73
+
74
+ target_dir: Path
75
+ msa_dir: Path
76
+ manifest: Manifest
77
+ prob: float
78
+ sampler: Sampler
79
+ cropper: Cropper
80
+ tokenizer: Tokenizer
81
+ featurizer: BoltzFeaturizer
82
+
83
+
84
+ def load_input(record: Record, target_dir: Path, msa_dir: Path) -> Input:
85
+ """Load the given input data.
86
+
87
+ Parameters
88
+ ----------
89
+ record : Record
90
+ The record to load.
91
+ target_dir : Path
92
+ The path to the data directory.
93
+ msa_dir : Path
94
+ The path to msa directory.
95
+
96
+ Returns
97
+ -------
98
+ Input
99
+ The loaded input.
100
+
101
+ """
102
+ # Load the structure
103
+ structure = np.load(target_dir / "structures" / f"{record.id}.npz")
104
+ structure = Structure(
105
+ atoms=structure["atoms"],
106
+ bonds=structure["bonds"],
107
+ residues=structure["residues"],
108
+ chains=structure["chains"],
109
+ connections=structure["connections"].astype(Connection),
110
+ interfaces=structure["interfaces"],
111
+ mask=structure["mask"],
112
+ )
113
+
114
+ msas = {}
115
+ for chain in record.chains:
116
+ msa_id = chain.msa_id
117
+ # Load the MSA for this chain, if any
118
+ if msa_id != -1 and msa_id != "":
119
+ msa = np.load(msa_dir / f"{msa_id}.npz")
120
+ msas[chain.chain_id] = MSA(**msa)
121
+
122
+ return Input(structure, msas)
123
+
124
+
125
+ def collate(data: list[dict[str, Tensor]]) -> dict[str, Tensor]:
126
+ """Collate the data.
127
+
128
+ Parameters
129
+ ----------
130
+ data : list[dict[str, Tensor]]
131
+ The data to collate.
132
+
133
+ Returns
134
+ -------
135
+ dict[str, Tensor]
136
+ The collated data.
137
+
138
+ """
139
+ # Get the keys
140
+ keys = data[0].keys()
141
+
142
+ # Collate the data
143
+ collated = {}
144
+ for key in keys:
145
+ values = [d[key] for d in data]
146
+
147
+ if key not in [
148
+ "all_coords",
149
+ "all_resolved_mask",
150
+ "crop_to_all_atom_map",
151
+ "chain_symmetries",
152
+ "amino_acids_symmetries",
153
+ "ligand_symmetries",
154
+ ]:
155
+ # Check if all have the same shape
156
+ shape = values[0].shape
157
+ if not all(v.shape == shape for v in values):
158
+ values, _ = pad_to_max(values, 0)
159
+ else:
160
+ values = torch.stack(values, dim=0)
161
+
162
+ # Stack the values
163
+ collated[key] = values
164
+
165
+ return collated
166
+
167
+
168
+ class TrainingDataset(torch.utils.data.Dataset):
169
+ """Base iterable dataset."""
170
+
171
+ def __init__(
172
+ self,
173
+ datasets: list[Dataset],
174
+ samples_per_epoch: int,
175
+ symmetries: dict,
176
+ max_atoms: int,
177
+ max_tokens: int,
178
+ max_seqs: int,
179
+ pad_to_max_atoms: bool = False,
180
+ pad_to_max_tokens: bool = False,
181
+ pad_to_max_seqs: bool = False,
182
+ atoms_per_window_queries: int = 32,
183
+ min_dist: float = 2.0,
184
+ max_dist: float = 22.0,
185
+ num_bins: int = 64,
186
+ overfit: Optional[int] = None,
187
+ binder_pocket_conditioned_prop: Optional[float] = 0.0,
188
+ binder_pocket_cutoff: Optional[float] = 6.0,
189
+ binder_pocket_sampling_geometric_p: Optional[float] = 0.0,
190
+ return_symmetries: Optional[bool] = False,
191
+ ) -> None:
192
+ """Initialize the training dataset."""
193
+ super().__init__()
194
+ self.datasets = datasets
195
+ self.probs = [d.prob for d in datasets]
196
+ self.samples_per_epoch = samples_per_epoch
197
+ self.symmetries = symmetries
198
+ self.max_tokens = max_tokens
199
+ self.max_seqs = max_seqs
200
+ self.max_atoms = max_atoms
201
+ self.pad_to_max_tokens = pad_to_max_tokens
202
+ self.pad_to_max_atoms = pad_to_max_atoms
203
+ self.pad_to_max_seqs = pad_to_max_seqs
204
+ self.atoms_per_window_queries = atoms_per_window_queries
205
+ self.min_dist = min_dist
206
+ self.max_dist = max_dist
207
+ self.num_bins = num_bins
208
+ self.binder_pocket_conditioned_prop = binder_pocket_conditioned_prop
209
+ self.binder_pocket_cutoff = binder_pocket_cutoff
210
+ self.binder_pocket_sampling_geometric_p = binder_pocket_sampling_geometric_p
211
+ self.return_symmetries = return_symmetries
212
+ self.samples = []
213
+ for dataset in datasets:
214
+ records = dataset.manifest.records
215
+ if overfit is not None:
216
+ records = records[:overfit]
217
+ iterator = dataset.sampler.sample(records, np.random)
218
+ self.samples.append(iterator)
219
+
220
+ def __getitem__(self, idx: int) -> dict[str, Tensor]:
221
+ """Get an item from the dataset.
222
+
223
+ Parameters
224
+ ----------
225
+ idx : int
226
+ The data index.
227
+
228
+ Returns
229
+ -------
230
+ dict[str, Tensor]
231
+ The sampled data features.
232
+
233
+ """
234
+ # Pick a random dataset
235
+ dataset_idx = np.random.choice(
236
+ len(self.datasets),
237
+ p=self.probs,
238
+ )
239
+ dataset = self.datasets[dataset_idx]
240
+
241
+ # Get a sample from the dataset
242
+ sample: Sample = next(self.samples[dataset_idx])
243
+
244
+ # Get the structure
245
+ try:
246
+ input_data = load_input(sample.record, dataset.target_dir, dataset.msa_dir)
247
+ except Exception as e:
248
+ print(
249
+ f"Failed to load input for {sample.record.id} with error {e}. Skipping."
250
+ )
251
+ return self.__getitem__(idx)
252
+
253
+ # Tokenize structure
254
+ try:
255
+ tokenized = dataset.tokenizer.tokenize(input_data)
256
+ except Exception as e:
257
+ print(f"Tokenizer failed on {sample.record.id} with error {e}. Skipping.")
258
+ return self.__getitem__(idx)
259
+
260
+ # Compute crop
261
+ try:
262
+ if self.max_tokens is not None:
263
+ tokenized = dataset.cropper.crop(
264
+ tokenized,
265
+ max_atoms=self.max_atoms,
266
+ max_tokens=self.max_tokens,
267
+ random=np.random,
268
+ chain_id=sample.chain_id,
269
+ interface_id=sample.interface_id,
270
+ )
271
+ except Exception as e:
272
+ print(f"Cropper failed on {sample.record.id} with error {e}. Skipping.")
273
+ return self.__getitem__(idx)
274
+
275
+ # Check if there are tokens
276
+ if len(tokenized.tokens) == 0:
277
+ msg = "No tokens in cropped structure."
278
+ raise ValueError(msg)
279
+
280
+ # Compute features
281
+ try:
282
+ features = dataset.featurizer.process(
283
+ tokenized,
284
+ training=True,
285
+ max_atoms=self.max_atoms if self.pad_to_max_atoms else None,
286
+ max_tokens=self.max_tokens if self.pad_to_max_tokens else None,
287
+ max_seqs=self.max_seqs,
288
+ pad_to_max_seqs=self.pad_to_max_seqs,
289
+ symmetries=self.symmetries,
290
+ atoms_per_window_queries=self.atoms_per_window_queries,
291
+ min_dist=self.min_dist,
292
+ max_dist=self.max_dist,
293
+ num_bins=self.num_bins,
294
+ compute_symmetries=self.return_symmetries,
295
+ binder_pocket_conditioned_prop=self.binder_pocket_conditioned_prop,
296
+ binder_pocket_cutoff=self.binder_pocket_cutoff,
297
+ binder_pocket_sampling_geometric_p=self.binder_pocket_sampling_geometric_p,
298
+ )
299
+ except Exception as e:
300
+ print(f"Featurizer failed on {sample.record.id} with error {e}. Skipping.")
301
+ return self.__getitem__(idx)
302
+
303
+ return features
304
+
305
+ def __len__(self) -> int:
306
+ """Get the length of the dataset.
307
+
308
+ Returns
309
+ -------
310
+ int
311
+ The length of the dataset.
312
+
313
+ """
314
+ return self.samples_per_epoch
315
+
316
+
317
+ class ValidationDataset(torch.utils.data.Dataset):
318
+ """Base iterable dataset."""
319
+
320
+ def __init__(
321
+ self,
322
+ datasets: list[Dataset],
323
+ seed: int,
324
+ symmetries: dict,
325
+ max_atoms: Optional[int] = None,
326
+ max_tokens: Optional[int] = None,
327
+ max_seqs: Optional[int] = None,
328
+ pad_to_max_atoms: bool = False,
329
+ pad_to_max_tokens: bool = False,
330
+ pad_to_max_seqs: bool = False,
331
+ atoms_per_window_queries: int = 32,
332
+ min_dist: float = 2.0,
333
+ max_dist: float = 22.0,
334
+ num_bins: int = 64,
335
+ overfit: Optional[int] = None,
336
+ crop_validation: bool = False,
337
+ return_symmetries: Optional[bool] = False,
338
+ binder_pocket_conditioned_prop: Optional[float] = 0.0,
339
+ binder_pocket_cutoff: Optional[float] = 6.0,
340
+ ) -> None:
341
+ """Initialize the validation dataset."""
342
+ super().__init__()
343
+ self.datasets = datasets
344
+ self.max_atoms = max_atoms
345
+ self.max_tokens = max_tokens
346
+ self.max_seqs = max_seqs
347
+ self.seed = seed
348
+ self.symmetries = symmetries
349
+ self.random = np.random if overfit else np.random.RandomState(self.seed)
350
+ self.pad_to_max_tokens = pad_to_max_tokens
351
+ self.pad_to_max_atoms = pad_to_max_atoms
352
+ self.pad_to_max_seqs = pad_to_max_seqs
353
+ self.overfit = overfit
354
+ self.crop_validation = crop_validation
355
+ self.atoms_per_window_queries = atoms_per_window_queries
356
+ self.min_dist = min_dist
357
+ self.max_dist = max_dist
358
+ self.num_bins = num_bins
359
+ self.return_symmetries = return_symmetries
360
+ self.binder_pocket_conditioned_prop = binder_pocket_conditioned_prop
361
+ self.binder_pocket_cutoff = binder_pocket_cutoff
362
+
363
+ def __getitem__(self, idx: int) -> dict[str, Tensor]:
364
+ """Get an item from the dataset.
365
+
366
+ Parameters
367
+ ----------
368
+ idx : int
369
+ The data index.
370
+
371
+ Returns
372
+ -------
373
+ dict[str, Tensor]
374
+ The sampled data features.
375
+
376
+ """
377
+ # Pick dataset based on idx
378
+ for dataset in self.datasets:
379
+ size = len(dataset.manifest.records)
380
+ if self.overfit is not None:
381
+ size = min(size, self.overfit)
382
+ if idx < size:
383
+ break
384
+ idx -= size
385
+
386
+ # Get a sample from the dataset
387
+ record = dataset.manifest.records[idx]
388
+
389
+ # Get the structure
390
+ try:
391
+ input_data = load_input(record, dataset.target_dir, dataset.msa_dir)
392
+ except Exception as e:
393
+ print(f"Failed to load input for {record.id} with error {e}. Skipping.")
394
+ return self.__getitem__(0)
395
+
396
+ # Tokenize structure
397
+ try:
398
+ tokenized = dataset.tokenizer.tokenize(input_data)
399
+ except Exception as e:
400
+ print(f"Tokenizer failed on {record.id} with error {e}. Skipping.")
401
+ return self.__getitem__(0)
402
+
403
+ # Compute crop
404
+ try:
405
+ if self.crop_validation and (self.max_tokens is not None):
406
+ tokenized = dataset.cropper.crop(
407
+ tokenized,
408
+ max_tokens=self.max_tokens,
409
+ random=self.random,
410
+ max_atoms=self.max_atoms,
411
+ )
412
+ except Exception as e:
413
+ print(f"Cropper failed on {record.id} with error {e}. Skipping.")
414
+ return self.__getitem__(0)
415
+
416
+ # Check if there are tokens
417
+ if len(tokenized.tokens) == 0:
418
+ msg = "No tokens in cropped structure."
419
+ raise ValueError(msg)
420
+
421
+ # Compute features
422
+ try:
423
+ pad_atoms = self.crop_validation and self.pad_to_max_atoms
424
+ pad_tokens = self.crop_validation and self.pad_to_max_tokens
425
+
426
+ features = dataset.featurizer.process(
427
+ tokenized,
428
+ training=False,
429
+ max_atoms=self.max_atoms if pad_atoms else None,
430
+ max_tokens=self.max_tokens if pad_tokens else None,
431
+ max_seqs=self.max_seqs,
432
+ pad_to_max_seqs=self.pad_to_max_seqs,
433
+ symmetries=self.symmetries,
434
+ atoms_per_window_queries=self.atoms_per_window_queries,
435
+ min_dist=self.min_dist,
436
+ max_dist=self.max_dist,
437
+ num_bins=self.num_bins,
438
+ compute_symmetries=self.return_symmetries,
439
+ binder_pocket_conditioned_prop=self.binder_pocket_conditioned_prop,
440
+ binder_pocket_cutoff=self.binder_pocket_cutoff,
441
+ binder_pocket_sampling_geometric_p=1.0, # this will only sample a single pocket token
442
+ only_ligand_binder_pocket=True,
443
+ )
444
+ except Exception as e:
445
+ print(f"Featurizer failed on {record.id} with error {e}. Skipping.")
446
+ return self.__getitem__(0)
447
+
448
+ return features
449
+
450
+ def __len__(self) -> int:
451
+ """Get the length of the dataset.
452
+
453
+ Returns
454
+ -------
455
+ int
456
+ The length of the dataset.
457
+
458
+ """
459
+ if self.overfit is not None:
460
+ length = sum(len(d.manifest.records[: self.overfit]) for d in self.datasets)
461
+ else:
462
+ length = sum(len(d.manifest.records) for d in self.datasets)
463
+
464
+ return length
465
+
466
+
467
+ class BoltzTrainingDataModule(pl.LightningDataModule):
468
+ """DataModule for boltz."""
469
+
470
+ def __init__(self, cfg: DataConfig) -> None:
471
+ """Initialize the DataModule.
472
+
473
+ Parameters
474
+ ----------
475
+ config : DataConfig
476
+ The data configuration.
477
+
478
+ """
479
+ super().__init__()
480
+ self.cfg = cfg
481
+
482
+ assert self.cfg.val_batch_size == 1, "Validation only works with batch size=1."
483
+
484
+ # Load symmetries
485
+ symmetries = get_symmetries(cfg.symmetries)
486
+
487
+ # Load datasets
488
+ train: list[Dataset] = []
489
+ val: list[Dataset] = []
490
+
491
+ for data_config in cfg.datasets:
492
+ # Set target_dir
493
+ target_dir = Path(data_config.target_dir)
494
+ msa_dir = Path(data_config.msa_dir)
495
+
496
+ # Load manifest
497
+ if data_config.manifest_path is not None:
498
+ path = Path(data_config.manifest_path)
499
+ else:
500
+ path = target_dir / "manifest.json"
501
+ manifest: Manifest = Manifest.load(path)
502
+
503
+ # Split records if given
504
+ if data_config.split is not None:
505
+ with Path(data_config.split).open("r") as f:
506
+ split = {x.lower() for x in f.read().splitlines()}
507
+
508
+ train_records = []
509
+ val_records = []
510
+ for record in manifest.records:
511
+ if record.id.lower() in split:
512
+ val_records.append(record)
513
+ else:
514
+ train_records.append(record)
515
+ else:
516
+ train_records = manifest.records
517
+ val_records = []
518
+
519
+ # Filter training records
520
+ train_records = [
521
+ record
522
+ for record in train_records
523
+ if all(f.filter(record) for f in cfg.filters)
524
+ ]
525
+ # Filter training records
526
+ if data_config.filters is not None:
527
+ train_records = [
528
+ record
529
+ for record in train_records
530
+ if all(f.filter(record) for f in data_config.filters)
531
+ ]
532
+
533
+ # Create train dataset
534
+ train_manifest = Manifest(train_records)
535
+ train.append(
536
+ Dataset(
537
+ target_dir,
538
+ msa_dir,
539
+ train_manifest,
540
+ data_config.prob,
541
+ data_config.sampler,
542
+ data_config.cropper,
543
+ cfg.tokenizer,
544
+ cfg.featurizer,
545
+ )
546
+ )
547
+
548
+ # Create validation dataset
549
+ if val_records:
550
+ val_manifest = Manifest(val_records)
551
+ val.append(
552
+ Dataset(
553
+ target_dir,
554
+ msa_dir,
555
+ val_manifest,
556
+ data_config.prob,
557
+ data_config.sampler,
558
+ data_config.cropper,
559
+ cfg.tokenizer,
560
+ cfg.featurizer,
561
+ )
562
+ )
563
+
564
+ # Print dataset sizes
565
+ for dataset in train:
566
+ dataset: Dataset
567
+ print(f"Training dataset size: {len(dataset.manifest.records)}")
568
+
569
+ for dataset in val:
570
+ dataset: Dataset
571
+ print(f"Validation dataset size: {len(dataset.manifest.records)}")
572
+
573
+ # Create wrapper datasets
574
+ self._train_set = TrainingDataset(
575
+ datasets=train,
576
+ samples_per_epoch=cfg.samples_per_epoch,
577
+ max_atoms=cfg.max_atoms,
578
+ max_tokens=cfg.max_tokens,
579
+ max_seqs=cfg.max_seqs,
580
+ pad_to_max_atoms=cfg.pad_to_max_atoms,
581
+ pad_to_max_tokens=cfg.pad_to_max_tokens,
582
+ pad_to_max_seqs=cfg.pad_to_max_seqs,
583
+ symmetries=symmetries,
584
+ atoms_per_window_queries=cfg.atoms_per_window_queries,
585
+ min_dist=cfg.min_dist,
586
+ max_dist=cfg.max_dist,
587
+ num_bins=cfg.num_bins,
588
+ overfit=cfg.overfit,
589
+ binder_pocket_conditioned_prop=cfg.train_binder_pocket_conditioned_prop,
590
+ binder_pocket_cutoff=cfg.binder_pocket_cutoff,
591
+ binder_pocket_sampling_geometric_p=cfg.binder_pocket_sampling_geometric_p,
592
+ return_symmetries=cfg.return_train_symmetries,
593
+ )
594
+ self._val_set = ValidationDataset(
595
+ datasets=train if cfg.overfit is not None else val,
596
+ seed=cfg.random_seed,
597
+ max_atoms=cfg.max_atoms,
598
+ max_tokens=cfg.max_tokens,
599
+ max_seqs=cfg.max_seqs,
600
+ pad_to_max_atoms=cfg.pad_to_max_atoms,
601
+ pad_to_max_tokens=cfg.pad_to_max_tokens,
602
+ pad_to_max_seqs=cfg.pad_to_max_seqs,
603
+ symmetries=symmetries,
604
+ atoms_per_window_queries=cfg.atoms_per_window_queries,
605
+ min_dist=cfg.min_dist,
606
+ max_dist=cfg.max_dist,
607
+ num_bins=cfg.num_bins,
608
+ overfit=cfg.overfit,
609
+ crop_validation=cfg.crop_validation,
610
+ return_symmetries=cfg.return_val_symmetries,
611
+ binder_pocket_conditioned_prop=cfg.val_binder_pocket_conditioned_prop,
612
+ binder_pocket_cutoff=cfg.binder_pocket_cutoff,
613
+ )
614
+
615
+ def setup(self, stage: Optional[str] = None) -> None:
616
+ """Run the setup for the DataModule.
617
+
618
+ Parameters
619
+ ----------
620
+ stage : str, optional
621
+ The stage, one of 'fit', 'validate', 'test'.
622
+
623
+ """
624
+ return
625
+
626
+ def train_dataloader(self) -> DataLoader:
627
+ """Get the training dataloader.
628
+
629
+ Returns
630
+ -------
631
+ DataLoader
632
+ The training dataloader.
633
+
634
+ """
635
+ return DataLoader(
636
+ self._train_set,
637
+ batch_size=self.cfg.batch_size,
638
+ num_workers=self.cfg.num_workers,
639
+ pin_memory=self.cfg.pin_memory,
640
+ shuffle=False,
641
+ collate_fn=collate,
642
+ )
643
+
644
+ def val_dataloader(self) -> DataLoader:
645
+ """Get the validation dataloader.
646
+
647
+ Returns
648
+ -------
649
+ DataLoader
650
+ The validation dataloader.
651
+
652
+ """
653
+ return DataLoader(
654
+ self._val_set,
655
+ batch_size=self.cfg.val_batch_size,
656
+ num_workers=self.cfg.num_workers,
657
+ pin_memory=self.cfg.pin_memory,
658
+ shuffle=False,
659
+ collate_fn=collate,
660
+ )