molcraft 0.1.0a5__py3-none-any.whl → 0.1.0a7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

molcraft/layers.py CHANGED
@@ -125,23 +125,45 @@ class GraphLayer(keras.layers.Layer):
125
125
  return tensors.to_dict(outputs)
126
126
  return outputs
127
127
 
128
- def __call__(self, inputs, **kwargs):
128
+ def __call__(
129
+ self,
130
+ graph: dict[str, dict[str, tf.Tensor]] | tensors.GraphTensor,
131
+ **kwargs
132
+ ) -> tf.Tensor | dict[str, dict[str, tf.Tensor]] | tensors.GraphTensor:
129
133
  if not self.built:
130
- spec = _spec_from_inputs(inputs)
134
+ spec = _spec_from_inputs(graph)
131
135
  self.build(spec)
132
- convert = isinstance(inputs, tensors.GraphTensor)
133
- if convert:
134
- inputs = tensors.to_dict(inputs)
136
+
137
+ is_graph_tensor = isinstance(graph, tensors.GraphTensor)
138
+ if is_graph_tensor:
139
+ graph = tensors.to_dict(graph)
140
+ else:
141
+ graph = {field: dict(data) for (field, data) in graph.items()}
142
+
135
143
  if isinstance(self, functional.Functional):
136
- inputs, left_out_inputs = _match_functional_input(self.input, inputs)
137
- outputs = super().__call__(inputs, **kwargs)
144
+ # As a functional model is strict for what input can
145
+ # be passed to it, we need to temporarily pop some of the
146
+ # input and add it afterwards.
147
+ label = graph['context'].pop('label', None)
148
+ weight = graph['context'].pop('weight', None)
149
+ tf.nest.assert_same_structure(self.input, graph)
150
+
151
+ outputs = super().__call__(graph, **kwargs)
152
+
138
153
  if not tensors.is_graph(outputs):
139
154
  return outputs
155
+
156
+ graph = outputs
140
157
  if isinstance(self, functional.Functional):
141
- outputs = _add_left_out_inputs(outputs, left_out_inputs)
142
- if convert:
143
- outputs = tensors.from_dict(outputs)
144
- return outputs
158
+ if label is not None:
159
+ graph['context']['label'] = label
160
+ if weight is not None:
161
+ graph['context']['weight'] = weight
162
+
163
+ if is_graph_tensor:
164
+ return tensors.from_dict(graph)
165
+
166
+ return graph
145
167
 
146
168
  def get_build_config(self) -> dict:
147
169
  if self._custom_build_config:
@@ -256,10 +278,10 @@ class GraphConv(GraphLayer):
256
278
  Default to `None`.
257
279
  use_bias (bool):
258
280
  Whether bias should be used in dense layers. Default to `True`.
259
- normalization (bool, str):
281
+ normalize (bool, str):
260
282
  Whether `LayerNormalization` should be applied to the final node feature output.
261
283
  To use `BatchNormalization`, specify `batch_norm`. Default to `False`.
262
- skip_connection (bool, str):
284
+ skip_connect (bool, str):
263
285
  Whether node feature input should be added to the node feature output.
264
286
  If node feature input dim is not equal to `units` (node feature output dim),
265
287
  a projection layer will automatically project the residual before adding it
@@ -294,14 +316,14 @@ class GraphConv(GraphLayer):
294
316
  units: int = None,
295
317
  activation: str | keras.layers.Activation | None = None,
296
318
  use_bias: bool = True,
297
- normalization: bool | str = False,
298
- skip_connection: bool | str = True,
319
+ normalize: bool | str = False,
320
+ skip_connect: bool | str = True,
299
321
  **kwargs
300
322
  ) -> None:
301
323
  super().__init__(use_bias=use_bias, **kwargs)
302
324
  self._units = units
303
- self._normalization = normalization
304
- self._skip_connection = skip_connection
325
+ self._normalize = normalize
326
+ self._skip_connect = skip_connect
305
327
  self._activation = keras.activations.get(activation)
306
328
 
307
329
  def __init_subclass__(cls, **kwargs):
@@ -319,36 +341,6 @@ class GraphConv(GraphLayer):
319
341
  def units(self):
320
342
  return self._units
321
343
 
322
- def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
323
- """Forward pass.
324
-
325
- Invokes `message(graph)`, `aggregate(graph)` and `update(graph)` in sequence.
326
-
327
- Arguments:
328
- tensor:
329
- A `GraphTensor` instance.
330
- """
331
- if self._skip_connection:
332
- input_node_feature = tensor.node['feature']
333
- if self._project_input_node_feature:
334
- input_node_feature = self._residual_projection(input_node_feature)
335
-
336
- tensor = self.message(tensor)
337
- tensor = self.aggregate(tensor)
338
- tensor = self.update(tensor)
339
-
340
- updated_node_feature = tensor.node['feature']
341
-
342
- if self._skip_connection:
343
- if self._use_weighted_skip_connection:
344
- input_node_feature *= self._skip_connection_weight
345
- updated_node_feature += input_node_feature
346
-
347
- if self._normalization:
348
- updated_node_feature = self._output_norm(updated_node_feature)
349
-
350
- return tensor.update({'node': {'feature': updated_node_feature}})
351
-
352
344
  def build(self, spec: tensors.GraphTensor.Spec) -> None:
