molcraft 0.1.0rc9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

@@ -0,0 +1,624 @@
1
+ import warnings
2
+ import collections
3
+ import functools
4
+ import inspect
5
+ import keras
6
+ import json
7
+ import abc
8
+ import typing
9
+ import numpy as np
10
+ import pandas as pd
11
+ import tensorflow as tf
12
+ import multiprocessing as mp
13
+
14
+ from pathlib import Path
15
+
16
+ from molcraft import tensors
17
+ from molcraft import features
18
+ from molcraft import records
19
+ from molcraft import chem
20
+ from molcraft import descriptors
21
+
22
+
23
+ @keras.saving.register_keras_serializable(package='molcraft')
24
+ class GraphFeaturizer(abc.ABC):
25
+
26
+ """Base graph featurizer.
27
+ """
28
+
29
+ @abc.abstractmethod
30
+ def call(self, x: typing.Any, context: dict) -> tensors.GraphTensor:
31
+ pass
32
+
33
+ def get_config(self) -> dict:
34
+ return {}
35
+
36
+ @classmethod
37
+ def from_config(cls, config: dict) -> 'GraphFeaturizer':
38
+ return cls(**config)
39
+
40
+ def save(self, filepath: str | Path, *args, **kwargs) -> None:
41
+ save_featurizer(self, filepath, *args, **kwargs)
42
+
43
+ @staticmethod
44
+ def load(filepath: str | Path, *args, **kwargs) -> 'GraphFeaturizer':
45
+ return load_featurizer(filepath, *args, **kwargs)
46
+
47
+ def write_records(self, inputs: typing.Iterable, path: str | Path, **kwargs) -> None:
48
+ records.write(
49
+ inputs, featurizer=self, path=path, **kwargs
50
+ )
51
+
52
+ @staticmethod
53
+ def read_records(path: str | Path, **kwargs) -> tf.data.Dataset:
54
+ return records.read(
55
+ path=path, **kwargs
56
+ )
57
+
58
+ def _call(self, inputs: typing.Any) -> tensors.GraphTensor:
59
+ inputs, context = _unpack_inputs(inputs)
60
+ if _call_kwargs(self.call):
61
+ graph = self.call(inputs, context=context)
62
+ else:
63
+ graph = self.call(inputs)
64
+ if not isinstance(graph, tensors.GraphTensor):
65
+ graph = tensors.from_dict(graph)
66
+ return graph
67
+
68
+ def __call__(
69
+ self,
70
+ inputs: typing.Iterable,
71
+ *,
72
+ multiprocessing: bool = False,
73
+ processes: int | None = None,
74
+ device: str = '/cpu:0',
75
+ ) -> tensors.GraphTensor:
76
+ if not isinstance(
77
+ inputs, (list, np.ndarray, pd.Series, pd.DataFrame, typing.Generator)
78
+ ):
79
+ return self._call(inputs)
80
+
81
+ if isinstance(inputs, (np.ndarray, pd.Series)):
82
+ inputs = inputs.tolist()
83
+ elif isinstance(inputs, pd.DataFrame):
84
+ inputs = inputs.iterrows()
85
+
86
+ if not multiprocessing:
87
+ outputs = [self._call(x) for x in inputs]
88
+ else:
89
+ with tf.device(device):
90
+ with mp.Pool(processes) as pool:
91
+ outputs = pool.map(func=self._call, iterable=inputs)
92
+ outputs = [x for x in outputs if x is not None]
93
+ if tensors.is_scalar(outputs[0]):
94
+ return tf.stack(outputs, axis=0)
95
+ return tf.concat(outputs, axis=0)
96
+
97
+
98
+ @keras.saving.register_keras_serializable(package='molcraft')
99
+ class MolGraphFeaturizer(GraphFeaturizer):
100
+
101
+ """Molecular graph featurizer.
102
+
103
+ Converts SMILES or InChI strings to a molecular graph.
104
+
105
+ The molecular graph may encode a single molecule or a batch of molecules.
106
+
107
+ Example:
108
+
109
+ >>> import molcraft
110
+ >>>
111
+ >>> featurizer = molcraft.featurizers.MolGraphFeaturizer(
112
+ ... atom_features=[
113
+ ... molcraft.features.AtomType(),
114
+ ... molcraft.features.NumHydrogens(),
115
+ ... molcraft.features.Degree(),
116
+ ... ],
117
+ ... bond_features=[
118
+ ... molcraft.features.BondType(),
119
+ ... ],
120
+ ... super_node=False,
121
+ ... self_loops=False,
122
+ ... )
123
+ >>>
124
+ >>> graph = featurizer(["N[C@@H](C)C(=O)O", "N[C@@H](CS)C(=O)O"])
125
+ >>> graph
126
+ GraphTensor(
127
+ context={
128
+ 'size': <tf.Tensor: shape=[2], dtype=int32>
129
+ },
130
+ node={
131
+ 'feature': <tf.Tensor: shape=[13, 129], dtype=float32>
132
+ },
133
+ edge={
134
+ 'source': <tf.Tensor: shape=[22], dtype=int32>,
135
+ 'target': <tf.Tensor: shape=[22], dtype=int32>,
136
+ 'feature': <tf.Tensor: shape=[22, 5], dtype=float32>
137
+ }
138
+ )
139
+
140
+ Args:
141
+ atom_features:
142
+ A list of `features.Feature` encoded as the node features.
143
+ bond_features:
144
+ A list of `features.Feature` encoded as the edge features.
145
+ molecule_features:
146
+ A list of `descriptors.Descriptor` encoded as the context feature.
147
+ super_node:
148
+ A boolean specifying whether to include a super node.
149
+ self_loops:
150
+ A boolean specifying whether self loops exist.
151
+ include_hydrogens:
152
+ A boolean specifying whether hydrogens should be encoded as nodes.
153
+ wildcards:
154
+ A boolean specifying whether wildcards exist. If True, wildcard labels will
155
+ be encoded in the graph and separately embedded in `layers.NodeEmbedding`.
156
+ """
157
+
158
+ def __init__(
159
+ self,
160
+ atom_features: list[features.Feature] | str = 'auto',
161
+ bond_features: list[features.Feature] | str | None = 'auto',
162
+ molecule_features: list[descriptors.Descriptor] | str | None = None,
163
+ super_node: bool = False,
164
+ self_loops: bool = False,
165
+ include_hydrogens: bool = False,
166
+ wildcards: bool = False,
167
+ ) -> None:
168
+ use_default_atom_features = (
169
+ atom_features == 'auto' or atom_features == 'default'
170
+ )
171
+ if use_default_atom_features:
172
+ atom_features = [features.AtomType(), features.Degree()]
173
+ if not include_hydrogens:
174
+ atom_features += [features.NumHydrogens()]
175
+
176
+ use_default_bond_features = (
177
+ bond_features == 'auto' or bond_features == 'default'
178
+ )
179
+ if use_default_bond_features:
180
+ bond_features = [features.BondType()]
181
+
182
+ use_default_molecule_features = (
183
+ molecule_features == 'auto' or molecule_features == 'default'
184
+ )
185
+ if use_default_molecule_features:
186
+ molecule_features = [
187
+ descriptors.MolWeight(),
188
+ descriptors.TotalPolarSurfaceArea(),
189
+ descriptors.LogP(),
190
+ descriptors.MolarRefractivity(),
191
+ descriptors.NumHeavyAtoms(),
192
+ descriptors.NumHeteroatoms(),
193
+ descriptors.NumHydrogenDonors(),
194
+ descriptors.NumHydrogenAcceptors(),
195
+ descriptors.NumRotatableBonds(),
196
+ descriptors.NumRings(),
197
+ ]
198
+
199
+ self._atom_features = atom_features
200
+ self._bond_features = bond_features
201
+ self._molecule_features = molecule_features
202
+ self._include_hydrogens = include_hydrogens
203
+ self._wildcards = wildcards
204
+ self._self_loops = self_loops
205
+ self._super_node = super_node
206
+
207
+ def call(
208
+ self,
209
+ mol: str | chem.Mol | tuple,
210
+ context: dict | None = None
211
+ ) -> tensors.GraphTensor:
212
+
213
+ if isinstance(mol, str):
214
+ mol = chem.Mol.from_encoding(
215
+ mol, explicit_hs=self._include_hydrogens
216
+ )
217
+ elif isinstance(mol, chem.RDKitMol):
218
+ mol = chem.Mol.cast(mol)
219
+
220
+ data = {'context': {}, 'node': {}, 'edge': {}}
221
+
222
+ data['context']['size'] = np.asarray(mol.num_atoms)
223
+
224
+ if self._molecule_features is not None:
225
+ data['context']['feature'] = np.concatenate(
226
+ [f(mol) for f in self._molecule_features], axis=-1
227
+ )
228
+
229
+ if context:
230
+ data['context'].update(context)
231
+
232
+ data['node']['feature'] = np.concatenate(
233
+ [f(mol) for f in self._atom_features], axis=-1
234
+ )
235
+
236
+ if self._wildcards:
237
+ wildcard_labels = np.asarray([
238
+ (atom.label or 0) + 1 if atom.symbol == "*" else 0
239
+ for atom in mol.atoms
240
+ ])
241
+ data['node']['wildcard'] = wildcard_labels
242
+ data['node']['feature'] = np.where(
243
+ wildcard_labels[:, None],
244
+ np.zeros_like(data['node']['feature']),
245
+ data['node']['feature']
246
+ )
247
+
248
+ data['edge']['source'], data['edge']['target'] = mol.adjacency(
249
+ fill='full', sparse=True, self_loops=self._self_loops
250
+ )
251
+
252
+ if self._bond_features is not None:
253
+ bond_features = np.concatenate(
254
+ [f(mol) for f in self._bond_features], axis=-1
255
+ )
256
+ if self._self_loops:
257
+ bond_features = np.pad(bond_features, [(0, 1), (0, 0)])
258
+
259
+ bond_indices = [
260
+ mol.get_bond_between_atoms(i, j).index if (i != j) else -1
261
+ for (i, j) in zip(data['edge']['source'], data['edge']['target'])
262
+ ]
263
+
264
+ data['edge']['feature'] = bond_features[bond_indices]
265
+
266
+ if self._super_node:
267
+ data = _add_super_node(data)
268
+
269
+ return tensors.GraphTensor(**_convert_dtypes(data))
270
+
271
+ def get_config(self):
272
+ config = super().get_config()
273
+ config.update({
274
+ 'atom_features': keras.saving.serialize_keras_object(
275
+ self._atom_features
276
+ ),
277
+ 'bond_features': keras.saving.serialize_keras_object(
278
+ self._bond_features
279
+ ),
280
+ 'molecule_features': keras.saving.serialize_keras_object(
281
+ self._molecule_features
282
+ ),
283
+ 'super_node': self._super_node,
284
+ 'self_loops': self._self_loops,
285
+ 'include_hydrogens': self._include_hydrogens,
286
+ 'wildcards': self._wildcards,
287
+ })
288
+ return config
289
+
290
+ @classmethod
291
+ def from_config(cls, config: dict):
292
+ config['atom_features'] = keras.saving.deserialize_keras_object(
293
+ config['atom_features']
294
+ )
295
+ config['bond_features'] = keras.saving.deserialize_keras_object(
296
+ config['bond_features']
297
+ )
298
+ config['molecule_features'] = keras.saving.deserialize_keras_object(
299
+ config['molecule_features']
300
+ )
301
+ return cls(**config)
302
+
303
+
304
+ @keras.saving.register_keras_serializable(package='molcraft')
305
+ class MolGraphFeaturizer3D(MolGraphFeaturizer):
306
+
307
+ """3D Molecular graph featurizer.
308
+
309
+ Converts SMILES or InChI strings to a 3d molecular graph.
310
+
311
+ The molecular graph may encode a single molecule or a batch of molecules.
312
+
313
+ Example:
314
+
315
+ >>> import molcraft
316
+ >>>
317
+ >>> featurizer = molcraft.featurizers.MolGraphFeaturizer3D(
318
+ ... atom_features=[
319
+ ... molcraft.features.AtomType(),
320
+ ... molcraft.features.NumHydrogens(),
321
+ ... molcraft.features.Degree(),
322
+ ... ],
323
+ ... radius=5.0,
324
+ ... random_seed=42,
325
+ ... )
326
+ >>>
327
+ >>> graph = featurizer(["N[C@@H](C)C(=O)O", "N[C@@H](CS)C(=O)O"])
328
+ >>> graph
329
+ GraphTensor(
330
+ context={
331
+ 'size': <tf.Tensor: shape=[2], dtype=int32>
332
+ },
333
+ node={
334
+ 'feature': <tf.Tensor: shape=[13, 129], dtype=float32>,
335
+ 'coordinate': <tf.Tensor: shape=[13, 3], dtype=float32>
336
+ },
337
+ edge={
338
+ 'source': <tf.Tensor: shape=[72], dtype=int32>,
339
+ 'target': <tf.Tensor: shape=[72], dtype=int32>,
340
+ 'feature': <tf.Tensor: shape=[72, 12], dtype=float32>
341
+ }
342
+ )
343
+
344
+ Args:
345
+ atom_features:
346
+ A list of `features.Feature` encoded as the node features.
347
+ pair_features:
348
+ A list of `features.PairFeature` encoded as the edge features.
349
+ molecule_features:
350
+ A list of `descriptors.Descriptor` encoded as the context feature.
351
+ super_node:
352
+ A boolean specifying whether to include a super node.
353
+ self_loops:
354
+ A boolean specifying whether self loops exist.
355
+ include_hydrogens:
356
+ A boolean specifying whether hydrogens should be encoded as nodes.
357
+ wildcards:
358
+ A boolean specifying whether wildcards exist. If True, wildcard labels will
359
+ be encoded in the graph and separately embedded in `layers.NodeEmbedding`.
360
+ radius:
361
+ A floating point value specifying maximum edge length.
362
+ random_seed:
363
+ An integer specifying the random seed for the conformer generation.
364
+ """
365
+
366
+ def __init__(
367
+ self,
368
+ atom_features: list[features.Feature] | str = 'auto',
369
+ pair_features: list[features.PairFeature] | str = 'auto',
370
+ molecule_features: features.Feature | str | None = None,
371
+ super_node: bool = False,
372
+ self_loops: bool = False,
373
+ include_hydrogens: bool = False,
374
+ wildcards: bool = False,
375
+ radius: int | float | None = 6.0,
376
+ random_seed: int | None = None,
377
+ **kwargs,
378
+ ) -> None:
379
+ kwargs.pop('bond_features', None)
380
+ super().__init__(
381
+ atom_features=atom_features,
382
+ bond_features=None,
383
+ molecule_features=molecule_features,
384
+ super_node=super_node,
385
+ self_loops=self_loops,
386
+ include_hydrogens=include_hydrogens,
387
+ wildcards=wildcards,
388
+ )
389
+
390
+ use_default_pair_features = (
391
+ pair_features == 'auto' or pair_features == 'default'
392
+ )
393
+ if use_default_pair_features:
394
+ pair_features = [features.PairDistance()]
395
+
396
+ self._pair_features = pair_features
397
+ self._radius = float(radius) if radius else None
398
+ self._random_seed = random_seed
399
+
400
+ def call(
401
+ self,
402
+ mol: str | chem.Mol | tuple,
403
+ context: dict | None = None
404
+ ) -> tensors.GraphTensor:
405
+
406
+ if isinstance(mol, str):
407
+ mol = chem.Mol.from_encoding(
408
+ mol, explicit_hs=True
409
+ )
410
+ elif isinstance(mol, chem.RDKitMol):
411
+ mol = chem.Mol.cast(mol)
412
+
413
+ if mol.num_conformers == 0:
414
+ mol = chem.embed_conformers(
415
+ mol, num_conformers=1, random_seed=self._random_seed
416
+ )
417
+
418
+ if not self._include_hydrogens:
419
+ mol = chem.remove_hs(mol)
420
+
421
+ data = {'context': {}, 'node': {}, 'edge': {}}
422
+
423
+ data['context']['size'] = np.asarray(mol.num_atoms)
424
+
425
+ if self._molecule_features is not None:
426
+ data['context']['feature'] = np.concatenate(
427
+ [f(mol) for f in self._molecule_features], axis=-1
428
+ )
429
+
430
+ if context:
431
+ data['context'].update(context)
432
+
433
+ conformer = mol.get_conformer()
434
+
435
+ data['node']['feature'] = np.concatenate(
436
+ [f(mol) for f in self._atom_features], axis=-1
437
+ )
438
+ data['node']['coordinate'] = conformer.coordinates
439
+
440
+ if self._wildcards:
441
+ wildcard_labels = np.asarray([
442
+ (atom.label or 0) + 1 if atom.symbol == "*" else 0
443
+ for atom in mol.atoms
444
+ ])
445
+ data['node']['wildcard'] = wildcard_labels
446
+ data['node']['feature'] = np.where(
447
+ wildcard_labels[:, None],
448
+ np.zeros_like(data['node']['feature']),
449
+ data['node']['feature']
450
+ )
451
+
452
+ adjacency_matrix = conformer.adjacency(
453
+ fill='full', radius=self._radius, sparse=False, self_loops=self._self_loops,
454
+ )
455
+
456
+ data['edge']['source'], data['edge']['target'] = np.where(adjacency_matrix)
457
+
458
+ if self._pair_features is not None:
459
+ pair_features = np.concatenate(
460
+ [f(mol) for f in self._pair_features], axis=-1
461
+ )
462
+ pair_keep = adjacency_matrix.reshape(-1).astype(bool)
463
+ data['edge']['feature'] = pair_features[pair_keep]
464
+
465
+ if self._super_node:
466
+ data = _add_super_node(data)
467
+ data['node']['coordinate'] = np.concatenate(
468
+ [data['node']['coordinate'], conformer.centroid[None]], axis=0
469
+ )
470
+
471
+ return tensors.GraphTensor(**_convert_dtypes(data))
472
+
473
+ @property
474
+ def random_seed(self) -> int | None:
475
+ return self._random_seed
476
+
477
+ @random_seed.setter
478
+ def random_seed(self, value: int) -> None:
479
+ self._random_seed = value
480
+
481
+ def get_config(self):
482
+ config = super().get_config()
483
+ config['radius'] = self._radius
484
+ config['pair_features'] = keras.saving.serialize_keras_object(
485
+ self._pair_features
486
+ )
487
+ config['random_seed'] = self._random_seed
488
+ return config
489
+
490
+ @classmethod
491
+ def from_config(cls, config: dict):
492
+ config['pair_features'] = keras.saving.deserialize_keras_object(
493
+ config['pair_features']
494
+ )
495
+ return super().from_config(config)
496
+
497
+
498
+ def save_featurizer(
499
+ featurizer: GraphFeaturizer,
500
+ filepath: str | Path,
501
+ overwrite: bool = True,
502
+ **kwargs
503
+ ) -> None:
504
+ filepath = Path(filepath)
505
+ if filepath.suffix != '.json':
506
+ raise ValueError(
507
+ 'Invalid `filepath` extension for saving a `GraphFeaturizer`. '
508
+ 'A `GraphFeaturizer` should be saved as a JSON file.'
509
+ )
510
+ if not filepath.parent.exists():
511
+ filepath.parent.mkdir(parents=True, exist_ok=True)
512
+ if filepath.exists() and not overwrite:
513
+ return
514
+ serialized_featurizer = keras.saving.serialize_keras_object(featurizer)
515
+ with open(filepath, 'w') as f:
516
+ json.dump(serialized_featurizer, f, indent=4)
517
+
518
+ def load_featurizer(
519
+ filepath: str | Path,
520
+ **kwargs
521
+ ) -> GraphFeaturizer:
522
+ filepath = Path(filepath)
523
+ if filepath.suffix != '.json':
524
+ raise ValueError(
525
+ 'Invalid `filepath` extension for loading a `GraphFeaturizer`. '
526
+ 'A `GraphFeaturizer` should be saved as a JSON file.'
527
+ )
528
+ if not filepath.exists():
529
+ return
530
+ with open(filepath, 'r') as f:
531
+ config = json.load(f)
532
+ return keras.saving.deserialize_keras_object(config)
533
+
534
+ def _add_super_node(
535
+ data: dict[str, dict[str, np.ndarray]]
536
+ ) -> dict[str, dict[str, np.ndarray]]:
537
+
538
+ data['context']['size'] += 1
539
+
540
+ num_nodes = data['node']['feature'].shape[0]
541
+ num_edges = data['edge']['source'].shape[0]
542
+ super_node_index = num_nodes
543
+
544
+ add_self_loops = np.any(
545
+ data['edge']['source'] == data['edge']['target']
546
+ )
547
+ if add_self_loops:
548
+ data['edge']['source'] = np.append(
549
+ data['edge']['source'], super_node_index
550
+ )
551
+ data['edge']['target'] = np.append(
552
+ data['edge']['target'], super_node_index
553
+ )
554
+
555
+ data['node']['feature'] = np.pad(data['node']['feature'], [(0, 1), (0, 0)])
556
+ data['node']['super'] = np.asarray([False] * num_nodes + [True])
557
+ if 'wildcard' in data['node']:
558
+ data['node']['wildcard'] = np.pad(data['node']['wildcard'], [(0, 1)])
559
+
560
+ node_indices = list(range(num_nodes))
561
+ super_node_indices = [super_node_index] * num_nodes
562
+
563
+ data['edge']['source'] = np.append(
564
+ data['edge']['source'], node_indices + super_node_indices
565
+ )
566
+ data['edge']['target'] = np.append(
567
+ data['edge']['target'], super_node_indices + node_indices
568
+ )
569
+
570
+ total_num_edges = data['edge']['source'].shape[0]
571
+ num_super_edges = (total_num_edges - num_edges)
572
+ data['edge']['super'] = np.asarray(
573
+ [False] * num_edges + [True] * num_super_edges
574
+ )
575
+
576
+ if 'feature' in data['edge']:
577
+ data['edge']['feature'] = np.pad(
578
+ data['edge']['feature'], [(0, num_super_edges), (0, 0)]
579
+ )
580
+
581
+ return data
582
+
583
+ def _convert_dtypes(data: dict[str, dict[str, np.ndarray]]) -> np.ndarray:
584
+ for outer_key, inner_dict in data.items():
585
+ for inner_key, inner_value in inner_dict.items():
586
+ if inner_key in ['source', 'target', 'size']:
587
+ data[outer_key][inner_key] = inner_value.astype(np.int32)
588
+ elif np.issubdtype(inner_value.dtype, np.integer):
589
+ data[outer_key][inner_key] = inner_value.astype(np.int32)
590
+ elif np.issubdtype(inner_value.dtype, np.floating):
591
+ data[outer_key][inner_key] = inner_value.astype(np.float32)
592
+ return data
593
+
594
+ def _unpack_inputs(inputs) -> tuple:
595
+ if isinstance(inputs, np.ndarray):
596
+ inputs = tuple(inputs.tolist())
597
+ elif isinstance(inputs, list):
598
+ inputs = tuple(inputs)
599
+ if not isinstance(inputs, tuple):
600
+ mol, context = inputs, {}
601
+ elif isinstance(inputs[0], int) and isinstance(inputs[1], pd.Series):
602
+ index, series = inputs
603
+ mol = series.values[0]
604
+ context = dict(
605
+ zip(
606
+ map(_snake_case, series.index[1:]),
607
+ map(np.asarray, series.values[1:])
608
+ )
609
+ )
610
+ context['index'] = np.asarray(index)
611
+ else:
612
+ mol, *context = inputs
613
+ context = dict(zip(['label', 'sample_weight'], map(np.asarray, context)))
614
+ return mol, context
615
+
616
+ def _snake_case(x: str) -> str:
617
+ return '_'.join(x.lower().split())
618
+
619
+ def _call_kwargs(func) -> bool:
620
+ signature = inspect.signature(func)
621
+ return any(
622
+ (param.kind == inspect.Parameter.VAR_KEYWORD) or (param.name == 'context')
623
+ for param in signature.parameters.values()
624
+ )