bayesianflow-for-chem 2.0.5__py3-none-any.whl → 2.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bayesianflow-for-chem might be problematic. Click here for more details.
- bayesianflow_for_chem/__init__.py +4 -3
- bayesianflow_for_chem/cli.py +120 -31
- bayesianflow_for_chem/data.py +1 -1
- bayesianflow_for_chem/model.py +257 -113
- bayesianflow_for_chem/tool.py +150 -89
- bayesianflow_for_chem/train.py +8 -1
- {bayesianflow_for_chem-2.0.5.dist-info → bayesianflow_for_chem-2.2.1.dist-info}/METADATA +4 -1
- bayesianflow_for_chem-2.2.1.dist-info/RECORD +15 -0
- bayesianflow_for_chem-2.0.5.dist-info/RECORD +0 -15
- /bayesianflow_for_chem/{vocab.txt → _data/vocab.txt} +0 -0
- {bayesianflow_for_chem-2.0.5.dist-info → bayesianflow_for_chem-2.2.1.dist-info}/WHEEL +0 -0
- {bayesianflow_for_chem-2.0.5.dist-info → bayesianflow_for_chem-2.2.1.dist-info}/entry_points.txt +0 -0
- {bayesianflow_for_chem-2.0.5.dist-info → bayesianflow_for_chem-2.2.1.dist-info}/licenses/LICENSE +0 -0
- {bayesianflow_for_chem-2.0.5.dist-info → bayesianflow_for_chem-2.2.1.dist-info}/top_level.txt +0 -0
|
@@ -3,10 +3,8 @@
|
|
|
3
3
|
"""
|
|
4
4
|
ChemBFN package.
|
|
5
5
|
"""
|
|
6
|
-
import colorama
|
|
7
6
|
from . import data, tool, train, scorer, spectra
|
|
8
7
|
from .model import ChemBFN, MLP, EnsembleChemBFN
|
|
9
|
-
from .cli import main_script
|
|
10
8
|
|
|
11
9
|
__all__ = [
|
|
12
10
|
"data",
|
|
@@ -18,7 +16,7 @@ __all__ = [
|
|
|
18
16
|
"MLP",
|
|
19
17
|
"EnsembleChemBFN",
|
|
20
18
|
]
|
|
21
|
-
__version__ = "2.
|
|
19
|
+
__version__ = "2.2.1"
|
|
22
20
|
__author__ = "Nianze A. Tao (Omozawa Sueno)"
|
|
23
21
|
|
|
24
22
|
|
|
@@ -29,6 +27,9 @@ def main() -> None:
|
|
|
29
27
|
:return:
|
|
30
28
|
:rtype: None
|
|
31
29
|
"""
|
|
30
|
+
import colorama
|
|
31
|
+
from bayesianflow_for_chem.cli import main_script
|
|
32
|
+
|
|
32
33
|
colorama.just_fix_windows_console()
|
|
33
34
|
main_script(__version__)
|
|
34
35
|
colorama.deinit()
|
bayesianflow_for_chem/cli.py
CHANGED
|
@@ -12,13 +12,8 @@ from pathlib import Path
|
|
|
12
12
|
from functools import partial
|
|
13
13
|
from typing import List, Tuple, Dict, Union, Callable
|
|
14
14
|
import torch
|
|
15
|
-
import lightning as L
|
|
16
15
|
from rdkit.Chem import MolFromSmiles, CanonSmiles
|
|
17
|
-
from torch.utils.data import DataLoader
|
|
18
|
-
from lightning.pytorch import loggers
|
|
19
|
-
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
20
16
|
from bayesianflow_for_chem import ChemBFN, MLP
|
|
21
|
-
from bayesianflow_for_chem.train import Model
|
|
22
17
|
from bayesianflow_for_chem.scorer import smiles_valid, Scorer
|
|
23
18
|
from bayesianflow_for_chem.data import (
|
|
24
19
|
VOCAB_COUNT,
|
|
@@ -32,7 +27,7 @@ from bayesianflow_for_chem.data import (
|
|
|
32
27
|
collate,
|
|
33
28
|
CSVData,
|
|
34
29
|
)
|
|
35
|
-
from bayesianflow_for_chem.tool import sample, inpaint
|
|
30
|
+
from bayesianflow_for_chem.tool import sample, inpaint, optimise, adjust_lora_
|
|
36
31
|
|
|
37
32
|
|
|
38
33
|
"""
|
|
@@ -90,6 +85,7 @@ checkpoint_save_path = "home/user/project/ckpt"
|
|
|
90
85
|
train_strategy = "auto" # or any strategy supported by Lightning, e.g., "ddp"
|
|
91
86
|
accumulate_grad_batches = 1
|
|
92
87
|
enable_progress_bar = false
|
|
88
|
+
plugin_script = "" # define customised behaviours of dataset, datasetloader, etc in a python script
|
|
93
89
|
|
|
94
90
|
# Remove this table if inference is unnecessary
|
|
95
91
|
[inference]
|
|
@@ -99,9 +95,11 @@ sample_size = 1000 # the minimum number of samples you want
|
|
|
99
95
|
sample_step = 100
|
|
100
96
|
sample_method = "ODE:0.5" # ODE-solver with temperature of 0.5; another choice is "BFN"
|
|
101
97
|
semi_autoregressive = false
|
|
98
|
+
lora_scaling = 1.0 # LoRA scaling if applied
|
|
102
99
|
guidance_objective = [-0.023, 0.09, 0.113] # if no objective is needed set it to empty array []
|
|
103
100
|
guidance_objective_strength = 4.0 # unnecessary if guidance_objective = []
|
|
104
101
|
guidance_scaffold = "c1ccccc1" # if no scaffold is used set it to empty string ""
|
|
102
|
+
sample_template = "" # template for mol2mol task; leave it blank if scaffold is used
|
|
105
103
|
unwanted_token = []
|
|
106
104
|
exclude_invalid = true # to only store valid samples
|
|
107
105
|
exclude_duplicate = true # to only store unique samples
|
|
@@ -118,6 +116,32 @@ madmadmadmadmadmadmadmadmadmadmadmadmadmadmad
|
|
|
118
116
|
madmadmadmadmadmadmadmadmadmadmadmadmadmadmad
|
|
119
117
|
"""
|
|
120
118
|
|
|
119
|
+
_ALLOWED_PLUGINS = [
|
|
120
|
+
"collate_fn",
|
|
121
|
+
"num_workers",
|
|
122
|
+
"max_sequence_length",
|
|
123
|
+
"shuffle",
|
|
124
|
+
"CustomData",
|
|
125
|
+
]
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _load_plugin(plugin_file: str) -> Dict[str, Union[int, Callable, object, None]]:
|
|
129
|
+
if not plugin_file:
|
|
130
|
+
return {n: None for n in _ALLOWED_PLUGINS}
|
|
131
|
+
from importlib import util as iutil
|
|
132
|
+
|
|
133
|
+
spec = iutil.spec_from_file_location(Path(plugin_file).stem, plugin_file)
|
|
134
|
+
plugins = iutil.module_from_spec(spec)
|
|
135
|
+
spec.loader.exec_module(plugins)
|
|
136
|
+
plugin_names: List[str] = plugins.__all__
|
|
137
|
+
plugin_dict = {}
|
|
138
|
+
for n in _ALLOWED_PLUGINS:
|
|
139
|
+
if n in plugin_names:
|
|
140
|
+
plugin_dict[n] = getattr(plugins, n)
|
|
141
|
+
else:
|
|
142
|
+
plugin_dict[n] = None
|
|
143
|
+
return plugin_dict
|
|
144
|
+
|
|
121
145
|
|
|
122
146
|
def parse_cli(version: str) -> argparse.Namespace:
|
|
123
147
|
"""
|
|
@@ -130,7 +154,7 @@ def parse_cli(version: str) -> argparse.Namespace:
|
|
|
130
154
|
"""
|
|
131
155
|
parser = argparse.ArgumentParser(
|
|
132
156
|
description="Madmol: a CLI molecular design tool for "
|
|
133
|
-
"de novo design, R-group replacement, and sequence in-filling, "
|
|
157
|
+
"de novo design, R-group replacement, molecule optimisation, and sequence in-filling, "
|
|
134
158
|
"based on generative route of ChemBFN method. "
|
|
135
159
|
"Let's make some craziest molecules.",
|
|
136
160
|
epilog=f"Madmol {version}, developed in Hiroshima University by chemists for chemists. "
|
|
@@ -157,7 +181,7 @@ def parse_cli(version: str) -> argparse.Namespace:
|
|
|
157
181
|
"-D",
|
|
158
182
|
"--dryrun",
|
|
159
183
|
action="store_true",
|
|
160
|
-
help="dry-run to check the configurations",
|
|
184
|
+
help="dry-run to check the configurations and exit",
|
|
161
185
|
)
|
|
162
186
|
parser.add_argument("-V", "--version", action="version", version=version)
|
|
163
187
|
return parser.parse_args()
|
|
@@ -265,6 +289,14 @@ def load_runtime_config(
|
|
|
265
289
|
f"\033[0;31mCritical\033[0;0m in {config_file}: Restart checkpoint file {ckpt_file} does not exist."
|
|
266
290
|
)
|
|
267
291
|
flag_critical += 1
|
|
292
|
+
# ↓ added in v2.2.0; need to be compatible with old versions.
|
|
293
|
+
plugin_script: str = config["train"].get("plugin_script", "")
|
|
294
|
+
if plugin_script:
|
|
295
|
+
if not os.path.exists(plugin_script):
|
|
296
|
+
print(
|
|
297
|
+
f"\033[0;31mCritical\033[0;0m in {config_file}: Plugin script {plugin_script} does not exist."
|
|
298
|
+
)
|
|
299
|
+
flag_critical += 1
|
|
268
300
|
if "inference" in config:
|
|
269
301
|
if not "train" in config:
|
|
270
302
|
if not isinstance(config["inference"]["sequence_length"], int):
|
|
@@ -284,6 +316,14 @@ def load_runtime_config(
|
|
|
284
316
|
f"\033[0;33mWarning\033[0;0m in {config_file}: Directory {result_dir} to save the result does not exist."
|
|
285
317
|
)
|
|
286
318
|
flag_warning += 1
|
|
319
|
+
if (
|
|
320
|
+
config["inference"]["guidance_scaffold"] != ""
|
|
321
|
+
and config["inference"]["sample_template"] != ""
|
|
322
|
+
):
|
|
323
|
+
print(
|
|
324
|
+
f"\033[0;33mWarning\033[0;0m in {config_file}: Inpaint task or mol2mol task?"
|
|
325
|
+
)
|
|
326
|
+
flag_warning += 1
|
|
287
327
|
return config, flag_critical, flag_warning
|
|
288
328
|
|
|
289
329
|
|
|
@@ -325,14 +365,15 @@ def main_script(version: str) -> None:
|
|
|
325
365
|
)
|
|
326
366
|
flag_warning += 1
|
|
327
367
|
if not os.path.exists(runtime_config["train"]["checkpoint_save_path"]):
|
|
328
|
-
|
|
368
|
+
if not parser.dryrun: # only create it in real tasks
|
|
369
|
+
os.makedirs(runtime_config["train"]["checkpoint_save_path"])
|
|
329
370
|
else:
|
|
330
371
|
if not model_config["ChemBFN"]["base_model"]:
|
|
331
372
|
print(
|
|
332
373
|
f"\033[0;33mWarning\033[0;0m in {parser.model_config}: You should load a pretrained ChemBFN model."
|
|
333
374
|
)
|
|
334
375
|
flag_warning += 1
|
|
335
|
-
if not model_config["MLP"]["base_model"]:
|
|
376
|
+
if "MLP" in model_config and not model_config["MLP"]["base_model"]:
|
|
336
377
|
print(
|
|
337
378
|
f"\033[0;33mWarning\033[0;0m in {parser.model_config}: You should load a pretrained MLP."
|
|
338
379
|
)
|
|
@@ -340,13 +381,17 @@ def main_script(version: str) -> None:
|
|
|
340
381
|
if "inference" in runtime_config:
|
|
341
382
|
if runtime_config["inference"]["guidance_objective"]:
|
|
342
383
|
if not "MLP" in model_config:
|
|
343
|
-
print(
|
|
384
|
+
print(
|
|
385
|
+
f"\033[0;33mWarning\033[0;0m in {parser.model_config}: Oh no, you don't have a MLP."
|
|
386
|
+
)
|
|
344
387
|
flag_warning += 1
|
|
345
388
|
if parser.dryrun:
|
|
346
389
|
if flag_critical != 0:
|
|
347
390
|
print("Configuration check failed!")
|
|
348
391
|
elif flag_warning != 0:
|
|
349
|
-
print(
|
|
392
|
+
print(
|
|
393
|
+
"Your job will probably run, but it may not follow your expectations."
|
|
394
|
+
)
|
|
350
395
|
else:
|
|
351
396
|
print("Configuration check passed.")
|
|
352
397
|
return
|
|
@@ -405,6 +450,15 @@ def main_script(version: str) -> None:
|
|
|
405
450
|
mlp = None
|
|
406
451
|
# ------- train -------
|
|
407
452
|
if "train" in runtime_config:
|
|
453
|
+
import lightning as L
|
|
454
|
+
from torch.utils.data import DataLoader
|
|
455
|
+
from lightning.pytorch import loggers
|
|
456
|
+
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
457
|
+
from bayesianflow_for_chem.train import Model
|
|
458
|
+
|
|
459
|
+
# ####### get plugins #######
|
|
460
|
+
plugin_file = runtime_config["train"].get("plugin_script", "")
|
|
461
|
+
plugins = _load_plugin(plugin_file)
|
|
408
462
|
# ####### build scorer #######
|
|
409
463
|
if (tokeniser_name == "smiles" or tokeniser_name == "safe") and runtime_config[
|
|
410
464
|
"train"
|
|
@@ -418,30 +472,43 @@ def main_script(version: str) -> None:
|
|
|
418
472
|
mol_tag = runtime_config["train"]["molecule_tag"]
|
|
419
473
|
obj_tag = runtime_config["train"]["objective_tag"]
|
|
420
474
|
dataset_file = runtime_config["train"]["dataset"]
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
475
|
+
if plugins["max_sequence_length"]:
|
|
476
|
+
lmax = plugins["max_sequence_length"]
|
|
477
|
+
else:
|
|
478
|
+
with open(dataset_file, "r") as db:
|
|
479
|
+
_data = db.readlines()
|
|
480
|
+
_header = _data[0]
|
|
481
|
+
_mol_idx = []
|
|
482
|
+
for i, tag in enumerate(_header.replace("\n", "").split(",")):
|
|
483
|
+
if tag == mol_tag:
|
|
484
|
+
_mol_idx.append(i)
|
|
485
|
+
_data_len = []
|
|
486
|
+
for i in _data[1:]:
|
|
487
|
+
i = i.replace("\n", "").split(",")
|
|
488
|
+
_mol = ".".join([i[j] for j in _mol_idx])
|
|
489
|
+
_data_len.append(tokeniser(_mol).shape[-1])
|
|
490
|
+
lmax = max(_data_len)
|
|
491
|
+
del _data, _data_len, _header, _mol_idx # clear memory
|
|
492
|
+
if plugins["CustomData"] is not None:
|
|
493
|
+
dataset = plugins["CustomData"](dataset_file)
|
|
494
|
+
else:
|
|
495
|
+
dataset = CSVData(dataset_file)
|
|
435
496
|
dataset.map(
|
|
436
497
|
partial(_encode, mol_tag=mol_tag, obj_tag=obj_tag, tokeniser=tokeniser)
|
|
437
498
|
)
|
|
438
499
|
dataloader = DataLoader(
|
|
439
500
|
dataset,
|
|
440
501
|
runtime_config["train"]["batch_size"],
|
|
441
|
-
True,
|
|
442
|
-
num_workers=4,
|
|
443
|
-
collate_fn=
|
|
444
|
-
|
|
502
|
+
True if plugins["shuffle"] is None else plugins["shuffle"],
|
|
503
|
+
num_workers=4 if plugins["num_workers"] is None else plugins["num_workers"],
|
|
504
|
+
collate_fn=(
|
|
505
|
+
collate if plugins["collate_fn"] is None else plugins["collate_fn"]
|
|
506
|
+
),
|
|
507
|
+
persistent_workers=(
|
|
508
|
+
True
|
|
509
|
+
if (plugins["num_workers"] is None or plugins["num_workers"] > 0)
|
|
510
|
+
else False
|
|
511
|
+
),
|
|
445
512
|
)
|
|
446
513
|
# ####### build trainer #######
|
|
447
514
|
logger_name = runtime_config["train"]["logger_name"].lower()
|
|
@@ -520,6 +587,8 @@ def main_script(version: str) -> None:
|
|
|
520
587
|
if "train" in runtime_config:
|
|
521
588
|
bfn = model.model
|
|
522
589
|
mlp = model.mlp
|
|
590
|
+
# ↓ added in v2.1.0; need to be compatible with old versions
|
|
591
|
+
lora_scaling = runtime_config["inference"].get("lora_scaling", 1.0)
|
|
523
592
|
# ####### strat inference #######
|
|
524
593
|
bfn.semi_autoregressive = runtime_config["inference"]["semi_autoregressive"]
|
|
525
594
|
_device = (
|
|
@@ -550,8 +619,16 @@ def main_script(version: str) -> None:
|
|
|
550
619
|
x[:-1], (0, sequence_length - x.shape[-1] + 1), value=0
|
|
551
620
|
)
|
|
552
621
|
x = x[None, :].repeat(batch_size, 1)
|
|
622
|
+
# then sample template will be ignored.
|
|
623
|
+
elif runtime_config["inference"]["sample_template"]:
|
|
624
|
+
template = runtime_config["inference"]["sample_template"]
|
|
625
|
+
x = tokeniser(template)
|
|
626
|
+
x = torch.nn.functional.pad(x, (0, sequence_length - x.shape[-1]), value=0)
|
|
627
|
+
x = x[None, :].repeat(batch_size, 1)
|
|
553
628
|
else:
|
|
554
629
|
x = None
|
|
630
|
+
if bfn.lora_enabled:
|
|
631
|
+
adjust_lora_(bfn, lora_scaling)
|
|
555
632
|
mols = []
|
|
556
633
|
while len(mols) < runtime_config["inference"]["sample_size"]:
|
|
557
634
|
if x is None:
|
|
@@ -567,7 +644,7 @@ def main_script(version: str) -> None:
|
|
|
567
644
|
method=sample_method,
|
|
568
645
|
allowed_tokens=allowed_token,
|
|
569
646
|
)
|
|
570
|
-
|
|
647
|
+
elif runtime_config["inference"]["guidance_scaffold"]:
|
|
571
648
|
s = inpaint(
|
|
572
649
|
bfn,
|
|
573
650
|
x,
|
|
@@ -579,6 +656,18 @@ def main_script(version: str) -> None:
|
|
|
579
656
|
method=sample_method,
|
|
580
657
|
allowed_tokens=allowed_token,
|
|
581
658
|
)
|
|
659
|
+
else:
|
|
660
|
+
s = optimise(
|
|
661
|
+
bfn,
|
|
662
|
+
x,
|
|
663
|
+
sample_step,
|
|
664
|
+
y,
|
|
665
|
+
guidance_strength,
|
|
666
|
+
_device,
|
|
667
|
+
vocab_keys,
|
|
668
|
+
method=sample_method,
|
|
669
|
+
allowed_tokens=allowed_token,
|
|
670
|
+
)
|
|
582
671
|
if runtime_config["inference"]["exclude_invalid"]:
|
|
583
672
|
s = [i for i in s if i]
|
|
584
673
|
if tokeniser_name == "smiles" or tokeniser_name == "safe":
|
bayesianflow_for_chem/data.py
CHANGED
|
@@ -61,7 +61,7 @@ def load_vocab(
|
|
|
61
61
|
}
|
|
62
62
|
|
|
63
63
|
|
|
64
|
-
_DEFUALT_VOCAB = load_vocab(__filedir__ / "vocab.txt")
|
|
64
|
+
_DEFUALT_VOCAB = load_vocab(__filedir__ / "_data/vocab.txt")
|
|
65
65
|
VOCAB_KEYS: List[str] = _DEFUALT_VOCAB["vocab_keys"]
|
|
66
66
|
VOCAB_DICT: Dict[str, int] = _DEFUALT_VOCAB["vocab_dict"]
|
|
67
67
|
VOCAB_COUNT: int = _DEFUALT_VOCAB["vocab_count"]
|
bayesianflow_for_chem/model.py
CHANGED
|
@@ -659,12 +659,91 @@ class ChemBFN(nn.Module):
|
|
|
659
659
|
return (-logits.gather(-1, x[..., :1]).squeeze(-1)).mean()
|
|
660
660
|
|
|
661
661
|
@staticmethod
|
|
662
|
-
def
|
|
662
|
+
def _reshape(y: Tensor) -> Tensor:
|
|
663
663
|
assert y.dim() <= 3 # this doesn't work if the model is frezen in JIT.
|
|
664
664
|
if y.dim() == 2:
|
|
665
665
|
return y[:, None, :]
|
|
666
666
|
return y
|
|
667
667
|
|
|
668
|
+
def _process(
|
|
669
|
+
self,
|
|
670
|
+
theta: Tensor,
|
|
671
|
+
mask: Optional[Tuple[Tensor, Tensor]],
|
|
672
|
+
y: Optional[Tensor],
|
|
673
|
+
sample_step: int,
|
|
674
|
+
guidance_strength: float,
|
|
675
|
+
token_mask: Optional[Tensor],
|
|
676
|
+
) -> Tuple[Tensor, Tensor]:
|
|
677
|
+
# BFN inference process.
|
|
678
|
+
#
|
|
679
|
+
# theta: piror distribution; shape: (n_b, n_t, n_vocab)
|
|
680
|
+
# mask: masked condition distribution; shape: (n_b, n_t, n_vocab)
|
|
681
|
+
# condition distribution mask; shape: (n_b, n_t, 1)
|
|
682
|
+
n_b = theta.shape[0]
|
|
683
|
+
if y is not None:
|
|
684
|
+
y = self._reshape(y)
|
|
685
|
+
for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
|
|
686
|
+
t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
|
|
687
|
+
p = self.discrete_output_distribution(theta, t, y, guidance_strength)
|
|
688
|
+
if token_mask is not None:
|
|
689
|
+
p = p.masked_fill_(token_mask, 0.0)
|
|
690
|
+
alpha = self.calc_discrete_alpha(t, t + 1 / sample_step)
|
|
691
|
+
e_k = nn.functional.one_hot(torch.argmax(p, -1), self.K).float()
|
|
692
|
+
mu = alpha * (self.K * e_k - 1)
|
|
693
|
+
sigma = (alpha * self.K).sqrt()
|
|
694
|
+
theta = (mu + sigma * torch.randn_like(mu)).exp() * theta
|
|
695
|
+
theta = theta / theta.sum(-1, True)
|
|
696
|
+
if mask is not None:
|
|
697
|
+
x_onehot, x_mask = mask
|
|
698
|
+
theta = x_onehot + (1 - x_mask) * theta
|
|
699
|
+
t_final = torch.ones((n_b, 1, 1), device=self.beta.device)
|
|
700
|
+
p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
|
|
701
|
+
entropy = -(p * p.log()).sum(-1).mean(-1)
|
|
702
|
+
if token_mask is not None:
|
|
703
|
+
p = p.masked_fill_(token_mask, 0.0)
|
|
704
|
+
return torch.argmax(p, -1), entropy
|
|
705
|
+
|
|
706
|
+
def _ode_process(
|
|
707
|
+
self,
|
|
708
|
+
z: Tensor,
|
|
709
|
+
mask: Optional[Tuple[Tensor, Tensor]],
|
|
710
|
+
y: Optional[Tensor],
|
|
711
|
+
sample_step: int,
|
|
712
|
+
guidance_strength: float,
|
|
713
|
+
token_mask: Optional[Tensor],
|
|
714
|
+
temperature: float,
|
|
715
|
+
) -> Tuple[Tensor, Tensor]:
|
|
716
|
+
# ODE-solver engaged inference process.
|
|
717
|
+
#
|
|
718
|
+
# z: prior latent vector; shape: (n_b, n_t, n_vocab)
|
|
719
|
+
# mask: masked condition distribution; shape: (n_b, n_t, n_vocab)
|
|
720
|
+
# condition distribution mask; shape: (n_b, n_t, 1)
|
|
721
|
+
n_b = z.shape[0]
|
|
722
|
+
if y is not None:
|
|
723
|
+
y = self._reshape(y)
|
|
724
|
+
for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
|
|
725
|
+
t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
|
|
726
|
+
theta = softmax(z, -1)
|
|
727
|
+
if mask is not None:
|
|
728
|
+
x_onehot, x_mask = mask
|
|
729
|
+
theta = x_onehot + (1 - x_mask) * theta
|
|
730
|
+
beta = self.calc_beta(t + 1 / sample_step)
|
|
731
|
+
p = self.discrete_output_distribution(theta, t, y, guidance_strength)
|
|
732
|
+
if token_mask is not None:
|
|
733
|
+
p = p.masked_fill_(token_mask, 0.0)
|
|
734
|
+
u = torch.randn_like(z)
|
|
735
|
+
z = (self.K * p - 1) * beta + (self.K * beta * temperature).sqrt() * u
|
|
736
|
+
t_final = torch.ones((n_b, 1, 1), device=self.beta.device)
|
|
737
|
+
theta = softmax(z, -1)
|
|
738
|
+
if mask is not None:
|
|
739
|
+
x_onehot, x_mask = mask
|
|
740
|
+
theta = x_onehot + (1 - x_mask) * theta
|
|
741
|
+
p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
|
|
742
|
+
entropy = -(p * p.log()).sum(-1).mean(-1)
|
|
743
|
+
if token_mask is not None:
|
|
744
|
+
p = p.masked_fill_(token_mask, 0.0)
|
|
745
|
+
return torch.argmax(p, -1), entropy
|
|
746
|
+
|
|
668
747
|
@torch.jit.export
|
|
669
748
|
def sample(
|
|
670
749
|
self,
|
|
@@ -676,48 +755,30 @@ class ChemBFN(nn.Module):
|
|
|
676
755
|
token_mask: Optional[Tensor] = None,
|
|
677
756
|
) -> Tuple[Tensor, Tensor]:
|
|
678
757
|
"""
|
|
679
|
-
Sample from a piror distribution.
|
|
758
|
+
Sample from a uniform piror distribution.
|
|
680
759
|
|
|
681
760
|
:param batch_size: batch size
|
|
682
761
|
:param sequence_size: max sequence length
|
|
683
|
-
:param y: conditioning vector;
|
|
762
|
+
:param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
|
|
684
763
|
:param sample_step: number of sampling steps
|
|
685
764
|
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
686
765
|
:param token_mask: token mask assigning unwanted token(s) with `True`;
|
|
687
|
-
|
|
766
|
+
shape: (1, 1, n_vocab)
|
|
688
767
|
:type batch_size: int
|
|
689
768
|
:type sequence_size: int
|
|
690
769
|
:type y: torch.Tensor | None
|
|
691
770
|
:type sample_step: int
|
|
692
771
|
:type guidance_strength: float
|
|
693
772
|
:type token_mask: torch.Tensor | None
|
|
694
|
-
:return: sampled token indices;
|
|
695
|
-
entropy of the tokens;
|
|
773
|
+
:return: sampled token indices; shape: (n_b, n_t) \n
|
|
774
|
+
entropy of the tokens; shape: (n_b)
|
|
696
775
|
:rtype: tuple
|
|
697
776
|
"""
|
|
698
777
|
theta = (
|
|
699
778
|
torch.ones((batch_size, sequence_size, self.K), device=self.beta.device)
|
|
700
779
|
/ self.K
|
|
701
780
|
)
|
|
702
|
-
|
|
703
|
-
y = self.reshape_y(y)
|
|
704
|
-
for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
|
|
705
|
-
t = (i - 1).view(1, 1, 1).repeat(batch_size, 1, 1) / sample_step
|
|
706
|
-
p = self.discrete_output_distribution(theta, t, y, guidance_strength)
|
|
707
|
-
if token_mask is not None:
|
|
708
|
-
p = p.masked_fill_(token_mask, 0.0)
|
|
709
|
-
alpha = self.calc_discrete_alpha(t, t + 1 / sample_step)
|
|
710
|
-
e_k = nn.functional.one_hot(torch.argmax(p, -1), self.K).float()
|
|
711
|
-
mu = alpha * (self.K * e_k - 1)
|
|
712
|
-
sigma = (alpha * self.K).sqrt()
|
|
713
|
-
theta = (mu + sigma * torch.randn_like(mu)).exp() * theta
|
|
714
|
-
theta = theta / theta.sum(-1, True)
|
|
715
|
-
t_final = torch.ones((batch_size, 1, 1), device=self.beta.device)
|
|
716
|
-
p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
|
|
717
|
-
entropy = -(p * p.log()).sum(-1).mean(-1)
|
|
718
|
-
if token_mask is not None:
|
|
719
|
-
p = p.masked_fill_(token_mask, 0.0)
|
|
720
|
-
return torch.argmax(p, -1), entropy
|
|
781
|
+
return self._process(theta, None, y, sample_step, guidance_strength, token_mask)
|
|
721
782
|
|
|
722
783
|
@torch.jit.export
|
|
723
784
|
def ode_sample(
|
|
@@ -735,11 +796,11 @@ class ChemBFN(nn.Module):
|
|
|
735
796
|
|
|
736
797
|
:param batch_size: batch size
|
|
737
798
|
:param sequence_size: max sequence length
|
|
738
|
-
:param y: conditioning vector;
|
|
799
|
+
:param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
|
|
739
800
|
:param sample_step: number of sampling steps
|
|
740
801
|
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
741
802
|
:param token_mask: token mask assigning unwanted token(s) with `True`;
|
|
742
|
-
|
|
803
|
+
shape: (1, 1, n_vocab)
|
|
743
804
|
:param temperature: sampling temperature
|
|
744
805
|
:type batch_size: int
|
|
745
806
|
:type sequence_size: int
|
|
@@ -748,29 +809,14 @@ class ChemBFN(nn.Module):
|
|
|
748
809
|
:type guidance_strength: float
|
|
749
810
|
:type token_mask: torch.Tensor | None
|
|
750
811
|
:type temperature: float
|
|
751
|
-
:return: sampled token indices;
|
|
752
|
-
entropy of the tokens;
|
|
812
|
+
:return: sampled token indices; shape: (n_b, n_t) \n
|
|
813
|
+
entropy of the tokens; shape: (n_b)
|
|
753
814
|
:rtype: tuple
|
|
754
815
|
"""
|
|
755
816
|
z = torch.zeros((batch_size, sequence_size, self.K), device=self.beta.device)
|
|
756
|
-
|
|
757
|
-
y
|
|
758
|
-
|
|
759
|
-
t = (i - 1).view(1, 1, 1).repeat(batch_size, 1, 1) / sample_step
|
|
760
|
-
theta = torch.softmax(z, -1)
|
|
761
|
-
beta = self.calc_beta(t + 1 / sample_step)
|
|
762
|
-
p = self.discrete_output_distribution(theta, t, y, guidance_strength)
|
|
763
|
-
if token_mask is not None:
|
|
764
|
-
p = p.masked_fill_(token_mask, 0.0)
|
|
765
|
-
u = torch.randn_like(z)
|
|
766
|
-
z = (self.K * p - 1) * beta + (self.K * beta * temperature).sqrt() * u
|
|
767
|
-
t_final = torch.ones((batch_size, 1, 1), device=self.beta.device)
|
|
768
|
-
theta = torch.softmax(z, -1)
|
|
769
|
-
p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
|
|
770
|
-
entropy = -(p * p.log()).sum(-1).mean(-1)
|
|
771
|
-
if token_mask is not None:
|
|
772
|
-
p = p.masked_fill_(token_mask, 0.0)
|
|
773
|
-
return torch.argmax(p, -1), entropy
|
|
817
|
+
return self._ode_process(
|
|
818
|
+
z, None, y, sample_step, guidance_strength, token_mask, temperature
|
|
819
|
+
)
|
|
774
820
|
|
|
775
821
|
@torch.jit.export
|
|
776
822
|
def inpaint(
|
|
@@ -800,30 +846,12 @@ class ChemBFN(nn.Module):
|
|
|
800
846
|
:rtype: tuple
|
|
801
847
|
"""
|
|
802
848
|
n_b, n_t = x.shape
|
|
803
|
-
|
|
849
|
+
x_mask = (x != 0).float()[..., None]
|
|
804
850
|
theta = torch.ones((n_b, n_t, self.K), device=x.device) / self.K
|
|
805
|
-
x_onehot = nn.functional.one_hot(x, self.K) *
|
|
806
|
-
theta = x_onehot + (1 -
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
for i in torch.linspace(1, sample_step, sample_step, device=x.device):
|
|
810
|
-
t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
|
|
811
|
-
p = self.discrete_output_distribution(theta, t, y, guidance_strength)
|
|
812
|
-
if token_mask is not None:
|
|
813
|
-
p = p.masked_fill_(token_mask, 0.0)
|
|
814
|
-
alpha = self.calc_discrete_alpha(t, t + 1 / sample_step)
|
|
815
|
-
e_k = nn.functional.one_hot(torch.argmax(p, -1), self.K).float()
|
|
816
|
-
mu = alpha * (self.K * e_k - 1)
|
|
817
|
-
sigma = (alpha * self.K).sqrt()
|
|
818
|
-
theta = (mu + sigma * torch.randn_like(mu)).exp() * theta
|
|
819
|
-
theta = theta / theta.sum(-1, True)
|
|
820
|
-
theta = x_onehot + (1 - mask) * theta
|
|
821
|
-
t_final = torch.ones((n_b, 1, 1), device=x.device)
|
|
822
|
-
p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
|
|
823
|
-
entropy = -(p * p.log()).sum(-1).mean(-1)
|
|
824
|
-
if token_mask is not None:
|
|
825
|
-
p = p.masked_fill_(token_mask, 0.0)
|
|
826
|
-
return torch.argmax(p, -1), entropy
|
|
851
|
+
x_onehot = nn.functional.one_hot(x, self.K) * x_mask
|
|
852
|
+
theta = x_onehot + (1 - x_mask) * theta
|
|
853
|
+
mask = (x_onehot, x_mask)
|
|
854
|
+
return self._process(theta, mask, y, sample_step, guidance_strength, token_mask)
|
|
827
855
|
|
|
828
856
|
@torch.jit.export
|
|
829
857
|
def ode_inpaint(
|
|
@@ -856,29 +884,80 @@ class ChemBFN(nn.Module):
|
|
|
856
884
|
:rtype: tuple
|
|
857
885
|
"""
|
|
858
886
|
n_b, n_t = x.shape
|
|
859
|
-
|
|
860
|
-
x_onehot = nn.functional.one_hot(x, self.K) *
|
|
887
|
+
x_mask = (x != 0).float()[..., None]
|
|
888
|
+
x_onehot = nn.functional.one_hot(x, self.K) * x_mask
|
|
861
889
|
z = torch.zeros((n_b, n_t, self.K), device=self.beta.device)
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
890
|
+
mask = (x_onehot, x_mask)
|
|
891
|
+
return self._ode_process(
|
|
892
|
+
z, mask, y, sample_step, guidance_strength, token_mask, temperature
|
|
893
|
+
)
|
|
894
|
+
|
|
895
|
+
@torch.jit.export
|
|
896
|
+
def optimise(
|
|
897
|
+
self,
|
|
898
|
+
x: Tensor,
|
|
899
|
+
y: Optional[Tensor] = None,
|
|
900
|
+
sample_step: int = 100,
|
|
901
|
+
guidance_strength: float = 4.0,
|
|
902
|
+
token_mask: Optional[Tensor] = None,
|
|
903
|
+
) -> Tuple[Tensor, Tensor]:
|
|
904
|
+
"""
|
|
905
|
+
Optimise the template molecule (mol2mol). \n
|
|
906
|
+
This method is equivalent to sampling from a customised prior distribution.
|
|
907
|
+
|
|
908
|
+
:param x: categorical indices of template; shape: (n_b, n_t)
|
|
909
|
+
:param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
|
|
910
|
+
:param sample_step: number of sampling steps
|
|
911
|
+
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
912
|
+
:param token_mask: token mask assigning unwanted token(s) with `True`;
|
|
913
|
+
shape: (1, 1, n_vocab)
|
|
914
|
+
:type x: torch.Tensor
|
|
915
|
+
:type y: torch.Tensor | None
|
|
916
|
+
:type sample_step: int
|
|
917
|
+
:type guidance_strength: float
|
|
918
|
+
:type token_mask: torch.Tensor | None
|
|
919
|
+
:return: sampled token indices; shape: (n_b, n_t) \n
|
|
920
|
+
entropy of the tokens; shape: (n_b)
|
|
921
|
+
:rtype: tuple
|
|
922
|
+
"""
|
|
923
|
+
x_onehot = nn.functional.one_hot(x, self.K).float()
|
|
924
|
+
theta = softmax(x_onehot, -1)
|
|
925
|
+
return self._process(theta, None, y, sample_step, guidance_strength, token_mask)
|
|
926
|
+
|
|
927
|
+
@torch.jit.export
|
|
928
|
+
def ode_optimise(
|
|
929
|
+
self,
|
|
930
|
+
x: Tensor,
|
|
931
|
+
y: Optional[Tensor] = None,
|
|
932
|
+
sample_step: int = 100,
|
|
933
|
+
guidance_strength: float = 4.0,
|
|
934
|
+
token_mask: Optional[Tensor] = None,
|
|
935
|
+
temperature: float = 0.5,
|
|
936
|
+
) -> Tuple[Tensor, Tensor]:
|
|
937
|
+
"""
|
|
938
|
+
ODE mol2mol.
|
|
939
|
+
|
|
940
|
+
:param x: categorical indices of template; shape: (n_b, n_t)
|
|
941
|
+
:param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
|
|
942
|
+
:param sample_step: number of sampling steps
|
|
943
|
+
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
944
|
+
:param token_mask: token mask assigning unwanted token(s) with `True`;
|
|
945
|
+
shape: (1, 1, n_vocab)
|
|
946
|
+
:param temperature: sampling temperature
|
|
947
|
+
:type x: torch.Tensor
|
|
948
|
+
:type y: torch.Tensor | None
|
|
949
|
+
:type sample_step: int
|
|
950
|
+
:type guidance_strength: float
|
|
951
|
+
:type token_mask: torch.Tensor | None
|
|
952
|
+
:type temperature: float
|
|
953
|
+
:return: sampled token indices; shape: (n_b, n_t) \n
|
|
954
|
+
entropy of the tokens; shape: (n_b)
|
|
955
|
+
:rtype: tuple
|
|
956
|
+
"""
|
|
957
|
+
z = nn.functional.one_hot(x, self.K).float()
|
|
958
|
+
return self._ode_process(
|
|
959
|
+
z, None, y, sample_step, guidance_strength, token_mask, temperature
|
|
960
|
+
)
|
|
882
961
|
|
|
883
962
|
def inference(
|
|
884
963
|
self, x: Tensor, mlp: MLP, embed_fn: Optional[Callable[[Tensor], Tensor]] = None
|
|
@@ -1052,22 +1131,6 @@ class EnsembleChemBFN(ChemBFN):
|
|
|
1052
1131
|
module.lora_dropout = None
|
|
1053
1132
|
v.lora_enabled = False
|
|
1054
1133
|
|
|
1055
|
-
def construct_y(
|
|
1056
|
-
self, c: Union[List[Tensor], Dict[str, Tensor]]
|
|
1057
|
-
) -> Dict[str, Tensor]:
|
|
1058
|
-
assert (
|
|
1059
|
-
isinstance(c, dict) is self._label_is_dict
|
|
1060
|
-
), f"`c` should be a {'`dict` instance' if self._label_is_dict else '`list` instance'} but got {type(c)} instand."
|
|
1061
|
-
out: Dict[str, Tensor] = {}
|
|
1062
|
-
if isinstance(c, list):
|
|
1063
|
-
c = dict(zip([f"val_{i}" for i in range(len(c))], c))
|
|
1064
|
-
for name, model in self.cond_heads.items():
|
|
1065
|
-
y = model.forward(c[name])
|
|
1066
|
-
if y.dim() == 2:
|
|
1067
|
-
y = y[:, None, :]
|
|
1068
|
-
out[name] = y
|
|
1069
|
-
return out
|
|
1070
|
-
|
|
1071
1134
|
def discrete_output_distribution(
|
|
1072
1135
|
self, theta: Tensor, t: Tensor, y: Dict[str, Tensor], w: float
|
|
1073
1136
|
) -> Tensor:
|
|
@@ -1102,8 +1165,24 @@ class EnsembleChemBFN(ChemBFN):
|
|
|
1102
1165
|
p_cond += p_cond_ * self.adapter_weights[name]
|
|
1103
1166
|
return softmax((1 + w) * p_cond - w * p_uncond, -1)
|
|
1104
1167
|
|
|
1168
|
+
def _map_to_dict(
|
|
1169
|
+
self, c: Union[List[Tensor], Dict[str, Tensor]]
|
|
1170
|
+
) -> Dict[str, Tensor]:
|
|
1171
|
+
assert (
|
|
1172
|
+
isinstance(c, dict) is self._label_is_dict
|
|
1173
|
+
), f"`c` should be a {'`dict` instance' if self._label_is_dict else '`list` instance'} but got {type(c)} instand."
|
|
1174
|
+
out: Dict[str, Tensor] = {}
|
|
1175
|
+
if isinstance(c, list):
|
|
1176
|
+
c = dict(zip([f"val_{i}" for i in range(len(c))], c))
|
|
1177
|
+
for name, model in self.cond_heads.items():
|
|
1178
|
+
y = model.forward(c[name])
|
|
1179
|
+
if y.dim() == 2:
|
|
1180
|
+
y = y[:, None, :]
|
|
1181
|
+
out[name] = y
|
|
1182
|
+
return out
|
|
1183
|
+
|
|
1105
1184
|
@staticmethod
|
|
1106
|
-
def
|
|
1185
|
+
def _reshape(y: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
1107
1186
|
for k in y:
|
|
1108
1187
|
assert y[k].dim() <= 3
|
|
1109
1188
|
if y[k].dim() == 2:
|
|
@@ -1139,7 +1218,7 @@ class EnsembleChemBFN(ChemBFN):
|
|
|
1139
1218
|
entropy of the tokens; shape: (n_b)
|
|
1140
1219
|
:rtype: tuple
|
|
1141
1220
|
"""
|
|
1142
|
-
y = self.
|
|
1221
|
+
y = self._map_to_dict(conditions)
|
|
1143
1222
|
return super().sample(
|
|
1144
1223
|
batch_size, sequence_size, y, sample_step, guidance_strength, token_mask
|
|
1145
1224
|
)
|
|
@@ -1176,7 +1255,7 @@ class EnsembleChemBFN(ChemBFN):
|
|
|
1176
1255
|
entropy of the tokens; shape: (n_b)
|
|
1177
1256
|
:rtype: tuple
|
|
1178
1257
|
"""
|
|
1179
|
-
y = self.
|
|
1258
|
+
y = self._map_to_dict(conditions)
|
|
1180
1259
|
return super().ode_sample(
|
|
1181
1260
|
batch_size,
|
|
1182
1261
|
sequence_size,
|
|
@@ -1213,7 +1292,7 @@ class EnsembleChemBFN(ChemBFN):
|
|
|
1213
1292
|
entropy of the tokens; shape: (n_b)
|
|
1214
1293
|
:rtype: tuple
|
|
1215
1294
|
"""
|
|
1216
|
-
y = self.
|
|
1295
|
+
y = self._map_to_dict(conditions)
|
|
1217
1296
|
return super().inpaint(x, y, sample_step, guidance_strength, token_mask)
|
|
1218
1297
|
|
|
1219
1298
|
@torch.inference_mode()
|
|
@@ -1245,11 +1324,76 @@ class EnsembleChemBFN(ChemBFN):
|
|
|
1245
1324
|
entropy of the tokens; shape: (n_b)
|
|
1246
1325
|
:rtype: tuple
|
|
1247
1326
|
"""
|
|
1248
|
-
y = self.
|
|
1327
|
+
y = self._map_to_dict(conditions)
|
|
1249
1328
|
return super().ode_inpaint(
|
|
1250
1329
|
x, y, sample_step, guidance_strength, token_mask, temperature
|
|
1251
1330
|
)
|
|
1252
1331
|
|
|
1332
|
+
@torch.inference_mode()
|
|
1333
|
+
def optimise(
|
|
1334
|
+
self,
|
|
1335
|
+
x: Tensor,
|
|
1336
|
+
conditions: Union[List[Tensor], Dict[str, Tensor]],
|
|
1337
|
+
sample_step: int = 100,
|
|
1338
|
+
guidance_strength: float = 4.0,
|
|
1339
|
+
token_mask: Optional[Tensor] = None,
|
|
1340
|
+
) -> Tuple[Tensor, Tensor]:
|
|
1341
|
+
"""
|
|
1342
|
+
Optimise the template molecule (mol2mol). \n
|
|
1343
|
+
This method is equivalent to sampling from a customised prior distribution.
|
|
1344
|
+
|
|
1345
|
+
:param x: categorical indices of template; shape: (n_b, n_t)
|
|
1346
|
+
:param conditions: conditioning vector; shape: (n_b, n_c) * n_h
|
|
1347
|
+
:param sample_step: number of sampling steps
|
|
1348
|
+
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
1349
|
+
:param token_mask: token mask assigning unwanted token(s) with `True`;
|
|
1350
|
+
shape: (1, 1, n_vocab)
|
|
1351
|
+
:type x: torch.Tensor
|
|
1352
|
+
:type y: torch.Tensor | None
|
|
1353
|
+
:type sample_step: int
|
|
1354
|
+
:type guidance_strength: float
|
|
1355
|
+
:type token_mask: torch.Tensor | None
|
|
1356
|
+
:return: sampled token indices; shape: (n_b, n_t) \n
|
|
1357
|
+
entropy of the tokens; shape: (n_b)
|
|
1358
|
+
:rtype: tuple
|
|
1359
|
+
"""
|
|
1360
|
+
y = self._map_to_dict(conditions)
|
|
1361
|
+
return super().optimise(x, y, sample_step, guidance_strength, token_mask)
|
|
1362
|
+
|
|
1363
|
+
@torch.inference_mode()
|
|
1364
|
+
def ode_optimise(
|
|
1365
|
+
self,
|
|
1366
|
+
x: Tensor,
|
|
1367
|
+
conditions: Union[List[Tensor], Dict[str, Tensor]],
|
|
1368
|
+
sample_step: int = 100,
|
|
1369
|
+
guidance_strength: float = 4.0,
|
|
1370
|
+
token_mask: Optional[Tensor] = None,
|
|
1371
|
+
temperature: float = 0.5,
|
|
1372
|
+
) -> Tuple[Tensor, Tensor]:
|
|
1373
|
+
"""
|
|
1374
|
+
ODE inpainting.
|
|
1375
|
+
|
|
1376
|
+
:param x: categorical indices of template; shape: (n_b, n_t)
|
|
1377
|
+
:param conditions: conditioning vector; shape: (n_b, n_c) * n_h
|
|
1378
|
+
:param sample_step: number of sampling steps
|
|
1379
|
+
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
1380
|
+
:param token_mask: token mask; shape: (1, 1, n_vocab)
|
|
1381
|
+
:param temperature: sampling temperature
|
|
1382
|
+
:type x: torch.Tensor
|
|
1383
|
+
:type conditions: list | dict
|
|
1384
|
+
:type sample_step: int
|
|
1385
|
+
:type guidance_strength: float
|
|
1386
|
+
:type token_mask: torch.Tensor | None
|
|
1387
|
+
:type temperature: float
|
|
1388
|
+
:return: sampled token indices; shape: (n_b, n_t) \n
|
|
1389
|
+
entropy of the tokens; shape: (n_b)
|
|
1390
|
+
:rtype: tuple
|
|
1391
|
+
"""
|
|
1392
|
+
y = self._map_to_dict(conditions)
|
|
1393
|
+
return super().ode_optimise(
|
|
1394
|
+
x, y, sample_step, guidance_strength, token_mask, temperature
|
|
1395
|
+
)
|
|
1396
|
+
|
|
1253
1397
|
def quantise(
|
|
1254
1398
|
self, quantise_method: Optional[Callable[[ChemBFN], None]] = None
|
|
1255
1399
|
) -> None:
|
bayesianflow_for_chem/tool.py
CHANGED
|
@@ -9,7 +9,6 @@ import warnings
|
|
|
9
9
|
from pathlib import Path
|
|
10
10
|
from typing import List, Dict, Tuple, Union, Optional
|
|
11
11
|
import torch
|
|
12
|
-
import colorama
|
|
13
12
|
import numpy as np
|
|
14
13
|
from torch import cuda, Tensor, softmax
|
|
15
14
|
from torch.utils.data import DataLoader
|
|
@@ -24,15 +23,7 @@ from rdkit.Chem import (
|
|
|
24
23
|
AddHs,
|
|
25
24
|
Mol,
|
|
26
25
|
)
|
|
27
|
-
from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles
|
|
28
|
-
from sklearn.metrics import (
|
|
29
|
-
roc_auc_score,
|
|
30
|
-
auc,
|
|
31
|
-
precision_recall_curve,
|
|
32
|
-
r2_score,
|
|
33
|
-
mean_absolute_error,
|
|
34
|
-
root_mean_squared_error,
|
|
35
|
-
)
|
|
26
|
+
from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles
|
|
36
27
|
from .data import VOCAB_KEYS
|
|
37
28
|
from .model import ChemBFN, MLP, EnsembleChemBFN
|
|
38
29
|
|
|
@@ -45,6 +36,68 @@ def _find_device() -> torch.device:
|
|
|
45
36
|
return torch.device("cpu")
|
|
46
37
|
|
|
47
38
|
|
|
39
|
+
def _parse_and_assert_param(
|
|
40
|
+
model: Union[ChemBFN, EnsembleChemBFN],
|
|
41
|
+
y: Optional[Union[Tensor, Dict[str, Tensor], List[Tensor]]],
|
|
42
|
+
method: str,
|
|
43
|
+
) -> Optional[float]:
|
|
44
|
+
assert method.split(":")[0].lower() in ("ode", "bfn")
|
|
45
|
+
if isinstance(model, EnsembleChemBFN):
|
|
46
|
+
assert y is not None, "conditioning is required while using an ensemble model."
|
|
47
|
+
assert isinstance(y, list) or isinstance(y, dict)
|
|
48
|
+
else:
|
|
49
|
+
assert isinstance(y, Tensor) or (y is None)
|
|
50
|
+
if "ode" in method.lower():
|
|
51
|
+
tp = float(method.split(":")[-1])
|
|
52
|
+
assert tp > 0, "Sampling temperature should be higher than 0."
|
|
53
|
+
return tp
|
|
54
|
+
return None
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _map_to_device(
|
|
58
|
+
y: Optional[Union[Tensor, Dict[str, Tensor], List[Tensor]]],
|
|
59
|
+
device: Union[str, torch.device],
|
|
60
|
+
) -> Optional[Union[Tensor, Dict[str, Tensor], List[Tensor]]]:
|
|
61
|
+
if y is not None:
|
|
62
|
+
if isinstance(y, Tensor):
|
|
63
|
+
y = y.to(device)
|
|
64
|
+
elif isinstance(y, list):
|
|
65
|
+
y = [i.to(device) for i in y]
|
|
66
|
+
elif isinstance(y, dict):
|
|
67
|
+
y = {k: v.to(device) for k, v in y.items()}
|
|
68
|
+
else:
|
|
69
|
+
raise NotImplementedError
|
|
70
|
+
return y
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _build_token_mask(
|
|
74
|
+
allowed_tokens: Union[str, List[str]],
|
|
75
|
+
vocab_keys: List[str],
|
|
76
|
+
device: Union[str, torch.tensor],
|
|
77
|
+
) -> Optional[Tensor]:
|
|
78
|
+
if isinstance(allowed_tokens, list):
|
|
79
|
+
token_mask = [0 if i in allowed_tokens else 1 for i in vocab_keys]
|
|
80
|
+
token_mask = torch.tensor([[token_mask]], dtype=torch.bool).to(device)
|
|
81
|
+
else:
|
|
82
|
+
token_mask = None
|
|
83
|
+
return token_mask
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _token_to_seq(
|
|
87
|
+
tokens: Tensor, entropy: Tensor, vocab_keys: List[str], separator: str, sort: bool
|
|
88
|
+
) -> List[str]:
|
|
89
|
+
if sort:
|
|
90
|
+
sorted_idx = entropy.argsort(stable=True)
|
|
91
|
+
tokens = tokens[sorted_idx]
|
|
92
|
+
return [
|
|
93
|
+
separator.join([vocab_keys[i] for i in j])
|
|
94
|
+
.split("<start>" + separator)[-1]
|
|
95
|
+
.split(separator + "<end>")[0]
|
|
96
|
+
.replace("<pad>", "")
|
|
97
|
+
for j in tokens
|
|
98
|
+
]
|
|
99
|
+
|
|
100
|
+
|
|
48
101
|
@torch.no_grad()
|
|
49
102
|
def test(
|
|
50
103
|
model: ChemBFN,
|
|
@@ -86,6 +139,12 @@ def test(
|
|
|
86
139
|
predict_y.append(y_hat.detach().to("cpu"))
|
|
87
140
|
predict_y, label_y = torch.cat(predict_y, 0), torch.cat(label_y, 0).split(1, -1)
|
|
88
141
|
if mode == "regression":
|
|
142
|
+
from sklearn.metrics import (
|
|
143
|
+
r2_score,
|
|
144
|
+
mean_absolute_error,
|
|
145
|
+
root_mean_squared_error,
|
|
146
|
+
)
|
|
147
|
+
|
|
89
148
|
predict_y = [
|
|
90
149
|
predict[label_y[i] != torch.inf]
|
|
91
150
|
for (i, predict) in enumerate(predict_y.split(1, -1))
|
|
@@ -99,6 +158,8 @@ def test(
|
|
|
99
158
|
r2 = [r2_score(label, predict) for (label, predict) in y_zipped]
|
|
100
159
|
return {"MAE": mae, "RMSE": rmse, "R^2": r2}
|
|
101
160
|
if mode == "classification":
|
|
161
|
+
from sklearn.metrics import roc_auc_score, auc, precision_recall_curve
|
|
162
|
+
|
|
102
163
|
n_c = len(label_y)
|
|
103
164
|
predict_y = predict_y.chunk(n_c, -1)
|
|
104
165
|
y_zipped = list(zip(label_y, predict_y))
|
|
@@ -142,7 +203,6 @@ def split_dataset(
|
|
|
142
203
|
assert file.endswith(".csv")
|
|
143
204
|
assert len(split_ratio) == 3
|
|
144
205
|
assert method in ("random", "scaffold")
|
|
145
|
-
colorama.just_fix_windows_console()
|
|
146
206
|
with open(file, "r") as f:
|
|
147
207
|
data = list(csv.reader(f))
|
|
148
208
|
header = data[0]
|
|
@@ -167,10 +227,8 @@ def split_dataset(
|
|
|
167
227
|
# compute Bemis-Murcko scaffold
|
|
168
228
|
if len(smiles_idx) > 1:
|
|
169
229
|
warnings.warn(
|
|
170
|
-
"\033[32;1m"
|
|
171
230
|
f"We found {len(smiles_idx)} SMILES strings in a row!"
|
|
172
|
-
" Only the first SMILES will be used to compute the molecular scaffold."
|
|
173
|
-
"\033[0m",
|
|
231
|
+
" Only the first SMILES will be used to compute the molecular scaffold.",
|
|
174
232
|
stacklevel=2,
|
|
175
233
|
)
|
|
176
234
|
try:
|
|
@@ -197,10 +255,10 @@ def split_dataset(
|
|
|
197
255
|
with open(file.replace(".csv", "_test.csv"), "w", newline="") as fte:
|
|
198
256
|
writer = csv.writer(fte)
|
|
199
257
|
writer.writerows([header] + test_set)
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
258
|
+
if val_set:
|
|
259
|
+
with open(file.replace(".csv", "_val.csv"), "w", newline="") as fva:
|
|
260
|
+
writer = csv.writer(fva)
|
|
261
|
+
writer.writerows([header] + val_set)
|
|
204
262
|
|
|
205
263
|
|
|
206
264
|
@torch.no_grad()
|
|
@@ -219,7 +277,7 @@ def sample(
|
|
|
219
277
|
sort: bool = False,
|
|
220
278
|
) -> List[str]:
|
|
221
279
|
"""
|
|
222
|
-
Sampling.
|
|
280
|
+
Sampling molecules.
|
|
223
281
|
|
|
224
282
|
:param model: trained ChemBFN model
|
|
225
283
|
:param batch_size: batch size
|
|
@@ -250,32 +308,12 @@ def sample(
|
|
|
250
308
|
:return: a list of generated molecular strings
|
|
251
309
|
:rtype: list
|
|
252
310
|
"""
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
assert y is not None, "conditioning is required while using an ensemble model."
|
|
256
|
-
assert isinstance(y, list) or isinstance(y, dict)
|
|
257
|
-
else:
|
|
258
|
-
assert isinstance(y, Tensor) or y is None
|
|
259
|
-
if device is None:
|
|
260
|
-
device = _find_device()
|
|
311
|
+
tp = _parse_and_assert_param(model, y, method)
|
|
312
|
+
device = _find_device() if device is None else device
|
|
261
313
|
model.to(device).eval()
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
elif isinstance(y, list):
|
|
266
|
-
y = [i.to(device) for i in y]
|
|
267
|
-
elif isinstance(y, dict):
|
|
268
|
-
y = {k: v.to(device) for k, v in y.items()}
|
|
269
|
-
else:
|
|
270
|
-
raise NotImplementedError
|
|
271
|
-
if isinstance(allowed_tokens, list):
|
|
272
|
-
token_mask = [0 if i in allowed_tokens else 1 for i in vocab_keys]
|
|
273
|
-
token_mask = torch.tensor([[token_mask]], dtype=torch.bool).to(device)
|
|
274
|
-
else:
|
|
275
|
-
token_mask = None
|
|
276
|
-
if "ode" in method.lower():
|
|
277
|
-
tp = float(method.split(":")[-1])
|
|
278
|
-
assert tp > 0, "Sampling temperature should be higher than 0."
|
|
314
|
+
y = _map_to_device(y, device)
|
|
315
|
+
token_mask = _build_token_mask(allowed_tokens, vocab_keys, device)
|
|
316
|
+
if tp:
|
|
279
317
|
tokens, entropy = model.ode_sample(
|
|
280
318
|
batch_size, sequence_size, y, sample_step, guidance_strength, token_mask, tp
|
|
281
319
|
)
|
|
@@ -283,16 +321,7 @@ def sample(
|
|
|
283
321
|
tokens, entropy = model.sample(
|
|
284
322
|
batch_size, sequence_size, y, sample_step, guidance_strength, token_mask
|
|
285
323
|
)
|
|
286
|
-
|
|
287
|
-
sorted_idx = entropy.argsort(stable=True)
|
|
288
|
-
tokens = tokens[sorted_idx]
|
|
289
|
-
return [
|
|
290
|
-
seperator.join([vocab_keys[i] for i in j])
|
|
291
|
-
.split("<start>" + seperator)[-1]
|
|
292
|
-
.split(seperator + "<end>")[0]
|
|
293
|
-
.replace("<pad>", "")
|
|
294
|
-
for j in tokens
|
|
295
|
-
]
|
|
324
|
+
return _token_to_seq(tokens, entropy, vocab_keys, seperator, sort)
|
|
296
325
|
|
|
297
326
|
|
|
298
327
|
@torch.no_grad()
|
|
@@ -339,33 +368,13 @@ def inpaint(
|
|
|
339
368
|
:return: a list of generated molecular strings
|
|
340
369
|
:rtype: list
|
|
341
370
|
"""
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
assert y is not None, "conditioning is required while using an ensemble model."
|
|
345
|
-
assert isinstance(y, list) or isinstance(y, dict)
|
|
346
|
-
else:
|
|
347
|
-
assert isinstance(y, Tensor) or y is None
|
|
348
|
-
if device is None:
|
|
349
|
-
device = _find_device()
|
|
371
|
+
tp = _parse_and_assert_param(model, y, method)
|
|
372
|
+
device = _find_device() if device is None else device
|
|
350
373
|
model.to(device).eval()
|
|
351
374
|
x = x.to(device)
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
elif isinstance(y, list):
|
|
356
|
-
y = [i.to(device) for i in y]
|
|
357
|
-
elif isinstance(y, dict):
|
|
358
|
-
y = {k: v.to(device) for k, v in y.items()}
|
|
359
|
-
else:
|
|
360
|
-
raise NotImplementedError
|
|
361
|
-
if isinstance(allowed_tokens, list):
|
|
362
|
-
token_mask = [0 if i in allowed_tokens else 1 for i in vocab_keys]
|
|
363
|
-
token_mask = torch.tensor([[token_mask]], dtype=torch.bool).to(device)
|
|
364
|
-
else:
|
|
365
|
-
token_mask = None
|
|
366
|
-
if "ode" in method.lower():
|
|
367
|
-
tp = float(method.split(":")[-1])
|
|
368
|
-
assert tp > 0, "Sampling temperature should be higher than 0."
|
|
375
|
+
y = _map_to_device(y, device)
|
|
376
|
+
token_mask = _build_token_mask(allowed_tokens, vocab_keys, device)
|
|
377
|
+
if tp:
|
|
369
378
|
tokens, entropy = model.ode_inpaint(
|
|
370
379
|
x, y, sample_step, guidance_strength, token_mask, tp
|
|
371
380
|
)
|
|
@@ -373,16 +382,68 @@ def inpaint(
|
|
|
373
382
|
tokens, entropy = model.inpaint(
|
|
374
383
|
x, y, sample_step, guidance_strength, token_mask
|
|
375
384
|
)
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
385
|
+
return _token_to_seq(tokens, entropy, vocab_keys, separator, sort)
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
@torch.no_grad()
|
|
389
|
+
def optimise(
|
|
390
|
+
model: Union[ChemBFN, EnsembleChemBFN],
|
|
391
|
+
x: Tensor,
|
|
392
|
+
sample_step: int = 100,
|
|
393
|
+
y: Optional[Union[Tensor, Dict[str, Tensor], List[Tensor]]] = None,
|
|
394
|
+
guidance_strength: float = 4.0,
|
|
395
|
+
device: Union[str, torch.device, None] = None,
|
|
396
|
+
vocab_keys: List[str] = VOCAB_KEYS,
|
|
397
|
+
separator: str = "",
|
|
398
|
+
method: str = "BFN",
|
|
399
|
+
allowed_tokens: Union[str, List[str]] = "all",
|
|
400
|
+
sort: bool = False,
|
|
401
|
+
) -> List[str]:
|
|
402
|
+
"""
|
|
403
|
+
Optimising template molecules (mol2mol).
|
|
404
|
+
|
|
405
|
+
:param model: trained ChemBFN model
|
|
406
|
+
:param x: categorical indices of template; shape: (n_b, n_t)
|
|
407
|
+
:param sample_step: number of sampling steps
|
|
408
|
+
:param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f) \n
|
|
409
|
+
or a list/`dict` of conditions; shape: (n_b, n_c) * n_h
|
|
410
|
+
|
|
411
|
+
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
412
|
+
:param device: hardware accelerator
|
|
413
|
+
:param vocab_keys: a list of (ordered) vocabulary
|
|
414
|
+
:param separator: token separator; default is `""`
|
|
415
|
+
:param method: sampling method chosen from `"ODE:x"` or `"BFN"` where `x` is the value of sampling temperature; default is `"BFN"`
|
|
416
|
+
:param allowed_tokens: a list of allowed tokens; default is `"all"`
|
|
417
|
+
:param sort: whether to sort the samples according to entropy values; default is `False`
|
|
418
|
+
:type model: bayesianflow_for_chem.model.ChemBFN | bayesianflow_for_chem.model.EnsembleChemBFN
|
|
419
|
+
:type x: torch.Tensor
|
|
420
|
+
:type sample_step: int
|
|
421
|
+
:type y: torch.Tensor | list | dict | None
|
|
422
|
+
:type guidance_strength: float
|
|
423
|
+
:type device: str | torch.device | None
|
|
424
|
+
:type vocab_keys: list
|
|
425
|
+
:type separator: str
|
|
426
|
+
:type method: str
|
|
427
|
+
:type allowed_tokens: str | list
|
|
428
|
+
:type sort: bool
|
|
429
|
+
:return: a list of generated molecular strings
|
|
430
|
+
:rtype: list
|
|
431
|
+
"""
|
|
432
|
+
tp = _parse_and_assert_param(model, y, method)
|
|
433
|
+
device = _find_device() if device is None else device
|
|
434
|
+
model.to(device).eval()
|
|
435
|
+
x = x.to(device)
|
|
436
|
+
y = _map_to_device(y, device)
|
|
437
|
+
token_mask = _build_token_mask(allowed_tokens, vocab_keys, device)
|
|
438
|
+
if tp:
|
|
439
|
+
tokens, entropy = model.ode_optimise(
|
|
440
|
+
x, y, sample_step, guidance_strength, token_mask, tp
|
|
441
|
+
)
|
|
442
|
+
else:
|
|
443
|
+
tokens, entropy = model.optimise(
|
|
444
|
+
x, y, sample_step, guidance_strength, token_mask
|
|
445
|
+
)
|
|
446
|
+
return _token_to_seq(tokens, entropy, vocab_keys, separator, sort)
|
|
386
447
|
|
|
387
448
|
|
|
388
449
|
def quantise_model_(model: ChemBFN) -> None:
|
bayesianflow_for_chem/train.py
CHANGED
|
@@ -8,7 +8,6 @@ from typing import Dict, Tuple, Union, Optional
|
|
|
8
8
|
import torch
|
|
9
9
|
import torch.optim as op
|
|
10
10
|
import torch.nn.functional as F
|
|
11
|
-
from loralib import lora_state_dict, mark_only_lora_as_trainable
|
|
12
11
|
from torch import Tensor
|
|
13
12
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
14
13
|
from lightning import LightningModule
|
|
@@ -55,6 +54,8 @@ class Model(LightningModule):
|
|
|
55
54
|
self.scorer = scorer
|
|
56
55
|
self.save_hyperparameters(hparam, ignore=["model", "mlp", "scorer"])
|
|
57
56
|
if model.lora_enabled:
|
|
57
|
+
from loralib import mark_only_lora_as_trainable
|
|
58
|
+
|
|
58
59
|
mark_only_lora_as_trainable(self.model)
|
|
59
60
|
self.use_scorer = self.scorer is not None
|
|
60
61
|
|
|
@@ -107,6 +108,8 @@ class Model(LightningModule):
|
|
|
107
108
|
:rtype: None
|
|
108
109
|
"""
|
|
109
110
|
if self.model.lora_enabled:
|
|
111
|
+
from loralib import lora_state_dict
|
|
112
|
+
|
|
110
113
|
torch.save(
|
|
111
114
|
{
|
|
112
115
|
"lora_nn": lora_state_dict(self.model),
|
|
@@ -152,6 +155,8 @@ class Regressor(LightningModule):
|
|
|
152
155
|
self.model.requires_grad_(not hparam["freeze"])
|
|
153
156
|
self.save_hyperparameters(hparam, ignore=["model", "mlp"])
|
|
154
157
|
if model.lora_enabled:
|
|
158
|
+
from loralib import mark_only_lora_as_trainable
|
|
159
|
+
|
|
155
160
|
mark_only_lora_as_trainable(self.model)
|
|
156
161
|
assert hparam["mode"] in ("regression", "classification")
|
|
157
162
|
|
|
@@ -231,6 +236,8 @@ class Regressor(LightningModule):
|
|
|
231
236
|
)
|
|
232
237
|
if not self.hparams.freeze:
|
|
233
238
|
if self.model.lora_enabled:
|
|
239
|
+
from loralib import lora_state_dict
|
|
240
|
+
|
|
234
241
|
torch.save(
|
|
235
242
|
{
|
|
236
243
|
"lora_nn": lora_state_dict(self.model),
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: bayesianflow_for_chem
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.2.1
|
|
4
4
|
Summary: Bayesian flow network framework for Chemistry
|
|
5
5
|
Home-page: https://augus1999.github.io/bayesian-flow-network-for-chemistry/
|
|
6
6
|
Author: Nianze A. Tao
|
|
@@ -54,12 +54,14 @@ This is the repository of the PyTorch implementation of ChemBFN model.
|
|
|
54
54
|
|
|
55
55
|
[](https://pypi.org/project/bayesianflow-for-chem/)
|
|
56
56
|

|
|
57
|
+
[](https://augus1999.github.io/bayesian-flow-network-for-chemistry/)
|
|
57
58
|
|
|
58
59
|
## Features
|
|
59
60
|
|
|
60
61
|
ChemBFN provides the state-of-the-art functionalities of
|
|
61
62
|
* SMILES or SELFIES-based *de novo* molecule generation
|
|
62
63
|
* Protein sequence *de novo* generation
|
|
64
|
+
* Template optimisation (mol2mol)
|
|
63
65
|
* Classifier-free guidance conditional generation (single or multi-objective optimisation)
|
|
64
66
|
* Context-guided conditional generation (inpaint)
|
|
65
67
|
* Outstanding out-of-distribution chemical space sampling
|
|
@@ -71,6 +73,7 @@ in an all-in-one-model style.
|
|
|
71
73
|
|
|
72
74
|
## News
|
|
73
75
|
|
|
76
|
+
* [09/10/2025] A web app [`chembfn_webui`](https://github.com/Augus1999/ChemBFN-WebUI) for hosting ChemBFN models is available on [PyPI](https://pypi.org/project/chembfn-webui/).
|
|
74
77
|
* [30/01/2025] The package `bayesianflow_for_chem` is available on [PyPI](https://pypi.org/project/bayesianflow-for-chem/).
|
|
75
78
|
* [21/01/2025] Our first paper has been accepted by [JCIM](https://pubs.acs.org/doi/10.1021/acs.jcim.4c01792).
|
|
76
79
|
* [17/12/2024] The second paper of out-of-distribution generation is available on [arxiv.org](https://arxiv.org/abs/2412.11439).
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
bayesianflow_for_chem/__init__.py,sha256=l6twtXTdnix9SC8oL0Aqr0dYJ1uBUmMtk2QbUaRrhkE,642
|
|
2
|
+
bayesianflow_for_chem/cli.py,sha256=jZPBUVwOl4qnjUMb3Lavu6owv1RP9jGdrDlM7KDA99c,26600
|
|
3
|
+
bayesianflow_for_chem/data.py,sha256=RM3YaZUkTJ4-RNLH37GwhM4sRddOOusrC996ZPAe1lI,6590
|
|
4
|
+
bayesianflow_for_chem/model.py,sha256=M35G4u4mX4btl9vOK3Iqs6yOSuIKI_OoCTmLhmjbwNk,57559
|
|
5
|
+
bayesianflow_for_chem/scorer.py,sha256=gQFUlkyxitch02ntqcRh1ZS8aondKLynW5U6NfTQTb4,4084
|
|
6
|
+
bayesianflow_for_chem/spectra.py,sha256=Ba9ib1aDvTtDYbH3b4d-lIty3ZSQMu7jwehuV2KmhwA,1781
|
|
7
|
+
bayesianflow_for_chem/tool.py,sha256=VQo6xShwwmJbBsC-xTRL5ysCI7OLoOpmwQ-A7AuAPI8,23587
|
|
8
|
+
bayesianflow_for_chem/train.py,sha256=7AU0A-eZwzSYsLyIe3OxGTNWPnhGpHmVUaQLplV2Fn8,9886
|
|
9
|
+
bayesianflow_for_chem/_data/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
|
|
10
|
+
bayesianflow_for_chem-2.2.1.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
|
|
11
|
+
bayesianflow_for_chem-2.2.1.dist-info/METADATA,sha256=YKuXgdklakB9DH4cFpTy3EbYdmo8dz9BgQQ4bI9CVQs,6476
|
|
12
|
+
bayesianflow_for_chem-2.2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
13
|
+
bayesianflow_for_chem-2.2.1.dist-info/entry_points.txt,sha256=N63RMoJsr8rxuKxc7Fj802SL8J5AlpCoPkS8E3IFPLI,54
|
|
14
|
+
bayesianflow_for_chem-2.2.1.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
|
|
15
|
+
bayesianflow_for_chem-2.2.1.dist-info/RECORD,,
|
|
@@ -1,15 +0,0 @@
|
|
|
1
|
-
bayesianflow_for_chem/__init__.py,sha256=U8ExDm5IRa-OPOLgt8VfjMDAiCmARiJonIKnZ8AVDQ0,612
|
|
2
|
-
bayesianflow_for_chem/cli.py,sha256=g5pjXymycpchlgRE6SKr0LjTBjUl6MFHnFdQiKjcE3Q,22803
|
|
3
|
-
bayesianflow_for_chem/data.py,sha256=Pl0gGWHmMKTKHpsxznvLgYPCwwlLNL7nqH19Vipjkxs,6584
|
|
4
|
-
bayesianflow_for_chem/model.py,sha256=QF15BLpUjEpUCneTOHoU6MswvxArPfbFMiOHjwJ9JrM,52230
|
|
5
|
-
bayesianflow_for_chem/scorer.py,sha256=gQFUlkyxitch02ntqcRh1ZS8aondKLynW5U6NfTQTb4,4084
|
|
6
|
-
bayesianflow_for_chem/spectra.py,sha256=Ba9ib1aDvTtDYbH3b4d-lIty3ZSQMu7jwehuV2KmhwA,1781
|
|
7
|
-
bayesianflow_for_chem/tool.py,sha256=JSlKrar1vVouTRUJBYU6lc5AQmAIaexvpUA3W3oQcKs,21284
|
|
8
|
-
bayesianflow_for_chem/train.py,sha256=jYkhSguW50lrcTEydCQ20yig_mmc1j7WH9KmVwBCTAo,9727
|
|
9
|
-
bayesianflow_for_chem/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
|
|
10
|
-
bayesianflow_for_chem-2.0.5.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
|
|
11
|
-
bayesianflow_for_chem-2.0.5.dist-info/METADATA,sha256=qCd0LnD66eQbWlTvKDpHZHtJKmy9NQogm_k9_amEt2s,6057
|
|
12
|
-
bayesianflow_for_chem-2.0.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
13
|
-
bayesianflow_for_chem-2.0.5.dist-info/entry_points.txt,sha256=N63RMoJsr8rxuKxc7Fj802SL8J5AlpCoPkS8E3IFPLI,54
|
|
14
|
-
bayesianflow_for_chem-2.0.5.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
|
|
15
|
-
bayesianflow_for_chem-2.0.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
{bayesianflow_for_chem-2.0.5.dist-info → bayesianflow_for_chem-2.2.1.dist-info}/entry_points.txt
RENAMED
|
File without changes
|
{bayesianflow_for_chem-2.0.5.dist-info → bayesianflow_for_chem-2.2.1.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|
{bayesianflow_for_chem-2.0.5.dist-info → bayesianflow_for_chem-2.2.1.dist-info}/top_level.txt
RENAMED
|
File without changes
|