353
345
  if not self.units:
354
346
  raise ValueError(
@@ -356,11 +348,11 @@ class GraphConv(GraphLayer):
356
348
  )
357
349
  node_feature_dim = spec.node['feature'].shape[-1]
358
350
  self._project_input_node_feature = (
359
- self._skip_connection and (node_feature_dim != self.units)
351
+ self._skip_connect and (node_feature_dim != self.units)
360
352
  )
361
353
  if self._project_input_node_feature:
362
354
  warn(
363
- '`skip_connection` is set to `True`, but found incompatible dim '
355
+ '`skip_connect` is set to `True`, but found incompatible dim '
364
356
  'between input (node feature dim) and output (`self.units`). '
365
357
  'Automatically applying a projection layer to residual to '
366
358
  'match input and output. '
@@ -369,8 +361,8 @@ class GraphConv(GraphLayer):
369
361
  self.units, name='residual_projection'
370
362
  )
371
363
 
372
- skip_connection = str(self._skip_connection).lower()
373
- self._use_weighted_skip_connection = skip_connection.startswith('weight')
364
+ skip_connect = str(self._skip_connect).lower()
365
+ self._use_weighted_skip_connection = skip_connect.startswith('weight')
374
366
  if self._use_weighted_skip_connection:
375
367
  self._skip_connection_weight = self.add_weight(
376
368
  name='skip_connection_weight',
@@ -379,8 +371,8 @@ class GraphConv(GraphLayer):
379
371
  trainable=True,
380
372
  )
381
373
 
382
- if self._normalization:
383
- if str(self._normalization).lower().startswith('batch'):
374
+ if self._normalize:
375
+ if str(self._normalize).lower().startswith('batch'):
384
376
  self._output_norm = keras.layers.BatchNormalization(
385
377
  name='output_batch_norm'
386
378
  )
@@ -389,7 +381,7 @@ class GraphConv(GraphLayer):
389
381
  name='output_layer_norm'
390
382
  )
391
383
 
392
- self._has_edge_feature = 'edge' in spec.edge
384
+ self._has_edge_feature = 'feature' in spec.edge
393
385
 
394
386
  has_overridden_message = self.__class__.message != GraphConv.message
395
387
  if not has_overridden_message:
@@ -400,6 +392,50 @@ class GraphConv(GraphLayer):
400
392
  self._output_dense = self.get_dense(self.units)
401
393
  self._output_activation = self._activation
402
394
 
395
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
396
+ """Forward pass.
397
+
398
+ Invokes `message(graph)`, `aggregate(graph)` and `update(graph)` in sequence.
399
+
400
+ Arguments:
401
+ tensor:
402
+ A `GraphTensor` instance.
403
+ """
404
+ if self._skip_connect:
405
+ input_node_feature = tensor.node['feature']
406
+ if self._project_input_node_feature:
407
+ input_node_feature = self._residual_projection(input_node_feature)
408
+
409
+ message = self.message(tensor)
410
+ if not isinstance(message, tensors.GraphTensor):
411
+ message = tensor.update({'edge': {'message': message}})
412
+ elif not 'message' in message.edge:
413
+ raise ValueError('Could not find `message` in `edge` output.')
414
+
415
+ aggregate = self.aggregate(message)
416
+ if not isinstance(aggregate, tensors.GraphTensor):
417
+ aggregate = tensor.update({'node': {'aggregate': aggregate}})
418
+ elif not 'aggregate' in aggregate.node:
419
+ raise ValueError('Could not find `aggregate` in `node` output.')
420
+
421
+ update = self.update(aggregate)
422
+ if not isinstance(update, tensors.GraphTensor):
423
+ update = tensor.update({'node': {'feature': update}})
424
+ elif not 'feature' in update.node:
425
+ raise ValueError('Could not find `feature` in `node` output.')
426
+
427
+ updated_node_feature = update.node['feature']
428
+
429
+ if self._skip_connect:
430
+ if self._use_weighted_skip_connection:
431
+ input_node_feature *= self._skip_connection_weight
432
+ updated_node_feature += input_node_feature
433
+
434
+ if self._normalize:
435
+ updated_node_feature = self._output_norm(updated_node_feature)
436
+
437
+ return update.update({'node': {'feature': updated_node_feature}})
438
+
403
439
  def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
404
440
  """Compute messages.
405
441
 
@@ -441,8 +477,7 @@ class GraphConv(GraphLayer):
441
477
  return tensor.update(
442
478
  {
443
479
  'node': {
444
- 'feature': aggregate,
445
- 'previous_feature': tensor.node['feature']
480
+ 'aggregate': aggregate,
446
481
  },
447
482
  'edge': {
448
483
  'message': None
@@ -460,23 +495,20 @@ class GraphConv(GraphLayer):
460
495
  A `GraphTensor` instance containing aggregated messages
461
496
  (updated node features).
462
497
  """
463
- if not 'previous_feature' in tensor.node:
464
- feature = tensor.node['feature']
465
- else:
466
- feature = keras.ops.concatenate(
467
- [
468
- tensor.node['feature'],
469
- tensor.node['previous_feature']
470
- ],
471
- axis=-1
472
- )
498
+ feature = keras.ops.concatenate(
499
+ [
500
+ tensor.node['aggregate'],
501
+ tensor.node['feature']
502
+ ],
503
+ axis=-1
504
+ )
473
505
  update = self._output_dense(feature)
