molcraft 0.1.0a5__py3-none-any.whl → 0.1.0a6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

molcraft/layers.py CHANGED
@@ -125,23 +125,45 @@ class GraphLayer(keras.layers.Layer):
125
125
  return tensors.to_dict(outputs)
126
126
  return outputs
127
127
 
128
- def __call__(self, inputs, **kwargs):
128
+ def __call__(
129
+ self,
130
+ graph: dict[str, dict[str, tf.Tensor]] | tensors.GraphTensor,
131
+ **kwargs
132
+ ) -> tf.Tensor | dict[str, dict[str, tf.Tensor]] | tensors.GraphTensor:
129
133
  if not self.built:
130
- spec = _spec_from_inputs(inputs)
134
+ spec = _spec_from_inputs(graph)
131
135
  self.build(spec)
132
- convert = isinstance(inputs, tensors.GraphTensor)
133
- if convert:
134
- inputs = tensors.to_dict(inputs)
136
+
137
+ is_graph_tensor = isinstance(graph, tensors.GraphTensor)
138
+ if is_graph_tensor:
139
+ graph = tensors.to_dict(graph)
140
+ else:
141
+ graph = {field: dict(data) for (field, data) in graph.items()}
142
+
135
143
  if isinstance(self, functional.Functional):
136
- inputs, left_out_inputs = _match_functional_input(self.input, inputs)
137
- outputs = super().__call__(inputs, **kwargs)
144
+ # As a functional model is strict for what input can
145
+ # be passed to it, we need to temporarily pop some of the
146
+ # input and add it afterwards.
147
+ label = graph['context'].pop('label', None)
148
+ weight = graph['context'].pop('weight', None)
149
+ tf.nest.assert_same_structure(self.input, graph)
150
+
151
+ outputs = super().__call__(graph, **kwargs)
152
+
138
153
  if not tensors.is_graph(outputs):
139
154
  return outputs
155
+
156
+ graph = outputs
140
157
  if isinstance(self, functional.Functional):
141
- outputs = _add_left_out_inputs(outputs, left_out_inputs)
142
- if convert:
143
- outputs = tensors.from_dict(outputs)
144
- return outputs
158
+ if label is not None:
159
+ graph['context']['label'] = label
160
+ if weight is not None:
161
+ graph['context']['weight'] = weight
162
+
163
+ if is_graph_tensor:
164
+ return tensors.from_dict(graph)
165
+
166
+ return graph
145
167
 
146
168
  def get_build_config(self) -> dict:
147
169
  if self._custom_build_config:
@@ -256,10 +278,10 @@ class GraphConv(GraphLayer):
256
278
  Default to `None`.
257
279
  use_bias (bool):
258
280
  Whether bias should be used in dense layers. Default to `True`.
259
- normalization (bool, str):
281
+ normalize (bool, str):
260
282
  Whether `LayerNormalization` should be applied to the final node feature output.
261
283
  To use `BatchNormalization`, specify `batch_norm`. Default to `False`.
262
- skip_connection (bool, str):
284
+ skip_connect (bool, str):
263
285
  Whether node feature input should be added to the node feature output.
264
286
  If node feature input dim is not equal to `units` (node feature output dim),
265
287
  a projection layer will automatically project the residual before adding it
@@ -294,14 +316,14 @@ class GraphConv(GraphLayer):
294
316
  units: int = None,
295
317
  activation: str | keras.layers.Activation | None = None,
296
318
  use_bias: bool = True,
297
- normalization: bool | str = False,
298
- skip_connection: bool | str = True,
319
+ normalize: bool | str = False,
320
+ skip_connect: bool | str = True,
299
321
  **kwargs
300
322
  ) -> None:
301
323
  super().__init__(use_bias=use_bias, **kwargs)
302
324
  self._units = units
303
- self._normalization = normalization
304
- self._skip_connection = skip_connection
325
+ self._normalize = normalize
326
+ self._skip_connect = skip_connect
305
327
  self._activation = keras.activations.get(activation)
306
328
 
307
329
  def __init_subclass__(cls, **kwargs):
@@ -319,36 +341,6 @@ class GraphConv(GraphLayer):
319
341
  def units(self):
320
342
  return self._units
321
343
 
