molcraft 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of molcraft might be problematic. Click here for more details.
- molcraft/__init__.py +16 -0
- molcraft/callbacks.py +21 -0
- molcraft/chem.py +600 -0
- molcraft/conformers.py +155 -0
- molcraft/descriptors.py +90 -0
- molcraft/experimental/__init__.py +1 -0
- molcraft/experimental/peptides.py +303 -0
- molcraft/features.py +387 -0
- molcraft/featurizers.py +693 -0
- molcraft/layers.py +1224 -0
- molcraft/models.py +441 -0
- molcraft/ops.py +129 -0
- molcraft/records.py +169 -0
- molcraft/tensors.py +527 -0
- molcraft-0.1.0a1.dist-info/METADATA +58 -0
- molcraft-0.1.0a1.dist-info/RECORD +19 -0
- molcraft-0.1.0a1.dist-info/WHEEL +5 -0
- molcraft-0.1.0a1.dist-info/licenses/LICENSE +21 -0
- molcraft-0.1.0a1.dist-info/top_level.txt +1 -0
molcraft/layers.py
ADDED
|
@@ -0,0 +1,1224 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import keras
|
|
3
|
+
import tensorflow as tf
|
|
4
|
+
import warnings
|
|
5
|
+
from keras.src.models import functional
|
|
6
|
+
|
|
7
|
+
from molcraft import tensors
|
|
8
|
+
from molcraft import ops
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
12
|
+
class GraphLayer(keras.layers.Layer):
|
|
13
|
+
"""Base graph layer.
|
|
14
|
+
|
|
15
|
+
Currently, the `GraphLayer` only supports `GraphTensor` input.
|
|
16
|
+
|
|
17
|
+
The list of arguments are only relevant if the derived layer
|
|
18
|
+
invokes 'get_dense_kwargs`, `get_dense` or `get_einsum_dense`.
|
|
19
|
+
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
use_bias: bool = True,
|
|
25
|
+
kernel_initializer: keras.initializers.Initializer | str = "glorot_uniform",
|
|
26
|
+
bias_initializer: keras.initializers.Initializer | str = "zeros",
|
|
27
|
+
kernel_regularizer: keras.regularizers.Regularizer | None = None,
|
|
28
|
+
bias_regularizer: keras.regularizers.Regularizer | None = None,
|
|
29
|
+
activity_regularizer: keras.regularizers.Regularizer | None = None,
|
|
30
|
+
kernel_constraint: keras.constraints.Constraint | None = None,
|
|
31
|
+
bias_constraint: keras.constraints.Constraint | None = None,
|
|
32
|
+
**kwargs,
|
|
33
|
+
) -> None:
|
|
34
|
+
super().__init__(activity_regularizer=activity_regularizer, **kwargs)
|
|
35
|
+
self._use_bias = use_bias
|
|
36
|
+
self._kernel_initializer = keras.initializers.get(kernel_initializer)
|
|
37
|
+
self._bias_initializer = keras.initializers.get(bias_initializer)
|
|
38
|
+
self._kernel_regularizer = keras.regularizers.get(kernel_regularizer)
|
|
39
|
+
self._bias_regularizer = keras.regularizers.get(bias_regularizer)
|
|
40
|
+
self._kernel_constraint = keras.constraints.get(kernel_constraint)
|
|
41
|
+
self._bias_constraint = keras.constraints.get(bias_constraint)
|
|
42
|
+
self.built = False
|
|
43
|
+
# TODO: Add warning if build is implemented in subclass
|
|
44
|
+
# TODO: Add warning if call is implemented in subclass
|
|
45
|
+
|
|
46
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
47
|
+
"""Calls the layer.
|
|
48
|
+
|
|
49
|
+
Needs to be implemented by subclass.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
tensor:
|
|
53
|
+
A `GraphTensor` instance.
|
|
54
|
+
"""
|
|
55
|
+
raise NotImplementedError('`propagate` needs to be implemented.')
|
|
56
|
+
|
|
57
|
+
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
58
|
+
"""Builds the layer.
|
|
59
|
+
|
|
60
|
+
May use built-in methods such as `get_weight`, `get_dense` and `get_einsum_dense`.
|
|
61
|
+
|
|
62
|
+
Optionally implemented by subclass. If implemented, it is recommended to
|
|
63
|
+
build the sub-layers via `build([None, input_dim])`. If sub-layers are not
|
|
64
|
+
built, symbolic input will be passed through the layer to build it.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
spec:
|
|
68
|
+
A `GraphTensor.Spec` instance, corresponding to the input `GraphTensor`
|
|
69
|
+
of the `propagate` method.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
73
|
+
|
|
74
|
+
self._custom_build_config = {'spec': _serialize_spec(spec)}
|
|
75
|
+
|
|
76
|
+
invoke_build_from_spec = (
|
|
77
|
+
GraphLayer.build_from_spec != self.__class__.build_from_spec
|
|
78
|
+
)
|
|
79
|
+
if invoke_build_from_spec:
|
|
80
|
+
self.build_from_spec(spec)
|
|
81
|
+
self.built = True
|
|
82
|
+
|
|
83
|
+
if not self.built:
|
|
84
|
+
# Automatically build layer or model by calling it on symbolic inputs
|
|
85
|
+
self.built = True
|
|
86
|
+
symbolic_inputs = Input(spec)
|
|
87
|
+
self(symbolic_inputs)
|
|
88
|
+
|
|
89
|
+
def get_build_config(self) -> dict:
|
|
90
|
+
if not hasattr(self, '_custom_build_config'):
|
|
91
|
+
return super().get_build_config()
|
|
92
|
+
return self._custom_build_config
|
|
93
|
+
|
|
94
|
+
def build_from_config(self, config: dict) -> None:
|
|
95
|
+
use_custom_build_from_config = ('spec' in config)
|
|
96
|
+
if not use_custom_build_from_config:
|
|
97
|
+
super().build_from_config(config)
|
|
98
|
+
else:
|
|
99
|
+
spec = _deserialize_spec(config['spec'])
|
|
100
|
+
self.build(spec)
|
|
101
|
+
|
|
102
|
+
def call(
|
|
103
|
+
self,
|
|
104
|
+
graph: dict[str, dict[str, tf.Tensor]]
|
|
105
|
+
) -> dict[str, dict[str, tf.Tensor]]:
|
|
106
|
+
graph_tensor = tensors.from_dict(graph)
|
|
107
|
+
outputs = self.propagate(graph_tensor)
|
|
108
|
+
if isinstance(outputs, tensors.GraphTensor):
|
|
109
|
+
return tensors.to_dict(outputs)
|
|
110
|
+
return outputs
|
|
111
|
+
|
|
112
|
+
def __call__(self, inputs, **kwargs):
|
|
113
|
+
if not self.built:
|
|
114
|
+
spec = _spec_from_inputs(inputs)
|
|
115
|
+
self.build(spec)
|
|
116
|
+
convert = isinstance(inputs, tensors.GraphTensor)
|
|
117
|
+
if convert:
|
|
118
|
+
inputs = tensors.to_dict(inputs)
|
|
119
|
+
if isinstance(self, functional.Functional):
|
|
120
|
+
inputs, left_out_inputs = _match_functional_input(self.input, inputs)
|
|
121
|
+
outputs = super().__call__(inputs, **kwargs)
|
|
122
|
+
if not tensors.is_graph(outputs):
|
|
123
|
+
return outputs
|
|
124
|
+
if isinstance(self, functional.Functional):
|
|
125
|
+
outputs = _add_left_out_inputs(outputs, left_out_inputs)
|
|
126
|
+
if convert:
|
|
127
|
+
outputs = tensors.from_dict(outputs)
|
|
128
|
+
return outputs
|
|
129
|
+
|
|
130
|
+
def get_weight(
|
|
131
|
+
self,
|
|
132
|
+
shape: tf.TensorShape,
|
|
133
|
+
**kwargs,
|
|
134
|
+
) -> tf.Variable:
|
|
135
|
+
common_kwargs = self.get_dense_kwargs()
|
|
136
|
+
weight_kwargs = {
|
|
137
|
+
'initializer': common_kwargs['kernel_initializer'],
|
|
138
|
+
'regularizer': common_kwargs['kernel_regularizer'],
|
|
139
|
+
'constraint': common_kwargs['kernel_constraint']
|
|
140
|
+
}
|
|
141
|
+
weight_kwargs.update(kwargs)
|
|
142
|
+
return self.add_weight(shape=shape, **weight_kwargs)
|
|
143
|
+
|
|
144
|
+
def get_dense(
|
|
145
|
+
self,
|
|
146
|
+
units: int,
|
|
147
|
+
**kwargs
|
|
148
|
+
) -> keras.layers.Dense:
|
|
149
|
+
common_kwargs = self.get_dense_kwargs()
|
|
150
|
+
common_kwargs.update(kwargs)
|
|
151
|
+
return keras.layers.Dense(units, **common_kwargs)
|
|
152
|
+
|
|
153
|
+
def get_einsum_dense(
|
|
154
|
+
self,
|
|
155
|
+
equation: str,
|
|
156
|
+
output_shape: tf.TensorShape,
|
|
157
|
+
**kwargs
|
|
158
|
+
) -> keras.layers.EinsumDense:
|
|
159
|
+
common_kwargs = self.get_dense_kwargs()
|
|
160
|
+
common_kwargs.update(kwargs)
|
|
161
|
+
use_bias = common_kwargs.pop('use_bias', False)
|
|
162
|
+
if use_bias and not 'bias_axes' in common_kwargs:
|
|
163
|
+
common_kwargs['bias_axes'] = equation.split('->')[-1][1:] or None
|
|
164
|
+
return keras.layers.EinsumDense(equation, output_shape, **common_kwargs)
|
|
165
|
+
|
|
166
|
+
def get_dense_kwargs(self) -> dict:
|
|
167
|
+
common_kwargs = dict(
|
|
168
|
+
use_bias=self._use_bias,
|
|
169
|
+
kernel_regularizer=self._kernel_regularizer,
|
|
170
|
+
bias_regularizer=self._bias_regularizer,
|
|
171
|
+
activity_regularizer=self.activity_regularizer,
|
|
172
|
+
kernel_constraint=self._kernel_constraint,
|
|
173
|
+
bias_constraint=self._bias_constraint,
|
|
174
|
+
)
|
|
175
|
+
kernel_initializer = self._kernel_initializer.__class__.from_config(
|
|
176
|
+
self._kernel_initializer.get_config()
|
|
177
|
+
)
|
|
178
|
+
bias_initializer = self._bias_initializer.__class__.from_config(
|
|
179
|
+
self._bias_initializer.get_config()
|
|
180
|
+
)
|
|
181
|
+
common_kwargs["kernel_initializer"] = kernel_initializer
|
|
182
|
+
common_kwargs["bias_initializer"] = bias_initializer
|
|
183
|
+
return common_kwargs
|
|
184
|
+
|
|
185
|
+
def get_config(self) -> dict:
|
|
186
|
+
config = super().get_config()
|
|
187
|
+
config.update({
|
|
188
|
+
"use_bias": self._use_bias,
|
|
189
|
+
"kernel_initializer":
|
|
190
|
+
keras.initializers.serialize(self._kernel_initializer),
|
|
191
|
+
"bias_initializer":
|
|
192
|
+
keras.initializers.serialize(self._bias_initializer),
|
|
193
|
+
"kernel_regularizer":
|
|
194
|
+
keras.regularizers.serialize(self._kernel_regularizer),
|
|
195
|
+
"bias_regularizer":
|
|
196
|
+
keras.regularizers.serialize(self._bias_regularizer),
|
|
197
|
+
"kernel_constraint":
|
|
198
|
+
keras.constraints.serialize(self._kernel_constraint),
|
|
199
|
+
"bias_constraint":
|
|
200
|
+
keras.constraints.serialize(self._bias_constraint),
|
|
201
|
+
})
|
|
202
|
+
return config
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
206
|
+
class GraphConv(GraphLayer):
|
|
207
|
+
|
|
208
|
+
"""Base graph neural network layer.
|
|
209
|
+
"""
|
|
210
|
+
|
|
211
|
+
def __init__(self, units: int, **kwargs) -> None:
|
|
212
|
+
super().__init__(**kwargs)
|
|
213
|
+
self.units = units
|
|
214
|
+
|
|
215
|
+
@abc.abstractmethod
|
|
216
|
+
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
217
|
+
"""Compute messages.
|
|
218
|
+
|
|
219
|
+
This method needs to be implemented by subclass.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
tensor:
|
|
223
|
+
The inputted `GraphTensor` instance.
|
|
224
|
+
"""
|
|
225
|
+
|
|
226
|
+
@abc.abstractmethod
|
|
227
|
+
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
228
|
+
"""Aggregates messages.
|
|
229
|
+
|
|
230
|
+
This method needs to be implemented by subclass.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
tensor:
|
|
234
|
+
A `GraphTensor` instance containing a message.
|
|
235
|
+
"""
|
|
236
|
+
|
|
237
|
+
@abc.abstractmethod
|
|
238
|
+
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
239
|
+
"""Updates nodes.
|
|
240
|
+
|
|
241
|
+
This method needs to be implemented by subclass.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
tensor:
|
|
245
|
+
A `GraphTensor` instance containing aggregated messages
|
|
246
|
+
(updated node features).
|
|
247
|
+
"""
|
|
248
|
+
|
|
249
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
250
|
+
"""Calls the layer.
|
|
251
|
+
|
|
252
|
+
The `GraphConv` layer invokes `message`, `aggregate` and `update`
|
|
253
|
+
in sequence.
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
tensor:
|
|
257
|
+
A `GraphTensor` instance.
|
|
258
|
+
"""
|
|
259
|
+
tensor = self.message(tensor)
|
|
260
|
+
tensor = self.aggregate(tensor)
|
|
261
|
+
tensor = self.update(tensor)
|
|
262
|
+
return tensor
|
|
263
|
+
|
|
264
|
+
def get_config(self) -> dict:
|
|
265
|
+
config = super().get_config()
|
|
266
|
+
config.update({
|
|
267
|
+
'units': self.units
|
|
268
|
+
})
|
|
269
|
+
return config
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
273
|
+
class Projection(GraphLayer):
|
|
274
|
+
"""Base graph projection layer.
|
|
275
|
+
"""
|
|
276
|
+
def __init__(
|
|
277
|
+
self,
|
|
278
|
+
units: int = None,
|
|
279
|
+
activation: str = None,
|
|
280
|
+
field: str = 'node',
|
|
281
|
+
**kwargs
|
|
282
|
+
) -> None:
|
|
283
|
+
super().__init__(**kwargs)
|
|
284
|
+
self.units = units
|
|
285
|
+
self._activation = keras.activations.get(activation)
|
|
286
|
+
self.field = field
|
|
287
|
+
|
|
288
|
+
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
289
|
+
"""Builds the layer.
|
|
290
|
+
"""
|
|
291
|
+
data = getattr(spec, self.field, None)
|
|
292
|
+
if data is None:
|
|
293
|
+
raise ValueError('Could not access field {self.field!r}.')
|
|
294
|
+
feature_dim = data['feature'].shape[-1]
|
|
295
|
+
if not self.units:
|
|
296
|
+
self.units = feature_dim
|
|
297
|
+
self._dense = self.get_dense(self.units)
|
|
298
|
+
self._dense.build([None, feature_dim])
|
|
299
|
+
|
|
300
|
+
def propagate(self, tensor: tensors.GraphTensor):
|
|
301
|
+
"""Calls the layer.
|
|
302
|
+
"""
|
|
303
|
+
feature = getattr(tensor, self.field)['feature']
|
|
304
|
+
feature = self._dense(feature)
|
|
305
|
+
feature = self._activation(feature)
|
|
306
|
+
return tensor.update(
|
|
307
|
+
{
|
|
308
|
+
self.field: {
|
|
309
|
+
'feature': feature
|
|
310
|
+
}
|
|
311
|
+
}
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
def get_config(self) -> dict:
|
|
315
|
+
config = super().get_config()
|
|
316
|
+
config.update({
|
|
317
|
+
'units': self.units,
|
|
318
|
+
'activation': keras.activations.serialize(self._activation),
|
|
319
|
+
'field': self.field,
|
|
320
|
+
})
|
|
321
|
+
return config
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
325
|
+
class GraphNetwork(GraphLayer):
|
|
326
|
+
|
|
327
|
+
"""Graph neural network.
|
|
328
|
+
|
|
329
|
+
Sequentially calls graph layers (`GraphLayer`) and concatenates its output.
|
|
330
|
+
|
|
331
|
+
Args:
|
|
332
|
+
layers (list):
|
|
333
|
+
A list of graph layers.
|
|
334
|
+
"""
|
|
335
|
+
|
|
336
|
+
def __init__(self, layers: list[GraphLayer], **kwargs) -> None:
|
|
337
|
+
super().__init__(**kwargs)
|
|
338
|
+
self.layers = layers
|
|
339
|
+
self._update_edge_feature = False
|
|
340
|
+
|
|
341
|
+
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
342
|
+
"""Builds the layer.
|
|
343
|
+
"""
|
|
344
|
+
units = self.layers[0].units
|
|
345
|
+
node_feature_dim = spec.node['feature'].shape[-1]
|
|
346
|
+
if node_feature_dim != units:
|
|
347
|
+
warn(
|
|
348
|
+
'Node feature dim does not match `units` of the first layer. '
|
|
349
|
+
'Automatically adding a node projection layer to match `units`.'
|
|
350
|
+
)
|
|
351
|
+
self._node_dense = self.get_dense(units)
|
|
352
|
+
self._update_node_feature = True
|
|
353
|
+
has_edge_feature = 'feature' in spec.edge
|
|
354
|
+
if has_edge_feature:
|
|
355
|
+
edge_feature_dim = spec.edge['feature'].shape[-1]
|
|
356
|
+
if edge_feature_dim != units:
|
|
357
|
+
warn(
|
|
358
|
+
'Edge feature dim does not match `units` of the first layer. '
|
|
359
|
+
'Automatically adding a edge projection layer to match `units`.'
|
|
360
|
+
)
|
|
361
|
+
self._edge_dense = self.get_dense(units)
|
|
362
|
+
self._update_edge_feature = True
|
|
363
|
+
|
|
364
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
365
|
+
"""Calls the layer.
|
|
366
|
+
"""
|
|
367
|
+
x = tensors.to_dict(tensor)
|
|
368
|
+
if self._update_node_feature:
|
|
369
|
+
x['node']['feature'] = self._node_dense(tensor.node['feature'])
|
|
370
|
+
if self._update_edge_feature:
|
|
371
|
+
x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
|
|
372
|
+
outputs = [x['node']['feature']]
|
|
373
|
+
for layer in self.layers:
|
|
374
|
+
x = layer(x)
|
|
375
|
+
outputs.append(x['node']['feature'])
|
|
376
|
+
return tensor.update(
|
|
377
|
+
{
|
|
378
|
+
'node': {
|
|
379
|
+
'feature': keras.ops.concatenate(outputs, axis=-1)
|
|
380
|
+
}
|
|
381
|
+
}
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
def tape_propagate(
|
|
385
|
+
self,
|
|
386
|
+
tensor: tensors.GraphTensor,
|
|
387
|
+
tape: tf.GradientTape,
|
|
388
|
+
training: bool | None = None,
|
|
389
|
+
) -> tuple[tensors.GraphTensor, list[tf.Tensor]]:
|
|
390
|
+
"""Performs the propagation with a `GradientTape`.
|
|
391
|
+
|
|
392
|
+
Performs the same forward pass as `propagate` but with a `GradientTape`
|
|
393
|
+
watching intermediate node features.
|
|
394
|
+
|
|
395
|
+
Args:
|
|
396
|
+
tensor (tensors.GraphTensor):
|
|
397
|
+
The graph input.
|
|
398
|
+
"""
|
|
399
|
+
if isinstance(tensor, tensors.GraphTensor):
|
|
400
|
+
x = tensors.to_dict(tensor)
|
|
401
|
+
else:
|
|
402
|
+
x = tensor
|
|
403
|
+
if self._update_node_feature:
|
|
404
|
+
x['node']['feature'] = self._node_dense(tensor.node['feature'])
|
|
405
|
+
if self._update_edge_feature:
|
|
406
|
+
x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
|
|
407
|
+
tape.watch(x['node']['feature'])
|
|
408
|
+
outputs = [x['node']['feature']]
|
|
409
|
+
for layer in self.layers:
|
|
410
|
+
x = layer(x, training=training)
|
|
411
|
+
tape.watch(x['node']['feature'])
|
|
412
|
+
outputs.append(x['node']['feature'])
|
|
413
|
+
|
|
414
|
+
tensor = tensor.update(
|
|
415
|
+
{
|
|
416
|
+
'node': {
|
|
417
|
+
'feature': keras.ops.concatenate(outputs, axis=-1)
|
|
418
|
+
}
|
|
419
|
+
}
|
|
420
|
+
)
|
|
421
|
+
return tensor, outputs
|
|
422
|
+
|
|
423
|
+
def get_config(self) -> dict:
|
|
424
|
+
config = super().get_config()
|
|
425
|
+
config.update(
|
|
426
|
+
{
|
|
427
|
+
'layers': [
|
|
428
|
+
keras.layers.serialize(layer) for layer in self.layers
|
|
429
|
+
]
|
|
430
|
+
}
|
|
431
|
+
)
|
|
432
|
+
return config
|
|
433
|
+
|
|
434
|
+
@classmethod
|
|
435
|
+
def from_config(cls, config: dict) -> 'GraphNetwork':
|
|
436
|
+
config['layers'] = [
|
|
437
|
+
keras.layers.deserialize(layer) for layer in config['layers']
|
|
438
|
+
]
|
|
439
|
+
return super().from_config(config)
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
443
|
+
class NodeEmbedding(GraphLayer):
|
|
444
|
+
|
|
445
|
+
"""Node embedding layer.
|
|
446
|
+
|
|
447
|
+
Embeds nodes based on its initial features.
|
|
448
|
+
"""
|
|
449
|
+
|
|
450
|
+
def __init__(
|
|
451
|
+
self,
|
|
452
|
+
dim: int = None,
|
|
453
|
+
embed_context: bool = True,
|
|
454
|
+
allow_masking: bool = True,
|
|
455
|
+
**kwargs
|
|
456
|
+
) -> None:
|
|
457
|
+
super().__init__(**kwargs)
|
|
458
|
+
self.dim = dim
|
|
459
|
+
self._embed_context = embed_context
|
|
460
|
+
self._masking_rate = None
|
|
461
|
+
self._allow_masking = allow_masking
|
|
462
|
+
|
|
463
|
+
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
464
|
+
"""Builds the layer.
|
|
465
|
+
"""
|
|
466
|
+
feature_dim = spec.node['feature'].shape[-1]
|
|
467
|
+
if not self.dim:
|
|
468
|
+
self.dim = feature_dim
|
|
469
|
+
self._node_dense = self.get_dense(self.dim)
|
|
470
|
+
self._node_dense.build([None, feature_dim])
|
|
471
|
+
|
|
472
|
+
self._has_super = 'super' in spec.node
|
|
473
|
+
has_context_feature = 'feature' in spec.context
|
|
474
|
+
if not has_context_feature:
|
|
475
|
+
self._embed_context = False
|
|
476
|
+
if self._has_super and not self._embed_context:
|
|
477
|
+
self._super_feature = self.get_weight(shape=[self.dim], name='super_node_feature')
|
|
478
|
+
if self._allow_masking:
|
|
479
|
+
self._mask_feature = self.get_weight(shape=[self.dim], name='mask_node_feature')
|
|
480
|
+
|
|
481
|
+
if self._embed_context:
|
|
482
|
+
context_feature_dim = spec.context['feature'].shape[-1]
|
|
483
|
+
self._context_dense = self.get_dense(self.dim)
|
|
484
|
+
self._context_dense.build([None, context_feature_dim])
|
|
485
|
+
|
|
486
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
487
|
+
"""Calls the layer.
|
|
488
|
+
"""
|
|
489
|
+
feature = self._node_dense(tensor.node['feature'])
|
|
490
|
+
|
|
491
|
+
if self._has_super:
|
|
492
|
+
super_feature = (0 if self._embed_context else self._super_feature)
|
|
493
|
+
super_mask = keras.ops.expand_dims(tensor.node['super'], 1)
|
|
494
|
+
feature = keras.ops.where(super_mask, super_feature, feature)
|
|
495
|
+
|
|
496
|
+
if self._embed_context:
|
|
497
|
+
context_feature = self._context_dense(tensor.context['feature'])
|
|
498
|
+
feature = ops.scatter_update(feature, tensor.node['super'], context_feature)
|
|
499
|
+
tensor = tensor.update({'context': {'feature': None}})
|
|
500
|
+
|
|
501
|
+
if (
|
|
502
|
+
self._allow_masking and
|
|
503
|
+
self._masking_rate is not None and
|
|
504
|
+
self._masking_rate > 0
|
|
505
|
+
):
|
|
506
|
+
random = keras.random.uniform(shape=[tensor.num_nodes])
|
|
507
|
+
mask = random <= self._masking_rate
|
|
508
|
+
if self._has_super:
|
|
509
|
+
mask = keras.ops.logical_and(
|
|
510
|
+
mask, keras.ops.logical_not(tensor.node['super'])
|
|
511
|
+
)
|
|
512
|
+
mask = keras.ops.expand_dims(mask, -1)
|
|
513
|
+
feature = keras.ops.where(mask, self._mask_feature, feature)
|
|
514
|
+
elif self._allow_masking:
|
|
515
|
+
# Slience warning of 'no gradients for variables'
|
|
516
|
+
feature = feature + (self._mask_feature * 0.0)
|
|
517
|
+
|
|
518
|
+
return tensor.update({'node': {'feature': feature}})
|
|
519
|
+
|
|
520
|
+
@property
|
|
521
|
+
def masking_rate(self):
|
|
522
|
+
return self._masking_rate
|
|
523
|
+
|
|
524
|
+
@masking_rate.setter
|
|
525
|
+
def masking_rate(self, rate: float):
|
|
526
|
+
if not self._allow_masking and rate is not None:
|
|
527
|
+
raise ValueError(
|
|
528
|
+
f'Cannot set `masking_rate` for layer {self} '
|
|
529
|
+
'as `allow_masking` was set to `False`.'
|
|
530
|
+
)
|
|
531
|
+
self._masking_rate = float(rate)
|
|
532
|
+
|
|
533
|
+
def get_config(self) -> dict:
|
|
534
|
+
config = super().get_config()
|
|
535
|
+
config.update({
|
|
536
|
+
'dim': self.dim,
|
|
537
|
+
'allow_masking': self._allow_masking
|
|
538
|
+
})
|
|
539
|
+
return config
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
543
|
+
class EdgeEmbedding(GraphLayer):
|
|
544
|
+
|
|
545
|
+
"""Edge embedding layer.
|
|
546
|
+
|
|
547
|
+
Embeds edges based on its initial features.
|
|
548
|
+
"""
|
|
549
|
+
|
|
550
|
+
def __init__(
|
|
551
|
+
self,
|
|
552
|
+
dim: int = None,
|
|
553
|
+
allow_masking: bool = True,
|
|
554
|
+
**kwargs
|
|
555
|
+
) -> None:
|
|
556
|
+
super().__init__(**kwargs)
|
|
557
|
+
self.dim = dim
|
|
558
|
+
self._masking_rate = None
|
|
559
|
+
self._allow_masking = allow_masking
|
|
560
|
+
|
|
561
|
+
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
562
|
+
"""Builds the layer.
|
|
563
|
+
"""
|
|
564
|
+
feature_dim = spec.edge['feature'].shape[-1]
|
|
565
|
+
if not self.dim:
|
|
566
|
+
self.dim = feature_dim
|
|
567
|
+
self._edge_dense = self.get_dense(self.dim)
|
|
568
|
+
self._edge_dense.build([None, feature_dim])
|
|
569
|
+
|
|
570
|
+
self._has_super = 'super' in spec.edge
|
|
571
|
+
if self._has_super:
|
|
572
|
+
self._super_feature = self.get_weight(shape=[self.dim], name='super_edge_feature')
|
|
573
|
+
if self._allow_masking:
|
|
574
|
+
self._mask_feature = self.get_weight(shape=[self.dim], name='mask_edge_feature')
|
|
575
|
+
|
|
576
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
577
|
+
"""Calls the layer.
|
|
578
|
+
"""
|
|
579
|
+
feature = self._edge_dense(tensor.edge['feature'])
|
|
580
|
+
|
|
581
|
+
if self._has_super:
|
|
582
|
+
super_feature = self._super_feature
|
|
583
|
+
super_mask = keras.ops.expand_dims(tensor.edge['super'], 1)
|
|
584
|
+
feature = keras.ops.where(super_mask, super_feature, feature)
|
|
585
|
+
|
|
586
|
+
if (
|
|
587
|
+
self._allow_masking and
|
|
588
|
+
self._masking_rate is not None and
|
|
589
|
+
self._masking_rate > 0
|
|
590
|
+
):
|
|
591
|
+
random = keras.random.uniform(shape=[tensor.num_edges])
|
|
592
|
+
mask = random <= self._masking_rate
|
|
593
|
+
if self._has_super:
|
|
594
|
+
mask = keras.ops.logical_and(
|
|
595
|
+
mask, keras.ops.logical_not(tensor.edge['super'])
|
|
596
|
+
)
|
|
597
|
+
mask = keras.ops.expand_dims(mask, -1)
|
|
598
|
+
feature = keras.ops.where(mask, self._mask_feature, feature)
|
|
599
|
+
elif self._allow_masking:
|
|
600
|
+
# Slience warning of 'no gradients for variables'
|
|
601
|
+
feature = feature + (self._mask_feature * 0.0)
|
|
602
|
+
|
|
603
|
+
return tensor.update({'edge': {'feature': feature}})
|
|
604
|
+
|
|
605
|
+
@property
|
|
606
|
+
def masking_rate(self):
|
|
607
|
+
return self._masking_rate
|
|
608
|
+
|
|
609
|
+
@masking_rate.setter
|
|
610
|
+
def masking_rate(self, rate: float):
|
|
611
|
+
if not self._allow_masking and rate is not None:
|
|
612
|
+
raise ValueError(
|
|
613
|
+
f'Cannot set `masking_rate` for layer {self} '
|
|
614
|
+
'as `allow_masking` was set to `False`.'
|
|
615
|
+
)
|
|
616
|
+
self._masking_rate = float(rate)
|
|
617
|
+
|
|
618
|
+
def get_config(self) -> dict:
|
|
619
|
+
config = super().get_config()
|
|
620
|
+
config.update({
|
|
621
|
+
'dim': self.dim,
|
|
622
|
+
'allow_masking': self._allow_masking
|
|
623
|
+
})
|
|
624
|
+
return config
|
|
625
|
+
|
|
626
|
+
|
|
627
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
628
|
+
class ContextProjection(Projection):
|
|
629
|
+
"""Context projection layer.
|
|
630
|
+
"""
|
|
631
|
+
def __init__(self, units: int = None, activation: str = None, **kwargs):
|
|
632
|
+
super().__init__(units=units, activation=activation, field='context', **kwargs)
|
|
633
|
+
|
|
634
|
+
|
|
635
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
636
|
+
class NodeProjection(Projection):
|
|
637
|
+
"""Node projection layer.
|
|
638
|
+
"""
|
|
639
|
+
def __init__(self, units: int = None, activation: str = None, **kwargs):
|
|
640
|
+
super().__init__(units=units, activation=activation, field='node', **kwargs)
|
|
641
|
+
|
|
642
|
+
|
|
643
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
644
|
+
class EdgeProjection(Projection):
|
|
645
|
+
"""Edge projection layer.
|
|
646
|
+
"""
|
|
647
|
+
def __init__(self, units: int = None, activation: str = None, **kwargs):
|
|
648
|
+
super().__init__(units=units, activation=activation, field='edge', **kwargs)
|
|
649
|
+
|
|
650
|
+
|
|
651
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
652
|
+
class GINConv(GraphConv):
|
|
653
|
+
|
|
654
|
+
"""Graph isomorphism network layer.
|
|
655
|
+
"""
|
|
656
|
+
|
|
657
|
+
def __init__(
|
|
658
|
+
self,
|
|
659
|
+
units: int,
|
|
660
|
+
activation: keras.layers.Activation | str | None = 'relu',
|
|
661
|
+
dropout: float = 0.0,
|
|
662
|
+
normalize: bool = True,
|
|
663
|
+
update_edge_feature: bool = True,
|
|
664
|
+
**kwargs,
|
|
665
|
+
):
|
|
666
|
+
super().__init__(units=units, **kwargs)
|
|
667
|
+
self._activation = keras.activations.get(activation)
|
|
668
|
+
self._normalize = normalize
|
|
669
|
+
self._dropout = dropout
|
|
670
|
+
self._update_edge_feature = update_edge_feature
|
|
671
|
+
|
|
672
|
+
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
673
|
+
"""Builds the layer.
|
|
674
|
+
"""
|
|
675
|
+
node_feature_dim = spec.node['feature'].shape[-1]
|
|
676
|
+
|
|
677
|
+
self.epsilon = self.add_weight(
|
|
678
|
+
name='epsilon',
|
|
679
|
+
shape=(),
|
|
680
|
+
initializer='zeros',
|
|
681
|
+
trainable=True,
|
|
682
|
+
)
|
|
683
|
+
|
|
684
|
+
if 'feature' in spec.edge:
|
|
685
|
+
edge_feature_dim = spec.edge['feature'].shape[-1]
|
|
686
|
+
|
|
687
|
+
if not self._update_edge_feature:
|
|
688
|
+
if (edge_feature_dim != node_feature_dim):
|
|
689
|
+
warn(
|
|
690
|
+
'Found edge feature dim to be incompatible with node feature dim. '
|
|
691
|
+
'Automatically adding a edge feature projection layer to match '
|
|
692
|
+
'the dim of node features.'
|
|
693
|
+
)
|
|
694
|
+
self._update_edge_feature = True
|
|
695
|
+
|
|
696
|
+
if self._update_edge_feature:
|
|
697
|
+
self._edge_dense = self.get_dense(node_feature_dim)
|
|
698
|
+
self._edge_dense.build([None, edge_feature_dim])
|
|
699
|
+
else:
|
|
700
|
+
self._update_edge_feature = False
|
|
701
|
+
|
|
702
|
+
has_overridden_update = self.__class__.update != GINConv.update
|
|
703
|
+
if not has_overridden_update:
|
|
704
|
+
# Use default feedforward network
|
|
705
|
+
self._feedforward_intermediate_dense = self.get_dense(self.units)
|
|
706
|
+
self._feedforward_intermediate_dense.build([None, node_feature_dim])
|
|
707
|
+
|
|
708
|
+
if self._normalize:
|
|
709
|
+
self._feedforward_intermediate_norm = keras.layers.BatchNormalization()
|
|
710
|
+
self._feedforward_intermediate_norm.build([None, self.units])
|
|
711
|
+
|
|
712
|
+
self._feedforward_dropout = keras.layers.Dropout(self._dropout)
|
|
713
|
+
self._feedforward_activation = self._activation
|
|
714
|
+
|
|
715
|
+
self._feedforward_output_dense = self.get_dense(self.units)
|
|
716
|
+
self._feedforward_output_dense.build([None, self.units])
|
|
717
|
+
|
|
718
|
+
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
719
|
+
"""Compute messages.
|
|
720
|
+
"""
|
|
721
|
+
message = tensor.gather('feature', 'source')
|
|
722
|
+
edge_feature = tensor.edge.get('feature')
|
|
723
|
+
if self._update_edge_feature:
|
|
724
|
+
edge_feature = self._edge_dense(edge_feature)
|
|
725
|
+
if edge_feature is not None:
|
|
726
|
+
message += edge_feature
|
|
727
|
+
return tensor.update(
|
|
728
|
+
{
|
|
729
|
+
'edge': {
|
|
730
|
+
'message': message,
|
|
731
|
+
'feature': edge_feature
|
|
732
|
+
}
|
|
733
|
+
}
|
|
734
|
+
)
|
|
735
|
+
|
|
736
|
+
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
737
|
+
"""Aggregates messages.
|
|
738
|
+
"""
|
|
739
|
+
node_feature = tensor.aggregate('message')
|
|
740
|
+
node_feature += (1 + self.epsilon) * tensor.node['feature']
|
|
741
|
+
return tensor.update(
|
|
742
|
+
{
|
|
743
|
+
'node': {
|
|
744
|
+
'feature': node_feature,
|
|
745
|
+
},
|
|
746
|
+
'edge': {
|
|
747
|
+
'message': None,
|
|
748
|
+
}
|
|
749
|
+
}
|
|
750
|
+
)
|
|
751
|
+
|
|
752
|
+
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
753
|
+
"""Updates nodes.
|
|
754
|
+
"""
|
|
755
|
+
node_feature = tensor.node['feature']
|
|
756
|
+
node_feature = self._feedforward_intermediate_dense(node_feature)
|
|
757
|
+
node_feature = self._feedforward_activation(node_feature)
|
|
758
|
+
if self._normalize:
|
|
759
|
+
node_feature = self._feedforward_intermediate_norm(node_feature)
|
|
760
|
+
node_feature = self._feedforward_dropout(node_feature)
|
|
761
|
+
node_feature = self._feedforward_output_dense(node_feature)
|
|
762
|
+
return tensor.update(
|
|
763
|
+
{
|
|
764
|
+
'node': {
|
|
765
|
+
'feature': node_feature,
|
|
766
|
+
}
|
|
767
|
+
}
|
|
768
|
+
)
|
|
769
|
+
|
|
770
|
+
def get_config(self) -> dict:
|
|
771
|
+
config = super().get_config()
|
|
772
|
+
config.update({
|
|
773
|
+
'activation': keras.activations.serialize(self._activation),
|
|
774
|
+
'dropout': self._dropout,
|
|
775
|
+
'normalize': self._normalize,
|
|
776
|
+
})
|
|
777
|
+
return config
|
|
778
|
+
|
|
779
|
+
|
|
780
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
781
|
+
class GTConv(GraphConv):
|
|
782
|
+
|
|
783
|
+
"""Graph transformer layer.
|
|
784
|
+
"""
|
|
785
|
+
|
|
786
|
+
def __init__(
|
|
787
|
+
self,
|
|
788
|
+
units: int,
|
|
789
|
+
heads: int = 8,
|
|
790
|
+
activation: keras.layers.Activation | str | None = "relu",
|
|
791
|
+
dropout: float = 0.0,
|
|
792
|
+
attention_dropout: float = 0.0,
|
|
793
|
+
normalize: bool = True,
|
|
794
|
+
normalize_first: bool = True,
|
|
795
|
+
**kwargs,
|
|
796
|
+
) -> None:
|
|
797
|
+
super().__init__(units=units, **kwargs)
|
|
798
|
+
self._heads = heads
|
|
799
|
+
if self.units % self.heads != 0:
|
|
800
|
+
raise ValueError(f"units need to be divisible by heads.")
|
|
801
|
+
self._head_units = self.units // self.heads
|
|
802
|
+
self._activation = keras.activations.get(activation)
|
|
803
|
+
self._dropout = dropout
|
|
804
|
+
self._attention_dropout = attention_dropout
|
|
805
|
+
self._normalize = normalize
|
|
806
|
+
self._normalize_first = normalize_first
|
|
807
|
+
|
|
808
|
+
@property
|
|
809
|
+
def heads(self):
|
|
810
|
+
return self._heads
|
|
811
|
+
|
|
812
|
+
@property
|
|
813
|
+
def head_units(self):
|
|
814
|
+
return self._head_units
|
|
815
|
+
|
|
816
|
+
def build_from_spec(self, spec):
|
|
817
|
+
"""Builds the layer.
|
|
818
|
+
"""
|
|
819
|
+
node_feature_dim = spec.node['feature'].shape[-1]
|
|
820
|
+
incompatible_dim = node_feature_dim != self.units
|
|
821
|
+
if incompatible_dim:
|
|
822
|
+
warnings.warn(
|
|
823
|
+
message=(
|
|
824
|
+
'`GTConv` uses residual connections, but input node feature dim '
|
|
825
|
+
'is incompatible with intermediate dim (`units`). '
|
|
826
|
+
'Automatically projecting first residual to match its dim with intermediate dim.'
|
|
827
|
+
),
|
|
828
|
+
category=UserWarning,
|
|
829
|
+
stacklevel=1
|
|
830
|
+
)
|
|
831
|
+
self._residual_dense = self.get_dense(self.units)
|
|
832
|
+
self._residual_dense.build([None, node_feature_dim])
|
|
833
|
+
self._project_residual = True
|
|
834
|
+
else:
|
|
835
|
+
self._project_residual = False
|
|
836
|
+
|
|
837
|
+
self._query_dense = self.get_einsum_dense(
|
|
838
|
+
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
839
|
+
)
|
|
840
|
+
self._query_dense.build([None, node_feature_dim])
|
|
841
|
+
|
|
842
|
+
self._key_dense = self.get_einsum_dense(
|
|
843
|
+
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
844
|
+
)
|
|
845
|
+
self._key_dense.build([None, node_feature_dim])
|
|
846
|
+
|
|
847
|
+
self._value_dense = self.get_einsum_dense(
|
|
848
|
+
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
849
|
+
)
|
|
850
|
+
self._value_dense.build([None, node_feature_dim])
|
|
851
|
+
|
|
852
|
+
self._output_dense = self.get_dense(self.units)
|
|
853
|
+
self._output_dense.build([None, self.units])
|
|
854
|
+
|
|
855
|
+
self._softmax_dropout = keras.layers.Dropout(self._attention_dropout)
|
|
856
|
+
|
|
857
|
+
self._self_attention_norm = keras.layers.LayerNormalization()
|
|
858
|
+
if self._normalize_first:
|
|
859
|
+
self._self_attention_norm.build([None, node_feature_dim])
|
|
860
|
+
else:
|
|
861
|
+
self._self_attention_norm.build([None, self.units])
|
|
862
|
+
|
|
863
|
+
self._self_attention_dropout = keras.layers.Dropout(self._dropout)
|
|
864
|
+
|
|
865
|
+
has_overriden_edge_bias = (
|
|
866
|
+
self.__class__.add_edge_bias != GTConv.add_edge_bias
|
|
867
|
+
)
|
|
868
|
+
if not has_overriden_edge_bias:
|
|
869
|
+
self._has_edge_length = 'length' in spec.edge
|
|
870
|
+
if self._has_edge_length and 'bias' not in spec.edge:
|
|
871
|
+
edge_length_dim = spec.edge['length'].shape[-1]
|
|
872
|
+
self._spatial_encoding_dense = self.get_einsum_dense(
|
|
873
|
+
'ij,jkh->ikh', (1, self.heads), kernel_initializer='zeros'
|
|
874
|
+
)
|
|
875
|
+
self._spatial_encoding_dense.build([None, edge_length_dim])
|
|
876
|
+
|
|
877
|
+
self._has_edge_feature = 'feature' in spec.edge
|
|
878
|
+
if self._has_edge_feature and 'bias' not in spec.edge:
|
|
879
|
+
edge_feature_dim = spec.edge['feature'].shape[-1]
|
|
880
|
+
self._edge_feature_dense = self.get_einsum_dense(
|
|
881
|
+
'ij,jkh->ikh', (1, self.heads),
|
|
882
|
+
)
|
|
883
|
+
self._edge_feature_dense.build([None, edge_feature_dim])
|
|
884
|
+
|
|
885
|
+
has_overridden_update = self.__class__.update != GTConv.update
|
|
886
|
+
if not has_overridden_update:
|
|
887
|
+
|
|
888
|
+
self._feedforward_norm = keras.layers.LayerNormalization()
|
|
889
|
+
self._feedforward_norm.build([None, self.units])
|
|
890
|
+
|
|
891
|
+
self._feedforward_dropout = keras.layers.Dropout(self._dropout)
|
|
892
|
+
|
|
893
|
+
self._feedforward_intermediate_dense = self.get_dense(self.units)
|
|
894
|
+
self._feedforward_intermediate_dense.build([None, self.units])
|
|
895
|
+
|
|
896
|
+
self._feedforward_output_dense = self.get_dense(self.units)
|
|
897
|
+
self._feedforward_output_dense.build([None, self.units])
|
|
898
|
+
|
|
899
|
+
def add_node_bias(self, tensor: tensors.GraphTensor) -> tf.Tensor:
|
|
900
|
+
return tensor
|
|
901
|
+
|
|
902
|
+
def add_edge_bias(self, tensor: tensors.GraphTensor) -> tf.Tensor:
|
|
903
|
+
if 'bias' in tensor.edge:
|
|
904
|
+
return tensor
|
|
905
|
+
elif not self._has_edge_feature and not self._has_edge_length:
|
|
906
|
+
return tensor
|
|
907
|
+
|
|
908
|
+
if self._has_edge_feature and not self._has_edge_length:
|
|
909
|
+
edge_bias = self._edge_feature_dense(tensor.edge['feature'])
|
|
910
|
+
elif not self._has_edge_feature and self._has_edge_length:
|
|
911
|
+
edge_bias = self._spatial_encoding_dense(tensor.edge['length'])
|
|
912
|
+
else:
|
|
913
|
+
edge_bias = (
|
|
914
|
+
self._edge_feature_dense(tensor.edge['feature']) +
|
|
915
|
+
self._spatial_encoding_dense(tensor.edge['length'])
|
|
916
|
+
)
|
|
917
|
+
|
|
918
|
+
return tensor.update(
|
|
919
|
+
{
|
|
920
|
+
'edge': {
|
|
921
|
+
'bias': edge_bias
|
|
922
|
+
}
|
|
923
|
+
}
|
|
924
|
+
)
|
|
925
|
+
|
|
926
|
+
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
927
|
+
"""Compute messages.
|
|
928
|
+
"""
|
|
929
|
+
tensor = self.add_edge_bias(tensor)
|
|
930
|
+
tensor = self.add_node_bias(tensor)
|
|
931
|
+
|
|
932
|
+
node_feature = tensor.node['feature']
|
|
933
|
+
|
|
934
|
+
if 'bias' in tensor.node:
|
|
935
|
+
node_feature += tensor.node['bias']
|
|
936
|
+
|
|
937
|
+
if self._normalize_first:
|
|
938
|
+
node_feature = self._self_attention_norm(node_feature)
|
|
939
|
+
|
|
940
|
+
query = self._query_dense(node_feature)
|
|
941
|
+
key = self._key_dense(node_feature)
|
|
942
|
+
value = self._value_dense(node_feature)
|
|
943
|
+
|
|
944
|
+
query = ops.gather(query, tensor.edge['source'])
|
|
945
|
+
key = ops.gather(key, tensor.edge['target'])
|
|
946
|
+
value = ops.gather(value, tensor.edge['source'])
|
|
947
|
+
|
|
948
|
+
attention_score = keras.ops.sum(query * key, axis=1, keepdims=True)
|
|
949
|
+
attention_score /= keras.ops.sqrt(float(self.units))
|
|
950
|
+
|
|
951
|
+
if 'bias' in tensor.edge:
|
|
952
|
+
attention_score += tensor.edge['bias']
|
|
953
|
+
|
|
954
|
+
attention = ops.edge_softmax(attention_score, tensor.edge['target'])
|
|
955
|
+
attention = self._softmax_dropout(attention)
|
|
956
|
+
|
|
957
|
+
return tensor.update(
|
|
958
|
+
{
|
|
959
|
+
'edge': {
|
|
960
|
+
'message': value,
|
|
961
|
+
'weight': attention,
|
|
962
|
+
},
|
|
963
|
+
}
|
|
964
|
+
)
|
|
965
|
+
|
|
966
|
+
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
967
|
+
"""Aggregates messages.
|
|
968
|
+
"""
|
|
969
|
+
node_feature = tensor.aggregate('message')
|
|
970
|
+
|
|
971
|
+
node_feature = keras.ops.reshape(node_feature, (-1, self.units))
|
|
972
|
+
node_feature = self._output_dense(node_feature)
|
|
973
|
+
node_feature = self._self_attention_dropout(node_feature)
|
|
974
|
+
|
|
975
|
+
residual = tensor.node['feature']
|
|
976
|
+
if self._project_residual:
|
|
977
|
+
residual = self._residual_dense(residual)
|
|
978
|
+
node_feature += residual
|
|
979
|
+
|
|
980
|
+
if not self._normalize_first:
|
|
981
|
+
node_feature = self._self_attention_norm(node_feature)
|
|
982
|
+
|
|
983
|
+
return tensor.update(
|
|
984
|
+
{
|
|
985
|
+
'node': {
|
|
986
|
+
'feature': node_feature,
|
|
987
|
+
},
|
|
988
|
+
'edge': {
|
|
989
|
+
'message': None,
|
|
990
|
+
'weight': None,
|
|
991
|
+
}
|
|
992
|
+
}
|
|
993
|
+
)
|
|
994
|
+
|
|
995
|
+
|
|
996
|
+
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
997
|
+
"""Updates nodes.
|
|
998
|
+
"""
|
|
999
|
+
node_feature = tensor.node['feature']
|
|
1000
|
+
|
|
1001
|
+
if self._normalize_first:
|
|
1002
|
+
node_feature = self._feedforward_norm(node_feature)
|
|
1003
|
+
|
|
1004
|
+
node_feature = self._feedforward_intermediate_dense(node_feature)
|
|
1005
|
+
node_feature = self._activation(node_feature)
|
|
1006
|
+
node_feature = self._feedforward_output_dense(node_feature)
|
|
1007
|
+
|
|
1008
|
+
node_feature = self._feedforward_dropout(node_feature)
|
|
1009
|
+
node_feature += tensor.node['feature']
|
|
1010
|
+
|
|
1011
|
+
if not self._normalize_first:
|
|
1012
|
+
node_feature = self._feedforward_norm(node_feature)
|
|
1013
|
+
|
|
1014
|
+
return tensor.update(
|
|
1015
|
+
{
|
|
1016
|
+
'node': {
|
|
1017
|
+
'feature': node_feature,
|
|
1018
|
+
},
|
|
1019
|
+
}
|
|
1020
|
+
)
|
|
1021
|
+
|
|
1022
|
+
def get_config(self) -> dict:
|
|
1023
|
+
config = super().get_config()
|
|
1024
|
+
config.update({
|
|
1025
|
+
"heads": self._heads,
|
|
1026
|
+
'activation': keras.activations.serialize(self._activation),
|
|
1027
|
+
'dropout': self._dropout,
|
|
1028
|
+
'attention_dropout': self._attention_dropout,
|
|
1029
|
+
'normalize': self._normalize,
|
|
1030
|
+
'normalize_first': self._normalize_first,
|
|
1031
|
+
})
|
|
1032
|
+
return config
|
|
1033
|
+
|
|
1034
|
+
|
|
1035
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1036
|
+
class Readout(keras.layers.Layer):
|
|
1037
|
+
|
|
1038
|
+
def __init__(self, mode: str | None = None, **kwargs):
|
|
1039
|
+
super().__init__(**kwargs)
|
|
1040
|
+
self.mode = mode
|
|
1041
|
+
if not self.mode:
|
|
1042
|
+
self._reduce_fn = None
|
|
1043
|
+
elif str(self.mode).lower().startswith('sum'):
|
|
1044
|
+
self._reduce_fn = keras.ops.segment_sum
|
|
1045
|
+
elif str(self.mode).lower().startswith('max'):
|
|
1046
|
+
self._reduce_fn = keras.ops.segment_max
|
|
1047
|
+
elif str(self.mode).lower().startswith('super'):
|
|
1048
|
+
self._reduce_fn = keras.ops.segment_sum
|
|
1049
|
+
else:
|
|
1050
|
+
self._reduce_fn = ops.segment_mean
|
|
1051
|
+
|
|
1052
|
+
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1053
|
+
"""Builds the layer.
|
|
1054
|
+
"""
|
|
1055
|
+
pass
|
|
1056
|
+
|
|
1057
|
+
def reduce(self, tensor: tensors.GraphTensor) -> tf.Tensor:
|
|
1058
|
+
if self._reduce_fn is None:
|
|
1059
|
+
raise NotImplementedError("Need to define a reduce method.")
|
|
1060
|
+
if str(self.mode).lower().startswith('super'):
|
|
1061
|
+
node_feature = keras.ops.where(
|
|
1062
|
+
tensor.node['super'][:, None], tensor.node['feature'], 0.0
|
|
1063
|
+
)
|
|
1064
|
+
return self._reduce_fn(
|
|
1065
|
+
node_feature, tensor.graph_indicator, tensor.num_subgraphs
|
|
1066
|
+
)
|
|
1067
|
+
return self._reduce_fn(
|
|
1068
|
+
tensor.node['feature'], tensor.graph_indicator, tensor.num_subgraphs
|
|
1069
|
+
)
|
|
1070
|
+
|
|
1071
|
+
def build(self, input_shapes) -> None:
|
|
1072
|
+
spec = tensors.GraphTensor.Spec.from_input_shape_dict(input_shapes)
|
|
1073
|
+
self.build_from_spec(spec)
|
|
1074
|
+
self.built = True
|
|
1075
|
+
|
|
1076
|
+
def call(self, graph) -> tf.Tensor:
|
|
1077
|
+
graph_tensor = tensors.from_dict(graph)
|
|
1078
|
+
if tensors.is_ragged(graph_tensor):
|
|
1079
|
+
graph_tensor = graph_tensor.flatten()
|
|
1080
|
+
return self.reduce(graph_tensor)
|
|
1081
|
+
|
|
1082
|
+
def __call__(
|
|
1083
|
+
self,
|
|
1084
|
+
graph: tensors.GraphTensor,
|
|
1085
|
+
*args,
|
|
1086
|
+
**kwargs
|
|
1087
|
+
) -> tensors.GraphTensor:
|
|
1088
|
+
is_tensor = isinstance(graph, tensors.GraphTensor)
|
|
1089
|
+
if is_tensor:
|
|
1090
|
+
graph = tensors.to_dict(graph)
|
|
1091
|
+
tensor = super().__call__(graph, *args, **kwargs)
|
|
1092
|
+
return tensor
|
|
1093
|
+
|
|
1094
|
+
def get_config(self) -> dict:
|
|
1095
|
+
config = super().get_config()
|
|
1096
|
+
config['mode'] = self.mode
|
|
1097
|
+
return config
|
|
1098
|
+
|
|
1099
|
+
|
|
1100
|
+
def Input(spec: tensors.GraphTensor.Spec) -> dict:
|
|
1101
|
+
"""Used to specify inputs to model.
|
|
1102
|
+
|
|
1103
|
+
Example:
|
|
1104
|
+
|
|
1105
|
+
>>> import molcraft
|
|
1106
|
+
>>> import keras
|
|
1107
|
+
>>>
|
|
1108
|
+
>>> featurizer = molcraft.featurizers.MolGraphFeaturizer()
|
|
1109
|
+
>>> graph = featurizer([('N[C@@H](C)C(=O)O', 1.0), ('N[C@@H](CS)C(=O)O', 2.0)])
|
|
1110
|
+
>>>
|
|
1111
|
+
>>> model = molcraft.models.GraphModel.from_layers(
|
|
1112
|
+
... molcraft.layers.Input(graph.spec),
|
|
1113
|
+
... molcraft.layers.NodeEmbedding(128),
|
|
1114
|
+
... molcraft.layers.EdgeEmbedding(128),
|
|
1115
|
+
... molcraft.layers.GraphTransformer(128),
|
|
1116
|
+
... molcraft.layers.GraphTransformer(128),
|
|
1117
|
+
... molcraft.layers.Readout('mean'),
|
|
1118
|
+
... molcraft.layers.Dense(1)
|
|
1119
|
+
... ])
|
|
1120
|
+
"""
|
|
1121
|
+
|
|
1122
|
+
# Currently, Keras (3.8.0) does not support extension types.
|
|
1123
|
+
# So for now, this function will unpack the `GraphTensor.Spec` and
|
|
1124
|
+
# return a dictionary of nested tensor specs. However, the corresponding
|
|
1125
|
+
# nest of tensors will temporarily be converted to a `GraphTensor` by the
|
|
1126
|
+
# `GraphLayer`, to levarage the utility of a `GraphTensor` object.
|
|
1127
|
+
inputs = {}
|
|
1128
|
+
for outer_field, data in spec.__dict__.items():
|
|
1129
|
+
inputs[outer_field] = {}
|
|
1130
|
+
for inner_field, nested_spec in data.items():
|
|
1131
|
+
if inner_field in ['label', 'weight']:
|
|
1132
|
+
if outer_field == 'context':
|
|
1133
|
+
continue
|
|
1134
|
+
kwargs = {
|
|
1135
|
+
'shape': nested_spec.shape[1:],
|
|
1136
|
+
'dtype': nested_spec.dtype,
|
|
1137
|
+
'name': f'{outer_field}_{inner_field}'
|
|
1138
|
+
}
|
|
1139
|
+
if isinstance(nested_spec, tf.RaggedTensorSpec):
|
|
1140
|
+
kwargs['ragged'] = True
|
|
1141
|
+
try:
|
|
1142
|
+
inputs[outer_field][inner_field] = keras.Input(**kwargs)
|
|
1143
|
+
except TypeError:
|
|
1144
|
+
raise ValueError(
|
|
1145
|
+
"`keras.Input` does not currently support ragged tensors. For now, "
|
|
1146
|
+
"pass the `Spec` of a 'flat' `GraphTensor` to `GNNInput`."
|
|
1147
|
+
)
|
|
1148
|
+
return inputs
|
|
1149
|
+
|
|
1150
|
+
|
|
1151
|
+
def warn(message: str) -> None:
|
|
1152
|
+
warnings.warn(
|
|
1153
|
+
message=message,
|
|
1154
|
+
category=UserWarning,
|
|
1155
|
+
stacklevel=1
|
|
1156
|
+
)
|
|
1157
|
+
|
|
1158
|
+
def _match_functional_input(functional_input, inputs):
|
|
1159
|
+
matching_inputs = {}
|
|
1160
|
+
for outer_field, data in functional_input.items():
|
|
1161
|
+
matching_inputs[outer_field] = {}
|
|
1162
|
+
for inner_field, _ in data.items():
|
|
1163
|
+
call_input = inputs[outer_field].pop(inner_field)
|
|
1164
|
+
matching_inputs[outer_field][inner_field] = call_input
|
|
1165
|
+
unmatching_inputs = inputs
|
|
1166
|
+
return matching_inputs, unmatching_inputs
|
|
1167
|
+
|
|
1168
|
+
def _add_left_out_inputs(outputs, inputs):
|
|
1169
|
+
for outer_field, data in inputs.items():
|
|
1170
|
+
for inner_field, value in data.items():
|
|
1171
|
+
if inner_field in ['label', 'weight']:
|
|
1172
|
+
outputs[outer_field][inner_field] = value
|
|
1173
|
+
return outputs
|
|
1174
|
+
|
|
1175
|
+
def _serialize_spec(spec: tensors.GraphTensor.Spec) -> dict:
|
|
1176
|
+
serialized_spec = {}
|
|
1177
|
+
for outer_field, data in spec.__dict__.items():
|
|
1178
|
+
serialized_spec[outer_field] = {}
|
|
1179
|
+
for inner_field, inner_spec in data.items():
|
|
1180
|
+
serialized_spec[outer_field][inner_field] = {
|
|
1181
|
+
'shape': inner_spec.shape.as_list(),
|
|
1182
|
+
'dtype': inner_spec.dtype.name,
|
|
1183
|
+
'name': inner_spec.name,
|
|
1184
|
+
}
|
|
1185
|
+
return serialized_spec
|
|
1186
|
+
|
|
1187
|
+
def _deserialize_spec(serialized_spec: dict) -> tensors.GraphTensor.Spec:
|
|
1188
|
+
deserialized_spec = {}
|
|
1189
|
+
for outer_field, data in serialized_spec.items():
|
|
1190
|
+
deserialized_spec[outer_field] = {}
|
|
1191
|
+
for inner_field, inner_spec in data.items():
|
|
1192
|
+
deserialized_spec[outer_field][inner_field] = tf.TensorSpec(
|
|
1193
|
+
inner_spec['shape'], inner_spec['dtype'], inner_spec['name']
|
|
1194
|
+
)
|
|
1195
|
+
return tensors.GraphTensor.Spec(**deserialized_spec)
|
|
1196
|
+
|
|
1197
|
+
def _spec_from_inputs(inputs):
|
|
1198
|
+
symbolic_inputs = keras.backend.is_keras_tensor(
|
|
1199
|
+
tf.nest.flatten(inputs)[0]
|
|
1200
|
+
)
|
|
1201
|
+
if not symbolic_inputs:
|
|
1202
|
+
nested_specs = tf.nest.map_structure(
|
|
1203
|
+
tf.type_spec_from_value, inputs
|
|
1204
|
+
)
|
|
1205
|
+
else:
|
|
1206
|
+
nested_specs = tf.nest.map_structure(
|
|
1207
|
+
lambda t: tf.TensorSpec(t.shape, t.dtype), inputs
|
|
1208
|
+
)
|
|
1209
|
+
if isinstance(nested_specs, tensors.GraphTensor.Spec):
|
|
1210
|
+
spec = nested_specs
|
|
1211
|
+
return spec
|
|
1212
|
+
return tensors.GraphTensor.Spec(**nested_specs)
|
|
1213
|
+
|
|
1214
|
+
|
|
1215
|
+
GraphTransformer = GTConvolution = GTConv
|
|
1216
|
+
GINConvolution = GINConv
|
|
1217
|
+
|
|
1218
|
+
EdgeEmbed = EdgeEmbedding
|
|
1219
|
+
NodeEmbed = NodeEmbedding
|
|
1220
|
+
|
|
1221
|
+
ContextDense = ContextProjection
|
|
1222
|
+
EdgeDense = EdgeProjection
|
|
1223
|
+
NodeDense = NodeProjection
|
|
1224
|
+
|