molcraft 0.1.0a7__py3-none-any.whl → 0.1.0a9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

molcraft/layers.py CHANGED
@@ -274,20 +274,14 @@ class GraphConv(GraphLayer):
274
274
  units (int):
275
275
  Dimensionality of the output space.
276
276
  activation (keras.layers.Activation, str, None):
277
- Activation function to use. If not specified, a linear activation (a(x) = x) is used.
278
- Default to `None`.
277
+ Activation function to be accessed via `self.activation`, and used for the
278
+ `message()` and `update()` methods, if not overriden. Default to `relu`.
279
279
  use_bias (bool):
280
- Whether bias should be used in dense layers. Default to `True`.
280
+ Whether bias should be used in the dense layers. Default to `True`.
281
281
  normalize (bool, str):
282
- Whether `LayerNormalization` should be applied to the final node feature output.
283
- To use `BatchNormalization`, specify `batch_norm`. Default to `False`.
284
- skip_connect (bool, str):
285
- Whether node feature input should be added to the node feature output.
286
- If node feature input dim is not equal to `units` (node feature output dim),
287
- a projection layer will automatically project the residual before adding it
288
- to the output. To use weighted skip connection,
289
- specify `weighted`. The weight multiplied with the skip connection is a
290
- learnable scalar. Default to `True`.
282
+ Whether normalization should be applied to the final output. Default to `False`.
283
+ skip_connect (bool):
284
+ Whether node feature input should be added to the node feature output. Default to `True`.
291
285
  kernel_initializer (keras.initializers.Initializer, str):
292
286
  Initializer for the kernel weight matrix of the dense layers.
293
287
  Default to `glorot_uniform`.
@@ -314,10 +308,10 @@ class GraphConv(GraphLayer):
314
308
  def __init__(
315
309
  self,
316
310
  units: int = None,
317
- activation: str | keras.layers.Activation | None = None,
311
+ activation: str | keras.layers.Activation | None = 'relu',
318
312
  use_bias: bool = True,
319
- normalize: bool | str = False,
320
- skip_connect: bool | str = True,
313
+ normalize: bool = False,
314
+ skip_connect: bool = True,
321
315
  **kwargs
322
316
  ) -> None:
323
317
  super().__init__(use_bias=use_bias, **kwargs)
@@ -341,56 +335,56 @@ class GraphConv(GraphLayer):
341
335
  def units(self):
342
336
  return self._units
343
337
 
338
+ @property
339
+ def activation(self):
340
+ return self._activation
341
+
344
342
  def build(self, spec: tensors.GraphTensor.Spec) -> None:
345
343
  if not self.units:
346
344
  raise ValueError(
347
345
  f'`self.units` needs to be a positive integer. Found: {self.units}.'
348
346
  )
349
347
  node_feature_dim = spec.node['feature'].shape[-1]
