molcraft 0.1.0a3__py3-none-any.whl → 0.1.0a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

molcraft/layers.py CHANGED
@@ -1,7 +1,7 @@
1
- import abc
2
1
  import keras
3
2
  import tensorflow as tf
4
3
  import warnings
4
+ import functools
5
5
  from keras.src.models import functional
6
6
 
7
7
  from molcraft import tensors
@@ -12,11 +12,39 @@ from molcraft import ops
12
12
  class GraphLayer(keras.layers.Layer):
13
13
  """Base graph layer.
14
14
 
15
- Currently, the `GraphLayer` only supports `GraphTensor` input.
15
+ Subclasses must implement a forward pass via **propagate(graph)**.
16
+
17
+ Subclasses may create dense layers and weights in **build(graph_spec)**.
18
+
19
+ Note: `GraphLayer` currently only supports `GraphTensor` input.
16
20
 
17
- The list of arguments are only relevant if the derived layer
21
+ The list of arguments below are only relevant if the derived layer
18
22
  invokes 'get_dense_kwargs`, `get_dense` or `get_einsum_dense`.
19
23
 
24
+ Arguments:
25
+ use_bias (bool):
26
+ Whether bias should be used in dense layers. Default to `True`.
27
+ kernel_initializer (keras.initializers.Initializer, str):
28
+ Initializer for the kernel weight matrix of the dense layers.
29
+ Default to `glorot_uniform`.
30
+ bias_initializer (keras.initializers.Initializer, str):
31
+ Initializer for the bias weight vector of the dense layers.
32
+ Default to `zeros`.
33
+ kernel_regularizer (keras.regularizers.Regularizer, None):
34
+ Regularizer function applied to the kernel weight matrix.
35
+ Default to `None`.
36
+ bias_regularizer (keras.regularizers.Regularizer, None):
37
+ Regularizer function applied to the bias weight vector.
38
+ Default to `None`.
39
+ activity_regularizer (keras.regularizers.Regularizer, None):
40
+ Regularizer function applied to the output of the dense layers.
41
+ Default to `None`.
42
+ kernel_constraint (keras.constraints.Constraint, None):
43
+ Constraint function applied to the kernel weight matrix.
44
+ Default to `None`.
45
+ bias_constraint (keras.constraints.Constraint, None):
46
+ Constraint function applied to the bias weight vector.
47
+ Default to `None`.
20
48
  """
21
49
 
22
50
  def __init__(
@@ -31,68 +59,61 @@ class GraphLayer(keras.layers.Layer):
31
59
  bias_constraint: keras.constraints.Constraint | None = None,
32
60
  **kwargs,
33
61
  ) -> None:
34
- super().__init__(activity_regularizer=activity_regularizer, **kwargs)
62
+ super().__init__(**kwargs)
35
63
  self._use_bias = use_bias
36
64
  self._kernel_initializer = keras.initializers.get(kernel_initializer)
37
65
  self._bias_initializer = keras.initializers.get(bias_initializer)
38
66
  self._kernel_regularizer = keras.regularizers.get(kernel_regularizer)
39
67
  self._bias_regularizer = keras.regularizers.get(bias_regularizer)
68
+ self._activity_regularizer = keras.regularizers.get(activity_regularizer)
40
69
  self._kernel_constraint = keras.constraints.get(kernel_constraint)
41
70
  self._bias_constraint = keras.constraints.get(bias_constraint)
71
+ self._custom_build_config = {}
42
72
  self.built = False
43
- # TODO: Add warning if build is implemented in subclass
44
- # TODO: Add warning if call is implemented in subclass
73
+
74
+ def __init_subclass__(cls, **kwargs):
75
+ super().__init_subclass__(**kwargs)
76
+ subclass_build = cls.build
77
+
78
+ @functools.wraps(subclass_build)
79
+ def build_wrapper(self: GraphLayer, spec: tensors.GraphTensor.Spec | None):
80
+ GraphLayer.build(self, spec)
81
+ subclass_build(self, spec)
82
+ if not self.built and isinstance(self, keras.Model):
83
+ symbolic_inputs = Input(spec)
84
+ self.built = True
85
+ self(symbolic_inputs)
86
+
87
+ cls.build = build_wrapper
45
88
 
46
89
  def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
47
- """Calls the layer.
90
+ """Forward pass.
48
91
 
49
- Needs to be implemented by subclass.
92
+ Must be implemented by subclass.
50
93
 
51
- Args:
94
+ Arguments:
52
95
  tensor:
53
96
  A `GraphTensor` instance.
54
97
  """
55
- raise NotImplementedError('`propagate` needs to be implemented.')
98
+ raise NotImplementedError(
99
+ 'The forward pass of the layer is not implemented. '
100
+ 'Please implement `propagate`.'
101
+ )
56
102
 
57
- def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
103
+ def build(self, tensor_spec: tensors.GraphTensor.Spec) -> None:
58
104
  """Builds the layer.
59
105
 
60
106
  May use built-in methods such as `get_weight`, `get_dense` and `get_einsum_dense`.
61
107
 
62
- Optionally implemented by subclass. If implemented, it is recommended to
63
- If sub-layers are built (via `build` or `build_from_spec`), set `built`
64
- to True. If not, symbolic input will be passed through the layer to build them.
108
+ Optionally implemented by subclass.
65
109
 
66
- Args:
67
- spec:
68
- A `GraphTensor.Spec` instance, corresponding to the `GraphTensor`
110
+ Arguments:
111
+ tensor_spec:
112
+ A `GraphTensor.Spec` instance corresponding to the `GraphTensor`
69
113
  passed to `propagate`.
70
114
  """
71
-
72
- def build(self, spec: tensors.GraphTensor.Spec) -> None:
73
-
74
- self._custom_build_config = {'spec': _serialize_spec(spec)}
75
-
76
- self.build_from_spec(spec)
77
-
78
- if not self.built:
79
- # Automatically build layer or model by calling it on symbolic inputs
80
- self.built = True
81
- symbolic_inputs = Input(spec)
82
- self(symbolic_inputs)
83
-
84
- def get_build_config(self) -> dict:
85
- if not hasattr(self, '_custom_build_config'):
86
- return super().get_build_config()
87
- return self._custom_build_config
88
-
89
- def build_from_config(self, config: dict) -> None:
90
- use_custom_build_from_config = ('spec' in config)
91
- if not use_custom_build_from_config:
92
- super().build_from_config(config)
93
- else:
94
- spec = _deserialize_spec(config['spec'])
95
- self.build(spec)
115
+ if isinstance(tensor_spec, tensors.GraphTensor.Spec):
116
+ self._custom_build_config['spec'] = _serialize_spec(tensor_spec)
96
117
 
97
118
  def call(
98
119
  self,
@@ -122,6 +143,19 @@ class GraphLayer(keras.layers.Layer):
122
143
  outputs = tensors.from_dict(outputs)
123
144
  return outputs
124
145
 
146
+ def get_build_config(self) -> dict:
147
+ if self._custom_build_config:
148
+ return self._custom_build_config
149
+ return super().get_build_config()
150
+
151
+ def build_from_config(self, config: dict) -> None:
152
+ serialized_spec = config.get('spec')
153
+ if serialized_spec is not None:
154
+ spec = _deserialize_spec(serialized_spec)
155
+ self.build(spec)
156
+ else:
157
+ super().build_from_config(config)
158
+
125
159
  def get_weight(
126
160
  self,
127
161
  shape: tf.TensorShape,
@@ -163,7 +197,7 @@ class GraphLayer(keras.layers.Layer):
163
197
  use_bias=self._use_bias,
164
198
  kernel_regularizer=self._kernel_regularizer,
165
199
  bias_regularizer=self._bias_regularizer,
166
- activity_regularizer=self.activity_regularizer,
200
+ activity_regularizer=self._activity_regularizer,
167
201
  kernel_constraint=self._kernel_constraint,
168
202
  bias_constraint=self._bias_constraint,
169
203
  )
@@ -189,55 +223,136 @@ class GraphLayer(keras.layers.Layer):
189
223
  keras.regularizers.serialize(self._kernel_regularizer),
190
224
  "bias_regularizer":
191
225
  keras.regularizers.serialize(self._bias_regularizer),
226
+ "activity_regularizer":
227
+ keras.regularizers.serialize(self._activity_regularizer),
192
228
  "kernel_constraint":
193
229
  keras.constraints.serialize(self._kernel_constraint),
194
230
  "bias_constraint":
195
231
  keras.constraints.serialize(self._bias_constraint),
196
232
  })
197
233
  return config
198
-
234
+
199
235
 
200
236
  @keras.saving.register_keras_serializable(package='molcraft')
201
237
  class GraphConv(GraphLayer):
202
238
 
203
239
  """Base graph neural network layer.
204
240
 
