bayesianflow-for-chem 1.2.6__tar.gz → 1.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bayesianflow-for-chem might be problematic. Click here for more details.

Files changed (18) hide show
  1. {bayesianflow_for_chem-1.2.6 → bayesianflow_for_chem-1.3.0}/PKG-INFO +5 -5
  2. {bayesianflow_for_chem-1.2.6 → bayesianflow_for_chem-1.3.0}/README.md +1 -1
  3. {bayesianflow_for_chem-1.2.6 → bayesianflow_for_chem-1.3.0}/bayesianflow_for_chem/__init__.py +3 -3
  4. {bayesianflow_for_chem-1.2.6 → bayesianflow_for_chem-1.3.0}/bayesianflow_for_chem/data.py +2 -2
  5. {bayesianflow_for_chem-1.2.6 → bayesianflow_for_chem-1.3.0}/bayesianflow_for_chem/model.py +392 -26
  6. {bayesianflow_for_chem-1.2.6 → bayesianflow_for_chem-1.3.0}/bayesianflow_for_chem/scorer.py +1 -1
  7. {bayesianflow_for_chem-1.2.6 → bayesianflow_for_chem-1.3.0}/bayesianflow_for_chem/tool.py +240 -147
  8. {bayesianflow_for_chem-1.2.6 → bayesianflow_for_chem-1.3.0}/bayesianflow_for_chem/train.py +5 -3
  9. {bayesianflow_for_chem-1.2.6 → bayesianflow_for_chem-1.3.0}/bayesianflow_for_chem.egg-info/PKG-INFO +5 -5
  10. {bayesianflow_for_chem-1.2.6 → bayesianflow_for_chem-1.3.0}/pyproject.toml +1 -1
  11. {bayesianflow_for_chem-1.2.6 → bayesianflow_for_chem-1.3.0}/setup.py +2 -2
  12. {bayesianflow_for_chem-1.2.6 → bayesianflow_for_chem-1.3.0}/LICENSE +0 -0
  13. {bayesianflow_for_chem-1.2.6 → bayesianflow_for_chem-1.3.0}/bayesianflow_for_chem/vocab.txt +0 -0
  14. {bayesianflow_for_chem-1.2.6 → bayesianflow_for_chem-1.3.0}/bayesianflow_for_chem.egg-info/SOURCES.txt +0 -0
  15. {bayesianflow_for_chem-1.2.6 → bayesianflow_for_chem-1.3.0}/bayesianflow_for_chem.egg-info/dependency_links.txt +0 -0
  16. {bayesianflow_for_chem-1.2.6 → bayesianflow_for_chem-1.3.0}/bayesianflow_for_chem.egg-info/requires.txt +0 -0
  17. {bayesianflow_for_chem-1.2.6 → bayesianflow_for_chem-1.3.0}/bayesianflow_for_chem.egg-info/top_level.txt +0 -0
  18. {bayesianflow_for_chem-1.2.6 → bayesianflow_for_chem-1.3.0}/setup.cfg +0 -0
@@ -1,16 +1,15 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: bayesianflow_for_chem
3
- Version: 1.2.6
3
+ Version: 1.3.0
4
4
  Summary: Bayesian flow network framework for Chemistry
5
5
  Home-page: https://augus1999.github.io/bayesian-flow-network-for-chemistry/
6
6
  Author: Nianze A. Tao
7
7
  Author-email: tao-nianze@hiroshima-u.ac.jp
8
- License: AGPL-3.0 licence
8
+ License: AGPL-3.0-or-later
9
9
  Project-URL: Source, https://github.com/Augus1999/bayesian-flow-network-for-chemistry
10
10
  Keywords: Chemistry,CLM,ChemBFN
11
11
  Classifier: Development Status :: 5 - Production/Stable
12
12
  Classifier: Intended Audience :: Science/Research
13
- Classifier: License :: OSI Approved :: GNU Affero General Public License v3
14
13
  Classifier: Natural Language :: English
15
14
  Classifier: Programming Language :: Python :: 3
16
15
  Classifier: Programming Language :: Python :: 3.9
@@ -39,6 +38,7 @@ Dynamic: description-content-type
39
38
  Dynamic: home-page
40
39
  Dynamic: keywords
41
40
  Dynamic: license
41
+ Dynamic: license-file
42
42
  Dynamic: project-url
43
43
  Dynamic: provides-extra
44
44
  Dynamic: requires-dist
@@ -86,7 +86,7 @@ You can find example scripts in [📁example](./example) folder.
86
86
 
87
87
  ## Pre-trained Model
88
88
 
