molcraft 0.1.0a2__py3-none-any.whl → 0.1.0a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of molcraft might be problematic. Click here for more details.
- molcraft/__init__.py +2 -1
- molcraft/callbacks.py +12 -0
- molcraft/descriptors.py +24 -23
- molcraft/experimental/peptides.py +96 -79
- molcraft/features.py +5 -3
- molcraft/featurizers.py +61 -38
- molcraft/layers.py +1004 -425
- molcraft/models.py +47 -3
- molcraft/ops.py +14 -3
- molcraft/tensors.py +3 -3
- {molcraft-0.1.0a2.dist-info → molcraft-0.1.0a4.dist-info}/METADATA +6 -6
- molcraft-0.1.0a4.dist-info/RECORD +20 -0
- {molcraft-0.1.0a2.dist-info → molcraft-0.1.0a4.dist-info}/WHEEL +1 -1
- molcraft-0.1.0a2.dist-info/RECORD +0 -20
- {molcraft-0.1.0a2.dist-info → molcraft-0.1.0a4.dist-info}/licenses/LICENSE +0 -0
- {molcraft-0.1.0a2.dist-info → molcraft-0.1.0a4.dist-info}/top_level.txt +0 -0
molcraft/layers.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
import abc
|
|
2
1
|
import keras
|
|
3
2
|
import tensorflow as tf
|
|
4
3
|
import warnings
|
|
4
|
+
import functools
|
|
5
5
|
from keras.src.models import functional
|
|
6
6
|
|
|
7
7
|
from molcraft import tensors
|
|
@@ -12,11 +12,39 @@ from molcraft import ops
|
|
|
12
12
|
class GraphLayer(keras.layers.Layer):
|
|
13
13
|
"""Base graph layer.
|
|
14
14
|
|
|
15
|
-
|
|
15
|
+
Subclasses must implement a forward pass via **propagate(graph)**.
|
|
16
|
+
|
|
17
|
+
Subclasses may create dense layers and weights in **build(graph_spec)**.
|
|
18
|
+
|
|
19
|
+
Note: `GraphLayer` currently only supports `GraphTensor` input.
|
|
16
20
|
|
|
17
|
-
The list of arguments are only relevant if the derived layer
|
|
21
|
+
The list of arguments below are only relevant if the derived layer
|
|
18
22
|
invokes 'get_dense_kwargs`, `get_dense` or `get_einsum_dense`.
|
|
19
23
|
|
|
24
|
+
Arguments:
|
|
25
|
+
use_bias (bool):
|
|
26
|
+
Whether bias should be used in dense layers. Default to `True`.
|
|
27
|
+
kernel_initializer (keras.initializers.Initializer, str):
|
|
28
|
+
Initializer for the kernel weight matrix of the dense layers.
|
|
29
|
+
Default to `glorot_uniform`.
|
|
30
|
+
bias_initializer (keras.initializers.Initializer, str):
|
|
31
|
+
Initializer for the bias weight vector of the dense layers.
|
|
32
|
+
Default to `zeros`.
|
|
33
|
+
kernel_regularizer (keras.regularizers.Regularizer, None):
|
|
34
|
+
Regularizer function applied to the kernel weight matrix.
|
|
35
|
+
Default to `None`.
|
|
36
|
+
bias_regularizer (keras.regularizers.Regularizer, None):
|
|
37
|
+
Regularizer function applied to the bias weight vector.
|
|
38
|
+
Default to `None`.
|
|
39
|
+
activity_regularizer (keras.regularizers.Regularizer, None):
|
|
40
|
+
Regularizer function applied to the output of the dense layers.
|
|
41
|
+
Default to `None`.
|
|
42
|
+
kernel_constraint (keras.constraints.Constraint, None):
|
|
43
|
+
Constraint function applied to the kernel weight matrix.
|
|
44
|
+
Default to `None`.
|
|
45
|
+
bias_constraint (keras.constraints.Constraint, None):
|
|
46
|
+
Constraint function applied to the bias weight vector.
|
|
47
|
+
Default to `None`.
|
|
20
48
|
"""
|
|
21
49
|
|
|
22
50
|
def __init__(
|
|
@@ -31,73 +59,61 @@ class GraphLayer(keras.layers.Layer):
|
|
|
31
59
|
bias_constraint: keras.constraints.Constraint | None = None,
|
|
32
60
|
**kwargs,
|
|
33
61
|
) -> None:
|
|
34
|
-
super().__init__(
|
|
62
|
+
super().__init__(**kwargs)
|
|
35
63
|
self._use_bias = use_bias
|
|
36
64
|
self._kernel_initializer = keras.initializers.get(kernel_initializer)
|
|
37
65
|
self._bias_initializer = keras.initializers.get(bias_initializer)
|
|
38
66
|
self._kernel_regularizer = keras.regularizers.get(kernel_regularizer)
|
|
39
67
|
self._bias_regularizer = keras.regularizers.get(bias_regularizer)
|
|
68
|
+
self._activity_regularizer = keras.regularizers.get(activity_regularizer)
|
|
40
69
|
self._kernel_constraint = keras.constraints.get(kernel_constraint)
|
|
41
70
|
self._bias_constraint = keras.constraints.get(bias_constraint)
|
|
71
|
+
self._custom_build_config = {}
|
|
42
72
|
self.built = False
|
|
43
|
-
|
|
44
|
-
|
|
73
|
+
|
|
74
|
+
def __init_subclass__(cls, **kwargs):
|
|
75
|
+
super().__init_subclass__(**kwargs)
|
|
76
|
+
subclass_build = cls.build
|
|
77
|
+
|
|
78
|
+
@functools.wraps(subclass_build)
|
|
79
|
+
def build_wrapper(self: GraphLayer, spec: tensors.GraphTensor.Spec | None):
|
|
80
|
+
GraphLayer.build(self, spec)
|
|
81
|
+
subclass_build(self, spec)
|
|
82
|
+
if not self.built and isinstance(self, keras.Model):
|
|
83
|
+
symbolic_inputs = Input(spec)
|
|
84
|
+
self.built = True
|
|
85
|
+
self(symbolic_inputs)
|
|
86
|
+
|
|
87
|
+
cls.build = build_wrapper
|
|
45
88
|
|
|
46
89
|
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
47
|
-
"""
|
|
90
|
+
"""Forward pass.
|
|
48
91
|
|
|
49
|
-
|
|
92
|
+
Must be implemented by subclass.
|
|
50
93
|
|
|
51
|
-
|
|
94
|
+
Arguments:
|
|
52
95
|
tensor:
|
|
53
96
|
A `GraphTensor` instance.
|
|
54
97
|
"""
|
|
55
|
-
raise NotImplementedError(
|
|
98
|
+
raise NotImplementedError(
|
|
99
|
+
'The forward pass of the layer is not implemented. '
|
|
100
|
+
'Please implement `propagate`.'
|
|
101
|
+
)
|
|
56
102
|
|
|
57
|
-
def
|
|
103
|
+
def build(self, tensor_spec: tensors.GraphTensor.Spec) -> None:
|
|
58
104
|
"""Builds the layer.
|
|
59
105
|
|
|
60
106
|
May use built-in methods such as `get_weight`, `get_dense` and `get_einsum_dense`.
|
|
61
107
|
|
|
62
|
-
Optionally implemented by subclass.
|
|
63
|
-
build the sub-layers via `build([None, input_dim])`. If sub-layers are not
|
|
64
|
-
built, symbolic input will be passed through the layer to build it.
|
|
108
|
+
Optionally implemented by subclass.
|
|
65
109
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
A `GraphTensor.Spec` instance
|
|
69
|
-
|
|
110
|
+
Arguments:
|
|
111
|
+
tensor_spec:
|
|
112
|
+
A `GraphTensor.Spec` instance corresponding to the `GraphTensor`
|
|
113
|
+
passed to `propagate`.
|
|
70
114
|
"""
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
self._custom_build_config = {'spec': _serialize_spec(spec)}
|
|
75
|
-
|
|
76
|
-
invoke_build_from_spec = (
|
|
77
|
-
GraphLayer.build_from_spec != self.__class__.build_from_spec
|
|
78
|
-
)
|
|
79
|
-
if invoke_build_from_spec:
|
|
80
|
-
self.build_from_spec(spec)
|
|
81
|
-
self.built = True
|
|
82
|
-
|
|
83
|
-
if not self.built:
|
|
84
|
-
# Automatically build layer or model by calling it on symbolic inputs
|
|
85
|
-
self.built = True
|
|
86
|
-
symbolic_inputs = Input(spec)
|
|
87
|
-
self(symbolic_inputs)
|
|
88
|
-
|
|
89
|
-
def get_build_config(self) -> dict:
|
|
90
|
-
if not hasattr(self, '_custom_build_config'):
|
|
91
|
-
return super().get_build_config()
|
|
92
|
-
return self._custom_build_config
|
|
93
|
-
|
|
94
|
-
def build_from_config(self, config: dict) -> None:
|
|
95
|
-
use_custom_build_from_config = ('spec' in config)
|
|
96
|
-
if not use_custom_build_from_config:
|
|
97
|
-
super().build_from_config(config)
|
|
98
|
-
else:
|
|
99
|
-
spec = _deserialize_spec(config['spec'])
|
|
100
|
-
self.build(spec)
|
|
115
|
+
if isinstance(tensor_spec, tensors.GraphTensor.Spec):
|
|
116
|
+
self._custom_build_config['spec'] = _serialize_spec(tensor_spec)
|
|
101
117
|
|
|
102
118
|
def call(
|
|
103
119
|
self,
|
|
@@ -127,6 +143,19 @@ class GraphLayer(keras.layers.Layer):
|
|
|
127
143
|
outputs = tensors.from_dict(outputs)
|
|
128
144
|
return outputs
|
|
129
145
|
|
|
146
|
+
def get_build_config(self) -> dict:
|
|
147
|
+
if self._custom_build_config:
|
|
148
|
+
return self._custom_build_config
|
|
149
|
+
return super().get_build_config()
|
|
150
|
+
|
|
151
|
+
def build_from_config(self, config: dict) -> None:
|
|
152
|
+
serialized_spec = config.get('spec')
|
|
153
|
+
if serialized_spec is not None:
|
|
154
|
+
spec = _deserialize_spec(serialized_spec)
|
|
155
|
+
self.build(spec)
|
|
156
|
+
else:
|
|
157
|
+
super().build_from_config(config)
|
|
158
|
+
|
|
130
159
|
def get_weight(
|
|
131
160
|
self,
|
|
132
161
|
shape: tf.TensorShape,
|
|
@@ -168,7 +197,7 @@ class GraphLayer(keras.layers.Layer):
|
|
|
168
197
|
use_bias=self._use_bias,
|
|
169
198
|
kernel_regularizer=self._kernel_regularizer,
|
|
170
199
|
bias_regularizer=self._bias_regularizer,
|
|
171
|
-
activity_regularizer=self.
|
|
200
|
+
activity_regularizer=self._activity_regularizer,
|
|
172
201
|
kernel_constraint=self._kernel_constraint,
|
|
173
202
|
bias_constraint=self._bias_constraint,
|
|
174
203
|
)
|
|
@@ -194,52 +223,137 @@ class GraphLayer(keras.layers.Layer):
|
|
|
194
223
|
keras.regularizers.serialize(self._kernel_regularizer),
|
|
195
224
|
"bias_regularizer":
|
|
196
225
|
keras.regularizers.serialize(self._bias_regularizer),
|
|
226
|
+
"activity_regularizer":
|
|
227
|
+
keras.regularizers.serialize(self._activity_regularizer),
|
|
197
228
|
"kernel_constraint":
|
|
198
229
|
keras.constraints.serialize(self._kernel_constraint),
|
|
199
230
|
"bias_constraint":
|
|
200
231
|
keras.constraints.serialize(self._bias_constraint),
|
|
201
232
|
})
|
|
202
233
|
return config
|
|
203
|
-
|
|
234
|
+
|
|
204
235
|
|
|
205
236
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
206
237
|
class GraphConv(GraphLayer):
|
|
207
238
|
|
|
208
239
|
"""Base graph neural network layer.
|
|
209
240
|
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
241
|
+
This layer implements the three basic steps of a graph neural network layer, each of which
|
|
242
|
+
can (optionally) be overridden by the `GraphConv` subclass:
|
|
243
|
+
|
|
244
|
+
1. **message(graph)**, which computes the *messages* to be passed to target nodes;
|
|
245
|
+
2. **aggregate(graph)**, which *aggregates* messages to target nodes;
|
|
246
|
+
3. **update(graph)**, which further *updates* (target) nodes.
|
|
247
|
+
|
|
248
|
+
Note: for skip connection to work, the `GraphConv` subclass requires final node feature
|
|
249
|
+
output dimension to be equal to `units`.
|
|
250
|
+
|
|
251
|
+
Arguments:
|
|
252
|
+
units (int):
|
|
253
|
+
Dimensionality of the output space.
|
|
254
|
+
activation (keras.layers.Activation, str, None):
|
|
255
|
+
Activation function to use. If not specified, a linear activation (a(x) = x) is used.
|
|
256
|
+
Default to `None`.
|
|
257
|
+
use_bias (bool):
|
|
258
|
+
Whether bias should be used in dense layers. Default to `True`.
|
|
259
|
+
normalization (bool, str):
|
|
260
|
+
Whether `LayerNormalization` should be applied to the final node feature output.
|
|
261
|
+
To use `BatchNormalization`, specify `batch_norm`. Default to `False`.
|
|
262
|
+
skip_connection (bool, str):
|
|
263
|
+
Whether node feature input should be added to the node feature output.
|
|
264
|
+
If node feature input dim is not equal to `units` (node feature output dim),
|
|
265
|
+
a projection layer will automatically project the residual before adding it
|
|
266
|
+
to the output. To use weighted skip connection,
|
|
267
|
+
specify `weighted`. The weight multiplied with the skip connection is a
|
|
268
|
+
learnable scalar. Default to `True`.
|
|
269
|
+
kernel_initializer (keras.initializers.Initializer, str):
|
|
270
|
+
Initializer for the kernel weight matrix of the dense layers.
|
|
271
|
+
Default to `glorot_uniform`.
|
|
272
|
+
bias_initializer (keras.initializers.Initializer, str):
|
|
273
|
+
Initializer for the bias weight vector of the dense layers.
|
|
274
|
+
Default to `zeros`.
|
|
275
|
+
kernel_regularizer (keras.regularizers.Regularizer, None):
|
|
276
|
+
Regularizer function applied to the kernel weight matrix.
|
|
277
|
+
Default to `None`.
|
|
278
|
+
bias_regularizer (keras.regularizers.Regularizer, None):
|
|
279
|
+
Regularizer function applied to the bias weight vector.
|
|
280
|
+
Default to `None`.
|
|
281
|
+
activity_regularizer (keras.regularizers.Regularizer, None):
|
|
282
|
+
Regularizer function applied to the output of the dense layers.
|
|
283
|
+
Default to `None`.
|
|
284
|
+
kernel_constraint (keras.constraints.Constraint, None):
|
|
285
|
+
Constraint function applied to the kernel weight matrix.
|
|
286
|
+
Default to `None`.
|
|
287
|
+
bias_constraint (keras.constraints.Constraint, None):
|
|
288
|
+
Constraint function applied to the bias weight vector.
|
|
289
|
+
Default to `None`.
|
|
228
290
|
"""
|
|
229
291
|
|
|
230
292
|
def __init__(
|
|
231
293
|
self,
|
|
232
|
-
units: int,
|
|
233
|
-
|
|
234
|
-
|
|
294
|
+
units: int = None,
|
|
295
|
+
activation: str | keras.layers.Activation | None = None,
|
|
296
|
+
use_bias: bool = True,
|
|
297
|
+
normalization: bool | str = False,
|
|
298
|
+
skip_connection: bool | str = True,
|
|
235
299
|
**kwargs
|
|
236
300
|
) -> None:
|
|
237
|
-
super().__init__(**kwargs)
|
|
238
|
-
self.
|
|
239
|
-
self.
|
|
301
|
+
super().__init__(use_bias=use_bias, **kwargs)
|
|
302
|
+
self._units = units
|
|
303
|
+
self._normalization = normalization
|
|
240
304
|
self._skip_connection = skip_connection
|
|
305
|
+
self._activation = keras.activations.get(activation)
|
|
306
|
+
|
|
307
|
+
def __init_subclass__(cls, **kwargs):
|
|
308
|
+
super().__init_subclass__(**kwargs)
|
|
309
|
+
subclass_build = cls.build
|
|
310
|
+
|
|
311
|
+
@functools.wraps(subclass_build)
|
|
312
|
+
def build_wrapper(self, spec):
|
|
313
|
+
GraphConv.build(self, spec)
|
|
314
|
+
return subclass_build(self, spec)
|
|
315
|
+
|
|
316
|
+
cls.build = build_wrapper
|
|
317
|
+
|
|
318
|
+
@property
|
|
319
|
+
def units(self):
|
|
320
|
+
return self._units
|
|
321
|
+
|
|
322
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
323
|
+
"""Forward pass.
|
|
324
|
+
|
|
325
|
+
Invokes `message(graph)`, `aggregate(graph)` and `update(graph)` in sequence.
|
|
326
|
+
|
|
327
|
+
Arguments:
|
|
328
|
+
tensor:
|
|
329
|
+
A `GraphTensor` instance.
|
|
330
|
+
"""
|
|
331
|
+
if self._skip_connection:
|
|
332
|
+
input_node_feature = tensor.node['feature']
|
|
333
|
+
if self._project_input_node_feature:
|
|
334
|
+
input_node_feature = self._residual_projection(input_node_feature)
|
|
335
|
+
|
|
336
|
+
tensor = self.message(tensor)
|
|
337
|
+
tensor = self.aggregate(tensor)
|
|
338
|
+
tensor = self.update(tensor)
|
|
339
|
+
|
|
340
|
+
updated_node_feature = tensor.node['feature']
|
|
341
|
+
|
|
342
|
+
if self._skip_connection:
|
|
343
|
+
if self._use_weighted_skip_connection:
|
|
344
|
+
input_node_feature *= self._skip_connection_weight
|
|
345
|
+
updated_node_feature += input_node_feature
|
|
346
|
+
|
|
347
|
+
if self._normalization:
|
|
348
|
+
updated_node_feature = self._output_norm(updated_node_feature)
|
|
349
|
+
|
|
350
|
+
return tensor.update({'node': {'feature': updated_node_feature}})
|
|
241
351
|
|
|
242
352
|
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
353
|
+
if not self.units:
|
|
354
|
+
raise ValueError(
|
|
355
|
+
f'`self.units` needs to be a positive integer. Found: {self.units}.'
|
|
356
|
+
)
|
|
243
357
|
node_feature_dim = spec.node['feature'].shape[-1]
|
|
244
358
|
self._project_input_node_feature = (
|
|
245
359
|
self._skip_connection and (node_feature_dim != self.units)
|
|
@@ -254,81 +368,115 @@ class GraphConv(GraphLayer):
|
|
|
254
368
|
self._residual_projection = self.get_dense(
|
|
255
369
|
self.units, name='residual_projection'
|
|
256
370
|
)
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
371
|
+
|
|
372
|
+
skip_connection = str(self._skip_connection).lower()
|
|
373
|
+
self._use_weighted_skip_connection = skip_connection.startswith('weight')
|
|
374
|
+
if self._use_weighted_skip_connection:
|
|
375
|
+
self._skip_connection_weight = self.add_weight(
|
|
376
|
+
name='skip_connection_weight',
|
|
377
|
+
shape=(),
|
|
378
|
+
initializer='ones',
|
|
379
|
+
trainable=True,
|
|
260
380
|
)
|
|
261
|
-
self._aggregation_norm.build([None, self.units])
|
|
262
381
|
|
|
263
|
-
|
|
382
|
+
if self._normalization:
|
|
383
|
+
if str(self._normalization).lower().startswith('batch'):
|
|
384
|
+
self._output_norm = keras.layers.BatchNormalization(
|
|
385
|
+
name='output_batch_norm'
|
|
386
|
+
)
|
|
387
|
+
else:
|
|
388
|
+
self._output_norm = keras.layers.LayerNormalization(
|
|
389
|
+
name='output_layer_norm'
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
self._has_edge_feature = 'edge' in spec.edge
|
|
393
|
+
|
|
394
|
+
has_overridden_message = self.__class__.message != GraphConv.message
|
|
395
|
+
if not has_overridden_message:
|
|
396
|
+
self._message_dense = self.get_dense(self.units)
|
|
397
|
+
|
|
398
|
+
has_overridden_update = self.__class__.update != GraphConv.update
|
|
399
|
+
if not has_overridden_update:
|
|
400
|
+
self._output_dense = self.get_dense(self.units)
|
|
401
|
+
self._output_activation = self._activation
|
|
264
402
|
|
|
265
|
-
@abc.abstractmethod
|
|
266
403
|
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
267
404
|
"""Compute messages.
|
|
268
405
|
|
|
269
|
-
This method
|
|
406
|
+
This method may be overridden by subclass.
|
|
270
407
|
|
|
271
|
-
|
|
408
|
+
Arguments:
|
|
272
409
|
tensor:
|
|
273
410
|
The inputted `GraphTensor` instance.
|
|
274
411
|
"""
|
|
275
|
-
|
|
276
|
-
|
|
412
|
+
if not self._has_edge_feature:
|
|
413
|
+
message_feature = tensor.gather('feature', 'source')
|
|
414
|
+
else:
|
|
415
|
+
message_feature = keras.ops.concatenate(
|
|
416
|
+
[
|
|
417
|
+
tensor.gather('feature', 'source'),
|
|
418
|
+
tensor.edge['feature']
|
|
419
|
+
],
|
|
420
|
+
axis=-1
|
|
421
|
+
)
|
|
422
|
+
message = self._message_dense(message_feature)
|
|
423
|
+
return tensor.update(
|
|
424
|
+
{
|
|
425
|
+
'edge': {
|
|
426
|
+
'message': message
|
|
427
|
+
}
|
|
428
|
+
}
|
|
429
|
+
)
|
|
430
|
+
|
|
277
431
|
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
278
432
|
"""Aggregates messages.
|
|
279
433
|
|
|
280
|
-
This method
|
|
434
|
+
This method may be overridden by subclass.
|
|
281
435
|
|
|
282
|
-
|
|
436
|
+
Arguments:
|
|
283
437
|
tensor:
|
|
284
438
|
A `GraphTensor` instance containing a message.
|
|
285
439
|
"""
|
|
440
|
+
aggregate = tensor.aggregate('message', mode='mean')
|
|
441
|
+
return tensor.update(
|
|
442
|
+
{
|
|
443
|
+
'node': {
|
|
444
|
+
'feature': aggregate,
|
|
445
|
+
'previous_feature': tensor.node['feature']
|
|
446
|
+
},
|
|
447
|
+
'edge': {
|
|
448
|
+
'message': None
|
|
449
|
+
}
|
|
450
|
+
}
|
|
451
|
+
)
|
|
286
452
|
|
|
287
|
-
@abc.abstractmethod
|
|
288
453
|
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
289
454
|
"""Updates nodes.
|
|
290
455
|
|
|
291
|
-
This method
|
|
456
|
+
This method may be overridden by subclass.
|
|
292
457
|
|
|
293
|
-
|
|
458
|
+
Arguments:
|
|
294
459
|
tensor:
|
|
295
460
|
A `GraphTensor` instance containing aggregated messages
|
|
296
461
|
(updated node features).
|
|
297
462
|
"""
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
if self._skip_connection:
|
|
311
|
-
input_node_feature = tensor.node['feature']
|
|
312
|
-
if self._project_input_node_feature:
|
|
313
|
-
input_node_feature = self._residual_projection(input_node_feature)
|
|
314
|
-
|
|
315
|
-
tensor = self.message(tensor)
|
|
316
|
-
tensor = self.aggregate(tensor)
|
|
317
|
-
|
|
318
|
-
if self._normalize_aggregate:
|
|
319
|
-
normalized_node_feature = self._aggregation_norm(tensor.node['feature'])
|
|
320
|
-
tensor = tensor.update({'node': {'feature': normalized_node_feature}})
|
|
321
|
-
|
|
322
|
-
tensor = self.update(tensor)
|
|
323
|
-
|
|
324
|
-
if not self._skip_connection:
|
|
325
|
-
return tensor
|
|
326
|
-
|
|
327
|
-
updated_node_feature = tensor.node['feature']
|
|
463
|
+
if not 'previous_feature' in tensor.node:
|
|
464
|
+
feature = tensor.node['feature']
|
|
465
|
+
else:
|
|
466
|
+
feature = keras.ops.concatenate(
|
|
467
|
+
[
|
|
468
|
+
tensor.node['feature'],
|
|
469
|
+
tensor.node['previous_feature']
|
|
470
|
+
],
|
|
471
|
+
axis=-1
|
|
472
|
+
)
|
|
473
|
+
update = self._output_dense(feature)
|
|
474
|
+
update = self._output_activation(update)
|
|
328
475
|
return tensor.update(
|
|
329
476
|
{
|
|
330
477
|
'node': {
|
|
331
|
-
'feature':
|
|
478
|
+
'feature': update,
|
|
479
|
+
'previous_feature': None,
|
|
332
480
|
}
|
|
333
481
|
}
|
|
334
482
|
)
|
|
@@ -337,16 +485,44 @@ class GraphConv(GraphLayer):
|
|
|
337
485
|
config = super().get_config()
|
|
338
486
|
config.update({
|
|
339
487
|
'units': self.units,
|
|
340
|
-
'
|
|
488
|
+
'activation': keras.activations.serialize(self._activation),
|
|
489
|
+
'normalization': self._normalization,
|
|
341
490
|
'skip_connection': self._skip_connection,
|
|
342
491
|
})
|
|
343
492
|
return config
|
|
344
493
|
|
|
345
494
|
|
|
346
495
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
347
|
-
class
|
|
496
|
+
class GIConv(GraphConv):
|
|
348
497
|
|
|
349
498
|
"""Graph isomorphism network layer.
|
|
499
|
+
|
|
500
|
+
>>> graph = molcraft.tensors.GraphTensor(
|
|
501
|
+
... context={
|
|
502
|
+
... 'size': [2]
|
|
503
|
+
... },
|
|
504
|
+
... node={
|
|
505
|
+
... 'feature': [[1.], [2.]]
|
|
506
|
+
... },
|
|
507
|
+
... edge={
|
|
508
|
+
... 'source': [0, 1],
|
|
509
|
+
... 'target': [1, 0],
|
|
510
|
+
... }
|
|
511
|
+
... )
|
|
512
|
+
>>> conv = molcraft.layers.GIConv(units=4)
|
|
513
|
+
>>> conv(graph)
|
|
514
|
+
GraphTensor(
|
|
515
|
+
context={
|
|
516
|
+
'size': <tf.Tensor: shape=[1], dtype=int32>
|
|
517
|
+
},
|
|
518
|
+
node={
|
|
519
|
+
'feature': <tf.Tensor: shape=[2, 4], dtype=float32>
|
|
520
|
+
},
|
|
521
|
+
edge={
|
|
522
|
+
'source': <tf.Tensor: shape=[2], dtype=int32>,
|
|
523
|
+
'target': <tf.Tensor: shape=[2], dtype=int32>
|
|
524
|
+
}
|
|
525
|
+
)
|
|
350
526
|
"""
|
|
351
527
|
|
|
352
528
|
def __init__(
|
|
@@ -354,24 +530,20 @@ class GINConv(GraphConv):
|
|
|
354
530
|
units: int,
|
|
355
531
|
activation: keras.layers.Activation | str | None = 'relu',
|
|
356
532
|
use_bias: bool = True,
|
|
357
|
-
|
|
358
|
-
dropout: float = 0.0,
|
|
533
|
+
normalization: bool = False,
|
|
359
534
|
update_edge_feature: bool = True,
|
|
360
535
|
**kwargs,
|
|
361
536
|
):
|
|
362
537
|
super().__init__(
|
|
363
538
|
units=units,
|
|
364
|
-
|
|
539
|
+
activation=activation,
|
|
540
|
+
normalization=normalization,
|
|
365
541
|
use_bias=use_bias,
|
|
366
542
|
**kwargs
|
|
367
543
|
)
|
|
368
|
-
self._activation = keras.activations.get(activation)
|
|
369
|
-
self._dropout = dropout
|
|
370
544
|
self._update_edge_feature = update_edge_feature
|
|
371
545
|
|
|
372
|
-
def
|
|
373
|
-
"""Builds the layer.
|
|
374
|
-
"""
|
|
546
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
375
547
|
node_feature_dim = spec.node['feature'].shape[-1]
|
|
376
548
|
|
|
377
549
|
self.epsilon = self.add_weight(
|
|
@@ -381,7 +553,8 @@ class GINConv(GraphConv):
|
|
|
381
553
|
trainable=True,
|
|
382
554
|
)
|
|
383
555
|
|
|
384
|
-
|
|
556
|
+
self._has_edge_feature = 'feature' in spec.edge
|
|
557
|
+
if self._has_edge_feature:
|
|
385
558
|
edge_feature_dim = spec.edge['feature'].shape[-1]
|
|
386
559
|
|
|
387
560
|
if not self._update_edge_feature:
|
|
@@ -395,31 +568,21 @@ class GINConv(GraphConv):
|
|
|
395
568
|
|
|
396
569
|
if self._update_edge_feature:
|
|
397
570
|
self._edge_dense = self.get_dense(node_feature_dim)
|
|
398
|
-
self._edge_dense.build([None, edge_feature_dim])
|
|
399
571
|
else:
|
|
400
572
|
self._update_edge_feature = False
|
|
401
|
-
|
|
402
|
-
self._feedforward_intermediate_dense = self.get_dense(self.units)
|
|
403
|
-
self._feedforward_intermediate_dense.build([None, node_feature_dim])
|
|
404
573
|
|
|
405
|
-
has_overridden_update = self.__class__.update !=
|
|
574
|
+
has_overridden_update = self.__class__.update != GIConv.update
|
|
406
575
|
if not has_overridden_update:
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
self._feedforward_dropout = keras.layers.Dropout(self._dropout)
|
|
576
|
+
self._feedforward_intermediate_dense = self.get_dense(self.units)
|
|
410
577
|
self._feedforward_activation = self._activation
|
|
411
|
-
|
|
412
578
|
self._feedforward_output_dense = self.get_dense(self.units)
|
|
413
|
-
|
|
414
|
-
|
|
579
|
+
|
|
415
580
|
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
416
|
-
"""Computes messages.
|
|
417
|
-
"""
|
|
418
581
|
message = tensor.gather('feature', 'source')
|
|
419
582
|
edge_feature = tensor.edge.get('feature')
|
|
420
583
|
if self._update_edge_feature:
|
|
421
584
|
edge_feature = self._edge_dense(edge_feature)
|
|
422
|
-
if
|
|
585
|
+
if self._has_edge_feature:
|
|
423
586
|
message += edge_feature
|
|
424
587
|
return tensor.update(
|
|
425
588
|
{
|
|
@@ -431,12 +594,8 @@ class GINConv(GraphConv):
|
|
|
431
594
|
)
|
|
432
595
|
|
|
433
596
|
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
434
|
-
|
|
435
|
-
"""
|
|
436
|
-
node_feature = tensor.aggregate('message')
|
|
597
|
+
node_feature = tensor.aggregate('message', mode='mean')
|
|
437
598
|
node_feature += (1 + self.epsilon) * tensor.node['feature']
|
|
438
|
-
node_feature = self._feedforward_intermediate_dense(node_feature)
|
|
439
|
-
node_feature = self._feedforward_activation(node_feature)
|
|
440
599
|
return tensor.update(
|
|
441
600
|
{
|
|
442
601
|
'node': {
|
|
@@ -449,10 +608,9 @@ class GINConv(GraphConv):
|
|
|
449
608
|
)
|
|
450
609
|
|
|
451
610
|
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
452
|
-
"""Updates nodes.
|
|
453
|
-
"""
|
|
454
611
|
node_feature = tensor.node['feature']
|
|
455
|
-
node_feature = self.
|
|
612
|
+
node_feature = self._feedforward_intermediate_dense(node_feature)
|
|
613
|
+
node_feature = self._feedforward_activation(node_feature)
|
|
456
614
|
node_feature = self._feedforward_output_dense(node_feature)
|
|
457
615
|
return tensor.update(
|
|
458
616
|
{
|
|
@@ -465,17 +623,217 @@ class GINConv(GraphConv):
|
|
|
465
623
|
def get_config(self) -> dict:
|
|
466
624
|
config = super().get_config()
|
|
467
625
|
config.update({
|
|
468
|
-
'activation': keras.activations.serialize(self._activation),
|
|
469
|
-
'dropout': self._dropout,
|
|
470
626
|
'update_edge_feature': self._update_edge_feature
|
|
471
627
|
})
|
|
472
628
|
return config
|
|
473
629
|
|
|
474
630
|
|
|
631
|
+
@keras.saving.register_keras_serializable(package='molgraphx')
|
|
632
|
+
class GAConv(GraphConv):
|
|
633
|
+
|
|
634
|
+
"""Graph attention network layer.
|
|
635
|
+
|
|
636
|
+
>>> graph = molcraft.tensors.GraphTensor(
|
|
637
|
+
... context={
|
|
638
|
+
... 'size': [2]
|
|
639
|
+
... },
|
|
640
|
+
... node={
|
|
641
|
+
... 'feature': [[1.], [2.]]
|
|
642
|
+
... },
|
|
643
|
+
... edge={
|
|
644
|
+
... 'source': [0, 1],
|
|
645
|
+
... 'target': [1, 0],
|
|
646
|
+
... }
|
|
647
|
+
... )
|
|
648
|
+
>>> conv = molcraft.layers.GAConv(units=4, heads=2)
|
|
649
|
+
>>> conv(graph)
|
|
650
|
+
GraphTensor(
|
|
651
|
+
context={
|
|
652
|
+
'size': <tf.Tensor: shape=[1], dtype=int32>
|
|
653
|
+
},
|
|
654
|
+
node={
|
|
655
|
+
'feature': <tf.Tensor: shape=[2, 4], dtype=float32>
|
|
656
|
+
},
|
|
657
|
+
edge={
|
|
658
|
+
'source': <tf.Tensor: shape=[2], dtype=int32>,
|
|
659
|
+
'target': <tf.Tensor: shape=[2], dtype=int32>
|
|
660
|
+
}
|
|
661
|
+
)
|
|
662
|
+
"""
|
|
663
|
+
|
|
664
|
+
def __init__(
|
|
665
|
+
self,
|
|
666
|
+
units: int,
|
|
667
|
+
heads: int = 8,
|
|
668
|
+
activation: keras.layers.Activation | str | None = "relu",
|
|
669
|
+
use_bias: bool = True,
|
|
670
|
+
normalization: bool = False,
|
|
671
|
+
update_edge_feature: bool = True,
|
|
672
|
+
attention_activation: keras.layers.Activation | str | None = "leaky_relu",
|
|
673
|
+
**kwargs,
|
|
674
|
+
) -> None:
|
|
675
|
+
kwargs['skip_connection'] = False
|
|
676
|
+
super().__init__(
|
|
677
|
+
units=units,
|
|
678
|
+
activation=activation,
|
|
679
|
+
use_bias=use_bias,
|
|
680
|
+
normalization=normalization,
|
|
681
|
+
**kwargs
|
|
682
|
+
)
|
|
683
|
+
self._heads = heads
|
|
684
|
+
if self.units % self.heads != 0:
|
|
685
|
+
raise ValueError(f"units need to be divisible by heads.")
|
|
686
|
+
self._head_units = self.units // self.heads
|
|
687
|
+
self._update_edge_feature = update_edge_feature
|
|
688
|
+
self._attention_activation = keras.activations.get(attention_activation)
|
|
689
|
+
|
|
690
|
+
@property
|
|
691
|
+
def heads(self):
|
|
692
|
+
return self._heads
|
|
693
|
+
|
|
694
|
+
@property
|
|
695
|
+
def head_units(self):
|
|
696
|
+
return self._head_units
|
|
697
|
+
|
|
698
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
699
|
+
self._has_edge_feature = 'feature' in spec.edge
|
|
700
|
+
self._update_edge_feature = self._has_edge_feature and self._update_edge_feature
|
|
701
|
+
if self._update_edge_feature:
|
|
702
|
+
self._edge_dense = self.get_einsum_dense(
|
|
703
|
+
'ijh,jkh->ikh', (self.head_units, self.heads)
|
|
704
|
+
)
|
|
705
|
+
self._node_dense = self.get_einsum_dense(
|
|
706
|
+
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
707
|
+
)
|
|
708
|
+
self._feature_dense = self.get_einsum_dense(
|
|
709
|
+
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
710
|
+
)
|
|
711
|
+
self._attention_dense = self.get_einsum_dense(
|
|
712
|
+
'ijh,jkh->ikh', (1, self.heads)
|
|
713
|
+
)
|
|
714
|
+
self._node_self_dense = self.get_einsum_dense(
|
|
715
|
+
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
716
|
+
)
|
|
717
|
+
|
|
718
|
+
has_overridden_update = self.__class__.update != GAConv.update
|
|
719
|
+
if not has_overridden_update:
|
|
720
|
+
self._feedforward_intermediate_dense = self.get_dense(self.units)
|
|
721
|
+
self._feedforward_activation = self._activation
|
|
722
|
+
self._feedforward_output_dense = self.get_dense(self.units)
|
|
723
|
+
|
|
724
|
+
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
725
|
+
attention_feature = keras.ops.concatenate(
|
|
726
|
+
[
|
|
727
|
+
tensor.gather('feature', 'source'),
|
|
728
|
+
tensor.gather('feature', 'target')
|
|
729
|
+
],
|
|
730
|
+
axis=-1
|
|
731
|
+
)
|
|
732
|
+
if self._has_edge_feature:
|
|
733
|
+
attention_feature = keras.ops.concatenate(
|
|
734
|
+
[
|
|
735
|
+
attention_feature,
|
|
736
|
+
tensor.edge['feature']
|
|
737
|
+
],
|
|
738
|
+
axis=-1
|
|
739
|
+
)
|
|
740
|
+
|
|
741
|
+
attention_feature = self._feature_dense(attention_feature)
|
|
742
|
+
|
|
743
|
+
edge_feature = tensor.edge.get('feature')
|
|
744
|
+
|
|
745
|
+
if self._update_edge_feature:
|
|
746
|
+
edge_feature = self._edge_dense(attention_feature)
|
|
747
|
+
edge_feature = keras.ops.reshape(edge_feature, (-1, self.units))
|
|
748
|
+
|
|
749
|
+
attention_feature = self._attention_activation(attention_feature)
|
|
750
|
+
attention_score = self._attention_dense(attention_feature)
|
|
751
|
+
attention_score = ops.edge_softmax(
|
|
752
|
+
score=attention_score, edge_target=tensor.edge['target']
|
|
753
|
+
)
|
|
754
|
+
node_feature = self._node_dense(tensor.node['feature'])
|
|
755
|
+
message = ops.gather(node_feature, tensor.edge['source'])
|
|
756
|
+
return tensor.update(
|
|
757
|
+
{
|
|
758
|
+
'edge': {
|
|
759
|
+
'message': message,
|
|
760
|
+
'weight': attention_score,
|
|
761
|
+
'feature': edge_feature,
|
|
762
|
+
}
|
|
763
|
+
}
|
|
764
|
+
)
|
|
765
|
+
|
|
766
|
+
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
767
|
+
node_feature = tensor.aggregate('message', mode='sum')
|
|
768
|
+
node_feature += self._node_self_dense(tensor.node['feature'])
|
|
769
|
+
node_feature = keras.ops.reshape(node_feature, (-1, self.units))
|
|
770
|
+
return tensor.update(
|
|
771
|
+
{
|
|
772
|
+
'node': {
|
|
773
|
+
'feature': node_feature
|
|
774
|
+
},
|
|
775
|
+
'edge': {
|
|
776
|
+
'message': None,
|
|
777
|
+
'weight': None,
|
|
778
|
+
}
|
|
779
|
+
}
|
|
780
|
+
)
|
|
781
|
+
|
|
782
|
+
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
783
|
+
node_feature = tensor.node['feature']
|
|
784
|
+
node_feature = self._feedforward_intermediate_dense(node_feature)
|
|
785
|
+
node_feature = self._feedforward_activation(node_feature)
|
|
786
|
+
node_feature = self._feedforward_output_dense(node_feature)
|
|
787
|
+
return tensor.update(
|
|
788
|
+
{
|
|
789
|
+
'node': {
|
|
790
|
+
'feature': node_feature
|
|
791
|
+
}
|
|
792
|
+
}
|
|
793
|
+
)
|
|
794
|
+
|
|
795
|
+
def get_config(self) -> dict:
|
|
796
|
+
config = super().get_config()
|
|
797
|
+
config.update({
|
|
798
|
+
"heads": self._heads,
|
|
799
|
+
'update_edge_feature': self._update_edge_feature,
|
|
800
|
+
'attention_activation': keras.activations.serialize(self._attention_activation),
|
|
801
|
+
})
|
|
802
|
+
return config
|
|
803
|
+
|
|
804
|
+
|
|
475
805
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
476
806
|
class GTConv(GraphConv):
|
|
477
807
|
|
|
478
808
|
"""Graph transformer layer.
|
|
809
|
+
|
|
810
|
+
>>> graph = molcraft.tensors.GraphTensor(
|
|
811
|
+
... context={
|
|
812
|
+
... 'size': [2]
|
|
813
|
+
... },
|
|
814
|
+
... node={
|
|
815
|
+
... 'feature': [[1.], [2.]]
|
|
816
|
+
... },
|
|
817
|
+
... edge={
|
|
818
|
+
... 'source': [0, 1],
|
|
819
|
+
... 'target': [1, 0],
|
|
820
|
+
... }
|
|
821
|
+
... )
|
|
822
|
+
>>> conv = molcraft.layers.GTConv(units=4, heads=2)
|
|
823
|
+
>>> conv(graph)
|
|
824
|
+
GraphTensor(
|
|
825
|
+
context={
|
|
826
|
+
'size': <tf.Tensor: shape=[1], dtype=int32>
|
|
827
|
+
},
|
|
828
|
+
node={
|
|
829
|
+
'feature': <tf.Tensor: shape=[2, 4], dtype=float32>
|
|
830
|
+
},
|
|
831
|
+
edge={
|
|
832
|
+
'source': <tf.Tensor: shape=[2], dtype=int32>,
|
|
833
|
+
'target': <tf.Tensor: shape=[2], dtype=int32>
|
|
834
|
+
}
|
|
835
|
+
)
|
|
836
|
+
|
|
479
837
|
"""
|
|
480
838
|
|
|
481
839
|
def __init__(
|
|
@@ -484,26 +842,22 @@ class GTConv(GraphConv):
|
|
|
484
842
|
heads: int = 8,
|
|
485
843
|
activation: keras.layers.Activation | str | None = "relu",
|
|
486
844
|
use_bias: bool = True,
|
|
487
|
-
|
|
488
|
-
dropout: float = 0.0,
|
|
845
|
+
normalization: bool = False,
|
|
489
846
|
attention_dropout: float = 0.0,
|
|
490
847
|
**kwargs,
|
|
491
848
|
) -> None:
|
|
492
|
-
kwargs['skip_connection'] = False
|
|
493
849
|
super().__init__(
|
|
494
850
|
units=units,
|
|
495
|
-
|
|
851
|
+
activation=activation,
|
|
496
852
|
use_bias=use_bias,
|
|
853
|
+
normalization=normalization,
|
|
497
854
|
**kwargs
|
|
498
855
|
)
|
|
499
856
|
self._heads = heads
|
|
500
857
|
if self.units % self.heads != 0:
|
|
501
858
|
raise ValueError(f"units need to be divisible by heads.")
|
|
502
859
|
self._head_units = self.units // self.heads
|
|
503
|
-
self._activation = keras.activations.get(activation)
|
|
504
|
-
self._dropout = dropout
|
|
505
860
|
self._attention_dropout = attention_dropout
|
|
506
|
-
self._normalize = normalize
|
|
507
861
|
|
|
508
862
|
@property
|
|
509
863
|
def heads(self):
|
|
@@ -513,68 +867,41 @@ class GTConv(GraphConv):
|
|
|
513
867
|
def head_units(self):
|
|
514
868
|
return self._head_units
|
|
515
869
|
|
|
516
|
-
def
|
|
517
|
-
"""Builds the layer.
|
|
518
|
-
"""
|
|
519
|
-
node_feature_dim = spec.node['feature'].shape[-1]
|
|
520
|
-
self.project_residual = node_feature_dim != self.units
|
|
521
|
-
if self.project_residual:
|
|
522
|
-
warn(
|
|
523
|
-
'`GTConv` uses residual connections, but found incompatible dim '
|
|
524
|
-
'between input (node feature dim) and output (`self.units`). '
|
|
525
|
-
'Automatically applying a projection layer to residual to '
|
|
526
|
-
'match input and output. '
|
|
527
|
-
)
|
|
528
|
-
self._residual_dense = self.get_dense(self.units)
|
|
529
|
-
self._residual_dense.build([None, node_feature_dim])
|
|
530
|
-
|
|
870
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
531
871
|
self._query_dense = self.get_einsum_dense(
|
|
532
872
|
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
533
873
|
)
|
|
534
|
-
self._query_dense.build([None, node_feature_dim])
|
|
535
|
-
|
|
536
874
|
self._key_dense = self.get_einsum_dense(
|
|
537
875
|
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
538
876
|
)
|
|
539
|
-
self._key_dense.build([None, node_feature_dim])
|
|
540
|
-
|
|
541
877
|
self._value_dense = self.get_einsum_dense(
|
|
542
878
|
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
543
879
|
)
|
|
544
|
-
self._value_dense.build([None, node_feature_dim])
|
|
545
|
-
|
|
546
880
|
self._output_dense = self.get_dense(self.units)
|
|
547
|
-
self._output_dense.build([None, self.units])
|
|
548
|
-
|
|
549
881
|
self._softmax_dropout = keras.layers.Dropout(self._attention_dropout)
|
|
550
882
|
|
|
551
|
-
self.
|
|
883
|
+
self._add_bias = not 'bias' in spec.edge
|
|
552
884
|
|
|
553
|
-
self.
|
|
554
|
-
|
|
555
|
-
self._add_edge_bias = AddEdgeBias()
|
|
556
|
-
self._add_edge_bias.build_from_spec(spec)
|
|
885
|
+
if self._add_bias:
|
|
886
|
+
self._edge_bias = EdgeBias(biases=self.heads)
|
|
557
887
|
|
|
558
888
|
has_overridden_update = self.__class__.update != GTConv.update
|
|
559
889
|
if not has_overridden_update:
|
|
560
|
-
|
|
561
|
-
if self._normalize:
|
|
562
|
-
self._feedforward_output_norm = keras.layers.LayerNormalization()
|
|
563
|
-
self._feedforward_output_norm.build([None, self.units])
|
|
564
|
-
|
|
565
|
-
self._feedforward_dropout = keras.layers.Dropout(self._dropout)
|
|
566
|
-
|
|
567
890
|
self._feedforward_intermediate_dense = self.get_dense(self.units)
|
|
568
|
-
self.
|
|
569
|
-
|
|
891
|
+
self._feedforward_activation = self._activation
|
|
570
892
|
self._feedforward_output_dense = self.get_dense(self.units)
|
|
571
|
-
self._feedforward_output_dense.build([None, self.units])
|
|
572
|
-
|
|
573
893
|
|
|
574
894
|
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
895
|
+
if self._add_bias:
|
|
896
|
+
edge_bias = self._edge_bias(tensor)
|
|
897
|
+
tensor = tensor.update(
|
|
898
|
+
{
|
|
899
|
+
'edge': {
|
|
900
|
+
'bias': edge_bias
|
|
901
|
+
}
|
|
902
|
+
}
|
|
903
|
+
)
|
|
904
|
+
|
|
578
905
|
node_feature = tensor.node['feature']
|
|
579
906
|
|
|
580
907
|
query = self._query_dense(node_feature)
|
|
@@ -587,12 +914,8 @@ class GTConv(GraphConv):
|
|
|
587
914
|
|
|
588
915
|
attention_score = keras.ops.sum(query * key, axis=1, keepdims=True)
|
|
589
916
|
attention_score /= keras.ops.sqrt(float(self.head_units))
|
|
590
|
-
|
|
591
|
-
if self._add_edge_bias:
|
|
592
|
-
tensor = self._add_edge_bias(tensor)
|
|
593
917
|
|
|
594
|
-
attention_score += keras.ops.expand_dims(tensor.edge['bias'],
|
|
595
|
-
|
|
918
|
+
attention_score += keras.ops.expand_dims(tensor.edge['bias'], axis=1)
|
|
596
919
|
attention = ops.edge_softmax(attention_score, tensor.edge['target'])
|
|
597
920
|
attention = self._softmax_dropout(attention)
|
|
598
921
|
|
|
@@ -606,12 +929,9 @@ class GTConv(GraphConv):
|
|
|
606
929
|
)
|
|
607
930
|
|
|
608
931
|
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
609
|
-
|
|
610
|
-
"""
|
|
611
|
-
node_feature = tensor.aggregate('message')
|
|
932
|
+
node_feature = tensor.aggregate('message', mode='sum')
|
|
612
933
|
node_feature = keras.ops.reshape(node_feature, (-1, self.units))
|
|
613
934
|
node_feature = self._output_dense(node_feature)
|
|
614
|
-
node_feature = self._self_attention_dropout(node_feature)
|
|
615
935
|
return tensor.update(
|
|
616
936
|
{
|
|
617
937
|
'node': {
|
|
@@ -626,49 +946,257 @@ class GTConv(GraphConv):
|
|
|
626
946
|
)
|
|
627
947
|
|
|
628
948
|
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
629
|
-
|
|
949
|
+
node_feature = tensor.node['feature']
|
|
950
|
+
node_feature = self._feedforward_intermediate_dense(node_feature)
|
|
951
|
+
node_feature = self._feedforward_activation(node_feature)
|
|
952
|
+
node_feature = self._feedforward_output_dense(node_feature)
|
|
953
|
+
return tensor.update(
|
|
954
|
+
{
|
|
955
|
+
'node': {
|
|
956
|
+
'feature': node_feature,
|
|
957
|
+
},
|
|
958
|
+
}
|
|
959
|
+
)
|
|
960
|
+
|
|
961
|
+
def get_config(self) -> dict:
|
|
962
|
+
config = super().get_config()
|
|
963
|
+
config.update({
|
|
964
|
+
"heads": self._heads,
|
|
965
|
+
'attention_dropout': self._attention_dropout,
|
|
966
|
+
})
|
|
967
|
+
return config
|
|
968
|
+
|
|
969
|
+
|
|
970
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
971
|
+
class MPConv(GraphConv):
|
|
972
|
+
|
|
973
|
+
"""Message passing neural network layer.
|
|
974
|
+
"""
|
|
975
|
+
|
|
976
|
+
def __init__(
|
|
977
|
+
self,
|
|
978
|
+
units: int = 128,
|
|
979
|
+
activation: keras.layers.Activation | str | None = None,
|
|
980
|
+
use_bias: bool = True,
|
|
981
|
+
normalization: bool = False,
|
|
982
|
+
**kwargs
|
|
983
|
+
) -> None:
|
|
984
|
+
super().__init__(
|
|
985
|
+
units=units,
|
|
986
|
+
activation=activation,
|
|
987
|
+
use_bias=use_bias,
|
|
988
|
+
normalization=normalization,
|
|
989
|
+
**kwargs
|
|
990
|
+
)
|
|
991
|
+
|
|
992
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
993
|
+
node_feature_dim = spec.node['feature'].shape[-1]
|
|
994
|
+
self.message_fn = self.get_dense(self.units, activation=self._activation)
|
|
995
|
+
self.update_fn = keras.layers.GRUCell(self.units)
|
|
996
|
+
self._has_edge_feature = 'feature' in spec.edge
|
|
997
|
+
self.project_input_node_feature = node_feature_dim != self.units
|
|
998
|
+
if self.project_input_node_feature:
|
|
999
|
+
warn(
|
|
1000
|
+
'Input node feature dim does not match updated node feature dim. '
|
|
1001
|
+
'To make sure input node feature can be passed as `states` to the '
|
|
1002
|
+
'GRU cell, it will automatically be projected prior to it.'
|
|
1003
|
+
)
|
|
1004
|
+
self._previous_node_dense = self.get_dense(
|
|
1005
|
+
self.units, activation=self._activation
|
|
1006
|
+
)
|
|
1007
|
+
|
|
1008
|
+
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1009
|
+
feature = keras.ops.concatenate(
|
|
1010
|
+
[
|
|
1011
|
+
tensor.gather('feature', 'source'),
|
|
1012
|
+
tensor.gather('feature', 'target'),
|
|
1013
|
+
],
|
|
1014
|
+
axis=-1
|
|
1015
|
+
)
|
|
1016
|
+
if self._has_edge_feature:
|
|
1017
|
+
feature = keras.ops.concatenate(
|
|
1018
|
+
[
|
|
1019
|
+
feature,
|
|
1020
|
+
tensor.edge['feature']
|
|
1021
|
+
],
|
|
1022
|
+
axis=-1
|
|
1023
|
+
)
|
|
1024
|
+
message = self.message_fn(feature)
|
|
1025
|
+
return tensor.update(
|
|
1026
|
+
{
|
|
1027
|
+
'edge': {
|
|
1028
|
+
'message': message,
|
|
1029
|
+
}
|
|
1030
|
+
}
|
|
1031
|
+
)
|
|
1032
|
+
|
|
1033
|
+
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1034
|
+
aggregate = tensor.aggregate('message', mode='mean')
|
|
1035
|
+
previous = tensor.node['feature']
|
|
1036
|
+
if self.project_input_node_feature:
|
|
1037
|
+
previous = self._previous_node_dense(previous)
|
|
1038
|
+
return tensor.update(
|
|
1039
|
+
{
|
|
1040
|
+
'node': {
|
|
1041
|
+
'feature': aggregate,
|
|
1042
|
+
'previous_feature': previous,
|
|
1043
|
+
}
|
|
1044
|
+
}
|
|
1045
|
+
)
|
|
1046
|
+
|
|
1047
|
+
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1048
|
+
updated_node_feature, _ = self.update_fn(
|
|
1049
|
+
inputs=tensor.node['feature'],
|
|
1050
|
+
states=tensor.node['previous_feature']
|
|
1051
|
+
)
|
|
1052
|
+
return tensor.update(
|
|
1053
|
+
{
|
|
1054
|
+
'node': {
|
|
1055
|
+
'feature': updated_node_feature,
|
|
1056
|
+
'previous_feature': None,
|
|
1057
|
+
}
|
|
1058
|
+
}
|
|
1059
|
+
)
|
|
1060
|
+
|
|
1061
|
+
def get_config(self) -> dict:
|
|
1062
|
+
config = super().get_config()
|
|
1063
|
+
config.update({})
|
|
1064
|
+
return config
|
|
1065
|
+
|
|
1066
|
+
|
|
1067
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1068
|
+
class GTConv3D(GTConv):
|
|
1069
|
+
|
|
1070
|
+
"""Graph transformer layer 3D.
|
|
1071
|
+
"""
|
|
1072
|
+
|
|
1073
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1074
|
+
"""Builds the layer.
|
|
630
1075
|
"""
|
|
1076
|
+
super().build(spec)
|
|
1077
|
+
if self._add_bias:
|
|
1078
|
+
node_feature_dim = spec.node['feature'].shape[-1]
|
|
1079
|
+
kernels = self.units
|
|
1080
|
+
self._gaussian_basis = GaussianDistance(kernels)
|
|
1081
|
+
self._centrality_dense = self.get_dense(units=node_feature_dim)
|
|
1082
|
+
self._gaussian_edge_bias = self.get_dense(self.heads)
|
|
1083
|
+
|
|
1084
|
+
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
631
1085
|
node_feature = tensor.node['feature']
|
|
632
1086
|
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
1087
|
+
if self._add_bias:
|
|
1088
|
+
gaussian = self._gaussian_basis(tensor)
|
|
1089
|
+
centrality = keras.ops.segment_sum(
|
|
1090
|
+
gaussian, tensor.edge['target'], tensor.num_nodes
|
|
1091
|
+
)
|
|
1092
|
+
node_feature += self._centrality_dense(centrality)
|
|
636
1093
|
|
|
637
|
-
|
|
638
|
-
|
|
1094
|
+
edge_bias = self._edge_bias(tensor) + self._gaussian_edge_bias(gaussian)
|
|
1095
|
+
tensor = tensor.update({'edge': {'bias': edge_bias}})
|
|
1096
|
+
|
|
1097
|
+
query = self._query_dense(node_feature)
|
|
1098
|
+
key = self._key_dense(node_feature)
|
|
1099
|
+
value = self._value_dense(node_feature)
|
|
639
1100
|
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
1101
|
+
query = ops.gather(query, tensor.edge['source'])
|
|
1102
|
+
key = ops.gather(key, tensor.edge['target'])
|
|
1103
|
+
value = ops.gather(value, tensor.edge['source'])
|
|
1104
|
+
|
|
1105
|
+
attention_score = keras.ops.sum(query * key, axis=1, keepdims=True)
|
|
1106
|
+
attention_score /= keras.ops.sqrt(float(self.head_units))
|
|
1107
|
+
|
|
1108
|
+
attention_score += keras.ops.expand_dims(tensor.edge['bias'], axis=1)
|
|
1109
|
+
|
|
1110
|
+
attention = ops.edge_softmax(attention_score, tensor.edge['target'])
|
|
1111
|
+
attention = self._softmax_dropout(attention)
|
|
1112
|
+
|
|
1113
|
+
distance = keras.ops.subtract(
|
|
1114
|
+
tensor.gather('coordinate', 'source'),
|
|
1115
|
+
tensor.gather('coordinate', 'target')
|
|
1116
|
+
)
|
|
1117
|
+
euclidean_distance = ops.euclidean_distance(
|
|
1118
|
+
tensor.gather('coordinate', 'source'),
|
|
1119
|
+
tensor.gather('coordinate', 'target'),
|
|
1120
|
+
axis=-1
|
|
1121
|
+
)
|
|
1122
|
+
distance /= euclidean_distance
|
|
1123
|
+
|
|
1124
|
+
attention *= keras.ops.expand_dims(distance, axis=-1)
|
|
1125
|
+
attention = keras.ops.expand_dims(attention, axis=2)
|
|
1126
|
+
value = keras.ops.expand_dims(value, axis=1)
|
|
1127
|
+
|
|
1128
|
+
return tensor.update(
|
|
1129
|
+
{
|
|
1130
|
+
'edge': {
|
|
1131
|
+
'message': value,
|
|
1132
|
+
'weight': attention,
|
|
1133
|
+
},
|
|
1134
|
+
}
|
|
1135
|
+
)
|
|
1136
|
+
|
|
1137
|
+
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1138
|
+
node_feature = tensor.aggregate('message', mode='sum')
|
|
1139
|
+
node_feature = keras.ops.reshape(
|
|
1140
|
+
node_feature, (tensor.num_nodes, -1, self.units)
|
|
1141
|
+
)
|
|
1142
|
+
node_feature = self._output_dense(node_feature)
|
|
1143
|
+
node_feature = keras.ops.sum(node_feature, axis=1)
|
|
1144
|
+
return tensor.update(
|
|
1145
|
+
{
|
|
1146
|
+
'node': {
|
|
1147
|
+
'feature': node_feature,
|
|
1148
|
+
'residual': tensor.node['feature']
|
|
1149
|
+
},
|
|
1150
|
+
'edge': {
|
|
1151
|
+
'message': None,
|
|
1152
|
+
'weight': None,
|
|
1153
|
+
}
|
|
1154
|
+
}
|
|
1155
|
+
)
|
|
1156
|
+
|
|
646
1157
|
|
|
647
|
-
|
|
1158
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1159
|
+
class MPConv3D(MPConv):
|
|
648
1160
|
|
|
1161
|
+
"""Message passing neural network layer 3D.
|
|
1162
|
+
"""
|
|
1163
|
+
|
|
1164
|
+
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1165
|
+
euclidean_distance = ops.euclidean_distance(
|
|
1166
|
+
tensor.gather('coordinate', 'target'),
|
|
1167
|
+
tensor.gather('coordinate', 'source'),
|
|
1168
|
+
axis=-1
|
|
1169
|
+
)
|
|
1170
|
+
feature = keras.ops.concatenate(
|
|
1171
|
+
[
|
|
1172
|
+
tensor.gather('feature', 'source'),
|
|
1173
|
+
tensor.gather('feature', 'target'),
|
|
1174
|
+
euclidean_distance,
|
|
1175
|
+
],
|
|
1176
|
+
axis=-1
|
|
1177
|
+
)
|
|
1178
|
+
if self._has_edge_feature:
|
|
1179
|
+
feature = keras.ops.concatenate(
|
|
1180
|
+
[
|
|
1181
|
+
feature,
|
|
1182
|
+
tensor.edge['feature']
|
|
1183
|
+
],
|
|
1184
|
+
axis=-1
|
|
1185
|
+
)
|
|
1186
|
+
message = self.message_fn(feature)
|
|
649
1187
|
return tensor.update(
|
|
650
1188
|
{
|
|
651
|
-
'
|
|
652
|
-
'
|
|
653
|
-
}
|
|
1189
|
+
'edge': {
|
|
1190
|
+
'message': message,
|
|
1191
|
+
}
|
|
654
1192
|
}
|
|
655
1193
|
)
|
|
656
1194
|
|
|
657
|
-
def get_config(self) -> dict:
|
|
658
|
-
config = super().get_config()
|
|
659
|
-
config.update({
|
|
660
|
-
"heads": self._heads,
|
|
661
|
-
'activation': keras.activations.serialize(self._activation),
|
|
662
|
-
'dropout': self._dropout,
|
|
663
|
-
'attention_dropout': self._attention_dropout,
|
|
664
|
-
})
|
|
665
|
-
return config
|
|
666
|
-
|
|
667
1195
|
|
|
668
1196
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
669
1197
|
class EGConv3D(GraphConv):
|
|
670
1198
|
|
|
671
|
-
"""Equivariant graph neural network layer.
|
|
1199
|
+
"""Equivariant graph neural network layer 3D.
|
|
672
1200
|
"""
|
|
673
1201
|
|
|
674
1202
|
def __init__(
|
|
@@ -676,48 +1204,33 @@ class EGConv3D(GraphConv):
|
|
|
676
1204
|
units: int = 128,
|
|
677
1205
|
activation: keras.layers.Activation | str | None = None,
|
|
678
1206
|
use_bias: bool = True,
|
|
679
|
-
|
|
680
|
-
dropout: float = 0.0,
|
|
1207
|
+
normalization: bool = False,
|
|
681
1208
|
**kwargs
|
|
682
1209
|
) -> None:
|
|
683
1210
|
super().__init__(
|
|
684
1211
|
units=units,
|
|
685
|
-
|
|
1212
|
+
activation=activation,
|
|
686
1213
|
use_bias=use_bias,
|
|
1214
|
+
normalization=normalization,
|
|
687
1215
|
**kwargs
|
|
688
1216
|
)
|
|
689
|
-
self._activation = keras.activations.get(activation)
|
|
690
|
-
self._dropout = dropout or 0.0
|
|
691
1217
|
|
|
692
|
-
def
|
|
1218
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
693
1219
|
if 'coordinate' not in spec.node:
|
|
694
1220
|
raise ValueError(
|
|
695
1221
|
'Could not find `coordinate`s in node, '
|
|
696
1222
|
'which is required for Conv3D layers.'
|
|
697
1223
|
)
|
|
698
|
-
|
|
699
|
-
feature_dim = node_feature_dim + node_feature_dim + 1
|
|
700
|
-
if 'feature' in spec.edge:
|
|
701
|
-
self._has_edge_feature = True
|
|
702
|
-
edge_feature_dim = spec.edge['feature'].shape[-1]
|
|
703
|
-
feature_dim += edge_feature_dim
|
|
704
|
-
else:
|
|
705
|
-
self._has_edge_feature = False
|
|
706
|
-
|
|
1224
|
+
self._has_edge_feature = 'feature' in spec.edge
|
|
707
1225
|
self.message_fn = self.get_dense(self.units, activation=self._activation)
|
|
708
|
-
self.message_fn.build([None, feature_dim])
|
|
709
1226
|
self.dense_position = self.get_dense(1)
|
|
710
|
-
self.dense_position.build([None, self.units])
|
|
711
1227
|
|
|
712
1228
|
has_overridden_update = self.__class__.update != EGConv3D.update
|
|
713
1229
|
if not has_overridden_update:
|
|
714
1230
|
self.update_fn = self.get_dense(self.units, activation=self._activation)
|
|
715
|
-
self.
|
|
716
|
-
self._dropout_layer = keras.layers.Dropout(self._dropout)
|
|
1231
|
+
self.output_dense = self.get_dense(self.units)
|
|
717
1232
|
|
|
718
1233
|
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
719
|
-
"""Computes messages.
|
|
720
|
-
"""
|
|
721
1234
|
relative_node_coordinate = keras.ops.subtract(
|
|
722
1235
|
tensor.gather('coordinate', 'target'),
|
|
723
1236
|
tensor.gather('coordinate', 'source')
|
|
@@ -760,8 +1273,6 @@ class EGConv3D(GraphConv):
|
|
|
760
1273
|
)
|
|
761
1274
|
|
|
762
1275
|
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
763
|
-
"""Aggregates messages.
|
|
764
|
-
"""
|
|
765
1276
|
coefficient = keras.ops.bincount(
|
|
766
1277
|
tensor.edge['source'],
|
|
767
1278
|
minlength=tensor.num_nodes
|
|
@@ -776,7 +1287,7 @@ class EGConv3D(GraphConv):
|
|
|
776
1287
|
updated_coordinate = tensor.aggregate('relative_node_coordinate') * coefficient
|
|
777
1288
|
updated_coordinate += tensor.node['coordinate']
|
|
778
1289
|
|
|
779
|
-
aggregate = tensor.aggregate('message')
|
|
1290
|
+
aggregate = tensor.aggregate('message', mode='mean')
|
|
780
1291
|
return tensor.update(
|
|
781
1292
|
{
|
|
782
1293
|
'node': {
|
|
@@ -792,8 +1303,6 @@ class EGConv3D(GraphConv):
|
|
|
792
1303
|
)
|
|
793
1304
|
|
|
794
1305
|
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
795
|
-
"""Updates nodes.
|
|
796
|
-
"""
|
|
797
1306
|
updated_node_feature = self.update_fn(
|
|
798
1307
|
keras.ops.concatenate(
|
|
799
1308
|
[
|
|
@@ -803,7 +1312,7 @@ class EGConv3D(GraphConv):
|
|
|
803
1312
|
axis=-1
|
|
804
1313
|
)
|
|
805
1314
|
)
|
|
806
|
-
updated_node_feature = self.
|
|
1315
|
+
updated_node_feature = self.output_dense(updated_node_feature)
|
|
807
1316
|
return tensor.update(
|
|
808
1317
|
{
|
|
809
1318
|
'node': {
|
|
@@ -815,65 +1324,46 @@ class EGConv3D(GraphConv):
|
|
|
815
1324
|
|
|
816
1325
|
def get_config(self) -> dict:
|
|
817
1326
|
config = super().get_config()
|
|
818
|
-
config.update({
|
|
819
|
-
'activation': keras.activations.serialize(self._activation),
|
|
820
|
-
'dropout': self._dropout,
|
|
821
|
-
})
|
|
1327
|
+
config.update({})
|
|
822
1328
|
return config
|
|
823
1329
|
|
|
824
1330
|
|
|
825
1331
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
826
|
-
class
|
|
827
|
-
|
|
1332
|
+
class Readout(GraphLayer):
|
|
1333
|
+
|
|
1334
|
+
"""Readout layer.
|
|
828
1335
|
"""
|
|
829
|
-
def __init__(
|
|
830
|
-
self,
|
|
831
|
-
units: int = None,
|
|
832
|
-
activation: str = None,
|
|
833
|
-
field: str = 'node',
|
|
834
|
-
**kwargs
|
|
835
|
-
) -> None:
|
|
836
|
-
super().__init__(**kwargs)
|
|
837
|
-
self.units = units
|
|
838
|
-
self._activation = keras.activations.get(activation)
|
|
839
|
-
self.field = field
|
|
840
1336
|
|
|
841
|
-
def
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
self.
|
|
850
|
-
|
|
851
|
-
|
|
1337
|
+
def __init__(self, mode: str | None = None, **kwargs):
|
|
1338
|
+
kwargs['kernel_initializer'] = None
|
|
1339
|
+
kwargs['bias_initializer'] = None
|
|
1340
|
+
super().__init__(**kwargs)
|
|
1341
|
+
self.mode = mode
|
|
1342
|
+
if str(self.mode).lower().startswith('sum'):
|
|
1343
|
+
self._reduce_fn = keras.ops.segment_sum
|
|
1344
|
+
elif str(self.mode).lower().startswith('max'):
|
|
1345
|
+
self._reduce_fn = keras.ops.segment_max
|
|
1346
|
+
elif str(self.mode).lower().startswith('super'):
|
|
1347
|
+
self._reduce_fn = keras.ops.segment_sum
|
|
1348
|
+
else:
|
|
1349
|
+
self._reduce_fn = ops.segment_mean
|
|
852
1350
|
|
|
853
|
-
def propagate(self, tensor: tensors.GraphTensor):
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
return
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
'feature': feature
|
|
863
|
-
}
|
|
864
|
-
}
|
|
865
|
-
)
|
|
1351
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tf.Tensor:
|
|
1352
|
+
node_feature = tensor.node['feature']
|
|
1353
|
+
if str(self.mode).lower().startswith('super'):
|
|
1354
|
+
node_feature = keras.ops.where(
|
|
1355
|
+
tensor.node['super'][:, None], node_feature, 0.0
|
|
1356
|
+
)
|
|
1357
|
+
return self._reduce_fn(
|
|
1358
|
+
node_feature, tensor.graph_indicator, tensor.num_subgraphs
|
|
1359
|
+
)
|
|
866
1360
|
|
|
867
1361
|
def get_config(self) -> dict:
|
|
868
1362
|
config = super().get_config()
|
|
869
|
-
config.
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
'field': self.field,
|
|
873
|
-
})
|
|
874
|
-
return config
|
|
1363
|
+
config['mode'] = self.mode
|
|
1364
|
+
return config
|
|
1365
|
+
|
|
875
1366
|
|
|
876
|
-
|
|
877
1367
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
878
1368
|
class GraphNetwork(GraphLayer):
|
|
879
1369
|
|
|
@@ -881,7 +1371,7 @@ class GraphNetwork(GraphLayer):
|
|
|
881
1371
|
|
|
882
1372
|
Sequentially calls graph layers (`GraphLayer`) and concatenates its output.
|
|
883
1373
|
|
|
884
|
-
|
|
1374
|
+
Arguments:
|
|
885
1375
|
layers (list):
|
|
886
1376
|
A list of graph layers.
|
|
887
1377
|
"""
|
|
@@ -891,36 +1381,32 @@ class GraphNetwork(GraphLayer):
|
|
|
891
1381
|
self.layers = layers
|
|
892
1382
|
self._update_edge_feature = False
|
|
893
1383
|
|
|
894
|
-
def
|
|
895
|
-
"""Builds the layer.
|
|
896
|
-
"""
|
|
1384
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
897
1385
|
units = self.layers[0].units
|
|
898
1386
|
node_feature_dim = spec.node['feature'].shape[-1]
|
|
899
|
-
|
|
1387
|
+
self._update_node_feature = node_feature_dim != units
|
|
1388
|
+
if self._update_node_feature:
|
|
900
1389
|
warn(
|
|
901
1390
|
'Node feature dim does not match `units` of the first layer. '
|
|
902
1391
|
'Automatically adding a node projection layer to match `units`.'
|
|
903
1392
|
)
|
|
904
1393
|
self._node_dense = self.get_dense(units)
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
if has_edge_feature:
|
|
1394
|
+
self._has_edge_feature = 'feature' in spec.edge
|
|
1395
|
+
if self._has_edge_feature:
|
|
908
1396
|
edge_feature_dim = spec.edge['feature'].shape[-1]
|
|
909
|
-
|
|
1397
|
+
self._update_edge_feature = edge_feature_dim != units
|
|
1398
|
+
if self._update_edge_feature:
|
|
910
1399
|
warn(
|
|
911
1400
|
'Edge feature dim does not match `units` of the first layer. '
|
|
912
1401
|
'Automatically adding a edge projection layer to match `units`.'
|
|
913
1402
|
)
|
|
914
1403
|
self._edge_dense = self.get_dense(units)
|
|
915
|
-
self._update_edge_feature = True
|
|
916
1404
|
|
|
917
1405
|
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
918
|
-
"""Calls the layer.
|
|
919
|
-
"""
|
|
920
1406
|
x = tensors.to_dict(tensor)
|
|
921
1407
|
if self._update_node_feature:
|
|
922
1408
|
x['node']['feature'] = self._node_dense(tensor.node['feature'])
|
|
923
|
-
if self._update_edge_feature:
|
|
1409
|
+
if self._has_edge_feature and self._update_edge_feature:
|
|
924
1410
|
x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
|
|
925
1411
|
outputs = [x['node']['feature']]
|
|
926
1412
|
for layer in self.layers:
|
|
@@ -945,7 +1431,7 @@ class GraphNetwork(GraphLayer):
|
|
|
945
1431
|
Performs the same forward pass as `propagate` but with a `GradientTape`
|
|
946
1432
|
watching intermediate node features.
|
|
947
1433
|
|
|
948
|
-
|
|
1434
|
+
Arguments:
|
|
949
1435
|
tensor (tensors.GraphTensor):
|
|
950
1436
|
The graph input.
|
|
951
1437
|
"""
|
|
@@ -1003,24 +1489,25 @@ class NodeEmbedding(GraphLayer):
|
|
|
1003
1489
|
def __init__(
|
|
1004
1490
|
self,
|
|
1005
1491
|
dim: int = None,
|
|
1492
|
+
normalization: bool = False,
|
|
1006
1493
|
embed_context: bool = True,
|
|
1494
|
+
allow_reconstruction: bool = False,
|
|
1007
1495
|
allow_masking: bool = True,
|
|
1008
1496
|
**kwargs
|
|
1009
1497
|
) -> None:
|
|
1010
1498
|
super().__init__(**kwargs)
|
|
1011
1499
|
self.dim = dim
|
|
1500
|
+
self._normalization = normalization
|
|
1012
1501
|
self._embed_context = embed_context
|
|
1013
1502
|
self._masking_rate = None
|
|
1014
1503
|
self._allow_masking = allow_masking
|
|
1504
|
+
self._allow_reconstruction = allow_reconstruction
|
|
1015
1505
|
|
|
1016
|
-
def
|
|
1017
|
-
"""Builds the layer.
|
|
1018
|
-
"""
|
|
1506
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1019
1507
|
feature_dim = spec.node['feature'].shape[-1]
|
|
1020
1508
|
if not self.dim:
|
|
1021
1509
|
self.dim = feature_dim
|
|
1022
1510
|
self._node_dense = self.get_dense(self.dim)
|
|
1023
|
-
self._node_dense.build([None, feature_dim])
|
|
1024
1511
|
|
|
1025
1512
|
self._has_super = 'super' in spec.node
|
|
1026
1513
|
has_context_feature = 'feature' in spec.context
|
|
@@ -1034,11 +1521,18 @@ class NodeEmbedding(GraphLayer):
|
|
|
1034
1521
|
if self._embed_context:
|
|
1035
1522
|
context_feature_dim = spec.context['feature'].shape[-1]
|
|
1036
1523
|
self._context_dense = self.get_dense(self.dim)
|
|
1037
|
-
|
|
1524
|
+
|
|
1525
|
+
if self._normalization:
|
|
1526
|
+
if str(self._normalization).lower().startswith('batch'):
|
|
1527
|
+
self._norm = keras.layers.BatchNormalization(
|
|
1528
|
+
name='output_batch_norm'
|
|
1529
|
+
)
|
|
1530
|
+
else:
|
|
1531
|
+
self._norm = keras.layers.LayerNormalization(
|
|
1532
|
+
name='output_layer_norm'
|
|
1533
|
+
)
|
|
1038
1534
|
|
|
1039
1535
|
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1040
|
-
"""Calls the layer.
|
|
1041
|
-
"""
|
|
1042
1536
|
feature = self._node_dense(tensor.node['feature'])
|
|
1043
1537
|
|
|
1044
1538
|
if self._has_super:
|
|
@@ -1068,8 +1562,13 @@ class NodeEmbedding(GraphLayer):
|
|
|
1068
1562
|
# Slience warning of 'no gradients for variables'
|
|
1069
1563
|
feature = feature + (self._mask_feature * 0.0)
|
|
1070
1564
|
|
|
1071
|
-
|
|
1565
|
+
if self._normalization:
|
|
1566
|
+
feature = self._norm(feature)
|
|
1072
1567
|
|
|
1568
|
+
if not self._allow_reconstruction:
|
|
1569
|
+
return tensor.update({'node': {'feature': feature}})
|
|
1570
|
+
return tensor.update({'node': {'feature': feature, 'target_feature': feature}})
|
|
1571
|
+
|
|
1073
1572
|
@property
|
|
1074
1573
|
def masking_rate(self):
|
|
1075
1574
|
return self._masking_rate
|
|
@@ -1087,7 +1586,10 @@ class NodeEmbedding(GraphLayer):
|
|
|
1087
1586
|
config = super().get_config()
|
|
1088
1587
|
config.update({
|
|
1089
1588
|
'dim': self.dim,
|
|
1090
|
-
'
|
|
1589
|
+
'normalization': self._normalization,
|
|
1590
|
+
'embed_context': self._embed_context,
|
|
1591
|
+
'allow_masking': self._allow_masking,
|
|
1592
|
+
'allow_reconstruction': self._allow_reconstruction,
|
|
1091
1593
|
})
|
|
1092
1594
|
return config
|
|
1093
1595
|
|
|
@@ -1103,22 +1605,21 @@ class EdgeEmbedding(GraphLayer):
|
|
|
1103
1605
|
def __init__(
|
|
1104
1606
|
self,
|
|
1105
1607
|
dim: int = None,
|
|
1608
|
+
normalization: bool = False,
|
|
1106
1609
|
allow_masking: bool = True,
|
|
1107
1610
|
**kwargs
|
|
1108
1611
|
) -> None:
|
|
1109
1612
|
super().__init__(**kwargs)
|
|
1110
1613
|
self.dim = dim
|
|
1614
|
+
self._normalization = normalization
|
|
1111
1615
|
self._masking_rate = None
|
|
1112
1616
|
self._allow_masking = allow_masking
|
|
1113
1617
|
|
|
1114
|
-
def
|
|
1115
|
-
"""Builds the layer.
|
|
1116
|
-
"""
|
|
1618
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1117
1619
|
feature_dim = spec.edge['feature'].shape[-1]
|
|
1118
1620
|
if not self.dim:
|
|
1119
1621
|
self.dim = feature_dim
|
|
1120
1622
|
self._edge_dense = self.get_dense(self.dim)
|
|
1121
|
-
self._edge_dense.build([None, feature_dim])
|
|
1122
1623
|
|
|
1123
1624
|
self._has_super = 'super' in spec.edge
|
|
1124
1625
|
if self._has_super:
|
|
@@ -1126,9 +1627,17 @@ class EdgeEmbedding(GraphLayer):
|
|
|
1126
1627
|
if self._allow_masking:
|
|
1127
1628
|
self._mask_feature = self.get_weight(shape=[self.dim], name='mask_edge_feature')
|
|
1128
1629
|
|
|
1630
|
+
if self._normalization:
|
|
1631
|
+
if str(self._normalization).lower().startswith('batch'):
|
|
1632
|
+
self._norm = keras.layers.BatchNormalization(
|
|
1633
|
+
name='output_batch_norm'
|
|
1634
|
+
)
|
|
1635
|
+
else:
|
|
1636
|
+
self._norm = keras.layers.LayerNormalization(
|
|
1637
|
+
name='output_layer_norm'
|
|
1638
|
+
)
|
|
1639
|
+
|
|
1129
1640
|
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1130
|
-
"""Calls the layer.
|
|
1131
|
-
"""
|
|
1132
1641
|
feature = self._edge_dense(tensor.edge['feature'])
|
|
1133
1642
|
|
|
1134
1643
|
if self._has_super:
|
|
@@ -1153,7 +1662,10 @@ class EdgeEmbedding(GraphLayer):
|
|
|
1153
1662
|
# Slience warning of 'no gradients for variables'
|
|
1154
1663
|
feature = feature + (self._mask_feature * 0.0)
|
|
1155
1664
|
|
|
1156
|
-
|
|
1665
|
+
if self._normalization:
|
|
1666
|
+
feature = self._norm(feature)
|
|
1667
|
+
|
|
1668
|
+
return tensor.update({'edge': {'feature': feature, 'embedding': feature}})
|
|
1157
1669
|
|
|
1158
1670
|
@property
|
|
1159
1671
|
def masking_rate(self):
|
|
@@ -1172,17 +1684,67 @@ class EdgeEmbedding(GraphLayer):
|
|
|
1172
1684
|
config = super().get_config()
|
|
1173
1685
|
config.update({
|
|
1174
1686
|
'dim': self.dim,
|
|
1687
|
+
'normalization': self._normalization,
|
|
1175
1688
|
'allow_masking': self._allow_masking
|
|
1176
1689
|
})
|
|
1177
1690
|
return config
|
|
1178
1691
|
|
|
1179
1692
|
|
|
1693
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1694
|
+
class Projection(GraphLayer):
|
|
1695
|
+
"""Base graph projection layer.
|
|
1696
|
+
"""
|
|
1697
|
+
def __init__(
|
|
1698
|
+
self,
|
|
1699
|
+
units: int = None,
|
|
1700
|
+
activation: str | keras.layers.Activation | None = None,
|
|
1701
|
+
use_bias: bool = True,
|
|
1702
|
+
field: str = 'node',
|
|
1703
|
+
**kwargs
|
|
1704
|
+
) -> None:
|
|
1705
|
+
super().__init__(use_bias=use_bias, **kwargs)
|
|
1706
|
+
self.units = units
|
|
1707
|
+
self._activation = keras.activations.get(activation)
|
|
1708
|
+
self.field = field
|
|
1709
|
+
|
|
1710
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1711
|
+
data = getattr(spec, self.field, None)
|
|
1712
|
+
if data is None:
|
|
1713
|
+
raise ValueError('Could not access field {self.field!r}.')
|
|
1714
|
+
feature_dim = data['feature'].shape[-1]
|
|
1715
|
+
if not self.units:
|
|
1716
|
+
self.units = feature_dim
|
|
1717
|
+
self._dense = self.get_dense(self.units)
|
|
1718
|
+
|
|
1719
|
+
def propagate(self, tensor: tensors.GraphTensor):
|
|
1720
|
+
feature = getattr(tensor, self.field)['feature']
|
|
1721
|
+
feature = self._dense(feature)
|
|
1722
|
+
feature = self._activation(feature)
|
|
1723
|
+
return tensor.update(
|
|
1724
|
+
{
|
|
1725
|
+
self.field: {
|
|
1726
|
+
'feature': feature
|
|
1727
|
+
}
|
|
1728
|
+
}
|
|
1729
|
+
)
|
|
1730
|
+
|
|
1731
|
+
def get_config(self) -> dict:
|
|
1732
|
+
config = super().get_config()
|
|
1733
|
+
config.update({
|
|
1734
|
+
'units': self.units,
|
|
1735
|
+
'activation': keras.activations.serialize(self._activation),
|
|
1736
|
+
'field': self.field,
|
|
1737
|
+
})
|
|
1738
|
+
return config
|
|
1739
|
+
|
|
1740
|
+
|
|
1180
1741
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1181
1742
|
class ContextProjection(Projection):
|
|
1182
1743
|
"""Context projection layer.
|
|
1183
1744
|
"""
|
|
1184
1745
|
def __init__(self, units: int = None, activation: str = None, **kwargs):
|
|
1185
|
-
|
|
1746
|
+
kwargs['field'] = 'context'
|
|
1747
|
+
super().__init__(units=units, activation=activation, **kwargs)
|
|
1186
1748
|
|
|
1187
1749
|
|
|
1188
1750
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
@@ -1190,7 +1752,8 @@ class NodeProjection(Projection):
|
|
|
1190
1752
|
"""Node projection layer.
|
|
1191
1753
|
"""
|
|
1192
1754
|
def __init__(self, units: int = None, activation: str = None, **kwargs):
|
|
1193
|
-
|
|
1755
|
+
kwargs['field'] = 'node'
|
|
1756
|
+
super().__init__(units=units, activation=activation, **kwargs)
|
|
1194
1757
|
|
|
1195
1758
|
|
|
1196
1759
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
@@ -1198,103 +1761,126 @@ class EdgeProjection(Projection):
|
|
|
1198
1761
|
"""Edge projection layer.
|
|
1199
1762
|
"""
|
|
1200
1763
|
def __init__(self, units: int = None, activation: str = None, **kwargs):
|
|
1201
|
-
|
|
1202
|
-
|
|
1764
|
+
kwargs['field'] = 'edge'
|
|
1765
|
+
super().__init__(units=units, activation=activation, **kwargs)
|
|
1766
|
+
|
|
1203
1767
|
|
|
1204
1768
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1205
|
-
class
|
|
1769
|
+
class Reconstruction(GraphLayer):
|
|
1206
1770
|
|
|
1207
|
-
def __init__(
|
|
1771
|
+
def __init__(
|
|
1772
|
+
self,
|
|
1773
|
+
loss: keras.losses.Loss | str = 'mse',
|
|
1774
|
+
loss_weight: float = 0.5,
|
|
1775
|
+
**kwargs
|
|
1776
|
+
):
|
|
1208
1777
|
super().__init__(**kwargs)
|
|
1209
|
-
self.
|
|
1210
|
-
|
|
1211
|
-
|
|
1212
|
-
|
|
1213
|
-
|
|
1214
|
-
|
|
1215
|
-
|
|
1216
|
-
|
|
1217
|
-
|
|
1218
|
-
|
|
1219
|
-
self._reduce_fn = ops.segment_mean
|
|
1220
|
-
|
|
1221
|
-
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1222
|
-
"""Builds the layer.
|
|
1223
|
-
"""
|
|
1224
|
-
pass
|
|
1225
|
-
|
|
1226
|
-
def reduce(self, tensor: tensors.GraphTensor) -> tf.Tensor:
|
|
1227
|
-
if self._reduce_fn is None:
|
|
1228
|
-
raise NotImplementedError("Need to define a reduce method.")
|
|
1229
|
-
if str(self.mode).lower().startswith('super'):
|
|
1230
|
-
node_feature = keras.ops.where(
|
|
1231
|
-
tensor.node['super'][:, None], tensor.node['feature'], 0.0
|
|
1232
|
-
)
|
|
1233
|
-
return self._reduce_fn(
|
|
1234
|
-
node_feature, tensor.graph_indicator, tensor.num_subgraphs
|
|
1778
|
+
self._loss_fn = keras.losses.get(loss)
|
|
1779
|
+
self._loss_weight = loss_weight
|
|
1780
|
+
|
|
1781
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1782
|
+
has_target_node_feature = 'target_feature' in spec.node
|
|
1783
|
+
if not has_target_node_feature:
|
|
1784
|
+
raise ValueError(
|
|
1785
|
+
'Could not find `target_feature` in `spec.node`. '
|
|
1786
|
+
'Add a `target_feature` via `NodeEmbedding` by setting '
|
|
1787
|
+
'`allow_reconstruction` to `True`.'
|
|
1235
1788
|
)
|
|
1236
|
-
|
|
1237
|
-
|
|
1238
|
-
)
|
|
1789
|
+
output_dim = spec.node['target_feature'].shape[-1]
|
|
1790
|
+
self._dense = self.get_dense(output_dim)
|
|
1239
1791
|
|
|
1240
|
-
def
|
|
1241
|
-
|
|
1242
|
-
|
|
1243
|
-
self.built = True
|
|
1792
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1793
|
+
target_node_feature = tensor.node['target_feature']
|
|
1794
|
+
transformed_node_feature = tensor.node['feature']
|
|
1244
1795
|
|
|
1245
|
-
|
|
1246
|
-
|
|
1247
|
-
|
|
1248
|
-
graph_tensor = graph_tensor.flatten()
|
|
1249
|
-
return self.reduce(graph_tensor)
|
|
1796
|
+
reconstructed_node_feature = self._dense(
|
|
1797
|
+
transformed_node_feature
|
|
1798
|
+
)
|
|
1250
1799
|
|
|
1251
|
-
|
|
1252
|
-
|
|
1253
|
-
|
|
1254
|
-
*
|
|
1255
|
-
|
|
1256
|
-
) -> tensors.GraphTensor:
|
|
1257
|
-
is_tensor = isinstance(graph, tensors.GraphTensor)
|
|
1258
|
-
if is_tensor:
|
|
1259
|
-
graph = tensors.to_dict(graph)
|
|
1260
|
-
tensor = super().__call__(graph, *args, **kwargs)
|
|
1261
|
-
return tensor
|
|
1800
|
+
loss = self._loss_fn(
|
|
1801
|
+
target_node_feature, reconstructed_node_feature
|
|
1802
|
+
)
|
|
1803
|
+
self.add_loss(keras.ops.sum(loss) * self._loss_weight)
|
|
1804
|
+
return tensor.update({'node': {'feature': transformed_node_feature}})
|
|
1262
1805
|
|
|
1263
|
-
def get_config(self)
|
|
1806
|
+
def get_config(self):
|
|
1264
1807
|
config = super().get_config()
|
|
1265
|
-
config['
|
|
1266
|
-
|
|
1267
|
-
|
|
1808
|
+
config['loss'] = keras.losses.serialize(self._loss_fn)
|
|
1809
|
+
config['loss_weight'] = self._loss_weight
|
|
1810
|
+
return config
|
|
1811
|
+
|
|
1268
1812
|
|
|
1269
1813
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1270
|
-
class
|
|
1814
|
+
class EdgeBias(GraphLayer):
|
|
1271
1815
|
|
|
1272
|
-
def
|
|
1816
|
+
def __init__(self, biases: int, **kwargs):
|
|
1817
|
+
super().__init__(**kwargs)
|
|
1818
|
+
self.biases = biases
|
|
1819
|
+
|
|
1820
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1273
1821
|
self._has_edge_length = 'length' in spec.edge
|
|
1274
1822
|
self._has_edge_feature = 'feature' in spec.edge
|
|
1275
1823
|
if self._has_edge_feature:
|
|
1276
|
-
self._edge_feature_dense = self.get_dense(
|
|
1824
|
+
self._edge_feature_dense = self.get_dense(self.biases)
|
|
1277
1825
|
if self._has_edge_length:
|
|
1278
1826
|
self._edge_length_dense = self.get_dense(
|
|
1279
|
-
|
|
1827
|
+
self.biases, kernel_initializer='zeros'
|
|
1280
1828
|
)
|
|
1281
|
-
|
|
1829
|
+
|
|
1282
1830
|
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1283
1831
|
bias = keras.ops.zeros(
|
|
1284
|
-
shape=(tensor.num_edges,
|
|
1832
|
+
shape=(tensor.num_edges, self.biases),
|
|
1285
1833
|
dtype=tensor.node['feature'].dtype
|
|
1286
1834
|
)
|
|
1287
1835
|
if self._has_edge_feature:
|
|
1288
1836
|
bias += self._edge_feature_dense(tensor.edge['feature'])
|
|
1289
1837
|
if self._has_edge_length:
|
|
1290
1838
|
bias += self._edge_length_dense(tensor.edge['length'])
|
|
1291
|
-
return
|
|
1292
|
-
|
|
1293
|
-
|
|
1294
|
-
|
|
1295
|
-
|
|
1296
|
-
|
|
1839
|
+
return bias
|
|
1840
|
+
|
|
1841
|
+
def get_config(self) -> dict:
|
|
1842
|
+
config = super().get_config()
|
|
1843
|
+
config.update({'biases': self.biases})
|
|
1844
|
+
return config
|
|
1845
|
+
|
|
1846
|
+
|
|
1847
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1848
|
+
class GaussianDistance(GraphLayer):
|
|
1849
|
+
|
|
1850
|
+
def __init__(self, kernels: int, **kwargs):
|
|
1851
|
+
super().__init__(**kwargs)
|
|
1852
|
+
self.kernels = kernels
|
|
1853
|
+
|
|
1854
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1855
|
+
self._loc = self.add_weight(
|
|
1856
|
+
shape=[self.kernels],
|
|
1857
|
+
initializer='zeros',
|
|
1858
|
+
dtype='float32',
|
|
1859
|
+
trainable=True
|
|
1860
|
+
)
|
|
1861
|
+
self._scale = self.add_weight(
|
|
1862
|
+
shape=[self.kernels],
|
|
1863
|
+
initializer='ones',
|
|
1864
|
+
dtype='float32',
|
|
1865
|
+
trainable=True
|
|
1866
|
+
)
|
|
1867
|
+
|
|
1868
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1869
|
+
euclidean_distance = ops.euclidean_distance(
|
|
1870
|
+
tensor.gather('coordinate', 'source'),
|
|
1871
|
+
tensor.gather('coordinate', 'target'),
|
|
1872
|
+
axis=-1
|
|
1873
|
+
)
|
|
1874
|
+
return ops.gaussian(
|
|
1875
|
+
euclidean_distance, self._loc, self._scale
|
|
1297
1876
|
)
|
|
1877
|
+
|
|
1878
|
+
def get_config(self) -> dict:
|
|
1879
|
+
config = super().get_config()
|
|
1880
|
+
config.update({
|
|
1881
|
+
'kernels': self.kernels,
|
|
1882
|
+
})
|
|
1883
|
+
return config
|
|
1298
1884
|
|
|
1299
1885
|
|
|
1300
1886
|
def Input(spec: tensors.GraphTensor.Spec) -> dict:
|
|
@@ -1412,13 +1998,6 @@ def _spec_from_inputs(inputs):
|
|
|
1412
1998
|
return tensors.GraphTensor.Spec(**nested_specs)
|
|
1413
1999
|
|
|
1414
2000
|
|
|
1415
|
-
GraphTransformer =
|
|
1416
|
-
|
|
1417
|
-
|
|
1418
|
-
EdgeEmbed = EdgeEmbedding
|
|
1419
|
-
NodeEmbed = NodeEmbedding
|
|
1420
|
-
|
|
1421
|
-
ContextDense = ContextProjection
|
|
1422
|
-
EdgeDense = EdgeProjection
|
|
1423
|
-
NodeDense = NodeProjection
|
|
2001
|
+
GraphTransformer = GTConv
|
|
2002
|
+
GraphTransformer3D = GTConv3D
|
|
1424
2003
|
|