bayesianflow-for-chem 2.1.0__py3-none-any.whl → 2.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bayesianflow-for-chem might be problematic. Click here for more details.

@@ -3,10 +3,8 @@
3
3
  """
4
4
  ChemBFN package.
5
5
  """
6
- import colorama
7
6
  from . import data, tool, train, scorer, spectra
8
7
  from .model import ChemBFN, MLP, EnsembleChemBFN
9
- from .cli import main_script
10
8
 
11
9
  __all__ = [
12
10
  "data",
@@ -18,7 +16,7 @@ __all__ = [
18
16
  "MLP",
19
17
  "EnsembleChemBFN",
20
18
  ]
21
- __version__ = "2.1.0"
19
+ __version__ = "2.2.1"
22
20
  __author__ = "Nianze A. Tao (Omozawa Sueno)"
23
21
 
24
22
 
@@ -29,6 +27,9 @@ def main() -> None:
29
27
  :return:
30
28
  :rtype: None
31
29
  """
30
+ import colorama
31
+ from bayesianflow_for_chem.cli import main_script
32
+
32
33
  colorama.just_fix_windows_console()
33
34
  main_script(__version__)
34
35
  colorama.deinit()
@@ -12,13 +12,8 @@ from pathlib import Path
12
12
  from functools import partial
13
13
  from typing import List, Tuple, Dict, Union, Callable
14
14
  import torch
15
- import lightning as L
16
15
  from rdkit.Chem import MolFromSmiles, CanonSmiles
17
- from torch.utils.data import DataLoader
18
- from lightning.pytorch import loggers
19
- from lightning.pytorch.callbacks import ModelCheckpoint
20
16
  from bayesianflow_for_chem import ChemBFN, MLP
21
- from bayesianflow_for_chem.train import Model
22
17
  from bayesianflow_for_chem.scorer import smiles_valid, Scorer
23
18
  from bayesianflow_for_chem.data import (
24
19
  VOCAB_COUNT,
@@ -90,6 +85,7 @@ checkpoint_save_path = "home/user/project/ckpt"
90
85
  train_strategy = "auto" # or any strategy supported by Lightning, e.g., "ddp"
91
86
  accumulate_grad_batches = 1
92
87
  enable_progress_bar = false
88
+ plugin_script = "" # define customised behaviours of dataset, datasetloader, etc in a python script
93
89
 
94
90
  # Remove this table if inference is unnecessary
95
91
  [inference]
@@ -120,6 +116,32 @@ madmadmadmadmadmadmadmadmadmadmadmadmadmadmad
120
116
  madmadmadmadmadmadmadmadmadmadmadmadmadmadmad
121
117
  """
122
118
 
119
+ _ALLOWED_PLUGINS = [
120
+ "collate_fn",
121
+ "num_workers",
122
+ "max_sequence_length",
123
+ "shuffle",
124
+ "CustomData",
125
+ ]
126
+
127
+
128
+ def _load_plugin(plugin_file: str) -> Dict[str, Union[int, Callable, object, None]]:
129
+ if not plugin_file:
130
+ return {n: None for n in _ALLOWED_PLUGINS}
131
+ from importlib import util as iutil
132
+
133
+ spec = iutil.spec_from_file_location(Path(plugin_file).stem, plugin_file)
134
+ plugins = iutil.module_from_spec(spec)
135
+ spec.loader.exec_module(plugins)
136
+ plugin_names: List[str] = plugins.__all__
137
+ plugin_dict = {}
138
+ for n in _ALLOWED_PLUGINS:
139
+ if n in plugin_names:
140
+ plugin_dict[n] = getattr(plugins, n)
141
+ else:
142
+ plugin_dict[n] = None
143
+ return plugin_dict
144
+
123
145
 
124
146
  def parse_cli(version: str) -> argparse.Namespace:
125
147
  """
@@ -267,6 +289,14 @@ def load_runtime_config(
267
289
  f"\033[0;31mCritical\033[0;0m in {config_file}: Restart checkpoint file {ckpt_file} does not exist."
268
290
  )
269
291
  flag_critical += 1
292
+ # ↓ added in v2.2.0; need to be compatible with old versions.
293
+ plugin_script: str = config["train"].get("plugin_script", "")
294
+ if plugin_script:
295
+ if not os.path.exists(plugin_script):
296
+ print(
297
+ f"\033[0;31mCritical\033[0;0m in {config_file}: Plugin script {plugin_script} does not exist."
298
+ )
299
+ flag_critical += 1
270
300
  if "inference" in config:
271
301
  if not "train" in config:
272
302
  if not isinstance(config["inference"]["sequence_length"], int):
@@ -335,14 +365,15 @@ def main_script(version: str) -> None:
335
365
  )
336
366
  flag_warning += 1
337
367
  if not os.path.exists(runtime_config["train"]["checkpoint_save_path"]):
338
- os.makedirs(runtime_config["train"]["checkpoint_save_path"])
368
+ if not parser.dryrun: # only create it in real tasks
369
+ os.makedirs(runtime_config["train"]["checkpoint_save_path"])
339
370
  else:
340
371
  if not model_config["ChemBFN"]["base_model"]:
341
372
  print(
342
373
  f"\033[0;33mWarning\033[0;0m in {parser.model_config}: You should load a pretrained ChemBFN model."
343
374
  )
344
375
  flag_warning += 1
345
- if not model_config["MLP"]["base_model"]:
376
+ if "MLP" in model_config and not model_config["MLP"]["base_model"]:
346
377
  print(
347
378
  f"\033[0;33mWarning\033[0;0m in {parser.model_config}: You should load a pretrained MLP."
348
379
  )
@@ -350,13 +381,17 @@ def main_script(version: str) -> None:
350
381
  if "inference" in runtime_config:
351
382
  if runtime_config["inference"]["guidance_objective"]:
352
383
  if not "MLP" in model_config:
353
- print(f"Warning in {parser.model_config}: Oh no, you don't have a MLP.")
384
+ print(
385
+ f"\033[0;33mWarning\033[0;0m in {parser.model_config}: Oh no, you don't have a MLP."
386
+ )
354
387
  flag_warning += 1
355
388
  if parser.dryrun:
356
389
  if flag_critical != 0:
357
390
  print("Configuration check failed!")
358
391
  elif flag_warning != 0:
359
- print("Your job will probably run, but it may not follow your expectation.")
392
+ print(
393
+ "Your job will probably run, but it may not follow your expectations."
394
+ )
360
395
  else:
361
396
  print("Configuration check passed.")
362
397
  return
@@ -415,6 +450,15 @@ def main_script(version: str) -> None:
415
450
  mlp = None
416
451
  # ------- train -------
417
452
  if "train" in runtime_config:
453
+ import lightning as L
454
+ from torch.utils.data import DataLoader
455
+ from lightning.pytorch import loggers
456
+ from lightning.pytorch.callbacks import ModelCheckpoint
457
+ from bayesianflow_for_chem.train import Model
458
+
459
+ # ####### get plugins #######
460
+ plugin_file = runtime_config["train"].get("plugin_script", "")
461
+ plugins = _load_plugin(plugin_file)
418
462
  # ####### build scorer #######
419
463
  if (tokeniser_name == "smiles" or tokeniser_name == "safe") and runtime_config[
420
464
  "train"
@@ -428,30 +472,43 @@ def main_script(version: str) -> None:
428
472
  mol_tag = runtime_config["train"]["molecule_tag"]
429
473
  obj_tag = runtime_config["train"]["objective_tag"]
430
474
  dataset_file = runtime_config["train"]["dataset"]
431
- with open(dataset_file, "r") as db:
432
- _data = db.readlines()
433
- header = _data[0]
434
- mol_idx = []
435
- for i, tag in enumerate(header.replace("\n", "").split(",")):
436
- if tag == mol_tag:
437
- mol_idx.append(i)
438
- _data_len = []
439
- for i in _data[1:]:
440
- i = i.replace("\n", "").split(",")
441
- _mol = ".".join([i[j] for j in mol_idx])
442
- _data_len.append(tokeniser(_mol).shape[-1])
443
- lmax = max(_data_len)
444
- dataset = CSVData(dataset_file)
475
+ if plugins["max_sequence_length"]:
476
+ lmax = plugins["max_sequence_length"]
477
+ else:
478
+ with open(dataset_file, "r") as db:
479
+ _data = db.readlines()
480
+ _header = _data[0]
481
+ _mol_idx = []
482
+ for i, tag in enumerate(_header.replace("\n", "").split(",")):
483
+ if tag == mol_tag:
484
+ _mol_idx.append(i)
485
+ _data_len = []
486
+ for i in _data[1:]:
487
+ i = i.replace("\n", "").split(",")
488
+ _mol = ".".join([i[j] for j in _mol_idx])
489
+ _data_len.append(tokeniser(_mol).shape[-1])
490
+ lmax = max(_data_len)
491
+ del _data, _data_len, _header, _mol_idx # clear memory
492
+ if plugins["CustomData"] is not None:
493
+ dataset = plugins["CustomData"](dataset_file)
494
+ else:
495
+ dataset = CSVData(dataset_file)
445
496
  dataset.map(
446
497
  partial(_encode, mol_tag=mol_tag, obj_tag=obj_tag, tokeniser=tokeniser)
447
498
  )
448
499
  dataloader = DataLoader(
449
500
  dataset,
450
501
  runtime_config["train"]["batch_size"],
451
- True,
452
- num_workers=4,
453
- collate_fn=collate,
454
- persistent_workers=True,
502
+ True if plugins["shuffle"] is None else plugins["shuffle"],
503
+ num_workers=4 if plugins["num_workers"] is None else plugins["num_workers"],
504
+ collate_fn=(
505
+ collate if plugins["collate_fn"] is None else plugins["collate_fn"]
506
+ ),
507
+ persistent_workers=(
508
+ True
509
+ if (plugins["num_workers"] is None or plugins["num_workers"] > 0)
510
+ else False
511
+ ),
455
512
  )
456
513
  # ####### build trainer #######
457
514
  logger_name = runtime_config["train"]["logger_name"].lower()
@@ -530,6 +587,7 @@ def main_script(version: str) -> None:
530
587
  if "train" in runtime_config:
531
588
  bfn = model.model
532
589
  mlp = model.mlp
590
+ # ↓ added in v2.1.0; need to be compatible with old versions
533
591
  lora_scaling = runtime_config["inference"].get("lora_scaling", 1.0)
534
592
  # ####### strat inference #######
535
593
  bfn.semi_autoregressive = runtime_config["inference"]["semi_autoregressive"]
@@ -61,7 +61,7 @@ def load_vocab(
61
61
  }
62
62
 
63
63
 
64
- _DEFUALT_VOCAB = load_vocab(__filedir__ / "vocab.txt")
64
+ _DEFUALT_VOCAB = load_vocab(__filedir__ / "_data/vocab.txt")
65
65
  VOCAB_KEYS: List[str] = _DEFUALT_VOCAB["vocab_keys"]
66
66
  VOCAB_DICT: Dict[str, int] = _DEFUALT_VOCAB["vocab_dict"]
67
67
  VOCAB_COUNT: int = _DEFUALT_VOCAB["vocab_count"]
@@ -659,12 +659,91 @@ class ChemBFN(nn.Module):
659
659
  return (-logits.gather(-1, x[..., :1]).squeeze(-1)).mean()
660
660
 
661
661
  @staticmethod
662
- def reshape_y(y: Tensor) -> Tensor:
662
+ def _reshape(y: Tensor) -> Tensor:
663
663
  assert y.dim() <= 3 # this doesn't work if the model is frezen in JIT.
664
664
  if y.dim() == 2:
665
665
  return y[:, None, :]
666
666
  return y
667
667
 
668
+ def _process(
669
+ self,
670
+ theta: Tensor,
671
+ mask: Optional[Tuple[Tensor, Tensor]],
672
+ y: Optional[Tensor],
673
+ sample_step: int,
674
+ guidance_strength: float,
675
+ token_mask: Optional[Tensor],
676
+ ) -> Tuple[Tensor, Tensor]:
677
+ # BFN inference process.
678
+ #
679
+ # theta: piror distribution; shape: (n_b, n_t, n_vocab)
680
+ # mask: masked condition distribution; shape: (n_b, n_t, n_vocab)
681
+ # condition distribution mask; shape: (n_b, n_t, 1)
682
+ n_b = theta.shape[0]
683
+ if y is not None:
684
+ y = self._reshape(y)
685
+ for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
686
+ t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
687
+ p = self.discrete_output_distribution(theta, t, y, guidance_strength)
688
+ if token_mask is not None:
689
+ p = p.masked_fill_(token_mask, 0.0)
690
+ alpha = self.calc_discrete_alpha(t, t + 1 / sample_step)
691
+ e_k = nn.functional.one_hot(torch.argmax(p, -1), self.K).float()
692
+ mu = alpha * (self.K * e_k - 1)
693
+ sigma = (alpha * self.K).sqrt()
694
+ theta = (mu + sigma * torch.randn_like(mu)).exp() * theta
695
+ theta = theta / theta.sum(-1, True)
696
+ if mask is not None:
697
+ x_onehot, x_mask = mask
698
+ theta = x_onehot + (1 - x_mask) * theta
699
+ t_final = torch.ones((n_b, 1, 1), device=self.beta.device)
700
+ p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
701
+ entropy = -(p * p.log()).sum(-1).mean(-1)
702
+ if token_mask is not None:
703
+ p = p.masked_fill_(token_mask, 0.0)
704
+ return torch.argmax(p, -1), entropy
705
+
706
+ def _ode_process(
707
+ self,
708
+ z: Tensor,
709
+ mask: Optional[Tuple[Tensor, Tensor]],
710
+ y: Optional[Tensor],
711
+ sample_step: int,
712
+ guidance_strength: float,
713
+ token_mask: Optional[Tensor],
714
+ temperature: float,
715
+ ) -> Tuple[Tensor, Tensor]:
716
+ # ODE-solver engaged inference process.
717
+ #
718
+ # z: prior latent vector; shape: (n_b, n_t, n_vocab)
719
+ # mask: masked condition distribution; shape: (n_b, n_t, n_vocab)
720
+ # condition distribution mask; shape: (n_b, n_t, 1)
721
+ n_b = z.shape[0]
722
+ if y is not None:
723
+ y = self._reshape(y)
724
+ for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
725
+ t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
726
+ theta = softmax(z, -1)
727
+ if mask is not None:
728
+ x_onehot, x_mask = mask
729
+ theta = x_onehot + (1 - x_mask) * theta
730
+ beta = self.calc_beta(t + 1 / sample_step)
731
+ p = self.discrete_output_distribution(theta, t, y, guidance_strength)
732
+ if token_mask is not None:
733
+ p = p.masked_fill_(token_mask, 0.0)
734
+ u = torch.randn_like(z)
735
+ z = (self.K * p - 1) * beta + (self.K * beta * temperature).sqrt() * u
736
+ t_final = torch.ones((n_b, 1, 1), device=self.beta.device)
737
+ theta = softmax(z, -1)
738
+ if mask is not None:
739
+ x_onehot, x_mask = mask
740
+ theta = x_onehot + (1 - x_mask) * theta
741
+ p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
742
+ entropy = -(p * p.log()).sum(-1).mean(-1)
743
+ if token_mask is not None:
744
+ p = p.masked_fill_(token_mask, 0.0)
745
+ return torch.argmax(p, -1), entropy
746
+
668
747
  @torch.jit.export
669
748
  def sample(
670
749
  self,
@@ -680,44 +759,26 @@ class ChemBFN(nn.Module):
680
759
 
681
760
  :param batch_size: batch size
682
761
  :param sequence_size: max sequence length
683
- :param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
762
+ :param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
684
763
  :param sample_step: number of sampling steps
685
764
  :param guidance_strength: strength of conditional generation. It is not used if y is null.
686
765
  :param token_mask: token mask assigning unwanted token(s) with `True`;
687
- shape: (1, 1, n_vocab)
766
+ shape: (1, 1, n_vocab)
688
767
  :type batch_size: int
689
768
  :type sequence_size: int
690
769
  :type y: torch.Tensor | None
691
770
  :type sample_step: int
692
771
  :type guidance_strength: float
693
772
  :type token_mask: torch.Tensor | None
694
- :return: sampled token indices; shape: (n_b, n_t) \n
695
- entropy of the tokens; shape: (n_b)
773
+ :return: sampled token indices; shape: (n_b, n_t) \n
774
+ entropy of the tokens; shape: (n_b)
696
775
  :rtype: tuple
697
776
  """
698
777
  theta = (
699
778
  torch.ones((batch_size, sequence_size, self.K), device=self.beta.device)
700
779
  / self.K
701
780
  )
702
- if y is not None:
703
- y = self.reshape_y(y)
704
- for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
705
- t = (i - 1).view(1, 1, 1).repeat(batch_size, 1, 1) / sample_step
706
- p = self.discrete_output_distribution(theta, t, y, guidance_strength)
707
- if token_mask is not None:
708
- p = p.masked_fill_(token_mask, 0.0)
709
- alpha = self.calc_discrete_alpha(t, t + 1 / sample_step)
710
- e_k = nn.functional.one_hot(torch.argmax(p, -1), self.K).float()
711
- mu = alpha * (self.K * e_k - 1)
712
- sigma = (alpha * self.K).sqrt()
713
- theta = (mu + sigma * torch.randn_like(mu)).exp() * theta
714
- theta = theta / theta.sum(-1, True)
715
- t_final = torch.ones((batch_size, 1, 1), device=self.beta.device)
716
- p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
717
- entropy = -(p * p.log()).sum(-1).mean(-1)
718
- if token_mask is not None:
719
- p = p.masked_fill_(token_mask, 0.0)
720
- return torch.argmax(p, -1), entropy
781
+ return self._process(theta, None, y, sample_step, guidance_strength, token_mask)
721
782
 
722
783
  @torch.jit.export
723
784
  def ode_sample(
@@ -735,11 +796,11 @@ class ChemBFN(nn.Module):
735
796
 
736
797
  :param batch_size: batch size
737
798
  :param sequence_size: max sequence length
738
- :param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
799
+ :param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
739
800
  :param sample_step: number of sampling steps
740
801
  :param guidance_strength: strength of conditional generation. It is not used if y is null.
741
802
  :param token_mask: token mask assigning unwanted token(s) with `True`;
742
- shape: (1, 1, n_vocab)
803
+ shape: (1, 1, n_vocab)
743
804
  :param temperature: sampling temperature
744
805
  :type batch_size: int
745
806
  :type sequence_size: int
@@ -748,29 +809,14 @@ class ChemBFN(nn.Module):
748
809
  :type guidance_strength: float
749
810
  :type token_mask: torch.Tensor | None
750
811
  :type temperature: float
751
- :return: sampled token indices; shape: (n_b, n_t) \n
752
- entropy of the tokens; shape: (n_b)
812
+ :return: sampled token indices; shape: (n_b, n_t) \n
813
+ entropy of the tokens; shape: (n_b)
753
814
  :rtype: tuple
754
815
  """
755
816
  z = torch.zeros((batch_size, sequence_size, self.K), device=self.beta.device)
756
- if y is not None:
757
- y = self.reshape_y(y)
758
- for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
759
- t = (i - 1).view(1, 1, 1).repeat(batch_size, 1, 1) / sample_step
760
- theta = torch.softmax(z, -1)
761
- beta = self.calc_beta(t + 1 / sample_step)
762
- p = self.discrete_output_distribution(theta, t, y, guidance_strength)
763
- if token_mask is not None:
764
- p = p.masked_fill_(token_mask, 0.0)
765
- u = torch.randn_like(z)
766
- z = (self.K * p - 1) * beta + (self.K * beta * temperature).sqrt() * u
767
- t_final = torch.ones((batch_size, 1, 1), device=self.beta.device)
768
- theta = torch.softmax(z, -1)
769
- p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
770
- entropy = -(p * p.log()).sum(-1).mean(-1)
771
- if token_mask is not None:
772
- p = p.masked_fill_(token_mask, 0.0)
773
- return torch.argmax(p, -1), entropy
817
+ return self._ode_process(
818
+ z, None, y, sample_step, guidance_strength, token_mask, temperature
819
+ )
774
820
 
775
821
  @torch.jit.export
776
822
  def inpaint(
@@ -800,30 +846,12 @@ class ChemBFN(nn.Module):
800
846
  :rtype: tuple
801
847
  """
802
848
  n_b, n_t = x.shape
803
- mask = (x != 0).float()[..., None]
849
+ x_mask = (x != 0).float()[..., None]
804
850
  theta = torch.ones((n_b, n_t, self.K), device=x.device) / self.K
805
- x_onehot = nn.functional.one_hot(x, self.K) * mask
806
- theta = x_onehot + (1 - mask) * theta
807
- if y is not None:
808
- y = self.reshape_y(y)
809
- for i in torch.linspace(1, sample_step, sample_step, device=x.device):
810
- t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
811
- p = self.discrete_output_distribution(theta, t, y, guidance_strength)
812
- if token_mask is not None:
813
- p = p.masked_fill_(token_mask, 0.0)
814
- alpha = self.calc_discrete_alpha(t, t + 1 / sample_step)
815
- e_k = nn.functional.one_hot(torch.argmax(p, -1), self.K).float()
816
- mu = alpha * (self.K * e_k - 1)
817
- sigma = (alpha * self.K).sqrt()
818
- theta = (mu + sigma * torch.randn_like(mu)).exp() * theta
819
- theta = theta / theta.sum(-1, True)
820
- theta = x_onehot + (1 - mask) * theta
821
- t_final = torch.ones((n_b, 1, 1), device=x.device)
822
- p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
823
- entropy = -(p * p.log()).sum(-1).mean(-1)
824
- if token_mask is not None:
825
- p = p.masked_fill_(token_mask, 0.0)
826
- return torch.argmax(p, -1), entropy
851
+ x_onehot = nn.functional.one_hot(x, self.K) * x_mask
852
+ theta = x_onehot + (1 - x_mask) * theta
853
+ mask = (x_onehot, x_mask)
854
+ return self._process(theta, mask, y, sample_step, guidance_strength, token_mask)
827
855
 
828
856
  @torch.jit.export
829
857
  def ode_inpaint(
@@ -856,29 +884,13 @@ class ChemBFN(nn.Module):
856
884
  :rtype: tuple
857
885
  """
858
886
  n_b, n_t = x.shape
859
- mask = (x != 0).float()[..., None]
860
- x_onehot = nn.functional.one_hot(x, self.K) * mask
887
+ x_mask = (x != 0).float()[..., None]
888
+ x_onehot = nn.functional.one_hot(x, self.K) * x_mask
861
889
  z = torch.zeros((n_b, n_t, self.K), device=self.beta.device)
862
- if y is not None:
863
- y = self.reshape_y(y)
864
- for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
865
- t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
866
- theta = torch.softmax(z, -1)
867
- theta = x_onehot + (1 - mask) * theta
868
- beta = self.calc_beta(t + 1 / sample_step)
869
- p = self.discrete_output_distribution(theta, t, y, guidance_strength)
870
- if token_mask is not None:
871
- p = p.masked_fill_(token_mask, 0.0)
872
- u = torch.randn_like(z)
873
- z = (self.K * p - 1) * beta + (self.K * beta * temperature).sqrt() * u
874
- t_final = torch.ones((n_b, 1, 1), device=self.beta.device)
875
- theta = torch.softmax(z, -1)
876
- theta = x_onehot + (1 - mask) * theta
877
- p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
878
- entropy = -(p * p.log()).sum(-1).mean(-1)
879
- if token_mask is not None:
880
- p = p.masked_fill_(token_mask, 0.0)
881
- return torch.argmax(p, -1), entropy
890
+ mask = (x_onehot, x_mask)
891
+ return self._ode_process(
892
+ z, mask, y, sample_step, guidance_strength, token_mask, temperature
893
+ )
882
894
 
883
895
  @torch.jit.export
884
896
  def optimise(
@@ -908,28 +920,9 @@ class ChemBFN(nn.Module):
908
920
  entropy of the tokens; shape: (n_b)
909
921
  :rtype: tuple
910
922
  """
911
- n_b = x.shape[0]
912
923
  x_onehot = nn.functional.one_hot(x, self.K).float()
913
- theta = nn.functional.softmax(x_onehot, -1)
914
- if y is not None:
915
- y = self.reshape_y(y)
916
- for i in torch.linspace(1, sample_step, sample_step, device=x.device):
917
- t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
918
- p = self.discrete_output_distribution(theta, t, y, guidance_strength)
919
- if token_mask is not None:
920
- p = p.masked_fill_(token_mask, 0.0)
921
- alpha = self.calc_discrete_alpha(t, t + 1 / sample_step)
922
- e_k = nn.functional.one_hot(torch.argmax(p, -1), self.K).float()
923
- mu = alpha * (self.K * e_k - 1)
924
- sigma = (alpha * self.K).sqrt()
925
- theta = (mu + sigma * torch.randn_like(mu)).exp() * theta
926
- theta = theta / theta.sum(-1, True)
927
- t_final = torch.ones((n_b, 1, 1), device=x.device)
928
- p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
929
- entropy = -(p * p.log()).sum(-1).mean(-1)
930
- if token_mask is not None:
931
- p = p.masked_fill_(token_mask, 0.0)
932
- return torch.argmax(p, -1), entropy
924
+ theta = softmax(x_onehot, -1)
925
+ return self._process(theta, None, y, sample_step, guidance_strength, token_mask)
933
926
 
934
927
  @torch.jit.export
935
928
  def ode_optimise(
@@ -961,26 +954,10 @@ class ChemBFN(nn.Module):
961
954
  entropy of the tokens; shape: (n_b)
962
955
  :rtype: tuple
963
956
  """
964
- n_b = x.shape[0]
965
957
  z = nn.functional.one_hot(x, self.K).float()
966
- if y is not None:
967
- y = self.reshape_y(y)
968
- for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
969
- t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
970
- theta = torch.softmax(z, -1)
971
- beta = self.calc_beta(t + 1 / sample_step)
972
- p = self.discrete_output_distribution(theta, t, y, guidance_strength)
973
- if token_mask is not None:
974
- p = p.masked_fill_(token_mask, 0.0)
975
- u = torch.randn_like(z)
976
- z = (self.K * p - 1) * beta + (self.K * beta * temperature).sqrt() * u
977
- t_final = torch.ones((n_b, 1, 1), device=self.beta.device)
978
- theta = torch.softmax(z, -1)
979
- p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
980
- entropy = -(p * p.log()).sum(-1).mean(-1)
981
- if token_mask is not None:
982
- p = p.masked_fill_(token_mask, 0.0)
983
- return torch.argmax(p, -1), entropy
958
+ return self._ode_process(
959
+ z, None, y, sample_step, guidance_strength, token_mask, temperature
960
+ )
984
961
 
985
962
  def inference(
986
963
  self, x: Tensor, mlp: MLP, embed_fn: Optional[Callable[[Tensor], Tensor]] = None
@@ -1154,22 +1131,6 @@ class EnsembleChemBFN(ChemBFN):
1154
1131
  module.lora_dropout = None
1155
1132
  v.lora_enabled = False
1156
1133
 
1157
- def construct_y(
1158
- self, c: Union[List[Tensor], Dict[str, Tensor]]
1159
- ) -> Dict[str, Tensor]:
1160
- assert (
1161
- isinstance(c, dict) is self._label_is_dict
1162
- ), f"`c` should be a {'`dict` instance' if self._label_is_dict else '`list` instance'} but got {type(c)} instand."
1163
- out: Dict[str, Tensor] = {}
1164
- if isinstance(c, list):
1165
- c = dict(zip([f"val_{i}" for i in range(len(c))], c))
1166
- for name, model in self.cond_heads.items():
1167
- y = model.forward(c[name])
1168
- if y.dim() == 2:
1169
- y = y[:, None, :]
1170
- out[name] = y
1171
- return out
1172
-
1173
1134
  def discrete_output_distribution(
1174
1135
  self, theta: Tensor, t: Tensor, y: Dict[str, Tensor], w: float
1175
1136
  ) -> Tensor:
@@ -1204,8 +1165,24 @@ class EnsembleChemBFN(ChemBFN):
1204
1165
  p_cond += p_cond_ * self.adapter_weights[name]
1205
1166
  return softmax((1 + w) * p_cond - w * p_uncond, -1)
1206
1167
 
1168
+ def _map_to_dict(
1169
+ self, c: Union[List[Tensor], Dict[str, Tensor]]
1170
+ ) -> Dict[str, Tensor]:
1171
+ assert (
1172
+ isinstance(c, dict) is self._label_is_dict
1173
+ ), f"`c` should be a {'`dict` instance' if self._label_is_dict else '`list` instance'} but got {type(c)} instand."
1174
+ out: Dict[str, Tensor] = {}
1175
+ if isinstance(c, list):
1176
+ c = dict(zip([f"val_{i}" for i in range(len(c))], c))
1177
+ for name, model in self.cond_heads.items():
1178
+ y = model.forward(c[name])
1179
+ if y.dim() == 2:
1180
+ y = y[:, None, :]
1181
+ out[name] = y
1182
+ return out
1183
+
1207
1184
  @staticmethod
1208
- def reshape_y(y: Dict[str, Tensor]) -> Dict[str, Tensor]:
1185
+ def _reshape(y: Dict[str, Tensor]) -> Dict[str, Tensor]:
1209
1186
  for k in y:
1210
1187
  assert y[k].dim() <= 3
1211
1188
  if y[k].dim() == 2:
@@ -1241,7 +1218,7 @@ class EnsembleChemBFN(ChemBFN):
1241
1218
  entropy of the tokens; shape: (n_b)
1242
1219
  :rtype: tuple
1243
1220
  """
1244
- y = self.construct_y(conditions)
1221
+ y = self._map_to_dict(conditions)
1245
1222
  return super().sample(
1246
1223
  batch_size, sequence_size, y, sample_step, guidance_strength, token_mask
1247
1224
  )
@@ -1278,7 +1255,7 @@ class EnsembleChemBFN(ChemBFN):
1278
1255
  entropy of the tokens; shape: (n_b)
1279
1256
  :rtype: tuple
1280
1257
  """
1281
- y = self.construct_y(conditions)
1258
+ y = self._map_to_dict(conditions)
1282
1259
  return super().ode_sample(
1283
1260
  batch_size,
1284
1261
  sequence_size,
@@ -1315,7 +1292,7 @@ class EnsembleChemBFN(ChemBFN):
1315
1292
  entropy of the tokens; shape: (n_b)
1316
1293
  :rtype: tuple
1317
1294
  """
1318
- y = self.construct_y(conditions)
1295
+ y = self._map_to_dict(conditions)
1319
1296
  return super().inpaint(x, y, sample_step, guidance_strength, token_mask)
1320
1297
 
1321
1298
  @torch.inference_mode()
@@ -1347,7 +1324,7 @@ class EnsembleChemBFN(ChemBFN):
1347
1324
  entropy of the tokens; shape: (n_b)
1348
1325
  :rtype: tuple
1349
1326
  """
1350
- y = self.construct_y(conditions)
1327
+ y = self._map_to_dict(conditions)
1351
1328
  return super().ode_inpaint(
1352
1329
  x, y, sample_step, guidance_strength, token_mask, temperature
1353
1330
  )
@@ -1380,7 +1357,7 @@ class EnsembleChemBFN(ChemBFN):
1380
1357
  entropy of the tokens; shape: (n_b)
1381
1358
  :rtype: tuple
1382
1359
  """
1383
- y = self.construct_y(conditions)
1360
+ y = self._map_to_dict(conditions)
1384
1361
  return super().optimise(x, y, sample_step, guidance_strength, token_mask)
1385
1362
 
1386
1363
  @torch.inference_mode()
@@ -1412,7 +1389,7 @@ class EnsembleChemBFN(ChemBFN):
1412
1389
  entropy of the tokens; shape: (n_b)
1413
1390
  :rtype: tuple
1414
1391
  """
1415
- y = self.construct_y(conditions)
1392
+ y = self._map_to_dict(conditions)
1416
1393
  return super().ode_optimise(
1417
1394
  x, y, sample_step, guidance_strength, token_mask, temperature
1418
1395
  )
@@ -9,7 +9,6 @@ import warnings
9
9
  from pathlib import Path
10
10
  from typing import List, Dict, Tuple, Union, Optional
11
11
  import torch
12
- import colorama
13
12
  import numpy as np
14
13
  from torch import cuda, Tensor, softmax
15
14
  from torch.utils.data import DataLoader
@@ -24,15 +23,7 @@ from rdkit.Chem import (
24
23
  AddHs,
25
24
  Mol,
26
25
  )
27
- from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles # type: ignore
28
- from sklearn.metrics import (
29
- roc_auc_score,
30
- auc,
31
- precision_recall_curve,
32
- r2_score,
33
- mean_absolute_error,
34
- root_mean_squared_error,
35
- )
26
+ from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles
36
27
  from .data import VOCAB_KEYS
37
28
  from .model import ChemBFN, MLP, EnsembleChemBFN
38
29
 
@@ -45,6 +36,68 @@ def _find_device() -> torch.device:
45
36
  return torch.device("cpu")
46
37
 
47
38
 
39
+ def _parse_and_assert_param(
40
+ model: Union[ChemBFN, EnsembleChemBFN],
41
+ y: Optional[Union[Tensor, Dict[str, Tensor], List[Tensor]]],
42
+ method: str,
43
+ ) -> Optional[float]:
44
+ assert method.split(":")[0].lower() in ("ode", "bfn")
45
+ if isinstance(model, EnsembleChemBFN):
46
+ assert y is not None, "conditioning is required while using an ensemble model."
47
+ assert isinstance(y, list) or isinstance(y, dict)
48
+ else:
49
+ assert isinstance(y, Tensor) or (y is None)
50
+ if "ode" in method.lower():
51
+ tp = float(method.split(":")[-1])
52
+ assert tp > 0, "Sampling temperature should be higher than 0."
53
+ return tp
54
+ return None
55
+
56
+
57
+ def _map_to_device(
58
+ y: Optional[Union[Tensor, Dict[str, Tensor], List[Tensor]]],
59
+ device: Union[str, torch.device],
60
+ ) -> Optional[Union[Tensor, Dict[str, Tensor], List[Tensor]]]:
61
+ if y is not None:
62
+ if isinstance(y, Tensor):
63
+ y = y.to(device)
64
+ elif isinstance(y, list):
65
+ y = [i.to(device) for i in y]
66
+ elif isinstance(y, dict):
67
+ y = {k: v.to(device) for k, v in y.items()}
68
+ else:
69
+ raise NotImplementedError
70
+ return y
71
+
72
+
73
+ def _build_token_mask(
74
+ allowed_tokens: Union[str, List[str]],
75
+ vocab_keys: List[str],
76
+ device: Union[str, torch.tensor],
77
+ ) -> Optional[Tensor]:
78
+ if isinstance(allowed_tokens, list):
79
+ token_mask = [0 if i in allowed_tokens else 1 for i in vocab_keys]
80
+ token_mask = torch.tensor([[token_mask]], dtype=torch.bool).to(device)
81
+ else:
82
+ token_mask = None
83
+ return token_mask
84
+
85
+
86
+ def _token_to_seq(
87
+ tokens: Tensor, entropy: Tensor, vocab_keys: List[str], separator: str, sort: bool
88
+ ) -> List[str]:
89
+ if sort:
90
+ sorted_idx = entropy.argsort(stable=True)
91
+ tokens = tokens[sorted_idx]
92
+ return [
93
+ separator.join([vocab_keys[i] for i in j])
94
+ .split("<start>" + separator)[-1]
95
+ .split(separator + "<end>")[0]
96
+ .replace("<pad>", "")
97
+ for j in tokens
98
+ ]
99
+
100
+
48
101
  @torch.no_grad()
49
102
  def test(
50
103
  model: ChemBFN,
@@ -86,6 +139,12 @@ def test(
86
139
  predict_y.append(y_hat.detach().to("cpu"))
87
140
  predict_y, label_y = torch.cat(predict_y, 0), torch.cat(label_y, 0).split(1, -1)
88
141
  if mode == "regression":
142
+ from sklearn.metrics import (
143
+ r2_score,
144
+ mean_absolute_error,
145
+ root_mean_squared_error,
146
+ )
147
+
89
148
  predict_y = [
90
149
  predict[label_y[i] != torch.inf]
91
150
  for (i, predict) in enumerate(predict_y.split(1, -1))
@@ -99,6 +158,8 @@ def test(
99
158
  r2 = [r2_score(label, predict) for (label, predict) in y_zipped]
100
159
  return {"MAE": mae, "RMSE": rmse, "R^2": r2}
101
160
  if mode == "classification":
161
+ from sklearn.metrics import roc_auc_score, auc, precision_recall_curve
162
+
102
163
  n_c = len(label_y)
103
164
  predict_y = predict_y.chunk(n_c, -1)
104
165
  y_zipped = list(zip(label_y, predict_y))
@@ -142,7 +203,6 @@ def split_dataset(
142
203
  assert file.endswith(".csv")
143
204
  assert len(split_ratio) == 3
144
205
  assert method in ("random", "scaffold")
145
- colorama.just_fix_windows_console()
146
206
  with open(file, "r") as f:
147
207
  data = list(csv.reader(f))
148
208
  header = data[0]
@@ -167,10 +227,8 @@ def split_dataset(
167
227
  # compute Bemis-Murcko scaffold
168
228
  if len(smiles_idx) > 1:
169
229
  warnings.warn(
170
- "\033[32;1m"
171
230
  f"We found {len(smiles_idx)} SMILES strings in a row!"
172
- " Only the first SMILES will be used to compute the molecular scaffold."
173
- "\033[0m",
231
+ " Only the first SMILES will be used to compute the molecular scaffold.",
174
232
  stacklevel=2,
175
233
  )
176
234
  try:
@@ -197,10 +255,10 @@ def split_dataset(
197
255
  with open(file.replace(".csv", "_test.csv"), "w", newline="") as fte:
198
256
  writer = csv.writer(fte)
199
257
  writer.writerows([header] + test_set)
200
- with open(file.replace(".csv", "_val.csv"), "w", newline="") as fva:
201
- writer = csv.writer(fva)
202
- writer.writerows([header] + val_set)
203
- colorama.deinit()
258
+ if val_set:
259
+ with open(file.replace(".csv", "_val.csv"), "w", newline="") as fva:
260
+ writer = csv.writer(fva)
261
+ writer.writerows([header] + val_set)
204
262
 
205
263
 
206
264
  @torch.no_grad()
@@ -250,32 +308,12 @@ def sample(
250
308
  :return: a list of generated molecular strings
251
309
  :rtype: list
252
310
  """
253
- assert method.split(":")[0].lower() in ("ode", "bfn")
254
- if isinstance(model, EnsembleChemBFN):
255
- assert y is not None, "conditioning is required while using an ensemble model."
256
- assert isinstance(y, list) or isinstance(y, dict)
257
- else:
258
- assert isinstance(y, Tensor) or y is None
259
- if device is None:
260
- device = _find_device()
311
+ tp = _parse_and_assert_param(model, y, method)
312
+ device = _find_device() if device is None else device
261
313
  model.to(device).eval()
262
- if y is not None:
263
- if isinstance(y, Tensor):
264
- y = y.to(device)
265
- elif isinstance(y, list):
266
- y = [i.to(device) for i in y]
267
- elif isinstance(y, dict):
268
- y = {k: v.to(device) for k, v in y.items()}
269
- else:
270
- raise NotImplementedError
271
- if isinstance(allowed_tokens, list):
272
- token_mask = [0 if i in allowed_tokens else 1 for i in vocab_keys]
273
- token_mask = torch.tensor([[token_mask]], dtype=torch.bool).to(device)
274
- else:
275
- token_mask = None
276
- if "ode" in method.lower():
277
- tp = float(method.split(":")[-1])
278
- assert tp > 0, "Sampling temperature should be higher than 0."
314
+ y = _map_to_device(y, device)
315
+ token_mask = _build_token_mask(allowed_tokens, vocab_keys, device)
316
+ if tp:
279
317
  tokens, entropy = model.ode_sample(
280
318
  batch_size, sequence_size, y, sample_step, guidance_strength, token_mask, tp
281
319
  )
@@ -283,16 +321,7 @@ def sample(
283
321
  tokens, entropy = model.sample(
284
322
  batch_size, sequence_size, y, sample_step, guidance_strength, token_mask
285
323
  )
286
- if sort:
287
- sorted_idx = entropy.argsort(stable=True)
288
- tokens = tokens[sorted_idx]
289
- return [
290
- seperator.join([vocab_keys[i] for i in j])
291
- .split("<start>" + seperator)[-1]
292
- .split(seperator + "<end>")[0]
293
- .replace("<pad>", "")
294
- for j in tokens
295
- ]
324
+ return _token_to_seq(tokens, entropy, vocab_keys, seperator, sort)
296
325
 
297
326
 
298
327
  @torch.no_grad()
@@ -339,33 +368,13 @@ def inpaint(
339
368
  :return: a list of generated molecular strings
340
369
  :rtype: list
341
370
  """
342
- assert method.split(":")[0].lower() in ("ode", "bfn")
343
- if isinstance(model, EnsembleChemBFN):
344
- assert y is not None, "conditioning is required while using an ensemble model."
345
- assert isinstance(y, list) or isinstance(y, dict)
346
- else:
347
- assert isinstance(y, Tensor) or y is None
348
- if device is None:
349
- device = _find_device()
371
+ tp = _parse_and_assert_param(model, y, method)
372
+ device = _find_device() if device is None else device
350
373
  model.to(device).eval()
351
374
  x = x.to(device)
352
- if y is not None:
353
- if isinstance(y, Tensor):
354
- y = y.to(device)
355
- elif isinstance(y, list):
356
- y = [i.to(device) for i in y]
357
- elif isinstance(y, dict):
358
- y = {k: v.to(device) for k, v in y.items()}
359
- else:
360
- raise NotImplementedError
361
- if isinstance(allowed_tokens, list):
362
- token_mask = [0 if i in allowed_tokens else 1 for i in vocab_keys]
363
- token_mask = torch.tensor([[token_mask]], dtype=torch.bool).to(device)
364
- else:
365
- token_mask = None
366
- if "ode" in method.lower():
367
- tp = float(method.split(":")[-1])
368
- assert tp > 0, "Sampling temperature should be higher than 0."
375
+ y = _map_to_device(y, device)
376
+ token_mask = _build_token_mask(allowed_tokens, vocab_keys, device)
377
+ if tp:
369
378
  tokens, entropy = model.ode_inpaint(
370
379
  x, y, sample_step, guidance_strength, token_mask, tp
371
380
  )
@@ -373,16 +382,7 @@ def inpaint(
373
382
  tokens, entropy = model.inpaint(
374
383
  x, y, sample_step, guidance_strength, token_mask
375
384
  )
376
- if sort:
377
- sorted_idx = entropy.argsort(stable=True)
378
- tokens = tokens[sorted_idx]
379
- return [
380
- separator.join([vocab_keys[i] for i in j])
381
- .split("<start>" + separator)[-1]
382
- .split(separator + "<end>")[0]
383
- .replace("<pad>", "")
384
- for j in tokens
385
- ]
385
+ return _token_to_seq(tokens, entropy, vocab_keys, separator, sort)
386
386
 
387
387
 
388
388
  @torch.no_grad()
@@ -429,33 +429,13 @@ def optimise(
429
429
  :return: a list of generated molecular strings
430
430
  :rtype: list
431
431
  """
432
- assert method.split(":")[0].lower() in ("ode", "bfn")
433
- if isinstance(model, EnsembleChemBFN):
434
- assert y is not None, "conditioning is required while using an ensemble model."
435
- assert isinstance(y, list) or isinstance(y, dict)
436
- else:
437
- assert isinstance(y, Tensor) or y is None
438
- if device is None:
439
- device = _find_device()
432
+ tp = _parse_and_assert_param(model, y, method)
433
+ device = _find_device() if device is None else device
440
434
  model.to(device).eval()
441
435
  x = x.to(device)
442
- if y is not None:
443
- if isinstance(y, Tensor):
444
- y = y.to(device)
445
- elif isinstance(y, list):
446
- y = [i.to(device) for i in y]
447
- elif isinstance(y, dict):
448
- y = {k: v.to(device) for k, v in y.items()}
449
- else:
450
- raise NotImplementedError
451
- if isinstance(allowed_tokens, list):
452
- token_mask = [0 if i in allowed_tokens else 1 for i in vocab_keys]
453
- token_mask = torch.tensor([[token_mask]], dtype=torch.bool).to(device)
454
- else:
455
- token_mask = None
456
- if "ode" in method.lower():
457
- tp = float(method.split(":")[-1])
458
- assert tp > 0, "Sampling temperature should be higher than 0."
436
+ y = _map_to_device(y, device)
437
+ token_mask = _build_token_mask(allowed_tokens, vocab_keys, device)
438
+ if tp:
459
439
  tokens, entropy = model.ode_optimise(
460
440
  x, y, sample_step, guidance_strength, token_mask, tp
461
441
  )
@@ -463,16 +443,7 @@ def optimise(
463
443
  tokens, entropy = model.optimise(
464
444
  x, y, sample_step, guidance_strength, token_mask
465
445
  )
466
- if sort:
467
- sorted_idx = entropy.argsort(stable=True)
468
- tokens = tokens[sorted_idx]
469
- return [
470
- separator.join([vocab_keys[i] for i in j])
471
- .split("<start>" + separator)[-1]
472
- .split(separator + "<end>")[0]
473
- .replace("<pad>", "")
474
- for j in tokens
475
- ]
446
+ return _token_to_seq(tokens, entropy, vocab_keys, separator, sort)
476
447
 
477
448
 
478
449
  def quantise_model_(model: ChemBFN) -> None:
@@ -8,7 +8,6 @@ from typing import Dict, Tuple, Union, Optional
8
8
  import torch
9
9
  import torch.optim as op
10
10
  import torch.nn.functional as F
11
- from loralib import lora_state_dict, mark_only_lora_as_trainable
12
11
  from torch import Tensor
13
12
  from torch.optim.lr_scheduler import ReduceLROnPlateau
14
13
  from lightning import LightningModule
@@ -55,6 +54,8 @@ class Model(LightningModule):
55
54
  self.scorer = scorer
56
55
  self.save_hyperparameters(hparam, ignore=["model", "mlp", "scorer"])
57
56
  if model.lora_enabled:
57
+ from loralib import mark_only_lora_as_trainable
58
+
58
59
  mark_only_lora_as_trainable(self.model)
59
60
  self.use_scorer = self.scorer is not None
60
61
 
@@ -107,6 +108,8 @@ class Model(LightningModule):
107
108
  :rtype: None
108
109
  """
109
110
  if self.model.lora_enabled:
111
+ from loralib import lora_state_dict
112
+
110
113
  torch.save(
111
114
  {
112
115
  "lora_nn": lora_state_dict(self.model),
@@ -152,6 +155,8 @@ class Regressor(LightningModule):
152
155
  self.model.requires_grad_(not hparam["freeze"])
153
156
  self.save_hyperparameters(hparam, ignore=["model", "mlp"])
154
157
  if model.lora_enabled:
158
+ from loralib import mark_only_lora_as_trainable
159
+
155
160
  mark_only_lora_as_trainable(self.model)
156
161
  assert hparam["mode"] in ("regression", "classification")
157
162
 
@@ -231,6 +236,8 @@ class Regressor(LightningModule):
231
236
  )
232
237
  if not self.hparams.freeze:
233
238
  if self.model.lora_enabled:
239
+ from loralib import lora_state_dict
240
+
234
241
  torch.save(
235
242
  {
236
243
  "lora_nn": lora_state_dict(self.model),
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayesianflow_for_chem
3
- Version: 2.1.0
3
+ Version: 2.2.1
4
4
  Summary: Bayesian flow network framework for Chemistry
5
5
  Home-page: https://augus1999.github.io/bayesian-flow-network-for-chemistry/
6
6
  Author: Nianze A. Tao
@@ -54,12 +54,14 @@ This is the repository of the PyTorch implementation of ChemBFN model.
54
54
 
55
55
  [![PyPI](https://img.shields.io/pypi/v/bayesianflow-for-chem?color=ff69b4)](https://pypi.org/project/bayesianflow-for-chem/)
56
56
  ![pytest](https://github.com/Augus1999/bayesian-flow-network-for-chemistry/actions/workflows/pytest.yml/badge.svg)
57
+ [![document](https://github.com/Augus1999/bayesian-flow-network-for-chemistry/actions/workflows/pages/pages-build-deployment/badge.svg)](https://augus1999.github.io/bayesian-flow-network-for-chemistry/)
57
58
 
58
59
  ## Features
59
60
 
60
61
  ChemBFN provides the state-of-the-art functionalities of
61
62
  * SMILES or SELFIES-based *de novo* molecule generation
62
63
  * Protein sequence *de novo* generation
64
+ * Template optimisation (mol2mol)
63
65
  * Classifier-free guidance conditional generation (single or multi-objective optimisation)
64
66
  * Context-guided conditional generation (inpaint)
65
67
  * Outstanding out-of-distribution chemical space sampling
@@ -71,6 +73,7 @@ in an all-in-one-model style.
71
73
 
72
74
  ## News
73
75
 
76
+ * [09/10/2025] A web app [`chembfn_webui`](https://github.com/Augus1999/ChemBFN-WebUI) for hosting ChemBFN models is available on [PyPI](https://pypi.org/project/chembfn-webui/).
74
77
  * [30/01/2025] The package `bayesianflow_for_chem` is available on [PyPI](https://pypi.org/project/bayesianflow-for-chem/).
75
78
  * [21/01/2025] Our first paper has been accepted by [JCIM](https://pubs.acs.org/doi/10.1021/acs.jcim.4c01792).
76
79
  * [17/12/2024] The second paper of out-of-distribution generation is available on [arxiv.org](https://arxiv.org/abs/2412.11439).
@@ -0,0 +1,15 @@
1
+ bayesianflow_for_chem/__init__.py,sha256=l6twtXTdnix9SC8oL0Aqr0dYJ1uBUmMtk2QbUaRrhkE,642
2
+ bayesianflow_for_chem/cli.py,sha256=jZPBUVwOl4qnjUMb3Lavu6owv1RP9jGdrDlM7KDA99c,26600
3
+ bayesianflow_for_chem/data.py,sha256=RM3YaZUkTJ4-RNLH37GwhM4sRddOOusrC996ZPAe1lI,6590
4
+ bayesianflow_for_chem/model.py,sha256=M35G4u4mX4btl9vOK3Iqs6yOSuIKI_OoCTmLhmjbwNk,57559
5
+ bayesianflow_for_chem/scorer.py,sha256=gQFUlkyxitch02ntqcRh1ZS8aondKLynW5U6NfTQTb4,4084
6
+ bayesianflow_for_chem/spectra.py,sha256=Ba9ib1aDvTtDYbH3b4d-lIty3ZSQMu7jwehuV2KmhwA,1781
7
+ bayesianflow_for_chem/tool.py,sha256=VQo6xShwwmJbBsC-xTRL5ysCI7OLoOpmwQ-A7AuAPI8,23587
8
+ bayesianflow_for_chem/train.py,sha256=7AU0A-eZwzSYsLyIe3OxGTNWPnhGpHmVUaQLplV2Fn8,9886
9
+ bayesianflow_for_chem/_data/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
10
+ bayesianflow_for_chem-2.2.1.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
11
+ bayesianflow_for_chem-2.2.1.dist-info/METADATA,sha256=YKuXgdklakB9DH4cFpTy3EbYdmo8dz9BgQQ4bI9CVQs,6476
12
+ bayesianflow_for_chem-2.2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
13
+ bayesianflow_for_chem-2.2.1.dist-info/entry_points.txt,sha256=N63RMoJsr8rxuKxc7Fj802SL8J5AlpCoPkS8E3IFPLI,54
14
+ bayesianflow_for_chem-2.2.1.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
15
+ bayesianflow_for_chem-2.2.1.dist-info/RECORD,,
@@ -1,15 +0,0 @@
1
- bayesianflow_for_chem/__init__.py,sha256=bmqERRnnmyK7UgUYn3BFH_YipKfTxNNjRzu9__RVid4,612
2
- bayesianflow_for_chem/cli.py,sha256=A1cFz6jhpLEpbK9r8GxCLdbnPCzQ4RrsavLKg_lssVg,24208
3
- bayesianflow_for_chem/data.py,sha256=Pl0gGWHmMKTKHpsxznvLgYPCwwlLNL7nqH19Vipjkxs,6584
4
- bayesianflow_for_chem/model.py,sha256=UW5hfAofYK9dH9euDPYWfJedVMRFxk8WtY427fObf70,59641
5
- bayesianflow_for_chem/scorer.py,sha256=gQFUlkyxitch02ntqcRh1ZS8aondKLynW5U6NfTQTb4,4084
6
- bayesianflow_for_chem/spectra.py,sha256=Ba9ib1aDvTtDYbH3b4d-lIty3ZSQMu7jwehuV2KmhwA,1781
7
- bayesianflow_for_chem/tool.py,sha256=bqoIMas8bmcjYBghuQWLh75Eq8ZlG6mh9ZeDzWGOmuw,24790
8
- bayesianflow_for_chem/train.py,sha256=jYkhSguW50lrcTEydCQ20yig_mmc1j7WH9KmVwBCTAo,9727
9
- bayesianflow_for_chem/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
10
- bayesianflow_for_chem-2.1.0.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
11
- bayesianflow_for_chem-2.1.0.dist-info/METADATA,sha256=ctwor5jnPCmAo-1wuIGkX70aqkXj-8YQxr27AJOdEjM,6057
12
- bayesianflow_for_chem-2.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
13
- bayesianflow_for_chem-2.1.0.dist-info/entry_points.txt,sha256=N63RMoJsr8rxuKxc7Fj802SL8J5AlpCoPkS8E3IFPLI,54
14
- bayesianflow_for_chem-2.1.0.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
15
- bayesianflow_for_chem-2.1.0.dist-info/RECORD,,
File without changes