350
- self._project_input_node_feature = (
348
+ self._project_residual = (
351
349
  self._skip_connect and (node_feature_dim != self.units)
352
350
  )
353
- if self._project_input_node_feature:
354
- warn(
351
+ if self._project_residual:
352
+ warnings.warn(
355
353
  '`skip_connect` is set to `True`, but found incompatible dim '
356
354
  'between input (node feature dim) and output (`self.units`). '
357
355
  'Automatically applying a projection layer to residual to '
358
- 'match input and output. '
356
+ 'match input and output. ',
357
+ stacklevel=2,
359
358
  )
360
- self._residual_projection = self.get_dense(
361
- self.units, name='residual_projection'
359
+ self._residual_dense = self.get_dense(
360
+ self.units, name='residual_dense'
362
361
  )
363
362
 
364
- skip_connect = str(self._skip_connect).lower()
365
- self._use_weighted_skip_connection = skip_connect.startswith('weight')
366
- if self._use_weighted_skip_connection:
367
- self._skip_connection_weight = self.add_weight(
368
- name='skip_connection_weight',
369
- shape=(),
370
- initializer='ones',
371
- trainable=True,
372
- )
373
-
374
- if self._normalize:
375
- if str(self._normalize).lower().startswith('batch'):
376
- self._output_norm = keras.layers.BatchNormalization(
377
- name='output_batch_norm'
378
- )
379
- else:
380
- self._output_norm = keras.layers.LayerNormalization(
381
- name='output_layer_norm'
382
- )
383
-
384
- self._has_edge_feature = 'feature' in spec.edge
363
+ self.has_edge_feature = 'feature' in spec.edge
364
+ self.has_node_coordinate = 'coordinate' in spec.node
385
365
 
386
366
  has_overridden_message = self.__class__.message != GraphConv.message
387
367
  if not has_overridden_message:
388
- self._message_dense = self.get_dense(self.units)
368
+ self._message_intermediate_dense = self.get_dense(self.units)
369
+ self._message_intermediate_activation = self.activation
370
+ self._message_final_dense = self.get_dense(self.units)
371
+
372
+ has_overridden_aggregate = self.__class__.message != GraphConv.aggregate
373
+ if not has_overridden_aggregate:
374
+ pass
389
375
 
390
376
  has_overridden_update = self.__class__.update != GraphConv.update
391
377
  if not has_overridden_update:
392
- self._output_dense = self.get_dense(self.units)
393
- self._output_activation = self._activation
378
+ self._update_intermediate_dense = self.get_dense(self.units)
379
+ self._update_intermediate_activation = self.activation
380
+ self._update_final_dense = self.get_dense(self.units)
381
+
382
+ if not self._normalize:
383
+ self._normalization = keras.layers.Identity()
384
+ elif str(self._normalize).lower().startswith('layer'):
385
+ self._normalization = keras.layers.LayerNormalization()
386
+ else:
387
+ self._normalization = keras.layers.BatchNormalization()
394
388
 
395
389
  def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
396
390
  """Forward pass.
@@ -402,10 +396,10 @@ class GraphConv(GraphLayer):
402
396
  A `GraphTensor` instance.
403
397
  """
404
398
  if self._skip_connect:
405
- input_node_feature = tensor.node['feature']
406
- if self._project_input_node_feature:
407
- input_node_feature = self._residual_projection(input_node_feature)
408
-
399
+ residual = tensor.node['feature']
400
+ if self._project_residual:
401
+ residual = self._residual_dense(residual)
402
+
409
403
  message = self.message(tensor)
410
404
  if not isinstance(message, tensors.GraphTensor):
411
405
  message = tensor.update({'edge': {'message': message}})
@@ -417,24 +411,24 @@ class GraphConv(GraphLayer):
417
411
  aggregate = tensor.update({'node': {'aggregate': aggregate}})
418
412
  elif not 'aggregate' in aggregate.node:
419
413
  raise ValueError('Could not find `aggregate` in `node` output.')
420
-
414
+
421
415
  update = self.update(aggregate)
422
416
  if not isinstance(update, tensors.GraphTensor):
423
417
  update = tensor.update({'node': {'feature': update}})
424
418
  elif not 'feature' in update.node:
425
419
  raise ValueError('Could not find `feature` in `node` output.')
426
-
427
- updated_node_feature = update.node['feature']
420
+
421
+ if update.node['feature'].shape[-1] != self.units:
422
+ raise ValueError('Updated node `feature` is not equal to `self.units`.')
423
+
424
+ feature = update.node['feature']
428
425
 
429
426
  if self._skip_connect:
430
- if self._use_weighted_skip_connection:
431
- input_node_feature *= self._skip_connection_weight
432
- updated_node_feature += input_node_feature
433
-
434
- if self._normalize:
435
- updated_node_feature = self._output_norm(updated_node_feature)
427
+ feature += residual
428
+
429
+ feature = self._normalization(feature)
436
430
 
437
- return update.update({'node': {'feature': updated_node_feature}})
431
+ return update.update({'node': {'feature': feature}})
438
432
 
439
433
  def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
440
434
  """Compute messages.
@@ -445,24 +439,38 @@ class GraphConv(GraphLayer):
445
439
  tensor:
446
440
  The inputted `GraphTensor` instance.
447
441
  """
448
- if not self._has_edge_feature:
449
- message_feature = tensor.gather('feature', 'source')
450
- else:
451
- message_feature = keras.ops.concatenate(
442
+ message = keras.ops.concatenate(
443
+ [
444
+ tensor.gather('feature', 'source'),
445
+ tensor.gather('feature', 'target'),
446
+ ],
447
+ axis=-1
448
+ )
449
+ if self.has_edge_feature:
450
+ message = keras.ops.concatenate(
452
451
  [
453
- tensor.gather('feature', 'source'),
452
+ message,
454
453
  tensor.edge['feature']
455
454
  ],
456
455
  axis=-1
457
456
  )
458
- message = self._message_dense(message_feature)
459
- return tensor.update(
460
- {
461
- 'edge': {
462
- 'message': message
463
- }
464
- }
465
- )
457
+ if self.has_node_coordinate:
458
+ euclidean_distance = ops.euclidean_distance(
459
+ tensor.gather('coordinate', 'target'),
460
+ tensor.gather('coordinate', 'source'),
461
+ axis=-1
462
+ )
463
+ message = keras.ops.concatenate(
464
+ [
465
+ message,
466
+ euclidean_distance
467
+ ],
468
+ axis=-1
469
+ )
470
+ message = self._message_intermediate_dense(message)
471
+ message = self._message_intermediate_activation(message)
472
+ message = self._message_final_dense(message)
473
+ return tensor.update({'edge': {'message': message}})
466
474
 
467
475
  def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
468
476
  """Aggregates messages.
@@ -473,14 +481,16 @@ class GraphConv(GraphLayer):
473
481
  tensor:
474
482
  A `GraphTensor` instance containing a message.
475
483
  """
484
+ previous = tensor.node['feature']
476
485
  aggregate = tensor.aggregate('message', mode='mean')
486
+ aggregate = keras.ops.concatenate([aggregate, previous], axis=-1)
477
487
  return tensor.update(
478
488
  {
479
489
  'node': {
480
- 'aggregate': aggregate,
490
+ 'aggregate': aggregate,
481
491
  },
482
492
  'edge': {
483
- 'message': None
493
+ 'message': None,
484
494
  }
485
495
  }
486
496
  )
@@ -495,21 +505,16 @@ class GraphConv(GraphLayer):
495
505
  A `GraphTensor` instance containing aggregated messages
496
506
  (updated node features).
497
507
  """
498
- feature = keras.ops.concatenate(
499
- [
500
- tensor.node['aggregate'],
501
- tensor.node['feature']
502
- ],
503
- axis=-1
504
- )
505
- update = self._output_dense(feature)
506
- update = self._output_activation(update)
508
+ aggregate = tensor.node['aggregate']
509
+ node_feature = self._update_intermediate_dense(aggregate)
510
+ node_feature = self._update_intermediate_activation(node_feature)
511
+ node_feature = self._update_final_dense(node_feature)
507
512
  return tensor.update(
508
513
  {
509
514
  'node': {
510
- 'feature': update,
515
+ 'feature': node_feature,
511
516
  'aggregate': None,
512
- }
517
+ },
513
518
  }
514
519
  )
515
520
 
@@ -563,14 +568,16 @@ class GIConv(GraphConv):
563
568
  activation: keras.layers.Activation | str | None = 'relu',
564
569
  use_bias: bool = True,
565
570
  normalize: bool = False,
571
+ skip_connect: bool = True,
566
572
  update_edge_feature: bool = True,
567
573
  **kwargs,
568
574
  ):
569
575
  super().__init__(
570
576
  units=units,
571
577
  activation=activation,
572
- normalize=normalize,
573
578
  use_bias=use_bias,
579
+ normalize=normalize,
580
+ skip_connect=skip_connect,
574
581
  **kwargs
575
582
  )
576
583
  self._update_edge_feature = update_edge_feature
@@ -585,16 +592,16 @@ class GIConv(GraphConv):
585
592
  trainable=True,
586
593
  )
587
594
 
588
- self._has_edge_feature = 'feature' in spec.edge
589
- if self._has_edge_feature:
595
+ if self.has_edge_feature:
590
596
  edge_feature_dim = spec.edge['feature'].shape[-1]
591
597
 
592
598
  if not self._update_edge_feature:
593
599
  if (edge_feature_dim != node_feature_dim):
594
- warn(
600
+ warnings.warn(
595
601
  'Found edge feature dim to be incompatible with node feature dim. '
596
602
  'Automatically adding a edge feature projection layer to match '
597
- 'the dim of node features.'
603
+ 'the dim of node features.',
604
+ stacklevel=2,
598
605
  )
599
606
  self._update_edge_feature = True
600
607
 
@@ -603,19 +610,14 @@ class GIConv(GraphConv):
603
610
  else:
604
611
  self._update_edge_feature = False
605
612
 
606
- has_overridden_update = self.__class__.update != GIConv.update
607
- if not has_overridden_update:
608
- self._feedforward_intermediate_dense = self.get_dense(self.units)
609
- self._feedforward_activation = self._activation
610
- self._feedforward_output_dense = self.get_dense(self.units)
611
-
612
613
  def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
613
614
  message = tensor.gather('feature', 'source')
614
615
  edge_feature = tensor.edge.get('feature')
615
616
  if self._update_edge_feature:
616
617
  edge_feature = self._edge_dense(edge_feature)
617
- if self._has_edge_feature:
618
+ if self.has_edge_feature:
618
619
  message += edge_feature
620
+ message = keras.ops.relu(message)
619
621
  return tensor.update(
620
622
  {
621
623
  'edge': {
@@ -639,20 +641,6 @@ class GIConv(GraphConv):
639
641
  }
640
642
  )
641
643
 
642
- def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
643
- node_feature = tensor.node['aggregate']
644
- node_feature = self._feedforward_intermediate_dense(node_feature)
645
- node_feature = self._feedforward_activation(node_feature)
646
- node_feature = self._feedforward_output_dense(node_feature)
647
- return tensor.update(
648
- {
649
- 'node': {
650
- 'feature': node_feature,
651
- 'aggregate': None,
652
- }
653
- }
654
- )
655
-
656
644
  def get_config(self) -> dict:
657
645
  config = super().get_config()
658
646
  config.update({
@@ -701,15 +689,16 @@ class GAConv(GraphConv):
701
689
  activation: keras.layers.Activation | str | None = "relu",
702
690
  use_bias: bool = True,
703
691
  normalize: bool = False,
692
+ skip_connect: bool = True,
704
693
  update_edge_feature: bool = True,
705
- attention_activation: keras.layers.Activation | str | None = "leaky_relu",
706
694
  **kwargs,
707
695
  ) -> None:
708
696
  super().__init__(
709
697
  units=units,
710
698
  activation=activation,
711
- use_bias=use_bias,
712
699
  normalize=normalize,
700
+ use_bias=use_bias,
701
+ skip_connect=skip_connect,
713
702
  **kwargs
714
703
  )
715
704
  self._heads = heads
@@ -717,7 +706,6 @@ class GAConv(GraphConv):
717
706
  raise ValueError(f"units need to be divisible by heads.")
718
707
  self._head_units = self.units // self.heads
719
708
  self._update_edge_feature = update_edge_feature
720
- self._attention_activation = keras.activations.get(attention_activation)
721
709
 
722
710
  @property
723
711
  def heads(self):
@@ -728,8 +716,7 @@ class GAConv(GraphConv):
728
716
  return self._head_units
729
717
 
730
718
  def build(self, spec: tensors.GraphTensor.Spec) -> None:
731
- self._has_edge_feature = 'feature' in spec.edge
732
- self._update_edge_feature = self._has_edge_feature and self._update_edge_feature
719
+ self._update_edge_feature = self.has_edge_feature and self._update_edge_feature
733
720
  if self._update_edge_feature:
734
721
  self._edge_dense = self.get_einsum_dense(
735
722
  'ijh,jkh->ikh', (self.head_units, self.heads)
@@ -743,15 +730,6 @@ class GAConv(GraphConv):
743
730
  self._attention_dense = self.get_einsum_dense(
744
731
  'ijh,jkh->ikh', (1, self.heads)
745
732
  )
746
- self._node_self_dense = self.get_einsum_dense(
747
- 'ij,jkh->ikh', (self.head_units, self.heads)
748
- )
749
-
750
- has_overridden_update = self.__class__.update != GAConv.update
751
- if not has_overridden_update:
752
- self._feedforward_intermediate_dense = self.get_dense(self.units)
753
- self._feedforward_activation = self._activation
754
- self._feedforward_output_dense = self.get_dense(self.units)
755
733
 
756
734
  def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
757
735
  attention_feature = keras.ops.concatenate(
@@ -761,7 +739,7 @@ class GAConv(GraphConv):
761
739
  ],
762
740
  axis=-1
763
741
  )
764
- if self._has_edge_feature:
742
+ if self.has_edge_feature:
765
743
  attention_feature = keras.ops.concatenate(
766
744
  [
767
745
  attention_feature,
@@ -778,7 +756,7 @@ class GAConv(GraphConv):
778
756
  edge_feature = self._edge_dense(attention_feature)
779
757
  edge_feature = keras.ops.reshape(edge_feature, (-1, self.units))
780
758
 
781
- attention_feature = self._attention_activation(attention_feature)
759
+ attention_feature = keras.ops.leaky_relu(attention_feature)
782
760
  attention_score = self._attention_dense(attention_feature)
783
761
  attention_score = ops.edge_softmax(
784
762
  score=attention_score, edge_target=tensor.edge['target']
@@ -797,7 +775,6 @@ class GAConv(GraphConv):
797
775
 
798
776
  def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
799
777
  node_feature = tensor.aggregate('message', mode='sum')
800
- node_feature += self._node_self_dense(tensor.node['feature'])
801
778
  node_feature = keras.ops.reshape(node_feature, (-1, self.units))
802
779
  return tensor.update(
803
780
  {
@@ -810,28 +787,123 @@ class GAConv(GraphConv):
810
787
  }
811
788
  )
812
789
 
790
+ def get_config(self) -> dict:
791
+ config = super().get_config()
792
+ config.update({
793
+ "heads": self._heads,
794
+ 'update_edge_feature': self._update_edge_feature,
795
+ })
796
+ return config
797
+
798
+
799
+ @keras.saving.register_keras_serializable(package='molcraft')
800
+ class MPConv(GraphConv):
801
+
802
+ """Message passing neural network layer.
803
+
804
+ Also supports 3D molecular graphs.
805
+
806
+ >>> graph = molcraft.tensors.GraphTensor(
807
+ ... context={
808
+ ... 'size': [2]
809
+ ... },
810
+ ... node={
811
+ ... 'feature': [[1.], [2.]]
812
+ ... },
813
+ ... edge={
814
+ ... 'source': [0, 1],
815
+ ... 'target': [1, 0],
816
+ ... }
817
+ ... )
818
+ >>> conv = molcraft.layers.MPConv(units=4)
819
+ >>> conv(graph)
820
+ GraphTensor(
821
+ context={
822
+ 'size': <tf.Tensor: shape=[1], dtype=int32>
823
+ },
824
+ node={
825
+ 'feature': <tf.Tensor: shape=[2, 4], dtype=float32>
826
+ },
827
+ edge={
828
+ 'source': <tf.Tensor: shape=[2], dtype=int32>,
829
+ 'target': <tf.Tensor: shape=[2], dtype=int32>
830
+ }
831
+ )
832
+ """
833
+
834
+ def __init__(
835
+ self,
836
+ units: int = 128,
837
+ activation: keras.layers.Activation | str | None = 'relu',
838
+ use_bias: bool = True,
839
+ normalize: bool = False,
840
+ skip_connect: bool = True,
841
+ **kwargs
842
+ ) -> None:
843
+ super().__init__(
844
+ units=units,
845
+ activation=activation,
846
+ use_bias=use_bias,
847
+ normalize=normalize,
848
+ skip_connect=skip_connect,
849
+ **kwargs
850
+ )
851
+
852
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
853
+ node_feature_dim = spec.node['feature'].shape[-1]
854
+ self.update_fn = keras.layers.GRUCell(self.units)
855
+ self._project_previous_node_feature = node_feature_dim != self.units
856
+ if self._project_previous_node_feature:
857
+ warnings.warn(
858
+ 'Input node feature dim does not match updated node feature dim. '
859
+ 'To make sure input node feature can be passed as `states` to the '
860
+ 'GRU cell, it will automatically be projected prior to it.',
861
+ stacklevel=2
862
+ )
863
+ self._previous_node_dense = self.get_dense(self.units)
864
+
865
+ def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
866
+ """Aggregates messages.
867
+
868
+ This method may be overridden by subclass.
869
+
870
+ Arguments:
871
+ tensor:
872
+ A `GraphTensor` instance containing a message.
873
+ """
874
+ aggregate = tensor.aggregate('message', mode='mean')
875
+ return tensor.update(
876
+ {
877
+ 'node': {
878
+ 'aggregate': aggregate,
879
+ },
880
+ 'edge': {
881
+ 'message': None,
882
+ }
883
+ }
884
+ )
885
+
813
886
  def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
814
- node_feature = tensor.node['aggregate']
815
- node_feature = self._feedforward_intermediate_dense(node_feature)
816
- node_feature = self._feedforward_activation(node_feature)
817
- node_feature = self._feedforward_output_dense(node_feature)
887
+ previous = tensor.node['feature']
888
+ aggregate = tensor.node['aggregate']
889
+ if self._project_previous_node_feature:
890
+ previous = self._previous_node_dense(previous)
891
+ updated_node_feature, _ = self.update_fn(
892
+ inputs=aggregate, states=previous
893
+ )
818
894
  return tensor.update(
819
895
  {
820
896
  'node': {
821
- 'feature': node_feature,
897
+ 'feature': updated_node_feature,
822
898
  'aggregate': None,
823
899
  }
824
900
  }
825
901
  )
826
-
902
+
827
903
  def get_config(self) -> dict:
828
904
  config = super().get_config()
829
- config.update({
830
- "heads": self._heads,
831
- 'update_edge_feature': self._update_edge_feature,
832
- 'attention_activation': keras.activations.serialize(self._attention_activation),
833
- })
834
- return config
905
+ config.update({})
906
+ return config
835
907
 
836
908
 
837
909
  @keras.saving.register_keras_serializable(package='molcraft')
@@ -839,6 +911,8 @@ class GTConv(GraphConv):
839
911
 
840
912
  """Graph transformer layer.
841
913
 
914
+ Also supports 3D molecular graphs.
915
+
842
916
  >>> graph = molcraft.tensors.GraphTensor(
843
917
  ... context={
844
918
  ... 'size': [2]
@@ -862,10 +936,9 @@ class GTConv(GraphConv):
862
936
  },
863
937
  edge={
864
938
  'source': <tf.Tensor: shape=[2], dtype=int32>,
865
- 'target': <tf.Tensor: shape=[2], dtype=int32>
939
+ 'target': <tf.Tensor: shape=[2], dtype=int32>,
866
940
  }
867
941
  )
868
-
869
942
  """
870
943
 
871
944
  def __init__(
@@ -875,14 +948,16 @@ class GTConv(GraphConv):
875
948
  activation: keras.layers.Activation | str | None = "relu",
876
949
  use_bias: bool = True,
877
950
  normalize: bool = False,
951
+ skip_connect: bool = True,
878
952
  attention_dropout: float = 0.0,
879
953
  **kwargs,
880
954
  ) -> None:
881
955
  super().__init__(
882
956
  units=units,
883
957
  activation=activation,
884
- use_bias=use_bias,
885
958
  normalize=normalize,
959
+ use_bias=use_bias,
960
+ skip_connect=skip_connect,
886
961
  **kwargs
887
962
  )
888
963
  self._heads = heads
@@ -900,6 +975,8 @@ class GTConv(GraphConv):
900
975
  return self._head_units
901
976
 
902
977
  def build(self, spec: tensors.GraphTensor.Spec) -> None:
978
+ """Builds the layer.
979
+ """
903
980
  self._query_dense = self.get_einsum_dense(
904
981
  'ij,jkh->ikh', (self.head_units, self.heads)
905
982
  )
@@ -912,29 +989,36 @@ class GTConv(GraphConv):
912
989
  self._output_dense = self.get_dense(self.units)
913
990
  self._softmax_dropout = keras.layers.Dropout(self._attention_dropout)
914
991
 
915
- self._add_bias = not 'bias' in spec.edge
916
-
917
- if self._add_bias:
918
- self._edge_bias = EdgeBias(biases=self.heads)
992
+ if self.has_edge_feature:
993
+ self._attention_bias_dense_1 = self.get_einsum_dense('ij,jkh->ikh', (1, self.heads))
919
994
 
920
- has_overridden_update = self.__class__.update != GTConv.update
921
- if not has_overridden_update:
922
- self._feedforward_intermediate_dense = self.get_dense(self.units)
923
- self._feedforward_activation = self._activation
924
- self._feedforward_output_dense = self.get_dense(self.units)
995
+ if self.has_node_coordinate:
996
+ node_feature_dim = spec.node['feature'].shape[-1]
997
+ num_kernels = self.units
998
+ self._gaussian_loc = self.add_weight(
999
+ shape=[num_kernels], initializer='zeros', dtype='float32', trainable=True
1000
+ )
1001
+ self._gaussian_scale = self.add_weight(
1002
+ shape=[num_kernels], initializer='ones', dtype='float32', trainable=True
1003
+ )
1004
+ self._centrality_dense = self.get_dense(units=node_feature_dim)
1005
+ self._attention_bias_dense_2 = self.get_einsum_dense('ij,jkh->ikh', (1, self.heads))
925
1006
 
926
1007
  def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
927
- if self._add_bias:
928
- edge_bias = self._edge_bias(tensor)
929
- tensor = tensor.update(
930
- {
931
- 'edge': {
932
- 'bias': edge_bias
933
- }
934
- }
935
- )
936
1008
  node_feature = tensor.node['feature']
937
-
1009
+
1010
+ if self.has_node_coordinate:
1011
+ euclidean_distance = ops.euclidean_distance(
1012
+ tensor.gather('coordinate', 'target'),
1013
+ tensor.gather('coordinate', 'source'),
1014
+ axis=-1
1015
+ )
1016
+ gaussian = ops.gaussian(
1017
+ euclidean_distance, self._gaussian_loc, self._gaussian_scale
1018
+ )
1019
+ centrality = keras.ops.segment_sum(gaussian, tensor.edge['target'], tensor.num_nodes)
1020
+ node_feature += self._centrality_dense(centrality)
1021
+
938
1022
  query = self._query_dense(node_feature)
939
1023
  key = self._key_dense(node_feature)
940
1024
  value = self._value_dense(node_feature)
@@ -946,23 +1030,45 @@ class GTConv(GraphConv):
946
1030
  attention_score = keras.ops.sum(query * key, axis=1, keepdims=True)
947
1031
  attention_score /= keras.ops.sqrt(float(self.head_units))
948
1032
 
949
- attention_score += keras.ops.expand_dims(tensor.edge['bias'], axis=1)
1033
+ if self.has_edge_feature:
1034
+ attention_score += self._attention_bias_dense_1(tensor.edge['feature'])
1035
+
1036
+ if self.has_node_coordinate:
1037
+ attention_score += self._attention_bias_dense_2(gaussian)
1038
+
950
1039
  attention = ops.edge_softmax(attention_score, tensor.edge['target'])
951
1040
  attention = self._softmax_dropout(attention)
952
- message = ops.edge_weight(value, attention)
953
1041
 
1042
+ if self.has_node_coordinate:
1043
+ displacement = ops.displacement(
1044
+ tensor.gather('coordinate', 'target'),
1045
+ tensor.gather('coordinate', 'source'),
1046
+ normalize=True
1047
+ )
1048
+ attention *= keras.ops.expand_dims(displacement, axis=-1)
1049
+ attention = keras.ops.expand_dims(attention, axis=2)
1050
+ value = keras.ops.expand_dims(value, axis=1)
1051
+
1052
+ message = ops.edge_weight(value, attention)
1053
+
954
1054
  return tensor.update(
955
1055
  {
956
1056
  'edge': {
957
- 'message': message
1057
+ 'message': message,
958
1058
  },
959
1059
  }
960
1060
  )
961
1061
 
962
1062
  def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
963
1063
  node_feature = tensor.aggregate('message', mode='sum')
964
- node_feature = keras.ops.reshape(node_feature, (-1, self.units))
1064
+ if self.has_node_coordinate:
1065
+ shape = (tensor.num_nodes, -1, self.units)
1066
+ else:
1067
+ shape = (tensor.num_nodes, self.units)
1068
+ node_feature = keras.ops.reshape(node_feature, shape)
965
1069
  node_feature = self._output_dense(node_feature)
1070
+ if self.has_node_coordinate:
1071
+ node_feature = keras.ops.sum(node_feature, axis=1)
966
1072
  return tensor.update(
967
1073
  {
968
1074
  'node': {
@@ -973,20 +1079,6 @@ class GTConv(GraphConv):
973
1079
  }
974
1080
  }
975
1081
  )
976
-
977
- def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
978
- node_feature = tensor.node['aggregate']
979
- node_feature = self._feedforward_intermediate_dense(node_feature)
980
- node_feature = self._feedforward_activation(node_feature)
981
- node_feature = self._feedforward_output_dense(node_feature)
982
- return tensor.update(
983
- {
984
- 'node': {
985
- 'feature': node_feature,
986
- 'aggregate': None,
987
- },
988
- }
989
- )
990
1082
 
991
1083
  def get_config(self) -> dict:
992
1084
  config = super().get_config()
@@ -998,17 +1090,49 @@ class GTConv(GraphConv):
998
1090
 
999
1091
 
1000
1092
  @keras.saving.register_keras_serializable(package='molcraft')
1001
- class MPConv(GraphConv):
1093
+ class EGConv(GraphConv):
1002
1094
 
1003
- """Message passing neural network layer.
1095
+ """Equivariant graph neural network layer 3D.
1096
+
1097
+ Only supports 3D molecular graphs.
1098
+
1099
+ >>> graph = molcraft.tensors.GraphTensor(
1100
+ ... context={
1101
+ ... 'size': [2]
1102
+ ... },
1103
+ ... node={
1104
+ ... 'feature': [[1.], [2.]],
1105
+ ... 'coordinate': [[0.1, -0.1, 0.5], [1.2, -0.5, 2.1]],
1106
+ ... },
1107
+ ... edge={
1108
+ ... 'source': [0, 1],
1109
+ ... 'target': [1, 0],
1110
+ ... }
1111
+ ... )
1112
+ >>> conv = molcraft.layers.EGConv(units=4)
1113
+ >>> conv(graph)
1114
+ GraphTensor(
1115
+ context={
1116
+ 'size': <tf.Tensor: shape=[1], dtype=int32>
1117
+ },
1118
+ node={
1119
+ 'feature': <tf.Tensor: shape=[2, 4], dtype=float32>,
1120
+ 'coordinate': <tf.Tensor: shape=[2, 3], dtype=float32>
1121
+ },
1122
+ edge={
1123
+ 'source': <tf.Tensor: shape=[2], dtype=int32>,
1124
+ 'target': <tf.Tensor: shape=[2], dtype=int32>
1125
+ }
1126
+ )
1004
1127
  """
1005
1128
 
1006
1129
  def __init__(
1007
1130
  self,
1008
1131
  units: int = 128,
1009
- activation: keras.layers.Activation | str | None = None,
1132
+ activation: keras.layers.Activation | str | None = 'silu',
1010
1133
  use_bias: bool = True,
1011
1134
  normalize: bool = False,
1135
+ skip_connect: bool = True,
1012
1136
  **kwargs
1013
1137
  ) -> None:
1014
1138
  super().__init__(
@@ -1016,262 +1140,30 @@ class MPConv(GraphConv):
1016
1140
  activation=activation,
1017
1141
  use_bias=use_bias,
1018
1142
  normalize=normalize,
1143
+ skip_connect=skip_connect,
1019
1144
  **kwargs
1020
1145
  )
1021
1146
 
1022
1147
  def build(self, spec: tensors.GraphTensor.Spec) -> None:
1023
- node_feature_dim = spec.node['feature'].shape[-1]
1024
- self.message_fn = self.get_dense(self.units, activation=self._activation)
1025
- self.update_fn = keras.layers.GRUCell(self.units)
1026
- self._has_edge_feature = 'feature' in spec.edge
1027
- self.project_input_node_feature = node_feature_dim != self.units
1028
- if self.project_input_node_feature:
1029
- warn(
1030
- 'Input node feature dim does not match updated node feature dim. '
1031
- 'To make sure input node feature can be passed as `states` to the '
1032
- 'GRU cell, it will automatically be projected prior to it.'
1033
- )
1034
- self._previous_node_dense = self.get_dense(
1035
- self.units, activation=self._activation
1148
+ if not self.has_node_coordinate:
1149
+ raise ValueError(
1150
+ 'Could not find `coordinate`s in node, '
1151
+ 'which is required for Conv3D layers.'
1036
1152
  )
1037
-
1038
- def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1039
- feature = keras.ops.concatenate(
1040
- [
1041
- tensor.gather('feature', 'source'),
1042
- tensor.gather('feature', 'target'),
1043
- ],
1044
- axis=-1
1045
- )
1046
- if self._has_edge_feature:
1047
- feature = keras.ops.concatenate(
1048
- [
1049
- feature,
1050
- tensor.edge['feature']
1051
- ],
1052
- axis=-1
1053
- )
1054
- message = self.message_fn(feature)
1055
- return tensor.update(
1056
- {
1057
- 'edge': {
1058
- 'message': message,
1059
- }
1060
- }
1061
- )
1062
-
1063
- def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1064
- aggregate = tensor.aggregate('message', mode='mean')
1065
- feature = tensor.node['feature']
1066
- if self.project_input_node_feature:
1067
- feature = self._previous_node_dense(feature)
1068
- return tensor.update(
1069
- {
1070
- 'node': {
1071
- 'aggregate': aggregate,
1072
- 'feature': feature,
1073
- }
1074
- }
1075
- )
1076
-
1077
- def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1078
- updated_node_feature, _ = self.update_fn(
1079
- inputs=tensor.node['aggregate'],
1080
- states=tensor.node['feature']
1081
- )
1082
- return tensor.update(
1083
- {
1084
- 'node': {
1085
- 'feature': updated_node_feature,
1086
- 'aggregate': None,
1087
- }
1088
- }
1089
- )
1090
-
1091
- def get_config(self) -> dict:
1092
- config = super().get_config()
1093
- config.update({})
1094
- return config
1095
-
1096
-
1097
- @keras.saving.register_keras_serializable(package='molcraft')
1098
- class GTConv3D(GTConv):
1099
-
1100
- """Graph transformer layer 3D.
1101
- """
1102
-
1103
- def build(self, spec: tensors.GraphTensor.Spec) -> None:
1104
- """Builds the layer.
1105
- """
1106
- super().build(spec)
1107
- if self._add_bias:
1108
- node_feature_dim = spec.node['feature'].shape[-1]
1109
- kernels = self.units
1110
- self._gaussian_basis = GaussianDistance(kernels)
1111
- self._centrality_dense = self.get_dense(units=node_feature_dim)
1112
- self._gaussian_edge_bias = self.get_dense(self.heads)
1113
-
1114
- def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1115
- node_feature = tensor.node['feature']
1116
-
1117
- if self._add_bias:
1118
- gaussian = self._gaussian_basis(tensor)
1119
- centrality = keras.ops.segment_sum(
1120
- gaussian, tensor.edge['target'], tensor.num_nodes
1121
- )
1122
- node_feature += self._centrality_dense(centrality)
1123
-
1124
- edge_bias = self._edge_bias(tensor) + self._gaussian_edge_bias(gaussian)
1125
- tensor = tensor.update({'edge': {'bias': edge_bias}})
1126
-
1127
- query = self._query_dense(node_feature)
1128
- key = self._key_dense(node_feature)
1129
- value = self._value_dense(node_feature)
1130
-
1131
- query = ops.gather(query, tensor.edge['source'])
1132
- key = ops.gather(key, tensor.edge['target'])
1133
- value = ops.gather(value, tensor.edge['source'])
1134
-
1135
- attention_score = keras.ops.sum(query * key, axis=1, keepdims=True)
1136
- attention_score /= keras.ops.sqrt(float(self.head_units))
1137
-
1138
- attention_score += keras.ops.expand_dims(tensor.edge['bias'], axis=1)
1139
-
1140
- attention = ops.edge_softmax(attention_score, tensor.edge['target'])
1141
- attention = self._softmax_dropout(attention)
1142
-
1143
- distance = keras.ops.subtract(
1144
- tensor.gather('coordinate', 'source'),
1145
- tensor.gather('coordinate', 'target')
1146
- )
1147
- euclidean_distance = ops.euclidean_distance(
1148
- tensor.gather('coordinate', 'source'),
1149
- tensor.gather('coordinate', 'target'),
1150
- axis=-1
1151
- )
1152
- distance /= euclidean_distance
1153
-
1154
- attention *= keras.ops.expand_dims(distance, axis=-1)
1155
- attention = keras.ops.expand_dims(attention, axis=2)
1156
- value = keras.ops.expand_dims(value, axis=1)
1157
-
1158
- message = ops.edge_weight(value, attention)
1159
-
1160
- return tensor.update(
1161
- {
1162
- 'edge': {
1163
- 'message': message,
1164
- },
1165
- }
1166
- )
1167
-
1168
- def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1169
- node_feature = tensor.aggregate('message', mode='sum')
1170
- node_feature = keras.ops.reshape(
1171
- node_feature, (tensor.num_nodes, -1, self.units)
1172
- )
1173
- node_feature = self._output_dense(node_feature)
1174
- node_feature = keras.ops.sum(node_feature, axis=1)
1175
- return tensor.update(
1176
- {
1177
- 'node': {
1178
- 'aggregate': node_feature,
1179
- },
1180
- 'edge': {
1181
- 'message': None,
1182
- }
1183
- }
1184
- )
1185
-
1186
-
1187
- @keras.saving.register_keras_serializable(package='molcraft')
1188
- class MPConv3D(MPConv):
1189
-
1190
- """Message passing neural network layer 3D.
1191
- """
1192
-
1193
- def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1194
- euclidean_distance = ops.euclidean_distance(
1195
- tensor.gather('coordinate', 'target'),
1196
- tensor.gather('coordinate', 'source'),
1197
- axis=-1
1198
- )
1199
- feature = keras.ops.concatenate(
1200
- [
1201
- tensor.gather('feature', 'source'),
1202
- tensor.gather('feature', 'target'),
1203
- euclidean_distance,
1204
- ],
1205
- axis=-1
1206
- )
1207
- if self._has_edge_feature:
1208
- feature = keras.ops.concatenate(
1209
- [
1210
- feature,
1211
- tensor.edge['feature']
1212
- ],
1213
- axis=-1
1214
- )
1215
- message = self.message_fn(feature)
1216
- return tensor.update(
1217
- {
1218
- 'edge': {
1219
- 'message': message,
1220
- }
1221
- }
1222
- )
1223
-
1224
-
1225
- @keras.saving.register_keras_serializable(package='molcraft')
1226
- class EGConv3D(GraphConv):
1227
-
1228
- """Equivariant graph neural network layer 3D.
1229
- """
1230
-
1231
- def __init__(
1232
- self,
1233
- units: int = 128,
1234
- activation: keras.layers.Activation | str | None = 'silu',
1235
- use_bias: bool = True,
1236
- normalize: bool = False,
1237
- **kwargs
1238
- ) -> None:
1239
- super().__init__(
1240
- units=units,
1241
- activation=activation,
1242
- use_bias=use_bias,
1243
- normalize=normalize,
1244
- **kwargs
1245
- )
1246
-
1247
- def build(self, spec: tensors.GraphTensor.Spec) -> None:
1248
- if 'coordinate' not in spec.node:
1249
- raise ValueError(
1250
- 'Could not find `coordinate`s in node, '
1251
- 'which is required for Conv3D layers.'
1252
- )
1253
- self._has_edge_feature = 'feature' in spec.edge
1254
- self._message_feedforward_intermediate = self.get_dense(
1255
- self.units, activation=self._activation
1256
- )
1257
- self._message_feedforward_final = self.get_dense(
1258
- self.units, activation=self._activation
1259
- )
1153
+ self._message_feedforward_intermediate = self.get_dense(
1154
+ self.units, activation=self.activation
1155
+ )
1156
+ self._message_feedforward_final = self.get_dense(
1157
+ self.units, activation=self.activation
1158
+ )
1260
1159
 
1261
1160
  self._coord_feedforward_intermediate = self.get_dense(
1262
- self.units, activation=self._activation
1161
+ self.units, activation=self.activation
1263
1162
  )
1264
1163
  self._coord_feedforward_final = self.get_dense(
1265
1164
  1, use_bias=False, activation='tanh'
1266
1165
  )
1267
1166
 
1268
- has_overridden_update = self.__class__.update != EGConv3D.update
1269
- if not has_overridden_update:
1270
- self._feedforward_intermediate = self.get_dense(
1271
- self.units, activation=self._activation
1272
- )
1273
- self._feedforward_output = self.get_dense(self.units)
1274
-
1275
1167
  def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1276
1168
  relative_node_coordinate = keras.ops.subtract(
1277
1169
  tensor.gather('coordinate', 'target'),
@@ -1300,7 +1192,7 @@ class EGConv3D(GraphConv):
1300
1192
  ],
1301
1193
  axis=-1
1302
1194
  )
1303
- if self._has_edge_feature:
1195
+ if self.has_edge_feature:
1304
1196
  feature = keras.ops.concatenate(
1305
1197
  [
1306
1198
  feature,
@@ -1339,7 +1231,7 @@ class EGConv3D(GraphConv):
1339
1231
  # graph to graph). Therefore, a mean mean aggregation is performed
1340
1232
  # instead:
1341
1233
  aggregate = tensor.aggregate('message', mode='mean')
1342
-
1234
+ aggregate = keras.ops.concatenate([aggregate, tensor.node['feature']], axis=-1)
1343
1235
  # Simply added to silence warning ('no gradients for variables ...')
1344
1236
  aggregate += (0.0 * keras.ops.sum(coordinate))
1345
1237
 
@@ -1355,26 +1247,6 @@ class EGConv3D(GraphConv):
1355
1247
  }
1356
1248
  }
1357
1249
  )
1358
-
1359
- def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1360
- feature = keras.ops.concatenate(
1361
- [
1362
- tensor.node['aggregate'],
1363
- tensor.node['feature']
1364
- ],
1365
- axis=-1
1366
- )
1367
- updated_node_feature = self._feedforward_output(
1368
- self._feedforward_intermediate(feature)
1369
- )
1370
- return tensor.update(
1371
- {
1372
- 'node': {
1373
- 'feature': updated_node_feature,
1374
- 'aggregate': None,
1375
- },
1376
- }
1377
- )
1378
1250
 
1379
1251
  def get_config(self) -> dict:
1380
1252
  config = super().get_config()
@@ -1417,146 +1289,6 @@ class Readout(GraphLayer):
1417
1289
  config['mode'] = self.mode
1418
1290
  return config
1419
1291
 
1420
-
1421
- @keras.saving.register_keras_serializable(package='molcraft')
1422
- class GraphNetwork(GraphLayer):
1423
-
1424
- """Graph neural network.
1425
-
1426
- Sequentially calls graph layers (`GraphLayer`) and concatenates its output.
1427
-
1428
- Arguments:
1429
- layers (list):
1430
- A list of graph layers.
1431
- """
1432
-
1433
- def __init__(self, layers: list[GraphLayer], **kwargs) -> None:
1434
- super().__init__(**kwargs)
1435
- self.layers = layers
1436
- self._update_edge_feature = False
1437
-
1438
- def build(self, spec: tensors.GraphTensor.Spec) -> None:
1439
- units = self.layers[0].units
1440
- node_feature_dim = spec.node['feature'].shape[-1]
1441
- self._update_node_feature = node_feature_dim != units
1442
- if self._update_node_feature:
1443
- warn(
1444
- 'Node feature dim does not match `units` of the first layer. '
1445
- 'Automatically adding a node projection layer to match `units`.'
1446
- )
1447
- self._node_dense = self.get_dense(units)
1448
- self._has_edge_feature = 'feature' in spec.edge
1449
- if self._has_edge_feature:
1450
- edge_feature_dim = spec.edge['feature'].shape[-1]
1451
- self._update_edge_feature = edge_feature_dim != units
1452
- if self._update_edge_feature:
1453
- warn(
1454
- 'Edge feature dim does not match `units` of the first layer. '
1455
- 'Automatically adding a edge projection layer to match `units`.'
1456
- )
1457
- self._edge_dense = self.get_dense(units)
1458
-
1459
- def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1460
- x = tensors.to_dict(tensor)
1461
- if self._update_node_feature:
1462
- x['node']['feature'] = self._node_dense(tensor.node['feature'])
1463
- if self._has_edge_feature and self._update_edge_feature:
1464
- x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
1465
- outputs = [x['node']['feature']]
1466
- for layer in self.layers:
1467
- x = layer(x)
1468
- outputs.append(x['node']['feature'])
1469
- return tensor.update(
1470
- {
1471
- 'node': {
1472
- 'feature': keras.ops.concatenate(outputs, axis=-1)
1473
- }
1474
- }
1475
- )
1476
-
1477
- def tape_propagate(
1478
- self,
1479
- tensor: tensors.GraphTensor,
1480
- tape: tf.GradientTape,
1481
- training: bool | None = None,
1482
- ) -> tuple[tensors.GraphTensor, list[tf.Tensor]]:
1483
- """Performs the propagation with a `GradientTape`.
1484
-
1485
- Performs the same forward pass as `propagate` but with a `GradientTape`
1486
- watching intermediate node features.
1487
-
1488
- Arguments:
1489
- tensor (tensors.GraphTensor):
1490
- The graph input.
1491
- """
1492
- if isinstance(tensor, tensors.GraphTensor):
1493
- x = tensors.to_dict(tensor)
1494
- else:
1495
- x = tensor
1496
- if self._update_node_feature:
1497
- x['node']['feature'] = self._node_dense(tensor.node['feature'])
1498
- if self._update_edge_feature:
1499
- x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
1500
- tape.watch(x['node']['feature'])
1501
- outputs = [x['node']['feature']]
1502
- for layer in self.layers:
1503
- x = layer(x, training=training)
1504
- tape.watch(x['node']['feature'])
1505
- outputs.append(x['node']['feature'])
1506
-
1507
- tensor = tensor.update(
1508
- {
1509
- 'node': {
1510
- 'feature': keras.ops.concatenate(outputs, axis=-1)
1511
- }
1512
- }
1513
- )
1514
- return tensor, outputs
1515
-
1516
- def get_config(self) -> dict:
1517
- config = super().get_config()
1518
- config.update(
1519
- {
1520
- 'layers': [
1521
- keras.layers.serialize(layer) for layer in self.layers
1522
- ]
1523
- }
1524
- )
1525
- return config
1526
-
1527
- @classmethod
1528
- def from_config(cls, config: dict) -> 'GraphNetwork':
1529
- config['layers'] = [
1530
- keras.layers.deserialize(layer) for layer in config['layers']
1531
- ]
1532
- return super().from_config(config)
1533
-
1534
-
1535
- @keras.saving.register_keras_serializable(package='molcraft')
1536
- class Extraction(GraphLayer):
1537
-
1538
- def __init__(
1539
- self,
1540
- field: str,
1541
- inner_field: str | None = None,
1542
- **kwargs
1543
- ) -> None:
1544
- super().__init__(**kwargs)
1545
- self.field = field
1546
- self.inner_field = inner_field
1547
-
1548
- def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1549
- data = dict(getattr(tensor, self.field))
1550
- if not self.inner_field:
1551
- return data
1552
- return data[self.inner_field]
1553
-
1554
- def get_config(self):
1555
- config = super().get_config()
1556
- config['field'] = self.field
1557
- config['inner_field'] = self.inner_field
1558
- return config
1559
-
1560
1292
 
1561
1293
  @keras.saving.register_keras_serializable(package='molcraft')
1562
1294
  class NodeEmbedding(GraphLayer):
@@ -1571,17 +1303,12 @@ class NodeEmbedding(GraphLayer):
1571
1303
  dim: int = None,
1572
1304
  normalize: bool = False,
1573
1305
  embed_context: bool = False,
1574
- allow_reconstruction: bool = False,
1575
- allow_masking: bool = False,
1576
1306
  **kwargs
1577
1307
  ) -> None:
1578
1308
  super().__init__(**kwargs)
1579
1309
  self.dim = dim
1580
1310
  self._normalize = normalize
1581
1311
  self._embed_context = embed_context
1582
- self._masking_rate = None
1583
- self._allow_masking = allow_masking
1584
- self._allow_reconstruction = allow_reconstruction
1585
1312
 
1586
1313
  def build(self, spec: tensors.GraphTensor.Spec) -> None:
1587
1314
  feature_dim = spec.node['feature'].shape[-1]
@@ -1595,47 +1322,31 @@ class NodeEmbedding(GraphLayer):
1595
1322
  self._embed_context = False
1596
1323
  if self._has_super and not self._embed_context:
1597
1324
  self._super_feature = self.get_weight(shape=[self.dim], name='super_node_feature')
1598
- if self._allow_masking:
1599
- self._mask_feature = self.get_weight(shape=[self.dim], name='mask_node_feature')
1600
1325
  if self._embed_context:
1601
1326
  self._context_dense = self.get_dense(self.dim)
1602
1327
 
1603
- if self._normalize:
1604
- if str(self._normalize).lower().startswith('batch'):
1605
- self._norm = keras.layers.BatchNormalization(
1606
- name='output_batch_norm'
1607
- )
1608
- else:
1609
- self._norm = keras.layers.LayerNormalization(
1610
- name='output_layer_norm'
1611
- )
1328
+ if not self._normalize:
1329
+ self._norm = keras.layers.Identity()
1330
+ elif str(self._normalize).lower().startswith('layer'):
1331
+ self._norm = keras.layers.LayerNormalization()
1332
+ else:
1333
+ self._norm = keras.layers.BatchNormalization()
1612
1334
 
1613
1335
  def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1614
1336
  feature = self._node_dense(tensor.node['feature'])
1615
1337
 
1616
- if self._has_super:
1617
- super_feature = (0 if self._embed_context else self._super_feature)
1338
+ if self._has_super and not self._embed_context:
1618
1339
  super_mask = keras.ops.expand_dims(tensor.node['super'], 1)
1619
- feature = keras.ops.where(super_mask, super_feature, feature)
1340
+ feature = keras.ops.where(super_mask, self._super_feature, feature)
1620
1341
 
1621
1342
  if self._embed_context:
1622
1343
  context_feature = self._context_dense(tensor.context['feature'])
1623
1344
  feature = ops.scatter_update(feature, tensor.node['super'], context_feature)
1624
1345
  tensor = tensor.update({'context': {'feature': None}})
1625
1346
 
1626
- apply_mask = (self._allow_masking and 'mask' in tensor.node)
1627
- if apply_mask:
1628
- mask = keras.ops.expand_dims(tensor.node['mask'], -1)
1629
- feature = keras.ops.where(mask, self._mask_feature, feature)
1630
- elif self._allow_masking:
1631
- feature = feature + (self._mask_feature * 0.0)
1347
+ feature = self._norm(feature)
1632
1348
 
1633
- if self._normalize:
1634
- feature = self._norm(feature)
1635
-
1636
- if not self._allow_reconstruction:
1637
- return tensor.update({'node': {'feature': feature}})
1638
- return tensor.update({'node': {'feature': feature, 'target_feature': feature}})
1349
+ return tensor.update({'node': {'feature': feature}})
1639
1350
 
1640
1351
  def get_config(self) -> dict:
1641
1352
  config = super().get_config()
@@ -1643,8 +1354,6 @@ class NodeEmbedding(GraphLayer):
1643
1354
  'dim': self.dim,
1644
1355
  'normalize': self._normalize,
1645
1356
  'embed_context': self._embed_context,
1646
- 'allow_masking': self._allow_masking,
1647
- 'allow_reconstruction': self._allow_reconstruction,
1648
1357
  })
