molcraft 0.1.0a1__py3-none-any.whl → 0.1.0a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

molcraft/layers.py CHANGED
@@ -206,12 +206,62 @@ class GraphLayer(keras.layers.Layer):
206
206
  class GraphConv(GraphLayer):
207
207
 
208
208
  """Base graph neural network layer.
209
+
210
+ For normalization and skip connection to work, the `GraphConv` subclass
211
+ requires the (node feature) output of `aggregate` and `update` to have a
212
+ dimension of `self.units`, respectively.
213
+
214
+ Args:
215
+ units:
216
+ The number of units.
217
+ normalize:
218
+ Whether `LayerNormalization` should be applied to the (node feature) output
219
+ of the `aggregate` step. While normalization is recommended, it is not used
220
+ by default.
221
+ skip_connection:
222
+ Whether (node feature) input should be added to the (node feature) output.
223
+ If (node feature) input dim is not equal to `units`, a projection layer will
224
+ automatically project the residual before adding it to the output. While skip
225
+ connection is recommended, it is not used by default.
226
+ kwargs:
227
+ See arguments of `GraphLayer`.
209
228
  """
210
229
 
211
- def __init__(self, units: int, **kwargs) -> None:
230
+ def __init__(
231
+ self,
232
+ units: int,
233
+ normalize: bool = False,
234
+ skip_connection: bool = False,
235
+ **kwargs
236
+ ) -> None:
212
237
  super().__init__(**kwargs)
213
238
  self.units = units
214
-
239
+ self._normalize_aggregate = normalize
240
+ self._skip_connection = skip_connection
241
+
242
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
243
+ node_feature_dim = spec.node['feature'].shape[-1]
244
+ self._project_input_node_feature = (
245
+ self._skip_connection and (node_feature_dim != self.units)
246
+ )
247
+ if self._project_input_node_feature:
248
+ warn(
249
+ '`skip_connection` is set to `True`, but found incompatible dim '
250
+ 'between input (node feature dim) and output (`self.units`). '
251
+ 'Automatically applying a projection layer to residual to '
252
+ 'match input and output. '
253
+ )
254
+ self._residual_projection = self.get_dense(
255
+ self.units, name='residual_projection'
256
+ )
257
+ if self._normalize_aggregate:
258
+ self._aggregation_norm = keras.layers.LayerNormalization(
259
+ name='aggregation_normalizer'
260
+ )
261
+ self._aggregation_norm.build([None, self.units])
262
+
263
+ super().build(spec)
264
+
215
265
  @abc.abstractmethod
216
266
  def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
217
267
  """Compute messages.
@@ -256,479 +306,455 @@ class GraphConv(GraphLayer):
256
306
  tensor:
257
307
  A `GraphTensor` instance.
258
308
  """
309
+
310
+ if self._skip_connection:
311
+ input_node_feature = tensor.node['feature']
312
+ if self._project_input_node_feature:
313
+ input_node_feature = self._residual_projection(input_node_feature)
314
+
259
315
  tensor = self.message(tensor)
260
316
  tensor = self.aggregate(tensor)
261
- tensor = self.update(tensor)
262
- return tensor
263
317
 
264
- def get_config(self) -> dict:
265
- config = super().get_config()
266
- config.update({
267
- 'units': self.units
268
- })
269
- return config
270
-
271
-
272
- @keras.saving.register_keras_serializable(package='molcraft')
273
- class Projection(GraphLayer):
274
- """Base graph projection layer.
275
- """
276
- def __init__(
277
- self,
278
- units: int = None,
279
- activation: str = None,
280
- field: str = 'node',
281
- **kwargs
282
- ) -> None:
283
- super().__init__(**kwargs)
284
- self.units = units
285
- self._activation = keras.activations.get(activation)
286
- self.field = field
318
+ if self._normalize_aggregate:
319
+ normalized_node_feature = self._aggregation_norm(tensor.node['feature'])
320
+ tensor = tensor.update({'node': {'feature': normalized_node_feature}})
287
321
 
288
- def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
289
- """Builds the layer.
290
- """
291
- data = getattr(spec, self.field, None)
292
- if data is None:
293
- raise ValueError('Could not access field {self.field!r}.')
294
- feature_dim = data['feature'].shape[-1]
295
- if not self.units:
296
- self.units = feature_dim
297
- self._dense = self.get_dense(self.units)
298
- self._dense.build([None, feature_dim])
322
+ tensor = self.update(tensor)
299
323
 
300
- def propagate(self, tensor: tensors.GraphTensor):
301
- """Calls the layer.
302
- """
303
- feature = getattr(tensor, self.field)['feature']
304
- feature = self._dense(feature)
305
- feature = self._activation(feature)
324
+ if not self._skip_connection:
325
+ return tensor
326
+
327
+ updated_node_feature = tensor.node['feature']
306
328
  return tensor.update(
307
329
  {
308
- self.field: {
309
- 'feature': feature
330
+ 'node': {
331
+ 'feature': updated_node_feature + input_node_feature
310
332
  }
311
333
  }
312
- )
334
+ )
313
335
 
314
336
  def get_config(self) -> dict:
315
337
  config = super().get_config()
316
338
  config.update({
317
339
  'units': self.units,
318
- 'activation': keras.activations.serialize(self._activation),
319
- 'field': self.field,
340
+ 'normalize': self._normalize_aggregate,
341
+ 'skip_connection': self._skip_connection,
320
342
  })
321
343
  return config
322
344
 
323
345
 
324
346
  @keras.saving.register_keras_serializable(package='molcraft')
325
- class GraphNetwork(GraphLayer):
326
-
327
- """Graph neural network.
328
-
329
- Sequentially calls graph layers (`GraphLayer`) and concatenates its output.
347
+ class GINConv(GraphConv):
330
348
 
331
- Args:
332
- layers (list):
333
- A list of graph layers.
349
+ """Graph isomorphism network layer.
334
350
  """
335
351
 
336
- def __init__(self, layers: list[GraphLayer], **kwargs) -> None:
337
- super().__init__(**kwargs)
338
- self.layers = layers
339
- self._update_edge_feature = False
352
+ def __init__(
353
+ self,
354
+ units: int,
355
+ activation: keras.layers.Activation | str | None = 'relu',
356
+ use_bias: bool = True,
357
+ normalize: bool = True,
358
+ dropout: float = 0.0,
359
+ update_edge_feature: bool = True,
360
+ **kwargs,
361
+ ):
362
+ super().__init__(
363
+ units=units,
364
+ normalize=normalize,
365
+ use_bias=use_bias,
366
+ **kwargs
367
+ )
368
+ self._activation = keras.activations.get(activation)
369
+ self._dropout = dropout
370
+ self._update_edge_feature = update_edge_feature
340
371
 
341
372
  def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
342
373
  """Builds the layer.
