bayesianflow-for-chem 1.2.3__py3-none-any.whl → 1.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bayesianflow-for-chem might be problematic. Click here for more details.
- bayesianflow_for_chem/__init__.py +1 -1
- bayesianflow_for_chem/model.py +2 -2
- bayesianflow_for_chem/tool.py +120 -68
- {bayesianflow_for_chem-1.2.3.dist-info → bayesianflow_for_chem-1.2.4.dist-info}/METADATA +1 -1
- bayesianflow_for_chem-1.2.4.dist-info/RECORD +12 -0
- bayesianflow_for_chem-1.2.3.dist-info/RECORD +0 -12
- {bayesianflow_for_chem-1.2.3.dist-info → bayesianflow_for_chem-1.2.4.dist-info}/LICENSE +0 -0
- {bayesianflow_for_chem-1.2.3.dist-info → bayesianflow_for_chem-1.2.4.dist-info}/WHEEL +0 -0
- {bayesianflow_for_chem-1.2.3.dist-info → bayesianflow_for_chem-1.2.4.dist-info}/top_level.txt +0 -0
bayesianflow_for_chem/model.py
CHANGED
|
@@ -847,7 +847,7 @@ class ChemBFN(nn.Module):
|
|
|
847
847
|
with open(ckpt, "rb") as f:
|
|
848
848
|
state = torch.load(f, "cpu", weights_only=True)
|
|
849
849
|
nn, hparam = state["nn"], state["hparam"]
|
|
850
|
-
model =
|
|
850
|
+
model = cls(
|
|
851
851
|
hparam["num_vocab"],
|
|
852
852
|
hparam["channel"],
|
|
853
853
|
hparam["num_layer"],
|
|
@@ -926,7 +926,7 @@ class MLP(nn.Module):
|
|
|
926
926
|
with open(ckpt, "rb") as f:
|
|
927
927
|
state = torch.load(f, "cpu", weights_only=True)
|
|
928
928
|
nn, hparam = state["nn"], state["hparam"]
|
|
929
|
-
model =
|
|
929
|
+
model = cls(hparam["size"], hparam["class_input"], hparam["dropout"])
|
|
930
930
|
model.load_state_dict(nn, strict)
|
|
931
931
|
return model
|
|
932
932
|
|
bayesianflow_for_chem/tool.py
CHANGED
|
@@ -6,18 +6,16 @@ Tools.
|
|
|
6
6
|
import re
|
|
7
7
|
import csv
|
|
8
8
|
import random
|
|
9
|
+
from copy import deepcopy
|
|
9
10
|
from pathlib import Path
|
|
10
11
|
from typing import List, Dict, Tuple, Union, Optional
|
|
11
12
|
import torch
|
|
12
13
|
import numpy as np
|
|
14
|
+
import torch.nn as nn
|
|
13
15
|
from torch import cuda, Tensor, softmax
|
|
14
|
-
from torch.ao
|
|
16
|
+
from torch.ao import quantization
|
|
15
17
|
from torch.utils.data import DataLoader
|
|
16
|
-
from
|
|
17
|
-
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
|
|
18
|
-
XNNPACKQuantizer,
|
|
19
|
-
get_symmetric_quantization_config,
|
|
20
|
-
)
|
|
18
|
+
from typing_extensions import Self
|
|
21
19
|
from rdkit.Chem import rdDetermineBonds, Bond, MolFromXYZBlock, CanonicalRankAtoms
|
|
22
20
|
from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles # type: ignore
|
|
23
21
|
from sklearn.metrics import (
|
|
@@ -39,7 +37,7 @@ except ImportError:
|
|
|
39
37
|
_use_pynauty = False
|
|
40
38
|
|
|
41
39
|
from .data import VOCAB_KEYS
|
|
42
|
-
from .model import ChemBFN, MLP
|
|
40
|
+
from .model import ChemBFN, MLP, Linear
|
|
43
41
|
|
|
44
42
|
|
|
45
43
|
_atom_regex_pattern = (
|
|
@@ -386,9 +384,7 @@ def sample(
|
|
|
386
384
|
assert method.split(":")[0].lower() in ("ode", "bfn")
|
|
387
385
|
if device is None:
|
|
388
386
|
device = _find_device()
|
|
389
|
-
model.to(device)
|
|
390
|
-
if not isinstance(model, torch.fx.GraphModule):
|
|
391
|
-
model.eval() # Calling eval() is not supported for GraphModule
|
|
387
|
+
model.to(device).eval()
|
|
392
388
|
if y is not None:
|
|
393
389
|
y = y.to(device)
|
|
394
390
|
if isinstance(allowed_tokens, list):
|
|
@@ -463,9 +459,7 @@ def inpaint(
|
|
|
463
459
|
assert method.split(":")[0].lower() in ("ode", "bfn")
|
|
464
460
|
if device is None:
|
|
465
461
|
device = _find_device()
|
|
466
|
-
model.to(device)
|
|
467
|
-
if not isinstance(model, torch.fx.GraphModule):
|
|
468
|
-
model.eval() # Calling eval() is not supported for GraphModule
|
|
462
|
+
model.to(device).eval()
|
|
469
463
|
x = x.to(device)
|
|
470
464
|
if y is not None:
|
|
471
465
|
y = y.to(device)
|
|
@@ -496,66 +490,124 @@ def inpaint(
|
|
|
496
490
|
]
|
|
497
491
|
|
|
498
492
|
|
|
499
|
-
def quantise_model(
|
|
500
|
-
model: ChemBFN,
|
|
501
|
-
dataloader: DataLoader,
|
|
502
|
-
mlp: Optional[MLP] = None,
|
|
503
|
-
save_model: bool = False,
|
|
504
|
-
save_model_file_path: Union[str, Path] = "qmodel.pt",
|
|
505
|
-
) -> torch.fx.GraphModule:
|
|
493
|
+
def quantise_model(model: ChemBFN) -> nn.Module:
|
|
506
494
|
"""
|
|
507
|
-
|
|
495
|
+
Dynamic quantisation of the trained model.
|
|
508
496
|
|
|
509
497
|
:param model: trained ChemBFN model
|
|
510
|
-
:param dataloader: DataLoader instance containing example data for calibration
|
|
511
|
-
:param mlp: trained MLP model (guidance) if applied
|
|
512
|
-
:param save_model: whether to save the model
|
|
513
|
-
:param save_model_file_path: file name of the saved model; not used if `save_model=False`
|
|
514
498
|
:type model: bayesianflow_for_chem.model.ChemBFN
|
|
515
|
-
:type dataloader: torch.utils.data.DataLoader
|
|
516
|
-
:type mlp: bayesianflow_for_chem.model.MLP | None
|
|
517
|
-
:type save_model: bool
|
|
518
|
-
:type save_model_file_path: str | pathlib.Path
|
|
519
499
|
:return: quantised model
|
|
520
|
-
:rtype: torch.
|
|
500
|
+
:rtype: torch.nn.Module
|
|
521
501
|
"""
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
502
|
+
from torch.ao.nn.quantized.modules.utils import _quantize_weight
|
|
503
|
+
from torch.ao.nn.quantized import dynamic
|
|
504
|
+
|
|
505
|
+
class QuantisedLinear(dynamic.Linear):
|
|
506
|
+
# Modified from https://github.com/pytorch/pytorch/blob/main/torch/ao/nn/quantized/dynamic/modules/linear.py
|
|
507
|
+
# We made it compatible with our LoRA linear layer.
|
|
508
|
+
# LoRA parameters will not be quantised.
|
|
509
|
+
def __init__(
|
|
510
|
+
self,
|
|
511
|
+
in_features: int,
|
|
512
|
+
out_features: int,
|
|
513
|
+
bias_: bool = True,
|
|
514
|
+
dtype: torch.dtype = torch.qint8,
|
|
515
|
+
) -> None:
|
|
516
|
+
super().__init__(in_features, out_features, bias_, dtype=dtype)
|
|
517
|
+
self.version = self._version
|
|
518
|
+
self.lora_enabled: bool = False
|
|
519
|
+
self.lora_A: Optional[nn.Parameter] = None
|
|
520
|
+
self.lora_B: Optional[nn.Parameter] = None
|
|
521
|
+
self.scaling: Optional[float] = None
|
|
522
|
+
self.lora_dropout: Optional[float] = None
|
|
523
|
+
|
|
524
|
+
def enable_lora(
|
|
525
|
+
self, r: int = 8, lora_alpha: int = 1, lora_dropout: float = 0.0
|
|
526
|
+
) -> None:
|
|
527
|
+
assert r > 0, "Rank should be larger than 0."
|
|
528
|
+
device = self._weight_bias()[0].device
|
|
529
|
+
self.lora_A = nn.Parameter(
|
|
530
|
+
torch.zeros((r, self.in_features), device=device)
|
|
531
|
+
)
|
|
532
|
+
self.lora_B = nn.Parameter(
|
|
533
|
+
torch.zeros((self.out_features, r), device=device)
|
|
534
|
+
)
|
|
535
|
+
self.scaling = lora_alpha / r
|
|
536
|
+
self.lora_dropout = lora_dropout
|
|
537
|
+
self.lora_enabled = True
|
|
538
|
+
nn.init.kaiming_uniform_(self.lora_A, a=5**0.5)
|
|
539
|
+
nn.init.zeros_(self.lora_B)
|
|
540
|
+
self._packed_params.requires_grad_(False)
|
|
541
|
+
|
|
542
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
543
|
+
# Note that we can handle self.bias == None case.
|
|
544
|
+
if self._packed_params.dtype == torch.qint8:
|
|
545
|
+
if self.version is None or self.version < 4:
|
|
546
|
+
Y = torch.ops.quantized.linear_dynamic(
|
|
547
|
+
x, self._packed_params._packed_params
|
|
548
|
+
)
|
|
549
|
+
else:
|
|
550
|
+
Y = torch.ops.quantized.linear_dynamic(
|
|
551
|
+
x, self._packed_params._packed_params, reduce_range=True
|
|
552
|
+
)
|
|
553
|
+
elif self._packed_params.dtype == torch.float16:
|
|
554
|
+
Y = torch.ops.quantized.linear_dynamic_fp16(
|
|
555
|
+
x, self._packed_params._packed_params
|
|
556
|
+
)
|
|
557
|
+
else:
|
|
558
|
+
raise RuntimeError("Unsupported dtype on dynamic quantized linear!")
|
|
559
|
+
result = Y.to(x.dtype)
|
|
560
|
+
if self.lora_enabled and isinstance(self.lora_dropout, float):
|
|
561
|
+
result += (
|
|
562
|
+
nn.functional.dropout(x, self.lora_dropout, self.training)
|
|
563
|
+
@ self.lora_A.transpose(0, 1)
|
|
564
|
+
@ self.lora_B.transpose(0, 1)
|
|
565
|
+
) * self.scaling
|
|
566
|
+
return result
|
|
567
|
+
|
|
568
|
+
@classmethod
|
|
569
|
+
def from_float(
|
|
570
|
+
cls, mod: Linear, use_precomputed_fake_quant: bool = False
|
|
571
|
+
) -> Self:
|
|
572
|
+
assert hasattr(
|
|
573
|
+
mod, "qconfig"
|
|
574
|
+
), "Input float module must have qconfig defined"
|
|
575
|
+
if mod.qconfig is not None and mod.qconfig.weight is not None:
|
|
576
|
+
weight_observer = mod.qconfig.weight()
|
|
540
577
|
else:
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
578
|
+
# We have the circular import issues if we import the qconfig in the beginning of this file:
|
|
579
|
+
# https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
|
|
580
|
+
# import until we need it.
|
|
581
|
+
from torch.ao.quantization.qconfig import default_dynamic_qconfig
|
|
582
|
+
|
|
583
|
+
weight_observer = default_dynamic_qconfig.weight()
|
|
584
|
+
dtype = weight_observer.dtype
|
|
585
|
+
assert dtype in [torch.qint8, torch.float16], (
|
|
586
|
+
"The only supported dtypes for "
|
|
587
|
+
f"dynamic quantized linear are qint8 and float16 got: {dtype}"
|
|
588
|
+
)
|
|
589
|
+
weight_observer(mod.weight)
|
|
590
|
+
if dtype == torch.qint8:
|
|
591
|
+
qweight = _quantize_weight(mod.weight.float(), weight_observer)
|
|
592
|
+
elif dtype == torch.float16:
|
|
593
|
+
qweight = mod.weight.float()
|
|
594
|
+
else:
|
|
595
|
+
raise RuntimeError(
|
|
596
|
+
"Unsupported dtype specified for dynamic quantized Linear!"
|
|
597
|
+
)
|
|
598
|
+
qlinear = cls(mod.in_features, mod.out_features, dtype=dtype)
|
|
599
|
+
qlinear.set_weight_bias(qweight, mod.bias)
|
|
600
|
+
if mod.lora_enabled:
|
|
601
|
+
qlinear.lora_enabled = True
|
|
602
|
+
qlinear.lora_A = mod.lora_A
|
|
603
|
+
qlinear.lora_B = mod.lora_B
|
|
604
|
+
qlinear.scaling = mod.scaling
|
|
605
|
+
qlinear.lora_dropout = mod.lora_dropout
|
|
606
|
+
return qlinear
|
|
607
|
+
|
|
608
|
+
mapping = deepcopy(quantization.DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS)
|
|
609
|
+
mapping[Linear] = QuantisedLinear
|
|
610
|
+
quantised_model = quantization.quantize_dynamic(
|
|
611
|
+
model, {nn.Linear, Linear}, torch.qint8, mapping
|
|
612
|
+
)
|
|
561
613
|
return quantised_model
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
bayesianflow_for_chem/__init__.py,sha256=-_0xD4lo_Vn2GrlXG-y13MCTwDfj391kzgTnyLplkNk,293
|
|
2
|
+
bayesianflow_for_chem/data.py,sha256=9tpRba40lxwrB6aPSJMkxUglEVC3VEQC9wWxhDuz3Q8,7760
|
|
3
|
+
bayesianflow_for_chem/model.py,sha256=HvEvW_xRbkv4eSv5lhd72BJMZkg-ZACEi1DAW3p5Q1Y,35918
|
|
4
|
+
bayesianflow_for_chem/scorer.py,sha256=mV1vX8aBGFra2BE7N8WHihVIo3dXmUdPQIGfSaiuNdk,4084
|
|
5
|
+
bayesianflow_for_chem/tool.py,sha256=d-g47Ctn6qb_j1bWCWV99ytUxJ23zJ32SJacQ_WXONk,23028
|
|
6
|
+
bayesianflow_for_chem/train.py,sha256=kj6icGqymUUYopDtpre1oE_wpvpeNilbpzgffBsd1tk,9589
|
|
7
|
+
bayesianflow_for_chem/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
|
|
8
|
+
bayesianflow_for_chem-1.2.4.dist-info/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
|
|
9
|
+
bayesianflow_for_chem-1.2.4.dist-info/METADATA,sha256=78FGoGjMsdwBavH4rSDtQ_psRYLSUdcg6cdR7KRmgVQ,5890
|
|
10
|
+
bayesianflow_for_chem-1.2.4.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
11
|
+
bayesianflow_for_chem-1.2.4.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
|
|
12
|
+
bayesianflow_for_chem-1.2.4.dist-info/RECORD,,
|
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
bayesianflow_for_chem/__init__.py,sha256=g03hNao3V1EGkajm0W0Ydrdnn7rpj7NRoz-JjdTBNUE,293
|
|
2
|
-
bayesianflow_for_chem/data.py,sha256=9tpRba40lxwrB6aPSJMkxUglEVC3VEQC9wWxhDuz3Q8,7760
|
|
3
|
-
bayesianflow_for_chem/model.py,sha256=CEwqUMahNEcVOZaFjv1JcBokktjW9LspFsYzKjzNmZk,35922
|
|
4
|
-
bayesianflow_for_chem/scorer.py,sha256=mV1vX8aBGFra2BE7N8WHihVIo3dXmUdPQIGfSaiuNdk,4084
|
|
5
|
-
bayesianflow_for_chem/tool.py,sha256=K6COLatSqvBXwcXV7QtVQX2sJOxfibwiIq0yIh96kfg,20818
|
|
6
|
-
bayesianflow_for_chem/train.py,sha256=kj6icGqymUUYopDtpre1oE_wpvpeNilbpzgffBsd1tk,9589
|
|
7
|
-
bayesianflow_for_chem/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
|
|
8
|
-
bayesianflow_for_chem-1.2.3.dist-info/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
|
|
9
|
-
bayesianflow_for_chem-1.2.3.dist-info/METADATA,sha256=QYbTrgY0QqfgpCFWYXbkeWjBVZ7wF9J7RraO6g7sbbI,5890
|
|
10
|
-
bayesianflow_for_chem-1.2.3.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
11
|
-
bayesianflow_for_chem-1.2.3.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
|
|
12
|
-
bayesianflow_for_chem-1.2.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
{bayesianflow_for_chem-1.2.3.dist-info → bayesianflow_for_chem-1.2.4.dist-info}/top_level.txt
RENAMED
|
File without changes
|