bayesianflow-for-chem 2.0.5__py3-none-any.whl → 2.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bayesianflow-for-chem might be problematic. Click here for more details.

@@ -3,10 +3,8 @@
3
3
  """
4
4
  ChemBFN package.
5
5
  """
6
- import colorama
7
6
  from . import data, tool, train, scorer, spectra
8
7
  from .model import ChemBFN, MLP, EnsembleChemBFN
9
- from .cli import main_script
10
8
 
11
9
  __all__ = [
12
10
  "data",
@@ -18,7 +16,7 @@ __all__ = [
18
16
  "MLP",
19
17
  "EnsembleChemBFN",
20
18
  ]
21
- __version__ = "2.0.5"
19
+ __version__ = "2.2.1"
22
20
  __author__ = "Nianze A. Tao (Omozawa Sueno)"
23
21
 
24
22
 
@@ -29,6 +27,9 @@ def main() -> None:
29
27
  :return:
30
28
  :rtype: None
31
29
  """
30
+ import colorama
31
+ from bayesianflow_for_chem.cli import main_script
32
+
32
33
  colorama.just_fix_windows_console()
33
34
  main_script(__version__)
34
35
  colorama.deinit()
@@ -12,13 +12,8 @@ from pathlib import Path
12
12
  from functools import partial
13
13
  from typing import List, Tuple, Dict, Union, Callable
14
14
  import torch
15
- import lightning as L
16
15
  from rdkit.Chem import MolFromSmiles, CanonSmiles
17
- from torch.utils.data import DataLoader
18
- from lightning.pytorch import loggers
19
- from lightning.pytorch.callbacks import ModelCheckpoint
20
16
  from bayesianflow_for_chem import ChemBFN, MLP
21
- from bayesianflow_for_chem.train import Model
22
17
  from bayesianflow_for_chem.scorer import smiles_valid, Scorer
23
18
  from bayesianflow_for_chem.data import (
24
19
  VOCAB_COUNT,
@@ -32,7 +27,7 @@ from bayesianflow_for_chem.data import (
32
27
  collate,
33
28
  CSVData,
34
29
  )
35
- from bayesianflow_for_chem.tool import sample, inpaint
30
+ from bayesianflow_for_chem.tool import sample, inpaint, optimise, adjust_lora_
36
31
 
37
32
 
38
33
  """
@@ -90,6 +85,7 @@ checkpoint_save_path = "home/user/project/ckpt"
90
85
  train_strategy = "auto" # or any strategy supported by Lightning, e.g., "ddp"
91
86
  accumulate_grad_batches = 1
92
87
  enable_progress_bar = false
88
+ plugin_script = "" # define customised behaviours of dataset, datasetloader, etc in a python script
93
89
 
94
90
  # Remove this table if inference is unnecessary
95
91
  [inference]
@@ -99,9 +95,11 @@ sample_size = 1000 # the minimum number of samples you want
99
95
  sample_step = 100
100
96
  sample_method = "ODE:0.5" # ODE-solver with temperature of 0.5; another choice is "BFN"
101
97
  semi_autoregressive = false
98
+ lora_scaling = 1.0 # LoRA scaling if applied
102
99
  guidance_objective = [-0.023, 0.09, 0.113] # if no objective is needed set it to empty array []
103
100
  guidance_objective_strength = 4.0 # unnecessary if guidance_objective = []
104
101
  guidance_scaffold = "c1ccccc1" # if no scaffold is used set it to empty string ""
102
+ sample_template = "" # template for mol2mol task; leave it blank if scaffold is used
105
103
  unwanted_token = []
106
104
  exclude_invalid = true # to only store valid samples
107
105
  exclude_duplicate = true # to only store unique samples
@@ -118,6 +116,32 @@ madmadmadmadmadmadmadmadmadmadmadmadmadmadmad
118
116
  madmadmadmadmadmadmadmadmadmadmadmadmadmadmad
119
117
  """
120
118
 
119
+ _ALLOWED_PLUGINS = [
120
+ "collate_fn",
121
+ "num_workers",
122
+ "max_sequence_length",
123
+ "shuffle",
124
+ "CustomData",
125
+ ]
126
+
127
+
128
+ def _load_plugin(plugin_file: str) -> Dict[str, Union[int, Callable, object, None]]:
129
+ if not plugin_file:
130
+ return {n: None for n in _ALLOWED_PLUGINS}
131
+ from importlib import util as iutil
132
+
133
+ spec = iutil.spec_from_file_location(Path(plugin_file).stem, plugin_file)
134
+ plugins = iutil.module_from_spec(spec)
135
+ spec.loader.exec_module(plugins)
136
+ plugin_names: List[str] = plugins.__all__
137
+ plugin_dict = {}
138
+ for n in _ALLOWED_PLUGINS:
139
+ if n in plugin_names:
140
+ plugin_dict[n] = getattr(plugins, n)
141
+ else:
142
+ plugin_dict[n] = None
143
+ return plugin_dict
144
+
121
145
 
122
146
  def parse_cli(version: str) -> argparse.Namespace:
123
147
  """
@@ -130,7 +154,7 @@ def parse_cli(version: str) -> argparse.Namespace:
130
154
  """
131
155
  parser = argparse.ArgumentParser(
132
156
  description="Madmol: a CLI molecular design tool for "
133
- "de novo design, R-group replacement, and sequence in-filling, "
157
+ "de novo design, R-group replacement, molecule optimisation, and sequence in-filling, "
134
158
  "based on generative route of ChemBFN method. "
135
159
  "Let's make some craziest molecules.",
136
160
  epilog=f"Madmol {version}, developed in Hiroshima University by chemists for chemists. "
@@ -157,7 +181,7 @@ def parse_cli(version: str) -> argparse.Namespace:
157
181
  "-D",
158
182
  "--dryrun",
159
183
  action="store_true",
160
- help="dry-run to check the configurations",
184
+ help="dry-run to check the configurations and exit",
161
185
  )
162
186
  parser.add_argument("-V", "--version", action="version", version=version)
163
187
  return parser.parse_args()
@@ -265,6 +289,14 @@ def load_runtime_config(
265
289
  f"\033[0;31mCritical\033[0;0m in {config_file}: Restart checkpoint file {ckpt_file} does not exist."
266
290
  )
267
291
  flag_critical += 1
292
+ # ↓ added in v2.2.0; need to be compatible with old versions.
293
+ plugin_script: str = config["train"].get("plugin_script", "")
294
+ if plugin_script:
295
+ if not os.path.exists(plugin_script):
296
+ print(
297
+ f"\033[0;31mCritical\033[0;0m in {config_file}: Plugin script {plugin_script} does not exist."
298
+ )
299
+ flag_critical += 1
268
300
  if "inference" in config:
269
301
  if not "train" in config:
270
302
  if not isinstance(config["inference"]["sequence_length"], int):
@@ -284,6 +316,14 @@ def load_runtime_config(
284
316
  f"\033[0;33mWarning\033[0;0m in {config_file}: Directory {result_dir} to save the result does not exist."
285
317
  )
286
318
  flag_warning += 1
319
+ if (
320
+ config["inference"]["guidance_scaffold"] != ""
321
+ and config["inference"]["sample_template"] != ""
322
+ ):
323
+ print(
324
+ f"\033[0;33mWarning\033[0;0m in {config_file}: Inpaint task or mol2mol task?"
325
+ )
326
+ flag_warning += 1
287
327
  return config, flag_critical, flag_warning
288
328
 
289
329
 
@@ -325,14 +365,15 @@ def main_script(version: str) -> None:
325
365
  )
326
366
  flag_warning += 1
327
367
  if not os.path.exists(runtime_config["train"]["checkpoint_save_path"]):
328
- os.makedirs(runtime_config["train"]["checkpoint_save_path"])
368
+ if not parser.dryrun: # only create it in real tasks
369
+ os.makedirs(runtime_config["train"]["checkpoint_save_path"])
329
370
  else:
330
371
  if not model_config["ChemBFN"]["base_model"]:
331
372
  print(
332
373
  f"\033[0;33mWarning\033[0;0m in {parser.model_config}: You should load a pretrained ChemBFN model."
333
374
  )
334
375
  flag_warning += 1
335
- if not model_config["MLP"]["base_model"]:
376
+ if "MLP" in model_config and not model_config["MLP"]["base_model"]:
336
377
  print(
337
378
  f"\033[0;33mWarning\033[0;0m in {parser.model_config}: You should load a pretrained MLP."
338
379
  )
@@ -340,13 +381,17 @@ def main_script(version: str) -> None:
340
381
  if "inference" in runtime_config:
341
382
  if runtime_config["inference"]["guidance_objective"]:
342
383
  if not "MLP" in model_config:
343
- print(f"Warning in {parser.model_config}: Oh no, you don't have a MLP.")
384
+ print(
385
+ f"\033[0;33mWarning\033[0;0m in {parser.model_config}: Oh no, you don't have a MLP."
386
+ )
344
387
  flag_warning += 1
345
388
  if parser.dryrun:
346
389
  if flag_critical != 0:
347
390
  print("Configuration check failed!")
348
391
  elif flag_warning != 0:
349
- print("Your job will probably run, but it may not follow your expectation.")
392
+ print(
393
+ "Your job will probably run, but it may not follow your expectations."
394
+ )
350
395
  else:
351
396
  print("Configuration check passed.")
352
397
  return
@@ -405,6 +450,15 @@ def main_script(version: str) -> None:
405
450
  mlp = None
406
451
  # ------- train -------
407
452
  if "train" in runtime_config:
453
+ import lightning as L
454
+ from torch.utils.data import DataLoader
455
+ from lightning.pytorch import loggers
456
+ from lightning.pytorch.callbacks import ModelCheckpoint
457
+ from bayesianflow_for_chem.train import Model
458
+
459
+ # ####### get plugins #######
460
+ plugin_file = runtime_config["train"].get("plugin_script", "")
461
+ plugins = _load_plugin(plugin_file)
408
462
  # ####### build scorer #######
409
463
  if (tokeniser_name == "smiles" or tokeniser_name == "safe") and runtime_config[
410
464
  "train"
@@ -418,30 +472,43 @@ def main_script(version: str) -> None:
418
472
  mol_tag = runtime_config["train"]["molecule_tag"]
419
473
  obj_tag = runtime_config["train"]["objective_tag"]
420
474
  dataset_file = runtime_config["train"]["dataset"]
421
- with open(dataset_file, "r") as db:
422
- _data = db.readlines()
423
- header = _data[0]
424
- mol_idx = []
425
- for i, tag in enumerate(header.replace("\n", "").split(",")):
426
- if tag == mol_tag:
427
- mol_idx.append(i)
428
- _data_len = []
429
- for i in _data[1:]:
430
- i = i.replace("\n", "").split(",")
431
- _mol = ".".join([i[j] for j in mol_idx])
432
- _data_len.append(tokeniser(_mol).shape[-1])
433
- lmax = max(_data_len)
434
- dataset = CSVData(dataset_file)
475
+ if plugins["max_sequence_length"]:
476
+ lmax = plugins["max_sequence_length"]
477
+ else:
478
+ with open(dataset_file, "r") as db:
479
+ _data = db.readlines()
480
+ _header = _data[0]
481
+ _mol_idx = []
482
+ for i, tag in enumerate(_header.replace("\n", "").split(",")):
483
+ if tag == mol_tag:
484
+ _mol_idx.append(i)
485
+ _data_len = []
486
+ for i in _data[1:]:
487
+ i = i.replace("\n", "").split(",")
488
+ _mol = ".".join([i[j] for j in _mol_idx])
489
+ _data_len.append(tokeniser(_mol).shape[-1])
490
+ lmax = max(_data_len)
491
+ del _data, _data_len, _header, _mol_idx # clear memory
492
+ if plugins["CustomData"] is not None:
493
+ dataset = plugins["CustomData"](dataset_file)
494
+ else:
495
+ dataset = CSVData(dataset_file)
435
496
  dataset.map(
436
497
  partial(_encode, mol_tag=mol_tag, obj_tag=obj_tag, tokeniser=tokeniser)
437
498
  )
438
499
  dataloader = DataLoader(
439
500
  dataset,
440
501
  runtime_config["train"]["batch_size"],
441
- True,
442
- num_workers=4,
443
- collate_fn=collate,
444
- persistent_workers=True,
502
+ True if plugins["shuffle"] is None else plugins["shuffle"],
503
+ num_workers=4 if plugins["num_workers"] is None else plugins["num_workers"],
504
+ collate_fn=(
505
+ collate if plugins["collate_fn"] is None else plugins["collate_fn"]
506
+ ),
507
+ persistent_workers=(
508
+ True
509
+ if (plugins["num_workers"] is None or plugins["num_workers"] > 0)
510
+ else False
511
+ ),
445
512
  )
446
513
  # ####### build trainer #######
447
514
  logger_name = runtime_config["train"]["logger_name"].lower()
@@ -520,6 +587,8 @@ def main_script(version: str) -> None:
520
587
  if "train" in runtime_config:
521
588
  bfn = model.model
522
589
  mlp = model.mlp
590
+ # ↓ added in v2.1.0; need to be compatible with old versions
591
+ lora_scaling = runtime_config["inference"].get("lora_scaling", 1.0)
523
592
  # ####### strat inference #######
524
593
  bfn.semi_autoregressive = runtime_config["inference"]["semi_autoregressive"]
525
594
  _device = (
@@ -550,8 +619,16 @@ def main_script(version: str) -> None:
550
619
  x[:-1], (0, sequence_length - x.shape[-1] + 1), value=0
551
620
  )
552
621
  x = x[None, :].repeat(batch_size, 1)
622
+ # then sample template will be ignored.
623
+ elif runtime_config["inference"]["sample_template"]:
624
+ template = runtime_config["inference"]["sample_template"]
625
+ x = tokeniser(template)
626
+ x = torch.nn.functional.pad(x, (0, sequence_length - x.shape[-1]), value=0)
627
+ x = x[None, :].repeat(batch_size, 1)
553
628
  else:
554
629
  x = None
630
+ if bfn.lora_enabled:
631
+ adjust_lora_(bfn, lora_scaling)
555
632
  mols = []
556
633
  while len(mols) < runtime_config["inference"]["sample_size"]:
557
634
  if x is None:
@@ -567,7 +644,7 @@ def main_script(version: str) -> None:
567
644
  method=sample_method,
568
645
  allowed_tokens=allowed_token,
569
646
  )
570
- else:
647
+ elif runtime_config["inference"]["guidance_scaffold"]:
571
648
  s = inpaint(
572
649
  bfn,
573
650
  x,
@@ -579,6 +656,18 @@ def main_script(version: str) -> None:
579
656
  method=sample_method,
580
657
  allowed_tokens=allowed_token,
581
658
  )
659
+ else:
660
+ s = optimise(
661
+ bfn,
662
+ x,
663
+ sample_step,
664
+ y,
665
+ guidance_strength,
666
+ _device,
667
+ vocab_keys,
668
+ method=sample_method,
669
+ allowed_tokens=allowed_token,
670
+ )
582
671
  if runtime_config["inference"]["exclude_invalid"]:
583
672
  s = [i for i in s if i]
584
673
  if tokeniser_name == "smiles" or tokeniser_name == "safe":
@@ -61,7 +61,7 @@ def load_vocab(
61
61
  }
62
62
 
63
63
 
64
- _DEFUALT_VOCAB = load_vocab(__filedir__ / "vocab.txt")
64
+ _DEFUALT_VOCAB = load_vocab(__filedir__ / "_data/vocab.txt")
65
65
  VOCAB_KEYS: List[str] = _DEFUALT_VOCAB["vocab_keys"]
66
66
  VOCAB_DICT: Dict[str, int] = _DEFUALT_VOCAB["vocab_dict"]
67
67
  VOCAB_COUNT: int = _DEFUALT_VOCAB["vocab_count"]
@@ -659,12 +659,91 @@ class ChemBFN(nn.Module):
659
659
  return (-logits.gather(-1, x[..., :1]).squeeze(-1)).mean()
660
660
 
661
661
  @staticmethod
662
- def reshape_y(y: Tensor) -> Tensor:
662
+ def _reshape(y: Tensor) -> Tensor:
663
663
  assert y.dim() <= 3 # this doesn't work if the model is frezen in JIT.
664
664
  if y.dim() == 2:
665
665
  return y[:, None, :]
666
666
  return y
667
667
 
668
+ def _process(
669
+ self,
670
+ theta: Tensor,
671
+ mask: Optional[Tuple[Tensor, Tensor]],
672
+ y: Optional[Tensor],
673
+ sample_step: int,
674
+ guidance_strength: float,
675
+ token_mask: Optional[Tensor],
676
+ ) -> Tuple[Tensor, Tensor]:
677
+ # BFN inference process.
678
+ #
679
+ # theta: piror distribution; shape: (n_b, n_t, n_vocab)
680
+ # mask: masked condition distribution; shape: (n_b, n_t, n_vocab)
681
+ # condition distribution mask; shape: (n_b, n_t, 1)
682
+ n_b = theta.shape[0]
683
+ if y is not None:
684
+ y = self._reshape(y)
685
+ for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
686
+ t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
687
+ p = self.discrete_output_distribution(theta, t, y, guidance_strength)
688
+ if token_mask is not None:
689
+ p = p.masked_fill_(token_mask, 0.0)
690
+ alpha = self.calc_discrete_alpha(t, t + 1 / sample_step)
691
+ e_k = nn.functional.one_hot(torch.argmax(p, -1), self.K).float()
692
+ mu = alpha * (self.K * e_k - 1)
693
+ sigma = (alpha * self.K).sqrt()
694
+ theta = (mu + sigma * torch.randn_like(mu)).exp() * theta
695
+ theta = theta / theta.sum(-1, True)
696
+ if mask is not None:
697
+ x_onehot, x_mask = mask
698
+ theta = x_onehot + (1 - x_mask) * theta
699
+ t_final = torch.ones((n_b, 1, 1), device=self.beta.device)
700
+ p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
701
+ entropy = -(p * p.log()).sum(-1).mean(-1)
702
+ if token_mask is not None:
703
+ p = p.masked_fill_(token_mask, 0.0)
704
+ return torch.argmax(p, -1), entropy
705
+
706
+ def _ode_process(
707
+ self,
708
+ z: Tensor,
709
+ mask: Optional[Tuple[Tensor, Tensor]],
710
+ y: Optional[Tensor],
711
+ sample_step: int,
712
+ guidance_strength: float,
713
+ token_mask: Optional[Tensor],
714
+ temperature: float,
715
+ ) -> Tuple[Tensor, Tensor]:
716
+ # ODE-solver engaged inference process.
717
+ #
718
+ # z: prior latent vector; shape: (n_b, n_t, n_vocab)
719
+ # mask: masked condition distribution; shape: (n_b, n_t, n_vocab)
720
+ # condition distribution mask; shape: (n_b, n_t, 1)
721
+ n_b = z.shape[0]
722
+ if y is not None:
723
+ y = self._reshape(y)
724
+ for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
725
+ t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
726
+ theta = softmax(z, -1)
727
+ if mask is not None:
728
+ x_onehot, x_mask = mask
729
+ theta = x_onehot + (1 - x_mask) * theta
730
+ beta = self.calc_beta(t + 1 / sample_step)
731
+ p = self.discrete_output_distribution(theta, t, y, guidance_strength)
732
+ if token_mask is not None:
733
+ p = p.masked_fill_(token_mask, 0.0)
734
+ u = torch.randn_like(z)
735
+ z = (self.K * p - 1) * beta + (self.K * beta * temperature).sqrt() * u
736
+ t_final = torch.ones((n_b, 1, 1), device=self.beta.device)
737
+ theta = softmax(z, -1)
738
+ if mask is not None:
739
+ x_onehot, x_mask = mask
740
+ theta = x_onehot + (1 - x_mask) * theta
741
+ p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
742
+ entropy = -(p * p.log()).sum(-1).mean(-1)
743
+ if token_mask is not None:
744
+ p = p.masked_fill_(token_mask, 0.0)
745
+ return torch.argmax(p, -1), entropy
746
+
668
747
  @torch.jit.export
669
748
  def sample(
670
749
  self,
@@ -676,48 +755,30 @@ class ChemBFN(nn.Module):
676
755
  token_mask: Optional[Tensor] = None,
677
756
  ) -> Tuple[Tensor, Tensor]:
678
757
  """
679
- Sample from a piror distribution.
758
+ Sample from a uniform piror distribution.
680
759
 
681
760
  :param batch_size: batch size
682
761
  :param sequence_size: max sequence length
683
- :param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
762
+ :param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
684
763
  :param sample_step: number of sampling steps
685
764
  :param guidance_strength: strength of conditional generation. It is not used if y is null.
686
765
  :param token_mask: token mask assigning unwanted token(s) with `True`;
687
- shape: (1, 1, n_vocab)
766
+ shape: (1, 1, n_vocab)
688
767
  :type batch_size: int
689
768
  :type sequence_size: int
690
769
  :type y: torch.Tensor | None
691
770
  :type sample_step: int
692
771
  :type guidance_strength: float
693
772
  :type token_mask: torch.Tensor | None
694
- :return: sampled token indices; shape: (n_b, n_t) \n
695
- entropy of the tokens; shape: (n_b)
773
+ :return: sampled token indices; shape: (n_b, n_t) \n
774
+ entropy of the tokens; shape: (n_b)
696
775
  :rtype: tuple
697
776
  """
698
777
  theta = (
699
778
  torch.ones((batch_size, sequence_size, self.K), device=self.beta.device)
700
779
  / self.K
701
780
  )
702
- if y is not None:
703
- y = self.reshape_y(y)
704
- for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
705
- t = (i - 1).view(1, 1, 1).repeat(batch_size, 1, 1) / sample_step
706
- p = self.discrete_output_distribution(theta, t, y, guidance_strength)
707
- if token_mask is not None:
708
- p = p.masked_fill_(token_mask, 0.0)
709
- alpha = self.calc_discrete_alpha(t, t + 1 / sample_step)
710
- e_k = nn.functional.one_hot(torch.argmax(p, -1), self.K).float()
711
- mu = alpha * (self.K * e_k - 1)
712
- sigma = (alpha * self.K).sqrt()
713
- theta = (mu + sigma * torch.randn_like(mu)).exp() * theta
714
- theta = theta / theta.sum(-1, True)
715
- t_final = torch.ones((batch_size, 1, 1), device=self.beta.device)
716
- p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
717
- entropy = -(p * p.log()).sum(-1).mean(-1)
718
- if token_mask is not None:
719
- p = p.masked_fill_(token_mask, 0.0)
720
- return torch.argmax(p, -1), entropy
781
+ return self._process(theta, None, y, sample_step, guidance_strength, token_mask)
721
782
 
722
783
  @torch.jit.export
723
784
  def ode_sample(
@@ -735,11 +796,11 @@ class ChemBFN(nn.Module):
735
796
 
736
797
  :param batch_size: batch size
737
798
  :param sequence_size: max sequence length
738
- :param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
799
+ :param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
739
800
  :param sample_step: number of sampling steps
740
801
  :param guidance_strength: strength of conditional generation. It is not used if y is null.
741
802
  :param token_mask: token mask assigning unwanted token(s) with `True`;
742
- shape: (1, 1, n_vocab)
803
+ shape: (1, 1, n_vocab)
743
804
  :param temperature: sampling temperature
744
805
  :type batch_size: int
745
806
  :type sequence_size: int
@@ -748,29 +809,14 @@ class ChemBFN(nn.Module):
748
809
  :type guidance_strength: float
749
810
  :type token_mask: torch.Tensor | None
750
811
  :type temperature: float
751
- :return: sampled token indices; shape: (n_b, n_t) \n
752
- entropy of the tokens; shape: (n_b)
812
+ :return: sampled token indices; shape: (n_b, n_t) \n
813
+ entropy of the tokens; shape: (n_b)
753
814
  :rtype: tuple
754
815
  """
755
816
  z = torch.zeros((batch_size, sequence_size, self.K), device=self.beta.device)
756
- if y is not None:
757
- y = self.reshape_y(y)
758
- for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
759
- t = (i - 1).view(1, 1, 1).repeat(batch_size, 1, 1) / sample_step
760
- theta = torch.softmax(z, -1)
761
- beta = self.calc_beta(t + 1 / sample_step)
762
- p = self.discrete_output_distribution(theta, t, y, guidance_strength)
763
- if token_mask is not None:
764
- p = p.masked_fill_(token_mask, 0.0)
765
- u = torch.randn_like(z)
766
- z = (self.K * p - 1) * beta + (self.K * beta * temperature).sqrt() * u
767
- t_final = torch.ones((batch_size, 1, 1), device=self.beta.device)
768
- theta = torch.softmax(z, -1)
769
- p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
770
- entropy = -(p * p.log()).sum(-1).mean(-1)
771
- if token_mask is not None:
772
- p = p.masked_fill_(token_mask, 0.0)
773
- return torch.argmax(p, -1), entropy
817
+ return self._ode_process(
818
+ z, None, y, sample_step, guidance_strength, token_mask, temperature
819
+ )
774
820
 
775
821
  @torch.jit.export
776
822
  def inpaint(
@@ -800,30 +846,12 @@ class ChemBFN(nn.Module):
800
846
  :rtype: tuple
801
847
  """
802
848
  n_b, n_t = x.shape
803
- mask = (x != 0).float()[..., None]
849
+ x_mask = (x != 0).float()[..., None]
804
850
  theta = torch.ones((n_b, n_t, self.K), device=x.device) / self.K
805
- x_onehot = nn.functional.one_hot(x, self.K) * mask
806
- theta = x_onehot + (1 - mask) * theta
807
- if y is not None:
808
- y = self.reshape_y(y)
809
- for i in torch.linspace(1, sample_step, sample_step, device=x.device):
810
- t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
811
- p = self.discrete_output_distribution(theta, t, y, guidance_strength)
812
- if token_mask is not None:
813
- p = p.masked_fill_(token_mask, 0.0)
814
- alpha = self.calc_discrete_alpha(t, t + 1 / sample_step)
815
- e_k = nn.functional.one_hot(torch.argmax(p, -1), self.K).float()
816
- mu = alpha * (self.K * e_k - 1)
817
- sigma = (alpha * self.K).sqrt()
818
- theta = (mu + sigma * torch.randn_like(mu)).exp() * theta
819
- theta = theta / theta.sum(-1, True)
820
- theta = x_onehot + (1 - mask) * theta
821
- t_final = torch.ones((n_b, 1, 1), device=x.device)
822
- p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
823
- entropy = -(p * p.log()).sum(-1).mean(-1)
824
- if token_mask is not None:
825
- p = p.masked_fill_(token_mask, 0.0)
826
- return torch.argmax(p, -1), entropy
851
+ x_onehot = nn.functional.one_hot(x, self.K) * x_mask
852
+ theta = x_onehot + (1 - x_mask) * theta
853
+ mask = (x_onehot, x_mask)
854
+ return self._process(theta, mask, y, sample_step, guidance_strength, token_mask)
827
855
 
828
856
  @torch.jit.export
829
857
  def ode_inpaint(
@@ -856,29 +884,80 @@ class ChemBFN(nn.Module):
856
884
  :rtype: tuple
857
885
  """
858
886
  n_b, n_t = x.shape
859
- mask = (x != 0).float()[..., None]
860
- x_onehot = nn.functional.one_hot(x, self.K) * mask
887
+ x_mask = (x != 0).float()[..., None]
888
+ x_onehot = nn.functional.one_hot(x, self.K) * x_mask
861
889
  z = torch.zeros((n_b, n_t, self.K), device=self.beta.device)
862
- if y is not None:
863
- y = self.reshape_y(y)
864
- for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
865
- t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
866
- theta = torch.softmax(z, -1)
867
- theta = x_onehot + (1 - mask) * theta
868
- beta = self.calc_beta(t + 1 / sample_step)
869
- p = self.discrete_output_distribution(theta, t, y, guidance_strength)
870
- if token_mask is not None:
871
- p = p.masked_fill_(token_mask, 0.0)
872
- u = torch.randn_like(z)
873
- z = (self.K * p - 1) * beta + (self.K * beta * temperature).sqrt() * u
874
- t_final = torch.ones((n_b, 1, 1), device=self.beta.device)
875
- theta = torch.softmax(z, -1)
876
- theta = x_onehot + (1 - mask) * theta
877
- p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
878
- entropy = -(p * p.log()).sum(-1).mean(-1)
879
- if token_mask is not None:
880
- p = p.masked_fill_(token_mask, 0.0)
881
- return torch.argmax(p, -1), entropy
890
+ mask = (x_onehot, x_mask)
891
+ return self._ode_process(
892
+ z, mask, y, sample_step, guidance_strength, token_mask, temperature
893
+ )
894
+
895
+ @torch.jit.export
896
+ def optimise(
897
+ self,
898
+ x: Tensor,
899
+ y: Optional[Tensor] = None,
900
+ sample_step: int = 100,
901
+ guidance_strength: float = 4.0,
902
+ token_mask: Optional[Tensor] = None,
903
+ ) -> Tuple[Tensor, Tensor]:
904
+ """
905
+ Optimise the template molecule (mol2mol). \n
906
+ This method is equivalent to sampling from a customised prior distribution.
907
+
908
+ :param x: categorical indices of template; shape: (n_b, n_t)
909
+ :param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
910
+ :param sample_step: number of sampling steps
911
+ :param guidance_strength: strength of conditional generation. It is not used if y is null.
912
+ :param token_mask: token mask assigning unwanted token(s) with `True`;
913
+ shape: (1, 1, n_vocab)
914
+ :type x: torch.Tensor
915
+ :type y: torch.Tensor | None
916
+ :type sample_step: int
917
+ :type guidance_strength: float
918
+ :type token_mask: torch.Tensor | None
919
+ :return: sampled token indices; shape: (n_b, n_t) \n
920
+ entropy of the tokens; shape: (n_b)
921
+ :rtype: tuple
922
+ """
923
+ x_onehot = nn.functional.one_hot(x, self.K).float()
924
+ theta = softmax(x_onehot, -1)
925
+ return self._process(theta, None, y, sample_step, guidance_strength, token_mask)
926
+
927
+ @torch.jit.export
928
+ def ode_optimise(
929
+ self,
930
+ x: Tensor,
931
+ y: Optional[Tensor] = None,
932
+ sample_step: int = 100,
933
+ guidance_strength: float = 4.0,
934
+ token_mask: Optional[Tensor] = None,
935
+ temperature: float = 0.5,
936
+ ) -> Tuple[Tensor, Tensor]:
937
+ """
938
+ ODE mol2mol.
939
+
940
+ :param x: categorical indices of template; shape: (n_b, n_t)
941
+ :param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
942
+ :param sample_step: number of sampling steps
943
+ :param guidance_strength: strength of conditional generation. It is not used if y is null.
944
+ :param token_mask: token mask assigning unwanted token(s) with `True`;
945
+ shape: (1, 1, n_vocab)
946
+ :param temperature: sampling temperature
947
+ :type x: torch.Tensor
948
+ :type y: torch.Tensor | None
949
+ :type sample_step: int
950
+ :type guidance_strength: float
951
+ :type token_mask: torch.Tensor | None
952
+ :type temperature: float
953
+ :return: sampled token indices; shape: (n_b, n_t) \n
954
+ entropy of the tokens; shape: (n_b)
955
+ :rtype: tuple
956
+ """
957
+ z = nn.functional.one_hot(x, self.K).float()
958
+ return self._ode_process(
959
+ z, None, y, sample_step, guidance_strength, token_mask, temperature
960
+ )
882
961
 
883
962
  def inference(
884
963
  self, x: Tensor, mlp: MLP, embed_fn: Optional[Callable[[Tensor], Tensor]] = None
@@ -1052,22 +1131,6 @@ class EnsembleChemBFN(ChemBFN):
1052
1131
  module.lora_dropout = None
1053
1132
  v.lora_enabled = False
1054
1133
 
1055
- def construct_y(
1056
- self, c: Union[List[Tensor], Dict[str, Tensor]]
1057
- ) -> Dict[str, Tensor]:
1058
- assert (
1059
- isinstance(c, dict) is self._label_is_dict
1060
- ), f"`c` should be a {'`dict` instance' if self._label_is_dict else '`list` instance'} but got {type(c)} instand."
1061
- out: Dict[str, Tensor] = {}
1062
- if isinstance(c, list):
1063
- c = dict(zip([f"val_{i}" for i in range(len(c))], c))
1064
- for name, model in self.cond_heads.items():
1065
- y = model.forward(c[name])
1066
- if y.dim() == 2:
1067
- y = y[:, None, :]
1068
- out[name] = y
1069
- return out
1070
-
1071
1134
  def discrete_output_distribution(
1072
1135
  self, theta: Tensor, t: Tensor, y: Dict[str, Tensor], w: float
1073
1136
  ) -> Tensor:
@@ -1102,8 +1165,24 @@ class EnsembleChemBFN(ChemBFN):
1102
1165
  p_cond += p_cond_ * self.adapter_weights[name]
1103
1166
  return softmax((1 + w) * p_cond - w * p_uncond, -1)
1104
1167
 
1168
+ def _map_to_dict(
1169
+ self, c: Union[List[Tensor], Dict[str, Tensor]]
1170
+ ) -> Dict[str, Tensor]:
1171
+ assert (
1172
+ isinstance(c, dict) is self._label_is_dict
1173
+ ), f"`c` should be a {'`dict` instance' if self._label_is_dict else '`list` instance'} but got {type(c)} instand."
1174
+ out: Dict[str, Tensor] = {}
1175
+ if isinstance(c, list):
1176
+ c = dict(zip([f"val_{i}" for i in range(len(c))], c))
1177
+ for name, model in self.cond_heads.items():
1178
+ y = model.forward(c[name])
1179
+ if y.dim() == 2:
1180
+ y = y[:, None, :]
1181
+ out[name] = y
1182
+ return out
1183
+
1105
1184
  @staticmethod
1106
- def reshape_y(y: Dict[str, Tensor]) -> Dict[str, Tensor]:
1185
+ def _reshape(y: Dict[str, Tensor]) -> Dict[str, Tensor]:
1107
1186
  for k in y:
1108
1187
  assert y[k].dim() <= 3
1109
1188
  if y[k].dim() == 2:
@@ -1139,7 +1218,7 @@ class EnsembleChemBFN(ChemBFN):
1139
1218
  entropy of the tokens; shape: (n_b)
1140
1219
  :rtype: tuple
1141
1220
  """
1142
- y = self.construct_y(conditions)
1221
+ y = self._map_to_dict(conditions)
1143
1222
  return super().sample(
1144
1223
  batch_size, sequence_size, y, sample_step, guidance_strength, token_mask
1145
1224
  )
@@ -1176,7 +1255,7 @@ class EnsembleChemBFN(ChemBFN):
1176
1255
  entropy of the tokens; shape: (n_b)
1177
1256
  :rtype: tuple
1178
1257
  """
1179
- y = self.construct_y(conditions)
1258
+ y = self._map_to_dict(conditions)
1180
1259
  return super().ode_sample(
1181
1260
  batch_size,
1182
1261
  sequence_size,
@@ -1213,7 +1292,7 @@ class EnsembleChemBFN(ChemBFN):
1213
1292
  entropy of the tokens; shape: (n_b)
1214
1293
  :rtype: tuple
1215
1294
  """
1216
- y = self.construct_y(conditions)
1295
+ y = self._map_to_dict(conditions)
1217
1296
  return super().inpaint(x, y, sample_step, guidance_strength, token_mask)
1218
1297
 
1219
1298
  @torch.inference_mode()
@@ -1245,11 +1324,76 @@ class EnsembleChemBFN(ChemBFN):
1245
1324
  entropy of the tokens; shape: (n_b)
1246
1325
  :rtype: tuple
1247
1326
  """
1248
- y = self.construct_y(conditions)
1327
+ y = self._map_to_dict(conditions)
1249
1328
  return super().ode_inpaint(
1250
1329
  x, y, sample_step, guidance_strength, token_mask, temperature
1251
1330
  )
1252
1331
 
1332
+ @torch.inference_mode()
1333
+ def optimise(
1334
+ self,
1335
+ x: Tensor,
1336
+ conditions: Union[List[Tensor], Dict[str, Tensor]],
1337
+ sample_step: int = 100,
1338
+ guidance_strength: float = 4.0,
1339
+ token_mask: Optional[Tensor] = None,
1340
+ ) -> Tuple[Tensor, Tensor]:
1341
+ """
1342
+ Optimise the template molecule (mol2mol). \n
1343
+ This method is equivalent to sampling from a customised prior distribution.
1344
+
1345
+ :param x: categorical indices of template; shape: (n_b, n_t)
1346
+ :param conditions: conditioning vector; shape: (n_b, n_c) * n_h
1347
+ :param sample_step: number of sampling steps
1348
+ :param guidance_strength: strength of conditional generation. It is not used if y is null.
1349
+ :param token_mask: token mask assigning unwanted token(s) with `True`;
1350
+ shape: (1, 1, n_vocab)
1351
+ :type x: torch.Tensor
1352
+ :type y: torch.Tensor | None
1353
+ :type sample_step: int
1354
+ :type guidance_strength: float
1355
+ :type token_mask: torch.Tensor | None
1356
+ :return: sampled token indices; shape: (n_b, n_t) \n
1357
+ entropy of the tokens; shape: (n_b)
1358
+ :rtype: tuple
1359
+ """
1360
+ y = self._map_to_dict(conditions)
1361
+ return super().optimise(x, y, sample_step, guidance_strength, token_mask)
1362
+
1363
+ @torch.inference_mode()
1364
+ def ode_optimise(
1365
+ self,
1366
+ x: Tensor,
1367
+ conditions: Union[List[Tensor], Dict[str, Tensor]],
1368
+ sample_step: int = 100,
1369
+ guidance_strength: float = 4.0,
1370
+ token_mask: Optional[Tensor] = None,
1371
+ temperature: float = 0.5,
1372
+ ) -> Tuple[Tensor, Tensor]:
1373
+ """
1374
+ ODE inpainting.
1375
+
1376
+ :param x: categorical indices of template; shape: (n_b, n_t)
1377
+ :param conditions: conditioning vector; shape: (n_b, n_c) * n_h
1378
+ :param sample_step: number of sampling steps
1379
+ :param guidance_strength: strength of conditional generation. It is not used if y is null.
1380
+ :param token_mask: token mask; shape: (1, 1, n_vocab)
1381
+ :param temperature: sampling temperature
1382
+ :type x: torch.Tensor
1383
+ :type conditions: list | dict
1384
+ :type sample_step: int
1385
+ :type guidance_strength: float
1386
+ :type token_mask: torch.Tensor | None
1387
+ :type temperature: float
1388
+ :return: sampled token indices; shape: (n_b, n_t) \n
1389
+ entropy of the tokens; shape: (n_b)
1390
+ :rtype: tuple
1391
+ """
1392
+ y = self._map_to_dict(conditions)
1393
+ return super().ode_optimise(
1394
+ x, y, sample_step, guidance_strength, token_mask, temperature
1395
+ )
1396
+
1253
1397
  def quantise(
1254
1398
  self, quantise_method: Optional[Callable[[ChemBFN], None]] = None
1255
1399
  ) -> None:
@@ -9,7 +9,6 @@ import warnings
9
9
  from pathlib import Path
10
10
  from typing import List, Dict, Tuple, Union, Optional
11
11
  import torch
12
- import colorama
13
12
  import numpy as np
14
13
  from torch import cuda, Tensor, softmax
15
14
  from torch.utils.data import DataLoader
@@ -24,15 +23,7 @@ from rdkit.Chem import (
24
23
  AddHs,
25
24
  Mol,
26
25
  )
27
- from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles # type: ignore
28
- from sklearn.metrics import (
29
- roc_auc_score,
30
- auc,
31
- precision_recall_curve,
32
- r2_score,
33
- mean_absolute_error,
34
- root_mean_squared_error,
35
- )
26
+ from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles
36
27
  from .data import VOCAB_KEYS
37
28
  from .model import ChemBFN, MLP, EnsembleChemBFN
38
29
 
@@ -45,6 +36,68 @@ def _find_device() -> torch.device:
45
36
  return torch.device("cpu")
46
37
 
47
38
 
39
+ def _parse_and_assert_param(
40
+ model: Union[ChemBFN, EnsembleChemBFN],
41
+ y: Optional[Union[Tensor, Dict[str, Tensor], List[Tensor]]],
42
+ method: str,
43
+ ) -> Optional[float]:
44
+ assert method.split(":")[0].lower() in ("ode", "bfn")
45
+ if isinstance(model, EnsembleChemBFN):
46
+ assert y is not None, "conditioning is required while using an ensemble model."
47
+ assert isinstance(y, list) or isinstance(y, dict)
48
+ else:
49
+ assert isinstance(y, Tensor) or (y is None)
50
+ if "ode" in method.lower():
51
+ tp = float(method.split(":")[-1])
52
+ assert tp > 0, "Sampling temperature should be higher than 0."
53
+ return tp
54
+ return None
55
+
56
+
57
+ def _map_to_device(
58
+ y: Optional[Union[Tensor, Dict[str, Tensor], List[Tensor]]],
59
+ device: Union[str, torch.device],
60
+ ) -> Optional[Union[Tensor, Dict[str, Tensor], List[Tensor]]]:
61
+ if y is not None:
62
+ if isinstance(y, Tensor):
63
+ y = y.to(device)
64
+ elif isinstance(y, list):
65
+ y = [i.to(device) for i in y]
66
+ elif isinstance(y, dict):
67
+ y = {k: v.to(device) for k, v in y.items()}
68
+ else:
69
+ raise NotImplementedError
70
+ return y
71
+
72
+
73
+ def _build_token_mask(
74
+ allowed_tokens: Union[str, List[str]],
75
+ vocab_keys: List[str],
76
+ device: Union[str, torch.tensor],
77
+ ) -> Optional[Tensor]:
78
+ if isinstance(allowed_tokens, list):
79
+ token_mask = [0 if i in allowed_tokens else 1 for i in vocab_keys]
80
+ token_mask = torch.tensor([[token_mask]], dtype=torch.bool).to(device)
81
+ else:
82
+ token_mask = None
83
+ return token_mask
84
+
85
+
86
+ def _token_to_seq(
87
+ tokens: Tensor, entropy: Tensor, vocab_keys: List[str], separator: str, sort: bool
88
+ ) -> List[str]:
89
+ if sort:
90
+ sorted_idx = entropy.argsort(stable=True)
91
+ tokens = tokens[sorted_idx]
92
+ return [
93
+ separator.join([vocab_keys[i] for i in j])
94
+ .split("<start>" + separator)[-1]
95
+ .split(separator + "<end>")[0]
96
+ .replace("<pad>", "")
97
+ for j in tokens
98
+ ]
99
+
100
+
48
101
  @torch.no_grad()
49
102
  def test(
50
103
  model: ChemBFN,
@@ -86,6 +139,12 @@ def test(
86
139
  predict_y.append(y_hat.detach().to("cpu"))
87
140
  predict_y, label_y = torch.cat(predict_y, 0), torch.cat(label_y, 0).split(1, -1)
88
141
  if mode == "regression":
142
+ from sklearn.metrics import (
143
+ r2_score,
144
+ mean_absolute_error,
145
+ root_mean_squared_error,
146
+ )
147
+
89
148
  predict_y = [
90
149
  predict[label_y[i] != torch.inf]
91
150
  for (i, predict) in enumerate(predict_y.split(1, -1))
@@ -99,6 +158,8 @@ def test(
99
158
  r2 = [r2_score(label, predict) for (label, predict) in y_zipped]
100
159
  return {"MAE": mae, "RMSE": rmse, "R^2": r2}
101
160
  if mode == "classification":
161
+ from sklearn.metrics import roc_auc_score, auc, precision_recall_curve
162
+
102
163
  n_c = len(label_y)
103
164
  predict_y = predict_y.chunk(n_c, -1)
104
165
  y_zipped = list(zip(label_y, predict_y))
@@ -142,7 +203,6 @@ def split_dataset(
142
203
  assert file.endswith(".csv")
143
204
  assert len(split_ratio) == 3
144
205
  assert method in ("random", "scaffold")
145
- colorama.just_fix_windows_console()
146
206
  with open(file, "r") as f:
147
207
  data = list(csv.reader(f))
148
208
  header = data[0]
@@ -167,10 +227,8 @@ def split_dataset(
167
227
  # compute Bemis-Murcko scaffold
168
228
  if len(smiles_idx) > 1:
169
229
  warnings.warn(
170
- "\033[32;1m"
171
230
  f"We found {len(smiles_idx)} SMILES strings in a row!"
172
- " Only the first SMILES will be used to compute the molecular scaffold."
173
- "\033[0m",
231
+ " Only the first SMILES will be used to compute the molecular scaffold.",
174
232
  stacklevel=2,
175
233
  )
176
234
  try:
@@ -197,10 +255,10 @@ def split_dataset(
197
255
  with open(file.replace(".csv", "_test.csv"), "w", newline="") as fte:
198
256
  writer = csv.writer(fte)
199
257
  writer.writerows([header] + test_set)
200
- with open(file.replace(".csv", "_val.csv"), "w", newline="") as fva:
201
- writer = csv.writer(fva)
202
- writer.writerows([header] + val_set)
203
- colorama.deinit()
258
+ if val_set:
259
+ with open(file.replace(".csv", "_val.csv"), "w", newline="") as fva:
260
+ writer = csv.writer(fva)
261
+ writer.writerows([header] + val_set)
204
262
 
205
263
 
206
264
  @torch.no_grad()
@@ -219,7 +277,7 @@ def sample(
219
277
  sort: bool = False,
220
278
  ) -> List[str]:
221
279
  """
222
- Sampling.
280
+ Sampling molecules.
223
281
 
224
282
  :param model: trained ChemBFN model
225
283
  :param batch_size: batch size
@@ -250,32 +308,12 @@ def sample(
250
308
  :return: a list of generated molecular strings
251
309
  :rtype: list
252
310
  """
253
- assert method.split(":")[0].lower() in ("ode", "bfn")
254
- if isinstance(model, EnsembleChemBFN):
255
- assert y is not None, "conditioning is required while using an ensemble model."
256
- assert isinstance(y, list) or isinstance(y, dict)
257
- else:
258
- assert isinstance(y, Tensor) or y is None
259
- if device is None:
260
- device = _find_device()
311
+ tp = _parse_and_assert_param(model, y, method)
312
+ device = _find_device() if device is None else device
261
313
  model.to(device).eval()
262
- if y is not None:
263
- if isinstance(y, Tensor):
264
- y = y.to(device)
265
- elif isinstance(y, list):
266
- y = [i.to(device) for i in y]
267
- elif isinstance(y, dict):
268
- y = {k: v.to(device) for k, v in y.items()}
269
- else:
270
- raise NotImplementedError
271
- if isinstance(allowed_tokens, list):
272
- token_mask = [0 if i in allowed_tokens else 1 for i in vocab_keys]
273
- token_mask = torch.tensor([[token_mask]], dtype=torch.bool).to(device)
274
- else:
275
- token_mask = None
276
- if "ode" in method.lower():
277
- tp = float(method.split(":")[-1])
278
- assert tp > 0, "Sampling temperature should be higher than 0."
314
+ y = _map_to_device(y, device)
315
+ token_mask = _build_token_mask(allowed_tokens, vocab_keys, device)
316
+ if tp:
279
317
  tokens, entropy = model.ode_sample(
280
318
  batch_size, sequence_size, y, sample_step, guidance_strength, token_mask, tp
281
319
  )
@@ -283,16 +321,7 @@ def sample(
283
321
  tokens, entropy = model.sample(
284
322
  batch_size, sequence_size, y, sample_step, guidance_strength, token_mask
285
323
  )
286
- if sort:
287
- sorted_idx = entropy.argsort(stable=True)
288
- tokens = tokens[sorted_idx]
289
- return [
290
- seperator.join([vocab_keys[i] for i in j])
291
- .split("<start>" + seperator)[-1]
292
- .split(seperator + "<end>")[0]
293
- .replace("<pad>", "")
294
- for j in tokens
295
- ]
324
+ return _token_to_seq(tokens, entropy, vocab_keys, seperator, sort)
296
325
 
297
326
 
298
327
  @torch.no_grad()
@@ -339,33 +368,13 @@ def inpaint(
339
368
  :return: a list of generated molecular strings
340
369
  :rtype: list
341
370
  """
342
- assert method.split(":")[0].lower() in ("ode", "bfn")
343
- if isinstance(model, EnsembleChemBFN):
344
- assert y is not None, "conditioning is required while using an ensemble model."
345
- assert isinstance(y, list) or isinstance(y, dict)
346
- else:
347
- assert isinstance(y, Tensor) or y is None
348
- if device is None:
349
- device = _find_device()
371
+ tp = _parse_and_assert_param(model, y, method)
372
+ device = _find_device() if device is None else device
350
373
  model.to(device).eval()
351
374
  x = x.to(device)
352
- if y is not None:
353
- if isinstance(y, Tensor):
354
- y = y.to(device)
355
- elif isinstance(y, list):
356
- y = [i.to(device) for i in y]
357
- elif isinstance(y, dict):
358
- y = {k: v.to(device) for k, v in y.items()}
359
- else:
360
- raise NotImplementedError
361
- if isinstance(allowed_tokens, list):
362
- token_mask = [0 if i in allowed_tokens else 1 for i in vocab_keys]
363
- token_mask = torch.tensor([[token_mask]], dtype=torch.bool).to(device)
364
- else:
365
- token_mask = None
366
- if "ode" in method.lower():
367
- tp = float(method.split(":")[-1])
368
- assert tp > 0, "Sampling temperature should be higher than 0."
375
+ y = _map_to_device(y, device)
376
+ token_mask = _build_token_mask(allowed_tokens, vocab_keys, device)
377
+ if tp:
369
378
  tokens, entropy = model.ode_inpaint(
370
379
  x, y, sample_step, guidance_strength, token_mask, tp
371
380
  )
@@ -373,16 +382,68 @@ def inpaint(
373
382
  tokens, entropy = model.inpaint(
374
383
  x, y, sample_step, guidance_strength, token_mask
375
384
  )
376
- if sort:
377
- sorted_idx = entropy.argsort(stable=True)
378
- tokens = tokens[sorted_idx]
379
- return [
380
- separator.join([vocab_keys[i] for i in j])
381
- .split("<start>" + separator)[-1]
382
- .split(separator + "<end>")[0]
383
- .replace("<pad>", "")
384
- for j in tokens
385
- ]
385
+ return _token_to_seq(tokens, entropy, vocab_keys, separator, sort)
386
+
387
+
388
+ @torch.no_grad()
389
+ def optimise(
390
+ model: Union[ChemBFN, EnsembleChemBFN],
391
+ x: Tensor,
392
+ sample_step: int = 100,
393
+ y: Optional[Union[Tensor, Dict[str, Tensor], List[Tensor]]] = None,
394
+ guidance_strength: float = 4.0,
395
+ device: Union[str, torch.device, None] = None,
396
+ vocab_keys: List[str] = VOCAB_KEYS,
397
+ separator: str = "",
398
+ method: str = "BFN",
399
+ allowed_tokens: Union[str, List[str]] = "all",
400
+ sort: bool = False,
401
+ ) -> List[str]:
402
+ """
403
+ Optimising template molecules (mol2mol).
404
+
405
+ :param model: trained ChemBFN model
406
+ :param x: categorical indices of template; shape: (n_b, n_t)
407
+ :param sample_step: number of sampling steps
408
+ :param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f) \n
409
+ or a list/`dict` of conditions; shape: (n_b, n_c) * n_h
410
+
411
+ :param guidance_strength: strength of conditional generation. It is not used if y is null.
412
+ :param device: hardware accelerator
413
+ :param vocab_keys: a list of (ordered) vocabulary
414
+ :param separator: token separator; default is `""`
415
+ :param method: sampling method chosen from `"ODE:x"` or `"BFN"` where `x` is the value of sampling temperature; default is `"BFN"`
416
+ :param allowed_tokens: a list of allowed tokens; default is `"all"`
417
+ :param sort: whether to sort the samples according to entropy values; default is `False`
418
+ :type model: bayesianflow_for_chem.model.ChemBFN | bayesianflow_for_chem.model.EnsembleChemBFN
419
+ :type x: torch.Tensor
420
+ :type sample_step: int
421
+ :type y: torch.Tensor | list | dict | None
422
+ :type guidance_strength: float
423
+ :type device: str | torch.device | None
424
+ :type vocab_keys: list
425
+ :type separator: str
426
+ :type method: str
427
+ :type allowed_tokens: str | list
428
+ :type sort: bool
429
+ :return: a list of generated molecular strings
430
+ :rtype: list
431
+ """
432
+ tp = _parse_and_assert_param(model, y, method)
433
+ device = _find_device() if device is None else device
434
+ model.to(device).eval()
435
+ x = x.to(device)
436
+ y = _map_to_device(y, device)
437
+ token_mask = _build_token_mask(allowed_tokens, vocab_keys, device)
438
+ if tp:
439
+ tokens, entropy = model.ode_optimise(
440
+ x, y, sample_step, guidance_strength, token_mask, tp
441
+ )
442
+ else:
443
+ tokens, entropy = model.optimise(
444
+ x, y, sample_step, guidance_strength, token_mask
445
+ )
446
+ return _token_to_seq(tokens, entropy, vocab_keys, separator, sort)
386
447
 
387
448
 
388
449
  def quantise_model_(model: ChemBFN) -> None:
@@ -8,7 +8,6 @@ from typing import Dict, Tuple, Union, Optional
8
8
  import torch
9
9
  import torch.optim as op
10
10
  import torch.nn.functional as F
11
- from loralib import lora_state_dict, mark_only_lora_as_trainable
12
11
  from torch import Tensor
13
12
  from torch.optim.lr_scheduler import ReduceLROnPlateau
14
13
  from lightning import LightningModule
@@ -55,6 +54,8 @@ class Model(LightningModule):
55
54
  self.scorer = scorer
56
55
  self.save_hyperparameters(hparam, ignore=["model", "mlp", "scorer"])
57
56
  if model.lora_enabled:
57
+ from loralib import mark_only_lora_as_trainable
58
+
58
59
  mark_only_lora_as_trainable(self.model)
59
60
  self.use_scorer = self.scorer is not None
60
61
 
@@ -107,6 +108,8 @@ class Model(LightningModule):
107
108
  :rtype: None
108
109
  """
109
110
  if self.model.lora_enabled:
111
+ from loralib import lora_state_dict
112
+
110
113
  torch.save(
111
114
  {
112
115
  "lora_nn": lora_state_dict(self.model),
@@ -152,6 +155,8 @@ class Regressor(LightningModule):
152
155
  self.model.requires_grad_(not hparam["freeze"])
153
156
  self.save_hyperparameters(hparam, ignore=["model", "mlp"])
154
157
  if model.lora_enabled:
158
+ from loralib import mark_only_lora_as_trainable
159
+
155
160
  mark_only_lora_as_trainable(self.model)
156
161
  assert hparam["mode"] in ("regression", "classification")
157
162
 
@@ -231,6 +236,8 @@ class Regressor(LightningModule):
231
236
  )
232
237
  if not self.hparams.freeze:
233
238
  if self.model.lora_enabled:
239
+ from loralib import lora_state_dict
240
+
234
241
  torch.save(
235
242
  {
236
243
  "lora_nn": lora_state_dict(self.model),
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayesianflow_for_chem
3
- Version: 2.0.5
3
+ Version: 2.2.1
4
4
  Summary: Bayesian flow network framework for Chemistry
5
5
  Home-page: https://augus1999.github.io/bayesian-flow-network-for-chemistry/
6
6
  Author: Nianze A. Tao
@@ -54,12 +54,14 @@ This is the repository of the PyTorch implementation of ChemBFN model.
54
54
 
55
55
  [![PyPI](https://img.shields.io/pypi/v/bayesianflow-for-chem?color=ff69b4)](https://pypi.org/project/bayesianflow-for-chem/)
56
56
  ![pytest](https://github.com/Augus1999/bayesian-flow-network-for-chemistry/actions/workflows/pytest.yml/badge.svg)
57
+ [![document](https://github.com/Augus1999/bayesian-flow-network-for-chemistry/actions/workflows/pages/pages-build-deployment/badge.svg)](https://augus1999.github.io/bayesian-flow-network-for-chemistry/)
57
58
 
58
59
  ## Features
59
60
 
60
61
  ChemBFN provides the state-of-the-art functionalities of
61
62
  * SMILES or SELFIES-based *de novo* molecule generation
62
63
  * Protein sequence *de novo* generation
64
+ * Template optimisation (mol2mol)
63
65
  * Classifier-free guidance conditional generation (single or multi-objective optimisation)
64
66
  * Context-guided conditional generation (inpaint)
65
67
  * Outstanding out-of-distribution chemical space sampling
@@ -71,6 +73,7 @@ in an all-in-one-model style.
71
73
 
72
74
  ## News
73
75
 
76
+ * [09/10/2025] A web app [`chembfn_webui`](https://github.com/Augus1999/ChemBFN-WebUI) for hosting ChemBFN models is available on [PyPI](https://pypi.org/project/chembfn-webui/).
74
77
  * [30/01/2025] The package `bayesianflow_for_chem` is available on [PyPI](https://pypi.org/project/bayesianflow-for-chem/).
75
78
  * [21/01/2025] Our first paper has been accepted by [JCIM](https://pubs.acs.org/doi/10.1021/acs.jcim.4c01792).
76
79
  * [17/12/2024] The second paper of out-of-distribution generation is available on [arxiv.org](https://arxiv.org/abs/2412.11439).
@@ -0,0 +1,15 @@
1
+ bayesianflow_for_chem/__init__.py,sha256=l6twtXTdnix9SC8oL0Aqr0dYJ1uBUmMtk2QbUaRrhkE,642
2
+ bayesianflow_for_chem/cli.py,sha256=jZPBUVwOl4qnjUMb3Lavu6owv1RP9jGdrDlM7KDA99c,26600
3
+ bayesianflow_for_chem/data.py,sha256=RM3YaZUkTJ4-RNLH37GwhM4sRddOOusrC996ZPAe1lI,6590
4
+ bayesianflow_for_chem/model.py,sha256=M35G4u4mX4btl9vOK3Iqs6yOSuIKI_OoCTmLhmjbwNk,57559
5
+ bayesianflow_for_chem/scorer.py,sha256=gQFUlkyxitch02ntqcRh1ZS8aondKLynW5U6NfTQTb4,4084
6
+ bayesianflow_for_chem/spectra.py,sha256=Ba9ib1aDvTtDYbH3b4d-lIty3ZSQMu7jwehuV2KmhwA,1781
7
+ bayesianflow_for_chem/tool.py,sha256=VQo6xShwwmJbBsC-xTRL5ysCI7OLoOpmwQ-A7AuAPI8,23587
8
+ bayesianflow_for_chem/train.py,sha256=7AU0A-eZwzSYsLyIe3OxGTNWPnhGpHmVUaQLplV2Fn8,9886
9
+ bayesianflow_for_chem/_data/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
10
+ bayesianflow_for_chem-2.2.1.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
11
+ bayesianflow_for_chem-2.2.1.dist-info/METADATA,sha256=YKuXgdklakB9DH4cFpTy3EbYdmo8dz9BgQQ4bI9CVQs,6476
12
+ bayesianflow_for_chem-2.2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
13
+ bayesianflow_for_chem-2.2.1.dist-info/entry_points.txt,sha256=N63RMoJsr8rxuKxc7Fj802SL8J5AlpCoPkS8E3IFPLI,54
14
+ bayesianflow_for_chem-2.2.1.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
15
+ bayesianflow_for_chem-2.2.1.dist-info/RECORD,,
@@ -1,15 +0,0 @@
1
- bayesianflow_for_chem/__init__.py,sha256=U8ExDm5IRa-OPOLgt8VfjMDAiCmARiJonIKnZ8AVDQ0,612
2
- bayesianflow_for_chem/cli.py,sha256=g5pjXymycpchlgRE6SKr0LjTBjUl6MFHnFdQiKjcE3Q,22803
3
- bayesianflow_for_chem/data.py,sha256=Pl0gGWHmMKTKHpsxznvLgYPCwwlLNL7nqH19Vipjkxs,6584
4
- bayesianflow_for_chem/model.py,sha256=QF15BLpUjEpUCneTOHoU6MswvxArPfbFMiOHjwJ9JrM,52230
5
- bayesianflow_for_chem/scorer.py,sha256=gQFUlkyxitch02ntqcRh1ZS8aondKLynW5U6NfTQTb4,4084
6
- bayesianflow_for_chem/spectra.py,sha256=Ba9ib1aDvTtDYbH3b4d-lIty3ZSQMu7jwehuV2KmhwA,1781
7
- bayesianflow_for_chem/tool.py,sha256=JSlKrar1vVouTRUJBYU6lc5AQmAIaexvpUA3W3oQcKs,21284
8
- bayesianflow_for_chem/train.py,sha256=jYkhSguW50lrcTEydCQ20yig_mmc1j7WH9KmVwBCTAo,9727
9
- bayesianflow_for_chem/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
10
- bayesianflow_for_chem-2.0.5.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
11
- bayesianflow_for_chem-2.0.5.dist-info/METADATA,sha256=qCd0LnD66eQbWlTvKDpHZHtJKmy9NQogm_k9_amEt2s,6057
12
- bayesianflow_for_chem-2.0.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
13
- bayesianflow_for_chem-2.0.5.dist-info/entry_points.txt,sha256=N63RMoJsr8rxuKxc7Fj802SL8J5AlpCoPkS8E3IFPLI,54
14
- bayesianflow_for_chem-2.0.5.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
15
- bayesianflow_for_chem-2.0.5.dist-info/RECORD,,
File without changes