bayesianflow-for-chem 2.1.0__py3-none-any.whl → 2.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bayesianflow-for-chem might be problematic. Click here for more details.
- bayesianflow_for_chem/__init__.py +4 -3
- bayesianflow_for_chem/cli.py +110 -34
- bayesianflow_for_chem/data.py +23 -22
- bayesianflow_for_chem/model.py +131 -154
- bayesianflow_for_chem/tool.py +100 -127
- bayesianflow_for_chem/train.py +8 -1
- {bayesianflow_for_chem-2.1.0.dist-info → bayesianflow_for_chem-2.2.2.dist-info}/METADATA +5 -2
- bayesianflow_for_chem-2.2.2.dist-info/RECORD +15 -0
- bayesianflow_for_chem-2.1.0.dist-info/RECORD +0 -15
- /bayesianflow_for_chem/{vocab.txt → _data/vocab.txt} +0 -0
- {bayesianflow_for_chem-2.1.0.dist-info → bayesianflow_for_chem-2.2.2.dist-info}/WHEEL +0 -0
- {bayesianflow_for_chem-2.1.0.dist-info → bayesianflow_for_chem-2.2.2.dist-info}/entry_points.txt +0 -0
- {bayesianflow_for_chem-2.1.0.dist-info → bayesianflow_for_chem-2.2.2.dist-info}/licenses/LICENSE +0 -0
- {bayesianflow_for_chem-2.1.0.dist-info → bayesianflow_for_chem-2.2.2.dist-info}/top_level.txt +0 -0
|
@@ -3,10 +3,8 @@
|
|
|
3
3
|
"""
|
|
4
4
|
ChemBFN package.
|
|
5
5
|
"""
|
|
6
|
-
import colorama
|
|
7
6
|
from . import data, tool, train, scorer, spectra
|
|
8
7
|
from .model import ChemBFN, MLP, EnsembleChemBFN
|
|
9
|
-
from .cli import main_script
|
|
10
8
|
|
|
11
9
|
__all__ = [
|
|
12
10
|
"data",
|
|
@@ -18,7 +16,7 @@ __all__ = [
|
|
|
18
16
|
"MLP",
|
|
19
17
|
"EnsembleChemBFN",
|
|
20
18
|
]
|
|
21
|
-
__version__ = "2.
|
|
19
|
+
__version__ = "2.2.2"
|
|
22
20
|
__author__ = "Nianze A. Tao (Omozawa Sueno)"
|
|
23
21
|
|
|
24
22
|
|
|
@@ -29,6 +27,9 @@ def main() -> None:
|
|
|
29
27
|
:return:
|
|
30
28
|
:rtype: None
|
|
31
29
|
"""
|
|
30
|
+
import colorama
|
|
31
|
+
from bayesianflow_for_chem.cli import main_script
|
|
32
|
+
|
|
32
33
|
colorama.just_fix_windows_console()
|
|
33
34
|
main_script(__version__)
|
|
34
35
|
colorama.deinit()
|
bayesianflow_for_chem/cli.py
CHANGED
|
@@ -12,22 +12,17 @@ from pathlib import Path
|
|
|
12
12
|
from functools import partial
|
|
13
13
|
from typing import List, Tuple, Dict, Union, Callable
|
|
14
14
|
import torch
|
|
15
|
-
import lightning as L
|
|
16
15
|
from rdkit.Chem import MolFromSmiles, CanonSmiles
|
|
17
|
-
from torch.utils.data import DataLoader
|
|
18
|
-
from lightning.pytorch import loggers
|
|
19
|
-
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
20
16
|
from bayesianflow_for_chem import ChemBFN, MLP
|
|
21
|
-
from bayesianflow_for_chem.train import Model
|
|
22
17
|
from bayesianflow_for_chem.scorer import smiles_valid, Scorer
|
|
23
18
|
from bayesianflow_for_chem.data import (
|
|
24
19
|
VOCAB_COUNT,
|
|
25
20
|
VOCAB_KEYS,
|
|
26
|
-
|
|
27
|
-
|
|
21
|
+
FASTA_VOCAB_COUNT,
|
|
22
|
+
FASTA_VOCAB_KEYS,
|
|
28
23
|
load_vocab,
|
|
29
24
|
smiles2token,
|
|
30
|
-
|
|
25
|
+
fasta2token,
|
|
31
26
|
split_selfies,
|
|
32
27
|
collate,
|
|
33
28
|
CSVData,
|
|
@@ -90,6 +85,7 @@ checkpoint_save_path = "home/user/project/ckpt"
|
|
|
90
85
|
train_strategy = "auto" # or any strategy supported by Lightning, e.g., "ddp"
|
|
91
86
|
accumulate_grad_batches = 1
|
|
92
87
|
enable_progress_bar = false
|
|
88
|
+
plugin_script = "" # define customised behaviours of dataset, datasetloader, etc in a python script
|
|
93
89
|
|
|
94
90
|
# Remove this table if inference is unnecessary
|
|
95
91
|
[inference]
|
|
@@ -120,6 +116,49 @@ madmadmadmadmadmadmadmadmadmadmadmadmadmadmad
|
|
|
120
116
|
madmadmadmadmadmadmadmadmadmadmadmadmadmadmad
|
|
121
117
|
"""
|
|
122
118
|
|
|
119
|
+
_END_MESSAGE = r"""
|
|
120
|
+
If you find this project helpful, please cite us:
|
|
121
|
+
1. N. Tao, and M. Abe, J. Chem. Inf. Model., 2025, 65, 1178-1187.
|
|
122
|
+
2. N. Tao, 2024, arXiv:2412.11439.
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
_ERROR_MESSAGE = r"""
|
|
126
|
+
Some who believe in inductive logic are anxious to point out, with
|
|
127
|
+
Reichenbach, that 'the principle of induction is unreservedly accepted
|
|
128
|
+
by the whole of science and that no man can seriously doubt this
|
|
129
|
+
principle in everyday life either'. Yet even supposing this were the
|
|
130
|
+
case—for after all, 'the whole of science' might err—I should still
|
|
131
|
+
contend that a principle of induction is superfluous, and that it must
|
|
132
|
+
lead to logical inconsistencies.
|
|
133
|
+
-- Karl Popper --
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
_ALLOWED_PLUGINS = [
|
|
137
|
+
"collate_fn",
|
|
138
|
+
"num_workers",
|
|
139
|
+
"max_sequence_length",
|
|
140
|
+
"shuffle",
|
|
141
|
+
"CustomData",
|
|
142
|
+
]
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _load_plugin(plugin_file: str) -> Dict[str, Union[int, Callable, object, None]]:
|
|
146
|
+
if not plugin_file:
|
|
147
|
+
return {n: None for n in _ALLOWED_PLUGINS}
|
|
148
|
+
from importlib import util as iutil
|
|
149
|
+
|
|
150
|
+
spec = iutil.spec_from_file_location(Path(plugin_file).stem, plugin_file)
|
|
151
|
+
plugins = iutil.module_from_spec(spec)
|
|
152
|
+
spec.loader.exec_module(plugins)
|
|
153
|
+
plugin_names: List[str] = plugins.__all__
|
|
154
|
+
plugin_dict = {}
|
|
155
|
+
for n in _ALLOWED_PLUGINS:
|
|
156
|
+
if n in plugin_names:
|
|
157
|
+
plugin_dict[n] = getattr(plugins, n)
|
|
158
|
+
else:
|
|
159
|
+
plugin_dict[n] = None
|
|
160
|
+
return plugin_dict
|
|
161
|
+
|
|
123
162
|
|
|
124
163
|
def parse_cli(version: str) -> argparse.Namespace:
|
|
125
164
|
"""
|
|
@@ -267,6 +306,14 @@ def load_runtime_config(
|
|
|
267
306
|
f"\033[0;31mCritical\033[0;0m in {config_file}: Restart checkpoint file {ckpt_file} does not exist."
|
|
268
307
|
)
|
|
269
308
|
flag_critical += 1
|
|
309
|
+
# ↓ added in v2.2.0; need to be compatible with old versions.
|
|
310
|
+
plugin_script: str = config["train"].get("plugin_script", "")
|
|
311
|
+
if plugin_script:
|
|
312
|
+
if not os.path.exists(plugin_script):
|
|
313
|
+
print(
|
|
314
|
+
f"\033[0;31mCritical\033[0;0m in {config_file}: Plugin script {plugin_script} does not exist."
|
|
315
|
+
)
|
|
316
|
+
flag_critical += 1
|
|
270
317
|
if "inference" in config:
|
|
271
318
|
if not "train" in config:
|
|
272
319
|
if not isinstance(config["inference"]["sequence_length"], int):
|
|
@@ -335,14 +382,15 @@ def main_script(version: str) -> None:
|
|
|
335
382
|
)
|
|
336
383
|
flag_warning += 1
|
|
337
384
|
if not os.path.exists(runtime_config["train"]["checkpoint_save_path"]):
|
|
338
|
-
|
|
385
|
+
if not parser.dryrun: # only create it in real tasks
|
|
386
|
+
os.makedirs(runtime_config["train"]["checkpoint_save_path"])
|
|
339
387
|
else:
|
|
340
388
|
if not model_config["ChemBFN"]["base_model"]:
|
|
341
389
|
print(
|
|
342
390
|
f"\033[0;33mWarning\033[0;0m in {parser.model_config}: You should load a pretrained ChemBFN model."
|
|
343
391
|
)
|
|
344
392
|
flag_warning += 1
|
|
345
|
-
if not model_config["MLP"]["base_model"]:
|
|
393
|
+
if "MLP" in model_config and not model_config["MLP"]["base_model"]:
|
|
346
394
|
print(
|
|
347
395
|
f"\033[0;33mWarning\033[0;0m in {parser.model_config}: You should load a pretrained MLP."
|
|
348
396
|
)
|
|
@@ -350,18 +398,22 @@ def main_script(version: str) -> None:
|
|
|
350
398
|
if "inference" in runtime_config:
|
|
351
399
|
if runtime_config["inference"]["guidance_objective"]:
|
|
352
400
|
if not "MLP" in model_config:
|
|
353
|
-
print(
|
|
401
|
+
print(
|
|
402
|
+
f"\033[0;33mWarning\033[0;0m in {parser.model_config}: Oh no, you don't have a MLP."
|
|
403
|
+
)
|
|
354
404
|
flag_warning += 1
|
|
355
405
|
if parser.dryrun:
|
|
356
406
|
if flag_critical != 0:
|
|
357
407
|
print("Configuration check failed!")
|
|
358
408
|
elif flag_warning != 0:
|
|
359
|
-
print(
|
|
409
|
+
print(
|
|
410
|
+
"Your job will probably run, but it may not follow your expectations."
|
|
411
|
+
)
|
|
360
412
|
else:
|
|
361
413
|
print("Configuration check passed.")
|
|
362
414
|
return
|
|
363
415
|
if flag_critical != 0:
|
|
364
|
-
raise RuntimeError
|
|
416
|
+
raise RuntimeError(_ERROR_MESSAGE)
|
|
365
417
|
print(_MESSAGE.format(version))
|
|
366
418
|
# ####### build tokeniser #######
|
|
367
419
|
tokeniser_config = runtime_config["tokeniser"]
|
|
@@ -371,9 +423,9 @@ def main_script(version: str) -> None:
|
|
|
371
423
|
vocab_keys = VOCAB_KEYS
|
|
372
424
|
tokeniser = smiles2token
|
|
373
425
|
if tokeniser_name == "fasta":
|
|
374
|
-
num_vocab =
|
|
375
|
-
vocab_keys =
|
|
376
|
-
tokeniser =
|
|
426
|
+
num_vocab = FASTA_VOCAB_COUNT
|
|
427
|
+
vocab_keys = FASTA_VOCAB_KEYS
|
|
428
|
+
tokeniser = fasta2token
|
|
377
429
|
if tokeniser_name == "selfies":
|
|
378
430
|
vocab_data = load_vocab(tokeniser_config["vocab"])
|
|
379
431
|
num_vocab = vocab_data["vocab_count"]
|
|
@@ -415,6 +467,15 @@ def main_script(version: str) -> None:
|
|
|
415
467
|
mlp = None
|
|
416
468
|
# ------- train -------
|
|
417
469
|
if "train" in runtime_config:
|
|
470
|
+
import lightning as L
|
|
471
|
+
from torch.utils.data import DataLoader
|
|
472
|
+
from lightning.pytorch import loggers
|
|
473
|
+
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
474
|
+
from bayesianflow_for_chem.train import Model
|
|
475
|
+
|
|
476
|
+
# ####### get plugins #######
|
|
477
|
+
plugin_file = runtime_config["train"].get("plugin_script", "")
|
|
478
|
+
plugins = _load_plugin(plugin_file)
|
|
418
479
|
# ####### build scorer #######
|
|
419
480
|
if (tokeniser_name == "smiles" or tokeniser_name == "safe") and runtime_config[
|
|
420
481
|
"train"
|
|
@@ -428,30 +489,43 @@ def main_script(version: str) -> None:
|
|
|
428
489
|
mol_tag = runtime_config["train"]["molecule_tag"]
|
|
429
490
|
obj_tag = runtime_config["train"]["objective_tag"]
|
|
430
491
|
dataset_file = runtime_config["train"]["dataset"]
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
492
|
+
if plugins["max_sequence_length"]:
|
|
493
|
+
lmax = plugins["max_sequence_length"]
|
|
494
|
+
else:
|
|
495
|
+
with open(dataset_file, "r") as db:
|
|
496
|
+
_data = db.readlines()
|
|
497
|
+
_header = _data[0]
|
|
498
|
+
_mol_idx = []
|
|
499
|
+
for i, tag in enumerate(_header.replace("\n", "").split(",")):
|
|
500
|
+
if tag == mol_tag:
|
|
501
|
+
_mol_idx.append(i)
|
|
502
|
+
_data_len = []
|
|
503
|
+
for i in _data[1:]:
|
|
504
|
+
i = i.replace("\n", "").split(",")
|
|
505
|
+
_mol = ".".join([i[j] for j in _mol_idx])
|
|
506
|
+
_data_len.append(tokeniser(_mol).shape[-1])
|
|
507
|
+
lmax = max(_data_len)
|
|
508
|
+
del _data, _data_len, _header, _mol_idx # clear memory
|
|
509
|
+
if plugins["CustomData"] is not None:
|
|
510
|
+
dataset = plugins["CustomData"](dataset_file)
|
|
511
|
+
else:
|
|
512
|
+
dataset = CSVData(dataset_file)
|
|
445
513
|
dataset.map(
|
|
446
514
|
partial(_encode, mol_tag=mol_tag, obj_tag=obj_tag, tokeniser=tokeniser)
|
|
447
515
|
)
|
|
448
516
|
dataloader = DataLoader(
|
|
449
517
|
dataset,
|
|
450
518
|
runtime_config["train"]["batch_size"],
|
|
451
|
-
True,
|
|
452
|
-
num_workers=4,
|
|
453
|
-
collate_fn=
|
|
454
|
-
|
|
519
|
+
True if plugins["shuffle"] is None else plugins["shuffle"],
|
|
520
|
+
num_workers=4 if plugins["num_workers"] is None else plugins["num_workers"],
|
|
521
|
+
collate_fn=(
|
|
522
|
+
collate if plugins["collate_fn"] is None else plugins["collate_fn"]
|
|
523
|
+
),
|
|
524
|
+
persistent_workers=(
|
|
525
|
+
True
|
|
526
|
+
if (plugins["num_workers"] is None or plugins["num_workers"] > 0)
|
|
527
|
+
else False
|
|
528
|
+
),
|
|
455
529
|
)
|
|
456
530
|
# ####### build trainer #######
|
|
457
531
|
logger_name = runtime_config["train"]["logger_name"].lower()
|
|
@@ -530,6 +604,7 @@ def main_script(version: str) -> None:
|
|
|
530
604
|
if "train" in runtime_config:
|
|
531
605
|
bfn = model.model
|
|
532
606
|
mlp = model.mlp
|
|
607
|
+
# ↓ added in v2.1.0; need to be compatible with old versions
|
|
533
608
|
lora_scaling = runtime_config["inference"].get("lora_scaling", 1.0)
|
|
534
609
|
# ####### strat inference #######
|
|
535
610
|
bfn.semi_autoregressive = runtime_config["inference"]["semi_autoregressive"]
|
|
@@ -622,6 +697,7 @@ def main_script(version: str) -> None:
|
|
|
622
697
|
f.write("\n".join(mols))
|
|
623
698
|
# ------- finished -------
|
|
624
699
|
print(" ####### job finished #######")
|
|
700
|
+
print(_END_MESSAGE)
|
|
625
701
|
|
|
626
702
|
|
|
627
703
|
if __name__ == "__main__":
|
bayesianflow_for_chem/data.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
# -*- coding: utf-8 -*-
|
|
2
2
|
# Author: Nianze A. TAO (Omozawa SUENO)
|
|
3
3
|
"""
|
|
4
|
-
Tokenise SMILES/SAFE/SELFIES/
|
|
4
|
+
Tokenise SMILES/SAFE/SELFIES/FASTA strings.
|
|
5
5
|
"""
|
|
6
6
|
import os
|
|
7
7
|
import re
|
|
@@ -14,7 +14,7 @@ from torch.utils.data import Dataset
|
|
|
14
14
|
|
|
15
15
|
__filedir__ = Path(__file__).parent
|
|
16
16
|
|
|
17
|
-
|
|
17
|
+
_SMI_REGEX_PATTERN = (
|
|
18
18
|
r"(\[|\]|H[e,f,g,s,o]?|"
|
|
19
19
|
r"L[i,v,a,r,u]|"
|
|
20
20
|
r"B[e,r,a,i,h,k]?|"
|
|
@@ -31,11 +31,11 @@ SMI_REGEX_PATTERN = (
|
|
|
31
31
|
r"\(|\)|\.|=|#|-|\+|\\|\/|:|"
|
|
32
32
|
r"~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
|
|
33
33
|
)
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
34
|
+
_SEL_REGEX_PATTERN = r"(\[[^\]]+]|\.)"
|
|
35
|
+
_FAS_REGEX_PATTERN = r"(A|B|C|D|E|F|G|H|I|J|K|L|M|N|O|P|Q|R|S|T|U|V|W|X|Y|Z|-|\*|\.)"
|
|
36
|
+
_smi_regex = re.compile(_SMI_REGEX_PATTERN)
|
|
37
|
+
_sel_regex = re.compile(_SEL_REGEX_PATTERN)
|
|
38
|
+
_fas_regex = re.compile(_FAS_REGEX_PATTERN)
|
|
39
39
|
|
|
40
40
|
|
|
41
41
|
def load_vocab(
|
|
@@ -61,15 +61,16 @@ def load_vocab(
|
|
|
61
61
|
}
|
|
62
62
|
|
|
63
63
|
|
|
64
|
-
_DEFUALT_VOCAB = load_vocab(__filedir__ / "vocab.txt")
|
|
64
|
+
_DEFUALT_VOCAB = load_vocab(__filedir__ / "_data/vocab.txt")
|
|
65
65
|
VOCAB_KEYS: List[str] = _DEFUALT_VOCAB["vocab_keys"]
|
|
66
66
|
VOCAB_DICT: Dict[str, int] = _DEFUALT_VOCAB["vocab_dict"]
|
|
67
67
|
VOCAB_COUNT: int = _DEFUALT_VOCAB["vocab_count"]
|
|
68
|
-
|
|
69
|
-
VOCAB_KEYS[0:3]
|
|
68
|
+
FASTA_VOCAB_KEYS = (
|
|
69
|
+
VOCAB_KEYS[0:3]
|
|
70
|
+
+ "A B C D E F G H I K L M N P Q R S T V W Y Z - . J O U X *".split()
|
|
70
71
|
)
|
|
71
|
-
|
|
72
|
-
|
|
72
|
+
FASTA_VOCAB_COUNT = len(FASTA_VOCAB_KEYS)
|
|
73
|
+
FASTA_VOCAB_DICT = dict(zip(FASTA_VOCAB_KEYS, range(FASTA_VOCAB_COUNT)))
|
|
73
74
|
|
|
74
75
|
|
|
75
76
|
def smiles2vec(smiles: str) -> List[int]:
|
|
@@ -81,21 +82,21 @@ def smiles2vec(smiles: str) -> List[int]:
|
|
|
81
82
|
:return: tokens w/o `<start>` and `<end>`
|
|
82
83
|
:rtype: list
|
|
83
84
|
"""
|
|
84
|
-
tokens = [token for token in
|
|
85
|
+
tokens = [token for token in _smi_regex.findall(smiles)]
|
|
85
86
|
return [VOCAB_DICT[token] for token in tokens]
|
|
86
87
|
|
|
87
88
|
|
|
88
|
-
def
|
|
89
|
+
def fasta2vec(fasta: str) -> List[int]:
|
|
89
90
|
"""
|
|
90
|
-
|
|
91
|
+
FASTA sequence tokenisation using a dataset-independent regex pattern.
|
|
91
92
|
|
|
92
|
-
:param
|
|
93
|
-
:type
|
|
93
|
+
:param fasta: protein (amino acid) sequence
|
|
94
|
+
:type fasta: str
|
|
94
95
|
:return: tokens w/o `<start>` and `<end>`
|
|
95
96
|
:rtype: list
|
|
96
97
|
"""
|
|
97
|
-
tokens = [token for token in
|
|
98
|
-
return [
|
|
98
|
+
tokens = [token for token in _fas_regex.findall(fasta)]
|
|
99
|
+
return [FASTA_VOCAB_DICT[token] for token in tokens]
|
|
99
100
|
|
|
100
101
|
|
|
101
102
|
def split_selfies(selfies: str) -> List[str]:
|
|
@@ -107,7 +108,7 @@ def split_selfies(selfies: str) -> List[str]:
|
|
|
107
108
|
:return: SELFIES vocab
|
|
108
109
|
:rtype: list
|
|
109
110
|
"""
|
|
110
|
-
return [token for token in
|
|
111
|
+
return [token for token in _sel_regex.findall(selfies)]
|
|
111
112
|
|
|
112
113
|
|
|
113
114
|
def smiles2token(smiles: str) -> Tensor:
|
|
@@ -115,9 +116,9 @@ def smiles2token(smiles: str) -> Tensor:
|
|
|
115
116
|
return torch.tensor([1] + smiles2vec(smiles) + [2], dtype=torch.long)
|
|
116
117
|
|
|
117
118
|
|
|
118
|
-
def
|
|
119
|
+
def fasta2token(fasta: str) -> Tensor:
|
|
119
120
|
# start token: <start> = 1; end token: <end> = 2
|
|
120
|
-
return torch.tensor([1] +
|
|
121
|
+
return torch.tensor([1] + fasta2vec(fasta) + [2], dtype=torch.long)
|
|
121
122
|
|
|
122
123
|
|
|
123
124
|
def collate(batch: List[Dict[str, Tensor]]) -> Dict[str, Tensor]:
|
bayesianflow_for_chem/model.py
CHANGED
|
@@ -659,12 +659,91 @@ class ChemBFN(nn.Module):
|
|
|
659
659
|
return (-logits.gather(-1, x[..., :1]).squeeze(-1)).mean()
|
|
660
660
|
|
|
661
661
|
@staticmethod
|
|
662
|
-
def
|
|
662
|
+
def _reshape(y: Tensor) -> Tensor:
|
|
663
663
|
assert y.dim() <= 3 # this doesn't work if the model is frezen in JIT.
|
|
664
664
|
if y.dim() == 2:
|
|
665
665
|
return y[:, None, :]
|
|
666
666
|
return y
|
|
667
667
|
|
|
668
|
+
def _process(
|
|
669
|
+
self,
|
|
670
|
+
theta: Tensor,
|
|
671
|
+
mask: Optional[Tuple[Tensor, Tensor]],
|
|
672
|
+
y: Optional[Tensor],
|
|
673
|
+
sample_step: int,
|
|
674
|
+
guidance_strength: float,
|
|
675
|
+
token_mask: Optional[Tensor],
|
|
676
|
+
) -> Tuple[Tensor, Tensor]:
|
|
677
|
+
# BFN inference process.
|
|
678
|
+
#
|
|
679
|
+
# theta: piror distribution; shape: (n_b, n_t, n_vocab)
|
|
680
|
+
# mask: masked condition distribution; shape: (n_b, n_t, n_vocab)
|
|
681
|
+
# condition distribution mask; shape: (n_b, n_t, 1)
|
|
682
|
+
n_b = theta.shape[0]
|
|
683
|
+
if y is not None:
|
|
684
|
+
y = self._reshape(y)
|
|
685
|
+
for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
|
|
686
|
+
t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
|
|
687
|
+
p = self.discrete_output_distribution(theta, t, y, guidance_strength)
|
|
688
|
+
if token_mask is not None:
|
|
689
|
+
p = p.masked_fill_(token_mask, 0.0)
|
|
690
|
+
alpha = self.calc_discrete_alpha(t, t + 1 / sample_step)
|
|
691
|
+
e_k = nn.functional.one_hot(torch.argmax(p, -1), self.K).float()
|
|
692
|
+
mu = alpha * (self.K * e_k - 1)
|
|
693
|
+
sigma = (alpha * self.K).sqrt()
|
|
694
|
+
theta = (mu + sigma * torch.randn_like(mu)).exp() * theta
|
|
695
|
+
theta = theta / theta.sum(-1, True)
|
|
696
|
+
if mask is not None:
|
|
697
|
+
x_onehot, x_mask = mask
|
|
698
|
+
theta = x_onehot + (1 - x_mask) * theta
|
|
699
|
+
t_final = torch.ones((n_b, 1, 1), device=self.beta.device)
|
|
700
|
+
p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
|
|
701
|
+
entropy = -(p * p.log()).sum(-1).mean(-1)
|
|
702
|
+
if token_mask is not None:
|
|
703
|
+
p = p.masked_fill_(token_mask, 0.0)
|
|
704
|
+
return torch.argmax(p, -1), entropy
|
|
705
|
+
|
|
706
|
+
def _ode_process(
|
|
707
|
+
self,
|
|
708
|
+
z: Tensor,
|
|
709
|
+
mask: Optional[Tuple[Tensor, Tensor]],
|
|
710
|
+
y: Optional[Tensor],
|
|
711
|
+
sample_step: int,
|
|
712
|
+
guidance_strength: float,
|
|
713
|
+
token_mask: Optional[Tensor],
|
|
714
|
+
temperature: float,
|
|
715
|
+
) -> Tuple[Tensor, Tensor]:
|
|
716
|
+
# ODE-solver engaged inference process.
|
|
717
|
+
#
|
|
718
|
+
# z: prior latent vector; shape: (n_b, n_t, n_vocab)
|
|
719
|
+
# mask: masked condition distribution; shape: (n_b, n_t, n_vocab)
|
|
720
|
+
# condition distribution mask; shape: (n_b, n_t, 1)
|
|
721
|
+
n_b = z.shape[0]
|
|
722
|
+
if y is not None:
|
|
723
|
+
y = self._reshape(y)
|
|
724
|
+
for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
|
|
725
|
+
t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
|
|
726
|
+
theta = softmax(z, -1)
|
|
727
|
+
if mask is not None:
|
|
728
|
+
x_onehot, x_mask = mask
|
|
729
|
+
theta = x_onehot + (1 - x_mask) * theta
|
|
730
|
+
beta = self.calc_beta(t + 1 / sample_step)
|
|
731
|
+
p = self.discrete_output_distribution(theta, t, y, guidance_strength)
|
|
732
|
+
if token_mask is not None:
|
|
733
|
+
p = p.masked_fill_(token_mask, 0.0)
|
|
734
|
+
u = torch.randn_like(z)
|
|
735
|
+
z = (self.K * p - 1) * beta + (self.K * beta * temperature).sqrt() * u
|
|
736
|
+
t_final = torch.ones((n_b, 1, 1), device=self.beta.device)
|
|
737
|
+
theta = softmax(z, -1)
|
|
738
|
+
if mask is not None:
|
|
739
|
+
x_onehot, x_mask = mask
|
|
740
|
+
theta = x_onehot + (1 - x_mask) * theta
|
|
741
|
+
p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
|
|
742
|
+
entropy = -(p * p.log()).sum(-1).mean(-1)
|
|
743
|
+
if token_mask is not None:
|
|
744
|
+
p = p.masked_fill_(token_mask, 0.0)
|
|
745
|
+
return torch.argmax(p, -1), entropy
|
|
746
|
+
|
|
668
747
|
@torch.jit.export
|
|
669
748
|
def sample(
|
|
670
749
|
self,
|
|
@@ -680,44 +759,26 @@ class ChemBFN(nn.Module):
|
|
|
680
759
|
|
|
681
760
|
:param batch_size: batch size
|
|
682
761
|
:param sequence_size: max sequence length
|
|
683
|
-
:param y: conditioning vector;
|
|
762
|
+
:param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
|
|
684
763
|
:param sample_step: number of sampling steps
|
|
685
764
|
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
686
765
|
:param token_mask: token mask assigning unwanted token(s) with `True`;
|
|
687
|
-
|
|
766
|
+
shape: (1, 1, n_vocab)
|
|
688
767
|
:type batch_size: int
|
|
689
768
|
:type sequence_size: int
|
|
690
769
|
:type y: torch.Tensor | None
|
|
691
770
|
:type sample_step: int
|
|
692
771
|
:type guidance_strength: float
|
|
693
772
|
:type token_mask: torch.Tensor | None
|
|
694
|
-
:return: sampled token indices;
|
|
695
|
-
entropy of the tokens;
|
|
773
|
+
:return: sampled token indices; shape: (n_b, n_t) \n
|
|
774
|
+
entropy of the tokens; shape: (n_b)
|
|
696
775
|
:rtype: tuple
|
|
697
776
|
"""
|
|
698
777
|
theta = (
|
|
699
778
|
torch.ones((batch_size, sequence_size, self.K), device=self.beta.device)
|
|
700
779
|
/ self.K
|
|
701
780
|
)
|
|
702
|
-
|
|
703
|
-
y = self.reshape_y(y)
|
|
704
|
-
for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
|
|
705
|
-
t = (i - 1).view(1, 1, 1).repeat(batch_size, 1, 1) / sample_step
|
|
706
|
-
p = self.discrete_output_distribution(theta, t, y, guidance_strength)
|
|
707
|
-
if token_mask is not None:
|
|
708
|
-
p = p.masked_fill_(token_mask, 0.0)
|
|
709
|
-
alpha = self.calc_discrete_alpha(t, t + 1 / sample_step)
|
|
710
|
-
e_k = nn.functional.one_hot(torch.argmax(p, -1), self.K).float()
|
|
711
|
-
mu = alpha * (self.K * e_k - 1)
|
|
712
|
-
sigma = (alpha * self.K).sqrt()
|
|
713
|
-
theta = (mu + sigma * torch.randn_like(mu)).exp() * theta
|
|
714
|
-
theta = theta / theta.sum(-1, True)
|
|
715
|
-
t_final = torch.ones((batch_size, 1, 1), device=self.beta.device)
|
|
716
|
-
p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
|
|
717
|
-
entropy = -(p * p.log()).sum(-1).mean(-1)
|
|
718
|
-
if token_mask is not None:
|
|
719
|
-
p = p.masked_fill_(token_mask, 0.0)
|
|
720
|
-
return torch.argmax(p, -1), entropy
|
|
781
|
+
return self._process(theta, None, y, sample_step, guidance_strength, token_mask)
|
|
721
782
|
|
|
722
783
|
@torch.jit.export
|
|
723
784
|
def ode_sample(
|
|
@@ -735,11 +796,11 @@ class ChemBFN(nn.Module):
|
|
|
735
796
|
|
|
736
797
|
:param batch_size: batch size
|
|
737
798
|
:param sequence_size: max sequence length
|
|
738
|
-
:param y: conditioning vector;
|
|
799
|
+
:param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
|
|
739
800
|
:param sample_step: number of sampling steps
|
|
740
801
|
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
741
802
|
:param token_mask: token mask assigning unwanted token(s) with `True`;
|
|
742
|
-
|
|
803
|
+
shape: (1, 1, n_vocab)
|
|
743
804
|
:param temperature: sampling temperature
|
|
744
805
|
:type batch_size: int
|
|
745
806
|
:type sequence_size: int
|
|
@@ -748,29 +809,14 @@ class ChemBFN(nn.Module):
|
|
|
748
809
|
:type guidance_strength: float
|
|
749
810
|
:type token_mask: torch.Tensor | None
|
|
750
811
|
:type temperature: float
|
|
751
|
-
:return: sampled token indices;
|
|
752
|
-
entropy of the tokens;
|
|
812
|
+
:return: sampled token indices; shape: (n_b, n_t) \n
|
|
813
|
+
entropy of the tokens; shape: (n_b)
|
|
753
814
|
:rtype: tuple
|
|
754
815
|
"""
|
|
755
816
|
z = torch.zeros((batch_size, sequence_size, self.K), device=self.beta.device)
|
|
756
|
-
|
|
757
|
-
y
|
|
758
|
-
|
|
759
|
-
t = (i - 1).view(1, 1, 1).repeat(batch_size, 1, 1) / sample_step
|
|
760
|
-
theta = torch.softmax(z, -1)
|
|
761
|
-
beta = self.calc_beta(t + 1 / sample_step)
|
|
762
|
-
p = self.discrete_output_distribution(theta, t, y, guidance_strength)
|
|
763
|
-
if token_mask is not None:
|
|
764
|
-
p = p.masked_fill_(token_mask, 0.0)
|
|
765
|
-
u = torch.randn_like(z)
|
|
766
|
-
z = (self.K * p - 1) * beta + (self.K * beta * temperature).sqrt() * u
|
|
767
|
-
t_final = torch.ones((batch_size, 1, 1), device=self.beta.device)
|
|
768
|
-
theta = torch.softmax(z, -1)
|
|
769
|
-
p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
|
|
770
|
-
entropy = -(p * p.log()).sum(-1).mean(-1)
|
|
771
|
-
if token_mask is not None:
|
|
772
|
-
p = p.masked_fill_(token_mask, 0.0)
|
|
773
|
-
return torch.argmax(p, -1), entropy
|
|
817
|
+
return self._ode_process(
|
|
818
|
+
z, None, y, sample_step, guidance_strength, token_mask, temperature
|
|
819
|
+
)
|
|
774
820
|
|
|
775
821
|
@torch.jit.export
|
|
776
822
|
def inpaint(
|
|
@@ -800,30 +846,12 @@ class ChemBFN(nn.Module):
|
|
|
800
846
|
:rtype: tuple
|
|
801
847
|
"""
|
|
802
848
|
n_b, n_t = x.shape
|
|
803
|
-
|
|
849
|
+
x_mask = (x != 0).float()[..., None]
|
|
804
850
|
theta = torch.ones((n_b, n_t, self.K), device=x.device) / self.K
|
|
805
|
-
x_onehot = nn.functional.one_hot(x, self.K) *
|
|
806
|
-
theta = x_onehot + (1 -
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
for i in torch.linspace(1, sample_step, sample_step, device=x.device):
|
|
810
|
-
t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
|
|
811
|
-
p = self.discrete_output_distribution(theta, t, y, guidance_strength)
|
|
812
|
-
if token_mask is not None:
|
|
813
|
-
p = p.masked_fill_(token_mask, 0.0)
|
|
814
|
-
alpha = self.calc_discrete_alpha(t, t + 1 / sample_step)
|
|
815
|
-
e_k = nn.functional.one_hot(torch.argmax(p, -1), self.K).float()
|
|
816
|
-
mu = alpha * (self.K * e_k - 1)
|
|
817
|
-
sigma = (alpha * self.K).sqrt()
|
|
818
|
-
theta = (mu + sigma * torch.randn_like(mu)).exp() * theta
|
|
819
|
-
theta = theta / theta.sum(-1, True)
|
|
820
|
-
theta = x_onehot + (1 - mask) * theta
|
|
821
|
-
t_final = torch.ones((n_b, 1, 1), device=x.device)
|
|
822
|
-
p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
|
|
823
|
-
entropy = -(p * p.log()).sum(-1).mean(-1)
|
|
824
|
-
if token_mask is not None:
|
|
825
|
-
p = p.masked_fill_(token_mask, 0.0)
|
|
826
|
-
return torch.argmax(p, -1), entropy
|
|
851
|
+
x_onehot = nn.functional.one_hot(x, self.K) * x_mask
|
|
852
|
+
theta = x_onehot + (1 - x_mask) * theta
|
|
853
|
+
mask = (x_onehot, x_mask)
|
|
854
|
+
return self._process(theta, mask, y, sample_step, guidance_strength, token_mask)
|
|
827
855
|
|
|
828
856
|
@torch.jit.export
|
|
829
857
|
def ode_inpaint(
|
|
@@ -856,29 +884,13 @@ class ChemBFN(nn.Module):
|
|
|
856
884
|
:rtype: tuple
|
|
857
885
|
"""
|
|
858
886
|
n_b, n_t = x.shape
|
|
859
|
-
|
|
860
|
-
x_onehot = nn.functional.one_hot(x, self.K) *
|
|
887
|
+
x_mask = (x != 0).float()[..., None]
|
|
888
|
+
x_onehot = nn.functional.one_hot(x, self.K) * x_mask
|
|
861
889
|
z = torch.zeros((n_b, n_t, self.K), device=self.beta.device)
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
theta = torch.softmax(z, -1)
|
|
867
|
-
theta = x_onehot + (1 - mask) * theta
|
|
868
|
-
beta = self.calc_beta(t + 1 / sample_step)
|
|
869
|
-
p = self.discrete_output_distribution(theta, t, y, guidance_strength)
|
|
870
|
-
if token_mask is not None:
|
|
871
|
-
p = p.masked_fill_(token_mask, 0.0)
|
|
872
|
-
u = torch.randn_like(z)
|
|
873
|
-
z = (self.K * p - 1) * beta + (self.K * beta * temperature).sqrt() * u
|
|
874
|
-
t_final = torch.ones((n_b, 1, 1), device=self.beta.device)
|
|
875
|
-
theta = torch.softmax(z, -1)
|
|
876
|
-
theta = x_onehot + (1 - mask) * theta
|
|
877
|
-
p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
|
|
878
|
-
entropy = -(p * p.log()).sum(-1).mean(-1)
|
|
879
|
-
if token_mask is not None:
|
|
880
|
-
p = p.masked_fill_(token_mask, 0.0)
|
|
881
|
-
return torch.argmax(p, -1), entropy
|
|
890
|
+
mask = (x_onehot, x_mask)
|
|
891
|
+
return self._ode_process(
|
|
892
|
+
z, mask, y, sample_step, guidance_strength, token_mask, temperature
|
|
893
|
+
)
|
|
882
894
|
|
|
883
895
|
@torch.jit.export
|
|
884
896
|
def optimise(
|
|
@@ -908,28 +920,9 @@ class ChemBFN(nn.Module):
|
|
|
908
920
|
entropy of the tokens; shape: (n_b)
|
|
909
921
|
:rtype: tuple
|
|
910
922
|
"""
|
|
911
|
-
n_b = x.shape[0]
|
|
912
923
|
x_onehot = nn.functional.one_hot(x, self.K).float()
|
|
913
|
-
theta =
|
|
914
|
-
|
|
915
|
-
y = self.reshape_y(y)
|
|
916
|
-
for i in torch.linspace(1, sample_step, sample_step, device=x.device):
|
|
917
|
-
t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
|
|
918
|
-
p = self.discrete_output_distribution(theta, t, y, guidance_strength)
|
|
919
|
-
if token_mask is not None:
|
|
920
|
-
p = p.masked_fill_(token_mask, 0.0)
|
|
921
|
-
alpha = self.calc_discrete_alpha(t, t + 1 / sample_step)
|
|
922
|
-
e_k = nn.functional.one_hot(torch.argmax(p, -1), self.K).float()
|
|
923
|
-
mu = alpha * (self.K * e_k - 1)
|
|
924
|
-
sigma = (alpha * self.K).sqrt()
|
|
925
|
-
theta = (mu + sigma * torch.randn_like(mu)).exp() * theta
|
|
926
|
-
theta = theta / theta.sum(-1, True)
|
|
927
|
-
t_final = torch.ones((n_b, 1, 1), device=x.device)
|
|
928
|
-
p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
|
|
929
|
-
entropy = -(p * p.log()).sum(-1).mean(-1)
|
|
930
|
-
if token_mask is not None:
|
|
931
|
-
p = p.masked_fill_(token_mask, 0.0)
|
|
932
|
-
return torch.argmax(p, -1), entropy
|
|
924
|
+
theta = softmax(x_onehot, -1)
|
|
925
|
+
return self._process(theta, None, y, sample_step, guidance_strength, token_mask)
|
|
933
926
|
|
|
934
927
|
@torch.jit.export
|
|
935
928
|
def ode_optimise(
|
|
@@ -961,26 +954,10 @@ class ChemBFN(nn.Module):
|
|
|
961
954
|
entropy of the tokens; shape: (n_b)
|
|
962
955
|
:rtype: tuple
|
|
963
956
|
"""
|
|
964
|
-
n_b = x.shape[0]
|
|
965
957
|
z = nn.functional.one_hot(x, self.K).float()
|
|
966
|
-
|
|
967
|
-
y
|
|
968
|
-
|
|
969
|
-
t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
|
|
970
|
-
theta = torch.softmax(z, -1)
|
|
971
|
-
beta = self.calc_beta(t + 1 / sample_step)
|
|
972
|
-
p = self.discrete_output_distribution(theta, t, y, guidance_strength)
|
|
973
|
-
if token_mask is not None:
|
|
974
|
-
p = p.masked_fill_(token_mask, 0.0)
|
|
975
|
-
u = torch.randn_like(z)
|
|
976
|
-
z = (self.K * p - 1) * beta + (self.K * beta * temperature).sqrt() * u
|
|
977
|
-
t_final = torch.ones((n_b, 1, 1), device=self.beta.device)
|
|
978
|
-
theta = torch.softmax(z, -1)
|
|
979
|
-
p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
|
|
980
|
-
entropy = -(p * p.log()).sum(-1).mean(-1)
|
|
981
|
-
if token_mask is not None:
|
|
982
|
-
p = p.masked_fill_(token_mask, 0.0)
|
|
983
|
-
return torch.argmax(p, -1), entropy
|
|
958
|
+
return self._ode_process(
|
|
959
|
+
z, None, y, sample_step, guidance_strength, token_mask, temperature
|
|
960
|
+
)
|
|
984
961
|
|
|
985
962
|
def inference(
|
|
986
963
|
self, x: Tensor, mlp: MLP, embed_fn: Optional[Callable[[Tensor], Tensor]] = None
|
|
@@ -1154,22 +1131,6 @@ class EnsembleChemBFN(ChemBFN):
|
|
|
1154
1131
|
module.lora_dropout = None
|
|
1155
1132
|
v.lora_enabled = False
|
|
1156
1133
|
|
|
1157
|
-
def construct_y(
|
|
1158
|
-
self, c: Union[List[Tensor], Dict[str, Tensor]]
|
|
1159
|
-
) -> Dict[str, Tensor]:
|
|
1160
|
-
assert (
|
|
1161
|
-
isinstance(c, dict) is self._label_is_dict
|
|
1162
|
-
), f"`c` should be a {'`dict` instance' if self._label_is_dict else '`list` instance'} but got {type(c)} instand."
|
|
1163
|
-
out: Dict[str, Tensor] = {}
|
|
1164
|
-
if isinstance(c, list):
|
|
1165
|
-
c = dict(zip([f"val_{i}" for i in range(len(c))], c))
|
|
1166
|
-
for name, model in self.cond_heads.items():
|
|
1167
|
-
y = model.forward(c[name])
|
|
1168
|
-
if y.dim() == 2:
|
|
1169
|
-
y = y[:, None, :]
|
|
1170
|
-
out[name] = y
|
|
1171
|
-
return out
|
|
1172
|
-
|
|
1173
1134
|
def discrete_output_distribution(
|
|
1174
1135
|
self, theta: Tensor, t: Tensor, y: Dict[str, Tensor], w: float
|
|
1175
1136
|
) -> Tensor:
|
|
@@ -1204,8 +1165,24 @@ class EnsembleChemBFN(ChemBFN):
|
|
|
1204
1165
|
p_cond += p_cond_ * self.adapter_weights[name]
|
|
1205
1166
|
return softmax((1 + w) * p_cond - w * p_uncond, -1)
|
|
1206
1167
|
|
|
1168
|
+
def _map_to_dict(
|
|
1169
|
+
self, c: Union[List[Tensor], Dict[str, Tensor]]
|
|
1170
|
+
) -> Dict[str, Tensor]:
|
|
1171
|
+
assert (
|
|
1172
|
+
isinstance(c, dict) is self._label_is_dict
|
|
1173
|
+
), f"`c` should be a {'`dict` instance' if self._label_is_dict else '`list` instance'} but got {type(c)} instand."
|
|
1174
|
+
out: Dict[str, Tensor] = {}
|
|
1175
|
+
if isinstance(c, list):
|
|
1176
|
+
c = dict(zip([f"val_{i}" for i in range(len(c))], c))
|
|
1177
|
+
for name, model in self.cond_heads.items():
|
|
1178
|
+
y = model.forward(c[name])
|
|
1179
|
+
if y.dim() == 2:
|
|
1180
|
+
y = y[:, None, :]
|
|
1181
|
+
out[name] = y
|
|
1182
|
+
return out
|
|
1183
|
+
|
|
1207
1184
|
@staticmethod
|
|
1208
|
-
def
|
|
1185
|
+
def _reshape(y: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
1209
1186
|
for k in y:
|
|
1210
1187
|
assert y[k].dim() <= 3
|
|
1211
1188
|
if y[k].dim() == 2:
|
|
@@ -1241,7 +1218,7 @@ class EnsembleChemBFN(ChemBFN):
|
|
|
1241
1218
|
entropy of the tokens; shape: (n_b)
|
|
1242
1219
|
:rtype: tuple
|
|
1243
1220
|
"""
|
|
1244
|
-
y = self.
|
|
1221
|
+
y = self._map_to_dict(conditions)
|
|
1245
1222
|
return super().sample(
|
|
1246
1223
|
batch_size, sequence_size, y, sample_step, guidance_strength, token_mask
|
|
1247
1224
|
)
|
|
@@ -1278,7 +1255,7 @@ class EnsembleChemBFN(ChemBFN):
|
|
|
1278
1255
|
entropy of the tokens; shape: (n_b)
|
|
1279
1256
|
:rtype: tuple
|
|
1280
1257
|
"""
|
|
1281
|
-
y = self.
|
|
1258
|
+
y = self._map_to_dict(conditions)
|
|
1282
1259
|
return super().ode_sample(
|
|
1283
1260
|
batch_size,
|
|
1284
1261
|
sequence_size,
|
|
@@ -1315,7 +1292,7 @@ class EnsembleChemBFN(ChemBFN):
|
|
|
1315
1292
|
entropy of the tokens; shape: (n_b)
|
|
1316
1293
|
:rtype: tuple
|
|
1317
1294
|
"""
|
|
1318
|
-
y = self.
|
|
1295
|
+
y = self._map_to_dict(conditions)
|
|
1319
1296
|
return super().inpaint(x, y, sample_step, guidance_strength, token_mask)
|
|
1320
1297
|
|
|
1321
1298
|
@torch.inference_mode()
|
|
@@ -1347,7 +1324,7 @@ class EnsembleChemBFN(ChemBFN):
|
|
|
1347
1324
|
entropy of the tokens; shape: (n_b)
|
|
1348
1325
|
:rtype: tuple
|
|
1349
1326
|
"""
|
|
1350
|
-
y = self.
|
|
1327
|
+
y = self._map_to_dict(conditions)
|
|
1351
1328
|
return super().ode_inpaint(
|
|
1352
1329
|
x, y, sample_step, guidance_strength, token_mask, temperature
|
|
1353
1330
|
)
|
|
@@ -1380,7 +1357,7 @@ class EnsembleChemBFN(ChemBFN):
|
|
|
1380
1357
|
entropy of the tokens; shape: (n_b)
|
|
1381
1358
|
:rtype: tuple
|
|
1382
1359
|
"""
|
|
1383
|
-
y = self.
|
|
1360
|
+
y = self._map_to_dict(conditions)
|
|
1384
1361
|
return super().optimise(x, y, sample_step, guidance_strength, token_mask)
|
|
1385
1362
|
|
|
1386
1363
|
@torch.inference_mode()
|
|
@@ -1412,7 +1389,7 @@ class EnsembleChemBFN(ChemBFN):
|
|
|
1412
1389
|
entropy of the tokens; shape: (n_b)
|
|
1413
1390
|
:rtype: tuple
|
|
1414
1391
|
"""
|
|
1415
|
-
y = self.
|
|
1392
|
+
y = self._map_to_dict(conditions)
|
|
1416
1393
|
return super().ode_optimise(
|
|
1417
1394
|
x, y, sample_step, guidance_strength, token_mask, temperature
|
|
1418
1395
|
)
|
bayesianflow_for_chem/tool.py
CHANGED
|
@@ -7,9 +7,8 @@ import csv
|
|
|
7
7
|
import random
|
|
8
8
|
import warnings
|
|
9
9
|
from pathlib import Path
|
|
10
|
-
from typing import List, Dict, Tuple, Union, Optional
|
|
10
|
+
from typing import List, Dict, Tuple, Union, Optional, Literal
|
|
11
11
|
import torch
|
|
12
|
-
import colorama
|
|
13
12
|
import numpy as np
|
|
14
13
|
from torch import cuda, Tensor, softmax
|
|
15
14
|
from torch.utils.data import DataLoader
|
|
@@ -24,15 +23,7 @@ from rdkit.Chem import (
|
|
|
24
23
|
AddHs,
|
|
25
24
|
Mol,
|
|
26
25
|
)
|
|
27
|
-
from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles
|
|
28
|
-
from sklearn.metrics import (
|
|
29
|
-
roc_auc_score,
|
|
30
|
-
auc,
|
|
31
|
-
precision_recall_curve,
|
|
32
|
-
r2_score,
|
|
33
|
-
mean_absolute_error,
|
|
34
|
-
root_mean_squared_error,
|
|
35
|
-
)
|
|
26
|
+
from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles
|
|
36
27
|
from .data import VOCAB_KEYS
|
|
37
28
|
from .model import ChemBFN, MLP, EnsembleChemBFN
|
|
38
29
|
|
|
@@ -45,12 +36,74 @@ def _find_device() -> torch.device:
|
|
|
45
36
|
return torch.device("cpu")
|
|
46
37
|
|
|
47
38
|
|
|
39
|
+
def _parse_and_assert_param(
|
|
40
|
+
model: Union[ChemBFN, EnsembleChemBFN],
|
|
41
|
+
y: Optional[Union[Tensor, Dict[str, Tensor], List[Tensor]]],
|
|
42
|
+
method: str,
|
|
43
|
+
) -> Optional[float]:
|
|
44
|
+
assert method.split(":")[0].lower() in ("ode", "bfn")
|
|
45
|
+
if isinstance(model, EnsembleChemBFN):
|
|
46
|
+
assert y is not None, "conditioning is required while using an ensemble model."
|
|
47
|
+
assert isinstance(y, list) or isinstance(y, dict)
|
|
48
|
+
else:
|
|
49
|
+
assert isinstance(y, Tensor) or (y is None)
|
|
50
|
+
if "ode" in method.lower():
|
|
51
|
+
tp = float(method.split(":")[-1])
|
|
52
|
+
assert tp > 0, "Sampling temperature should be higher than 0."
|
|
53
|
+
return tp
|
|
54
|
+
return None
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _map_to_device(
|
|
58
|
+
y: Optional[Union[Tensor, Dict[str, Tensor], List[Tensor]]],
|
|
59
|
+
device: Union[str, torch.device],
|
|
60
|
+
) -> Optional[Union[Tensor, Dict[str, Tensor], List[Tensor]]]:
|
|
61
|
+
if y is not None:
|
|
62
|
+
if isinstance(y, Tensor):
|
|
63
|
+
y = y.to(device)
|
|
64
|
+
elif isinstance(y, list):
|
|
65
|
+
y = [i.to(device) for i in y]
|
|
66
|
+
elif isinstance(y, dict):
|
|
67
|
+
y = {k: v.to(device) for k, v in y.items()}
|
|
68
|
+
else:
|
|
69
|
+
raise NotImplementedError
|
|
70
|
+
return y
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _build_token_mask(
|
|
74
|
+
allowed_tokens: Union[str, List[str]],
|
|
75
|
+
vocab_keys: List[str],
|
|
76
|
+
device: Union[str, torch.tensor],
|
|
77
|
+
) -> Optional[Tensor]:
|
|
78
|
+
if isinstance(allowed_tokens, list):
|
|
79
|
+
token_mask = [0 if i in allowed_tokens else 1 for i in vocab_keys]
|
|
80
|
+
token_mask = torch.tensor([[token_mask]], dtype=torch.bool).to(device)
|
|
81
|
+
else:
|
|
82
|
+
token_mask = None
|
|
83
|
+
return token_mask
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _token_to_seq(
|
|
87
|
+
tokens: Tensor, entropy: Tensor, vocab_keys: List[str], separator: str, sort: bool
|
|
88
|
+
) -> List[str]:
|
|
89
|
+
if sort:
|
|
90
|
+
sorted_idx = entropy.argsort(stable=True)
|
|
91
|
+
tokens = tokens[sorted_idx]
|
|
92
|
+
return [
|
|
93
|
+
separator.join([vocab_keys[i] for i in j])
|
|
94
|
+
.split("<start>" + separator)[-1]
|
|
95
|
+
.split(separator + "<end>")[0]
|
|
96
|
+
.replace("<pad>", "")
|
|
97
|
+
for j in tokens
|
|
98
|
+
]
|
|
99
|
+
|
|
100
|
+
|
|
48
101
|
@torch.no_grad()
|
|
49
102
|
def test(
|
|
50
103
|
model: ChemBFN,
|
|
51
104
|
mlp: MLP,
|
|
52
105
|
data: DataLoader,
|
|
53
|
-
mode:
|
|
106
|
+
mode: Literal["regression", "classification"] = "regression",
|
|
54
107
|
device: Union[str, torch.device, None] = None,
|
|
55
108
|
) -> Dict[str, float]:
|
|
56
109
|
"""
|
|
@@ -86,6 +139,12 @@ def test(
|
|
|
86
139
|
predict_y.append(y_hat.detach().to("cpu"))
|
|
87
140
|
predict_y, label_y = torch.cat(predict_y, 0), torch.cat(label_y, 0).split(1, -1)
|
|
88
141
|
if mode == "regression":
|
|
142
|
+
from sklearn.metrics import (
|
|
143
|
+
r2_score,
|
|
144
|
+
mean_absolute_error,
|
|
145
|
+
root_mean_squared_error,
|
|
146
|
+
)
|
|
147
|
+
|
|
89
148
|
predict_y = [
|
|
90
149
|
predict[label_y[i] != torch.inf]
|
|
91
150
|
for (i, predict) in enumerate(predict_y.split(1, -1))
|
|
@@ -99,6 +158,8 @@ def test(
|
|
|
99
158
|
r2 = [r2_score(label, predict) for (label, predict) in y_zipped]
|
|
100
159
|
return {"MAE": mae, "RMSE": rmse, "R^2": r2}
|
|
101
160
|
if mode == "classification":
|
|
161
|
+
from sklearn.metrics import roc_auc_score, auc, precision_recall_curve
|
|
162
|
+
|
|
102
163
|
n_c = len(label_y)
|
|
103
164
|
predict_y = predict_y.chunk(n_c, -1)
|
|
104
165
|
y_zipped = list(zip(label_y, predict_y))
|
|
@@ -123,7 +184,9 @@ def test(
|
|
|
123
184
|
|
|
124
185
|
|
|
125
186
|
def split_dataset(
|
|
126
|
-
file: Union[str, Path],
|
|
187
|
+
file: Union[str, Path],
|
|
188
|
+
split_ratio: List[int] = [8, 1, 1],
|
|
189
|
+
method: Literal["random", "scaffold"] = "random",
|
|
127
190
|
) -> None:
|
|
128
191
|
"""
|
|
129
192
|
Split a dataset.
|
|
@@ -142,7 +205,6 @@ def split_dataset(
|
|
|
142
205
|
assert file.endswith(".csv")
|
|
143
206
|
assert len(split_ratio) == 3
|
|
144
207
|
assert method in ("random", "scaffold")
|
|
145
|
-
colorama.just_fix_windows_console()
|
|
146
208
|
with open(file, "r") as f:
|
|
147
209
|
data = list(csv.reader(f))
|
|
148
210
|
header = data[0]
|
|
@@ -167,10 +229,8 @@ def split_dataset(
|
|
|
167
229
|
# compute Bemis-Murcko scaffold
|
|
168
230
|
if len(smiles_idx) > 1:
|
|
169
231
|
warnings.warn(
|
|
170
|
-
"\033[32;1m"
|
|
171
232
|
f"We found {len(smiles_idx)} SMILES strings in a row!"
|
|
172
|
-
" Only the first SMILES will be used to compute the molecular scaffold."
|
|
173
|
-
"\033[0m",
|
|
233
|
+
" Only the first SMILES will be used to compute the molecular scaffold.",
|
|
174
234
|
stacklevel=2,
|
|
175
235
|
)
|
|
176
236
|
try:
|
|
@@ -197,10 +257,10 @@ def split_dataset(
|
|
|
197
257
|
with open(file.replace(".csv", "_test.csv"), "w", newline="") as fte:
|
|
198
258
|
writer = csv.writer(fte)
|
|
199
259
|
writer.writerows([header] + test_set)
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
260
|
+
if val_set:
|
|
261
|
+
with open(file.replace(".csv", "_val.csv"), "w", newline="") as fva:
|
|
262
|
+
writer = csv.writer(fva)
|
|
263
|
+
writer.writerows([header] + val_set)
|
|
204
264
|
|
|
205
265
|
|
|
206
266
|
@torch.no_grad()
|
|
@@ -250,32 +310,12 @@ def sample(
|
|
|
250
310
|
:return: a list of generated molecular strings
|
|
251
311
|
:rtype: list
|
|
252
312
|
"""
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
assert y is not None, "conditioning is required while using an ensemble model."
|
|
256
|
-
assert isinstance(y, list) or isinstance(y, dict)
|
|
257
|
-
else:
|
|
258
|
-
assert isinstance(y, Tensor) or y is None
|
|
259
|
-
if device is None:
|
|
260
|
-
device = _find_device()
|
|
313
|
+
tp = _parse_and_assert_param(model, y, method)
|
|
314
|
+
device = _find_device() if device is None else device
|
|
261
315
|
model.to(device).eval()
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
elif isinstance(y, list):
|
|
266
|
-
y = [i.to(device) for i in y]
|
|
267
|
-
elif isinstance(y, dict):
|
|
268
|
-
y = {k: v.to(device) for k, v in y.items()}
|
|
269
|
-
else:
|
|
270
|
-
raise NotImplementedError
|
|
271
|
-
if isinstance(allowed_tokens, list):
|
|
272
|
-
token_mask = [0 if i in allowed_tokens else 1 for i in vocab_keys]
|
|
273
|
-
token_mask = torch.tensor([[token_mask]], dtype=torch.bool).to(device)
|
|
274
|
-
else:
|
|
275
|
-
token_mask = None
|
|
276
|
-
if "ode" in method.lower():
|
|
277
|
-
tp = float(method.split(":")[-1])
|
|
278
|
-
assert tp > 0, "Sampling temperature should be higher than 0."
|
|
316
|
+
y = _map_to_device(y, device)
|
|
317
|
+
token_mask = _build_token_mask(allowed_tokens, vocab_keys, device)
|
|
318
|
+
if tp:
|
|
279
319
|
tokens, entropy = model.ode_sample(
|
|
280
320
|
batch_size, sequence_size, y, sample_step, guidance_strength, token_mask, tp
|
|
281
321
|
)
|
|
@@ -283,16 +323,7 @@ def sample(
|
|
|
283
323
|
tokens, entropy = model.sample(
|
|
284
324
|
batch_size, sequence_size, y, sample_step, guidance_strength, token_mask
|
|
285
325
|
)
|
|
286
|
-
|
|
287
|
-
sorted_idx = entropy.argsort(stable=True)
|
|
288
|
-
tokens = tokens[sorted_idx]
|
|
289
|
-
return [
|
|
290
|
-
seperator.join([vocab_keys[i] for i in j])
|
|
291
|
-
.split("<start>" + seperator)[-1]
|
|
292
|
-
.split(seperator + "<end>")[0]
|
|
293
|
-
.replace("<pad>", "")
|
|
294
|
-
for j in tokens
|
|
295
|
-
]
|
|
326
|
+
return _token_to_seq(tokens, entropy, vocab_keys, seperator, sort)
|
|
296
327
|
|
|
297
328
|
|
|
298
329
|
@torch.no_grad()
|
|
@@ -339,33 +370,13 @@ def inpaint(
|
|
|
339
370
|
:return: a list of generated molecular strings
|
|
340
371
|
:rtype: list
|
|
341
372
|
"""
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
assert y is not None, "conditioning is required while using an ensemble model."
|
|
345
|
-
assert isinstance(y, list) or isinstance(y, dict)
|
|
346
|
-
else:
|
|
347
|
-
assert isinstance(y, Tensor) or y is None
|
|
348
|
-
if device is None:
|
|
349
|
-
device = _find_device()
|
|
373
|
+
tp = _parse_and_assert_param(model, y, method)
|
|
374
|
+
device = _find_device() if device is None else device
|
|
350
375
|
model.to(device).eval()
|
|
351
376
|
x = x.to(device)
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
elif isinstance(y, list):
|
|
356
|
-
y = [i.to(device) for i in y]
|
|
357
|
-
elif isinstance(y, dict):
|
|
358
|
-
y = {k: v.to(device) for k, v in y.items()}
|
|
359
|
-
else:
|
|
360
|
-
raise NotImplementedError
|
|
361
|
-
if isinstance(allowed_tokens, list):
|
|
362
|
-
token_mask = [0 if i in allowed_tokens else 1 for i in vocab_keys]
|
|
363
|
-
token_mask = torch.tensor([[token_mask]], dtype=torch.bool).to(device)
|
|
364
|
-
else:
|
|
365
|
-
token_mask = None
|
|
366
|
-
if "ode" in method.lower():
|
|
367
|
-
tp = float(method.split(":")[-1])
|
|
368
|
-
assert tp > 0, "Sampling temperature should be higher than 0."
|
|
377
|
+
y = _map_to_device(y, device)
|
|
378
|
+
token_mask = _build_token_mask(allowed_tokens, vocab_keys, device)
|
|
379
|
+
if tp:
|
|
369
380
|
tokens, entropy = model.ode_inpaint(
|
|
370
381
|
x, y, sample_step, guidance_strength, token_mask, tp
|
|
371
382
|
)
|
|
@@ -373,16 +384,7 @@ def inpaint(
|
|
|
373
384
|
tokens, entropy = model.inpaint(
|
|
374
385
|
x, y, sample_step, guidance_strength, token_mask
|
|
375
386
|
)
|
|
376
|
-
|
|
377
|
-
sorted_idx = entropy.argsort(stable=True)
|
|
378
|
-
tokens = tokens[sorted_idx]
|
|
379
|
-
return [
|
|
380
|
-
separator.join([vocab_keys[i] for i in j])
|
|
381
|
-
.split("<start>" + separator)[-1]
|
|
382
|
-
.split(separator + "<end>")[0]
|
|
383
|
-
.replace("<pad>", "")
|
|
384
|
-
for j in tokens
|
|
385
|
-
]
|
|
387
|
+
return _token_to_seq(tokens, entropy, vocab_keys, separator, sort)
|
|
386
388
|
|
|
387
389
|
|
|
388
390
|
@torch.no_grad()
|
|
@@ -429,33 +431,13 @@ def optimise(
|
|
|
429
431
|
:return: a list of generated molecular strings
|
|
430
432
|
:rtype: list
|
|
431
433
|
"""
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
assert y is not None, "conditioning is required while using an ensemble model."
|
|
435
|
-
assert isinstance(y, list) or isinstance(y, dict)
|
|
436
|
-
else:
|
|
437
|
-
assert isinstance(y, Tensor) or y is None
|
|
438
|
-
if device is None:
|
|
439
|
-
device = _find_device()
|
|
434
|
+
tp = _parse_and_assert_param(model, y, method)
|
|
435
|
+
device = _find_device() if device is None else device
|
|
440
436
|
model.to(device).eval()
|
|
441
437
|
x = x.to(device)
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
elif isinstance(y, list):
|
|
446
|
-
y = [i.to(device) for i in y]
|
|
447
|
-
elif isinstance(y, dict):
|
|
448
|
-
y = {k: v.to(device) for k, v in y.items()}
|
|
449
|
-
else:
|
|
450
|
-
raise NotImplementedError
|
|
451
|
-
if isinstance(allowed_tokens, list):
|
|
452
|
-
token_mask = [0 if i in allowed_tokens else 1 for i in vocab_keys]
|
|
453
|
-
token_mask = torch.tensor([[token_mask]], dtype=torch.bool).to(device)
|
|
454
|
-
else:
|
|
455
|
-
token_mask = None
|
|
456
|
-
if "ode" in method.lower():
|
|
457
|
-
tp = float(method.split(":")[-1])
|
|
458
|
-
assert tp > 0, "Sampling temperature should be higher than 0."
|
|
438
|
+
y = _map_to_device(y, device)
|
|
439
|
+
token_mask = _build_token_mask(allowed_tokens, vocab_keys, device)
|
|
440
|
+
if tp:
|
|
459
441
|
tokens, entropy = model.ode_optimise(
|
|
460
442
|
x, y, sample_step, guidance_strength, token_mask, tp
|
|
461
443
|
)
|
|
@@ -463,16 +445,7 @@ def optimise(
|
|
|
463
445
|
tokens, entropy = model.optimise(
|
|
464
446
|
x, y, sample_step, guidance_strength, token_mask
|
|
465
447
|
)
|
|
466
|
-
|
|
467
|
-
sorted_idx = entropy.argsort(stable=True)
|
|
468
|
-
tokens = tokens[sorted_idx]
|
|
469
|
-
return [
|
|
470
|
-
separator.join([vocab_keys[i] for i in j])
|
|
471
|
-
.split("<start>" + separator)[-1]
|
|
472
|
-
.split(separator + "<end>")[0]
|
|
473
|
-
.replace("<pad>", "")
|
|
474
|
-
for j in tokens
|
|
475
|
-
]
|
|
448
|
+
return _token_to_seq(tokens, entropy, vocab_keys, separator, sort)
|
|
476
449
|
|
|
477
450
|
|
|
478
451
|
def quantise_model_(model: ChemBFN) -> None:
|
|
@@ -555,7 +528,7 @@ class GeometryConverter:
|
|
|
555
528
|
def smiles2cartesian(
|
|
556
529
|
smiles: str,
|
|
557
530
|
num_conformers: int = 250,
|
|
558
|
-
rdkit_ff_type:
|
|
531
|
+
rdkit_ff_type: Literal["MMFF", "UFF"] = "MMFF",
|
|
559
532
|
refine_with_crest: bool = False,
|
|
560
533
|
spin: float = 0.0,
|
|
561
534
|
) -> Tuple[List[str], np.ndarray]:
|
bayesianflow_for_chem/train.py
CHANGED
|
@@ -8,7 +8,6 @@ from typing import Dict, Tuple, Union, Optional
|
|
|
8
8
|
import torch
|
|
9
9
|
import torch.optim as op
|
|
10
10
|
import torch.nn.functional as F
|
|
11
|
-
from loralib import lora_state_dict, mark_only_lora_as_trainable
|
|
12
11
|
from torch import Tensor
|
|
13
12
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
14
13
|
from lightning import LightningModule
|
|
@@ -55,6 +54,8 @@ class Model(LightningModule):
|
|
|
55
54
|
self.scorer = scorer
|
|
56
55
|
self.save_hyperparameters(hparam, ignore=["model", "mlp", "scorer"])
|
|
57
56
|
if model.lora_enabled:
|
|
57
|
+
from loralib import mark_only_lora_as_trainable
|
|
58
|
+
|
|
58
59
|
mark_only_lora_as_trainable(self.model)
|
|
59
60
|
self.use_scorer = self.scorer is not None
|
|
60
61
|
|
|
@@ -107,6 +108,8 @@ class Model(LightningModule):
|
|
|
107
108
|
:rtype: None
|
|
108
109
|
"""
|
|
109
110
|
if self.model.lora_enabled:
|
|
111
|
+
from loralib import lora_state_dict
|
|
112
|
+
|
|
110
113
|
torch.save(
|
|
111
114
|
{
|
|
112
115
|
"lora_nn": lora_state_dict(self.model),
|
|
@@ -152,6 +155,8 @@ class Regressor(LightningModule):
|
|
|
152
155
|
self.model.requires_grad_(not hparam["freeze"])
|
|
153
156
|
self.save_hyperparameters(hparam, ignore=["model", "mlp"])
|
|
154
157
|
if model.lora_enabled:
|
|
158
|
+
from loralib import mark_only_lora_as_trainable
|
|
159
|
+
|
|
155
160
|
mark_only_lora_as_trainable(self.model)
|
|
156
161
|
assert hparam["mode"] in ("regression", "classification")
|
|
157
162
|
|
|
@@ -231,6 +236,8 @@ class Regressor(LightningModule):
|
|
|
231
236
|
)
|
|
232
237
|
if not self.hparams.freeze:
|
|
233
238
|
if self.model.lora_enabled:
|
|
239
|
+
from loralib import lora_state_dict
|
|
240
|
+
|
|
234
241
|
torch.save(
|
|
235
242
|
{
|
|
236
243
|
"lora_nn": lora_state_dict(self.model),
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: bayesianflow_for_chem
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.2.2
|
|
4
4
|
Summary: Bayesian flow network framework for Chemistry
|
|
5
5
|
Home-page: https://augus1999.github.io/bayesian-flow-network-for-chemistry/
|
|
6
6
|
Author: Nianze A. Tao
|
|
@@ -54,12 +54,14 @@ This is the repository of the PyTorch implementation of ChemBFN model.
|
|
|
54
54
|
|
|
55
55
|
[](https://pypi.org/project/bayesianflow-for-chem/)
|
|
56
56
|

|
|
57
|
+
[](https://augus1999.github.io/bayesian-flow-network-for-chemistry/)
|
|
57
58
|
|
|
58
59
|
## Features
|
|
59
60
|
|
|
60
61
|
ChemBFN provides the state-of-the-art functionalities of
|
|
61
62
|
* SMILES or SELFIES-based *de novo* molecule generation
|
|
62
63
|
* Protein sequence *de novo* generation
|
|
64
|
+
* Template optimisation (mol2mol)
|
|
63
65
|
* Classifier-free guidance conditional generation (single or multi-objective optimisation)
|
|
64
66
|
* Context-guided conditional generation (inpaint)
|
|
65
67
|
* Outstanding out-of-distribution chemical space sampling
|
|
@@ -71,6 +73,7 @@ in an all-in-one-model style.
|
|
|
71
73
|
|
|
72
74
|
## News
|
|
73
75
|
|
|
76
|
+
* [09/10/2025] A web app [`chembfn_webui`](https://github.com/Augus1999/ChemBFN-WebUI) for hosting ChemBFN models is available on [PyPI](https://pypi.org/project/chembfn-webui/).
|
|
74
77
|
* [30/01/2025] The package `bayesianflow_for_chem` is available on [PyPI](https://pypi.org/project/bayesianflow-for-chem/).
|
|
75
78
|
* [21/01/2025] Our first paper has been accepted by [JCIM](https://pubs.acs.org/doi/10.1021/acs.jcim.4c01792).
|
|
76
79
|
* [17/12/2024] The second paper of out-of-distribution generation is available on [arxiv.org](https://arxiv.org/abs/2412.11439).
|
|
@@ -93,7 +96,7 @@ You can find pretrained models on our [🤗Hugging Face model page](https://hugg
|
|
|
93
96
|
|
|
94
97
|
## Dataset Handling
|
|
95
98
|
|
|
96
|
-
We provide a Python class [`CSVData`](https://github.com/Augus1999/bayesian-flow-network-for-chemistry/blob/main/bayesianflow_for_chem/data.py#
|
|
99
|
+
We provide a Python class [`CSVData`](https://github.com/Augus1999/bayesian-flow-network-for-chemistry/blob/main/bayesianflow_for_chem/data.py#L153) to handle data stored in CSV or similar format containing headers to identify the entities. The following is a quickstart.
|
|
97
100
|
|
|
98
101
|
1. Download your dataset file (e.g., ESOL from [MoleculeNet](https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv)) and split the file:
|
|
99
102
|
```python
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
bayesianflow_for_chem/__init__.py,sha256=5nyapMRF-NM610Y_FzZQxu97LQvFaQDxwsyYFBJFxdw,642
|
|
2
|
+
bayesianflow_for_chem/cli.py,sha256=wSDQ5EpETB0-o_YeSIuFt4hP1gI4if566a3qehspgB0,27353
|
|
3
|
+
bayesianflow_for_chem/data.py,sha256=jOzcOO5FDNju8hnaimT_WI8sjdaiOHDalDIOEOpLjEE,6643
|
|
4
|
+
bayesianflow_for_chem/model.py,sha256=M35G4u4mX4btl9vOK3Iqs6yOSuIKI_OoCTmLhmjbwNk,57559
|
|
5
|
+
bayesianflow_for_chem/scorer.py,sha256=gQFUlkyxitch02ntqcRh1ZS8aondKLynW5U6NfTQTb4,4084
|
|
6
|
+
bayesianflow_for_chem/spectra.py,sha256=Ba9ib1aDvTtDYbH3b4d-lIty3ZSQMu7jwehuV2KmhwA,1781
|
|
7
|
+
bayesianflow_for_chem/tool.py,sha256=pAEGfYzEiquu9cTM0Te8EAAr2RPRRObGCzLk9uXaw8o,23686
|
|
8
|
+
bayesianflow_for_chem/train.py,sha256=7AU0A-eZwzSYsLyIe3OxGTNWPnhGpHmVUaQLplV2Fn8,9886
|
|
9
|
+
bayesianflow_for_chem/_data/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
|
|
10
|
+
bayesianflow_for_chem-2.2.2.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
|
|
11
|
+
bayesianflow_for_chem-2.2.2.dist-info/METADATA,sha256=nh6i_LRZTBSoJr3KP3iD0Q-CgrveQSj8AHrdg75FsU4,6476
|
|
12
|
+
bayesianflow_for_chem-2.2.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
13
|
+
bayesianflow_for_chem-2.2.2.dist-info/entry_points.txt,sha256=N63RMoJsr8rxuKxc7Fj802SL8J5AlpCoPkS8E3IFPLI,54
|
|
14
|
+
bayesianflow_for_chem-2.2.2.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
|
|
15
|
+
bayesianflow_for_chem-2.2.2.dist-info/RECORD,,
|
|
@@ -1,15 +0,0 @@
|
|
|
1
|
-
bayesianflow_for_chem/__init__.py,sha256=bmqERRnnmyK7UgUYn3BFH_YipKfTxNNjRzu9__RVid4,612
|
|
2
|
-
bayesianflow_for_chem/cli.py,sha256=A1cFz6jhpLEpbK9r8GxCLdbnPCzQ4RrsavLKg_lssVg,24208
|
|
3
|
-
bayesianflow_for_chem/data.py,sha256=Pl0gGWHmMKTKHpsxznvLgYPCwwlLNL7nqH19Vipjkxs,6584
|
|
4
|
-
bayesianflow_for_chem/model.py,sha256=UW5hfAofYK9dH9euDPYWfJedVMRFxk8WtY427fObf70,59641
|
|
5
|
-
bayesianflow_for_chem/scorer.py,sha256=gQFUlkyxitch02ntqcRh1ZS8aondKLynW5U6NfTQTb4,4084
|
|
6
|
-
bayesianflow_for_chem/spectra.py,sha256=Ba9ib1aDvTtDYbH3b4d-lIty3ZSQMu7jwehuV2KmhwA,1781
|
|
7
|
-
bayesianflow_for_chem/tool.py,sha256=bqoIMas8bmcjYBghuQWLh75Eq8ZlG6mh9ZeDzWGOmuw,24790
|
|
8
|
-
bayesianflow_for_chem/train.py,sha256=jYkhSguW50lrcTEydCQ20yig_mmc1j7WH9KmVwBCTAo,9727
|
|
9
|
-
bayesianflow_for_chem/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
|
|
10
|
-
bayesianflow_for_chem-2.1.0.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
|
|
11
|
-
bayesianflow_for_chem-2.1.0.dist-info/METADATA,sha256=ctwor5jnPCmAo-1wuIGkX70aqkXj-8YQxr27AJOdEjM,6057
|
|
12
|
-
bayesianflow_for_chem-2.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
13
|
-
bayesianflow_for_chem-2.1.0.dist-info/entry_points.txt,sha256=N63RMoJsr8rxuKxc7Fj802SL8J5AlpCoPkS8E3IFPLI,54
|
|
14
|
-
bayesianflow_for_chem-2.1.0.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
|
|
15
|
-
bayesianflow_for_chem-2.1.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
{bayesianflow_for_chem-2.1.0.dist-info → bayesianflow_for_chem-2.2.2.dist-info}/entry_points.txt
RENAMED
|
File without changes
|
{bayesianflow_for_chem-2.1.0.dist-info → bayesianflow_for_chem-2.2.2.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|
{bayesianflow_for_chem-2.1.0.dist-info → bayesianflow_for_chem-2.2.2.dist-info}/top_level.txt
RENAMED
|
File without changes
|