322
- def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
323
- """Forward pass.
324
-
325
- Invokes `message(graph)`, `aggregate(graph)` and `update(graph)` in sequence.
326
-
327
- Arguments:
328
- tensor:
329
- A `GraphTensor` instance.
330
- """
331
- if self._skip_connection:
332
- input_node_feature = tensor.node['feature']
333
- if self._project_input_node_feature:
334
- input_node_feature = self._residual_projection(input_node_feature)
335
-
336
- tensor = self.message(tensor)
337
- tensor = self.aggregate(tensor)
338
- tensor = self.update(tensor)
339
-
340
- updated_node_feature = tensor.node['feature']
341
-
342
- if self._skip_connection:
343
- if self._use_weighted_skip_connection:
344
- input_node_feature *= self._skip_connection_weight
345
- updated_node_feature += input_node_feature
346
-
347
- if self._normalization:
348
- updated_node_feature = self._output_norm(updated_node_feature)
349
-
350
- return tensor.update({'node': {'feature': updated_node_feature}})
351
-
352
344
  def build(self, spec: tensors.GraphTensor.Spec) -> None:
353
345
  if not self.units:
354
346
  raise ValueError(
@@ -356,11 +348,11 @@ class GraphConv(GraphLayer):
356
348
  )
357
349
  node_feature_dim = spec.node['feature'].shape[-1]
358
350
  self._project_input_node_feature = (
359
- self._skip_connection and (node_feature_dim != self.units)
351
+ self._skip_connect and (node_feature_dim != self.units)
360
352
  )
361
353
  if self._project_input_node_feature:
362
354
  warn(
363
- '`skip_connection` is set to `True`, but found incompatible dim '
355
+ '`skip_connect` is set to `True`, but found incompatible dim '
364
356
  'between input (node feature dim) and output (`self.units`). '
365
357
  'Automatically applying a projection layer to residual to '
366
358
  'match input and output. '
@@ -369,8 +361,8 @@ class GraphConv(GraphLayer):
369
361
  self.units, name='residual_projection'
370
362
  )
371
363
 
372
- skip_connection = str(self._skip_connection).lower()
373
- self._use_weighted_skip_connection = skip_connection.startswith('weight')
364
+ skip_connect = str(self._skip_connect).lower()
365
+ self._use_weighted_skip_connection = skip_connect.startswith('weight')
374
366
  if self._use_weighted_skip_connection:
375
367
  self._skip_connection_weight = self.add_weight(
376
368
  name='skip_connection_weight',
@@ -379,8 +371,8 @@ class GraphConv(GraphLayer):
379
371
  trainable=True,
380
372
  )
381
373
 
382
- if self._normalization:
383
- if str(self._normalization).lower().startswith('batch'):
374
+ if self._normalize:
375
+ if str(self._normalize).lower().startswith('batch'):
384
376
  self._output_norm = keras.layers.BatchNormalization(
385
377
  name='output_batch_norm'
386
378
  )
@@ -389,7 +381,7 @@ class GraphConv(GraphLayer):
389
381
  name='output_layer_norm'
390
382
  )
391
383
 
392
- self._has_edge_feature = 'edge' in spec.edge
384
+ self._has_edge_feature = 'feature' in spec.edge
393
385
 
394
386
  has_overridden_message = self.__class__.message != GraphConv.message
395
387
  if not has_overridden_message:
@@ -400,6 +392,50 @@ class GraphConv(GraphLayer):
400
392
  self._output_dense = self.get_dense(self.units)
401
393
  self._output_activation = self._activation
402
394
 
395
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
396
+ """Forward pass.
397
+
398
+ Invokes `message(graph)`, `aggregate(graph)` and `update(graph)` in sequence.
399
+
400
+ Arguments:
401
+ tensor:
402
+ A `GraphTensor` instance.
403
+ """
404
+ if self._skip_connect:
405
+ input_node_feature = tensor.node['feature']
406
+ if self._project_input_node_feature:
407
+ input_node_feature = self._residual_projection(input_node_feature)
408
+
409
+ message = self.message(tensor)
410
+ if not isinstance(message, tensors.GraphTensor):
411
+ message = tensor.update({'edge': {'message': message}})
412
+ elif not 'message' in message.edge:
413
+ raise ValueError('Could not find `message` in `edge` output.')
414
+
415
+ aggregate = self.aggregate(message)
416
+ if not isinstance(aggregate, tensors.GraphTensor):
417
+ aggregate = tensor.update({'node': {'aggregate': aggregate}})
418
+ elif not 'aggregate' in aggregate.node:
419
+ raise ValueError('Could not find `aggregate` in `node` output.')
420
+
421
+ update = self.update(aggregate)
422
+ if not isinstance(update, tensors.GraphTensor):
423
+ update = tensor.update({'node': {'feature': update}})
424
+ elif not 'feature' in update.node:
425
+ raise ValueError('Could not find `feature` in `node` output.')
426
+
427
+ updated_node_feature = update.node['feature']
428
+
429
+ if self._skip_connect:
430
+ if self._use_weighted_skip_connection:
431
+ input_node_feature *= self._skip_connection_weight
432
+ updated_node_feature += input_node_feature
433
+
434
+ if self._normalize:
435
+ updated_node_feature = self._output_norm(updated_node_feature)
436
+
437
+ return update.update({'node': {'feature': updated_node_feature}})
438
+
403
439
  def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