343
374
  """
344
- units = self.layers[0].units
345
375
  node_feature_dim = spec.node['feature'].shape[-1]
346
- if node_feature_dim != units:
347
- warn(
348
- 'Node feature dim does not match `units` of the first layer. '
349
- 'Automatically adding a node projection layer to match `units`.'
350
- )
351
- self._node_dense = self.get_dense(units)
352
- self._update_node_feature = True
353
- has_edge_feature = 'feature' in spec.edge
354
- if has_edge_feature:
376
+
377
+ self.epsilon = self.add_weight(
378
+ name='epsilon',
379
+ shape=(),
380
+ initializer='zeros',
381
+ trainable=True,
382
+ )
383
+
384
+ if 'feature' in spec.edge:
355
385
  edge_feature_dim = spec.edge['feature'].shape[-1]
356
- if edge_feature_dim != units:
357
- warn(
358
- 'Edge feature dim does not match `units` of the first layer. '
359
- 'Automatically adding a edge projection layer to match `units`.'
360
- )
361
- self._edge_dense = self.get_dense(units)
362
- self._update_edge_feature = True
363
386
 
364
- def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
365
- """Calls the layer.
387
+ if not self._update_edge_feature:
388
+ if (edge_feature_dim != node_feature_dim):
389
+ warn(
390
+ 'Found edge feature dim to be incompatible with node feature dim. '
391
+ 'Automatically adding a edge feature projection layer to match '
392
+ 'the dim of node features.'
393
+ )
394
+ self._update_edge_feature = True
395
+
396
+ if self._update_edge_feature:
397
+ self._edge_dense = self.get_dense(node_feature_dim)
398
+ self._edge_dense.build([None, edge_feature_dim])
399
+ else:
400
+ self._update_edge_feature = False
401
+
402
+ self._feedforward_intermediate_dense = self.get_dense(self.units)
403
+ self._feedforward_intermediate_dense.build([None, node_feature_dim])
404
+
405
+ has_overridden_update = self.__class__.update != GINConv.update
406
+ if not has_overridden_update:
407
+ # Use default feedforward network
408
+
409
+ self._feedforward_dropout = keras.layers.Dropout(self._dropout)
410
+ self._feedforward_activation = self._activation
411
+
412
+ self._feedforward_output_dense = self.get_dense(self.units)
413
+ self._feedforward_output_dense.build([None, self.units])
414
+
415
+ def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
416
+ """Computes messages.
366
417
  """
367
- x = tensors.to_dict(tensor)
368
- if self._update_node_feature:
369
- x['node']['feature'] = self._node_dense(tensor.node['feature'])
418
+ message = tensor.gather('feature', 'source')
419
+ edge_feature = tensor.edge.get('feature')
370
420
  if self._update_edge_feature:
371
- x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
372
- outputs = [x['node']['feature']]
373
- for layer in self.layers:
374
- x = layer(x)
375
- outputs.append(x['node']['feature'])
421
+ edge_feature = self._edge_dense(edge_feature)
422
+ if edge_feature is not None:
423
+ message += edge_feature
376
424
  return tensor.update(
377
425
  {
378
- 'node': {
379
- 'feature': keras.ops.concatenate(outputs, axis=-1)
380
- }
426
+ 'edge': {
427
+ 'message': message,
428
+ 'feature': edge_feature
429
+ }
381
430
  }
382
431
  )
383
-
384
- def tape_propagate(
385
- self,
386
- tensor: tensors.GraphTensor,
387
- tape: tf.GradientTape,
388
- training: bool | None = None,
389
- ) -> tuple[tensors.GraphTensor, list[tf.Tensor]]:
390
- """Performs the propagation with a `GradientTape`.
391
-
392
- Performs the same forward pass as `propagate` but with a `GradientTape`
393
- watching intermediate node features.
394
432
 
395
- Args:
396
- tensor (tensors.GraphTensor):
397
- The graph input.
433
+ def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
434
+ """Aggregates messages.
398
435
  """
399
- if isinstance(tensor, tensors.GraphTensor):
400
- x = tensors.to_dict(tensor)
401
- else:
402
- x = tensor
403
- if self._update_node_feature:
404
- x['node']['feature'] = self._node_dense(tensor.node['feature'])
405
- if self._update_edge_feature:
406
- x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
407
- tape.watch(x['node']['feature'])
408
- outputs = [x['node']['feature']]
409
- for layer in self.layers:
410
- x = layer(x, training=training)
411
- tape.watch(x['node']['feature'])
412
- outputs.append(x['node']['feature'])
413
-
414
- tensor = tensor.update(
436
+ node_feature = tensor.aggregate('message')
437
+ node_feature += (1 + self.epsilon) * tensor.node['feature']
438
+ node_feature = self._feedforward_intermediate_dense(node_feature)
439
+ node_feature = self._feedforward_activation(node_feature)
440
+ return tensor.update(
415
441
  {
416
442
  'node': {
417
- 'feature': keras.ops.concatenate(outputs, axis=-1)
443
+ 'feature': node_feature,
444
+ },
445
+ 'edge': {
446
+ 'message': None,
418
447
  }
419
448
  }
420
449
  )
421
- return tensor, outputs
422
450
 
423
- def get_config(self) -> dict:
424
- config = super().get_config()
425
- config.update(
451
+ def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
452
+ """Updates nodes.
453
+ """
454
+ node_feature = tensor.node['feature']
455
+ node_feature = self._feedforward_dropout(node_feature)
456
+ node_feature = self._feedforward_output_dense(node_feature)
457
+ return tensor.update(
426
458
  {
427
- 'layers': [
428
- keras.layers.serialize(layer) for layer in self.layers
429
- ]
459
+ 'node': {
460
+ 'feature': node_feature,
461
+ }
430
462
  }
431
463
  )
464
+
465
+ def get_config(self) -> dict:
466
+ config = super().get_config()
467
+ config.update({
468
+ 'activation': keras.activations.serialize(self._activation),
469
+ 'dropout': self._dropout,
470
+ 'update_edge_feature': self._update_edge_feature
471
+ })
432
472
  return config
433
473
 
434
- @classmethod
435
- def from_config(cls, config: dict) -> 'GraphNetwork':
436
- config['layers'] = [
437
- keras.layers.deserialize(layer) for layer in config['layers']
438
- ]
439
- return super().from_config(config)
440
-
441
474
 
442
475
  @keras.saving.register_keras_serializable(package='molcraft')
443
- class NodeEmbedding(GraphLayer):
444
-
445
- """Node embedding layer.
476
+ class GTConv(GraphConv):
446
477
 
447
- Embeds nodes based on its initial features.
478
+ """Graph transformer layer.
448
479
  """
449
480
 
450
481
  def __init__(
451
- self,
452
- dim: int = None,
453
- embed_context: bool = True,
454
- allow_masking: bool = True,
455
- **kwargs
482
+ self,
483
+ units: int,
484
+ heads: int = 8,
485
+ activation: keras.layers.Activation | str | None = "relu",
486
+ use_bias: bool = True,
487
+ normalize: bool = True,
488
+ dropout: float = 0.0,
489
+ attention_dropout: float = 0.0,
490
+ **kwargs,
456
491
  ) -> None:
457
- super().__init__(**kwargs)
458
- self.dim = dim
459
- self._embed_context = embed_context
460
- self._masking_rate = None
461
- self._allow_masking = allow_masking
492
+ kwargs['skip_connection'] = False
493
+ super().__init__(
494
+ units=units,
495
+ normalize=normalize,
496
+ use_bias=use_bias,
497
+ **kwargs
498
+ )
499
+ self._heads = heads
500
+ if self.units % self.heads != 0:
501
+ raise ValueError(f"units need to be divisible by heads.")
502
+ self._head_units = self.units // self.heads
503
+ self._activation = keras.activations.get(activation)
504
+ self._dropout = dropout
505
+ self._attention_dropout = attention_dropout
506
+ self._normalize = normalize
462
507
 
463
- def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
508
+ @property
509
+ def heads(self):
510
+ return self._heads
511
+
512
+ @property
513
+ def head_units(self):
514
+ return self._head_units
515
+
516
+ def build_from_spec(self, spec):
464
517
  """Builds the layer.
465
518
  """
