molcraft 0.1.0a2__py3-none-any.whl → 0.1.0a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

molcraft/layers.py CHANGED
@@ -1,7 +1,7 @@
1
- import abc
2
1
  import keras
3
2
  import tensorflow as tf
4
3
  import warnings
4
+ import functools
5
5
  from keras.src.models import functional
6
6
 
7
7
  from molcraft import tensors
@@ -12,11 +12,39 @@ from molcraft import ops
12
12
  class GraphLayer(keras.layers.Layer):
13
13
  """Base graph layer.
14
14
 
15
- Currently, the `GraphLayer` only supports `GraphTensor` input.
15
+ Subclasses must implement a forward pass via **propagate(graph)**.
16
+
17
+ Subclasses may create dense layers and weights in **build(graph_spec)**.
18
+
19
+ Note: `GraphLayer` currently only supports `GraphTensor` input.
16
20
 
17
- The list of arguments are only relevant if the derived layer
21
+ The list of arguments below are only relevant if the derived layer
18
22
  invokes 'get_dense_kwargs`, `get_dense` or `get_einsum_dense`.
19
23
 
24
+ Arguments:
25
+ use_bias (bool):
26
+ Whether bias should be used in dense layers. Default to `True`.
27
+ kernel_initializer (keras.initializers.Initializer, str):
28
+ Initializer for the kernel weight matrix of the dense layers.
29
+ Default to `glorot_uniform`.
30
+ bias_initializer (keras.initializers.Initializer, str):
31
+ Initializer for the bias weight vector of the dense layers.
32
+ Default to `zeros`.
33
+ kernel_regularizer (keras.regularizers.Regularizer, None):
34
+ Regularizer function applied to the kernel weight matrix.
35
+ Default to `None`.
36
+ bias_regularizer (keras.regularizers.Regularizer, None):
37
+ Regularizer function applied to the bias weight vector.
38
+ Default to `None`.
39
+ activity_regularizer (keras.regularizers.Regularizer, None):
40
+ Regularizer function applied to the output of the dense layers.
41
+ Default to `None`.
42
+ kernel_constraint (keras.constraints.Constraint, None):
43
+ Constraint function applied to the kernel weight matrix.
44
+ Default to `None`.
45
+ bias_constraint (keras.constraints.Constraint, None):
46
+ Constraint function applied to the bias weight vector.
47
+ Default to `None`.
20
48
  """
21
49
 
22
50
  def __init__(
@@ -31,73 +59,61 @@ class GraphLayer(keras.layers.Layer):
31
59
  bias_constraint: keras.constraints.Constraint | None = None,
32
60
  **kwargs,
33
61
  ) -> None:
34
- super().__init__(activity_regularizer=activity_regularizer, **kwargs)
62
+ super().__init__(**kwargs)
35
63
  self._use_bias = use_bias
36
64
  self._kernel_initializer = keras.initializers.get(kernel_initializer)
37
65
  self._bias_initializer = keras.initializers.get(bias_initializer)
38
66
  self._kernel_regularizer = keras.regularizers.get(kernel_regularizer)
39
67
  self._bias_regularizer = keras.regularizers.get(bias_regularizer)
68
+ self._activity_regularizer = keras.regularizers.get(activity_regularizer)
40
69
  self._kernel_constraint = keras.constraints.get(kernel_constraint)
41
70
  self._bias_constraint = keras.constraints.get(bias_constraint)
71
+ self._custom_build_config = {}
42
72
  self.built = False
43
- # TODO: Add warning if build is implemented in subclass
44
- # TODO: Add warning if call is implemented in subclass
73
+
74
+ def __init_subclass__(cls, **kwargs):
75
+ super().__init_subclass__(**kwargs)
76
+ subclass_build = cls.build
77
+
78
+ @functools.wraps(subclass_build)
79
+ def build_wrapper(self: GraphLayer, spec: tensors.GraphTensor.Spec | None):
80
+ GraphLayer.build(self, spec)
81
+ subclass_build(self, spec)
82
+ if not self.built and isinstance(self, keras.Model):
83
+ symbolic_inputs = Input(spec)
84
+ self.built = True
85
+ self(symbolic_inputs)
86
+
87
+ cls.build = build_wrapper
45
88
 
46
89
  def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
47
- """Calls the layer.
90
+ """Forward pass.
48
91
 
49
- Needs to be implemented by subclass.
92
+ Must be implemented by subclass.
50
93
 
51
- Args:
94
+ Arguments:
52
95
  tensor:
53
96
  A `GraphTensor` instance.
54
97
  """
55
- raise NotImplementedError('`propagate` needs to be implemented.')
98
+ raise NotImplementedError(
99
+ 'The forward pass of the layer is not implemented. '
100
+ 'Please implement `propagate`.'
101
+ )
56
102
 
57
- def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
103
+ def build(self, tensor_spec: tensors.GraphTensor.Spec) -> None:
58
104
  """Builds the layer.
59
105
 
60
106
  May use built-in methods such as `get_weight`, `get_dense` and `get_einsum_dense`.
61
107
 
62
- Optionally implemented by subclass. If implemented, it is recommended to
63
- build the sub-layers via `build([None, input_dim])`. If sub-layers are not
64
- built, symbolic input will be passed through the layer to build it.
108
+ Optionally implemented by subclass.
65
109
 
66
- Args:
67
- spec:
68
- A `GraphTensor.Spec` instance, corresponding to the input `GraphTensor`
69
- of the `propagate` method.
110
+ Arguments:
111
+ tensor_spec:
112
+ A `GraphTensor.Spec` instance corresponding to the `GraphTensor`
113
+ passed to `propagate`.
70
114
  """
71
-
72
- def build(self, spec: tensors.GraphTensor.Spec) -> None:
73
-
74
- self._custom_build_config = {'spec': _serialize_spec(spec)}
75
-
76
- invoke_build_from_spec = (
77
- GraphLayer.build_from_spec != self.__class__.build_from_spec
78
- )
79
- if invoke_build_from_spec:
80
- self.build_from_spec(spec)
81
- self.built = True
82
-
83
- if not self.built:
84
- # Automatically build layer or model by calling it on symbolic inputs
85
- self.built = True
86
- symbolic_inputs = Input(spec)
87
- self(symbolic_inputs)
88
-
89
- def get_build_config(self) -> dict:
90
- if not hasattr(self, '_custom_build_config'):
91
- return super().get_build_config()
92
- return self._custom_build_config
93
-
94
- def build_from_config(self, config: dict) -> None:
95
- use_custom_build_from_config = ('spec' in config)
96
- if not use_custom_build_from_config:
97
- super().build_from_config(config)
98
- else:
99
- spec = _deserialize_spec(config['spec'])
100
- self.build(spec)
115
+ if isinstance(tensor_spec, tensors.GraphTensor.Spec):
116
+ self._custom_build_config['spec'] = _serialize_spec(tensor_spec)
101
117
 
102
118
  def call(
103
119
  self,
@@ -127,6 +143,19 @@ class GraphLayer(keras.layers.Layer):
127
143
  outputs = tensors.from_dict(outputs)
128
144
  return outputs
129
145
 
146
+ def get_build_config(self) -> dict:
147
+ if self._custom_build_config:
148
+ return self._custom_build_config
149
+ return super().get_build_config()
150
+
151
+ def build_from_config(self, config: dict) -> None:
152
+ serialized_spec = config.get('spec')
153
+ if serialized_spec is not None:
154
+ spec = _deserialize_spec(serialized_spec)
155
+ self.build(spec)
156
+ else:
157
+ super().build_from_config(config)
158
+
130
159
  def get_weight(
131
160
  self,
132
161
  shape: tf.TensorShape,
@@ -168,7 +197,7 @@ class GraphLayer(keras.layers.Layer):
168
197
  use_bias=self._use_bias,
169
198
  kernel_regularizer=self._kernel_regularizer,
170
199
  bias_regularizer=self._bias_regularizer,
171
- activity_regularizer=self.activity_regularizer,
200
+ activity_regularizer=self._activity_regularizer,
172
201
  kernel_constraint=self._kernel_constraint,
173
202
  bias_constraint=self._bias_constraint,
174
203
  )
@@ -194,52 +223,137 @@ class GraphLayer(keras.layers.Layer):
194
223
  keras.regularizers.serialize(self._kernel_regularizer),
195
224
  "bias_regularizer":
196
225
  keras.regularizers.serialize(self._bias_regularizer),
226
+ "activity_regularizer":
227
+ keras.regularizers.serialize(self._activity_regularizer),
197
228
  "kernel_constraint":
198
229
  keras.constraints.serialize(self._kernel_constraint),
199
230
  "bias_constraint":
200
231
  keras.constraints.serialize(self._bias_constraint),
201
232
  })
202
233
  return config
203
-
234
+
204
235
 
205
236
  @keras.saving.register_keras_serializable(package='molcraft')
206
237
  class GraphConv(GraphLayer):
207
238
 
208
239
  """Base graph neural network layer.
209
240
 