205
- For normalization and skip connection to work, the `GraphConv` subclass
206
- requires the (node feature) output of `aggregate` and `update` to have a
207
- dimension of `self.units`, respectively.
208
-
209
- Args:
210
- units:
211
- The number of units.
212
- normalize:
213
- Whether `LayerNormalization` should be applied to the (node feature) output
214
- of the `aggregate` step. While normalization is recommended, it is not used
215
- by default.
216
- skip_connection:
217
- Whether (node feature) input should be added to the (node feature) output.
218
- If (node feature) input dim is not equal to `units`, a projection layer will
219
- automatically project the residual before adding it to the output. While skip
220
- connection is recommended, it is not used by default.
221
- kwargs:
222
- See arguments of `GraphLayer`.
241
+ This layer implements the three basic steps of a graph neural network layer, each of which
242
+ can (optionally) be overridden by the `GraphConv` subclass:
243
+
244
+ 1. **message(graph)**, which computes the *messages* to be passed to target nodes;
245
+ 2. **aggregate(graph)**, which *aggregates* messages to target nodes;
246
+ 3. **update(graph)**, which further *updates* (target) nodes.
247
+
248
+ Note: for skip connection to work, the `GraphConv` subclass requires final node feature
249
+ output dimension to be equal to `units`.
250
+
251
+ Arguments:
252
+ units (int):
253
+ Dimensionality of the output space.
254
+ activation (keras.layers.Activation, str, None):
255
+ Activation function to use. If not specified, a linear activation (a(x) = x) is used.
256
+ Default to `None`.
257
+ use_bias (bool):
258
+ Whether bias should be used in dense layers. Default to `True`.
259
+ normalization (bool, str):
260
+ Whether `LayerNormalization` should be applied to the final node feature output.
261
+ To use `BatchNormalization`, specify `batch_norm`. Default to `False`.
262
+ skip_connection (bool, str):
263
+ Whether node feature input should be added to the node feature output.
264
+ If node feature input dim is not equal to `units` (node feature output dim),
265
+ a projection layer will automatically project the residual before adding it
266
+ to the output. To use weighted skip connection,
267
+ specify `weighted`. The weight multiplied with the skip connection is a
268
+ learnable scalar. Default to `True`.
269
+ kernel_initializer (keras.initializers.Initializer, str):
270
+ Initializer for the kernel weight matrix of the dense layers.
271
+ Default to `glorot_uniform`.
272
+ bias_initializer (keras.initializers.Initializer, str):
273
+ Initializer for the bias weight vector of the dense layers.
274
+ Default to `zeros`.
275
+ kernel_regularizer (keras.regularizers.Regularizer, None):
276
+ Regularizer function applied to the kernel weight matrix.
277
+ Default to `None`.
278
+ bias_regularizer (keras.regularizers.Regularizer, None):
279
+ Regularizer function applied to the bias weight vector.
280
+ Default to `None`.
281
+ activity_regularizer (keras.regularizers.Regularizer, None):
282
+ Regularizer function applied to the output of the dense layers.
283
+ Default to `None`.
284
+ kernel_constraint (keras.constraints.Constraint, None):
285
+ Constraint function applied to the kernel weight matrix.
286
+ Default to `None`.
287
+ bias_constraint (keras.constraints.Constraint, None):
288
+ Constraint function applied to the bias weight vector.
289
+ Default to `None`.
223
290
  """
224
291
 
225
292
  def __init__(
226
293
  self,
227
294
  units: int = None,
228
- normalize: bool = False,
229
- skip_connection: bool = False,
295
+ activation: str | keras.layers.Activation | None = None,
296
+ use_bias: bool = True,
297
+ normalization: bool | str = False,
298
+ skip_connection: bool | str = True,
230
299
  **kwargs
231
300
  ) -> None:
232
- super().__init__(**kwargs)
233
- self.units = units
234
- self._normalize_aggregate = normalize
301
+ super().__init__(use_bias=use_bias, **kwargs)
302
+ self._units = units
303
+ self._normalization = normalization
235
304
  self._skip_connection = skip_connection
305
+ self._activation = keras.activations.get(activation)
306
+
307
+ def __init_subclass__(cls, **kwargs):
308
+ super().__init_subclass__(**kwargs)
309
+ subclass_build = cls.build
310
+
311
+ @functools.wraps(subclass_build)
312
+ def build_wrapper(self, spec):
313
+ GraphConv.build(self, spec)
314
+ return subclass_build(self, spec)
315
+
316
+ cls.build = build_wrapper
317
+
318
+ @property
319
+ def units(self):
320
+ return self._units
321
+
322
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
323
+ """Forward pass.
324
+
325
+ Invokes `message(graph)`, `aggregate(graph)` and `update(graph)` in sequence.
326
+
327
+ Arguments:
328
+ tensor:
329
+ A `GraphTensor` instance.
330
+ """
331
+ if self._skip_connection:
332
+ input_node_feature = tensor.node['feature']
333
+ if self._project_input_node_feature:
334
+ input_node_feature = self._residual_projection(input_node_feature)
335
+
336
+ tensor = self.message(tensor)
337
+ tensor = self.aggregate(tensor)
338
+ tensor = self.update(tensor)
339
+
340
+ updated_node_feature = tensor.node['feature']
341
+
342
+ if self._skip_connection:
343
+ if self._use_weighted_skip_connection:
344
+ input_node_feature *= self._skip_connection_weight
345
+ updated_node_feature += input_node_feature
346
+
347
+ if self._normalization:
348
+ updated_node_feature = self._output_norm(updated_node_feature)
349
+
350
+ return tensor.update({'node': {'feature': updated_node_feature}})
236
351
 
237
352
  def build(self, spec: tensors.GraphTensor.Spec) -> None:
238
353
  if not self.units:
239
354
  raise ValueError(
240
- f'`self.units` needs to be a positive integer. ound: {self.units}.'
355
+ f'`self.units` needs to be a positive integer. Found: {self.units}.'
241
356
  )
242
357
  node_feature_dim = spec.node['feature'].shape[-1]
243
358
  self._project_input_node_feature = (
@@ -253,81 +368,115 @@ class GraphConv(GraphLayer):
253
368
  self._residual_projection = self.get_dense(
254
369
  self.units, name='residual_projection'
255
370
  )
256
- if self._normalize_aggregate:
257
- self._aggregation_norm = keras.layers.LayerNormalization(
258
- name='aggregation_normalization'
371
+
372
+ skip_connection = str(self._skip_connection).lower()
373
+ self._use_weighted_skip_connection = skip_connection.startswith('weight')
374
+ if self._use_weighted_skip_connection:
375
+ self._skip_connection_weight = self.add_weight(
376
+ name='skip_connection_weight',
377
+ shape=(),
378
+ initializer='ones',
379
+ trainable=True,
259
380
  )
260
- self._aggregation_norm.build([None, self.units])
261
381
 
262
- super().build(spec)
382
+ if self._normalization:
383
+ if str(self._normalization).lower().startswith('batch'):
384
+ self._output_norm = keras.layers.BatchNormalization(
385
+ name='output_batch_norm'
386
+ )
387
+ else:
388
+ self._output_norm = keras.layers.LayerNormalization(
389
+ name='output_layer_norm'
390
+ )
391
+
392
+ self._has_edge_feature = 'edge' in spec.edge
393
+
394
+ has_overridden_message = self.__class__.message != GraphConv.message
395
+ if not has_overridden_message:
396
+ self._message_dense = self.get_dense(self.units)
397
+
398
+ has_overridden_update = self.__class__.update != GraphConv.update
399
+ if not has_overridden_update:
400
+ self._output_dense = self.get_dense(self.units)
401
+ self._output_activation = self._activation
263
402
 
264
- @abc.abstractmethod
265
403
  def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
266
404
  """Compute messages.
267
405
 
268
- This method needs to be implemented by subclass.
406
+ This method may be overridden by subclass.
269
407
 
270
- Args:
408
+ Arguments:
271
409
  tensor:
272
410
  The inputted `GraphTensor` instance.
273
411
  """
274
-
275
- @abc.abstractmethod
412
+ if not self._has_edge_feature:
413
+ message_feature = tensor.gather('feature', 'source')
414
+ else:
415
+ message_feature = keras.ops.concatenate(
416
+ [
417
+ tensor.gather('feature', 'source'),
418
+ tensor.edge['feature']
419
+ ],
420
+ axis=-1
421
+ )
422
+ message = self._message_dense(message_feature)
423
+ return tensor.update(
424
+ {
425
+ 'edge': {
426
+ 'message': message
427
+ }
428
+ }
429
+ )
430
+
276
431
  def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
277
432
  """Aggregates messages.
278
433
 
279
- This method needs to be implemented by subclass.
434
+ This method may be overridden by subclass.
280
435
 
281
- Args:
436
+ Arguments:
282
437
  tensor:
283
438
  A `GraphTensor` instance containing a message.
284
439
  """
440
+ aggregate = tensor.aggregate('message', mode='mean')
441
+ return tensor.update(
442
+ {
443
+ 'node': {
444
+ 'feature': aggregate,
445
+ 'previous_feature': tensor.node['feature']
446
+ },
447
+ 'edge': {
448
+ 'message': None
449
+ }
450
+ }
451
+ )
285
452
 
286
- @abc.abstractmethod
287
453
  def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
288
454
  """Updates nodes.
289
455
 
290
- This method needs to be implemented by subclass.
456
+ This method may be overridden by subclass.
291
457
 
292
- Args:
458
+ Arguments:
293
459
  tensor:
294
460
  A `GraphTensor` instance containing aggregated messages
295
461
  (updated node features).
