molcraft 0.1.0a6__tar.gz → 0.1.0a8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

Files changed (31) hide show
  1. {molcraft-0.1.0a6 → molcraft-0.1.0a8}/PKG-INFO +2 -2
  2. {molcraft-0.1.0a6 → molcraft-0.1.0a8}/README.md +1 -1
  3. {molcraft-0.1.0a6 → molcraft-0.1.0a8}/molcraft/__init__.py +1 -1
  4. molcraft-0.1.0a8/molcraft/callbacks.py +100 -0
  5. {molcraft-0.1.0a6 → molcraft-0.1.0a8}/molcraft/chem.py +45 -30
  6. {molcraft-0.1.0a6 → molcraft-0.1.0a8}/molcraft/conformers.py +0 -4
  7. {molcraft-0.1.0a6 → molcraft-0.1.0a8}/molcraft/features.py +3 -9
  8. {molcraft-0.1.0a6 → molcraft-0.1.0a8}/molcraft/featurizers.py +18 -26
  9. {molcraft-0.1.0a6 → molcraft-0.1.0a8}/molcraft/layers.py +466 -801
  10. {molcraft-0.1.0a6 → molcraft-0.1.0a8}/molcraft/models.py +16 -1
  11. {molcraft-0.1.0a6 → molcraft-0.1.0a8}/molcraft/ops.py +14 -3
  12. {molcraft-0.1.0a6 → molcraft-0.1.0a8}/molcraft.egg-info/PKG-INFO +2 -2
  13. {molcraft-0.1.0a6 → molcraft-0.1.0a8}/tests/test_layers.py +9 -3
  14. {molcraft-0.1.0a6 → molcraft-0.1.0a8}/tests/test_models.py +60 -3
  15. molcraft-0.1.0a6/molcraft/callbacks.py +0 -33
  16. {molcraft-0.1.0a6 → molcraft-0.1.0a8}/LICENSE +0 -0
  17. {molcraft-0.1.0a6 → molcraft-0.1.0a8}/molcraft/datasets.py +0 -0
  18. {molcraft-0.1.0a6 → molcraft-0.1.0a8}/molcraft/descriptors.py +0 -0
  19. {molcraft-0.1.0a6 → molcraft-0.1.0a8}/molcraft/losses.py +0 -0
  20. {molcraft-0.1.0a6 → molcraft-0.1.0a8}/molcraft/records.py +0 -0
  21. {molcraft-0.1.0a6 → molcraft-0.1.0a8}/molcraft/tensors.py +0 -0
  22. {molcraft-0.1.0a6 → molcraft-0.1.0a8}/molcraft.egg-info/SOURCES.txt +0 -0
  23. {molcraft-0.1.0a6 → molcraft-0.1.0a8}/molcraft.egg-info/dependency_links.txt +0 -0
  24. {molcraft-0.1.0a6 → molcraft-0.1.0a8}/molcraft.egg-info/requires.txt +0 -0
  25. {molcraft-0.1.0a6 → molcraft-0.1.0a8}/molcraft.egg-info/top_level.txt +0 -0
  26. {molcraft-0.1.0a6 → molcraft-0.1.0a8}/pyproject.toml +0 -0
  27. {molcraft-0.1.0a6 → molcraft-0.1.0a8}/setup.cfg +0 -0
  28. {molcraft-0.1.0a6 → molcraft-0.1.0a8}/tests/test_chem.py +0 -0
  29. {molcraft-0.1.0a6 → molcraft-0.1.0a8}/tests/test_featurizers.py +0 -0
  30. {molcraft-0.1.0a6 → molcraft-0.1.0a8}/tests/test_losses.py +0 -0
  31. {molcraft-0.1.0a6 → molcraft-0.1.0a8}/tests/test_tensors.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: molcraft
3
- Version: 0.1.0a6
3
+ Version: 0.1.0a8
4
4
  Summary: Graph Neural Networks for Molecular Machine Learning
5
5
  Author-email: Alexander Kensert <alexander.kensert@gmail.com>
6
6
  License: MIT License
@@ -78,7 +78,7 @@ featurizer = featurizers.MolGraphFeaturizer(
78
78
  features.IsRotatable(),
79
79
  ],
80
80
  super_atom=True,
81
- self_loops=False,
81
+ self_loops=True,
82
82
  )
83
83
 