466
- feature_dim = spec.node['feature'].shape[-1]
467
- if not self.dim:
468
- self.dim = feature_dim
469
- self._node_dense = self.get_dense(self.dim)
470
- self._node_dense.build([None, feature_dim])
519
+ node_feature_dim = spec.node['feature'].shape[-1]
520
+ self.project_residual = node_feature_dim != self.units
521
+ if self.project_residual:
522
+ warn(
523
+ '`GTConv` uses residual connections, but found incompatible dim '
524
+ 'between input (node feature dim) and output (`self.units`). '
525
+ 'Automatically applying a projection layer to residual to '
526
+ 'match input and output. '
527
+ )
528
+ self._residual_dense = self.get_dense(self.units)
529
+ self._residual_dense.build([None, node_feature_dim])
530
+
531
+ self._query_dense = self.get_einsum_dense(
532
+ 'ij,jkh->ikh', (self.head_units, self.heads)
533
+ )
534
+ self._query_dense.build([None, node_feature_dim])
471
535
 
472
- self._has_super = 'super' in spec.node
473
- has_context_feature = 'feature' in spec.context
474
- if not has_context_feature:
475
- self._embed_context = False
476
- if self._has_super and not self._embed_context:
477
- self._super_feature = self.get_weight(shape=[self.dim], name='super_node_feature')
478
- if self._allow_masking:
479
- self._mask_feature = self.get_weight(shape=[self.dim], name='mask_node_feature')
536
+ self._key_dense = self.get_einsum_dense(
537
+ 'ij,jkh->ikh', (self.head_units, self.heads)
538
+ )
539
+ self._key_dense.build([None, node_feature_dim])
480
540
 
481
- if self._embed_context:
482
- context_feature_dim = spec.context['feature'].shape[-1]
483
- self._context_dense = self.get_dense(self.dim)
484
- self._context_dense.build([None, context_feature_dim])
541
+ self._value_dense = self.get_einsum_dense(
542
+ 'ij,jkh->ikh', (self.head_units, self.heads)
543
+ )
544
+ self._value_dense.build([None, node_feature_dim])
485
545
 
486
- def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
487
- """Calls the layer.
488
- """
489
- feature = self._node_dense(tensor.node['feature'])
546
+ self._output_dense = self.get_dense(self.units)
547
+ self._output_dense.build([None, self.units])
490
548
 
491
- if self._has_super:
492
- super_feature = (0 if self._embed_context else self._super_feature)
493
- super_mask = keras.ops.expand_dims(tensor.node['super'], 1)
494
- feature = keras.ops.where(super_mask, super_feature, feature)
549
+ self._softmax_dropout = keras.layers.Dropout(self._attention_dropout)
495
550
 
496
- if self._embed_context:
497
- context_feature = self._context_dense(tensor.context['feature'])
498
- feature = ops.scatter_update(feature, tensor.node['super'], context_feature)
499
- tensor = tensor.update({'context': {'feature': None}})
551
+ self._self_attention_dropout = keras.layers.Dropout(self._dropout)
500
552
 
501
- if (
502
- self._allow_masking and
503
- self._masking_rate is not None and
504
- self._masking_rate > 0
505
- ):
506
- random = keras.random.uniform(shape=[tensor.num_nodes])
507
- mask = random <= self._masking_rate
508
- if self._has_super:
509
- mask = keras.ops.logical_and(
510
- mask, keras.ops.logical_not(tensor.node['super'])
511
- )
512
- mask = keras.ops.expand_dims(mask, -1)
513
- feature = keras.ops.where(mask, self._mask_feature, feature)
514
- elif self._allow_masking:
515
- # Slience warning of 'no gradients for variables'
516
- feature = feature + (self._mask_feature * 0.0)
553
+ self._add_edge_bias = not 'bias' in spec.edge
554
+ if self._add_edge_bias:
555
+ self._add_edge_bias = AddEdgeBias()
556
+ self._add_edge_bias.build_from_spec(spec)
517
557
 
518
- return tensor.update({'node': {'feature': feature}})
558
+ has_overridden_update = self.__class__.update != GTConv.update
559
+ if not has_overridden_update:
560
+
561
+ if self._normalize:
562
+ self._feedforward_output_norm = keras.layers.LayerNormalization()
563
+ self._feedforward_output_norm.build([None, self.units])
519
564
 
520
- @property
521
- def masking_rate(self):
522
- return self._masking_rate
523
-
524
- @masking_rate.setter
525
- def masking_rate(self, rate: float):
526
- if not self._allow_masking and rate is not None:
527
- raise ValueError(
528
- f'Cannot set `masking_rate` for layer {self} '
529
- 'as `allow_masking` was set to `False`.'
530
- )
531
- self._masking_rate = float(rate)
565
+ self._feedforward_dropout = keras.layers.Dropout(self._dropout)
532
566
 
533
- def get_config(self) -> dict:
534
- config = super().get_config()
535
- config.update({
536
- 'dim': self.dim,
537
- 'allow_masking': self._allow_masking
538
- })
539
- return config
540
-
567
+ self._feedforward_intermediate_dense = self.get_dense(self.units)
568
+ self._feedforward_intermediate_dense.build([None, self.units])
541
569
 
542
- @keras.saving.register_keras_serializable(package='molcraft')
543
- class EdgeEmbedding(GraphLayer):
570
+ self._feedforward_output_dense = self.get_dense(self.units)
571
+ self._feedforward_output_dense.build([None, self.units])
544
572
 
545
- """Edge embedding layer.
546
573
 
547
- Embeds edges based on its initial features.
548
- """
574
+ def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
575
+ """Computes messages.
576
+ """
549
577
 
550
- def __init__(
551
- self,
552
- dim: int = None,
553
- allow_masking: bool = True,
554
- **kwargs
555
- ) -> None:
556
- super().__init__(**kwargs)
557
- self.dim = dim
558
- self._masking_rate = None
559
- self._allow_masking = allow_masking
578
+ node_feature = tensor.node['feature']
579
+
580
+ query = self._query_dense(node_feature)
581
+ key = self._key_dense(node_feature)
582
+ value = self._value_dense(node_feature)
560
583
 
561
- def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
562
- """Builds the layer.
563
- """
564
- feature_dim = spec.edge['feature'].shape[-1]
565
- if not self.dim:
566
- self.dim = feature_dim
567
- self._edge_dense = self.get_dense(self.dim)
568
- self._edge_dense.build([None, feature_dim])
584
+ query = ops.gather(query, tensor.edge['source'])
585
+ key = ops.gather(key, tensor.edge['target'])
586
+ value = ops.gather(value, tensor.edge['source'])
569
587
 
570
- self._has_super = 'super' in spec.edge
571
- if self._has_super:
572
- self._super_feature = self.get_weight(shape=[self.dim], name='super_edge_feature')
573
- if self._allow_masking:
574
- self._mask_feature = self.get_weight(shape=[self.dim], name='mask_edge_feature')
588
+ attention_score = keras.ops.sum(query * key, axis=1, keepdims=True)
589
+ attention_score /= keras.ops.sqrt(float(self.head_units))
575
590
 
