boltz-vsynthes 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltz/__init__.py +7 -0
- boltz/data/__init__.py +0 -0
- boltz/data/const.py +1184 -0
- boltz/data/crop/__init__.py +0 -0
- boltz/data/crop/affinity.py +164 -0
- boltz/data/crop/boltz.py +296 -0
- boltz/data/crop/cropper.py +45 -0
- boltz/data/feature/__init__.py +0 -0
- boltz/data/feature/featurizer.py +1230 -0
- boltz/data/feature/featurizerv2.py +2208 -0
- boltz/data/feature/symmetry.py +602 -0
- boltz/data/filter/__init__.py +0 -0
- boltz/data/filter/dynamic/__init__.py +0 -0
- boltz/data/filter/dynamic/date.py +76 -0
- boltz/data/filter/dynamic/filter.py +24 -0
- boltz/data/filter/dynamic/max_residues.py +37 -0
- boltz/data/filter/dynamic/resolution.py +34 -0
- boltz/data/filter/dynamic/size.py +38 -0
- boltz/data/filter/dynamic/subset.py +42 -0
- boltz/data/filter/static/__init__.py +0 -0
- boltz/data/filter/static/filter.py +26 -0
- boltz/data/filter/static/ligand.py +37 -0
- boltz/data/filter/static/polymer.py +299 -0
- boltz/data/module/__init__.py +0 -0
- boltz/data/module/inference.py +307 -0
- boltz/data/module/inferencev2.py +429 -0
- boltz/data/module/training.py +684 -0
- boltz/data/module/trainingv2.py +660 -0
- boltz/data/mol.py +900 -0
- boltz/data/msa/__init__.py +0 -0
- boltz/data/msa/mmseqs2.py +235 -0
- boltz/data/pad.py +84 -0
- boltz/data/parse/__init__.py +0 -0
- boltz/data/parse/a3m.py +134 -0
- boltz/data/parse/csv.py +100 -0
- boltz/data/parse/fasta.py +138 -0
- boltz/data/parse/mmcif.py +1239 -0
- boltz/data/parse/mmcif_with_constraints.py +1607 -0
- boltz/data/parse/schema.py +1851 -0
- boltz/data/parse/yaml.py +68 -0
- boltz/data/sample/__init__.py +0 -0
- boltz/data/sample/cluster.py +283 -0
- boltz/data/sample/distillation.py +57 -0
- boltz/data/sample/random.py +39 -0
- boltz/data/sample/sampler.py +49 -0
- boltz/data/tokenize/__init__.py +0 -0
- boltz/data/tokenize/boltz.py +195 -0
- boltz/data/tokenize/boltz2.py +396 -0
- boltz/data/tokenize/tokenizer.py +24 -0
- boltz/data/types.py +777 -0
- boltz/data/write/__init__.py +0 -0
- boltz/data/write/mmcif.py +305 -0
- boltz/data/write/pdb.py +171 -0
- boltz/data/write/utils.py +23 -0
- boltz/data/write/writer.py +330 -0
- boltz/main.py +1292 -0
- boltz/model/__init__.py +0 -0
- boltz/model/layers/__init__.py +0 -0
- boltz/model/layers/attention.py +132 -0
- boltz/model/layers/attentionv2.py +111 -0
- boltz/model/layers/confidence_utils.py +231 -0
- boltz/model/layers/dropout.py +34 -0
- boltz/model/layers/initialize.py +100 -0
- boltz/model/layers/outer_product_mean.py +98 -0
- boltz/model/layers/pair_averaging.py +135 -0
- boltz/model/layers/pairformer.py +337 -0
- boltz/model/layers/relative.py +58 -0
- boltz/model/layers/transition.py +78 -0
- boltz/model/layers/triangular_attention/__init__.py +0 -0
- boltz/model/layers/triangular_attention/attention.py +189 -0
- boltz/model/layers/triangular_attention/primitives.py +409 -0
- boltz/model/layers/triangular_attention/utils.py +380 -0
- boltz/model/layers/triangular_mult.py +212 -0
- boltz/model/loss/__init__.py +0 -0
- boltz/model/loss/bfactor.py +49 -0
- boltz/model/loss/confidence.py +590 -0
- boltz/model/loss/confidencev2.py +621 -0
- boltz/model/loss/diffusion.py +171 -0
- boltz/model/loss/diffusionv2.py +134 -0
- boltz/model/loss/distogram.py +48 -0
- boltz/model/loss/distogramv2.py +105 -0
- boltz/model/loss/validation.py +1025 -0
- boltz/model/models/__init__.py +0 -0
- boltz/model/models/boltz1.py +1286 -0
- boltz/model/models/boltz2.py +1249 -0
- boltz/model/modules/__init__.py +0 -0
- boltz/model/modules/affinity.py +223 -0
- boltz/model/modules/confidence.py +481 -0
- boltz/model/modules/confidence_utils.py +181 -0
- boltz/model/modules/confidencev2.py +495 -0
- boltz/model/modules/diffusion.py +844 -0
- boltz/model/modules/diffusion_conditioning.py +116 -0
- boltz/model/modules/diffusionv2.py +677 -0
- boltz/model/modules/encoders.py +639 -0
- boltz/model/modules/encodersv2.py +565 -0
- boltz/model/modules/transformers.py +322 -0
- boltz/model/modules/transformersv2.py +261 -0
- boltz/model/modules/trunk.py +688 -0
- boltz/model/modules/trunkv2.py +828 -0
- boltz/model/modules/utils.py +303 -0
- boltz/model/optim/__init__.py +0 -0
- boltz/model/optim/ema.py +389 -0
- boltz/model/optim/scheduler.py +99 -0
- boltz/model/potentials/__init__.py +0 -0
- boltz/model/potentials/potentials.py +497 -0
- boltz/model/potentials/schedules.py +32 -0
- boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
- boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
- boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
- boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
- boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
- boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,330 @@
|
|
1
|
+
import json
|
2
|
+
from dataclasses import asdict, replace
|
3
|
+
from pathlib import Path
|
4
|
+
from typing import Literal
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
import torch
|
8
|
+
from pytorch_lightning import LightningModule, Trainer
|
9
|
+
from pytorch_lightning.callbacks import BasePredictionWriter
|
10
|
+
from torch import Tensor
|
11
|
+
|
12
|
+
from boltz.data.types import Coords, Interface, Record, Structure, StructureV2
|
13
|
+
from boltz.data.write.mmcif import to_mmcif
|
14
|
+
from boltz.data.write.pdb import to_pdb
|
15
|
+
|
16
|
+
|
17
|
+
class BoltzWriter(BasePredictionWriter):
|
18
|
+
"""Custom writer for predictions."""
|
19
|
+
|
20
|
+
def __init__(
|
21
|
+
self,
|
22
|
+
data_dir: str,
|
23
|
+
output_dir: str,
|
24
|
+
output_format: Literal["pdb", "mmcif"] = "mmcif",
|
25
|
+
boltz2: bool = False,
|
26
|
+
) -> None:
|
27
|
+
"""Initialize the writer.
|
28
|
+
|
29
|
+
Parameters
|
30
|
+
----------
|
31
|
+
output_dir : str
|
32
|
+
The directory to save the predictions.
|
33
|
+
|
34
|
+
"""
|
35
|
+
super().__init__(write_interval="batch")
|
36
|
+
if output_format not in ["pdb", "mmcif"]:
|
37
|
+
msg = f"Invalid output format: {output_format}"
|
38
|
+
raise ValueError(msg)
|
39
|
+
|
40
|
+
self.data_dir = Path(data_dir)
|
41
|
+
self.output_dir = Path(output_dir)
|
42
|
+
self.output_format = output_format
|
43
|
+
self.failed = 0
|
44
|
+
self.boltz2 = boltz2
|
45
|
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
46
|
+
|
47
|
+
def write_on_batch_end(
|
48
|
+
self,
|
49
|
+
trainer: Trainer, # noqa: ARG002
|
50
|
+
pl_module: LightningModule, # noqa: ARG002
|
51
|
+
prediction: dict[str, Tensor],
|
52
|
+
batch_indices: list[int], # noqa: ARG002
|
53
|
+
batch: dict[str, Tensor],
|
54
|
+
batch_idx: int, # noqa: ARG002
|
55
|
+
dataloader_idx: int, # noqa: ARG002
|
56
|
+
) -> None:
|
57
|
+
"""Write the predictions to disk."""
|
58
|
+
if prediction["exception"]:
|
59
|
+
self.failed += 1
|
60
|
+
return
|
61
|
+
|
62
|
+
# Get the records
|
63
|
+
records: list[Record] = batch["record"]
|
64
|
+
|
65
|
+
# Get the predictions
|
66
|
+
coords = prediction["coords"]
|
67
|
+
coords = coords.unsqueeze(0)
|
68
|
+
|
69
|
+
pad_masks = prediction["masks"]
|
70
|
+
|
71
|
+
# Get ranking
|
72
|
+
if "confidence_score" in prediction:
|
73
|
+
argsort = torch.argsort(prediction["confidence_score"], descending=True)
|
74
|
+
idx_to_rank = {idx.item(): rank for rank, idx in enumerate(argsort)}
|
75
|
+
# Handles cases where confidence summary is False
|
76
|
+
else:
|
77
|
+
idx_to_rank = {i: i for i in range(len(records))}
|
78
|
+
|
79
|
+
# Iterate over the records
|
80
|
+
for record, coord, pad_mask in zip(records, coords, pad_masks):
|
81
|
+
# Load the structure
|
82
|
+
path = self.data_dir / f"{record.id}.npz"
|
83
|
+
if self.boltz2:
|
84
|
+
structure: StructureV2 = StructureV2.load(path)
|
85
|
+
else:
|
86
|
+
structure: Structure = Structure.load(path)
|
87
|
+
|
88
|
+
# Compute chain map with masked removed, to be used later
|
89
|
+
chain_map = {}
|
90
|
+
for i, mask in enumerate(structure.mask):
|
91
|
+
if mask:
|
92
|
+
chain_map[len(chain_map)] = i
|
93
|
+
|
94
|
+
# Remove masked chains completely
|
95
|
+
structure = structure.remove_invalid_chains()
|
96
|
+
|
97
|
+
for model_idx in range(coord.shape[0]):
|
98
|
+
# Get model coord
|
99
|
+
model_coord = coord[model_idx]
|
100
|
+
# Unpad
|
101
|
+
coord_unpad = model_coord[pad_mask.bool()]
|
102
|
+
coord_unpad = coord_unpad.cpu().numpy()
|
103
|
+
|
104
|
+
# New atom table
|
105
|
+
atoms = structure.atoms
|
106
|
+
atoms["coords"] = coord_unpad
|
107
|
+
atoms["is_present"] = True
|
108
|
+
if self.boltz2:
|
109
|
+
structure: StructureV2
|
110
|
+
coord_unpad = [(x,) for x in coord_unpad]
|
111
|
+
coord_unpad = np.array(coord_unpad, dtype=Coords)
|
112
|
+
|
113
|
+
# Mew residue table
|
114
|
+
residues = structure.residues
|
115
|
+
residues["is_present"] = True
|
116
|
+
|
117
|
+
# Update the structure
|
118
|
+
interfaces = np.array([], dtype=Interface)
|
119
|
+
if self.boltz2:
|
120
|
+
new_structure: StructureV2 = replace(
|
121
|
+
structure,
|
122
|
+
atoms=atoms,
|
123
|
+
residues=residues,
|
124
|
+
interfaces=interfaces,
|
125
|
+
coords=coord_unpad,
|
126
|
+
)
|
127
|
+
else:
|
128
|
+
new_structure: Structure = replace(
|
129
|
+
structure,
|
130
|
+
atoms=atoms,
|
131
|
+
residues=residues,
|
132
|
+
interfaces=interfaces,
|
133
|
+
)
|
134
|
+
|
135
|
+
# Update chain info
|
136
|
+
chain_info = []
|
137
|
+
for chain in new_structure.chains:
|
138
|
+
old_chain_idx = chain_map[chain["asym_id"]]
|
139
|
+
old_chain_info = record.chains[old_chain_idx]
|
140
|
+
new_chain_info = replace(
|
141
|
+
old_chain_info,
|
142
|
+
chain_id=int(chain["asym_id"]),
|
143
|
+
valid=True,
|
144
|
+
)
|
145
|
+
chain_info.append(new_chain_info)
|
146
|
+
|
147
|
+
# Save the structure
|
148
|
+
struct_dir = self.output_dir / record.id
|
149
|
+
struct_dir.mkdir(exist_ok=True)
|
150
|
+
|
151
|
+
# Get plddt's
|
152
|
+
plddts = None
|
153
|
+
if "plddt" in prediction:
|
154
|
+
plddts = prediction["plddt"][model_idx]
|
155
|
+
|
156
|
+
# Create path name
|
157
|
+
outname = f"{record.id}_model_{idx_to_rank[model_idx]}"
|
158
|
+
|
159
|
+
# Save the structure
|
160
|
+
if self.output_format == "pdb":
|
161
|
+
path = struct_dir / f"{outname}.pdb"
|
162
|
+
with path.open("w") as f:
|
163
|
+
f.write(
|
164
|
+
to_pdb(new_structure, plddts=plddts, boltz2=self.boltz2)
|
165
|
+
)
|
166
|
+
elif self.output_format == "mmcif":
|
167
|
+
path = struct_dir / f"{outname}.cif"
|
168
|
+
with path.open("w") as f:
|
169
|
+
f.write(
|
170
|
+
to_mmcif(new_structure, plddts=plddts, boltz2=self.boltz2)
|
171
|
+
)
|
172
|
+
else:
|
173
|
+
path = struct_dir / f"{outname}.npz"
|
174
|
+
np.savez_compressed(path, **asdict(new_structure))
|
175
|
+
|
176
|
+
if self.boltz2 and record.affinity and idx_to_rank[model_idx] == 0:
|
177
|
+
path = struct_dir / f"pre_affinity_{record.id}.npz"
|
178
|
+
np.savez_compressed(path, **asdict(new_structure))
|
179
|
+
np.array(atoms["coords"][:, None], dtype=Coords)
|
180
|
+
|
181
|
+
# Save confidence summary
|
182
|
+
if "plddt" in prediction:
|
183
|
+
path = (
|
184
|
+
struct_dir
|
185
|
+
/ f"confidence_{record.id}_model_{idx_to_rank[model_idx]}.json"
|
186
|
+
)
|
187
|
+
confidence_summary_dict = {}
|
188
|
+
for key in [
|
189
|
+
"confidence_score",
|
190
|
+
"ptm",
|
191
|
+
"iptm",
|
192
|
+
"ligand_iptm",
|
193
|
+
"protein_iptm",
|
194
|
+
"complex_plddt",
|
195
|
+
"complex_iplddt",
|
196
|
+
"complex_pde",
|
197
|
+
"complex_ipde",
|
198
|
+
]:
|
199
|
+
confidence_summary_dict[key] = prediction[key][model_idx].item()
|
200
|
+
confidence_summary_dict["chains_ptm"] = {
|
201
|
+
idx: prediction["pair_chains_iptm"][idx][idx][model_idx].item()
|
202
|
+
for idx in prediction["pair_chains_iptm"]
|
203
|
+
}
|
204
|
+
confidence_summary_dict["pair_chains_iptm"] = {
|
205
|
+
idx1: {
|
206
|
+
idx2: prediction["pair_chains_iptm"][idx1][idx2][
|
207
|
+
model_idx
|
208
|
+
].item()
|
209
|
+
for idx2 in prediction["pair_chains_iptm"][idx1]
|
210
|
+
}
|
211
|
+
for idx1 in prediction["pair_chains_iptm"]
|
212
|
+
}
|
213
|
+
with path.open("w") as f:
|
214
|
+
f.write(
|
215
|
+
json.dumps(
|
216
|
+
confidence_summary_dict,
|
217
|
+
indent=4,
|
218
|
+
)
|
219
|
+
)
|
220
|
+
|
221
|
+
# Save plddt
|
222
|
+
plddt = prediction["plddt"][model_idx]
|
223
|
+
path = (
|
224
|
+
struct_dir
|
225
|
+
/ f"plddt_{record.id}_model_{idx_to_rank[model_idx]}.npz"
|
226
|
+
)
|
227
|
+
np.savez_compressed(path, plddt=plddt.cpu().numpy())
|
228
|
+
|
229
|
+
# Save pae
|
230
|
+
if "pae" in prediction:
|
231
|
+
pae = prediction["pae"][model_idx]
|
232
|
+
path = (
|
233
|
+
struct_dir
|
234
|
+
/ f"pae_{record.id}_model_{idx_to_rank[model_idx]}.npz"
|
235
|
+
)
|
236
|
+
np.savez_compressed(path, pae=pae.cpu().numpy())
|
237
|
+
|
238
|
+
# Save pde
|
239
|
+
if "pde" in prediction:
|
240
|
+
pde = prediction["pde"][model_idx]
|
241
|
+
path = (
|
242
|
+
struct_dir
|
243
|
+
/ f"pde_{record.id}_model_{idx_to_rank[model_idx]}.npz"
|
244
|
+
)
|
245
|
+
np.savez_compressed(path, pde=pde.cpu().numpy())
|
246
|
+
|
247
|
+
def on_predict_epoch_end(
|
248
|
+
self,
|
249
|
+
trainer: Trainer, # noqa: ARG002
|
250
|
+
pl_module: LightningModule, # noqa: ARG002
|
251
|
+
) -> None:
|
252
|
+
"""Print the number of failed examples."""
|
253
|
+
# Print number of failed examples
|
254
|
+
print(f"Number of failed examples: {self.failed}") # noqa: T201
|
255
|
+
|
256
|
+
|
257
|
+
class BoltzAffinityWriter(BasePredictionWriter):
|
258
|
+
"""Custom writer for predictions."""
|
259
|
+
|
260
|
+
def __init__(
|
261
|
+
self,
|
262
|
+
data_dir: str,
|
263
|
+
output_dir: str,
|
264
|
+
) -> None:
|
265
|
+
"""Initialize the writer.
|
266
|
+
|
267
|
+
Parameters
|
268
|
+
----------
|
269
|
+
output_dir : str
|
270
|
+
The directory to save the predictions.
|
271
|
+
|
272
|
+
"""
|
273
|
+
super().__init__(write_interval="batch")
|
274
|
+
self.failed = 0
|
275
|
+
self.data_dir = Path(data_dir)
|
276
|
+
self.output_dir = Path(output_dir)
|
277
|
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
278
|
+
|
279
|
+
def write_on_batch_end(
|
280
|
+
self,
|
281
|
+
trainer: Trainer, # noqa: ARG002
|
282
|
+
pl_module: LightningModule, # noqa: ARG002
|
283
|
+
prediction: dict[str, Tensor],
|
284
|
+
batch_indices: list[int], # noqa: ARG002
|
285
|
+
batch: dict[str, Tensor],
|
286
|
+
batch_idx: int, # noqa: ARG002
|
287
|
+
dataloader_idx: int, # noqa: ARG002
|
288
|
+
) -> None:
|
289
|
+
"""Write the predictions to disk."""
|
290
|
+
if prediction["exception"]:
|
291
|
+
self.failed += 1
|
292
|
+
return
|
293
|
+
# Dump affinity summary
|
294
|
+
affinity_summary = {}
|
295
|
+
pred_affinity_value = prediction["affinity_pred_value"]
|
296
|
+
pred_affinity_probability = prediction["affinity_probability_binary"]
|
297
|
+
affinity_summary = {
|
298
|
+
"affinity_pred_value": pred_affinity_value.item(),
|
299
|
+
"affinity_probability_binary": pred_affinity_probability.item(),
|
300
|
+
}
|
301
|
+
if "affinity_pred_value1" in prediction:
|
302
|
+
pred_affinity_value1 = prediction["affinity_pred_value1"]
|
303
|
+
pred_affinity_probability1 = prediction["affinity_probability_binary1"]
|
304
|
+
pred_affinity_value2 = prediction["affinity_pred_value2"]
|
305
|
+
pred_affinity_probability2 = prediction["affinity_probability_binary2"]
|
306
|
+
affinity_summary["affinity_pred_value1"] = pred_affinity_value1.item()
|
307
|
+
affinity_summary["affinity_probability_binary1"] = (
|
308
|
+
pred_affinity_probability1.item()
|
309
|
+
)
|
310
|
+
affinity_summary["affinity_pred_value2"] = pred_affinity_value2.item()
|
311
|
+
affinity_summary["affinity_probability_binary2"] = (
|
312
|
+
pred_affinity_probability2.item()
|
313
|
+
)
|
314
|
+
|
315
|
+
# Save the affinity summary
|
316
|
+
struct_dir = self.output_dir / batch["record"][0].id
|
317
|
+
struct_dir.mkdir(exist_ok=True)
|
318
|
+
path = struct_dir / f"affinity_{batch['record'][0].id}.json"
|
319
|
+
|
320
|
+
with path.open("w") as f:
|
321
|
+
f.write(json.dumps(affinity_summary, indent=4))
|
322
|
+
|
323
|
+
def on_predict_epoch_end(
|
324
|
+
self,
|
325
|
+
trainer: Trainer, # noqa: ARG002
|
326
|
+
pl_module: LightningModule, # noqa: ARG002
|
327
|
+
) -> None:
|
328
|
+
"""Print the number of failed examples."""
|
329
|
+
# Print number of failed examples
|
330
|
+
print(f"Number of failed examples: {self.failed}") # noqa: T201
|