molcraft 0.1.0a16__tar.gz → 0.1.0a17__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

Files changed (34) hide show
  1. {molcraft-0.1.0a16 → molcraft-0.1.0a17}/PKG-INFO +13 -12
  2. {molcraft-0.1.0a16 → molcraft-0.1.0a17}/README.md +12 -11
  3. {molcraft-0.1.0a16 → molcraft-0.1.0a17}/molcraft/__init__.py +1 -2
  4. molcraft-0.1.0a17/molcraft/applications/chromatography.py +0 -0
  5. {molcraft-0.1.0a16 → molcraft-0.1.0a17}/molcraft/applications/proteomics.py +47 -92
  6. {molcraft-0.1.0a16 → molcraft-0.1.0a17}/molcraft/chem.py +17 -22
  7. {molcraft-0.1.0a16 → molcraft-0.1.0a17}/molcraft/datasets.py +6 -6
  8. {molcraft-0.1.0a16 → molcraft-0.1.0a17}/molcraft/descriptors.py +14 -0
  9. {molcraft-0.1.0a16 → molcraft-0.1.0a17}/molcraft/features.py +50 -58
  10. molcraft-0.1.0a17/molcraft/featurizers.py +523 -0
  11. {molcraft-0.1.0a16 → molcraft-0.1.0a17}/molcraft/layers.py +1 -1
  12. {molcraft-0.1.0a16 → molcraft-0.1.0a17}/molcraft/models.py +2 -0
  13. {molcraft-0.1.0a16 → molcraft-0.1.0a17}/molcraft/records.py +24 -15
  14. {molcraft-0.1.0a16 → molcraft-0.1.0a17}/molcraft.egg-info/PKG-INFO +13 -12
  15. {molcraft-0.1.0a16 → molcraft-0.1.0a17}/molcraft.egg-info/SOURCES.txt +1 -1
  16. {molcraft-0.1.0a16 → molcraft-0.1.0a17}/tests/test_featurizers.py +10 -17
  17. molcraft-0.1.0a16/molcraft/conformers.py +0 -151
  18. molcraft-0.1.0a16/molcraft/featurizers.py +0 -753
  19. {molcraft-0.1.0a16 → molcraft-0.1.0a17}/LICENSE +0 -0
  20. {molcraft-0.1.0a16 → molcraft-0.1.0a17}/molcraft/applications/__init__.py +0 -0
  21. {molcraft-0.1.0a16 → molcraft-0.1.0a17}/molcraft/callbacks.py +0 -0
  22. {molcraft-0.1.0a16 → molcraft-0.1.0a17}/molcraft/losses.py +0 -0
  23. {molcraft-0.1.0a16 → molcraft-0.1.0a17}/molcraft/ops.py +0 -0
  24. {molcraft-0.1.0a16 → molcraft-0.1.0a17}/molcraft/tensors.py +0 -0
  25. {molcraft-0.1.0a16 → molcraft-0.1.0a17}/molcraft.egg-info/dependency_links.txt +0 -0
  26. {molcraft-0.1.0a16 → molcraft-0.1.0a17}/molcraft.egg-info/requires.txt +0 -0
  27. {molcraft-0.1.0a16 → molcraft-0.1.0a17}/molcraft.egg-info/top_level.txt +0 -0
  28. {molcraft-0.1.0a16 → molcraft-0.1.0a17}/pyproject.toml +0 -0
  29. {molcraft-0.1.0a16 → molcraft-0.1.0a17}/setup.cfg +0 -0
  30. {molcraft-0.1.0a16 → molcraft-0.1.0a17}/tests/test_chem.py +0 -0
  31. {molcraft-0.1.0a16 → molcraft-0.1.0a17}/tests/test_layers.py +0 -0
  32. {molcraft-0.1.0a16 → molcraft-0.1.0a17}/tests/test_losses.py +0 -0
  33. {molcraft-0.1.0a16 → molcraft-0.1.0a17}/tests/test_models.py +0 -0
  34. {molcraft-0.1.0a16 → molcraft-0.1.0a17}/tests/test_tensors.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: molcraft
3
- Version: 0.1.0a16
3
+ Version: 0.1.0a17
4
4
  Summary: Graph Neural Networks for Molecular Machine Learning
5
5
  Author-email: Alexander Kensert <alexander.kensert@gmail.com>
6
6
  License: MIT License
