molcraft 0.1.0a6__py3-none-any.whl → 0.1.0a8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of molcraft might be problematic. Click here for more details.
- molcraft/__init__.py +1 -1
- molcraft/callbacks.py +67 -0
- molcraft/chem.py +45 -30
- molcraft/conformers.py +0 -4
- molcraft/features.py +3 -9
- molcraft/featurizers.py +18 -26
- molcraft/layers.py +466 -801
- molcraft/models.py +16 -1
- molcraft/ops.py +14 -3
- {molcraft-0.1.0a6.dist-info → molcraft-0.1.0a8.dist-info}/METADATA +2 -2
- molcraft-0.1.0a8.dist-info/RECORD +19 -0
- {molcraft-0.1.0a6.dist-info → molcraft-0.1.0a8.dist-info}/WHEEL +1 -1
- molcraft-0.1.0a6.dist-info/RECORD +0 -19
- {molcraft-0.1.0a6.dist-info → molcraft-0.1.0a8.dist-info}/licenses/LICENSE +0 -0
- {molcraft-0.1.0a6.dist-info → molcraft-0.1.0a8.dist-info}/top_level.txt +0 -0
molcraft/layers.py
CHANGED
|
@@ -274,20 +274,14 @@ class GraphConv(GraphLayer):
|
|
|
274
274
|
units (int):
|
|
275
275
|
Dimensionality of the output space.
|
|
276
276
|
activation (keras.layers.Activation, str, None):
|
|
277
|
-
Activation function to
|
|
278
|
-
Default to `
|
|
277
|
+
Activation function to be accessed via `self.activation`, and used for the
|
|
278
|
+
`message()` and `update()` methods, if not overriden. Default to `relu`.
|
|
279
279
|
use_bias (bool):
|
|
280
|
-
Whether bias should be used in dense layers. Default to `True`.
|
|
280
|
+
Whether bias should be used in the dense layers. Default to `True`.
|
|
281
281
|
normalize (bool, str):
|
|
282
|
-
Whether
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
Whether node feature input should be added to the node feature output.
|
|
286
|
-
If node feature input dim is not equal to `units` (node feature output dim),
|
|
287
|
-
a projection layer will automatically project the residual before adding it
|
|
288
|
-
to the output. To use weighted skip connection,
|
|
289
|
-
specify `weighted`. The weight multiplied with the skip connection is a
|
|
290
|
-
learnable scalar. Default to `True`.
|
|
282
|
+
Whether normalization should be applied to the final output. Default to `False`.
|
|
283
|
+
skip_connect (bool):
|
|
284
|
+
Whether node feature input should be added to the node feature output. Default to `True`.
|
|
291
285
|
kernel_initializer (keras.initializers.Initializer, str):
|
|
292
286
|
Initializer for the kernel weight matrix of the dense layers.
|
|
293
287
|
Default to `glorot_uniform`.
|
|
@@ -314,10 +308,10 @@ class GraphConv(GraphLayer):
|
|
|
314
308
|
def __init__(
|
|
315
309
|
self,
|
|
316
310
|
units: int = None,
|
|
317
|
-
activation: str | keras.layers.Activation | None =
|
|
311
|
+
activation: str | keras.layers.Activation | None = 'relu',
|
|
318
312
|
use_bias: bool = True,
|
|
319
|
-
normalize: bool
|
|
320
|
-
skip_connect: bool
|
|
313
|
+
normalize: bool = False,
|
|
314
|
+
skip_connect: bool = True,
|
|
321
315
|
**kwargs
|
|
322
316
|
) -> None:
|
|
323
317
|
super().__init__(use_bias=use_bias, **kwargs)
|
|
@@ -341,56 +335,56 @@ class GraphConv(GraphLayer):
|
|
|
341
335
|
def units(self):
|
|
342
336
|
return self._units
|
|
343
337
|
|
|
338
|
+
@property
|
|
339
|
+
def activation(self):
|
|
340
|
+
return self._activation
|
|
341
|
+
|
|
344
342
|
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
345
343
|
if not self.units:
|
|
346
344
|
raise ValueError(
|
|
347
345
|
f'`self.units` needs to be a positive integer. Found: {self.units}.'
|
|
348
346
|
)
|
|
349
347
|
node_feature_dim = spec.node['feature'].shape[-1]
|
|
350
|
-
self.
|
|
348
|
+
self._project_residual = (
|
|
351
349
|
self._skip_connect and (node_feature_dim != self.units)
|
|
352
350
|
)
|
|
353
|
-
if self.
|
|
354
|
-
warn(
|
|
351
|
+
if self._project_residual:
|
|
352
|
+
warnings.warn(
|
|
355
353
|
'`skip_connect` is set to `True`, but found incompatible dim '
|
|
356
354
|
'between input (node feature dim) and output (`self.units`). '
|
|
357
355
|
'Automatically applying a projection layer to residual to '
|
|
358
|
-
'match input and output. '
|
|
356
|
+
'match input and output. ',
|
|
357
|
+
stacklevel=2,
|
|
359
358
|
)
|
|
360
|
-
self.
|
|
361
|
-
self.units, name='
|
|
359
|
+
self._residual_dense = self.get_dense(
|
|
360
|
+
self.units, name='residual_dense'
|
|
362
361
|
)
|
|
363
362
|
|
|
364
|
-
|
|
365
|
-
self.
|
|
366
|
-
if self._use_weighted_skip_connection:
|
|
367
|
-
self._skip_connection_weight = self.add_weight(
|
|
368
|
-
name='skip_connection_weight',
|
|
369
|
-
shape=(),
|
|
370
|
-
initializer='ones',
|
|
371
|
-
trainable=True,
|
|
372
|
-
)
|
|
373
|
-
|
|
374
|
-
if self._normalize:
|
|
375
|
-
if str(self._normalize).lower().startswith('batch'):
|
|
376
|
-
self._output_norm = keras.layers.BatchNormalization(
|
|
377
|
-
name='output_batch_norm'
|
|
378
|
-
)
|
|
379
|
-
else:
|
|
380
|
-
self._output_norm = keras.layers.LayerNormalization(
|
|
381
|
-
name='output_layer_norm'
|
|
382
|
-
)
|
|
383
|
-
|
|
384
|
-
self._has_edge_feature = 'feature' in spec.edge
|
|
363
|
+
self.has_edge_feature = 'feature' in spec.edge
|
|
364
|
+
self.has_node_coordinate = 'coordinate' in spec.node
|
|
385
365
|
|
|
386
366
|
has_overridden_message = self.__class__.message != GraphConv.message
|
|
387
367
|
if not has_overridden_message:
|
|
388
|
-
self.
|
|
368
|
+
self._message_intermediate_dense = self.get_dense(self.units)
|
|
369
|
+
self._message_intermediate_activation = self.activation
|
|
370
|
+
self._message_final_dense = self.get_dense(self.units)
|
|
371
|
+
|
|
372
|
+
has_overridden_aggregate = self.__class__.message != GraphConv.aggregate
|
|
373
|
+
if not has_overridden_aggregate:
|
|
374
|
+
pass
|
|
389
375
|
|
|
390
376
|
has_overridden_update = self.__class__.update != GraphConv.update
|
|
391
377
|
if not has_overridden_update:
|
|
392
|
-
self.
|
|
393
|
-
self.
|
|
378
|
+
self._update_intermediate_dense = self.get_dense(self.units)
|
|
379
|
+
self._update_intermediate_activation = self.activation
|
|
380
|
+
self._update_final_dense = self.get_dense(self.units)
|
|
381
|
+
|
|
382
|
+
if not self._normalize:
|
|
383
|
+
self._normalization = keras.layers.Identity()
|
|
384
|
+
elif str(self._normalize).lower().startswith('layer'):
|
|
385
|
+
self._normalization = keras.layers.LayerNormalization()
|
|
386
|
+
else:
|
|
387
|
+
self._normalization = keras.layers.BatchNormalization()
|
|
394
388
|
|
|
395
389
|
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
396
390
|
"""Forward pass.
|
|
@@ -402,10 +396,10 @@ class GraphConv(GraphLayer):
|
|
|
402
396
|
A `GraphTensor` instance.
|
|
403
397
|
"""
|
|
404
398
|
if self._skip_connect:
|
|
405
|
-
|
|
406
|
-
if self.
|
|
407
|
-
|
|
408
|
-
|
|
399
|
+
residual = tensor.node['feature']
|
|
400
|
+
if self._project_residual:
|
|
401
|
+
residual = self._residual_dense(residual)
|
|
402
|
+
|
|
409
403
|
message = self.message(tensor)
|
|
410
404
|
if not isinstance(message, tensors.GraphTensor):
|
|
411
405
|
message = tensor.update({'edge': {'message': message}})
|
|
@@ -417,24 +411,24 @@ class GraphConv(GraphLayer):
|
|
|
417
411
|
aggregate = tensor.update({'node': {'aggregate': aggregate}})
|
|
418
412
|
elif not 'aggregate' in aggregate.node:
|
|
419
413
|
raise ValueError('Could not find `aggregate` in `node` output.')
|
|
420
|
-
|
|
414
|
+
|
|
421
415
|
update = self.update(aggregate)
|
|
422
416
|
if not isinstance(update, tensors.GraphTensor):
|
|
423
417
|
update = tensor.update({'node': {'feature': update}})
|
|
424
418
|
elif not 'feature' in update.node:
|
|
425
419
|
raise ValueError('Could not find `feature` in `node` output.')
|
|
426
|
-
|
|
427
|
-
|
|
420
|
+
|
|
421
|
+
if update.node['feature'].shape[-1] != self.units:
|
|
422
|
+
raise ValueError('Updated node `feature` is not equal to `self.units`.')
|
|
423
|
+
|
|
424
|
+
feature = update.node['feature']
|
|
428
425
|
|
|
429
426
|
if self._skip_connect:
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
if self._normalize:
|
|
435
|
-
updated_node_feature = self._output_norm(updated_node_feature)
|
|
427
|
+
feature += residual
|
|
428
|
+
|
|
429
|
+
feature = self._normalization(feature)
|
|
436
430
|
|
|
437
|
-
return update.update({'node': {'feature':
|
|
431
|
+
return update.update({'node': {'feature': feature}})
|
|
438
432
|
|
|
439
433
|
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
440
434
|
"""Compute messages.
|
|
@@ -445,24 +439,38 @@ class GraphConv(GraphLayer):
|
|
|
445
439
|
tensor:
|
|
446
440
|
The inputted `GraphTensor` instance.
|
|
447
441
|
"""
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
442
|
+
message = keras.ops.concatenate(
|
|
443
|
+
[
|
|
444
|
+
tensor.gather('feature', 'source'),
|
|
445
|
+
tensor.gather('feature', 'target'),
|
|
446
|
+
],
|
|
447
|
+
axis=-1
|
|
448
|
+
)
|
|
449
|
+
if self.has_edge_feature:
|
|
450
|
+
message = keras.ops.concatenate(
|
|
452
451
|
[
|
|
453
|
-
|
|
452
|
+
message,
|
|
454
453
|
tensor.edge['feature']
|
|
455
454
|
],
|
|
456
455
|
axis=-1
|
|
457
456
|
)
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
'
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
457
|
+
if self.has_node_coordinate:
|
|
458
|
+
euclidean_distance = ops.euclidean_distance(
|
|
459
|
+
tensor.gather('coordinate', 'target'),
|
|
460
|
+
tensor.gather('coordinate', 'source'),
|
|
461
|
+
axis=-1
|
|
462
|
+
)
|
|
463
|
+
message = keras.ops.concatenate(
|
|
464
|
+
[
|
|
465
|
+
message,
|
|
466
|
+
euclidean_distance
|
|
467
|
+
],
|
|
468
|
+
axis=-1
|
|
469
|
+
)
|
|
470
|
+
message = self._message_intermediate_dense(message)
|
|
471
|
+
message = self._message_intermediate_activation(message)
|
|
472
|
+
message = self._message_final_dense(message)
|
|
473
|
+
return tensor.update({'edge': {'message': message}})
|
|
466
474
|
|
|
467
475
|
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
468
476
|
"""Aggregates messages.
|
|
@@ -473,14 +481,16 @@ class GraphConv(GraphLayer):
|
|
|
473
481
|
tensor:
|
|
474
482
|
A `GraphTensor` instance containing a message.
|
|
475
483
|
"""
|
|
484
|
+
previous = tensor.node['feature']
|
|
476
485
|
aggregate = tensor.aggregate('message', mode='mean')
|
|
486
|
+
aggregate = keras.ops.concatenate([aggregate, previous], axis=-1)
|
|
477
487
|
return tensor.update(
|
|
478
488
|
{
|
|
479
489
|
'node': {
|
|
480
|
-
'aggregate': aggregate,
|
|
490
|
+
'aggregate': aggregate,
|
|
481
491
|
},
|
|
482
492
|
'edge': {
|
|
483
|
-
'message': None
|
|
493
|
+
'message': None,
|
|
484
494
|
}
|
|
485
495
|
}
|
|
486
496
|
)
|
|
@@ -495,21 +505,16 @@ class GraphConv(GraphLayer):
|
|
|
495
505
|
A `GraphTensor` instance containing aggregated messages
|
|
496
506
|
(updated node features).
|
|
497
507
|
"""
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
],
|
|
503
|
-
axis=-1
|
|
504
|
-
)
|
|
505
|
-
update = self._output_dense(feature)
|
|
506
|
-
update = self._output_activation(update)
|
|
508
|
+
aggregate = tensor.node['aggregate']
|
|
509
|
+
node_feature = self._update_intermediate_dense(aggregate)
|
|
510
|
+
node_feature = self._update_intermediate_activation(node_feature)
|
|
511
|
+
node_feature = self._update_final_dense(node_feature)
|
|
507
512
|
return tensor.update(
|
|
508
513
|
{
|
|
509
514
|
'node': {
|
|
510
|
-
'feature':
|
|
515
|
+
'feature': node_feature,
|
|
511
516
|
'aggregate': None,
|
|
512
|
-
}
|
|
517
|
+
},
|
|
513
518
|
}
|
|
514
519
|
)
|
|
515
520
|
|
|
@@ -563,14 +568,16 @@ class GIConv(GraphConv):
|
|
|
563
568
|
activation: keras.layers.Activation | str | None = 'relu',
|
|
564
569
|
use_bias: bool = True,
|
|
565
570
|
normalize: bool = False,
|
|
571
|
+
skip_connect: bool = True,
|
|
566
572
|
update_edge_feature: bool = True,
|
|
567
573
|
**kwargs,
|
|
568
574
|
):
|
|
569
575
|
super().__init__(
|
|
570
576
|
units=units,
|
|
571
577
|
activation=activation,
|
|
572
|
-
normalize=normalize,
|
|
573
578
|
use_bias=use_bias,
|
|
579
|
+
normalize=normalize,
|
|
580
|
+
skip_connect=skip_connect,
|
|
574
581
|
**kwargs
|
|
575
582
|
)
|
|
576
583
|
self._update_edge_feature = update_edge_feature
|
|
@@ -585,16 +592,16 @@ class GIConv(GraphConv):
|
|
|
585
592
|
trainable=True,
|
|
586
593
|
)
|
|
587
594
|
|
|
588
|
-
self.
|
|
589
|
-
if self._has_edge_feature:
|
|
595
|
+
if self.has_edge_feature:
|
|
590
596
|
edge_feature_dim = spec.edge['feature'].shape[-1]
|
|
591
597
|
|
|
592
598
|
if not self._update_edge_feature:
|
|
593
599
|
if (edge_feature_dim != node_feature_dim):
|
|
594
|
-
warn(
|
|
600
|
+
warnings.warn(
|
|
595
601
|
'Found edge feature dim to be incompatible with node feature dim. '
|
|
596
602
|
'Automatically adding a edge feature projection layer to match '
|
|
597
|
-
'the dim of node features.'
|
|
603
|
+
'the dim of node features.',
|
|
604
|
+
stacklevel=2,
|
|
598
605
|
)
|
|
599
606
|
self._update_edge_feature = True
|
|
600
607
|
|
|
@@ -603,19 +610,14 @@ class GIConv(GraphConv):
|
|
|
603
610
|
else:
|
|
604
611
|
self._update_edge_feature = False
|
|
605
612
|
|
|
606
|
-
has_overridden_update = self.__class__.update != GIConv.update
|
|
607
|
-
if not has_overridden_update:
|
|
608
|
-
self._feedforward_intermediate_dense = self.get_dense(self.units)
|
|
609
|
-
self._feedforward_activation = self._activation
|
|
610
|
-
self._feedforward_output_dense = self.get_dense(self.units)
|
|
611
|
-
|
|
612
613
|
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
613
614
|
message = tensor.gather('feature', 'source')
|
|
614
615
|
edge_feature = tensor.edge.get('feature')
|
|
615
616
|
if self._update_edge_feature:
|
|
616
617
|
edge_feature = self._edge_dense(edge_feature)
|
|
617
|
-
if self.
|
|
618
|
+
if self.has_edge_feature:
|
|
618
619
|
message += edge_feature
|
|
620
|
+
message = keras.ops.relu(message)
|
|
619
621
|
return tensor.update(
|
|
620
622
|
{
|
|
621
623
|
'edge': {
|
|
@@ -639,20 +641,6 @@ class GIConv(GraphConv):
|
|
|
639
641
|
}
|
|
640
642
|
)
|
|
641
643
|
|
|
642
|
-
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
643
|
-
node_feature = tensor.node['aggregate']
|
|
644
|
-
node_feature = self._feedforward_intermediate_dense(node_feature)
|
|
645
|
-
node_feature = self._feedforward_activation(node_feature)
|
|
646
|
-
node_feature = self._feedforward_output_dense(node_feature)
|
|
647
|
-
return tensor.update(
|
|
648
|
-
{
|
|
649
|
-
'node': {
|
|
650
|
-
'feature': node_feature,
|
|
651
|
-
'aggregate': None,
|
|
652
|
-
}
|
|
653
|
-
}
|
|
654
|
-
)
|
|
655
|
-
|
|
656
644
|
def get_config(self) -> dict:
|
|
657
645
|
config = super().get_config()
|
|
658
646
|
config.update({
|
|
@@ -701,15 +689,16 @@ class GAConv(GraphConv):
|
|
|
701
689
|
activation: keras.layers.Activation | str | None = "relu",
|
|
702
690
|
use_bias: bool = True,
|
|
703
691
|
normalize: bool = False,
|
|
692
|
+
skip_connect: bool = True,
|
|
704
693
|
update_edge_feature: bool = True,
|
|
705
|
-
attention_activation: keras.layers.Activation | str | None = "leaky_relu",
|
|
706
694
|
**kwargs,
|
|
707
695
|
) -> None:
|
|
708
696
|
super().__init__(
|
|
709
697
|
units=units,
|
|
710
698
|
activation=activation,
|
|
711
|
-
use_bias=use_bias,
|
|
712
699
|
normalize=normalize,
|
|
700
|
+
use_bias=use_bias,
|
|
701
|
+
skip_connect=skip_connect,
|
|
713
702
|
**kwargs
|
|
714
703
|
)
|
|
715
704
|
self._heads = heads
|
|
@@ -717,7 +706,6 @@ class GAConv(GraphConv):
|
|
|
717
706
|
raise ValueError(f"units need to be divisible by heads.")
|
|
718
707
|
self._head_units = self.units // self.heads
|
|
719
708
|
self._update_edge_feature = update_edge_feature
|
|
720
|
-
self._attention_activation = keras.activations.get(attention_activation)
|
|
721
709
|
|
|
722
710
|
@property
|
|
723
711
|
def heads(self):
|
|
@@ -728,8 +716,7 @@ class GAConv(GraphConv):
|
|
|
728
716
|
return self._head_units
|
|
729
717
|
|
|
730
718
|
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
731
|
-
self.
|
|
732
|
-
self._update_edge_feature = self._has_edge_feature and self._update_edge_feature
|
|
719
|
+
self._update_edge_feature = self.has_edge_feature and self._update_edge_feature
|
|
733
720
|
if self._update_edge_feature:
|
|
734
721
|
self._edge_dense = self.get_einsum_dense(
|
|
735
722
|
'ijh,jkh->ikh', (self.head_units, self.heads)
|
|
@@ -743,15 +730,6 @@ class GAConv(GraphConv):
|
|
|
743
730
|
self._attention_dense = self.get_einsum_dense(
|
|
744
731
|
'ijh,jkh->ikh', (1, self.heads)
|
|
745
732
|
)
|
|
746
|
-
self._node_self_dense = self.get_einsum_dense(
|
|
747
|
-
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
748
|
-
)
|
|
749
|
-
|
|
750
|
-
has_overridden_update = self.__class__.update != GAConv.update
|
|
751
|
-
if not has_overridden_update:
|
|
752
|
-
self._feedforward_intermediate_dense = self.get_dense(self.units)
|
|
753
|
-
self._feedforward_activation = self._activation
|
|
754
|
-
self._feedforward_output_dense = self.get_dense(self.units)
|
|
755
733
|
|
|
756
734
|
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
757
735
|
attention_feature = keras.ops.concatenate(
|
|
@@ -761,7 +739,7 @@ class GAConv(GraphConv):
|
|
|
761
739
|
],
|
|
762
740
|
axis=-1
|
|
763
741
|
)
|
|
764
|
-
if self.
|
|
742
|
+
if self.has_edge_feature:
|
|
765
743
|
attention_feature = keras.ops.concatenate(
|
|
766
744
|
[
|
|
767
745
|
attention_feature,
|
|
@@ -778,7 +756,7 @@ class GAConv(GraphConv):
|
|
|
778
756
|
edge_feature = self._edge_dense(attention_feature)
|
|
779
757
|
edge_feature = keras.ops.reshape(edge_feature, (-1, self.units))
|
|
780
758
|
|
|
781
|
-
attention_feature =
|
|
759
|
+
attention_feature = keras.ops.leaky_relu(attention_feature)
|
|
782
760
|
attention_score = self._attention_dense(attention_feature)
|
|
783
761
|
attention_score = ops.edge_softmax(
|
|
784
762
|
score=attention_score, edge_target=tensor.edge['target']
|
|
@@ -797,7 +775,6 @@ class GAConv(GraphConv):
|
|
|
797
775
|
|
|
798
776
|
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
799
777
|
node_feature = tensor.aggregate('message', mode='sum')
|
|
800
|
-
node_feature += self._node_self_dense(tensor.node['feature'])
|
|
801
778
|
node_feature = keras.ops.reshape(node_feature, (-1, self.units))
|
|
802
779
|
return tensor.update(
|
|
803
780
|
{
|
|
@@ -810,28 +787,123 @@ class GAConv(GraphConv):
|
|
|
810
787
|
}
|
|
811
788
|
)
|
|
812
789
|
|
|
790
|
+
def get_config(self) -> dict:
|
|
791
|
+
config = super().get_config()
|
|
792
|
+
config.update({
|
|
793
|
+
"heads": self._heads,
|
|
794
|
+
'update_edge_feature': self._update_edge_feature,
|
|
795
|
+
})
|
|
796
|
+
return config
|
|
797
|
+
|
|
798
|
+
|
|
799
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
800
|
+
class MPConv(GraphConv):
|
|
801
|
+
|
|
802
|
+
"""Message passing neural network layer.
|
|
803
|
+
|
|
804
|
+
Also supports 3D molecular graphs.
|
|
805
|
+
|
|
806
|
+
>>> graph = molcraft.tensors.GraphTensor(
|
|
807
|
+
... context={
|
|
808
|
+
... 'size': [2]
|
|
809
|
+
... },
|
|
810
|
+
... node={
|
|
811
|
+
... 'feature': [[1.], [2.]]
|
|
812
|
+
... },
|
|
813
|
+
... edge={
|
|
814
|
+
... 'source': [0, 1],
|
|
815
|
+
... 'target': [1, 0],
|
|
816
|
+
... }
|
|
817
|
+
... )
|
|
818
|
+
>>> conv = molcraft.layers.MPConv(units=4)
|
|
819
|
+
>>> conv(graph)
|
|
820
|
+
GraphTensor(
|
|
821
|
+
context={
|
|
822
|
+
'size': <tf.Tensor: shape=[1], dtype=int32>
|
|
823
|
+
},
|
|
824
|
+
node={
|
|
825
|
+
'feature': <tf.Tensor: shape=[2, 4], dtype=float32>
|
|
826
|
+
},
|
|
827
|
+
edge={
|
|
828
|
+
'source': <tf.Tensor: shape=[2], dtype=int32>,
|
|
829
|
+
'target': <tf.Tensor: shape=[2], dtype=int32>
|
|
830
|
+
}
|
|
831
|
+
)
|
|
832
|
+
"""
|
|
833
|
+
|
|
834
|
+
def __init__(
|
|
835
|
+
self,
|
|
836
|
+
units: int = 128,
|
|
837
|
+
activation: keras.layers.Activation | str | None = 'relu',
|
|
838
|
+
use_bias: bool = True,
|
|
839
|
+
normalize: bool = False,
|
|
840
|
+
skip_connect: bool = True,
|
|
841
|
+
**kwargs
|
|
842
|
+
) -> None:
|
|
843
|
+
super().__init__(
|
|
844
|
+
units=units,
|
|
845
|
+
activation=activation,
|
|
846
|
+
use_bias=use_bias,
|
|
847
|
+
normalize=normalize,
|
|
848
|
+
skip_connect=skip_connect,
|
|
849
|
+
**kwargs
|
|
850
|
+
)
|
|
851
|
+
|
|
852
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
853
|
+
node_feature_dim = spec.node['feature'].shape[-1]
|
|
854
|
+
self.update_fn = keras.layers.GRUCell(self.units)
|
|
855
|
+
self._project_previous_node_feature = node_feature_dim != self.units
|
|
856
|
+
if self._project_previous_node_feature:
|
|
857
|
+
warnings.warn(
|
|
858
|
+
'Input node feature dim does not match updated node feature dim. '
|
|
859
|
+
'To make sure input node feature can be passed as `states` to the '
|
|
860
|
+
'GRU cell, it will automatically be projected prior to it.',
|
|
861
|
+
stacklevel=2
|
|
862
|
+
)
|
|
863
|
+
self._previous_node_dense = self.get_dense(self.units)
|
|
864
|
+
|
|
865
|
+
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
866
|
+
"""Aggregates messages.
|
|
867
|
+
|
|
868
|
+
This method may be overridden by subclass.
|
|
869
|
+
|
|
870
|
+
Arguments:
|
|
871
|
+
tensor:
|
|
872
|
+
A `GraphTensor` instance containing a message.
|
|
873
|
+
"""
|
|
874
|
+
aggregate = tensor.aggregate('message', mode='mean')
|
|
875
|
+
return tensor.update(
|
|
876
|
+
{
|
|
877
|
+
'node': {
|
|
878
|
+
'aggregate': aggregate,
|
|
879
|
+
},
|
|
880
|
+
'edge': {
|
|
881
|
+
'message': None,
|
|
882
|
+
}
|
|
883
|
+
}
|
|
884
|
+
)
|
|
885
|
+
|
|
813
886
|
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
887
|
+
previous = tensor.node['feature']
|
|
888
|
+
aggregate = tensor.node['aggregate']
|
|
889
|
+
if self._project_previous_node_feature:
|
|
890
|
+
previous = self._previous_node_dense(previous)
|
|
891
|
+
updated_node_feature, _ = self.update_fn(
|
|
892
|
+
inputs=aggregate, states=previous
|
|
893
|
+
)
|
|
818
894
|
return tensor.update(
|
|
819
895
|
{
|
|
820
896
|
'node': {
|
|
821
|
-
'feature':
|
|
897
|
+
'feature': updated_node_feature,
|
|
822
898
|
'aggregate': None,
|
|
823
899
|
}
|
|
824
900
|
}
|
|
825
901
|
)
|
|
826
|
-
|
|
902
|
+
|
|
827
903
|
def get_config(self) -> dict:
|
|
828
904
|
config = super().get_config()
|
|
829
|
-
config.update({
|
|
830
|
-
|
|
831
|
-
'update_edge_feature': self._update_edge_feature,
|
|
832
|
-
'attention_activation': keras.activations.serialize(self._attention_activation),
|
|
833
|
-
})
|
|
834
|
-
return config
|
|
905
|
+
config.update({})
|
|
906
|
+
return config
|
|
835
907
|
|
|
836
908
|
|
|
837
909
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
@@ -839,6 +911,8 @@ class GTConv(GraphConv):
|
|
|
839
911
|
|
|
840
912
|
"""Graph transformer layer.
|
|
841
913
|
|
|
914
|
+
Also supports 3D molecular graphs.
|
|
915
|
+
|
|
842
916
|
>>> graph = molcraft.tensors.GraphTensor(
|
|
843
917
|
... context={
|
|
844
918
|
... 'size': [2]
|
|
@@ -862,10 +936,9 @@ class GTConv(GraphConv):
|
|
|
862
936
|
},
|
|
863
937
|
edge={
|
|
864
938
|
'source': <tf.Tensor: shape=[2], dtype=int32>,
|
|
865
|
-
'target': <tf.Tensor: shape=[2], dtype=int32
|
|
939
|
+
'target': <tf.Tensor: shape=[2], dtype=int32>,
|
|
866
940
|
}
|
|
867
941
|
)
|
|
868
|
-
|
|
869
942
|
"""
|
|
870
943
|
|
|
871
944
|
def __init__(
|
|
@@ -875,14 +948,16 @@ class GTConv(GraphConv):
|
|
|
875
948
|
activation: keras.layers.Activation | str | None = "relu",
|
|
876
949
|
use_bias: bool = True,
|
|
877
950
|
normalize: bool = False,
|
|
951
|
+
skip_connect: bool = True,
|
|
878
952
|
attention_dropout: float = 0.0,
|
|
879
953
|
**kwargs,
|
|
880
954
|
) -> None:
|
|
881
955
|
super().__init__(
|
|
882
956
|
units=units,
|
|
883
957
|
activation=activation,
|
|
884
|
-
use_bias=use_bias,
|
|
885
958
|
normalize=normalize,
|
|
959
|
+
use_bias=use_bias,
|
|
960
|
+
skip_connect=skip_connect,
|
|
886
961
|
**kwargs
|
|
887
962
|
)
|
|
888
963
|
self._heads = heads
|
|
@@ -900,6 +975,8 @@ class GTConv(GraphConv):
|
|
|
900
975
|
return self._head_units
|
|
901
976
|
|
|
902
977
|
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
978
|
+
"""Builds the layer.
|
|
979
|
+
"""
|
|
903
980
|
self._query_dense = self.get_einsum_dense(
|
|
904
981
|
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
905
982
|
)
|
|
@@ -912,29 +989,36 @@ class GTConv(GraphConv):
|
|
|
912
989
|
self._output_dense = self.get_dense(self.units)
|
|
913
990
|
self._softmax_dropout = keras.layers.Dropout(self._attention_dropout)
|
|
914
991
|
|
|
915
|
-
self.
|
|
916
|
-
|
|
917
|
-
if self._add_bias:
|
|
918
|
-
self._edge_bias = EdgeBias(biases=self.heads)
|
|
992
|
+
if self.has_edge_feature:
|
|
993
|
+
self._attention_bias_dense_1 = self.get_einsum_dense('ij,jkh->ikh', (1, self.heads))
|
|
919
994
|
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
self.
|
|
924
|
-
|
|
995
|
+
if self.has_node_coordinate:
|
|
996
|
+
node_feature_dim = spec.node['feature'].shape[-1]
|
|
997
|
+
num_kernels = self.units
|
|
998
|
+
self._gaussian_loc = self.add_weight(
|
|
999
|
+
shape=[num_kernels], initializer='zeros', dtype='float32', trainable=True
|
|
1000
|
+
)
|
|
1001
|
+
self._gaussian_scale = self.add_weight(
|
|
1002
|
+
shape=[num_kernels], initializer='ones', dtype='float32', trainable=True
|
|
1003
|
+
)
|
|
1004
|
+
self._centrality_dense = self.get_dense(units=node_feature_dim)
|
|
1005
|
+
self._attention_bias_dense_2 = self.get_einsum_dense('ij,jkh->ikh', (1, self.heads))
|
|
925
1006
|
|
|
926
1007
|
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
927
|
-
if self._add_bias:
|
|
928
|
-
edge_bias = self._edge_bias(tensor)
|
|
929
|
-
tensor = tensor.update(
|
|
930
|
-
{
|
|
931
|
-
'edge': {
|
|
932
|
-
'bias': edge_bias
|
|
933
|
-
}
|
|
934
|
-
}
|
|
935
|
-
)
|
|
936
1008
|
node_feature = tensor.node['feature']
|
|
937
|
-
|
|
1009
|
+
|
|
1010
|
+
if self.has_node_coordinate:
|
|
1011
|
+
euclidean_distance = ops.euclidean_distance(
|
|
1012
|
+
tensor.gather('coordinate', 'target'),
|
|
1013
|
+
tensor.gather('coordinate', 'source'),
|
|
1014
|
+
axis=-1
|
|
1015
|
+
)
|
|
1016
|
+
gaussian = ops.gaussian(
|
|
1017
|
+
euclidean_distance, self._gaussian_loc, self._gaussian_scale
|
|
1018
|
+
)
|
|
1019
|
+
centrality = keras.ops.segment_sum(gaussian, tensor.edge['target'], tensor.num_nodes)
|
|
1020
|
+
node_feature += self._centrality_dense(centrality)
|
|
1021
|
+
|
|
938
1022
|
query = self._query_dense(node_feature)
|
|
939
1023
|
key = self._key_dense(node_feature)
|
|
940
1024
|
value = self._value_dense(node_feature)
|
|
@@ -946,23 +1030,45 @@ class GTConv(GraphConv):
|
|
|
946
1030
|
attention_score = keras.ops.sum(query * key, axis=1, keepdims=True)
|
|
947
1031
|
attention_score /= keras.ops.sqrt(float(self.head_units))
|
|
948
1032
|
|
|
949
|
-
|
|
1033
|
+
if self.has_edge_feature:
|
|
1034
|
+
attention_score += self._attention_bias_dense_1(tensor.edge['feature'])
|
|
1035
|
+
|
|
1036
|
+
if self.has_node_coordinate:
|
|
1037
|
+
attention_score += self._attention_bias_dense_2(gaussian)
|
|
1038
|
+
|
|
950
1039
|
attention = ops.edge_softmax(attention_score, tensor.edge['target'])
|
|
951
1040
|
attention = self._softmax_dropout(attention)
|
|
952
|
-
message = ops.edge_weight(value, attention)
|
|
953
1041
|
|
|
1042
|
+
if self.has_node_coordinate:
|
|
1043
|
+
displacement = ops.displacement(
|
|
1044
|
+
tensor.gather('coordinate', 'target'),
|
|
1045
|
+
tensor.gather('coordinate', 'source'),
|
|
1046
|
+
normalize=True
|
|
1047
|
+
)
|
|
1048
|
+
attention *= keras.ops.expand_dims(displacement, axis=-1)
|
|
1049
|
+
attention = keras.ops.expand_dims(attention, axis=2)
|
|
1050
|
+
value = keras.ops.expand_dims(value, axis=1)
|
|
1051
|
+
|
|
1052
|
+
message = ops.edge_weight(value, attention)
|
|
1053
|
+
|
|
954
1054
|
return tensor.update(
|
|
955
1055
|
{
|
|
956
1056
|
'edge': {
|
|
957
|
-
'message': message
|
|
1057
|
+
'message': message,
|
|
958
1058
|
},
|
|
959
1059
|
}
|
|
960
1060
|
)
|
|
961
1061
|
|
|
962
1062
|
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
963
1063
|
node_feature = tensor.aggregate('message', mode='sum')
|
|
964
|
-
|
|
1064
|
+
if self.has_node_coordinate:
|
|
1065
|
+
shape = (tensor.num_nodes, -1, self.units)
|
|
1066
|
+
else:
|
|
1067
|
+
shape = (tensor.num_nodes, self.units)
|
|
1068
|
+
node_feature = keras.ops.reshape(node_feature, shape)
|
|
965
1069
|
node_feature = self._output_dense(node_feature)
|
|
1070
|
+
if self.has_node_coordinate:
|
|
1071
|
+
node_feature = keras.ops.sum(node_feature, axis=1)
|
|
966
1072
|
return tensor.update(
|
|
967
1073
|
{
|
|
968
1074
|
'node': {
|
|
@@ -973,20 +1079,6 @@ class GTConv(GraphConv):
|
|
|
973
1079
|
}
|
|
974
1080
|
}
|
|
975
1081
|
)
|
|
976
|
-
|
|
977
|
-
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
978
|
-
node_feature = tensor.node['aggregate']
|
|
979
|
-
node_feature = self._feedforward_intermediate_dense(node_feature)
|
|
980
|
-
node_feature = self._feedforward_activation(node_feature)
|
|
981
|
-
node_feature = self._feedforward_output_dense(node_feature)
|
|
982
|
-
return tensor.update(
|
|
983
|
-
{
|
|
984
|
-
'node': {
|
|
985
|
-
'feature': node_feature,
|
|
986
|
-
'aggregate': None,
|
|
987
|
-
},
|
|
988
|
-
}
|
|
989
|
-
)
|
|
990
1082
|
|
|
991
1083
|
def get_config(self) -> dict:
|
|
992
1084
|
config = super().get_config()
|
|
@@ -998,17 +1090,49 @@ class GTConv(GraphConv):
|
|
|
998
1090
|
|
|
999
1091
|
|
|
1000
1092
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1001
|
-
class
|
|
1093
|
+
class EGConv(GraphConv):
|
|
1002
1094
|
|
|
1003
|
-
"""
|
|
1095
|
+
"""Equivariant graph neural network layer 3D.
|
|
1096
|
+
|
|
1097
|
+
Only supports 3D molecular graphs.
|
|
1098
|
+
|
|
1099
|
+
>>> graph = molcraft.tensors.GraphTensor(
|
|
1100
|
+
... context={
|
|
1101
|
+
... 'size': [2]
|
|
1102
|
+
... },
|
|
1103
|
+
... node={
|
|
1104
|
+
... 'feature': [[1.], [2.]],
|
|
1105
|
+
... 'coordinate': [[0.1, -0.1, 0.5], [1.2, -0.5, 2.1]],
|
|
1106
|
+
... },
|
|
1107
|
+
... edge={
|
|
1108
|
+
... 'source': [0, 1],
|
|
1109
|
+
... 'target': [1, 0],
|
|
1110
|
+
... }
|
|
1111
|
+
... )
|
|
1112
|
+
>>> conv = molcraft.layers.EGConv(units=4)
|
|
1113
|
+
>>> conv(graph)
|
|
1114
|
+
GraphTensor(
|
|
1115
|
+
context={
|
|
1116
|
+
'size': <tf.Tensor: shape=[1], dtype=int32>
|
|
1117
|
+
},
|
|
1118
|
+
node={
|
|
1119
|
+
'feature': <tf.Tensor: shape=[2, 4], dtype=float32>,
|
|
1120
|
+
'coordinate': <tf.Tensor: shape=[2, 3], dtype=float32>
|
|
1121
|
+
},
|
|
1122
|
+
edge={
|
|
1123
|
+
'source': <tf.Tensor: shape=[2], dtype=int32>,
|
|
1124
|
+
'target': <tf.Tensor: shape=[2], dtype=int32>
|
|
1125
|
+
}
|
|
1126
|
+
)
|
|
1004
1127
|
"""
|
|
1005
1128
|
|
|
1006
1129
|
def __init__(
|
|
1007
1130
|
self,
|
|
1008
1131
|
units: int = 128,
|
|
1009
|
-
activation: keras.layers.Activation | str | None =
|
|
1132
|
+
activation: keras.layers.Activation | str | None = 'silu',
|
|
1010
1133
|
use_bias: bool = True,
|
|
1011
1134
|
normalize: bool = False,
|
|
1135
|
+
skip_connect: bool = True,
|
|
1012
1136
|
**kwargs
|
|
1013
1137
|
) -> None:
|
|
1014
1138
|
super().__init__(
|
|
@@ -1016,270 +1140,59 @@ class MPConv(GraphConv):
|
|
|
1016
1140
|
activation=activation,
|
|
1017
1141
|
use_bias=use_bias,
|
|
1018
1142
|
normalize=normalize,
|
|
1143
|
+
skip_connect=skip_connect,
|
|
1019
1144
|
**kwargs
|
|
1020
1145
|
)
|
|
1021
1146
|
|
|
1022
1147
|
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
self.project_input_node_feature = node_feature_dim != self.units
|
|
1028
|
-
if self.project_input_node_feature:
|
|
1029
|
-
warn(
|
|
1030
|
-
'Input node feature dim does not match updated node feature dim. '
|
|
1031
|
-
'To make sure input node feature can be passed as `states` to the '
|
|
1032
|
-
'GRU cell, it will automatically be projected prior to it.'
|
|
1033
|
-
)
|
|
1034
|
-
self._previous_node_dense = self.get_dense(
|
|
1035
|
-
self.units, activation=self._activation
|
|
1148
|
+
if not self.has_node_coordinate:
|
|
1149
|
+
raise ValueError(
|
|
1150
|
+
'Could not find `coordinate`s in node, '
|
|
1151
|
+
'which is required for Conv3D layers.'
|
|
1036
1152
|
)
|
|
1153
|
+
self._message_feedforward_intermediate = self.get_dense(
|
|
1154
|
+
self.units, activation=self.activation
|
|
1155
|
+
)
|
|
1156
|
+
self._message_feedforward_final = self.get_dense(
|
|
1157
|
+
self.units, activation=self.activation
|
|
1158
|
+
)
|
|
1037
1159
|
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
],
|
|
1044
|
-
axis=-1
|
|
1045
|
-
)
|
|
1046
|
-
if self._has_edge_feature:
|
|
1047
|
-
feature = keras.ops.concatenate(
|
|
1048
|
-
[
|
|
1049
|
-
feature,
|
|
1050
|
-
tensor.edge['feature']
|
|
1051
|
-
],
|
|
1052
|
-
axis=-1
|
|
1053
|
-
)
|
|
1054
|
-
message = self.message_fn(feature)
|
|
1055
|
-
return tensor.update(
|
|
1056
|
-
{
|
|
1057
|
-
'edge': {
|
|
1058
|
-
'message': message,
|
|
1059
|
-
}
|
|
1060
|
-
}
|
|
1061
|
-
)
|
|
1062
|
-
|
|
1063
|
-
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1064
|
-
aggregate = tensor.aggregate('message', mode='mean')
|
|
1065
|
-
feature = tensor.node['feature']
|
|
1066
|
-
if self.project_input_node_feature:
|
|
1067
|
-
feature = self._previous_node_dense(feature)
|
|
1068
|
-
return tensor.update(
|
|
1069
|
-
{
|
|
1070
|
-
'node': {
|
|
1071
|
-
'aggregate': aggregate,
|
|
1072
|
-
'feature': feature,
|
|
1073
|
-
}
|
|
1074
|
-
}
|
|
1075
|
-
)
|
|
1076
|
-
|
|
1077
|
-
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1078
|
-
updated_node_feature, _ = self.update_fn(
|
|
1079
|
-
inputs=tensor.node['aggregate'],
|
|
1080
|
-
states=tensor.node['feature']
|
|
1081
|
-
)
|
|
1082
|
-
return tensor.update(
|
|
1083
|
-
{
|
|
1084
|
-
'node': {
|
|
1085
|
-
'feature': updated_node_feature,
|
|
1086
|
-
'aggregate': None,
|
|
1087
|
-
}
|
|
1088
|
-
}
|
|
1089
|
-
)
|
|
1090
|
-
|
|
1091
|
-
def get_config(self) -> dict:
|
|
1092
|
-
config = super().get_config()
|
|
1093
|
-
config.update({})
|
|
1094
|
-
return config
|
|
1095
|
-
|
|
1096
|
-
|
|
1097
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1098
|
-
class GTConv3D(GTConv):
|
|
1099
|
-
|
|
1100
|
-
"""Graph transformer layer 3D.
|
|
1101
|
-
"""
|
|
1102
|
-
|
|
1103
|
-
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1104
|
-
"""Builds the layer.
|
|
1105
|
-
"""
|
|
1106
|
-
super().build(spec)
|
|
1107
|
-
if self._add_bias:
|
|
1108
|
-
node_feature_dim = spec.node['feature'].shape[-1]
|
|
1109
|
-
kernels = self.units
|
|
1110
|
-
self._gaussian_basis = GaussianDistance(kernels)
|
|
1111
|
-
self._centrality_dense = self.get_dense(units=node_feature_dim)
|
|
1112
|
-
self._gaussian_edge_bias = self.get_dense(self.heads)
|
|
1113
|
-
|
|
1114
|
-
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1115
|
-
node_feature = tensor.node['feature']
|
|
1116
|
-
|
|
1117
|
-
if self._add_bias:
|
|
1118
|
-
gaussian = self._gaussian_basis(tensor)
|
|
1119
|
-
centrality = keras.ops.segment_sum(
|
|
1120
|
-
gaussian, tensor.edge['target'], tensor.num_nodes
|
|
1121
|
-
)
|
|
1122
|
-
node_feature += self._centrality_dense(centrality)
|
|
1123
|
-
|
|
1124
|
-
edge_bias = self._edge_bias(tensor) + self._gaussian_edge_bias(gaussian)
|
|
1125
|
-
tensor = tensor.update({'edge': {'bias': edge_bias}})
|
|
1126
|
-
|
|
1127
|
-
query = self._query_dense(node_feature)
|
|
1128
|
-
key = self._key_dense(node_feature)
|
|
1129
|
-
value = self._value_dense(node_feature)
|
|
1130
|
-
|
|
1131
|
-
query = ops.gather(query, tensor.edge['source'])
|
|
1132
|
-
key = ops.gather(key, tensor.edge['target'])
|
|
1133
|
-
value = ops.gather(value, tensor.edge['source'])
|
|
1134
|
-
|
|
1135
|
-
attention_score = keras.ops.sum(query * key, axis=1, keepdims=True)
|
|
1136
|
-
attention_score /= keras.ops.sqrt(float(self.head_units))
|
|
1137
|
-
|
|
1138
|
-
attention_score += keras.ops.expand_dims(tensor.edge['bias'], axis=1)
|
|
1139
|
-
|
|
1140
|
-
attention = ops.edge_softmax(attention_score, tensor.edge['target'])
|
|
1141
|
-
attention = self._softmax_dropout(attention)
|
|
1142
|
-
|
|
1143
|
-
distance = keras.ops.subtract(
|
|
1144
|
-
tensor.gather('coordinate', 'source'),
|
|
1145
|
-
tensor.gather('coordinate', 'target')
|
|
1146
|
-
)
|
|
1147
|
-
euclidean_distance = ops.euclidean_distance(
|
|
1148
|
-
tensor.gather('coordinate', 'source'),
|
|
1149
|
-
tensor.gather('coordinate', 'target'),
|
|
1150
|
-
axis=-1
|
|
1151
|
-
)
|
|
1152
|
-
distance /= euclidean_distance
|
|
1153
|
-
|
|
1154
|
-
attention *= keras.ops.expand_dims(distance, axis=-1)
|
|
1155
|
-
attention = keras.ops.expand_dims(attention, axis=2)
|
|
1156
|
-
value = keras.ops.expand_dims(value, axis=1)
|
|
1157
|
-
|
|
1158
|
-
message = ops.edge_weight(value, attention)
|
|
1159
|
-
|
|
1160
|
-
return tensor.update(
|
|
1161
|
-
{
|
|
1162
|
-
'edge': {
|
|
1163
|
-
'message': message,
|
|
1164
|
-
},
|
|
1165
|
-
}
|
|
1166
|
-
)
|
|
1167
|
-
|
|
1168
|
-
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1169
|
-
node_feature = tensor.aggregate('message', mode='sum')
|
|
1170
|
-
node_feature = keras.ops.reshape(
|
|
1171
|
-
node_feature, (tensor.num_nodes, -1, self.units)
|
|
1172
|
-
)
|
|
1173
|
-
node_feature = self._output_dense(node_feature)
|
|
1174
|
-
node_feature = keras.ops.sum(node_feature, axis=1)
|
|
1175
|
-
return tensor.update(
|
|
1176
|
-
{
|
|
1177
|
-
'node': {
|
|
1178
|
-
'aggregate': node_feature,
|
|
1179
|
-
},
|
|
1180
|
-
'edge': {
|
|
1181
|
-
'message': None,
|
|
1182
|
-
}
|
|
1183
|
-
}
|
|
1184
|
-
)
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1188
|
-
class MPConv3D(MPConv):
|
|
1189
|
-
|
|
1190
|
-
"""Message passing neural network layer 3D.
|
|
1191
|
-
"""
|
|
1192
|
-
|
|
1193
|
-
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1194
|
-
euclidean_distance = ops.euclidean_distance(
|
|
1195
|
-
tensor.gather('coordinate', 'target'),
|
|
1196
|
-
tensor.gather('coordinate', 'source'),
|
|
1197
|
-
axis=-1
|
|
1198
|
-
)
|
|
1199
|
-
feature = keras.ops.concatenate(
|
|
1200
|
-
[
|
|
1201
|
-
tensor.gather('feature', 'source'),
|
|
1202
|
-
tensor.gather('feature', 'target'),
|
|
1203
|
-
euclidean_distance,
|
|
1204
|
-
],
|
|
1205
|
-
axis=-1
|
|
1206
|
-
)
|
|
1207
|
-
if self._has_edge_feature:
|
|
1208
|
-
feature = keras.ops.concatenate(
|
|
1209
|
-
[
|
|
1210
|
-
feature,
|
|
1211
|
-
tensor.edge['feature']
|
|
1212
|
-
],
|
|
1213
|
-
axis=-1
|
|
1214
|
-
)
|
|
1215
|
-
message = self.message_fn(feature)
|
|
1216
|
-
return tensor.update(
|
|
1217
|
-
{
|
|
1218
|
-
'edge': {
|
|
1219
|
-
'message': message,
|
|
1220
|
-
}
|
|
1221
|
-
}
|
|
1222
|
-
)
|
|
1223
|
-
|
|
1224
|
-
|
|
1225
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1226
|
-
class EGConv3D(GraphConv):
|
|
1227
|
-
|
|
1228
|
-
"""Equivariant graph neural network layer 3D.
|
|
1229
|
-
"""
|
|
1230
|
-
|
|
1231
|
-
def __init__(
|
|
1232
|
-
self,
|
|
1233
|
-
units: int = 128,
|
|
1234
|
-
activation: keras.layers.Activation | str | None = None,
|
|
1235
|
-
use_bias: bool = True,
|
|
1236
|
-
normalize: bool = False,
|
|
1237
|
-
**kwargs
|
|
1238
|
-
) -> None:
|
|
1239
|
-
super().__init__(
|
|
1240
|
-
units=units,
|
|
1241
|
-
activation=activation,
|
|
1242
|
-
use_bias=use_bias,
|
|
1243
|
-
normalize=normalize,
|
|
1244
|
-
**kwargs
|
|
1160
|
+
self._coord_feedforward_intermediate = self.get_dense(
|
|
1161
|
+
self.units, activation=self.activation
|
|
1162
|
+
)
|
|
1163
|
+
self._coord_feedforward_final = self.get_dense(
|
|
1164
|
+
1, use_bias=False, activation='tanh'
|
|
1245
1165
|
)
|
|
1246
1166
|
|
|
1247
|
-
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1248
|
-
if 'coordinate' not in spec.node:
|
|
1249
|
-
raise ValueError(
|
|
1250
|
-
'Could not find `coordinate`s in node, '
|
|
1251
|
-
'which is required for Conv3D layers.'
|
|
1252
|
-
)
|
|
1253
|
-
self._has_edge_feature = 'feature' in spec.edge
|
|
1254
|
-
self.message_fn = self.get_dense(self.units, activation=self._activation)
|
|
1255
|
-
self.dense_position = self.get_dense(1, use_bias=False, kernel_initializer='zeros')
|
|
1256
|
-
|
|
1257
|
-
has_overridden_update = self.__class__.update != EGConv3D.update
|
|
1258
|
-
if not has_overridden_update:
|
|
1259
|
-
self.update_fn = self.get_dense(self.units, activation=self._activation)
|
|
1260
|
-
self.output_dense = self.get_dense(self.units)
|
|
1261
|
-
|
|
1262
1167
|
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1263
1168
|
relative_node_coordinate = keras.ops.subtract(
|
|
1264
1169
|
tensor.gather('coordinate', 'target'),
|
|
1265
1170
|
tensor.gather('coordinate', 'source')
|
|
1266
1171
|
)
|
|
1267
|
-
|
|
1268
|
-
keras.ops.square(
|
|
1269
|
-
relative_node_coordinate
|
|
1270
|
-
),
|
|
1172
|
+
squared_distance = keras.ops.sum(
|
|
1173
|
+
keras.ops.square(relative_node_coordinate),
|
|
1271
1174
|
axis=-1,
|
|
1272
1175
|
keepdims=True
|
|
1273
1176
|
)
|
|
1177
|
+
|
|
1178
|
+
# For numerical stability (i.e., to prevent NaN losses), this implementation of `EGConv3D`
|
|
1179
|
+
# either needs to apply a `tanh` activation to the output of `self._coord_feedforward_final`,
|
|
1180
|
+
# or normalize `relative_node_cordinate` as follows:
|
|
1181
|
+
#
|
|
1182
|
+
# norm = keras.ops.sqrt(squared_distance) + keras.backend.epsilon()
|
|
1183
|
+
# relative_node_coordinate /= norm
|
|
1184
|
+
#
|
|
1185
|
+
# For now, this implementation does the former.
|
|
1186
|
+
|
|
1274
1187
|
feature = keras.ops.concatenate(
|
|
1275
1188
|
[
|
|
1276
1189
|
tensor.gather('feature', 'target'),
|
|
1277
1190
|
tensor.gather('feature', 'source'),
|
|
1278
|
-
|
|
1191
|
+
squared_distance,
|
|
1279
1192
|
],
|
|
1280
1193
|
axis=-1
|
|
1281
1194
|
)
|
|
1282
|
-
if self.
|
|
1195
|
+
if self.has_edge_feature:
|
|
1283
1196
|
feature = keras.ops.concatenate(
|
|
1284
1197
|
[
|
|
1285
1198
|
feature,
|
|
@@ -1287,10 +1200,15 @@ class EGConv3D(GraphConv):
|
|
|
1287
1200
|
],
|
|
1288
1201
|
axis=-1
|
|
1289
1202
|
)
|
|
1290
|
-
message = self.
|
|
1203
|
+
message = self._message_feedforward_final(
|
|
1204
|
+
self._message_feedforward_intermediate(feature)
|
|
1205
|
+
)
|
|
1206
|
+
|
|
1291
1207
|
relative_node_coordinate = keras.ops.multiply(
|
|
1292
|
-
relative_node_coordinate,
|
|
1293
|
-
self.
|
|
1208
|
+
relative_node_coordinate,
|
|
1209
|
+
self._coord_feedforward_final(
|
|
1210
|
+
self._coord_feedforward_intermediate(message)
|
|
1211
|
+
)
|
|
1294
1212
|
)
|
|
1295
1213
|
return tensor.update(
|
|
1296
1214
|
{
|
|
@@ -1302,26 +1220,26 @@ class EGConv3D(GraphConv):
|
|
|
1302
1220
|
)
|
|
1303
1221
|
|
|
1304
1222
|
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
|
|
1308
|
-
#
|
|
1309
|
-
#
|
|
1310
|
-
#
|
|
1311
|
-
#
|
|
1312
|
-
#
|
|
1313
|
-
#
|
|
1314
|
-
#
|
|
1315
|
-
|
|
1316
|
-
updated_coordinate = tensor.aggregate('relative_node_coordinate', mode='mean')# * coefficient
|
|
1317
|
-
updated_coordinate += tensor.node['coordinate']
|
|
1318
|
-
|
|
1223
|
+
coordinate = tensor.node['coordinate']
|
|
1224
|
+
coordinate += tensor.aggregate('relative_node_coordinate', mode='mean')
|
|
1225
|
+
|
|
1226
|
+
# Original implementation seems to apply sum aggregation, which does not
|
|
1227
|
+
# seem work well for this implementation of `EGConv3D`, as it causes
|
|
1228
|
+
# large output values and large initial losses. The magnitude of the
|
|
1229
|
+
# aggregated values of a sum aggregation depends on the number of
|
|
1230
|
+
# neighbors, which may be many and may differ from node to node (or
|
|
1231
|
+
# graph to graph). Therefore, a mean mean aggregation is performed
|
|
1232
|
+
# instead:
|
|
1319
1233
|
aggregate = tensor.aggregate('message', mode='mean')
|
|
1234
|
+
aggregate = keras.ops.concatenate([aggregate, tensor.node['feature']], axis=-1)
|
|
1235
|
+
# Simply added to silence warning ('no gradients for variables ...')
|
|
1236
|
+
aggregate += (0.0 * keras.ops.sum(coordinate))
|
|
1237
|
+
|
|
1320
1238
|
return tensor.update(
|
|
1321
1239
|
{
|
|
1322
1240
|
'node': {
|
|
1323
1241
|
'aggregate': aggregate,
|
|
1324
|
-
'coordinate':
|
|
1242
|
+
'coordinate': coordinate,
|
|
1325
1243
|
},
|
|
1326
1244
|
'edge': {
|
|
1327
1245
|
'message': None,
|
|
@@ -1329,26 +1247,6 @@ class EGConv3D(GraphConv):
|
|
|
1329
1247
|
}
|
|
1330
1248
|
}
|
|
1331
1249
|
)
|
|
1332
|
-
|
|
1333
|
-
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1334
|
-
updated_node_feature = self.update_fn(
|
|
1335
|
-
keras.ops.concatenate(
|
|
1336
|
-
[
|
|
1337
|
-
tensor.node['aggregate'],
|
|
1338
|
-
tensor.node['feature']
|
|
1339
|
-
],
|
|
1340
|
-
axis=-1
|
|
1341
|
-
)
|
|
1342
|
-
)
|
|
1343
|
-
updated_node_feature = self.output_dense(updated_node_feature)
|
|
1344
|
-
return tensor.update(
|
|
1345
|
-
{
|
|
1346
|
-
'node': {
|
|
1347
|
-
'feature': updated_node_feature,
|
|
1348
|
-
'aggregate': None,
|
|
1349
|
-
},
|
|
1350
|
-
}
|
|
1351
|
-
)
|
|
1352
1250
|
|
|
1353
1251
|
def get_config(self) -> dict:
|
|
1354
1252
|
config = super().get_config()
|
|
@@ -1391,146 +1289,6 @@ class Readout(GraphLayer):
|
|
|
1391
1289
|
config['mode'] = self.mode
|
|
1392
1290
|
return config
|
|
1393
1291
|
|
|
1394
|
-
|
|
1395
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1396
|
-
class GraphNetwork(GraphLayer):
|
|
1397
|
-
|
|
1398
|
-
"""Graph neural network.
|
|
1399
|
-
|
|
1400
|
-
Sequentially calls graph layers (`GraphLayer`) and concatenates its output.
|
|
1401
|
-
|
|
1402
|
-
Arguments:
|
|
1403
|
-
layers (list):
|
|
1404
|
-
A list of graph layers.
|
|
1405
|
-
"""
|
|
1406
|
-
|
|
1407
|
-
def __init__(self, layers: list[GraphLayer], **kwargs) -> None:
|
|
1408
|
-
super().__init__(**kwargs)
|
|
1409
|
-
self.layers = layers
|
|
1410
|
-
self._update_edge_feature = False
|
|
1411
|
-
|
|
1412
|
-
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1413
|
-
units = self.layers[0].units
|
|
1414
|
-
node_feature_dim = spec.node['feature'].shape[-1]
|
|
1415
|
-
self._update_node_feature = node_feature_dim != units
|
|
1416
|
-
if self._update_node_feature:
|
|
1417
|
-
warn(
|
|
1418
|
-
'Node feature dim does not match `units` of the first layer. '
|
|
1419
|
-
'Automatically adding a node projection layer to match `units`.'
|
|
1420
|
-
)
|
|
1421
|
-
self._node_dense = self.get_dense(units)
|
|
1422
|
-
self._has_edge_feature = 'feature' in spec.edge
|
|
1423
|
-
if self._has_edge_feature:
|
|
1424
|
-
edge_feature_dim = spec.edge['feature'].shape[-1]
|
|
1425
|
-
self._update_edge_feature = edge_feature_dim != units
|
|
1426
|
-
if self._update_edge_feature:
|
|
1427
|
-
warn(
|
|
1428
|
-
'Edge feature dim does not match `units` of the first layer. '
|
|
1429
|
-
'Automatically adding a edge projection layer to match `units`.'
|
|
1430
|
-
)
|
|
1431
|
-
self._edge_dense = self.get_dense(units)
|
|
1432
|
-
|
|
1433
|
-
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1434
|
-
x = tensors.to_dict(tensor)
|
|
1435
|
-
if self._update_node_feature:
|
|
1436
|
-
x['node']['feature'] = self._node_dense(tensor.node['feature'])
|
|
1437
|
-
if self._has_edge_feature and self._update_edge_feature:
|
|
1438
|
-
x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
|
|
1439
|
-
outputs = [x['node']['feature']]
|
|
1440
|
-
for layer in self.layers:
|
|
1441
|
-
x = layer(x)
|
|
1442
|
-
outputs.append(x['node']['feature'])
|
|
1443
|
-
return tensor.update(
|
|
1444
|
-
{
|
|
1445
|
-
'node': {
|
|
1446
|
-
'feature': keras.ops.concatenate(outputs, axis=-1)
|
|
1447
|
-
}
|
|
1448
|
-
}
|
|
1449
|
-
)
|
|
1450
|
-
|
|
1451
|
-
def tape_propagate(
|
|
1452
|
-
self,
|
|
1453
|
-
tensor: tensors.GraphTensor,
|
|
1454
|
-
tape: tf.GradientTape,
|
|
1455
|
-
training: bool | None = None,
|
|
1456
|
-
) -> tuple[tensors.GraphTensor, list[tf.Tensor]]:
|
|
1457
|
-
"""Performs the propagation with a `GradientTape`.
|
|
1458
|
-
|
|
1459
|
-
Performs the same forward pass as `propagate` but with a `GradientTape`
|
|
1460
|
-
watching intermediate node features.
|
|
1461
|
-
|
|
1462
|
-
Arguments:
|
|
1463
|
-
tensor (tensors.GraphTensor):
|
|
1464
|
-
The graph input.
|
|
1465
|
-
"""
|
|
1466
|
-
if isinstance(tensor, tensors.GraphTensor):
|
|
1467
|
-
x = tensors.to_dict(tensor)
|
|
1468
|
-
else:
|
|
1469
|
-
x = tensor
|
|
1470
|
-
if self._update_node_feature:
|
|
1471
|
-
x['node']['feature'] = self._node_dense(tensor.node['feature'])
|
|
1472
|
-
if self._update_edge_feature:
|
|
1473
|
-
x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
|
|
1474
|
-
tape.watch(x['node']['feature'])
|
|
1475
|
-
outputs = [x['node']['feature']]
|
|
1476
|
-
for layer in self.layers:
|
|
1477
|
-
x = layer(x, training=training)
|
|
1478
|
-
tape.watch(x['node']['feature'])
|
|
1479
|
-
outputs.append(x['node']['feature'])
|
|
1480
|
-
|
|
1481
|
-
tensor = tensor.update(
|
|
1482
|
-
{
|
|
1483
|
-
'node': {
|
|
1484
|
-
'feature': keras.ops.concatenate(outputs, axis=-1)
|
|
1485
|
-
}
|
|
1486
|
-
}
|
|
1487
|
-
)
|
|
1488
|
-
return tensor, outputs
|
|
1489
|
-
|
|
1490
|
-
def get_config(self) -> dict:
|
|
1491
|
-
config = super().get_config()
|
|
1492
|
-
config.update(
|
|
1493
|
-
{
|
|
1494
|
-
'layers': [
|
|
1495
|
-
keras.layers.serialize(layer) for layer in self.layers
|
|
1496
|
-
]
|
|
1497
|
-
}
|
|
1498
|
-
)
|
|
1499
|
-
return config
|
|
1500
|
-
|
|
1501
|
-
@classmethod
|
|
1502
|
-
def from_config(cls, config: dict) -> 'GraphNetwork':
|
|
1503
|
-
config['layers'] = [
|
|
1504
|
-
keras.layers.deserialize(layer) for layer in config['layers']
|
|
1505
|
-
]
|
|
1506
|
-
return super().from_config(config)
|
|
1507
|
-
|
|
1508
|
-
|
|
1509
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1510
|
-
class Extraction(GraphLayer):
|
|
1511
|
-
|
|
1512
|
-
def __init__(
|
|
1513
|
-
self,
|
|
1514
|
-
field: str,
|
|
1515
|
-
inner_field: str | None = None,
|
|
1516
|
-
**kwargs
|
|
1517
|
-
) -> None:
|
|
1518
|
-
super().__init__(**kwargs)
|
|
1519
|
-
self.field = field
|
|
1520
|
-
self.inner_field = inner_field
|
|
1521
|
-
|
|
1522
|
-
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1523
|
-
data = dict(getattr(tensor, self.field))
|
|
1524
|
-
if not self.inner_field:
|
|
1525
|
-
return data
|
|
1526
|
-
return data[self.inner_field]
|
|
1527
|
-
|
|
1528
|
-
def get_config(self):
|
|
1529
|
-
config = super().get_config()
|
|
1530
|
-
config['field'] = self.field
|
|
1531
|
-
config['inner_field'] = self.inner_field
|
|
1532
|
-
return config
|
|
1533
|
-
|
|
1534
1292
|
|
|
1535
1293
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1536
1294
|
class NodeEmbedding(GraphLayer):
|
|
@@ -1574,15 +1332,12 @@ class NodeEmbedding(GraphLayer):
|
|
|
1574
1332
|
if self._embed_context:
|
|
1575
1333
|
self._context_dense = self.get_dense(self.dim)
|
|
1576
1334
|
|
|
1577
|
-
if self._normalize:
|
|
1578
|
-
|
|
1579
|
-
|
|
1580
|
-
|
|
1581
|
-
|
|
1582
|
-
|
|
1583
|
-
self._norm = keras.layers.LayerNormalization(
|
|
1584
|
-
name='output_layer_norm'
|
|
1585
|
-
)
|
|
1335
|
+
if not self._normalize:
|
|
1336
|
+
self._norm = keras.layers.Identity()
|
|
1337
|
+
elif str(self._normalize).lower().startswith('layer'):
|
|
1338
|
+
self._norm = keras.layers.LayerNormalization()
|
|
1339
|
+
else:
|
|
1340
|
+
self._norm = keras.layers.BatchNormalization()
|
|
1586
1341
|
|
|
1587
1342
|
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1588
1343
|
feature = self._node_dense(tensor.node['feature'])
|
|
@@ -1604,8 +1359,7 @@ class NodeEmbedding(GraphLayer):
|
|
|
1604
1359
|
elif self._allow_masking:
|
|
1605
1360
|
feature = feature + (self._mask_feature * 0.0)
|
|
1606
1361
|
|
|
1607
|
-
|
|
1608
|
-
feature = self._norm(feature)
|
|
1362
|
+
feature = self._norm(feature)
|
|
1609
1363
|
|
|
1610
1364
|
if not self._allow_reconstruction:
|
|
1611
1365
|
return tensor.update({'node': {'feature': feature}})
|
|
@@ -1694,8 +1448,8 @@ class EdgeEmbedding(GraphLayer):
|
|
|
1694
1448
|
mask = keras.ops.expand_dims(mask, -1)
|
|
1695
1449
|
feature = keras.ops.where(mask, self._mask_feature, feature)
|
|
1696
1450
|
elif self._allow_masking:
|
|
1697
|
-
#
|
|
1698
|
-
feature
|
|
1451
|
+
# Simply added to silence warning ('no gradients for variables ...')
|
|
1452
|
+
feature += (0.0 * self._mask_feature)
|
|
1699
1453
|
|
|
1700
1454
|
if self._normalize:
|
|
1701
1455
|
feature = self._norm(feature)
|
|
@@ -1726,196 +1480,119 @@ class EdgeEmbedding(GraphLayer):
|
|
|
1726
1480
|
|
|
1727
1481
|
|
|
1728
1482
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1729
|
-
class
|
|
1730
|
-
"""Base graph projection layer.
|
|
1731
|
-
"""
|
|
1732
|
-
def __init__(
|
|
1733
|
-
self,
|
|
1734
|
-
units: int = None,
|
|
1735
|
-
activation: str | keras.layers.Activation | None = None,
|
|
1736
|
-
use_bias: bool = True,
|
|
1737
|
-
field: str = 'node',
|
|
1738
|
-
**kwargs
|
|
1739
|
-
) -> None:
|
|
1740
|
-
super().__init__(use_bias=use_bias, **kwargs)
|
|
1741
|
-
self.units = units
|
|
1742
|
-
self._activation = keras.activations.get(activation)
|
|
1743
|
-
self.field = field
|
|
1744
|
-
|
|
1745
|
-
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1746
|
-
data = getattr(spec, self.field, None)
|
|
1747
|
-
if data is None:
|
|
1748
|
-
raise ValueError('Could not access field {self.field!r}.')
|
|
1749
|
-
feature_dim = data['feature'].shape[-1]
|
|
1750
|
-
if not self.units:
|
|
1751
|
-
self.units = feature_dim
|
|
1752
|
-
self._dense = self.get_dense(self.units)
|
|
1753
|
-
|
|
1754
|
-
def propagate(self, tensor: tensors.GraphTensor):
|
|
1755
|
-
feature = getattr(tensor, self.field)['feature']
|
|
1756
|
-
feature = self._dense(feature)
|
|
1757
|
-
feature = self._activation(feature)
|
|
1758
|
-
return tensor.update(
|
|
1759
|
-
{
|
|
1760
|
-
self.field: {
|
|
1761
|
-
'feature': feature
|
|
1762
|
-
}
|
|
1763
|
-
}
|
|
1764
|
-
)
|
|
1765
|
-
|
|
1766
|
-
def get_config(self) -> dict:
|
|
1767
|
-
config = super().get_config()
|
|
1768
|
-
config.update({
|
|
1769
|
-
'units': self.units,
|
|
1770
|
-
'activation': keras.activations.serialize(self._activation),
|
|
1771
|
-
'field': self.field,
|
|
1772
|
-
})
|
|
1773
|
-
return config
|
|
1774
|
-
|
|
1775
|
-
|
|
1776
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1777
|
-
class ContextProjection(Projection):
|
|
1778
|
-
"""Context projection layer.
|
|
1779
|
-
"""
|
|
1780
|
-
def __init__(self, units: int = None, activation: str = None, **kwargs):
|
|
1781
|
-
kwargs['field'] = 'context'
|
|
1782
|
-
super().__init__(units=units, activation=activation, **kwargs)
|
|
1783
|
-
|
|
1483
|
+
class GraphNetwork(GraphLayer):
|
|
1784
1484
|
|
|
1785
|
-
|
|
1786
|
-
class NodeProjection(Projection):
|
|
1787
|
-
"""Node projection layer.
|
|
1788
|
-
"""
|
|
1789
|
-
def __init__(self, units: int = None, activation: str = None, **kwargs):
|
|
1790
|
-
kwargs['field'] = 'node'
|
|
1791
|
-
super().__init__(units=units, activation=activation, **kwargs)
|
|
1485
|
+
"""Graph neural network.
|
|
1792
1486
|
|
|
1487
|
+
Sequentially calls graph layers (`GraphLayer`) and concatenates its output.
|
|
1793
1488
|
|
|
1794
|
-
|
|
1795
|
-
|
|
1796
|
-
|
|
1489
|
+
Arguments:
|
|
1490
|
+
layers (list):
|
|
1491
|
+
A list of graph layers.
|
|
1797
1492
|
"""
|
|
1798
|
-
def __init__(self, units: int = None, activation: str = None, **kwargs):
|
|
1799
|
-
kwargs['field'] = 'edge'
|
|
1800
|
-
super().__init__(units=units, activation=activation, **kwargs)
|
|
1801
|
-
|
|
1802
|
-
|
|
1803
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1804
|
-
class Reconstruction(GraphLayer):
|
|
1805
1493
|
|
|
1806
|
-
def __init__(
|
|
1807
|
-
self,
|
|
1808
|
-
loss: keras.losses.Loss | str = 'mse',
|
|
1809
|
-
loss_weight: float = 0.5,
|
|
1810
|
-
**kwargs
|
|
1811
|
-
):
|
|
1812
|
-
super().__init__(**kwargs)
|
|
1813
|
-
self._loss_fn = keras.losses.get(loss)
|
|
1814
|
-
self._loss_weight = loss_weight
|
|
1815
|
-
|
|
1816
|
-
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1817
|
-
has_target_node_feature = 'target_feature' in spec.node
|
|
1818
|
-
if not has_target_node_feature:
|
|
1819
|
-
raise ValueError(
|
|
1820
|
-
'Could not find `target_feature` in `spec.node`. '
|
|
1821
|
-
'Add a `target_feature` via `NodeEmbedding` by setting '
|
|
1822
|
-
'`allow_reconstruction` to `True`.'
|
|
1823
|
-
)
|
|
1824
|
-
output_dim = spec.node['target_feature'].shape[-1]
|
|
1825
|
-
self._dense = self.get_dense(output_dim)
|
|
1826
|
-
|
|
1827
|
-
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1828
|
-
target_node_feature = tensor.node['target_feature']
|
|
1829
|
-
transformed_node_feature = tensor.node['feature']
|
|
1830
|
-
|
|
1831
|
-
reconstructed_node_feature = self._dense(
|
|
1832
|
-
transformed_node_feature
|
|
1833
|
-
)
|
|
1834
|
-
|
|
1835
|
-
loss = self._loss_fn(
|
|
1836
|
-
target_node_feature, reconstructed_node_feature
|
|
1837
|
-
)
|
|
1838
|
-
self.add_loss(keras.ops.sum(loss) * self._loss_weight)
|
|
1839
|
-
return tensor.update({'node': {'feature': transformed_node_feature}})
|
|
1840
|
-
|
|
1841
|
-
def get_config(self):
|
|
1842
|
-
config = super().get_config()
|
|
1843
|
-
config['loss'] = keras.losses.serialize(self._loss_fn)
|
|
1844
|
-
config['loss_weight'] = self._loss_weight
|
|
1845
|
-
return config
|
|
1846
|
-
|
|
1847
|
-
|
|
1848
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1849
|
-
class EdgeBias(GraphLayer):
|
|
1850
|
-
|
|
1851
|
-
def __init__(self, biases: int, **kwargs):
|
|
1494
|
+
def __init__(self, layers: list[GraphLayer], **kwargs) -> None:
|
|
1852
1495
|
super().__init__(**kwargs)
|
|
1853
|
-
self.
|
|
1496
|
+
self.layers = layers
|
|
1497
|
+
self._update_edge_feature = False
|
|
1854
1498
|
|
|
1855
1499
|
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1856
|
-
|
|
1857
|
-
|
|
1858
|
-
|
|
1859
|
-
|
|
1860
|
-
|
|
1861
|
-
|
|
1862
|
-
|
|
1500
|
+
units = self.layers[0].units
|
|
1501
|
+
node_feature_dim = spec.node['feature'].shape[-1]
|
|
1502
|
+
self._update_node_feature = node_feature_dim != units
|
|
1503
|
+
if self._update_node_feature:
|
|
1504
|
+
warnings.warn(
|
|
1505
|
+
'Node feature dim does not match `units` of the first layer. '
|
|
1506
|
+
'Automatically adding a node projection layer to match `units`.',
|
|
1507
|
+
stacklevel=2
|
|
1863
1508
|
)
|
|
1509
|
+
self._node_dense = self.get_dense(units)
|
|
1510
|
+
self._has_edge_feature = 'feature' in spec.edge
|
|
1511
|
+
if self._has_edge_feature:
|
|
1512
|
+
edge_feature_dim = spec.edge['feature'].shape[-1]
|
|
1513
|
+
self._update_edge_feature = edge_feature_dim != units
|
|
1514
|
+
if self._update_edge_feature:
|
|
1515
|
+
warnings.warn(
|
|
1516
|
+
'Edge feature dim does not match `units` of the first layer. '
|
|
1517
|
+
'Automatically adding a edge projection layer to match `units`.',
|
|
1518
|
+
stacklevel=2
|
|
1519
|
+
)
|
|
1520
|
+
self._edge_dense = self.get_dense(units)
|
|
1864
1521
|
|
|
1865
1522
|
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1866
|
-
|
|
1867
|
-
|
|
1868
|
-
|
|
1523
|
+
x = tensors.to_dict(tensor)
|
|
1524
|
+
if self._update_node_feature:
|
|
1525
|
+
x['node']['feature'] = self._node_dense(tensor.node['feature'])
|
|
1526
|
+
if self._has_edge_feature and self._update_edge_feature:
|
|
1527
|
+
x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
|
|
1528
|
+
outputs = [x['node']['feature']]
|
|
1529
|
+
for layer in self.layers:
|
|
1530
|
+
x = layer(x)
|
|
1531
|
+
outputs.append(x['node']['feature'])
|
|
1532
|
+
return tensor.update(
|
|
1533
|
+
{
|
|
1534
|
+
'node': {
|
|
1535
|
+
'feature': keras.ops.concatenate(outputs, axis=-1)
|
|
1536
|
+
}
|
|
1537
|
+
}
|
|
1869
1538
|
)
|
|
1870
|
-
if self._has_edge_feature:
|
|
1871
|
-
bias += self._edge_feature_dense(tensor.edge['feature'])
|
|
1872
|
-
if self._has_edge_length:
|
|
1873
|
-
bias += self._edge_length_dense(tensor.edge['length'])
|
|
1874
|
-
return bias
|
|
1875
|
-
|
|
1876
|
-
def get_config(self) -> dict:
|
|
1877
|
-
config = super().get_config()
|
|
1878
|
-
config.update({'biases': self.biases})
|
|
1879
|
-
return config
|
|
1880
1539
|
|
|
1540
|
+
def tape_propagate(
|
|
1541
|
+
self,
|
|
1542
|
+
tensor: tensors.GraphTensor,
|
|
1543
|
+
tape: tf.GradientTape,
|
|
1544
|
+
training: bool | None = None,
|
|
1545
|
+
) -> tuple[tensors.GraphTensor, list[tf.Tensor]]:
|
|
1546
|
+
"""Performs the propagation with a `GradientTape`.
|
|
1881
1547
|
|
|
1882
|
-
|
|
1883
|
-
|
|
1884
|
-
|
|
1885
|
-
def __init__(self, kernels: int, **kwargs):
|
|
1886
|
-
super().__init__(**kwargs)
|
|
1887
|
-
self.kernels = kernels
|
|
1548
|
+
Performs the same forward pass as `propagate` but with a `GradientTape`
|
|
1549
|
+
watching intermediate node features.
|
|
1888
1550
|
|
|
1889
|
-
|
|
1890
|
-
|
|
1891
|
-
|
|
1892
|
-
|
|
1893
|
-
|
|
1894
|
-
|
|
1895
|
-
|
|
1896
|
-
|
|
1897
|
-
|
|
1898
|
-
|
|
1899
|
-
|
|
1900
|
-
|
|
1901
|
-
)
|
|
1551
|
+
Arguments:
|
|
1552
|
+
tensor (tensors.GraphTensor):
|
|
1553
|
+
The graph input.
|
|
1554
|
+
"""
|
|
1555
|
+
if isinstance(tensor, tensors.GraphTensor):
|
|
1556
|
+
x = tensors.to_dict(tensor)
|
|
1557
|
+
else:
|
|
1558
|
+
x = tensor
|
|
1559
|
+
if self._update_node_feature:
|
|
1560
|
+
x['node']['feature'] = self._node_dense(tensor.node['feature'])
|
|
1561
|
+
if self._update_edge_feature:
|
|
1562
|
+
x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
|
|
1563
|
+
tape.watch(x['node']['feature'])
|
|
1564
|
+
outputs = [x['node']['feature']]
|
|
1565
|
+
for layer in self.layers:
|
|
1566
|
+
x = layer(x, training=training)
|
|
1567
|
+
tape.watch(x['node']['feature'])
|
|
1568
|
+
outputs.append(x['node']['feature'])
|
|
1902
1569
|
|
|
1903
|
-
|
|
1904
|
-
|
|
1905
|
-
|
|
1906
|
-
|
|
1907
|
-
|
|
1908
|
-
|
|
1909
|
-
return ops.gaussian(
|
|
1910
|
-
euclidean_distance, self._loc, self._scale
|
|
1570
|
+
tensor = tensor.update(
|
|
1571
|
+
{
|
|
1572
|
+
'node': {
|
|
1573
|
+
'feature': keras.ops.concatenate(outputs, axis=-1)
|
|
1574
|
+
}
|
|
1575
|
+
}
|
|
1911
1576
|
)
|
|
1912
|
-
|
|
1577
|
+
return tensor, outputs
|
|
1578
|
+
|
|
1913
1579
|
def get_config(self) -> dict:
|
|
1914
1580
|
config = super().get_config()
|
|
1915
|
-
config.update(
|
|
1916
|
-
|
|
1917
|
-
|
|
1581
|
+
config.update(
|
|
1582
|
+
{
|
|
1583
|
+
'layers': [
|
|
1584
|
+
keras.layers.serialize(layer) for layer in self.layers
|
|
1585
|
+
]
|
|
1586
|
+
}
|
|
1587
|
+
)
|
|
1918
1588
|
return config
|
|
1589
|
+
|
|
1590
|
+
@classmethod
|
|
1591
|
+
def from_config(cls, config: dict) -> 'GraphNetwork':
|
|
1592
|
+
config['layers'] = [
|
|
1593
|
+
keras.layers.deserialize(layer) for layer in config['layers']
|
|
1594
|
+
]
|
|
1595
|
+
return super().from_config(config)
|
|
1919
1596
|
|
|
1920
1597
|
|
|
1921
1598
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
@@ -1966,7 +1643,7 @@ class GaussianParams(keras.layers.Dense):
|
|
|
1966
1643
|
config['units'] = None
|
|
1967
1644
|
config['activation'] = keras.activations.serialize(self.loc_activation)
|
|
1968
1645
|
return config
|
|
1969
|
-
|
|
1646
|
+
|
|
1970
1647
|
|
|
1971
1648
|
def Input(spec: tensors.GraphTensor.Spec) -> dict:
|
|
1972
1649
|
"""Used to specify inputs to model.
|
|
@@ -1999,14 +1676,11 @@ def Input(spec: tensors.GraphTensor.Spec) -> dict:
|
|
|
1999
1676
|
for outer_field, data in spec.__dict__.items():
|
|
2000
1677
|
inputs[outer_field] = {}
|
|
2001
1678
|
for inner_field, nested_spec in data.items():
|
|
2002
|
-
if inner_field in ['label', 'weight']:
|
|
1679
|
+
if outer_field == 'context' and inner_field in ['label', 'weight']:
|
|
2003
1680
|
# Remove context label and weight from the symbolic input
|
|
2004
1681
|
# as a functional model is strict for what input can be passed.
|
|
2005
|
-
# We want to
|
|
2006
|
-
|
|
2007
|
-
# temporarily pops label and weight to avoid errors.
|
|
2008
|
-
if outer_field == 'context':
|
|
2009
|
-
continue
|
|
1682
|
+
# (We want to train and predict with the model.)
|
|
1683
|
+
continue
|
|
2010
1684
|
kwargs = {
|
|
2011
1685
|
'shape': nested_spec.shape[1:],
|
|
2012
1686
|
'dtype': nested_spec.dtype,
|
|
@@ -2024,13 +1698,6 @@ def Input(spec: tensors.GraphTensor.Spec) -> dict:
|
|
|
2024
1698
|
return inputs
|
|
2025
1699
|
|
|
2026
1700
|
|
|
2027
|
-
def warn(message: str) -> None:
|
|
2028
|
-
warnings.warn(
|
|
2029
|
-
message=message,
|
|
2030
|
-
category=UserWarning,
|
|
2031
|
-
stacklevel=1
|
|
2032
|
-
)
|
|
2033
|
-
|
|
2034
1701
|
def _serialize_spec(spec: tensors.GraphTensor.Spec) -> dict:
|
|
2035
1702
|
serialized_spec = {}
|
|
2036
1703
|
for outer_field, data in spec.__dict__.items():
|
|
@@ -2072,5 +1739,3 @@ def _spec_from_inputs(inputs):
|
|
|
2072
1739
|
|
|
2073
1740
|
|
|
2074
1741
|
GraphTransformer = GTConv
|
|
2075
|
-
GraphTransformer3D = GTConv3D
|
|
2076
|
-
|