1649
1358
  return config
1650
1359
 
@@ -1661,39 +1370,30 @@ class EdgeEmbedding(GraphLayer):
1661
1370
  self,
1662
1371
  dim: int = None,
1663
1372
  normalize: bool = False,
1664
- allow_masking: bool = True,
1665
1373
  **kwargs
1666
1374
  ) -> None:
1667
1375
  super().__init__(**kwargs)
1668
1376
  self.dim = dim
1669
1377
  self._normalize = normalize
1670
- self._masking_rate = None
1671
- self._allow_masking = allow_masking
1672
1378
 
1673
1379
  def build(self, spec: tensors.GraphTensor.Spec) -> None:
1674
1380
  feature_dim = spec.edge['feature'].shape[-1]
1675
1381
  if not self.dim:
1676
1382
  self.dim = feature_dim
1677
- self._edge_dense = self.get_dense(self.dim)
1383
+ self._edge_dense = self.get_dense(self.dim)
1384
+
1385
+ self._self_loop_feature = self.get_weight(shape=[self.dim], name='self_loop_edge_feature')
1678
1386
 
1679
1387
  self._has_super = 'super' in spec.edge
1680
- self._has_self_loop = 'self_loop' in spec.edge
1681
1388
  if self._has_super:
1682
1389
  self._super_feature = self.get_weight(shape=[self.dim], name='super_edge_feature')