210
- For normalization and skip connection to work, the `GraphConv` subclass
211
- requires the (node feature) output of `aggregate` and `update` to have a
212
- dimension of `self.units`, respectively.
213
-
214
- Args:
215
- units:
216
- The number of units.
217
- normalize:
218
- Whether `LayerNormalization` should be applied to the (node feature) output
219
- of the `aggregate` step. While normalization is recommended, it is not used
220
- by default.
221
- skip_connection:
222
- Whether (node feature) input should be added to the (node feature) output.
223
- If (node feature) input dim is not equal to `units`, a projection layer will
224
- automatically project the residual before adding it to the output. While skip
225
- connection is recommended, it is not used by default.
226
- kwargs:
227
- See arguments of `GraphLayer`.
241
+ This layer implements the three basic steps of a graph neural network layer, each of which
242
+ can (optionally) be overridden by the `GraphConv` subclass:
243
+
244
+ 1. **message(graph)**, which computes the *messages* to be passed to target nodes;
245
+ 2. **aggregate(graph)**, which *aggregates* messages to target nodes;
246
+ 3. **update(graph)**, which further *updates* (target) nodes.
247
+
248
+ Note: for skip connection to work, the `GraphConv` subclass requires final node feature
249
+ output dimension to be equal to `units`.
250
+
251
+ Arguments:
252
+ units (int):
253
+ Dimensionality of the output space.
254
+ activation (keras.layers.Activation, str, None):
255
+ Activation function to use. If not specified, a linear activation (a(x) = x) is used.
256
+ Default to `None`.
257
+ use_bias (bool):
258
+ Whether bias should be used in dense layers. Default to `True`.
259
+ normalization (bool, str):
260
+ Whether `LayerNormalization` should be applied to the final node feature output.
261
+ To use `BatchNormalization`, specify `batch_norm`. Default to `False`.
262
+ skip_connection (bool, str):
263
+ Whether node feature input should be added to the node feature output.
264
+ If node feature input dim is not equal to `units` (node feature output dim),
265
+ a projection layer will automatically project the residual before adding it
266
+ to the output. To use weighted skip connection,
267
+ specify `weighted`. The weight multiplied with the skip connection is a
268
+ learnable scalar. Default to `True`.
269
+ kernel_initializer (keras.initializers.Initializer, str):
270
+ Initializer for the kernel weight matrix of the dense layers.
271
+ Default to `glorot_uniform`.
272
+ bias_initializer (keras.initializers.Initializer, str):
273
+ Initializer for the bias weight vector of the dense layers.
274
+ Default to `zeros`.
275
+ kernel_regularizer (keras.regularizers.Regularizer, None):
276
+ Regularizer function applied to the kernel weight matrix.
277
+ Default to `None`.
278
+ bias_regularizer (keras.regularizers.Regularizer, None):
279
+ Regularizer function applied to the bias weight vector.
280
+ Default to `None`.
281
+ activity_regularizer (keras.regularizers.Regularizer, None):
282
+ Regularizer function applied to the output of the dense layers.
283
+ Default to `None`.
284
+ kernel_constraint (keras.constraints.Constraint, None):
285
+ Constraint function applied to the kernel weight matrix.
286
+ Default to `None`.
287
+ bias_constraint (keras.constraints.Constraint, None):
288
+ Constraint function applied to the bias weight vector.
289
+ Default to `None`.
228
290
  """
229
291
 
230
292
  def __init__(
231
293
  self,
232
- units: int,
233
- normalize: bool = False,
234
- skip_connection: bool = False,
294
+ units: int = None,
295
+ activation: str | keras.layers.Activation | None = None,
296
+ use_bias: bool = True,
297
+ normalization: bool | str = False,
298
+ skip_connection: bool | str = True,
235
299
  **kwargs
236
300
  ) -> None:
237
- super().__init__(**kwargs)
238
- self.units = units
239
- self._normalize_aggregate = normalize
301
+ super().__init__(use_bias=use_bias, **kwargs)
302
+ self._units = units
303
+ self._normalization = normalization
240
304
  self._skip_connection = skip_connection
305
+ self._activation = keras.activations.get(activation)
306
+
307
+ def __init_subclass__(cls, **kwargs):
308
+ super().__init_subclass__(**kwargs)
309
+ subclass_build = cls.build
310
+
311
+ @functools.wraps(subclass_build)
312
+ def build_wrapper(self, spec):
313
+ GraphConv.build(self, spec)
314
+ return subclass_build(self, spec)
315
+
316
+ cls.build = build_wrapper
317
+
318
+ @property
319
+ def units(self):
320
+ return self._units
321
+
322
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
323
+ """Forward pass.
324
+
325
+ Invokes `message(graph)`, `aggregate(graph)` and `update(graph)` in sequence.
326
+
327
+ Arguments:
328
+ tensor:
329
+ A `GraphTensor` instance.
330
+ """
331
+ if self._skip_connection:
332
+ input_node_feature = tensor.node['feature']
333
+ if self._project_input_node_feature:
334
+ input_node_feature = self._residual_projection(input_node_feature)
335
+
336
+ tensor = self.message(tensor)
337
+ tensor = self.aggregate(tensor)
338
+ tensor = self.update(tensor)
339
+
340
+ updated_node_feature = tensor.node['feature']
341
+
342
+ if self._skip_connection:
343
+ if self._use_weighted_skip_connection:
344
+ input_node_feature *= self._skip_connection_weight
345
+ updated_node_feature += input_node_feature
346
+
347
+ if self._normalization:
348
+ updated_node_feature = self._output_norm(updated_node_feature)
349
+
350
+ return tensor.update({'node': {'feature': updated_node_feature}})
241
351
 
242
352
  def build(self, spec: tensors.GraphTensor.Spec) -> None:
353
+ if not self.units:
354
+ raise ValueError(
355
+ f'`self.units` needs to be a positive integer. Found: {self.units}.'
356
+ )
243
357
  node_feature_dim = spec.node['feature'].shape[-1]
244
358
  self._project_input_node_feature = (
245
359
  self._skip_connection and (node_feature_dim != self.units)
@@ -254,81 +368,115 @@ class GraphConv(GraphLayer):
254
368
  self._residual_projection = self.get_dense(
255
369
  self.units, name='residual_projection'
256
370
  )
257
- if self._normalize_aggregate:
258
- self._aggregation_norm = keras.layers.LayerNormalization(
259
- name='aggregation_normalizer'
371
+
372
+ skip_connection = str(self._skip_connection).lower()
373
+ self._use_weighted_skip_connection = skip_connection.startswith('weight')
374
+ if self._use_weighted_skip_connection:
375
+ self._skip_connection_weight = self.add_weight(
376
+ name='skip_connection_weight',
377
+ shape=(),
378
+ initializer='ones',
379
+ trainable=True,
260
380
  )
261
- self._aggregation_norm.build([None, self.units])
262
381
 
263
- super().build(spec)
382
+ if self._normalization:
383
+ if str(self._normalization).lower().startswith('batch'):
384
+ self._output_norm = keras.layers.BatchNormalization(
385
+ name='output_batch_norm'
386
+ )
387
+ else:
388
+ self._output_norm = keras.layers.LayerNormalization(
389
+ name='output_layer_norm'
390
+ )
391
+
392
+ self._has_edge_feature = 'edge' in spec.edge
393
+
394
+ has_overridden_message = self.__class__.message != GraphConv.message
395
+ if not has_overridden_message:
396
+ self._message_dense = self.get_dense(self.units)
397
+
398
+ has_overridden_update = self.__class__.update != GraphConv.update
399
+ if not has_overridden_update:
400
+ self._output_dense = self.get_dense(self.units)
401
+ self._output_activation = self._activation
264
402
 
265
- @abc.abstractmethod
266
403
  def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
267
404
  """Compute messages.
268
405
 
269
- This method needs to be implemented by subclass.
406
+ This method may be overridden by subclass.
270
407
 
271
- Args:
408
+ Arguments:
272
409
  tensor:
273
410
  The inputted `GraphTensor` instance.
274
411
  """
275
-
276
- @abc.abstractmethod
412
+ if not self._has_edge_feature:
413
+ message_feature = tensor.gather('feature', 'source')
414
+ else:
415
+ message_feature = keras.ops.concatenate(
416
+ [
417
+ tensor.gather('feature', 'source'),
418
+ tensor.edge['feature']
419
+ ],
420
+ axis=-1
421
+ )
422
+ message = self._message_dense(message_feature)
423
+ return tensor.update(
424
+ {
425
+ 'edge': {
426
+ 'message': message
427
+ }
428
+ }
429
+ )
430
+
277
431
  def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
278
432
  """Aggregates messages.
279
433
 
280
- This method needs to be implemented by subclass.
434
+ This method may be overridden by subclass.
281
435
 
282
- Args:
436
+ Arguments:
283
437
  tensor:
284
438
  A `GraphTensor` instance containing a message.
285
439
  """
440
+ aggregate = tensor.aggregate('message', mode='mean')
441
+ return tensor.update(
442
+ {
443
+ 'node': {
444
+ 'feature': aggregate,
445
+ 'previous_feature': tensor.node['feature']
446
+ },
447
+ 'edge': {
448
+ 'message': None
449
+ }
450
+ }
451
+ )
286
452
 
287
- @abc.abstractmethod
288
453
  def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
289
454
  """Updates nodes.
290
455
 
291
- This method needs to be implemented by subclass.
456
+ This method may be overridden by subclass.
292
457
 
293
- Args:
458
+ Arguments:
294
459
  tensor:
295
460
  A `GraphTensor` instance containing aggregated messages
296
461
  (updated node features).