576
- def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
577
- """Calls the layer.
591
+ if self._add_edge_bias:
592
+ tensor = self._add_edge_bias(tensor)
593
+
594
+ attention_score += keras.ops.expand_dims(tensor.edge['bias'], -1)
595
+
596
+ attention = ops.edge_softmax(attention_score, tensor.edge['target'])
597
+ attention = self._softmax_dropout(attention)
598
+
599
+ return tensor.update(
600
+ {
601
+ 'edge': {
602
+ 'message': value,
603
+ 'weight': attention,
604
+ },
605
+ }
606
+ )
607
+
608
+ def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
609
+ """Aggregates messages.
578
610
  """
579
- feature = self._edge_dense(tensor.edge['feature'])
611
+ node_feature = tensor.aggregate('message')
612
+ node_feature = keras.ops.reshape(node_feature, (-1, self.units))
613
+ node_feature = self._output_dense(node_feature)
614
+ node_feature = self._self_attention_dropout(node_feature)
615
+ return tensor.update(
616
+ {
617
+ 'node': {
618
+ 'feature': node_feature,
619
+ 'residual': tensor.node['feature']
620
+ },
621
+ 'edge': {
622
+ 'message': None,
623
+ 'weight': None,
624
+ }
625
+ }
626
+ )
580
627
 
581
- if self._has_super:
582
- super_feature = self._super_feature
583
- super_mask = keras.ops.expand_dims(tensor.edge['super'], 1)
584
- feature = keras.ops.where(super_mask, super_feature, feature)
628
+ def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
629
+ """Updates nodes.
630
+ """
631
+ node_feature = tensor.node['feature']
585
632
 
586
- if (
587
- self._allow_masking and
588
- self._masking_rate is not None and
589
- self._masking_rate > 0
590
- ):
591
- random = keras.random.uniform(shape=[tensor.num_edges])
592
- mask = random <= self._masking_rate
593
- if self._has_super:
594
- mask = keras.ops.logical_and(
595
- mask, keras.ops.logical_not(tensor.edge['super'])
596
- )
597
- mask = keras.ops.expand_dims(mask, -1)
598
- feature = keras.ops.where(mask, self._mask_feature, feature)
599
- elif self._allow_masking:
600
- # Slience warning of 'no gradients for variables'
601
- feature = feature + (self._mask_feature * 0.0)
633
+ residual = tensor.node['residual']
634
+ if self.project_residual:
635
+ residual = self._residual_dense(residual)
602
636
 
603
- return tensor.update({'edge': {'feature': feature}})
637
+ node_feature += residual
638
+ residual = node_feature
604
639
 
605
- @property
606
- def masking_rate(self):
607
- return self._masking_rate
608
-
609
- @masking_rate.setter
610
- def masking_rate(self, rate: float):
611
- if not self._allow_masking and rate is not None:
612
- raise ValueError(
613
- f'Cannot set `masking_rate` for layer {self} '
614
- 'as `allow_masking` was set to `False`.'
615
- )
616
- self._masking_rate = float(rate)
640
+ node_feature = self._feedforward_intermediate_dense(node_feature)
641
+ node_feature = self._activation(node_feature)
642
+ node_feature = self._feedforward_output_dense(node_feature)
643
+ node_feature = self._feedforward_dropout(node_feature)
644
+ if self._normalize:
645
+ node_feature = self._feedforward_output_norm(node_feature)
617
646
 
647
+ node_feature += residual
648
+
649
+ return tensor.update(
650
+ {
651
+ 'node': {
652
+ 'feature': node_feature,
653
+ },
654
+ }
655
+ )
656
+
618
657
  def get_config(self) -> dict:
619
658
  config = super().get_config()
620
659
  config.update({
621
- 'dim': self.dim,
622
- 'allow_masking': self._allow_masking
660
+ "heads": self._heads,
661
+ 'activation': keras.activations.serialize(self._activation),
662
+ 'dropout': self._dropout,
663
+ 'attention_dropout': self._attention_dropout,
623
664
  })
624
665
  return config
625
666
 
626
667
 
627
668
  @keras.saving.register_keras_serializable(package='molcraft')
628
- class ContextProjection(Projection):
629
- """Context projection layer.
630
- """
631
- def __init__(self, units: int = None, activation: str = None, **kwargs):
632
- super().__init__(units=units, activation=activation, field='context', **kwargs)
633
-
634
-
635
- @keras.saving.register_keras_serializable(package='molcraft')
636
- class NodeProjection(Projection):
637
- """Node projection layer.
638
- """
639
- def __init__(self, units: int = None, activation: str = None, **kwargs):
640
- super().__init__(units=units, activation=activation, field='node', **kwargs)
641
-
642
-
643
- @keras.saving.register_keras_serializable(package='molcraft')
644
- class EdgeProjection(Projection):
645
- """Edge projection layer.
646
- """
647
- def __init__(self, units: int = None, activation: str = None, **kwargs):
648
- super().__init__(units=units, activation=activation, field='edge', **kwargs)
649
-
650
-
651
- @keras.saving.register_keras_serializable(package='molcraft')
652
- class GINConv(GraphConv):
669
+ class EGConv3D(GraphConv):
653
670
 
654
- """Graph isomorphism network layer.
671
+ """Equivariant graph neural network layer.
655
672
  """
656
673
 
657
674
  def __init__(
658
- self,
659
- units: int,
660
- activation: keras.layers.Activation | str | None = 'relu',
661
- dropout: float = 0.0,
675
+ self,
676
+ units: int = 128,
677
+ activation: keras.layers.Activation | str | None = None,
678
+ use_bias: bool = True,
662
679
  normalize: bool = True,
663
- update_edge_feature: bool = True,
664
- **kwargs,
665
- ):
666
- super().__init__(units=units, **kwargs)
680
+ dropout: float = 0.0,
681
+ **kwargs
682
+ ) -> None:
683
+ super().__init__(
684
+ units=units,
685
+ normalize=normalize,
686
+ use_bias=use_bias,
687
+ **kwargs
688
+ )
667
689
  self._activation = keras.activations.get(activation)
668
- self._normalize = normalize
669
- self._dropout = dropout
670
- self._update_edge_feature = update_edge_feature
690
+ self._dropout = dropout or 0.0
671
691
 
672
692
  def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
673
- """Builds the layer.
674
- """
693
+ if 'coordinate' not in spec.node:
694
+ raise ValueError(
695
+ 'Could not find `coordinate`s in node, '
696
+ 'which is required for Conv3D layers.'
697
+ )
675
698
  node_feature_dim = spec.node['feature'].shape[-1]
676
-
677
- self.epsilon = self.add_weight(
678
- name='epsilon',
679
- shape=(),
680
- initializer='zeros',
681
- trainable=True,
682
- )
683
-
699
+ feature_dim = node_feature_dim + node_feature_dim + 1
684
700
  if 'feature' in spec.edge:
701
+ self._has_edge_feature = True
685
702
  edge_feature_dim = spec.edge['feature'].shape[-1]
703
+ feature_dim += edge_feature_dim
704
+ else:
705
+ self._has_edge_feature = False
686
706
 
687
- if not self._update_edge_feature:
688
- if (edge_feature_dim != node_feature_dim):
689
- warn(
690
- 'Found edge feature dim to be incompatible with node feature dim. '
691
- 'Automatically adding a edge feature projection layer to match '
692
- 'the dim of node features.'
693
- )
694
- self._update_edge_feature = True
707
+ self.message_fn = self.get_dense(self.units, activation=self._activation)
708
+ self.message_fn.build([None, feature_dim])
709
+ self.dense_position = self.get_dense(1)
710
+ self.dense_position.build([None, self.units])
695
711
 
696
- if self._update_edge_feature:
697
- self._edge_dense = self.get_dense(node_feature_dim)
698
- self._edge_dense.build([None, edge_feature_dim])
699
- else:
700
- self._update_edge_feature = False
701
-
702
- has_overridden_update = self.__class__.update != GINConv.update
712
+ has_overridden_update = self.__class__.update != EGConv3D.update
703
713
  if not has_overridden_update:
