bayesianflow-for-chem 2.1.0__tar.gz → 2.2.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bayesianflow-for-chem might be problematic. Click here for more details.
- {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/PKG-INFO +5 -2
- {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/README.md +3 -0
- {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/bayesianflow_for_chem/__init__.py +4 -3
- {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/bayesianflow_for_chem/cli.py +110 -34
- {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/bayesianflow_for_chem/data.py +23 -22
- {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/bayesianflow_for_chem/model.py +131 -154
- {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/bayesianflow_for_chem/tool.py +100 -127
- {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/bayesianflow_for_chem/train.py +8 -1
- {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/bayesianflow_for_chem.egg-info/PKG-INFO +5 -2
- {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/bayesianflow_for_chem.egg-info/SOURCES.txt +2 -1
- {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/pyproject.toml +3 -0
- {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/setup.py +1 -1
- bayesianflow_for_chem-2.2.2/test/test_cli_plugin.py +55 -0
- {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/LICENSE +0 -0
- {bayesianflow_for_chem-2.1.0/bayesianflow_for_chem → bayesianflow_for_chem-2.2.2/bayesianflow_for_chem/_data}/vocab.txt +0 -0
- {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/bayesianflow_for_chem/scorer.py +0 -0
- {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/bayesianflow_for_chem/spectra.py +0 -0
- {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/bayesianflow_for_chem.egg-info/dependency_links.txt +0 -0
- {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/bayesianflow_for_chem.egg-info/entry_points.txt +0 -0
- {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/bayesianflow_for_chem.egg-info/requires.txt +0 -0
- {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/bayesianflow_for_chem.egg-info/top_level.txt +0 -0
- {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/setup.cfg +0 -0
- {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/test/test_jit_compatibility.py +0 -0
- {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/test/test_merge_lora.py +0 -0
- {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/test/test_molecular_embedding.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: bayesianflow_for_chem
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.2.2
|
|
4
4
|
Summary: Bayesian flow network framework for Chemistry
|
|
5
5
|
Home-page: https://augus1999.github.io/bayesian-flow-network-for-chemistry/
|
|
6
6
|
Author: Nianze A. Tao
|
|
@@ -54,12 +54,14 @@ This is the repository of the PyTorch implementation of ChemBFN model.
|
|
|
54
54
|
|
|
55
55
|
[](https://pypi.org/project/bayesianflow-for-chem/)
|
|
56
56
|

|
|
57
|
+
[](https://augus1999.github.io/bayesian-flow-network-for-chemistry/)
|
|
57
58
|
|
|
58
59
|
## Features
|
|
59
60
|
|
|
60
61
|
ChemBFN provides the state-of-the-art functionalities of
|
|
61
62
|
* SMILES or SELFIES-based *de novo* molecule generation
|
|
62
63
|
* Protein sequence *de novo* generation
|
|
64
|
+
* Template optimisation (mol2mol)
|
|
63
65
|
* Classifier-free guidance conditional generation (single or multi-objective optimisation)
|
|
64
66
|
* Context-guided conditional generation (inpaint)
|
|
65
67
|
* Outstanding out-of-distribution chemical space sampling
|
|
@@ -71,6 +73,7 @@ in an all-in-one-model style.
|
|
|
71
73
|
|
|
72
74
|
## News
|
|
73
75
|
|
|
76
|
+
* [09/10/2025] A web app [`chembfn_webui`](https://github.com/Augus1999/ChemBFN-WebUI) for hosting ChemBFN models is available on [PyPI](https://pypi.org/project/chembfn-webui/).
|
|
74
77
|
* [30/01/2025] The package `bayesianflow_for_chem` is available on [PyPI](https://pypi.org/project/bayesianflow-for-chem/).
|
|
75
78
|
* [21/01/2025] Our first paper has been accepted by [JCIM](https://pubs.acs.org/doi/10.1021/acs.jcim.4c01792).
|
|
76
79
|
* [17/12/2024] The second paper of out-of-distribution generation is available on [arxiv.org](https://arxiv.org/abs/2412.11439).
|
|
@@ -93,7 +96,7 @@ You can find pretrained models on our [🤗Hugging Face model page](https://hugg
|
|
|
93
96
|
|
|
94
97
|
## Dataset Handling
|
|
95
98
|
|
|
96
|
-
We provide a Python class [`CSVData`](https://github.com/Augus1999/bayesian-flow-network-for-chemistry/blob/main/bayesianflow_for_chem/data.py#
|
|
99
|
+
We provide a Python class [`CSVData`](https://github.com/Augus1999/bayesian-flow-network-for-chemistry/blob/main/bayesianflow_for_chem/data.py#L153) to handle data stored in CSV or similar format containing headers to identify the entities. The following is a quickstart.
|
|
97
100
|
|
|
98
101
|
1. Download your dataset file (e.g., ESOL from [MoleculeNet](https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv)) and split the file:
|
|
99
102
|
```python
|
|
@@ -9,12 +9,14 @@ This is the repository of the PyTorch implementation of ChemBFN model.
|
|
|
9
9
|
|
|
10
10
|
[](https://pypi.org/project/bayesianflow-for-chem/)
|
|
11
11
|

|
|
12
|
+
[](https://augus1999.github.io/bayesian-flow-network-for-chemistry/)
|
|
12
13
|
|
|
13
14
|
## Features
|
|
14
15
|
|
|
15
16
|
ChemBFN provides the state-of-the-art functionalities of
|
|
16
17
|
* SMILES or SELFIES-based *de novo* molecule generation
|
|
17
18
|
* Protein sequence *de novo* generation
|
|
19
|
+
* Template optimisation (mol2mol)
|
|
18
20
|
* Classifier-free guidance conditional generation (single or multi-objective optimisation)
|
|
19
21
|
* Context-guided conditional generation (inpaint)
|
|
20
22
|
* Outstanding out-of-distribution chemical space sampling
|
|
@@ -26,6 +28,7 @@ in an all-in-one-model style.
|
|
|
26
28
|
|
|
27
29
|
## News
|
|
28
30
|
|
|
31
|
+
* [09/10/2025] A web app [`chembfn_webui`](https://github.com/Augus1999/ChemBFN-WebUI) for hosting ChemBFN models is available on [PyPI](https://pypi.org/project/chembfn-webui/).
|
|
29
32
|
* [30/01/2025] The package `bayesianflow_for_chem` is available on [PyPI](https://pypi.org/project/bayesianflow-for-chem/).
|
|
30
33
|
* [21/01/2025] Our first paper has been accepted by [JCIM](https://pubs.acs.org/doi/10.1021/acs.jcim.4c01792).
|
|
31
34
|
* [17/12/2024] The second paper of out-of-distribution generation is available on [arxiv.org](https://arxiv.org/abs/2412.11439).
|
{bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/bayesianflow_for_chem/__init__.py
RENAMED
|
@@ -3,10 +3,8 @@
|
|
|
3
3
|
"""
|
|
4
4
|
ChemBFN package.
|
|
5
5
|
"""
|
|
6
|
-
import colorama
|
|
7
6
|
from . import data, tool, train, scorer, spectra
|
|
8
7
|
from .model import ChemBFN, MLP, EnsembleChemBFN
|
|
9
|
-
from .cli import main_script
|
|
10
8
|
|
|
11
9
|
__all__ = [
|
|
12
10
|
"data",
|
|
@@ -18,7 +16,7 @@ __all__ = [
|
|
|
18
16
|
"MLP",
|
|
19
17
|
"EnsembleChemBFN",
|
|
20
18
|
]
|
|
21
|
-
__version__ = "2.
|
|
19
|
+
__version__ = "2.2.2"
|
|
22
20
|
__author__ = "Nianze A. Tao (Omozawa Sueno)"
|
|
23
21
|
|
|
24
22
|
|
|
@@ -29,6 +27,9 @@ def main() -> None:
|
|
|
29
27
|
:return:
|
|
30
28
|
:rtype: None
|
|
31
29
|
"""
|
|
30
|
+
import colorama
|
|
31
|
+
from bayesianflow_for_chem.cli import main_script
|
|
32
|
+
|
|
32
33
|
colorama.just_fix_windows_console()
|
|
33
34
|
main_script(__version__)
|
|
34
35
|
colorama.deinit()
|
|
@@ -12,22 +12,17 @@ from pathlib import Path
|
|
|
12
12
|
from functools import partial
|
|
13
13
|
from typing import List, Tuple, Dict, Union, Callable
|
|
14
14
|
import torch
|
|
15
|
-
import lightning as L
|
|
16
15
|
from rdkit.Chem import MolFromSmiles, CanonSmiles
|
|
17
|
-
from torch.utils.data import DataLoader
|
|
18
|
-
from lightning.pytorch import loggers
|
|
19
|
-
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
20
16
|
from bayesianflow_for_chem import ChemBFN, MLP
|
|
21
|
-
from bayesianflow_for_chem.train import Model
|
|
22
17
|
from bayesianflow_for_chem.scorer import smiles_valid, Scorer
|
|
23
18
|
from bayesianflow_for_chem.data import (
|
|
24
19
|
VOCAB_COUNT,
|
|
25
20
|
VOCAB_KEYS,
|
|
26
|
-
|
|
27
|
-
|
|
21
|
+
FASTA_VOCAB_COUNT,
|
|
22
|
+
FASTA_VOCAB_KEYS,
|
|
28
23
|
load_vocab,
|
|
29
24
|
smiles2token,
|
|
30
|
-
|
|
25
|
+
fasta2token,
|
|
31
26
|
split_selfies,
|
|
32
27
|
collate,
|
|
33
28
|
CSVData,
|
|
@@ -90,6 +85,7 @@ checkpoint_save_path = "home/user/project/ckpt"
|
|
|
90
85
|
train_strategy = "auto" # or any strategy supported by Lightning, e.g., "ddp"
|
|
91
86
|
accumulate_grad_batches = 1
|
|
92
87
|
enable_progress_bar = false
|
|
88
|
+
plugin_script = "" # define customised behaviours of dataset, datasetloader, etc in a python script
|
|
93
89
|
|
|
94
90
|
# Remove this table if inference is unnecessary
|
|
95
91
|
[inference]
|
|
@@ -120,6 +116,49 @@ madmadmadmadmadmadmadmadmadmadmadmadmadmadmad
|
|
|
120
116
|
madmadmadmadmadmadmadmadmadmadmadmadmadmadmad
|
|
121
117
|
"""
|
|
122
118
|
|
|
119
|
+
_END_MESSAGE = r"""
|
|
120
|
+
If you find this project helpful, please cite us:
|
|
121
|
+
1. N. Tao, and M. Abe, J. Chem. Inf. Model., 2025, 65, 1178-1187.
|
|
122
|
+
2. N. Tao, 2024, arXiv:2412.11439.
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
_ERROR_MESSAGE = r"""
|
|
126
|
+
Some who believe in inductive logic are anxious to point out, with
|
|
127
|
+
Reichenbach, that 'the principle of induction is unreservedly accepted
|
|
128
|
+
by the whole of science and that no man can seriously doubt this
|
|
129
|
+
principle in everyday life either'. Yet even supposing this were the
|
|
130
|
+
case—for after all, 'the whole of science' might err—I should still
|
|
131
|
+
contend that a principle of induction is superfluous, and that it must
|
|
132
|
+
lead to logical inconsistencies.
|
|
133
|
+
-- Karl Popper --
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
_ALLOWED_PLUGINS = [
|
|
137
|
+
"collate_fn",
|
|
138
|
+
"num_workers",
|
|
139
|
+
"max_sequence_length",
|
|
140
|
+
"shuffle",
|
|
141
|
+
"CustomData",
|
|
142
|
+
]
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _load_plugin(plugin_file: str) -> Dict[str, Union[int, Callable, object, None]]:
|
|
146
|
+
if not plugin_file:
|
|
147
|
+
return {n: None for n in _ALLOWED_PLUGINS}
|
|
148
|
+
from importlib import util as iutil
|
|
149
|
+
|
|
150
|
+
spec = iutil.spec_from_file_location(Path(plugin_file).stem, plugin_file)
|
|
151
|
+
plugins = iutil.module_from_spec(spec)
|
|
152
|
+
spec.loader.exec_module(plugins)
|
|
153
|
+
plugin_names: List[str] = plugins.__all__
|
|
154
|
+
plugin_dict = {}
|
|
155
|
+
for n in _ALLOWED_PLUGINS:
|
|
156
|
+
if n in plugin_names:
|
|
157
|
+
plugin_dict[n] = getattr(plugins, n)
|
|
158
|
+
else:
|
|
159
|
+
plugin_dict[n] = None
|
|
160
|
+
return plugin_dict
|
|
161
|
+
|
|
123
162
|
|
|
124
163
|
def parse_cli(version: str) -> argparse.Namespace:
|
|
125
164
|
"""
|
|
@@ -267,6 +306,14 @@ def load_runtime_config(
|
|
|
267
306
|
f"\033[0;31mCritical\033[0;0m in {config_file}: Restart checkpoint file {ckpt_file} does not exist."
|
|
268
307
|
)
|
|
269
308
|
flag_critical += 1
|
|
309
|
+
# ↓ added in v2.2.0; need to be compatible with old versions.
|
|
310
|
+
plugin_script: str = config["train"].get("plugin_script", "")
|
|
311
|
+
if plugin_script:
|
|
312
|
+
if not os.path.exists(plugin_script):
|
|
313
|
+
print(
|
|
314
|
+
f"\033[0;31mCritical\033[0;0m in {config_file}: Plugin script {plugin_script} does not exist."
|
|
315
|
+
)
|
|
316
|
+
flag_critical += 1
|
|
270
317
|
if "inference" in config:
|
|
271
318
|
if not "train" in config:
|
|
272
319
|
if not isinstance(config["inference"]["sequence_length"], int):
|
|
@@ -335,14 +382,15 @@ def main_script(version: str) -> None:
|
|
|
335
382
|
)
|
|
336
383
|
flag_warning += 1
|
|
337
384
|
if not os.path.exists(runtime_config["train"]["checkpoint_save_path"]):
|
|
338
|
-
|
|
385
|
+
if not parser.dryrun: # only create it in real tasks
|
|
386
|
+
os.makedirs(runtime_config["train"]["checkpoint_save_path"])
|
|
339
387
|
else:
|
|
340
388
|
if not model_config["ChemBFN"]["base_model"]:
|
|
341
389
|
print(
|
|
342
390
|
f"\033[0;33mWarning\033[0;0m in {parser.model_config}: You should load a pretrained ChemBFN model."
|
|
343
391
|
)
|
|
344
392
|
flag_warning += 1
|
|
345
|
-
if not model_config["MLP"]["base_model"]:
|
|
393
|
+
if "MLP" in model_config and not model_config["MLP"]["base_model"]:
|
|
346
394
|
print(
|
|
347
395
|
f"\033[0;33mWarning\033[0;0m in {parser.model_config}: You should load a pretrained MLP."
|
|
348
396
|
)
|
|
@@ -350,18 +398,22 @@ def main_script(version: str) -> None:
|
|
|
350
398
|
if "inference" in runtime_config:
|
|
351
399
|
if runtime_config["inference"]["guidance_objective"]:
|
|
352
400
|
if not "MLP" in model_config:
|
|
353
|
-
print(
|
|
401
|
+
print(
|
|
402
|
+
f"\033[0;33mWarning\033[0;0m in {parser.model_config}: Oh no, you don't have a MLP."
|
|
403
|
+
)
|
|
354
404
|
flag_warning += 1
|
|
355
405
|
if parser.dryrun:
|
|
356
406
|
if flag_critical != 0:
|
|
357
407
|
print("Configuration check failed!")
|
|
358
408
|
elif flag_warning != 0:
|
|
359
|
-
print(
|
|
409
|
+
print(
|
|
410
|
+
"Your job will probably run, but it may not follow your expectations."
|
|
411
|
+
)
|
|
360
412
|
else:
|
|
361
413
|
print("Configuration check passed.")
|
|
362
414
|
return
|
|
363
415
|
if flag_critical != 0:
|
|
364
|
-
raise RuntimeError
|
|
416
|
+
raise RuntimeError(_ERROR_MESSAGE)
|
|
365
417
|
print(_MESSAGE.format(version))
|
|
366
418
|
# ####### build tokeniser #######
|
|
367
419
|
tokeniser_config = runtime_config["tokeniser"]
|
|
@@ -371,9 +423,9 @@ def main_script(version: str) -> None:
|
|
|
371
423
|
vocab_keys = VOCAB_KEYS
|
|
372
424
|
tokeniser = smiles2token
|
|
373
425
|
if tokeniser_name == "fasta":
|
|
374
|
-
num_vocab =
|
|
375
|
-
vocab_keys =
|
|
376
|
-
tokeniser =
|
|
426
|
+
num_vocab = FASTA_VOCAB_COUNT
|
|
427
|
+
vocab_keys = FASTA_VOCAB_KEYS
|
|
428
|
+
tokeniser = fasta2token
|
|
377
429
|
if tokeniser_name == "selfies":
|
|
378
430
|
vocab_data = load_vocab(tokeniser_config["vocab"])
|
|
379
431
|
num_vocab = vocab_data["vocab_count"]
|
|
@@ -415,6 +467,15 @@ def main_script(version: str) -> None:
|
|
|
415
467
|
mlp = None
|
|
416
468
|
# ------- train -------
|
|
417
469
|
if "train" in runtime_config:
|
|
470
|
+
import lightning as L
|
|
471
|
+
from torch.utils.data import DataLoader
|
|
472
|
+
from lightning.pytorch import loggers
|
|
473
|
+
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
474
|
+
from bayesianflow_for_chem.train import Model
|
|
475
|
+
|
|
476
|
+
# ####### get plugins #######
|
|
477
|
+
plugin_file = runtime_config["train"].get("plugin_script", "")
|
|
478
|
+
plugins = _load_plugin(plugin_file)
|
|
418
479
|
# ####### build scorer #######
|
|
419
480
|
if (tokeniser_name == "smiles" or tokeniser_name == "safe") and runtime_config[
|
|
420
481
|
"train"
|
|
@@ -428,30 +489,43 @@ def main_script(version: str) -> None:
|
|
|
428
489
|
mol_tag = runtime_config["train"]["molecule_tag"]
|
|
429
490
|
obj_tag = runtime_config["train"]["objective_tag"]
|
|
430
491
|
dataset_file = runtime_config["train"]["dataset"]
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
492
|
+
if plugins["max_sequence_length"]:
|
|
493
|
+
lmax = plugins["max_sequence_length"]
|
|
494
|
+
else:
|
|
495
|
+
with open(dataset_file, "r") as db:
|
|
496
|
+
_data = db.readlines()
|
|
497
|
+
_header = _data[0]
|
|
498
|
+
_mol_idx = []
|
|
499
|
+
for i, tag in enumerate(_header.replace("\n", "").split(",")):
|
|
500
|
+
if tag == mol_tag:
|
|
501
|
+
_mol_idx.append(i)
|
|
502
|
+
_data_len = []
|
|
503
|
+
for i in _data[1:]:
|
|
504
|
+
i = i.replace("\n", "").split(",")
|
|
505
|
+
_mol = ".".join([i[j] for j in _mol_idx])
|
|
506
|
+
_data_len.append(tokeniser(_mol).shape[-1])
|
|
507
|
+
lmax = max(_data_len)
|
|
508
|
+
del _data, _data_len, _header, _mol_idx # clear memory
|
|
509
|
+
if plugins["CustomData"] is not None:
|
|
510
|
+
dataset = plugins["CustomData"](dataset_file)
|
|
511
|
+
else:
|
|
512
|
+
dataset = CSVData(dataset_file)
|
|
445
513
|
dataset.map(
|
|
446
514
|
partial(_encode, mol_tag=mol_tag, obj_tag=obj_tag, tokeniser=tokeniser)
|
|
447
515
|
)
|
|
448
516
|
dataloader = DataLoader(
|
|
449
517
|
dataset,
|
|
450
518
|
runtime_config["train"]["batch_size"],
|
|
451
|
-
True,
|
|
452
|
-
num_workers=4,
|
|
453
|
-
collate_fn=
|
|
454
|
-
|
|
519
|
+
True if plugins["shuffle"] is None else plugins["shuffle"],
|
|
520
|
+
num_workers=4 if plugins["num_workers"] is None else plugins["num_workers"],
|
|
521
|
+
collate_fn=(
|
|
522
|
+
collate if plugins["collate_fn"] is None else plugins["collate_fn"]
|
|
523
|
+
),
|
|
524
|
+
persistent_workers=(
|
|
525
|
+
True
|
|
526
|
+
if (plugins["num_workers"] is None or plugins["num_workers"] > 0)
|
|
527
|
+
else False
|
|
528
|
+
),
|
|
455
529
|
)
|
|
456
530
|
# ####### build trainer #######
|
|
457
531
|
logger_name = runtime_config["train"]["logger_name"].lower()
|
|
@@ -530,6 +604,7 @@ def main_script(version: str) -> None:
|
|
|
530
604
|
if "train" in runtime_config:
|
|
531
605
|
bfn = model.model
|
|
532
606
|
mlp = model.mlp
|
|
607
|
+
# ↓ added in v2.1.0; need to be compatible with old versions
|
|
533
608
|
lora_scaling = runtime_config["inference"].get("lora_scaling", 1.0)
|
|
534
609
|
# ####### strat inference #######
|
|
535
610
|
bfn.semi_autoregressive = runtime_config["inference"]["semi_autoregressive"]
|
|
@@ -622,6 +697,7 @@ def main_script(version: str) -> None:
|
|
|
622
697
|
f.write("\n".join(mols))
|
|
623
698
|
# ------- finished -------
|
|
624
699
|
print(" ####### job finished #######")
|
|
700
|
+
print(_END_MESSAGE)
|
|
625
701
|
|
|
626
702
|
|
|
627
703
|
if __name__ == "__main__":
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
# -*- coding: utf-8 -*-
|
|
2
2
|
# Author: Nianze A. TAO (Omozawa SUENO)
|
|
3
3
|
"""
|
|
4
|
-
Tokenise SMILES/SAFE/SELFIES/
|
|
4
|
+
Tokenise SMILES/SAFE/SELFIES/FASTA strings.
|
|
5
5
|
"""
|
|
6
6
|
import os
|
|
7
7
|
import re
|
|
@@ -14,7 +14,7 @@ from torch.utils.data import Dataset
|
|
|
14
14
|
|
|
15
15
|
__filedir__ = Path(__file__).parent
|
|
16
16
|
|
|
17
|
-
|
|
17
|
+
_SMI_REGEX_PATTERN = (
|
|
18
18
|
r"(\[|\]|H[e,f,g,s,o]?|"
|
|
19
19
|
r"L[i,v,a,r,u]|"
|
|
20
20
|
r"B[e,r,a,i,h,k]?|"
|
|
@@ -31,11 +31,11 @@ SMI_REGEX_PATTERN = (
|
|
|
31
31
|
r"\(|\)|\.|=|#|-|\+|\\|\/|:|"
|
|
32
32
|
r"~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
|
|
33
33
|
)
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
34
|
+
_SEL_REGEX_PATTERN = r"(\[[^\]]+]|\.)"
|
|
35
|
+
_FAS_REGEX_PATTERN = r"(A|B|C|D|E|F|G|H|I|J|K|L|M|N|O|P|Q|R|S|T|U|V|W|X|Y|Z|-|\*|\.)"
|
|
36
|
+
_smi_regex = re.compile(_SMI_REGEX_PATTERN)
|
|
37
|
+
_sel_regex = re.compile(_SEL_REGEX_PATTERN)
|
|
38
|
+
_fas_regex = re.compile(_FAS_REGEX_PATTERN)
|
|
39
39
|
|
|
40
40
|
|
|
41
41
|
def load_vocab(
|
|
@@ -61,15 +61,16 @@ def load_vocab(
|
|
|
61
61
|
}
|
|
62
62
|
|
|
63
63
|
|
|
64
|
-
_DEFUALT_VOCAB = load_vocab(__filedir__ / "vocab.txt")
|
|
64
|
+
_DEFUALT_VOCAB = load_vocab(__filedir__ / "_data/vocab.txt")
|
|
65
65
|
VOCAB_KEYS: List[str] = _DEFUALT_VOCAB["vocab_keys"]
|
|
66
66
|
VOCAB_DICT: Dict[str, int] = _DEFUALT_VOCAB["vocab_dict"]
|
|
67
67
|
VOCAB_COUNT: int = _DEFUALT_VOCAB["vocab_count"]
|
|
68
|
-
|
|
69
|
-
VOCAB_KEYS[0:3]
|
|
68
|
+
FASTA_VOCAB_KEYS = (
|
|
69
|
+
VOCAB_KEYS[0:3]
|
|
70
|
+
+ "A B C D E F G H I K L M N P Q R S T V W Y Z - . J O U X *".split()
|
|
70
71
|
)
|
|
71
|
-
|
|
72
|
-
|
|
72
|
+
FASTA_VOCAB_COUNT = len(FASTA_VOCAB_KEYS)
|
|
73
|
+
FASTA_VOCAB_DICT = dict(zip(FASTA_VOCAB_KEYS, range(FASTA_VOCAB_COUNT)))
|
|
73
74
|
|
|
74
75
|
|
|
75
76
|
def smiles2vec(smiles: str) -> List[int]:
|
|
@@ -81,21 +82,21 @@ def smiles2vec(smiles: str) -> List[int]:
|
|
|
81
82
|
:return: tokens w/o `<start>` and `<end>`
|
|
82
83
|
:rtype: list
|
|
83
84
|
"""
|
|
84
|
-
tokens = [token for token in
|
|
85
|
+
tokens = [token for token in _smi_regex.findall(smiles)]
|
|
85
86
|
return [VOCAB_DICT[token] for token in tokens]
|
|
86
87
|
|
|
87
88
|
|
|
88
|
-
def
|
|
89
|
+
def fasta2vec(fasta: str) -> List[int]:
|
|
89
90
|
"""
|
|
90
|
-
|
|
91
|
+
FASTA sequence tokenisation using a dataset-independent regex pattern.
|
|
91
92
|
|
|
92
|
-
:param
|
|
93
|
-
:type
|
|
93
|
+
:param fasta: protein (amino acid) sequence
|
|
94
|
+
:type fasta: str
|
|
94
95
|
:return: tokens w/o `<start>` and `<end>`
|
|
95
96
|
:rtype: list
|
|
96
97
|
"""
|
|
97
|
-
tokens = [token for token in
|
|
98
|
-
return [
|
|
98
|
+
tokens = [token for token in _fas_regex.findall(fasta)]
|
|
99
|
+
return [FASTA_VOCAB_DICT[token] for token in tokens]
|
|
99
100
|
|
|
100
101
|
|
|
101
102
|
def split_selfies(selfies: str) -> List[str]:
|
|
@@ -107,7 +108,7 @@ def split_selfies(selfies: str) -> List[str]:
|
|
|
107
108
|
:return: SELFIES vocab
|
|
108
109
|
:rtype: list
|
|
109
110
|
"""
|
|
110
|
-
return [token for token in
|
|
111
|
+
return [token for token in _sel_regex.findall(selfies)]
|
|
111
112
|
|
|
112
113
|
|
|
113
114
|
def smiles2token(smiles: str) -> Tensor:
|
|
@@ -115,9 +116,9 @@ def smiles2token(smiles: str) -> Tensor:
|
|
|
115
116
|
return torch.tensor([1] + smiles2vec(smiles) + [2], dtype=torch.long)
|
|
116
117
|
|
|
117
118
|
|
|
118
|
-
def
|
|
119
|
+
def fasta2token(fasta: str) -> Tensor:
|
|
119
120
|
# start token: <start> = 1; end token: <end> = 2
|
|
120
|
-
return torch.tensor([1] +
|
|
121
|
+
return torch.tensor([1] + fasta2vec(fasta) + [2], dtype=torch.long)
|
|
121
122
|
|
|
122
123
|
|
|
123
124
|
def collate(batch: List[Dict[str, Tensor]]) -> Dict[str, Tensor]:
|