bayesianflow-for-chem 1.3.0__py3-none-any.whl → 1.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bayesianflow-for-chem might be problematic. Click here for more details.
- bayesianflow_for_chem/__init__.py +1 -1
- bayesianflow_for_chem/data.py +1 -38
- bayesianflow_for_chem/model.py +7 -22
- bayesianflow_for_chem/tool.py +97 -166
- {bayesianflow_for_chem-1.3.0.dist-info → bayesianflow_for_chem-1.4.1.dist-info}/METADATA +2 -5
- bayesianflow_for_chem-1.4.1.dist-info/RECORD +12 -0
- bayesianflow_for_chem-1.3.0.dist-info/RECORD +0 -12
- {bayesianflow_for_chem-1.3.0.dist-info → bayesianflow_for_chem-1.4.1.dist-info}/WHEEL +0 -0
- {bayesianflow_for_chem-1.3.0.dist-info → bayesianflow_for_chem-1.4.1.dist-info}/licenses/LICENSE +0 -0
- {bayesianflow_for_chem-1.3.0.dist-info → bayesianflow_for_chem-1.4.1.dist-info}/top_level.txt +0 -0
|
@@ -7,5 +7,5 @@ from . import data, tool, train, scorer
|
|
|
7
7
|
from .model import ChemBFN, MLP, EnsembleChemBFN
|
|
8
8
|
|
|
9
9
|
__all__ = ["data", "tool", "train", "scorer", "ChemBFN", "MLP", "EnsembleChemBFN"]
|
|
10
|
-
__version__ = "1.
|
|
10
|
+
__version__ = "1.4.1"
|
|
11
11
|
__author__ = "Nianze A. Tao (Omozawa Sueno)"
|
bayesianflow_for_chem/data.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
# -*- coding: utf-8 -*-
|
|
2
2
|
# Author: Nianze A. TAO (Omozawa SUENO)
|
|
3
3
|
"""
|
|
4
|
-
Tokenise SMILES/SAFE/SELFIES/
|
|
4
|
+
Tokenise SMILES/SAFE/SELFIES/protein-sequence strings.
|
|
5
5
|
"""
|
|
6
6
|
import os
|
|
7
7
|
import re
|
|
@@ -32,25 +32,9 @@ SMI_REGEX_PATTERN = (
|
|
|
32
32
|
r"~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
|
|
33
33
|
)
|
|
34
34
|
SEL_REGEX_PATTERN = r"(\[[^\]]+]|\.)"
|
|
35
|
-
GEO_REGEX_PATTERN = (
|
|
36
|
-
r"(H[e,f,g,s,o]?|"
|
|
37
|
-
r"L[i,v,a,r,u]|"
|
|
38
|
-
r"B[e,r,a,i,h,k]?|"
|
|
39
|
-
r"C[l,a,r,o,u,d,s,n,e,m,f]?|"
|
|
40
|
-
r"N[e,a,i,b,h,d,o,p]?|"
|
|
41
|
-
r"O[s,g]?|S[i,c,e,r,n,m,b,g]?|"
|
|
42
|
-
r"K[r]?|T[i,c,e,a,l,b,h,m,s]|"
|
|
43
|
-
r"G[a,e,d]|R[b,u,h,e,n,a,f,g]|"
|
|
44
|
-
r"Yb?|Z[n,r]|P[t,o,d,r,a,u,b,m]?|"
|
|
45
|
-
r"F[e,r,l,m]?|M[g,n,o,t,c,d]|"
|
|
46
|
-
r"A[l,r,s,g,u,t,c,m]|I[n,r]?|"
|
|
47
|
-
r"W|X[e]|E[u,r,s]|U|D[b,s,y]|"
|
|
48
|
-
r"-|.| |[0-9])"
|
|
49
|
-
)
|
|
50
35
|
AA_REGEX_PATTERN = r"(A|B|C|D|E|F|G|H|I|K|L|M|N|P|Q|R|S|T|V|W|Y|Z|-|.)"
|
|
51
36
|
smi_regex = re.compile(SMI_REGEX_PATTERN)
|
|
52
37
|
sel_regex = re.compile(SEL_REGEX_PATTERN)
|
|
53
|
-
geo_regex = re.compile(GEO_REGEX_PATTERN)
|
|
54
38
|
aa_regex = re.compile(AA_REGEX_PATTERN)
|
|
55
39
|
|
|
56
40
|
|
|
@@ -86,9 +70,6 @@ AA_VOCAB_KEYS = (
|
|
|
86
70
|
)
|
|
87
71
|
AA_VOCAB_COUNT = len(AA_VOCAB_KEYS)
|
|
88
72
|
AA_VOCAB_DICT = dict(zip(AA_VOCAB_KEYS, range(AA_VOCAB_COUNT)))
|
|
89
|
-
GEO_VOCAB_KEYS = VOCAB_KEYS[0:3] + [" "] + VOCAB_KEYS[22:150] + [".", "-"]
|
|
90
|
-
GEO_VOCAB_COUNT = len(GEO_VOCAB_KEYS)
|
|
91
|
-
GEO_VOCAB_DICT = dict(zip(GEO_VOCAB_KEYS, range(GEO_VOCAB_COUNT)))
|
|
92
73
|
|
|
93
74
|
|
|
94
75
|
def smiles2vec(smiles: str) -> List[int]:
|
|
@@ -104,19 +85,6 @@ def smiles2vec(smiles: str) -> List[int]:
|
|
|
104
85
|
return [VOCAB_DICT[token] for token in tokens]
|
|
105
86
|
|
|
106
87
|
|
|
107
|
-
def geo2vec(geo2seq: str) -> List[int]:
|
|
108
|
-
"""
|
|
109
|
-
Geo2Seq tokenisation using a dataset-independent regex pattern.
|
|
110
|
-
|
|
111
|
-
:param geo2seq: `GEO2SEQ` string
|
|
112
|
-
:type geo2seq: str
|
|
113
|
-
:return: tokens w/o `<start>` and `<end>`
|
|
114
|
-
:rtype: list
|
|
115
|
-
"""
|
|
116
|
-
tokens = [token for token in geo_regex.findall(geo2seq)]
|
|
117
|
-
return [GEO_VOCAB_DICT[token] for token in tokens]
|
|
118
|
-
|
|
119
|
-
|
|
120
88
|
def aa2vec(aa_seq: str) -> List[int]:
|
|
121
89
|
"""
|
|
122
90
|
Protein sequence tokenisation using a dataset-independent regex pattern.
|
|
@@ -147,11 +115,6 @@ def smiles2token(smiles: str) -> Tensor:
|
|
|
147
115
|
return torch.tensor([1] + smiles2vec(smiles) + [2], dtype=torch.long)
|
|
148
116
|
|
|
149
117
|
|
|
150
|
-
def geo2token(geo2seq: str) -> Tensor:
|
|
151
|
-
# start token: <start> = 1; end token: <esc> = 2
|
|
152
|
-
return torch.tensor([1] + geo2vec(geo2seq) + [2], dtype=torch.long)
|
|
153
|
-
|
|
154
|
-
|
|
155
118
|
def aa2token(aa_seq: str) -> Tensor:
|
|
156
119
|
# start token: <start> = 1; end token: <end> = 2
|
|
157
120
|
return torch.tensor([1] + aa2vec(aa_seq) + [2], dtype=torch.long)
|
bayesianflow_for_chem/model.py
CHANGED
|
@@ -162,23 +162,13 @@ class Attention(nn.Module):
|
|
|
162
162
|
:return: attentioned output; shape: (n_b, n_t, n_f)
|
|
163
163
|
:rtype: torch.Tensor
|
|
164
164
|
"""
|
|
165
|
-
n_b,
|
|
166
|
-
split = (n_b,
|
|
165
|
+
n_b, n_t, _ = shape = x.shape
|
|
166
|
+
split = (n_b, n_t, self.nh, self.d)
|
|
167
167
|
q, k, v = self.qkv(x).chunk(3, -1)
|
|
168
168
|
q = q.view(split).permute(2, 0, 1, 3).contiguous()
|
|
169
169
|
k = k.view(split).permute(2, 0, 1, 3).contiguous()
|
|
170
170
|
v = v.view(split).permute(2, 0, 1, 3).contiguous()
|
|
171
171
|
q, k = self._rotate(q, k, pe) # position embedding
|
|
172
|
-
"""
|
|
173
|
-
# Original code. Maybe using `nn.functional.scaled_dot_product_attention(...)` is better.
|
|
174
|
-
|
|
175
|
-
k_t = k.transpose(-2, -1)
|
|
176
|
-
if mask is not None:
|
|
177
|
-
alpha = softmax((q @ k_t / self.tp).masked_fill_(mask, -torch.inf), -1)
|
|
178
|
-
else:
|
|
179
|
-
alpha = softmax(q @ k_t / self.tp, -1)
|
|
180
|
-
atten_out = (alpha @ v).permute(1, 2, 0, 3).contiguous().view(shape)
|
|
181
|
-
"""
|
|
182
172
|
atten_out = nn.functional.scaled_dot_product_attention(
|
|
183
173
|
q, k, v, mask, 0.0, False, scale=1 / self.tp
|
|
184
174
|
)
|
|
@@ -428,21 +418,16 @@ class ChemBFN(nn.Module):
|
|
|
428
418
|
c = self.time_embed(t)
|
|
429
419
|
if y is not None:
|
|
430
420
|
c += y
|
|
431
|
-
pe = self.position(
|
|
421
|
+
pe = self.position(n_t)
|
|
432
422
|
x = self.embedding(x)
|
|
433
|
-
attn_mask: Optional[Tensor] = None
|
|
434
423
|
if self.semi_autoregressive:
|
|
435
424
|
attn_mask = torch.tril(
|
|
436
|
-
torch.ones((1, n_b, n_t, n_t), device=
|
|
425
|
+
torch.ones((1, n_b, n_t, n_t), device=x.device), diagonal=0
|
|
437
426
|
)
|
|
427
|
+
elif mask is not None:
|
|
428
|
+
attn_mask = mask.transpose(-2, -1).repeat(1, n_t, 1)[None, ...] != 0
|
|
438
429
|
else:
|
|
439
|
-
|
|
440
|
-
"""
|
|
441
|
-
# Original Code.
|
|
442
|
-
|
|
443
|
-
attn_mask = mask.transpose(-2, -1).repeat(1, x.shape[1], 1)[None, ...] == 0
|
|
444
|
-
"""
|
|
445
|
-
attn_mask = mask.transpose(-2, -1).repeat(1, n_t, 1)[None, ...] != 0
|
|
430
|
+
attn_mask = None
|
|
446
431
|
for layer in self.encoder_layers:
|
|
447
432
|
x = layer(x, pe, c, attn_mask)
|
|
448
433
|
return self.final_layer(x, c, mask is None)
|
bayesianflow_for_chem/tool.py
CHANGED
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
"""
|
|
4
4
|
Essential tools.
|
|
5
5
|
"""
|
|
6
|
-
import re
|
|
7
6
|
import csv
|
|
8
7
|
import random
|
|
9
8
|
import warnings
|
|
@@ -17,8 +16,17 @@ from torch import cuda, Tensor, softmax
|
|
|
17
16
|
from torch.ao import quantization
|
|
18
17
|
from torch.utils.data import DataLoader
|
|
19
18
|
from typing_extensions import Self
|
|
20
|
-
from rdkit.Chem
|
|
21
|
-
|
|
19
|
+
from rdkit.Chem import (
|
|
20
|
+
rdDetermineBonds,
|
|
21
|
+
GetFormalCharge,
|
|
22
|
+
MolFromXYZBlock,
|
|
23
|
+
MolFromSmiles,
|
|
24
|
+
MolToSmiles,
|
|
25
|
+
CanonSmiles,
|
|
26
|
+
AllChem,
|
|
27
|
+
AddHs,
|
|
28
|
+
Mol,
|
|
29
|
+
)
|
|
22
30
|
from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles # type: ignore
|
|
23
31
|
from sklearn.metrics import (
|
|
24
32
|
roc_auc_score,
|
|
@@ -28,36 +36,10 @@ from sklearn.metrics import (
|
|
|
28
36
|
mean_absolute_error,
|
|
29
37
|
root_mean_squared_error,
|
|
30
38
|
)
|
|
31
|
-
|
|
32
|
-
try:
|
|
33
|
-
from pynauty import Graph, canon_label # type: ignore
|
|
34
|
-
|
|
35
|
-
_use_pynauty = True
|
|
36
|
-
except ImportError:
|
|
37
|
-
import platform
|
|
38
|
-
|
|
39
|
-
_use_pynauty = False
|
|
40
39
|
from .data import VOCAB_KEYS
|
|
41
40
|
from .model import ChemBFN, MLP, Linear, EnsembleChemBFN
|
|
42
41
|
|
|
43
42
|
|
|
44
|
-
_atom_regex_pattern = (
|
|
45
|
-
r"(H[e,f,g,s,o]?|"
|
|
46
|
-
r"L[i,v,a,r,u]|"
|
|
47
|
-
r"B[e,r,a,i,h,k]?|"
|
|
48
|
-
r"C[l,a,r,o,u,d,s,n,e,m,f]?|"
|
|
49
|
-
r"N[e,a,i,b,h,d,o,p]?|"
|
|
50
|
-
r"O[s,g]?|S[i,c,e,r,n,m,b,g]?|"
|
|
51
|
-
r"K[r]?|T[i,c,e,a,l,b,h,m,s]|"
|
|
52
|
-
r"G[a,e,d]|R[b,u,h,e,n,a,f,g]|"
|
|
53
|
-
r"Yb?|Z[n,r]|P[t,o,d,r,a,u,b,m]?|"
|
|
54
|
-
r"F[e,r,l,m]?|M[g,n,o,t,c,d]|"
|
|
55
|
-
r"A[l,r,s,g,u,t,c,m]|I[n,r]?|"
|
|
56
|
-
r"W|X[e]|E[u,r,s]|U|D[b,s,y])"
|
|
57
|
-
)
|
|
58
|
-
_atom_regex = re.compile(_atom_regex_pattern)
|
|
59
|
-
|
|
60
|
-
|
|
61
43
|
def _find_device() -> torch.device:
|
|
62
44
|
if cuda.is_available():
|
|
63
45
|
return torch.device("cuda")
|
|
@@ -66,10 +48,6 @@ def _find_device() -> torch.device:
|
|
|
66
48
|
return torch.device("cpu")
|
|
67
49
|
|
|
68
50
|
|
|
69
|
-
def _bond_pair_idx(bonds: Bond) -> List[List[int]]:
|
|
70
|
-
return [[i.GetBeginAtomIdx(), i.GetEndAtomIdx()] for i in bonds]
|
|
71
|
-
|
|
72
|
-
|
|
73
51
|
@torch.no_grad()
|
|
74
52
|
def test(
|
|
75
53
|
model: ChemBFN,
|
|
@@ -493,6 +471,8 @@ def quantise_model(model: ChemBFN) -> nn.Module:
|
|
|
493
471
|
assert hasattr(
|
|
494
472
|
mod, "qconfig"
|
|
495
473
|
), "Input float module must have qconfig defined"
|
|
474
|
+
if use_precomputed_fake_quant:
|
|
475
|
+
warnings.warn("Fake quantize operator is not implemented.")
|
|
496
476
|
if mod.qconfig is not None and mod.qconfig.weight is not None:
|
|
497
477
|
weight_observer = mod.qconfig.weight()
|
|
498
478
|
else:
|
|
@@ -560,6 +540,90 @@ class GeometryConverter:
|
|
|
560
540
|
xyz_block.append(f"{atom} {r[i][0]:.10f} {r[i][1]:.10f} {r[i][2]:.10f}")
|
|
561
541
|
return MolFromXYZBlock("\n".join(xyz_block))
|
|
562
542
|
|
|
543
|
+
@staticmethod
|
|
544
|
+
def smiles2cartesian(
|
|
545
|
+
smiles: str,
|
|
546
|
+
num_conformers: int = 50,
|
|
547
|
+
rdkit_ff_type: str = "MMFF",
|
|
548
|
+
refine_with_crest: bool = False,
|
|
549
|
+
spin: float = 0.0,
|
|
550
|
+
) -> Tuple[List[str], np.ndarray]:
|
|
551
|
+
"""
|
|
552
|
+
Guess the 3D geometry from SMILES string via MMFF conformer search.
|
|
553
|
+
|
|
554
|
+
:param smiles: a valid SMILES string
|
|
555
|
+
:param num_conformers: number of initial conformers
|
|
556
|
+
:param rdkit_ff_type: force field type chosen in `'MMFF'` and `'UFF'`
|
|
557
|
+
:param refine_with_crest: find the best conformer via CREST
|
|
558
|
+
:param spin: total spin; only required when `refine_with_crest=True`
|
|
559
|
+
:type smiles: str
|
|
560
|
+
:type num_conformers: int
|
|
561
|
+
:type rdkit_ff_type: str
|
|
562
|
+
:type refine_with_crest: bool
|
|
563
|
+
:type spin: float
|
|
564
|
+
:return: atomic symbols \n
|
|
565
|
+
cartesian coordinates; shape: (n_a, 3)
|
|
566
|
+
:rtype: tuple
|
|
567
|
+
"""
|
|
568
|
+
assert rdkit_ff_type.lower() in ("mmff", "uff")
|
|
569
|
+
if refine_with_crest:
|
|
570
|
+
from tempfile import TemporaryDirectory
|
|
571
|
+
from subprocess import run
|
|
572
|
+
|
|
573
|
+
# We need both CREST and xTB installed.
|
|
574
|
+
if run("crest --version", shell=True).returncode != 0:
|
|
575
|
+
raise RuntimeError(
|
|
576
|
+
"`CREST` is not found! Make sure it is installed and added into the PATH."
|
|
577
|
+
)
|
|
578
|
+
if run("xtb --version", shell=True).returncode != 0:
|
|
579
|
+
raise RuntimeError(
|
|
580
|
+
"`xTB` is not found! Make sure it is installed and added into the PATH."
|
|
581
|
+
)
|
|
582
|
+
mol = MolFromSmiles(smiles)
|
|
583
|
+
mol = AddHs(mol)
|
|
584
|
+
AllChem.EmbedMultipleConfs(mol, numConfs=num_conformers, params=AllChem.ETKDG())
|
|
585
|
+
symbols = [atom.GetSymbol() for atom in mol.GetAtoms()]
|
|
586
|
+
energies = []
|
|
587
|
+
for conf_id in range(num_conformers):
|
|
588
|
+
if rdkit_ff_type.lower() == "mmff":
|
|
589
|
+
ff = AllChem.MMFFGetMoleculeForceField(
|
|
590
|
+
mol, AllChem.MMFFGetMoleculeProperties(mol), confId=conf_id
|
|
591
|
+
)
|
|
592
|
+
else: # UFF
|
|
593
|
+
ff = AllChem.UFFGetMoleculeForceField(mol, confId=conf_id)
|
|
594
|
+
energy = ff.CalcEnergy()
|
|
595
|
+
energies.append((conf_id, energy))
|
|
596
|
+
lowest_energy_conf = min(energies, key=lambda x: x[1])
|
|
597
|
+
coordinates = mol.GetConformer(id=lowest_energy_conf[0]).GetPositions()
|
|
598
|
+
if refine_with_crest:
|
|
599
|
+
xyz = f"{len(symbols)}\n\n" + "\n".join(
|
|
600
|
+
f"{s} {coordinates[i][0]:.10f} {coordinates[i][1]:.10f} {coordinates[i][2]:.10f}"
|
|
601
|
+
for i, s in enumerate(symbols)
|
|
602
|
+
)
|
|
603
|
+
chrg = GetFormalCharge(mol)
|
|
604
|
+
uhf = int(spin * 2)
|
|
605
|
+
with TemporaryDirectory(dir=Path.cwd()) as temp_dir:
|
|
606
|
+
with open(Path(temp_dir) / "mol.xyz", "w", encoding="utf-8") as f:
|
|
607
|
+
f.write(xyz)
|
|
608
|
+
s = run(
|
|
609
|
+
f"crest mol.xyz -gfn2 -quick -prop ohess{f' --chrg {chrg}' if chrg != 0 else ''}{f' --uhf {uhf}' if uhf != 0 else ''}",
|
|
610
|
+
shell=True,
|
|
611
|
+
cwd=temp_dir,
|
|
612
|
+
)
|
|
613
|
+
if s.returncode == 0:
|
|
614
|
+
with open(Path(temp_dir) / "crest_property.xyz", "r") as f:
|
|
615
|
+
xyz = f.readlines()
|
|
616
|
+
xyz_data = []
|
|
617
|
+
for i in xyz[2:]:
|
|
618
|
+
if i == xyz[0]:
|
|
619
|
+
break
|
|
620
|
+
xyz_data.append(i.strip().split())
|
|
621
|
+
xyz_data = np.array(xyz_data)
|
|
622
|
+
symbols, coordinates = np.split(xyz_data, [1], axis=-1)
|
|
623
|
+
symbols = symbols.flatten().tolist()
|
|
624
|
+
coordinates = coordinates.astype(np.float64)
|
|
625
|
+
return symbols, coordinates
|
|
626
|
+
|
|
563
627
|
def cartesian2smiles(
|
|
564
628
|
self,
|
|
565
629
|
symbols: List[str],
|
|
@@ -587,136 +651,3 @@ class GeometryConverter:
|
|
|
587
651
|
if canonical:
|
|
588
652
|
smiles = CanonSmiles(smiles)
|
|
589
653
|
return smiles
|
|
590
|
-
|
|
591
|
-
def canonicalise(
|
|
592
|
-
self, symbols: List[str], coordinates: np.ndarray
|
|
593
|
-
) -> Tuple[List[str], np.ndarray]:
|
|
594
|
-
"""
|
|
595
|
-
Canonicalising the 3D molecular graph.
|
|
596
|
-
|
|
597
|
-
:param symbols: a list of atomic symbols
|
|
598
|
-
:param coordinates: Cartesian coordinates; shape: (n_a, 3)
|
|
599
|
-
:type symbols: list
|
|
600
|
-
:type coordinates: numpy.ndarray
|
|
601
|
-
:return: canonicalised symbols \n
|
|
602
|
-
canonicalised coordinates; shape: (n_a, 3)
|
|
603
|
-
:rtype: tuple
|
|
604
|
-
"""
|
|
605
|
-
if not _use_pynauty:
|
|
606
|
-
if platform.system() == "Windows":
|
|
607
|
-
raise NotImplementedError(
|
|
608
|
-
"This method is not implemented on Windows platform."
|
|
609
|
-
)
|
|
610
|
-
else:
|
|
611
|
-
raise ImportError("`pynauty` is not installed.")
|
|
612
|
-
n = len(symbols)
|
|
613
|
-
if n == 1:
|
|
614
|
-
return symbols, coordinates
|
|
615
|
-
mol = self._xyz2mol(symbols, coordinates)
|
|
616
|
-
rdDetermineBonds.DetermineConnectivity(mol)
|
|
617
|
-
# ------- Canonicalization -------
|
|
618
|
-
pair_idx = np.array(_bond_pair_idx(mol.GetBonds())).T.tolist()
|
|
619
|
-
pair_dict: Dict[int, List[int]] = {}
|
|
620
|
-
for key, i in enumerate(pair_idx[0]):
|
|
621
|
-
if i not in pair_dict:
|
|
622
|
-
pair_dict[i] = [pair_idx[1][key]]
|
|
623
|
-
else:
|
|
624
|
-
pair_dict[i].append(pair_idx[1][key])
|
|
625
|
-
g = Graph(n, adjacency_dict=pair_dict)
|
|
626
|
-
cl = canon_label(g) # type: list
|
|
627
|
-
symbols = np.array([[s] for s in symbols])[cl].flatten().tolist()
|
|
628
|
-
coordinates = coordinates[cl]
|
|
629
|
-
return symbols, coordinates
|
|
630
|
-
|
|
631
|
-
@staticmethod
|
|
632
|
-
def cartesian2spherical(coordinates: np.ndarray) -> np.ndarray:
|
|
633
|
-
"""
|
|
634
|
-
Transforming Cartesian coordinate to spherical form.\n
|
|
635
|
-
The method is adapted from the paper: https://arxiv.org/abs/2408.10120.
|
|
636
|
-
|
|
637
|
-
:param coordinates: Cartesian coordinates; shape: (n_a, 3)
|
|
638
|
-
:type coordinates: numpy.ndarray
|
|
639
|
-
:return: spherical coordinates; shape: (n_a, 3)
|
|
640
|
-
:rtype: numpy.ndarray
|
|
641
|
-
"""
|
|
642
|
-
n = coordinates.shape[0]
|
|
643
|
-
if n == 1:
|
|
644
|
-
return np.array([[0.0, 0.0, 0.0]])
|
|
645
|
-
# ------- Find global coordinate frame -------
|
|
646
|
-
if n == 2:
|
|
647
|
-
d = np.linalg.norm(coordinates[0] - coordinates[1], 2)
|
|
648
|
-
return np.array([[0.0, 0.0, 0.0], [d, 0.0, 0.0]])
|
|
649
|
-
for idx_0 in range(n - 2):
|
|
650
|
-
_vec0 = coordinates[idx_0] - coordinates[idx_0 + 1]
|
|
651
|
-
_vec1 = coordinates[idx_0] - coordinates[idx_0 + 2]
|
|
652
|
-
_d1 = np.linalg.norm(_vec0, 2)
|
|
653
|
-
_d2 = np.linalg.norm(_vec1, 2)
|
|
654
|
-
if 1 - np.abs(np.dot(_vec0, _vec1) / (_d1 * _d2)) > 1e-6:
|
|
655
|
-
break
|
|
656
|
-
x = (coordinates[idx_0 + 1] - coordinates[idx_0]) / _d1
|
|
657
|
-
y = np.cross((coordinates[idx_0 + 2] - coordinates[idx_0]), x)
|
|
658
|
-
y_d = np.linalg.norm(y, 2)
|
|
659
|
-
y = y / np.ma.filled(np.ma.array(y_d, mask=y_d == 0), np.inf)
|
|
660
|
-
z = np.cross(x, y)
|
|
661
|
-
# ------- Build spherical coordinates -------
|
|
662
|
-
vec = coordinates - coordinates[idx_0]
|
|
663
|
-
d = np.linalg.norm(vec, 2, axis=-1)
|
|
664
|
-
_d = np.ma.filled(np.ma.array(d, mask=d == 0), np.inf)
|
|
665
|
-
theta = np.arccos(np.dot(vec, z) / _d) # in [0, \pi]
|
|
666
|
-
phi = np.arctan2(np.dot(vec, y), np.dot(vec, x)) # in [-\pi, \pi]
|
|
667
|
-
info = np.vstack([d, theta, phi]).T
|
|
668
|
-
info[idx_0] = np.zeros_like(info[idx_0])
|
|
669
|
-
return info
|
|
670
|
-
|
|
671
|
-
def geo2seq(
|
|
672
|
-
self, symbols: List[str], coordinates: np.ndarray, decimals: int = 2
|
|
673
|
-
) -> str:
|
|
674
|
-
"""
|
|
675
|
-
Geometry-to-sequence function.\n
|
|
676
|
-
The algorithm follows the descriptions in paper: https://arxiv.org/abs/2408.10120.
|
|
677
|
-
|
|
678
|
-
:param symbols: a list of atomic symbols
|
|
679
|
-
:param coordinates: Cartesian coordinates; shape: (n_a, 3)
|
|
680
|
-
:param decimals: the maxmium number of decimals to keep; default is 2
|
|
681
|
-
:type symbols: list
|
|
682
|
-
:type coordinates: numpy.ndarray
|
|
683
|
-
:type decimals: int
|
|
684
|
-
:return: `Geo2Seq` string
|
|
685
|
-
:rtype: str
|
|
686
|
-
"""
|
|
687
|
-
symbols, coordinates = self.canonicalise(symbols, coordinates)
|
|
688
|
-
info = self.cartesian2spherical(coordinates)
|
|
689
|
-
info = [
|
|
690
|
-
f"{symbols[i]} {r[0]} {r[1]} {r[2]}"
|
|
691
|
-
for i, r in enumerate(np.round(info, decimals))
|
|
692
|
-
]
|
|
693
|
-
return " ".join(info)
|
|
694
|
-
|
|
695
|
-
@staticmethod
|
|
696
|
-
def seq2geo(seq: str) -> Tuple[Optional[List[str]], Optional[np.ndarray]]:
|
|
697
|
-
"""
|
|
698
|
-
Sequence-to-geometry function.\n
|
|
699
|
-
The method follows the descriptions in paper: https://arxiv.org/abs/2408.10120.
|
|
700
|
-
|
|
701
|
-
:param seq: `Geo2Seq` string
|
|
702
|
-
:type seq: str
|
|
703
|
-
:return: (symbols, coordinates) if `seq` is valid
|
|
704
|
-
:rtype: tuple
|
|
705
|
-
"""
|
|
706
|
-
tokens = seq.split()
|
|
707
|
-
if len(tokens) % 4 != 0:
|
|
708
|
-
return None, None
|
|
709
|
-
tokens = np.array(tokens).reshape(-1, 4)
|
|
710
|
-
symbols, coordinates = tokens[::, 0], tokens[::, 1:]
|
|
711
|
-
if sum([len(_atom_regex.findall(sym)) for sym in symbols]) != len(symbols):
|
|
712
|
-
return None, None
|
|
713
|
-
try:
|
|
714
|
-
coord = [[float(i) for i in j] for j in coordinates]
|
|
715
|
-
coord = np.array(coord)
|
|
716
|
-
except ValueError:
|
|
717
|
-
return None, None
|
|
718
|
-
d, theta, phi = coord[::, 0, None], coord[::, 1, None], coord[::, 2, None]
|
|
719
|
-
x = d * np.sin(theta) * np.cos(phi)
|
|
720
|
-
y = d * np.sin(theta) * np.sin(phi)
|
|
721
|
-
z = d * np.cos(theta)
|
|
722
|
-
return symbols, np.concatenate([x, y, z], -1)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: bayesianflow_for_chem
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.4.1
|
|
4
4
|
Summary: Bayesian flow network framework for Chemistry
|
|
5
5
|
Home-page: https://augus1999.github.io/bayesian-flow-network-for-chemistry/
|
|
6
6
|
Author: Nianze A. Tao
|
|
@@ -28,8 +28,6 @@ Requires-Dist: loralib>=0.1.2
|
|
|
28
28
|
Requires-Dist: lightning>=2.2.0
|
|
29
29
|
Requires-Dist: scikit-learn>=1.5.0
|
|
30
30
|
Requires-Dist: typing_extensions>=4.8.0
|
|
31
|
-
Provides-Extra: geo2seq
|
|
32
|
-
Requires-Dist: pynauty>=2.8.8.1; extra == "geo2seq"
|
|
33
31
|
Dynamic: author
|
|
34
32
|
Dynamic: author-email
|
|
35
33
|
Dynamic: classifier
|
|
@@ -40,7 +38,6 @@ Dynamic: keywords
|
|
|
40
38
|
Dynamic: license
|
|
41
39
|
Dynamic: license-file
|
|
42
40
|
Dynamic: project-url
|
|
43
|
-
Dynamic: provides-extra
|
|
44
41
|
Dynamic: requires-dist
|
|
45
42
|
Dynamic: requires-python
|
|
46
43
|
Dynamic: summary
|
|
@@ -92,7 +89,7 @@ You can find pretrained models on our [🤗Hugging Face model page](https://hugg
|
|
|
92
89
|
|
|
93
90
|
We provide a Python class [`CSVData`](./bayesianflow_for_chem/data.py) to handle data stored in CSV or similar format containing headers to identify the entities. The following is a quickstart.
|
|
94
91
|
|
|
95
|
-
1. Download your dataset file (e.g., ESOL
|
|
92
|
+
1. Download your dataset file (e.g., ESOL from [MoleculeNet](https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv)) and split the file:
|
|
96
93
|
```python
|
|
97
94
|
>>> from bayesianflow_for_chem.tool import split_data
|
|
98
95
|
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
bayesianflow_for_chem/__init__.py,sha256=N_7P9Ea0eUmdC0wQKXIHiuMzPK4p9_cBF_YOexjo5yo,329
|
|
2
|
+
bayesianflow_for_chem/data.py,sha256=WoOCOVmJX4WeHa2WeO4i66J2FS8rvRaYRCdlBN7ZeOM,6576
|
|
3
|
+
bayesianflow_for_chem/model.py,sha256=zJkcUnZcxFa4iTo9_-BHzAM1MkJm1pbEGiczVgyUcxo,50186
|
|
4
|
+
bayesianflow_for_chem/scorer.py,sha256=7G1TVSwC0qONtNm6kiDZUWwvuFPzasNSjp4eJAk5TL0,4101
|
|
5
|
+
bayesianflow_for_chem/tool.py,sha256=Ma4dEBWP5GFKk-Tj5vBJgxGs_WAp4F87-b1UcsqUAn4,25486
|
|
6
|
+
bayesianflow_for_chem/train.py,sha256=hGKyhGhLch-exSYPZdLXrLn3gf39Q1VLSJs2qtuikQE,9709
|
|
7
|
+
bayesianflow_for_chem/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
|
|
8
|
+
bayesianflow_for_chem-1.4.1.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
|
|
9
|
+
bayesianflow_for_chem-1.4.1.dist-info/METADATA,sha256=460yUOjHG9PTavIddJJ2Ufdq0bkLBZqbmMugyq6LVPQ,5643
|
|
10
|
+
bayesianflow_for_chem-1.4.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
11
|
+
bayesianflow_for_chem-1.4.1.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
|
|
12
|
+
bayesianflow_for_chem-1.4.1.dist-info/RECORD,,
|
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
bayesianflow_for_chem/__init__.py,sha256=3BW4-ri8OcMZAIPJBT2q-48L3LAY776xluMDC6WXaZU,329
|
|
2
|
-
bayesianflow_for_chem/data.py,sha256=EbCfhA1bCieVHVOYVk7nvgsaOzhKyFdnHd261qNR4BY,7763
|
|
3
|
-
bayesianflow_for_chem/model.py,sha256=fFcfg1RZuoJeptAtglo2U8j1EGNSGjItMHqlKdLGGhU,50799
|
|
4
|
-
bayesianflow_for_chem/scorer.py,sha256=7G1TVSwC0qONtNm6kiDZUWwvuFPzasNSjp4eJAk5TL0,4101
|
|
5
|
-
bayesianflow_for_chem/tool.py,sha256=Z9qF80qzK-CJk9MJaWuSNOLnA-LPiD6CiC7S3sZbBP8,27704
|
|
6
|
-
bayesianflow_for_chem/train.py,sha256=hGKyhGhLch-exSYPZdLXrLn3gf39Q1VLSJs2qtuikQE,9709
|
|
7
|
-
bayesianflow_for_chem/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
|
|
8
|
-
bayesianflow_for_chem-1.3.0.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
|
|
9
|
-
bayesianflow_for_chem-1.3.0.dist-info/METADATA,sha256=2BDjaVhIkd0TLolVETa2kb7xUGYhn8kdlq2CMfF-i7Y,5746
|
|
10
|
-
bayesianflow_for_chem-1.3.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
11
|
-
bayesianflow_for_chem-1.3.0.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
|
|
12
|
-
bayesianflow_for_chem-1.3.0.dist-info/RECORD,,
|
|
File without changes
|
{bayesianflow_for_chem-1.3.0.dist-info → bayesianflow_for_chem-1.4.1.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|
{bayesianflow_for_chem-1.3.0.dist-info → bayesianflow_for_chem-1.4.1.dist-info}/top_level.txt
RENAMED
|
File without changes
|