297
462
  """
298
-
299
- def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
300
- """Calls the layer.
301
-
302
- The `GraphConv` layer invokes `message`, `aggregate` and `update`
303
- in sequence.
304
-
305
- Args:
306
- tensor:
307
- A `GraphTensor` instance.
308
- """
309
-
310
- if self._skip_connection:
311
- input_node_feature = tensor.node['feature']
312
- if self._project_input_node_feature:
313
- input_node_feature = self._residual_projection(input_node_feature)
314
-
315
- tensor = self.message(tensor)
316
- tensor = self.aggregate(tensor)
317
-
318
- if self._normalize_aggregate:
319
- normalized_node_feature = self._aggregation_norm(tensor.node['feature'])
320
- tensor = tensor.update({'node': {'feature': normalized_node_feature}})
321
-
322
- tensor = self.update(tensor)
323
-
324
- if not self._skip_connection:
325
- return tensor
326
-
327
- updated_node_feature = tensor.node['feature']
463
+ if not 'previous_feature' in tensor.node:
464
+ feature = tensor.node['feature']
465
+ else:
466
+ feature = keras.ops.concatenate(
467
+ [
468
+ tensor.node['feature'],
469
+ tensor.node['previous_feature']
470
+ ],
471
+ axis=-1
472
+ )
473
+ update = self._output_dense(feature)
474
+ update = self._output_activation(update)
328
475
  return tensor.update(
329
476
  {
330
477
  'node': {
331
- 'feature': updated_node_feature + input_node_feature
478
+ 'feature': update,
479
+ 'previous_feature': None,
332
480
  }
333
481
  }
334
482
  )
@@ -337,16 +485,44 @@ class GraphConv(GraphLayer):
337
485
  config = super().get_config()
338
486
  config.update({
339
487
  'units': self.units,
340
- 'normalize': self._normalize_aggregate,
488
+ 'activation': keras.activations.serialize(self._activation),
489
+ 'normalization': self._normalization,
341
490
  'skip_connection': self._skip_connection,
342
491
  })
343
492
  return config
344
493
 
345
494
 
346
495
  @keras.saving.register_keras_serializable(package='molcraft')
347
- class GINConv(GraphConv):
496
+ class GIConv(GraphConv):
348
497
 
349
498
  """Graph isomorphism network layer.
499
+
500
+ >>> graph = molcraft.tensors.GraphTensor(
501
+ ... context={
502
+ ... 'size': [2]
503
+ ... },
504
+ ... node={
505
+ ... 'feature': [[1.], [2.]]
506
+ ... },
507
+ ... edge={
508
+ ... 'source': [0, 1],
509
+ ... 'target': [1, 0],
510
+ ... }
511
+ ... )
512
+ >>> conv = molcraft.layers.GIConv(units=4)
513
+ >>> conv(graph)
514
+ GraphTensor(
515
+ context={
516
+ 'size': <tf.Tensor: shape=[1], dtype=int32>
517
+ },
518
+ node={
519
+ 'feature': <tf.Tensor: shape=[2, 4], dtype=float32>
520
+ },
521
+ edge={
522
+ 'source': <tf.Tensor: shape=[2], dtype=int32>,
523
+ 'target': <tf.Tensor: shape=[2], dtype=int32>
524
+ }
525
+ )
350
526
  """
351
527
 
352
528
  def __init__(
@@ -354,24 +530,20 @@ class GINConv(GraphConv):
354
530
  units: int,
355
531
  activation: keras.layers.Activation | str | None = 'relu',
356
532
  use_bias: bool = True,
357
- normalize: bool = True,
358
- dropout: float = 0.0,
533
+ normalization: bool = False,
359
534
  update_edge_feature: bool = True,
360
535
  **kwargs,
361
536
  ):
362
537
  super().__init__(
363
538
  units=units,
364
- normalize=normalize,
539
+ activation=activation,
540
+ normalization=normalization,
365
541
  use_bias=use_bias,
366
542
  **kwargs
367
543
  )
368
- self._activation = keras.activations.get(activation)
369
- self._dropout = dropout
370
544
  self._update_edge_feature = update_edge_feature
371
545
 
372
- def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
373
- """Builds the layer.
374
- """
546
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
375
547
  node_feature_dim = spec.node['feature'].shape[-1]
376
548
 
377
549
  self.epsilon = self.add_weight(
@@ -381,7 +553,8 @@ class GINConv(GraphConv):
381
553
  trainable=True,
382
554
  )
383
555
 
384
- if 'feature' in spec.edge:
556
+ self._has_edge_feature = 'feature' in spec.edge
557
+ if self._has_edge_feature:
385
558
  edge_feature_dim = spec.edge['feature'].shape[-1]
386
559
 
387
560
  if not self._update_edge_feature:
@@ -395,31 +568,21 @@ class GINConv(GraphConv):
395
568
 
396
569
  if self._update_edge_feature:
397
570
  self._edge_dense = self.get_dense(node_feature_dim)
398
- self._edge_dense.build([None, edge_feature_dim])
399
571
  else:
400
572
  self._update_edge_feature = False
401
-
402
- self._feedforward_intermediate_dense = self.get_dense(self.units)
403
- self._feedforward_intermediate_dense.build([None, node_feature_dim])
404
573
 
405
- has_overridden_update = self.__class__.update != GINConv.update
574
+ has_overridden_update = self.__class__.update != GIConv.update
406
575
  if not has_overridden_update:
407
- # Use default feedforward network
408
-
409
- self._feedforward_dropout = keras.layers.Dropout(self._dropout)
576
+ self._feedforward_intermediate_dense = self.get_dense(self.units)
410
577
  self._feedforward_activation = self._activation
411
-
412
578
  self._feedforward_output_dense = self.get_dense(self.units)
413
- self._feedforward_output_dense.build([None, self.units])
414
-
579
+
415
580
  def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
416
- """Computes messages.
417
- """
418
581
  message = tensor.gather('feature', 'source')
419
582
  edge_feature = tensor.edge.get('feature')
420
583
  if self._update_edge_feature:
421
584
  edge_feature = self._edge_dense(edge_feature)
422
- if edge_feature is not None:
585
+ if self._has_edge_feature:
423
586
  message += edge_feature
424
587
  return tensor.update(
425
588
  {
@@ -431,12 +594,8 @@ class GINConv(GraphConv):
431
594
  )
432
595
 
433
596
  def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
434
- """Aggregates messages.
435
- """
436
- node_feature = tensor.aggregate('message')
597
+ node_feature = tensor.aggregate('message', mode='mean')
437
598
  node_feature += (1 + self.epsilon) * tensor.node['feature']
438
- node_feature = self._feedforward_intermediate_dense(node_feature)
439
- node_feature = self._feedforward_activation(node_feature)
440
599
  return tensor.update(
441
600
  {
442
601
  'node': {
@@ -449,10 +608,9 @@ class GINConv(GraphConv):
449
608
  )
450
609
 
451
610
  def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
452
- """Updates nodes.
453
- """
454
611
  node_feature = tensor.node['feature']
455
- node_feature = self._feedforward_dropout(node_feature)
612
+ node_feature = self._feedforward_intermediate_dense(node_feature)
613
+ node_feature = self._feedforward_activation(node_feature)
456
614
  node_feature = self._feedforward_output_dense(node_feature)
457
615
  return tensor.update(
458
616
  {
@@ -465,17 +623,217 @@ class GINConv(GraphConv):
465
623
  def get_config(self) -> dict:
466
624
  config = super().get_config()
467
625
  config.update({
468
- 'activation': keras.activations.serialize(self._activation),
469
- 'dropout': self._dropout,
470
626
  'update_edge_feature': self._update_edge_feature
471
627
  })
472
628
  return config
473
629
 
474
630
 
631
+ @keras.saving.register_keras_serializable(package='molgraphx')
632
+ class GAConv(GraphConv):
633
+
634
+ """Graph attention network layer.
635
+
636
+ >>> graph = molcraft.tensors.GraphTensor(
637
+ ... context={
638
+ ... 'size': [2]
639
+ ... },
640
+ ... node={
641
+ ... 'feature': [[1.], [2.]]
642
+ ... },
643
+ ... edge={
644
+ ... 'source': [0, 1],
645
+ ... 'target': [1, 0],
646
+ ... }
647
+ ... )
648
+ >>> conv = molcraft.layers.GAConv(units=4, heads=2)
649
+ >>> conv(graph)
650
+ GraphTensor(
651
+ context={
652
+ 'size': <tf.Tensor: shape=[1], dtype=int32>
653
+ },
654
+ node={
655
+ 'feature': <tf.Tensor: shape=[2, 4], dtype=float32>
656
+ },
657
+ edge={
658
+ 'source': <tf.Tensor: shape=[2], dtype=int32>,
659
+ 'target': <tf.Tensor: shape=[2], dtype=int32>
660
+ }
661
+ )
662
+ """
663
+
664
+ def __init__(
665
+ self,
666
+ units: int,
667
+ heads: int = 8,
668
+ activation: keras.layers.Activation | str | None = "relu",
669
+ use_bias: bool = True,
670
+ normalization: bool = False,
671
+ update_edge_feature: bool = True,
672
+ attention_activation: keras.layers.Activation | str | None = "leaky_relu",
673
+ **kwargs,
674
+ ) -> None:
675
+ kwargs['skip_connection'] = False
676
+ super().__init__(
677
+ units=units,
678
+ activation=activation,
679
+ use_bias=use_bias,
680
+ normalization=normalization,
681
+ **kwargs
682
+ )
683
+ self._heads = heads
684
+ if self.units % self.heads != 0:
685
+ raise ValueError(f"units need to be divisible by heads.")
686
+ self._head_units = self.units // self.heads
687
+ self._update_edge_feature = update_edge_feature
688
+ self._attention_activation = keras.activations.get(attention_activation)
689
+
690
+ @property
691
+ def heads(self):
692
+ return self._heads
693
+
694
+ @property
695
+ def head_units(self):
696
+ return self._head_units
697
+
698
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
699
+ self._has_edge_feature = 'feature' in spec.edge
700
+ self._update_edge_feature = self._has_edge_feature and self._update_edge_feature
701
+ if self._update_edge_feature:
702
+ self._edge_dense = self.get_einsum_dense(
703
+ 'ijh,jkh->ikh', (self.head_units, self.heads)
704
+ )
705
+ self._node_dense = self.get_einsum_dense(
706
+ 'ij,jkh->ikh', (self.head_units, self.heads)
707
+ )
708
+ self._feature_dense = self.get_einsum_dense(
709
+ 'ij,jkh->ikh', (self.head_units, self.heads)
710
+ )
711
+ self._attention_dense = self.get_einsum_dense(
712
+ 'ijh,jkh->ikh', (1, self.heads)
713
+ )
714
+ self._node_self_dense = self.get_einsum_dense(
715
+ 'ij,jkh->ikh', (self.head_units, self.heads)
716
+ )
717
+
718
+ has_overridden_update = self.__class__.update != GAConv.update
719
+ if not has_overridden_update:
720
+ self._feedforward_intermediate_dense = self.get_dense(self.units)
721
+ self._feedforward_activation = self._activation
722
+ self._feedforward_output_dense = self.get_dense(self.units)
723
+
724
+ def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
725
+ attention_feature = keras.ops.concatenate(
726
+ [
727
+ tensor.gather('feature', 'source'),
728
+ tensor.gather('feature', 'target')
729
+ ],
730
+ axis=-1
731
+ )
732
+ if self._has_edge_feature:
733
+ attention_feature = keras.ops.concatenate(
734
+ [
735
+ attention_feature,
736
+ tensor.edge['feature']
737
+ ],
738
+ axis=-1
739
+ )
740
+
741
+ attention_feature = self._feature_dense(attention_feature)
742
+
743
+ edge_feature = tensor.edge.get('feature')
744
+
745
+ if self._update_edge_feature:
746
+ edge_feature = self._edge_dense(attention_feature)
747
+ edge_feature = keras.ops.reshape(edge_feature, (-1, self.units))
748
+
749
+ attention_feature = self._attention_activation(attention_feature)
750
+ attention_score = self._attention_dense(attention_feature)
751
+ attention_score = ops.edge_softmax(
752
+ score=attention_score, edge_target=tensor.edge['target']
753
+ )
754
+ node_feature = self._node_dense(tensor.node['feature'])
755
+ message = ops.gather(node_feature, tensor.edge['source'])
756
+ return tensor.update(
757
+ {
758
+ 'edge': {
759
+ 'message': message,
760
+ 'weight': attention_score,
761
+ 'feature': edge_feature,
762
+ }
763
+ }
764
+ )
765
+
766
+ def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
767
+ node_feature = tensor.aggregate('message', mode='sum')
768
+ node_feature += self._node_self_dense(tensor.node['feature'])
769
+ node_feature = keras.ops.reshape(node_feature, (-1, self.units))
770
+ return tensor.update(
771
+ {
772
+ 'node': {
773
+ 'feature': node_feature
774
+ },
775
+ 'edge': {
776
+ 'message': None,
777
+ 'weight': None,
778
+ }
779
+ }
780
+ )
781
+
782
+ def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
783
+ node_feature = tensor.node['feature']
784
+ node_feature = self._feedforward_intermediate_dense(node_feature)
785
+ node_feature = self._feedforward_activation(node_feature)
786
+ node_feature = self._feedforward_output_dense(node_feature)
787
+ return tensor.update(
788
+ {
789
+ 'node': {
790
+ 'feature': node_feature
791
+ }
792
+ }
793
+ )
794
+
795
+ def get_config(self) -> dict:
796
+ config = super().get_config()
797
+ config.update({
798
+ "heads": self._heads,
799
+ 'update_edge_feature': self._update_edge_feature,
800
+ 'attention_activation': keras.activations.serialize(self._attention_activation),
801
+ })
802
+ return config
803
+
804
+
475
805
  @keras.saving.register_keras_serializable(package='molcraft')
476
806
  class GTConv(GraphConv):
477
807
 
478
808
  """Graph transformer layer.
