molcraft 0.1.0rc10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
molcraft/layers.py ADDED
@@ -0,0 +1,2034 @@
1
+ import warnings
2
+ import keras
3
+ import tensorflow as tf
4
+ import functools
5
+ import inspect
6
+ from keras.src.models import functional
7
+
8
+ from molcraft import tensors
9
+ from molcraft import ops
10
+
11
+
12
+ @keras.saving.register_keras_serializable(package='molcraft')
13
+ class GraphLayer(keras.layers.Layer):
14
+ """Base graph layer.
15
+
16
+ Subclasses must implement a forward pass via **propagate(graph)**.
17
+
18
+ Subclasses may create dense layers and weights in **build(graph_spec)**.
19
+
20
+ Note: `GraphLayer` currently only supports `GraphTensor` input.
21
+
22
+ The list of arguments below are only relevant if the derived layer
23
+ invokes 'get_dense_kwargs`, `get_dense` or `get_einsum_dense`.
24
+
25
+ Arguments:
26
+ use_bias (bool):
27
+ Whether bias should be used in dense layers. Default to `True`.
28
+ kernel_initializer (keras.initializers.Initializer, str):
29
+ Initializer for the kernel weight matrix of the dense layers.
30
+ Default to `glorot_uniform`.
31
+ bias_initializer (keras.initializers.Initializer, str):
32
+ Initializer for the bias weight vector of the dense layers.
33
+ Default to `zeros`.
34
+ kernel_regularizer (keras.regularizers.Regularizer, None):
35
+ Regularizer function applied to the kernel weight matrix.
36
+ Default to `None`.
37
+ bias_regularizer (keras.regularizers.Regularizer, None):
38
+ Regularizer function applied to the bias weight vector.
39
+ Default to `None`.
40
+ activity_regularizer (keras.regularizers.Regularizer, None):
41
+ Regularizer function applied to the output of the dense layers.
42
+ Default to `None`.
43
+ kernel_constraint (keras.constraints.Constraint, None):
44
+ Constraint function applied to the kernel weight matrix.
45
+ Default to `None`.
46
+ bias_constraint (keras.constraints.Constraint, None):
47
+ Constraint function applied to the bias weight vector.
48
+ Default to `None`.
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ use_bias: bool = True,
54
+ kernel_initializer: keras.initializers.Initializer | str = "glorot_uniform",
55
+ bias_initializer: keras.initializers.Initializer | str = "zeros",
56
+ kernel_regularizer: keras.regularizers.Regularizer | None = None,
57
+ bias_regularizer: keras.regularizers.Regularizer | None = None,
58
+ activity_regularizer: keras.regularizers.Regularizer | None = None,
59
+ kernel_constraint: keras.constraints.Constraint | None = None,
60
+ bias_constraint: keras.constraints.Constraint | None = None,
61
+ **kwargs,
62
+ ) -> None:
63
+ super().__init__(**kwargs)
64
+ self._use_bias = use_bias
65
+ self._kernel_initializer = keras.initializers.get(kernel_initializer)
66
+ self._bias_initializer = keras.initializers.get(bias_initializer)
67
+ self._kernel_regularizer = keras.regularizers.get(kernel_regularizer)
68
+ self._bias_regularizer = keras.regularizers.get(bias_regularizer)
69
+ self._activity_regularizer = keras.regularizers.get(activity_regularizer)
70
+ self._kernel_constraint = keras.constraints.get(kernel_constraint)
71
+ self._bias_constraint = keras.constraints.get(bias_constraint)
72
+ self._custom_build_config = {}
73
+ self._propagate_kwargs = _propagate_kwargs(self.propagate)
74
+ self.built = False
75
+
76
+ def __init_subclass__(cls, **kwargs):
77
+ super().__init_subclass__(**kwargs)
78
+ subclass_build = cls.build
79
+
80
+ @functools.wraps(subclass_build)
81
+ def build_wrapper(self: GraphLayer, spec: tensors.GraphTensor.Spec | None):
82
+ GraphLayer.build(self, spec)
83
+ subclass_build(self, spec)
84
+ if not self.built and isinstance(self, keras.Model):
85
+ symbolic_inputs = Input(spec)
86
+ self.built = True
87
+ self(symbolic_inputs)
88
+
89
+ cls.build = build_wrapper
90
+
91
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
92
+ """Forward pass.
93
+
94
+ Must be implemented by subclass.
95
+
96
+ Arguments:
97
+ tensor:
98
+ A `GraphTensor` instance.
99
+ """
100
+ raise NotImplementedError(
101
+ 'The forward pass of the layer is not implemented. '
102
+ 'Please implement `propagate`.'
103
+ )
104
+
105
+ def build(self, tensor_spec: tensors.GraphTensor.Spec) -> None:
106
+ """Builds the layer.
107
+
108
+ May use built-in methods such as `get_weight`, `get_dense` and `get_einsum_dense`.
109
+
110
+ Optionally implemented by subclass.
111
+
112
+ Arguments:
113
+ tensor_spec:
114
+ A `GraphTensor.Spec` instance corresponding to the `GraphTensor`
115
+ passed to `propagate`.
116
+ """
117
+ if isinstance(tensor_spec, tensors.GraphTensor.Spec):
118
+ self._custom_build_config['spec'] = _serialize_spec(tensor_spec)
119
+
120
+ def call(
121
+ self,
122
+ graph: dict[str, dict[str, tf.Tensor]],
123
+ training: bool | None = None,
124
+ ) -> dict[str, dict[str, tf.Tensor]]:
125
+ graph_tensor = tensors.from_dict(graph)
126
+ if not self._propagate_kwargs:
127
+ outputs = self.propagate(graph_tensor)
128
+ else:
129
+ outputs = self.propagate(graph_tensor, training=training)
130
+ if isinstance(outputs, tensors.GraphTensor):
131
+ return tensors.to_dict(outputs)
132
+ return outputs
133
+
134
+ def __call__(
135
+ self,
136
+ graph: dict[str, dict[str, tf.Tensor]] | tensors.GraphTensor,
137
+ **kwargs
138
+ ) -> tf.Tensor | dict[str, dict[str, tf.Tensor]] | tensors.GraphTensor:
139
+ if not self.built:
140
+ spec = _spec_from_inputs(graph)
141
+ self.build(spec)
142
+
143
+ is_graph_tensor = isinstance(graph, tensors.GraphTensor)
144
+ if is_graph_tensor:
145
+ graph = tensors.to_dict(graph)
146
+ else:
147
+ graph = {field: dict(data) for (field, data) in graph.items()}
148
+
149
+ if isinstance(self, functional.Functional):
150
+ # As a functional model is strict for what input can be
151
+ # passed to it, we need to temporarily pop some of the
152
+ # input and add it back afterwards.
153
+ symbolic_input = self._symbolic_input
154
+ excluded = {}
155
+ for outer_field in ['context', 'node', 'edge']:
156
+ excluded[outer_field] = {}
157
+ for inner_field in list(graph[outer_field]):
158
+ if inner_field not in symbolic_input[outer_field]:
159
+ excluded[outer_field][inner_field] = (
160
+ graph[outer_field].pop(inner_field)
161
+ )
162
+
163
+ tf.nest.assert_same_structure(self.input, graph)
164
+
165
+ outputs = super().__call__(graph, **kwargs)
166
+
167
+ if not tensors.is_graph(outputs):
168
+ return outputs
169
+
170
+ graph = outputs
171
+ if isinstance(self, functional.Functional):
172
+ for outer_field in ['context', 'node', 'edge']:
173
+ for inner_field in list(excluded[outer_field]):
174
+ graph[outer_field][inner_field] = (
175
+ excluded[outer_field].pop(inner_field)
176
+ )
177
+
178
+ if is_graph_tensor:
179
+ return tensors.from_dict(graph)
180
+
181
+ return graph
182
+
183
+ def get_build_config(self) -> dict:
184
+ if self._custom_build_config:
185
+ return self._custom_build_config
186
+ return super().get_build_config()
187
+
188
+ def build_from_config(self, config: dict) -> None:
189
+ serialized_spec = config.get('spec')
190
+ if serialized_spec is not None:
191
+ spec = _deserialize_spec(serialized_spec)
192
+ self.build(spec)
193
+ else:
194
+ super().build_from_config(config)
195
+
196
+ def get_weight(
197
+ self,
198
+ shape: tf.TensorShape,
199
+ **kwargs,
200
+ ) -> tf.Variable:
201
+ common_kwargs = self.get_dense_kwargs()
202
+ weight_kwargs = {
203
+ 'initializer': common_kwargs['kernel_initializer'],
204
+ 'regularizer': common_kwargs['kernel_regularizer'],
205
+ 'constraint': common_kwargs['kernel_constraint']
206
+ }
207
+ weight_kwargs.update(kwargs)
208
+ return self.add_weight(shape=shape, **weight_kwargs)
209
+
210
+ def get_dense(
211
+ self,
212
+ units: int,
213
+ **kwargs
214
+ ) -> keras.layers.Dense:
215
+ common_kwargs = self.get_dense_kwargs()
216
+ common_kwargs.update(kwargs)
217
+ return keras.layers.Dense(units, **common_kwargs)
218
+
219
+ def get_einsum_dense(
220
+ self,
221
+ equation: str,
222
+ output_shape: tf.TensorShape,
223
+ **kwargs
224
+ ) -> keras.layers.EinsumDense:
225
+ common_kwargs = self.get_dense_kwargs()
226
+ common_kwargs.update(kwargs)
227
+ use_bias = common_kwargs.pop('use_bias', False)
228
+ if use_bias and not 'bias_axes' in common_kwargs:
229
+ common_kwargs['bias_axes'] = equation.split('->')[-1][1:] or None
230
+ return keras.layers.EinsumDense(equation, output_shape, **common_kwargs)
231
+
232
+ def get_dense_kwargs(self) -> dict:
233
+ common_kwargs = dict(
234
+ use_bias=self._use_bias,
235
+ kernel_regularizer=self._kernel_regularizer,
236
+ bias_regularizer=self._bias_regularizer,
237
+ activity_regularizer=self._activity_regularizer,
238
+ kernel_constraint=self._kernel_constraint,
239
+ bias_constraint=self._bias_constraint,
240
+ )
241
+ kernel_initializer = self._kernel_initializer.__class__.from_config(
242
+ self._kernel_initializer.get_config()
243
+ )
244
+ bias_initializer = self._bias_initializer.__class__.from_config(
245
+ self._bias_initializer.get_config()
246
+ )
247
+ common_kwargs["kernel_initializer"] = kernel_initializer
248
+ common_kwargs["bias_initializer"] = bias_initializer
249
+ return common_kwargs
250
+
251
+ def get_config(self) -> dict:
252
+ config = super().get_config()
253
+ config.update({
254
+ "use_bias": self._use_bias,
255
+ "kernel_initializer":
256
+ keras.initializers.serialize(self._kernel_initializer),
257
+ "bias_initializer":
258
+ keras.initializers.serialize(self._bias_initializer),
259
+ "kernel_regularizer":
260
+ keras.regularizers.serialize(self._kernel_regularizer),
261
+ "bias_regularizer":
262
+ keras.regularizers.serialize(self._bias_regularizer),
263
+ "activity_regularizer":
264
+ keras.regularizers.serialize(self._activity_regularizer),
265
+ "kernel_constraint":
266
+ keras.constraints.serialize(self._kernel_constraint),
267
+ "bias_constraint":
268
+ keras.constraints.serialize(self._bias_constraint),
269
+ })
270
+ return config
271
+
272
+ @property
273
+ def _symbolic_output(self) -> dict[dict[str, keras.KerasTensor]]:
274
+ output = self.output
275
+ if tensors.is_graph(output):
276
+ # GraphModel
277
+ return output
278
+ symbolic_input = self._symbolic_input
279
+ symbolic_output = self.compute_output_spec(symbolic_input)
280
+ return tf.nest.pack_sequence_as(symbolic_output, output)
281
+
282
+ @property
283
+ def _symbolic_input(self) -> dict[dict[str, keras.KerasTensor]]:
284
+ input = self.input
285
+ if tensors.is_graph(input):
286
+ # GraphModel or initial GraphLayer of GraphModel
287
+ return input
288
+ spec = _deserialize_spec(self.get_build_config()['spec'])
289
+ spec_dict = {k: dict(v) for k, v in spec.__dict__.items()}
290
+ return tf.nest.pack_sequence_as(spec_dict, input)
291
+
292
+ @property
293
+ def output_spec(self) -> tensors.GraphTensor.Spec | tf.TensorSpec:
294
+ if not self.built:
295
+ return None
296
+ if isinstance(self, functional.Functional):
297
+ output_spec = self.output
298
+ else:
299
+ serialized_spec = self.get_build_config()['spec']
300
+ deserialized_spec = _deserialize_spec(serialized_spec)
301
+ input_spec = Input(deserialized_spec)
302
+ output_spec = self.compute_output_spec(input_spec)
303
+ if not tensors.is_graph(output_spec):
304
+ return tf.TensorSpec(output_spec.shape, output_spec.dtype)
305
+ spec_dict = tf.nest.map_structure(
306
+ lambda t: tf.TensorSpec(t.shape, t.dtype), output_spec
307
+ )
308
+ return tensors.GraphTensor.Spec(**spec_dict)
309
+
310
+
311
+ @keras.saving.register_keras_serializable(package='molcraft')
312
+ class GraphConv(GraphLayer):
313
+
314
+ """Base graph neural network layer.
315
+
316
+ This layer implements the three basic steps of a graph neural network layer, each of which
317
+ can (optionally) be overridden by the `GraphConv` subclass:
318
+
319
+ 1. **message(graph)**, which computes the *messages* to be passed to target nodes;
320
+ 2. **aggregate(graph)**, which *aggregates* messages to target nodes;
321
+ 3. **update(graph)**, which further *updates* (target) nodes.
322
+
323
+ Note: for skip connection to work, the `GraphConv` subclass requires final node feature
324
+ output dimension to be equal to `units`.
325
+
326
+ Arguments:
327
+ units (int):
328
+ Dimensionality of the output space.
329
+ activation (keras.layers.Activation, str, None):
330
+ Activation function to be accessed via `self.activation`, and used for the
331
+ `message()` and `update()` methods, if not overriden. Default to `relu`.
332
+ use_bias (bool):
333
+ Whether bias should be used in the dense layers. Default to `True`.
334
+ normalize (bool, str):
335
+ Whether a normalization layer should be obtain by `get_norm()`. Default to `False`.
336
+ skip_connect (bool):
337
+ Whether node feature input should be added to the node feature output. Default to `True`.
338
+ kernel_initializer (keras.initializers.Initializer, str):
339
+ Initializer for the kernel weight matrix of the dense layers.
340
+ Default to `glorot_uniform`.
341
+ bias_initializer (keras.initializers.Initializer, str):
342
+ Initializer for the bias weight vector of the dense layers.
343
+ Default to `zeros`.
344
+ kernel_regularizer (keras.regularizers.Regularizer, None):
345
+ Regularizer function applied to the kernel weight matrix.
346
+ Default to `None`.
347
+ bias_regularizer (keras.regularizers.Regularizer, None):
348
+ Regularizer function applied to the bias weight vector.
349
+ Default to `None`.
350
+ activity_regularizer (keras.regularizers.Regularizer, None):
351
+ Regularizer function applied to the output of the dense layers.
352
+ Default to `None`.
353
+ kernel_constraint (keras.constraints.Constraint, None):
354
+ Constraint function applied to the kernel weight matrix.
355
+ Default to `None`.
356
+ bias_constraint (keras.constraints.Constraint, None):
357
+ Constraint function applied to the bias weight vector.
358
+ Default to `None`.
359
+ """
360
+
361
+ def __init__(
362
+ self,
363
+ units: int = None,
364
+ activation: str | keras.layers.Activation | None = 'relu',
365
+ use_bias: bool = True,
366
+ normalize: bool = False,
367
+ skip_connect: bool = True,
368
+ **kwargs
369
+ ) -> None:
370
+ super().__init__(use_bias=use_bias, **kwargs)
371
+ self._units = units
372
+ self._normalize = normalize
373
+ self._skip_connect = skip_connect
374
+ self._activation = keras.activations.get(activation)
375
+
376
+ def __init_subclass__(cls, **kwargs):
377
+ super().__init_subclass__(**kwargs)
378
+ subclass_build = cls.build
379
+
380
+ @functools.wraps(subclass_build)
381
+ def build_wrapper(self, spec):
382
+ GraphConv.build(self, spec)
383
+ return subclass_build(self, spec)
384
+
385
+ cls.build = build_wrapper
386
+
387
+ @property
388
+ def units(self):
389
+ return self._units
390
+
391
+ @property
392
+ def activation(self):
393
+ return self._activation
394
+
395
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
396
+ if not self.units:
397
+ raise ValueError(
398
+ f'`self.units` needs to be a positive integer. Found: {self.units}.'
399
+ )
400
+ node_feature_dim = spec.node['feature'].shape[-1]
401
+ self._project_residual = (
402
+ self._skip_connect and (node_feature_dim != self.units)
403
+ )
404
+ if self._project_residual:
405
+ warnings.warn(
406
+ 'Found incompatible dim between input and output. Applying '
407
+ 'a projection layer to residual to match input and output dim.',
408
+ )
409
+ self._residual_dense = self.get_dense(
410
+ self.units, name='residual_dense'
411
+ )
412
+
413
+ self.has_edge_feature = 'feature' in spec.edge
414
+ self.has_node_coordinate = 'coordinate' in spec.node
415
+
416
+ has_overridden_message = self.__class__.message != GraphConv.message
417
+ if not has_overridden_message:
418
+ self._message_intermediate_dense = self.get_dense(self.units)
419
+ self._message_norm = self.get_norm()
420
+ self._message_intermediate_activation = self.activation
421
+ self._message_final_dense = self.get_dense(self.units)
422
+
423
+ has_overridden_aggregate = self.__class__.message != GraphConv.aggregate
424
+ if not has_overridden_aggregate:
425
+ pass
426
+
427
+ has_overridden_update = self.__class__.update != GraphConv.update
428
+ if not has_overridden_update:
429
+ self._update_intermediate_dense = self.get_dense(self.units)
430
+ self._update_norm = self.get_norm()
431
+ self._update_intermediate_activation = self.activation
432
+ self._update_final_dense = self.get_dense(self.units)
433
+
434
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
435
+ """Forward pass.
436
+
437
+ Invokes `message(graph)`, `aggregate(graph)` and `update(graph)` in sequence.
438
+
439
+ Arguments:
440
+ tensor:
441
+ A `GraphTensor` instance.
442
+ """
443
+ if self._skip_connect:
444
+ residual = tensor.node['feature']
445
+ if self._project_residual:
446
+ residual = self._residual_dense(residual)
447
+
448
+ message = self.message(tensor)
449
+ add_message = not isinstance(message, tensors.GraphTensor)
450
+ if add_message:
451
+ message = tensor.update({'edge': {'message': message}})
452
+ elif not 'message' in message.edge:
453
+ raise ValueError('Could not find `message` in `edge` output.')
454
+
455
+ aggregate = self.aggregate(message)
456
+ add_aggregate = not isinstance(aggregate, tensors.GraphTensor)
457
+ if add_aggregate:
458
+ aggregate = tensor.update({'node': {'aggregate': aggregate}})
459
+ elif not 'aggregate' in aggregate.node:
460
+ raise ValueError('Could not find `aggregate` in `node` output.')
461
+
462
+ update = self.update(aggregate)
463
+ if not isinstance(update, tensors.GraphTensor):
464
+ update = tensor.update({'node': {'feature': update}})
465
+ elif not 'feature' in update.node:
466
+ raise ValueError('Could not find `feature` in `node` output.')
467
+
468
+ if update.node['feature'].shape[-1] != self.units:
469
+ raise ValueError('Updated node `feature` is not equal to `self.units`.')
470
+
471
+ if add_message and add_aggregate:
472
+ update = update.update({'node': {'aggregate': None}, 'edge': {'message': None}})
473
+ elif add_message:
474
+ update = update.update({'edge': {'message': None}})
475
+ elif add_aggregate:
476
+ update = update.update({'node': {'aggregate': None}})
477
+
478
+ if not self._skip_connect:
479
+ return update
480
+
481
+ feature = update.node['feature']
482
+
483
+ if self._skip_connect:
484
+ feature += residual
485
+
486
+ return update.update({'node': {'feature': feature}})
487
+
488
+ def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
489
+ """Compute messages.
490
+
491
+ This method may be overridden by subclass.
492
+
493
+ Arguments:
494
+ tensor:
495
+ The inputted `GraphTensor` instance.
496
+ """
497
+ message = keras.ops.concatenate(
498
+ [
499
+ tensor.gather('feature', 'source'),
500
+ tensor.gather('feature', 'target'),
501
+ ],
502
+ axis=-1
503
+ )
504
+ if self.has_edge_feature:
505
+ message = keras.ops.concatenate(
506
+ [
507
+ message,
508
+ tensor.edge['feature']
509
+ ],
510
+ axis=-1
511
+ )
512
+ if self.has_node_coordinate:
513
+ euclidean_distance = ops.euclidean_distance(
514
+ tensor.gather('coordinate', 'target'),
515
+ tensor.gather('coordinate', 'source'),
516
+ axis=-1,
517
+ keepdims=True
518
+ )
519
+ message = keras.ops.concatenate(
520
+ [
521
+ message,
522
+ euclidean_distance
523
+ ],
524
+ axis=-1
525
+ )
526
+ message = self._message_intermediate_dense(message)
527
+ message = self._message_norm(message)
528
+ message = self._message_intermediate_activation(message)
529
+ message = self._message_final_dense(message)
530
+ return tensor.update({'edge': {'message': message}})
531
+
532
+ def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
533
+ """Aggregates messages.
534
+
535
+ This method may be overridden by subclass.
536
+
537
+ Arguments:
538
+ tensor:
539
+ A `GraphTensor` instance containing a message.
540
+ """
541
+ previous = tensor.node['feature']
542
+ aggregate = tensor.aggregate('message', mode='mean')
543
+ aggregate = keras.ops.concatenate([aggregate, previous], axis=-1)
544
+ return tensor.update(
545
+ {
546
+ 'node': {
547
+ 'aggregate': aggregate,
548
+ },
549
+ 'edge': {
550
+ 'message': None,
551
+ }
552
+ }
553
+ )
554
+
555
+ def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
556
+ """Updates nodes.
557
+
558
+ This method may be overridden by subclass.
559
+
560
+ Arguments:
561
+ tensor:
562
+ A `GraphTensor` instance containing aggregated messages
563
+ (updated node features).
564
+ """
565
+ aggregate = tensor.node['aggregate']
566
+ node_feature = self._update_intermediate_dense(aggregate)
567
+ node_feature = self._update_norm(node_feature)
568
+ node_feature = self._update_intermediate_activation(node_feature)
569
+ node_feature = self._update_final_dense(node_feature)
570
+ return tensor.update(
571
+ {
572
+ 'node': {
573
+ 'feature': node_feature,
574
+ 'aggregate': None,
575
+ },
576
+ }
577
+ )
578
+
579
+ def get_norm(self, **kwargs):
580
+ if not self._normalize:
581
+ return keras.layers.Identity()
582
+ elif str(self._normalize).lower().startswith('layer'):
583
+ return keras.layers.LayerNormalization(**kwargs)
584
+ else:
585
+ return keras.layers.BatchNormalization(**kwargs)
586
+
587
+ def get_config(self) -> dict:
588
+ config = super().get_config()
589
+ config.update({
590
+ 'units': self.units,
591
+ 'activation': keras.activations.serialize(self._activation),
592
+ 'normalize': self._normalize,
593
+ 'skip_connect': self._skip_connect,
594
+ })
595
+ return config
596
+
597
+
598
+ @keras.saving.register_keras_serializable(package='molcraft')
599
+ class GIConv(GraphConv):
600
+
601
+ """Graph isomorphism network layer.
602
+
603
+ >>> graph = molcraft.tensors.GraphTensor(
604
+ ... context={
605
+ ... 'size': [2]
606
+ ... },
607
+ ... node={
608
+ ... 'feature': [[1.], [2.]]
609
+ ... },
610
+ ... edge={
611
+ ... 'source': [0, 1],
612
+ ... 'target': [1, 0],
613
+ ... }
614
+ ... )
615
+ >>> conv = molcraft.layers.GIConv(units=4)
616
+ >>> conv(graph)
617
+ GraphTensor(
618
+ context={
619
+ 'size': <tf.Tensor: shape=[1], dtype=int32>
620
+ },
621
+ node={
622
+ 'feature': <tf.Tensor: shape=[2, 4], dtype=float32>
623
+ },
624
+ edge={
625
+ 'source': <tf.Tensor: shape=[2], dtype=int32>,
626
+ 'target': <tf.Tensor: shape=[2], dtype=int32>
627
+ }
628
+ )
629
+ """
630
+
631
+ def __init__(
632
+ self,
633
+ units: int,
634
+ activation: keras.layers.Activation | str | None = 'relu',
635
+ use_bias: bool = True,
636
+ normalize: bool = False,
637
+ skip_connect: bool = True,
638
+ update_edge_feature: bool = True,
639
+ **kwargs,
640
+ ):
641
+ super().__init__(
642
+ units=units,
643
+ activation=activation,
644
+ use_bias=use_bias,
645
+ normalize=normalize,
646
+ skip_connect=skip_connect,
647
+ **kwargs
648
+ )
649
+ self._update_edge_feature = update_edge_feature
650
+
651
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
652
+ node_feature_dim = spec.node['feature'].shape[-1]
653
+
654
+ self.epsilon = self.add_weight(
655
+ name='epsilon',
656
+ shape=(),
657
+ initializer='zeros',
658
+ trainable=True,
659
+ )
660
+
661
+ if self.has_edge_feature:
662
+ edge_feature_dim = spec.edge['feature'].shape[-1]
663
+
664
+ if not self._update_edge_feature:
665
+ if (edge_feature_dim != node_feature_dim):
666
+ warnings.warn(
667
+ 'Found edge and node feature dim to be incompatible. Applying a '
668
+ 'projection layer to edge features to match the dim of the node features.',
669
+ )
670
+ self._update_edge_feature = True
671
+
672
+ if self._update_edge_feature:
673
+ self._edge_dense = self.get_dense(node_feature_dim)
674
+ else:
675
+ self._update_edge_feature = False
676
+
677
+ def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
678
+ message = tensor.gather('feature', 'source')
679
+ edge_feature = tensor.edge.get('feature')
680
+ if self._update_edge_feature:
681
+ edge_feature = self._edge_dense(edge_feature)
682
+ if self.has_edge_feature:
683
+ message += edge_feature
684
+ message = keras.ops.relu(message)
685
+ return tensor.update(
686
+ {
687
+ 'edge': {
688
+ 'message': message,
689
+ 'feature': edge_feature
690
+ }
691
+ }
692
+ )
693
+
694
+ def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
695
+ node_feature = tensor.aggregate('message', mode='mean')
696
+ node_feature += (1 + self.epsilon) * tensor.node['feature']
697
+ return tensor.update(
698
+ {
699
+ 'node': {
700
+ 'aggregate': node_feature,
701
+ },
702
+ 'edge': {
703
+ 'message': None,
704
+ }
705
+ }
706
+ )
707
+
708
+ def get_config(self) -> dict:
709
+ config = super().get_config()
710
+ config.update({
711
+ 'update_edge_feature': self._update_edge_feature
712
+ })
713
+ return config
714
+
715
+
716
+ @keras.saving.register_keras_serializable(package='molcraft')
717
+ class GAConv(GraphConv):
718
+
719
+ """Graph attention network layer.
720
+
721
+ >>> graph = molcraft.tensors.GraphTensor(
722
+ ... context={
723
+ ... 'size': [2]
724
+ ... },
725
+ ... node={
726
+ ... 'feature': [[1.], [2.]]
727
+ ... },
728
+ ... edge={
729
+ ... 'source': [0, 1],
730
+ ... 'target': [1, 0],
731
+ ... }
732
+ ... )
733
+ >>> conv = molcraft.layers.GAConv(units=4, heads=2)
734
+ >>> conv(graph)
735
+ GraphTensor(
736
+ context={
737
+ 'size': <tf.Tensor: shape=[1], dtype=int32>
738
+ },
739
+ node={
740
+ 'feature': <tf.Tensor: shape=[2, 4], dtype=float32>
741
+ },
742
+ edge={
743
+ 'source': <tf.Tensor: shape=[2], dtype=int32>,
744
+ 'target': <tf.Tensor: shape=[2], dtype=int32>
745
+ }
746
+ )
747
+ """
748
+
749
+ def __init__(
750
+ self,
751
+ units: int,
752
+ heads: int = 8,
753
+ activation: keras.layers.Activation | str | None = "relu",
754
+ use_bias: bool = True,
755
+ normalize: bool = False,
756
+ skip_connect: bool = True,
757
+ update_edge_feature: bool = True,
758
+ **kwargs,
759
+ ) -> None:
760
+ super().__init__(
761
+ units=units,
762
+ activation=activation,
763
+ normalize=normalize,
764
+ use_bias=use_bias,
765
+ skip_connect=skip_connect,
766
+ **kwargs
767
+ )
768
+ self._heads = heads
769
+ if self.units % self.heads != 0:
770
+ raise ValueError(f"units need to be divisible by heads.")
771
+ self._head_units = self.units // self.heads
772
+ self._update_edge_feature = update_edge_feature
773
+
774
+ @property
775
+ def heads(self):
776
+ return self._heads
777
+
778
+ @property
779
+ def head_units(self):
780
+ return self._head_units
781
+
782
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
783
+ self._update_edge_feature = self.has_edge_feature and self._update_edge_feature
784
+ if self._update_edge_feature:
785
+ self._edge_dense = self.get_einsum_dense(
786
+ 'ijh,jkh->ikh', (self.head_units, self.heads)
787
+ )
788
+ self._node_dense = self.get_einsum_dense(
789
+ 'ij,jkh->ikh', (self.head_units, self.heads)
790
+ )
791
+ self._feature_dense = self.get_einsum_dense(
792
+ 'ij,jkh->ikh', (self.head_units, self.heads)
793
+ )
794
+ self._attention_dense = self.get_einsum_dense(
795
+ 'ijh,jkh->ikh', (1, self.heads)
796
+ )
797
+
798
+ def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
799
+ attention_feature = keras.ops.concatenate(
800
+ [
801
+ tensor.gather('feature', 'source'),
802
+ tensor.gather('feature', 'target')
803
+ ],
804
+ axis=-1
805
+ )
806
+ if self.has_edge_feature:
807
+ attention_feature = keras.ops.concatenate(
808
+ [
809
+ attention_feature,
810
+ tensor.edge['feature']
811
+ ],
812
+ axis=-1
813
+ )
814
+
815
+ attention_feature = self._feature_dense(attention_feature)
816
+
817
+ edge_feature = tensor.edge.get('feature')
818
+
819
+ if self._update_edge_feature:
820
+ edge_feature = self._edge_dense(attention_feature)
821
+ edge_feature = keras.ops.reshape(edge_feature, (-1, self.units))
822
+
823
+ attention_feature = keras.ops.leaky_relu(attention_feature)
824
+ attention_score = self._attention_dense(attention_feature)
825
+ attention_score = ops.edge_softmax(
826
+ score=attention_score, edge_target=tensor.edge['target']
827
+ )
828
+ node_feature = self._node_dense(tensor.node['feature'])
829
+ message = ops.gather(node_feature, tensor.edge['source'])
830
+ message = ops.edge_weight(message, attention_score)
831
+ return tensor.update(
832
+ {
833
+ 'edge': {
834
+ 'message': message,
835
+ 'feature': edge_feature,
836
+ }
837
+ }
838
+ )
839
+
840
+ def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
841
+ node_feature = tensor.aggregate('message', mode='sum')
842
+ node_feature = keras.ops.reshape(node_feature, (-1, self.units))
843
+ return tensor.update(
844
+ {
845
+ 'node': {
846
+ 'aggregate': node_feature
847
+ },
848
+ 'edge': {
849
+ 'message': None,
850
+ }
851
+ }
852
+ )
853
+
854
+ def get_config(self) -> dict:
855
+ config = super().get_config()
856
+ config.update({
857
+ "heads": self._heads,
858
+ 'update_edge_feature': self._update_edge_feature,
859
+ })
860
+ return config
861
+
862
+
863
+ @keras.saving.register_keras_serializable(package='molcraft')
864
+ class MPConv(GraphConv):
865
+
866
+ """Message passing neural network layer.
867
+
868
+ Also supports 3D molecular graphs.
869
+
870
+ >>> graph = molcraft.tensors.GraphTensor(
871
+ ... context={
872
+ ... 'size': [2]
873
+ ... },
874
+ ... node={
875
+ ... 'feature': [[1.], [2.]]
876
+ ... },
877
+ ... edge={
878
+ ... 'source': [0, 1],
879
+ ... 'target': [1, 0],
880
+ ... }
881
+ ... )
882
+ >>> conv = molcraft.layers.MPConv(units=4)
883
+ >>> conv(graph)
884
+ GraphTensor(
885
+ context={
886
+ 'size': <tf.Tensor: shape=[1], dtype=int32>
887
+ },
888
+ node={
889
+ 'feature': <tf.Tensor: shape=[2, 4], dtype=float32>
890
+ },
891
+ edge={
892
+ 'source': <tf.Tensor: shape=[2], dtype=int32>,
893
+ 'target': <tf.Tensor: shape=[2], dtype=int32>
894
+ }
895
+ )
896
+ """
897
+
898
+ def __init__(
899
+ self,
900
+ units: int = 128,
901
+ activation: keras.layers.Activation | str | None = 'relu',
902
+ use_bias: bool = True,
903
+ normalize: bool = False,
904
+ skip_connect: bool = True,
905
+ **kwargs
906
+ ) -> None:
907
+ super().__init__(
908
+ units=units,
909
+ activation=activation,
910
+ use_bias=use_bias,
911
+ normalize=normalize,
912
+ skip_connect=skip_connect,
913
+ **kwargs
914
+ )
915
+
916
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
917
+ node_feature_dim = spec.node['feature'].shape[-1]
918
+ self.update_fn = keras.layers.GRUCell(self.units)
919
+ self._project_previous_node_feature = node_feature_dim != self.units
920
+ if self._project_previous_node_feature:
921
+ warnings.warn(
922
+ 'Inputted node feature dim does not match updated node feature dim, '
923
+ 'which is required for the GRU update. Applying a projection layer to '
924
+ 'the inputted node features prior to the GRU update, to match dim '
925
+ 'of the updated node feature dim.'
926
+ )
927
+ self._previous_node_dense = self.get_dense(self.units)
928
+
929
+ def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
930
+ """Aggregates messages.
931
+
932
+ This method may be overridden by subclass.
933
+
934
+ Arguments:
935
+ tensor:
936
+ A `GraphTensor` instance containing a message.
937
+ """
938
+ aggregate = tensor.aggregate('message', mode='mean')
939
+ return tensor.update(
940
+ {
941
+ 'node': {
942
+ 'aggregate': aggregate,
943
+ },
944
+ 'edge': {
945
+ 'message': None,
946
+ }
947
+ }
948
+ )
949
+
950
+ def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
951
+ previous = tensor.node['feature']
952
+ aggregate = tensor.node['aggregate']
953
+ if self._project_previous_node_feature:
954
+ previous = self._previous_node_dense(previous)
955
+ updated_node_feature, _ = self.update_fn(
956
+ inputs=aggregate, states=previous
957
+ )
958
+ return tensor.update(
959
+ {
960
+ 'node': {
961
+ 'feature': updated_node_feature,
962
+ 'aggregate': None,
963
+ }
964
+ }
965
+ )
966
+
967
+ def get_config(self) -> dict:
968
+ config = super().get_config()
969
+ config.update({})
970
+ return config
971
+
972
+
973
+ @keras.saving.register_keras_serializable(package='molcraft')
974
+ class GTConv(GraphConv):
975
+
976
+ """Graph transformer layer.
977
+
978
+ Also supports 3D molecular graphs.
979
+
980
+ >>> graph = molcraft.tensors.GraphTensor(
981
+ ... context={
982
+ ... 'size': [2]
983
+ ... },
984
+ ... node={
985
+ ... 'feature': [[1.], [2.]]
986
+ ... },
987
+ ... edge={
988
+ ... 'source': [0, 1],
989
+ ... 'target': [1, 0],
990
+ ... }
991
+ ... )
992
+ >>> conv = molcraft.layers.GTConv(units=4, heads=2)
993
+ >>> conv(graph)
994
+ GraphTensor(
995
+ context={
996
+ 'size': <tf.Tensor: shape=[1], dtype=int32>
997
+ },
998
+ node={
999
+ 'feature': <tf.Tensor: shape=[2, 4], dtype=float32>
1000
+ },
1001
+ edge={
1002
+ 'source': <tf.Tensor: shape=[2], dtype=int32>,
1003
+ 'target': <tf.Tensor: shape=[2], dtype=int32>,
1004
+ }
1005
+ )
1006
+ """
1007
+
1008
+ def __init__(
1009
+ self,
1010
+ units: int,
1011
+ heads: int = 8,
1012
+ activation: keras.layers.Activation | str | None = "relu",
1013
+ use_bias: bool = True,
1014
+ normalize: bool = False,
1015
+ skip_connect: bool = True,
1016
+ attention_dropout: float = 0.0,
1017
+ **kwargs,
1018
+ ) -> None:
1019
+ super().__init__(
1020
+ units=units,
1021
+ activation=activation,
1022
+ normalize=normalize,
1023
+ use_bias=use_bias,
1024
+ skip_connect=skip_connect,
1025
+ **kwargs
1026
+ )
1027
+ self._heads = heads
1028
+ if self.units % self.heads != 0:
1029
+ raise ValueError(f"units need to be divisible by heads.")
1030
+ self._head_units = self.units // self.heads
1031
+ self._attention_dropout = attention_dropout
1032
+
1033
+ @property
1034
+ def heads(self):
1035
+ return self._heads
1036
+
1037
+ @property
1038
+ def head_units(self):
1039
+ return self._head_units
1040
+
1041
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
1042
+ """Builds the layer.
1043
+ """
1044
+ self._query_dense = self.get_einsum_dense(
1045
+ 'ij,jkh->ikh', (self.head_units, self.heads)
1046
+ )
1047
+ self._key_dense = self.get_einsum_dense(
1048
+ 'ij,jkh->ikh', (self.head_units, self.heads)
1049
+ )
1050
+ self._value_dense = self.get_einsum_dense(
1051
+ 'ij,jkh->ikh', (self.head_units, self.heads)
1052
+ )
1053
+ self._output_dense = self.get_dense(self.units)
1054
+ self._softmax_dropout = keras.layers.Dropout(self._attention_dropout)
1055
+
1056
+ if self.has_edge_feature:
1057
+ self._attention_bias_dense_1 = self.get_einsum_dense('ij,jkh->ikh', (1, self.heads))
1058
+
1059
+ if self.has_node_coordinate:
1060
+ node_feature_dim = spec.node['feature'].shape[-1]
1061
+ num_kernels = self.units
1062
+ self._gaussian_loc = self.add_weight(
1063
+ shape=[num_kernels], initializer='zeros', dtype='float32', trainable=True
1064
+ )
1065
+ self._gaussian_scale = self.add_weight(
1066
+ shape=[num_kernels], initializer='ones', dtype='float32', trainable=True
1067
+ )
1068
+ self._centrality_dense = self.get_dense(units=node_feature_dim)
1069
+ self._attention_bias_dense_2 = self.get_einsum_dense('ij,jkh->ikh', (1, self.heads))
1070
+
1071
+ def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1072
+ node_feature = tensor.node['feature']
1073
+
1074
+ if self.has_node_coordinate:
1075
+ euclidean_distance = ops.euclidean_distance(
1076
+ tensor.gather('coordinate', 'target'),
1077
+ tensor.gather('coordinate', 'source'),
1078
+ axis=-1,
1079
+ keepdims=True
1080
+ )
1081
+ gaussian = ops.gaussian(
1082
+ euclidean_distance, self._gaussian_loc, self._gaussian_scale
1083
+ )
1084
+ centrality = keras.ops.segment_sum(gaussian, tensor.edge['target'], tensor.num_nodes)
1085
+ node_feature += self._centrality_dense(centrality)
1086
+
1087
+ query = self._query_dense(node_feature)
1088
+ key = self._key_dense(node_feature)
1089
+ value = self._value_dense(node_feature)
1090
+
1091
+ query = ops.gather(query, tensor.edge['source'])
1092
+ key = ops.gather(key, tensor.edge['target'])
1093
+ value = ops.gather(value, tensor.edge['source'])
1094
+
1095
+ attention_score = keras.ops.sum(query * key, axis=1, keepdims=True)
1096
+ attention_score /= keras.ops.sqrt(float(self.head_units))
1097
+
1098
+ if self.has_edge_feature:
1099
+ attention_score += self._attention_bias_dense_1(tensor.edge['feature'])
1100
+
1101
+ if self.has_node_coordinate:
1102
+ attention_score += self._attention_bias_dense_2(gaussian)
1103
+
1104
+ attention = ops.edge_softmax(attention_score, tensor.edge['target'])
1105
+ attention = self._softmax_dropout(attention)
1106
+
1107
+ if self.has_node_coordinate:
1108
+ displacement = ops.displacement(
1109
+ tensor.gather('coordinate', 'target'),
1110
+ tensor.gather('coordinate', 'source'),
1111
+ normalize=True,
1112
+ axis=-1,
1113
+ keepdims=True,
1114
+ )
1115
+ attention *= keras.ops.expand_dims(displacement, axis=-1)
1116
+ attention = keras.ops.expand_dims(attention, axis=2)
1117
+ value = keras.ops.expand_dims(value, axis=1)
1118
+
1119
+ message = ops.edge_weight(value, attention)
1120
+
1121
+ return tensor.update(
1122
+ {
1123
+ 'edge': {
1124
+ 'message': message,
1125
+ },
1126
+ }
1127
+ )
1128
+
1129
+ def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1130
+ node_feature = tensor.aggregate('message', mode='sum')
1131
+ if self.has_node_coordinate:
1132
+ shape = (tensor.num_nodes, -1, self.units)
1133
+ else:
1134
+ shape = (tensor.num_nodes, self.units)
1135
+ node_feature = keras.ops.reshape(node_feature, shape)
1136
+ node_feature = self._output_dense(node_feature)
1137
+ if self.has_node_coordinate:
1138
+ node_feature = keras.ops.sum(node_feature, axis=1)
1139
+ return tensor.update(
1140
+ {
1141
+ 'node': {
1142
+ 'aggregate': node_feature,
1143
+ },
1144
+ 'edge': {
1145
+ 'message': None,
1146
+ }
1147
+ }
1148
+ )
1149
+
1150
+ def get_config(self) -> dict:
1151
+ config = super().get_config()
1152
+ config.update({
1153
+ "heads": self._heads,
1154
+ 'attention_dropout': self._attention_dropout,
1155
+ })
1156
+ return config
1157
+
1158
+
1159
+ @keras.saving.register_keras_serializable(package='molcraft')
1160
+ class EGConv(GraphConv):
1161
+
1162
+ """Equivariant graph neural network layer 3D.
1163
+
1164
+ Only supports 3D molecular graphs.
1165
+
1166
+ >>> graph = molcraft.tensors.GraphTensor(
1167
+ ... context={
1168
+ ... 'size': [2]
1169
+ ... },
1170
+ ... node={
1171
+ ... 'feature': [[1.], [2.]],
1172
+ ... 'coordinate': [[0.1, -0.1, 0.5], [1.2, -0.5, 2.1]],
1173
+ ... },
1174
+ ... edge={
1175
+ ... 'source': [0, 1],
1176
+ ... 'target': [1, 0],
1177
+ ... }
1178
+ ... )
1179
+ >>> conv = molcraft.layers.EGConv(units=4)
1180
+ >>> conv(graph)
1181
+ GraphTensor(
1182
+ context={
1183
+ 'size': <tf.Tensor: shape=[1], dtype=int32>
1184
+ },
1185
+ node={
1186
+ 'feature': <tf.Tensor: shape=[2, 4], dtype=float32>,
1187
+ 'coordinate': <tf.Tensor: shape=[2, 3], dtype=float32>
1188
+ },
1189
+ edge={
1190
+ 'source': <tf.Tensor: shape=[2], dtype=int32>,
1191
+ 'target': <tf.Tensor: shape=[2], dtype=int32>
1192
+ }
1193
+ )
1194
+ """
1195
+
1196
+ def __init__(
1197
+ self,
1198
+ units: int = 128,
1199
+ activation: keras.layers.Activation | str | None = 'silu',
1200
+ use_bias: bool = True,
1201
+ normalize: bool = False,
1202
+ skip_connect: bool = True,
1203
+ **kwargs
1204
+ ) -> None:
1205
+ super().__init__(
1206
+ units=units,
1207
+ activation=activation,
1208
+ use_bias=use_bias,
1209
+ normalize=normalize,
1210
+ skip_connect=skip_connect,
1211
+ **kwargs
1212
+ )
1213
+
1214
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
1215
+ if not self.has_node_coordinate:
1216
+ raise ValueError(
1217
+ 'Could not find `coordinate`s in node, '
1218
+ 'which is required for Conv3D layers.'
1219
+ )
1220
+ self._message_feedforward_intermediate = self.get_dense(
1221
+ self.units, activation=self.activation
1222
+ )
1223
+ self._message_feedforward_final = self.get_dense(
1224
+ self.units, activation=self.activation
1225
+ )
1226
+
1227
+ self._coord_feedforward_intermediate = self.get_dense(
1228
+ self.units, activation=self.activation
1229
+ )
1230
+ self._coord_feedforward_final = self.get_dense(
1231
+ 1, use_bias=False, activation='tanh'
1232
+ )
1233
+
1234
+ def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1235
+ relative_node_coordinate = keras.ops.subtract(
1236
+ tensor.gather('coordinate', 'target'),
1237
+ tensor.gather('coordinate', 'source')
1238
+ )
1239
+ squared_distance = keras.ops.sum(
1240
+ keras.ops.square(relative_node_coordinate),
1241
+ axis=-1,
1242
+ keepdims=True
1243
+ )
1244
+
1245
+ # For numerical stability (i.e., to prevent NaN losses), this implementation of `EGConv3D`
1246
+ # either needs to apply a `tanh` activation to the output of `self._coord_feedforward_final`,
1247
+ # or normalize `relative_node_cordinate` as follows:
1248
+ #
1249
+ # norm = keras.ops.sqrt(squared_distance) + keras.backend.epsilon()
1250
+ # relative_node_coordinate /= norm
1251
+ #
1252
+ # For now, this implementation does the former.
1253
+
1254
+ feature = keras.ops.concatenate(
1255
+ [
1256
+ tensor.gather('feature', 'target'),
1257
+ tensor.gather('feature', 'source'),
1258
+ squared_distance,
1259
+ ],
1260
+ axis=-1
1261
+ )
1262
+ if self.has_edge_feature:
1263
+ feature = keras.ops.concatenate(
1264
+ [
1265
+ feature,
1266
+ tensor.edge['feature']
1267
+ ],
1268
+ axis=-1
1269
+ )
1270
+ message = self._message_feedforward_final(
1271
+ self._message_feedforward_intermediate(feature)
1272
+ )
1273
+
1274
+ relative_node_coordinate = keras.ops.multiply(
1275
+ relative_node_coordinate,
1276
+ self._coord_feedforward_final(
1277
+ self._coord_feedforward_intermediate(message)
1278
+ )
1279
+ )
1280
+ return tensor.update(
1281
+ {
1282
+ 'edge': {
1283
+ 'message': message,
1284
+ 'relative_node_coordinate': relative_node_coordinate
1285
+ }
1286
+ }
1287
+ )
1288
+
1289
+ def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1290
+ coordinate = tensor.node['coordinate']
1291
+ coordinate += tensor.aggregate('relative_node_coordinate', mode='mean')
1292
+
1293
+ # Original implementation seems to apply sum aggregation, which does not
1294
+ # seem work well for this implementation of `EGConv3D`, as it causes
1295
+ # large output values and large initial losses. The magnitude of the
1296
+ # aggregated values of a sum aggregation depends on the number of
1297
+ # neighbors, which may be many and may differ from node to node (or
1298
+ # graph to graph). Therefore, a mean mean aggregation is performed
1299
+ # instead:
1300
+ aggregate = tensor.aggregate('message', mode='mean')
1301
+ aggregate = keras.ops.concatenate([aggregate, tensor.node['feature']], axis=-1)
1302
+ # Simply added to silence warning ('no gradients for variables ...')
1303
+ aggregate += (0.0 * keras.ops.sum(coordinate))
1304
+
1305
+ return tensor.update(
1306
+ {
1307
+ 'node': {
1308
+ 'aggregate': aggregate,
1309
+ 'coordinate': coordinate,
1310
+ },
1311
+ 'edge': {
1312
+ 'message': None,
1313
+ 'relative_node_coordinate': None
1314
+ }
1315
+ }
1316
+ )
1317
+
1318
+ def get_config(self) -> dict:
1319
+ config = super().get_config()
1320
+ config.update({})
1321
+ return config
1322
+
1323
+
1324
+ @keras.saving.register_keras_serializable(package='molcraft')
1325
+ class Readout(GraphLayer):
1326
+
1327
+ """Readout layer.
1328
+ """
1329
+
1330
+ def __init__(self, mode: str | None = None, **kwargs):
1331
+ kwargs['kernel_initializer'] = None
1332
+ kwargs['bias_initializer'] = None
1333
+ super().__init__(**kwargs)
1334
+ self.mode = mode
1335
+ mode = str(self.mode).lower()
1336
+ if mode.startswith('sum'):
1337
+ self._reduce_fn = keras.ops.segment_sum
1338
+ elif mode.startswith('max'):
1339
+ self._reduce_fn = keras.ops.segment_max
1340
+ else:
1341
+ self._reduce_fn = ops.segment_mean
1342
+
1343
+ def propagate(self, tensor: tensors.GraphTensor) -> tf.Tensor:
1344
+ return self._reduce_fn(
1345
+ tensor.node['feature'], tensor.graph_indicator, tensor.num_subgraphs
1346
+ )
1347
+
1348
+ def get_config(self) -> dict:
1349
+ config = super().get_config()
1350
+ config['mode'] = self.mode
1351
+ return config
1352
+
1353
+
1354
+ @keras.saving.register_keras_serializable(package='molcraft')
1355
+ class SuperReadout(GraphLayer):
1356
+
1357
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
1358
+ if 'super' not in spec.node:
1359
+ raise ValueError(
1360
+ 'Could not find `super` field in input.'
1361
+ )
1362
+
1363
+ def propagate(self, tensor: tensors.GraphTensor) -> tf.Tensor:
1364
+ node_feature = tensor.node['feature']
1365
+ node_feature = keras.ops.where(
1366
+ tensor.node['super'][:, None], node_feature, 0.0
1367
+ )
1368
+ return keras.ops.segment_sum(
1369
+ node_feature, tensor.graph_indicator, tensor.num_subgraphs
1370
+ )
1371
+
1372
+
1373
+ @keras.saving.register_keras_serializable(package='molcraft')
1374
+ class SubgraphReadout(GraphLayer):
1375
+
1376
+ def __init__(
1377
+ self,
1378
+ pad: bool = True,
1379
+ add_mask: bool | None = None,
1380
+ ignore_super_node: bool = True,
1381
+ **kwargs
1382
+ ) -> None:
1383
+ super().__init__(**kwargs)
1384
+ self._pad = pad
1385
+ self._ignore_super_node = ignore_super_node
1386
+ self._add_mask = (
1387
+ add_mask if add_mask is not None else self._pad
1388
+ )
1389
+ if self._add_mask:
1390
+ self._readout_mask = _ReadoutMask()
1391
+
1392
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
1393
+ if 'subgraph_indicator' not in spec.node:
1394
+ raise ValueError(
1395
+ 'Could not find `subgraph_indicator` field in input.'
1396
+ )
1397
+ self._has_super = 'super' in spec.node
1398
+
1399
+ def propagate(self, tensor: tensors.GraphTensor) -> tf.Tensor:
1400
+
1401
+ size = tensor.context['size']
1402
+ graph_indicator = tensor.graph_indicator
1403
+ subgraph_indicator = tensor.node['subgraph_indicator']
1404
+ node_feature = tensor.node['feature']
1405
+
1406
+ if self._has_super:
1407
+ if not self._ignore_super_node:
1408
+ super_node_feature = tf.boolean_mask(
1409
+ node_feature, tensor.node['super']
1410
+ )
1411
+ keep = (tensor.node['super'] == False)
1412
+ graph_indicator = tf.boolean_mask(graph_indicator, keep)
1413
+ subgraph_indicator = tf.boolean_mask(subgraph_indicator, keep)
1414
+ node_feature = tf.boolean_mask(node_feature, keep)
1415
+ size -= 1
1416
+
1417
+ num_subgraphs = keras.ops.segment_max(
1418
+ subgraph_indicator, graph_indicator
1419
+ )
1420
+ num_subgraphs += 1
1421
+
1422
+ def global_subgraph_indicator():
1423
+ incr = keras.ops.cumsum(num_subgraphs[:-1])
1424
+ incr = keras.ops.pad(incr, [(1, 0)])
1425
+ incr = keras.ops.repeat(incr, size)
1426
+ return subgraph_indicator + incr
1427
+
1428
+ readout = ops.segment_mean(node_feature, global_subgraph_indicator())
1429
+ readout = tf.RaggedTensor.from_row_lengths(readout, num_subgraphs)
1430
+
1431
+ if self._has_super and not self._ignore_super_node:
1432
+ readout += super_node_feature[:, None, :]
1433
+
1434
+ if not self._pad:
1435
+ return readout
1436
+
1437
+ return readout.to_tensor()
1438
+
1439
+ def compute_mask(self, inputs, previous_mask=None):
1440
+ if not self._add_mask:
1441
+ return None
1442
+ return self._readout_mask(inputs)
1443
+
1444
+ def get_config(self) -> dict:
1445
+ config = super().get_config()
1446
+ config['pad'] = self._pad
1447
+ config['add_mask'] = self._add_mask
1448
+ config['ignore_super_node'] = self._ignore_super_node
1449
+ return config
1450
+
1451
+
1452
+ @keras.saving.register_keras_serializable(package='molcraft')
1453
+ class NodeEmbedding(GraphLayer):
1454
+
1455
+ """Node embedding layer.
1456
+
1457
+ Embeds nodes based on its initial features.
1458
+ """
1459
+
1460
+ def __init__(
1461
+ self,
1462
+ dim: int | None = None,
1463
+ intermediate_dim: int | None = None,
1464
+ intermediate_activation: str | keras.layers.Activation | None = 'relu',
1465
+ normalize: bool = False,
1466
+ embed_context: bool = False,
1467
+ num_wildcards: int | None = None,
1468
+ **kwargs
1469
+ ) -> None:
1470
+ super().__init__(**kwargs)
1471
+ self.dim = dim
1472
+ self._intermediate_dim = intermediate_dim
1473
+ self._intermediate_activation = keras.activations.get(
1474
+ intermediate_activation
1475
+ )
1476
+ self._normalize = normalize
1477
+ self._embed_context = embed_context
1478
+ self._num_wildcards = num_wildcards
1479
+
1480
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
1481
+ feature_dim = spec.node['feature'].shape[-1]
1482
+ if not self.dim:
1483
+ self.dim = feature_dim
1484
+ if not self._intermediate_dim:
1485
+ self._intermediate_dim = self.dim * 2
1486
+ self._node_dense = self.get_dense(
1487
+ self._intermediate_dim, activation=self._intermediate_activation
1488
+ )
1489
+ self._has_wildcard = 'wildcard' in spec.node
1490
+ self._has_super = 'super' in spec.node
1491
+ has_context_feature = 'feature' in spec.context
1492
+ if not has_context_feature:
1493
+ self._embed_context = False
1494
+ if self._has_super and not self._embed_context:
1495
+ self._super_feature = self.get_weight(
1496
+ shape=[self._intermediate_dim], name='super_node_feature'
1497
+ )
1498
+ if self._embed_context:
1499
+ self._context_dense = self.get_dense(
1500
+ self._intermediate_dim, activation=self._intermediate_activation
1501
+ )
1502
+ if self._has_wildcard:
1503
+ if self._num_wildcards is None:
1504
+ warnings.warn(
1505
+ "Found wildcards in input, but `num_wildcards` is set to `None`. "
1506
+ "Automatically setting `num_wildcards` to 1. Please specify `num_wildcards>1` "
1507
+ "if the layer should distinguish between different types of wildcards."
1508
+ )
1509
+ self._num_wildcards = 1
1510
+ self._wildcard_features = self.get_weight(
1511
+ shape=[self._num_wildcards, self._intermediate_dim],
1512
+ name='wildcard_node_features'
1513
+ )
1514
+ if not self._normalize:
1515
+ self._norm = keras.layers.Identity()
1516
+ elif str(self._normalize).lower().startswith('layer'):
1517
+ self._norm = keras.layers.LayerNormalization()
1518
+ else:
1519
+ self._norm = keras.layers.BatchNormalization()
1520
+ self._dense = self.get_dense(self.dim)
1521
+
1522
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1523
+ feature = self._node_dense(tensor.node['feature'])
1524
+
1525
+ if self._has_super and not self._embed_context:
1526
+ super_mask = keras.ops.expand_dims(tensor.node['super'], 1)
1527
+ super_feature = self._intermediate_activation(self._super_feature)
1528
+ feature = keras.ops.where(super_mask, super_feature, feature)
1529
+
1530
+ if self._embed_context:
1531
+ context_feature = self._context_dense(tensor.context['feature'])
1532
+ feature = ops.scatter_update(feature, tensor.node['super'], context_feature)
1533
+ tensor = tensor.update({'context': {'feature': None}})
1534
+
1535
+ if self._has_wildcard:
1536
+ wildcard = tensor.node['wildcard']
1537
+ wildcard_indices = keras.ops.where(wildcard > 0)[0]
1538
+ wildcard = ops.gather(wildcard, wildcard_indices)
1539
+ wildcard_feature = ops.gather(
1540
+ self._wildcard_features, keras.ops.mod(wildcard-1, self._num_wildcards)
1541
+ )
1542
+ wildcard_feature = self._intermediate_activation(wildcard_feature)
1543
+ feature = ops.scatter_update(feature, wildcard_indices, wildcard_feature)
1544
+
1545
+ feature = self._norm(feature)
1546
+ feature = self._dense(feature)
1547
+
1548
+ return tensor.update({'node': {'feature': feature}})
1549
+
1550
+ def get_config(self) -> dict:
1551
+ config = super().get_config()
1552
+ config.update({
1553
+ 'dim': self.dim,
1554
+ 'intermediate_dim': self._intermediate_dim,
1555
+ 'intermediate_activation': keras.activations.serialize(
1556
+ self._intermediate_activation
1557
+ ),
1558
+ 'normalize': self._normalize,
1559
+ 'embed_context': self._embed_context,
1560
+ 'num_wildcards': self._num_wildcards,
1561
+ })
1562
+ return config
1563
+
1564
+
1565
+ @keras.saving.register_keras_serializable(package='molcraft')
1566
+ class EdgeEmbedding(GraphLayer):
1567
+
1568
+ """Edge embedding layer.
1569
+
1570
+ Embeds edges based on its initial features.
1571
+ """
1572
+
1573
+ def __init__(
1574
+ self,
1575
+ dim: int = None,
1576
+ intermediate_dim: int | None = None,
1577
+ intermediate_activation: str | keras.layers.Activation | None = 'relu',
1578
+ normalize: bool = False,
1579
+ **kwargs
1580
+ ) -> None:
1581
+ super().__init__(**kwargs)
1582
+ self.dim = dim
1583
+ self._intermediate_dim = intermediate_dim
1584
+ self._intermediate_activation = keras.activations.get(
1585
+ intermediate_activation
1586
+ )
1587
+ self._normalize = normalize
1588
+
1589
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
1590
+ feature_dim = spec.edge['feature'].shape[-1]
1591
+ if not self.dim:
1592
+ self.dim = feature_dim
1593
+ if not self._intermediate_dim:
1594
+ self._intermediate_dim = self.dim * 2
1595
+ self._edge_dense = self.get_dense(
1596
+ self._intermediate_dim, activation=self._intermediate_activation
1597
+ )
1598
+ self._self_loop_feature = self.get_weight(
1599
+ shape=[self._intermediate_dim], name='self_loop_edge_feature'
1600
+ )
1601
+ self._has_super = 'super' in spec.edge
1602
+ if self._has_super:
1603
+ self._super_feature = self.get_weight(
1604
+ shape=[self._intermediate_dim], name='super_edge_feature'
1605
+ )
1606
+ if not self._normalize:
1607
+ self._norm = keras.layers.Identity()
1608
+ elif str(self._normalize).lower().startswith('layer'):
1609
+ self._norm = keras.layers.LayerNormalization()
1610
+ else:
1611
+ self._norm = keras.layers.BatchNormalization()
1612
+ self._dense = self.get_dense(self.dim)
1613
+
1614
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1615
+ feature = self._edge_dense(tensor.edge['feature'])
1616
+
1617
+ if self._has_super:
1618
+ super_mask = keras.ops.expand_dims(tensor.edge['super'], 1)
1619
+ super_feature = self._intermediate_activation(self._super_feature)
1620
+ feature = keras.ops.where(super_mask, super_feature, feature)
1621
+
1622
+ self_loop_mask = keras.ops.expand_dims(tensor.edge['source'] == tensor.edge['target'], 1)
1623
+ self_loop_feature = self._intermediate_activation(self._self_loop_feature)
1624
+ feature = keras.ops.where(self_loop_mask, self_loop_feature, feature)
1625
+ feature = self._norm(feature)
1626
+ feature = self._dense(feature)
1627
+ return tensor.update({'edge': {'feature': feature}})
1628
+
1629
+ def get_config(self) -> dict:
1630
+ config = super().get_config()
1631
+ config.update({
1632
+ 'dim': self.dim,
1633
+ 'intermediate_dim': self._intermediate_dim,
1634
+ 'intermediate_activation': keras.activations.serialize(
1635
+ self._intermediate_activation
1636
+ ),
1637
+ 'normalize': self._normalize,
1638
+ })
1639
+ return config
1640
+
1641
+
1642
+ @keras.saving.register_keras_serializable(package='molcraft')
1643
+ class AddContext(GraphLayer):
1644
+
1645
+ """Context adding layer.
1646
+
1647
+ Adds context to super nodes or nodes, depending on whether super nodes exist.
1648
+ """
1649
+
1650
+ def __init__(
1651
+ self,
1652
+ field: str = 'feature',
1653
+ intermediate_dim: int | None = None,
1654
+ intermediate_activation: str | keras.layers.Activation | None = 'relu',
1655
+ drop: bool = False,
1656
+ normalize: bool = False,
1657
+ num_categories: int | None = None,
1658
+ **kwargs
1659
+ ) -> None:
1660
+ super().__init__(**kwargs)
1661
+ self._field = field
1662
+ self._drop = drop
1663
+ self._intermediate_dim = intermediate_dim
1664
+ self._intermediate_activation = keras.activations.get(
1665
+ intermediate_activation
1666
+ )
1667
+ self._normalize = normalize
1668
+ self._num_categories = num_categories
1669
+
1670
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
1671
+ is_categorical = spec.context[self._field].dtype.is_integer
1672
+ if is_categorical and not self._num_categories:
1673
+ raise ValueError(
1674
+ f'Found context ({self._field}) to be categorical (`int` dtype), but `num_categories`'
1675
+ 'to be `None`. Please specify `num_categories` for categorical context.'
1676
+ )
1677
+ elif not is_categorical and self._num_categories:
1678
+ warnings.warn(
1679
+ f'`num_categories` is set to {self._num_categories}, but found context to be '
1680
+ 'continuous (`float` dtype). Layer will cast context from `float` to `int` '
1681
+ 'before one-hot encoding it.'
1682
+ )
1683
+ feature_dim = spec.node['feature'].shape[-1]
1684
+ self._has_super_node = 'super' in spec.node
1685
+ if self._intermediate_dim is None:
1686
+ self._intermediate_dim = feature_dim * 2
1687
+ self._intermediate_dense = self.get_dense(
1688
+ self._intermediate_dim, activation=self._intermediate_activation
1689
+ )
1690
+ self._final_dense = self.get_dense(feature_dim)
1691
+ if not self._normalize:
1692
+ self._intermediate_norm = keras.layers.Identity()
1693
+ elif str(self._normalize).lower().startswith('layer'):
1694
+ self._intermediate_norm = keras.layers.LayerNormalization()
1695
+ else:
1696
+ self._intermediate_norm = keras.layers.BatchNormalization()
1697
+
1698
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1699
+ context = tensor.context[self._field]
1700
+ if self._num_categories:
1701
+ if context.dtype.is_floating:
1702
+ context = keras.ops.cast(context, dtype=tensor.edge['source'].dtype)
1703
+ context = keras.utils.to_categorical(context, self._num_categories)
1704
+ elif len(keras.ops.shape(context)) == 1:
1705
+ context = keras.ops.expand_dims(context, axis=1)
1706
+ context = self._intermediate_dense(context)
1707
+ context = self._intermediate_norm(context)
1708
+ context = self._final_dense(context)
1709
+ if self._has_super_node:
1710
+ node_feature = ops.scatter_add(
1711
+ tensor.node['feature'], tensor.node['super'], context
1712
+ )
1713
+ else:
1714
+ node_feature = (
1715
+ tensor.node['feature'] + ops.gather(context, tensor.graph_indicator)
1716
+ )
1717
+ data = {'node': {'feature': node_feature}}
1718
+ if self._drop:
1719
+ data['context'] = {self._field: None}
1720
+ return tensor.update(data)
1721
+
1722
+ def get_config(self) -> dict:
1723
+ config = super().get_config()
1724
+ config.update({
1725
+ 'field': self._field,
1726
+ 'intermediate_dim': self._intermediate_dim,
1727
+ 'intermediate_activation': keras.activations.serialize(
1728
+ self._intermediate_activation
1729
+ ),
1730
+ 'drop': self._drop,
1731
+ 'normalize': self._normalize,
1732
+ 'num_categories': self._num_categories,
1733
+ })
1734
+ return config
1735
+
1736
+
1737
+ @keras.saving.register_keras_serializable(package='molcraft')
1738
+ class GraphNetwork(GraphLayer):
1739
+
1740
+ """Graph neural network.
1741
+
1742
+ Sequentially calls graph layers (`GraphLayer`) and concatenates its output.
1743
+
1744
+ Arguments:
1745
+ layers (list):
1746
+ A list of graph layers.
1747
+ """
1748
+
1749
+ def __init__(self, layers: list[GraphLayer], **kwargs) -> None:
1750
+ super().__init__(**kwargs)
1751
+ self.layers = layers
1752
+ self._update_edge_feature = False
1753
+
1754
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
1755
+ units = self.layers[0].units
1756
+ node_feature_dim = spec.node['feature'].shape[-1]
1757
+ self._update_node_feature = node_feature_dim != units
1758
+ if self._update_node_feature:
1759
+ warnings.warn(
1760
+ 'Node feature dim does not match `units` of the first layer. '
1761
+ 'Applying a projection layer to node features to match `units`.',
1762
+ )
1763
+ self._node_dense = self.get_dense(units)
1764
+ self._has_edge_feature = 'feature' in spec.edge
1765
+ if self._has_edge_feature:
1766
+ edge_feature_dim = spec.edge['feature'].shape[-1]
1767
+ self._update_edge_feature = edge_feature_dim != units
1768
+ if self._update_edge_feature:
1769
+ warnings.warn(
1770
+ 'Edge feature dim does not match `units` of the first layer. '
1771
+ 'Applying projection layer to edge features to match `units`.'
1772
+ )
1773
+ self._edge_dense = self.get_dense(units)
1774
+
1775
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1776
+ x = tensors.to_dict(tensor)
1777
+ if self._update_node_feature:
1778
+ x['node']['feature'] = self._node_dense(tensor.node['feature'])
1779
+ if self._has_edge_feature and self._update_edge_feature:
1780
+ x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
1781
+ outputs = [x['node']['feature']]
1782
+ for layer in self.layers:
1783
+ x = layer(x)
1784
+ outputs.append(x['node']['feature'])
1785
+ return tensor.update(
1786
+ {
1787
+ 'node': {
1788
+ 'feature': keras.ops.concatenate(outputs, axis=-1)
1789
+ }
1790
+ }
1791
+ )
1792
+
1793
+ def tape_propagate(
1794
+ self,
1795
+ tensor: tensors.GraphTensor,
1796
+ tape: tf.GradientTape,
1797
+ training: bool | None = None,
1798
+ ) -> tuple[tensors.GraphTensor, list[tf.Tensor]]:
1799
+ """Performs the propagation with a `GradientTape`.
1800
+
1801
+ Performs the same forward pass as `propagate` but with a `GradientTape`
1802
+ watching intermediate node features.
1803
+
1804
+ Arguments:
1805
+ tensor (tensors.GraphTensor):
1806
+ The graph input.
1807
+ """
1808
+ if isinstance(tensor, tensors.GraphTensor):
1809
+ x = tensors.to_dict(tensor)
1810
+ else:
1811
+ x = tensor
1812
+ if self._update_node_feature:
1813
+ x['node']['feature'] = self._node_dense(tensor.node['feature'])
1814
+ if self._update_edge_feature:
1815
+ x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
1816
+ tape.watch(x['node']['feature'])
1817
+ outputs = [x['node']['feature']]
1818
+ for layer in self.layers:
1819
+ x = layer(x, training=training)
1820
+ tape.watch(x['node']['feature'])
1821
+ outputs.append(x['node']['feature'])
1822
+
1823
+ tensor = tensor.update(
1824
+ {
1825
+ 'node': {
1826
+ 'feature': keras.ops.concatenate(outputs, axis=-1)
1827
+ }
1828
+ }
1829
+ )
1830
+ return tensor, outputs
1831
+
1832
+ def get_config(self) -> dict:
1833
+ config = super().get_config()
1834
+ config.update(
1835
+ {
1836
+ 'layers': [
1837
+ keras.layers.serialize(layer) for layer in self.layers
1838
+ ]
1839
+ }
1840
+ )
1841
+ return config
1842
+
1843
+ @classmethod
1844
+ def from_config(cls, config: dict) -> 'GraphNetwork':
1845
+ config['layers'] = [
1846
+ keras.layers.deserialize(layer) for layer in config['layers']
1847
+ ]
1848
+ return super().from_config(config)
1849
+
1850
+
1851
+ @keras.saving.register_keras_serializable(package='molcraft')
1852
+ class GaussianParams(keras.layers.Dense):
1853
+ '''Gaussian parameters.
1854
+
1855
+ Computes loc and scale via a dense layer. Should be implemented
1856
+ as the last layer in a model and paired with `losses.GaussianNLL`.
1857
+
1858
+ The loc and scale parameters (resulting from this layer) are concatenated
1859
+ together along the last axis, resulting in a single output tensor.
1860
+
1861
+ Args:
1862
+ events (int):
1863
+ The number of events. If the model makes a single prediction per example,
1864
+ then the number of events should be 1. If the model makes multiple predictions
1865
+ per example, then the number of events should be greater than 1.
1866
+ Default to 1.
1867
+ kwargs:
1868
+ See `keras.layers.Dense` documentation. `activation` will be applied
1869
+ to `loc` only. `scale` is automatically softplus activated.
1870
+ '''
1871
+ def __init__(self, events: int = 1, **kwargs):
1872
+ units = kwargs.pop('units', None)
1873
+ activation = kwargs.pop('activation', None)
1874
+ if units:
1875
+ if units % 2 != 0:
1876
+ raise ValueError(
1877
+ '`units` needs to be divisble by 2 as `units` = 2 x `events`.'
1878
+ )
1879
+ else:
1880
+ units = int(events * 2)
1881
+ super().__init__(units=units, **kwargs)
1882
+ self.events = events
1883
+ self.loc_activation = keras.activations.get(activation)
1884
+
1885
+ def call(self, inputs, **kwargs):
1886
+ loc_and_scale = super().call(inputs, **kwargs)
1887
+ loc = loc_and_scale[..., :self.events]
1888
+ scale = loc_and_scale[..., self.events:]
1889
+ scale = keras.ops.softplus(scale) + keras.backend.epsilon()
1890
+ loc = self.loc_activation(loc)
1891
+ return keras.ops.concatenate([loc, scale], axis=-1)
1892
+
1893
+ def get_config(self):
1894
+ config = super().get_config()
1895
+ config['events'] = self.events
1896
+ config['units'] = None
1897
+ config['activation'] = keras.activations.serialize(self.loc_activation)
1898
+ return config
1899
+
1900
+
1901
+ def Input(spec: tensors.GraphTensor.Spec) -> dict:
1902
+ """Used to specify inputs to a functional model.
1903
+
1904
+ Example:
1905
+
1906
+ >>> import molcraft
1907
+ >>> import keras
1908
+ >>>
1909
+ >>> featurizer = molcraft.featurizers.MolGraphFeaturizer()
1910
+ >>> graph = featurizer([('N[C@@H](C)C(=O)O', 1.0), ('N[C@@H](CS)C(=O)O', 2.0)])
1911
+ >>>
1912
+ >>> model = molcraft.models.GraphModel.from_layers(
1913
+ ... molcraft.layers.Input(graph.spec),
1914
+ ... molcraft.layers.NodeEmbedding(128),
1915
+ ... molcraft.layers.EdgeEmbedding(128),
1916
+ ... molcraft.layers.GraphTransformer(128),
1917
+ ... molcraft.layers.GraphTransformer(128),
1918
+ ... molcraft.layers.Readout('mean'),
1919
+ ... molcraft.layers.Dense(1)
1920
+ ... ])
1921
+ """
1922
+
1923
+ # Currently, Keras (3.8.0) does not support extension types.
1924
+ # So for now, this function will unpack the `GraphTensor.Spec` and
1925
+ # return a dictionary of nested tensor specs. However, the corresponding
1926
+ # nest of tensors will temporarily be converted to a `GraphTensor` by the
1927
+ # `GraphLayer`, to levarage the utility of a `GraphTensor` object.
1928
+ inputs = {}
1929
+ for outer_field, data in spec.__dict__.items():
1930
+ inputs[outer_field] = {}
1931
+ for inner_field, nested_spec in data.items():
1932
+ if inner_field in ['label', 'sample_weight']:
1933
+ # Remove label and sample_weight from the symbolic input as
1934
+ # a functional model is strict for what input can be passed.
1935
+ continue
1936
+ kwargs = {
1937
+ 'shape': nested_spec.shape[1:],
1938
+ 'dtype': nested_spec.dtype,
1939
+ 'name': f'{outer_field}_{inner_field}'
1940
+ }
1941
+ if isinstance(nested_spec, tf.RaggedTensorSpec):
1942
+ # kwargs['ragged'] = True
1943
+ raise ValueError(
1944
+ 'Graph layers only supports graph input with nested `tf.Tensor` values.'
1945
+ )
1946
+ try:
1947
+ inputs[outer_field][inner_field] = keras.Input(**kwargs)
1948
+ except TypeError:
1949
+ raise ValueError(
1950
+ "`keras.Input` does not currently support ragged tensors. For now, "
1951
+ "pass the `Spec` of a 'flat' `GraphTensor` to `GNNInput`."
1952
+ )
1953
+ return inputs
1954
+
1955
+
1956
+ class _ReadoutMask(GraphLayer):
1957
+
1958
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
1959
+ if 'subgraph_indicator' not in spec.node:
1960
+ raise ValueError(
1961
+ 'Could not find `subgraph_indicator` field in input.'
1962
+ )
1963
+ self._has_super = 'super' in spec.node
1964
+
1965
+ def propagate(self, tensor: tensors.GraphTensor) -> tf.Tensor:
1966
+
1967
+ size = tensor.context['size']
1968
+ graph_indicator = tensor.graph_indicator
1969
+ subgraph_indicator = tensor.node['subgraph_indicator']
1970
+ node_feature = tensor.node['feature']
1971
+
1972
+ if self._has_super:
1973
+ keep = (tensor.node['super'] == False)
1974
+ graph_indicator = tf.boolean_mask(graph_indicator, keep)
1975
+ subgraph_indicator = tf.boolean_mask(subgraph_indicator, keep)
1976
+ node_feature = tf.boolean_mask(node_feature, keep)
1977
+ size -= 1
1978
+
1979
+ num_subgraphs = keras.ops.segment_max(
1980
+ subgraph_indicator, graph_indicator
1981
+ )
1982
+ num_subgraphs += 1
1983
+ max_len = keras.ops.max(num_subgraphs)
1984
+ mask = tf.sequence_mask(num_subgraphs, maxlen=max_len)
1985
+ return mask
1986
+
1987
+
1988
+ def _serialize_spec(spec: tensors.GraphTensor.Spec) -> dict:
1989
+ serialized_spec = {}
1990
+ for outer_field, data in spec.__dict__.items():
1991
+ serialized_spec[outer_field] = {}
1992
+ for inner_field, inner_spec in data.items():
1993
+ if inner_field in ['label', 'sample_weight']:
1994
+ continue
1995
+ serialized_spec[outer_field][inner_field] = {
1996
+ 'shape': inner_spec.shape.as_list(),
1997
+ 'dtype': inner_spec.dtype.name,
1998
+ 'name': inner_spec.name,
1999
+ }
2000
+ return serialized_spec
2001
+
2002
+ def _deserialize_spec(serialized_spec: dict) -> tensors.GraphTensor.Spec:
2003
+ deserialized_spec = {}
2004
+ for outer_field, data in serialized_spec.items():
2005
+ deserialized_spec[outer_field] = {}
2006
+ for inner_field, inner_spec in data.items():
2007
+ deserialized_spec[outer_field][inner_field] = tf.TensorSpec(
2008
+ inner_spec['shape'], inner_spec['dtype'], inner_spec['name']
2009
+ )
2010
+ return tensors.GraphTensor.Spec(**deserialized_spec)
2011
+
2012
+ def _spec_from_inputs(inputs):
2013
+ symbolic_inputs = keras.backend.is_keras_tensor(
2014
+ tf.nest.flatten(inputs)[0]
2015
+ )
2016
+ if not symbolic_inputs:
2017
+ nested_specs = tf.nest.map_structure(
2018
+ tf.type_spec_from_value, inputs
2019
+ )
2020
+ else:
2021
+ nested_specs = tf.nest.map_structure(
2022
+ lambda t: tf.TensorSpec(t.shape, t.dtype), inputs
2023
+ )
2024
+ if isinstance(nested_specs, tensors.GraphTensor.Spec):
2025
+ spec = nested_specs
2026
+ return spec
2027
+ return tensors.GraphTensor.Spec(**nested_specs)
2028
+
2029
+ def _propagate_kwargs(func) -> bool:
2030
+ signature = inspect.signature(func)
2031
+ return any(
2032
+ (param.kind == inspect.Parameter.VAR_KEYWORD) or (param.name == 'training')
2033
+ for param in signature.parameters.values()
2034
+ )