bayesianflow-for-chem 2.0.5__tar.gz → 2.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bayesianflow-for-chem might be problematic. Click here for more details.

Files changed (25) hide show
  1. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/PKG-INFO +4 -1
  2. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/README.md +3 -0
  3. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/bayesianflow_for_chem/__init__.py +4 -3
  4. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/bayesianflow_for_chem/cli.py +120 -31
  5. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/bayesianflow_for_chem/data.py +1 -1
  6. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/bayesianflow_for_chem/model.py +257 -113
  7. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/bayesianflow_for_chem/tool.py +150 -89
  8. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/bayesianflow_for_chem/train.py +8 -1
  9. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/bayesianflow_for_chem.egg-info/PKG-INFO +4 -1
  10. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/bayesianflow_for_chem.egg-info/SOURCES.txt +3 -1
  11. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/pyproject.toml +3 -0
  12. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/setup.py +1 -1
  13. bayesianflow_for_chem-2.2.1/test/test_cli_plugin.py +55 -0
  14. bayesianflow_for_chem-2.2.1/test/test_jit_compatibility.py +28 -0
  15. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/LICENSE +0 -0
  16. {bayesianflow_for_chem-2.0.5/bayesianflow_for_chem → bayesianflow_for_chem-2.2.1/bayesianflow_for_chem/_data}/vocab.txt +0 -0
  17. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/bayesianflow_for_chem/scorer.py +0 -0
  18. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/bayesianflow_for_chem/spectra.py +0 -0
  19. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/bayesianflow_for_chem.egg-info/dependency_links.txt +0 -0
  20. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/bayesianflow_for_chem.egg-info/entry_points.txt +0 -0
  21. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/bayesianflow_for_chem.egg-info/requires.txt +0 -0
  22. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/bayesianflow_for_chem.egg-info/top_level.txt +0 -0
  23. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/setup.cfg +0 -0
  24. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/test/test_merge_lora.py +0 -0
  25. {bayesianflow_for_chem-2.0.5 → bayesianflow_for_chem-2.2.1}/test/test_molecular_embedding.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayesianflow_for_chem
3
- Version: 2.0.5
3
+ Version: 2.2.1
4
4
  Summary: Bayesian flow network framework for Chemistry
5
5
  Home-page: https://augus1999.github.io/bayesian-flow-network-for-chemistry/
6
6
  Author: Nianze A. Tao
@@ -54,12 +54,14 @@ This is the repository of the PyTorch implementation of ChemBFN model.
54
54
 
