molcraft 0.1.0a1__py3-none-any.whl → 0.1.0a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

molcraft/layers.py CHANGED
@@ -60,25 +60,20 @@ class GraphLayer(keras.layers.Layer):
60
60
  May use built-in methods such as `get_weight`, `get_dense` and `get_einsum_dense`.
61
61
 
62
62
  Optionally implemented by subclass. If implemented, it is recommended to
63
- build the sub-layers via `build([None, input_dim])`. If sub-layers are not
64
- built, symbolic input will be passed through the layer to build it.
63
+ If sub-layers are built (via `build` or `build_from_spec`), set `built`
64
+ to True. If not, symbolic input will be passed through the layer to build them.
65
65
 
66
66
  Args:
67
67
  spec:
68
- A `GraphTensor.Spec` instance, corresponding to the input `GraphTensor`
69
- of the `propagate` method.
68
+ A `GraphTensor.Spec` instance, corresponding to the `GraphTensor`
69
+ passed to `propagate`.
70
70
  """
71
-
71
+
72
72
  def build(self, spec: tensors.GraphTensor.Spec) -> None:
73
73
 
74
74
  self._custom_build_config = {'spec': _serialize_spec(spec)}
75
75
 
76
- invoke_build_from_spec = (
77
- GraphLayer.build_from_spec != self.__class__.build_from_spec
78
- )
79
- if invoke_build_from_spec:
80
- self.build_from_spec(spec)
81
- self.built = True
76
+ self.build_from_spec(spec)
82
77
 
83
78
  if not self.built:
84
79
  # Automatically build layer or model by calling it on symbolic inputs
@@ -206,12 +201,66 @@ class GraphLayer(keras.layers.Layer):
206
201
  class GraphConv(GraphLayer):
207
202
 
208
203
  """Base graph neural network layer.
204
+
205
+ For normalization and skip connection to work, the `GraphConv` subclass
206
+ requires the (node feature) output of `aggregate` and `update` to have a
207
+ dimension of `self.units`, respectively.
208
+
209
+ Args:
210
+ units:
211
+ The number of units.
212
+ normalize:
213
+ Whether `LayerNormalization` should be applied to the (node feature) output
214
+ of the `aggregate` step. While normalization is recommended, it is not used
215
+ by default.
216
+ skip_connection:
217
+ Whether (node feature) input should be added to the (node feature) output.
218
+ If (node feature) input dim is not equal to `units`, a projection layer will
219
+ automatically project the residual before adding it to the output. While skip
220
+ connection is recommended, it is not used by default.
221
+ kwargs:
222
+ See arguments of `GraphLayer`.
209
223
  """
210
224
 
211
- def __init__(self, units: int, **kwargs) -> None:
225
+ def __init__(
226
+ self,
227
+ units: int = None,
228
+ normalize: bool = False,
229
+ skip_connection: bool = False,
230
+ **kwargs
231
+ ) -> None:
212
232
  super().__init__(**kwargs)
213
233
  self.units = units
214
-
234
+ self._normalize_aggregate = normalize
235
+ self._skip_connection = skip_connection
236
+
237
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
238
+ if not self.units:
239
+ raise ValueError(
240
+ f'`self.units` needs to be a positive integer. ound: {self.units}.'
241
+ )
242
+ node_feature_dim = spec.node['feature'].shape[-1]
243
+ self._project_input_node_feature = (
244
+ self._skip_connection and (node_feature_dim != self.units)
245
+ )
246
+ if self._project_input_node_feature:
247
+ warn(
248
+ '`skip_connection` is set to `True`, but found incompatible dim '
249
+ 'between input (node feature dim) and output (`self.units`). '
250
+ 'Automatically applying a projection layer to residual to '
251
+ 'match input and output. '
252
+ )
253
+ self._residual_projection = self.get_dense(
254
+ self.units, name='residual_projection'
255
+ )
256
+ if self._normalize_aggregate:
257
+ self._aggregation_norm = keras.layers.LayerNormalization(
258
+ name='aggregation_normalization'
259
+ )
260
+ self._aggregation_norm.build([None, self.units])
261
+
262
+ super().build(spec)
263
+
215
264
  @abc.abstractmethod
216
265
  def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
217
266
  """Compute messages.
@@ -256,206 +305,1123 @@ class GraphConv(GraphLayer):
256
305
  tensor:
257
306
  A `GraphTensor` instance.
258
307
  """
308
+
309
+ if self._skip_connection:
310
+ input_node_feature = tensor.node['feature']
311
+ if self._project_input_node_feature:
312
+ input_node_feature = self._residual_projection(input_node_feature)
313
+
259
314
  tensor = self.message(tensor)
260
315
  tensor = self.aggregate(tensor)
261
- tensor = self.update(tensor)
262
- return tensor
263
316
 
264
- def get_config(self) -> dict:
265
- config = super().get_config()
266
- config.update({
267
- 'units': self.units
268
- })
269
- return config
270
-
271
-
272
- @keras.saving.register_keras_serializable(package='molcraft')
273
- class Projection(GraphLayer):
274
- """Base graph projection layer.
275
- """
276
- def __init__(
277
- self,
278
- units: int = None,
279
- activation: str = None,
280
- field: str = 'node',
281
- **kwargs
282
- ) -> None:
283
- super().__init__(**kwargs)
284
- self.units = units
285
- self._activation = keras.activations.get(activation)
286
- self.field = field
317
+ if self._normalize_aggregate:
318
+ normalized_node_feature = self._aggregation_norm(tensor.node['feature'])
319
+ tensor = tensor.update({'node': {'feature': normalized_node_feature}})
287
320
 
288
- def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
289
- """Builds the layer.
290
- """
291
- data = getattr(spec, self.field, None)
292
- if data is None:
293
- raise ValueError('Could not access field {self.field!r}.')
294
- feature_dim = data['feature'].shape[-1]
295
- if not self.units:
296
- self.units = feature_dim
297
- self._dense = self.get_dense(self.units)
298
- self._dense.build([None, feature_dim])
321
+ tensor = self.update(tensor)
299
322
 
300
- def propagate(self, tensor: tensors.GraphTensor):
301
- """Calls the layer.
302
- """
303
- feature = getattr(tensor, self.field)['feature']
304
- feature = self._dense(feature)
305
- feature = self._activation(feature)
323
+ if not self._skip_connection:
324
+ return tensor
325
+
326
+ updated_node_feature = tensor.node['feature']
306
327
  return tensor.update(
307
328
  {
308
- self.field: {
309
- 'feature': feature
329
+ 'node': {
330
+ 'feature': updated_node_feature + input_node_feature
310
331
  }
311
332
  }
312
- )
333
+ )
313
334
 
314
335
  def get_config(self) -> dict:
315
336
  config = super().get_config()
316
337
  config.update({
317
338
  'units': self.units,
318
- 'activation': keras.activations.serialize(self._activation),
319
- 'field': self.field,
339
+ 'normalize': self._normalize_aggregate,
340
+ 'skip_connection': self._skip_connection,
320
341
  })
321
342
  return config
322
343
 
323
344
 
324
345
  @keras.saving.register_keras_serializable(package='molcraft')
325
- class GraphNetwork(GraphLayer):
326
-
327
- """Graph neural network.
346
+ class GIConv(GraphConv):
328
347
 
329
- Sequentially calls graph layers (`GraphLayer`) and concatenates its output.
330
-
331
- Args:
332
- layers (list):
333
- A list of graph layers.
348
+ """Graph isomorphism network layer.
334
349
  """
335
350
 
336
- def __init__(self, layers: list[GraphLayer], **kwargs) -> None:
337
- super().__init__(**kwargs)
338
- self.layers = layers
339
- self._update_edge_feature = False
351
+ def __init__(
352
+ self,
353
+ units: int,
354
+ activation: keras.layers.Activation | str | None = 'relu',
355
+ use_bias: bool = True,
356
+ normalize: bool = True,
357
+ dropout: float = 0.0,
358
+ update_edge_feature: bool = True,
359
+ **kwargs,
360
+ ):
361
+ super().__init__(
362
+ units=units,
363
+ normalize=normalize,
364
+ use_bias=use_bias,
365
+ **kwargs
366
+ )
367
+ self._activation = keras.activations.get(activation)
368
+ self._dropout = dropout
369
+ self._update_edge_feature = update_edge_feature
340
370
 
341
371
  def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
342
372
  """Builds the layer.
343
373
  """
344
- units = self.layers[0].units
345
374
  node_feature_dim = spec.node['feature'].shape[-1]
346
- if node_feature_dim != units:
347
- warn(
348
- 'Node feature dim does not match `units` of the first layer. '
349
- 'Automatically adding a node projection layer to match `units`.'
350
- )
351
- self._node_dense = self.get_dense(units)
352
- self._update_node_feature = True
353
- has_edge_feature = 'feature' in spec.edge
354
- if has_edge_feature:
375
+
376
+ self.epsilon = self.add_weight(
377
+ name='epsilon',
378
+ shape=(),
379
+ initializer='zeros',
380
+ trainable=True,
381
+ )
382
+
383
+ self._has_edge_feature = 'feature' in spec.edge
384
+ if self._has_edge_feature:
355
385
  edge_feature_dim = spec.edge['feature'].shape[-1]
356
- if edge_feature_dim != units:
357
- warn(
358
- 'Edge feature dim does not match `units` of the first layer. '
359
- 'Automatically adding a edge projection layer to match `units`.'
360
- )
361
- self._edge_dense = self.get_dense(units)
362
- self._update_edge_feature = True
363
386
 
