molcraft 0.1.0a3__py3-none-any.whl → 0.1.0a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of molcraft might be problematic. Click here for more details.
- molcraft/__init__.py +1 -1
- molcraft/callbacks.py +12 -0
- molcraft/descriptors.py +24 -23
- molcraft/experimental/peptides.py +96 -79
- molcraft/featurizers.py +59 -37
- molcraft/layers.py +692 -565
- molcraft/models.py +14 -1
- molcraft/ops.py +14 -3
- molcraft/tensors.py +3 -3
- {molcraft-0.1.0a3.dist-info → molcraft-0.1.0a4.dist-info}/METADATA +6 -6
- molcraft-0.1.0a4.dist-info/RECORD +20 -0
- {molcraft-0.1.0a3.dist-info → molcraft-0.1.0a4.dist-info}/WHEEL +1 -1
- molcraft-0.1.0a3.dist-info/RECORD +0 -20
- {molcraft-0.1.0a3.dist-info → molcraft-0.1.0a4.dist-info}/licenses/LICENSE +0 -0
- {molcraft-0.1.0a3.dist-info → molcraft-0.1.0a4.dist-info}/top_level.txt +0 -0
molcraft/layers.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
import abc
|
|
2
1
|
import keras
|
|
3
2
|
import tensorflow as tf
|
|
4
3
|
import warnings
|
|
4
|
+
import functools
|
|
5
5
|
from keras.src.models import functional
|
|
6
6
|
|
|
7
7
|
from molcraft import tensors
|
|
@@ -12,11 +12,39 @@ from molcraft import ops
|
|
|
12
12
|
class GraphLayer(keras.layers.Layer):
|
|
13
13
|
"""Base graph layer.
|
|
14
14
|
|
|
15
|
-
|
|
15
|
+
Subclasses must implement a forward pass via **propagate(graph)**.
|
|
16
|
+
|
|
17
|
+
Subclasses may create dense layers and weights in **build(graph_spec)**.
|
|
18
|
+
|
|
19
|
+
Note: `GraphLayer` currently only supports `GraphTensor` input.
|
|
16
20
|
|
|
17
|
-
The list of arguments are only relevant if the derived layer
|
|
21
|
+
The list of arguments below are only relevant if the derived layer
|
|
18
22
|
invokes 'get_dense_kwargs`, `get_dense` or `get_einsum_dense`.
|
|
19
23
|
|
|
24
|
+
Arguments:
|
|
25
|
+
use_bias (bool):
|
|
26
|
+
Whether bias should be used in dense layers. Default to `True`.
|
|
27
|
+
kernel_initializer (keras.initializers.Initializer, str):
|
|
28
|
+
Initializer for the kernel weight matrix of the dense layers.
|
|
29
|
+
Default to `glorot_uniform`.
|
|
30
|
+
bias_initializer (keras.initializers.Initializer, str):
|
|
31
|
+
Initializer for the bias weight vector of the dense layers.
|
|
32
|
+
Default to `zeros`.
|
|
33
|
+
kernel_regularizer (keras.regularizers.Regularizer, None):
|
|
34
|
+
Regularizer function applied to the kernel weight matrix.
|
|
35
|
+
Default to `None`.
|
|
36
|
+
bias_regularizer (keras.regularizers.Regularizer, None):
|
|
37
|
+
Regularizer function applied to the bias weight vector.
|
|
38
|
+
Default to `None`.
|
|
39
|
+
activity_regularizer (keras.regularizers.Regularizer, None):
|
|
40
|
+
Regularizer function applied to the output of the dense layers.
|
|
41
|
+
Default to `None`.
|
|
42
|
+
kernel_constraint (keras.constraints.Constraint, None):
|
|
43
|
+
Constraint function applied to the kernel weight matrix.
|
|
44
|
+
Default to `None`.
|
|
45
|
+
bias_constraint (keras.constraints.Constraint, None):
|
|
46
|
+
Constraint function applied to the bias weight vector.
|
|
47
|
+
Default to `None`.
|
|
20
48
|
"""
|
|
21
49
|
|
|
22
50
|
def __init__(
|
|
@@ -31,68 +59,61 @@ class GraphLayer(keras.layers.Layer):
|
|
|
31
59
|
bias_constraint: keras.constraints.Constraint | None = None,
|
|
32
60
|
**kwargs,
|
|
33
61
|
) -> None:
|
|
34
|
-
super().__init__(
|
|
62
|
+
super().__init__(**kwargs)
|
|
35
63
|
self._use_bias = use_bias
|
|
36
64
|
self._kernel_initializer = keras.initializers.get(kernel_initializer)
|
|
37
65
|
self._bias_initializer = keras.initializers.get(bias_initializer)
|
|
38
66
|
self._kernel_regularizer = keras.regularizers.get(kernel_regularizer)
|
|
39
67
|
self._bias_regularizer = keras.regularizers.get(bias_regularizer)
|
|
68
|
+
self._activity_regularizer = keras.regularizers.get(activity_regularizer)
|
|
40
69
|
self._kernel_constraint = keras.constraints.get(kernel_constraint)
|
|
41
70
|
self._bias_constraint = keras.constraints.get(bias_constraint)
|
|
71
|
+
self._custom_build_config = {}
|
|
42
72
|
self.built = False
|
|
43
|
-
|
|
44
|
-
|
|
73
|
+
|
|
74
|
+
def __init_subclass__(cls, **kwargs):
|
|
75
|
+
super().__init_subclass__(**kwargs)
|
|
76
|
+
subclass_build = cls.build
|
|
77
|
+
|
|
78
|
+
@functools.wraps(subclass_build)
|
|
79
|
+
def build_wrapper(self: GraphLayer, spec: tensors.GraphTensor.Spec | None):
|
|
80
|
+
GraphLayer.build(self, spec)
|
|
81
|
+
subclass_build(self, spec)
|
|
82
|
+
if not self.built and isinstance(self, keras.Model):
|
|
83
|
+
symbolic_inputs = Input(spec)
|
|
84
|
+
self.built = True
|
|
85
|
+
self(symbolic_inputs)
|
|
86
|
+
|
|
87
|
+
cls.build = build_wrapper
|
|
45
88
|
|
|
46
89
|
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
47
|
-
"""
|
|
90
|
+
"""Forward pass.
|
|
48
91
|
|
|
49
|
-
|
|
92
|
+
Must be implemented by subclass.
|
|
50
93
|
|
|
51
|
-
|
|
94
|
+
Arguments:
|
|
52
95
|
tensor:
|
|
53
96
|
A `GraphTensor` instance.
|
|
54
97
|
"""
|
|
55
|
-
raise NotImplementedError(
|
|
98
|
+
raise NotImplementedError(
|
|
99
|
+
'The forward pass of the layer is not implemented. '
|
|
100
|
+
'Please implement `propagate`.'
|
|
101
|
+
)
|
|
56
102
|
|
|
57
|
-
def
|
|
103
|
+
def build(self, tensor_spec: tensors.GraphTensor.Spec) -> None:
|
|
58
104
|
"""Builds the layer.
|
|
59
105
|
|
|
60
106
|
May use built-in methods such as `get_weight`, `get_dense` and `get_einsum_dense`.
|
|
61
107
|
|
|
62
|
-
Optionally implemented by subclass.
|
|
63
|
-
If sub-layers are built (via `build` or `build_from_spec`), set `built`
|
|
64
|
-
to True. If not, symbolic input will be passed through the layer to build them.
|
|
108
|
+
Optionally implemented by subclass.
|
|
65
109
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
A `GraphTensor.Spec` instance
|
|
110
|
+
Arguments:
|
|
111
|
+
tensor_spec:
|
|
112
|
+
A `GraphTensor.Spec` instance corresponding to the `GraphTensor`
|
|
69
113
|
passed to `propagate`.
|
|
70
114
|
"""
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
self._custom_build_config = {'spec': _serialize_spec(spec)}
|
|
75
|
-
|
|
76
|
-
self.build_from_spec(spec)
|
|
77
|
-
|
|
78
|
-
if not self.built:
|
|
79
|
-
# Automatically build layer or model by calling it on symbolic inputs
|
|
80
|
-
self.built = True
|
|
81
|
-
symbolic_inputs = Input(spec)
|
|
82
|
-
self(symbolic_inputs)
|
|
83
|
-
|
|
84
|
-
def get_build_config(self) -> dict:
|
|
85
|
-
if not hasattr(self, '_custom_build_config'):
|
|
86
|
-
return super().get_build_config()
|
|
87
|
-
return self._custom_build_config
|
|
88
|
-
|
|
89
|
-
def build_from_config(self, config: dict) -> None:
|
|
90
|
-
use_custom_build_from_config = ('spec' in config)
|
|
91
|
-
if not use_custom_build_from_config:
|
|
92
|
-
super().build_from_config(config)
|
|
93
|
-
else:
|
|
94
|
-
spec = _deserialize_spec(config['spec'])
|
|
95
|
-
self.build(spec)
|
|
115
|
+
if isinstance(tensor_spec, tensors.GraphTensor.Spec):
|
|
116
|
+
self._custom_build_config['spec'] = _serialize_spec(tensor_spec)
|
|
96
117
|
|
|
97
118
|
def call(
|
|
98
119
|
self,
|
|
@@ -122,6 +143,19 @@ class GraphLayer(keras.layers.Layer):
|
|
|
122
143
|
outputs = tensors.from_dict(outputs)
|
|
123
144
|
return outputs
|
|
124
145
|
|
|
146
|
+
def get_build_config(self) -> dict:
|
|
147
|
+
if self._custom_build_config:
|
|
148
|
+
return self._custom_build_config
|
|
149
|
+
return super().get_build_config()
|
|
150
|
+
|
|
151
|
+
def build_from_config(self, config: dict) -> None:
|
|
152
|
+
serialized_spec = config.get('spec')
|
|
153
|
+
if serialized_spec is not None:
|
|
154
|
+
spec = _deserialize_spec(serialized_spec)
|
|
155
|
+
self.build(spec)
|
|
156
|
+
else:
|
|
157
|
+
super().build_from_config(config)
|
|
158
|
+
|
|
125
159
|
def get_weight(
|
|
126
160
|
self,
|
|
127
161
|
shape: tf.TensorShape,
|
|
@@ -163,7 +197,7 @@ class GraphLayer(keras.layers.Layer):
|
|
|
163
197
|
use_bias=self._use_bias,
|
|
164
198
|
kernel_regularizer=self._kernel_regularizer,
|
|
165
199
|
bias_regularizer=self._bias_regularizer,
|
|
166
|
-
activity_regularizer=self.
|
|
200
|
+
activity_regularizer=self._activity_regularizer,
|
|
167
201
|
kernel_constraint=self._kernel_constraint,
|
|
168
202
|
bias_constraint=self._bias_constraint,
|
|
169
203
|
)
|
|
@@ -189,55 +223,136 @@ class GraphLayer(keras.layers.Layer):
|
|
|
189
223
|
keras.regularizers.serialize(self._kernel_regularizer),
|
|
190
224
|
"bias_regularizer":
|
|
191
225
|
keras.regularizers.serialize(self._bias_regularizer),
|
|
226
|
+
"activity_regularizer":
|
|
227
|
+
keras.regularizers.serialize(self._activity_regularizer),
|
|
192
228
|
"kernel_constraint":
|
|
193
229
|
keras.constraints.serialize(self._kernel_constraint),
|
|
194
230
|
"bias_constraint":
|
|
195
231
|
keras.constraints.serialize(self._bias_constraint),
|
|
196
232
|
})
|
|
197
233
|
return config
|
|
198
|
-
|
|
234
|
+
|
|
199
235
|
|
|
200
236
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
201
237
|
class GraphConv(GraphLayer):
|
|
202
238
|
|
|
203
239
|
"""Base graph neural network layer.
|
|
204
240
|
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
241
|
+
This layer implements the three basic steps of a graph neural network layer, each of which
|
|
242
|
+
can (optionally) be overridden by the `GraphConv` subclass:
|
|
243
|
+
|
|
244
|
+
1. **message(graph)**, which computes the *messages* to be passed to target nodes;
|
|
245
|
+
2. **aggregate(graph)**, which *aggregates* messages to target nodes;
|
|
246
|
+
3. **update(graph)**, which further *updates* (target) nodes.
|
|
247
|
+
|
|
248
|
+
Note: for skip connection to work, the `GraphConv` subclass requires final node feature
|
|
249
|
+
output dimension to be equal to `units`.
|
|
250
|
+
|
|
251
|
+
Arguments:
|
|
252
|
+
units (int):
|
|
253
|
+
Dimensionality of the output space.
|
|
254
|
+
activation (keras.layers.Activation, str, None):
|
|
255
|
+
Activation function to use. If not specified, a linear activation (a(x) = x) is used.
|
|
256
|
+
Default to `None`.
|
|
257
|
+
use_bias (bool):
|
|
258
|
+
Whether bias should be used in dense layers. Default to `True`.
|
|
259
|
+
normalization (bool, str):
|
|
260
|
+
Whether `LayerNormalization` should be applied to the final node feature output.
|
|
261
|
+
To use `BatchNormalization`, specify `batch_norm`. Default to `False`.
|
|
262
|
+
skip_connection (bool, str):
|
|
263
|
+
Whether node feature input should be added to the node feature output.
|
|
264
|
+
If node feature input dim is not equal to `units` (node feature output dim),
|
|
265
|
+
a projection layer will automatically project the residual before adding it
|
|
266
|
+
to the output. To use weighted skip connection,
|
|
267
|
+
specify `weighted`. The weight multiplied with the skip connection is a
|
|
268
|
+
learnable scalar. Default to `True`.
|
|
269
|
+
kernel_initializer (keras.initializers.Initializer, str):
|
|
270
|
+
Initializer for the kernel weight matrix of the dense layers.
|
|
271
|
+
Default to `glorot_uniform`.
|
|
272
|
+
bias_initializer (keras.initializers.Initializer, str):
|
|
273
|
+
Initializer for the bias weight vector of the dense layers.
|
|
274
|
+
Default to `zeros`.
|
|
275
|
+
kernel_regularizer (keras.regularizers.Regularizer, None):
|
|
276
|
+
Regularizer function applied to the kernel weight matrix.
|
|
277
|
+
Default to `None`.
|
|
278
|
+
bias_regularizer (keras.regularizers.Regularizer, None):
|
|
279
|
+
Regularizer function applied to the bias weight vector.
|
|
280
|
+
Default to `None`.
|
|
281
|
+
activity_regularizer (keras.regularizers.Regularizer, None):
|
|
282
|
+
Regularizer function applied to the output of the dense layers.
|
|
283
|
+
Default to `None`.
|
|
284
|
+
kernel_constraint (keras.constraints.Constraint, None):
|
|
285
|
+
Constraint function applied to the kernel weight matrix.
|
|
286
|
+
Default to `None`.
|
|
287
|
+
bias_constraint (keras.constraints.Constraint, None):
|
|
288
|
+
Constraint function applied to the bias weight vector.
|
|
289
|
+
Default to `None`.
|
|
223
290
|
"""
|
|
224
291
|
|
|
225
292
|
def __init__(
|
|
226
293
|
self,
|
|
227
294
|
units: int = None,
|
|
228
|
-
|
|
229
|
-
|
|
295
|
+
activation: str | keras.layers.Activation | None = None,
|
|
296
|
+
use_bias: bool = True,
|
|
297
|
+
normalization: bool | str = False,
|
|
298
|
+
skip_connection: bool | str = True,
|
|
230
299
|
**kwargs
|
|
231
300
|
) -> None:
|
|
232
|
-
super().__init__(**kwargs)
|
|
233
|
-
self.
|
|
234
|
-
self.
|
|
301
|
+
super().__init__(use_bias=use_bias, **kwargs)
|
|
302
|
+
self._units = units
|
|
303
|
+
self._normalization = normalization
|
|
235
304
|
self._skip_connection = skip_connection
|
|
305
|
+
self._activation = keras.activations.get(activation)
|
|
306
|
+
|
|
307
|
+
def __init_subclass__(cls, **kwargs):
|
|
308
|
+
super().__init_subclass__(**kwargs)
|
|
309
|
+
subclass_build = cls.build
|
|
310
|
+
|
|
311
|
+
@functools.wraps(subclass_build)
|
|
312
|
+
def build_wrapper(self, spec):
|
|
313
|
+
GraphConv.build(self, spec)
|
|
314
|
+
return subclass_build(self, spec)
|
|
315
|
+
|
|
316
|
+
cls.build = build_wrapper
|
|
317
|
+
|
|
318
|
+
@property
|
|
319
|
+
def units(self):
|
|
320
|
+
return self._units
|
|
321
|
+
|
|
322
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
323
|
+
"""Forward pass.
|
|
324
|
+
|
|
325
|
+
Invokes `message(graph)`, `aggregate(graph)` and `update(graph)` in sequence.
|
|
326
|
+
|
|
327
|
+
Arguments:
|
|
328
|
+
tensor:
|
|
329
|
+
A `GraphTensor` instance.
|
|
330
|
+
"""
|
|
331
|
+
if self._skip_connection:
|
|
332
|
+
input_node_feature = tensor.node['feature']
|
|
333
|
+
if self._project_input_node_feature:
|
|
334
|
+
input_node_feature = self._residual_projection(input_node_feature)
|
|
335
|
+
|
|
336
|
+
tensor = self.message(tensor)
|
|
337
|
+
tensor = self.aggregate(tensor)
|
|
338
|
+
tensor = self.update(tensor)
|
|
339
|
+
|
|
340
|
+
updated_node_feature = tensor.node['feature']
|
|
341
|
+
|
|
342
|
+
if self._skip_connection:
|
|
343
|
+
if self._use_weighted_skip_connection:
|
|
344
|
+
input_node_feature *= self._skip_connection_weight
|
|
345
|
+
updated_node_feature += input_node_feature
|
|
346
|
+
|
|
347
|
+
if self._normalization:
|
|
348
|
+
updated_node_feature = self._output_norm(updated_node_feature)
|
|
349
|
+
|
|
350
|
+
return tensor.update({'node': {'feature': updated_node_feature}})
|
|
236
351
|
|
|
237
352
|
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
238
353
|
if not self.units:
|
|
239
354
|
raise ValueError(
|
|
240
|
-
f'`self.units` needs to be a positive integer.
|
|
355
|
+
f'`self.units` needs to be a positive integer. Found: {self.units}.'
|
|
241
356
|
)
|
|
242
357
|
node_feature_dim = spec.node['feature'].shape[-1]
|
|
243
358
|
self._project_input_node_feature = (
|
|
@@ -253,81 +368,115 @@ class GraphConv(GraphLayer):
|
|
|
253
368
|
self._residual_projection = self.get_dense(
|
|
254
369
|
self.units, name='residual_projection'
|
|
255
370
|
)
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
371
|
+
|
|
372
|
+
skip_connection = str(self._skip_connection).lower()
|
|
373
|
+
self._use_weighted_skip_connection = skip_connection.startswith('weight')
|
|
374
|
+
if self._use_weighted_skip_connection:
|
|
375
|
+
self._skip_connection_weight = self.add_weight(
|
|
376
|
+
name='skip_connection_weight',
|
|
377
|
+
shape=(),
|
|
378
|
+
initializer='ones',
|
|
379
|
+
trainable=True,
|
|
259
380
|
)
|
|
260
|
-
self._aggregation_norm.build([None, self.units])
|
|
261
381
|
|
|
262
|
-
|
|
382
|
+
if self._normalization:
|
|
383
|
+
if str(self._normalization).lower().startswith('batch'):
|
|
384
|
+
self._output_norm = keras.layers.BatchNormalization(
|
|
385
|
+
name='output_batch_norm'
|
|
386
|
+
)
|
|
387
|
+
else:
|
|
388
|
+
self._output_norm = keras.layers.LayerNormalization(
|
|
389
|
+
name='output_layer_norm'
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
self._has_edge_feature = 'edge' in spec.edge
|
|
393
|
+
|
|
394
|
+
has_overridden_message = self.__class__.message != GraphConv.message
|
|
395
|
+
if not has_overridden_message:
|
|
396
|
+
self._message_dense = self.get_dense(self.units)
|
|
397
|
+
|
|
398
|
+
has_overridden_update = self.__class__.update != GraphConv.update
|
|
399
|
+
if not has_overridden_update:
|
|
400
|
+
self._output_dense = self.get_dense(self.units)
|
|
401
|
+
self._output_activation = self._activation
|
|
263
402
|
|
|
264
|
-
@abc.abstractmethod
|
|
265
403
|
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
266
404
|
"""Compute messages.
|
|
267
405
|
|
|
268
|
-
This method
|
|
406
|
+
This method may be overridden by subclass.
|
|
269
407
|
|
|
270
|
-
|
|
408
|
+
Arguments:
|
|
271
409
|
tensor:
|
|
272
410
|
The inputted `GraphTensor` instance.
|
|
273
411
|
"""
|
|
274
|
-
|
|
275
|
-
|
|
412
|
+
if not self._has_edge_feature:
|
|
413
|
+
message_feature = tensor.gather('feature', 'source')
|
|
414
|
+
else:
|
|
415
|
+
message_feature = keras.ops.concatenate(
|
|
416
|
+
[
|
|
417
|
+
tensor.gather('feature', 'source'),
|
|
418
|
+
tensor.edge['feature']
|
|
419
|
+
],
|
|
420
|
+
axis=-1
|
|
421
|
+
)
|
|
422
|
+
message = self._message_dense(message_feature)
|
|
423
|
+
return tensor.update(
|
|
424
|
+
{
|
|
425
|
+
'edge': {
|
|
426
|
+
'message': message
|
|
427
|
+
}
|
|
428
|
+
}
|
|
429
|
+
)
|
|
430
|
+
|
|
276
431
|
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
277
432
|
"""Aggregates messages.
|
|
278
433
|
|
|
279
|
-
This method
|
|
434
|
+
This method may be overridden by subclass.
|
|
280
435
|
|
|
281
|
-
|
|
436
|
+
Arguments:
|
|
282
437
|
tensor:
|
|
283
438
|
A `GraphTensor` instance containing a message.
|
|
284
439
|
"""
|
|
440
|
+
aggregate = tensor.aggregate('message', mode='mean')
|
|
441
|
+
return tensor.update(
|
|
442
|
+
{
|
|
443
|
+
'node': {
|
|
444
|
+
'feature': aggregate,
|
|
445
|
+
'previous_feature': tensor.node['feature']
|
|
446
|
+
},
|
|
447
|
+
'edge': {
|
|
448
|
+
'message': None
|
|
449
|
+
}
|
|
450
|
+
}
|
|
451
|
+
)
|
|
285
452
|
|
|
286
|
-
@abc.abstractmethod
|
|
287
453
|
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
288
454
|
"""Updates nodes.
|
|
289
455
|
|
|
290
|
-
This method
|
|
456
|
+
This method may be overridden by subclass.
|
|
291
457
|
|
|
292
|
-
|
|
458
|
+
Arguments:
|
|
293
459
|
tensor:
|
|
294
460
|
A `GraphTensor` instance containing aggregated messages
|
|
295
461
|
(updated node features).
|
|
296
462
|
"""
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
if self._skip_connection:
|
|
310
|
-
input_node_feature = tensor.node['feature']
|
|
311
|
-
if self._project_input_node_feature:
|
|
312
|
-
input_node_feature = self._residual_projection(input_node_feature)
|
|
313
|
-
|
|
314
|
-
tensor = self.message(tensor)
|
|
315
|
-
tensor = self.aggregate(tensor)
|
|
316
|
-
|
|
317
|
-
if self._normalize_aggregate:
|
|
318
|
-
normalized_node_feature = self._aggregation_norm(tensor.node['feature'])
|
|
319
|
-
tensor = tensor.update({'node': {'feature': normalized_node_feature}})
|
|
320
|
-
|
|
321
|
-
tensor = self.update(tensor)
|
|
322
|
-
|
|
323
|
-
if not self._skip_connection:
|
|
324
|
-
return tensor
|
|
325
|
-
|
|
326
|
-
updated_node_feature = tensor.node['feature']
|
|
463
|
+
if not 'previous_feature' in tensor.node:
|
|
464
|
+
feature = tensor.node['feature']
|
|
465
|
+
else:
|
|
466
|
+
feature = keras.ops.concatenate(
|
|
467
|
+
[
|
|
468
|
+
tensor.node['feature'],
|
|
469
|
+
tensor.node['previous_feature']
|
|
470
|
+
],
|
|
471
|
+
axis=-1
|
|
472
|
+
)
|
|
473
|
+
update = self._output_dense(feature)
|
|
474
|
+
update = self._output_activation(update)
|
|
327
475
|
return tensor.update(
|
|
328
476
|
{
|
|
329
477
|
'node': {
|
|
330
|
-
'feature':
|
|
478
|
+
'feature': update,
|
|
479
|
+
'previous_feature': None,
|
|
331
480
|
}
|
|
332
481
|
}
|
|
333
482
|
)
|
|
@@ -336,7 +485,8 @@ class GraphConv(GraphLayer):
|
|
|
336
485
|
config = super().get_config()
|
|
337
486
|
config.update({
|
|
338
487
|
'units': self.units,
|
|
339
|
-
'
|
|
488
|
+
'activation': keras.activations.serialize(self._activation),
|
|
489
|
+
'normalization': self._normalization,
|
|
340
490
|
'skip_connection': self._skip_connection,
|
|
341
491
|
})
|
|
342
492
|
return config
|
|
@@ -346,6 +496,33 @@ class GraphConv(GraphLayer):
|
|
|
346
496
|
class GIConv(GraphConv):
|
|
347
497
|
|
|
348
498
|
"""Graph isomorphism network layer.
|
|
499
|
+
|
|
500
|
+
>>> graph = molcraft.tensors.GraphTensor(
|
|
501
|
+
... context={
|
|
502
|
+
... 'size': [2]
|
|
503
|
+
... },
|
|
504
|
+
... node={
|
|
505
|
+
... 'feature': [[1.], [2.]]
|
|
506
|
+
... },
|
|
507
|
+
... edge={
|
|
508
|
+
... 'source': [0, 1],
|
|
509
|
+
... 'target': [1, 0],
|
|
510
|
+
... }
|
|
511
|
+
... )
|
|
512
|
+
>>> conv = molcraft.layers.GIConv(units=4)
|
|
513
|
+
>>> conv(graph)
|
|
514
|
+
GraphTensor(
|
|
515
|
+
context={
|
|
516
|
+
'size': <tf.Tensor: shape=[1], dtype=int32>
|
|
517
|
+
},
|
|
518
|
+
node={
|
|
519
|
+
'feature': <tf.Tensor: shape=[2, 4], dtype=float32>
|
|
520
|
+
},
|
|
521
|
+
edge={
|
|
522
|
+
'source': <tf.Tensor: shape=[2], dtype=int32>,
|
|
523
|
+
'target': <tf.Tensor: shape=[2], dtype=int32>
|
|
524
|
+
}
|
|
525
|
+
)
|
|
349
526
|
"""
|
|
350
527
|
|
|
351
528
|
def __init__(
|
|
@@ -353,24 +530,20 @@ class GIConv(GraphConv):
|
|
|
353
530
|
units: int,
|
|
354
531
|
activation: keras.layers.Activation | str | None = 'relu',
|
|
355
532
|
use_bias: bool = True,
|
|
356
|
-
|
|
357
|
-
dropout: float = 0.0,
|
|
533
|
+
normalization: bool = False,
|
|
358
534
|
update_edge_feature: bool = True,
|
|
359
535
|
**kwargs,
|
|
360
536
|
):
|
|
361
537
|
super().__init__(
|
|
362
538
|
units=units,
|
|
363
|
-
|
|
539
|
+
activation=activation,
|
|
540
|
+
normalization=normalization,
|
|
364
541
|
use_bias=use_bias,
|
|
365
542
|
**kwargs
|
|
366
543
|
)
|
|
367
|
-
self._activation = keras.activations.get(activation)
|
|
368
|
-
self._dropout = dropout
|
|
369
544
|
self._update_edge_feature = update_edge_feature
|
|
370
545
|
|
|
371
|
-
def
|
|
372
|
-
"""Builds the layer.
|
|
373
|
-
"""
|
|
546
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
374
547
|
node_feature_dim = spec.node['feature'].shape[-1]
|
|
375
548
|
|
|
376
549
|
self.epsilon = self.add_weight(
|
|
@@ -395,25 +568,16 @@ class GIConv(GraphConv):
|
|
|
395
568
|
|
|
396
569
|
if self._update_edge_feature:
|
|
397
570
|
self._edge_dense = self.get_dense(node_feature_dim)
|
|
398
|
-
self._edge_dense.build([None, edge_feature_dim])
|
|
399
571
|
else:
|
|
400
572
|
self._update_edge_feature = False
|
|
401
|
-
|
|
402
|
-
self._feedforward_intermediate_dense = self.get_dense(self.units)
|
|
403
|
-
self._feedforward_intermediate_dense.build([None, node_feature_dim])
|
|
404
573
|
|
|
405
574
|
has_overridden_update = self.__class__.update != GIConv.update
|
|
406
575
|
if not has_overridden_update:
|
|
576
|
+
self._feedforward_intermediate_dense = self.get_dense(self.units)
|
|
407
577
|
self._feedforward_activation = self._activation
|
|
408
|
-
self._feedforward_dropout = keras.layers.Dropout(self._dropout)
|
|
409
578
|
self._feedforward_output_dense = self.get_dense(self.units)
|
|
410
|
-
self._feedforward_output_dense.build([None, self.units])
|
|
411
|
-
|
|
412
|
-
self.built = True
|
|
413
579
|
|
|
414
580
|
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
415
|
-
"""Computes messages.
|
|
416
|
-
"""
|
|
417
581
|
message = tensor.gather('feature', 'source')
|
|
418
582
|
edge_feature = tensor.edge.get('feature')
|
|
419
583
|
if self._update_edge_feature:
|
|
@@ -430,11 +594,8 @@ class GIConv(GraphConv):
|
|
|
430
594
|
)
|
|
431
595
|
|
|
432
596
|
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
433
|
-
|
|
434
|
-
"""
|
|
435
|
-
node_feature = tensor.aggregate('message')
|
|
597
|
+
node_feature = tensor.aggregate('message', mode='mean')
|
|
436
598
|
node_feature += (1 + self.epsilon) * tensor.node['feature']
|
|
437
|
-
node_feature = self._feedforward_intermediate_dense(node_feature)
|
|
438
599
|
return tensor.update(
|
|
439
600
|
{
|
|
440
601
|
'node': {
|
|
@@ -447,11 +608,9 @@ class GIConv(GraphConv):
|
|
|
447
608
|
)
|
|
448
609
|
|
|
449
610
|
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
450
|
-
"""Updates nodes.
|
|
451
|
-
"""
|
|
452
611
|
node_feature = tensor.node['feature']
|
|
612
|
+
node_feature = self._feedforward_intermediate_dense(node_feature)
|
|
453
613
|
node_feature = self._feedforward_activation(node_feature)
|
|
454
|
-
node_feature = self._feedforward_dropout(node_feature)
|
|
455
614
|
node_feature = self._feedforward_output_dense(node_feature)
|
|
456
615
|
return tensor.update(
|
|
457
616
|
{
|
|
@@ -464,8 +623,6 @@ class GIConv(GraphConv):
|
|
|
464
623
|
def get_config(self) -> dict:
|
|
465
624
|
config = super().get_config()
|
|
466
625
|
config.update({
|
|
467
|
-
'activation': keras.activations.serialize(self._activation),
|
|
468
|
-
'dropout': self._dropout,
|
|
469
626
|
'update_edge_feature': self._update_edge_feature
|
|
470
627
|
})
|
|
471
628
|
return config
|
|
@@ -475,6 +632,33 @@ class GIConv(GraphConv):
|
|
|
475
632
|
class GAConv(GraphConv):
|
|
476
633
|
|
|
477
634
|
"""Graph attention network layer.
|
|
635
|
+
|
|
636
|
+
>>> graph = molcraft.tensors.GraphTensor(
|
|
637
|
+
... context={
|
|
638
|
+
... 'size': [2]
|
|
639
|
+
... },
|
|
640
|
+
... node={
|
|
641
|
+
... 'feature': [[1.], [2.]]
|
|
642
|
+
... },
|
|
643
|
+
... edge={
|
|
644
|
+
... 'source': [0, 1],
|
|
645
|
+
... 'target': [1, 0],
|
|
646
|
+
... }
|
|
647
|
+
... )
|
|
648
|
+
>>> conv = molcraft.layers.GAConv(units=4, heads=2)
|
|
649
|
+
>>> conv(graph)
|
|
650
|
+
GraphTensor(
|
|
651
|
+
context={
|
|
652
|
+
'size': <tf.Tensor: shape=[1], dtype=int32>
|
|
653
|
+
},
|
|
654
|
+
node={
|
|
655
|
+
'feature': <tf.Tensor: shape=[2, 4], dtype=float32>
|
|
656
|
+
},
|
|
657
|
+
edge={
|
|
658
|
+
'source': <tf.Tensor: shape=[2], dtype=int32>,
|
|
659
|
+
'target': <tf.Tensor: shape=[2], dtype=int32>
|
|
660
|
+
}
|
|
661
|
+
)
|
|
478
662
|
"""
|
|
479
663
|
|
|
480
664
|
def __init__(
|
|
@@ -483,8 +667,7 @@ class GAConv(GraphConv):
|
|
|
483
667
|
heads: int = 8,
|
|
484
668
|
activation: keras.layers.Activation | str | None = "relu",
|
|
485
669
|
use_bias: bool = True,
|
|
486
|
-
|
|
487
|
-
dropout: float = 0.0,
|
|
670
|
+
normalization: bool = False,
|
|
488
671
|
update_edge_feature: bool = True,
|
|
489
672
|
attention_activation: keras.layers.Activation | str | None = "leaky_relu",
|
|
490
673
|
**kwargs,
|
|
@@ -492,17 +675,15 @@ class GAConv(GraphConv):
|
|
|
492
675
|
kwargs['skip_connection'] = False
|
|
493
676
|
super().__init__(
|
|
494
677
|
units=units,
|
|
495
|
-
|
|
678
|
+
activation=activation,
|
|
496
679
|
use_bias=use_bias,
|
|
680
|
+
normalization=normalization,
|
|
497
681
|
**kwargs
|
|
498
682
|
)
|
|
499
683
|
self._heads = heads
|
|
500
684
|
if self.units % self.heads != 0:
|
|
501
685
|
raise ValueError(f"units need to be divisible by heads.")
|
|
502
686
|
self._head_units = self.units // self.heads
|
|
503
|
-
self._activation = keras.activations.get(activation)
|
|
504
|
-
self._dropout = dropout
|
|
505
|
-
self._normalize = normalize
|
|
506
687
|
self._update_edge_feature = update_edge_feature
|
|
507
688
|
self._attention_activation = keras.activations.get(attention_activation)
|
|
508
689
|
|
|
@@ -514,48 +695,33 @@ class GAConv(GraphConv):
|
|
|
514
695
|
def head_units(self):
|
|
515
696
|
return self._head_units
|
|
516
697
|
|
|
517
|
-
def
|
|
518
|
-
|
|
519
|
-
node_feature_dim = spec.node['feature'].shape[-1]
|
|
520
|
-
attn_feature_dim = node_feature_dim + node_feature_dim
|
|
521
|
-
|
|
698
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
522
699
|
self._has_edge_feature = 'feature' in spec.edge
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
'ijh,jkh->ikh', (self.head_units, self.heads)
|
|
529
|
-
)
|
|
530
|
-
self._edge_dense.build([None, self.head_units, self.heads])
|
|
531
|
-
else:
|
|
532
|
-
self._update_edge_feature = False
|
|
533
|
-
|
|
700
|
+
self._update_edge_feature = self._has_edge_feature and self._update_edge_feature
|
|
701
|
+
if self._update_edge_feature:
|
|
702
|
+
self._edge_dense = self.get_einsum_dense(
|
|
703
|
+
'ijh,jkh->ikh', (self.head_units, self.heads)
|
|
704
|
+
)
|
|
534
705
|
self._node_dense = self.get_einsum_dense(
|
|
535
706
|
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
536
707
|
)
|
|
537
|
-
self._node_dense.build([None, node_feature_dim])
|
|
538
|
-
|
|
539
708
|
self._feature_dense = self.get_einsum_dense(
|
|
540
709
|
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
541
710
|
)
|
|
542
|
-
self._feature_dense.build([None, attn_feature_dim])
|
|
543
|
-
|
|
544
711
|
self._attention_dense = self.get_einsum_dense(
|
|
545
712
|
'ijh,jkh->ikh', (1, self.heads)
|
|
546
713
|
)
|
|
547
|
-
self._attention_dense.build([None, self.head_units, self.heads])
|
|
548
|
-
|
|
549
714
|
self._node_self_dense = self.get_einsum_dense(
|
|
550
715
|
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
551
716
|
)
|
|
552
|
-
self._node_self_dense.build([None, node_feature_dim])
|
|
553
|
-
self._dropout_layer = keras.layers.Dropout(self._dropout)
|
|
554
717
|
|
|
555
|
-
self.
|
|
718
|
+
has_overridden_update = self.__class__.update != GAConv.update
|
|
719
|
+
if not has_overridden_update:
|
|
720
|
+
self._feedforward_intermediate_dense = self.get_dense(self.units)
|
|
721
|
+
self._feedforward_activation = self._activation
|
|
722
|
+
self._feedforward_output_dense = self.get_dense(self.units)
|
|
556
723
|
|
|
557
724
|
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
558
|
-
|
|
559
725
|
attention_feature = keras.ops.concatenate(
|
|
560
726
|
[
|
|
561
727
|
tensor.gather('feature', 'source'),
|
|
@@ -598,9 +764,8 @@ class GAConv(GraphConv):
|
|
|
598
764
|
)
|
|
599
765
|
|
|
600
766
|
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
601
|
-
node_feature = tensor.aggregate('message')
|
|
767
|
+
node_feature = tensor.aggregate('message', mode='sum')
|
|
602
768
|
node_feature += self._node_self_dense(tensor.node['feature'])
|
|
603
|
-
node_feature = self._dropout_layer(node_feature)
|
|
604
769
|
node_feature = keras.ops.reshape(node_feature, (-1, self.units))
|
|
605
770
|
return tensor.update(
|
|
606
771
|
{
|
|
@@ -615,7 +780,10 @@ class GAConv(GraphConv):
|
|
|
615
780
|
)
|
|
616
781
|
|
|
617
782
|
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
618
|
-
node_feature =
|
|
783
|
+
node_feature = tensor.node['feature']
|
|
784
|
+
node_feature = self._feedforward_intermediate_dense(node_feature)
|
|
785
|
+
node_feature = self._feedforward_activation(node_feature)
|
|
786
|
+
node_feature = self._feedforward_output_dense(node_feature)
|
|
619
787
|
return tensor.update(
|
|
620
788
|
{
|
|
621
789
|
'node': {
|
|
@@ -628,8 +796,6 @@ class GAConv(GraphConv):
|
|
|
628
796
|
config = super().get_config()
|
|
629
797
|
config.update({
|
|
630
798
|
"heads": self._heads,
|
|
631
|
-
'activation': keras.activations.serialize(self._activation),
|
|
632
|
-
'dropout': self._dropout,
|
|
633
799
|
'update_edge_feature': self._update_edge_feature,
|
|
634
800
|
'attention_activation': keras.activations.serialize(self._attention_activation),
|
|
635
801
|
})
|
|
@@ -640,6 +806,34 @@ class GAConv(GraphConv):
|
|
|
640
806
|
class GTConv(GraphConv):
|
|
641
807
|
|
|
642
808
|
"""Graph transformer layer.
|
|
809
|
+
|
|
810
|
+
>>> graph = molcraft.tensors.GraphTensor(
|
|
811
|
+
... context={
|
|
812
|
+
... 'size': [2]
|
|
813
|
+
... },
|
|
814
|
+
... node={
|
|
815
|
+
... 'feature': [[1.], [2.]]
|
|
816
|
+
... },
|
|
817
|
+
... edge={
|
|
818
|
+
... 'source': [0, 1],
|
|
819
|
+
... 'target': [1, 0],
|
|
820
|
+
... }
|
|
821
|
+
... )
|
|
822
|
+
>>> conv = molcraft.layers.GTConv(units=4, heads=2)
|
|
823
|
+
>>> conv(graph)
|
|
824
|
+
GraphTensor(
|
|
825
|
+
context={
|
|
826
|
+
'size': <tf.Tensor: shape=[1], dtype=int32>
|
|
827
|
+
},
|
|
828
|
+
node={
|
|
829
|
+
'feature': <tf.Tensor: shape=[2, 4], dtype=float32>
|
|
830
|
+
},
|
|
831
|
+
edge={
|
|
832
|
+
'source': <tf.Tensor: shape=[2], dtype=int32>,
|
|
833
|
+
'target': <tf.Tensor: shape=[2], dtype=int32>
|
|
834
|
+
}
|
|
835
|
+
)
|
|
836
|
+
|
|
643
837
|
"""
|
|
644
838
|
|
|
645
839
|
def __init__(
|
|
@@ -648,26 +842,22 @@ class GTConv(GraphConv):
|
|
|
648
842
|
heads: int = 8,
|
|
649
843
|
activation: keras.layers.Activation | str | None = "relu",
|
|
650
844
|
use_bias: bool = True,
|
|
651
|
-
|
|
652
|
-
dropout: float = 0.0,
|
|
845
|
+
normalization: bool = False,
|
|
653
846
|
attention_dropout: float = 0.0,
|
|
654
847
|
**kwargs,
|
|
655
848
|
) -> None:
|
|
656
|
-
kwargs['skip_connection'] = False
|
|
657
849
|
super().__init__(
|
|
658
850
|
units=units,
|
|
659
|
-
|
|
851
|
+
activation=activation,
|
|
660
852
|
use_bias=use_bias,
|
|
853
|
+
normalization=normalization,
|
|
661
854
|
**kwargs
|
|
662
855
|
)
|
|
663
856
|
self._heads = heads
|
|
664
857
|
if self.units % self.heads != 0:
|
|
665
858
|
raise ValueError(f"units need to be divisible by heads.")
|
|
666
859
|
self._head_units = self.units // self.heads
|
|
667
|
-
self._activation = keras.activations.get(activation)
|
|
668
|
-
self._dropout = dropout
|
|
669
860
|
self._attention_dropout = attention_dropout
|
|
670
|
-
self._normalize = normalize
|
|
671
861
|
|
|
672
862
|
@property
|
|
673
863
|
def heads(self):
|
|
@@ -677,69 +867,31 @@ class GTConv(GraphConv):
|
|
|
677
867
|
def head_units(self):
|
|
678
868
|
return self._head_units
|
|
679
869
|
|
|
680
|
-
def
|
|
681
|
-
"""Builds the layer.
|
|
682
|
-
"""
|
|
683
|
-
node_feature_dim = spec.node['feature'].shape[-1]
|
|
684
|
-
self.project_residual = node_feature_dim != self.units
|
|
685
|
-
if self.project_residual:
|
|
686
|
-
warn(
|
|
687
|
-
'`GTConv` uses residual connections, but found incompatible dim '
|
|
688
|
-
'between input (node feature dim) and output (`self.units`). '
|
|
689
|
-
'Automatically applying a projection layer to residual to '
|
|
690
|
-
'match input and output. '
|
|
691
|
-
)
|
|
692
|
-
self._residual_dense = self.get_dense(self.units)
|
|
693
|
-
self._residual_dense.build([None, node_feature_dim])
|
|
694
|
-
|
|
870
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
695
871
|
self._query_dense = self.get_einsum_dense(
|
|
696
872
|
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
697
873
|
)
|
|
698
|
-
self._query_dense.build([None, node_feature_dim])
|
|
699
|
-
|
|
700
874
|
self._key_dense = self.get_einsum_dense(
|
|
701
875
|
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
702
876
|
)
|
|
703
|
-
self._key_dense.build([None, node_feature_dim])
|
|
704
|
-
|
|
705
877
|
self._value_dense = self.get_einsum_dense(
|
|
706
878
|
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
707
879
|
)
|
|
708
|
-
self._value_dense.build([None, node_feature_dim])
|
|
709
|
-
|
|
710
880
|
self._output_dense = self.get_dense(self.units)
|
|
711
|
-
self._output_dense.build([None, self.units])
|
|
712
|
-
|
|
713
881
|
self._softmax_dropout = keras.layers.Dropout(self._attention_dropout)
|
|
714
882
|
|
|
715
|
-
self._self_attention_dropout = keras.layers.Dropout(self._dropout)
|
|
716
|
-
|
|
717
883
|
self._add_bias = not 'bias' in spec.edge
|
|
718
884
|
|
|
719
885
|
if self._add_bias:
|
|
720
886
|
self._edge_bias = EdgeBias(biases=self.heads)
|
|
721
|
-
self._edge_bias.build_from_spec(spec)
|
|
722
887
|
|
|
723
888
|
has_overridden_update = self.__class__.update != GTConv.update
|
|
724
889
|
if not has_overridden_update:
|
|
725
|
-
|
|
726
|
-
if self._normalize:
|
|
727
|
-
self._feedforward_output_norm = keras.layers.LayerNormalization()
|
|
728
|
-
self._feedforward_output_norm.build([None, self.units])
|
|
729
|
-
|
|
730
|
-
self._feedforward_dropout = keras.layers.Dropout(self._dropout)
|
|
731
|
-
|
|
732
890
|
self._feedforward_intermediate_dense = self.get_dense(self.units)
|
|
733
|
-
self.
|
|
734
|
-
|
|
891
|
+
self._feedforward_activation = self._activation
|
|
735
892
|
self._feedforward_output_dense = self.get_dense(self.units)
|
|
736
|
-
self._feedforward_output_dense.build([None, self.units])
|
|
737
|
-
|
|
738
|
-
self.built = True
|
|
739
893
|
|
|
740
894
|
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
741
|
-
"""Computes messages.
|
|
742
|
-
"""
|
|
743
895
|
if self._add_bias:
|
|
744
896
|
edge_bias = self._edge_bias(tensor)
|
|
745
897
|
tensor = tensor.update(
|
|
@@ -764,7 +916,6 @@ class GTConv(GraphConv):
|
|
|
764
916
|
attention_score /= keras.ops.sqrt(float(self.head_units))
|
|
765
917
|
|
|
766
918
|
attention_score += keras.ops.expand_dims(tensor.edge['bias'], axis=1)
|
|
767
|
-
|
|
768
919
|
attention = ops.edge_softmax(attention_score, tensor.edge['target'])
|
|
769
920
|
attention = self._softmax_dropout(attention)
|
|
770
921
|
|
|
@@ -778,12 +929,9 @@ class GTConv(GraphConv):
|
|
|
778
929
|
)
|
|
779
930
|
|
|
780
931
|
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
781
|
-
|
|
782
|
-
"""
|
|
783
|
-
node_feature = tensor.aggregate('message')
|
|
932
|
+
node_feature = tensor.aggregate('message', mode='sum')
|
|
784
933
|
node_feature = keras.ops.reshape(node_feature, (-1, self.units))
|
|
785
934
|
node_feature = self._output_dense(node_feature)
|
|
786
|
-
node_feature = self._self_attention_dropout(node_feature)
|
|
787
935
|
return tensor.update(
|
|
788
936
|
{
|
|
789
937
|
'node': {
|
|
@@ -798,26 +946,10 @@ class GTConv(GraphConv):
|
|
|
798
946
|
)
|
|
799
947
|
|
|
800
948
|
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
801
|
-
"""Updates nodes.
|
|
802
|
-
"""
|
|
803
949
|
node_feature = tensor.node['feature']
|
|
804
|
-
|
|
805
|
-
residual = tensor.node['residual']
|
|
806
|
-
if self.project_residual:
|
|
807
|
-
residual = self._residual_dense(residual)
|
|
808
|
-
|
|
809
|
-
node_feature += residual
|
|
810
|
-
residual = node_feature
|
|
811
|
-
|
|
812
950
|
node_feature = self._feedforward_intermediate_dense(node_feature)
|
|
813
|
-
node_feature = self.
|
|
951
|
+
node_feature = self._feedforward_activation(node_feature)
|
|
814
952
|
node_feature = self._feedforward_output_dense(node_feature)
|
|
815
|
-
node_feature = self._feedforward_dropout(node_feature)
|
|
816
|
-
if self._normalize:
|
|
817
|
-
node_feature = self._feedforward_output_norm(node_feature)
|
|
818
|
-
|
|
819
|
-
node_feature += residual
|
|
820
|
-
|
|
821
953
|
return tensor.update(
|
|
822
954
|
{
|
|
823
955
|
'node': {
|
|
@@ -830,148 +962,48 @@ class GTConv(GraphConv):
|
|
|
830
962
|
config = super().get_config()
|
|
831
963
|
config.update({
|
|
832
964
|
"heads": self._heads,
|
|
833
|
-
'activation': keras.activations.serialize(self._activation),
|
|
834
|
-
'dropout': self._dropout,
|
|
835
965
|
'attention_dropout': self._attention_dropout,
|
|
836
966
|
})
|
|
837
967
|
return config
|
|
838
968
|
|
|
839
969
|
|
|
840
970
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
841
|
-
class
|
|
971
|
+
class MPConv(GraphConv):
|
|
842
972
|
|
|
843
|
-
"""
|
|
973
|
+
"""Message passing neural network layer.
|
|
844
974
|
"""
|
|
845
975
|
|
|
846
|
-
def
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
node_feature = tensor.node['feature']
|
|
976
|
+
def __init__(
|
|
977
|
+
self,
|
|
978
|
+
units: int = 128,
|
|
979
|
+
activation: keras.layers.Activation | str | None = None,
|
|
980
|
+
use_bias: bool = True,
|
|
981
|
+
normalization: bool = False,
|
|
982
|
+
**kwargs
|
|
983
|
+
) -> None:
|
|
984
|
+
super().__init__(
|
|
985
|
+
units=units,
|
|
986
|
+
activation=activation,
|
|
987
|
+
use_bias=use_bias,
|
|
988
|
+
normalization=normalization,
|
|
989
|
+
**kwargs
|
|
990
|
+
)
|
|
862
991
|
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
992
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
993
|
+
node_feature_dim = spec.node['feature'].shape[-1]
|
|
994
|
+
self.message_fn = self.get_dense(self.units, activation=self._activation)
|
|
995
|
+
self.update_fn = keras.layers.GRUCell(self.units)
|
|
996
|
+
self._has_edge_feature = 'feature' in spec.edge
|
|
997
|
+
self.project_input_node_feature = node_feature_dim != self.units
|
|
998
|
+
if self.project_input_node_feature:
|
|
999
|
+
warn(
|
|
1000
|
+
'Input node feature dim does not match updated node feature dim. '
|
|
1001
|
+
'To make sure input node feature can be passed as `states` to the '
|
|
1002
|
+
'GRU cell, it will automatically be projected prior to it.'
|
|
1003
|
+
)
|
|
1004
|
+
self._previous_node_dense = self.get_dense(
|
|
1005
|
+
self.units, activation=self._activation
|
|
867
1006
|
)
|
|
868
|
-
node_feature += self._centrality_dense(centrality)
|
|
869
|
-
|
|
870
|
-
edge_bias = self._edge_bias(tensor) + self._gaussian_edge_bias(gaussian)
|
|
871
|
-
tensor = tensor.update({'edge': {'bias': edge_bias}})
|
|
872
|
-
|
|
873
|
-
query = self._query_dense(node_feature)
|
|
874
|
-
key = self._key_dense(node_feature)
|
|
875
|
-
value = self._value_dense(node_feature)
|
|
876
|
-
|
|
877
|
-
query = ops.gather(query, tensor.edge['source'])
|
|
878
|
-
key = ops.gather(key, tensor.edge['target'])
|
|
879
|
-
value = ops.gather(value, tensor.edge['source'])
|
|
880
|
-
|
|
881
|
-
attention_score = keras.ops.sum(query * key, axis=1, keepdims=True)
|
|
882
|
-
attention_score /= keras.ops.sqrt(float(self.head_units))
|
|
883
|
-
|
|
884
|
-
attention_score += keras.ops.expand_dims(tensor.edge['bias'], axis=1)
|
|
885
|
-
|
|
886
|
-
attention = ops.edge_softmax(attention_score, tensor.edge['target'])
|
|
887
|
-
attention = self._softmax_dropout(attention)
|
|
888
|
-
|
|
889
|
-
distance = keras.ops.subtract(
|
|
890
|
-
tensor.gather('coordinate', 'source'),
|
|
891
|
-
tensor.gather('coordinate', 'target')
|
|
892
|
-
)
|
|
893
|
-
euclidean_distance = ops.euclidean_distance(
|
|
894
|
-
tensor.gather('coordinate', 'source'),
|
|
895
|
-
tensor.gather('coordinate', 'target'),
|
|
896
|
-
axis=-1
|
|
897
|
-
)
|
|
898
|
-
distance /= euclidean_distance
|
|
899
|
-
|
|
900
|
-
attention *= keras.ops.expand_dims(distance, axis=-1)
|
|
901
|
-
attention = keras.ops.expand_dims(attention, axis=2)
|
|
902
|
-
value = keras.ops.expand_dims(value, axis=1)
|
|
903
|
-
|
|
904
|
-
return tensor.update(
|
|
905
|
-
{
|
|
906
|
-
'edge': {
|
|
907
|
-
'message': value,
|
|
908
|
-
'weight': attention,
|
|
909
|
-
},
|
|
910
|
-
}
|
|
911
|
-
)
|
|
912
|
-
|
|
913
|
-
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
914
|
-
"""Aggregates messages.
|
|
915
|
-
"""
|
|
916
|
-
node_feature = tensor.aggregate('message')
|
|
917
|
-
node_feature = keras.ops.reshape(
|
|
918
|
-
node_feature, (tensor.num_nodes, -1, self.units)
|
|
919
|
-
)
|
|
920
|
-
node_feature = self._output_dense(node_feature)
|
|
921
|
-
node_feature = keras.ops.sum(node_feature, axis=1)
|
|
922
|
-
node_feature = self._self_attention_dropout(node_feature)
|
|
923
|
-
return tensor.update(
|
|
924
|
-
{
|
|
925
|
-
'node': {
|
|
926
|
-
'feature': node_feature,
|
|
927
|
-
'residual': tensor.node['feature']
|
|
928
|
-
},
|
|
929
|
-
'edge': {
|
|
930
|
-
'message': None,
|
|
931
|
-
'weight': None,
|
|
932
|
-
}
|
|
933
|
-
}
|
|
934
|
-
)
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
938
|
-
class MPConv(GraphConv):
|
|
939
|
-
|
|
940
|
-
"""Message passing neural network layer.
|
|
941
|
-
"""
|
|
942
|
-
|
|
943
|
-
def __init__(
|
|
944
|
-
self,
|
|
945
|
-
units: int = 128,
|
|
946
|
-
activation: keras.layers.Activation | str | None = None,
|
|
947
|
-
use_bias: bool = True,
|
|
948
|
-
normalize: bool = True,
|
|
949
|
-
dropout: float = 0.0,
|
|
950
|
-
**kwargs
|
|
951
|
-
) -> None:
|
|
952
|
-
super().__init__(
|
|
953
|
-
units=units,
|
|
954
|
-
normalize=normalize,
|
|
955
|
-
use_bias=use_bias,
|
|
956
|
-
**kwargs
|
|
957
|
-
)
|
|
958
|
-
self._activation = keras.activations.get(activation)
|
|
959
|
-
self._dropout = dropout or 0.0
|
|
960
|
-
|
|
961
|
-
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
962
|
-
node_feature_dim = spec.node['feature'].shape[-1]
|
|
963
|
-
self.message_fn = self.get_dense(self.units, activation=self._activation)
|
|
964
|
-
self.update_fn = keras.layers.GRUCell(self.units)
|
|
965
|
-
self._has_edge_feature = 'feature' in spec.edge
|
|
966
|
-
self.project_input_node_feature = node_feature_dim != self.units
|
|
967
|
-
if self.project_input_node_feature:
|
|
968
|
-
warn(
|
|
969
|
-
'Input node feature dim does not match updated node feature dim. '
|
|
970
|
-
'To make sure input node feature can be passed as `states` to the '
|
|
971
|
-
'GRU cell, it will automatically be projected prior to it.'
|
|
972
|
-
)
|
|
973
|
-
self._previous_node_dense = self.get_dense(self.units, activation=self._activation)
|
|
974
|
-
self.built = True
|
|
975
1007
|
|
|
976
1008
|
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
977
1009
|
feature = keras.ops.concatenate(
|
|
@@ -999,7 +1031,7 @@ class MPConv(GraphConv):
|
|
|
999
1031
|
)
|
|
1000
1032
|
|
|
1001
1033
|
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1002
|
-
aggregate = tensor.aggregate('message')
|
|
1034
|
+
aggregate = tensor.aggregate('message', mode='mean')
|
|
1003
1035
|
previous = tensor.node['feature']
|
|
1004
1036
|
if self.project_input_node_feature:
|
|
1005
1037
|
previous = self._previous_node_dense(previous)
|
|
@@ -1028,17 +1060,105 @@ class MPConv(GraphConv):
|
|
|
1028
1060
|
|
|
1029
1061
|
def get_config(self) -> dict:
|
|
1030
1062
|
config = super().get_config()
|
|
1031
|
-
config.update({
|
|
1032
|
-
'activation': keras.activations.serialize(self._activation),
|
|
1033
|
-
'dropout': self._dropout,
|
|
1034
|
-
})
|
|
1063
|
+
config.update({})
|
|
1035
1064
|
return config
|
|
1036
1065
|
|
|
1037
1066
|
|
|
1067
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1068
|
+
class GTConv3D(GTConv):
|
|
1069
|
+
|
|
1070
|
+
"""Graph transformer layer 3D.
|
|
1071
|
+
"""
|
|
1072
|
+
|
|
1073
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1074
|
+
"""Builds the layer.
|
|
1075
|
+
"""
|
|
1076
|
+
super().build(spec)
|
|
1077
|
+
if self._add_bias:
|
|
1078
|
+
node_feature_dim = spec.node['feature'].shape[-1]
|
|
1079
|
+
kernels = self.units
|
|
1080
|
+
self._gaussian_basis = GaussianDistance(kernels)
|
|
1081
|
+
self._centrality_dense = self.get_dense(units=node_feature_dim)
|
|
1082
|
+
self._gaussian_edge_bias = self.get_dense(self.heads)
|
|
1083
|
+
|
|
1084
|
+
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1085
|
+
node_feature = tensor.node['feature']
|
|
1086
|
+
|
|
1087
|
+
if self._add_bias:
|
|
1088
|
+
gaussian = self._gaussian_basis(tensor)
|
|
1089
|
+
centrality = keras.ops.segment_sum(
|
|
1090
|
+
gaussian, tensor.edge['target'], tensor.num_nodes
|
|
1091
|
+
)
|
|
1092
|
+
node_feature += self._centrality_dense(centrality)
|
|
1093
|
+
|
|
1094
|
+
edge_bias = self._edge_bias(tensor) + self._gaussian_edge_bias(gaussian)
|
|
1095
|
+
tensor = tensor.update({'edge': {'bias': edge_bias}})
|
|
1096
|
+
|
|
1097
|
+
query = self._query_dense(node_feature)
|
|
1098
|
+
key = self._key_dense(node_feature)
|
|
1099
|
+
value = self._value_dense(node_feature)
|
|
1100
|
+
|
|
1101
|
+
query = ops.gather(query, tensor.edge['source'])
|
|
1102
|
+
key = ops.gather(key, tensor.edge['target'])
|
|
1103
|
+
value = ops.gather(value, tensor.edge['source'])
|
|
1104
|
+
|
|
1105
|
+
attention_score = keras.ops.sum(query * key, axis=1, keepdims=True)
|
|
1106
|
+
attention_score /= keras.ops.sqrt(float(self.head_units))
|
|
1107
|
+
|
|
1108
|
+
attention_score += keras.ops.expand_dims(tensor.edge['bias'], axis=1)
|
|
1109
|
+
|
|
1110
|
+
attention = ops.edge_softmax(attention_score, tensor.edge['target'])
|
|
1111
|
+
attention = self._softmax_dropout(attention)
|
|
1112
|
+
|
|
1113
|
+
distance = keras.ops.subtract(
|
|
1114
|
+
tensor.gather('coordinate', 'source'),
|
|
1115
|
+
tensor.gather('coordinate', 'target')
|
|
1116
|
+
)
|
|
1117
|
+
euclidean_distance = ops.euclidean_distance(
|
|
1118
|
+
tensor.gather('coordinate', 'source'),
|
|
1119
|
+
tensor.gather('coordinate', 'target'),
|
|
1120
|
+
axis=-1
|
|
1121
|
+
)
|
|
1122
|
+
distance /= euclidean_distance
|
|
1123
|
+
|
|
1124
|
+
attention *= keras.ops.expand_dims(distance, axis=-1)
|
|
1125
|
+
attention = keras.ops.expand_dims(attention, axis=2)
|
|
1126
|
+
value = keras.ops.expand_dims(value, axis=1)
|
|
1127
|
+
|
|
1128
|
+
return tensor.update(
|
|
1129
|
+
{
|
|
1130
|
+
'edge': {
|
|
1131
|
+
'message': value,
|
|
1132
|
+
'weight': attention,
|
|
1133
|
+
},
|
|
1134
|
+
}
|
|
1135
|
+
)
|
|
1136
|
+
|
|
1137
|
+
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1138
|
+
node_feature = tensor.aggregate('message', mode='sum')
|
|
1139
|
+
node_feature = keras.ops.reshape(
|
|
1140
|
+
node_feature, (tensor.num_nodes, -1, self.units)
|
|
1141
|
+
)
|
|
1142
|
+
node_feature = self._output_dense(node_feature)
|
|
1143
|
+
node_feature = keras.ops.sum(node_feature, axis=1)
|
|
1144
|
+
return tensor.update(
|
|
1145
|
+
{
|
|
1146
|
+
'node': {
|
|
1147
|
+
'feature': node_feature,
|
|
1148
|
+
'residual': tensor.node['feature']
|
|
1149
|
+
},
|
|
1150
|
+
'edge': {
|
|
1151
|
+
'message': None,
|
|
1152
|
+
'weight': None,
|
|
1153
|
+
}
|
|
1154
|
+
}
|
|
1155
|
+
)
|
|
1156
|
+
|
|
1157
|
+
|
|
1038
1158
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1039
1159
|
class MPConv3D(MPConv):
|
|
1040
1160
|
|
|
1041
|
-
"""
|
|
1161
|
+
"""Message passing neural network layer 3D.
|
|
1042
1162
|
"""
|
|
1043
1163
|
|
|
1044
1164
|
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
@@ -1076,7 +1196,7 @@ class MPConv3D(MPConv):
|
|
|
1076
1196
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1077
1197
|
class EGConv3D(GraphConv):
|
|
1078
1198
|
|
|
1079
|
-
"""Equivariant graph neural network layer.
|
|
1199
|
+
"""Equivariant graph neural network layer 3D.
|
|
1080
1200
|
"""
|
|
1081
1201
|
|
|
1082
1202
|
def __init__(
|
|
@@ -1084,49 +1204,33 @@ class EGConv3D(GraphConv):
|
|
|
1084
1204
|
units: int = 128,
|
|
1085
1205
|
activation: keras.layers.Activation | str | None = None,
|
|
1086
1206
|
use_bias: bool = True,
|
|
1087
|
-
|
|
1088
|
-
dropout: float = 0.0,
|
|
1207
|
+
normalization: bool = False,
|
|
1089
1208
|
**kwargs
|
|
1090
1209
|
) -> None:
|
|
1091
1210
|
super().__init__(
|
|
1092
1211
|
units=units,
|
|
1093
|
-
|
|
1212
|
+
activation=activation,
|
|
1094
1213
|
use_bias=use_bias,
|
|
1214
|
+
normalization=normalization,
|
|
1095
1215
|
**kwargs
|
|
1096
1216
|
)
|
|
1097
|
-
self._activation = keras.activations.get(activation)
|
|
1098
|
-
self._dropout = dropout or 0.0
|
|
1099
1217
|
|
|
1100
|
-
def
|
|
1218
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1101
1219
|
if 'coordinate' not in spec.node:
|
|
1102
1220
|
raise ValueError(
|
|
1103
1221
|
'Could not find `coordinate`s in node, '
|
|
1104
1222
|
'which is required for Conv3D layers.'
|
|
1105
1223
|
)
|
|
1106
|
-
|
|
1107
|
-
feature_dim = node_feature_dim + node_feature_dim + 1
|
|
1108
|
-
if 'feature' in spec.edge:
|
|
1109
|
-
self._has_edge_feature = True
|
|
1110
|
-
edge_feature_dim = spec.edge['feature'].shape[-1]
|
|
1111
|
-
feature_dim += edge_feature_dim
|
|
1112
|
-
else:
|
|
1113
|
-
self._has_edge_feature = False
|
|
1114
|
-
|
|
1224
|
+
self._has_edge_feature = 'feature' in spec.edge
|
|
1115
1225
|
self.message_fn = self.get_dense(self.units, activation=self._activation)
|
|
1116
|
-
self.message_fn.build([None, feature_dim])
|
|
1117
1226
|
self.dense_position = self.get_dense(1)
|
|
1118
|
-
self.dense_position.build([None, self.units])
|
|
1119
1227
|
|
|
1120
1228
|
has_overridden_update = self.__class__.update != EGConv3D.update
|
|
1121
1229
|
if not has_overridden_update:
|
|
1122
1230
|
self.update_fn = self.get_dense(self.units, activation=self._activation)
|
|
1123
|
-
self.
|
|
1124
|
-
self._dropout_layer = keras.layers.Dropout(self._dropout)
|
|
1125
|
-
self.built = True
|
|
1231
|
+
self.output_dense = self.get_dense(self.units)
|
|
1126
1232
|
|
|
1127
1233
|
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1128
|
-
"""Computes messages.
|
|
1129
|
-
"""
|
|
1130
1234
|
relative_node_coordinate = keras.ops.subtract(
|
|
1131
1235
|
tensor.gather('coordinate', 'target'),
|
|
1132
1236
|
tensor.gather('coordinate', 'source')
|
|
@@ -1169,8 +1273,6 @@ class EGConv3D(GraphConv):
|
|
|
1169
1273
|
)
|
|
1170
1274
|
|
|
1171
1275
|
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1172
|
-
"""Aggregates messages.
|
|
1173
|
-
"""
|
|
1174
1276
|
coefficient = keras.ops.bincount(
|
|
1175
1277
|
tensor.edge['source'],
|
|
1176
1278
|
minlength=tensor.num_nodes
|
|
@@ -1185,7 +1287,7 @@ class EGConv3D(GraphConv):
|
|
|
1185
1287
|
updated_coordinate = tensor.aggregate('relative_node_coordinate') * coefficient
|
|
1186
1288
|
updated_coordinate += tensor.node['coordinate']
|
|
1187
1289
|
|
|
1188
|
-
aggregate = tensor.aggregate('message')
|
|
1290
|
+
aggregate = tensor.aggregate('message', mode='mean')
|
|
1189
1291
|
return tensor.update(
|
|
1190
1292
|
{
|
|
1191
1293
|
'node': {
|
|
@@ -1201,8 +1303,6 @@ class EGConv3D(GraphConv):
|
|
|
1201
1303
|
)
|
|
1202
1304
|
|
|
1203
1305
|
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1204
|
-
"""Updates nodes.
|
|
1205
|
-
"""
|
|
1206
1306
|
updated_node_feature = self.update_fn(
|
|
1207
1307
|
keras.ops.concatenate(
|
|
1208
1308
|
[
|
|
@@ -1212,7 +1312,7 @@ class EGConv3D(GraphConv):
|
|
|
1212
1312
|
axis=-1
|
|
1213
1313
|
)
|
|
1214
1314
|
)
|
|
1215
|
-
updated_node_feature = self.
|
|
1315
|
+
updated_node_feature = self.output_dense(updated_node_feature)
|
|
1216
1316
|
return tensor.update(
|
|
1217
1317
|
{
|
|
1218
1318
|
'node': {
|
|
@@ -1224,66 +1324,46 @@ class EGConv3D(GraphConv):
|
|
|
1224
1324
|
|
|
1225
1325
|
def get_config(self) -> dict:
|
|
1226
1326
|
config = super().get_config()
|
|
1227
|
-
config.update({
|
|
1228
|
-
'activation': keras.activations.serialize(self._activation),
|
|
1229
|
-
'dropout': self._dropout,
|
|
1230
|
-
})
|
|
1327
|
+
config.update({})
|
|
1231
1328
|
return config
|
|
1232
1329
|
|
|
1233
1330
|
|
|
1234
1331
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1235
|
-
class
|
|
1236
|
-
|
|
1332
|
+
class Readout(GraphLayer):
|
|
1333
|
+
|
|
1334
|
+
"""Readout layer.
|
|
1237
1335
|
"""
|
|
1238
|
-
def __init__(
|
|
1239
|
-
self,
|
|
1240
|
-
units: int = None,
|
|
1241
|
-
activation: str = None,
|
|
1242
|
-
field: str = 'node',
|
|
1243
|
-
**kwargs
|
|
1244
|
-
) -> None:
|
|
1245
|
-
super().__init__(**kwargs)
|
|
1246
|
-
self.units = units
|
|
1247
|
-
self._activation = keras.activations.get(activation)
|
|
1248
|
-
self.field = field
|
|
1249
1336
|
|
|
1250
|
-
def
|
|
1251
|
-
|
|
1252
|
-
|
|
1253
|
-
|
|
1254
|
-
|
|
1255
|
-
|
|
1256
|
-
|
|
1257
|
-
|
|
1258
|
-
self.
|
|
1259
|
-
|
|
1260
|
-
|
|
1261
|
-
|
|
1337
|
+
def __init__(self, mode: str | None = None, **kwargs):
|
|
1338
|
+
kwargs['kernel_initializer'] = None
|
|
1339
|
+
kwargs['bias_initializer'] = None
|
|
1340
|
+
super().__init__(**kwargs)
|
|
1341
|
+
self.mode = mode
|
|
1342
|
+
if str(self.mode).lower().startswith('sum'):
|
|
1343
|
+
self._reduce_fn = keras.ops.segment_sum
|
|
1344
|
+
elif str(self.mode).lower().startswith('max'):
|
|
1345
|
+
self._reduce_fn = keras.ops.segment_max
|
|
1346
|
+
elif str(self.mode).lower().startswith('super'):
|
|
1347
|
+
self._reduce_fn = keras.ops.segment_sum
|
|
1348
|
+
else:
|
|
1349
|
+
self._reduce_fn = ops.segment_mean
|
|
1262
1350
|
|
|
1263
|
-
def propagate(self, tensor: tensors.GraphTensor):
|
|
1264
|
-
|
|
1265
|
-
|
|
1266
|
-
|
|
1267
|
-
|
|
1268
|
-
|
|
1269
|
-
return
|
|
1270
|
-
|
|
1271
|
-
|
|
1272
|
-
'feature': feature
|
|
1273
|
-
}
|
|
1274
|
-
}
|
|
1275
|
-
)
|
|
1351
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tf.Tensor:
|
|
1352
|
+
node_feature = tensor.node['feature']
|
|
1353
|
+
if str(self.mode).lower().startswith('super'):
|
|
1354
|
+
node_feature = keras.ops.where(
|
|
1355
|
+
tensor.node['super'][:, None], node_feature, 0.0
|
|
1356
|
+
)
|
|
1357
|
+
return self._reduce_fn(
|
|
1358
|
+
node_feature, tensor.graph_indicator, tensor.num_subgraphs
|
|
1359
|
+
)
|
|
1276
1360
|
|
|
1277
1361
|
def get_config(self) -> dict:
|
|
1278
1362
|
config = super().get_config()
|
|
1279
|
-
config.
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
'field': self.field,
|
|
1283
|
-
})
|
|
1284
|
-
return config
|
|
1363
|
+
config['mode'] = self.mode
|
|
1364
|
+
return config
|
|
1365
|
+
|
|
1285
1366
|
|
|
1286
|
-
|
|
1287
1367
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1288
1368
|
class GraphNetwork(GraphLayer):
|
|
1289
1369
|
|
|
@@ -1291,7 +1371,7 @@ class GraphNetwork(GraphLayer):
|
|
|
1291
1371
|
|
|
1292
1372
|
Sequentially calls graph layers (`GraphLayer`) and concatenates its output.
|
|
1293
1373
|
|
|
1294
|
-
|
|
1374
|
+
Arguments:
|
|
1295
1375
|
layers (list):
|
|
1296
1376
|
A list of graph layers.
|
|
1297
1377
|
"""
|
|
@@ -1301,37 +1381,32 @@ class GraphNetwork(GraphLayer):
|
|
|
1301
1381
|
self.layers = layers
|
|
1302
1382
|
self._update_edge_feature = False
|
|
1303
1383
|
|
|
1304
|
-
def
|
|
1305
|
-
"""Builds the layer.
|
|
1306
|
-
"""
|
|
1384
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1307
1385
|
units = self.layers[0].units
|
|
1308
1386
|
node_feature_dim = spec.node['feature'].shape[-1]
|
|
1309
|
-
|
|
1387
|
+
self._update_node_feature = node_feature_dim != units
|
|
1388
|
+
if self._update_node_feature:
|
|
1310
1389
|
warn(
|
|
1311
1390
|
'Node feature dim does not match `units` of the first layer. '
|
|
1312
1391
|
'Automatically adding a node projection layer to match `units`.'
|
|
1313
1392
|
)
|
|
1314
1393
|
self._node_dense = self.get_dense(units)
|
|
1315
|
-
|
|
1316
|
-
|
|
1317
|
-
if has_edge_feature:
|
|
1394
|
+
self._has_edge_feature = 'feature' in spec.edge
|
|
1395
|
+
if self._has_edge_feature:
|
|
1318
1396
|
edge_feature_dim = spec.edge['feature'].shape[-1]
|
|
1319
|
-
|
|
1397
|
+
self._update_edge_feature = edge_feature_dim != units
|
|
1398
|
+
if self._update_edge_feature:
|
|
1320
1399
|
warn(
|
|
1321
1400
|
'Edge feature dim does not match `units` of the first layer. '
|
|
1322
1401
|
'Automatically adding a edge projection layer to match `units`.'
|
|
1323
1402
|
)
|
|
1324
1403
|
self._edge_dense = self.get_dense(units)
|
|
1325
|
-
self._update_edge_feature = True
|
|
1326
|
-
self.built = True
|
|
1327
1404
|
|
|
1328
1405
|
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1329
|
-
"""Calls the layer.
|
|
1330
|
-
"""
|
|
1331
1406
|
x = tensors.to_dict(tensor)
|
|
1332
1407
|
if self._update_node_feature:
|
|
1333
1408
|
x['node']['feature'] = self._node_dense(tensor.node['feature'])
|
|
1334
|
-
if self._update_edge_feature:
|
|
1409
|
+
if self._has_edge_feature and self._update_edge_feature:
|
|
1335
1410
|
x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
|
|
1336
1411
|
outputs = [x['node']['feature']]
|
|
1337
1412
|
for layer in self.layers:
|
|
@@ -1356,7 +1431,7 @@ class GraphNetwork(GraphLayer):
|
|
|
1356
1431
|
Performs the same forward pass as `propagate` but with a `GradientTape`
|
|
1357
1432
|
watching intermediate node features.
|
|
1358
1433
|
|
|
1359
|
-
|
|
1434
|
+
Arguments:
|
|
1360
1435
|
tensor (tensors.GraphTensor):
|
|
1361
1436
|
The graph input.
|
|
1362
1437
|
"""
|
|
@@ -1414,26 +1489,25 @@ class NodeEmbedding(GraphLayer):
|
|
|
1414
1489
|
def __init__(
|
|
1415
1490
|
self,
|
|
1416
1491
|
dim: int = None,
|
|
1417
|
-
|
|
1492
|
+
normalization: bool = False,
|
|
1418
1493
|
embed_context: bool = True,
|
|
1494
|
+
allow_reconstruction: bool = False,
|
|
1419
1495
|
allow_masking: bool = True,
|
|
1420
1496
|
**kwargs
|
|
1421
1497
|
) -> None:
|
|
1422
1498
|
super().__init__(**kwargs)
|
|
1423
1499
|
self.dim = dim
|
|
1424
|
-
self.
|
|
1500
|
+
self._normalization = normalization
|
|
1425
1501
|
self._embed_context = embed_context
|
|
1426
1502
|
self._masking_rate = None
|
|
1427
1503
|
self._allow_masking = allow_masking
|
|
1504
|
+
self._allow_reconstruction = allow_reconstruction
|
|
1428
1505
|
|
|
1429
|
-
def
|
|
1430
|
-
"""Builds the layer.
|
|
1431
|
-
"""
|
|
1506
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1432
1507
|
feature_dim = spec.node['feature'].shape[-1]
|
|
1433
1508
|
if not self.dim:
|
|
1434
1509
|
self.dim = feature_dim
|
|
1435
1510
|
self._node_dense = self.get_dense(self.dim)
|
|
1436
|
-
self._node_dense.build([None, feature_dim])
|
|
1437
1511
|
|
|
1438
1512
|
self._has_super = 'super' in spec.node
|
|
1439
1513
|
has_context_feature = 'feature' in spec.context
|
|
@@ -1447,17 +1521,18 @@ class NodeEmbedding(GraphLayer):
|
|
|
1447
1521
|
if self._embed_context:
|
|
1448
1522
|
context_feature_dim = spec.context['feature'].shape[-1]
|
|
1449
1523
|
self._context_dense = self.get_dense(self.dim)
|
|
1450
|
-
self._context_dense.build([None, context_feature_dim])
|
|
1451
1524
|
|
|
1452
|
-
if self.
|
|
1453
|
-
self.
|
|
1454
|
-
|
|
1455
|
-
|
|
1456
|
-
|
|
1525
|
+
if self._normalization:
|
|
1526
|
+
if str(self._normalization).lower().startswith('batch'):
|
|
1527
|
+
self._norm = keras.layers.BatchNormalization(
|
|
1528
|
+
name='output_batch_norm'
|
|
1529
|
+
)
|
|
1530
|
+
else:
|
|
1531
|
+
self._norm = keras.layers.LayerNormalization(
|
|
1532
|
+
name='output_layer_norm'
|
|
1533
|
+
)
|
|
1457
1534
|
|
|
1458
1535
|
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1459
|
-
"""Calls the layer.
|
|
1460
|
-
"""
|
|
1461
1536
|
feature = self._node_dense(tensor.node['feature'])
|
|
1462
1537
|
|
|
1463
1538
|
if self._has_super:
|
|
@@ -1487,11 +1562,13 @@ class NodeEmbedding(GraphLayer):
|
|
|
1487
1562
|
# Slience warning of 'no gradients for variables'
|
|
1488
1563
|
feature = feature + (self._mask_feature * 0.0)
|
|
1489
1564
|
|
|
1490
|
-
if self.
|
|
1565
|
+
if self._normalization:
|
|
1491
1566
|
feature = self._norm(feature)
|
|
1492
1567
|
|
|
1493
|
-
|
|
1494
|
-
|
|
1568
|
+
if not self._allow_reconstruction:
|
|
1569
|
+
return tensor.update({'node': {'feature': feature}})
|
|
1570
|
+
return tensor.update({'node': {'feature': feature, 'target_feature': feature}})
|
|
1571
|
+
|
|
1495
1572
|
@property
|
|
1496
1573
|
def masking_rate(self):
|
|
1497
1574
|
return self._masking_rate
|
|
@@ -1509,9 +1586,10 @@ class NodeEmbedding(GraphLayer):
|
|
|
1509
1586
|
config = super().get_config()
|
|
1510
1587
|
config.update({
|
|
1511
1588
|
'dim': self.dim,
|
|
1512
|
-
'
|
|
1589
|
+
'normalization': self._normalization,
|
|
1513
1590
|
'embed_context': self._embed_context,
|
|
1514
|
-
'allow_masking': self._allow_masking
|
|
1591
|
+
'allow_masking': self._allow_masking,
|
|
1592
|
+
'allow_reconstruction': self._allow_reconstruction,
|
|
1515
1593
|
})
|
|
1516
1594
|
return config
|
|
1517
1595
|
|
|
@@ -1527,39 +1605,39 @@ class EdgeEmbedding(GraphLayer):
|
|
|
1527
1605
|
def __init__(
|
|
1528
1606
|
self,
|
|
1529
1607
|
dim: int = None,
|
|
1530
|
-
|
|
1608
|
+
normalization: bool = False,
|
|
1531
1609
|
allow_masking: bool = True,
|
|
1532
1610
|
**kwargs
|
|
1533
1611
|
) -> None:
|
|
1534
1612
|
super().__init__(**kwargs)
|
|
1535
1613
|
self.dim = dim
|
|
1536
|
-
self.
|
|
1614
|
+
self._normalization = normalization
|
|
1537
1615
|
self._masking_rate = None
|
|
1538
1616
|
self._allow_masking = allow_masking
|
|
1539
1617
|
|
|
1540
|
-
def
|
|
1541
|
-
"""Builds the layer.
|
|
1542
|
-
"""
|
|
1618
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1543
1619
|
feature_dim = spec.edge['feature'].shape[-1]
|
|
1544
1620
|
if not self.dim:
|
|
1545
1621
|
self.dim = feature_dim
|
|
1546
1622
|
self._edge_dense = self.get_dense(self.dim)
|
|
1547
|
-
self._edge_dense.build([None, feature_dim])
|
|
1548
1623
|
|
|
1549
1624
|
self._has_super = 'super' in spec.edge
|
|
1550
1625
|
if self._has_super:
|
|
1551
1626
|
self._super_feature = self.get_weight(shape=[self.dim], name='super_edge_feature')
|
|
1552
1627
|
if self._allow_masking:
|
|
1553
1628
|
self._mask_feature = self.get_weight(shape=[self.dim], name='mask_edge_feature')
|
|
1554
|
-
if self._normalize:
|
|
1555
|
-
self._norm = keras.layers.LayerNormalization()
|
|
1556
|
-
self._norm.build([None, self.dim])
|
|
1557
1629
|
|
|
1558
|
-
self.
|
|
1630
|
+
if self._normalization:
|
|
1631
|
+
if str(self._normalization).lower().startswith('batch'):
|
|
1632
|
+
self._norm = keras.layers.BatchNormalization(
|
|
1633
|
+
name='output_batch_norm'
|
|
1634
|
+
)
|
|
1635
|
+
else:
|
|
1636
|
+
self._norm = keras.layers.LayerNormalization(
|
|
1637
|
+
name='output_layer_norm'
|
|
1638
|
+
)
|
|
1559
1639
|
|
|
1560
1640
|
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1561
|
-
"""Calls the layer.
|
|
1562
|
-
"""
|
|
1563
1641
|
feature = self._edge_dense(tensor.edge['feature'])
|
|
1564
1642
|
|
|
1565
1643
|
if self._has_super:
|
|
@@ -1584,10 +1662,10 @@ class EdgeEmbedding(GraphLayer):
|
|
|
1584
1662
|
# Slience warning of 'no gradients for variables'
|
|
1585
1663
|
feature = feature + (self._mask_feature * 0.0)
|
|
1586
1664
|
|
|
1587
|
-
if self.
|
|
1665
|
+
if self._normalization:
|
|
1588
1666
|
feature = self._norm(feature)
|
|
1589
1667
|
|
|
1590
|
-
return tensor.update({'edge': {'feature': feature}})
|
|
1668
|
+
return tensor.update({'edge': {'feature': feature, 'embedding': feature}})
|
|
1591
1669
|
|
|
1592
1670
|
@property
|
|
1593
1671
|
def masking_rate(self):
|
|
@@ -1606,18 +1684,67 @@ class EdgeEmbedding(GraphLayer):
|
|
|
1606
1684
|
config = super().get_config()
|
|
1607
1685
|
config.update({
|
|
1608
1686
|
'dim': self.dim,
|
|
1609
|
-
'
|
|
1687
|
+
'normalization': self._normalization,
|
|
1610
1688
|
'allow_masking': self._allow_masking
|
|
1611
1689
|
})
|
|
1612
1690
|
return config
|
|
1613
1691
|
|
|
1614
1692
|
|
|
1693
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1694
|
+
class Projection(GraphLayer):
|
|
1695
|
+
"""Base graph projection layer.
|
|
1696
|
+
"""
|
|
1697
|
+
def __init__(
|
|
1698
|
+
self,
|
|
1699
|
+
units: int = None,
|
|
1700
|
+
activation: str | keras.layers.Activation | None = None,
|
|
1701
|
+
use_bias: bool = True,
|
|
1702
|
+
field: str = 'node',
|
|
1703
|
+
**kwargs
|
|
1704
|
+
) -> None:
|
|
1705
|
+
super().__init__(use_bias=use_bias, **kwargs)
|
|
1706
|
+
self.units = units
|
|
1707
|
+
self._activation = keras.activations.get(activation)
|
|
1708
|
+
self.field = field
|
|
1709
|
+
|
|
1710
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1711
|
+
data = getattr(spec, self.field, None)
|
|
1712
|
+
if data is None:
|
|
1713
|
+
raise ValueError('Could not access field {self.field!r}.')
|
|
1714
|
+
feature_dim = data['feature'].shape[-1]
|
|
1715
|
+
if not self.units:
|
|
1716
|
+
self.units = feature_dim
|
|
1717
|
+
self._dense = self.get_dense(self.units)
|
|
1718
|
+
|
|
1719
|
+
def propagate(self, tensor: tensors.GraphTensor):
|
|
1720
|
+
feature = getattr(tensor, self.field)['feature']
|
|
1721
|
+
feature = self._dense(feature)
|
|
1722
|
+
feature = self._activation(feature)
|
|
1723
|
+
return tensor.update(
|
|
1724
|
+
{
|
|
1725
|
+
self.field: {
|
|
1726
|
+
'feature': feature
|
|
1727
|
+
}
|
|
1728
|
+
}
|
|
1729
|
+
)
|
|
1730
|
+
|
|
1731
|
+
def get_config(self) -> dict:
|
|
1732
|
+
config = super().get_config()
|
|
1733
|
+
config.update({
|
|
1734
|
+
'units': self.units,
|
|
1735
|
+
'activation': keras.activations.serialize(self._activation),
|
|
1736
|
+
'field': self.field,
|
|
1737
|
+
})
|
|
1738
|
+
return config
|
|
1739
|
+
|
|
1740
|
+
|
|
1615
1741
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1616
1742
|
class ContextProjection(Projection):
|
|
1617
1743
|
"""Context projection layer.
|
|
1618
1744
|
"""
|
|
1619
1745
|
def __init__(self, units: int = None, activation: str = None, **kwargs):
|
|
1620
|
-
|
|
1746
|
+
kwargs['field'] = 'context'
|
|
1747
|
+
super().__init__(units=units, activation=activation, **kwargs)
|
|
1621
1748
|
|
|
1622
1749
|
|
|
1623
1750
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
@@ -1625,7 +1752,8 @@ class NodeProjection(Projection):
|
|
|
1625
1752
|
"""Node projection layer.
|
|
1626
1753
|
"""
|
|
1627
1754
|
def __init__(self, units: int = None, activation: str = None, **kwargs):
|
|
1628
|
-
|
|
1755
|
+
kwargs['field'] = 'node'
|
|
1756
|
+
super().__init__(units=units, activation=activation, **kwargs)
|
|
1629
1757
|
|
|
1630
1758
|
|
|
1631
1759
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
@@ -1633,9 +1761,55 @@ class EdgeProjection(Projection):
|
|
|
1633
1761
|
"""Edge projection layer.
|
|
1634
1762
|
"""
|
|
1635
1763
|
def __init__(self, units: int = None, activation: str = None, **kwargs):
|
|
1636
|
-
|
|
1764
|
+
kwargs['field'] = 'edge'
|
|
1765
|
+
super().__init__(units=units, activation=activation, **kwargs)
|
|
1637
1766
|
|
|
1638
1767
|
|
|
1768
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1769
|
+
class Reconstruction(GraphLayer):
|
|
1770
|
+
|
|
1771
|
+
def __init__(
|
|
1772
|
+
self,
|
|
1773
|
+
loss: keras.losses.Loss | str = 'mse',
|
|
1774
|
+
loss_weight: float = 0.5,
|
|
1775
|
+
**kwargs
|
|
1776
|
+
):
|
|
1777
|
+
super().__init__(**kwargs)
|
|
1778
|
+
self._loss_fn = keras.losses.get(loss)
|
|
1779
|
+
self._loss_weight = loss_weight
|
|
1780
|
+
|
|
1781
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1782
|
+
has_target_node_feature = 'target_feature' in spec.node
|
|
1783
|
+
if not has_target_node_feature:
|
|
1784
|
+
raise ValueError(
|
|
1785
|
+
'Could not find `target_feature` in `spec.node`. '
|
|
1786
|
+
'Add a `target_feature` via `NodeEmbedding` by setting '
|
|
1787
|
+
'`allow_reconstruction` to `True`.'
|
|
1788
|
+
)
|
|
1789
|
+
output_dim = spec.node['target_feature'].shape[-1]
|
|
1790
|
+
self._dense = self.get_dense(output_dim)
|
|
1791
|
+
|
|
1792
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1793
|
+
target_node_feature = tensor.node['target_feature']
|
|
1794
|
+
transformed_node_feature = tensor.node['feature']
|
|
1795
|
+
|
|
1796
|
+
reconstructed_node_feature = self._dense(
|
|
1797
|
+
transformed_node_feature
|
|
1798
|
+
)
|
|
1799
|
+
|
|
1800
|
+
loss = self._loss_fn(
|
|
1801
|
+
target_node_feature, reconstructed_node_feature
|
|
1802
|
+
)
|
|
1803
|
+
self.add_loss(keras.ops.sum(loss) * self._loss_weight)
|
|
1804
|
+
return tensor.update({'node': {'feature': transformed_node_feature}})
|
|
1805
|
+
|
|
1806
|
+
def get_config(self):
|
|
1807
|
+
config = super().get_config()
|
|
1808
|
+
config['loss'] = keras.losses.serialize(self._loss_fn)
|
|
1809
|
+
config['loss_weight'] = self._loss_weight
|
|
1810
|
+
return config
|
|
1811
|
+
|
|
1812
|
+
|
|
1639
1813
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1640
1814
|
class EdgeBias(GraphLayer):
|
|
1641
1815
|
|
|
@@ -1643,18 +1817,15 @@ class EdgeBias(GraphLayer):
|
|
|
1643
1817
|
super().__init__(**kwargs)
|
|
1644
1818
|
self.biases = biases
|
|
1645
1819
|
|
|
1646
|
-
def
|
|
1820
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1647
1821
|
self._has_edge_length = 'length' in spec.edge
|
|
1648
1822
|
self._has_edge_feature = 'feature' in spec.edge
|
|
1649
1823
|
if self._has_edge_feature:
|
|
1650
1824
|
self._edge_feature_dense = self.get_dense(self.biases)
|
|
1651
|
-
self._edge_feature_dense.build([None, spec.edge['feature'].shape[-1]])
|
|
1652
1825
|
if self._has_edge_length:
|
|
1653
1826
|
self._edge_length_dense = self.get_dense(
|
|
1654
1827
|
self.biases, kernel_initializer='zeros'
|
|
1655
1828
|
)
|
|
1656
|
-
self._edge_length_dense.build([None, spec.edge['length'].shape[-1]])
|
|
1657
|
-
self.built = True
|
|
1658
1829
|
|
|
1659
1830
|
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1660
1831
|
bias = keras.ops.zeros(
|
|
@@ -1680,7 +1851,7 @@ class GaussianDistance(GraphLayer):
|
|
|
1680
1851
|
super().__init__(**kwargs)
|
|
1681
1852
|
self.kernels = kernels
|
|
1682
1853
|
|
|
1683
|
-
def
|
|
1854
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1684
1855
|
self._loc = self.add_weight(
|
|
1685
1856
|
shape=[self.kernels],
|
|
1686
1857
|
initializer='zeros',
|
|
@@ -1693,8 +1864,7 @@ class GaussianDistance(GraphLayer):
|
|
|
1693
1864
|
dtype='float32',
|
|
1694
1865
|
trainable=True
|
|
1695
1866
|
)
|
|
1696
|
-
|
|
1697
|
-
|
|
1867
|
+
|
|
1698
1868
|
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1699
1869
|
euclidean_distance = ops.euclidean_distance(
|
|
1700
1870
|
tensor.gather('coordinate', 'source'),
|
|
@@ -1711,49 +1881,6 @@ class GaussianDistance(GraphLayer):
|
|
|
1711
1881
|
'kernels': self.kernels,
|
|
1712
1882
|
})
|
|
1713
1883
|
return config
|
|
1714
|
-
|
|
1715
|
-
|
|
1716
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1717
|
-
class Readout(GraphLayer):
|
|
1718
|
-
|
|
1719
|
-
"""Readout layer.
|
|
1720
|
-
"""
|
|
1721
|
-
|
|
1722
|
-
def __init__(self, mode: str | None = None, **kwargs):
|
|
1723
|
-
kwargs['kernel_initializer'] = None
|
|
1724
|
-
kwargs['bias_initializer'] = None
|
|
1725
|
-
super().__init__(**kwargs)
|
|
1726
|
-
self.mode = mode
|
|
1727
|
-
if str(self.mode).lower().startswith('sum'):
|
|
1728
|
-
self._reduce_fn = keras.ops.segment_sum
|
|
1729
|
-
elif str(self.mode).lower().startswith('max'):
|
|
1730
|
-
self._reduce_fn = keras.ops.segment_max
|
|
1731
|
-
elif str(self.mode).lower().startswith('super'):
|
|
1732
|
-
self._reduce_fn = keras.ops.segment_sum
|
|
1733
|
-
else:
|
|
1734
|
-
self._reduce_fn = ops.segment_mean
|
|
1735
|
-
|
|
1736
|
-
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1737
|
-
"""Builds the layer.
|
|
1738
|
-
"""
|
|
1739
|
-
self.built = True
|
|
1740
|
-
|
|
1741
|
-
def propagate(self, tensor: tensors.GraphTensor) -> tf.Tensor:
|
|
1742
|
-
"""Calls the layer.
|
|
1743
|
-
"""
|
|
1744
|
-
node_feature = tensor.node['feature']
|
|
1745
|
-
if str(self.mode).lower().startswith('super'):
|
|
1746
|
-
node_feature = keras.ops.where(
|
|
1747
|
-
tensor.node['super'][:, None], node_feature, 0.0
|
|
1748
|
-
)
|
|
1749
|
-
return self._reduce_fn(
|
|
1750
|
-
node_feature, tensor.graph_indicator, tensor.num_subgraphs
|
|
1751
|
-
)
|
|
1752
|
-
|
|
1753
|
-
def get_config(self) -> dict:
|
|
1754
|
-
config = super().get_config()
|
|
1755
|
-
config['mode'] = self.mode
|
|
1756
|
-
return config
|
|
1757
1884
|
|
|
1758
1885
|
|
|
1759
1886
|
def Input(spec: tensors.GraphTensor.Spec) -> dict:
|