molcraft 0.1.0a16__py3-none-any.whl → 0.1.0a18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of molcraft might be problematic. Click here for more details.
- molcraft/__init__.py +4 -3
- molcraft/applications/chromatography.py +0 -0
- molcraft/applications/proteomics.py +141 -106
- molcraft/chem.py +17 -22
- molcraft/datasets.py +6 -6
- molcraft/descriptors.py +14 -0
- molcraft/features.py +50 -58
- molcraft/featurizers.py +257 -487
- molcraft/layers.py +95 -40
- molcraft/models.py +2 -0
- molcraft/records.py +24 -15
- {molcraft-0.1.0a16.dist-info → molcraft-0.1.0a18.dist-info}/METADATA +13 -12
- molcraft-0.1.0a18.dist-info/RECORD +21 -0
- molcraft/conformers.py +0 -151
- molcraft-0.1.0a16.dist-info/RECORD +0 -21
- {molcraft-0.1.0a16.dist-info → molcraft-0.1.0a18.dist-info}/WHEEL +0 -0
- {molcraft-0.1.0a16.dist-info → molcraft-0.1.0a18.dist-info}/licenses/LICENSE +0 -0
- {molcraft-0.1.0a16.dist-info → molcraft-0.1.0a18.dist-info}/top_level.txt +0 -0
molcraft/featurizers.py
CHANGED
|
@@ -14,63 +14,47 @@ from pathlib import Path
|
|
|
14
14
|
from molcraft import tensors
|
|
15
15
|
from molcraft import features
|
|
16
16
|
from molcraft import chem
|
|
17
|
-
from molcraft import conformers
|
|
18
17
|
from molcraft import descriptors
|
|
19
18
|
|
|
20
19
|
|
|
21
20
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
22
|
-
class
|
|
21
|
+
class GraphFeaturizer(abc.ABC):
|
|
23
22
|
|
|
24
|
-
"""Base
|
|
23
|
+
"""Base graph featurizer.
|
|
25
24
|
"""
|
|
26
25
|
|
|
27
26
|
@abc.abstractmethod
|
|
28
|
-
def call(
|
|
29
|
-
self,
|
|
30
|
-
x: tensors.GraphTensor
|
|
31
|
-
) -> tensors.GraphTensor | list[tensors.GraphTensor]:
|
|
27
|
+
def call(self, x: str | chem.Mol | tuple) -> tensors.GraphTensor:
|
|
32
28
|
pass
|
|
33
29
|
|
|
34
|
-
@abc.abstractmethod
|
|
35
|
-
def stack(
|
|
36
|
-
self,
|
|
37
|
-
call_outputs: list[tensors.GraphTensor]
|
|
38
|
-
) -> tensors.GraphTensor:
|
|
39
|
-
pass
|
|
40
|
-
|
|
41
30
|
def get_config(self) -> dict:
|
|
42
31
|
return {}
|
|
43
32
|
|
|
44
33
|
@classmethod
|
|
45
|
-
def from_config(cls, config: dict) -> '
|
|
34
|
+
def from_config(cls, config: dict) -> 'GraphFeaturizer':
|
|
46
35
|
return cls(**config)
|
|
47
36
|
|
|
48
37
|
def save(self, filepath: str | Path, *args, **kwargs) -> None:
|
|
49
|
-
save_featurizer(
|
|
50
|
-
self, filepath, *args, **kwargs
|
|
51
|
-
)
|
|
38
|
+
save_featurizer(self, filepath, *args, **kwargs)
|
|
52
39
|
|
|
53
40
|
@staticmethod
|
|
54
|
-
def load(filepath: str | Path, *args, **kwargs) -> '
|
|
55
|
-
return load_featurizer(
|
|
56
|
-
filepath, *args, **kwargs
|
|
57
|
-
)
|
|
41
|
+
def load(filepath: str | Path, *args, **kwargs) -> 'GraphFeaturizer':
|
|
42
|
+
return load_featurizer(filepath, *args, **kwargs)
|
|
58
43
|
|
|
59
44
|
def __call__(
|
|
60
45
|
self,
|
|
61
|
-
inputs: str |
|
|
46
|
+
inputs: str | chem.Mol | tuple | typing.Iterable,
|
|
62
47
|
*,
|
|
63
48
|
multiprocessing: bool = False,
|
|
64
49
|
processes: int | None = None,
|
|
65
50
|
device: str = '/cpu:0',
|
|
66
|
-
**kwargs
|
|
67
51
|
) -> tensors.GraphTensor:
|
|
68
52
|
if isinstance(inputs, (str, tuple)):
|
|
69
53
|
return self.call(inputs)
|
|
70
54
|
if isinstance(inputs, (pd.DataFrame, pd.Series)):
|
|
71
|
-
inputs = inputs.values
|
|
72
|
-
|
|
73
|
-
inputs =
|
|
55
|
+
inputs = inputs.values.tolist()
|
|
56
|
+
elif isinstance(inputs, np.ndarray):
|
|
57
|
+
inputs = inputs.tolist()
|
|
74
58
|
if not multiprocessing:
|
|
75
59
|
outputs = [self.call(x) for x in inputs]
|
|
76
60
|
else:
|
|
@@ -78,19 +62,19 @@ class Featurizer(abc.ABC):
|
|
|
78
62
|
with mp.Pool(processes) as pool:
|
|
79
63
|
outputs = pool.map(func=self.call, iterable=inputs)
|
|
80
64
|
outputs = [x for x in outputs if x is not None]
|
|
81
|
-
|
|
65
|
+
if tensors.is_scalar(outputs[0]):
|
|
66
|
+
return tf.stack(outputs, axis=0)
|
|
67
|
+
return tf.concat(outputs, axis=0)
|
|
82
68
|
|
|
83
69
|
|
|
84
70
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
85
|
-
class MolGraphFeaturizer(
|
|
71
|
+
class MolGraphFeaturizer(GraphFeaturizer):
|
|
86
72
|
|
|
87
73
|
"""Molecular graph featurizer.
|
|
88
74
|
|
|
89
75
|
Converts SMILES or InChI strings to a molecular graph.
|
|
90
76
|
|
|
91
77
|
The molecular graph may encode a single molecule or a batch of molecules.
|
|
92
|
-
In either case, it is a single graph, with each molecule corresponding to
|
|
93
|
-
a subgraph within the graph.
|
|
94
78
|
|
|
95
79
|
Example:
|
|
96
80
|
|
|
@@ -99,10 +83,14 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
99
83
|
>>> featurizer = molcraft.featurizers.MolGraphFeaturizer(
|
|
100
84
|
... atom_features=[
|
|
101
85
|
... molcraft.features.AtomType(),
|
|
102
|
-
... molcraft.features.
|
|
86
|
+
... molcraft.features.NumHydrogens(),
|
|
103
87
|
... molcraft.features.Degree(),
|
|
104
88
|
... ],
|
|
105
|
-
...
|
|
89
|
+
... bond_features=[
|
|
90
|
+
... molcraft.features.BondType(),
|
|
91
|
+
... ],
|
|
92
|
+
... super_node=False,
|
|
93
|
+
... self_loops=False,
|
|
106
94
|
... )
|
|
107
95
|
>>>
|
|
108
96
|
>>> graph = featurizer(["N[C@@H](C)C(=O)O", "N[C@@H](CS)C(=O)O"])
|
|
@@ -112,7 +100,7 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
112
100
|
'size': <tf.Tensor: shape=[2], dtype=int32>
|
|
113
101
|
},
|
|
114
102
|
node={
|
|
115
|
-
'feature': <tf.Tensor: shape=[13,
|
|
103
|
+
'feature': <tf.Tensor: shape=[13, 129], dtype=float32>
|
|
116
104
|
},
|
|
117
105
|
edge={
|
|
118
106
|
'source': <tf.Tensor: shape=[22], dtype=int32>,
|
|
@@ -123,73 +111,46 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
123
111
|
|
|
124
112
|
Args:
|
|
125
113
|
atom_features:
|
|
126
|
-
A list of `features.Feature`
|
|
114
|
+
A list of `features.Feature` encoded as the node features.
|
|
127
115
|
bond_features:
|
|
128
|
-
A list of `features.Feature`
|
|
116
|
+
A list of `features.Feature` encoded as the edge features.
|
|
129
117
|
molecule_features:
|
|
130
|
-
A `
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
super_atom:
|
|
134
|
-
A boolean specifying whether super atoms exist and should be embedded
|
|
135
|
-
via `NodeEmbedding`.
|
|
136
|
-
radius:
|
|
137
|
-
An integer specifying how many bond lengths should be considered as an
|
|
138
|
-
edge. The default is None (or 1), which specifies that only bonds should
|
|
139
|
-
be considered an edge.
|
|
118
|
+
A list of `descriptors.Descriptor` encoded as the context feature.
|
|
119
|
+
super_node:
|
|
120
|
+
A boolean specifying whether to include a super node.
|
|
140
121
|
self_loops:
|
|
141
|
-
A boolean specifying whether self loops exist.
|
|
142
|
-
|
|
143
|
-
include_hs:
|
|
122
|
+
A boolean specifying whether self loops exist.
|
|
123
|
+
include_hydrogens:
|
|
144
124
|
A boolean specifying whether hydrogens should be encoded as nodes.
|
|
145
125
|
"""
|
|
146
126
|
|
|
147
127
|
def __init__(
|
|
148
128
|
self,
|
|
149
|
-
atom_features: list[features.Feature] | str
|
|
129
|
+
atom_features: list[features.Feature] | str = 'auto',
|
|
150
130
|
bond_features: list[features.Feature] | str | None = 'auto',
|
|
151
|
-
molecule_features:
|
|
152
|
-
|
|
153
|
-
radius: int | float | None = None,
|
|
131
|
+
molecule_features: list[descriptors.Descriptor] | str | None = None,
|
|
132
|
+
super_node: bool = False,
|
|
154
133
|
self_loops: bool = False,
|
|
155
|
-
|
|
156
|
-
**kwargs,
|
|
134
|
+
include_hydrogens: bool = False,
|
|
157
135
|
) -> None:
|
|
158
|
-
|
|
159
|
-
molecule_features = kwargs.pop('mol_features', None)
|
|
160
|
-
|
|
161
|
-
self.radius = int(max(radius or 1, 1))
|
|
162
|
-
self.include_hs = include_hs
|
|
163
|
-
self.self_loops = self_loops
|
|
164
|
-
self.super_atom = super_atom
|
|
165
|
-
|
|
166
|
-
default_atom_features = (
|
|
136
|
+
use_default_atom_features = (
|
|
167
137
|
atom_features == 'auto' or atom_features == 'default'
|
|
168
138
|
)
|
|
169
|
-
if
|
|
170
|
-
atom_features = [features.AtomType()]
|
|
171
|
-
if not
|
|
172
|
-
atom_features
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
features.BondType(vocab)
|
|
182
|
-
]
|
|
183
|
-
if not default_bond_features and self.radius > 1:
|
|
184
|
-
warnings.warn(
|
|
185
|
-
'Replacing user-specified bond features with default bond features, '
|
|
186
|
-
'as `radius`>1. When `radius`>1, only bond types are considered.',
|
|
187
|
-
stacklevel=2
|
|
188
|
-
)
|
|
189
|
-
default_molecule_features = (
|
|
139
|
+
if use_default_atom_features:
|
|
140
|
+
atom_features = [features.AtomType(), features.Degree()]
|
|
141
|
+
if not include_hydrogens:
|
|
142
|
+
atom_features += [features.NumHydrogens()]
|
|
143
|
+
|
|
144
|
+
use_default_bond_features = (
|
|
145
|
+
bond_features == 'auto' or bond_features == 'default'
|
|
146
|
+
)
|
|
147
|
+
if use_default_bond_features:
|
|
148
|
+
bond_features = [features.BondType()]
|
|
149
|
+
|
|
150
|
+
use_default_molecule_features = (
|
|
190
151
|
molecule_features == 'auto' or molecule_features == 'default'
|
|
191
152
|
)
|
|
192
|
-
if
|
|
153
|
+
if use_default_molecule_features:
|
|
193
154
|
molecule_features = [
|
|
194
155
|
descriptors.MolWeight(),
|
|
195
156
|
descriptors.TotalPolarSurfaceArea(),
|
|
@@ -202,182 +163,65 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
202
163
|
descriptors.NumRotatableBonds(),
|
|
203
164
|
descriptors.NumRings(),
|
|
204
165
|
]
|
|
166
|
+
|
|
205
167
|
self._atom_features = atom_features
|
|
206
168
|
self._bond_features = bond_features
|
|
207
169
|
self._molecule_features = molecule_features
|
|
208
|
-
self.
|
|
209
|
-
self.
|
|
170
|
+
self._include_hydrogens = include_hydrogens
|
|
171
|
+
self._self_loops = self_loops
|
|
172
|
+
self._super_node = super_node
|
|
210
173
|
|
|
211
174
|
def call(self, inputs: str | tuple) -> tensors.GraphTensor:
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
context = copy.deepcopy(context[0])
|
|
216
|
-
else:
|
|
217
|
-
x, context = inputs, None
|
|
175
|
+
|
|
176
|
+
if isinstance(inputs, str):
|
|
177
|
+
inputs = (inputs,)
|
|
218
178
|
|
|
219
|
-
|
|
179
|
+
inputs, *context_inputs = inputs
|
|
220
180
|
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
'Returning `None` (proceeding without it).',
|
|
225
|
-
stacklevel=2
|
|
226
|
-
)
|
|
227
|
-
return None
|
|
228
|
-
|
|
229
|
-
atom_feature = self.atom_features(mol)
|
|
230
|
-
bond_feature = self.bond_features(mol)
|
|
231
|
-
molecule_feature = self.molecule_feature(mol)
|
|
232
|
-
molecule_size = self.num_atoms(mol)
|
|
233
|
-
|
|
234
|
-
if isinstance(context, dict):
|
|
235
|
-
if 'x' in context:
|
|
236
|
-
context['feature'] = context.pop('x')
|
|
237
|
-
if 'y' in context:
|
|
238
|
-
context['label'] = context.pop('y')
|
|
239
|
-
if 'sample_weight' in context:
|
|
240
|
-
context['weight'] = context.pop('sample_weight')
|
|
241
|
-
context = {
|
|
242
|
-
**{'size': molecule_size},
|
|
243
|
-
**context
|
|
244
|
-
}
|
|
245
|
-
elif isinstance(context, list):
|
|
246
|
-
context = {
|
|
247
|
-
**{'size': molecule_size},
|
|
248
|
-
**{key: value for (key, value) in zip(['label', 'weight'], context)}
|
|
249
|
-
}
|
|
250
|
-
else:
|
|
251
|
-
context = {'size': molecule_size}
|
|
252
|
-
|
|
253
|
-
if molecule_feature is not None:
|
|
254
|
-
if 'feature' in context:
|
|
255
|
-
warnings.warn(
|
|
256
|
-
'Found both inputted and computed context feature. '
|
|
257
|
-
'Overwriting inputted context feature with computed '
|
|
258
|
-
'context feature (based on `molecule_features`).',
|
|
259
|
-
stacklevel=2
|
|
260
|
-
)
|
|
261
|
-
context['feature'] = molecule_feature
|
|
262
|
-
|
|
263
|
-
node = {}
|
|
264
|
-
node['feature'] = atom_feature
|
|
181
|
+
mol = chem.Mol.from_encoding(inputs, explicit_hs=self._include_hydrogens)
|
|
182
|
+
|
|
183
|
+
data = {'context': {}, 'node': {}, 'edge': {}}
|
|
265
184
|
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
bond_indices.append(-1)
|
|
278
|
-
else:
|
|
279
|
-
bond_indices.append(
|
|
280
|
-
mol.get_bond_between_atoms(atom_i, atom_j).index
|
|
281
|
-
)
|
|
282
|
-
edge['feature'] = bond_feature[bond_indices]
|
|
283
|
-
else:
|
|
284
|
-
paths = chem.get_shortest_paths(
|
|
285
|
-
mol, radius=self.radius, self_loops=self.self_loops
|
|
286
|
-
)
|
|
287
|
-
edge['source'] = np.asarray(
|
|
288
|
-
[path[0] for path in paths], dtype=self.index_dtype
|
|
289
|
-
)
|
|
290
|
-
edge['target'] = np.asarray(
|
|
291
|
-
[path[-1] for path in paths], dtype=self.index_dtype
|
|
185
|
+
data['context']['size'] = np.asarray(mol.num_atoms)
|
|
186
|
+
|
|
187
|
+
if len(context_inputs) == 1:
|
|
188
|
+
data['context']['label'] = np.asarray(context_inputs[0])
|
|
189
|
+
elif len(context_inputs) == 2:
|
|
190
|
+
data['context']['label'] = np.asarray(context_inputs[0])
|
|
191
|
+
data['context']['weight'] = np.asarray(context_inputs[1])
|
|
192
|
+
|
|
193
|
+
if self._molecule_features is not None:
|
|
194
|
+
data['context']['feature'] = np.concatenate(
|
|
195
|
+
[f(mol) for f in self._molecule_features], axis=-1
|
|
292
196
|
)
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
[[1., 0., 0., 0., 0.]], dtype=bond_feature.dtype
|
|
296
|
-
)
|
|
297
|
-
bond_feature = np.concatenate(
|
|
298
|
-
[bond_feature, zero_bond_feature], axis=0
|
|
299
|
-
)
|
|
300
|
-
edge['feature'] = self._expand_bond_features(
|
|
301
|
-
mol, paths, bond_feature,
|
|
302
|
-
)
|
|
303
|
-
|
|
304
|
-
if self.super_atom:
|
|
305
|
-
node, edge = self._add_super_atom(node, edge)
|
|
306
|
-
context['size'] += 1
|
|
307
|
-
|
|
308
|
-
return tensors.GraphTensor(context, node, edge)
|
|
309
|
-
|
|
310
|
-
def stack(self, outputs):
|
|
311
|
-
if tensors.is_scalar(outputs[0]):
|
|
312
|
-
return tf.stack(outputs, axis=0)
|
|
313
|
-
return tf.concat(outputs, axis=0)
|
|
314
|
-
|
|
315
|
-
def atom_features(self, mol: chem.Mol) -> np.ndarray:
|
|
316
|
-
atom_feature: np.ndarray = np.concatenate(
|
|
197
|
+
|
|
198
|
+
data['node']['feature'] = np.concatenate(
|
|
317
199
|
[f(mol) for f in self._atom_features], axis=-1
|
|
318
200
|
)
|
|
319
|
-
return atom_feature.astype(self.feature_dtype)
|
|
320
201
|
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
return None
|
|
324
|
-
bond_feature: np.ndarray = np.concatenate(
|
|
325
|
-
[f(mol) for f in self._bond_features], axis=-1
|
|
326
|
-
)
|
|
327
|
-
return bond_feature.astype(self.feature_dtype)
|
|
328
|
-
|
|
329
|
-
def molecule_feature(self, mol: chem.Mol) -> np.ndarray:
|
|
330
|
-
if self._molecule_features is None:
|
|
331
|
-
return None
|
|
332
|
-
molecule_feature: np.ndarray = np.concatenate(
|
|
333
|
-
[f(mol) for f in self._molecule_features], axis=-1
|
|
202
|
+
data['edge']['source'], data['edge']['target'] = mol.adjacency(
|
|
203
|
+
fill='full', sparse=True, self_loops=self._self_loops
|
|
334
204
|
)
|
|
335
|
-
|
|
205
|
+
|
|
206
|
+
if self._bond_features is not None:
|
|
207
|
+
bond_features = np.concatenate(
|
|
208
|
+
[f(mol) for f in self._bond_features], axis=-1
|
|
209
|
+
)
|
|
210
|
+
if self._self_loops:
|
|
211
|
+
bond_features = np.pad(bond_features, [(0, 1), (0, 0)])
|
|
336
212
|
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
def num_bonds(self, mol: chem.Mol) -> np.ndarray:
|
|
341
|
-
return np.asarray(mol.num_bonds, dtype=self.index_dtype)
|
|
342
|
-
|
|
343
|
-
def _expand_bond_features(
|
|
344
|
-
self,
|
|
345
|
-
mol: chem.Mol,
|
|
346
|
-
paths: list[list[int]],
|
|
347
|
-
bond_feature: np.ndarray,
|
|
348
|
-
) -> np.ndarray:
|
|
349
|
-
|
|
350
|
-
def bond_feature_lookup(path):
|
|
351
|
-
path_bond_indices = [
|
|
352
|
-
mol.get_bond_between_atoms(path[i], path[i + 1]).index
|
|
353
|
-
for i in range(len(path) - 1)
|
|
213
|
+
bond_indices = [
|
|
214
|
+
mol.get_bond_between_atoms(i, j).index if (i != j) else -1
|
|
215
|
+
for (i, j) in zip(data['edge']['source'], data['edge']['target'])
|
|
354
216
|
]
|
|
355
|
-
padding = [-1] * (self.radius - len(path) + 1)
|
|
356
|
-
path_bond_indices += padding
|
|
357
|
-
return bond_feature[path_bond_indices].reshape(-1)
|
|
358
|
-
|
|
359
|
-
edge_feature = np.asarray(
|
|
360
|
-
[
|
|
361
|
-
bond_feature_lookup(path) for path in paths
|
|
362
|
-
],
|
|
363
|
-
dtype=self.feature_dtype
|
|
364
|
-
).reshape((-1, bond_feature.shape[-1] * self.radius))
|
|
365
|
-
|
|
366
|
-
return edge_feature
|
|
367
|
-
|
|
368
|
-
def _add_super_atom(
|
|
369
|
-
self,
|
|
370
|
-
node: dict[str, np.ndarray],
|
|
371
|
-
edge: dict[str, np.ndarray],
|
|
372
|
-
) -> tuple[dict[str, np.ndarray]]:
|
|
373
|
-
num_super_nodes = 1
|
|
374
|
-
num_nodes = node['feature'].shape[0]
|
|
375
|
-
node = _add_super_nodes(node, num_super_nodes)
|
|
376
|
-
edge = _add_super_edges(
|
|
377
|
-
edge, num_nodes, num_super_nodes, self.feature_dtype, self.index_dtype, self.self_loops
|
|
378
|
-
)
|
|
379
|
-
return node, edge
|
|
380
217
|
|
|
218
|
+
data['edge']['feature'] = bond_features[bond_indices]
|
|
219
|
+
|
|
220
|
+
if self._super_node:
|
|
221
|
+
data = _add_super_node(data)
|
|
222
|
+
|
|
223
|
+
return tensors.GraphTensor(**_convert_dtypes(data))
|
|
224
|
+
|
|
381
225
|
def get_config(self):
|
|
382
226
|
config = super().get_config()
|
|
383
227
|
config.update({
|
|
@@ -390,10 +234,9 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
390
234
|
'molecule_features': keras.saving.serialize_keras_object(
|
|
391
235
|
self._molecule_features
|
|
392
236
|
),
|
|
393
|
-
'
|
|
394
|
-
'
|
|
395
|
-
'
|
|
396
|
-
'include_hs': self.include_hs,
|
|
237
|
+
'super_node': self._super_node,
|
|
238
|
+
'self_loops': self._self_loops,
|
|
239
|
+
'include_hydrogens': self._include_hydrogens,
|
|
397
240
|
})
|
|
398
241
|
return config
|
|
399
242
|
|
|
@@ -414,15 +257,11 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
414
257
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
415
258
|
class MolGraphFeaturizer3D(MolGraphFeaturizer):
|
|
416
259
|
|
|
417
|
-
"""Molecular
|
|
260
|
+
"""3D Molecular graph featurizer.
|
|
418
261
|
|
|
419
|
-
Converts SMILES or InChI strings to a molecular graph
|
|
420
|
-
Namely, in addition to the information encoded in a standard molecular
|
|
421
|
-
graph, cartesian coordinates are also included.
|
|
262
|
+
Converts SMILES or InChI strings to a 3d molecular graph.
|
|
422
263
|
|
|
423
264
|
The molecular graph may encode a single molecule or a batch of molecules.
|
|
424
|
-
In either case, it is a single graph, with each molecule corresponding to
|
|
425
|
-
a subgraph within the graph.
|
|
426
265
|
|
|
427
266
|
Example:
|
|
428
267
|
|
|
@@ -431,226 +270,166 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
|
|
|
431
270
|
>>> featurizer = molcraft.featurizers.MolGraphFeaturizer3D(
|
|
432
271
|
... atom_features=[
|
|
433
272
|
... molcraft.features.AtomType(),
|
|
434
|
-
... molcraft.features.
|
|
273
|
+
... molcraft.features.NumHydrogens(),
|
|
435
274
|
... molcraft.features.Degree(),
|
|
436
275
|
... ],
|
|
437
|
-
... radius=5.0
|
|
276
|
+
... radius=5.0,
|
|
277
|
+
... random_seed=42,
|
|
438
278
|
... )
|
|
439
279
|
>>>
|
|
440
280
|
>>> graph = featurizer(["N[C@@H](C)C(=O)O", "N[C@@H](CS)C(=O)O"])
|
|
441
281
|
>>> graph
|
|
442
282
|
GraphTensor(
|
|
443
283
|
context={
|
|
444
|
-
'size': <tf.Tensor: shape=[
|
|
284
|
+
'size': <tf.Tensor: shape=[2], dtype=int32>
|
|
445
285
|
},
|
|
446
286
|
node={
|
|
447
|
-
'feature': <tf.Tensor: shape=[
|
|
448
|
-
'coordinate': <tf.Tensor: shape=[
|
|
287
|
+
'feature': <tf.Tensor: shape=[13, 129], dtype=float32>,
|
|
288
|
+
'coordinate': <tf.Tensor: shape=[13, 3], dtype=float32>
|
|
449
289
|
},
|
|
450
290
|
edge={
|
|
451
|
-
'source': <tf.Tensor: shape=[
|
|
452
|
-
'target': <tf.Tensor: shape=[
|
|
453
|
-
'feature': <tf.Tensor: shape=[
|
|
291
|
+
'source': <tf.Tensor: shape=[72], dtype=int32>,
|
|
292
|
+
'target': <tf.Tensor: shape=[72], dtype=int32>,
|
|
293
|
+
'feature': <tf.Tensor: shape=[72, 12], dtype=float32>
|
|
454
294
|
}
|
|
455
295
|
)
|
|
456
296
|
|
|
457
297
|
Args:
|
|
458
298
|
atom_features:
|
|
459
|
-
A list of `features.Feature`
|
|
460
|
-
|
|
461
|
-
A list of `features.
|
|
299
|
+
A list of `features.Feature` encoded as the node features.
|
|
300
|
+
pair_features:
|
|
301
|
+
A list of `features.PairFeature` encoded as the edge features.
|
|
462
302
|
molecule_features:
|
|
463
|
-
A `
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
conformer_generator:
|
|
467
|
-
A `conformers.ConformerGenerator` which produces conformers. If `auto`
|
|
468
|
-
a `conformers.ConformerEmbedder` will be used. If None, it is assumed
|
|
469
|
-
that the molecule already has conformer(s).
|
|
470
|
-
super_atom:
|
|
471
|
-
A boolean specifying whether super atoms exist and should be embedded
|
|
472
|
-
via `NodeEmbedding`.
|
|
473
|
-
radius:
|
|
474
|
-
A float specifying, for each atom, the maximum distance (in angstroms)
|
|
475
|
-
that another atom should be within to be considered an edge. Default
|
|
476
|
-
is set to 6.0 as this should cover most interactions. This parameter
|
|
477
|
-
can be though of as the receptive field. If None, the radius will be
|
|
478
|
-
infinite so all the receptive field will be the entire space (graph).
|
|
303
|
+
A list of `descriptors.Descriptor` encoded as the context feature.
|
|
304
|
+
super_node:
|
|
305
|
+
A boolean specifying whether to include a super node.
|
|
479
306
|
self_loops:
|
|
480
|
-
A boolean specifying whether self loops exist.
|
|
481
|
-
|
|
482
|
-
include_hs:
|
|
307
|
+
A boolean specifying whether self loops exist.
|
|
308
|
+
include_hydrogens:
|
|
483
309
|
A boolean specifying whether hydrogens should be encoded as nodes.
|
|
310
|
+
radius:
|
|
311
|
+
A floating point value specifying maximum edge length.
|
|
312
|
+
random_seed:
|
|
313
|
+
An integer specifying the random seed for the conformer generation.
|
|
484
314
|
"""
|
|
485
315
|
|
|
486
316
|
def __init__(
|
|
487
317
|
self,
|
|
488
|
-
atom_features: list[features.Feature] | str
|
|
489
|
-
|
|
490
|
-
molecule_features: features.Feature | str = None,
|
|
491
|
-
|
|
492
|
-
super_atom: bool = False,
|
|
493
|
-
radius: int | float | None = 6.0,
|
|
318
|
+
atom_features: list[features.Feature] | str = 'auto',
|
|
319
|
+
pair_features: list[features.PairFeature] | str = 'auto',
|
|
320
|
+
molecule_features: features.Feature | str | None = None,
|
|
321
|
+
super_node: bool = False,
|
|
494
322
|
self_loops: bool = False,
|
|
495
|
-
|
|
496
|
-
|
|
323
|
+
include_hydrogens: bool = False,
|
|
324
|
+
radius: int | float | None = 6.0,
|
|
325
|
+
random_seed: int | None = None,
|
|
497
326
|
) -> None:
|
|
498
|
-
if bond_features == 'auto':
|
|
499
|
-
bond_features = [
|
|
500
|
-
features.Distance()
|
|
501
|
-
]
|
|
502
327
|
super().__init__(
|
|
503
328
|
atom_features=atom_features,
|
|
504
|
-
bond_features=
|
|
329
|
+
bond_features=None,
|
|
505
330
|
molecule_features=molecule_features,
|
|
506
|
-
|
|
507
|
-
radius=radius,
|
|
331
|
+
super_node=super_node,
|
|
508
332
|
self_loops=self_loops,
|
|
509
|
-
|
|
510
|
-
**kwargs,
|
|
333
|
+
include_hydrogens=include_hydrogens,
|
|
511
334
|
)
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
self.
|
|
522
|
-
|
|
523
|
-
self.radius = float(radius) if radius else None
|
|
524
|
-
|
|
335
|
+
|
|
336
|
+
use_default_pair_features = (
|
|
337
|
+
pair_features == 'auto' or pair_features == 'default'
|
|
338
|
+
)
|
|
339
|
+
if use_default_pair_features:
|
|
340
|
+
pair_features = [features.PairDistance()]
|
|
341
|
+
|
|
342
|
+
self._pair_features = pair_features
|
|
343
|
+
self._radius = float(radius) if radius else None
|
|
344
|
+
self._random_seed = random_seed
|
|
345
|
+
|
|
525
346
|
def call(self, inputs: str | tuple) -> tensors.GraphTensor:
|
|
526
347
|
|
|
527
|
-
if isinstance(inputs,
|
|
528
|
-
|
|
529
|
-
if len(context) and isinstance(context[0], dict):
|
|
530
|
-
context = copy.deepcopy(context[0])
|
|
531
|
-
else:
|
|
532
|
-
x, context = inputs, None
|
|
348
|
+
if isinstance(inputs, str):
|
|
349
|
+
inputs = [inputs]
|
|
533
350
|
|
|
534
|
-
|
|
535
|
-
mol = chem.Mol.from_encoding(x, explicit_hs=explicit_hs)
|
|
536
|
-
|
|
537
|
-
if mol is None:
|
|
538
|
-
warnings.warn(
|
|
539
|
-
f'Could not obtain `chem.Mol` from {x}. '
|
|
540
|
-
'Proceeding without it.',
|
|
541
|
-
stacklevel=2
|
|
542
|
-
)
|
|
543
|
-
return None
|
|
351
|
+
inputs, *context_inputs = inputs
|
|
544
352
|
|
|
545
|
-
|
|
546
|
-
mol = self.conformer_generator(mol)
|
|
547
|
-
if not self.include_hs:
|
|
548
|
-
mol = chem.remove_hs(mol)
|
|
353
|
+
mol = chem.Mol.from_encoding(inputs, explicit_hs=True)
|
|
549
354
|
|
|
550
355
|
if mol.num_conformers == 0:
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
'Make sure to pass a `ConformerGenerator` to the constructor '
|
|
554
|
-
'of the `Featurizer` or input a 3D representation of the molecule. '
|
|
356
|
+
mol = chem.embed_conformers(
|
|
357
|
+
mol, num_conformers=1, random_seed=self._random_seed
|
|
555
358
|
)
|
|
556
359
|
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
context =
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
context = {
|
|
574
|
-
**{'size': molecule_size},
|
|
575
|
-
**{key: value for (key, value) in zip(['label', 'weight'], context)}
|
|
576
|
-
}
|
|
577
|
-
else:
|
|
578
|
-
context = {'size': molecule_size}
|
|
579
|
-
|
|
580
|
-
if molecule_feature is not None:
|
|
581
|
-
if 'feature' in context:
|
|
582
|
-
warnings.warn(
|
|
583
|
-
'Found both inputted and computed context feature. '
|
|
584
|
-
'Overwriting inputted context feature with computed '
|
|
585
|
-
'context feature (based on `molecule_features`).',
|
|
586
|
-
stacklevel=2
|
|
587
|
-
)
|
|
588
|
-
context['feature'] = molecule_feature
|
|
589
|
-
|
|
590
|
-
node = {}
|
|
591
|
-
node['feature'] = self.atom_features(mol)
|
|
592
|
-
|
|
593
|
-
if self._bond_features:
|
|
594
|
-
edge_feature = self.bond_features(mol)
|
|
595
|
-
|
|
596
|
-
edge = {}
|
|
597
|
-
mols = chem.unpack_conformers(mol)
|
|
598
|
-
tensor_list = []
|
|
599
|
-
for i, mol in enumerate(mols):
|
|
600
|
-
node_conformer = copy.deepcopy(node)
|
|
601
|
-
edge_conformer = copy.deepcopy(edge)
|
|
602
|
-
conformer = mol.get_conformer()
|
|
603
|
-
adjacency_matrix = conformer.adjacency(
|
|
604
|
-
fill='full',
|
|
605
|
-
radius=self.radius,
|
|
606
|
-
sparse=False,
|
|
607
|
-
self_loops=self.self_loops,
|
|
608
|
-
dtype=bool
|
|
360
|
+
if not self._include_hydrogens:
|
|
361
|
+
mol = chem.remove_hs(mol)
|
|
362
|
+
|
|
363
|
+
data = {'context': {}, 'node': {}, 'edge': {}}
|
|
364
|
+
|
|
365
|
+
data['context']['size'] = np.asarray(mol.num_atoms)
|
|
366
|
+
|
|
367
|
+
if len(context_inputs) == 1:
|
|
368
|
+
data['context']['label'] = np.asarray(context_inputs[0])
|
|
369
|
+
elif len(context_inputs) == 2:
|
|
370
|
+
data['context']['label'] = np.asarray(context_inputs[0])
|
|
371
|
+
data['context']['weight'] = np.asarray(context_inputs[1])
|
|
372
|
+
|
|
373
|
+
if self._molecule_features is not None:
|
|
374
|
+
data['context']['feature'] = np.concatenate(
|
|
375
|
+
[f(mol) for f in self._molecule_features], axis=-1
|
|
609
376
|
)
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
377
|
+
|
|
378
|
+
conformer = mol.get_conformer()
|
|
379
|
+
|
|
380
|
+
data['node']['feature'] = np.concatenate(
|
|
381
|
+
[f(mol) for f in self._atom_features], axis=-1
|
|
382
|
+
)
|
|
383
|
+
data['node']['coordinate'] = conformer.coordinates
|
|
384
|
+
|
|
385
|
+
adjacency_matrix = conformer.adjacency(
|
|
386
|
+
fill='full', radius=self._radius, sparse=False, self_loops=self._self_loops,
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
data['edge']['source'], data['edge']['target'] = np.where(adjacency_matrix)
|
|
390
|
+
|
|
391
|
+
if self._pair_features is not None:
|
|
392
|
+
pair_features = np.concatenate(
|
|
393
|
+
[f(mol) for f in self._pair_features], axis=-1
|
|
394
|
+
)
|
|
395
|
+
pair_keep = adjacency_matrix.reshape(-1).astype(bool)
|
|
396
|
+
data['edge']['feature'] = pair_features[pair_keep]
|
|
397
|
+
|
|
398
|
+
if self._super_node:
|
|
399
|
+
data = _add_super_node(data)
|
|
400
|
+
data['node']['coordinate'] = np.concatenate(
|
|
401
|
+
[data['node']['coordinate'], conformer.centroid[None]], axis=0
|
|
628
402
|
)
|
|
629
|
-
|
|
630
|
-
return
|
|
403
|
+
|
|
404
|
+
return tensors.GraphTensor(**_convert_dtypes(data))
|
|
631
405
|
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
return super().stack(outputs)
|
|
406
|
+
@property
|
|
407
|
+
def random_seed(self) -> int | None:
|
|
408
|
+
return self._random_seed
|
|
636
409
|
|
|
410
|
+
@random_seed.setter
|
|
411
|
+
def random_seed(self, value: int) -> None:
|
|
412
|
+
self._random_seed = value
|
|
413
|
+
|
|
637
414
|
def get_config(self):
|
|
638
415
|
config = super().get_config()
|
|
639
|
-
config['
|
|
640
|
-
|
|
416
|
+
config['radius'] = self._radius
|
|
417
|
+
config['pair_features'] = keras.saving.serialize_keras_object(
|
|
418
|
+
self._pair_features
|
|
641
419
|
)
|
|
420
|
+
config['random_seed'] = self._random_seed
|
|
642
421
|
return config
|
|
643
422
|
|
|
644
423
|
@classmethod
|
|
645
424
|
def from_config(cls, config: dict):
|
|
646
|
-
config['
|
|
647
|
-
config['
|
|
425
|
+
config['pair_features'] = keras.saving.deserialize_keras_object(
|
|
426
|
+
config['pair_features']
|
|
648
427
|
)
|
|
649
428
|
return super().from_config(config)
|
|
650
|
-
|
|
429
|
+
|
|
651
430
|
|
|
652
431
|
def save_featurizer(
|
|
653
|
-
featurizer:
|
|
432
|
+
featurizer: GraphFeaturizer,
|
|
654
433
|
filepath: str | Path,
|
|
655
434
|
overwrite: bool = True,
|
|
656
435
|
**kwargs
|
|
@@ -658,8 +437,8 @@ def save_featurizer(
|
|
|
658
437
|
filepath = Path(filepath)
|
|
659
438
|
if filepath.suffix != '.json':
|
|
660
439
|
raise ValueError(
|
|
661
|
-
'Invalid `filepath` extension for saving a `
|
|
662
|
-
'A `
|
|
440
|
+
'Invalid `filepath` extension for saving a `GraphFeaturizer`. '
|
|
441
|
+
'A `GraphFeaturizer` should be saved as a JSON file.'
|
|
663
442
|
)
|
|
664
443
|
if not filepath.parent.exists():
|
|
665
444
|
filepath.parent.mkdir(parents=True, exist_ok=True)
|
|
@@ -672,12 +451,12 @@ def save_featurizer(
|
|
|
672
451
|
def load_featurizer(
|
|
673
452
|
filepath: str | Path,
|
|
674
453
|
**kwargs
|
|
675
|
-
) ->
|
|
454
|
+
) -> GraphFeaturizer:
|
|
676
455
|
filepath = Path(filepath)
|
|
677
456
|
if filepath.suffix != '.json':
|
|
678
457
|
raise ValueError(
|
|
679
|
-
'Invalid `filepath` extension for loading a `
|
|
680
|
-
'A `
|
|
458
|
+
'Invalid `filepath` extension for loading a `GraphFeaturizer`. '
|
|
459
|
+
'A `GraphFeaturizer` should be saved as a JSON file.'
|
|
681
460
|
)
|
|
682
461
|
if not filepath.exists():
|
|
683
462
|
return
|
|
@@ -685,69 +464,60 @@ def load_featurizer(
|
|
|
685
464
|
config = json.load(f)
|
|
686
465
|
return keras.saving.deserialize_keras_object(config)
|
|
687
466
|
|
|
688
|
-
def
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
)
|
|
701
|
-
node['feature'] = np.concatenate([node['feature'], super_node_feature])
|
|
702
|
-
return node
|
|
703
|
-
|
|
704
|
-
def _add_super_edges(
|
|
705
|
-
edge: dict[str, np.ndarray],
|
|
706
|
-
num_nodes: int,
|
|
707
|
-
num_super_nodes: int,
|
|
708
|
-
feature_dtype: str,
|
|
709
|
-
index_dtype: str,
|
|
710
|
-
self_loops: bool,
|
|
711
|
-
) -> dict[str, np.ndarray]:
|
|
712
|
-
edge = copy.deepcopy(edge)
|
|
713
|
-
|
|
714
|
-
super_node_indices = np.arange(num_super_nodes) + num_nodes
|
|
715
|
-
if self_loops:
|
|
716
|
-
edge['source'] = np.concatenate([edge['source'], super_node_indices])
|
|
717
|
-
edge['target'] = np.concatenate([edge['target'], super_node_indices])
|
|
718
|
-
super_node_indices = np.repeat(super_node_indices, [num_nodes])
|
|
719
|
-
node_indices = (
|
|
720
|
-
np.tile(np.arange(num_nodes), [num_super_nodes])
|
|
467
|
+
def _add_super_node(
|
|
468
|
+
data: dict[str, dict[str, np.ndarray]]
|
|
469
|
+
) -> dict[str, dict[str, np.ndarray]]:
|
|
470
|
+
|
|
471
|
+
data['context']['size'] += 1
|
|
472
|
+
|
|
473
|
+
num_nodes = data['node']['feature'].shape[0]
|
|
474
|
+
num_edges = data['edge']['source'].shape[0]
|
|
475
|
+
super_node_index = num_nodes
|
|
476
|
+
|
|
477
|
+
add_self_loops = np.any(
|
|
478
|
+
data['edge']['source'] == data['edge']['target']
|
|
721
479
|
)
|
|
722
|
-
|
|
723
|
-
[edge['source']
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
edge['target'] = np.concatenate(
|
|
727
|
-
[edge['target'], super_node_indices, node_indices]
|
|
728
|
-
).astype(index_dtype)
|
|
729
|
-
|
|
730
|
-
if 'feature' in edge:
|
|
731
|
-
num_edges = int(edge['feature'].shape[0])
|
|
732
|
-
num_super_edges = int(num_super_nodes * num_nodes * 2)
|
|
733
|
-
if self_loops:
|
|
734
|
-
num_super_edges += num_super_nodes
|
|
735
|
-
edge['super'] = np.asarray(
|
|
736
|
-
([False] * num_edges + [True] * num_super_edges),
|
|
737
|
-
dtype=bool
|
|
480
|
+
if add_self_loops:
|
|
481
|
+
data['edge']['source'] = np.append(
|
|
482
|
+
data['edge']['source'], super_node_index
|
|
738
483
|
)
|
|
739
|
-
edge['
|
|
740
|
-
[
|
|
741
|
-
edge['feature'],
|
|
742
|
-
np.zeros(
|
|
743
|
-
shape=(num_super_edges, edge['feature'].shape[-1]),
|
|
744
|
-
dtype=edge['feature'].dtype
|
|
745
|
-
)
|
|
746
|
-
]
|
|
484
|
+
data['edge']['target'] = np.append(
|
|
485
|
+
data['edge']['target'], super_node_index
|
|
747
486
|
)
|
|
748
487
|
|
|
749
|
-
|
|
488
|
+
data['node']['feature'] = np.pad(data['node']['feature'], [(0, 1), (0, 0)])
|
|
489
|
+
data['node']['super'] = np.asarray([False] * num_nodes + [True])
|
|
750
490
|
|
|
491
|
+
node_indices = list(range(num_nodes))
|
|
492
|
+
super_node_indices = [super_node_index] * num_nodes
|
|
751
493
|
|
|
752
|
-
|
|
753
|
-
|
|
494
|
+
data['edge']['source'] = np.append(
|
|
495
|
+
data['edge']['source'], node_indices + super_node_indices
|
|
496
|
+
)
|
|
497
|
+
data['edge']['target'] = np.append(
|
|
498
|
+
data['edge']['target'], super_node_indices + node_indices
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
total_num_edges = data['edge']['source'].shape[0]
|
|
502
|
+
num_super_edges = (total_num_edges - num_edges)
|
|
503
|
+
data['edge']['super'] = np.asarray(
|
|
504
|
+
[False] * num_edges + [True] * num_super_edges
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
if 'feature' in data['edge']:
|
|
508
|
+
data['edge']['feature'] = np.pad(
|
|
509
|
+
data['edge']['feature'], [(0, num_super_edges), (0, 0)]
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
return data
|
|
513
|
+
|
|
514
|
+
def _convert_dtypes(data: dict[str, dict[str, np.ndarray]]) -> np.ndarray:
|
|
515
|
+
for outer_key, inner_dict in data.items():
|
|
516
|
+
for inner_key, inner_value in inner_dict.items():
|
|
517
|
+
if inner_key in ['source', 'target', 'size']:
|
|
518
|
+
data[outer_key][inner_key] = inner_value.astype(np.int32)
|
|
519
|
+
elif np.issubdtype(inner_value.dtype, np.integer):
|
|
520
|
+
data[outer_key][inner_key] = inner_value.astype(np.int32)
|
|
521
|
+
elif np.issubdtype(inner_value.dtype, np.floating):
|
|
522
|
+
data[outer_key][inner_key] = inner_value.astype(np.float32)
|
|
523
|
+
return data
|