molcraft 0.1.0a7__py3-none-any.whl → 0.1.0a9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

molcraft/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = '0.1.0a7'
1
+ __version__ = '0.1.0a9'
2
2
 
3
3
  import os
4
4
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
molcraft/callbacks.py CHANGED
@@ -36,58 +36,65 @@ class LearningRateDecay(keras.callbacks.LearningRateScheduler):
36
36
 
37
37
 
38
38
  class Rollback(keras.callbacks.Callback):
39
+ """Rollback callback.
40
+
41
+ Currently, this callback simply restores the model and (optionally) the optimizer
42
+ variables if current loss deviates too much from the best observed loss.
43
+
44
+ This callback might be useful in situations where the loss tend to spike and put
45
+ the model in an undesired/problematic high-loss parameter space.
46
+
47
+ Args:
48
+ tolerance (float):
49
+ The threshold for when the restoration is triggered. The devaiation is
50
+ calculated as follows: (current_loss - best_loss) / best_loss.
51
+ """
39
52
 
40
53
  def __init__(
41
- self,
42
- frequency: int = None,
43
- tolerance: float = 0.5,
54
+ self,
55
+ tolerance: float = 0.5,
44
56
  rollback_optimizer: bool = True,
45
57
  ):
46
58
  super().__init__()
47
- self.frequency = frequency or 1_000_000_000
48
59
  self.tolerance = tolerance
49
60
  self.rollback_optimizer = rollback_optimizer
50
61
 
51
62
  def on_train_begin(self, logs=None):
52
- self.rollback_weights = self._get_model_vars()
53
- self.rollback_optimizer_vars = self._get_optimizer_vars()
54
- self.rollback_loss = float('inf')
63
+ self._rollback_weights = self._get_model_vars()
64
+ if self.rollback_optimizer:
65
+ self._rollback_optimizer_vars = self._get_optimizer_vars()
66
+ self._rollback_loss = float('inf')
55
67
 
56
68
  def on_epoch_end(self, epoch: int, logs: dict = None):
57
69
  current_loss = logs.get('val_loss', logs.get('loss'))
58
- deviation = (current_loss - self.rollback_loss) / self.rollback_loss
70
+ deviation = (current_loss - self._rollback_loss) / self._rollback_loss
59
71
 
60
72
  if np.isnan(current_loss) or np.isinf(current_loss):
61
73
  self._rollback()
62
- print("\nRolling back model, found nan or inf loss.\n")
63
- return
64
-
74
+ # Rolling back model because of nan or inf loss
75
+ return
76
+
65
77
  if deviation > self.tolerance:
66
78
  self._rollback()
67
- print(f"\nRolling back model, found too large deviation: {deviation:.3f}\n")
68
-
69
- if epoch and epoch % self.frequency == 0:
70
- self._rollback()
71
- print(f"\nRolling back model, {epoch} % {self.frequency} == 0\n")
72
- return
73
-
74
- if current_loss < self.rollback_loss:
79
+ # Rolling back model because of large loss deviation.
80
+ return
81
+
82
+ if current_loss < self._rollback_loss:
75
83
  self._save_state(current_loss)
76
84
 
77
85
  def _save_state(self, current_loss: float) -> None:
78
- self.rollback_loss = current_loss
79
- self.rollback_weights = self._get_model_vars()
86
+ self._rollback_loss = current_loss
87
+ self._rollback_weights = self._get_model_vars()
80
88
  if self.rollback_optimizer:
81
- self.rollback_optimizer_vars = self._get_optimizer_vars()
89
+ self._rollback_optimizer_vars = self._get_optimizer_vars()
82
90
 
83
91
  def _rollback(self) -> None:
84
- self.model.set_weights(self.rollback_weights)
92
+ self.model.set_weights(self._rollback_weights)
85
93
  if self.rollback_optimizer:
86
- self.model.optimizer.set_weights(self.rollback_optimizer_vars)
94
+ self.model.optimizer.set_weights(self._rollback_optimizer_vars)
87
95
 