296
462
  """
297
-
298
- def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
299
- """Calls the layer.
300
-
301
- The `GraphConv` layer invokes `message`, `aggregate` and `update`
302
- in sequence.
303
-
304
- Args:
305
- tensor:
306
- A `GraphTensor` instance.
307
- """
308
-
309
- if self._skip_connection:
310
- input_node_feature = tensor.node['feature']
311
- if self._project_input_node_feature:
312
- input_node_feature = self._residual_projection(input_node_feature)
313
-
314
- tensor = self.message(tensor)
315
- tensor = self.aggregate(tensor)
316
-
317
- if self._normalize_aggregate:
318
- normalized_node_feature = self._aggregation_norm(tensor.node['feature'])
319
- tensor = tensor.update({'node': {'feature': normalized_node_feature}})
320
-
321
- tensor = self.update(tensor)
322
-
323
- if not self._skip_connection:
324
- return tensor
325
-
326
- updated_node_feature = tensor.node['feature']
463
+ if not 'previous_feature' in tensor.node:
464
+ feature = tensor.node['feature']
465
+ else:
466
+ feature = keras.ops.concatenate(
467
+ [
468
+ tensor.node['feature'],
469
+ tensor.node['previous_feature']
470
+ ],
471
+ axis=-1
472
+ )
473
+ update = self._output_dense(feature)
474
+ update = self._output_activation(update)
327
475
  return tensor.update(
328
476
  {
329
477
  'node': {
330
- 'feature': updated_node_feature + input_node_feature
478
+ 'feature': update,
479
+ 'previous_feature': None,
331
480
  }
332
481
  }
333
482
  )
@@ -336,7 +485,8 @@ class GraphConv(GraphLayer):
336
485
  config = super().get_config()
337
486
  config.update({
338
487
  'units': self.units,
339
- 'normalize': self._normalize_aggregate,
488
+ 'activation': keras.activations.serialize(self._activation),
489
+ 'normalization': self._normalization,
340
490
  'skip_connection': self._skip_connection,
341
491
  })
342
492
  return config
@@ -346,6 +496,33 @@ class GraphConv(GraphLayer):
346
496
  class GIConv(GraphConv):
347
497
 
348
498
  """Graph isomorphism network layer.
499
+
500
+ >>> graph = molcraft.tensors.GraphTensor(
501
+ ... context={
502
+ ... 'size': [2]
503
+ ... },
504
+ ... node={
505
+ ... 'feature': [[1.], [2.]]
506
+ ... },
507
+ ... edge={
508
+ ... 'source': [0, 1],
509
+ ... 'target': [1, 0],
510
+ ... }
511
+ ... )
512
+ >>> conv = molcraft.layers.GIConv(units=4)
513
+ >>> conv(graph)
514
+ GraphTensor(
515
+ context={
516
+ 'size': <tf.Tensor: shape=[1], dtype=int32>
517
+ },
518
+ node={
519
+ 'feature': <tf.Tensor: shape=[2, 4], dtype=float32>
520
+ },
521
+ edge={
522
+ 'source': <tf.Tensor: shape=[2], dtype=int32>,
523
+ 'target': <tf.Tensor: shape=[2], dtype=int32>
524
+ }
525
+ )
349
526
  """
350
527
 
351
528
  def __init__(
@@ -353,24 +530,20 @@ class GIConv(GraphConv):
353
530
  units: int,
354
531
  activation: keras.layers.Activation | str | None = 'relu',
355
532
  use_bias: bool = True,
356
- normalize: bool = True,
357
- dropout: float = 0.0,
533
+ normalization: bool = False,
358
534
  update_edge_feature: bool = True,
359
535
  **kwargs,
360
536
  ):
361
537
  super().__init__(
362
538
  units=units,
363
- normalize=normalize,
539
+ activation=activation,
540
+ normalization=normalization,
364
541
  use_bias=use_bias,
365
542
  **kwargs
366
543
  )
367
- self._activation = keras.activations.get(activation)
368
- self._dropout = dropout
369
544
  self._update_edge_feature = update_edge_feature
370
545
 
371
- def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
372
- """Builds the layer.
373
- """
546
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
374
547
  node_feature_dim = spec.node['feature'].shape[-1]
375
548
 
376
549
  self.epsilon = self.add_weight(
@@ -395,25 +568,16 @@ class GIConv(GraphConv):
395
568
 
396
569
  if self._update_edge_feature:
397
570
  self._edge_dense = self.get_dense(node_feature_dim)
398
- self._edge_dense.build([None, edge_feature_dim])
399
571
  else:
400
572
  self._update_edge_feature = False
401
-
402
- self._feedforward_intermediate_dense = self.get_dense(self.units)
403
- self._feedforward_intermediate_dense.build([None, node_feature_dim])
404
573
 
405
574
  has_overridden_update = self.__class__.update != GIConv.update
406
575
  if not has_overridden_update:
576
+ self._feedforward_intermediate_dense = self.get_dense(self.units)
407
577
  self._feedforward_activation = self._activation
408
- self._feedforward_dropout = keras.layers.Dropout(self._dropout)
409
578
  self._feedforward_output_dense = self.get_dense(self.units)
410
- self._feedforward_output_dense.build([None, self.units])
411
-
412
- self.built = True
413
579
 
414
580
  def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
415
- """Computes messages.
416
- """
417
581
  message = tensor.gather('feature', 'source')
418
582
  edge_feature = tensor.edge.get('feature')
419
583
  if self._update_edge_feature:
@@ -430,11 +594,8 @@ class GIConv(GraphConv):
430
594
  )
431
595
 
432
596
  def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
433
- """Aggregates messages.
434
- """
435
- node_feature = tensor.aggregate('message')
597
+ node_feature = tensor.aggregate('message', mode='mean')
436
598
  node_feature += (1 + self.epsilon) * tensor.node['feature']
437
- node_feature = self._feedforward_intermediate_dense(node_feature)
438
599
  return tensor.update(
439
600
  {
440
601
  'node': {
@@ -447,11 +608,9 @@ class GIConv(GraphConv):
447
608
  )
448
609
 
449
610
  def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
450
- """Updates nodes.
451
- """
452
611
  node_feature = tensor.node['feature']
612
+ node_feature = self._feedforward_intermediate_dense(node_feature)
453
613
  node_feature = self._feedforward_activation(node_feature)
454
- node_feature = self._feedforward_dropout(node_feature)
455
614
  node_feature = self._feedforward_output_dense(node_feature)
456
615
  return tensor.update(
457
616
  {
@@ -464,8 +623,6 @@ class GIConv(GraphConv):
464
623
  def get_config(self) -> dict:
465
624
  config = super().get_config()
466
625
  config.update({
467
- 'activation': keras.activations.serialize(self._activation),
468
- 'dropout': self._dropout,
469
626
  'update_edge_feature': self._update_edge_feature
470
627
  })
471
628
  return config
@@ -475,6 +632,33 @@ class GIConv(GraphConv):
475
632
  class GAConv(GraphConv):
476
633
 
477
634
  """Graph attention network layer.
635
+
636
+ >>> graph = molcraft.tensors.GraphTensor(
637
+ ... context={
638
+ ... 'size': [2]
639
+ ... },
640
+ ... node={
641
+ ... 'feature': [[1.], [2.]]
642
+ ... },
643
+ ... edge={
644
+ ... 'source': [0, 1],
645
+ ... 'target': [1, 0],
646
+ ... }
647
+ ... )
648
+ >>> conv = molcraft.layers.GAConv(units=4, heads=2)
649
+ >>> conv(graph)
650
+ GraphTensor(
651
+ context={
652
+ 'size': <tf.Tensor: shape=[1], dtype=int32>
653
+ },
654
+ node={
655
+ 'feature': <tf.Tensor: shape=[2, 4], dtype=float32>
656
+ },
657
+ edge={
658
+ 'source': <tf.Tensor: shape=[2], dtype=int32>,
659
+ 'target': <tf.Tensor: shape=[2], dtype=int32>
660
+ }
661
+ )
478
662
  """
479
663
 
480
664
  def __init__(
@@ -483,8 +667,7 @@ class GAConv(GraphConv):
483
667
  heads: int = 8,
484
668
  activation: keras.layers.Activation | str | None = "relu",
485
669
  use_bias: bool = True,
486
- normalize: bool = True,
487
- dropout: float = 0.0,
670
+ normalization: bool = False,
488
671
  update_edge_feature: bool = True,
489
672
  attention_activation: keras.layers.Activation | str | None = "leaky_relu",
490
673
  **kwargs,
@@ -492,17 +675,15 @@ class GAConv(GraphConv):
492
675
  kwargs['skip_connection'] = False
493
676
  super().__init__(
494
677
  units=units,
495
- normalize=normalize,
678
+ activation=activation,
496
679
  use_bias=use_bias,
680
+ normalization=normalization,
497
681
  **kwargs
498
682
  )
499
683
  self._heads = heads
500
684
  if self.units % self.heads != 0:
501
685
  raise ValueError(f"units need to be divisible by heads.")
502
686
  self._head_units = self.units // self.heads
503
- self._activation = keras.activations.get(activation)
504
- self._dropout = dropout
505
- self._normalize = normalize
506
687
  self._update_edge_feature = update_edge_feature
507
688
  self._attention_activation = keras.activations.get(attention_activation)
508
689
 
@@ -514,48 +695,33 @@ class GAConv(GraphConv):
514
695
  def head_units(self):
515
696
  return self._head_units
516
697
 
517
- def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
518
-
519
- node_feature_dim = spec.node['feature'].shape[-1]
520
- attn_feature_dim = node_feature_dim + node_feature_dim
521
-
698
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
522
699
  self._has_edge_feature = 'feature' in spec.edge
523
- if self._has_edge_feature:
524
- edge_feature_dim = spec.edge['feature'].shape[-1]
525
- attn_feature_dim += edge_feature_dim
526
- if self._update_edge_feature:
527
- self._edge_dense = self.get_einsum_dense(
528
- 'ijh,jkh->ikh', (self.head_units, self.heads)
529
- )
530
- self._edge_dense.build([None, self.head_units, self.heads])
531
- else:
532
- self._update_edge_feature = False
533
-
700
+ self._update_edge_feature = self._has_edge_feature and self._update_edge_feature
701
+ if self._update_edge_feature:
702
+ self._edge_dense = self.get_einsum_dense(
703
+ 'ijh,jkh->ikh', (self.head_units, self.heads)
704
+ )
534
705
  self._node_dense = self.get_einsum_dense(
535
706
  'ij,jkh->ikh', (self.head_units, self.heads)
536
707
  )
537
- self._node_dense.build([None, node_feature_dim])
538
-
539
708
  self._feature_dense = self.get_einsum_dense(
540
709
  'ij,jkh->ikh', (self.head_units, self.heads)
541
710
  )
542
- self._feature_dense.build([None, attn_feature_dim])
543
-
544
711
  self._attention_dense = self.get_einsum_dense(
545
712
  'ijh,jkh->ikh', (1, self.heads)
546
713
  )
547
- self._attention_dense.build([None, self.head_units, self.heads])
548
-
549
714
  self._node_self_dense = self.get_einsum_dense(
550
715
  'ij,jkh->ikh', (self.head_units, self.heads)
551
716
  )
552
- self._node_self_dense.build([None, node_feature_dim])
553
- self._dropout_layer = keras.layers.Dropout(self._dropout)
554
717
 
555
- self.built = True
718
+ has_overridden_update = self.__class__.update != GAConv.update
719
+ if not has_overridden_update:
720
+ self._feedforward_intermediate_dense = self.get_dense(self.units)
721
+ self._feedforward_activation = self._activation
722
+ self._feedforward_output_dense = self.get_dense(self.units)
556
723
 
557
724
  def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
558
-
559
725
  attention_feature = keras.ops.concatenate(
560
726
  [
561
727
  tensor.gather('feature', 'source'),
@@ -598,9 +764,8 @@ class GAConv(GraphConv):
598
764
  )
599
765
 
600
766
  def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
601
- node_feature = tensor.aggregate('message')
767
+ node_feature = tensor.aggregate('message', mode='sum')
602
768
  node_feature += self._node_self_dense(tensor.node['feature'])
603
- node_feature = self._dropout_layer(node_feature)
604
769
  node_feature = keras.ops.reshape(node_feature, (-1, self.units))
605
770
  return tensor.update(
606
771
  {
@@ -615,7 +780,10 @@ class GAConv(GraphConv):
615
780
  )
616
781
 
617
782
  def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
618
- node_feature = self._activation(tensor.node['feature'])
783
+ node_feature = tensor.node['feature']
784
+ node_feature = self._feedforward_intermediate_dense(node_feature)
785
+ node_feature = self._feedforward_activation(node_feature)
786
+ node_feature = self._feedforward_output_dense(node_feature)
619
787
  return tensor.update(
620
788
  {
621
789
  'node': {
@@ -628,8 +796,6 @@ class GAConv(GraphConv):
628
796
  config = super().get_config()
629
797
  config.update({
630
798
  "heads": self._heads,
631
- 'activation': keras.activations.serialize(self._activation),
632
- 'dropout': self._dropout,
633
799
  'update_edge_feature': self._update_edge_feature,
634
800
  'attention_activation': keras.activations.serialize(self._attention_activation),
635
801
  })
@@ -640,6 +806,34 @@ class GAConv(GraphConv):
640
806
  class GTConv(GraphConv):
641
807
 
642
808
  """Graph transformer layer.
