molcraft 0.1.0a16__py3-none-any.whl → 0.1.0a18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

molcraft/featurizers.py CHANGED
@@ -14,63 +14,47 @@ from pathlib import Path
14
14
  from molcraft import tensors
15
15
  from molcraft import features
16
16
  from molcraft import chem
17
- from molcraft import conformers
18
17
  from molcraft import descriptors
19
18
 
20
19
 
21
20
  @keras.saving.register_keras_serializable(package='molcraft')
22
- class Featurizer(abc.ABC):
21
+ class GraphFeaturizer(abc.ABC):
23
22
 
24
- """Base class for featurizers.
23
+ """Base graph featurizer.
25
24
  """
26
25
 
27
26
  @abc.abstractmethod
28
- def call(
29
- self,
30
- x: tensors.GraphTensor
31
- ) -> tensors.GraphTensor | list[tensors.GraphTensor]:
27
+ def call(self, x: str | chem.Mol | tuple) -> tensors.GraphTensor:
32
28
  pass
33
29
 
34
- @abc.abstractmethod
35
- def stack(
36
- self,
37
- call_outputs: list[tensors.GraphTensor]
38
- ) -> tensors.GraphTensor:
39
- pass
40
-
41
30
  def get_config(self) -> dict:
42
31
  return {}
43
32
 
44
33
  @classmethod
45
- def from_config(cls, config: dict) -> 'Featurizer':
34
+ def from_config(cls, config: dict) -> 'GraphFeaturizer':
46
35
  return cls(**config)
47
36
 
48
37
  def save(self, filepath: str | Path, *args, **kwargs) -> None:
49
- save_featurizer(
50
- self, filepath, *args, **kwargs
51
- )
38
+ save_featurizer(self, filepath, *args, **kwargs)
52
39
 
53
40
  @staticmethod
54
- def load(filepath: str | Path, *args, **kwargs) -> 'Featurizer':
55
- return load_featurizer(
56
- filepath, *args, **kwargs
57
- )
41
+ def load(filepath: str | Path, *args, **kwargs) -> 'GraphFeaturizer':
42
+ return load_featurizer(filepath, *args, **kwargs)
58
43
 
59
44
  def __call__(
60
45
  self,
61
- inputs: str | tuple | list | np.ndarray | pd.DataFrame | pd.Series,
46
+ inputs: str | chem.Mol | tuple | typing.Iterable,
62
47
  *,
63
48
  multiprocessing: bool = False,
64
49
  processes: int | None = None,
65
50
  device: str = '/cpu:0',
66
- **kwargs
67
51
  ) -> tensors.GraphTensor:
68
52
  if isinstance(inputs, (str, tuple)):
69
53
  return self.call(inputs)
70
54
  if isinstance(inputs, (pd.DataFrame, pd.Series)):
71
- inputs = inputs.values
72
- if isinstance(inputs, np.ndarray):
73
- inputs = list(inputs)
55
+ inputs = inputs.values.tolist()
56
+ elif isinstance(inputs, np.ndarray):
57
+ inputs = inputs.tolist()
74
58
  if not multiprocessing:
75
59
  outputs = [self.call(x) for x in inputs]
76
60
  else:
@@ -78,19 +62,19 @@ class Featurizer(abc.ABC):
78
62
  with mp.Pool(processes) as pool:
79
63
  outputs = pool.map(func=self.call, iterable=inputs)
80
64
  outputs = [x for x in outputs if x is not None]
81
- return self.stack(outputs)
65
+ if tensors.is_scalar(outputs[0]):
66
+ return tf.stack(outputs, axis=0)
67
+ return tf.concat(outputs, axis=0)
82
68
 
83
69
 
84
70
  @keras.saving.register_keras_serializable(package='molcraft')
85
- class MolGraphFeaturizer(Featurizer):
71
+ class MolGraphFeaturizer(GraphFeaturizer):
86
72
 
87
73
  """Molecular graph featurizer.
88
74
 
89
75
  Converts SMILES or InChI strings to a molecular graph.
90
76
 
91
77
  The molecular graph may encode a single molecule or a batch of molecules.
92
- In either case, it is a single graph, with each molecule corresponding to
93
- a subgraph within the graph.
94
78
 
95
79
  Example:
96
80
 
@@ -99,10 +83,14 @@ class MolGraphFeaturizer(Featurizer):
99
83
  >>> featurizer = molcraft.featurizers.MolGraphFeaturizer(
100
84
  ... atom_features=[
101
85
  ... molcraft.features.AtomType(),
102
- ... molcraft.features.TotalNumHs(),
86
+ ... molcraft.features.NumHydrogens(),
103
87
  ... molcraft.features.Degree(),
104
88
  ... ],
105
- ... radius=1
89
+ ... bond_features=[
90
+ ... molcraft.features.BondType(),
91
+ ... ],
92
+ ... super_node=False,
93
+ ... self_loops=False,
106
94
  ... )
107
95
  >>>
108
96
  >>> graph = featurizer(["N[C@@H](C)C(=O)O", "N[C@@H](CS)C(=O)O"])