404
440
  """Compute messages.
405
441
 
@@ -441,8 +477,7 @@ class GraphConv(GraphLayer):
441
477
  return tensor.update(
442
478
  {
443
479
  'node': {
444
- 'feature': aggregate,
445
- 'previous_feature': tensor.node['feature']
480
+ 'aggregate': aggregate,
446
481
  },
447
482
  'edge': {
448
483
  'message': None
@@ -460,23 +495,20 @@ class GraphConv(GraphLayer):
460
495
  A `GraphTensor` instance containing aggregated messages
461
496
  (updated node features).
462
497
  """
463
- if not 'previous_feature' in tensor.node:
464
- feature = tensor.node['feature']
465
- else:
466
- feature = keras.ops.concatenate(
467
- [
468
- tensor.node['feature'],
469
- tensor.node['previous_feature']
470
- ],
471
- axis=-1
472
- )
498
+ feature = keras.ops.concatenate(
499
+ [
500
+ tensor.node['aggregate'],
501
+ tensor.node['feature']
502
+ ],
503
+ axis=-1
504
+ )
473
505
  update = self._output_dense(feature)
474
506
  update = self._output_activation(update)
475
507
  return tensor.update(
476
508
  {
477
509
  'node': {
478
510
  'feature': update,
479
- 'previous_feature': None,
511
+ 'aggregate': None,
480
512
  }
481
513
  }
482
514
  )
@@ -486,8 +518,8 @@ class GraphConv(GraphLayer):
486
518
  config.update({
487
519
  'units': self.units,
488
520
  'activation': keras.activations.serialize(self._activation),
489
- 'normalization': self._normalization,
490
- 'skip_connection': self._skip_connection,
521
+ 'normalize': self._normalize,
522
+ 'skip_connect': self._skip_connect,
491
523
  })
492
524
  return config
493
525
 
@@ -530,14 +562,14 @@ class GIConv(GraphConv):
530
562
  units: int,
531
563
  activation: keras.layers.Activation | str | None = 'relu',
532
564
  use_bias: bool = True,
533
- normalization: bool = False,
565
+ normalize: bool = False,
534
566
  update_edge_feature: bool = True,
535
567
  **kwargs,
536
568
  ):
537
569
  super().__init__(
538
570
  units=units,
539
571
  activation=activation,
540
- normalization=normalization,
572
+ normalize=normalize,
541
573
  use_bias=use_bias,
542
574
  **kwargs
543
575
  )
@@ -599,7 +631,7 @@ class GIConv(GraphConv):
599
631
  return tensor.update(
600
632
  {
601
633
  'node': {
602
- 'feature': node_feature,
634
+ 'aggregate': node_feature,
603
635
  },
604
636
  'edge': {
605
637
  'message': None,
@@ -608,7 +640,7 @@ class GIConv(GraphConv):
608
640
  )
609
641
 
610
642
  def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
611
- node_feature = tensor.node['feature']
643
+ node_feature = tensor.node['aggregate']
612
644
  node_feature = self._feedforward_intermediate_dense(node_feature)
613
645
  node_feature = self._feedforward_activation(node_feature)
614
646
  node_feature = self._feedforward_output_dense(node_feature)
@@ -616,6 +648,7 @@ class GIConv(GraphConv):
616
648
  {
617
649
  'node': {
618
650
  'feature': node_feature,
651
+ 'aggregate': None,
619
652
  }
620
653
  }
621
654
  )
@@ -667,17 +700,16 @@ class GAConv(GraphConv):
667
700
  heads: int = 8,
668
701
  activation: keras.layers.Activation | str | None = "relu",
669
702
  use_bias: bool = True,
670
- normalization: bool = False,
703
+ normalize: bool = False,
671
704
  update_edge_feature: bool = True,
672
705
  attention_activation: keras.layers.Activation | str | None = "leaky_relu",
673
706
  **kwargs,
674
707
  ) -> None:
675
- kwargs['skip_connection'] = False
676
708
  super().__init__(
677
709
  units=units,
678
710
  activation=activation,
679
711
  use_bias=use_bias,
680
- normalization=normalization,
712
+ normalize=normalize,
681
713
  **kwargs
682
714
  )
683
715
  self._heads = heads
@@ -753,11 +785,11 @@ class GAConv(GraphConv):
753
785
  )
754
786
  node_feature = self._node_dense(tensor.node['feature'])
755
787
  message = ops.gather(node_feature, tensor.edge['source'])