@@ -43,9 +43,9 @@ Provides-Extra: gpu
43
43
  Requires-Dist: tensorflow[and-cuda]>=2.16; extra == "gpu"
44
44
  Dynamic: license-file
45
45
 
46
- <img src="https://github.com/akensert/molcraft/blob/main/docs/_static/molcraft-logo.png" alt="molcraft-logo">
46
+ <img src="https://github.com/akensert/molcraft/blob/main/docs/_static/molcraft-logo.png" alt="molcraft-logo", width="90%">
47
47
 
48
- **Deep Learning on Molecules**: A Minimalistic GNN package for Molecular ML.
48
+ **Deep Learning on Molecules**: A Minimalistic GNN package for Molecular ML.
49
49
 
50
50
  > [!NOTE]
51
51
  > In progress.
@@ -83,11 +83,12 @@ featurizer = featurizers.MolGraphFeaturizer(
83
83
  features.BondType(),
84
84
  features.IsRotatable(),
85
85
  ],
86
- super_atom=True,
86
+ super_node=True,
87
87
  self_loops=True,
88
+ include_hydrogens=False,
88
89
  )
89
90
 
90
- graph = featurizer([('N[C@@H](C)C(=O)O', 2.0), ('N[C@@H](CS)C(=O)O', 1.0)])
91
+ graph = featurizer([('N[C@@H](C)C(=O)O', 2.5), ('N[C@@H](CS)C(=O)O', 1.5)])
91
92
  print(graph)
92
93
 
93
94
  model = models.GraphModel.from_layers(
@@ -95,13 +96,13 @@ model = models.GraphModel.from_layers(
95
96
  layers.Input(graph.spec),
96
97
  layers.NodeEmbedding(dim=128),
97
98
  layers.EdgeEmbedding(dim=128),
98
- layers.GraphTransformer(units=128),
99
- layers.GraphTransformer(units=128),
100
- layers.GraphTransformer(units=128),
101
- layers.GraphTransformer(units=128),
102
- layers.Readout(mode='mean'),
103
- keras.layers.Dense(units=1024, activation='relu'),
104
- keras.layers.Dense(units=1024, activation='relu'),
99
+ layers.GraphConv(units=128),
100
+ layers.GraphConv(units=128),
101
+ layers.GraphConv(units=128),
102
+ layers.GraphConv(units=128),
103
+ layers.Readout(),
104
+ keras.layers.Dense(units=1024, activation='elu'),
105
+ keras.layers.Dense(units=1024, activation='elu'),
105
106
  keras.layers.Dense(1)
106
107
  ]
107
108
  )
@@ -1,6 +1,6 @@
1
- <img src="https://github.com/akensert/molcraft/blob/main/docs/_static/molcraft-logo.png" alt="molcraft-logo">
1
+ <img src="https://github.com/akensert/molcraft/blob/main/docs/_static/molcraft-logo.png" alt="molcraft-logo", width="90%">
2
2
 
3
- **Deep Learning on Molecules**: A Minimalistic GNN package for Molecular ML.
3
+ **Deep Learning on Molecules**: A Minimalistic GNN package for Molecular ML.
4
4
 
5
5
  > [!NOTE]
6
6
  > In progress.
@@ -38,11 +38,12 @@ featurizer = featurizers.MolGraphFeaturizer(
38
38
  features.BondType(),
39
39
  features.IsRotatable(),
40
40
  ],
41
- super_atom=True,
41
+ super_node=True,
42
42
  self_loops=True,
43
+ include_hydrogens=False,
43
44
  )
44
45
 
45
- graph = featurizer([('N[C@@H](C)C(=O)O', 2.0), ('N[C@@H](CS)C(=O)O', 1.0)])
46
+ graph = featurizer([('N[C@@H](C)C(=O)O', 2.5), ('N[C@@H](CS)C(=O)O', 1.5)])
46
47
  print(graph)
47
48
 
