bayesianflow-for-chem 2.0.5__tar.gz → 2.2.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bayesianflow-for-chem might be problematic. Click here for more details.
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/PKG-INFO +4 -1
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/README.md +3 -0
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/bayesianflow_for_chem/__init__.py +4 -3
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/bayesianflow_for_chem/cli.py +120 -31
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/bayesianflow_for_chem/data.py +1 -1
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/bayesianflow_for_chem/model.py +257 -113
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/bayesianflow_for_chem/tool.py +150 -89
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/bayesianflow_for_chem/train.py +8 -1
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/bayesianflow_for_chem.egg-info/PKG-INFO +4 -1
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/bayesianflow_for_chem.egg-info/SOURCES.txt +3 -1
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/pyproject.toml +3 -0
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/setup.py +1 -1
- bayesianflow_for_chem-2.2.1/test/test_cli_plugin.py +55 -0
- bayesianflow_for_chem-2.2.1/test/test_jit_compatibility.py +28 -0
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/LICENSE +0 -0
- {bayesianflow_for_chem-2.0.5/bayesianflow_for_chem → bayesianflow_for_chem-2.2.1/bayesianflow_for_chem/_data}/vocab.txt +0 -0
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/bayesianflow_for_chem/scorer.py +0 -0
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/bayesianflow_for_chem/spectra.py +0 -0
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/bayesianflow_for_chem.egg-info/dependency_links.txt +0 -0
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/bayesianflow_for_chem.egg-info/entry_points.txt +0 -0
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/bayesianflow_for_chem.egg-info/requires.txt +0 -0
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/bayesianflow_for_chem.egg-info/top_level.txt +0 -0
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/setup.cfg +0 -0
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/test/test_merge_lora.py +0 -0
- {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/test/test_molecular_embedding.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: bayesianflow_for_chem
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.2.1
|
|
4
4
|
Summary: Bayesian flow network framework for Chemistry
|
|
5
5
|
Home-page: https://augus1999.github.io/bayesian-flow-network-for-chemistry/
|
|
6
6
|
Author: Nianze A. Tao
|
|
@@ -54,12 +54,14 @@ This is the repository of the PyTorch implementation of ChemBFN model.
|
|
|
54
54
|
|
|
55
55
|
[](https://pypi.org/project/bayesianflow-for-chem/)
|
|
56
56
|

|
|
57
|
+
[](https://augus1999.github.io/bayesian-flow-network-for-chemistry/)
|
|
57
58
|
|
|
58
59
|
## Features
|
|
59
60
|
|
|
60
61
|
ChemBFN provides the state-of-the-art functionalities of
|
|
61
62
|
* SMILES or SELFIES-based *de novo* molecule generation
|
|
62
63
|
* Protein sequence *de novo* generation
|
|
64
|
+
* Template optimisation (mol2mol)
|
|
63
65
|
* Classifier-free guidance conditional generation (single or multi-objective optimisation)
|
|
64
66
|
* Context-guided conditional generation (inpaint)
|
|
65
67
|
* Outstanding out-of-distribution chemical space sampling
|
|
@@ -71,6 +73,7 @@ in an all-in-one-model style.
|
|
|
71
73
|
|
|
72
74
|
## News
|
|
73
75
|
|
|
76
|
+
* [09/10/2025] A web app [`chembfn_webui`](https://github.com/Augus1999/ChemBFN-WebUI) for hosting ChemBFN models is available on [PyPI](https://pypi.org/project/chembfn-webui/).
|
|
74
77
|
* [30/01/2025] The package `bayesianflow_for_chem` is available on [PyPI](https://pypi.org/project/bayesianflow-for-chem/).
|
|
75
78
|
* [21/01/2025] Our first paper has been accepted by [JCIM](https://pubs.acs.org/doi/10.1021/acs.jcim.4c01792).
|
|
76
79
|
* [17/12/2024] The second paper of out-of-distribution generation is available on [arxiv.org](https://arxiv.org/abs/2412.11439).
|
|
@@ -9,12 +9,14 @@ This is the repository of the PyTorch implementation of ChemBFN model.
|
|
|
9
9
|
|
|
10
10
|
[](https://pypi.org/project/bayesianflow-for-chem/)
|
|
11
11
|

|
|
12
|
+
[](https://augus1999.github.io/bayesian-flow-network-for-chemistry/)
|
|
12
13
|
|
|
13
14
|
## Features
|
|
14
15
|
|
|
15
16
|
ChemBFN provides the state-of-the-art functionalities of
|
|
16
17
|
* SMILES or SELFIES-based *de novo* molecule generation
|
|
17
18
|
* Protein sequence *de novo* generation
|
|
19
|
+
* Template optimisation (mol2mol)
|
|
18
20
|
* Classifier-free guidance conditional generation (single or multi-objective optimisation)
|
|
19
21
|
* Context-guided conditional generation (inpaint)
|
|
20
22
|
* Outstanding out-of-distribution chemical space sampling
|
|
@@ -26,6 +28,7 @@ in an all-in-one-model style.
|
|
|
26
28
|
|
|
27
29
|
## News
|
|
28
30
|
|
|
31
|
+
* [09/10/2025] A web app [`chembfn_webui`](https://github.com/Augus1999/ChemBFN-WebUI) for hosting ChemBFN models is available on [PyPI](https://pypi.org/project/chembfn-webui/).
|
|
29
32
|
* [30/01/2025] The package `bayesianflow_for_chem` is available on [PyPI](https://pypi.org/project/bayesianflow-for-chem/).
|
|
30
33
|
* [21/01/2025] Our first paper has been accepted by [JCIM](https://pubs.acs.org/doi/10.1021/acs.jcim.4c01792).
|
|
31
34
|
* [17/12/2024] The second paper of out-of-distribution generation is available on [arxiv.org](https://arxiv.org/abs/2412.11439).
|
{bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/bayesianflow_for_chem/__init__.py
RENAMED
|
@@ -3,10 +3,8 @@
|
|
|
3
3
|
"""
|
|
4
4
|
ChemBFN package.
|
|
5
5
|
"""
|
|
6
|
-
import colorama
|
|
7
6
|
from . import data, tool, train, scorer, spectra
|
|
8
7
|
from .model import ChemBFN, MLP, EnsembleChemBFN
|
|
9
|
-
from .cli import main_script
|
|
10
8
|
|
|
11
9
|
__all__ = [
|
|
12
10
|
"data",
|
|
@@ -18,7 +16,7 @@ __all__ = [
|
|
|
18
16
|
"MLP",
|
|
19
17
|
"EnsembleChemBFN",
|
|
20
18
|
]
|
|
21
|
-
__version__ = "2.
|
|
19
|
+
__version__ = "2.2.1"
|
|
22
20
|
__author__ = "Nianze A. Tao (Omozawa Sueno)"
|
|
23
21
|
|
|
24
22
|
|
|
@@ -29,6 +27,9 @@ def main() -> None:
|
|
|
29
27
|
:return:
|
|
30
28
|
:rtype: None
|
|
31
29
|
"""
|
|
30
|
+
import colorama
|
|
31
|
+
from bayesianflow_for_chem.cli import main_script
|
|
32
|
+
|
|
32
33
|
colorama.just_fix_windows_console()
|
|
33
34
|
main_script(__version__)
|
|
34
35
|
colorama.deinit()
|
|
@@ -12,13 +12,8 @@ from pathlib import Path
|
|
|
12
12
|
from functools import partial
|
|
13
13
|
from typing import List, Tuple, Dict, Union, Callable
|
|
14
14
|
import torch
|
|
15
|
-
import lightning as L
|
|
16
15
|
from rdkit.Chem import MolFromSmiles, CanonSmiles
|
|
17
|
-
from torch.utils.data import DataLoader
|
|
18
|
-
from lightning.pytorch import loggers
|
|
19
|
-
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
20
16
|
from bayesianflow_for_chem import ChemBFN, MLP
|
|
21
|
-
from bayesianflow_for_chem.train import Model
|
|
22
17
|
from bayesianflow_for_chem.scorer import smiles_valid, Scorer
|
|
23
18
|
from bayesianflow_for_chem.data import (
|
|
24
19
|
VOCAB_COUNT,
|
|
@@ -32,7 +27,7 @@ from bayesianflow_for_chem.data import (
|
|
|
32
27
|
collate,
|
|
33
28
|
CSVData,
|
|
34
29
|
)
|
|
35
|
-
from bayesianflow_for_chem.tool import sample, inpaint
|
|
30
|
+
from bayesianflow_for_chem.tool import sample, inpaint, optimise, adjust_lora_
|
|
36
31
|
|
|
37
32
|
|
|
38
33
|
"""
|
|
@@ -90,6 +85,7 @@ checkpoint_save_path = "home/user/project/ckpt"
|
|
|
90
85
|
train_strategy = "auto" # or any strategy supported by Lightning, e.g., "ddp"
|
|
91
86
|
accumulate_grad_batches = 1
|
|
92
87
|
enable_progress_bar = false
|
|
88
|
+
plugin_script = "" # define customised behaviours of dataset, datasetloader, etc in a python script
|
|
93
89
|
|
|
94
90
|
# Remove this table if inference is unnecessary
|
|
95
91
|
[inference]
|
|
@@ -99,9 +95,11 @@ sample_size = 1000 # the minimum number of samples you want
|
|
|
99
95
|
sample_step = 100
|
|
100
96
|
sample_method = "ODE:0.5" # ODE-solver with temperature of 0.5; another choice is "BFN"
|
|
101
97
|
semi_autoregressive = false
|
|
98
|
+
lora_scaling = 1.0 # LoRA scaling if applied
|
|
102
99
|
guidance_objective = [-0.023, 0.09, 0.113] # if no objective is needed set it to empty array []
|
|
103
100
|
guidance_objective_strength = 4.0 # unnecessary if guidance_objective = []
|
|
104
101
|
guidance_scaffold = "c1ccccc1" # if no scaffold is used set it to empty string ""
|
|
102
|
+
sample_template = "" # template for mol2mol task; leave it blank if scaffold is used
|
|
105
103
|
unwanted_token = []
|
|
106
104
|
exclude_invalid = true # to only store valid samples
|
|
107
105
|
exclude_duplicate = true # to only store unique samples
|
|
@@ -118,6 +116,32 @@ madmadmadmadmadmadmadmadmadmadmadmadmadmadmad
|
|
|
118
116
|
madmadmadmadmadmadmadmadmadmadmadmadmadmadmad
|
|
119
117
|
"""
|
|
120
118
|
|
|
119
|
+
_ALLOWED_PLUGINS = [
|
|
120
|
+
"collate_fn",
|
|
121
|
+
"num_workers",
|
|
122
|
+
"max_sequence_length",
|
|
123
|
+
"shuffle",
|
|
124
|
+
"CustomData",
|
|
125
|
+
]
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _load_plugin(plugin_file: str) -> Dict[str, Union[int, Callable, object, None]]:
|
|
129
|
+
if not plugin_file:
|
|
130
|
+
return {n: None for n in _ALLOWED_PLUGINS}
|
|
131
|
+
from importlib import util as iutil
|
|
132
|
+
|
|
133
|
+
spec = iutil.spec_from_file_location(Path(plugin_file).stem, plugin_file)
|
|
134
|
+
plugins = iutil.module_from_spec(spec)
|
|
135
|
+
spec.loader.exec_module(plugins)
|
|
136
|
+
plugin_names: List[str] = plugins.__all__
|
|
137
|
+
plugin_dict = {}
|
|
138
|
+
for n in _ALLOWED_PLUGINS:
|
|
139
|
+
if n in plugin_names:
|
|
140
|
+
plugin_dict[n] = getattr(plugins, n)
|
|
141
|
+
else:
|
|
142
|
+
plugin_dict[n] = None
|
|
143
|
+
return plugin_dict
|
|
144
|
+
|
|
121
145
|
|
|
122
146
|
def parse_cli(version: str) -> argparse.Namespace:
|
|
123
147
|
"""
|
|
@@ -130,7 +154,7 @@ def parse_cli(version: str) -> argparse.Namespace:
|
|
|
130
154
|
"""
|
|
131
155
|
parser = argparse.ArgumentParser(
|
|
132
156
|
description="Madmol: a CLI molecular design tool for "
|
|
133
|
-
"de novo design, R-group replacement, and sequence in-filling, "
|
|
157
|
+
"de novo design, R-group replacement, molecule optimisation, and sequence in-filling, "
|
|
134
158
|
"based on generative route of ChemBFN method. "
|
|
135
159
|
"Let's make some craziest molecules.",
|
|
136
160
|
epilog=f"Madmol {version}, developed in Hiroshima University by chemists for chemists. "
|
|
@@ -157,7 +181,7 @@ def parse_cli(version: str) -> argparse.Namespace:
|
|
|
157
181
|
"-D",
|
|
158
182
|
"--dryrun",
|
|
159
183
|
action="store_true",
|
|
160
|
-
help="dry-run to check the configurations",
|
|
184
|
+
help="dry-run to check the configurations and exit",
|
|
161
185
|
)
|
|
162
186
|
parser.add_argument("-V", "--version", action="version", version=version)
|
|
163
187
|
return parser.parse_args()
|
|
@@ -265,6 +289,14 @@ def load_runtime_config(
|
|
|
265
289
|
f"\033[0;31mCritical\033[0;0m in {config_file}: Restart checkpoint file {ckpt_file} does not exist."
|
|
266
290
|
)
|
|
267
291
|
flag_critical += 1
|
|
292
|
+
# ↓ added in v2.2.0; need to be compatible with old versions.
|
|
293
|
+
plugin_script: str = config["train"].get("plugin_script", "")
|
|
294
|
+
if plugin_script:
|
|
295
|
+
if not os.path.exists(plugin_script):
|
|
296
|
+
print(
|
|
297
|
+
f"\033[0;31mCritical\033[0;0m in {config_file}: Plugin script {plugin_script} does not exist."
|
|
298
|
+
)
|
|
299
|
+
flag_critical += 1
|
|
268
300
|
if "inference" in config:
|
|
269
301
|
if not "train" in config:
|
|
270
302
|
if not isinstance(config["inference"]["sequence_length"], int):
|
|
@@ -284,6 +316,14 @@ def load_runtime_config(
|
|
|
284
316
|
f"\033[0;33mWarning\033[0;0m in {config_file}: Directory {result_dir} to save the result does not exist."
|
|
285
317
|
)
|
|
286
318
|
flag_warning += 1
|
|
319
|
+
if (
|
|
320
|
+
config["inference"]["guidance_scaffold"] != ""
|
|
321
|
+
and config["inference"]["sample_template"] != ""
|
|
322
|
+
):
|
|
323
|
+
print(
|
|
324
|
+
f"\033[0;33mWarning\033[0;0m in {config_file}: Inpaint task or mol2mol task?"
|
|
325
|
+
)
|
|
326
|
+
flag_warning += 1
|
|
287
327
|
return config, flag_critical, flag_warning
|
|
288
328
|
|
|
289
329
|
|
|
@@ -325,14 +365,15 @@ def main_script(version: str) -> None:
|
|
|
325
365
|
)
|
|
326
366
|
flag_warning += 1
|
|
327
367
|
if not os.path.exists(runtime_config["train"]["checkpoint_save_path"]):
|
|
328
|
-
|
|
368
|
+
if not parser.dryrun: # only create it in real tasks
|
|
369
|
+
os.makedirs(runtime_config["train"]["checkpoint_save_path"])
|
|
329
370
|
else:
|
|
330
371
|
if not model_config["ChemBFN"]["base_model"]:
|
|
331
372
|
print(
|
|
332
373
|
f"\033[0;33mWarning\033[0;0m in {parser.model_config}: You should load a pretrained ChemBFN model."
|
|
333
374
|
)
|
|
334
375
|
flag_warning += 1
|
|
335
|
-
if not model_config["MLP"]["base_model"]:
|
|
376
|
+
if "MLP" in model_config and not model_config["MLP"]["base_model"]:
|
|
336
377
|
print(
|
|
337
378
|
f"\033[0;33mWarning\033[0;0m in {parser.model_config}: You should load a pretrained MLP."
|
|
338
379
|
)
|
|
@@ -340,13 +381,17 @@ def main_script(version: str) -> None:
|
|
|
340
381
|
if "inference" in runtime_config:
|
|
341
382
|
if runtime_config["inference"]["guidance_objective"]:
|
|
342
383
|
if not "MLP" in model_config:
|
|
343
|
-
print(
|
|
384
|
+
print(
|
|
385
|
+
f"\033[0;33mWarning\033[0;0m in {parser.model_config}: Oh no, you don't have a MLP."
|
|
386
|
+
)
|
|
344
387
|
flag_warning += 1
|
|
345
388
|
if parser.dryrun:
|
|
346
389
|
if flag_critical != 0:
|
|
347
390
|
print("Configuration check failed!")
|
|
348
391
|
elif flag_warning != 0:
|
|
349
|
-
print(
|
|
392
|
+
print(
|
|
393
|
+
"Your job will probably run, but it may not follow your expectations."
|
|
394
|
+
)
|
|
350
395
|
else:
|
|
351
396
|
print("Configuration check passed.")
|
|
352
397
|
return
|
|
@@ -405,6 +450,15 @@ def main_script(version: str) -> None:
|
|
|
405
450
|
mlp = None
|
|
406
451
|
# ------- train -------
|
|
407
452
|
if "train" in runtime_config:
|
|
453
|
+
import lightning as L
|
|
454
|
+
from torch.utils.data import DataLoader
|
|
455
|
+
from lightning.pytorch import loggers
|
|
456
|
+
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
457
|
+
from bayesianflow_for_chem.train import Model
|
|
458
|
+
|
|
459
|
+
# ####### get plugins #######
|
|
460
|
+
plugin_file = runtime_config["train"].get("plugin_script", "")
|
|
461
|
+
plugins = _load_plugin(plugin_file)
|
|
408
462
|
# ####### build scorer #######
|
|
409
463
|
if (tokeniser_name == "smiles" or tokeniser_name == "safe") and runtime_config[
|
|
410
464
|
"train"
|
|
@@ -418,30 +472,43 @@ def main_script(version: str) -> None:
|
|
|
418
472
|
mol_tag = runtime_config["train"]["molecule_tag"]
|
|
419
473
|
obj_tag = runtime_config["train"]["objective_tag"]
|
|
420
474
|
dataset_file = runtime_config["train"]["dataset"]
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
475
|
+
if plugins["max_sequence_length"]:
|
|
476
|
+
lmax = plugins["max_sequence_length"]
|
|
477
|
+
else:
|
|
478
|
+
with open(dataset_file, "r") as db:
|
|
479
|
+
_data = db.readlines()
|
|
480
|
+
_header = _data[0]
|
|
481
|
+
_mol_idx = []
|
|
482
|
+
for i, tag in enumerate(_header.replace("\n", "").split(",")):
|
|
483
|
+
if tag == mol_tag:
|
|
484
|
+
_mol_idx.append(i)
|
|
485
|
+
_data_len = []
|
|
486
|
+
for i in _data[1:]:
|
|
487
|
+
i = i.replace("\n", "").split(",")
|
|
488
|
+
_mol = ".".join([i[j] for j in _mol_idx])
|
|
489
|
+
_data_len.append(tokeniser(_mol).shape[-1])
|
|
490
|
+
lmax = max(_data_len)
|
|
491
|
+
del _data, _data_len, _header, _mol_idx # clear memory
|
|
492
|
+
if plugins["CustomData"] is not None:
|
|
493
|
+
dataset = plugins["CustomData"](dataset_file)
|
|
494
|
+
else:
|
|
495
|
+
dataset = CSVData(dataset_file)
|
|
435
496
|
dataset.map(
|
|
436
497
|
partial(_encode, mol_tag=mol_tag, obj_tag=obj_tag, tokeniser=tokeniser)
|
|
437
498
|
)
|
|
438
499
|
dataloader = DataLoader(
|
|
439
500
|
dataset,
|
|
440
501
|
runtime_config["train"]["batch_size"],
|
|
441
|
-
True,
|
|
442
|
-
num_workers=4,
|
|
443
|
-
collate_fn=
|
|
444
|
-
|
|
502
|
+
True if plugins["shuffle"] is None else plugins["shuffle"],
|
|
503
|
+
num_workers=4 if plugins["num_workers"] is None else plugins["num_workers"],
|
|
504
|
+
collate_fn=(
|
|
505
|
+
collate if plugins["collate_fn"] is None else plugins["collate_fn"]
|
|
506
|
+
),
|
|
507
|
+
persistent_workers=(
|
|
508
|
+
True
|
|
509
|
+
if (plugins["num_workers"] is None or plugins["num_workers"] > 0)
|
|
510
|
+
else False
|
|
511
|
+
),
|
|
445
512
|
)
|
|
446
513
|
# ####### build trainer #######
|
|
447
514
|
logger_name = runtime_config["train"]["logger_name"].lower()
|
|
@@ -520,6 +587,8 @@ def main_script(version: str) -> None:
|
|
|
520
587
|
if "train" in runtime_config:
|
|
521
588
|
bfn = model.model
|
|
522
589
|
mlp = model.mlp
|
|
590
|
+
# ↓ added in v2.1.0; need to be compatible with old versions
|
|
591
|
+
lora_scaling = runtime_config["inference"].get("lora_scaling", 1.0)
|
|
523
592
|
# ####### strat inference #######
|
|
524
593
|
bfn.semi_autoregressive = runtime_config["inference"]["semi_autoregressive"]
|
|
525
594
|
_device = (
|
|
@@ -550,8 +619,16 @@ def main_script(version: str) -> None:
|
|
|
550
619
|
x[:-1], (0, sequence_length - x.shape[-1] + 1), value=0
|
|
551
620
|
)
|
|
552
621
|
x = x[None, :].repeat(batch_size, 1)
|
|
622
|
+
# then sample template will be ignored.
|
|
623
|
+
elif runtime_config["inference"]["sample_template"]:
|
|
624
|
+
template = runtime_config["inference"]["sample_template"]
|
|
625
|
+
x = tokeniser(template)
|
|
626
|
+
x = torch.nn.functional.pad(x, (0, sequence_length - x.shape[-1]), value=0)
|
|
627
|
+
x = x[None, :].repeat(batch_size, 1)
|
|
553
628
|
else:
|
|
554
629
|
x = None
|
|
630
|
+
if bfn.lora_enabled:
|
|
631
|
+
adjust_lora_(bfn, lora_scaling)
|
|
555
632
|
mols = []
|
|
556
633
|
while len(mols) < runtime_config["inference"]["sample_size"]:
|
|
557
634
|
if x is None:
|
|
@@ -567,7 +644,7 @@ def main_script(version: str) -> None:
|
|
|
567
644
|
method=sample_method,
|
|
568
645
|
allowed_tokens=allowed_token,
|
|
569
646
|
)
|
|
570
|
-
|
|
647
|
+
elif runtime_config["inference"]["guidance_scaffold"]:
|
|
571
648
|
s = inpaint(
|
|
572
649
|
bfn,
|
|
573
650
|
x,
|
|
@@ -579,6 +656,18 @@ def main_script(version: str) -> None:
|
|
|
579
656
|
method=sample_method,
|
|
580
657
|
allowed_tokens=allowed_token,
|
|
581
658
|
)
|
|
659
|
+
else:
|
|
660
|
+
s = optimise(
|
|
661
|
+
bfn,
|
|
662
|
+
x,
|
|
663
|
+
sample_step,
|
|
664
|
+
y,
|
|
665
|
+
guidance_strength,
|
|
666
|
+
_device,
|
|
667
|
+
vocab_keys,
|
|
668
|
+
method=sample_method,
|
|
669
|
+
allowed_tokens=allowed_token,
|
|
670
|
+
)
|
|
582
671
|
if runtime_config["inference"]["exclude_invalid"]:
|
|
583
672
|
s = [i for i in s if i]
|
|
584
673
|
if tokeniser_name == "smiles" or tokeniser_name == "safe":
|
|
@@ -61,7 +61,7 @@ def load_vocab(
|
|
|
61
61
|
}
|
|
62
62
|
|
|
63
63
|
|
|
64
|
-
_DEFUALT_VOCAB = load_vocab(__filedir__ / "vocab.txt")
|
|
64
|
+
_DEFUALT_VOCAB = load_vocab(__filedir__ / "_data/vocab.txt")
|
|
65
65
|
VOCAB_KEYS: List[str] = _DEFUALT_VOCAB["vocab_keys"]
|
|
66
66
|
VOCAB_DICT: Dict[str, int] = _DEFUALT_VOCAB["vocab_dict"]
|
|
67
67
|
VOCAB_COUNT: int = _DEFUALT_VOCAB["vocab_count"]
|