molcraft 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

@@ -0,0 +1,693 @@
1
+ import keras
2
+ import json
3
+ import abc
4
+ import typing
5
+ import copy
6
+ import warnings
7
+ import numpy as np
8
+ import pandas as pd
9
+ import tensorflow as tf
10
+ import multiprocessing as mp
11
+
12
+ from pathlib import Path
13
+
14
+ from molcraft import tensors
15
+ from molcraft import features
16
+ from molcraft import chem
17
+ from molcraft import conformers
18
+ from molcraft import descriptors
19
+
20
+
21
+ @keras.saving.register_keras_serializable(package='molcraft')
22
+ class Featurizer(abc.ABC):
23
+
24
+ """Base class for featurizers.
25
+ """
26
+
27
+ @abc.abstractmethod
28
+ def call(
29
+ self,
30
+ x: tensors.GraphTensor
31
+ ) -> tensors.GraphTensor | list[tensors.GraphTensor]:
32
+ pass
33
+
34
+ @abc.abstractmethod
35
+ def stack(
36
+ self,
37
+ call_outputs: list[tensors.GraphTensor]
38
+ ) -> tensors.GraphTensor:
39
+ pass
40
+
41
+ def get_config(self) -> dict:
42
+ return {}
43
+
44
+ @classmethod
45
+ def from_config(cls, config: dict) -> 'Featurizer':
46
+ return cls(**config)
47
+
48
+ def save(self, filepath: str | Path, *args, **kwargs) -> None:
49
+ save_featurizer(
50
+ self, filepath, *args, **kwargs
51
+ )
52
+
53
+ @staticmethod
54
+ def load(filepath: str | Path, *args, **kwargs) -> 'Featurizer':
55
+ return load_featurizer(
56
+ filepath, *args, **kwargs
57
+ )
58
+
59
+ def __call__(
60
+ self,
61
+ inputs: str | tuple | list | np.ndarray | pd.DataFrame | pd.Series,
62
+ *,
63
+ multiprocessing: bool = False,
64
+ processes: int | None = None,
65
+ device: str = '/cpu:0',
66
+ **kwargs
67
+ ) -> tensors.GraphTensor:
68
+ if isinstance(inputs, (str, tuple)):
69
+ return self.call(inputs)
70
+ if isinstance(inputs, (pd.DataFrame, pd.Series)):
71
+ inputs = inputs.values
72
+ if isinstance(inputs, np.ndarray):
73
+ inputs = list(inputs)
74
+ if not multiprocessing:
75
+ outputs = [self.call(x) for x in inputs]
76
+ else:
77
+ with tf.device(device):
78
+ with mp.Pool(processes) as pool:
79
+ outputs = pool.map(func=self.call, iterable=inputs)
80
+ outputs = [x for x in outputs if x is not None]
81
+ return self.stack(outputs)
82
+
83
+
84
+ @keras.saving.register_keras_serializable(package='molcraft')
85
+ class MolGraphFeaturizer(Featurizer):
86
+
87
+ """Molecular graph featurizer.
88
+
89
+ Converts SMILES or InChI strings to a molecular graph.
90
+
91
+ The molecular graph may encode a single molecule or a batch of molecules.
92
+ In either case, it is a single graph, with each molecule corresponding to
93
+ a subgraph within the graph.
94
+
95
+ Example:
96
+
97
+ >>> import molcraft
98
+ >>>
99
+ >>> featurizer = molcraft.featurizers.MolGraphFeaturizer(
100
+ ... atom_features=[
101
+ ... molcraft.features.AtomType(),
102
+ ... molcraft.features.TotalNumHs(),
103
+ ... molcraft.features.Degree(),
104
+ ... ],
105
+ ... radius=1
106
+ ... )
107
+ >>>
108
+ >>> graph = featurizer(["N[C@@H](C)C(=O)O", "N[C@@H](CS)C(=O)O"])
109
+ >>> graph
110
+ GraphTensor(
111
+ context={
112
+ 'size': <tf.Tensor: shape=[2], dtype=int32>
113
+ },
114
+ node={
115
+ 'feature': <tf.Tensor: shape=[13, 133], dtype=float32>
116
+ },
117
+ edge={
118
+ 'source': <tf.Tensor: shape=[22], dtype=int32>,
119
+ 'target': <tf.Tensor: shape=[22], dtype=int32>,
120
+ 'feature': <tf.Tensor: shape=[22, 5], dtype=float32>
121
+ }
122
+ )
123
+
124
+ Args:
125
+ atom_features:
126
+ A list of `features.Feature` encoding the nodes of the molecular graph.
127
+ bond_features:
128
+ A list of `features.Feature` encoding the edges of the molecular graph.
129
+ molecule_features:
130
+ A `features.Feature` encoding the molecule (or `context`) of the graph.
131
+ If `contextual_super_atom` is set to `True`, then this feature will be
132
+ embedded, via `NodeEmbedding`, as a super node in the molecular graph.
133
+ super_atom:
134
+ A boolean specifying whether super atoms exist and should be embedded
135
+ via `NodeEmbedding`.
136
+ radius:
137
+ An integer specifying how many bond lengths should be considered as an
138
+ edge. The default is None (or 1), which specifies that only bonds should
139
+ be considered an edge.
140
+ self_loops:
141
+ A boolean specifying whether self loops exist. If True, this means that
142
+ each node (atom) has an edge (bond) to itself.
143
+ include_hs:
144
+ A boolean specifying whether hydrogens should be encoded as nodes.
145
+ """
146
+
147
+ def __init__(
148
+ self,
149
+ atom_features: list[features.Feature] | str | None = 'auto',
150
+ bond_features: list[features.Feature] | str | None = 'auto',
151
+ molecule_features: features.Feature | str | None = None,
152
+ super_atom: bool = False,
153
+ radius: int | float | None = None,
154
+ self_loops: bool = False,
155
+ include_hs: bool = False,
156
+ **kwargs,
157
+ ) -> None:
158
+ if molecule_features is None:
159
+ molecule_features = kwargs.pop('mol_features', None)
160
+
161
+ self.radius = int(max(radius or 1, 1))
162
+ self.include_hs = include_hs
163
+ self.self_loops = self_loops
164
+ self.super_atom = super_atom
165
+
166
+ default_atom_features = (
167
+ atom_features == 'auto' or atom_features == 'default'
168
+ )
169
+ if default_atom_features:
170
+ atom_features = [features.AtomType()]
171
+ if not self.include_hs:
172
+ atom_features.append(features.TotalNumHs())
173
+ atom_features.append(features.Degree())
174
+ if not isinstance(self, MolGraphFeaturizer3D):
175
+ default_bond_features = (
176
+ bond_features == 'auto' or bond_features == 'default'
177
+ )
178
+ if default_bond_features or self.radius > 1 or self.self_loops:
179
+ vocab = ['zero', 'single', 'double', 'triple', 'aromatic']
180
+ bond_features = [
181
+ features.BondType(vocab)
182
+ ]
183
+ default_molecule_features = (
184
+ molecule_features == 'auto' or molecule_features == 'default'
185
+ )
186
+ if default_molecule_features:
187
+ molecule_features = [
188
+ descriptors.MolWeight(),
189
+ descriptors.MolTPSA(),
190
+ descriptors.MolLogP(),
191
+ descriptors.NumHeavyAtoms(),
192
+ descriptors.NumHydrogenDonors(),
193
+ descriptors.NumHydrogenAcceptors(),
194
+ descriptors.NumRotatableBonds(),
195
+ descriptors.NumRings(),
196
+ ]
197
+ self._atom_features = atom_features
198
+ self._bond_features = bond_features
199
+ self._molecule_features = molecule_features
200
+ self.feature_dtype = 'float32'
201
+ self.index_dtype = 'int32'
202
+
203
+ def call(self, x: str | typing.Tuple) -> tensors.GraphTensor:
204
+
205
+ if isinstance(x, (tuple, list, np.ndarray)):
206
+ x, *args = x
207
+ else:
208
+ args = []
209
+
210
+ mol = chem.Mol.from_encoding(x, explicit_hs=self.include_hs)
211
+
212
+ if mol is None:
213
+ warn(
214
+ f'Could not obtain `chem.Mol` from {x}. '
215
+ 'Proceeding without it.'
216
+ )
217
+ return None
218
+
219
+ atom_feature = self.atom_features(mol)
220
+ bond_feature = self.bond_features(mol)
221
+ context_feature = self.context_feature(mol)
222
+ molecule_size = self.num_atoms(mol)
223
+
224
+ context, node, edge = {}, {}, {}
225
+ for field, value in zip(['size', 'label', 'weight'], [molecule_size] + args):
226
+ context[field] = value
227
+
228
+ if context_feature is not None:
229
+ context['feature'] = context_feature
230
+
231
+ node['feature'] = atom_feature
232
+
233
+ if bond_feature is not None and (self.radius > 1 or self.self_loops):
234
+ # Append 'zero order' bond feature encoding, which encodes non-bonds.
235
+ zero_bond_feature = np.array(
236
+ [[1., 0., 0., 0., 0.]], dtype=bond_feature.dtype
237
+ )
238
+ bond_feature = np.concatenate(
239
+ [bond_feature, zero_bond_feature], axis=0
240
+ )
241
+
242
+ if self.radius == 1:
243
+ edge['source'], edge['target'] = mol.adjacency(
244
+ fill='full', sparse=True, self_loops=self.self_loops, dtype=self.index_dtype
245
+ )
246
+ if bond_feature is not None:
247
+ bond_indices = []
248
+ for (atom_i, atom_j) in zip(edge['source'], edge['target']):
249
+ if atom_i == atom_j:
250
+ bond_indices.append(-1)
251
+ else:
252
+ bond_indices.append(
253
+ mol.get_bond_between_atoms(atom_i, atom_j).index
254
+ )
255
+ edge['feature'] = bond_feature[bond_indices]
256
+ else:
257
+ paths = chem.get_shortest_paths(
258
+ mol, radius=self.radius, self_loops=self.self_loops
259
+ )
260
+ edge['source'] = np.asarray(
261
+ [path[0] for path in paths], dtype=self.index_dtype
262
+ )
263
+ edge['target'] = np.asarray(
264
+ [path[-1] for path in paths], dtype=self.index_dtype
265
+ )
266
+ edge['length'] = np.asarray(
267
+ [len(path) - 1 for path in paths], dtype=self.index_dtype
268
+ )
269
+ if bond_feature is not None:
270
+ edge['feature'] = self._expand_bond_features(
271
+ mol, paths, bond_feature,
272
+ )
273
+ edge['length'] = np.eye(self.radius + 1)[edge['length']]
274
+
275
+ if self.super_atom:
276
+ node, edge = self._add_super_atom(node, edge)
277
+ context['size'] += 1
278
+
279
+ return tensors.GraphTensor(context, node, edge)
280
+
281
+ def stack(self, outputs):
282
+ if tensors.is_scalar(outputs[0]):
283
+ return tf.stack(outputs, axis=0)
284
+ return tf.concat(outputs, axis=0)
285
+
286
+ def atom_features(self, mol: chem.Mol) -> np.ndarray:
287
+ atom_feature: np.ndarray = np.concatenate(
288
+ [f(mol) for f in self._atom_features], axis=-1
289
+ )
290
+ return atom_feature.astype(self.feature_dtype)
291
+
292
+ def bond_features(self, mol: chem.Mol) -> np.ndarray:
293
+ if self._bond_features is None:
294
+ return None
295
+ bond_feature: np.ndarray = np.concatenate(
296
+ [f(mol) for f in self._bond_features], axis=-1
297
+ )
298
+ return bond_feature.astype(self.feature_dtype)
299
+
300
+ def context_feature(self, mol: chem.Mol) -> np.ndarray:
301
+ if self._molecule_features is None:
302
+ return None
303
+ context_feature: np.ndarray = np.concatenate(
304
+ [f(mol) for f in self._molecule_features], axis=-1
305
+ )
306
+ return context_feature.astype(self.feature_dtype)
307
+
308
+ def num_atoms(self, mol: chem.Mol) -> np.ndarray:
309
+ return np.asarray(mol.num_atoms, dtype=self.index_dtype)
310
+
311
+ def num_bonds(self, mol: chem.Mol) -> np.ndarray:
312
+ return np.asarray(mol.num_bonds, dtype=self.index_dtype)
313
+
314
+ def _expand_bond_features(
315
+ self,
316
+ mol: chem.Mol,
317
+ paths: list[list[int]],
318
+ bond_feature: np.ndarray,
319
+ ) -> np.ndarray:
320
+
321
+ def bond_feature_lookup(path):
322
+ path_bond_indices = [
323
+ mol.get_bond_between_atoms(path[i], path[i + 1]).index
324
+ for i in range(len(path) - 1)
325
+ ]
326
+ padding = [-1] * (self.radius - len(path) + 1)
327
+ path_bond_indices += padding
328
+ return bond_feature[path_bond_indices].reshape(-1)
329
+
330
+ edge_feature = np.asarray(
331
+ [
332
+ bond_feature_lookup(path) for path in paths
333
+ ],
334
+ dtype=self.feature_dtype
335
+ ).reshape((-1, bond_feature.shape[-1] * self.radius))
336
+
337
+ return edge_feature
338
+
339
+ def _add_super_atom(
340
+ self,
341
+ node: dict[str, np.ndarray],
342
+ edge: dict[str, np.ndarray],
343
+ ) -> tuple[dict[str, np.ndarray]]:
344
+ num_super_nodes = 1
345
+ num_nodes = node['feature'].shape[0]
346
+ node = _add_super_nodes(
347
+ node, num_super_nodes, self.feature_dtype
348
+ )
349
+ edge = _add_super_edges(
350
+ edge, num_nodes, num_super_nodes, self.feature_dtype, self.index_dtype
351
+ )
352
+ return node, edge
353
+
354
+ def get_config(self):
355
+ config = super().get_config()
356
+ config.update({
357
+ 'atom_features': keras.saving.serialize_keras_object(
358
+ self._atom_features
359
+ ),
360
+ 'bond_features': keras.saving.serialize_keras_object(
361
+ self._bond_features
362
+ ),
363
+ 'molecule_features': keras.saving.serialize_keras_object(
364
+ self._molecule_features
365
+ ),
366
+ 'super_atom': self.super_atom,
367
+ 'radius': self.radius,
368
+ 'self_loops': self.self_loops,
369
+ 'include_hs': self.include_hs,
370
+ })
371
+ return config
372
+
373
+ @classmethod
374
+ def from_config(cls, config: dict):
375
+ config['atom_features'] = keras.saving.deserialize_keras_object(
376
+ config['atom_features']
377
+ )
378
+ config['bond_features'] = keras.saving.deserialize_keras_object(
379
+ config['bond_features']
380
+ )
381
+ config['molecule_features'] = keras.saving.deserialize_keras_object(
382
+ config['molecule_features']
383
+ )
384
+ return cls(**config)
385
+
386
+
387
+ class MolGraphFeaturizer3D(MolGraphFeaturizer):
388
+
389
+ """Molecular 3d-graph featurizer.
390
+
391
+ Converts SMILES or InChI strings to a molecular graph in 3d space.
392
+ Namely, in addition to the information encoded in a standard molecular
393
+ graph, cartesian coordinates are also included.
394
+
395
+ The molecular graph may encode a single molecule or a batch of molecules.
396
+ In either case, it is a single graph, with each molecule corresponding to
397
+ a subgraph within the graph.
398
+
399
+ Example:
400
+
401
+ >>> import molcraft
402
+ >>>
403
+ >>> featurizer = molcraft.featurizers.MolGraphFeaturizer3D(
404
+ ... atom_features=[
405
+ ... molcraft.features.AtomType(),
406
+ ... molcraft.features.TotalNumHs(),
407
+ ... molcraft.features.Degree(),
408
+ ... ],
409
+ ... radius=5.0
410
+ ... )
411
+ >>>
412
+ >>> graph = featurizer(["N[C@@H](C)C(=O)O", "N[C@@H](CS)C(=O)O"])
413
+ >>> graph
414
+ GraphTensor(
415
+ context={
416
+ 'size': <tf.Tensor: shape=[20], dtype=int32>
417
+ },
418
+ node={
419
+ 'feature': <tf.Tensor: shape=[130, 133], dtype=float32>,
420
+ 'coordinate': <tf.Tensor: shape=[130, 3], dtype=float32>
421
+ },
422
+ edge={
423
+ 'source': <tf.Tensor: shape=[714], dtype=int32>,
424
+ 'target': <tf.Tensor: shape=[714], dtype=int32>,
425
+ 'feature': <tf.Tensor: shape=[714, 23], dtype=float32>
426
+ }
427
+ )
428
+
429
+ Args:
430
+ atom_features:
431
+ A list of `features.Feature` encoding the nodes of the molecular graph.
432
+ bond_features:
433
+ A list of `features.Feature` encoding the edges of the molecular graph.
434
+ molecule_features:
435
+ A `features.Feature` encoding the molecule (or `context`) of the graph.
436
+ If `contextual_super_atom` is set to `True`, then this feature will be
437
+ embedded, via `NodeEmbedding`, as a super node in the molecular graph.
438
+ conformer_generator:
439
+ A `conformers.ConformerGenerator` which produces conformers. If `auto`
440
+ a `conformers.ConformerEmbedder` will be used. If None, it is assumed
441
+ that the molecule already has conformer(s).
442
+ super_atom:
443
+ A boolean specifying whether super atoms exist and should be embedded
444
+ via `NodeEmbedding`.
445
+ radius:
446
+ A float specifying, for each atom, the maximum distance (in angstroms)
447
+ that another atom should be within to be considered an edge. Default
448
+ is set to 6.0 as this should cover most interactions. This parameter
449
+ can be though of as the receptive field. If None, the radius will be
450
+ infinite so all the receptive field will be the entire space (graph).
451
+ self_loops:
452
+ A boolean specifying whether self loops exist. If True, this means that
453
+ each node (atom) has an edge (bond) to itself.
454
+ include_hs:
455
+ A boolean specifying whether hydrogens should be encoded as nodes.
456
+ """
457
+
458
+ def __init__(
459
+ self,
460
+ atom_features: list[features.Feature] | str | None = 'auto',
461
+ bond_features: list[features.Feature] | str | None = 'auto',
462
+ molecule_features: features.Feature | str = None,
463
+ conformer_generator: conformers.ConformerProcessor | str | None = 'auto',
464
+ super_atom: bool = False,
465
+ radius: int | float | None = 6.0,
466
+ self_loops: bool = False,
467
+ include_hs: bool = False,
468
+ **kwargs,
469
+ ) -> None:
470
+ if bond_features == 'auto':
471
+ bond_features = [
472
+ features.Distance()
473
+ ]
474
+ super().__init__(
475
+ atom_features=atom_features,
476
+ bond_features=bond_features,
477
+ molecule_features=molecule_features,
478
+ super_atom=super_atom,
479
+ radius=radius,
480
+ self_loops=self_loops,
481
+ include_hs=include_hs,
482
+ **kwargs,
483
+ )
484
+ if conformer_generator == 'auto':
485
+ conformer_generator = conformers.ConformerGenerator(
486
+ steps=[
487
+ conformers.ConformerEmbedder(
488
+ method='ETKDGv3',
489
+ num_conformers=10
490
+ ),
491
+ ]
492
+ )
493
+ self.conformer_generator = conformer_generator
494
+ self.embed_conformer = self.conformer_generator is not None
495
+ self.radius = float(radius) if radius else None
496
+
497
+ def call(self, x: str | typing.Tuple) -> tensors.GraphTensor:
498
+
499
+ if isinstance(x, (tuple, list, np.ndarray)):
500
+ x, *args = x
501
+ else:
502
+ args = []
503
+
504
+ explicit_hs = (self.include_hs or self.embed_conformer)
505
+ mol = chem.Mol.from_encoding(x, explicit_hs=explicit_hs)
506
+
507
+ if mol is None:
508
+ return None
509
+
510
+ if self.embed_conformer:
511
+ mol = self.conformer_generator(mol)
512
+ if not self.include_hs:
513
+ mol = chem.remove_hs(mol)
514
+
515
+ if mol.num_conformers == 0:
516
+ raise ValueError(
517
+ 'Cannot featurize a molecule without conformer(s). '
518
+ 'Make sure to pass a `ConformerGenerator` to the constructor '
519
+ 'of the `Featurizer` or input a 3D representation of the molecule. '
520
+ )
521
+
522
+ context, node, edge = {}, {}, {}
523
+
524
+ context['size'] = self.num_atoms(mol) + int(self.super_atom)
525
+ for field, value in zip(['label', 'weight'], args):
526
+ context[field] = value
527
+
528
+ node['feature'] = self.atom_features(mol)
529
+
530
+ if self._bond_features:
531
+ edge_feature = self.bond_features(mol)
532
+
533
+ context_feature = self.context_feature(mol)
534
+ if context_feature is not None:
535
+ context['feature'] = context_feature
536
+
537
+ mols = chem._split_mol_by_confs(mol)
538
+ tensor_list = []
539
+ for i, mol in enumerate(mols):
540
+ node_conformer = copy.deepcopy(node)
541
+ edge_conformer = copy.deepcopy(edge)
542
+ conformer = mol.get_conformer()
543
+ adjacency_matrix = conformer.adjacency(
544
+ fill='full',
545
+ radius=self.radius,
546
+ sparse=False,
547
+ self_loops=self.self_loops,
548
+ dtype=np.bool
549
+ )
550
+ edge_conformer['source'], edge_conformer['target'] = np.where(adjacency_matrix)
551
+ edge_conformer['source'] = edge_conformer['source'].astype(self.index_dtype)
552
+ edge_conformer['target'] = edge_conformer['target'].astype(self.index_dtype)
553
+ node_conformer['coordinate'] = conformer.coordinates.astype(self.feature_dtype)
554
+
555
+ if self._bond_features:
556
+ edge_feature_keep = adjacency_matrix.reshape(-1)
557
+ edge_conformer['feature'] = edge_feature[edge_feature_keep]
558
+
559
+ if self.super_atom:
560
+ node_conformer, edge_conformer = self._add_super_atom(
561
+ node_conformer, edge_conformer
562
+ )
563
+ node_conformer['coordinate'] = np.concatenate(
564
+ [node_conformer['coordinate'], conformer.centroid[None]], axis=0
565
+ )
566
+
567
+ tensor_list.append(
568
+ tensors.GraphTensor(context, node_conformer, edge_conformer)
569
+ )
570
+
571
+ return tensor_list
572
+
573
+ def stack(self, outputs):
574
+ # Flatten list of lists (due to multiple conformers per molecule)
575
+ outputs = [x for xs in outputs for x in xs]
576
+ return super().stack(outputs)
577
+
578
+ def get_config(self):
579
+ config = super().get_config()
580
+ config['conformer_generator'] = keras.saving.serialize_keras_object(
581
+ self.conformer_generator
582
+ )
583
+ return config
584
+
585
+ @classmethod
586
+ def from_config(cls, config: dict):
587
+ config['conformer_generator'] = keras.saving.deserialize_keras_object(
588
+ config['conformer_generator']
589
+ )
590
+ return super().from_config(**config)
591
+
592
+
593
+ def save_featurizer(
594
+ featurizer: Featurizer,
595
+ filepath: str | Path,
596
+ overwrite: bool = True,
597
+ **kwargs
598
+ ) -> None:
599
+ filepath = Path(filepath)
600
+ if filepath.suffix != '.json':
601
+ raise ValueError(
602
+ 'Invalid `filepath` extension for saving a `Featurizer`. '
603
+ 'A `Featurizer` should be saved as a JSON file.'
604
+ )
605
+ if not filepath.parent.exists():
606
+ filepath.parent.mkdir(parents=True, exist_ok=True)
607
+ if filepath.exists() and not overwrite:
608
+ return
609
+ serialized_featurizer = keras.saving.serialize_keras_object(featurizer)
610
+ with open(filepath, 'w') as f:
611
+ json.dump(serialized_featurizer, f, indent=4)
612
+
613
+ def load_featurizer(
614
+ filepath: str | Path,
615
+ **kwargs
616
+ ) -> Featurizer:
617
+ filepath = Path(filepath)
618
+ if filepath.suffix != '.json':
619
+ raise ValueError(
620
+ 'Invalid `filepath` extension for loading a `Featurizer`. '
621
+ 'A `Featurizer` should be saved as a JSON file.'
622
+ )
623
+ if not filepath.exists():
624
+ return
625
+ with open(filepath, 'r') as f:
626
+ config = json.load(f)
627
+ return keras.saving.deserialize_keras_object(config)
628
+
629
+ def _add_super_nodes(
630
+ node: dict[str, np.ndarray],
631
+ num_super_nodes: int = 1,
632
+ feature_dtype: str = 'float32',
633
+ ) -> dict[str, np.ndarray]:
634
+ node = copy.deepcopy(node)
635
+ node['super'] = np.array([False] * len(node['feature']) + [True] * num_super_nodes)
636
+ super_node_feature = np.zeros(
637
+ [num_super_nodes, node['feature'].shape[-1]], dtype=feature_dtype
638
+ )
639
+ node['feature'] = np.concatenate([node['feature'], super_node_feature])
640
+ return node
641
+
642
+ def _add_super_edges(
643
+ edge: dict[str, np.ndarray],
644
+ num_nodes: int,
645
+ num_super_nodes: int,
646
+ feature_dtype: str,
647
+ index_dtype: str,
648
+ ) -> dict[str, np.ndarray]:
649
+ edge = copy.deepcopy(edge)
650
+ super_node_indices = (
651
+ np.repeat(np.arange(num_super_nodes), [num_nodes]) + num_nodes
652
+ )
653
+ node_indices = (
654
+ np.tile(np.arange(num_nodes), [num_super_nodes])
655
+ )
656
+ edge['source'] = np.concatenate(
657
+ [
658
+ edge['source'],
659
+ node_indices,
660
+ super_node_indices,
661
+ ]
662
+ )
663
+ edge['source'] = edge['source'].astype(index_dtype)
664
+ edge['target'] = np.concatenate(
665
+ [
666
+ edge['target'],
667
+ super_node_indices,
668
+ node_indices
669
+ ]
670
+ )
671
+ edge['target'] = edge['target'].astype(index_dtype)
672
+ if 'feature' in edge:
673
+ edge['super'] = np.asarray([False] * edge['feature'].shape[0] + [True] * (num_super_nodes * num_nodes * 2))
674
+ edge['feature'] = np.concatenate([edge['feature'], np.zeros((num_super_nodes * num_nodes * 2, edge['feature'].shape[-1]))])
675
+ if 'length' in edge:
676
+ edge['length'] = np.pad(edge['length'], [(0, 0), (1, 0)])
677
+ zero_array = np.zeros((num_nodes * num_super_nodes * 2,), dtype='int32')
678
+ edge_length_dim = edge['length'].shape[1]
679
+ virtual_edge_length = np.eye(edge_length_dim)[zero_array]
680
+ edge['length'] = np.concatenate([edge['length'], virtual_edge_length])
681
+ edge['length'] = edge['length'].astype(feature_dtype)
682
+ return edge
683
+
684
+
685
+ def warn(message: str) -> None:
686
+ warnings.warn(
687
+ message=message,
688
+ category=UserWarning,
689
+ stacklevel=1
690
+ )
691
+
692
+ MolFeaturizer = MolGraphFeaturizer
693
+ MolFeaturizer3D = MolGraphFeaturizer3D