48
49
  model = models.GraphModel.from_layers(
@@ -50,13 +51,13 @@ model = models.GraphModel.from_layers(
50
51
  layers.Input(graph.spec),
51
52
  layers.NodeEmbedding(dim=128),
52
53
  layers.EdgeEmbedding(dim=128),
53
- layers.GraphTransformer(units=128),
54
- layers.GraphTransformer(units=128),
55
- layers.GraphTransformer(units=128),
56
- layers.GraphTransformer(units=128),
57
- layers.Readout(mode='mean'),
58
- keras.layers.Dense(units=1024, activation='relu'),
59
- keras.layers.Dense(units=1024, activation='relu'),
54
+ layers.GraphConv(units=128),
55
+ layers.GraphConv(units=128),
56
+ layers.GraphConv(units=128),
57
+ layers.GraphConv(units=128),
58
+ layers.Readout(),
59
+ keras.layers.Dense(units=1024, activation='elu'),
60
+ keras.layers.Dense(units=1024, activation='elu'),
60
61
  keras.layers.Dense(1)
61
62
  ]
62
63
  )
@@ -1,4 +1,4 @@
1
- __version__ = '0.1.0a16'
1
+ __version__ = '0.1.0a17'
2
2
 
3
3
  import os
4
4
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
@@ -6,7 +6,6 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
6
6
  from molcraft import chem
7
7
  from molcraft import features
8
8
  from molcraft import descriptors
9
- from molcraft import conformers
10
9
  from molcraft import featurizers
11
10
  from molcraft import layers
12
11
  from molcraft import models
@@ -3,7 +3,6 @@ import keras
3
3
  import numpy as np
4
4
  import tensorflow as tf
5
5
  import tensorflow_text as tf_text
6
- import json
7
6
 
8
7
  from molcraft import featurizers
9
8
  from molcraft import tensors
@@ -46,7 +45,22 @@ default_residues: dict[str, str] = {
46
45
  "Y": "N[C@@H](Cc1ccc(O)cc1)C(=O)O",
47
46
  }
48
47
 
49
-
48
+ def register_residues(residues: dict[str, str]) -> None:
49
+ # TODO: Implement functions that check if residue has N- or C-terminal mod
50
+ # if C-terminal mod, no need to enforce concatenatable perm.
51
+ # if N-terminal mod, enforce only 'C(=O)O'
52
+ # if normal mod, enforce concatenateable perm ('N[C@@H]' and 'C(=O)O)).
53
+ for residue, smiles in residues.items():
54
+ if residue.startswith('P'):
55
+ smiles.startswith('N'), f'Incorrect SMILES permutation for {residue}.'
56
+ elif not residue.startswith('['):
57
+ smiles.startswith('N[C@@H]'), f'Incorrect SMILES permutation for {residue}.'
58
+ if len(residue) > 1 and not residue[1] == "-":
59
+ assert smiles.endswith('C(=O)O'), f'Incorrect SMILES permutation for {residue}.'
60
+ registered_residues[residue] = smiles
61
+ registered_residues[residue + '*'] = smiles.strip('O')
62
+
63
+
50
64
  class Peptide(chem.Mol):
51
65
 
52
66
  @classmethod
@@ -72,40 +86,51 @@ class ResidueEmbedding(keras.layers.Layer):
72
86
  self,
73
87
  featurizer: featurizers.MolGraphFeaturizer,
74
88
  embedder: models.GraphModel,
89
+ residues: dict[str, str] | None = None,
75
90
  **kwargs
76
91
  ) -> None:
77
- residues = kwargs.pop('_residues', None)
78
92
  super().__init__(**kwargs)
79
93
  if residues is None:
80
- residues = registered_residues.copy()
81
- self._residues = residues
94
+ residues = {}
95
+ self._residue_dict = {**default_residues, **residues}
82
96
  self.embedder = embedder
83
97
  self.featurizer = featurizer
98
+ self.embedding_dim = self.embedder.output.shape[-1]
84
99
  self.ragged_split = SequenceSplitter(pad=False)
85
100
  self.split = SequenceSplitter(pad=True)
101
+ self.use_cached_embeddings = tf.Variable(False)
86
102
  self.supports_masking = True
87
103
 
88
- def build(self, input_shape) -> None:
89
- embedding_dim = self.embedder.output.shape[-1]
90
- residues = sorted(self._residues.keys())
91
- smiles = [self._residues[residue] for residue in residues]
104
+ @property
105
+ def residues(self) -> dict[str, str]:
106
+ return self._residue_dict
107
+
108
+ @residues.setter
109
+ def residues(self, residues: dict[str, str]) -> None:
110
+ self._residue_dict = residues
92
111
  num_residues = len(residues)