809
+
810
+ >>> graph = molcraft.tensors.GraphTensor(
811
+ ... context={
812
+ ... 'size': [2]
813
+ ... },
814
+ ... node={
815
+ ... 'feature': [[1.], [2.]]
816
+ ... },
817
+ ... edge={
818
+ ... 'source': [0, 1],
819
+ ... 'target': [1, 0],
820
+ ... }
821
+ ... )
822
+ >>> conv = molcraft.layers.GTConv(units=4, heads=2)
823
+ >>> conv(graph)
824
+ GraphTensor(
825
+ context={
826
+ 'size': <tf.Tensor: shape=[1], dtype=int32>
827
+ },
828
+ node={
829
+ 'feature': <tf.Tensor: shape=[2, 4], dtype=float32>
830
+ },
831
+ edge={
832
+ 'source': <tf.Tensor: shape=[2], dtype=int32>,
833
+ 'target': <tf.Tensor: shape=[2], dtype=int32>
834
+ }
835
+ )
836
+
479
837
  """
480
838
 
481
839
  def __init__(
@@ -484,26 +842,22 @@ class GTConv(GraphConv):
484
842
  heads: int = 8,
485
843
  activation: keras.layers.Activation | str | None = "relu",
486
844
  use_bias: bool = True,
487
- normalize: bool = True,
488
- dropout: float = 0.0,
845
+ normalization: bool = False,
489
846
  attention_dropout: float = 0.0,
490
847
  **kwargs,
491
848
  ) -> None:
492
- kwargs['skip_connection'] = False
493
849
  super().__init__(
494
850
  units=units,
495
- normalize=normalize,
851
+ activation=activation,
496
852
  use_bias=use_bias,
853
+ normalization=normalization,
497
854
  **kwargs
498
855
  )
499
856
  self._heads = heads
500
857
  if self.units % self.heads != 0:
501
858
  raise ValueError(f"units need to be divisible by heads.")
502
859
  self._head_units = self.units // self.heads
503
- self._activation = keras.activations.get(activation)
504
- self._dropout = dropout
505
860
  self._attention_dropout = attention_dropout
506
- self._normalize = normalize
507
861
 
508
862
  @property
509
863
  def heads(self):
@@ -513,68 +867,41 @@ class GTConv(GraphConv):
513
867
  def head_units(self):
514
868
  return self._head_units
515
869
 
516
- def build_from_spec(self, spec):
517
- """Builds the layer.
518
- """
519
- node_feature_dim = spec.node['feature'].shape[-1]
520
- self.project_residual = node_feature_dim != self.units
521
- if self.project_residual:
522
- warn(
523
- '`GTConv` uses residual connections, but found incompatible dim '
524
- 'between input (node feature dim) and output (`self.units`). '
525
- 'Automatically applying a projection layer to residual to '
526
- 'match input and output. '
527
- )
528
- self._residual_dense = self.get_dense(self.units)
529
- self._residual_dense.build([None, node_feature_dim])
530
-
870
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
531
871
  self._query_dense = self.get_einsum_dense(
532
872
  'ij,jkh->ikh', (self.head_units, self.heads)
533
873
  )
534
- self._query_dense.build([None, node_feature_dim])
535
-
536
874
  self._key_dense = self.get_einsum_dense(
537
875
  'ij,jkh->ikh', (self.head_units, self.heads)
538
876
  )
539
- self._key_dense.build([None, node_feature_dim])
540
-
541
877
  self._value_dense = self.get_einsum_dense(
542
878
  'ij,jkh->ikh', (self.head_units, self.heads)
543
879
  )
544
- self._value_dense.build([None, node_feature_dim])
545
-
546
880
  self._output_dense = self.get_dense(self.units)
547
- self._output_dense.build([None, self.units])
548
-
549
881
  self._softmax_dropout = keras.layers.Dropout(self._attention_dropout)
550
882
 
551
- self._self_attention_dropout = keras.layers.Dropout(self._dropout)
883
+ self._add_bias = not 'bias' in spec.edge
552
884
 
553
- self._add_edge_bias = not 'bias' in spec.edge
554
- if self._add_edge_bias:
555
- self._add_edge_bias = AddEdgeBias()
556
- self._add_edge_bias.build_from_spec(spec)
885
+ if self._add_bias:
886
+ self._edge_bias = EdgeBias(biases=self.heads)
557
887
 
558
888
  has_overridden_update = self.__class__.update != GTConv.update
559
889
  if not has_overridden_update:
560
-
561
- if self._normalize:
562
- self._feedforward_output_norm = keras.layers.LayerNormalization()
563
- self._feedforward_output_norm.build([None, self.units])
564
-
565
- self._feedforward_dropout = keras.layers.Dropout(self._dropout)
566
-
567
890
  self._feedforward_intermediate_dense = self.get_dense(self.units)
568
- self._feedforward_intermediate_dense.build([None, self.units])
569
-
891
+ self._feedforward_activation = self._activation
570
892
  self._feedforward_output_dense = self.get_dense(self.units)
571
- self._feedforward_output_dense.build([None, self.units])
572
-
573
893
 
574
894
  def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
575
- """Computes messages.
576
- """
577
-
895
+ if self._add_bias:
896
+ edge_bias = self._edge_bias(tensor)
897
+ tensor = tensor.update(
898
+ {
899
+ 'edge': {
900
+ 'bias': edge_bias
901
+ }
902
+ }
903
+ )
904
+
578
905
  node_feature = tensor.node['feature']
579
906
 
580
907
  query = self._query_dense(node_feature)
@@ -587,12 +914,8 @@ class GTConv(GraphConv):
587
914
 
588
915
  attention_score = keras.ops.sum(query * key, axis=1, keepdims=True)
589
916
  attention_score /= keras.ops.sqrt(float(self.head_units))
590
-
591
- if self._add_edge_bias:
592
- tensor = self._add_edge_bias(tensor)
593
917
 
594
- attention_score += keras.ops.expand_dims(tensor.edge['bias'], -1)
595
-
918
+ attention_score += keras.ops.expand_dims(tensor.edge['bias'], axis=1)
596
919
  attention = ops.edge_softmax(attention_score, tensor.edge['target'])
597
920
  attention = self._softmax_dropout(attention)
598
921
 
@@ -606,12 +929,9 @@ class GTConv(GraphConv):
606
929
  )
607
930
 
608
931
  def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
609
- """Aggregates messages.
610
- """
611
- node_feature = tensor.aggregate('message')
932
+ node_feature = tensor.aggregate('message', mode='sum')
612
933
  node_feature = keras.ops.reshape(node_feature, (-1, self.units))
613
934
  node_feature = self._output_dense(node_feature)
614
- node_feature = self._self_attention_dropout(node_feature)
615
935
  return tensor.update(
616
936
  {
617
937
  'node': {
@@ -626,49 +946,257 @@ class GTConv(GraphConv):
626
946
  )
627
947
 
628
948
  def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
629
- """Updates nodes.
949
+ node_feature = tensor.node['feature']
950
+ node_feature = self._feedforward_intermediate_dense(node_feature)
951
+ node_feature = self._feedforward_activation(node_feature)
952
+ node_feature = self._feedforward_output_dense(node_feature)
953
+ return tensor.update(
954
+ {
955
+ 'node': {
956
+ 'feature': node_feature,
957
+ },
958
+ }
959
+ )
960
+
961
+ def get_config(self) -> dict:
962
+ config = super().get_config()
963
+ config.update({
964
+ "heads": self._heads,
965
+ 'attention_dropout': self._attention_dropout,
966
+ })
967
+ return config
968
+
969
+
970
+ @keras.saving.register_keras_serializable(package='molcraft')
971
+ class MPConv(GraphConv):
972
+
973
+ """Message passing neural network layer.
974
+ """
975
+
976
+ def __init__(
977
+ self,
978
+ units: int = 128,
979
+ activation: keras.layers.Activation | str | None = None,
980
+ use_bias: bool = True,
981
+ normalization: bool = False,
982
+ **kwargs
983
+ ) -> None:
984
+ super().__init__(
985
+ units=units,
986
+ activation=activation,
987
+ use_bias=use_bias,
988
+ normalization=normalization,
989
+ **kwargs
990
+ )
991
+
992
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
993
+ node_feature_dim = spec.node['feature'].shape[-1]
994
+ self.message_fn = self.get_dense(self.units, activation=self._activation)
995
+ self.update_fn = keras.layers.GRUCell(self.units)
996
+ self._has_edge_feature = 'feature' in spec.edge
997
+ self.project_input_node_feature = node_feature_dim != self.units
998
+ if self.project_input_node_feature:
999
+ warn(
1000
+ 'Input node feature dim does not match updated node feature dim. '
1001
+ 'To make sure input node feature can be passed as `states` to the '
1002
+ 'GRU cell, it will automatically be projected prior to it.'
1003
+ )
1004
+ self._previous_node_dense = self.get_dense(
1005
+ self.units, activation=self._activation
1006
+ )
1007
+
1008
+ def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1009
+ feature = keras.ops.concatenate(
1010
+ [
1011
+ tensor.gather('feature', 'source'),
1012
+ tensor.gather('feature', 'target'),
1013
+ ],
1014
+ axis=-1
1015
+ )
1016
+ if self._has_edge_feature:
1017
+ feature = keras.ops.concatenate(
1018
+ [
1019
+ feature,
1020
+ tensor.edge['feature']
1021
+ ],
1022
+ axis=-1
1023
+ )
1024
+ message = self.message_fn(feature)
1025
+ return tensor.update(
1026
+ {
1027
+ 'edge': {
1028
+ 'message': message,
1029
+ }
1030
+ }
1031
+ )
1032
+
1033
+ def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1034
+ aggregate = tensor.aggregate('message', mode='mean')
1035
+ previous = tensor.node['feature']
1036
+ if self.project_input_node_feature:
1037
+ previous = self._previous_node_dense(previous)
1038
+ return tensor.update(
1039
+ {
1040
+ 'node': {
1041
+ 'feature': aggregate,
1042
+ 'previous_feature': previous,
1043
+ }
1044
+ }
1045
+ )
1046
+
1047
+ def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1048
+ updated_node_feature, _ = self.update_fn(
1049
+ inputs=tensor.node['feature'],
1050
+ states=tensor.node['previous_feature']
1051
+ )
1052
+ return tensor.update(
1053
+ {
1054
+ 'node': {
1055
+ 'feature': updated_node_feature,
1056
+ 'previous_feature': None,
1057
+ }
1058
+ }
1059
+ )
1060
+
1061
+ def get_config(self) -> dict:
1062
+ config = super().get_config()
1063
+ config.update({})
1064
+ return config
1065
+
1066
+
1067
+ @keras.saving.register_keras_serializable(package='molcraft')
1068
+ class GTConv3D(GTConv):
1069
+
1070
+ """Graph transformer layer 3D.
1071
+ """
1072
+
1073
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
1074
+ """Builds the layer.
630
1075
  """
