bayesianflow-for-chem 1.2.7__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bayesianflow-for-chem might be problematic. Click here for more details.

@@ -4,8 +4,8 @@
4
4
  ChemBFN package.
5
5
  """
6
6
  from . import data, tool, train, scorer
7
- from .model import ChemBFN, MLP
7
+ from .model import ChemBFN, MLP, EnsembleChemBFN
8
8
 
9
- __all__ = ["data", "tool", "train", "scorer", "ChemBFN", "MLP"]
10
- __version__ = "1.2.7"
9
+ __all__ = ["data", "tool", "train", "scorer", "ChemBFN", "MLP", "EnsembleChemBFN"]
10
+ __version__ = "1.4.0"
11
11
  __author__ = "Nianze A. Tao (Omozawa Sueno)"
@@ -1,7 +1,7 @@
1
1
  # -*- coding: utf-8 -*-
2
2
  # Author: Nianze A. TAO (Omozawa SUENO)
3
3
  """
4
- Tokenise SMILES/SAFE/SELFIES/GEO2SEQ/protein-sequence strings.
4
+ Tokenise SMILES/SAFE/SELFIES/protein-sequence strings.
5
5
  """
6
6
  import os
7
7
  import re
@@ -32,30 +32,14 @@ SMI_REGEX_PATTERN = (
32
32
  r"~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
33
33
  )
34
34
  SEL_REGEX_PATTERN = r"(\[[^\]]+]|\.)"
35
- GEO_REGEX_PATTERN = (
36
- r"(H[e,f,g,s,o]?|"
37
- r"L[i,v,a,r,u]|"
38
- r"B[e,r,a,i,h,k]?|"
39
- r"C[l,a,r,o,u,d,s,n,e,m,f]?|"
40
- r"N[e,a,i,b,h,d,o,p]?|"
41
- r"O[s,g]?|S[i,c,e,r,n,m,b,g]?|"
42
- r"K[r]?|T[i,c,e,a,l,b,h,m,s]|"
43
- r"G[a,e,d]|R[b,u,h,e,n,a,f,g]|"
44
- r"Yb?|Z[n,r]|P[t,o,d,r,a,u,b,m]?|"
45
- r"F[e,r,l,m]?|M[g,n,o,t,c,d]|"
46
- r"A[l,r,s,g,u,t,c,m]|I[n,r]?|"
47
- r"W|X[e]|E[u,r,s]|U|D[b,s,y]|"
48
- r"-|.| |[0-9])"
49
- )
50
35
  AA_REGEX_PATTERN = r"(A|B|C|D|E|F|G|H|I|K|L|M|N|P|Q|R|S|T|V|W|Y|Z|-|.)"
51
36
  smi_regex = re.compile(SMI_REGEX_PATTERN)
52
37
  sel_regex = re.compile(SEL_REGEX_PATTERN)
53
- geo_regex = re.compile(GEO_REGEX_PATTERN)
54
38
  aa_regex = re.compile(AA_REGEX_PATTERN)
55
39
 
56
40
 
57
41
  def load_vocab(
58
- vocab_file: Union[str, Path]
42
+ vocab_file: Union[str, Path],
59
43
  ) -> Dict[str, Union[int, List[str], Dict[str, int]]]:
60
44
  """
61
45
  Load vocabulary from source file.
@@ -86,9 +70,6 @@ AA_VOCAB_KEYS = (
86
70
  )
87
71
  AA_VOCAB_COUNT = len(AA_VOCAB_KEYS)
88
72
  AA_VOCAB_DICT = dict(zip(AA_VOCAB_KEYS, range(AA_VOCAB_COUNT)))
89
- GEO_VOCAB_KEYS = VOCAB_KEYS[0:3] + [" "] + VOCAB_KEYS[22:150] + [".", "-"]
90
- GEO_VOCAB_COUNT = len(GEO_VOCAB_KEYS)
91
- GEO_VOCAB_DICT = dict(zip(GEO_VOCAB_KEYS, range(GEO_VOCAB_COUNT)))
92
73
 
93
74
 
94
75
  def smiles2vec(smiles: str) -> List[int]:
@@ -104,19 +85,6 @@ def smiles2vec(smiles: str) -> List[int]:
104
85
  return [VOCAB_DICT[token] for token in tokens]
105
86
 
106
87
 
107
- def geo2vec(geo2seq: str) -> List[int]:
108
- """
109
- Geo2Seq tokenisation using a dataset-independent regex pattern.
110
-
111
- :param geo2seq: Geo2Seq string
112
- :type geo2seq: str
113
- :return: tokens w/o `<start>` and `<end>`
114
- :rtype: list
115
- """
116
- tokens = [token for token in geo_regex.findall(geo2seq)]
117
- return [GEO_VOCAB_DICT[token] for token in tokens]
118
-
119
-
120
88
  def aa2vec(aa_seq: str) -> List[int]:
121
89
  """
122
90
  Protein sequence tokenisation using a dataset-independent regex pattern.
@@ -147,11 +115,6 @@ def smiles2token(smiles: str) -> Tensor:
147
115
  return torch.tensor([1] + smiles2vec(smiles) + [2], dtype=torch.long)
148
116
 
149
117
 
150
- def geo2token(geo2seq: str) -> Tensor:
151
- # start token: <start> = 1; end token: <esc> = 2
152
- return torch.tensor([1] + geo2vec(geo2seq) + [2], dtype=torch.long)
153
-
154
-
155
118
  def aa2token(aa_seq: str) -> Tensor:
156
119
  # start token: <start> = 1; end token: <end> = 2
157
120
  return torch.tensor([1] + aa2vec(aa_seq) + [2], dtype=torch.long)
@@ -4,7 +4,8 @@
4
4
  Define Bayesian Flow Network for Chemistry (ChemBFN) model.
5
5
  """
6
6
  from pathlib import Path
7
- from typing import List, Tuple, Optional, Union
7
+ from copy import deepcopy
8
+ from typing import List, Tuple, Dict, Optional, Union, Callable
8
9
  import torch
9
10
  import torch.nn as nn
10
11
  from torch import Tensor
@@ -161,8 +162,8 @@ class Attention(nn.Module):
161
162
  :return: attentioned output; shape: (n_b, n_t, n_f)
162
163
  :rtype: torch.Tensor
