molcraft 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of molcraft might be problematic. Click here for more details.
- molcraft/__init__.py +16 -0
- molcraft/callbacks.py +21 -0
- molcraft/chem.py +600 -0
- molcraft/conformers.py +155 -0
- molcraft/descriptors.py +90 -0
- molcraft/experimental/__init__.py +1 -0
- molcraft/experimental/peptides.py +303 -0
- molcraft/features.py +387 -0
- molcraft/featurizers.py +693 -0
- molcraft/layers.py +1224 -0
- molcraft/models.py +441 -0
- molcraft/ops.py +129 -0
- molcraft/records.py +169 -0
- molcraft/tensors.py +527 -0
- molcraft-0.1.0a1.dist-info/METADATA +58 -0
- molcraft-0.1.0a1.dist-info/RECORD +19 -0
- molcraft-0.1.0a1.dist-info/WHEEL +5 -0
- molcraft-0.1.0a1.dist-info/licenses/LICENSE +21 -0
- molcraft-0.1.0a1.dist-info/top_level.txt +1 -0
molcraft/featurizers.py
ADDED
|
@@ -0,0 +1,693 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
import json
|
|
3
|
+
import abc
|
|
4
|
+
import typing
|
|
5
|
+
import copy
|
|
6
|
+
import warnings
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pandas as pd
|
|
9
|
+
import tensorflow as tf
|
|
10
|
+
import multiprocessing as mp
|
|
11
|
+
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
|
|
14
|
+
from molcraft import tensors
|
|
15
|
+
from molcraft import features
|
|
16
|
+
from molcraft import chem
|
|
17
|
+
from molcraft import conformers
|
|
18
|
+
from molcraft import descriptors
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
22
|
+
class Featurizer(abc.ABC):
|
|
23
|
+
|
|
24
|
+
"""Base class for featurizers.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
@abc.abstractmethod
|
|
28
|
+
def call(
|
|
29
|
+
self,
|
|
30
|
+
x: tensors.GraphTensor
|
|
31
|
+
) -> tensors.GraphTensor | list[tensors.GraphTensor]:
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
@abc.abstractmethod
|
|
35
|
+
def stack(
|
|
36
|
+
self,
|
|
37
|
+
call_outputs: list[tensors.GraphTensor]
|
|
38
|
+
) -> tensors.GraphTensor:
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
def get_config(self) -> dict:
|
|
42
|
+
return {}
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def from_config(cls, config: dict) -> 'Featurizer':
|
|
46
|
+
return cls(**config)
|
|
47
|
+
|
|
48
|
+
def save(self, filepath: str | Path, *args, **kwargs) -> None:
|
|
49
|
+
save_featurizer(
|
|
50
|
+
self, filepath, *args, **kwargs
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
@staticmethod
|
|
54
|
+
def load(filepath: str | Path, *args, **kwargs) -> 'Featurizer':
|
|
55
|
+
return load_featurizer(
|
|
56
|
+
filepath, *args, **kwargs
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
def __call__(
|
|
60
|
+
self,
|
|
61
|
+
inputs: str | tuple | list | np.ndarray | pd.DataFrame | pd.Series,
|
|
62
|
+
*,
|
|
63
|
+
multiprocessing: bool = False,
|
|
64
|
+
processes: int | None = None,
|
|
65
|
+
device: str = '/cpu:0',
|
|
66
|
+
**kwargs
|
|
67
|
+
) -> tensors.GraphTensor:
|
|
68
|
+
if isinstance(inputs, (str, tuple)):
|
|
69
|
+
return self.call(inputs)
|
|
70
|
+
if isinstance(inputs, (pd.DataFrame, pd.Series)):
|
|
71
|
+
inputs = inputs.values
|
|
72
|
+
if isinstance(inputs, np.ndarray):
|
|
73
|
+
inputs = list(inputs)
|
|
74
|
+
if not multiprocessing:
|
|
75
|
+
outputs = [self.call(x) for x in inputs]
|
|
76
|
+
else:
|
|
77
|
+
with tf.device(device):
|
|
78
|
+
with mp.Pool(processes) as pool:
|
|
79
|
+
outputs = pool.map(func=self.call, iterable=inputs)
|
|
80
|
+
outputs = [x for x in outputs if x is not None]
|
|
81
|
+
return self.stack(outputs)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
85
|
+
class MolGraphFeaturizer(Featurizer):
|
|
86
|
+
|
|
87
|
+
"""Molecular graph featurizer.
|
|
88
|
+
|
|
89
|
+
Converts SMILES or InChI strings to a molecular graph.
|
|
90
|
+
|
|
91
|
+
The molecular graph may encode a single molecule or a batch of molecules.
|
|
92
|
+
In either case, it is a single graph, with each molecule corresponding to
|
|
93
|
+
a subgraph within the graph.
|
|
94
|
+
|
|
95
|
+
Example:
|
|
96
|
+
|
|
97
|
+
>>> import molcraft
|
|
98
|
+
>>>
|
|
99
|
+
>>> featurizer = molcraft.featurizers.MolGraphFeaturizer(
|
|
100
|
+
... atom_features=[
|
|
101
|
+
... molcraft.features.AtomType(),
|
|
102
|
+
... molcraft.features.TotalNumHs(),
|
|
103
|
+
... molcraft.features.Degree(),
|
|
104
|
+
... ],
|
|
105
|
+
... radius=1
|
|
106
|
+
... )
|
|
107
|
+
>>>
|
|
108
|
+
>>> graph = featurizer(["N[C@@H](C)C(=O)O", "N[C@@H](CS)C(=O)O"])
|
|
109
|
+
>>> graph
|
|
110
|
+
GraphTensor(
|
|
111
|
+
context={
|
|
112
|
+
'size': <tf.Tensor: shape=[2], dtype=int32>
|
|
113
|
+
},
|
|
114
|
+
node={
|
|
115
|
+
'feature': <tf.Tensor: shape=[13, 133], dtype=float32>
|
|
116
|
+
},
|
|
117
|
+
edge={
|
|
118
|
+
'source': <tf.Tensor: shape=[22], dtype=int32>,
|
|
119
|
+
'target': <tf.Tensor: shape=[22], dtype=int32>,
|
|
120
|
+
'feature': <tf.Tensor: shape=[22, 5], dtype=float32>
|
|
121
|
+
}
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
atom_features:
|
|
126
|
+
A list of `features.Feature` encoding the nodes of the molecular graph.
|
|
127
|
+
bond_features:
|
|
128
|
+
A list of `features.Feature` encoding the edges of the molecular graph.
|
|
129
|
+
molecule_features:
|
|
130
|
+
A `features.Feature` encoding the molecule (or `context`) of the graph.
|
|
131
|
+
If `contextual_super_atom` is set to `True`, then this feature will be
|
|
132
|
+
embedded, via `NodeEmbedding`, as a super node in the molecular graph.
|
|
133
|
+
super_atom:
|
|
134
|
+
A boolean specifying whether super atoms exist and should be embedded
|
|
135
|
+
via `NodeEmbedding`.
|
|
136
|
+
radius:
|
|
137
|
+
An integer specifying how many bond lengths should be considered as an
|
|
138
|
+
edge. The default is None (or 1), which specifies that only bonds should
|
|
139
|
+
be considered an edge.
|
|
140
|
+
self_loops:
|
|
141
|
+
A boolean specifying whether self loops exist. If True, this means that
|
|
142
|
+
each node (atom) has an edge (bond) to itself.
|
|
143
|
+
include_hs:
|
|
144
|
+
A boolean specifying whether hydrogens should be encoded as nodes.
|
|
145
|
+
"""
|
|
146
|
+
|
|
147
|
+
def __init__(
|
|
148
|
+
self,
|
|
149
|
+
atom_features: list[features.Feature] | str | None = 'auto',
|
|
150
|
+
bond_features: list[features.Feature] | str | None = 'auto',
|
|
151
|
+
molecule_features: features.Feature | str | None = None,
|
|
152
|
+
super_atom: bool = False,
|
|
153
|
+
radius: int | float | None = None,
|
|
154
|
+
self_loops: bool = False,
|
|
155
|
+
include_hs: bool = False,
|
|
156
|
+
**kwargs,
|
|
157
|
+
) -> None:
|
|
158
|
+
if molecule_features is None:
|
|
159
|
+
molecule_features = kwargs.pop('mol_features', None)
|
|
160
|
+
|
|
161
|
+
self.radius = int(max(radius or 1, 1))
|
|
162
|
+
self.include_hs = include_hs
|
|
163
|
+
self.self_loops = self_loops
|
|
164
|
+
self.super_atom = super_atom
|
|
165
|
+
|
|
166
|
+
default_atom_features = (
|
|
167
|
+
atom_features == 'auto' or atom_features == 'default'
|
|
168
|
+
)
|
|
169
|
+
if default_atom_features:
|
|
170
|
+
atom_features = [features.AtomType()]
|
|
171
|
+
if not self.include_hs:
|
|
172
|
+
atom_features.append(features.TotalNumHs())
|
|
173
|
+
atom_features.append(features.Degree())
|
|
174
|
+
if not isinstance(self, MolGraphFeaturizer3D):
|
|
175
|
+
default_bond_features = (
|
|
176
|
+
bond_features == 'auto' or bond_features == 'default'
|
|
177
|
+
)
|
|
178
|
+
if default_bond_features or self.radius > 1 or self.self_loops:
|
|
179
|
+
vocab = ['zero', 'single', 'double', 'triple', 'aromatic']
|
|
180
|
+
bond_features = [
|
|
181
|
+
features.BondType(vocab)
|
|
182
|
+
]
|
|
183
|
+
default_molecule_features = (
|
|
184
|
+
molecule_features == 'auto' or molecule_features == 'default'
|
|
185
|
+
)
|
|
186
|
+
if default_molecule_features:
|
|
187
|
+
molecule_features = [
|
|
188
|
+
descriptors.MolWeight(),
|
|
189
|
+
descriptors.MolTPSA(),
|
|
190
|
+
descriptors.MolLogP(),
|
|
191
|
+
descriptors.NumHeavyAtoms(),
|
|
192
|
+
descriptors.NumHydrogenDonors(),
|
|
193
|
+
descriptors.NumHydrogenAcceptors(),
|
|
194
|
+
descriptors.NumRotatableBonds(),
|
|
195
|
+
descriptors.NumRings(),
|
|
196
|
+
]
|
|
197
|
+
self._atom_features = atom_features
|
|
198
|
+
self._bond_features = bond_features
|
|
199
|
+
self._molecule_features = molecule_features
|
|
200
|
+
self.feature_dtype = 'float32'
|
|
201
|
+
self.index_dtype = 'int32'
|
|
202
|
+
|
|
203
|
+
def call(self, x: str | typing.Tuple) -> tensors.GraphTensor:
|
|
204
|
+
|
|
205
|
+
if isinstance(x, (tuple, list, np.ndarray)):
|
|
206
|
+
x, *args = x
|
|
207
|
+
else:
|
|
208
|
+
args = []
|
|
209
|
+
|
|
210
|
+
mol = chem.Mol.from_encoding(x, explicit_hs=self.include_hs)
|
|
211
|
+
|
|
212
|
+
if mol is None:
|
|
213
|
+
warn(
|
|
214
|
+
f'Could not obtain `chem.Mol` from {x}. '
|
|
215
|
+
'Proceeding without it.'
|
|
216
|
+
)
|
|
217
|
+
return None
|
|
218
|
+
|
|
219
|
+
atom_feature = self.atom_features(mol)
|
|
220
|
+
bond_feature = self.bond_features(mol)
|
|
221
|
+
context_feature = self.context_feature(mol)
|
|
222
|
+
molecule_size = self.num_atoms(mol)
|
|
223
|
+
|
|
224
|
+
context, node, edge = {}, {}, {}
|
|
225
|
+
for field, value in zip(['size', 'label', 'weight'], [molecule_size] + args):
|
|
226
|
+
context[field] = value
|
|
227
|
+
|
|
228
|
+
if context_feature is not None:
|
|
229
|
+
context['feature'] = context_feature
|
|
230
|
+
|
|
231
|
+
node['feature'] = atom_feature
|
|
232
|
+
|
|
233
|
+
if bond_feature is not None and (self.radius > 1 or self.self_loops):
|
|
234
|
+
# Append 'zero order' bond feature encoding, which encodes non-bonds.
|
|
235
|
+
zero_bond_feature = np.array(
|
|
236
|
+
[[1., 0., 0., 0., 0.]], dtype=bond_feature.dtype
|
|
237
|
+
)
|
|
238
|
+
bond_feature = np.concatenate(
|
|
239
|
+
[bond_feature, zero_bond_feature], axis=0
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
if self.radius == 1:
|
|
243
|
+
edge['source'], edge['target'] = mol.adjacency(
|
|
244
|
+
fill='full', sparse=True, self_loops=self.self_loops, dtype=self.index_dtype
|
|
245
|
+
)
|
|
246
|
+
if bond_feature is not None:
|
|
247
|
+
bond_indices = []
|
|
248
|
+
for (atom_i, atom_j) in zip(edge['source'], edge['target']):
|
|
249
|
+
if atom_i == atom_j:
|
|
250
|
+
bond_indices.append(-1)
|
|
251
|
+
else:
|
|
252
|
+
bond_indices.append(
|
|
253
|
+
mol.get_bond_between_atoms(atom_i, atom_j).index
|
|
254
|
+
)
|
|
255
|
+
edge['feature'] = bond_feature[bond_indices]
|
|
256
|
+
else:
|
|
257
|
+
paths = chem.get_shortest_paths(
|
|
258
|
+
mol, radius=self.radius, self_loops=self.self_loops
|
|
259
|
+
)
|
|
260
|
+
edge['source'] = np.asarray(
|
|
261
|
+
[path[0] for path in paths], dtype=self.index_dtype
|
|
262
|
+
)
|
|
263
|
+
edge['target'] = np.asarray(
|
|
264
|
+
[path[-1] for path in paths], dtype=self.index_dtype
|
|
265
|
+
)
|
|
266
|
+
edge['length'] = np.asarray(
|
|
267
|
+
[len(path) - 1 for path in paths], dtype=self.index_dtype
|
|
268
|
+
)
|
|
269
|
+
if bond_feature is not None:
|
|
270
|
+
edge['feature'] = self._expand_bond_features(
|
|
271
|
+
mol, paths, bond_feature,
|
|
272
|
+
)
|
|
273
|
+
edge['length'] = np.eye(self.radius + 1)[edge['length']]
|
|
274
|
+
|
|
275
|
+
if self.super_atom:
|
|
276
|
+
node, edge = self._add_super_atom(node, edge)
|
|
277
|
+
context['size'] += 1
|
|
278
|
+
|
|
279
|
+
return tensors.GraphTensor(context, node, edge)
|
|
280
|
+
|
|
281
|
+
def stack(self, outputs):
|
|
282
|
+
if tensors.is_scalar(outputs[0]):
|
|
283
|
+
return tf.stack(outputs, axis=0)
|
|
284
|
+
return tf.concat(outputs, axis=0)
|
|
285
|
+
|
|
286
|
+
def atom_features(self, mol: chem.Mol) -> np.ndarray:
|
|
287
|
+
atom_feature: np.ndarray = np.concatenate(
|
|
288
|
+
[f(mol) for f in self._atom_features], axis=-1
|
|
289
|
+
)
|
|
290
|
+
return atom_feature.astype(self.feature_dtype)
|
|
291
|
+
|
|
292
|
+
def bond_features(self, mol: chem.Mol) -> np.ndarray:
|
|
293
|
+
if self._bond_features is None:
|
|
294
|
+
return None
|
|
295
|
+
bond_feature: np.ndarray = np.concatenate(
|
|
296
|
+
[f(mol) for f in self._bond_features], axis=-1
|
|
297
|
+
)
|
|
298
|
+
return bond_feature.astype(self.feature_dtype)
|
|
299
|
+
|
|
300
|
+
def context_feature(self, mol: chem.Mol) -> np.ndarray:
|
|
301
|
+
if self._molecule_features is None:
|
|
302
|
+
return None
|
|
303
|
+
context_feature: np.ndarray = np.concatenate(
|
|
304
|
+
[f(mol) for f in self._molecule_features], axis=-1
|
|
305
|
+
)
|
|
306
|
+
return context_feature.astype(self.feature_dtype)
|
|
307
|
+
|
|
308
|
+
def num_atoms(self, mol: chem.Mol) -> np.ndarray:
|
|
309
|
+
return np.asarray(mol.num_atoms, dtype=self.index_dtype)
|
|
310
|
+
|
|
311
|
+
def num_bonds(self, mol: chem.Mol) -> np.ndarray:
|
|
312
|
+
return np.asarray(mol.num_bonds, dtype=self.index_dtype)
|
|
313
|
+
|
|
314
|
+
def _expand_bond_features(
|
|
315
|
+
self,
|
|
316
|
+
mol: chem.Mol,
|
|
317
|
+
paths: list[list[int]],
|
|
318
|
+
bond_feature: np.ndarray,
|
|
319
|
+
) -> np.ndarray:
|
|
320
|
+
|
|
321
|
+
def bond_feature_lookup(path):
|
|
322
|
+
path_bond_indices = [
|
|
323
|
+
mol.get_bond_between_atoms(path[i], path[i + 1]).index
|
|
324
|
+
for i in range(len(path) - 1)
|
|
325
|
+
]
|
|
326
|
+
padding = [-1] * (self.radius - len(path) + 1)
|
|
327
|
+
path_bond_indices += padding
|
|
328
|
+
return bond_feature[path_bond_indices].reshape(-1)
|
|
329
|
+
|
|
330
|
+
edge_feature = np.asarray(
|
|
331
|
+
[
|
|
332
|
+
bond_feature_lookup(path) for path in paths
|
|
333
|
+
],
|
|
334
|
+
dtype=self.feature_dtype
|
|
335
|
+
).reshape((-1, bond_feature.shape[-1] * self.radius))
|
|
336
|
+
|
|
337
|
+
return edge_feature
|
|
338
|
+
|
|
339
|
+
def _add_super_atom(
|
|
340
|
+
self,
|
|
341
|
+
node: dict[str, np.ndarray],
|
|
342
|
+
edge: dict[str, np.ndarray],
|
|
343
|
+
) -> tuple[dict[str, np.ndarray]]:
|
|
344
|
+
num_super_nodes = 1
|
|
345
|
+
num_nodes = node['feature'].shape[0]
|
|
346
|
+
node = _add_super_nodes(
|
|
347
|
+
node, num_super_nodes, self.feature_dtype
|
|
348
|
+
)
|
|
349
|
+
edge = _add_super_edges(
|
|
350
|
+
edge, num_nodes, num_super_nodes, self.feature_dtype, self.index_dtype
|
|
351
|
+
)
|
|
352
|
+
return node, edge
|
|
353
|
+
|
|
354
|
+
def get_config(self):
|
|
355
|
+
config = super().get_config()
|
|
356
|
+
config.update({
|
|
357
|
+
'atom_features': keras.saving.serialize_keras_object(
|
|
358
|
+
self._atom_features
|
|
359
|
+
),
|
|
360
|
+
'bond_features': keras.saving.serialize_keras_object(
|
|
361
|
+
self._bond_features
|
|
362
|
+
),
|
|
363
|
+
'molecule_features': keras.saving.serialize_keras_object(
|
|
364
|
+
self._molecule_features
|
|
365
|
+
),
|
|
366
|
+
'super_atom': self.super_atom,
|
|
367
|
+
'radius': self.radius,
|
|
368
|
+
'self_loops': self.self_loops,
|
|
369
|
+
'include_hs': self.include_hs,
|
|
370
|
+
})
|
|
371
|
+
return config
|
|
372
|
+
|
|
373
|
+
@classmethod
|
|
374
|
+
def from_config(cls, config: dict):
|
|
375
|
+
config['atom_features'] = keras.saving.deserialize_keras_object(
|
|
376
|
+
config['atom_features']
|
|
377
|
+
)
|
|
378
|
+
config['bond_features'] = keras.saving.deserialize_keras_object(
|
|
379
|
+
config['bond_features']
|
|
380
|
+
)
|
|
381
|
+
config['molecule_features'] = keras.saving.deserialize_keras_object(
|
|
382
|
+
config['molecule_features']
|
|
383
|
+
)
|
|
384
|
+
return cls(**config)
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
class MolGraphFeaturizer3D(MolGraphFeaturizer):
|
|
388
|
+
|
|
389
|
+
"""Molecular 3d-graph featurizer.
|
|
390
|
+
|
|
391
|
+
Converts SMILES or InChI strings to a molecular graph in 3d space.
|
|
392
|
+
Namely, in addition to the information encoded in a standard molecular
|
|
393
|
+
graph, cartesian coordinates are also included.
|
|
394
|
+
|
|
395
|
+
The molecular graph may encode a single molecule or a batch of molecules.
|
|
396
|
+
In either case, it is a single graph, with each molecule corresponding to
|
|
397
|
+
a subgraph within the graph.
|
|
398
|
+
|
|
399
|
+
Example:
|
|
400
|
+
|
|
401
|
+
>>> import molcraft
|
|
402
|
+
>>>
|
|
403
|
+
>>> featurizer = molcraft.featurizers.MolGraphFeaturizer3D(
|
|
404
|
+
... atom_features=[
|
|
405
|
+
... molcraft.features.AtomType(),
|
|
406
|
+
... molcraft.features.TotalNumHs(),
|
|
407
|
+
... molcraft.features.Degree(),
|
|
408
|
+
... ],
|
|
409
|
+
... radius=5.0
|
|
410
|
+
... )
|
|
411
|
+
>>>
|
|
412
|
+
>>> graph = featurizer(["N[C@@H](C)C(=O)O", "N[C@@H](CS)C(=O)O"])
|
|
413
|
+
>>> graph
|
|
414
|
+
GraphTensor(
|
|
415
|
+
context={
|
|
416
|
+
'size': <tf.Tensor: shape=[20], dtype=int32>
|
|
417
|
+
},
|
|
418
|
+
node={
|
|
419
|
+
'feature': <tf.Tensor: shape=[130, 133], dtype=float32>,
|
|
420
|
+
'coordinate': <tf.Tensor: shape=[130, 3], dtype=float32>
|
|
421
|
+
},
|
|
422
|
+
edge={
|
|
423
|
+
'source': <tf.Tensor: shape=[714], dtype=int32>,
|
|
424
|
+
'target': <tf.Tensor: shape=[714], dtype=int32>,
|
|
425
|
+
'feature': <tf.Tensor: shape=[714, 23], dtype=float32>
|
|
426
|
+
}
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
Args:
|
|
430
|
+
atom_features:
|
|
431
|
+
A list of `features.Feature` encoding the nodes of the molecular graph.
|
|
432
|
+
bond_features:
|
|
433
|
+
A list of `features.Feature` encoding the edges of the molecular graph.
|
|
434
|
+
molecule_features:
|
|
435
|
+
A `features.Feature` encoding the molecule (or `context`) of the graph.
|
|
436
|
+
If `contextual_super_atom` is set to `True`, then this feature will be
|
|
437
|
+
embedded, via `NodeEmbedding`, as a super node in the molecular graph.
|
|
438
|
+
conformer_generator:
|
|
439
|
+
A `conformers.ConformerGenerator` which produces conformers. If `auto`
|
|
440
|
+
a `conformers.ConformerEmbedder` will be used. If None, it is assumed
|
|
441
|
+
that the molecule already has conformer(s).
|
|
442
|
+
super_atom:
|
|
443
|
+
A boolean specifying whether super atoms exist and should be embedded
|
|
444
|
+
via `NodeEmbedding`.
|
|
445
|
+
radius:
|
|
446
|
+
A float specifying, for each atom, the maximum distance (in angstroms)
|
|
447
|
+
that another atom should be within to be considered an edge. Default
|
|
448
|
+
is set to 6.0 as this should cover most interactions. This parameter
|
|
449
|
+
can be though of as the receptive field. If None, the radius will be
|
|
450
|
+
infinite so all the receptive field will be the entire space (graph).
|
|
451
|
+
self_loops:
|
|
452
|
+
A boolean specifying whether self loops exist. If True, this means that
|
|
453
|
+
each node (atom) has an edge (bond) to itself.
|
|
454
|
+
include_hs:
|
|
455
|
+
A boolean specifying whether hydrogens should be encoded as nodes.
|
|
456
|
+
"""
|
|
457
|
+
|
|
458
|
+
def __init__(
|
|
459
|
+
self,
|
|
460
|
+
atom_features: list[features.Feature] | str | None = 'auto',
|
|
461
|
+
bond_features: list[features.Feature] | str | None = 'auto',
|
|
462
|
+
molecule_features: features.Feature | str = None,
|
|
463
|
+
conformer_generator: conformers.ConformerProcessor | str | None = 'auto',
|
|
464
|
+
super_atom: bool = False,
|
|
465
|
+
radius: int | float | None = 6.0,
|
|
466
|
+
self_loops: bool = False,
|
|
467
|
+
include_hs: bool = False,
|
|
468
|
+
**kwargs,
|
|
469
|
+
) -> None:
|
|
470
|
+
if bond_features == 'auto':
|
|
471
|
+
bond_features = [
|
|
472
|
+
features.Distance()
|
|
473
|
+
]
|
|
474
|
+
super().__init__(
|
|
475
|
+
atom_features=atom_features,
|
|
476
|
+
bond_features=bond_features,
|
|
477
|
+
molecule_features=molecule_features,
|
|
478
|
+
super_atom=super_atom,
|
|
479
|
+
radius=radius,
|
|
480
|
+
self_loops=self_loops,
|
|
481
|
+
include_hs=include_hs,
|
|
482
|
+
**kwargs,
|
|
483
|
+
)
|
|
484
|
+
if conformer_generator == 'auto':
|
|
485
|
+
conformer_generator = conformers.ConformerGenerator(
|
|
486
|
+
steps=[
|
|
487
|
+
conformers.ConformerEmbedder(
|
|
488
|
+
method='ETKDGv3',
|
|
489
|
+
num_conformers=10
|
|
490
|
+
),
|
|
491
|
+
]
|
|
492
|
+
)
|
|
493
|
+
self.conformer_generator = conformer_generator
|
|
494
|
+
self.embed_conformer = self.conformer_generator is not None
|
|
495
|
+
self.radius = float(radius) if radius else None
|
|
496
|
+
|
|
497
|
+
def call(self, x: str | typing.Tuple) -> tensors.GraphTensor:
|
|
498
|
+
|
|
499
|
+
if isinstance(x, (tuple, list, np.ndarray)):
|
|
500
|
+
x, *args = x
|
|
501
|
+
else:
|
|
502
|
+
args = []
|
|
503
|
+
|
|
504
|
+
explicit_hs = (self.include_hs or self.embed_conformer)
|
|
505
|
+
mol = chem.Mol.from_encoding(x, explicit_hs=explicit_hs)
|
|
506
|
+
|
|
507
|
+
if mol is None:
|
|
508
|
+
return None
|
|
509
|
+
|
|
510
|
+
if self.embed_conformer:
|
|
511
|
+
mol = self.conformer_generator(mol)
|
|
512
|
+
if not self.include_hs:
|
|
513
|
+
mol = chem.remove_hs(mol)
|
|
514
|
+
|
|
515
|
+
if mol.num_conformers == 0:
|
|
516
|
+
raise ValueError(
|
|
517
|
+
'Cannot featurize a molecule without conformer(s). '
|
|
518
|
+
'Make sure to pass a `ConformerGenerator` to the constructor '
|
|
519
|
+
'of the `Featurizer` or input a 3D representation of the molecule. '
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
context, node, edge = {}, {}, {}
|
|
523
|
+
|
|
524
|
+
context['size'] = self.num_atoms(mol) + int(self.super_atom)
|
|
525
|
+
for field, value in zip(['label', 'weight'], args):
|
|
526
|
+
context[field] = value
|
|
527
|
+
|
|
528
|
+
node['feature'] = self.atom_features(mol)
|
|
529
|
+
|
|
530
|
+
if self._bond_features:
|
|
531
|
+
edge_feature = self.bond_features(mol)
|
|
532
|
+
|
|
533
|
+
context_feature = self.context_feature(mol)
|
|
534
|
+
if context_feature is not None:
|
|
535
|
+
context['feature'] = context_feature
|
|
536
|
+
|
|
537
|
+
mols = chem._split_mol_by_confs(mol)
|
|
538
|
+
tensor_list = []
|
|
539
|
+
for i, mol in enumerate(mols):
|
|
540
|
+
node_conformer = copy.deepcopy(node)
|
|
541
|
+
edge_conformer = copy.deepcopy(edge)
|
|
542
|
+
conformer = mol.get_conformer()
|
|
543
|
+
adjacency_matrix = conformer.adjacency(
|
|
544
|
+
fill='full',
|
|
545
|
+
radius=self.radius,
|
|
546
|
+
sparse=False,
|
|
547
|
+
self_loops=self.self_loops,
|
|
548
|
+
dtype=np.bool
|
|
549
|
+
)
|
|
550
|
+
edge_conformer['source'], edge_conformer['target'] = np.where(adjacency_matrix)
|
|
551
|
+
edge_conformer['source'] = edge_conformer['source'].astype(self.index_dtype)
|
|
552
|
+
edge_conformer['target'] = edge_conformer['target'].astype(self.index_dtype)
|
|
553
|
+
node_conformer['coordinate'] = conformer.coordinates.astype(self.feature_dtype)
|
|
554
|
+
|
|
555
|
+
if self._bond_features:
|
|
556
|
+
edge_feature_keep = adjacency_matrix.reshape(-1)
|
|
557
|
+
edge_conformer['feature'] = edge_feature[edge_feature_keep]
|
|
558
|
+
|
|
559
|
+
if self.super_atom:
|
|
560
|
+
node_conformer, edge_conformer = self._add_super_atom(
|
|
561
|
+
node_conformer, edge_conformer
|
|
562
|
+
)
|
|
563
|
+
node_conformer['coordinate'] = np.concatenate(
|
|
564
|
+
[node_conformer['coordinate'], conformer.centroid[None]], axis=0
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
tensor_list.append(
|
|
568
|
+
tensors.GraphTensor(context, node_conformer, edge_conformer)
|
|
569
|
+
)
|
|
570
|
+
|
|
571
|
+
return tensor_list
|
|
572
|
+
|
|
573
|
+
def stack(self, outputs):
|
|
574
|
+
# Flatten list of lists (due to multiple conformers per molecule)
|
|
575
|
+
outputs = [x for xs in outputs for x in xs]
|
|
576
|
+
return super().stack(outputs)
|
|
577
|
+
|
|
578
|
+
def get_config(self):
|
|
579
|
+
config = super().get_config()
|
|
580
|
+
config['conformer_generator'] = keras.saving.serialize_keras_object(
|
|
581
|
+
self.conformer_generator
|
|
582
|
+
)
|
|
583
|
+
return config
|
|
584
|
+
|
|
585
|
+
@classmethod
|
|
586
|
+
def from_config(cls, config: dict):
|
|
587
|
+
config['conformer_generator'] = keras.saving.deserialize_keras_object(
|
|
588
|
+
config['conformer_generator']
|
|
589
|
+
)
|
|
590
|
+
return super().from_config(**config)
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
def save_featurizer(
|
|
594
|
+
featurizer: Featurizer,
|
|
595
|
+
filepath: str | Path,
|
|
596
|
+
overwrite: bool = True,
|
|
597
|
+
**kwargs
|
|
598
|
+
) -> None:
|
|
599
|
+
filepath = Path(filepath)
|
|
600
|
+
if filepath.suffix != '.json':
|
|
601
|
+
raise ValueError(
|
|
602
|
+
'Invalid `filepath` extension for saving a `Featurizer`. '
|
|
603
|
+
'A `Featurizer` should be saved as a JSON file.'
|
|
604
|
+
)
|
|
605
|
+
if not filepath.parent.exists():
|
|
606
|
+
filepath.parent.mkdir(parents=True, exist_ok=True)
|
|
607
|
+
if filepath.exists() and not overwrite:
|
|
608
|
+
return
|
|
609
|
+
serialized_featurizer = keras.saving.serialize_keras_object(featurizer)
|
|
610
|
+
with open(filepath, 'w') as f:
|
|
611
|
+
json.dump(serialized_featurizer, f, indent=4)
|
|
612
|
+
|
|
613
|
+
def load_featurizer(
|
|
614
|
+
filepath: str | Path,
|
|
615
|
+
**kwargs
|
|
616
|
+
) -> Featurizer:
|
|
617
|
+
filepath = Path(filepath)
|
|
618
|
+
if filepath.suffix != '.json':
|
|
619
|
+
raise ValueError(
|
|
620
|
+
'Invalid `filepath` extension for loading a `Featurizer`. '
|
|
621
|
+
'A `Featurizer` should be saved as a JSON file.'
|
|
622
|
+
)
|
|
623
|
+
if not filepath.exists():
|
|
624
|
+
return
|
|
625
|
+
with open(filepath, 'r') as f:
|
|
626
|
+
config = json.load(f)
|
|
627
|
+
return keras.saving.deserialize_keras_object(config)
|
|
628
|
+
|
|
629
|
+
def _add_super_nodes(
|
|
630
|
+
node: dict[str, np.ndarray],
|
|
631
|
+
num_super_nodes: int = 1,
|
|
632
|
+
feature_dtype: str = 'float32',
|
|
633
|
+
) -> dict[str, np.ndarray]:
|
|
634
|
+
node = copy.deepcopy(node)
|
|
635
|
+
node['super'] = np.array([False] * len(node['feature']) + [True] * num_super_nodes)
|
|
636
|
+
super_node_feature = np.zeros(
|
|
637
|
+
[num_super_nodes, node['feature'].shape[-1]], dtype=feature_dtype
|
|
638
|
+
)
|
|
639
|
+
node['feature'] = np.concatenate([node['feature'], super_node_feature])
|
|
640
|
+
return node
|
|
641
|
+
|
|
642
|
+
def _add_super_edges(
|
|
643
|
+
edge: dict[str, np.ndarray],
|
|
644
|
+
num_nodes: int,
|
|
645
|
+
num_super_nodes: int,
|
|
646
|
+
feature_dtype: str,
|
|
647
|
+
index_dtype: str,
|
|
648
|
+
) -> dict[str, np.ndarray]:
|
|
649
|
+
edge = copy.deepcopy(edge)
|
|
650
|
+
super_node_indices = (
|
|
651
|
+
np.repeat(np.arange(num_super_nodes), [num_nodes]) + num_nodes
|
|
652
|
+
)
|
|
653
|
+
node_indices = (
|
|
654
|
+
np.tile(np.arange(num_nodes), [num_super_nodes])
|
|
655
|
+
)
|
|
656
|
+
edge['source'] = np.concatenate(
|
|
657
|
+
[
|
|
658
|
+
edge['source'],
|
|
659
|
+
node_indices,
|
|
660
|
+
super_node_indices,
|
|
661
|
+
]
|
|
662
|
+
)
|
|
663
|
+
edge['source'] = edge['source'].astype(index_dtype)
|
|
664
|
+
edge['target'] = np.concatenate(
|
|
665
|
+
[
|
|
666
|
+
edge['target'],
|
|
667
|
+
super_node_indices,
|
|
668
|
+
node_indices
|
|
669
|
+
]
|
|
670
|
+
)
|
|
671
|
+
edge['target'] = edge['target'].astype(index_dtype)
|
|
672
|
+
if 'feature' in edge:
|
|
673
|
+
edge['super'] = np.asarray([False] * edge['feature'].shape[0] + [True] * (num_super_nodes * num_nodes * 2))
|
|
674
|
+
edge['feature'] = np.concatenate([edge['feature'], np.zeros((num_super_nodes * num_nodes * 2, edge['feature'].shape[-1]))])
|
|
675
|
+
if 'length' in edge:
|
|
676
|
+
edge['length'] = np.pad(edge['length'], [(0, 0), (1, 0)])
|
|
677
|
+
zero_array = np.zeros((num_nodes * num_super_nodes * 2,), dtype='int32')
|
|
678
|
+
edge_length_dim = edge['length'].shape[1]
|
|
679
|
+
virtual_edge_length = np.eye(edge_length_dim)[zero_array]
|
|
680
|
+
edge['length'] = np.concatenate([edge['length'], virtual_edge_length])
|
|
681
|
+
edge['length'] = edge['length'].astype(feature_dtype)
|
|
682
|
+
return edge
|
|
683
|
+
|
|
684
|
+
|
|
685
|
+
def warn(message: str) -> None:
|
|
686
|
+
warnings.warn(
|
|
687
|
+
message=message,
|
|
688
|
+
category=UserWarning,
|
|
689
|
+
stacklevel=1
|
|
690
|
+
)
|
|
691
|
+
|
|
692
|
+
MolFeaturizer = MolGraphFeaturizer
|
|
693
|
+
MolFeaturizer3D = MolGraphFeaturizer3D
|