1683
- if self._has_self_loop:
1684
- self._self_loop_feature = self.get_weight(shape=[self.dim], name='self_loop_edge_feature')
1685
- if self._allow_masking:
1686
- self._mask_feature = self.get_weight(shape=[self.dim], name='mask_edge_feature')
1687
-
1688
- if self._normalize:
1689
- if str(self._normalize).lower().startswith('batch'):
1690
- self._norm = keras.layers.BatchNormalization(
1691
- name='output_batch_norm'
1692
- )
1693
- else:
1694
- self._norm = keras.layers.LayerNormalization(
1695
- name='output_layer_norm'
1696
- )
1390
+
1391
+ if not self._normalize:
1392
+ self._norm = keras.layers.Identity()
1393
+ elif str(self._normalize).lower().startswith('layer'):
1394
+ self._norm = keras.layers.LayerNormalization()
1395
+ else:
1396
+ self._norm = keras.layers.BatchNormalization()
1697
1397
 
1698
1398
  def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1699
1399
  feature = self._edge_dense(tensor.edge['feature'])
@@ -1702,246 +1402,136 @@ class EdgeEmbedding(GraphLayer):
1702
1402
  super_mask = keras.ops.expand_dims(tensor.edge['super'], 1)
1703
1403
  feature = keras.ops.where(super_mask, self._super_feature, feature)