474
506
  update = self._output_activation(update)
475
507
  return tensor.update(
476
508
  {
477
509
  'node': {
478
510
  'feature': update,
479
- 'previous_feature': None,
511
+ 'aggregate': None,
480
512
  }
481
513
  }
482
514
  )
@@ -486,8 +518,8 @@ class GraphConv(GraphLayer):
486
518
  config.update({
487
519
  'units': self.units,
488
520
  'activation': keras.activations.serialize(self._activation),
489
- 'normalization': self._normalization,
490
- 'skip_connection': self._skip_connection,
521
+ 'normalize': self._normalize,
522
+ 'skip_connect': self._skip_connect,
491
523
  })
492
524
  return config
493
525
 
@@ -530,14 +562,14 @@ class GIConv(GraphConv):
530
562
  units: int,
531
563
  activation: keras.layers.Activation | str | None = 'relu',
532
564
  use_bias: bool = True,
533
- normalization: bool = False,
565
+ normalize: bool = False,
534
566
  update_edge_feature: bool = True,
535
567
  **kwargs,
536
568
  ):
537
569
  super().__init__(
538
570
  units=units,
539
571
  activation=activation,
540
- normalization=normalization,
572
+ normalize=normalize,
541
573
  use_bias=use_bias,
542
574
  **kwargs
543
575
  )
@@ -599,7 +631,7 @@ class GIConv(GraphConv):
599
631
  return tensor.update(
600
632
  {
601
633
  'node': {
602
- 'feature': node_feature,
634
+ 'aggregate': node_feature,
603
635
  },
604
636
  'edge': {
605
637
  'message': None,
@@ -608,7 +640,7 @@ class GIConv(GraphConv):
608
640
  )
609
641
 
610
642
  def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
611
- node_feature = tensor.node['feature']
643
+ node_feature = tensor.node['aggregate']
612
644
  node_feature = self._feedforward_intermediate_dense(node_feature)
613
645
  node_feature = self._feedforward_activation(node_feature)
614
646
  node_feature = self._feedforward_output_dense(node_feature)
@@ -616,6 +648,7 @@ class GIConv(GraphConv):
616
648
  {
617
649
  'node': {
618
650
  'feature': node_feature,
651
+ 'aggregate': None,
619
652
  }
620
653
  }
621
654
  )
@@ -667,17 +700,16 @@ class GAConv(GraphConv):
667
700
  heads: int = 8,
668
701
  activation: keras.layers.Activation | str | None = "relu",
669
702
  use_bias: bool = True,
670
- normalization: bool = False,
703
+ normalize: bool = False,
671
704
  update_edge_feature: bool = True,
672
705
  attention_activation: keras.layers.Activation | str | None = "leaky_relu",
673
706
  **kwargs,
674
707
  ) -> None:
675
- kwargs['skip_connection'] = False
676
708
  super().__init__(
677
709
  units=units,
678
710
  activation=activation,
679
711
  use_bias=use_bias,
680
- normalization=normalization,
712
+ normalize=normalize,
681
713
  **kwargs
682
714
  )
683
715
  self._heads = heads
@@ -753,11 +785,11 @@ class GAConv(GraphConv):
753
785
  )
754
786
  node_feature = self._node_dense(tensor.node['feature'])
755
787
  message = ops.gather(node_feature, tensor.edge['source'])
788
+ message = ops.edge_weight(message, attention_score)
756
789
  return tensor.update(
757
790
  {
758
791
  'edge': {
759
792
  'message': message,
760
- 'weight': attention_score,
761
793
  'feature': edge_feature,
762
794
  }
763
795
  }
@@ -770,24 +802,24 @@ class GAConv(GraphConv):
770
802
  return tensor.update(
771
803
  {
772
804
  'node': {
773
- 'feature': node_feature
805
+ 'aggregate': node_feature
774
806
  },
775
807
  'edge': {
776
808
  'message': None,
777
- 'weight': None,
778
809
  }
779
810
  }
780
811
  )
781
812
 
782
813
  def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
783
- node_feature = tensor.node['feature']
814
+ node_feature = tensor.node['aggregate']
784
815
  node_feature = self._feedforward_intermediate_dense(node_feature)
785
816
  node_feature = self._feedforward_activation(node_feature)
786
817
  node_feature = self._feedforward_output_dense(node_feature)
787
818
  return tensor.update(
788
819
  {
789
820
  'node': {
790
- 'feature': node_feature
821
+ 'feature': node_feature,
822
+ 'aggregate': None,
791
823
  }
792
824
  }
793
825
  )
@@ -842,7 +874,7 @@ class GTConv(GraphConv):
842
874
  heads: int = 8,
843
875
  activation: keras.layers.Activation | str | None = "relu",
844
876
  use_bias: bool = True,
845
- normalization: bool = False,
877
+ normalize: bool = False,
846
878
  attention_dropout: float = 0.0,
847
879
  **kwargs,
848
880
  ) -> None:
@@ -850,7 +882,7 @@ class GTConv(GraphConv):
850
882
  units=units,
851
883
  activation=activation,
852
884
  use_bias=use_bias,
