bayesianflow-for-chem 2.1.0__py3-none-any.whl → 2.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bayesianflow-for-chem might be problematic. Click here for more details.
- bayesianflow_for_chem/__init__.py +4 -3
- bayesianflow_for_chem/cli.py +85 -27
- bayesianflow_for_chem/data.py +1 -1
- bayesianflow_for_chem/model.py +131 -154
- bayesianflow_for_chem/tool.py +94 -123
- bayesianflow_for_chem/train.py +8 -1
- {bayesianflow_for_chem-2.1.0.dist-info → bayesianflow_for_chem-2.2.1.dist-info}/METADATA +4 -1
- bayesianflow_for_chem-2.2.1.dist-info/RECORD +15 -0
- bayesianflow_for_chem-2.1.0.dist-info/RECORD +0 -15
- /bayesianflow_for_chem/{vocab.txt → _data/vocab.txt} +0 -0
- {bayesianflow_for_chem-2.1.0.dist-info → bayesianflow_for_chem-2.2.1.dist-info}/WHEEL +0 -0
- {bayesianflow_for_chem-2.1.0.dist-info → bayesianflow_for_chem-2.2.1.dist-info}/entry_points.txt +0 -0
- {bayesianflow_for_chem-2.1.0.dist-info → bayesianflow_for_chem-2.2.1.dist-info}/licenses/LICENSE +0 -0
- {bayesianflow_for_chem-2.1.0.dist-info → bayesianflow_for_chem-2.2.1.dist-info}/top_level.txt +0 -0
|
@@ -3,10 +3,8 @@
|
|
|
3
3
|
"""
|
|
4
4
|
ChemBFN package.
|
|
5
5
|
"""
|
|
6
|
-
import colorama
|
|
7
6
|
from . import data, tool, train, scorer, spectra
|
|
8
7
|
from .model import ChemBFN, MLP, EnsembleChemBFN
|
|
9
|
-
from .cli import main_script
|
|
10
8
|
|
|
11
9
|
__all__ = [
|
|
12
10
|
"data",
|
|
@@ -18,7 +16,7 @@ __all__ = [
|
|
|
18
16
|
"MLP",
|
|
19
17
|
"EnsembleChemBFN",
|
|
20
18
|
]
|
|
21
|
-
__version__ = "2.1
|
|
19
|
+
__version__ = "2.2.1"
|
|
22
20
|
__author__ = "Nianze A. Tao (Omozawa Sueno)"
|
|
23
21
|
|
|
24
22
|
|
|
@@ -29,6 +27,9 @@ def main() -> None:
|
|
|
29
27
|
:return:
|
|
30
28
|
:rtype: None
|
|
31
29
|
"""
|
|
30
|
+
import colorama
|
|
31
|
+
from bayesianflow_for_chem.cli import main_script
|
|
32
|
+
|
|
32
33
|
colorama.just_fix_windows_console()
|
|
33
34
|
main_script(__version__)
|
|
34
35
|
colorama.deinit()
|
bayesianflow_for_chem/cli.py
CHANGED
|
@@ -12,13 +12,8 @@ from pathlib import Path
|
|
|
12
12
|
from functools import partial
|
|
13
13
|
from typing import List, Tuple, Dict, Union, Callable
|
|
14
14
|
import torch
|
|
15
|
-
import lightning as L
|
|
16
15
|
from rdkit.Chem import MolFromSmiles, CanonSmiles
|
|
17
|
-
from torch.utils.data import DataLoader
|
|
18
|
-
from lightning.pytorch import loggers
|
|
19
|
-
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
20
16
|
from bayesianflow_for_chem import ChemBFN, MLP
|
|
21
|
-
from bayesianflow_for_chem.train import Model
|
|
22
17
|
from bayesianflow_for_chem.scorer import smiles_valid, Scorer
|
|
23
18
|
from bayesianflow_for_chem.data import (
|
|
24
19
|
VOCAB_COUNT,
|
|
@@ -90,6 +85,7 @@ checkpoint_save_path = "home/user/project/ckpt"
|
|
|
90
85
|
train_strategy = "auto" # or any strategy supported by Lightning, e.g., "ddp"
|
|
91
86
|
accumulate_grad_batches = 1
|
|
92
87
|
enable_progress_bar = false
|
|
88
|
+
plugin_script = "" # define customised behaviours of dataset, datasetloader, etc in a python script
|
|
93
89
|
|
|
94
90
|
# Remove this table if inference is unnecessary
|
|
95
91
|
[inference]
|
|
@@ -120,6 +116,32 @@ madmadmadmadmadmadmadmadmadmadmadmadmadmadmad
|
|
|
120
116
|
madmadmadmadmadmadmadmadmadmadmadmadmadmadmad
|
|
121
117
|
"""
|
|
122
118
|
|
|
119
|
+
_ALLOWED_PLUGINS = [
|
|
120
|
+
"collate_fn",
|
|
121
|
+
"num_workers",
|
|
122
|
+
"max_sequence_length",
|
|
123
|
+
"shuffle",
|
|
124
|
+
"CustomData",
|
|
125
|
+
]
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _load_plugin(plugin_file: str) -> Dict[str, Union[int, Callable, object, None]]:
|
|
129
|
+
if not plugin_file:
|
|
130
|
+
return {n: None for n in _ALLOWED_PLUGINS}
|
|
131
|
+
from importlib import util as iutil
|
|
132
|
+
|
|
133
|
+
spec = iutil.spec_from_file_location(Path(plugin_file).stem, plugin_file)
|
|
134
|
+
plugins = iutil.module_from_spec(spec)
|
|
135
|
+
spec.loader.exec_module(plugins)
|
|
136
|
+
plugin_names: List[str] = plugins.__all__
|
|
137
|
+
plugin_dict = {}
|
|
138
|
+
for n in _ALLOWED_PLUGINS:
|
|
139
|
+
if n in plugin_names:
|
|
140
|
+
plugin_dict[n] = getattr(plugins, n)
|
|
141
|
+
else:
|
|
142
|
+
plugin_dict[n] = None
|
|
143
|
+
return plugin_dict
|
|
144
|
+
|
|
123
145
|
|
|
124
146
|
def parse_cli(version: str) -> argparse.Namespace:
|
|
125
147
|
"""
|
|
@@ -267,6 +289,14 @@ def load_runtime_config(
|
|
|
267
289
|
f"\033[0;31mCritical\033[0;0m in {config_file}: Restart checkpoint file {ckpt_file} does not exist."
|
|
268
290
|
)
|
|
269
291
|
flag_critical += 1
|
|
292
|
+
# ↓ added in v2.2.0; need to be compatible with old versions.
|
|
293
|
+
plugin_script: str = config["train"].get("plugin_script", "")
|
|
294
|
+
if plugin_script:
|
|
295
|
+
if not os.path.exists(plugin_script):
|
|
296
|
+
print(
|
|
297
|
+
f"\033[0;31mCritical\033[0;0m in {config_file}: Plugin script {plugin_script} does not exist."
|
|
298
|
+
)
|
|
299
|
+
flag_critical += 1
|
|
270
300
|
if "inference" in config:
|
|
271
301
|
if not "train" in config:
|
|
272
302
|
if not isinstance(config["inference"]["sequence_length"], int):
|
|
@@ -335,14 +365,15 @@ def main_script(version: str) -> None:
|
|
|
335
365
|
)
|
|
336
366
|
flag_warning += 1
|
|
337
367
|
if not os.path.exists(runtime_config["train"]["checkpoint_save_path"]):
|
|
338
|
-
|
|
368
|
+
if not parser.dryrun: # only create it in real tasks
|
|
369
|
+
os.makedirs(runtime_config["train"]["checkpoint_save_path"])
|
|
339
370
|
else:
|
|
340
371
|
if not model_config["ChemBFN"]["base_model"]:
|
|
341
372
|
print(
|
|
342
373
|
f"\033[0;33mWarning\033[0;0m in {parser.model_config}: You should load a pretrained ChemBFN model."
|
|
343
374
|
)
|
|
344
375
|
flag_warning += 1
|
|
345
|
-
if not model_config["MLP"]["base_model"]:
|
|
376
|
+
if "MLP" in model_config and not model_config["MLP"]["base_model"]:
|
|
346
377
|
print(
|
|
347
378
|
f"\033[0;33mWarning\033[0;0m in {parser.model_config}: You should load a pretrained MLP."
|
|
348
379
|
)
|
|
@@ -350,13 +381,17 @@ def main_script(version: str) -> None:
|
|
|
350
381
|
if "inference" in runtime_config:
|
|
351
382
|
if runtime_config["inference"]["guidance_objective"]:
|
|
352
383
|
if not "MLP" in model_config:
|
|
353
|
-
print(
|
|
384
|
+
print(
|
|
385
|
+
f"\033[0;33mWarning\033[0;0m in {parser.model_config}: Oh no, you don't have a MLP."
|
|
386
|
+
)
|
|
354
387
|
flag_warning += 1
|
|
355
388
|
if parser.dryrun:
|
|
356
389
|
if flag_critical != 0:
|
|
357
390
|
print("Configuration check failed!")
|
|
358
391
|
elif flag_warning != 0:
|
|
359
|
-
print(
|
|
392
|
+
print(
|
|
393
|
+
"Your job will probably run, but it may not follow your expectations."
|
|
394
|
+
)
|
|
360
395
|
else:
|
|
361
396
|
print("Configuration check passed.")
|
|
362
397
|
return
|
|
@@ -415,6 +450,15 @@ def main_script(version: str) -> None:
|
|
|
415
450
|
mlp = None
|
|
416
451
|
# ------- train -------
|
|
417
452
|
if "train" in runtime_config:
|
|
453
|
+
import lightning as L
|
|
454
|
+
from torch.utils.data import DataLoader
|
|
455
|
+
from lightning.pytorch import loggers
|
|
456
|
+
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
457
|
+
from bayesianflow_for_chem.train import Model
|
|
458
|
+
|
|
459
|
+
# ####### get plugins #######
|
|
460
|
+
plugin_file = runtime_config["train"].get("plugin_script", "")
|
|
461
|
+
plugins = _load_plugin(plugin_file)
|
|
418
462
|
# ####### build scorer #######
|
|
419
463
|
if (tokeniser_name == "smiles" or tokeniser_name == "safe") and runtime_config[
|
|
420
464
|
"train"
|
|
@@ -428,30 +472,43 @@ def main_script(version: str) -> None:
|
|
|
428
472
|
mol_tag = runtime_config["train"]["molecule_tag"]
|
|
429
473
|
obj_tag = runtime_config["train"]["objective_tag"]
|
|
430
474
|
dataset_file = runtime_config["train"]["dataset"]
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
475
|
+
if plugins["max_sequence_length"]:
|
|
476
|
+
lmax = plugins["max_sequence_length"]
|
|
477
|
+
else:
|
|
478
|
+
with open(dataset_file, "r") as db:
|
|
479
|
+
_data = db.readlines()
|
|
480
|
+
_header = _data[0]
|
|
481
|
+
_mol_idx = []
|
|
482
|
+
for i, tag in enumerate(_header.replace("\n", "").split(",")):
|
|
483
|
+
if tag == mol_tag:
|
|
484
|
+
_mol_idx.append(i)
|
|
485
|
+
_data_len = []
|
|
486
|
+
for i in _data[1:]:
|
|
487
|
+
i = i.replace("\n", "").split(",")
|
|
488
|
+
_mol = ".".join([i[j] for j in _mol_idx])
|
|
489
|
+
_data_len.append(tokeniser(_mol).shape[-1])
|
|
490
|
+
lmax = max(_data_len)
|
|
491
|
+
del _data, _data_len, _header, _mol_idx # clear memory
|
|
492
|
+
if plugins["CustomData"] is not None:
|
|
493
|
+
dataset = plugins["CustomData"](dataset_file)
|
|
494
|
+
else:
|
|
495
|
+
dataset = CSVData(dataset_file)
|
|
445
496
|
dataset.map(
|
|
446
497
|
partial(_encode, mol_tag=mol_tag, obj_tag=obj_tag, tokeniser=tokeniser)
|
|
447
498
|
)
|
|
448
499
|
dataloader = DataLoader(
|
|
449
500
|
dataset,
|
|
450
501
|
runtime_config["train"]["batch_size"],
|
|
451
|
-
True,
|
|
452
|
-
num_workers=4,
|
|
453
|
-
collate_fn=
|
|
454
|
-
|
|
502
|
+
True if plugins["shuffle"] is None else plugins["shuffle"],
|
|
503
|
+
num_workers=4 if plugins["num_workers"] is None else plugins["num_workers"],
|
|
504
|
+
collate_fn=(
|
|
505
|
+
collate if plugins["collate_fn"] is None else plugins["collate_fn"]
|
|
506
|
+
),
|
|
507
|
+
persistent_workers=(
|
|
508
|
+
True
|
|
509
|
+
if (plugins["num_workers"] is None or plugins["num_workers"] > 0)
|
|
510
|
+
else False
|
|
511
|
+
),
|
|
455
512
|
)
|
|
456
513
|
# ####### build trainer #######
|
|
457
514
|
logger_name = runtime_config["train"]["logger_name"].lower()
|
|
@@ -530,6 +587,7 @@ def main_script(version: str) -> None:
|
|
|
530
587
|
if "train" in runtime_config:
|
|
531
588
|
bfn = model.model
|
|
532
589
|
mlp = model.mlp
|
|
590
|
+
# ↓ added in v2.1.0; need to be compatible with old versions
|
|
533
591
|
lora_scaling = runtime_config["inference"].get("lora_scaling", 1.0)
|
|
534
592
|
# ####### strat inference #######
|
|
535
593
|
bfn.semi_autoregressive = runtime_config["inference"]["semi_autoregressive"]
|
bayesianflow_for_chem/data.py
CHANGED
|
@@ -61,7 +61,7 @@ def load_vocab(
|
|
|
61
61
|
}
|
|
62
62
|
|
|
63
63
|
|
|
64
|
-
_DEFUALT_VOCAB = load_vocab(__filedir__ / "vocab.txt")
|
|
64
|
+
_DEFUALT_VOCAB = load_vocab(__filedir__ / "_data/vocab.txt")
|
|
65
65
|
VOCAB_KEYS: List[str] = _DEFUALT_VOCAB["vocab_keys"]
|
|
66
66
|
VOCAB_DICT: Dict[str, int] = _DEFUALT_VOCAB["vocab_dict"]
|
|
67
67
|
VOCAB_COUNT: int = _DEFUALT_VOCAB["vocab_count"]
|
bayesianflow_for_chem/model.py
CHANGED
|
@@ -659,12 +659,91 @@ class ChemBFN(nn.Module):
|
|
|
659
659
|
return (-logits.gather(-1, x[..., :1]).squeeze(-1)).mean()
|
|
660
660
|
|
|
661
661
|
@staticmethod
|
|
662
|
-
def
|
|
662
|
+
def _reshape(y: Tensor) -> Tensor:
|
|
663
663
|
assert y.dim() <= 3 # this doesn't work if the model is frezen in JIT.
|
|
664
664
|
if y.dim() == 2:
|
|
665
665
|
return y[:, None, :]
|
|
666
666
|
return y
|
|
667
667
|
|
|
668
|
+
def _process(
|
|
669
|
+
self,
|
|
670
|
+
theta: Tensor,
|
|
671
|
+
mask: Optional[Tuple[Tensor, Tensor]],
|
|
672
|
+
y: Optional[Tensor],
|
|
673
|
+
sample_step: int,
|
|
674
|
+
guidance_strength: float,
|
|
675
|
+
token_mask: Optional[Tensor],
|
|
676
|
+
) -> Tuple[Tensor, Tensor]:
|
|
677
|
+
# BFN inference process.
|
|
678
|
+
#
|
|
679
|
+
# theta: piror distribution; shape: (n_b, n_t, n_vocab)
|
|
680
|
+
# mask: masked condition distribution; shape: (n_b, n_t, n_vocab)
|
|
681
|
+
# condition distribution mask; shape: (n_b, n_t, 1)
|
|
682
|
+
n_b = theta.shape[0]
|
|
683
|
+
if y is not None:
|
|
684
|
+
y = self._reshape(y)
|
|
685
|
+
for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
|
|
686
|
+
t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
|
|
687
|
+
p = self.discrete_output_distribution(theta, t, y, guidance_strength)
|
|
688
|
+
if token_mask is not None:
|
|
689
|
+
p = p.masked_fill_(token_mask, 0.0)
|
|
690
|
+
alpha = self.calc_discrete_alpha(t, t + 1 / sample_step)
|
|
691
|
+
e_k = nn.functional.one_hot(torch.argmax(p, -1), self.K).float()
|
|
692
|
+
mu = alpha * (self.K * e_k - 1)
|
|
693
|
+
sigma = (alpha * self.K).sqrt()
|
|
694
|
+
theta = (mu + sigma * torch.randn_like(mu)).exp() * theta
|
|
695
|
+
theta = theta / theta.sum(-1, True)
|
|
696
|
+
if mask is not None:
|
|
697
|
+
x_onehot, x_mask = mask
|
|
698
|
+
theta = x_onehot + (1 - x_mask) * theta
|
|
699
|
+
t_final = torch.ones((n_b, 1, 1), device=self.beta.device)
|
|
700
|
+
p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
|
|
701
|
+
entropy = -(p * p.log()).sum(-1).mean(-1)
|
|
702
|
+
if token_mask is not None:
|
|
703
|
+
p = p.masked_fill_(token_mask, 0.0)
|
|
704
|
+
return torch.argmax(p, -1), entropy
|
|
705
|
+
|
|
706
|
+
def _ode_process(
|
|
707
|
+
self,
|
|
708
|
+
z: Tensor,
|
|
709
|
+
mask: Optional[Tuple[Tensor, Tensor]],
|
|
710
|
+
y: Optional[Tensor],
|
|
711
|
+
sample_step: int,
|
|
712
|
+
guidance_strength: float,
|
|
713
|
+
token_mask: Optional[Tensor],
|
|
714
|
+
temperature: float,
|
|
715
|
+
) -> Tuple[Tensor, Tensor]:
|
|
716
|
+
# ODE-solver engaged inference process.
|
|
717
|
+
#
|
|
718
|
+
# z: prior latent vector; shape: (n_b, n_t, n_vocab)
|
|
719
|
+
# mask: masked condition distribution; shape: (n_b, n_t, n_vocab)
|
|
720
|
+
# condition distribution mask; shape: (n_b, n_t, 1)
|
|
721
|
+
n_b = z.shape[0]
|
|
722
|
+
if y is not None:
|
|
723
|
+
y = self._reshape(y)
|
|
724
|
+
for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
|
|
725
|
+
t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
|
|
726
|
+
theta = softmax(z, -1)
|
|
727
|
+
if mask is not None:
|
|
728
|
+
x_onehot, x_mask = mask
|
|
729
|
+
theta = x_onehot + (1 - x_mask) * theta
|
|
730
|
+
beta = self.calc_beta(t + 1 / sample_step)
|
|
731
|
+
p = self.discrete_output_distribution(theta, t, y, guidance_strength)
|
|
732
|
+
if token_mask is not None:
|
|
733
|
+
p = p.masked_fill_(token_mask, 0.0)
|
|
734
|
+
u = torch.randn_like(z)
|
|
735
|
+
z = (self.K * p - 1) * beta + (self.K * beta * temperature).sqrt() * u
|
|
736
|
+
t_final = torch.ones((n_b, 1, 1), device=self.beta.device)
|
|
737
|
+
theta = softmax(z, -1)
|
|
738
|
+
if mask is not None:
|
|
739
|
+
x_onehot, x_mask = mask
|
|
740
|
+
theta = x_onehot + (1 - x_mask) * theta
|
|
741
|
+
p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
|
|
742
|
+
entropy = -(p * p.log()).sum(-1).mean(-1)
|
|
743
|
+
if token_mask is not None:
|
|
744
|
+
p = p.masked_fill_(token_mask, 0.0)
|
|
745
|
+
return torch.argmax(p, -1), entropy
|
|
746
|
+
|
|
668
747
|
@torch.jit.export
|
|
669
748
|
def sample(
|
|
670
749
|
self,
|
|
@@ -680,44 +759,26 @@ class ChemBFN(nn.Module):
|
|
|
680
759
|
|
|
681
760
|
:param batch_size: batch size
|
|
682
761
|
:param sequence_size: max sequence length
|
|
683
|
-
:param y: conditioning vector;
|
|
762
|
+
:param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
|
|
684
763
|
:param sample_step: number of sampling steps
|
|
685
764
|
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
686
765
|
:param token_mask: token mask assigning unwanted token(s) with `True`;
|
|
687
|
-
|
|
766
|
+
shape: (1, 1, n_vocab)
|
|
688
767
|
:type batch_size: int
|
|
689
768
|
:type sequence_size: int
|
|
690
769
|
:type y: torch.Tensor | None
|
|
691
770
|
:type sample_step: int
|
|
692
771
|
:type guidance_strength: float
|
|
693
772
|
:type token_mask: torch.Tensor | None
|
|
694
|
-
:return: sampled token indices;
|
|
695
|
-
entropy of the tokens;
|
|
773
|
+
:return: sampled token indices; shape: (n_b, n_t) \n
|
|
774
|
+
entropy of the tokens; shape: (n_b)
|
|
696
775
|
:rtype: tuple
|
|
697
776
|
"""
|
|
698
777
|
theta = (
|
|
699
778
|
torch.ones((batch_size, sequence_size, self.K), device=self.beta.device)
|
|
700
779
|
/ self.K
|
|
701
780
|
)
|
|
702
|
-
|
|
703
|
-
y = self.reshape_y(y)
|
|
704
|
-
for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
|
|
705
|
-
t = (i - 1).view(1, 1, 1).repeat(batch_size, 1, 1) / sample_step
|
|
706
|
-
p = self.discrete_output_distribution(theta, t, y, guidance_strength)
|
|
707
|
-
if token_mask is not None:
|
|
708
|
-
p = p.masked_fill_(token_mask, 0.0)
|
|
709
|
-
alpha = self.calc_discrete_alpha(t, t + 1 / sample_step)
|
|
710
|
-
e_k = nn.functional.one_hot(torch.argmax(p, -1), self.K).float()
|
|
711
|
-
mu = alpha * (self.K * e_k - 1)
|
|
712
|
-
sigma = (alpha * self.K).sqrt()
|
|
713
|
-
theta = (mu + sigma * torch.randn_like(mu)).exp() * theta
|
|
714
|
-
theta = theta / theta.sum(-1, True)
|
|
715
|
-
t_final = torch.ones((batch_size, 1, 1), device=self.beta.device)
|
|
716
|
-
p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
|
|
717
|
-
entropy = -(p * p.log()).sum(-1).mean(-1)
|
|
718
|
-
if token_mask is not None:
|
|
719
|
-
p = p.masked_fill_(token_mask, 0.0)
|
|
720
|
-
return torch.argmax(p, -1), entropy
|
|
781
|
+
return self._process(theta, None, y, sample_step, guidance_strength, token_mask)
|
|
721
782
|
|
|
722
783
|
@torch.jit.export
|
|
723
784
|
def ode_sample(
|
|
@@ -735,11 +796,11 @@ class ChemBFN(nn.Module):
|
|
|
735
796
|
|
|
736
797
|
:param batch_size: batch size
|
|
737
798
|
:param sequence_size: max sequence length
|
|
738
|
-
:param y: conditioning vector;
|
|
799
|
+
:param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
|
|
739
800
|
:param sample_step: number of sampling steps
|
|
740
801
|
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
741
802
|
:param token_mask: token mask assigning unwanted token(s) with `True`;
|
|
742
|
-
|
|
803
|
+
shape: (1, 1, n_vocab)
|
|
743
804
|
:param temperature: sampling temperature
|
|
744
805
|
:type batch_size: int
|
|
745
806
|
:type sequence_size: int
|
|
@@ -748,29 +809,14 @@ class ChemBFN(nn.Module):
|
|
|
748
809
|
:type guidance_strength: float
|
|
749
810
|
:type token_mask: torch.Tensor | None
|
|
750
811
|
:type temperature: float
|
|
751
|
-
:return: sampled token indices;
|
|
752
|
-
entropy of the tokens;
|
|
812
|
+
:return: sampled token indices; shape: (n_b, n_t) \n
|
|
813
|
+
entropy of the tokens; shape: (n_b)
|
|
753
814
|
:rtype: tuple
|
|
754
815
|
"""
|
|
755
816
|
z = torch.zeros((batch_size, sequence_size, self.K), device=self.beta.device)
|
|
756
|
-
|
|
757
|
-
y
|
|
758
|
-
|
|
759
|
-
t = (i - 1).view(1, 1, 1).repeat(batch_size, 1, 1) / sample_step
|
|
760
|
-
theta = torch.softmax(z, -1)
|
|
761
|
-
beta = self.calc_beta(t + 1 / sample_step)
|
|
762
|
-
p = self.discrete_output_distribution(theta, t, y, guidance_strength)
|
|
763
|
-
if token_mask is not None:
|
|
764
|
-
p = p.masked_fill_(token_mask, 0.0)
|
|
765
|
-
u = torch.randn_like(z)
|
|
766
|
-
z = (self.K * p - 1) * beta + (self.K * beta * temperature).sqrt() * u
|
|
767
|
-
t_final = torch.ones((batch_size, 1, 1), device=self.beta.device)
|
|
768
|
-
theta = torch.softmax(z, -1)
|
|
769
|
-
p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
|
|
770
|
-
entropy = -(p * p.log()).sum(-1).mean(-1)
|
|
771
|
-
if token_mask is not None:
|
|
772
|
-
p = p.masked_fill_(token_mask, 0.0)
|
|
773
|
-
return torch.argmax(p, -1), entropy
|
|
817
|
+
return self._ode_process(
|
|
818
|
+
z, None, y, sample_step, guidance_strength, token_mask, temperature
|
|
819
|
+
)
|
|
774
820
|
|
|
775
821
|
@torch.jit.export
|
|
776
822
|
def inpaint(
|
|
@@ -800,30 +846,12 @@ class ChemBFN(nn.Module):
|
|
|
800
846
|
:rtype: tuple
|
|
801
847
|
"""
|
|
802
848
|
n_b, n_t = x.shape
|
|
803
|
-
|
|
849
|
+
x_mask = (x != 0).float()[..., None]
|
|
804
850
|
theta = torch.ones((n_b, n_t, self.K), device=x.device) / self.K
|
|
805
|
-
x_onehot = nn.functional.one_hot(x, self.K) *
|
|
806
|
-
theta = x_onehot + (1 -
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
for i in torch.linspace(1, sample_step, sample_step, device=x.device):
|
|
810
|
-
t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
|
|
811
|
-
p = self.discrete_output_distribution(theta, t, y, guidance_strength)
|
|
812
|
-
if token_mask is not None:
|
|
813
|
-
p = p.masked_fill_(token_mask, 0.0)
|
|
814
|
-
alpha = self.calc_discrete_alpha(t, t + 1 / sample_step)
|
|
815
|
-
e_k = nn.functional.one_hot(torch.argmax(p, -1), self.K).float()
|
|
816
|
-
mu = alpha * (self.K * e_k - 1)
|
|
817
|
-
sigma = (alpha * self.K).sqrt()
|
|
818
|
-
theta = (mu + sigma * torch.randn_like(mu)).exp() * theta
|
|
819
|
-
theta = theta / theta.sum(-1, True)
|
|
820
|
-
theta = x_onehot + (1 - mask) * theta
|
|
821
|
-
t_final = torch.ones((n_b, 1, 1), device=x.device)
|
|
822
|
-
p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
|
|
823
|
-
entropy = -(p * p.log()).sum(-1).mean(-1)
|
|
824
|
-
if token_mask is not None:
|
|
825
|
-
p = p.masked_fill_(token_mask, 0.0)
|
|
826
|
-
return torch.argmax(p, -1), entropy
|
|
851
|
+
x_onehot = nn.functional.one_hot(x, self.K) * x_mask
|
|
852
|
+
theta = x_onehot + (1 - x_mask) * theta
|
|
853
|
+
mask = (x_onehot, x_mask)
|
|
854
|
+
return self._process(theta, mask, y, sample_step, guidance_strength, token_mask)
|
|
827
855
|
|
|
828
856
|
@torch.jit.export
|
|
829
857
|
def ode_inpaint(
|
|
@@ -856,29 +884,13 @@ class ChemBFN(nn.Module):
|
|
|
856
884
|
:rtype: tuple
|
|
857
885
|
"""
|
|
858
886
|
n_b, n_t = x.shape
|
|
859
|
-
|
|
860
|
-
x_onehot = nn.functional.one_hot(x, self.K) *
|
|
887
|
+
x_mask = (x != 0).float()[..., None]
|
|
888
|
+
x_onehot = nn.functional.one_hot(x, self.K) * x_mask
|
|
861
889
|
z = torch.zeros((n_b, n_t, self.K), device=self.beta.device)
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
theta = torch.softmax(z, -1)
|
|
867
|
-
theta = x_onehot + (1 - mask) * theta
|
|
868
|
-
beta = self.calc_beta(t + 1 / sample_step)
|
|
869
|
-
p = self.discrete_output_distribution(theta, t, y, guidance_strength)
|
|
870
|
-
if token_mask is not None:
|
|
871
|
-
p = p.masked_fill_(token_mask, 0.0)
|
|
872
|
-
u = torch.randn_like(z)
|
|
873
|
-
z = (self.K * p - 1) * beta + (self.K * beta * temperature).sqrt() * u
|
|
874
|
-
t_final = torch.ones((n_b, 1, 1), device=self.beta.device)
|
|
875
|
-
theta = torch.softmax(z, -1)
|
|
876
|
-
theta = x_onehot + (1 - mask) * theta
|
|
877
|
-
p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
|
|
878
|
-
entropy = -(p * p.log()).sum(-1).mean(-1)
|
|
879
|
-
if token_mask is not None:
|
|
880
|
-
p = p.masked_fill_(token_mask, 0.0)
|
|
881
|
-
return torch.argmax(p, -1), entropy
|
|
890
|
+
mask = (x_onehot, x_mask)
|
|
891
|
+
return self._ode_process(
|
|
892
|
+
z, mask, y, sample_step, guidance_strength, token_mask, temperature
|
|
893
|
+
)
|
|
882
894
|
|
|
883
895
|
@torch.jit.export
|
|
884
896
|
def optimise(
|
|
@@ -908,28 +920,9 @@ class ChemBFN(nn.Module):
|
|
|
908
920
|
entropy of the tokens; shape: (n_b)
|
|
909
921
|
:rtype: tuple
|
|
910
922
|
"""
|
|
911
|
-
n_b = x.shape[0]
|
|
912
923
|
x_onehot = nn.functional.one_hot(x, self.K).float()
|
|
913
|
-
theta =
|
|
914
|
-
|
|
915
|
-
y = self.reshape_y(y)
|
|
916
|
-
for i in torch.linspace(1, sample_step, sample_step, device=x.device):
|
|
917
|
-
t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
|
|
918
|
-
p = self.discrete_output_distribution(theta, t, y, guidance_strength)
|
|
919
|
-
if token_mask is not None:
|
|
920
|
-
p = p.masked_fill_(token_mask, 0.0)
|
|
921
|
-
alpha = self.calc_discrete_alpha(t, t + 1 / sample_step)
|
|
922
|
-
e_k = nn.functional.one_hot(torch.argmax(p, -1), self.K).float()
|
|
923
|
-
mu = alpha * (self.K * e_k - 1)
|
|
924
|
-
sigma = (alpha * self.K).sqrt()
|
|
925
|
-
theta = (mu + sigma * torch.randn_like(mu)).exp() * theta
|
|
926
|
-
theta = theta / theta.sum(-1, True)
|
|
927
|
-
t_final = torch.ones((n_b, 1, 1), device=x.device)
|
|
928
|
-
p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
|
|
929
|
-
entropy = -(p * p.log()).sum(-1).mean(-1)
|
|
930
|
-
if token_mask is not None:
|
|
931
|
-
p = p.masked_fill_(token_mask, 0.0)
|
|
932
|
-
return torch.argmax(p, -1), entropy
|
|
924
|
+
theta = softmax(x_onehot, -1)
|
|
925
|
+
return self._process(theta, None, y, sample_step, guidance_strength, token_mask)
|
|
933
926
|
|
|
934
927
|
@torch.jit.export
|
|
935
928
|
def ode_optimise(
|
|
@@ -961,26 +954,10 @@ class ChemBFN(nn.Module):
|
|
|
961
954
|
entropy of the tokens; shape: (n_b)
|
|
962
955
|
:rtype: tuple
|
|
963
956
|
"""
|
|
964
|
-
n_b = x.shape[0]
|
|
965
957
|
z = nn.functional.one_hot(x, self.K).float()
|
|
966
|
-
|
|
967
|
-
y
|
|
968
|
-
|
|
969
|
-
t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
|
|
970
|
-
theta = torch.softmax(z, -1)
|
|
971
|
-
beta = self.calc_beta(t + 1 / sample_step)
|
|
972
|
-
p = self.discrete_output_distribution(theta, t, y, guidance_strength)
|
|
973
|
-
if token_mask is not None:
|
|
974
|
-
p = p.masked_fill_(token_mask, 0.0)
|
|
975
|
-
u = torch.randn_like(z)
|
|
976
|
-
z = (self.K * p - 1) * beta + (self.K * beta * temperature).sqrt() * u
|
|
977
|
-
t_final = torch.ones((n_b, 1, 1), device=self.beta.device)
|
|
978
|
-
theta = torch.softmax(z, -1)
|
|
979
|
-
p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
|
|
980
|
-
entropy = -(p * p.log()).sum(-1).mean(-1)
|
|
981
|
-
if token_mask is not None:
|
|
982
|
-
p = p.masked_fill_(token_mask, 0.0)
|
|
983
|
-
return torch.argmax(p, -1), entropy
|
|
958
|
+
return self._ode_process(
|
|
959
|
+
z, None, y, sample_step, guidance_strength, token_mask, temperature
|
|
960
|
+
)
|
|
984
961
|
|
|
985
962
|
def inference(
|
|
986
963
|
self, x: Tensor, mlp: MLP, embed_fn: Optional[Callable[[Tensor], Tensor]] = None
|
|
@@ -1154,22 +1131,6 @@ class EnsembleChemBFN(ChemBFN):
|
|
|
1154
1131
|
module.lora_dropout = None
|
|
1155
1132
|
v.lora_enabled = False
|
|
1156
1133
|
|
|
1157
|
-
def construct_y(
|
|
1158
|
-
self, c: Union[List[Tensor], Dict[str, Tensor]]
|
|
1159
|
-
) -> Dict[str, Tensor]:
|
|
1160
|
-
assert (
|
|
1161
|
-
isinstance(c, dict) is self._label_is_dict
|
|
1162
|
-
), f"`c` should be a {'`dict` instance' if self._label_is_dict else '`list` instance'} but got {type(c)} instand."
|
|
1163
|
-
out: Dict[str, Tensor] = {}
|
|
1164
|
-
if isinstance(c, list):
|
|
1165
|
-
c = dict(zip([f"val_{i}" for i in range(len(c))], c))
|
|
1166
|
-
for name, model in self.cond_heads.items():
|
|
1167
|
-
y = model.forward(c[name])
|
|
1168
|
-
if y.dim() == 2:
|
|
1169
|
-
y = y[:, None, :]
|
|
1170
|
-
out[name] = y
|
|
1171
|
-
return out
|
|
1172
|
-
|
|
1173
1134
|
def discrete_output_distribution(
|
|
1174
1135
|
self, theta: Tensor, t: Tensor, y: Dict[str, Tensor], w: float
|
|
1175
1136
|
) -> Tensor:
|
|
@@ -1204,8 +1165,24 @@ class EnsembleChemBFN(ChemBFN):
|
|
|
1204
1165
|
p_cond += p_cond_ * self.adapter_weights[name]
|
|
1205
1166
|
return softmax((1 + w) * p_cond - w * p_uncond, -1)
|
|
1206
1167
|
|
|
1168
|
+
def _map_to_dict(
|
|
1169
|
+
self, c: Union[List[Tensor], Dict[str, Tensor]]
|
|
1170
|
+
) -> Dict[str, Tensor]:
|
|
1171
|
+
assert (
|
|
1172
|
+
isinstance(c, dict) is self._label_is_dict
|
|
1173
|
+
), f"`c` should be a {'`dict` instance' if self._label_is_dict else '`list` instance'} but got {type(c)} instand."
|
|
1174
|
+
out: Dict[str, Tensor] = {}
|
|
1175
|
+
if isinstance(c, list):
|
|
1176
|
+
c = dict(zip([f"val_{i}" for i in range(len(c))], c))
|
|
1177
|
+
for name, model in self.cond_heads.items():
|
|
1178
|
+
y = model.forward(c[name])
|
|
1179
|
+
if y.dim() == 2:
|
|
1180
|
+
y = y[:, None, :]
|
|
1181
|
+
out[name] = y
|
|
1182
|
+
return out
|
|
1183
|
+
|
|
1207
1184
|
@staticmethod
|
|
1208
|
-
def
|
|
1185
|
+
def _reshape(y: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
1209
1186
|
for k in y:
|
|
1210
1187
|
assert y[k].dim() <= 3
|
|
1211
1188
|
if y[k].dim() == 2:
|
|
@@ -1241,7 +1218,7 @@ class EnsembleChemBFN(ChemBFN):
|
|
|
1241
1218
|
entropy of the tokens; shape: (n_b)
|
|
1242
1219
|
:rtype: tuple
|
|
1243
1220
|
"""
|
|
1244
|
-
y = self.
|
|
1221
|
+
y = self._map_to_dict(conditions)
|
|
1245
1222
|
return super().sample(
|
|
1246
1223
|
batch_size, sequence_size, y, sample_step, guidance_strength, token_mask
|
|
1247
1224
|
)
|
|
@@ -1278,7 +1255,7 @@ class EnsembleChemBFN(ChemBFN):
|
|
|
1278
1255
|
entropy of the tokens; shape: (n_b)
|
|
1279
1256
|
:rtype: tuple
|
|
1280
1257
|
"""
|
|
1281
|
-
y = self.
|
|
1258
|
+
y = self._map_to_dict(conditions)
|
|
1282
1259
|
return super().ode_sample(
|
|
1283
1260
|
batch_size,
|
|
1284
1261
|
sequence_size,
|
|
@@ -1315,7 +1292,7 @@ class EnsembleChemBFN(ChemBFN):
|
|
|
1315
1292
|
entropy of the tokens; shape: (n_b)
|
|
1316
1293
|
:rtype: tuple
|
|
1317
1294
|
"""
|
|
1318
|
-
y = self.
|
|
1295
|
+
y = self._map_to_dict(conditions)
|
|
1319
1296
|
return super().inpaint(x, y, sample_step, guidance_strength, token_mask)
|
|
1320
1297
|
|
|
1321
1298
|
@torch.inference_mode()
|
|
@@ -1347,7 +1324,7 @@ class EnsembleChemBFN(ChemBFN):
|
|
|
1347
1324
|
entropy of the tokens; shape: (n_b)
|
|
1348
1325
|
:rtype: tuple
|
|
1349
1326
|
"""
|
|
1350
|
-
y = self.
|
|
1327
|
+
y = self._map_to_dict(conditions)
|
|
1351
1328
|
return super().ode_inpaint(
|
|
1352
1329
|
x, y, sample_step, guidance_strength, token_mask, temperature
|
|
1353
1330
|
)
|
|
@@ -1380,7 +1357,7 @@ class EnsembleChemBFN(ChemBFN):
|
|
|
1380
1357
|
entropy of the tokens; shape: (n_b)
|
|
1381
1358
|
:rtype: tuple
|
|
1382
1359
|
"""
|
|
1383
|
-
y = self.
|
|
1360
|
+
y = self._map_to_dict(conditions)
|
|
1384
1361
|
return super().optimise(x, y, sample_step, guidance_strength, token_mask)
|
|
1385
1362
|
|
|
1386
1363
|
@torch.inference_mode()
|
|
@@ -1412,7 +1389,7 @@ class EnsembleChemBFN(ChemBFN):
|
|
|
1412
1389
|
entropy of the tokens; shape: (n_b)
|
|
1413
1390
|
:rtype: tuple
|
|
1414
1391
|
"""
|
|
1415
|
-
y = self.
|
|
1392
|
+
y = self._map_to_dict(conditions)
|
|
1416
1393
|
return super().ode_optimise(
|
|
1417
1394
|
x, y, sample_step, guidance_strength, token_mask, temperature
|
|
1418
1395
|
)
|
bayesianflow_for_chem/tool.py
CHANGED
|
@@ -9,7 +9,6 @@ import warnings
|
|
|
9
9
|
from pathlib import Path
|
|
10
10
|
from typing import List, Dict, Tuple, Union, Optional
|
|
11
11
|
import torch
|
|
12
|
-
import colorama
|
|
13
12
|
import numpy as np
|
|
14
13
|
from torch import cuda, Tensor, softmax
|
|
15
14
|
from torch.utils.data import DataLoader
|
|
@@ -24,15 +23,7 @@ from rdkit.Chem import (
|
|
|
24
23
|
AddHs,
|
|
25
24
|
Mol,
|
|
26
25
|
)
|
|
27
|
-
from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles
|
|
28
|
-
from sklearn.metrics import (
|
|
29
|
-
roc_auc_score,
|
|
30
|
-
auc,
|
|
31
|
-
precision_recall_curve,
|
|
32
|
-
r2_score,
|
|
33
|
-
mean_absolute_error,
|
|
34
|
-
root_mean_squared_error,
|
|
35
|
-
)
|
|
26
|
+
from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles
|
|
36
27
|
from .data import VOCAB_KEYS
|
|
37
28
|
from .model import ChemBFN, MLP, EnsembleChemBFN
|
|
38
29
|
|
|
@@ -45,6 +36,68 @@ def _find_device() -> torch.device:
|
|
|
45
36
|
return torch.device("cpu")
|
|
46
37
|
|
|
47
38
|
|
|
39
|
+
def _parse_and_assert_param(
|
|
40
|
+
model: Union[ChemBFN, EnsembleChemBFN],
|
|
41
|
+
y: Optional[Union[Tensor, Dict[str, Tensor], List[Tensor]]],
|
|
42
|
+
method: str,
|
|
43
|
+
) -> Optional[float]:
|
|
44
|
+
assert method.split(":")[0].lower() in ("ode", "bfn")
|
|
45
|
+
if isinstance(model, EnsembleChemBFN):
|
|
46
|
+
assert y is not None, "conditioning is required while using an ensemble model."
|
|
47
|
+
assert isinstance(y, list) or isinstance(y, dict)
|
|
48
|
+
else:
|
|
49
|
+
assert isinstance(y, Tensor) or (y is None)
|
|
50
|
+
if "ode" in method.lower():
|
|
51
|
+
tp = float(method.split(":")[-1])
|
|
52
|
+
assert tp > 0, "Sampling temperature should be higher than 0."
|
|
53
|
+
return tp
|
|
54
|
+
return None
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _map_to_device(
|
|
58
|
+
y: Optional[Union[Tensor, Dict[str, Tensor], List[Tensor]]],
|
|
59
|
+
device: Union[str, torch.device],
|
|
60
|
+
) -> Optional[Union[Tensor, Dict[str, Tensor], List[Tensor]]]:
|
|
61
|
+
if y is not None:
|
|
62
|
+
if isinstance(y, Tensor):
|
|
63
|
+
y = y.to(device)
|
|
64
|
+
elif isinstance(y, list):
|
|
65
|
+
y = [i.to(device) for i in y]
|
|
66
|
+
elif isinstance(y, dict):
|
|
67
|
+
y = {k: v.to(device) for k, v in y.items()}
|
|
68
|
+
else:
|
|
69
|
+
raise NotImplementedError
|
|
70
|
+
return y
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _build_token_mask(
|
|
74
|
+
allowed_tokens: Union[str, List[str]],
|
|
75
|
+
vocab_keys: List[str],
|
|
76
|
+
device: Union[str, torch.tensor],
|
|
77
|
+
) -> Optional[Tensor]:
|
|
78
|
+
if isinstance(allowed_tokens, list):
|
|
79
|
+
token_mask = [0 if i in allowed_tokens else 1 for i in vocab_keys]
|
|
80
|
+
token_mask = torch.tensor([[token_mask]], dtype=torch.bool).to(device)
|
|
81
|
+
else:
|
|
82
|
+
token_mask = None
|
|
83
|
+
return token_mask
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _token_to_seq(
|
|
87
|
+
tokens: Tensor, entropy: Tensor, vocab_keys: List[str], separator: str, sort: bool
|
|
88
|
+
) -> List[str]:
|
|
89
|
+
if sort:
|
|
90
|
+
sorted_idx = entropy.argsort(stable=True)
|
|
91
|
+
tokens = tokens[sorted_idx]
|
|
92
|
+
return [
|
|
93
|
+
separator.join([vocab_keys[i] for i in j])
|
|
94
|
+
.split("<start>" + separator)[-1]
|
|
95
|
+
.split(separator + "<end>")[0]
|
|
96
|
+
.replace("<pad>", "")
|
|
97
|
+
for j in tokens
|
|
98
|
+
]
|
|
99
|
+
|
|
100
|
+
|
|
48
101
|
@torch.no_grad()
|
|
49
102
|
def test(
|
|
50
103
|
model: ChemBFN,
|
|
@@ -86,6 +139,12 @@ def test(
|
|
|
86
139
|
predict_y.append(y_hat.detach().to("cpu"))
|
|
87
140
|
predict_y, label_y = torch.cat(predict_y, 0), torch.cat(label_y, 0).split(1, -1)
|
|
88
141
|
if mode == "regression":
|
|
142
|
+
from sklearn.metrics import (
|
|
143
|
+
r2_score,
|
|
144
|
+
mean_absolute_error,
|
|
145
|
+
root_mean_squared_error,
|
|
146
|
+
)
|
|
147
|
+
|
|
89
148
|
predict_y = [
|
|
90
149
|
predict[label_y[i] != torch.inf]
|
|
91
150
|
for (i, predict) in enumerate(predict_y.split(1, -1))
|
|
@@ -99,6 +158,8 @@ def test(
|
|
|
99
158
|
r2 = [r2_score(label, predict) for (label, predict) in y_zipped]
|
|
100
159
|
return {"MAE": mae, "RMSE": rmse, "R^2": r2}
|
|
101
160
|
if mode == "classification":
|
|
161
|
+
from sklearn.metrics import roc_auc_score, auc, precision_recall_curve
|
|
162
|
+
|
|
102
163
|
n_c = len(label_y)
|
|
103
164
|
predict_y = predict_y.chunk(n_c, -1)
|
|
104
165
|
y_zipped = list(zip(label_y, predict_y))
|
|
@@ -142,7 +203,6 @@ def split_dataset(
|
|
|
142
203
|
assert file.endswith(".csv")
|
|
143
204
|
assert len(split_ratio) == 3
|
|
144
205
|
assert method in ("random", "scaffold")
|
|
145
|
-
colorama.just_fix_windows_console()
|
|
146
206
|
with open(file, "r") as f:
|
|
147
207
|
data = list(csv.reader(f))
|
|
148
208
|
header = data[0]
|
|
@@ -167,10 +227,8 @@ def split_dataset(
|
|
|
167
227
|
# compute Bemis-Murcko scaffold
|
|
168
228
|
if len(smiles_idx) > 1:
|
|
169
229
|
warnings.warn(
|
|
170
|
-
"\033[32;1m"
|
|
171
230
|
f"We found {len(smiles_idx)} SMILES strings in a row!"
|
|
172
|
-
" Only the first SMILES will be used to compute the molecular scaffold."
|
|
173
|
-
"\033[0m",
|
|
231
|
+
" Only the first SMILES will be used to compute the molecular scaffold.",
|
|
174
232
|
stacklevel=2,
|
|
175
233
|
)
|
|
176
234
|
try:
|
|
@@ -197,10 +255,10 @@ def split_dataset(
|
|
|
197
255
|
with open(file.replace(".csv", "_test.csv"), "w", newline="") as fte:
|
|
198
256
|
writer = csv.writer(fte)
|
|
199
257
|
writer.writerows([header] + test_set)
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
258
|
+
if val_set:
|
|
259
|
+
with open(file.replace(".csv", "_val.csv"), "w", newline="") as fva:
|
|
260
|
+
writer = csv.writer(fva)
|
|
261
|
+
writer.writerows([header] + val_set)
|
|
204
262
|
|
|
205
263
|
|
|
206
264
|
@torch.no_grad()
|
|
@@ -250,32 +308,12 @@ def sample(
|
|
|
250
308
|
:return: a list of generated molecular strings
|
|
251
309
|
:rtype: list
|
|
252
310
|
"""
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
assert y is not None, "conditioning is required while using an ensemble model."
|
|
256
|
-
assert isinstance(y, list) or isinstance(y, dict)
|
|
257
|
-
else:
|
|
258
|
-
assert isinstance(y, Tensor) or y is None
|
|
259
|
-
if device is None:
|
|
260
|
-
device = _find_device()
|
|
311
|
+
tp = _parse_and_assert_param(model, y, method)
|
|
312
|
+
device = _find_device() if device is None else device
|
|
261
313
|
model.to(device).eval()
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
elif isinstance(y, list):
|
|
266
|
-
y = [i.to(device) for i in y]
|
|
267
|
-
elif isinstance(y, dict):
|
|
268
|
-
y = {k: v.to(device) for k, v in y.items()}
|
|
269
|
-
else:
|
|
270
|
-
raise NotImplementedError
|
|
271
|
-
if isinstance(allowed_tokens, list):
|
|
272
|
-
token_mask = [0 if i in allowed_tokens else 1 for i in vocab_keys]
|
|
273
|
-
token_mask = torch.tensor([[token_mask]], dtype=torch.bool).to(device)
|
|
274
|
-
else:
|
|
275
|
-
token_mask = None
|
|
276
|
-
if "ode" in method.lower():
|
|
277
|
-
tp = float(method.split(":")[-1])
|
|
278
|
-
assert tp > 0, "Sampling temperature should be higher than 0."
|
|
314
|
+
y = _map_to_device(y, device)
|
|
315
|
+
token_mask = _build_token_mask(allowed_tokens, vocab_keys, device)
|
|
316
|
+
if tp:
|
|
279
317
|
tokens, entropy = model.ode_sample(
|
|
280
318
|
batch_size, sequence_size, y, sample_step, guidance_strength, token_mask, tp
|
|
281
319
|
)
|
|
@@ -283,16 +321,7 @@ def sample(
|
|
|
283
321
|
tokens, entropy = model.sample(
|
|
284
322
|
batch_size, sequence_size, y, sample_step, guidance_strength, token_mask
|
|
285
323
|
)
|
|
286
|
-
|
|
287
|
-
sorted_idx = entropy.argsort(stable=True)
|
|
288
|
-
tokens = tokens[sorted_idx]
|
|
289
|
-
return [
|
|
290
|
-
seperator.join([vocab_keys[i] for i in j])
|
|
291
|
-
.split("<start>" + seperator)[-1]
|
|
292
|
-
.split(seperator + "<end>")[0]
|
|
293
|
-
.replace("<pad>", "")
|
|
294
|
-
for j in tokens
|
|
295
|
-
]
|
|
324
|
+
return _token_to_seq(tokens, entropy, vocab_keys, seperator, sort)
|
|
296
325
|
|
|
297
326
|
|
|
298
327
|
@torch.no_grad()
|
|
@@ -339,33 +368,13 @@ def inpaint(
|
|
|
339
368
|
:return: a list of generated molecular strings
|
|
340
369
|
:rtype: list
|
|
341
370
|
"""
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
assert y is not None, "conditioning is required while using an ensemble model."
|
|
345
|
-
assert isinstance(y, list) or isinstance(y, dict)
|
|
346
|
-
else:
|
|
347
|
-
assert isinstance(y, Tensor) or y is None
|
|
348
|
-
if device is None:
|
|
349
|
-
device = _find_device()
|
|
371
|
+
tp = _parse_and_assert_param(model, y, method)
|
|
372
|
+
device = _find_device() if device is None else device
|
|
350
373
|
model.to(device).eval()
|
|
351
374
|
x = x.to(device)
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
elif isinstance(y, list):
|
|
356
|
-
y = [i.to(device) for i in y]
|
|
357
|
-
elif isinstance(y, dict):
|
|
358
|
-
y = {k: v.to(device) for k, v in y.items()}
|
|
359
|
-
else:
|
|
360
|
-
raise NotImplementedError
|
|
361
|
-
if isinstance(allowed_tokens, list):
|
|
362
|
-
token_mask = [0 if i in allowed_tokens else 1 for i in vocab_keys]
|
|
363
|
-
token_mask = torch.tensor([[token_mask]], dtype=torch.bool).to(device)
|
|
364
|
-
else:
|
|
365
|
-
token_mask = None
|
|
366
|
-
if "ode" in method.lower():
|
|
367
|
-
tp = float(method.split(":")[-1])
|
|
368
|
-
assert tp > 0, "Sampling temperature should be higher than 0."
|
|
375
|
+
y = _map_to_device(y, device)
|
|
376
|
+
token_mask = _build_token_mask(allowed_tokens, vocab_keys, device)
|
|
377
|
+
if tp:
|
|
369
378
|
tokens, entropy = model.ode_inpaint(
|
|
370
379
|
x, y, sample_step, guidance_strength, token_mask, tp
|
|
371
380
|
)
|
|
@@ -373,16 +382,7 @@ def inpaint(
|
|
|
373
382
|
tokens, entropy = model.inpaint(
|
|
374
383
|
x, y, sample_step, guidance_strength, token_mask
|
|
375
384
|
)
|
|
376
|
-
|
|
377
|
-
sorted_idx = entropy.argsort(stable=True)
|
|
378
|
-
tokens = tokens[sorted_idx]
|
|
379
|
-
return [
|
|
380
|
-
separator.join([vocab_keys[i] for i in j])
|
|
381
|
-
.split("<start>" + separator)[-1]
|
|
382
|
-
.split(separator + "<end>")[0]
|
|
383
|
-
.replace("<pad>", "")
|
|
384
|
-
for j in tokens
|
|
385
|
-
]
|
|
385
|
+
return _token_to_seq(tokens, entropy, vocab_keys, separator, sort)
|
|
386
386
|
|
|
387
387
|
|
|
388
388
|
@torch.no_grad()
|
|
@@ -429,33 +429,13 @@ def optimise(
|
|
|
429
429
|
:return: a list of generated molecular strings
|
|
430
430
|
:rtype: list
|
|
431
431
|
"""
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
assert y is not None, "conditioning is required while using an ensemble model."
|
|
435
|
-
assert isinstance(y, list) or isinstance(y, dict)
|
|
436
|
-
else:
|
|
437
|
-
assert isinstance(y, Tensor) or y is None
|
|
438
|
-
if device is None:
|
|
439
|
-
device = _find_device()
|
|
432
|
+
tp = _parse_and_assert_param(model, y, method)
|
|
433
|
+
device = _find_device() if device is None else device
|
|
440
434
|
model.to(device).eval()
|
|
441
435
|
x = x.to(device)
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
elif isinstance(y, list):
|
|
446
|
-
y = [i.to(device) for i in y]
|
|
447
|
-
elif isinstance(y, dict):
|
|
448
|
-
y = {k: v.to(device) for k, v in y.items()}
|
|
449
|
-
else:
|
|
450
|
-
raise NotImplementedError
|
|
451
|
-
if isinstance(allowed_tokens, list):
|
|
452
|
-
token_mask = [0 if i in allowed_tokens else 1 for i in vocab_keys]
|
|
453
|
-
token_mask = torch.tensor([[token_mask]], dtype=torch.bool).to(device)
|
|
454
|
-
else:
|
|
455
|
-
token_mask = None
|
|
456
|
-
if "ode" in method.lower():
|
|
457
|
-
tp = float(method.split(":")[-1])
|
|
458
|
-
assert tp > 0, "Sampling temperature should be higher than 0."
|
|
436
|
+
y = _map_to_device(y, device)
|
|
437
|
+
token_mask = _build_token_mask(allowed_tokens, vocab_keys, device)
|
|
438
|
+
if tp:
|
|
459
439
|
tokens, entropy = model.ode_optimise(
|
|
460
440
|
x, y, sample_step, guidance_strength, token_mask, tp
|
|
461
441
|
)
|
|
@@ -463,16 +443,7 @@ def optimise(
|
|
|
463
443
|
tokens, entropy = model.optimise(
|
|
464
444
|
x, y, sample_step, guidance_strength, token_mask
|
|
465
445
|
)
|
|
466
|
-
|
|
467
|
-
sorted_idx = entropy.argsort(stable=True)
|
|
468
|
-
tokens = tokens[sorted_idx]
|
|
469
|
-
return [
|
|
470
|
-
separator.join([vocab_keys[i] for i in j])
|
|
471
|
-
.split("<start>" + separator)[-1]
|
|
472
|
-
.split(separator + "<end>")[0]
|
|
473
|
-
.replace("<pad>", "")
|
|
474
|
-
for j in tokens
|
|
475
|
-
]
|
|
446
|
+
return _token_to_seq(tokens, entropy, vocab_keys, separator, sort)
|
|
476
447
|
|
|
477
448
|
|
|
478
449
|
def quantise_model_(model: ChemBFN) -> None:
|
bayesianflow_for_chem/train.py
CHANGED
|
@@ -8,7 +8,6 @@ from typing import Dict, Tuple, Union, Optional
|
|
|
8
8
|
import torch
|
|
9
9
|
import torch.optim as op
|
|
10
10
|
import torch.nn.functional as F
|
|
11
|
-
from loralib import lora_state_dict, mark_only_lora_as_trainable
|
|
12
11
|
from torch import Tensor
|
|
13
12
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
14
13
|
from lightning import LightningModule
|
|
@@ -55,6 +54,8 @@ class Model(LightningModule):
|
|
|
55
54
|
self.scorer = scorer
|
|
56
55
|
self.save_hyperparameters(hparam, ignore=["model", "mlp", "scorer"])
|
|
57
56
|
if model.lora_enabled:
|
|
57
|
+
from loralib import mark_only_lora_as_trainable
|
|
58
|
+
|
|
58
59
|
mark_only_lora_as_trainable(self.model)
|
|
59
60
|
self.use_scorer = self.scorer is not None
|
|
60
61
|
|
|
@@ -107,6 +108,8 @@ class Model(LightningModule):
|
|
|
107
108
|
:rtype: None
|
|
108
109
|
"""
|
|
109
110
|
if self.model.lora_enabled:
|
|
111
|
+
from loralib import lora_state_dict
|
|
112
|
+
|
|
110
113
|
torch.save(
|
|
111
114
|
{
|
|
112
115
|
"lora_nn": lora_state_dict(self.model),
|
|
@@ -152,6 +155,8 @@ class Regressor(LightningModule):
|
|
|
152
155
|
self.model.requires_grad_(not hparam["freeze"])
|
|
153
156
|
self.save_hyperparameters(hparam, ignore=["model", "mlp"])
|
|
154
157
|
if model.lora_enabled:
|
|
158
|
+
from loralib import mark_only_lora_as_trainable
|
|
159
|
+
|
|
155
160
|
mark_only_lora_as_trainable(self.model)
|
|
156
161
|
assert hparam["mode"] in ("regression", "classification")
|
|
157
162
|
|
|
@@ -231,6 +236,8 @@ class Regressor(LightningModule):
|
|
|
231
236
|
)
|
|
232
237
|
if not self.hparams.freeze:
|
|
233
238
|
if self.model.lora_enabled:
|
|
239
|
+
from loralib import lora_state_dict
|
|
240
|
+
|
|
234
241
|
torch.save(
|
|
235
242
|
{
|
|
236
243
|
"lora_nn": lora_state_dict(self.model),
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: bayesianflow_for_chem
|
|
3
|
-
Version: 2.1
|
|
3
|
+
Version: 2.2.1
|
|
4
4
|
Summary: Bayesian flow network framework for Chemistry
|
|
5
5
|
Home-page: https://augus1999.github.io/bayesian-flow-network-for-chemistry/
|
|
6
6
|
Author: Nianze A. Tao
|
|
@@ -54,12 +54,14 @@ This is the repository of the PyTorch implementation of ChemBFN model.
|
|
|
54
54
|
|
|
55
55
|
[](https://pypi.org/project/bayesianflow-for-chem/)
|
|
56
56
|

|
|
57
|
+
[](https://augus1999.github.io/bayesian-flow-network-for-chemistry/)
|
|
57
58
|
|
|
58
59
|
## Features
|
|
59
60
|
|
|
60
61
|
ChemBFN provides the state-of-the-art functionalities of
|
|
61
62
|
* SMILES or SELFIES-based *de novo* molecule generation
|
|
62
63
|
* Protein sequence *de novo* generation
|
|
64
|
+
* Template optimisation (mol2mol)
|
|
63
65
|
* Classifier-free guidance conditional generation (single or multi-objective optimisation)
|
|
64
66
|
* Context-guided conditional generation (inpaint)
|
|
65
67
|
* Outstanding out-of-distribution chemical space sampling
|
|
@@ -71,6 +73,7 @@ in an all-in-one-model style.
|
|
|
71
73
|
|
|
72
74
|
## News
|
|
73
75
|
|
|
76
|
+
* [09/10/2025] A web app [`chembfn_webui`](https://github.com/Augus1999/ChemBFN-WebUI) for hosting ChemBFN models is available on [PyPI](https://pypi.org/project/chembfn-webui/).
|
|
74
77
|
* [30/01/2025] The package `bayesianflow_for_chem` is available on [PyPI](https://pypi.org/project/bayesianflow-for-chem/).
|
|
75
78
|
* [21/01/2025] Our first paper has been accepted by [JCIM](https://pubs.acs.org/doi/10.1021/acs.jcim.4c01792).
|
|
76
79
|
* [17/12/2024] The second paper of out-of-distribution generation is available on [arxiv.org](https://arxiv.org/abs/2412.11439).
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
bayesianflow_for_chem/__init__.py,sha256=l6twtXTdnix9SC8oL0Aqr0dYJ1uBUmMtk2QbUaRrhkE,642
|
|
2
|
+
bayesianflow_for_chem/cli.py,sha256=jZPBUVwOl4qnjUMb3Lavu6owv1RP9jGdrDlM7KDA99c,26600
|
|
3
|
+
bayesianflow_for_chem/data.py,sha256=RM3YaZUkTJ4-RNLH37GwhM4sRddOOusrC996ZPAe1lI,6590
|
|
4
|
+
bayesianflow_for_chem/model.py,sha256=M35G4u4mX4btl9vOK3Iqs6yOSuIKI_OoCTmLhmjbwNk,57559
|
|
5
|
+
bayesianflow_for_chem/scorer.py,sha256=gQFUlkyxitch02ntqcRh1ZS8aondKLynW5U6NfTQTb4,4084
|
|
6
|
+
bayesianflow_for_chem/spectra.py,sha256=Ba9ib1aDvTtDYbH3b4d-lIty3ZSQMu7jwehuV2KmhwA,1781
|
|
7
|
+
bayesianflow_for_chem/tool.py,sha256=VQo6xShwwmJbBsC-xTRL5ysCI7OLoOpmwQ-A7AuAPI8,23587
|
|
8
|
+
bayesianflow_for_chem/train.py,sha256=7AU0A-eZwzSYsLyIe3OxGTNWPnhGpHmVUaQLplV2Fn8,9886
|
|
9
|
+
bayesianflow_for_chem/_data/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
|
|
10
|
+
bayesianflow_for_chem-2.2.1.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
|
|
11
|
+
bayesianflow_for_chem-2.2.1.dist-info/METADATA,sha256=YKuXgdklakB9DH4cFpTy3EbYdmo8dz9BgQQ4bI9CVQs,6476
|
|
12
|
+
bayesianflow_for_chem-2.2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
13
|
+
bayesianflow_for_chem-2.2.1.dist-info/entry_points.txt,sha256=N63RMoJsr8rxuKxc7Fj802SL8J5AlpCoPkS8E3IFPLI,54
|
|
14
|
+
bayesianflow_for_chem-2.2.1.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
|
|
15
|
+
bayesianflow_for_chem-2.2.1.dist-info/RECORD,,
|
|
@@ -1,15 +0,0 @@
|
|
|
1
|
-
bayesianflow_for_chem/__init__.py,sha256=bmqERRnnmyK7UgUYn3BFH_YipKfTxNNjRzu9__RVid4,612
|
|
2
|
-
bayesianflow_for_chem/cli.py,sha256=A1cFz6jhpLEpbK9r8GxCLdbnPCzQ4RrsavLKg_lssVg,24208
|
|
3
|
-
bayesianflow_for_chem/data.py,sha256=Pl0gGWHmMKTKHpsxznvLgYPCwwlLNL7nqH19Vipjkxs,6584
|
|
4
|
-
bayesianflow_for_chem/model.py,sha256=UW5hfAofYK9dH9euDPYWfJedVMRFxk8WtY427fObf70,59641
|
|
5
|
-
bayesianflow_for_chem/scorer.py,sha256=gQFUlkyxitch02ntqcRh1ZS8aondKLynW5U6NfTQTb4,4084
|
|
6
|
-
bayesianflow_for_chem/spectra.py,sha256=Ba9ib1aDvTtDYbH3b4d-lIty3ZSQMu7jwehuV2KmhwA,1781
|
|
7
|
-
bayesianflow_for_chem/tool.py,sha256=bqoIMas8bmcjYBghuQWLh75Eq8ZlG6mh9ZeDzWGOmuw,24790
|
|
8
|
-
bayesianflow_for_chem/train.py,sha256=jYkhSguW50lrcTEydCQ20yig_mmc1j7WH9KmVwBCTAo,9727
|
|
9
|
-
bayesianflow_for_chem/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
|
|
10
|
-
bayesianflow_for_chem-2.1.0.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
|
|
11
|
-
bayesianflow_for_chem-2.1.0.dist-info/METADATA,sha256=ctwor5jnPCmAo-1wuIGkX70aqkXj-8YQxr27AJOdEjM,6057
|
|
12
|
-
bayesianflow_for_chem-2.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
13
|
-
bayesianflow_for_chem-2.1.0.dist-info/entry_points.txt,sha256=N63RMoJsr8rxuKxc7Fj802SL8J5AlpCoPkS8E3IFPLI,54
|
|
14
|
-
bayesianflow_for_chem-2.1.0.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
|
|
15
|
-
bayesianflow_for_chem-2.1.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
{bayesianflow_for_chem-2.1.0.dist-info → bayesianflow_for_chem-2.2.1.dist-info}/entry_points.txt
RENAMED
|
File without changes
|
{bayesianflow_for_chem-2.1.0.dist-info → bayesianflow_for_chem-2.2.1.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|
{bayesianflow_for_chem-2.1.0.dist-info → bayesianflow_for_chem-2.2.1.dist-info}/top_level.txt
RENAMED
|
File without changes
|