molcraft 0.1.0a7__py3-none-any.whl → 0.1.0a9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of molcraft might be problematic. Click here for more details.
- molcraft/__init__.py +1 -1
- molcraft/callbacks.py +33 -26
- molcraft/chem.py +15 -16
- molcraft/features.py +3 -9
- molcraft/featurizers.py +28 -38
- molcraft/layers.py +439 -858
- molcraft/ops.py +12 -1
- {molcraft-0.1.0a7.dist-info → molcraft-0.1.0a9.dist-info}/METADATA +2 -2
- molcraft-0.1.0a9.dist-info/RECORD +19 -0
- molcraft-0.1.0a7.dist-info/RECORD +0 -19
- {molcraft-0.1.0a7.dist-info → molcraft-0.1.0a9.dist-info}/WHEEL +0 -0
- {molcraft-0.1.0a7.dist-info → molcraft-0.1.0a9.dist-info}/licenses/LICENSE +0 -0
- {molcraft-0.1.0a7.dist-info → molcraft-0.1.0a9.dist-info}/top_level.txt +0 -0
molcraft/layers.py
CHANGED
|
@@ -274,20 +274,14 @@ class GraphConv(GraphLayer):
|
|
|
274
274
|
units (int):
|
|
275
275
|
Dimensionality of the output space.
|
|
276
276
|
activation (keras.layers.Activation, str, None):
|
|
277
|
-
Activation function to
|
|
278
|
-
Default to `
|
|
277
|
+
Activation function to be accessed via `self.activation`, and used for the
|
|
278
|
+
`message()` and `update()` methods, if not overriden. Default to `relu`.
|
|
279
279
|
use_bias (bool):
|
|
280
|
-
Whether bias should be used in dense layers. Default to `True`.
|
|
280
|
+
Whether bias should be used in the dense layers. Default to `True`.
|
|
281
281
|
normalize (bool, str):
|
|
282
|
-
Whether
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
Whether node feature input should be added to the node feature output.
|
|
286
|
-
If node feature input dim is not equal to `units` (node feature output dim),
|
|
287
|
-
a projection layer will automatically project the residual before adding it
|
|
288
|
-
to the output. To use weighted skip connection,
|
|
289
|
-
specify `weighted`. The weight multiplied with the skip connection is a
|
|
290
|
-
learnable scalar. Default to `True`.
|
|
282
|
+
Whether normalization should be applied to the final output. Default to `False`.
|
|
283
|
+
skip_connect (bool):
|
|
284
|
+
Whether node feature input should be added to the node feature output. Default to `True`.
|
|
291
285
|
kernel_initializer (keras.initializers.Initializer, str):
|
|
292
286
|
Initializer for the kernel weight matrix of the dense layers.
|
|
293
287
|
Default to `glorot_uniform`.
|
|
@@ -314,10 +308,10 @@ class GraphConv(GraphLayer):
|
|
|
314
308
|
def __init__(
|
|
315
309
|
self,
|
|
316
310
|
units: int = None,
|
|
317
|
-
activation: str | keras.layers.Activation | None =
|
|
311
|
+
activation: str | keras.layers.Activation | None = 'relu',
|
|
318
312
|
use_bias: bool = True,
|
|
319
|
-
normalize: bool
|
|
320
|
-
skip_connect: bool
|
|
313
|
+
normalize: bool = False,
|
|
314
|
+
skip_connect: bool = True,
|
|
321
315
|
**kwargs
|
|
322
316
|
) -> None:
|
|
323
317
|
super().__init__(use_bias=use_bias, **kwargs)
|
|
@@ -341,56 +335,56 @@ class GraphConv(GraphLayer):
|
|
|
341
335
|
def units(self):
|
|
342
336
|
return self._units
|
|
343
337
|
|
|
338
|
+
@property
|
|
339
|
+
def activation(self):
|
|
340
|
+
return self._activation
|
|
341
|
+
|
|
344
342
|
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
345
343
|
if not self.units:
|
|
346
344
|
raise ValueError(
|
|
347
345
|
f'`self.units` needs to be a positive integer. Found: {self.units}.'
|
|
348
346
|
)
|
|
349
347
|
node_feature_dim = spec.node['feature'].shape[-1]
|
|
350
|
-
self.
|
|
348
|
+
self._project_residual = (
|
|
351
349
|
self._skip_connect and (node_feature_dim != self.units)
|
|
352
350
|
)
|
|
353
|
-
if self.
|
|
354
|
-
warn(
|
|
351
|
+
if self._project_residual:
|
|
352
|
+
warnings.warn(
|
|
355
353
|
'`skip_connect` is set to `True`, but found incompatible dim '
|
|
356
354
|
'between input (node feature dim) and output (`self.units`). '
|
|
357
355
|
'Automatically applying a projection layer to residual to '
|
|
358
|
-
'match input and output. '
|
|
356
|
+
'match input and output. ',
|
|
357
|
+
stacklevel=2,
|
|
359
358
|
)
|
|
360
|
-
self.
|
|
361
|
-
self.units, name='
|
|
359
|
+
self._residual_dense = self.get_dense(
|
|
360
|
+
self.units, name='residual_dense'
|
|
362
361
|
)
|
|
363
362
|
|
|
364
|
-
|
|
365
|
-
self.
|
|
366
|
-
if self._use_weighted_skip_connection:
|
|
367
|
-
self._skip_connection_weight = self.add_weight(
|
|
368
|
-
name='skip_connection_weight',
|
|
369
|
-
shape=(),
|
|
370
|
-
initializer='ones',
|
|
371
|
-
trainable=True,
|
|
372
|
-
)
|
|
373
|
-
|
|
374
|
-
if self._normalize:
|
|
375
|
-
if str(self._normalize).lower().startswith('batch'):
|
|
376
|
-
self._output_norm = keras.layers.BatchNormalization(
|
|
377
|
-
name='output_batch_norm'
|
|
378
|
-
)
|
|
379
|
-
else:
|
|
380
|
-
self._output_norm = keras.layers.LayerNormalization(
|
|
381
|
-
name='output_layer_norm'
|
|
382
|
-
)
|
|
383
|
-
|
|
384
|
-
self._has_edge_feature = 'feature' in spec.edge
|
|
363
|
+
self.has_edge_feature = 'feature' in spec.edge
|
|
364
|
+
self.has_node_coordinate = 'coordinate' in spec.node
|
|
385
365
|
|
|
386
366
|
has_overridden_message = self.__class__.message != GraphConv.message
|
|
387
367
|
if not has_overridden_message:
|
|
388
|
-
self.
|
|
368
|
+
self._message_intermediate_dense = self.get_dense(self.units)
|
|
369
|
+
self._message_intermediate_activation = self.activation
|
|
370
|
+
self._message_final_dense = self.get_dense(self.units)
|
|
371
|
+
|
|
372
|
+
has_overridden_aggregate = self.__class__.message != GraphConv.aggregate
|
|
373
|
+
if not has_overridden_aggregate:
|
|
374
|
+
pass
|
|
389
375
|
|
|
390
376
|
has_overridden_update = self.__class__.update != GraphConv.update
|
|
391
377
|
if not has_overridden_update:
|
|
392
|
-
self.
|
|
393
|
-
self.
|
|
378
|
+
self._update_intermediate_dense = self.get_dense(self.units)
|
|
379
|
+
self._update_intermediate_activation = self.activation
|
|
380
|
+
self._update_final_dense = self.get_dense(self.units)
|
|
381
|
+
|
|
382
|
+
if not self._normalize:
|
|
383
|
+
self._normalization = keras.layers.Identity()
|
|
384
|
+
elif str(self._normalize).lower().startswith('layer'):
|
|
385
|
+
self._normalization = keras.layers.LayerNormalization()
|
|
386
|
+
else:
|
|
387
|
+
self._normalization = keras.layers.BatchNormalization()
|
|
394
388
|
|
|
395
389
|
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
396
390
|
"""Forward pass.
|
|
@@ -402,10 +396,10 @@ class GraphConv(GraphLayer):
|
|
|
402
396
|
A `GraphTensor` instance.
|
|
403
397
|
"""
|
|
404
398
|
if self._skip_connect:
|
|
405
|
-
|
|
406
|
-
if self.
|
|
407
|
-
|
|
408
|
-
|
|
399
|
+
residual = tensor.node['feature']
|
|
400
|
+
if self._project_residual:
|
|
401
|
+
residual = self._residual_dense(residual)
|
|
402
|
+
|
|
409
403
|
message = self.message(tensor)
|
|
410
404
|
if not isinstance(message, tensors.GraphTensor):
|
|
411
405
|
message = tensor.update({'edge': {'message': message}})
|
|
@@ -417,24 +411,24 @@ class GraphConv(GraphLayer):
|
|
|
417
411
|
aggregate = tensor.update({'node': {'aggregate': aggregate}})
|
|
418
412
|
elif not 'aggregate' in aggregate.node:
|
|
419
413
|
raise ValueError('Could not find `aggregate` in `node` output.')
|
|
420
|
-
|
|
414
|
+
|
|
421
415
|
update = self.update(aggregate)
|
|
422
416
|
if not isinstance(update, tensors.GraphTensor):
|
|
423
417
|
update = tensor.update({'node': {'feature': update}})
|
|
424
418
|
elif not 'feature' in update.node:
|
|
425
419
|
raise ValueError('Could not find `feature` in `node` output.')
|
|
426
|
-
|
|
427
|
-
|
|
420
|
+
|
|
421
|
+
if update.node['feature'].shape[-1] != self.units:
|
|
422
|
+
raise ValueError('Updated node `feature` is not equal to `self.units`.')
|
|
423
|
+
|
|
424
|
+
feature = update.node['feature']
|
|
428
425
|
|
|
429
426
|
if self._skip_connect:
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
if self._normalize:
|
|
435
|
-
updated_node_feature = self._output_norm(updated_node_feature)
|
|
427
|
+
feature += residual
|
|
428
|
+
|
|
429
|
+
feature = self._normalization(feature)
|
|
436
430
|
|
|
437
|
-
return update.update({'node': {'feature':
|
|
431
|
+
return update.update({'node': {'feature': feature}})
|
|
438
432
|
|
|
439
433
|
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
440
434
|
"""Compute messages.
|
|
@@ -445,24 +439,38 @@ class GraphConv(GraphLayer):
|
|
|
445
439
|
tensor:
|
|
446
440
|
The inputted `GraphTensor` instance.
|
|
447
441
|
"""
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
442
|
+
message = keras.ops.concatenate(
|
|
443
|
+
[
|
|
444
|
+
tensor.gather('feature', 'source'),
|
|
445
|
+
tensor.gather('feature', 'target'),
|
|
446
|
+
],
|
|
447
|
+
axis=-1
|
|
448
|
+
)
|
|
449
|
+
if self.has_edge_feature:
|
|
450
|
+
message = keras.ops.concatenate(
|
|
452
451
|
[
|
|
453
|
-
|
|
452
|
+
message,
|
|
454
453
|
tensor.edge['feature']
|
|
455
454
|
],
|
|
456
455
|
axis=-1
|
|
457
456
|
)
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
'
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
457
|
+
if self.has_node_coordinate:
|
|
458
|
+
euclidean_distance = ops.euclidean_distance(
|
|
459
|
+
tensor.gather('coordinate', 'target'),
|
|
460
|
+
tensor.gather('coordinate', 'source'),
|
|
461
|
+
axis=-1
|
|
462
|
+
)
|
|
463
|
+
message = keras.ops.concatenate(
|
|
464
|
+
[
|
|
465
|
+
message,
|
|
466
|
+
euclidean_distance
|
|
467
|
+
],
|
|
468
|
+
axis=-1
|
|
469
|
+
)
|
|
470
|
+
message = self._message_intermediate_dense(message)
|
|
471
|
+
message = self._message_intermediate_activation(message)
|
|
472
|
+
message = self._message_final_dense(message)
|
|
473
|
+
return tensor.update({'edge': {'message': message}})
|
|
466
474
|
|
|
467
475
|
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
468
476
|
"""Aggregates messages.
|
|
@@ -473,14 +481,16 @@ class GraphConv(GraphLayer):
|
|
|
473
481
|
tensor:
|
|
474
482
|
A `GraphTensor` instance containing a message.
|
|
475
483
|
"""
|
|
484
|
+
previous = tensor.node['feature']
|
|
476
485
|
aggregate = tensor.aggregate('message', mode='mean')
|
|
486
|
+
aggregate = keras.ops.concatenate([aggregate, previous], axis=-1)
|
|
477
487
|
return tensor.update(
|
|
478
488
|
{
|
|
479
489
|
'node': {
|
|
480
|
-
'aggregate': aggregate,
|
|
490
|
+
'aggregate': aggregate,
|
|
481
491
|
},
|
|
482
492
|
'edge': {
|
|
483
|
-
'message': None
|
|
493
|
+
'message': None,
|
|
484
494
|
}
|
|
485
495
|
}
|
|
486
496
|
)
|
|
@@ -495,21 +505,16 @@ class GraphConv(GraphLayer):
|
|
|
495
505
|
A `GraphTensor` instance containing aggregated messages
|
|
496
506
|
(updated node features).
|
|
497
507
|
"""
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
],
|
|
503
|
-
axis=-1
|
|
504
|
-
)
|
|
505
|
-
update = self._output_dense(feature)
|
|
506
|
-
update = self._output_activation(update)
|
|
508
|
+
aggregate = tensor.node['aggregate']
|
|
509
|
+
node_feature = self._update_intermediate_dense(aggregate)
|
|
510
|
+
node_feature = self._update_intermediate_activation(node_feature)
|
|
511
|
+
node_feature = self._update_final_dense(node_feature)
|
|
507
512
|
return tensor.update(
|
|
508
513
|
{
|
|
509
514
|
'node': {
|
|
510
|
-
'feature':
|
|
515
|
+
'feature': node_feature,
|
|
511
516
|
'aggregate': None,
|
|
512
|
-
}
|
|
517
|
+
},
|
|
513
518
|
}
|
|
514
519
|
)
|
|
515
520
|
|
|
@@ -563,14 +568,16 @@ class GIConv(GraphConv):
|
|
|
563
568
|
activation: keras.layers.Activation | str | None = 'relu',
|
|
564
569
|
use_bias: bool = True,
|
|
565
570
|
normalize: bool = False,
|
|
571
|
+
skip_connect: bool = True,
|
|
566
572
|
update_edge_feature: bool = True,
|
|
567
573
|
**kwargs,
|
|
568
574
|
):
|
|
569
575
|
super().__init__(
|
|
570
576
|
units=units,
|
|
571
577
|
activation=activation,
|
|
572
|
-
normalize=normalize,
|
|
573
578
|
use_bias=use_bias,
|
|
579
|
+
normalize=normalize,
|
|
580
|
+
skip_connect=skip_connect,
|
|
574
581
|
**kwargs
|
|
575
582
|
)
|
|
576
583
|
self._update_edge_feature = update_edge_feature
|
|
@@ -585,16 +592,16 @@ class GIConv(GraphConv):
|
|
|
585
592
|
trainable=True,
|
|
586
593
|
)
|
|
587
594
|
|
|
588
|
-
self.
|
|
589
|
-
if self._has_edge_feature:
|
|
595
|
+
if self.has_edge_feature:
|
|
590
596
|
edge_feature_dim = spec.edge['feature'].shape[-1]
|
|
591
597
|
|
|
592
598
|
if not self._update_edge_feature:
|
|
593
599
|
if (edge_feature_dim != node_feature_dim):
|
|
594
|
-
warn(
|
|
600
|
+
warnings.warn(
|
|
595
601
|
'Found edge feature dim to be incompatible with node feature dim. '
|
|
596
602
|
'Automatically adding a edge feature projection layer to match '
|
|
597
|
-
'the dim of node features.'
|
|
603
|
+
'the dim of node features.',
|
|
604
|
+
stacklevel=2,
|
|
598
605
|
)
|
|
599
606
|
self._update_edge_feature = True
|
|
600
607
|
|
|
@@ -603,19 +610,14 @@ class GIConv(GraphConv):
|
|
|
603
610
|
else:
|
|
604
611
|
self._update_edge_feature = False
|
|
605
612
|
|
|
606
|
-
has_overridden_update = self.__class__.update != GIConv.update
|
|
607
|
-
if not has_overridden_update:
|
|
608
|
-
self._feedforward_intermediate_dense = self.get_dense(self.units)
|
|
609
|
-
self._feedforward_activation = self._activation
|
|
610
|
-
self._feedforward_output_dense = self.get_dense(self.units)
|
|
611
|
-
|
|
612
613
|
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
613
614
|
message = tensor.gather('feature', 'source')
|
|
614
615
|
edge_feature = tensor.edge.get('feature')
|
|
615
616
|
if self._update_edge_feature:
|
|
616
617
|
edge_feature = self._edge_dense(edge_feature)
|
|
617
|
-
if self.
|
|
618
|
+
if self.has_edge_feature:
|
|
618
619
|
message += edge_feature
|
|
620
|
+
message = keras.ops.relu(message)
|
|
619
621
|
return tensor.update(
|
|
620
622
|
{
|
|
621
623
|
'edge': {
|
|
@@ -639,20 +641,6 @@ class GIConv(GraphConv):
|
|
|
639
641
|
}
|
|
640
642
|
)
|
|
641
643
|
|
|
642
|
-
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
643
|
-
node_feature = tensor.node['aggregate']
|
|
644
|
-
node_feature = self._feedforward_intermediate_dense(node_feature)
|
|
645
|
-
node_feature = self._feedforward_activation(node_feature)
|
|
646
|
-
node_feature = self._feedforward_output_dense(node_feature)
|
|
647
|
-
return tensor.update(
|
|
648
|
-
{
|
|
649
|
-
'node': {
|
|
650
|
-
'feature': node_feature,
|
|
651
|
-
'aggregate': None,
|
|
652
|
-
}
|
|
653
|
-
}
|
|
654
|
-
)
|
|
655
|
-
|
|
656
644
|
def get_config(self) -> dict:
|
|
657
645
|
config = super().get_config()
|
|
658
646
|
config.update({
|
|
@@ -701,15 +689,16 @@ class GAConv(GraphConv):
|
|
|
701
689
|
activation: keras.layers.Activation | str | None = "relu",
|
|
702
690
|
use_bias: bool = True,
|
|
703
691
|
normalize: bool = False,
|
|
692
|
+
skip_connect: bool = True,
|
|
704
693
|
update_edge_feature: bool = True,
|
|
705
|
-
attention_activation: keras.layers.Activation | str | None = "leaky_relu",
|
|
706
694
|
**kwargs,
|
|
707
695
|
) -> None:
|
|
708
696
|
super().__init__(
|
|
709
697
|
units=units,
|
|
710
698
|
activation=activation,
|
|
711
|
-
use_bias=use_bias,
|
|
712
699
|
normalize=normalize,
|
|
700
|
+
use_bias=use_bias,
|
|
701
|
+
skip_connect=skip_connect,
|
|
713
702
|
**kwargs
|
|
714
703
|
)
|
|
715
704
|
self._heads = heads
|
|
@@ -717,7 +706,6 @@ class GAConv(GraphConv):
|
|
|
717
706
|
raise ValueError(f"units need to be divisible by heads.")
|
|
718
707
|
self._head_units = self.units // self.heads
|
|
719
708
|
self._update_edge_feature = update_edge_feature
|
|
720
|
-
self._attention_activation = keras.activations.get(attention_activation)
|
|
721
709
|
|
|
722
710
|
@property
|
|
723
711
|
def heads(self):
|
|
@@ -728,8 +716,7 @@ class GAConv(GraphConv):
|
|
|
728
716
|
return self._head_units
|
|
729
717
|
|
|
730
718
|
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
731
|
-
self.
|
|
732
|
-
self._update_edge_feature = self._has_edge_feature and self._update_edge_feature
|
|
719
|
+
self._update_edge_feature = self.has_edge_feature and self._update_edge_feature
|
|
733
720
|
if self._update_edge_feature:
|
|
734
721
|
self._edge_dense = self.get_einsum_dense(
|
|
735
722
|
'ijh,jkh->ikh', (self.head_units, self.heads)
|
|
@@ -743,15 +730,6 @@ class GAConv(GraphConv):
|
|
|
743
730
|
self._attention_dense = self.get_einsum_dense(
|
|
744
731
|
'ijh,jkh->ikh', (1, self.heads)
|
|
745
732
|
)
|
|
746
|
-
self._node_self_dense = self.get_einsum_dense(
|
|
747
|
-
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
748
|
-
)
|
|
749
|
-
|
|
750
|
-
has_overridden_update = self.__class__.update != GAConv.update
|
|
751
|
-
if not has_overridden_update:
|
|
752
|
-
self._feedforward_intermediate_dense = self.get_dense(self.units)
|
|
753
|
-
self._feedforward_activation = self._activation
|
|
754
|
-
self._feedforward_output_dense = self.get_dense(self.units)
|
|
755
733
|
|
|
756
734
|
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
757
735
|
attention_feature = keras.ops.concatenate(
|
|
@@ -761,7 +739,7 @@ class GAConv(GraphConv):
|
|
|
761
739
|
],
|
|
762
740
|
axis=-1
|
|
763
741
|
)
|
|
764
|
-
if self.
|
|
742
|
+
if self.has_edge_feature:
|
|
765
743
|
attention_feature = keras.ops.concatenate(
|
|
766
744
|
[
|
|
767
745
|
attention_feature,
|
|
@@ -778,7 +756,7 @@ class GAConv(GraphConv):
|
|
|
778
756
|
edge_feature = self._edge_dense(attention_feature)
|
|
779
757
|
edge_feature = keras.ops.reshape(edge_feature, (-1, self.units))
|
|
780
758
|
|
|
781
|
-
attention_feature =
|
|
759
|
+
attention_feature = keras.ops.leaky_relu(attention_feature)
|
|
782
760
|
attention_score = self._attention_dense(attention_feature)
|
|
783
761
|
attention_score = ops.edge_softmax(
|
|
784
762
|
score=attention_score, edge_target=tensor.edge['target']
|
|
@@ -797,7 +775,6 @@ class GAConv(GraphConv):
|
|
|
797
775
|
|
|
798
776
|
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
799
777
|
node_feature = tensor.aggregate('message', mode='sum')
|
|
800
|
-
node_feature += self._node_self_dense(tensor.node['feature'])
|
|
801
778
|
node_feature = keras.ops.reshape(node_feature, (-1, self.units))
|
|
802
779
|
return tensor.update(
|
|
803
780
|
{
|
|
@@ -810,28 +787,123 @@ class GAConv(GraphConv):
|
|
|
810
787
|
}
|
|
811
788
|
)
|
|
812
789
|
|
|
790
|
+
def get_config(self) -> dict:
|
|
791
|
+
config = super().get_config()
|
|
792
|
+
config.update({
|
|
793
|
+
"heads": self._heads,
|
|
794
|
+
'update_edge_feature': self._update_edge_feature,
|
|
795
|
+
})
|
|
796
|
+
return config
|
|
797
|
+
|
|
798
|
+
|
|
799
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
800
|
+
class MPConv(GraphConv):
|
|
801
|
+
|
|
802
|
+
"""Message passing neural network layer.
|
|
803
|
+
|
|
804
|
+
Also supports 3D molecular graphs.
|
|
805
|
+
|
|
806
|
+
>>> graph = molcraft.tensors.GraphTensor(
|
|
807
|
+
... context={
|
|
808
|
+
... 'size': [2]
|
|
809
|
+
... },
|
|
810
|
+
... node={
|
|
811
|
+
... 'feature': [[1.], [2.]]
|
|
812
|
+
... },
|
|
813
|
+
... edge={
|
|
814
|
+
... 'source': [0, 1],
|
|
815
|
+
... 'target': [1, 0],
|
|
816
|
+
... }
|
|
817
|
+
... )
|
|
818
|
+
>>> conv = molcraft.layers.MPConv(units=4)
|
|
819
|
+
>>> conv(graph)
|
|
820
|
+
GraphTensor(
|
|
821
|
+
context={
|
|
822
|
+
'size': <tf.Tensor: shape=[1], dtype=int32>
|
|
823
|
+
},
|
|
824
|
+
node={
|
|
825
|
+
'feature': <tf.Tensor: shape=[2, 4], dtype=float32>
|
|
826
|
+
},
|
|
827
|
+
edge={
|
|
828
|
+
'source': <tf.Tensor: shape=[2], dtype=int32>,
|
|
829
|
+
'target': <tf.Tensor: shape=[2], dtype=int32>
|
|
830
|
+
}
|
|
831
|
+
)
|
|
832
|
+
"""
|
|
833
|
+
|
|
834
|
+
def __init__(
|
|
835
|
+
self,
|
|
836
|
+
units: int = 128,
|
|
837
|
+
activation: keras.layers.Activation | str | None = 'relu',
|
|
838
|
+
use_bias: bool = True,
|
|
839
|
+
normalize: bool = False,
|
|
840
|
+
skip_connect: bool = True,
|
|
841
|
+
**kwargs
|
|
842
|
+
) -> None:
|
|
843
|
+
super().__init__(
|
|
844
|
+
units=units,
|
|
845
|
+
activation=activation,
|
|
846
|
+
use_bias=use_bias,
|
|
847
|
+
normalize=normalize,
|
|
848
|
+
skip_connect=skip_connect,
|
|
849
|
+
**kwargs
|
|
850
|
+
)
|
|
851
|
+
|
|
852
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
853
|
+
node_feature_dim = spec.node['feature'].shape[-1]
|
|
854
|
+
self.update_fn = keras.layers.GRUCell(self.units)
|
|
855
|
+
self._project_previous_node_feature = node_feature_dim != self.units
|
|
856
|
+
if self._project_previous_node_feature:
|
|
857
|
+
warnings.warn(
|
|
858
|
+
'Input node feature dim does not match updated node feature dim. '
|
|
859
|
+
'To make sure input node feature can be passed as `states` to the '
|
|
860
|
+
'GRU cell, it will automatically be projected prior to it.',
|
|
861
|
+
stacklevel=2
|
|
862
|
+
)
|
|
863
|
+
self._previous_node_dense = self.get_dense(self.units)
|
|
864
|
+
|
|
865
|
+
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
866
|
+
"""Aggregates messages.
|
|
867
|
+
|
|
868
|
+
This method may be overridden by subclass.
|
|
869
|
+
|
|
870
|
+
Arguments:
|
|
871
|
+
tensor:
|
|
872
|
+
A `GraphTensor` instance containing a message.
|
|
873
|
+
"""
|
|
874
|
+
aggregate = tensor.aggregate('message', mode='mean')
|
|
875
|
+
return tensor.update(
|
|
876
|
+
{
|
|
877
|
+
'node': {
|
|
878
|
+
'aggregate': aggregate,
|
|
879
|
+
},
|
|
880
|
+
'edge': {
|
|
881
|
+
'message': None,
|
|
882
|
+
}
|
|
883
|
+
}
|
|
884
|
+
)
|
|
885
|
+
|
|
813
886
|
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
887
|
+
previous = tensor.node['feature']
|
|
888
|
+
aggregate = tensor.node['aggregate']
|
|
889
|
+
if self._project_previous_node_feature:
|
|
890
|
+
previous = self._previous_node_dense(previous)
|
|
891
|
+
updated_node_feature, _ = self.update_fn(
|
|
892
|
+
inputs=aggregate, states=previous
|
|
893
|
+
)
|
|
818
894
|
return tensor.update(
|
|
819
895
|
{
|
|
820
896
|
'node': {
|
|
821
|
-
'feature':
|
|
897
|
+
'feature': updated_node_feature,
|
|
822
898
|
'aggregate': None,
|
|
823
899
|
}
|
|
824
900
|
}
|
|
825
901
|
)
|
|
826
|
-
|
|
902
|
+
|
|
827
903
|
def get_config(self) -> dict:
|
|
828
904
|
config = super().get_config()
|
|
829
|
-
config.update({
|
|
830
|
-
|
|
831
|
-
'update_edge_feature': self._update_edge_feature,
|
|
832
|
-
'attention_activation': keras.activations.serialize(self._attention_activation),
|
|
833
|
-
})
|
|
834
|
-
return config
|
|
905
|
+
config.update({})
|
|
906
|
+
return config
|
|
835
907
|
|
|
836
908
|
|
|
837
909
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
@@ -839,6 +911,8 @@ class GTConv(GraphConv):
|
|
|
839
911
|
|
|
840
912
|
"""Graph transformer layer.
|
|
841
913
|
|
|
914
|
+
Also supports 3D molecular graphs.
|
|
915
|
+
|
|
842
916
|
>>> graph = molcraft.tensors.GraphTensor(
|
|
843
917
|
... context={
|
|
844
918
|
... 'size': [2]
|
|
@@ -862,10 +936,9 @@ class GTConv(GraphConv):
|
|
|
862
936
|
},
|
|
863
937
|
edge={
|
|
864
938
|
'source': <tf.Tensor: shape=[2], dtype=int32>,
|
|
865
|
-
'target': <tf.Tensor: shape=[2], dtype=int32
|
|
939
|
+
'target': <tf.Tensor: shape=[2], dtype=int32>,
|
|
866
940
|
}
|
|
867
941
|
)
|
|
868
|
-
|
|
869
942
|
"""
|
|
870
943
|
|
|
871
944
|
def __init__(
|
|
@@ -875,14 +948,16 @@ class GTConv(GraphConv):
|
|
|
875
948
|
activation: keras.layers.Activation | str | None = "relu",
|
|
876
949
|
use_bias: bool = True,
|
|
877
950
|
normalize: bool = False,
|
|
951
|
+
skip_connect: bool = True,
|
|
878
952
|
attention_dropout: float = 0.0,
|
|
879
953
|
**kwargs,
|
|
880
954
|
) -> None:
|
|
881
955
|
super().__init__(
|
|
882
956
|
units=units,
|
|
883
957
|
activation=activation,
|
|
884
|
-
use_bias=use_bias,
|
|
885
958
|
normalize=normalize,
|
|
959
|
+
use_bias=use_bias,
|
|
960
|
+
skip_connect=skip_connect,
|
|
886
961
|
**kwargs
|
|
887
962
|
)
|
|
888
963
|
self._heads = heads
|
|
@@ -900,6 +975,8 @@ class GTConv(GraphConv):
|
|
|
900
975
|
return self._head_units
|
|
901
976
|
|
|
902
977
|
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
978
|
+
"""Builds the layer.
|
|
979
|
+
"""
|
|
903
980
|
self._query_dense = self.get_einsum_dense(
|
|
904
981
|
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
905
982
|
)
|
|
@@ -912,29 +989,36 @@ class GTConv(GraphConv):
|
|
|
912
989
|
self._output_dense = self.get_dense(self.units)
|
|
913
990
|
self._softmax_dropout = keras.layers.Dropout(self._attention_dropout)
|
|
914
991
|
|
|
915
|
-
self.
|
|
916
|
-
|
|
917
|
-
if self._add_bias:
|
|
918
|
-
self._edge_bias = EdgeBias(biases=self.heads)
|
|
992
|
+
if self.has_edge_feature:
|
|
993
|
+
self._attention_bias_dense_1 = self.get_einsum_dense('ij,jkh->ikh', (1, self.heads))
|
|
919
994
|
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
self.
|
|
924
|
-
|
|
995
|
+
if self.has_node_coordinate:
|
|
996
|
+
node_feature_dim = spec.node['feature'].shape[-1]
|
|
997
|
+
num_kernels = self.units
|
|
998
|
+
self._gaussian_loc = self.add_weight(
|
|
999
|
+
shape=[num_kernels], initializer='zeros', dtype='float32', trainable=True
|
|
1000
|
+
)
|
|
1001
|
+
self._gaussian_scale = self.add_weight(
|
|
1002
|
+
shape=[num_kernels], initializer='ones', dtype='float32', trainable=True
|
|
1003
|
+
)
|
|
1004
|
+
self._centrality_dense = self.get_dense(units=node_feature_dim)
|
|
1005
|
+
self._attention_bias_dense_2 = self.get_einsum_dense('ij,jkh->ikh', (1, self.heads))
|
|
925
1006
|
|
|
926
1007
|
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
927
|
-
if self._add_bias:
|
|
928
|
-
edge_bias = self._edge_bias(tensor)
|
|
929
|
-
tensor = tensor.update(
|
|
930
|
-
{
|
|
931
|
-
'edge': {
|
|
932
|
-
'bias': edge_bias
|
|
933
|
-
}
|
|
934
|
-
}
|
|
935
|
-
)
|
|
936
1008
|
node_feature = tensor.node['feature']
|
|
937
|
-
|
|
1009
|
+
|
|
1010
|
+
if self.has_node_coordinate:
|
|
1011
|
+
euclidean_distance = ops.euclidean_distance(
|
|
1012
|
+
tensor.gather('coordinate', 'target'),
|
|
1013
|
+
tensor.gather('coordinate', 'source'),
|
|
1014
|
+
axis=-1
|
|
1015
|
+
)
|
|
1016
|
+
gaussian = ops.gaussian(
|
|
1017
|
+
euclidean_distance, self._gaussian_loc, self._gaussian_scale
|
|
1018
|
+
)
|
|
1019
|
+
centrality = keras.ops.segment_sum(gaussian, tensor.edge['target'], tensor.num_nodes)
|
|
1020
|
+
node_feature += self._centrality_dense(centrality)
|
|
1021
|
+
|
|
938
1022
|
query = self._query_dense(node_feature)
|
|
939
1023
|
key = self._key_dense(node_feature)
|
|
940
1024
|
value = self._value_dense(node_feature)
|
|
@@ -946,23 +1030,45 @@ class GTConv(GraphConv):
|
|
|
946
1030
|
attention_score = keras.ops.sum(query * key, axis=1, keepdims=True)
|
|
947
1031
|
attention_score /= keras.ops.sqrt(float(self.head_units))
|
|
948
1032
|
|
|
949
|
-
|
|
1033
|
+
if self.has_edge_feature:
|
|
1034
|
+
attention_score += self._attention_bias_dense_1(tensor.edge['feature'])
|
|
1035
|
+
|
|
1036
|
+
if self.has_node_coordinate:
|
|
1037
|
+
attention_score += self._attention_bias_dense_2(gaussian)
|
|
1038
|
+
|
|
950
1039
|
attention = ops.edge_softmax(attention_score, tensor.edge['target'])
|
|
951
1040
|
attention = self._softmax_dropout(attention)
|
|
952
|
-
message = ops.edge_weight(value, attention)
|
|
953
1041
|
|
|
1042
|
+
if self.has_node_coordinate:
|
|
1043
|
+
displacement = ops.displacement(
|
|
1044
|
+
tensor.gather('coordinate', 'target'),
|
|
1045
|
+
tensor.gather('coordinate', 'source'),
|
|
1046
|
+
normalize=True
|
|
1047
|
+
)
|
|
1048
|
+
attention *= keras.ops.expand_dims(displacement, axis=-1)
|
|
1049
|
+
attention = keras.ops.expand_dims(attention, axis=2)
|
|
1050
|
+
value = keras.ops.expand_dims(value, axis=1)
|
|
1051
|
+
|
|
1052
|
+
message = ops.edge_weight(value, attention)
|
|
1053
|
+
|
|
954
1054
|
return tensor.update(
|
|
955
1055
|
{
|
|
956
1056
|
'edge': {
|
|
957
|
-
'message': message
|
|
1057
|
+
'message': message,
|
|
958
1058
|
},
|
|
959
1059
|
}
|
|
960
1060
|
)
|
|
961
1061
|
|
|
962
1062
|
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
963
1063
|
node_feature = tensor.aggregate('message', mode='sum')
|
|
964
|
-
|
|
1064
|
+
if self.has_node_coordinate:
|
|
1065
|
+
shape = (tensor.num_nodes, -1, self.units)
|
|
1066
|
+
else:
|
|
1067
|
+
shape = (tensor.num_nodes, self.units)
|
|
1068
|
+
node_feature = keras.ops.reshape(node_feature, shape)
|
|
965
1069
|
node_feature = self._output_dense(node_feature)
|
|
1070
|
+
if self.has_node_coordinate:
|
|
1071
|
+
node_feature = keras.ops.sum(node_feature, axis=1)
|
|
966
1072
|
return tensor.update(
|
|
967
1073
|
{
|
|
968
1074
|
'node': {
|
|
@@ -973,20 +1079,6 @@ class GTConv(GraphConv):
|
|
|
973
1079
|
}
|
|
974
1080
|
}
|
|
975
1081
|
)
|
|
976
|
-
|
|
977
|
-
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
978
|
-
node_feature = tensor.node['aggregate']
|
|
979
|
-
node_feature = self._feedforward_intermediate_dense(node_feature)
|
|
980
|
-
node_feature = self._feedforward_activation(node_feature)
|
|
981
|
-
node_feature = self._feedforward_output_dense(node_feature)
|
|
982
|
-
return tensor.update(
|
|
983
|
-
{
|
|
984
|
-
'node': {
|
|
985
|
-
'feature': node_feature,
|
|
986
|
-
'aggregate': None,
|
|
987
|
-
},
|
|
988
|
-
}
|
|
989
|
-
)
|
|
990
1082
|
|
|
991
1083
|
def get_config(self) -> dict:
|
|
992
1084
|
config = super().get_config()
|
|
@@ -998,17 +1090,49 @@ class GTConv(GraphConv):
|
|
|
998
1090
|
|
|
999
1091
|
|
|
1000
1092
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1001
|
-
class
|
|
1093
|
+
class EGConv(GraphConv):
|
|
1002
1094
|
|
|
1003
|
-
"""
|
|
1095
|
+
"""Equivariant graph neural network layer 3D.
|
|
1096
|
+
|
|
1097
|
+
Only supports 3D molecular graphs.
|
|
1098
|
+
|
|
1099
|
+
>>> graph = molcraft.tensors.GraphTensor(
|
|
1100
|
+
... context={
|
|
1101
|
+
... 'size': [2]
|
|
1102
|
+
... },
|
|
1103
|
+
... node={
|
|
1104
|
+
... 'feature': [[1.], [2.]],
|
|
1105
|
+
... 'coordinate': [[0.1, -0.1, 0.5], [1.2, -0.5, 2.1]],
|
|
1106
|
+
... },
|
|
1107
|
+
... edge={
|
|
1108
|
+
... 'source': [0, 1],
|
|
1109
|
+
... 'target': [1, 0],
|
|
1110
|
+
... }
|
|
1111
|
+
... )
|
|
1112
|
+
>>> conv = molcraft.layers.EGConv(units=4)
|
|
1113
|
+
>>> conv(graph)
|
|
1114
|
+
GraphTensor(
|
|
1115
|
+
context={
|
|
1116
|
+
'size': <tf.Tensor: shape=[1], dtype=int32>
|
|
1117
|
+
},
|
|
1118
|
+
node={
|
|
1119
|
+
'feature': <tf.Tensor: shape=[2, 4], dtype=float32>,
|
|
1120
|
+
'coordinate': <tf.Tensor: shape=[2, 3], dtype=float32>
|
|
1121
|
+
},
|
|
1122
|
+
edge={
|
|
1123
|
+
'source': <tf.Tensor: shape=[2], dtype=int32>,
|
|
1124
|
+
'target': <tf.Tensor: shape=[2], dtype=int32>
|
|
1125
|
+
}
|
|
1126
|
+
)
|
|
1004
1127
|
"""
|
|
1005
1128
|
|
|
1006
1129
|
def __init__(
|
|
1007
1130
|
self,
|
|
1008
1131
|
units: int = 128,
|
|
1009
|
-
activation: keras.layers.Activation | str | None =
|
|
1132
|
+
activation: keras.layers.Activation | str | None = 'silu',
|
|
1010
1133
|
use_bias: bool = True,
|
|
1011
1134
|
normalize: bool = False,
|
|
1135
|
+
skip_connect: bool = True,
|
|
1012
1136
|
**kwargs
|
|
1013
1137
|
) -> None:
|
|
1014
1138
|
super().__init__(
|
|
@@ -1016,262 +1140,30 @@ class MPConv(GraphConv):
|
|
|
1016
1140
|
activation=activation,
|
|
1017
1141
|
use_bias=use_bias,
|
|
1018
1142
|
normalize=normalize,
|
|
1143
|
+
skip_connect=skip_connect,
|
|
1019
1144
|
**kwargs
|
|
1020
1145
|
)
|
|
1021
1146
|
|
|
1022
1147
|
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
self.project_input_node_feature = node_feature_dim != self.units
|
|
1028
|
-
if self.project_input_node_feature:
|
|
1029
|
-
warn(
|
|
1030
|
-
'Input node feature dim does not match updated node feature dim. '
|
|
1031
|
-
'To make sure input node feature can be passed as `states` to the '
|
|
1032
|
-
'GRU cell, it will automatically be projected prior to it.'
|
|
1033
|
-
)
|
|
1034
|
-
self._previous_node_dense = self.get_dense(
|
|
1035
|
-
self.units, activation=self._activation
|
|
1148
|
+
if not self.has_node_coordinate:
|
|
1149
|
+
raise ValueError(
|
|
1150
|
+
'Could not find `coordinate`s in node, '
|
|
1151
|
+
'which is required for Conv3D layers.'
|
|
1036
1152
|
)
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
],
|
|
1044
|
-
axis=-1
|
|
1045
|
-
)
|
|
1046
|
-
if self._has_edge_feature:
|
|
1047
|
-
feature = keras.ops.concatenate(
|
|
1048
|
-
[
|
|
1049
|
-
feature,
|
|
1050
|
-
tensor.edge['feature']
|
|
1051
|
-
],
|
|
1052
|
-
axis=-1
|
|
1053
|
-
)
|
|
1054
|
-
message = self.message_fn(feature)
|
|
1055
|
-
return tensor.update(
|
|
1056
|
-
{
|
|
1057
|
-
'edge': {
|
|
1058
|
-
'message': message,
|
|
1059
|
-
}
|
|
1060
|
-
}
|
|
1061
|
-
)
|
|
1062
|
-
|
|
1063
|
-
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1064
|
-
aggregate = tensor.aggregate('message', mode='mean')
|
|
1065
|
-
feature = tensor.node['feature']
|
|
1066
|
-
if self.project_input_node_feature:
|
|
1067
|
-
feature = self._previous_node_dense(feature)
|
|
1068
|
-
return tensor.update(
|
|
1069
|
-
{
|
|
1070
|
-
'node': {
|
|
1071
|
-
'aggregate': aggregate,
|
|
1072
|
-
'feature': feature,
|
|
1073
|
-
}
|
|
1074
|
-
}
|
|
1075
|
-
)
|
|
1076
|
-
|
|
1077
|
-
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1078
|
-
updated_node_feature, _ = self.update_fn(
|
|
1079
|
-
inputs=tensor.node['aggregate'],
|
|
1080
|
-
states=tensor.node['feature']
|
|
1081
|
-
)
|
|
1082
|
-
return tensor.update(
|
|
1083
|
-
{
|
|
1084
|
-
'node': {
|
|
1085
|
-
'feature': updated_node_feature,
|
|
1086
|
-
'aggregate': None,
|
|
1087
|
-
}
|
|
1088
|
-
}
|
|
1089
|
-
)
|
|
1090
|
-
|
|
1091
|
-
def get_config(self) -> dict:
|
|
1092
|
-
config = super().get_config()
|
|
1093
|
-
config.update({})
|
|
1094
|
-
return config
|
|
1095
|
-
|
|
1096
|
-
|
|
1097
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1098
|
-
class GTConv3D(GTConv):
|
|
1099
|
-
|
|
1100
|
-
"""Graph transformer layer 3D.
|
|
1101
|
-
"""
|
|
1102
|
-
|
|
1103
|
-
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1104
|
-
"""Builds the layer.
|
|
1105
|
-
"""
|
|
1106
|
-
super().build(spec)
|
|
1107
|
-
if self._add_bias:
|
|
1108
|
-
node_feature_dim = spec.node['feature'].shape[-1]
|
|
1109
|
-
kernels = self.units
|
|
1110
|
-
self._gaussian_basis = GaussianDistance(kernels)
|
|
1111
|
-
self._centrality_dense = self.get_dense(units=node_feature_dim)
|
|
1112
|
-
self._gaussian_edge_bias = self.get_dense(self.heads)
|
|
1113
|
-
|
|
1114
|
-
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1115
|
-
node_feature = tensor.node['feature']
|
|
1116
|
-
|
|
1117
|
-
if self._add_bias:
|
|
1118
|
-
gaussian = self._gaussian_basis(tensor)
|
|
1119
|
-
centrality = keras.ops.segment_sum(
|
|
1120
|
-
gaussian, tensor.edge['target'], tensor.num_nodes
|
|
1121
|
-
)
|
|
1122
|
-
node_feature += self._centrality_dense(centrality)
|
|
1123
|
-
|
|
1124
|
-
edge_bias = self._edge_bias(tensor) + self._gaussian_edge_bias(gaussian)
|
|
1125
|
-
tensor = tensor.update({'edge': {'bias': edge_bias}})
|
|
1126
|
-
|
|
1127
|
-
query = self._query_dense(node_feature)
|
|
1128
|
-
key = self._key_dense(node_feature)
|
|
1129
|
-
value = self._value_dense(node_feature)
|
|
1130
|
-
|
|
1131
|
-
query = ops.gather(query, tensor.edge['source'])
|
|
1132
|
-
key = ops.gather(key, tensor.edge['target'])
|
|
1133
|
-
value = ops.gather(value, tensor.edge['source'])
|
|
1134
|
-
|
|
1135
|
-
attention_score = keras.ops.sum(query * key, axis=1, keepdims=True)
|
|
1136
|
-
attention_score /= keras.ops.sqrt(float(self.head_units))
|
|
1137
|
-
|
|
1138
|
-
attention_score += keras.ops.expand_dims(tensor.edge['bias'], axis=1)
|
|
1139
|
-
|
|
1140
|
-
attention = ops.edge_softmax(attention_score, tensor.edge['target'])
|
|
1141
|
-
attention = self._softmax_dropout(attention)
|
|
1142
|
-
|
|
1143
|
-
distance = keras.ops.subtract(
|
|
1144
|
-
tensor.gather('coordinate', 'source'),
|
|
1145
|
-
tensor.gather('coordinate', 'target')
|
|
1146
|
-
)
|
|
1147
|
-
euclidean_distance = ops.euclidean_distance(
|
|
1148
|
-
tensor.gather('coordinate', 'source'),
|
|
1149
|
-
tensor.gather('coordinate', 'target'),
|
|
1150
|
-
axis=-1
|
|
1151
|
-
)
|
|
1152
|
-
distance /= euclidean_distance
|
|
1153
|
-
|
|
1154
|
-
attention *= keras.ops.expand_dims(distance, axis=-1)
|
|
1155
|
-
attention = keras.ops.expand_dims(attention, axis=2)
|
|
1156
|
-
value = keras.ops.expand_dims(value, axis=1)
|
|
1157
|
-
|
|
1158
|
-
message = ops.edge_weight(value, attention)
|
|
1159
|
-
|
|
1160
|
-
return tensor.update(
|
|
1161
|
-
{
|
|
1162
|
-
'edge': {
|
|
1163
|
-
'message': message,
|
|
1164
|
-
},
|
|
1165
|
-
}
|
|
1166
|
-
)
|
|
1167
|
-
|
|
1168
|
-
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1169
|
-
node_feature = tensor.aggregate('message', mode='sum')
|
|
1170
|
-
node_feature = keras.ops.reshape(
|
|
1171
|
-
node_feature, (tensor.num_nodes, -1, self.units)
|
|
1172
|
-
)
|
|
1173
|
-
node_feature = self._output_dense(node_feature)
|
|
1174
|
-
node_feature = keras.ops.sum(node_feature, axis=1)
|
|
1175
|
-
return tensor.update(
|
|
1176
|
-
{
|
|
1177
|
-
'node': {
|
|
1178
|
-
'aggregate': node_feature,
|
|
1179
|
-
},
|
|
1180
|
-
'edge': {
|
|
1181
|
-
'message': None,
|
|
1182
|
-
}
|
|
1183
|
-
}
|
|
1184
|
-
)
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1188
|
-
class MPConv3D(MPConv):
|
|
1189
|
-
|
|
1190
|
-
"""Message passing neural network layer 3D.
|
|
1191
|
-
"""
|
|
1192
|
-
|
|
1193
|
-
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1194
|
-
euclidean_distance = ops.euclidean_distance(
|
|
1195
|
-
tensor.gather('coordinate', 'target'),
|
|
1196
|
-
tensor.gather('coordinate', 'source'),
|
|
1197
|
-
axis=-1
|
|
1198
|
-
)
|
|
1199
|
-
feature = keras.ops.concatenate(
|
|
1200
|
-
[
|
|
1201
|
-
tensor.gather('feature', 'source'),
|
|
1202
|
-
tensor.gather('feature', 'target'),
|
|
1203
|
-
euclidean_distance,
|
|
1204
|
-
],
|
|
1205
|
-
axis=-1
|
|
1206
|
-
)
|
|
1207
|
-
if self._has_edge_feature:
|
|
1208
|
-
feature = keras.ops.concatenate(
|
|
1209
|
-
[
|
|
1210
|
-
feature,
|
|
1211
|
-
tensor.edge['feature']
|
|
1212
|
-
],
|
|
1213
|
-
axis=-1
|
|
1214
|
-
)
|
|
1215
|
-
message = self.message_fn(feature)
|
|
1216
|
-
return tensor.update(
|
|
1217
|
-
{
|
|
1218
|
-
'edge': {
|
|
1219
|
-
'message': message,
|
|
1220
|
-
}
|
|
1221
|
-
}
|
|
1222
|
-
)
|
|
1223
|
-
|
|
1224
|
-
|
|
1225
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1226
|
-
class EGConv3D(GraphConv):
|
|
1227
|
-
|
|
1228
|
-
"""Equivariant graph neural network layer 3D.
|
|
1229
|
-
"""
|
|
1230
|
-
|
|
1231
|
-
def __init__(
|
|
1232
|
-
self,
|
|
1233
|
-
units: int = 128,
|
|
1234
|
-
activation: keras.layers.Activation | str | None = 'silu',
|
|
1235
|
-
use_bias: bool = True,
|
|
1236
|
-
normalize: bool = False,
|
|
1237
|
-
**kwargs
|
|
1238
|
-
) -> None:
|
|
1239
|
-
super().__init__(
|
|
1240
|
-
units=units,
|
|
1241
|
-
activation=activation,
|
|
1242
|
-
use_bias=use_bias,
|
|
1243
|
-
normalize=normalize,
|
|
1244
|
-
**kwargs
|
|
1245
|
-
)
|
|
1246
|
-
|
|
1247
|
-
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1248
|
-
if 'coordinate' not in spec.node:
|
|
1249
|
-
raise ValueError(
|
|
1250
|
-
'Could not find `coordinate`s in node, '
|
|
1251
|
-
'which is required for Conv3D layers.'
|
|
1252
|
-
)
|
|
1253
|
-
self._has_edge_feature = 'feature' in spec.edge
|
|
1254
|
-
self._message_feedforward_intermediate = self.get_dense(
|
|
1255
|
-
self.units, activation=self._activation
|
|
1256
|
-
)
|
|
1257
|
-
self._message_feedforward_final = self.get_dense(
|
|
1258
|
-
self.units, activation=self._activation
|
|
1259
|
-
)
|
|
1153
|
+
self._message_feedforward_intermediate = self.get_dense(
|
|
1154
|
+
self.units, activation=self.activation
|
|
1155
|
+
)
|
|
1156
|
+
self._message_feedforward_final = self.get_dense(
|
|
1157
|
+
self.units, activation=self.activation
|
|
1158
|
+
)
|
|
1260
1159
|
|
|
1261
1160
|
self._coord_feedforward_intermediate = self.get_dense(
|
|
1262
|
-
self.units, activation=self.
|
|
1161
|
+
self.units, activation=self.activation
|
|
1263
1162
|
)
|
|
1264
1163
|
self._coord_feedforward_final = self.get_dense(
|
|
1265
1164
|
1, use_bias=False, activation='tanh'
|
|
1266
1165
|
)
|
|
1267
1166
|
|
|
1268
|
-
has_overridden_update = self.__class__.update != EGConv3D.update
|
|
1269
|
-
if not has_overridden_update:
|
|
1270
|
-
self._feedforward_intermediate = self.get_dense(
|
|
1271
|
-
self.units, activation=self._activation
|
|
1272
|
-
)
|
|
1273
|
-
self._feedforward_output = self.get_dense(self.units)
|
|
1274
|
-
|
|
1275
1167
|
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1276
1168
|
relative_node_coordinate = keras.ops.subtract(
|
|
1277
1169
|
tensor.gather('coordinate', 'target'),
|
|
@@ -1300,7 +1192,7 @@ class EGConv3D(GraphConv):
|
|
|
1300
1192
|
],
|
|
1301
1193
|
axis=-1
|
|
1302
1194
|
)
|
|
1303
|
-
if self.
|
|
1195
|
+
if self.has_edge_feature:
|
|
1304
1196
|
feature = keras.ops.concatenate(
|
|
1305
1197
|
[
|
|
1306
1198
|
feature,
|
|
@@ -1339,7 +1231,7 @@ class EGConv3D(GraphConv):
|
|
|
1339
1231
|
# graph to graph). Therefore, a mean mean aggregation is performed
|
|
1340
1232
|
# instead:
|
|
1341
1233
|
aggregate = tensor.aggregate('message', mode='mean')
|
|
1342
|
-
|
|
1234
|
+
aggregate = keras.ops.concatenate([aggregate, tensor.node['feature']], axis=-1)
|
|
1343
1235
|
# Simply added to silence warning ('no gradients for variables ...')
|
|
1344
1236
|
aggregate += (0.0 * keras.ops.sum(coordinate))
|
|
1345
1237
|
|
|
@@ -1355,26 +1247,6 @@ class EGConv3D(GraphConv):
|
|
|
1355
1247
|
}
|
|
1356
1248
|
}
|
|
1357
1249
|
)
|
|
1358
|
-
|
|
1359
|
-
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1360
|
-
feature = keras.ops.concatenate(
|
|
1361
|
-
[
|
|
1362
|
-
tensor.node['aggregate'],
|
|
1363
|
-
tensor.node['feature']
|
|
1364
|
-
],
|
|
1365
|
-
axis=-1
|
|
1366
|
-
)
|
|
1367
|
-
updated_node_feature = self._feedforward_output(
|
|
1368
|
-
self._feedforward_intermediate(feature)
|
|
1369
|
-
)
|
|
1370
|
-
return tensor.update(
|
|
1371
|
-
{
|
|
1372
|
-
'node': {
|
|
1373
|
-
'feature': updated_node_feature,
|
|
1374
|
-
'aggregate': None,
|
|
1375
|
-
},
|
|
1376
|
-
}
|
|
1377
|
-
)
|
|
1378
1250
|
|
|
1379
1251
|
def get_config(self) -> dict:
|
|
1380
1252
|
config = super().get_config()
|
|
@@ -1417,146 +1289,6 @@ class Readout(GraphLayer):
|
|
|
1417
1289
|
config['mode'] = self.mode
|
|
1418
1290
|
return config
|
|
1419
1291
|
|
|
1420
|
-
|
|
1421
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1422
|
-
class GraphNetwork(GraphLayer):
|
|
1423
|
-
|
|
1424
|
-
"""Graph neural network.
|
|
1425
|
-
|
|
1426
|
-
Sequentially calls graph layers (`GraphLayer`) and concatenates its output.
|
|
1427
|
-
|
|
1428
|
-
Arguments:
|
|
1429
|
-
layers (list):
|
|
1430
|
-
A list of graph layers.
|
|
1431
|
-
"""
|
|
1432
|
-
|
|
1433
|
-
def __init__(self, layers: list[GraphLayer], **kwargs) -> None:
|
|
1434
|
-
super().__init__(**kwargs)
|
|
1435
|
-
self.layers = layers
|
|
1436
|
-
self._update_edge_feature = False
|
|
1437
|
-
|
|
1438
|
-
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1439
|
-
units = self.layers[0].units
|
|
1440
|
-
node_feature_dim = spec.node['feature'].shape[-1]
|
|
1441
|
-
self._update_node_feature = node_feature_dim != units
|
|
1442
|
-
if self._update_node_feature:
|
|
1443
|
-
warn(
|
|
1444
|
-
'Node feature dim does not match `units` of the first layer. '
|
|
1445
|
-
'Automatically adding a node projection layer to match `units`.'
|
|
1446
|
-
)
|
|
1447
|
-
self._node_dense = self.get_dense(units)
|
|
1448
|
-
self._has_edge_feature = 'feature' in spec.edge
|
|
1449
|
-
if self._has_edge_feature:
|
|
1450
|
-
edge_feature_dim = spec.edge['feature'].shape[-1]
|
|
1451
|
-
self._update_edge_feature = edge_feature_dim != units
|
|
1452
|
-
if self._update_edge_feature:
|
|
1453
|
-
warn(
|
|
1454
|
-
'Edge feature dim does not match `units` of the first layer. '
|
|
1455
|
-
'Automatically adding a edge projection layer to match `units`.'
|
|
1456
|
-
)
|
|
1457
|
-
self._edge_dense = self.get_dense(units)
|
|
1458
|
-
|
|
1459
|
-
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1460
|
-
x = tensors.to_dict(tensor)
|
|
1461
|
-
if self._update_node_feature:
|
|
1462
|
-
x['node']['feature'] = self._node_dense(tensor.node['feature'])
|
|
1463
|
-
if self._has_edge_feature and self._update_edge_feature:
|
|
1464
|
-
x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
|
|
1465
|
-
outputs = [x['node']['feature']]
|
|
1466
|
-
for layer in self.layers:
|
|
1467
|
-
x = layer(x)
|
|
1468
|
-
outputs.append(x['node']['feature'])
|
|
1469
|
-
return tensor.update(
|
|
1470
|
-
{
|
|
1471
|
-
'node': {
|
|
1472
|
-
'feature': keras.ops.concatenate(outputs, axis=-1)
|
|
1473
|
-
}
|
|
1474
|
-
}
|
|
1475
|
-
)
|
|
1476
|
-
|
|
1477
|
-
def tape_propagate(
|
|
1478
|
-
self,
|
|
1479
|
-
tensor: tensors.GraphTensor,
|
|
1480
|
-
tape: tf.GradientTape,
|
|
1481
|
-
training: bool | None = None,
|
|
1482
|
-
) -> tuple[tensors.GraphTensor, list[tf.Tensor]]:
|
|
1483
|
-
"""Performs the propagation with a `GradientTape`.
|
|
1484
|
-
|
|
1485
|
-
Performs the same forward pass as `propagate` but with a `GradientTape`
|
|
1486
|
-
watching intermediate node features.
|
|
1487
|
-
|
|
1488
|
-
Arguments:
|
|
1489
|
-
tensor (tensors.GraphTensor):
|
|
1490
|
-
The graph input.
|
|
1491
|
-
"""
|
|
1492
|
-
if isinstance(tensor, tensors.GraphTensor):
|
|
1493
|
-
x = tensors.to_dict(tensor)
|
|
1494
|
-
else:
|
|
1495
|
-
x = tensor
|
|
1496
|
-
if self._update_node_feature:
|
|
1497
|
-
x['node']['feature'] = self._node_dense(tensor.node['feature'])
|
|
1498
|
-
if self._update_edge_feature:
|
|
1499
|
-
x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
|
|
1500
|
-
tape.watch(x['node']['feature'])
|
|
1501
|
-
outputs = [x['node']['feature']]
|
|
1502
|
-
for layer in self.layers:
|
|
1503
|
-
x = layer(x, training=training)
|
|
1504
|
-
tape.watch(x['node']['feature'])
|
|
1505
|
-
outputs.append(x['node']['feature'])
|
|
1506
|
-
|
|
1507
|
-
tensor = tensor.update(
|
|
1508
|
-
{
|
|
1509
|
-
'node': {
|
|
1510
|
-
'feature': keras.ops.concatenate(outputs, axis=-1)
|
|
1511
|
-
}
|
|
1512
|
-
}
|
|
1513
|
-
)
|
|
1514
|
-
return tensor, outputs
|
|
1515
|
-
|
|
1516
|
-
def get_config(self) -> dict:
|
|
1517
|
-
config = super().get_config()
|
|
1518
|
-
config.update(
|
|
1519
|
-
{
|
|
1520
|
-
'layers': [
|
|
1521
|
-
keras.layers.serialize(layer) for layer in self.layers
|
|
1522
|
-
]
|
|
1523
|
-
}
|
|
1524
|
-
)
|
|
1525
|
-
return config
|
|
1526
|
-
|
|
1527
|
-
@classmethod
|
|
1528
|
-
def from_config(cls, config: dict) -> 'GraphNetwork':
|
|
1529
|
-
config['layers'] = [
|
|
1530
|
-
keras.layers.deserialize(layer) for layer in config['layers']
|
|
1531
|
-
]
|
|
1532
|
-
return super().from_config(config)
|
|
1533
|
-
|
|
1534
|
-
|
|
1535
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1536
|
-
class Extraction(GraphLayer):
|
|
1537
|
-
|
|
1538
|
-
def __init__(
|
|
1539
|
-
self,
|
|
1540
|
-
field: str,
|
|
1541
|
-
inner_field: str | None = None,
|
|
1542
|
-
**kwargs
|
|
1543
|
-
) -> None:
|
|
1544
|
-
super().__init__(**kwargs)
|
|
1545
|
-
self.field = field
|
|
1546
|
-
self.inner_field = inner_field
|
|
1547
|
-
|
|
1548
|
-
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1549
|
-
data = dict(getattr(tensor, self.field))
|
|
1550
|
-
if not self.inner_field:
|
|
1551
|
-
return data
|
|
1552
|
-
return data[self.inner_field]
|
|
1553
|
-
|
|
1554
|
-
def get_config(self):
|
|
1555
|
-
config = super().get_config()
|
|
1556
|
-
config['field'] = self.field
|
|
1557
|
-
config['inner_field'] = self.inner_field
|
|
1558
|
-
return config
|
|
1559
|
-
|
|
1560
1292
|
|
|
1561
1293
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1562
1294
|
class NodeEmbedding(GraphLayer):
|
|
@@ -1571,17 +1303,12 @@ class NodeEmbedding(GraphLayer):
|
|
|
1571
1303
|
dim: int = None,
|
|
1572
1304
|
normalize: bool = False,
|
|
1573
1305
|
embed_context: bool = False,
|
|
1574
|
-
allow_reconstruction: bool = False,
|
|
1575
|
-
allow_masking: bool = False,
|
|
1576
1306
|
**kwargs
|
|
1577
1307
|
) -> None:
|
|
1578
1308
|
super().__init__(**kwargs)
|
|
1579
1309
|
self.dim = dim
|
|
1580
1310
|
self._normalize = normalize
|
|
1581
1311
|
self._embed_context = embed_context
|
|
1582
|
-
self._masking_rate = None
|
|
1583
|
-
self._allow_masking = allow_masking
|
|
1584
|
-
self._allow_reconstruction = allow_reconstruction
|
|
1585
1312
|
|
|
1586
1313
|
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1587
1314
|
feature_dim = spec.node['feature'].shape[-1]
|
|
@@ -1595,47 +1322,31 @@ class NodeEmbedding(GraphLayer):
|
|
|
1595
1322
|
self._embed_context = False
|
|
1596
1323
|
if self._has_super and not self._embed_context:
|
|
1597
1324
|
self._super_feature = self.get_weight(shape=[self.dim], name='super_node_feature')
|
|
1598
|
-
if self._allow_masking:
|
|
1599
|
-
self._mask_feature = self.get_weight(shape=[self.dim], name='mask_node_feature')
|
|
1600
1325
|
if self._embed_context:
|
|
1601
1326
|
self._context_dense = self.get_dense(self.dim)
|
|
1602
1327
|
|
|
1603
|
-
if self._normalize:
|
|
1604
|
-
|
|
1605
|
-
|
|
1606
|
-
|
|
1607
|
-
|
|
1608
|
-
|
|
1609
|
-
self._norm = keras.layers.LayerNormalization(
|
|
1610
|
-
name='output_layer_norm'
|
|
1611
|
-
)
|
|
1328
|
+
if not self._normalize:
|
|
1329
|
+
self._norm = keras.layers.Identity()
|
|
1330
|
+
elif str(self._normalize).lower().startswith('layer'):
|
|
1331
|
+
self._norm = keras.layers.LayerNormalization()
|
|
1332
|
+
else:
|
|
1333
|
+
self._norm = keras.layers.BatchNormalization()
|
|
1612
1334
|
|
|
1613
1335
|
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1614
1336
|
feature = self._node_dense(tensor.node['feature'])
|
|
1615
1337
|
|
|
1616
|
-
if self._has_super:
|
|
1617
|
-
super_feature = (0 if self._embed_context else self._super_feature)
|
|
1338
|
+
if self._has_super and not self._embed_context:
|
|
1618
1339
|
super_mask = keras.ops.expand_dims(tensor.node['super'], 1)
|
|
1619
|
-
feature = keras.ops.where(super_mask,
|
|
1340
|
+
feature = keras.ops.where(super_mask, self._super_feature, feature)
|
|
1620
1341
|
|
|
1621
1342
|
if self._embed_context:
|
|
1622
1343
|
context_feature = self._context_dense(tensor.context['feature'])
|
|
1623
1344
|
feature = ops.scatter_update(feature, tensor.node['super'], context_feature)
|
|
1624
1345
|
tensor = tensor.update({'context': {'feature': None}})
|
|
1625
1346
|
|
|
1626
|
-
|
|
1627
|
-
if apply_mask:
|
|
1628
|
-
mask = keras.ops.expand_dims(tensor.node['mask'], -1)
|
|
1629
|
-
feature = keras.ops.where(mask, self._mask_feature, feature)
|
|
1630
|
-
elif self._allow_masking:
|
|
1631
|
-
feature = feature + (self._mask_feature * 0.0)
|
|
1347
|
+
feature = self._norm(feature)
|
|
1632
1348
|
|
|
1633
|
-
|
|
1634
|
-
feature = self._norm(feature)
|
|
1635
|
-
|
|
1636
|
-
if not self._allow_reconstruction:
|
|
1637
|
-
return tensor.update({'node': {'feature': feature}})
|
|
1638
|
-
return tensor.update({'node': {'feature': feature, 'target_feature': feature}})
|
|
1349
|
+
return tensor.update({'node': {'feature': feature}})
|
|
1639
1350
|
|
|
1640
1351
|
def get_config(self) -> dict:
|
|
1641
1352
|
config = super().get_config()
|
|
@@ -1643,8 +1354,6 @@ class NodeEmbedding(GraphLayer):
|
|
|
1643
1354
|
'dim': self.dim,
|
|
1644
1355
|
'normalize': self._normalize,
|
|
1645
1356
|
'embed_context': self._embed_context,
|
|
1646
|
-
'allow_masking': self._allow_masking,
|
|
1647
|
-
'allow_reconstruction': self._allow_reconstruction,
|
|
1648
1357
|
})
|
|
1649
1358
|
return config
|
|
1650
1359
|
|
|
@@ -1661,39 +1370,30 @@ class EdgeEmbedding(GraphLayer):
|
|
|
1661
1370
|
self,
|
|
1662
1371
|
dim: int = None,
|
|
1663
1372
|
normalize: bool = False,
|
|
1664
|
-
allow_masking: bool = True,
|
|
1665
1373
|
**kwargs
|
|
1666
1374
|
) -> None:
|
|
1667
1375
|
super().__init__(**kwargs)
|
|
1668
1376
|
self.dim = dim
|
|
1669
1377
|
self._normalize = normalize
|
|
1670
|
-
self._masking_rate = None
|
|
1671
|
-
self._allow_masking = allow_masking
|
|
1672
1378
|
|
|
1673
1379
|
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1674
1380
|
feature_dim = spec.edge['feature'].shape[-1]
|
|
1675
1381
|
if not self.dim:
|
|
1676
1382
|
self.dim = feature_dim
|
|
1677
|
-
self._edge_dense = self.get_dense(self.dim)
|
|
1383
|
+
self._edge_dense = self.get_dense(self.dim)
|
|
1384
|
+
|
|
1385
|
+
self._self_loop_feature = self.get_weight(shape=[self.dim], name='self_loop_edge_feature')
|
|
1678
1386
|
|
|
1679
1387
|
self._has_super = 'super' in spec.edge
|
|
1680
|
-
self._has_self_loop = 'self_loop' in spec.edge
|
|
1681
1388
|
if self._has_super:
|
|
1682
1389
|
self._super_feature = self.get_weight(shape=[self.dim], name='super_edge_feature')
|
|
1683
|
-
|
|
1684
|
-
|
|
1685
|
-
|
|
1686
|
-
|
|
1687
|
-
|
|
1688
|
-
|
|
1689
|
-
|
|
1690
|
-
self._norm = keras.layers.BatchNormalization(
|
|
1691
|
-
name='output_batch_norm'
|
|
1692
|
-
)
|
|
1693
|
-
else:
|
|
1694
|
-
self._norm = keras.layers.LayerNormalization(
|
|
1695
|
-
name='output_layer_norm'
|
|
1696
|
-
)
|
|
1390
|
+
|
|
1391
|
+
if not self._normalize:
|
|
1392
|
+
self._norm = keras.layers.Identity()
|
|
1393
|
+
elif str(self._normalize).lower().startswith('layer'):
|
|
1394
|
+
self._norm = keras.layers.LayerNormalization()
|
|
1395
|
+
else:
|
|
1396
|
+
self._norm = keras.layers.BatchNormalization()
|
|
1697
1397
|
|
|
1698
1398
|
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1699
1399
|
feature = self._edge_dense(tensor.edge['feature'])
|
|
@@ -1702,246 +1402,136 @@ class EdgeEmbedding(GraphLayer):
|
|
|
1702
1402
|
super_mask = keras.ops.expand_dims(tensor.edge['super'], 1)
|
|
1703
1403
|
feature = keras.ops.where(super_mask, self._super_feature, feature)
|
|
1704
1404
|
|
|
1705
|
-
|
|
1706
|
-
|
|
1707
|
-
feature = keras.ops.where(self_loop_mask, self._self_loop_feature, feature)
|
|
1708
|
-
|
|
1709
|
-
if (
|
|
1710
|
-
self._allow_masking and
|
|
1711
|
-
self._masking_rate is not None and
|
|
1712
|
-
self._masking_rate > 0
|
|
1713
|
-
):
|
|
1714
|
-
random = keras.random.uniform(shape=[tensor.num_edges])
|
|
1715
|
-
mask = random <= self._masking_rate
|
|
1716
|
-
if self._has_super:
|
|
1717
|
-
mask = keras.ops.logical_and(
|
|
1718
|
-
mask, keras.ops.logical_not(tensor.edge['super'])
|
|
1719
|
-
)
|
|
1720
|
-
mask = keras.ops.expand_dims(mask, -1)
|
|
1721
|
-
feature = keras.ops.where(mask, self._mask_feature, feature)
|
|
1722
|
-
elif self._allow_masking:
|
|
1723
|
-
# Simply added to silence warning ('no gradients for variables ...')
|
|
1724
|
-
feature += (0.0 * self._mask_feature)
|
|
1405
|
+
self_loop_mask = keras.ops.expand_dims(tensor.edge['source'] == tensor.edge['target'], 1)
|
|
1406
|
+
feature = keras.ops.where(self_loop_mask, self._self_loop_feature, feature)
|
|
1725
1407
|
|
|
1726
|
-
|
|
1727
|
-
feature = self._norm(feature)
|
|
1408
|
+
feature = self._norm(feature)
|
|
1728
1409
|
|
|
1729
|
-
return tensor.update({'edge': {'feature': feature
|
|
1730
|
-
|
|
1731
|
-
@property
|
|
1732
|
-
def masking_rate(self):
|
|
1733
|
-
return self._masking_rate
|
|
1734
|
-
|
|
1735
|
-
@masking_rate.setter
|
|
1736
|
-
def masking_rate(self, rate: float):
|
|
1737
|
-
if not self._allow_masking and rate is not None:
|
|
1738
|
-
raise ValueError(
|
|
1739
|
-
f'Cannot set `masking_rate` for layer {self} '
|
|
1740
|
-
'as `allow_masking` was set to `False`.'
|
|
1741
|
-
)
|
|
1742
|
-
self._masking_rate = float(rate)
|
|
1410
|
+
return tensor.update({'edge': {'feature': feature}})
|
|
1743
1411
|
|
|
1744
1412
|
def get_config(self) -> dict:
|
|
1745
1413
|
config = super().get_config()
|
|
1746
1414
|
config.update({
|
|
1747
1415
|
'dim': self.dim,
|
|
1748
1416
|
'normalize': self._normalize,
|
|
1749
|
-
'allow_masking': self._allow_masking
|
|
1750
1417
|
})
|
|
1751
1418
|
return config
|
|
1752
1419
|
|
|
1753
1420
|
|
|
1754
1421
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1755
|
-
class
|
|
1756
|
-
"""Base graph projection layer.
|
|
1757
|
-
"""
|
|
1758
|
-
def __init__(
|
|
1759
|
-
self,
|
|
1760
|
-
units: int = None,
|
|
1761
|
-
activation: str | keras.layers.Activation | None = None,
|
|
1762
|
-
use_bias: bool = True,
|
|
1763
|
-
field: str = 'node',
|
|
1764
|
-
**kwargs
|
|
1765
|
-
) -> None:
|
|
1766
|
-
super().__init__(use_bias=use_bias, **kwargs)
|
|
1767
|
-
self.units = units
|
|
1768
|
-
self._activation = keras.activations.get(activation)
|
|
1769
|
-
self.field = field
|
|
1770
|
-
|
|
1771
|
-
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1772
|
-
data = getattr(spec, self.field, None)
|
|
1773
|
-
if data is None:
|
|
1774
|
-
raise ValueError('Could not access field {self.field!r}.')
|
|
1775
|
-
feature_dim = data['feature'].shape[-1]
|
|
1776
|
-
if not self.units:
|
|
1777
|
-
self.units = feature_dim
|
|
1778
|
-
self._dense = self.get_dense(self.units)
|
|
1779
|
-
|
|
1780
|
-
def propagate(self, tensor: tensors.GraphTensor):
|
|
1781
|
-
feature = getattr(tensor, self.field)['feature']
|
|
1782
|
-
feature = self._dense(feature)
|
|
1783
|
-
feature = self._activation(feature)
|
|
1784
|
-
return tensor.update(
|
|
1785
|
-
{
|
|
1786
|
-
self.field: {
|
|
1787
|
-
'feature': feature
|
|
1788
|
-
}
|
|
1789
|
-
}
|
|
1790
|
-
)
|
|
1791
|
-
|
|
1792
|
-
def get_config(self) -> dict:
|
|
1793
|
-
config = super().get_config()
|
|
1794
|
-
config.update({
|
|
1795
|
-
'units': self.units,
|
|
1796
|
-
'activation': keras.activations.serialize(self._activation),
|
|
1797
|
-
'field': self.field,
|
|
1798
|
-
})
|
|
1799
|
-
return config
|
|
1800
|
-
|
|
1801
|
-
|
|
1802
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1803
|
-
class ContextProjection(Projection):
|
|
1804
|
-
"""Context projection layer.
|
|
1805
|
-
"""
|
|
1806
|
-
def __init__(self, units: int = None, activation: str = None, **kwargs):
|
|
1807
|
-
kwargs['field'] = 'context'
|
|
1808
|
-
super().__init__(units=units, activation=activation, **kwargs)
|
|
1809
|
-
|
|
1422
|
+
class GraphNetwork(GraphLayer):
|
|
1810
1423
|
|
|
1811
|
-
|
|
1812
|
-
class NodeProjection(Projection):
|
|
1813
|
-
"""Node projection layer.
|
|
1814
|
-
"""
|
|
1815
|
-
def __init__(self, units: int = None, activation: str = None, **kwargs):
|
|
1816
|
-
kwargs['field'] = 'node'
|
|
1817
|
-
super().__init__(units=units, activation=activation, **kwargs)
|
|
1424
|
+
"""Graph neural network.
|
|
1818
1425
|
|
|
1426
|
+
Sequentially calls graph layers (`GraphLayer`) and concatenates its output.
|
|
1819
1427
|
|
|
1820
|
-
|
|
1821
|
-
|
|
1822
|
-
|
|
1428
|
+
Arguments:
|
|
1429
|
+
layers (list):
|
|
1430
|
+
A list of graph layers.
|
|
1823
1431
|
"""
|
|
1824
|
-
def __init__(self, units: int = None, activation: str = None, **kwargs):
|
|
1825
|
-
kwargs['field'] = 'edge'
|
|
1826
|
-
super().__init__(units=units, activation=activation, **kwargs)
|
|
1827
|
-
|
|
1828
|
-
|
|
1829
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1830
|
-
class Reconstruction(GraphLayer):
|
|
1831
|
-
|
|
1832
|
-
def __init__(
|
|
1833
|
-
self,
|
|
1834
|
-
loss: keras.losses.Loss | str = 'mse',
|
|
1835
|
-
loss_weight: float = 0.5,
|
|
1836
|
-
**kwargs
|
|
1837
|
-
):
|
|
1838
|
-
super().__init__(**kwargs)
|
|
1839
|
-
self._loss_fn = keras.losses.get(loss)
|
|
1840
|
-
self._loss_weight = loss_weight
|
|
1841
|
-
|
|
1842
|
-
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1843
|
-
has_target_node_feature = 'target_feature' in spec.node
|
|
1844
|
-
if not has_target_node_feature:
|
|
1845
|
-
raise ValueError(
|
|
1846
|
-
'Could not find `target_feature` in `spec.node`. '
|
|
1847
|
-
'Add a `target_feature` via `NodeEmbedding` by setting '
|
|
1848
|
-
'`allow_reconstruction` to `True`.'
|
|
1849
|
-
)
|
|
1850
|
-
output_dim = spec.node['target_feature'].shape[-1]
|
|
1851
|
-
self._dense = self.get_dense(output_dim)
|
|
1852
|
-
|
|
1853
|
-
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1854
|
-
target_node_feature = tensor.node['target_feature']
|
|
1855
|
-
transformed_node_feature = tensor.node['feature']
|
|
1856
|
-
|
|
1857
|
-
reconstructed_node_feature = self._dense(
|
|
1858
|
-
transformed_node_feature
|
|
1859
|
-
)
|
|
1860
|
-
|
|
1861
|
-
loss = self._loss_fn(
|
|
1862
|
-
target_node_feature, reconstructed_node_feature
|
|
1863
|
-
)
|
|
1864
|
-
self.add_loss(keras.ops.sum(loss) * self._loss_weight)
|
|
1865
|
-
return tensor.update({'node': {'feature': transformed_node_feature}})
|
|
1866
|
-
|
|
1867
|
-
def get_config(self):
|
|
1868
|
-
config = super().get_config()
|
|
1869
|
-
config['loss'] = keras.losses.serialize(self._loss_fn)
|
|
1870
|
-
config['loss_weight'] = self._loss_weight
|
|
1871
|
-
return config
|
|
1872
|
-
|
|
1873
|
-
|
|
1874
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1875
|
-
class EdgeBias(GraphLayer):
|
|
1876
1432
|
|
|
1877
|
-
def __init__(self,
|
|
1433
|
+
def __init__(self, layers: list[GraphLayer], **kwargs) -> None:
|
|
1878
1434
|
super().__init__(**kwargs)
|
|
1879
|
-
self.
|
|
1435
|
+
self.layers = layers
|
|
1436
|
+
self._update_edge_feature = False
|
|
1880
1437
|
|
|
1881
1438
|
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1882
|
-
|
|
1883
|
-
|
|
1884
|
-
|
|
1885
|
-
|
|
1886
|
-
|
|
1887
|
-
|
|
1888
|
-
|
|
1439
|
+
units = self.layers[0].units
|
|
1440
|
+
node_feature_dim = spec.node['feature'].shape[-1]
|
|
1441
|
+
self._update_node_feature = node_feature_dim != units
|
|
1442
|
+
if self._update_node_feature:
|
|
1443
|
+
warnings.warn(
|
|
1444
|
+
'Node feature dim does not match `units` of the first layer. '
|
|
1445
|
+
'Automatically adding a node projection layer to match `units`.',
|
|
1446
|
+
stacklevel=2
|
|
1889
1447
|
)
|
|
1448
|
+
self._node_dense = self.get_dense(units)
|
|
1449
|
+
self._has_edge_feature = 'feature' in spec.edge
|
|
1450
|
+
if self._has_edge_feature:
|
|
1451
|
+
edge_feature_dim = spec.edge['feature'].shape[-1]
|
|
1452
|
+
self._update_edge_feature = edge_feature_dim != units
|
|
1453
|
+
if self._update_edge_feature:
|
|
1454
|
+
warnings.warn(
|
|
1455
|
+
'Edge feature dim does not match `units` of the first layer. '
|
|
1456
|
+
'Automatically adding a edge projection layer to match `units`.',
|
|
1457
|
+
stacklevel=2
|
|
1458
|
+
)
|
|
1459
|
+
self._edge_dense = self.get_dense(units)
|
|
1890
1460
|
|
|
1891
1461
|
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1892
|
-
|
|
1893
|
-
|
|
1894
|
-
|
|
1462
|
+
x = tensors.to_dict(tensor)
|
|
1463
|
+
if self._update_node_feature:
|
|
1464
|
+
x['node']['feature'] = self._node_dense(tensor.node['feature'])
|
|
1465
|
+
if self._has_edge_feature and self._update_edge_feature:
|
|
1466
|
+
x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
|
|
1467
|
+
outputs = [x['node']['feature']]
|
|
1468
|
+
for layer in self.layers:
|
|
1469
|
+
x = layer(x)
|
|
1470
|
+
outputs.append(x['node']['feature'])
|
|
1471
|
+
return tensor.update(
|
|
1472
|
+
{
|
|
1473
|
+
'node': {
|
|
1474
|
+
'feature': keras.ops.concatenate(outputs, axis=-1)
|
|
1475
|
+
}
|
|
1476
|
+
}
|
|
1895
1477
|
)
|
|
1896
|
-
if self._has_edge_feature:
|
|
1897
|
-
bias += self._edge_feature_dense(tensor.edge['feature'])
|
|
1898
|
-
if self._has_edge_length:
|
|
1899
|
-
bias += self._edge_length_dense(tensor.edge['length'])
|
|
1900
|
-
return bias
|
|
1901
|
-
|
|
1902
|
-
def get_config(self) -> dict:
|
|
1903
|
-
config = super().get_config()
|
|
1904
|
-
config.update({'biases': self.biases})
|
|
1905
|
-
return config
|
|
1906
1478
|
|
|
1479
|
+
def tape_propagate(
|
|
1480
|
+
self,
|
|
1481
|
+
tensor: tensors.GraphTensor,
|
|
1482
|
+
tape: tf.GradientTape,
|
|
1483
|
+
training: bool | None = None,
|
|
1484
|
+
) -> tuple[tensors.GraphTensor, list[tf.Tensor]]:
|
|
1485
|
+
"""Performs the propagation with a `GradientTape`.
|
|
1907
1486
|
|
|
1908
|
-
|
|
1909
|
-
|
|
1910
|
-
|
|
1911
|
-
def __init__(self, kernels: int, **kwargs):
|
|
1912
|
-
super().__init__(**kwargs)
|
|
1913
|
-
self.kernels = kernels
|
|
1487
|
+
Performs the same forward pass as `propagate` but with a `GradientTape`
|
|
1488
|
+
watching intermediate node features.
|
|
1914
1489
|
|
|
1915
|
-
|
|
1916
|
-
|
|
1917
|
-
|
|
1918
|
-
|
|
1919
|
-
|
|
1920
|
-
|
|
1921
|
-
|
|
1922
|
-
|
|
1923
|
-
|
|
1924
|
-
|
|
1925
|
-
|
|
1926
|
-
|
|
1927
|
-
)
|
|
1490
|
+
Arguments:
|
|
1491
|
+
tensor (tensors.GraphTensor):
|
|
1492
|
+
The graph input.
|
|
1493
|
+
"""
|
|
1494
|
+
if isinstance(tensor, tensors.GraphTensor):
|
|
1495
|
+
x = tensors.to_dict(tensor)
|
|
1496
|
+
else:
|
|
1497
|
+
x = tensor
|
|
1498
|
+
if self._update_node_feature:
|
|
1499
|
+
x['node']['feature'] = self._node_dense(tensor.node['feature'])
|
|
1500
|
+
if self._update_edge_feature:
|
|
1501
|
+
x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
|
|
1502
|
+
tape.watch(x['node']['feature'])
|
|
1503
|
+
outputs = [x['node']['feature']]
|
|
1504
|
+
for layer in self.layers:
|
|
1505
|
+
x = layer(x, training=training)
|
|
1506
|
+
tape.watch(x['node']['feature'])
|
|
1507
|
+
outputs.append(x['node']['feature'])
|
|
1928
1508
|
|
|
1929
|
-
|
|
1930
|
-
|
|
1931
|
-
|
|
1932
|
-
|
|
1933
|
-
|
|
1934
|
-
|
|
1935
|
-
return ops.gaussian(
|
|
1936
|
-
euclidean_distance, self._loc, self._scale
|
|
1509
|
+
tensor = tensor.update(
|
|
1510
|
+
{
|
|
1511
|
+
'node': {
|
|
1512
|
+
'feature': keras.ops.concatenate(outputs, axis=-1)
|
|
1513
|
+
}
|
|
1514
|
+
}
|
|
1937
1515
|
)
|
|
1938
|
-
|
|
1516
|
+
return tensor, outputs
|
|
1517
|
+
|
|
1939
1518
|
def get_config(self) -> dict:
|
|
1940
1519
|
config = super().get_config()
|
|
1941
|
-
config.update(
|
|
1942
|
-
|
|
1943
|
-
|
|
1520
|
+
config.update(
|
|
1521
|
+
{
|
|
1522
|
+
'layers': [
|
|
1523
|
+
keras.layers.serialize(layer) for layer in self.layers
|
|
1524
|
+
]
|
|
1525
|
+
}
|
|
1526
|
+
)
|
|
1944
1527
|
return config
|
|
1528
|
+
|
|
1529
|
+
@classmethod
|
|
1530
|
+
def from_config(cls, config: dict) -> 'GraphNetwork':
|
|
1531
|
+
config['layers'] = [
|
|
1532
|
+
keras.layers.deserialize(layer) for layer in config['layers']
|
|
1533
|
+
]
|
|
1534
|
+
return super().from_config(config)
|
|
1945
1535
|
|
|
1946
1536
|
|
|
1947
1537
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
@@ -1992,7 +1582,7 @@ class GaussianParams(keras.layers.Dense):
|
|
|
1992
1582
|
config['units'] = None
|
|
1993
1583
|
config['activation'] = keras.activations.serialize(self.loc_activation)
|
|
1994
1584
|
return config
|
|
1995
|
-
|
|
1585
|
+
|
|
1996
1586
|
|
|
1997
1587
|
def Input(spec: tensors.GraphTensor.Spec) -> dict:
|
|
1998
1588
|
"""Used to specify inputs to model.
|
|
@@ -2047,13 +1637,6 @@ def Input(spec: tensors.GraphTensor.Spec) -> dict:
|
|
|
2047
1637
|
return inputs
|
|
2048
1638
|
|
|
2049
1639
|
|
|
2050
|
-
def warn(message: str) -> None:
|
|
2051
|
-
warnings.warn(
|
|
2052
|
-
message=message,
|
|
2053
|
-
category=UserWarning,
|
|
2054
|
-
stacklevel=1
|
|
2055
|
-
)
|
|
2056
|
-
|
|
2057
1640
|
def _serialize_spec(spec: tensors.GraphTensor.Spec) -> dict:
|
|
2058
1641
|
serialized_spec = {}
|
|
2059
1642
|
for outer_field, data in spec.__dict__.items():
|
|
@@ -2095,5 +1678,3 @@ def _spec_from_inputs(inputs):
|
|
|
2095
1678
|
|
|
2096
1679
|
|
|
2097
1680
|
GraphTransformer = GTConv
|
|
2098
|
-
GraphTransformer3D = GTConv3D
|
|
2099
|
-
|