853
- normalization=normalization,
885
+ normalize=normalize,
854
886
  **kwargs
855
887
  )
856
888
  self._heads = heads
@@ -901,7 +933,6 @@ class GTConv(GraphConv):
901
933
  }
902
934
  }
903
935
  )
904
-
905
936
  node_feature = tensor.node['feature']
906
937
 
907
938
  query = self._query_dense(node_feature)
@@ -918,12 +949,12 @@ class GTConv(GraphConv):
918
949
  attention_score += keras.ops.expand_dims(tensor.edge['bias'], axis=1)
919
950
  attention = ops.edge_softmax(attention_score, tensor.edge['target'])
920
951
  attention = self._softmax_dropout(attention)
952
+ message = ops.edge_weight(value, attention)
921
953
 
922
954
  return tensor.update(
923
955
  {
924
956
  'edge': {
925
- 'message': value,
926
- 'weight': attention,
957
+ 'message': message
927
958
  },
928
959
  }
929
960
  )
@@ -935,18 +966,16 @@ class GTConv(GraphConv):
935
966
  return tensor.update(
936
967
  {
937
968
  'node': {
938
- 'feature': node_feature,
939
- 'residual': tensor.node['feature']
969
+ 'aggregate': node_feature,
940
970
  },
941
971
  'edge': {
942
972
  'message': None,
943
- 'weight': None,
944
973
  }
945
974
  }
946
975
  )
947
976
 
948
977
  def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
949
- node_feature = tensor.node['feature']
978
+ node_feature = tensor.node['aggregate']
950
979
  node_feature = self._feedforward_intermediate_dense(node_feature)
951
980
  node_feature = self._feedforward_activation(node_feature)
952
981
  node_feature = self._feedforward_output_dense(node_feature)
@@ -954,6 +983,7 @@ class GTConv(GraphConv):
954
983
  {
955
984
  'node': {
956
985
  'feature': node_feature,
986
+ 'aggregate': None,
957
987
  },
958
988
  }
959
989
  )
@@ -978,14 +1008,14 @@ class MPConv(GraphConv):
978
1008
  units: int = 128,
979
1009
  activation: keras.layers.Activation | str | None = None,
980
1010
  use_bias: bool = True,
981
- normalization: bool = False,
1011
+ normalize: bool = False,
982
1012
  **kwargs
983
1013
  ) -> None:
984
1014
  super().__init__(
985
1015
  units=units,
986
1016
  activation=activation,
987
1017
  use_bias=use_bias,
988
- normalization=normalization,
1018
+ normalize=normalize,
989
1019
  **kwargs
990
1020
  )
991
1021
 
@@ -1032,28 +1062,28 @@ class MPConv(GraphConv):
1032
1062
 
1033
1063
  def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1034
1064
  aggregate = tensor.aggregate('message', mode='mean')
1035
- previous = tensor.node['feature']
1065
+ feature = tensor.node['feature']
1036
1066
  if self.project_input_node_feature:
1037
- previous = self._previous_node_dense(previous)
1067
+ feature = self._previous_node_dense(feature)
1038
1068
  return tensor.update(
1039
1069
  {
1040
1070
  'node': {
1041
- 'feature': aggregate,
1042
- 'previous_feature': previous,
1071
+ 'aggregate': aggregate,
1072
+ 'feature': feature,
1043
1073
  }
1044
1074
  }
1045
1075
  )
1046
1076
 
1047
1077
  def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1048
1078
  updated_node_feature, _ = self.update_fn(
1049
- inputs=tensor.node['feature'],
1050
- states=tensor.node['previous_feature']
1079
+ inputs=tensor.node['aggregate'],
1080
+ states=tensor.node['feature']
1051
1081
  )
1052
1082
  return tensor.update(
1053
1083
  {
1054
1084
  'node': {
1055
1085
  'feature': updated_node_feature,
1056
- 'previous_feature': None,
1086
+ 'aggregate': None,
1057
1087
  }
1058
1088
  }
1059
1089
  )
@@ -1124,12 +1154,13 @@ class GTConv3D(GTConv):
1124
1154
  attention *= keras.ops.expand_dims(distance, axis=-1)
1125
1155
  attention = keras.ops.expand_dims(attention, axis=2)
1126
1156
  value = keras.ops.expand_dims(value, axis=1)
1157
+
1158
+ message = ops.edge_weight(value, attention)
1127
1159
 
1128
1160
  return tensor.update(
1129
1161
  {
1130
1162
  'edge': {
1131
- 'message': value,
1132
- 'weight': attention,
1163
+ 'message': message,
1133
1164
  },
1134
1165
  }
1135
1166
  )
@@ -1144,12 +1175,10 @@ class GTConv3D(GTConv):
1144
1175
  return tensor.update(
1145
1176
  {
1146
1177
  'node': {
1147
- 'feature': node_feature,
1148
- 'residual': tensor.node['feature']
1178
+ 'aggregate': node_feature,
1149
1179
  },
1150
1180
  'edge': {
1151
1181
  'message': None,
1152
- 'weight': None,
1153
1182
  }
1154
1183
  }
1155
1184
  )
