bayesianflow-for-chem 2.0.2__tar.gz → 2.0.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bayesianflow-for-chem might be problematic. Click here for more details.
- {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/PKG-INFO +7 -1
- {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/README.md +5 -0
- {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/bayesianflow_for_chem/__init__.py +4 -1
- {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/bayesianflow_for_chem/cli.py +33 -18
- {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/bayesianflow_for_chem/model.py +13 -0
- {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/bayesianflow_for_chem/tool.py +4 -1
- {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/bayesianflow_for_chem/train.py +1 -1
- {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/bayesianflow_for_chem.egg-info/PKG-INFO +7 -1
- {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/bayesianflow_for_chem.egg-info/SOURCES.txt +3 -1
- {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/bayesianflow_for_chem.egg-info/requires.txt +1 -0
- {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/setup.py +1 -0
- bayesianflow_for_chem-2.0.4/test/test_merge_lora.py +40 -0
- bayesianflow_for_chem-2.0.4/test/test_molecular_embedding.py +67 -0
- {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/LICENSE +0 -0
- {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/bayesianflow_for_chem/data.py +0 -0
- {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/bayesianflow_for_chem/scorer.py +0 -0
- {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/bayesianflow_for_chem/spectra.py +0 -0
- {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/bayesianflow_for_chem/vocab.txt +0 -0
- {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/bayesianflow_for_chem.egg-info/dependency_links.txt +0 -0
- {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/bayesianflow_for_chem.egg-info/entry_points.txt +0 -0
- {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/bayesianflow_for_chem.egg-info/top_level.txt +0 -0
- {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/pyproject.toml +0 -0
- {bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: bayesianflow_for_chem
|
|
3
|
-
Version: 2.0.
|
|
3
|
+
Version: 2.0.4
|
|
4
4
|
Summary: Bayesian flow network framework for Chemistry
|
|
5
5
|
Home-page: https://augus1999.github.io/bayesian-flow-network-for-chemistry/
|
|
6
6
|
Author: Nianze A. Tao
|
|
@@ -23,6 +23,7 @@ License-File: LICENSE
|
|
|
23
23
|
Requires-Dist: rdkit>=2025.3.5
|
|
24
24
|
Requires-Dist: torch>=2.8.0
|
|
25
25
|
Requires-Dist: torchao>=0.12
|
|
26
|
+
Requires-Dist: colorama>=0.4.6
|
|
26
27
|
Requires-Dist: numpy>=2.3.2
|
|
27
28
|
Requires-Dist: scipy>=1.16.1
|
|
28
29
|
Requires-Dist: loralib>=0.1.2
|
|
@@ -49,6 +50,11 @@ Dynamic: summary
|
|
|
49
50
|
|
|
50
51
|
This is the repository of the PyTorch implementation of ChemBFN model.
|
|
51
52
|
|
|
53
|
+
### Build State
|
|
54
|
+
|
|
55
|
+
[](https://pypi.org/project/bayesianflow-for-chem/)
|
|
56
|
+

|
|
57
|
+
|
|
52
58
|
## Features
|
|
53
59
|
|
|
54
60
|
ChemBFN provides the state-of-the-art functionalities of
|
|
@@ -5,6 +5,11 @@
|
|
|
5
5
|
|
|
6
6
|
This is the repository of the PyTorch implementation of ChemBFN model.
|
|
7
7
|
|
|
8
|
+
### Build State
|
|
9
|
+
|
|
10
|
+
[](https://pypi.org/project/bayesianflow-for-chem/)
|
|
11
|
+

|
|
12
|
+
|
|
8
13
|
## Features
|
|
9
14
|
|
|
10
15
|
ChemBFN provides the state-of-the-art functionalities of
|
{bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/bayesianflow_for_chem/__init__.py
RENAMED
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
"""
|
|
4
4
|
ChemBFN package.
|
|
5
5
|
"""
|
|
6
|
+
import colorama
|
|
6
7
|
from . import data, tool, train, scorer, spectra
|
|
7
8
|
from .model import ChemBFN, MLP, EnsembleChemBFN
|
|
8
9
|
from .cli import main_script
|
|
@@ -17,7 +18,7 @@ __all__ = [
|
|
|
17
18
|
"MLP",
|
|
18
19
|
"EnsembleChemBFN",
|
|
19
20
|
]
|
|
20
|
-
__version__ = "2.0.
|
|
21
|
+
__version__ = "2.0.4"
|
|
21
22
|
__author__ = "Nianze A. Tao (Omozawa Sueno)"
|
|
22
23
|
|
|
23
24
|
|
|
@@ -28,4 +29,6 @@ def main() -> None:
|
|
|
28
29
|
:return:
|
|
29
30
|
:rtype: None
|
|
30
31
|
"""
|
|
32
|
+
colorama.just_fix_windows_console()
|
|
31
33
|
main_script(__version__)
|
|
34
|
+
colorama.deinit()
|
|
@@ -130,10 +130,11 @@ def parse_cli(version: str) -> argparse.Namespace:
|
|
|
130
130
|
"""
|
|
131
131
|
parser = argparse.ArgumentParser(
|
|
132
132
|
description="Madmol: a CLI molecular design tool for "
|
|
133
|
-
"de novo design
|
|
133
|
+
"de novo design, R-group replacement, and sequence in-filling, "
|
|
134
134
|
"based on generative route of ChemBFN method. "
|
|
135
135
|
"Let's make some craziest molecules.",
|
|
136
|
-
epilog=f"Madmol {version}, developed in Hiroshima University"
|
|
136
|
+
epilog=f"Madmol {version}, developed in Hiroshima University by chemists for chemists. "
|
|
137
|
+
"Visit https://augus1999.github.io/bayesian-flow-network-for-chemistry/ for more details.",
|
|
137
138
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
138
139
|
)
|
|
139
140
|
parser.add_argument(
|
|
@@ -180,14 +181,16 @@ def load_model_config(
|
|
|
180
181
|
model_config = tomllib.load(f)
|
|
181
182
|
if model_config["ChemBFN"]["num_vocab"] != "match vocabulary size":
|
|
182
183
|
if not isinstance(model_config["ChemBFN"]["num_vocab"], int):
|
|
183
|
-
print(
|
|
184
|
+
print(
|
|
185
|
+
f"\033[0;31mCritical\033[0;0m in {config_file}: You must specify num_vocab."
|
|
186
|
+
)
|
|
184
187
|
flag_critical += 1
|
|
185
188
|
if model_config["ChemBFN"]["base_model"]:
|
|
186
189
|
model_file = model_config["ChemBFN"]["base_model"]
|
|
187
190
|
for fn in model_file:
|
|
188
191
|
if not os.path.exists(fn):
|
|
189
192
|
print(
|
|
190
|
-
f"
|
|
193
|
+
f"\033[0;31mCritical\033[0;0m in {config_file}: Base model file {fn} does not exist."
|
|
191
194
|
)
|
|
192
195
|
flag_critical += 1
|
|
193
196
|
if "MLP" in model_config:
|
|
@@ -195,14 +198,14 @@ def load_model_config(
|
|
|
195
198
|
b = model_config["MLP"]["size"][-1]
|
|
196
199
|
if a != b:
|
|
197
200
|
print(
|
|
198
|
-
f"
|
|
201
|
+
f"\033[0;31mCritical\033[0;0m in {config_file}: MLP hidden size {b} should match ChemBFN hidden size {a}."
|
|
199
202
|
)
|
|
200
203
|
flag_critical += 1
|
|
201
204
|
if model_config["MLP"]["base_model"]:
|
|
202
205
|
model_file = model_config["MLP"]["base_model"]
|
|
203
206
|
if not os.path.exists(model_file):
|
|
204
207
|
print(
|
|
205
|
-
f"
|
|
208
|
+
f"\033[0;31mCritical\033[0;0m in {config_file}: Base model file {fn} does not exist."
|
|
206
209
|
)
|
|
207
210
|
flag_critical += 1
|
|
208
211
|
return model_config, flag_critical, flag_warning
|
|
@@ -226,49 +229,61 @@ def load_runtime_config(
|
|
|
226
229
|
config = tomllib.load(f)
|
|
227
230
|
tokeniser_name = config["tokeniser"]["name"].lower()
|
|
228
231
|
if not tokeniser_name in "smiles selfies safe fasta".split():
|
|
229
|
-
print(
|
|
232
|
+
print(
|
|
233
|
+
f"\033[0;31mCritical\033[0;0m in {config_file}: Unknown tokensier name: {tokeniser_name}."
|
|
234
|
+
)
|
|
230
235
|
flag_critical += 1
|
|
231
236
|
if tokeniser_name == "selfies":
|
|
232
237
|
vocab = config["tokeniser"]["vocab"]
|
|
233
238
|
if vocab.lower() == "default":
|
|
234
|
-
print(
|
|
239
|
+
print(
|
|
240
|
+
f"\033[0;31mCritical\033[0;0m in {config_file}: You should specify a vocabulary file."
|
|
241
|
+
)
|
|
235
242
|
flag_critical += 1
|
|
236
243
|
elif not os.path.exists(vocab):
|
|
237
|
-
print(
|
|
244
|
+
print(
|
|
245
|
+
f"\033[0;31mCritical\033[0;0m in {config_file}: Vocabulary file {vocab} does not exist."
|
|
246
|
+
)
|
|
238
247
|
flag_critical += 1
|
|
239
248
|
if "train" in config:
|
|
240
249
|
dataset_file = config["train"]["dataset"]
|
|
241
250
|
if not os.path.exists(dataset_file):
|
|
242
251
|
print(
|
|
243
|
-
f"
|
|
252
|
+
f"\033[0;31mCritical\033[0;0m in {config_file}: Dataset file {dataset_file} does not exist."
|
|
244
253
|
)
|
|
245
254
|
flag_critical += 1
|
|
246
255
|
logger_name = config["train"]["logger_name"].lower()
|
|
247
256
|
if not logger_name in "csv tensorboard wandb".split():
|
|
248
|
-
print(
|
|
257
|
+
print(
|
|
258
|
+
f"\033[0;31mCritical\033[0;0m in {config_file}: Unknown logger: {logger_name}."
|
|
259
|
+
)
|
|
249
260
|
flag_critical += 1
|
|
250
261
|
if config["train"]["restart"]:
|
|
251
262
|
ckpt_file = config["train"]["restart"]
|
|
252
263
|
if not os.path.exists(ckpt_file):
|
|
253
264
|
print(
|
|
254
|
-
f"
|
|
265
|
+
f"\033[0;31mCritical\033[0;0m in {config_file}: Restart checkpoint file {ckpt_file} does not exist."
|
|
255
266
|
)
|
|
256
267
|
flag_critical += 1
|
|
257
268
|
if "inference" in config:
|
|
258
269
|
if not "train" in config:
|
|
259
270
|
if not isinstance(config["inference"]["sequence_length"], int):
|
|
260
271
|
print(
|
|
261
|
-
f"
|
|
272
|
+
f"\033[0;31mCritical\033[0;0m in {config_file}: You must set an integer for sequence_length."
|
|
262
273
|
)
|
|
263
274
|
flag_critical += 1
|
|
264
275
|
if config["inference"]["guidance_objective"]:
|
|
265
276
|
if not "guidance_objective_strength" in config["inference"]:
|
|
266
277
|
print(
|
|
267
|
-
f"
|
|
278
|
+
f"\033[0;31mCritical\033[0;0m in {config_file}: You need to add guidance_objective_strength."
|
|
268
279
|
)
|
|
269
280
|
flag_critical += 1
|
|
270
281
|
result_dir = Path(config["inference"]["result_file"]).parent
|
|
271
|
-
|
|
282
|
+
if not os.path.exists(result_dir):
|
|
283
|
+
print(
|
|
284
|
+
f"\033[0;33mWarning\033[0;0m in {config_file}: Directory {result_dir} to save the result does not exist."
|
|
285
|
+
)
|
|
286
|
+
flag_warning += 1
|
|
272
287
|
return config, flag_critical, flag_warning
|
|
273
288
|
|
|
274
289
|
|
|
@@ -306,7 +321,7 @@ def main_script(version: str) -> None:
|
|
|
306
321
|
if runtime_config["train"]["enable_lora"]:
|
|
307
322
|
if not model_config["ChemBFN"]["base_model"]:
|
|
308
323
|
print(
|
|
309
|
-
f"
|
|
324
|
+
f"\033[0;33mWarning\033[0;0m in {parser.model_config}: You should load a pretrained model first."
|
|
310
325
|
)
|
|
311
326
|
flag_warning += 1
|
|
312
327
|
if not os.path.exists(runtime_config["train"]["checkpoint_save_path"]):
|
|
@@ -314,12 +329,12 @@ def main_script(version: str) -> None:
|
|
|
314
329
|
else:
|
|
315
330
|
if not model_config["ChemBFN"]["base_model"]:
|
|
316
331
|
print(
|
|
317
|
-
f"
|
|
332
|
+
f"\033[0;33mWarning\033[0;0m in {parser.model_config}: You should load a pretrained ChemBFN model."
|
|
318
333
|
)
|
|
319
334
|
flag_warning += 1
|
|
320
335
|
if not model_config["MLP"]["base_model"]:
|
|
321
336
|
print(
|
|
322
|
-
f"
|
|
337
|
+
f"\033[0;33mWarning\033[0;0m in {parser.model_config}: You should load a pretrained MLP."
|
|
323
338
|
)
|
|
324
339
|
flag_warning += 1
|
|
325
340
|
if "inference" in runtime_config:
|
|
@@ -1038,6 +1038,19 @@ class EnsembleChemBFN(ChemBFN):
|
|
|
1038
1038
|
self.__delattr__("lora_enabled")
|
|
1039
1039
|
self.__delattr__("lora_param")
|
|
1040
1040
|
self.__delattr__("hparam")
|
|
1041
|
+
# ------- merge LoRA parameters to reduce the latency -------
|
|
1042
|
+
for _, v in self.models.items():
|
|
1043
|
+
for module in v.modules():
|
|
1044
|
+
if hasattr(module, "lora_A"):
|
|
1045
|
+
module.weight.data += (
|
|
1046
|
+
module.lora_B @ module.lora_A
|
|
1047
|
+
) * module.scaling
|
|
1048
|
+
module.lora_enabled = False
|
|
1049
|
+
module.lora_A = None
|
|
1050
|
+
module.lora_B = None
|
|
1051
|
+
module.scaling = None
|
|
1052
|
+
module.lora_dropout = None
|
|
1053
|
+
v.lora_enabled = False
|
|
1041
1054
|
|
|
1042
1055
|
def construct_y(
|
|
1043
1056
|
self, c: Union[List[Tensor], Dict[str, Tensor]]
|
|
@@ -9,6 +9,7 @@ import warnings
|
|
|
9
9
|
from pathlib import Path
|
|
10
10
|
from typing import List, Dict, Tuple, Union, Optional
|
|
11
11
|
import torch
|
|
12
|
+
import colorama
|
|
12
13
|
import numpy as np
|
|
13
14
|
from torch import cuda, Tensor, softmax
|
|
14
15
|
from torch.utils.data import DataLoader
|
|
@@ -141,6 +142,7 @@ def split_dataset(
|
|
|
141
142
|
assert file.endswith(".csv")
|
|
142
143
|
assert len(split_ratio) == 3
|
|
143
144
|
assert method in ("random", "scaffold")
|
|
145
|
+
colorama.just_fix_windows_console()
|
|
144
146
|
with open(file, "r") as f:
|
|
145
147
|
data = list(csv.reader(f))
|
|
146
148
|
header = data[0]
|
|
@@ -198,6 +200,7 @@ def split_dataset(
|
|
|
198
200
|
with open(file.replace(".csv", "_val.csv"), "w", newline="") as fva:
|
|
199
201
|
writer = csv.writer(fva)
|
|
200
202
|
writer.writerows([header] + val_set)
|
|
203
|
+
colorama.deinit()
|
|
201
204
|
|
|
202
205
|
|
|
203
206
|
@torch.no_grad()
|
|
@@ -467,7 +470,7 @@ class GeometryConverter:
|
|
|
467
470
|
spin: float = 0.0,
|
|
468
471
|
) -> Tuple[List[str], np.ndarray]:
|
|
469
472
|
"""
|
|
470
|
-
Guess the 3D geometry from SMILES string via
|
|
473
|
+
Guess the 3D geometry from SMILES string via conformer search.
|
|
471
474
|
|
|
472
475
|
:param smiles: a valid SMILES string
|
|
473
476
|
:param num_conformers: number of initial conformers
|
|
@@ -134,7 +134,7 @@ class Regressor(LightningModule):
|
|
|
134
134
|
hparam: Dict[str, Union[str, int, float, bool]] = DEFAULT_REGRESSOR_HPARAM,
|
|
135
135
|
) -> None:
|
|
136
136
|
"""
|
|
137
|
-
A `~lightning.LightningModule` wrapper of bayesian flow network for chemistry regression model.\n
|
|
137
|
+
A `~lightning.LightningModule` wrapper of bayesian flow network for chemistry regression or classification model.\n
|
|
138
138
|
This module is used in training stage only. By calling `Regressor(...).export_model(YOUR_WORK_DIR)` after training,
|
|
139
139
|
the models will be saved to `YOUR_WORK_DIR/model_ft.pt` (if LoRA is enabled then `YOUR_WORK_DIR/lora.pt`)
|
|
140
140
|
and `YOUR_WORK_DIR/readout.pt`.
|
{bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/bayesianflow_for_chem.egg-info/PKG-INFO
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: bayesianflow_for_chem
|
|
3
|
-
Version: 2.0.
|
|
3
|
+
Version: 2.0.4
|
|
4
4
|
Summary: Bayesian flow network framework for Chemistry
|
|
5
5
|
Home-page: https://augus1999.github.io/bayesian-flow-network-for-chemistry/
|
|
6
6
|
Author: Nianze A. Tao
|
|
@@ -23,6 +23,7 @@ License-File: LICENSE
|
|
|
23
23
|
Requires-Dist: rdkit>=2025.3.5
|
|
24
24
|
Requires-Dist: torch>=2.8.0
|
|
25
25
|
Requires-Dist: torchao>=0.12
|
|
26
|
+
Requires-Dist: colorama>=0.4.6
|
|
26
27
|
Requires-Dist: numpy>=2.3.2
|
|
27
28
|
Requires-Dist: scipy>=1.16.1
|
|
28
29
|
Requires-Dist: loralib>=0.1.2
|
|
@@ -49,6 +50,11 @@ Dynamic: summary
|
|
|
49
50
|
|
|
50
51
|
This is the repository of the PyTorch implementation of ChemBFN model.
|
|
51
52
|
|
|
53
|
+
### Build State
|
|
54
|
+
|
|
55
|
+
[](https://pypi.org/project/bayesianflow-for-chem/)
|
|
56
|
+

|
|
57
|
+
|
|
52
58
|
## Features
|
|
53
59
|
|
|
54
60
|
ChemBFN provides the state-of-the-art functionalities of
|
|
@@ -16,4 +16,6 @@ bayesianflow_for_chem.egg-info/SOURCES.txt
|
|
|
16
16
|
bayesianflow_for_chem.egg-info/dependency_links.txt
|
|
17
17
|
bayesianflow_for_chem.egg-info/entry_points.txt
|
|
18
18
|
bayesianflow_for_chem.egg-info/requires.txt
|
|
19
|
-
bayesianflow_for_chem.egg-info/top_level.txt
|
|
19
|
+
bayesianflow_for_chem.egg-info/top_level.txt
|
|
20
|
+
test/test_merge_lora.py
|
|
21
|
+
test/test_molecular_embedding.py
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Author: Nianze A. Tao (Omozawa Sueno)
|
|
3
|
+
"""
|
|
4
|
+
Model output should be almost identical before and after emerging LoRA parameters into base model.
|
|
5
|
+
"""
|
|
6
|
+
import torch
|
|
7
|
+
from bayesianflow_for_chem import ChemBFN, MLP
|
|
8
|
+
from bayesianflow_for_chem.tool import merge_lora_
|
|
9
|
+
from bayesianflow_for_chem.data import VOCAB_COUNT, smiles2token, collate
|
|
10
|
+
|
|
11
|
+
torch.manual_seed(8964)
|
|
12
|
+
|
|
13
|
+
model = ChemBFN(VOCAB_COUNT)
|
|
14
|
+
model.enable_lora(r=8)
|
|
15
|
+
model.eval()
|
|
16
|
+
mlp = MLP([512, 256, 3], dropout=0.7)
|
|
17
|
+
mlp.eval()
|
|
18
|
+
for module in model.modules():
|
|
19
|
+
if hasattr(module, "lora_B"):
|
|
20
|
+
torch.nn.init.kaiming_uniform_(module.lora_B, a=5**0.5)
|
|
21
|
+
|
|
22
|
+
x = collate(
|
|
23
|
+
[{"token": smiles2token("c1ccccc1O")}, {"token": smiles2token("[NH4+]CCCCCC[O-]")}]
|
|
24
|
+
)["token"]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@torch.inference_mode()
|
|
28
|
+
def test():
|
|
29
|
+
model.semi_autoregressive = False
|
|
30
|
+
y1 = model.inference(x, mlp)
|
|
31
|
+
model.semi_autoregressive = True
|
|
32
|
+
y2 = model.inference(x, mlp)
|
|
33
|
+
merge_lora_(model)
|
|
34
|
+
model.semi_autoregressive = False
|
|
35
|
+
y3 = model.inference(x, mlp)
|
|
36
|
+
model.semi_autoregressive = True
|
|
37
|
+
y4 = model.inference(x, mlp)
|
|
38
|
+
assert not model.lora_enabled
|
|
39
|
+
assert (y1 - y3).abs().mean() < 1e-6
|
|
40
|
+
assert (y2 - y4).abs().mean() < 1e-6
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Author: Nianze A. Tao (Omozawa Sueno)
|
|
3
|
+
"""
|
|
4
|
+
Molecular embedding vectors should not be affected by <pad> tokens.
|
|
5
|
+
"""
|
|
6
|
+
from functools import partial
|
|
7
|
+
import torch
|
|
8
|
+
from bayesianflow_for_chem import ChemBFN, MLP
|
|
9
|
+
from bayesianflow_for_chem.data import VOCAB_COUNT, smiles2token
|
|
10
|
+
|
|
11
|
+
torch.manual_seed(8964)
|
|
12
|
+
|
|
13
|
+
model = ChemBFN(VOCAB_COUNT)
|
|
14
|
+
model.eval()
|
|
15
|
+
mlp1 = MLP([512, 256, 3], dropout=0.7)
|
|
16
|
+
mlp1.eval()
|
|
17
|
+
mlp2 = MLP([1024, 512, 3], dropout=0.7)
|
|
18
|
+
mlp2.eval()
|
|
19
|
+
|
|
20
|
+
x = smiles2token("c1ccccc1O.[NH4+]CCCCCC[O-]")
|
|
21
|
+
x1 = x[None, ...]
|
|
22
|
+
x2 = torch.nn.functional.pad(x1, (0, 7, 0, 0))
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def embed_fn(z, sar_flag, mask, x):
|
|
26
|
+
mb0 = z[x == 2].view(z.shape[0], -1) if sar_flag else z[::, 0]
|
|
27
|
+
mb1 = (z * mask[..., None]).sum(1) / (mask != 0).float().sum(1, True)
|
|
28
|
+
return torch.cat([mb0, mb1], -1)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@torch.inference_mode()
|
|
32
|
+
def test():
|
|
33
|
+
model.semi_autoregressive = False
|
|
34
|
+
y1 = model.inference(x1, mlp1)
|
|
35
|
+
y2 = model.inference(x2, mlp1)
|
|
36
|
+
assert (y1 != y2).sum() == 0
|
|
37
|
+
model.semi_autoregressive = True
|
|
38
|
+
y1 = model.inference(x1, mlp1)
|
|
39
|
+
y2 = model.inference(x2, mlp1)
|
|
40
|
+
assert (y1 != y2).sum() == 0
|
|
41
|
+
# ------- customised embedding extraction -------
|
|
42
|
+
mask1 = torch.tensor([[0] + [0.7] * 9 + [0] + [0.3] * 16 + [0]])
|
|
43
|
+
mask2 = torch.tensor([[0] + [0.7] * 9 + [0] + [0.3] * 16 + [0] * 8])
|
|
44
|
+
model.semi_autoregressive = False
|
|
45
|
+
y1 = model.inference(
|
|
46
|
+
x1,
|
|
47
|
+
mlp2,
|
|
48
|
+
partial(embed_fn, sar_flag=model.semi_autoregressive, mask=mask1, x=x1),
|
|
49
|
+
)
|
|
50
|
+
y2 = model.inference(
|
|
51
|
+
x2,
|
|
52
|
+
mlp2,
|
|
53
|
+
partial(embed_fn, sar_flag=model.semi_autoregressive, mask=mask2, x=x2),
|
|
54
|
+
)
|
|
55
|
+
assert (y1 != y2).sum() == 0
|
|
56
|
+
model.semi_autoregressive = True
|
|
57
|
+
y1 = model.inference(
|
|
58
|
+
x1,
|
|
59
|
+
mlp2,
|
|
60
|
+
partial(embed_fn, sar_flag=model.semi_autoregressive, mask=mask1, x=x1),
|
|
61
|
+
)
|
|
62
|
+
y2 = model.inference(
|
|
63
|
+
x2,
|
|
64
|
+
mlp2,
|
|
65
|
+
partial(embed_fn, sar_flag=model.semi_autoregressive, mask=mask2, x=x2),
|
|
66
|
+
)
|
|
67
|
+
assert (y1 != y2).sum() == 0
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{bayesianflow_for_chem-2.0.2 → bayesianflow_for_chem-2.0.4}/bayesianflow_for_chem/spectra.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|