bayesianflow-for-chem 1.2.2__py3-none-any.whl → 1.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bayesianflow-for-chem might be problematic. Click here for more details.
- bayesianflow_for_chem/__init__.py +1 -1
- bayesianflow_for_chem/tool.py +17 -4
- {bayesianflow_for_chem-1.2.2.dist-info → bayesianflow_for_chem-1.2.3.dist-info}/METADATA +1 -1
- {bayesianflow_for_chem-1.2.2.dist-info → bayesianflow_for_chem-1.2.3.dist-info}/RECORD +7 -7
- {bayesianflow_for_chem-1.2.2.dist-info → bayesianflow_for_chem-1.2.3.dist-info}/LICENSE +0 -0
- {bayesianflow_for_chem-1.2.2.dist-info → bayesianflow_for_chem-1.2.3.dist-info}/WHEEL +0 -0
- {bayesianflow_for_chem-1.2.2.dist-info → bayesianflow_for_chem-1.2.3.dist-info}/top_level.txt +0 -0
bayesianflow_for_chem/tool.py
CHANGED
|
@@ -11,6 +11,7 @@ from typing import List, Dict, Tuple, Union, Optional
|
|
|
11
11
|
import torch
|
|
12
12
|
import numpy as np
|
|
13
13
|
from torch import cuda, Tensor, softmax
|
|
14
|
+
from torch.ao.quantization import move_exported_model_to_eval
|
|
14
15
|
from torch.utils.data import DataLoader
|
|
15
16
|
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
|
|
16
17
|
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
|
|
@@ -388,7 +389,6 @@ def sample(
|
|
|
388
389
|
model.to(device)
|
|
389
390
|
if not isinstance(model, torch.fx.GraphModule):
|
|
390
391
|
model.eval() # Calling eval() is not supported for GraphModule
|
|
391
|
-
# model.to(device).eval()
|
|
392
392
|
if y is not None:
|
|
393
393
|
y = y.to(device)
|
|
394
394
|
if isinstance(allowed_tokens, list):
|
|
@@ -466,7 +466,6 @@ def inpaint(
|
|
|
466
466
|
model.to(device)
|
|
467
467
|
if not isinstance(model, torch.fx.GraphModule):
|
|
468
468
|
model.eval() # Calling eval() is not supported for GraphModule
|
|
469
|
-
# model.to(device).eval()
|
|
470
469
|
x = x.to(device)
|
|
471
470
|
if y is not None:
|
|
472
471
|
y = y.to(device)
|
|
@@ -498,20 +497,29 @@ def inpaint(
|
|
|
498
497
|
|
|
499
498
|
|
|
500
499
|
def quantise_model(
|
|
501
|
-
model: ChemBFN,
|
|
500
|
+
model: ChemBFN,
|
|
501
|
+
dataloader: DataLoader,
|
|
502
|
+
mlp: Optional[MLP] = None,
|
|
503
|
+
save_model: bool = False,
|
|
504
|
+
save_model_file_path: Union[str, Path] = "qmodel.pt",
|
|
502
505
|
) -> torch.fx.GraphModule:
|
|
503
506
|
"""
|
|
504
|
-
Static quantisation of the
|
|
507
|
+
Static quantisation of the trained model.
|
|
505
508
|
|
|
506
509
|
:param model: trained ChemBFN model
|
|
507
510
|
:param dataloader: DataLoader instance containing example data for calibration
|
|
508
511
|
:param mlp: trained MLP model (guidance) if applied
|
|
512
|
+
:param save_model: whether to save the model
|
|
513
|
+
:param save_model_file_path: file name of the saved model; not used if `save_model=False`
|
|
509
514
|
:type model: bayesianflow_for_chem.model.ChemBFN
|
|
510
515
|
:type dataloader: torch.utils.data.DataLoader
|
|
511
516
|
:type mlp: bayesianflow_for_chem.model.MLP | None
|
|
517
|
+
:type save_model: bool
|
|
518
|
+
:type save_model_file_path: str | pathlib.Path
|
|
512
519
|
:return: quantised model
|
|
513
520
|
:rtype: torch.fx.GraphModule
|
|
514
521
|
"""
|
|
522
|
+
model.eval()
|
|
515
523
|
nb, nt = dataloader._get_iterator()._next_data()["token"].shape
|
|
516
524
|
x = 2 * softmax(torch.rand((nb, nt, model.K)), -1) - 1
|
|
517
525
|
t = torch.rand((nb, 1, 1))
|
|
@@ -522,6 +530,7 @@ def quantise_model(
|
|
|
522
530
|
prepared_model = prepare_pt2e(graph_model, quantizer)
|
|
523
531
|
# ------- calibration -------
|
|
524
532
|
with torch.inference_mode():
|
|
533
|
+
move_exported_model_to_eval(prepared_model)
|
|
525
534
|
for data in dataloader:
|
|
526
535
|
x = data["token"]
|
|
527
536
|
if x.shape[0] != nb:
|
|
@@ -537,6 +546,7 @@ def quantise_model(
|
|
|
537
546
|
sigma = (beta * model.K).sqrt()
|
|
538
547
|
theta = softmax(mu + sigma * torch.randn_like(mu), -1)
|
|
539
548
|
prepared_model(2 * theta - 1, t, None, y)
|
|
549
|
+
# ---------------------------
|
|
540
550
|
quantised_model = convert_pt2e(prepared_model)
|
|
541
551
|
quantised_model = torch.export.export_for_training(
|
|
542
552
|
quantised_model, example_input
|
|
@@ -545,4 +555,7 @@ def quantise_model(
|
|
|
545
555
|
quantised_model.ode_sample = model.ode_sample
|
|
546
556
|
quantised_model.inpaint = model.inpaint
|
|
547
557
|
quantised_model.ode_inpaint = model.ode_inpaint
|
|
558
|
+
if save_model:
|
|
559
|
+
quantised_ep = torch.export.export(quantised_model, example_input)
|
|
560
|
+
torch.export.save(quantised_ep, save_model_file_path)
|
|
548
561
|
return quantised_model
|
|
@@ -1,12 +1,12 @@
|
|
|
1
|
-
bayesianflow_for_chem/__init__.py,sha256=
|
|
1
|
+
bayesianflow_for_chem/__init__.py,sha256=g03hNao3V1EGkajm0W0Ydrdnn7rpj7NRoz-JjdTBNUE,293
|
|
2
2
|
bayesianflow_for_chem/data.py,sha256=9tpRba40lxwrB6aPSJMkxUglEVC3VEQC9wWxhDuz3Q8,7760
|
|
3
3
|
bayesianflow_for_chem/model.py,sha256=CEwqUMahNEcVOZaFjv1JcBokktjW9LspFsYzKjzNmZk,35922
|
|
4
4
|
bayesianflow_for_chem/scorer.py,sha256=mV1vX8aBGFra2BE7N8WHihVIo3dXmUdPQIGfSaiuNdk,4084
|
|
5
|
-
bayesianflow_for_chem/tool.py,sha256=
|
|
5
|
+
bayesianflow_for_chem/tool.py,sha256=K6COLatSqvBXwcXV7QtVQX2sJOxfibwiIq0yIh96kfg,20818
|
|
6
6
|
bayesianflow_for_chem/train.py,sha256=kj6icGqymUUYopDtpre1oE_wpvpeNilbpzgffBsd1tk,9589
|
|
7
7
|
bayesianflow_for_chem/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
|
|
8
|
-
bayesianflow_for_chem-1.2.
|
|
9
|
-
bayesianflow_for_chem-1.2.
|
|
10
|
-
bayesianflow_for_chem-1.2.
|
|
11
|
-
bayesianflow_for_chem-1.2.
|
|
12
|
-
bayesianflow_for_chem-1.2.
|
|
8
|
+
bayesianflow_for_chem-1.2.3.dist-info/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
|
|
9
|
+
bayesianflow_for_chem-1.2.3.dist-info/METADATA,sha256=QYbTrgY0QqfgpCFWYXbkeWjBVZ7wF9J7RraO6g7sbbI,5890
|
|
10
|
+
bayesianflow_for_chem-1.2.3.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
11
|
+
bayesianflow_for_chem-1.2.3.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
|
|
12
|
+
bayesianflow_for_chem-1.2.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
{bayesianflow_for_chem-1.2.2.dist-info → bayesianflow_for_chem-1.2.3.dist-info}/top_level.txt
RENAMED
|
File without changes
|