bayesianflow-for-chem 1.2.5__py3-none-any.whl → 1.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bayesianflow-for-chem might be problematic. Click here for more details.

@@ -7,5 +7,5 @@ from . import data, tool, train, scorer
7
7
  from .model import ChemBFN, MLP
8
8
 
9
9
  __all__ = ["data", "tool", "train", "scorer", "ChemBFN", "MLP"]
10
- __version__ = "1.2.5"
10
+ __version__ = "1.2.7"
11
11
  __author__ = "Nianze A. Tao (Omozawa Sueno)"
@@ -161,6 +161,8 @@ def split_dataset(
161
161
  :return:
162
162
  :rtype: None
163
163
  """
164
+ if isinstance(file, Path):
165
+ file = file.__str__()
164
166
  assert file.endswith(".csv")
165
167
  assert len(split_ratio) == 3
166
168
  assert method in ("random", "scaffold")
@@ -170,7 +172,7 @@ def split_dataset(
170
172
  raw_data = data[1:]
171
173
  smiles_idx = [] # only first index will be used
172
174
  for key, h in enumerate(header):
173
- if h.lower() == "smiles":
175
+ if "smiles" in h.lower():
174
176
  smiles_idx.append(key)
175
177
  assert len(smiles_idx) > 0
176
178
  data_len = len(raw_data)
@@ -186,6 +188,14 @@ def split_dataset(
186
188
  scaffolds: Dict[str, List] = {}
187
189
  for key, d in enumerate(raw_data):
188
190
  # compute Bemis-Murcko scaffold
191
+ if len(smiles_idx) > 1:
192
+ warnings.warn(
193
+ "\033[32;1m"
194
+ f"We found {len(smiles_idx)} SMILES strings in a row!"
195
+ " Only the first SMILES will be used to compute the molecular scaffold."
196
+ "\033[0m",
197
+ stacklevel=2,
198
+ )
189
199
  scaffold = MurckoScaffoldSmiles(d[smiles_idx[0]])
190
200
  if scaffold in scaffolds:
191
201
  scaffolds[scaffold].append(key)
@@ -492,15 +502,16 @@ def inpaint(
492
502
 
493
503
  def quantise_model(model: ChemBFN) -> nn.Module:
494
504
  """
495
- Dynamic quantisation of the trained model.
505
+ Dynamic quantisation of the trained model to `torch.qint8` data type.
496
506
 
497
507
  :param model: trained ChemBFN model
498
508
  :type model: bayesianflow_for_chem.model.ChemBFN
499
509
  :return: quantised model
500
510
  :rtype: torch.nn.Module
501
511
  """
502
- from torch.ao.nn.quantized.modules.utils import _quantize_weight
503
512
  from torch.ao.nn.quantized import dynamic
513
+ from torch.ao.nn.quantized.modules.utils import _quantize_weight
514
+ from torch.ao.quantization.qconfig import default_dynamic_qconfig
504
515
 
505
516
  class QuantisedLinear(dynamic.Linear):
506
517
  # Modified from https://github.com/pytorch/pytorch/blob/main/torch/ao/nn/quantized/dynamic/modules/linear.py
@@ -543,7 +554,22 @@ def quantise_model(model: ChemBFN) -> nn.Module:
543
554
  self._packed_params.requires_grad_(False)
544
555
 
545
556
  def forward(self, x: Tensor) -> Tensor:
546
- result = dynamic.Linear.forward(self, x)
557
+ if self._packed_params.dtype == torch.qint8:
558
+ if self.version is None or self.version < 4:
559
+ Y = torch.ops.quantized.linear_dynamic(
560
+ x, self._packed_params._packed_params
561
+ )
562
+ else:
563
+ Y = torch.ops.quantized.linear_dynamic(
564
+ x, self._packed_params._packed_params, reduce_range=True
565
+ )
566
+ elif self._packed_params.dtype == torch.float16:
567
+ Y = torch.ops.quantized.linear_dynamic_fp16(
568
+ x, self._packed_params._packed_params
569
+ )
570
+ else:
571
+ raise RuntimeError("Unsupported dtype on dynamic quantized linear!")
572
+ result = Y.to(x.dtype)
547
573
  if self.lora_enabled and isinstance(self.lora_dropout, float):
548
574
  result += (
549
575
  nn.functional.dropout(x, self.lora_dropout, self.training)
@@ -562,11 +588,6 @@ def quantise_model(model: ChemBFN) -> nn.Module:
562
588
  if mod.qconfig is not None and mod.qconfig.weight is not None:
563
589
  weight_observer = mod.qconfig.weight()
564
590
  else:
565
- # We have the circular import issues if we import the qconfig in the beginning of this file:
566
- # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
567
- # import until we need it.
568
- from torch.ao.quantization.qconfig import default_dynamic_qconfig
569
-
570
591
  weight_observer = default_dynamic_qconfig.weight()
571
592
  dtype = weight_observer.dtype
572
593
  assert dtype in [torch.qint8, torch.float16], (
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: bayesianflow_for_chem
3
- Version: 1.2.5
3
+ Version: 1.2.7
4
4
  Summary: Bayesian flow network framework for Chemistry
5
5
  Home-page: https://augus1999.github.io/bayesian-flow-network-for-chemistry/
6
6
  Author: Nianze A. Tao
@@ -39,6 +39,7 @@ Dynamic: description-content-type
39
39
  Dynamic: home-page
40
40
  Dynamic: keywords
41
41
  Dynamic: license
42
+ Dynamic: license-file
42
43
  Dynamic: project-url
43
44
  Dynamic: provides-extra
44
45
  Dynamic: requires-dist
@@ -0,0 +1,12 @@
1
+ bayesianflow_for_chem/__init__.py,sha256=xYC8F86oe8y40GGqzGGjbbjSXPK16Qci8XqDMjrbxK8,293
2
+ bayesianflow_for_chem/data.py,sha256=9tpRba40lxwrB6aPSJMkxUglEVC3VEQC9wWxhDuz3Q8,7760
3
+ bayesianflow_for_chem/model.py,sha256=HvEvW_xRbkv4eSv5lhd72BJMZkg-ZACEi1DAW3p5Q1Y,35918
4
+ bayesianflow_for_chem/scorer.py,sha256=mV1vX8aBGFra2BE7N8WHihVIo3dXmUdPQIGfSaiuNdk,4084
5
+ bayesianflow_for_chem/tool.py,sha256=yZrWCI3Zi6EHxo3zqCU_ebmzVECaco8Vbx-oTg-rHhg,24118
6
+ bayesianflow_for_chem/train.py,sha256=kj6icGqymUUYopDtpre1oE_wpvpeNilbpzgffBsd1tk,9589
7
+ bayesianflow_for_chem/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
8
+ bayesianflow_for_chem-1.2.7.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
9
+ bayesianflow_for_chem-1.2.7.dist-info/METADATA,sha256=9v-CEHo1DJGmgwopQiQ68sFEaUZkHzFIhfiNTL2r6mc,5913
10
+ bayesianflow_for_chem-1.2.7.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
11
+ bayesianflow_for_chem-1.2.7.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
12
+ bayesianflow_for_chem-1.2.7.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.0)
2
+ Generator: setuptools (80.3.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,12 +0,0 @@
1
- bayesianflow_for_chem/__init__.py,sha256=GMGe5nU963qFL6vJ9OZSfqfSyEImC_P2zyUS0cyP3Mg,293
2
- bayesianflow_for_chem/data.py,sha256=9tpRba40lxwrB6aPSJMkxUglEVC3VEQC9wWxhDuz3Q8,7760
3
- bayesianflow_for_chem/model.py,sha256=HvEvW_xRbkv4eSv5lhd72BJMZkg-ZACEi1DAW3p5Q1Y,35918
4
- bayesianflow_for_chem/scorer.py,sha256=mV1vX8aBGFra2BE7N8WHihVIo3dXmUdPQIGfSaiuNdk,4084
5
- bayesianflow_for_chem/tool.py,sha256=tJjb8q3_orNkj2BYJwz5VxqeaOv55dvqO93_uigLJIk,23221
6
- bayesianflow_for_chem/train.py,sha256=kj6icGqymUUYopDtpre1oE_wpvpeNilbpzgffBsd1tk,9589
7
- bayesianflow_for_chem/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
8
- bayesianflow_for_chem-1.2.5.dist-info/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
9
- bayesianflow_for_chem-1.2.5.dist-info/METADATA,sha256=hwEEDW6ipmHpjRjQDKxWk5zqI9jwjsl-yxBpvYn93HQ,5890
10
- bayesianflow_for_chem-1.2.5.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
11
- bayesianflow_for_chem-1.2.5.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
12
- bayesianflow_for_chem-1.2.5.dist-info/RECORD,,