@@ -1202,16 +1231,16 @@ class EGConv3D(GraphConv):
1202
1231
  def __init__(
1203
1232
  self,
1204
1233
  units: int = 128,
1205
- activation: keras.layers.Activation | str | None = None,
1234
+ activation: keras.layers.Activation | str | None = 'silu',
1206
1235
  use_bias: bool = True,
1207
- normalization: bool = False,
1236
+ normalize: bool = False,
1208
1237
  **kwargs
1209
1238
  ) -> None:
1210
1239
  super().__init__(
1211
1240
  units=units,
1212
1241
  activation=activation,
1213
1242
  use_bias=use_bias,
1214
- normalization=normalization,
1243
+ normalize=normalize,
1215
1244
  **kwargs
1216
1245
  )
1217
1246
 
@@ -1222,31 +1251,52 @@ class EGConv3D(GraphConv):
1222
1251
  'which is required for Conv3D layers.'
1223
1252
  )
1224
1253
  self._has_edge_feature = 'feature' in spec.edge
1225
- self.message_fn = self.get_dense(self.units, activation=self._activation)
1226
- self.dense_position = self.get_dense(1)
1254
+ self._message_feedforward_intermediate = self.get_dense(
1255
+ self.units, activation=self._activation
1256
+ )
1257
+ self._message_feedforward_final = self.get_dense(
1258
+ self.units, activation=self._activation
1259
+ )
1260
+
1261
+ self._coord_feedforward_intermediate = self.get_dense(
1262
+ self.units, activation=self._activation
1263
+ )
1264
+ self._coord_feedforward_final = self.get_dense(
1265
+ 1, use_bias=False, activation='tanh'
1266
+ )
1227
1267
 
1228
1268
  has_overridden_update = self.__class__.update != EGConv3D.update
1229
1269
  if not has_overridden_update:
1230
- self.update_fn = self.get_dense(self.units, activation=self._activation)
1231
- self.output_dense = self.get_dense(self.units)
1270
+ self._feedforward_intermediate = self.get_dense(
1271
+ self.units, activation=self._activation
1272
+ )
1273
+ self._feedforward_output = self.get_dense(self.units)
1232
1274
 
1233
1275
  def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1234
1276
  relative_node_coordinate = keras.ops.subtract(
1235
1277
  tensor.gather('coordinate', 'target'),
1236
1278
  tensor.gather('coordinate', 'source')
1237
1279
  )
1238
- euclidean_distance = keras.ops.sum(
1239
- keras.ops.square(
1240
- relative_node_coordinate
1241
- ),
1280
+ squared_distance = keras.ops.sum(
1281
+ keras.ops.square(relative_node_coordinate),
1242
1282
  axis=-1,
1243
1283
  keepdims=True
1244
1284
  )
1285
+
1286
+ # For numerical stability (i.e., to prevent NaN losses), this implementation of `EGConv3D`
1287
+ # either needs to apply a `tanh` activation to the output of `self._coord_feedforward_final`,
1288
+ # or normalize `relative_node_cordinate` as follows:
1289
+ #
1290
+ # norm = keras.ops.sqrt(squared_distance) + keras.backend.epsilon()
1291
+ # relative_node_coordinate /= norm
1292
+ #
1293
+ # For now, this implementation does the former.
1294
+
1245
1295
  feature = keras.ops.concatenate(
1246
1296
  [
1247
1297
  tensor.gather('feature', 'target'),
1248
1298
  tensor.gather('feature', 'source'),
1249
- euclidean_distance,
1299
+ squared_distance,
1250
1300
  ],
1251
1301
  axis=-1
1252
1302
  )
@@ -1258,10 +1308,15 @@ class EGConv3D(GraphConv):
1258
1308
  ],
1259
1309
  axis=-1
1260
1310
  )
1261
- message = self.message_fn(feature)
1311
+ message = self._message_feedforward_final(
1312
+ self._message_feedforward_intermediate(feature)
1313
+ )
1314
+
1262
1315
  relative_node_coordinate = keras.ops.multiply(
1263
- relative_node_coordinate,
1264
- self.dense_position(message)
1316
+ relative_node_coordinate,
1317
+ self._coord_feedforward_final(
1318
+ self._coord_feedforward_intermediate(message)
1319
+ )
1265
1320
  )
