bayesianflow-for-chem 1.2.6__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bayesianflow-for-chem might be problematic. Click here for more details.
- bayesianflow_for_chem/__init__.py +3 -3
- bayesianflow_for_chem/data.py +2 -2
- bayesianflow_for_chem/model.py +392 -26
- bayesianflow_for_chem/scorer.py +1 -1
- bayesianflow_for_chem/tool.py +240 -147
- bayesianflow_for_chem/train.py +5 -3
- {bayesianflow_for_chem-1.2.6.dist-info → bayesianflow_for_chem-1.3.0.dist-info}/METADATA +5 -5
- bayesianflow_for_chem-1.3.0.dist-info/RECORD +12 -0
- {bayesianflow_for_chem-1.2.6.dist-info → bayesianflow_for_chem-1.3.0.dist-info}/WHEEL +1 -1
- bayesianflow_for_chem-1.2.6.dist-info/RECORD +0 -12
- {bayesianflow_for_chem-1.2.6.dist-info → bayesianflow_for_chem-1.3.0.dist-info/licenses}/LICENSE +0 -0
- {bayesianflow_for_chem-1.2.6.dist-info → bayesianflow_for_chem-1.3.0.dist-info}/top_level.txt +0 -0
|
@@ -4,8 +4,8 @@
|
|
|
4
4
|
ChemBFN package.
|
|
5
5
|
"""
|
|
6
6
|
from . import data, tool, train, scorer
|
|
7
|
-
from .model import ChemBFN, MLP
|
|
7
|
+
from .model import ChemBFN, MLP, EnsembleChemBFN
|
|
8
8
|
|
|
9
|
-
__all__ = ["data", "tool", "train", "scorer", "ChemBFN", "MLP"]
|
|
10
|
-
__version__ = "1.
|
|
9
|
+
__all__ = ["data", "tool", "train", "scorer", "ChemBFN", "MLP", "EnsembleChemBFN"]
|
|
10
|
+
__version__ = "1.3.0"
|
|
11
11
|
__author__ = "Nianze A. Tao (Omozawa Sueno)"
|
bayesianflow_for_chem/data.py
CHANGED
|
@@ -55,7 +55,7 @@ aa_regex = re.compile(AA_REGEX_PATTERN)
|
|
|
55
55
|
|
|
56
56
|
|
|
57
57
|
def load_vocab(
|
|
58
|
-
vocab_file: Union[str, Path]
|
|
58
|
+
vocab_file: Union[str, Path],
|
|
59
59
|
) -> Dict[str, Union[int, List[str], Dict[str, int]]]:
|
|
60
60
|
"""
|
|
61
61
|
Load vocabulary from source file.
|
|
@@ -108,7 +108,7 @@ def geo2vec(geo2seq: str) -> List[int]:
|
|
|
108
108
|
"""
|
|
109
109
|
Geo2Seq tokenisation using a dataset-independent regex pattern.
|
|
110
110
|
|
|
111
|
-
:param geo2seq:
|
|
111
|
+
:param geo2seq: `GEO2SEQ` string
|
|
112
112
|
:type geo2seq: str
|
|
113
113
|
:return: tokens w/o `<start>` and `<end>`
|
|
114
114
|
:rtype: list
|
bayesianflow_for_chem/model.py
CHANGED
|
@@ -4,7 +4,8 @@
|
|
|
4
4
|
Define Bayesian Flow Network for Chemistry (ChemBFN) model.
|
|
5
5
|
"""
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from
|
|
7
|
+
from copy import deepcopy
|
|
8
|
+
from typing import List, Tuple, Dict, Optional, Union, Callable
|
|
8
9
|
import torch
|
|
9
10
|
import torch.nn as nn
|
|
10
11
|
from torch import Tensor
|
|
@@ -592,6 +593,13 @@ class ChemBFN(nn.Module):
|
|
|
592
593
|
x, logits = torch.broadcast_tensors(x[..., None], logits)
|
|
593
594
|
return (-logits.gather(-1, x[..., :1]).squeeze(-1)).mean()
|
|
594
595
|
|
|
596
|
+
@staticmethod
|
|
597
|
+
def reshape_y(y: Tensor) -> Tensor:
|
|
598
|
+
assert y.dim() <= 3 # this doesn't work if the model is frezen in JIT.
|
|
599
|
+
if y.dim() == 2:
|
|
600
|
+
return y[:, None, :]
|
|
601
|
+
return y
|
|
602
|
+
|
|
595
603
|
@torch.jit.export
|
|
596
604
|
def sample(
|
|
597
605
|
self,
|
|
@@ -607,7 +615,7 @@ class ChemBFN(nn.Module):
|
|
|
607
615
|
|
|
608
616
|
:param batch_size: batch size
|
|
609
617
|
:param sequence_size: max sequence length
|
|
610
|
-
:param y: conditioning vector; shape: (n_b, 1, n_f)
|
|
618
|
+
:param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
|
|
611
619
|
:param sample_step: number of sampling steps
|
|
612
620
|
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
613
621
|
:param token_mask: token mask; shape: (1, 1, n_vocab)
|
|
@@ -626,9 +634,7 @@ class ChemBFN(nn.Module):
|
|
|
626
634
|
/ self.K
|
|
627
635
|
)
|
|
628
636
|
if y is not None:
|
|
629
|
-
|
|
630
|
-
if y.shape[0] == 1:
|
|
631
|
-
y = y.repeat(batch_size, 1, 1)
|
|
637
|
+
y = self.reshape_y(y)
|
|
632
638
|
for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
|
|
633
639
|
t = (i - 1).view(1, 1, 1).repeat(batch_size, 1, 1) / sample_step
|
|
634
640
|
p = self.discrete_output_distribution(theta, t, y, guidance_strength)
|
|
@@ -663,7 +669,7 @@ class ChemBFN(nn.Module):
|
|
|
663
669
|
|
|
664
670
|
:param batch_size: batch size
|
|
665
671
|
:param sequence_size: max sequence length
|
|
666
|
-
:param y: conditioning vector; shape: (n_b, 1, n_f)
|
|
672
|
+
:param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
|
|
667
673
|
:param sample_step: number of sampling steps
|
|
668
674
|
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
669
675
|
:param token_mask: token mask; shape: (1, 1, n_vocab)
|
|
@@ -681,9 +687,7 @@ class ChemBFN(nn.Module):
|
|
|
681
687
|
"""
|
|
682
688
|
z = torch.zeros((batch_size, sequence_size, self.K), device=self.beta.device)
|
|
683
689
|
if y is not None:
|
|
684
|
-
|
|
685
|
-
if y.shape[0] == 1:
|
|
686
|
-
y = y.repeat(batch_size, 1, 1)
|
|
690
|
+
y = self.reshape_y(y)
|
|
687
691
|
for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
|
|
688
692
|
t = (i - 1).view(1, 1, 1).repeat(batch_size, 1, 1) / sample_step
|
|
689
693
|
theta = torch.softmax(z, -1)
|
|
@@ -714,7 +718,7 @@ class ChemBFN(nn.Module):
|
|
|
714
718
|
Molecule inpaint functionality.
|
|
715
719
|
|
|
716
720
|
:param x: categorical indices of scaffold; shape: (n_b, n_t)
|
|
717
|
-
:param y: conditioning vector; shape: (n_b, 1, n_f)
|
|
721
|
+
:param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
|
|
718
722
|
:param sample_step: number of sampling steps
|
|
719
723
|
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
720
724
|
:param token_mask: token mask; shape: (1, 1, n_vocab)
|
|
@@ -733,9 +737,7 @@ class ChemBFN(nn.Module):
|
|
|
733
737
|
x_onehot = nn.functional.one_hot(x, self.K) * mask
|
|
734
738
|
theta = x_onehot + (1 - mask) * theta
|
|
735
739
|
if y is not None:
|
|
736
|
-
|
|
737
|
-
if y.shape[0] == 1:
|
|
738
|
-
y = y.repeat(n_b, 1, 1)
|
|
740
|
+
y = self.reshape_y(y)
|
|
739
741
|
for i in torch.linspace(1, sample_step, sample_step, device=x.device):
|
|
740
742
|
t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
|
|
741
743
|
p = self.discrete_output_distribution(theta, t, y, guidance_strength)
|
|
@@ -769,7 +771,7 @@ class ChemBFN(nn.Module):
|
|
|
769
771
|
ODE inpainting.
|
|
770
772
|
|
|
771
773
|
:param x: categorical indices of scaffold; shape: (n_b, n_t)
|
|
772
|
-
:param y: conditioning vector; shape: (n_b, 1, n_f)
|
|
774
|
+
:param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
|
|
773
775
|
:param sample_step: number of sampling steps
|
|
774
776
|
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
775
777
|
:param token_mask: token mask; shape: (1, 1, n_vocab)
|
|
@@ -789,9 +791,7 @@ class ChemBFN(nn.Module):
|
|
|
789
791
|
x_onehot = nn.functional.one_hot(x, self.K) * mask
|
|
790
792
|
z = torch.zeros((n_b, n_t, self.K), device=self.beta.device)
|
|
791
793
|
if y is not None:
|
|
792
|
-
|
|
793
|
-
if y.shape[0] == 1:
|
|
794
|
-
y = y.repeat(n_b, 1, 1)
|
|
794
|
+
y = self.reshape_y(y)
|
|
795
795
|
for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
|
|
796
796
|
t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
|
|
797
797
|
theta = torch.softmax(z, -1)
|
|
@@ -847,13 +847,7 @@ class ChemBFN(nn.Module):
|
|
|
847
847
|
with open(ckpt, "rb") as f:
|
|
848
848
|
state = torch.load(f, "cpu", weights_only=True)
|
|
849
849
|
nn, hparam = state["nn"], state["hparam"]
|
|
850
|
-
model = cls(
|
|
851
|
-
hparam["num_vocab"],
|
|
852
|
-
hparam["channel"],
|
|
853
|
-
hparam["num_layer"],
|
|
854
|
-
hparam["num_head"],
|
|
855
|
-
hparam["dropout"],
|
|
856
|
-
)
|
|
850
|
+
model = cls(**hparam)
|
|
857
851
|
model.load_state_dict(nn, False)
|
|
858
852
|
if ckpt_lora:
|
|
859
853
|
with open(ckpt_lora, "rb") as g:
|
|
@@ -908,7 +902,7 @@ class MLP(nn.Module):
|
|
|
908
902
|
if self.class_input:
|
|
909
903
|
x = x.to(dtype=torch.long)
|
|
910
904
|
for layer in self.layers[:-1]:
|
|
911
|
-
x = torch.selu(layer(x))
|
|
905
|
+
x = torch.selu(layer.forward(x))
|
|
912
906
|
return self.layers[-1](x)
|
|
913
907
|
|
|
914
908
|
@classmethod
|
|
@@ -926,10 +920,382 @@ class MLP(nn.Module):
|
|
|
926
920
|
with open(ckpt, "rb") as f:
|
|
927
921
|
state = torch.load(f, "cpu", weights_only=True)
|
|
928
922
|
nn, hparam = state["nn"], state["hparam"]
|
|
929
|
-
model = cls(hparam
|
|
923
|
+
model = cls(**hparam)
|
|
930
924
|
model.load_state_dict(nn, strict)
|
|
931
925
|
return model
|
|
932
926
|
|
|
933
927
|
|
|
928
|
+
class EnsembleChemBFN(ChemBFN):
|
|
929
|
+
"""
|
|
930
|
+
This module does not fully support `torch.jit.script`. We have `EnsembleChemBFN.jit()`
|
|
931
|
+
method to JIT compile the submodels.
|
|
932
|
+
`torch.compile()` is a better choice to compiling the whole model.
|
|
933
|
+
"""
|
|
934
|
+
|
|
935
|
+
def __init__(
|
|
936
|
+
self,
|
|
937
|
+
base_model_path: Union[str, Path],
|
|
938
|
+
lora_paths: Union[List[Union[str, Path]], Dict[str, Union[str, Path]]],
|
|
939
|
+
cond_heads: Union[List[nn.Module], Dict[str, nn.Module]],
|
|
940
|
+
adapter_weights: Optional[Union[List[float], Dict[str, float]]] = None,
|
|
941
|
+
semi_autoregressive_flags: Optional[Union[List[bool], Dict[str, bool]]] = None,
|
|
942
|
+
) -> None:
|
|
943
|
+
"""
|
|
944
|
+
Ensemble of ChemBFN models from LoRA checkpoints.
|
|
945
|
+
|
|
946
|
+
:param base_model_path: base model checkpoint file
|
|
947
|
+
:param lora_paths: a list of LoRA checkpoint files or a `dict` instance of these files
|
|
948
|
+
:param cond_heads: a list of conditioning network heads or a `dict` instance of these networks
|
|
949
|
+
:param adapter_weights: a list of weights of each LoRA finetuned model or a 'dict` instance of these weights; default is equally weighted
|
|
950
|
+
:param semi_autoregressive_flags: a list of the semi-autoregressive behaviour states of each LoRA finetuned model or a `dict` instance of these states; default is all `False`
|
|
951
|
+
:type base_model_path: str | pathlib.Path
|
|
952
|
+
:type lora_paths: list | dict
|
|
953
|
+
:type cond_heads: list | dict
|
|
954
|
+
:type adapter_weights: list | dict | None
|
|
955
|
+
:type semi_autoregressive_flags: list | dict | None
|
|
956
|
+
"""
|
|
957
|
+
n = len(lora_paths)
|
|
958
|
+
assert type(lora_paths) == type(
|
|
959
|
+
cond_heads
|
|
960
|
+
), "`lora_paths` and `cond_heads` should have the same type!"
|
|
961
|
+
assert n == len(
|
|
962
|
+
cond_heads
|
|
963
|
+
), "`lora_paths` and `cond_heads` should have the same length!"
|
|
964
|
+
if adapter_weights:
|
|
965
|
+
assert type(lora_paths) == type(
|
|
966
|
+
adapter_weights
|
|
967
|
+
), "`lora_paths` and `adapter_weights` should have the same type!"
|
|
968
|
+
assert n == len(
|
|
969
|
+
adapter_weights
|
|
970
|
+
), "`lora_paths` and `adapter_weights` should have the same length!"
|
|
971
|
+
if semi_autoregressive_flags:
|
|
972
|
+
assert type(lora_paths) == type(
|
|
973
|
+
semi_autoregressive_flags
|
|
974
|
+
), "`lora_paths` and `semi_autoregressive_flags` should have the same type!"
|
|
975
|
+
assert n == len(
|
|
976
|
+
semi_autoregressive_flags
|
|
977
|
+
), "`lora_paths` and `semi_autoregressive_flags` should have the same length!"
|
|
978
|
+
_label_is_dict = isinstance(lora_paths, dict)
|
|
979
|
+
if isinstance(lora_paths, list):
|
|
980
|
+
names = tuple(f"val_{i}" for i in range(n))
|
|
981
|
+
lora_paths = dict(zip(names, lora_paths))
|
|
982
|
+
cond_heads = dict(zip(names, cond_heads))
|
|
983
|
+
if not adapter_weights:
|
|
984
|
+
adapter_weights = (1 / n for _ in names)
|
|
985
|
+
if not semi_autoregressive_flags:
|
|
986
|
+
semi_autoregressive_flags = (False for _ in names)
|
|
987
|
+
adapter_weights = dict(zip(names, adapter_weights))
|
|
988
|
+
semi_autoregressive_flags = dict(zip(names, semi_autoregressive_flags))
|
|
989
|
+
else:
|
|
990
|
+
names = tuple(lora_paths.keys())
|
|
991
|
+
if not adapter_weights:
|
|
992
|
+
adapter_weights = dict(zip(names, (1 / n for _ in names)))
|
|
993
|
+
if not semi_autoregressive_flags:
|
|
994
|
+
semi_autoregressive_flags = dict(zip(names, (False for _ in names)))
|
|
995
|
+
base_model = ChemBFN.from_checkpoint(base_model_path)
|
|
996
|
+
models = dict(zip(names, (deepcopy(base_model.eval()) for _ in names)))
|
|
997
|
+
for k in names:
|
|
998
|
+
with open(lora_paths[k], "rb") as f:
|
|
999
|
+
state = torch.load(f, "cpu", weights_only=True)
|
|
1000
|
+
lora_nn, lora_param = state["lora_nn"], state["lora_param"]
|
|
1001
|
+
models[k].enable_lora(**lora_param)
|
|
1002
|
+
models[k].load_state_dict(lora_nn, False)
|
|
1003
|
+
models[k].semi_autoregressive = semi_autoregressive_flags[k]
|
|
1004
|
+
super().__init__(**base_model.hparam)
|
|
1005
|
+
self.cond_heads = nn.ModuleDict(cond_heads)
|
|
1006
|
+
self.models = nn.ModuleDict(models)
|
|
1007
|
+
self.adapter_weights = adapter_weights
|
|
1008
|
+
self._label_is_dict = _label_is_dict # flag
|
|
1009
|
+
# ------- remove unnecessary submodules -------
|
|
1010
|
+
self.embedding = None
|
|
1011
|
+
self.time_embed = None
|
|
1012
|
+
self.position = None
|
|
1013
|
+
self.encoder_layers = None
|
|
1014
|
+
self.final_layer = None
|
|
1015
|
+
self.__delattr__("embedding")
|
|
1016
|
+
self.__delattr__("time_embed")
|
|
1017
|
+
self.__delattr__("position")
|
|
1018
|
+
self.__delattr__("encoder_layers")
|
|
1019
|
+
self.__delattr__("final_layer")
|
|
1020
|
+
# ------- remove unused attributes -------
|
|
1021
|
+
self.__delattr__("semi_autoregressive")
|
|
1022
|
+
self.__delattr__("lora_enabled")
|
|
1023
|
+
self.__delattr__("lora_param")
|
|
1024
|
+
self.__delattr__("hparam")
|
|
1025
|
+
|
|
1026
|
+
def construct_y(
|
|
1027
|
+
self, c: Union[List[Tensor], Dict[str, Tensor]]
|
|
1028
|
+
) -> Dict[str, Tensor]:
|
|
1029
|
+
assert (
|
|
1030
|
+
isinstance(c, dict) is self._label_is_dict
|
|
1031
|
+
), f"`c` should be a {'`dict` instance' if self._label_is_dict else '`list` instance'} but got {type(c)} instand."
|
|
1032
|
+
out: Dict[str, Tensor] = {}
|
|
1033
|
+
if isinstance(c, list):
|
|
1034
|
+
c = dict(zip([f"val_{i}" for i in range(len(c))], c))
|
|
1035
|
+
for name, model in self.cond_heads.items():
|
|
1036
|
+
y = model.forward(c[name])
|
|
1037
|
+
if y.dim() == 2:
|
|
1038
|
+
y = y[:, None, :]
|
|
1039
|
+
out[name] = y
|
|
1040
|
+
return out
|
|
1041
|
+
|
|
1042
|
+
def discrete_output_distribution(
|
|
1043
|
+
self, theta: Tensor, t: Tensor, y: Dict[str, Tensor], w: float
|
|
1044
|
+
) -> Tensor:
|
|
1045
|
+
"""
|
|
1046
|
+
:param theta: input distribution; shape: (n_b, n_t, n_vocab)
|
|
1047
|
+
:param t: continuous time in [0, 1]; shape: (n_b, 1, 1)
|
|
1048
|
+
:param y: a dict of conditioning vectors; shape: (n_b, 1, n_f) * n_h
|
|
1049
|
+
:param w: guidance strength controlling the conditional generation
|
|
1050
|
+
:type theta: torch.Tensor
|
|
1051
|
+
:type t: torch.Tensor
|
|
1052
|
+
:type y: dict
|
|
1053
|
+
:type w: float
|
|
1054
|
+
:return: output distribution; shape: (n_b, n_t, n_vocab)
|
|
1055
|
+
:rtype: torch.Tensor
|
|
1056
|
+
"""
|
|
1057
|
+
theta = 2 * theta - 1 # rescale to [-1, 1]
|
|
1058
|
+
p_uncond, p_cond = torch.zeros_like(theta), torch.zeros_like(theta)
|
|
1059
|
+
# Q: Why not use `torch.vmap`? It's faster than doing the loop, isn't it?
|
|
1060
|
+
#
|
|
1061
|
+
# A: We have quite a few reasons to avoid using `vmap`:
|
|
1062
|
+
# 1. JIT doesn't support vmap;
|
|
1063
|
+
# 2. It's harder to switch on/off semi-autroregssive behaviours for individual
|
|
1064
|
+
# models when all models are stacked into one (we have a solution but it's not
|
|
1065
|
+
# that elegant);
|
|
1066
|
+
# 3. We just found that the result from vmap was not identical to doing the loop;
|
|
1067
|
+
# 4. vmap requires all models have the same size but it's not always that case
|
|
1068
|
+
# since we sometimes use different ranks of LoRA in finetuning.
|
|
1069
|
+
for name, model in self.models.items():
|
|
1070
|
+
p_uncond_ = model.forward(theta, t, None, None)
|
|
1071
|
+
p_uncond += p_uncond_ * self.adapter_weights[name]
|
|
1072
|
+
p_cond_ = model.forward(theta, t, None, y[name])
|
|
1073
|
+
p_cond += p_cond_ * self.adapter_weights[name]
|
|
1074
|
+
return softmax((1 + w) * p_cond - w * p_uncond, -1)
|
|
1075
|
+
|
|
1076
|
+
@staticmethod
|
|
1077
|
+
def reshape_y(y: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
1078
|
+
for k in y:
|
|
1079
|
+
assert y[k].dim() <= 3
|
|
1080
|
+
if y[k].dim() == 2:
|
|
1081
|
+
y[k] = y[k][:, None, :]
|
|
1082
|
+
return y
|
|
1083
|
+
|
|
1084
|
+
@torch.inference_mode()
|
|
1085
|
+
def sample(
|
|
1086
|
+
self,
|
|
1087
|
+
batch_size: int,
|
|
1088
|
+
sequence_size: int,
|
|
1089
|
+
conditions: Union[List[Tensor], Dict[str, Tensor]],
|
|
1090
|
+
sample_step: int = 100,
|
|
1091
|
+
guidance_strength: float = 4.0,
|
|
1092
|
+
token_mask: Optional[Tensor] = None,
|
|
1093
|
+
) -> Tuple[Tensor, Tensor]:
|
|
1094
|
+
"""
|
|
1095
|
+
Sample from a piror distribution.
|
|
1096
|
+
|
|
1097
|
+
:param batch_size: batch size
|
|
1098
|
+
:param sequence_size: max sequence length
|
|
1099
|
+
:param conditions: guidance conditions; shape: (n_b, n_c) * n_h
|
|
1100
|
+
:param sample_step: number of sampling steps
|
|
1101
|
+
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
1102
|
+
:param token_mask: token mask; shape: (1, 1, n_vocab)
|
|
1103
|
+
:type batch_size: int
|
|
1104
|
+
:type sequence_size: int
|
|
1105
|
+
:type conditions: list | dict
|
|
1106
|
+
:type sample_step: int
|
|
1107
|
+
:type guidance_strength: float
|
|
1108
|
+
:type token_mask: torch.Tensor | None
|
|
1109
|
+
:return: sampled token indices; shape: (n_b, n_t) \n
|
|
1110
|
+
entropy of the tokens; shape: (n_b)
|
|
1111
|
+
:rtype: tuple
|
|
1112
|
+
"""
|
|
1113
|
+
y = self.construct_y(conditions)
|
|
1114
|
+
return super().sample(
|
|
1115
|
+
batch_size, sequence_size, y, sample_step, guidance_strength, token_mask
|
|
1116
|
+
)
|
|
1117
|
+
|
|
1118
|
+
@torch.inference_mode()
|
|
1119
|
+
def ode_sample(
|
|
1120
|
+
self,
|
|
1121
|
+
batch_size: int,
|
|
1122
|
+
sequence_size: int,
|
|
1123
|
+
conditions: Union[List[Tensor], Dict[str, Tensor]],
|
|
1124
|
+
sample_step: int = 100,
|
|
1125
|
+
guidance_strength: float = 4.0,
|
|
1126
|
+
token_mask: Optional[Tensor] = None,
|
|
1127
|
+
temperature: float = 0.5,
|
|
1128
|
+
) -> Tuple[Tensor, Tensor]:
|
|
1129
|
+
"""
|
|
1130
|
+
ODE-based sampling.
|
|
1131
|
+
|
|
1132
|
+
:param batch_size: batch size
|
|
1133
|
+
:param sequence_size: max sequence length
|
|
1134
|
+
:param conditions: conditioning vector; shape: (n_b, n_c) * n_h
|
|
1135
|
+
:param sample_step: number of sampling steps
|
|
1136
|
+
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
1137
|
+
:param token_mask: token mask; shape: (1, 1, n_vocab)
|
|
1138
|
+
:param temperature: sampling temperature
|
|
1139
|
+
:type batch_size: int
|
|
1140
|
+
:type sequence_size: int
|
|
1141
|
+
:type conditions: list | dict
|
|
1142
|
+
:type sample_step: int
|
|
1143
|
+
:type guidance_strength: float
|
|
1144
|
+
:type token_mask: torch.Tensor | None
|
|
1145
|
+
:type temperature: float
|
|
1146
|
+
:return: sampled token indices; shape: (n_b, n_t) \n
|
|
1147
|
+
entropy of the tokens; shape: (n_b)
|
|
1148
|
+
:rtype: tuple
|
|
1149
|
+
"""
|
|
1150
|
+
y = self.construct_y(conditions)
|
|
1151
|
+
return super().ode_sample(
|
|
1152
|
+
batch_size,
|
|
1153
|
+
sequence_size,
|
|
1154
|
+
y,
|
|
1155
|
+
sample_step,
|
|
1156
|
+
guidance_strength,
|
|
1157
|
+
token_mask,
|
|
1158
|
+
temperature,
|
|
1159
|
+
)
|
|
1160
|
+
|
|
1161
|
+
@torch.inference_mode()
|
|
1162
|
+
def inpaint(
|
|
1163
|
+
self,
|
|
1164
|
+
x: Tensor,
|
|
1165
|
+
conditions: Union[List[Tensor], Dict[str, Tensor]],
|
|
1166
|
+
sample_step: int = 100,
|
|
1167
|
+
guidance_strength: float = 4.0,
|
|
1168
|
+
token_mask: Optional[Tensor] = None,
|
|
1169
|
+
) -> Tuple[Tensor, Tensor]:
|
|
1170
|
+
"""
|
|
1171
|
+
Molecule inpaint functionality.
|
|
1172
|
+
|
|
1173
|
+
:param x: categorical indices of scaffold; shape: (n_b, n_t)
|
|
1174
|
+
:param conditions: conditioning vector; shape: (n_b, n_c) * n_h
|
|
1175
|
+
:param sample_step: number of sampling steps
|
|
1176
|
+
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
1177
|
+
:param token_mask: token mask; shape: (1, 1, n_vocab)
|
|
1178
|
+
:type x: torch.Tensor
|
|
1179
|
+
:type conditions: list | dict
|
|
1180
|
+
:type sample_step: int
|
|
1181
|
+
:type guidance_strength: float
|
|
1182
|
+
:type token_mask: torch.Tensor | None
|
|
1183
|
+
:return: sampled token indices; shape: (n_b, n_t) \n
|
|
1184
|
+
entropy of the tokens; shape: (n_b)
|
|
1185
|
+
:rtype: tuple
|
|
1186
|
+
"""
|
|
1187
|
+
y = self.construct_y(conditions)
|
|
1188
|
+
return super().inpaint(x, y, sample_step, guidance_strength, token_mask)
|
|
1189
|
+
|
|
1190
|
+
@torch.inference_mode()
|
|
1191
|
+
def ode_inpaint(
|
|
1192
|
+
self,
|
|
1193
|
+
x: Tensor,
|
|
1194
|
+
conditions: Union[List[Tensor], Dict[str, Tensor]],
|
|
1195
|
+
sample_step: int = 100,
|
|
1196
|
+
guidance_strength: float = 4.0,
|
|
1197
|
+
token_mask: Optional[Tensor] = None,
|
|
1198
|
+
temperature: float = 0.5,
|
|
1199
|
+
) -> Tuple[Tensor, Tensor]:
|
|
1200
|
+
"""
|
|
1201
|
+
ODE inpainting.
|
|
1202
|
+
|
|
1203
|
+
:param x: categorical indices of scaffold; shape: (n_b, n_t)
|
|
1204
|
+
:param conditions: conditioning vector; shape: (n_b, n_c) * n_h
|
|
1205
|
+
:param sample_step: number of sampling steps
|
|
1206
|
+
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
1207
|
+
:param token_mask: token mask; shape: (1, 1, n_vocab)
|
|
1208
|
+
:param temperature: sampling temperature
|
|
1209
|
+
:type x: torch.Tensor
|
|
1210
|
+
:type conditions: list | dict
|
|
1211
|
+
:type sample_step: int
|
|
1212
|
+
:type guidance_strength: float
|
|
1213
|
+
:type token_mask: torch.Tensor | None
|
|
1214
|
+
:type temperature: float
|
|
1215
|
+
:return: sampled token indices; shape: (n_b, n_t) \n
|
|
1216
|
+
entropy of the tokens; shape: (n_b)
|
|
1217
|
+
:rtype: tuple
|
|
1218
|
+
"""
|
|
1219
|
+
y = self.construct_y(conditions)
|
|
1220
|
+
return super().ode_inpaint(
|
|
1221
|
+
x, y, sample_step, guidance_strength, token_mask, temperature
|
|
1222
|
+
)
|
|
1223
|
+
|
|
1224
|
+
def quantise(
|
|
1225
|
+
self, quantise_method: Optional[Callable[[ChemBFN], nn.Module]] = None
|
|
1226
|
+
) -> None:
|
|
1227
|
+
"""
|
|
1228
|
+
Quantise the submodels. \n
|
|
1229
|
+
This method should be called, if necessary, before `torch.compile()`.
|
|
1230
|
+
|
|
1231
|
+
:param quantise_method: quantisation method; default is `bayesianflow_for_chem.tool.quantise_model`
|
|
1232
|
+
:type quantise_method: callable | None
|
|
1233
|
+
:return:
|
|
1234
|
+
:rtype: None
|
|
1235
|
+
"""
|
|
1236
|
+
if quantise_method is None:
|
|
1237
|
+
from bayesianflow_for_chem.tool import quantise_model
|
|
1238
|
+
|
|
1239
|
+
quantise_method = quantise_model
|
|
1240
|
+
for k, v in self.models.items():
|
|
1241
|
+
self.models[k] = quantise_method(v)
|
|
1242
|
+
|
|
1243
|
+
def jit(self, freeze: bool = False) -> None:
|
|
1244
|
+
"""
|
|
1245
|
+
JIT compile the submodels. \n
|
|
1246
|
+
This method should be called, if necessary, before `quantise()` method is called if applied.
|
|
1247
|
+
|
|
1248
|
+
:param freeze: whether to freeze the submodels; default is `False`. If set to `True` this
|
|
1249
|
+
method should be called before moving the model to a different device.
|
|
1250
|
+
:type freeze: bool
|
|
1251
|
+
:return:
|
|
1252
|
+
:rtype: None
|
|
1253
|
+
"""
|
|
1254
|
+
for k, v in self.models.items():
|
|
1255
|
+
self.models[k] = torch.jit.script(v)
|
|
1256
|
+
if freeze:
|
|
1257
|
+
self.models[k] = torch.jit.freeze(
|
|
1258
|
+
self.models[k], ["semi_autoregressive"]
|
|
1259
|
+
)
|
|
1260
|
+
|
|
1261
|
+
@torch.jit.ignore
|
|
1262
|
+
def forward(self, *_, **__) -> None:
|
|
1263
|
+
"""
|
|
1264
|
+
Don't use this method!
|
|
1265
|
+
"""
|
|
1266
|
+
raise NotImplementedError("There's nothing here!")
|
|
1267
|
+
|
|
1268
|
+
def cts_loss(self, *_, **__) -> None:
|
|
1269
|
+
"""
|
|
1270
|
+
Don't use this method!
|
|
1271
|
+
"""
|
|
1272
|
+
raise NotImplementedError("There's nothing here!")
|
|
1273
|
+
|
|
1274
|
+
def reconstruction_loss(self, *_, **__) -> None:
|
|
1275
|
+
"""
|
|
1276
|
+
Don't use this method!
|
|
1277
|
+
"""
|
|
1278
|
+
raise NotImplementedError("There's nothing here!")
|
|
1279
|
+
|
|
1280
|
+
def enable_lora(self, *_, **__) -> None:
|
|
1281
|
+
"""
|
|
1282
|
+
Don't use this method!
|
|
1283
|
+
"""
|
|
1284
|
+
raise NotImplementedError("There's nothing here!")
|
|
1285
|
+
|
|
1286
|
+
def inference(self, *_, **__) -> None:
|
|
1287
|
+
"""
|
|
1288
|
+
Don't use this method!
|
|
1289
|
+
"""
|
|
1290
|
+
raise NotImplementedError("There's nothing here!")
|
|
1291
|
+
|
|
1292
|
+
@classmethod
|
|
1293
|
+
def from_checkpoint(cls, *_, **__) -> None:
|
|
1294
|
+
"""
|
|
1295
|
+
Don't use this method!
|
|
1296
|
+
"""
|
|
1297
|
+
raise NotImplementedError("There's nothing here!")
|
|
1298
|
+
|
|
1299
|
+
|
|
934
1300
|
if __name__ == "__main__":
|
|
935
1301
|
...
|
bayesianflow_for_chem/scorer.py
CHANGED
bayesianflow_for_chem/tool.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
# -*- coding: utf-8 -*-
|
|
2
2
|
# Author: Nianze A. TAO (Omozawa SUENO)
|
|
3
3
|
"""
|
|
4
|
-
|
|
4
|
+
Essential tools.
|
|
5
5
|
"""
|
|
6
6
|
import re
|
|
7
7
|
import csv
|
|
8
8
|
import random
|
|
9
|
+
import warnings
|
|
9
10
|
from copy import deepcopy
|
|
10
11
|
from pathlib import Path
|
|
11
12
|
from typing import List, Dict, Tuple, Union, Optional
|
|
@@ -16,7 +17,8 @@ from torch import cuda, Tensor, softmax
|
|
|
16
17
|
from torch.ao import quantization
|
|
17
18
|
from torch.utils.data import DataLoader
|
|
18
19
|
from typing_extensions import Self
|
|
19
|
-
from rdkit.Chem import
|
|
20
|
+
from rdkit.Chem.rdchem import Mol, Bond
|
|
21
|
+
from rdkit.Chem import rdDetermineBonds, MolFromXYZBlock, MolToSmiles, CanonSmiles
|
|
20
22
|
from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles # type: ignore
|
|
21
23
|
from sklearn.metrics import (
|
|
22
24
|
roc_auc_score,
|
|
@@ -32,12 +34,11 @@ try:
|
|
|
32
34
|
|
|
33
35
|
_use_pynauty = True
|
|
34
36
|
except ImportError:
|
|
35
|
-
import
|
|
37
|
+
import platform
|
|
36
38
|
|
|
37
39
|
_use_pynauty = False
|
|
38
|
-
|
|
39
40
|
from .data import VOCAB_KEYS
|
|
40
|
-
from .model import ChemBFN, MLP, Linear
|
|
41
|
+
from .model import ChemBFN, MLP, Linear, EnsembleChemBFN
|
|
41
42
|
|
|
42
43
|
|
|
43
44
|
_atom_regex_pattern = (
|
|
@@ -161,6 +162,8 @@ def split_dataset(
|
|
|
161
162
|
:return:
|
|
162
163
|
:rtype: None
|
|
163
164
|
"""
|
|
165
|
+
if isinstance(file, Path):
|
|
166
|
+
file = file.__str__()
|
|
164
167
|
assert file.endswith(".csv")
|
|
165
168
|
assert len(split_ratio) == 3
|
|
166
169
|
assert method in ("random", "scaffold")
|
|
@@ -170,7 +173,7 @@ def split_dataset(
|
|
|
170
173
|
raw_data = data[1:]
|
|
171
174
|
smiles_idx = [] # only first index will be used
|
|
172
175
|
for key, h in enumerate(header):
|
|
173
|
-
if h.lower()
|
|
176
|
+
if "smiles" in h.lower():
|
|
174
177
|
smiles_idx.append(key)
|
|
175
178
|
assert len(smiles_idx) > 0
|
|
176
179
|
data_len = len(raw_data)
|
|
@@ -186,11 +189,22 @@ def split_dataset(
|
|
|
186
189
|
scaffolds: Dict[str, List] = {}
|
|
187
190
|
for key, d in enumerate(raw_data):
|
|
188
191
|
# compute Bemis-Murcko scaffold
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
192
|
+
if len(smiles_idx) > 1:
|
|
193
|
+
warnings.warn(
|
|
194
|
+
"\033[32;1m"
|
|
195
|
+
f"We found {len(smiles_idx)} SMILES strings in a row!"
|
|
196
|
+
" Only the first SMILES will be used to compute the molecular scaffold."
|
|
197
|
+
"\033[0m",
|
|
198
|
+
stacklevel=2,
|
|
199
|
+
)
|
|
200
|
+
try:
|
|
201
|
+
scaffold = MurckoScaffoldSmiles(d[smiles_idx[0]])
|
|
202
|
+
if scaffold in scaffolds:
|
|
203
|
+
scaffolds[scaffold].append(key)
|
|
204
|
+
else:
|
|
205
|
+
scaffolds[scaffold] = [key]
|
|
206
|
+
except ValueError: # do nothing when SMILES is not valid
|
|
207
|
+
...
|
|
194
208
|
scaffolds = {key: sorted(value) for key, value in scaffolds.items()}
|
|
195
209
|
train_set, test_set, val_set = [], [], []
|
|
196
210
|
for idxs in scaffolds.values():
|
|
@@ -212,137 +226,13 @@ def split_dataset(
|
|
|
212
226
|
writer.writerows([header] + val_set)
|
|
213
227
|
|
|
214
228
|
|
|
215
|
-
def geo2seq(
|
|
216
|
-
symbols: List[str],
|
|
217
|
-
coordinates: np.ndarray,
|
|
218
|
-
decimals: int = 2,
|
|
219
|
-
angle_unit: str = "degree",
|
|
220
|
-
) -> str:
|
|
221
|
-
"""
|
|
222
|
-
Geometry-to-sequence function.\n
|
|
223
|
-
The algorithm follows the descriptions in paper: https://arxiv.org/abs/2408.10120.
|
|
224
|
-
|
|
225
|
-
:param symbols: a list of atomic symbols
|
|
226
|
-
:param coordinates: Cartesian coordinates; shape: (n_a, 3)
|
|
227
|
-
:param decimals: number of decimal places to round to
|
|
228
|
-
:param angle_unit: `'degree'` or `'radian'`
|
|
229
|
-
:type symbols: list
|
|
230
|
-
:type coordinates: numpy.ndarray
|
|
231
|
-
:type decimals: int
|
|
232
|
-
:type angle_unit: str
|
|
233
|
-
:return: `Geo2Seq` string
|
|
234
|
-
:rtype: str
|
|
235
|
-
"""
|
|
236
|
-
assert angle_unit in ("degree", "radian")
|
|
237
|
-
angle_scale = 180 / np.pi if angle_unit == "degree" else 1.0
|
|
238
|
-
n = len(symbols)
|
|
239
|
-
if n == 1:
|
|
240
|
-
return f"{symbols[0]} {'0.0'} {'0.0'} {'0.0'}"
|
|
241
|
-
xyz_block = [str(n), ""]
|
|
242
|
-
for i, atom in enumerate(symbols):
|
|
243
|
-
xyz_block.append(
|
|
244
|
-
f"{atom} {'%.10f' % coordinates[i][0].item()} {'%.10f' % coordinates[i][1].item()} {'%.10f' % coordinates[i][2].item()}"
|
|
245
|
-
)
|
|
246
|
-
mol = MolFromXYZBlock("\n".join(xyz_block))
|
|
247
|
-
rdDetermineBonds.DetermineConnectivity(mol)
|
|
248
|
-
# ------- Canonicalization -------
|
|
249
|
-
if _use_pynauty:
|
|
250
|
-
pair_idx = np.array(_bond_pair_idx(mol.GetBonds())).T.tolist()
|
|
251
|
-
pair_dict: Dict[int, List[int]] = {}
|
|
252
|
-
for key, i in enumerate(pair_idx[0]):
|
|
253
|
-
if i not in pair_dict:
|
|
254
|
-
pair_dict[i] = [pair_idx[1][key]]
|
|
255
|
-
else:
|
|
256
|
-
pair_dict[i].append(pair_idx[1][key])
|
|
257
|
-
g = Graph(n, adjacency_dict=pair_dict)
|
|
258
|
-
cl = canon_label(g) # type: list
|
|
259
|
-
else:
|
|
260
|
-
warnings.warn(
|
|
261
|
-
"\033[32;1m"
|
|
262
|
-
"`pynauty` is not installed."
|
|
263
|
-
" Switched to canonicalization function provided by `rdkit`."
|
|
264
|
-
" This is the expected behaviour only if you are working on Windows platform."
|
|
265
|
-
"\033[0m",
|
|
266
|
-
stacklevel=2,
|
|
267
|
-
)
|
|
268
|
-
cl = list(CanonicalRankAtoms(mol, breakTies=True))
|
|
269
|
-
symbols = np.array([[s] for s in symbols])[cl].flatten().tolist()
|
|
270
|
-
coordinates = coordinates[cl]
|
|
271
|
-
# ------- Find global coordinate frame -------
|
|
272
|
-
if n == 2:
|
|
273
|
-
d = np.round(np.linalg.norm(coordinates[0] - coordinates[1], 2), decimals)
|
|
274
|
-
return f"{symbols[0]} {'0.0'} {'0.0'} {'0.0'} {symbols[1]} {d} {'0.0'} {'0.0'}"
|
|
275
|
-
for idx_0 in range(n - 2):
|
|
276
|
-
_vec0 = coordinates[idx_0] - coordinates[idx_0 + 1]
|
|
277
|
-
_vec1 = coordinates[idx_0] - coordinates[idx_0 + 2]
|
|
278
|
-
_d1 = np.linalg.norm(_vec0, 2)
|
|
279
|
-
_d2 = np.linalg.norm(_vec1, 2)
|
|
280
|
-
if 1 - np.abs(np.dot(_vec0, _vec1) / (_d1 * _d2)) > 1e-6:
|
|
281
|
-
break
|
|
282
|
-
x = (coordinates[idx_0 + 1] - coordinates[idx_0]) / _d1
|
|
283
|
-
y = np.cross((coordinates[idx_0 + 2] - coordinates[idx_0]), x)
|
|
284
|
-
y_d = np.linalg.norm(y, 2)
|
|
285
|
-
y = y / np.ma.filled(np.ma.array(y_d, mask=y_d == 0), np.inf)
|
|
286
|
-
z = np.cross(x, y)
|
|
287
|
-
# ------- Build spherical coordinates -------
|
|
288
|
-
vec = coordinates - coordinates[idx_0]
|
|
289
|
-
d = np.linalg.norm(vec, 2, axis=-1)
|
|
290
|
-
_d = np.ma.filled(np.ma.array(d, mask=d == 0), np.inf)
|
|
291
|
-
theta = angle_scale * np.arccos(np.dot(vec, z) / _d) # in [0, \pi]
|
|
292
|
-
phi = angle_scale * np.arctan2(np.dot(vec, y), np.dot(vec, x)) # in [-\pi, \pi]
|
|
293
|
-
info = np.vstack([d, theta, phi]).T
|
|
294
|
-
info[idx_0] = np.zeros(3)
|
|
295
|
-
info = [
|
|
296
|
-
f"{symbols[i]} {r[0]} {r[1]} {r[2]}"
|
|
297
|
-
for i, r in enumerate(np.round(info, decimals))
|
|
298
|
-
]
|
|
299
|
-
return " ".join(info)
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
def seq2geo(
|
|
303
|
-
seq: str, angle_unit: str = "degree"
|
|
304
|
-
) -> Optional[Tuple[List[str], List[List[float]]]]:
|
|
305
|
-
"""
|
|
306
|
-
Sequence-to-geometry function.\n
|
|
307
|
-
The method follows the descriptions in paper: https://arxiv.org/abs/2408.10120.
|
|
308
|
-
|
|
309
|
-
:param seq: `Geo2Seq` string
|
|
310
|
-
:param angle_unit: `'degree'` or `'radian'`
|
|
311
|
-
:type seq: str
|
|
312
|
-
:type angle_unit: str
|
|
313
|
-
:return: (symbols, coordinates) if `seq` is valid
|
|
314
|
-
:rtype: tuple | None
|
|
315
|
-
"""
|
|
316
|
-
assert angle_unit in ("degree", "radian")
|
|
317
|
-
angle_scale = np.pi / 180 if angle_unit == "degree" else 1.0
|
|
318
|
-
tokens = seq.split()
|
|
319
|
-
if len(tokens) % 4 == 0:
|
|
320
|
-
tokens = np.array(tokens).reshape(-1, 4).tolist()
|
|
321
|
-
symbols, coordinates = [], []
|
|
322
|
-
for i in tokens:
|
|
323
|
-
symbol = i[0]
|
|
324
|
-
if len(_atom_regex.findall(symbol)) != 1:
|
|
325
|
-
return None
|
|
326
|
-
symbols.append(symbol)
|
|
327
|
-
try:
|
|
328
|
-
d, theta, phi = float(i[1]), float(i[2]), float(i[3])
|
|
329
|
-
x = d * np.sin(theta * angle_scale) * np.cos(phi * angle_scale)
|
|
330
|
-
y = d * np.sin(theta * angle_scale) * np.sin(phi * angle_scale)
|
|
331
|
-
z = d * np.cos(theta * angle_scale)
|
|
332
|
-
coordinates.append([x.item(), y.item(), z.item()])
|
|
333
|
-
except ValueError:
|
|
334
|
-
return None
|
|
335
|
-
return symbols, coordinates
|
|
336
|
-
return None
|
|
337
|
-
|
|
338
|
-
|
|
339
229
|
@torch.no_grad()
|
|
340
230
|
def sample(
|
|
341
|
-
model: ChemBFN,
|
|
231
|
+
model: Union[ChemBFN, EnsembleChemBFN],
|
|
342
232
|
batch_size: int,
|
|
343
233
|
sequence_size: int,
|
|
344
234
|
sample_step: int = 100,
|
|
345
|
-
y: Optional[Tensor] = None,
|
|
235
|
+
y: Optional[Union[Tensor, Dict[str, Tensor], List[Tensor]]] = None,
|
|
346
236
|
guidance_strength: float = 4.0,
|
|
347
237
|
device: Union[str, torch.device, None] = None,
|
|
348
238
|
vocab_keys: List[str] = VOCAB_KEYS,
|
|
@@ -358,7 +248,9 @@ def sample(
|
|
|
358
248
|
:param batch_size: batch size
|
|
359
249
|
:param sequence_size: max sequence length
|
|
360
250
|
:param sample_step: number of sampling steps
|
|
361
|
-
:param y: conditioning vector;
|
|
251
|
+
:param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f) \n
|
|
252
|
+
or a list/`dict` of conditions; shape: (n_b, n_c) * n_h
|
|
253
|
+
|
|
362
254
|
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
363
255
|
:param device: hardware accelerator
|
|
364
256
|
:param vocab_keys: a list of (ordered) vocabulary
|
|
@@ -366,11 +258,11 @@ def sample(
|
|
|
366
258
|
:param method: sampling method chosen from `"ODE:x"` or `"BFN"` where `x` is the value of sampling temperature; default is `"BFN"`
|
|
367
259
|
:param allowed_tokens: a list of allowed tokens; default is `"all"`
|
|
368
260
|
:param sort: whether to sort the samples according to entropy values; default is `False`
|
|
369
|
-
:type model: bayesianflow_for_chem.model.ChemBFN
|
|
261
|
+
:type model: bayesianflow_for_chem.model.ChemBFN | bayesianflow_for_chem.model.EnsembleChemBFN
|
|
370
262
|
:type batch_size: int
|
|
371
263
|
:type sequence_size: int
|
|
372
264
|
:type sample_step: int
|
|
373
|
-
:type y: torch.Tensor | None
|
|
265
|
+
:type y: torch.Tensor | list | dict | None
|
|
374
266
|
:type guidance_strength: float
|
|
375
267
|
:type device: str | torch.device | None
|
|
376
268
|
:type vocab_keys: list
|
|
@@ -382,11 +274,23 @@ def sample(
|
|
|
382
274
|
:rtype: list
|
|
383
275
|
"""
|
|
384
276
|
assert method.split(":")[0].lower() in ("ode", "bfn")
|
|
277
|
+
if isinstance(model, EnsembleChemBFN):
|
|
278
|
+
assert y is not None, "conditioning is required while using an ensemble model."
|
|
279
|
+
assert isinstance(y, list) or isinstance(y, dict)
|
|
280
|
+
else:
|
|
281
|
+
assert isinstance(y, Tensor) or y is None
|
|
385
282
|
if device is None:
|
|
386
283
|
device = _find_device()
|
|
387
284
|
model.to(device).eval()
|
|
388
285
|
if y is not None:
|
|
389
|
-
y
|
|
286
|
+
if isinstance(y, Tensor):
|
|
287
|
+
y = y.to(device)
|
|
288
|
+
elif isinstance(y, list):
|
|
289
|
+
y = [i.to(device) for i in y]
|
|
290
|
+
elif isinstance(y, dict):
|
|
291
|
+
y = {k: v.to(device) for k, v in y.items()}
|
|
292
|
+
else:
|
|
293
|
+
raise NotImplementedError
|
|
390
294
|
if isinstance(allowed_tokens, list):
|
|
391
295
|
token_mask = [0 if i in allowed_tokens else 1 for i in vocab_keys]
|
|
392
296
|
token_mask = torch.tensor([[token_mask]], dtype=torch.bool).to(device)
|
|
@@ -416,10 +320,10 @@ def sample(
|
|
|
416
320
|
|
|
417
321
|
@torch.no_grad()
|
|
418
322
|
def inpaint(
|
|
419
|
-
model: ChemBFN,
|
|
323
|
+
model: Union[ChemBFN, EnsembleChemBFN],
|
|
420
324
|
x: Tensor,
|
|
421
325
|
sample_step: int = 100,
|
|
422
|
-
y: Optional[Tensor] = None,
|
|
326
|
+
y: Optional[Union[Tensor, Dict[str, Tensor], List[Tensor]]] = None,
|
|
423
327
|
guidance_strength: float = 4.0,
|
|
424
328
|
device: Union[str, torch.device, None] = None,
|
|
425
329
|
vocab_keys: List[str] = VOCAB_KEYS,
|
|
@@ -434,7 +338,9 @@ def inpaint(
|
|
|
434
338
|
:param model: trained ChemBFN model
|
|
435
339
|
:param x: categorical indices of scaffold; shape: (n_b, n_t)
|
|
436
340
|
:param sample_step: number of sampling steps
|
|
437
|
-
:param y: conditioning vector; shape: (n_b, 1, n_f)
|
|
341
|
+
:param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f) \n
|
|
342
|
+
or a list/`dict` of conditions; shape: (n_b, n_c) * n_h
|
|
343
|
+
|
|
438
344
|
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
439
345
|
:param device: hardware accelerator
|
|
440
346
|
:param vocab_keys: a list of (ordered) vocabulary
|
|
@@ -442,10 +348,10 @@ def inpaint(
|
|
|
442
348
|
:param method: sampling method chosen from `"ODE:x"` or `"BFN"` where `x` is the value of sampling temperature; default is `"BFN"`
|
|
443
349
|
:param allowed_tokens: a list of allowed tokens; default is `"all"`
|
|
444
350
|
:param sort: whether to sort the samples according to entropy values; default is `False`
|
|
445
|
-
:type model: bayesianflow_for_chem.model.ChemBFN
|
|
351
|
+
:type model: bayesianflow_for_chem.model.ChemBFN | bayesianflow_for_chem.model.EnsembleChemBFN
|
|
446
352
|
:type x: torch.Tensor
|
|
447
353
|
:type sample_step: int
|
|
448
|
-
:type y: torch.Tensor | None
|
|
354
|
+
:type y: torch.Tensor | list | dict | None
|
|
449
355
|
:type guidance_strength: float
|
|
450
356
|
:type device: str | torch.device | None
|
|
451
357
|
:type vocab_keys: list
|
|
@@ -457,12 +363,24 @@ def inpaint(
|
|
|
457
363
|
:rtype: list
|
|
458
364
|
"""
|
|
459
365
|
assert method.split(":")[0].lower() in ("ode", "bfn")
|
|
366
|
+
if isinstance(model, EnsembleChemBFN):
|
|
367
|
+
assert y is not None, "conditioning is required while using an ensemble model."
|
|
368
|
+
assert isinstance(y, list) or isinstance(y, dict)
|
|
369
|
+
else:
|
|
370
|
+
assert isinstance(y, Tensor) or y is None
|
|
460
371
|
if device is None:
|
|
461
372
|
device = _find_device()
|
|
462
373
|
model.to(device).eval()
|
|
463
374
|
x = x.to(device)
|
|
464
375
|
if y is not None:
|
|
465
|
-
y
|
|
376
|
+
if isinstance(y, Tensor):
|
|
377
|
+
y = y.to(device)
|
|
378
|
+
elif isinstance(y, list):
|
|
379
|
+
y = [i.to(device) for i in y]
|
|
380
|
+
elif isinstance(y, dict):
|
|
381
|
+
y = {k: v.to(device) for k, v in y.items()}
|
|
382
|
+
else:
|
|
383
|
+
raise NotImplementedError
|
|
466
384
|
if isinstance(allowed_tokens, list):
|
|
467
385
|
token_mask = [0 if i in allowed_tokens else 1 for i in vocab_keys]
|
|
468
386
|
token_mask = torch.tensor([[token_mask]], dtype=torch.bool).to(device)
|
|
@@ -627,3 +545,178 @@ def quantise_model(model: ChemBFN) -> nn.Module:
|
|
|
627
545
|
model, {nn.Linear, Linear}, torch.qint8, mapping
|
|
628
546
|
)
|
|
629
547
|
return quantised_model
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
class GeometryConverter:
|
|
551
|
+
"""
|
|
552
|
+
Converting between different 2D/3D molecular representations.
|
|
553
|
+
"""
|
|
554
|
+
|
|
555
|
+
@staticmethod
|
|
556
|
+
def _xyz2mol(symbols: List[str], coordinates: np.ndarray) -> Mol:
|
|
557
|
+
xyz_block = [str(len(symbols)), ""]
|
|
558
|
+
r = coordinates
|
|
559
|
+
for i, atom in enumerate(symbols):
|
|
560
|
+
xyz_block.append(f"{atom} {r[i][0]:.10f} {r[i][1]:.10f} {r[i][2]:.10f}")
|
|
561
|
+
return MolFromXYZBlock("\n".join(xyz_block))
|
|
562
|
+
|
|
563
|
+
def cartesian2smiles(
|
|
564
|
+
self,
|
|
565
|
+
symbols: List[str],
|
|
566
|
+
coordinates: np.ndarray,
|
|
567
|
+
charge: int = 0,
|
|
568
|
+
canonical: bool = True,
|
|
569
|
+
) -> str:
|
|
570
|
+
"""
|
|
571
|
+
Transform (guess out) molecular geometry to SMILES string.
|
|
572
|
+
|
|
573
|
+
:param symbols: a list of atomic symbols
|
|
574
|
+
:param coordinates: Cartesian coordinates; shape: (n_a, 3)
|
|
575
|
+
:param charge: net charge
|
|
576
|
+
:param canonical: whether to canonicalise the SMILES
|
|
577
|
+
:type symbols: list
|
|
578
|
+
:type coordinates: numpy.ndarray
|
|
579
|
+
:type charge: int
|
|
580
|
+
:type canonical: bool
|
|
581
|
+
:return: SMILES string
|
|
582
|
+
:rtype: str
|
|
583
|
+
"""
|
|
584
|
+
mol = self._xyz2mol(symbols, coordinates)
|
|
585
|
+
rdDetermineBonds.DetermineBonds(mol, charge=charge)
|
|
586
|
+
smiles = MolToSmiles(mol)
|
|
587
|
+
if canonical:
|
|
588
|
+
smiles = CanonSmiles(smiles)
|
|
589
|
+
return smiles
|
|
590
|
+
|
|
591
|
+
def canonicalise(
|
|
592
|
+
self, symbols: List[str], coordinates: np.ndarray
|
|
593
|
+
) -> Tuple[List[str], np.ndarray]:
|
|
594
|
+
"""
|
|
595
|
+
Canonicalising the 3D molecular graph.
|
|
596
|
+
|
|
597
|
+
:param symbols: a list of atomic symbols
|
|
598
|
+
:param coordinates: Cartesian coordinates; shape: (n_a, 3)
|
|
599
|
+
:type symbols: list
|
|
600
|
+
:type coordinates: numpy.ndarray
|
|
601
|
+
:return: canonicalised symbols \n
|
|
602
|
+
canonicalised coordinates; shape: (n_a, 3)
|
|
603
|
+
:rtype: tuple
|
|
604
|
+
"""
|
|
605
|
+
if not _use_pynauty:
|
|
606
|
+
if platform.system() == "Windows":
|
|
607
|
+
raise NotImplementedError(
|
|
608
|
+
"This method is not implemented on Windows platform."
|
|
609
|
+
)
|
|
610
|
+
else:
|
|
611
|
+
raise ImportError("`pynauty` is not installed.")
|
|
612
|
+
n = len(symbols)
|
|
613
|
+
if n == 1:
|
|
614
|
+
return symbols, coordinates
|
|
615
|
+
mol = self._xyz2mol(symbols, coordinates)
|
|
616
|
+
rdDetermineBonds.DetermineConnectivity(mol)
|
|
617
|
+
# ------- Canonicalization -------
|
|
618
|
+
pair_idx = np.array(_bond_pair_idx(mol.GetBonds())).T.tolist()
|
|
619
|
+
pair_dict: Dict[int, List[int]] = {}
|
|
620
|
+
for key, i in enumerate(pair_idx[0]):
|
|
621
|
+
if i not in pair_dict:
|
|
622
|
+
pair_dict[i] = [pair_idx[1][key]]
|
|
623
|
+
else:
|
|
624
|
+
pair_dict[i].append(pair_idx[1][key])
|
|
625
|
+
g = Graph(n, adjacency_dict=pair_dict)
|
|
626
|
+
cl = canon_label(g) # type: list
|
|
627
|
+
symbols = np.array([[s] for s in symbols])[cl].flatten().tolist()
|
|
628
|
+
coordinates = coordinates[cl]
|
|
629
|
+
return symbols, coordinates
|
|
630
|
+
|
|
631
|
+
@staticmethod
|
|
632
|
+
def cartesian2spherical(coordinates: np.ndarray) -> np.ndarray:
|
|
633
|
+
"""
|
|
634
|
+
Transforming Cartesian coordinate to spherical form.\n
|
|
635
|
+
The method is adapted from the paper: https://arxiv.org/abs/2408.10120.
|
|
636
|
+
|
|
637
|
+
:param coordinates: Cartesian coordinates; shape: (n_a, 3)
|
|
638
|
+
:type coordinates: numpy.ndarray
|
|
639
|
+
:return: spherical coordinates; shape: (n_a, 3)
|
|
640
|
+
:rtype: numpy.ndarray
|
|
641
|
+
"""
|
|
642
|
+
n = coordinates.shape[0]
|
|
643
|
+
if n == 1:
|
|
644
|
+
return np.array([[0.0, 0.0, 0.0]])
|
|
645
|
+
# ------- Find global coordinate frame -------
|
|
646
|
+
if n == 2:
|
|
647
|
+
d = np.linalg.norm(coordinates[0] - coordinates[1], 2)
|
|
648
|
+
return np.array([[0.0, 0.0, 0.0], [d, 0.0, 0.0]])
|
|
649
|
+
for idx_0 in range(n - 2):
|
|
650
|
+
_vec0 = coordinates[idx_0] - coordinates[idx_0 + 1]
|
|
651
|
+
_vec1 = coordinates[idx_0] - coordinates[idx_0 + 2]
|
|
652
|
+
_d1 = np.linalg.norm(_vec0, 2)
|
|
653
|
+
_d2 = np.linalg.norm(_vec1, 2)
|
|
654
|
+
if 1 - np.abs(np.dot(_vec0, _vec1) / (_d1 * _d2)) > 1e-6:
|
|
655
|
+
break
|
|
656
|
+
x = (coordinates[idx_0 + 1] - coordinates[idx_0]) / _d1
|
|
657
|
+
y = np.cross((coordinates[idx_0 + 2] - coordinates[idx_0]), x)
|
|
658
|
+
y_d = np.linalg.norm(y, 2)
|
|
659
|
+
y = y / np.ma.filled(np.ma.array(y_d, mask=y_d == 0), np.inf)
|
|
660
|
+
z = np.cross(x, y)
|
|
661
|
+
# ------- Build spherical coordinates -------
|
|
662
|
+
vec = coordinates - coordinates[idx_0]
|
|
663
|
+
d = np.linalg.norm(vec, 2, axis=-1)
|
|
664
|
+
_d = np.ma.filled(np.ma.array(d, mask=d == 0), np.inf)
|
|
665
|
+
theta = np.arccos(np.dot(vec, z) / _d) # in [0, \pi]
|
|
666
|
+
phi = np.arctan2(np.dot(vec, y), np.dot(vec, x)) # in [-\pi, \pi]
|
|
667
|
+
info = np.vstack([d, theta, phi]).T
|
|
668
|
+
info[idx_0] = np.zeros_like(info[idx_0])
|
|
669
|
+
return info
|
|
670
|
+
|
|
671
|
+
def geo2seq(
|
|
672
|
+
self, symbols: List[str], coordinates: np.ndarray, decimals: int = 2
|
|
673
|
+
) -> str:
|
|
674
|
+
"""
|
|
675
|
+
Geometry-to-sequence function.\n
|
|
676
|
+
The algorithm follows the descriptions in paper: https://arxiv.org/abs/2408.10120.
|
|
677
|
+
|
|
678
|
+
:param symbols: a list of atomic symbols
|
|
679
|
+
:param coordinates: Cartesian coordinates; shape: (n_a, 3)
|
|
680
|
+
:param decimals: the maxmium number of decimals to keep; default is 2
|
|
681
|
+
:type symbols: list
|
|
682
|
+
:type coordinates: numpy.ndarray
|
|
683
|
+
:type decimals: int
|
|
684
|
+
:return: `Geo2Seq` string
|
|
685
|
+
:rtype: str
|
|
686
|
+
"""
|
|
687
|
+
symbols, coordinates = self.canonicalise(symbols, coordinates)
|
|
688
|
+
info = self.cartesian2spherical(coordinates)
|
|
689
|
+
info = [
|
|
690
|
+
f"{symbols[i]} {r[0]} {r[1]} {r[2]}"
|
|
691
|
+
for i, r in enumerate(np.round(info, decimals))
|
|
692
|
+
]
|
|
693
|
+
return " ".join(info)
|
|
694
|
+
|
|
695
|
+
@staticmethod
|
|
696
|
+
def seq2geo(seq: str) -> Tuple[Optional[List[str]], Optional[np.ndarray]]:
|
|
697
|
+
"""
|
|
698
|
+
Sequence-to-geometry function.\n
|
|
699
|
+
The method follows the descriptions in paper: https://arxiv.org/abs/2408.10120.
|
|
700
|
+
|
|
701
|
+
:param seq: `Geo2Seq` string
|
|
702
|
+
:type seq: str
|
|
703
|
+
:return: (symbols, coordinates) if `seq` is valid
|
|
704
|
+
:rtype: tuple
|
|
705
|
+
"""
|
|
706
|
+
tokens = seq.split()
|
|
707
|
+
if len(tokens) % 4 != 0:
|
|
708
|
+
return None, None
|
|
709
|
+
tokens = np.array(tokens).reshape(-1, 4)
|
|
710
|
+
symbols, coordinates = tokens[::, 0], tokens[::, 1:]
|
|
711
|
+
if sum([len(_atom_regex.findall(sym)) for sym in symbols]) != len(symbols):
|
|
712
|
+
return None, None
|
|
713
|
+
try:
|
|
714
|
+
coord = [[float(i) for i in j] for j in coordinates]
|
|
715
|
+
coord = np.array(coord)
|
|
716
|
+
except ValueError:
|
|
717
|
+
return None, None
|
|
718
|
+
d, theta, phi = coord[::, 0, None], coord[::, 1, None], coord[::, 2, None]
|
|
719
|
+
x = d * np.sin(theta) * np.cos(phi)
|
|
720
|
+
y = d * np.sin(theta) * np.sin(phi)
|
|
721
|
+
z = d * np.cos(theta)
|
|
722
|
+
return symbols, np.concatenate([x, y, z], -1)
|
bayesianflow_for_chem/train.py
CHANGED
|
@@ -37,7 +37,8 @@ class Model(LightningModule):
|
|
|
37
37
|
"""
|
|
38
38
|
A `~lightning.LightningModule` wrapper of bayesian flow network for chemistry model.\n
|
|
39
39
|
This module is used in training stage only. By calling `Model(...).export_model(YOUR_WORK_DIR)` after training,
|
|
40
|
-
the model(s) will be saved to `YOUR_WORK_DIR/model.pt`
|
|
40
|
+
the model(s) will be saved to `YOUR_WORK_DIR/model.pt` (if LoRA is enabled then `YOUR_WORK_DIR/lora.pt`)
|
|
41
|
+
and (if exists) `YOUR_WORK_DIR/mlp.pt`.
|
|
41
42
|
|
|
42
43
|
:param model: `~bayesianflow_for_chem.model.ChemBFN` instance.
|
|
43
44
|
:param mlp: `~bayesianflow_for_chem.model.MLP` instance or `None`.
|
|
@@ -135,7 +136,8 @@ class Regressor(LightningModule):
|
|
|
135
136
|
"""
|
|
136
137
|
A `~lightning.LightningModule` wrapper of bayesian flow network for chemistry regression model.\n
|
|
137
138
|
This module is used in training stage only. By calling `Regressor(...).export_model(YOUR_WORK_DIR)` after training,
|
|
138
|
-
the models will be saved to `YOUR_WORK_DIR/
|
|
139
|
+
the models will be saved to `YOUR_WORK_DIR/model_ft.pt` (if LoRA is enabled then `YOUR_WORK_DIR/lora.pt`)
|
|
140
|
+
and `YOUR_WORK_DIR/readout.pt`.
|
|
139
141
|
|
|
140
142
|
:param model: `~bayesianflow_for_chem.model.ChemBFN` instance.
|
|
141
143
|
:param mlp: `~bayesianflow_for_chem.model.MLP` instance.
|
|
@@ -218,7 +220,7 @@ class Regressor(LightningModule):
|
|
|
218
220
|
"""
|
|
219
221
|
Save the trained model.
|
|
220
222
|
|
|
221
|
-
:param workdir: the directory to save the
|
|
223
|
+
:param workdir: the directory to save the models
|
|
222
224
|
:type workdir: pathlib.Path
|
|
223
225
|
:return:
|
|
224
226
|
:rtype: None
|
|
@@ -1,16 +1,15 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: bayesianflow_for_chem
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.3.0
|
|
4
4
|
Summary: Bayesian flow network framework for Chemistry
|
|
5
5
|
Home-page: https://augus1999.github.io/bayesian-flow-network-for-chemistry/
|
|
6
6
|
Author: Nianze A. Tao
|
|
7
7
|
Author-email: tao-nianze@hiroshima-u.ac.jp
|
|
8
|
-
License: AGPL-3.0
|
|
8
|
+
License: AGPL-3.0-or-later
|
|
9
9
|
Project-URL: Source, https://github.com/Augus1999/bayesian-flow-network-for-chemistry
|
|
10
10
|
Keywords: Chemistry,CLM,ChemBFN
|
|
11
11
|
Classifier: Development Status :: 5 - Production/Stable
|
|
12
12
|
Classifier: Intended Audience :: Science/Research
|
|
13
|
-
Classifier: License :: OSI Approved :: GNU Affero General Public License v3
|
|
14
13
|
Classifier: Natural Language :: English
|
|
15
14
|
Classifier: Programming Language :: Python :: 3
|
|
16
15
|
Classifier: Programming Language :: Python :: 3.9
|
|
@@ -39,6 +38,7 @@ Dynamic: description-content-type
|
|
|
39
38
|
Dynamic: home-page
|
|
40
39
|
Dynamic: keywords
|
|
41
40
|
Dynamic: license
|
|
41
|
+
Dynamic: license-file
|
|
42
42
|
Dynamic: project-url
|
|
43
43
|
Dynamic: provides-extra
|
|
44
44
|
Dynamic: requires-dist
|
|
@@ -86,7 +86,7 @@ You can find example scripts in [📁example](./example) folder.
|
|
|
86
86
|
|
|
87
87
|
## Pre-trained Model
|
|
88
88
|
|
|
89
|
-
You can find pretrained models
|
|
89
|
+
You can find pretrained models on our [🤗Hugging Face model page](https://huggingface.co/suenoomozawa/ChemBFN).
|
|
90
90
|
|
|
91
91
|
## Dataset Handling
|
|
92
92
|
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
bayesianflow_for_chem/__init__.py,sha256=3BW4-ri8OcMZAIPJBT2q-48L3LAY776xluMDC6WXaZU,329
|
|
2
|
+
bayesianflow_for_chem/data.py,sha256=EbCfhA1bCieVHVOYVk7nvgsaOzhKyFdnHd261qNR4BY,7763
|
|
3
|
+
bayesianflow_for_chem/model.py,sha256=fFcfg1RZuoJeptAtglo2U8j1EGNSGjItMHqlKdLGGhU,50799
|
|
4
|
+
bayesianflow_for_chem/scorer.py,sha256=7G1TVSwC0qONtNm6kiDZUWwvuFPzasNSjp4eJAk5TL0,4101
|
|
5
|
+
bayesianflow_for_chem/tool.py,sha256=Z9qF80qzK-CJk9MJaWuSNOLnA-LPiD6CiC7S3sZbBP8,27704
|
|
6
|
+
bayesianflow_for_chem/train.py,sha256=hGKyhGhLch-exSYPZdLXrLn3gf39Q1VLSJs2qtuikQE,9709
|
|
7
|
+
bayesianflow_for_chem/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
|
|
8
|
+
bayesianflow_for_chem-1.3.0.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
|
|
9
|
+
bayesianflow_for_chem-1.3.0.dist-info/METADATA,sha256=2BDjaVhIkd0TLolVETa2kb7xUGYhn8kdlq2CMfF-i7Y,5746
|
|
10
|
+
bayesianflow_for_chem-1.3.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
11
|
+
bayesianflow_for_chem-1.3.0.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
|
|
12
|
+
bayesianflow_for_chem-1.3.0.dist-info/RECORD,,
|
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
bayesianflow_for_chem/__init__.py,sha256=sdyCK-Zd32-FNOcjuSB02ABx8vn53phorQeVqyWMWk4,293
|
|
2
|
-
bayesianflow_for_chem/data.py,sha256=9tpRba40lxwrB6aPSJMkxUglEVC3VEQC9wWxhDuz3Q8,7760
|
|
3
|
-
bayesianflow_for_chem/model.py,sha256=HvEvW_xRbkv4eSv5lhd72BJMZkg-ZACEi1DAW3p5Q1Y,35918
|
|
4
|
-
bayesianflow_for_chem/scorer.py,sha256=mV1vX8aBGFra2BE7N8WHihVIo3dXmUdPQIGfSaiuNdk,4084
|
|
5
|
-
bayesianflow_for_chem/tool.py,sha256=VuEqbT7Qraa4vnKMHbToyAYIiRoQI7gEPLKEBCWGmVg,23706
|
|
6
|
-
bayesianflow_for_chem/train.py,sha256=kj6icGqymUUYopDtpre1oE_wpvpeNilbpzgffBsd1tk,9589
|
|
7
|
-
bayesianflow_for_chem/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
|
|
8
|
-
bayesianflow_for_chem-1.2.6.dist-info/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
|
|
9
|
-
bayesianflow_for_chem-1.2.6.dist-info/METADATA,sha256=Akoh5dQW_0jeYuGC4ZKKYHS1WJn0xRwGDr7ut-Q-5sc,5890
|
|
10
|
-
bayesianflow_for_chem-1.2.6.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
11
|
-
bayesianflow_for_chem-1.2.6.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
|
|
12
|
-
bayesianflow_for_chem-1.2.6.dist-info/RECORD,,
|
{bayesianflow_for_chem-1.2.6.dist-info → bayesianflow_for_chem-1.3.0.dist-info/licenses}/LICENSE
RENAMED
|
File without changes
|
{bayesianflow_for_chem-1.2.6.dist-info → bayesianflow_for_chem-1.3.0.dist-info}/top_level.txt
RENAMED
|
File without changes
|