163
164
  """
164
- n_b, n_a, _ = shape = x.shape
165
- split = (n_b, n_a, self.nh, self.d)
165
+ n_b, n_t, _ = shape = x.shape
166
+ split = (n_b, n_t, self.nh, self.d)
166
167
  q, k, v = self.qkv(x).chunk(3, -1)
167
168
  q = q.view(split).permute(2, 0, 1, 3).contiguous()
168
169
  k = k.view(split).permute(2, 0, 1, 3).contiguous()
@@ -427,12 +428,12 @@ class ChemBFN(nn.Module):
427
428
  c = self.time_embed(t)
428
429
  if y is not None:
429
430
  c += y
430
- pe = self.position(x.shape[1])
431
+ pe = self.position(n_t)
431
432
  x = self.embedding(x)
432
433
  attn_mask: Optional[Tensor] = None
433
434
  if self.semi_autoregressive:
434
435
  attn_mask = torch.tril(
435
- torch.ones((1, n_b, n_t, n_t), device=self.beta.device), diagonal=0
436
+ torch.ones((1, n_b, n_t, n_t), device=x.device), diagonal=0
436
437
  )
437
438
  else:
438
439
  if mask is not None:
@@ -592,6 +593,13 @@ class ChemBFN(nn.Module):
592
593
  x, logits = torch.broadcast_tensors(x[..., None], logits)
593
594
  return (-logits.gather(-1, x[..., :1]).squeeze(-1)).mean()
594
595
 
596
+ @staticmethod
597
+ def reshape_y(y: Tensor) -> Tensor:
598
+ assert y.dim() <= 3 # this doesn't work if the model is frezen in JIT.
599
+ if y.dim() == 2:
600
+ return y[:, None, :]
601
+ return y
602
+
595
603
  @torch.jit.export
596
604
  def sample(
597
605
  self,
@@ -607,7 +615,7 @@ class ChemBFN(nn.Module):
607
615
 
608
616
  :param batch_size: batch size
609
617
  :param sequence_size: max sequence length
610
- :param y: conditioning vector; shape: (n_b, 1, n_f)
618
+ :param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
611
619
  :param sample_step: number of sampling steps
612
620
  :param guidance_strength: strength of conditional generation. It is not used if y is null.
613
621
  :param token_mask: token mask; shape: (1, 1, n_vocab)
@@ -626,9 +634,7 @@ class ChemBFN(nn.Module):
626
634
  / self.K
627
635
  )
628
636
  if y is not None:
629
- assert y.dim() == 3 # this doesn't work if the model is frezen in JIT.
630
- if y.shape[0] == 1:
631
- y = y.repeat(batch_size, 1, 1)
637
+ y = self.reshape_y(y)
632
638
  for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
633
639
  t = (i - 1).view(1, 1, 1).repeat(batch_size, 1, 1) / sample_step
634
640
  p = self.discrete_output_distribution(theta, t, y, guidance_strength)
@@ -663,7 +669,7 @@ class ChemBFN(nn.Module):
663
669
 
664
670
  :param batch_size: batch size
665
671
  :param sequence_size: max sequence length
666
- :param y: conditioning vector; shape: (n_b, 1, n_f)
672
+ :param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
667
673
  :param sample_step: number of sampling steps
668
674
  :param guidance_strength: strength of conditional generation. It is not used if y is null.
669
675
  :param token_mask: token mask; shape: (1, 1, n_vocab)
@@ -681,9 +687,7 @@ class ChemBFN(nn.Module):
681
687
  """
682
688
  z = torch.zeros((batch_size, sequence_size, self.K), device=self.beta.device)
683
689
  if y is not None:
684
- assert y.dim() == 3 # this doesn't work if the model is frezen in JIT.
685
- if y.shape[0] == 1:
686
- y = y.repeat(batch_size, 1, 1)
690
+ y = self.reshape_y(y)
687
691
  for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
688
692
  t = (i - 1).view(1, 1, 1).repeat(batch_size, 1, 1) / sample_step
689
693
  theta = torch.softmax(z, -1)
@@ -714,7 +718,7 @@ class ChemBFN(nn.Module):
714
718
  Molecule inpaint functionality.
715
719
 
716
720
  :param x: categorical indices of scaffold; shape: (n_b, n_t)
717
- :param y: conditioning vector; shape: (n_b, 1, n_f)
721
+ :param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
718
722
  :param sample_step: number of sampling steps
719
723
  :param guidance_strength: strength of conditional generation. It is not used if y is null.
720
724
  :param token_mask: token mask; shape: (1, 1, n_vocab)
@@ -733,9 +737,7 @@ class ChemBFN(nn.Module):
733
737
  x_onehot = nn.functional.one_hot(x, self.K) * mask
734
738
  theta = x_onehot + (1 - mask) * theta
735
739
  if y is not None:
736
- assert y.dim() == 3 # this doesn't work if the model is frezen in JIT.
737
- if y.shape[0] == 1:
738
- y = y.repeat(n_b, 1, 1)
740
+ y = self.reshape_y(y)
739
741
  for i in torch.linspace(1, sample_step, sample_step, device=x.device):
740
742
  t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
741
743
  p = self.discrete_output_distribution(theta, t, y, guidance_strength)
@@ -769,7 +771,7 @@ class ChemBFN(nn.Module):
769
771
  ODE inpainting.
770
772
 
771
773
  :param x: categorical indices of scaffold; shape: (n_b, n_t)
772
- :param y: conditioning vector; shape: (n_b, 1, n_f)
774
+ :param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
773
775
  :param sample_step: number of sampling steps
774
776
  :param guidance_strength: strength of conditional generation. It is not used if y is null.
775
777
  :param token_mask: token mask; shape: (1, 1, n_vocab)
@@ -789,9 +791,7 @@ class ChemBFN(nn.Module):
789
791
  x_onehot = nn.functional.one_hot(x, self.K) * mask
790
792
  z = torch.zeros((n_b, n_t, self.K), device=self.beta.device)
791
793
  if y is not None:
792
- assert y.dim() == 3 # this doesn't work if the model is frezen in JIT.
793
- if y.shape[0] == 1:
794
- y = y.repeat(n_b, 1, 1)
794
+ y = self.reshape_y(y)
795
795
  for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
796
796
  t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
797
797
  theta = torch.softmax(z, -1)
@@ -847,13 +847,7 @@ class ChemBFN(nn.Module):
847
847
  with open(ckpt, "rb") as f:
848
848
  state = torch.load(f, "cpu", weights_only=True)
849
849
  nn, hparam = state["nn"], state["hparam"]
850
- model = cls(
851
- hparam["num_vocab"],
852
- hparam["channel"],
853
- hparam["num_layer"],
854
- hparam["num_head"],
855
- hparam["dropout"],
856
- )
850
+ model = cls(**hparam)
857
851
  model.load_state_dict(nn, False)
858
852
  if ckpt_lora:
859
853
  with open(ckpt_lora, "rb") as g:
@@ -908,7 +902,7 @@ class MLP(nn.Module):
908
902
  if self.class_input:
909
903
  x = x.to(dtype=torch.long)
910
904
  for layer in self.layers[:-1]:
911
- x = torch.selu(layer(x))
905
+ x = torch.selu(layer.forward(x))
912
906
  return self.layers[-1](x)
913
907
 
914
908
  @classmethod
@@ -926,10 +920,382 @@ class MLP(nn.Module):
926
920
  with open(ckpt, "rb") as f:
927
921
  state = torch.load(f, "cpu", weights_only=True)
928
922
  nn, hparam = state["nn"], state["hparam"]
929
- model = cls(hparam["size"], hparam["class_input"], hparam["dropout"])
923
+ model = cls(**hparam)
930
924
  model.load_state_dict(nn, strict)
931
925
  return model
932
926
 
933
927
 