93
- self.oov_index = np.where(np.array(residues) == "G")[0][0]
112
+ residue_keys = sorted(residues.keys())
113
+ oov_value = np.where(np.array(residue_keys) == "G")[0][0]
94
114
  self.mapping = tf.lookup.StaticHashTable(
95
115
  tf.lookup.KeyValueTensorInitializer(
96
- keys=residues,
116
+ keys=residue_keys,
97
117
  values=range(num_residues)
98
118
  ),
99
- default_value=-1,
119
+ default_value=oov_value,
100
120
  )
101
- self.graph = tf.stack([self.featurizer(s) for s in smiles], axis=0)
121
+ self.graph = tf.stack([
122
+ self.featurizer(residues[residue]) for residue in residue_keys
123
+ ], axis=0)
102
124
  self.cached_embeddings = tf.Variable(
103
- initial_value=tf.zeros((num_residues, embedding_dim))
125
+ initial_value=tf.zeros((num_residues, self.embedding_dim))
104
126
  )
105
- self.use_cached_embeddings = tf.Variable(False)
127
+ _ = self.cache_and_get_embeddings()
128
+
129
+ def build(self, input_shape) -> None:
130
+ self.residues = self._residue_dict
106
131
  super().build(input_shape)
107
132
 
108
- def call(self, sequences, training=None) -> tensors.GraphTensor:
133
+ def call(self, sequences: tf.Tensor, training: bool = None) -> tf.Tensor:
109
134
  if training is False:
110
135
  self.use_cached_embeddings.assign(True)
111
136
  else:
@@ -113,17 +138,16 @@ class ResidueEmbedding(keras.layers.Layer):
113
138
  embeddings = tf.cond(
114
139
  pred=self.use_cached_embeddings,
115
140
  true_fn=lambda: self.cached_embeddings,
116
- false_fn=lambda: self.embeddings(),
141
+ false_fn=lambda: self.cache_and_get_embeddings(),
117
142
  )
118
143
  sequences = self.ragged_split(sequences)
119
144
  sequences = keras.ops.concatenate([
120
145
  tf.strings.join([sequences[:, :-1], '*']), sequences[:, -1:]
121
146
  ], axis=1)
122
147
  indices = self.mapping.lookup(sequences)
123
- indices = keras.ops.where(indices == -1, self.oov_index, indices)
124
148
  return tf.gather(embeddings, indices).to_tensor()
125
149
 
126
- def embeddings(self) -> tf.Tensor:
150
+ def cache_and_get_embeddings(self) -> tf.Tensor:
127
151
  embeddings = self.embedder(self.graph)
128
152
  self.cached_embeddings.assign(embeddings)
129
153
  return embeddings
@@ -139,9 +163,9 @@ class ResidueEmbedding(keras.layers.Layer):
139
163
  def get_config(self) -> dict:
140
164
  config = super().get_config()
141
165
  config.update({
142
- '_residues': self._residues,
143
166
  'featurizer': keras.saving.serialize_keras_object(self.featurizer),
144
- 'embedder': keras.saving.serialize_keras_object(self.embedder)
167
+ 'embedder': keras.saving.serialize_keras_object(self.embedder),
168
+ 'residues': self._residue_dict,
145
169
  })
146
170
  return config
147
171
 
@@ -153,87 +177,18 @@ class ResidueEmbedding(keras.layers.Layer):
153
177
 
154
178
 
155
179
  @keras.saving.register_keras_serializable(package='proteomics')
156
- class SequenceSplitter(keras.layers.Layer):
180
+ class SequenceSplitter(keras.layers.Layer):
157
181
 
158
182
  def __init__(self, pad: bool, **kwargs):
159
183
  super().__init__(**kwargs)
160
184
  self.pad = pad
161
185
 
162
- def call(self, inputs):
186
+ def call(self, inputs: tf.Tensor) -> tf.Tensor | tf.RaggedTensor:
163
187
  inputs = tf_text.regex_split(inputs, residue_pattern, residue_pattern)
164
188
  if self.pad:
165
189
  inputs = inputs.to_tensor()
166
190
  return inputs
167
191
 
168
192
 
169
- def interpret(model: keras.models.Model, sequence: list[str]) -> tensors.GraphTensor:
170
-
171
- if not tf.is_tensor(sequence):
172
- sequence = keras.ops.convert_to_tensor(sequence)
173
-
174
- # Find embedding layer
175
- for layer in model.layers:
176
- if isinstance(layer, ResidueEmbedding):
177
- break
178
-
179
- # Use embedding layer to convert the sequence to a graph
180
- residues = layer.ragged_split(sequence)
181
- residues = keras.ops.concatenate([
182
- tf.strings.join([residues[:, :-1], '*']), residues[:, -1:]
183
- ], axis=1)
184
- indices = layer.mapping.lookup(residues)
185
- graph = tf.concat([
186
- layer.graph[residue_ids] for residue_ids in indices
187
- ], axis=0)
188
-
189
- # Define layer which reshapes data into sequences of residue embeddings
190
- num_residues = indices.row_lengths()
191
- to_sequence = (
192
- lambda x: tf.RaggedTensor.from_row_lengths(x, num_residues).to_tensor()
193
- )
194
- reshape = keras.layers.Lambda(to_sequence)
195
-
196
- # Obtain the embedder part of the original model
197
- embedder = layer.embedder
198
- # Obtain the remaining part of the original model
199
- predictor = keras.models.Model(embedder.output, model.output)
200
- # Obtain an 'interpretable model', based on the original model
201
- inputs = layers.Input(graph.spec)
202
- x = inputs
203
- for layer in embedder.layers: # Loop over layers to expose them
204
- x = layer(x)
205
- x = reshape(x)
206
- outputs = predictor(x)
207
- interpretable_model = models.GraphModel(inputs, outputs)
208
-
209
- # Interpret original model through the 'interpretable model'
210
- graph = models.interpret(interpretable_model, graph)
211
- del interpretable_model
212
-
213
- # Update 'size' field with new sizes corresponding to peptides for convenience
214
- # Allows the user to obtain n:th peptide graph using indexing: nth_peptide = graph[n]
215
- peptide_indices = range(len(num_residues))
216
- peptide_indicator = keras.ops.repeat(peptide_indices, num_residues)
217
- residue_sizes = graph.context['size']
218
- peptide_sizes = keras.ops.segment_sum(residue_sizes, peptide_indicator)
219
- return graph.update({'context': {'size': peptide_sizes, 'sequence': sequence}})
220
-
221
-
222
- def register_residues(residues: dict[str, str]) -> None:
223
- # TODO: Implement functions that check if residue has N- or C-terminal mod
224
- # if C-terminal mod, no need to enforce concatenatable perm.
225
- # if N-terminal mod, enforce only 'C(=O)O'
226
- # if normal mod, enforce concatenateable perm ('N[C@@H]' and 'C(=O)O)).
227
- for residue, smiles in residues.items():
228
- if residue.startswith('P'):
229
- smiles.startswith('N'), f'Incorrect SMILES permutation for {residue}.'
230
- elif not residue.startswith('['):
231
- smiles.startswith('N[C@@H]'), f'Incorrect SMILES permutation for {residue}.'
232
- if len(residue) > 1 and not residue[1] == "-":
233
- assert smiles.endswith('C(=O)O'), f'Incorrect SMILES permutation for {residue}.'
234
- registered_residues[residue] = smiles
235
- registered_residues[residue + '*'] = smiles.strip('O')
236
-
237
-
238
193
  registered_residues: dict[str, str] = {}
239
194
  register_residues(default_residues)
@@ -19,8 +19,6 @@ class Mol(Chem.Mol):
19
19
  @classmethod
20
20
  def from_encoding(cls, encoding: str, explicit_hs: bool = False, **kwargs) -> 'Mol':
21
21
  rdkit_mol = get_mol(encoding, **kwargs)
22
- if not rdkit_mol:
23
- return None
24
22
  if explicit_hs:
25
23
  rdkit_mol = Chem.AddHs(rdkit_mol)
26
24
  rdkit_mol.__class__ = cls
@@ -102,21 +100,13 @@ class Mol(Chem.Mol):
102
100
 
103
101
  def get_conformer(self, index: int = 0) -> 'Conformer':
104
102
  if self.num_conformers == 0:
105
- warnings.warn(
106
- 'Molecule has no conformer. To embed conformer(s), invoke the `embed` method, '
107
- 'and optionally followed by `minimize()` to perform force field minimization.',
108
- stacklevel=2
109
- )
103
+ warnings.warn('Molecule has no conformer.')
110
104
  return None
111
105
  return Conformer.cast(self.GetConformer(index))
