bayesianflow-for-chem 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bayesianflow-for-chem might be problematic. Click here for more details.

@@ -0,0 +1,134 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Author: Nianze A. TAO (Omozawa SUENO)
3
+ """
4
+ Scorers.
5
+ """
6
+ from typing import List, Callable, Union, Optional
7
+ import torch
8
+ from torch import Tensor
9
+ from rdkit import RDLogger
10
+ from rdkit.Contrib.SA_Score import sascorer # type: ignore
11
+ from rdkit.Chem import MolFromSmiles, QED
12
+
13
+ RDLogger.DisableLog("rdApp.*") # type: ignore
14
+
15
+
16
+ def smiles_valid(smiles: str) -> int:
17
+ """
18
+ Return the validity of a SMILES string.
19
+
20
+ :param smiles: SMIlES string
21
+ :type smiles: str
22
+ :return: validity
23
+ :rtype: int
24
+ """
25
+ return 1 if (MolFromSmiles(smiles) and smiles) else 0
26
+
27
+
28
+ def qed_score(smiles: str) -> float:
29
+ """
30
+ Return the quantitative estimate of drug-likeness score of a SMILES string.
31
+
32
+ :param smiles: SMILES string
33
+ :type smiles: str
34
+ :return: QED score
35
+ :rtype: float
36
+ """
37
+ return QED.qed(MolFromSmiles(smiles))
38
+
39
+
40
+ def sa_score(smiles: str) -> float:
41
+ """
42
+ Return the synthetic accessibility score of a SMILES string.
43
+
44
+ :param smiles: SMILES string
45
+ :type smiles: str
46
+ :return: SA score
47
+ :rtype: float
48
+ """
49
+ return sascorer.calculateScore(MolFromSmiles(smiles))
50
+
51
+
52
+ class Scorer:
53
+ def __init__(
54
+ self,
55
+ scorers: List[Callable[[str], Union[int, float]]],
56
+ score_criteria: List[Callable[[Union[int, float]], float]],
57
+ vocab_keys: List[str],
58
+ vocab_separator: str = "",
59
+ valid_checker: Optional[Callable[[str], int]] = None,
60
+ eta: float = 1e-2,
61
+ name: str = "scorer",
62
+ ) -> None:
63
+ """
64
+ Scorer class.
65
+ e.g.
66
+
67
+ ```python
68
+ scorer = Scorer(
69
+ scorers=[smiles_valid, qed_score],
70
+ score_criteria=[lambda x: float(x == 1), lambda x: float(x > 0.5)],
71
+ vocab_keys=VOCAB_KEYS,
72
+ )
73
+ ```
74
+
75
+ :param scorers: a list of scorer(s)
76
+ :param score_criteria: a list of score criterion (or criteria) in the same order of scorers
77
+ :param vocab_keys: a list of (ordered) vocabulary
78
+ :param vocab_separator: token separator; default is `""`
79
+ :param valid_checker: a callable to check the validity of sequences; default is `None`
80
+ :param eta: the coefficient to be multiplied to the loss
81
+ :param name: the name of this scorer
82
+ :type scorers: list
83
+ :type score_criteria: list
84
+ :type vocab_keys: list
85
+ :type vocab_separator: str
86
+ :type eta: float
87
+ :type name: str
88
+ :type valid_checker: typing.Callable | None
89
+ """
90
+ assert len(scorers) == len(
91
+ score_criteria
92
+ ), "The number of scores should match that of criteria."
93
+ self.scorers = scorers
94
+ self.score_criteria = score_criteria
95
+ self.vocab_keys = vocab_keys
96
+ self.vocab_separator = vocab_separator
97
+ self.valid_checker = valid_checker
98
+ self.eta = eta
99
+ self.name = name
100
+
101
+ def calc_score_loss(self, p: Tensor) -> Tensor:
102
+ """
103
+ Calculate the score loss.
104
+
105
+ :param p: token probability distributions; shape: (n_b, n_t, n_vocab)
106
+ :type p: torch.Tensor
107
+ :return: score loss; shape: ()
108
+ :rtype: torch.Tensor
109
+ """
110
+ tokens = p.argmax(-1)
111
+ e_k = torch.nn.functional.one_hot(tokens, len(self.vocab_keys)).float()
112
+ seqs = [
113
+ self.vocab_separator.join([self.vocab_keys[i] for i in j])
114
+ .split("<start>" + self.vocab_separator)[-1]
115
+ .split(self.vocab_separator + "<end>")[0]
116
+ .replace("<pad>", "")
117
+ for j in tokens
118
+ ]
119
+ valid = [
120
+ 1 if self.valid_checker is None else self.valid_checker(i) for i in seqs
121
+ ]
122
+ scores = [
123
+ [
124
+ 1 if valid[j] == 0 else 1 - self.score_criteria[i](scorer(seq))
125
+ for j, seq in enumerate(seqs)
126
+ ]
127
+ for i, scorer in enumerate(self.scorers)
128
+ ]
129
+ loss = (e_k * p).sum(2).mean(1) * torch.tensor(scores, device=p.device).mean(0)
130
+ return loss.mean()
131
+
132
+
133
+ if __name__ == "__main__":
134
+ ...
@@ -0,0 +1,470 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Author: Nianze A. TAO (Omozawa SUENO)
3
+ """
4
+ Tools.
5
+ """
6
+ import re
7
+ import csv
8
+ import random
9
+ from pathlib import Path
10
+ from typing import List, Dict, Tuple, Union, Optional
11
+ import torch
12
+ import numpy as np
13
+ from torch import cuda, Tensor, softmax
14
+ from torch.utils.data import DataLoader
15
+ from rdkit.Chem import rdDetermineBonds, Bond, MolFromXYZBlock, CanonicalRankAtoms
16
+ from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles # type: ignore
17
+ from sklearn.metrics import (
18
+ roc_auc_score,
19
+ auc,
20
+ precision_recall_curve,
21
+ r2_score,
22
+ mean_absolute_error,
23
+ root_mean_squared_error,
24
+ )
25
+
26
+ try:
27
+ from pynauty import Graph, canon_label # type: ignore
28
+
29
+ _use_pynauty = True
30
+ except ImportError:
31
+ import warnings
32
+
33
+ _use_pynauty = False
34
+
35
+ from .data import VOCAB_KEYS
36
+ from .model import ChemBFN, MLP
37
+
38
+
39
+ _atom_regex_pattern = (
40
+ r"(H[e,f,g,s,o]?|"
41
+ r"L[i,v,a,r,u]|"
42
+ r"B[e,r,a,i,h,k]?|"
43
+ r"C[l,a,r,o,u,d,s,n,e,m,f]?|"
44
+ r"N[e,a,i,b,h,d,o,p]?|"
45
+ r"O[s,g]?|S[i,c,e,r,n,m,b,g]?|"
46
+ r"K[r]?|T[i,c,e,a,l,b,h,m,s]|"
47
+ r"G[a,e,d]|R[b,u,h,e,n,a,f,g]|"
48
+ r"Yb?|Z[n,r]|P[t,o,d,r,a,u,b,m]?|"
49
+ r"F[e,r,l,m]?|M[g,n,o,t,c,d]|"
50
+ r"A[l,r,s,g,u,t,c,m]|I[n,r]?|"
51
+ r"W|X[e]|E[u,r,s]|U|D[b,s,y])"
52
+ )
53
+ _atom_regex = re.compile(_atom_regex_pattern)
54
+
55
+
56
+ def _find_device() -> torch.device:
57
+ if cuda.is_available():
58
+ return torch.device("cuda")
59
+ elif torch.backends.mps.is_available():
60
+ return torch.device("mps")
61
+ return torch.device("cpu")
62
+
63
+
64
+ def _bond_pair_idx(bonds: Bond) -> List[List[int]]:
65
+ return [[i.GetBeginAtomIdx(), i.GetEndAtomIdx()] for i in bonds]
66
+
67
+
68
+ @torch.no_grad()
69
+ def test(
70
+ model: ChemBFN,
71
+ mlp: MLP,
72
+ data: DataLoader,
73
+ mode: str = "regression",
74
+ device: Union[str, torch.device, None] = None,
75
+ ) -> Dict[str, float]:
76
+ """
77
+ Test the trained network.
78
+
79
+ :param model: pretrained ChemBFN model
80
+ :param mlp: trained MLP model for testing
81
+ :param data: DataLoader instance
82
+ :param mode: testing mode chosen from `'regression'` and `'classification'`
83
+ :param device: hardware accelerator
84
+ :type model: bayesianflow_for_chem.model.ChemBFN
85
+ :type mlp: bayesianflow_for_chem.model.MLP
86
+ :type data: torch.utils.data.DataLoader
87
+ :type mode: str
88
+ :type device: str | torch.device | None
89
+ :return: MAE & RMSE & R^2 / ROC-AUC & PRC-AUC
90
+ :rtype: dict
91
+ """
92
+ if device is None:
93
+ device = _find_device()
94
+ model.to(device).eval()
95
+ mlp.to(device).eval()
96
+ predict_y, label_y = [], []
97
+ for d in data:
98
+ x, y = d["token"].to(device), d["value"]
99
+ label_y.append(y)
100
+ if mode == "regression":
101
+ y_hat = model.inference(x, mlp)
102
+ if mode == "classification":
103
+ n_b, n_y = y.shape
104
+ y_hat = softmax(model.inference(x, mlp).reshape(n_b * n_y, -1), -1)
105
+ y_hat = y_hat.reshape(n_b, -1)
106
+ predict_y.append(y_hat.detach().to("cpu"))
107
+ predict_y, label_y = torch.cat(predict_y, 0), torch.cat(label_y, 0).split(1, -1)
108
+ if mode == "regression":
109
+ predict_y = [
110
+ predict[label_y[i] != torch.inf]
111
+ for (i, predict) in enumerate(predict_y.split(1, -1))
112
+ ]
113
+ label_y = [label[label != torch.inf] for label in label_y]
114
+ y_zipped = list(zip(label_y, predict_y))
115
+ mae = [mean_absolute_error(label, predict) for (label, predict) in y_zipped]
116
+ rmse = [
117
+ root_mean_squared_error(label, predict) for (label, predict) in y_zipped
118
+ ]
119
+ r2 = [r2_score(label, predict) for (label, predict) in y_zipped]
120
+ return {"MAE": mae, "RMSE": rmse, "R^2": r2}
121
+ if mode == "classification":
122
+ n_c = len(label_y)
123
+ predict_y = predict_y.chunk(n_c, -1)
124
+ y_zipped = list(zip(label_y, predict_y))
125
+ roc_auc = [
126
+ roc_auc_score(
127
+ label.flatten(),
128
+ predict[:, 1] if predict.shape[-1] == 2 else predict,
129
+ multi_class="raise" if predict.shape[-1] == 2 else "ovo",
130
+ labels=None if predict.shape[-1] == 2 else range(predict.shape[-1]),
131
+ )
132
+ for (label, predict) in y_zipped
133
+ ]
134
+ try:
135
+ prc = [
136
+ precision_recall_curve(label.flatten(), predict[:, 1])[:2]
137
+ for (label, predict) in y_zipped
138
+ ]
139
+ prc_auc = [auc(recall, precision) for (precision, recall) in prc]
140
+ except ValueError:
141
+ prc_auc = []
142
+ return {"ROC-AUC": roc_auc, "PRC-AUC": prc_auc}
143
+
144
+
145
+ def split_dataset(
146
+ file: Union[str, Path], split_ratio: List[int] = [8, 1, 1], method: str = "random"
147
+ ) -> None:
148
+ """
149
+ Split a dataset.
150
+
151
+ :param file: dataset file <file>
152
+ :param split_ratio: traing-testing-validation ratio
153
+ :param method: chosen from `'random'` and `'scaffold'`
154
+ :type file: str | pathlib.Path
155
+ :type split_ratio: list
156
+ :type method: str
157
+ :return:
158
+ :rtype: None
159
+ """
160
+ assert file.endswith(".csv")
161
+ assert len(split_ratio) == 3
162
+ assert method in ("random", "scaffold")
163
+ with open(file, "r") as f:
164
+ data = list(csv.reader(f))
165
+ header = data[0]
166
+ raw_data = data[1:]
167
+ smiles_idx = [] # only first index will be used
168
+ for key, h in enumerate(header):
169
+ if h.lower() == "smiles":
170
+ smiles_idx.append(key)
171
+ assert len(smiles_idx) > 0
172
+ data_len = len(raw_data)
173
+ train_ratio = split_ratio[0] / sum(split_ratio)
174
+ test_ratio = sum(split_ratio[:2]) / sum(split_ratio)
175
+ train_idx, test_idx = int(data_len * train_ratio), int(data_len * test_ratio)
176
+ if method == "random":
177
+ random.shuffle(raw_data)
178
+ train_set = raw_data[:train_idx]
179
+ test_set = raw_data[train_idx:test_idx]
180
+ val_set = raw_data[test_idx:]
181
+ if method == "scaffold":
182
+ scaffolds: Dict[str, List] = {}
183
+ for key, d in enumerate(raw_data):
184
+ # compute Bemis-Murcko scaffold
185
+ scaffold = MurckoScaffoldSmiles(d[smiles_idx[0]])
186
+ if scaffold in scaffolds:
187
+ scaffolds[scaffold].append(key)
188
+ else:
189
+ scaffolds[scaffold] = [key]
190
+ scaffolds = {key: sorted(value) for key, value in scaffolds.items()}
191
+ train_set, test_set, val_set = [], [], []
192
+ for idxs in scaffolds.values():
193
+ if len(train_set) + len(idxs) > train_idx:
194
+ if len(train_set) + len(test_set) + len(idxs) > test_idx:
195
+ val_set += [raw_data[i] for i in idxs]
196
+ else:
197
+ test_set += [raw_data[i] for i in idxs]
198
+ else:
199
+ train_set += [raw_data[i] for i in idxs]
200
+ with open(file.replace(".csv", "_train.csv"), "w", newline="") as ftr:
201
+ writer = csv.writer(ftr)
202
+ writer.writerows([header] + train_set)
203
+ with open(file.replace(".csv", "_test.csv"), "w", newline="") as fte:
204
+ writer = csv.writer(fte)
205
+ writer.writerows([header] + test_set)
206
+ with open(file.replace(".csv", "_val.csv"), "w", newline="") as fva:
207
+ writer = csv.writer(fva)
208
+ writer.writerows([header] + val_set)
209
+
210
+
211
+ def geo2seq(
212
+ symbols: List[str],
213
+ coordinates: np.ndarray,
214
+ decimals: int = 2,
215
+ angle_unit: str = "degree",
216
+ ) -> str:
217
+ """
218
+ Geometry-to-sequence function.\n
219
+ The algorithm follows the descriptions in paper: https://arxiv.org/abs/2408.10120.
220
+
221
+ :param symbols: a list of atomic symbols
222
+ :param coordinates: Cartesian coordinates; shape: (n_a, 3)
223
+ :param decimals: number of decimal places to round to
224
+ :param angle_unit: `'degree'` or `'radian'`
225
+ :type symbols: list
226
+ :type coordinates: numpy.ndarray
227
+ :type decimals: int
228
+ :type angle_unit: str
229
+ :return: `Geo2Seq` string
230
+ :rtype: str
231
+ """
232
+ assert angle_unit in ("degree", "radian")
233
+ angle_scale = 180 / np.pi if angle_unit == "degree" else 1.0
234
+ n = len(symbols)
235
+ if n == 1:
236
+ return f"{symbols[0]} {'0.0'} {'0.0'} {'0.0'}"
237
+ xyz_block = [str(n), ""]
238
+ for i, atom in enumerate(symbols):
239
+ xyz_block.append(
240
+ f"{atom} {'%.10f' % coordinates[i][0].item()} {'%.10f' % coordinates[i][1].item()} {'%.10f' % coordinates[i][2].item()}"
241
+ )
242
+ mol = MolFromXYZBlock("\n".join(xyz_block))
243
+ rdDetermineBonds.DetermineConnectivity(mol)
244
+ # ------- Canonicalization -------
245
+ if _use_pynauty:
246
+ pair_idx = np.array(_bond_pair_idx(mol.GetBonds())).T.tolist()
247
+ pair_dict: Dict[int, List[int]] = {}
248
+ for key, i in enumerate(pair_idx[0]):
249
+ if i not in pair_dict:
250
+ pair_dict[i] = [pair_idx[1][key]]
251
+ else:
252
+ pair_dict[i].append(pair_idx[1][key])
253
+ g = Graph(n, adjacency_dict=pair_dict)
254
+ cl = canon_label(g) # type: list
255
+ else:
256
+ warnings.warn(
257
+ "\033[32;1m"
258
+ "`pynauty` is not installed."
259
+ " Switched to canonicalization function provided by `rdkit`."
260
+ " This is the expected behaviour only if you are working on Windows platform."
261
+ "\033[0m",
262
+ stacklevel=2,
263
+ )
264
+ cl = list(CanonicalRankAtoms(mol, breakTies=True))
265
+ symbols = np.array([[s] for s in symbols])[cl].flatten().tolist()
266
+ coordinates = coordinates[cl]
267
+ # ------- Find global coordinate frame -------
268
+ if n == 2:
269
+ d = np.round(np.linalg.norm(coordinates[0] - coordinates[1], 2), decimals)
270
+ return f"{symbols[0]} {'0.0'} {'0.0'} {'0.0'} {symbols[1]} {d} {'0.0'} {'0.0'}"
271
+ for idx_0 in range(n - 2):
272
+ _vec0 = coordinates[idx_0] - coordinates[idx_0 + 1]
273
+ _vec1 = coordinates[idx_0] - coordinates[idx_0 + 2]
274
+ _d1 = np.linalg.norm(_vec0, 2)
275
+ _d2 = np.linalg.norm(_vec1, 2)
276
+ if 1 - np.abs(np.dot(_vec0, _vec1) / (_d1 * _d2)) > 1e-6:
277
+ break
278
+ x = (coordinates[idx_0 + 1] - coordinates[idx_0]) / _d1
279
+ y = np.cross((coordinates[idx_0 + 2] - coordinates[idx_0]), x)
280
+ y_d = np.linalg.norm(y, 2)
281
+ y = y / np.ma.filled(np.ma.array(y_d, mask=y_d == 0), np.inf)
282
+ z = np.cross(x, y)
283
+ # ------- Build spherical coordinates -------
284
+ vec = coordinates - coordinates[idx_0]
285
+ d = np.linalg.norm(vec, 2, axis=-1)
286
+ _d = np.ma.filled(np.ma.array(d, mask=d == 0), np.inf)
287
+ theta = angle_scale * np.arccos(np.dot(vec, z) / _d) # in [0, \pi]
288
+ phi = angle_scale * np.arctan2(np.dot(vec, y), np.dot(vec, x)) # in [-\pi, \pi]
289
+ info = np.vstack([d, theta, phi]).T
290
+ info[idx_0] = np.zeros(3)
291
+ info = [
292
+ f"{symbols[i]} {r[0]} {r[1]} {r[2]}"
293
+ for i, r in enumerate(np.round(info, decimals))
294
+ ]
295
+ return " ".join(info)
296
+
297
+
298
+ def seq2geo(
299
+ seq: str, angle_unit: str = "degree"
300
+ ) -> Optional[Tuple[List[str], List[List[float]]]]:
301
+ """
302
+ Sequence-to-geometry function.\n
303
+ The method follows the descriptions in paper: https://arxiv.org/abs/2408.10120.
304
+
305
+ :param seq: `Geo2Seq` string
306
+ :param angle_unit: `'degree'` or `'radian'`
307
+ :type seq: str
308
+ :type angle_unit: str
309
+ :return: (symbols, coordinates) if `seq` is valid
310
+ :rtype: tuple | None
311
+ """
312
+ assert angle_unit in ("degree", "radian")
313
+ angle_scale = np.pi / 180 if angle_unit == "degree" else 1.0
314
+ tokens = seq.split()
315
+ if len(tokens) % 4 == 0:
316
+ tokens = np.array(tokens).reshape(-1, 4).tolist()
317
+ symbols, coordinates = [], []
318
+ for i in tokens:
319
+ symbol = i[0]
320
+ if len(_atom_regex.findall(symbol)) != 1:
321
+ return None
322
+ symbols.append(symbol)
323
+ try:
324
+ d, theta, phi = float(i[1]), float(i[2]), float(i[3])
325
+ x = d * np.sin(theta * angle_scale) * np.cos(phi * angle_scale)
326
+ y = d * np.sin(theta * angle_scale) * np.sin(phi * angle_scale)
327
+ z = d * np.cos(theta * angle_scale)
328
+ coordinates.append([x.item(), y.item(), z.item()])
329
+ except ValueError:
330
+ return None
331
+ return symbols, coordinates
332
+ return None
333
+
334
+
335
+ @torch.no_grad()
336
+ def sample(
337
+ model: ChemBFN,
338
+ batch_size: int,
339
+ sequence_size: int,
340
+ sample_step: int = 100,
341
+ y: Optional[Tensor] = None,
342
+ guidance_strength: float = 4.0,
343
+ device: Union[str, torch.device, None] = None,
344
+ vocab_keys: List[str] = VOCAB_KEYS,
345
+ seperator: str = "",
346
+ method: str = "BFN",
347
+ allowed_tokens: Union[str, List[str]] = "all",
348
+ ) -> List[str]:
349
+ """
350
+ Sampling.
351
+
352
+ :param model: trained ChemBFN model
353
+ :param batch_size: batch size
354
+ :param sequence_size: max sequence length
355
+ :param sample_step: number of sampling steps
356
+ :param y: conditioning vector; shape: (n_b, 1, n_f)
357
+ :param guidance_strength: strength of conditional generation. It is not used if y is null.
358
+ :param device: hardware accelerator
359
+ :param vocab_keys: a list of (ordered) vocabulary
360
+ :param separator: token separator; default is `""`
361
+ :param method: sampling method chosen from `"ODE:x"` or `"BFN"` where `x` is the value of sampling temperature; default is `"BFN"`
362
+ :param allowed_tokens: a list of allowed tokens; default is `"all"`
363
+ :type model: bayesianflow_for_chem.model.ChemBFN
364
+ :type batch_size: int
365
+ :type sequence_size: int
366
+ :type sample_step: int
367
+ :type y: torch.Tensor | None
368
+ :type guidance_strength: float
369
+ :type device: str | torch.device | None
370
+ :type vocab_keys: list
371
+ :type separator: str
372
+ :type method: str
373
+ :type allowed_tokens: str | list
374
+ :return: a list of generated molecular strings
375
+ :rtype: list
376
+ """
377
+ assert method.split(":")[0].lower() in ("ode", "bfn")
378
+ if device is None:
379
+ device = _find_device()
380
+ model.to(device).eval()
381
+ if y is not None:
382
+ y = y.to(device)
383
+ if isinstance(allowed_tokens, list):
384
+ token_mask = [0 if i in allowed_tokens else 1 for i in vocab_keys]
385
+ token_mask = torch.tensor([[token_mask]], dtype=torch.bool).to(device)
386
+ else:
387
+ token_mask = None
388
+ if "ode" in method.lower():
389
+ tp = float(method.split(":")[-1])
390
+ assert tp > 0, "Sampling temperature should be higher than 0."
391
+ tokens = model.ode_sample(
392
+ batch_size, sequence_size, y, sample_step, guidance_strength, token_mask, tp
393
+ )
394
+ else:
395
+ tokens = model.sample(
396
+ batch_size, sequence_size, y, sample_step, guidance_strength, token_mask
397
+ )
398
+ return [
399
+ seperator.join([vocab_keys[i] for i in j])
400
+ .split("<start>" + seperator)[-1]
401
+ .split(seperator + "<end>")[0]
402
+ .replace("<pad>", "")
403
+ for j in tokens
404
+ ]
405
+
406
+
407
+ @torch.no_grad()
408
+ def inpaint(
409
+ model: ChemBFN,
410
+ x: Tensor,
411
+ sample_step: int = 100,
412
+ y: Optional[Tensor] = None,
413
+ guidance_strength: float = 4.0,
414
+ device: Union[str, torch.device, None] = None,
415
+ vocab_keys: List[str] = VOCAB_KEYS,
416
+ separator: str = "",
417
+ method: str = "BFN",
418
+ allowed_tokens: Union[str, List[str]] = "all",
419
+ ) -> List[str]:
420
+ """
421
+ Inpaint (context guided) sampling.
422
+
423
+ :param model: trained ChemBFN model
424
+ :param x: categorical indices of scaffold; shape: (n_b, n_t)
425
+ :param sample_step: number of sampling steps
426
+ :param y: conditioning vector; shape: (n_b, 1, n_f)
427
+ :param guidance_strength: strength of conditional generation. It is not used if y is null.
428
+ :param device: hardware accelerator
429
+ :param vocab_keys: a list of (ordered) vocabulary
430
+ :param separator: token separator; default is `""`
431
+ :param method: sampling method chosen from `"ODE:x"` or `"BFN"` where `x` is the value of sampling temperature; default is `"BFN"`
432
+ :param allowed_tokens: a list of allowed tokens; default is `"all"`
433
+ :type model: bayesianflow_for_chem.model.ChemBFN
434
+ :type x: torch.Tensor
435
+ :type sample_step: int
436
+ :type y: torch.Tensor | None
437
+ :type guidance_strength: float
438
+ :type device: str | torch.device | None
439
+ :type vocab_keys: list
440
+ :type separator: str
441
+ :type method: str
442
+ :type allowed_tokens: str | list
443
+ :return: a list of generated molecular strings
444
+ :rtype: list
445
+ """
446
+ assert method.split(":")[0].lower() in ("ode", "bfn")
447
+ if device is None:
448
+ device = _find_device()
449
+ model.to(device).eval()
450
+ x = x.to(device)
451
+ if y is not None:
452
+ y = y.to(device)
453
+ if isinstance(allowed_tokens, list):
454
+ token_mask = [0 if i in allowed_tokens else 1 for i in vocab_keys]
455
+ token_mask = torch.tensor([[token_mask]], dtype=torch.bool).to(device)
456
+ else:
457
+ token_mask = None
458
+ if "ode" in method.lower():
459
+ tp = float(method.split(":")[-1])
460
+ assert tp > 0, "Sampling temperature should be higher than 0."
461
+ tokens = model.ode_inpaint(x, y, sample_step, guidance_strength, token_mask, tp)
462
+ else:
463
+ tokens = model.inpaint(x, y, sample_step, guidance_strength, token_mask)
464
+ return [
465
+ separator.join([vocab_keys[i] for i in j])
466
+ .split("<start>" + separator)[-1]
467
+ .split(separator + "<end>")[0]
468
+ .replace("<pad>", "")
469
+ for j in tokens
470
+ ]