molcraft 0.1.0rc10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- molcraft/__init__.py +18 -0
- molcraft/callbacks.py +100 -0
- molcraft/chem.py +714 -0
- molcraft/datasets.py +132 -0
- molcraft/descriptors.py +149 -0
- molcraft/features.py +379 -0
- molcraft/featurizers.py +727 -0
- molcraft/layers.py +2034 -0
- molcraft/losses.py +37 -0
- molcraft/models.py +627 -0
- molcraft/ops.py +195 -0
- molcraft/records.py +187 -0
- molcraft/tensors.py +561 -0
- molcraft/trainers.py +212 -0
- molcraft-0.1.0rc10.dist-info/METADATA +118 -0
- molcraft-0.1.0rc10.dist-info/RECORD +19 -0
- molcraft-0.1.0rc10.dist-info/WHEEL +5 -0
- molcraft-0.1.0rc10.dist-info/licenses/LICENSE +21 -0
- molcraft-0.1.0rc10.dist-info/top_level.txt +1 -0
molcraft/featurizers.py
ADDED
|
@@ -0,0 +1,727 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
import inspect
|
|
3
|
+
import keras
|
|
4
|
+
import json
|
|
5
|
+
import abc
|
|
6
|
+
import re
|
|
7
|
+
import typing
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pandas as pd
|
|
10
|
+
import tensorflow as tf
|
|
11
|
+
import multiprocessing as mp
|
|
12
|
+
|
|
13
|
+
from rdkit.Chem import rdChemReactions
|
|
14
|
+
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
|
|
17
|
+
from molcraft import tensors
|
|
18
|
+
from molcraft import features
|
|
19
|
+
from molcraft import records
|
|
20
|
+
from molcraft import chem
|
|
21
|
+
from molcraft import descriptors
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
25
|
+
class GraphFeaturizer(abc.ABC):
|
|
26
|
+
|
|
27
|
+
"""Base graph featurizer.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
@abc.abstractmethod
|
|
31
|
+
def call(self, x: typing.Any, context: dict) -> tensors.GraphTensor:
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
def get_config(self) -> dict:
|
|
35
|
+
return {}
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
def from_config(cls, config: dict) -> 'GraphFeaturizer':
|
|
39
|
+
return cls(**config)
|
|
40
|
+
|
|
41
|
+
def save(self, filepath: str | Path, *args, **kwargs) -> None:
|
|
42
|
+
save_featurizer(self, filepath, *args, **kwargs)
|
|
43
|
+
|
|
44
|
+
@staticmethod
|
|
45
|
+
def load(filepath: str | Path, *args, **kwargs) -> 'GraphFeaturizer':
|
|
46
|
+
return load_featurizer(filepath, *args, **kwargs)
|
|
47
|
+
|
|
48
|
+
def write_records(self, inputs: typing.Iterable, path: str | Path, **kwargs) -> None:
|
|
49
|
+
records.write(
|
|
50
|
+
inputs, featurizer=self, path=path, **kwargs
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
@staticmethod
|
|
54
|
+
def read_records(path: str | Path, **kwargs) -> tf.data.Dataset:
|
|
55
|
+
return records.read(
|
|
56
|
+
path=path, **kwargs
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
def _call(self, inputs: typing.Any) -> tensors.GraphTensor:
|
|
60
|
+
inputs, context = _unpack_inputs(inputs)
|
|
61
|
+
if _call_kwargs(self.call):
|
|
62
|
+
graph = self.call(inputs, context=context)
|
|
63
|
+
else:
|
|
64
|
+
graph = self.call(inputs)
|
|
65
|
+
if not isinstance(graph, tensors.GraphTensor):
|
|
66
|
+
graph = tensors.from_dict(graph)
|
|
67
|
+
return graph
|
|
68
|
+
|
|
69
|
+
def __call__(
|
|
70
|
+
self,
|
|
71
|
+
inputs: typing.Iterable,
|
|
72
|
+
*,
|
|
73
|
+
multiprocessing: bool = False,
|
|
74
|
+
processes: int | None = None,
|
|
75
|
+
device: str = '/cpu:0',
|
|
76
|
+
) -> tensors.GraphTensor:
|
|
77
|
+
if not isinstance(
|
|
78
|
+
inputs, (list, np.ndarray, pd.Series, pd.DataFrame, typing.Generator)
|
|
79
|
+
):
|
|
80
|
+
return self._call(inputs)
|
|
81
|
+
|
|
82
|
+
if isinstance(inputs, (np.ndarray, pd.Series)):
|
|
83
|
+
inputs = inputs.tolist()
|
|
84
|
+
elif isinstance(inputs, pd.DataFrame):
|
|
85
|
+
inputs = inputs.iterrows()
|
|
86
|
+
|
|
87
|
+
if not multiprocessing:
|
|
88
|
+
outputs = [self._call(x) for x in inputs]
|
|
89
|
+
else:
|
|
90
|
+
with tf.device(device):
|
|
91
|
+
with mp.Pool(processes) as pool:
|
|
92
|
+
outputs = pool.map(func=self._call, iterable=inputs)
|
|
93
|
+
outputs = [x for x in outputs if x is not None]
|
|
94
|
+
if tensors.is_scalar(outputs[0]):
|
|
95
|
+
return tf.stack(outputs, axis=0)
|
|
96
|
+
return tf.concat(outputs, axis=0)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
100
|
+
class MolGraphFeaturizer(GraphFeaturizer):
|
|
101
|
+
|
|
102
|
+
"""Molecular graph featurizer.
|
|
103
|
+
|
|
104
|
+
Converts SMILES or InChI strings to a molecular graph.
|
|
105
|
+
|
|
106
|
+
The molecular graph may encode a single molecule or a batch of molecules.
|
|
107
|
+
|
|
108
|
+
Example:
|
|
109
|
+
|
|
110
|
+
>>> import molcraft
|
|
111
|
+
>>>
|
|
112
|
+
>>> featurizer = molcraft.featurizers.MolGraphFeaturizer(
|
|
113
|
+
... atom_features=[
|
|
114
|
+
... molcraft.features.AtomType(),
|
|
115
|
+
... molcraft.features.NumHydrogens(),
|
|
116
|
+
... molcraft.features.Degree(),
|
|
117
|
+
... ],
|
|
118
|
+
... bond_features=[
|
|
119
|
+
... molcraft.features.BondType(),
|
|
120
|
+
... ],
|
|
121
|
+
... super_node=False,
|
|
122
|
+
... self_loops=False,
|
|
123
|
+
... )
|
|
124
|
+
>>>
|
|
125
|
+
>>> graph = featurizer(["N[C@@H](C)C(=O)O", "N[C@@H](CS)C(=O)O"])
|
|
126
|
+
>>> graph
|
|
127
|
+
GraphTensor(
|
|
128
|
+
context={
|
|
129
|
+
'size': <tf.Tensor: shape=[2], dtype=int32>
|
|
130
|
+
},
|
|
131
|
+
node={
|
|
132
|
+
'feature': <tf.Tensor: shape=[13, 129], dtype=float32>
|
|
133
|
+
},
|
|
134
|
+
edge={
|
|
135
|
+
'source': <tf.Tensor: shape=[22], dtype=int32>,
|
|
136
|
+
'target': <tf.Tensor: shape=[22], dtype=int32>,
|
|
137
|
+
'feature': <tf.Tensor: shape=[22, 5], dtype=float32>
|
|
138
|
+
}
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
atom_features:
|
|
143
|
+
A list of `features.Feature` encoded as the node features.
|
|
144
|
+
bond_features:
|
|
145
|
+
A list of `features.Feature` encoded as the edge features.
|
|
146
|
+
molecule_features:
|
|
147
|
+
A list of `descriptors.Descriptor` encoded as the context feature.
|
|
148
|
+
super_node:
|
|
149
|
+
A boolean specifying whether to include a super node.
|
|
150
|
+
self_loops:
|
|
151
|
+
A boolean specifying whether self loops exist.
|
|
152
|
+
include_hydrogens:
|
|
153
|
+
A boolean specifying whether hydrogens should be encoded as nodes.
|
|
154
|
+
wildcards:
|
|
155
|
+
A boolean specifying whether wildcards exist. If True, wildcard labels will
|
|
156
|
+
be encoded in the graph and separately embedded in `layers.NodeEmbedding`.
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
def __init__(
|
|
160
|
+
self,
|
|
161
|
+
atom_features: list[features.Feature] | str = 'auto',
|
|
162
|
+
bond_features: list[features.Feature] | str | None = 'auto',
|
|
163
|
+
molecule_features: list[descriptors.Descriptor] | str | None = None,
|
|
164
|
+
super_node: bool = False,
|
|
165
|
+
self_loops: bool = False,
|
|
166
|
+
include_hydrogens: bool = False,
|
|
167
|
+
wildcards: bool = False,
|
|
168
|
+
) -> None:
|
|
169
|
+
use_default_atom_features = (
|
|
170
|
+
atom_features == 'auto' or atom_features == 'default'
|
|
171
|
+
)
|
|
172
|
+
if use_default_atom_features:
|
|
173
|
+
atom_features = [features.AtomType(), features.Degree()]
|
|
174
|
+
if not include_hydrogens:
|
|
175
|
+
atom_features += [features.NumHydrogens()]
|
|
176
|
+
|
|
177
|
+
use_default_bond_features = (
|
|
178
|
+
bond_features == 'auto' or bond_features == 'default'
|
|
179
|
+
)
|
|
180
|
+
if use_default_bond_features:
|
|
181
|
+
bond_features = [features.BondType()]
|
|
182
|
+
|
|
183
|
+
use_default_molecule_features = (
|
|
184
|
+
molecule_features == 'auto' or molecule_features == 'default'
|
|
185
|
+
)
|
|
186
|
+
if use_default_molecule_features:
|
|
187
|
+
molecule_features = [
|
|
188
|
+
descriptors.MolWeight(),
|
|
189
|
+
descriptors.TotalPolarSurfaceArea(),
|
|
190
|
+
descriptors.LogP(),
|
|
191
|
+
descriptors.MolarRefractivity(),
|
|
192
|
+
descriptors.NumHeavyAtoms(),
|
|
193
|
+
descriptors.NumHeteroatoms(),
|
|
194
|
+
descriptors.NumHydrogenDonors(),
|
|
195
|
+
descriptors.NumHydrogenAcceptors(),
|
|
196
|
+
descriptors.NumRotatableBonds(),
|
|
197
|
+
descriptors.NumRings(),
|
|
198
|
+
]
|
|
199
|
+
|
|
200
|
+
self._atom_features = atom_features
|
|
201
|
+
self._bond_features = bond_features
|
|
202
|
+
self._molecule_features = molecule_features
|
|
203
|
+
self._include_hydrogens = include_hydrogens
|
|
204
|
+
self._wildcards = wildcards
|
|
205
|
+
self._self_loops = self_loops
|
|
206
|
+
self._super_node = super_node
|
|
207
|
+
|
|
208
|
+
def call(
|
|
209
|
+
self,
|
|
210
|
+
mol: str | chem.Mol | tuple,
|
|
211
|
+
context: dict | None = None
|
|
212
|
+
) -> tensors.GraphTensor:
|
|
213
|
+
|
|
214
|
+
if isinstance(mol, str):
|
|
215
|
+
mol = chem.Mol.from_encoding(
|
|
216
|
+
mol, explicit_hs=self._include_hydrogens
|
|
217
|
+
)
|
|
218
|
+
elif isinstance(mol, chem.RDKitMol):
|
|
219
|
+
mol = chem.Mol.cast(mol)
|
|
220
|
+
|
|
221
|
+
data = {'context': {}, 'node': {}, 'edge': {}}
|
|
222
|
+
|
|
223
|
+
data['context']['size'] = np.asarray(mol.num_atoms)
|
|
224
|
+
|
|
225
|
+
if self._molecule_features is not None:
|
|
226
|
+
data['context']['feature'] = np.concatenate(
|
|
227
|
+
[f(mol) for f in self._molecule_features], axis=-1
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
if context:
|
|
231
|
+
data['context'].update(context)
|
|
232
|
+
|
|
233
|
+
data['node']['feature'] = np.concatenate(
|
|
234
|
+
[f(mol) for f in self._atom_features], axis=-1
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
if self._wildcards:
|
|
238
|
+
wildcard_labels = np.asarray([
|
|
239
|
+
(atom.label or 0) + 1 if atom.symbol == "*" else 0
|
|
240
|
+
for atom in mol.atoms
|
|
241
|
+
])
|
|
242
|
+
data['node']['wildcard'] = wildcard_labels
|
|
243
|
+
data['node']['feature'] = np.where(
|
|
244
|
+
wildcard_labels[:, None],
|
|
245
|
+
np.zeros_like(data['node']['feature']),
|
|
246
|
+
data['node']['feature']
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
data['edge']['source'], data['edge']['target'] = mol.adjacency(
|
|
250
|
+
fill='full', sparse=True, self_loops=self._self_loops
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
if self._bond_features is not None:
|
|
254
|
+
bond_features = np.concatenate(
|
|
255
|
+
[f(mol) for f in self._bond_features], axis=-1
|
|
256
|
+
)
|
|
257
|
+
if self._self_loops:
|
|
258
|
+
bond_features = np.pad(bond_features, [(0, 1), (0, 0)])
|
|
259
|
+
|
|
260
|
+
bond_indices = [
|
|
261
|
+
mol.get_bond_between_atoms(i, j).index if (i != j) else -1
|
|
262
|
+
for (i, j) in zip(data['edge']['source'], data['edge']['target'])
|
|
263
|
+
]
|
|
264
|
+
|
|
265
|
+
data['edge']['feature'] = bond_features[bond_indices]
|
|
266
|
+
|
|
267
|
+
if self._super_node:
|
|
268
|
+
data = _add_super_node(data)
|
|
269
|
+
|
|
270
|
+
return tensors.GraphTensor(**_convert_dtypes(data))
|
|
271
|
+
|
|
272
|
+
def get_config(self):
|
|
273
|
+
config = super().get_config()
|
|
274
|
+
config.update({
|
|
275
|
+
'atom_features': keras.saving.serialize_keras_object(
|
|
276
|
+
self._atom_features
|
|
277
|
+
),
|
|
278
|
+
'bond_features': keras.saving.serialize_keras_object(
|
|
279
|
+
self._bond_features
|
|
280
|
+
),
|
|
281
|
+
'molecule_features': keras.saving.serialize_keras_object(
|
|
282
|
+
self._molecule_features
|
|
283
|
+
),
|
|
284
|
+
'super_node': self._super_node,
|
|
285
|
+
'self_loops': self._self_loops,
|
|
286
|
+
'include_hydrogens': self._include_hydrogens,
|
|
287
|
+
'wildcards': self._wildcards,
|
|
288
|
+
})
|
|
289
|
+
return config
|
|
290
|
+
|
|
291
|
+
@classmethod
|
|
292
|
+
def from_config(cls, config: dict):
|
|
293
|
+
config['atom_features'] = keras.saving.deserialize_keras_object(
|
|
294
|
+
config['atom_features']
|
|
295
|
+
)
|
|
296
|
+
config['bond_features'] = keras.saving.deserialize_keras_object(
|
|
297
|
+
config['bond_features']
|
|
298
|
+
)
|
|
299
|
+
config['molecule_features'] = keras.saving.deserialize_keras_object(
|
|
300
|
+
config['molecule_features']
|
|
301
|
+
)
|
|
302
|
+
return cls(**config)
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
306
|
+
class MolGraphFeaturizer3D(MolGraphFeaturizer):
|
|
307
|
+
|
|
308
|
+
"""3D Molecular graph featurizer.
|
|
309
|
+
|
|
310
|
+
Converts SMILES or InChI strings to a 3d molecular graph.
|
|
311
|
+
|
|
312
|
+
The molecular graph may encode a single molecule or a batch of molecules.
|
|
313
|
+
|
|
314
|
+
Example:
|
|
315
|
+
|
|
316
|
+
>>> import molcraft
|
|
317
|
+
>>>
|
|
318
|
+
>>> featurizer = molcraft.featurizers.MolGraphFeaturizer3D(
|
|
319
|
+
... atom_features=[
|
|
320
|
+
... molcraft.features.AtomType(),
|
|
321
|
+
... molcraft.features.NumHydrogens(),
|
|
322
|
+
... molcraft.features.Degree(),
|
|
323
|
+
... ],
|
|
324
|
+
... radius=5.0,
|
|
325
|
+
... random_seed=42,
|
|
326
|
+
... )
|
|
327
|
+
>>>
|
|
328
|
+
>>> graph = featurizer(["N[C@@H](C)C(=O)O", "N[C@@H](CS)C(=O)O"])
|
|
329
|
+
>>> graph
|
|
330
|
+
GraphTensor(
|
|
331
|
+
context={
|
|
332
|
+
'size': <tf.Tensor: shape=[2], dtype=int32>
|
|
333
|
+
},
|
|
334
|
+
node={
|
|
335
|
+
'feature': <tf.Tensor: shape=[13, 129], dtype=float32>,
|
|
336
|
+
'coordinate': <tf.Tensor: shape=[13, 3], dtype=float32>
|
|
337
|
+
},
|
|
338
|
+
edge={
|
|
339
|
+
'source': <tf.Tensor: shape=[72], dtype=int32>,
|
|
340
|
+
'target': <tf.Tensor: shape=[72], dtype=int32>,
|
|
341
|
+
'feature': <tf.Tensor: shape=[72, 12], dtype=float32>
|
|
342
|
+
}
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
Args:
|
|
346
|
+
atom_features:
|
|
347
|
+
A list of `features.Feature` encoded as the node features.
|
|
348
|
+
pair_features:
|
|
349
|
+
A list of `features.PairFeature` encoded as the edge features.
|
|
350
|
+
molecule_features:
|
|
351
|
+
A list of `descriptors.Descriptor` encoded as the context feature.
|
|
352
|
+
super_node:
|
|
353
|
+
A boolean specifying whether to include a super node.
|
|
354
|
+
self_loops:
|
|
355
|
+
A boolean specifying whether self loops exist.
|
|
356
|
+
include_hydrogens:
|
|
357
|
+
A boolean specifying whether hydrogens should be encoded as nodes.
|
|
358
|
+
wildcards:
|
|
359
|
+
A boolean specifying whether wildcards exist. If True, wildcard labels will
|
|
360
|
+
be encoded in the graph and separately embedded in `layers.NodeEmbedding`.
|
|
361
|
+
radius:
|
|
362
|
+
A floating point value specifying maximum edge length.
|
|
363
|
+
random_seed:
|
|
364
|
+
An integer specifying the random seed for the conformer generation.
|
|
365
|
+
"""
|
|
366
|
+
|
|
367
|
+
def __init__(
|
|
368
|
+
self,
|
|
369
|
+
atom_features: list[features.Feature] | str = 'auto',
|
|
370
|
+
pair_features: list[features.PairFeature] | str = 'auto',
|
|
371
|
+
molecule_features: features.Feature | str | None = None,
|
|
372
|
+
super_node: bool = False,
|
|
373
|
+
self_loops: bool = False,
|
|
374
|
+
include_hydrogens: bool = False,
|
|
375
|
+
wildcards: bool = False,
|
|
376
|
+
radius: int | float | None = 6.0,
|
|
377
|
+
random_seed: int | None = None,
|
|
378
|
+
**kwargs,
|
|
379
|
+
) -> None:
|
|
380
|
+
kwargs.pop('bond_features', None)
|
|
381
|
+
super().__init__(
|
|
382
|
+
atom_features=atom_features,
|
|
383
|
+
bond_features=None,
|
|
384
|
+
molecule_features=molecule_features,
|
|
385
|
+
super_node=super_node,
|
|
386
|
+
self_loops=self_loops,
|
|
387
|
+
include_hydrogens=include_hydrogens,
|
|
388
|
+
wildcards=wildcards,
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
use_default_pair_features = (
|
|
392
|
+
pair_features == 'auto' or pair_features == 'default'
|
|
393
|
+
)
|
|
394
|
+
if use_default_pair_features:
|
|
395
|
+
pair_features = [features.PairDistance()]
|
|
396
|
+
|
|
397
|
+
self._pair_features = pair_features
|
|
398
|
+
self._radius = float(radius) if radius else None
|
|
399
|
+
self._random_seed = random_seed
|
|
400
|
+
|
|
401
|
+
def call(
|
|
402
|
+
self,
|
|
403
|
+
mol: str | chem.Mol | tuple,
|
|
404
|
+
context: dict | None = None
|
|
405
|
+
) -> tensors.GraphTensor:
|
|
406
|
+
|
|
407
|
+
if isinstance(mol, str):
|
|
408
|
+
mol = chem.Mol.from_encoding(
|
|
409
|
+
mol, explicit_hs=True
|
|
410
|
+
)
|
|
411
|
+
elif isinstance(mol, chem.RDKitMol):
|
|
412
|
+
mol = chem.Mol.cast(mol)
|
|
413
|
+
|
|
414
|
+
if mol.num_conformers == 0:
|
|
415
|
+
mol = chem.embed_conformers(
|
|
416
|
+
mol, num_conformers=1, random_seed=self._random_seed
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
if not self._include_hydrogens:
|
|
420
|
+
mol = chem.remove_hs(mol)
|
|
421
|
+
|
|
422
|
+
data = {'context': {}, 'node': {}, 'edge': {}}
|
|
423
|
+
|
|
424
|
+
data['context']['size'] = np.asarray(mol.num_atoms)
|
|
425
|
+
|
|
426
|
+
if self._molecule_features is not None:
|
|
427
|
+
data['context']['feature'] = np.concatenate(
|
|
428
|
+
[f(mol) for f in self._molecule_features], axis=-1
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
if context:
|
|
432
|
+
data['context'].update(context)
|
|
433
|
+
|
|
434
|
+
conformer = mol.get_conformer()
|
|
435
|
+
|
|
436
|
+
data['node']['feature'] = np.concatenate(
|
|
437
|
+
[f(mol) for f in self._atom_features], axis=-1
|
|
438
|
+
)
|
|
439
|
+
data['node']['coordinate'] = conformer.coordinates
|
|
440
|
+
|
|
441
|
+
if self._wildcards:
|
|
442
|
+
wildcard_labels = np.asarray([
|
|
443
|
+
(atom.label or 0) + 1 if atom.symbol == "*" else 0
|
|
444
|
+
for atom in mol.atoms
|
|
445
|
+
])
|
|
446
|
+
data['node']['wildcard'] = wildcard_labels
|
|
447
|
+
data['node']['feature'] = np.where(
|
|
448
|
+
wildcard_labels[:, None],
|
|
449
|
+
np.zeros_like(data['node']['feature']),
|
|
450
|
+
data['node']['feature']
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
adjacency_matrix = conformer.adjacency(
|
|
454
|
+
fill='full', radius=self._radius, sparse=False, self_loops=self._self_loops,
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
data['edge']['source'], data['edge']['target'] = np.where(adjacency_matrix)
|
|
458
|
+
|
|
459
|
+
if self._pair_features is not None:
|
|
460
|
+
pair_features = np.concatenate(
|
|
461
|
+
[f(mol) for f in self._pair_features], axis=-1
|
|
462
|
+
)
|
|
463
|
+
pair_keep = adjacency_matrix.reshape(-1).astype(bool)
|
|
464
|
+
data['edge']['feature'] = pair_features[pair_keep]
|
|
465
|
+
|
|
466
|
+
if self._super_node:
|
|
467
|
+
data = _add_super_node(data)
|
|
468
|
+
data['node']['coordinate'] = np.concatenate(
|
|
469
|
+
[data['node']['coordinate'], conformer.centroid[None]], axis=0
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
return tensors.GraphTensor(**_convert_dtypes(data))
|
|
473
|
+
|
|
474
|
+
@property
|
|
475
|
+
def random_seed(self) -> int | None:
|
|
476
|
+
return self._random_seed
|
|
477
|
+
|
|
478
|
+
@random_seed.setter
|
|
479
|
+
def random_seed(self, value: int) -> None:
|
|
480
|
+
self._random_seed = value
|
|
481
|
+
|
|
482
|
+
def get_config(self):
|
|
483
|
+
config = super().get_config()
|
|
484
|
+
config['radius'] = self._radius
|
|
485
|
+
config['pair_features'] = keras.saving.serialize_keras_object(
|
|
486
|
+
self._pair_features
|
|
487
|
+
)
|
|
488
|
+
config['random_seed'] = self._random_seed
|
|
489
|
+
return config
|
|
490
|
+
|
|
491
|
+
@classmethod
|
|
492
|
+
def from_config(cls, config: dict):
|
|
493
|
+
config['pair_features'] = keras.saving.deserialize_keras_object(
|
|
494
|
+
config['pair_features']
|
|
495
|
+
)
|
|
496
|
+
return super().from_config(config)
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
class PeptideGraphFeaturizer(MolGraphFeaturizer):
|
|
500
|
+
|
|
501
|
+
def __init__(
|
|
502
|
+
self,
|
|
503
|
+
atom_features: list[features.Feature] | str = 'auto',
|
|
504
|
+
bond_features: list[features.Feature] | str | None = 'auto',
|
|
505
|
+
molecule_features: list[descriptors.Descriptor] | str | None = None,
|
|
506
|
+
super_node: bool = False,
|
|
507
|
+
self_loops: bool = False,
|
|
508
|
+
include_hydrogens: bool = False,
|
|
509
|
+
wildcards: bool = False,
|
|
510
|
+
monomers: dict[str, str] = None
|
|
511
|
+
) -> None:
|
|
512
|
+
super().__init__(
|
|
513
|
+
atom_features=atom_features,
|
|
514
|
+
bond_features=bond_features,
|
|
515
|
+
molecule_features=molecule_features,
|
|
516
|
+
super_node=super_node,
|
|
517
|
+
self_loops=self_loops,
|
|
518
|
+
include_hydrogens=include_hydrogens,
|
|
519
|
+
wildcards=wildcards,
|
|
520
|
+
)
|
|
521
|
+
if not monomers:
|
|
522
|
+
monomers = {}
|
|
523
|
+
monomers = {**_default_monomers, **monomers}
|
|
524
|
+
self.monomers = monomers
|
|
525
|
+
|
|
526
|
+
def call(self, mol, context=None):
|
|
527
|
+
mol = self.mol_from_sequence(mol)
|
|
528
|
+
subgraph_indicator = [
|
|
529
|
+
int(atom.GetProp('react_idx')) for atom in mol.atoms
|
|
530
|
+
]
|
|
531
|
+
if self._super_node:
|
|
532
|
+
subgraph_indicator.append(-1)
|
|
533
|
+
graph = super().call(mol, context=context)
|
|
534
|
+
return graph.update({'node': {'subgraph_indicator': subgraph_indicator}})
|
|
535
|
+
|
|
536
|
+
def get_config(self) -> dict:
|
|
537
|
+
config = super().get_config()
|
|
538
|
+
config['monomers'] = dict(self.monomers)
|
|
539
|
+
return config
|
|
540
|
+
|
|
541
|
+
def mol_from_sequence(self, sequence: str) -> chem.Mol:
|
|
542
|
+
symbols = [
|
|
543
|
+
match.group(0) for match in re.finditer(_monomer_pattern, sequence)
|
|
544
|
+
]
|
|
545
|
+
monomers = [
|
|
546
|
+
chem.Mol.from_encoding(self.monomers[s]) for s in symbols
|
|
547
|
+
]
|
|
548
|
+
backbone_template = '[N:{0}][C:{1}][C:{2}](=[O:{3}])'
|
|
549
|
+
product = reactants = ''
|
|
550
|
+
for i in range(len(monomers)):
|
|
551
|
+
backbone = backbone_template.format((i*4)+0, (i*4)+1, (i*4)+2, (i*4)+3)
|
|
552
|
+
product += backbone
|
|
553
|
+
reactants += backbone
|
|
554
|
+
c_terminal = (i == len(monomers) - 1)
|
|
555
|
+
if c_terminal:
|
|
556
|
+
product += f'[O:{(i*4)+4}]'
|
|
557
|
+
reactants += f'[O:{(i*4)+4}]'
|
|
558
|
+
else:
|
|
559
|
+
reactants += '[O]' + '.'
|
|
560
|
+
reaction = rdChemReactions.ReactionFromSmarts(reactants + '>>' + product)
|
|
561
|
+
products = reaction.RunReactants(monomers)
|
|
562
|
+
if not len(products):
|
|
563
|
+
raise ValueError(f'Could not obtain polymer from monomers: {monomers}.')
|
|
564
|
+
polymer = products[0][0]
|
|
565
|
+
return chem.sanitize_mol(polymer)
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
def save_featurizer(
|
|
569
|
+
featurizer: GraphFeaturizer,
|
|
570
|
+
filepath: str | Path,
|
|
571
|
+
overwrite: bool = True,
|
|
572
|
+
**kwargs
|
|
573
|
+
) -> None:
|
|
574
|
+
filepath = Path(filepath)
|
|
575
|
+
if filepath.suffix != '.json':
|
|
576
|
+
raise ValueError(
|
|
577
|
+
'Invalid `filepath` extension for saving a `GraphFeaturizer`. '
|
|
578
|
+
'A `GraphFeaturizer` should be saved as a JSON file.'
|
|
579
|
+
)
|
|
580
|
+
if not filepath.parent.exists():
|
|
581
|
+
filepath.parent.mkdir(parents=True, exist_ok=True)
|
|
582
|
+
if filepath.exists() and not overwrite:
|
|
583
|
+
return
|
|
584
|
+
serialized_featurizer = keras.saving.serialize_keras_object(featurizer)
|
|
585
|
+
with open(filepath, 'w') as f:
|
|
586
|
+
json.dump(serialized_featurizer, f, indent=4)
|
|
587
|
+
|
|
588
|
+
def load_featurizer(
|
|
589
|
+
filepath: str | Path,
|
|
590
|
+
**kwargs
|
|
591
|
+
) -> GraphFeaturizer:
|
|
592
|
+
filepath = Path(filepath)
|
|
593
|
+
if filepath.suffix != '.json':
|
|
594
|
+
raise ValueError(
|
|
595
|
+
'Invalid `filepath` extension for loading a `GraphFeaturizer`. '
|
|
596
|
+
'A `GraphFeaturizer` should be saved as a JSON file.'
|
|
597
|
+
)
|
|
598
|
+
if not filepath.exists():
|
|
599
|
+
return
|
|
600
|
+
with open(filepath, 'r') as f:
|
|
601
|
+
config = json.load(f)
|
|
602
|
+
return keras.saving.deserialize_keras_object(config)
|
|
603
|
+
|
|
604
|
+
def _add_super_node(
|
|
605
|
+
data: dict[str, dict[str, np.ndarray]]
|
|
606
|
+
) -> dict[str, dict[str, np.ndarray]]:
|
|
607
|
+
|
|
608
|
+
data['context']['size'] += 1
|
|
609
|
+
|
|
610
|
+
num_nodes = data['node']['feature'].shape[0]
|
|
611
|
+
num_edges = data['edge']['source'].shape[0]
|
|
612
|
+
super_node_index = num_nodes
|
|
613
|
+
|
|
614
|
+
add_self_loops = np.any(
|
|
615
|
+
data['edge']['source'] == data['edge']['target']
|
|
616
|
+
)
|
|
617
|
+
if add_self_loops:
|
|
618
|
+
data['edge']['source'] = np.append(
|
|
619
|
+
data['edge']['source'], super_node_index
|
|
620
|
+
)
|
|
621
|
+
data['edge']['target'] = np.append(
|
|
622
|
+
data['edge']['target'], super_node_index
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
data['node']['feature'] = np.pad(data['node']['feature'], [(0, 1), (0, 0)])
|
|
626
|
+
data['node']['super'] = np.asarray([False] * num_nodes + [True])
|
|
627
|
+
if 'wildcard' in data['node']:
|
|
628
|
+
data['node']['wildcard'] = np.pad(data['node']['wildcard'], [(0, 1)])
|
|
629
|
+
|
|
630
|
+
node_indices = list(range(num_nodes))
|
|
631
|
+
super_node_indices = [super_node_index] * num_nodes
|
|
632
|
+
|
|
633
|
+
data['edge']['source'] = np.append(
|
|
634
|
+
data['edge']['source'], node_indices + super_node_indices
|
|
635
|
+
)
|
|
636
|
+
data['edge']['target'] = np.append(
|
|
637
|
+
data['edge']['target'], super_node_indices + node_indices
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
total_num_edges = data['edge']['source'].shape[0]
|
|
641
|
+
num_super_edges = (total_num_edges - num_edges)
|
|
642
|
+
data['edge']['super'] = np.asarray(
|
|
643
|
+
[False] * num_edges + [True] * num_super_edges
|
|
644
|
+
)
|
|
645
|
+
|
|
646
|
+
if 'feature' in data['edge']:
|
|
647
|
+
data['edge']['feature'] = np.pad(
|
|
648
|
+
data['edge']['feature'], [(0, num_super_edges), (0, 0)]
|
|
649
|
+
)
|
|
650
|
+
|
|
651
|
+
return data
|
|
652
|
+
|
|
653
|
+
def _convert_dtypes(data: dict[str, dict[str, np.ndarray]]) -> np.ndarray:
|
|
654
|
+
for outer_key, inner_dict in data.items():
|
|
655
|
+
for inner_key, inner_value in inner_dict.items():
|
|
656
|
+
if inner_key in ['source', 'target', 'size']:
|
|
657
|
+
data[outer_key][inner_key] = inner_value.astype(np.int32)
|
|
658
|
+
elif np.issubdtype(inner_value.dtype, np.integer):
|
|
659
|
+
data[outer_key][inner_key] = inner_value.astype(np.int32)
|
|
660
|
+
elif np.issubdtype(inner_value.dtype, np.floating):
|
|
661
|
+
data[outer_key][inner_key] = inner_value.astype(np.float32)
|
|
662
|
+
return data
|
|
663
|
+
|
|
664
|
+
def _unpack_inputs(inputs) -> tuple:
|
|
665
|
+
if isinstance(inputs, np.ndarray):
|
|
666
|
+
inputs = tuple(inputs.tolist())
|
|
667
|
+
elif isinstance(inputs, list):
|
|
668
|
+
inputs = tuple(inputs)
|
|
669
|
+
if not isinstance(inputs, tuple):
|
|
670
|
+
mol, context = inputs, {}
|
|
671
|
+
elif isinstance(inputs[0], int) and isinstance(inputs[1], pd.Series):
|
|
672
|
+
index, series = inputs
|
|
673
|
+
mol = series.values[0]
|
|
674
|
+
context = dict(
|
|
675
|
+
zip(
|
|
676
|
+
map(_snake_case, series.index[1:]),
|
|
677
|
+
map(np.asarray, series.values[1:])
|
|
678
|
+
)
|
|
679
|
+
)
|
|
680
|
+
context['index'] = np.asarray(index)
|
|
681
|
+
else:
|
|
682
|
+
mol, *context = inputs
|
|
683
|
+
context = dict(zip(['label', 'sample_weight'], map(np.asarray, context)))
|
|
684
|
+
return mol, context
|
|
685
|
+
|
|
686
|
+
def _snake_case(x: str) -> str:
|
|
687
|
+
return '_'.join(x.lower().split())
|
|
688
|
+
|
|
689
|
+
def _call_kwargs(func) -> bool:
|
|
690
|
+
signature = inspect.signature(func)
|
|
691
|
+
return any(
|
|
692
|
+
(param.kind == inspect.Parameter.VAR_KEYWORD) or (param.name == 'context')
|
|
693
|
+
for param in signature.parameters.values()
|
|
694
|
+
)
|
|
695
|
+
|
|
696
|
+
_monomer_pattern = "|".join([
|
|
697
|
+
r'(\[[A-Za-z0-9]+\]-[A-Z]\[[A-Za-z0-9]+\])', # [Mod]-A[Mod]
|
|
698
|
+
r'([A-Z]\[[A-Za-z0-9]+\]-\[[A-Za-z0-9]+\])', # A[Mod]-[Mod]
|
|
699
|
+
r'(\[[A-Za-z0-9]+\]-[A-Z])', # [Mod]-A
|
|
700
|
+
r'([A-Z]-\[[A-Za-z0-9]+\])', # A-[Mod]
|
|
701
|
+
r'([A-Z]\[[A-Za-z0-9]+\])', # A[Mod]
|
|
702
|
+
r'([A-Z])', # A
|
|
703
|
+
r'\(.*?\)' # (A)
|
|
704
|
+
])
|
|
705
|
+
|
|
706
|
+
_default_monomers = {
|
|
707
|
+
"A": "N[C@@H](C)C(=O)O",
|
|
708
|
+
"C": "N[C@@H](CS)C(=O)O",
|
|
709
|
+
"D": "N[C@@H](CC(=O)O)C(=O)O",
|
|
710
|
+
"E": "N[C@@H](CCC(=O)O)C(=O)O",
|
|
711
|
+
"F": "N[C@@H](Cc1ccccc1)C(=O)O",
|
|
712
|
+
"G": "NCC(=O)O",
|
|
713
|
+
"H": "N[C@@H](CC1=CN=C-N1)C(=O)O",
|
|
714
|
+
"I": "N[C@@H](C(CC)C)C(=O)O",
|
|
715
|
+
"K": "N[C@@H](CCCCN)C(=O)O",
|
|
716
|
+
"L": "N[C@@H](CC(C)C)C(=O)O",
|
|
717
|
+
"M": "N[C@@H](CCSC)C(=O)O",
|
|
718
|
+
"N": "N[C@@H](CC(=O)N)C(=O)O",
|
|
719
|
+
"P": "N1[C@@H](CCC1)C(=O)O",
|
|
720
|
+
"Q": "N[C@@H](CCC(=O)N)C(=O)O",
|
|
721
|
+
"R": "N[C@@H](CCCNC(=N)N)C(=O)O",
|
|
722
|
+
"S": "N[C@@H](CO)C(=O)O",
|
|
723
|
+
"T": "N[C@@H](C(O)C)C(=O)O",
|
|
724
|
+
"V": "N[C@@H](C(C)C)C(=O)O",
|
|
725
|
+
"W": "N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)O",
|
|
726
|
+
"Y": "N[C@@H](Cc1ccc(O)cc1)C(=O)O",
|
|
727
|
+
}
|