boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,307 @@
1
+ from pathlib import Path
2
+ from typing import Optional
3
+
4
+ import numpy as np
5
+ import pytorch_lightning as pl
6
+ import torch
7
+ from torch import Tensor
8
+ from torch.utils.data import DataLoader
9
+
10
+ from boltz.data import const
11
+ from boltz.data.feature.featurizer import BoltzFeaturizer
12
+ from boltz.data.pad import pad_to_max
13
+ from boltz.data.tokenize.boltz import BoltzTokenizer
14
+ from boltz.data.types import (
15
+ MSA,
16
+ Connection,
17
+ Input,
18
+ Manifest,
19
+ Record,
20
+ ResidueConstraints,
21
+ Structure,
22
+ )
23
+
24
+
25
+ def load_input(
26
+ record: Record,
27
+ target_dir: Path,
28
+ msa_dir: Path,
29
+ constraints_dir: Optional[Path] = None,
30
+ ) -> Input:
31
+ """Load the given input data.
32
+
33
+ Parameters
34
+ ----------
35
+ record : Record
36
+ The record to load.
37
+ target_dir : Path
38
+ The path to the data directory.
39
+ msa_dir : Path
40
+ The path to msa directory.
41
+
42
+ Returns
43
+ -------
44
+ Input
45
+ The loaded input.
46
+
47
+ """
48
+ # Load the structure
49
+ structure = np.load(target_dir / f"{record.id}.npz")
50
+ structure = Structure(
51
+ atoms=structure["atoms"],
52
+ bonds=structure["bonds"],
53
+ residues=structure["residues"],
54
+ chains=structure["chains"],
55
+ connections=structure["connections"].astype(Connection),
56
+ interfaces=structure["interfaces"],
57
+ mask=structure["mask"],
58
+ )
59
+
60
+ msas = {}
61
+ for chain in record.chains:
62
+ msa_id = chain.msa_id
63
+ # Load the MSA for this chain, if any
64
+ if msa_id != -1:
65
+ msa = np.load(msa_dir / f"{msa_id}.npz")
66
+ msas[chain.chain_id] = MSA(**msa)
67
+
68
+ residue_constraints = None
69
+ if constraints_dir is not None:
70
+ residue_constraints = ResidueConstraints.load(
71
+ constraints_dir / f"{record.id}.npz"
72
+ )
73
+
74
+ return Input(structure, msas, record, residue_constraints)
75
+
76
+
77
+ def collate(data: list[dict[str, Tensor]]) -> dict[str, Tensor]:
78
+ """Collate the data.
79
+
80
+ Parameters
81
+ ----------
82
+ data : List[Dict[str, Tensor]]
83
+ The data to collate.
84
+
85
+ Returns
86
+ -------
87
+ Dict[str, Tensor]
88
+ The collated data.
89
+
90
+ """
91
+ # Get the keys
92
+ keys = data[0].keys()
93
+
94
+ # Collate the data
95
+ collated = {}
96
+ for key in keys:
97
+ values = [d[key] for d in data]
98
+
99
+ if key not in [
100
+ "all_coords",
101
+ "all_resolved_mask",
102
+ "crop_to_all_atom_map",
103
+ "chain_symmetries",
104
+ "amino_acids_symmetries",
105
+ "ligand_symmetries",
106
+ "record",
107
+ ]:
108
+ # Check if all have the same shape
109
+ shape = values[0].shape
110
+ if not all(v.shape == shape for v in values):
111
+ values, _ = pad_to_max(values, 0)
112
+ else:
113
+ values = torch.stack(values, dim=0)
114
+
115
+ # Stack the values
116
+ collated[key] = values
117
+
118
+ return collated
119
+
120
+
121
+ class PredictionDataset(torch.utils.data.Dataset):
122
+ """Base iterable dataset."""
123
+
124
+ def __init__(
125
+ self,
126
+ manifest: Manifest,
127
+ target_dir: Path,
128
+ msa_dir: Path,
129
+ constraints_dir: Optional[Path] = None,
130
+ ) -> None:
131
+ """Initialize the training dataset.
132
+
133
+ Parameters
134
+ ----------
135
+ manifest : Manifest
136
+ The manifest to load data from.
137
+ target_dir : Path
138
+ The path to the target directory.
139
+ msa_dir : Path
140
+ The path to the msa directory.
141
+
142
+ """
143
+ super().__init__()
144
+ self.manifest = manifest
145
+ self.target_dir = target_dir
146
+ self.msa_dir = msa_dir
147
+ self.constraints_dir = constraints_dir
148
+ self.tokenizer = BoltzTokenizer()
149
+ self.featurizer = BoltzFeaturizer()
150
+
151
+ def __getitem__(self, idx: int) -> dict:
152
+ """Get an item from the dataset.
153
+
154
+ Returns
155
+ -------
156
+ Dict[str, Tensor]
157
+ The sampled data features.
158
+
159
+ """
160
+ # Get a sample from the dataset
161
+ record = self.manifest.records[idx]
162
+
163
+ # Get the structure
164
+ try:
165
+ input_data = load_input(
166
+ record,
167
+ self.target_dir,
168
+ self.msa_dir,
169
+ self.constraints_dir,
170
+ )
171
+ except Exception as e: # noqa: BLE001
172
+ print(f"Failed to load input for {record.id} with error {e}. Skipping.") # noqa: T201
173
+ return self.__getitem__(0)
174
+
175
+ # Tokenize structure
176
+ try:
177
+ tokenized = self.tokenizer.tokenize(input_data)
178
+ except Exception as e: # noqa: BLE001
179
+ print(f"Tokenizer failed on {record.id} with error {e}. Skipping.") # noqa: T201
180
+ return self.__getitem__(0)
181
+
182
+ # Inference specific options
183
+ options = record.inference_options
184
+ if options is None or len(options.pocket_constraints) == 0:
185
+ binder, pocket = None, None
186
+ else:
187
+ binder, pocket = options.pocket_constraints[0][0], options.pocket_constraints[0][1]
188
+
189
+ # Compute features
190
+ try:
191
+ features = self.featurizer.process(
192
+ tokenized,
193
+ training=False,
194
+ max_atoms=None,
195
+ max_tokens=None,
196
+ max_seqs=const.max_msa_seqs,
197
+ pad_to_max_seqs=False,
198
+ symmetries={},
199
+ compute_symmetries=False,
200
+ inference_binder=binder,
201
+ inference_pocket=pocket,
202
+ compute_constraint_features=True,
203
+ )
204
+ except Exception as e: # noqa: BLE001
205
+ print(f"Featurizer failed on {record.id} with error {e}. Skipping.") # noqa: T201
206
+ return self.__getitem__(0)
207
+
208
+ features["record"] = record
209
+ return features
210
+
211
+ def __len__(self) -> int:
212
+ """Get the length of the dataset.
213
+
214
+ Returns
215
+ -------
216
+ int
217
+ The length of the dataset.
218
+
219
+ """
220
+ return len(self.manifest.records)
221
+
222
+
223
+ class BoltzInferenceDataModule(pl.LightningDataModule):
224
+ """DataModule for Boltz inference."""
225
+
226
+ def __init__(
227
+ self,
228
+ manifest: Manifest,
229
+ target_dir: Path,
230
+ msa_dir: Path,
231
+ num_workers: int,
232
+ constraints_dir: Optional[Path] = None,
233
+ ) -> None:
234
+ """Initialize the DataModule.
235
+
236
+ Parameters
237
+ ----------
238
+ config : DataConfig
239
+ The data configuration.
240
+
241
+ """
242
+ super().__init__()
243
+ self.num_workers = num_workers
244
+ self.manifest = manifest
245
+ self.target_dir = target_dir
246
+ self.msa_dir = msa_dir
247
+ self.constraints_dir = constraints_dir
248
+
249
+ def predict_dataloader(self) -> DataLoader:
250
+ """Get the training dataloader.
251
+
252
+ Returns
253
+ -------
254
+ DataLoader
255
+ The training dataloader.
256
+
257
+ """
258
+ dataset = PredictionDataset(
259
+ manifest=self.manifest,
260
+ target_dir=self.target_dir,
261
+ msa_dir=self.msa_dir,
262
+ constraints_dir=self.constraints_dir,
263
+ )
264
+ return DataLoader(
265
+ dataset,
266
+ batch_size=1,
267
+ num_workers=self.num_workers,
268
+ pin_memory=True,
269
+ shuffle=False,
270
+ collate_fn=collate,
271
+ )
272
+
273
+ def transfer_batch_to_device(
274
+ self,
275
+ batch: dict,
276
+ device: torch.device,
277
+ dataloader_idx: int, # noqa: ARG002
278
+ ) -> dict:
279
+ """Transfer a batch to the given device.
280
+
281
+ Parameters
282
+ ----------
283
+ batch : Dict
284
+ The batch to transfer.
285
+ device : torch.device
286
+ The device to transfer to.
287
+ dataloader_idx : int
288
+ The dataloader index.
289
+
290
+ Returns
291
+ -------
292
+ np.Any
293
+ The transferred batch.
294
+
295
+ """
296
+ for key in batch:
297
+ if key not in [
298
+ "all_coords",
299
+ "all_resolved_mask",
300
+ "crop_to_all_atom_map",
301
+ "chain_symmetries",
302
+ "amino_acids_symmetries",
303
+ "ligand_symmetries",
304
+ "record",
305
+ ]:
306
+ batch[key] = batch[key].to(device)
307
+ return batch