bayesianflow-for-chem 2.1.0__tar.gz → 2.2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bayesianflow-for-chem might be problematic. Click here for more details.

Files changed (25) hide show
  1. {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/PKG-INFO +5 -2
  2. {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/README.md +3 -0
  3. {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/bayesianflow_for_chem/__init__.py +4 -3
  4. {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/bayesianflow_for_chem/cli.py +110 -34
  5. {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/bayesianflow_for_chem/data.py +23 -22
  6. {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/bayesianflow_for_chem/model.py +131 -154
  7. {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/bayesianflow_for_chem/tool.py +100 -127
  8. {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/bayesianflow_for_chem/train.py +8 -1
  9. {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/bayesianflow_for_chem.egg-info/PKG-INFO +5 -2
  10. {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/bayesianflow_for_chem.egg-info/SOURCES.txt +2 -1
  11. {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/pyproject.toml +3 -0
  12. {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/setup.py +1 -1
  13. bayesianflow_for_chem-2.2.2/test/test_cli_plugin.py +55 -0
  14. {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/LICENSE +0 -0
  15. {bayesianflow_for_chem-2.1.0/bayesianflow_for_chem → bayesianflow_for_chem-2.2.2/bayesianflow_for_chem/_data}/vocab.txt +0 -0
  16. {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/bayesianflow_for_chem/scorer.py +0 -0
  17. {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/bayesianflow_for_chem/spectra.py +0 -0
  18. {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/bayesianflow_for_chem.egg-info/dependency_links.txt +0 -0
  19. {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/bayesianflow_for_chem.egg-info/entry_points.txt +0 -0
  20. {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/bayesianflow_for_chem.egg-info/requires.txt +0 -0
  21. {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/bayesianflow_for_chem.egg-info/top_level.txt +0 -0
  22. {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/setup.cfg +0 -0
  23. {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/test/test_jit_compatibility.py +0 -0
  24. {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/test/test_merge_lora.py +0 -0
  25. {bayesianflow_for_chem-2.1.0 → bayesianflow_for_chem-2.2.2}/test/test_molecular_embedding.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayesianflow_for_chem
3
- Version: 2.1.0
3
+ Version: 2.2.2
4
4
  Summary: Bayesian flow network framework for Chemistry
5
5
  Home-page: https://augus1999.github.io/bayesian-flow-network-for-chemistry/
6
6
  Author: Nianze A. Tao
@@ -54,12 +54,14 @@ This is the repository of the PyTorch implementation of ChemBFN model.
54
54
 
55
55
  [![PyPI](https://img.shields.io/pypi/v/bayesianflow-for-chem?color=ff69b4)](https://pypi.org/project/bayesianflow-for-chem/)
56
56
  ![pytest](https://github.com/Augus1999/bayesian-flow-network-for-chemistry/actions/workflows/pytest.yml/badge.svg)
57
+ [![document](https://github.com/Augus1999/bayesian-flow-network-for-chemistry/actions/workflows/pages/pages-build-deployment/badge.svg)](https://augus1999.github.io/bayesian-flow-network-for-chemistry/)
57
58
 
58
59
  ## Features
59
60
 
60
61
  ChemBFN provides the state-of-the-art functionalities of
61
62
  * SMILES or SELFIES-based *de novo* molecule generation
62
63
  * Protein sequence *de novo* generation
64
+ * Template optimisation (mol2mol)
63
65
  * Classifier-free guidance conditional generation (single or multi-objective optimisation)
64
66
  * Context-guided conditional generation (inpaint)
65
67
  * Outstanding out-of-distribution chemical space sampling
@@ -71,6 +73,7 @@ in an all-in-one-model style.
71
73
 
72
74
  ## News
73
75
 
76
+ * [09/10/2025] A web app [`chembfn_webui`](https://github.com/Augus1999/ChemBFN-WebUI) for hosting ChemBFN models is available on [PyPI](https://pypi.org/project/chembfn-webui/).
74
77
  * [30/01/2025] The package `bayesianflow_for_chem` is available on [PyPI](https://pypi.org/project/bayesianflow-for-chem/).
75
78
  * [21/01/2025] Our first paper has been accepted by [JCIM](https://pubs.acs.org/doi/10.1021/acs.jcim.4c01792).
76
79
  * [17/12/2024] The second paper of out-of-distribution generation is available on [arxiv.org](https://arxiv.org/abs/2412.11439).
@@ -93,7 +96,7 @@ You can find pretrained models on our [🤗Hugging Face model page](https://hugg
93
96
 
94
97
  ## Dataset Handling
95
98
 
96
- We provide a Python class [`CSVData`](https://github.com/Augus1999/bayesian-flow-network-for-chemistry/blob/main/bayesianflow_for_chem/data.py#L152) to handle data stored in CSV or similar format containing headers to identify the entities. The following is a quickstart.
99
+ We provide a Python class [`CSVData`](https://github.com/Augus1999/bayesian-flow-network-for-chemistry/blob/main/bayesianflow_for_chem/data.py#L153) to handle data stored in CSV or similar format containing headers to identify the entities. The following is a quickstart.
97
100
 
98
101
  1. Download your dataset file (e.g., ESOL from [MoleculeNet](https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv)) and split the file:
99
102
  ```python
@@ -9,12 +9,14 @@ This is the repository of the PyTorch implementation of ChemBFN model.
9
9
 
10
10
  [![PyPI](https://img.shields.io/pypi/v/bayesianflow-for-chem?color=ff69b4)](https://pypi.org/project/bayesianflow-for-chem/)
11
11
  ![pytest](https://github.com/Augus1999/bayesian-flow-network-for-chemistry/actions/workflows/pytest.yml/badge.svg)
12
+ [![document](https://github.com/Augus1999/bayesian-flow-network-for-chemistry/actions/workflows/pages/pages-build-deployment/badge.svg)](https://augus1999.github.io/bayesian-flow-network-for-chemistry/)
12
13
 
13
14
  ## Features
14
15
 
15
16
  ChemBFN provides the state-of-the-art functionalities of
16
17
  * SMILES or SELFIES-based *de novo* molecule generation
17
18
  * Protein sequence *de novo* generation
19
+ * Template optimisation (mol2mol)
18
20
  * Classifier-free guidance conditional generation (single or multi-objective optimisation)
19
21
  * Context-guided conditional generation (inpaint)
20
22
  * Outstanding out-of-distribution chemical space sampling
@@ -26,6 +28,7 @@ in an all-in-one-model style.
26
28
 
27
29
  ## News
28
30
 
31
+ * [09/10/2025] A web app [`chembfn_webui`](https://github.com/Augus1999/ChemBFN-WebUI) for hosting ChemBFN models is available on [PyPI](https://pypi.org/project/chembfn-webui/).
29
32
  * [30/01/2025] The package `bayesianflow_for_chem` is available on [PyPI](https://pypi.org/project/bayesianflow-for-chem/).
30
33
  * [21/01/2025] Our first paper has been accepted by [JCIM](https://pubs.acs.org/doi/10.1021/acs.jcim.4c01792).
31
34
  * [17/12/2024] The second paper of out-of-distribution generation is available on [arxiv.org](https://arxiv.org/abs/2412.11439).
@@ -3,10 +3,8 @@
3
3
  """
4
4
  ChemBFN package.
5
5
  """
6
- import colorama
7
6
  from . import data, tool, train, scorer, spectra
8
7
  from .model import ChemBFN, MLP, EnsembleChemBFN
9
- from .cli import main_script
10
8
 
11
9
  __all__ = [
12
10
  "data",
@@ -18,7 +16,7 @@ __all__ = [
18
16
  "MLP",
19
17
  "EnsembleChemBFN",
20
18
  ]
21
- __version__ = "2.1.0"
19
+ __version__ = "2.2.2"
22
20
  __author__ = "Nianze A. Tao (Omozawa Sueno)"
23
21
 
24
22
 
@@ -29,6 +27,9 @@ def main() -> None:
29
27
  :return:
30
28
  :rtype: None
31
29
  """
30
+ import colorama
31
+ from bayesianflow_for_chem.cli import main_script
32
+
32
33
  colorama.just_fix_windows_console()
33
34
  main_script(__version__)
34
35
  colorama.deinit()
@@ -12,22 +12,17 @@ from pathlib import Path
12
12
  from functools import partial
13
13
  from typing import List, Tuple, Dict, Union, Callable
14
14
  import torch
15
- import lightning as L
16
15
  from rdkit.Chem import MolFromSmiles, CanonSmiles
17
- from torch.utils.data import DataLoader
18
- from lightning.pytorch import loggers
19
- from lightning.pytorch.callbacks import ModelCheckpoint
20
16
  from bayesianflow_for_chem import ChemBFN, MLP
21
- from bayesianflow_for_chem.train import Model
22
17
  from bayesianflow_for_chem.scorer import smiles_valid, Scorer
23
18
  from bayesianflow_for_chem.data import (
24
19
  VOCAB_COUNT,
25
20
  VOCAB_KEYS,
26
- AA_VOCAB_COUNT,
27
- AA_VOCAB_KEYS,
21
+ FASTA_VOCAB_COUNT,
22
+ FASTA_VOCAB_KEYS,
28
23
  load_vocab,
29
24
  smiles2token,
30
- aa2token,
25
+ fasta2token,
31
26
  split_selfies,
32
27
  collate,
33
28
  CSVData,
@@ -90,6 +85,7 @@ checkpoint_save_path = "home/user/project/ckpt"
90
85
  train_strategy = "auto" # or any strategy supported by Lightning, e.g., "ddp"
91
86
  accumulate_grad_batches = 1
92
87
  enable_progress_bar = false
88
+ plugin_script = "" # define customised behaviours of dataset, datasetloader, etc in a python script
93
89
 
94
90
  # Remove this table if inference is unnecessary
95
91
  [inference]
@@ -120,6 +116,49 @@ madmadmadmadmadmadmadmadmadmadmadmadmadmadmad
120
116
  madmadmadmadmadmadmadmadmadmadmadmadmadmadmad
121
117
  """
122
118
 
119
+ _END_MESSAGE = r"""
120
+ If you find this project helpful, please cite us:
121
+ 1. N. Tao, and M. Abe, J. Chem. Inf. Model., 2025, 65, 1178-1187.
122
+ 2. N. Tao, 2024, arXiv:2412.11439.
123
+ """
124
+
125
+ _ERROR_MESSAGE = r"""
126
+ Some who believe in inductive logic are anxious to point out, with
127
+ Reichenbach, that 'the principle of induction is unreservedly accepted
128
+ by the whole of science and that no man can seriously doubt this
129
+ principle in everyday life either'. Yet even supposing this were the
130
+ case—for after all, 'the whole of science' might err—I should still
131
+ contend that a principle of induction is superfluous, and that it must
132
+ lead to logical inconsistencies.
133
+ -- Karl Popper --
134
+ """
135
+
136
+ _ALLOWED_PLUGINS = [
137
+ "collate_fn",
138
+ "num_workers",
139
+ "max_sequence_length",
140
+ "shuffle",
141
+ "CustomData",
142
+ ]
143
+
144
+
145
+ def _load_plugin(plugin_file: str) -> Dict[str, Union[int, Callable, object, None]]:
146
+ if not plugin_file:
147
+ return {n: None for n in _ALLOWED_PLUGINS}
148
+ from importlib import util as iutil
149
+
150
+ spec = iutil.spec_from_file_location(Path(plugin_file).stem, plugin_file)
151
+ plugins = iutil.module_from_spec(spec)
152
+ spec.loader.exec_module(plugins)
153
+ plugin_names: List[str] = plugins.__all__
154
+ plugin_dict = {}
155
+ for n in _ALLOWED_PLUGINS:
156
+ if n in plugin_names:
157
+ plugin_dict[n] = getattr(plugins, n)
158
+ else:
159
+ plugin_dict[n] = None
160
+ return plugin_dict
161
+
123
162
 
124
163
  def parse_cli(version: str) -> argparse.Namespace:
125
164
  """
@@ -267,6 +306,14 @@ def load_runtime_config(
267
306
  f"\033[0;31mCritical\033[0;0m in {config_file}: Restart checkpoint file {ckpt_file} does not exist."
268
307
  )
269
308
  flag_critical += 1
309
+ # ↓ added in v2.2.0; need to be compatible with old versions.
310
+ plugin_script: str = config["train"].get("plugin_script", "")
311
+ if plugin_script:
312
+ if not os.path.exists(plugin_script):
313
+ print(
314
+ f"\033[0;31mCritical\033[0;0m in {config_file}: Plugin script {plugin_script} does not exist."
315
+ )
316
+ flag_critical += 1
270
317
  if "inference" in config:
271
318
  if not "train" in config:
272
319
  if not isinstance(config["inference"]["sequence_length"], int):
@@ -335,14 +382,15 @@ def main_script(version: str) -> None:
335
382
  )
336
383
  flag_warning += 1
337
384
  if not os.path.exists(runtime_config["train"]["checkpoint_save_path"]):
338
- os.makedirs(runtime_config["train"]["checkpoint_save_path"])
385
+ if not parser.dryrun: # only create it in real tasks
386
+ os.makedirs(runtime_config["train"]["checkpoint_save_path"])
339
387
  else:
340
388
  if not model_config["ChemBFN"]["base_model"]:
341
389
  print(
342
390
  f"\033[0;33mWarning\033[0;0m in {parser.model_config}: You should load a pretrained ChemBFN model."
343
391
  )
344
392
  flag_warning += 1
345
- if not model_config["MLP"]["base_model"]:
393
+ if "MLP" in model_config and not model_config["MLP"]["base_model"]:
346
394
  print(
347
395
  f"\033[0;33mWarning\033[0;0m in {parser.model_config}: You should load a pretrained MLP."
348
396
  )
@@ -350,18 +398,22 @@ def main_script(version: str) -> None:
350
398
  if "inference" in runtime_config:
351
399
  if runtime_config["inference"]["guidance_objective"]:
352
400
  if not "MLP" in model_config:
353
- print(f"Warning in {parser.model_config}: Oh no, you don't have a MLP.")
401
+ print(
402
+ f"\033[0;33mWarning\033[0;0m in {parser.model_config}: Oh no, you don't have a MLP."
403
+ )
354
404
  flag_warning += 1
355
405
  if parser.dryrun:
356
406
  if flag_critical != 0:
357
407
  print("Configuration check failed!")
358
408
  elif flag_warning != 0:
359
- print("Your job will probably run, but it may not follow your expectation.")
409
+ print(
410
+ "Your job will probably run, but it may not follow your expectations."
411
+ )
360
412
  else:
361
413
  print("Configuration check passed.")
362
414
  return
363
415
  if flag_critical != 0:
364
- raise RuntimeError
416
+ raise RuntimeError(_ERROR_MESSAGE)
365
417
  print(_MESSAGE.format(version))
366
418
  # ####### build tokeniser #######
367
419
  tokeniser_config = runtime_config["tokeniser"]
@@ -371,9 +423,9 @@ def main_script(version: str) -> None:
371
423
  vocab_keys = VOCAB_KEYS
372
424
  tokeniser = smiles2token
373
425
  if tokeniser_name == "fasta":
374
- num_vocab = AA_VOCAB_COUNT
375
- vocab_keys = AA_VOCAB_KEYS
376
- tokeniser = aa2token
426
+ num_vocab = FASTA_VOCAB_COUNT
427
+ vocab_keys = FASTA_VOCAB_KEYS
428
+ tokeniser = fasta2token
377
429
  if tokeniser_name == "selfies":
378
430
  vocab_data = load_vocab(tokeniser_config["vocab"])
379
431
  num_vocab = vocab_data["vocab_count"]
@@ -415,6 +467,15 @@ def main_script(version: str) -> None:
415
467
  mlp = None
416
468
  # ------- train -------
417
469
  if "train" in runtime_config:
470
+ import lightning as L
471
+ from torch.utils.data import DataLoader
472
+ from lightning.pytorch import loggers
473
+ from lightning.pytorch.callbacks import ModelCheckpoint
474
+ from bayesianflow_for_chem.train import Model
475
+
476
+ # ####### get plugins #######
477
+ plugin_file = runtime_config["train"].get("plugin_script", "")
478
+ plugins = _load_plugin(plugin_file)
418
479
  # ####### build scorer #######
419
480
  if (tokeniser_name == "smiles" or tokeniser_name == "safe") and runtime_config[
420
481
  "train"
@@ -428,30 +489,43 @@ def main_script(version: str) -> None:
428
489
  mol_tag = runtime_config["train"]["molecule_tag"]
429
490
  obj_tag = runtime_config["train"]["objective_tag"]
430
491
  dataset_file = runtime_config["train"]["dataset"]
431
- with open(dataset_file, "r") as db:
432
- _data = db.readlines()
433
- header = _data[0]
434
- mol_idx = []
435
- for i, tag in enumerate(header.replace("\n", "").split(",")):
436
- if tag == mol_tag:
437
- mol_idx.append(i)
438
- _data_len = []
439
- for i in _data[1:]:
440
- i = i.replace("\n", "").split(",")
441
- _mol = ".".join([i[j] for j in mol_idx])
442
- _data_len.append(tokeniser(_mol).shape[-1])
443
- lmax = max(_data_len)
444
- dataset = CSVData(dataset_file)
492
+ if plugins["max_sequence_length"]:
493
+ lmax = plugins["max_sequence_length"]
494
+ else:
495
+ with open(dataset_file, "r") as db:
496
+ _data = db.readlines()
497
+ _header = _data[0]
498
+ _mol_idx = []
499
+ for i, tag in enumerate(_header.replace("\n", "").split(",")):
500
+ if tag == mol_tag:
501
+ _mol_idx.append(i)
502
+ _data_len = []
503
+ for i in _data[1:]:
504
+ i = i.replace("\n", "").split(",")
505
+ _mol = ".".join([i[j] for j in _mol_idx])
506
+ _data_len.append(tokeniser(_mol).shape[-1])
507
+ lmax = max(_data_len)
508
+ del _data, _data_len, _header, _mol_idx # clear memory
509
+ if plugins["CustomData"] is not None:
510
+ dataset = plugins["CustomData"](dataset_file)
511
+ else:
512
+ dataset = CSVData(dataset_file)
445
513
  dataset.map(
446
514
  partial(_encode, mol_tag=mol_tag, obj_tag=obj_tag, tokeniser=tokeniser)
447
515
  )
448
516
  dataloader = DataLoader(
449
517
  dataset,
450
518
  runtime_config["train"]["batch_size"],
451
- True,
452
- num_workers=4,
453
- collate_fn=collate,
454
- persistent_workers=True,
519
+ True if plugins["shuffle"] is None else plugins["shuffle"],
520
+ num_workers=4 if plugins["num_workers"] is None else plugins["num_workers"],
521
+ collate_fn=(
522
+ collate if plugins["collate_fn"] is None else plugins["collate_fn"]
523
+ ),
524
+ persistent_workers=(
525
+ True
526
+ if (plugins["num_workers"] is None or plugins["num_workers"] > 0)
527
+ else False
528
+ ),
455
529
  )
456
530
  # ####### build trainer #######
457
531
  logger_name = runtime_config["train"]["logger_name"].lower()
@@ -530,6 +604,7 @@ def main_script(version: str) -> None:
530
604
  if "train" in runtime_config:
531
605
  bfn = model.model
532
606
  mlp = model.mlp
607
+ # ↓ added in v2.1.0; need to be compatible with old versions
533
608
  lora_scaling = runtime_config["inference"].get("lora_scaling", 1.0)
534
609
  # ####### strat inference #######
535
610
  bfn.semi_autoregressive = runtime_config["inference"]["semi_autoregressive"]
@@ -622,6 +697,7 @@ def main_script(version: str) -> None:
622
697
  f.write("\n".join(mols))
623
698
  # ------- finished -------
624
699
  print(" ####### job finished #######")
700
+ print(_END_MESSAGE)
625
701
 
626
702
 
627
703
  if __name__ == "__main__":
@@ -1,7 +1,7 @@
1
1
  # -*- coding: utf-8 -*-
2
2
  # Author: Nianze A. TAO (Omozawa SUENO)
3
3
  """
4
- Tokenise SMILES/SAFE/SELFIES/protein-sequence strings.
4
+ Tokenise SMILES/SAFE/SELFIES/FASTA strings.
5
5
  """
6
6
  import os
7
7
  import re
@@ -14,7 +14,7 @@ from torch.utils.data import Dataset
14
14
 
15
15
  __filedir__ = Path(__file__).parent
16
16
 
17
- SMI_REGEX_PATTERN = (
17
+ _SMI_REGEX_PATTERN = (
18
18
  r"(\[|\]|H[e,f,g,s,o]?|"
19
19
  r"L[i,v,a,r,u]|"
20
20
  r"B[e,r,a,i,h,k]?|"
@@ -31,11 +31,11 @@ SMI_REGEX_PATTERN = (
31
31
  r"\(|\)|\.|=|#|-|\+|\\|\/|:|"
32
32
  r"~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
33
33
  )
34
- SEL_REGEX_PATTERN = r"(\[[^\]]+]|\.)"
35
- AA_REGEX_PATTERN = r"(A|B|C|D|E|F|G|H|I|K|L|M|N|P|Q|R|S|T|V|W|Y|Z|-|.)"
36
- smi_regex = re.compile(SMI_REGEX_PATTERN)
37
- sel_regex = re.compile(SEL_REGEX_PATTERN)
38
- aa_regex = re.compile(AA_REGEX_PATTERN)
34
+ _SEL_REGEX_PATTERN = r"(\[[^\]]+]|\.)"
35
+ _FAS_REGEX_PATTERN = r"(A|B|C|D|E|F|G|H|I|J|K|L|M|N|O|P|Q|R|S|T|U|V|W|X|Y|Z|-|\*|\.)"
36
+ _smi_regex = re.compile(_SMI_REGEX_PATTERN)
37
+ _sel_regex = re.compile(_SEL_REGEX_PATTERN)
38
+ _fas_regex = re.compile(_FAS_REGEX_PATTERN)
39
39
 
40
40
 
41
41
  def load_vocab(
@@ -61,15 +61,16 @@ def load_vocab(
61
61
  }
62
62
 
63
63
 
64
- _DEFUALT_VOCAB = load_vocab(__filedir__ / "vocab.txt")
64
+ _DEFUALT_VOCAB = load_vocab(__filedir__ / "_data/vocab.txt")
65
65
  VOCAB_KEYS: List[str] = _DEFUALT_VOCAB["vocab_keys"]
66
66
  VOCAB_DICT: Dict[str, int] = _DEFUALT_VOCAB["vocab_dict"]
67
67
  VOCAB_COUNT: int = _DEFUALT_VOCAB["vocab_count"]
68
- AA_VOCAB_KEYS = (
69
- VOCAB_KEYS[0:3] + "A B C D E F G H I K L M N P Q R S T V W Y Z - .".split()
68
+ FASTA_VOCAB_KEYS = (
69
+ VOCAB_KEYS[0:3]
70
+ + "A B C D E F G H I K L M N P Q R S T V W Y Z - . J O U X *".split()
70
71
  )
71
- AA_VOCAB_COUNT = len(AA_VOCAB_KEYS)
72
- AA_VOCAB_DICT = dict(zip(AA_VOCAB_KEYS, range(AA_VOCAB_COUNT)))
72
+ FASTA_VOCAB_COUNT = len(FASTA_VOCAB_KEYS)
73
+ FASTA_VOCAB_DICT = dict(zip(FASTA_VOCAB_KEYS, range(FASTA_VOCAB_COUNT)))
73
74
 
74
75
 
75
76
  def smiles2vec(smiles: str) -> List[int]:
@@ -81,21 +82,21 @@ def smiles2vec(smiles: str) -> List[int]:
81
82
  :return: tokens w/o `<start>` and `<end>`
82
83
  :rtype: list
83
84
  """
84
- tokens = [token for token in smi_regex.findall(smiles)]
85
+ tokens = [token for token in _smi_regex.findall(smiles)]
85
86
  return [VOCAB_DICT[token] for token in tokens]
86
87
 
87
88
 
88
- def aa2vec(aa_seq: str) -> List[int]:
89
+ def fasta2vec(fasta: str) -> List[int]:
89
90
  """
90
- Protein sequence tokenisation using a dataset-independent regex pattern.
91
+ FASTA sequence tokenisation using a dataset-independent regex pattern.
91
92
 
92
- :param aa_seq: protein (amino acid) sequence
93
- :type aa_seq: str
93
+ :param fasta: protein (amino acid) sequence
94
+ :type fasta: str
94
95
  :return: tokens w/o `<start>` and `<end>`
95
96
  :rtype: list
96
97
  """
97
- tokens = [token for token in aa_regex.findall(aa_seq)]
98
- return [AA_VOCAB_DICT[token] for token in tokens]
98
+ tokens = [token for token in _fas_regex.findall(fasta)]
99
+ return [FASTA_VOCAB_DICT[token] for token in tokens]
99
100
 
100
101
 
101
102
  def split_selfies(selfies: str) -> List[str]:
@@ -107,7 +108,7 @@ def split_selfies(selfies: str) -> List[str]:
107
108
  :return: SELFIES vocab
108
109
  :rtype: list
109
110
  """
110
- return [token for token in sel_regex.findall(selfies)]
111
+ return [token for token in _sel_regex.findall(selfies)]
111
112
 
112
113
 
113
114
  def smiles2token(smiles: str) -> Tensor:
@@ -115,9 +116,9 @@ def smiles2token(smiles: str) -> Tensor:
115
116
  return torch.tensor([1] + smiles2vec(smiles) + [2], dtype=torch.long)
116
117
 
117
118
 
118
- def aa2token(aa_seq: str) -> Tensor:
119
+ def fasta2token(fasta: str) -> Tensor:
119
120
  # start token: <start> = 1; end token: <end> = 2
120
- return torch.tensor([1] + aa2vec(aa_seq) + [2], dtype=torch.long)
121
+ return torch.tensor([1] + fasta2vec(fasta) + [2], dtype=torch.long)
121
122
 
122
123
 
123
124
  def collate(batch: List[Dict[str, Tensor]]) -> Dict[str, Tensor]: