molcraft 0.1.0a2__tar.gz → 0.1.0a4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

Files changed (32) hide show
  1. {molcraft-0.1.0a2 → molcraft-0.1.0a4}/PKG-INFO +6 -6
  2. {molcraft-0.1.0a2 → molcraft-0.1.0a4}/README.md +5 -5
  3. {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft/__init__.py +2 -1
  4. {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft/callbacks.py +12 -0
  5. {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft/descriptors.py +24 -23
  6. {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft/experimental/peptides.py +96 -79
  7. {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft/features.py +5 -3
  8. {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft/featurizers.py +61 -38
  9. {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft/layers.py +1004 -425
  10. {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft/models.py +47 -3
  11. {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft/ops.py +14 -3
  12. {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft/tensors.py +3 -3
  13. {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft.egg-info/PKG-INFO +6 -6
  14. molcraft-0.1.0a4/tests/test_featurizers.py +197 -0
  15. molcraft-0.1.0a4/tests/test_layers.py +287 -0
  16. {molcraft-0.1.0a2 → molcraft-0.1.0a4}/tests/test_models.py +41 -1
  17. molcraft-0.1.0a2/tests/test_featurizers.py +0 -111
  18. molcraft-0.1.0a2/tests/test_layers.py +0 -143
  19. {molcraft-0.1.0a2 → molcraft-0.1.0a4}/LICENSE +0 -0
  20. {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft/chem.py +0 -0
  21. {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft/conformers.py +0 -0
  22. {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft/datasets.py +0 -0
  23. {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft/experimental/__init__.py +0 -0
  24. {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft/records.py +0 -0
  25. {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft.egg-info/SOURCES.txt +0 -0
  26. {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft.egg-info/dependency_links.txt +0 -0
  27. {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft.egg-info/requires.txt +0 -0
  28. {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft.egg-info/top_level.txt +0 -0
  29. {molcraft-0.1.0a2 → molcraft-0.1.0a4}/pyproject.toml +0 -0
  30. {molcraft-0.1.0a2 → molcraft-0.1.0a4}/setup.cfg +0 -0
  31. {molcraft-0.1.0a2 → molcraft-0.1.0a4}/tests/test_chem.py +0 -0
  32. {molcraft-0.1.0a2 → molcraft-0.1.0a4}/tests/test_tensors.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: molcraft
3
- Version: 0.1.0a2
3
+ Version: 0.1.0a4
4
4
  Summary: Graph Neural Networks for Molecular Machine Learning
5
5
  Author-email: Alexander Kensert <alexander.kensert@gmail.com>
6
6
  License: MIT License
@@ -51,11 +51,11 @@ Dynamic: license-file
51
51
 
52
52
  ## Highlights
53
53
  - Compatible with **Keras 3**
54
- - Simplified API
55
- - Fast featurization
56
- - Modular graph **layers**
57
- - Serializable graph **featurizers** and **models**
58
- - Flexible **GraphTensor**
54
+ - Customizable and serializable **featurizers**
55
+ - Customizable and serializable **layers** and **models**
56
+ - Customizable **GraphTensor**
57
+ - Fast and efficient featurization of molecular graphs
58
+ - Efficient and easy-to-use input pipelines using TF **records**
59
59
 
60
60
  ## Examples
61
61
 
@@ -7,11 +7,11 @@
7
7
 
8
8
  ## Highlights
9
9
  - Compatible with **Keras 3**
10
- - Simplified API
11
- - Fast featurization
12
- - Modular graph **layers**
13
- - Serializable graph **featurizers** and **models**
14
- - Flexible **GraphTensor**
10
+ - Customizable and serializable **featurizers**
11
+ - Customizable and serializable **layers** and **models**
12
+ - Customizable **GraphTensor**
13
+ - Fast and efficient featurization of molecular graphs
14
+ - Efficient and easy-to-use input pipelines using TF **records**
15
15
 
16
16
  ## Examples
17
17
 
@@ -1,4 +1,4 @@
1
- __version__ = '0.1.0a2'
1
+ __version__ = '0.1.0a4'
2
2
 
3
3
  import os
4
4
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
@@ -14,3 +14,4 @@ from molcraft import ops
14
14
  from molcraft import records
15
15
  from molcraft import tensors
16
16
  from molcraft import callbacks
17
+ from molcraft import datasets
@@ -19,3 +19,15 @@ class TensorBoard(keras.callbacks.TensorBoard):
19
19
  weight, image_weight_name, epoch
20
20
  )
21
21
  self._train_writer.flush()
22
+
23
+
24
+ class LearningRateDecay(keras.callbacks.LearningRateScheduler):
25
+
26
+ def __init__(self, rate: float, delay: int = 0, **kwargs):
27
+
28
+ def lr_schedule(epoch: int, lr: float):
29
+ if epoch < delay:
30
+ return float(lr)
31
+ return float(lr * keras.ops.exp(-rate))
32
+
33
+ super().__init__(schedule=lr_schedule, **kwargs)
@@ -1,6 +1,6 @@
1
1
  import keras
2
2
  import numpy as np
3
- from rdkit.Chem import Descriptors
3
+ from rdkit.Chem import rdMolDescriptors
4
4
 
5
5
  from molcraft import chem
6
6
  from molcraft import features
@@ -8,9 +8,6 @@ from molcraft import features
8
8
 
9
9
  @keras.saving.register_keras_serializable(package='molcraft')
10
10
  class Descriptor(features.Feature):
11
- def __init__(self, scale: float | None = None, **kwargs):
12
- super().__init__(**kwargs)
13
- self.scale = scale
14
11
 
15
12
  def __call__(self, mol: chem.Mol) -> np.ndarray:
16
13
  if not isinstance(mol, chem.Mol):
@@ -24,67 +21,71 @@ class Descriptor(features.Feature):
24
21
  self._featurize_categorical if self.vocab else
25
22
  self._featurize_floating
26
23
  )
27
- scale_value = self.scale and not self.vocab
28
24
  if not isinstance(descriptor, (tuple, list, np.ndarray)):
29
25
  descriptor = [descriptor]
30
26
 
31
27
  descriptors = []
32
28
  for value in descriptor:
33
- if scale_value:
34
- value /= self.scale
35
29
  descriptors.append(func(value))
36
30
  return np.concatenate(descriptors)
37
31
 
38
- def get_config(self):
39
- config = super().get_config()
40
- config['scale'] = self.scale
41
- return config
42
-
43
32
 
44
33
  @keras.saving.register_keras_serializable(package='molcraft')
45
34
  class MolWeight(Descriptor):
46
35
  def call(self, mol: chem.Mol) -> np.ndarray:
47
- return Descriptors.MolWt(mol)
36
+ return rdMolDescriptors.CalcExactMolWt(mol)
48
37
 
49
38
 
50
39
  @keras.saving.register_keras_serializable(package='molcraft')
51
- class MolTPSA(Descriptor):
40
+ class TPSA(Descriptor):
52
41
  def call(self, mol: chem.Mol) -> np.ndarray:
53
- return Descriptors.TPSA(mol)
42
+ return rdMolDescriptors.CalcTPSA(mol)
43
+
54
44
 
45
+ @keras.saving.register_keras_serializable(package='molcraft')
46
+ class CrippenLogP(Descriptor):
47
+ def call(self, mol: chem.Mol) -> np.ndarray:
48
+ return rdMolDescriptors.CalcCrippenDescriptors(mol)[0]
49
+
55
50
 
56
51
  @keras.saving.register_keras_serializable(package='molcraft')
57
- class MolLogP(Descriptor):
52
+ class CrippenMolarRefractivity(Descriptor):
58
53
  def call(self, mol: chem.Mol) -> np.ndarray:
59
- return Descriptors.MolLogP(mol)
54
+ return rdMolDescriptors.CalcCrippenDescriptors(mol)[1]
60
55
 
61
56
 
62
57
  @keras.saving.register_keras_serializable(package='molcraft')
63
58
  class NumHeavyAtoms(Descriptor):
64
59
  def call(self, mol: chem.Mol) -> np.ndarray:
65
- return Descriptors.HeavyAtomCount(mol)
60
+ return rdMolDescriptors.CalcNumHeavyAtoms(mol)
66
61
 
67
62
 
63
+ @keras.saving.register_keras_serializable(package='molcraft')
64
+ class NumHeteroAtoms(Descriptor):
65
+ def call(self, mol: chem.Mol) -> np.ndarray:
66
+ return rdMolDescriptors.CalcNumHeteroatoms(mol)
67
+
68
+
68
69
  @keras.saving.register_keras_serializable(package='molcraft')
69
70
  class NumHydrogenDonors(Descriptor):
70
71
  def call(self, mol: chem.Mol) -> np.ndarray:
71
- return Descriptors.NumHDonors(mol)
72
+ return rdMolDescriptors.CalcNumHBD(mol)
72
73
 
73
74
 
74
75
  @keras.saving.register_keras_serializable(package='molcraft')
75
76
  class NumHydrogenAcceptors(Descriptor):
76
77
  def call(self, mol: chem.Mol) -> np.ndarray:
77
- return Descriptors.NumHAcceptors(mol)
78
-
78
+ return rdMolDescriptors.CalcNumHBA(mol)
79
+
79
80
 
80
81
  @keras.saving.register_keras_serializable(package='molcraft')
81
82
  class NumRotatableBonds(Descriptor):
82
83
  def call(self, mol: chem.Mol) -> np.ndarray:
83
- return Descriptors.NumRotatableBonds(mol)
84
+ return rdMolDescriptors.CalcNumRotatableBonds(mol)
84
85
 
85
86
 
86
87
  @keras.saving.register_keras_serializable(package='molcraft')
87
88
  class NumRings(Descriptor):
88
89
  def call(self, mol: chem.Mol) -> np.ndarray:
89
- return Descriptors.RingCount(mol)
90
+ return rdMolDescriptors.CalcNumRings(mol)
90
91
 
@@ -2,6 +2,7 @@ import re
2
2
  import keras
3
3
  import numpy as np
4
4
  import tensorflow as tf
5
+ import tensorflow_text as tf_text
5
6
  from rdkit import Chem
6
7
 
7
8
  from molcraft import ops
@@ -10,12 +11,14 @@ from molcraft import features
10
11
  from molcraft import featurizers
11
12
  from molcraft import tensors
12
13
  from molcraft import descriptors
14
+ from molcraft import layers
13
15
 
14
16
 
15
17
  def Graph(
16
18
  inputs,
17
19
  atom_features: list[features.Feature] | str | None = 'auto',
18
20
  bond_features: list[features.Feature] | str | None = 'auto',
21
+ molecule_features: list[descriptors.Descriptor] | str | None = 'auto',
19
22
  super_atom: bool = True,
20
23
  radius: int | float | None = None,
21
24
  self_loops: bool = False,
@@ -25,7 +28,7 @@ def Graph(
25
28
  featurizer = featurizers.MolGraphFeaturizer(
26
29
  atom_features=atom_features,
27
30
  bond_features=bond_features,
28
- molecule_features=[AminoAcidType()],
31
+ molecule_features=molecule_features,
29
32
  super_atom=super_atom,
30
33
  radius=radius,
31
34
  self_loops=self_loops,
@@ -33,41 +36,55 @@ def Graph(
33
36
  **kwargs,
34
37
  )
35
38
 
36
- inputs = [
37
- residues[x] for x in ['G'] + inputs
39
+ tensor_list: list[tensors.GraphTensor] = [
40
+ featurizer(residues[tag]).update({'context': {'tag': tag}}) for tag in inputs
38
41
  ]
39
- tensor_list = [featurizer(x) for x in inputs]
40
42
  return tf.stack(tensor_list, axis=0)
43
+
41
44
 
42
-
43
- def GraphLookup(graph: tensors.GraphTensor) -> 'GraphLookupLayer':
44
- lookup = GraphLookupLayer()
45
+ def Lookup(graph: tensors.GraphTensor) -> 'LookupLayer':
46
+ lookup = LookupLayer()
45
47
  lookup._build(graph)
46
48
  return lookup
47
49
 
48
50
 
49
51
  @keras.saving.register_keras_serializable(package='molcraft')
50
- class GraphLookupLayer(keras.layers.Layer):
52
+ class SequenceSplit(keras.layers.Layer):
53
+
54
+ _pattern = "|".join([
55
+ r'(\[[A-Za-z0-9]+\]-[A-Z]\[[A-Za-z0-9]+\])', # N-term mod + mod
56
+ r'([A-Z]\[[A-Za-z0-9]+\]-\[[A-Za-z0-9]+\])', # C-term mod + mod
57
+ r'([A-Z]-\[[A-Za-z0-9]+\])', # C-term mod
58
+ r'(\[[A-Za-z0-9]+\]-[A-Z])', # N-term mod
59
+ r'([A-Z]\[[A-Za-z0-9]+\])', # Mod
60
+ r'([A-Z])', # No mod
61
+ ])
62
+
63
+ def call(self, inputs):
64
+ inputs = tf_text.regex_split(inputs, self._pattern, self._pattern)
65
+ inputs = keras.ops.concatenate([
66
+ tf.strings.join([inputs[:, :-1], '-[X]']),
67
+ inputs[:, -1:]
68
+ ], axis=1)
69
+ return inputs.to_tensor()
70
+
71
+
72
+
51
73
 
52
- def call(self, indices: tf.Tensor) -> tensors.GraphTensor:
53
- indices = tf.sort(tf.unique(tf.reshape(indices, [-1]))[0])
74
+ @keras.saving.register_keras_serializable(package='molcraft')
75
+ class LookupLayer(keras.layers.Layer):
76
+
77
+ def __init__(self, **kwargs):
78
+ super().__init__(**kwargs)
79
+ self._sequence_splitter = SequenceSplit()
80
+
81
+ def call(self, sequence: tf.Tensor) -> tensors.GraphTensor:
82
+ sequence = self._sequence_splitter(sequence)
83
+ indices = self._tag_to_index.lookup(sequence)
84
+ indices = tf.sort(tf.unique(tf.reshape(indices, [-1]))[0])[1:]
54
85
  graph = self.graph[indices]
55
- sizes = graph.context['size']
56
- max_index = keras.ops.max(indices)
57
- sizes = tf.tensor_scatter_nd_update(
58
- tensor=tf.zeros([max_index + 1], dtype=indices.dtype),
59
- indices=indices[:, None],
60
- updates=sizes
61
- )
62
- graph = graph.update(
63
- {
64
- 'context': {
65
- 'size': sizes
66
- }
67
- },
68
- )
69
86
  return tensors.to_dict(graph)
70
-
87
+
71
88
  def _build(self, x):
72
89
 
73
90
  if isinstance(x, tensors.GraphTensor):
@@ -98,7 +115,17 @@ class GraphLookupLayer(keras.layers.Layer):
98
115
  keras.ops.convert_to_tensor, self._graph
99
116
  )
100
117
  self._graph_tensor = tensors.from_dict(graph)
101
-
118
+
119
+ tags = self._graph_tensor.context['tag']
120
+
121
+ self._tag_to_index = tf.lookup.StaticHashTable(
122
+ tf.lookup.KeyValueTensorInitializer(
123
+ keys=tags,
124
+ values=range(len(tags)),
125
+ ),
126
+ default_value=-1,
127
+ )
128
+
102
129
  def get_config(self):
103
130
  config = super().get_config()
104
131
  spec = keras.saving.serialize_keras_object(self._spec)
@@ -106,7 +133,7 @@ class GraphLookupLayer(keras.layers.Layer):
106
133
  return config
107
134
 
108
135
  @classmethod
109
- def from_config(cls, config: dict) -> 'GraphLookupLayer':
136
+ def from_config(cls, config: dict) -> 'LookupLayer':
110
137
  spec = config.pop('spec')
111
138
  spec = keras.saving.deserialize_keras_object(spec)
112
139
  layer = cls(**config)
@@ -130,66 +157,48 @@ class Gather(keras.layers.Layer):
130
157
  super().__init__(**kwargs)
131
158
  self.padding = padding
132
159
  self.mask_value = mask_value
160
+ self._readout_layer = layers.Readout(mode='mean')
133
161
  self.supports_masking = True
162
+ self._sequence_splitter = SequenceSplit()
134
163
 
135
164
  def get_config(self):
136
165
  config = super().get_config()
137
166
  config['mask_value'] = self.mask_value
138
167
  config['padding'] = self.padding
139
168
  return config
140
-
141
- def call(self, inputs: tuple[tf.Tensor, tf.Tensor]) -> tf.Tensor:
142
- data, indices = inputs
143
- # if self.padding:
144
- # padding = self.padding
145
- # if isinstance(self.padding, int):
146
- # padding = [(self.padding, 0)]
147
- # if isinstance(self.padding, tuple):
148
- # padding = [self.padding]
149
- # data_rank = len(keras.ops.shape(data))
150
- # for _ in range(data_rank - len(padding)):
151
- # padding.append((0, 0))
152
- # data = keras.ops.pad(data, padding)
153
- return ops.gather(data, indices)
154
-
169
+
170
+ def call(self, inputs) -> tf.Tensor:
171
+
172
+ graph, sequence = inputs
173
+
174
+ tag = graph['context']['tag']
175
+ data = self._readout_layer(graph)
176
+
177
+ table = tf.lookup.experimental.MutableHashTable(
178
+ key_dtype=tf.string,
179
+ value_dtype=tf.int32,
180
+ default_value=-1
181
+ )
182
+
183
+ table.insert(tag, tf.range(tf.shape(tag)[0]))
184
+ sequence = self._sequence_splitter(sequence)
185
+ sequence = table.lookup(sequence)
186
+
187
+ readout = ops.gather(data, keras.ops.where(sequence == -1, 0, sequence))
188
+ readout = keras.ops.where(sequence[..., None] == -1, 0.0, readout)
189
+ return readout
190
+
155
191
  def compute_mask(
156
192
  self,
157
- inputs: tuple[tf.Tensor, tf.Tensor],
193
+ inputs: tensors.GraphTensor,
158
194
  mask: bool | None = None
159
195
  ) -> tf.Tensor | None:
160
196
  # if self.mask_value is None:
161
197
  # return None
162
- _, indices = inputs
163
- return keras.ops.not_equal(indices, self.mask_value)
164
-
198
+ _, sequence = inputs
199
+ sequence = self._sequence_splitter(sequence)
200
+ return keras.ops.not_equal(sequence, '')
165
201
 
166
- @keras.saving.register_keras_serializable(package='molcraft')
167
- class AminoAcidType(descriptors.Descriptor):
168
-
169
- def __init__(self, vocab=None, **kwargs):
170
- vocab = [
171
- "A", "C", "D", "E", "F", "G", "H", "I", "K", "L",
172
- "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y",
173
- ]
174
- super().__init__(vocab=vocab, **kwargs)
175
-
176
- def call(self, mol: chem.Mol) -> list[str]:
177
- residue = residues_reverse.get(mol.canonical_smiles)
178
- if not residue:
179
- raise KeyError(f'Could not find {mol.canonical_smiles} in `residues_reverse`.')
180
- mol = chem.remove_hs(mol)
181
- return _extract_residue_type(residues_reverse[mol.canonical_smiles])
182
-
183
- def sequence_split(sequence: str):
184
- patterns = [
185
- r'(\[[A-Za-z0-9]+\]-[A-Z]\[[A-Za-z0-9]+\])', # N-term mod + mod
186
- r'([A-Z]\[[A-Za-z0-9]+\]-\[[A-Za-z0-9]+\])', # C-term mod + mod
187
- r'([A-Z]-\[[A-Za-z0-9]+\])', # C-term mod
188
- r'(\[[A-Za-z0-9]+\]-[A-Z])', # N-term mod
189
- r'([A-Z]\[[A-Za-z0-9]+\])', # Mod
190
- r'([A-Z])', # No mod
191
- ]
192
- return [match.group(0) for match in re.finditer("|".join(patterns), sequence)]
193
202
 
194
203
  residues = {
195
204
  "A": "N[C@@H](C)C(=O)O",
@@ -252,13 +261,21 @@ residues = {
252
261
  }
253
262
 
254
263
  residues_reverse = {}
255
- def register_peptide_residues(residues: dict[str, str]):
256
- for residue, smiles in residues.items():
257
- residues[residue] = Chem.MolToSmiles(Chem.MolFromSmiles(smiles))
264
+ def register_peptide_residues(residues_: dict[str, str], canonicalize=True):
265
+ for residue, smiles in residues_.items():
266
+ if canonicalize:
267
+ smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles))
268
+ residues[residue] = smiles
258
269
  residues_reverse[residues[residue]] = residue
259
270
 
260
- register_peptide_residues(residues)
271
+ register_peptide_residues(residues, canonicalize=False)
261
272
 
262
273
  def _extract_residue_type(residue_tag: str) -> str:
263
- pattern = r"(?<!\[)[A-Z](?![\w-])"
264
- return [match.group(0) for match in re.finditer(pattern, residue_tag)][0]
274
+ pattern = r"(?<!\[)[A-Z](?![^\[]*\])"
275
+ return [match.group(0) for match in re.finditer(pattern, residue_tag)][0]
276
+
277
+ special_residues = {}
278
+ for key, value in residues.items():
279
+ special_residues[key + '-[X]'] = value.rstrip('O')
280
+
281
+ register_peptide_residues(special_residues, canonicalize=False)
@@ -155,9 +155,11 @@ class Distance(EdgeFeature):
155
155
  encode_oov: bool = True,
156
156
  **kwargs,
157
157
  ) -> None:
158
- if max_distance is None:
159
- max_distance = 20
160
- vocab = list(range(max_distance + 1))
158
+ vocab = kwargs.pop('vocab', None)
159
+ if not vocab:
160
+ if max_distance is None:
161
+ max_distance = 20
162
+ vocab = list(range(max_distance + 1))
161
163
  super().__init__(
162
164
  vocab=vocab,
163
165
  allow_oov=allow_oov,
@@ -186,9 +186,11 @@ class MolGraphFeaturizer(Featurizer):
186
186
  if default_molecule_features:
187
187
  molecule_features = [
188
188
  descriptors.MolWeight(),
189
- descriptors.MolTPSA(),
190
- descriptors.MolLogP(),
189
+ descriptors.TPSA(),
190
+ descriptors.CrippenLogP(),
191
+ descriptors.CrippenMolarRefractivity(),
191
192
  descriptors.NumHeavyAtoms(),
193
+ descriptors.NumHeteroAtoms(),
192
194
  descriptors.NumHydrogenDonors(),
193
195
  descriptors.NumHydrogenAcceptors(),
194
196
  descriptors.NumRotatableBonds(),
@@ -219,7 +221,7 @@ class MolGraphFeaturizer(Featurizer):
219
221
 
220
222
  atom_feature = self.atom_features(mol)
221
223
  bond_feature = self.bond_features(mol)
222
- context_feature = self.context_feature(mol)
224
+ molecule_feature = self.molecule_feature(mol)
223
225
  molecule_size = self.num_atoms(mol)
224
226
 
225
227
  if isinstance(context, dict):
@@ -241,8 +243,14 @@ class MolGraphFeaturizer(Featurizer):
241
243
  else:
242
244
  context = {'size': molecule_size}
243
245
 
244
- if context_feature is not None:
245
- context['feature'] = context_feature
246
+ if molecule_feature is not None:
247
+ if 'feature' in context:
248
+ warn(
249
+ 'Found both inputted and computed context feature. '
250
+ 'Overwriting inputted context feature with computed '
251
+ 'context feature (based on `molecule_features`).'
252
+ )
253
+ context['feature'] = molecule_feature
246
254
 
247
255
  node = {}
248
256
  node['feature'] = atom_feature
@@ -288,7 +296,7 @@ class MolGraphFeaturizer(Featurizer):
288
296
  edge['feature'] = self._expand_bond_features(
289
297
  mol, paths, bond_feature,
290
298
  )
291
- edge['length'] = np.eye(self.radius + 1)[edge['length']]
299
+ edge['length'] = np.eye(self.radius + 1, dtype=self.feature_dtype)[edge['length']]
292
300
 
293
301
  if self.super_atom:
294
302
  node, edge = self._add_super_atom(node, edge)
@@ -315,13 +323,13 @@ class MolGraphFeaturizer(Featurizer):
315
323
  )
316
324
  return bond_feature.astype(self.feature_dtype)
317
325
 
318
- def context_feature(self, mol: chem.Mol) -> np.ndarray:
326
+ def molecule_feature(self, mol: chem.Mol) -> np.ndarray:
319
327
  if self._molecule_features is None:
320
328
  return None
321
- context_feature: np.ndarray = np.concatenate(
329
+ molecule_feature: np.ndarray = np.concatenate(
322
330
  [f(mol) for f in self._molecule_features], axis=-1
323
331
  )
324
- return context_feature.astype(self.feature_dtype)
332
+ return molecule_feature.astype(self.feature_dtype)
325
333
 
326
334
  def num_atoms(self, mol: chem.Mol) -> np.ndarray:
327
335
  return np.asarray(mol.num_atoms, dtype=self.index_dtype)
@@ -361,9 +369,7 @@ class MolGraphFeaturizer(Featurizer):
361
369
  ) -> tuple[dict[str, np.ndarray]]:
362
370
  num_super_nodes = 1
363
371
  num_nodes = node['feature'].shape[0]
364
- node = _add_super_nodes(
365
- node, num_super_nodes, self.feature_dtype
366
- )
372
+ node = _add_super_nodes(node, num_super_nodes)
367
373
  edge = _add_super_edges(
368
374
  edge, num_nodes, num_super_nodes, self.feature_dtype, self.index_dtype
369
375
  )
@@ -402,6 +408,7 @@ class MolGraphFeaturizer(Featurizer):
402
408
  return cls(**config)
403
409
 
404
410
 
411
+ @keras.saving.register_keras_serializable(package='molcraft')
405
412
  class MolGraphFeaturizer3D(MolGraphFeaturizer):
406
413
 
407
414
  """Molecular 3d-graph featurizer.
@@ -543,7 +550,7 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
543
550
  'of the `Featurizer` or input a 3D representation of the molecule. '
544
551
  )
545
552
 
546
- context_feature = self.context_feature(mol)
553
+ molecule_feature = self.molecule_feature(mol)
547
554
  molecule_size = self.num_atoms(mol) + int(self.super_atom)
548
555
 
549
556
  if isinstance(context, dict):
@@ -565,8 +572,14 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
565
572
  else:
566
573
  context = {'size': molecule_size}
567
574
 
568
- if context_feature is not None:
569
- context['feature'] = context_feature
575
+ if molecule_feature is not None:
576
+ if 'feature' in context:
577
+ warn(
578
+ 'Found both inputted and computed context feature. '
579
+ 'Overwriting inputted context feature with computed '
580
+ 'context feature (based on `molecule_features`).'
581
+ )
582
+ context['feature'] = molecule_feature
570
583
 
571
584
  node = {}
572
585
  node['feature'] = self.atom_features(mol)
@@ -586,7 +599,7 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
586
599
  radius=self.radius,
587
600
  sparse=False,
588
601
  self_loops=self.self_loops,
589
- dtype=np.bool
602
+ dtype=bool
590
603
  )
591
604
  edge_conformer['source'], edge_conformer['target'] = np.where(adjacency_matrix)
592
605
  edge_conformer['source'] = edge_conformer['source'].astype(self.index_dtype)
@@ -603,7 +616,7 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
603
616
  )
604
617
  node_conformer['coordinate'] = np.concatenate(
605
618
  [node_conformer['coordinate'], conformer.centroid[None]], axis=0
606
- )
619
+ ).astype(self.feature_dtype)
607
620
  tensor_list.append(
608
621
  tensors.GraphTensor(context, node_conformer, edge_conformer)
609
622
  )
@@ -627,7 +640,7 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
627
640
  config['conformer_generator'] = keras.saving.deserialize_keras_object(
628
641
  config['conformer_generator']
629
642
  )
630
- return super().from_config(**config)
643
+ return super().from_config(config)
631
644
 
632
645
 
633
646
  def save_featurizer(
@@ -669,12 +682,15 @@ def load_featurizer(
669
682
  def _add_super_nodes(
670
683
  node: dict[str, np.ndarray],
671
684
  num_super_nodes: int = 1,
672
- feature_dtype: str = 'float32',
673
685
  ) -> dict[str, np.ndarray]:
674
686
  node = copy.deepcopy(node)
675
- node['super'] = np.array([False] * len(node['feature']) + [True] * num_super_nodes)
687
+ node['super'] = np.array(
688
+ [False] * len(node['feature']) + [True] * num_super_nodes,
689
+ dtype=bool
690
+ )
676
691
  super_node_feature = np.zeros(
677
- [num_super_nodes, node['feature'].shape[-1]], dtype=feature_dtype
692
+ [num_super_nodes, node['feature'].shape[-1]],
693
+ dtype=node['feature'].dtype
678
694
  )
679
695
  node['feature'] = np.concatenate([node['feature'], super_node_feature])
680
696
  return node
@@ -694,31 +710,38 @@ def _add_super_edges(
694
710
  np.tile(np.arange(num_nodes), [num_super_nodes])
695
711
  )
696
712
  edge['source'] = np.concatenate(
697
- [
698
- edge['source'],
699
- node_indices,
700
- super_node_indices,
701
- ]
702
- )
703
- edge['source'] = edge['source'].astype(index_dtype)
713
+ [edge['source'], node_indices, super_node_indices]
714
+ ).astype(index_dtype)
715
+
704
716
  edge['target'] = np.concatenate(
705
- [
706
- edge['target'],
707
- super_node_indices,
708
- node_indices
709
- ]
710
- )
711
- edge['target'] = edge['target'].astype(index_dtype)
717
+ [edge['target'], super_node_indices, node_indices]
718
+ ).astype(index_dtype)
719
+
712
720
  if 'feature' in edge:
713
- edge['super'] = np.asarray([False] * edge['feature'].shape[0] + [True] * (num_super_nodes * num_nodes * 2))
714
- edge['feature'] = np.concatenate([edge['feature'], np.zeros((num_super_nodes * num_nodes * 2, edge['feature'].shape[-1]))])
721
+ num_edges = int(edge['feature'].shape[0])
722
+ num_super_edges = int(num_super_nodes * num_nodes * 2)
723
+ edge['super'] = np.asarray(
724
+ ([False] * num_edges + [True] * num_super_edges),
725
+ dtype=bool
726
+ )
727
+ edge['feature'] = np.concatenate(
728
+ [
729
+ edge['feature'],
730
+ np.zeros(
731
+ shape=(num_super_edges, edge['feature'].shape[-1]),
732
+ dtype=edge['feature'].dtype
733
+ )
734
+ ]
735
+ )
736
+
715
737
  if 'length' in edge:
716
738
  edge['length'] = np.pad(edge['length'], [(0, 0), (1, 0)])
717
- zero_array = np.zeros((num_nodes * num_super_nodes * 2,), dtype='int32')
739
+ zero_array = np.zeros([num_nodes * num_super_nodes * 2], dtype='int32')
718
740
  edge_length_dim = edge['length'].shape[1]
719
741
  virtual_edge_length = np.eye(edge_length_dim)[zero_array]
720
742
  edge['length'] = np.concatenate([edge['length'], virtual_edge_length])
721
743
  edge['length'] = edge['length'].astype(feature_dtype)
744
+
722
745
  return edge
723
746
 
724
747