928
+ class EnsembleChemBFN(ChemBFN):
929
+ """
930
+ This module does not fully support `torch.jit.script`. We have `EnsembleChemBFN.jit()`
931
+ method to JIT compile the submodels.
932
+ `torch.compile()` is a better choice to compiling the whole model.
933
+ """
934
+
935
+ def __init__(
936
+ self,
937
+ base_model_path: Union[str, Path],
938
+ lora_paths: Union[List[Union[str, Path]], Dict[str, Union[str, Path]]],
939
+ cond_heads: Union[List[nn.Module], Dict[str, nn.Module]],
940
+ adapter_weights: Optional[Union[List[float], Dict[str, float]]] = None,
941
+ semi_autoregressive_flags: Optional[Union[List[bool], Dict[str, bool]]] = None,
942
+ ) -> None:
943
+ """
944
+ Ensemble of ChemBFN models from LoRA checkpoints.
945
+
946
+ :param base_model_path: base model checkpoint file
947
+ :param lora_paths: a list of LoRA checkpoint files or a `dict` instance of these files
948
+ :param cond_heads: a list of conditioning network heads or a `dict` instance of these networks
949
+ :param adapter_weights: a list of weights of each LoRA finetuned model or a 'dict` instance of these weights; default is equally weighted
950
+ :param semi_autoregressive_flags: a list of the semi-autoregressive behaviour states of each LoRA finetuned model or a `dict` instance of these states; default is all `False`
951
+ :type base_model_path: str | pathlib.Path
952
+ :type lora_paths: list | dict
953
+ :type cond_heads: list | dict
954
+ :type adapter_weights: list | dict | None
955
+ :type semi_autoregressive_flags: list | dict | None
956
+ """
957
+ n = len(lora_paths)
958
+ assert type(lora_paths) == type(
959
+ cond_heads
960
+ ), "`lora_paths` and `cond_heads` should have the same type!"
961
+ assert n == len(
962
+ cond_heads
963
+ ), "`lora_paths` and `cond_heads` should have the same length!"
964
+ if adapter_weights:
965
+ assert type(lora_paths) == type(
966
+ adapter_weights
967
+ ), "`lora_paths` and `adapter_weights` should have the same type!"
968
+ assert n == len(
969
+ adapter_weights
970
+ ), "`lora_paths` and `adapter_weights` should have the same length!"
971
+ if semi_autoregressive_flags:
972
+ assert type(lora_paths) == type(
973
+ semi_autoregressive_flags
974
+ ), "`lora_paths` and `semi_autoregressive_flags` should have the same type!"
975
+ assert n == len(
976
+ semi_autoregressive_flags
977
+ ), "`lora_paths` and `semi_autoregressive_flags` should have the same length!"
978
+ _label_is_dict = isinstance(lora_paths, dict)
979
+ if isinstance(lora_paths, list):
980
+ names = tuple(f"val_{i}" for i in range(n))
981
+ lora_paths = dict(zip(names, lora_paths))
982
+ cond_heads = dict(zip(names, cond_heads))
983
+ if not adapter_weights:
984
+ adapter_weights = (1 / n for _ in names)
985
+ if not semi_autoregressive_flags:
986
+ semi_autoregressive_flags = (False for _ in names)
987
+ adapter_weights = dict(zip(names, adapter_weights))
988
+ semi_autoregressive_flags = dict(zip(names, semi_autoregressive_flags))
989
+ else:
990
+ names = tuple(lora_paths.keys())
991
+ if not adapter_weights:
992
+ adapter_weights = dict(zip(names, (1 / n for _ in names)))
993
+ if not semi_autoregressive_flags:
994
+ semi_autoregressive_flags = dict(zip(names, (False for _ in names)))
995
+ base_model = ChemBFN.from_checkpoint(base_model_path)
996
+ models = dict(zip(names, (deepcopy(base_model.eval()) for _ in names)))
997
+ for k in names:
998
+ with open(lora_paths[k], "rb") as f:
999
+ state = torch.load(f, "cpu", weights_only=True)
1000
+ lora_nn, lora_param = state["lora_nn"], state["lora_param"]
1001
+ models[k].enable_lora(**lora_param)
1002
+ models[k].load_state_dict(lora_nn, False)
1003
+ models[k].semi_autoregressive = semi_autoregressive_flags[k]
1004
+ super().__init__(**base_model.hparam)
1005
+ self.cond_heads = nn.ModuleDict(cond_heads)
1006
+ self.models = nn.ModuleDict(models)
1007
+ self.adapter_weights = adapter_weights
1008
+ self._label_is_dict = _label_is_dict # flag
1009
+ # ------- remove unnecessary submodules -------
1010
+ self.embedding = None
1011
+ self.time_embed = None
1012
+ self.position = None
1013
+ self.encoder_layers = None
1014
+ self.final_layer = None
1015
+ self.__delattr__("embedding")
1016
+ self.__delattr__("time_embed")
1017
+ self.__delattr__("position")
1018
+ self.__delattr__("encoder_layers")
1019
+ self.__delattr__("final_layer")
1020
+ # ------- remove unused attributes -------
1021
+ self.__delattr__("semi_autoregressive")
1022
+ self.__delattr__("lora_enabled")
1023
+ self.__delattr__("lora_param")
1024
+ self.__delattr__("hparam")
1025
+
1026
+ def construct_y(
1027
+ self, c: Union[List[Tensor], Dict[str, Tensor]]
1028
+ ) -> Dict[str, Tensor]:
1029
+ assert (
1030
+ isinstance(c, dict) is self._label_is_dict
1031
+ ), f"`c` should be a {'`dict` instance' if self._label_is_dict else '`list` instance'} but got {type(c)} instand."
1032
+ out: Dict[str, Tensor] = {}
1033
+ if isinstance(c, list):
1034
+ c = dict(zip([f"val_{i}" for i in range(len(c))], c))
1035
+ for name, model in self.cond_heads.items():
1036
+ y = model.forward(c[name])
1037
+ if y.dim() == 2:
1038
+ y = y[:, None, :]
1039
+ out[name] = y
1040
+ return out
1041
+
1042
+ def discrete_output_distribution(
1043
+ self, theta: Tensor, t: Tensor, y: Dict[str, Tensor], w: float
1044
+ ) -> Tensor:
1045
+ """
1046
+ :param theta: input distribution; shape: (n_b, n_t, n_vocab)
1047
+ :param t: continuous time in [0, 1]; shape: (n_b, 1, 1)
1048
+ :param y: a dict of conditioning vectors; shape: (n_b, 1, n_f) * n_h
1049
+ :param w: guidance strength controlling the conditional generation
1050
+ :type theta: torch.Tensor
1051
+ :type t: torch.Tensor
1052
+ :type y: dict
1053
+ :type w: float
1054
+ :return: output distribution; shape: (n_b, n_t, n_vocab)
1055
+ :rtype: torch.Tensor
1056
+ """
1057
+ theta = 2 * theta - 1 # rescale to [-1, 1]
1058
+ p_uncond, p_cond = torch.zeros_like(theta), torch.zeros_like(theta)
1059
+ # Q: Why not use `torch.vmap`? It's faster than doing the loop, isn't it?
1060
+ #
1061
+ # A: We have quite a few reasons to avoid using `vmap`:
1062
+ # 1. JIT doesn't support vmap;
1063
+ # 2. It's harder to switch on/off semi-autroregssive behaviours for individual
1064
+ # models when all models are stacked into one (we have a solution but it's not
1065
+ # that elegant);
1066
+ # 3. We just found that the result from vmap was not identical to doing the loop;
1067
+ # 4. vmap requires all models have the same size but it's not always that case
1068
+ # since we sometimes use different ranks of LoRA in finetuning.
1069
+ for name, model in self.models.items():
1070
+ p_uncond_ = model.forward(theta, t, None, None)
1071
+ p_uncond += p_uncond_ * self.adapter_weights[name]
1072
+ p_cond_ = model.forward(theta, t, None, y[name])
1073
+ p_cond += p_cond_ * self.adapter_weights[name]
1074
+ return softmax((1 + w) * p_cond - w * p_uncond, -1)
1075
+
1076
+ @staticmethod
1077
+ def reshape_y(y: Dict[str, Tensor]) -> Dict[str, Tensor]:
1078
+ for k in y:
1079
+ assert y[k].dim() <= 3
1080
+ if y[k].dim() == 2:
1081
+ y[k] = y[k][:, None, :]
1082
+ return y
1083
+
1084
+ @torch.inference_mode()
1085
+ def sample(
1086
+ self,
1087
+ batch_size: int,
1088
+ sequence_size: int,
1089
+ conditions: Union[List[Tensor], Dict[str, Tensor]],
1090
+ sample_step: int = 100,
1091
+ guidance_strength: float = 4.0,
1092
+ token_mask: Optional[Tensor] = None,
1093
+ ) -> Tuple[Tensor, Tensor]:
1094
+ """
1095
+ Sample from a piror distribution.
1096
+
1097
+ :param batch_size: batch size
1098
+ :param sequence_size: max sequence length
1099
+ :param conditions: guidance conditions; shape: (n_b, n_c) * n_h
1100
+ :param sample_step: number of sampling steps
1101
+ :param guidance_strength: strength of conditional generation. It is not used if y is null.
1102
+ :param token_mask: token mask; shape: (1, 1, n_vocab)
1103
+ :type batch_size: int
1104
+ :type sequence_size: int
1105
+ :type conditions: list | dict
1106
+ :type sample_step: int
1107
+ :type guidance_strength: float
1108
+ :type token_mask: torch.Tensor | None
1109
+ :return: sampled token indices; shape: (n_b, n_t) \n
1110
+ entropy of the tokens; shape: (n_b)
1111
+ :rtype: tuple
1112
+ """
1113
+ y = self.construct_y(conditions)
1114
+ return super().sample(
1115
+ batch_size, sequence_size, y, sample_step, guidance_strength, token_mask
1116
+ )
1117
+
1118
+ @torch.inference_mode()
1119
+ def ode_sample(
1120
+ self,
1121
+ batch_size: int,
1122
+ sequence_size: int,
1123
+ conditions: Union[List[Tensor], Dict[str, Tensor]],
1124
+ sample_step: int = 100,
1125
+ guidance_strength: float = 4.0,
1126
+ token_mask: Optional[Tensor] = None,
1127
+ temperature: float = 0.5,
1128
+ ) -> Tuple[Tensor, Tensor]:
1129
+ """
1130
+ ODE-based sampling.
1131
+
1132
+ :param batch_size: batch size
1133
+ :param sequence_size: max sequence length
1134
+ :param conditions: conditioning vector; shape: (n_b, n_c) * n_h
1135
+ :param sample_step: number of sampling steps
1136
+ :param guidance_strength: strength of conditional generation. It is not used if y is null.
1137
+ :param token_mask: token mask; shape: (1, 1, n_vocab)
1138
+ :param temperature: sampling temperature
1139
+ :type batch_size: int
1140
+ :type sequence_size: int
1141
+ :type conditions: list | dict
1142
+ :type sample_step: int
1143
+ :type guidance_strength: float
1144
+ :type token_mask: torch.Tensor | None
1145
+ :type temperature: float
1146
+ :return: sampled token indices; shape: (n_b, n_t) \n
1147
+ entropy of the tokens; shape: (n_b)
1148
+ :rtype: tuple
1149
+ """
1150
+ y = self.construct_y(conditions)
1151
+ return super().ode_sample(
1152
+ batch_size,
1153
+ sequence_size,
1154
+ y,
1155
+ sample_step,
1156
+ guidance_strength,
1157
+ token_mask,
1158
+ temperature,
1159
+ )
1160
+
1161
+ @torch.inference_mode()
1162
+ def inpaint(
1163
+ self,
1164
+ x: Tensor,
1165
+ conditions: Union[List[Tensor], Dict[str, Tensor]],
1166
+ sample_step: int = 100,
1167
+ guidance_strength: float = 4.0,
1168
+ token_mask: Optional[Tensor] = None,
1169
+ ) -> Tuple[Tensor, Tensor]:
1170
+ """
1171
+ Molecule inpaint functionality.
1172
+
1173
+ :param x: categorical indices of scaffold; shape: (n_b, n_t)
1174
+ :param conditions: conditioning vector; shape: (n_b, n_c) * n_h
1175
+ :param sample_step: number of sampling steps
1176
+ :param guidance_strength: strength of conditional generation. It is not used if y is null.
1177
+ :param token_mask: token mask; shape: (1, 1, n_vocab)
1178
+ :type x: torch.Tensor
1179
+ :type conditions: list | dict
1180
+ :type sample_step: int
1181
+ :type guidance_strength: float
1182
+ :type token_mask: torch.Tensor | None
1183
+ :return: sampled token indices; shape: (n_b, n_t) \n
1184
+ entropy of the tokens; shape: (n_b)
1185
+ :rtype: tuple
1186
+ """
1187
+ y = self.construct_y(conditions)
1188
+ return super().inpaint(x, y, sample_step, guidance_strength, token_mask)
1189
+
1190
+ @torch.inference_mode()
1191
+ def ode_inpaint(
1192
+ self,
1193
+ x: Tensor,
1194
+ conditions: Union[List[Tensor], Dict[str, Tensor]],
1195
+ sample_step: int = 100,
1196
+ guidance_strength: float = 4.0,
1197
+ token_mask: Optional[Tensor] = None,
1198
+ temperature: float = 0.5,
1199
+ ) -> Tuple[Tensor, Tensor]:
1200
+ """
1201
+ ODE inpainting.
1202
+
1203
+ :param x: categorical indices of scaffold; shape: (n_b, n_t)
1204
+ :param conditions: conditioning vector; shape: (n_b, n_c) * n_h
1205
+ :param sample_step: number of sampling steps
1206
+ :param guidance_strength: strength of conditional generation. It is not used if y is null.
1207
+ :param token_mask: token mask; shape: (1, 1, n_vocab)
1208
+ :param temperature: sampling temperature
1209
+ :type x: torch.Tensor
1210
+ :type conditions: list | dict
1211
+ :type sample_step: int
1212
+ :type guidance_strength: float
1213
+ :type token_mask: torch.Tensor | None
1214
+ :type temperature: float
1215
+ :return: sampled token indices; shape: (n_b, n_t) \n
1216
+ entropy of the tokens; shape: (n_b)
1217
+ :rtype: tuple
1218
+ """
1219
+ y = self.construct_y(conditions)
1220
+ return super().ode_inpaint(
1221
+ x, y, sample_step, guidance_strength, token_mask, temperature
1222
+ )
1223
+
1224
+ def quantise(
1225
+ self, quantise_method: Optional[Callable[[ChemBFN], nn.Module]] = None
1226
+ ) -> None:
1227
+ """
1228
+ Quantise the submodels. \n
1229
+ This method should be called, if necessary, before `torch.compile()`.
1230
+
1231
+ :param quantise_method: quantisation method; default is `bayesianflow_for_chem.tool.quantise_model`
1232
+ :type quantise_method: callable | None
1233
+ :return:
1234
+ :rtype: None
1235
+ """
1236
+ if quantise_method is None:
1237
+ from bayesianflow_for_chem.tool import quantise_model
1238
+
1239
+ quantise_method = quantise_model
1240
+ for k, v in self.models.items():
1241
+ self.models[k] = quantise_method(v)
1242
+
1243
+ def jit(self, freeze: bool = False) -> None:
1244
+ """
1245
+ JIT compile the submodels. \n
1246
+ This method should be called, if necessary, before `quantise()` method is called if applied.
1247
+
1248
+ :param freeze: whether to freeze the submodels; default is `False`. If set to `True` this
1249
+ method should be called before moving the model to a different device.
1250
+ :type freeze: bool
1251
+ :return:
1252
+ :rtype: None
1253
+ """
1254
+ for k, v in self.models.items():
1255
+ self.models[k] = torch.jit.script(v)
1256
+ if freeze:
1257
+ self.models[k] = torch.jit.freeze(
1258
+ self.models[k], ["semi_autoregressive"]
1259
+ )
1260
+
1261
+ @torch.jit.ignore
1262
+ def forward(self, *_, **__) -> None:
1263
+ """
1264
+ Don't use this method!
1265
+ """
1266
+ raise NotImplementedError("There's nothing here!")
1267
+
1268
+ def cts_loss(self, *_, **__) -> None:
1269
+ """
1270
+ Don't use this method!
1271
+ """
1272
+ raise NotImplementedError("There's nothing here!")
1273
+
1274
+ def reconstruction_loss(self, *_, **__) -> None:
1275
+ """
1276
+ Don't use this method!
1277
+ """
1278
+ raise NotImplementedError("There's nothing here!")
1279
+
1280
+ def enable_lora(self, *_, **__) -> None:
1281
+ """
1282
+ Don't use this method!
1283
+ """
1284
+ raise NotImplementedError("There's nothing here!")
1285
+
1286
+ def inference(self, *_, **__) -> None:
1287
+ """
1288
+ Don't use this method!
1289
+ """
1290
+ raise NotImplementedError("There's nothing here!")
1291
+
1292
+ @classmethod
1293
+ def from_checkpoint(cls, *_, **__) -> None:
1294
+ """
1295
+ Don't use this method!
1296
+ """
1297
+ raise NotImplementedError("There's nothing here!")
1298
+
1299
+
934
1300
  if __name__ == "__main__":
