bayesianflow-for-chem 1.4.0__py3-none-any.whl → 1.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bayesianflow-for-chem might be problematic. Click here for more details.
- bayesianflow_for_chem/__init__.py +1 -1
- bayesianflow_for_chem/model.py +3 -18
- bayesianflow_for_chem/tool.py +61 -12
- {bayesianflow_for_chem-1.4.0.dist-info → bayesianflow_for_chem-1.4.1.dist-info}/METADATA +1 -1
- bayesianflow_for_chem-1.4.1.dist-info/RECORD +12 -0
- bayesianflow_for_chem-1.4.0.dist-info/RECORD +0 -12
- {bayesianflow_for_chem-1.4.0.dist-info → bayesianflow_for_chem-1.4.1.dist-info}/WHEEL +0 -0
- {bayesianflow_for_chem-1.4.0.dist-info → bayesianflow_for_chem-1.4.1.dist-info}/licenses/LICENSE +0 -0
- {bayesianflow_for_chem-1.4.0.dist-info → bayesianflow_for_chem-1.4.1.dist-info}/top_level.txt +0 -0
|
@@ -7,5 +7,5 @@ from . import data, tool, train, scorer
|
|
|
7
7
|
from .model import ChemBFN, MLP, EnsembleChemBFN
|
|
8
8
|
|
|
9
9
|
__all__ = ["data", "tool", "train", "scorer", "ChemBFN", "MLP", "EnsembleChemBFN"]
|
|
10
|
-
__version__ = "1.4.
|
|
10
|
+
__version__ = "1.4.1"
|
|
11
11
|
__author__ = "Nianze A. Tao (Omozawa Sueno)"
|
bayesianflow_for_chem/model.py
CHANGED
|
@@ -169,16 +169,6 @@ class Attention(nn.Module):
|
|
|
169
169
|
k = k.view(split).permute(2, 0, 1, 3).contiguous()
|
|
170
170
|
v = v.view(split).permute(2, 0, 1, 3).contiguous()
|
|
171
171
|
q, k = self._rotate(q, k, pe) # position embedding
|
|
172
|
-
"""
|
|
173
|
-
# Original code. Maybe using `nn.functional.scaled_dot_product_attention(...)` is better.
|
|
174
|
-
|
|
175
|
-
k_t = k.transpose(-2, -1)
|
|
176
|
-
if mask is not None:
|
|
177
|
-
alpha = softmax((q @ k_t / self.tp).masked_fill_(mask, -torch.inf), -1)
|
|
178
|
-
else:
|
|
179
|
-
alpha = softmax(q @ k_t / self.tp, -1)
|
|
180
|
-
atten_out = (alpha @ v).permute(1, 2, 0, 3).contiguous().view(shape)
|
|
181
|
-
"""
|
|
182
172
|
atten_out = nn.functional.scaled_dot_product_attention(
|
|
183
173
|
q, k, v, mask, 0.0, False, scale=1 / self.tp
|
|
184
174
|
)
|
|
@@ -430,19 +420,14 @@ class ChemBFN(nn.Module):
|
|
|
430
420
|
c += y
|
|
431
421
|
pe = self.position(n_t)
|
|
432
422
|
x = self.embedding(x)
|
|
433
|
-
attn_mask: Optional[Tensor] = None
|
|
434
423
|
if self.semi_autoregressive:
|
|
435
424
|
attn_mask = torch.tril(
|
|
436
425
|
torch.ones((1, n_b, n_t, n_t), device=x.device), diagonal=0
|
|
437
426
|
)
|
|
427
|
+
elif mask is not None:
|
|
428
|
+
attn_mask = mask.transpose(-2, -1).repeat(1, n_t, 1)[None, ...] != 0
|
|
438
429
|
else:
|
|
439
|
-
|
|
440
|
-
"""
|
|
441
|
-
# Original Code.
|
|
442
|
-
|
|
443
|
-
attn_mask = mask.transpose(-2, -1).repeat(1, x.shape[1], 1)[None, ...] == 0
|
|
444
|
-
"""
|
|
445
|
-
attn_mask = mask.transpose(-2, -1).repeat(1, n_t, 1)[None, ...] != 0
|
|
430
|
+
attn_mask = None
|
|
446
431
|
for layer in self.encoder_layers:
|
|
447
432
|
x = layer(x, pe, c, attn_mask)
|
|
448
433
|
return self.final_layer(x, c, mask is None)
|
bayesianflow_for_chem/tool.py
CHANGED
|
@@ -16,15 +16,16 @@ from torch import cuda, Tensor, softmax
|
|
|
16
16
|
from torch.ao import quantization
|
|
17
17
|
from torch.utils.data import DataLoader
|
|
18
18
|
from typing_extensions import Self
|
|
19
|
-
from rdkit.Chem.rdchem import Mol, Bond
|
|
20
19
|
from rdkit.Chem import (
|
|
21
20
|
rdDetermineBonds,
|
|
21
|
+
GetFormalCharge,
|
|
22
22
|
MolFromXYZBlock,
|
|
23
23
|
MolFromSmiles,
|
|
24
24
|
MolToSmiles,
|
|
25
25
|
CanonSmiles,
|
|
26
26
|
AllChem,
|
|
27
27
|
AddHs,
|
|
28
|
+
Mol,
|
|
28
29
|
)
|
|
29
30
|
from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles # type: ignore
|
|
30
31
|
from sklearn.metrics import (
|
|
@@ -539,40 +540,88 @@ class GeometryConverter:
|
|
|
539
540
|
xyz_block.append(f"{atom} {r[i][0]:.10f} {r[i][1]:.10f} {r[i][2]:.10f}")
|
|
540
541
|
return MolFromXYZBlock("\n".join(xyz_block))
|
|
541
542
|
|
|
542
|
-
@staticmethod
|
|
543
|
-
def _bond_pair_idx(bonds: Bond) -> List[List[int]]:
|
|
544
|
-
return [[i.GetBeginAtomIdx(), i.GetEndAtomIdx()] for i in bonds]
|
|
545
|
-
|
|
546
543
|
@staticmethod
|
|
547
544
|
def smiles2cartesian(
|
|
548
|
-
smiles: str,
|
|
545
|
+
smiles: str,
|
|
546
|
+
num_conformers: int = 50,
|
|
547
|
+
rdkit_ff_type: str = "MMFF",
|
|
548
|
+
refine_with_crest: bool = False,
|
|
549
|
+
spin: float = 0.0,
|
|
549
550
|
) -> Tuple[List[str], np.ndarray]:
|
|
550
551
|
"""
|
|
551
552
|
Guess the 3D geometry from SMILES string via MMFF conformer search.
|
|
552
553
|
|
|
553
554
|
:param smiles: a valid SMILES string
|
|
554
555
|
:param num_conformers: number of initial conformers
|
|
555
|
-
:param
|
|
556
|
+
:param rdkit_ff_type: force field type chosen in `'MMFF'` and `'UFF'`
|
|
557
|
+
:param refine_with_crest: find the best conformer via CREST
|
|
558
|
+
:param spin: total spin; only required when `refine_with_crest=True`
|
|
556
559
|
:type smiles: str
|
|
557
560
|
:type num_conformers: int
|
|
558
|
-
:type
|
|
561
|
+
:type rdkit_ff_type: str
|
|
562
|
+
:type refine_with_crest: bool
|
|
563
|
+
:type spin: float
|
|
559
564
|
:return: atomic symbols \n
|
|
560
565
|
cartesian coordinates; shape: (n_a, 3)
|
|
561
566
|
:rtype: tuple
|
|
562
567
|
"""
|
|
568
|
+
assert rdkit_ff_type.lower() in ("mmff", "uff")
|
|
569
|
+
if refine_with_crest:
|
|
570
|
+
from tempfile import TemporaryDirectory
|
|
571
|
+
from subprocess import run
|
|
572
|
+
|
|
573
|
+
# We need both CREST and xTB installed.
|
|
574
|
+
if run("crest --version", shell=True).returncode != 0:
|
|
575
|
+
raise RuntimeError(
|
|
576
|
+
"`CREST` is not found! Make sure it is installed and added into the PATH."
|
|
577
|
+
)
|
|
578
|
+
if run("xtb --version", shell=True).returncode != 0:
|
|
579
|
+
raise RuntimeError(
|
|
580
|
+
"`xTB` is not found! Make sure it is installed and added into the PATH."
|
|
581
|
+
)
|
|
563
582
|
mol = MolFromSmiles(smiles)
|
|
564
583
|
mol = AddHs(mol)
|
|
565
|
-
AllChem.EmbedMultipleConfs(mol, numConfs=num_conformers,
|
|
584
|
+
AllChem.EmbedMultipleConfs(mol, numConfs=num_conformers, params=AllChem.ETKDG())
|
|
566
585
|
symbols = [atom.GetSymbol() for atom in mol.GetAtoms()]
|
|
567
586
|
energies = []
|
|
568
587
|
for conf_id in range(num_conformers):
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
588
|
+
if rdkit_ff_type.lower() == "mmff":
|
|
589
|
+
ff = AllChem.MMFFGetMoleculeForceField(
|
|
590
|
+
mol, AllChem.MMFFGetMoleculeProperties(mol), confId=conf_id
|
|
591
|
+
)
|
|
592
|
+
else: # UFF
|
|
593
|
+
ff = AllChem.UFFGetMoleculeForceField(mol, confId=conf_id)
|
|
572
594
|
energy = ff.CalcEnergy()
|
|
573
595
|
energies.append((conf_id, energy))
|
|
574
596
|
lowest_energy_conf = min(energies, key=lambda x: x[1])
|
|
575
597
|
coordinates = mol.GetConformer(id=lowest_energy_conf[0]).GetPositions()
|
|
598
|
+
if refine_with_crest:
|
|
599
|
+
xyz = f"{len(symbols)}\n\n" + "\n".join(
|
|
600
|
+
f"{s} {coordinates[i][0]:.10f} {coordinates[i][1]:.10f} {coordinates[i][2]:.10f}"
|
|
601
|
+
for i, s in enumerate(symbols)
|
|
602
|
+
)
|
|
603
|
+
chrg = GetFormalCharge(mol)
|
|
604
|
+
uhf = int(spin * 2)
|
|
605
|
+
with TemporaryDirectory(dir=Path.cwd()) as temp_dir:
|
|
606
|
+
with open(Path(temp_dir) / "mol.xyz", "w", encoding="utf-8") as f:
|
|
607
|
+
f.write(xyz)
|
|
608
|
+
s = run(
|
|
609
|
+
f"crest mol.xyz -gfn2 -quick -prop ohess{f' --chrg {chrg}' if chrg != 0 else ''}{f' --uhf {uhf}' if uhf != 0 else ''}",
|
|
610
|
+
shell=True,
|
|
611
|
+
cwd=temp_dir,
|
|
612
|
+
)
|
|
613
|
+
if s.returncode == 0:
|
|
614
|
+
with open(Path(temp_dir) / "crest_property.xyz", "r") as f:
|
|
615
|
+
xyz = f.readlines()
|
|
616
|
+
xyz_data = []
|
|
617
|
+
for i in xyz[2:]:
|
|
618
|
+
if i == xyz[0]:
|
|
619
|
+
break
|
|
620
|
+
xyz_data.append(i.strip().split())
|
|
621
|
+
xyz_data = np.array(xyz_data)
|
|
622
|
+
symbols, coordinates = np.split(xyz_data, [1], axis=-1)
|
|
623
|
+
symbols = symbols.flatten().tolist()
|
|
624
|
+
coordinates = coordinates.astype(np.float64)
|
|
576
625
|
return symbols, coordinates
|
|
577
626
|
|
|
578
627
|
def cartesian2smiles(
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
bayesianflow_for_chem/__init__.py,sha256=N_7P9Ea0eUmdC0wQKXIHiuMzPK4p9_cBF_YOexjo5yo,329
|
|
2
|
+
bayesianflow_for_chem/data.py,sha256=WoOCOVmJX4WeHa2WeO4i66J2FS8rvRaYRCdlBN7ZeOM,6576
|
|
3
|
+
bayesianflow_for_chem/model.py,sha256=zJkcUnZcxFa4iTo9_-BHzAM1MkJm1pbEGiczVgyUcxo,50186
|
|
4
|
+
bayesianflow_for_chem/scorer.py,sha256=7G1TVSwC0qONtNm6kiDZUWwvuFPzasNSjp4eJAk5TL0,4101
|
|
5
|
+
bayesianflow_for_chem/tool.py,sha256=Ma4dEBWP5GFKk-Tj5vBJgxGs_WAp4F87-b1UcsqUAn4,25486
|
|
6
|
+
bayesianflow_for_chem/train.py,sha256=hGKyhGhLch-exSYPZdLXrLn3gf39Q1VLSJs2qtuikQE,9709
|
|
7
|
+
bayesianflow_for_chem/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
|
|
8
|
+
bayesianflow_for_chem-1.4.1.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
|
|
9
|
+
bayesianflow_for_chem-1.4.1.dist-info/METADATA,sha256=460yUOjHG9PTavIddJJ2Ufdq0bkLBZqbmMugyq6LVPQ,5643
|
|
10
|
+
bayesianflow_for_chem-1.4.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
11
|
+
bayesianflow_for_chem-1.4.1.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
|
|
12
|
+
bayesianflow_for_chem-1.4.1.dist-info/RECORD,,
|
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
bayesianflow_for_chem/__init__.py,sha256=3sP8nM4_idOX-ksvpBJEApxPAVAPijKvQHxidTO5790,329
|
|
2
|
-
bayesianflow_for_chem/data.py,sha256=WoOCOVmJX4WeHa2WeO4i66J2FS8rvRaYRCdlBN7ZeOM,6576
|
|
3
|
-
bayesianflow_for_chem/model.py,sha256=fUrXKhn2U9FrVPJyb4lqACqPTePkIgI0v6_1jPs5c0Q,50784
|
|
4
|
-
bayesianflow_for_chem/scorer.py,sha256=7G1TVSwC0qONtNm6kiDZUWwvuFPzasNSjp4eJAk5TL0,4101
|
|
5
|
-
bayesianflow_for_chem/tool.py,sha256=NMMRHk2FJY0fyA76zCrz6tkcylCuExMUMj5hohWTnkE,23155
|
|
6
|
-
bayesianflow_for_chem/train.py,sha256=hGKyhGhLch-exSYPZdLXrLn3gf39Q1VLSJs2qtuikQE,9709
|
|
7
|
-
bayesianflow_for_chem/vocab.txt,sha256=HgtAZmpWYk4y8PqEVC4vqut1vE75DfRKE_10s2UW0rU,790
|
|
8
|
-
bayesianflow_for_chem-1.4.0.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
|
|
9
|
-
bayesianflow_for_chem-1.4.0.dist-info/METADATA,sha256=1Y5mLIOaPsHcyCCm2SkWz7OCniQYVJ67-cVq3cUU0Mw,5643
|
|
10
|
-
bayesianflow_for_chem-1.4.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
11
|
-
bayesianflow_for_chem-1.4.0.dist-info/top_level.txt,sha256=KHsanI3BMCt8D9Qpze2ycrF6nMa3PyojgO6eS1c8kco,22
|
|
12
|
-
bayesianflow_for_chem-1.4.0.dist-info/RECORD,,
|
|
File without changes
|
{bayesianflow_for_chem-1.4.0.dist-info → bayesianflow_for_chem-1.4.1.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|
{bayesianflow_for_chem-1.4.0.dist-info → bayesianflow_for_chem-1.4.1.dist-info}/top_level.txt
RENAMED
|
File without changes
|