molcraft 0.1.0a7__py3-none-any.whl → 0.1.0a9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of molcraft might be problematic. Click here for more details.
- molcraft/__init__.py +1 -1
- molcraft/callbacks.py +33 -26
- molcraft/chem.py +15 -16
- molcraft/features.py +3 -9
- molcraft/featurizers.py +28 -38
- molcraft/layers.py +439 -858
- molcraft/ops.py +12 -1
- {molcraft-0.1.0a7.dist-info → molcraft-0.1.0a9.dist-info}/METADATA +2 -2
- molcraft-0.1.0a9.dist-info/RECORD +19 -0
- molcraft-0.1.0a7.dist-info/RECORD +0 -19
- {molcraft-0.1.0a7.dist-info → molcraft-0.1.0a9.dist-info}/WHEEL +0 -0
- {molcraft-0.1.0a7.dist-info → molcraft-0.1.0a9.dist-info}/licenses/LICENSE +0 -0
- {molcraft-0.1.0a7.dist-info → molcraft-0.1.0a9.dist-info}/top_level.txt +0 -0
molcraft/__init__.py
CHANGED
molcraft/callbacks.py
CHANGED
|
@@ -36,58 +36,65 @@ class LearningRateDecay(keras.callbacks.LearningRateScheduler):
|
|
|
36
36
|
|
|
37
37
|
|
|
38
38
|
class Rollback(keras.callbacks.Callback):
|
|
39
|
+
"""Rollback callback.
|
|
40
|
+
|
|
41
|
+
Currently, this callback simply restores the model and (optionally) the optimizer
|
|
42
|
+
variables if current loss deviates too much from the best observed loss.
|
|
43
|
+
|
|
44
|
+
This callback might be useful in situations where the loss tend to spike and put
|
|
45
|
+
the model in an undesired/problematic high-loss parameter space.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
tolerance (float):
|
|
49
|
+
The threshold for when the restoration is triggered. The devaiation is
|
|
50
|
+
calculated as follows: (current_loss - best_loss) / best_loss.
|
|
51
|
+
"""
|
|
39
52
|
|
|
40
53
|
def __init__(
|
|
41
|
-
self,
|
|
42
|
-
|
|
43
|
-
tolerance: float = 0.5,
|
|
54
|
+
self,
|
|
55
|
+
tolerance: float = 0.5,
|
|
44
56
|
rollback_optimizer: bool = True,
|
|
45
57
|
):
|
|
46
58
|
super().__init__()
|
|
47
|
-
self.frequency = frequency or 1_000_000_000
|
|
48
59
|
self.tolerance = tolerance
|
|
49
60
|
self.rollback_optimizer = rollback_optimizer
|
|
50
61
|
|
|
51
62
|
def on_train_begin(self, logs=None):
|
|
52
|
-
self.
|
|
53
|
-
|
|
54
|
-
|
|
63
|
+
self._rollback_weights = self._get_model_vars()
|
|
64
|
+
if self.rollback_optimizer:
|
|
65
|
+
self._rollback_optimizer_vars = self._get_optimizer_vars()
|
|
66
|
+
self._rollback_loss = float('inf')
|
|
55
67
|
|
|
56
68
|
def on_epoch_end(self, epoch: int, logs: dict = None):
|
|
57
69
|
current_loss = logs.get('val_loss', logs.get('loss'))
|
|
58
|
-
deviation = (current_loss - self.
|
|
70
|
+
deviation = (current_loss - self._rollback_loss) / self._rollback_loss
|
|
59
71
|
|
|
60
72
|
if np.isnan(current_loss) or np.isinf(current_loss):
|
|
61
73
|
self._rollback()
|
|
62
|
-
|
|
63
|
-
return
|
|
64
|
-
|
|
74
|
+
# Rolling back model because of nan or inf loss
|
|
75
|
+
return
|
|
76
|
+
|
|
65
77
|
if deviation > self.tolerance:
|
|
66
78
|
self._rollback()
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
print(f"\nRolling back model, {epoch} % {self.frequency} == 0\n")
|
|
72
|
-
return
|
|
73
|
-
|
|
74
|
-
if current_loss < self.rollback_loss:
|
|
79
|
+
# Rolling back model because of large loss deviation.
|
|
80
|
+
return
|
|
81
|
+
|
|
82
|
+
if current_loss < self._rollback_loss:
|
|
75
83
|
self._save_state(current_loss)
|
|
76
84
|
|
|
77
85
|
def _save_state(self, current_loss: float) -> None:
|
|
78
|
-
self.
|
|
79
|
-
self.
|
|
86
|
+
self._rollback_loss = current_loss
|
|
87
|
+
self._rollback_weights = self._get_model_vars()
|
|
80
88
|
if self.rollback_optimizer:
|
|
81
|
-
self.
|
|
89
|
+
self._rollback_optimizer_vars = self._get_optimizer_vars()
|
|
82
90
|
|
|
83
91
|
def _rollback(self) -> None:
|
|
84
|
-
self.model.set_weights(self.
|
|
92
|
+
self.model.set_weights(self._rollback_weights)
|
|
85
93
|
if self.rollback_optimizer:
|
|
86
|
-
self.model.optimizer.set_weights(self.
|
|
94
|
+
self.model.optimizer.set_weights(self._rollback_optimizer_vars)
|
|
87
95
|
|
|
88
96
|
def _get_optimizer_vars(self):
|
|
89
97
|
return [v.numpy() for v in self.model.optimizer.variables]
|
|
90
|
-
|
|
98
|
+
|
|
91
99
|
def _get_model_vars(self):
|
|
92
100
|
return self.model.get_weights()
|
|
93
|
-
|
molcraft/chem.py
CHANGED
|
@@ -102,18 +102,20 @@ class Mol(Chem.Mol):
|
|
|
102
102
|
|
|
103
103
|
def get_conformer(self, index: int = 0) -> 'Conformer':
|
|
104
104
|
if self.num_conformers == 0:
|
|
105
|
-
warn(
|
|
105
|
+
warnings.warn(
|
|
106
106
|
'Molecule has no conformer. To embed conformer(s), invoke the `embed` method, '
|
|
107
|
-
'and optionally followed by `minimize()` to perform force field minimization.'
|
|
107
|
+
'and optionally followed by `minimize()` to perform force field minimization.',
|
|
108
|
+
stacklevel=2
|
|
108
109
|
)
|
|
109
110
|
return None
|
|
110
111
|
return Conformer.cast(self.GetConformer(index))
|
|
111
112
|
|
|
112
113
|
def get_conformers(self) -> list['Conformer']:
|
|
113
114
|
if self.num_conformers == 0:
|
|
114
|
-
warn(
|
|
115
|
+
warnings.warn(
|
|
115
116
|
'Molecule has no conformers. To embed conformers, invoke the `embed` method, '
|
|
116
|
-
'and optionally followed by `minimize()` to perform force field minimization.'
|
|
117
|
+
'and optionally followed by `minimize()` to perform force field minimization.',
|
|
118
|
+
stacklevel=2
|
|
117
119
|
)
|
|
118
120
|
return []
|
|
119
121
|
return [Conformer.cast(x) for x in self.GetConformers()]
|
|
@@ -425,9 +427,10 @@ def embed_conformers(
|
|
|
425
427
|
mol, numConfs=num_conformers, params=embedding_method
|
|
426
428
|
)
|
|
427
429
|
if not len(success):
|
|
428
|
-
warn(
|
|
430
|
+
warnings.warn(
|
|
429
431
|
f'Could not embed conformer(s) for {mol.canonical_smiles!r} using the '
|
|
430
|
-
'speified method. Giving it another try with more permissive methods.'
|
|
432
|
+
'speified method. Giving it another try with more permissive methods.',
|
|
433
|
+
stacklevel=2
|
|
431
434
|
)
|
|
432
435
|
max_attempts = (20 * mol.num_atoms) # increasing it from 10xN to 20xN
|
|
433
436
|
for fallback_method in [method, 'ETDG', 'KDG']:
|
|
@@ -483,9 +486,10 @@ def optimize_conformers(
|
|
|
483
486
|
ignore_interfragment_interactions=ignore_interfragment_interactions,
|
|
484
487
|
)
|
|
485
488
|
except RuntimeError as e:
|
|
486
|
-
warn(
|
|
489
|
+
warnings.warn(
|
|
487
490
|
f'{method} force field minimization raised {e}. '
|
|
488
|
-
'\nProceeding without force field minimization.'
|
|
491
|
+
'\nProceeding without force field minimization.',
|
|
492
|
+
stacklevel=2
|
|
489
493
|
)
|
|
490
494
|
return mol
|
|
491
495
|
|
|
@@ -496,9 +500,10 @@ def prune_conformers(
|
|
|
496
500
|
energy_force_field: str = 'UFF',
|
|
497
501
|
):
|
|
498
502
|
if mol.num_conformers == 0:
|
|
499
|
-
warn(
|
|
503
|
+
warnings.warn(
|
|
500
504
|
'Molecule has no conformers. To embed conformers, invoke the `embed` method, '
|
|
501
|
-
'and optionally followed by `minimize()` to perform force field minimization.'
|
|
505
|
+
'and optionally followed by `minimize()` to perform force field minimization.',
|
|
506
|
+
stacklevel=2
|
|
502
507
|
)
|
|
503
508
|
return mol
|
|
504
509
|
|
|
@@ -674,9 +679,3 @@ def _atom_pair_fingerprint(
|
|
|
674
679
|
fp_param = {'fpSize': size}
|
|
675
680
|
return _get_fingerprint(mol, 'atom_pair', binary, dtype, **fp_param)
|
|
676
681
|
|
|
677
|
-
def warn(message: str) -> None:
|
|
678
|
-
warnings.warn(
|
|
679
|
-
message=message,
|
|
680
|
-
category=UserWarning,
|
|
681
|
-
stacklevel=1,
|
|
682
|
-
)
|
molcraft/features.py
CHANGED
|
@@ -110,9 +110,10 @@ class Feature(abc.ABC):
|
|
|
110
110
|
'type `float`, `int`, `bool` or `None`.'
|
|
111
111
|
)
|
|
112
112
|
if not math.isfinite(value):
|
|
113
|
-
warn(
|
|
113
|
+
warnings.warn(
|
|
114
114
|
f'Found value of {self.name} to be non-finite. '
|
|
115
|
-
f'Value received: {value}. Converting it to a value of 0.'
|
|
115
|
+
f'Value received: {value}. Converting it to a value of 0.',
|
|
116
|
+
stacklevel=2
|
|
116
117
|
)
|
|
117
118
|
value = 0.0
|
|
118
119
|
return np.asarray([value], dtype=self.dtype)
|
|
@@ -380,10 +381,3 @@ default_vocabulary = {
|
|
|
380
381
|
],
|
|
381
382
|
}
|
|
382
383
|
|
|
383
|
-
|
|
384
|
-
def warn(message: str) -> None:
|
|
385
|
-
warnings.warn(
|
|
386
|
-
message=message,
|
|
387
|
-
category=UserWarning,
|
|
388
|
-
stacklevel=1
|
|
389
|
-
)
|
molcraft/featurizers.py
CHANGED
|
@@ -180,6 +180,12 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
180
180
|
bond_features = [
|
|
181
181
|
features.BondType(vocab)
|
|
182
182
|
]
|
|
183
|
+
if not default_bond_features and self.radius > 1:
|
|
184
|
+
warnings.warn(
|
|
185
|
+
'Replacing user-specified bond features with default bond features, '
|
|
186
|
+
'as `radius`>1. When `radius`>1, only bond types are considered.',
|
|
187
|
+
stacklevel=2
|
|
188
|
+
)
|
|
183
189
|
default_molecule_features = (
|
|
184
190
|
molecule_features == 'auto' or molecule_features == 'default'
|
|
185
191
|
)
|
|
@@ -213,9 +219,10 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
213
219
|
mol = chem.Mol.from_encoding(x, explicit_hs=self.include_hs)
|
|
214
220
|
|
|
215
221
|
if mol is None:
|
|
216
|
-
warn(
|
|
222
|
+
warnings.warn(
|
|
217
223
|
f'Could not obtain `chem.Mol` from {x}. '
|
|
218
|
-
'Returning `None` (proceeding without it).'
|
|
224
|
+
'Returning `None` (proceeding without it).',
|
|
225
|
+
stacklevel=2
|
|
219
226
|
)
|
|
220
227
|
return None
|
|
221
228
|
|
|
@@ -245,10 +252,11 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
245
252
|
|
|
246
253
|
if molecule_feature is not None:
|
|
247
254
|
if 'feature' in context:
|
|
248
|
-
warn(
|
|
255
|
+
warnings.warn(
|
|
249
256
|
'Found both inputted and computed context feature. '
|
|
250
257
|
'Overwriting inputted context feature with computed '
|
|
251
|
-
'context feature (based on `molecule_features`).'
|
|
258
|
+
'context feature (based on `molecule_features`).',
|
|
259
|
+
stacklevel=2
|
|
252
260
|
)
|
|
253
261
|
context['feature'] = molecule_feature
|
|
254
262
|
|
|
@@ -272,8 +280,6 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
272
280
|
mol.get_bond_between_atoms(atom_i, atom_j).index
|
|
273
281
|
)
|
|
274
282
|
edge['feature'] = bond_feature[bond_indices]
|
|
275
|
-
if self.self_loops:
|
|
276
|
-
edge['self_loop'] = (edge['source'] == edge['target'])
|
|
277
283
|
else:
|
|
278
284
|
paths = chem.get_shortest_paths(
|
|
279
285
|
mol, radius=self.radius, self_loops=self.self_loops
|
|
@@ -284,9 +290,6 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
284
290
|
edge['target'] = np.asarray(
|
|
285
291
|
[path[-1] for path in paths], dtype=self.index_dtype
|
|
286
292
|
)
|
|
287
|
-
edge['length'] = np.asarray(
|
|
288
|
-
[len(path) - 1 for path in paths], dtype=self.index_dtype
|
|
289
|
-
)
|
|
290
293
|
if bond_feature is not None:
|
|
291
294
|
zero_bond_feature = np.array(
|
|
292
295
|
[[1., 0., 0., 0., 0.]], dtype=bond_feature.dtype
|
|
@@ -297,7 +300,6 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
297
300
|
edge['feature'] = self._expand_bond_features(
|
|
298
301
|
mol, paths, bond_feature,
|
|
299
302
|
)
|
|
300
|
-
edge['length'] = np.eye(self.radius + 1, dtype=self.feature_dtype)[edge['length']]
|
|
301
303
|
|
|
302
304
|
if self.super_atom:
|
|
303
305
|
node, edge = self._add_super_atom(node, edge)
|
|
@@ -372,7 +374,7 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
372
374
|
num_nodes = node['feature'].shape[0]
|
|
373
375
|
node = _add_super_nodes(node, num_super_nodes)
|
|
374
376
|
edge = _add_super_edges(
|
|
375
|
-
edge, num_nodes, num_super_nodes, self.feature_dtype, self.index_dtype
|
|
377
|
+
edge, num_nodes, num_super_nodes, self.feature_dtype, self.index_dtype, self.self_loops
|
|
376
378
|
)
|
|
377
379
|
return node, edge
|
|
378
380
|
|
|
@@ -533,9 +535,10 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
|
|
|
533
535
|
mol = chem.Mol.from_encoding(x, explicit_hs=explicit_hs)
|
|
534
536
|
|
|
535
537
|
if mol is None:
|
|
536
|
-
warn(
|
|
538
|
+
warnings.warn(
|
|
537
539
|
f'Could not obtain `chem.Mol` from {x}. '
|
|
538
|
-
'Proceeding without it.'
|
|
540
|
+
'Proceeding without it.',
|
|
541
|
+
stacklevel=2
|
|
539
542
|
)
|
|
540
543
|
return None
|
|
541
544
|
|
|
@@ -575,10 +578,11 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
|
|
|
575
578
|
|
|
576
579
|
if molecule_feature is not None:
|
|
577
580
|
if 'feature' in context:
|
|
578
|
-
warn(
|
|
581
|
+
warnings.warn(
|
|
579
582
|
'Found both inputted and computed context feature. '
|
|
580
583
|
'Overwriting inputted context feature with computed '
|
|
581
|
-
'context feature (based on `molecule_features`).'
|
|
584
|
+
'context feature (based on `molecule_features`).',
|
|
585
|
+
stacklevel=2
|
|
582
586
|
)
|
|
583
587
|
context['feature'] = molecule_feature
|
|
584
588
|
|
|
@@ -702,11 +706,15 @@ def _add_super_edges(
|
|
|
702
706
|
num_super_nodes: int,
|
|
703
707
|
feature_dtype: str,
|
|
704
708
|
index_dtype: str,
|
|
709
|
+
self_loops: bool,
|
|
705
710
|
) -> dict[str, np.ndarray]:
|
|
706
711
|
edge = copy.deepcopy(edge)
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
712
|
+
|
|
713
|
+
super_node_indices = np.arange(num_super_nodes) + num_nodes
|
|
714
|
+
if self_loops:
|
|
715
|
+
edge['source'] = np.concatenate([edge['source'], super_node_indices])
|
|
716
|
+
edge['target'] = np.concatenate([edge['target'], super_node_indices])
|
|
717
|
+
super_node_indices = np.repeat(super_node_indices, [num_nodes])
|
|
710
718
|
node_indices = (
|
|
711
719
|
np.tile(np.arange(num_nodes), [num_super_nodes])
|
|
712
720
|
)
|
|
@@ -721,6 +729,8 @@ def _add_super_edges(
|
|
|
721
729
|
if 'feature' in edge:
|
|
722
730
|
num_edges = int(edge['feature'].shape[0])
|
|
723
731
|
num_super_edges = int(num_super_nodes * num_nodes * 2)
|
|
732
|
+
if self_loops:
|
|
733
|
+
num_super_edges += num_super_nodes
|
|
724
734
|
edge['super'] = np.asarray(
|
|
725
735
|
([False] * num_edges + [True] * num_super_edges),
|
|
726
736
|
dtype=bool
|
|
@@ -735,28 +745,8 @@ def _add_super_edges(
|
|
|
735
745
|
]
|
|
736
746
|
)
|
|
737
747
|
|
|
738
|
-
if 'self_loop' in edge:
|
|
739
|
-
edge['self_loop'] = np.pad(
|
|
740
|
-
edge['self_loop'], [(0, num_nodes * num_super_nodes * 2)],
|
|
741
|
-
constant_values=False,
|
|
742
|
-
)
|
|
743
|
-
if 'length' in edge:
|
|
744
|
-
edge['length'] = np.pad(edge['length'], [(0, 0), (1, 0)])
|
|
745
|
-
zero_array = np.zeros([num_nodes * num_super_nodes * 2], dtype='int32')
|
|
746
|
-
edge_length_dim = edge['length'].shape[1]
|
|
747
|
-
virtual_edge_length = np.eye(edge_length_dim)[zero_array]
|
|
748
|
-
edge['length'] = np.concatenate([edge['length'], virtual_edge_length])
|
|
749
|
-
edge['length'] = edge['length'].astype(feature_dtype)
|
|
750
|
-
|
|
751
748
|
return edge
|
|
752
749
|
|
|
753
|
-
|
|
754
|
-
def warn(message: str) -> None:
|
|
755
|
-
warnings.warn(
|
|
756
|
-
message=message,
|
|
757
|
-
category=UserWarning,
|
|
758
|
-
stacklevel=1
|
|
759
|
-
)
|
|
760
750
|
|
|
761
751
|
MolFeaturizer = MolGraphFeaturizer
|
|
762
752
|
MolFeaturizer3D = MolGraphFeaturizer3D
|