boltz-vsynthes 0.0.13__py3-none-any.whl → 0.0.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -59,9 +59,15 @@ def load_input(
59
59
 
60
60
  """
61
61
  # Load the structure
62
+ # if affinity:
63
+ # structure = StructureV2.load(
64
+ # target_dir / record.id / f"pre_affinity_{record.id}.npz"
65
+ # )
62
66
  if affinity:
67
+ if target_dir.name == "predictions":
68
+ target_dir = target_dir.parent / "processed"
63
69
  structure = StructureV2.load(
64
- target_dir / record.id / f"pre_affinity_{record.id}.npz"
70
+ target_dir / record.id / f"{record.id}.npz"
65
71
  )
66
72
  else:
67
73
  structure = StructureV2.load(target_dir / f"{record.id}.npz")
@@ -1076,41 +1076,41 @@ def predict( # noqa: C901, PLR0915, PLR0912
1076
1076
  ),
1077
1077
  )
1078
1078
 
1079
- # # Set up trainer
1080
- # strategy = "auto"
1081
- # if (isinstance(devices, int) and devices > 1) or (
1082
- # isinstance(devices, list) and len(devices) > 1
1083
- # ):
1084
- # start_method = "fork" if platform.system() != "win32" else "spawn"
1085
- # strategy = DDPStrategy(start_method=start_method)
1086
- # if len(filtered_manifest.records) < devices:
1087
- # msg = (
1088
- # "Number of requested devices is greater "
1089
- # "than the number of predictions, taking the minimum."
1090
- # )
1091
- # click.echo(msg)
1092
- # if isinstance(devices, list):
1093
- # devices = devices[: max(1, len(filtered_manifest.records))]
1094
- # else:
1095
- # devices = max(1, min(len(filtered_manifest.records), devices))
1096
-
1097
- # # Set up model parameters
1098
- # if model == "boltz2":
1099
- # diffusion_params = Boltz2DiffusionParams()
1100
- # step_scale = 1.5 if step_scale is None else step_scale
1101
- # diffusion_params.step_scale = step_scale
1102
- # pairformer_args = PairformerArgsV2()
1103
- # else:
1104
- # diffusion_params = BoltzDiffusionParams()
1105
- # step_scale = 1.638 if step_scale is None else step_scale
1106
- # diffusion_params.step_scale = step_scale
1107
- # pairformer_args = PairformerArgs()
1108
-
1109
- # msa_args = MSAModuleArgs(
1110
- # subsample_msa=subsample_msa,
1111
- # num_subsampled_msa=num_subsampled_msa,
1112
- # use_paired_feature=model == "boltz2",
1113
- # )
1079
+ # Set up trainer
1080
+ strategy = "auto"
1081
+ if (isinstance(devices, int) and devices > 1) or (
1082
+ isinstance(devices, list) and len(devices) > 1
1083
+ ):
1084
+ start_method = "fork" if platform.system() != "win32" else "spawn"
1085
+ strategy = DDPStrategy(start_method=start_method)
1086
+ if len(filtered_manifest.records) < devices:
1087
+ msg = (
1088
+ "Number of requested devices is greater "
1089
+ "than the number of predictions, taking the minimum."
1090
+ )
1091
+ click.echo(msg)
1092
+ if isinstance(devices, list):
1093
+ devices = devices[: max(1, len(filtered_manifest.records))]
1094
+ else:
1095
+ devices = max(1, min(len(filtered_manifest.records), devices))
1096
+
1097
+ # Set up model parameters
1098
+ if model == "boltz2":
1099
+ diffusion_params = Boltz2DiffusionParams()
1100
+ step_scale = 1.5 if step_scale is None else step_scale
1101
+ diffusion_params.step_scale = step_scale
1102
+ pairformer_args = PairformerArgsV2()
1103
+ else:
1104
+ diffusion_params = BoltzDiffusionParams()
1105
+ step_scale = 1.638 if step_scale is None else step_scale
1106
+ diffusion_params.step_scale = step_scale
1107
+ pairformer_args = PairformerArgs()
1108
+
1109
+ msa_args = MSAModuleArgs(
1110
+ subsample_msa=subsample_msa,
1111
+ num_subsampled_msa=num_subsampled_msa,
1112
+ use_paired_feature=model == "boltz2",
1113
+ )
1114
1114
 
1115
1115
  # # Create prediction writer
1116
1116
  # pred_writer = BoltzWriter(
@@ -1200,77 +1200,86 @@ def predict( # noqa: C901, PLR0915, PLR0912
1200
1200
  # return_predictions=False,
1201
1201
  # )
1202
1202
 
1203
- # # Check if affinity predictions are needed
1204
- # if any(r.affinity for r in manifest.records):
1205
- # # Print header
1206
- # click.echo("\nPredicting property: affinity\n")
1203
+ # Check if affinity predictions are needed
1204
+ if any(r.affinity for r in manifest.records):
1205
+ # Print header
1206
+ click.echo("\nPredicting property: affinity\n")
1207
1207
 
1208
- # # Validate inputs
1209
- # manifest_filtered = filter_inputs_affinity(
1210
- # manifest=manifest,
1211
- # outdir=out_dir,
1212
- # override=override,
1213
- # )
1214
- # if not manifest_filtered.records:
1215
- # click.echo("Found existing affinity predictions for all inputs, skipping.")
1216
- # return
1217
-
1218
- # msg = f"Running affinity prediction for {len(manifest_filtered.records)} input"
1219
- # msg += "s." if len(manifest_filtered.records) > 1 else "."
1220
- # click.echo(msg)
1208
+ # Validate inputs
1209
+ manifest_filtered = filter_inputs_affinity(
1210
+ manifest=manifest,
1211
+ outdir=out_dir,
1212
+ override=override,
1213
+ )
1214
+ if not manifest_filtered.records:
1215
+ click.echo("Found existing affinity predictions for all inputs, skipping.")
1216
+ return
1221
1217
 
1222
- # pred_writer = BoltzAffinityWriter(
1223
- # data_dir=processed.targets_dir,
1224
- # output_dir=out_dir / "predictions",
1225
- # )
1218
+ msg = f"Running affinity prediction for {len(manifest_filtered.records)} input"
1219
+ msg += "s." if len(manifest_filtered.records) > 1 else "."
1220
+ click.echo(msg)
1226
1221
 
1227
- # data_module = Boltz2InferenceDataModule(
1228
- # manifest=manifest_filtered,
1229
- # target_dir=out_dir / "predictions",
1230
- # msa_dir=processed.msa_dir,
1231
- # mol_dir=mol_dir,
1232
- # num_workers=num_workers,
1233
- # constraints_dir=processed.constraints_dir,
1234
- # template_dir=processed.template_dir,
1235
- # extra_mols_dir=processed.extra_mols_dir,
1236
- # override_method="other",
1237
- # affinity=True,
1238
- # )
1222
+ pred_writer = BoltzAffinityWriter(
1223
+ data_dir=processed.targets_dir,
1224
+ output_dir=out_dir / "predictions",
1225
+ )
1239
1226
 
1240
- # predict_affinity_args = {
1241
- # "recycling_steps": 5,
1242
- # "sampling_steps": sampling_steps_affinity,
1243
- # "diffusion_samples": diffusion_samples_affinity,
1244
- # "max_parallel_samples": 1,
1245
- # "write_confidence_summary": False,
1246
- # "write_full_pae": False,
1247
- # "write_full_pde": False,
1248
- # }
1227
+ trainer = Trainer(
1228
+ default_root_dir=out_dir,
1229
+ strategy=strategy,
1230
+ callbacks=[pred_writer],
1231
+ accelerator=accelerator,
1232
+ devices=devices,
1233
+ precision=32 if model == "boltz1" else "bf16-mixed",
1234
+ )
1249
1235
 
1250
- # # Load affinity model
1251
- # if affinity_checkpoint is None:
1252
- # affinity_checkpoint = cache / "boltz2_aff.ckpt"
1236
+ data_module = Boltz2InferenceDataModule(
1237
+ manifest=manifest_filtered,
1238
+ target_dir=out_dir / "predictions",
1239
+ msa_dir=processed.msa_dir,
1240
+ mol_dir=mol_dir,
1241
+ num_workers=num_workers,
1242
+ constraints_dir=processed.constraints_dir,
1243
+ template_dir=processed.template_dir,
1244
+ extra_mols_dir=processed.extra_mols_dir,
1245
+ override_method="other",
1246
+ affinity=True,
1247
+ )
1253
1248
 
1254
- # model_module = Boltz2.load_from_checkpoint(
1255
- # affinity_checkpoint,
1256
- # strict=True,
1257
- # predict_args=predict_affinity_args,
1258
- # map_location="cpu",
1259
- # diffusion_process_args=asdict(diffusion_params),
1260
- # ema=False,
1261
- # pairformer_args=asdict(pairformer_args),
1262
- # msa_args=asdict(msa_args),
1263
- # steering_args={"fk_steering": False, "guidance_update": False},
1264
- # affinity_mw_correction=affinity_mw_correction,
1265
- # )
1266
- # model_module.eval()
1249
+ predict_affinity_args = {
1250
+ "recycling_steps": 5,
1251
+ "sampling_steps": sampling_steps_affinity,
1252
+ "diffusion_samples": diffusion_samples_affinity,
1253
+ "max_parallel_samples": 1,
1254
+ "write_confidence_summary": False,
1255
+ "write_full_pae": False,
1256
+ "write_full_pde": False,
1257
+ }
1258
+
1259
+ # Load affinity model
1260
+ if affinity_checkpoint is None:
1261
+ affinity_checkpoint = cache / "boltz2_aff.ckpt"
1262
+
1263
+ model_module = Boltz2.load_from_checkpoint(
1264
+ affinity_checkpoint,
1265
+ strict=True,
1266
+ predict_args=predict_affinity_args,
1267
+ map_location="cpu",
1268
+ diffusion_process_args=asdict(diffusion_params),
1269
+ ema=False,
1270
+ pairformer_args=asdict(pairformer_args),
1271
+ msa_args=asdict(msa_args),
1272
+ steering_args={"fk_steering": False, "guidance_update": False},
1273
+ affinity_mw_correction=affinity_mw_correction,
1274
+ )
1275
+ model_module.eval()
1267
1276
 
1268
- # trainer.callbacks[0] = pred_writer
1269
- # trainer.predict(
1270
- # model_module,
1271
- # datamodule=data_module,
1272
- # return_predictions=False,
1273
- # )
1277
+ trainer.callbacks[0] = pred_writer
1278
+ trainer.predict(
1279
+ model_module,
1280
+ datamodule=data_module,
1281
+ return_predictions=False,
1282
+ )
1274
1283
 
1275
1284
 
1276
1285
  if __name__ == "__main__":
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: boltz-vsynthes
3
- Version: 0.0.13
3
+ Version: 0.0.15
4
4
  Summary: Boltz for VSYNTHES
5
5
  Requires-Python: <3.13,>=3.10
6
6
  Description-Content-Type: text/markdown
@@ -27,7 +27,7 @@ boltz/data/filter/static/ligand.py,sha256=LamC-Z9IjYj3DmfxwMFmPbKBBhRMby3uWQj74w
27
27
  boltz/data/filter/static/polymer.py,sha256=LNsQMsOOnhYpeKuM9AStktoTQPMZE3H0yu4mRg-jwPc,9386
28
28
  boltz/data/module/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
29
29
  boltz/data/module/inference.py,sha256=xk8ZJ8UhjPiPTdOluH_v4gnV8GtTX3sr1WZ1s5Ox8I8,8100
30
- boltz/data/module/inferencev2.py,sha256=3p-jyPstcNzUeaOshEzAexXRBFjvpr-9tP3n8hxT6nw,12508
30
+ boltz/data/module/inferencev2.py,sha256=y1bN4EchiQ3BsWd6Tz6Ug41I_TItP6WQTc4amBuRN0E,12739
31
31
  boltz/data/module/training.py,sha256=iNzmq9ufs20S4M947CCzdYzGTFjmCTf2tFExJ2PtXnA,22428
32
32
  boltz/data/module/trainingv2.py,sha256=ZsYUHYXxfuPgIpbTwCj5QLO0XK__xjsqIw6GARSNGW0,21276
33
33
  boltz/data/msa/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -108,11 +108,11 @@ boltz/model/potentials/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3
108
108
  boltz/model/potentials/potentials.py,sha256=vev8Vjfs-ML1hyrdv_R8DynG4wSFahJ6nzPWp7CYQqw,17507
109
109
  boltz/model/potentials/schedules.py,sha256=m7XJjfuF9uTX3bR9VisXv1rvzJjxiD8PobXRpcBBu1c,968
110
110
  boltz/utils/sdf_splitter.py,sha256=ZHn_syOcmm-fDnJ3YEGyGv_vYz2IRzUW7vbbMSU2JBY,2108
111
- boltz/utils/sdf_to_pre_affinity_npz.py,sha256=ENljAVhA7ZtDkUCp1xuJufTyVbuaQHZAe_vAl6ck-WE,40301
111
+ boltz/utils/sdf_to_pre_affinity_npz.py,sha256=ro0KGe24JexbJm47J8S8w8Lmr_KaQbzOAb_dKZO2G9I,40384
112
112
  boltz/utils/yaml_generator.py,sha256=ermWIG-BE6nNWHFvpEwpk92N9J-YATpGXZGLvD1I2oQ,4012
113
- boltz_vsynthes-0.0.13.dist-info/licenses/LICENSE,sha256=8GZ_1eZsUeG6jdqgJJxtciWzADfgLEV4LY8sKUOsJhc,1102
114
- boltz_vsynthes-0.0.13.dist-info/METADATA,sha256=-ZCZOOLVwXKOYf2E0XY8Q1ZAlHQgZFp0Qj4_mPfeIqU,7235
115
- boltz_vsynthes-0.0.13.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
116
- boltz_vsynthes-0.0.13.dist-info/entry_points.txt,sha256=nZNYPKKrmAr-MVA0K-ClNRT2p90FV1_14d7HpsESZFQ,211
117
- boltz_vsynthes-0.0.13.dist-info/top_level.txt,sha256=MgU3Jfb-ctWm07YGMts68PMjSh9v26D0gfG3dFRmVFA,6
118
- boltz_vsynthes-0.0.13.dist-info/RECORD,,
113
+ boltz_vsynthes-0.0.15.dist-info/licenses/LICENSE,sha256=8GZ_1eZsUeG6jdqgJJxtciWzADfgLEV4LY8sKUOsJhc,1102
114
+ boltz_vsynthes-0.0.15.dist-info/METADATA,sha256=NJSqvHhTsV6-kN89Wa1LdFW8gmXutClh54r_iKsjo94,7235
115
+ boltz_vsynthes-0.0.15.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
116
+ boltz_vsynthes-0.0.15.dist-info/entry_points.txt,sha256=nZNYPKKrmAr-MVA0K-ClNRT2p90FV1_14d7HpsESZFQ,211
117
+ boltz_vsynthes-0.0.15.dist-info/top_level.txt,sha256=MgU3Jfb-ctWm07YGMts68PMjSh9v26D0gfG3dFRmVFA,6
118
+ boltz_vsynthes-0.0.15.dist-info/RECORD,,