809
+
810
+ >>> graph = molcraft.tensors.GraphTensor(
811
+ ... context={
812
+ ... 'size': [2]
813
+ ... },
814
+ ... node={
815
+ ... 'feature': [[1.], [2.]]
816
+ ... },
817
+ ... edge={
818
+ ... 'source': [0, 1],
819
+ ... 'target': [1, 0],
820
+ ... }
821
+ ... )
822
+ >>> conv = molcraft.layers.GTConv(units=4, heads=2)
823
+ >>> conv(graph)
824
+ GraphTensor(
825
+ context={
826
+ 'size': <tf.Tensor: shape=[1], dtype=int32>
827
+ },
828
+ node={
829
+ 'feature': <tf.Tensor: shape=[2, 4], dtype=float32>
830
+ },
831
+ edge={
832
+ 'source': <tf.Tensor: shape=[2], dtype=int32>,
833
+ 'target': <tf.Tensor: shape=[2], dtype=int32>
834
+ }
835
+ )
836
+
643
837
  """
644
838
 
645
839
  def __init__(
@@ -648,26 +842,22 @@ class GTConv(GraphConv):
648
842
  heads: int = 8,
649
843
  activation: keras.layers.Activation | str | None = "relu",
650
844
  use_bias: bool = True,
651
- normalize: bool = True,
652
- dropout: float = 0.0,
845
+ normalization: bool = False,
653
846
  attention_dropout: float = 0.0,
654
847
  **kwargs,
655
848
  ) -> None:
656
- kwargs['skip_connection'] = False
657
849
  super().__init__(
658
850
  units=units,
659
- normalize=normalize,
851
+ activation=activation,
660
852
  use_bias=use_bias,
853
+ normalization=normalization,
661
854
  **kwargs
662
855
  )
663
856
  self._heads = heads
664
857
  if self.units % self.heads != 0:
665
858
  raise ValueError(f"units need to be divisible by heads.")
666
859
  self._head_units = self.units // self.heads
667
- self._activation = keras.activations.get(activation)
668
- self._dropout = dropout
669
860
  self._attention_dropout = attention_dropout
670
- self._normalize = normalize
671
861
 
672
862
  @property
673
863
  def heads(self):
@@ -677,69 +867,31 @@ class GTConv(GraphConv):
677
867
  def head_units(self):
678
868
  return self._head_units
679
869
 
680
- def build_from_spec(self, spec):
681
- """Builds the layer.
682
- """
683
- node_feature_dim = spec.node['feature'].shape[-1]
684
- self.project_residual = node_feature_dim != self.units
685
- if self.project_residual:
686
- warn(
687
- '`GTConv` uses residual connections, but found incompatible dim '
688
- 'between input (node feature dim) and output (`self.units`). '
689
- 'Automatically applying a projection layer to residual to '
690
- 'match input and output. '
691
- )
692
- self._residual_dense = self.get_dense(self.units)
693
- self._residual_dense.build([None, node_feature_dim])
694
-
870
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
695
871
  self._query_dense = self.get_einsum_dense(
696
872
  'ij,jkh->ikh', (self.head_units, self.heads)
697
873
  )
698
- self._query_dense.build([None, node_feature_dim])
699
-
700
874
  self._key_dense = self.get_einsum_dense(
701
875
  'ij,jkh->ikh', (self.head_units, self.heads)
702
876
  )
703
- self._key_dense.build([None, node_feature_dim])
704
-
705
877
  self._value_dense = self.get_einsum_dense(
706
878
  'ij,jkh->ikh', (self.head_units, self.heads)
707
879
  )
708
- self._value_dense.build([None, node_feature_dim])
709
-
710
880
  self._output_dense = self.get_dense(self.units)
711
- self._output_dense.build([None, self.units])
712
-
713
881
  self._softmax_dropout = keras.layers.Dropout(self._attention_dropout)
714
882
 
715
- self._self_attention_dropout = keras.layers.Dropout(self._dropout)
716
-
717
883
  self._add_bias = not 'bias' in spec.edge
718
884
 
719
885
  if self._add_bias:
720
886
  self._edge_bias = EdgeBias(biases=self.heads)
721
- self._edge_bias.build_from_spec(spec)
722
887
 
723
888
  has_overridden_update = self.__class__.update != GTConv.update
724
889
  if not has_overridden_update:
725
-
726
- if self._normalize:
727
- self._feedforward_output_norm = keras.layers.LayerNormalization()
728
- self._feedforward_output_norm.build([None, self.units])
729
-
730
- self._feedforward_dropout = keras.layers.Dropout(self._dropout)
731
-
732
890
  self._feedforward_intermediate_dense = self.get_dense(self.units)
733
- self._feedforward_intermediate_dense.build([None, self.units])
734
-
891
+ self._feedforward_activation = self._activation
735
892
  self._feedforward_output_dense = self.get_dense(self.units)
736
- self._feedforward_output_dense.build([None, self.units])
737
-
738
- self.built = True
739
893
 
740
894
  def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
741
- """Computes messages.
742
- """
743
895
  if self._add_bias:
744
896
  edge_bias = self._edge_bias(tensor)
745
897
  tensor = tensor.update(
@@ -764,7 +916,6 @@ class GTConv(GraphConv):
764
916
  attention_score /= keras.ops.sqrt(float(self.head_units))
765
917
 
766
918
  attention_score += keras.ops.expand_dims(tensor.edge['bias'], axis=1)
767
-
768
919
  attention = ops.edge_softmax(attention_score, tensor.edge['target'])
769
920
  attention = self._softmax_dropout(attention)
770
921
 
@@ -778,12 +929,9 @@ class GTConv(GraphConv):
778
929
  )
779
930
 
780
931
  def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
781
- """Aggregates messages.
782
- """
783
- node_feature = tensor.aggregate('message')
932
+ node_feature = tensor.aggregate('message', mode='sum')
784
933
  node_feature = keras.ops.reshape(node_feature, (-1, self.units))
785
934
  node_feature = self._output_dense(node_feature)
786
- node_feature = self._self_attention_dropout(node_feature)
787
935
  return tensor.update(
788
936
  {
789
937
  'node': {
@@ -798,26 +946,10 @@ class GTConv(GraphConv):
798
946
  )
799
947
 
800
948
  def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
801
- """Updates nodes.
802
- """
803
949
  node_feature = tensor.node['feature']
804
-
805
- residual = tensor.node['residual']
806
- if self.project_residual:
807
- residual = self._residual_dense(residual)
808
-
809
- node_feature += residual
810
- residual = node_feature
811
-
812
950
  node_feature = self._feedforward_intermediate_dense(node_feature)
813
- node_feature = self._activation(node_feature)
951
+ node_feature = self._feedforward_activation(node_feature)
814
952
  node_feature = self._feedforward_output_dense(node_feature)
815
- node_feature = self._feedforward_dropout(node_feature)
816
- if self._normalize:
817
- node_feature = self._feedforward_output_norm(node_feature)
818
-
819
- node_feature += residual
820
-
821
953
  return tensor.update(
822
954
  {
823
955
  'node': {
@@ -830,148 +962,48 @@ class GTConv(GraphConv):
830
962
  config = super().get_config()
831
963
  config.update({
832
964
  "heads": self._heads,
833
- 'activation': keras.activations.serialize(self._activation),
834
- 'dropout': self._dropout,
835
965
  'attention_dropout': self._attention_dropout,
836
966
  })
837
967
  return config
838
968
 
839
969
 
840
970
  @keras.saving.register_keras_serializable(package='molcraft')
841
- class GTConv3D(GTConv):
971
+ class MPConv(GraphConv):
842
972
 
843
- """Graph transformer 3D layer.
973
+ """Message passing neural network layer.
844
974
  """
