bayesianflow-for-chem 1.2.1__tar.gz → 1.2.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bayesianflow-for-chem might be problematic. Click here for more details.
- {bayesianflow_for_chem-1.2.1 → bayesianflow_for_chem-1.2.3}/PKG-INFO +1 -1
- {bayesianflow_for_chem-1.2.1 → bayesianflow_for_chem-1.2.3}/bayesianflow_for_chem/__init__.py +1 -1
- {bayesianflow_for_chem-1.2.1 → bayesianflow_for_chem-1.2.3}/bayesianflow_for_chem/tool.py +77 -2
- {bayesianflow_for_chem-1.2.1 → bayesianflow_for_chem-1.2.3}/bayesianflow_for_chem.egg-info/PKG-INFO +1 -1
- {bayesianflow_for_chem-1.2.1 → bayesianflow_for_chem-1.2.3}/LICENSE +0 -0
- {bayesianflow_for_chem-1.2.1 → bayesianflow_for_chem-1.2.3}/README.md +0 -0
- {bayesianflow_for_chem-1.2.1 → bayesianflow_for_chem-1.2.3}/bayesianflow_for_chem/data.py +0 -0
- {bayesianflow_for_chem-1.2.1 → bayesianflow_for_chem-1.2.3}/bayesianflow_for_chem/model.py +0 -0
- {bayesianflow_for_chem-1.2.1 → bayesianflow_for_chem-1.2.3}/bayesianflow_for_chem/scorer.py +0 -0
- {bayesianflow_for_chem-1.2.1 → bayesianflow_for_chem-1.2.3}/bayesianflow_for_chem/train.py +0 -0
- {bayesianflow_for_chem-1.2.1 → bayesianflow_for_chem-1.2.3}/bayesianflow_for_chem/vocab.txt +0 -0
- {bayesianflow_for_chem-1.2.1 → bayesianflow_for_chem-1.2.3}/bayesianflow_for_chem.egg-info/SOURCES.txt +0 -0
- {bayesianflow_for_chem-1.2.1 → bayesianflow_for_chem-1.2.3}/bayesianflow_for_chem.egg-info/dependency_links.txt +0 -0
- {bayesianflow_for_chem-1.2.1 → bayesianflow_for_chem-1.2.3}/bayesianflow_for_chem.egg-info/requires.txt +0 -0
- {bayesianflow_for_chem-1.2.1 → bayesianflow_for_chem-1.2.3}/bayesianflow_for_chem.egg-info/top_level.txt +0 -0
- {bayesianflow_for_chem-1.2.1 → bayesianflow_for_chem-1.2.3}/pyproject.toml +0 -0
- {bayesianflow_for_chem-1.2.1 → bayesianflow_for_chem-1.2.3}/setup.cfg +0 -0
- {bayesianflow_for_chem-1.2.1 → bayesianflow_for_chem-1.2.3}/setup.py +0 -0
|
@@ -11,7 +11,13 @@ from typing import List, Dict, Tuple, Union, Optional
|
|
|
11
11
|
import torch
|
|
12
12
|
import numpy as np
|
|
13
13
|
from torch import cuda, Tensor, softmax
|
|
14
|
+
from torch.ao.quantization import move_exported_model_to_eval
|
|
14
15
|
from torch.utils.data import DataLoader
|
|
16
|
+
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
|
|
17
|
+
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
|
|
18
|
+
XNNPACKQuantizer,
|
|
19
|
+
get_symmetric_quantization_config,
|
|
20
|
+
)
|
|
15
21
|
from rdkit.Chem import rdDetermineBonds, Bond, MolFromXYZBlock, CanonicalRankAtoms
|
|
16
22
|
from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles # type: ignore
|
|
17
23
|
from sklearn.metrics import (
|
|
@@ -380,7 +386,9 @@ def sample(
|
|
|
380
386
|
assert method.split(":")[0].lower() in ("ode", "bfn")
|
|
381
387
|
if device is None:
|
|
382
388
|
device = _find_device()
|
|
383
|
-
model.to(device)
|
|
389
|
+
model.to(device)
|
|
390
|
+
if not isinstance(model, torch.fx.GraphModule):
|
|
391
|
+
model.eval() # Calling eval() is not supported for GraphModule
|
|
384
392
|
if y is not None:
|
|
385
393
|
y = y.to(device)
|
|
386
394
|
if isinstance(allowed_tokens, list):
|
|
@@ -455,7 +463,9 @@ def inpaint(
|
|
|
455
463
|
assert method.split(":")[0].lower() in ("ode", "bfn")
|
|
456
464
|
if device is None:
|
|
457
465
|
device = _find_device()
|
|
458
|
-
model.to(device)
|
|
466
|
+
model.to(device)
|
|
467
|
+
if not isinstance(model, torch.fx.GraphModule):
|
|
468
|
+
model.eval() # Calling eval() is not supported for GraphModule
|
|
459
469
|
x = x.to(device)
|
|
460
470
|
if y is not None:
|
|
461
471
|
y = y.to(device)
|
|
@@ -484,3 +494,68 @@ def inpaint(
|
|
|
484
494
|
.replace("<pad>", "")
|
|
485
495
|
for j in tokens
|
|
486
496
|
]
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
def quantise_model(
|
|
500
|
+
model: ChemBFN,
|
|
501
|
+
dataloader: DataLoader,
|
|
502
|
+
mlp: Optional[MLP] = None,
|
|
503
|
+
save_model: bool = False,
|
|
504
|
+
save_model_file_path: Union[str, Path] = "qmodel.pt",
|
|
505
|
+
) -> torch.fx.GraphModule:
|
|
506
|
+
"""
|
|
507
|
+
Static quantisation of the trained model.
|
|
508
|
+
|
|
509
|
+
:param model: trained ChemBFN model
|
|
510
|
+
:param dataloader: DataLoader instance containing example data for calibration
|
|
511
|
+
:param mlp: trained MLP model (guidance) if applied
|
|
512
|
+
:param save_model: whether to save the model
|
|
513
|
+
:param save_model_file_path: file name of the saved model; not used if `save_model=False`
|
|
514
|
+
:type model: bayesianflow_for_chem.model.ChemBFN
|
|
515
|
+
:type dataloader: torch.utils.data.DataLoader
|
|
516
|
+
:type mlp: bayesianflow_for_chem.model.MLP | None
|
|
517
|
+
:type save_model: bool
|
|
518
|
+
:type save_model_file_path: str | pathlib.Path
|
|
519
|
+
:return: quantised model
|
|
520
|
+
:rtype: torch.fx.GraphModule
|
|
521
|
+
"""
|
|
522
|
+
model.eval()
|
|
523
|
+
nb, nt = dataloader._get_iterator()._next_data()["token"].shape
|
|
524
|
+
x = 2 * softmax(torch.rand((nb, nt, model.K)), -1) - 1
|
|
525
|
+
t = torch.rand((nb, 1, 1))
|
|
526
|
+
y = torch.randn(nb, 1, model.embedding.weight.shape[0]) if mlp is not None else None
|
|
527
|
+
example_input = (2 * x - 1, t, None, y)
|
|
528
|
+
graph_model = torch.export.export_for_training(model, example_input).module()
|
|
529
|
+
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
|
|
530
|
+
prepared_model = prepare_pt2e(graph_model, quantizer)
|
|
531
|
+
# ------- calibration -------
|
|
532
|
+
with torch.inference_mode():
|
|
533
|
+
move_exported_model_to_eval(prepared_model)
|
|
534
|
+
for data in dataloader:
|
|
535
|
+
x = data["token"]
|
|
536
|
+
if x.shape[0] != nb:
|
|
537
|
+
break
|
|
538
|
+
if mlp is not None:
|
|
539
|
+
y = mlp(data["value"])[:, None, :]
|
|
540
|
+
else:
|
|
541
|
+
y = None
|
|
542
|
+
t = torch.rand((x.shape[0], 1, 1))
|
|
543
|
+
beta = model.calc_beta(t)
|
|
544
|
+
e_x = torch.nn.functional.one_hot(x, model.K).float()
|
|
545
|
+
mu = beta * (model.K * e_x - 1)
|
|
546
|
+
sigma = (beta * model.K).sqrt()
|
|
547
|
+
theta = softmax(mu + sigma * torch.randn_like(mu), -1)
|
|
548
|
+
prepared_model(2 * theta - 1, t, None, y)
|
|
549
|
+
# ---------------------------
|
|
550
|
+
quantised_model = convert_pt2e(prepared_model)
|
|
551
|
+
quantised_model = torch.export.export_for_training(
|
|
552
|
+
quantised_model, example_input
|
|
553
|
+
).module() # remove the weights of original model
|
|
554
|
+
quantised_model.sample = model.sample
|
|
555
|
+
quantised_model.ode_sample = model.ode_sample
|
|
556
|
+
quantised_model.inpaint = model.inpaint
|
|
557
|
+
quantised_model.ode_inpaint = model.ode_inpaint
|
|
558
|
+
if save_model:
|
|
559
|
+
quantised_ep = torch.export.export(quantised_model, example_input)
|
|
560
|
+
torch.export.save(quantised_ep, save_model_file_path)
|
|
561
|
+
return quantised_model
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|