molcraft 0.1.0rc10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,727 @@
1
+ import warnings
2
+ import inspect
3
+ import keras
4
+ import json
5
+ import abc
6
+ import re
7
+ import typing
8
+ import numpy as np
9
+ import pandas as pd
10
+ import tensorflow as tf
11
+ import multiprocessing as mp
12
+
13
+ from rdkit.Chem import rdChemReactions
14
+
15
+ from pathlib import Path
16
+
17
+ from molcraft import tensors
18
+ from molcraft import features
19
+ from molcraft import records
20
+ from molcraft import chem
21
+ from molcraft import descriptors
22
+
23
+
24
+ @keras.saving.register_keras_serializable(package='molcraft')
25
+ class GraphFeaturizer(abc.ABC):
26
+
27
+ """Base graph featurizer.
28
+ """
29
+
30
+ @abc.abstractmethod
31
+ def call(self, x: typing.Any, context: dict) -> tensors.GraphTensor:
32
+ pass
33
+
34
+ def get_config(self) -> dict:
35
+ return {}
36
+
37
+ @classmethod
38
+ def from_config(cls, config: dict) -> 'GraphFeaturizer':
39
+ return cls(**config)
40
+
41
+ def save(self, filepath: str | Path, *args, **kwargs) -> None:
42
+ save_featurizer(self, filepath, *args, **kwargs)
43
+
44
+ @staticmethod
45
+ def load(filepath: str | Path, *args, **kwargs) -> 'GraphFeaturizer':
46
+ return load_featurizer(filepath, *args, **kwargs)
47
+
48
+ def write_records(self, inputs: typing.Iterable, path: str | Path, **kwargs) -> None:
49
+ records.write(
50
+ inputs, featurizer=self, path=path, **kwargs
51
+ )
52
+
53
+ @staticmethod
54
+ def read_records(path: str | Path, **kwargs) -> tf.data.Dataset:
55
+ return records.read(
56
+ path=path, **kwargs
57
+ )
58
+
59
+ def _call(self, inputs: typing.Any) -> tensors.GraphTensor:
60
+ inputs, context = _unpack_inputs(inputs)
61
+ if _call_kwargs(self.call):
62
+ graph = self.call(inputs, context=context)
63
+ else:
64
+ graph = self.call(inputs)
65
+ if not isinstance(graph, tensors.GraphTensor):
66
+ graph = tensors.from_dict(graph)
67
+ return graph
68
+
69
+ def __call__(
70
+ self,
71
+ inputs: typing.Iterable,
72
+ *,
73
+ multiprocessing: bool = False,
74
+ processes: int | None = None,
75
+ device: str = '/cpu:0',
76
+ ) -> tensors.GraphTensor:
77
+ if not isinstance(
78
+ inputs, (list, np.ndarray, pd.Series, pd.DataFrame, typing.Generator)
79
+ ):
80
+ return self._call(inputs)
81
+
82
+ if isinstance(inputs, (np.ndarray, pd.Series)):
83
+ inputs = inputs.tolist()
84
+ elif isinstance(inputs, pd.DataFrame):
85
+ inputs = inputs.iterrows()
86
+
87
+ if not multiprocessing:
88
+ outputs = [self._call(x) for x in inputs]
89
+ else:
90
+ with tf.device(device):
91
+ with mp.Pool(processes) as pool:
92
+ outputs = pool.map(func=self._call, iterable=inputs)
93
+ outputs = [x for x in outputs if x is not None]
94
+ if tensors.is_scalar(outputs[0]):
95
+ return tf.stack(outputs, axis=0)
96
+ return tf.concat(outputs, axis=0)
97
+
98
+
99
+ @keras.saving.register_keras_serializable(package='molcraft')
100
+ class MolGraphFeaturizer(GraphFeaturizer):
101
+
102
+ """Molecular graph featurizer.
103
+
104
+ Converts SMILES or InChI strings to a molecular graph.
105
+
106
+ The molecular graph may encode a single molecule or a batch of molecules.
107
+
108
+ Example:
109
+
110
+ >>> import molcraft
111
+ >>>
112
+ >>> featurizer = molcraft.featurizers.MolGraphFeaturizer(
113
+ ... atom_features=[
114
+ ... molcraft.features.AtomType(),
115
+ ... molcraft.features.NumHydrogens(),
116
+ ... molcraft.features.Degree(),
117
+ ... ],
118
+ ... bond_features=[
119
+ ... molcraft.features.BondType(),
120
+ ... ],
121
+ ... super_node=False,
122
+ ... self_loops=False,
123
+ ... )
124
+ >>>
125
+ >>> graph = featurizer(["N[C@@H](C)C(=O)O", "N[C@@H](CS)C(=O)O"])
126
+ >>> graph
127
+ GraphTensor(
128
+ context={
129
+ 'size': <tf.Tensor: shape=[2], dtype=int32>
130
+ },
131
+ node={
132
+ 'feature': <tf.Tensor: shape=[13, 129], dtype=float32>
133
+ },
134
+ edge={
135
+ 'source': <tf.Tensor: shape=[22], dtype=int32>,
136
+ 'target': <tf.Tensor: shape=[22], dtype=int32>,
137
+ 'feature': <tf.Tensor: shape=[22, 5], dtype=float32>
138
+ }
139
+ )
140
+
141
+ Args:
142
+ atom_features:
143
+ A list of `features.Feature` encoded as the node features.
144
+ bond_features:
145
+ A list of `features.Feature` encoded as the edge features.
146
+ molecule_features:
147
+ A list of `descriptors.Descriptor` encoded as the context feature.
148
+ super_node:
149
+ A boolean specifying whether to include a super node.
150
+ self_loops:
151
+ A boolean specifying whether self loops exist.
152
+ include_hydrogens:
153
+ A boolean specifying whether hydrogens should be encoded as nodes.
154
+ wildcards:
155
+ A boolean specifying whether wildcards exist. If True, wildcard labels will
156
+ be encoded in the graph and separately embedded in `layers.NodeEmbedding`.
157
+ """
158
+
159
+ def __init__(
160
+ self,
161
+ atom_features: list[features.Feature] | str = 'auto',
162
+ bond_features: list[features.Feature] | str | None = 'auto',
163
+ molecule_features: list[descriptors.Descriptor] | str | None = None,
164
+ super_node: bool = False,
165
+ self_loops: bool = False,
166
+ include_hydrogens: bool = False,
167
+ wildcards: bool = False,
168
+ ) -> None:
169
+ use_default_atom_features = (
170
+ atom_features == 'auto' or atom_features == 'default'
171
+ )
172
+ if use_default_atom_features:
173
+ atom_features = [features.AtomType(), features.Degree()]
174
+ if not include_hydrogens:
175
+ atom_features += [features.NumHydrogens()]
176
+
177
+ use_default_bond_features = (
178
+ bond_features == 'auto' or bond_features == 'default'
179
+ )
180
+ if use_default_bond_features:
181
+ bond_features = [features.BondType()]
182
+
183
+ use_default_molecule_features = (
184
+ molecule_features == 'auto' or molecule_features == 'default'
185
+ )
186
+ if use_default_molecule_features:
187
+ molecule_features = [
188
+ descriptors.MolWeight(),
189
+ descriptors.TotalPolarSurfaceArea(),
190
+ descriptors.LogP(),
191
+ descriptors.MolarRefractivity(),
192
+ descriptors.NumHeavyAtoms(),
193
+ descriptors.NumHeteroatoms(),
194
+ descriptors.NumHydrogenDonors(),
195
+ descriptors.NumHydrogenAcceptors(),
196
+ descriptors.NumRotatableBonds(),
197
+ descriptors.NumRings(),
198
+ ]
199
+
200
+ self._atom_features = atom_features
201
+ self._bond_features = bond_features
202
+ self._molecule_features = molecule_features
203
+ self._include_hydrogens = include_hydrogens
204
+ self._wildcards = wildcards
205
+ self._self_loops = self_loops
206
+ self._super_node = super_node
207
+
208
+ def call(
209
+ self,
210
+ mol: str | chem.Mol | tuple,
211
+ context: dict | None = None
212
+ ) -> tensors.GraphTensor:
213
+
214
+ if isinstance(mol, str):
215
+ mol = chem.Mol.from_encoding(
216
+ mol, explicit_hs=self._include_hydrogens
217
+ )
218
+ elif isinstance(mol, chem.RDKitMol):
219
+ mol = chem.Mol.cast(mol)
220
+
221
+ data = {'context': {}, 'node': {}, 'edge': {}}
222
+
223
+ data['context']['size'] = np.asarray(mol.num_atoms)
224
+
225
+ if self._molecule_features is not None:
226
+ data['context']['feature'] = np.concatenate(
227
+ [f(mol) for f in self._molecule_features], axis=-1
228
+ )
229
+
230
+ if context:
231
+ data['context'].update(context)
232
+
233
+ data['node']['feature'] = np.concatenate(
234
+ [f(mol) for f in self._atom_features], axis=-1
235
+ )
236
+
237
+ if self._wildcards:
238
+ wildcard_labels = np.asarray([
239
+ (atom.label or 0) + 1 if atom.symbol == "*" else 0
240
+ for atom in mol.atoms
241
+ ])
242
+ data['node']['wildcard'] = wildcard_labels
243
+ data['node']['feature'] = np.where(
244
+ wildcard_labels[:, None],
245
+ np.zeros_like(data['node']['feature']),
246
+ data['node']['feature']
247
+ )
248
+
249
+ data['edge']['source'], data['edge']['target'] = mol.adjacency(
250
+ fill='full', sparse=True, self_loops=self._self_loops
251
+ )
252
+
253
+ if self._bond_features is not None:
254
+ bond_features = np.concatenate(
255
+ [f(mol) for f in self._bond_features], axis=-1
256
+ )
257
+ if self._self_loops:
258
+ bond_features = np.pad(bond_features, [(0, 1), (0, 0)])
259
+
260
+ bond_indices = [
261
+ mol.get_bond_between_atoms(i, j).index if (i != j) else -1
262
+ for (i, j) in zip(data['edge']['source'], data['edge']['target'])
263
+ ]
264
+
265
+ data['edge']['feature'] = bond_features[bond_indices]
266
+
267
+ if self._super_node:
268
+ data = _add_super_node(data)
269
+
270
+ return tensors.GraphTensor(**_convert_dtypes(data))
271
+
272
+ def get_config(self):
273
+ config = super().get_config()
274
+ config.update({
275
+ 'atom_features': keras.saving.serialize_keras_object(
276
+ self._atom_features
277
+ ),
278
+ 'bond_features': keras.saving.serialize_keras_object(
279
+ self._bond_features
280
+ ),
281
+ 'molecule_features': keras.saving.serialize_keras_object(
282
+ self._molecule_features
283
+ ),
284
+ 'super_node': self._super_node,
285
+ 'self_loops': self._self_loops,
286
+ 'include_hydrogens': self._include_hydrogens,
287
+ 'wildcards': self._wildcards,
288
+ })
289
+ return config
290
+
291
+ @classmethod
292
+ def from_config(cls, config: dict):
293
+ config['atom_features'] = keras.saving.deserialize_keras_object(
294
+ config['atom_features']
295
+ )
296
+ config['bond_features'] = keras.saving.deserialize_keras_object(
297
+ config['bond_features']
298
+ )
299
+ config['molecule_features'] = keras.saving.deserialize_keras_object(
300
+ config['molecule_features']
301
+ )
302
+ return cls(**config)
303
+
304
+
305
+ @keras.saving.register_keras_serializable(package='molcraft')
306
+ class MolGraphFeaturizer3D(MolGraphFeaturizer):
307
+
308
+ """3D Molecular graph featurizer.
309
+
310
+ Converts SMILES or InChI strings to a 3d molecular graph.
311
+
312
+ The molecular graph may encode a single molecule or a batch of molecules.
313
+
314
+ Example:
315
+
316
+ >>> import molcraft
317
+ >>>
318
+ >>> featurizer = molcraft.featurizers.MolGraphFeaturizer3D(
319
+ ... atom_features=[
320
+ ... molcraft.features.AtomType(),
321
+ ... molcraft.features.NumHydrogens(),
322
+ ... molcraft.features.Degree(),
323
+ ... ],
324
+ ... radius=5.0,
325
+ ... random_seed=42,
326
+ ... )
327
+ >>>
328
+ >>> graph = featurizer(["N[C@@H](C)C(=O)O", "N[C@@H](CS)C(=O)O"])
329
+ >>> graph
330
+ GraphTensor(
331
+ context={
332
+ 'size': <tf.Tensor: shape=[2], dtype=int32>
333
+ },
334
+ node={
335
+ 'feature': <tf.Tensor: shape=[13, 129], dtype=float32>,
336
+ 'coordinate': <tf.Tensor: shape=[13, 3], dtype=float32>
337
+ },
338
+ edge={
339
+ 'source': <tf.Tensor: shape=[72], dtype=int32>,
340
+ 'target': <tf.Tensor: shape=[72], dtype=int32>,
341
+ 'feature': <tf.Tensor: shape=[72, 12], dtype=float32>
342
+ }
343
+ )
344
+
345
+ Args:
346
+ atom_features:
347
+ A list of `features.Feature` encoded as the node features.
348
+ pair_features:
349
+ A list of `features.PairFeature` encoded as the edge features.
350
+ molecule_features:
351
+ A list of `descriptors.Descriptor` encoded as the context feature.
352
+ super_node:
353
+ A boolean specifying whether to include a super node.
354
+ self_loops:
355
+ A boolean specifying whether self loops exist.
356
+ include_hydrogens:
357
+ A boolean specifying whether hydrogens should be encoded as nodes.
358
+ wildcards:
359
+ A boolean specifying whether wildcards exist. If True, wildcard labels will
360
+ be encoded in the graph and separately embedded in `layers.NodeEmbedding`.
361
+ radius:
362
+ A floating point value specifying maximum edge length.
363
+ random_seed:
364
+ An integer specifying the random seed for the conformer generation.
365
+ """
366
+
367
+ def __init__(
368
+ self,
369
+ atom_features: list[features.Feature] | str = 'auto',
370
+ pair_features: list[features.PairFeature] | str = 'auto',
371
+ molecule_features: features.Feature | str | None = None,
372
+ super_node: bool = False,
373
+ self_loops: bool = False,
374
+ include_hydrogens: bool = False,
375
+ wildcards: bool = False,
376
+ radius: int | float | None = 6.0,
377
+ random_seed: int | None = None,
378
+ **kwargs,
379
+ ) -> None:
380
+ kwargs.pop('bond_features', None)
381
+ super().__init__(
382
+ atom_features=atom_features,
383
+ bond_features=None,
384
+ molecule_features=molecule_features,
385
+ super_node=super_node,
386
+ self_loops=self_loops,
387
+ include_hydrogens=include_hydrogens,
388
+ wildcards=wildcards,
389
+ )
390
+
391
+ use_default_pair_features = (
392
+ pair_features == 'auto' or pair_features == 'default'
393
+ )
394
+ if use_default_pair_features:
395
+ pair_features = [features.PairDistance()]
396
+
397
+ self._pair_features = pair_features
398
+ self._radius = float(radius) if radius else None
399
+ self._random_seed = random_seed
400
+
401
+ def call(
402
+ self,
403
+ mol: str | chem.Mol | tuple,
404
+ context: dict | None = None
405
+ ) -> tensors.GraphTensor:
406
+
407
+ if isinstance(mol, str):
408
+ mol = chem.Mol.from_encoding(
409
+ mol, explicit_hs=True
410
+ )
411
+ elif isinstance(mol, chem.RDKitMol):
412
+ mol = chem.Mol.cast(mol)
413
+
414
+ if mol.num_conformers == 0:
415
+ mol = chem.embed_conformers(
416
+ mol, num_conformers=1, random_seed=self._random_seed
417
+ )
418
+
419
+ if not self._include_hydrogens:
420
+ mol = chem.remove_hs(mol)
421
+
422
+ data = {'context': {}, 'node': {}, 'edge': {}}
423
+
424
+ data['context']['size'] = np.asarray(mol.num_atoms)
425
+
426
+ if self._molecule_features is not None:
427
+ data['context']['feature'] = np.concatenate(
428
+ [f(mol) for f in self._molecule_features], axis=-1
429
+ )
430
+
431
+ if context:
432
+ data['context'].update(context)
433
+
434
+ conformer = mol.get_conformer()
435
+
436
+ data['node']['feature'] = np.concatenate(
437
+ [f(mol) for f in self._atom_features], axis=-1
438
+ )
439
+ data['node']['coordinate'] = conformer.coordinates
440
+
441
+ if self._wildcards:
442
+ wildcard_labels = np.asarray([
443
+ (atom.label or 0) + 1 if atom.symbol == "*" else 0
444
+ for atom in mol.atoms
445
+ ])
446
+ data['node']['wildcard'] = wildcard_labels
447
+ data['node']['feature'] = np.where(
448
+ wildcard_labels[:, None],
449
+ np.zeros_like(data['node']['feature']),
450
+ data['node']['feature']
451
+ )
452
+
453
+ adjacency_matrix = conformer.adjacency(
454
+ fill='full', radius=self._radius, sparse=False, self_loops=self._self_loops,
455
+ )
456
+
457
+ data['edge']['source'], data['edge']['target'] = np.where(adjacency_matrix)
458
+
459
+ if self._pair_features is not None:
460
+ pair_features = np.concatenate(
461
+ [f(mol) for f in self._pair_features], axis=-1
462
+ )
463
+ pair_keep = adjacency_matrix.reshape(-1).astype(bool)
464
+ data['edge']['feature'] = pair_features[pair_keep]
465
+
466
+ if self._super_node:
467
+ data = _add_super_node(data)
468
+ data['node']['coordinate'] = np.concatenate(
469
+ [data['node']['coordinate'], conformer.centroid[None]], axis=0
470
+ )
471
+
472
+ return tensors.GraphTensor(**_convert_dtypes(data))
473
+
474
+ @property
475
+ def random_seed(self) -> int | None:
476
+ return self._random_seed
477
+
478
+ @random_seed.setter
479
+ def random_seed(self, value: int) -> None:
480
+ self._random_seed = value
481
+
482
+ def get_config(self):
483
+ config = super().get_config()
484
+ config['radius'] = self._radius
485
+ config['pair_features'] = keras.saving.serialize_keras_object(
486
+ self._pair_features
487
+ )
488
+ config['random_seed'] = self._random_seed
489
+ return config
490
+
491
+ @classmethod
492
+ def from_config(cls, config: dict):
493
+ config['pair_features'] = keras.saving.deserialize_keras_object(
494
+ config['pair_features']
495
+ )
496
+ return super().from_config(config)
497
+
498
+
499
+ class PeptideGraphFeaturizer(MolGraphFeaturizer):
500
+
501
+ def __init__(
502
+ self,
503
+ atom_features: list[features.Feature] | str = 'auto',
504
+ bond_features: list[features.Feature] | str | None = 'auto',
505
+ molecule_features: list[descriptors.Descriptor] | str | None = None,
506
+ super_node: bool = False,
507
+ self_loops: bool = False,
508
+ include_hydrogens: bool = False,
509
+ wildcards: bool = False,
510
+ monomers: dict[str, str] = None
511
+ ) -> None:
512
+ super().__init__(
513
+ atom_features=atom_features,
514
+ bond_features=bond_features,
515
+ molecule_features=molecule_features,
516
+ super_node=super_node,
517
+ self_loops=self_loops,
518
+ include_hydrogens=include_hydrogens,
519
+ wildcards=wildcards,
520
+ )
521
+ if not monomers:
522
+ monomers = {}
523
+ monomers = {**_default_monomers, **monomers}
524
+ self.monomers = monomers
525
+
526
+ def call(self, mol, context=None):
527
+ mol = self.mol_from_sequence(mol)
528
+ subgraph_indicator = [
529
+ int(atom.GetProp('react_idx')) for atom in mol.atoms
530
+ ]
531
+ if self._super_node:
532
+ subgraph_indicator.append(-1)
533
+ graph = super().call(mol, context=context)
534
+ return graph.update({'node': {'subgraph_indicator': subgraph_indicator}})
535
+
536
+ def get_config(self) -> dict:
537
+ config = super().get_config()
538
+ config['monomers'] = dict(self.monomers)
539
+ return config
540
+
541
+ def mol_from_sequence(self, sequence: str) -> chem.Mol:
542
+ symbols = [
543
+ match.group(0) for match in re.finditer(_monomer_pattern, sequence)
544
+ ]
545
+ monomers = [
546
+ chem.Mol.from_encoding(self.monomers[s]) for s in symbols
547
+ ]
548
+ backbone_template = '[N:{0}][C:{1}][C:{2}](=[O:{3}])'
549
+ product = reactants = ''
550
+ for i in range(len(monomers)):
551
+ backbone = backbone_template.format((i*4)+0, (i*4)+1, (i*4)+2, (i*4)+3)
552
+ product += backbone
553
+ reactants += backbone
554
+ c_terminal = (i == len(monomers) - 1)
555
+ if c_terminal:
556
+ product += f'[O:{(i*4)+4}]'
557
+ reactants += f'[O:{(i*4)+4}]'
558
+ else:
559
+ reactants += '[O]' + '.'
560
+ reaction = rdChemReactions.ReactionFromSmarts(reactants + '>>' + product)
561
+ products = reaction.RunReactants(monomers)
562
+ if not len(products):
563
+ raise ValueError(f'Could not obtain polymer from monomers: {monomers}.')
564
+ polymer = products[0][0]
565
+ return chem.sanitize_mol(polymer)
566
+
567
+
568
+ def save_featurizer(
569
+ featurizer: GraphFeaturizer,
570
+ filepath: str | Path,
571
+ overwrite: bool = True,
572
+ **kwargs
573
+ ) -> None:
574
+ filepath = Path(filepath)
575
+ if filepath.suffix != '.json':
576
+ raise ValueError(
577
+ 'Invalid `filepath` extension for saving a `GraphFeaturizer`. '
578
+ 'A `GraphFeaturizer` should be saved as a JSON file.'
579
+ )
580
+ if not filepath.parent.exists():
581
+ filepath.parent.mkdir(parents=True, exist_ok=True)
582
+ if filepath.exists() and not overwrite:
583
+ return
584
+ serialized_featurizer = keras.saving.serialize_keras_object(featurizer)
585
+ with open(filepath, 'w') as f:
586
+ json.dump(serialized_featurizer, f, indent=4)
587
+
588
+ def load_featurizer(
589
+ filepath: str | Path,
590
+ **kwargs
591
+ ) -> GraphFeaturizer:
592
+ filepath = Path(filepath)
593
+ if filepath.suffix != '.json':
594
+ raise ValueError(
595
+ 'Invalid `filepath` extension for loading a `GraphFeaturizer`. '
596
+ 'A `GraphFeaturizer` should be saved as a JSON file.'
597
+ )
598
+ if not filepath.exists():
599
+ return
600
+ with open(filepath, 'r') as f:
601
+ config = json.load(f)
602
+ return keras.saving.deserialize_keras_object(config)
603
+
604
+ def _add_super_node(
605
+ data: dict[str, dict[str, np.ndarray]]
606
+ ) -> dict[str, dict[str, np.ndarray]]:
607
+
608
+ data['context']['size'] += 1
609
+
610
+ num_nodes = data['node']['feature'].shape[0]
611
+ num_edges = data['edge']['source'].shape[0]
612
+ super_node_index = num_nodes
613
+
614
+ add_self_loops = np.any(
615
+ data['edge']['source'] == data['edge']['target']
616
+ )
617
+ if add_self_loops:
618
+ data['edge']['source'] = np.append(
619
+ data['edge']['source'], super_node_index
620
+ )
621
+ data['edge']['target'] = np.append(
622
+ data['edge']['target'], super_node_index
623
+ )
624
+
625
+ data['node']['feature'] = np.pad(data['node']['feature'], [(0, 1), (0, 0)])
626
+ data['node']['super'] = np.asarray([False] * num_nodes + [True])
627
+ if 'wildcard' in data['node']:
628
+ data['node']['wildcard'] = np.pad(data['node']['wildcard'], [(0, 1)])
629
+
630
+ node_indices = list(range(num_nodes))
631
+ super_node_indices = [super_node_index] * num_nodes
632
+
633
+ data['edge']['source'] = np.append(
634
+ data['edge']['source'], node_indices + super_node_indices
635
+ )
636
+ data['edge']['target'] = np.append(
637
+ data['edge']['target'], super_node_indices + node_indices
638
+ )
639
+
640
+ total_num_edges = data['edge']['source'].shape[0]
641
+ num_super_edges = (total_num_edges - num_edges)
642
+ data['edge']['super'] = np.asarray(
643
+ [False] * num_edges + [True] * num_super_edges
644
+ )
645
+
646
+ if 'feature' in data['edge']:
647
+ data['edge']['feature'] = np.pad(
648
+ data['edge']['feature'], [(0, num_super_edges), (0, 0)]
649
+ )
650
+
651
+ return data
652
+
653
+ def _convert_dtypes(data: dict[str, dict[str, np.ndarray]]) -> np.ndarray:
654
+ for outer_key, inner_dict in data.items():
655
+ for inner_key, inner_value in inner_dict.items():
656
+ if inner_key in ['source', 'target', 'size']:
657
+ data[outer_key][inner_key] = inner_value.astype(np.int32)
658
+ elif np.issubdtype(inner_value.dtype, np.integer):
659
+ data[outer_key][inner_key] = inner_value.astype(np.int32)
660
+ elif np.issubdtype(inner_value.dtype, np.floating):
661
+ data[outer_key][inner_key] = inner_value.astype(np.float32)
662
+ return data
663
+
664
+ def _unpack_inputs(inputs) -> tuple:
665
+ if isinstance(inputs, np.ndarray):
666
+ inputs = tuple(inputs.tolist())
667
+ elif isinstance(inputs, list):
668
+ inputs = tuple(inputs)
669
+ if not isinstance(inputs, tuple):
670
+ mol, context = inputs, {}
671
+ elif isinstance(inputs[0], int) and isinstance(inputs[1], pd.Series):
672
+ index, series = inputs
673
+ mol = series.values[0]
674
+ context = dict(
675
+ zip(
676
+ map(_snake_case, series.index[1:]),
677
+ map(np.asarray, series.values[1:])
678
+ )
679
+ )
680
+ context['index'] = np.asarray(index)
681
+ else:
682
+ mol, *context = inputs
683
+ context = dict(zip(['label', 'sample_weight'], map(np.asarray, context)))
684
+ return mol, context
685
+
686
+ def _snake_case(x: str) -> str:
687
+ return '_'.join(x.lower().split())
688
+
689
+ def _call_kwargs(func) -> bool:
690
+ signature = inspect.signature(func)
691
+ return any(
692
+ (param.kind == inspect.Parameter.VAR_KEYWORD) or (param.name == 'context')
693
+ for param in signature.parameters.values()
694
+ )
695
+
696
+ _monomer_pattern = "|".join([
697
+ r'(\[[A-Za-z0-9]+\]-[A-Z]\[[A-Za-z0-9]+\])', # [Mod]-A[Mod]
698
+ r'([A-Z]\[[A-Za-z0-9]+\]-\[[A-Za-z0-9]+\])', # A[Mod]-[Mod]
699
+ r'(\[[A-Za-z0-9]+\]-[A-Z])', # [Mod]-A
700
+ r'([A-Z]-\[[A-Za-z0-9]+\])', # A-[Mod]
701
+ r'([A-Z]\[[A-Za-z0-9]+\])', # A[Mod]
702
+ r'([A-Z])', # A
703
+ r'\(.*?\)' # (A)
704
+ ])
705
+
706
+ _default_monomers = {
707
+ "A": "N[C@@H](C)C(=O)O",
708
+ "C": "N[C@@H](CS)C(=O)O",
709
+ "D": "N[C@@H](CC(=O)O)C(=O)O",
710
+ "E": "N[C@@H](CCC(=O)O)C(=O)O",
711
+ "F": "N[C@@H](Cc1ccccc1)C(=O)O",
712
+ "G": "NCC(=O)O",
713
+ "H": "N[C@@H](CC1=CN=C-N1)C(=O)O",
714
+ "I": "N[C@@H](C(CC)C)C(=O)O",
715
+ "K": "N[C@@H](CCCCN)C(=O)O",
716
+ "L": "N[C@@H](CC(C)C)C(=O)O",
717
+ "M": "N[C@@H](CCSC)C(=O)O",
718
+ "N": "N[C@@H](CC(=O)N)C(=O)O",
719
+ "P": "N1[C@@H](CCC1)C(=O)O",
720
+ "Q": "N[C@@H](CCC(=O)N)C(=O)O",
721
+ "R": "N[C@@H](CCCNC(=N)N)C(=O)O",
722
+ "S": "N[C@@H](CO)C(=O)O",
723
+ "T": "N[C@@H](C(O)C)C(=O)O",
724
+ "V": "N[C@@H](C(C)C)C(=O)O",
725
+ "W": "N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)O",
726
+ "Y": "N[C@@H](Cc1ccc(O)cc1)C(=O)O",
727
+ }