364
- def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
365
- """Calls the layer.
387
+ if not self._update_edge_feature:
388
+ if (edge_feature_dim != node_feature_dim):
389
+ warn(
390
+ 'Found edge feature dim to be incompatible with node feature dim. '
391
+ 'Automatically adding a edge feature projection layer to match '
392
+ 'the dim of node features.'
393
+ )
394
+ self._update_edge_feature = True
395
+
396
+ if self._update_edge_feature:
397
+ self._edge_dense = self.get_dense(node_feature_dim)
398
+ self._edge_dense.build([None, edge_feature_dim])
399
+ else:
400
+ self._update_edge_feature = False
401
+
402
+ self._feedforward_intermediate_dense = self.get_dense(self.units)
403
+ self._feedforward_intermediate_dense.build([None, node_feature_dim])
404
+
405
+ has_overridden_update = self.__class__.update != GIConv.update
406
+ if not has_overridden_update:
407
+ self._feedforward_activation = self._activation
408
+ self._feedforward_dropout = keras.layers.Dropout(self._dropout)
409
+ self._feedforward_output_dense = self.get_dense(self.units)
410
+ self._feedforward_output_dense.build([None, self.units])
411
+
412
+ self.built = True
413
+
414
+ def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
415
+ """Computes messages.
366
416
  """
367
- x = tensors.to_dict(tensor)
368
- if self._update_node_feature:
369
- x['node']['feature'] = self._node_dense(tensor.node['feature'])
417
+ message = tensor.gather('feature', 'source')
418
+ edge_feature = tensor.edge.get('feature')
370
419
  if self._update_edge_feature:
371
- x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
372
- outputs = [x['node']['feature']]
373
- for layer in self.layers:
374
- x = layer(x)
375
- outputs.append(x['node']['feature'])
420
+ edge_feature = self._edge_dense(edge_feature)
421
+ if self._has_edge_feature:
422
+ message += edge_feature
376
423
  return tensor.update(
377
424
  {
378
- 'node': {
379
- 'feature': keras.ops.concatenate(outputs, axis=-1)
380
- }
425
+ 'edge': {
426
+ 'message': message,
427
+ 'feature': edge_feature
428
+ }
381
429
  }
382
430
  )
383
-
384
- def tape_propagate(
385
- self,
386
- tensor: tensors.GraphTensor,
387
- tape: tf.GradientTape,
388
- training: bool | None = None,
389
- ) -> tuple[tensors.GraphTensor, list[tf.Tensor]]:
390
- """Performs the propagation with a `GradientTape`.
391
-
392
- Performs the same forward pass as `propagate` but with a `GradientTape`
393
- watching intermediate node features.
394
431
 
395
- Args:
396
- tensor (tensors.GraphTensor):
397
- The graph input.
432
+ def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
433
+ """Aggregates messages.
398
434
  """
