bayesianflow-for-chem 1.4.0__py3-none-any.whl → 1.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bayesianflow-for-chem might be problematic. Click here for more details.
- bayesianflow_for_chem/__init__.py +1 -1
- bayesianflow_for_chem/model.py +21 -26
- bayesianflow_for_chem/tool.py +86 -14
- {bayesianflow_for_chem-1.4.0.dist-info → bayesianflow_for_chem-1.4.2.dist-info}/METADATA +2 -1
- bayesianflow_for_chem-1.4.2.dist-info/RECORD +12 -0
- bayesianflow_for_chem-1.4.0.dist-info/RECORD +0 -12
- {bayesianflow_for_chem-1.4.0.dist-info → bayesianflow_for_chem-1.4.2.dist-info}/WHEEL +0 -0
- {bayesianflow_for_chem-1.4.0.dist-info → bayesianflow_for_chem-1.4.2.dist-info}/licenses/LICENSE +0 -0
- {bayesianflow_for_chem-1.4.0.dist-info → bayesianflow_for_chem-1.4.2.dist-info}/top_level.txt +0 -0
|
@@ -7,5 +7,5 @@ from . import data, tool, train, scorer
|
|
|
7
7
|
from .model import ChemBFN, MLP, EnsembleChemBFN
|
|
8
8
|
|
|
9
9
|
__all__ = ["data", "tool", "train", "scorer", "ChemBFN", "MLP", "EnsembleChemBFN"]
|
|
10
|
-
__version__ = "1.4.
|
|
10
|
+
__version__ = "1.4.2"
|
|
11
11
|
__author__ = "Nianze A. Tao (Omozawa Sueno)"
|
bayesianflow_for_chem/model.py
CHANGED
|
@@ -54,9 +54,19 @@ class Linear(nn.Linear):
|
|
|
54
54
|
:return:
|
|
55
55
|
:rtype: None
|
|
56
56
|
"""
|
|
57
|
+
from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor
|
|
58
|
+
|
|
57
59
|
assert r > 0, "Rank should be larger than 0."
|
|
58
|
-
|
|
59
|
-
|
|
60
|
+
if isinstance(self.weight, AffineQuantizedTensor):
|
|
61
|
+
self.lora_A = nn.Parameter(
|
|
62
|
+
torch.zeros((r, self.in_features), device=self.weight.device)
|
|
63
|
+
)
|
|
64
|
+
self.lora_B = nn.Parameter(
|
|
65
|
+
torch.zeros((self.out_features, r), device=self.weight.device)
|
|
66
|
+
)
|
|
67
|
+
else:
|
|
68
|
+
self.lora_A = nn.Parameter(self.weight.new_zeros((r, self.in_features)))
|
|
69
|
+
self.lora_B = nn.Parameter(self.weight.new_zeros((self.out_features, r)))
|
|
60
70
|
self.scaling = lora_alpha / r
|
|
61
71
|
self.lora_dropout = lora_dropout
|
|
62
72
|
self.lora_enabled = True
|
|
@@ -169,16 +179,6 @@ class Attention(nn.Module):
|
|
|
169
179
|
k = k.view(split).permute(2, 0, 1, 3).contiguous()
|
|
170
180
|
v = v.view(split).permute(2, 0, 1, 3).contiguous()
|
|
171
181
|
q, k = self._rotate(q, k, pe) # position embedding
|
|
172
|
-
"""
|
|
173
|
-
# Original code. Maybe using `nn.functional.scaled_dot_product_attention(...)` is better.
|
|
174
|
-
|
|
175
|
-
k_t = k.transpose(-2, -1)
|
|
176
|
-
if mask is not None:
|
|
177
|
-
alpha = softmax((q @ k_t / self.tp).masked_fill_(mask, -torch.inf), -1)
|
|
178
|
-
else:
|
|
179
|
-
alpha = softmax(q @ k_t / self.tp, -1)
|
|
180
|
-
atten_out = (alpha @ v).permute(1, 2, 0, 3).contiguous().view(shape)
|
|
181
|
-
"""
|
|
182
182
|
atten_out = nn.functional.scaled_dot_product_attention(
|
|
183
183
|
q, k, v, mask, 0.0, False, scale=1 / self.tp
|
|
184
184
|
)
|
|
@@ -430,19 +430,14 @@ class ChemBFN(nn.Module):
|
|
|
430
430
|
c += y
|
|
431
431
|
pe = self.position(n_t)
|
|
432
432
|
x = self.embedding(x)
|
|
433
|
-
attn_mask: Optional[Tensor] = None
|
|
434
433
|
if self.semi_autoregressive:
|
|
435
434
|
attn_mask = torch.tril(
|
|
436
435
|
torch.ones((1, n_b, n_t, n_t), device=x.device), diagonal=0
|
|
437
436
|
)
|
|
437
|
+
elif mask is not None:
|
|
438
|
+
attn_mask = mask.transpose(-2, -1).repeat(1, n_t, 1)[None, ...] != 0
|
|
438
439
|
else:
|
|
439
|
-
|
|
440
|
-
"""
|
|
441
|
-
# Original Code.
|
|
442
|
-
|
|
443
|
-
attn_mask = mask.transpose(-2, -1).repeat(1, x.shape[1], 1)[None, ...] == 0
|
|
444
|
-
"""
|
|
445
|
-
attn_mask = mask.transpose(-2, -1).repeat(1, n_t, 1)[None, ...] != 0
|
|
440
|
+
attn_mask = None
|
|
446
441
|
for layer in self.encoder_layers:
|
|
447
442
|
x = layer(x, pe, c, attn_mask)
|
|
448
443
|
return self.final_layer(x, c, mask is None)
|
|
@@ -1222,23 +1217,23 @@ class EnsembleChemBFN(ChemBFN):
|
|
|
1222
1217
|
)
|
|
1223
1218
|
|
|
1224
1219
|
def quantise(
|
|
1225
|
-
self, quantise_method: Optional[Callable[[ChemBFN],
|
|
1220
|
+
self, quantise_method: Optional[Callable[[ChemBFN], None]] = None
|
|
1226
1221
|
) -> None:
|
|
1227
1222
|
"""
|
|
1228
1223
|
Quantise the submodels. \n
|
|
1229
1224
|
This method should be called, if necessary, before `torch.compile()`.
|
|
1230
1225
|
|
|
1231
|
-
:param quantise_method: quantisation method; default is `bayesianflow_for_chem.tool.
|
|
1226
|
+
:param quantise_method: quantisation method; default is `bayesianflow_for_chem.tool.quantise_model_`
|
|
1232
1227
|
:type quantise_method: callable | None
|
|
1233
1228
|
:return:
|
|
1234
1229
|
:rtype: None
|
|
1235
1230
|
"""
|
|
1236
1231
|
if quantise_method is None:
|
|
1237
|
-
from bayesianflow_for_chem.tool import
|
|
1232
|
+
from bayesianflow_for_chem.tool import quantise_model_
|
|
1238
1233
|
|
|
1239
|
-
quantise_method =
|
|
1240
|
-
for
|
|
1241
|
-
|
|
1234
|
+
quantise_method = quantise_model_
|
|
1235
|
+
for _, v in self.models.items():
|
|
1236
|
+
quantise_method(v)
|
|
1242
1237
|
|
|
1243
1238
|
def jit(self, freeze: bool = False) -> None:
|
|
1244
1239
|
"""
|
bayesianflow_for_chem/tool.py
CHANGED
|
@@ -13,18 +13,18 @@ import torch
|
|
|
13
13
|
import numpy as np
|
|
14
14
|
import torch.nn as nn
|
|
15
15
|
from torch import cuda, Tensor, softmax
|
|
16
|
-
from torch.ao import quantization
|
|
17
16
|
from torch.utils.data import DataLoader
|
|
18
|
-
from typing_extensions import Self
|
|
19
|
-
from rdkit.Chem.rdchem import Mol, Bond
|
|
17
|
+
from typing_extensions import Self, deprecated
|
|
20
18
|
from rdkit.Chem import (
|
|
21
19
|
rdDetermineBonds,
|
|
20
|
+
GetFormalCharge,
|
|
22
21
|
MolFromXYZBlock,
|
|
23
22
|
MolFromSmiles,
|
|
24
23
|
MolToSmiles,
|
|
25
24
|
CanonSmiles,
|
|
26
25
|
AllChem,
|
|
27
26
|
AddHs,
|
|
27
|
+
Mol,
|
|
28
28
|
)
|
|
29
29
|
from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles # type: ignore
|
|
30
30
|
from sklearn.metrics import (
|
|
@@ -385,6 +385,11 @@ def inpaint(
|
|
|
385
385
|
]
|
|
386
386
|
|
|
387
387
|
|
|
388
|
+
@deprecated(
|
|
389
|
+
"Eager mode quantization from `torch.ao` is deprecated and will be remove in version 2.10, "
|
|
390
|
+
"so this fuction will stop working since that time. "
|
|
391
|
+
"Please use `quantise_model_` instead."
|
|
392
|
+
)
|
|
388
393
|
def quantise_model(model: ChemBFN) -> nn.Module:
|
|
389
394
|
"""
|
|
390
395
|
Dynamic quantisation of the trained model to `torch.qint8` data type.
|
|
@@ -394,6 +399,7 @@ def quantise_model(model: ChemBFN) -> nn.Module:
|
|
|
394
399
|
:return: quantised model
|
|
395
400
|
:rtype: torch.nn.Module
|
|
396
401
|
"""
|
|
402
|
+
from torch.ao import quantization
|
|
397
403
|
from torch.ao.nn.quantized import dynamic
|
|
398
404
|
from torch.ao.nn.quantized.modules.utils import _quantize_weight
|
|
399
405
|
from torch.ao.quantization.qconfig import default_dynamic_qconfig
|
|
@@ -526,6 +532,24 @@ def quantise_model(model: ChemBFN) -> nn.Module:
|
|
|
526
532
|
return quantised_model
|
|
527
533
|
|
|
528
534
|
|
|
535
|
+
def quantise_model_(model: ChemBFN) -> None:
|
|
536
|
+
"""
|
|
537
|
+
In-place dynamic quantisation of the trained model to `int8` data type. \n
|
|
538
|
+
Due to some limitations of `torchao` module, it is slower than method previded by `torch.ao`.
|
|
539
|
+
|
|
540
|
+
:param model: trained ChemBFN model
|
|
541
|
+
:type model: bayesianflow_for_chem.model.ChemBFN
|
|
542
|
+
:return:
|
|
543
|
+
:rtype: None
|
|
544
|
+
"""
|
|
545
|
+
from torchao.quantization.quant_api import (
|
|
546
|
+
quantize_,
|
|
547
|
+
Int8DynamicActivationInt8WeightConfig,
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
quantize_(model, Int8DynamicActivationInt8WeightConfig())
|
|
551
|
+
|
|
552
|
+
|
|
529
553
|
class GeometryConverter:
|
|
530
554
|
"""
|
|
531
555
|
Converting between different 2D/3D molecular representations.
|
|
@@ -539,40 +563,88 @@ class GeometryConverter:
|
|
|
539
563
|
xyz_block.append(f"{atom} {r[i][0]:.10f} {r[i][1]:.10f} {r[i][2]:.10f}")
|
|
540
564
|
return MolFromXYZBlock("\n".join(xyz_block))
|
|
541
565
|
|
|
542
|
-
@staticmethod
|
|
543
|
-
def _bond_pair_idx(bonds: Bond) -> List[List[int]]:
|
|
544
|
-
return [[i.GetBeginAtomIdx(), i.GetEndAtomIdx()] for i in bonds]
|
|
545
|
-
|
|
546
566
|
@staticmethod
|
|
547
567
|
def smiles2cartesian(
|
|
548
|
-
smiles: str,
|
|
568
|
+
smiles: str,
|
|
569
|
+
num_conformers: int = 50,
|
|
570
|
+
rdkit_ff_type: str = "MMFF",
|
|
571
|
+
refine_with_crest: bool = False,
|
|
572
|
+
spin: float = 0.0,
|
|
549
573
|
) -> Tuple[List[str], np.ndarray]:
|
|
550
574
|
"""
|
|
551
575
|
Guess the 3D geometry from SMILES string via MMFF conformer search.
|
|
552
576
|
|
|
553
577
|
:param smiles: a valid SMILES string
|
|
554
578
|
:param num_conformers: number of initial conformers
|
|
555
|
-
:param
|
|
579
|
+
:param rdkit_ff_type: force field type chosen in `'MMFF'` and `'UFF'`
|
|
580
|
+
:param refine_with_crest: find the best conformer via CREST
|
|
581
|
+
:param spin: total spin; only required when `refine_with_crest=True`
|
|
556
582
|
:type smiles: str
|
|
557
583
|
:type num_conformers: int
|
|
558
|
-
:type
|
|
584
|
+
:type rdkit_ff_type: str
|
|
585
|
+
:type refine_with_crest: bool
|
|
586
|
+
:type spin: float
|
|
559
587
|
:return: atomic symbols \n
|
|
560
588
|
cartesian coordinates; shape: (n_a, 3)
|
|
561
589
|
:rtype: tuple
|
|
562
590
|
"""
|
|
591
|
+
assert rdkit_ff_type.lower() in ("mmff", "uff")
|
|
592
|
+
if refine_with_crest:
|
|
593
|
+
from tempfile import TemporaryDirectory
|
|
594
|
+
from subprocess import run
|
|
595
|
+
|
|
596
|
+
# We need both CREST and xTB installed.
|
|
597
|
+
if run("crest --version", shell=True).returncode != 0:
|
|
598
|
+
raise RuntimeError(
|
|
599
|
+
"`CREST` is not found! Make sure it is installed and added into the PATH."
|
|
600
|
+
)
|
|
601
|
+
if run("xtb --version", shell=True).returncode != 0:
|
|
602
|
+
raise RuntimeError(
|
|
603
|
+
"`xTB` is not found! Make sure it is installed and added into the PATH."
|
|
604
|
+
)
|
|
563
605
|
mol = MolFromSmiles(smiles)
|
|
564
606
|
mol = AddHs(mol)
|
|
565
|
-
AllChem.EmbedMultipleConfs(mol, numConfs=num_conformers,
|
|
607
|
+
AllChem.EmbedMultipleConfs(mol, numConfs=num_conformers, params=AllChem.ETKDG())
|
|
566
608
|
symbols = [atom.GetSymbol() for atom in mol.GetAtoms()]
|
|
567
609
|
energies = []
|
|
568
610
|
for conf_id in range(num_conformers):
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
611
|
+
if rdkit_ff_type.lower() == "mmff":
|
|
612
|
+
ff = AllChem.MMFFGetMoleculeForceField(
|
|
613
|
+
mol, AllChem.MMFFGetMoleculeProperties(mol), confId=conf_id
|
|
614
|
+
)
|
|
615
|
+
else: # UFF
|
|
616
|
+
ff = AllChem.UFFGetMoleculeForceField(mol, confId=conf_id)
|
|
572
617
|
energy = ff.CalcEnergy()
|
|
573
618
|
energies.append((conf_id, energy))
|
|
574
619
|
lowest_energy_conf = min(energies, key=lambda x: x[1])
|
|
575
620
|
coordinates = mol.GetConformer(id=lowest_energy_conf[0]).GetPositions()
|
|
621
|
+
if refine_with_crest:
|
|
622
|
+
xyz = f"{len(symbols)}\n\n" + "\n".join(
|
|
623
|
+
f"{s} {coordinates[i][0]:.10f} {coordinates[i][1]:.10f} {coordinates[i][2]:.10f}"
|
|
624
|
+
for i, s in enumerate(symbols)
|
|
625
|
+
)
|
|
626
|
+
chrg = GetFormalCharge(mol)
|
|
627
|
+
uhf = int(spin * 2)
|
|
628
|
+
with TemporaryDirectory(dir=Path.cwd()) as temp_dir:
|
|
629
|
+
with open(Path(temp_dir) / "mol.xyz", "w", encoding="utf-8") as f:
|
|
630
|
+
f.write(xyz)
|
|
631
|
+
s = run(
|
|
632
|
+
f"crest mol.xyz -gfn2 -quick -prop ohess{f' --chrg {chrg}' if chrg != 0 else ''}{f' --uhf {uhf}' if uhf != 0 else ''}",
|
|
633
|
+
shell=True,
|
|
634
|
+
cwd=temp_dir,
|
|
635
|
+
)
|
|
636
|
+
if s.returncode == 0:
|
|
637
|
+
with open(Path(temp_dir) / "crest_property.xyz", "r") as f:
|
|
638
|
+
xyz = f.readlines()
|
|
639
|
+
xyz_data = []
|
|
640
|
+
for i in xyz[2:]:
|
|
641
|
+
if i == xyz[0]:
|
|
642
|
+
break
|
|
643
|
+
xyz_data.append(i.strip().split())
|
|
644
|
+
xyz_data = np.array(xyz_data)
|
|
645
|
+
symbols, coordinates = np.split(xyz_data, [1], axis=-1)
|
|
646
|
+
symbols = symbols.flatten().tolist()
|
|
647
|
+
coordinates = coordinates.astype(np.float64)
|
|
576
648
|
return symbols, coordinates
|
|
577
649
|
|
|
578
650
|
def cartesian2smiles(
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: bayesianflow_for_chem
|
|
3
|
-
Version: 1.4.
|
|
3
|
+
Version: 1.4.2
|
|
4
4
|
Summary: Bayesian flow network framework for Chemistry
|
|
5
5
|
Home-page: https://augus1999.github.io/bayesian-flow-network-for-chemistry/
|
|
6
6
|
Author: Nianze A. Tao
|
|
@@ -23,6 +23,7 @@ Description-Content-Type: text/markdown
|
|
|
23
23
|
License-File: LICENSE
|
|
24
24
|
Requires-Dist: rdkit>=2023.9.6
|
|
25
25
|
Requires-Dist: torch>=2.3.1
|
|
26
|
+
Requires-Dist: torchao>=0.12
|
|
26
27
|
Requires-Dist: numpy>=1.26.4
|
|
27
28
|
Requires-Dist: loralib>=0.1.2
|
|
28
29
|
Requires-Dist: lightning>=2.2.0
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
bayesianflow_for_chem/__init__.py,sha256=IeIasLe6wLuGbH7DIlB38ehDPqvlMBT388hf58I3J30,329
|
|
2
|
+
bayesianflow_for_chem/data.py,sha256=WoOCOVmJX4WeHa2WeO4i66J2FS8rvRaYRCdlBN7ZeOM,6576
|
|
3
|
+
bayesianflow_for_chem/model.py,sha256=6pxGuIM7rKyawcz2hI8dT88rv3qFsnCvlLhDj1CB9YU,50595
|
|
4
|
+
bayesianflow_for_chem/scorer.py,sha256=7G1TVSwC0qONtNm6kiDZUWwvuFPzasNSjp4eJAk5TL0,4101
|
|
5
|
+
bayesianflow_for_chem/tool.py,sha256=Ne_ew1P8r6KWOqUZpb-BL_q7Dm6fnSTtxhJvgV1JHHs,26264
|
|
6
|
+
bayesianflow_for_chem/train.py,sha256=hGKyhGhLch-exSYPZdLXrLn3gf39Q1VLSJs2qtuikQE,9709
|
|
7
|
+
bayesianflow_for_chem/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
|
|
8
|
+
bayesianflow_for_chem-1.4.2.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
|
|
9
|
+
bayesianflow_for_chem-1.4.2.dist-info/METADATA,sha256=s6k85HFXvasxvZBJD3Rj8cFNJXehS-utcMeKC6tP8F8,5673
|
|
10
|
+
bayesianflow_for_chem-1.4.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
11
|
+
bayesianflow_for_chem-1.4.2.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
|
|
12
|
+
bayesianflow_for_chem-1.4.2.dist-info/RECORD,,
|
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
bayesianflow_for_chem/__init__.py,sha256=3sP8nM4_idOX-ksvpBJEApxPAVAPijKvQHxidTO5790,329
|
|
2
|
-
bayesianflow_for_chem/data.py,sha256=WoOCOVmJX4WeHa2WeO4i66J2FS8rvRaYRCdlBN7ZeOM,6576
|
|
3
|
-
bayesianflow_for_chem/model.py,sha256=fUrXKhn2U9FrVPJyb4lqACqPTePkIgI0v6_1jPs5c0Q,50784
|
|
4
|
-
bayesianflow_for_chem/scorer.py,sha256=7G1TVSwC0qONtNm6kiDZUWwvuFPzasNSjp4eJAk5TL0,4101
|
|
5
|
-
bayesianflow_for_chem/tool.py,sha256=NMMRHk2FJY0fyA76zCrz6tkcylCuExMUMj5hohWTnkE,23155
|
|
6
|
-
bayesianflow_for_chem/train.py,sha256=hGKyhGhLch-exSYPZdLXrLn3gf39Q1VLSJs2qtuikQE,9709
|
|
7
|
-
bayesianflow_for_chem/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
|
|
8
|
-
bayesianflow_for_chem-1.4.0.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
|
|
9
|
-
bayesianflow_for_chem-1.4.0.dist-info/METADATA,sha256=1Y5mLIOaPsHcyCCm2SkWz7OCniQYVJ67-cVq3cUU0Mw,5643
|
|
10
|
-
bayesianflow_for_chem-1.4.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
11
|
-
bayesianflow_for_chem-1.4.0.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
|
|
12
|
-
bayesianflow_for_chem-1.4.0.dist-info/RECORD,,
|
|
File without changes
|
{bayesianflow_for_chem-1.4.0.dist-info → bayesianflow_for_chem-1.4.2.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|
{bayesianflow_for_chem-1.4.0.dist-info → bayesianflow_for_chem-1.4.2.dist-info}/top_level.txt
RENAMED
|
File without changes
|