704
- # Use default feedforward network
705
- self._feedforward_intermediate_dense = self.get_dense(self.units)
706
- self._feedforward_intermediate_dense.build([None, node_feature_dim])
714
+ self.update_fn = self.get_dense(self.units, activation=self._activation)
715
+ self.update_fn.build([None, node_feature_dim + self.units])
716
+ self._dropout_layer = keras.layers.Dropout(self._dropout)
707
717
 
708
- if self._normalize:
709
- self._feedforward_intermediate_norm = keras.layers.BatchNormalization()
710
- self._feedforward_intermediate_norm.build([None, self.units])
711
-
712
- self._feedforward_dropout = keras.layers.Dropout(self._dropout)
713
- self._feedforward_activation = self._activation
714
-
715
- self._feedforward_output_dense = self.get_dense(self.units)
716
- self._feedforward_output_dense.build([None, self.units])
717
-
718
718
  def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
719
- """Compute messages.
719
+ """Computes messages.
720
720
  """
721
- message = tensor.gather('feature', 'source')
722
- edge_feature = tensor.edge.get('feature')
723
- if self._update_edge_feature:
724
- edge_feature = self._edge_dense(edge_feature)
725
- if edge_feature is not None:
726
- message += edge_feature
721
+ relative_node_coordinate = keras.ops.subtract(
722
+ tensor.gather('coordinate', 'target'),
723
+ tensor.gather('coordinate', 'source')
724
+ )
725
+ euclidean_distance = keras.ops.sum(
726
+ keras.ops.square(
727
+ relative_node_coordinate
728
+ ),
729
+ axis=-1,
730
+ keepdims=True
731
+ )
732
+ feature = keras.ops.concatenate(
733
+ [
734
+ tensor.gather('feature', 'target'),
735
+ tensor.gather('feature', 'source'),
736
+ euclidean_distance,
737
+ ],
738
+ axis=-1
739
+ )
740
+ if self._has_edge_feature:
741
+ feature = keras.ops.concatenate(
742
+ [
743
+ feature,
744
+ tensor.edge['feature']
745
+ ],
746
+ axis=-1
747
+ )
748
+ message = self.message_fn(feature)
749
+ relative_node_coordinate = keras.ops.multiply(
750
+ relative_node_coordinate,
751
+ self.dense_position(message)
752
+ )
727
753
  return tensor.update(
728
754
  {
729
755
  'edge': {
730
756
  'message': message,
731
- 'feature': edge_feature
757
+ 'relative_node_coordinate': relative_node_coordinate
732
758
  }
733
759
  }
734
760
  )
@@ -736,34 +762,54 @@ class GINConv(GraphConv):
736
762
  def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
737
763
  """Aggregates messages.
738
764
  """
739
- node_feature = tensor.aggregate('message')
740
- node_feature += (1 + self.epsilon) * tensor.node['feature']
765
+ coefficient = keras.ops.bincount(
766
+ tensor.edge['source'],
767
+ minlength=tensor.num_nodes
768
+ )
769
+ coefficient = keras.ops.cast(
770
+ coefficient, tensor.node['coordinate'].dtype
771
+ )
772
+ coefficient = keras.ops.expand_dims(
773
+ keras.ops.divide_no_nan(1, coefficient), axis=1
774
+ )
775
+
776
+ updated_coordinate = tensor.aggregate('relative_node_coordinate') * coefficient
777
+ updated_coordinate += tensor.node['coordinate']
778
+
779
+ aggregate = tensor.aggregate('message')
741
780
  return tensor.update(
742
781
  {
743
782
  'node': {
744
- 'feature': node_feature,
783
+ 'feature': aggregate,
784
+ 'coordinate': updated_coordinate,
785
+ 'previous_feature': tensor.node['feature'],
745
786
  },
746
787
  'edge': {
747
788
  'message': None,
789
+ 'relative_node_coordinate': None
748
790
  }
749
791
  }
750
- )
751
-
792
+ )
793
+
752
794
  def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
753
- """Updates nodes.
795
+ """Updates nodes.
754
796
  """
755
- node_feature = tensor.node['feature']
756
- node_feature = self._feedforward_intermediate_dense(node_feature)
757
- node_feature = self._feedforward_activation(node_feature)
758
- if self._normalize:
759
- node_feature = self._feedforward_intermediate_norm(node_feature)
760
- node_feature = self._feedforward_dropout(node_feature)
761
- node_feature = self._feedforward_output_dense(node_feature)
797
+ updated_node_feature = self.update_fn(
798
+ keras.ops.concatenate(
799
+ [
800
+ tensor.node['feature'],
801
+ tensor.node['previous_feature']
802
+ ],
803
+ axis=-1
804
+ )
805
+ )
806
+ updated_node_feature = self._dropout_layer(updated_node_feature)
762
807
  return tensor.update(
763
808
  {
764
809
  'node': {
765
- 'feature': node_feature,
766
- }
810
+ 'feature': updated_node_feature,
811
+ 'previous_feature': None,
812
+ },
767
813
  }
768
814
  )
769
815
 
@@ -771,267 +817,390 @@ class GINConv(GraphConv):
771
817
  config = super().get_config()
772
818
  config.update({
773
819
  'activation': keras.activations.serialize(self._activation),
774
- 'dropout': self._dropout,
775
- 'normalize': self._normalize,
820
+ 'dropout': self._dropout,
776
821
  })
777
- return config
778
-
822
+ return config
823
+
779
824
 
780
825
  @keras.saving.register_keras_serializable(package='molcraft')
781
- class GTConv(GraphConv):
782
-
783
- """Graph transformer layer.
826
+ class Projection(GraphLayer):
827
+ """Base graph projection layer.
784
828
  """
785
-
786
829
  def __init__(
787
- self,
788
- units: int,
789
- heads: int = 8,
790
- activation: keras.layers.Activation | str | None = "relu",
791
- dropout: float = 0.0,
792
- attention_dropout: float = 0.0,
793
- normalize: bool = True,
794
- normalize_first: bool = True,
795
- **kwargs,
830
+ self,
831
+ units: int = None,
832
+ activation: str = None,
833
+ field: str = 'node',
834
+ **kwargs
796
835
  ) -> None:
797
- super().__init__(units=units, **kwargs)
798
- self._heads = heads
799
- if self.units % self.heads != 0:
800
- raise ValueError(f"units need to be divisible by heads.")
801
- self._head_units = self.units // self.heads
836
+ super().__init__(**kwargs)
837
+ self.units = units
802
838
  self._activation = keras.activations.get(activation)
803
- self._dropout = dropout
804
- self._attention_dropout = attention_dropout
805
- self._normalize = normalize
806
- self._normalize_first = normalize_first
839
+ self.field = field
807
840
 
808
- @property
809
- def heads(self):
810
- return self._heads
811
-
812
- @property
813
- def head_units(self):
814
- return self._head_units
841
+ def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
842
+ """Builds the layer.
843
+ """
844
+ data = getattr(spec, self.field, None)
845
+ if data is None:
846
+ raise ValueError('Could not access field {self.field!r}.')
847
+ feature_dim = data['feature'].shape[-1]
848
+ if not self.units:
849
+ self.units = feature_dim
850
+ self._dense = self.get_dense(self.units)
851
+ self._dense.build([None, feature_dim])
852
+
853
+ def propagate(self, tensor: tensors.GraphTensor):
854
+ """Calls the layer.
855
+ """
856
+ feature = getattr(tensor, self.field)['feature']
857
+ feature = self._dense(feature)
858
+ feature = self._activation(feature)
859
+ return tensor.update(
860
+ {
861
+ self.field: {
862
+ 'feature': feature
863
+ }
864
+ }
865
+ )
866
+
867
+ def get_config(self) -> dict:
868
+ config = super().get_config()
869
+ config.update({
870
+ 'units': self.units,
871
+ 'activation': keras.activations.serialize(self._activation),
872
+ 'field': self.field,
873
+ })
874
+ return config
815
875
 
816
- def build_from_spec(self, spec):
876
+
877
+ @keras.saving.register_keras_serializable(package='molcraft')
878
+ class GraphNetwork(GraphLayer):
879
+
880
+ """Graph neural network.
881
+
882
+ Sequentially calls graph layers (`GraphLayer`) and concatenates its output.
883
+
884
+ Args:
885
+ layers (list):
886
+ A list of graph layers.
887
+ """
888
+
889
+ def __init__(self, layers: list[GraphLayer], **kwargs) -> None:
890
+ super().__init__(**kwargs)
891
+ self.layers = layers
892
+ self._update_edge_feature = False
893
+
894
+ def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
817
895
  """Builds the layer.