88
96
  def _get_optimizer_vars(self):
89
97
  return [v.numpy() for v in self.model.optimizer.variables]
90
-
98
+
91
99
  def _get_model_vars(self):
92
100
  return self.model.get_weights()
93
-
molcraft/chem.py CHANGED
@@ -102,18 +102,20 @@ class Mol(Chem.Mol):
102
102
 
103
103
  def get_conformer(self, index: int = 0) -> 'Conformer':
104
104
  if self.num_conformers == 0:
105
- warn(
105
+ warnings.warn(
106
106
  'Molecule has no conformer. To embed conformer(s), invoke the `embed` method, '
107
- 'and optionally followed by `minimize()` to perform force field minimization.'
107
+ 'and optionally followed by `minimize()` to perform force field minimization.',
108
+ stacklevel=2
108
109
  )
109
110
  return None
110
111
  return Conformer.cast(self.GetConformer(index))
111
112
 
112
113
  def get_conformers(self) -> list['Conformer']:
113
114
  if self.num_conformers == 0:
114
- warn(
115
+ warnings.warn(
115
116
  'Molecule has no conformers. To embed conformers, invoke the `embed` method, '
116
- 'and optionally followed by `minimize()` to perform force field minimization.'
117
+ 'and optionally followed by `minimize()` to perform force field minimization.',
118
+ stacklevel=2
117
119
  )
118
120
  return []
119
121
  return [Conformer.cast(x) for x in self.GetConformers()]
@@ -425,9 +427,10 @@ def embed_conformers(
425
427
  mol, numConfs=num_conformers, params=embedding_method
426
428
  )
427
429
  if not len(success):
428
- warn(
430
+ warnings.warn(
429
431
  f'Could not embed conformer(s) for {mol.canonical_smiles!r} using the '
430
- 'speified method. Giving it another try with more permissive methods.'
432
+ 'speified method. Giving it another try with more permissive methods.',
433
+ stacklevel=2
431
434
  )
432
435
  max_attempts = (20 * mol.num_atoms) # increasing it from 10xN to 20xN
433
436
  for fallback_method in [method, 'ETDG', 'KDG']:
@@ -483,9 +486,10 @@ def optimize_conformers(
483
486
  ignore_interfragment_interactions=ignore_interfragment_interactions,
484
487
  )
485
488
  except RuntimeError as e:
486
- warn(
489
+ warnings.warn(
487
490
  f'{method} force field minimization raised {e}. '
488
- '\nProceeding without force field minimization.'
491
+ '\nProceeding without force field minimization.',
492
+ stacklevel=2
489
493
  )
490
494
  return mol
491
495
 
@@ -496,9 +500,10 @@ def prune_conformers(
496
500
  energy_force_field: str = 'UFF',
497
501
  ):
498
502
  if mol.num_conformers == 0:
499
- warn(
503
+ warnings.warn(
500
504
  'Molecule has no conformers. To embed conformers, invoke the `embed` method, '
501
- 'and optionally followed by `minimize()` to perform force field minimization.'
505
+ 'and optionally followed by `minimize()` to perform force field minimization.',
506
+ stacklevel=2
502
507
  )
503
508
  return mol
504
509
 
@@ -674,9 +679,3 @@ def _atom_pair_fingerprint(
674
679
  fp_param = {'fpSize': size}
675
680
  return _get_fingerprint(mol, 'atom_pair', binary, dtype, **fp_param)
676
681
 
677
- def warn(message: str) -> None:
678
- warnings.warn(
679
- message=message,
680
- category=UserWarning,
681
- stacklevel=1,
682
- )
molcraft/features.py CHANGED
@@ -110,9 +110,10 @@ class Feature(abc.ABC):
110
110
  'type `float`, `int`, `bool` or `None`.'
111
111
  )
112
112
  if not math.isfinite(value):
113
- warn(
113
+ warnings.warn(
114
114
  f'Found value of {self.name} to be non-finite. '
115
- f'Value received: {value}. Converting it to a value of 0.'
115
+ f'Value received: {value}. Converting it to a value of 0.',
116
+ stacklevel=2
116
117
  )
117
118
  value = 0.0
118
119
  return np.asarray([value], dtype=self.dtype)
@@ -380,10 +381,3 @@ default_vocabulary = {
380
381
  ],
381
382
  }
382
383
 
383
-
384
- def warn(message: str) -> None:
385
- warnings.warn(
386
- message=message,
387
- category=UserWarning,
388
- stacklevel=1
389
- )
molcraft/featurizers.py CHANGED
@@ -180,6 +180,12 @@ class MolGraphFeaturizer(Featurizer):
180
180
  bond_features = [
181
181
  features.BondType(vocab)
182
182
  ]
183
+ if not default_bond_features and self.radius > 1:
184
+ warnings.warn(
185
+ 'Replacing user-specified bond features with default bond features, '
186
+ 'as `radius`>1. When `radius`>1, only bond types are considered.',
187
+ stacklevel=2
188
+ )
183
189
  default_molecule_features = (
184
190
  molecule_features == 'auto' or molecule_features == 'default'
185
191
  )
@@ -213,9 +219,10 @@ class MolGraphFeaturizer(Featurizer):
213
219
  mol = chem.Mol.from_encoding(x, explicit_hs=self.include_hs)
214
220
 
215
221
  if mol is None:
216
- warn(
222
+ warnings.warn(
217
223
  f'Could not obtain `chem.Mol` from {x}. '
218
- 'Returning `None` (proceeding without it).'
224
+ 'Returning `None` (proceeding without it).',
225
+ stacklevel=2
219
226
  )
220
227
  return None
221
228
 
@@ -245,10 +252,11 @@ class MolGraphFeaturizer(Featurizer):
245
252
 
246
253
  if molecule_feature is not None:
247
254
  if 'feature' in context:
248
- warn(
255
+ warnings.warn(
249
256
  'Found both inputted and computed context feature. '
250
257
  'Overwriting inputted context feature with computed '
251
- 'context feature (based on `molecule_features`).'
258
+ 'context feature (based on `molecule_features`).',
259
+ stacklevel=2
252
260
  )
253
261
  context['feature'] = molecule_feature
254
262
 
@@ -272,8 +280,6 @@ class MolGraphFeaturizer(Featurizer):
272
280
  mol.get_bond_between_atoms(atom_i, atom_j).index
273
281
  )
274
282
  edge['feature'] = bond_feature[bond_indices]
275
- if self.self_loops:
276
- edge['self_loop'] = (edge['source'] == edge['target'])
277
283
  else:
278
284
  paths = chem.get_shortest_paths(
279
285
  mol, radius=self.radius, self_loops=self.self_loops
@@ -284,9 +290,6 @@ class MolGraphFeaturizer(Featurizer):
284
290
  edge['target'] = np.asarray(
285
291
  [path[-1] for path in paths], dtype=self.index_dtype
286
292
  )
287
- edge['length'] = np.asarray(
288
- [len(path) - 1 for path in paths], dtype=self.index_dtype
289
- )
290
293
  if bond_feature is not None:
291
294
  zero_bond_feature = np.array(
292
295
  [[1., 0., 0., 0., 0.]], dtype=bond_feature.dtype
@@ -297,7 +300,6 @@ class MolGraphFeaturizer(Featurizer):
297
300
  edge['feature'] = self._expand_bond_features(
298
301
  mol, paths, bond_feature,
299
302
  )
300
- edge['length'] = np.eye(self.radius + 1, dtype=self.feature_dtype)[edge['length']]
301
303
 
302
304
  if self.super_atom:
303
305
  node, edge = self._add_super_atom(node, edge)
@@ -372,7 +374,7 @@ class MolGraphFeaturizer(Featurizer):
372
374
  num_nodes = node['feature'].shape[0]
373
375
  node = _add_super_nodes(node, num_super_nodes)
374
376
  edge = _add_super_edges(
375
- edge, num_nodes, num_super_nodes, self.feature_dtype, self.index_dtype
377
+ edge, num_nodes, num_super_nodes, self.feature_dtype, self.index_dtype, self.self_loops
376
378
  )
377
379
  return node, edge
378
380
 
@@ -533,9 +535,10 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
533
535
  mol = chem.Mol.from_encoding(x, explicit_hs=explicit_hs)
534
536
 
535
537
  if mol is None:
536
- warn(
538
+ warnings.warn(
537
539
  f'Could not obtain `chem.Mol` from {x}. '
538
- 'Proceeding without it.'
540
+ 'Proceeding without it.',
541
+ stacklevel=2
539
542
  )
540
543
  return None
541
544
 
@@ -575,10 +578,11 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
575
578
 
576
579
  if molecule_feature is not None:
577
580
  if 'feature' in context:
578
- warn(
581
+ warnings.warn(
579
582
  'Found both inputted and computed context feature. '
580
583
  'Overwriting inputted context feature with computed '
581
- 'context feature (based on `molecule_features`).'
584
+ 'context feature (based on `molecule_features`).',
585
+ stacklevel=2
582
586
  )
583
587
  context['feature'] = molecule_feature
584
588
 
@@ -702,11 +706,15 @@ def _add_super_edges(
702
706
  num_super_nodes: int,
703
707
  feature_dtype: str,
704
708
  index_dtype: str,
709
+ self_loops: bool,
705
710
  ) -> dict[str, np.ndarray]:
706
711
  edge = copy.deepcopy(edge)
707
- super_node_indices = (
708
- np.repeat(np.arange(num_super_nodes), [num_nodes]) + num_nodes
709
- )
712
+
713
+ super_node_indices = np.arange(num_super_nodes) + num_nodes
714
+ if self_loops:
715
+ edge['source'] = np.concatenate([edge['source'], super_node_indices])
716
+ edge['target'] = np.concatenate([edge['target'], super_node_indices])
717
+ super_node_indices = np.repeat(super_node_indices, [num_nodes])
710
718
  node_indices = (
711
719
  np.tile(np.arange(num_nodes), [num_super_nodes])
712
720
  )
@@ -721,6 +729,8 @@ def _add_super_edges(
721
729
  if 'feature' in edge:
722
730
  num_edges = int(edge['feature'].shape[0])
723
731
  num_super_edges = int(num_super_nodes * num_nodes * 2)
732
+ if self_loops:
733
+ num_super_edges += num_super_nodes
724
734
  edge['super'] = np.asarray(
725
735
  ([False] * num_edges + [True] * num_super_edges),
726
736
  dtype=bool
@@ -735,28 +745,8 @@ def _add_super_edges(
735
745
  ]
736
746
  )
737
747
 
738
- if 'self_loop' in edge:
739
- edge['self_loop'] = np.pad(
740
- edge['self_loop'], [(0, num_nodes * num_super_nodes * 2)],
741
- constant_values=False,
742
- )
743
- if 'length' in edge:
744
- edge['length'] = np.pad(edge['length'], [(0, 0), (1, 0)])
745
- zero_array = np.zeros([num_nodes * num_super_nodes * 2], dtype='int32')
746
- edge_length_dim = edge['length'].shape[1]
747
- virtual_edge_length = np.eye(edge_length_dim)[zero_array]
748
- edge['length'] = np.concatenate([edge['length'], virtual_edge_length])
749
- edge['length'] = edge['length'].astype(feature_dtype)
750
-
751
748
  return edge
752
749
 
753
-
754
- def warn(message: str) -> None:
755
- warnings.warn(
756
- message=message,
757
- category=UserWarning,
758
- stacklevel=1
759
- )
760
750
 
761
751
  MolFeaturizer = MolGraphFeaturizer
762
752
  MolFeaturizer3D = MolGraphFeaturizer3D