1704
1404
 
1705
- if self._has_self_loop:
1706
- self_loop_mask = keras.ops.expand_dims(tensor.edge['self_loop'], 1)
1707
- feature = keras.ops.where(self_loop_mask, self._self_loop_feature, feature)
1708
-
1709
- if (
1710
- self._allow_masking and
1711
- self._masking_rate is not None and
1712
- self._masking_rate > 0
1713
- ):
1714
- random = keras.random.uniform(shape=[tensor.num_edges])
1715
- mask = random <= self._masking_rate
1716
- if self._has_super:
1717
- mask = keras.ops.logical_and(
1718
- mask, keras.ops.logical_not(tensor.edge['super'])
1719
- )
1720
- mask = keras.ops.expand_dims(mask, -1)
1721
- feature = keras.ops.where(mask, self._mask_feature, feature)
1722
- elif self._allow_masking:
1723
- # Simply added to silence warning ('no gradients for variables ...')
1724
- feature += (0.0 * self._mask_feature)
1405
+ self_loop_mask = keras.ops.expand_dims(tensor.edge['source'] == tensor.edge['target'], 1)
1406
+ feature = keras.ops.where(self_loop_mask, self._self_loop_feature, feature)
1725
1407
 
1726
- if self._normalize:
1727
- feature = self._norm(feature)
1408
+ feature = self._norm(feature)
1728
1409
 
