bayesianflow-for-chem 1.2.7__tar.gz → 1.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bayesianflow-for-chem might be problematic. Click here for more details.
- {bayesianflow_for_chem-1.2.7 → bayesianflow_for_chem-1.4.0}/PKG-INFO +4 -8
- {bayesianflow_for_chem-1.2.7 → bayesianflow_for_chem-1.4.0}/README.md +2 -2
- {bayesianflow_for_chem-1.2.7 → bayesianflow_for_chem-1.4.0}/bayesianflow_for_chem/__init__.py +3 -3
- {bayesianflow_for_chem-1.2.7 → bayesianflow_for_chem-1.4.0}/bayesianflow_for_chem/data.py +2 -39
- {bayesianflow_for_chem-1.2.7 → bayesianflow_for_chem-1.4.0}/bayesianflow_for_chem/model.py +396 -30
- {bayesianflow_for_chem-1.2.7 → bayesianflow_for_chem-1.4.0}/bayesianflow_for_chem/scorer.py +1 -1
- {bayesianflow_for_chem-1.2.7 → bayesianflow_for_chem-1.4.0}/bayesianflow_for_chem/tool.py +141 -176
- {bayesianflow_for_chem-1.2.7 → bayesianflow_for_chem-1.4.0}/bayesianflow_for_chem/train.py +5 -3
- {bayesianflow_for_chem-1.2.7 → bayesianflow_for_chem-1.4.0}/bayesianflow_for_chem.egg-info/PKG-INFO +4 -8
- {bayesianflow_for_chem-1.2.7 → bayesianflow_for_chem-1.4.0}/bayesianflow_for_chem.egg-info/requires.txt +0 -3
- {bayesianflow_for_chem-1.2.7 → bayesianflow_for_chem-1.4.0}/pyproject.toml +1 -1
- {bayesianflow_for_chem-1.2.7 → bayesianflow_for_chem-1.4.0}/setup.py +2 -3
- {bayesianflow_for_chem-1.2.7 → bayesianflow_for_chem-1.4.0}/LICENSE +0 -0
- {bayesianflow_for_chem-1.2.7 → bayesianflow_for_chem-1.4.0}/bayesianflow_for_chem/vocab.txt +0 -0
- {bayesianflow_for_chem-1.2.7 → bayesianflow_for_chem-1.4.0}/bayesianflow_for_chem.egg-info/SOURCES.txt +0 -0
- {bayesianflow_for_chem-1.2.7 → bayesianflow_for_chem-1.4.0}/bayesianflow_for_chem.egg-info/dependency_links.txt +0 -0
- {bayesianflow_for_chem-1.2.7 → bayesianflow_for_chem-1.4.0}/bayesianflow_for_chem.egg-info/top_level.txt +0 -0
- {bayesianflow_for_chem-1.2.7 → bayesianflow_for_chem-1.4.0}/setup.cfg +0 -0
|
@@ -1,16 +1,15 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: bayesianflow_for_chem
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.4.0
|
|
4
4
|
Summary: Bayesian flow network framework for Chemistry
|
|
5
5
|
Home-page: https://augus1999.github.io/bayesian-flow-network-for-chemistry/
|
|
6
6
|
Author: Nianze A. Tao
|
|
7
7
|
Author-email: tao-nianze@hiroshima-u.ac.jp
|
|
8
|
-
License: AGPL-3.0
|
|
8
|
+
License: AGPL-3.0-or-later
|
|
9
9
|
Project-URL: Source, https://github.com/Augus1999/bayesian-flow-network-for-chemistry
|
|
10
10
|
Keywords: Chemistry,CLM,ChemBFN
|
|
11
11
|
Classifier: Development Status :: 5 - Production/Stable
|
|
12
12
|
Classifier: Intended Audience :: Science/Research
|
|
13
|
-
Classifier: License :: OSI Approved :: GNU Affero General Public License v3
|
|
14
13
|
Classifier: Natural Language :: English
|
|
15
14
|
Classifier: Programming Language :: Python :: 3
|
|
16
15
|
Classifier: Programming Language :: Python :: 3.9
|
|
@@ -29,8 +28,6 @@ Requires-Dist: loralib>=0.1.2
|
|
|
29
28
|
Requires-Dist: lightning>=2.2.0
|
|
30
29
|
Requires-Dist: scikit-learn>=1.5.0
|
|
31
30
|
Requires-Dist: typing_extensions>=4.8.0
|
|
32
|
-
Provides-Extra: geo2seq
|
|
33
|
-
Requires-Dist: pynauty>=2.8.8.1; extra == "geo2seq"
|
|
34
31
|
Dynamic: author
|
|
35
32
|
Dynamic: author-email
|
|
36
33
|
Dynamic: classifier
|
|
@@ -41,7 +38,6 @@ Dynamic: keywords
|
|
|
41
38
|
Dynamic: license
|
|
42
39
|
Dynamic: license-file
|
|
43
40
|
Dynamic: project-url
|
|
44
|
-
Dynamic: provides-extra
|
|
45
41
|
Dynamic: requires-dist
|
|
46
42
|
Dynamic: requires-python
|
|
47
43
|
Dynamic: summary
|
|
@@ -87,13 +83,13 @@ You can find example scripts in [📁example](./example) folder.
|
|
|
87
83
|
|
|
88
84
|
## Pre-trained Model
|
|
89
85
|
|
|
90
|
-
You can find pretrained models
|
|
86
|
+
You can find pretrained models on our [🤗Hugging Face model page](https://huggingface.co/suenoomozawa/ChemBFN).
|
|
91
87
|
|
|
92
88
|
## Dataset Handling
|
|
93
89
|
|
|
94
90
|
We provide a Python class [`CSVData`](./bayesianflow_for_chem/data.py) to handle data stored in CSV or similar format containing headers to identify the entities. The following is a quickstart.
|
|
95
91
|
|
|
96
|
-
1. Download your dataset file (e.g., ESOL
|
|
92
|
+
1. Download your dataset file (e.g., ESOL from [MoleculeNet](https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv)) and split the file:
|
|
97
93
|
```python
|
|
98
94
|
>>> from bayesianflow_for_chem.tool import split_data
|
|
99
95
|
|
|
@@ -39,13 +39,13 @@ You can find example scripts in [📁example](./example) folder.
|
|
|
39
39
|
|
|
40
40
|
## Pre-trained Model
|
|
41
41
|
|
|
42
|
-
You can find pretrained models
|
|
42
|
+
You can find pretrained models on our [🤗Hugging Face model page](https://huggingface.co/suenoomozawa/ChemBFN).
|
|
43
43
|
|
|
44
44
|
## Dataset Handling
|
|
45
45
|
|
|
46
46
|
We provide a Python class [`CSVData`](./bayesianflow_for_chem/data.py) to handle data stored in CSV or similar format containing headers to identify the entities. The following is a quickstart.
|
|
47
47
|
|
|
48
|
-
1. Download your dataset file (e.g., ESOL
|
|
48
|
+
1. Download your dataset file (e.g., ESOL from [MoleculeNet](https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv)) and split the file:
|
|
49
49
|
```python
|
|
50
50
|
>>> from bayesianflow_for_chem.tool import split_data
|
|
51
51
|
|
{bayesianflow_for_chem-1.2.7 → bayesianflow_for_chem-1.4.0}/bayesianflow_for_chem/__init__.py
RENAMED
|
@@ -4,8 +4,8 @@
|
|
|
4
4
|
ChemBFN package.
|
|
5
5
|
"""
|
|
6
6
|
from . import data, tool, train, scorer
|
|
7
|
-
from .model import ChemBFN, MLP
|
|
7
|
+
from .model import ChemBFN, MLP, EnsembleChemBFN
|
|
8
8
|
|
|
9
|
-
__all__ = ["data", "tool", "train", "scorer", "ChemBFN", "MLP"]
|
|
10
|
-
__version__ = "1.
|
|
9
|
+
__all__ = ["data", "tool", "train", "scorer", "ChemBFN", "MLP", "EnsembleChemBFN"]
|
|
10
|
+
__version__ = "1.4.0"
|
|
11
11
|
__author__ = "Nianze A. Tao (Omozawa Sueno)"
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
# -*- coding: utf-8 -*-
|
|
2
2
|
# Author: Nianze A. TAO (Omozawa SUENO)
|
|
3
3
|
"""
|
|
4
|
-
Tokenise SMILES/SAFE/SELFIES/
|
|
4
|
+
Tokenise SMILES/SAFE/SELFIES/protein-sequence strings.
|
|
5
5
|
"""
|
|
6
6
|
import os
|
|
7
7
|
import re
|
|
@@ -32,30 +32,14 @@ SMI_REGEX_PATTERN = (
|
|
|
32
32
|
r"~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
|
|
33
33
|
)
|
|
34
34
|
SEL_REGEX_PATTERN = r"(\[[^\]]+]|\.)"
|
|
35
|
-
GEO_REGEX_PATTERN = (
|
|
36
|
-
r"(H[e,f,g,s,o]?|"
|
|
37
|
-
r"L[i,v,a,r,u]|"
|
|
38
|
-
r"B[e,r,a,i,h,k]?|"
|
|
39
|
-
r"C[l,a,r,o,u,d,s,n,e,m,f]?|"
|
|
40
|
-
r"N[e,a,i,b,h,d,o,p]?|"
|
|
41
|
-
r"O[s,g]?|S[i,c,e,r,n,m,b,g]?|"
|
|
42
|
-
r"K[r]?|T[i,c,e,a,l,b,h,m,s]|"
|
|
43
|
-
r"G[a,e,d]|R[b,u,h,e,n,a,f,g]|"
|
|
44
|
-
r"Yb?|Z[n,r]|P[t,o,d,r,a,u,b,m]?|"
|
|
45
|
-
r"F[e,r,l,m]?|M[g,n,o,t,c,d]|"
|
|
46
|
-
r"A[l,r,s,g,u,t,c,m]|I[n,r]?|"
|
|
47
|
-
r"W|X[e]|E[u,r,s]|U|D[b,s,y]|"
|
|
48
|
-
r"-|.| |[0-9])"
|
|
49
|
-
)
|
|
50
35
|
AA_REGEX_PATTERN = r"(A|B|C|D|E|F|G|H|I|K|L|M|N|P|Q|R|S|T|V|W|Y|Z|-|.)"
|
|
51
36
|
smi_regex = re.compile(SMI_REGEX_PATTERN)
|
|
52
37
|
sel_regex = re.compile(SEL_REGEX_PATTERN)
|
|
53
|
-
geo_regex = re.compile(GEO_REGEX_PATTERN)
|
|
54
38
|
aa_regex = re.compile(AA_REGEX_PATTERN)
|
|
55
39
|
|
|
56
40
|
|
|
57
41
|
def load_vocab(
|
|
58
|
-
vocab_file: Union[str, Path]
|
|
42
|
+
vocab_file: Union[str, Path],
|
|
59
43
|
) -> Dict[str, Union[int, List[str], Dict[str, int]]]:
|
|
60
44
|
"""
|
|
61
45
|
Load vocabulary from source file.
|
|
@@ -86,9 +70,6 @@ AA_VOCAB_KEYS = (
|
|
|
86
70
|
)
|
|
87
71
|
AA_VOCAB_COUNT = len(AA_VOCAB_KEYS)
|
|
88
72
|
AA_VOCAB_DICT = dict(zip(AA_VOCAB_KEYS, range(AA_VOCAB_COUNT)))
|
|
89
|
-
GEO_VOCAB_KEYS = VOCAB_KEYS[0:3] + [" "] + VOCAB_KEYS[22:150] + [".", "-"]
|
|
90
|
-
GEO_VOCAB_COUNT = len(GEO_VOCAB_KEYS)
|
|
91
|
-
GEO_VOCAB_DICT = dict(zip(GEO_VOCAB_KEYS, range(GEO_VOCAB_COUNT)))
|
|
92
73
|
|
|
93
74
|
|
|
94
75
|
def smiles2vec(smiles: str) -> List[int]:
|
|
@@ -104,19 +85,6 @@ def smiles2vec(smiles: str) -> List[int]:
|
|
|
104
85
|
return [VOCAB_DICT[token] for token in tokens]
|
|
105
86
|
|
|
106
87
|
|
|
107
|
-
def geo2vec(geo2seq: str) -> List[int]:
|
|
108
|
-
"""
|
|
109
|
-
Geo2Seq tokenisation using a dataset-independent regex pattern.
|
|
110
|
-
|
|
111
|
-
:param geo2seq: Geo2Seq string
|
|
112
|
-
:type geo2seq: str
|
|
113
|
-
:return: tokens w/o `<start>` and `<end>`
|
|
114
|
-
:rtype: list
|
|
115
|
-
"""
|
|
116
|
-
tokens = [token for token in geo_regex.findall(geo2seq)]
|
|
117
|
-
return [GEO_VOCAB_DICT[token] for token in tokens]
|
|
118
|
-
|
|
119
|
-
|
|
120
88
|
def aa2vec(aa_seq: str) -> List[int]:
|
|
121
89
|
"""
|
|
122
90
|
Protein sequence tokenisation using a dataset-independent regex pattern.
|
|
@@ -147,11 +115,6 @@ def smiles2token(smiles: str) -> Tensor:
|
|
|
147
115
|
return torch.tensor([1] + smiles2vec(smiles) + [2], dtype=torch.long)
|
|
148
116
|
|
|
149
117
|
|
|
150
|
-
def geo2token(geo2seq: str) -> Tensor:
|
|
151
|
-
# start token: <start> = 1; end token: <esc> = 2
|
|
152
|
-
return torch.tensor([1] + geo2vec(geo2seq) + [2], dtype=torch.long)
|
|
153
|
-
|
|
154
|
-
|
|
155
118
|
def aa2token(aa_seq: str) -> Tensor:
|
|
156
119
|
# start token: <start> = 1; end token: <end> = 2
|
|
157
120
|
return torch.tensor([1] + aa2vec(aa_seq) + [2], dtype=torch.long)
|
|
@@ -4,7 +4,8 @@
|
|
|
4
4
|
Define Bayesian Flow Network for Chemistry (ChemBFN) model.
|
|
5
5
|
"""
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from
|
|
7
|
+
from copy import deepcopy
|
|
8
|
+
from typing import List, Tuple, Dict, Optional, Union, Callable
|
|
8
9
|
import torch
|
|
9
10
|
import torch.nn as nn
|
|
10
11
|
from torch import Tensor
|
|
@@ -161,8 +162,8 @@ class Attention(nn.Module):
|
|
|
161
162
|
:return: attentioned output; shape: (n_b, n_t, n_f)
|
|
162
163
|
:rtype: torch.Tensor
|
|
163
164
|
"""
|
|
164
|
-
n_b,
|
|
165
|
-
split = (n_b,
|
|
165
|
+
n_b, n_t, _ = shape = x.shape
|
|
166
|
+
split = (n_b, n_t, self.nh, self.d)
|
|
166
167
|
q, k, v = self.qkv(x).chunk(3, -1)
|
|
167
168
|
q = q.view(split).permute(2, 0, 1, 3).contiguous()
|
|
168
169
|
k = k.view(split).permute(2, 0, 1, 3).contiguous()
|
|
@@ -427,12 +428,12 @@ class ChemBFN(nn.Module):
|
|
|
427
428
|
c = self.time_embed(t)
|
|
428
429
|
if y is not None:
|
|
429
430
|
c += y
|
|
430
|
-
pe = self.position(
|
|
431
|
+
pe = self.position(n_t)
|
|
431
432
|
x = self.embedding(x)
|
|
432
433
|
attn_mask: Optional[Tensor] = None
|
|
433
434
|
if self.semi_autoregressive:
|
|
434
435
|
attn_mask = torch.tril(
|
|
435
|
-
torch.ones((1, n_b, n_t, n_t), device=
|
|
436
|
+
torch.ones((1, n_b, n_t, n_t), device=x.device), diagonal=0
|
|
436
437
|
)
|
|
437
438
|
else:
|
|
438
439
|
if mask is not None:
|
|
@@ -592,6 +593,13 @@ class ChemBFN(nn.Module):
|
|
|
592
593
|
x, logits = torch.broadcast_tensors(x[..., None], logits)
|
|
593
594
|
return (-logits.gather(-1, x[..., :1]).squeeze(-1)).mean()
|
|
594
595
|
|
|
596
|
+
@staticmethod
|
|
597
|
+
def reshape_y(y: Tensor) -> Tensor:
|
|
598
|
+
assert y.dim() <= 3 # this doesn't work if the model is frezen in JIT.
|
|
599
|
+
if y.dim() == 2:
|
|
600
|
+
return y[:, None, :]
|
|
601
|
+
return y
|
|
602
|
+
|
|
595
603
|
@torch.jit.export
|
|
596
604
|
def sample(
|
|
597
605
|
self,
|
|
@@ -607,7 +615,7 @@ class ChemBFN(nn.Module):
|
|
|
607
615
|
|
|
608
616
|
:param batch_size: batch size
|
|
609
617
|
:param sequence_size: max sequence length
|
|
610
|
-
:param y: conditioning vector; shape: (n_b, 1, n_f)
|
|
618
|
+
:param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
|
|
611
619
|
:param sample_step: number of sampling steps
|
|
612
620
|
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
613
621
|
:param token_mask: token mask; shape: (1, 1, n_vocab)
|
|
@@ -626,9 +634,7 @@ class ChemBFN(nn.Module):
|
|
|
626
634
|
/ self.K
|
|
627
635
|
)
|
|
628
636
|
if y is not None:
|
|
629
|
-
|
|
630
|
-
if y.shape[0] == 1:
|
|
631
|
-
y = y.repeat(batch_size, 1, 1)
|
|
637
|
+
y = self.reshape_y(y)
|
|
632
638
|
for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
|
|
633
639
|
t = (i - 1).view(1, 1, 1).repeat(batch_size, 1, 1) / sample_step
|
|
634
640
|
p = self.discrete_output_distribution(theta, t, y, guidance_strength)
|
|
@@ -663,7 +669,7 @@ class ChemBFN(nn.Module):
|
|
|
663
669
|
|
|
664
670
|
:param batch_size: batch size
|
|
665
671
|
:param sequence_size: max sequence length
|
|
666
|
-
:param y: conditioning vector; shape: (n_b, 1, n_f)
|
|
672
|
+
:param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
|
|
667
673
|
:param sample_step: number of sampling steps
|
|
668
674
|
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
669
675
|
:param token_mask: token mask; shape: (1, 1, n_vocab)
|
|
@@ -681,9 +687,7 @@ class ChemBFN(nn.Module):
|
|
|
681
687
|
"""
|
|
682
688
|
z = torch.zeros((batch_size, sequence_size, self.K), device=self.beta.device)
|
|
683
689
|
if y is not None:
|
|
684
|
-
|
|
685
|
-
if y.shape[0] == 1:
|
|
686
|
-
y = y.repeat(batch_size, 1, 1)
|
|
690
|
+
y = self.reshape_y(y)
|
|
687
691
|
for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
|
|
688
692
|
t = (i - 1).view(1, 1, 1).repeat(batch_size, 1, 1) / sample_step
|
|
689
693
|
theta = torch.softmax(z, -1)
|
|
@@ -714,7 +718,7 @@ class ChemBFN(nn.Module):
|
|
|
714
718
|
Molecule inpaint functionality.
|
|
715
719
|
|
|
716
720
|
:param x: categorical indices of scaffold; shape: (n_b, n_t)
|
|
717
|
-
:param y: conditioning vector; shape: (n_b, 1, n_f)
|
|
721
|
+
:param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
|
|
718
722
|
:param sample_step: number of sampling steps
|
|
719
723
|
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
720
724
|
:param token_mask: token mask; shape: (1, 1, n_vocab)
|
|
@@ -733,9 +737,7 @@ class ChemBFN(nn.Module):
|
|
|
733
737
|
x_onehot = nn.functional.one_hot(x, self.K) * mask
|
|
734
738
|
theta = x_onehot + (1 - mask) * theta
|
|
735
739
|
if y is not None:
|
|
736
|
-
|
|
737
|
-
if y.shape[0] == 1:
|
|
738
|
-
y = y.repeat(n_b, 1, 1)
|
|
740
|
+
y = self.reshape_y(y)
|
|
739
741
|
for i in torch.linspace(1, sample_step, sample_step, device=x.device):
|
|
740
742
|
t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
|
|
741
743
|
p = self.discrete_output_distribution(theta, t, y, guidance_strength)
|
|
@@ -769,7 +771,7 @@ class ChemBFN(nn.Module):
|
|
|
769
771
|
ODE inpainting.
|
|
770
772
|
|
|
771
773
|
:param x: categorical indices of scaffold; shape: (n_b, n_t)
|
|
772
|
-
:param y: conditioning vector; shape: (n_b, 1, n_f)
|
|
774
|
+
:param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f)
|
|
773
775
|
:param sample_step: number of sampling steps
|
|
774
776
|
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
775
777
|
:param token_mask: token mask; shape: (1, 1, n_vocab)
|
|
@@ -789,9 +791,7 @@ class ChemBFN(nn.Module):
|
|
|
789
791
|
x_onehot = nn.functional.one_hot(x, self.K) * mask
|
|
790
792
|
z = torch.zeros((n_b, n_t, self.K), device=self.beta.device)
|
|
791
793
|
if y is not None:
|
|
792
|
-
|
|
793
|
-
if y.shape[0] == 1:
|
|
794
|
-
y = y.repeat(n_b, 1, 1)
|
|
794
|
+
y = self.reshape_y(y)
|
|
795
795
|
for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
|
|
796
796
|
t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
|
|
797
797
|
theta = torch.softmax(z, -1)
|
|
@@ -847,13 +847,7 @@ class ChemBFN(nn.Module):
|
|
|
847
847
|
with open(ckpt, "rb") as f:
|
|
848
848
|
state = torch.load(f, "cpu", weights_only=True)
|
|
849
849
|
nn, hparam = state["nn"], state["hparam"]
|
|
850
|
-
model = cls(
|
|
851
|
-
hparam["num_vocab"],
|
|
852
|
-
hparam["channel"],
|
|
853
|
-
hparam["num_layer"],
|
|
854
|
-
hparam["num_head"],
|
|
855
|
-
hparam["dropout"],
|
|
856
|
-
)
|
|
850
|
+
model = cls(**hparam)
|
|
857
851
|
model.load_state_dict(nn, False)
|
|
858
852
|
if ckpt_lora:
|
|
859
853
|
with open(ckpt_lora, "rb") as g:
|
|
@@ -908,7 +902,7 @@ class MLP(nn.Module):
|
|
|
908
902
|
if self.class_input:
|
|
909
903
|
x = x.to(dtype=torch.long)
|
|
910
904
|
for layer in self.layers[:-1]:
|
|
911
|
-
x = torch.selu(layer(x))
|
|
905
|
+
x = torch.selu(layer.forward(x))
|
|
912
906
|
return self.layers[-1](x)
|
|
913
907
|
|
|
914
908
|
@classmethod
|
|
@@ -926,10 +920,382 @@ class MLP(nn.Module):
|
|
|
926
920
|
with open(ckpt, "rb") as f:
|
|
927
921
|
state = torch.load(f, "cpu", weights_only=True)
|
|
928
922
|
nn, hparam = state["nn"], state["hparam"]
|
|
929
|
-
model = cls(hparam
|
|
923
|
+
model = cls(**hparam)
|
|
930
924
|
model.load_state_dict(nn, strict)
|
|
931
925
|
return model
|
|
932
926
|
|
|
933
927
|
|
|
928
|
+
class EnsembleChemBFN(ChemBFN):
|
|
929
|
+
"""
|
|
930
|
+
This module does not fully support `torch.jit.script`. We have `EnsembleChemBFN.jit()`
|
|
931
|
+
method to JIT compile the submodels.
|
|
932
|
+
`torch.compile()` is a better choice to compiling the whole model.
|
|
933
|
+
"""
|
|
934
|
+
|
|
935
|
+
def __init__(
|
|
936
|
+
self,
|
|
937
|
+
base_model_path: Union[str, Path],
|
|
938
|
+
lora_paths: Union[List[Union[str, Path]], Dict[str, Union[str, Path]]],
|
|
939
|
+
cond_heads: Union[List[nn.Module], Dict[str, nn.Module]],
|
|
940
|
+
adapter_weights: Optional[Union[List[float], Dict[str, float]]] = None,
|
|
941
|
+
semi_autoregressive_flags: Optional[Union[List[bool], Dict[str, bool]]] = None,
|
|
942
|
+
) -> None:
|
|
943
|
+
"""
|
|
944
|
+
Ensemble of ChemBFN models from LoRA checkpoints.
|
|
945
|
+
|
|
946
|
+
:param base_model_path: base model checkpoint file
|
|
947
|
+
:param lora_paths: a list of LoRA checkpoint files or a `dict` instance of these files
|
|
948
|
+
:param cond_heads: a list of conditioning network heads or a `dict` instance of these networks
|
|
949
|
+
:param adapter_weights: a list of weights of each LoRA finetuned model or a 'dict` instance of these weights; default is equally weighted
|
|
950
|
+
:param semi_autoregressive_flags: a list of the semi-autoregressive behaviour states of each LoRA finetuned model or a `dict` instance of these states; default is all `False`
|
|
951
|
+
:type base_model_path: str | pathlib.Path
|
|
952
|
+
:type lora_paths: list | dict
|
|
953
|
+
:type cond_heads: list | dict
|
|
954
|
+
:type adapter_weights: list | dict | None
|
|
955
|
+
:type semi_autoregressive_flags: list | dict | None
|
|
956
|
+
"""
|
|
957
|
+
n = len(lora_paths)
|
|
958
|
+
assert type(lora_paths) == type(
|
|
959
|
+
cond_heads
|
|
960
|
+
), "`lora_paths` and `cond_heads` should have the same type!"
|
|
961
|
+
assert n == len(
|
|
962
|
+
cond_heads
|
|
963
|
+
), "`lora_paths` and `cond_heads` should have the same length!"
|
|
964
|
+
if adapter_weights:
|
|
965
|
+
assert type(lora_paths) == type(
|
|
966
|
+
adapter_weights
|
|
967
|
+
), "`lora_paths` and `adapter_weights` should have the same type!"
|
|
968
|
+
assert n == len(
|
|
969
|
+
adapter_weights
|
|
970
|
+
), "`lora_paths` and `adapter_weights` should have the same length!"
|
|
971
|
+
if semi_autoregressive_flags:
|
|
972
|
+
assert type(lora_paths) == type(
|
|
973
|
+
semi_autoregressive_flags
|
|
974
|
+
), "`lora_paths` and `semi_autoregressive_flags` should have the same type!"
|
|
975
|
+
assert n == len(
|
|
976
|
+
semi_autoregressive_flags
|
|
977
|
+
), "`lora_paths` and `semi_autoregressive_flags` should have the same length!"
|
|
978
|
+
_label_is_dict = isinstance(lora_paths, dict)
|
|
979
|
+
if isinstance(lora_paths, list):
|
|
980
|
+
names = tuple(f"val_{i}" for i in range(n))
|
|
981
|
+
lora_paths = dict(zip(names, lora_paths))
|
|
982
|
+
cond_heads = dict(zip(names, cond_heads))
|
|
983
|
+
if not adapter_weights:
|
|
984
|
+
adapter_weights = (1 / n for _ in names)
|
|
985
|
+
if not semi_autoregressive_flags:
|
|
986
|
+
semi_autoregressive_flags = (False for _ in names)
|
|
987
|
+
adapter_weights = dict(zip(names, adapter_weights))
|
|
988
|
+
semi_autoregressive_flags = dict(zip(names, semi_autoregressive_flags))
|
|
989
|
+
else:
|
|
990
|
+
names = tuple(lora_paths.keys())
|
|
991
|
+
if not adapter_weights:
|
|
992
|
+
adapter_weights = dict(zip(names, (1 / n for _ in names)))
|
|
993
|
+
if not semi_autoregressive_flags:
|
|
994
|
+
semi_autoregressive_flags = dict(zip(names, (False for _ in names)))
|
|
995
|
+
base_model = ChemBFN.from_checkpoint(base_model_path)
|
|
996
|
+
models = dict(zip(names, (deepcopy(base_model.eval()) for _ in names)))
|
|
997
|
+
for k in names:
|
|
998
|
+
with open(lora_paths[k], "rb") as f:
|
|
999
|
+
state = torch.load(f, "cpu", weights_only=True)
|
|
1000
|
+
lora_nn, lora_param = state["lora_nn"], state["lora_param"]
|
|
1001
|
+
models[k].enable_lora(**lora_param)
|
|
1002
|
+
models[k].load_state_dict(lora_nn, False)
|
|
1003
|
+
models[k].semi_autoregressive = semi_autoregressive_flags[k]
|
|
1004
|
+
super().__init__(**base_model.hparam)
|
|
1005
|
+
self.cond_heads = nn.ModuleDict(cond_heads)
|
|
1006
|
+
self.models = nn.ModuleDict(models)
|
|
1007
|
+
self.adapter_weights = adapter_weights
|
|
1008
|
+
self._label_is_dict = _label_is_dict # flag
|
|
1009
|
+
# ------- remove unnecessary submodules -------
|
|
1010
|
+
self.embedding = None
|
|
1011
|
+
self.time_embed = None
|
|
1012
|
+
self.position = None
|
|
1013
|
+
self.encoder_layers = None
|
|
1014
|
+
self.final_layer = None
|
|
1015
|
+
self.__delattr__("embedding")
|
|
1016
|
+
self.__delattr__("time_embed")
|
|
1017
|
+
self.__delattr__("position")
|
|
1018
|
+
self.__delattr__("encoder_layers")
|
|
1019
|
+
self.__delattr__("final_layer")
|
|
1020
|
+
# ------- remove unused attributes -------
|
|
1021
|
+
self.__delattr__("semi_autoregressive")
|
|
1022
|
+
self.__delattr__("lora_enabled")
|
|
1023
|
+
self.__delattr__("lora_param")
|
|
1024
|
+
self.__delattr__("hparam")
|
|
1025
|
+
|
|
1026
|
+
def construct_y(
|
|
1027
|
+
self, c: Union[List[Tensor], Dict[str, Tensor]]
|
|
1028
|
+
) -> Dict[str, Tensor]:
|
|
1029
|
+
assert (
|
|
1030
|
+
isinstance(c, dict) is self._label_is_dict
|
|
1031
|
+
), f"`c` should be a {'`dict` instance' if self._label_is_dict else '`list` instance'} but got {type(c)} instand."
|
|
1032
|
+
out: Dict[str, Tensor] = {}
|
|
1033
|
+
if isinstance(c, list):
|
|
1034
|
+
c = dict(zip([f"val_{i}" for i in range(len(c))], c))
|
|
1035
|
+
for name, model in self.cond_heads.items():
|
|
1036
|
+
y = model.forward(c[name])
|
|
1037
|
+
if y.dim() == 2:
|
|
1038
|
+
y = y[:, None, :]
|
|
1039
|
+
out[name] = y
|
|
1040
|
+
return out
|
|
1041
|
+
|
|
1042
|
+
def discrete_output_distribution(
|
|
1043
|
+
self, theta: Tensor, t: Tensor, y: Dict[str, Tensor], w: float
|
|
1044
|
+
) -> Tensor:
|
|
1045
|
+
"""
|
|
1046
|
+
:param theta: input distribution; shape: (n_b, n_t, n_vocab)
|
|
1047
|
+
:param t: continuous time in [0, 1]; shape: (n_b, 1, 1)
|
|
1048
|
+
:param y: a dict of conditioning vectors; shape: (n_b, 1, n_f) * n_h
|
|
1049
|
+
:param w: guidance strength controlling the conditional generation
|
|
1050
|
+
:type theta: torch.Tensor
|
|
1051
|
+
:type t: torch.Tensor
|
|
1052
|
+
:type y: dict
|
|
1053
|
+
:type w: float
|
|
1054
|
+
:return: output distribution; shape: (n_b, n_t, n_vocab)
|
|
1055
|
+
:rtype: torch.Tensor
|
|
1056
|
+
"""
|
|
1057
|
+
theta = 2 * theta - 1 # rescale to [-1, 1]
|
|
1058
|
+
p_uncond, p_cond = torch.zeros_like(theta), torch.zeros_like(theta)
|
|
1059
|
+
# Q: Why not use `torch.vmap`? It's faster than doing the loop, isn't it?
|
|
1060
|
+
#
|
|
1061
|
+
# A: We have quite a few reasons to avoid using `vmap`:
|
|
1062
|
+
# 1. JIT doesn't support vmap;
|
|
1063
|
+
# 2. It's harder to switch on/off semi-autroregssive behaviours for individual
|
|
1064
|
+
# models when all models are stacked into one (we have a solution but it's not
|
|
1065
|
+
# that elegant);
|
|
1066
|
+
# 3. We just found that the result from vmap was not identical to doing the loop;
|
|
1067
|
+
# 4. vmap requires all models have the same size but it's not always that case
|
|
1068
|
+
# since we sometimes use different ranks of LoRA in finetuning.
|
|
1069
|
+
for name, model in self.models.items():
|
|
1070
|
+
p_uncond_ = model.forward(theta, t, None, None)
|
|
1071
|
+
p_uncond += p_uncond_ * self.adapter_weights[name]
|
|
1072
|
+
p_cond_ = model.forward(theta, t, None, y[name])
|
|
1073
|
+
p_cond += p_cond_ * self.adapter_weights[name]
|
|
1074
|
+
return softmax((1 + w) * p_cond - w * p_uncond, -1)
|
|
1075
|
+
|
|
1076
|
+
@staticmethod
|
|
1077
|
+
def reshape_y(y: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
1078
|
+
for k in y:
|
|
1079
|
+
assert y[k].dim() <= 3
|
|
1080
|
+
if y[k].dim() == 2:
|
|
1081
|
+
y[k] = y[k][:, None, :]
|
|
1082
|
+
return y
|
|
1083
|
+
|
|
1084
|
+
@torch.inference_mode()
|
|
1085
|
+
def sample(
|
|
1086
|
+
self,
|
|
1087
|
+
batch_size: int,
|
|
1088
|
+
sequence_size: int,
|
|
1089
|
+
conditions: Union[List[Tensor], Dict[str, Tensor]],
|
|
1090
|
+
sample_step: int = 100,
|
|
1091
|
+
guidance_strength: float = 4.0,
|
|
1092
|
+
token_mask: Optional[Tensor] = None,
|
|
1093
|
+
) -> Tuple[Tensor, Tensor]:
|
|
1094
|
+
"""
|
|
1095
|
+
Sample from a piror distribution.
|
|
1096
|
+
|
|
1097
|
+
:param batch_size: batch size
|
|
1098
|
+
:param sequence_size: max sequence length
|
|
1099
|
+
:param conditions: guidance conditions; shape: (n_b, n_c) * n_h
|
|
1100
|
+
:param sample_step: number of sampling steps
|
|
1101
|
+
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
1102
|
+
:param token_mask: token mask; shape: (1, 1, n_vocab)
|
|
1103
|
+
:type batch_size: int
|
|
1104
|
+
:type sequence_size: int
|
|
1105
|
+
:type conditions: list | dict
|
|
1106
|
+
:type sample_step: int
|
|
1107
|
+
:type guidance_strength: float
|
|
1108
|
+
:type token_mask: torch.Tensor | None
|
|
1109
|
+
:return: sampled token indices; shape: (n_b, n_t) \n
|
|
1110
|
+
entropy of the tokens; shape: (n_b)
|
|
1111
|
+
:rtype: tuple
|
|
1112
|
+
"""
|
|
1113
|
+
y = self.construct_y(conditions)
|
|
1114
|
+
return super().sample(
|
|
1115
|
+
batch_size, sequence_size, y, sample_step, guidance_strength, token_mask
|
|
1116
|
+
)
|
|
1117
|
+
|
|
1118
|
+
@torch.inference_mode()
|
|
1119
|
+
def ode_sample(
|
|
1120
|
+
self,
|
|
1121
|
+
batch_size: int,
|
|
1122
|
+
sequence_size: int,
|
|
1123
|
+
conditions: Union[List[Tensor], Dict[str, Tensor]],
|
|
1124
|
+
sample_step: int = 100,
|
|
1125
|
+
guidance_strength: float = 4.0,
|
|
1126
|
+
token_mask: Optional[Tensor] = None,
|
|
1127
|
+
temperature: float = 0.5,
|
|
1128
|
+
) -> Tuple[Tensor, Tensor]:
|
|
1129
|
+
"""
|
|
1130
|
+
ODE-based sampling.
|
|
1131
|
+
|
|
1132
|
+
:param batch_size: batch size
|
|
1133
|
+
:param sequence_size: max sequence length
|
|
1134
|
+
:param conditions: conditioning vector; shape: (n_b, n_c) * n_h
|
|
1135
|
+
:param sample_step: number of sampling steps
|
|
1136
|
+
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
1137
|
+
:param token_mask: token mask; shape: (1, 1, n_vocab)
|
|
1138
|
+
:param temperature: sampling temperature
|
|
1139
|
+
:type batch_size: int
|
|
1140
|
+
:type sequence_size: int
|
|
1141
|
+
:type conditions: list | dict
|
|
1142
|
+
:type sample_step: int
|
|
1143
|
+
:type guidance_strength: float
|
|
1144
|
+
:type token_mask: torch.Tensor | None
|
|
1145
|
+
:type temperature: float
|
|
1146
|
+
:return: sampled token indices; shape: (n_b, n_t) \n
|
|
1147
|
+
entropy of the tokens; shape: (n_b)
|
|
1148
|
+
:rtype: tuple
|
|
1149
|
+
"""
|
|
1150
|
+
y = self.construct_y(conditions)
|
|
1151
|
+
return super().ode_sample(
|
|
1152
|
+
batch_size,
|
|
1153
|
+
sequence_size,
|
|
1154
|
+
y,
|
|
1155
|
+
sample_step,
|
|
1156
|
+
guidance_strength,
|
|
1157
|
+
token_mask,
|
|
1158
|
+
temperature,
|
|
1159
|
+
)
|
|
1160
|
+
|
|
1161
|
+
@torch.inference_mode()
|
|
1162
|
+
def inpaint(
|
|
1163
|
+
self,
|
|
1164
|
+
x: Tensor,
|
|
1165
|
+
conditions: Union[List[Tensor], Dict[str, Tensor]],
|
|
1166
|
+
sample_step: int = 100,
|
|
1167
|
+
guidance_strength: float = 4.0,
|
|
1168
|
+
token_mask: Optional[Tensor] = None,
|
|
1169
|
+
) -> Tuple[Tensor, Tensor]:
|
|
1170
|
+
"""
|
|
1171
|
+
Molecule inpaint functionality.
|
|
1172
|
+
|
|
1173
|
+
:param x: categorical indices of scaffold; shape: (n_b, n_t)
|
|
1174
|
+
:param conditions: conditioning vector; shape: (n_b, n_c) * n_h
|
|
1175
|
+
:param sample_step: number of sampling steps
|
|
1176
|
+
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
1177
|
+
:param token_mask: token mask; shape: (1, 1, n_vocab)
|
|
1178
|
+
:type x: torch.Tensor
|
|
1179
|
+
:type conditions: list | dict
|
|
1180
|
+
:type sample_step: int
|
|
1181
|
+
:type guidance_strength: float
|
|
1182
|
+
:type token_mask: torch.Tensor | None
|
|
1183
|
+
:return: sampled token indices; shape: (n_b, n_t) \n
|
|
1184
|
+
entropy of the tokens; shape: (n_b)
|
|
1185
|
+
:rtype: tuple
|
|
1186
|
+
"""
|
|
1187
|
+
y = self.construct_y(conditions)
|
|
1188
|
+
return super().inpaint(x, y, sample_step, guidance_strength, token_mask)
|
|
1189
|
+
|
|
1190
|
+
@torch.inference_mode()
|
|
1191
|
+
def ode_inpaint(
|
|
1192
|
+
self,
|
|
1193
|
+
x: Tensor,
|
|
1194
|
+
conditions: Union[List[Tensor], Dict[str, Tensor]],
|
|
1195
|
+
sample_step: int = 100,
|
|
1196
|
+
guidance_strength: float = 4.0,
|
|
1197
|
+
token_mask: Optional[Tensor] = None,
|
|
1198
|
+
temperature: float = 0.5,
|
|
1199
|
+
) -> Tuple[Tensor, Tensor]:
|
|
1200
|
+
"""
|
|
1201
|
+
ODE inpainting.
|
|
1202
|
+
|
|
1203
|
+
:param x: categorical indices of scaffold; shape: (n_b, n_t)
|
|
1204
|
+
:param conditions: conditioning vector; shape: (n_b, n_c) * n_h
|
|
1205
|
+
:param sample_step: number of sampling steps
|
|
1206
|
+
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
1207
|
+
:param token_mask: token mask; shape: (1, 1, n_vocab)
|
|
1208
|
+
:param temperature: sampling temperature
|
|
1209
|
+
:type x: torch.Tensor
|
|
1210
|
+
:type conditions: list | dict
|
|
1211
|
+
:type sample_step: int
|
|
1212
|
+
:type guidance_strength: float
|
|
1213
|
+
:type token_mask: torch.Tensor | None
|
|
1214
|
+
:type temperature: float
|
|
1215
|
+
:return: sampled token indices; shape: (n_b, n_t) \n
|
|
1216
|
+
entropy of the tokens; shape: (n_b)
|
|
1217
|
+
:rtype: tuple
|
|
1218
|
+
"""
|
|
1219
|
+
y = self.construct_y(conditions)
|
|
1220
|
+
return super().ode_inpaint(
|
|
1221
|
+
x, y, sample_step, guidance_strength, token_mask, temperature
|
|
1222
|
+
)
|
|
1223
|
+
|
|
1224
|
+
def quantise(
|
|
1225
|
+
self, quantise_method: Optional[Callable[[ChemBFN], nn.Module]] = None
|
|
1226
|
+
) -> None:
|
|
1227
|
+
"""
|
|
1228
|
+
Quantise the submodels. \n
|
|
1229
|
+
This method should be called, if necessary, before `torch.compile()`.
|
|
1230
|
+
|
|
1231
|
+
:param quantise_method: quantisation method; default is `bayesianflow_for_chem.tool.quantise_model`
|
|
1232
|
+
:type quantise_method: callable | None
|
|
1233
|
+
:return:
|
|
1234
|
+
:rtype: None
|
|
1235
|
+
"""
|
|
1236
|
+
if quantise_method is None:
|
|
1237
|
+
from bayesianflow_for_chem.tool import quantise_model
|
|
1238
|
+
|
|
1239
|
+
quantise_method = quantise_model
|
|
1240
|
+
for k, v in self.models.items():
|
|
1241
|
+
self.models[k] = quantise_method(v)
|
|
1242
|
+
|
|
1243
|
+
def jit(self, freeze: bool = False) -> None:
|
|
1244
|
+
"""
|
|
1245
|
+
JIT compile the submodels. \n
|
|
1246
|
+
This method should be called, if necessary, before `quantise()` method is called if applied.
|
|
1247
|
+
|
|
1248
|
+
:param freeze: whether to freeze the submodels; default is `False`. If set to `True` this
|
|
1249
|
+
method should be called before moving the model to a different device.
|
|
1250
|
+
:type freeze: bool
|
|
1251
|
+
:return:
|
|
1252
|
+
:rtype: None
|
|
1253
|
+
"""
|
|
1254
|
+
for k, v in self.models.items():
|
|
1255
|
+
self.models[k] = torch.jit.script(v)
|
|
1256
|
+
if freeze:
|
|
1257
|
+
self.models[k] = torch.jit.freeze(
|
|
1258
|
+
self.models[k], ["semi_autoregressive"]
|
|
1259
|
+
)
|
|
1260
|
+
|
|
1261
|
+
@torch.jit.ignore
|
|
1262
|
+
def forward(self, *_, **__) -> None:
|
|
1263
|
+
"""
|
|
1264
|
+
Don't use this method!
|
|
1265
|
+
"""
|
|
1266
|
+
raise NotImplementedError("There's nothing here!")
|
|
1267
|
+
|
|
1268
|
+
def cts_loss(self, *_, **__) -> None:
|
|
1269
|
+
"""
|
|
1270
|
+
Don't use this method!
|
|
1271
|
+
"""
|
|
1272
|
+
raise NotImplementedError("There's nothing here!")
|
|
1273
|
+
|
|
1274
|
+
def reconstruction_loss(self, *_, **__) -> None:
|
|
1275
|
+
"""
|
|
1276
|
+
Don't use this method!
|
|
1277
|
+
"""
|
|
1278
|
+
raise NotImplementedError("There's nothing here!")
|
|
1279
|
+
|
|
1280
|
+
def enable_lora(self, *_, **__) -> None:
|
|
1281
|
+
"""
|
|
1282
|
+
Don't use this method!
|
|
1283
|
+
"""
|
|
1284
|
+
raise NotImplementedError("There's nothing here!")
|
|
1285
|
+
|
|
1286
|
+
def inference(self, *_, **__) -> None:
|
|
1287
|
+
"""
|
|
1288
|
+
Don't use this method!
|
|
1289
|
+
"""
|
|
1290
|
+
raise NotImplementedError("There's nothing here!")
|
|
1291
|
+
|
|
1292
|
+
@classmethod
|
|
1293
|
+
def from_checkpoint(cls, *_, **__) -> None:
|
|
1294
|
+
"""
|
|
1295
|
+
Don't use this method!
|
|
1296
|
+
"""
|
|
1297
|
+
raise NotImplementedError("There's nothing here!")
|
|
1298
|
+
|
|
1299
|
+
|
|
934
1300
|
if __name__ == "__main__":
|
|
935
1301
|
...
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
# -*- coding: utf-8 -*-
|
|
2
2
|
# Author: Nianze A. TAO (Omozawa SUENO)
|
|
3
3
|
"""
|
|
4
|
-
|
|
4
|
+
Essential tools.
|
|
5
5
|
"""
|
|
6
|
-
import re
|
|
7
6
|
import csv
|
|
8
7
|
import random
|
|
8
|
+
import warnings
|
|
9
9
|
from copy import deepcopy
|
|
10
10
|
from pathlib import Path
|
|
11
11
|
from typing import List, Dict, Tuple, Union, Optional
|
|
@@ -16,7 +16,16 @@ from torch import cuda, Tensor, softmax
|
|
|
16
16
|
from torch.ao import quantization
|
|
17
17
|
from torch.utils.data import DataLoader
|
|
18
18
|
from typing_extensions import Self
|
|
19
|
-
from rdkit.Chem import
|
|
19
|
+
from rdkit.Chem.rdchem import Mol, Bond
|
|
20
|
+
from rdkit.Chem import (
|
|
21
|
+
rdDetermineBonds,
|
|
22
|
+
MolFromXYZBlock,
|
|
23
|
+
MolFromSmiles,
|
|
24
|
+
MolToSmiles,
|
|
25
|
+
CanonSmiles,
|
|
26
|
+
AllChem,
|
|
27
|
+
AddHs,
|
|
28
|
+
)
|
|
20
29
|
from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles # type: ignore
|
|
21
30
|
from sklearn.metrics import (
|
|
22
31
|
roc_auc_score,
|
|
@@ -26,35 +35,8 @@ from sklearn.metrics import (
|
|
|
26
35
|
mean_absolute_error,
|
|
27
36
|
root_mean_squared_error,
|
|
28
37
|
)
|
|
29
|
-
|
|
30
|
-
try:
|
|
31
|
-
from pynauty import Graph, canon_label # type: ignore
|
|
32
|
-
|
|
33
|
-
_use_pynauty = True
|
|
34
|
-
except ImportError:
|
|
35
|
-
import warnings
|
|
36
|
-
|
|
37
|
-
_use_pynauty = False
|
|
38
|
-
|
|
39
38
|
from .data import VOCAB_KEYS
|
|
40
|
-
from .model import ChemBFN, MLP, Linear
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
_atom_regex_pattern = (
|
|
44
|
-
r"(H[e,f,g,s,o]?|"
|
|
45
|
-
r"L[i,v,a,r,u]|"
|
|
46
|
-
r"B[e,r,a,i,h,k]?|"
|
|
47
|
-
r"C[l,a,r,o,u,d,s,n,e,m,f]?|"
|
|
48
|
-
r"N[e,a,i,b,h,d,o,p]?|"
|
|
49
|
-
r"O[s,g]?|S[i,c,e,r,n,m,b,g]?|"
|
|
50
|
-
r"K[r]?|T[i,c,e,a,l,b,h,m,s]|"
|
|
51
|
-
r"G[a,e,d]|R[b,u,h,e,n,a,f,g]|"
|
|
52
|
-
r"Yb?|Z[n,r]|P[t,o,d,r,a,u,b,m]?|"
|
|
53
|
-
r"F[e,r,l,m]?|M[g,n,o,t,c,d]|"
|
|
54
|
-
r"A[l,r,s,g,u,t,c,m]|I[n,r]?|"
|
|
55
|
-
r"W|X[e]|E[u,r,s]|U|D[b,s,y])"
|
|
56
|
-
)
|
|
57
|
-
_atom_regex = re.compile(_atom_regex_pattern)
|
|
39
|
+
from .model import ChemBFN, MLP, Linear, EnsembleChemBFN
|
|
58
40
|
|
|
59
41
|
|
|
60
42
|
def _find_device() -> torch.device:
|
|
@@ -65,10 +47,6 @@ def _find_device() -> torch.device:
|
|
|
65
47
|
return torch.device("cpu")
|
|
66
48
|
|
|
67
49
|
|
|
68
|
-
def _bond_pair_idx(bonds: Bond) -> List[List[int]]:
|
|
69
|
-
return [[i.GetBeginAtomIdx(), i.GetEndAtomIdx()] for i in bonds]
|
|
70
|
-
|
|
71
|
-
|
|
72
50
|
@torch.no_grad()
|
|
73
51
|
def test(
|
|
74
52
|
model: ChemBFN,
|
|
@@ -196,11 +174,14 @@ def split_dataset(
|
|
|
196
174
|
"\033[0m",
|
|
197
175
|
stacklevel=2,
|
|
198
176
|
)
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
scaffolds
|
|
202
|
-
|
|
203
|
-
|
|
177
|
+
try:
|
|
178
|
+
scaffold = MurckoScaffoldSmiles(d[smiles_idx[0]])
|
|
179
|
+
if scaffold in scaffolds:
|
|
180
|
+
scaffolds[scaffold].append(key)
|
|
181
|
+
else:
|
|
182
|
+
scaffolds[scaffold] = [key]
|
|
183
|
+
except ValueError: # do nothing when SMILES is not valid
|
|
184
|
+
...
|
|
204
185
|
scaffolds = {key: sorted(value) for key, value in scaffolds.items()}
|
|
205
186
|
train_set, test_set, val_set = [], [], []
|
|
206
187
|
for idxs in scaffolds.values():
|
|
@@ -222,137 +203,13 @@ def split_dataset(
|
|
|
222
203
|
writer.writerows([header] + val_set)
|
|
223
204
|
|
|
224
205
|
|
|
225
|
-
def geo2seq(
|
|
226
|
-
symbols: List[str],
|
|
227
|
-
coordinates: np.ndarray,
|
|
228
|
-
decimals: int = 2,
|
|
229
|
-
angle_unit: str = "degree",
|
|
230
|
-
) -> str:
|
|
231
|
-
"""
|
|
232
|
-
Geometry-to-sequence function.\n
|
|
233
|
-
The algorithm follows the descriptions in paper: https://arxiv.org/abs/2408.10120.
|
|
234
|
-
|
|
235
|
-
:param symbols: a list of atomic symbols
|
|
236
|
-
:param coordinates: Cartesian coordinates; shape: (n_a, 3)
|
|
237
|
-
:param decimals: number of decimal places to round to
|
|
238
|
-
:param angle_unit: `'degree'` or `'radian'`
|
|
239
|
-
:type symbols: list
|
|
240
|
-
:type coordinates: numpy.ndarray
|
|
241
|
-
:type decimals: int
|
|
242
|
-
:type angle_unit: str
|
|
243
|
-
:return: `Geo2Seq` string
|
|
244
|
-
:rtype: str
|
|
245
|
-
"""
|
|
246
|
-
assert angle_unit in ("degree", "radian")
|
|
247
|
-
angle_scale = 180 / np.pi if angle_unit == "degree" else 1.0
|
|
248
|
-
n = len(symbols)
|
|
249
|
-
if n == 1:
|
|
250
|
-
return f"{symbols[0]} {'0.0'} {'0.0'} {'0.0'}"
|
|
251
|
-
xyz_block = [str(n), ""]
|
|
252
|
-
for i, atom in enumerate(symbols):
|
|
253
|
-
xyz_block.append(
|
|
254
|
-
f"{atom} {'%.10f' % coordinates[i][0].item()} {'%.10f' % coordinates[i][1].item()} {'%.10f' % coordinates[i][2].item()}"
|
|
255
|
-
)
|
|
256
|
-
mol = MolFromXYZBlock("\n".join(xyz_block))
|
|
257
|
-
rdDetermineBonds.DetermineConnectivity(mol)
|
|
258
|
-
# ------- Canonicalization -------
|
|
259
|
-
if _use_pynauty:
|
|
260
|
-
pair_idx = np.array(_bond_pair_idx(mol.GetBonds())).T.tolist()
|
|
261
|
-
pair_dict: Dict[int, List[int]] = {}
|
|
262
|
-
for key, i in enumerate(pair_idx[0]):
|
|
263
|
-
if i not in pair_dict:
|
|
264
|
-
pair_dict[i] = [pair_idx[1][key]]
|
|
265
|
-
else:
|
|
266
|
-
pair_dict[i].append(pair_idx[1][key])
|
|
267
|
-
g = Graph(n, adjacency_dict=pair_dict)
|
|
268
|
-
cl = canon_label(g) # type: list
|
|
269
|
-
else:
|
|
270
|
-
warnings.warn(
|
|
271
|
-
"\033[32;1m"
|
|
272
|
-
"`pynauty` is not installed."
|
|
273
|
-
" Switched to canonicalization function provided by `rdkit`."
|
|
274
|
-
" This is the expected behaviour only if you are working on Windows platform."
|
|
275
|
-
"\033[0m",
|
|
276
|
-
stacklevel=2,
|
|
277
|
-
)
|
|
278
|
-
cl = list(CanonicalRankAtoms(mol, breakTies=True))
|
|
279
|
-
symbols = np.array([[s] for s in symbols])[cl].flatten().tolist()
|
|
280
|
-
coordinates = coordinates[cl]
|
|
281
|
-
# ------- Find global coordinate frame -------
|
|
282
|
-
if n == 2:
|
|
283
|
-
d = np.round(np.linalg.norm(coordinates[0] - coordinates[1], 2), decimals)
|
|
284
|
-
return f"{symbols[0]} {'0.0'} {'0.0'} {'0.0'} {symbols[1]} {d} {'0.0'} {'0.0'}"
|
|
285
|
-
for idx_0 in range(n - 2):
|
|
286
|
-
_vec0 = coordinates[idx_0] - coordinates[idx_0 + 1]
|
|
287
|
-
_vec1 = coordinates[idx_0] - coordinates[idx_0 + 2]
|
|
288
|
-
_d1 = np.linalg.norm(_vec0, 2)
|
|
289
|
-
_d2 = np.linalg.norm(_vec1, 2)
|
|
290
|
-
if 1 - np.abs(np.dot(_vec0, _vec1) / (_d1 * _d2)) > 1e-6:
|
|
291
|
-
break
|
|
292
|
-
x = (coordinates[idx_0 + 1] - coordinates[idx_0]) / _d1
|
|
293
|
-
y = np.cross((coordinates[idx_0 + 2] - coordinates[idx_0]), x)
|
|
294
|
-
y_d = np.linalg.norm(y, 2)
|
|
295
|
-
y = y / np.ma.filled(np.ma.array(y_d, mask=y_d == 0), np.inf)
|
|
296
|
-
z = np.cross(x, y)
|
|
297
|
-
# ------- Build spherical coordinates -------
|
|
298
|
-
vec = coordinates - coordinates[idx_0]
|
|
299
|
-
d = np.linalg.norm(vec, 2, axis=-1)
|
|
300
|
-
_d = np.ma.filled(np.ma.array(d, mask=d == 0), np.inf)
|
|
301
|
-
theta = angle_scale * np.arccos(np.dot(vec, z) / _d) # in [0, \pi]
|
|
302
|
-
phi = angle_scale * np.arctan2(np.dot(vec, y), np.dot(vec, x)) # in [-\pi, \pi]
|
|
303
|
-
info = np.vstack([d, theta, phi]).T
|
|
304
|
-
info[idx_0] = np.zeros(3)
|
|
305
|
-
info = [
|
|
306
|
-
f"{symbols[i]} {r[0]} {r[1]} {r[2]}"
|
|
307
|
-
for i, r in enumerate(np.round(info, decimals))
|
|
308
|
-
]
|
|
309
|
-
return " ".join(info)
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
def seq2geo(
|
|
313
|
-
seq: str, angle_unit: str = "degree"
|
|
314
|
-
) -> Optional[Tuple[List[str], List[List[float]]]]:
|
|
315
|
-
"""
|
|
316
|
-
Sequence-to-geometry function.\n
|
|
317
|
-
The method follows the descriptions in paper: https://arxiv.org/abs/2408.10120.
|
|
318
|
-
|
|
319
|
-
:param seq: `Geo2Seq` string
|
|
320
|
-
:param angle_unit: `'degree'` or `'radian'`
|
|
321
|
-
:type seq: str
|
|
322
|
-
:type angle_unit: str
|
|
323
|
-
:return: (symbols, coordinates) if `seq` is valid
|
|
324
|
-
:rtype: tuple | None
|
|
325
|
-
"""
|
|
326
|
-
assert angle_unit in ("degree", "radian")
|
|
327
|
-
angle_scale = np.pi / 180 if angle_unit == "degree" else 1.0
|
|
328
|
-
tokens = seq.split()
|
|
329
|
-
if len(tokens) % 4 == 0:
|
|
330
|
-
tokens = np.array(tokens).reshape(-1, 4).tolist()
|
|
331
|
-
symbols, coordinates = [], []
|
|
332
|
-
for i in tokens:
|
|
333
|
-
symbol = i[0]
|
|
334
|
-
if len(_atom_regex.findall(symbol)) != 1:
|
|
335
|
-
return None
|
|
336
|
-
symbols.append(symbol)
|
|
337
|
-
try:
|
|
338
|
-
d, theta, phi = float(i[1]), float(i[2]), float(i[3])
|
|
339
|
-
x = d * np.sin(theta * angle_scale) * np.cos(phi * angle_scale)
|
|
340
|
-
y = d * np.sin(theta * angle_scale) * np.sin(phi * angle_scale)
|
|
341
|
-
z = d * np.cos(theta * angle_scale)
|
|
342
|
-
coordinates.append([x.item(), y.item(), z.item()])
|
|
343
|
-
except ValueError:
|
|
344
|
-
return None
|
|
345
|
-
return symbols, coordinates
|
|
346
|
-
return None
|
|
347
|
-
|
|
348
|
-
|
|
349
206
|
@torch.no_grad()
|
|
350
207
|
def sample(
|
|
351
|
-
model: ChemBFN,
|
|
208
|
+
model: Union[ChemBFN, EnsembleChemBFN],
|
|
352
209
|
batch_size: int,
|
|
353
210
|
sequence_size: int,
|
|
354
211
|
sample_step: int = 100,
|
|
355
|
-
y: Optional[Tensor] = None,
|
|
212
|
+
y: Optional[Union[Tensor, Dict[str, Tensor], List[Tensor]]] = None,
|
|
356
213
|
guidance_strength: float = 4.0,
|
|
357
214
|
device: Union[str, torch.device, None] = None,
|
|
358
215
|
vocab_keys: List[str] = VOCAB_KEYS,
|
|
@@ -368,7 +225,9 @@ def sample(
|
|
|
368
225
|
:param batch_size: batch size
|
|
369
226
|
:param sequence_size: max sequence length
|
|
370
227
|
:param sample_step: number of sampling steps
|
|
371
|
-
:param y: conditioning vector;
|
|
228
|
+
:param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f) \n
|
|
229
|
+
or a list/`dict` of conditions; shape: (n_b, n_c) * n_h
|
|
230
|
+
|
|
372
231
|
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
373
232
|
:param device: hardware accelerator
|
|
374
233
|
:param vocab_keys: a list of (ordered) vocabulary
|
|
@@ -376,11 +235,11 @@ def sample(
|
|
|
376
235
|
:param method: sampling method chosen from `"ODE:x"` or `"BFN"` where `x` is the value of sampling temperature; default is `"BFN"`
|
|
377
236
|
:param allowed_tokens: a list of allowed tokens; default is `"all"`
|
|
378
237
|
:param sort: whether to sort the samples according to entropy values; default is `False`
|
|
379
|
-
:type model: bayesianflow_for_chem.model.ChemBFN
|
|
238
|
+
:type model: bayesianflow_for_chem.model.ChemBFN | bayesianflow_for_chem.model.EnsembleChemBFN
|
|
380
239
|
:type batch_size: int
|
|
381
240
|
:type sequence_size: int
|
|
382
241
|
:type sample_step: int
|
|
383
|
-
:type y: torch.Tensor | None
|
|
242
|
+
:type y: torch.Tensor | list | dict | None
|
|
384
243
|
:type guidance_strength: float
|
|
385
244
|
:type device: str | torch.device | None
|
|
386
245
|
:type vocab_keys: list
|
|
@@ -392,11 +251,23 @@ def sample(
|
|
|
392
251
|
:rtype: list
|
|
393
252
|
"""
|
|
394
253
|
assert method.split(":")[0].lower() in ("ode", "bfn")
|
|
254
|
+
if isinstance(model, EnsembleChemBFN):
|
|
255
|
+
assert y is not None, "conditioning is required while using an ensemble model."
|
|
256
|
+
assert isinstance(y, list) or isinstance(y, dict)
|
|
257
|
+
else:
|
|
258
|
+
assert isinstance(y, Tensor) or y is None
|
|
395
259
|
if device is None:
|
|
396
260
|
device = _find_device()
|
|
397
261
|
model.to(device).eval()
|
|
398
262
|
if y is not None:
|
|
399
|
-
y
|
|
263
|
+
if isinstance(y, Tensor):
|
|
264
|
+
y = y.to(device)
|
|
265
|
+
elif isinstance(y, list):
|
|
266
|
+
y = [i.to(device) for i in y]
|
|
267
|
+
elif isinstance(y, dict):
|
|
268
|
+
y = {k: v.to(device) for k, v in y.items()}
|
|
269
|
+
else:
|
|
270
|
+
raise NotImplementedError
|
|
400
271
|
if isinstance(allowed_tokens, list):
|
|
401
272
|
token_mask = [0 if i in allowed_tokens else 1 for i in vocab_keys]
|
|
402
273
|
token_mask = torch.tensor([[token_mask]], dtype=torch.bool).to(device)
|
|
@@ -426,10 +297,10 @@ def sample(
|
|
|
426
297
|
|
|
427
298
|
@torch.no_grad()
|
|
428
299
|
def inpaint(
|
|
429
|
-
model: ChemBFN,
|
|
300
|
+
model: Union[ChemBFN, EnsembleChemBFN],
|
|
430
301
|
x: Tensor,
|
|
431
302
|
sample_step: int = 100,
|
|
432
|
-
y: Optional[Tensor] = None,
|
|
303
|
+
y: Optional[Union[Tensor, Dict[str, Tensor], List[Tensor]]] = None,
|
|
433
304
|
guidance_strength: float = 4.0,
|
|
434
305
|
device: Union[str, torch.device, None] = None,
|
|
435
306
|
vocab_keys: List[str] = VOCAB_KEYS,
|
|
@@ -444,7 +315,9 @@ def inpaint(
|
|
|
444
315
|
:param model: trained ChemBFN model
|
|
445
316
|
:param x: categorical indices of scaffold; shape: (n_b, n_t)
|
|
446
317
|
:param sample_step: number of sampling steps
|
|
447
|
-
:param y: conditioning vector; shape: (n_b, 1, n_f)
|
|
318
|
+
:param y: conditioning vector; shape: (n_b, 1, n_f) or (n_b, n_f) \n
|
|
319
|
+
or a list/`dict` of conditions; shape: (n_b, n_c) * n_h
|
|
320
|
+
|
|
448
321
|
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
449
322
|
:param device: hardware accelerator
|
|
450
323
|
:param vocab_keys: a list of (ordered) vocabulary
|
|
@@ -452,10 +325,10 @@ def inpaint(
|
|
|
452
325
|
:param method: sampling method chosen from `"ODE:x"` or `"BFN"` where `x` is the value of sampling temperature; default is `"BFN"`
|
|
453
326
|
:param allowed_tokens: a list of allowed tokens; default is `"all"`
|
|
454
327
|
:param sort: whether to sort the samples according to entropy values; default is `False`
|
|
455
|
-
:type model: bayesianflow_for_chem.model.ChemBFN
|
|
328
|
+
:type model: bayesianflow_for_chem.model.ChemBFN | bayesianflow_for_chem.model.EnsembleChemBFN
|
|
456
329
|
:type x: torch.Tensor
|
|
457
330
|
:type sample_step: int
|
|
458
|
-
:type y: torch.Tensor | None
|
|
331
|
+
:type y: torch.Tensor | list | dict | None
|
|
459
332
|
:type guidance_strength: float
|
|
460
333
|
:type device: str | torch.device | None
|
|
461
334
|
:type vocab_keys: list
|
|
@@ -467,12 +340,24 @@ def inpaint(
|
|
|
467
340
|
:rtype: list
|
|
468
341
|
"""
|
|
469
342
|
assert method.split(":")[0].lower() in ("ode", "bfn")
|
|
343
|
+
if isinstance(model, EnsembleChemBFN):
|
|
344
|
+
assert y is not None, "conditioning is required while using an ensemble model."
|
|
345
|
+
assert isinstance(y, list) or isinstance(y, dict)
|
|
346
|
+
else:
|
|
347
|
+
assert isinstance(y, Tensor) or y is None
|
|
470
348
|
if device is None:
|
|
471
349
|
device = _find_device()
|
|
472
350
|
model.to(device).eval()
|
|
473
351
|
x = x.to(device)
|
|
474
352
|
if y is not None:
|
|
475
|
-
y
|
|
353
|
+
if isinstance(y, Tensor):
|
|
354
|
+
y = y.to(device)
|
|
355
|
+
elif isinstance(y, list):
|
|
356
|
+
y = [i.to(device) for i in y]
|
|
357
|
+
elif isinstance(y, dict):
|
|
358
|
+
y = {k: v.to(device) for k, v in y.items()}
|
|
359
|
+
else:
|
|
360
|
+
raise NotImplementedError
|
|
476
361
|
if isinstance(allowed_tokens, list):
|
|
477
362
|
token_mask = [0 if i in allowed_tokens else 1 for i in vocab_keys]
|
|
478
363
|
token_mask = torch.tensor([[token_mask]], dtype=torch.bool).to(device)
|
|
@@ -585,6 +470,8 @@ def quantise_model(model: ChemBFN) -> nn.Module:
|
|
|
585
470
|
assert hasattr(
|
|
586
471
|
mod, "qconfig"
|
|
587
472
|
), "Input float module must have qconfig defined"
|
|
473
|
+
if use_precomputed_fake_quant:
|
|
474
|
+
warnings.warn("Fake quantize operator is not implemented.")
|
|
588
475
|
if mod.qconfig is not None and mod.qconfig.weight is not None:
|
|
589
476
|
weight_observer = mod.qconfig.weight()
|
|
590
477
|
else:
|
|
@@ -637,3 +524,81 @@ def quantise_model(model: ChemBFN) -> nn.Module:
|
|
|
637
524
|
model, {nn.Linear, Linear}, torch.qint8, mapping
|
|
638
525
|
)
|
|
639
526
|
return quantised_model
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
class GeometryConverter:
|
|
530
|
+
"""
|
|
531
|
+
Converting between different 2D/3D molecular representations.
|
|
532
|
+
"""
|
|
533
|
+
|
|
534
|
+
@staticmethod
|
|
535
|
+
def _xyz2mol(symbols: List[str], coordinates: np.ndarray) -> Mol:
|
|
536
|
+
xyz_block = [str(len(symbols)), ""]
|
|
537
|
+
r = coordinates
|
|
538
|
+
for i, atom in enumerate(symbols):
|
|
539
|
+
xyz_block.append(f"{atom} {r[i][0]:.10f} {r[i][1]:.10f} {r[i][2]:.10f}")
|
|
540
|
+
return MolFromXYZBlock("\n".join(xyz_block))
|
|
541
|
+
|
|
542
|
+
@staticmethod
|
|
543
|
+
def _bond_pair_idx(bonds: Bond) -> List[List[int]]:
|
|
544
|
+
return [[i.GetBeginAtomIdx(), i.GetEndAtomIdx()] for i in bonds]
|
|
545
|
+
|
|
546
|
+
@staticmethod
|
|
547
|
+
def smiles2cartesian(
|
|
548
|
+
smiles: str, num_conformers: int = 50, random_seed: int = 42
|
|
549
|
+
) -> Tuple[List[str], np.ndarray]:
|
|
550
|
+
"""
|
|
551
|
+
Guess the 3D geometry from SMILES string via MMFF conformer search.
|
|
552
|
+
|
|
553
|
+
:param smiles: a valid SMILES string
|
|
554
|
+
:param num_conformers: number of initial conformers
|
|
555
|
+
:param random_seed: random seed used to generate conformers
|
|
556
|
+
:type smiles: str
|
|
557
|
+
:type num_conformers: int
|
|
558
|
+
:type random_seed: int
|
|
559
|
+
:return: atomic symbols \n
|
|
560
|
+
cartesian coordinates; shape: (n_a, 3)
|
|
561
|
+
:rtype: tuple
|
|
562
|
+
"""
|
|
563
|
+
mol = MolFromSmiles(smiles)
|
|
564
|
+
mol = AddHs(mol)
|
|
565
|
+
AllChem.EmbedMultipleConfs(mol, numConfs=num_conformers, randomSeed=random_seed)
|
|
566
|
+
symbols = [atom.GetSymbol() for atom in mol.GetAtoms()]
|
|
567
|
+
energies = []
|
|
568
|
+
for conf_id in range(num_conformers):
|
|
569
|
+
ff = AllChem.MMFFGetMoleculeForceField(
|
|
570
|
+
mol, AllChem.MMFFGetMoleculeProperties(mol), confId=conf_id
|
|
571
|
+
)
|
|
572
|
+
energy = ff.CalcEnergy()
|
|
573
|
+
energies.append((conf_id, energy))
|
|
574
|
+
lowest_energy_conf = min(energies, key=lambda x: x[1])
|
|
575
|
+
coordinates = mol.GetConformer(id=lowest_energy_conf[0]).GetPositions()
|
|
576
|
+
return symbols, coordinates
|
|
577
|
+
|
|
578
|
+
def cartesian2smiles(
|
|
579
|
+
self,
|
|
580
|
+
symbols: List[str],
|
|
581
|
+
coordinates: np.ndarray,
|
|
582
|
+
charge: int = 0,
|
|
583
|
+
canonical: bool = True,
|
|
584
|
+
) -> str:
|
|
585
|
+
"""
|
|
586
|
+
Transform (guess out) molecular geometry to SMILES string.
|
|
587
|
+
|
|
588
|
+
:param symbols: a list of atomic symbols
|
|
589
|
+
:param coordinates: Cartesian coordinates; shape: (n_a, 3)
|
|
590
|
+
:param charge: net charge
|
|
591
|
+
:param canonical: whether to canonicalise the SMILES
|
|
592
|
+
:type symbols: list
|
|
593
|
+
:type coordinates: numpy.ndarray
|
|
594
|
+
:type charge: int
|
|
595
|
+
:type canonical: bool
|
|
596
|
+
:return: SMILES string
|
|
597
|
+
:rtype: str
|
|
598
|
+
"""
|
|
599
|
+
mol = self._xyz2mol(symbols, coordinates)
|
|
600
|
+
rdDetermineBonds.DetermineBonds(mol, charge=charge)
|
|
601
|
+
smiles = MolToSmiles(mol)
|
|
602
|
+
if canonical:
|
|
603
|
+
smiles = CanonSmiles(smiles)
|
|
604
|
+
return smiles
|
|
@@ -37,7 +37,8 @@ class Model(LightningModule):
|
|
|
37
37
|
"""
|
|
38
38
|
A `~lightning.LightningModule` wrapper of bayesian flow network for chemistry model.\n
|
|
39
39
|
This module is used in training stage only. By calling `Model(...).export_model(YOUR_WORK_DIR)` after training,
|
|
40
|
-
the model(s) will be saved to `YOUR_WORK_DIR/model.pt`
|
|
40
|
+
the model(s) will be saved to `YOUR_WORK_DIR/model.pt` (if LoRA is enabled then `YOUR_WORK_DIR/lora.pt`)
|
|
41
|
+
and (if exists) `YOUR_WORK_DIR/mlp.pt`.
|
|
41
42
|
|
|
42
43
|
:param model: `~bayesianflow_for_chem.model.ChemBFN` instance.
|
|
43
44
|
:param mlp: `~bayesianflow_for_chem.model.MLP` instance or `None`.
|
|
@@ -135,7 +136,8 @@ class Regressor(LightningModule):
|
|
|
135
136
|
"""
|
|
136
137
|
A `~lightning.LightningModule` wrapper of bayesian flow network for chemistry regression model.\n
|
|
137
138
|
This module is used in training stage only. By calling `Regressor(...).export_model(YOUR_WORK_DIR)` after training,
|
|
138
|
-
the models will be saved to `YOUR_WORK_DIR/
|
|
139
|
+
the models will be saved to `YOUR_WORK_DIR/model_ft.pt` (if LoRA is enabled then `YOUR_WORK_DIR/lora.pt`)
|
|
140
|
+
and `YOUR_WORK_DIR/readout.pt`.
|
|
139
141
|
|
|
140
142
|
:param model: `~bayesianflow_for_chem.model.ChemBFN` instance.
|
|
141
143
|
:param mlp: `~bayesianflow_for_chem.model.MLP` instance.
|
|
@@ -218,7 +220,7 @@ class Regressor(LightningModule):
|
|
|
218
220
|
"""
|
|
219
221
|
Save the trained model.
|
|
220
222
|
|
|
221
|
-
:param workdir: the directory to save the
|
|
223
|
+
:param workdir: the directory to save the models
|
|
222
224
|
:type workdir: pathlib.Path
|
|
223
225
|
:return:
|
|
224
226
|
:rtype: None
|
{bayesianflow_for_chem-1.2.7 → bayesianflow_for_chem-1.4.0}/bayesianflow_for_chem.egg-info/PKG-INFO
RENAMED
|
@@ -1,16 +1,15 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: bayesianflow_for_chem
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.4.0
|
|
4
4
|
Summary: Bayesian flow network framework for Chemistry
|
|
5
5
|
Home-page: https://augus1999.github.io/bayesian-flow-network-for-chemistry/
|
|
6
6
|
Author: Nianze A. Tao
|
|
7
7
|
Author-email: tao-nianze@hiroshima-u.ac.jp
|
|
8
|
-
License: AGPL-3.0
|
|
8
|
+
License: AGPL-3.0-or-later
|
|
9
9
|
Project-URL: Source, https://github.com/Augus1999/bayesian-flow-network-for-chemistry
|
|
10
10
|
Keywords: Chemistry,CLM,ChemBFN
|
|
11
11
|
Classifier: Development Status :: 5 - Production/Stable
|
|
12
12
|
Classifier: Intended Audience :: Science/Research
|
|
13
|
-
Classifier: License :: OSI Approved :: GNU Affero General Public License v3
|
|
14
13
|
Classifier: Natural Language :: English
|
|
15
14
|
Classifier: Programming Language :: Python :: 3
|
|
16
15
|
Classifier: Programming Language :: Python :: 3.9
|
|
@@ -29,8 +28,6 @@ Requires-Dist: loralib>=0.1.2
|
|
|
29
28
|
Requires-Dist: lightning>=2.2.0
|
|
30
29
|
Requires-Dist: scikit-learn>=1.5.0
|
|
31
30
|
Requires-Dist: typing_extensions>=4.8.0
|
|
32
|
-
Provides-Extra: geo2seq
|
|
33
|
-
Requires-Dist: pynauty>=2.8.8.1; extra == "geo2seq"
|
|
34
31
|
Dynamic: author
|
|
35
32
|
Dynamic: author-email
|
|
36
33
|
Dynamic: classifier
|
|
@@ -41,7 +38,6 @@ Dynamic: keywords
|
|
|
41
38
|
Dynamic: license
|
|
42
39
|
Dynamic: license-file
|
|
43
40
|
Dynamic: project-url
|
|
44
|
-
Dynamic: provides-extra
|
|
45
41
|
Dynamic: requires-dist
|
|
46
42
|
Dynamic: requires-python
|
|
47
43
|
Dynamic: summary
|
|
@@ -87,13 +83,13 @@ You can find example scripts in [📁example](./example) folder.
|
|
|
87
83
|
|
|
88
84
|
## Pre-trained Model
|
|
89
85
|
|
|
90
|
-
You can find pretrained models
|
|
86
|
+
You can find pretrained models on our [🤗Hugging Face model page](https://huggingface.co/suenoomozawa/ChemBFN).
|
|
91
87
|
|
|
92
88
|
## Dataset Handling
|
|
93
89
|
|
|
94
90
|
We provide a Python class [`CSVData`](./bayesianflow_for_chem/data.py) to handle data stored in CSV or similar format containing headers to identify the entities. The following is a quickstart.
|
|
95
91
|
|
|
96
|
-
1. Download your dataset file (e.g., ESOL
|
|
92
|
+
1. Download your dataset file (e.g., ESOL from [MoleculeNet](https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv)) and split the file:
|
|
97
93
|
```python
|
|
98
94
|
>>> from bayesianflow_for_chem.tool import split_data
|
|
99
95
|
|
|
@@ -28,7 +28,8 @@ setup(
|
|
|
28
28
|
description="Bayesian flow network framework for Chemistry",
|
|
29
29
|
long_description=long_description,
|
|
30
30
|
long_description_content_type="text/markdown",
|
|
31
|
-
license="AGPL-3.0
|
|
31
|
+
license="AGPL-3.0-or-later",
|
|
32
|
+
license_files=["LICEN[CS]E*"],
|
|
32
33
|
package_dir={"bayesianflow_for_chem": "bayesianflow_for_chem"},
|
|
33
34
|
package_data={"bayesianflow_for_chem": ["./*.txt", "./*.py"]},
|
|
34
35
|
include_package_data=True,
|
|
@@ -45,14 +46,12 @@ setup(
|
|
|
45
46
|
"scikit-learn>=1.5.0",
|
|
46
47
|
"typing_extensions>=4.8.0",
|
|
47
48
|
],
|
|
48
|
-
extras_require={"geo2seq": ["pynauty>=2.8.8.1"]},
|
|
49
49
|
project_urls={
|
|
50
50
|
"Source": "https://github.com/Augus1999/bayesian-flow-network-for-chemistry"
|
|
51
51
|
},
|
|
52
52
|
classifiers=[
|
|
53
53
|
"Development Status :: 5 - Production/Stable",
|
|
54
54
|
"Intended Audience :: Science/Research",
|
|
55
|
-
"License :: OSI Approved :: GNU Affero General Public License v3",
|
|
56
55
|
"Natural Language :: English",
|
|
57
56
|
"Programming Language :: Python :: 3",
|
|
58
57
|
"Programming Language :: Python :: 3.9",
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|