bayesianflow-for-chem 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bayesianflow-for-chem might be problematic. Click here for more details.
- bayesianflow_for_chem/__init__.py +11 -0
- bayesianflow_for_chem/data.py +250 -0
- bayesianflow_for_chem/model.py +927 -0
- bayesianflow_for_chem/scorer.py +134 -0
- bayesianflow_for_chem/tool.py +470 -0
- bayesianflow_for_chem/train.py +243 -0
- bayesianflow_for_chem/vocab.txt +246 -0
- bayesianflow_for_chem-1.2.0.dist-info/METADATA +162 -0
- bayesianflow_for_chem-1.2.0.dist-info/RECORD +11 -0
- bayesianflow_for_chem-1.2.0.dist-info/WHEEL +5 -0
- bayesianflow_for_chem-1.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Author: Nianze A. TAO (Omozawa SUENO)
|
|
3
|
+
"""
|
|
4
|
+
Scorers.
|
|
5
|
+
"""
|
|
6
|
+
from typing import List, Callable, Union, Optional
|
|
7
|
+
import torch
|
|
8
|
+
from torch import Tensor
|
|
9
|
+
from rdkit import RDLogger
|
|
10
|
+
from rdkit.Contrib.SA_Score import sascorer # type: ignore
|
|
11
|
+
from rdkit.Chem import MolFromSmiles, QED
|
|
12
|
+
|
|
13
|
+
RDLogger.DisableLog("rdApp.*") # type: ignore
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def smiles_valid(smiles: str) -> int:
|
|
17
|
+
"""
|
|
18
|
+
Return the validity of a SMILES string.
|
|
19
|
+
|
|
20
|
+
:param smiles: SMIlES string
|
|
21
|
+
:type smiles: str
|
|
22
|
+
:return: validity
|
|
23
|
+
:rtype: int
|
|
24
|
+
"""
|
|
25
|
+
return 1 if (MolFromSmiles(smiles) and smiles) else 0
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def qed_score(smiles: str) -> float:
|
|
29
|
+
"""
|
|
30
|
+
Return the quantitative estimate of drug-likeness score of a SMILES string.
|
|
31
|
+
|
|
32
|
+
:param smiles: SMILES string
|
|
33
|
+
:type smiles: str
|
|
34
|
+
:return: QED score
|
|
35
|
+
:rtype: float
|
|
36
|
+
"""
|
|
37
|
+
return QED.qed(MolFromSmiles(smiles))
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def sa_score(smiles: str) -> float:
|
|
41
|
+
"""
|
|
42
|
+
Return the synthetic accessibility score of a SMILES string.
|
|
43
|
+
|
|
44
|
+
:param smiles: SMILES string
|
|
45
|
+
:type smiles: str
|
|
46
|
+
:return: SA score
|
|
47
|
+
:rtype: float
|
|
48
|
+
"""
|
|
49
|
+
return sascorer.calculateScore(MolFromSmiles(smiles))
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class Scorer:
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
scorers: List[Callable[[str], Union[int, float]]],
|
|
56
|
+
score_criteria: List[Callable[[Union[int, float]], float]],
|
|
57
|
+
vocab_keys: List[str],
|
|
58
|
+
vocab_separator: str = "",
|
|
59
|
+
valid_checker: Optional[Callable[[str], int]] = None,
|
|
60
|
+
eta: float = 1e-2,
|
|
61
|
+
name: str = "scorer",
|
|
62
|
+
) -> None:
|
|
63
|
+
"""
|
|
64
|
+
Scorer class.
|
|
65
|
+
e.g.
|
|
66
|
+
|
|
67
|
+
```python
|
|
68
|
+
scorer = Scorer(
|
|
69
|
+
scorers=[smiles_valid, qed_score],
|
|
70
|
+
score_criteria=[lambda x: float(x == 1), lambda x: float(x > 0.5)],
|
|
71
|
+
vocab_keys=VOCAB_KEYS,
|
|
72
|
+
)
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
:param scorers: a list of scorer(s)
|
|
76
|
+
:param score_criteria: a list of score criterion (or criteria) in the same order of scorers
|
|
77
|
+
:param vocab_keys: a list of (ordered) vocabulary
|
|
78
|
+
:param vocab_separator: token separator; default is `""`
|
|
79
|
+
:param valid_checker: a callable to check the validity of sequences; default is `None`
|
|
80
|
+
:param eta: the coefficient to be multiplied to the loss
|
|
81
|
+
:param name: the name of this scorer
|
|
82
|
+
:type scorers: list
|
|
83
|
+
:type score_criteria: list
|
|
84
|
+
:type vocab_keys: list
|
|
85
|
+
:type vocab_separator: str
|
|
86
|
+
:type eta: float
|
|
87
|
+
:type name: str
|
|
88
|
+
:type valid_checker: typing.Callable | None
|
|
89
|
+
"""
|
|
90
|
+
assert len(scorers) == len(
|
|
91
|
+
score_criteria
|
|
92
|
+
), "The number of scores should match that of criteria."
|
|
93
|
+
self.scorers = scorers
|
|
94
|
+
self.score_criteria = score_criteria
|
|
95
|
+
self.vocab_keys = vocab_keys
|
|
96
|
+
self.vocab_separator = vocab_separator
|
|
97
|
+
self.valid_checker = valid_checker
|
|
98
|
+
self.eta = eta
|
|
99
|
+
self.name = name
|
|
100
|
+
|
|
101
|
+
def calc_score_loss(self, p: Tensor) -> Tensor:
|
|
102
|
+
"""
|
|
103
|
+
Calculate the score loss.
|
|
104
|
+
|
|
105
|
+
:param p: token probability distributions; shape: (n_b, n_t, n_vocab)
|
|
106
|
+
:type p: torch.Tensor
|
|
107
|
+
:return: score loss; shape: ()
|
|
108
|
+
:rtype: torch.Tensor
|
|
109
|
+
"""
|
|
110
|
+
tokens = p.argmax(-1)
|
|
111
|
+
e_k = torch.nn.functional.one_hot(tokens, len(self.vocab_keys)).float()
|
|
112
|
+
seqs = [
|
|
113
|
+
self.vocab_separator.join([self.vocab_keys[i] for i in j])
|
|
114
|
+
.split("<start>" + self.vocab_separator)[-1]
|
|
115
|
+
.split(self.vocab_separator + "<end>")[0]
|
|
116
|
+
.replace("<pad>", "")
|
|
117
|
+
for j in tokens
|
|
118
|
+
]
|
|
119
|
+
valid = [
|
|
120
|
+
1 if self.valid_checker is None else self.valid_checker(i) for i in seqs
|
|
121
|
+
]
|
|
122
|
+
scores = [
|
|
123
|
+
[
|
|
124
|
+
1 if valid[j] == 0 else 1 - self.score_criteria[i](scorer(seq))
|
|
125
|
+
for j, seq in enumerate(seqs)
|
|
126
|
+
]
|
|
127
|
+
for i, scorer in enumerate(self.scorers)
|
|
128
|
+
]
|
|
129
|
+
loss = (e_k * p).sum(2).mean(1) * torch.tensor(scores, device=p.device).mean(0)
|
|
130
|
+
return loss.mean()
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
if __name__ == "__main__":
|
|
134
|
+
...
|
|
@@ -0,0 +1,470 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Author: Nianze A. TAO (Omozawa SUENO)
|
|
3
|
+
"""
|
|
4
|
+
Tools.
|
|
5
|
+
"""
|
|
6
|
+
import re
|
|
7
|
+
import csv
|
|
8
|
+
import random
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import List, Dict, Tuple, Union, Optional
|
|
11
|
+
import torch
|
|
12
|
+
import numpy as np
|
|
13
|
+
from torch import cuda, Tensor, softmax
|
|
14
|
+
from torch.utils.data import DataLoader
|
|
15
|
+
from rdkit.Chem import rdDetermineBonds, Bond, MolFromXYZBlock, CanonicalRankAtoms
|
|
16
|
+
from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles # type: ignore
|
|
17
|
+
from sklearn.metrics import (
|
|
18
|
+
roc_auc_score,
|
|
19
|
+
auc,
|
|
20
|
+
precision_recall_curve,
|
|
21
|
+
r2_score,
|
|
22
|
+
mean_absolute_error,
|
|
23
|
+
root_mean_squared_error,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
try:
|
|
27
|
+
from pynauty import Graph, canon_label # type: ignore
|
|
28
|
+
|
|
29
|
+
_use_pynauty = True
|
|
30
|
+
except ImportError:
|
|
31
|
+
import warnings
|
|
32
|
+
|
|
33
|
+
_use_pynauty = False
|
|
34
|
+
|
|
35
|
+
from .data import VOCAB_KEYS
|
|
36
|
+
from .model import ChemBFN, MLP
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
_atom_regex_pattern = (
|
|
40
|
+
r"(H[e,f,g,s,o]?|"
|
|
41
|
+
r"L[i,v,a,r,u]|"
|
|
42
|
+
r"B[e,r,a,i,h,k]?|"
|
|
43
|
+
r"C[l,a,r,o,u,d,s,n,e,m,f]?|"
|
|
44
|
+
r"N[e,a,i,b,h,d,o,p]?|"
|
|
45
|
+
r"O[s,g]?|S[i,c,e,r,n,m,b,g]?|"
|
|
46
|
+
r"K[r]?|T[i,c,e,a,l,b,h,m,s]|"
|
|
47
|
+
r"G[a,e,d]|R[b,u,h,e,n,a,f,g]|"
|
|
48
|
+
r"Yb?|Z[n,r]|P[t,o,d,r,a,u,b,m]?|"
|
|
49
|
+
r"F[e,r,l,m]?|M[g,n,o,t,c,d]|"
|
|
50
|
+
r"A[l,r,s,g,u,t,c,m]|I[n,r]?|"
|
|
51
|
+
r"W|X[e]|E[u,r,s]|U|D[b,s,y])"
|
|
52
|
+
)
|
|
53
|
+
_atom_regex = re.compile(_atom_regex_pattern)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _find_device() -> torch.device:
|
|
57
|
+
if cuda.is_available():
|
|
58
|
+
return torch.device("cuda")
|
|
59
|
+
elif torch.backends.mps.is_available():
|
|
60
|
+
return torch.device("mps")
|
|
61
|
+
return torch.device("cpu")
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _bond_pair_idx(bonds: Bond) -> List[List[int]]:
|
|
65
|
+
return [[i.GetBeginAtomIdx(), i.GetEndAtomIdx()] for i in bonds]
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@torch.no_grad()
|
|
69
|
+
def test(
|
|
70
|
+
model: ChemBFN,
|
|
71
|
+
mlp: MLP,
|
|
72
|
+
data: DataLoader,
|
|
73
|
+
mode: str = "regression",
|
|
74
|
+
device: Union[str, torch.device, None] = None,
|
|
75
|
+
) -> Dict[str, float]:
|
|
76
|
+
"""
|
|
77
|
+
Test the trained network.
|
|
78
|
+
|
|
79
|
+
:param model: pretrained ChemBFN model
|
|
80
|
+
:param mlp: trained MLP model for testing
|
|
81
|
+
:param data: DataLoader instance
|
|
82
|
+
:param mode: testing mode chosen from `'regression'` and `'classification'`
|
|
83
|
+
:param device: hardware accelerator
|
|
84
|
+
:type model: bayesianflow_for_chem.model.ChemBFN
|
|
85
|
+
:type mlp: bayesianflow_for_chem.model.MLP
|
|
86
|
+
:type data: torch.utils.data.DataLoader
|
|
87
|
+
:type mode: str
|
|
88
|
+
:type device: str | torch.device | None
|
|
89
|
+
:return: MAE & RMSE & R^2 / ROC-AUC & PRC-AUC
|
|
90
|
+
:rtype: dict
|
|
91
|
+
"""
|
|
92
|
+
if device is None:
|
|
93
|
+
device = _find_device()
|
|
94
|
+
model.to(device).eval()
|
|
95
|
+
mlp.to(device).eval()
|
|
96
|
+
predict_y, label_y = [], []
|
|
97
|
+
for d in data:
|
|
98
|
+
x, y = d["token"].to(device), d["value"]
|
|
99
|
+
label_y.append(y)
|
|
100
|
+
if mode == "regression":
|
|
101
|
+
y_hat = model.inference(x, mlp)
|
|
102
|
+
if mode == "classification":
|
|
103
|
+
n_b, n_y = y.shape
|
|
104
|
+
y_hat = softmax(model.inference(x, mlp).reshape(n_b * n_y, -1), -1)
|
|
105
|
+
y_hat = y_hat.reshape(n_b, -1)
|
|
106
|
+
predict_y.append(y_hat.detach().to("cpu"))
|
|
107
|
+
predict_y, label_y = torch.cat(predict_y, 0), torch.cat(label_y, 0).split(1, -1)
|
|
108
|
+
if mode == "regression":
|
|
109
|
+
predict_y = [
|
|
110
|
+
predict[label_y[i] != torch.inf]
|
|
111
|
+
for (i, predict) in enumerate(predict_y.split(1, -1))
|
|
112
|
+
]
|
|
113
|
+
label_y = [label[label != torch.inf] for label in label_y]
|
|
114
|
+
y_zipped = list(zip(label_y, predict_y))
|
|
115
|
+
mae = [mean_absolute_error(label, predict) for (label, predict) in y_zipped]
|
|
116
|
+
rmse = [
|
|
117
|
+
root_mean_squared_error(label, predict) for (label, predict) in y_zipped
|
|
118
|
+
]
|
|
119
|
+
r2 = [r2_score(label, predict) for (label, predict) in y_zipped]
|
|
120
|
+
return {"MAE": mae, "RMSE": rmse, "R^2": r2}
|
|
121
|
+
if mode == "classification":
|
|
122
|
+
n_c = len(label_y)
|
|
123
|
+
predict_y = predict_y.chunk(n_c, -1)
|
|
124
|
+
y_zipped = list(zip(label_y, predict_y))
|
|
125
|
+
roc_auc = [
|
|
126
|
+
roc_auc_score(
|
|
127
|
+
label.flatten(),
|
|
128
|
+
predict[:, 1] if predict.shape[-1] == 2 else predict,
|
|
129
|
+
multi_class="raise" if predict.shape[-1] == 2 else "ovo",
|
|
130
|
+
labels=None if predict.shape[-1] == 2 else range(predict.shape[-1]),
|
|
131
|
+
)
|
|
132
|
+
for (label, predict) in y_zipped
|
|
133
|
+
]
|
|
134
|
+
try:
|
|
135
|
+
prc = [
|
|
136
|
+
precision_recall_curve(label.flatten(), predict[:, 1])[:2]
|
|
137
|
+
for (label, predict) in y_zipped
|
|
138
|
+
]
|
|
139
|
+
prc_auc = [auc(recall, precision) for (precision, recall) in prc]
|
|
140
|
+
except ValueError:
|
|
141
|
+
prc_auc = []
|
|
142
|
+
return {"ROC-AUC": roc_auc, "PRC-AUC": prc_auc}
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def split_dataset(
|
|
146
|
+
file: Union[str, Path], split_ratio: List[int] = [8, 1, 1], method: str = "random"
|
|
147
|
+
) -> None:
|
|
148
|
+
"""
|
|
149
|
+
Split a dataset.
|
|
150
|
+
|
|
151
|
+
:param file: dataset file <file>
|
|
152
|
+
:param split_ratio: traing-testing-validation ratio
|
|
153
|
+
:param method: chosen from `'random'` and `'scaffold'`
|
|
154
|
+
:type file: str | pathlib.Path
|
|
155
|
+
:type split_ratio: list
|
|
156
|
+
:type method: str
|
|
157
|
+
:return:
|
|
158
|
+
:rtype: None
|
|
159
|
+
"""
|
|
160
|
+
assert file.endswith(".csv")
|
|
161
|
+
assert len(split_ratio) == 3
|
|
162
|
+
assert method in ("random", "scaffold")
|
|
163
|
+
with open(file, "r") as f:
|
|
164
|
+
data = list(csv.reader(f))
|
|
165
|
+
header = data[0]
|
|
166
|
+
raw_data = data[1:]
|
|
167
|
+
smiles_idx = [] # only first index will be used
|
|
168
|
+
for key, h in enumerate(header):
|
|
169
|
+
if h.lower() == "smiles":
|
|
170
|
+
smiles_idx.append(key)
|
|
171
|
+
assert len(smiles_idx) > 0
|
|
172
|
+
data_len = len(raw_data)
|
|
173
|
+
train_ratio = split_ratio[0] / sum(split_ratio)
|
|
174
|
+
test_ratio = sum(split_ratio[:2]) / sum(split_ratio)
|
|
175
|
+
train_idx, test_idx = int(data_len * train_ratio), int(data_len * test_ratio)
|
|
176
|
+
if method == "random":
|
|
177
|
+
random.shuffle(raw_data)
|
|
178
|
+
train_set = raw_data[:train_idx]
|
|
179
|
+
test_set = raw_data[train_idx:test_idx]
|
|
180
|
+
val_set = raw_data[test_idx:]
|
|
181
|
+
if method == "scaffold":
|
|
182
|
+
scaffolds: Dict[str, List] = {}
|
|
183
|
+
for key, d in enumerate(raw_data):
|
|
184
|
+
# compute Bemis-Murcko scaffold
|
|
185
|
+
scaffold = MurckoScaffoldSmiles(d[smiles_idx[0]])
|
|
186
|
+
if scaffold in scaffolds:
|
|
187
|
+
scaffolds[scaffold].append(key)
|
|
188
|
+
else:
|
|
189
|
+
scaffolds[scaffold] = [key]
|
|
190
|
+
scaffolds = {key: sorted(value) for key, value in scaffolds.items()}
|
|
191
|
+
train_set, test_set, val_set = [], [], []
|
|
192
|
+
for idxs in scaffolds.values():
|
|
193
|
+
if len(train_set) + len(idxs) > train_idx:
|
|
194
|
+
if len(train_set) + len(test_set) + len(idxs) > test_idx:
|
|
195
|
+
val_set += [raw_data[i] for i in idxs]
|
|
196
|
+
else:
|
|
197
|
+
test_set += [raw_data[i] for i in idxs]
|
|
198
|
+
else:
|
|
199
|
+
train_set += [raw_data[i] for i in idxs]
|
|
200
|
+
with open(file.replace(".csv", "_train.csv"), "w", newline="") as ftr:
|
|
201
|
+
writer = csv.writer(ftr)
|
|
202
|
+
writer.writerows([header] + train_set)
|
|
203
|
+
with open(file.replace(".csv", "_test.csv"), "w", newline="") as fte:
|
|
204
|
+
writer = csv.writer(fte)
|
|
205
|
+
writer.writerows([header] + test_set)
|
|
206
|
+
with open(file.replace(".csv", "_val.csv"), "w", newline="") as fva:
|
|
207
|
+
writer = csv.writer(fva)
|
|
208
|
+
writer.writerows([header] + val_set)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def geo2seq(
|
|
212
|
+
symbols: List[str],
|
|
213
|
+
coordinates: np.ndarray,
|
|
214
|
+
decimals: int = 2,
|
|
215
|
+
angle_unit: str = "degree",
|
|
216
|
+
) -> str:
|
|
217
|
+
"""
|
|
218
|
+
Geometry-to-sequence function.\n
|
|
219
|
+
The algorithm follows the descriptions in paper: https://arxiv.org/abs/2408.10120.
|
|
220
|
+
|
|
221
|
+
:param symbols: a list of atomic symbols
|
|
222
|
+
:param coordinates: Cartesian coordinates; shape: (n_a, 3)
|
|
223
|
+
:param decimals: number of decimal places to round to
|
|
224
|
+
:param angle_unit: `'degree'` or `'radian'`
|
|
225
|
+
:type symbols: list
|
|
226
|
+
:type coordinates: numpy.ndarray
|
|
227
|
+
:type decimals: int
|
|
228
|
+
:type angle_unit: str
|
|
229
|
+
:return: `Geo2Seq` string
|
|
230
|
+
:rtype: str
|
|
231
|
+
"""
|
|
232
|
+
assert angle_unit in ("degree", "radian")
|
|
233
|
+
angle_scale = 180 / np.pi if angle_unit == "degree" else 1.0
|
|
234
|
+
n = len(symbols)
|
|
235
|
+
if n == 1:
|
|
236
|
+
return f"{symbols[0]} {'0.0'} {'0.0'} {'0.0'}"
|
|
237
|
+
xyz_block = [str(n), ""]
|
|
238
|
+
for i, atom in enumerate(symbols):
|
|
239
|
+
xyz_block.append(
|
|
240
|
+
f"{atom} {'%.10f' % coordinates[i][0].item()} {'%.10f' % coordinates[i][1].item()} {'%.10f' % coordinates[i][2].item()}"
|
|
241
|
+
)
|
|
242
|
+
mol = MolFromXYZBlock("\n".join(xyz_block))
|
|
243
|
+
rdDetermineBonds.DetermineConnectivity(mol)
|
|
244
|
+
# ------- Canonicalization -------
|
|
245
|
+
if _use_pynauty:
|
|
246
|
+
pair_idx = np.array(_bond_pair_idx(mol.GetBonds())).T.tolist()
|
|
247
|
+
pair_dict: Dict[int, List[int]] = {}
|
|
248
|
+
for key, i in enumerate(pair_idx[0]):
|
|
249
|
+
if i not in pair_dict:
|
|
250
|
+
pair_dict[i] = [pair_idx[1][key]]
|
|
251
|
+
else:
|
|
252
|
+
pair_dict[i].append(pair_idx[1][key])
|
|
253
|
+
g = Graph(n, adjacency_dict=pair_dict)
|
|
254
|
+
cl = canon_label(g) # type: list
|
|
255
|
+
else:
|
|
256
|
+
warnings.warn(
|
|
257
|
+
"\033[32;1m"
|
|
258
|
+
"`pynauty` is not installed."
|
|
259
|
+
" Switched to canonicalization function provided by `rdkit`."
|
|
260
|
+
" This is the expected behaviour only if you are working on Windows platform."
|
|
261
|
+
"\033[0m",
|
|
262
|
+
stacklevel=2,
|
|
263
|
+
)
|
|
264
|
+
cl = list(CanonicalRankAtoms(mol, breakTies=True))
|
|
265
|
+
symbols = np.array([[s] for s in symbols])[cl].flatten().tolist()
|
|
266
|
+
coordinates = coordinates[cl]
|
|
267
|
+
# ------- Find global coordinate frame -------
|
|
268
|
+
if n == 2:
|
|
269
|
+
d = np.round(np.linalg.norm(coordinates[0] - coordinates[1], 2), decimals)
|
|
270
|
+
return f"{symbols[0]} {'0.0'} {'0.0'} {'0.0'} {symbols[1]} {d} {'0.0'} {'0.0'}"
|
|
271
|
+
for idx_0 in range(n - 2):
|
|
272
|
+
_vec0 = coordinates[idx_0] - coordinates[idx_0 + 1]
|
|
273
|
+
_vec1 = coordinates[idx_0] - coordinates[idx_0 + 2]
|
|
274
|
+
_d1 = np.linalg.norm(_vec0, 2)
|
|
275
|
+
_d2 = np.linalg.norm(_vec1, 2)
|
|
276
|
+
if 1 - np.abs(np.dot(_vec0, _vec1) / (_d1 * _d2)) > 1e-6:
|
|
277
|
+
break
|
|
278
|
+
x = (coordinates[idx_0 + 1] - coordinates[idx_0]) / _d1
|
|
279
|
+
y = np.cross((coordinates[idx_0 + 2] - coordinates[idx_0]), x)
|
|
280
|
+
y_d = np.linalg.norm(y, 2)
|
|
281
|
+
y = y / np.ma.filled(np.ma.array(y_d, mask=y_d == 0), np.inf)
|
|
282
|
+
z = np.cross(x, y)
|
|
283
|
+
# ------- Build spherical coordinates -------
|
|
284
|
+
vec = coordinates - coordinates[idx_0]
|
|
285
|
+
d = np.linalg.norm(vec, 2, axis=-1)
|
|
286
|
+
_d = np.ma.filled(np.ma.array(d, mask=d == 0), np.inf)
|
|
287
|
+
theta = angle_scale * np.arccos(np.dot(vec, z) / _d) # in [0, \pi]
|
|
288
|
+
phi = angle_scale * np.arctan2(np.dot(vec, y), np.dot(vec, x)) # in [-\pi, \pi]
|
|
289
|
+
info = np.vstack([d, theta, phi]).T
|
|
290
|
+
info[idx_0] = np.zeros(3)
|
|
291
|
+
info = [
|
|
292
|
+
f"{symbols[i]} {r[0]} {r[1]} {r[2]}"
|
|
293
|
+
for i, r in enumerate(np.round(info, decimals))
|
|
294
|
+
]
|
|
295
|
+
return " ".join(info)
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def seq2geo(
|
|
299
|
+
seq: str, angle_unit: str = "degree"
|
|
300
|
+
) -> Optional[Tuple[List[str], List[List[float]]]]:
|
|
301
|
+
"""
|
|
302
|
+
Sequence-to-geometry function.\n
|
|
303
|
+
The method follows the descriptions in paper: https://arxiv.org/abs/2408.10120.
|
|
304
|
+
|
|
305
|
+
:param seq: `Geo2Seq` string
|
|
306
|
+
:param angle_unit: `'degree'` or `'radian'`
|
|
307
|
+
:type seq: str
|
|
308
|
+
:type angle_unit: str
|
|
309
|
+
:return: (symbols, coordinates) if `seq` is valid
|
|
310
|
+
:rtype: tuple | None
|
|
311
|
+
"""
|
|
312
|
+
assert angle_unit in ("degree", "radian")
|
|
313
|
+
angle_scale = np.pi / 180 if angle_unit == "degree" else 1.0
|
|
314
|
+
tokens = seq.split()
|
|
315
|
+
if len(tokens) % 4 == 0:
|
|
316
|
+
tokens = np.array(tokens).reshape(-1, 4).tolist()
|
|
317
|
+
symbols, coordinates = [], []
|
|
318
|
+
for i in tokens:
|
|
319
|
+
symbol = i[0]
|
|
320
|
+
if len(_atom_regex.findall(symbol)) != 1:
|
|
321
|
+
return None
|
|
322
|
+
symbols.append(symbol)
|
|
323
|
+
try:
|
|
324
|
+
d, theta, phi = float(i[1]), float(i[2]), float(i[3])
|
|
325
|
+
x = d * np.sin(theta * angle_scale) * np.cos(phi * angle_scale)
|
|
326
|
+
y = d * np.sin(theta * angle_scale) * np.sin(phi * angle_scale)
|
|
327
|
+
z = d * np.cos(theta * angle_scale)
|
|
328
|
+
coordinates.append([x.item(), y.item(), z.item()])
|
|
329
|
+
except ValueError:
|
|
330
|
+
return None
|
|
331
|
+
return symbols, coordinates
|
|
332
|
+
return None
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
@torch.no_grad()
|
|
336
|
+
def sample(
|
|
337
|
+
model: ChemBFN,
|
|
338
|
+
batch_size: int,
|
|
339
|
+
sequence_size: int,
|
|
340
|
+
sample_step: int = 100,
|
|
341
|
+
y: Optional[Tensor] = None,
|
|
342
|
+
guidance_strength: float = 4.0,
|
|
343
|
+
device: Union[str, torch.device, None] = None,
|
|
344
|
+
vocab_keys: List[str] = VOCAB_KEYS,
|
|
345
|
+
seperator: str = "",
|
|
346
|
+
method: str = "BFN",
|
|
347
|
+
allowed_tokens: Union[str, List[str]] = "all",
|
|
348
|
+
) -> List[str]:
|
|
349
|
+
"""
|
|
350
|
+
Sampling.
|
|
351
|
+
|
|
352
|
+
:param model: trained ChemBFN model
|
|
353
|
+
:param batch_size: batch size
|
|
354
|
+
:param sequence_size: max sequence length
|
|
355
|
+
:param sample_step: number of sampling steps
|
|
356
|
+
:param y: conditioning vector; shape: (n_b, 1, n_f)
|
|
357
|
+
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
358
|
+
:param device: hardware accelerator
|
|
359
|
+
:param vocab_keys: a list of (ordered) vocabulary
|
|
360
|
+
:param separator: token separator; default is `""`
|
|
361
|
+
:param method: sampling method chosen from `"ODE:x"` or `"BFN"` where `x` is the value of sampling temperature; default is `"BFN"`
|
|
362
|
+
:param allowed_tokens: a list of allowed tokens; default is `"all"`
|
|
363
|
+
:type model: bayesianflow_for_chem.model.ChemBFN
|
|
364
|
+
:type batch_size: int
|
|
365
|
+
:type sequence_size: int
|
|
366
|
+
:type sample_step: int
|
|
367
|
+
:type y: torch.Tensor | None
|
|
368
|
+
:type guidance_strength: float
|
|
369
|
+
:type device: str | torch.device | None
|
|
370
|
+
:type vocab_keys: list
|
|
371
|
+
:type separator: str
|
|
372
|
+
:type method: str
|
|
373
|
+
:type allowed_tokens: str | list
|
|
374
|
+
:return: a list of generated molecular strings
|
|
375
|
+
:rtype: list
|
|
376
|
+
"""
|
|
377
|
+
assert method.split(":")[0].lower() in ("ode", "bfn")
|
|
378
|
+
if device is None:
|
|
379
|
+
device = _find_device()
|
|
380
|
+
model.to(device).eval()
|
|
381
|
+
if y is not None:
|
|
382
|
+
y = y.to(device)
|
|
383
|
+
if isinstance(allowed_tokens, list):
|
|
384
|
+
token_mask = [0 if i in allowed_tokens else 1 for i in vocab_keys]
|
|
385
|
+
token_mask = torch.tensor([[token_mask]], dtype=torch.bool).to(device)
|
|
386
|
+
else:
|
|
387
|
+
token_mask = None
|
|
388
|
+
if "ode" in method.lower():
|
|
389
|
+
tp = float(method.split(":")[-1])
|
|
390
|
+
assert tp > 0, "Sampling temperature should be higher than 0."
|
|
391
|
+
tokens = model.ode_sample(
|
|
392
|
+
batch_size, sequence_size, y, sample_step, guidance_strength, token_mask, tp
|
|
393
|
+
)
|
|
394
|
+
else:
|
|
395
|
+
tokens = model.sample(
|
|
396
|
+
batch_size, sequence_size, y, sample_step, guidance_strength, token_mask
|
|
397
|
+
)
|
|
398
|
+
return [
|
|
399
|
+
seperator.join([vocab_keys[i] for i in j])
|
|
400
|
+
.split("<start>" + seperator)[-1]
|
|
401
|
+
.split(seperator + "<end>")[0]
|
|
402
|
+
.replace("<pad>", "")
|
|
403
|
+
for j in tokens
|
|
404
|
+
]
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
@torch.no_grad()
|
|
408
|
+
def inpaint(
|
|
409
|
+
model: ChemBFN,
|
|
410
|
+
x: Tensor,
|
|
411
|
+
sample_step: int = 100,
|
|
412
|
+
y: Optional[Tensor] = None,
|
|
413
|
+
guidance_strength: float = 4.0,
|
|
414
|
+
device: Union[str, torch.device, None] = None,
|
|
415
|
+
vocab_keys: List[str] = VOCAB_KEYS,
|
|
416
|
+
separator: str = "",
|
|
417
|
+
method: str = "BFN",
|
|
418
|
+
allowed_tokens: Union[str, List[str]] = "all",
|
|
419
|
+
) -> List[str]:
|
|
420
|
+
"""
|
|
421
|
+
Inpaint (context guided) sampling.
|
|
422
|
+
|
|
423
|
+
:param model: trained ChemBFN model
|
|
424
|
+
:param x: categorical indices of scaffold; shape: (n_b, n_t)
|
|
425
|
+
:param sample_step: number of sampling steps
|
|
426
|
+
:param y: conditioning vector; shape: (n_b, 1, n_f)
|
|
427
|
+
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
428
|
+
:param device: hardware accelerator
|
|
429
|
+
:param vocab_keys: a list of (ordered) vocabulary
|
|
430
|
+
:param separator: token separator; default is `""`
|
|
431
|
+
:param method: sampling method chosen from `"ODE:x"` or `"BFN"` where `x` is the value of sampling temperature; default is `"BFN"`
|
|
432
|
+
:param allowed_tokens: a list of allowed tokens; default is `"all"`
|
|
433
|
+
:type model: bayesianflow_for_chem.model.ChemBFN
|
|
434
|
+
:type x: torch.Tensor
|
|
435
|
+
:type sample_step: int
|
|
436
|
+
:type y: torch.Tensor | None
|
|
437
|
+
:type guidance_strength: float
|
|
438
|
+
:type device: str | torch.device | None
|
|
439
|
+
:type vocab_keys: list
|
|
440
|
+
:type separator: str
|
|
441
|
+
:type method: str
|
|
442
|
+
:type allowed_tokens: str | list
|
|
443
|
+
:return: a list of generated molecular strings
|
|
444
|
+
:rtype: list
|
|
445
|
+
"""
|
|
446
|
+
assert method.split(":")[0].lower() in ("ode", "bfn")
|
|
447
|
+
if device is None:
|
|
448
|
+
device = _find_device()
|
|
449
|
+
model.to(device).eval()
|
|
450
|
+
x = x.to(device)
|
|
451
|
+
if y is not None:
|
|
452
|
+
y = y.to(device)
|
|
453
|
+
if isinstance(allowed_tokens, list):
|
|
454
|
+
token_mask = [0 if i in allowed_tokens else 1 for i in vocab_keys]
|
|
455
|
+
token_mask = torch.tensor([[token_mask]], dtype=torch.bool).to(device)
|
|
456
|
+
else:
|
|
457
|
+
token_mask = None
|
|
458
|
+
if "ode" in method.lower():
|
|
459
|
+
tp = float(method.split(":")[-1])
|
|
460
|
+
assert tp > 0, "Sampling temperature should be higher than 0."
|
|
461
|
+
tokens = model.ode_inpaint(x, y, sample_step, guidance_strength, token_mask, tp)
|
|
462
|
+
else:
|
|
463
|
+
tokens = model.inpaint(x, y, sample_step, guidance_strength, token_mask)
|
|
464
|
+
return [
|
|
465
|
+
separator.join([vocab_keys[i] for i in j])
|
|
466
|
+
.split("<start>" + separator)[-1]
|
|
467
|
+
.split(separator + "<end>")[0]
|
|
468
|
+
.replace("<pad>", "")
|
|
469
|
+
for j in tokens
|
|
470
|
+
]
|