55
55
  [![PyPI](https://img.shields.io/pypi/v/bayesianflow-for-chem?color=ff69b4)](https://pypi.org/project/bayesianflow-for-chem/)
56
56
  ![pytest](https://github.com/Augus1999/bayesian-flow-network-for-chemistry/actions/workflows/pytest.yml/badge.svg)
57
+ [![document](https://github.com/Augus1999/bayesian-flow-network-for-chemistry/actions/workflows/pages/pages-build-deployment/badge.svg)](https://augus1999.github.io/bayesian-flow-network-for-chemistry/)
57
58
 
58
59
  ## Features
59
60
 
60
61
  ChemBFN provides the state-of-the-art functionalities of
61
62
  * SMILES or SELFIES-based *de novo* molecule generation
62
63
  * Protein sequence *de novo* generation
64
+ * Template optimisation (mol2mol)
63
65
  * Classifier-free guidance conditional generation (single or multi-objective optimisation)
64
66
  * Context-guided conditional generation (inpaint)
65
67
  * Outstanding out-of-distribution chemical space sampling
@@ -71,6 +73,7 @@ in an all-in-one-model style.
71
73
 
72
74
  ## News
73
75
 
76
+ * [09/10/2025] A web app [`chembfn_webui`](https://github.com/Augus1999/ChemBFN-WebUI) for hosting ChemBFN models is available on [PyPI](https://pypi.org/project/chembfn-webui/).
74
77
  * [30/01/2025] The package `bayesianflow_for_chem` is available on [PyPI](https://pypi.org/project/bayesianflow-for-chem/).
75
78
  * [21/01/2025] Our first paper has been accepted by [JCIM](https://pubs.acs.org/doi/10.1021/acs.jcim.4c01792).
76
79
  * [17/12/2024] The second paper of out-of-distribution generation is available on [arxiv.org](https://arxiv.org/abs/2412.11439).
@@ -9,12 +9,14 @@ This is the repository of the PyTorch implementation of ChemBFN model.
9
9
 
10
10
  [![PyPI](https://img.shields.io/pypi/v/bayesianflow-for-chem?color=ff69b4)](https://pypi.org/project/bayesianflow-for-chem/)
11
11
  ![pytest](https://github.com/Augus1999/bayesian-flow-network-for-chemistry/actions/workflows/pytest.yml/badge.svg)
12
+ [![document](https://github.com/Augus1999/bayesian-flow-network-for-chemistry/actions/workflows/pages/pages-build-deployment/badge.svg)](https://augus1999.github.io/bayesian-flow-network-for-chemistry/)
12
13
 
13
14
  ## Features
14
15
 
15
16
  ChemBFN provides the state-of-the-art functionalities of
16
17
  * SMILES or SELFIES-based *de novo* molecule generation
17
18
  * Protein sequence *de novo* generation
19
+ * Template optimisation (mol2mol)
18
20
  * Classifier-free guidance conditional generation (single or multi-objective optimisation)
19
21
  * Context-guided conditional generation (inpaint)
20
22
  * Outstanding out-of-distribution chemical space sampling
@@ -26,6 +28,7 @@ in an all-in-one-model style.
26
28
 
27
29
  ## News
28
30
 
31
+ * [09/10/2025] A web app [`chembfn_webui`](https://github.com/Augus1999/ChemBFN-WebUI) for hosting ChemBFN models is available on [PyPI](https://pypi.org/project/chembfn-webui/).
29
32
  * [30/01/2025] The package `bayesianflow_for_chem` is available on [PyPI](https://pypi.org/project/bayesianflow-for-chem/).
30
33
  * [21/01/2025] Our first paper has been accepted by [JCIM](https://pubs.acs.org/doi/10.1021/acs.jcim.4c01792).
31
34
  * [17/12/2024] The second paper of out-of-distribution generation is available on [arxiv.org](https://arxiv.org/abs/2412.11439).
@@ -3,10 +3,8 @@
3
3
  """
4
4
  ChemBFN package.
5
5
  """
6
- import colorama
7
6
  from . import data, tool, train, scorer, spectra
8
7
  from .model import ChemBFN, MLP, EnsembleChemBFN
9
- from .cli import main_script
10
8
 
11
9
  __all__ = [
12
10
  "data",
@@ -18,7 +16,7 @@ __all__ = [
18
16
  "MLP",
19
17
  "EnsembleChemBFN",
20
18
  ]
21
- __version__ = "2.0.5"
19
+ __version__ = "2.2.1"
22
20
  __author__ = "Nianze A. Tao (Omozawa Sueno)"
23
21
 
24
22
 
@@ -29,6 +27,9 @@ def main() -> None:
29
27
  :return:
30
28
  :rtype: None
31
29
  """
30
+ import colorama
31
+ from bayesianflow_for_chem.cli import main_script
32
+
32
33
  colorama.just_fix_windows_console()
33
34
  main_script(__version__)
34
35
  colorama.deinit()
@@ -12,13 +12,8 @@ from pathlib import Path
12
12
  from functools import partial
13
13
  from typing import List, Tuple, Dict, Union, Callable
14
14
  import torch
15
- import lightning as L
16
15
  from rdkit.Chem import MolFromSmiles, CanonSmiles
17
- from torch.utils.data import DataLoader
18
- from lightning.pytorch import loggers
19
- from lightning.pytorch.callbacks import ModelCheckpoint
20
16
  from bayesianflow_for_chem import ChemBFN, MLP
21
- from bayesianflow_for_chem.train import Model
22
17
  from bayesianflow_for_chem.scorer import smiles_valid, Scorer
23
18
  from bayesianflow_for_chem.data import (
24
19
  VOCAB_COUNT,
@@ -32,7 +27,7 @@ from bayesianflow_for_chem.data import (
32
27
  collate,
33
28
  CSVData,
34
29
  )
35
- from bayesianflow_for_chem.tool import sample, inpaint
30
+ from bayesianflow_for_chem.tool import sample, inpaint, optimise, adjust_lora_
36
31
 
37
32
 
38
33
  """
@@ -90,6 +85,7 @@ checkpoint_save_path = "home/user/project/ckpt"
90
85
  train_strategy = "auto" # or any strategy supported by Lightning, e.g., "ddp"
91
86
  accumulate_grad_batches = 1
92
87
  enable_progress_bar = false
88
+ plugin_script = "" # define customised behaviours of dataset, datasetloader, etc in a python script
93
89
 
94
90
  # Remove this table if inference is unnecessary
95
91
  [inference]
@@ -99,9 +95,11 @@ sample_size = 1000 # the minimum number of samples you want
99
95
  sample_step = 100
100
96
  sample_method = "ODE:0.5" # ODE-solver with temperature of 0.5; another choice is "BFN"
101
97
  semi_autoregressive = false
98
+ lora_scaling = 1.0 # LoRA scaling if applied
102
99
  guidance_objective = [-0.023, 0.09, 0.113] # if no objective is needed set it to empty array []
103
100
  guidance_objective_strength = 4.0 # unnecessary if guidance_objective = []
104
101
  guidance_scaffold = "c1ccccc1" # if no scaffold is used set it to empty string ""
102
+ sample_template = "" # template for mol2mol task; leave it blank if scaffold is used
105
103
  unwanted_token = []
106
104
  exclude_invalid = true # to only store valid samples
107
105
  exclude_duplicate = true # to only store unique samples
@@ -118,6 +116,32 @@ madmadmadmadmadmadmadmadmadmadmadmadmadmadmad
118
116
  madmadmadmadmadmadmadmadmadmadmadmadmadmadmad
119
117
  """
120
118
 
119
+ _ALLOWED_PLUGINS = [
120
+ "collate_fn",
121
+ "num_workers",
122
+ "max_sequence_length",
123
+ "shuffle",
124
+ "CustomData",
125
+ ]
126
+
127
+
128
+ def _load_plugin(plugin_file: str) -> Dict[str, Union[int, Callable, object, None]]:
129
+ if not plugin_file:
130
+ return {n: None for n in _ALLOWED_PLUGINS}
131
+ from importlib import util as iutil
132
+
133
+ spec = iutil.spec_from_file_location(Path(plugin_file).stem, plugin_file)
134
+ plugins = iutil.module_from_spec(spec)
135
+ spec.loader.exec_module(plugins)
136
+ plugin_names: List[str] = plugins.__all__
137
+ plugin_dict = {}
138
+ for n in _ALLOWED_PLUGINS:
139
+ if n in plugin_names:
140
+ plugin_dict[n] = getattr(plugins, n)
141
+ else:
142
+ plugin_dict[n] = None
143
+ return plugin_dict
144
+
121
145
 
122
146
  def parse_cli(version: str) -> argparse.Namespace:
123
147
  """
@@ -130,7 +154,7 @@ def parse_cli(version: str) -> argparse.Namespace:
130
154
  """
131
155
  parser = argparse.ArgumentParser(
132
156
  description="Madmol: a CLI molecular design tool for "
133
- "de novo design, R-group replacement, and sequence in-filling, "
157
+ "de novo design, R-group replacement, molecule optimisation, and sequence in-filling, "
134
158
  "based on generative route of ChemBFN method. "
135
159
  "Let's make some craziest molecules.",
136
160
  epilog=f"Madmol {version}, developed in Hiroshima University by chemists for chemists. "
@@ -157,7 +181,7 @@ def parse_cli(version: str) -> argparse.Namespace:
157
181
  "-D",
158
182
  "--dryrun",
159
183
  action="store_true",
160
- help="dry-run to check the configurations",
184
+ help="dry-run to check the configurations and exit",
161
185
  )
162
186
  parser.add_argument("-V", "--version", action="version", version=version)
163
187
  return parser.parse_args()
@@ -265,6 +289,14 @@ def load_runtime_config(
265
289
  f"\033[0;31mCritical\033[0;0m in {config_file}: Restart checkpoint file {ckpt_file} does not exist."
266
290
  )
267
291
  flag_critical += 1
292
+ # ↓ added in v2.2.0; need to be compatible with old versions.
293
+ plugin_script: str = config["train"].get("plugin_script", "")
294
+ if plugin_script:
295
+ if not os.path.exists(plugin_script):
296
+ print(
297
+ f"\033[0;31mCritical\033[0;0m in {config_file}: Plugin script {plugin_script} does not exist."
298
+ )
299
+ flag_critical += 1
268
300
  if "inference" in config:
269
301
  if not "train" in config:
270
302
  if not isinstance(config["inference"]["sequence_length"], int):
@@ -284,6 +316,14 @@ def load_runtime_config(
284
316
  f"\033[0;33mWarning\033[0;0m in {config_file}: Directory {result_dir} to save the result does not exist."
285
317
  )
286
318
  flag_warning += 1
319
+ if (
320
+ config["inference"]["guidance_scaffold"] != ""
321
+ and config["inference"]["sample_template"] != ""
322
+ ):
323
+ print(
324
+ f"\033[0;33mWarning\033[0;0m in {config_file}: Inpaint task or mol2mol task?"
325
+ )
326
+ flag_warning += 1
287
327
  return config, flag_critical, flag_warning
288
328
 
289
329
 
@@ -325,14 +365,15 @@ def main_script(version: str) -> None:
325
365
  )
326
366
  flag_warning += 1
327
367
  if not os.path.exists(runtime_config["train"]["checkpoint_save_path"]):
328
- os.makedirs(runtime_config["train"]["checkpoint_save_path"])
368
+ if not parser.dryrun: # only create it in real tasks
369
+ os.makedirs(runtime_config["train"]["checkpoint_save_path"])
329
370
  else:
330
371
  if not model_config["ChemBFN"]["base_model"]:
331
372
  print(
332
373
  f"\033[0;33mWarning\033[0;0m in {parser.model_config}: You should load a pretrained ChemBFN model."
333
374
  )
334
375
  flag_warning += 1
335
- if not model_config["MLP"]["base_model"]:
376
+ if "MLP" in model_config and not model_config["MLP"]["base_model"]:
336
377
  print(
337
378
  f"\033[0;33mWarning\033[0;0m in {parser.model_config}: You should load a pretrained MLP."
338
379
  )
@@ -340,13 +381,17 @@ def main_script(version: str) -> None:
340
381
  if "inference" in runtime_config:
341
382
  if runtime_config["inference"]["guidance_objective"]:
342
383
  if not "MLP" in model_config:
343
- print(f"Warning in {parser.model_config}: Oh no, you don't have a MLP.")
384
+ print(
385
+ f"\033[0;33mWarning\033[0;0m in {parser.model_config}: Oh no, you don't have a MLP."
386
+ )
344
387
  flag_warning += 1
345
388
  if parser.dryrun:
346
389
  if flag_critical != 0:
347
390
  print("Configuration check failed!")
348
391
  elif flag_warning != 0:
349
- print("Your job will probably run, but it may not follow your expectation.")
392
+ print(
393
+ "Your job will probably run, but it may not follow your expectations."
394
+ )
350
395
  else:
351
396
  print("Configuration check passed.")
352
397
  return
@@ -405,6 +450,15 @@ def main_script(version: str) -> None:
405
450
  mlp = None
406
451
  # ------- train -------
407
452
  if "train" in runtime_config:
453
+ import lightning as L
454
+ from torch.utils.data import DataLoader
455
+ from lightning.pytorch import loggers
456
+ from lightning.pytorch.callbacks import ModelCheckpoint
457
+ from bayesianflow_for_chem.train import Model
458
+
459
+ # ####### get plugins #######
460
+ plugin_file = runtime_config["train"].get("plugin_script", "")
461
+ plugins = _load_plugin(plugin_file)
408
462
  # ####### build scorer #######
409
463
  if (tokeniser_name == "smiles" or tokeniser_name == "safe") and runtime_config[
410
464
  "train"
@@ -418,30 +472,43 @@ def main_script(version: str) -> None:
418
472
  mol_tag = runtime_config["train"]["molecule_tag"]
419
473
  obj_tag = runtime_config["train"]["objective_tag"]
420
474
  dataset_file = runtime_config["train"]["dataset"]
421
- with open(dataset_file, "r") as db:
422
- _data = db.readlines()
423
- header = _data[0]
424
- mol_idx = []
425
- for i, tag in enumerate(header.replace("\n", "").split(",")):
426
- if tag == mol_tag:
427
- mol_idx.append(i)
428
- _data_len = []
429
- for i in _data[1:]:
430
- i = i.replace("\n", "").split(",")
431
- _mol = ".".join([i[j] for j in mol_idx])
432
- _data_len.append(tokeniser(_mol).shape[-1])
433
- lmax = max(_data_len)
434
- dataset = CSVData(dataset_file)
475
+ if plugins["max_sequence_length"]:
476
+ lmax = plugins["max_sequence_length"]
477
+ else:
478
+ with open(dataset_file, "r") as db:
479
+ _data = db.readlines()
480
+ _header = _data[0]
481
+ _mol_idx = []
482
+ for i, tag in enumerate(_header.replace("\n", "").split(",")):
483
+ if tag == mol_tag:
484
+ _mol_idx.append(i)
485
+ _data_len = []
486
+ for i in _data[1:]:
487
+ i = i.replace("\n", "").split(",")
488
+ _mol = ".".join([i[j] for j in _mol_idx])
489
+ _data_len.append(tokeniser(_mol).shape[-1])
490
+ lmax = max(_data_len)
491
+ del _data, _data_len, _header, _mol_idx # clear memory
492
+ if plugins["CustomData"] is not None:
493
+ dataset = plugins["CustomData"](dataset_file)
494
+ else:
495
+ dataset = CSVData(dataset_file)
435
496
  dataset.map(
436
497
  partial(_encode, mol_tag=mol_tag, obj_tag=obj_tag, tokeniser=tokeniser)
437
498
  )
438
499
  dataloader = DataLoader(
439
500
  dataset,
440
501
  runtime_config["train"]["batch_size"],
441
- True,
442
- num_workers=4,
443
- collate_fn=collate,
444
- persistent_workers=True,
502
+ True if plugins["shuffle"] is None else plugins["shuffle"],
503
+ num_workers=4 if plugins["num_workers"] is None else plugins["num_workers"],
504
+ collate_fn=(
505
+ collate if plugins["collate_fn"] is None else plugins["collate_fn"]
506
+ ),
507
+ persistent_workers=(
508
+ True
509
+ if (plugins["num_workers"] is None or plugins["num_workers"] > 0)
510
+ else False
511
+ ),
445
512
  )
446
513
  # ####### build trainer #######
447
514
  logger_name = runtime_config["train"]["logger_name"].lower()
@@ -520,6 +587,8 @@ def main_script(version: str) -> None:
520
587
  if "train" in runtime_config:
521
588
  bfn = model.model
522
589
  mlp = model.mlp
590
+ # ↓ added in v2.1.0; need to be compatible with old versions
591
+ lora_scaling = runtime_config["inference"].get("lora_scaling", 1.0)
523
592
  # ####### strat inference #######
524
593
  bfn.semi_autoregressive = runtime_config["inference"]["semi_autoregressive"]
525
594
  _device = (
@@ -550,8 +619,16 @@ def main_script(version: str) -> None:
550
619
  x[:-1], (0, sequence_length - x.shape[-1] + 1), value=0
551
620
  )
552
621
  x = x[None, :].repeat(batch_size, 1)
622
+ # then sample template will be ignored.
623
+ elif runtime_config["inference"]["sample_template"]:
624
+ template = runtime_config["inference"]["sample_template"]
625
+ x = tokeniser(template)
626
+ x = torch.nn.functional.pad(x, (0, sequence_length - x.shape[-1]), value=0)
627
+ x = x[None, :].repeat(batch_size, 1)
553
628
  else:
554
629
  x = None
630
+ if bfn.lora_enabled:
631
+ adjust_lora_(bfn, lora_scaling)
555
632
  mols = []
556
633
  while len(mols) < runtime_config["inference"]["sample_size"]:
557
634
  if x is None:
@@ -567,7 +644,7 @@ def main_script(version: str) -> None:
567
644
  method=sample_method,
568
645
  allowed_tokens=allowed_token,
569
646
  )
570
- else:
647
+ elif runtime_config["inference"]["guidance_scaffold"]:
571
648
  s = inpaint(
572
649
  bfn,
573
650
  x,
@@ -579,6 +656,18 @@ def main_script(version: str) -> None:
579
656
  method=sample_method,
580
657
  allowed_tokens=allowed_token,
581
658
  )
659
+ else:
660
+ s = optimise(
661
+ bfn,
662
+ x,
663
+ sample_step,
664
+ y,
665
+ guidance_strength,
666
+ _device,
667
+ vocab_keys,
668
+ method=sample_method,
669
+ allowed_tokens=allowed_token,
670
+ )
582
671
  if runtime_config["inference"]["exclude_invalid"]:
583
672
  s = [i for i in s if i]
584
673
  if tokeniser_name == "smiles" or tokeniser_name == "safe":
@@ -61,7 +61,7 @@ def load_vocab(
61
61
  }
62
62
 
63
63
 
64
- _DEFUALT_VOCAB = load_vocab(__filedir__ / "vocab.txt")
64
+ _DEFUALT_VOCAB = load_vocab(__filedir__ / "_data/vocab.txt")
65
65
  VOCAB_KEYS: List[str] = _DEFUALT_VOCAB["vocab_keys"]
66
66
  VOCAB_DICT: Dict[str, int] = _DEFUALT_VOCAB["vocab_dict"]
67
67
  VOCAB_COUNT: int = _DEFUALT_VOCAB["vocab_count"]