1076
+ super().build(spec)
1077
+ if self._add_bias:
1078
+ node_feature_dim = spec.node['feature'].shape[-1]
1079
+ kernels = self.units
1080
+ self._gaussian_basis = GaussianDistance(kernels)
1081
+ self._centrality_dense = self.get_dense(units=node_feature_dim)
1082
+ self._gaussian_edge_bias = self.get_dense(self.heads)
1083
+
1084
+ def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
631
1085
  node_feature = tensor.node['feature']
632
1086
 
633
- residual = tensor.node['residual']
634
- if self.project_residual:
635
- residual = self._residual_dense(residual)
1087
+ if self._add_bias:
1088
+ gaussian = self._gaussian_basis(tensor)
1089
+ centrality = keras.ops.segment_sum(
1090
+ gaussian, tensor.edge['target'], tensor.num_nodes
1091
+ )
1092
+ node_feature += self._centrality_dense(centrality)
636
1093
 
637
- node_feature += residual
638
- residual = node_feature
1094
+ edge_bias = self._edge_bias(tensor) + self._gaussian_edge_bias(gaussian)
1095
+ tensor = tensor.update({'edge': {'bias': edge_bias}})
1096
+
1097
+ query = self._query_dense(node_feature)
1098
+ key = self._key_dense(node_feature)
1099
+ value = self._value_dense(node_feature)
639
1100
 
640
- node_feature = self._feedforward_intermediate_dense(node_feature)
641
- node_feature = self._activation(node_feature)
642
- node_feature = self._feedforward_output_dense(node_feature)
643
- node_feature = self._feedforward_dropout(node_feature)
644
- if self._normalize:
645
- node_feature = self._feedforward_output_norm(node_feature)
1101
+ query = ops.gather(query, tensor.edge['source'])
1102
+ key = ops.gather(key, tensor.edge['target'])
1103
+ value = ops.gather(value, tensor.edge['source'])
1104
+
1105
+ attention_score = keras.ops.sum(query * key, axis=1, keepdims=True)
1106
+ attention_score /= keras.ops.sqrt(float(self.head_units))
1107
+
1108
+ attention_score += keras.ops.expand_dims(tensor.edge['bias'], axis=1)
1109
+
1110
+ attention = ops.edge_softmax(attention_score, tensor.edge['target'])
1111
+ attention = self._softmax_dropout(attention)
1112
+
1113
+ distance = keras.ops.subtract(
1114
+ tensor.gather('coordinate', 'source'),
1115
+ tensor.gather('coordinate', 'target')
1116
+ )
1117
+ euclidean_distance = ops.euclidean_distance(
1118
+ tensor.gather('coordinate', 'source'),
1119
+ tensor.gather('coordinate', 'target'),
1120
+ axis=-1
1121
+ )
1122
+ distance /= euclidean_distance
1123
+
1124
+ attention *= keras.ops.expand_dims(distance, axis=-1)
1125
+ attention = keras.ops.expand_dims(attention, axis=2)
1126
+ value = keras.ops.expand_dims(value, axis=1)
1127
+
1128
+ return tensor.update(
1129
+ {
1130
+ 'edge': {
1131
+ 'message': value,
1132
+ 'weight': attention,
1133
+ },
1134
+ }
1135
+ )
1136
+
1137
+ def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1138
+ node_feature = tensor.aggregate('message', mode='sum')
1139
+ node_feature = keras.ops.reshape(
1140
+ node_feature, (tensor.num_nodes, -1, self.units)
1141
+ )
1142
+ node_feature = self._output_dense(node_feature)
1143
+ node_feature = keras.ops.sum(node_feature, axis=1)
1144
+ return tensor.update(
1145
+ {
1146
+ 'node': {
1147
+ 'feature': node_feature,
1148
+ 'residual': tensor.node['feature']
1149
+ },
1150
+ 'edge': {
1151
+ 'message': None,
1152
+ 'weight': None,
1153
+ }
1154
+ }
1155
+ )
1156
+
646
1157
 
647
- node_feature += residual
1158
+ @keras.saving.register_keras_serializable(package='molcraft')
1159
+ class MPConv3D(MPConv):
648
1160
 