788
+ message = ops.edge_weight(message, attention_score)
756
789
  return tensor.update(
757
790
  {
758
791
  'edge': {
759
792
  'message': message,
760
- 'weight': attention_score,
761
793
  'feature': edge_feature,
762
794
  }
763
795
  }
@@ -770,24 +802,24 @@ class GAConv(GraphConv):
770
802
  return tensor.update(
771
803
  {
772
804
  'node': {
773
- 'feature': node_feature
805
+ 'aggregate': node_feature
774
806
  },
775
807
  'edge': {
776
808
  'message': None,
777
- 'weight': None,
778
809
  }
779
810
  }
780
811
  )
781
812
 
782
813
  def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
783
- node_feature = tensor.node['feature']
814
+ node_feature = tensor.node['aggregate']
784
815
  node_feature = self._feedforward_intermediate_dense(node_feature)
785
816
  node_feature = self._feedforward_activation(node_feature)
786
817
  node_feature = self._feedforward_output_dense(node_feature)
787
818
  return tensor.update(
788
819
  {
789
820
  'node': {
790
- 'feature': node_feature
821
+ 'feature': node_feature,
822
+ 'aggregate': None,
791
823
  }
792
824
  }
793
825
  )
@@ -842,7 +874,7 @@ class GTConv(GraphConv):
842
874
  heads: int = 8,
843
875
  activation: keras.layers.Activation | str | None = "relu",
844
876
  use_bias: bool = True,
845
- normalization: bool = False,
877
+ normalize: bool = False,
846
878
  attention_dropout: float = 0.0,
847
879
  **kwargs,
848
880
  ) -> None:
@@ -850,7 +882,7 @@ class GTConv(GraphConv):
850
882
  units=units,
851
883
  activation=activation,
852
884
  use_bias=use_bias,
853
- normalization=normalization,
885
+ normalize=normalize,
854
886
  **kwargs
855
887
  )
856
888
  self._heads = heads
@@ -901,7 +933,6 @@ class GTConv(GraphConv):
901
933
  }
902
934
  }
903
935
  )
904
-
905
936
  node_feature = tensor.node['feature']
906
937
 
907
938
  query = self._query_dense(node_feature)
@@ -918,12 +949,12 @@ class GTConv(GraphConv):
918
949
  attention_score += keras.ops.expand_dims(tensor.edge['bias'], axis=1)
919
950
  attention = ops.edge_softmax(attention_score, tensor.edge['target'])
920
951
  attention = self._softmax_dropout(attention)
952
+ message = ops.edge_weight(value, attention)
921
953
 
922
954
  return tensor.update(
923
955
  {
924
956
  'edge': {
925
- 'message': value,
926
- 'weight': attention,
957
+ 'message': message
927
958
  },
928
959
  }
929
960
  )
@@ -935,18 +966,16 @@ class GTConv(GraphConv):
935
966
  return tensor.update(
936
967
  {
937
968
  'node': {
938
- 'feature': node_feature,
939
- 'residual': tensor.node['feature']
969
+ 'aggregate': node_feature,
940
970
  },
941
971
  'edge': {
942
972
  'message': None,
943
- 'weight': None,
944
973
  }
945
974
  }
946
975
  )
947
976
 
948
977
  def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
949
- node_feature = tensor.node['feature']
978
+ node_feature = tensor.node['aggregate']
950
979
  node_feature = self._feedforward_intermediate_dense(node_feature)
951
980
  node_feature = self._feedforward_activation(node_feature)
952
981
  node_feature = self._feedforward_output_dense(node_feature)
@@ -954,6 +983,7 @@ class GTConv(GraphConv):
954
983
  {
955
984
  'node': {
956
985
  'feature': node_feature,
986
+ 'aggregate': None,
957
987
  },
958
988
  }
959
989
  )
@@ -978,14 +1008,14 @@ class MPConv(GraphConv):
978
1008
  units: int = 128,
979
1009
  activation: keras.layers.Activation | str | None = None,
980
1010
  use_bias: bool = True,
981
- normalization: bool = False,
1011
+ normalize: bool = False,
982
1012
  **kwargs
983
1013
  ) -> None:
984
1014
  super().__init__(
985
1015
  units=units,
986
1016
  activation=activation,
987
1017
  use_bias=use_bias,
988
- normalization=normalization,
1018
+ normalize=normalize,
989
1019
  **kwargs
990
1020
  )
991
1021
 
@@ -1032,28 +1062,28 @@ class MPConv(GraphConv):
1032
1062
 
1033
1063
  def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1034
1064
  aggregate = tensor.aggregate('message', mode='mean')
1035
- previous = tensor.node['feature']
1065
+ feature = tensor.node['feature']
1036
1066
  if self.project_input_node_feature:
1037
- previous = self._previous_node_dense(previous)
1067
+ feature = self._previous_node_dense(feature)
1038
1068
  return tensor.update(
1039
1069
  {
1040
1070
  'node': {
1041
- 'feature': aggregate,
1042
- 'previous_feature': previous,
1071
+ 'aggregate': aggregate,
1072
+ 'feature': feature,
1043
1073
  }
1044
1074
  }
1045
1075
  )