@@ -112,7 +100,7 @@ class MolGraphFeaturizer(Featurizer):
112
100
  'size': <tf.Tensor: shape=[2], dtype=int32>
113
101
  },
114
102
  node={
115
- 'feature': <tf.Tensor: shape=[13, 133], dtype=float32>
103
+ 'feature': <tf.Tensor: shape=[13, 129], dtype=float32>
116
104
  },
117
105
  edge={
118
106
  'source': <tf.Tensor: shape=[22], dtype=int32>,
@@ -123,73 +111,46 @@ class MolGraphFeaturizer(Featurizer):
123
111
 
124
112
  Args:
125
113
  atom_features:
126
- A list of `features.Feature` encoding the nodes of the molecular graph.
114
+ A list of `features.Feature` encoded as the node features.
127
115
  bond_features:
128
- A list of `features.Feature` encoding the edges of the molecular graph.
116
+ A list of `features.Feature` encoded as the edge features.
129
117
  molecule_features:
130
- A `features.Feature` encoding the molecule (or `context`) of the graph.
131
- If `contextual_super_atom` is set to `True`, then this feature will be
132
- embedded, via `NodeEmbedding`, as a super node in the molecular graph.
133
- super_atom:
134
- A boolean specifying whether super atoms exist and should be embedded
135
- via `NodeEmbedding`.
136
- radius:
137
- An integer specifying how many bond lengths should be considered as an
138
- edge. The default is None (or 1), which specifies that only bonds should
139
- be considered an edge.
118
+ A list of `descriptors.Descriptor` encoded as the context feature.
119
+ super_node:
120
+ A boolean specifying whether to include a super node.
140
121
  self_loops:
141
- A boolean specifying whether self loops exist. If True, this means that
142
- each node (atom) has an edge (bond) to itself.
143
- include_hs:
122
+ A boolean specifying whether self loops exist.
123
+ include_hydrogens:
144
124
  A boolean specifying whether hydrogens should be encoded as nodes.
145
125
  """
146
126
 
147
127
  def __init__(
148
128
  self,
149
- atom_features: list[features.Feature] | str | None = 'auto',
129
+ atom_features: list[features.Feature] | str = 'auto',
150
130
  bond_features: list[features.Feature] | str | None = 'auto',
151
- molecule_features: features.Feature | str | None = None,
152
- super_atom: bool = False,
153
- radius: int | float | None = None,
131
+ molecule_features: list[descriptors.Descriptor] | str | None = None,
132
+ super_node: bool = False,
154
133
  self_loops: bool = False,
155
- include_hs: bool = False,
156
- **kwargs,
134
+ include_hydrogens: bool = False,
157
135
  ) -> None:
158
- if molecule_features is None:
159
- molecule_features = kwargs.pop('mol_features', None)
160
-
161
- self.radius = int(max(radius or 1, 1))
162
- self.include_hs = include_hs
163
- self.self_loops = self_loops
164
- self.super_atom = super_atom
165
-
166
- default_atom_features = (
136
+ use_default_atom_features = (
167
137
  atom_features == 'auto' or atom_features == 'default'
168
138
  )
169
- if default_atom_features:
170
- atom_features = [features.AtomType()]
171
- if not self.include_hs:
172
- atom_features.append(features.NumHydrogens())
173
- atom_features.append(features.Degree())
174
- if not isinstance(self, MolGraphFeaturizer3D):
175
- default_bond_features = (
176
- bond_features == 'auto' or bond_features == 'default'
177
- )
178
- if default_bond_features or self.radius > 1:
179
- vocab = ['zero', 'single', 'double', 'triple', 'aromatic']
180
- bond_features = [
181
- features.BondType(vocab)
182
- ]
183
- if not default_bond_features and self.radius > 1:
184
- warnings.warn(
185
- 'Replacing user-specified bond features with default bond features, '
186
- 'as `radius`>1. When `radius`>1, only bond types are considered.',
187
- stacklevel=2
188
- )
189
- default_molecule_features = (
139
+ if use_default_atom_features:
140
+ atom_features = [features.AtomType(), features.Degree()]
141
+ if not include_hydrogens:
142
+ atom_features += [features.NumHydrogens()]
143
+
144
+ use_default_bond_features = (
145
+ bond_features == 'auto' or bond_features == 'default'
146
+ )
147
+ if use_default_bond_features:
148
+ bond_features = [features.BondType()]
149
+
150
+ use_default_molecule_features = (
190
151
  molecule_features == 'auto' or molecule_features == 'default'
191
152
  )
192
- if default_molecule_features:
153
+ if use_default_molecule_features:
193
154
  molecule_features = [
194
155
  descriptors.MolWeight(),
195
156
  descriptors.TotalPolarSurfaceArea(),
@@ -202,182 +163,65 @@ class MolGraphFeaturizer(Featurizer):
202
163
  descriptors.NumRotatableBonds(),
203
164
  descriptors.NumRings(),
204
165
  ]
166
+
205
167
  self._atom_features = atom_features
206
168
  self._bond_features = bond_features
207
169
  self._molecule_features = molecule_features
208
- self.feature_dtype = 'float32'
209
- self.index_dtype = 'int32'
170
+ self._include_hydrogens = include_hydrogens
171
+ self._self_loops = self_loops
172
+ self._super_node = super_node
210
173
 
211
174
  def call(self, inputs: str | tuple) -> tensors.GraphTensor:
212
- if isinstance(inputs, (tuple, list, np.ndarray)):
213
- x, *context = inputs
214
- if len(context) and isinstance(context[0], dict):
215
- context = copy.deepcopy(context[0])
216
- else:
217
- x, context = inputs, None
175
+
176
+ if isinstance(inputs, str):
177
+ inputs = (inputs,)
218
178
 
219
- mol = chem.Mol.from_encoding(x, explicit_hs=self.include_hs)
179
+ inputs, *context_inputs = inputs
220
180
 
221
- if mol is None:
222
- warnings.warn(
223
- f'Could not obtain `chem.Mol` from {x}. '
224
- 'Returning `None` (proceeding without it).',
225
- stacklevel=2
226
- )
227
- return None
228
-
229
- atom_feature = self.atom_features(mol)
230
- bond_feature = self.bond_features(mol)
231
- molecule_feature = self.molecule_feature(mol)
232
- molecule_size = self.num_atoms(mol)
233
-
234
- if isinstance(context, dict):
235
- if 'x' in context:
236
- context['feature'] = context.pop('x')
237
- if 'y' in context:
238
- context['label'] = context.pop('y')
239
- if 'sample_weight' in context:
240
- context['weight'] = context.pop('sample_weight')
241
- context = {
242
- **{'size': molecule_size},
243
- **context
244
- }
245
- elif isinstance(context, list):
246
- context = {
247
- **{'size': molecule_size},
248
- **{key: value for (key, value) in zip(['label', 'weight'], context)}
249
- }
250
- else:
251
- context = {'size': molecule_size}
252
-
253
- if molecule_feature is not None:
254
- if 'feature' in context:
255
- warnings.warn(
256
- 'Found both inputted and computed context feature. '
257
- 'Overwriting inputted context feature with computed '
258
- 'context feature (based on `molecule_features`).',
259
- stacklevel=2
260
- )
261
- context['feature'] = molecule_feature
262
-
263
- node = {}
264
- node['feature'] = atom_feature
181
+ mol = chem.Mol.from_encoding(inputs, explicit_hs=self._include_hydrogens)
182
+
183
+ data = {'context': {}, 'node': {}, 'edge': {}}
265
184
 
266
- edge = {}
267
- if self.radius == 1:
268
- edge['source'], edge['target'] = mol.adjacency(
269
- fill='full', sparse=True, self_loops=self.self_loops, dtype=self.index_dtype
270
- )
271
- if self.self_loops:
272
- bond_feature = np.pad(bond_feature, [(0, 1), (0, 0)])
273
- if bond_feature is not None:
274
- bond_indices = []
275
- for atom_i, atom_j in zip(edge['source'], edge['target']):
276
- if atom_i == atom_j:
277
- bond_indices.append(-1)
278
- else:
279
- bond_indices.append(
280
- mol.get_bond_between_atoms(atom_i, atom_j).index
281
- )
282
- edge['feature'] = bond_feature[bond_indices]
283
- else:
284
- paths = chem.get_shortest_paths(
285
- mol, radius=self.radius, self_loops=self.self_loops
286
- )
287
- edge['source'] = np.asarray(
288
- [path[0] for path in paths], dtype=self.index_dtype
289
- )
290
- edge['target'] = np.asarray(
291
- [path[-1] for path in paths], dtype=self.index_dtype
185
+ data['context']['size'] = np.asarray(mol.num_atoms)
186
+
187
+ if len(context_inputs) == 1:
188
+ data['context']['label'] = np.asarray(context_inputs[0])
189
+ elif len(context_inputs) == 2:
190
+ data['context']['label'] = np.asarray(context_inputs[0])
191
+ data['context']['weight'] = np.asarray(context_inputs[1])
192
+
193
+ if self._molecule_features is not None:
194
+ data['context']['feature'] = np.concatenate(
195
+ [f(mol) for f in self._molecule_features], axis=-1
292
196
  )
293
- if bond_feature is not None:
294
- zero_bond_feature = np.array(
295
- [[1., 0., 0., 0., 0.]], dtype=bond_feature.dtype
296
- )
297
- bond_feature = np.concatenate(
298
- [bond_feature, zero_bond_feature], axis=0
299
- )
300
- edge['feature'] = self._expand_bond_features(
301
- mol, paths, bond_feature,
302
- )
303
-
304
- if self.super_atom:
305
- node, edge = self._add_super_atom(node, edge)
306
- context['size'] += 1
307
-
308
- return tensors.GraphTensor(context, node, edge)
309
-
310
- def stack(self, outputs):
311
- if tensors.is_scalar(outputs[0]):
312
- return tf.stack(outputs, axis=0)
313
- return tf.concat(outputs, axis=0)
314
-
315
- def atom_features(self, mol: chem.Mol) -> np.ndarray:
316
- atom_feature: np.ndarray = np.concatenate(
197
+
198
+ data['node']['feature'] = np.concatenate(
317
199
  [f(mol) for f in self._atom_features], axis=-1
318
200
  )
319
- return atom_feature.astype(self.feature_dtype)
320
201
 
321
- def bond_features(self, mol: chem.Mol) -> np.ndarray:
322
- if self._bond_features is None:
323
- return None
324
- bond_feature: np.ndarray = np.concatenate(
325
- [f(mol) for f in self._bond_features], axis=-1
326
- )
327
- return bond_feature.astype(self.feature_dtype)
328
-
329
- def molecule_feature(self, mol: chem.Mol) -> np.ndarray:
330
- if self._molecule_features is None:
331
- return None
332
- molecule_feature: np.ndarray = np.concatenate(
333
- [f(mol) for f in self._molecule_features], axis=-1
202
+ data['edge']['source'], data['edge']['target'] = mol.adjacency(
203
+ fill='full', sparse=True, self_loops=self._self_loops
334
204
  )
335
- return molecule_feature.astype(self.feature_dtype)
205
+
206
+ if self._bond_features is not None:
207
+ bond_features = np.concatenate(
208
+ [f(mol) for f in self._bond_features], axis=-1
209
+ )
210
+ if self._self_loops:
211
+ bond_features = np.pad(bond_features, [(0, 1), (0, 0)])
336
212
 
337
- def num_atoms(self, mol: chem.Mol) -> np.ndarray:
338
- return np.asarray(mol.num_atoms, dtype=self.index_dtype)
339
-
340
- def num_bonds(self, mol: chem.Mol) -> np.ndarray:
341
- return np.asarray(mol.num_bonds, dtype=self.index_dtype)
342
-
343
- def _expand_bond_features(
344
- self,
345
- mol: chem.Mol,
346
- paths: list[list[int]],
347
- bond_feature: np.ndarray,
348
- ) -> np.ndarray:
349
-
350
- def bond_feature_lookup(path):
351
- path_bond_indices = [
352
- mol.get_bond_between_atoms(path[i], path[i + 1]).index
353
- for i in range(len(path) - 1)
213
+ bond_indices = [
214
+ mol.get_bond_between_atoms(i, j).index if (i != j) else -1
215
+ for (i, j) in zip(data['edge']['source'], data['edge']['target'])
354
216
  ]
355
- padding = [-1] * (self.radius - len(path) + 1)
356
- path_bond_indices += padding
357
- return bond_feature[path_bond_indices].reshape(-1)
358
-
359
- edge_feature = np.asarray(
360
- [
361
- bond_feature_lookup(path) for path in paths
362
- ],
363
- dtype=self.feature_dtype
364
- ).reshape((-1, bond_feature.shape[-1] * self.radius))
365
-
366
- return edge_feature
367
-
368
- def _add_super_atom(
369
- self,
370
- node: dict[str, np.ndarray],
371
- edge: dict[str, np.ndarray],
372
- ) -> tuple[dict[str, np.ndarray]]:
373
- num_super_nodes = 1
374
- num_nodes = node['feature'].shape[0]
375
- node = _add_super_nodes(node, num_super_nodes)
376
- edge = _add_super_edges(
377
- edge, num_nodes, num_super_nodes, self.feature_dtype, self.index_dtype, self.self_loops
378
- )
379
- return node, edge
380
217
 
218
+ data['edge']['feature'] = bond_features[bond_indices]
219
+
220
+ if self._super_node:
221
+ data = _add_super_node(data)
222
+
223
+ return tensors.GraphTensor(**_convert_dtypes(data))
224
+
381
225
  def get_config(self):
382
226
  config = super().get_config()
383
227
  config.update({
@@ -390,10 +234,9 @@ class MolGraphFeaturizer(Featurizer):
390
234
  'molecule_features': keras.saving.serialize_keras_object(
391
235
  self._molecule_features
392
236
  ),
393
- 'super_atom': self.super_atom,
394
- 'radius': self.radius,
395
- 'self_loops': self.self_loops,
396
- 'include_hs': self.include_hs,
237
+ 'super_node': self._super_node,
238
+ 'self_loops': self._self_loops,
239
+ 'include_hydrogens': self._include_hydrogens,
397
240
  })
398
241
  return config
399
242
 
@@ -414,15 +257,11 @@ class MolGraphFeaturizer(Featurizer):
414
257
  @keras.saving.register_keras_serializable(package='molcraft')
415
258
  class MolGraphFeaturizer3D(MolGraphFeaturizer):
416
259
 
417
- """Molecular 3d-graph featurizer.
260
+ """3D Molecular graph featurizer.
418
261
 
419
- Converts SMILES or InChI strings to a molecular graph in 3d space.
420
- Namely, in addition to the information encoded in a standard molecular
421
- graph, cartesian coordinates are also included.
262
+ Converts SMILES or InChI strings to a 3d molecular graph.
422
263
 
423
264
  The molecular graph may encode a single molecule or a batch of molecules.
424
- In either case, it is a single graph, with each molecule corresponding to
425
- a subgraph within the graph.
426
265
 
427
266
  Example:
428
267
 
@@ -431,226 +270,166 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
431
270
  >>> featurizer = molcraft.featurizers.MolGraphFeaturizer3D(
432
271
  ... atom_features=[
433
272
  ... molcraft.features.AtomType(),
434
- ... molcraft.features.TotalNumHs(),
273
+ ... molcraft.features.NumHydrogens(),
435
274
  ... molcraft.features.Degree(),
436
275
  ... ],
437
- ... radius=5.0
276
+ ... radius=5.0,
277
+ ... random_seed=42,
438
278
  ... )
439
279
  >>>
440
280
  >>> graph = featurizer(["N[C@@H](C)C(=O)O", "N[C@@H](CS)C(=O)O"])
441
281
  >>> graph
442
282
  GraphTensor(
443
283
  context={
444
- 'size': <tf.Tensor: shape=[20], dtype=int32>
284
+ 'size': <tf.Tensor: shape=[2], dtype=int32>
445
285
  },
446
286
  node={
447
- 'feature': <tf.Tensor: shape=[130, 133], dtype=float32>,
448
- 'coordinate': <tf.Tensor: shape=[130, 3], dtype=float32>
287
+ 'feature': <tf.Tensor: shape=[13, 129], dtype=float32>,
288
+ 'coordinate': <tf.Tensor: shape=[13, 3], dtype=float32>
449
289
  },
450
290
  edge={
451
- 'source': <tf.Tensor: shape=[714], dtype=int32>,
452
- 'target': <tf.Tensor: shape=[714], dtype=int32>,
453
- 'feature': <tf.Tensor: shape=[714, 23], dtype=float32>
291
+ 'source': <tf.Tensor: shape=[72], dtype=int32>,
292
+ 'target': <tf.Tensor: shape=[72], dtype=int32>,
293
+ 'feature': <tf.Tensor: shape=[72, 12], dtype=float32>
454
294
  }
455
295
  )
456
296
 
457
297
  Args:
458
298
  atom_features:
459
- A list of `features.Feature` encoding the nodes of the molecular graph.
460
- bond_features:
461
- A list of `features.Feature` encoding the edges of the molecular graph.
299
+ A list of `features.Feature` encoded as the node features.
300
+ pair_features:
301
+ A list of `features.PairFeature` encoded as the edge features.
462
302
  molecule_features:
463
- A `features.Feature` encoding the molecule (or `context`) of the graph.
464
- If `contextual_super_atom` is set to `True`, then this feature will be
465
- embedded, via `NodeEmbedding`, as a super node in the molecular graph.
466
- conformer_generator:
467
- A `conformers.ConformerGenerator` which produces conformers. If `auto`
468
- a `conformers.ConformerEmbedder` will be used. If None, it is assumed
469
- that the molecule already has conformer(s).
470
- super_atom:
471
- A boolean specifying whether super atoms exist and should be embedded
472
- via `NodeEmbedding`.
473
- radius:
474
- A float specifying, for each atom, the maximum distance (in angstroms)
475
- that another atom should be within to be considered an edge. Default
476
- is set to 6.0 as this should cover most interactions. This parameter
477
- can be though of as the receptive field. If None, the radius will be
478
- infinite so all the receptive field will be the entire space (graph).
303
+ A list of `descriptors.Descriptor` encoded as the context feature.
304
+ super_node:
305
+ A boolean specifying whether to include a super node.
479
306
  self_loops:
480
- A boolean specifying whether self loops exist. If True, this means that
481
- each node (atom) has an edge (bond) to itself.
482
- include_hs:
307
+ A boolean specifying whether self loops exist.
308
+ include_hydrogens:
483
309
  A boolean specifying whether hydrogens should be encoded as nodes.
310
+ radius:
311
+ A floating point value specifying maximum edge length.
312
+ random_seed:
313
+ An integer specifying the random seed for the conformer generation.
484
314
  """
485
315
 
486
316
  def __init__(
487
317
  self,
488
- atom_features: list[features.Feature] | str | None = 'auto',
489
- bond_features: list[features.Feature] | str | None = 'auto',
490
- molecule_features: features.Feature | str = None,
491
- conformer_generator: conformers.ConformerProcessor | str | None = 'auto',
492
- super_atom: bool = False,
493
- radius: int | float | None = 6.0,
318
+ atom_features: list[features.Feature] | str = 'auto',
319
+ pair_features: list[features.PairFeature] | str = 'auto',
320
+ molecule_features: features.Feature | str | None = None,
321
+ super_node: bool = False,
494
322
  self_loops: bool = False,
495
- include_hs: bool = False,
496
- **kwargs,
323
+ include_hydrogens: bool = False,
324
+ radius: int | float | None = 6.0,
325
+ random_seed: int | None = None,
497
326
  ) -> None:
498
- if bond_features == 'auto':
499
- bond_features = [
500
- features.Distance()
501
- ]
502
327
  super().__init__(
503
328
  atom_features=atom_features,
504
- bond_features=bond_features,
329
+ bond_features=None,
505
330
  molecule_features=molecule_features,
506
- super_atom=super_atom,
507
- radius=radius,
331
+ super_node=super_node,
508
332
  self_loops=self_loops,
509
- include_hs=include_hs,
510
- **kwargs,
333
+ include_hydrogens=include_hydrogens,
511
334
  )
512
- if conformer_generator == 'auto':
513
- conformer_generator = conformers.ConformerGenerator(
514
- steps=[
515
- conformers.ConformerEmbedder(
516
- method='ETKDGv3',
517
- num_conformers=5
518
- ),
519
- ]
520
- )
521
- self.conformer_generator = conformer_generator
522
- self.embed_conformer = self.conformer_generator is not None
523
- self.radius = float(radius) if radius else None
524
-
335
+
336
+ use_default_pair_features = (
337
+ pair_features == 'auto' or pair_features == 'default'
338
+ )
339
+ if use_default_pair_features:
340
+ pair_features = [features.PairDistance()]
341
+
342
+ self._pair_features = pair_features
343
+ self._radius = float(radius) if radius else None
344
+ self._random_seed = random_seed
345
+
525
346
  def call(self, inputs: str | tuple) -> tensors.GraphTensor:
526
347
 
527
- if isinstance(inputs, (tuple, list, np.ndarray)):
528
- x, *context = inputs
529
- if len(context) and isinstance(context[0], dict):
530
- context = copy.deepcopy(context[0])
531
- else:
532
- x, context = inputs, None
348
+ if isinstance(inputs, str):
349
+ inputs = [inputs]
533
350
 
534
- explicit_hs = (self.include_hs or self.embed_conformer)
535
- mol = chem.Mol.from_encoding(x, explicit_hs=explicit_hs)
536
-
537
- if mol is None:
538
- warnings.warn(
539
- f'Could not obtain `chem.Mol` from {x}. '
540
- 'Proceeding without it.',
541
- stacklevel=2
542
- )
543
- return None
351
+ inputs, *context_inputs = inputs
544
352
 
545
- if self.embed_conformer:
546
- mol = self.conformer_generator(mol)
547
- if not self.include_hs:
548
- mol = chem.remove_hs(mol)
353
+ mol = chem.Mol.from_encoding(inputs, explicit_hs=True)
549
354
 
550
355
  if mol.num_conformers == 0:
551
- raise ValueError(
552
- 'Cannot featurize a molecule without conformer(s). '
553
- 'Make sure to pass a `ConformerGenerator` to the constructor '
554
- 'of the `Featurizer` or input a 3D representation of the molecule. '
356
+ mol = chem.embed_conformers(
357
+ mol, num_conformers=1, random_seed=self._random_seed
555
358
  )
556
359
 
557
- molecule_feature = self.molecule_feature(mol)
558
- molecule_size = self.num_atoms(mol) + int(self.super_atom)
559
- molecule_size = molecule_size.astype(self.index_dtype)
560
-
561
- if isinstance(context, dict):
562
- if 'x' in context:
563
- context['feature'] = context.pop('x')
564
- if 'y' in context:
565
- context['label'] = context.pop('y')
566
- if 'sample_weight' in context:
567
- context['weight'] = context.pop('sample_weight')
568
- context = {
569
- **{'size': molecule_size},
570
- **context
571
- }
572
- elif isinstance(context, list):
573
- context = {
574
- **{'size': molecule_size},
575
- **{key: value for (key, value) in zip(['label', 'weight'], context)}
576
- }
577
- else:
578
- context = {'size': molecule_size}
579
-
580
- if molecule_feature is not None:
581
- if 'feature' in context:
582
- warnings.warn(
583
- 'Found both inputted and computed context feature. '
584
- 'Overwriting inputted context feature with computed '
585
- 'context feature (based on `molecule_features`).',
586
- stacklevel=2
587
- )
588
- context['feature'] = molecule_feature
589
-
590
- node = {}
591
- node['feature'] = self.atom_features(mol)
592
-
593
- if self._bond_features:
594
- edge_feature = self.bond_features(mol)
595
-
596
- edge = {}
597
- mols = chem.unpack_conformers(mol)
598
- tensor_list = []
599
- for i, mol in enumerate(mols):
600
- node_conformer = copy.deepcopy(node)
601
- edge_conformer = copy.deepcopy(edge)
602
- conformer = mol.get_conformer()
603
- adjacency_matrix = conformer.adjacency(
604
- fill='full',
605
- radius=self.radius,
606
- sparse=False,
607
- self_loops=self.self_loops,
608
- dtype=bool
360
+ if not self._include_hydrogens:
361
+ mol = chem.remove_hs(mol)
362
+
363
+ data = {'context': {}, 'node': {}, 'edge': {}}
364
+
365
+ data['context']['size'] = np.asarray(mol.num_atoms)
366
+
367
+ if len(context_inputs) == 1:
368
+ data['context']['label'] = np.asarray(context_inputs[0])
369
+ elif len(context_inputs) == 2:
370
+ data['context']['label'] = np.asarray(context_inputs[0])
371
+ data['context']['weight'] = np.asarray(context_inputs[1])
372
+
373
+ if self._molecule_features is not None:
374
+ data['context']['feature'] = np.concatenate(
375
+ [f(mol) for f in self._molecule_features], axis=-1
609
376
  )
610
- edge_conformer['source'], edge_conformer['target'] = np.where(adjacency_matrix)
611
- edge_conformer['source'] = edge_conformer['source'].astype(self.index_dtype)
612
- edge_conformer['target'] = edge_conformer['target'].astype(self.index_dtype)
613
- node_conformer['coordinate'] = conformer.coordinates.astype(self.feature_dtype)
614
-
615
- if self._bond_features:
616
- edge_feature_keep = adjacency_matrix.reshape(-1)
617
- edge_conformer['feature'] = edge_feature[edge_feature_keep]
618
-
619
- if self.super_atom:
620
- node_conformer, edge_conformer = self._add_super_atom(
621
- node_conformer, edge_conformer
622
- )
623
- node_conformer['coordinate'] = np.concatenate(
624
- [node_conformer['coordinate'], conformer.centroid[None]], axis=0
625
- ).astype(self.feature_dtype)
626
- tensor_list.append(
627
- tensors.GraphTensor(context, node_conformer, edge_conformer)
377
+
378
+ conformer = mol.get_conformer()
379
+
380
+ data['node']['feature'] = np.concatenate(
381
+ [f(mol) for f in self._atom_features], axis=-1
382
+ )
383
+ data['node']['coordinate'] = conformer.coordinates
384
+
385
+ adjacency_matrix = conformer.adjacency(
386
+ fill='full', radius=self._radius, sparse=False, self_loops=self._self_loops,
387
+ )
388
+
389
+ data['edge']['source'], data['edge']['target'] = np.where(adjacency_matrix)
390
+
391
+ if self._pair_features is not None:
392
+ pair_features = np.concatenate(
393
+ [f(mol) for f in self._pair_features], axis=-1
394
+ )
395
+ pair_keep = adjacency_matrix.reshape(-1).astype(bool)
396
+ data['edge']['feature'] = pair_features[pair_keep]
397
+
398
+ if self._super_node:
399
+ data = _add_super_node(data)
400
+ data['node']['coordinate'] = np.concatenate(
401
+ [data['node']['coordinate'], conformer.centroid[None]], axis=0
628
402
  )
629
-
630
- return tensor_list
403
+
404
+ return tensors.GraphTensor(**_convert_dtypes(data))
631
405
 
632
- def stack(self, outputs):
633
- # Flatten list of lists (due to multiple conformers per molecule)
634
- outputs = [x for xs in outputs for x in xs]
635
- return super().stack(outputs)
406
+ @property
407
+ def random_seed(self) -> int | None:
408
+ return self._random_seed
636
409
 
410
+ @random_seed.setter
411
+ def random_seed(self, value: int) -> None:
412
+ self._random_seed = value
413
+
637
414
  def get_config(self):
638
415
  config = super().get_config()
639
- config['conformer_generator'] = keras.saving.serialize_keras_object(
640
- self.conformer_generator
416
+ config['radius'] = self._radius
417
+ config['pair_features'] = keras.saving.serialize_keras_object(
418
+ self._pair_features
641
419
  )
420
+ config['random_seed'] = self._random_seed
642
421
  return config
643
422
 
644
423
  @classmethod
645
424
  def from_config(cls, config: dict):
646
- config['conformer_generator'] = keras.saving.deserialize_keras_object(
647
- config['conformer_generator']
425
+ config['pair_features'] = keras.saving.deserialize_keras_object(
426
+ config['pair_features']
648
427
  )
649
428
  return super().from_config(config)
650
-
429
+
651
430
 
652
431
  def save_featurizer(
653
- featurizer: Featurizer,
432
+ featurizer: GraphFeaturizer,
654
433
  filepath: str | Path,
655
434
  overwrite: bool = True,
656
435
  **kwargs
@@ -658,8 +437,8 @@ def save_featurizer(
658
437
  filepath = Path(filepath)
659
438
  if filepath.suffix != '.json':
660
439
  raise ValueError(
661
- 'Invalid `filepath` extension for saving a `Featurizer`. '
662
- 'A `Featurizer` should be saved as a JSON file.'
440
+ 'Invalid `filepath` extension for saving a `GraphFeaturizer`. '
441
+ 'A `GraphFeaturizer` should be saved as a JSON file.'
663
442
  )
664
443
  if not filepath.parent.exists():
665
444
  filepath.parent.mkdir(parents=True, exist_ok=True)
@@ -672,12 +451,12 @@ def save_featurizer(
672
451
  def load_featurizer(
673
452
  filepath: str | Path,
674
453
  **kwargs
675
- ) -> Featurizer:
454
+ ) -> GraphFeaturizer:
676
455
  filepath = Path(filepath)
677
456
  if filepath.suffix != '.json':
678
457
  raise ValueError(
679
- 'Invalid `filepath` extension for loading a `Featurizer`. '
680
- 'A `Featurizer` should be saved as a JSON file.'
458
+ 'Invalid `filepath` extension for loading a `GraphFeaturizer`. '
459
+ 'A `GraphFeaturizer` should be saved as a JSON file.'
681
460
  )
682
461
  if not filepath.exists():
683
462
  return
@@ -685,69 +464,60 @@ def load_featurizer(
685
464
  config = json.load(f)
686
465
  return keras.saving.deserialize_keras_object(config)
687
466
 
688
- def _add_super_nodes(
689
- node: dict[str, np.ndarray],
690
- num_super_nodes: int = 1,
691
- ) -> dict[str, np.ndarray]:
692
- node = copy.deepcopy(node)
693
- node['super'] = np.array(
694
- [False] * len(node['feature']) + [True] * num_super_nodes,
695
- dtype=bool
696
- )
697
- super_node_feature = np.zeros(
698
- [num_super_nodes, node['feature'].shape[-1]],
699
- dtype=node['feature'].dtype
700
- )
701
- node['feature'] = np.concatenate([node['feature'], super_node_feature])
702
- return node
703
-
704
- def _add_super_edges(
705
- edge: dict[str, np.ndarray],
706
- num_nodes: int,
707
- num_super_nodes: int,
708
- feature_dtype: str,
709
- index_dtype: str,
710
- self_loops: bool,
711
- ) -> dict[str, np.ndarray]:
712
- edge = copy.deepcopy(edge)
713
-
714
- super_node_indices = np.arange(num_super_nodes) + num_nodes
715
- if self_loops:
716
- edge['source'] = np.concatenate([edge['source'], super_node_indices])
717
- edge['target'] = np.concatenate([edge['target'], super_node_indices])
718
- super_node_indices = np.repeat(super_node_indices, [num_nodes])
719
- node_indices = (
720
- np.tile(np.arange(num_nodes), [num_super_nodes])
467
+ def _add_super_node(
468
+ data: dict[str, dict[str, np.ndarray]]
469
+ ) -> dict[str, dict[str, np.ndarray]]:
470
+
471
+ data['context']['size'] += 1
472
+
473
+ num_nodes = data['node']['feature'].shape[0]
474
+ num_edges = data['edge']['source'].shape[0]
475
+ super_node_index = num_nodes
476
+
477
+ add_self_loops = np.any(
478
+ data['edge']['source'] == data['edge']['target']
721
479
  )
722
- edge['source'] = np.concatenate(
723
- [edge['source'], node_indices, super_node_indices]
724
- ).astype(index_dtype)
725
-
726
- edge['target'] = np.concatenate(
727
- [edge['target'], super_node_indices, node_indices]
728
- ).astype(index_dtype)
729
-
730
- if 'feature' in edge:
731
- num_edges = int(edge['feature'].shape[0])
732
- num_super_edges = int(num_super_nodes * num_nodes * 2)
733
- if self_loops:
734
- num_super_edges += num_super_nodes
735
- edge['super'] = np.asarray(
736
- ([False] * num_edges + [True] * num_super_edges),
737
- dtype=bool
480
+ if add_self_loops:
481
+ data['edge']['source'] = np.append(
482
+ data['edge']['source'], super_node_index
738
483
  )
739
- edge['feature'] = np.concatenate(
740
- [
741
- edge['feature'],
742
- np.zeros(
743
- shape=(num_super_edges, edge['feature'].shape[-1]),
744
- dtype=edge['feature'].dtype
745
- )
746
- ]
484
+ data['edge']['target'] = np.append(
485
+ data['edge']['target'], super_node_index
747
486
  )
748
487
 
749
- return edge
488
+ data['node']['feature'] = np.pad(data['node']['feature'], [(0, 1), (0, 0)])
489
+ data['node']['super'] = np.asarray([False] * num_nodes + [True])
750
490
 
491
+ node_indices = list(range(num_nodes))
492
+ super_node_indices = [super_node_index] * num_nodes
751
493
 
752
- MolFeaturizer = MolGraphFeaturizer
753
- MolFeaturizer3D = MolGraphFeaturizer3D
494
+ data['edge']['source'] = np.append(
495
+ data['edge']['source'], node_indices + super_node_indices
496
+ )
497
+ data['edge']['target'] = np.append(
498
+ data['edge']['target'], super_node_indices + node_indices
499
+ )
500
+
501
+ total_num_edges = data['edge']['source'].shape[0]
502
+ num_super_edges = (total_num_edges - num_edges)
503
+ data['edge']['super'] = np.asarray(
504
+ [False] * num_edges + [True] * num_super_edges
505
+ )
506
+
507
+ if 'feature' in data['edge']:
508
+ data['edge']['feature'] = np.pad(
509
+ data['edge']['feature'], [(0, num_super_edges), (0, 0)]
510
+ )
511
+
512
+ return data
513
+
514
+ def _convert_dtypes(data: dict[str, dict[str, np.ndarray]]) -> np.ndarray:
515
+ for outer_key, inner_dict in data.items():
516
+ for inner_key, inner_value in inner_dict.items():
517
+ if inner_key in ['source', 'target', 'size']:
518
+ data[outer_key][inner_key] = inner_value.astype(np.int32)
519
+ elif np.issubdtype(inner_value.dtype, np.integer):
520
+ data[outer_key][inner_key] = inner_value.astype(np.int32)
521
+ elif np.issubdtype(inner_value.dtype, np.floating):
522
+ data[outer_key][inner_key] = inner_value.astype(np.float32)
523
+ return data