935
1301
  ...
@@ -1,7 +1,7 @@
1
1
  # -*- coding: utf-8 -*-
2
2
  # Author: Nianze A. TAO (Omozawa SUENO)
3
3
  """
4
- Scorers.
4
+ Define essential scorers.
5
5
  """
6
6
  from typing import List, Callable, Union, Optional
7
7
  import torch
@@ -1,11 +1,11 @@
1
1
  # -*- coding: utf-8 -*-
2
2
  # Author: Nianze A. TAO (Omozawa SUENO)
3
3
  """
4
- Tools.
4
+ Essential tools.
5
5
  """
6
- import re
7
6
  import csv
8
7
  import random
8
+ import warnings
9
9
  from copy import deepcopy
10
10
  from pathlib import Path
11
11
  from typing import List, Dict, Tuple, Union, Optional
@@ -16,7 +16,16 @@ from torch import cuda, Tensor, softmax
16
16
  from torch.ao import quantization
17
17
  from torch.utils.data import DataLoader
18
18
  from typing_extensions import Self
19
- from rdkit.Chem import rdDetermineBonds, Bond, MolFromXYZBlock, CanonicalRankAtoms
19
+ from rdkit.Chem.rdchem import Mol, Bond
20
+ from rdkit.Chem import (
21
+ rdDetermineBonds,
22
+ MolFromXYZBlock,
23
+ MolFromSmiles,
24
+ MolToSmiles,
25
+ CanonSmiles,
26
+ AllChem,
27
+ AddHs,
28
+ )
20
29
  from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles # type: ignore