112
106
 
113
107
  def get_conformers(self) -> list['Conformer']:
114
108
  if self.num_conformers == 0:
115
- warnings.warn(
116
- 'Molecule has no conformers. To embed conformers, invoke the `embed` method, '
117
- 'and optionally followed by `minimize()` to perform force field minimization.',
118
- stacklevel=2
119
- )
109
+ warnings.warn('Molecule has no conformer.')
120
110
  return []
121
111
  return [Conformer.cast(x) for x in self.GetConformers()]
122
112
 
@@ -222,11 +212,10 @@ def get_mol(
222
212
  else:
223
213
  mol = Chem.MolFromSmiles(encoding, sanitize=False)
224
214
  if mol is not None:
225
- return sanitize_mol(mol, strict, assign_stereo_chemistry)
226
- raise ValueError(
227
- f"{encoding} is invalid; "
228
- f"make sure {encoding} is a valid SMILES or InChI string."
229
- )
215
+ mol = sanitize_mol(mol, strict, assign_stereo_chemistry)
216
+ if mol is not None:
217
+ return mol
218
+ raise ValueError(f'Could not obtain `chem.Mol` from {encoding}.')
230
219
 
231
220
  def get_adjacency_matrix(
232
221
  mol: Chem.Mol,
@@ -402,8 +391,9 @@ def embed_conformers(
402
391
  mol: Mol,
403
392
  num_conformers: int,
404
393
  method: str = 'ETKDGv3',
394
+ random_seed: int | None = None,
405
395
  **kwargs
406
- ) -> None:
396
+ ) -> Mol:
407
397
  available_embedding_methods = {
408
398
  'ETDG': rdDistGeom.ETDG(),
409
399
  'ETKDG': rdDistGeom.ETKDG(),
@@ -423,6 +413,9 @@ def embed_conformers(
423
413
  for key, value in kwargs.items():
424
414
  setattr(embedding_method, key, value)
425
415
 
416
+ if random_seed is not None:
417
+ embedding_method.randomSeed = random_seed
418
+
426
419
  success = rdDistGeom.EmbedMultipleConfs(
427
420
  mol, numConfs=num_conformers, params=embedding_method
428
421
  )
@@ -440,6 +433,8 @@ def embed_conformers(
440
433
  fallback_embedding_method.useRandomCoords = True
441
434
  fallback_embedding_method.maxAttempts = max_attempts
442
435
  fallback_embedding_method.clearConfs = False
436
+ if random_seed is not None:
437
+ fallback_embedding_method.randomSeed = random_seed
443
438
  success = rdDistGeom.EmbedMultipleConfs(
444
439
  mol, numConfs=(num_conformers - num_successes), params=fallback_embedding_method
445
440
  )
@@ -459,7 +454,7 @@ def optimize_conformers(
459
454
  num_threads: bool = 1,
460
455
  ignore_interfragment_interactions: bool = True,
461
456
  vdw_threshold: float = 10.0,
462
- ):
457
+ ) -> Mol:
463
458
  available_force_field_methods = [
464
459
  'MMFF', 'MMFF94', 'MMFF94s', 'UFF'
465
460
  ]
@@ -502,7 +497,7 @@ def prune_conformers(
502
497
  keep: int = 1,
503
498
  threshold: float = 0.0,
504
499
  energy_force_field: str = 'UFF',
505
- ):
500
+ ) -> Mol:
506
501
  if mol.num_conformers == 0:
507
502
  warnings.warn(
508
503
  'Molecule has no conformers. To embed conformers, invoke the `embed` method, '
@@ -539,7 +534,7 @@ def _uff_optimize_conformers(
539
534
  vdw_threshold: float = 10.0,
540
535
  ignore_interfragment_interactions: bool = True,
541
536
  **kwargs,
542
- ) -> Mol:
537
+ ) -> tuple[list[float], list[bool]]:
543
538
  """Universal Force Field Minimization.
544
539
  """
545
540
  results = rdForceFieldHelpers.UFFOptimizeMoleculeConfs(
@@ -560,7 +555,7 @@ def _mmff_optimize_conformers(
560
555
  variant: str = 'MMFF94',
561
556
  ignore_interfragment_interactions: bool = True,
562
557
  **kwargs,
563
- ) -> Mol:
558
+ ) -> tuple[list[float], list[bool]]:
564
559
  """Merck Molecular Force Field Minimization.
565
560
  """
566
561
  if not rdForceFieldHelpers.MMFFHasAllMoleculeParams(mol):
@@ -11,7 +11,7 @@ def split(
11
11
  test_size: float | None = None,
12
12
  groups: str | np.ndarray = None,
13
13
  shuffle: bool = False,
14
- random_state: int | None = None,
14
+ random_seed: int | None = None,
15
15
  ) -> tuple[np.ndarray | pd.DataFrame, ...]:
16
16
  """Splits the dataset into subsets.
17
17
 
@@ -28,7 +28,7 @@ def split(
28
28
  The groups to perform the splitting on.
29
29
  shuffle:
30
30
  Whether the dataset should be shuffled prior to splitting.
31
- random_state:
31
+ random_seed:
32
32
  The random state/seed. Only applicable if shuffling.
33
33
  """
34
34
  if not isinstance(data, (pd.DataFrame, np.ndarray)):
@@ -69,7 +69,7 @@ def split(
69
69
  train_size += remainder
70
70
 
71
71
  if shuffle:
72
- np.random.seed(random_state)
72
+ np.random.seed(random_seed)
73
73
  np.random.shuffle(indices)
74
74
 
75
75
  train_mask = np.isin(groups, indices[:train_size])
@@ -84,7 +84,7 @@ def cv_split(
84
84
  num_splits: int = 10,
85
85
  groups: str | np.ndarray = None,
86
86
  shuffle: bool = False,
87
- random_state: int | None = None,
87
+ random_seed: int | None = None,
88
88
  ) -> typing.Iterator[
89
89
  tuple[np.ndarray | pd.DataFrame, np.ndarray | pd.DataFrame]
90
90
  ]:
@@ -99,7 +99,7 @@ def cv_split(
99
99
  The groups to perform the splitting on.
100
100
  shuffle:
101
101
  Whether the dataset should be shuffled prior to splitting.
102
- random_state:
102
+ random_seed:
103
103
  The random state/seed. Only applicable if shuffling.
104
104
  """
105
105
  if not isinstance(data, (pd.DataFrame, np.ndarray)):
@@ -119,7 +119,7 @@ def cv_split(
119
119
  f'the data size or the number of groups ({size}).'
120
120
  )
121
121
  if shuffle:
122
- np.random.seed(random_state)
122
+ np.random.seed(random_seed)
123
123
  np.random.shuffle(indices)
124
124
 
125
125
  indices_splits = np.array_split(indices, num_splits)
@@ -91,3 +91,17 @@ class NumRings(Descriptor):
91
91
  def call(self, mol: chem.Mol) -> np.ndarray:
92
92
  return rdMolDescriptors.CalcNumRings(mol)
93
93
 
94
+
95
+ @keras.saving.register_keras_serializable(package='molcraft')
96
+ class AtomCount(Descriptor):
97
+
98
+ def __init__(self, atom_type: str, **kwargs):
99
+ super().__init__(**kwargs)
100
+ self.atom_type = atom_type
101
+
102
+ def call(self, mol: chem.Mol) -> np.ndarray:
103
+ count = 0
104
+ for atom in mol.atoms:
105
+ if atom.GetSymbol() == self.atom_type:
106
+ count += 1
107
+ return count
@@ -41,11 +41,7 @@ class Feature(abc.ABC):
41
41
 
42
42
  def __call__(self, mol: chem.Mol) -> np.ndarray:
43
43
  if not isinstance(mol, chem.Mol):
44
- raise ValueError(
45
- f'Input to {self.name} needs to be a `chem.Mol`, which '
46
- 'implements two properties that should be iterated over '
47
- 'to compute features: `atoms` and `bonds`.'
48
- )
44
+ raise TypeError(f'Input to {self.name} must be a `chem.Mol` instance.')
49
45
  features = self.call(mol)
50
46
  if len(features) != mol.num_atoms and len(features) != mol.num_bonds:
51
47
  raise ValueError(
@@ -119,59 +115,6 @@ class Feature(abc.ABC):
119
115
  return np.asarray([value], dtype=self.dtype)
120
116
 
121
117
 
122
- @keras.saving.register_keras_serializable(package='molcraft')
123
- class EdgeFeature(Feature):
124
-
125
- def __call__(self, mol: chem.Mol) -> np.ndarray:
126
- if not isinstance(mol, chem.Mol):
127
- raise ValueError(
128
- f'Input to {self.name} needs to be a `chem.Mol`, which '
129
- 'implements two properties that should be iterated over '
130
- 'to compute features: `atoms` and `bonds`.'
131
- )
132
- features = self.call(mol)
133
- if len(features) != int(mol.num_atoms**2):
134
- raise ValueError(
135
- f'The number of features computed by {self.name} does not '
136
- 'match the number of node pairs in the `chem.Mol` object. '
137
- f'Make sure the list of items returned by {self.name}(input) '
138
- 'correspond to node/atom pairs: '
139
- '[(0, 0), (0, 1), ..., (0, N), (1, 0), ... (N, N)], '
140
- 'where N denotes the number of nodes/atoms.'
141
- )
142
- func = (
143
- self._featurize_categorical if self.vocab else
144
- self._featurize_floating
145
- )
146
- return np.asarray([func(x) for x in features], dtype=self.dtype)
147
-
148
-
149
- @keras.saving.register_keras_serializable(package='molcraft')
150
- class Distance(EdgeFeature):
151
-
152
- def __init__(
153
- self,
154
- max_distance: int = None,
155
- allow_oov: int = True,
156
- encode_oov: bool = True,
157
- **kwargs,
158
- ) -> None:
159
- vocab = kwargs.pop('vocab', None)
160
- if not vocab:
161
- if max_distance is None:
162
- max_distance = 20
163
- vocab = list(range(max_distance + 1))
164
- super().__init__(
165
- vocab=vocab,
166
- allow_oov=allow_oov,
167
- encode_oov=encode_oov,
168
- **kwargs
169
- )
170
-
171
- def call(self, mol: chem.Mol) -> list[int]:
172
- return [int(x) for x in chem.get_distances(mol).reshape(-1)]
173
-
174
-
175
118
  @keras.saving.register_keras_serializable(package='molcraft')
176
119
  class AtomType(Feature):
177
120
  def call(self, mol: chem.Mol) -> list[int, float, str]:
@@ -340,6 +283,55 @@ class IsRotatable(Feature):
340
283
  return chem.rotatable_bonds(mol)
341
284
 
342
285
 
286
+ @keras.saving.register_keras_serializable(package='molcraft')
287
+ class PairFeature(Feature):
288
+
289
+ def __call__(self, mol: chem.Mol) -> np.ndarray:
290
+ if not isinstance(mol, chem.Mol):
291
+ raise TypeError(f'Input to {self.name} must be a `chem.Mol` instance.')
292
+ features = self.call(mol)
293
+ if len(features) != int(mol.num_atoms**2):
294
+ raise ValueError(
295
+ f'The number of features computed by {self.name} does not '
296
+ 'match the number of node/atom pairs in the `chem.Mol` object. '
297
+ f'Make sure the list of items returned by {self.name}(input) '
298
+ 'correspond to node/atom pairs: '
299
+ '[(0, 0), (0, 1), ..., (0, N), (1, 0), ... (N, N)], '
300
+ 'where N denotes the number of nodes/atoms.'
301
+ )
302
+ func = (
303
+ self._featurize_categorical if self.vocab else
304
+ self._featurize_floating
305
+ )
306
+ return np.asarray([func(x) for x in features], dtype=self.dtype)
307
+
308
+
309
+ @keras.saving.register_keras_serializable(package='molcraft')
310
+ class PairDistance(PairFeature):
311
+
312
+ def __init__(
313
+ self,
314
+ max_distance: int = None,
315
+ allow_oov: int = True,
316
+ encode_oov: bool = True,
317
+ **kwargs,
318
+ ) -> None:
319
+ vocab = kwargs.pop('vocab', None)
320
+ if not vocab:
321
+ if max_distance is None:
322
+ max_distance = 10
323
+ vocab = list(range(max_distance + 1))
324
+ super().__init__(
325
+ vocab=vocab,
326
+ allow_oov=allow_oov,
327
+ encode_oov=encode_oov,
328
+ **kwargs
329
+ )
330
+
331
+ def call(self, mol: chem.Mol) -> list[int]:
332
+ return [int(x) for x in chem.get_distances(mol).reshape(-1)]
333
+
334
+
343
335
  default_vocabulary = {
344
336
  'AtomType': [
345
337
  '*', 'H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na',