1729
- return tensor.update({'edge': {'feature': feature, 'embedding': feature}})
1730
-
1731
- @property
1732
- def masking_rate(self):
1733
- return self._masking_rate
1734
-
1735
- @masking_rate.setter
1736
- def masking_rate(self, rate: float):
1737
- if not self._allow_masking and rate is not None:
1738
- raise ValueError(
1739
- f'Cannot set `masking_rate` for layer {self} '
1740
- 'as `allow_masking` was set to `False`.'
1741
- )
1742
- self._masking_rate = float(rate)
1410
+ return tensor.update({'edge': {'feature': feature}})
1743
1411
 
1744
1412
  def get_config(self) -> dict:
1745
1413
  config = super().get_config()
1746
1414
  config.update({
1747
1415
  'dim': self.dim,
1748
1416
  'normalize': self._normalize,
1749
- 'allow_masking': self._allow_masking
1750
1417
  })
1751
1418
  return config
1752
1419
 
1753
1420
 
1754
1421
  @keras.saving.register_keras_serializable(package='molcraft')
1755
- class Projection(GraphLayer):
1756
- """Base graph projection layer.
1757
- """
1758
- def __init__(
1759
- self,
1760
- units: int = None,
1761
- activation: str | keras.layers.Activation | None = None,
1762
- use_bias: bool = True,
1763
- field: str = 'node',
1764
- **kwargs
1765
- ) -> None:
1766
- super().__init__(use_bias=use_bias, **kwargs)
1767
- self.units = units
1768
- self._activation = keras.activations.get(activation)
1769
- self.field = field
1770
-
1771
- def build(self, spec: tensors.GraphTensor.Spec) -> None:
1772
- data = getattr(spec, self.field, None)
1773
- if data is None:
1774
- raise ValueError('Could not access field {self.field!r}.')
1775
- feature_dim = data['feature'].shape[-1]
1776
- if not self.units:
1777
- self.units = feature_dim
1778
- self._dense = self.get_dense(self.units)
1779
-
1780
- def propagate(self, tensor: tensors.GraphTensor):
1781
- feature = getattr(tensor, self.field)['feature']
1782
- feature = self._dense(feature)
1783
- feature = self._activation(feature)
1784
- return tensor.update(
1785
- {
1786
- self.field: {
1787
- 'feature': feature
1788
- }
1789
- }
1790
- )
1791
-
1792
- def get_config(self) -> dict:
1793
- config = super().get_config()
1794
- config.update({
1795
- 'units': self.units,
1796
- 'activation': keras.activations.serialize(self._activation),
1797
- 'field': self.field,
1798
- })
1799
- return config
1800
-
1801
-
1802
- @keras.saving.register_keras_serializable(package='molcraft')
1803
- class ContextProjection(Projection):
1804
- """Context projection layer.
1805
- """
1806
- def __init__(self, units: int = None, activation: str = None, **kwargs):
1807
- kwargs['field'] = 'context'
1808
- super().__init__(units=units, activation=activation, **kwargs)
1809
-
1422
+ class GraphNetwork(GraphLayer):
1810
1423
 