21
30
  from sklearn.metrics import (
22
31
  roc_auc_score,
@@ -26,35 +35,8 @@ from sklearn.metrics import (
26
35
  mean_absolute_error,
27
36
  root_mean_squared_error,
28
37
  )
29
-
30
- try:
31
- from pynauty import Graph, canon_label # type: ignore
32
-
33
- _use_pynauty = True
34
- except ImportError:
35
- import warnings
36
-
37
- _use_pynauty = False
38
-
39
38
  from .data import VOCAB_KEYS
40
- from .model import ChemBFN, MLP, Linear
41
-
42
-
43
- _atom_regex_pattern = (
44
- r"(H[e,f,g,s,o]?|"
45
- r"L[i,v,a,r,u]|"
46
- r"B[e,r,a,i,h,k]?|"
47
- r"C[l,a,r,o,u,d,s,n,e,m,f]?|"
48
- r"N[e,a,i,b,h,d,o,p]?|"
49
- r"O[s,g]?|S[i,c,e,r,n,m,b,g]?|"
50
- r"K[r]?|T[i,c,e,a,l,b,h,m,s]|"
51
- r"G[a,e,d]|R[b,u,h,e,n,a,f,g]|"
52
- r"Yb?|Z[n,r]|P[t,o,d,r,a,u,b,m]?|"
53
- r"F[e,r,l,m]?|M[g,n,o,t,c,d]|"
54
- r"A[l,r,s,g,u,t,c,m]|I[n,r]?|"
55
- r"W|X[e]|E[u,r,s]|U|D[b,s,y])"
56
- )
57
- _atom_regex = re.compile(_atom_regex_pattern)
39
+ from .model import ChemBFN, MLP, Linear, EnsembleChemBFN
58
40
 
59
41
 
60
42
  def _find_device() -> torch.device:
@@ -65,10 +47,6 @@ def _find_device() -> torch.device:
65
47
  return torch.device("cpu")
66
48
 
67
49
 
68
- def _bond_pair_idx(bonds: Bond) -> List[List[int]]:
69
- return [[i.GetBeginAtomIdx(), i.GetEndAtomIdx()] for i in bonds]
70
-
71
-
72
50
  @torch.no_grad()
73
51
  def test(
74
52
  model: ChemBFN,
@@ -196,11 +174,14 @@ def split_dataset(
196
174
  "\033[0m",
197
175
  stacklevel=2,
198
176
  )
199
- scaffold = MurckoScaffoldSmiles(d[smiles_idx[0]])
200
- if scaffold in scaffolds:
201
- scaffolds[scaffold].append(key)
202
- else:
203
- scaffolds[scaffold] = [key]
177
+ try:
178
+ scaffold = MurckoScaffoldSmiles(d[smiles_idx[0]])
179
+ if scaffold in scaffolds:
180
+ scaffolds[scaffold].append(key)
181
+ else:
182
+ scaffolds[scaffold] = [key]
183
+ except ValueError: # do nothing when SMILES is not valid
184
+ ...
204
185
  scaffolds = {key: sorted(value) for key, value in scaffolds.items()}
205
186
  train_set, test_set, val_set = [], [], []
206
187
  for idxs in scaffolds.values():
@@ -222,137 +203,13 @@ def split_dataset(
222
203
  writer.writerows([header] + val_set)
223
204
 
224
205
 
225
- def geo2seq(
226
- symbols: List[str],
227
- coordinates: np.ndarray,
228
- decimals: int = 2,
229
- angle_unit: str = "degree",
230
- ) -> str:
231
- """
232
- Geometry-to-sequence function.\n
233
- The algorithm follows the descriptions in paper: https://arxiv.org/abs/2408.10120.
234
-
235
- :param symbols: a list of atomic symbols
236
- :param coordinates: Cartesian coordinates; shape: (n_a, 3)
237
- :param decimals: number of decimal places to round to
238
- :param angle_unit: `'degree'` or `'radian'`
239
- :type symbols: list
240
- :type coordinates: numpy.ndarray
241
- :type decimals: int
242
- :type angle_unit: str
243
- :return: `Geo2Seq` string
244
- :rtype: str
245
- """
246
- assert angle_unit in ("degree", "radian")
247
- angle_scale = 180 / np.pi if angle_unit == "degree" else 1.0
248
- n = len(symbols)
249
- if n == 1:
250
- return f"{symbols[0]} {'0.0'} {'0.0'} {'0.0'}"
251
- xyz_block = [str(n), ""]
252
- for i, atom in enumerate(symbols):
253
- xyz_block.append(
254
- f"{atom} {'%.10f' % coordinates[i][0].item()} {'%.10f' % coordinates[i][1].item()} {'%.10f' % coordinates[i][2].item()}"
255
- )
256
- mol = MolFromXYZBlock("\n".join(xyz_block))
257
- rdDetermineBonds.DetermineConnectivity(mol)
258
- # ------- Canonicalization -------
259
- if _use_pynauty:
260
- pair_idx = np.array(_bond_pair_idx(mol.GetBonds())).T.tolist()
261
- pair_dict: Dict[int, List[int]] = {}
262
- for key, i in enumerate(pair_idx[0]):
263
- if i not in pair_dict:
264
- pair_dict[i] = [pair_idx[1][key]]
265
- else:
266
- pair_dict[i].append(pair_idx[1][key])
267
- g = Graph(n, adjacency_dict=pair_dict)
268
- cl = canon_label(g) # type: list
269
- else:
270
- warnings.warn(
271
- "\033[32;1m"
272
- "`pynauty` is not installed."
273
- " Switched to canonicalization function provided by `rdkit`."
274
- " This is the expected behaviour only if you are working on Windows platform."
275
- "\033[0m",
276
- stacklevel=2,
277
- )
278
- cl = list(CanonicalRankAtoms(mol, breakTies=True))
279
- symbols = np.array([[s] for s in symbols])[cl].flatten().tolist()
280
- coordinates = coordinates[cl]
281
- # ------- Find global coordinate frame -------
282
- if n == 2:
283
- d = np.round(np.linalg.norm(coordinates[0] - coordinates[1], 2), decimals)
284
- return f"{symbols[0]} {'0.0'} {'0.0'} {'0.0'} {symbols[1]} {d} {'0.0'} {'0.0'}"
285
- for idx_0 in range(n - 2):
286
- _vec0 = coordinates[idx_0] - coordinates[idx_0 + 1]
287
- _vec1 = coordinates[idx_0] - coordinates[idx_0 + 2]
288
- _d1 = np.linalg.norm(_vec0, 2)
289
- _d2 = np.linalg.norm(_vec1, 2)
290
- if 1 - np.abs(np.dot(_vec0, _vec1) / (_d1 * _d2)) > 1e-6:
291
- break
292
- x = (coordinates[idx_0 + 1] - coordinates[idx_0]) / _d1
293
- y = np.cross((coordinates[idx_0 + 2] - coordinates[idx_0]), x)
294
- y_d = np.linalg.norm(y, 2)
295
- y = y / np.ma.filled(np.ma.array(y_d, mask=y_d == 0), np.inf)
296
- z = np.cross(x, y)
297
- # ------- Build spherical coordinates -------
298
- vec = coordinates - coordinates[idx_0]
299
- d = np.linalg.norm(vec, 2, axis=-1)
300
- _d = np.ma.filled(np.ma.array(d, mask=d == 0), np.inf)
301
- theta = angle_scale * np.arccos(np.dot(vec, z) / _d) # in [0, \pi]
302
- phi = angle_scale * np.arctan2(np.dot(vec, y), np.dot(vec, x)) # in [-\pi, \pi]
303
- info = np.vstack([d, theta, phi]).T
304
- info[idx_0] = np.zeros(3)
305
- info = [
306
- f"{symbols[i]} {r[0]} {r[1]} {r[2]}"
307
- for i, r in enumerate(np.round(info, decimals))
308
- ]
309
- return " ".join(info)
310
-
311
-
312
- def seq2geo(
313
- seq: str, angle_unit: str = "degree"
314
- ) -> Optional[Tuple[List[str], List[List[float]]]]:
315
- """
316
- Sequence-to-geometry function.\n
317
- The method follows the descriptions in paper: https://arxiv.org/abs/2408.10120.
318
-
319
- :param seq: `Geo2Seq` string
320
- :param angle_unit: `'degree'` or `'radian'`
321
- :type seq: str
322
- :type angle_unit: str
323
- :return: (symbols, coordinates) if `seq` is valid
324
- :rtype: tuple | None
325
- """
326
- assert angle_unit in ("degree", "radian")
327
- angle_scale = np.pi / 180 if angle_unit == "degree" else 1.0
328
- tokens = seq.split()
329
- if len(tokens) % 4 == 0:
330
- tokens = np.array(tokens).reshape(-1, 4).tolist()
331
- symbols, coordinates = [], []
332
- for i in tokens:
333
- symbol = i[0]
334
- if len(_atom_regex.findall(symbol)) != 1:
335
- return None
336
- symbols.append(symbol)
337
- try:
338
- d, theta, phi = float(i[1]), float(i[2]), float(i[3])
339
- x = d * np.sin(theta * angle_scale) * np.cos(phi * angle_scale)
340
- y = d * np.sin(theta * angle_scale) * np.sin(phi * angle_scale)
341
- z = d * np.cos(theta * angle_scale)
342
- coordinates.append([x.item(), y.item(), z.item()])
343
- except ValueError:
344
- return None
345
- return symbols, coordinates
346
- return None
347
-
348
-
349
206
  @torch.no_grad()
350
207
  def sample(
351
- model: ChemBFN,
208
+ model: Union[ChemBFN, EnsembleChemBFN],
352
209
  batch_size: int,
353
210
  sequence_size: int,
354
211
  sample_step: int = 100,
355
- y: Optional[Tensor] = None,
212
+ y: Optional[Union[Tensor, Dict[str, Tensor], List[Tensor]]] = None,
356
213
  guidance_strength: float = 4.0,
357
214
  device: Union[str, torch.device, None] = None,
358
215
  vocab_keys: List[str] = VOCAB_KEYS,
@@ -368,7 +225,9 @@ def sample(
368
225
  :param batch_size: batch size
369
226
  :param sequence_size: max sequence length
370
227
  :param sample_step: number of sampling steps
371
- :param y: conditioning vector; shape: (n_b, 1, n_f)
228
+ :param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f) \n
229
+ or a list/`dict` of conditions; shape: (n_b, n_c) * n_h
230
+
372
231
  :param guidance_strength: strength of conditional generation. It is not used if y is null.
373
232
  :param device: hardware accelerator
374
233
  :param vocab_keys: a list of (ordered) vocabulary
@@ -376,11 +235,11 @@ def sample(
376
235
  :param method: sampling method chosen from `"ODE:x"` or `"BFN"` where `x` is the value of sampling temperature; default is `"BFN"`
377
236
  :param allowed_tokens: a list of allowed tokens; default is `"all"`
378
237
  :param sort: whether to sort the samples according to entropy values; default is `False`
379
- :type model: bayesianflow_for_chem.model.ChemBFN
238
+ :type model: bayesianflow_for_chem.model.ChemBFN | bayesianflow_for_chem.model.EnsembleChemBFN
380
239
  :type batch_size: int
381
240
  :type sequence_size: int
382
241
  :type sample_step: int
383
- :type y: torch.Tensor | None
242
+ :type y: torch.Tensor | list | dict | None
384
243
  :type guidance_strength: float
385
244
  :type device: str | torch.device | None
386
245
  :type vocab_keys: list
@@ -392,11 +251,23 @@ def sample(
392
251
  :rtype: list
393
252
  """
394
253
  assert method.split(":")[0].lower() in ("ode", "bfn")
254
+ if isinstance(model, EnsembleChemBFN):
255
+ assert y is not None, "conditioning is required while using an ensemble model."
256
+ assert isinstance(y, list) or isinstance(y, dict)
257
+ else:
258
+ assert isinstance(y, Tensor) or y is None
395
259
  if device is None:
396
260
  device = _find_device()
397
261
  model.to(device).eval()
398
262
  if y is not None:
399
- y = y.to(device)
263
+ if isinstance(y, Tensor):
264
+ y = y.to(device)
265
+ elif isinstance(y, list):
266
+ y = [i.to(device) for i in y]
267
+ elif isinstance(y, dict):
268
+ y = {k: v.to(device) for k, v in y.items()}
269
+ else:
270
+ raise NotImplementedError
400
271
  if isinstance(allowed_tokens, list):
401
272
  token_mask = [0 if i in allowed_tokens else 1 for i in vocab_keys]
402
273
  token_mask = torch.tensor([[token_mask]], dtype=torch.bool).to(device)
@@ -426,10 +297,10 @@ def sample(
426
297
 
427
298
  @torch.no_grad()
428
299
  def inpaint(
429
- model: ChemBFN,
300
+ model: Union[ChemBFN, EnsembleChemBFN],
430
301
  x: Tensor,
431
302
  sample_step: int = 100,
432
- y: Optional[Tensor] = None,
303
+ y: Optional[Union[Tensor, Dict[str, Tensor], List[Tensor]]] = None,
433
304
  guidance_strength: float = 4.0,
434
305
  device: Union[str, torch.device, None] = None,
435
306
  vocab_keys: List[str] = VOCAB_KEYS,
@@ -444,7 +315,9 @@ def inpaint(
444
315
  :param model: trained ChemBFN model
445
316
  :param x: categorical indices of scaffold; shape: (n_b, n_t)
446
317
  :param sample_step: number of sampling steps
447
- :param y: conditioning vector; shape: (n_b, 1, n_f)
318
+ :param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f) \n
319
+ or a list/`dict` of conditions; shape: (n_b, n_c) * n_h
320
+
448
321
  :param guidance_strength: strength of conditional generation. It is not used if y is null.
449
322
  :param device: hardware accelerator
450
323
  :param vocab_keys: a list of (ordered) vocabulary
@@ -452,10 +325,10 @@ def inpaint(
452
325
  :param method: sampling method chosen from `"ODE:x"` or `"BFN"` where `x` is the value of sampling temperature; default is `"BFN"`
453
326
  :param allowed_tokens: a list of allowed tokens; default is `"all"`
454
327
  :param sort: whether to sort the samples according to entropy values; default is `False`
455
- :type model: bayesianflow_for_chem.model.ChemBFN
328
+ :type model: bayesianflow_for_chem.model.ChemBFN | bayesianflow_for_chem.model.EnsembleChemBFN
456
329
  :type x: torch.Tensor
457
330
  :type sample_step: int
458
- :type y: torch.Tensor | None
331
+ :type y: torch.Tensor | list | dict | None
459
332
  :type guidance_strength: float
460
333
  :type device: str | torch.device | None
461
334
  :type vocab_keys: list
@@ -467,12 +340,24 @@ def inpaint(
467
340
  :rtype: list
468
341
  """
469
342
  assert method.split(":")[0].lower() in ("ode", "bfn")
343
+ if isinstance(model, EnsembleChemBFN):
344
+ assert y is not None, "conditioning is required while using an ensemble model."
345
+ assert isinstance(y, list) or isinstance(y, dict)
346
+ else:
347
+ assert isinstance(y, Tensor) or y is None
470
348
  if device is None:
471
349
  device = _find_device()
472
350
  model.to(device).eval()
473
351
  x = x.to(device)
474
352
  if y is not None:
475
- y = y.to(device)
353
+ if isinstance(y, Tensor):
354
+ y = y.to(device)
355
+ elif isinstance(y, list):
356
+ y = [i.to(device) for i in y]
357
+ elif isinstance(y, dict):
358
+ y = {k: v.to(device) for k, v in y.items()}
359
+ else:
360
+ raise NotImplementedError
476
361
  if isinstance(allowed_tokens, list):
477
362
  token_mask = [0 if i in allowed_tokens else 1 for i in vocab_keys]
478
363
  token_mask = torch.tensor([[token_mask]], dtype=torch.bool).to(device)
@@ -585,6 +470,8 @@ def quantise_model(model: ChemBFN) -> nn.Module:
585
470
  assert hasattr(
586
471
  mod, "qconfig"
587
472
  ), "Input float module must have qconfig defined"
473
+ if use_precomputed_fake_quant:
474
+ warnings.warn("Fake quantize operator is not implemented.")
588
475
  if mod.qconfig is not None and mod.qconfig.weight is not None:
589
476
  weight_observer = mod.qconfig.weight()
590
477
  else:
@@ -637,3 +524,81 @@ def quantise_model(model: ChemBFN) -> nn.Module:
637
524
  model, {nn.Linear, Linear}, torch.qint8, mapping
638
525
  )
639
526
  return quantised_model
527
+
528
+
529
+ class GeometryConverter:
530
+ """
531
+ Converting between different 2D/3D molecular representations.
532
+ """
533
+
534
+ @staticmethod
535
+ def _xyz2mol(symbols: List[str], coordinates: np.ndarray) -> Mol:
536
+ xyz_block = [str(len(symbols)), ""]
537
+ r = coordinates
538
+ for i, atom in enumerate(symbols):
539
+ xyz_block.append(f"{atom} {r[i][0]:.10f} {r[i][1]:.10f} {r[i][2]:.10f}")
540
+ return MolFromXYZBlock("\n".join(xyz_block))
541
+
542
+ @staticmethod
543
+ def _bond_pair_idx(bonds: Bond) -> List[List[int]]:
544
+ return [[i.GetBeginAtomIdx(), i.GetEndAtomIdx()] for i in bonds]
545
+
546
+ @staticmethod
547
+ def smiles2cartesian(
548
+ smiles: str, num_conformers: int = 50, random_seed: int = 42
549
+ ) -> Tuple[List[str], np.ndarray]:
550
+ """
551
+ Guess the 3D geometry from SMILES string via MMFF conformer search.
552
+
553
+ :param smiles: a valid SMILES string
554
+ :param num_conformers: number of initial conformers
555
+ :param random_seed: random seed used to generate conformers
556
+ :type smiles: str
557
+ :type num_conformers: int
558
+ :type random_seed: int
559
+ :return: atomic symbols \n
560
+ cartesian coordinates; shape: (n_a, 3)
561
+ :rtype: tuple
562
+ """
563
+ mol = MolFromSmiles(smiles)
564
+ mol = AddHs(mol)
565
+ AllChem.EmbedMultipleConfs(mol, numConfs=num_conformers, randomSeed=random_seed)
566
+ symbols = [atom.GetSymbol() for atom in mol.GetAtoms()]
567
+ energies = []
568
+ for conf_id in range(num_conformers):
569
+ ff = AllChem.MMFFGetMoleculeForceField(
570
+ mol, AllChem.MMFFGetMoleculeProperties(mol), confId=conf_id
571
+ )
572
+ energy = ff.CalcEnergy()
573
+ energies.append((conf_id, energy))
574
+ lowest_energy_conf = min(energies, key=lambda x: x[1])
575
+ coordinates = mol.GetConformer(id=lowest_energy_conf[0]).GetPositions()
576
+ return symbols, coordinates
577
+
578
+ def cartesian2smiles(
579
+ self,
580
+ symbols: List[str],
581
+ coordinates: np.ndarray,
582
+ charge: int = 0,
583
+ canonical: bool = True,
584
+ ) -> str:
585
+ """
586
+ Transform (guess out) molecular geometry to SMILES string.
587
+
588
+ :param symbols: a list of atomic symbols
589
+ :param coordinates: Cartesian coordinates; shape: (n_a, 3)
590
+ :param charge: net charge
591
+ :param canonical: whether to canonicalise the SMILES
592
+ :type symbols: list
593
+ :type coordinates: numpy.ndarray
594
+ :type charge: int
595
+ :type canonical: bool
596
+ :return: SMILES string
597
+ :rtype: str
598
+ """
599
+ mol = self._xyz2mol(symbols, coordinates)
600
+ rdDetermineBonds.DetermineBonds(mol, charge=charge)
601
+ smiles = MolToSmiles(mol)
602
+ if canonical:
603
+ smiles = CanonSmiles(smiles)
604
+ return smiles
@@ -37,7 +37,8 @@ class Model(LightningModule):
37
37
  """
38
38
  A `~lightning.LightningModule` wrapper of bayesian flow network for chemistry model.\n
39
39
  This module is used in training stage only. By calling `Model(...).export_model(YOUR_WORK_DIR)` after training,
40
- the model(s) will be saved to `YOUR_WORK_DIR/model.pt` and (if exists) `YOUR_WORK_DIR/mlp.pt`.
40
+ the model(s) will be saved to `YOUR_WORK_DIR/model.pt` (if LoRA is enabled then `YOUR_WORK_DIR/lora.pt`)
41
+ and (if exists) `YOUR_WORK_DIR/mlp.pt`.
41
42
 
42
43
  :param model: `~bayesianflow_for_chem.model.ChemBFN` instance.
43
44
  :param mlp: `~bayesianflow_for_chem.model.MLP` instance or `None`.
@@ -135,7 +136,8 @@ class Regressor(LightningModule):
135
136
  """
136
137
  A `~lightning.LightningModule` wrapper of bayesian flow network for chemistry regression model.\n
137
138
  This module is used in training stage only. By calling `Regressor(...).export_model(YOUR_WORK_DIR)` after training,
138
- the models will be saved to `YOUR_WORK_DIR/model.pt` and `YOUR_WORK_DIR/readout.pt`.
139
+ the models will be saved to `YOUR_WORK_DIR/model_ft.pt` (if LoRA is enabled then `YOUR_WORK_DIR/lora.pt`)
140
+ and `YOUR_WORK_DIR/readout.pt`.
139
141
 
140
142
  :param model: `~bayesianflow_for_chem.model.ChemBFN` instance.
141
143
  :param mlp: `~bayesianflow_for_chem.model.MLP` instance.
@@ -218,7 +220,7 @@ class Regressor(LightningModule):
218
220
  """
219
221
  Save the trained model.
220
222
 
221
- :param workdir: the directory to save the model
223
+ :param workdir: the directory to save the models
222
224
  :type workdir: pathlib.Path
223
225
  :return:
224
226
  :rtype: None
@@ -1,16 +1,15 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bayesianflow_for_chem
3
- Version: 1.2.7
3
+ Version: 1.4.0
4
4
  Summary: Bayesian flow network framework for Chemistry
5
5
  Home-page: https://augus1999.github.io/bayesian-flow-network-for-chemistry/
6
6
  Author: Nianze A. Tao
7
7
  Author-email: tao-nianze@hiroshima-u.ac.jp
8
- License: AGPL-3.0 licence
8
+ License: AGPL-3.0-or-later
9
9
  Project-URL: Source, https://github.com/Augus1999/bayesian-flow-network-for-chemistry
10
10
  Keywords: Chemistry,CLM,ChemBFN
11
11
  Classifier: Development Status :: 5 - Production/Stable
12
12
  Classifier: Intended Audience :: Science/Research
13
- Classifier: License :: OSI Approved :: GNU Affero General Public License v3
14
13
  Classifier: Natural Language :: English
15
14
  Classifier: Programming Language :: Python :: 3
16
15
  Classifier: Programming Language :: Python :: 3.9
@@ -29,8 +28,6 @@ Requires-Dist: loralib>=0.1.2
29
28
  Requires-Dist: lightning>=2.2.0
30
29
  Requires-Dist: scikit-learn>=1.5.0
31
30
  Requires-Dist: typing_extensions>=4.8.0
32
- Provides-Extra: geo2seq
33
- Requires-Dist: pynauty>=2.8.8.1; extra == "geo2seq"
34
31
  Dynamic: author
35
32
  Dynamic: author-email
36
33
  Dynamic: classifier
@@ -41,7 +38,6 @@ Dynamic: keywords
41
38
  Dynamic: license
42
39
  Dynamic: license-file
43
40
  Dynamic: project-url
44
- Dynamic: provides-extra
45
41
  Dynamic: requires-dist
46
42
  Dynamic: requires-python
47
43
  Dynamic: summary
@@ -87,13 +83,13 @@ You can find example scripts in [📁example](./example) folder.
87
83
 
88
84
  ## Pre-trained Model
89
85
 
90
- You can find pretrained models in [release](https://github.com/Augus1999/bayesian-flow-network-for-chemistry/releases) or on our [🤗Hugging Face model page](https://huggingface.co/suenoomozawa/ChemBFN).
86
+ You can find pretrained models on our [🤗Hugging Face model page](https://huggingface.co/suenoomozawa/ChemBFN).
91
87
 
92
88
  ## Dataset Handling
93
89
 
94
90
  We provide a Python class [`CSVData`](./bayesianflow_for_chem/data.py) to handle data stored in CSV or similar format containing headers to identify the entities. The following is a quickstart.
95
91
 
96
- 1. Download your dataset file (e.g., ESOL form [MoleculeNet](https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv)) and split the file:
92
+ 1. Download your dataset file (e.g., ESOL from [MoleculeNet](https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv)) and split the file:
97
93
  ```python
98
94
  >>> from bayesianflow_for_chem.tool import split_data
99
95
 
@@ -0,0 +1,12 @@
1
+ bayesianflow_for_chem/__init__.py,sha256=3sP8nM4_idOX-ksvpBJEApxPAVAPijKvQHxidTO5790,329
2
+ bayesianflow_for_chem/data.py,sha256=WoOCOVmJX4WeHa2WeO4i66J2FS8rvRaYRCdlBN7ZeOM,6576
3
+ bayesianflow_for_chem/model.py,sha256=fUrXKhn2U9FrVPJyb4lqACqPTePkIgI0v6_1jPs5c0Q,50784
4
+ bayesianflow_for_chem/scorer.py,sha256=7G1TVSwC0qONtNm6kiDZUWwvuFPzasNSjp4eJAk5TL0,4101
5
+ bayesianflow_for_chem/tool.py,sha256=NMMRHk2FJY0fyA76zCrz6tkcylCuExMUMj5hohWTnkE,23155
6
+ bayesianflow_for_chem/train.py,sha256=hGKyhGhLch-exSYPZdLXrLn3gf39Q1VLSJs2qtuikQE,9709
7
+ bayesianflow_for_chem/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
8
+ bayesianflow_for_chem-1.4.0.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
9
+ bayesianflow_for_chem-1.4.0.dist-info/METADATA,sha256=1Y5mLIOaPsHcyCCm2SkWz7OCniQYVJ67-cVq3cUU0Mw,5643
10
+ bayesianflow_for_chem-1.4.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
11
+ bayesianflow_for_chem-1.4.0.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
12
+ bayesianflow_for_chem-1.4.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.3.1)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,12 +0,0 @@
1
- bayesianflow_for_chem/__init__.py,sha256=xYC8F86oe8y40GGqzGGjbbjSXPK16Qci8XqDMjrbxK8,293
2
- bayesianflow_for_chem/data.py,sha256=9tpRba40lxwrB6aPSJMkxUglEVC3VEQC9wWxhDuz3Q8,7760
3
- bayesianflow_for_chem/model.py,sha256=HvEvW_xRbkv4eSv5lhd72BJMZkg-ZACEi1DAW3p5Q1Y,35918
4
- bayesianflow_for_chem/scorer.py,sha256=mV1vX8aBGFra2BE7N8WHihVIo3dXmUdPQIGfSaiuNdk,4084
5
- bayesianflow_for_chem/tool.py,sha256=yZrWCI3Zi6EHxo3zqCU_ebmzVECaco8Vbx-oTg-rHhg,24118
6
- bayesianflow_for_chem/train.py,sha256=kj6icGqymUUYopDtpre1oE_wpvpeNilbpzgffBsd1tk,9589
7
- bayesianflow_for_chem/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
8
- bayesianflow_for_chem-1.2.7.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
9
- bayesianflow_for_chem-1.2.7.dist-info/METADATA,sha256=9v-CEHo1DJGmgwopQiQ68sFEaUZkHzFIhfiNTL2r6mc,5913
10
- bayesianflow_for_chem-1.2.7.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
11
- bayesianflow_for_chem-1.2.7.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
12
- bayesianflow_for_chem-1.2.7.dist-info/RECORD,,