boltz-vsynthes 0.0.17__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -59,16 +59,16 @@ def load_input(
59
59
 
60
60
  """
61
61
  # Load the structure
62
- # if affinity:
63
- # structure = StructureV2.load(
64
- # target_dir / record.id / f"pre_affinity_{record.id}.npz"
65
- # )
66
62
  if affinity:
67
- if target_dir.name == "predictions":
68
- target_dir = target_dir.parent / "processed"
69
63
  structure = StructureV2.load(
70
- target_dir / f"structures/{record.id}.npz"
64
+ target_dir / record.id / f"pre_affinity_{record.id}.npz"
71
65
  )
66
+ # if affinity:
67
+ # if target_dir.name == "predictions":
68
+ # target_dir = target_dir.parent / "processed"
69
+ # structure = StructureV2.load(
70
+ # target_dir / f"structures/{record.id}.npz"
71
+ # )
72
72
  else:
73
73
  structure = StructureV2.load(target_dir / f"{record.id}.npz")
74
74
 
boltz/main.py CHANGED
@@ -9,7 +9,7 @@ from dataclasses import asdict, dataclass
9
9
  from functools import partial
10
10
  from multiprocessing import Pool
11
11
  from pathlib import Path
12
- from typing import Literal, Optional
12
+ from typing import Literal, Optional, List
13
13
 
14
14
  import click
15
15
  import torch
@@ -18,6 +18,8 @@ from pytorch_lightning.strategies import DDPStrategy
18
18
  from pytorch_lightning.utilities import rank_zero_only
19
19
  from rdkit import Chem
20
20
  from tqdm import tqdm
21
+ import time
22
+ from datetime import datetime
21
23
 
22
24
  from boltz.data import const
23
25
  from boltz.data.module.inference import BoltzInferenceDataModule
@@ -203,22 +205,21 @@ def download_boltz2(cache: Path) -> None:
203
205
  The cache directory.
204
206
 
205
207
  """
208
+ # Use /tmp if possible for faster local disk I/O
209
+ if str(cache).startswith("/home") or str(cache).startswith("/mnt"):
210
+ cache = Path("/tmp/boltz_cache")
211
+ cache.mkdir(parents=True, exist_ok=True)
212
+
206
213
  # Download CCD
207
214
  mols = cache / "mols"
208
215
  tar_mols = cache / "mols.tar"
209
- if not tar_mols.exists():
210
- click.echo(
211
- f"Downloading the CCD data to {tar_mols}. "
212
- "This may take a bit of time. You may change the cache directory "
213
- "with the --cache flag."
214
- )
215
- urllib.request.urlretrieve(MOL_URL, str(tar_mols)) # noqa: S310
216
216
  if not mols.exists():
217
217
  click.echo(
218
- f"Extracting the CCD data to {mols}. "
218
+ f"Downloading and extracting the CCD data to {mols}. "
219
219
  "This may take a bit of time. You may change the cache directory "
220
220
  "with the --cache flag."
221
221
  )
222
+ urllib.request.urlretrieve(MOL_URL, str(tar_mols)) # noqa: S310
222
223
  with tarfile.open(str(tar_mols), "r") as tar:
223
224
  tar.extractall(cache) # noqa: S202
224
225
 
@@ -983,6 +984,7 @@ def predict( # noqa: C901, PLR0915, PLR0912
983
984
  torch.set_grad_enabled(False)
984
985
 
985
986
  # Ignore matmul precision warning
987
+ # torch.set_float32_matmul_precision('medium')
986
988
  torch.set_float32_matmul_precision("highest")
987
989
 
988
990
  # Set rdkit pickle logic
@@ -1029,14 +1031,16 @@ def predict( # noqa: C901, PLR0915, PLR0912
1029
1031
  msg = f"Method {method} not supported. Supported: {method_names}"
1030
1032
  raise ValueError(msg)
1031
1033
 
1032
- # Process inputs
1034
+ # 1. Before and after process_inputs
1035
+ t_process_inputs = time.time()
1033
1036
  ccd_path = cache / "ccd.pkl"
1034
1037
  mol_dir = cache / "mols"
1038
+ print(f"[{datetime.now().strftime('%H:%M:%S')}] Starting process_inputs...")
1035
1039
  process_inputs(
1036
1040
  data=data,
1037
1041
  out_dir=out_dir,
1038
- ccd_path=ccd_path,
1039
- mol_dir=mol_dir,
1042
+ ccd_path=cache / "ccd.pkl",
1043
+ mol_dir=cache / "mols",
1040
1044
  use_msa_server=use_msa_server,
1041
1045
  msa_server_url=msa_server_url,
1042
1046
  msa_pairing_strategy=msa_pairing_strategy,
@@ -1044,18 +1048,27 @@ def predict( # noqa: C901, PLR0915, PLR0912
1044
1048
  preprocessing_threads=preprocessing_threads,
1045
1049
  max_msa_seqs=max_msa_seqs,
1046
1050
  )
1051
+ print(f"[{datetime.now().strftime('%H:%M:%S')}] process_inputs finished in {time.time() - t_process_inputs:.2f} seconds")
1047
1052
 
1048
- # Load manifest
1053
+ # 2. Before and after load manifest
1054
+ t_manifest = time.time()
1055
+ print(f"[{datetime.now().strftime('%H:%M:%S')}] Loading manifest...")
1049
1056
  manifest = Manifest.load(out_dir / "processed" / "manifest.json")
1057
+ print(f"[{datetime.now().strftime('%H:%M:%S')}] Manifest loaded in {time.time() - t_manifest:.2f} seconds")
1050
1058
 
1051
- # Filter out existing predictions
1059
+ # 3. Before and after Filter out existing predictions
1060
+ t_filter = time.time()
1061
+ print(f"[{datetime.now().strftime('%H:%M:%S')}] Filtering out existing predictions...")
1052
1062
  filtered_manifest = filter_inputs_structure(
1053
1063
  manifest=manifest,
1054
1064
  outdir=out_dir,
1055
1065
  override=override,
1056
1066
  )
1067
+ print(f"[{datetime.now().strftime('%H:%M:%S')}] Filtering finished in {time.time() - t_filter:.2f} seconds")
1057
1068
 
1058
- # Load processed data
1069
+ # 4. Before and after load processed data
1070
+ t_processed = time.time()
1071
+ print(f"[{datetime.now().strftime('%H:%M:%S')}] Loading processed data...")
1059
1072
  processed_dir = out_dir / "processed"
1060
1073
  processed = BoltzProcessedInput(
1061
1074
  manifest=filtered_manifest,
@@ -1075,8 +1088,22 @@ def predict( # noqa: C901, PLR0915, PLR0912
1075
1088
  (processed_dir / "mols") if (processed_dir / "mols").exists() else None
1076
1089
  ),
1077
1090
  )
1091
+ print(f"[{datetime.now().strftime('%H:%M:%S')}] Processed data loaded in {time.time() - t_processed:.2f} seconds")
1078
1092
 
1079
- # Set up trainer
1093
+ # 5. Before and after create prediction writer
1094
+ t_writer = time.time()
1095
+ print(f"[{datetime.now().strftime('%H:%M:%S')}] Creating prediction writer...")
1096
+ pred_writer = BoltzWriter(
1097
+ data_dir=processed.targets_dir,
1098
+ output_dir=out_dir / "predictions",
1099
+ output_format=output_format,
1100
+ boltz2=model == "boltz2",
1101
+ )
1102
+ print(f"[{datetime.now().strftime('%H:%M:%S')}] Prediction writer created in {time.time() - t_writer:.2f} seconds")
1103
+
1104
+ # 6. Before and after set up trainer
1105
+ t_trainer = time.time()
1106
+ print(f"[{datetime.now().strftime('%H:%M:%S')}] Setting up trainer...")
1080
1107
  strategy = "auto"
1081
1108
  if (isinstance(devices, int) and devices > 1) or (
1082
1109
  isinstance(devices, list) and len(devices) > 1
@@ -1129,13 +1156,16 @@ def predict( # noqa: C901, PLR0915, PLR0912
1129
1156
  devices=devices,
1130
1157
  precision=32 if model == "boltz1" else "bf16-mixed",
1131
1158
  )
1159
+ print(f"[{datetime.now().strftime('%H:%M:%S')}] Trainer set up in {time.time() - t_trainer:.2f} seconds")
1132
1160
 
1133
1161
  if filtered_manifest.records:
1134
1162
  msg = f"Running structure prediction for {len(filtered_manifest.records)} input"
1135
1163
  msg += "s." if len(filtered_manifest.records) > 1 else "."
1136
1164
  click.echo(msg)
1137
1165
 
1138
- # Create data module
1166
+ # 7. Before and after create data module
1167
+ t_datamodule = time.time()
1168
+ print(f"[{datetime.now().strftime('%H:%M:%S')}] Creating data module...")
1139
1169
  if model == "boltz2":
1140
1170
  data_module = Boltz2InferenceDataModule(
1141
1171
  manifest=processed.manifest,
@@ -1156,14 +1186,16 @@ def predict( # noqa: C901, PLR0915, PLR0912
1156
1186
  num_workers=num_workers,
1157
1187
  constraints_dir=processed.constraints_dir,
1158
1188
  )
1189
+ print(f"[{datetime.now().strftime('%H:%M:%S')}] Data module created in {time.time() - t_datamodule:.2f} seconds")
1159
1190
 
1160
- # Load model
1191
+ # 8. Before and after load model
1192
+ t_model = time.time()
1193
+ print(f"[{datetime.now().strftime('%H:%M:%S')}] Loading model...")
1161
1194
  if checkpoint is None:
1162
1195
  if model == "boltz2":
1163
1196
  checkpoint = cache / "boltz2_conf.ckpt"
1164
1197
  else:
1165
1198
  checkpoint = cache / "boltz1_conf.ckpt"
1166
-
1167
1199
  predict_args = {
1168
1200
  "recycling_steps": recycling_steps,
1169
1201
  "sampling_steps": sampling_steps,
@@ -1173,12 +1205,10 @@ def predict( # noqa: C901, PLR0915, PLR0912
1173
1205
  "write_full_pae": write_full_pae,
1174
1206
  "write_full_pde": write_full_pde,
1175
1207
  }
1176
-
1177
1208
  steering_args = BoltzSteeringParams()
1178
1209
  steering_args.fk_steering = use_potentials
1179
1210
  steering_args.guidance_update = use_potentials
1180
-
1181
- model_cls = Boltz2 if model == "boltz2" else Boltz1
1211
+ model_cls = Boltz2 if model == "boltz2" else Boltz1)
1182
1212
  model_module = model_cls.load_from_checkpoint(
1183
1213
  checkpoint,
1184
1214
  strict=True,
@@ -1192,13 +1222,17 @@ def predict( # noqa: C901, PLR0915, PLR0912
1192
1222
  steering_args=asdict(steering_args),
1193
1223
  )
1194
1224
  model_module.eval()
1225
+ print(f"[{datetime.now().strftime('%H:%M:%S')}] Model loaded in {time.time() - t_model:.2f} seconds")
1195
1226
 
1196
- # Compute structure predictions
1227
+ # 9. Before and after compute structure predictions (predict)
1228
+ t_predict = time.time()
1229
+ print(f"[{datetime.now().strftime('%H:%M:%S')}] Computing structure predictions...")
1197
1230
  trainer.predict(
1198
1231
  model_module,
1199
1232
  datamodule=data_module,
1200
1233
  return_predictions=False,
1201
1234
  )
1235
+ print(f"[{datetime.now().strftime('%H:%M:%S')}] Structure predictions computed in {time.time() - t_predict:.2f} seconds")
1202
1236
 
1203
1237
  # Check if affinity predictions are needed
1204
1238
  if any(r.affinity for r in manifest.records):
@@ -1251,6 +1285,8 @@ def predict( # noqa: C901, PLR0915, PLR0912
1251
1285
  if affinity_checkpoint is None:
1252
1286
  affinity_checkpoint = cache / "boltz2_aff.ckpt"
1253
1287
 
1288
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1289
+ print(f"[{datetime.now().strftime('%H:%M:%S')}] Using device: {device}")
1254
1290
  model_module = Boltz2.load_from_checkpoint(
1255
1291
  affinity_checkpoint,
1256
1292
  strict=True,
@@ -1272,6 +1308,5 @@ def predict( # noqa: C901, PLR0915, PLR0912
1272
1308
  return_predictions=False,
1273
1309
  )
1274
1310
 
1275
-
1276
1311
  if __name__ == "__main__":
1277
1312
  cli()
@@ -63,7 +63,7 @@ class Boltz2(LightningModule):
63
63
  num_val_datasets: int = 1,
64
64
  atom_feature_dim: int = 128,
65
65
  template_args: Optional[dict] = None,
66
- confidence_prediction: bool = True,
66
+ confidence_prediction: bool = True, #TODO: change to False
67
67
  affinity_prediction: bool = False,
68
68
  affinity_ensemble: bool = False,
69
69
  affinity_mw_correction: bool = True,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: boltz-vsynthes
3
- Version: 0.0.17
3
+ Version: 0.1.0
4
4
  Summary: Boltz for VSYNTHES
5
5
  Requires-Python: <3.13,>=3.10
6
6
  Description-Content-Type: text/markdown
@@ -28,9 +28,9 @@ Requires-Dist: numba==0.61.0
28
28
  Requires-Dist: gemmi==0.6.5
29
29
  Requires-Dist: scikit-learn==1.6.1
30
30
  Requires-Dist: chembl_structure_pipeline==1.2.2
31
- Requires-Dist: cuequivariance_ops_cu12>=0.5.0
32
- Requires-Dist: cuequivariance_ops_torch_cu12>=0.5.0
33
- Requires-Dist: cuequivariance_torch>=0.5.0
31
+ Requires-Dist: cuequivariance_ops_cu12==0.5.0
32
+ Requires-Dist: cuequivariance_ops_torch_cu12==0.5.0
33
+ Requires-Dist: cuequivariance_torch==0.5.0
34
34
  Provides-Extra: lint
35
35
  Requires-Dist: ruff; extra == "lint"
36
36
  Provides-Extra: test
@@ -1,5 +1,5 @@
1
1
  boltz/__init__.py,sha256=F_-so3S40iZrSZ89Ge4TS6aZqwWyZXq_H4AXGDlbA_g,187
2
- boltz/main.py,sha256=zBLxa6T8hxcBs7gj1BnWfgJSx6uki8iV-QgClvoaiSA,39951
2
+ boltz/main.py,sha256=BApAG6y3m_V5RuwTtoBL_f2Ud69BVPT3pPSfsG1R718,42706
3
3
  boltz/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  boltz/data/const.py,sha256=1M-88Z6HkfKY6MkNtqcj3b9P-oX9xEXluh3qM_u8dNU,26779
5
5
  boltz/data/mol.py,sha256=maOpPHEGX1VVXCIFY6pQNGF7gUBZPAfgSvuPf2QO1yc,34268
@@ -27,7 +27,7 @@ boltz/data/filter/static/ligand.py,sha256=LamC-Z9IjYj3DmfxwMFmPbKBBhRMby3uWQj74w
27
27
  boltz/data/filter/static/polymer.py,sha256=LNsQMsOOnhYpeKuM9AStktoTQPMZE3H0yu4mRg-jwPc,9386
28
28
  boltz/data/module/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
29
29
  boltz/data/module/inference.py,sha256=xk8ZJ8UhjPiPTdOluH_v4gnV8GtTX3sr1WZ1s5Ox8I8,8100
30
- boltz/data/module/inferencev2.py,sha256=QmtvyUuHkY2f1ulE2xkIjo87xRIuOj3Yddo6_mbILYg,12738
30
+ boltz/data/module/inferencev2.py,sha256=aLUm1WR6E1814JUrF6sJfoe5y8y7d_s4zlQ3pdFBVy8,12742
31
31
  boltz/data/module/training.py,sha256=iNzmq9ufs20S4M947CCzdYzGTFjmCTf2tFExJ2PtXnA,22428
32
32
  boltz/data/module/trainingv2.py,sha256=ZsYUHYXxfuPgIpbTwCj5QLO0XK__xjsqIw6GARSNGW0,21276
33
33
  boltz/data/msa/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -85,7 +85,7 @@ boltz/model/loss/distogramv2.py,sha256=dFgMGwpdLK4-skHJwvpERG10KfF3ZUN1T9_hUj-iW
85
85
  boltz/model/loss/validation.py,sha256=gYpbag9mulg5HJPXjOUFaMV9XSYX_s2bIQ0iYjiAow0,33501
86
86
  boltz/model/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
87
87
  boltz/model/models/boltz1.py,sha256=x-x0b3VAXiAkPUBePnF56k1aYEPNgX1M6GtNCYVdCso,51718
88
- boltz/model/models/boltz2.py,sha256=3XOWjUWaSJquw8Xdp7JItDUnVDyoC0qtx3q4MFQrd38,51523
88
+ boltz/model/models/boltz2.py,sha256=hD1kF4XSox9PPeLN_v02YrMg8zFnGyI8ZYW_AwOGJnc,51546
89
89
  boltz/model/modules/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
90
90
  boltz/model/modules/affinity.py,sha256=FktI2wrkDqsjGHJOuvzVrZK78MOPjU65QN0l6sB1QPQ,7041
91
91
  boltz/model/modules/confidence.py,sha256=sXGymZiiMtfXPkUvHpa2KCCvNY79D8jXXEx9Gz2rNFs,17475
@@ -110,9 +110,9 @@ boltz/model/potentials/schedules.py,sha256=m7XJjfuF9uTX3bR9VisXv1rvzJjxiD8PobXRp
110
110
  boltz/utils/sdf_splitter.py,sha256=ZHn_syOcmm-fDnJ3YEGyGv_vYz2IRzUW7vbbMSU2JBY,2108
111
111
  boltz/utils/sdf_to_pre_affinity_npz.py,sha256=ro0KGe24JexbJm47J8S8w8Lmr_KaQbzOAb_dKZO2G9I,40384
112
112
  boltz/utils/yaml_generator.py,sha256=ermWIG-BE6nNWHFvpEwpk92N9J-YATpGXZGLvD1I2oQ,4012
113
- boltz_vsynthes-0.0.17.dist-info/licenses/LICENSE,sha256=8GZ_1eZsUeG6jdqgJJxtciWzADfgLEV4LY8sKUOsJhc,1102
114
- boltz_vsynthes-0.0.17.dist-info/METADATA,sha256=C3WBB1KJrsEeeiuDnk78uhHAdFRMp8ApLNnOS8uLx1c,7235
115
- boltz_vsynthes-0.0.17.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
116
- boltz_vsynthes-0.0.17.dist-info/entry_points.txt,sha256=nZNYPKKrmAr-MVA0K-ClNRT2p90FV1_14d7HpsESZFQ,211
117
- boltz_vsynthes-0.0.17.dist-info/top_level.txt,sha256=MgU3Jfb-ctWm07YGMts68PMjSh9v26D0gfG3dFRmVFA,6
118
- boltz_vsynthes-0.0.17.dist-info/RECORD,,
113
+ boltz_vsynthes-0.1.0.dist-info/licenses/LICENSE,sha256=8GZ_1eZsUeG6jdqgJJxtciWzADfgLEV4LY8sKUOsJhc,1102
114
+ boltz_vsynthes-0.1.0.dist-info/METADATA,sha256=KKvZxnnHx3XX6WsLY4WA8lE2bY4OaTZa4RnYSP_rsfs,7234
115
+ boltz_vsynthes-0.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
116
+ boltz_vsynthes-0.1.0.dist-info/entry_points.txt,sha256=nZNYPKKrmAr-MVA0K-ClNRT2p90FV1_14d7HpsESZFQ,211
117
+ boltz_vsynthes-0.1.0.dist-info/top_level.txt,sha256=MgU3Jfb-ctWm07YGMts68PMjSh9v26D0gfG3dFRmVFA,6
118
+ boltz_vsynthes-0.1.0.dist-info/RECORD,,