molcraft 0.1.0a2__py3-none-any.whl → 0.1.0a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

molcraft/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = '0.1.0a2'
1
+ __version__ = '0.1.0a3'
2
2
 
3
3
  import os
4
4
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
@@ -14,3 +14,4 @@ from molcraft import ops
14
14
  from molcraft import records
15
15
  from molcraft import tensors
16
16
  from molcraft import callbacks
17
+ from molcraft import datasets
molcraft/features.py CHANGED
@@ -155,9 +155,11 @@ class Distance(EdgeFeature):
155
155
  encode_oov: bool = True,
156
156
  **kwargs,
157
157
  ) -> None:
158
- if max_distance is None:
159
- max_distance = 20
160
- vocab = list(range(max_distance + 1))
158
+ vocab = kwargs.pop('vocab', None)
159
+ if not vocab:
160
+ if max_distance is None:
161
+ max_distance = 20
162
+ vocab = list(range(max_distance + 1))
161
163
  super().__init__(
162
164
  vocab=vocab,
163
165
  allow_oov=allow_oov,
molcraft/featurizers.py CHANGED
@@ -402,6 +402,7 @@ class MolGraphFeaturizer(Featurizer):
402
402
  return cls(**config)
403
403
 
404
404
 
405
+ @keras.saving.register_keras_serializable(package='molcraft')
405
406
  class MolGraphFeaturizer3D(MolGraphFeaturizer):
406
407
 
407
408
  """Molecular 3d-graph featurizer.
@@ -627,7 +628,7 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
627
628
  config['conformer_generator'] = keras.saving.deserialize_keras_object(
628
629
  config['conformer_generator']
629
630
  )
630
- return super().from_config(**config)
631
+ return super().from_config(config)
631
632
 
632
633
 
633
634
  def save_featurizer(
molcraft/layers.py CHANGED
@@ -60,25 +60,20 @@ class GraphLayer(keras.layers.Layer):
60
60
  May use built-in methods such as `get_weight`, `get_dense` and `get_einsum_dense`.
61
61
 
62
62
  Optionally implemented by subclass. If implemented, it is recommended to
63
- build the sub-layers via `build([None, input_dim])`. If sub-layers are not
64
- built, symbolic input will be passed through the layer to build it.
63
+ If sub-layers are built (via `build` or `build_from_spec`), set `built`
64
+ to True. If not, symbolic input will be passed through the layer to build them.
65
65
 
66
66
  Args:
67
67
  spec:
68
- A `GraphTensor.Spec` instance, corresponding to the input `GraphTensor`
69
- of the `propagate` method.
68
+ A `GraphTensor.Spec` instance, corresponding to the `GraphTensor`
69
+ passed to `propagate`.
70
70
  """
71
-
71
+
72
72
  def build(self, spec: tensors.GraphTensor.Spec) -> None:
73
73
 
74
74
  self._custom_build_config = {'spec': _serialize_spec(spec)}
75
75
 
76
- invoke_build_from_spec = (
77
- GraphLayer.build_from_spec != self.__class__.build_from_spec
78
- )
79
- if invoke_build_from_spec:
80
- self.build_from_spec(spec)
81
- self.built = True
76
+ self.build_from_spec(spec)
82
77
 
83
78
  if not self.built:
84
79
  # Automatically build layer or model by calling it on symbolic inputs
@@ -229,7 +224,7 @@ class GraphConv(GraphLayer):
229
224
 
230
225
  def __init__(
231
226
  self,
232
- units: int,
227
+ units: int = None,
233
228
  normalize: bool = False,
234
229
  skip_connection: bool = False,
235
230
  **kwargs
@@ -240,6 +235,10 @@ class GraphConv(GraphLayer):
240
235
  self._skip_connection = skip_connection
241
236
 
242
237
  def build(self, spec: tensors.GraphTensor.Spec) -> None:
238
+ if not self.units:
239
+ raise ValueError(
240
+ f'`self.units` needs to be a positive integer. ound: {self.units}.'
241
+ )
243
242
  node_feature_dim = spec.node['feature'].shape[-1]
244
243
  self._project_input_node_feature = (
245
244
  self._skip_connection and (node_feature_dim != self.units)
@@ -256,7 +255,7 @@ class GraphConv(GraphLayer):
256
255
  )
257
256
  if self._normalize_aggregate:
258
257
  self._aggregation_norm = keras.layers.LayerNormalization(
259
- name='aggregation_normalizer'
258
+ name='aggregation_normalization'
260
259
  )
261
260
  self._aggregation_norm.build([None, self.units])
262
261
 
@@ -344,7 +343,7 @@ class GraphConv(GraphLayer):
344
343
 
345
344
 
346
345
  @keras.saving.register_keras_serializable(package='molcraft')
347
- class GINConv(GraphConv):
346
+ class GIConv(GraphConv):
348
347
 
349
348
  """Graph isomorphism network layer.
350
349
  """
@@ -381,7 +380,8 @@ class GINConv(GraphConv):
381
380
  trainable=True,
382
381
  )
383
382
 
384
- if 'feature' in spec.edge:
383
+ self._has_edge_feature = 'feature' in spec.edge
384
+ if self._has_edge_feature:
385
385
  edge_feature_dim = spec.edge['feature'].shape[-1]
386
386
 
387
387
  if not self._update_edge_feature:
@@ -402,16 +402,15 @@ class GINConv(GraphConv):
402
402
  self._feedforward_intermediate_dense = self.get_dense(self.units)
403
403
  self._feedforward_intermediate_dense.build([None, node_feature_dim])
404
404
 
405
- has_overridden_update = self.__class__.update != GINConv.update
405
+ has_overridden_update = self.__class__.update != GIConv.update
406
406
  if not has_overridden_update:
407
- # Use default feedforward network
408
-
409
- self._feedforward_dropout = keras.layers.Dropout(self._dropout)
410
407
  self._feedforward_activation = self._activation
411
-
408
+ self._feedforward_dropout = keras.layers.Dropout(self._dropout)
412
409
  self._feedforward_output_dense = self.get_dense(self.units)
413
410
  self._feedforward_output_dense.build([None, self.units])
414
-
411
+
412
+ self.built = True
413
+
415
414
  def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
416
415
  """Computes messages.
417
416
  """
@@ -419,7 +418,7 @@ class GINConv(GraphConv):
419
418
  edge_feature = tensor.edge.get('feature')
420
419
  if self._update_edge_feature:
421
420
  edge_feature = self._edge_dense(edge_feature)
422
- if edge_feature is not None:
421
+ if self._has_edge_feature:
423
422
  message += edge_feature
424
423
  return tensor.update(
425
424
  {
@@ -436,7 +435,6 @@ class GINConv(GraphConv):
436
435
  node_feature = tensor.aggregate('message')
437
436
  node_feature += (1 + self.epsilon) * tensor.node['feature']
438
437
  node_feature = self._feedforward_intermediate_dense(node_feature)
439
- node_feature = self._feedforward_activation(node_feature)
440
438
  return tensor.update(
441
439
  {
442
440
  'node': {
@@ -452,6 +450,7 @@ class GINConv(GraphConv):
452
450
  """Updates nodes.
453
451
  """
454
452
  node_feature = tensor.node['feature']
453
+ node_feature = self._feedforward_activation(node_feature)
455
454
  node_feature = self._feedforward_dropout(node_feature)
456
455
  node_feature = self._feedforward_output_dense(node_feature)
457
456
  return tensor.update(
@@ -472,6 +471,171 @@ class GINConv(GraphConv):
472
471
  return config
473
472
 
474
473
 
474
+ @keras.saving.register_keras_serializable(package='molgraphx')
475
+ class GAConv(GraphConv):
476
+
477
+ """Graph attention network layer.
478
+ """
479
+
480
+ def __init__(
481
+ self,
482
+ units: int,
483
+ heads: int = 8,
484
+ activation: keras.layers.Activation | str | None = "relu",
485
+ use_bias: bool = True,
486
+ normalize: bool = True,
487
+ dropout: float = 0.0,
488
+ update_edge_feature: bool = True,
489
+ attention_activation: keras.layers.Activation | str | None = "leaky_relu",
490
+ **kwargs,
491
+ ) -> None:
492
+ kwargs['skip_connection'] = False
493
+ super().__init__(
494
+ units=units,
495
+ normalize=normalize,
496
+ use_bias=use_bias,
497
+ **kwargs
498
+ )
499
+ self._heads = heads
500
+ if self.units % self.heads != 0:
501
+ raise ValueError(f"units need to be divisible by heads.")
502
+ self._head_units = self.units // self.heads
503
+ self._activation = keras.activations.get(activation)
504
+ self._dropout = dropout
505
+ self._normalize = normalize
506
+ self._update_edge_feature = update_edge_feature
507
+ self._attention_activation = keras.activations.get(attention_activation)
508
+
509
+ @property
510
+ def heads(self):
511
+ return self._heads
512
+
513
+ @property
514
+ def head_units(self):
515
+ return self._head_units
516
+
517
+ def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
518
+
519
+ node_feature_dim = spec.node['feature'].shape[-1]
520
+ attn_feature_dim = node_feature_dim + node_feature_dim
521
+
522
+ self._has_edge_feature = 'feature' in spec.edge
523
+ if self._has_edge_feature:
524
+ edge_feature_dim = spec.edge['feature'].shape[-1]
525
+ attn_feature_dim += edge_feature_dim
526
+ if self._update_edge_feature:
527
+ self._edge_dense = self.get_einsum_dense(
528
+ 'ijh,jkh->ikh', (self.head_units, self.heads)
529
+ )
530
+ self._edge_dense.build([None, self.head_units, self.heads])
531
+ else:
532
+ self._update_edge_feature = False
533
+
534
+ self._node_dense = self.get_einsum_dense(
535
+ 'ij,jkh->ikh', (self.head_units, self.heads)
536
+ )
537
+ self._node_dense.build([None, node_feature_dim])
538
+
539
+ self._feature_dense = self.get_einsum_dense(
540
+ 'ij,jkh->ikh', (self.head_units, self.heads)
541
+ )
542
+ self._feature_dense.build([None, attn_feature_dim])
543
+
544
+ self._attention_dense = self.get_einsum_dense(
545
+ 'ijh,jkh->ikh', (1, self.heads)
546
+ )
547
+ self._attention_dense.build([None, self.head_units, self.heads])
548
+
549
+ self._node_self_dense = self.get_einsum_dense(
550
+ 'ij,jkh->ikh', (self.head_units, self.heads)
551
+ )
552
+ self._node_self_dense.build([None, node_feature_dim])
553
+ self._dropout_layer = keras.layers.Dropout(self._dropout)
554
+
555
+ self.built = True
556
+
557
+ def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
558
+
559
+ attention_feature = keras.ops.concatenate(
560
+ [
561
+ tensor.gather('feature', 'source'),
562
+ tensor.gather('feature', 'target')
563
+ ],
564
+ axis=-1
565
+ )
566
+ if self._has_edge_feature:
567
+ attention_feature = keras.ops.concatenate(
568
+ [
569
+ attention_feature,
570
+ tensor.edge['feature']
571
+ ],
572
+ axis=-1
573
+ )
574
+
575
+ attention_feature = self._feature_dense(attention_feature)
576
+
577
+ edge_feature = tensor.edge.get('feature')
578
+
579
+ if self._update_edge_feature:
580
+ edge_feature = self._edge_dense(attention_feature)
581
+ edge_feature = keras.ops.reshape(edge_feature, (-1, self.units))
582
+
583
+ attention_feature = self._attention_activation(attention_feature)
584
+ attention_score = self._attention_dense(attention_feature)
585
+ attention_score = ops.edge_softmax(
586
+ score=attention_score, edge_target=tensor.edge['target']
587
+ )
588
+ node_feature = self._node_dense(tensor.node['feature'])
589
+ message = ops.gather(node_feature, tensor.edge['source'])
590
+ return tensor.update(
591
+ {
592
+ 'edge': {
593
+ 'message': message,
594
+ 'weight': attention_score,
595
+ 'feature': edge_feature,
596
+ }
597
+ }
598
+ )
599
+
600
+ def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
601
+ node_feature = tensor.aggregate('message')
602
+ node_feature += self._node_self_dense(tensor.node['feature'])
603
+ node_feature = self._dropout_layer(node_feature)
604
+ node_feature = keras.ops.reshape(node_feature, (-1, self.units))
605
+ return tensor.update(
606
+ {
607
+ 'node': {
608
+ 'feature': node_feature
609
+ },
610
+ 'edge': {
611
+ 'message': None,
612
+ 'weight': None,
613
+ }
614
+ }
615
+ )
616
+
617
+ def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
618
+ node_feature = self._activation(tensor.node['feature'])
619
+ return tensor.update(
620
+ {
621
+ 'node': {
622
+ 'feature': node_feature
623
+ }
624
+ }
625
+ )
626
+
627
+ def get_config(self) -> dict:
628
+ config = super().get_config()
629
+ config.update({
630
+ "heads": self._heads,
631
+ 'activation': keras.activations.serialize(self._activation),
632
+ 'dropout': self._dropout,
633
+ 'update_edge_feature': self._update_edge_feature,
634
+ 'attention_activation': keras.activations.serialize(self._attention_activation),
635
+ })
636
+ return config
637
+
638
+
475
639
  @keras.saving.register_keras_serializable(package='molcraft')
476
640
  class GTConv(GraphConv):
477
641
 
@@ -550,10 +714,11 @@ class GTConv(GraphConv):
550
714
 
551
715
  self._self_attention_dropout = keras.layers.Dropout(self._dropout)
552
716
 
553
- self._add_edge_bias = not 'bias' in spec.edge
554
- if self._add_edge_bias:
555
- self._add_edge_bias = AddEdgeBias()
556
- self._add_edge_bias.build_from_spec(spec)
717
+ self._add_bias = not 'bias' in spec.edge
718
+
719
+ if self._add_bias:
720
+ self._edge_bias = EdgeBias(biases=self.heads)
721
+ self._edge_bias.build_from_spec(spec)
557
722
 
558
723
  has_overridden_update = self.__class__.update != GTConv.update
559
724
  if not has_overridden_update:
@@ -570,11 +735,21 @@ class GTConv(GraphConv):
570
735
  self._feedforward_output_dense = self.get_dense(self.units)
571
736
  self._feedforward_output_dense.build([None, self.units])
572
737
 
738
+ self.built = True
573
739
 
574
740
  def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
575
741
  """Computes messages.
576
742
  """
577
-
743
+ if self._add_bias:
744
+ edge_bias = self._edge_bias(tensor)
745
+ tensor = tensor.update(
746
+ {
747
+ 'edge': {
748
+ 'bias': edge_bias
749
+ }
750
+ }
751
+ )
752
+
578
753
  node_feature = tensor.node['feature']
579
754
 
580
755
  query = self._query_dense(node_feature)
@@ -587,11 +762,8 @@ class GTConv(GraphConv):
587
762
 
588
763
  attention_score = keras.ops.sum(query * key, axis=1, keepdims=True)
589
764
  attention_score /= keras.ops.sqrt(float(self.head_units))
590
-
591
- if self._add_edge_bias:
592
- tensor = self._add_edge_bias(tensor)
593
765
 
594
- attention_score += keras.ops.expand_dims(tensor.edge['bias'], -1)
766
+ attention_score += keras.ops.expand_dims(tensor.edge['bias'], axis=1)
595
767
 
596
768
  attention = ops.edge_softmax(attention_score, tensor.edge['target'])
597
769
  attention = self._softmax_dropout(attention)
@@ -665,6 +837,242 @@ class GTConv(GraphConv):
665
837
  return config
666
838
 
667
839
 
840
+ @keras.saving.register_keras_serializable(package='molcraft')
841
+ class GTConv3D(GTConv):
842
+
843
+ """Graph transformer 3D layer.
844
+ """
845
+
846
+ def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
847
+ super().build_from_spec(spec)
848
+ if self._add_bias:
849
+ node_feature_dim = spec.node['feature'].shape[-1]
850
+ kernels = self.units
851
+ self._gaussian_basis = GaussianDistance(kernels)
852
+ self._gaussian_basis.build_from_spec(spec)
853
+ self._centrality_dense = self.get_dense(units=node_feature_dim)
854
+ self._centrality_dense.build([None, kernels])
855
+ self._gaussian_edge_bias = self.get_dense(self.heads)
856
+ self._gaussian_edge_bias.build([None, kernels])
857
+
858
+ def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
859
+ """Computes messages.
860
+ """
861
+ node_feature = tensor.node['feature']
862
+
863
+ if self._add_bias:
864
+ gaussian = self._gaussian_basis(tensor)
865
+ centrality = keras.ops.segment_sum(
866
+ gaussian, tensor.edge['target'], tensor.num_nodes
867
+ )
868
+ node_feature += self._centrality_dense(centrality)
869
+
870
+ edge_bias = self._edge_bias(tensor) + self._gaussian_edge_bias(gaussian)
871
+ tensor = tensor.update({'edge': {'bias': edge_bias}})
872
+
873
+ query = self._query_dense(node_feature)
874
+ key = self._key_dense(node_feature)
875
+ value = self._value_dense(node_feature)
876
+
877
+ query = ops.gather(query, tensor.edge['source'])
878
+ key = ops.gather(key, tensor.edge['target'])
879
+ value = ops.gather(value, tensor.edge['source'])
880
+
881
+ attention_score = keras.ops.sum(query * key, axis=1, keepdims=True)
882
+ attention_score /= keras.ops.sqrt(float(self.head_units))
883
+
884
+ attention_score += keras.ops.expand_dims(tensor.edge['bias'], axis=1)
885
+
886
+ attention = ops.edge_softmax(attention_score, tensor.edge['target'])
887
+ attention = self._softmax_dropout(attention)
888
+
889
+ distance = keras.ops.subtract(
890
+ tensor.gather('coordinate', 'source'),
891
+ tensor.gather('coordinate', 'target')
892
+ )
893
+ euclidean_distance = ops.euclidean_distance(
894
+ tensor.gather('coordinate', 'source'),
895
+ tensor.gather('coordinate', 'target'),
896
+ axis=-1
897
+ )
898
+ distance /= euclidean_distance
899
+
900
+ attention *= keras.ops.expand_dims(distance, axis=-1)
901
+ attention = keras.ops.expand_dims(attention, axis=2)
902
+ value = keras.ops.expand_dims(value, axis=1)
903
+
904
+ return tensor.update(
905
+ {
906
+ 'edge': {
907
+ 'message': value,
908
+ 'weight': attention,
909
+ },
910
+ }
911
+ )
912
+
913
+ def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
914
+ """Aggregates messages.
915
+ """
916
+ node_feature = tensor.aggregate('message')
917
+ node_feature = keras.ops.reshape(
918
+ node_feature, (tensor.num_nodes, -1, self.units)
919
+ )
920
+ node_feature = self._output_dense(node_feature)
921
+ node_feature = keras.ops.sum(node_feature, axis=1)
922
+ node_feature = self._self_attention_dropout(node_feature)
923
+ return tensor.update(
924
+ {
925
+ 'node': {
926
+ 'feature': node_feature,
927
+ 'residual': tensor.node['feature']
928
+ },
929
+ 'edge': {
930
+ 'message': None,
931
+ 'weight': None,
932
+ }
933
+ }
934
+ )
935
+
936
+
937
+ @keras.saving.register_keras_serializable(package='molcraft')
938
+ class MPConv(GraphConv):
939
+
940
+ """Message passing neural network layer.
941
+ """
942
+
943
+ def __init__(
944
+ self,
945
+ units: int = 128,
946
+ activation: keras.layers.Activation | str | None = None,
947
+ use_bias: bool = True,
948
+ normalize: bool = True,
949
+ dropout: float = 0.0,
950
+ **kwargs
951
+ ) -> None:
952
+ super().__init__(
953
+ units=units,
954
+ normalize=normalize,
955
+ use_bias=use_bias,
956
+ **kwargs
957
+ )
958
+ self._activation = keras.activations.get(activation)
959
+ self._dropout = dropout or 0.0
960
+
961
+ def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
962
+ node_feature_dim = spec.node['feature'].shape[-1]
963
+ self.message_fn = self.get_dense(self.units, activation=self._activation)
964
+ self.update_fn = keras.layers.GRUCell(self.units)
965
+ self._has_edge_feature = 'feature' in spec.edge
966
+ self.project_input_node_feature = node_feature_dim != self.units
967
+ if self.project_input_node_feature:
968
+ warn(
969
+ 'Input node feature dim does not match updated node feature dim. '
970
+ 'To make sure input node feature can be passed as `states` to the '
971
+ 'GRU cell, it will automatically be projected prior to it.'
972
+ )
973
+ self._previous_node_dense = self.get_dense(self.units, activation=self._activation)
974
+ self.built = True
975
+
976
+ def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
977
+ feature = keras.ops.concatenate(
978
+ [
979
+ tensor.gather('feature', 'source'),
980
+ tensor.gather('feature', 'target'),
981
+ ],
982
+ axis=-1
983
+ )
984
+ if self._has_edge_feature:
985
+ feature = keras.ops.concatenate(
986
+ [
987
+ feature,
988
+ tensor.edge['feature']
989
+ ],
990
+ axis=-1
991
+ )
992
+ message = self.message_fn(feature)
993
+ return tensor.update(
994
+ {
995
+ 'edge': {
996
+ 'message': message,
997
+ }
998
+ }
999
+ )
1000
+
1001
+ def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1002
+ aggregate = tensor.aggregate('message')
1003
+ previous = tensor.node['feature']
1004
+ if self.project_input_node_feature:
1005
+ previous = self._previous_node_dense(previous)
1006
+ return tensor.update(
1007
+ {
1008
+ 'node': {
1009
+ 'feature': aggregate,
1010
+ 'previous_feature': previous,
1011
+ }
1012
+ }
1013
+ )
1014
+
1015
+ def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1016
+ updated_node_feature, _ = self.update_fn(
1017
+ inputs=tensor.node['feature'],
1018
+ states=tensor.node['previous_feature']
1019
+ )
1020
+ return tensor.update(
1021
+ {
1022
+ 'node': {
1023
+ 'feature': updated_node_feature,
1024
+ 'previous_feature': None,
1025
+ }
1026
+ }
1027
+ )
1028
+
1029
+ def get_config(self) -> dict:
1030
+ config = super().get_config()
1031
+ config.update({
1032
+ 'activation': keras.activations.serialize(self._activation),
1033
+ 'dropout': self._dropout,
1034
+ })
1035
+ return config
1036
+
1037
+
1038
+ @keras.saving.register_keras_serializable(package='molcraft')
1039
+ class MPConv3D(MPConv):
1040
+
1041
+ """3D Message passing neural network layer.
1042
+ """
1043
+
1044
+ def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1045
+ euclidean_distance = ops.euclidean_distance(
1046
+ tensor.gather('coordinate', 'target'),
1047
+ tensor.gather('coordinate', 'source'),
1048
+ axis=-1
1049
+ )
1050
+ feature = keras.ops.concatenate(
1051
+ [
1052
+ tensor.gather('feature', 'source'),
1053
+ tensor.gather('feature', 'target'),
1054
+ euclidean_distance,
1055
+ ],
1056
+ axis=-1
1057
+ )
1058
+ if self._has_edge_feature:
1059
+ feature = keras.ops.concatenate(
1060
+ [
1061
+ feature,
1062
+ tensor.edge['feature']
1063
+ ],
1064
+ axis=-1
1065
+ )
1066
+ message = self.message_fn(feature)
1067
+ return tensor.update(
1068
+ {
1069
+ 'edge': {
1070
+ 'message': message,
1071
+ }
1072
+ }
1073
+ )
1074
+
1075
+
668
1076
  @keras.saving.register_keras_serializable(package='molcraft')
669
1077
  class EGConv3D(GraphConv):
670
1078
 
@@ -714,6 +1122,7 @@ class EGConv3D(GraphConv):
714
1122
  self.update_fn = self.get_dense(self.units, activation=self._activation)
715
1123
  self.update_fn.build([None, node_feature_dim + self.units])
716
1124
  self._dropout_layer = keras.layers.Dropout(self._dropout)
1125
+ self.built = True
717
1126
 
718
1127
  def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
719
1128
  """Computes messages.
@@ -849,6 +1258,7 @@ class Projection(GraphLayer):
849
1258
  self.units = feature_dim
850
1259
  self._dense = self.get_dense(self.units)
851
1260
  self._dense.build([None, feature_dim])
1261
+ self.built = True
852
1262
 
853
1263
  def propagate(self, tensor: tensors.GraphTensor):
854
1264
  """Calls the layer.
@@ -913,6 +1323,7 @@ class GraphNetwork(GraphLayer):
913
1323
  )
914
1324
  self._edge_dense = self.get_dense(units)
915
1325
  self._update_edge_feature = True
1326
+ self.built = True
916
1327
 
917
1328
  def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
918
1329
  """Calls the layer.
@@ -1003,12 +1414,14 @@ class NodeEmbedding(GraphLayer):
1003
1414
  def __init__(
1004
1415
  self,
1005
1416
  dim: int = None,
1417
+ normalize: bool = True,
1006
1418
  embed_context: bool = True,
1007
1419
  allow_masking: bool = True,
1008
1420
  **kwargs
1009
1421
  ) -> None:
1010
1422
  super().__init__(**kwargs)
1011
1423
  self.dim = dim
1424
+ self._normalize = normalize
1012
1425
  self._embed_context = embed_context
1013
1426
  self._masking_rate = None
1014
1427
  self._allow_masking = allow_masking
@@ -1035,6 +1448,12 @@ class NodeEmbedding(GraphLayer):
1035
1448
  context_feature_dim = spec.context['feature'].shape[-1]
1036
1449
  self._context_dense = self.get_dense(self.dim)
1037
1450
  self._context_dense.build([None, context_feature_dim])
1451
+
1452
+ if self._normalize:
1453
+ self._norm = keras.layers.LayerNormalization()
1454
+ self._norm.build([None, self.dim])
1455
+
1456
+ self.built = True
1038
1457
 
1039
1458
  def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1040
1459
  """Calls the layer.
@@ -1068,6 +1487,9 @@ class NodeEmbedding(GraphLayer):
1068
1487
  # Slience warning of 'no gradients for variables'
1069
1488
  feature = feature + (self._mask_feature * 0.0)
1070
1489
 
1490
+ if self._normalize:
1491
+ feature = self._norm(feature)
1492
+
1071
1493
  return tensor.update({'node': {'feature': feature}})
1072
1494
 
1073
1495
  @property
@@ -1087,6 +1509,8 @@ class NodeEmbedding(GraphLayer):
1087
1509
  config = super().get_config()
1088
1510
  config.update({
1089
1511
  'dim': self.dim,
1512
+ 'normalize': self._normalize,
1513
+ 'embed_context': self._embed_context,
1090
1514
  'allow_masking': self._allow_masking
1091
1515
  })
1092
1516
  return config
@@ -1103,11 +1527,13 @@ class EdgeEmbedding(GraphLayer):
1103
1527
  def __init__(
1104
1528
  self,
1105
1529
  dim: int = None,
1530
+ normalize: bool = True,
1106
1531
  allow_masking: bool = True,
1107
1532
  **kwargs
1108
1533
  ) -> None:
1109
1534
  super().__init__(**kwargs)
1110
1535
  self.dim = dim
1536
+ self._normalize = normalize
1111
1537
  self._masking_rate = None
1112
1538
  self._allow_masking = allow_masking
1113
1539
 
@@ -1125,6 +1551,11 @@ class EdgeEmbedding(GraphLayer):
1125
1551
  self._super_feature = self.get_weight(shape=[self.dim], name='super_edge_feature')
1126
1552
  if self._allow_masking:
1127
1553
  self._mask_feature = self.get_weight(shape=[self.dim], name='mask_edge_feature')
1554
+ if self._normalize:
1555
+ self._norm = keras.layers.LayerNormalization()
1556
+ self._norm.build([None, self.dim])
1557
+
1558
+ self.built = True
1128
1559
 
1129
1560
  def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1130
1561
  """Calls the layer.
@@ -1153,6 +1584,9 @@ class EdgeEmbedding(GraphLayer):
1153
1584
  # Slience warning of 'no gradients for variables'
1154
1585
  feature = feature + (self._mask_feature * 0.0)
1155
1586
 
1587
+ if self._normalize:
1588
+ feature = self._norm(feature)
1589
+
1156
1590
  return tensor.update({'edge': {'feature': feature}})
1157
1591
 
1158
1592
  @property
@@ -1172,6 +1606,7 @@ class EdgeEmbedding(GraphLayer):
1172
1606
  config = super().get_config()
1173
1607
  config.update({
1174
1608
  'dim': self.dim,
1609
+ 'normalize': self._normalize,
1175
1610
  'allow_masking': self._allow_masking
1176
1611
  })
1177
1612
  return config
@@ -1199,17 +1634,97 @@ class EdgeProjection(Projection):
1199
1634
  """
1200
1635
  def __init__(self, units: int = None, activation: str = None, **kwargs):
1201
1636
  super().__init__(units=units, activation=activation, field='edge', **kwargs)
1637
+
1638
+
1639
+ @keras.saving.register_keras_serializable(package='molcraft')
1640
+ class EdgeBias(GraphLayer):
1641
+
1642
+ def __init__(self, biases: int, **kwargs):
1643
+ super().__init__(**kwargs)
1644
+ self.biases = biases
1645
+
1646
+ def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
1647
+ self._has_edge_length = 'length' in spec.edge
1648
+ self._has_edge_feature = 'feature' in spec.edge
1649
+ if self._has_edge_feature:
1650
+ self._edge_feature_dense = self.get_dense(self.biases)
1651
+ self._edge_feature_dense.build([None, spec.edge['feature'].shape[-1]])
1652
+ if self._has_edge_length:
1653
+ self._edge_length_dense = self.get_dense(
1654
+ self.biases, kernel_initializer='zeros'
1655
+ )
1656
+ self._edge_length_dense.build([None, spec.edge['length'].shape[-1]])
1657
+ self.built = True
1658
+
1659
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1660
+ bias = keras.ops.zeros(
1661
+ shape=(tensor.num_edges, self.biases),
1662
+ dtype=tensor.node['feature'].dtype
1663
+ )
1664
+ if self._has_edge_feature:
1665
+ bias += self._edge_feature_dense(tensor.edge['feature'])
1666
+ if self._has_edge_length:
1667
+ bias += self._edge_length_dense(tensor.edge['length'])
1668
+ return bias
1669
+
1670
+ def get_config(self) -> dict:
1671
+ config = super().get_config()
1672
+ config.update({'biases': self.biases})
1673
+ return config
1202
1674
 
1203
1675
 
1204
1676
  @keras.saving.register_keras_serializable(package='molcraft')
1205
- class Readout(keras.layers.Layer):
1677
+ class GaussianDistance(GraphLayer):
1678
+
1679
+ def __init__(self, kernels: int, **kwargs):
1680
+ super().__init__(**kwargs)
1681
+ self.kernels = kernels
1682
+
1683
+ def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
1684
+ self._loc = self.add_weight(
1685
+ shape=[self.kernels],
1686
+ initializer='zeros',
1687
+ dtype='float32',
1688
+ trainable=True
1689
+ )
1690
+ self._scale = self.add_weight(
1691
+ shape=[self.kernels],
1692
+ initializer='ones',
1693
+ dtype='float32',
1694
+ trainable=True
1695
+ )
1696
+ self.built = True
1697
+
1698
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1699
+ euclidean_distance = ops.euclidean_distance(
1700
+ tensor.gather('coordinate', 'source'),
1701
+ tensor.gather('coordinate', 'target'),
1702
+ axis=-1
1703
+ )
1704
+ return ops.gaussian(
1705
+ euclidean_distance, self._loc, self._scale
1706
+ )
1707
+
1708
+ def get_config(self) -> dict:
1709
+ config = super().get_config()
1710
+ config.update({
1711
+ 'kernels': self.kernels,
1712
+ })
1713
+ return config
1714
+
1715
+
1716
+ @keras.saving.register_keras_serializable(package='molcraft')
1717
+ class Readout(GraphLayer):
1718
+
1719
+ """Readout layer.
1720
+ """
1206
1721
 
1207
1722
  def __init__(self, mode: str | None = None, **kwargs):
1723
+ kwargs['kernel_initializer'] = None
1724
+ kwargs['bias_initializer'] = None
1208
1725
  super().__init__(**kwargs)
1209
1726
  self.mode = mode
1210
- if not self.mode:
1211
- self._reduce_fn = None
1212
- elif str(self.mode).lower().startswith('sum'):
1727
+ if str(self.mode).lower().startswith('sum'):
1213
1728
  self._reduce_fn = keras.ops.segment_sum
1214
1729
  elif str(self.mode).lower().startswith('max'):
1215
1730
  self._reduce_fn = keras.ops.segment_max
@@ -1221,80 +1736,24 @@ class Readout(keras.layers.Layer):
1221
1736
  def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
1222
1737
  """Builds the layer.
1223
1738
  """
1224
- pass
1739
+ self.built = True
1225
1740
 
1226
- def reduce(self, tensor: tensors.GraphTensor) -> tf.Tensor:
1227
- if self._reduce_fn is None:
1228
- raise NotImplementedError("Need to define a reduce method.")
1741
+ def propagate(self, tensor: tensors.GraphTensor) -> tf.Tensor:
1742
+ """Calls the layer.
1743
+ """
1744
+ node_feature = tensor.node['feature']
1229
1745
  if str(self.mode).lower().startswith('super'):
1230
1746
  node_feature = keras.ops.where(
1231
- tensor.node['super'][:, None], tensor.node['feature'], 0.0
1232
- )
1233
- return self._reduce_fn(
1234
- node_feature, tensor.graph_indicator, tensor.num_subgraphs
1747
+ tensor.node['super'][:, None], node_feature, 0.0
1235
1748
  )
1236
1749
  return self._reduce_fn(
1237
- tensor.node['feature'], tensor.graph_indicator, tensor.num_subgraphs
1750
+ node_feature, tensor.graph_indicator, tensor.num_subgraphs
1238
1751
  )
1239
1752
 
1240
- def build(self, input_shapes) -> None:
1241
- spec = tensors.GraphTensor.Spec.from_input_shape_dict(input_shapes)
1242
- self.build_from_spec(spec)
1243
- self.built = True
1244
-
1245
- def call(self, graph) -> tf.Tensor:
1246
- graph_tensor = tensors.from_dict(graph)
1247
- if tensors.is_ragged(graph_tensor):
1248
- graph_tensor = graph_tensor.flatten()
1249
- return self.reduce(graph_tensor)
1250
-
1251
- def __call__(
1252
- self,
1253
- graph: tensors.GraphTensor,
1254
- *args,
1255
- **kwargs
1256
- ) -> tensors.GraphTensor:
1257
- is_tensor = isinstance(graph, tensors.GraphTensor)
1258
- if is_tensor:
1259
- graph = tensors.to_dict(graph)
1260
- tensor = super().__call__(graph, *args, **kwargs)
1261
- return tensor
1262
-
1263
1753
  def get_config(self) -> dict:
1264
1754
  config = super().get_config()
1265
1755
  config['mode'] = self.mode
1266
1756
  return config
1267
-
1268
-
1269
- @keras.saving.register_keras_serializable(package='molcraft')
1270
- class AddEdgeBias(GraphLayer):
1271
-
1272
- def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
1273
- self._has_edge_length = 'length' in spec.edge
1274
- self._has_edge_feature = 'feature' in spec.edge
1275
- if self._has_edge_feature:
1276
- self._edge_feature_dense = self.get_dense(units=1)
1277
- if self._has_edge_length:
1278
- self._edge_length_dense = self.get_dense(
1279
- units=1, kernel_initializer='zeros'
1280
- )
1281
-
1282
- def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1283
- bias = keras.ops.zeros(
1284
- shape=(tensor.num_edges, 1),
1285
- dtype=tensor.node['feature'].dtype
1286
- )
1287
- if self._has_edge_feature:
1288
- bias += self._edge_feature_dense(tensor.edge['feature'])
1289
- if self._has_edge_length:
1290
- bias += self._edge_length_dense(tensor.edge['length'])
1291
- return tensor.update(
1292
- {
1293
- 'edge': {
1294
- 'bias': bias
1295
- }
1296
- }
1297
- )
1298
1757
 
1299
1758
 
1300
1759
  def Input(spec: tensors.GraphTensor.Spec) -> dict:
@@ -1412,13 +1871,6 @@ def _spec_from_inputs(inputs):
1412
1871
  return tensors.GraphTensor.Spec(**nested_specs)
1413
1872
 
1414
1873
 
1415
- GraphTransformer = GTConvolution = GTConv
1416
- GINConvolution = GINConv
1417
-
1418
- EdgeEmbed = EdgeEmbedding
1419
- NodeEmbed = NodeEmbedding
1420
-
1421
- ContextDense = ContextProjection
1422
- EdgeDense = EdgeProjection
1423
- NodeDense = NodeProjection
1874
+ GraphTransformer = GTConv
1875
+ GraphTransformer3D = GTConv3D
1424
1876
 
molcraft/models.py CHANGED
@@ -315,10 +315,16 @@ class FunctionalGraphModel(functional.Functional, GraphModel):
315
315
  ]
316
316
 
317
317
 
318
- def save_model(model: keras.Model, filepath: str | Path, *args, **kwargs) -> None:
318
+ def save_model(model: GraphModel, filepath: str | Path, *args, **kwargs) -> None:
319
+ if not model.built:
320
+ raise ValueError(
321
+ 'Model and its layers have not yet been (fully) built. '
322
+ 'Build the model before saving it: `model.build(graph_spec)` '
323
+ 'or `model(graph)`.'
324
+ )
319
325
  keras.models.save_model(model, filepath, *args, **kwargs)
320
326
 
321
- def load_model(filepath: str | Path, inputs=None, *args, **kwargs) -> None:
327
+ def load_model(filepath: str | Path, inputs=None, *args, **kwargs) -> GraphModel:
322
328
  return keras.models.load_model(filepath, *args, **kwargs)
323
329
 
324
330
  def create(
@@ -333,7 +339,7 @@ def create(
333
339
  def interpret(
334
340
  model: GraphModel,
335
341
  graph_tensor: tensors.GraphTensor,
336
- ) -> tuple[tf.Tensor | tf.RaggedTensor | np.ndarray, tf.Tensor | np.ndarray]:
342
+ ) -> tensors.GraphTensor:
337
343
  x = graph_tensor
338
344
  if tensors.is_ragged(x):
339
345
  x = x.flatten()
@@ -373,6 +379,31 @@ def interpret(
373
379
  }
374
380
  )
375
381
 
382
+ def saliency(
383
+ model: GraphModel,
384
+ graph_tensor: tensors.GraphTensor
385
+ ) -> tensors.GraphTensor:
386
+ x = graph_tensor
387
+ if tensors.is_ragged(x):
388
+ x = x.flatten()
389
+ y_true = x.context.get('label')
390
+ with tf.GradientTape(watch_accessed_variables=False) as tape:
391
+ tape.watch(x.node['feature'])
392
+ y_pred = model(x, training=False)
393
+ if y_true is not None and len(y_true.shape) > 1:
394
+ target = tf.gather_nd(y_pred, tf.where(y_true != 0))
395
+ else:
396
+ target = y_pred
397
+ gradients = tape.gradient(target, x.node['feature'])
398
+ gradients = keras.ops.absolute(gradients)
399
+ return graph_tensor.update(
400
+ {
401
+ 'node': {
402
+ 'feature_saliency': gradients
403
+ }
404
+ }
405
+ )
406
+
376
407
  def predict(
377
408
  model: GraphModel,
378
409
  x: tensors.GraphTensor | tf.data.Dataset,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: molcraft
3
- Version: 0.1.0a2
3
+ Version: 0.1.0a3
4
4
  Summary: Graph Neural Networks for Molecular Machine Learning
5
5
  Author-email: Alexander Kensert <alexander.kensert@gmail.com>
6
6
  License: MIT License
@@ -1,20 +1,20 @@
1
- molcraft/__init__.py,sha256=lE7_mCo7lLcP1AopGZtGyWqzAN1qgjZnH5juymdjrJc,406
1
+ molcraft/__init__.py,sha256=2ZNfWBjGl8DscOwjdDiRkgIsuPnKit29Q3MhZyP336Q,435
2
2
  molcraft/callbacks.py,sha256=6gwCwdsHGb-fVB4m1QGmtBwQwZ9mFq9QUkmPKSMn05U,849
3
3
  molcraft/chem.py,sha256=_UO5O-I7KUtGf3vRrFEYoAUGlW5xi2x8ylu5f-Ybumo,18696
4
4
  molcraft/conformers.py,sha256=p09gOQOdxLSj3yohZOMkxxLriHsZ1ZqOoiWLi73OpIg,4325
5
5
  molcraft/datasets.py,sha256=rFgXTC1ZheLhfgQgcCspP_wEE54a33PIneH7OplbS-8,4047
6
6
  molcraft/descriptors.py,sha256=x6RfZ-gK7D_WSvmK6sh6yHyEjQqovPnRU0xwC3dAKfg,2880
7
- molcraft/features.py,sha256=nZDfX9fsWWjhUbUbrWSUI0ny1QIDbxb4MO8umjcdQqw,13572
8
- molcraft/featurizers.py,sha256=gAUe7Ui8gF32aotuiDAUoRUuw8bTbkMgB2C2BO1VWDM,26176
9
- molcraft/layers.py,sha256=zs6Ae6p7ASeAy3eF113f35d55yQmyk2Z7vUUfkfJUmY,49677
10
- molcraft/models.py,sha256=Nvm5LKCtH-xj395f1OvIEmYVTTrnutoSthL2DxGicnY,16519
7
+ molcraft/features.py,sha256=69oV_GHNdBKPA4sp6Tpo6brvNmaauk_IVIzNjX7VDmg,13648
8
+ molcraft/featurizers.py,sha256=Yu8I6I_zkzB__WYSiqz-FDGjvKFOmyWFxojRBr39Aw8,26236
9
+ molcraft/layers.py,sha256=HjnAtqhuP0uZ5yP4L33k3xT4IUdLavWBrjd3wO9_Rmw,64915
10
+ molcraft/models.py,sha256=DXqWR_XnMVXQseVR91XnDLXvmHa1hv-6_Y_wvpQZBFI,17476
11
11
  molcraft/ops.py,sha256=iiE6zgA2P7cmjKO1RHmL9GE_Tv7Tyuo_xDoxB_ELZQM,3824
12
12
  molcraft/records.py,sha256=w4-bcWZEC0oVInrE1e0kQBroIaSCA0PN1JBPOtO6VUY,5251
13
13
  molcraft/tensors.py,sha256=b7PO-YOvV72s9g057ILJACKS2n2fn10VkO35gHXpssI,22312
14
14
  molcraft/experimental/__init__.py,sha256=x5h6LOO8bo3NPjkKKM9M1H-Kz6R3yxYhRSePoxHCdRE,42
15
15
  molcraft/experimental/peptides.py,sha256=RCuOTSwoYHGSdeYi6TWHdPIv2WC3avCZjKLdhEZQeXw,8997
16
- molcraft-0.1.0a2.dist-info/licenses/LICENSE,sha256=sbVeqlrtZ0V63uYhZGL5dCxUm8rBAOqe2avyA1zIQNk,1074
17
- molcraft-0.1.0a2.dist-info/METADATA,sha256=TYf32YHTSrK9OaaGKCCk89uPlf_REWsK-LKf93c6V4M,4088
18
- molcraft-0.1.0a2.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
19
- molcraft-0.1.0a2.dist-info/top_level.txt,sha256=dENV6MfOceshM6MQCgJlcN1ojZkiCL9B4F7XyUge3QM,9
20
- molcraft-0.1.0a2.dist-info/RECORD,,
16
+ molcraft-0.1.0a3.dist-info/licenses/LICENSE,sha256=sbVeqlrtZ0V63uYhZGL5dCxUm8rBAOqe2avyA1zIQNk,1074
17
+ molcraft-0.1.0a3.dist-info/METADATA,sha256=f_5sBinpFcGSqKLaSqGkZJ83gGQtZw1Pb3fkgq9aCBM,4088
18
+ molcraft-0.1.0a3.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
19
+ molcraft-0.1.0a3.dist-info/top_level.txt,sha256=dENV6MfOceshM6MQCgJlcN1ojZkiCL9B4F7XyUge3QM,9
20
+ molcraft-0.1.0a3.dist-info/RECORD,,