1811
- @keras.saving.register_keras_serializable(package='molcraft')
1812
- class NodeProjection(Projection):
1813
- """Node projection layer.
1814
- """
1815
- def __init__(self, units: int = None, activation: str = None, **kwargs):
1816
- kwargs['field'] = 'node'
1817
- super().__init__(units=units, activation=activation, **kwargs)
1424
+ """Graph neural network.
1818
1425
 
1426
+ Sequentially calls graph layers (`GraphLayer`) and concatenates its output.
1819
1427
 
1820
- @keras.saving.register_keras_serializable(package='molcraft')
1821
- class EdgeProjection(Projection):
1822
- """Edge projection layer.
1428
+ Arguments:
1429
+ layers (list):
1430
+ A list of graph layers.
1823
1431
  """
1824
- def __init__(self, units: int = None, activation: str = None, **kwargs):
1825
- kwargs['field'] = 'edge'
1826
- super().__init__(units=units, activation=activation, **kwargs)
1827
-
1828
-
1829
- @keras.saving.register_keras_serializable(package='molcraft')
1830
- class Reconstruction(GraphLayer):
1831
-
1832
- def __init__(
1833
- self,
1834
- loss: keras.losses.Loss | str = 'mse',
1835
- loss_weight: float = 0.5,
1836
- **kwargs
1837
- ):
1838
- super().__init__(**kwargs)
1839
- self._loss_fn = keras.losses.get(loss)
1840
- self._loss_weight = loss_weight
1841
-
1842
- def build(self, spec: tensors.GraphTensor.Spec) -> None:
1843
- has_target_node_feature = 'target_feature' in spec.node
1844
- if not has_target_node_feature:
1845
- raise ValueError(
1846
- 'Could not find `target_feature` in `spec.node`. '
1847
- 'Add a `target_feature` via `NodeEmbedding` by setting '
1848
- '`allow_reconstruction` to `True`.'
1849
- )
1850
- output_dim = spec.node['target_feature'].shape[-1]
1851
- self._dense = self.get_dense(output_dim)
1852
-
1853
- def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1854
- target_node_feature = tensor.node['target_feature']
1855
- transformed_node_feature = tensor.node['feature']
1856
-
1857
- reconstructed_node_feature = self._dense(
1858
- transformed_node_feature
1859
- )
1860
-
1861
- loss = self._loss_fn(
1862
- target_node_feature, reconstructed_node_feature
1863
- )
1864
- self.add_loss(keras.ops.sum(loss) * self._loss_weight)
1865
- return tensor.update({'node': {'feature': transformed_node_feature}})
1866
-
1867
- def get_config(self):
1868
- config = super().get_config()
1869
- config['loss'] = keras.losses.serialize(self._loss_fn)
1870
- config['loss_weight'] = self._loss_weight
1871
- return config
1872
-
1873
-
1874
- @keras.saving.register_keras_serializable(package='molcraft')
1875
- class EdgeBias(GraphLayer):
1876
1432
 
1877
- def __init__(self, biases: int, **kwargs):
1433
+ def __init__(self, layers: list[GraphLayer], **kwargs) -> None:
1878
1434
  super().__init__(**kwargs)
1879
- self.biases = biases
1435
+ self.layers = layers
1436
+ self._update_edge_feature = False
1880
1437
 
1881
1438
  def build(self, spec: tensors.GraphTensor.Spec) -> None:
1882
- self._has_edge_length = 'length' in spec.edge
1883
- self._has_edge_feature = 'feature' in spec.edge
1884
- if self._has_edge_feature:
1885
- self._edge_feature_dense = self.get_dense(self.biases)
1886
- if self._has_edge_length:
1887
- self._edge_length_dense = self.get_dense(
1888
- self.biases, kernel_initializer='zeros'
1439
+ units = self.layers[0].units
1440
+ node_feature_dim = spec.node['feature'].shape[-1]
1441
+ self._update_node_feature = node_feature_dim != units
1442
+ if self._update_node_feature:
1443
+ warnings.warn(
1444
+ 'Node feature dim does not match `units` of the first layer. '
1445
+ 'Automatically adding a node projection layer to match `units`.',
1446
+ stacklevel=2
1889
1447
  )
1448
+ self._node_dense = self.get_dense(units)
1449
+ self._has_edge_feature = 'feature' in spec.edge
1450
+ if self._has_edge_feature:
1451
+ edge_feature_dim = spec.edge['feature'].shape[-1]
1452
+ self._update_edge_feature = edge_feature_dim != units
1453
+ if self._update_edge_feature:
1454
+ warnings.warn(
1455
+ 'Edge feature dim does not match `units` of the first layer. '
1456
+ 'Automatically adding a edge projection layer to match `units`.',
1457
+ stacklevel=2
1458
+ )
1459
+ self._edge_dense = self.get_dense(units)
1890
1460
 
1891
1461
  def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1892
- bias = keras.ops.zeros(
1893
- shape=(tensor.num_edges, self.biases),
1894
- dtype=tensor.node['feature'].dtype
1462
+ x = tensors.to_dict(tensor)
1463
+ if self._update_node_feature:
1464
+ x['node']['feature'] = self._node_dense(tensor.node['feature'])
1465
+ if self._has_edge_feature and self._update_edge_feature:
1466
+ x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
1467
+ outputs = [x['node']['feature']]
1468
+ for layer in self.layers:
1469
+ x = layer(x)
1470
+ outputs.append(x['node']['feature'])
1471
+ return tensor.update(
1472
+ {
1473
+ 'node': {
1474
+ 'feature': keras.ops.concatenate(outputs, axis=-1)
1475
+ }
1476
+ }
1895
1477
  )
1896
- if self._has_edge_feature:
1897
- bias += self._edge_feature_dense(tensor.edge['feature'])
1898
- if self._has_edge_length:
1899
- bias += self._edge_length_dense(tensor.edge['length'])
1900
- return bias
1901
-
1902
- def get_config(self) -> dict:
1903
- config = super().get_config()
1904
- config.update({'biases': self.biases})
1905
- return config
1906
1478
 
1479
+ def tape_propagate(
1480
+ self,
1481
+ tensor: tensors.GraphTensor,
1482
+ tape: tf.GradientTape,
1483
+ training: bool | None = None,
1484
+ ) -> tuple[tensors.GraphTensor, list[tf.Tensor]]:
1485
+ """Performs the propagation with a `GradientTape`.
1907
1486
 