1046
1076
 
1047
1077
  def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1048
1078
  updated_node_feature, _ = self.update_fn(
1049
- inputs=tensor.node['feature'],
1050
- states=tensor.node['previous_feature']
1079
+ inputs=tensor.node['aggregate'],
1080
+ states=tensor.node['feature']
1051
1081
  )
1052
1082
  return tensor.update(
1053
1083
  {
1054
1084
  'node': {
1055
1085
  'feature': updated_node_feature,
1056
- 'previous_feature': None,
1086
+ 'aggregate': None,
1057
1087
  }
1058
1088
  }
1059
1089
  )
@@ -1124,12 +1154,13 @@ class GTConv3D(GTConv):
1124
1154
  attention *= keras.ops.expand_dims(distance, axis=-1)
1125
1155
  attention = keras.ops.expand_dims(attention, axis=2)
1126
1156
  value = keras.ops.expand_dims(value, axis=1)
1157
+
1158
+ message = ops.edge_weight(value, attention)
1127
1159
 
1128
1160
  return tensor.update(
1129
1161
  {
1130
1162
  'edge': {
1131
- 'message': value,
1132
- 'weight': attention,
1163
+ 'message': message,
1133
1164
  },
1134
1165
  }
1135
1166
  )
@@ -1144,12 +1175,10 @@ class GTConv3D(GTConv):
1144
1175
  return tensor.update(
1145
1176
  {
1146
1177
  'node': {
1147
- 'feature': node_feature,
1148
- 'residual': tensor.node['feature']
1178
+ 'aggregate': node_feature,
1149
1179
  },
1150
1180
  'edge': {
1151
1181
  'message': None,
1152
- 'weight': None,
1153
1182
  }
1154
1183
  }
1155
1184
  )
@@ -1204,14 +1233,14 @@ class EGConv3D(GraphConv):
1204
1233
  units: int = 128,
1205
1234
  activation: keras.layers.Activation | str | None = None,
1206
1235
  use_bias: bool = True,
1207
- normalization: bool = False,
1236
+ normalize: bool = False,
1208
1237
  **kwargs
1209
1238
  ) -> None:
1210
1239
  super().__init__(
1211
1240
  units=units,
1212
1241
  activation=activation,
1213
1242
  use_bias=use_bias,
1214
- normalization=normalization,
1243
+ normalize=normalize,
1215
1244
  **kwargs
1216
1245
  )
1217
1246
 
@@ -1223,7 +1252,7 @@ class EGConv3D(GraphConv):
1223
1252
  )
1224
1253
  self._has_edge_feature = 'feature' in spec.edge
1225
1254
  self.message_fn = self.get_dense(self.units, activation=self._activation)
1226
- self.dense_position = self.get_dense(1)
1255
+ self.dense_position = self.get_dense(1, use_bias=False, kernel_initializer='zeros')
1227
1256
 
1228
1257
  has_overridden_update = self.__class__.update != EGConv3D.update
1229
1258
  if not has_overridden_update:
@@ -1273,27 +1302,26 @@ class EGConv3D(GraphConv):
1273
1302
  )
1274
1303
 
1275
1304
  def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1276
- coefficient = keras.ops.bincount(
1277
- tensor.edge['source'],
1278
- minlength=tensor.num_nodes
1279
- )
1280
- coefficient = keras.ops.cast(
1281
- coefficient, tensor.node['coordinate'].dtype
1282
- )
1283
- coefficient = keras.ops.expand_dims(
1284
- keras.ops.divide_no_nan(1, coefficient), axis=1
1285
- )
1286
-
1287
- updated_coordinate = tensor.aggregate('relative_node_coordinate') * coefficient
1305
+ # coefficient = keras.ops.bincount(
1306
+ # tensor.edge['source'],
1307
+ # minlength=tensor.num_nodes
1308
+ # )
1309
+ # coefficient = keras.ops.cast(
1310
+ # coefficient, tensor.node['coordinate'].dtype
1311
+ # )
1312
+ # coefficient = keras.ops.expand_dims(
1313
+ # keras.ops.divide_no_nan(1, coefficient), axis=1
1314
+ # )
1315
+
1316
+ updated_coordinate = tensor.aggregate('relative_node_coordinate', mode='mean')# * coefficient
1288
1317
  updated_coordinate += tensor.node['coordinate']
1289
1318
 
1290
1319
  aggregate = tensor.aggregate('message', mode='mean')
