bayesianflow-for-chem 1.2.5__py3-none-any.whl → 1.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bayesianflow-for-chem might be problematic. Click here for more details.
- bayesianflow_for_chem/__init__.py +1 -1
- bayesianflow_for_chem/tool.py +30 -9
- {bayesianflow_for_chem-1.2.5.dist-info → bayesianflow_for_chem-1.2.7.dist-info}/METADATA +3 -2
- bayesianflow_for_chem-1.2.7.dist-info/RECORD +12 -0
- {bayesianflow_for_chem-1.2.5.dist-info → bayesianflow_for_chem-1.2.7.dist-info}/WHEEL +1 -1
- bayesianflow_for_chem-1.2.5.dist-info/RECORD +0 -12
- {bayesianflow_for_chem-1.2.5.dist-info → bayesianflow_for_chem-1.2.7.dist-info/licenses}/LICENSE +0 -0
- {bayesianflow_for_chem-1.2.5.dist-info → bayesianflow_for_chem-1.2.7.dist-info}/top_level.txt +0 -0
bayesianflow_for_chem/tool.py
CHANGED
|
@@ -161,6 +161,8 @@ def split_dataset(
|
|
|
161
161
|
:return:
|
|
162
162
|
:rtype: None
|
|
163
163
|
"""
|
|
164
|
+
if isinstance(file, Path):
|
|
165
|
+
file = file.__str__()
|
|
164
166
|
assert file.endswith(".csv")
|
|
165
167
|
assert len(split_ratio) == 3
|
|
166
168
|
assert method in ("random", "scaffold")
|
|
@@ -170,7 +172,7 @@ def split_dataset(
|
|
|
170
172
|
raw_data = data[1:]
|
|
171
173
|
smiles_idx = [] # only first index will be used
|
|
172
174
|
for key, h in enumerate(header):
|
|
173
|
-
if h.lower()
|
|
175
|
+
if "smiles" in h.lower():
|
|
174
176
|
smiles_idx.append(key)
|
|
175
177
|
assert len(smiles_idx) > 0
|
|
176
178
|
data_len = len(raw_data)
|
|
@@ -186,6 +188,14 @@ def split_dataset(
|
|
|
186
188
|
scaffolds: Dict[str, List] = {}
|
|
187
189
|
for key, d in enumerate(raw_data):
|
|
188
190
|
# compute Bemis-Murcko scaffold
|
|
191
|
+
if len(smiles_idx) > 1:
|
|
192
|
+
warnings.warn(
|
|
193
|
+
"\033[32;1m"
|
|
194
|
+
f"We found {len(smiles_idx)} SMILES strings in a row!"
|
|
195
|
+
" Only the first SMILES will be used to compute the molecular scaffold."
|
|
196
|
+
"\033[0m",
|
|
197
|
+
stacklevel=2,
|
|
198
|
+
)
|
|
189
199
|
scaffold = MurckoScaffoldSmiles(d[smiles_idx[0]])
|
|
190
200
|
if scaffold in scaffolds:
|
|
191
201
|
scaffolds[scaffold].append(key)
|
|
@@ -492,15 +502,16 @@ def inpaint(
|
|
|
492
502
|
|
|
493
503
|
def quantise_model(model: ChemBFN) -> nn.Module:
|
|
494
504
|
"""
|
|
495
|
-
Dynamic quantisation of the trained model.
|
|
505
|
+
Dynamic quantisation of the trained model to `torch.qint8` data type.
|
|
496
506
|
|
|
497
507
|
:param model: trained ChemBFN model
|
|
498
508
|
:type model: bayesianflow_for_chem.model.ChemBFN
|
|
499
509
|
:return: quantised model
|
|
500
510
|
:rtype: torch.nn.Module
|
|
501
511
|
"""
|
|
502
|
-
from torch.ao.nn.quantized.modules.utils import _quantize_weight
|
|
503
512
|
from torch.ao.nn.quantized import dynamic
|
|
513
|
+
from torch.ao.nn.quantized.modules.utils import _quantize_weight
|
|
514
|
+
from torch.ao.quantization.qconfig import default_dynamic_qconfig
|
|
504
515
|
|
|
505
516
|
class QuantisedLinear(dynamic.Linear):
|
|
506
517
|
# Modified from https://github.com/pytorch/pytorch/blob/main/torch/ao/nn/quantized/dynamic/modules/linear.py
|
|
@@ -543,7 +554,22 @@ def quantise_model(model: ChemBFN) -> nn.Module:
|
|
|
543
554
|
self._packed_params.requires_grad_(False)
|
|
544
555
|
|
|
545
556
|
def forward(self, x: Tensor) -> Tensor:
|
|
546
|
-
|
|
557
|
+
if self._packed_params.dtype == torch.qint8:
|
|
558
|
+
if self.version is None or self.version < 4:
|
|
559
|
+
Y = torch.ops.quantized.linear_dynamic(
|
|
560
|
+
x, self._packed_params._packed_params
|
|
561
|
+
)
|
|
562
|
+
else:
|
|
563
|
+
Y = torch.ops.quantized.linear_dynamic(
|
|
564
|
+
x, self._packed_params._packed_params, reduce_range=True
|
|
565
|
+
)
|
|
566
|
+
elif self._packed_params.dtype == torch.float16:
|
|
567
|
+
Y = torch.ops.quantized.linear_dynamic_fp16(
|
|
568
|
+
x, self._packed_params._packed_params
|
|
569
|
+
)
|
|
570
|
+
else:
|
|
571
|
+
raise RuntimeError("Unsupported dtype on dynamic quantized linear!")
|
|
572
|
+
result = Y.to(x.dtype)
|
|
547
573
|
if self.lora_enabled and isinstance(self.lora_dropout, float):
|
|
548
574
|
result += (
|
|
549
575
|
nn.functional.dropout(x, self.lora_dropout, self.training)
|
|
@@ -562,11 +588,6 @@ def quantise_model(model: ChemBFN) -> nn.Module:
|
|
|
562
588
|
if mod.qconfig is not None and mod.qconfig.weight is not None:
|
|
563
589
|
weight_observer = mod.qconfig.weight()
|
|
564
590
|
else:
|
|
565
|
-
# We have the circular import issues if we import the qconfig in the beginning of this file:
|
|
566
|
-
# https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
|
|
567
|
-
# import until we need it.
|
|
568
|
-
from torch.ao.quantization.qconfig import default_dynamic_qconfig
|
|
569
|
-
|
|
570
591
|
weight_observer = default_dynamic_qconfig.weight()
|
|
571
592
|
dtype = weight_observer.dtype
|
|
572
593
|
assert dtype in [torch.qint8, torch.float16], (
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: bayesianflow_for_chem
|
|
3
|
-
Version: 1.2.
|
|
3
|
+
Version: 1.2.7
|
|
4
4
|
Summary: Bayesian flow network framework for Chemistry
|
|
5
5
|
Home-page: https://augus1999.github.io/bayesian-flow-network-for-chemistry/
|
|
6
6
|
Author: Nianze A. Tao
|
|
@@ -39,6 +39,7 @@ Dynamic: description-content-type
|
|
|
39
39
|
Dynamic: home-page
|
|
40
40
|
Dynamic: keywords
|
|
41
41
|
Dynamic: license
|
|
42
|
+
Dynamic: license-file
|
|
42
43
|
Dynamic: project-url
|
|
43
44
|
Dynamic: provides-extra
|
|
44
45
|
Dynamic: requires-dist
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
bayesianflow_for_chem/__init__.py,sha256=xYC8F86oe8y40GGqzGGjbbjSXPK16Qci8XqDMjrbxK8,293
|
|
2
|
+
bayesianflow_for_chem/data.py,sha256=9tpRba40lxwrB6aPSJMkxUglEVC3VEQC9wWxhDuz3Q8,7760
|
|
3
|
+
bayesianflow_for_chem/model.py,sha256=HvEvW_xRbkv4eSv5lhd72BJMZkg-ZACEi1DAW3p5Q1Y,35918
|
|
4
|
+
bayesianflow_for_chem/scorer.py,sha256=mV1vX8aBGFra2BE7N8WHihVIo3dXmUdPQIGfSaiuNdk,4084
|
|
5
|
+
bayesianflow_for_chem/tool.py,sha256=yZrWCI3Zi6EHxo3zqCU_ebmzVECaco8Vbx-oTg-rHhg,24118
|
|
6
|
+
bayesianflow_for_chem/train.py,sha256=kj6icGqymUUYopDtpre1oE_wpvpeNilbpzgffBsd1tk,9589
|
|
7
|
+
bayesianflow_for_chem/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
|
|
8
|
+
bayesianflow_for_chem-1.2.7.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
|
|
9
|
+
bayesianflow_for_chem-1.2.7.dist-info/METADATA,sha256=9v-CEHo1DJGmgwopQiQ68sFEaUZkHzFIhfiNTL2r6mc,5913
|
|
10
|
+
bayesianflow_for_chem-1.2.7.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
|
|
11
|
+
bayesianflow_for_chem-1.2.7.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
|
|
12
|
+
bayesianflow_for_chem-1.2.7.dist-info/RECORD,,
|
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
bayesianflow_for_chem/__init__.py,sha256=GMGe5nU963qFL6vJ9OZSfqfSyEImC_P2zyUS0cyP3Mg,293
|
|
2
|
-
bayesianflow_for_chem/data.py,sha256=9tpRba40lxwrB6aPSJMkxUglEVC3VEQC9wWxhDuz3Q8,7760
|
|
3
|
-
bayesianflow_for_chem/model.py,sha256=HvEvW_xRbkv4eSv5lhd72BJMZkg-ZACEi1DAW3p5Q1Y,35918
|
|
4
|
-
bayesianflow_for_chem/scorer.py,sha256=mV1vX8aBGFra2BE7N8WHihVIo3dXmUdPQIGfSaiuNdk,4084
|
|
5
|
-
bayesianflow_for_chem/tool.py,sha256=tJjb8q3_orNkj2BYJwz5VxqeaOv55dvqO93_uigLJIk,23221
|
|
6
|
-
bayesianflow_for_chem/train.py,sha256=kj6icGqymUUYopDtpre1oE_wpvpeNilbpzgffBsd1tk,9589
|
|
7
|
-
bayesianflow_for_chem/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
|
|
8
|
-
bayesianflow_for_chem-1.2.5.dist-info/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
|
|
9
|
-
bayesianflow_for_chem-1.2.5.dist-info/METADATA,sha256=hwEEDW6ipmHpjRjQDKxWk5zqI9jwjsl-yxBpvYn93HQ,5890
|
|
10
|
-
bayesianflow_for_chem-1.2.5.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
11
|
-
bayesianflow_for_chem-1.2.5.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
|
|
12
|
-
bayesianflow_for_chem-1.2.5.dist-info/RECORD,,
|
{bayesianflow_for_chem-1.2.5.dist-info → bayesianflow_for_chem-1.2.7.dist-info/licenses}/LICENSE
RENAMED
|
File without changes
|
{bayesianflow_for_chem-1.2.5.dist-info → bayesianflow_for_chem-1.2.7.dist-info}/top_level.txt
RENAMED
|
File without changes
|