molcraft 0.1.0rc9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

molcraft/layers.py ADDED
@@ -0,0 +1,1910 @@
1
+ import warnings
2
+ import keras
3
+ import tensorflow as tf
4
+ import functools
5
+ import inspect
6
+ from keras.src.models import functional
7
+
8
+ from molcraft import tensors
9
+ from molcraft import ops
10
+
11
+
12
+ @keras.saving.register_keras_serializable(package='molcraft')
13
+ class GraphLayer(keras.layers.Layer):
14
+ """Base graph layer.
15
+
16
+ Subclasses must implement a forward pass via **propagate(graph)**.
17
+
18
+ Subclasses may create dense layers and weights in **build(graph_spec)**.
19
+
20
+ Note: `GraphLayer` currently only supports `GraphTensor` input.
21
+
22
+ The list of arguments below are only relevant if the derived layer
23
+ invokes 'get_dense_kwargs`, `get_dense` or `get_einsum_dense`.
24
+
25
+ Arguments:
26
+ use_bias (bool):
27
+ Whether bias should be used in dense layers. Default to `True`.
28
+ kernel_initializer (keras.initializers.Initializer, str):
29
+ Initializer for the kernel weight matrix of the dense layers.
30
+ Default to `glorot_uniform`.
31
+ bias_initializer (keras.initializers.Initializer, str):
32
+ Initializer for the bias weight vector of the dense layers.
33
+ Default to `zeros`.
34
+ kernel_regularizer (keras.regularizers.Regularizer, None):
35
+ Regularizer function applied to the kernel weight matrix.
36
+ Default to `None`.
37
+ bias_regularizer (keras.regularizers.Regularizer, None):
38
+ Regularizer function applied to the bias weight vector.
39
+ Default to `None`.
40
+ activity_regularizer (keras.regularizers.Regularizer, None):
41
+ Regularizer function applied to the output of the dense layers.
42
+ Default to `None`.
43
+ kernel_constraint (keras.constraints.Constraint, None):
44
+ Constraint function applied to the kernel weight matrix.
45
+ Default to `None`.
46
+ bias_constraint (keras.constraints.Constraint, None):
47
+ Constraint function applied to the bias weight vector.
48
+ Default to `None`.
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ use_bias: bool = True,
54
+ kernel_initializer: keras.initializers.Initializer | str = "glorot_uniform",
55
+ bias_initializer: keras.initializers.Initializer | str = "zeros",
56
+ kernel_regularizer: keras.regularizers.Regularizer | None = None,
57
+ bias_regularizer: keras.regularizers.Regularizer | None = None,
58
+ activity_regularizer: keras.regularizers.Regularizer | None = None,
59
+ kernel_constraint: keras.constraints.Constraint | None = None,
60
+ bias_constraint: keras.constraints.Constraint | None = None,
61
+ **kwargs,
62
+ ) -> None:
63
+ super().__init__(**kwargs)
64
+ self._use_bias = use_bias
65
+ self._kernel_initializer = keras.initializers.get(kernel_initializer)
66
+ self._bias_initializer = keras.initializers.get(bias_initializer)
67
+ self._kernel_regularizer = keras.regularizers.get(kernel_regularizer)
68
+ self._bias_regularizer = keras.regularizers.get(bias_regularizer)
69
+ self._activity_regularizer = keras.regularizers.get(activity_regularizer)
70
+ self._kernel_constraint = keras.constraints.get(kernel_constraint)
71
+ self._bias_constraint = keras.constraints.get(bias_constraint)
72
+ self._custom_build_config = {}
73
+ self._propagate_kwargs = _propagate_kwargs(self.propagate)
74
+ self.built = False
75
+
76
+ def __init_subclass__(cls, **kwargs):
77
+ super().__init_subclass__(**kwargs)
78
+ subclass_build = cls.build
79
+
80
+ @functools.wraps(subclass_build)
81
+ def build_wrapper(self: GraphLayer, spec: tensors.GraphTensor.Spec | None):
82
+ GraphLayer.build(self, spec)
83
+ subclass_build(self, spec)
84
+ if not self.built and isinstance(self, keras.Model):
85
+ symbolic_inputs = Input(spec)
86
+ self.built = True
87
+ self(symbolic_inputs)
88
+
89
+ cls.build = build_wrapper
90
+
91
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
92
+ """Forward pass.
93
+
94
+ Must be implemented by subclass.
95
+
96
+ Arguments:
97
+ tensor:
98
+ A `GraphTensor` instance.
99
+ """
100
+ raise NotImplementedError(
101
+ 'The forward pass of the layer is not implemented. '
102
+ 'Please implement `propagate`.'
103
+ )
104
+
105
+ def build(self, tensor_spec: tensors.GraphTensor.Spec) -> None:
106
+ """Builds the layer.
107
+
108
+ May use built-in methods such as `get_weight`, `get_dense` and `get_einsum_dense`.
109
+
110
+ Optionally implemented by subclass.
111
+
112
+ Arguments:
113
+ tensor_spec:
114
+ A `GraphTensor.Spec` instance corresponding to the `GraphTensor`
115
+ passed to `propagate`.
116
+ """
117
+ if isinstance(tensor_spec, tensors.GraphTensor.Spec):
118
+ self._custom_build_config['spec'] = _serialize_spec(tensor_spec)
119
+
120
+ def call(
121
+ self,
122
+ graph: dict[str, dict[str, tf.Tensor]],
123
+ training: bool | None = None,
124
+ ) -> dict[str, dict[str, tf.Tensor]]:
125
+ graph_tensor = tensors.from_dict(graph)
126
+ if not self._propagate_kwargs:
127
+ outputs = self.propagate(graph_tensor)
128
+ else:
129
+ outputs = self.propagate(graph_tensor, training=training)
130
+ if isinstance(outputs, tensors.GraphTensor):
131
+ return tensors.to_dict(outputs)
132
+ return outputs
133
+
134
+ def __call__(
135
+ self,
136
+ graph: dict[str, dict[str, tf.Tensor]] | tensors.GraphTensor,
137
+ **kwargs
138
+ ) -> tf.Tensor | dict[str, dict[str, tf.Tensor]] | tensors.GraphTensor:
139
+ if not self.built:
140
+ spec = _spec_from_inputs(graph)
141
+ self.build(spec)
142
+
143
+ is_graph_tensor = isinstance(graph, tensors.GraphTensor)
144
+ if is_graph_tensor:
145
+ graph = tensors.to_dict(graph)
146
+ else:
147
+ graph = {field: dict(data) for (field, data) in graph.items()}
148
+
149
+ if isinstance(self, functional.Functional):
150
+ # As a functional model is strict for what input can be
151
+ # passed to it, we need to temporarily pop some of the
152
+ # input and add it back afterwards.
153
+ symbolic_input = self._symbolic_input
154
+ excluded = {}
155
+ for outer_field in ['context', 'node', 'edge']:
156
+ excluded[outer_field] = {}
157
+ for inner_field in list(graph[outer_field]):
158
+ if inner_field not in symbolic_input[outer_field]:
159
+ excluded[outer_field][inner_field] = (
160
+ graph[outer_field].pop(inner_field)
161
+ )
162
+
163
+ tf.nest.assert_same_structure(self.input, graph)
164
+
165
+ outputs = super().__call__(graph, **kwargs)
166
+
167
+ if not tensors.is_graph(outputs):
168
+ return outputs
169
+
170
+ graph = outputs
171
+ if isinstance(self, functional.Functional):
172
+ for outer_field in ['context', 'node', 'edge']:
173
+ for inner_field in list(excluded[outer_field]):
174
+ graph[outer_field][inner_field] = (
175
+ excluded[outer_field].pop(inner_field)
176
+ )
177
+
178
+ if is_graph_tensor:
179
+ return tensors.from_dict(graph)
180
+
181
+ return graph
182
+
183
+ def get_build_config(self) -> dict:
184
+ if self._custom_build_config:
185
+ return self._custom_build_config
186
+ return super().get_build_config()
187
+
188
+ def build_from_config(self, config: dict) -> None:
189
+ serialized_spec = config.get('spec')
190
+ if serialized_spec is not None:
191
+ spec = _deserialize_spec(serialized_spec)
192
+ self.build(spec)
193
+ else:
194
+ super().build_from_config(config)
195
+
196
+ def get_weight(
197
+ self,
198
+ shape: tf.TensorShape,
199
+ **kwargs,
200
+ ) -> tf.Variable:
201
+ common_kwargs = self.get_dense_kwargs()
202
+ weight_kwargs = {
203
+ 'initializer': common_kwargs['kernel_initializer'],
204
+ 'regularizer': common_kwargs['kernel_regularizer'],
205
+ 'constraint': common_kwargs['kernel_constraint']
206
+ }
207
+ weight_kwargs.update(kwargs)
208
+ return self.add_weight(shape=shape, **weight_kwargs)
209
+
210
+ def get_dense(
211
+ self,
212
+ units: int,
213
+ **kwargs
214
+ ) -> keras.layers.Dense:
215
+ common_kwargs = self.get_dense_kwargs()
216
+ common_kwargs.update(kwargs)
217
+ return keras.layers.Dense(units, **common_kwargs)
218
+
219
+ def get_einsum_dense(
220
+ self,
221
+ equation: str,
222
+ output_shape: tf.TensorShape,
223
+ **kwargs
224
+ ) -> keras.layers.EinsumDense:
225
+ common_kwargs = self.get_dense_kwargs()
226
+ common_kwargs.update(kwargs)
227
+ use_bias = common_kwargs.pop('use_bias', False)
228
+ if use_bias and not 'bias_axes' in common_kwargs:
229
+ common_kwargs['bias_axes'] = equation.split('->')[-1][1:] or None
230
+ return keras.layers.EinsumDense(equation, output_shape, **common_kwargs)
231
+
232
+ def get_dense_kwargs(self) -> dict:
233
+ common_kwargs = dict(
234
+ use_bias=self._use_bias,
235
+ kernel_regularizer=self._kernel_regularizer,
236
+ bias_regularizer=self._bias_regularizer,
237
+ activity_regularizer=self._activity_regularizer,
238
+ kernel_constraint=self._kernel_constraint,
239
+ bias_constraint=self._bias_constraint,
240
+ )
241
+ kernel_initializer = self._kernel_initializer.__class__.from_config(
242
+ self._kernel_initializer.get_config()
243
+ )
244
+ bias_initializer = self._bias_initializer.__class__.from_config(
245
+ self._bias_initializer.get_config()
246
+ )
247
+ common_kwargs["kernel_initializer"] = kernel_initializer
248
+ common_kwargs["bias_initializer"] = bias_initializer
249
+ return common_kwargs
250
+
251
+ def get_config(self) -> dict:
252
+ config = super().get_config()
253
+ config.update({
254
+ "use_bias": self._use_bias,
255
+ "kernel_initializer":
256
+ keras.initializers.serialize(self._kernel_initializer),
257
+ "bias_initializer":
258
+ keras.initializers.serialize(self._bias_initializer),
259
+ "kernel_regularizer":
260
+ keras.regularizers.serialize(self._kernel_regularizer),
261
+ "bias_regularizer":
262
+ keras.regularizers.serialize(self._bias_regularizer),
263
+ "activity_regularizer":
264
+ keras.regularizers.serialize(self._activity_regularizer),
265
+ "kernel_constraint":
266
+ keras.constraints.serialize(self._kernel_constraint),
267
+ "bias_constraint":
268
+ keras.constraints.serialize(self._bias_constraint),
269
+ })
270
+ return config
271
+
272
+ @property
273
+ def _symbolic_output(self) -> dict[dict[str, keras.KerasTensor]]:
274
+ output = self.output
275
+ if tensors.is_graph(output):
276
+ # GraphModel
277
+ return output
278
+ symbolic_input = self._symbolic_input
279
+ symbolic_output = self.compute_output_spec(symbolic_input)
280
+ return tf.nest.pack_sequence_as(symbolic_output, output)
281
+
282
+ @property
283
+ def _symbolic_input(self) -> dict[dict[str, keras.KerasTensor]]:
284
+ input = self.input
285
+ if tensors.is_graph(input):
286
+ # GraphModel or initial GraphLayer of GraphModel
287
+ return input
288
+ spec = _deserialize_spec(self.get_build_config()['spec'])
289
+ spec_dict = {k: dict(v) for k, v in spec.__dict__.items()}
290
+ return tf.nest.pack_sequence_as(spec_dict, input)
291
+
292
+ @property
293
+ def output_spec(self) -> tensors.GraphTensor.Spec | tf.TensorSpec:
294
+ if not self.built:
295
+ return None
296
+ if isinstance(self, functional.Functional):
297
+ output_spec = self.output
298
+ else:
299
+ serialized_spec = self.get_build_config()['spec']
300
+ deserialized_spec = _deserialize_spec(serialized_spec)
301
+ input_spec = Input(deserialized_spec)
302
+ output_spec = self.compute_output_spec(input_spec)
303
+ if not tensors.is_graph(output_spec):
304
+ return tf.TensorSpec(output_spec.shape, output_spec.dtype)
305
+ spec_dict = tf.nest.map_structure(
306
+ lambda t: tf.TensorSpec(t.shape, t.dtype), output_spec
307
+ )
308
+ return tensors.GraphTensor.Spec(**spec_dict)
309
+
310
+
311
+ @keras.saving.register_keras_serializable(package='molcraft')
312
+ class GraphConv(GraphLayer):
313
+
314
+ """Base graph neural network layer.
315
+
316
+ This layer implements the three basic steps of a graph neural network layer, each of which
317
+ can (optionally) be overridden by the `GraphConv` subclass:
318
+
319
+ 1. **message(graph)**, which computes the *messages* to be passed to target nodes;
320
+ 2. **aggregate(graph)**, which *aggregates* messages to target nodes;
321
+ 3. **update(graph)**, which further *updates* (target) nodes.
322
+
323
+ Note: for skip connection to work, the `GraphConv` subclass requires final node feature
324
+ output dimension to be equal to `units`.
325
+
326
+ Arguments:
327
+ units (int):
328
+ Dimensionality of the output space.
329
+ activation (keras.layers.Activation, str, None):
330
+ Activation function to be accessed via `self.activation`, and used for the
331
+ `message()` and `update()` methods, if not overriden. Default to `relu`.
332
+ use_bias (bool):
333
+ Whether bias should be used in the dense layers. Default to `True`.
334
+ normalize (bool, str):
335
+ Whether a normalization layer should be obtain by `get_norm()`. Default to `False`.
336
+ skip_connect (bool):
337
+ Whether node feature input should be added to the node feature output. Default to `True`.
338
+ kernel_initializer (keras.initializers.Initializer, str):
339
+ Initializer for the kernel weight matrix of the dense layers.
340
+ Default to `glorot_uniform`.
341
+ bias_initializer (keras.initializers.Initializer, str):
342
+ Initializer for the bias weight vector of the dense layers.
343
+ Default to `zeros`.
344
+ kernel_regularizer (keras.regularizers.Regularizer, None):
345
+ Regularizer function applied to the kernel weight matrix.
346
+ Default to `None`.
347
+ bias_regularizer (keras.regularizers.Regularizer, None):
348
+ Regularizer function applied to the bias weight vector.
349
+ Default to `None`.
350
+ activity_regularizer (keras.regularizers.Regularizer, None):
351
+ Regularizer function applied to the output of the dense layers.
352
+ Default to `None`.
353
+ kernel_constraint (keras.constraints.Constraint, None):
354
+ Constraint function applied to the kernel weight matrix.
355
+ Default to `None`.
356
+ bias_constraint (keras.constraints.Constraint, None):
357
+ Constraint function applied to the bias weight vector.
358
+ Default to `None`.
359
+ """
360
+
361
+ def __init__(
362
+ self,
363
+ units: int = None,
364
+ activation: str | keras.layers.Activation | None = 'relu',
365
+ use_bias: bool = True,
366
+ normalize: bool = False,
367
+ skip_connect: bool = True,
368
+ **kwargs
369
+ ) -> None:
370
+ super().__init__(use_bias=use_bias, **kwargs)
371
+ self._units = units
372
+ self._normalize = normalize
373
+ self._skip_connect = skip_connect
374
+ self._activation = keras.activations.get(activation)
375
+
376
+ def __init_subclass__(cls, **kwargs):
377
+ super().__init_subclass__(**kwargs)
378
+ subclass_build = cls.build
379
+
380
+ @functools.wraps(subclass_build)
381
+ def build_wrapper(self, spec):
382
+ GraphConv.build(self, spec)
383
+ return subclass_build(self, spec)
384
+
385
+ cls.build = build_wrapper
386
+
387
+ @property
388
+ def units(self):
389
+ return self._units
390
+
391
+ @property
392
+ def activation(self):
393
+ return self._activation
394
+
395
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
396
+ if not self.units:
397
+ raise ValueError(
398
+ f'`self.units` needs to be a positive integer. Found: {self.units}.'
399
+ )
400
+ node_feature_dim = spec.node['feature'].shape[-1]
401
+ self._project_residual = (
402
+ self._skip_connect and (node_feature_dim != self.units)
403
+ )
404
+ if self._project_residual:
405
+ warnings.warn(
406
+ 'Found incompatible dim between input and output. Applying '
407
+ 'a projection layer to residual to match input and output dim.',
408
+ )
409
+ self._residual_dense = self.get_dense(
410
+ self.units, name='residual_dense'
411
+ )
412
+
413
+ self.has_edge_feature = 'feature' in spec.edge
414
+ self.has_node_coordinate = 'coordinate' in spec.node
415
+
416
+ has_overridden_message = self.__class__.message != GraphConv.message
417
+ if not has_overridden_message:
418
+ self._message_intermediate_dense = self.get_dense(self.units)
419
+ self._message_norm = self.get_norm()
420
+ self._message_intermediate_activation = self.activation
421
+ self._message_final_dense = self.get_dense(self.units)
422
+
423
+ has_overridden_aggregate = self.__class__.message != GraphConv.aggregate
424
+ if not has_overridden_aggregate:
425
+ pass
426
+
427
+ has_overridden_update = self.__class__.update != GraphConv.update
428
+ if not has_overridden_update:
429
+ self._update_intermediate_dense = self.get_dense(self.units)
430
+ self._update_norm = self.get_norm()
431
+ self._update_intermediate_activation = self.activation
432
+ self._update_final_dense = self.get_dense(self.units)
433
+
434
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
435
+ """Forward pass.
436
+
437
+ Invokes `message(graph)`, `aggregate(graph)` and `update(graph)` in sequence.
438
+
439
+ Arguments:
440
+ tensor:
441
+ A `GraphTensor` instance.
442
+ """
443
+ if self._skip_connect:
444
+ residual = tensor.node['feature']
445
+ if self._project_residual:
446
+ residual = self._residual_dense(residual)
447
+
448
+ message = self.message(tensor)
449
+ add_message = not isinstance(message, tensors.GraphTensor)
450
+ if add_message:
451
+ message = tensor.update({'edge': {'message': message}})
452
+ elif not 'message' in message.edge:
453
+ raise ValueError('Could not find `message` in `edge` output.')
454
+
455
+ aggregate = self.aggregate(message)
456
+ add_aggregate = not isinstance(aggregate, tensors.GraphTensor)
457
+ if add_aggregate:
458
+ aggregate = tensor.update({'node': {'aggregate': aggregate}})
459
+ elif not 'aggregate' in aggregate.node:
460
+ raise ValueError('Could not find `aggregate` in `node` output.')
461
+
462
+ update = self.update(aggregate)
463
+ if not isinstance(update, tensors.GraphTensor):
464
+ update = tensor.update({'node': {'feature': update}})
465
+ elif not 'feature' in update.node:
466
+ raise ValueError('Could not find `feature` in `node` output.')
467
+
468
+ if update.node['feature'].shape[-1] != self.units:
469
+ raise ValueError('Updated node `feature` is not equal to `self.units`.')
470
+
471
+ if add_message and add_aggregate:
472
+ update = update.update({'node': {'aggregate': None}, 'edge': {'message': None}})
473
+ elif add_message:
474
+ update = update.update({'edge': {'message': None}})
475
+ elif add_aggregate:
476
+ update = update.update({'node': {'aggregate': None}})
477
+
478
+ if not self._skip_connect:
479
+ return update
480
+
481
+ feature = update.node['feature']
482
+
483
+ if self._skip_connect:
484
+ feature += residual
485
+
486
+ return update.update({'node': {'feature': feature}})
487
+
488
+ def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
489
+ """Compute messages.
490
+
491
+ This method may be overridden by subclass.
492
+
493
+ Arguments:
494
+ tensor:
495
+ The inputted `GraphTensor` instance.
496
+ """
497
+ message = keras.ops.concatenate(
498
+ [
499
+ tensor.gather('feature', 'source'),
500
+ tensor.gather('feature', 'target'),
501
+ ],
502
+ axis=-1
503
+ )
504
+ if self.has_edge_feature:
505
+ message = keras.ops.concatenate(
506
+ [
507
+ message,
508
+ tensor.edge['feature']
509
+ ],
510
+ axis=-1
511
+ )
512
+ if self.has_node_coordinate:
513
+ euclidean_distance = ops.euclidean_distance(
514
+ tensor.gather('coordinate', 'target'),
515
+ tensor.gather('coordinate', 'source'),
516
+ axis=-1,
517
+ keepdims=True
518
+ )
519
+ message = keras.ops.concatenate(
520
+ [
521
+ message,
522
+ euclidean_distance
523
+ ],
524
+ axis=-1
525
+ )
526
+ message = self._message_intermediate_dense(message)
527
+ message = self._message_norm(message)
528
+ message = self._message_intermediate_activation(message)
529
+ message = self._message_final_dense(message)
530
+ return tensor.update({'edge': {'message': message}})
531
+
532
+ def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
533
+ """Aggregates messages.
534
+
535
+ This method may be overridden by subclass.
536
+
537
+ Arguments:
538
+ tensor:
539
+ A `GraphTensor` instance containing a message.
540
+ """
541
+ previous = tensor.node['feature']
542
+ aggregate = tensor.aggregate('message', mode='mean')
543
+ aggregate = keras.ops.concatenate([aggregate, previous], axis=-1)
544
+ return tensor.update(
545
+ {
546
+ 'node': {
547
+ 'aggregate': aggregate,
548
+ },
549
+ 'edge': {
550
+ 'message': None,
551
+ }
552
+ }
553
+ )
554
+
555
+ def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
556
+ """Updates nodes.
557
+
558
+ This method may be overridden by subclass.
559
+
560
+ Arguments:
561
+ tensor:
562
+ A `GraphTensor` instance containing aggregated messages
563
+ (updated node features).
564
+ """
565
+ aggregate = tensor.node['aggregate']
566
+ node_feature = self._update_intermediate_dense(aggregate)
567
+ node_feature = self._update_norm(node_feature)
568
+ node_feature = self._update_intermediate_activation(node_feature)
569
+ node_feature = self._update_final_dense(node_feature)
570
+ return tensor.update(
571
+ {
572
+ 'node': {
573
+ 'feature': node_feature,
574
+ 'aggregate': None,
575
+ },
576
+ }
577
+ )
578
+
579
+ def get_norm(self, **kwargs):
580
+ if not self._normalize:
581
+ return keras.layers.Identity()
582
+ elif str(self._normalize).lower().startswith('layer'):
583
+ return keras.layers.LayerNormalization(**kwargs)
584
+ else:
585
+ return keras.layers.BatchNormalization(**kwargs)
586
+
587
+ def get_config(self) -> dict:
588
+ config = super().get_config()
589
+ config.update({
590
+ 'units': self.units,
591
+ 'activation': keras.activations.serialize(self._activation),
592
+ 'normalize': self._normalize,
593
+ 'skip_connect': self._skip_connect,
594
+ })
595
+ return config
596
+
597
+
598
+ @keras.saving.register_keras_serializable(package='molcraft')
599
+ class GIConv(GraphConv):
600
+
601
+ """Graph isomorphism network layer.
602
+
603
+ >>> graph = molcraft.tensors.GraphTensor(
604
+ ... context={
605
+ ... 'size': [2]
606
+ ... },
607
+ ... node={
608
+ ... 'feature': [[1.], [2.]]
609
+ ... },
610
+ ... edge={
611
+ ... 'source': [0, 1],
612
+ ... 'target': [1, 0],
613
+ ... }
614
+ ... )
615
+ >>> conv = molcraft.layers.GIConv(units=4)
616
+ >>> conv(graph)
617
+ GraphTensor(
618
+ context={
619
+ 'size': <tf.Tensor: shape=[1], dtype=int32>
620
+ },
621
+ node={
622
+ 'feature': <tf.Tensor: shape=[2, 4], dtype=float32>
623
+ },
624
+ edge={
625
+ 'source': <tf.Tensor: shape=[2], dtype=int32>,
626
+ 'target': <tf.Tensor: shape=[2], dtype=int32>
627
+ }
628
+ )
629
+ """
630
+
631
+ def __init__(
632
+ self,
633
+ units: int,
634
+ activation: keras.layers.Activation | str | None = 'relu',
635
+ use_bias: bool = True,
636
+ normalize: bool = False,
637
+ skip_connect: bool = True,
638
+ update_edge_feature: bool = True,
639
+ **kwargs,
640
+ ):
641
+ super().__init__(
642
+ units=units,
643
+ activation=activation,
644
+ use_bias=use_bias,
645
+ normalize=normalize,
646
+ skip_connect=skip_connect,
647
+ **kwargs
648
+ )
649
+ self._update_edge_feature = update_edge_feature
650
+
651
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
652
+ node_feature_dim = spec.node['feature'].shape[-1]
653
+
654
+ self.epsilon = self.add_weight(
655
+ name='epsilon',
656
+ shape=(),
657
+ initializer='zeros',
658
+ trainable=True,
659
+ )
660
+
661
+ if self.has_edge_feature:
662
+ edge_feature_dim = spec.edge['feature'].shape[-1]
663
+
664
+ if not self._update_edge_feature:
665
+ if (edge_feature_dim != node_feature_dim):
666
+ warnings.warn(
667
+ 'Found edge and node feature dim to be incompatible. Applying a '
668
+ 'projection layer to edge features to match the dim of the node features.',
669
+ )
670
+ self._update_edge_feature = True
671
+
672
+ if self._update_edge_feature:
673
+ self._edge_dense = self.get_dense(node_feature_dim)
674
+ else:
675
+ self._update_edge_feature = False
676
+
677
+ def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
678
+ message = tensor.gather('feature', 'source')
679
+ edge_feature = tensor.edge.get('feature')
680
+ if self._update_edge_feature:
681
+ edge_feature = self._edge_dense(edge_feature)
682
+ if self.has_edge_feature:
683
+ message += edge_feature
684
+ message = keras.ops.relu(message)
685
+ return tensor.update(
686
+ {
687
+ 'edge': {
688
+ 'message': message,
689
+ 'feature': edge_feature
690
+ }
691
+ }
692
+ )
693
+
694
+ def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
695
+ node_feature = tensor.aggregate('message', mode='mean')
696
+ node_feature += (1 + self.epsilon) * tensor.node['feature']
697
+ return tensor.update(
698
+ {
699
+ 'node': {
700
+ 'aggregate': node_feature,
701
+ },
702
+ 'edge': {
703
+ 'message': None,
704
+ }
705
+ }
706
+ )
707
+
708
+ def get_config(self) -> dict:
709
+ config = super().get_config()
710
+ config.update({
711
+ 'update_edge_feature': self._update_edge_feature
712
+ })
713
+ return config
714
+
715
+
716
+ @keras.saving.register_keras_serializable(package='molcraft')
717
+ class GAConv(GraphConv):
718
+
719
+ """Graph attention network layer.
720
+
721
+ >>> graph = molcraft.tensors.GraphTensor(
722
+ ... context={
723
+ ... 'size': [2]
724
+ ... },
725
+ ... node={
726
+ ... 'feature': [[1.], [2.]]
727
+ ... },
728
+ ... edge={
729
+ ... 'source': [0, 1],
730
+ ... 'target': [1, 0],
731
+ ... }
732
+ ... )
733
+ >>> conv = molcraft.layers.GAConv(units=4, heads=2)
734
+ >>> conv(graph)
735
+ GraphTensor(
736
+ context={
737
+ 'size': <tf.Tensor: shape=[1], dtype=int32>
738
+ },
739
+ node={
740
+ 'feature': <tf.Tensor: shape=[2, 4], dtype=float32>
741
+ },
742
+ edge={
743
+ 'source': <tf.Tensor: shape=[2], dtype=int32>,
744
+ 'target': <tf.Tensor: shape=[2], dtype=int32>
745
+ }
746
+ )
747
+ """
748
+
749
+ def __init__(
750
+ self,
751
+ units: int,
752
+ heads: int = 8,
753
+ activation: keras.layers.Activation | str | None = "relu",
754
+ use_bias: bool = True,
755
+ normalize: bool = False,
756
+ skip_connect: bool = True,
757
+ update_edge_feature: bool = True,
758
+ **kwargs,
759
+ ) -> None:
760
+ super().__init__(
761
+ units=units,
762
+ activation=activation,
763
+ normalize=normalize,
764
+ use_bias=use_bias,
765
+ skip_connect=skip_connect,
766
+ **kwargs
767
+ )
768
+ self._heads = heads
769
+ if self.units % self.heads != 0:
770
+ raise ValueError(f"units need to be divisible by heads.")
771
+ self._head_units = self.units // self.heads
772
+ self._update_edge_feature = update_edge_feature
773
+
774
+ @property
775
+ def heads(self):
776
+ return self._heads
777
+
778
+ @property
779
+ def head_units(self):
780
+ return self._head_units
781
+
782
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
783
+ self._update_edge_feature = self.has_edge_feature and self._update_edge_feature
784
+ if self._update_edge_feature:
785
+ self._edge_dense = self.get_einsum_dense(
786
+ 'ijh,jkh->ikh', (self.head_units, self.heads)
787
+ )
788
+ self._node_dense = self.get_einsum_dense(
789
+ 'ij,jkh->ikh', (self.head_units, self.heads)
790
+ )
791
+ self._feature_dense = self.get_einsum_dense(
792
+ 'ij,jkh->ikh', (self.head_units, self.heads)
793
+ )
794
+ self._attention_dense = self.get_einsum_dense(
795
+ 'ijh,jkh->ikh', (1, self.heads)
796
+ )
797
+
798
+ def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
799
+ attention_feature = keras.ops.concatenate(
800
+ [
801
+ tensor.gather('feature', 'source'),
802
+ tensor.gather('feature', 'target')
803
+ ],
804
+ axis=-1
805
+ )
806
+ if self.has_edge_feature:
807
+ attention_feature = keras.ops.concatenate(
808
+ [
809
+ attention_feature,
810
+ tensor.edge['feature']
811
+ ],
812
+ axis=-1
813
+ )
814
+
815
+ attention_feature = self._feature_dense(attention_feature)
816
+
817
+ edge_feature = tensor.edge.get('feature')
818
+
819
+ if self._update_edge_feature:
820
+ edge_feature = self._edge_dense(attention_feature)
821
+ edge_feature = keras.ops.reshape(edge_feature, (-1, self.units))
822
+
823
+ attention_feature = keras.ops.leaky_relu(attention_feature)
824
+ attention_score = self._attention_dense(attention_feature)
825
+ attention_score = ops.edge_softmax(
826
+ score=attention_score, edge_target=tensor.edge['target']
827
+ )
828
+ node_feature = self._node_dense(tensor.node['feature'])
829
+ message = ops.gather(node_feature, tensor.edge['source'])
830
+ message = ops.edge_weight(message, attention_score)
831
+ return tensor.update(
832
+ {
833
+ 'edge': {
834
+ 'message': message,
835
+ 'feature': edge_feature,
836
+ }
837
+ }
838
+ )
839
+
840
+ def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
841
+ node_feature = tensor.aggregate('message', mode='sum')
842
+ node_feature = keras.ops.reshape(node_feature, (-1, self.units))
843
+ return tensor.update(
844
+ {
845
+ 'node': {
846
+ 'aggregate': node_feature
847
+ },
848
+ 'edge': {
849
+ 'message': None,
850
+ }
851
+ }
852
+ )
853
+
854
+ def get_config(self) -> dict:
855
+ config = super().get_config()
856
+ config.update({
857
+ "heads": self._heads,
858
+ 'update_edge_feature': self._update_edge_feature,
859
+ })
860
+ return config
861
+
862
+
863
+ @keras.saving.register_keras_serializable(package='molcraft')
864
+ class MPConv(GraphConv):
865
+
866
+ """Message passing neural network layer.
867
+
868
+ Also supports 3D molecular graphs.
869
+
870
+ >>> graph = molcraft.tensors.GraphTensor(
871
+ ... context={
872
+ ... 'size': [2]
873
+ ... },
874
+ ... node={
875
+ ... 'feature': [[1.], [2.]]
876
+ ... },
877
+ ... edge={
878
+ ... 'source': [0, 1],
879
+ ... 'target': [1, 0],
880
+ ... }
881
+ ... )
882
+ >>> conv = molcraft.layers.MPConv(units=4)
883
+ >>> conv(graph)
884
+ GraphTensor(
885
+ context={
886
+ 'size': <tf.Tensor: shape=[1], dtype=int32>
887
+ },
888
+ node={
889
+ 'feature': <tf.Tensor: shape=[2, 4], dtype=float32>
890
+ },
891
+ edge={
892
+ 'source': <tf.Tensor: shape=[2], dtype=int32>,
893
+ 'target': <tf.Tensor: shape=[2], dtype=int32>
894
+ }
895
+ )
896
+ """
897
+
898
+ def __init__(
899
+ self,
900
+ units: int = 128,
901
+ activation: keras.layers.Activation | str | None = 'relu',
902
+ use_bias: bool = True,
903
+ normalize: bool = False,
904
+ skip_connect: bool = True,
905
+ **kwargs
906
+ ) -> None:
907
+ super().__init__(
908
+ units=units,
909
+ activation=activation,
910
+ use_bias=use_bias,
911
+ normalize=normalize,
912
+ skip_connect=skip_connect,
913
+ **kwargs
914
+ )
915
+
916
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
917
+ node_feature_dim = spec.node['feature'].shape[-1]
918
+ self.update_fn = keras.layers.GRUCell(self.units)
919
+ self._project_previous_node_feature = node_feature_dim != self.units
920
+ if self._project_previous_node_feature:
921
+ warnings.warn(
922
+ 'Inputted node feature dim does not match updated node feature dim, '
923
+ 'which is required for the GRU update. Applying a projection layer to '
924
+ 'the inputted node features prior to the GRU update, to match dim '
925
+ 'of the updated node feature dim.'
926
+ )
927
+ self._previous_node_dense = self.get_dense(self.units)
928
+
929
+ def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
930
+ """Aggregates messages.
931
+
932
+ This method may be overridden by subclass.
933
+
934
+ Arguments:
935
+ tensor:
936
+ A `GraphTensor` instance containing a message.
937
+ """
938
+ aggregate = tensor.aggregate('message', mode='mean')
939
+ return tensor.update(
940
+ {
941
+ 'node': {
942
+ 'aggregate': aggregate,
943
+ },
944
+ 'edge': {
945
+ 'message': None,
946
+ }
947
+ }
948
+ )
949
+
950
+ def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
951
+ previous = tensor.node['feature']
952
+ aggregate = tensor.node['aggregate']
953
+ if self._project_previous_node_feature:
954
+ previous = self._previous_node_dense(previous)
955
+ updated_node_feature, _ = self.update_fn(
956
+ inputs=aggregate, states=previous
957
+ )
958
+ return tensor.update(
959
+ {
960
+ 'node': {
961
+ 'feature': updated_node_feature,
962
+ 'aggregate': None,
963
+ }
964
+ }
965
+ )
966
+
967
+ def get_config(self) -> dict:
968
+ config = super().get_config()
969
+ config.update({})
970
+ return config
971
+
972
+
973
+ @keras.saving.register_keras_serializable(package='molcraft')
974
+ class GTConv(GraphConv):
975
+
976
+ """Graph transformer layer.
977
+
978
+ Also supports 3D molecular graphs.
979
+
980
+ >>> graph = molcraft.tensors.GraphTensor(
981
+ ... context={
982
+ ... 'size': [2]
983
+ ... },
984
+ ... node={
985
+ ... 'feature': [[1.], [2.]]
986
+ ... },
987
+ ... edge={
988
+ ... 'source': [0, 1],
989
+ ... 'target': [1, 0],
990
+ ... }
991
+ ... )
992
+ >>> conv = molcraft.layers.GTConv(units=4, heads=2)
993
+ >>> conv(graph)
994
+ GraphTensor(
995
+ context={
996
+ 'size': <tf.Tensor: shape=[1], dtype=int32>
997
+ },
998
+ node={
999
+ 'feature': <tf.Tensor: shape=[2, 4], dtype=float32>
1000
+ },
1001
+ edge={
1002
+ 'source': <tf.Tensor: shape=[2], dtype=int32>,
1003
+ 'target': <tf.Tensor: shape=[2], dtype=int32>,
1004
+ }
1005
+ )
1006
+ """
1007
+
1008
+ def __init__(
1009
+ self,
1010
+ units: int,
1011
+ heads: int = 8,
1012
+ activation: keras.layers.Activation | str | None = "relu",
1013
+ use_bias: bool = True,
1014
+ normalize: bool = False,
1015
+ skip_connect: bool = True,
1016
+ attention_dropout: float = 0.0,
1017
+ **kwargs,
1018
+ ) -> None:
1019
+ super().__init__(
1020
+ units=units,
1021
+ activation=activation,
1022
+ normalize=normalize,
1023
+ use_bias=use_bias,
1024
+ skip_connect=skip_connect,
1025
+ **kwargs
1026
+ )
1027
+ self._heads = heads
1028
+ if self.units % self.heads != 0:
1029
+ raise ValueError(f"units need to be divisible by heads.")
1030
+ self._head_units = self.units // self.heads
1031
+ self._attention_dropout = attention_dropout
1032
+
1033
+ @property
1034
+ def heads(self):
1035
+ return self._heads
1036
+
1037
+ @property
1038
+ def head_units(self):
1039
+ return self._head_units
1040
+
1041
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
1042
+ """Builds the layer.
1043
+ """
1044
+ self._query_dense = self.get_einsum_dense(
1045
+ 'ij,jkh->ikh', (self.head_units, self.heads)
1046
+ )
1047
+ self._key_dense = self.get_einsum_dense(
1048
+ 'ij,jkh->ikh', (self.head_units, self.heads)
1049
+ )
1050
+ self._value_dense = self.get_einsum_dense(
1051
+ 'ij,jkh->ikh', (self.head_units, self.heads)
1052
+ )
1053
+ self._output_dense = self.get_dense(self.units)
1054
+ self._softmax_dropout = keras.layers.Dropout(self._attention_dropout)
1055
+
1056
+ if self.has_edge_feature:
1057
+ self._attention_bias_dense_1 = self.get_einsum_dense('ij,jkh->ikh', (1, self.heads))
1058
+
1059
+ if self.has_node_coordinate:
1060
+ node_feature_dim = spec.node['feature'].shape[-1]
1061
+ num_kernels = self.units
1062
+ self._gaussian_loc = self.add_weight(
1063
+ shape=[num_kernels], initializer='zeros', dtype='float32', trainable=True
1064
+ )
1065
+ self._gaussian_scale = self.add_weight(
1066
+ shape=[num_kernels], initializer='ones', dtype='float32', trainable=True
1067
+ )
1068
+ self._centrality_dense = self.get_dense(units=node_feature_dim)
1069
+ self._attention_bias_dense_2 = self.get_einsum_dense('ij,jkh->ikh', (1, self.heads))
1070
+
1071
+ def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1072
+ node_feature = tensor.node['feature']
1073
+
1074
+ if self.has_node_coordinate:
1075
+ euclidean_distance = ops.euclidean_distance(
1076
+ tensor.gather('coordinate', 'target'),
1077
+ tensor.gather('coordinate', 'source'),
1078
+ axis=-1,
1079
+ keepdims=True
1080
+ )
1081
+ gaussian = ops.gaussian(
1082
+ euclidean_distance, self._gaussian_loc, self._gaussian_scale
1083
+ )
1084
+ centrality = keras.ops.segment_sum(gaussian, tensor.edge['target'], tensor.num_nodes)
1085
+ node_feature += self._centrality_dense(centrality)
1086
+
1087
+ query = self._query_dense(node_feature)
1088
+ key = self._key_dense(node_feature)
1089
+ value = self._value_dense(node_feature)
1090
+
1091
+ query = ops.gather(query, tensor.edge['source'])
1092
+ key = ops.gather(key, tensor.edge['target'])
1093
+ value = ops.gather(value, tensor.edge['source'])
1094
+
1095
+ attention_score = keras.ops.sum(query * key, axis=1, keepdims=True)
1096
+ attention_score /= keras.ops.sqrt(float(self.head_units))
1097
+
1098
+ if self.has_edge_feature:
1099
+ attention_score += self._attention_bias_dense_1(tensor.edge['feature'])
1100
+
1101
+ if self.has_node_coordinate:
1102
+ attention_score += self._attention_bias_dense_2(gaussian)
1103
+
1104
+ attention = ops.edge_softmax(attention_score, tensor.edge['target'])
1105
+ attention = self._softmax_dropout(attention)
1106
+
1107
+ if self.has_node_coordinate:
1108
+ displacement = ops.displacement(
1109
+ tensor.gather('coordinate', 'target'),
1110
+ tensor.gather('coordinate', 'source'),
1111
+ normalize=True,
1112
+ axis=-1,
1113
+ keepdims=True,
1114
+ )
1115
+ attention *= keras.ops.expand_dims(displacement, axis=-1)
1116
+ attention = keras.ops.expand_dims(attention, axis=2)
1117
+ value = keras.ops.expand_dims(value, axis=1)
1118
+
1119
+ message = ops.edge_weight(value, attention)
1120
+
1121
+ return tensor.update(
1122
+ {
1123
+ 'edge': {
1124
+ 'message': message,
1125
+ },
1126
+ }
1127
+ )
1128
+
1129
+ def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1130
+ node_feature = tensor.aggregate('message', mode='sum')
1131
+ if self.has_node_coordinate:
1132
+ shape = (tensor.num_nodes, -1, self.units)
1133
+ else:
1134
+ shape = (tensor.num_nodes, self.units)
1135
+ node_feature = keras.ops.reshape(node_feature, shape)
1136
+ node_feature = self._output_dense(node_feature)
1137
+ if self.has_node_coordinate:
1138
+ node_feature = keras.ops.sum(node_feature, axis=1)
1139
+ return tensor.update(
1140
+ {
1141
+ 'node': {
1142
+ 'aggregate': node_feature,
1143
+ },
1144
+ 'edge': {
1145
+ 'message': None,
1146
+ }
1147
+ }
1148
+ )
1149
+
1150
+ def get_config(self) -> dict:
1151
+ config = super().get_config()
1152
+ config.update({
1153
+ "heads": self._heads,
1154
+ 'attention_dropout': self._attention_dropout,
1155
+ })
1156
+ return config
1157
+
1158
+
1159
+ @keras.saving.register_keras_serializable(package='molcraft')
1160
+ class EGConv(GraphConv):
1161
+
1162
+ """Equivariant graph neural network layer 3D.
1163
+
1164
+ Only supports 3D molecular graphs.
1165
+
1166
+ >>> graph = molcraft.tensors.GraphTensor(
1167
+ ... context={
1168
+ ... 'size': [2]
1169
+ ... },
1170
+ ... node={
1171
+ ... 'feature': [[1.], [2.]],
1172
+ ... 'coordinate': [[0.1, -0.1, 0.5], [1.2, -0.5, 2.1]],
1173
+ ... },
1174
+ ... edge={
1175
+ ... 'source': [0, 1],
1176
+ ... 'target': [1, 0],
1177
+ ... }
1178
+ ... )
1179
+ >>> conv = molcraft.layers.EGConv(units=4)
1180
+ >>> conv(graph)
1181
+ GraphTensor(
1182
+ context={
1183
+ 'size': <tf.Tensor: shape=[1], dtype=int32>
1184
+ },
1185
+ node={
1186
+ 'feature': <tf.Tensor: shape=[2, 4], dtype=float32>,
1187
+ 'coordinate': <tf.Tensor: shape=[2, 3], dtype=float32>
1188
+ },
1189
+ edge={
1190
+ 'source': <tf.Tensor: shape=[2], dtype=int32>,
1191
+ 'target': <tf.Tensor: shape=[2], dtype=int32>
1192
+ }
1193
+ )
1194
+ """
1195
+
1196
+ def __init__(
1197
+ self,
1198
+ units: int = 128,
1199
+ activation: keras.layers.Activation | str | None = 'silu',
1200
+ use_bias: bool = True,
1201
+ normalize: bool = False,
1202
+ skip_connect: bool = True,
1203
+ **kwargs
1204
+ ) -> None:
1205
+ super().__init__(
1206
+ units=units,
1207
+ activation=activation,
1208
+ use_bias=use_bias,
1209
+ normalize=normalize,
1210
+ skip_connect=skip_connect,
1211
+ **kwargs
1212
+ )
1213
+
1214
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
1215
+ if not self.has_node_coordinate:
1216
+ raise ValueError(
1217
+ 'Could not find `coordinate`s in node, '
1218
+ 'which is required for Conv3D layers.'
1219
+ )
1220
+ self._message_feedforward_intermediate = self.get_dense(
1221
+ self.units, activation=self.activation
1222
+ )
1223
+ self._message_feedforward_final = self.get_dense(
1224
+ self.units, activation=self.activation
1225
+ )
1226
+
1227
+ self._coord_feedforward_intermediate = self.get_dense(
1228
+ self.units, activation=self.activation
1229
+ )
1230
+ self._coord_feedforward_final = self.get_dense(
1231
+ 1, use_bias=False, activation='tanh'
1232
+ )
1233
+
1234
+ def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1235
+ relative_node_coordinate = keras.ops.subtract(
1236
+ tensor.gather('coordinate', 'target'),
1237
+ tensor.gather('coordinate', 'source')
1238
+ )
1239
+ squared_distance = keras.ops.sum(
1240
+ keras.ops.square(relative_node_coordinate),
1241
+ axis=-1,
1242
+ keepdims=True
1243
+ )
1244
+
1245
+ # For numerical stability (i.e., to prevent NaN losses), this implementation of `EGConv3D`
1246
+ # either needs to apply a `tanh` activation to the output of `self._coord_feedforward_final`,
1247
+ # or normalize `relative_node_cordinate` as follows:
1248
+ #
1249
+ # norm = keras.ops.sqrt(squared_distance) + keras.backend.epsilon()
1250
+ # relative_node_coordinate /= norm
1251
+ #
1252
+ # For now, this implementation does the former.
1253
+
1254
+ feature = keras.ops.concatenate(
1255
+ [
1256
+ tensor.gather('feature', 'target'),
1257
+ tensor.gather('feature', 'source'),
1258
+ squared_distance,
1259
+ ],
1260
+ axis=-1
1261
+ )
1262
+ if self.has_edge_feature:
1263
+ feature = keras.ops.concatenate(
1264
+ [
1265
+ feature,
1266
+ tensor.edge['feature']
1267
+ ],
1268
+ axis=-1
1269
+ )
1270
+ message = self._message_feedforward_final(
1271
+ self._message_feedforward_intermediate(feature)
1272
+ )
1273
+
1274
+ relative_node_coordinate = keras.ops.multiply(
1275
+ relative_node_coordinate,
1276
+ self._coord_feedforward_final(
1277
+ self._coord_feedforward_intermediate(message)
1278
+ )
1279
+ )
1280
+ return tensor.update(
1281
+ {
1282
+ 'edge': {
1283
+ 'message': message,
1284
+ 'relative_node_coordinate': relative_node_coordinate
1285
+ }
1286
+ }
1287
+ )
1288
+
1289
+ def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1290
+ coordinate = tensor.node['coordinate']
1291
+ coordinate += tensor.aggregate('relative_node_coordinate', mode='mean')
1292
+
1293
+ # Original implementation seems to apply sum aggregation, which does not
1294
+ # seem work well for this implementation of `EGConv3D`, as it causes
1295
+ # large output values and large initial losses. The magnitude of the
1296
+ # aggregated values of a sum aggregation depends on the number of
1297
+ # neighbors, which may be many and may differ from node to node (or
1298
+ # graph to graph). Therefore, a mean mean aggregation is performed
1299
+ # instead:
1300
+ aggregate = tensor.aggregate('message', mode='mean')
1301
+ aggregate = keras.ops.concatenate([aggregate, tensor.node['feature']], axis=-1)
1302
+ # Simply added to silence warning ('no gradients for variables ...')
1303
+ aggregate += (0.0 * keras.ops.sum(coordinate))
1304
+
1305
+ return tensor.update(
1306
+ {
1307
+ 'node': {
1308
+ 'aggregate': aggregate,
1309
+ 'coordinate': coordinate,
1310
+ },
1311
+ 'edge': {
1312
+ 'message': None,
1313
+ 'relative_node_coordinate': None
1314
+ }
1315
+ }
1316
+ )
1317
+
1318
+ def get_config(self) -> dict:
1319
+ config = super().get_config()
1320
+ config.update({})
1321
+ return config
1322
+
1323
+
1324
+ @keras.saving.register_keras_serializable(package='molcraft')
1325
+ class Readout(GraphLayer):
1326
+
1327
+ """Readout layer.
1328
+ """
1329
+
1330
+ def __init__(self, mode: str | None = None, **kwargs):
1331
+ kwargs['kernel_initializer'] = None
1332
+ kwargs['bias_initializer'] = None
1333
+ super().__init__(**kwargs)
1334
+ self.mode = mode
1335
+ if str(self.mode).lower().startswith('sum'):
1336
+ self._reduce_fn = keras.ops.segment_sum
1337
+ elif str(self.mode).lower().startswith('max'):
1338
+ self._reduce_fn = keras.ops.segment_max
1339
+ elif str(self.mode).lower().startswith('super'):
1340
+ self._reduce_fn = keras.ops.segment_sum
1341
+ else:
1342
+ self._reduce_fn = ops.segment_mean
1343
+
1344
+ def propagate(self, tensor: tensors.GraphTensor) -> tf.Tensor:
1345
+ node_feature = tensor.node['feature']
1346
+ if str(self.mode).lower().startswith('super'):
1347
+ node_feature = keras.ops.where(
1348
+ tensor.node['super'][:, None], node_feature, 0.0
1349
+ )
1350
+ return self._reduce_fn(
1351
+ node_feature, tensor.graph_indicator, tensor.num_subgraphs
1352
+ )
1353
+
1354
+ def get_config(self) -> dict:
1355
+ config = super().get_config()
1356
+ config['mode'] = self.mode
1357
+ return config
1358
+
1359
+
1360
+ @keras.saving.register_keras_serializable(package='molcraft')
1361
+ class NodeEmbedding(GraphLayer):
1362
+
1363
+ """Node embedding layer.
1364
+
1365
+ Embeds nodes based on its initial features.
1366
+ """
1367
+
1368
+ def __init__(
1369
+ self,
1370
+ dim: int | None = None,
1371
+ intermediate_dim: int | None = None,
1372
+ intermediate_activation: str | keras.layers.Activation | None = 'relu',
1373
+ normalize: bool = False,
1374
+ embed_context: bool = False,
1375
+ num_wildcards: int | None = None,
1376
+ **kwargs
1377
+ ) -> None:
1378
+ super().__init__(**kwargs)
1379
+ self.dim = dim
1380
+ self._intermediate_dim = intermediate_dim
1381
+ self._intermediate_activation = keras.activations.get(
1382
+ intermediate_activation
1383
+ )
1384
+ self._normalize = normalize
1385
+ self._embed_context = embed_context
1386
+ self._num_wildcards = num_wildcards
1387
+
1388
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
1389
+ feature_dim = spec.node['feature'].shape[-1]
1390
+ if not self.dim:
1391
+ self.dim = feature_dim
1392
+ if not self._intermediate_dim:
1393
+ self._intermediate_dim = self.dim * 2
1394
+ self._node_dense = self.get_dense(
1395
+ self._intermediate_dim, activation=self._intermediate_activation
1396
+ )
1397
+ self._has_wildcard = 'wildcard' in spec.node
1398
+ self._has_super = 'super' in spec.node
1399
+ has_context_feature = 'feature' in spec.context
1400
+ if not has_context_feature:
1401
+ self._embed_context = False
1402
+ if self._has_super and not self._embed_context:
1403
+ self._super_feature = self.get_weight(
1404
+ shape=[self._intermediate_dim], name='super_node_feature'
1405
+ )
1406
+ if self._embed_context:
1407
+ self._context_dense = self.get_dense(
1408
+ self._intermediate_dim, activation=self._intermediate_activation
1409
+ )
1410
+ if self._has_wildcard:
1411
+ if self._num_wildcards is None:
1412
+ warnings.warn(
1413
+ "Found wildcards in input, but `num_wildcards` is set to `None`. "
1414
+ "Automatically setting `num_wildcards` to 1. Please specify `num_wildcards>1` "
1415
+ "if the layer should distinguish between different types of wildcards."
1416
+ )
1417
+ self._num_wildcards = 1
1418
+ self._wildcard_features = self.get_weight(
1419
+ shape=[self._num_wildcards, self._intermediate_dim],
1420
+ name='wildcard_node_features'
1421
+ )
1422
+ if not self._normalize:
1423
+ self._norm = keras.layers.Identity()
1424
+ elif str(self._normalize).lower().startswith('layer'):
1425
+ self._norm = keras.layers.LayerNormalization()
1426
+ else:
1427
+ self._norm = keras.layers.BatchNormalization()
1428
+ self._dense = self.get_dense(self.dim)
1429
+
1430
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1431
+ feature = self._node_dense(tensor.node['feature'])
1432
+
1433
+ if self._has_super and not self._embed_context:
1434
+ super_mask = keras.ops.expand_dims(tensor.node['super'], 1)
1435
+ super_feature = self._intermediate_activation(self._super_feature)
1436
+ feature = keras.ops.where(super_mask, super_feature, feature)
1437
+
1438
+ if self._embed_context:
1439
+ context_feature = self._context_dense(tensor.context['feature'])
1440
+ feature = ops.scatter_update(feature, tensor.node['super'], context_feature)
1441
+ tensor = tensor.update({'context': {'feature': None}})
1442
+
1443
+ if self._has_wildcard:
1444
+ wildcard = tensor.node['wildcard']
1445
+ wildcard_indices = keras.ops.where(wildcard > 0)[0]
1446
+ wildcard = ops.gather(wildcard, wildcard_indices)
1447
+ wildcard_feature = ops.gather(
1448
+ self._wildcard_features, keras.ops.mod(wildcard-1, self._num_wildcards)
1449
+ )
1450
+ wildcard_feature = self._intermediate_activation(wildcard_feature)
1451
+ feature = ops.scatter_update(feature, wildcard_indices, wildcard_feature)
1452
+
1453
+ feature = self._norm(feature)
1454
+ feature = self._dense(feature)
1455
+
1456
+ return tensor.update({'node': {'feature': feature}})
1457
+
1458
+ def get_config(self) -> dict:
1459
+ config = super().get_config()
1460
+ config.update({
1461
+ 'dim': self.dim,
1462
+ 'intermediate_dim': self._intermediate_dim,
1463
+ 'intermediate_activation': keras.activations.serialize(
1464
+ self._intermediate_activation
1465
+ ),
1466
+ 'normalize': self._normalize,
1467
+ 'embed_context': self._embed_context,
1468
+ 'num_wildcards': self._num_wildcards,
1469
+ })
1470
+ return config
1471
+
1472
+
1473
+ @keras.saving.register_keras_serializable(package='molcraft')
1474
+ class EdgeEmbedding(GraphLayer):
1475
+
1476
+ """Edge embedding layer.
1477
+
1478
+ Embeds edges based on its initial features.
1479
+ """
1480
+
1481
+ def __init__(
1482
+ self,
1483
+ dim: int = None,
1484
+ intermediate_dim: int | None = None,
1485
+ intermediate_activation: str | keras.layers.Activation | None = 'relu',
1486
+ normalize: bool = False,
1487
+ **kwargs
1488
+ ) -> None:
1489
+ super().__init__(**kwargs)
1490
+ self.dim = dim
1491
+ self._intermediate_dim = intermediate_dim
1492
+ self._intermediate_activation = keras.activations.get(
1493
+ intermediate_activation
1494
+ )
1495
+ self._normalize = normalize
1496
+
1497
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
1498
+ feature_dim = spec.edge['feature'].shape[-1]
1499
+ if not self.dim:
1500
+ self.dim = feature_dim
1501
+ if not self._intermediate_dim:
1502
+ self._intermediate_dim = self.dim * 2
1503
+ self._edge_dense = self.get_dense(
1504
+ self._intermediate_dim, activation=self._intermediate_activation
1505
+ )
1506
+ self._self_loop_feature = self.get_weight(
1507
+ shape=[self._intermediate_dim], name='self_loop_edge_feature'
1508
+ )
1509
+ self._has_super = 'super' in spec.edge
1510
+ if self._has_super:
1511
+ self._super_feature = self.get_weight(
1512
+ shape=[self._intermediate_dim], name='super_edge_feature'
1513
+ )
1514
+ if not self._normalize:
1515
+ self._norm = keras.layers.Identity()
1516
+ elif str(self._normalize).lower().startswith('layer'):
1517
+ self._norm = keras.layers.LayerNormalization()
1518
+ else:
1519
+ self._norm = keras.layers.BatchNormalization()
1520
+ self._dense = self.get_dense(self.dim)
1521
+
1522
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1523
+ feature = self._edge_dense(tensor.edge['feature'])
1524
+
1525
+ if self._has_super:
1526
+ super_mask = keras.ops.expand_dims(tensor.edge['super'], 1)
1527
+ super_feature = self._intermediate_activation(self._super_feature)
1528
+ feature = keras.ops.where(super_mask, super_feature, feature)
1529
+
1530
+ self_loop_mask = keras.ops.expand_dims(tensor.edge['source'] == tensor.edge['target'], 1)
1531
+ self_loop_feature = self._intermediate_activation(self._self_loop_feature)
1532
+ feature = keras.ops.where(self_loop_mask, self_loop_feature, feature)
1533
+ feature = self._norm(feature)
1534
+ feature = self._dense(feature)
1535
+ return tensor.update({'edge': {'feature': feature}})
1536
+
1537
+ def get_config(self) -> dict:
1538
+ config = super().get_config()
1539
+ config.update({
1540
+ 'dim': self.dim,
1541
+ 'intermediate_dim': self._intermediate_dim,
1542
+ 'intermediate_activation': keras.activations.serialize(
1543
+ self._intermediate_activation
1544
+ ),
1545
+ 'normalize': self._normalize,
1546
+ })
1547
+ return config
1548
+
1549
+
1550
+ @keras.saving.register_keras_serializable(package='molcraft')
1551
+ class AddContext(GraphLayer):
1552
+
1553
+ """Context adding layer.
1554
+
1555
+ Adds context to super nodes or nodes, depending on whether super nodes exist.
1556
+ """
1557
+
1558
+ def __init__(
1559
+ self,
1560
+ field: str = 'feature',
1561
+ intermediate_dim: int | None = None,
1562
+ intermediate_activation: str | keras.layers.Activation | None = 'relu',
1563
+ drop: bool = False,
1564
+ normalize: bool = False,
1565
+ num_categories: int | None = None,
1566
+ **kwargs
1567
+ ) -> None:
1568
+ super().__init__(**kwargs)
1569
+ self._field = field
1570
+ self._drop = drop
1571
+ self._intermediate_dim = intermediate_dim
1572
+ self._intermediate_activation = keras.activations.get(
1573
+ intermediate_activation
1574
+ )
1575
+ self._normalize = normalize
1576
+ self._num_categories = num_categories
1577
+
1578
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
1579
+ is_categorical = spec.context[self._field].dtype.is_integer
1580
+ if is_categorical and not self._num_categories:
1581
+ raise ValueError(
1582
+ f'Found context ({self._field}) to be categorical (`int` dtype), but `num_categories`'
1583
+ 'to be `None`. Please specify `num_categories` for categorical context.'
1584
+ )
1585
+ elif not is_categorical and self._num_categories:
1586
+ warnings.warn(
1587
+ f'`num_categories` is set to {self._num_categories}, but found context to be '
1588
+ 'continuous (`float` dtype). Layer will cast context from `float` to `int` '
1589
+ 'before one-hot encoding it.'
1590
+ )
1591
+ feature_dim = spec.node['feature'].shape[-1]
1592
+ self._has_super_node = 'super' in spec.node
1593
+ if self._intermediate_dim is None:
1594
+ self._intermediate_dim = feature_dim * 2
1595
+ self._intermediate_dense = self.get_dense(
1596
+ self._intermediate_dim, activation=self._intermediate_activation
1597
+ )
1598
+ self._final_dense = self.get_dense(feature_dim)
1599
+ if not self._normalize:
1600
+ self._intermediate_norm = keras.layers.Identity()
1601
+ elif str(self._normalize).lower().startswith('layer'):
1602
+ self._intermediate_norm = keras.layers.LayerNormalization()
1603
+ else:
1604
+ self._intermediate_norm = keras.layers.BatchNormalization()
1605
+
1606
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1607
+ context = tensor.context[self._field]
1608
+ if self._num_categories:
1609
+ if context.dtype.is_floating:
1610
+ context = keras.ops.cast(context, dtype=tensor.edge['source'].dtype)
1611
+ context = keras.utils.to_categorical(context, self._num_categories)
1612
+ elif len(keras.ops.shape(context)) == 1:
1613
+ context = keras.ops.expand_dims(context, axis=1)
1614
+ context = self._intermediate_dense(context)
1615
+ context = self._intermediate_norm(context)
1616
+ context = self._final_dense(context)
1617
+ if self._has_super_node:
1618
+ node_feature = ops.scatter_add(
1619
+ tensor.node['feature'], tensor.node['super'], context
1620
+ )
1621
+ else:
1622
+ node_feature = (
1623
+ tensor.node['feature'] + ops.gather(context, tensor.graph_indicator)
1624
+ )
1625
+ data = {'node': {'feature': node_feature}}
1626
+ if self._drop:
1627
+ data['context'] = {self._field: None}
1628
+ return tensor.update(data)
1629
+
1630
+ def get_config(self) -> dict:
1631
+ config = super().get_config()
1632
+ config.update({
1633
+ 'field': self._field,
1634
+ 'intermediate_dim': self._intermediate_dim,
1635
+ 'intermediate_activation': keras.activations.serialize(
1636
+ self._intermediate_activation
1637
+ ),
1638
+ 'drop': self._drop,
1639
+ 'normalize': self._normalize,
1640
+ 'num_categories': self._num_categories,
1641
+ })
1642
+ return config
1643
+
1644
+
1645
+ @keras.saving.register_keras_serializable(package='molcraft')
1646
+ class GraphNetwork(GraphLayer):
1647
+
1648
+ """Graph neural network.
1649
+
1650
+ Sequentially calls graph layers (`GraphLayer`) and concatenates its output.
1651
+
1652
+ Arguments:
1653
+ layers (list):
1654
+ A list of graph layers.
1655
+ """
1656
+
1657
+ def __init__(self, layers: list[GraphLayer], **kwargs) -> None:
1658
+ super().__init__(**kwargs)
1659
+ self.layers = layers
1660
+ self._update_edge_feature = False
1661
+
1662
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
1663
+ units = self.layers[0].units
1664
+ node_feature_dim = spec.node['feature'].shape[-1]
1665
+ self._update_node_feature = node_feature_dim != units
1666
+ if self._update_node_feature:
1667
+ warnings.warn(
1668
+ 'Node feature dim does not match `units` of the first layer. '
1669
+ 'Applying a projection layer to node features to match `units`.',
1670
+ )
1671
+ self._node_dense = self.get_dense(units)
1672
+ self._has_edge_feature = 'feature' in spec.edge
1673
+ if self._has_edge_feature:
1674
+ edge_feature_dim = spec.edge['feature'].shape[-1]
1675
+ self._update_edge_feature = edge_feature_dim != units
1676
+ if self._update_edge_feature:
1677
+ warnings.warn(
1678
+ 'Edge feature dim does not match `units` of the first layer. '
1679
+ 'Applying projection layer to edge features to match `units`.'
1680
+ )
1681
+ self._edge_dense = self.get_dense(units)
1682
+
1683
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1684
+ x = tensors.to_dict(tensor)
1685
+ if self._update_node_feature:
1686
+ x['node']['feature'] = self._node_dense(tensor.node['feature'])
1687
+ if self._has_edge_feature and self._update_edge_feature:
1688
+ x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
1689
+ outputs = [x['node']['feature']]
1690
+ for layer in self.layers:
1691
+ x = layer(x)
1692
+ outputs.append(x['node']['feature'])
1693
+ return tensor.update(
1694
+ {
1695
+ 'node': {
1696
+ 'feature': keras.ops.concatenate(outputs, axis=-1)
1697
+ }
1698
+ }
1699
+ )
1700
+
1701
+ def tape_propagate(
1702
+ self,
1703
+ tensor: tensors.GraphTensor,
1704
+ tape: tf.GradientTape,
1705
+ training: bool | None = None,
1706
+ ) -> tuple[tensors.GraphTensor, list[tf.Tensor]]:
1707
+ """Performs the propagation with a `GradientTape`.
1708
+
1709
+ Performs the same forward pass as `propagate` but with a `GradientTape`
1710
+ watching intermediate node features.
1711
+
1712
+ Arguments:
1713
+ tensor (tensors.GraphTensor):
1714
+ The graph input.
1715
+ """
1716
+ if isinstance(tensor, tensors.GraphTensor):
1717
+ x = tensors.to_dict(tensor)
1718
+ else:
1719
+ x = tensor
1720
+ if self._update_node_feature:
1721
+ x['node']['feature'] = self._node_dense(tensor.node['feature'])
1722
+ if self._update_edge_feature:
1723
+ x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
1724
+ tape.watch(x['node']['feature'])
1725
+ outputs = [x['node']['feature']]
1726
+ for layer in self.layers:
1727
+ x = layer(x, training=training)
1728
+ tape.watch(x['node']['feature'])
1729
+ outputs.append(x['node']['feature'])
1730
+
1731
+ tensor = tensor.update(
1732
+ {
1733
+ 'node': {
1734
+ 'feature': keras.ops.concatenate(outputs, axis=-1)
1735
+ }
1736
+ }
1737
+ )
1738
+ return tensor, outputs
1739
+
1740
+ def get_config(self) -> dict:
1741
+ config = super().get_config()
1742
+ config.update(
1743
+ {
1744
+ 'layers': [
1745
+ keras.layers.serialize(layer) for layer in self.layers
1746
+ ]
1747
+ }
1748
+ )
1749
+ return config
1750
+
1751
+ @classmethod
1752
+ def from_config(cls, config: dict) -> 'GraphNetwork':
1753
+ config['layers'] = [
1754
+ keras.layers.deserialize(layer) for layer in config['layers']
1755
+ ]
1756
+ return super().from_config(config)
1757
+
1758
+
1759
+ @keras.saving.register_keras_serializable(package='molcraft')
1760
+ class GaussianParams(keras.layers.Dense):
1761
+ '''Gaussian parameters.
1762
+
1763
+ Computes loc and scale via a dense layer. Should be implemented
1764
+ as the last layer in a model and paired with `losses.GaussianNLL`.
1765
+
1766
+ The loc and scale parameters (resulting from this layer) are concatenated
1767
+ together along the last axis, resulting in a single output tensor.
1768
+
1769
+ Args:
1770
+ events (int):
1771
+ The number of events. If the model makes a single prediction per example,
1772
+ then the number of events should be 1. If the model makes multiple predictions
1773
+ per example, then the number of events should be greater than 1.
1774
+ Default to 1.
1775
+ kwargs:
1776
+ See `keras.layers.Dense` documentation. `activation` will be applied
1777
+ to `loc` only. `scale` is automatically softplus activated.
1778
+ '''
1779
+ def __init__(self, events: int = 1, **kwargs):
1780
+ units = kwargs.pop('units', None)
1781
+ activation = kwargs.pop('activation', None)
1782
+ if units:
1783
+ if units % 2 != 0:
1784
+ raise ValueError(
1785
+ '`units` needs to be divisble by 2 as `units` = 2 x `events`.'
1786
+ )
1787
+ else:
1788
+ units = int(events * 2)
1789
+ super().__init__(units=units, **kwargs)
1790
+ self.events = events
1791
+ self.loc_activation = keras.activations.get(activation)
1792
+
1793
+ def call(self, inputs, **kwargs):
1794
+ loc_and_scale = super().call(inputs, **kwargs)
1795
+ loc = loc_and_scale[..., :self.events]
1796
+ scale = loc_and_scale[..., self.events:]
1797
+ scale = keras.ops.softplus(scale) + keras.backend.epsilon()
1798
+ loc = self.loc_activation(loc)
1799
+ return keras.ops.concatenate([loc, scale], axis=-1)
1800
+
1801
+ def get_config(self):
1802
+ config = super().get_config()
1803
+ config['events'] = self.events
1804
+ config['units'] = None
1805
+ config['activation'] = keras.activations.serialize(self.loc_activation)
1806
+ return config
1807
+
1808
+
1809
+ def Input(spec: tensors.GraphTensor.Spec) -> dict:
1810
+ """Used to specify inputs to a functional model.
1811
+
1812
+ Example:
1813
+
1814
+ >>> import molcraft
1815
+ >>> import keras
1816
+ >>>
1817
+ >>> featurizer = molcraft.featurizers.MolGraphFeaturizer()
1818
+ >>> graph = featurizer([('N[C@@H](C)C(=O)O', 1.0), ('N[C@@H](CS)C(=O)O', 2.0)])
1819
+ >>>
1820
+ >>> model = molcraft.models.GraphModel.from_layers(
1821
+ ... molcraft.layers.Input(graph.spec),
1822
+ ... molcraft.layers.NodeEmbedding(128),
1823
+ ... molcraft.layers.EdgeEmbedding(128),
1824
+ ... molcraft.layers.GraphTransformer(128),
1825
+ ... molcraft.layers.GraphTransformer(128),
1826
+ ... molcraft.layers.Readout('mean'),
1827
+ ... molcraft.layers.Dense(1)
1828
+ ... ])
1829
+ """
1830
+
1831
+ # Currently, Keras (3.8.0) does not support extension types.
1832
+ # So for now, this function will unpack the `GraphTensor.Spec` and
1833
+ # return a dictionary of nested tensor specs. However, the corresponding
1834
+ # nest of tensors will temporarily be converted to a `GraphTensor` by the
1835
+ # `GraphLayer`, to levarage the utility of a `GraphTensor` object.
1836
+ inputs = {}
1837
+ for outer_field, data in spec.__dict__.items():
1838
+ inputs[outer_field] = {}
1839
+ for inner_field, nested_spec in data.items():
1840
+ if inner_field in ['label', 'sample_weight']:
1841
+ # Remove label and sample_weight from the symbolic input as
1842
+ # a functional model is strict for what input can be passed.
1843
+ continue
1844
+ kwargs = {
1845
+ 'shape': nested_spec.shape[1:],
1846
+ 'dtype': nested_spec.dtype,
1847
+ 'name': f'{outer_field}_{inner_field}'
1848
+ }
1849
+ if isinstance(nested_spec, tf.RaggedTensorSpec):
1850
+ # kwargs['ragged'] = True
1851
+ raise ValueError(
1852
+ 'Graph layers only supports graph input with nested `tf.Tensor` values.'
1853
+ )
1854
+ try:
1855
+ inputs[outer_field][inner_field] = keras.Input(**kwargs)
1856
+ except TypeError:
1857
+ raise ValueError(
1858
+ "`keras.Input` does not currently support ragged tensors. For now, "
1859
+ "pass the `Spec` of a 'flat' `GraphTensor` to `GNNInput`."
1860
+ )
1861
+ return inputs
1862
+
1863
+
1864
+ def _serialize_spec(spec: tensors.GraphTensor.Spec) -> dict:
1865
+ serialized_spec = {}
1866
+ for outer_field, data in spec.__dict__.items():
1867
+ serialized_spec[outer_field] = {}
1868
+ for inner_field, inner_spec in data.items():
1869
+ if inner_field in ['label', 'sample_weight']:
1870
+ continue
1871
+ serialized_spec[outer_field][inner_field] = {
1872
+ 'shape': inner_spec.shape.as_list(),
1873
+ 'dtype': inner_spec.dtype.name,
1874
+ 'name': inner_spec.name,
1875
+ }
1876
+ return serialized_spec
1877
+
1878
+ def _deserialize_spec(serialized_spec: dict) -> tensors.GraphTensor.Spec:
1879
+ deserialized_spec = {}
1880
+ for outer_field, data in serialized_spec.items():
1881
+ deserialized_spec[outer_field] = {}
1882
+ for inner_field, inner_spec in data.items():
1883
+ deserialized_spec[outer_field][inner_field] = tf.TensorSpec(
1884
+ inner_spec['shape'], inner_spec['dtype'], inner_spec['name']
1885
+ )
1886
+ return tensors.GraphTensor.Spec(**deserialized_spec)
1887
+
1888
+ def _spec_from_inputs(inputs):
1889
+ symbolic_inputs = keras.backend.is_keras_tensor(
1890
+ tf.nest.flatten(inputs)[0]
1891
+ )
1892
+ if not symbolic_inputs:
1893
+ nested_specs = tf.nest.map_structure(
1894
+ tf.type_spec_from_value, inputs
1895
+ )
1896
+ else:
1897
+ nested_specs = tf.nest.map_structure(
1898
+ lambda t: tf.TensorSpec(t.shape, t.dtype), inputs
1899
+ )
1900
+ if isinstance(nested_specs, tensors.GraphTensor.Spec):
1901
+ spec = nested_specs
1902
+ return spec
1903
+ return tensors.GraphTensor.Spec(**nested_specs)
1904
+
1905
+ def _propagate_kwargs(func) -> bool:
1906
+ signature = inspect.signature(func)
1907
+ return any(
1908
+ (param.kind == inspect.Parameter.VAR_KEYWORD) or (param.name == 'training')
1909
+ for param in signature.parameters.values()
1910
+ )