bayesianflow-for-chem 1.2.2__py3-none-any.whl → 1.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bayesianflow-for-chem might be problematic. Click here for more details.
- bayesianflow_for_chem/__init__.py +1 -1
- bayesianflow_for_chem/model.py +2 -2
- bayesianflow_for_chem/tool.py +120 -55
- {bayesianflow_for_chem-1.2.2.dist-info → bayesianflow_for_chem-1.2.4.dist-info}/METADATA +1 -1
- bayesianflow_for_chem-1.2.4.dist-info/RECORD +12 -0
- bayesianflow_for_chem-1.2.2.dist-info/RECORD +0 -12
- {bayesianflow_for_chem-1.2.2.dist-info → bayesianflow_for_chem-1.2.4.dist-info}/LICENSE +0 -0
- {bayesianflow_for_chem-1.2.2.dist-info → bayesianflow_for_chem-1.2.4.dist-info}/WHEEL +0 -0
- {bayesianflow_for_chem-1.2.2.dist-info → bayesianflow_for_chem-1.2.4.dist-info}/top_level.txt +0 -0
bayesianflow_for_chem/model.py
CHANGED
|
@@ -847,7 +847,7 @@ class ChemBFN(nn.Module):
|
|
|
847
847
|
with open(ckpt, "rb") as f:
|
|
848
848
|
state = torch.load(f, "cpu", weights_only=True)
|
|
849
849
|
nn, hparam = state["nn"], state["hparam"]
|
|
850
|
-
model =
|
|
850
|
+
model = cls(
|
|
851
851
|
hparam["num_vocab"],
|
|
852
852
|
hparam["channel"],
|
|
853
853
|
hparam["num_layer"],
|
|
@@ -926,7 +926,7 @@ class MLP(nn.Module):
|
|
|
926
926
|
with open(ckpt, "rb") as f:
|
|
927
927
|
state = torch.load(f, "cpu", weights_only=True)
|
|
928
928
|
nn, hparam = state["nn"], state["hparam"]
|
|
929
|
-
model =
|
|
929
|
+
model = cls(hparam["size"], hparam["class_input"], hparam["dropout"])
|
|
930
930
|
model.load_state_dict(nn, strict)
|
|
931
931
|
return model
|
|
932
932
|
|
bayesianflow_for_chem/tool.py
CHANGED
|
@@ -6,17 +6,16 @@ Tools.
|
|
|
6
6
|
import re
|
|
7
7
|
import csv
|
|
8
8
|
import random
|
|
9
|
+
from copy import deepcopy
|
|
9
10
|
from pathlib import Path
|
|
10
11
|
from typing import List, Dict, Tuple, Union, Optional
|
|
11
12
|
import torch
|
|
12
13
|
import numpy as np
|
|
14
|
+
import torch.nn as nn
|
|
13
15
|
from torch import cuda, Tensor, softmax
|
|
16
|
+
from torch.ao import quantization
|
|
14
17
|
from torch.utils.data import DataLoader
|
|
15
|
-
from
|
|
16
|
-
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
|
|
17
|
-
XNNPACKQuantizer,
|
|
18
|
-
get_symmetric_quantization_config,
|
|
19
|
-
)
|
|
18
|
+
from typing_extensions import Self
|
|
20
19
|
from rdkit.Chem import rdDetermineBonds, Bond, MolFromXYZBlock, CanonicalRankAtoms
|
|
21
20
|
from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles # type: ignore
|
|
22
21
|
from sklearn.metrics import (
|
|
@@ -38,7 +37,7 @@ except ImportError:
|
|
|
38
37
|
_use_pynauty = False
|
|
39
38
|
|
|
40
39
|
from .data import VOCAB_KEYS
|
|
41
|
-
from .model import ChemBFN, MLP
|
|
40
|
+
from .model import ChemBFN, MLP, Linear
|
|
42
41
|
|
|
43
42
|
|
|
44
43
|
_atom_regex_pattern = (
|
|
@@ -385,10 +384,7 @@ def sample(
|
|
|
385
384
|
assert method.split(":")[0].lower() in ("ode", "bfn")
|
|
386
385
|
if device is None:
|
|
387
386
|
device = _find_device()
|
|
388
|
-
model.to(device)
|
|
389
|
-
if not isinstance(model, torch.fx.GraphModule):
|
|
390
|
-
model.eval() # Calling eval() is not supported for GraphModule
|
|
391
|
-
# model.to(device).eval()
|
|
387
|
+
model.to(device).eval()
|
|
392
388
|
if y is not None:
|
|
393
389
|
y = y.to(device)
|
|
394
390
|
if isinstance(allowed_tokens, list):
|
|
@@ -463,10 +459,7 @@ def inpaint(
|
|
|
463
459
|
assert method.split(":")[0].lower() in ("ode", "bfn")
|
|
464
460
|
if device is None:
|
|
465
461
|
device = _find_device()
|
|
466
|
-
model.to(device)
|
|
467
|
-
if not isinstance(model, torch.fx.GraphModule):
|
|
468
|
-
model.eval() # Calling eval() is not supported for GraphModule
|
|
469
|
-
# model.to(device).eval()
|
|
462
|
+
model.to(device).eval()
|
|
470
463
|
x = x.to(device)
|
|
471
464
|
if y is not None:
|
|
472
465
|
y = y.to(device)
|
|
@@ -497,52 +490,124 @@ def inpaint(
|
|
|
497
490
|
]
|
|
498
491
|
|
|
499
492
|
|
|
500
|
-
def quantise_model(
|
|
501
|
-
model: ChemBFN, dataloader: DataLoader, mlp: Optional[MLP] = None
|
|
502
|
-
) -> torch.fx.GraphModule:
|
|
493
|
+
def quantise_model(model: ChemBFN) -> nn.Module:
|
|
503
494
|
"""
|
|
504
|
-
|
|
495
|
+
Dynamic quantisation of the trained model.
|
|
505
496
|
|
|
506
497
|
:param model: trained ChemBFN model
|
|
507
|
-
:param dataloader: DataLoader instance containing example data for calibration
|
|
508
|
-
:param mlp: trained MLP model (guidance) if applied
|
|
509
498
|
:type model: bayesianflow_for_chem.model.ChemBFN
|
|
510
|
-
:type dataloader: torch.utils.data.DataLoader
|
|
511
|
-
:type mlp: bayesianflow_for_chem.model.MLP | None
|
|
512
499
|
:return: quantised model
|
|
513
|
-
:rtype: torch.
|
|
500
|
+
:rtype: torch.nn.Module
|
|
514
501
|
"""
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
502
|
+
from torch.ao.nn.quantized.modules.utils import _quantize_weight
|
|
503
|
+
from torch.ao.nn.quantized import dynamic
|
|
504
|
+
|
|
505
|
+
class QuantisedLinear(dynamic.Linear):
|
|
506
|
+
# Modified from https://github.com/pytorch/pytorch/blob/main/torch/ao/nn/quantized/dynamic/modules/linear.py
|
|
507
|
+
# We made it compatible with our LoRA linear layer.
|
|
508
|
+
# LoRA parameters will not be quantised.
|
|
509
|
+
def __init__(
|
|
510
|
+
self,
|
|
511
|
+
in_features: int,
|
|
512
|
+
out_features: int,
|
|
513
|
+
bias_: bool = True,
|
|
514
|
+
dtype: torch.dtype = torch.qint8,
|
|
515
|
+
) -> None:
|
|
516
|
+
super().__init__(in_features, out_features, bias_, dtype=dtype)
|
|
517
|
+
self.version = self._version
|
|
518
|
+
self.lora_enabled: bool = False
|
|
519
|
+
self.lora_A: Optional[nn.Parameter] = None
|
|
520
|
+
self.lora_B: Optional[nn.Parameter] = None
|
|
521
|
+
self.scaling: Optional[float] = None
|
|
522
|
+
self.lora_dropout: Optional[float] = None
|
|
523
|
+
|
|
524
|
+
def enable_lora(
|
|
525
|
+
self, r: int = 8, lora_alpha: int = 1, lora_dropout: float = 0.0
|
|
526
|
+
) -> None:
|
|
527
|
+
assert r > 0, "Rank should be larger than 0."
|
|
528
|
+
device = self._weight_bias()[0].device
|
|
529
|
+
self.lora_A = nn.Parameter(
|
|
530
|
+
torch.zeros((r, self.in_features), device=device)
|
|
531
|
+
)
|
|
532
|
+
self.lora_B = nn.Parameter(
|
|
533
|
+
torch.zeros((self.out_features, r), device=device)
|
|
534
|
+
)
|
|
535
|
+
self.scaling = lora_alpha / r
|
|
536
|
+
self.lora_dropout = lora_dropout
|
|
537
|
+
self.lora_enabled = True
|
|
538
|
+
nn.init.kaiming_uniform_(self.lora_A, a=5**0.5)
|
|
539
|
+
nn.init.zeros_(self.lora_B)
|
|
540
|
+
self._packed_params.requires_grad_(False)
|
|
541
|
+
|
|
542
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
543
|
+
# Note that we can handle self.bias == None case.
|
|
544
|
+
if self._packed_params.dtype == torch.qint8:
|
|
545
|
+
if self.version is None or self.version < 4:
|
|
546
|
+
Y = torch.ops.quantized.linear_dynamic(
|
|
547
|
+
x, self._packed_params._packed_params
|
|
548
|
+
)
|
|
549
|
+
else:
|
|
550
|
+
Y = torch.ops.quantized.linear_dynamic(
|
|
551
|
+
x, self._packed_params._packed_params, reduce_range=True
|
|
552
|
+
)
|
|
553
|
+
elif self._packed_params.dtype == torch.float16:
|
|
554
|
+
Y = torch.ops.quantized.linear_dynamic_fp16(
|
|
555
|
+
x, self._packed_params._packed_params
|
|
556
|
+
)
|
|
557
|
+
else:
|
|
558
|
+
raise RuntimeError("Unsupported dtype on dynamic quantized linear!")
|
|
559
|
+
result = Y.to(x.dtype)
|
|
560
|
+
if self.lora_enabled and isinstance(self.lora_dropout, float):
|
|
561
|
+
result += (
|
|
562
|
+
nn.functional.dropout(x, self.lora_dropout, self.training)
|
|
563
|
+
@ self.lora_A.transpose(0, 1)
|
|
564
|
+
@ self.lora_B.transpose(0, 1)
|
|
565
|
+
) * self.scaling
|
|
566
|
+
return result
|
|
567
|
+
|
|
568
|
+
@classmethod
|
|
569
|
+
def from_float(
|
|
570
|
+
cls, mod: Linear, use_precomputed_fake_quant: bool = False
|
|
571
|
+
) -> Self:
|
|
572
|
+
assert hasattr(
|
|
573
|
+
mod, "qconfig"
|
|
574
|
+
), "Input float module must have qconfig defined"
|
|
575
|
+
if mod.qconfig is not None and mod.qconfig.weight is not None:
|
|
576
|
+
weight_observer = mod.qconfig.weight()
|
|
531
577
|
else:
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
578
|
+
# We have the circular import issues if we import the qconfig in the beginning of this file:
|
|
579
|
+
# https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
|
|
580
|
+
# import until we need it.
|
|
581
|
+
from torch.ao.quantization.qconfig import default_dynamic_qconfig
|
|
582
|
+
|
|
583
|
+
weight_observer = default_dynamic_qconfig.weight()
|
|
584
|
+
dtype = weight_observer.dtype
|
|
585
|
+
assert dtype in [torch.qint8, torch.float16], (
|
|
586
|
+
"The only supported dtypes for "
|
|
587
|
+
f"dynamic quantized linear are qint8 and float16 got: {dtype}"
|
|
588
|
+
)
|
|
589
|
+
weight_observer(mod.weight)
|
|
590
|
+
if dtype == torch.qint8:
|
|
591
|
+
qweight = _quantize_weight(mod.weight.float(), weight_observer)
|
|
592
|
+
elif dtype == torch.float16:
|
|
593
|
+
qweight = mod.weight.float()
|
|
594
|
+
else:
|
|
595
|
+
raise RuntimeError(
|
|
596
|
+
"Unsupported dtype specified for dynamic quantized Linear!"
|
|
597
|
+
)
|
|
598
|
+
qlinear = cls(mod.in_features, mod.out_features, dtype=dtype)
|
|
599
|
+
qlinear.set_weight_bias(qweight, mod.bias)
|
|
600
|
+
if mod.lora_enabled:
|
|
601
|
+
qlinear.lora_enabled = True
|
|
602
|
+
qlinear.lora_A = mod.lora_A
|
|
603
|
+
qlinear.lora_B = mod.lora_B
|
|
604
|
+
qlinear.scaling = mod.scaling
|
|
605
|
+
qlinear.lora_dropout = mod.lora_dropout
|
|
606
|
+
return qlinear
|
|
607
|
+
|
|
608
|
+
mapping = deepcopy(quantization.DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS)
|
|
609
|
+
mapping[Linear] = QuantisedLinear
|
|
610
|
+
quantised_model = quantization.quantize_dynamic(
|
|
611
|
+
model, {nn.Linear, Linear}, torch.qint8, mapping
|
|
612
|
+
)
|
|
548
613
|
return quantised_model
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
bayesianflow_for_chem/__init__.py,sha256=-_0xD4lo_Vn2GrlXG-y13MCTwDfj391kzgTnyLplkNk,293
|
|
2
|
+
bayesianflow_for_chem/data.py,sha256=9tpRba40lxwrB6aPSJMkxUglEVC3VEQC9wWxhDuz3Q8,7760
|
|
3
|
+
bayesianflow_for_chem/model.py,sha256=HvEvW_xRbkv4eSv5lhd72BJMZkg-ZACEi1DAW3p5Q1Y,35918
|
|
4
|
+
bayesianflow_for_chem/scorer.py,sha256=mV1vX8aBGFra2BE7N8WHihVIo3dXmUdPQIGfSaiuNdk,4084
|
|
5
|
+
bayesianflow_for_chem/tool.py,sha256=d-g47Ctn6qb_j1bWCWV99ytUxJ23zJ32SJacQ_WXONk,23028
|
|
6
|
+
bayesianflow_for_chem/train.py,sha256=kj6icGqymUUYopDtpre1oE_wpvpeNilbpzgffBsd1tk,9589
|
|
7
|
+
bayesianflow_for_chem/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
|
|
8
|
+
bayesianflow_for_chem-1.2.4.dist-info/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
|
|
9
|
+
bayesianflow_for_chem-1.2.4.dist-info/METADATA,sha256=78FGoGjMsdwBavH4rSDtQ_psRYLSUdcg6cdR7KRmgVQ,5890
|
|
10
|
+
bayesianflow_for_chem-1.2.4.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
11
|
+
bayesianflow_for_chem-1.2.4.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
|
|
12
|
+
bayesianflow_for_chem-1.2.4.dist-info/RECORD,,
|
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
bayesianflow_for_chem/__init__.py,sha256=sPILW44_x_imRo2kKPMKWQ45C4aNfRQDo1it5Smqqmo,293
|
|
2
|
-
bayesianflow_for_chem/data.py,sha256=9tpRba40lxwrB6aPSJMkxUglEVC3VEQC9wWxhDuz3Q8,7760
|
|
3
|
-
bayesianflow_for_chem/model.py,sha256=CEwqUMahNEcVOZaFjv1JcBokktjW9LspFsYzKjzNmZk,35922
|
|
4
|
-
bayesianflow_for_chem/scorer.py,sha256=mV1vX8aBGFra2BE7N8WHihVIo3dXmUdPQIGfSaiuNdk,4084
|
|
5
|
-
bayesianflow_for_chem/tool.py,sha256=kjR-BUenSjqkwI-TB0QwYXEMy9qdPjL6y4BZVCVfzHA,20237
|
|
6
|
-
bayesianflow_for_chem/train.py,sha256=kj6icGqymUUYopDtpre1oE_wpvpeNilbpzgffBsd1tk,9589
|
|
7
|
-
bayesianflow_for_chem/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
|
|
8
|
-
bayesianflow_for_chem-1.2.2.dist-info/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
|
|
9
|
-
bayesianflow_for_chem-1.2.2.dist-info/METADATA,sha256=vRxX8mUrOJJwg_vkgXmla8s2vKhmZAVIYH_N3htAElQ,5890
|
|
10
|
-
bayesianflow_for_chem-1.2.2.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
11
|
-
bayesianflow_for_chem-1.2.2.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
|
|
12
|
-
bayesianflow_for_chem-1.2.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
{bayesianflow_for_chem-1.2.2.dist-info → bayesianflow_for_chem-1.2.4.dist-info}/top_level.txt
RENAMED
|
File without changes
|