818
896
  """
897
+ units = self.layers[0].units
819
898
  node_feature_dim = spec.node['feature'].shape[-1]
820
- incompatible_dim = node_feature_dim != self.units
821
- if incompatible_dim:
822
- warnings.warn(
823
- message=(
824
- '`GTConv` uses residual connections, but input node feature dim '
825
- 'is incompatible with intermediate dim (`units`). '
826
- 'Automatically projecting first residual to match its dim with intermediate dim.'
827
- ),
828
- category=UserWarning,
829
- stacklevel=1
899
+ if node_feature_dim != units:
900
+ warn(
901
+ 'Node feature dim does not match `units` of the first layer. '
902
+ 'Automatically adding a node projection layer to match `units`.'
830
903
  )
831
- self._residual_dense = self.get_dense(self.units)
832
- self._residual_dense.build([None, node_feature_dim])
833
- self._project_residual = True
834
- else:
835
- self._project_residual = False
836
-
837
- self._query_dense = self.get_einsum_dense(
838
- 'ij,jkh->ikh', (self.head_units, self.heads)
839
- )
840
- self._query_dense.build([None, node_feature_dim])
904
+ self._node_dense = self.get_dense(units)
905
+ self._update_node_feature = True
906
+ has_edge_feature = 'feature' in spec.edge
907
+ if has_edge_feature:
908
+ edge_feature_dim = spec.edge['feature'].shape[-1]
909
+ if edge_feature_dim != units:
910
+ warn(
911
+ 'Edge feature dim does not match `units` of the first layer. '
912
+ 'Automatically adding a edge projection layer to match `units`.'
913
+ )
914
+ self._edge_dense = self.get_dense(units)
915
+ self._update_edge_feature = True
841
916
 
842
- self._key_dense = self.get_einsum_dense(
843
- 'ij,jkh->ikh', (self.head_units, self.heads)
917
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
918
+ """Calls the layer.
919
+ """
920
+ x = tensors.to_dict(tensor)
921
+ if self._update_node_feature:
922
+ x['node']['feature'] = self._node_dense(tensor.node['feature'])
923
+ if self._update_edge_feature:
924
+ x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
925
+ outputs = [x['node']['feature']]
926
+ for layer in self.layers:
927
+ x = layer(x)
928
+ outputs.append(x['node']['feature'])
929
+ return tensor.update(
930
+ {
931
+ 'node': {
932
+ 'feature': keras.ops.concatenate(outputs, axis=-1)
933
+ }
934
+ }
844
935
  )
845
- self._key_dense.build([None, node_feature_dim])
936
+
937
+ def tape_propagate(
938
+ self,
939
+ tensor: tensors.GraphTensor,
940
+ tape: tf.GradientTape,
941
+ training: bool | None = None,
942
+ ) -> tuple[tensors.GraphTensor, list[tf.Tensor]]:
943
+ """Performs the propagation with a `GradientTape`.
846
944
 
