boltz-vsynthes 0.0.17__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltz/data/module/inferencev2.py +7 -7
- boltz/main.py +59 -24
- boltz/model/models/boltz2.py +1 -1
- {boltz_vsynthes-0.0.17.dist-info → boltz_vsynthes-0.1.0.dist-info}/METADATA +4 -4
- {boltz_vsynthes-0.0.17.dist-info → boltz_vsynthes-0.1.0.dist-info}/RECORD +9 -9
- {boltz_vsynthes-0.0.17.dist-info → boltz_vsynthes-0.1.0.dist-info}/WHEEL +0 -0
- {boltz_vsynthes-0.0.17.dist-info → boltz_vsynthes-0.1.0.dist-info}/entry_points.txt +0 -0
- {boltz_vsynthes-0.0.17.dist-info → boltz_vsynthes-0.1.0.dist-info}/licenses/LICENSE +0 -0
- {boltz_vsynthes-0.0.17.dist-info → boltz_vsynthes-0.1.0.dist-info}/top_level.txt +0 -0
boltz/data/module/inferencev2.py
CHANGED
@@ -59,16 +59,16 @@ def load_input(
|
|
59
59
|
|
60
60
|
"""
|
61
61
|
# Load the structure
|
62
|
-
# if affinity:
|
63
|
-
# structure = StructureV2.load(
|
64
|
-
# target_dir / record.id / f"pre_affinity_{record.id}.npz"
|
65
|
-
# )
|
66
62
|
if affinity:
|
67
|
-
if target_dir.name == "predictions":
|
68
|
-
target_dir = target_dir.parent / "processed"
|
69
63
|
structure = StructureV2.load(
|
70
|
-
target_dir / f"
|
64
|
+
target_dir / record.id / f"pre_affinity_{record.id}.npz"
|
71
65
|
)
|
66
|
+
# if affinity:
|
67
|
+
# if target_dir.name == "predictions":
|
68
|
+
# target_dir = target_dir.parent / "processed"
|
69
|
+
# structure = StructureV2.load(
|
70
|
+
# target_dir / f"structures/{record.id}.npz"
|
71
|
+
# )
|
72
72
|
else:
|
73
73
|
structure = StructureV2.load(target_dir / f"{record.id}.npz")
|
74
74
|
|
boltz/main.py
CHANGED
@@ -9,7 +9,7 @@ from dataclasses import asdict, dataclass
|
|
9
9
|
from functools import partial
|
10
10
|
from multiprocessing import Pool
|
11
11
|
from pathlib import Path
|
12
|
-
from typing import Literal, Optional
|
12
|
+
from typing import Literal, Optional, List
|
13
13
|
|
14
14
|
import click
|
15
15
|
import torch
|
@@ -18,6 +18,8 @@ from pytorch_lightning.strategies import DDPStrategy
|
|
18
18
|
from pytorch_lightning.utilities import rank_zero_only
|
19
19
|
from rdkit import Chem
|
20
20
|
from tqdm import tqdm
|
21
|
+
import time
|
22
|
+
from datetime import datetime
|
21
23
|
|
22
24
|
from boltz.data import const
|
23
25
|
from boltz.data.module.inference import BoltzInferenceDataModule
|
@@ -203,22 +205,21 @@ def download_boltz2(cache: Path) -> None:
|
|
203
205
|
The cache directory.
|
204
206
|
|
205
207
|
"""
|
208
|
+
# Use /tmp if possible for faster local disk I/O
|
209
|
+
if str(cache).startswith("/home") or str(cache).startswith("/mnt"):
|
210
|
+
cache = Path("/tmp/boltz_cache")
|
211
|
+
cache.mkdir(parents=True, exist_ok=True)
|
212
|
+
|
206
213
|
# Download CCD
|
207
214
|
mols = cache / "mols"
|
208
215
|
tar_mols = cache / "mols.tar"
|
209
|
-
if not tar_mols.exists():
|
210
|
-
click.echo(
|
211
|
-
f"Downloading the CCD data to {tar_mols}. "
|
212
|
-
"This may take a bit of time. You may change the cache directory "
|
213
|
-
"with the --cache flag."
|
214
|
-
)
|
215
|
-
urllib.request.urlretrieve(MOL_URL, str(tar_mols)) # noqa: S310
|
216
216
|
if not mols.exists():
|
217
217
|
click.echo(
|
218
|
-
f"
|
218
|
+
f"Downloading and extracting the CCD data to {mols}. "
|
219
219
|
"This may take a bit of time. You may change the cache directory "
|
220
220
|
"with the --cache flag."
|
221
221
|
)
|
222
|
+
urllib.request.urlretrieve(MOL_URL, str(tar_mols)) # noqa: S310
|
222
223
|
with tarfile.open(str(tar_mols), "r") as tar:
|
223
224
|
tar.extractall(cache) # noqa: S202
|
224
225
|
|
@@ -983,6 +984,7 @@ def predict( # noqa: C901, PLR0915, PLR0912
|
|
983
984
|
torch.set_grad_enabled(False)
|
984
985
|
|
985
986
|
# Ignore matmul precision warning
|
987
|
+
# torch.set_float32_matmul_precision('medium')
|
986
988
|
torch.set_float32_matmul_precision("highest")
|
987
989
|
|
988
990
|
# Set rdkit pickle logic
|
@@ -1029,14 +1031,16 @@ def predict( # noqa: C901, PLR0915, PLR0912
|
|
1029
1031
|
msg = f"Method {method} not supported. Supported: {method_names}"
|
1030
1032
|
raise ValueError(msg)
|
1031
1033
|
|
1032
|
-
#
|
1034
|
+
# 1. Before and after process_inputs
|
1035
|
+
t_process_inputs = time.time()
|
1033
1036
|
ccd_path = cache / "ccd.pkl"
|
1034
1037
|
mol_dir = cache / "mols"
|
1038
|
+
print(f"[{datetime.now().strftime('%H:%M:%S')}] Starting process_inputs...")
|
1035
1039
|
process_inputs(
|
1036
1040
|
data=data,
|
1037
1041
|
out_dir=out_dir,
|
1038
|
-
ccd_path=
|
1039
|
-
mol_dir=
|
1042
|
+
ccd_path=cache / "ccd.pkl",
|
1043
|
+
mol_dir=cache / "mols",
|
1040
1044
|
use_msa_server=use_msa_server,
|
1041
1045
|
msa_server_url=msa_server_url,
|
1042
1046
|
msa_pairing_strategy=msa_pairing_strategy,
|
@@ -1044,18 +1048,27 @@ def predict( # noqa: C901, PLR0915, PLR0912
|
|
1044
1048
|
preprocessing_threads=preprocessing_threads,
|
1045
1049
|
max_msa_seqs=max_msa_seqs,
|
1046
1050
|
)
|
1051
|
+
print(f"[{datetime.now().strftime('%H:%M:%S')}] process_inputs finished in {time.time() - t_process_inputs:.2f} seconds")
|
1047
1052
|
|
1048
|
-
#
|
1053
|
+
# 2. Before and after load manifest
|
1054
|
+
t_manifest = time.time()
|
1055
|
+
print(f"[{datetime.now().strftime('%H:%M:%S')}] Loading manifest...")
|
1049
1056
|
manifest = Manifest.load(out_dir / "processed" / "manifest.json")
|
1057
|
+
print(f"[{datetime.now().strftime('%H:%M:%S')}] Manifest loaded in {time.time() - t_manifest:.2f} seconds")
|
1050
1058
|
|
1051
|
-
# Filter out existing predictions
|
1059
|
+
# 3. Before and after Filter out existing predictions
|
1060
|
+
t_filter = time.time()
|
1061
|
+
print(f"[{datetime.now().strftime('%H:%M:%S')}] Filtering out existing predictions...")
|
1052
1062
|
filtered_manifest = filter_inputs_structure(
|
1053
1063
|
manifest=manifest,
|
1054
1064
|
outdir=out_dir,
|
1055
1065
|
override=override,
|
1056
1066
|
)
|
1067
|
+
print(f"[{datetime.now().strftime('%H:%M:%S')}] Filtering finished in {time.time() - t_filter:.2f} seconds")
|
1057
1068
|
|
1058
|
-
#
|
1069
|
+
# 4. Before and after load processed data
|
1070
|
+
t_processed = time.time()
|
1071
|
+
print(f"[{datetime.now().strftime('%H:%M:%S')}] Loading processed data...")
|
1059
1072
|
processed_dir = out_dir / "processed"
|
1060
1073
|
processed = BoltzProcessedInput(
|
1061
1074
|
manifest=filtered_manifest,
|
@@ -1075,8 +1088,22 @@ def predict( # noqa: C901, PLR0915, PLR0912
|
|
1075
1088
|
(processed_dir / "mols") if (processed_dir / "mols").exists() else None
|
1076
1089
|
),
|
1077
1090
|
)
|
1091
|
+
print(f"[{datetime.now().strftime('%H:%M:%S')}] Processed data loaded in {time.time() - t_processed:.2f} seconds")
|
1078
1092
|
|
1079
|
-
#
|
1093
|
+
# 5. Before and after create prediction writer
|
1094
|
+
t_writer = time.time()
|
1095
|
+
print(f"[{datetime.now().strftime('%H:%M:%S')}] Creating prediction writer...")
|
1096
|
+
pred_writer = BoltzWriter(
|
1097
|
+
data_dir=processed.targets_dir,
|
1098
|
+
output_dir=out_dir / "predictions",
|
1099
|
+
output_format=output_format,
|
1100
|
+
boltz2=model == "boltz2",
|
1101
|
+
)
|
1102
|
+
print(f"[{datetime.now().strftime('%H:%M:%S')}] Prediction writer created in {time.time() - t_writer:.2f} seconds")
|
1103
|
+
|
1104
|
+
# 6. Before and after set up trainer
|
1105
|
+
t_trainer = time.time()
|
1106
|
+
print(f"[{datetime.now().strftime('%H:%M:%S')}] Setting up trainer...")
|
1080
1107
|
strategy = "auto"
|
1081
1108
|
if (isinstance(devices, int) and devices > 1) or (
|
1082
1109
|
isinstance(devices, list) and len(devices) > 1
|
@@ -1129,13 +1156,16 @@ def predict( # noqa: C901, PLR0915, PLR0912
|
|
1129
1156
|
devices=devices,
|
1130
1157
|
precision=32 if model == "boltz1" else "bf16-mixed",
|
1131
1158
|
)
|
1159
|
+
print(f"[{datetime.now().strftime('%H:%M:%S')}] Trainer set up in {time.time() - t_trainer:.2f} seconds")
|
1132
1160
|
|
1133
1161
|
if filtered_manifest.records:
|
1134
1162
|
msg = f"Running structure prediction for {len(filtered_manifest.records)} input"
|
1135
1163
|
msg += "s." if len(filtered_manifest.records) > 1 else "."
|
1136
1164
|
click.echo(msg)
|
1137
1165
|
|
1138
|
-
#
|
1166
|
+
# 7. Before and after create data module
|
1167
|
+
t_datamodule = time.time()
|
1168
|
+
print(f"[{datetime.now().strftime('%H:%M:%S')}] Creating data module...")
|
1139
1169
|
if model == "boltz2":
|
1140
1170
|
data_module = Boltz2InferenceDataModule(
|
1141
1171
|
manifest=processed.manifest,
|
@@ -1156,14 +1186,16 @@ def predict( # noqa: C901, PLR0915, PLR0912
|
|
1156
1186
|
num_workers=num_workers,
|
1157
1187
|
constraints_dir=processed.constraints_dir,
|
1158
1188
|
)
|
1189
|
+
print(f"[{datetime.now().strftime('%H:%M:%S')}] Data module created in {time.time() - t_datamodule:.2f} seconds")
|
1159
1190
|
|
1160
|
-
#
|
1191
|
+
# 8. Before and after load model
|
1192
|
+
t_model = time.time()
|
1193
|
+
print(f"[{datetime.now().strftime('%H:%M:%S')}] Loading model...")
|
1161
1194
|
if checkpoint is None:
|
1162
1195
|
if model == "boltz2":
|
1163
1196
|
checkpoint = cache / "boltz2_conf.ckpt"
|
1164
1197
|
else:
|
1165
1198
|
checkpoint = cache / "boltz1_conf.ckpt"
|
1166
|
-
|
1167
1199
|
predict_args = {
|
1168
1200
|
"recycling_steps": recycling_steps,
|
1169
1201
|
"sampling_steps": sampling_steps,
|
@@ -1173,12 +1205,10 @@ def predict( # noqa: C901, PLR0915, PLR0912
|
|
1173
1205
|
"write_full_pae": write_full_pae,
|
1174
1206
|
"write_full_pde": write_full_pde,
|
1175
1207
|
}
|
1176
|
-
|
1177
1208
|
steering_args = BoltzSteeringParams()
|
1178
1209
|
steering_args.fk_steering = use_potentials
|
1179
1210
|
steering_args.guidance_update = use_potentials
|
1180
|
-
|
1181
|
-
model_cls = Boltz2 if model == "boltz2" else Boltz1
|
1211
|
+
model_cls = Boltz2 if model == "boltz2" else Boltz1)
|
1182
1212
|
model_module = model_cls.load_from_checkpoint(
|
1183
1213
|
checkpoint,
|
1184
1214
|
strict=True,
|
@@ -1192,13 +1222,17 @@ def predict( # noqa: C901, PLR0915, PLR0912
|
|
1192
1222
|
steering_args=asdict(steering_args),
|
1193
1223
|
)
|
1194
1224
|
model_module.eval()
|
1225
|
+
print(f"[{datetime.now().strftime('%H:%M:%S')}] Model loaded in {time.time() - t_model:.2f} seconds")
|
1195
1226
|
|
1196
|
-
#
|
1227
|
+
# 9. Before and after compute structure predictions (predict)
|
1228
|
+
t_predict = time.time()
|
1229
|
+
print(f"[{datetime.now().strftime('%H:%M:%S')}] Computing structure predictions...")
|
1197
1230
|
trainer.predict(
|
1198
1231
|
model_module,
|
1199
1232
|
datamodule=data_module,
|
1200
1233
|
return_predictions=False,
|
1201
1234
|
)
|
1235
|
+
print(f"[{datetime.now().strftime('%H:%M:%S')}] Structure predictions computed in {time.time() - t_predict:.2f} seconds")
|
1202
1236
|
|
1203
1237
|
# Check if affinity predictions are needed
|
1204
1238
|
if any(r.affinity for r in manifest.records):
|
@@ -1251,6 +1285,8 @@ def predict( # noqa: C901, PLR0915, PLR0912
|
|
1251
1285
|
if affinity_checkpoint is None:
|
1252
1286
|
affinity_checkpoint = cache / "boltz2_aff.ckpt"
|
1253
1287
|
|
1288
|
+
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
1289
|
+
print(f"[{datetime.now().strftime('%H:%M:%S')}] Using device: {device}")
|
1254
1290
|
model_module = Boltz2.load_from_checkpoint(
|
1255
1291
|
affinity_checkpoint,
|
1256
1292
|
strict=True,
|
@@ -1272,6 +1308,5 @@ def predict( # noqa: C901, PLR0915, PLR0912
|
|
1272
1308
|
return_predictions=False,
|
1273
1309
|
)
|
1274
1310
|
|
1275
|
-
|
1276
1311
|
if __name__ == "__main__":
|
1277
1312
|
cli()
|
boltz/model/models/boltz2.py
CHANGED
@@ -63,7 +63,7 @@ class Boltz2(LightningModule):
|
|
63
63
|
num_val_datasets: int = 1,
|
64
64
|
atom_feature_dim: int = 128,
|
65
65
|
template_args: Optional[dict] = None,
|
66
|
-
confidence_prediction: bool = True,
|
66
|
+
confidence_prediction: bool = True, #TODO: change to False
|
67
67
|
affinity_prediction: bool = False,
|
68
68
|
affinity_ensemble: bool = False,
|
69
69
|
affinity_mw_correction: bool = True,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: boltz-vsynthes
|
3
|
-
Version: 0.0
|
3
|
+
Version: 0.1.0
|
4
4
|
Summary: Boltz for VSYNTHES
|
5
5
|
Requires-Python: <3.13,>=3.10
|
6
6
|
Description-Content-Type: text/markdown
|
@@ -28,9 +28,9 @@ Requires-Dist: numba==0.61.0
|
|
28
28
|
Requires-Dist: gemmi==0.6.5
|
29
29
|
Requires-Dist: scikit-learn==1.6.1
|
30
30
|
Requires-Dist: chembl_structure_pipeline==1.2.2
|
31
|
-
Requires-Dist: cuequivariance_ops_cu12
|
32
|
-
Requires-Dist: cuequivariance_ops_torch_cu12
|
33
|
-
Requires-Dist: cuequivariance_torch
|
31
|
+
Requires-Dist: cuequivariance_ops_cu12==0.5.0
|
32
|
+
Requires-Dist: cuequivariance_ops_torch_cu12==0.5.0
|
33
|
+
Requires-Dist: cuequivariance_torch==0.5.0
|
34
34
|
Provides-Extra: lint
|
35
35
|
Requires-Dist: ruff; extra == "lint"
|
36
36
|
Provides-Extra: test
|
@@ -1,5 +1,5 @@
|
|
1
1
|
boltz/__init__.py,sha256=F_-so3S40iZrSZ89Ge4TS6aZqwWyZXq_H4AXGDlbA_g,187
|
2
|
-
boltz/main.py,sha256=
|
2
|
+
boltz/main.py,sha256=BApAG6y3m_V5RuwTtoBL_f2Ud69BVPT3pPSfsG1R718,42706
|
3
3
|
boltz/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
4
|
boltz/data/const.py,sha256=1M-88Z6HkfKY6MkNtqcj3b9P-oX9xEXluh3qM_u8dNU,26779
|
5
5
|
boltz/data/mol.py,sha256=maOpPHEGX1VVXCIFY6pQNGF7gUBZPAfgSvuPf2QO1yc,34268
|
@@ -27,7 +27,7 @@ boltz/data/filter/static/ligand.py,sha256=LamC-Z9IjYj3DmfxwMFmPbKBBhRMby3uWQj74w
|
|
27
27
|
boltz/data/filter/static/polymer.py,sha256=LNsQMsOOnhYpeKuM9AStktoTQPMZE3H0yu4mRg-jwPc,9386
|
28
28
|
boltz/data/module/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
29
29
|
boltz/data/module/inference.py,sha256=xk8ZJ8UhjPiPTdOluH_v4gnV8GtTX3sr1WZ1s5Ox8I8,8100
|
30
|
-
boltz/data/module/inferencev2.py,sha256=
|
30
|
+
boltz/data/module/inferencev2.py,sha256=aLUm1WR6E1814JUrF6sJfoe5y8y7d_s4zlQ3pdFBVy8,12742
|
31
31
|
boltz/data/module/training.py,sha256=iNzmq9ufs20S4M947CCzdYzGTFjmCTf2tFExJ2PtXnA,22428
|
32
32
|
boltz/data/module/trainingv2.py,sha256=ZsYUHYXxfuPgIpbTwCj5QLO0XK__xjsqIw6GARSNGW0,21276
|
33
33
|
boltz/data/msa/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -85,7 +85,7 @@ boltz/model/loss/distogramv2.py,sha256=dFgMGwpdLK4-skHJwvpERG10KfF3ZUN1T9_hUj-iW
|
|
85
85
|
boltz/model/loss/validation.py,sha256=gYpbag9mulg5HJPXjOUFaMV9XSYX_s2bIQ0iYjiAow0,33501
|
86
86
|
boltz/model/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
87
87
|
boltz/model/models/boltz1.py,sha256=x-x0b3VAXiAkPUBePnF56k1aYEPNgX1M6GtNCYVdCso,51718
|
88
|
-
boltz/model/models/boltz2.py,sha256=
|
88
|
+
boltz/model/models/boltz2.py,sha256=hD1kF4XSox9PPeLN_v02YrMg8zFnGyI8ZYW_AwOGJnc,51546
|
89
89
|
boltz/model/modules/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
90
90
|
boltz/model/modules/affinity.py,sha256=FktI2wrkDqsjGHJOuvzVrZK78MOPjU65QN0l6sB1QPQ,7041
|
91
91
|
boltz/model/modules/confidence.py,sha256=sXGymZiiMtfXPkUvHpa2KCCvNY79D8jXXEx9Gz2rNFs,17475
|
@@ -110,9 +110,9 @@ boltz/model/potentials/schedules.py,sha256=m7XJjfuF9uTX3bR9VisXv1rvzJjxiD8PobXRp
|
|
110
110
|
boltz/utils/sdf_splitter.py,sha256=ZHn_syOcmm-fDnJ3YEGyGv_vYz2IRzUW7vbbMSU2JBY,2108
|
111
111
|
boltz/utils/sdf_to_pre_affinity_npz.py,sha256=ro0KGe24JexbJm47J8S8w8Lmr_KaQbzOAb_dKZO2G9I,40384
|
112
112
|
boltz/utils/yaml_generator.py,sha256=ermWIG-BE6nNWHFvpEwpk92N9J-YATpGXZGLvD1I2oQ,4012
|
113
|
-
boltz_vsynthes-0.0.
|
114
|
-
boltz_vsynthes-0.0.
|
115
|
-
boltz_vsynthes-0.0.
|
116
|
-
boltz_vsynthes-0.0.
|
117
|
-
boltz_vsynthes-0.0.
|
118
|
-
boltz_vsynthes-0.0.
|
113
|
+
boltz_vsynthes-0.1.0.dist-info/licenses/LICENSE,sha256=8GZ_1eZsUeG6jdqgJJxtciWzADfgLEV4LY8sKUOsJhc,1102
|
114
|
+
boltz_vsynthes-0.1.0.dist-info/METADATA,sha256=KKvZxnnHx3XX6WsLY4WA8lE2bY4OaTZa4RnYSP_rsfs,7234
|
115
|
+
boltz_vsynthes-0.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
116
|
+
boltz_vsynthes-0.1.0.dist-info/entry_points.txt,sha256=nZNYPKKrmAr-MVA0K-ClNRT2p90FV1_14d7HpsESZFQ,211
|
117
|
+
boltz_vsynthes-0.1.0.dist-info/top_level.txt,sha256=MgU3Jfb-ctWm07YGMts68PMjSh9v26D0gfG3dFRmVFA,6
|
118
|
+
boltz_vsynthes-0.1.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|