1908
- @keras.saving.register_keras_serializable(package='molcraft')
1909
- class GaussianDistance(GraphLayer):
1910
-
1911
- def __init__(self, kernels: int, **kwargs):
1912
- super().__init__(**kwargs)
1913
- self.kernels = kernels
1487
+ Performs the same forward pass as `propagate` but with a `GradientTape`
1488
+ watching intermediate node features.
1914
1489
 
1915
- def build(self, spec: tensors.GraphTensor.Spec) -> None:
1916
- self._loc = self.add_weight(
1917
- shape=[self.kernels],
1918
- initializer='zeros',
1919
- dtype='float32',
1920
- trainable=True
1921
- )
1922
- self._scale = self.add_weight(
1923
- shape=[self.kernels],
1924
- initializer='ones',
1925
- dtype='float32',
1926
- trainable=True
1927
- )
1490
+ Arguments:
1491
+ tensor (tensors.GraphTensor):
1492
+ The graph input.
1493
+ """
1494
+ if isinstance(tensor, tensors.GraphTensor):
1495
+ x = tensors.to_dict(tensor)
1496
+ else:
1497
+ x = tensor
1498
+ if self._update_node_feature:
1499
+ x['node']['feature'] = self._node_dense(tensor.node['feature'])
1500
+ if self._update_edge_feature:
1501
+ x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
1502
+ tape.watch(x['node']['feature'])
1503
+ outputs = [x['node']['feature']]
1504
+ for layer in self.layers:
1505
+ x = layer(x, training=training)
1506
+ tape.watch(x['node']['feature'])
1507
+ outputs.append(x['node']['feature'])
1928
1508
 
1929
- def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1930
- euclidean_distance = ops.euclidean_distance(
1931
- tensor.gather('coordinate', 'source'),
1932
- tensor.gather('coordinate', 'target'),
1933
- axis=-1
1934
- )
1935
- return ops.gaussian(
1936
- euclidean_distance, self._loc, self._scale
1509
+ tensor = tensor.update(
1510
+ {
1511
+ 'node': {
1512
+ 'feature': keras.ops.concatenate(outputs, axis=-1)
1513
+ }
1514
+ }
1937
1515
  )
1938
-
1516
+ return tensor, outputs
1517
+
1939
1518
  def get_config(self) -> dict:
1940
1519
  config = super().get_config()
1941
- config.update({
1942
- 'kernels': self.kernels,
1943
- })
1520
+ config.update(
1521
+ {
1522
+ 'layers': [
1523
+ keras.layers.serialize(layer) for layer in self.layers
1524
+ ]
1525
+ }
1526
+ )
1944
1527
  return config
1528
+
1529
+ @classmethod
1530
+ def from_config(cls, config: dict) -> 'GraphNetwork':
1531
+ config['layers'] = [
1532
+ keras.layers.deserialize(layer) for layer in config['layers']
1533
+ ]
1534
+ return super().from_config(config)
1945
1535
 
1946
1536
 
1947
1537
  @keras.saving.register_keras_serializable(package='molcraft')
@@ -1992,7 +1582,7 @@ class GaussianParams(keras.layers.Dense):
1992
1582
  config['units'] = None
1993
1583
  config['activation'] = keras.activations.serialize(self.loc_activation)
1994
1584
  return config
1995
-
1585
+
1996
1586
 
1997
1587
  def Input(spec: tensors.GraphTensor.Spec) -> dict:
1998
1588
  """Used to specify inputs to model.
@@ -2047,13 +1637,6 @@ def Input(spec: tensors.GraphTensor.Spec) -> dict:
2047
1637
  return inputs
2048
1638
 
2049
1639
 
2050
- def warn(message: str) -> None:
2051
- warnings.warn(
2052
- message=message,
2053
- category=UserWarning,
2054
- stacklevel=1
2055
- )
2056
-
2057
1640
  def _serialize_spec(spec: tensors.GraphTensor.Spec) -> dict:
2058
1641
  serialized_spec = {}
2059
1642
  for outer_field, data in spec.__dict__.items():
@@ -2095,5 +1678,3 @@ def _spec_from_inputs(inputs):
2095
1678
 
2096
1679
 
2097
1680
  GraphTransformer = GTConv
2098
- GraphTransformer3D = GTConv3D
2099
-