bayesianflow-for-chem 1.2.2__py3-none-any.whl → 1.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bayesianflow-for-chem might be problematic. Click here for more details.

@@ -7,5 +7,5 @@ from . import data, tool, train, scorer
7
7
  from .model import ChemBFN, MLP
8
8
 
9
9
  __all__ = ["data", "tool", "train", "scorer", "ChemBFN", "MLP"]
10
- __version__ = "1.2.2"
10
+ __version__ = "1.2.3"
11
11
  __author__ = "Nianze A. Tao (Omozawa Sueno)"
@@ -11,6 +11,7 @@ from typing import List, Dict, Tuple, Union, Optional
11
11
  import torch
12
12
  import numpy as np
13
13
  from torch import cuda, Tensor, softmax
14
+ from torch.ao.quantization import move_exported_model_to_eval
14
15
  from torch.utils.data import DataLoader
15
16
  from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
16
17
  from torch.ao.quantization.quantizer.xnnpack_quantizer import (
@@ -388,7 +389,6 @@ def sample(
388
389
  model.to(device)
389
390
  if not isinstance(model, torch.fx.GraphModule):
390
391
  model.eval() # Calling eval() is not supported for GraphModule
391
- # model.to(device).eval()
392
392
  if y is not None:
393
393
  y = y.to(device)
394
394
  if isinstance(allowed_tokens, list):
@@ -466,7 +466,6 @@ def inpaint(
466
466
  model.to(device)
467
467
  if not isinstance(model, torch.fx.GraphModule):
468
468
  model.eval() # Calling eval() is not supported for GraphModule
469
- # model.to(device).eval()
470
469
  x = x.to(device)
471
470
  if y is not None:
472
471
  y = y.to(device)
@@ -498,20 +497,29 @@ def inpaint(
498
497
 
499
498
 
500
499
  def quantise_model(
501
- model: ChemBFN, dataloader: DataLoader, mlp: Optional[MLP] = None
500
+ model: ChemBFN,
501
+ dataloader: DataLoader,
502
+ mlp: Optional[MLP] = None,
503
+ save_model: bool = False,
504
+ save_model_file_path: Union[str, Path] = "qmodel.pt",
502
505
  ) -> torch.fx.GraphModule:
503
506
  """
504
- Static quantisation of the input model.
507
+ Static quantisation of the trained model.
505
508
 
506
509
  :param model: trained ChemBFN model
507
510
  :param dataloader: DataLoader instance containing example data for calibration
508
511
  :param mlp: trained MLP model (guidance) if applied
512
+ :param save_model: whether to save the model
513
+ :param save_model_file_path: file name of the saved model; not used if `save_model=False`
509
514
  :type model: bayesianflow_for_chem.model.ChemBFN
510
515
  :type dataloader: torch.utils.data.DataLoader
511
516
  :type mlp: bayesianflow_for_chem.model.MLP | None
517
+ :type save_model: bool
518
+ :type save_model_file_path: str | pathlib.Path
512
519
  :return: quantised model
513
520
  :rtype: torch.fx.GraphModule
514
521
  """
522
+ model.eval()
515
523
  nb, nt = dataloader._get_iterator()._next_data()["token"].shape
516
524
  x = 2 * softmax(torch.rand((nb, nt, model.K)), -1) - 1
517
525
  t = torch.rand((nb, 1, 1))
@@ -522,6 +530,7 @@ def quantise_model(
522
530
  prepared_model = prepare_pt2e(graph_model, quantizer)
523
531
  # ------- calibration -------
524
532
  with torch.inference_mode():
533
+ move_exported_model_to_eval(prepared_model)
525
534
  for data in dataloader:
526
535
  x = data["token"]
527
536
  if x.shape[0] != nb:
@@ -537,6 +546,7 @@ def quantise_model(
537
546
  sigma = (beta * model.K).sqrt()
538
547
  theta = softmax(mu + sigma * torch.randn_like(mu), -1)
539
548
  prepared_model(2 * theta - 1, t, None, y)
549
+ # ---------------------------
540
550
  quantised_model = convert_pt2e(prepared_model)
541
551
  quantised_model = torch.export.export_for_training(
542
552
  quantised_model, example_input
@@ -545,4 +555,7 @@ def quantise_model(
545
555
  quantised_model.ode_sample = model.ode_sample
546
556
  quantised_model.inpaint = model.inpaint
547
557
  quantised_model.ode_inpaint = model.ode_inpaint
558
+ if save_model:
559
+ quantised_ep = torch.export.export(quantised_model, example_input)
560
+ torch.export.save(quantised_ep, save_model_file_path)
548
561
  return quantised_model
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: bayesianflow_for_chem
3
- Version: 1.2.2
3
+ Version: 1.2.3
4
4
  Summary: Bayesian flow network framework for Chemistry
5
5
  Home-page: https://augus1999.github.io/bayesian-flow-network-for-chemistry/
6
6
  Author: Nianze A. Tao
@@ -1,12 +1,12 @@
1
- bayesianflow_for_chem/__init__.py,sha256=sPILW44_x_imRo2kKPMKWQ45C4aNfRQDo1it5Smqqmo,293
1
+ bayesianflow_for_chem/__init__.py,sha256=g03hNao3V1EGkajm0W0Ydrdnn7rpj7NRoz-JjdTBNUE,293
2
2
  bayesianflow_for_chem/data.py,sha256=9tpRba40lxwrB6aPSJMkxUglEVC3VEQC9wWxhDuz3Q8,7760
3
3
  bayesianflow_for_chem/model.py,sha256=CEwqUMahNEcVOZaFjv1JcBokktjW9LspFsYzKjzNmZk,35922
4
4
  bayesianflow_for_chem/scorer.py,sha256=mV1vX8aBGFra2BE7N8WHihVIo3dXmUdPQIGfSaiuNdk,4084
5
- bayesianflow_for_chem/tool.py,sha256=kjR-BUenSjqkwI-TB0QwYXEMy9qdPjL6y4BZVCVfzHA,20237
5
+ bayesianflow_for_chem/tool.py,sha256=K6COLatSqvBXwcXV7QtVQX2sJOxfibwiIq0yIh96kfg,20818
6
6
  bayesianflow_for_chem/train.py,sha256=kj6icGqymUUYopDtpre1oE_wpvpeNilbpzgffBsd1tk,9589
7
7
  bayesianflow_for_chem/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
8
- bayesianflow_for_chem-1.2.2.dist-info/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
9
- bayesianflow_for_chem-1.2.2.dist-info/METADATA,sha256=vRxX8mUrOJJwg_vkgXmla8s2vKhmZAVIYH_N3htAElQ,5890
10
- bayesianflow_for_chem-1.2.2.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
11
- bayesianflow_for_chem-1.2.2.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
12
- bayesianflow_for_chem-1.2.2.dist-info/RECORD,,
8
+ bayesianflow_for_chem-1.2.3.dist-info/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
9
+ bayesianflow_for_chem-1.2.3.dist-info/METADATA,sha256=QYbTrgY0QqfgpCFWYXbkeWjBVZ7wF9J7RraO6g7sbbI,5890
10
+ bayesianflow_for_chem-1.2.3.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
11
+ bayesianflow_for_chem-1.2.3.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
12
+ bayesianflow_for_chem-1.2.3.dist-info/RECORD,,