molcraft 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

molcraft/layers.py ADDED
@@ -0,0 +1,1224 @@
1
+ import abc
2
+ import keras
3
+ import tensorflow as tf
4
+ import warnings
5
+ from keras.src.models import functional
6
+
7
+ from molcraft import tensors
8
+ from molcraft import ops
9
+
10
+
11
+ @keras.saving.register_keras_serializable(package='molcraft')
12
+ class GraphLayer(keras.layers.Layer):
13
+ """Base graph layer.
14
+
15
+ Currently, the `GraphLayer` only supports `GraphTensor` input.
16
+
17
+ The list of arguments are only relevant if the derived layer
18
+ invokes 'get_dense_kwargs`, `get_dense` or `get_einsum_dense`.
19
+
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ use_bias: bool = True,
25
+ kernel_initializer: keras.initializers.Initializer | str = "glorot_uniform",
26
+ bias_initializer: keras.initializers.Initializer | str = "zeros",
27
+ kernel_regularizer: keras.regularizers.Regularizer | None = None,
28
+ bias_regularizer: keras.regularizers.Regularizer | None = None,
29
+ activity_regularizer: keras.regularizers.Regularizer | None = None,
30
+ kernel_constraint: keras.constraints.Constraint | None = None,
31
+ bias_constraint: keras.constraints.Constraint | None = None,
32
+ **kwargs,
33
+ ) -> None:
34
+ super().__init__(activity_regularizer=activity_regularizer, **kwargs)
35
+ self._use_bias = use_bias
36
+ self._kernel_initializer = keras.initializers.get(kernel_initializer)
37
+ self._bias_initializer = keras.initializers.get(bias_initializer)
38
+ self._kernel_regularizer = keras.regularizers.get(kernel_regularizer)
39
+ self._bias_regularizer = keras.regularizers.get(bias_regularizer)
40
+ self._kernel_constraint = keras.constraints.get(kernel_constraint)
41
+ self._bias_constraint = keras.constraints.get(bias_constraint)
42
+ self.built = False
43
+ # TODO: Add warning if build is implemented in subclass
44
+ # TODO: Add warning if call is implemented in subclass
45
+
46
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
47
+ """Calls the layer.
48
+
49
+ Needs to be implemented by subclass.
50
+
51
+ Args:
52
+ tensor:
53
+ A `GraphTensor` instance.
54
+ """
55
+ raise NotImplementedError('`propagate` needs to be implemented.')
56
+
57
+ def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
58
+ """Builds the layer.
59
+
60
+ May use built-in methods such as `get_weight`, `get_dense` and `get_einsum_dense`.
61
+
62
+ Optionally implemented by subclass. If implemented, it is recommended to
63
+ build the sub-layers via `build([None, input_dim])`. If sub-layers are not
64
+ built, symbolic input will be passed through the layer to build it.
65
+
66
+ Args:
67
+ spec:
68
+ A `GraphTensor.Spec` instance, corresponding to the input `GraphTensor`
69
+ of the `propagate` method.
70
+ """
71
+
72
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
73
+
74
+ self._custom_build_config = {'spec': _serialize_spec(spec)}
75
+
76
+ invoke_build_from_spec = (
77
+ GraphLayer.build_from_spec != self.__class__.build_from_spec
78
+ )
79
+ if invoke_build_from_spec:
80
+ self.build_from_spec(spec)
81
+ self.built = True
82
+
83
+ if not self.built:
84
+ # Automatically build layer or model by calling it on symbolic inputs
85
+ self.built = True
86
+ symbolic_inputs = Input(spec)
87
+ self(symbolic_inputs)
88
+
89
+ def get_build_config(self) -> dict:
90
+ if not hasattr(self, '_custom_build_config'):
91
+ return super().get_build_config()
92
+ return self._custom_build_config
93
+
94
+ def build_from_config(self, config: dict) -> None:
95
+ use_custom_build_from_config = ('spec' in config)
96
+ if not use_custom_build_from_config:
97
+ super().build_from_config(config)
98
+ else:
99
+ spec = _deserialize_spec(config['spec'])
100
+ self.build(spec)
101
+
102
+ def call(
103
+ self,
104
+ graph: dict[str, dict[str, tf.Tensor]]
105
+ ) -> dict[str, dict[str, tf.Tensor]]:
106
+ graph_tensor = tensors.from_dict(graph)
107
+ outputs = self.propagate(graph_tensor)
108
+ if isinstance(outputs, tensors.GraphTensor):
109
+ return tensors.to_dict(outputs)
110
+ return outputs
111
+
112
+ def __call__(self, inputs, **kwargs):
113
+ if not self.built:
114
+ spec = _spec_from_inputs(inputs)
115
+ self.build(spec)
116
+ convert = isinstance(inputs, tensors.GraphTensor)
117
+ if convert:
118
+ inputs = tensors.to_dict(inputs)
119
+ if isinstance(self, functional.Functional):
120
+ inputs, left_out_inputs = _match_functional_input(self.input, inputs)
121
+ outputs = super().__call__(inputs, **kwargs)
122
+ if not tensors.is_graph(outputs):
123
+ return outputs
124
+ if isinstance(self, functional.Functional):
125
+ outputs = _add_left_out_inputs(outputs, left_out_inputs)
126
+ if convert:
127
+ outputs = tensors.from_dict(outputs)
128
+ return outputs
129
+
130
+ def get_weight(
131
+ self,
132
+ shape: tf.TensorShape,
133
+ **kwargs,
134
+ ) -> tf.Variable:
135
+ common_kwargs = self.get_dense_kwargs()
136
+ weight_kwargs = {
137
+ 'initializer': common_kwargs['kernel_initializer'],
138
+ 'regularizer': common_kwargs['kernel_regularizer'],
139
+ 'constraint': common_kwargs['kernel_constraint']
140
+ }
141
+ weight_kwargs.update(kwargs)
142
+ return self.add_weight(shape=shape, **weight_kwargs)
143
+
144
+ def get_dense(
145
+ self,
146
+ units: int,
147
+ **kwargs
148
+ ) -> keras.layers.Dense:
149
+ common_kwargs = self.get_dense_kwargs()
150
+ common_kwargs.update(kwargs)
151
+ return keras.layers.Dense(units, **common_kwargs)
152
+
153
+ def get_einsum_dense(
154
+ self,
155
+ equation: str,
156
+ output_shape: tf.TensorShape,
157
+ **kwargs
158
+ ) -> keras.layers.EinsumDense:
159
+ common_kwargs = self.get_dense_kwargs()
160
+ common_kwargs.update(kwargs)
161
+ use_bias = common_kwargs.pop('use_bias', False)
162
+ if use_bias and not 'bias_axes' in common_kwargs:
163
+ common_kwargs['bias_axes'] = equation.split('->')[-1][1:] or None
164
+ return keras.layers.EinsumDense(equation, output_shape, **common_kwargs)
165
+
166
+ def get_dense_kwargs(self) -> dict:
167
+ common_kwargs = dict(
168
+ use_bias=self._use_bias,
169
+ kernel_regularizer=self._kernel_regularizer,
170
+ bias_regularizer=self._bias_regularizer,
171
+ activity_regularizer=self.activity_regularizer,
172
+ kernel_constraint=self._kernel_constraint,
173
+ bias_constraint=self._bias_constraint,
174
+ )
175
+ kernel_initializer = self._kernel_initializer.__class__.from_config(
176
+ self._kernel_initializer.get_config()
177
+ )
178
+ bias_initializer = self._bias_initializer.__class__.from_config(
179
+ self._bias_initializer.get_config()
180
+ )
181
+ common_kwargs["kernel_initializer"] = kernel_initializer
182
+ common_kwargs["bias_initializer"] = bias_initializer
183
+ return common_kwargs
184
+
185
+ def get_config(self) -> dict:
186
+ config = super().get_config()
187
+ config.update({
188
+ "use_bias": self._use_bias,
189
+ "kernel_initializer":
190
+ keras.initializers.serialize(self._kernel_initializer),
191
+ "bias_initializer":
192
+ keras.initializers.serialize(self._bias_initializer),
193
+ "kernel_regularizer":
194
+ keras.regularizers.serialize(self._kernel_regularizer),
195
+ "bias_regularizer":
196
+ keras.regularizers.serialize(self._bias_regularizer),
197
+ "kernel_constraint":
198
+ keras.constraints.serialize(self._kernel_constraint),
199
+ "bias_constraint":
200
+ keras.constraints.serialize(self._bias_constraint),
201
+ })
202
+ return config
203
+
204
+
205
+ @keras.saving.register_keras_serializable(package='molcraft')
206
+ class GraphConv(GraphLayer):
207
+
208
+ """Base graph neural network layer.
209
+ """
210
+
211
+ def __init__(self, units: int, **kwargs) -> None:
212
+ super().__init__(**kwargs)
213
+ self.units = units
214
+
215
+ @abc.abstractmethod
216
+ def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
217
+ """Compute messages.
218
+
219
+ This method needs to be implemented by subclass.
220
+
221
+ Args:
222
+ tensor:
223
+ The inputted `GraphTensor` instance.
224
+ """
225
+
226
+ @abc.abstractmethod
227
+ def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
228
+ """Aggregates messages.
229
+
230
+ This method needs to be implemented by subclass.
231
+
232
+ Args:
233
+ tensor:
234
+ A `GraphTensor` instance containing a message.
235
+ """
236
+
237
+ @abc.abstractmethod
238
+ def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
239
+ """Updates nodes.
240
+
241
+ This method needs to be implemented by subclass.
242
+
243
+ Args:
244
+ tensor:
245
+ A `GraphTensor` instance containing aggregated messages
246
+ (updated node features).
247
+ """
248
+
249
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
250
+ """Calls the layer.
251
+
252
+ The `GraphConv` layer invokes `message`, `aggregate` and `update`
253
+ in sequence.
254
+
255
+ Args:
256
+ tensor:
257
+ A `GraphTensor` instance.
258
+ """
259
+ tensor = self.message(tensor)
260
+ tensor = self.aggregate(tensor)
261
+ tensor = self.update(tensor)
262
+ return tensor
263
+
264
+ def get_config(self) -> dict:
265
+ config = super().get_config()
266
+ config.update({
267
+ 'units': self.units
268
+ })
269
+ return config
270
+
271
+
272
+ @keras.saving.register_keras_serializable(package='molcraft')
273
+ class Projection(GraphLayer):
274
+ """Base graph projection layer.
275
+ """
276
+ def __init__(
277
+ self,
278
+ units: int = None,
279
+ activation: str = None,
280
+ field: str = 'node',
281
+ **kwargs
282
+ ) -> None:
283
+ super().__init__(**kwargs)
284
+ self.units = units
285
+ self._activation = keras.activations.get(activation)
286
+ self.field = field
287
+
288
+ def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
289
+ """Builds the layer.
290
+ """
291
+ data = getattr(spec, self.field, None)
292
+ if data is None:
293
+ raise ValueError('Could not access field {self.field!r}.')
294
+ feature_dim = data['feature'].shape[-1]
295
+ if not self.units:
296
+ self.units = feature_dim
297
+ self._dense = self.get_dense(self.units)
298
+ self._dense.build([None, feature_dim])
299
+
300
+ def propagate(self, tensor: tensors.GraphTensor):
301
+ """Calls the layer.
302
+ """
303
+ feature = getattr(tensor, self.field)['feature']
304
+ feature = self._dense(feature)
305
+ feature = self._activation(feature)
306
+ return tensor.update(
307
+ {
308
+ self.field: {
309
+ 'feature': feature
310
+ }
311
+ }
312
+ )
313
+
314
+ def get_config(self) -> dict:
315
+ config = super().get_config()
316
+ config.update({
317
+ 'units': self.units,
318
+ 'activation': keras.activations.serialize(self._activation),
319
+ 'field': self.field,
320
+ })
321
+ return config
322
+
323
+
324
+ @keras.saving.register_keras_serializable(package='molcraft')
325
+ class GraphNetwork(GraphLayer):
326
+
327
+ """Graph neural network.
328
+
329
+ Sequentially calls graph layers (`GraphLayer`) and concatenates its output.
330
+
331
+ Args:
332
+ layers (list):
333
+ A list of graph layers.
334
+ """
335
+
336
+ def __init__(self, layers: list[GraphLayer], **kwargs) -> None:
337
+ super().__init__(**kwargs)
338
+ self.layers = layers
339
+ self._update_edge_feature = False
340
+
341
+ def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
342
+ """Builds the layer.
343
+ """
344
+ units = self.layers[0].units
345
+ node_feature_dim = spec.node['feature'].shape[-1]
346
+ if node_feature_dim != units:
347
+ warn(
348
+ 'Node feature dim does not match `units` of the first layer. '
349
+ 'Automatically adding a node projection layer to match `units`.'
350
+ )
351
+ self._node_dense = self.get_dense(units)
352
+ self._update_node_feature = True
353
+ has_edge_feature = 'feature' in spec.edge
354
+ if has_edge_feature:
355
+ edge_feature_dim = spec.edge['feature'].shape[-1]
356
+ if edge_feature_dim != units:
357
+ warn(
358
+ 'Edge feature dim does not match `units` of the first layer. '
359
+ 'Automatically adding a edge projection layer to match `units`.'
360
+ )
361
+ self._edge_dense = self.get_dense(units)
362
+ self._update_edge_feature = True
363
+
364
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
365
+ """Calls the layer.
366
+ """
367
+ x = tensors.to_dict(tensor)
368
+ if self._update_node_feature:
369
+ x['node']['feature'] = self._node_dense(tensor.node['feature'])
370
+ if self._update_edge_feature:
371
+ x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
372
+ outputs = [x['node']['feature']]
373
+ for layer in self.layers:
374
+ x = layer(x)
375
+ outputs.append(x['node']['feature'])
376
+ return tensor.update(
377
+ {
378
+ 'node': {
379
+ 'feature': keras.ops.concatenate(outputs, axis=-1)
380
+ }
381
+ }
382
+ )
383
+
384
+ def tape_propagate(
385
+ self,
386
+ tensor: tensors.GraphTensor,
387
+ tape: tf.GradientTape,
388
+ training: bool | None = None,
389
+ ) -> tuple[tensors.GraphTensor, list[tf.Tensor]]:
390
+ """Performs the propagation with a `GradientTape`.
391
+
392
+ Performs the same forward pass as `propagate` but with a `GradientTape`
393
+ watching intermediate node features.
394
+
395
+ Args:
396
+ tensor (tensors.GraphTensor):
397
+ The graph input.
398
+ """
399
+ if isinstance(tensor, tensors.GraphTensor):
400
+ x = tensors.to_dict(tensor)
401
+ else:
402
+ x = tensor
403
+ if self._update_node_feature:
404
+ x['node']['feature'] = self._node_dense(tensor.node['feature'])
405
+ if self._update_edge_feature:
406
+ x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
407
+ tape.watch(x['node']['feature'])
408
+ outputs = [x['node']['feature']]
409
+ for layer in self.layers:
410
+ x = layer(x, training=training)
411
+ tape.watch(x['node']['feature'])
412
+ outputs.append(x['node']['feature'])
413
+
414
+ tensor = tensor.update(
415
+ {
416
+ 'node': {
417
+ 'feature': keras.ops.concatenate(outputs, axis=-1)
418
+ }
419
+ }
420
+ )
421
+ return tensor, outputs
422
+
423
+ def get_config(self) -> dict:
424
+ config = super().get_config()
425
+ config.update(
426
+ {
427
+ 'layers': [
428
+ keras.layers.serialize(layer) for layer in self.layers
429
+ ]
430
+ }
431
+ )
432
+ return config
433
+
434
+ @classmethod
435
+ def from_config(cls, config: dict) -> 'GraphNetwork':
436
+ config['layers'] = [
437
+ keras.layers.deserialize(layer) for layer in config['layers']
438
+ ]
439
+ return super().from_config(config)
440
+
441
+
442
+ @keras.saving.register_keras_serializable(package='molcraft')
443
+ class NodeEmbedding(GraphLayer):
444
+
445
+ """Node embedding layer.
446
+
447
+ Embeds nodes based on its initial features.
448
+ """
449
+
450
+ def __init__(
451
+ self,
452
+ dim: int = None,
453
+ embed_context: bool = True,
454
+ allow_masking: bool = True,
455
+ **kwargs
456
+ ) -> None:
457
+ super().__init__(**kwargs)
458
+ self.dim = dim
459
+ self._embed_context = embed_context
460
+ self._masking_rate = None
461
+ self._allow_masking = allow_masking
462
+
463
+ def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
464
+ """Builds the layer.
465
+ """
466
+ feature_dim = spec.node['feature'].shape[-1]
467
+ if not self.dim:
468
+ self.dim = feature_dim
469
+ self._node_dense = self.get_dense(self.dim)
470
+ self._node_dense.build([None, feature_dim])
471
+
472
+ self._has_super = 'super' in spec.node
473
+ has_context_feature = 'feature' in spec.context
474
+ if not has_context_feature:
475
+ self._embed_context = False
476
+ if self._has_super and not self._embed_context:
477
+ self._super_feature = self.get_weight(shape=[self.dim], name='super_node_feature')
478
+ if self._allow_masking:
479
+ self._mask_feature = self.get_weight(shape=[self.dim], name='mask_node_feature')
480
+
481
+ if self._embed_context:
482
+ context_feature_dim = spec.context['feature'].shape[-1]
483
+ self._context_dense = self.get_dense(self.dim)
484
+ self._context_dense.build([None, context_feature_dim])
485
+
486
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
487
+ """Calls the layer.
488
+ """
489
+ feature = self._node_dense(tensor.node['feature'])
490
+
491
+ if self._has_super:
492
+ super_feature = (0 if self._embed_context else self._super_feature)
493
+ super_mask = keras.ops.expand_dims(tensor.node['super'], 1)
494
+ feature = keras.ops.where(super_mask, super_feature, feature)
495
+
496
+ if self._embed_context:
497
+ context_feature = self._context_dense(tensor.context['feature'])
498
+ feature = ops.scatter_update(feature, tensor.node['super'], context_feature)
499
+ tensor = tensor.update({'context': {'feature': None}})
500
+
501
+ if (
502
+ self._allow_masking and
503
+ self._masking_rate is not None and
504
+ self._masking_rate > 0
505
+ ):
506
+ random = keras.random.uniform(shape=[tensor.num_nodes])
507
+ mask = random <= self._masking_rate
508
+ if self._has_super:
509
+ mask = keras.ops.logical_and(
510
+ mask, keras.ops.logical_not(tensor.node['super'])
511
+ )
512
+ mask = keras.ops.expand_dims(mask, -1)
513
+ feature = keras.ops.where(mask, self._mask_feature, feature)
514
+ elif self._allow_masking:
515
+ # Slience warning of 'no gradients for variables'
516
+ feature = feature + (self._mask_feature * 0.0)
517
+
518
+ return tensor.update({'node': {'feature': feature}})
519
+
520
+ @property
521
+ def masking_rate(self):
522
+ return self._masking_rate
523
+
524
+ @masking_rate.setter
525
+ def masking_rate(self, rate: float):
526
+ if not self._allow_masking and rate is not None:
527
+ raise ValueError(
528
+ f'Cannot set `masking_rate` for layer {self} '
529
+ 'as `allow_masking` was set to `False`.'
530
+ )
531
+ self._masking_rate = float(rate)
532
+
533
+ def get_config(self) -> dict:
534
+ config = super().get_config()
535
+ config.update({
536
+ 'dim': self.dim,
537
+ 'allow_masking': self._allow_masking
538
+ })
539
+ return config
540
+
541
+
542
+ @keras.saving.register_keras_serializable(package='molcraft')
543
+ class EdgeEmbedding(GraphLayer):
544
+
545
+ """Edge embedding layer.
546
+
547
+ Embeds edges based on its initial features.
548
+ """
549
+
550
+ def __init__(
551
+ self,
552
+ dim: int = None,
553
+ allow_masking: bool = True,
554
+ **kwargs
555
+ ) -> None:
556
+ super().__init__(**kwargs)
557
+ self.dim = dim
558
+ self._masking_rate = None
559
+ self._allow_masking = allow_masking
560
+
561
+ def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
562
+ """Builds the layer.
563
+ """
564
+ feature_dim = spec.edge['feature'].shape[-1]
565
+ if not self.dim:
566
+ self.dim = feature_dim
567
+ self._edge_dense = self.get_dense(self.dim)
568
+ self._edge_dense.build([None, feature_dim])
569
+
570
+ self._has_super = 'super' in spec.edge
571
+ if self._has_super:
572
+ self._super_feature = self.get_weight(shape=[self.dim], name='super_edge_feature')
573
+ if self._allow_masking:
574
+ self._mask_feature = self.get_weight(shape=[self.dim], name='mask_edge_feature')
575
+
576
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
577
+ """Calls the layer.
578
+ """
579
+ feature = self._edge_dense(tensor.edge['feature'])
580
+
581
+ if self._has_super:
582
+ super_feature = self._super_feature
583
+ super_mask = keras.ops.expand_dims(tensor.edge['super'], 1)
584
+ feature = keras.ops.where(super_mask, super_feature, feature)
585
+
586
+ if (
587
+ self._allow_masking and
588
+ self._masking_rate is not None and
589
+ self._masking_rate > 0
590
+ ):
591
+ random = keras.random.uniform(shape=[tensor.num_edges])
592
+ mask = random <= self._masking_rate
593
+ if self._has_super:
594
+ mask = keras.ops.logical_and(
595
+ mask, keras.ops.logical_not(tensor.edge['super'])
596
+ )
597
+ mask = keras.ops.expand_dims(mask, -1)
598
+ feature = keras.ops.where(mask, self._mask_feature, feature)
599
+ elif self._allow_masking:
600
+ # Slience warning of 'no gradients for variables'
601
+ feature = feature + (self._mask_feature * 0.0)
602
+
603
+ return tensor.update({'edge': {'feature': feature}})
604
+
605
+ @property
606
+ def masking_rate(self):
607
+ return self._masking_rate
608
+
609
+ @masking_rate.setter
610
+ def masking_rate(self, rate: float):
611
+ if not self._allow_masking and rate is not None:
612
+ raise ValueError(
613
+ f'Cannot set `masking_rate` for layer {self} '
614
+ 'as `allow_masking` was set to `False`.'
615
+ )
616
+ self._masking_rate = float(rate)
617
+
618
+ def get_config(self) -> dict:
619
+ config = super().get_config()
620
+ config.update({
621
+ 'dim': self.dim,
622
+ 'allow_masking': self._allow_masking
623
+ })
624
+ return config
625
+
626
+
627
+ @keras.saving.register_keras_serializable(package='molcraft')
628
+ class ContextProjection(Projection):
629
+ """Context projection layer.
630
+ """
631
+ def __init__(self, units: int = None, activation: str = None, **kwargs):
632
+ super().__init__(units=units, activation=activation, field='context', **kwargs)
633
+
634
+
635
+ @keras.saving.register_keras_serializable(package='molcraft')
636
+ class NodeProjection(Projection):
637
+ """Node projection layer.
638
+ """
639
+ def __init__(self, units: int = None, activation: str = None, **kwargs):
640
+ super().__init__(units=units, activation=activation, field='node', **kwargs)
641
+
642
+
643
+ @keras.saving.register_keras_serializable(package='molcraft')
644
+ class EdgeProjection(Projection):
645
+ """Edge projection layer.
646
+ """
647
+ def __init__(self, units: int = None, activation: str = None, **kwargs):
648
+ super().__init__(units=units, activation=activation, field='edge', **kwargs)
649
+
650
+
651
+ @keras.saving.register_keras_serializable(package='molcraft')
652
+ class GINConv(GraphConv):
653
+
654
+ """Graph isomorphism network layer.
655
+ """
656
+
657
+ def __init__(
658
+ self,
659
+ units: int,
660
+ activation: keras.layers.Activation | str | None = 'relu',
661
+ dropout: float = 0.0,
662
+ normalize: bool = True,
663
+ update_edge_feature: bool = True,
664
+ **kwargs,
665
+ ):
666
+ super().__init__(units=units, **kwargs)
667
+ self._activation = keras.activations.get(activation)
668
+ self._normalize = normalize
669
+ self._dropout = dropout
670
+ self._update_edge_feature = update_edge_feature
671
+
672
+ def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
673
+ """Builds the layer.
674
+ """
675
+ node_feature_dim = spec.node['feature'].shape[-1]
676
+
677
+ self.epsilon = self.add_weight(
678
+ name='epsilon',
679
+ shape=(),
680
+ initializer='zeros',
681
+ trainable=True,
682
+ )
683
+
684
+ if 'feature' in spec.edge:
685
+ edge_feature_dim = spec.edge['feature'].shape[-1]
686
+
687
+ if not self._update_edge_feature:
688
+ if (edge_feature_dim != node_feature_dim):
689
+ warn(
690
+ 'Found edge feature dim to be incompatible with node feature dim. '
691
+ 'Automatically adding a edge feature projection layer to match '
692
+ 'the dim of node features.'
693
+ )
694
+ self._update_edge_feature = True
695
+
696
+ if self._update_edge_feature:
697
+ self._edge_dense = self.get_dense(node_feature_dim)
698
+ self._edge_dense.build([None, edge_feature_dim])
699
+ else:
700
+ self._update_edge_feature = False
701
+
702
+ has_overridden_update = self.__class__.update != GINConv.update
703
+ if not has_overridden_update:
704
+ # Use default feedforward network
705
+ self._feedforward_intermediate_dense = self.get_dense(self.units)
706
+ self._feedforward_intermediate_dense.build([None, node_feature_dim])
707
+
708
+ if self._normalize:
709
+ self._feedforward_intermediate_norm = keras.layers.BatchNormalization()
710
+ self._feedforward_intermediate_norm.build([None, self.units])
711
+
712
+ self._feedforward_dropout = keras.layers.Dropout(self._dropout)
713
+ self._feedforward_activation = self._activation
714
+
715
+ self._feedforward_output_dense = self.get_dense(self.units)
716
+ self._feedforward_output_dense.build([None, self.units])
717
+
718
+ def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
719
+ """Compute messages.
720
+ """
721
+ message = tensor.gather('feature', 'source')
722
+ edge_feature = tensor.edge.get('feature')
723
+ if self._update_edge_feature:
724
+ edge_feature = self._edge_dense(edge_feature)
725
+ if edge_feature is not None:
726
+ message += edge_feature
727
+ return tensor.update(
728
+ {
729
+ 'edge': {
730
+ 'message': message,
731
+ 'feature': edge_feature
732
+ }
733
+ }
734
+ )
735
+
736
+ def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
737
+ """Aggregates messages.
738
+ """
739
+ node_feature = tensor.aggregate('message')
740
+ node_feature += (1 + self.epsilon) * tensor.node['feature']
741
+ return tensor.update(
742
+ {
743
+ 'node': {
744
+ 'feature': node_feature,
745
+ },
746
+ 'edge': {
747
+ 'message': None,
748
+ }
749
+ }
750
+ )
751
+
752
+ def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
753
+ """Updates nodes.
754
+ """
755
+ node_feature = tensor.node['feature']
756
+ node_feature = self._feedforward_intermediate_dense(node_feature)
757
+ node_feature = self._feedforward_activation(node_feature)
758
+ if self._normalize:
759
+ node_feature = self._feedforward_intermediate_norm(node_feature)
760
+ node_feature = self._feedforward_dropout(node_feature)
761
+ node_feature = self._feedforward_output_dense(node_feature)
762
+ return tensor.update(
763
+ {
764
+ 'node': {
765
+ 'feature': node_feature,
766
+ }
767
+ }
768
+ )
769
+
770
+ def get_config(self) -> dict:
771
+ config = super().get_config()
772
+ config.update({
773
+ 'activation': keras.activations.serialize(self._activation),
774
+ 'dropout': self._dropout,
775
+ 'normalize': self._normalize,
776
+ })
777
+ return config
778
+
779
+
780
+ @keras.saving.register_keras_serializable(package='molcraft')
781
+ class GTConv(GraphConv):
782
+
783
+ """Graph transformer layer.
784
+ """
785
+
786
+ def __init__(
787
+ self,
788
+ units: int,
789
+ heads: int = 8,
790
+ activation: keras.layers.Activation | str | None = "relu",
791
+ dropout: float = 0.0,
792
+ attention_dropout: float = 0.0,
793
+ normalize: bool = True,
794
+ normalize_first: bool = True,
795
+ **kwargs,
796
+ ) -> None:
797
+ super().__init__(units=units, **kwargs)
798
+ self._heads = heads
799
+ if self.units % self.heads != 0:
800
+ raise ValueError(f"units need to be divisible by heads.")
801
+ self._head_units = self.units // self.heads
802
+ self._activation = keras.activations.get(activation)
803
+ self._dropout = dropout
804
+ self._attention_dropout = attention_dropout
805
+ self._normalize = normalize
806
+ self._normalize_first = normalize_first
807
+
808
+ @property
809
+ def heads(self):
810
+ return self._heads
811
+
812
+ @property
813
+ def head_units(self):
814
+ return self._head_units
815
+
816
+ def build_from_spec(self, spec):
817
+ """Builds the layer.
818
+ """
819
+ node_feature_dim = spec.node['feature'].shape[-1]
820
+ incompatible_dim = node_feature_dim != self.units
821
+ if incompatible_dim:
822
+ warnings.warn(
823
+ message=(
824
+ '`GTConv` uses residual connections, but input node feature dim '
825
+ 'is incompatible with intermediate dim (`units`). '
826
+ 'Automatically projecting first residual to match its dim with intermediate dim.'
827
+ ),
828
+ category=UserWarning,
829
+ stacklevel=1
830
+ )
831
+ self._residual_dense = self.get_dense(self.units)
832
+ self._residual_dense.build([None, node_feature_dim])
833
+ self._project_residual = True
834
+ else:
835
+ self._project_residual = False
836
+
837
+ self._query_dense = self.get_einsum_dense(
838
+ 'ij,jkh->ikh', (self.head_units, self.heads)
839
+ )
840
+ self._query_dense.build([None, node_feature_dim])
841
+
842
+ self._key_dense = self.get_einsum_dense(
843
+ 'ij,jkh->ikh', (self.head_units, self.heads)
844
+ )
845
+ self._key_dense.build([None, node_feature_dim])
846
+
847
+ self._value_dense = self.get_einsum_dense(
848
+ 'ij,jkh->ikh', (self.head_units, self.heads)
849
+ )
850
+ self._value_dense.build([None, node_feature_dim])
851
+
852
+ self._output_dense = self.get_dense(self.units)
853
+ self._output_dense.build([None, self.units])
854
+
855
+ self._softmax_dropout = keras.layers.Dropout(self._attention_dropout)
856
+
857
+ self._self_attention_norm = keras.layers.LayerNormalization()
858
+ if self._normalize_first:
859
+ self._self_attention_norm.build([None, node_feature_dim])
860
+ else:
861
+ self._self_attention_norm.build([None, self.units])
862
+
863
+ self._self_attention_dropout = keras.layers.Dropout(self._dropout)
864
+
865
+ has_overriden_edge_bias = (
866
+ self.__class__.add_edge_bias != GTConv.add_edge_bias
867
+ )
868
+ if not has_overriden_edge_bias:
869
+ self._has_edge_length = 'length' in spec.edge
870
+ if self._has_edge_length and 'bias' not in spec.edge:
871
+ edge_length_dim = spec.edge['length'].shape[-1]
872
+ self._spatial_encoding_dense = self.get_einsum_dense(
873
+ 'ij,jkh->ikh', (1, self.heads), kernel_initializer='zeros'
874
+ )
875
+ self._spatial_encoding_dense.build([None, edge_length_dim])
876
+
877
+ self._has_edge_feature = 'feature' in spec.edge
878
+ if self._has_edge_feature and 'bias' not in spec.edge:
879
+ edge_feature_dim = spec.edge['feature'].shape[-1]
880
+ self._edge_feature_dense = self.get_einsum_dense(
881
+ 'ij,jkh->ikh', (1, self.heads),
882
+ )
883
+ self._edge_feature_dense.build([None, edge_feature_dim])
884
+
885
+ has_overridden_update = self.__class__.update != GTConv.update
886
+ if not has_overridden_update:
887
+
888
+ self._feedforward_norm = keras.layers.LayerNormalization()
889
+ self._feedforward_norm.build([None, self.units])
890
+
891
+ self._feedforward_dropout = keras.layers.Dropout(self._dropout)
892
+
893
+ self._feedforward_intermediate_dense = self.get_dense(self.units)
894
+ self._feedforward_intermediate_dense.build([None, self.units])
895
+
896
+ self._feedforward_output_dense = self.get_dense(self.units)
897
+ self._feedforward_output_dense.build([None, self.units])
898
+
899
+ def add_node_bias(self, tensor: tensors.GraphTensor) -> tf.Tensor:
900
+ return tensor
901
+
902
+ def add_edge_bias(self, tensor: tensors.GraphTensor) -> tf.Tensor:
903
+ if 'bias' in tensor.edge:
904
+ return tensor
905
+ elif not self._has_edge_feature and not self._has_edge_length:
906
+ return tensor
907
+
908
+ if self._has_edge_feature and not self._has_edge_length:
909
+ edge_bias = self._edge_feature_dense(tensor.edge['feature'])
910
+ elif not self._has_edge_feature and self._has_edge_length:
911
+ edge_bias = self._spatial_encoding_dense(tensor.edge['length'])
912
+ else:
913
+ edge_bias = (
914
+ self._edge_feature_dense(tensor.edge['feature']) +
915
+ self._spatial_encoding_dense(tensor.edge['length'])
916
+ )
917
+
918
+ return tensor.update(
919
+ {
920
+ 'edge': {
921
+ 'bias': edge_bias
922
+ }
923
+ }
924
+ )
925
+
926
+ def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
927
+ """Compute messages.
928
+ """
929
+ tensor = self.add_edge_bias(tensor)
930
+ tensor = self.add_node_bias(tensor)
931
+
932
+ node_feature = tensor.node['feature']
933
+
934
+ if 'bias' in tensor.node:
935
+ node_feature += tensor.node['bias']
936
+
937
+ if self._normalize_first:
938
+ node_feature = self._self_attention_norm(node_feature)
939
+
940
+ query = self._query_dense(node_feature)
941
+ key = self._key_dense(node_feature)
942
+ value = self._value_dense(node_feature)
943
+
944
+ query = ops.gather(query, tensor.edge['source'])
945
+ key = ops.gather(key, tensor.edge['target'])
946
+ value = ops.gather(value, tensor.edge['source'])
947
+
948
+ attention_score = keras.ops.sum(query * key, axis=1, keepdims=True)
949
+ attention_score /= keras.ops.sqrt(float(self.units))
950
+
951
+ if 'bias' in tensor.edge:
952
+ attention_score += tensor.edge['bias']
953
+
954
+ attention = ops.edge_softmax(attention_score, tensor.edge['target'])
955
+ attention = self._softmax_dropout(attention)
956
+
957
+ return tensor.update(
958
+ {
959
+ 'edge': {
960
+ 'message': value,
961
+ 'weight': attention,
962
+ },
963
+ }
964
+ )
965
+
966
+ def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
967
+ """Aggregates messages.
968
+ """
969
+ node_feature = tensor.aggregate('message')
970
+
971
+ node_feature = keras.ops.reshape(node_feature, (-1, self.units))
972
+ node_feature = self._output_dense(node_feature)
973
+ node_feature = self._self_attention_dropout(node_feature)
974
+
975
+ residual = tensor.node['feature']
976
+ if self._project_residual:
977
+ residual = self._residual_dense(residual)
978
+ node_feature += residual
979
+
980
+ if not self._normalize_first:
981
+ node_feature = self._self_attention_norm(node_feature)
982
+
983
+ return tensor.update(
984
+ {
985
+ 'node': {
986
+ 'feature': node_feature,
987
+ },
988
+ 'edge': {
989
+ 'message': None,
990
+ 'weight': None,
991
+ }
992
+ }
993
+ )
994
+
995
+
996
+ def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
997
+ """Updates nodes.
998
+ """
999
+ node_feature = tensor.node['feature']
1000
+
1001
+ if self._normalize_first:
1002
+ node_feature = self._feedforward_norm(node_feature)
1003
+
1004
+ node_feature = self._feedforward_intermediate_dense(node_feature)
1005
+ node_feature = self._activation(node_feature)
1006
+ node_feature = self._feedforward_output_dense(node_feature)
1007
+
1008
+ node_feature = self._feedforward_dropout(node_feature)
1009
+ node_feature += tensor.node['feature']
1010
+
1011
+ if not self._normalize_first:
1012
+ node_feature = self._feedforward_norm(node_feature)
1013
+
1014
+ return tensor.update(
1015
+ {
1016
+ 'node': {
1017
+ 'feature': node_feature,
1018
+ },
1019
+ }
1020
+ )
1021
+
1022
+ def get_config(self) -> dict:
1023
+ config = super().get_config()
1024
+ config.update({
1025
+ "heads": self._heads,
1026
+ 'activation': keras.activations.serialize(self._activation),
1027
+ 'dropout': self._dropout,
1028
+ 'attention_dropout': self._attention_dropout,
1029
+ 'normalize': self._normalize,
1030
+ 'normalize_first': self._normalize_first,
1031
+ })
1032
+ return config
1033
+
1034
+
1035
+ @keras.saving.register_keras_serializable(package='molcraft')
1036
+ class Readout(keras.layers.Layer):
1037
+
1038
+ def __init__(self, mode: str | None = None, **kwargs):
1039
+ super().__init__(**kwargs)
1040
+ self.mode = mode
1041
+ if not self.mode:
1042
+ self._reduce_fn = None
1043
+ elif str(self.mode).lower().startswith('sum'):
1044
+ self._reduce_fn = keras.ops.segment_sum
1045
+ elif str(self.mode).lower().startswith('max'):
1046
+ self._reduce_fn = keras.ops.segment_max
1047
+ elif str(self.mode).lower().startswith('super'):
1048
+ self._reduce_fn = keras.ops.segment_sum
1049
+ else:
1050
+ self._reduce_fn = ops.segment_mean
1051
+
1052
+ def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
1053
+ """Builds the layer.
1054
+ """
1055
+ pass
1056
+
1057
+ def reduce(self, tensor: tensors.GraphTensor) -> tf.Tensor:
1058
+ if self._reduce_fn is None:
1059
+ raise NotImplementedError("Need to define a reduce method.")
1060
+ if str(self.mode).lower().startswith('super'):
1061
+ node_feature = keras.ops.where(
1062
+ tensor.node['super'][:, None], tensor.node['feature'], 0.0
1063
+ )
1064
+ return self._reduce_fn(
1065
+ node_feature, tensor.graph_indicator, tensor.num_subgraphs
1066
+ )
1067
+ return self._reduce_fn(
1068
+ tensor.node['feature'], tensor.graph_indicator, tensor.num_subgraphs
1069
+ )
1070
+
1071
+ def build(self, input_shapes) -> None:
1072
+ spec = tensors.GraphTensor.Spec.from_input_shape_dict(input_shapes)
1073
+ self.build_from_spec(spec)
1074
+ self.built = True
1075
+
1076
+ def call(self, graph) -> tf.Tensor:
1077
+ graph_tensor = tensors.from_dict(graph)
1078
+ if tensors.is_ragged(graph_tensor):
1079
+ graph_tensor = graph_tensor.flatten()
1080
+ return self.reduce(graph_tensor)
1081
+
1082
+ def __call__(
1083
+ self,
1084
+ graph: tensors.GraphTensor,
1085
+ *args,
1086
+ **kwargs
1087
+ ) -> tensors.GraphTensor:
1088
+ is_tensor = isinstance(graph, tensors.GraphTensor)
1089
+ if is_tensor:
1090
+ graph = tensors.to_dict(graph)
1091
+ tensor = super().__call__(graph, *args, **kwargs)
1092
+ return tensor
1093
+
1094
+ def get_config(self) -> dict:
1095
+ config = super().get_config()
1096
+ config['mode'] = self.mode
1097
+ return config
1098
+
1099
+
1100
+ def Input(spec: tensors.GraphTensor.Spec) -> dict:
1101
+ """Used to specify inputs to model.
1102
+
1103
+ Example:
1104
+
1105
+ >>> import molcraft
1106
+ >>> import keras
1107
+ >>>
1108
+ >>> featurizer = molcraft.featurizers.MolGraphFeaturizer()
1109
+ >>> graph = featurizer([('N[C@@H](C)C(=O)O', 1.0), ('N[C@@H](CS)C(=O)O', 2.0)])
1110
+ >>>
1111
+ >>> model = molcraft.models.GraphModel.from_layers(
1112
+ ... molcraft.layers.Input(graph.spec),
1113
+ ... molcraft.layers.NodeEmbedding(128),
1114
+ ... molcraft.layers.EdgeEmbedding(128),
1115
+ ... molcraft.layers.GraphTransformer(128),
1116
+ ... molcraft.layers.GraphTransformer(128),
1117
+ ... molcraft.layers.Readout('mean'),
1118
+ ... molcraft.layers.Dense(1)
1119
+ ... ])
1120
+ """
1121
+
1122
+ # Currently, Keras (3.8.0) does not support extension types.
1123
+ # So for now, this function will unpack the `GraphTensor.Spec` and
1124
+ # return a dictionary of nested tensor specs. However, the corresponding
1125
+ # nest of tensors will temporarily be converted to a `GraphTensor` by the
1126
+ # `GraphLayer`, to levarage the utility of a `GraphTensor` object.
1127
+ inputs = {}
1128
+ for outer_field, data in spec.__dict__.items():
1129
+ inputs[outer_field] = {}
1130
+ for inner_field, nested_spec in data.items():
1131
+ if inner_field in ['label', 'weight']:
1132
+ if outer_field == 'context':
1133
+ continue
1134
+ kwargs = {
1135
+ 'shape': nested_spec.shape[1:],
1136
+ 'dtype': nested_spec.dtype,
1137
+ 'name': f'{outer_field}_{inner_field}'
1138
+ }
1139
+ if isinstance(nested_spec, tf.RaggedTensorSpec):
1140
+ kwargs['ragged'] = True
1141
+ try:
1142
+ inputs[outer_field][inner_field] = keras.Input(**kwargs)
1143
+ except TypeError:
1144
+ raise ValueError(
1145
+ "`keras.Input` does not currently support ragged tensors. For now, "
1146
+ "pass the `Spec` of a 'flat' `GraphTensor` to `GNNInput`."
1147
+ )
1148
+ return inputs
1149
+
1150
+
1151
+ def warn(message: str) -> None:
1152
+ warnings.warn(
1153
+ message=message,
1154
+ category=UserWarning,
1155
+ stacklevel=1
1156
+ )
1157
+
1158
+ def _match_functional_input(functional_input, inputs):
1159
+ matching_inputs = {}
1160
+ for outer_field, data in functional_input.items():
1161
+ matching_inputs[outer_field] = {}
1162
+ for inner_field, _ in data.items():
1163
+ call_input = inputs[outer_field].pop(inner_field)
1164
+ matching_inputs[outer_field][inner_field] = call_input
1165
+ unmatching_inputs = inputs
1166
+ return matching_inputs, unmatching_inputs
1167
+
1168
+ def _add_left_out_inputs(outputs, inputs):
1169
+ for outer_field, data in inputs.items():
1170
+ for inner_field, value in data.items():
1171
+ if inner_field in ['label', 'weight']:
1172
+ outputs[outer_field][inner_field] = value
1173
+ return outputs
1174
+
1175
+ def _serialize_spec(spec: tensors.GraphTensor.Spec) -> dict:
1176
+ serialized_spec = {}
1177
+ for outer_field, data in spec.__dict__.items():
1178
+ serialized_spec[outer_field] = {}
1179
+ for inner_field, inner_spec in data.items():
1180
+ serialized_spec[outer_field][inner_field] = {
1181
+ 'shape': inner_spec.shape.as_list(),
1182
+ 'dtype': inner_spec.dtype.name,
1183
+ 'name': inner_spec.name,
1184
+ }
1185
+ return serialized_spec
1186
+
1187
+ def _deserialize_spec(serialized_spec: dict) -> tensors.GraphTensor.Spec:
1188
+ deserialized_spec = {}
1189
+ for outer_field, data in serialized_spec.items():
1190
+ deserialized_spec[outer_field] = {}
1191
+ for inner_field, inner_spec in data.items():
1192
+ deserialized_spec[outer_field][inner_field] = tf.TensorSpec(
1193
+ inner_spec['shape'], inner_spec['dtype'], inner_spec['name']
1194
+ )
1195
+ return tensors.GraphTensor.Spec(**deserialized_spec)
1196
+
1197
+ def _spec_from_inputs(inputs):
1198
+ symbolic_inputs = keras.backend.is_keras_tensor(
1199
+ tf.nest.flatten(inputs)[0]
1200
+ )
1201
+ if not symbolic_inputs:
1202
+ nested_specs = tf.nest.map_structure(
1203
+ tf.type_spec_from_value, inputs
1204
+ )
1205
+ else:
1206
+ nested_specs = tf.nest.map_structure(
1207
+ lambda t: tf.TensorSpec(t.shape, t.dtype), inputs
1208
+ )
1209
+ if isinstance(nested_specs, tensors.GraphTensor.Spec):
1210
+ spec = nested_specs
1211
+ return spec
1212
+ return tensors.GraphTensor.Spec(**nested_specs)
1213
+
1214
+
1215
+ GraphTransformer = GTConvolution = GTConv
1216
+ GINConvolution = GINConv
1217
+
1218
+ EdgeEmbed = EdgeEmbedding
1219
+ NodeEmbed = NodeEmbedding
1220
+
1221
+ ContextDense = ContextProjection
1222
+ EdgeDense = EdgeProjection
1223
+ NodeDense = NodeProjection
1224
+