399
- if isinstance(tensor, tensors.GraphTensor):
400
- x = tensors.to_dict(tensor)
401
- else:
402
- x = tensor
403
- if self._update_node_feature:
404
- x['node']['feature'] = self._node_dense(tensor.node['feature'])
405
- if self._update_edge_feature:
406
- x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
407
- tape.watch(x['node']['feature'])
408
- outputs = [x['node']['feature']]
409
- for layer in self.layers:
410
- x = layer(x, training=training)
411
- tape.watch(x['node']['feature'])
412
- outputs.append(x['node']['feature'])
413
-
414
- tensor = tensor.update(
435
+ node_feature = tensor.aggregate('message')
436
+ node_feature += (1 + self.epsilon) * tensor.node['feature']
437
+ node_feature = self._feedforward_intermediate_dense(node_feature)
438
+ return tensor.update(
415
439
  {
416
440
  'node': {
417
- 'feature': keras.ops.concatenate(outputs, axis=-1)
441
+ 'feature': node_feature,
442
+ },
443
+ 'edge': {
444
+ 'message': None,
418
445
  }
419
446
  }
420
447
  )
421
- return tensor, outputs
422
448
 
423
- def get_config(self) -> dict:
424
- config = super().get_config()
425
- config.update(
449
+ def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
450
+ """Updates nodes.
451
+ """
452
+ node_feature = tensor.node['feature']
453
+ node_feature = self._feedforward_activation(node_feature)
454
+ node_feature = self._feedforward_dropout(node_feature)
455
+ node_feature = self._feedforward_output_dense(node_feature)
456
+ return tensor.update(
426
457
  {
427
- 'layers': [
428
- keras.layers.serialize(layer) for layer in self.layers
429
- ]
458
+ 'node': {
459
+ 'feature': node_feature,
460
+ }
430
461
  }
431
462
  )
432
- return config
433
-
434
- @classmethod
435
- def from_config(cls, config: dict) -> 'GraphNetwork':
436
- config['layers'] = [
437
- keras.layers.deserialize(layer) for layer in config['layers']
438
- ]
439
- return super().from_config(config)
440
463
 
464
+ def get_config(self) -> dict:
465
+ config = super().get_config()
466
+ config.update({
467
+ 'activation': keras.activations.serialize(self._activation),
468
+ 'dropout': self._dropout,
469
+ 'update_edge_feature': self._update_edge_feature
470
+ })
471
+ return config
441
472
 
442
- @keras.saving.register_keras_serializable(package='molcraft')
443
- class NodeEmbedding(GraphLayer):
444
473
 
445
- """Node embedding layer.
474
+ @keras.saving.register_keras_serializable(package='molgraphx')
475
+ class GAConv(GraphConv):
446
476
 
447
- Embeds nodes based on its initial features.
477
+ """Graph attention network layer.
448
478
  """
449
479
 
450
480
  def __init__(
451
- self,
452
- dim: int = None,
453
- embed_context: bool = True,
454
- allow_masking: bool = True,
481
+ self,
482
+ units: int,
483
+ heads: int = 8,
484
+ activation: keras.layers.Activation | str | None = "relu",
485
+ use_bias: bool = True,
486
+ normalize: bool = True,
487
+ dropout: float = 0.0,
488
+ update_edge_feature: bool = True,
489
+ attention_activation: keras.layers.Activation | str | None = "leaky_relu",
490
+ **kwargs,
491
+ ) -> None:
492
+ kwargs['skip_connection'] = False
493
+ super().__init__(
494
+ units=units,
495
+ normalize=normalize,
496
+ use_bias=use_bias,
497
+ **kwargs
498
+ )
499
+ self._heads = heads
500
+ if self.units % self.heads != 0:
501
+ raise ValueError(f"units need to be divisible by heads.")
502
+ self._head_units = self.units // self.heads
503
+ self._activation = keras.activations.get(activation)
504
+ self._dropout = dropout
505
+ self._normalize = normalize
506
+ self._update_edge_feature = update_edge_feature
507
+ self._attention_activation = keras.activations.get(attention_activation)
508
+
509
+ @property
510
+ def heads(self):
511
+ return self._heads
512
+
513
+ @property
514
+ def head_units(self):
515
+ return self._head_units
516
+
517
+ def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
518
+
519
+ node_feature_dim = spec.node['feature'].shape[-1]
520
+ attn_feature_dim = node_feature_dim + node_feature_dim
521
+
522
+ self._has_edge_feature = 'feature' in spec.edge
523
+ if self._has_edge_feature:
524
+ edge_feature_dim = spec.edge['feature'].shape[-1]
525
+ attn_feature_dim += edge_feature_dim
526
+ if self._update_edge_feature:
527
+ self._edge_dense = self.get_einsum_dense(
528
+ 'ijh,jkh->ikh', (self.head_units, self.heads)
529
+ )
530
+ self._edge_dense.build([None, self.head_units, self.heads])
531
+ else:
532
+ self._update_edge_feature = False
533
+
534
+ self._node_dense = self.get_einsum_dense(
535
+ 'ij,jkh->ikh', (self.head_units, self.heads)
536
+ )
537
+ self._node_dense.build([None, node_feature_dim])
538
+
539
+ self._feature_dense = self.get_einsum_dense(
540
+ 'ij,jkh->ikh', (self.head_units, self.heads)
541
+ )
542
+ self._feature_dense.build([None, attn_feature_dim])
543
+
544
+ self._attention_dense = self.get_einsum_dense(
545
+ 'ijh,jkh->ikh', (1, self.heads)
546
+ )
547
+ self._attention_dense.build([None, self.head_units, self.heads])
548
+
549
+ self._node_self_dense = self.get_einsum_dense(
550
+ 'ij,jkh->ikh', (self.head_units, self.heads)
551
+ )
552
+ self._node_self_dense.build([None, node_feature_dim])
553
+ self._dropout_layer = keras.layers.Dropout(self._dropout)
554
+
555
+ self.built = True
556
+
557
+ def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
558
+
559
+ attention_feature = keras.ops.concatenate(
560
+ [
561
+ tensor.gather('feature', 'source'),
562
+ tensor.gather('feature', 'target')
563
+ ],
564
+ axis=-1
565
+ )
566
+ if self._has_edge_feature:
567
+ attention_feature = keras.ops.concatenate(
568
+ [
569
+ attention_feature,
570
+ tensor.edge['feature']
571
+ ],
572
+ axis=-1
573
+ )
574
+
575
+ attention_feature = self._feature_dense(attention_feature)
576
+
577
+ edge_feature = tensor.edge.get('feature')
578
+
579
+ if self._update_edge_feature:
580
+ edge_feature = self._edge_dense(attention_feature)
581
+ edge_feature = keras.ops.reshape(edge_feature, (-1, self.units))
582
+
583
+ attention_feature = self._attention_activation(attention_feature)
584
+ attention_score = self._attention_dense(attention_feature)
585
+ attention_score = ops.edge_softmax(
586
+ score=attention_score, edge_target=tensor.edge['target']
587
+ )
588
+ node_feature = self._node_dense(tensor.node['feature'])
589
+ message = ops.gather(node_feature, tensor.edge['source'])
590
+ return tensor.update(
591
+ {
592
+ 'edge': {
593
+ 'message': message,
594
+ 'weight': attention_score,
595
+ 'feature': edge_feature,
596
+ }
597
+ }
598
+ )
599
+
600
+ def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
601
+ node_feature = tensor.aggregate('message')
602
+ node_feature += self._node_self_dense(tensor.node['feature'])
603
+ node_feature = self._dropout_layer(node_feature)
604
+ node_feature = keras.ops.reshape(node_feature, (-1, self.units))
605
+ return tensor.update(
606
+ {
607
+ 'node': {
608
+ 'feature': node_feature
609
+ },
610
+ 'edge': {
611
+ 'message': None,
612
+ 'weight': None,
613
+ }
614
+ }
615
+ )
616
+
617
+ def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
618
+ node_feature = self._activation(tensor.node['feature'])
619
+ return tensor.update(
620
+ {
621
+ 'node': {
622
+ 'feature': node_feature
623
+ }
624
+ }
625
+ )
626
+
627
+ def get_config(self) -> dict:
628
+ config = super().get_config()
629
+ config.update({
630
+ "heads": self._heads,
631
+ 'activation': keras.activations.serialize(self._activation),
632
+ 'dropout': self._dropout,
633
+ 'update_edge_feature': self._update_edge_feature,
634
+ 'attention_activation': keras.activations.serialize(self._attention_activation),
635
+ })
636
+ return config
637
+
638
+
639
+ @keras.saving.register_keras_serializable(package='molcraft')
640
+ class GTConv(GraphConv):
641
+
642
+ """Graph transformer layer.
643
+ """
644
+
645
+ def __init__(
646
+ self,
647
+ units: int,
648
+ heads: int = 8,
649
+ activation: keras.layers.Activation | str | None = "relu",
650
+ use_bias: bool = True,
651
+ normalize: bool = True,
652
+ dropout: float = 0.0,
653
+ attention_dropout: float = 0.0,
654
+ **kwargs,
655
+ ) -> None:
656
+ kwargs['skip_connection'] = False
657
+ super().__init__(
658
+ units=units,
659
+ normalize=normalize,
660
+ use_bias=use_bias,
661
+ **kwargs
662
+ )
663
+ self._heads = heads
664
+ if self.units % self.heads != 0:
665
+ raise ValueError(f"units need to be divisible by heads.")
666
+ self._head_units = self.units // self.heads
667
+ self._activation = keras.activations.get(activation)
668
+ self._dropout = dropout
669
+ self._attention_dropout = attention_dropout
670
+ self._normalize = normalize
671
+
672
+ @property
673
+ def heads(self):
674
+ return self._heads
675
+
676
+ @property
677
+ def head_units(self):
678
+ return self._head_units
679
+
680
+ def build_from_spec(self, spec):
681
+ """Builds the layer.
682
+ """
683
+ node_feature_dim = spec.node['feature'].shape[-1]
684
+ self.project_residual = node_feature_dim != self.units
685
+ if self.project_residual:
686
+ warn(
687
+ '`GTConv` uses residual connections, but found incompatible dim '
688
+ 'between input (node feature dim) and output (`self.units`). '
689
+ 'Automatically applying a projection layer to residual to '
690
+ 'match input and output. '
691
+ )
692
+ self._residual_dense = self.get_dense(self.units)
693
+ self._residual_dense.build([None, node_feature_dim])
694
+
695
+ self._query_dense = self.get_einsum_dense(
696
+ 'ij,jkh->ikh', (self.head_units, self.heads)
697
+ )
698
+ self._query_dense.build([None, node_feature_dim])
699
+
700
+ self._key_dense = self.get_einsum_dense(
701
+ 'ij,jkh->ikh', (self.head_units, self.heads)
702
+ )
703
+ self._key_dense.build([None, node_feature_dim])
704
+
705
+ self._value_dense = self.get_einsum_dense(
706
+ 'ij,jkh->ikh', (self.head_units, self.heads)
707
+ )
708
+ self._value_dense.build([None, node_feature_dim])
709
+
710
+ self._output_dense = self.get_dense(self.units)
711
+ self._output_dense.build([None, self.units])
712
+
713
+ self._softmax_dropout = keras.layers.Dropout(self._attention_dropout)
714
+
715
+ self._self_attention_dropout = keras.layers.Dropout(self._dropout)
716
+
717
+ self._add_bias = not 'bias' in spec.edge
718
+
719
+ if self._add_bias:
720
+ self._edge_bias = EdgeBias(biases=self.heads)
721
+ self._edge_bias.build_from_spec(spec)
722
+
723
+ has_overridden_update = self.__class__.update != GTConv.update
724
+ if not has_overridden_update:
725
+
726
+ if self._normalize:
727
+ self._feedforward_output_norm = keras.layers.LayerNormalization()
728
+ self._feedforward_output_norm.build([None, self.units])
729
+
730
+ self._feedforward_dropout = keras.layers.Dropout(self._dropout)
731
+
732
+ self._feedforward_intermediate_dense = self.get_dense(self.units)
733
+ self._feedforward_intermediate_dense.build([None, self.units])
734
+
735
+ self._feedforward_output_dense = self.get_dense(self.units)
736
+ self._feedforward_output_dense.build([None, self.units])
737
+
738
+ self.built = True
739
+
740
+ def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
741
+ """Computes messages.
742
+ """
743
+ if self._add_bias:
744
+ edge_bias = self._edge_bias(tensor)
745
+ tensor = tensor.update(
746
+ {
747
+ 'edge': {
748
+ 'bias': edge_bias
749
+ }
750
+ }
751
+ )
752
+
753
+ node_feature = tensor.node['feature']
754
+
755
+ query = self._query_dense(node_feature)
756
+ key = self._key_dense(node_feature)
757
+ value = self._value_dense(node_feature)
758
+
759
+ query = ops.gather(query, tensor.edge['source'])
760
+ key = ops.gather(key, tensor.edge['target'])
761
+ value = ops.gather(value, tensor.edge['source'])
762
+
763
+ attention_score = keras.ops.sum(query * key, axis=1, keepdims=True)
764
+ attention_score /= keras.ops.sqrt(float(self.head_units))
765
+
766
+ attention_score += keras.ops.expand_dims(tensor.edge['bias'], axis=1)
767
+
768
+ attention = ops.edge_softmax(attention_score, tensor.edge['target'])
769
+ attention = self._softmax_dropout(attention)
770
+
771
+ return tensor.update(
772
+ {
773
+ 'edge': {
774
+ 'message': value,
775
+ 'weight': attention,
776
+ },
777
+ }
778
+ )
779
+
780
+ def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
781
+ """Aggregates messages.
782
+ """
783
+ node_feature = tensor.aggregate('message')
784
+ node_feature = keras.ops.reshape(node_feature, (-1, self.units))
785
+ node_feature = self._output_dense(node_feature)
786
+ node_feature = self._self_attention_dropout(node_feature)
787
+ return tensor.update(
788
+ {
789
+ 'node': {
790
+ 'feature': node_feature,
791
+ 'residual': tensor.node['feature']
792
+ },
793
+ 'edge': {
794
+ 'message': None,
795
+ 'weight': None,
796
+ }
797
+ }
798
+ )
799
+
800
+ def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
801
+ """Updates nodes.
802
+ """
803
+ node_feature = tensor.node['feature']
804
+
805
+ residual = tensor.node['residual']
806
+ if self.project_residual:
807
+ residual = self._residual_dense(residual)
808
+
809
+ node_feature += residual
810
+ residual = node_feature
811
+
812
+ node_feature = self._feedforward_intermediate_dense(node_feature)
813
+ node_feature = self._activation(node_feature)
814
+ node_feature = self._feedforward_output_dense(node_feature)
815
+ node_feature = self._feedforward_dropout(node_feature)
816
+ if self._normalize:
817
+ node_feature = self._feedforward_output_norm(node_feature)
818
+
819
+ node_feature += residual
820
+
821
+ return tensor.update(
822
+ {
823
+ 'node': {
824
+ 'feature': node_feature,
825
+ },
826
+ }
827
+ )
828
+
829
+ def get_config(self) -> dict:
830
+ config = super().get_config()
831
+ config.update({
832
+ "heads": self._heads,
833
+ 'activation': keras.activations.serialize(self._activation),
834
+ 'dropout': self._dropout,
835
+ 'attention_dropout': self._attention_dropout,
836
+ })
837
+ return config
838
+
839
+
840
+ @keras.saving.register_keras_serializable(package='molcraft')
841
+ class GTConv3D(GTConv):
842
+
843
+ """Graph transformer 3D layer.
844
+ """
845
+
846
+ def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
847
+ super().build_from_spec(spec)
848
+ if self._add_bias:
849
+ node_feature_dim = spec.node['feature'].shape[-1]
850
+ kernels = self.units
851
+ self._gaussian_basis = GaussianDistance(kernels)
852
+ self._gaussian_basis.build_from_spec(spec)
853
+ self._centrality_dense = self.get_dense(units=node_feature_dim)
854
+ self._centrality_dense.build([None, kernels])
855
+ self._gaussian_edge_bias = self.get_dense(self.heads)
856
+ self._gaussian_edge_bias.build([None, kernels])
857
+
858
+ def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
859
+ """Computes messages.
860
+ """
861
+ node_feature = tensor.node['feature']
862
+
863
+ if self._add_bias:
864
+ gaussian = self._gaussian_basis(tensor)
865
+ centrality = keras.ops.segment_sum(
866
+ gaussian, tensor.edge['target'], tensor.num_nodes
867
+ )
868
+ node_feature += self._centrality_dense(centrality)
869
+
870
+ edge_bias = self._edge_bias(tensor) + self._gaussian_edge_bias(gaussian)
871
+ tensor = tensor.update({'edge': {'bias': edge_bias}})
872
+
873
+ query = self._query_dense(node_feature)
874
+ key = self._key_dense(node_feature)
875
+ value = self._value_dense(node_feature)
876
+
877
+ query = ops.gather(query, tensor.edge['source'])
878
+ key = ops.gather(key, tensor.edge['target'])
879
+ value = ops.gather(value, tensor.edge['source'])
880
+
881
+ attention_score = keras.ops.sum(query * key, axis=1, keepdims=True)
882
+ attention_score /= keras.ops.sqrt(float(self.head_units))
883
+
884
+ attention_score += keras.ops.expand_dims(tensor.edge['bias'], axis=1)
885
+
886
+ attention = ops.edge_softmax(attention_score, tensor.edge['target'])
887
+ attention = self._softmax_dropout(attention)
888
+
889
+ distance = keras.ops.subtract(
890
+ tensor.gather('coordinate', 'source'),
891
+ tensor.gather('coordinate', 'target')
892
+ )
893
+ euclidean_distance = ops.euclidean_distance(
894
+ tensor.gather('coordinate', 'source'),
895
+ tensor.gather('coordinate', 'target'),
896
+ axis=-1
897
+ )
898
+ distance /= euclidean_distance
899
+
900
+ attention *= keras.ops.expand_dims(distance, axis=-1)
901
+ attention = keras.ops.expand_dims(attention, axis=2)
902
+ value = keras.ops.expand_dims(value, axis=1)
903
+
904
+ return tensor.update(
905
+ {
906
+ 'edge': {
907
+ 'message': value,
908
+ 'weight': attention,
909
+ },
910
+ }
911
+ )
912
+
913
+ def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
914
+ """Aggregates messages.
915
+ """
916
+ node_feature = tensor.aggregate('message')
917
+ node_feature = keras.ops.reshape(
918
+ node_feature, (tensor.num_nodes, -1, self.units)
919
+ )
920
+ node_feature = self._output_dense(node_feature)
921
+ node_feature = keras.ops.sum(node_feature, axis=1)
922
+ node_feature = self._self_attention_dropout(node_feature)
923
+ return tensor.update(
924
+ {
925
+ 'node': {
926
+ 'feature': node_feature,
927
+ 'residual': tensor.node['feature']
928
+ },
929
+ 'edge': {
930
+ 'message': None,
931
+ 'weight': None,
932
+ }
933
+ }
934
+ )
935
+
936
+
937
+ @keras.saving.register_keras_serializable(package='molcraft')
938
+ class MPConv(GraphConv):
939
+
940
+ """Message passing neural network layer.
941
+ """
942
+
943
+ def __init__(
944
+ self,
945
+ units: int = 128,
946
+ activation: keras.layers.Activation | str | None = None,
947
+ use_bias: bool = True,
948
+ normalize: bool = True,
949
+ dropout: float = 0.0,
950
+ **kwargs
951
+ ) -> None:
952
+ super().__init__(
953
+ units=units,
954
+ normalize=normalize,
955
+ use_bias=use_bias,
956
+ **kwargs
957
+ )
958
+ self._activation = keras.activations.get(activation)
959
+ self._dropout = dropout or 0.0
960
+
961
+ def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
962
+ node_feature_dim = spec.node['feature'].shape[-1]
963
+ self.message_fn = self.get_dense(self.units, activation=self._activation)
964
+ self.update_fn = keras.layers.GRUCell(self.units)
965
+ self._has_edge_feature = 'feature' in spec.edge
966
+ self.project_input_node_feature = node_feature_dim != self.units
967
+ if self.project_input_node_feature:
968
+ warn(
969
+ 'Input node feature dim does not match updated node feature dim. '
970
+ 'To make sure input node feature can be passed as `states` to the '
971
+ 'GRU cell, it will automatically be projected prior to it.'
972
+ )
973
+ self._previous_node_dense = self.get_dense(self.units, activation=self._activation)
974
+ self.built = True
975
+
976
+ def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
977
+ feature = keras.ops.concatenate(
978
+ [
979
+ tensor.gather('feature', 'source'),
980
+ tensor.gather('feature', 'target'),
981
+ ],
982
+ axis=-1
983
+ )
984
+ if self._has_edge_feature:
985
+ feature = keras.ops.concatenate(
986
+ [
987
+ feature,
988
+ tensor.edge['feature']
989
+ ],
990
+ axis=-1
991
+ )
992
+ message = self.message_fn(feature)
993
+ return tensor.update(
994
+ {
995
+ 'edge': {
996
+ 'message': message,
997
+ }
998
+ }
999
+ )
1000
+
1001
+ def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1002
+ aggregate = tensor.aggregate('message')
1003
+ previous = tensor.node['feature']
1004
+ if self.project_input_node_feature:
1005
+ previous = self._previous_node_dense(previous)
1006
+ return tensor.update(
1007
+ {
1008
+ 'node': {
1009
+ 'feature': aggregate,
1010
+ 'previous_feature': previous,
1011
+ }
1012
+ }
1013
+ )
1014
+
1015
+ def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1016
+ updated_node_feature, _ = self.update_fn(
1017
+ inputs=tensor.node['feature'],
1018
+ states=tensor.node['previous_feature']
1019
+ )
1020
+ return tensor.update(
1021
+ {
1022
+ 'node': {
1023
+ 'feature': updated_node_feature,
1024
+ 'previous_feature': None,
1025
+ }
1026
+ }
1027
+ )
1028
+
1029
+ def get_config(self) -> dict:
1030
+ config = super().get_config()
1031
+ config.update({
1032
+ 'activation': keras.activations.serialize(self._activation),
1033
+ 'dropout': self._dropout,
1034
+ })
1035
+ return config
1036
+
1037
+
1038
+ @keras.saving.register_keras_serializable(package='molcraft')
1039
+ class MPConv3D(MPConv):
1040
+
1041
+ """3D Message passing neural network layer.
1042
+ """
1043
+
1044
+ def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1045
+ euclidean_distance = ops.euclidean_distance(
1046
+ tensor.gather('coordinate', 'target'),
1047
+ tensor.gather('coordinate', 'source'),
1048
+ axis=-1
1049
+ )
1050
+ feature = keras.ops.concatenate(
1051
+ [
1052
+ tensor.gather('feature', 'source'),
1053
+ tensor.gather('feature', 'target'),
1054
+ euclidean_distance,
1055
+ ],
1056
+ axis=-1
1057
+ )
1058
+ if self._has_edge_feature:
1059
+ feature = keras.ops.concatenate(
1060
+ [
1061
+ feature,
1062
+ tensor.edge['feature']
1063
+ ],
1064
+ axis=-1
1065
+ )
1066
+ message = self.message_fn(feature)
1067
+ return tensor.update(
1068
+ {
1069
+ 'edge': {
1070
+ 'message': message,
1071
+ }
1072
+ }
1073
+ )
1074
+
1075
+
1076
+ @keras.saving.register_keras_serializable(package='molcraft')
1077
+ class EGConv3D(GraphConv):
1078
+
1079
+ """Equivariant graph neural network layer.
1080
+ """
1081
+
1082
+ def __init__(
1083
+ self,
1084
+ units: int = 128,
1085
+ activation: keras.layers.Activation | str | None = None,
1086
+ use_bias: bool = True,
1087
+ normalize: bool = True,
1088
+ dropout: float = 0.0,
1089
+ **kwargs
1090
+ ) -> None:
1091
+ super().__init__(
1092
+ units=units,
1093
+ normalize=normalize,
1094
+ use_bias=use_bias,
1095
+ **kwargs
1096
+ )
1097
+ self._activation = keras.activations.get(activation)
1098
+ self._dropout = dropout or 0.0
1099
+
1100
+ def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
1101
+ if 'coordinate' not in spec.node:
1102
+ raise ValueError(
1103
+ 'Could not find `coordinate`s in node, '
1104
+ 'which is required for Conv3D layers.'
1105
+ )
1106
+ node_feature_dim = spec.node['feature'].shape[-1]
1107
+ feature_dim = node_feature_dim + node_feature_dim + 1
1108
+ if 'feature' in spec.edge:
1109
+ self._has_edge_feature = True
1110
+ edge_feature_dim = spec.edge['feature'].shape[-1]
1111
+ feature_dim += edge_feature_dim
1112
+ else:
1113
+ self._has_edge_feature = False
1114
+
1115
+ self.message_fn = self.get_dense(self.units, activation=self._activation)
1116
+ self.message_fn.build([None, feature_dim])
1117
+ self.dense_position = self.get_dense(1)
1118
+ self.dense_position.build([None, self.units])
1119
+
1120
+ has_overridden_update = self.__class__.update != EGConv3D.update
1121
+ if not has_overridden_update:
1122
+ self.update_fn = self.get_dense(self.units, activation=self._activation)
1123
+ self.update_fn.build([None, node_feature_dim + self.units])
1124
+ self._dropout_layer = keras.layers.Dropout(self._dropout)
1125
+ self.built = True
1126
+
1127
+ def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1128
+ """Computes messages.
1129
+ """
1130
+ relative_node_coordinate = keras.ops.subtract(
1131
+ tensor.gather('coordinate', 'target'),
1132
+ tensor.gather('coordinate', 'source')
1133
+ )
1134
+ euclidean_distance = keras.ops.sum(
1135
+ keras.ops.square(
1136
+ relative_node_coordinate
1137
+ ),
1138
+ axis=-1,
1139
+ keepdims=True
1140
+ )
1141
+ feature = keras.ops.concatenate(
1142
+ [
1143
+ tensor.gather('feature', 'target'),
1144
+ tensor.gather('feature', 'source'),
1145
+ euclidean_distance,
1146
+ ],
1147
+ axis=-1
1148
+ )
1149
+ if self._has_edge_feature:
1150
+ feature = keras.ops.concatenate(
1151
+ [
1152
+ feature,
1153
+ tensor.edge['feature']
1154
+ ],
1155
+ axis=-1
1156
+ )
1157
+ message = self.message_fn(feature)
1158
+ relative_node_coordinate = keras.ops.multiply(
1159
+ relative_node_coordinate,
1160
+ self.dense_position(message)
1161
+ )
1162
+ return tensor.update(
1163
+ {
1164
+ 'edge': {
1165
+ 'message': message,
1166
+ 'relative_node_coordinate': relative_node_coordinate
1167
+ }
1168
+ }
1169
+ )
1170
+
1171
+ def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1172
+ """Aggregates messages.
1173
+ """
1174
+ coefficient = keras.ops.bincount(
1175
+ tensor.edge['source'],
1176
+ minlength=tensor.num_nodes
1177
+ )
1178
+ coefficient = keras.ops.cast(
1179
+ coefficient, tensor.node['coordinate'].dtype
1180
+ )
1181
+ coefficient = keras.ops.expand_dims(
1182
+ keras.ops.divide_no_nan(1, coefficient), axis=1
1183
+ )
1184
+
1185
+ updated_coordinate = tensor.aggregate('relative_node_coordinate') * coefficient
1186
+ updated_coordinate += tensor.node['coordinate']
1187
+
1188
+ aggregate = tensor.aggregate('message')
1189
+ return tensor.update(
1190
+ {
1191
+ 'node': {
1192
+ 'feature': aggregate,
1193
+ 'coordinate': updated_coordinate,
1194
+ 'previous_feature': tensor.node['feature'],
1195
+ },
1196
+ 'edge': {
1197
+ 'message': None,
1198
+ 'relative_node_coordinate': None
1199
+ }
1200
+ }
1201
+ )
1202
+
1203
+ def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1204
+ """Updates nodes.
1205
+ """
1206
+ updated_node_feature = self.update_fn(
1207
+ keras.ops.concatenate(
1208
+ [
1209
+ tensor.node['feature'],
1210
+ tensor.node['previous_feature']
1211
+ ],
1212
+ axis=-1
1213
+ )
1214
+ )
1215
+ updated_node_feature = self._dropout_layer(updated_node_feature)
1216
+ return tensor.update(
1217
+ {
1218
+ 'node': {
1219
+ 'feature': updated_node_feature,
1220
+ 'previous_feature': None,
1221
+ },
1222
+ }
1223
+ )
1224
+
1225
+ def get_config(self) -> dict:
1226
+ config = super().get_config()
1227
+ config.update({
1228
+ 'activation': keras.activations.serialize(self._activation),
1229
+ 'dropout': self._dropout,
1230
+ })
1231
+ return config
1232
+
1233
+
1234
+ @keras.saving.register_keras_serializable(package='molcraft')
1235
+ class Projection(GraphLayer):
1236
+ """Base graph projection layer.
1237
+ """
1238
+ def __init__(
1239
+ self,
1240
+ units: int = None,
1241
+ activation: str = None,
1242
+ field: str = 'node',
1243
+ **kwargs
1244
+ ) -> None:
1245
+ super().__init__(**kwargs)
1246
+ self.units = units
1247
+ self._activation = keras.activations.get(activation)
1248
+ self.field = field
1249
+
1250
+ def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
1251
+ """Builds the layer.
1252
+ """
1253
+ data = getattr(spec, self.field, None)
1254
+ if data is None:
1255
+ raise ValueError('Could not access field {self.field!r}.')
1256
+ feature_dim = data['feature'].shape[-1]
1257
+ if not self.units:
1258
+ self.units = feature_dim
1259
+ self._dense = self.get_dense(self.units)
1260
+ self._dense.build([None, feature_dim])
1261
+ self.built = True
1262
+
1263
+ def propagate(self, tensor: tensors.GraphTensor):
1264
+ """Calls the layer.
1265
+ """
1266
+ feature = getattr(tensor, self.field)['feature']
1267
+ feature = self._dense(feature)
1268
+ feature = self._activation(feature)
1269
+ return tensor.update(
1270
+ {
1271
+ self.field: {
1272
+ 'feature': feature
1273
+ }
1274
+ }
1275
+ )
1276
+
1277
+ def get_config(self) -> dict:
1278
+ config = super().get_config()
1279
+ config.update({
1280
+ 'units': self.units,
1281
+ 'activation': keras.activations.serialize(self._activation),
1282
+ 'field': self.field,
1283
+ })
1284
+ return config
1285
+
1286
+
1287
+ @keras.saving.register_keras_serializable(package='molcraft')
1288
+ class GraphNetwork(GraphLayer):
1289
+
1290
+ """Graph neural network.
1291
+
1292
+ Sequentially calls graph layers (`GraphLayer`) and concatenates its output.
1293
+
1294
+ Args:
1295
+ layers (list):
1296
+ A list of graph layers.
1297
+ """
1298
+
1299
+ def __init__(self, layers: list[GraphLayer], **kwargs) -> None:
1300
+ super().__init__(**kwargs)
1301
+ self.layers = layers
1302
+ self._update_edge_feature = False
1303
+
1304
+ def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
1305
+ """Builds the layer.
1306
+ """
1307
+ units = self.layers[0].units
1308
+ node_feature_dim = spec.node['feature'].shape[-1]
1309
+ if node_feature_dim != units:
1310
+ warn(
1311
+ 'Node feature dim does not match `units` of the first layer. '
1312
+ 'Automatically adding a node projection layer to match `units`.'
1313
+ )
1314
+ self._node_dense = self.get_dense(units)
1315
+ self._update_node_feature = True
1316
+ has_edge_feature = 'feature' in spec.edge
1317
+ if has_edge_feature:
1318
+ edge_feature_dim = spec.edge['feature'].shape[-1]
1319
+ if edge_feature_dim != units:
1320
+ warn(
1321
+ 'Edge feature dim does not match `units` of the first layer. '
1322
+ 'Automatically adding a edge projection layer to match `units`.'
1323
+ )
1324
+ self._edge_dense = self.get_dense(units)
1325
+ self._update_edge_feature = True
1326
+ self.built = True
1327
+
1328
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1329
+ """Calls the layer.
1330
+ """
1331
+ x = tensors.to_dict(tensor)
1332
+ if self._update_node_feature:
1333
+ x['node']['feature'] = self._node_dense(tensor.node['feature'])
1334
+ if self._update_edge_feature:
1335
+ x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
1336
+ outputs = [x['node']['feature']]
1337
+ for layer in self.layers:
1338
+ x = layer(x)
1339
+ outputs.append(x['node']['feature'])
1340
+ return tensor.update(
1341
+ {
1342
+ 'node': {
1343
+ 'feature': keras.ops.concatenate(outputs, axis=-1)
1344
+ }
1345
+ }
1346
+ )
1347
+
1348
+ def tape_propagate(
1349
+ self,
1350
+ tensor: tensors.GraphTensor,
1351
+ tape: tf.GradientTape,
1352
+ training: bool | None = None,
1353
+ ) -> tuple[tensors.GraphTensor, list[tf.Tensor]]:
1354
+ """Performs the propagation with a `GradientTape`.
1355
+
1356
+ Performs the same forward pass as `propagate` but with a `GradientTape`
1357
+ watching intermediate node features.
1358
+
1359
+ Args:
1360
+ tensor (tensors.GraphTensor):
1361
+ The graph input.
1362
+ """
1363
+ if isinstance(tensor, tensors.GraphTensor):
1364
+ x = tensors.to_dict(tensor)
1365
+ else:
1366
+ x = tensor
1367
+ if self._update_node_feature:
1368
+ x['node']['feature'] = self._node_dense(tensor.node['feature'])
1369
+ if self._update_edge_feature:
1370
+ x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
1371
+ tape.watch(x['node']['feature'])
1372
+ outputs = [x['node']['feature']]
1373
+ for layer in self.layers:
1374
+ x = layer(x, training=training)
1375
+ tape.watch(x['node']['feature'])
1376
+ outputs.append(x['node']['feature'])
1377
+
1378
+ tensor = tensor.update(
1379
+ {
1380
+ 'node': {
1381
+ 'feature': keras.ops.concatenate(outputs, axis=-1)
1382
+ }
1383
+ }
1384
+ )
1385
+ return tensor, outputs
1386
+
1387
+ def get_config(self) -> dict:
1388
+ config = super().get_config()
1389
+ config.update(
1390
+ {
1391
+ 'layers': [
1392
+ keras.layers.serialize(layer) for layer in self.layers
1393
+ ]
1394
+ }
1395
+ )
1396
+ return config
1397
+
1398
+ @classmethod
1399
+ def from_config(cls, config: dict) -> 'GraphNetwork':
1400
+ config['layers'] = [
1401
+ keras.layers.deserialize(layer) for layer in config['layers']
1402
+ ]
1403
+ return super().from_config(config)
1404
+
1405
+
1406
+ @keras.saving.register_keras_serializable(package='molcraft')
1407
+ class NodeEmbedding(GraphLayer):
1408
+
1409
+ """Node embedding layer.
1410
+
1411
+ Embeds nodes based on its initial features.
1412
+ """
1413
+
1414
+ def __init__(
1415
+ self,
1416
+ dim: int = None,
1417
+ normalize: bool = True,
1418
+ embed_context: bool = True,
1419
+ allow_masking: bool = True,
455
1420
  **kwargs
456
1421
  ) -> None:
457
1422
  super().__init__(**kwargs)
458
1423
  self.dim = dim
1424
+ self._normalize = normalize
459
1425
  self._embed_context = embed_context
460
1426
  self._masking_rate = None
461
1427
  self._allow_masking = allow_masking
@@ -482,6 +1448,12 @@ class NodeEmbedding(GraphLayer):
482
1448
  context_feature_dim = spec.context['feature'].shape[-1]
483
1449
  self._context_dense = self.get_dense(self.dim)
484
1450
  self._context_dense.build([None, context_feature_dim])
1451
+
1452
+ if self._normalize:
1453
+ self._norm = keras.layers.LayerNormalization()
1454
+ self._norm.build([None, self.dim])
1455
+
1456
+ self.built = True
485
1457
 
486
1458
  def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
487
1459
  """Calls the layer.
@@ -515,6 +1487,9 @@ class NodeEmbedding(GraphLayer):
515
1487
  # Slience warning of 'no gradients for variables'
516
1488
  feature = feature + (self._mask_feature * 0.0)
517
1489
 
1490
+ if self._normalize:
1491
+ feature = self._norm(feature)
1492
+
518
1493
  return tensor.update({'node': {'feature': feature}})
519
1494
 
520
1495
  @property
@@ -534,6 +1509,8 @@ class NodeEmbedding(GraphLayer):
534
1509
  config = super().get_config()
535
1510
  config.update({
536
1511
  'dim': self.dim,
1512
+ 'normalize': self._normalize,
1513
+ 'embed_context': self._embed_context,
537
1514
  'allow_masking': self._allow_masking
538
1515
  })
539
1516
  return config
@@ -544,503 +1521,210 @@ class EdgeEmbedding(GraphLayer):
544
1521
 
545
1522
  """Edge embedding layer.
546
1523
 
547
- Embeds edges based on its initial features.
548
- """
549
-
550
- def __init__(
551
- self,
552
- dim: int = None,
553
- allow_masking: bool = True,
554
- **kwargs
555
- ) -> None:
556
- super().__init__(**kwargs)
557
- self.dim = dim
558
- self._masking_rate = None
559
- self._allow_masking = allow_masking
560
-
561
- def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
562
- """Builds the layer.
563
- """
564
- feature_dim = spec.edge['feature'].shape[-1]
565
- if not self.dim:
566
- self.dim = feature_dim
567
- self._edge_dense = self.get_dense(self.dim)
568
- self._edge_dense.build([None, feature_dim])
569
-
570
- self._has_super = 'super' in spec.edge
571
- if self._has_super:
572
- self._super_feature = self.get_weight(shape=[self.dim], name='super_edge_feature')
573
- if self._allow_masking:
574
- self._mask_feature = self.get_weight(shape=[self.dim], name='mask_edge_feature')
575
-
576
- def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
577
- """Calls the layer.
578
- """
579
- feature = self._edge_dense(tensor.edge['feature'])
580
-
581
- if self._has_super:
582
- super_feature = self._super_feature
583
- super_mask = keras.ops.expand_dims(tensor.edge['super'], 1)
584
- feature = keras.ops.where(super_mask, super_feature, feature)
585
-
586
- if (
587
- self._allow_masking and
588
- self._masking_rate is not None and
589
- self._masking_rate > 0
590
- ):
591
- random = keras.random.uniform(shape=[tensor.num_edges])
592
- mask = random <= self._masking_rate
593
- if self._has_super:
594
- mask = keras.ops.logical_and(
595
- mask, keras.ops.logical_not(tensor.edge['super'])
596
- )
597
- mask = keras.ops.expand_dims(mask, -1)
598
- feature = keras.ops.where(mask, self._mask_feature, feature)
599
- elif self._allow_masking:
600
- # Slience warning of 'no gradients for variables'
601
- feature = feature + (self._mask_feature * 0.0)
602
-
603
- return tensor.update({'edge': {'feature': feature}})
604
-
605
- @property
606
- def masking_rate(self):
607
- return self._masking_rate
608
-
609
- @masking_rate.setter
610
- def masking_rate(self, rate: float):
611
- if not self._allow_masking and rate is not None:
612
- raise ValueError(
613
- f'Cannot set `masking_rate` for layer {self} '
614
- 'as `allow_masking` was set to `False`.'
615
- )
616
- self._masking_rate = float(rate)
617
-
618
- def get_config(self) -> dict:
619
- config = super().get_config()
620
- config.update({
621
- 'dim': self.dim,
622
- 'allow_masking': self._allow_masking
623
- })
624
- return config
625
-
626
-
627
- @keras.saving.register_keras_serializable(package='molcraft')
628
- class ContextProjection(Projection):
629
- """Context projection layer.
630
- """
631
- def __init__(self, units: int = None, activation: str = None, **kwargs):
632
- super().__init__(units=units, activation=activation, field='context', **kwargs)
633
-
634
-
635
- @keras.saving.register_keras_serializable(package='molcraft')
636
- class NodeProjection(Projection):
637
- """Node projection layer.
638
- """
639
- def __init__(self, units: int = None, activation: str = None, **kwargs):
640
- super().__init__(units=units, activation=activation, field='node', **kwargs)
641
-
642
-
643
- @keras.saving.register_keras_serializable(package='molcraft')
644
- class EdgeProjection(Projection):
645
- """Edge projection layer.
646
- """
647
- def __init__(self, units: int = None, activation: str = None, **kwargs):
648
- super().__init__(units=units, activation=activation, field='edge', **kwargs)
649
-
650
-
651
- @keras.saving.register_keras_serializable(package='molcraft')
652
- class GINConv(GraphConv):
653
-
654
- """Graph isomorphism network layer.
655
- """
656
-
657
- def __init__(
658
- self,
659
- units: int,
660
- activation: keras.layers.Activation | str | None = 'relu',
661
- dropout: float = 0.0,
662
- normalize: bool = True,
663
- update_edge_feature: bool = True,
664
- **kwargs,
665
- ):
666
- super().__init__(units=units, **kwargs)
667
- self._activation = keras.activations.get(activation)
668
- self._normalize = normalize
669
- self._dropout = dropout
670
- self._update_edge_feature = update_edge_feature
671
-
672
- def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
673
- """Builds the layer.
674
- """
675
- node_feature_dim = spec.node['feature'].shape[-1]
676
-
677
- self.epsilon = self.add_weight(
678
- name='epsilon',
679
- shape=(),
680
- initializer='zeros',
681
- trainable=True,
682
- )
683
-
684
- if 'feature' in spec.edge:
685
- edge_feature_dim = spec.edge['feature'].shape[-1]
686
-
687
- if not self._update_edge_feature:
688
- if (edge_feature_dim != node_feature_dim):
689
- warn(
690
- 'Found edge feature dim to be incompatible with node feature dim. '
691
- 'Automatically adding a edge feature projection layer to match '
692
- 'the dim of node features.'
693
- )
694
- self._update_edge_feature = True
695
-
696
- if self._update_edge_feature:
697
- self._edge_dense = self.get_dense(node_feature_dim)
698
- self._edge_dense.build([None, edge_feature_dim])
699
- else:
700
- self._update_edge_feature = False
701
-
702
- has_overridden_update = self.__class__.update != GINConv.update
703
- if not has_overridden_update:
704
- # Use default feedforward network
705
- self._feedforward_intermediate_dense = self.get_dense(self.units)
706
- self._feedforward_intermediate_dense.build([None, node_feature_dim])
707
-
708
- if self._normalize:
709
- self._feedforward_intermediate_norm = keras.layers.BatchNormalization()
710
- self._feedforward_intermediate_norm.build([None, self.units])
711
-
712
- self._feedforward_dropout = keras.layers.Dropout(self._dropout)
713
- self._feedforward_activation = self._activation
714
-
715
- self._feedforward_output_dense = self.get_dense(self.units)
716
- self._feedforward_output_dense.build([None, self.units])
717
-
718
- def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
719
- """Compute messages.
720
- """
721
- message = tensor.gather('feature', 'source')
722
- edge_feature = tensor.edge.get('feature')
723
- if self._update_edge_feature:
724
- edge_feature = self._edge_dense(edge_feature)
725
- if edge_feature is not None:
726
- message += edge_feature
727
- return tensor.update(
728
- {
729
- 'edge': {
730
- 'message': message,
731
- 'feature': edge_feature
732
- }
733
- }
734
- )
735
-
736
- def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
737
- """Aggregates messages.
738
- """
739
- node_feature = tensor.aggregate('message')
740
- node_feature += (1 + self.epsilon) * tensor.node['feature']
741
- return tensor.update(
742
- {
743
- 'node': {
744
- 'feature': node_feature,
745
- },
746
- 'edge': {
747
- 'message': None,
748
- }
749
- }
750
- )
751
-
752
- def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
753
- """Updates nodes.
754
- """
755
- node_feature = tensor.node['feature']
756
- node_feature = self._feedforward_intermediate_dense(node_feature)
757
- node_feature = self._feedforward_activation(node_feature)
758
- if self._normalize:
759
- node_feature = self._feedforward_intermediate_norm(node_feature)
760
- node_feature = self._feedforward_dropout(node_feature)
761
- node_feature = self._feedforward_output_dense(node_feature)
762
- return tensor.update(
763
- {
764
- 'node': {
765
- 'feature': node_feature,
766
- }
767
- }
768
- )
769
-
770
- def get_config(self) -> dict:
771
- config = super().get_config()
772
- config.update({
773
- 'activation': keras.activations.serialize(self._activation),
774
- 'dropout': self._dropout,
775
- 'normalize': self._normalize,
776
- })
777
- return config
778
-
779
-
780
- @keras.saving.register_keras_serializable(package='molcraft')
781
- class GTConv(GraphConv):
782
-
783
- """Graph transformer layer.
1524
+ Embeds edges based on its initial features.
784
1525
  """
785
1526
 
786
1527
  def __init__(
787
- self,
788
- units: int,
789
- heads: int = 8,
790
- activation: keras.layers.Activation | str | None = "relu",
791
- dropout: float = 0.0,
792
- attention_dropout: float = 0.0,
1528
+ self,
1529
+ dim: int = None,
793
1530
  normalize: bool = True,
794
- normalize_first: bool = True,
795
- **kwargs,
1531
+ allow_masking: bool = True,
1532
+ **kwargs
796
1533
  ) -> None:
797
- super().__init__(units=units, **kwargs)
798
- self._heads = heads
799
- if self.units % self.heads != 0:
800
- raise ValueError(f"units need to be divisible by heads.")
801
- self._head_units = self.units // self.heads
802
- self._activation = keras.activations.get(activation)
803
- self._dropout = dropout
804
- self._attention_dropout = attention_dropout
1534
+ super().__init__(**kwargs)
1535
+ self.dim = dim
805
1536
  self._normalize = normalize
806
- self._normalize_first = normalize_first
1537
+ self._masking_rate = None
1538
+ self._allow_masking = allow_masking
807
1539
 
808
- @property
809
- def heads(self):
810
- return self._heads
811
-
812
- @property
813
- def head_units(self):
814
- return self._head_units
815
-
816
- def build_from_spec(self, spec):
1540
+ def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
817
1541
  """Builds the layer.
818
1542
  """
819
- node_feature_dim = spec.node['feature'].shape[-1]
820
- incompatible_dim = node_feature_dim != self.units
821
- if incompatible_dim:
822
- warnings.warn(
823
- message=(
824
- '`GTConv` uses residual connections, but input node feature dim '
825
- 'is incompatible with intermediate dim (`units`). '
826
- 'Automatically projecting first residual to match its dim with intermediate dim.'
827
- ),
828
- category=UserWarning,
829
- stacklevel=1
830
- )
831
- self._residual_dense = self.get_dense(self.units)
832
- self._residual_dense.build([None, node_feature_dim])
833
- self._project_residual = True
834
- else:
835
- self._project_residual = False
836
-
837
- self._query_dense = self.get_einsum_dense(
838
- 'ij,jkh->ikh', (self.head_units, self.heads)
839
- )
840
- self._query_dense.build([None, node_feature_dim])
841
-
842
- self._key_dense = self.get_einsum_dense(
843
- 'ij,jkh->ikh', (self.head_units, self.heads)
844
- )
845
- self._key_dense.build([None, node_feature_dim])
846
-
847
- self._value_dense = self.get_einsum_dense(
848
- 'ij,jkh->ikh', (self.head_units, self.heads)
849
- )
850
- self._value_dense.build([None, node_feature_dim])
851
-
852
- self._output_dense = self.get_dense(self.units)
853
- self._output_dense.build([None, self.units])
1543
+ feature_dim = spec.edge['feature'].shape[-1]
1544
+ if not self.dim:
1545
+ self.dim = feature_dim
1546
+ self._edge_dense = self.get_dense(self.dim)
1547
+ self._edge_dense.build([None, feature_dim])
854
1548
 
855
- self._softmax_dropout = keras.layers.Dropout(self._attention_dropout)
1549
+ self._has_super = 'super' in spec.edge
1550
+ if self._has_super:
1551
+ self._super_feature = self.get_weight(shape=[self.dim], name='super_edge_feature')
1552
+ if self._allow_masking:
1553
+ self._mask_feature = self.get_weight(shape=[self.dim], name='mask_edge_feature')
1554
+ if self._normalize:
1555
+ self._norm = keras.layers.LayerNormalization()
1556
+ self._norm.build([None, self.dim])
856
1557
 
857
- self._self_attention_norm = keras.layers.LayerNormalization()
858
- if self._normalize_first:
859
- self._self_attention_norm.build([None, node_feature_dim])
860
- else:
861
- self._self_attention_norm.build([None, self.units])
1558
+ self.built = True
862
1559
 
863
- self._self_attention_dropout = keras.layers.Dropout(self._dropout)
1560
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1561
+ """Calls the layer.
1562
+ """
1563
+ feature = self._edge_dense(tensor.edge['feature'])
864
1564
 
865
- has_overriden_edge_bias = (
866
- self.__class__.add_edge_bias != GTConv.add_edge_bias
867
- )
868
- if not has_overriden_edge_bias:
869
- self._has_edge_length = 'length' in spec.edge
870
- if self._has_edge_length and 'bias' not in spec.edge:
871
- edge_length_dim = spec.edge['length'].shape[-1]
872
- self._spatial_encoding_dense = self.get_einsum_dense(
873
- 'ij,jkh->ikh', (1, self.heads), kernel_initializer='zeros'
874
- )
875
- self._spatial_encoding_dense.build([None, edge_length_dim])
1565
+ if self._has_super:
1566
+ super_feature = self._super_feature
1567
+ super_mask = keras.ops.expand_dims(tensor.edge['super'], 1)
1568
+ feature = keras.ops.where(super_mask, super_feature, feature)
876
1569
 
877
- self._has_edge_feature = 'feature' in spec.edge
878
- if self._has_edge_feature and 'bias' not in spec.edge:
879
- edge_feature_dim = spec.edge['feature'].shape[-1]
880
- self._edge_feature_dense = self.get_einsum_dense(
881
- 'ij,jkh->ikh', (1, self.heads),
1570
+ if (
1571
+ self._allow_masking and
1572
+ self._masking_rate is not None and
1573
+ self._masking_rate > 0
1574
+ ):
1575
+ random = keras.random.uniform(shape=[tensor.num_edges])
1576
+ mask = random <= self._masking_rate
1577
+ if self._has_super:
1578
+ mask = keras.ops.logical_and(
1579
+ mask, keras.ops.logical_not(tensor.edge['super'])
882
1580
  )
883
- self._edge_feature_dense.build([None, edge_feature_dim])
884
-
885
- has_overridden_update = self.__class__.update != GTConv.update
886
- if not has_overridden_update:
887
-
888
- self._feedforward_norm = keras.layers.LayerNormalization()
889
- self._feedforward_norm.build([None, self.units])
890
-
891
- self._feedforward_dropout = keras.layers.Dropout(self._dropout)
1581
+ mask = keras.ops.expand_dims(mask, -1)
1582
+ feature = keras.ops.where(mask, self._mask_feature, feature)
1583
+ elif self._allow_masking:
1584
+ # Slience warning of 'no gradients for variables'
1585
+ feature = feature + (self._mask_feature * 0.0)
892
1586
 
893
- self._feedforward_intermediate_dense = self.get_dense(self.units)
894
- self._feedforward_intermediate_dense.build([None, self.units])
1587
+ if self._normalize:
1588
+ feature = self._norm(feature)
895
1589
 
896
- self._feedforward_output_dense = self.get_dense(self.units)
897
- self._feedforward_output_dense.build([None, self.units])
1590
+ return tensor.update({'edge': {'feature': feature}})
898
1591
 
899
- def add_node_bias(self, tensor: tensors.GraphTensor) -> tf.Tensor:
900
- return tensor
1592
+ @property
1593
+ def masking_rate(self):
1594
+ return self._masking_rate
901
1595
 
902
- def add_edge_bias(self, tensor: tensors.GraphTensor) -> tf.Tensor:
903
- if 'bias' in tensor.edge:
904
- return tensor
905
- elif not self._has_edge_feature and not self._has_edge_length:
906
- return tensor
907
-
908
- if self._has_edge_feature and not self._has_edge_length:
909
- edge_bias = self._edge_feature_dense(tensor.edge['feature'])
910
- elif not self._has_edge_feature and self._has_edge_length:
911
- edge_bias = self._spatial_encoding_dense(tensor.edge['length'])
912
- else:
913
- edge_bias = (
914
- self._edge_feature_dense(tensor.edge['feature']) +
915
- self._spatial_encoding_dense(tensor.edge['length'])
1596
+ @masking_rate.setter
1597
+ def masking_rate(self, rate: float):
1598
+ if not self._allow_masking and rate is not None:
1599
+ raise ValueError(
1600
+ f'Cannot set `masking_rate` for layer {self} '
1601
+ 'as `allow_masking` was set to `False`.'
916
1602
  )
917
-
918
- return tensor.update(
919
- {
920
- 'edge': {
921
- 'bias': edge_bias
922
- }
923
- }
924
- )
925
-
926
- def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
927
- """Compute messages.
928
- """
929
- tensor = self.add_edge_bias(tensor)
930
- tensor = self.add_node_bias(tensor)
1603
+ self._masking_rate = float(rate)
931
1604
 
932
- node_feature = tensor.node['feature']
1605
+ def get_config(self) -> dict:
1606
+ config = super().get_config()
1607
+ config.update({
1608
+ 'dim': self.dim,
1609
+ 'normalize': self._normalize,
1610
+ 'allow_masking': self._allow_masking
1611
+ })
1612
+ return config
1613
+
933
1614
 
934
- if 'bias' in tensor.node:
935
- node_feature += tensor.node['bias']
936
-
937
- if self._normalize_first:
938
- node_feature = self._self_attention_norm(node_feature)
939
-
940
- query = self._query_dense(node_feature)
941
- key = self._key_dense(node_feature)
942
- value = self._value_dense(node_feature)
1615
+ @keras.saving.register_keras_serializable(package='molcraft')
1616
+ class ContextProjection(Projection):
1617
+ """Context projection layer.
1618
+ """
1619
+ def __init__(self, units: int = None, activation: str = None, **kwargs):
1620
+ super().__init__(units=units, activation=activation, field='context', **kwargs)
943
1621
 
944
- query = ops.gather(query, tensor.edge['source'])
945
- key = ops.gather(key, tensor.edge['target'])
946
- value = ops.gather(value, tensor.edge['source'])
947
1622
 
948
- attention_score = keras.ops.sum(query * key, axis=1, keepdims=True)
949
- attention_score /= keras.ops.sqrt(float(self.units))
1623
+ @keras.saving.register_keras_serializable(package='molcraft')
1624
+ class NodeProjection(Projection):
1625
+ """Node projection layer.
1626
+ """
1627
+ def __init__(self, units: int = None, activation: str = None, **kwargs):
1628
+ super().__init__(units=units, activation=activation, field='node', **kwargs)
950
1629
 
951
- if 'bias' in tensor.edge:
952
- attention_score += tensor.edge['bias']
953
-
954
- attention = ops.edge_softmax(attention_score, tensor.edge['target'])
955
- attention = self._softmax_dropout(attention)
956
1630
 
957
- return tensor.update(
958
- {
959
- 'edge': {
960
- 'message': value,
961
- 'weight': attention,
962
- },
963
- }
964
- )
1631
+ @keras.saving.register_keras_serializable(package='molcraft')
1632
+ class EdgeProjection(Projection):
1633
+ """Edge projection layer.
1634
+ """
1635
+ def __init__(self, units: int = None, activation: str = None, **kwargs):
1636
+ super().__init__(units=units, activation=activation, field='edge', **kwargs)
965
1637
 
966
- def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
967
- """Aggregates messages.
968
- """
969
- node_feature = tensor.aggregate('message')
970
1638
 
971
- node_feature = keras.ops.reshape(node_feature, (-1, self.units))
972
- node_feature = self._output_dense(node_feature)
973
- node_feature = self._self_attention_dropout(node_feature)
1639
+ @keras.saving.register_keras_serializable(package='molcraft')
1640
+ class EdgeBias(GraphLayer):
974
1641
 
975
- residual = tensor.node['feature']
976
- if self._project_residual:
977
- residual = self._residual_dense(residual)
978
- node_feature += residual
1642
+ def __init__(self, biases: int, **kwargs):
1643
+ super().__init__(**kwargs)
1644
+ self.biases = biases
979
1645
 
980
- if not self._normalize_first:
981
- node_feature = self._self_attention_norm(node_feature)
1646
+ def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
1647
+ self._has_edge_length = 'length' in spec.edge
1648
+ self._has_edge_feature = 'feature' in spec.edge
1649
+ if self._has_edge_feature:
1650
+ self._edge_feature_dense = self.get_dense(self.biases)
1651
+ self._edge_feature_dense.build([None, spec.edge['feature'].shape[-1]])
1652
+ if self._has_edge_length:
1653
+ self._edge_length_dense = self.get_dense(
1654
+ self.biases, kernel_initializer='zeros'
1655
+ )
1656
+ self._edge_length_dense.build([None, spec.edge['length'].shape[-1]])
1657
+ self.built = True
982
1658
 
983
- return tensor.update(
984
- {
985
- 'node': {
986
- 'feature': node_feature,
987
- },
988
- 'edge': {
989
- 'message': None,
990
- 'weight': None,
991
- }
992
- }
1659
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1660
+ bias = keras.ops.zeros(
1661
+ shape=(tensor.num_edges, self.biases),
1662
+ dtype=tensor.node['feature'].dtype
993
1663
  )
994
-
995
-
996
- def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
997
- """Updates nodes.
998
- """
999
- node_feature = tensor.node['feature']
1000
-
1001
- if self._normalize_first:
1002
- node_feature = self._feedforward_norm(node_feature)
1664
+ if self._has_edge_feature:
1665
+ bias += self._edge_feature_dense(tensor.edge['feature'])
1666
+ if self._has_edge_length:
1667
+ bias += self._edge_length_dense(tensor.edge['length'])
1668
+ return bias
1003
1669
 
1004
- node_feature = self._feedforward_intermediate_dense(node_feature)
1005
- node_feature = self._activation(node_feature)
1006
- node_feature = self._feedforward_output_dense(node_feature)
1670
+ def get_config(self) -> dict:
1671
+ config = super().get_config()
1672
+ config.update({'biases': self.biases})
1673
+ return config
1674
+
1007
1675
 
1008
- node_feature = self._feedforward_dropout(node_feature)
1009
- node_feature += tensor.node['feature']
1676
+ @keras.saving.register_keras_serializable(package='molcraft')
1677
+ class GaussianDistance(GraphLayer):
1010
1678
 
1011
- if not self._normalize_first:
1012
- node_feature = self._feedforward_norm(node_feature)
1679
+ def __init__(self, kernels: int, **kwargs):
1680
+ super().__init__(**kwargs)
1681
+ self.kernels = kernels
1013
1682
 
1014
- return tensor.update(
1015
- {
1016
- 'node': {
1017
- 'feature': node_feature,
1018
- },
1019
- }
1683
+ def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
1684
+ self._loc = self.add_weight(
1685
+ shape=[self.kernels],
1686
+ initializer='zeros',
1687
+ dtype='float32',
1688
+ trainable=True
1689
+ )
1690
+ self._scale = self.add_weight(
1691
+ shape=[self.kernels],
1692
+ initializer='ones',
1693
+ dtype='float32',
1694
+ trainable=True
1020
1695
  )
1021
-
1696
+ self.built = True
1697
+
1698
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1699
+ euclidean_distance = ops.euclidean_distance(
1700
+ tensor.gather('coordinate', 'source'),
1701
+ tensor.gather('coordinate', 'target'),
1702
+ axis=-1
1703
+ )
1704
+ return ops.gaussian(
1705
+ euclidean_distance, self._loc, self._scale
1706
+ )
1707
+
1022
1708
  def get_config(self) -> dict:
1023
1709
  config = super().get_config()
1024
1710
  config.update({
1025
- "heads": self._heads,
1026
- 'activation': keras.activations.serialize(self._activation),
1027
- 'dropout': self._dropout,
1028
- 'attention_dropout': self._attention_dropout,
1029
- 'normalize': self._normalize,
1030
- 'normalize_first': self._normalize_first,
1711
+ 'kernels': self.kernels,
1031
1712
  })
1032
1713
  return config
1033
-
1714
+
1034
1715
 
1035
1716
  @keras.saving.register_keras_serializable(package='molcraft')
1036
- class Readout(keras.layers.Layer):
1717
+ class Readout(GraphLayer):
1718
+
1719
+ """Readout layer.
1720
+ """
1037
1721
 
1038
1722
  def __init__(self, mode: str | None = None, **kwargs):
1723
+ kwargs['kernel_initializer'] = None
1724
+ kwargs['bias_initializer'] = None
1039
1725
  super().__init__(**kwargs)
1040
1726
  self.mode = mode
1041
- if not self.mode:
1042
- self._reduce_fn = None
1043
- elif str(self.mode).lower().startswith('sum'):
1727
+ if str(self.mode).lower().startswith('sum'):
1044
1728
  self._reduce_fn = keras.ops.segment_sum
1045
1729
  elif str(self.mode).lower().startswith('max'):
1046
1730
  self._reduce_fn = keras.ops.segment_max
@@ -1052,50 +1736,25 @@ class Readout(keras.layers.Layer):
1052
1736
  def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
1053
1737
  """Builds the layer.
1054
1738
  """
1055
- pass
1739
+ self.built = True
1056
1740
 
1057
- def reduce(self, tensor: tensors.GraphTensor) -> tf.Tensor:
1058
- if self._reduce_fn is None:
1059
- raise NotImplementedError("Need to define a reduce method.")
1741
+ def propagate(self, tensor: tensors.GraphTensor) -> tf.Tensor:
1742
+ """Calls the layer.
1743
+ """
1744
+ node_feature = tensor.node['feature']
1060
1745
  if str(self.mode).lower().startswith('super'):
1061
1746
  node_feature = keras.ops.where(
1062
- tensor.node['super'][:, None], tensor.node['feature'], 0.0
1063
- )
1064
- return self._reduce_fn(
1065
- node_feature, tensor.graph_indicator, tensor.num_subgraphs
1747
+ tensor.node['super'][:, None], node_feature, 0.0
1066
1748
  )
1067
1749
  return self._reduce_fn(
1068
- tensor.node['feature'], tensor.graph_indicator, tensor.num_subgraphs
1750
+ node_feature, tensor.graph_indicator, tensor.num_subgraphs
1069
1751
  )
1070
1752
 
1071
- def build(self, input_shapes) -> None:
1072
- spec = tensors.GraphTensor.Spec.from_input_shape_dict(input_shapes)
1073
- self.build_from_spec(spec)
1074
- self.built = True
1075
-
1076
- def call(self, graph) -> tf.Tensor:
1077
- graph_tensor = tensors.from_dict(graph)
1078
- if tensors.is_ragged(graph_tensor):
1079
- graph_tensor = graph_tensor.flatten()
1080
- return self.reduce(graph_tensor)
1081
-
1082
- def __call__(
1083
- self,
1084
- graph: tensors.GraphTensor,
1085
- *args,
1086
- **kwargs
1087
- ) -> tensors.GraphTensor:
1088
- is_tensor = isinstance(graph, tensors.GraphTensor)
1089
- if is_tensor:
1090
- graph = tensors.to_dict(graph)
1091
- tensor = super().__call__(graph, *args, **kwargs)
1092
- return tensor
1093
-
1094
1753
  def get_config(self) -> dict:
1095
1754
  config = super().get_config()
1096
1755
  config['mode'] = self.mode
1097
1756
  return config
1098
-
1757
+
1099
1758
 
1100
1759
  def Input(spec: tensors.GraphTensor.Spec) -> dict:
1101
1760
  """Used to specify inputs to model.
@@ -1212,13 +1871,6 @@ def _spec_from_inputs(inputs):
1212
1871
  return tensors.GraphTensor.Spec(**nested_specs)
1213
1872
 
1214
1873
 
1215
- GraphTransformer = GTConvolution = GTConv
1216
- GINConvolution = GINConv
1217
-
1218
- EdgeEmbed = EdgeEmbedding
1219
- NodeEmbed = NodeEmbedding
1220
-
1221
- ContextDense = ContextProjection
1222
- EdgeDense = EdgeProjection
1223
- NodeDense = NodeProjection
1874
+ GraphTransformer = GTConv
1875
+ GraphTransformer3D = GTConv3D
1224
1876