bayesianflow-for-chem 2.1.0__py3-none-any.whl → 2.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bayesianflow-for-chem might be problematic. Click here for more details.

@@ -3,10 +3,8 @@
3
3
  """
4
4
  ChemBFN package.
5
5
  """
6
- import colorama
7
6
  from . import data, tool, train, scorer, spectra
8
7
  from .model import ChemBFN, MLP, EnsembleChemBFN
9
- from .cli import main_script
10
8
 
11
9
  __all__ = [
12
10
  "data",
@@ -18,7 +16,7 @@ __all__ = [
18
16
  "MLP",
19
17
  "EnsembleChemBFN",
20
18
  ]
21
- __version__ = "2.1.0"
19
+ __version__ = "2.2.2"
22
20
  __author__ = "Nianze A. Tao (Omozawa Sueno)"
23
21
 
24
22
 
@@ -29,6 +27,9 @@ def main() -> None:
29
27
  :return:
30
28
  :rtype: None
31
29
  """
30
+ import colorama
31
+ from bayesianflow_for_chem.cli import main_script
32
+
32
33
  colorama.just_fix_windows_console()
33
34
  main_script(__version__)
34
35
  colorama.deinit()
@@ -12,22 +12,17 @@ from pathlib import Path
12
12
  from functools import partial
13
13
  from typing import List, Tuple, Dict, Union, Callable
14
14
  import torch
15
- import lightning as L
16
15
  from rdkit.Chem import MolFromSmiles, CanonSmiles
17
- from torch.utils.data import DataLoader
18
- from lightning.pytorch import loggers
19
- from lightning.pytorch.callbacks import ModelCheckpoint
20
16
  from bayesianflow_for_chem import ChemBFN, MLP
21
- from bayesianflow_for_chem.train import Model
22
17
  from bayesianflow_for_chem.scorer import smiles_valid, Scorer
23
18
  from bayesianflow_for_chem.data import (
24
19
  VOCAB_COUNT,
25
20
  VOCAB_KEYS,
26
- AA_VOCAB_COUNT,
27
- AA_VOCAB_KEYS,
21
+ FASTA_VOCAB_COUNT,
22
+ FASTA_VOCAB_KEYS,
28
23
  load_vocab,
29
24
  smiles2token,
30
- aa2token,
25
+ fasta2token,
31
26
  split_selfies,
32
27
  collate,
33
28
  CSVData,
@@ -90,6 +85,7 @@ checkpoint_save_path = "home/user/project/ckpt"
90
85
  train_strategy = "auto" # or any strategy supported by Lightning, e.g., "ddp"
91
86
  accumulate_grad_batches = 1
92
87
  enable_progress_bar = false
88
+ plugin_script = "" # define customised behaviours of dataset, datasetloader, etc in a python script
93
89
 
94
90
  # Remove this table if inference is unnecessary
95
91
  [inference]
@@ -120,6 +116,49 @@ madmadmadmadmadmadmadmadmadmadmadmadmadmadmad
120
116
  madmadmadmadmadmadmadmadmadmadmadmadmadmadmad
121
117
  """
122
118
 
119
+ _END_MESSAGE = r"""
120
+ If you find this project helpful, please cite us:
121
+ 1. N. Tao, and M. Abe, J. Chem. Inf. Model., 2025, 65, 1178-1187.
122
+ 2. N. Tao, 2024, arXiv:2412.11439.
123
+ """
124
+
125
+ _ERROR_MESSAGE = r"""
126
+ Some who believe in inductive logic are anxious to point out, with
127
+ Reichenbach, that 'the principle of induction is unreservedly accepted
128
+ by the whole of science and that no man can seriously doubt this
129
+ principle in everyday life either'. Yet even supposing this were the
130
+ case—for after all, 'the whole of science' might err—I should still
131
+ contend that a principle of induction is superfluous, and that it must
132
+ lead to logical inconsistencies.
133
+ -- Karl Popper --
134
+ """
135
+
136
+ _ALLOWED_PLUGINS = [
137
+ "collate_fn",
138
+ "num_workers",
139
+ "max_sequence_length",
140
+ "shuffle",
141
+ "CustomData",
142
+ ]
143
+
144
+
145
+ def _load_plugin(plugin_file: str) -> Dict[str, Union[int, Callable, object, None]]:
146
+ if not plugin_file:
147
+ return {n: None for n in _ALLOWED_PLUGINS}
148
+ from importlib import util as iutil
149
+
150
+ spec = iutil.spec_from_file_location(Path(plugin_file).stem, plugin_file)
151
+ plugins = iutil.module_from_spec(spec)
152
+ spec.loader.exec_module(plugins)
153
+ plugin_names: List[str] = plugins.__all__
154
+ plugin_dict = {}
155
+ for n in _ALLOWED_PLUGINS:
156
+ if n in plugin_names:
157
+ plugin_dict[n] = getattr(plugins, n)
158
+ else:
159
+ plugin_dict[n] = None
160
+ return plugin_dict
161
+
123
162
 
124
163
  def parse_cli(version: str) -> argparse.Namespace:
125
164
  """
@@ -267,6 +306,14 @@ def load_runtime_config(
267
306
  f"\033[0;31mCritical\033[0;0m in {config_file}: Restart checkpoint file {ckpt_file} does not exist."
268
307
  )
269
308
  flag_critical += 1
309
+ # ↓ added in v2.2.0; need to be compatible with old versions.
310
+ plugin_script: str = config["train"].get("plugin_script", "")
311
+ if plugin_script:
312
+ if not os.path.exists(plugin_script):
313
+ print(
314
+ f"\033[0;31mCritical\033[0;0m in {config_file}: Plugin script {plugin_script} does not exist."
315
+ )
316
+ flag_critical += 1
270
317
  if "inference" in config:
271
318
  if not "train" in config:
272
319
  if not isinstance(config["inference"]["sequence_length"], int):
@@ -335,14 +382,15 @@ def main_script(version: str) -> None:
335
382
  )
336
383
  flag_warning += 1
337
384
  if not os.path.exists(runtime_config["train"]["checkpoint_save_path"]):
338
- os.makedirs(runtime_config["train"]["checkpoint_save_path"])
385
+ if not parser.dryrun: # only create it in real tasks
386
+ os.makedirs(runtime_config["train"]["checkpoint_save_path"])
339
387
  else:
340
388
  if not model_config["ChemBFN"]["base_model"]:
341
389
  print(
342
390
  f"\033[0;33mWarning\033[0;0m in {parser.model_config}: You should load a pretrained ChemBFN model."
343
391
  )
344
392
  flag_warning += 1
345
- if not model_config["MLP"]["base_model"]:
393
+ if "MLP" in model_config and not model_config["MLP"]["base_model"]:
346
394
  print(
347
395
  f"\033[0;33mWarning\033[0;0m in {parser.model_config}: You should load a pretrained MLP."
348
396
  )
@@ -350,18 +398,22 @@ def main_script(version: str) -> None:
350
398
  if "inference" in runtime_config:
351
399
  if runtime_config["inference"]["guidance_objective"]:
352
400
  if not "MLP" in model_config:
353
- print(f"Warning in {parser.model_config}: Oh no, you don't have a MLP.")
401
+ print(
402
+ f"\033[0;33mWarning\033[0;0m in {parser.model_config}: Oh no, you don't have a MLP."
403
+ )
354
404
  flag_warning += 1
355
405
  if parser.dryrun:
356
406
  if flag_critical != 0:
357
407
  print("Configuration check failed!")
358
408
  elif flag_warning != 0:
359
- print("Your job will probably run, but it may not follow your expectation.")
409
+ print(
410
+ "Your job will probably run, but it may not follow your expectations."
411
+ )
360
412
  else:
361
413
  print("Configuration check passed.")
362
414
  return
363
415
  if flag_critical != 0:
364
- raise RuntimeError
416
+ raise RuntimeError(_ERROR_MESSAGE)
365
417
  print(_MESSAGE.format(version))
366
418
  # ####### build tokeniser #######
367
419
  tokeniser_config = runtime_config["tokeniser"]
@@ -371,9 +423,9 @@ def main_script(version: str) -> None:
371
423
  vocab_keys = VOCAB_KEYS
372
424
  tokeniser = smiles2token
373
425
  if tokeniser_name == "fasta":
374
- num_vocab = AA_VOCAB_COUNT
375
- vocab_keys = AA_VOCAB_KEYS
376
- tokeniser = aa2token
426
+ num_vocab = FASTA_VOCAB_COUNT
427
+ vocab_keys = FASTA_VOCAB_KEYS
428
+ tokeniser = fasta2token
377
429
  if tokeniser_name == "selfies":
378
430
  vocab_data = load_vocab(tokeniser_config["vocab"])
379
431
  num_vocab = vocab_data["vocab_count"]
@@ -415,6 +467,15 @@ def main_script(version: str) -> None:
415
467
  mlp = None
416
468
  # ------- train -------
417
469
  if "train" in runtime_config:
470
+ import lightning as L
471
+ from torch.utils.data import DataLoader
472
+ from lightning.pytorch import loggers
473
+ from lightning.pytorch.callbacks import ModelCheckpoint
474
+ from bayesianflow_for_chem.train import Model
475
+
476
+ # ####### get plugins #######
477
+ plugin_file = runtime_config["train"].get("plugin_script", "")
478
+ plugins = _load_plugin(plugin_file)
418
479
  # ####### build scorer #######
419
480
  if (tokeniser_name == "smiles" or tokeniser_name == "safe") and runtime_config[
420
481
  "train"
@@ -428,30 +489,43 @@ def main_script(version: str) -> None:
428
489
  mol_tag = runtime_config["train"]["molecule_tag"]
429
490
  obj_tag = runtime_config["train"]["objective_tag"]
430
491
  dataset_file = runtime_config["train"]["dataset"]
431
- with open(dataset_file, "r") as db:
432
- _data = db.readlines()
433
- header = _data[0]
434
- mol_idx = []
435
- for i, tag in enumerate(header.replace("\n", "").split(",")):
436
- if tag == mol_tag:
437
- mol_idx.append(i)
438
- _data_len = []
439
- for i in _data[1:]:
440
- i = i.replace("\n", "").split(",")
441
- _mol = ".".join([i[j] for j in mol_idx])
442
- _data_len.append(tokeniser(_mol).shape[-1])
443
- lmax = max(_data_len)
444
- dataset = CSVData(dataset_file)
492
+ if plugins["max_sequence_length"]:
493
+ lmax = plugins["max_sequence_length"]
494
+ else:
495
+ with open(dataset_file, "r") as db:
496
+ _data = db.readlines()
497
+ _header = _data[0]
498
+ _mol_idx = []
499
+ for i, tag in enumerate(_header.replace("\n", "").split(",")):
500
+ if tag == mol_tag:
501
+ _mol_idx.append(i)
502
+ _data_len = []
503
+ for i in _data[1:]:
504
+ i = i.replace("\n", "").split(",")
505
+ _mol = ".".join([i[j] for j in _mol_idx])
506
+ _data_len.append(tokeniser(_mol).shape[-1])
507
+ lmax = max(_data_len)
508
+ del _data, _data_len, _header, _mol_idx # clear memory
509
+ if plugins["CustomData"] is not None:
510
+ dataset = plugins["CustomData"](dataset_file)
511
+ else:
512
+ dataset = CSVData(dataset_file)
445
513
  dataset.map(
446
514
  partial(_encode, mol_tag=mol_tag, obj_tag=obj_tag, tokeniser=tokeniser)
447
515
  )
448
516
  dataloader = DataLoader(
449
517
  dataset,
450
518
  runtime_config["train"]["batch_size"],
451
- True,
452
- num_workers=4,
453
- collate_fn=collate,
454
- persistent_workers=True,
519
+ True if plugins["shuffle"] is None else plugins["shuffle"],
520
+ num_workers=4 if plugins["num_workers"] is None else plugins["num_workers"],
521
+ collate_fn=(
522
+ collate if plugins["collate_fn"] is None else plugins["collate_fn"]
523
+ ),
524
+ persistent_workers=(
525
+ True
526
+ if (plugins["num_workers"] is None or plugins["num_workers"] > 0)
527
+ else False
528
+ ),
455
529
  )
456
530
  # ####### build trainer #######
457
531
  logger_name = runtime_config["train"]["logger_name"].lower()
@@ -530,6 +604,7 @@ def main_script(version: str) -> None:
530
604
  if "train" in runtime_config:
531
605
  bfn = model.model
532
606
  mlp = model.mlp
607
+ # ↓ added in v2.1.0; need to be compatible with old versions
533
608
  lora_scaling = runtime_config["inference"].get("lora_scaling", 1.0)
534
609
  # ####### strat inference #######
535
610
  bfn.semi_autoregressive = runtime_config["inference"]["semi_autoregressive"]
@@ -622,6 +697,7 @@ def main_script(version: str) -> None:
622
697
  f.write("\n".join(mols))
623
698
  # ------- finished -------
624
699
  print(" ####### job finished #######")
700
+ print(_END_MESSAGE)
625
701
 
626
702
 
627
703
  if __name__ == "__main__":
@@ -1,7 +1,7 @@
1
1
  # -*- coding: utf-8 -*-
2
2
  # Author: Nianze A. TAO (Omozawa SUENO)
3
3
  """
4
- Tokenise SMILES/SAFE/SELFIES/protein-sequence strings.
4
+ Tokenise SMILES/SAFE/SELFIES/FASTA strings.
5
5
  """
6
6
  import os
7
7
  import re
@@ -14,7 +14,7 @@ from torch.utils.data import Dataset
14
14
 
15
15
  __filedir__ = Path(__file__).parent
16
16
 
17
- SMI_REGEX_PATTERN = (
17
+ _SMI_REGEX_PATTERN = (
18
18
  r"(\[|\]|H[e,f,g,s,o]?|"
19
19
  r"L[i,v,a,r,u]|"
20
20
  r"B[e,r,a,i,h,k]?|"
@@ -31,11 +31,11 @@ SMI_REGEX_PATTERN = (
31
31
  r"\(|\)|\.|=|#|-|\+|\\|\/|:|"
32
32
  r"~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
33
33
  )
34
- SEL_REGEX_PATTERN = r"(\[[^\]]+]|\.)"
35
- AA_REGEX_PATTERN = r"(A|B|C|D|E|F|G|H|I|K|L|M|N|P|Q|R|S|T|V|W|Y|Z|-|.)"
36
- smi_regex = re.compile(SMI_REGEX_PATTERN)
37
- sel_regex = re.compile(SEL_REGEX_PATTERN)
38
- aa_regex = re.compile(AA_REGEX_PATTERN)
34
+ _SEL_REGEX_PATTERN = r"(\[[^\]]+]|\.)"
35
+ _FAS_REGEX_PATTERN = r"(A|B|C|D|E|F|G|H|I|J|K|L|M|N|O|P|Q|R|S|T|U|V|W|X|Y|Z|-|\*|\.)"
36
+ _smi_regex = re.compile(_SMI_REGEX_PATTERN)
37
+ _sel_regex = re.compile(_SEL_REGEX_PATTERN)
38
+ _fas_regex = re.compile(_FAS_REGEX_PATTERN)
39
39
 
40
40
 
41
41
  def load_vocab(
@@ -61,15 +61,16 @@ def load_vocab(
61
61
  }
62
62
 
63
63
 
64
- _DEFUALT_VOCAB = load_vocab(__filedir__ / "vocab.txt")
64
+ _DEFUALT_VOCAB = load_vocab(__filedir__ / "_data/vocab.txt")
65
65
  VOCAB_KEYS: List[str] = _DEFUALT_VOCAB["vocab_keys"]
66
66
  VOCAB_DICT: Dict[str, int] = _DEFUALT_VOCAB["vocab_dict"]
67
67
  VOCAB_COUNT: int = _DEFUALT_VOCAB["vocab_count"]
68
- AA_VOCAB_KEYS = (
69
- VOCAB_KEYS[0:3] + "A B C D E F G H I K L M N P Q R S T V W Y Z - .".split()
68
+ FASTA_VOCAB_KEYS = (
69
+ VOCAB_KEYS[0:3]
70
+ + "A B C D E F G H I K L M N P Q R S T V W Y Z - . J O U X *".split()
70
71
  )
71
- AA_VOCAB_COUNT = len(AA_VOCAB_KEYS)
72
- AA_VOCAB_DICT = dict(zip(AA_VOCAB_KEYS, range(AA_VOCAB_COUNT)))
72
+ FASTA_VOCAB_COUNT = len(FASTA_VOCAB_KEYS)
73
+ FASTA_VOCAB_DICT = dict(zip(FASTA_VOCAB_KEYS, range(FASTA_VOCAB_COUNT)))
73
74
 
74
75
 
75
76
  def smiles2vec(smiles: str) -> List[int]:
@@ -81,21 +82,21 @@ def smiles2vec(smiles: str) -> List[int]:
81
82
  :return: tokens w/o `<start>` and `<end>`
82
83
  :rtype: list
83
84
  """
84
- tokens = [token for token in smi_regex.findall(smiles)]
85
+ tokens = [token for token in _smi_regex.findall(smiles)]
85
86
  return [VOCAB_DICT[token] for token in tokens]
86
87
 
87
88
 
88
- def aa2vec(aa_seq: str) -> List[int]:
89
+ def fasta2vec(fasta: str) -> List[int]:
89
90
  """
90
- Protein sequence tokenisation using a dataset-independent regex pattern.
91
+ FASTA sequence tokenisation using a dataset-independent regex pattern.
91
92
 
92
- :param aa_seq: protein (amino acid) sequence
93
- :type aa_seq: str
93
+ :param fasta: protein (amino acid) sequence
94
+ :type fasta: str
94
95
  :return: tokens w/o `<start>` and `<end>`
95
96
  :rtype: list
96
97
  """
97
- tokens = [token for token in aa_regex.findall(aa_seq)]
98
- return [AA_VOCAB_DICT[token] for token in tokens]
98
+ tokens = [token for token in _fas_regex.findall(fasta)]
99
+ return [FASTA_VOCAB_DICT[token] for token in tokens]
99
100
 
100
101
 
101
102
  def split_selfies(selfies: str) -> List[str]:
@@ -107,7 +108,7 @@ def split_selfies(selfies: str) -> List[str]:
107
108
  :return: SELFIES vocab
108
109
  :rtype: list
109
110
  """
110
- return [token for token in sel_regex.findall(selfies)]
111
+ return [token for token in _sel_regex.findall(selfies)]
111
112
 
112
113
 
113
114
  def smiles2token(smiles: str) -> Tensor:
@@ -115,9 +116,9 @@ def smiles2token(smiles: str) -> Tensor:
115
116
  return torch.tensor([1] + smiles2vec(smiles) + [2], dtype=torch.long)
116
117
 
117
118
 
118
- def aa2token(aa_seq: str) -> Tensor:
119
+ def fasta2token(fasta: str) -> Tensor:
119
120
  # start token: <start> = 1; end token: <end> = 2
120
- return torch.tensor([1] + aa2vec(aa_seq) + [2], dtype=torch.long)
121
+ return torch.tensor([1] + fasta2vec(fasta) + [2], dtype=torch.long)
121
122
 
122
123
 
123
124
  def collate(batch: List[Dict[str, Tensor]]) -> Dict[str, Tensor]:
@@ -659,12 +659,91 @@ class ChemBFN(nn.Module):
659
659
  return (-logits.gather(-1, x[..., :1]).squeeze(-1)).mean()
660
660
 
661
661
  @staticmethod
662
- def reshape_y(y: Tensor) -> Tensor:
662
+ def _reshape(y: Tensor) -> Tensor:
663
663
  assert y.dim() <= 3 # this doesn't work if the model is frezen in JIT.
664
664
  if y.dim() == 2:
665
665
  return y[:, None, :]
666
666
  return y
667
667
 
668
+ def _process(
669
+ self,
670
+ theta: Tensor,
671
+ mask: Optional[Tuple[Tensor, Tensor]],
672
+ y: Optional[Tensor],
673
+ sample_step: int,
674
+ guidance_strength: float,
675
+ token_mask: Optional[Tensor],
676
+ ) -> Tuple[Tensor, Tensor]:
677
+ # BFN inference process.
678
+ #
679
+ # theta: piror distribution; shape: (n_b, n_t, n_vocab)
680
+ # mask: masked condition distribution; shape: (n_b, n_t, n_vocab)
681
+ # condition distribution mask; shape: (n_b, n_t, 1)
682
+ n_b = theta.shape[0]
683
+ if y is not None:
684
+ y = self._reshape(y)
685
+ for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
686
+ t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
687
+ p = self.discrete_output_distribution(theta, t, y, guidance_strength)
688
+ if token_mask is not None:
689
+ p = p.masked_fill_(token_mask, 0.0)
690
+ alpha = self.calc_discrete_alpha(t, t + 1 / sample_step)
691
+ e_k = nn.functional.one_hot(torch.argmax(p, -1), self.K).float()
692
+ mu = alpha * (self.K * e_k - 1)
693
+ sigma = (alpha * self.K).sqrt()
694
+ theta = (mu + sigma * torch.randn_like(mu)).exp() * theta
695
+ theta = theta / theta.sum(-1, True)
696
+ if mask is not None:
697
+ x_onehot, x_mask = mask
698
+ theta = x_onehot + (1 - x_mask) * theta
699
+ t_final = torch.ones((n_b, 1, 1), device=self.beta.device)
700
+ p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
701
+ entropy = -(p * p.log()).sum(-1).mean(-1)
702
+ if token_mask is not None:
703
+ p = p.masked_fill_(token_mask, 0.0)
704
+ return torch.argmax(p, -1), entropy
705
+
706
+ def _ode_process(
707
+ self,
708
+ z: Tensor,
709
+ mask: Optional[Tuple[Tensor, Tensor]],
710
+ y: Optional[Tensor],
711
+ sample_step: int,
712
+ guidance_strength: float,
713
+ token_mask: Optional[Tensor],
714
+ temperature: float,
715
+ ) -> Tuple[Tensor, Tensor]:
716
+ # ODE-solver engaged inference process.
717
+ #
718
+ # z: prior latent vector; shape: (n_b, n_t, n_vocab)
719
+ # mask: masked condition distribution; shape: (n_b, n_t, n_vocab)
720
+ # condition distribution mask; shape: (n_b, n_t, 1)
721
+ n_b = z.shape[0]
722
+ if y is not None:
723
+ y = self._reshape(y)
724
+ for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
725
+ t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
726
+ theta = softmax(z, -1)
727
+ if mask is not None:
728
+ x_onehot, x_mask = mask
729
+ theta = x_onehot + (1 - x_mask) * theta
730
+ beta = self.calc_beta(t + 1 / sample_step)
731
+ p = self.discrete_output_distribution(theta, t, y, guidance_strength)
732
+ if token_mask is not None:
733
+ p = p.masked_fill_(token_mask, 0.0)
734
+ u = torch.randn_like(z)
735
+ z = (self.K * p - 1) * beta + (self.K * beta * temperature).sqrt() * u
736
+ t_final = torch.ones((n_b, 1, 1), device=self.beta.device)
737
+ theta = softmax(z, -1)
738
+ if mask is not None:
739
+ x_onehot, x_mask = mask
740
+ theta = x_onehot + (1 - x_mask) * theta
741
+ p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
742
+ entropy = -(p * p.log()).sum(-1).mean(-1)
743
+ if token_mask is not None:
744
+ p = p.masked_fill_(token_mask, 0.0)
745
+ return torch.argmax(p, -1), entropy
746
+
668
747
  @torch.jit.export
669
748
  def sample(
670
749
  self,
@@ -680,44 +759,26 @@ class ChemBFN(nn.Module):
680
759
 
681
760
  :param batch_size: batch size
682
761
  :param sequence_size: max sequence length
683
- :param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
762
+ :param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
684
763
  :param sample_step: number of sampling steps
685
764
  :param guidance_strength: strength of conditional generation. It is not used if y is null.
686
765
  :param token_mask: token mask assigning unwanted token(s) with `True`;
687
- shape: (1, 1, n_vocab)
766
+ shape: (1, 1, n_vocab)
688
767
  :type batch_size: int
689
768
  :type sequence_size: int
690
769
  :type y: torch.Tensor | None
691
770
  :type sample_step: int
692
771
  :type guidance_strength: float
693
772
  :type token_mask: torch.Tensor | None
694
- :return: sampled token indices; shape: (n_b, n_t) \n
695
- entropy of the tokens; shape: (n_b)
773
+ :return: sampled token indices; shape: (n_b, n_t) \n
774
+ entropy of the tokens; shape: (n_b)
696
775
  :rtype: tuple
697
776
  """
698
777
  theta = (
699
778
  torch.ones((batch_size, sequence_size, self.K), device=self.beta.device)
700
779
  / self.K
701
780
  )
702
- if y is not None:
703
- y = self.reshape_y(y)
704
- for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
705
- t = (i - 1).view(1, 1, 1).repeat(batch_size, 1, 1) / sample_step
706
- p = self.discrete_output_distribution(theta, t, y, guidance_strength)
707
- if token_mask is not None:
708
- p = p.masked_fill_(token_mask, 0.0)
709
- alpha = self.calc_discrete_alpha(t, t + 1 / sample_step)
710
- e_k = nn.functional.one_hot(torch.argmax(p, -1), self.K).float()
711
- mu = alpha * (self.K * e_k - 1)
712
- sigma = (alpha * self.K).sqrt()
713
- theta = (mu + sigma * torch.randn_like(mu)).exp() * theta
714
- theta = theta / theta.sum(-1, True)
715
- t_final = torch.ones((batch_size, 1, 1), device=self.beta.device)
716
- p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
717
- entropy = -(p * p.log()).sum(-1).mean(-1)
718
- if token_mask is not None:
719
- p = p.masked_fill_(token_mask, 0.0)
720
- return torch.argmax(p, -1), entropy
781
+ return self._process(theta, None, y, sample_step, guidance_strength, token_mask)
721
782
 
722
783
  @torch.jit.export
723
784
  def ode_sample(
@@ -735,11 +796,11 @@ class ChemBFN(nn.Module):
735
796
 
736
797
  :param batch_size: batch size
737
798
  :param sequence_size: max sequence length
738
- :param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
799
+ :param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
739
800
  :param sample_step: number of sampling steps
740
801
  :param guidance_strength: strength of conditional generation. It is not used if y is null.
741
802
  :param token_mask: token mask assigning unwanted token(s) with `True`;
742
- shape: (1, 1, n_vocab)
803
+ shape: (1, 1, n_vocab)
743
804
  :param temperature: sampling temperature
744
805
  :type batch_size: int
745
806
  :type sequence_size: int
@@ -748,29 +809,14 @@ class ChemBFN(nn.Module):
748
809
  :type guidance_strength: float
749
810
  :type token_mask: torch.Tensor | None
750
811
  :type temperature: float
751
- :return: sampled token indices; shape: (n_b, n_t) \n
752
- entropy of the tokens; shape: (n_b)
812
+ :return: sampled token indices; shape: (n_b, n_t) \n
813
+ entropy of the tokens; shape: (n_b)
753
814
  :rtype: tuple
754
815
  """
755
816
  z = torch.zeros((batch_size, sequence_size, self.K), device=self.beta.device)
756
- if y is not None:
757
- y = self.reshape_y(y)
758
- for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
759
- t = (i - 1).view(1, 1, 1).repeat(batch_size, 1, 1) / sample_step
760
- theta = torch.softmax(z, -1)
761
- beta = self.calc_beta(t + 1 / sample_step)
762
- p = self.discrete_output_distribution(theta, t, y, guidance_strength)
763
- if token_mask is not None:
764
- p = p.masked_fill_(token_mask, 0.0)
765
- u = torch.randn_like(z)
766
- z = (self.K * p - 1) * beta + (self.K * beta * temperature).sqrt() * u
767
- t_final = torch.ones((batch_size, 1, 1), device=self.beta.device)
768
- theta = torch.softmax(z, -1)
769
- p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
770
- entropy = -(p * p.log()).sum(-1).mean(-1)
771
- if token_mask is not None:
772
- p = p.masked_fill_(token_mask, 0.0)
773
- return torch.argmax(p, -1), entropy
817
+ return self._ode_process(
818
+ z, None, y, sample_step, guidance_strength, token_mask, temperature
819
+ )
774
820
 
775
821
  @torch.jit.export
776
822
  def inpaint(
@@ -800,30 +846,12 @@ class ChemBFN(nn.Module):
800
846
  :rtype: tuple
801
847
  """
802
848
  n_b, n_t = x.shape
803
- mask = (x != 0).float()[..., None]
849
+ x_mask = (x != 0).float()[..., None]
804
850
  theta = torch.ones((n_b, n_t, self.K), device=x.device) / self.K
805
- x_onehot = nn.functional.one_hot(x, self.K) * mask
806
- theta = x_onehot + (1 - mask) * theta
807
- if y is not None:
808
- y = self.reshape_y(y)
809
- for i in torch.linspace(1, sample_step, sample_step, device=x.device):
810
- t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
811
- p = self.discrete_output_distribution(theta, t, y, guidance_strength)
812
- if token_mask is not None:
813
- p = p.masked_fill_(token_mask, 0.0)
814
- alpha = self.calc_discrete_alpha(t, t + 1 / sample_step)
815
- e_k = nn.functional.one_hot(torch.argmax(p, -1), self.K).float()
816
- mu = alpha * (self.K * e_k - 1)
817
- sigma = (alpha * self.K).sqrt()
818
- theta = (mu + sigma * torch.randn_like(mu)).exp() * theta
819
- theta = theta / theta.sum(-1, True)
820
- theta = x_onehot + (1 - mask) * theta
821
- t_final = torch.ones((n_b, 1, 1), device=x.device)
822
- p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
823
- entropy = -(p * p.log()).sum(-1).mean(-1)
824
- if token_mask is not None:
825
- p = p.masked_fill_(token_mask, 0.0)
826
- return torch.argmax(p, -1), entropy
851
+ x_onehot = nn.functional.one_hot(x, self.K) * x_mask
852
+ theta = x_onehot + (1 - x_mask) * theta
853
+ mask = (x_onehot, x_mask)
854
+ return self._process(theta, mask, y, sample_step, guidance_strength, token_mask)
827
855
 
828
856
  @torch.jit.export
829
857
  def ode_inpaint(
@@ -856,29 +884,13 @@ class ChemBFN(nn.Module):
856
884
  :rtype: tuple
857
885
  """
858
886
  n_b, n_t = x.shape
859
- mask = (x != 0).float()[..., None]
860
- x_onehot = nn.functional.one_hot(x, self.K) * mask
887
+ x_mask = (x != 0).float()[..., None]
888
+ x_onehot = nn.functional.one_hot(x, self.K) * x_mask
861
889
  z = torch.zeros((n_b, n_t, self.K), device=self.beta.device)
862
- if y is not None:
863
- y = self.reshape_y(y)
864
- for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
865
- t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
866
- theta = torch.softmax(z, -1)
867
- theta = x_onehot + (1 - mask) * theta
868
- beta = self.calc_beta(t + 1 / sample_step)
869
- p = self.discrete_output_distribution(theta, t, y, guidance_strength)
870
- if token_mask is not None:
871
- p = p.masked_fill_(token_mask, 0.0)
872
- u = torch.randn_like(z)
873
- z = (self.K * p - 1) * beta + (self.K * beta * temperature).sqrt() * u
874
- t_final = torch.ones((n_b, 1, 1), device=self.beta.device)
875
- theta = torch.softmax(z, -1)
876
- theta = x_onehot + (1 - mask) * theta
877
- p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
878
- entropy = -(p * p.log()).sum(-1).mean(-1)
879
- if token_mask is not None:
880
- p = p.masked_fill_(token_mask, 0.0)
881
- return torch.argmax(p, -1), entropy
890
+ mask = (x_onehot, x_mask)
891
+ return self._ode_process(
892
+ z, mask, y, sample_step, guidance_strength, token_mask, temperature
893
+ )
882
894
 
883
895
  @torch.jit.export
884
896
  def optimise(
@@ -908,28 +920,9 @@ class ChemBFN(nn.Module):
908
920
  entropy of the tokens; shape: (n_b)
909
921
  :rtype: tuple
910
922
  """
911
- n_b = x.shape[0]
912
923
  x_onehot = nn.functional.one_hot(x, self.K).float()
913
- theta = nn.functional.softmax(x_onehot, -1)
914
- if y is not None:
915
- y = self.reshape_y(y)
916
- for i in torch.linspace(1, sample_step, sample_step, device=x.device):
917
- t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
918
- p = self.discrete_output_distribution(theta, t, y, guidance_strength)
919
- if token_mask is not None:
920
- p = p.masked_fill_(token_mask, 0.0)
921
- alpha = self.calc_discrete_alpha(t, t + 1 / sample_step)
922
- e_k = nn.functional.one_hot(torch.argmax(p, -1), self.K).float()
923
- mu = alpha * (self.K * e_k - 1)
924
- sigma = (alpha * self.K).sqrt()
925
- theta = (mu + sigma * torch.randn_like(mu)).exp() * theta
926
- theta = theta / theta.sum(-1, True)
927
- t_final = torch.ones((n_b, 1, 1), device=x.device)
928
- p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
929
- entropy = -(p * p.log()).sum(-1).mean(-1)
930
- if token_mask is not None:
931
- p = p.masked_fill_(token_mask, 0.0)
932
- return torch.argmax(p, -1), entropy
924
+ theta = softmax(x_onehot, -1)
925
+ return self._process(theta, None, y, sample_step, guidance_strength, token_mask)
933
926
 
934
927
  @torch.jit.export
935
928
  def ode_optimise(
@@ -961,26 +954,10 @@ class ChemBFN(nn.Module):
961
954
  entropy of the tokens; shape: (n_b)
962
955
  :rtype: tuple
963
956
  """
964
- n_b = x.shape[0]
965
957
  z = nn.functional.one_hot(x, self.K).float()
966
- if y is not None:
967
- y = self.reshape_y(y)
968
- for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
969
- t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
970
- theta = torch.softmax(z, -1)
971
- beta = self.calc_beta(t + 1 / sample_step)
972
- p = self.discrete_output_distribution(theta, t, y, guidance_strength)
973
- if token_mask is not None:
974
- p = p.masked_fill_(token_mask, 0.0)
975
- u = torch.randn_like(z)
976
- z = (self.K * p - 1) * beta + (self.K * beta * temperature).sqrt() * u
977
- t_final = torch.ones((n_b, 1, 1), device=self.beta.device)
978
- theta = torch.softmax(z, -1)
979
- p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
980
- entropy = -(p * p.log()).sum(-1).mean(-1)
981
- if token_mask is not None:
982
- p = p.masked_fill_(token_mask, 0.0)
983
- return torch.argmax(p, -1), entropy
958
+ return self._ode_process(
959
+ z, None, y, sample_step, guidance_strength, token_mask, temperature
960
+ )
984
961
 
985
962
  def inference(
986
963
  self, x: Tensor, mlp: MLP, embed_fn: Optional[Callable[[Tensor], Tensor]] = None
@@ -1154,22 +1131,6 @@ class EnsembleChemBFN(ChemBFN):
1154
1131
  module.lora_dropout = None
1155
1132
  v.lora_enabled = False
1156
1133
 
1157
- def construct_y(
1158
- self, c: Union[List[Tensor], Dict[str, Tensor]]
1159
- ) -> Dict[str, Tensor]:
1160
- assert (
1161
- isinstance(c, dict) is self._label_is_dict
1162
- ), f"`c` should be a {'`dict` instance' if self._label_is_dict else '`list` instance'} but got {type(c)} instand."
1163
- out: Dict[str, Tensor] = {}
1164
- if isinstance(c, list):
1165
- c = dict(zip([f"val_{i}" for i in range(len(c))], c))
1166
- for name, model in self.cond_heads.items():
1167
- y = model.forward(c[name])
1168
- if y.dim() == 2:
1169
- y = y[:, None, :]
1170
- out[name] = y
1171
- return out
1172
-
1173
1134
  def discrete_output_distribution(
1174
1135
  self, theta: Tensor, t: Tensor, y: Dict[str, Tensor], w: float
1175
1136
  ) -> Tensor:
@@ -1204,8 +1165,24 @@ class EnsembleChemBFN(ChemBFN):
1204
1165
  p_cond += p_cond_ * self.adapter_weights[name]
1205
1166
  return softmax((1 + w) * p_cond - w * p_uncond, -1)
1206
1167
 
1168
+ def _map_to_dict(
1169
+ self, c: Union[List[Tensor], Dict[str, Tensor]]
1170
+ ) -> Dict[str, Tensor]:
1171
+ assert (
1172
+ isinstance(c, dict) is self._label_is_dict
1173
+ ), f"`c` should be a {'`dict` instance' if self._label_is_dict else '`list` instance'} but got {type(c)} instand."
1174
+ out: Dict[str, Tensor] = {}
1175
+ if isinstance(c, list):
1176
+ c = dict(zip([f"val_{i}" for i in range(len(c))], c))
1177
+ for name, model in self.cond_heads.items():
1178
+ y = model.forward(c[name])
1179
+ if y.dim() == 2:
1180
+ y = y[:, None, :]
1181
+ out[name] = y
1182
+ return out
1183
+
1207
1184
  @staticmethod
1208
- def reshape_y(y: Dict[str, Tensor]) -> Dict[str, Tensor]:
1185
+ def _reshape(y: Dict[str, Tensor]) -> Dict[str, Tensor]:
1209
1186
  for k in y:
1210
1187
  assert y[k].dim() <= 3
1211
1188
  if y[k].dim() == 2:
@@ -1241,7 +1218,7 @@ class EnsembleChemBFN(ChemBFN):
1241
1218
  entropy of the tokens; shape: (n_b)
1242
1219
  :rtype: tuple
1243
1220
  """
1244
- y = self.construct_y(conditions)
1221
+ y = self._map_to_dict(conditions)
1245
1222
  return super().sample(
1246
1223
  batch_size, sequence_size, y, sample_step, guidance_strength, token_mask
1247
1224
  )
@@ -1278,7 +1255,7 @@ class EnsembleChemBFN(ChemBFN):
1278
1255
  entropy of the tokens; shape: (n_b)
1279
1256
  :rtype: tuple
1280
1257
  """
1281
- y = self.construct_y(conditions)
1258
+ y = self._map_to_dict(conditions)
1282
1259
  return super().ode_sample(
1283
1260
  batch_size,
1284
1261
  sequence_size,
@@ -1315,7 +1292,7 @@ class EnsembleChemBFN(ChemBFN):
1315
1292
  entropy of the tokens; shape: (n_b)
1316
1293
  :rtype: tuple
1317
1294
  """
1318
- y = self.construct_y(conditions)
1295
+ y = self._map_to_dict(conditions)
1319
1296
  return super().inpaint(x, y, sample_step, guidance_strength, token_mask)
1320
1297
 
1321
1298
  @torch.inference_mode()
@@ -1347,7 +1324,7 @@ class EnsembleChemBFN(ChemBFN):
1347
1324
  entropy of the tokens; shape: (n_b)
1348
1325
  :rtype: tuple
1349
1326
  """
1350
- y = self.construct_y(conditions)
1327
+ y = self._map_to_dict(conditions)
1351
1328
  return super().ode_inpaint(
1352
1329
  x, y, sample_step, guidance_strength, token_mask, temperature
1353
1330
  )
@@ -1380,7 +1357,7 @@ class EnsembleChemBFN(ChemBFN):
1380
1357
  entropy of the tokens; shape: (n_b)
1381
1358
  :rtype: tuple
1382
1359
  """
1383
- y = self.construct_y(conditions)
1360
+ y = self._map_to_dict(conditions)
1384
1361
  return super().optimise(x, y, sample_step, guidance_strength, token_mask)
1385
1362
 
1386
1363
  @torch.inference_mode()
@@ -1412,7 +1389,7 @@ class EnsembleChemBFN(ChemBFN):
1412
1389
  entropy of the tokens; shape: (n_b)
1413
1390
  :rtype: tuple
1414
1391
  """
1415
- y = self.construct_y(conditions)
1392
+ y = self._map_to_dict(conditions)
1416
1393
  return super().ode_optimise(
1417
1394
  x, y, sample_step, guidance_strength, token_mask, temperature
1418
1395
  )
@@ -7,9 +7,8 @@ import csv
7
7
  import random
8
8
  import warnings
9
9
  from pathlib import Path
10
- from typing import List, Dict, Tuple, Union, Optional
10
+ from typing import List, Dict, Tuple, Union, Optional, Literal
11
11
  import torch
12
- import colorama
13
12
  import numpy as np
14
13
  from torch import cuda, Tensor, softmax
15
14
  from torch.utils.data import DataLoader
@@ -24,15 +23,7 @@ from rdkit.Chem import (
24
23
  AddHs,
25
24
  Mol,
26
25
  )
27
- from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles # type: ignore
28
- from sklearn.metrics import (
29
- roc_auc_score,
30
- auc,
31
- precision_recall_curve,
32
- r2_score,
33
- mean_absolute_error,
34
- root_mean_squared_error,
35
- )
26
+ from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles
36
27
  from .data import VOCAB_KEYS
37
28
  from .model import ChemBFN, MLP, EnsembleChemBFN
38
29
 
@@ -45,12 +36,74 @@ def _find_device() -> torch.device:
45
36
  return torch.device("cpu")
46
37
 
47
38
 
39
+ def _parse_and_assert_param(
40
+ model: Union[ChemBFN, EnsembleChemBFN],
41
+ y: Optional[Union[Tensor, Dict[str, Tensor], List[Tensor]]],
42
+ method: str,
43
+ ) -> Optional[float]:
44
+ assert method.split(":")[0].lower() in ("ode", "bfn")
45
+ if isinstance(model, EnsembleChemBFN):
46
+ assert y is not None, "conditioning is required while using an ensemble model."
47
+ assert isinstance(y, list) or isinstance(y, dict)
48
+ else:
49
+ assert isinstance(y, Tensor) or (y is None)
50
+ if "ode" in method.lower():
51
+ tp = float(method.split(":")[-1])
52
+ assert tp > 0, "Sampling temperature should be higher than 0."
53
+ return tp
54
+ return None
55
+
56
+
57
+ def _map_to_device(
58
+ y: Optional[Union[Tensor, Dict[str, Tensor], List[Tensor]]],
59
+ device: Union[str, torch.device],
60
+ ) -> Optional[Union[Tensor, Dict[str, Tensor], List[Tensor]]]:
61
+ if y is not None:
62
+ if isinstance(y, Tensor):
63
+ y = y.to(device)
64
+ elif isinstance(y, list):
65
+ y = [i.to(device) for i in y]
66
+ elif isinstance(y, dict):
67
+ y = {k: v.to(device) for k, v in y.items()}
68
+ else:
69
+ raise NotImplementedError
70
+ return y
71
+
72
+
73
+ def _build_token_mask(
74
+ allowed_tokens: Union[str, List[str]],
75
+ vocab_keys: List[str],
76
+ device: Union[str, torch.tensor],
77
+ ) -> Optional[Tensor]:
78
+ if isinstance(allowed_tokens, list):
79
+ token_mask = [0 if i in allowed_tokens else 1 for i in vocab_keys]
80
+ token_mask = torch.tensor([[token_mask]], dtype=torch.bool).to(device)
81
+ else:
82
+ token_mask = None
83
+ return token_mask
84
+
85
+
86
+ def _token_to_seq(
87
+ tokens: Tensor, entropy: Tensor, vocab_keys: List[str], separator: str, sort: bool
88
+ ) -> List[str]:
89
+ if sort:
90
+ sorted_idx = entropy.argsort(stable=True)
91
+ tokens = tokens[sorted_idx]
92
+ return [
93
+ separator.join([vocab_keys[i] for i in j])
94
+ .split("<start>" + separator)[-1]
95
+ .split(separator + "<end>")[0]
96
+ .replace("<pad>", "")
97
+ for j in tokens
98
+ ]
99
+
100
+
48
101
  @torch.no_grad()
49
102
  def test(
50
103
  model: ChemBFN,
51
104
  mlp: MLP,
52
105
  data: DataLoader,
53
- mode: str = "regression",
106
+ mode: Literal["regression", "classification"] = "regression",
54
107
  device: Union[str, torch.device, None] = None,
55
108
  ) -> Dict[str, float]:
56
109
  """
@@ -86,6 +139,12 @@ def test(
86
139
  predict_y.append(y_hat.detach().to("cpu"))
87
140
  predict_y, label_y = torch.cat(predict_y, 0), torch.cat(label_y, 0).split(1, -1)
88
141
  if mode == "regression":
142
+ from sklearn.metrics import (
143
+ r2_score,
144
+ mean_absolute_error,
145
+ root_mean_squared_error,
146
+ )
147
+
89
148
  predict_y = [
90
149
  predict[label_y[i] != torch.inf]
91
150
  for (i, predict) in enumerate(predict_y.split(1, -1))
@@ -99,6 +158,8 @@ def test(
99
158
  r2 = [r2_score(label, predict) for (label, predict) in y_zipped]
100
159
  return {"MAE": mae, "RMSE": rmse, "R^2": r2}
101
160
  if mode == "classification":
161
+ from sklearn.metrics import roc_auc_score, auc, precision_recall_curve
162
+
102
163
  n_c = len(label_y)
103
164
  predict_y = predict_y.chunk(n_c, -1)
104
165
  y_zipped = list(zip(label_y, predict_y))
@@ -123,7 +184,9 @@ def test(
123
184
 
124
185
 
125
186
  def split_dataset(
126
- file: Union[str, Path], split_ratio: List[int] = [8, 1, 1], method: str = "random"
187
+ file: Union[str, Path],
188
+ split_ratio: List[int] = [8, 1, 1],
189
+ method: Literal["random", "scaffold"] = "random",
127
190
  ) -> None:
128
191
  """
129
192
  Split a dataset.
@@ -142,7 +205,6 @@ def split_dataset(
142
205
  assert file.endswith(".csv")
143
206
  assert len(split_ratio) == 3
144
207
  assert method in ("random", "scaffold")
145
- colorama.just_fix_windows_console()
146
208
  with open(file, "r") as f:
147
209
  data = list(csv.reader(f))
148
210
  header = data[0]
@@ -167,10 +229,8 @@ def split_dataset(
167
229
  # compute Bemis-Murcko scaffold
168
230
  if len(smiles_idx) > 1:
169
231
  warnings.warn(
170
- "\033[32;1m"
171
232
  f"We found {len(smiles_idx)} SMILES strings in a row!"
172
- " Only the first SMILES will be used to compute the molecular scaffold."
173
- "\033[0m",
233
+ " Only the first SMILES will be used to compute the molecular scaffold.",
174
234
  stacklevel=2,
175
235
  )
176
236
  try:
@@ -197,10 +257,10 @@ def split_dataset(
197
257
  with open(file.replace(".csv", "_test.csv"), "w", newline="") as fte:
198
258
  writer = csv.writer(fte)
199
259
  writer.writerows([header] + test_set)
200
- with open(file.replace(".csv", "_val.csv"), "w", newline="") as fva:
201
- writer = csv.writer(fva)
202
- writer.writerows([header] + val_set)
203
- colorama.deinit()
260
+ if val_set:
261
+ with open(file.replace(".csv", "_val.csv"), "w", newline="") as fva:
262
+ writer = csv.writer(fva)
263
+ writer.writerows([header] + val_set)
204
264
 
205
265
 
206
266
  @torch.no_grad()
@@ -250,32 +310,12 @@ def sample(
250
310
  :return: a list of generated molecular strings
251
311
  :rtype: list
252
312
  """
253
- assert method.split(":")[0].lower() in ("ode", "bfn")
254
- if isinstance(model, EnsembleChemBFN):
255
- assert y is not None, "conditioning is required while using an ensemble model."
256
- assert isinstance(y, list) or isinstance(y, dict)
257
- else:
258
- assert isinstance(y, Tensor) or y is None
259
- if device is None:
260
- device = _find_device()
313
+ tp = _parse_and_assert_param(model, y, method)
314
+ device = _find_device() if device is None else device
261
315
  model.to(device).eval()
262
- if y is not None:
263
- if isinstance(y, Tensor):
264
- y = y.to(device)
265
- elif isinstance(y, list):
266
- y = [i.to(device) for i in y]
267
- elif isinstance(y, dict):
268
- y = {k: v.to(device) for k, v in y.items()}
269
- else:
270
- raise NotImplementedError
271
- if isinstance(allowed_tokens, list):
272
- token_mask = [0 if i in allowed_tokens else 1 for i in vocab_keys]
273
- token_mask = torch.tensor([[token_mask]], dtype=torch.bool).to(device)
274
- else:
275
- token_mask = None
276
- if "ode" in method.lower():
277
- tp = float(method.split(":")[-1])
278
- assert tp > 0, "Sampling temperature should be higher than 0."
316
+ y = _map_to_device(y, device)
317
+ token_mask = _build_token_mask(allowed_tokens, vocab_keys, device)
318
+ if tp:
279
319
  tokens, entropy = model.ode_sample(
280
320
  batch_size, sequence_size, y, sample_step, guidance_strength, token_mask, tp
281
321
  )
@@ -283,16 +323,7 @@ def sample(
283
323
  tokens, entropy = model.sample(
284
324
  batch_size, sequence_size, y, sample_step, guidance_strength, token_mask
285
325
  )
286
- if sort:
287
- sorted_idx = entropy.argsort(stable=True)
288
- tokens = tokens[sorted_idx]
289
- return [
290
- seperator.join([vocab_keys[i] for i in j])
291
- .split("<start>" + seperator)[-1]
292
- .split(seperator + "<end>")[0]
293
- .replace("<pad>", "")
294
- for j in tokens
295
- ]
326
+ return _token_to_seq(tokens, entropy, vocab_keys, seperator, sort)
296
327
 
297
328
 
298
329
  @torch.no_grad()
@@ -339,33 +370,13 @@ def inpaint(
339
370
  :return: a list of generated molecular strings
340
371
  :rtype: list
341
372
  """
342
- assert method.split(":")[0].lower() in ("ode", "bfn")
343
- if isinstance(model, EnsembleChemBFN):
344
- assert y is not None, "conditioning is required while using an ensemble model."
345
- assert isinstance(y, list) or isinstance(y, dict)
346
- else:
347
- assert isinstance(y, Tensor) or y is None
348
- if device is None:
349
- device = _find_device()
373
+ tp = _parse_and_assert_param(model, y, method)
374
+ device = _find_device() if device is None else device
350
375
  model.to(device).eval()
351
376
  x = x.to(device)
352
- if y is not None:
353
- if isinstance(y, Tensor):
354
- y = y.to(device)
355
- elif isinstance(y, list):
356
- y = [i.to(device) for i in y]
357
- elif isinstance(y, dict):
358
- y = {k: v.to(device) for k, v in y.items()}
359
- else:
360
- raise NotImplementedError
361
- if isinstance(allowed_tokens, list):
362
- token_mask = [0 if i in allowed_tokens else 1 for i in vocab_keys]
363
- token_mask = torch.tensor([[token_mask]], dtype=torch.bool).to(device)
364
- else:
365
- token_mask = None
366
- if "ode" in method.lower():
367
- tp = float(method.split(":")[-1])
368
- assert tp > 0, "Sampling temperature should be higher than 0."
377
+ y = _map_to_device(y, device)
378
+ token_mask = _build_token_mask(allowed_tokens, vocab_keys, device)
379
+ if tp:
369
380
  tokens, entropy = model.ode_inpaint(
370
381
  x, y, sample_step, guidance_strength, token_mask, tp
371
382
  )
@@ -373,16 +384,7 @@ def inpaint(
373
384
  tokens, entropy = model.inpaint(
374
385
  x, y, sample_step, guidance_strength, token_mask
375
386
  )
376
- if sort:
377
- sorted_idx = entropy.argsort(stable=True)
378
- tokens = tokens[sorted_idx]
379
- return [
380
- separator.join([vocab_keys[i] for i in j])
381
- .split("<start>" + separator)[-1]
382
- .split(separator + "<end>")[0]
383
- .replace("<pad>", "")
384
- for j in tokens
385
- ]
387
+ return _token_to_seq(tokens, entropy, vocab_keys, separator, sort)
386
388
 
387
389
 
388
390
  @torch.no_grad()
@@ -429,33 +431,13 @@ def optimise(
429
431
  :return: a list of generated molecular strings
430
432
  :rtype: list
431
433
  """
432
- assert method.split(":")[0].lower() in ("ode", "bfn")
433
- if isinstance(model, EnsembleChemBFN):
434
- assert y is not None, "conditioning is required while using an ensemble model."
435
- assert isinstance(y, list) or isinstance(y, dict)
436
- else:
437
- assert isinstance(y, Tensor) or y is None
438
- if device is None:
439
- device = _find_device()
434
+ tp = _parse_and_assert_param(model, y, method)
435
+ device = _find_device() if device is None else device
440
436
  model.to(device).eval()
441
437
  x = x.to(device)
442
- if y is not None:
443
- if isinstance(y, Tensor):
444
- y = y.to(device)
445
- elif isinstance(y, list):
446
- y = [i.to(device) for i in y]
447
- elif isinstance(y, dict):
448
- y = {k: v.to(device) for k, v in y.items()}
449
- else:
450
- raise NotImplementedError
451
- if isinstance(allowed_tokens, list):
452
- token_mask = [0 if i in allowed_tokens else 1 for i in vocab_keys]
453
- token_mask = torch.tensor([[token_mask]], dtype=torch.bool).to(device)
454
- else:
455
- token_mask = None
456
- if "ode" in method.lower():
457
- tp = float(method.split(":")[-1])
458
- assert tp > 0, "Sampling temperature should be higher than 0."
438
+ y = _map_to_device(y, device)
439
+ token_mask = _build_token_mask(allowed_tokens, vocab_keys, device)
440
+ if tp:
459
441
  tokens, entropy = model.ode_optimise(
460
442
  x, y, sample_step, guidance_strength, token_mask, tp
461
443
  )
@@ -463,16 +445,7 @@ def optimise(
463
445
  tokens, entropy = model.optimise(
464
446
  x, y, sample_step, guidance_strength, token_mask
465
447
  )
466
- if sort:
467
- sorted_idx = entropy.argsort(stable=True)
468
- tokens = tokens[sorted_idx]
469
- return [
470
- separator.join([vocab_keys[i] for i in j])
471
- .split("<start>" + separator)[-1]
472
- .split(separator + "<end>")[0]
473
- .replace("<pad>", "")
474
- for j in tokens
475
- ]
448
+ return _token_to_seq(tokens, entropy, vocab_keys, separator, sort)
476
449
 
477
450
 
478
451
  def quantise_model_(model: ChemBFN) -> None:
@@ -555,7 +528,7 @@ class GeometryConverter:
555
528
  def smiles2cartesian(
556
529
  smiles: str,
557
530
  num_conformers: int = 250,
558
- rdkit_ff_type: str = "MMFF",
531
+ rdkit_ff_type: Literal["MMFF", "UFF"] = "MMFF",
559
532
  refine_with_crest: bool = False,
560
533
  spin: float = 0.0,
561
534
  ) -> Tuple[List[str], np.ndarray]:
@@ -8,7 +8,6 @@ from typing import Dict, Tuple, Union, Optional
8
8
  import torch
9
9
  import torch.optim as op
10
10
  import torch.nn.functional as F
11
- from loralib import lora_state_dict, mark_only_lora_as_trainable
12
11
  from torch import Tensor
13
12
  from torch.optim.lr_scheduler import ReduceLROnPlateau
14
13
  from lightning import LightningModule
@@ -55,6 +54,8 @@ class Model(LightningModule):
55
54
  self.scorer = scorer
56
55
  self.save_hyperparameters(hparam, ignore=["model", "mlp", "scorer"])
57
56
  if model.lora_enabled:
57
+ from loralib import mark_only_lora_as_trainable
58
+
58
59
  mark_only_lora_as_trainable(self.model)
59
60
  self.use_scorer = self.scorer is not None
60
61
 
@@ -107,6 +108,8 @@ class Model(LightningModule):
107
108
  :rtype: None
108
109
  """
109
110
  if self.model.lora_enabled:
111
+ from loralib import lora_state_dict
112
+
110
113
  torch.save(
111
114
  {
112
115
  "lora_nn": lora_state_dict(self.model),
@@ -152,6 +155,8 @@ class Regressor(LightningModule):
152
155
  self.model.requires_grad_(not hparam["freeze"])
153
156
  self.save_hyperparameters(hparam, ignore=["model", "mlp"])
154
157
  if model.lora_enabled:
158
+ from loralib import mark_only_lora_as_trainable
159
+
155
160
  mark_only_lora_as_trainable(self.model)
156
161
  assert hparam["mode"] in ("regression", "classification")
157
162
 
@@ -231,6 +236,8 @@ class Regressor(LightningModule):
231
236
  )
232
237
  if not self.hparams.freeze:
233
238
  if self.model.lora_enabled:
239
+ from loralib import lora_state_dict
240
+
234
241
  torch.save(
235
242
  {
236
243
  "lora_nn": lora_state_dict(self.model),
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayesianflow_for_chem
3
- Version: 2.1.0
3
+ Version: 2.2.2
4
4
  Summary: Bayesian flow network framework for Chemistry
5
5
  Home-page: https://augus1999.github.io/bayesian-flow-network-for-chemistry/
6
6
  Author: Nianze A. Tao
@@ -54,12 +54,14 @@ This is the repository of the PyTorch implementation of ChemBFN model.
54
54
 
55
55
  [![PyPI](https://img.shields.io/pypi/v/bayesianflow-for-chem?color=ff69b4)](https://pypi.org/project/bayesianflow-for-chem/)
56
56
  ![pytest](https://github.com/Augus1999/bayesian-flow-network-for-chemistry/actions/workflows/pytest.yml/badge.svg)
57
+ [![document](https://github.com/Augus1999/bayesian-flow-network-for-chemistry/actions/workflows/pages/pages-build-deployment/badge.svg)](https://augus1999.github.io/bayesian-flow-network-for-chemistry/)
57
58
 
58
59
  ## Features
59
60
 
60
61
  ChemBFN provides the state-of-the-art functionalities of
61
62
  * SMILES or SELFIES-based *de novo* molecule generation
62
63
  * Protein sequence *de novo* generation
64
+ * Template optimisation (mol2mol)
63
65
  * Classifier-free guidance conditional generation (single or multi-objective optimisation)
64
66
  * Context-guided conditional generation (inpaint)
65
67
  * Outstanding out-of-distribution chemical space sampling
@@ -71,6 +73,7 @@ in an all-in-one-model style.
71
73
 
72
74
  ## News
73
75
 
76
+ * [09/10/2025] A web app [`chembfn_webui`](https://github.com/Augus1999/ChemBFN-WebUI) for hosting ChemBFN models is available on [PyPI](https://pypi.org/project/chembfn-webui/).
74
77
  * [30/01/2025] The package `bayesianflow_for_chem` is available on [PyPI](https://pypi.org/project/bayesianflow-for-chem/).
75
78
  * [21/01/2025] Our first paper has been accepted by [JCIM](https://pubs.acs.org/doi/10.1021/acs.jcim.4c01792).
76
79
  * [17/12/2024] The second paper of out-of-distribution generation is available on [arxiv.org](https://arxiv.org/abs/2412.11439).
@@ -93,7 +96,7 @@ You can find pretrained models on our [🤗Hugging Face model page](https://hugg
93
96
 
94
97
  ## Dataset Handling
95
98
 
96
- We provide a Python class [`CSVData`](https://github.com/Augus1999/bayesian-flow-network-for-chemistry/blob/main/bayesianflow_for_chem/data.py#L152) to handle data stored in CSV or similar format containing headers to identify the entities. The following is a quickstart.
99
+ We provide a Python class [`CSVData`](https://github.com/Augus1999/bayesian-flow-network-for-chemistry/blob/main/bayesianflow_for_chem/data.py#L153) to handle data stored in CSV or similar format containing headers to identify the entities. The following is a quickstart.
97
100
 
98
101
  1. Download your dataset file (e.g., ESOL from [MoleculeNet](https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv)) and split the file:
99
102
  ```python
@@ -0,0 +1,15 @@
1
+ bayesianflow_for_chem/__init__.py,sha256=5nyapMRF-NM610Y_FzZQxu97LQvFaQDxwsyYFBJFxdw,642
2
+ bayesianflow_for_chem/cli.py,sha256=wSDQ5EpETB0-o_YeSIuFt4hP1gI4if566a3qehspgB0,27353
3
+ bayesianflow_for_chem/data.py,sha256=jOzcOO5FDNju8hnaimT_WI8sjdaiOHDalDIOEOpLjEE,6643
4
+ bayesianflow_for_chem/model.py,sha256=M35G4u4mX4btl9vOK3Iqs6yOSuIKI_OoCTmLhmjbwNk,57559
5
+ bayesianflow_for_chem/scorer.py,sha256=gQFUlkyxitch02ntqcRh1ZS8aondKLynW5U6NfTQTb4,4084
6
+ bayesianflow_for_chem/spectra.py,sha256=Ba9ib1aDvTtDYbH3b4d-lIty3ZSQMu7jwehuV2KmhwA,1781
7
+ bayesianflow_for_chem/tool.py,sha256=pAEGfYzEiquu9cTM0Te8EAAr2RPRRObGCzLk9uXaw8o,23686
8
+ bayesianflow_for_chem/train.py,sha256=7AU0A-eZwzSYsLyIe3OxGTNWPnhGpHmVUaQLplV2Fn8,9886
9
+ bayesianflow_for_chem/_data/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
10
+ bayesianflow_for_chem-2.2.2.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
11
+ bayesianflow_for_chem-2.2.2.dist-info/METADATA,sha256=nh6i_LRZTBSoJr3KP3iD0Q-CgrveQSj8AHrdg75FsU4,6476
12
+ bayesianflow_for_chem-2.2.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
13
+ bayesianflow_for_chem-2.2.2.dist-info/entry_points.txt,sha256=N63RMoJsr8rxuKxc7Fj802SL8J5AlpCoPkS8E3IFPLI,54
14
+ bayesianflow_for_chem-2.2.2.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
15
+ bayesianflow_for_chem-2.2.2.dist-info/RECORD,,
@@ -1,15 +0,0 @@
1
- bayesianflow_for_chem/__init__.py,sha256=bmqERRnnmyK7UgUYn3BFH_YipKfTxNNjRzu9__RVid4,612
2
- bayesianflow_for_chem/cli.py,sha256=A1cFz6jhpLEpbK9r8GxCLdbnPCzQ4RrsavLKg_lssVg,24208
3
- bayesianflow_for_chem/data.py,sha256=Pl0gGWHmMKTKHpsxznvLgYPCwwlLNL7nqH19Vipjkxs,6584
4
- bayesianflow_for_chem/model.py,sha256=UW5hfAofYK9dH9euDPYWfJedVMRFxk8WtY427fObf70,59641
5
- bayesianflow_for_chem/scorer.py,sha256=gQFUlkyxitch02ntqcRh1ZS8aondKLynW5U6NfTQTb4,4084
6
- bayesianflow_for_chem/spectra.py,sha256=Ba9ib1aDvTtDYbH3b4d-lIty3ZSQMu7jwehuV2KmhwA,1781
7
- bayesianflow_for_chem/tool.py,sha256=bqoIMas8bmcjYBghuQWLh75Eq8ZlG6mh9ZeDzWGOmuw,24790
8
- bayesianflow_for_chem/train.py,sha256=jYkhSguW50lrcTEydCQ20yig_mmc1j7WH9KmVwBCTAo,9727
9
- bayesianflow_for_chem/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
10
- bayesianflow_for_chem-2.1.0.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
11
- bayesianflow_for_chem-2.1.0.dist-info/METADATA,sha256=ctwor5jnPCmAo-1wuIGkX70aqkXj-8YQxr27AJOdEjM,6057
12
- bayesianflow_for_chem-2.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
13
- bayesianflow_for_chem-2.1.0.dist-info/entry_points.txt,sha256=N63RMoJsr8rxuKxc7Fj802SL8J5AlpCoPkS8E3IFPLI,54
14
- bayesianflow_for_chem-2.1.0.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
15
- bayesianflow_for_chem-2.1.0.dist-info/RECORD,,
File without changes