molcraft 0.1.0a1__py3-none-any.whl → 0.1.0a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of molcraft might be problematic. Click here for more details.
- molcraft/__init__.py +2 -1
- molcraft/datasets.py +123 -0
- molcraft/experimental/peptides.py +28 -67
- molcraft/features.py +5 -3
- molcraft/featurizers.py +68 -27
- molcraft/layers.py +1299 -647
- molcraft/models.py +35 -5
- molcraft/tensors.py +33 -12
- {molcraft-0.1.0a1.dist-info → molcraft-0.1.0a3.dist-info}/METADATA +68 -1
- molcraft-0.1.0a3.dist-info/RECORD +20 -0
- molcraft-0.1.0a1.dist-info/RECORD +0 -19
- {molcraft-0.1.0a1.dist-info → molcraft-0.1.0a3.dist-info}/WHEEL +0 -0
- {molcraft-0.1.0a1.dist-info → molcraft-0.1.0a3.dist-info}/licenses/LICENSE +0 -0
- {molcraft-0.1.0a1.dist-info → molcraft-0.1.0a3.dist-info}/top_level.txt +0 -0
molcraft/layers.py
CHANGED
|
@@ -60,25 +60,20 @@ class GraphLayer(keras.layers.Layer):
|
|
|
60
60
|
May use built-in methods such as `get_weight`, `get_dense` and `get_einsum_dense`.
|
|
61
61
|
|
|
62
62
|
Optionally implemented by subclass. If implemented, it is recommended to
|
|
63
|
-
|
|
64
|
-
|
|
63
|
+
If sub-layers are built (via `build` or `build_from_spec`), set `built`
|
|
64
|
+
to True. If not, symbolic input will be passed through the layer to build them.
|
|
65
65
|
|
|
66
66
|
Args:
|
|
67
67
|
spec:
|
|
68
|
-
A `GraphTensor.Spec` instance, corresponding to the
|
|
69
|
-
|
|
68
|
+
A `GraphTensor.Spec` instance, corresponding to the `GraphTensor`
|
|
69
|
+
passed to `propagate`.
|
|
70
70
|
"""
|
|
71
|
-
|
|
71
|
+
|
|
72
72
|
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
73
73
|
|
|
74
74
|
self._custom_build_config = {'spec': _serialize_spec(spec)}
|
|
75
75
|
|
|
76
|
-
|
|
77
|
-
GraphLayer.build_from_spec != self.__class__.build_from_spec
|
|
78
|
-
)
|
|
79
|
-
if invoke_build_from_spec:
|
|
80
|
-
self.build_from_spec(spec)
|
|
81
|
-
self.built = True
|
|
76
|
+
self.build_from_spec(spec)
|
|
82
77
|
|
|
83
78
|
if not self.built:
|
|
84
79
|
# Automatically build layer or model by calling it on symbolic inputs
|
|
@@ -206,12 +201,66 @@ class GraphLayer(keras.layers.Layer):
|
|
|
206
201
|
class GraphConv(GraphLayer):
|
|
207
202
|
|
|
208
203
|
"""Base graph neural network layer.
|
|
204
|
+
|
|
205
|
+
For normalization and skip connection to work, the `GraphConv` subclass
|
|
206
|
+
requires the (node feature) output of `aggregate` and `update` to have a
|
|
207
|
+
dimension of `self.units`, respectively.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
units:
|
|
211
|
+
The number of units.
|
|
212
|
+
normalize:
|
|
213
|
+
Whether `LayerNormalization` should be applied to the (node feature) output
|
|
214
|
+
of the `aggregate` step. While normalization is recommended, it is not used
|
|
215
|
+
by default.
|
|
216
|
+
skip_connection:
|
|
217
|
+
Whether (node feature) input should be added to the (node feature) output.
|
|
218
|
+
If (node feature) input dim is not equal to `units`, a projection layer will
|
|
219
|
+
automatically project the residual before adding it to the output. While skip
|
|
220
|
+
connection is recommended, it is not used by default.
|
|
221
|
+
kwargs:
|
|
222
|
+
See arguments of `GraphLayer`.
|
|
209
223
|
"""
|
|
210
224
|
|
|
211
|
-
def __init__(
|
|
225
|
+
def __init__(
|
|
226
|
+
self,
|
|
227
|
+
units: int = None,
|
|
228
|
+
normalize: bool = False,
|
|
229
|
+
skip_connection: bool = False,
|
|
230
|
+
**kwargs
|
|
231
|
+
) -> None:
|
|
212
232
|
super().__init__(**kwargs)
|
|
213
233
|
self.units = units
|
|
214
|
-
|
|
234
|
+
self._normalize_aggregate = normalize
|
|
235
|
+
self._skip_connection = skip_connection
|
|
236
|
+
|
|
237
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
238
|
+
if not self.units:
|
|
239
|
+
raise ValueError(
|
|
240
|
+
f'`self.units` needs to be a positive integer. ound: {self.units}.'
|
|
241
|
+
)
|
|
242
|
+
node_feature_dim = spec.node['feature'].shape[-1]
|
|
243
|
+
self._project_input_node_feature = (
|
|
244
|
+
self._skip_connection and (node_feature_dim != self.units)
|
|
245
|
+
)
|
|
246
|
+
if self._project_input_node_feature:
|
|
247
|
+
warn(
|
|
248
|
+
'`skip_connection` is set to `True`, but found incompatible dim '
|
|
249
|
+
'between input (node feature dim) and output (`self.units`). '
|
|
250
|
+
'Automatically applying a projection layer to residual to '
|
|
251
|
+
'match input and output. '
|
|
252
|
+
)
|
|
253
|
+
self._residual_projection = self.get_dense(
|
|
254
|
+
self.units, name='residual_projection'
|
|
255
|
+
)
|
|
256
|
+
if self._normalize_aggregate:
|
|
257
|
+
self._aggregation_norm = keras.layers.LayerNormalization(
|
|
258
|
+
name='aggregation_normalization'
|
|
259
|
+
)
|
|
260
|
+
self._aggregation_norm.build([None, self.units])
|
|
261
|
+
|
|
262
|
+
super().build(spec)
|
|
263
|
+
|
|
215
264
|
@abc.abstractmethod
|
|
216
265
|
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
217
266
|
"""Compute messages.
|
|
@@ -256,206 +305,1123 @@ class GraphConv(GraphLayer):
|
|
|
256
305
|
tensor:
|
|
257
306
|
A `GraphTensor` instance.
|
|
258
307
|
"""
|
|
308
|
+
|
|
309
|
+
if self._skip_connection:
|
|
310
|
+
input_node_feature = tensor.node['feature']
|
|
311
|
+
if self._project_input_node_feature:
|
|
312
|
+
input_node_feature = self._residual_projection(input_node_feature)
|
|
313
|
+
|
|
259
314
|
tensor = self.message(tensor)
|
|
260
315
|
tensor = self.aggregate(tensor)
|
|
261
|
-
tensor = self.update(tensor)
|
|
262
|
-
return tensor
|
|
263
316
|
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
'units': self.units
|
|
268
|
-
})
|
|
269
|
-
return config
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
273
|
-
class Projection(GraphLayer):
|
|
274
|
-
"""Base graph projection layer.
|
|
275
|
-
"""
|
|
276
|
-
def __init__(
|
|
277
|
-
self,
|
|
278
|
-
units: int = None,
|
|
279
|
-
activation: str = None,
|
|
280
|
-
field: str = 'node',
|
|
281
|
-
**kwargs
|
|
282
|
-
) -> None:
|
|
283
|
-
super().__init__(**kwargs)
|
|
284
|
-
self.units = units
|
|
285
|
-
self._activation = keras.activations.get(activation)
|
|
286
|
-
self.field = field
|
|
317
|
+
if self._normalize_aggregate:
|
|
318
|
+
normalized_node_feature = self._aggregation_norm(tensor.node['feature'])
|
|
319
|
+
tensor = tensor.update({'node': {'feature': normalized_node_feature}})
|
|
287
320
|
|
|
288
|
-
|
|
289
|
-
"""Builds the layer.
|
|
290
|
-
"""
|
|
291
|
-
data = getattr(spec, self.field, None)
|
|
292
|
-
if data is None:
|
|
293
|
-
raise ValueError('Could not access field {self.field!r}.')
|
|
294
|
-
feature_dim = data['feature'].shape[-1]
|
|
295
|
-
if not self.units:
|
|
296
|
-
self.units = feature_dim
|
|
297
|
-
self._dense = self.get_dense(self.units)
|
|
298
|
-
self._dense.build([None, feature_dim])
|
|
321
|
+
tensor = self.update(tensor)
|
|
299
322
|
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
feature = self._dense(feature)
|
|
305
|
-
feature = self._activation(feature)
|
|
323
|
+
if not self._skip_connection:
|
|
324
|
+
return tensor
|
|
325
|
+
|
|
326
|
+
updated_node_feature = tensor.node['feature']
|
|
306
327
|
return tensor.update(
|
|
307
328
|
{
|
|
308
|
-
|
|
309
|
-
'feature':
|
|
329
|
+
'node': {
|
|
330
|
+
'feature': updated_node_feature + input_node_feature
|
|
310
331
|
}
|
|
311
332
|
}
|
|
312
|
-
)
|
|
333
|
+
)
|
|
313
334
|
|
|
314
335
|
def get_config(self) -> dict:
|
|
315
336
|
config = super().get_config()
|
|
316
337
|
config.update({
|
|
317
338
|
'units': self.units,
|
|
318
|
-
'
|
|
319
|
-
'
|
|
339
|
+
'normalize': self._normalize_aggregate,
|
|
340
|
+
'skip_connection': self._skip_connection,
|
|
320
341
|
})
|
|
321
342
|
return config
|
|
322
343
|
|
|
323
344
|
|
|
324
345
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
325
|
-
class
|
|
326
|
-
|
|
327
|
-
"""Graph neural network.
|
|
346
|
+
class GIConv(GraphConv):
|
|
328
347
|
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
Args:
|
|
332
|
-
layers (list):
|
|
333
|
-
A list of graph layers.
|
|
348
|
+
"""Graph isomorphism network layer.
|
|
334
349
|
"""
|
|
335
350
|
|
|
336
|
-
def __init__(
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
351
|
+
def __init__(
|
|
352
|
+
self,
|
|
353
|
+
units: int,
|
|
354
|
+
activation: keras.layers.Activation | str | None = 'relu',
|
|
355
|
+
use_bias: bool = True,
|
|
356
|
+
normalize: bool = True,
|
|
357
|
+
dropout: float = 0.0,
|
|
358
|
+
update_edge_feature: bool = True,
|
|
359
|
+
**kwargs,
|
|
360
|
+
):
|
|
361
|
+
super().__init__(
|
|
362
|
+
units=units,
|
|
363
|
+
normalize=normalize,
|
|
364
|
+
use_bias=use_bias,
|
|
365
|
+
**kwargs
|
|
366
|
+
)
|
|
367
|
+
self._activation = keras.activations.get(activation)
|
|
368
|
+
self._dropout = dropout
|
|
369
|
+
self._update_edge_feature = update_edge_feature
|
|
340
370
|
|
|
341
371
|
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
342
372
|
"""Builds the layer.
|
|
343
373
|
"""
|
|
344
|
-
units = self.layers[0].units
|
|
345
374
|
node_feature_dim = spec.node['feature'].shape[-1]
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
375
|
+
|
|
376
|
+
self.epsilon = self.add_weight(
|
|
377
|
+
name='epsilon',
|
|
378
|
+
shape=(),
|
|
379
|
+
initializer='zeros',
|
|
380
|
+
trainable=True,
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
self._has_edge_feature = 'feature' in spec.edge
|
|
384
|
+
if self._has_edge_feature:
|
|
355
385
|
edge_feature_dim = spec.edge['feature'].shape[-1]
|
|
356
|
-
if edge_feature_dim != units:
|
|
357
|
-
warn(
|
|
358
|
-
'Edge feature dim does not match `units` of the first layer. '
|
|
359
|
-
'Automatically adding a edge projection layer to match `units`.'
|
|
360
|
-
)
|
|
361
|
-
self._edge_dense = self.get_dense(units)
|
|
362
|
-
self._update_edge_feature = True
|
|
363
386
|
|
|
364
|
-
|
|
365
|
-
|
|
387
|
+
if not self._update_edge_feature:
|
|
388
|
+
if (edge_feature_dim != node_feature_dim):
|
|
389
|
+
warn(
|
|
390
|
+
'Found edge feature dim to be incompatible with node feature dim. '
|
|
391
|
+
'Automatically adding a edge feature projection layer to match '
|
|
392
|
+
'the dim of node features.'
|
|
393
|
+
)
|
|
394
|
+
self._update_edge_feature = True
|
|
395
|
+
|
|
396
|
+
if self._update_edge_feature:
|
|
397
|
+
self._edge_dense = self.get_dense(node_feature_dim)
|
|
398
|
+
self._edge_dense.build([None, edge_feature_dim])
|
|
399
|
+
else:
|
|
400
|
+
self._update_edge_feature = False
|
|
401
|
+
|
|
402
|
+
self._feedforward_intermediate_dense = self.get_dense(self.units)
|
|
403
|
+
self._feedforward_intermediate_dense.build([None, node_feature_dim])
|
|
404
|
+
|
|
405
|
+
has_overridden_update = self.__class__.update != GIConv.update
|
|
406
|
+
if not has_overridden_update:
|
|
407
|
+
self._feedforward_activation = self._activation
|
|
408
|
+
self._feedforward_dropout = keras.layers.Dropout(self._dropout)
|
|
409
|
+
self._feedforward_output_dense = self.get_dense(self.units)
|
|
410
|
+
self._feedforward_output_dense.build([None, self.units])
|
|
411
|
+
|
|
412
|
+
self.built = True
|
|
413
|
+
|
|
414
|
+
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
415
|
+
"""Computes messages.
|
|
366
416
|
"""
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
x['node']['feature'] = self._node_dense(tensor.node['feature'])
|
|
417
|
+
message = tensor.gather('feature', 'source')
|
|
418
|
+
edge_feature = tensor.edge.get('feature')
|
|
370
419
|
if self._update_edge_feature:
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
x = layer(x)
|
|
375
|
-
outputs.append(x['node']['feature'])
|
|
420
|
+
edge_feature = self._edge_dense(edge_feature)
|
|
421
|
+
if self._has_edge_feature:
|
|
422
|
+
message += edge_feature
|
|
376
423
|
return tensor.update(
|
|
377
424
|
{
|
|
378
|
-
'
|
|
379
|
-
'
|
|
380
|
-
|
|
425
|
+
'edge': {
|
|
426
|
+
'message': message,
|
|
427
|
+
'feature': edge_feature
|
|
428
|
+
}
|
|
381
429
|
}
|
|
382
430
|
)
|
|
383
|
-
|
|
384
|
-
def tape_propagate(
|
|
385
|
-
self,
|
|
386
|
-
tensor: tensors.GraphTensor,
|
|
387
|
-
tape: tf.GradientTape,
|
|
388
|
-
training: bool | None = None,
|
|
389
|
-
) -> tuple[tensors.GraphTensor, list[tf.Tensor]]:
|
|
390
|
-
"""Performs the propagation with a `GradientTape`.
|
|
391
|
-
|
|
392
|
-
Performs the same forward pass as `propagate` but with a `GradientTape`
|
|
393
|
-
watching intermediate node features.
|
|
394
431
|
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
The graph input.
|
|
432
|
+
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
433
|
+
"""Aggregates messages.
|
|
398
434
|
"""
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
if self._update_node_feature:
|
|
404
|
-
x['node']['feature'] = self._node_dense(tensor.node['feature'])
|
|
405
|
-
if self._update_edge_feature:
|
|
406
|
-
x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
|
|
407
|
-
tape.watch(x['node']['feature'])
|
|
408
|
-
outputs = [x['node']['feature']]
|
|
409
|
-
for layer in self.layers:
|
|
410
|
-
x = layer(x, training=training)
|
|
411
|
-
tape.watch(x['node']['feature'])
|
|
412
|
-
outputs.append(x['node']['feature'])
|
|
413
|
-
|
|
414
|
-
tensor = tensor.update(
|
|
435
|
+
node_feature = tensor.aggregate('message')
|
|
436
|
+
node_feature += (1 + self.epsilon) * tensor.node['feature']
|
|
437
|
+
node_feature = self._feedforward_intermediate_dense(node_feature)
|
|
438
|
+
return tensor.update(
|
|
415
439
|
{
|
|
416
440
|
'node': {
|
|
417
|
-
'feature':
|
|
441
|
+
'feature': node_feature,
|
|
442
|
+
},
|
|
443
|
+
'edge': {
|
|
444
|
+
'message': None,
|
|
418
445
|
}
|
|
419
446
|
}
|
|
420
447
|
)
|
|
421
|
-
return tensor, outputs
|
|
422
448
|
|
|
423
|
-
def
|
|
424
|
-
|
|
425
|
-
|
|
449
|
+
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
450
|
+
"""Updates nodes.
|
|
451
|
+
"""
|
|
452
|
+
node_feature = tensor.node['feature']
|
|
453
|
+
node_feature = self._feedforward_activation(node_feature)
|
|
454
|
+
node_feature = self._feedforward_dropout(node_feature)
|
|
455
|
+
node_feature = self._feedforward_output_dense(node_feature)
|
|
456
|
+
return tensor.update(
|
|
426
457
|
{
|
|
427
|
-
'
|
|
428
|
-
|
|
429
|
-
|
|
458
|
+
'node': {
|
|
459
|
+
'feature': node_feature,
|
|
460
|
+
}
|
|
430
461
|
}
|
|
431
462
|
)
|
|
432
|
-
return config
|
|
433
|
-
|
|
434
|
-
@classmethod
|
|
435
|
-
def from_config(cls, config: dict) -> 'GraphNetwork':
|
|
436
|
-
config['layers'] = [
|
|
437
|
-
keras.layers.deserialize(layer) for layer in config['layers']
|
|
438
|
-
]
|
|
439
|
-
return super().from_config(config)
|
|
440
463
|
|
|
464
|
+
def get_config(self) -> dict:
|
|
465
|
+
config = super().get_config()
|
|
466
|
+
config.update({
|
|
467
|
+
'activation': keras.activations.serialize(self._activation),
|
|
468
|
+
'dropout': self._dropout,
|
|
469
|
+
'update_edge_feature': self._update_edge_feature
|
|
470
|
+
})
|
|
471
|
+
return config
|
|
441
472
|
|
|
442
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
443
|
-
class NodeEmbedding(GraphLayer):
|
|
444
473
|
|
|
445
|
-
|
|
474
|
+
@keras.saving.register_keras_serializable(package='molgraphx')
|
|
475
|
+
class GAConv(GraphConv):
|
|
446
476
|
|
|
447
|
-
|
|
477
|
+
"""Graph attention network layer.
|
|
448
478
|
"""
|
|
449
479
|
|
|
450
480
|
def __init__(
|
|
451
|
-
self,
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
481
|
+
self,
|
|
482
|
+
units: int,
|
|
483
|
+
heads: int = 8,
|
|
484
|
+
activation: keras.layers.Activation | str | None = "relu",
|
|
485
|
+
use_bias: bool = True,
|
|
486
|
+
normalize: bool = True,
|
|
487
|
+
dropout: float = 0.0,
|
|
488
|
+
update_edge_feature: bool = True,
|
|
489
|
+
attention_activation: keras.layers.Activation | str | None = "leaky_relu",
|
|
490
|
+
**kwargs,
|
|
491
|
+
) -> None:
|
|
492
|
+
kwargs['skip_connection'] = False
|
|
493
|
+
super().__init__(
|
|
494
|
+
units=units,
|
|
495
|
+
normalize=normalize,
|
|
496
|
+
use_bias=use_bias,
|
|
497
|
+
**kwargs
|
|
498
|
+
)
|
|
499
|
+
self._heads = heads
|
|
500
|
+
if self.units % self.heads != 0:
|
|
501
|
+
raise ValueError(f"units need to be divisible by heads.")
|
|
502
|
+
self._head_units = self.units // self.heads
|
|
503
|
+
self._activation = keras.activations.get(activation)
|
|
504
|
+
self._dropout = dropout
|
|
505
|
+
self._normalize = normalize
|
|
506
|
+
self._update_edge_feature = update_edge_feature
|
|
507
|
+
self._attention_activation = keras.activations.get(attention_activation)
|
|
508
|
+
|
|
509
|
+
@property
|
|
510
|
+
def heads(self):
|
|
511
|
+
return self._heads
|
|
512
|
+
|
|
513
|
+
@property
|
|
514
|
+
def head_units(self):
|
|
515
|
+
return self._head_units
|
|
516
|
+
|
|
517
|
+
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
518
|
+
|
|
519
|
+
node_feature_dim = spec.node['feature'].shape[-1]
|
|
520
|
+
attn_feature_dim = node_feature_dim + node_feature_dim
|
|
521
|
+
|
|
522
|
+
self._has_edge_feature = 'feature' in spec.edge
|
|
523
|
+
if self._has_edge_feature:
|
|
524
|
+
edge_feature_dim = spec.edge['feature'].shape[-1]
|
|
525
|
+
attn_feature_dim += edge_feature_dim
|
|
526
|
+
if self._update_edge_feature:
|
|
527
|
+
self._edge_dense = self.get_einsum_dense(
|
|
528
|
+
'ijh,jkh->ikh', (self.head_units, self.heads)
|
|
529
|
+
)
|
|
530
|
+
self._edge_dense.build([None, self.head_units, self.heads])
|
|
531
|
+
else:
|
|
532
|
+
self._update_edge_feature = False
|
|
533
|
+
|
|
534
|
+
self._node_dense = self.get_einsum_dense(
|
|
535
|
+
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
536
|
+
)
|
|
537
|
+
self._node_dense.build([None, node_feature_dim])
|
|
538
|
+
|
|
539
|
+
self._feature_dense = self.get_einsum_dense(
|
|
540
|
+
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
541
|
+
)
|
|
542
|
+
self._feature_dense.build([None, attn_feature_dim])
|
|
543
|
+
|
|
544
|
+
self._attention_dense = self.get_einsum_dense(
|
|
545
|
+
'ijh,jkh->ikh', (1, self.heads)
|
|
546
|
+
)
|
|
547
|
+
self._attention_dense.build([None, self.head_units, self.heads])
|
|
548
|
+
|
|
549
|
+
self._node_self_dense = self.get_einsum_dense(
|
|
550
|
+
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
551
|
+
)
|
|
552
|
+
self._node_self_dense.build([None, node_feature_dim])
|
|
553
|
+
self._dropout_layer = keras.layers.Dropout(self._dropout)
|
|
554
|
+
|
|
555
|
+
self.built = True
|
|
556
|
+
|
|
557
|
+
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
558
|
+
|
|
559
|
+
attention_feature = keras.ops.concatenate(
|
|
560
|
+
[
|
|
561
|
+
tensor.gather('feature', 'source'),
|
|
562
|
+
tensor.gather('feature', 'target')
|
|
563
|
+
],
|
|
564
|
+
axis=-1
|
|
565
|
+
)
|
|
566
|
+
if self._has_edge_feature:
|
|
567
|
+
attention_feature = keras.ops.concatenate(
|
|
568
|
+
[
|
|
569
|
+
attention_feature,
|
|
570
|
+
tensor.edge['feature']
|
|
571
|
+
],
|
|
572
|
+
axis=-1
|
|
573
|
+
)
|
|
574
|
+
|
|
575
|
+
attention_feature = self._feature_dense(attention_feature)
|
|
576
|
+
|
|
577
|
+
edge_feature = tensor.edge.get('feature')
|
|
578
|
+
|
|
579
|
+
if self._update_edge_feature:
|
|
580
|
+
edge_feature = self._edge_dense(attention_feature)
|
|
581
|
+
edge_feature = keras.ops.reshape(edge_feature, (-1, self.units))
|
|
582
|
+
|
|
583
|
+
attention_feature = self._attention_activation(attention_feature)
|
|
584
|
+
attention_score = self._attention_dense(attention_feature)
|
|
585
|
+
attention_score = ops.edge_softmax(
|
|
586
|
+
score=attention_score, edge_target=tensor.edge['target']
|
|
587
|
+
)
|
|
588
|
+
node_feature = self._node_dense(tensor.node['feature'])
|
|
589
|
+
message = ops.gather(node_feature, tensor.edge['source'])
|
|
590
|
+
return tensor.update(
|
|
591
|
+
{
|
|
592
|
+
'edge': {
|
|
593
|
+
'message': message,
|
|
594
|
+
'weight': attention_score,
|
|
595
|
+
'feature': edge_feature,
|
|
596
|
+
}
|
|
597
|
+
}
|
|
598
|
+
)
|
|
599
|
+
|
|
600
|
+
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
601
|
+
node_feature = tensor.aggregate('message')
|
|
602
|
+
node_feature += self._node_self_dense(tensor.node['feature'])
|
|
603
|
+
node_feature = self._dropout_layer(node_feature)
|
|
604
|
+
node_feature = keras.ops.reshape(node_feature, (-1, self.units))
|
|
605
|
+
return tensor.update(
|
|
606
|
+
{
|
|
607
|
+
'node': {
|
|
608
|
+
'feature': node_feature
|
|
609
|
+
},
|
|
610
|
+
'edge': {
|
|
611
|
+
'message': None,
|
|
612
|
+
'weight': None,
|
|
613
|
+
}
|
|
614
|
+
}
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
618
|
+
node_feature = self._activation(tensor.node['feature'])
|
|
619
|
+
return tensor.update(
|
|
620
|
+
{
|
|
621
|
+
'node': {
|
|
622
|
+
'feature': node_feature
|
|
623
|
+
}
|
|
624
|
+
}
|
|
625
|
+
)
|
|
626
|
+
|
|
627
|
+
def get_config(self) -> dict:
|
|
628
|
+
config = super().get_config()
|
|
629
|
+
config.update({
|
|
630
|
+
"heads": self._heads,
|
|
631
|
+
'activation': keras.activations.serialize(self._activation),
|
|
632
|
+
'dropout': self._dropout,
|
|
633
|
+
'update_edge_feature': self._update_edge_feature,
|
|
634
|
+
'attention_activation': keras.activations.serialize(self._attention_activation),
|
|
635
|
+
})
|
|
636
|
+
return config
|
|
637
|
+
|
|
638
|
+
|
|
639
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
640
|
+
class GTConv(GraphConv):
|
|
641
|
+
|
|
642
|
+
"""Graph transformer layer.
|
|
643
|
+
"""
|
|
644
|
+
|
|
645
|
+
def __init__(
|
|
646
|
+
self,
|
|
647
|
+
units: int,
|
|
648
|
+
heads: int = 8,
|
|
649
|
+
activation: keras.layers.Activation | str | None = "relu",
|
|
650
|
+
use_bias: bool = True,
|
|
651
|
+
normalize: bool = True,
|
|
652
|
+
dropout: float = 0.0,
|
|
653
|
+
attention_dropout: float = 0.0,
|
|
654
|
+
**kwargs,
|
|
655
|
+
) -> None:
|
|
656
|
+
kwargs['skip_connection'] = False
|
|
657
|
+
super().__init__(
|
|
658
|
+
units=units,
|
|
659
|
+
normalize=normalize,
|
|
660
|
+
use_bias=use_bias,
|
|
661
|
+
**kwargs
|
|
662
|
+
)
|
|
663
|
+
self._heads = heads
|
|
664
|
+
if self.units % self.heads != 0:
|
|
665
|
+
raise ValueError(f"units need to be divisible by heads.")
|
|
666
|
+
self._head_units = self.units // self.heads
|
|
667
|
+
self._activation = keras.activations.get(activation)
|
|
668
|
+
self._dropout = dropout
|
|
669
|
+
self._attention_dropout = attention_dropout
|
|
670
|
+
self._normalize = normalize
|
|
671
|
+
|
|
672
|
+
@property
|
|
673
|
+
def heads(self):
|
|
674
|
+
return self._heads
|
|
675
|
+
|
|
676
|
+
@property
|
|
677
|
+
def head_units(self):
|
|
678
|
+
return self._head_units
|
|
679
|
+
|
|
680
|
+
def build_from_spec(self, spec):
|
|
681
|
+
"""Builds the layer.
|
|
682
|
+
"""
|
|
683
|
+
node_feature_dim = spec.node['feature'].shape[-1]
|
|
684
|
+
self.project_residual = node_feature_dim != self.units
|
|
685
|
+
if self.project_residual:
|
|
686
|
+
warn(
|
|
687
|
+
'`GTConv` uses residual connections, but found incompatible dim '
|
|
688
|
+
'between input (node feature dim) and output (`self.units`). '
|
|
689
|
+
'Automatically applying a projection layer to residual to '
|
|
690
|
+
'match input and output. '
|
|
691
|
+
)
|
|
692
|
+
self._residual_dense = self.get_dense(self.units)
|
|
693
|
+
self._residual_dense.build([None, node_feature_dim])
|
|
694
|
+
|
|
695
|
+
self._query_dense = self.get_einsum_dense(
|
|
696
|
+
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
697
|
+
)
|
|
698
|
+
self._query_dense.build([None, node_feature_dim])
|
|
699
|
+
|
|
700
|
+
self._key_dense = self.get_einsum_dense(
|
|
701
|
+
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
702
|
+
)
|
|
703
|
+
self._key_dense.build([None, node_feature_dim])
|
|
704
|
+
|
|
705
|
+
self._value_dense = self.get_einsum_dense(
|
|
706
|
+
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
707
|
+
)
|
|
708
|
+
self._value_dense.build([None, node_feature_dim])
|
|
709
|
+
|
|
710
|
+
self._output_dense = self.get_dense(self.units)
|
|
711
|
+
self._output_dense.build([None, self.units])
|
|
712
|
+
|
|
713
|
+
self._softmax_dropout = keras.layers.Dropout(self._attention_dropout)
|
|
714
|
+
|
|
715
|
+
self._self_attention_dropout = keras.layers.Dropout(self._dropout)
|
|
716
|
+
|
|
717
|
+
self._add_bias = not 'bias' in spec.edge
|
|
718
|
+
|
|
719
|
+
if self._add_bias:
|
|
720
|
+
self._edge_bias = EdgeBias(biases=self.heads)
|
|
721
|
+
self._edge_bias.build_from_spec(spec)
|
|
722
|
+
|
|
723
|
+
has_overridden_update = self.__class__.update != GTConv.update
|
|
724
|
+
if not has_overridden_update:
|
|
725
|
+
|
|
726
|
+
if self._normalize:
|
|
727
|
+
self._feedforward_output_norm = keras.layers.LayerNormalization()
|
|
728
|
+
self._feedforward_output_norm.build([None, self.units])
|
|
729
|
+
|
|
730
|
+
self._feedforward_dropout = keras.layers.Dropout(self._dropout)
|
|
731
|
+
|
|
732
|
+
self._feedforward_intermediate_dense = self.get_dense(self.units)
|
|
733
|
+
self._feedforward_intermediate_dense.build([None, self.units])
|
|
734
|
+
|
|
735
|
+
self._feedforward_output_dense = self.get_dense(self.units)
|
|
736
|
+
self._feedforward_output_dense.build([None, self.units])
|
|
737
|
+
|
|
738
|
+
self.built = True
|
|
739
|
+
|
|
740
|
+
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
741
|
+
"""Computes messages.
|
|
742
|
+
"""
|
|
743
|
+
if self._add_bias:
|
|
744
|
+
edge_bias = self._edge_bias(tensor)
|
|
745
|
+
tensor = tensor.update(
|
|
746
|
+
{
|
|
747
|
+
'edge': {
|
|
748
|
+
'bias': edge_bias
|
|
749
|
+
}
|
|
750
|
+
}
|
|
751
|
+
)
|
|
752
|
+
|
|
753
|
+
node_feature = tensor.node['feature']
|
|
754
|
+
|
|
755
|
+
query = self._query_dense(node_feature)
|
|
756
|
+
key = self._key_dense(node_feature)
|
|
757
|
+
value = self._value_dense(node_feature)
|
|
758
|
+
|
|
759
|
+
query = ops.gather(query, tensor.edge['source'])
|
|
760
|
+
key = ops.gather(key, tensor.edge['target'])
|
|
761
|
+
value = ops.gather(value, tensor.edge['source'])
|
|
762
|
+
|
|
763
|
+
attention_score = keras.ops.sum(query * key, axis=1, keepdims=True)
|
|
764
|
+
attention_score /= keras.ops.sqrt(float(self.head_units))
|
|
765
|
+
|
|
766
|
+
attention_score += keras.ops.expand_dims(tensor.edge['bias'], axis=1)
|
|
767
|
+
|
|
768
|
+
attention = ops.edge_softmax(attention_score, tensor.edge['target'])
|
|
769
|
+
attention = self._softmax_dropout(attention)
|
|
770
|
+
|
|
771
|
+
return tensor.update(
|
|
772
|
+
{
|
|
773
|
+
'edge': {
|
|
774
|
+
'message': value,
|
|
775
|
+
'weight': attention,
|
|
776
|
+
},
|
|
777
|
+
}
|
|
778
|
+
)
|
|
779
|
+
|
|
780
|
+
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
781
|
+
"""Aggregates messages.
|
|
782
|
+
"""
|
|
783
|
+
node_feature = tensor.aggregate('message')
|
|
784
|
+
node_feature = keras.ops.reshape(node_feature, (-1, self.units))
|
|
785
|
+
node_feature = self._output_dense(node_feature)
|
|
786
|
+
node_feature = self._self_attention_dropout(node_feature)
|
|
787
|
+
return tensor.update(
|
|
788
|
+
{
|
|
789
|
+
'node': {
|
|
790
|
+
'feature': node_feature,
|
|
791
|
+
'residual': tensor.node['feature']
|
|
792
|
+
},
|
|
793
|
+
'edge': {
|
|
794
|
+
'message': None,
|
|
795
|
+
'weight': None,
|
|
796
|
+
}
|
|
797
|
+
}
|
|
798
|
+
)
|
|
799
|
+
|
|
800
|
+
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
801
|
+
"""Updates nodes.
|
|
802
|
+
"""
|
|
803
|
+
node_feature = tensor.node['feature']
|
|
804
|
+
|
|
805
|
+
residual = tensor.node['residual']
|
|
806
|
+
if self.project_residual:
|
|
807
|
+
residual = self._residual_dense(residual)
|
|
808
|
+
|
|
809
|
+
node_feature += residual
|
|
810
|
+
residual = node_feature
|
|
811
|
+
|
|
812
|
+
node_feature = self._feedforward_intermediate_dense(node_feature)
|
|
813
|
+
node_feature = self._activation(node_feature)
|
|
814
|
+
node_feature = self._feedforward_output_dense(node_feature)
|
|
815
|
+
node_feature = self._feedforward_dropout(node_feature)
|
|
816
|
+
if self._normalize:
|
|
817
|
+
node_feature = self._feedforward_output_norm(node_feature)
|
|
818
|
+
|
|
819
|
+
node_feature += residual
|
|
820
|
+
|
|
821
|
+
return tensor.update(
|
|
822
|
+
{
|
|
823
|
+
'node': {
|
|
824
|
+
'feature': node_feature,
|
|
825
|
+
},
|
|
826
|
+
}
|
|
827
|
+
)
|
|
828
|
+
|
|
829
|
+
def get_config(self) -> dict:
|
|
830
|
+
config = super().get_config()
|
|
831
|
+
config.update({
|
|
832
|
+
"heads": self._heads,
|
|
833
|
+
'activation': keras.activations.serialize(self._activation),
|
|
834
|
+
'dropout': self._dropout,
|
|
835
|
+
'attention_dropout': self._attention_dropout,
|
|
836
|
+
})
|
|
837
|
+
return config
|
|
838
|
+
|
|
839
|
+
|
|
840
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
841
|
+
class GTConv3D(GTConv):
|
|
842
|
+
|
|
843
|
+
"""Graph transformer 3D layer.
|
|
844
|
+
"""
|
|
845
|
+
|
|
846
|
+
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
847
|
+
super().build_from_spec(spec)
|
|
848
|
+
if self._add_bias:
|
|
849
|
+
node_feature_dim = spec.node['feature'].shape[-1]
|
|
850
|
+
kernels = self.units
|
|
851
|
+
self._gaussian_basis = GaussianDistance(kernels)
|
|
852
|
+
self._gaussian_basis.build_from_spec(spec)
|
|
853
|
+
self._centrality_dense = self.get_dense(units=node_feature_dim)
|
|
854
|
+
self._centrality_dense.build([None, kernels])
|
|
855
|
+
self._gaussian_edge_bias = self.get_dense(self.heads)
|
|
856
|
+
self._gaussian_edge_bias.build([None, kernels])
|
|
857
|
+
|
|
858
|
+
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
859
|
+
"""Computes messages.
|
|
860
|
+
"""
|
|
861
|
+
node_feature = tensor.node['feature']
|
|
862
|
+
|
|
863
|
+
if self._add_bias:
|
|
864
|
+
gaussian = self._gaussian_basis(tensor)
|
|
865
|
+
centrality = keras.ops.segment_sum(
|
|
866
|
+
gaussian, tensor.edge['target'], tensor.num_nodes
|
|
867
|
+
)
|
|
868
|
+
node_feature += self._centrality_dense(centrality)
|
|
869
|
+
|
|
870
|
+
edge_bias = self._edge_bias(tensor) + self._gaussian_edge_bias(gaussian)
|
|
871
|
+
tensor = tensor.update({'edge': {'bias': edge_bias}})
|
|
872
|
+
|
|
873
|
+
query = self._query_dense(node_feature)
|
|
874
|
+
key = self._key_dense(node_feature)
|
|
875
|
+
value = self._value_dense(node_feature)
|
|
876
|
+
|
|
877
|
+
query = ops.gather(query, tensor.edge['source'])
|
|
878
|
+
key = ops.gather(key, tensor.edge['target'])
|
|
879
|
+
value = ops.gather(value, tensor.edge['source'])
|
|
880
|
+
|
|
881
|
+
attention_score = keras.ops.sum(query * key, axis=1, keepdims=True)
|
|
882
|
+
attention_score /= keras.ops.sqrt(float(self.head_units))
|
|
883
|
+
|
|
884
|
+
attention_score += keras.ops.expand_dims(tensor.edge['bias'], axis=1)
|
|
885
|
+
|
|
886
|
+
attention = ops.edge_softmax(attention_score, tensor.edge['target'])
|
|
887
|
+
attention = self._softmax_dropout(attention)
|
|
888
|
+
|
|
889
|
+
distance = keras.ops.subtract(
|
|
890
|
+
tensor.gather('coordinate', 'source'),
|
|
891
|
+
tensor.gather('coordinate', 'target')
|
|
892
|
+
)
|
|
893
|
+
euclidean_distance = ops.euclidean_distance(
|
|
894
|
+
tensor.gather('coordinate', 'source'),
|
|
895
|
+
tensor.gather('coordinate', 'target'),
|
|
896
|
+
axis=-1
|
|
897
|
+
)
|
|
898
|
+
distance /= euclidean_distance
|
|
899
|
+
|
|
900
|
+
attention *= keras.ops.expand_dims(distance, axis=-1)
|
|
901
|
+
attention = keras.ops.expand_dims(attention, axis=2)
|
|
902
|
+
value = keras.ops.expand_dims(value, axis=1)
|
|
903
|
+
|
|
904
|
+
return tensor.update(
|
|
905
|
+
{
|
|
906
|
+
'edge': {
|
|
907
|
+
'message': value,
|
|
908
|
+
'weight': attention,
|
|
909
|
+
},
|
|
910
|
+
}
|
|
911
|
+
)
|
|
912
|
+
|
|
913
|
+
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
914
|
+
"""Aggregates messages.
|
|
915
|
+
"""
|
|
916
|
+
node_feature = tensor.aggregate('message')
|
|
917
|
+
node_feature = keras.ops.reshape(
|
|
918
|
+
node_feature, (tensor.num_nodes, -1, self.units)
|
|
919
|
+
)
|
|
920
|
+
node_feature = self._output_dense(node_feature)
|
|
921
|
+
node_feature = keras.ops.sum(node_feature, axis=1)
|
|
922
|
+
node_feature = self._self_attention_dropout(node_feature)
|
|
923
|
+
return tensor.update(
|
|
924
|
+
{
|
|
925
|
+
'node': {
|
|
926
|
+
'feature': node_feature,
|
|
927
|
+
'residual': tensor.node['feature']
|
|
928
|
+
},
|
|
929
|
+
'edge': {
|
|
930
|
+
'message': None,
|
|
931
|
+
'weight': None,
|
|
932
|
+
}
|
|
933
|
+
}
|
|
934
|
+
)
|
|
935
|
+
|
|
936
|
+
|
|
937
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
938
|
+
class MPConv(GraphConv):
|
|
939
|
+
|
|
940
|
+
"""Message passing neural network layer.
|
|
941
|
+
"""
|
|
942
|
+
|
|
943
|
+
def __init__(
|
|
944
|
+
self,
|
|
945
|
+
units: int = 128,
|
|
946
|
+
activation: keras.layers.Activation | str | None = None,
|
|
947
|
+
use_bias: bool = True,
|
|
948
|
+
normalize: bool = True,
|
|
949
|
+
dropout: float = 0.0,
|
|
950
|
+
**kwargs
|
|
951
|
+
) -> None:
|
|
952
|
+
super().__init__(
|
|
953
|
+
units=units,
|
|
954
|
+
normalize=normalize,
|
|
955
|
+
use_bias=use_bias,
|
|
956
|
+
**kwargs
|
|
957
|
+
)
|
|
958
|
+
self._activation = keras.activations.get(activation)
|
|
959
|
+
self._dropout = dropout or 0.0
|
|
960
|
+
|
|
961
|
+
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
962
|
+
node_feature_dim = spec.node['feature'].shape[-1]
|
|
963
|
+
self.message_fn = self.get_dense(self.units, activation=self._activation)
|
|
964
|
+
self.update_fn = keras.layers.GRUCell(self.units)
|
|
965
|
+
self._has_edge_feature = 'feature' in spec.edge
|
|
966
|
+
self.project_input_node_feature = node_feature_dim != self.units
|
|
967
|
+
if self.project_input_node_feature:
|
|
968
|
+
warn(
|
|
969
|
+
'Input node feature dim does not match updated node feature dim. '
|
|
970
|
+
'To make sure input node feature can be passed as `states` to the '
|
|
971
|
+
'GRU cell, it will automatically be projected prior to it.'
|
|
972
|
+
)
|
|
973
|
+
self._previous_node_dense = self.get_dense(self.units, activation=self._activation)
|
|
974
|
+
self.built = True
|
|
975
|
+
|
|
976
|
+
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
977
|
+
feature = keras.ops.concatenate(
|
|
978
|
+
[
|
|
979
|
+
tensor.gather('feature', 'source'),
|
|
980
|
+
tensor.gather('feature', 'target'),
|
|
981
|
+
],
|
|
982
|
+
axis=-1
|
|
983
|
+
)
|
|
984
|
+
if self._has_edge_feature:
|
|
985
|
+
feature = keras.ops.concatenate(
|
|
986
|
+
[
|
|
987
|
+
feature,
|
|
988
|
+
tensor.edge['feature']
|
|
989
|
+
],
|
|
990
|
+
axis=-1
|
|
991
|
+
)
|
|
992
|
+
message = self.message_fn(feature)
|
|
993
|
+
return tensor.update(
|
|
994
|
+
{
|
|
995
|
+
'edge': {
|
|
996
|
+
'message': message,
|
|
997
|
+
}
|
|
998
|
+
}
|
|
999
|
+
)
|
|
1000
|
+
|
|
1001
|
+
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1002
|
+
aggregate = tensor.aggregate('message')
|
|
1003
|
+
previous = tensor.node['feature']
|
|
1004
|
+
if self.project_input_node_feature:
|
|
1005
|
+
previous = self._previous_node_dense(previous)
|
|
1006
|
+
return tensor.update(
|
|
1007
|
+
{
|
|
1008
|
+
'node': {
|
|
1009
|
+
'feature': aggregate,
|
|
1010
|
+
'previous_feature': previous,
|
|
1011
|
+
}
|
|
1012
|
+
}
|
|
1013
|
+
)
|
|
1014
|
+
|
|
1015
|
+
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1016
|
+
updated_node_feature, _ = self.update_fn(
|
|
1017
|
+
inputs=tensor.node['feature'],
|
|
1018
|
+
states=tensor.node['previous_feature']
|
|
1019
|
+
)
|
|
1020
|
+
return tensor.update(
|
|
1021
|
+
{
|
|
1022
|
+
'node': {
|
|
1023
|
+
'feature': updated_node_feature,
|
|
1024
|
+
'previous_feature': None,
|
|
1025
|
+
}
|
|
1026
|
+
}
|
|
1027
|
+
)
|
|
1028
|
+
|
|
1029
|
+
def get_config(self) -> dict:
|
|
1030
|
+
config = super().get_config()
|
|
1031
|
+
config.update({
|
|
1032
|
+
'activation': keras.activations.serialize(self._activation),
|
|
1033
|
+
'dropout': self._dropout,
|
|
1034
|
+
})
|
|
1035
|
+
return config
|
|
1036
|
+
|
|
1037
|
+
|
|
1038
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1039
|
+
class MPConv3D(MPConv):
|
|
1040
|
+
|
|
1041
|
+
"""3D Message passing neural network layer.
|
|
1042
|
+
"""
|
|
1043
|
+
|
|
1044
|
+
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1045
|
+
euclidean_distance = ops.euclidean_distance(
|
|
1046
|
+
tensor.gather('coordinate', 'target'),
|
|
1047
|
+
tensor.gather('coordinate', 'source'),
|
|
1048
|
+
axis=-1
|
|
1049
|
+
)
|
|
1050
|
+
feature = keras.ops.concatenate(
|
|
1051
|
+
[
|
|
1052
|
+
tensor.gather('feature', 'source'),
|
|
1053
|
+
tensor.gather('feature', 'target'),
|
|
1054
|
+
euclidean_distance,
|
|
1055
|
+
],
|
|
1056
|
+
axis=-1
|
|
1057
|
+
)
|
|
1058
|
+
if self._has_edge_feature:
|
|
1059
|
+
feature = keras.ops.concatenate(
|
|
1060
|
+
[
|
|
1061
|
+
feature,
|
|
1062
|
+
tensor.edge['feature']
|
|
1063
|
+
],
|
|
1064
|
+
axis=-1
|
|
1065
|
+
)
|
|
1066
|
+
message = self.message_fn(feature)
|
|
1067
|
+
return tensor.update(
|
|
1068
|
+
{
|
|
1069
|
+
'edge': {
|
|
1070
|
+
'message': message,
|
|
1071
|
+
}
|
|
1072
|
+
}
|
|
1073
|
+
)
|
|
1074
|
+
|
|
1075
|
+
|
|
1076
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1077
|
+
class EGConv3D(GraphConv):
|
|
1078
|
+
|
|
1079
|
+
"""Equivariant graph neural network layer.
|
|
1080
|
+
"""
|
|
1081
|
+
|
|
1082
|
+
def __init__(
|
|
1083
|
+
self,
|
|
1084
|
+
units: int = 128,
|
|
1085
|
+
activation: keras.layers.Activation | str | None = None,
|
|
1086
|
+
use_bias: bool = True,
|
|
1087
|
+
normalize: bool = True,
|
|
1088
|
+
dropout: float = 0.0,
|
|
1089
|
+
**kwargs
|
|
1090
|
+
) -> None:
|
|
1091
|
+
super().__init__(
|
|
1092
|
+
units=units,
|
|
1093
|
+
normalize=normalize,
|
|
1094
|
+
use_bias=use_bias,
|
|
1095
|
+
**kwargs
|
|
1096
|
+
)
|
|
1097
|
+
self._activation = keras.activations.get(activation)
|
|
1098
|
+
self._dropout = dropout or 0.0
|
|
1099
|
+
|
|
1100
|
+
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1101
|
+
if 'coordinate' not in spec.node:
|
|
1102
|
+
raise ValueError(
|
|
1103
|
+
'Could not find `coordinate`s in node, '
|
|
1104
|
+
'which is required for Conv3D layers.'
|
|
1105
|
+
)
|
|
1106
|
+
node_feature_dim = spec.node['feature'].shape[-1]
|
|
1107
|
+
feature_dim = node_feature_dim + node_feature_dim + 1
|
|
1108
|
+
if 'feature' in spec.edge:
|
|
1109
|
+
self._has_edge_feature = True
|
|
1110
|
+
edge_feature_dim = spec.edge['feature'].shape[-1]
|
|
1111
|
+
feature_dim += edge_feature_dim
|
|
1112
|
+
else:
|
|
1113
|
+
self._has_edge_feature = False
|
|
1114
|
+
|
|
1115
|
+
self.message_fn = self.get_dense(self.units, activation=self._activation)
|
|
1116
|
+
self.message_fn.build([None, feature_dim])
|
|
1117
|
+
self.dense_position = self.get_dense(1)
|
|
1118
|
+
self.dense_position.build([None, self.units])
|
|
1119
|
+
|
|
1120
|
+
has_overridden_update = self.__class__.update != EGConv3D.update
|
|
1121
|
+
if not has_overridden_update:
|
|
1122
|
+
self.update_fn = self.get_dense(self.units, activation=self._activation)
|
|
1123
|
+
self.update_fn.build([None, node_feature_dim + self.units])
|
|
1124
|
+
self._dropout_layer = keras.layers.Dropout(self._dropout)
|
|
1125
|
+
self.built = True
|
|
1126
|
+
|
|
1127
|
+
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1128
|
+
"""Computes messages.
|
|
1129
|
+
"""
|
|
1130
|
+
relative_node_coordinate = keras.ops.subtract(
|
|
1131
|
+
tensor.gather('coordinate', 'target'),
|
|
1132
|
+
tensor.gather('coordinate', 'source')
|
|
1133
|
+
)
|
|
1134
|
+
euclidean_distance = keras.ops.sum(
|
|
1135
|
+
keras.ops.square(
|
|
1136
|
+
relative_node_coordinate
|
|
1137
|
+
),
|
|
1138
|
+
axis=-1,
|
|
1139
|
+
keepdims=True
|
|
1140
|
+
)
|
|
1141
|
+
feature = keras.ops.concatenate(
|
|
1142
|
+
[
|
|
1143
|
+
tensor.gather('feature', 'target'),
|
|
1144
|
+
tensor.gather('feature', 'source'),
|
|
1145
|
+
euclidean_distance,
|
|
1146
|
+
],
|
|
1147
|
+
axis=-1
|
|
1148
|
+
)
|
|
1149
|
+
if self._has_edge_feature:
|
|
1150
|
+
feature = keras.ops.concatenate(
|
|
1151
|
+
[
|
|
1152
|
+
feature,
|
|
1153
|
+
tensor.edge['feature']
|
|
1154
|
+
],
|
|
1155
|
+
axis=-1
|
|
1156
|
+
)
|
|
1157
|
+
message = self.message_fn(feature)
|
|
1158
|
+
relative_node_coordinate = keras.ops.multiply(
|
|
1159
|
+
relative_node_coordinate,
|
|
1160
|
+
self.dense_position(message)
|
|
1161
|
+
)
|
|
1162
|
+
return tensor.update(
|
|
1163
|
+
{
|
|
1164
|
+
'edge': {
|
|
1165
|
+
'message': message,
|
|
1166
|
+
'relative_node_coordinate': relative_node_coordinate
|
|
1167
|
+
}
|
|
1168
|
+
}
|
|
1169
|
+
)
|
|
1170
|
+
|
|
1171
|
+
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1172
|
+
"""Aggregates messages.
|
|
1173
|
+
"""
|
|
1174
|
+
coefficient = keras.ops.bincount(
|
|
1175
|
+
tensor.edge['source'],
|
|
1176
|
+
minlength=tensor.num_nodes
|
|
1177
|
+
)
|
|
1178
|
+
coefficient = keras.ops.cast(
|
|
1179
|
+
coefficient, tensor.node['coordinate'].dtype
|
|
1180
|
+
)
|
|
1181
|
+
coefficient = keras.ops.expand_dims(
|
|
1182
|
+
keras.ops.divide_no_nan(1, coefficient), axis=1
|
|
1183
|
+
)
|
|
1184
|
+
|
|
1185
|
+
updated_coordinate = tensor.aggregate('relative_node_coordinate') * coefficient
|
|
1186
|
+
updated_coordinate += tensor.node['coordinate']
|
|
1187
|
+
|
|
1188
|
+
aggregate = tensor.aggregate('message')
|
|
1189
|
+
return tensor.update(
|
|
1190
|
+
{
|
|
1191
|
+
'node': {
|
|
1192
|
+
'feature': aggregate,
|
|
1193
|
+
'coordinate': updated_coordinate,
|
|
1194
|
+
'previous_feature': tensor.node['feature'],
|
|
1195
|
+
},
|
|
1196
|
+
'edge': {
|
|
1197
|
+
'message': None,
|
|
1198
|
+
'relative_node_coordinate': None
|
|
1199
|
+
}
|
|
1200
|
+
}
|
|
1201
|
+
)
|
|
1202
|
+
|
|
1203
|
+
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1204
|
+
"""Updates nodes.
|
|
1205
|
+
"""
|
|
1206
|
+
updated_node_feature = self.update_fn(
|
|
1207
|
+
keras.ops.concatenate(
|
|
1208
|
+
[
|
|
1209
|
+
tensor.node['feature'],
|
|
1210
|
+
tensor.node['previous_feature']
|
|
1211
|
+
],
|
|
1212
|
+
axis=-1
|
|
1213
|
+
)
|
|
1214
|
+
)
|
|
1215
|
+
updated_node_feature = self._dropout_layer(updated_node_feature)
|
|
1216
|
+
return tensor.update(
|
|
1217
|
+
{
|
|
1218
|
+
'node': {
|
|
1219
|
+
'feature': updated_node_feature,
|
|
1220
|
+
'previous_feature': None,
|
|
1221
|
+
},
|
|
1222
|
+
}
|
|
1223
|
+
)
|
|
1224
|
+
|
|
1225
|
+
def get_config(self) -> dict:
|
|
1226
|
+
config = super().get_config()
|
|
1227
|
+
config.update({
|
|
1228
|
+
'activation': keras.activations.serialize(self._activation),
|
|
1229
|
+
'dropout': self._dropout,
|
|
1230
|
+
})
|
|
1231
|
+
return config
|
|
1232
|
+
|
|
1233
|
+
|
|
1234
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1235
|
+
class Projection(GraphLayer):
|
|
1236
|
+
"""Base graph projection layer.
|
|
1237
|
+
"""
|
|
1238
|
+
def __init__(
|
|
1239
|
+
self,
|
|
1240
|
+
units: int = None,
|
|
1241
|
+
activation: str = None,
|
|
1242
|
+
field: str = 'node',
|
|
1243
|
+
**kwargs
|
|
1244
|
+
) -> None:
|
|
1245
|
+
super().__init__(**kwargs)
|
|
1246
|
+
self.units = units
|
|
1247
|
+
self._activation = keras.activations.get(activation)
|
|
1248
|
+
self.field = field
|
|
1249
|
+
|
|
1250
|
+
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1251
|
+
"""Builds the layer.
|
|
1252
|
+
"""
|
|
1253
|
+
data = getattr(spec, self.field, None)
|
|
1254
|
+
if data is None:
|
|
1255
|
+
raise ValueError('Could not access field {self.field!r}.')
|
|
1256
|
+
feature_dim = data['feature'].shape[-1]
|
|
1257
|
+
if not self.units:
|
|
1258
|
+
self.units = feature_dim
|
|
1259
|
+
self._dense = self.get_dense(self.units)
|
|
1260
|
+
self._dense.build([None, feature_dim])
|
|
1261
|
+
self.built = True
|
|
1262
|
+
|
|
1263
|
+
def propagate(self, tensor: tensors.GraphTensor):
|
|
1264
|
+
"""Calls the layer.
|
|
1265
|
+
"""
|
|
1266
|
+
feature = getattr(tensor, self.field)['feature']
|
|
1267
|
+
feature = self._dense(feature)
|
|
1268
|
+
feature = self._activation(feature)
|
|
1269
|
+
return tensor.update(
|
|
1270
|
+
{
|
|
1271
|
+
self.field: {
|
|
1272
|
+
'feature': feature
|
|
1273
|
+
}
|
|
1274
|
+
}
|
|
1275
|
+
)
|
|
1276
|
+
|
|
1277
|
+
def get_config(self) -> dict:
|
|
1278
|
+
config = super().get_config()
|
|
1279
|
+
config.update({
|
|
1280
|
+
'units': self.units,
|
|
1281
|
+
'activation': keras.activations.serialize(self._activation),
|
|
1282
|
+
'field': self.field,
|
|
1283
|
+
})
|
|
1284
|
+
return config
|
|
1285
|
+
|
|
1286
|
+
|
|
1287
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1288
|
+
class GraphNetwork(GraphLayer):
|
|
1289
|
+
|
|
1290
|
+
"""Graph neural network.
|
|
1291
|
+
|
|
1292
|
+
Sequentially calls graph layers (`GraphLayer`) and concatenates its output.
|
|
1293
|
+
|
|
1294
|
+
Args:
|
|
1295
|
+
layers (list):
|
|
1296
|
+
A list of graph layers.
|
|
1297
|
+
"""
|
|
1298
|
+
|
|
1299
|
+
def __init__(self, layers: list[GraphLayer], **kwargs) -> None:
|
|
1300
|
+
super().__init__(**kwargs)
|
|
1301
|
+
self.layers = layers
|
|
1302
|
+
self._update_edge_feature = False
|
|
1303
|
+
|
|
1304
|
+
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1305
|
+
"""Builds the layer.
|
|
1306
|
+
"""
|
|
1307
|
+
units = self.layers[0].units
|
|
1308
|
+
node_feature_dim = spec.node['feature'].shape[-1]
|
|
1309
|
+
if node_feature_dim != units:
|
|
1310
|
+
warn(
|
|
1311
|
+
'Node feature dim does not match `units` of the first layer. '
|
|
1312
|
+
'Automatically adding a node projection layer to match `units`.'
|
|
1313
|
+
)
|
|
1314
|
+
self._node_dense = self.get_dense(units)
|
|
1315
|
+
self._update_node_feature = True
|
|
1316
|
+
has_edge_feature = 'feature' in spec.edge
|
|
1317
|
+
if has_edge_feature:
|
|
1318
|
+
edge_feature_dim = spec.edge['feature'].shape[-1]
|
|
1319
|
+
if edge_feature_dim != units:
|
|
1320
|
+
warn(
|
|
1321
|
+
'Edge feature dim does not match `units` of the first layer. '
|
|
1322
|
+
'Automatically adding a edge projection layer to match `units`.'
|
|
1323
|
+
)
|
|
1324
|
+
self._edge_dense = self.get_dense(units)
|
|
1325
|
+
self._update_edge_feature = True
|
|
1326
|
+
self.built = True
|
|
1327
|
+
|
|
1328
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1329
|
+
"""Calls the layer.
|
|
1330
|
+
"""
|
|
1331
|
+
x = tensors.to_dict(tensor)
|
|
1332
|
+
if self._update_node_feature:
|
|
1333
|
+
x['node']['feature'] = self._node_dense(tensor.node['feature'])
|
|
1334
|
+
if self._update_edge_feature:
|
|
1335
|
+
x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
|
|
1336
|
+
outputs = [x['node']['feature']]
|
|
1337
|
+
for layer in self.layers:
|
|
1338
|
+
x = layer(x)
|
|
1339
|
+
outputs.append(x['node']['feature'])
|
|
1340
|
+
return tensor.update(
|
|
1341
|
+
{
|
|
1342
|
+
'node': {
|
|
1343
|
+
'feature': keras.ops.concatenate(outputs, axis=-1)
|
|
1344
|
+
}
|
|
1345
|
+
}
|
|
1346
|
+
)
|
|
1347
|
+
|
|
1348
|
+
def tape_propagate(
|
|
1349
|
+
self,
|
|
1350
|
+
tensor: tensors.GraphTensor,
|
|
1351
|
+
tape: tf.GradientTape,
|
|
1352
|
+
training: bool | None = None,
|
|
1353
|
+
) -> tuple[tensors.GraphTensor, list[tf.Tensor]]:
|
|
1354
|
+
"""Performs the propagation with a `GradientTape`.
|
|
1355
|
+
|
|
1356
|
+
Performs the same forward pass as `propagate` but with a `GradientTape`
|
|
1357
|
+
watching intermediate node features.
|
|
1358
|
+
|
|
1359
|
+
Args:
|
|
1360
|
+
tensor (tensors.GraphTensor):
|
|
1361
|
+
The graph input.
|
|
1362
|
+
"""
|
|
1363
|
+
if isinstance(tensor, tensors.GraphTensor):
|
|
1364
|
+
x = tensors.to_dict(tensor)
|
|
1365
|
+
else:
|
|
1366
|
+
x = tensor
|
|
1367
|
+
if self._update_node_feature:
|
|
1368
|
+
x['node']['feature'] = self._node_dense(tensor.node['feature'])
|
|
1369
|
+
if self._update_edge_feature:
|
|
1370
|
+
x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
|
|
1371
|
+
tape.watch(x['node']['feature'])
|
|
1372
|
+
outputs = [x['node']['feature']]
|
|
1373
|
+
for layer in self.layers:
|
|
1374
|
+
x = layer(x, training=training)
|
|
1375
|
+
tape.watch(x['node']['feature'])
|
|
1376
|
+
outputs.append(x['node']['feature'])
|
|
1377
|
+
|
|
1378
|
+
tensor = tensor.update(
|
|
1379
|
+
{
|
|
1380
|
+
'node': {
|
|
1381
|
+
'feature': keras.ops.concatenate(outputs, axis=-1)
|
|
1382
|
+
}
|
|
1383
|
+
}
|
|
1384
|
+
)
|
|
1385
|
+
return tensor, outputs
|
|
1386
|
+
|
|
1387
|
+
def get_config(self) -> dict:
|
|
1388
|
+
config = super().get_config()
|
|
1389
|
+
config.update(
|
|
1390
|
+
{
|
|
1391
|
+
'layers': [
|
|
1392
|
+
keras.layers.serialize(layer) for layer in self.layers
|
|
1393
|
+
]
|
|
1394
|
+
}
|
|
1395
|
+
)
|
|
1396
|
+
return config
|
|
1397
|
+
|
|
1398
|
+
@classmethod
|
|
1399
|
+
def from_config(cls, config: dict) -> 'GraphNetwork':
|
|
1400
|
+
config['layers'] = [
|
|
1401
|
+
keras.layers.deserialize(layer) for layer in config['layers']
|
|
1402
|
+
]
|
|
1403
|
+
return super().from_config(config)
|
|
1404
|
+
|
|
1405
|
+
|
|
1406
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1407
|
+
class NodeEmbedding(GraphLayer):
|
|
1408
|
+
|
|
1409
|
+
"""Node embedding layer.
|
|
1410
|
+
|
|
1411
|
+
Embeds nodes based on its initial features.
|
|
1412
|
+
"""
|
|
1413
|
+
|
|
1414
|
+
def __init__(
|
|
1415
|
+
self,
|
|
1416
|
+
dim: int = None,
|
|
1417
|
+
normalize: bool = True,
|
|
1418
|
+
embed_context: bool = True,
|
|
1419
|
+
allow_masking: bool = True,
|
|
455
1420
|
**kwargs
|
|
456
1421
|
) -> None:
|
|
457
1422
|
super().__init__(**kwargs)
|
|
458
1423
|
self.dim = dim
|
|
1424
|
+
self._normalize = normalize
|
|
459
1425
|
self._embed_context = embed_context
|
|
460
1426
|
self._masking_rate = None
|
|
461
1427
|
self._allow_masking = allow_masking
|
|
@@ -482,6 +1448,12 @@ class NodeEmbedding(GraphLayer):
|
|
|
482
1448
|
context_feature_dim = spec.context['feature'].shape[-1]
|
|
483
1449
|
self._context_dense = self.get_dense(self.dim)
|
|
484
1450
|
self._context_dense.build([None, context_feature_dim])
|
|
1451
|
+
|
|
1452
|
+
if self._normalize:
|
|
1453
|
+
self._norm = keras.layers.LayerNormalization()
|
|
1454
|
+
self._norm.build([None, self.dim])
|
|
1455
|
+
|
|
1456
|
+
self.built = True
|
|
485
1457
|
|
|
486
1458
|
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
487
1459
|
"""Calls the layer.
|
|
@@ -515,6 +1487,9 @@ class NodeEmbedding(GraphLayer):
|
|
|
515
1487
|
# Slience warning of 'no gradients for variables'
|
|
516
1488
|
feature = feature + (self._mask_feature * 0.0)
|
|
517
1489
|
|
|
1490
|
+
if self._normalize:
|
|
1491
|
+
feature = self._norm(feature)
|
|
1492
|
+
|
|
518
1493
|
return tensor.update({'node': {'feature': feature}})
|
|
519
1494
|
|
|
520
1495
|
@property
|
|
@@ -534,6 +1509,8 @@ class NodeEmbedding(GraphLayer):
|
|
|
534
1509
|
config = super().get_config()
|
|
535
1510
|
config.update({
|
|
536
1511
|
'dim': self.dim,
|
|
1512
|
+
'normalize': self._normalize,
|
|
1513
|
+
'embed_context': self._embed_context,
|
|
537
1514
|
'allow_masking': self._allow_masking
|
|
538
1515
|
})
|
|
539
1516
|
return config
|
|
@@ -544,503 +1521,210 @@ class EdgeEmbedding(GraphLayer):
|
|
|
544
1521
|
|
|
545
1522
|
"""Edge embedding layer.
|
|
546
1523
|
|
|
547
|
-
Embeds edges based on its initial features.
|
|
548
|
-
"""
|
|
549
|
-
|
|
550
|
-
def __init__(
|
|
551
|
-
self,
|
|
552
|
-
dim: int = None,
|
|
553
|
-
allow_masking: bool = True,
|
|
554
|
-
**kwargs
|
|
555
|
-
) -> None:
|
|
556
|
-
super().__init__(**kwargs)
|
|
557
|
-
self.dim = dim
|
|
558
|
-
self._masking_rate = None
|
|
559
|
-
self._allow_masking = allow_masking
|
|
560
|
-
|
|
561
|
-
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
562
|
-
"""Builds the layer.
|
|
563
|
-
"""
|
|
564
|
-
feature_dim = spec.edge['feature'].shape[-1]
|
|
565
|
-
if not self.dim:
|
|
566
|
-
self.dim = feature_dim
|
|
567
|
-
self._edge_dense = self.get_dense(self.dim)
|
|
568
|
-
self._edge_dense.build([None, feature_dim])
|
|
569
|
-
|
|
570
|
-
self._has_super = 'super' in spec.edge
|
|
571
|
-
if self._has_super:
|
|
572
|
-
self._super_feature = self.get_weight(shape=[self.dim], name='super_edge_feature')
|
|
573
|
-
if self._allow_masking:
|
|
574
|
-
self._mask_feature = self.get_weight(shape=[self.dim], name='mask_edge_feature')
|
|
575
|
-
|
|
576
|
-
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
577
|
-
"""Calls the layer.
|
|
578
|
-
"""
|
|
579
|
-
feature = self._edge_dense(tensor.edge['feature'])
|
|
580
|
-
|
|
581
|
-
if self._has_super:
|
|
582
|
-
super_feature = self._super_feature
|
|
583
|
-
super_mask = keras.ops.expand_dims(tensor.edge['super'], 1)
|
|
584
|
-
feature = keras.ops.where(super_mask, super_feature, feature)
|
|
585
|
-
|
|
586
|
-
if (
|
|
587
|
-
self._allow_masking and
|
|
588
|
-
self._masking_rate is not None and
|
|
589
|
-
self._masking_rate > 0
|
|
590
|
-
):
|
|
591
|
-
random = keras.random.uniform(shape=[tensor.num_edges])
|
|
592
|
-
mask = random <= self._masking_rate
|
|
593
|
-
if self._has_super:
|
|
594
|
-
mask = keras.ops.logical_and(
|
|
595
|
-
mask, keras.ops.logical_not(tensor.edge['super'])
|
|
596
|
-
)
|
|
597
|
-
mask = keras.ops.expand_dims(mask, -1)
|
|
598
|
-
feature = keras.ops.where(mask, self._mask_feature, feature)
|
|
599
|
-
elif self._allow_masking:
|
|
600
|
-
# Slience warning of 'no gradients for variables'
|
|
601
|
-
feature = feature + (self._mask_feature * 0.0)
|
|
602
|
-
|
|
603
|
-
return tensor.update({'edge': {'feature': feature}})
|
|
604
|
-
|
|
605
|
-
@property
|
|
606
|
-
def masking_rate(self):
|
|
607
|
-
return self._masking_rate
|
|
608
|
-
|
|
609
|
-
@masking_rate.setter
|
|
610
|
-
def masking_rate(self, rate: float):
|
|
611
|
-
if not self._allow_masking and rate is not None:
|
|
612
|
-
raise ValueError(
|
|
613
|
-
f'Cannot set `masking_rate` for layer {self} '
|
|
614
|
-
'as `allow_masking` was set to `False`.'
|
|
615
|
-
)
|
|
616
|
-
self._masking_rate = float(rate)
|
|
617
|
-
|
|
618
|
-
def get_config(self) -> dict:
|
|
619
|
-
config = super().get_config()
|
|
620
|
-
config.update({
|
|
621
|
-
'dim': self.dim,
|
|
622
|
-
'allow_masking': self._allow_masking
|
|
623
|
-
})
|
|
624
|
-
return config
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
628
|
-
class ContextProjection(Projection):
|
|
629
|
-
"""Context projection layer.
|
|
630
|
-
"""
|
|
631
|
-
def __init__(self, units: int = None, activation: str = None, **kwargs):
|
|
632
|
-
super().__init__(units=units, activation=activation, field='context', **kwargs)
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
636
|
-
class NodeProjection(Projection):
|
|
637
|
-
"""Node projection layer.
|
|
638
|
-
"""
|
|
639
|
-
def __init__(self, units: int = None, activation: str = None, **kwargs):
|
|
640
|
-
super().__init__(units=units, activation=activation, field='node', **kwargs)
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
644
|
-
class EdgeProjection(Projection):
|
|
645
|
-
"""Edge projection layer.
|
|
646
|
-
"""
|
|
647
|
-
def __init__(self, units: int = None, activation: str = None, **kwargs):
|
|
648
|
-
super().__init__(units=units, activation=activation, field='edge', **kwargs)
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
652
|
-
class GINConv(GraphConv):
|
|
653
|
-
|
|
654
|
-
"""Graph isomorphism network layer.
|
|
655
|
-
"""
|
|
656
|
-
|
|
657
|
-
def __init__(
|
|
658
|
-
self,
|
|
659
|
-
units: int,
|
|
660
|
-
activation: keras.layers.Activation | str | None = 'relu',
|
|
661
|
-
dropout: float = 0.0,
|
|
662
|
-
normalize: bool = True,
|
|
663
|
-
update_edge_feature: bool = True,
|
|
664
|
-
**kwargs,
|
|
665
|
-
):
|
|
666
|
-
super().__init__(units=units, **kwargs)
|
|
667
|
-
self._activation = keras.activations.get(activation)
|
|
668
|
-
self._normalize = normalize
|
|
669
|
-
self._dropout = dropout
|
|
670
|
-
self._update_edge_feature = update_edge_feature
|
|
671
|
-
|
|
672
|
-
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
673
|
-
"""Builds the layer.
|
|
674
|
-
"""
|
|
675
|
-
node_feature_dim = spec.node['feature'].shape[-1]
|
|
676
|
-
|
|
677
|
-
self.epsilon = self.add_weight(
|
|
678
|
-
name='epsilon',
|
|
679
|
-
shape=(),
|
|
680
|
-
initializer='zeros',
|
|
681
|
-
trainable=True,
|
|
682
|
-
)
|
|
683
|
-
|
|
684
|
-
if 'feature' in spec.edge:
|
|
685
|
-
edge_feature_dim = spec.edge['feature'].shape[-1]
|
|
686
|
-
|
|
687
|
-
if not self._update_edge_feature:
|
|
688
|
-
if (edge_feature_dim != node_feature_dim):
|
|
689
|
-
warn(
|
|
690
|
-
'Found edge feature dim to be incompatible with node feature dim. '
|
|
691
|
-
'Automatically adding a edge feature projection layer to match '
|
|
692
|
-
'the dim of node features.'
|
|
693
|
-
)
|
|
694
|
-
self._update_edge_feature = True
|
|
695
|
-
|
|
696
|
-
if self._update_edge_feature:
|
|
697
|
-
self._edge_dense = self.get_dense(node_feature_dim)
|
|
698
|
-
self._edge_dense.build([None, edge_feature_dim])
|
|
699
|
-
else:
|
|
700
|
-
self._update_edge_feature = False
|
|
701
|
-
|
|
702
|
-
has_overridden_update = self.__class__.update != GINConv.update
|
|
703
|
-
if not has_overridden_update:
|
|
704
|
-
# Use default feedforward network
|
|
705
|
-
self._feedforward_intermediate_dense = self.get_dense(self.units)
|
|
706
|
-
self._feedforward_intermediate_dense.build([None, node_feature_dim])
|
|
707
|
-
|
|
708
|
-
if self._normalize:
|
|
709
|
-
self._feedforward_intermediate_norm = keras.layers.BatchNormalization()
|
|
710
|
-
self._feedforward_intermediate_norm.build([None, self.units])
|
|
711
|
-
|
|
712
|
-
self._feedforward_dropout = keras.layers.Dropout(self._dropout)
|
|
713
|
-
self._feedforward_activation = self._activation
|
|
714
|
-
|
|
715
|
-
self._feedforward_output_dense = self.get_dense(self.units)
|
|
716
|
-
self._feedforward_output_dense.build([None, self.units])
|
|
717
|
-
|
|
718
|
-
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
719
|
-
"""Compute messages.
|
|
720
|
-
"""
|
|
721
|
-
message = tensor.gather('feature', 'source')
|
|
722
|
-
edge_feature = tensor.edge.get('feature')
|
|
723
|
-
if self._update_edge_feature:
|
|
724
|
-
edge_feature = self._edge_dense(edge_feature)
|
|
725
|
-
if edge_feature is not None:
|
|
726
|
-
message += edge_feature
|
|
727
|
-
return tensor.update(
|
|
728
|
-
{
|
|
729
|
-
'edge': {
|
|
730
|
-
'message': message,
|
|
731
|
-
'feature': edge_feature
|
|
732
|
-
}
|
|
733
|
-
}
|
|
734
|
-
)
|
|
735
|
-
|
|
736
|
-
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
737
|
-
"""Aggregates messages.
|
|
738
|
-
"""
|
|
739
|
-
node_feature = tensor.aggregate('message')
|
|
740
|
-
node_feature += (1 + self.epsilon) * tensor.node['feature']
|
|
741
|
-
return tensor.update(
|
|
742
|
-
{
|
|
743
|
-
'node': {
|
|
744
|
-
'feature': node_feature,
|
|
745
|
-
},
|
|
746
|
-
'edge': {
|
|
747
|
-
'message': None,
|
|
748
|
-
}
|
|
749
|
-
}
|
|
750
|
-
)
|
|
751
|
-
|
|
752
|
-
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
753
|
-
"""Updates nodes.
|
|
754
|
-
"""
|
|
755
|
-
node_feature = tensor.node['feature']
|
|
756
|
-
node_feature = self._feedforward_intermediate_dense(node_feature)
|
|
757
|
-
node_feature = self._feedforward_activation(node_feature)
|
|
758
|
-
if self._normalize:
|
|
759
|
-
node_feature = self._feedforward_intermediate_norm(node_feature)
|
|
760
|
-
node_feature = self._feedforward_dropout(node_feature)
|
|
761
|
-
node_feature = self._feedforward_output_dense(node_feature)
|
|
762
|
-
return tensor.update(
|
|
763
|
-
{
|
|
764
|
-
'node': {
|
|
765
|
-
'feature': node_feature,
|
|
766
|
-
}
|
|
767
|
-
}
|
|
768
|
-
)
|
|
769
|
-
|
|
770
|
-
def get_config(self) -> dict:
|
|
771
|
-
config = super().get_config()
|
|
772
|
-
config.update({
|
|
773
|
-
'activation': keras.activations.serialize(self._activation),
|
|
774
|
-
'dropout': self._dropout,
|
|
775
|
-
'normalize': self._normalize,
|
|
776
|
-
})
|
|
777
|
-
return config
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
781
|
-
class GTConv(GraphConv):
|
|
782
|
-
|
|
783
|
-
"""Graph transformer layer.
|
|
1524
|
+
Embeds edges based on its initial features.
|
|
784
1525
|
"""
|
|
785
1526
|
|
|
786
1527
|
def __init__(
|
|
787
|
-
self,
|
|
788
|
-
|
|
789
|
-
heads: int = 8,
|
|
790
|
-
activation: keras.layers.Activation | str | None = "relu",
|
|
791
|
-
dropout: float = 0.0,
|
|
792
|
-
attention_dropout: float = 0.0,
|
|
1528
|
+
self,
|
|
1529
|
+
dim: int = None,
|
|
793
1530
|
normalize: bool = True,
|
|
794
|
-
|
|
795
|
-
**kwargs
|
|
1531
|
+
allow_masking: bool = True,
|
|
1532
|
+
**kwargs
|
|
796
1533
|
) -> None:
|
|
797
|
-
super().__init__(
|
|
798
|
-
self.
|
|
799
|
-
if self.units % self.heads != 0:
|
|
800
|
-
raise ValueError(f"units need to be divisible by heads.")
|
|
801
|
-
self._head_units = self.units // self.heads
|
|
802
|
-
self._activation = keras.activations.get(activation)
|
|
803
|
-
self._dropout = dropout
|
|
804
|
-
self._attention_dropout = attention_dropout
|
|
1534
|
+
super().__init__(**kwargs)
|
|
1535
|
+
self.dim = dim
|
|
805
1536
|
self._normalize = normalize
|
|
806
|
-
self.
|
|
1537
|
+
self._masking_rate = None
|
|
1538
|
+
self._allow_masking = allow_masking
|
|
807
1539
|
|
|
808
|
-
|
|
809
|
-
def heads(self):
|
|
810
|
-
return self._heads
|
|
811
|
-
|
|
812
|
-
@property
|
|
813
|
-
def head_units(self):
|
|
814
|
-
return self._head_units
|
|
815
|
-
|
|
816
|
-
def build_from_spec(self, spec):
|
|
1540
|
+
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
817
1541
|
"""Builds the layer.
|
|
818
1542
|
"""
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
'`GTConv` uses residual connections, but input node feature dim '
|
|
825
|
-
'is incompatible with intermediate dim (`units`). '
|
|
826
|
-
'Automatically projecting first residual to match its dim with intermediate dim.'
|
|
827
|
-
),
|
|
828
|
-
category=UserWarning,
|
|
829
|
-
stacklevel=1
|
|
830
|
-
)
|
|
831
|
-
self._residual_dense = self.get_dense(self.units)
|
|
832
|
-
self._residual_dense.build([None, node_feature_dim])
|
|
833
|
-
self._project_residual = True
|
|
834
|
-
else:
|
|
835
|
-
self._project_residual = False
|
|
836
|
-
|
|
837
|
-
self._query_dense = self.get_einsum_dense(
|
|
838
|
-
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
839
|
-
)
|
|
840
|
-
self._query_dense.build([None, node_feature_dim])
|
|
841
|
-
|
|
842
|
-
self._key_dense = self.get_einsum_dense(
|
|
843
|
-
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
844
|
-
)
|
|
845
|
-
self._key_dense.build([None, node_feature_dim])
|
|
846
|
-
|
|
847
|
-
self._value_dense = self.get_einsum_dense(
|
|
848
|
-
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
849
|
-
)
|
|
850
|
-
self._value_dense.build([None, node_feature_dim])
|
|
851
|
-
|
|
852
|
-
self._output_dense = self.get_dense(self.units)
|
|
853
|
-
self._output_dense.build([None, self.units])
|
|
1543
|
+
feature_dim = spec.edge['feature'].shape[-1]
|
|
1544
|
+
if not self.dim:
|
|
1545
|
+
self.dim = feature_dim
|
|
1546
|
+
self._edge_dense = self.get_dense(self.dim)
|
|
1547
|
+
self._edge_dense.build([None, feature_dim])
|
|
854
1548
|
|
|
855
|
-
self.
|
|
1549
|
+
self._has_super = 'super' in spec.edge
|
|
1550
|
+
if self._has_super:
|
|
1551
|
+
self._super_feature = self.get_weight(shape=[self.dim], name='super_edge_feature')
|
|
1552
|
+
if self._allow_masking:
|
|
1553
|
+
self._mask_feature = self.get_weight(shape=[self.dim], name='mask_edge_feature')
|
|
1554
|
+
if self._normalize:
|
|
1555
|
+
self._norm = keras.layers.LayerNormalization()
|
|
1556
|
+
self._norm.build([None, self.dim])
|
|
856
1557
|
|
|
857
|
-
self.
|
|
858
|
-
if self._normalize_first:
|
|
859
|
-
self._self_attention_norm.build([None, node_feature_dim])
|
|
860
|
-
else:
|
|
861
|
-
self._self_attention_norm.build([None, self.units])
|
|
1558
|
+
self.built = True
|
|
862
1559
|
|
|
863
|
-
|
|
1560
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1561
|
+
"""Calls the layer.
|
|
1562
|
+
"""
|
|
1563
|
+
feature = self._edge_dense(tensor.edge['feature'])
|
|
864
1564
|
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
self._has_edge_length = 'length' in spec.edge
|
|
870
|
-
if self._has_edge_length and 'bias' not in spec.edge:
|
|
871
|
-
edge_length_dim = spec.edge['length'].shape[-1]
|
|
872
|
-
self._spatial_encoding_dense = self.get_einsum_dense(
|
|
873
|
-
'ij,jkh->ikh', (1, self.heads), kernel_initializer='zeros'
|
|
874
|
-
)
|
|
875
|
-
self._spatial_encoding_dense.build([None, edge_length_dim])
|
|
1565
|
+
if self._has_super:
|
|
1566
|
+
super_feature = self._super_feature
|
|
1567
|
+
super_mask = keras.ops.expand_dims(tensor.edge['super'], 1)
|
|
1568
|
+
feature = keras.ops.where(super_mask, super_feature, feature)
|
|
876
1569
|
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
1570
|
+
if (
|
|
1571
|
+
self._allow_masking and
|
|
1572
|
+
self._masking_rate is not None and
|
|
1573
|
+
self._masking_rate > 0
|
|
1574
|
+
):
|
|
1575
|
+
random = keras.random.uniform(shape=[tensor.num_edges])
|
|
1576
|
+
mask = random <= self._masking_rate
|
|
1577
|
+
if self._has_super:
|
|
1578
|
+
mask = keras.ops.logical_and(
|
|
1579
|
+
mask, keras.ops.logical_not(tensor.edge['super'])
|
|
882
1580
|
)
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
self._feedforward_norm = keras.layers.LayerNormalization()
|
|
889
|
-
self._feedforward_norm.build([None, self.units])
|
|
890
|
-
|
|
891
|
-
self._feedforward_dropout = keras.layers.Dropout(self._dropout)
|
|
1581
|
+
mask = keras.ops.expand_dims(mask, -1)
|
|
1582
|
+
feature = keras.ops.where(mask, self._mask_feature, feature)
|
|
1583
|
+
elif self._allow_masking:
|
|
1584
|
+
# Slience warning of 'no gradients for variables'
|
|
1585
|
+
feature = feature + (self._mask_feature * 0.0)
|
|
892
1586
|
|
|
893
|
-
|
|
894
|
-
self.
|
|
1587
|
+
if self._normalize:
|
|
1588
|
+
feature = self._norm(feature)
|
|
895
1589
|
|
|
896
|
-
|
|
897
|
-
self._feedforward_output_dense.build([None, self.units])
|
|
1590
|
+
return tensor.update({'edge': {'feature': feature}})
|
|
898
1591
|
|
|
899
|
-
|
|
900
|
-
|
|
1592
|
+
@property
|
|
1593
|
+
def masking_rate(self):
|
|
1594
|
+
return self._masking_rate
|
|
901
1595
|
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
if self._has_edge_feature and not self._has_edge_length:
|
|
909
|
-
edge_bias = self._edge_feature_dense(tensor.edge['feature'])
|
|
910
|
-
elif not self._has_edge_feature and self._has_edge_length:
|
|
911
|
-
edge_bias = self._spatial_encoding_dense(tensor.edge['length'])
|
|
912
|
-
else:
|
|
913
|
-
edge_bias = (
|
|
914
|
-
self._edge_feature_dense(tensor.edge['feature']) +
|
|
915
|
-
self._spatial_encoding_dense(tensor.edge['length'])
|
|
1596
|
+
@masking_rate.setter
|
|
1597
|
+
def masking_rate(self, rate: float):
|
|
1598
|
+
if not self._allow_masking and rate is not None:
|
|
1599
|
+
raise ValueError(
|
|
1600
|
+
f'Cannot set `masking_rate` for layer {self} '
|
|
1601
|
+
'as `allow_masking` was set to `False`.'
|
|
916
1602
|
)
|
|
917
|
-
|
|
918
|
-
return tensor.update(
|
|
919
|
-
{
|
|
920
|
-
'edge': {
|
|
921
|
-
'bias': edge_bias
|
|
922
|
-
}
|
|
923
|
-
}
|
|
924
|
-
)
|
|
925
|
-
|
|
926
|
-
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
927
|
-
"""Compute messages.
|
|
928
|
-
"""
|
|
929
|
-
tensor = self.add_edge_bias(tensor)
|
|
930
|
-
tensor = self.add_node_bias(tensor)
|
|
1603
|
+
self._masking_rate = float(rate)
|
|
931
1604
|
|
|
932
|
-
|
|
1605
|
+
def get_config(self) -> dict:
|
|
1606
|
+
config = super().get_config()
|
|
1607
|
+
config.update({
|
|
1608
|
+
'dim': self.dim,
|
|
1609
|
+
'normalize': self._normalize,
|
|
1610
|
+
'allow_masking': self._allow_masking
|
|
1611
|
+
})
|
|
1612
|
+
return config
|
|
1613
|
+
|
|
933
1614
|
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
query = self._query_dense(node_feature)
|
|
941
|
-
key = self._key_dense(node_feature)
|
|
942
|
-
value = self._value_dense(node_feature)
|
|
1615
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1616
|
+
class ContextProjection(Projection):
|
|
1617
|
+
"""Context projection layer.
|
|
1618
|
+
"""
|
|
1619
|
+
def __init__(self, units: int = None, activation: str = None, **kwargs):
|
|
1620
|
+
super().__init__(units=units, activation=activation, field='context', **kwargs)
|
|
943
1621
|
|
|
944
|
-
query = ops.gather(query, tensor.edge['source'])
|
|
945
|
-
key = ops.gather(key, tensor.edge['target'])
|
|
946
|
-
value = ops.gather(value, tensor.edge['source'])
|
|
947
1622
|
|
|
948
|
-
|
|
949
|
-
|
|
1623
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1624
|
+
class NodeProjection(Projection):
|
|
1625
|
+
"""Node projection layer.
|
|
1626
|
+
"""
|
|
1627
|
+
def __init__(self, units: int = None, activation: str = None, **kwargs):
|
|
1628
|
+
super().__init__(units=units, activation=activation, field='node', **kwargs)
|
|
950
1629
|
|
|
951
|
-
if 'bias' in tensor.edge:
|
|
952
|
-
attention_score += tensor.edge['bias']
|
|
953
|
-
|
|
954
|
-
attention = ops.edge_softmax(attention_score, tensor.edge['target'])
|
|
955
|
-
attention = self._softmax_dropout(attention)
|
|
956
1630
|
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
|
|
963
|
-
}
|
|
964
|
-
)
|
|
1631
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1632
|
+
class EdgeProjection(Projection):
|
|
1633
|
+
"""Edge projection layer.
|
|
1634
|
+
"""
|
|
1635
|
+
def __init__(self, units: int = None, activation: str = None, **kwargs):
|
|
1636
|
+
super().__init__(units=units, activation=activation, field='edge', **kwargs)
|
|
965
1637
|
|
|
966
|
-
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
967
|
-
"""Aggregates messages.
|
|
968
|
-
"""
|
|
969
|
-
node_feature = tensor.aggregate('message')
|
|
970
1638
|
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
node_feature = self._self_attention_dropout(node_feature)
|
|
1639
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1640
|
+
class EdgeBias(GraphLayer):
|
|
974
1641
|
|
|
975
|
-
|
|
976
|
-
|
|
977
|
-
|
|
978
|
-
node_feature += residual
|
|
1642
|
+
def __init__(self, biases: int, **kwargs):
|
|
1643
|
+
super().__init__(**kwargs)
|
|
1644
|
+
self.biases = biases
|
|
979
1645
|
|
|
980
|
-
|
|
981
|
-
|
|
1646
|
+
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1647
|
+
self._has_edge_length = 'length' in spec.edge
|
|
1648
|
+
self._has_edge_feature = 'feature' in spec.edge
|
|
1649
|
+
if self._has_edge_feature:
|
|
1650
|
+
self._edge_feature_dense = self.get_dense(self.biases)
|
|
1651
|
+
self._edge_feature_dense.build([None, spec.edge['feature'].shape[-1]])
|
|
1652
|
+
if self._has_edge_length:
|
|
1653
|
+
self._edge_length_dense = self.get_dense(
|
|
1654
|
+
self.biases, kernel_initializer='zeros'
|
|
1655
|
+
)
|
|
1656
|
+
self._edge_length_dense.build([None, spec.edge['length'].shape[-1]])
|
|
1657
|
+
self.built = True
|
|
982
1658
|
|
|
983
|
-
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
},
|
|
988
|
-
'edge': {
|
|
989
|
-
'message': None,
|
|
990
|
-
'weight': None,
|
|
991
|
-
}
|
|
992
|
-
}
|
|
1659
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1660
|
+
bias = keras.ops.zeros(
|
|
1661
|
+
shape=(tensor.num_edges, self.biases),
|
|
1662
|
+
dtype=tensor.node['feature'].dtype
|
|
993
1663
|
)
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
node_feature = tensor.node['feature']
|
|
1000
|
-
|
|
1001
|
-
if self._normalize_first:
|
|
1002
|
-
node_feature = self._feedforward_norm(node_feature)
|
|
1664
|
+
if self._has_edge_feature:
|
|
1665
|
+
bias += self._edge_feature_dense(tensor.edge['feature'])
|
|
1666
|
+
if self._has_edge_length:
|
|
1667
|
+
bias += self._edge_length_dense(tensor.edge['length'])
|
|
1668
|
+
return bias
|
|
1003
1669
|
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1670
|
+
def get_config(self) -> dict:
|
|
1671
|
+
config = super().get_config()
|
|
1672
|
+
config.update({'biases': self.biases})
|
|
1673
|
+
return config
|
|
1674
|
+
|
|
1007
1675
|
|
|
1008
|
-
|
|
1009
|
-
|
|
1676
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1677
|
+
class GaussianDistance(GraphLayer):
|
|
1010
1678
|
|
|
1011
|
-
|
|
1012
|
-
|
|
1679
|
+
def __init__(self, kernels: int, **kwargs):
|
|
1680
|
+
super().__init__(**kwargs)
|
|
1681
|
+
self.kernels = kernels
|
|
1013
1682
|
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1683
|
+
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1684
|
+
self._loc = self.add_weight(
|
|
1685
|
+
shape=[self.kernels],
|
|
1686
|
+
initializer='zeros',
|
|
1687
|
+
dtype='float32',
|
|
1688
|
+
trainable=True
|
|
1689
|
+
)
|
|
1690
|
+
self._scale = self.add_weight(
|
|
1691
|
+
shape=[self.kernels],
|
|
1692
|
+
initializer='ones',
|
|
1693
|
+
dtype='float32',
|
|
1694
|
+
trainable=True
|
|
1020
1695
|
)
|
|
1021
|
-
|
|
1696
|
+
self.built = True
|
|
1697
|
+
|
|
1698
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1699
|
+
euclidean_distance = ops.euclidean_distance(
|
|
1700
|
+
tensor.gather('coordinate', 'source'),
|
|
1701
|
+
tensor.gather('coordinate', 'target'),
|
|
1702
|
+
axis=-1
|
|
1703
|
+
)
|
|
1704
|
+
return ops.gaussian(
|
|
1705
|
+
euclidean_distance, self._loc, self._scale
|
|
1706
|
+
)
|
|
1707
|
+
|
|
1022
1708
|
def get_config(self) -> dict:
|
|
1023
1709
|
config = super().get_config()
|
|
1024
1710
|
config.update({
|
|
1025
|
-
|
|
1026
|
-
'activation': keras.activations.serialize(self._activation),
|
|
1027
|
-
'dropout': self._dropout,
|
|
1028
|
-
'attention_dropout': self._attention_dropout,
|
|
1029
|
-
'normalize': self._normalize,
|
|
1030
|
-
'normalize_first': self._normalize_first,
|
|
1711
|
+
'kernels': self.kernels,
|
|
1031
1712
|
})
|
|
1032
1713
|
return config
|
|
1033
|
-
|
|
1714
|
+
|
|
1034
1715
|
|
|
1035
1716
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1036
|
-
class Readout(
|
|
1717
|
+
class Readout(GraphLayer):
|
|
1718
|
+
|
|
1719
|
+
"""Readout layer.
|
|
1720
|
+
"""
|
|
1037
1721
|
|
|
1038
1722
|
def __init__(self, mode: str | None = None, **kwargs):
|
|
1723
|
+
kwargs['kernel_initializer'] = None
|
|
1724
|
+
kwargs['bias_initializer'] = None
|
|
1039
1725
|
super().__init__(**kwargs)
|
|
1040
1726
|
self.mode = mode
|
|
1041
|
-
if
|
|
1042
|
-
self._reduce_fn = None
|
|
1043
|
-
elif str(self.mode).lower().startswith('sum'):
|
|
1727
|
+
if str(self.mode).lower().startswith('sum'):
|
|
1044
1728
|
self._reduce_fn = keras.ops.segment_sum
|
|
1045
1729
|
elif str(self.mode).lower().startswith('max'):
|
|
1046
1730
|
self._reduce_fn = keras.ops.segment_max
|
|
@@ -1052,50 +1736,25 @@ class Readout(keras.layers.Layer):
|
|
|
1052
1736
|
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1053
1737
|
"""Builds the layer.
|
|
1054
1738
|
"""
|
|
1055
|
-
|
|
1739
|
+
self.built = True
|
|
1056
1740
|
|
|
1057
|
-
def
|
|
1058
|
-
|
|
1059
|
-
|
|
1741
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tf.Tensor:
|
|
1742
|
+
"""Calls the layer.
|
|
1743
|
+
"""
|
|
1744
|
+
node_feature = tensor.node['feature']
|
|
1060
1745
|
if str(self.mode).lower().startswith('super'):
|
|
1061
1746
|
node_feature = keras.ops.where(
|
|
1062
|
-
tensor.node['super'][:, None],
|
|
1063
|
-
)
|
|
1064
|
-
return self._reduce_fn(
|
|
1065
|
-
node_feature, tensor.graph_indicator, tensor.num_subgraphs
|
|
1747
|
+
tensor.node['super'][:, None], node_feature, 0.0
|
|
1066
1748
|
)
|
|
1067
1749
|
return self._reduce_fn(
|
|
1068
|
-
|
|
1750
|
+
node_feature, tensor.graph_indicator, tensor.num_subgraphs
|
|
1069
1751
|
)
|
|
1070
1752
|
|
|
1071
|
-
def build(self, input_shapes) -> None:
|
|
1072
|
-
spec = tensors.GraphTensor.Spec.from_input_shape_dict(input_shapes)
|
|
1073
|
-
self.build_from_spec(spec)
|
|
1074
|
-
self.built = True
|
|
1075
|
-
|
|
1076
|
-
def call(self, graph) -> tf.Tensor:
|
|
1077
|
-
graph_tensor = tensors.from_dict(graph)
|
|
1078
|
-
if tensors.is_ragged(graph_tensor):
|
|
1079
|
-
graph_tensor = graph_tensor.flatten()
|
|
1080
|
-
return self.reduce(graph_tensor)
|
|
1081
|
-
|
|
1082
|
-
def __call__(
|
|
1083
|
-
self,
|
|
1084
|
-
graph: tensors.GraphTensor,
|
|
1085
|
-
*args,
|
|
1086
|
-
**kwargs
|
|
1087
|
-
) -> tensors.GraphTensor:
|
|
1088
|
-
is_tensor = isinstance(graph, tensors.GraphTensor)
|
|
1089
|
-
if is_tensor:
|
|
1090
|
-
graph = tensors.to_dict(graph)
|
|
1091
|
-
tensor = super().__call__(graph, *args, **kwargs)
|
|
1092
|
-
return tensor
|
|
1093
|
-
|
|
1094
1753
|
def get_config(self) -> dict:
|
|
1095
1754
|
config = super().get_config()
|
|
1096
1755
|
config['mode'] = self.mode
|
|
1097
1756
|
return config
|
|
1098
|
-
|
|
1757
|
+
|
|
1099
1758
|
|
|
1100
1759
|
def Input(spec: tensors.GraphTensor.Spec) -> dict:
|
|
1101
1760
|
"""Used to specify inputs to model.
|
|
@@ -1212,13 +1871,6 @@ def _spec_from_inputs(inputs):
|
|
|
1212
1871
|
return tensors.GraphTensor.Spec(**nested_specs)
|
|
1213
1872
|
|
|
1214
1873
|
|
|
1215
|
-
GraphTransformer =
|
|
1216
|
-
|
|
1217
|
-
|
|
1218
|
-
EdgeEmbed = EdgeEmbedding
|
|
1219
|
-
NodeEmbed = NodeEmbedding
|
|
1220
|
-
|
|
1221
|
-
ContextDense = ContextProjection
|
|
1222
|
-
EdgeDense = EdgeProjection
|
|
1223
|
-
NodeDense = NodeProjection
|
|
1874
|
+
GraphTransformer = GTConv
|
|
1875
|
+
GraphTransformer3D = GTConv3D
|
|
1224
1876
|
|