boltz-vsynthes 0.0.13__py3-none-any.whl → 0.0.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltz/data/module/inferencev2.py +7 -1
- boltz/utils/sdf_to_pre_affinity_npz.py +108 -99
- {boltz_vsynthes-0.0.13.dist-info → boltz_vsynthes-0.0.15.dist-info}/METADATA +1 -1
- {boltz_vsynthes-0.0.13.dist-info → boltz_vsynthes-0.0.15.dist-info}/RECORD +8 -8
- {boltz_vsynthes-0.0.13.dist-info → boltz_vsynthes-0.0.15.dist-info}/WHEEL +0 -0
- {boltz_vsynthes-0.0.13.dist-info → boltz_vsynthes-0.0.15.dist-info}/entry_points.txt +0 -0
- {boltz_vsynthes-0.0.13.dist-info → boltz_vsynthes-0.0.15.dist-info}/licenses/LICENSE +0 -0
- {boltz_vsynthes-0.0.13.dist-info → boltz_vsynthes-0.0.15.dist-info}/top_level.txt +0 -0
boltz/data/module/inferencev2.py
CHANGED
@@ -59,9 +59,15 @@ def load_input(
|
|
59
59
|
|
60
60
|
"""
|
61
61
|
# Load the structure
|
62
|
+
# if affinity:
|
63
|
+
# structure = StructureV2.load(
|
64
|
+
# target_dir / record.id / f"pre_affinity_{record.id}.npz"
|
65
|
+
# )
|
62
66
|
if affinity:
|
67
|
+
if target_dir.name == "predictions":
|
68
|
+
target_dir = target_dir.parent / "processed"
|
63
69
|
structure = StructureV2.load(
|
64
|
-
target_dir / record.id / f"
|
70
|
+
target_dir / record.id / f"{record.id}.npz"
|
65
71
|
)
|
66
72
|
else:
|
67
73
|
structure = StructureV2.load(target_dir / f"{record.id}.npz")
|
@@ -1076,41 +1076,41 @@ def predict( # noqa: C901, PLR0915, PLR0912
|
|
1076
1076
|
),
|
1077
1077
|
)
|
1078
1078
|
|
1079
|
-
#
|
1080
|
-
|
1081
|
-
|
1082
|
-
|
1083
|
-
|
1084
|
-
|
1085
|
-
|
1086
|
-
|
1087
|
-
|
1088
|
-
|
1089
|
-
|
1090
|
-
|
1091
|
-
|
1092
|
-
|
1093
|
-
|
1094
|
-
|
1095
|
-
|
1096
|
-
|
1097
|
-
#
|
1098
|
-
|
1099
|
-
|
1100
|
-
|
1101
|
-
|
1102
|
-
|
1103
|
-
|
1104
|
-
|
1105
|
-
|
1106
|
-
|
1107
|
-
|
1108
|
-
|
1109
|
-
|
1110
|
-
|
1111
|
-
|
1112
|
-
|
1113
|
-
|
1079
|
+
# Set up trainer
|
1080
|
+
strategy = "auto"
|
1081
|
+
if (isinstance(devices, int) and devices > 1) or (
|
1082
|
+
isinstance(devices, list) and len(devices) > 1
|
1083
|
+
):
|
1084
|
+
start_method = "fork" if platform.system() != "win32" else "spawn"
|
1085
|
+
strategy = DDPStrategy(start_method=start_method)
|
1086
|
+
if len(filtered_manifest.records) < devices:
|
1087
|
+
msg = (
|
1088
|
+
"Number of requested devices is greater "
|
1089
|
+
"than the number of predictions, taking the minimum."
|
1090
|
+
)
|
1091
|
+
click.echo(msg)
|
1092
|
+
if isinstance(devices, list):
|
1093
|
+
devices = devices[: max(1, len(filtered_manifest.records))]
|
1094
|
+
else:
|
1095
|
+
devices = max(1, min(len(filtered_manifest.records), devices))
|
1096
|
+
|
1097
|
+
# Set up model parameters
|
1098
|
+
if model == "boltz2":
|
1099
|
+
diffusion_params = Boltz2DiffusionParams()
|
1100
|
+
step_scale = 1.5 if step_scale is None else step_scale
|
1101
|
+
diffusion_params.step_scale = step_scale
|
1102
|
+
pairformer_args = PairformerArgsV2()
|
1103
|
+
else:
|
1104
|
+
diffusion_params = BoltzDiffusionParams()
|
1105
|
+
step_scale = 1.638 if step_scale is None else step_scale
|
1106
|
+
diffusion_params.step_scale = step_scale
|
1107
|
+
pairformer_args = PairformerArgs()
|
1108
|
+
|
1109
|
+
msa_args = MSAModuleArgs(
|
1110
|
+
subsample_msa=subsample_msa,
|
1111
|
+
num_subsampled_msa=num_subsampled_msa,
|
1112
|
+
use_paired_feature=model == "boltz2",
|
1113
|
+
)
|
1114
1114
|
|
1115
1115
|
# # Create prediction writer
|
1116
1116
|
# pred_writer = BoltzWriter(
|
@@ -1200,77 +1200,86 @@ def predict( # noqa: C901, PLR0915, PLR0912
|
|
1200
1200
|
# return_predictions=False,
|
1201
1201
|
# )
|
1202
1202
|
|
1203
|
-
#
|
1204
|
-
|
1205
|
-
|
1206
|
-
|
1203
|
+
# Check if affinity predictions are needed
|
1204
|
+
if any(r.affinity for r in manifest.records):
|
1205
|
+
# Print header
|
1206
|
+
click.echo("\nPredicting property: affinity\n")
|
1207
1207
|
|
1208
|
-
|
1209
|
-
|
1210
|
-
|
1211
|
-
|
1212
|
-
|
1213
|
-
|
1214
|
-
|
1215
|
-
|
1216
|
-
|
1217
|
-
|
1218
|
-
# msg = f"Running affinity prediction for {len(manifest_filtered.records)} input"
|
1219
|
-
# msg += "s." if len(manifest_filtered.records) > 1 else "."
|
1220
|
-
# click.echo(msg)
|
1208
|
+
# Validate inputs
|
1209
|
+
manifest_filtered = filter_inputs_affinity(
|
1210
|
+
manifest=manifest,
|
1211
|
+
outdir=out_dir,
|
1212
|
+
override=override,
|
1213
|
+
)
|
1214
|
+
if not manifest_filtered.records:
|
1215
|
+
click.echo("Found existing affinity predictions for all inputs, skipping.")
|
1216
|
+
return
|
1221
1217
|
|
1222
|
-
|
1223
|
-
|
1224
|
-
|
1225
|
-
# )
|
1218
|
+
msg = f"Running affinity prediction for {len(manifest_filtered.records)} input"
|
1219
|
+
msg += "s." if len(manifest_filtered.records) > 1 else "."
|
1220
|
+
click.echo(msg)
|
1226
1221
|
|
1227
|
-
|
1228
|
-
|
1229
|
-
|
1230
|
-
|
1231
|
-
# mol_dir=mol_dir,
|
1232
|
-
# num_workers=num_workers,
|
1233
|
-
# constraints_dir=processed.constraints_dir,
|
1234
|
-
# template_dir=processed.template_dir,
|
1235
|
-
# extra_mols_dir=processed.extra_mols_dir,
|
1236
|
-
# override_method="other",
|
1237
|
-
# affinity=True,
|
1238
|
-
# )
|
1222
|
+
pred_writer = BoltzAffinityWriter(
|
1223
|
+
data_dir=processed.targets_dir,
|
1224
|
+
output_dir=out_dir / "predictions",
|
1225
|
+
)
|
1239
1226
|
|
1240
|
-
|
1241
|
-
|
1242
|
-
|
1243
|
-
|
1244
|
-
|
1245
|
-
|
1246
|
-
|
1247
|
-
|
1248
|
-
# }
|
1227
|
+
trainer = Trainer(
|
1228
|
+
default_root_dir=out_dir,
|
1229
|
+
strategy=strategy,
|
1230
|
+
callbacks=[pred_writer],
|
1231
|
+
accelerator=accelerator,
|
1232
|
+
devices=devices,
|
1233
|
+
precision=32 if model == "boltz1" else "bf16-mixed",
|
1234
|
+
)
|
1249
1235
|
|
1250
|
-
|
1251
|
-
|
1252
|
-
|
1236
|
+
data_module = Boltz2InferenceDataModule(
|
1237
|
+
manifest=manifest_filtered,
|
1238
|
+
target_dir=out_dir / "predictions",
|
1239
|
+
msa_dir=processed.msa_dir,
|
1240
|
+
mol_dir=mol_dir,
|
1241
|
+
num_workers=num_workers,
|
1242
|
+
constraints_dir=processed.constraints_dir,
|
1243
|
+
template_dir=processed.template_dir,
|
1244
|
+
extra_mols_dir=processed.extra_mols_dir,
|
1245
|
+
override_method="other",
|
1246
|
+
affinity=True,
|
1247
|
+
)
|
1253
1248
|
|
1254
|
-
|
1255
|
-
|
1256
|
-
|
1257
|
-
|
1258
|
-
|
1259
|
-
|
1260
|
-
|
1261
|
-
|
1262
|
-
|
1263
|
-
|
1264
|
-
|
1265
|
-
|
1266
|
-
|
1249
|
+
predict_affinity_args = {
|
1250
|
+
"recycling_steps": 5,
|
1251
|
+
"sampling_steps": sampling_steps_affinity,
|
1252
|
+
"diffusion_samples": diffusion_samples_affinity,
|
1253
|
+
"max_parallel_samples": 1,
|
1254
|
+
"write_confidence_summary": False,
|
1255
|
+
"write_full_pae": False,
|
1256
|
+
"write_full_pde": False,
|
1257
|
+
}
|
1258
|
+
|
1259
|
+
# Load affinity model
|
1260
|
+
if affinity_checkpoint is None:
|
1261
|
+
affinity_checkpoint = cache / "boltz2_aff.ckpt"
|
1262
|
+
|
1263
|
+
model_module = Boltz2.load_from_checkpoint(
|
1264
|
+
affinity_checkpoint,
|
1265
|
+
strict=True,
|
1266
|
+
predict_args=predict_affinity_args,
|
1267
|
+
map_location="cpu",
|
1268
|
+
diffusion_process_args=asdict(diffusion_params),
|
1269
|
+
ema=False,
|
1270
|
+
pairformer_args=asdict(pairformer_args),
|
1271
|
+
msa_args=asdict(msa_args),
|
1272
|
+
steering_args={"fk_steering": False, "guidance_update": False},
|
1273
|
+
affinity_mw_correction=affinity_mw_correction,
|
1274
|
+
)
|
1275
|
+
model_module.eval()
|
1267
1276
|
|
1268
|
-
|
1269
|
-
|
1270
|
-
|
1271
|
-
|
1272
|
-
|
1273
|
-
|
1277
|
+
trainer.callbacks[0] = pred_writer
|
1278
|
+
trainer.predict(
|
1279
|
+
model_module,
|
1280
|
+
datamodule=data_module,
|
1281
|
+
return_predictions=False,
|
1282
|
+
)
|
1274
1283
|
|
1275
1284
|
|
1276
1285
|
if __name__ == "__main__":
|
@@ -27,7 +27,7 @@ boltz/data/filter/static/ligand.py,sha256=LamC-Z9IjYj3DmfxwMFmPbKBBhRMby3uWQj74w
|
|
27
27
|
boltz/data/filter/static/polymer.py,sha256=LNsQMsOOnhYpeKuM9AStktoTQPMZE3H0yu4mRg-jwPc,9386
|
28
28
|
boltz/data/module/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
29
29
|
boltz/data/module/inference.py,sha256=xk8ZJ8UhjPiPTdOluH_v4gnV8GtTX3sr1WZ1s5Ox8I8,8100
|
30
|
-
boltz/data/module/inferencev2.py,sha256=
|
30
|
+
boltz/data/module/inferencev2.py,sha256=y1bN4EchiQ3BsWd6Tz6Ug41I_TItP6WQTc4amBuRN0E,12739
|
31
31
|
boltz/data/module/training.py,sha256=iNzmq9ufs20S4M947CCzdYzGTFjmCTf2tFExJ2PtXnA,22428
|
32
32
|
boltz/data/module/trainingv2.py,sha256=ZsYUHYXxfuPgIpbTwCj5QLO0XK__xjsqIw6GARSNGW0,21276
|
33
33
|
boltz/data/msa/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -108,11 +108,11 @@ boltz/model/potentials/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3
|
|
108
108
|
boltz/model/potentials/potentials.py,sha256=vev8Vjfs-ML1hyrdv_R8DynG4wSFahJ6nzPWp7CYQqw,17507
|
109
109
|
boltz/model/potentials/schedules.py,sha256=m7XJjfuF9uTX3bR9VisXv1rvzJjxiD8PobXRpcBBu1c,968
|
110
110
|
boltz/utils/sdf_splitter.py,sha256=ZHn_syOcmm-fDnJ3YEGyGv_vYz2IRzUW7vbbMSU2JBY,2108
|
111
|
-
boltz/utils/sdf_to_pre_affinity_npz.py,sha256=
|
111
|
+
boltz/utils/sdf_to_pre_affinity_npz.py,sha256=ro0KGe24JexbJm47J8S8w8Lmr_KaQbzOAb_dKZO2G9I,40384
|
112
112
|
boltz/utils/yaml_generator.py,sha256=ermWIG-BE6nNWHFvpEwpk92N9J-YATpGXZGLvD1I2oQ,4012
|
113
|
-
boltz_vsynthes-0.0.
|
114
|
-
boltz_vsynthes-0.0.
|
115
|
-
boltz_vsynthes-0.0.
|
116
|
-
boltz_vsynthes-0.0.
|
117
|
-
boltz_vsynthes-0.0.
|
118
|
-
boltz_vsynthes-0.0.
|
113
|
+
boltz_vsynthes-0.0.15.dist-info/licenses/LICENSE,sha256=8GZ_1eZsUeG6jdqgJJxtciWzADfgLEV4LY8sKUOsJhc,1102
|
114
|
+
boltz_vsynthes-0.0.15.dist-info/METADATA,sha256=NJSqvHhTsV6-kN89Wa1LdFW8gmXutClh54r_iKsjo94,7235
|
115
|
+
boltz_vsynthes-0.0.15.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
116
|
+
boltz_vsynthes-0.0.15.dist-info/entry_points.txt,sha256=nZNYPKKrmAr-MVA0K-ClNRT2p90FV1_14d7HpsESZFQ,211
|
117
|
+
boltz_vsynthes-0.0.15.dist-info/top_level.txt,sha256=MgU3Jfb-ctWm07YGMts68PMjSh9v26D0gfG3dFRmVFA,6
|
118
|
+
boltz_vsynthes-0.0.15.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|