bayesianflow-for-chem 1.2.1__py3-none-any.whl → 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bayesianflow-for-chem might be problematic. Click here for more details.
- bayesianflow_for_chem/__init__.py +1 -1
- bayesianflow_for_chem/tool.py +64 -2
- {bayesianflow_for_chem-1.2.1.dist-info → bayesianflow_for_chem-1.2.2.dist-info}/METADATA +1 -1
- {bayesianflow_for_chem-1.2.1.dist-info → bayesianflow_for_chem-1.2.2.dist-info}/RECORD +7 -7
- {bayesianflow_for_chem-1.2.1.dist-info → bayesianflow_for_chem-1.2.2.dist-info}/LICENSE +0 -0
- {bayesianflow_for_chem-1.2.1.dist-info → bayesianflow_for_chem-1.2.2.dist-info}/WHEEL +0 -0
- {bayesianflow_for_chem-1.2.1.dist-info → bayesianflow_for_chem-1.2.2.dist-info}/top_level.txt +0 -0
bayesianflow_for_chem/tool.py
CHANGED
|
@@ -12,6 +12,11 @@ import torch
|
|
|
12
12
|
import numpy as np
|
|
13
13
|
from torch import cuda, Tensor, softmax
|
|
14
14
|
from torch.utils.data import DataLoader
|
|
15
|
+
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
|
|
16
|
+
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
|
|
17
|
+
XNNPACKQuantizer,
|
|
18
|
+
get_symmetric_quantization_config,
|
|
19
|
+
)
|
|
15
20
|
from rdkit.Chem import rdDetermineBonds, Bond, MolFromXYZBlock, CanonicalRankAtoms
|
|
16
21
|
from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles # type: ignore
|
|
17
22
|
from sklearn.metrics import (
|
|
@@ -380,7 +385,10 @@ def sample(
|
|
|
380
385
|
assert method.split(":")[0].lower() in ("ode", "bfn")
|
|
381
386
|
if device is None:
|
|
382
387
|
device = _find_device()
|
|
383
|
-
model.to(device)
|
|
388
|
+
model.to(device)
|
|
389
|
+
if not isinstance(model, torch.fx.GraphModule):
|
|
390
|
+
model.eval() # Calling eval() is not supported for GraphModule
|
|
391
|
+
# model.to(device).eval()
|
|
384
392
|
if y is not None:
|
|
385
393
|
y = y.to(device)
|
|
386
394
|
if isinstance(allowed_tokens, list):
|
|
@@ -455,7 +463,10 @@ def inpaint(
|
|
|
455
463
|
assert method.split(":")[0].lower() in ("ode", "bfn")
|
|
456
464
|
if device is None:
|
|
457
465
|
device = _find_device()
|
|
458
|
-
model.to(device)
|
|
466
|
+
model.to(device)
|
|
467
|
+
if not isinstance(model, torch.fx.GraphModule):
|
|
468
|
+
model.eval() # Calling eval() is not supported for GraphModule
|
|
469
|
+
# model.to(device).eval()
|
|
459
470
|
x = x.to(device)
|
|
460
471
|
if y is not None:
|
|
461
472
|
y = y.to(device)
|
|
@@ -484,3 +495,54 @@ def inpaint(
|
|
|
484
495
|
.replace("<pad>", "")
|
|
485
496
|
for j in tokens
|
|
486
497
|
]
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
def quantise_model(
|
|
501
|
+
model: ChemBFN, dataloader: DataLoader, mlp: Optional[MLP] = None
|
|
502
|
+
) -> torch.fx.GraphModule:
|
|
503
|
+
"""
|
|
504
|
+
Static quantisation of the input model.
|
|
505
|
+
|
|
506
|
+
:param model: trained ChemBFN model
|
|
507
|
+
:param dataloader: DataLoader instance containing example data for calibration
|
|
508
|
+
:param mlp: trained MLP model (guidance) if applied
|
|
509
|
+
:type model: bayesianflow_for_chem.model.ChemBFN
|
|
510
|
+
:type dataloader: torch.utils.data.DataLoader
|
|
511
|
+
:type mlp: bayesianflow_for_chem.model.MLP | None
|
|
512
|
+
:return: quantised model
|
|
513
|
+
:rtype: torch.fx.GraphModule
|
|
514
|
+
"""
|
|
515
|
+
nb, nt = dataloader._get_iterator()._next_data()["token"].shape
|
|
516
|
+
x = 2 * softmax(torch.rand((nb, nt, model.K)), -1) - 1
|
|
517
|
+
t = torch.rand((nb, 1, 1))
|
|
518
|
+
y = torch.randn(nb, 1, model.embedding.weight.shape[0]) if mlp is not None else None
|
|
519
|
+
example_input = (2 * x - 1, t, None, y)
|
|
520
|
+
graph_model = torch.export.export_for_training(model, example_input).module()
|
|
521
|
+
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
|
|
522
|
+
prepared_model = prepare_pt2e(graph_model, quantizer)
|
|
523
|
+
# ------- calibration -------
|
|
524
|
+
with torch.inference_mode():
|
|
525
|
+
for data in dataloader:
|
|
526
|
+
x = data["token"]
|
|
527
|
+
if x.shape[0] != nb:
|
|
528
|
+
break
|
|
529
|
+
if mlp is not None:
|
|
530
|
+
y = mlp(data["value"])[:, None, :]
|
|
531
|
+
else:
|
|
532
|
+
y = None
|
|
533
|
+
t = torch.rand((x.shape[0], 1, 1))
|
|
534
|
+
beta = model.calc_beta(t)
|
|
535
|
+
e_x = torch.nn.functional.one_hot(x, model.K).float()
|
|
536
|
+
mu = beta * (model.K * e_x - 1)
|
|
537
|
+
sigma = (beta * model.K).sqrt()
|
|
538
|
+
theta = softmax(mu + sigma * torch.randn_like(mu), -1)
|
|
539
|
+
prepared_model(2 * theta - 1, t, None, y)
|
|
540
|
+
quantised_model = convert_pt2e(prepared_model)
|
|
541
|
+
quantised_model = torch.export.export_for_training(
|
|
542
|
+
quantised_model, example_input
|
|
543
|
+
).module() # remove the weights of original model
|
|
544
|
+
quantised_model.sample = model.sample
|
|
545
|
+
quantised_model.ode_sample = model.ode_sample
|
|
546
|
+
quantised_model.inpaint = model.inpaint
|
|
547
|
+
quantised_model.ode_inpaint = model.ode_inpaint
|
|
548
|
+
return quantised_model
|
|
@@ -1,12 +1,12 @@
|
|
|
1
|
-
bayesianflow_for_chem/__init__.py,sha256=
|
|
1
|
+
bayesianflow_for_chem/__init__.py,sha256=sPILW44_x_imRo2kKPMKWQ45C4aNfRQDo1it5Smqqmo,293
|
|
2
2
|
bayesianflow_for_chem/data.py,sha256=9tpRba40lxwrB6aPSJMkxUglEVC3VEQC9wWxhDuz3Q8,7760
|
|
3
3
|
bayesianflow_for_chem/model.py,sha256=CEwqUMahNEcVOZaFjv1JcBokktjW9LspFsYzKjzNmZk,35922
|
|
4
4
|
bayesianflow_for_chem/scorer.py,sha256=mV1vX8aBGFra2BE7N8WHihVIo3dXmUdPQIGfSaiuNdk,4084
|
|
5
|
-
bayesianflow_for_chem/tool.py,sha256=
|
|
5
|
+
bayesianflow_for_chem/tool.py,sha256=kjR-BUenSjqkwI-TB0QwYXEMy9qdPjL6y4BZVCVfzHA,20237
|
|
6
6
|
bayesianflow_for_chem/train.py,sha256=kj6icGqymUUYopDtpre1oE_wpvpeNilbpzgffBsd1tk,9589
|
|
7
7
|
bayesianflow_for_chem/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
|
|
8
|
-
bayesianflow_for_chem-1.2.
|
|
9
|
-
bayesianflow_for_chem-1.2.
|
|
10
|
-
bayesianflow_for_chem-1.2.
|
|
11
|
-
bayesianflow_for_chem-1.2.
|
|
12
|
-
bayesianflow_for_chem-1.2.
|
|
8
|
+
bayesianflow_for_chem-1.2.2.dist-info/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
|
|
9
|
+
bayesianflow_for_chem-1.2.2.dist-info/METADATA,sha256=vRxX8mUrOJJwg_vkgXmla8s2vKhmZAVIYH_N3htAElQ,5890
|
|
10
|
+
bayesianflow_for_chem-1.2.2.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
11
|
+
bayesianflow_for_chem-1.2.2.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
|
|
12
|
+
bayesianflow_for_chem-1.2.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
{bayesianflow_for_chem-1.2.1.dist-info → bayesianflow_for_chem-1.2.2.dist-info}/top_level.txt
RENAMED
|
File without changes
|