bayesianflow-for-chem 1.2.2__py3-none-any.whl → 1.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bayesianflow-for-chem might be problematic. Click here for more details.

@@ -7,5 +7,5 @@ from . import data, tool, train, scorer
7
7
  from .model import ChemBFN, MLP
8
8
 
9
9
  __all__ = ["data", "tool", "train", "scorer", "ChemBFN", "MLP"]
10
- __version__ = "1.2.2"
10
+ __version__ = "1.2.4"
11
11
  __author__ = "Nianze A. Tao (Omozawa Sueno)"
@@ -847,7 +847,7 @@ class ChemBFN(nn.Module):
847
847
  with open(ckpt, "rb") as f:
848
848
  state = torch.load(f, "cpu", weights_only=True)
849
849
  nn, hparam = state["nn"], state["hparam"]
850
- model = ChemBFN(
850
+ model = cls(
851
851
  hparam["num_vocab"],
852
852
  hparam["channel"],
853
853
  hparam["num_layer"],
@@ -926,7 +926,7 @@ class MLP(nn.Module):
926
926
  with open(ckpt, "rb") as f:
927
927
  state = torch.load(f, "cpu", weights_only=True)
928
928
  nn, hparam = state["nn"], state["hparam"]
929
- model = MLP(hparam["size"], hparam["class_input"], hparam["dropout"])
929
+ model = cls(hparam["size"], hparam["class_input"], hparam["dropout"])
930
930
  model.load_state_dict(nn, strict)
931
931
  return model
932
932
 
@@ -6,17 +6,16 @@ Tools.
6
6
  import re
7
7
  import csv
8
8
  import random
9
+ from copy import deepcopy
9
10
  from pathlib import Path
10
11
  from typing import List, Dict, Tuple, Union, Optional
11
12
  import torch
12
13
  import numpy as np
14
+ import torch.nn as nn
13
15
  from torch import cuda, Tensor, softmax
16
+ from torch.ao import quantization
14
17
  from torch.utils.data import DataLoader
15
- from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
16
- from torch.ao.quantization.quantizer.xnnpack_quantizer import (
17
- XNNPACKQuantizer,
18
- get_symmetric_quantization_config,
19
- )
18
+ from typing_extensions import Self
20
19
  from rdkit.Chem import rdDetermineBonds, Bond, MolFromXYZBlock, CanonicalRankAtoms
21
20
  from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles # type: ignore
22
21
  from sklearn.metrics import (
@@ -38,7 +37,7 @@ except ImportError:
38
37
  _use_pynauty = False
39
38
 
40
39
  from .data import VOCAB_KEYS
41
- from .model import ChemBFN, MLP
40
+ from .model import ChemBFN, MLP, Linear
42
41
 
43
42
 
44
43
  _atom_regex_pattern = (
@@ -385,10 +384,7 @@ def sample(
385
384
  assert method.split(":")[0].lower() in ("ode", "bfn")
386
385
  if device is None:
387
386
  device = _find_device()
388
- model.to(device)
389
- if not isinstance(model, torch.fx.GraphModule):
390
- model.eval() # Calling eval() is not supported for GraphModule
391
- # model.to(device).eval()
387
+ model.to(device).eval()
392
388
  if y is not None:
393
389
  y = y.to(device)
394
390
  if isinstance(allowed_tokens, list):
@@ -463,10 +459,7 @@ def inpaint(
463
459
  assert method.split(":")[0].lower() in ("ode", "bfn")
464
460
  if device is None:
465
461
  device = _find_device()
466
- model.to(device)
467
- if not isinstance(model, torch.fx.GraphModule):
468
- model.eval() # Calling eval() is not supported for GraphModule
469
- # model.to(device).eval()
462
+ model.to(device).eval()
470
463
  x = x.to(device)
471
464
  if y is not None:
472
465
  y = y.to(device)
@@ -497,52 +490,124 @@ def inpaint(
497
490
  ]
498
491
 
499
492
 
500
- def quantise_model(
501
- model: ChemBFN, dataloader: DataLoader, mlp: Optional[MLP] = None
502
- ) -> torch.fx.GraphModule:
493
+ def quantise_model(model: ChemBFN) -> nn.Module:
503
494
  """
504
- Static quantisation of the input model.
495
+ Dynamic quantisation of the trained model.
505
496
 
506
497
  :param model: trained ChemBFN model
507
- :param dataloader: DataLoader instance containing example data for calibration
508
- :param mlp: trained MLP model (guidance) if applied
509
498
  :type model: bayesianflow_for_chem.model.ChemBFN
510
- :type dataloader: torch.utils.data.DataLoader
511
- :type mlp: bayesianflow_for_chem.model.MLP | None
512
499
  :return: quantised model
513
- :rtype: torch.fx.GraphModule
500
+ :rtype: torch.nn.Module
514
501
  """
515
- nb, nt = dataloader._get_iterator()._next_data()["token"].shape
516
- x = 2 * softmax(torch.rand((nb, nt, model.K)), -1) - 1
517
- t = torch.rand((nb, 1, 1))
518
- y = torch.randn(nb, 1, model.embedding.weight.shape[0]) if mlp is not None else None
519
- example_input = (2 * x - 1, t, None, y)
520
- graph_model = torch.export.export_for_training(model, example_input).module()
521
- quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
522
- prepared_model = prepare_pt2e(graph_model, quantizer)
523
- # ------- calibration -------
524
- with torch.inference_mode():
525
- for data in dataloader:
526
- x = data["token"]
527
- if x.shape[0] != nb:
528
- break
529
- if mlp is not None:
530
- y = mlp(data["value"])[:, None, :]
502
+ from torch.ao.nn.quantized.modules.utils import _quantize_weight
503
+ from torch.ao.nn.quantized import dynamic
504
+
505
+ class QuantisedLinear(dynamic.Linear):
506
+ # Modified from https://github.com/pytorch/pytorch/blob/main/torch/ao/nn/quantized/dynamic/modules/linear.py
507
+ # We made it compatible with our LoRA linear layer.
508
+ # LoRA parameters will not be quantised.
509
+ def __init__(
510
+ self,
511
+ in_features: int,
512
+ out_features: int,
513
+ bias_: bool = True,
514
+ dtype: torch.dtype = torch.qint8,
515
+ ) -> None:
516
+ super().__init__(in_features, out_features, bias_, dtype=dtype)
517
+ self.version = self._version
518
+ self.lora_enabled: bool = False
519
+ self.lora_A: Optional[nn.Parameter] = None
520
+ self.lora_B: Optional[nn.Parameter] = None
521
+ self.scaling: Optional[float] = None
522
+ self.lora_dropout: Optional[float] = None
523
+
524
+ def enable_lora(
525
+ self, r: int = 8, lora_alpha: int = 1, lora_dropout: float = 0.0
526
+ ) -> None:
527
+ assert r > 0, "Rank should be larger than 0."
528
+ device = self._weight_bias()[0].device
529
+ self.lora_A = nn.Parameter(
530
+ torch.zeros((r, self.in_features), device=device)
531
+ )
532
+ self.lora_B = nn.Parameter(
533
+ torch.zeros((self.out_features, r), device=device)
534
+ )
535
+ self.scaling = lora_alpha / r
536
+ self.lora_dropout = lora_dropout
537
+ self.lora_enabled = True
538
+ nn.init.kaiming_uniform_(self.lora_A, a=5**0.5)
539
+ nn.init.zeros_(self.lora_B)
540
+ self._packed_params.requires_grad_(False)
541
+
542
+ def forward(self, x: Tensor) -> Tensor:
543
+ # Note that we can handle self.bias == None case.
544
+ if self._packed_params.dtype == torch.qint8:
545
+ if self.version is None or self.version < 4:
546
+ Y = torch.ops.quantized.linear_dynamic(
547
+ x, self._packed_params._packed_params
548
+ )
549
+ else:
550
+ Y = torch.ops.quantized.linear_dynamic(
551
+ x, self._packed_params._packed_params, reduce_range=True
552
+ )
553
+ elif self._packed_params.dtype == torch.float16:
554
+ Y = torch.ops.quantized.linear_dynamic_fp16(
555
+ x, self._packed_params._packed_params
556
+ )
557
+ else:
558
+ raise RuntimeError("Unsupported dtype on dynamic quantized linear!")
559
+ result = Y.to(x.dtype)
560
+ if self.lora_enabled and isinstance(self.lora_dropout, float):
561
+ result += (
562
+ nn.functional.dropout(x, self.lora_dropout, self.training)
563
+ @ self.lora_A.transpose(0, 1)
564
+ @ self.lora_B.transpose(0, 1)
565
+ ) * self.scaling
566
+ return result
567
+
568
+ @classmethod
569
+ def from_float(
570
+ cls, mod: Linear, use_precomputed_fake_quant: bool = False
571
+ ) -> Self:
572
+ assert hasattr(
573
+ mod, "qconfig"
574
+ ), "Input float module must have qconfig defined"
575
+ if mod.qconfig is not None and mod.qconfig.weight is not None:
576
+ weight_observer = mod.qconfig.weight()
531
577
  else:
532
- y = None
533
- t = torch.rand((x.shape[0], 1, 1))
534
- beta = model.calc_beta(t)
535
- e_x = torch.nn.functional.one_hot(x, model.K).float()
536
- mu = beta * (model.K * e_x - 1)
537
- sigma = (beta * model.K).sqrt()
538
- theta = softmax(mu + sigma * torch.randn_like(mu), -1)
539
- prepared_model(2 * theta - 1, t, None, y)
540
- quantised_model = convert_pt2e(prepared_model)
541
- quantised_model = torch.export.export_for_training(
542
- quantised_model, example_input
543
- ).module() # remove the weights of original model
544
- quantised_model.sample = model.sample
545
- quantised_model.ode_sample = model.ode_sample
546
- quantised_model.inpaint = model.inpaint
547
- quantised_model.ode_inpaint = model.ode_inpaint
578
+ # We have the circular import issues if we import the qconfig in the beginning of this file:
579
+ # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
580
+ # import until we need it.
581
+ from torch.ao.quantization.qconfig import default_dynamic_qconfig
582
+
583
+ weight_observer = default_dynamic_qconfig.weight()
584
+ dtype = weight_observer.dtype
585
+ assert dtype in [torch.qint8, torch.float16], (
586
+ "The only supported dtypes for "
587
+ f"dynamic quantized linear are qint8 and float16 got: {dtype}"
588
+ )
589
+ weight_observer(mod.weight)
590
+ if dtype == torch.qint8:
591
+ qweight = _quantize_weight(mod.weight.float(), weight_observer)
592
+ elif dtype == torch.float16:
593
+ qweight = mod.weight.float()
594
+ else:
595
+ raise RuntimeError(
596
+ "Unsupported dtype specified for dynamic quantized Linear!"
597
+ )
598
+ qlinear = cls(mod.in_features, mod.out_features, dtype=dtype)
599
+ qlinear.set_weight_bias(qweight, mod.bias)
600
+ if mod.lora_enabled:
601
+ qlinear.lora_enabled = True
602
+ qlinear.lora_A = mod.lora_A
603
+ qlinear.lora_B = mod.lora_B
604
+ qlinear.scaling = mod.scaling
605
+ qlinear.lora_dropout = mod.lora_dropout
606
+ return qlinear
607
+
608
+ mapping = deepcopy(quantization.DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS)
609
+ mapping[Linear] = QuantisedLinear
610
+ quantised_model = quantization.quantize_dynamic(
611
+ model, {nn.Linear, Linear}, torch.qint8, mapping
612
+ )
548
613
  return quantised_model
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: bayesianflow_for_chem
3
- Version: 1.2.2
3
+ Version: 1.2.4
4
4
  Summary: Bayesian flow network framework for Chemistry
5
5
  Home-page: https://augus1999.github.io/bayesian-flow-network-for-chemistry/
6
6
  Author: Nianze A. Tao
@@ -0,0 +1,12 @@
1
+ bayesianflow_for_chem/__init__.py,sha256=-_0xD4lo_Vn2GrlXG-y13MCTwDfj391kzgTnyLplkNk,293
2
+ bayesianflow_for_chem/data.py,sha256=9tpRba40lxwrB6aPSJMkxUglEVC3VEQC9wWxhDuz3Q8,7760
3
+ bayesianflow_for_chem/model.py,sha256=HvEvW_xRbkv4eSv5lhd72BJMZkg-ZACEi1DAW3p5Q1Y,35918
4
+ bayesianflow_for_chem/scorer.py,sha256=mV1vX8aBGFra2BE7N8WHihVIo3dXmUdPQIGfSaiuNdk,4084
5
+ bayesianflow_for_chem/tool.py,sha256=d-g47Ctn6qb_j1bWCWV99ytUxJ23zJ32SJacQ_WXONk,23028
6
+ bayesianflow_for_chem/train.py,sha256=kj6icGqymUUYopDtpre1oE_wpvpeNilbpzgffBsd1tk,9589
7
+ bayesianflow_for_chem/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
8
+ bayesianflow_for_chem-1.2.4.dist-info/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
9
+ bayesianflow_for_chem-1.2.4.dist-info/METADATA,sha256=78FGoGjMsdwBavH4rSDtQ_psRYLSUdcg6cdR7KRmgVQ,5890
10
+ bayesianflow_for_chem-1.2.4.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
11
+ bayesianflow_for_chem-1.2.4.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
12
+ bayesianflow_for_chem-1.2.4.dist-info/RECORD,,
@@ -1,12 +0,0 @@
1
- bayesianflow_for_chem/__init__.py,sha256=sPILW44_x_imRo2kKPMKWQ45C4aNfRQDo1it5Smqqmo,293
2
- bayesianflow_for_chem/data.py,sha256=9tpRba40lxwrB6aPSJMkxUglEVC3VEQC9wWxhDuz3Q8,7760
3
- bayesianflow_for_chem/model.py,sha256=CEwqUMahNEcVOZaFjv1JcBokktjW9LspFsYzKjzNmZk,35922
4
- bayesianflow_for_chem/scorer.py,sha256=mV1vX8aBGFra2BE7N8WHihVIo3dXmUdPQIGfSaiuNdk,4084
5
- bayesianflow_for_chem/tool.py,sha256=kjR-BUenSjqkwI-TB0QwYXEMy9qdPjL6y4BZVCVfzHA,20237
6
- bayesianflow_for_chem/train.py,sha256=kj6icGqymUUYopDtpre1oE_wpvpeNilbpzgffBsd1tk,9589
7
- bayesianflow_for_chem/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
8
- bayesianflow_for_chem-1.2.2.dist-info/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
9
- bayesianflow_for_chem-1.2.2.dist-info/METADATA,sha256=vRxX8mUrOJJwg_vkgXmla8s2vKhmZAVIYH_N3htAElQ,5890
10
- bayesianflow_for_chem-1.2.2.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
11
- bayesianflow_for_chem-1.2.2.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
12
- bayesianflow_for_chem-1.2.2.dist-info/RECORD,,