1291
1320
  return tensor.update(
1292
1321
  {
1293
1322
  'node': {
1294
- 'feature': aggregate,
1323
+ 'aggregate': aggregate,
1295
1324
  'coordinate': updated_coordinate,
1296
- 'previous_feature': tensor.node['feature'],
1297
1325
  },
1298
1326
  'edge': {
1299
1327
  'message': None,
@@ -1306,8 +1334,8 @@ class EGConv3D(GraphConv):
1306
1334
  updated_node_feature = self.update_fn(
1307
1335
  keras.ops.concatenate(
1308
1336
  [
1309
- tensor.node['feature'],
1310
- tensor.node['previous_feature']
1337
+ tensor.node['aggregate'],
1338
+ tensor.node['feature']
1311
1339
  ],
1312
1340
  axis=-1
1313
1341
  )
@@ -1317,7 +1345,7 @@ class EGConv3D(GraphConv):
1317
1345
  {
1318
1346
  'node': {
1319
1347
  'feature': updated_node_feature,
1320
- 'previous_feature': None,
1348
+ 'aggregate': None,
1321
1349
  },
1322
1350
  }
1323
1351
  )
@@ -1478,6 +1506,32 @@ class GraphNetwork(GraphLayer):
1478
1506
  return super().from_config(config)
1479
1507
 
1480
1508
 
1509
+ @keras.saving.register_keras_serializable(package='molcraft')
1510
+ class Extraction(GraphLayer):
1511
+
1512
+ def __init__(
1513
+ self,
1514
+ field: str,
1515
+ inner_field: str | None = None,
1516
+ **kwargs
1517
+ ) -> None:
1518
+ super().__init__(**kwargs)
1519
+ self.field = field
1520
+ self.inner_field = inner_field
1521
+
1522
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1523
+ data = dict(getattr(tensor, self.field))
1524
+ if not self.inner_field:
1525
+ return data
1526
+ return data[self.inner_field]
1527
+
1528
+ def get_config(self):
1529
+ config = super().get_config()
1530
+ config['field'] = self.field
1531
+ config['inner_field'] = self.inner_field
1532
+ return config
1533
+
1534
+
1481
1535
  @keras.saving.register_keras_serializable(package='molcraft')
1482
1536
  class NodeEmbedding(GraphLayer):
1483
1537
 
@@ -1489,15 +1543,15 @@ class NodeEmbedding(GraphLayer):
1489
1543
  def __init__(
1490
1544
  self,
1491
1545
  dim: int = None,
1492
- normalization: bool = False,
1493
- embed_context: bool = True,
1546
+ normalize: bool = False,
1547
+ embed_context: bool = False,
1494
1548
  allow_reconstruction: bool = False,
1495
- allow_masking: bool = True,
1549
+ allow_masking: bool = False,
1496
1550
  **kwargs
1497
1551
  ) -> None:
1498
1552
  super().__init__(**kwargs)
1499
1553
  self.dim = dim
1500
- self._normalization = normalization
1554
+ self._normalize = normalize
1501
1555
  self._embed_context = embed_context
1502
1556
  self._masking_rate = None
1503
1557
  self._allow_masking = allow_masking
@@ -1517,13 +1571,11 @@ class NodeEmbedding(GraphLayer):
1517
1571
  self._super_feature = self.get_weight(shape=[self.dim], name='super_node_feature')
1518
1572
  if self._allow_masking:
1519
1573
  self._mask_feature = self.get_weight(shape=[self.dim], name='mask_node_feature')
1520
-
1521
1574
  if self._embed_context:
1522
- context_feature_dim = spec.context['feature'].shape[-1]
1523
1575
  self._context_dense = self.get_dense(self.dim)
1524
1576
 
1525
- if self._normalization:
1526
- if str(self._normalization).lower().startswith('batch'):
1577
+ if self._normalize:
1578
+ if str(self._normalize).lower().startswith('batch'):
1527
1579
  self._norm = keras.layers.BatchNormalization(
1528
1580
  name='output_batch_norm'
1529
1581
  )
@@ -1545,48 +1597,25 @@ class NodeEmbedding(GraphLayer):
1545
1597
  feature = ops.scatter_update(feature, tensor.node['super'], context_feature)
1546
1598
  tensor = tensor.update({'context': {'feature': None}})
1547
1599
 
1548
- if (
1549
- self._allow_masking and
1550
- self._masking_rate is not None and
1551
- self._masking_rate > 0
1552
- ):
1553
- random = keras.random.uniform(shape=[tensor.num_nodes])
1554
- mask = random <= self._masking_rate
1555
- if self._has_super:
1556
- mask = keras.ops.logical_and(
1557
- mask, keras.ops.logical_not(tensor.node['super'])
1558
- )
1559
- mask = keras.ops.expand_dims(mask, -1)
1600
+ apply_mask = (self._allow_masking and 'mask' in tensor.node)
1601
+ if apply_mask:
1602
+ mask = keras.ops.expand_dims(tensor.node['mask'], -1)
1560
1603
  feature = keras.ops.where(mask, self._mask_feature, feature)
1561
1604
  elif self._allow_masking:
1562
- # Slience warning of 'no gradients for variables'
1563
1605
  feature = feature + (self._mask_feature * 0.0)
1564
1606
 
1565
- if self._normalization:
1607
+ if self._normalize:
1566
1608
  feature = self._norm(feature)
1567
1609
 
1568
1610
  if not self._allow_reconstruction:
1569
1611
  return tensor.update({'node': {'feature': feature}})
1570
1612
  return tensor.update({'node': {'feature': feature, 'target_feature': feature}})
1571
-
1572
- @property
1573
- def masking_rate(self):
1574
- return self._masking_rate
1575
-
1576
- @masking_rate.setter
1577
- def masking_rate(self, rate: float):
1578
- if not self._allow_masking and rate is not None:
1579
- raise ValueError(
1580
- f'Cannot set `masking_rate` for layer {self} '
1581
- 'as `allow_masking` was set to `False`.'
1582
- )
1583
- self._masking_rate = float(rate)
1584
1613
 
1585
1614
  def get_config(self) -> dict:
1586
1615
  config = super().get_config()
1587
1616
  config.update({
1588
1617
  'dim': self.dim,
1589
- 'normalization': self._normalization,
1618
+ 'normalize': self._normalize,
1590
1619
  'embed_context': self._embed_context,
1591
1620
  'allow_masking': self._allow_masking,
1592
1621
  'allow_reconstruction': self._allow_reconstruction,
@@ -1605,13 +1634,13 @@ class EdgeEmbedding(GraphLayer):
1605
1634
  def __init__(
1606
1635
  self,
1607
1636
  dim: int = None,
1608
- normalization: bool = False,
1637
+ normalize: bool = False,
1609
1638
  allow_masking: bool = True,
1610
1639
  **kwargs
1611
1640
  ) -> None:
1612
1641
  super().__init__(**kwargs)
1613
1642
  self.dim = dim
1614
- self._normalization = normalization
1643
+ self._normalize = normalize
1615
1644
  self._masking_rate = None
1616
1645
  self._allow_masking = allow_masking
1617
1646
 
@@ -1622,13 +1651,16 @@ class EdgeEmbedding(GraphLayer):
1622
1651
  self._edge_dense = self.get_dense(self.dim)
1623
1652
 
1624
1653
  self._has_super = 'super' in spec.edge
1654
+ self._has_self_loop = 'self_loop' in spec.edge
1625
1655
  if self._has_super:
1626
1656
  self._super_feature = self.get_weight(shape=[self.dim], name='super_edge_feature')
1657
+ if self._has_self_loop:
1658
+ self._self_loop_feature = self.get_weight(shape=[self.dim], name='self_loop_edge_feature')
1627
1659
  if self._allow_masking:
1628
1660
  self._mask_feature = self.get_weight(shape=[self.dim], name='mask_edge_feature')
1629
1661
 
1630
- if self._normalization:
1631
- if str(self._normalization).lower().startswith('batch'):
1662
+ if self._normalize:
1663
+ if str(self._normalize).lower().startswith('batch'):
1632
1664
  self._norm = keras.layers.BatchNormalization(
1633
1665
  name='output_batch_norm'
1634
1666
  )
@@ -1641,10 +1673,13 @@ class EdgeEmbedding(GraphLayer):
1641
1673
  feature = self._edge_dense(tensor.edge['feature'])
1642
1674
 
1643
1675
  if self._has_super:
1644
- super_feature = self._super_feature
1645
1676
  super_mask = keras.ops.expand_dims(tensor.edge['super'], 1)
1646
- feature = keras.ops.where(super_mask, super_feature, feature)
1677
+ feature = keras.ops.where(super_mask, self._super_feature, feature)
1647
1678
 
1679
+ if self._has_self_loop:
1680
+ self_loop_mask = keras.ops.expand_dims(tensor.edge['self_loop'], 1)
1681
+ feature = keras.ops.where(self_loop_mask, self._self_loop_feature, feature)
1682
+
1648
1683
  if (
1649
1684
  self._allow_masking and
1650
1685
  self._masking_rate is not None and
@@ -1662,7 +1697,7 @@ class EdgeEmbedding(GraphLayer):
1662
1697
  # Slience warning of 'no gradients for variables'
1663
1698
  feature = feature + (self._mask_feature * 0.0)
1664
1699
 
1665
- if self._normalization:
1700
+ if self._normalize:
1666
1701
  feature = self._norm(feature)
1667
1702
 
1668
1703
  return tensor.update({'edge': {'feature': feature, 'embedding': feature}})
@@ -1684,7 +1719,7 @@ class EdgeEmbedding(GraphLayer):
1684
1719
  config = super().get_config()
1685
1720
  config.update({
1686
1721
  'dim': self.dim,
1687
- 'normalization': self._normalization,
1722
+ 'normalize': self._normalize,
1688
1723
  'allow_masking': self._allow_masking
1689
1724
  })
1690
1725
  return config
@@ -1883,6 +1918,56 @@ class GaussianDistance(GraphLayer):
1883
1918
  return config
1884
1919
 
1885
1920
 
1921
+ @keras.saving.register_keras_serializable(package='molcraft')
1922
+ class GaussianParams(keras.layers.Dense):
1923
+ '''Gaussian parameters.
1924
+
1925
+ Computes loc and scale via a dense layer. Should be implemented
1926
+ as the last layer in a model and paired with `losses.GaussianNLL`.
1927
+
1928
+ The loc and scale parameters (resulting from this layer) are concatenated
1929
+ together along the last axis, resulting in a single output tensor.
1930
+
1931
+ Args:
1932
+ events (int):
1933
+ The number of events. If the model makes a single prediction per example,
1934
+ then the number of events should be 1. If the model makes multiple predictions
1935
+ per example, then the number of events should be greater than 1.
1936
+ Default to 1.
1937
+ kwargs:
1938
+ See `keras.layers.Dense` documentation. `activation` will be applied
1939
+ to `loc` only. `scale` is automatically softplus activated.
1940
+ '''
1941
+ def __init__(self, events: int = 1, **kwargs):
1942
+ units = kwargs.pop('units', None)
1943
+ activation = kwargs.pop('activation', None)
1944
+ if units:
1945
+ if units % 2 != 0:
1946
+ raise ValueError(
1947
+ '`units` needs to be divisble by 2 as `units` = 2 x `events`.'
1948
+ )
1949
+ else:
1950
+ units = int(events * 2)
1951
+ super().__init__(units=units, **kwargs)
1952
+ self.events = events
1953
+ self.loc_activation = keras.activations.get(activation)
1954
+
1955
+ def call(self, inputs, **kwargs):
1956
+ loc_and_scale = super().call(inputs, **kwargs)
1957
+ loc = loc_and_scale[..., :self.events]
1958
+ scale = loc_and_scale[..., self.events:]
1959
+ scale = keras.ops.softplus(scale) + keras.backend.epsilon()
1960
+ loc = self.loc_activation(loc)
1961
+ return keras.ops.concatenate([loc, scale], axis=-1)
1962
+
1963
+ def get_config(self):
1964
+ config = super().get_config()
1965
+ config['events'] = self.events
1966
+ config['units'] = None
1967
+ config['activation'] = keras.activations.serialize(self.loc_activation)
1968
+ return config
1969
+
1970
+
1886
1971
  def Input(spec: tensors.GraphTensor.Spec) -> dict:
1887
1972
  """Used to specify inputs to model.
1888
1973
 
@@ -1915,6 +2000,11 @@ def Input(spec: tensors.GraphTensor.Spec) -> dict:
1915
2000
  inputs[outer_field] = {}
1916
2001
  for inner_field, nested_spec in data.items():
1917
2002
  if inner_field in ['label', 'weight']:
2003
+ # Remove context label and weight from the symbolic input
2004
+ # as a functional model is strict for what input can be passed.
2005
+ # We want to be able to pass a graph with or without labels
2006
+ # and sample weights. The __call__ method of the `GraphModel`
2007
+ # temporarily pops label and weight to avoid errors.
1918
2008
  if outer_field == 'context':
1919
2009
  continue
1920
2010
  kwargs = {
@@ -1941,23 +2031,6 @@ def warn(message: str) -> None:
1941
2031
  stacklevel=1
1942
2032
  )
1943
2033
 
1944
- def _match_functional_input(functional_input, inputs):
1945
- matching_inputs = {}
1946
- for outer_field, data in functional_input.items():
1947
- matching_inputs[outer_field] = {}
1948
- for inner_field, _ in data.items():
1949
- call_input = inputs[outer_field].pop(inner_field)
1950
- matching_inputs[outer_field][inner_field] = call_input
1951
- unmatching_inputs = inputs
1952
- return matching_inputs, unmatching_inputs
1953
-
1954
- def _add_left_out_inputs(outputs, inputs):
1955
- for outer_field, data in inputs.items():
1956
- for inner_field, value in data.items():
1957
- if inner_field in ['label', 'weight']:
1958
- outputs[outer_field][inner_field] = value
1959
- return outputs
1960
-
1961
2034
  def _serialize_spec(spec: tensors.GraphTensor.Spec) -> dict:
1962
2035
  serialized_spec = {}
1963
2036
  for outer_field, data in spec.__dict__.items():