1161
+ """Message passing neural network layer 3D.
1162
+ """
1163
+
1164
+ def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1165
+ euclidean_distance = ops.euclidean_distance(
1166
+ tensor.gather('coordinate', 'target'),
1167
+ tensor.gather('coordinate', 'source'),
1168
+ axis=-1
1169
+ )
1170
+ feature = keras.ops.concatenate(
1171
+ [
1172
+ tensor.gather('feature', 'source'),
1173
+ tensor.gather('feature', 'target'),
1174
+ euclidean_distance,
1175
+ ],
1176
+ axis=-1
1177
+ )
1178
+ if self._has_edge_feature:
1179
+ feature = keras.ops.concatenate(
1180
+ [
1181
+ feature,
1182
+ tensor.edge['feature']
1183
+ ],
1184
+ axis=-1
1185
+ )
1186
+ message = self.message_fn(feature)
649
1187
  return tensor.update(
650
1188
  {
651
- 'node': {
652
- 'feature': node_feature,
653
- },
1189
+ 'edge': {
1190
+ 'message': message,
1191
+ }
654
1192
  }
655
1193
  )
656
1194
 
657
- def get_config(self) -> dict:
658
- config = super().get_config()
659
- config.update({
660
- "heads": self._heads,
661
- 'activation': keras.activations.serialize(self._activation),
662
- 'dropout': self._dropout,
663
- 'attention_dropout': self._attention_dropout,
664
- })
665
- return config
666
-
667
1195
 
668
1196
  @keras.saving.register_keras_serializable(package='molcraft')
669
1197
  class EGConv3D(GraphConv):
670
1198
 
671
- """Equivariant graph neural network layer.
1199
+ """Equivariant graph neural network layer 3D.
672
1200
  """
673
1201
 
674
1202
  def __init__(
@@ -676,48 +1204,33 @@ class EGConv3D(GraphConv):
676
1204
  units: int = 128,
677
1205
  activation: keras.layers.Activation | str | None = None,
678
1206
  use_bias: bool = True,
679
- normalize: bool = True,
680
- dropout: float = 0.0,
1207
+ normalization: bool = False,
681
1208
  **kwargs
682
1209
  ) -> None:
683
1210
  super().__init__(
684
1211
  units=units,
685
- normalize=normalize,
1212
+ activation=activation,
686
1213
  use_bias=use_bias,
1214
+ normalization=normalization,
687
1215
  **kwargs
688
1216
  )
689
- self._activation = keras.activations.get(activation)
690
- self._dropout = dropout or 0.0
691
1217
 
692
- def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
1218
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
693
1219
  if 'coordinate' not in spec.node:
694
1220
  raise ValueError(
695
1221
  'Could not find `coordinate`s in node, '
696
1222
  'which is required for Conv3D layers.'
697
1223
  )
698
- node_feature_dim = spec.node['feature'].shape[-1]
699
- feature_dim = node_feature_dim + node_feature_dim + 1
700
- if 'feature' in spec.edge:
701
- self._has_edge_feature = True
702
- edge_feature_dim = spec.edge['feature'].shape[-1]
703
- feature_dim += edge_feature_dim
704
- else:
705
- self._has_edge_feature = False
706
-
1224
+ self._has_edge_feature = 'feature' in spec.edge
707
1225
  self.message_fn = self.get_dense(self.units, activation=self._activation)
708
- self.message_fn.build([None, feature_dim])
709
1226
  self.dense_position = self.get_dense(1)
710
- self.dense_position.build([None, self.units])
711
1227
 
712
1228
  has_overridden_update = self.__class__.update != EGConv3D.update
713
1229
  if not has_overridden_update:
714
1230
  self.update_fn = self.get_dense(self.units, activation=self._activation)
715
- self.update_fn.build([None, node_feature_dim + self.units])
716
- self._dropout_layer = keras.layers.Dropout(self._dropout)
1231
+ self.output_dense = self.get_dense(self.units)
717
1232
 
718
1233
  def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
719
- """Computes messages.
720
- """
721
1234
  relative_node_coordinate = keras.ops.subtract(
722
1235
  tensor.gather('coordinate', 'target'),
723
1236
  tensor.gather('coordinate', 'source')
@@ -760,8 +1273,6 @@ class EGConv3D(GraphConv):
760
1273
  )
761
1274
 
762
1275
  def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
763
- """Aggregates messages.
764
- """
765
1276
  coefficient = keras.ops.bincount(
766
1277
  tensor.edge['source'],
767
1278
  minlength=tensor.num_nodes
@@ -776,7 +1287,7 @@ class EGConv3D(GraphConv):
776
1287
  updated_coordinate = tensor.aggregate('relative_node_coordinate') * coefficient
777
1288
  updated_coordinate += tensor.node['coordinate']
778
1289
 
779
- aggregate = tensor.aggregate('message')
1290
+ aggregate = tensor.aggregate('message', mode='mean')
780
1291
  return tensor.update(
781
1292
  {
782
1293
  'node': {
@@ -792,8 +1303,6 @@ class EGConv3D(GraphConv):
792
1303
  )
793
1304
 
794
1305
  def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
795
- """Updates nodes.
796
- """
797
1306
  updated_node_feature = self.update_fn(
798
1307
  keras.ops.concatenate(
799
1308
  [
@@ -803,7 +1312,7 @@ class EGConv3D(GraphConv):
803
1312
  axis=-1
804
1313
  )
805
1314
  )
806
- updated_node_feature = self._dropout_layer(updated_node_feature)
1315
+ updated_node_feature = self.output_dense(updated_node_feature)
807
1316
  return tensor.update(
808
1317
  {
809
1318
  'node': {
@@ -815,65 +1324,46 @@ class EGConv3D(GraphConv):
815
1324
 
816
1325
  def get_config(self) -> dict:
817
1326
  config = super().get_config()
818
- config.update({
819
- 'activation': keras.activations.serialize(self._activation),
820
- 'dropout': self._dropout,
821
- })
1327
+ config.update({})
822
1328
  return config
823
1329
 
824
1330
 
825
1331
  @keras.saving.register_keras_serializable(package='molcraft')
826
- class Projection(GraphLayer):
827
- """Base graph projection layer.
1332
+ class Readout(GraphLayer):
1333
+
1334
+ """Readout layer.
828
1335
  """
829
- def __init__(
830
- self,
831
- units: int = None,
832
- activation: str = None,
833
- field: str = 'node',
834
- **kwargs
835
- ) -> None:
836
- super().__init__(**kwargs)
837
- self.units = units
838
- self._activation = keras.activations.get(activation)
839
- self.field = field
840
1336
 
841
- def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
842
- """Builds the layer.
843
- """
844
- data = getattr(spec, self.field, None)
845
- if data is None:
846
- raise ValueError('Could not access field {self.field!r}.')
847
- feature_dim = data['feature'].shape[-1]
848
- if not self.units:
849
- self.units = feature_dim
850
- self._dense = self.get_dense(self.units)
851
- self._dense.build([None, feature_dim])
1337
+ def __init__(self, mode: str | None = None, **kwargs):
1338
+ kwargs['kernel_initializer'] = None
1339
+ kwargs['bias_initializer'] = None
1340
+ super().__init__(**kwargs)
1341
+ self.mode = mode
1342
+ if str(self.mode).lower().startswith('sum'):
1343
+ self._reduce_fn = keras.ops.segment_sum
1344
+ elif str(self.mode).lower().startswith('max'):
1345
+ self._reduce_fn = keras.ops.segment_max
1346
+ elif str(self.mode).lower().startswith('super'):
1347
+ self._reduce_fn = keras.ops.segment_sum
1348
+ else:
1349
+ self._reduce_fn = ops.segment_mean
852
1350
 
853
- def propagate(self, tensor: tensors.GraphTensor):
854
- """Calls the layer.
855
- """
856
- feature = getattr(tensor, self.field)['feature']
857
- feature = self._dense(feature)
858
- feature = self._activation(feature)
859
- return tensor.update(
860
- {
861
- self.field: {
862
- 'feature': feature
863
- }
864
- }
865
- )
1351
+ def propagate(self, tensor: tensors.GraphTensor) -> tf.Tensor:
1352
+ node_feature = tensor.node['feature']
1353
+ if str(self.mode).lower().startswith('super'):
1354
+ node_feature = keras.ops.where(
1355
+ tensor.node['super'][:, None], node_feature, 0.0
1356
+ )
1357
+ return self._reduce_fn(
1358
+ node_feature, tensor.graph_indicator, tensor.num_subgraphs
1359
+ )
866
1360
 
867
1361
  def get_config(self) -> dict:
868
1362
  config = super().get_config()
869
- config.update({
870
- 'units': self.units,
871
- 'activation': keras.activations.serialize(self._activation),
872
- 'field': self.field,
873
- })
874
- return config
1363
+ config['mode'] = self.mode
1364
+ return config
1365
+
875
1366
 
876
-
877
1367
  @keras.saving.register_keras_serializable(package='molcraft')
878
1368
  class GraphNetwork(GraphLayer):
879
1369
 
@@ -881,7 +1371,7 @@ class GraphNetwork(GraphLayer):
881
1371
 
882
1372
  Sequentially calls graph layers (`GraphLayer`) and concatenates its output.
883
1373
 
884
- Args:
1374
+ Arguments:
885
1375
  layers (list):
886
1376
  A list of graph layers.