84
84
  graph = featurizer([('N[C@@H](C)C(=O)O', 2.0), ('N[C@@H](CS)C(=O)O', 1.0)])
@@ -34,7 +34,7 @@ featurizer = featurizers.MolGraphFeaturizer(
34
34
  features.IsRotatable(),
35
35
  ],
36
36
  super_atom=True,
37
- self_loops=False,
37
+ self_loops=True,
38
38
  )
39
39
 
40
40
  graph = featurizer([('N[C@@H](C)C(=O)O', 2.0), ('N[C@@H](CS)C(=O)O', 1.0)])
@@ -1,4 +1,4 @@
1
- __version__ = '0.1.0a6'
1
+ __version__ = '0.1.0a8'
2
2
 
3
3
  import os
4
4
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
@@ -0,0 +1,100 @@
1
+ import keras
2
+ import warnings
3
+ import numpy as np
4
+
5
+
6
+ class TensorBoard(keras.callbacks.TensorBoard):
7
+
8
+ def _log_weights(self, epoch):
9
+ with self._train_writer.as_default():
10
+ for layer in self.model.layers:
11
+ for weight in layer.weights:
12
+ # Use weight.path istead of weight.name to distinguish
13
+ # weights of different layers.
14
+ histogram_weight_name = weight.path + "/histogram"
15
+ self.summary.histogram(
16
+ histogram_weight_name, weight, step=epoch
17
+ )
18
+ if self.write_images:
19
+ image_weight_name = weight.path + "/image"
20
+ self._log_weight_as_image(
21
+ weight, image_weight_name, epoch
22
+ )
23
+ self._train_writer.flush()
24
+
25
+
26
+ class LearningRateDecay(keras.callbacks.LearningRateScheduler):
27
+
28
+ def __init__(self, rate: float, delay: int = 0, **kwargs):
29
+
30
+ def lr_schedule(epoch: int, lr: float):
31
+ if epoch < delay:
32
+ return float(lr)
33
+ return float(lr * keras.ops.exp(-rate))
34
+
35
+ super().__init__(schedule=lr_schedule, **kwargs)
36
+
37
+
38
+ class Rollback(keras.callbacks.Callback):
39
+ """Rollback callback.
40
+
41
+ Currently, this callback simply restores the model and (optionally) the optimizer
42
+ variables if current loss deviates too much from the best observed loss.
43
+
44
+ This callback might be useful in situations where the loss tend to spike and put
45
+ the model in an undesired/problematic high-loss parameter space.
46
+
47
+ Args:
48
+ tolerance (float):
49
+ The threshold for when the restoration is triggered. The devaiation is
50
+ calculated as follows: (current_loss - best_loss) / best_loss.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ tolerance: float = 0.5,
56
+ rollback_optimizer: bool = True,
57
+ ):
58
+ super().__init__()
59
+ self.tolerance = tolerance
60
+ self.rollback_optimizer = rollback_optimizer
61
+
62
+ def on_train_begin(self, logs=None):
63
+ self._rollback_weights = self._get_model_vars()
64
+ if self.rollback_optimizer:
65
+ self._rollback_optimizer_vars = self._get_optimizer_vars()
66
+ self._rollback_loss = float('inf')
67
+
68
+ def on_epoch_end(self, epoch: int, logs: dict = None):
69
+ current_loss = logs.get('val_loss', logs.get('loss'))
70
+ deviation = (current_loss - self._rollback_loss) / self._rollback_loss
71
+
72
+ if np.isnan(current_loss) or np.isinf(current_loss):
73
+ self._rollback()
74
+ # Rolling back model because of nan or inf loss
75
+ return
76
+
77
+ if deviation > self.tolerance:
78
+ self._rollback()
79
+ # Rolling back model because of large loss deviation.
80
+ return
81
+
82
+ if current_loss < self._rollback_loss:
83
+ self._save_state(current_loss)
84
+
85
+ def _save_state(self, current_loss: float) -> None:
86
+ self._rollback_loss = current_loss
87
+ self._rollback_weights = self._get_model_vars()
88
+ if self.rollback_optimizer:
89
+ self._rollback_optimizer_vars = self._get_optimizer_vars()
90
+
91
+ def _rollback(self) -> None:
92
+ self.model.set_weights(self._rollback_weights)
93
+ if self.rollback_optimizer:
94
+ self.model.optimizer.set_weights(self._rollback_optimizer_vars)
95
+
96
+ def _get_optimizer_vars(self):
97
+ return [v.numpy() for v in self.model.optimizer.variables]
98
+
99
+ def _get_model_vars(self):
100
+ return self.model.get_weights()
@@ -102,18 +102,20 @@ class Mol(Chem.Mol):
102
102
 
103
103
  def get_conformer(self, index: int = 0) -> 'Conformer':
104
104
  if self.num_conformers == 0:
105
- warn(
105
+ warnings.warn(
106
106
  'Molecule has no conformer. To embed conformer(s), invoke the `embed` method, '
107
- 'and optionally followed by `minimize()` to perform force field minimization.'
107
+ 'and optionally followed by `minimize()` to perform force field minimization.',
108
+ stacklevel=2
108
109
  )
109
110
  return None
110
111
  return Conformer.cast(self.GetConformer(index))
111
112
 
112
113
  def get_conformers(self) -> list['Conformer']:
113
114
  if self.num_conformers == 0:
114
- warn(
115
+ warnings.warn(
115
116
  'Molecule has no conformers. To embed conformers, invoke the `embed` method, '
116
- 'and optionally followed by `minimize()` to perform force field minimization.'
117
+ 'and optionally followed by `minimize()` to perform force field minimization.',
118
+ stacklevel=2
117
119
  )
118
120
  return []
119
121
  return [Conformer.cast(x) for x in self.GetConformers()]
@@ -400,7 +402,6 @@ def embed_conformers(
400
402
  mol: Mol,
401
403
  num_conformers: int,
402
404
  method: str = 'ETKDGv3',
403
- force: bool = True,
404
405
  **kwargs
405
406
  ) -> None:
406
407
  available_embedding_methods = {
@@ -411,27 +412,40 @@ def embed_conformers(
411
412
  'srETKDGv3': rdDistGeom.srETKDGv3(),
412
413
  'KDG': rdDistGeom.KDG()
413
414
  }
414
- default_embedding_method = 'ETKDGv3'
415
415
  mol = Mol(mol)
416
- params = available_embedding_methods.get(method)
417
- if params is None:
418
- warn(
419
- f"Could not find `method` {method}. "
420
- f"Automatically setting method to {default_embedding_method}."
416
+ embedding_method = available_embedding_methods.get(method)
417
+ if embedding_method is None:
418
+ raise ValueError(
419
+ f'Could not find `method` {method!r}. Specify either of: '
420
+ '`ETDG`, `ETKDG`, `ETKDGv2`, `ETKDGv3`, `srETKDGv3` or `KDG`.'
421
421
  )
422
- params = available_embedding_methods[default_embedding_method]
422
+
423
423
  for key, value in kwargs.items():
424
- setattr(params, key, value)
424
+ setattr(embedding_method, key, value)
425
425
 
426
- success = rdDistGeom.EmbedMultipleConfs(mol, numConfs=num_conformers, params=params)
426
+ success = rdDistGeom.EmbedMultipleConfs(
427
+ mol, numConfs=num_conformers, params=embedding_method
428
+ )
427
429
  if not len(success):
428
- warning = 'Could not embed conformer(s).'
429
- if not force:
430
- warn(warning)
430
+ warnings.warn(
431
+ f'Could not embed conformer(s) for {mol.canonical_smiles!r} using the '
432
+ 'speified method. Giving it another try with more permissive methods.',
433
+ stacklevel=2
434
+ )
435
+ max_attempts = (20 * mol.num_atoms) # increasing it from 10xN to 20xN
436
+ for fallback_method in [method, 'ETDG', 'KDG']:
437
+ fallback_embedding_method = available_embedding_methods[fallback_method]
438
+ fallback_embedding_method.useRandomCoords = True
439
+ fallback_embedding_method.maxAttempts = max_attempts
440
+ success = rdDistGeom.EmbedMultipleConfs(
441
+ mol, numConfs=num_conformers, params=fallback_embedding_method
442
+ )
443
+ if len(success):
444
+ break
431
445
  else:
432
- solution = ' Embedding a conformer (in 3D space) using (x, y) coordinates.'
433
- warn(warning + solution)
434
- rdDepictor.Compute2DCoords(mol)
446
+ raise RuntimeError(
447
+ f'Could not embed conformer(s) for {mol.canonical_smiles!r}. '
448
+ )
435
449
  return mol
436
450
 
437
451
  def optimize_conformers(
@@ -445,6 +459,11 @@ def optimize_conformers(
445
459
  available_force_field_methods = [
446
460
  'MMFF', 'MMFF94', 'MMFF94s', 'UFF'
447
461
  ]
462
+ if method not in available_force_field_methods:
463
+ raise ValueError(
464
+ f'Could not find `method` {method!r}. Specify either of: '
465
+ '`UFF`, `MMFF`, `MMFF94` or `MMFF94s`.'
466
+ )
448
467
  mol = Mol(mol)
449
468
  try:
450
469
  if method.startswith('MMFF'):
@@ -467,9 +486,10 @@ def optimize_conformers(
467
486
  ignore_interfragment_interactions=ignore_interfragment_interactions,
468
487
  )
469
488
  except RuntimeError as e:
470
- warn(
489
+ warnings.warn(
471
490
  f'{method} force field minimization raised {e}. '
472
- '\nProceeding without force field minimization...'
491
+ '\nProceeding without force field minimization.',
492
+ stacklevel=2
473
493
  )
474
494
  return mol
475
495
 
@@ -480,9 +500,10 @@ def prune_conformers(
480
500
  energy_force_field: str = 'UFF',
481
501
  ):
482
502
  if mol.num_conformers == 0:
483
- warn(
503
+ warnings.warn(
484
504
  'Molecule has no conformers. To embed conformers, invoke the `embed` method, '
485
- 'and optionally followed by `minimize()` to perform force field minimization.'
505
+ 'and optionally followed by `minimize()` to perform force field minimization.',
506
+ stacklevel=2
486
507
  )
487
508
  return mol
488
509
 
@@ -658,9 +679,3 @@ def _atom_pair_fingerprint(
658
679
  fp_param = {'fpSize': size}
659
680
  return _get_fingerprint(mol, 'atom_pair', binary, dtype, **fp_param)
660
681
 
661
- def warn(message: str) -> None:
662
- warnings.warn(
663
- message=message,
664
- category=UserWarning,
665
- stacklevel=1,
666
- )
@@ -24,19 +24,16 @@ class ConformerEmbedder(ConformerProcessor):
24
24
  self,
25
25
  method: str = 'ETKDGv3',
26
26
  num_conformers: int = 5,
27
- force: bool = True,
28
27
  **kwargs,
29
28
  ) -> None:
30
29
  self.method = method
31
30
  self.num_conformers = num_conformers
32
- self.force = force
33
31
  self.kwargs = kwargs
34
32
 
35
33
  def get_config(self) -> dict:
36
34
  config = {
37
35
  'method': self.method,
38
36
  'num_conformers': self.num_conformers,
39
- 'force': self.force,
40
37
  }
41
38
  config.update({
42
39
  k: v for (k, v) in self.kwargs.items()
@@ -48,7 +45,6 @@ class ConformerEmbedder(ConformerProcessor):
48
45
  mol,
49
46
  method=self.method,
50
47
  num_conformers=self.num_conformers,
51
- force=self.force,
52
48
  **self.kwargs,
53
49
  )
54
50
 
@@ -110,9 +110,10 @@ class Feature(abc.ABC):
110
110
  'type `float`, `int`, `bool` or `None`.'
111
111
  )
112
112
  if not math.isfinite(value):
113
- warn(
113
+ warnings.warn(
114
114
  f'Found value of {self.name} to be non-finite. '
115
- f'Value received: {value}. Converting it to a value of 0.'
115
+ f'Value received: {value}. Converting it to a value of 0.',
116
+ stacklevel=2
116
117
  )
117
118
  value = 0.0
118
119
  return np.asarray([value], dtype=self.dtype)
@@ -380,10 +381,3 @@ default_vocabulary = {
380
381
  ],
381
382
  }
382
383
 
383
-
384
- def warn(message: str) -> None:
385
- warnings.warn(
386
- message=message,
387
- category=UserWarning,
388
- stacklevel=1
389
- )
@@ -180,6 +180,12 @@ class MolGraphFeaturizer(Featurizer):
180
180
  bond_features = [
181
181
  features.BondType(vocab)
182
182
  ]
183
+ if not default_bond_features and self.radius > 1:
184
+ warnings.warn(
185
+ 'Replacing user-specified bond features with default bond features, '
186
+ 'as `radius`>1. When `radius`>1, only bond types are considered.',
187
+ stacklevel=2
188
+ )
183
189
  default_molecule_features = (
184
190
  molecule_features == 'auto' or molecule_features == 'default'
185
191
  )
@@ -213,9 +219,10 @@ class MolGraphFeaturizer(Featurizer):
213
219
  mol = chem.Mol.from_encoding(x, explicit_hs=self.include_hs)
214
220
 
215
221
  if mol is None:
216
- warn(
222
+ warnings.warn(
217
223
  f'Could not obtain `chem.Mol` from {x}. '
218
- 'Returning `None` (proceeding without it).'
224
+ 'Returning `None` (proceeding without it).',
225
+ stacklevel=2
219
226
  )
220
227
  return None
221
228
 
@@ -245,10 +252,11 @@ class MolGraphFeaturizer(Featurizer):
245
252
 
246
253
  if molecule_feature is not None:
247
254
  if 'feature' in context:
248
- warn(
255
+ warnings.warn(
249
256
  'Found both inputted and computed context feature. '
250
257
  'Overwriting inputted context feature with computed '
251
- 'context feature (based on `molecule_features`).'
258
+ 'context feature (based on `molecule_features`).',
259
+ stacklevel=2
252
260
  )
253
261
  context['feature'] = molecule_feature
254
262
 
@@ -284,9 +292,6 @@ class MolGraphFeaturizer(Featurizer):
284
292
  edge['target'] = np.asarray(
285
293
  [path[-1] for path in paths], dtype=self.index_dtype
286
294
  )
287
- edge['length'] = np.asarray(
288
- [len(path) - 1 for path in paths], dtype=self.index_dtype
289
- )
290
295
  if bond_feature is not None:
291
296
  zero_bond_feature = np.array(
292
297
  [[1., 0., 0., 0., 0.]], dtype=bond_feature.dtype
@@ -297,7 +302,6 @@ class MolGraphFeaturizer(Featurizer):
297
302
  edge['feature'] = self._expand_bond_features(
298
303
  mol, paths, bond_feature,
299
304
  )
300
- edge['length'] = np.eye(self.radius + 1, dtype=self.feature_dtype)[edge['length']]
301
305
 
302
306
  if self.super_atom:
303
307
  node, edge = self._add_super_atom(node, edge)
@@ -533,9 +537,10 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
533
537
  mol = chem.Mol.from_encoding(x, explicit_hs=explicit_hs)
534
538
 
535
539
  if mol is None:
536
- warn(
540
+ warnings.warn(
537
541
  f'Could not obtain `chem.Mol` from {x}. '
538
- 'Proceeding without it.'
542
+ 'Proceeding without it.',
543
+ stacklevel=2
539
544
  )
540
545
  return None
541
546
 
@@ -575,10 +580,11 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
575
580
 
576
581
  if molecule_feature is not None:
577
582
  if 'feature' in context:
578
- warn(
583
+ warnings.warn(
579
584
  'Found both inputted and computed context feature. '
580
585
  'Overwriting inputted context feature with computed '
581
- 'context feature (based on `molecule_features`).'
586
+ 'context feature (based on `molecule_features`).',
587
+ stacklevel=2
582
588
  )
583
589
  context['feature'] = molecule_feature
584
590
 
@@ -740,23 +746,9 @@ def _add_super_edges(
740
746
  edge['self_loop'], [(0, num_nodes * num_super_nodes * 2)],
741
747
  constant_values=False,
742
748
  )
743
- if 'length' in edge:
744
- edge['length'] = np.pad(edge['length'], [(0, 0), (1, 0)])
745
- zero_array = np.zeros([num_nodes * num_super_nodes * 2], dtype='int32')
746
- edge_length_dim = edge['length'].shape[1]
747
- virtual_edge_length = np.eye(edge_length_dim)[zero_array]
748
- edge['length'] = np.concatenate([edge['length'], virtual_edge_length])
749
- edge['length'] = edge['length'].astype(feature_dtype)
750
749
 
751
750
  return edge
752
751
 
753
-
754
- def warn(message: str) -> None:
755
- warnings.warn(
756
- message=message,
757
- category=UserWarning,
758
- stacklevel=1
759
- )
760
752
 
761
753
  MolFeaturizer = MolGraphFeaturizer
762
754
  MolFeaturizer3D = MolGraphFeaturizer3D