847
- self._value_dense = self.get_einsum_dense(
848
- 'ij,jkh->ikh', (self.head_units, self.heads)
945
+ Performs the same forward pass as `propagate` but with a `GradientTape`
946
+ watching intermediate node features.
947
+
948
+ Args:
949
+ tensor (tensors.GraphTensor):
950
+ The graph input.
951
+ """
952
+ if isinstance(tensor, tensors.GraphTensor):
953
+ x = tensors.to_dict(tensor)
954
+ else:
955
+ x = tensor
956
+ if self._update_node_feature:
957
+ x['node']['feature'] = self._node_dense(tensor.node['feature'])
958
+ if self._update_edge_feature:
959
+ x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
960
+ tape.watch(x['node']['feature'])
961
+ outputs = [x['node']['feature']]
962
+ for layer in self.layers:
963
+ x = layer(x, training=training)
964
+ tape.watch(x['node']['feature'])
965
+ outputs.append(x['node']['feature'])
966
+
967
+ tensor = tensor.update(
968
+ {
969
+ 'node': {
970
+ 'feature': keras.ops.concatenate(outputs, axis=-1)
971
+ }
972
+ }
849
973
  )
850
- self._value_dense.build([None, node_feature_dim])
974
+ return tensor, outputs
975
+
976
+ def get_config(self) -> dict:
977
+ config = super().get_config()
978
+ config.update(
979
+ {
980
+ 'layers': [
981
+ keras.layers.serialize(layer) for layer in self.layers
982
+ ]
983
+ }
984
+ )
985
+ return config
986
+
987
+ @classmethod
988
+ def from_config(cls, config: dict) -> 'GraphNetwork':
989
+ config['layers'] = [
990
+ keras.layers.deserialize(layer) for layer in config['layers']
991
+ ]
992
+ return super().from_config(config)
993
+
994
+
995
+ @keras.saving.register_keras_serializable(package='molcraft')
996
+ class NodeEmbedding(GraphLayer):
997
+
998
+ """Node embedding layer.
999
+
1000
+ Embeds nodes based on its initial features.
1001
+ """
1002
+
1003
+ def __init__(
1004
+ self,
1005
+ dim: int = None,
1006
+ embed_context: bool = True,
1007
+ allow_masking: bool = True,
1008
+ **kwargs
1009
+ ) -> None:
1010
+ super().__init__(**kwargs)
1011
+ self.dim = dim
1012
+ self._embed_context = embed_context
1013
+ self._masking_rate = None
1014
+ self._allow_masking = allow_masking
1015
+
1016
+ def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
1017
+ """Builds the layer.
1018
+ """
1019
+ feature_dim = spec.node['feature'].shape[-1]
1020
+ if not self.dim:
1021
+ self.dim = feature_dim
1022
+ self._node_dense = self.get_dense(self.dim)
1023
+ self._node_dense.build([None, feature_dim])
851
1024
 
852
- self._output_dense = self.get_dense(self.units)
853
- self._output_dense.build([None, self.units])
1025
+ self._has_super = 'super' in spec.node
1026
+ has_context_feature = 'feature' in spec.context
1027
+ if not has_context_feature:
1028
+ self._embed_context = False
1029
+ if self._has_super and not self._embed_context:
1030
+ self._super_feature = self.get_weight(shape=[self.dim], name='super_node_feature')
1031
+ if self._allow_masking:
1032
+ self._mask_feature = self.get_weight(shape=[self.dim], name='mask_node_feature')
854
1033
 
855
- self._softmax_dropout = keras.layers.Dropout(self._attention_dropout)
1034
+ if self._embed_context:
1035
+ context_feature_dim = spec.context['feature'].shape[-1]
1036
+ self._context_dense = self.get_dense(self.dim)
1037
+ self._context_dense.build([None, context_feature_dim])
856
1038
 
857
- self._self_attention_norm = keras.layers.LayerNormalization()
858
- if self._normalize_first:
859
- self._self_attention_norm.build([None, node_feature_dim])
860
- else:
861
- self._self_attention_norm.build([None, self.units])
1039
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1040
+ """Calls the layer.
1041
+ """
1042
+ feature = self._node_dense(tensor.node['feature'])
862
1043
 
863
- self._self_attention_dropout = keras.layers.Dropout(self._dropout)
1044
+ if self._has_super:
1045
+ super_feature = (0 if self._embed_context else self._super_feature)
1046
+ super_mask = keras.ops.expand_dims(tensor.node['super'], 1)
1047
+ feature = keras.ops.where(super_mask, super_feature, feature)
864
1048
 
865
- has_overriden_edge_bias = (
866
- self.__class__.add_edge_bias != GTConv.add_edge_bias
867
- )
868
- if not has_overriden_edge_bias:
869
- self._has_edge_length = 'length' in spec.edge
870
- if self._has_edge_length and 'bias' not in spec.edge:
871
- edge_length_dim = spec.edge['length'].shape[-1]
872
- self._spatial_encoding_dense = self.get_einsum_dense(
873
- 'ij,jkh->ikh', (1, self.heads), kernel_initializer='zeros'
874
- )
875
- self._spatial_encoding_dense.build([None, edge_length_dim])
1049
+ if self._embed_context:
1050
+ context_feature = self._context_dense(tensor.context['feature'])
1051
+ feature = ops.scatter_update(feature, tensor.node['super'], context_feature)
1052
+ tensor = tensor.update({'context': {'feature': None}})
876
1053
 
877
- self._has_edge_feature = 'feature' in spec.edge
878
- if self._has_edge_feature and 'bias' not in spec.edge:
879
- edge_feature_dim = spec.edge['feature'].shape[-1]
880
- self._edge_feature_dense = self.get_einsum_dense(
881
- 'ij,jkh->ikh', (1, self.heads),
1054
+ if (
1055
+ self._allow_masking and
1056
+ self._masking_rate is not None and
1057
+ self._masking_rate > 0
1058
+ ):
1059
+ random = keras.random.uniform(shape=[tensor.num_nodes])
1060
+ mask = random <= self._masking_rate
1061
+ if self._has_super:
1062
+ mask = keras.ops.logical_and(
1063
+ mask, keras.ops.logical_not(tensor.node['super'])
882
1064
  )
883
- self._edge_feature_dense.build([None, edge_feature_dim])
884
-
885
- has_overridden_update = self.__class__.update != GTConv.update
886
- if not has_overridden_update:
887
-
888
- self._feedforward_norm = keras.layers.LayerNormalization()
889
- self._feedforward_norm.build([None, self.units])
890
-
891
- self._feedforward_dropout = keras.layers.Dropout(self._dropout)
892
-
893
- self._feedforward_intermediate_dense = self.get_dense(self.units)
894
- self._feedforward_intermediate_dense.build([None, self.units])
1065
+ mask = keras.ops.expand_dims(mask, -1)
1066
+ feature = keras.ops.where(mask, self._mask_feature, feature)
1067
+ elif self._allow_masking:
1068
+ # Slience warning of 'no gradients for variables'
1069
+ feature = feature + (self._mask_feature * 0.0)
895
1070
 
896
- self._feedforward_output_dense = self.get_dense(self.units)
897
- self._feedforward_output_dense.build([None, self.units])
1071
+ return tensor.update({'node': {'feature': feature}})
898
1072
 
899
- def add_node_bias(self, tensor: tensors.GraphTensor) -> tf.Tensor:
900
- return tensor
1073
+ @property
1074
+ def masking_rate(self):
1075
+ return self._masking_rate
901
1076
 
902
- def add_edge_bias(self, tensor: tensors.GraphTensor) -> tf.Tensor:
903
- if 'bias' in tensor.edge:
904
- return tensor
905
- elif not self._has_edge_feature and not self._has_edge_length:
906
- return tensor
907
-
908
- if self._has_edge_feature and not self._has_edge_length:
909
- edge_bias = self._edge_feature_dense(tensor.edge['feature'])
910
- elif not self._has_edge_feature and self._has_edge_length:
911
- edge_bias = self._spatial_encoding_dense(tensor.edge['length'])
912
- else:
913
- edge_bias = (
914
- self._edge_feature_dense(tensor.edge['feature']) +
915
- self._spatial_encoding_dense(tensor.edge['length'])
1077
+ @masking_rate.setter
1078
+ def masking_rate(self, rate: float):
1079
+ if not self._allow_masking and rate is not None:
1080
+ raise ValueError(
1081
+ f'Cannot set `masking_rate` for layer {self} '
1082
+ 'as `allow_masking` was set to `False`.'
916
1083
  )
917
-
918
- return tensor.update(
919
- {
920
- 'edge': {
921
- 'bias': edge_bias
922
- }
923
- }
924
- )
925
-
926
- def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
927
- """Compute messages.
928
- """
929
- tensor = self.add_edge_bias(tensor)
930
- tensor = self.add_node_bias(tensor)
931
-
932
- node_feature = tensor.node['feature']
1084
+ self._masking_rate = float(rate)
933
1085
 
934
- if 'bias' in tensor.node:
935
- node_feature += tensor.node['bias']
936
-
937
- if self._normalize_first:
938
- node_feature = self._self_attention_norm(node_feature)
939
-
940
- query = self._query_dense(node_feature)
941
- key = self._key_dense(node_feature)
942
- value = self._value_dense(node_feature)
1086
+ def get_config(self) -> dict:
1087
+ config = super().get_config()
1088
+ config.update({
1089
+ 'dim': self.dim,
1090
+ 'allow_masking': self._allow_masking
1091
+ })
1092
+ return config
1093
+
943
1094
 
944
- query = ops.gather(query, tensor.edge['source'])
945
- key = ops.gather(key, tensor.edge['target'])
946
- value = ops.gather(value, tensor.edge['source'])
1095
+ @keras.saving.register_keras_serializable(package='molcraft')
1096
+ class EdgeEmbedding(GraphLayer):
947
1097
 
948
- attention_score = keras.ops.sum(query * key, axis=1, keepdims=True)
949
- attention_score /= keras.ops.sqrt(float(self.units))
1098
+ """Edge embedding layer.
950
1099
 
951
- if 'bias' in tensor.edge:
952
- attention_score += tensor.edge['bias']
953
-
954
- attention = ops.edge_softmax(attention_score, tensor.edge['target'])
955
- attention = self._softmax_dropout(attention)
1100
+ Embeds edges based on its initial features.
1101
+ """
956
1102
 
957
- return tensor.update(
958
- {
959
- 'edge': {
960
- 'message': value,
961
- 'weight': attention,
962
- },
963
- }
964
- )
1103
+ def __init__(
1104
+ self,
1105
+ dim: int = None,
1106
+ allow_masking: bool = True,
1107
+ **kwargs
1108
+ ) -> None:
1109
+ super().__init__(**kwargs)
1110
+ self.dim = dim
1111
+ self._masking_rate = None
1112
+ self._allow_masking = allow_masking
965
1113
 
966
- def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
967
- """Aggregates messages.
1114
+ def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
1115
+ """Builds the layer.
968
1116
  """
969
- node_feature = tensor.aggregate('message')
970
-
971
- node_feature = keras.ops.reshape(node_feature, (-1, self.units))
972
- node_feature = self._output_dense(node_feature)
973
- node_feature = self._self_attention_dropout(node_feature)
974
-
975
- residual = tensor.node['feature']
976
- if self._project_residual:
977
- residual = self._residual_dense(residual)
978
- node_feature += residual
979
-
980
- if not self._normalize_first:
981
- node_feature = self._self_attention_norm(node_feature)
1117
+ feature_dim = spec.edge['feature'].shape[-1]
1118
+ if not self.dim:
1119
+ self.dim = feature_dim
1120
+ self._edge_dense = self.get_dense(self.dim)
1121
+ self._edge_dense.build([None, feature_dim])
982
1122
 
983
- return tensor.update(
984
- {
985
- 'node': {
986
- 'feature': node_feature,
987
- },
988
- 'edge': {
989
- 'message': None,
990
- 'weight': None,
991
- }
992
- }
993
- )
994
-
1123
+ self._has_super = 'super' in spec.edge
1124
+ if self._has_super:
1125
+ self._super_feature = self.get_weight(shape=[self.dim], name='super_edge_feature')
1126
+ if self._allow_masking:
1127
+ self._mask_feature = self.get_weight(shape=[self.dim], name='mask_edge_feature')
995
1128
 
996
- def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
997
- """Updates nodes.
1129
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1130
+ """Calls the layer.
998
1131
  """
999
- node_feature = tensor.node['feature']
1000
-
1001
- if self._normalize_first:
1002
- node_feature = self._feedforward_norm(node_feature)
1132
+ feature = self._edge_dense(tensor.edge['feature'])
1003
1133
 
1004
- node_feature = self._feedforward_intermediate_dense(node_feature)
1005
- node_feature = self._activation(node_feature)
1006
- node_feature = self._feedforward_output_dense(node_feature)
1134
+ if self._has_super:
1135
+ super_feature = self._super_feature
1136
+ super_mask = keras.ops.expand_dims(tensor.edge['super'], 1)
1137
+ feature = keras.ops.where(super_mask, super_feature, feature)
1007
1138
 
1008
- node_feature = self._feedforward_dropout(node_feature)
1009
- node_feature += tensor.node['feature']
1139
+ if (
1140
+ self._allow_masking and
1141
+ self._masking_rate is not None and
1142
+ self._masking_rate > 0
1143
+ ):
1144
+ random = keras.random.uniform(shape=[tensor.num_edges])
1145
+ mask = random <= self._masking_rate
1146
+ if self._has_super:
1147
+ mask = keras.ops.logical_and(
1148
+ mask, keras.ops.logical_not(tensor.edge['super'])
1149
+ )
1150
+ mask = keras.ops.expand_dims(mask, -1)
1151
+ feature = keras.ops.where(mask, self._mask_feature, feature)
1152
+ elif self._allow_masking:
1153
+ # Slience warning of 'no gradients for variables'
1154
+ feature = feature + (self._mask_feature * 0.0)
1010
1155
 
1011
- if not self._normalize_first:
1012
- node_feature = self._feedforward_norm(node_feature)
1156
+ return tensor.update({'edge': {'feature': feature}})
1013
1157
 
1014
- return tensor.update(
1015
- {
1016
- 'node': {
1017
- 'feature': node_feature,
1018
- },
1019
- }
1020
- )
1158
+ @property
1159
+ def masking_rate(self):
1160
+ return self._masking_rate
1021
1161
 
1162
+ @masking_rate.setter
1163
+ def masking_rate(self, rate: float):
1164
+ if not self._allow_masking and rate is not None:
1165
+ raise ValueError(
1166
+ f'Cannot set `masking_rate` for layer {self} '
1167
+ 'as `allow_masking` was set to `False`.'
1168
+ )
1169
+ self._masking_rate = float(rate)
1170
+
1022
1171
  def get_config(self) -> dict:
1023
1172
  config = super().get_config()
1024
1173
  config.update({
1025
- "heads": self._heads,
1026
- 'activation': keras.activations.serialize(self._activation),
1027
- 'dropout': self._dropout,
1028
- 'attention_dropout': self._attention_dropout,
1029
- 'normalize': self._normalize,
1030
- 'normalize_first': self._normalize_first,
1174
+ 'dim': self.dim,
1175
+ 'allow_masking': self._allow_masking
1031
1176
  })
1032
1177
  return config
1033
1178
 
1034
1179
 
1180
+ @keras.saving.register_keras_serializable(package='molcraft')
1181
+ class ContextProjection(Projection):
1182
+ """Context projection layer.
1183
+ """
1184
+ def __init__(self, units: int = None, activation: str = None, **kwargs):
1185
+ super().__init__(units=units, activation=activation, field='context', **kwargs)
1186
+
1187
+
1188
+ @keras.saving.register_keras_serializable(package='molcraft')
1189
+ class NodeProjection(Projection):
1190
+ """Node projection layer.
1191
+ """
1192
+ def __init__(self, units: int = None, activation: str = None, **kwargs):
1193
+ super().__init__(units=units, activation=activation, field='node', **kwargs)
1194
+
1195
+
1196
+ @keras.saving.register_keras_serializable(package='molcraft')
1197
+ class EdgeProjection(Projection):
1198
+ """Edge projection layer.
1199
+ """
1200
+ def __init__(self, units: int = None, activation: str = None, **kwargs):
1201
+ super().__init__(units=units, activation=activation, field='edge', **kwargs)
1202
+
1203
+
1035
1204
  @keras.saving.register_keras_serializable(package='molcraft')
1036
1205
  class Readout(keras.layers.Layer):
1037
1206
 
@@ -1097,6 +1266,37 @@ class Readout(keras.layers.Layer):
1097
1266
  return config
1098
1267
 
1099
1268
 
1269
+ @keras.saving.register_keras_serializable(package='molcraft')
1270
+ class AddEdgeBias(GraphLayer):
1271
+
1272
+ def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
1273
+ self._has_edge_length = 'length' in spec.edge
1274
+ self._has_edge_feature = 'feature' in spec.edge
1275
+ if self._has_edge_feature:
1276
+ self._edge_feature_dense = self.get_dense(units=1)
1277
+ if self._has_edge_length:
1278
+ self._edge_length_dense = self.get_dense(
1279
+ units=1, kernel_initializer='zeros'
1280
+ )
1281
+
1282
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1283
+ bias = keras.ops.zeros(
1284
+ shape=(tensor.num_edges, 1),
1285
+ dtype=tensor.node['feature'].dtype
1286
+ )
1287
+ if self._has_edge_feature:
1288
+ bias += self._edge_feature_dense(tensor.edge['feature'])
1289
+ if self._has_edge_length:
1290
+ bias += self._edge_length_dense(tensor.edge['length'])
1291
+ return tensor.update(
1292
+ {
1293
+ 'edge': {
1294
+ 'bias': bias
1295
+ }
1296
+ }
1297
+ )
1298
+
1299
+
1100
1300
  def Input(spec: tensors.GraphTensor.Spec) -> dict:
1101
1301
  """Used to specify inputs to model.
1102
1302