molcraft 0.1.0a5__py3-none-any.whl → 0.1.0a7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of molcraft might be problematic. Click here for more details.
- molcraft/__init__.py +3 -2
- molcraft/callbacks.py +60 -0
- molcraft/chem.py +103 -21
- molcraft/conformers.py +1 -5
- molcraft/featurizers.py +20 -14
- molcraft/layers.py +307 -211
- molcraft/losses.py +36 -0
- molcraft/models.py +135 -9
- molcraft/ops.py +12 -2
- molcraft/records.py +32 -31
- molcraft/tensors.py +1 -1
- {molcraft-0.1.0a5.dist-info → molcraft-0.1.0a7.dist-info}/METADATA +4 -17
- molcraft-0.1.0a7.dist-info/RECORD +19 -0
- {molcraft-0.1.0a5.dist-info → molcraft-0.1.0a7.dist-info}/WHEEL +1 -1
- molcraft-0.1.0a5.dist-info/RECORD +0 -18
- {molcraft-0.1.0a5.dist-info → molcraft-0.1.0a7.dist-info}/licenses/LICENSE +0 -0
- {molcraft-0.1.0a5.dist-info → molcraft-0.1.0a7.dist-info}/top_level.txt +0 -0
molcraft/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
__version__ = '0.1.
|
|
1
|
+
__version__ = '0.1.0a7'
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
@@ -14,4 +14,5 @@ from molcraft import ops
|
|
|
14
14
|
from molcraft import records
|
|
15
15
|
from molcraft import tensors
|
|
16
16
|
from molcraft import callbacks
|
|
17
|
-
from molcraft import datasets
|
|
17
|
+
from molcraft import datasets
|
|
18
|
+
from molcraft import losses
|
molcraft/callbacks.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
import keras
|
|
2
|
+
import warnings
|
|
3
|
+
import numpy as np
|
|
2
4
|
|
|
3
5
|
|
|
4
6
|
class TensorBoard(keras.callbacks.TensorBoard):
|
|
@@ -31,3 +33,61 @@ class LearningRateDecay(keras.callbacks.LearningRateScheduler):
|
|
|
31
33
|
return float(lr * keras.ops.exp(-rate))
|
|
32
34
|
|
|
33
35
|
super().__init__(schedule=lr_schedule, **kwargs)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class Rollback(keras.callbacks.Callback):
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
frequency: int = None,
|
|
43
|
+
tolerance: float = 0.5,
|
|
44
|
+
rollback_optimizer: bool = True,
|
|
45
|
+
):
|
|
46
|
+
super().__init__()
|
|
47
|
+
self.frequency = frequency or 1_000_000_000
|
|
48
|
+
self.tolerance = tolerance
|
|
49
|
+
self.rollback_optimizer = rollback_optimizer
|
|
50
|
+
|
|
51
|
+
def on_train_begin(self, logs=None):
|
|
52
|
+
self.rollback_weights = self._get_model_vars()
|
|
53
|
+
self.rollback_optimizer_vars = self._get_optimizer_vars()
|
|
54
|
+
self.rollback_loss = float('inf')
|
|
55
|
+
|
|
56
|
+
def on_epoch_end(self, epoch: int, logs: dict = None):
|
|
57
|
+
current_loss = logs.get('val_loss', logs.get('loss'))
|
|
58
|
+
deviation = (current_loss - self.rollback_loss) / self.rollback_loss
|
|
59
|
+
|
|
60
|
+
if np.isnan(current_loss) or np.isinf(current_loss):
|
|
61
|
+
self._rollback()
|
|
62
|
+
print("\nRolling back model, found nan or inf loss.\n")
|
|
63
|
+
return
|
|
64
|
+
|
|
65
|
+
if deviation > self.tolerance:
|
|
66
|
+
self._rollback()
|
|
67
|
+
print(f"\nRolling back model, found too large deviation: {deviation:.3f}\n")
|
|
68
|
+
|
|
69
|
+
if epoch and epoch % self.frequency == 0:
|
|
70
|
+
self._rollback()
|
|
71
|
+
print(f"\nRolling back model, {epoch} % {self.frequency} == 0\n")
|
|
72
|
+
return
|
|
73
|
+
|
|
74
|
+
if current_loss < self.rollback_loss:
|
|
75
|
+
self._save_state(current_loss)
|
|
76
|
+
|
|
77
|
+
def _save_state(self, current_loss: float) -> None:
|
|
78
|
+
self.rollback_loss = current_loss
|
|
79
|
+
self.rollback_weights = self._get_model_vars()
|
|
80
|
+
if self.rollback_optimizer:
|
|
81
|
+
self.rollback_optimizer_vars = self._get_optimizer_vars()
|
|
82
|
+
|
|
83
|
+
def _rollback(self) -> None:
|
|
84
|
+
self.model.set_weights(self.rollback_weights)
|
|
85
|
+
if self.rollback_optimizer:
|
|
86
|
+
self.model.optimizer.set_weights(self.rollback_optimizer_vars)
|
|
87
|
+
|
|
88
|
+
def _get_optimizer_vars(self):
|
|
89
|
+
return [v.numpy() for v in self.model.optimizer.variables]
|
|
90
|
+
|
|
91
|
+
def _get_model_vars(self):
|
|
92
|
+
return self.model.get_weights()
|
|
93
|
+
|
molcraft/chem.py
CHANGED
|
@@ -11,6 +11,7 @@ from rdkit.Chem import rdMolTransforms
|
|
|
11
11
|
from rdkit.Chem import rdPartialCharges
|
|
12
12
|
from rdkit.Chem import rdMolDescriptors
|
|
13
13
|
from rdkit.Chem import rdForceFieldHelpers
|
|
14
|
+
from rdkit.Chem import rdFingerprintGenerator
|
|
14
15
|
|
|
15
16
|
|
|
16
17
|
class Mol(Chem.Mol):
|
|
@@ -399,7 +400,6 @@ def embed_conformers(
|
|
|
399
400
|
mol: Mol,
|
|
400
401
|
num_conformers: int,
|
|
401
402
|
method: str = 'ETKDGv3',
|
|
402
|
-
force: bool = True,
|
|
403
403
|
**kwargs
|
|
404
404
|
) -> None:
|
|
405
405
|
available_embedding_methods = {
|
|
@@ -410,27 +410,39 @@ def embed_conformers(
|
|
|
410
410
|
'srETKDGv3': rdDistGeom.srETKDGv3(),
|
|
411
411
|
'KDG': rdDistGeom.KDG()
|
|
412
412
|
}
|
|
413
|
-
default_embedding_method = 'ETKDGv3'
|
|
414
413
|
mol = Mol(mol)
|
|
415
|
-
|
|
416
|
-
if
|
|
417
|
-
|
|
418
|
-
f
|
|
419
|
-
|
|
414
|
+
embedding_method = available_embedding_methods.get(method)
|
|
415
|
+
if embedding_method is None:
|
|
416
|
+
raise ValueError(
|
|
417
|
+
f'Could not find `method` {method!r}. Specify either of: '
|
|
418
|
+
'`ETDG`, `ETKDG`, `ETKDGv2`, `ETKDGv3`, `srETKDGv3` or `KDG`.'
|
|
420
419
|
)
|
|
421
|
-
|
|
420
|
+
|
|
422
421
|
for key, value in kwargs.items():
|
|
423
|
-
setattr(
|
|
422
|
+
setattr(embedding_method, key, value)
|
|
424
423
|
|
|
425
|
-
success = rdDistGeom.EmbedMultipleConfs(
|
|
424
|
+
success = rdDistGeom.EmbedMultipleConfs(
|
|
425
|
+
mol, numConfs=num_conformers, params=embedding_method
|
|
426
|
+
)
|
|
426
427
|
if not len(success):
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
428
|
+
warn(
|
|
429
|
+
f'Could not embed conformer(s) for {mol.canonical_smiles!r} using the '
|
|
430
|
+
'speified method. Giving it another try with more permissive methods.'
|
|
431
|
+
)
|
|
432
|
+
max_attempts = (20 * mol.num_atoms) # increasing it from 10xN to 20xN
|
|
433
|
+
for fallback_method in [method, 'ETDG', 'KDG']:
|
|
434
|
+
fallback_embedding_method = available_embedding_methods[fallback_method]
|
|
435
|
+
fallback_embedding_method.useRandomCoords = True
|
|
436
|
+
fallback_embedding_method.maxAttempts = max_attempts
|
|
437
|
+
success = rdDistGeom.EmbedMultipleConfs(
|
|
438
|
+
mol, numConfs=num_conformers, params=fallback_embedding_method
|
|
439
|
+
)
|
|
440
|
+
if len(success):
|
|
441
|
+
break
|
|
430
442
|
else:
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
443
|
+
raise RuntimeError(
|
|
444
|
+
f'Could not embed conformer(s) for {mol.canonical_smiles!r}. '
|
|
445
|
+
)
|
|
434
446
|
return mol
|
|
435
447
|
|
|
436
448
|
def optimize_conformers(
|
|
@@ -444,6 +456,11 @@ def optimize_conformers(
|
|
|
444
456
|
available_force_field_methods = [
|
|
445
457
|
'MMFF', 'MMFF94', 'MMFF94s', 'UFF'
|
|
446
458
|
]
|
|
459
|
+
if method not in available_force_field_methods:
|
|
460
|
+
raise ValueError(
|
|
461
|
+
f'Could not find `method` {method!r}. Specify either of: '
|
|
462
|
+
'`UFF`, `MMFF`, `MMFF94` or `MMFF94s`.'
|
|
463
|
+
)
|
|
447
464
|
mol = Mol(mol)
|
|
448
465
|
try:
|
|
449
466
|
if method.startswith('MMFF'):
|
|
@@ -468,7 +485,7 @@ def optimize_conformers(
|
|
|
468
485
|
except RuntimeError as e:
|
|
469
486
|
warn(
|
|
470
487
|
f'{method} force field minimization raised {e}. '
|
|
471
|
-
'\nProceeding without force field minimization
|
|
488
|
+
'\nProceeding without force field minimization.'
|
|
472
489
|
)
|
|
473
490
|
return mol
|
|
474
491
|
|
|
@@ -579,8 +596,7 @@ def _calc_mmff_energies(
|
|
|
579
596
|
energies.append(float('nan'))
|
|
580
597
|
return energies
|
|
581
598
|
|
|
582
|
-
|
|
583
|
-
def _split_mol_by_confs(mol: Mol) -> list[Mol]:
|
|
599
|
+
def unpack_conformers(mol: Mol) -> list[Mol]:
|
|
584
600
|
mols = []
|
|
585
601
|
for conf in mol.get_conformers():
|
|
586
602
|
new_mol = Chem.Mol(mol)
|
|
@@ -590,11 +606,77 @@ def _split_mol_by_confs(mol: Mol) -> list[Mol]:
|
|
|
590
606
|
mols.append(new_mol)
|
|
591
607
|
return mols
|
|
592
608
|
|
|
609
|
+
_fingerprint_types = {
|
|
610
|
+
'rdkit': rdFingerprintGenerator.GetRDKitFPGenerator,
|
|
611
|
+
'morgan': rdFingerprintGenerator.GetMorganGenerator,
|
|
612
|
+
'topological_torsion': rdFingerprintGenerator.GetTopologicalTorsionGenerator,
|
|
613
|
+
'atom_pair': rdFingerprintGenerator.GetAtomPairGenerator,
|
|
614
|
+
}
|
|
615
|
+
|
|
616
|
+
def _get_fingerprint(
|
|
617
|
+
mol: Mol,
|
|
618
|
+
fp_type: str = 'morgan',
|
|
619
|
+
binary: bool = True,
|
|
620
|
+
dtype: str = 'float32',
|
|
621
|
+
**kwargs,
|
|
622
|
+
) -> np.ndarray:
|
|
623
|
+
fingerprint: rdFingerprintGenerator.FingerprintGenerator64 = (
|
|
624
|
+
_fingerprint_types[fp_type](**kwargs)
|
|
625
|
+
)
|
|
626
|
+
if not isinstance(mol, Mol):
|
|
627
|
+
mol = Mol.from_encoding(mol)
|
|
628
|
+
if binary:
|
|
629
|
+
fp: np.ndarray = fingerprint.GetFingerprintAsNumPy(mol)
|
|
630
|
+
else:
|
|
631
|
+
fp: np.ndarray = fingerprint.GetCountFingerprintAsNumPy(mol)
|
|
632
|
+
return fp.astype(dtype)
|
|
633
|
+
|
|
634
|
+
def _rdkit_fingerprint(
|
|
635
|
+
mol: Chem.Mol,
|
|
636
|
+
size: int = 2048,
|
|
637
|
+
*,
|
|
638
|
+
min_path: int = 1,
|
|
639
|
+
max_path: int = 7,
|
|
640
|
+
binary: bool = True,
|
|
641
|
+
dtype: str = 'float32',
|
|
642
|
+
) -> np.ndarray:
|
|
643
|
+
fp_param = {'fpSize': size, 'minPath': min_path, 'maxPath': max_path}
|
|
644
|
+
return _get_fingerprint(mol, 'rdkit', binary, dtype, **fp_param)
|
|
645
|
+
|
|
646
|
+
def _morgan_fingerprint(
|
|
647
|
+
mol: Chem.Mol,
|
|
648
|
+
size: int = 2048,
|
|
649
|
+
*,
|
|
650
|
+
radius: int = 3,
|
|
651
|
+
binary: bool = True,
|
|
652
|
+
dtype: str = 'float32',
|
|
653
|
+
) -> np.ndarray:
|
|
654
|
+
fp_param = {'radius': radius, 'fpSize': size}
|
|
655
|
+
return _get_fingerprint(mol, 'morgan', binary, dtype, **fp_param)
|
|
656
|
+
|
|
657
|
+
def _topological_torsion_fingerprint(
|
|
658
|
+
mol: Chem.Mol,
|
|
659
|
+
size: int = 2048,
|
|
660
|
+
*,
|
|
661
|
+
binary: bool = True,
|
|
662
|
+
dtype: str = 'float32',
|
|
663
|
+
) -> np.ndarray:
|
|
664
|
+
fp_param = {'fpSize': size}
|
|
665
|
+
return _get_fingerprint(mol, 'topological_torsion', binary, dtype, **fp_param)
|
|
666
|
+
|
|
667
|
+
def _atom_pair_fingerprint(
|
|
668
|
+
mol: Chem.Mol,
|
|
669
|
+
size: int = 2048,
|
|
670
|
+
*,
|
|
671
|
+
binary: bool = True,
|
|
672
|
+
dtype: str = 'float32',
|
|
673
|
+
) -> np.ndarray:
|
|
674
|
+
fp_param = {'fpSize': size}
|
|
675
|
+
return _get_fingerprint(mol, 'atom_pair', binary, dtype, **fp_param)
|
|
676
|
+
|
|
593
677
|
def warn(message: str) -> None:
|
|
594
678
|
warnings.warn(
|
|
595
679
|
message=message,
|
|
596
680
|
category=UserWarning,
|
|
597
681
|
stacklevel=1,
|
|
598
682
|
)
|
|
599
|
-
|
|
600
|
-
|
molcraft/conformers.py
CHANGED
|
@@ -23,20 +23,17 @@ class ConformerEmbedder(ConformerProcessor):
|
|
|
23
23
|
def __init__(
|
|
24
24
|
self,
|
|
25
25
|
method: str = 'ETKDGv3',
|
|
26
|
-
num_conformers: int =
|
|
27
|
-
force: bool = True,
|
|
26
|
+
num_conformers: int = 5,
|
|
28
27
|
**kwargs,
|
|
29
28
|
) -> None:
|
|
30
29
|
self.method = method
|
|
31
30
|
self.num_conformers = num_conformers
|
|
32
|
-
self.force = force
|
|
33
31
|
self.kwargs = kwargs
|
|
34
32
|
|
|
35
33
|
def get_config(self) -> dict:
|
|
36
34
|
config = {
|
|
37
35
|
'method': self.method,
|
|
38
36
|
'num_conformers': self.num_conformers,
|
|
39
|
-
'force': self.force,
|
|
40
37
|
}
|
|
41
38
|
config.update({
|
|
42
39
|
k: v for (k, v) in self.kwargs.items()
|
|
@@ -48,7 +45,6 @@ class ConformerEmbedder(ConformerProcessor):
|
|
|
48
45
|
mol,
|
|
49
46
|
method=self.method,
|
|
50
47
|
num_conformers=self.num_conformers,
|
|
51
|
-
force=self.force,
|
|
52
48
|
**self.kwargs,
|
|
53
49
|
)
|
|
54
50
|
|
molcraft/featurizers.py
CHANGED
|
@@ -175,7 +175,7 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
175
175
|
default_bond_features = (
|
|
176
176
|
bond_features == 'auto' or bond_features == 'default'
|
|
177
177
|
)
|
|
178
|
-
if default_bond_features or self.radius > 1
|
|
178
|
+
if default_bond_features or self.radius > 1:
|
|
179
179
|
vocab = ['zero', 'single', 'double', 'triple', 'aromatic']
|
|
180
180
|
bond_features = [
|
|
181
181
|
features.BondType(vocab)
|
|
@@ -215,7 +215,7 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
215
215
|
if mol is None:
|
|
216
216
|
warn(
|
|
217
217
|
f'Could not obtain `chem.Mol` from {x}. '
|
|
218
|
-
'
|
|
218
|
+
'Returning `None` (proceeding without it).'
|
|
219
219
|
)
|
|
220
220
|
return None
|
|
221
221
|
|
|
@@ -254,24 +254,17 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
254
254
|
|
|
255
255
|
node = {}
|
|
256
256
|
node['feature'] = atom_feature
|
|
257
|
-
|
|
258
|
-
if bond_feature is not None and (self.radius > 1 or self.self_loops):
|
|
259
|
-
# Append 'zero order' bond feature encoding, which encodes non-bonds.
|
|
260
|
-
zero_bond_feature = np.array(
|
|
261
|
-
[[1., 0., 0., 0., 0.]], dtype=bond_feature.dtype
|
|
262
|
-
)
|
|
263
|
-
bond_feature = np.concatenate(
|
|
264
|
-
[bond_feature, zero_bond_feature], axis=0
|
|
265
|
-
)
|
|
266
257
|
|
|
267
258
|
edge = {}
|
|
268
259
|
if self.radius == 1:
|
|
269
260
|
edge['source'], edge['target'] = mol.adjacency(
|
|
270
261
|
fill='full', sparse=True, self_loops=self.self_loops, dtype=self.index_dtype
|
|
271
262
|
)
|
|
263
|
+
if self.self_loops:
|
|
264
|
+
bond_feature = np.pad(bond_feature, [(0, 1), (0, 0)])
|
|
272
265
|
if bond_feature is not None:
|
|
273
266
|
bond_indices = []
|
|
274
|
-
for
|
|
267
|
+
for atom_i, atom_j in zip(edge['source'], edge['target']):
|
|
275
268
|
if atom_i == atom_j:
|
|
276
269
|
bond_indices.append(-1)
|
|
277
270
|
else:
|
|
@@ -279,6 +272,8 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
279
272
|
mol.get_bond_between_atoms(atom_i, atom_j).index
|
|
280
273
|
)
|
|
281
274
|
edge['feature'] = bond_feature[bond_indices]
|
|
275
|
+
if self.self_loops:
|
|
276
|
+
edge['self_loop'] = (edge['source'] == edge['target'])
|
|
282
277
|
else:
|
|
283
278
|
paths = chem.get_shortest_paths(
|
|
284
279
|
mol, radius=self.radius, self_loops=self.self_loops
|
|
@@ -293,6 +288,12 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
293
288
|
[len(path) - 1 for path in paths], dtype=self.index_dtype
|
|
294
289
|
)
|
|
295
290
|
if bond_feature is not None:
|
|
291
|
+
zero_bond_feature = np.array(
|
|
292
|
+
[[1., 0., 0., 0., 0.]], dtype=bond_feature.dtype
|
|
293
|
+
)
|
|
294
|
+
bond_feature = np.concatenate(
|
|
295
|
+
[bond_feature, zero_bond_feature], axis=0
|
|
296
|
+
)
|
|
296
297
|
edge['feature'] = self._expand_bond_features(
|
|
297
298
|
mol, paths, bond_feature,
|
|
298
299
|
)
|
|
@@ -511,7 +512,7 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
|
|
|
511
512
|
steps=[
|
|
512
513
|
conformers.ConformerEmbedder(
|
|
513
514
|
method='ETKDGv3',
|
|
514
|
-
num_conformers=
|
|
515
|
+
num_conformers=5
|
|
515
516
|
),
|
|
516
517
|
]
|
|
517
518
|
)
|
|
@@ -588,7 +589,7 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
|
|
|
588
589
|
edge_feature = self.bond_features(mol)
|
|
589
590
|
|
|
590
591
|
edge = {}
|
|
591
|
-
mols = chem.
|
|
592
|
+
mols = chem.unpack_conformers(mol)
|
|
592
593
|
tensor_list = []
|
|
593
594
|
for i, mol in enumerate(mols):
|
|
594
595
|
node_conformer = copy.deepcopy(node)
|
|
@@ -734,6 +735,11 @@ def _add_super_edges(
|
|
|
734
735
|
]
|
|
735
736
|
)
|
|
736
737
|
|
|
738
|
+
if 'self_loop' in edge:
|
|
739
|
+
edge['self_loop'] = np.pad(
|
|
740
|
+
edge['self_loop'], [(0, num_nodes * num_super_nodes * 2)],
|
|
741
|
+
constant_values=False,
|
|
742
|
+
)
|
|
737
743
|
if 'length' in edge:
|
|
738
744
|
edge['length'] = np.pad(edge['length'], [(0, 0), (1, 0)])
|
|
739
745
|
zero_array = np.zeros([num_nodes * num_super_nodes * 2], dtype='int32')
|