89
- You can find pretrained models in [release](https://github.com/Augus1999/bayesian-flow-network-for-chemistry/releases) or on our [🤗Hugging Face model page](https://huggingface.co/suenoomozawa/ChemBFN).
89
+ You can find pretrained models on our [🤗Hugging Face model page](https://huggingface.co/suenoomozawa/ChemBFN).
90
90
 
91
91
  ## Dataset Handling
92
92
 
@@ -39,7 +39,7 @@ You can find example scripts in [📁example](./example) folder.
39
39
 
40
40
  ## Pre-trained Model
41
41
 
42
- You can find pretrained models in [release](https://github.com/Augus1999/bayesian-flow-network-for-chemistry/releases) or on our [🤗Hugging Face model page](https://huggingface.co/suenoomozawa/ChemBFN).
42
+ You can find pretrained models on our [🤗Hugging Face model page](https://huggingface.co/suenoomozawa/ChemBFN).
43
43
 
44
44
  ## Dataset Handling
45
45
 
@@ -4,8 +4,8 @@
4
4
  ChemBFN package.
5
5
  """
6
6
  from . import data, tool, train, scorer
7
- from .model import ChemBFN, MLP
7
+ from .model import ChemBFN, MLP, EnsembleChemBFN
8
8
 
9
- __all__ = ["data", "tool", "train", "scorer", "ChemBFN", "MLP"]
10
- __version__ = "1.2.6"
9
+ __all__ = ["data", "tool", "train", "scorer", "ChemBFN", "MLP", "EnsembleChemBFN"]
10
+ __version__ = "1.3.0"
11
11
  __author__ = "Nianze A. Tao (Omozawa Sueno)"
@@ -55,7 +55,7 @@ aa_regex = re.compile(AA_REGEX_PATTERN)
55
55
 
56
56
 
57
57
  def load_vocab(
58
- vocab_file: Union[str, Path]
58
+ vocab_file: Union[str, Path],
59
59
  ) -> Dict[str, Union[int, List[str], Dict[str, int]]]:
60
60
  """
61
61
  Load vocabulary from source file.
@@ -108,7 +108,7 @@ def geo2vec(geo2seq: str) -> List[int]:
108
108
  """
109
109
  Geo2Seq tokenisation using a dataset-independent regex pattern.
110
110
 
111
- :param geo2seq: Geo2Seq string
111
+ :param geo2seq: `GEO2SEQ` string
112
112
  :type geo2seq: str
113
113
  :return: tokens w/o `<start>` and `<end>`
114
114
  :rtype: list
@@ -4,7 +4,8 @@
4
4
  Define Bayesian Flow Network for Chemistry (ChemBFN) model.
5
5
  """
6
6
  from pathlib import Path
7
- from typing import List, Tuple, Optional, Union
7
+ from copy import deepcopy
8
+ from typing import List, Tuple, Dict, Optional, Union, Callable
8
9
  import torch
9
10
  import torch.nn as nn
10
11
  from torch import Tensor
@@ -592,6 +593,13 @@ class ChemBFN(nn.Module):
592
593
  x, logits = torch.broadcast_tensors(x[..., None], logits)
593
594
  return (-logits.gather(-1, x[..., :1]).squeeze(-1)).mean()
594
595
 
596
+ @staticmethod
597
+ def reshape_y(y: Tensor) -> Tensor:
598
+ assert y.dim() <= 3 # this doesn't work if the model is frezen in JIT.
599
+ if y.dim() == 2:
600
+ return y[:, None, :]
601
+ return y
602
+
595
603
  @torch.jit.export
596
604
  def sample(
597
605
  self,
@@ -607,7 +615,7 @@ class ChemBFN(nn.Module):
607
615
 
608
616
  :param batch_size: batch size
609
617
  :param sequence_size: max sequence length
610
- :param y: conditioning vector; shape: (n_b, 1, n_f)
618
+ :param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
611
619
  :param sample_step: number of sampling steps
612
620
  :param guidance_strength: strength of conditional generation. It is not used if y is null.
613
621
  :param token_mask: token mask; shape: (1, 1, n_vocab)
@@ -626,9 +634,7 @@ class ChemBFN(nn.Module):
626
634
  / self.K
627
635
  )
628
636
  if y is not None:
629
- assert y.dim() == 3 # this doesn't work if the model is frezen in JIT.
630
- if y.shape[0] == 1:
631
- y = y.repeat(batch_size, 1, 1)
637
+ y = self.reshape_y(y)
632
638
  for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
633
639
  t = (i - 1).view(1, 1, 1).repeat(batch_size, 1, 1) / sample_step
634
640
  p = self.discrete_output_distribution(theta, t, y, guidance_strength)
@@ -663,7 +669,7 @@ class ChemBFN(nn.Module):
663
669
 
664
670
  :param batch_size: batch size
665
671
  :param sequence_size: max sequence length
666
- :param y: conditioning vector; shape: (n_b, 1, n_f)
672
+ :param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
667
673
  :param sample_step: number of sampling steps
668
674
  :param guidance_strength: strength of conditional generation. It is not used if y is null.
669
675
  :param token_mask: token mask; shape: (1, 1, n_vocab)
@@ -681,9 +687,7 @@ class ChemBFN(nn.Module):
681
687
  """
682
688
  z = torch.zeros((batch_size, sequence_size, self.K), device=self.beta.device)
683
689
  if y is not None:
684
- assert y.dim() == 3 # this doesn't work if the model is frezen in JIT.
685
- if y.shape[0] == 1:
686
- y = y.repeat(batch_size, 1, 1)
690
+ y = self.reshape_y(y)
687
691
  for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
688
692
  t = (i - 1).view(1, 1, 1).repeat(batch_size, 1, 1) / sample_step
689
693
  theta = torch.softmax(z, -1)
@@ -714,7 +718,7 @@ class ChemBFN(nn.Module):
714
718
  Molecule inpaint functionality.
715
719
 
716
720
  :param x: categorical indices of scaffold; shape: (n_b, n_t)
717
- :param y: conditioning vector; shape: (n_b, 1, n_f)
721
+ :param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
718
722
  :param sample_step: number of sampling steps
719
723
  :param guidance_strength: strength of conditional generation. It is not used if y is null.
720
724
  :param token_mask: token mask; shape: (1, 1, n_vocab)
@@ -733,9 +737,7 @@ class ChemBFN(nn.Module):
733
737
  x_onehot = nn.functional.one_hot(x, self.K) * mask
734
738
  theta = x_onehot + (1 - mask) * theta
735
739
  if y is not None:
736
- assert y.dim() == 3 # this doesn't work if the model is frezen in JIT.
737
- if y.shape[0] == 1:
738
- y = y.repeat(n_b, 1, 1)
740
+ y = self.reshape_y(y)
739
741
  for i in torch.linspace(1, sample_step, sample_step, device=x.device):
740
742
  t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
741
743
  p = self.discrete_output_distribution(theta, t, y, guidance_strength)
@@ -769,7 +771,7 @@ class ChemBFN(nn.Module):
769
771
  ODE inpainting.
770
772
 
771
773
  :param x: categorical indices of scaffold; shape: (n_b, n_t)
772
- :param y: conditioning vector; shape: (n_b, 1, n_f)
774
+ :param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
773
775
  :param sample_step: number of sampling steps
774
776
  :param guidance_strength: strength of conditional generation. It is not used if y is null.
775
777
  :param token_mask: token mask; shape: (1, 1, n_vocab)
@@ -789,9 +791,7 @@ class ChemBFN(nn.Module):
789
791
  x_onehot = nn.functional.one_hot(x, self.K) * mask
790
792
  z = torch.zeros((n_b, n_t, self.K), device=self.beta.device)
791
793
  if y is not None:
792
- assert y.dim() == 3 # this doesn't work if the model is frezen in JIT.
793
- if y.shape[0] == 1:
794
- y = y.repeat(n_b, 1, 1)
794
+ y = self.reshape_y(y)
795
795
  for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
796
796
  t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
797
797
  theta = torch.softmax(z, -1)
@@ -847,13 +847,7 @@ class ChemBFN(nn.Module):
847
847
  with open(ckpt, "rb") as f:
848
848
  state = torch.load(f, "cpu", weights_only=True)
849
849
  nn, hparam = state["nn"], state["hparam"]
850
- model = cls(
851
- hparam["num_vocab"],
852
- hparam["channel"],
853
- hparam["num_layer"],
854
- hparam["num_head"],
855
- hparam["dropout"],
856
- )
850
+ model = cls(**hparam)
857
851
  model.load_state_dict(nn, False)
858
852
  if ckpt_lora:
859
853
  with open(ckpt_lora, "rb") as g:
@@ -908,7 +902,7 @@ class MLP(nn.Module):
908
902
  if self.class_input:
909
903
  x = x.to(dtype=torch.long)
910
904
  for layer in self.layers[:-1]:
911
- x = torch.selu(layer(x))
905
+ x = torch.selu(layer.forward(x))
912
906
  return self.layers[-1](x)
913
907
 
914
908
  @classmethod
@@ -926,10 +920,382 @@ class MLP(nn.Module):
926
920
  with open(ckpt, "rb") as f:
927
921
  state = torch.load(f, "cpu", weights_only=True)
928
922
  nn, hparam = state["nn"], state["hparam"]
929
- model = cls(hparam["size"], hparam["class_input"], hparam["dropout"])
923
+ model = cls(**hparam)
930
924
  model.load_state_dict(nn, strict)
931
925
  return model
932
926
 
933
927
 
928
+ class EnsembleChemBFN(ChemBFN):
929
+ """
930
+ This module does not fully support `torch.jit.script`. We have `EnsembleChemBFN.jit()`
931
+ method to JIT compile the submodels.
932
+ `torch.compile()` is a better choice to compiling the whole model.
933
+ """
934
+
935
+ def __init__(
936
+ self,
937
+ base_model_path: Union[str, Path],
938
+ lora_paths: Union[List[Union[str, Path]], Dict[str, Union[str, Path]]],
939
+ cond_heads: Union[List[nn.Module], Dict[str, nn.Module]],
940
+ adapter_weights: Optional[Union[List[float], Dict[str, float]]] = None,
941
+ semi_autoregressive_flags: Optional[Union[List[bool], Dict[str, bool]]] = None,
942
+ ) -> None:
943
+ """
944
+ Ensemble of ChemBFN models from LoRA checkpoints.
945
+
946
+ :param base_model_path: base model checkpoint file
947
+ :param lora_paths: a list of LoRA checkpoint files or a `dict` instance of these files
948
+ :param cond_heads: a list of conditioning network heads or a `dict` instance of these networks
949
+ :param adapter_weights: a list of weights of each LoRA finetuned model or a 'dict` instance of these weights; default is equally weighted
950
+ :param semi_autoregressive_flags: a list of the semi-autoregressive behaviour states of each LoRA finetuned model or a `dict` instance of these states; default is all `False`
951
+ :type base_model_path: str | pathlib.Path
952
+ :type lora_paths: list | dict
953
+ :type cond_heads: list | dict
954
+ :type adapter_weights: list | dict | None
955
+ :type semi_autoregressive_flags: list | dict | None
956
+ """
957
+ n = len(lora_paths)
958
+ assert type(lora_paths) == type(
959
+ cond_heads
960
+ ), "`lora_paths` and `cond_heads` should have the same type!"
961
+ assert n == len(
962
+ cond_heads
963
+ ), "`lora_paths` and `cond_heads` should have the same length!"
964
+ if adapter_weights:
965
+ assert type(lora_paths) == type(
966
+ adapter_weights
967
+ ), "`lora_paths` and `adapter_weights` should have the same type!"
968
+ assert n == len(
969
+ adapter_weights
970
+ ), "`lora_paths` and `adapter_weights` should have the same length!"
971
+ if semi_autoregressive_flags:
972
+ assert type(lora_paths) == type(
973
+ semi_autoregressive_flags
974
+ ), "`lora_paths` and `semi_autoregressive_flags` should have the same type!"
975
+ assert n == len(
976
+ semi_autoregressive_flags
977
+ ), "`lora_paths` and `semi_autoregressive_flags` should have the same length!"
978
+ _label_is_dict = isinstance(lora_paths, dict)
979
+ if isinstance(lora_paths, list):
980
+ names = tuple(f"val_{i}" for i in range(n))
981
+ lora_paths = dict(zip(names, lora_paths))
982
+ cond_heads = dict(zip(names, cond_heads))
983
+ if not adapter_weights:
984
+ adapter_weights = (1 / n for _ in names)
985
+ if not semi_autoregressive_flags:
986
+ semi_autoregressive_flags = (False for _ in names)
987
+ adapter_weights = dict(zip(names, adapter_weights))
988
+ semi_autoregressive_flags = dict(zip(names, semi_autoregressive_flags))
989
+ else:
990
+ names = tuple(lora_paths.keys())
991
+ if not adapter_weights:
992
+ adapter_weights = dict(zip(names, (1 / n for _ in names)))
993
+ if not semi_autoregressive_flags:
994
+ semi_autoregressive_flags = dict(zip(names, (False for _ in names)))
995
+ base_model = ChemBFN.from_checkpoint(base_model_path)
996
+ models = dict(zip(names, (deepcopy(base_model.eval()) for _ in names)))
997
+ for k in names:
998
+ with open(lora_paths[k], "rb") as f:
999
+ state = torch.load(f, "cpu", weights_only=True)
1000
+ lora_nn, lora_param = state["lora_nn"], state["lora_param"]
1001
+ models[k].enable_lora(**lora_param)
1002
+ models[k].load_state_dict(lora_nn, False)
1003
+ models[k].semi_autoregressive = semi_autoregressive_flags[k]
1004
+ super().__init__(**base_model.hparam)
1005
+ self.cond_heads = nn.ModuleDict(cond_heads)
1006
+ self.models = nn.ModuleDict(models)
1007
+ self.adapter_weights = adapter_weights
1008
+ self._label_is_dict = _label_is_dict # flag
1009
+ # ------- remove unnecessary submodules -------
1010
+ self.embedding = None
1011
+ self.time_embed = None
1012
+ self.position = None
1013
+ self.encoder_layers = None
1014
+ self.final_layer = None
1015
+ self.__delattr__("embedding")
1016
+ self.__delattr__("time_embed")
1017
+ self.__delattr__("position")
1018
+ self.__delattr__("encoder_layers")
1019
+ self.__delattr__("final_layer")
1020
+ # ------- remove unused attributes -------
1021
+ self.__delattr__("semi_autoregressive")
1022
+ self.__delattr__("lora_enabled")
1023
+ self.__delattr__("lora_param")
1024
+ self.__delattr__("hparam")
1025
+
1026
+ def construct_y(
1027
+ self, c: Union[List[Tensor], Dict[str, Tensor]]
1028
+ ) -> Dict[str, Tensor]:
1029
+ assert (
1030
+ isinstance(c, dict) is self._label_is_dict
1031
+ ), f"`c` should be a {'`dict` instance' if self._label_is_dict else '`list` instance'} but got {type(c)} instand."
1032
+ out: Dict[str, Tensor] = {}
1033
+ if isinstance(c, list):
1034
+ c = dict(zip([f"val_{i}" for i in range(len(c))], c))
1035
+ for name, model in self.cond_heads.items():
1036
+ y = model.forward(c[name])
1037
+ if y.dim() == 2:
1038
+ y = y[:, None, :]
1039
+ out[name] = y
1040
+ return out
1041
+
1042
+ def discrete_output_distribution(
1043
+ self, theta: Tensor, t: Tensor, y: Dict[str, Tensor], w: float
1044
+ ) -> Tensor:
1045
+ """
1046
+ :param theta: input distribution; shape: (n_b, n_t, n_vocab)
1047
+ :param t: continuous time in [0, 1]; shape: (n_b, 1, 1)
1048
+ :param y: a dict of conditioning vectors; shape: (n_b, 1, n_f) * n_h
1049
+ :param w: guidance strength controlling the conditional generation
1050
+ :type theta: torch.Tensor
1051
+ :type t: torch.Tensor
1052
+ :type y: dict
1053
+ :type w: float
1054
+ :return: output distribution; shape: (n_b, n_t, n_vocab)
1055
+ :rtype: torch.Tensor
1056
+ """
1057
+ theta = 2 * theta - 1 # rescale to [-1, 1]
1058
+ p_uncond, p_cond = torch.zeros_like(theta), torch.zeros_like(theta)
1059
+ # Q: Why not use `torch.vmap`? It's faster than doing the loop, isn't it?
1060
+ #
1061
+ # A: We have quite a few reasons to avoid using `vmap`:
1062
+ # 1. JIT doesn't support vmap;
1063
+ # 2. It's harder to switch on/off semi-autroregssive behaviours for individual
1064
+ # models when all models are stacked into one (we have a solution but it's not
1065
+ # that elegant);
1066
+ # 3. We just found that the result from vmap was not identical to doing the loop;
1067
+ # 4. vmap requires all models have the same size but it's not always that case
1068
+ # since we sometimes use different ranks of LoRA in finetuning.
1069
+ for name, model in self.models.items():
1070
+ p_uncond_ = model.forward(theta, t, None, None)
1071
+ p_uncond += p_uncond_ * self.adapter_weights[name]
1072
+ p_cond_ = model.forward(theta, t, None, y[name])
1073
+ p_cond += p_cond_ * self.adapter_weights[name]
1074
+ return softmax((1 + w) * p_cond - w * p_uncond, -1)
1075
+
1076
+ @staticmethod
1077
+ def reshape_y(y: Dict[str, Tensor]) -> Dict[str, Tensor]:
1078
+ for k in y:
1079
+ assert y[k].dim() <= 3
1080
+ if y[k].dim() == 2:
1081
+ y[k] = y[k][:, None, :]
1082
+ return y
1083
+
1084
+ @torch.inference_mode()
1085
+ def sample(
1086
+ self,
1087
+ batch_size: int,
1088
+ sequence_size: int,
1089
+ conditions: Union[List[Tensor], Dict[str, Tensor]],
1090
+ sample_step: int = 100,
1091
+ guidance_strength: float = 4.0,
1092
+ token_mask: Optional[Tensor] = None,
1093
+ ) -> Tuple[Tensor, Tensor]:
1094
+ """
1095
+ Sample from a piror distribution.
1096
+
1097
+ :param batch_size: batch size
1098
+ :param sequence_size: max sequence length
1099
+ :param conditions: guidance conditions; shape: (n_b, n_c) * n_h
1100
+ :param sample_step: number of sampling steps
1101
+ :param guidance_strength: strength of conditional generation. It is not used if y is null.
1102
+ :param token_mask: token mask; shape: (1, 1, n_vocab)
1103
+ :type batch_size: int
1104
+ :type sequence_size: int
1105
+ :type conditions: list | dict
1106
+ :type sample_step: int
1107
+ :type guidance_strength: float
1108
+ :type token_mask: torch.Tensor | None
1109
+ :return: sampled token indices; shape: (n_b, n_t) \n
1110
+ entropy of the tokens; shape: (n_b)
1111
+ :rtype: tuple
1112
+ """
1113
+ y = self.construct_y(conditions)
1114
+ return super().sample(
1115
+ batch_size, sequence_size, y, sample_step, guidance_strength, token_mask
1116
+ )
1117
+
1118
+ @torch.inference_mode()
1119
+ def ode_sample(
1120
+ self,
1121
+ batch_size: int,
1122
+ sequence_size: int,
1123
+ conditions: Union[List[Tensor], Dict[str, Tensor]],
1124
+ sample_step: int = 100,
1125
+ guidance_strength: float = 4.0,
1126
+ token_mask: Optional[Tensor] = None,
1127
+ temperature: float = 0.5,
1128
+ ) -> Tuple[Tensor, Tensor]:
1129
+ """
1130
+ ODE-based sampling.
1131
+
1132
+ :param batch_size: batch size
1133
+ :param sequence_size: max sequence length
1134
+ :param conditions: conditioning vector; shape: (n_b, n_c) * n_h
1135
+ :param sample_step: number of sampling steps
1136
+ :param guidance_strength: strength of conditional generation. It is not used if y is null.
1137
+ :param token_mask: token mask; shape: (1, 1, n_vocab)
1138
+ :param temperature: sampling temperature
1139
+ :type batch_size: int
1140
+ :type sequence_size: int
1141
+ :type conditions: list | dict
1142
+ :type sample_step: int
1143
+ :type guidance_strength: float
1144
+ :type token_mask: torch.Tensor | None
1145
+ :type temperature: float
1146
+ :return: sampled token indices; shape: (n_b, n_t) \n
1147
+ entropy of the tokens; shape: (n_b)
1148
+ :rtype: tuple
1149
+ """
1150
+ y = self.construct_y(conditions)
1151
+ return super().ode_sample(
1152
+ batch_size,
1153
+ sequence_size,
1154
+ y,
1155
+ sample_step,
1156
+ guidance_strength,
1157
+ token_mask,
1158
+ temperature,
1159
+ )
1160
+
1161
+ @torch.inference_mode()
1162
+ def inpaint(
1163
+ self,
1164
+ x: Tensor,
1165
+ conditions: Union[List[Tensor], Dict[str, Tensor]],
1166
+ sample_step: int = 100,
1167
+ guidance_strength: float = 4.0,
1168
+ token_mask: Optional[Tensor] = None,
1169
+ ) -> Tuple[Tensor, Tensor]:
1170
+ """
1171
+ Molecule inpaint functionality.
1172
+
1173
+ :param x: categorical indices of scaffold; shape: (n_b, n_t)
1174
+ :param conditions: conditioning vector; shape: (n_b, n_c) * n_h
1175
+ :param sample_step: number of sampling steps
1176
+ :param guidance_strength: strength of conditional generation. It is not used if y is null.
1177
+ :param token_mask: token mask; shape: (1, 1, n_vocab)
1178
+ :type x: torch.Tensor
1179
+ :type conditions: list | dict
1180
+ :type sample_step: int
1181
+ :type guidance_strength: float
1182
+ :type token_mask: torch.Tensor | None
1183
+ :return: sampled token indices; shape: (n_b, n_t) \n
1184
+ entropy of the tokens; shape: (n_b)
1185
+ :rtype: tuple
1186
+ """
1187
+ y = self.construct_y(conditions)
1188
+ return super().inpaint(x, y, sample_step, guidance_strength, token_mask)
1189
+
1190
+ @torch.inference_mode()
1191
+ def ode_inpaint(
1192
+ self,
1193
+ x: Tensor,
1194
+ conditions: Union[List[Tensor], Dict[str, Tensor]],
1195
+ sample_step: int = 100,
1196
+ guidance_strength: float = 4.0,
1197
+ token_mask: Optional[Tensor] = None,
1198
+ temperature: float = 0.5,
1199
+ ) -> Tuple[Tensor, Tensor]:
1200
+ """
1201
+ ODE inpainting.
1202
+
1203
+ :param x: categorical indices of scaffold; shape: (n_b, n_t)
1204
+ :param conditions: conditioning vector; shape: (n_b, n_c) * n_h
1205
+ :param sample_step: number of sampling steps
1206
+ :param guidance_strength: strength of conditional generation. It is not used if y is null.
1207
+ :param token_mask: token mask; shape: (1, 1, n_vocab)
1208
+ :param temperature: sampling temperature
1209
+ :type x: torch.Tensor
1210
+ :type conditions: list | dict
1211
+ :type sample_step: int
1212
+ :type guidance_strength: float
1213
+ :type token_mask: torch.Tensor | None
1214
+ :type temperature: float
1215
+ :return: sampled token indices; shape: (n_b, n_t) \n
1216
+ entropy of the tokens; shape: (n_b)
1217
+ :rtype: tuple
1218
+ """
1219
+ y = self.construct_y(conditions)
1220
+ return super().ode_inpaint(
1221
+ x, y, sample_step, guidance_strength, token_mask, temperature
1222
+ )
1223
+
1224
+ def quantise(
1225
+ self, quantise_method: Optional[Callable[[ChemBFN], nn.Module]] = None
1226
+ ) -> None:
1227
+ """
1228
+ Quantise the submodels. \n
1229
+ This method should be called, if necessary, before `torch.compile()`.
1230
+
1231
+ :param quantise_method: quantisation method; default is `bayesianflow_for_chem.tool.quantise_model`
1232
+ :type quantise_method: callable | None
1233
+ :return:
1234
+ :rtype: None
1235
+ """
1236
+ if quantise_method is None:
1237
+ from bayesianflow_for_chem.tool import quantise_model
1238
+
1239
+ quantise_method = quantise_model
1240
+ for k, v in self.models.items():
1241
+ self.models[k] = quantise_method(v)
1242
+
1243
+ def jit(self, freeze: bool = False) -> None:
1244
+ """
1245
+ JIT compile the submodels. \n
1246
+ This method should be called, if necessary, before `quantise()` method is called if applied.
1247
+
1248
+ :param freeze: whether to freeze the submodels; default is `False`. If set to `True` this
1249
+ method should be called before moving the model to a different device.
1250
+ :type freeze: bool
1251
+ :return:
1252
+ :rtype: None
1253
+ """
1254
+ for k, v in self.models.items():
1255
+ self.models[k] = torch.jit.script(v)
1256
+ if freeze:
1257
+ self.models[k] = torch.jit.freeze(
1258
+ self.models[k], ["semi_autoregressive"]
1259
+ )
1260
+
1261
+ @torch.jit.ignore
1262
+ def forward(self, *_, **__) -> None:
1263
+ """
1264
+ Don't use this method!
1265
+ """
1266
+ raise NotImplementedError("There's nothing here!")
1267
+
1268
+ def cts_loss(self, *_, **__) -> None:
1269
+ """
1270
+ Don't use this method!
1271
+ """
1272
+ raise NotImplementedError("There's nothing here!")
1273
+
1274
+ def reconstruction_loss(self, *_, **__) -> None:
1275
+ """
1276
+ Don't use this method!
1277
+ """
1278
+ raise NotImplementedError("There's nothing here!")
1279
+
1280
+ def enable_lora(self, *_, **__) -> None:
1281
+ """
1282
+ Don't use this method!
1283
+ """
1284
+ raise NotImplementedError("There's nothing here!")
1285
+
1286
+ def inference(self, *_, **__) -> None:
1287
+ """
1288
+ Don't use this method!
1289
+ """
1290
+ raise NotImplementedError("There's nothing here!")
1291
+
1292
+ @classmethod
1293
+ def from_checkpoint(cls, *_, **__) -> None:
1294
+ """
1295
+ Don't use this method!
1296
+ """
1297
+ raise NotImplementedError("There's nothing here!")
1298
+
1299
+
934
1300
  if __name__ == "__main__":
935
1301
  ...
@@ -1,7 +1,7 @@
1
1
  # -*- coding: utf-8 -*-
2
2
  # Author: Nianze A. TAO (Omozawa SUENO)
3
3
  """
4
- Scorers.
4
+ Define essential scorers.
5
5
  """
6
6
  from typing import List, Callable, Union, Optional
7
7
  import torch
@@ -1,11 +1,12 @@
1
1
  # -*- coding: utf-8 -*-
2
2
  # Author: Nianze A. TAO (Omozawa SUENO)
3
3
  """
4
- Tools.
4
+ Essential tools.
5
5
  """
6
6
  import re
7
7
  import csv
8
8
  import random
9
+ import warnings
9
10
  from copy import deepcopy
10
11
  from pathlib import Path
11
12
  from typing import List, Dict, Tuple, Union, Optional
@@ -16,7 +17,8 @@ from torch import cuda, Tensor, softmax
16
17
  from torch.ao import quantization
17
18
  from torch.utils.data import DataLoader
18
19
  from typing_extensions import Self
19
- from rdkit.Chem import rdDetermineBonds, Bond, MolFromXYZBlock, CanonicalRankAtoms
20
+ from rdkit.Chem.rdchem import Mol, Bond
21
+ from rdkit.Chem import rdDetermineBonds, MolFromXYZBlock, MolToSmiles, CanonSmiles
20
22
  from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles # type: ignore
21
23
  from sklearn.metrics import (
22
24
  roc_auc_score,
@@ -32,12 +34,11 @@ try:
32
34
 
33
35
  _use_pynauty = True
34
36
  except ImportError:
35
- import warnings
37
+ import platform
36
38
 
37
39
  _use_pynauty = False
38
-
39
40
  from .data import VOCAB_KEYS
40
- from .model import ChemBFN, MLP, Linear
41
+ from .model import ChemBFN, MLP, Linear, EnsembleChemBFN
41
42
 
42
43
 
43
44
  _atom_regex_pattern = (
@@ -161,6 +162,8 @@ def split_dataset(
161
162
  :return:
162
163
  :rtype: None
163
164
  """
165
+ if isinstance(file, Path):
166
+ file = file.__str__()
164
167
  assert file.endswith(".csv")
165
168
  assert len(split_ratio) == 3
166
169
  assert method in ("random", "scaffold")
@@ -170,7 +173,7 @@ def split_dataset(
170
173
  raw_data = data[1:]
171
174
  smiles_idx = [] # only first index will be used
172
175
  for key, h in enumerate(header):
173
- if h.lower() == "smiles":
176
+ if "smiles" in h.lower():
174
177
  smiles_idx.append(key)
175
178
  assert len(smiles_idx) > 0
176
179
  data_len = len(raw_data)
@@ -186,11 +189,22 @@ def split_dataset(
186
189
  scaffolds: Dict[str, List] = {}
187
190
  for key, d in enumerate(raw_data):
188
191
  # compute Bemis-Murcko scaffold
189
- scaffold = MurckoScaffoldSmiles(d[smiles_idx[0]])
190
- if scaffold in scaffolds:
191
- scaffolds[scaffold].append(key)
192
- else:
193
- scaffolds[scaffold] = [key]
192
+ if len(smiles_idx) > 1:
193
+ warnings.warn(
194
+ "\033[32;1m"
195
+ f"We found {len(smiles_idx)} SMILES strings in a row!"
196
+ " Only the first SMILES will be used to compute the molecular scaffold."
197
+ "\033[0m",
198
+ stacklevel=2,
199
+ )
200
+ try:
201
+ scaffold = MurckoScaffoldSmiles(d[smiles_idx[0]])
202
+ if scaffold in scaffolds:
203
+ scaffolds[scaffold].append(key)
204
+ else:
205
+ scaffolds[scaffold] = [key]
206
+ except ValueError: # do nothing when SMILES is not valid
207
+ ...
194
208
  scaffolds = {key: sorted(value) for key, value in scaffolds.items()}
195
209
  train_set, test_set, val_set = [], [], []
196
210
  for idxs in scaffolds.values():
@@ -212,137 +226,13 @@ def split_dataset(
212
226
  writer.writerows([header] + val_set)
213
227
 
214
228
 
215
- def geo2seq(
216
- symbols: List[str],
217
- coordinates: np.ndarray,
218
- decimals: int = 2,
219
- angle_unit: str = "degree",
220
- ) -> str:
221
- """
222
- Geometry-to-sequence function.\n
223
- The algorithm follows the descriptions in paper: https://arxiv.org/abs/2408.10120.
224
-
225
- :param symbols: a list of atomic symbols
226
- :param coordinates: Cartesian coordinates; shape: (n_a, 3)
227
- :param decimals: number of decimal places to round to
228
- :param angle_unit: `'degree'` or `'radian'`
229
- :type symbols: list
230
- :type coordinates: numpy.ndarray
231
- :type decimals: int
232
- :type angle_unit: str
233
- :return: `Geo2Seq` string
234
- :rtype: str
235
- """
236
- assert angle_unit in ("degree", "radian")
237
- angle_scale = 180 / np.pi if angle_unit == "degree" else 1.0
238
- n = len(symbols)
239
- if n == 1:
240
- return f"{symbols[0]} {'0.0'} {'0.0'} {'0.0'}"
241
- xyz_block = [str(n), ""]
242
- for i, atom in enumerate(symbols):
243
- xyz_block.append(
244
- f"{atom} {'%.10f' % coordinates[i][0].item()} {'%.10f' % coordinates[i][1].item()} {'%.10f' % coordinates[i][2].item()}"
245
- )
246
- mol = MolFromXYZBlock("\n".join(xyz_block))
247
- rdDetermineBonds.DetermineConnectivity(mol)
248
- # ------- Canonicalization -------
249
- if _use_pynauty:
250
- pair_idx = np.array(_bond_pair_idx(mol.GetBonds())).T.tolist()
251
- pair_dict: Dict[int, List[int]] = {}
252
- for key, i in enumerate(pair_idx[0]):
253
- if i not in pair_dict:
254
- pair_dict[i] = [pair_idx[1][key]]
255
- else:
256
- pair_dict[i].append(pair_idx[1][key])
257
- g = Graph(n, adjacency_dict=pair_dict)
258
- cl = canon_label(g) # type: list
259
- else:
260
- warnings.warn(
261
- "\033[32;1m"
262
- "`pynauty` is not installed."
263
- " Switched to canonicalization function provided by `rdkit`."
264
- " This is the expected behaviour only if you are working on Windows platform."
265
- "\033[0m",
266
- stacklevel=2,
267
- )
268
- cl = list(CanonicalRankAtoms(mol, breakTies=True))
269
- symbols = np.array([[s] for s in symbols])[cl].flatten().tolist()
270
- coordinates = coordinates[cl]
271
- # ------- Find global coordinate frame -------
272
- if n == 2:
273
- d = np.round(np.linalg.norm(coordinates[0] - coordinates[1], 2), decimals)
274
- return f"{symbols[0]} {'0.0'} {'0.0'} {'0.0'} {symbols[1]} {d} {'0.0'} {'0.0'}"
275
- for idx_0 in range(n - 2):
276
- _vec0 = coordinates[idx_0] - coordinates[idx_0 + 1]
277
- _vec1 = coordinates[idx_0] - coordinates[idx_0 + 2]
278
- _d1 = np.linalg.norm(_vec0, 2)
279
- _d2 = np.linalg.norm(_vec1, 2)
280
- if 1 - np.abs(np.dot(_vec0, _vec1) / (_d1 * _d2)) > 1e-6:
281
- break
282
- x = (coordinates[idx_0 + 1] - coordinates[idx_0]) / _d1
283
- y = np.cross((coordinates[idx_0 + 2] - coordinates[idx_0]), x)
284
- y_d = np.linalg.norm(y, 2)
285
- y = y / np.ma.filled(np.ma.array(y_d, mask=y_d == 0), np.inf)
286
- z = np.cross(x, y)
287
- # ------- Build spherical coordinates -------
288
- vec = coordinates - coordinates[idx_0]
289
- d = np.linalg.norm(vec, 2, axis=-1)
290
- _d = np.ma.filled(np.ma.array(d, mask=d == 0), np.inf)
291
- theta = angle_scale * np.arccos(np.dot(vec, z) / _d) # in [0, \pi]
292
- phi = angle_scale * np.arctan2(np.dot(vec, y), np.dot(vec, x)) # in [-\pi, \pi]
293
- info = np.vstack([d, theta, phi]).T
294
- info[idx_0] = np.zeros(3)
295
- info = [
296
- f"{symbols[i]} {r[0]} {r[1]} {r[2]}"
297
- for i, r in enumerate(np.round(info, decimals))
298
- ]
299
- return " ".join(info)
300
-
301
-
302
- def seq2geo(
303
- seq: str, angle_unit: str = "degree"
304
- ) -> Optional[Tuple[List[str], List[List[float]]]]:
305
- """
306
- Sequence-to-geometry function.\n
307
- The method follows the descriptions in paper: https://arxiv.org/abs/2408.10120.
308
-
309
- :param seq: `Geo2Seq` string
310
- :param angle_unit: `'degree'` or `'radian'`
311
- :type seq: str
312
- :type angle_unit: str
313
- :return: (symbols, coordinates) if `seq` is valid
314
- :rtype: tuple | None
315
- """
316
- assert angle_unit in ("degree", "radian")
317
- angle_scale = np.pi / 180 if angle_unit == "degree" else 1.0
318
- tokens = seq.split()
319
- if len(tokens) % 4 == 0:
320
- tokens = np.array(tokens).reshape(-1, 4).tolist()
321
- symbols, coordinates = [], []
322
- for i in tokens:
323
- symbol = i[0]
324
- if len(_atom_regex.findall(symbol)) != 1:
325
- return None
326
- symbols.append(symbol)
327
- try:
328
- d, theta, phi = float(i[1]), float(i[2]), float(i[3])
329
- x = d * np.sin(theta * angle_scale) * np.cos(phi * angle_scale)
330
- y = d * np.sin(theta * angle_scale) * np.sin(phi * angle_scale)
331
- z = d * np.cos(theta * angle_scale)
332
- coordinates.append([x.item(), y.item(), z.item()])
333
- except ValueError:
334
- return None
335
- return symbols, coordinates
336
- return None
337
-
338
-
339
229
  @torch.no_grad()
340
230
  def sample(
341
- model: ChemBFN,
231
+ model: Union[ChemBFN, EnsembleChemBFN],
342
232
  batch_size: int,
343
233
  sequence_size: int,
344
234
  sample_step: int = 100,
345
- y: Optional[Tensor] = None,
235
+ y: Optional[Union[Tensor, Dict[str, Tensor], List[Tensor]]] = None,
346
236
  guidance_strength: float = 4.0,
347
237
  device: Union[str, torch.device, None] = None,
348
238
  vocab_keys: List[str] = VOCAB_KEYS,
@@ -358,7 +248,9 @@ def sample(
358
248
  :param batch_size: batch size
359
249
  :param sequence_size: max sequence length
360
250
  :param sample_step: number of sampling steps
361
- :param y: conditioning vector; shape: (n_b, 1, n_f)
251
+ :param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f) \n
252
+ or a list/`dict` of conditions; shape: (n_b, n_c) * n_h
253
+
362
254
  :param guidance_strength: strength of conditional generation. It is not used if y is null.
363
255
  :param device: hardware accelerator
364
256
  :param vocab_keys: a list of (ordered) vocabulary
@@ -366,11 +258,11 @@ def sample(
366
258
  :param method: sampling method chosen from `"ODE:x"` or `"BFN"` where `x` is the value of sampling temperature; default is `"BFN"`
367
259
  :param allowed_tokens: a list of allowed tokens; default is `"all"`
368
260
  :param sort: whether to sort the samples according to entropy values; default is `False`
369
- :type model: bayesianflow_for_chem.model.ChemBFN
261
+ :type model: bayesianflow_for_chem.model.ChemBFN | bayesianflow_for_chem.model.EnsembleChemBFN
370
262
  :type batch_size: int
371
263
  :type sequence_size: int
372
264
  :type sample_step: int
373
- :type y: torch.Tensor | None
265
+ :type y: torch.Tensor | list | dict | None
374
266
  :type guidance_strength: float
375
267
  :type device: str | torch.device | None
376
268
  :type vocab_keys: list
@@ -382,11 +274,23 @@ def sample(
382
274
  :rtype: list
383
275
  """
384
276
  assert method.split(":")[0].lower() in ("ode", "bfn")
277
+ if isinstance(model, EnsembleChemBFN):
278
+ assert y is not None, "conditioning is required while using an ensemble model."
279
+ assert isinstance(y, list) or isinstance(y, dict)
280
+ else:
281
+ assert isinstance(y, Tensor) or y is None
385
282
  if device is None:
386
283
  device = _find_device()
387
284
  model.to(device).eval()
388
285
  if y is not None:
389
- y = y.to(device)
286
+ if isinstance(y, Tensor):
287
+ y = y.to(device)
288
+ elif isinstance(y, list):
289
+ y = [i.to(device) for i in y]
290
+ elif isinstance(y, dict):
291
+ y = {k: v.to(device) for k, v in y.items()}
292
+ else:
293
+ raise NotImplementedError
390
294
  if isinstance(allowed_tokens, list):
391
295
  token_mask = [0 if i in allowed_tokens else 1 for i in vocab_keys]
392
296
  token_mask = torch.tensor([[token_mask]], dtype=torch.bool).to(device)
@@ -416,10 +320,10 @@ def sample(
416
320
 
417
321
  @torch.no_grad()
418
322
  def inpaint(
419
- model: ChemBFN,
323
+ model: Union[ChemBFN, EnsembleChemBFN],
420
324
  x: Tensor,
421
325
  sample_step: int = 100,
422
- y: Optional[Tensor] = None,
326
+ y: Optional[Union[Tensor, Dict[str, Tensor], List[Tensor]]] = None,
423
327
  guidance_strength: float = 4.0,
424
328
  device: Union[str, torch.device, None] = None,
425
329
  vocab_keys: List[str] = VOCAB_KEYS,
@@ -434,7 +338,9 @@ def inpaint(
434
338
  :param model: trained ChemBFN model
435
339
  :param x: categorical indices of scaffold; shape: (n_b, n_t)
436
340
  :param sample_step: number of sampling steps
437
- :param y: conditioning vector; shape: (n_b, 1, n_f)
341
+ :param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f) \n
342
+ or a list/`dict` of conditions; shape: (n_b, n_c) * n_h
343
+
438
344
  :param guidance_strength: strength of conditional generation. It is not used if y is null.
439
345
  :param device: hardware accelerator
440
346
  :param vocab_keys: a list of (ordered) vocabulary
@@ -442,10 +348,10 @@ def inpaint(
442
348
  :param method: sampling method chosen from `"ODE:x"` or `"BFN"` where `x` is the value of sampling temperature; default is `"BFN"`
443
349
  :param allowed_tokens: a list of allowed tokens; default is `"all"`
444
350
  :param sort: whether to sort the samples according to entropy values; default is `False`
445
- :type model: bayesianflow_for_chem.model.ChemBFN
351
+ :type model: bayesianflow_for_chem.model.ChemBFN | bayesianflow_for_chem.model.EnsembleChemBFN
446
352
  :type x: torch.Tensor
447
353
  :type sample_step: int
448
- :type y: torch.Tensor | None
354
+ :type y: torch.Tensor | list | dict | None
449
355
  :type guidance_strength: float
450
356
  :type device: str | torch.device | None
451
357
  :type vocab_keys: list
@@ -457,12 +363,24 @@ def inpaint(
457
363
  :rtype: list
458
364
  """
459
365
  assert method.split(":")[0].lower() in ("ode", "bfn")
366
+ if isinstance(model, EnsembleChemBFN):
367
+ assert y is not None, "conditioning is required while using an ensemble model."
368
+ assert isinstance(y, list) or isinstance(y, dict)
369
+ else:
370
+ assert isinstance(y, Tensor) or y is None
460
371
  if device is None:
461
372
  device = _find_device()
462
373
  model.to(device).eval()
463
374
  x = x.to(device)
464
375
  if y is not None:
465
- y = y.to(device)
376
+ if isinstance(y, Tensor):
377
+ y = y.to(device)
378
+ elif isinstance(y, list):
379
+ y = [i.to(device) for i in y]
380
+ elif isinstance(y, dict):
381
+ y = {k: v.to(device) for k, v in y.items()}
382
+ else:
383
+ raise NotImplementedError
466
384
  if isinstance(allowed_tokens, list):
467
385
  token_mask = [0 if i in allowed_tokens else 1 for i in vocab_keys]
468
386
  token_mask = torch.tensor([[token_mask]], dtype=torch.bool).to(device)
@@ -627,3 +545,178 @@ def quantise_model(model: ChemBFN) -> nn.Module:
627
545
  model, {nn.Linear, Linear}, torch.qint8, mapping
628
546
  )
629
547
  return quantised_model
548
+
549
+
550
+ class GeometryConverter:
551
+ """
552
+ Converting between different 2D/3D molecular representations.
553
+ """
554
+
555
+ @staticmethod
556
+ def _xyz2mol(symbols: List[str], coordinates: np.ndarray) -> Mol:
557
+ xyz_block = [str(len(symbols)), ""]
558
+ r = coordinates
559
+ for i, atom in enumerate(symbols):
560
+ xyz_block.append(f"{atom} {r[i][0]:.10f} {r[i][1]:.10f} {r[i][2]:.10f}")
561
+ return MolFromXYZBlock("\n".join(xyz_block))
562
+
563
+ def cartesian2smiles(
564
+ self,
565
+ symbols: List[str],
566
+ coordinates: np.ndarray,
567
+ charge: int = 0,
568
+ canonical: bool = True,
569
+ ) -> str:
570
+ """
571
+ Transform (guess out) molecular geometry to SMILES string.
572
+
573
+ :param symbols: a list of atomic symbols
574
+ :param coordinates: Cartesian coordinates; shape: (n_a, 3)
575
+ :param charge: net charge
576
+ :param canonical: whether to canonicalise the SMILES
577
+ :type symbols: list
578
+ :type coordinates: numpy.ndarray
579
+ :type charge: int
580
+ :type canonical: bool
581
+ :return: SMILES string
582
+ :rtype: str
583
+ """
584
+ mol = self._xyz2mol(symbols, coordinates)
585
+ rdDetermineBonds.DetermineBonds(mol, charge=charge)
586
+ smiles = MolToSmiles(mol)
587
+ if canonical:
588
+ smiles = CanonSmiles(smiles)
589
+ return smiles
590
+
591
+ def canonicalise(
592
+ self, symbols: List[str], coordinates: np.ndarray
593
+ ) -> Tuple[List[str], np.ndarray]:
594
+ """
595
+ Canonicalising the 3D molecular graph.
596
+
597
+ :param symbols: a list of atomic symbols
598
+ :param coordinates: Cartesian coordinates; shape: (n_a, 3)
599
+ :type symbols: list
600
+ :type coordinates: numpy.ndarray
601
+ :return: canonicalised symbols \n
602
+ canonicalised coordinates; shape: (n_a, 3)
603
+ :rtype: tuple
604
+ """
605
+ if not _use_pynauty:
606
+ if platform.system() == "Windows":
607
+ raise NotImplementedError(
608
+ "This method is not implemented on Windows platform."
609
+ )
610
+ else:
611
+ raise ImportError("`pynauty` is not installed.")
612
+ n = len(symbols)
613
+ if n == 1:
614
+ return symbols, coordinates
615
+ mol = self._xyz2mol(symbols, coordinates)
616
+ rdDetermineBonds.DetermineConnectivity(mol)
617
+ # ------- Canonicalization -------
618
+ pair_idx = np.array(_bond_pair_idx(mol.GetBonds())).T.tolist()
619
+ pair_dict: Dict[int, List[int]] = {}
620
+ for key, i in enumerate(pair_idx[0]):
621
+ if i not in pair_dict:
622
+ pair_dict[i] = [pair_idx[1][key]]
623
+ else:
624
+ pair_dict[i].append(pair_idx[1][key])
625
+ g = Graph(n, adjacency_dict=pair_dict)
626
+ cl = canon_label(g) # type: list
627
+ symbols = np.array([[s] for s in symbols])[cl].flatten().tolist()
628
+ coordinates = coordinates[cl]
629
+ return symbols, coordinates
630
+
631
+ @staticmethod
632
+ def cartesian2spherical(coordinates: np.ndarray) -> np.ndarray:
633
+ """
634
+ Transforming Cartesian coordinate to spherical form.\n
635
+ The method is adapted from the paper: https://arxiv.org/abs/2408.10120.
636
+
637
+ :param coordinates: Cartesian coordinates; shape: (n_a, 3)
638
+ :type coordinates: numpy.ndarray
639
+ :return: spherical coordinates; shape: (n_a, 3)
640
+ :rtype: numpy.ndarray
641
+ """
642
+ n = coordinates.shape[0]
643
+ if n == 1:
644
+ return np.array([[0.0, 0.0, 0.0]])
645
+ # ------- Find global coordinate frame -------
646
+ if n == 2:
647
+ d = np.linalg.norm(coordinates[0] - coordinates[1], 2)
648
+ return np.array([[0.0, 0.0, 0.0], [d, 0.0, 0.0]])
649
+ for idx_0 in range(n - 2):
650
+ _vec0 = coordinates[idx_0] - coordinates[idx_0 + 1]
651
+ _vec1 = coordinates[idx_0] - coordinates[idx_0 + 2]
652
+ _d1 = np.linalg.norm(_vec0, 2)
653
+ _d2 = np.linalg.norm(_vec1, 2)
654
+ if 1 - np.abs(np.dot(_vec0, _vec1) / (_d1 * _d2)) > 1e-6:
655
+ break
656
+ x = (coordinates[idx_0 + 1] - coordinates[idx_0]) / _d1
657
+ y = np.cross((coordinates[idx_0 + 2] - coordinates[idx_0]), x)
658
+ y_d = np.linalg.norm(y, 2)
659
+ y = y / np.ma.filled(np.ma.array(y_d, mask=y_d == 0), np.inf)
660
+ z = np.cross(x, y)
661
+ # ------- Build spherical coordinates -------
662
+ vec = coordinates - coordinates[idx_0]
663
+ d = np.linalg.norm(vec, 2, axis=-1)
664
+ _d = np.ma.filled(np.ma.array(d, mask=d == 0), np.inf)
665
+ theta = np.arccos(np.dot(vec, z) / _d) # in [0, \pi]
666
+ phi = np.arctan2(np.dot(vec, y), np.dot(vec, x)) # in [-\pi, \pi]
667
+ info = np.vstack([d, theta, phi]).T
668
+ info[idx_0] = np.zeros_like(info[idx_0])
669
+ return info
670
+
671
+ def geo2seq(
672
+ self, symbols: List[str], coordinates: np.ndarray, decimals: int = 2
673
+ ) -> str:
674
+ """
675
+ Geometry-to-sequence function.\n
676
+ The algorithm follows the descriptions in paper: https://arxiv.org/abs/2408.10120.
677
+
678
+ :param symbols: a list of atomic symbols
679
+ :param coordinates: Cartesian coordinates; shape: (n_a, 3)
680
+ :param decimals: the maxmium number of decimals to keep; default is 2
681
+ :type symbols: list
682
+ :type coordinates: numpy.ndarray
683
+ :type decimals: int
684
+ :return: `Geo2Seq` string
685
+ :rtype: str
686
+ """
687
+ symbols, coordinates = self.canonicalise(symbols, coordinates)
688
+ info = self.cartesian2spherical(coordinates)
689
+ info = [
690
+ f"{symbols[i]} {r[0]} {r[1]} {r[2]}"
691
+ for i, r in enumerate(np.round(info, decimals))
692
+ ]
693
+ return " ".join(info)
694
+
695
+ @staticmethod
696
+ def seq2geo(seq: str) -> Tuple[Optional[List[str]], Optional[np.ndarray]]:
697
+ """
698
+ Sequence-to-geometry function.\n
699
+ The method follows the descriptions in paper: https://arxiv.org/abs/2408.10120.
700
+
701
+ :param seq: `Geo2Seq` string
702
+ :type seq: str
703
+ :return: (symbols, coordinates) if `seq` is valid
704
+ :rtype: tuple
705
+ """
706
+ tokens = seq.split()
707
+ if len(tokens) % 4 != 0:
708
+ return None, None
709
+ tokens = np.array(tokens).reshape(-1, 4)
710
+ symbols, coordinates = tokens[::, 0], tokens[::, 1:]
711
+ if sum([len(_atom_regex.findall(sym)) for sym in symbols]) != len(symbols):
712
+ return None, None
713
+ try:
714
+ coord = [[float(i) for i in j] for j in coordinates]
715
+ coord = np.array(coord)
716
+ except ValueError:
717
+ return None, None
718
+ d, theta, phi = coord[::, 0, None], coord[::, 1, None], coord[::, 2, None]
719
+ x = d * np.sin(theta) * np.cos(phi)
720
+ y = d * np.sin(theta) * np.sin(phi)
721
+ z = d * np.cos(theta)
722
+ return symbols, np.concatenate([x, y, z], -1)
@@ -37,7 +37,8 @@ class Model(LightningModule):
37
37
  """
38
38
  A `~lightning.LightningModule` wrapper of bayesian flow network for chemistry model.\n
39
39
  This module is used in training stage only. By calling `Model(...).export_model(YOUR_WORK_DIR)` after training,
40
- the model(s) will be saved to `YOUR_WORK_DIR/model.pt` and (if exists) `YOUR_WORK_DIR/mlp.pt`.
40
+ the model(s) will be saved to `YOUR_WORK_DIR/model.pt` (if LoRA is enabled then `YOUR_WORK_DIR/lora.pt`)
41
+ and (if exists) `YOUR_WORK_DIR/mlp.pt`.
41
42
 
42
43
  :param model: `~bayesianflow_for_chem.model.ChemBFN` instance.
43
44
  :param mlp: `~bayesianflow_for_chem.model.MLP` instance or `None`.
@@ -135,7 +136,8 @@ class Regressor(LightningModule):
135
136
  """
136
137
  A `~lightning.LightningModule` wrapper of bayesian flow network for chemistry regression model.\n
137
138
  This module is used in training stage only. By calling `Regressor(...).export_model(YOUR_WORK_DIR)` after training,
138
- the models will be saved to `YOUR_WORK_DIR/model.pt` and `YOUR_WORK_DIR/readout.pt`.
139
+ the models will be saved to `YOUR_WORK_DIR/model_ft.pt` (if LoRA is enabled then `YOUR_WORK_DIR/lora.pt`)
140
+ and `YOUR_WORK_DIR/readout.pt`.
139
141
 
140
142
  :param model: `~bayesianflow_for_chem.model.ChemBFN` instance.
141
143
  :param mlp: `~bayesianflow_for_chem.model.MLP` instance.
@@ -218,7 +220,7 @@ class Regressor(LightningModule):
218
220
  """
219
221
  Save the trained model.
220
222
 
221
- :param workdir: the directory to save the model
223
+ :param workdir: the directory to save the models
222
224
  :type workdir: pathlib.Path
223
225
  :return:
224
226
  :rtype: None
@@ -1,16 +1,15 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: bayesianflow_for_chem
3
- Version: 1.2.6
3
+ Version: 1.3.0
4
4
  Summary: Bayesian flow network framework for Chemistry
5
5
  Home-page: https://augus1999.github.io/bayesian-flow-network-for-chemistry/
6
6
  Author: Nianze A. Tao
7
7
  Author-email: tao-nianze@hiroshima-u.ac.jp
8
- License: AGPL-3.0 licence
8
+ License: AGPL-3.0-or-later
9
9
  Project-URL: Source, https://github.com/Augus1999/bayesian-flow-network-for-chemistry
10
10
  Keywords: Chemistry,CLM,ChemBFN
11
11
  Classifier: Development Status :: 5 - Production/Stable
12
12
  Classifier: Intended Audience :: Science/Research
13
- Classifier: License :: OSI Approved :: GNU Affero General Public License v3
14
13
  Classifier: Natural Language :: English
15
14
  Classifier: Programming Language :: Python :: 3
16
15
  Classifier: Programming Language :: Python :: 3.9
@@ -39,6 +38,7 @@ Dynamic: description-content-type
39
38
  Dynamic: home-page
40
39
  Dynamic: keywords
41
40
  Dynamic: license
41
+ Dynamic: license-file
42
42
  Dynamic: project-url
43
43
  Dynamic: provides-extra
44
44
  Dynamic: requires-dist
@@ -86,7 +86,7 @@ You can find example scripts in [📁example](./example) folder.
86
86
 
87
87
  ## Pre-trained Model
88
88
 
89
- You can find pretrained models in [release](https://github.com/Augus1999/bayesian-flow-network-for-chemistry/releases) or on our [🤗Hugging Face model page](https://huggingface.co/suenoomozawa/ChemBFN).
89
+ You can find pretrained models on our [🤗Hugging Face model page](https://huggingface.co/suenoomozawa/ChemBFN).
90
90
 
91
91
  ## Dataset Handling
92
92
 
@@ -1,3 +1,3 @@
1
1
  [build-system]
2
- requires = ["setuptools >= 75.0"]
2
+ requires = ["setuptools >= 77.0.3"]
3
3
  build-backend = "setuptools.build_meta"
@@ -28,7 +28,8 @@ setup(
28
28
  description="Bayesian flow network framework for Chemistry",
29
29
  long_description=long_description,
30
30
  long_description_content_type="text/markdown",
31
- license="AGPL-3.0 licence",
31
+ license="AGPL-3.0-or-later",
32
+ license_files=["LICEN[CS]E*"],
32
33
  package_dir={"bayesianflow_for_chem": "bayesianflow_for_chem"},
33
34
  package_data={"bayesianflow_for_chem": ["./*.txt", "./*.py"]},
34
35
  include_package_data=True,
@@ -52,7 +53,6 @@ setup(
52
53
  classifiers=[
53
54
  "Development Status :: 5 - Production/Stable",
54
55
  "Intended Audience :: Science/Research",
55
- "License :: OSI Approved :: GNU Affero General Public License v3",
56
56
  "Natural Language :: English",
57
57
  "Programming Language :: Python :: 3",
58
58
  "Programming Language :: Python :: 3.9",