1266
1321
  return tensor.update(
1267
1322
  {
@@ -1273,27 +1328,26 @@ class EGConv3D(GraphConv):
1273
1328
  )
1274
1329
 
1275
1330
  def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1276
- coefficient = keras.ops.bincount(
1277
- tensor.edge['source'],
1278
- minlength=tensor.num_nodes
1279
- )
1280
- coefficient = keras.ops.cast(
1281
- coefficient, tensor.node['coordinate'].dtype
1282
- )
1283
- coefficient = keras.ops.expand_dims(
1284
- keras.ops.divide_no_nan(1, coefficient), axis=1
1285
- )
1331
+ coordinate = tensor.node['coordinate']
1332
+ coordinate += tensor.aggregate('relative_node_coordinate', mode='mean')
1333
+
1334
+ # Original implementation seems to apply sum aggregation, which does not
1335
+ # seem work well for this implementation of `EGConv3D`, as it causes
1336
+ # large output values and large initial losses. The magnitude of the
1337
+ # aggregated values of a sum aggregation depends on the number of
1338
+ # neighbors, which may be many and may differ from node to node (or
1339
+ # graph to graph). Therefore, a mean mean aggregation is performed
1340
+ # instead:
1341
+ aggregate = tensor.aggregate('message', mode='mean')
1286
1342
 
1287
- updated_coordinate = tensor.aggregate('relative_node_coordinate') * coefficient
1288
- updated_coordinate += tensor.node['coordinate']
1343
+ # Simply added to silence warning ('no gradients for variables ...')
1344
+ aggregate += (0.0 * keras.ops.sum(coordinate))
1289
1345
 
1290
- aggregate = tensor.aggregate('message', mode='mean')
1291
1346
  return tensor.update(
1292
1347
  {
1293
1348
  'node': {
1294
- 'feature': aggregate,
1295
- 'coordinate': updated_coordinate,
1296
- 'previous_feature': tensor.node['feature'],
1349
+ 'aggregate': aggregate,
1350
+ 'coordinate': coordinate,
1297
1351
  },
1298
1352
  'edge': {
1299
1353
  'message': None,
@@ -1303,21 +1357,21 @@ class EGConv3D(GraphConv):
1303
1357
  )
1304
1358
 
1305
1359
  def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1306
- updated_node_feature = self.update_fn(
1307
- keras.ops.concatenate(
1308
- [
1309
- tensor.node['feature'],
1310
- tensor.node['previous_feature']
1311
- ],
1312
- axis=-1
1313
- )
1360
+ feature = keras.ops.concatenate(
1361
+ [
1362
+ tensor.node['aggregate'],
1363
+ tensor.node['feature']
1364
+ ],
1365
+ axis=-1
1366
+ )
1367
+ updated_node_feature = self._feedforward_output(
1368
+ self._feedforward_intermediate(feature)
1314
1369
  )
1315
- updated_node_feature = self.output_dense(updated_node_feature)
1316
1370
  return tensor.update(
1317
1371
  {
1318
1372
  'node': {
1319
1373
  'feature': updated_node_feature,
1320
- 'previous_feature': None,
1374
+ 'aggregate': None,
1321
1375
  },
1322
1376
  }
1323
1377
  )
@@ -1478,6 +1532,32 @@ class GraphNetwork(GraphLayer):
1478
1532
  return super().from_config(config)
1479
1533
 
1480
1534
 
1535
+ @keras.saving.register_keras_serializable(package='molcraft')
1536
+ class Extraction(GraphLayer):
1537
+
1538
+ def __init__(
1539
+ self,
1540
+ field: str,
1541
+ inner_field: str | None = None,
1542
+ **kwargs
1543
+ ) -> None:
1544
+ super().__init__(**kwargs)
1545
+ self.field = field
1546
+ self.inner_field = inner_field
1547
+
1548
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1549
+ data = dict(getattr(tensor, self.field))
1550
+ if not self.inner_field:
1551
+ return data
1552
+ return data[self.inner_field]
1553
+
1554
+ def get_config(self):
1555
+ config = super().get_config()
1556
+ config['field'] = self.field
1557
+ config['inner_field'] = self.inner_field
1558
+ return config
1559
+
1560
+
1481
1561
  @keras.saving.register_keras_serializable(package='molcraft')
1482
1562
  class NodeEmbedding(GraphLayer):
1483
1563
 
@@ -1489,15 +1569,15 @@ class NodeEmbedding(GraphLayer):
1489
1569
  def __init__(
1490
1570
  self,
1491
1571
  dim: int = None,
1492
- normalization: bool = False,
1493
- embed_context: bool = True,
1572
+ normalize: bool = False,
1573
+ embed_context: bool = False,
1494
1574
  allow_reconstruction: bool = False,
1495
- allow_masking: bool = True,
1575
+ allow_masking: bool = False,
1496
1576
  **kwargs
1497
1577
  ) -> None:
1498
1578
  super().__init__(**kwargs)
1499
1579
  self.dim = dim
1500
- self._normalization = normalization
1580
+ self._normalize = normalize
1501
1581
  self._embed_context = embed_context
1502
1582
  self._masking_rate = None
1503
1583
  self._allow_masking = allow_masking
@@ -1517,13 +1597,11 @@ class NodeEmbedding(GraphLayer):
1517
1597
  self._super_feature = self.get_weight(shape=[self.dim], name='super_node_feature')
1518
1598
  if self._allow_masking:
1519
1599
  self._mask_feature = self.get_weight(shape=[self.dim], name='mask_node_feature')
1520
-
1521
1600
  if self._embed_context:
1522
- context_feature_dim = spec.context['feature'].shape[-1]
1523
1601
  self._context_dense = self.get_dense(self.dim)
1524
1602
 
1525
- if self._normalization:
1526
- if str(self._normalization).lower().startswith('batch'):
1603
+ if self._normalize:
1604
+ if str(self._normalize).lower().startswith('batch'):
1527
1605
  self._norm = keras.layers.BatchNormalization(
1528
1606
  name='output_batch_norm'
1529
1607
  )
@@ -1545,48 +1623,25 @@ class NodeEmbedding(GraphLayer):
1545
1623
  feature = ops.scatter_update(feature, tensor.node['super'], context_feature)
1546
1624
  tensor = tensor.update({'context': {'feature': None}})
1547
1625
 
1548
- if (
1549
- self._allow_masking and
1550
- self._masking_rate is not None and
1551
- self._masking_rate > 0
1552
- ):
1553
- random = keras.random.uniform(shape=[tensor.num_nodes])
1554
- mask = random <= self._masking_rate
1555
- if self._has_super:
1556
- mask = keras.ops.logical_and(
1557
- mask, keras.ops.logical_not(tensor.node['super'])
1558
- )
1559
- mask = keras.ops.expand_dims(mask, -1)
1626
+ apply_mask = (self._allow_masking and 'mask' in tensor.node)
1627
+ if apply_mask:
1628
+ mask = keras.ops.expand_dims(tensor.node['mask'], -1)
1560
1629
  feature = keras.ops.where(mask, self._mask_feature, feature)
1561
1630
  elif self._allow_masking:
1562
- # Slience warning of 'no gradients for variables'
1563
1631
  feature = feature + (self._mask_feature * 0.0)
1564
1632
 
1565
- if self._normalization:
1633
+ if self._normalize:
1566
1634
  feature = self._norm(feature)
1567
1635
 
1568
1636
  if not self._allow_reconstruction:
1569
1637
  return tensor.update({'node': {'feature': feature}})
1570
1638
  return tensor.update({'node': {'feature': feature, 'target_feature': feature}})
1571
-
1572
- @property
1573
- def masking_rate(self):
1574
- return self._masking_rate
1575
-
1576
- @masking_rate.setter
1577
- def masking_rate(self, rate: float):
1578
- if not self._allow_masking and rate is not None:
1579
- raise ValueError(
1580
- f'Cannot set `masking_rate` for layer {self} '
1581
- 'as `allow_masking` was set to `False`.'
1582
- )
1583
- self._masking_rate = float(rate)
1584
1639
 
1585
1640
  def get_config(self) -> dict:
1586
1641
  config = super().get_config()
1587
1642
  config.update({
1588
1643
  'dim': self.dim,
1589
- 'normalization': self._normalization,
1644
+ 'normalize': self._normalize,
1590
1645
  'embed_context': self._embed_context,
1591
1646
  'allow_masking': self._allow_masking,
1592
1647
  'allow_reconstruction': self._allow_reconstruction,
@@ -1605,13 +1660,13 @@ class EdgeEmbedding(GraphLayer):
1605
1660
  def __init__(
1606
1661
  self,
1607
1662
  dim: int = None,
1608
- normalization: bool = False,
1663
+ normalize: bool = False,
1609
1664
  allow_masking: bool = True,
1610
1665
  **kwargs
1611
1666
  ) -> None:
1612
1667
  super().__init__(**kwargs)
1613
1668
  self.dim = dim
1614
- self._normalization = normalization
1669
+ self._normalize = normalize
1615
1670
  self._masking_rate = None
1616
1671
  self._allow_masking = allow_masking
1617
1672
 
@@ -1622,13 +1677,16 @@ class EdgeEmbedding(GraphLayer):
1622
1677
  self._edge_dense = self.get_dense(self.dim)
1623
1678
 
1624
1679
  self._has_super = 'super' in spec.edge
1680
+ self._has_self_loop = 'self_loop' in spec.edge
1625
1681
  if self._has_super:
1626
1682
  self._super_feature = self.get_weight(shape=[self.dim], name='super_edge_feature')
1683
+ if self._has_self_loop:
1684
+ self._self_loop_feature = self.get_weight(shape=[self.dim], name='self_loop_edge_feature')
1627
1685
  if self._allow_masking:
1628
1686
  self._mask_feature = self.get_weight(shape=[self.dim], name='mask_edge_feature')
1629
1687
 
1630
- if self._normalization:
1631
- if str(self._normalization).lower().startswith('batch'):
1688
+ if self._normalize:
1689
+ if str(self._normalize).lower().startswith('batch'):
1632
1690
  self._norm = keras.layers.BatchNormalization(
1633
1691
  name='output_batch_norm'
1634
1692
  )
@@ -1641,10 +1699,13 @@ class EdgeEmbedding(GraphLayer):
1641
1699
  feature = self._edge_dense(tensor.edge['feature'])
1642
1700
 
1643
1701
  if self._has_super:
1644
- super_feature = self._super_feature
1645
1702
  super_mask = keras.ops.expand_dims(tensor.edge['super'], 1)
1646
- feature = keras.ops.where(super_mask, super_feature, feature)
1703
+ feature = keras.ops.where(super_mask, self._super_feature, feature)
1647
1704
 
1705
+ if self._has_self_loop:
1706
+ self_loop_mask = keras.ops.expand_dims(tensor.edge['self_loop'], 1)
1707
+ feature = keras.ops.where(self_loop_mask, self._self_loop_feature, feature)
1708
+
1648
1709
  if (
1649
1710
  self._allow_masking and
1650
1711
  self._masking_rate is not None and
@@ -1659,10 +1720,10 @@ class EdgeEmbedding(GraphLayer):
1659
1720
  mask = keras.ops.expand_dims(mask, -1)
1660
1721
  feature = keras.ops.where(mask, self._mask_feature, feature)
1661
1722
  elif self._allow_masking:
1662
- # Slience warning of 'no gradients for variables'
1663
- feature = feature + (self._mask_feature * 0.0)
1723
+ # Simply added to silence warning ('no gradients for variables ...')
1724
+ feature += (0.0 * self._mask_feature)
1664
1725
 
1665
- if self._normalization:
1726
+ if self._normalize:
1666
1727
  feature = self._norm(feature)
1667
1728
 
1668
1729
  return tensor.update({'edge': {'feature': feature, 'embedding': feature}})
@@ -1684,7 +1745,7 @@ class EdgeEmbedding(GraphLayer):
1684
1745
  config = super().get_config()
1685
1746
  config.update({
1686
1747
  'dim': self.dim,
1687
- 'normalization': self._normalization,
1748
+ 'normalize': self._normalize,
1688
1749
  'allow_masking': self._allow_masking
1689
1750
  })
1690
1751
  return config
@@ -1883,6 +1944,56 @@ class GaussianDistance(GraphLayer):
1883
1944
  return config
1884
1945
 
1885
1946
 
1947
+ @keras.saving.register_keras_serializable(package='molcraft')
1948
+ class GaussianParams(keras.layers.Dense):
1949
+ '''Gaussian parameters.
1950
+
1951
+ Computes loc and scale via a dense layer. Should be implemented
1952
+ as the last layer in a model and paired with `losses.GaussianNLL`.
1953
+
1954
+ The loc and scale parameters (resulting from this layer) are concatenated
1955
+ together along the last axis, resulting in a single output tensor.
1956
+
1957
+ Args:
1958
+ events (int):
1959
+ The number of events. If the model makes a single prediction per example,
1960
+ then the number of events should be 1. If the model makes multiple predictions
1961
+ per example, then the number of events should be greater than 1.
1962
+ Default to 1.
1963
+ kwargs:
1964
+ See `keras.layers.Dense` documentation. `activation` will be applied
1965
+ to `loc` only. `scale` is automatically softplus activated.
1966
+ '''
1967
+ def __init__(self, events: int = 1, **kwargs):
1968
+ units = kwargs.pop('units', None)
1969
+ activation = kwargs.pop('activation', None)
1970
+ if units:
1971
+ if units % 2 != 0:
1972
+ raise ValueError(
1973
+ '`units` needs to be divisble by 2 as `units` = 2 x `events`.'
1974
+ )
1975
+ else:
1976
+ units = int(events * 2)
1977
+ super().__init__(units=units, **kwargs)
1978
+ self.events = events
1979
+ self.loc_activation = keras.activations.get(activation)
1980
+
1981
+ def call(self, inputs, **kwargs):
1982
+ loc_and_scale = super().call(inputs, **kwargs)
1983
+ loc = loc_and_scale[..., :self.events]
1984
+ scale = loc_and_scale[..., self.events:]
1985
+ scale = keras.ops.softplus(scale) + keras.backend.epsilon()
1986
+ loc = self.loc_activation(loc)
1987
+ return keras.ops.concatenate([loc, scale], axis=-1)
1988
+
1989
+ def get_config(self):
1990
+ config = super().get_config()
1991
+ config['events'] = self.events
1992
+ config['units'] = None
1993
+ config['activation'] = keras.activations.serialize(self.loc_activation)
1994
+ return config
1995
+
1996
+
1886
1997
  def Input(spec: tensors.GraphTensor.Spec) -> dict:
1887
1998
  """Used to specify inputs to model.
1888
1999
 
@@ -1914,9 +2025,11 @@ def Input(spec: tensors.GraphTensor.Spec) -> dict:
1914
2025
  for outer_field, data in spec.__dict__.items():
1915
2026
  inputs[outer_field] = {}
1916
2027
  for inner_field, nested_spec in data.items():
1917
- if inner_field in ['label', 'weight']:
1918
- if outer_field == 'context':
1919
- continue
2028
+ if outer_field == 'context' and inner_field in ['label', 'weight']:
2029
+ # Remove context label and weight from the symbolic input
2030
+ # as a functional model is strict for what input can be passed.
2031
+ # (We want to train and predict with the model.)
2032
+ continue
1920
2033
  kwargs = {
1921
2034
  'shape': nested_spec.shape[1:],
1922
2035
  'dtype': nested_spec.dtype,
@@ -1941,23 +2054,6 @@ def warn(message: str) -> None:
1941
2054
  stacklevel=1
1942
2055
  )
1943
2056
 
1944
- def _match_functional_input(functional_input, inputs):
1945
- matching_inputs = {}
1946
- for outer_field, data in functional_input.items():
1947
- matching_inputs[outer_field] = {}
1948
- for inner_field, _ in data.items():
1949
- call_input = inputs[outer_field].pop(inner_field)
1950
- matching_inputs[outer_field][inner_field] = call_input
1951
- unmatching_inputs = inputs
1952
- return matching_inputs, unmatching_inputs
1953
-
1954
- def _add_left_out_inputs(outputs, inputs):
1955
- for outer_field, data in inputs.items():
1956
- for inner_field, value in data.items():
1957
- if inner_field in ['label', 'weight']:
1958
- outputs[outer_field][inner_field] = value
1959
- return outputs
1960
-
1961
2057
  def _serialize_spec(spec: tensors.GraphTensor.Spec) -> dict:
1962
2058
  serialized_spec = {}
1963
2059
  for outer_field, data in spec.__dict__.items():