boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,330 @@
1
+ import json
2
+ from dataclasses import asdict, replace
3
+ from pathlib import Path
4
+ from typing import Literal
5
+
6
+ import numpy as np
7
+ import torch
8
+ from pytorch_lightning import LightningModule, Trainer
9
+ from pytorch_lightning.callbacks import BasePredictionWriter
10
+ from torch import Tensor
11
+
12
+ from boltz.data.types import Coords, Interface, Record, Structure, StructureV2
13
+ from boltz.data.write.mmcif import to_mmcif
14
+ from boltz.data.write.pdb import to_pdb
15
+
16
+
17
+ class BoltzWriter(BasePredictionWriter):
18
+ """Custom writer for predictions."""
19
+
20
+ def __init__(
21
+ self,
22
+ data_dir: str,
23
+ output_dir: str,
24
+ output_format: Literal["pdb", "mmcif"] = "mmcif",
25
+ boltz2: bool = False,
26
+ ) -> None:
27
+ """Initialize the writer.
28
+
29
+ Parameters
30
+ ----------
31
+ output_dir : str
32
+ The directory to save the predictions.
33
+
34
+ """
35
+ super().__init__(write_interval="batch")
36
+ if output_format not in ["pdb", "mmcif"]:
37
+ msg = f"Invalid output format: {output_format}"
38
+ raise ValueError(msg)
39
+
40
+ self.data_dir = Path(data_dir)
41
+ self.output_dir = Path(output_dir)
42
+ self.output_format = output_format
43
+ self.failed = 0
44
+ self.boltz2 = boltz2
45
+ self.output_dir.mkdir(parents=True, exist_ok=True)
46
+
47
+ def write_on_batch_end(
48
+ self,
49
+ trainer: Trainer, # noqa: ARG002
50
+ pl_module: LightningModule, # noqa: ARG002
51
+ prediction: dict[str, Tensor],
52
+ batch_indices: list[int], # noqa: ARG002
53
+ batch: dict[str, Tensor],
54
+ batch_idx: int, # noqa: ARG002
55
+ dataloader_idx: int, # noqa: ARG002
56
+ ) -> None:
57
+ """Write the predictions to disk."""
58
+ if prediction["exception"]:
59
+ self.failed += 1
60
+ return
61
+
62
+ # Get the records
63
+ records: list[Record] = batch["record"]
64
+
65
+ # Get the predictions
66
+ coords = prediction["coords"]
67
+ coords = coords.unsqueeze(0)
68
+
69
+ pad_masks = prediction["masks"]
70
+
71
+ # Get ranking
72
+ if "confidence_score" in prediction:
73
+ argsort = torch.argsort(prediction["confidence_score"], descending=True)
74
+ idx_to_rank = {idx.item(): rank for rank, idx in enumerate(argsort)}
75
+ # Handles cases where confidence summary is False
76
+ else:
77
+ idx_to_rank = {i: i for i in range(len(records))}
78
+
79
+ # Iterate over the records
80
+ for record, coord, pad_mask in zip(records, coords, pad_masks):
81
+ # Load the structure
82
+ path = self.data_dir / f"{record.id}.npz"
83
+ if self.boltz2:
84
+ structure: StructureV2 = StructureV2.load(path)
85
+ else:
86
+ structure: Structure = Structure.load(path)
87
+
88
+ # Compute chain map with masked removed, to be used later
89
+ chain_map = {}
90
+ for i, mask in enumerate(structure.mask):
91
+ if mask:
92
+ chain_map[len(chain_map)] = i
93
+
94
+ # Remove masked chains completely
95
+ structure = structure.remove_invalid_chains()
96
+
97
+ for model_idx in range(coord.shape[0]):
98
+ # Get model coord
99
+ model_coord = coord[model_idx]
100
+ # Unpad
101
+ coord_unpad = model_coord[pad_mask.bool()]
102
+ coord_unpad = coord_unpad.cpu().numpy()
103
+
104
+ # New atom table
105
+ atoms = structure.atoms
106
+ atoms["coords"] = coord_unpad
107
+ atoms["is_present"] = True
108
+ if self.boltz2:
109
+ structure: StructureV2
110
+ coord_unpad = [(x,) for x in coord_unpad]
111
+ coord_unpad = np.array(coord_unpad, dtype=Coords)
112
+
113
+ # Mew residue table
114
+ residues = structure.residues
115
+ residues["is_present"] = True
116
+
117
+ # Update the structure
118
+ interfaces = np.array([], dtype=Interface)
119
+ if self.boltz2:
120
+ new_structure: StructureV2 = replace(
121
+ structure,
122
+ atoms=atoms,
123
+ residues=residues,
124
+ interfaces=interfaces,
125
+ coords=coord_unpad,
126
+ )
127
+ else:
128
+ new_structure: Structure = replace(
129
+ structure,
130
+ atoms=atoms,
131
+ residues=residues,
132
+ interfaces=interfaces,
133
+ )
134
+
135
+ # Update chain info
136
+ chain_info = []
137
+ for chain in new_structure.chains:
138
+ old_chain_idx = chain_map[chain["asym_id"]]
139
+ old_chain_info = record.chains[old_chain_idx]
140
+ new_chain_info = replace(
141
+ old_chain_info,
142
+ chain_id=int(chain["asym_id"]),
143
+ valid=True,
144
+ )
145
+ chain_info.append(new_chain_info)
146
+
147
+ # Save the structure
148
+ struct_dir = self.output_dir / record.id
149
+ struct_dir.mkdir(exist_ok=True)
150
+
151
+ # Get plddt's
152
+ plddts = None
153
+ if "plddt" in prediction:
154
+ plddts = prediction["plddt"][model_idx]
155
+
156
+ # Create path name
157
+ outname = f"{record.id}_model_{idx_to_rank[model_idx]}"
158
+
159
+ # Save the structure
160
+ if self.output_format == "pdb":
161
+ path = struct_dir / f"{outname}.pdb"
162
+ with path.open("w") as f:
163
+ f.write(
164
+ to_pdb(new_structure, plddts=plddts, boltz2=self.boltz2)
165
+ )
166
+ elif self.output_format == "mmcif":
167
+ path = struct_dir / f"{outname}.cif"
168
+ with path.open("w") as f:
169
+ f.write(
170
+ to_mmcif(new_structure, plddts=plddts, boltz2=self.boltz2)
171
+ )
172
+ else:
173
+ path = struct_dir / f"{outname}.npz"
174
+ np.savez_compressed(path, **asdict(new_structure))
175
+
176
+ if self.boltz2 and record.affinity and idx_to_rank[model_idx] == 0:
177
+ path = struct_dir / f"pre_affinity_{record.id}.npz"
178
+ np.savez_compressed(path, **asdict(new_structure))
179
+ np.array(atoms["coords"][:, None], dtype=Coords)
180
+
181
+ # Save confidence summary
182
+ if "plddt" in prediction:
183
+ path = (
184
+ struct_dir
185
+ / f"confidence_{record.id}_model_{idx_to_rank[model_idx]}.json"
186
+ )
187
+ confidence_summary_dict = {}
188
+ for key in [
189
+ "confidence_score",
190
+ "ptm",
191
+ "iptm",
192
+ "ligand_iptm",
193
+ "protein_iptm",
194
+ "complex_plddt",
195
+ "complex_iplddt",
196
+ "complex_pde",
197
+ "complex_ipde",
198
+ ]:
199
+ confidence_summary_dict[key] = prediction[key][model_idx].item()
200
+ confidence_summary_dict["chains_ptm"] = {
201
+ idx: prediction["pair_chains_iptm"][idx][idx][model_idx].item()
202
+ for idx in prediction["pair_chains_iptm"]
203
+ }
204
+ confidence_summary_dict["pair_chains_iptm"] = {
205
+ idx1: {
206
+ idx2: prediction["pair_chains_iptm"][idx1][idx2][
207
+ model_idx
208
+ ].item()
209
+ for idx2 in prediction["pair_chains_iptm"][idx1]
210
+ }
211
+ for idx1 in prediction["pair_chains_iptm"]
212
+ }
213
+ with path.open("w") as f:
214
+ f.write(
215
+ json.dumps(
216
+ confidence_summary_dict,
217
+ indent=4,
218
+ )
219
+ )
220
+
221
+ # Save plddt
222
+ plddt = prediction["plddt"][model_idx]
223
+ path = (
224
+ struct_dir
225
+ / f"plddt_{record.id}_model_{idx_to_rank[model_idx]}.npz"
226
+ )
227
+ np.savez_compressed(path, plddt=plddt.cpu().numpy())
228
+
229
+ # Save pae
230
+ if "pae" in prediction:
231
+ pae = prediction["pae"][model_idx]
232
+ path = (
233
+ struct_dir
234
+ / f"pae_{record.id}_model_{idx_to_rank[model_idx]}.npz"
235
+ )
236
+ np.savez_compressed(path, pae=pae.cpu().numpy())
237
+
238
+ # Save pde
239
+ if "pde" in prediction:
240
+ pde = prediction["pde"][model_idx]
241
+ path = (
242
+ struct_dir
243
+ / f"pde_{record.id}_model_{idx_to_rank[model_idx]}.npz"
244
+ )
245
+ np.savez_compressed(path, pde=pde.cpu().numpy())
246
+
247
+ def on_predict_epoch_end(
248
+ self,
249
+ trainer: Trainer, # noqa: ARG002
250
+ pl_module: LightningModule, # noqa: ARG002
251
+ ) -> None:
252
+ """Print the number of failed examples."""
253
+ # Print number of failed examples
254
+ print(f"Number of failed examples: {self.failed}") # noqa: T201
255
+
256
+
257
+ class BoltzAffinityWriter(BasePredictionWriter):
258
+ """Custom writer for predictions."""
259
+
260
+ def __init__(
261
+ self,
262
+ data_dir: str,
263
+ output_dir: str,
264
+ ) -> None:
265
+ """Initialize the writer.
266
+
267
+ Parameters
268
+ ----------
269
+ output_dir : str
270
+ The directory to save the predictions.
271
+
272
+ """
273
+ super().__init__(write_interval="batch")
274
+ self.failed = 0
275
+ self.data_dir = Path(data_dir)
276
+ self.output_dir = Path(output_dir)
277
+ self.output_dir.mkdir(parents=True, exist_ok=True)
278
+
279
+ def write_on_batch_end(
280
+ self,
281
+ trainer: Trainer, # noqa: ARG002
282
+ pl_module: LightningModule, # noqa: ARG002
283
+ prediction: dict[str, Tensor],
284
+ batch_indices: list[int], # noqa: ARG002
285
+ batch: dict[str, Tensor],
286
+ batch_idx: int, # noqa: ARG002
287
+ dataloader_idx: int, # noqa: ARG002
288
+ ) -> None:
289
+ """Write the predictions to disk."""
290
+ if prediction["exception"]:
291
+ self.failed += 1
292
+ return
293
+ # Dump affinity summary
294
+ affinity_summary = {}
295
+ pred_affinity_value = prediction["affinity_pred_value"]
296
+ pred_affinity_probability = prediction["affinity_probability_binary"]
297
+ affinity_summary = {
298
+ "affinity_pred_value": pred_affinity_value.item(),
299
+ "affinity_probability_binary": pred_affinity_probability.item(),
300
+ }
301
+ if "affinity_pred_value1" in prediction:
302
+ pred_affinity_value1 = prediction["affinity_pred_value1"]
303
+ pred_affinity_probability1 = prediction["affinity_probability_binary1"]
304
+ pred_affinity_value2 = prediction["affinity_pred_value2"]
305
+ pred_affinity_probability2 = prediction["affinity_probability_binary2"]
306
+ affinity_summary["affinity_pred_value1"] = pred_affinity_value1.item()
307
+ affinity_summary["affinity_probability_binary1"] = (
308
+ pred_affinity_probability1.item()
309
+ )
310
+ affinity_summary["affinity_pred_value2"] = pred_affinity_value2.item()
311
+ affinity_summary["affinity_probability_binary2"] = (
312
+ pred_affinity_probability2.item()
313
+ )
314
+
315
+ # Save the affinity summary
316
+ struct_dir = self.output_dir / batch["record"][0].id
317
+ struct_dir.mkdir(exist_ok=True)
318
+ path = struct_dir / f"affinity_{batch['record'][0].id}.json"
319
+
320
+ with path.open("w") as f:
321
+ f.write(json.dumps(affinity_summary, indent=4))
322
+
323
+ def on_predict_epoch_end(
324
+ self,
325
+ trainer: Trainer, # noqa: ARG002
326
+ pl_module: LightningModule, # noqa: ARG002
327
+ ) -> None:
328
+ """Print the number of failed examples."""
329
+ # Print number of failed examples
330
+ print(f"Number of failed examples: {self.failed}") # noqa: T201