887
1377
  """
@@ -891,36 +1381,32 @@ class GraphNetwork(GraphLayer):
891
1381
  self.layers = layers
892
1382
  self._update_edge_feature = False
893
1383
 
894
- def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
895
- """Builds the layer.
896
- """
1384
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
897
1385
  units = self.layers[0].units
898
1386
  node_feature_dim = spec.node['feature'].shape[-1]
899
- if node_feature_dim != units:
1387
+ self._update_node_feature = node_feature_dim != units
1388
+ if self._update_node_feature:
900
1389
  warn(
901
1390
  'Node feature dim does not match `units` of the first layer. '
902
1391
  'Automatically adding a node projection layer to match `units`.'
903
1392
  )
904
1393
  self._node_dense = self.get_dense(units)
905
- self._update_node_feature = True
906
- has_edge_feature = 'feature' in spec.edge
907
- if has_edge_feature:
1394
+ self._has_edge_feature = 'feature' in spec.edge
1395
+ if self._has_edge_feature:
908
1396
  edge_feature_dim = spec.edge['feature'].shape[-1]
909
- if edge_feature_dim != units:
1397
+ self._update_edge_feature = edge_feature_dim != units
1398
+ if self._update_edge_feature:
910
1399
  warn(
911
1400
  'Edge feature dim does not match `units` of the first layer. '
912
1401
  'Automatically adding a edge projection layer to match `units`.'
913
1402
  )
914
1403
  self._edge_dense = self.get_dense(units)
915
- self._update_edge_feature = True
916
1404
 
917
1405
  def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
918
- """Calls the layer.
919
- """
920
1406
  x = tensors.to_dict(tensor)
921
1407
  if self._update_node_feature:
922
1408
  x['node']['feature'] = self._node_dense(tensor.node['feature'])
923
- if self._update_edge_feature:
1409
+ if self._has_edge_feature and self._update_edge_feature:
924
1410
  x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
925
1411
  outputs = [x['node']['feature']]
926
1412
  for layer in self.layers:
@@ -945,7 +1431,7 @@ class GraphNetwork(GraphLayer):
945
1431
  Performs the same forward pass as `propagate` but with a `GradientTape`
946
1432
  watching intermediate node features.
947
1433
 
948
- Args:
1434
+ Arguments:
949
1435
  tensor (tensors.GraphTensor):
950
1436
  The graph input.
951
1437
  """
@@ -1003,24 +1489,25 @@ class NodeEmbedding(GraphLayer):
1003
1489
  def __init__(
1004
1490
  self,
1005
1491
  dim: int = None,
1492
+ normalization: bool = False,
1006
1493
  embed_context: bool = True,
1494
+ allow_reconstruction: bool = False,
1007
1495
  allow_masking: bool = True,
1008
1496
  **kwargs
1009
1497
  ) -> None:
1010
1498
  super().__init__(**kwargs)
1011
1499
  self.dim = dim
1500
+ self._normalization = normalization
1012
1501
  self._embed_context = embed_context
1013
1502
  self._masking_rate = None
1014
1503
  self._allow_masking = allow_masking
1504
+ self._allow_reconstruction = allow_reconstruction
1015
1505
 
1016
- def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
1017
- """Builds the layer.
1018
- """
1506
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
1019
1507
  feature_dim = spec.node['feature'].shape[-1]
1020
1508
  if not self.dim:
1021
1509
  self.dim = feature_dim
1022
1510
  self._node_dense = self.get_dense(self.dim)
1023
- self._node_dense.build([None, feature_dim])
1024
1511
 
1025
1512
  self._has_super = 'super' in spec.node
1026
1513
  has_context_feature = 'feature' in spec.context
@@ -1034,11 +1521,18 @@ class NodeEmbedding(GraphLayer):
1034
1521
  if self._embed_context:
1035
1522
  context_feature_dim = spec.context['feature'].shape[-1]
1036
1523
  self._context_dense = self.get_dense(self.dim)
1037
- self._context_dense.build([None, context_feature_dim])
1524
+
1525
+ if self._normalization:
1526
+ if str(self._normalization).lower().startswith('batch'):
1527
+ self._norm = keras.layers.BatchNormalization(
1528
+ name='output_batch_norm'
1529
+ )
1530
+ else:
1531
+ self._norm = keras.layers.LayerNormalization(
1532
+ name='output_layer_norm'
1533
+ )
1038
1534
 
1039
1535
  def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1040
- """Calls the layer.
1041
- """
1042
1536
  feature = self._node_dense(tensor.node['feature'])
1043
1537
 
1044
1538
  if self._has_super:
@@ -1068,8 +1562,13 @@ class NodeEmbedding(GraphLayer):
1068
1562
  # Slience warning of 'no gradients for variables'
1069
1563
  feature = feature + (self._mask_feature * 0.0)
1070
1564
 
1071
- return tensor.update({'node': {'feature': feature}})
1565
+ if self._normalization:
1566
+ feature = self._norm(feature)
1072
1567
 
1568
+ if not self._allow_reconstruction:
1569
+ return tensor.update({'node': {'feature': feature}})
1570
+ return tensor.update({'node': {'feature': feature, 'target_feature': feature}})
1571
+
1073
1572
  @property
1074
1573
  def masking_rate(self):
1075
1574
  return self._masking_rate
@@ -1087,7 +1586,10 @@ class NodeEmbedding(GraphLayer):
1087
1586
  config = super().get_config()
1088
1587
  config.update({
1089
1588
  'dim': self.dim,
1090
- 'allow_masking': self._allow_masking
1589
+ 'normalization': self._normalization,
1590
+ 'embed_context': self._embed_context,
1591
+ 'allow_masking': self._allow_masking,
1592
+ 'allow_reconstruction': self._allow_reconstruction,
1091
1593
  })
1092
1594
  return config
1093
1595
 
@@ -1103,22 +1605,21 @@ class EdgeEmbedding(GraphLayer):
1103
1605
  def __init__(
1104
1606
  self,
1105
1607
  dim: int = None,
1608
+ normalization: bool = False,
1106
1609
  allow_masking: bool = True,
1107
1610
  **kwargs
1108
1611
  ) -> None:
1109
1612
  super().__init__(**kwargs)
1110
1613
  self.dim = dim
1614
+ self._normalization = normalization
1111
1615
  self._masking_rate = None
1112
1616
  self._allow_masking = allow_masking
1113
1617
 
1114
- def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
1115
- """Builds the layer.
1116
- """
1618
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
1117
1619
  feature_dim = spec.edge['feature'].shape[-1]
1118
1620
  if not self.dim:
1119
1621
  self.dim = feature_dim
1120
1622
  self._edge_dense = self.get_dense(self.dim)
1121
- self._edge_dense.build([None, feature_dim])
1122
1623
 
1123
1624
  self._has_super = 'super' in spec.edge
1124
1625
  if self._has_super:
@@ -1126,9 +1627,17 @@ class EdgeEmbedding(GraphLayer):
1126
1627
  if self._allow_masking:
1127
1628
  self._mask_feature = self.get_weight(shape=[self.dim], name='mask_edge_feature')
1128
1629
 
1630
+ if self._normalization:
1631
+ if str(self._normalization).lower().startswith('batch'):
1632
+ self._norm = keras.layers.BatchNormalization(
1633
+ name='output_batch_norm'
1634
+ )
1635
+ else:
1636
+ self._norm = keras.layers.LayerNormalization(
1637
+ name='output_layer_norm'
1638
+ )
1639
+
1129
1640
  def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1130
- """Calls the layer.
1131
- """
1132
1641
  feature = self._edge_dense(tensor.edge['feature'])
1133
1642
 
1134
1643
  if self._has_super:
@@ -1153,7 +1662,10 @@ class EdgeEmbedding(GraphLayer):
1153
1662
  # Slience warning of 'no gradients for variables'
1154
1663
  feature = feature + (self._mask_feature * 0.0)
1155
1664
 
1156
- return tensor.update({'edge': {'feature': feature}})
1665
+ if self._normalization:
1666
+ feature = self._norm(feature)
1667
+
1668
+ return tensor.update({'edge': {'feature': feature, 'embedding': feature}})
1157
1669
 
1158
1670
  @property
1159
1671
  def masking_rate(self):
@@ -1172,17 +1684,67 @@ class EdgeEmbedding(GraphLayer):
1172
1684
  config = super().get_config()
1173
1685
  config.update({
1174
1686
  'dim': self.dim,
1687
+ 'normalization': self._normalization,
1175
1688
  'allow_masking': self._allow_masking
1176
1689
  })
1177
1690
  return config
1178
1691
 
1179
1692
 
1693
+ @keras.saving.register_keras_serializable(package='molcraft')
1694
+ class Projection(GraphLayer):
1695
+ """Base graph projection layer.
1696
+ """
1697
+ def __init__(
1698
+ self,
1699
+ units: int = None,
1700
+ activation: str | keras.layers.Activation | None = None,
1701
+ use_bias: bool = True,
1702
+ field: str = 'node',
1703
+ **kwargs
1704
+ ) -> None:
1705
+ super().__init__(use_bias=use_bias, **kwargs)
1706
+ self.units = units
1707
+ self._activation = keras.activations.get(activation)
1708
+ self.field = field
1709
+
1710
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
1711
+ data = getattr(spec, self.field, None)
1712
+ if data is None:
1713
+ raise ValueError('Could not access field {self.field!r}.')
1714
+ feature_dim = data['feature'].shape[-1]
1715
+ if not self.units:
1716
+ self.units = feature_dim
1717
+ self._dense = self.get_dense(self.units)
1718
+
1719
+ def propagate(self, tensor: tensors.GraphTensor):
1720
+ feature = getattr(tensor, self.field)['feature']
1721
+ feature = self._dense(feature)
1722
+ feature = self._activation(feature)
1723
+ return tensor.update(
1724
+ {
1725
+ self.field: {
1726
+ 'feature': feature
1727
+ }
1728
+ }
1729
+ )
1730
+
1731
+ def get_config(self) -> dict:
1732
+ config = super().get_config()
1733
+ config.update({
1734
+ 'units': self.units,
1735
+ 'activation': keras.activations.serialize(self._activation),
1736
+ 'field': self.field,
1737
+ })
1738
+ return config
1739
+
1740
+
1180
1741
  @keras.saving.register_keras_serializable(package='molcraft')
1181
1742
  class ContextProjection(Projection):