845
975
 
846
- def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
847
- super().build_from_spec(spec)
848
- if self._add_bias:
849
- node_feature_dim = spec.node['feature'].shape[-1]
850
- kernels = self.units
851
- self._gaussian_basis = GaussianDistance(kernels)
852
- self._gaussian_basis.build_from_spec(spec)
853
- self._centrality_dense = self.get_dense(units=node_feature_dim)
854
- self._centrality_dense.build([None, kernels])
855
- self._gaussian_edge_bias = self.get_dense(self.heads)
856
- self._gaussian_edge_bias.build([None, kernels])
857
-
858
- def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
859
- """Computes messages.
860
- """
861
- node_feature = tensor.node['feature']
976
+ def __init__(
977
+ self,
978
+ units: int = 128,
979
+ activation: keras.layers.Activation | str | None = None,
980
+ use_bias: bool = True,
981
+ normalization: bool = False,
982
+ **kwargs
983
+ ) -> None:
984
+ super().__init__(
985
+ units=units,
986
+ activation=activation,
987
+ use_bias=use_bias,
988
+ normalization=normalization,
989
+ **kwargs
990
+ )
862
991
 
863
- if self._add_bias:
864
- gaussian = self._gaussian_basis(tensor)
865
- centrality = keras.ops.segment_sum(
866
- gaussian, tensor.edge['target'], tensor.num_nodes
992
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
993
+ node_feature_dim = spec.node['feature'].shape[-1]
994
+ self.message_fn = self.get_dense(self.units, activation=self._activation)
995
+ self.update_fn = keras.layers.GRUCell(self.units)
996
+ self._has_edge_feature = 'feature' in spec.edge
997
+ self.project_input_node_feature = node_feature_dim != self.units
998
+ if self.project_input_node_feature:
999
+ warn(
1000
+ 'Input node feature dim does not match updated node feature dim. '
1001
+ 'To make sure input node feature can be passed as `states` to the '
1002
+ 'GRU cell, it will automatically be projected prior to it.'
1003
+ )
1004
+ self._previous_node_dense = self.get_dense(
1005
+ self.units, activation=self._activation
867
1006
  )
868
- node_feature += self._centrality_dense(centrality)
869
-
870
- edge_bias = self._edge_bias(tensor) + self._gaussian_edge_bias(gaussian)
871
- tensor = tensor.update({'edge': {'bias': edge_bias}})
872
-
873
- query = self._query_dense(node_feature)
874
- key = self._key_dense(node_feature)
875
- value = self._value_dense(node_feature)
876
-
877
- query = ops.gather(query, tensor.edge['source'])
878
- key = ops.gather(key, tensor.edge['target'])
879
- value = ops.gather(value, tensor.edge['source'])
880
-
881
- attention_score = keras.ops.sum(query * key, axis=1, keepdims=True)
882
- attention_score /= keras.ops.sqrt(float(self.head_units))
883
-
884
- attention_score += keras.ops.expand_dims(tensor.edge['bias'], axis=1)
885
-
886
- attention = ops.edge_softmax(attention_score, tensor.edge['target'])
887
- attention = self._softmax_dropout(attention)
888
-
889
- distance = keras.ops.subtract(
890
- tensor.gather('coordinate', 'source'),
891
- tensor.gather('coordinate', 'target')
892
- )
893
- euclidean_distance = ops.euclidean_distance(
894
- tensor.gather('coordinate', 'source'),
895
- tensor.gather('coordinate', 'target'),
896
- axis=-1
897
- )
898
- distance /= euclidean_distance
899
-
900
- attention *= keras.ops.expand_dims(distance, axis=-1)
901
- attention = keras.ops.expand_dims(attention, axis=2)
902
- value = keras.ops.expand_dims(value, axis=1)
903
-
904
- return tensor.update(
905
- {
906
- 'edge': {
907
- 'message': value,
908
- 'weight': attention,
909
- },
910
- }
911
- )
912
-
913
- def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
914
- """Aggregates messages.
915
- """
916
- node_feature = tensor.aggregate('message')
917
- node_feature = keras.ops.reshape(
918
- node_feature, (tensor.num_nodes, -1, self.units)
919
- )
920
- node_feature = self._output_dense(node_feature)
921
- node_feature = keras.ops.sum(node_feature, axis=1)
922
- node_feature = self._self_attention_dropout(node_feature)
923
- return tensor.update(
924
- {
925
- 'node': {
926
- 'feature': node_feature,
927
- 'residual': tensor.node['feature']
928
- },
929
- 'edge': {
930
- 'message': None,
931
- 'weight': None,
932
- }
933
- }
934
- )
935
-
936
-
937
- @keras.saving.register_keras_serializable(package='molcraft')
938
- class MPConv(GraphConv):
939
-
940
- """Message passing neural network layer.
941
- """
942
-
943
- def __init__(
944
- self,
945
- units: int = 128,
946
- activation: keras.layers.Activation | str | None = None,
947
- use_bias: bool = True,
948
- normalize: bool = True,
949
- dropout: float = 0.0,
950
- **kwargs
951
- ) -> None:
952
- super().__init__(
953
- units=units,
954
- normalize=normalize,
955
- use_bias=use_bias,
956
- **kwargs
957
- )
958
- self._activation = keras.activations.get(activation)
959
- self._dropout = dropout or 0.0
960
-
961
- def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
962
- node_feature_dim = spec.node['feature'].shape[-1]
963
- self.message_fn = self.get_dense(self.units, activation=self._activation)
964
- self.update_fn = keras.layers.GRUCell(self.units)
965
- self._has_edge_feature = 'feature' in spec.edge
966
- self.project_input_node_feature = node_feature_dim != self.units
967
- if self.project_input_node_feature:
968
- warn(
969
- 'Input node feature dim does not match updated node feature dim. '
970
- 'To make sure input node feature can be passed as `states` to the '
971
- 'GRU cell, it will automatically be projected prior to it.'
972
- )
973
- self._previous_node_dense = self.get_dense(self.units, activation=self._activation)
974
- self.built = True
975
1007
 
976
1008
  def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
977
1009
  feature = keras.ops.concatenate(
@@ -999,7 +1031,7 @@ class MPConv(GraphConv):
999
1031
  )
1000
1032
 
1001
1033
  def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1002
- aggregate = tensor.aggregate('message')
1034
+ aggregate = tensor.aggregate('message', mode='mean')
1003
1035
  previous = tensor.node['feature']
1004
1036
  if self.project_input_node_feature:
1005
1037
  previous = self._previous_node_dense(previous)
@@ -1028,17 +1060,105 @@ class MPConv(GraphConv):
1028
1060
 
1029
1061
  def get_config(self) -> dict:
1030
1062
  config = super().get_config()
1031
- config.update({
1032
- 'activation': keras.activations.serialize(self._activation),
1033
- 'dropout': self._dropout,
1034
- })
1063
+ config.update({})
1035
1064
  return config
1036
1065
 
1037
1066
 
1067
+ @keras.saving.register_keras_serializable(package='molcraft')
1068
+ class GTConv3D(GTConv):
1069
+
1070
+ """Graph transformer layer 3D.
1071
+ """
1072
+
1073
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
1074
+ """Builds the layer.
1075
+ """
1076
+ super().build(spec)
1077
+ if self._add_bias:
1078
+ node_feature_dim = spec.node['feature'].shape[-1]
1079
+ kernels = self.units
1080
+ self._gaussian_basis = GaussianDistance(kernels)
1081
+ self._centrality_dense = self.get_dense(units=node_feature_dim)
1082
+ self._gaussian_edge_bias = self.get_dense(self.heads)
1083
+
1084
+ def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1085
+ node_feature = tensor.node['feature']
1086
+
1087
+ if self._add_bias:
1088
+ gaussian = self._gaussian_basis(tensor)
1089
+ centrality = keras.ops.segment_sum(
1090
+ gaussian, tensor.edge['target'], tensor.num_nodes
1091
+ )
1092
+ node_feature += self._centrality_dense(centrality)
1093
+
1094
+ edge_bias = self._edge_bias(tensor) + self._gaussian_edge_bias(gaussian)
1095
+ tensor = tensor.update({'edge': {'bias': edge_bias}})
1096
+
1097
+ query = self._query_dense(node_feature)
1098
+ key = self._key_dense(node_feature)
1099
+ value = self._value_dense(node_feature)
1100
+
1101
+ query = ops.gather(query, tensor.edge['source'])
1102
+ key = ops.gather(key, tensor.edge['target'])
1103
+ value = ops.gather(value, tensor.edge['source'])
1104
+
1105
+ attention_score = keras.ops.sum(query * key, axis=1, keepdims=True)
1106
+ attention_score /= keras.ops.sqrt(float(self.head_units))
1107
+
1108
+ attention_score += keras.ops.expand_dims(tensor.edge['bias'], axis=1)
1109
+
1110
+ attention = ops.edge_softmax(attention_score, tensor.edge['target'])
1111
+ attention = self._softmax_dropout(attention)
1112
+
1113
+ distance = keras.ops.subtract(
1114
+ tensor.gather('coordinate', 'source'),
1115
+ tensor.gather('coordinate', 'target')
1116
+ )
1117
+ euclidean_distance = ops.euclidean_distance(
1118
+ tensor.gather('coordinate', 'source'),
1119
+ tensor.gather('coordinate', 'target'),
1120
+ axis=-1
1121
+ )
1122
+ distance /= euclidean_distance
1123
+
1124
+ attention *= keras.ops.expand_dims(distance, axis=-1)
1125
+ attention = keras.ops.expand_dims(attention, axis=2)
1126
+ value = keras.ops.expand_dims(value, axis=1)
1127
+
1128
+ return tensor.update(
1129
+ {
1130
+ 'edge': {
1131
+ 'message': value,
1132
+ 'weight': attention,
1133
+ },
1134
+ }
1135
+ )
1136
+
1137
+ def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1138
+ node_feature = tensor.aggregate('message', mode='sum')
1139
+ node_feature = keras.ops.reshape(
1140
+ node_feature, (tensor.num_nodes, -1, self.units)
1141
+ )
1142
+ node_feature = self._output_dense(node_feature)
1143
+ node_feature = keras.ops.sum(node_feature, axis=1)
1144
+ return tensor.update(
1145
+ {
1146
+ 'node': {
1147
+ 'feature': node_feature,
1148
+ 'residual': tensor.node['feature']
1149
+ },
1150
+ 'edge': {
1151
+ 'message': None,
1152
+ 'weight': None,
1153
+ }
1154
+ }
1155
+ )
1156
+
1157
+
1038
1158
  @keras.saving.register_keras_serializable(package='molcraft')
1039
1159
  class MPConv3D(MPConv):
1040
1160
 
1041
- """3D Message passing neural network layer.
1161
+ """Message passing neural network layer 3D.
1042
1162
  """
1043
1163
 
1044
1164
  def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
@@ -1076,7 +1196,7 @@ class MPConv3D(MPConv):
1076
1196
  @keras.saving.register_keras_serializable(package='molcraft')
1077
1197
  class EGConv3D(GraphConv):
1078
1198
 
1079
- """Equivariant graph neural network layer.
1199
+ """Equivariant graph neural network layer 3D.
1080
1200
  """
1081
1201
 
1082
1202
  def __init__(
@@ -1084,49 +1204,33 @@ class EGConv3D(GraphConv):
1084
1204
  units: int = 128,
1085
1205
  activation: keras.layers.Activation | str | None = None,
1086
1206
  use_bias: bool = True,
1087
- normalize: bool = True,
1088
- dropout: float = 0.0,
1207
+ normalization: bool = False,
1089
1208
  **kwargs
1090
1209
  ) -> None:
1091
1210
  super().__init__(
1092
1211
  units=units,
1093
- normalize=normalize,
1212
+ activation=activation,
1094
1213
  use_bias=use_bias,
1214
+ normalization=normalization,
1095
1215
  **kwargs
1096
1216
  )
1097
- self._activation = keras.activations.get(activation)
1098
- self._dropout = dropout or 0.0
1099
1217
 
1100
- def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
1218
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
1101
1219
  if 'coordinate' not in spec.node:
1102
1220
  raise ValueError(
1103
1221
  'Could not find `coordinate`s in node, '
1104
1222
  'which is required for Conv3D layers.'
1105
1223
  )
1106
- node_feature_dim = spec.node['feature'].shape[-1]
1107
- feature_dim = node_feature_dim + node_feature_dim + 1
1108
- if 'feature' in spec.edge:
1109
- self._has_edge_feature = True
1110
- edge_feature_dim = spec.edge['feature'].shape[-1]
1111
- feature_dim += edge_feature_dim
1112
- else:
1113
- self._has_edge_feature = False
1114
-
1224
+ self._has_edge_feature = 'feature' in spec.edge
1115
1225
  self.message_fn = self.get_dense(self.units, activation=self._activation)
1116
- self.message_fn.build([None, feature_dim])
1117
1226
  self.dense_position = self.get_dense(1)
1118
- self.dense_position.build([None, self.units])
1119
1227
 
1120
1228
  has_overridden_update = self.__class__.update != EGConv3D.update
1121
1229
  if not has_overridden_update:
1122
1230
  self.update_fn = self.get_dense(self.units, activation=self._activation)
1123
- self.update_fn.build([None, node_feature_dim + self.units])
1124
- self._dropout_layer = keras.layers.Dropout(self._dropout)
1125
- self.built = True
1231
+ self.output_dense = self.get_dense(self.units)
1126
1232
 
1127
1233
  def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1128
- """Computes messages.
1129
- """
1130
1234
  relative_node_coordinate = keras.ops.subtract(
1131
1235
  tensor.gather('coordinate', 'target'),
1132
1236
  tensor.gather('coordinate', 'source')
@@ -1169,8 +1273,6 @@ class EGConv3D(GraphConv):
1169
1273
  )
1170
1274
 
1171
1275
  def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1172
- """Aggregates messages.
1173
- """
1174
1276
  coefficient = keras.ops.bincount(
1175
1277
  tensor.edge['source'],
1176
1278
  minlength=tensor.num_nodes
@@ -1185,7 +1287,7 @@ class EGConv3D(GraphConv):
1185
1287
  updated_coordinate = tensor.aggregate('relative_node_coordinate') * coefficient
1186
1288
  updated_coordinate += tensor.node['coordinate']
1187
1289
 
1188
- aggregate = tensor.aggregate('message')
1290
+ aggregate = tensor.aggregate('message', mode='mean')
1189
1291
  return tensor.update(
1190
1292
  {
1191
1293
  'node': {
@@ -1201,8 +1303,6 @@ class EGConv3D(GraphConv):
1201
1303
  )
1202
1304
 
1203
1305
  def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1204
- """Updates nodes.
1205
- """
1206
1306
  updated_node_feature = self.update_fn(
1207
1307
  keras.ops.concatenate(
1208
1308
  [
@@ -1212,7 +1312,7 @@ class EGConv3D(GraphConv):
1212
1312
  axis=-1
1213
1313
  )
1214
1314
  )
1215
- updated_node_feature = self._dropout_layer(updated_node_feature)
1315
+ updated_node_feature = self.output_dense(updated_node_feature)
1216
1316
  return tensor.update(
1217
1317
  {
1218
1318
  'node': {
@@ -1224,66 +1324,46 @@ class EGConv3D(GraphConv):
1224
1324
 
1225
1325
  def get_config(self) -> dict:
1226
1326
  config = super().get_config()
1227
- config.update({
1228
- 'activation': keras.activations.serialize(self._activation),
1229
- 'dropout': self._dropout,
1230
- })
1327
+ config.update({})
1231
1328
  return config
1232
1329
 
1233
1330
 
1234
1331
  @keras.saving.register_keras_serializable(package='molcraft')
1235
- class Projection(GraphLayer):
1236
- """Base graph projection layer.
1332
+ class Readout(GraphLayer):
1333
+
1334
+ """Readout layer.
1237
1335
  """
1238
- def __init__(
1239
- self,
1240
- units: int = None,
1241
- activation: str = None,
1242
- field: str = 'node',
1243
- **kwargs
1244
- ) -> None:
1245
- super().__init__(**kwargs)
1246
- self.units = units
1247
- self._activation = keras.activations.get(activation)
1248
- self.field = field
1249
1336
 
1250
- def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
1251
- """Builds the layer.
1252
- """
1253
- data = getattr(spec, self.field, None)
1254
- if data is None:
1255
- raise ValueError('Could not access field {self.field!r}.')
1256
- feature_dim = data['feature'].shape[-1]
1257
- if not self.units:
1258
- self.units = feature_dim
1259
- self._dense = self.get_dense(self.units)
1260
- self._dense.build([None, feature_dim])
1261
- self.built = True
1337
+ def __init__(self, mode: str | None = None, **kwargs):
1338
+ kwargs['kernel_initializer'] = None
1339
+ kwargs['bias_initializer'] = None
1340
+ super().__init__(**kwargs)
1341
+ self.mode = mode
1342
+ if str(self.mode).lower().startswith('sum'):
1343
+ self._reduce_fn = keras.ops.segment_sum
1344
+ elif str(self.mode).lower().startswith('max'):
1345
+ self._reduce_fn = keras.ops.segment_max
1346
+ elif str(self.mode).lower().startswith('super'):
1347
+ self._reduce_fn = keras.ops.segment_sum
1348
+ else:
1349
+ self._reduce_fn = ops.segment_mean
1262
1350
 
1263
- def propagate(self, tensor: tensors.GraphTensor):
1264
- """Calls the layer.
1265
- """
1266
- feature = getattr(tensor, self.field)['feature']
1267
- feature = self._dense(feature)
1268
- feature = self._activation(feature)
1269
- return tensor.update(
1270
- {
1271
- self.field: {
1272
- 'feature': feature
1273
- }
1274
- }
1275
- )
1351
+ def propagate(self, tensor: tensors.GraphTensor) -> tf.Tensor:
1352
+ node_feature = tensor.node['feature']
1353
+ if str(self.mode).lower().startswith('super'):
1354
+ node_feature = keras.ops.where(
1355
+ tensor.node['super'][:, None], node_feature, 0.0
1356
+ )
1357
+ return self._reduce_fn(
1358
+ node_feature, tensor.graph_indicator, tensor.num_subgraphs
1359
+ )
1276
1360
 
1277
1361
  def get_config(self) -> dict:
1278
1362
  config = super().get_config()
1279
- config.update({
1280
- 'units': self.units,
1281
- 'activation': keras.activations.serialize(self._activation),
1282
- 'field': self.field,
1283
- })
1284
- return config
1363
+ config['mode'] = self.mode
1364
+ return config
1365
+
1285
1366
 
1286
-
1287
1367
  @keras.saving.register_keras_serializable(package='molcraft')
1288
1368
  class GraphNetwork(GraphLayer):
1289
1369
 
@@ -1291,7 +1371,7 @@ class GraphNetwork(GraphLayer):
1291
1371
 
1292
1372
  Sequentially calls graph layers (`GraphLayer`) and concatenates its output.
1293
1373
 
1294
- Args:
1374
+ Arguments:
1295
1375
  layers (list):
1296
1376
  A list of graph layers.
1297
1377
  """
@@ -1301,37 +1381,32 @@ class GraphNetwork(GraphLayer):
1301
1381
  self.layers = layers
1302
1382
  self._update_edge_feature = False
1303
1383
 
1304
- def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
1305
- """Builds the layer.
1306
- """
1384
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
1307
1385
  units = self.layers[0].units
1308
1386
  node_feature_dim = spec.node['feature'].shape[-1]
1309
- if node_feature_dim != units:
1387
+ self._update_node_feature = node_feature_dim != units
1388
+ if self._update_node_feature:
1310
1389
  warn(
1311
1390
  'Node feature dim does not match `units` of the first layer. '
1312
1391
  'Automatically adding a node projection layer to match `units`.'
1313
1392
  )
1314
1393
  self._node_dense = self.get_dense(units)
1315
- self._update_node_feature = True
1316
- has_edge_feature = 'feature' in spec.edge
1317
- if has_edge_feature:
1394
+ self._has_edge_feature = 'feature' in spec.edge
1395
+ if self._has_edge_feature:
1318
1396
  edge_feature_dim = spec.edge['feature'].shape[-1]
1319
- if edge_feature_dim != units:
1397
+ self._update_edge_feature = edge_feature_dim != units
1398
+ if self._update_edge_feature:
1320
1399
  warn(
1321
1400
  'Edge feature dim does not match `units` of the first layer. '
1322
1401
  'Automatically adding a edge projection layer to match `units`.'
1323
1402
  )
1324
1403
  self._edge_dense = self.get_dense(units)
1325
- self._update_edge_feature = True
1326
- self.built = True
1327
1404
 
1328
1405
  def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1329
- """Calls the layer.
1330
- """
1331
1406
  x = tensors.to_dict(tensor)
1332
1407
  if self._update_node_feature:
1333
1408
  x['node']['feature'] = self._node_dense(tensor.node['feature'])
1334
- if self._update_edge_feature:
1409
+ if self._has_edge_feature and self._update_edge_feature:
1335
1410
  x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
1336
1411
  outputs = [x['node']['feature']]
1337
1412
  for layer in self.layers:
@@ -1356,7 +1431,7 @@ class GraphNetwork(GraphLayer):
1356
1431
  Performs the same forward pass as `propagate` but with a `GradientTape`
1357
1432
  watching intermediate node features.
1358
1433
 
1359
- Args:
1434
+ Arguments:
1360
1435
  tensor (tensors.GraphTensor):
1361
1436
  The graph input.
1362
1437
  """
@@ -1414,26 +1489,25 @@ class NodeEmbedding(GraphLayer):
1414
1489
  def __init__(
1415
1490
  self,
1416
1491
  dim: int = None,
1417
- normalize: bool = True,
1492
+ normalization: bool = False,
1418
1493
  embed_context: bool = True,
1494
+ allow_reconstruction: bool = False,
1419
1495
  allow_masking: bool = True,
1420
1496
  **kwargs
1421
1497
  ) -> None:
1422
1498
  super().__init__(**kwargs)
1423
1499
  self.dim = dim
1424
- self._normalize = normalize
1500
+ self._normalization = normalization
1425
1501
  self._embed_context = embed_context
1426
1502
  self._masking_rate = None
1427
1503
  self._allow_masking = allow_masking
1504
+ self._allow_reconstruction = allow_reconstruction
1428
1505
 
1429
- def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
1430
- """Builds the layer.
1431
- """
1506
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
1432
1507
  feature_dim = spec.node['feature'].shape[-1]
1433
1508
  if not self.dim:
1434
1509
  self.dim = feature_dim
1435
1510
  self._node_dense = self.get_dense(self.dim)
1436
- self._node_dense.build([None, feature_dim])
1437
1511
 
1438
1512
  self._has_super = 'super' in spec.node
1439
1513
  has_context_feature = 'feature' in spec.context
@@ -1447,17 +1521,18 @@ class NodeEmbedding(GraphLayer):
1447
1521
  if self._embed_context:
1448
1522
  context_feature_dim = spec.context['feature'].shape[-1]
1449
1523
  self._context_dense = self.get_dense(self.dim)
1450
- self._context_dense.build([None, context_feature_dim])
1451
1524
 
1452
- if self._normalize:
1453
- self._norm = keras.layers.LayerNormalization()
1454
- self._norm.build([None, self.dim])
1455
-
1456
- self.built = True
1525
+ if self._normalization:
1526
+ if str(self._normalization).lower().startswith('batch'):
1527
+ self._norm = keras.layers.BatchNormalization(
1528
+ name='output_batch_norm'
1529
+ )
1530
+ else:
1531
+ self._norm = keras.layers.LayerNormalization(
1532
+ name='output_layer_norm'
1533
+ )
1457
1534
 
1458
1535
  def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1459
- """Calls the layer.
1460
- """
1461
1536
  feature = self._node_dense(tensor.node['feature'])
1462
1537
 
1463
1538
  if self._has_super:
@@ -1487,11 +1562,13 @@ class NodeEmbedding(GraphLayer):
1487
1562
  # Slience warning of 'no gradients for variables'
1488
1563
  feature = feature + (self._mask_feature * 0.0)
1489
1564
 
1490
- if self._normalize:
1565
+ if self._normalization:
1491
1566
  feature = self._norm(feature)
1492
1567
 
1493
- return tensor.update({'node': {'feature': feature}})
1494
-
1568
+ if not self._allow_reconstruction:
1569
+ return tensor.update({'node': {'feature': feature}})
1570
+ return tensor.update({'node': {'feature': feature, 'target_feature': feature}})
1571
+
1495
1572
  @property
1496
1573
  def masking_rate(self):
1497
1574
  return self._masking_rate
@@ -1509,9 +1586,10 @@ class NodeEmbedding(GraphLayer):
1509
1586
  config = super().get_config()
1510
1587
  config.update({
1511
1588
  'dim': self.dim,
1512
- 'normalize': self._normalize,
1589
+ 'normalization': self._normalization,
1513
1590
  'embed_context': self._embed_context,
1514
- 'allow_masking': self._allow_masking
1591
+ 'allow_masking': self._allow_masking,
1592
+ 'allow_reconstruction': self._allow_reconstruction,
1515
1593
  })
1516
1594
  return config
1517
1595
 
@@ -1527,39 +1605,39 @@ class EdgeEmbedding(GraphLayer):
1527
1605
  def __init__(
1528
1606
  self,
1529
1607
  dim: int = None,
1530
- normalize: bool = True,
1608
+ normalization: bool = False,
1531
1609
  allow_masking: bool = True,
1532
1610
  **kwargs
1533
1611
  ) -> None:
1534
1612
  super().__init__(**kwargs)
1535
1613
  self.dim = dim
1536
- self._normalize = normalize
1614
+ self._normalization = normalization
1537
1615
  self._masking_rate = None
1538
1616
  self._allow_masking = allow_masking
1539
1617
 
1540
- def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
1541
- """Builds the layer.
1542
- """
1618
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
1543
1619
  feature_dim = spec.edge['feature'].shape[-1]
1544
1620
  if not self.dim:
1545
1621
  self.dim = feature_dim
1546
1622
  self._edge_dense = self.get_dense(self.dim)
1547
- self._edge_dense.build([None, feature_dim])
1548
1623
 
1549
1624
  self._has_super = 'super' in spec.edge
1550
1625
  if self._has_super:
1551
1626
  self._super_feature = self.get_weight(shape=[self.dim], name='super_edge_feature')
1552
1627
  if self._allow_masking:
1553
1628
  self._mask_feature = self.get_weight(shape=[self.dim], name='mask_edge_feature')
1554
- if self._normalize:
1555
- self._norm = keras.layers.LayerNormalization()
1556
- self._norm.build([None, self.dim])
1557
1629
 
1558
- self.built = True
1630
+ if self._normalization:
1631
+ if str(self._normalization).lower().startswith('batch'):
1632
+ self._norm = keras.layers.BatchNormalization(
1633
+ name='output_batch_norm'
1634
+ )
1635
+ else:
1636
+ self._norm = keras.layers.LayerNormalization(
1637
+ name='output_layer_norm'
1638
+ )
1559
1639
 
1560
1640
  def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1561
- """Calls the layer.
1562
- """
1563
1641
  feature = self._edge_dense(tensor.edge['feature'])
1564
1642
 
1565
1643
  if self._has_super:
@@ -1584,10 +1662,10 @@ class EdgeEmbedding(GraphLayer):
1584
1662
  # Slience warning of 'no gradients for variables'
1585
1663
  feature = feature + (self._mask_feature * 0.0)
1586
1664
 
1587
- if self._normalize:
1665
+ if self._normalization:
1588
1666
  feature = self._norm(feature)
1589
1667
 
1590
- return tensor.update({'edge': {'feature': feature}})
1668
+ return tensor.update({'edge': {'feature': feature, 'embedding': feature}})
1591
1669
 
1592
1670
  @property
1593
1671
  def masking_rate(self):
@@ -1606,18 +1684,67 @@ class EdgeEmbedding(GraphLayer):
1606
1684
  config = super().get_config()
1607
1685
  config.update({
1608
1686
  'dim': self.dim,
1609
- 'normalize': self._normalize,
1687
+ 'normalization': self._normalization,
1610
1688
  'allow_masking': self._allow_masking
1611
1689
  })
1612
1690
  return config
1613
1691
 
1614
1692
 
1693
+ @keras.saving.register_keras_serializable(package='molcraft')
1694
+ class Projection(GraphLayer):
1695
+ """Base graph projection layer.
1696
+ """
1697
+ def __init__(
1698
+ self,
1699
+ units: int = None,
1700
+ activation: str | keras.layers.Activation | None = None,
1701
+ use_bias: bool = True,
1702
+ field: str = 'node',
1703
+ **kwargs
1704
+ ) -> None:
1705
+ super().__init__(use_bias=use_bias, **kwargs)
1706
+ self.units = units
1707
+ self._activation = keras.activations.get(activation)
1708
+ self.field = field
1709
+
1710
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
1711
+ data = getattr(spec, self.field, None)
1712
+ if data is None:
1713
+ raise ValueError('Could not access field {self.field!r}.')
1714
+ feature_dim = data['feature'].shape[-1]
1715
+ if not self.units:
1716
+ self.units = feature_dim
1717
+ self._dense = self.get_dense(self.units)
1718
+
1719
+ def propagate(self, tensor: tensors.GraphTensor):
1720
+ feature = getattr(tensor, self.field)['feature']
1721
+ feature = self._dense(feature)
1722
+ feature = self._activation(feature)
1723
+ return tensor.update(
1724
+ {
1725
+ self.field: {
1726
+ 'feature': feature
1727
+ }
1728
+ }
1729
+ )
1730
+
1731
+ def get_config(self) -> dict:
1732
+ config = super().get_config()
1733
+ config.update({
1734
+ 'units': self.units,
1735
+ 'activation': keras.activations.serialize(self._activation),
1736
+ 'field': self.field,
1737
+ })
1738
+ return config
1739
+
1740
+
1615
1741
  @keras.saving.register_keras_serializable(package='molcraft')
1616
1742
  class ContextProjection(Projection):
1617
1743
  """Context projection layer.
1618
1744
  """
1619
1745
  def __init__(self, units: int = None, activation: str = None, **kwargs):
1620
- super().__init__(units=units, activation=activation, field='context', **kwargs)
1746
+ kwargs['field'] = 'context'
1747
+ super().__init__(units=units, activation=activation, **kwargs)
1621
1748
 
1622
1749
 
1623
1750
  @keras.saving.register_keras_serializable(package='molcraft')
@@ -1625,7 +1752,8 @@ class NodeProjection(Projection):
1625
1752
  """Node projection layer.
1626
1753
  """
1627
1754
  def __init__(self, units: int = None, activation: str = None, **kwargs):
1628
- super().__init__(units=units, activation=activation, field='node', **kwargs)
1755
+ kwargs['field'] = 'node'
1756
+ super().__init__(units=units, activation=activation, **kwargs)
1629
1757
 
1630
1758
 
1631
1759
  @keras.saving.register_keras_serializable(package='molcraft')
@@ -1633,9 +1761,55 @@ class EdgeProjection(Projection):
1633
1761
  """Edge projection layer.
1634
1762
  """
1635
1763
  def __init__(self, units: int = None, activation: str = None, **kwargs):
1636
- super().__init__(units=units, activation=activation, field='edge', **kwargs)
1764
+ kwargs['field'] = 'edge'
1765
+ super().__init__(units=units, activation=activation, **kwargs)
1637
1766
 
1638
1767
 
1768
+ @keras.saving.register_keras_serializable(package='molcraft')
1769
+ class Reconstruction(GraphLayer):
1770
+
1771
+ def __init__(
1772
+ self,
1773
+ loss: keras.losses.Loss | str = 'mse',
1774
+ loss_weight: float = 0.5,
1775
+ **kwargs
1776
+ ):
1777
+ super().__init__(**kwargs)
1778
+ self._loss_fn = keras.losses.get(loss)
1779
+ self._loss_weight = loss_weight
1780
+
1781
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
1782
+ has_target_node_feature = 'target_feature' in spec.node
1783
+ if not has_target_node_feature:
1784
+ raise ValueError(
1785
+ 'Could not find `target_feature` in `spec.node`. '
1786
+ 'Add a `target_feature` via `NodeEmbedding` by setting '
1787
+ '`allow_reconstruction` to `True`.'
1788
+ )
1789
+ output_dim = spec.node['target_feature'].shape[-1]
1790
+ self._dense = self.get_dense(output_dim)
1791
+
1792
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1793
+ target_node_feature = tensor.node['target_feature']
1794
+ transformed_node_feature = tensor.node['feature']
1795
+
1796
+ reconstructed_node_feature = self._dense(
1797
+ transformed_node_feature
1798
+ )
1799
+
1800
+ loss = self._loss_fn(
1801
+ target_node_feature, reconstructed_node_feature
1802
+ )
1803
+ self.add_loss(keras.ops.sum(loss) * self._loss_weight)
1804
+ return tensor.update({'node': {'feature': transformed_node_feature}})
1805
+
1806
+ def get_config(self):
1807
+ config = super().get_config()
1808
+ config['loss'] = keras.losses.serialize(self._loss_fn)
1809
+ config['loss_weight'] = self._loss_weight
1810
+ return config
1811
+
1812
+
1639
1813
  @keras.saving.register_keras_serializable(package='molcraft')
1640
1814
  class EdgeBias(GraphLayer):
1641
1815
 
@@ -1643,18 +1817,15 @@ class EdgeBias(GraphLayer):
1643
1817
  super().__init__(**kwargs)
1644
1818
  self.biases = biases
1645
1819
 
1646
- def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
1820
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
1647
1821
  self._has_edge_length = 'length' in spec.edge
1648
1822
  self._has_edge_feature = 'feature' in spec.edge
1649
1823
  if self._has_edge_feature:
1650
1824
  self._edge_feature_dense = self.get_dense(self.biases)
1651
- self._edge_feature_dense.build([None, spec.edge['feature'].shape[-1]])
1652
1825
  if self._has_edge_length:
1653
1826
  self._edge_length_dense = self.get_dense(
1654
1827
  self.biases, kernel_initializer='zeros'
1655
1828
  )
1656
- self._edge_length_dense.build([None, spec.edge['length'].shape[-1]])
1657
- self.built = True
1658
1829
 
1659
1830
  def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1660
1831
  bias = keras.ops.zeros(
@@ -1680,7 +1851,7 @@ class GaussianDistance(GraphLayer):
1680
1851
  super().__init__(**kwargs)
1681
1852
  self.kernels = kernels
1682
1853
 
1683
- def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
1854
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
1684
1855
  self._loc = self.add_weight(
1685
1856
  shape=[self.kernels],
1686
1857
  initializer='zeros',
@@ -1693,8 +1864,7 @@ class GaussianDistance(GraphLayer):
1693
1864
  dtype='float32',
1694
1865
  trainable=True
1695
1866
  )
1696
- self.built = True
1697
-
1867
+
1698
1868
  def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1699
1869
  euclidean_distance = ops.euclidean_distance(
1700
1870
  tensor.gather('coordinate', 'source'),
@@ -1711,49 +1881,6 @@ class GaussianDistance(GraphLayer):
1711
1881
  'kernels': self.kernels,
1712
1882
  })
1713
1883
  return config
1714
-
1715
-
1716
- @keras.saving.register_keras_serializable(package='molcraft')
1717
- class Readout(GraphLayer):
1718
-
1719
- """Readout layer.
1720
- """
1721
-
1722
- def __init__(self, mode: str | None = None, **kwargs):
1723
- kwargs['kernel_initializer'] = None
1724
- kwargs['bias_initializer'] = None
1725
- super().__init__(**kwargs)
1726
- self.mode = mode
1727
- if str(self.mode).lower().startswith('sum'):
1728
- self._reduce_fn = keras.ops.segment_sum
1729
- elif str(self.mode).lower().startswith('max'):
1730
- self._reduce_fn = keras.ops.segment_max
1731
- elif str(self.mode).lower().startswith('super'):
1732
- self._reduce_fn = keras.ops.segment_sum
1733
- else:
1734
- self._reduce_fn = ops.segment_mean
1735
-
1736
- def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
1737
- """Builds the layer.
1738
- """
1739
- self.built = True
1740
-
1741
- def propagate(self, tensor: tensors.GraphTensor) -> tf.Tensor:
1742
- """Calls the layer.
1743
- """
1744
- node_feature = tensor.node['feature']
1745
- if str(self.mode).lower().startswith('super'):
1746
- node_feature = keras.ops.where(
1747
- tensor.node['super'][:, None], node_feature, 0.0
1748
- )
1749
- return self._reduce_fn(
1750
- node_feature, tensor.graph_indicator, tensor.num_subgraphs
1751
- )
1752
-
1753
- def get_config(self) -> dict:
1754
- config = super().get_config()
1755
- config['mode'] = self.mode
1756
- return config
1757
1884
 
1758
1885
 
1759
1886
  def Input(spec: tensors.GraphTensor.Spec) -> dict: