bayesianflow-for-chem 1.2.1__tar.gz → 1.2.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bayesianflow-for-chem might be problematic. Click here for more details.

Files changed (18) hide show
  1. {bayesianflow_for_chem-1.2.1 → bayesianflow_for_chem-1.2.3}/PKG-INFO +1 -1
  2. {bayesianflow_for_chem-1.2.1 → bayesianflow_for_chem-1.2.3}/bayesianflow_for_chem/__init__.py +1 -1
  3. {bayesianflow_for_chem-1.2.1 → bayesianflow_for_chem-1.2.3}/bayesianflow_for_chem/tool.py +77 -2
  4. {bayesianflow_for_chem-1.2.1 → bayesianflow_for_chem-1.2.3}/bayesianflow_for_chem.egg-info/PKG-INFO +1 -1
  5. {bayesianflow_for_chem-1.2.1 → bayesianflow_for_chem-1.2.3}/LICENSE +0 -0
  6. {bayesianflow_for_chem-1.2.1 → bayesianflow_for_chem-1.2.3}/README.md +0 -0
  7. {bayesianflow_for_chem-1.2.1 → bayesianflow_for_chem-1.2.3}/bayesianflow_for_chem/data.py +0 -0
  8. {bayesianflow_for_chem-1.2.1 → bayesianflow_for_chem-1.2.3}/bayesianflow_for_chem/model.py +0 -0
  9. {bayesianflow_for_chem-1.2.1 → bayesianflow_for_chem-1.2.3}/bayesianflow_for_chem/scorer.py +0 -0
  10. {bayesianflow_for_chem-1.2.1 → bayesianflow_for_chem-1.2.3}/bayesianflow_for_chem/train.py +0 -0
  11. {bayesianflow_for_chem-1.2.1 → bayesianflow_for_chem-1.2.3}/bayesianflow_for_chem/vocab.txt +0 -0
  12. {bayesianflow_for_chem-1.2.1 → bayesianflow_for_chem-1.2.3}/bayesianflow_for_chem.egg-info/SOURCES.txt +0 -0
  13. {bayesianflow_for_chem-1.2.1 → bayesianflow_for_chem-1.2.3}/bayesianflow_for_chem.egg-info/dependency_links.txt +0 -0
  14. {bayesianflow_for_chem-1.2.1 → bayesianflow_for_chem-1.2.3}/bayesianflow_for_chem.egg-info/requires.txt +0 -0
  15. {bayesianflow_for_chem-1.2.1 → bayesianflow_for_chem-1.2.3}/bayesianflow_for_chem.egg-info/top_level.txt +0 -0
  16. {bayesianflow_for_chem-1.2.1 → bayesianflow_for_chem-1.2.3}/pyproject.toml +0 -0
  17. {bayesianflow_for_chem-1.2.1 → bayesianflow_for_chem-1.2.3}/setup.cfg +0 -0
  18. {bayesianflow_for_chem-1.2.1 → bayesianflow_for_chem-1.2.3}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: bayesianflow_for_chem
3
- Version: 1.2.1
3
+ Version: 1.2.3
4
4
  Summary: Bayesian flow network framework for Chemistry
5
5
  Home-page: https://augus1999.github.io/bayesian-flow-network-for-chemistry/
6
6
  Author: Nianze A. Tao
@@ -7,5 +7,5 @@ from . import data, tool, train, scorer
7
7
  from .model import ChemBFN, MLP
8
8
 
9
9
  __all__ = ["data", "tool", "train", "scorer", "ChemBFN", "MLP"]
10
- __version__ = "1.2.1"
10
+ __version__ = "1.2.3"
11
11
  __author__ = "Nianze A. Tao (Omozawa Sueno)"
@@ -11,7 +11,13 @@ from typing import List, Dict, Tuple, Union, Optional
11
11
  import torch
12
12
  import numpy as np
13
13
  from torch import cuda, Tensor, softmax
14
+ from torch.ao.quantization import move_exported_model_to_eval
14
15
  from torch.utils.data import DataLoader
16
+ from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
17
+ from torch.ao.quantization.quantizer.xnnpack_quantizer import (
18
+ XNNPACKQuantizer,
19
+ get_symmetric_quantization_config,
20
+ )
15
21
  from rdkit.Chem import rdDetermineBonds, Bond, MolFromXYZBlock, CanonicalRankAtoms
16
22
  from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles # type: ignore
17
23
  from sklearn.metrics import (
@@ -380,7 +386,9 @@ def sample(
380
386
  assert method.split(":")[0].lower() in ("ode", "bfn")
381
387
  if device is None:
382
388
  device = _find_device()
383
- model.to(device).eval()
389
+ model.to(device)
390
+ if not isinstance(model, torch.fx.GraphModule):
391
+ model.eval() # Calling eval() is not supported for GraphModule
384
392
  if y is not None:
385
393
  y = y.to(device)
386
394
  if isinstance(allowed_tokens, list):
@@ -455,7 +463,9 @@ def inpaint(
455
463
  assert method.split(":")[0].lower() in ("ode", "bfn")
456
464
  if device is None:
457
465
  device = _find_device()
458
- model.to(device).eval()
466
+ model.to(device)
467
+ if not isinstance(model, torch.fx.GraphModule):
468
+ model.eval() # Calling eval() is not supported for GraphModule
459
469
  x = x.to(device)
460
470
  if y is not None:
461
471
  y = y.to(device)
@@ -484,3 +494,68 @@ def inpaint(
484
494
  .replace("<pad>", "")
485
495
  for j in tokens
486
496
  ]
497
+
498
+
499
+ def quantise_model(
500
+ model: ChemBFN,
501
+ dataloader: DataLoader,
502
+ mlp: Optional[MLP] = None,
503
+ save_model: bool = False,
504
+ save_model_file_path: Union[str, Path] = "qmodel.pt",
505
+ ) -> torch.fx.GraphModule:
506
+ """
507
+ Static quantisation of the trained model.
508
+
509
+ :param model: trained ChemBFN model
510
+ :param dataloader: DataLoader instance containing example data for calibration
511
+ :param mlp: trained MLP model (guidance) if applied
512
+ :param save_model: whether to save the model
513
+ :param save_model_file_path: file name of the saved model; not used if `save_model=False`
514
+ :type model: bayesianflow_for_chem.model.ChemBFN
515
+ :type dataloader: torch.utils.data.DataLoader
516
+ :type mlp: bayesianflow_for_chem.model.MLP | None
517
+ :type save_model: bool
518
+ :type save_model_file_path: str | pathlib.Path
519
+ :return: quantised model
520
+ :rtype: torch.fx.GraphModule
521
+ """
522
+ model.eval()
523
+ nb, nt = dataloader._get_iterator()._next_data()["token"].shape
524
+ x = 2 * softmax(torch.rand((nb, nt, model.K)), -1) - 1
525
+ t = torch.rand((nb, 1, 1))
526
+ y = torch.randn(nb, 1, model.embedding.weight.shape[0]) if mlp is not None else None
527
+ example_input = (2 * x - 1, t, None, y)
528
+ graph_model = torch.export.export_for_training(model, example_input).module()
529
+ quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
530
+ prepared_model = prepare_pt2e(graph_model, quantizer)
531
+ # ------- calibration -------
532
+ with torch.inference_mode():
533
+ move_exported_model_to_eval(prepared_model)
534
+ for data in dataloader:
535
+ x = data["token"]
536
+ if x.shape[0] != nb:
537
+ break
538
+ if mlp is not None:
539
+ y = mlp(data["value"])[:, None, :]
540
+ else:
541
+ y = None
542
+ t = torch.rand((x.shape[0], 1, 1))
543
+ beta = model.calc_beta(t)
544
+ e_x = torch.nn.functional.one_hot(x, model.K).float()
545
+ mu = beta * (model.K * e_x - 1)
546
+ sigma = (beta * model.K).sqrt()
547
+ theta = softmax(mu + sigma * torch.randn_like(mu), -1)
548
+ prepared_model(2 * theta - 1, t, None, y)
549
+ # ---------------------------
550
+ quantised_model = convert_pt2e(prepared_model)
551
+ quantised_model = torch.export.export_for_training(
552
+ quantised_model, example_input
553
+ ).module() # remove the weights of original model
554
+ quantised_model.sample = model.sample
555
+ quantised_model.ode_sample = model.ode_sample
556
+ quantised_model.inpaint = model.inpaint
557
+ quantised_model.ode_inpaint = model.ode_inpaint
558
+ if save_model:
559
+ quantised_ep = torch.export.export(quantised_model, example_input)
560
+ torch.export.save(quantised_ep, save_model_file_path)
561
+ return quantised_model
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: bayesianflow_for_chem
3
- Version: 1.2.1
3
+ Version: 1.2.3
4
4
  Summary: Bayesian flow network framework for Chemistry
5
5
  Home-page: https://augus1999.github.io/bayesian-flow-network-for-chemistry/
6
6
  Author: Nianze A. Tao