1182
1743
  """Context projection layer.
1183
1744
  """
1184
1745
  def __init__(self, units: int = None, activation: str = None, **kwargs):
1185
- super().__init__(units=units, activation=activation, field='context', **kwargs)
1746
+ kwargs['field'] = 'context'
1747
+ super().__init__(units=units, activation=activation, **kwargs)
1186
1748
 
1187
1749
 
1188
1750
  @keras.saving.register_keras_serializable(package='molcraft')
@@ -1190,7 +1752,8 @@ class NodeProjection(Projection):
1190
1752
  """Node projection layer.
1191
1753
  """
1192
1754
  def __init__(self, units: int = None, activation: str = None, **kwargs):
1193
- super().__init__(units=units, activation=activation, field='node', **kwargs)
1755
+ kwargs['field'] = 'node'
1756
+ super().__init__(units=units, activation=activation, **kwargs)
1194
1757
 
1195
1758
 
1196
1759
  @keras.saving.register_keras_serializable(package='molcraft')
@@ -1198,103 +1761,126 @@ class EdgeProjection(Projection):
1198
1761
  """Edge projection layer.
1199
1762
  """
1200
1763
  def __init__(self, units: int = None, activation: str = None, **kwargs):
1201
- super().__init__(units=units, activation=activation, field='edge', **kwargs)
1202
-
1764
+ kwargs['field'] = 'edge'
1765
+ super().__init__(units=units, activation=activation, **kwargs)
1766
+
1203
1767
 
1204
1768
  @keras.saving.register_keras_serializable(package='molcraft')
1205
- class Readout(keras.layers.Layer):
1769
+ class Reconstruction(GraphLayer):
1206
1770
 
1207
- def __init__(self, mode: str | None = None, **kwargs):
1771
+ def __init__(
1772
+ self,
1773
+ loss: keras.losses.Loss | str = 'mse',
1774
+ loss_weight: float = 0.5,
1775
+ **kwargs
1776
+ ):
1208
1777
  super().__init__(**kwargs)
1209
- self.mode = mode
1210
- if not self.mode:
1211
- self._reduce_fn = None
1212
- elif str(self.mode).lower().startswith('sum'):
1213
- self._reduce_fn = keras.ops.segment_sum
1214
- elif str(self.mode).lower().startswith('max'):
1215
- self._reduce_fn = keras.ops.segment_max
1216
- elif str(self.mode).lower().startswith('super'):
1217
- self._reduce_fn = keras.ops.segment_sum
1218
- else:
1219
- self._reduce_fn = ops.segment_mean
1220
-
1221
- def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
1222
- """Builds the layer.
1223
- """
1224
- pass
1225
-
1226
- def reduce(self, tensor: tensors.GraphTensor) -> tf.Tensor:
1227
- if self._reduce_fn is None:
1228
- raise NotImplementedError("Need to define a reduce method.")
1229
- if str(self.mode).lower().startswith('super'):
1230
- node_feature = keras.ops.where(
1231
- tensor.node['super'][:, None], tensor.node['feature'], 0.0
1232
- )
1233
- return self._reduce_fn(
1234
- node_feature, tensor.graph_indicator, tensor.num_subgraphs
1778
+ self._loss_fn = keras.losses.get(loss)
1779
+ self._loss_weight = loss_weight
1780
+
1781
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
1782
+ has_target_node_feature = 'target_feature' in spec.node
1783
+ if not has_target_node_feature:
1784
+ raise ValueError(
1785
+ 'Could not find `target_feature` in `spec.node`. '
1786
+ 'Add a `target_feature` via `NodeEmbedding` by setting '
1787
+ '`allow_reconstruction` to `True`.'
1235
1788
  )
1236
- return self._reduce_fn(
1237
- tensor.node['feature'], tensor.graph_indicator, tensor.num_subgraphs
1238
- )
1789
+ output_dim = spec.node['target_feature'].shape[-1]
1790
+ self._dense = self.get_dense(output_dim)
1239
1791
 
1240
- def build(self, input_shapes) -> None:
1241
- spec = tensors.GraphTensor.Spec.from_input_shape_dict(input_shapes)
1242
- self.build_from_spec(spec)
1243
- self.built = True
1792
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1793
+ target_node_feature = tensor.node['target_feature']
1794
+ transformed_node_feature = tensor.node['feature']
1244
1795
 
1245
- def call(self, graph) -> tf.Tensor:
1246
- graph_tensor = tensors.from_dict(graph)
1247
- if tensors.is_ragged(graph_tensor):
1248
- graph_tensor = graph_tensor.flatten()
1249
- return self.reduce(graph_tensor)
1796
+ reconstructed_node_feature = self._dense(
1797
+ transformed_node_feature
1798
+ )
1250
1799
 
1251
- def __call__(
1252
- self,
1253
- graph: tensors.GraphTensor,
1254
- *args,
1255
- **kwargs
1256
- ) -> tensors.GraphTensor:
1257
- is_tensor = isinstance(graph, tensors.GraphTensor)
1258
- if is_tensor:
1259
- graph = tensors.to_dict(graph)
1260
- tensor = super().__call__(graph, *args, **kwargs)
1261
- return tensor
1800
+ loss = self._loss_fn(
1801
+ target_node_feature, reconstructed_node_feature
1802
+ )
1803
+ self.add_loss(keras.ops.sum(loss) * self._loss_weight)
1804
+ return tensor.update({'node': {'feature': transformed_node_feature}})
1262
1805
 
1263
- def get_config(self) -> dict:
1806
+ def get_config(self):
1264
1807
  config = super().get_config()
1265
- config['mode'] = self.mode
1266
- return config
1267
-
1808
+ config['loss'] = keras.losses.serialize(self._loss_fn)
1809
+ config['loss_weight'] = self._loss_weight
1810
+ return config
1811
+
1268
1812
 
1269
1813
  @keras.saving.register_keras_serializable(package='molcraft')
1270
- class AddEdgeBias(GraphLayer):
1814
+ class EdgeBias(GraphLayer):
1271
1815
 
1272
- def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
1816
+ def __init__(self, biases: int, **kwargs):
1817
+ super().__init__(**kwargs)
1818
+ self.biases = biases
1819
+
1820
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
1273
1821
  self._has_edge_length = 'length' in spec.edge
1274
1822
  self._has_edge_feature = 'feature' in spec.edge
1275
1823
  if self._has_edge_feature:
1276
- self._edge_feature_dense = self.get_dense(units=1)
1824
+ self._edge_feature_dense = self.get_dense(self.biases)
1277
1825
  if self._has_edge_length:
1278
1826
  self._edge_length_dense = self.get_dense(
1279
- units=1, kernel_initializer='zeros'
1827
+ self.biases, kernel_initializer='zeros'
1280
1828
  )
1281
-
1829
+
1282
1830
  def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1283
1831
  bias = keras.ops.zeros(
1284
- shape=(tensor.num_edges, 1),
1832
+ shape=(tensor.num_edges, self.biases),
1285
1833
  dtype=tensor.node['feature'].dtype
1286
1834
  )
1287
1835
  if self._has_edge_feature:
1288
1836
  bias += self._edge_feature_dense(tensor.edge['feature'])
1289
1837
  if self._has_edge_length:
1290
1838
  bias += self._edge_length_dense(tensor.edge['length'])
1291
- return tensor.update(
1292
- {
1293
- 'edge': {
1294
- 'bias': bias
1295
- }
1296
- }
1839
+ return bias
1840
+
1841
+ def get_config(self) -> dict:
1842
+ config = super().get_config()
1843
+ config.update({'biases': self.biases})
1844
+ return config
1845
+
1846
+
1847
+ @keras.saving.register_keras_serializable(package='molcraft')
1848
+ class GaussianDistance(GraphLayer):
1849
+
1850
+ def __init__(self, kernels: int, **kwargs):
1851
+ super().__init__(**kwargs)
1852
+ self.kernels = kernels
1853
+
1854
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
1855
+ self._loc = self.add_weight(
1856
+ shape=[self.kernels],
1857
+ initializer='zeros',
1858
+ dtype='float32',
1859
+ trainable=True
1860
+ )
1861
+ self._scale = self.add_weight(
1862
+ shape=[self.kernels],
1863
+ initializer='ones',
1864
+ dtype='float32',
1865
+ trainable=True
1866
+ )
1867
+
1868
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1869
+ euclidean_distance = ops.euclidean_distance(
1870
+ tensor.gather('coordinate', 'source'),
1871
+ tensor.gather('coordinate', 'target'),
1872
+ axis=-1
1873
+ )
1874
+ return ops.gaussian(
1875
+ euclidean_distance, self._loc, self._scale
1297
1876
  )
1877
+
1878
+ def get_config(self) -> dict:
1879
+ config = super().get_config()
1880
+ config.update({
1881
+ 'kernels': self.kernels,
1882
+ })
1883
+ return config
1298
1884
 
1299
1885
 
1300
1886
  def Input(spec: tensors.GraphTensor.Spec) -> dict:
@@ -1412,13 +1998,6 @@ def _spec_from_inputs(inputs):
1412
1998
  return tensors.GraphTensor.Spec(**nested_specs)
1413
1999
 
1414
2000
 
1415
- GraphTransformer = GTConvolution = GTConv
1416
- GINConvolution = GINConv
1417
-
1418
- EdgeEmbed = EdgeEmbedding
1419
- NodeEmbed = NodeEmbedding
1420
-
1421
- ContextDense = ContextProjection
1422
- EdgeDense = EdgeProjection
1423
- NodeDense = NodeProjection
2001
+ GraphTransformer = GTConv
2002
+ GraphTransformer3D = GTConv3D
1424
2003