bayesianflow-for-chem 1.2.3__py3-none-any.whl → 1.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bayesianflow-for-chem might be problematic. Click here for more details.

@@ -7,5 +7,5 @@ from . import data, tool, train, scorer
7
7
  from .model import ChemBFN, MLP
8
8
 
9
9
  __all__ = ["data", "tool", "train", "scorer", "ChemBFN", "MLP"]
10
- __version__ = "1.2.3"
10
+ __version__ = "1.2.5"
11
11
  __author__ = "Nianze A. Tao (Omozawa Sueno)"
@@ -847,7 +847,7 @@ class ChemBFN(nn.Module):
847
847
  with open(ckpt, "rb") as f:
848
848
  state = torch.load(f, "cpu", weights_only=True)
849
849
  nn, hparam = state["nn"], state["hparam"]
850
- model = ChemBFN(
850
+ model = cls(
851
851
  hparam["num_vocab"],
852
852
  hparam["channel"],
853
853
  hparam["num_layer"],
@@ -926,7 +926,7 @@ class MLP(nn.Module):
926
926
  with open(ckpt, "rb") as f:
927
927
  state = torch.load(f, "cpu", weights_only=True)
928
928
  nn, hparam = state["nn"], state["hparam"]
929
- model = MLP(hparam["size"], hparam["class_input"], hparam["dropout"])
929
+ model = cls(hparam["size"], hparam["class_input"], hparam["dropout"])
930
930
  model.load_state_dict(nn, strict)
931
931
  return model
932
932
 
@@ -6,18 +6,16 @@ Tools.
6
6
  import re
7
7
  import csv
8
8
  import random
9
+ from copy import deepcopy
9
10
  from pathlib import Path
10
11
  from typing import List, Dict, Tuple, Union, Optional
11
12
  import torch
12
13
  import numpy as np
14
+ import torch.nn as nn
13
15
  from torch import cuda, Tensor, softmax
14
- from torch.ao.quantization import move_exported_model_to_eval
16
+ from torch.ao import quantization
15
17
  from torch.utils.data import DataLoader
16
- from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
17
- from torch.ao.quantization.quantizer.xnnpack_quantizer import (
18
- XNNPACKQuantizer,
19
- get_symmetric_quantization_config,
20
- )
18
+ from typing_extensions import Self
21
19
  from rdkit.Chem import rdDetermineBonds, Bond, MolFromXYZBlock, CanonicalRankAtoms
22
20
  from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles # type: ignore
23
21
  from sklearn.metrics import (
@@ -39,7 +37,7 @@ except ImportError:
39
37
  _use_pynauty = False
40
38
 
41
39
  from .data import VOCAB_KEYS
42
- from .model import ChemBFN, MLP
40
+ from .model import ChemBFN, MLP, Linear
43
41
 
44
42
 
45
43
  _atom_regex_pattern = (
@@ -386,9 +384,7 @@ def sample(
386
384
  assert method.split(":")[0].lower() in ("ode", "bfn")
387
385
  if device is None:
388
386
  device = _find_device()
389
- model.to(device)
390
- if not isinstance(model, torch.fx.GraphModule):
391
- model.eval() # Calling eval() is not supported for GraphModule
387
+ model.to(device).eval()
392
388
  if y is not None:
393
389
  y = y.to(device)
394
390
  if isinstance(allowed_tokens, list):
@@ -463,9 +459,7 @@ def inpaint(
463
459
  assert method.split(":")[0].lower() in ("ode", "bfn")
464
460
  if device is None:
465
461
  device = _find_device()
466
- model.to(device)
467
- if not isinstance(model, torch.fx.GraphModule):
468
- model.eval() # Calling eval() is not supported for GraphModule
462
+ model.to(device).eval()
469
463
  x = x.to(device)
470
464
  if y is not None:
471
465
  y = y.to(device)
@@ -496,66 +490,129 @@ def inpaint(
496
490
  ]
497
491
 
498
492
 
499
- def quantise_model(
500
- model: ChemBFN,
501
- dataloader: DataLoader,
502
- mlp: Optional[MLP] = None,
503
- save_model: bool = False,
504
- save_model_file_path: Union[str, Path] = "qmodel.pt",
505
- ) -> torch.fx.GraphModule:
493
+ def quantise_model(model: ChemBFN) -> nn.Module:
506
494
  """
507
- Static quantisation of the trained model.
495
+ Dynamic quantisation of the trained model.
508
496
 
509
497
  :param model: trained ChemBFN model
510
- :param dataloader: DataLoader instance containing example data for calibration
511
- :param mlp: trained MLP model (guidance) if applied
512
- :param save_model: whether to save the model
513
- :param save_model_file_path: file name of the saved model; not used if `save_model=False`
514
498
  :type model: bayesianflow_for_chem.model.ChemBFN
515
- :type dataloader: torch.utils.data.DataLoader
516
- :type mlp: bayesianflow_for_chem.model.MLP | None
517
- :type save_model: bool
518
- :type save_model_file_path: str | pathlib.Path
519
499
  :return: quantised model
520
- :rtype: torch.fx.GraphModule
500
+ :rtype: torch.nn.Module
521
501
  """
522
- model.eval()
523
- nb, nt = dataloader._get_iterator()._next_data()["token"].shape
524
- x = 2 * softmax(torch.rand((nb, nt, model.K)), -1) - 1
525
- t = torch.rand((nb, 1, 1))
526
- y = torch.randn(nb, 1, model.embedding.weight.shape[0]) if mlp is not None else None
527
- example_input = (2 * x - 1, t, None, y)
528
- graph_model = torch.export.export_for_training(model, example_input).module()
529
- quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
530
- prepared_model = prepare_pt2e(graph_model, quantizer)
531
- # ------- calibration -------
532
- with torch.inference_mode():
533
- move_exported_model_to_eval(prepared_model)
534
- for data in dataloader:
535
- x = data["token"]
536
- if x.shape[0] != nb:
537
- break
538
- if mlp is not None:
539
- y = mlp(data["value"])[:, None, :]
502
+ from torch.ao.nn.quantized.modules.utils import _quantize_weight
503
+ from torch.ao.nn.quantized import dynamic
504
+
505
+ class QuantisedLinear(dynamic.Linear):
506
+ # Modified from https://github.com/pytorch/pytorch/blob/main/torch/ao/nn/quantized/dynamic/modules/linear.py
507
+ # We made it compatible with our LoRA linear layer.
508
+ # LoRA parameters will not be quantised.
509
+ def __init__(
510
+ self,
511
+ in_features: int,
512
+ out_features: int,
513
+ bias_: bool = True,
514
+ dtype: torch.dtype = torch.qint8,
515
+ ) -> None:
516
+ super().__init__(in_features, out_features, bias_, dtype=dtype)
517
+ self.version = self._version
518
+ self.lora_enabled: bool = False
519
+ self.lora_A: Optional[nn.Parameter] = None
520
+ self.lora_B: Optional[nn.Parameter] = None
521
+ self.scaling: Optional[float] = None
522
+ self.lora_dropout: Optional[float] = None
523
+
524
+ def _get_name(self) -> str:
525
+ return "DynamicQuantizedLoRALinear"
526
+
527
+ def enable_lora(
528
+ self, r: int = 8, lora_alpha: int = 1, lora_dropout: float = 0.0
529
+ ) -> None:
530
+ assert r > 0, "Rank should be larger than 0."
531
+ device = self._weight_bias()[0].device
532
+ self.lora_A = nn.Parameter(
533
+ torch.zeros((r, self.in_features), device=device)
534
+ )
535
+ self.lora_B = nn.Parameter(
536
+ torch.zeros((self.out_features, r), device=device)
537
+ )
538
+ self.scaling = lora_alpha / r
539
+ self.lora_dropout = lora_dropout
540
+ self.lora_enabled = True
541
+ nn.init.kaiming_uniform_(self.lora_A, a=5**0.5)
542
+ nn.init.zeros_(self.lora_B)
543
+ self._packed_params.requires_grad_(False)
544
+
545
+ def forward(self, x: Tensor) -> Tensor:
546
+ result = dynamic.Linear.forward(self, x)
547
+ if self.lora_enabled and isinstance(self.lora_dropout, float):
548
+ result += (
549
+ nn.functional.dropout(x, self.lora_dropout, self.training)
550
+ @ self.lora_A.transpose(0, 1)
551
+ @ self.lora_B.transpose(0, 1)
552
+ ) * self.scaling
553
+ return result
554
+
555
+ @classmethod
556
+ def from_float(
557
+ cls, mod: Linear, use_precomputed_fake_quant: bool = False
558
+ ) -> Self:
559
+ assert hasattr(
560
+ mod, "qconfig"
561
+ ), "Input float module must have qconfig defined"
562
+ if mod.qconfig is not None and mod.qconfig.weight is not None:
563
+ weight_observer = mod.qconfig.weight()
564
+ else:
565
+ # We have the circular import issues if we import the qconfig in the beginning of this file:
566
+ # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
567
+ # import until we need it.
568
+ from torch.ao.quantization.qconfig import default_dynamic_qconfig
569
+
570
+ weight_observer = default_dynamic_qconfig.weight()
571
+ dtype = weight_observer.dtype
572
+ assert dtype in [torch.qint8, torch.float16], (
573
+ "The only supported dtypes for "
574
+ f"dynamic quantized linear are qint8 and float16 got: {dtype}"
575
+ )
576
+ weight_observer(mod.weight)
577
+ if dtype == torch.qint8:
578
+ qweight = _quantize_weight(mod.weight.float(), weight_observer)
579
+ elif dtype == torch.float16:
580
+ qweight = mod.weight.float()
540
581
  else:
541
- y = None
542
- t = torch.rand((x.shape[0], 1, 1))
543
- beta = model.calc_beta(t)
544
- e_x = torch.nn.functional.one_hot(x, model.K).float()
545
- mu = beta * (model.K * e_x - 1)
546
- sigma = (beta * model.K).sqrt()
547
- theta = softmax(mu + sigma * torch.randn_like(mu), -1)
548
- prepared_model(2 * theta - 1, t, None, y)
549
- # ---------------------------
550
- quantised_model = convert_pt2e(prepared_model)
551
- quantised_model = torch.export.export_for_training(
552
- quantised_model, example_input
553
- ).module() # remove the weights of original model
554
- quantised_model.sample = model.sample
555
- quantised_model.ode_sample = model.ode_sample
556
- quantised_model.inpaint = model.inpaint
557
- quantised_model.ode_inpaint = model.ode_inpaint
558
- if save_model:
559
- quantised_ep = torch.export.export(quantised_model, example_input)
560
- torch.export.save(quantised_ep, save_model_file_path)
582
+ raise RuntimeError(
583
+ "Unsupported dtype specified for dynamic quantized Linear!"
584
+ )
585
+ qlinear = cls(mod.in_features, mod.out_features, dtype=dtype)
586
+ qlinear.set_weight_bias(qweight, mod.bias)
587
+ if mod.lora_enabled:
588
+ qlinear.lora_enabled = True
589
+ qlinear.lora_A = nn.Parameter(mod.lora_A.clone().detach_())
590
+ qlinear.lora_B = nn.Parameter(mod.lora_B.clone().detach_())
591
+ qlinear.scaling = deepcopy(mod.scaling)
592
+ qlinear.lora_dropout = deepcopy(mod.lora_dropout)
593
+ return qlinear
594
+
595
+ @classmethod
596
+ def from_reference(cls, ref_qlinear: Self) -> Self:
597
+ qlinear = cls(
598
+ ref_qlinear.in_features,
599
+ ref_qlinear.out_features,
600
+ dtype=ref_qlinear.weight_dtype,
601
+ )
602
+ qweight = ref_qlinear.get_quantized_weight()
603
+ bias = ref_qlinear.bias
604
+ qlinear.set_weight_bias(qweight, bias)
605
+ if ref_qlinear.lora_enabled:
606
+ qlinear.lora_enabled = True
607
+ qlinear.lora_A = nn.Parameter(ref_qlinear.lora_A.clone().detach_())
608
+ qlinear.lora_B = nn.Parameter(ref_qlinear.lora_B.clone().detach_())
609
+ qlinear.scaling = deepcopy(ref_qlinear.scaling)
610
+ qlinear.lora_dropout = deepcopy(ref_qlinear.lora_dropout)
611
+ return qlinear
612
+
613
+ mapping = deepcopy(quantization.DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS)
614
+ mapping[Linear] = QuantisedLinear
615
+ quantised_model = quantization.quantize_dynamic(
616
+ model, {nn.Linear, Linear}, torch.qint8, mapping
617
+ )
561
618
  return quantised_model
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: bayesianflow_for_chem
3
- Version: 1.2.3
3
+ Version: 1.2.5
4
4
  Summary: Bayesian flow network framework for Chemistry
5
5
  Home-page: https://augus1999.github.io/bayesian-flow-network-for-chemistry/
6
6
  Author: Nianze A. Tao
@@ -0,0 +1,12 @@
1
+ bayesianflow_for_chem/__init__.py,sha256=GMGe5nU963qFL6vJ9OZSfqfSyEImC_P2zyUS0cyP3Mg,293
2
+ bayesianflow_for_chem/data.py,sha256=9tpRba40lxwrB6aPSJMkxUglEVC3VEQC9wWxhDuz3Q8,7760
3
+ bayesianflow_for_chem/model.py,sha256=HvEvW_xRbkv4eSv5lhd72BJMZkg-ZACEi1DAW3p5Q1Y,35918
4
+ bayesianflow_for_chem/scorer.py,sha256=mV1vX8aBGFra2BE7N8WHihVIo3dXmUdPQIGfSaiuNdk,4084
5
+ bayesianflow_for_chem/tool.py,sha256=tJjb8q3_orNkj2BYJwz5VxqeaOv55dvqO93_uigLJIk,23221
6
+ bayesianflow_for_chem/train.py,sha256=kj6icGqymUUYopDtpre1oE_wpvpeNilbpzgffBsd1tk,9589
7
+ bayesianflow_for_chem/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
8
+ bayesianflow_for_chem-1.2.5.dist-info/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
9
+ bayesianflow_for_chem-1.2.5.dist-info/METADATA,sha256=hwEEDW6ipmHpjRjQDKxWk5zqI9jwjsl-yxBpvYn93HQ,5890
10
+ bayesianflow_for_chem-1.2.5.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
11
+ bayesianflow_for_chem-1.2.5.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
12
+ bayesianflow_for_chem-1.2.5.dist-info/RECORD,,
@@ -1,12 +0,0 @@
1
- bayesianflow_for_chem/__init__.py,sha256=g03hNao3V1EGkajm0W0Ydrdnn7rpj7NRoz-JjdTBNUE,293
2
- bayesianflow_for_chem/data.py,sha256=9tpRba40lxwrB6aPSJMkxUglEVC3VEQC9wWxhDuz3Q8,7760
3
- bayesianflow_for_chem/model.py,sha256=CEwqUMahNEcVOZaFjv1JcBokktjW9LspFsYzKjzNmZk,35922
4
- bayesianflow_for_chem/scorer.py,sha256=mV1vX8aBGFra2BE7N8WHihVIo3dXmUdPQIGfSaiuNdk,4084
5
- bayesianflow_for_chem/tool.py,sha256=K6COLatSqvBXwcXV7QtVQX2sJOxfibwiIq0yIh96kfg,20818
6
- bayesianflow_for_chem/train.py,sha256=kj6icGqymUUYopDtpre1oE_wpvpeNilbpzgffBsd1tk,9589
7
- bayesianflow_for_chem/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
8
- bayesianflow_for_chem-1.2.3.dist-info/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
9
- bayesianflow_for_chem-1.2.3.dist-info/METADATA,sha256=QYbTrgY0QqfgpCFWYXbkeWjBVZ7wF9J7RraO6g7sbbI,5890
10
- bayesianflow_for_chem-1.2.3.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
11
- bayesianflow_for_chem-1.2.3.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
12
- bayesianflow_for_chem-1.2.3.dist-info/RECORD,,