molcraft 0.1.0rc10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- molcraft/__init__.py +18 -0
- molcraft/callbacks.py +100 -0
- molcraft/chem.py +714 -0
- molcraft/datasets.py +132 -0
- molcraft/descriptors.py +149 -0
- molcraft/features.py +379 -0
- molcraft/featurizers.py +727 -0
- molcraft/layers.py +2034 -0
- molcraft/losses.py +37 -0
- molcraft/models.py +627 -0
- molcraft/ops.py +195 -0
- molcraft/records.py +187 -0
- molcraft/tensors.py +561 -0
- molcraft/trainers.py +212 -0
- molcraft-0.1.0rc10.dist-info/METADATA +118 -0
- molcraft-0.1.0rc10.dist-info/RECORD +19 -0
- molcraft-0.1.0rc10.dist-info/WHEEL +5 -0
- molcraft-0.1.0rc10.dist-info/licenses/LICENSE +21 -0
- molcraft-0.1.0rc10.dist-info/top_level.txt +1 -0
molcraft/layers.py
ADDED
|
@@ -0,0 +1,2034 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
import keras
|
|
3
|
+
import tensorflow as tf
|
|
4
|
+
import functools
|
|
5
|
+
import inspect
|
|
6
|
+
from keras.src.models import functional
|
|
7
|
+
|
|
8
|
+
from molcraft import tensors
|
|
9
|
+
from molcraft import ops
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
13
|
+
class GraphLayer(keras.layers.Layer):
|
|
14
|
+
"""Base graph layer.
|
|
15
|
+
|
|
16
|
+
Subclasses must implement a forward pass via **propagate(graph)**.
|
|
17
|
+
|
|
18
|
+
Subclasses may create dense layers and weights in **build(graph_spec)**.
|
|
19
|
+
|
|
20
|
+
Note: `GraphLayer` currently only supports `GraphTensor` input.
|
|
21
|
+
|
|
22
|
+
The list of arguments below are only relevant if the derived layer
|
|
23
|
+
invokes 'get_dense_kwargs`, `get_dense` or `get_einsum_dense`.
|
|
24
|
+
|
|
25
|
+
Arguments:
|
|
26
|
+
use_bias (bool):
|
|
27
|
+
Whether bias should be used in dense layers. Default to `True`.
|
|
28
|
+
kernel_initializer (keras.initializers.Initializer, str):
|
|
29
|
+
Initializer for the kernel weight matrix of the dense layers.
|
|
30
|
+
Default to `glorot_uniform`.
|
|
31
|
+
bias_initializer (keras.initializers.Initializer, str):
|
|
32
|
+
Initializer for the bias weight vector of the dense layers.
|
|
33
|
+
Default to `zeros`.
|
|
34
|
+
kernel_regularizer (keras.regularizers.Regularizer, None):
|
|
35
|
+
Regularizer function applied to the kernel weight matrix.
|
|
36
|
+
Default to `None`.
|
|
37
|
+
bias_regularizer (keras.regularizers.Regularizer, None):
|
|
38
|
+
Regularizer function applied to the bias weight vector.
|
|
39
|
+
Default to `None`.
|
|
40
|
+
activity_regularizer (keras.regularizers.Regularizer, None):
|
|
41
|
+
Regularizer function applied to the output of the dense layers.
|
|
42
|
+
Default to `None`.
|
|
43
|
+
kernel_constraint (keras.constraints.Constraint, None):
|
|
44
|
+
Constraint function applied to the kernel weight matrix.
|
|
45
|
+
Default to `None`.
|
|
46
|
+
bias_constraint (keras.constraints.Constraint, None):
|
|
47
|
+
Constraint function applied to the bias weight vector.
|
|
48
|
+
Default to `None`.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
use_bias: bool = True,
|
|
54
|
+
kernel_initializer: keras.initializers.Initializer | str = "glorot_uniform",
|
|
55
|
+
bias_initializer: keras.initializers.Initializer | str = "zeros",
|
|
56
|
+
kernel_regularizer: keras.regularizers.Regularizer | None = None,
|
|
57
|
+
bias_regularizer: keras.regularizers.Regularizer | None = None,
|
|
58
|
+
activity_regularizer: keras.regularizers.Regularizer | None = None,
|
|
59
|
+
kernel_constraint: keras.constraints.Constraint | None = None,
|
|
60
|
+
bias_constraint: keras.constraints.Constraint | None = None,
|
|
61
|
+
**kwargs,
|
|
62
|
+
) -> None:
|
|
63
|
+
super().__init__(**kwargs)
|
|
64
|
+
self._use_bias = use_bias
|
|
65
|
+
self._kernel_initializer = keras.initializers.get(kernel_initializer)
|
|
66
|
+
self._bias_initializer = keras.initializers.get(bias_initializer)
|
|
67
|
+
self._kernel_regularizer = keras.regularizers.get(kernel_regularizer)
|
|
68
|
+
self._bias_regularizer = keras.regularizers.get(bias_regularizer)
|
|
69
|
+
self._activity_regularizer = keras.regularizers.get(activity_regularizer)
|
|
70
|
+
self._kernel_constraint = keras.constraints.get(kernel_constraint)
|
|
71
|
+
self._bias_constraint = keras.constraints.get(bias_constraint)
|
|
72
|
+
self._custom_build_config = {}
|
|
73
|
+
self._propagate_kwargs = _propagate_kwargs(self.propagate)
|
|
74
|
+
self.built = False
|
|
75
|
+
|
|
76
|
+
def __init_subclass__(cls, **kwargs):
|
|
77
|
+
super().__init_subclass__(**kwargs)
|
|
78
|
+
subclass_build = cls.build
|
|
79
|
+
|
|
80
|
+
@functools.wraps(subclass_build)
|
|
81
|
+
def build_wrapper(self: GraphLayer, spec: tensors.GraphTensor.Spec | None):
|
|
82
|
+
GraphLayer.build(self, spec)
|
|
83
|
+
subclass_build(self, spec)
|
|
84
|
+
if not self.built and isinstance(self, keras.Model):
|
|
85
|
+
symbolic_inputs = Input(spec)
|
|
86
|
+
self.built = True
|
|
87
|
+
self(symbolic_inputs)
|
|
88
|
+
|
|
89
|
+
cls.build = build_wrapper
|
|
90
|
+
|
|
91
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
92
|
+
"""Forward pass.
|
|
93
|
+
|
|
94
|
+
Must be implemented by subclass.
|
|
95
|
+
|
|
96
|
+
Arguments:
|
|
97
|
+
tensor:
|
|
98
|
+
A `GraphTensor` instance.
|
|
99
|
+
"""
|
|
100
|
+
raise NotImplementedError(
|
|
101
|
+
'The forward pass of the layer is not implemented. '
|
|
102
|
+
'Please implement `propagate`.'
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
def build(self, tensor_spec: tensors.GraphTensor.Spec) -> None:
|
|
106
|
+
"""Builds the layer.
|
|
107
|
+
|
|
108
|
+
May use built-in methods such as `get_weight`, `get_dense` and `get_einsum_dense`.
|
|
109
|
+
|
|
110
|
+
Optionally implemented by subclass.
|
|
111
|
+
|
|
112
|
+
Arguments:
|
|
113
|
+
tensor_spec:
|
|
114
|
+
A `GraphTensor.Spec` instance corresponding to the `GraphTensor`
|
|
115
|
+
passed to `propagate`.
|
|
116
|
+
"""
|
|
117
|
+
if isinstance(tensor_spec, tensors.GraphTensor.Spec):
|
|
118
|
+
self._custom_build_config['spec'] = _serialize_spec(tensor_spec)
|
|
119
|
+
|
|
120
|
+
def call(
|
|
121
|
+
self,
|
|
122
|
+
graph: dict[str, dict[str, tf.Tensor]],
|
|
123
|
+
training: bool | None = None,
|
|
124
|
+
) -> dict[str, dict[str, tf.Tensor]]:
|
|
125
|
+
graph_tensor = tensors.from_dict(graph)
|
|
126
|
+
if not self._propagate_kwargs:
|
|
127
|
+
outputs = self.propagate(graph_tensor)
|
|
128
|
+
else:
|
|
129
|
+
outputs = self.propagate(graph_tensor, training=training)
|
|
130
|
+
if isinstance(outputs, tensors.GraphTensor):
|
|
131
|
+
return tensors.to_dict(outputs)
|
|
132
|
+
return outputs
|
|
133
|
+
|
|
134
|
+
def __call__(
|
|
135
|
+
self,
|
|
136
|
+
graph: dict[str, dict[str, tf.Tensor]] | tensors.GraphTensor,
|
|
137
|
+
**kwargs
|
|
138
|
+
) -> tf.Tensor | dict[str, dict[str, tf.Tensor]] | tensors.GraphTensor:
|
|
139
|
+
if not self.built:
|
|
140
|
+
spec = _spec_from_inputs(graph)
|
|
141
|
+
self.build(spec)
|
|
142
|
+
|
|
143
|
+
is_graph_tensor = isinstance(graph, tensors.GraphTensor)
|
|
144
|
+
if is_graph_tensor:
|
|
145
|
+
graph = tensors.to_dict(graph)
|
|
146
|
+
else:
|
|
147
|
+
graph = {field: dict(data) for (field, data) in graph.items()}
|
|
148
|
+
|
|
149
|
+
if isinstance(self, functional.Functional):
|
|
150
|
+
# As a functional model is strict for what input can be
|
|
151
|
+
# passed to it, we need to temporarily pop some of the
|
|
152
|
+
# input and add it back afterwards.
|
|
153
|
+
symbolic_input = self._symbolic_input
|
|
154
|
+
excluded = {}
|
|
155
|
+
for outer_field in ['context', 'node', 'edge']:
|
|
156
|
+
excluded[outer_field] = {}
|
|
157
|
+
for inner_field in list(graph[outer_field]):
|
|
158
|
+
if inner_field not in symbolic_input[outer_field]:
|
|
159
|
+
excluded[outer_field][inner_field] = (
|
|
160
|
+
graph[outer_field].pop(inner_field)
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
tf.nest.assert_same_structure(self.input, graph)
|
|
164
|
+
|
|
165
|
+
outputs = super().__call__(graph, **kwargs)
|
|
166
|
+
|
|
167
|
+
if not tensors.is_graph(outputs):
|
|
168
|
+
return outputs
|
|
169
|
+
|
|
170
|
+
graph = outputs
|
|
171
|
+
if isinstance(self, functional.Functional):
|
|
172
|
+
for outer_field in ['context', 'node', 'edge']:
|
|
173
|
+
for inner_field in list(excluded[outer_field]):
|
|
174
|
+
graph[outer_field][inner_field] = (
|
|
175
|
+
excluded[outer_field].pop(inner_field)
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
if is_graph_tensor:
|
|
179
|
+
return tensors.from_dict(graph)
|
|
180
|
+
|
|
181
|
+
return graph
|
|
182
|
+
|
|
183
|
+
def get_build_config(self) -> dict:
|
|
184
|
+
if self._custom_build_config:
|
|
185
|
+
return self._custom_build_config
|
|
186
|
+
return super().get_build_config()
|
|
187
|
+
|
|
188
|
+
def build_from_config(self, config: dict) -> None:
|
|
189
|
+
serialized_spec = config.get('spec')
|
|
190
|
+
if serialized_spec is not None:
|
|
191
|
+
spec = _deserialize_spec(serialized_spec)
|
|
192
|
+
self.build(spec)
|
|
193
|
+
else:
|
|
194
|
+
super().build_from_config(config)
|
|
195
|
+
|
|
196
|
+
def get_weight(
|
|
197
|
+
self,
|
|
198
|
+
shape: tf.TensorShape,
|
|
199
|
+
**kwargs,
|
|
200
|
+
) -> tf.Variable:
|
|
201
|
+
common_kwargs = self.get_dense_kwargs()
|
|
202
|
+
weight_kwargs = {
|
|
203
|
+
'initializer': common_kwargs['kernel_initializer'],
|
|
204
|
+
'regularizer': common_kwargs['kernel_regularizer'],
|
|
205
|
+
'constraint': common_kwargs['kernel_constraint']
|
|
206
|
+
}
|
|
207
|
+
weight_kwargs.update(kwargs)
|
|
208
|
+
return self.add_weight(shape=shape, **weight_kwargs)
|
|
209
|
+
|
|
210
|
+
def get_dense(
|
|
211
|
+
self,
|
|
212
|
+
units: int,
|
|
213
|
+
**kwargs
|
|
214
|
+
) -> keras.layers.Dense:
|
|
215
|
+
common_kwargs = self.get_dense_kwargs()
|
|
216
|
+
common_kwargs.update(kwargs)
|
|
217
|
+
return keras.layers.Dense(units, **common_kwargs)
|
|
218
|
+
|
|
219
|
+
def get_einsum_dense(
|
|
220
|
+
self,
|
|
221
|
+
equation: str,
|
|
222
|
+
output_shape: tf.TensorShape,
|
|
223
|
+
**kwargs
|
|
224
|
+
) -> keras.layers.EinsumDense:
|
|
225
|
+
common_kwargs = self.get_dense_kwargs()
|
|
226
|
+
common_kwargs.update(kwargs)
|
|
227
|
+
use_bias = common_kwargs.pop('use_bias', False)
|
|
228
|
+
if use_bias and not 'bias_axes' in common_kwargs:
|
|
229
|
+
common_kwargs['bias_axes'] = equation.split('->')[-1][1:] or None
|
|
230
|
+
return keras.layers.EinsumDense(equation, output_shape, **common_kwargs)
|
|
231
|
+
|
|
232
|
+
def get_dense_kwargs(self) -> dict:
|
|
233
|
+
common_kwargs = dict(
|
|
234
|
+
use_bias=self._use_bias,
|
|
235
|
+
kernel_regularizer=self._kernel_regularizer,
|
|
236
|
+
bias_regularizer=self._bias_regularizer,
|
|
237
|
+
activity_regularizer=self._activity_regularizer,
|
|
238
|
+
kernel_constraint=self._kernel_constraint,
|
|
239
|
+
bias_constraint=self._bias_constraint,
|
|
240
|
+
)
|
|
241
|
+
kernel_initializer = self._kernel_initializer.__class__.from_config(
|
|
242
|
+
self._kernel_initializer.get_config()
|
|
243
|
+
)
|
|
244
|
+
bias_initializer = self._bias_initializer.__class__.from_config(
|
|
245
|
+
self._bias_initializer.get_config()
|
|
246
|
+
)
|
|
247
|
+
common_kwargs["kernel_initializer"] = kernel_initializer
|
|
248
|
+
common_kwargs["bias_initializer"] = bias_initializer
|
|
249
|
+
return common_kwargs
|
|
250
|
+
|
|
251
|
+
def get_config(self) -> dict:
|
|
252
|
+
config = super().get_config()
|
|
253
|
+
config.update({
|
|
254
|
+
"use_bias": self._use_bias,
|
|
255
|
+
"kernel_initializer":
|
|
256
|
+
keras.initializers.serialize(self._kernel_initializer),
|
|
257
|
+
"bias_initializer":
|
|
258
|
+
keras.initializers.serialize(self._bias_initializer),
|
|
259
|
+
"kernel_regularizer":
|
|
260
|
+
keras.regularizers.serialize(self._kernel_regularizer),
|
|
261
|
+
"bias_regularizer":
|
|
262
|
+
keras.regularizers.serialize(self._bias_regularizer),
|
|
263
|
+
"activity_regularizer":
|
|
264
|
+
keras.regularizers.serialize(self._activity_regularizer),
|
|
265
|
+
"kernel_constraint":
|
|
266
|
+
keras.constraints.serialize(self._kernel_constraint),
|
|
267
|
+
"bias_constraint":
|
|
268
|
+
keras.constraints.serialize(self._bias_constraint),
|
|
269
|
+
})
|
|
270
|
+
return config
|
|
271
|
+
|
|
272
|
+
@property
|
|
273
|
+
def _symbolic_output(self) -> dict[dict[str, keras.KerasTensor]]:
|
|
274
|
+
output = self.output
|
|
275
|
+
if tensors.is_graph(output):
|
|
276
|
+
# GraphModel
|
|
277
|
+
return output
|
|
278
|
+
symbolic_input = self._symbolic_input
|
|
279
|
+
symbolic_output = self.compute_output_spec(symbolic_input)
|
|
280
|
+
return tf.nest.pack_sequence_as(symbolic_output, output)
|
|
281
|
+
|
|
282
|
+
@property
|
|
283
|
+
def _symbolic_input(self) -> dict[dict[str, keras.KerasTensor]]:
|
|
284
|
+
input = self.input
|
|
285
|
+
if tensors.is_graph(input):
|
|
286
|
+
# GraphModel or initial GraphLayer of GraphModel
|
|
287
|
+
return input
|
|
288
|
+
spec = _deserialize_spec(self.get_build_config()['spec'])
|
|
289
|
+
spec_dict = {k: dict(v) for k, v in spec.__dict__.items()}
|
|
290
|
+
return tf.nest.pack_sequence_as(spec_dict, input)
|
|
291
|
+
|
|
292
|
+
@property
|
|
293
|
+
def output_spec(self) -> tensors.GraphTensor.Spec | tf.TensorSpec:
|
|
294
|
+
if not self.built:
|
|
295
|
+
return None
|
|
296
|
+
if isinstance(self, functional.Functional):
|
|
297
|
+
output_spec = self.output
|
|
298
|
+
else:
|
|
299
|
+
serialized_spec = self.get_build_config()['spec']
|
|
300
|
+
deserialized_spec = _deserialize_spec(serialized_spec)
|
|
301
|
+
input_spec = Input(deserialized_spec)
|
|
302
|
+
output_spec = self.compute_output_spec(input_spec)
|
|
303
|
+
if not tensors.is_graph(output_spec):
|
|
304
|
+
return tf.TensorSpec(output_spec.shape, output_spec.dtype)
|
|
305
|
+
spec_dict = tf.nest.map_structure(
|
|
306
|
+
lambda t: tf.TensorSpec(t.shape, t.dtype), output_spec
|
|
307
|
+
)
|
|
308
|
+
return tensors.GraphTensor.Spec(**spec_dict)
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
312
|
+
class GraphConv(GraphLayer):
|
|
313
|
+
|
|
314
|
+
"""Base graph neural network layer.
|
|
315
|
+
|
|
316
|
+
This layer implements the three basic steps of a graph neural network layer, each of which
|
|
317
|
+
can (optionally) be overridden by the `GraphConv` subclass:
|
|
318
|
+
|
|
319
|
+
1. **message(graph)**, which computes the *messages* to be passed to target nodes;
|
|
320
|
+
2. **aggregate(graph)**, which *aggregates* messages to target nodes;
|
|
321
|
+
3. **update(graph)**, which further *updates* (target) nodes.
|
|
322
|
+
|
|
323
|
+
Note: for skip connection to work, the `GraphConv` subclass requires final node feature
|
|
324
|
+
output dimension to be equal to `units`.
|
|
325
|
+
|
|
326
|
+
Arguments:
|
|
327
|
+
units (int):
|
|
328
|
+
Dimensionality of the output space.
|
|
329
|
+
activation (keras.layers.Activation, str, None):
|
|
330
|
+
Activation function to be accessed via `self.activation`, and used for the
|
|
331
|
+
`message()` and `update()` methods, if not overriden. Default to `relu`.
|
|
332
|
+
use_bias (bool):
|
|
333
|
+
Whether bias should be used in the dense layers. Default to `True`.
|
|
334
|
+
normalize (bool, str):
|
|
335
|
+
Whether a normalization layer should be obtain by `get_norm()`. Default to `False`.
|
|
336
|
+
skip_connect (bool):
|
|
337
|
+
Whether node feature input should be added to the node feature output. Default to `True`.
|
|
338
|
+
kernel_initializer (keras.initializers.Initializer, str):
|
|
339
|
+
Initializer for the kernel weight matrix of the dense layers.
|
|
340
|
+
Default to `glorot_uniform`.
|
|
341
|
+
bias_initializer (keras.initializers.Initializer, str):
|
|
342
|
+
Initializer for the bias weight vector of the dense layers.
|
|
343
|
+
Default to `zeros`.
|
|
344
|
+
kernel_regularizer (keras.regularizers.Regularizer, None):
|
|
345
|
+
Regularizer function applied to the kernel weight matrix.
|
|
346
|
+
Default to `None`.
|
|
347
|
+
bias_regularizer (keras.regularizers.Regularizer, None):
|
|
348
|
+
Regularizer function applied to the bias weight vector.
|
|
349
|
+
Default to `None`.
|
|
350
|
+
activity_regularizer (keras.regularizers.Regularizer, None):
|
|
351
|
+
Regularizer function applied to the output of the dense layers.
|
|
352
|
+
Default to `None`.
|
|
353
|
+
kernel_constraint (keras.constraints.Constraint, None):
|
|
354
|
+
Constraint function applied to the kernel weight matrix.
|
|
355
|
+
Default to `None`.
|
|
356
|
+
bias_constraint (keras.constraints.Constraint, None):
|
|
357
|
+
Constraint function applied to the bias weight vector.
|
|
358
|
+
Default to `None`.
|
|
359
|
+
"""
|
|
360
|
+
|
|
361
|
+
def __init__(
|
|
362
|
+
self,
|
|
363
|
+
units: int = None,
|
|
364
|
+
activation: str | keras.layers.Activation | None = 'relu',
|
|
365
|
+
use_bias: bool = True,
|
|
366
|
+
normalize: bool = False,
|
|
367
|
+
skip_connect: bool = True,
|
|
368
|
+
**kwargs
|
|
369
|
+
) -> None:
|
|
370
|
+
super().__init__(use_bias=use_bias, **kwargs)
|
|
371
|
+
self._units = units
|
|
372
|
+
self._normalize = normalize
|
|
373
|
+
self._skip_connect = skip_connect
|
|
374
|
+
self._activation = keras.activations.get(activation)
|
|
375
|
+
|
|
376
|
+
def __init_subclass__(cls, **kwargs):
|
|
377
|
+
super().__init_subclass__(**kwargs)
|
|
378
|
+
subclass_build = cls.build
|
|
379
|
+
|
|
380
|
+
@functools.wraps(subclass_build)
|
|
381
|
+
def build_wrapper(self, spec):
|
|
382
|
+
GraphConv.build(self, spec)
|
|
383
|
+
return subclass_build(self, spec)
|
|
384
|
+
|
|
385
|
+
cls.build = build_wrapper
|
|
386
|
+
|
|
387
|
+
@property
|
|
388
|
+
def units(self):
|
|
389
|
+
return self._units
|
|
390
|
+
|
|
391
|
+
@property
|
|
392
|
+
def activation(self):
|
|
393
|
+
return self._activation
|
|
394
|
+
|
|
395
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
396
|
+
if not self.units:
|
|
397
|
+
raise ValueError(
|
|
398
|
+
f'`self.units` needs to be a positive integer. Found: {self.units}.'
|
|
399
|
+
)
|
|
400
|
+
node_feature_dim = spec.node['feature'].shape[-1]
|
|
401
|
+
self._project_residual = (
|
|
402
|
+
self._skip_connect and (node_feature_dim != self.units)
|
|
403
|
+
)
|
|
404
|
+
if self._project_residual:
|
|
405
|
+
warnings.warn(
|
|
406
|
+
'Found incompatible dim between input and output. Applying '
|
|
407
|
+
'a projection layer to residual to match input and output dim.',
|
|
408
|
+
)
|
|
409
|
+
self._residual_dense = self.get_dense(
|
|
410
|
+
self.units, name='residual_dense'
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
self.has_edge_feature = 'feature' in spec.edge
|
|
414
|
+
self.has_node_coordinate = 'coordinate' in spec.node
|
|
415
|
+
|
|
416
|
+
has_overridden_message = self.__class__.message != GraphConv.message
|
|
417
|
+
if not has_overridden_message:
|
|
418
|
+
self._message_intermediate_dense = self.get_dense(self.units)
|
|
419
|
+
self._message_norm = self.get_norm()
|
|
420
|
+
self._message_intermediate_activation = self.activation
|
|
421
|
+
self._message_final_dense = self.get_dense(self.units)
|
|
422
|
+
|
|
423
|
+
has_overridden_aggregate = self.__class__.message != GraphConv.aggregate
|
|
424
|
+
if not has_overridden_aggregate:
|
|
425
|
+
pass
|
|
426
|
+
|
|
427
|
+
has_overridden_update = self.__class__.update != GraphConv.update
|
|
428
|
+
if not has_overridden_update:
|
|
429
|
+
self._update_intermediate_dense = self.get_dense(self.units)
|
|
430
|
+
self._update_norm = self.get_norm()
|
|
431
|
+
self._update_intermediate_activation = self.activation
|
|
432
|
+
self._update_final_dense = self.get_dense(self.units)
|
|
433
|
+
|
|
434
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
435
|
+
"""Forward pass.
|
|
436
|
+
|
|
437
|
+
Invokes `message(graph)`, `aggregate(graph)` and `update(graph)` in sequence.
|
|
438
|
+
|
|
439
|
+
Arguments:
|
|
440
|
+
tensor:
|
|
441
|
+
A `GraphTensor` instance.
|
|
442
|
+
"""
|
|
443
|
+
if self._skip_connect:
|
|
444
|
+
residual = tensor.node['feature']
|
|
445
|
+
if self._project_residual:
|
|
446
|
+
residual = self._residual_dense(residual)
|
|
447
|
+
|
|
448
|
+
message = self.message(tensor)
|
|
449
|
+
add_message = not isinstance(message, tensors.GraphTensor)
|
|
450
|
+
if add_message:
|
|
451
|
+
message = tensor.update({'edge': {'message': message}})
|
|
452
|
+
elif not 'message' in message.edge:
|
|
453
|
+
raise ValueError('Could not find `message` in `edge` output.')
|
|
454
|
+
|
|
455
|
+
aggregate = self.aggregate(message)
|
|
456
|
+
add_aggregate = not isinstance(aggregate, tensors.GraphTensor)
|
|
457
|
+
if add_aggregate:
|
|
458
|
+
aggregate = tensor.update({'node': {'aggregate': aggregate}})
|
|
459
|
+
elif not 'aggregate' in aggregate.node:
|
|
460
|
+
raise ValueError('Could not find `aggregate` in `node` output.')
|
|
461
|
+
|
|
462
|
+
update = self.update(aggregate)
|
|
463
|
+
if not isinstance(update, tensors.GraphTensor):
|
|
464
|
+
update = tensor.update({'node': {'feature': update}})
|
|
465
|
+
elif not 'feature' in update.node:
|
|
466
|
+
raise ValueError('Could not find `feature` in `node` output.')
|
|
467
|
+
|
|
468
|
+
if update.node['feature'].shape[-1] != self.units:
|
|
469
|
+
raise ValueError('Updated node `feature` is not equal to `self.units`.')
|
|
470
|
+
|
|
471
|
+
if add_message and add_aggregate:
|
|
472
|
+
update = update.update({'node': {'aggregate': None}, 'edge': {'message': None}})
|
|
473
|
+
elif add_message:
|
|
474
|
+
update = update.update({'edge': {'message': None}})
|
|
475
|
+
elif add_aggregate:
|
|
476
|
+
update = update.update({'node': {'aggregate': None}})
|
|
477
|
+
|
|
478
|
+
if not self._skip_connect:
|
|
479
|
+
return update
|
|
480
|
+
|
|
481
|
+
feature = update.node['feature']
|
|
482
|
+
|
|
483
|
+
if self._skip_connect:
|
|
484
|
+
feature += residual
|
|
485
|
+
|
|
486
|
+
return update.update({'node': {'feature': feature}})
|
|
487
|
+
|
|
488
|
+
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
489
|
+
"""Compute messages.
|
|
490
|
+
|
|
491
|
+
This method may be overridden by subclass.
|
|
492
|
+
|
|
493
|
+
Arguments:
|
|
494
|
+
tensor:
|
|
495
|
+
The inputted `GraphTensor` instance.
|
|
496
|
+
"""
|
|
497
|
+
message = keras.ops.concatenate(
|
|
498
|
+
[
|
|
499
|
+
tensor.gather('feature', 'source'),
|
|
500
|
+
tensor.gather('feature', 'target'),
|
|
501
|
+
],
|
|
502
|
+
axis=-1
|
|
503
|
+
)
|
|
504
|
+
if self.has_edge_feature:
|
|
505
|
+
message = keras.ops.concatenate(
|
|
506
|
+
[
|
|
507
|
+
message,
|
|
508
|
+
tensor.edge['feature']
|
|
509
|
+
],
|
|
510
|
+
axis=-1
|
|
511
|
+
)
|
|
512
|
+
if self.has_node_coordinate:
|
|
513
|
+
euclidean_distance = ops.euclidean_distance(
|
|
514
|
+
tensor.gather('coordinate', 'target'),
|
|
515
|
+
tensor.gather('coordinate', 'source'),
|
|
516
|
+
axis=-1,
|
|
517
|
+
keepdims=True
|
|
518
|
+
)
|
|
519
|
+
message = keras.ops.concatenate(
|
|
520
|
+
[
|
|
521
|
+
message,
|
|
522
|
+
euclidean_distance
|
|
523
|
+
],
|
|
524
|
+
axis=-1
|
|
525
|
+
)
|
|
526
|
+
message = self._message_intermediate_dense(message)
|
|
527
|
+
message = self._message_norm(message)
|
|
528
|
+
message = self._message_intermediate_activation(message)
|
|
529
|
+
message = self._message_final_dense(message)
|
|
530
|
+
return tensor.update({'edge': {'message': message}})
|
|
531
|
+
|
|
532
|
+
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
533
|
+
"""Aggregates messages.
|
|
534
|
+
|
|
535
|
+
This method may be overridden by subclass.
|
|
536
|
+
|
|
537
|
+
Arguments:
|
|
538
|
+
tensor:
|
|
539
|
+
A `GraphTensor` instance containing a message.
|
|
540
|
+
"""
|
|
541
|
+
previous = tensor.node['feature']
|
|
542
|
+
aggregate = tensor.aggregate('message', mode='mean')
|
|
543
|
+
aggregate = keras.ops.concatenate([aggregate, previous], axis=-1)
|
|
544
|
+
return tensor.update(
|
|
545
|
+
{
|
|
546
|
+
'node': {
|
|
547
|
+
'aggregate': aggregate,
|
|
548
|
+
},
|
|
549
|
+
'edge': {
|
|
550
|
+
'message': None,
|
|
551
|
+
}
|
|
552
|
+
}
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
556
|
+
"""Updates nodes.
|
|
557
|
+
|
|
558
|
+
This method may be overridden by subclass.
|
|
559
|
+
|
|
560
|
+
Arguments:
|
|
561
|
+
tensor:
|
|
562
|
+
A `GraphTensor` instance containing aggregated messages
|
|
563
|
+
(updated node features).
|
|
564
|
+
"""
|
|
565
|
+
aggregate = tensor.node['aggregate']
|
|
566
|
+
node_feature = self._update_intermediate_dense(aggregate)
|
|
567
|
+
node_feature = self._update_norm(node_feature)
|
|
568
|
+
node_feature = self._update_intermediate_activation(node_feature)
|
|
569
|
+
node_feature = self._update_final_dense(node_feature)
|
|
570
|
+
return tensor.update(
|
|
571
|
+
{
|
|
572
|
+
'node': {
|
|
573
|
+
'feature': node_feature,
|
|
574
|
+
'aggregate': None,
|
|
575
|
+
},
|
|
576
|
+
}
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
def get_norm(self, **kwargs):
|
|
580
|
+
if not self._normalize:
|
|
581
|
+
return keras.layers.Identity()
|
|
582
|
+
elif str(self._normalize).lower().startswith('layer'):
|
|
583
|
+
return keras.layers.LayerNormalization(**kwargs)
|
|
584
|
+
else:
|
|
585
|
+
return keras.layers.BatchNormalization(**kwargs)
|
|
586
|
+
|
|
587
|
+
def get_config(self) -> dict:
|
|
588
|
+
config = super().get_config()
|
|
589
|
+
config.update({
|
|
590
|
+
'units': self.units,
|
|
591
|
+
'activation': keras.activations.serialize(self._activation),
|
|
592
|
+
'normalize': self._normalize,
|
|
593
|
+
'skip_connect': self._skip_connect,
|
|
594
|
+
})
|
|
595
|
+
return config
|
|
596
|
+
|
|
597
|
+
|
|
598
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
599
|
+
class GIConv(GraphConv):
|
|
600
|
+
|
|
601
|
+
"""Graph isomorphism network layer.
|
|
602
|
+
|
|
603
|
+
>>> graph = molcraft.tensors.GraphTensor(
|
|
604
|
+
... context={
|
|
605
|
+
... 'size': [2]
|
|
606
|
+
... },
|
|
607
|
+
... node={
|
|
608
|
+
... 'feature': [[1.], [2.]]
|
|
609
|
+
... },
|
|
610
|
+
... edge={
|
|
611
|
+
... 'source': [0, 1],
|
|
612
|
+
... 'target': [1, 0],
|
|
613
|
+
... }
|
|
614
|
+
... )
|
|
615
|
+
>>> conv = molcraft.layers.GIConv(units=4)
|
|
616
|
+
>>> conv(graph)
|
|
617
|
+
GraphTensor(
|
|
618
|
+
context={
|
|
619
|
+
'size': <tf.Tensor: shape=[1], dtype=int32>
|
|
620
|
+
},
|
|
621
|
+
node={
|
|
622
|
+
'feature': <tf.Tensor: shape=[2, 4], dtype=float32>
|
|
623
|
+
},
|
|
624
|
+
edge={
|
|
625
|
+
'source': <tf.Tensor: shape=[2], dtype=int32>,
|
|
626
|
+
'target': <tf.Tensor: shape=[2], dtype=int32>
|
|
627
|
+
}
|
|
628
|
+
)
|
|
629
|
+
"""
|
|
630
|
+
|
|
631
|
+
def __init__(
|
|
632
|
+
self,
|
|
633
|
+
units: int,
|
|
634
|
+
activation: keras.layers.Activation | str | None = 'relu',
|
|
635
|
+
use_bias: bool = True,
|
|
636
|
+
normalize: bool = False,
|
|
637
|
+
skip_connect: bool = True,
|
|
638
|
+
update_edge_feature: bool = True,
|
|
639
|
+
**kwargs,
|
|
640
|
+
):
|
|
641
|
+
super().__init__(
|
|
642
|
+
units=units,
|
|
643
|
+
activation=activation,
|
|
644
|
+
use_bias=use_bias,
|
|
645
|
+
normalize=normalize,
|
|
646
|
+
skip_connect=skip_connect,
|
|
647
|
+
**kwargs
|
|
648
|
+
)
|
|
649
|
+
self._update_edge_feature = update_edge_feature
|
|
650
|
+
|
|
651
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
652
|
+
node_feature_dim = spec.node['feature'].shape[-1]
|
|
653
|
+
|
|
654
|
+
self.epsilon = self.add_weight(
|
|
655
|
+
name='epsilon',
|
|
656
|
+
shape=(),
|
|
657
|
+
initializer='zeros',
|
|
658
|
+
trainable=True,
|
|
659
|
+
)
|
|
660
|
+
|
|
661
|
+
if self.has_edge_feature:
|
|
662
|
+
edge_feature_dim = spec.edge['feature'].shape[-1]
|
|
663
|
+
|
|
664
|
+
if not self._update_edge_feature:
|
|
665
|
+
if (edge_feature_dim != node_feature_dim):
|
|
666
|
+
warnings.warn(
|
|
667
|
+
'Found edge and node feature dim to be incompatible. Applying a '
|
|
668
|
+
'projection layer to edge features to match the dim of the node features.',
|
|
669
|
+
)
|
|
670
|
+
self._update_edge_feature = True
|
|
671
|
+
|
|
672
|
+
if self._update_edge_feature:
|
|
673
|
+
self._edge_dense = self.get_dense(node_feature_dim)
|
|
674
|
+
else:
|
|
675
|
+
self._update_edge_feature = False
|
|
676
|
+
|
|
677
|
+
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
678
|
+
message = tensor.gather('feature', 'source')
|
|
679
|
+
edge_feature = tensor.edge.get('feature')
|
|
680
|
+
if self._update_edge_feature:
|
|
681
|
+
edge_feature = self._edge_dense(edge_feature)
|
|
682
|
+
if self.has_edge_feature:
|
|
683
|
+
message += edge_feature
|
|
684
|
+
message = keras.ops.relu(message)
|
|
685
|
+
return tensor.update(
|
|
686
|
+
{
|
|
687
|
+
'edge': {
|
|
688
|
+
'message': message,
|
|
689
|
+
'feature': edge_feature
|
|
690
|
+
}
|
|
691
|
+
}
|
|
692
|
+
)
|
|
693
|
+
|
|
694
|
+
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
695
|
+
node_feature = tensor.aggregate('message', mode='mean')
|
|
696
|
+
node_feature += (1 + self.epsilon) * tensor.node['feature']
|
|
697
|
+
return tensor.update(
|
|
698
|
+
{
|
|
699
|
+
'node': {
|
|
700
|
+
'aggregate': node_feature,
|
|
701
|
+
},
|
|
702
|
+
'edge': {
|
|
703
|
+
'message': None,
|
|
704
|
+
}
|
|
705
|
+
}
|
|
706
|
+
)
|
|
707
|
+
|
|
708
|
+
def get_config(self) -> dict:
|
|
709
|
+
config = super().get_config()
|
|
710
|
+
config.update({
|
|
711
|
+
'update_edge_feature': self._update_edge_feature
|
|
712
|
+
})
|
|
713
|
+
return config
|
|
714
|
+
|
|
715
|
+
|
|
716
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
717
|
+
class GAConv(GraphConv):
|
|
718
|
+
|
|
719
|
+
"""Graph attention network layer.
|
|
720
|
+
|
|
721
|
+
>>> graph = molcraft.tensors.GraphTensor(
|
|
722
|
+
... context={
|
|
723
|
+
... 'size': [2]
|
|
724
|
+
... },
|
|
725
|
+
... node={
|
|
726
|
+
... 'feature': [[1.], [2.]]
|
|
727
|
+
... },
|
|
728
|
+
... edge={
|
|
729
|
+
... 'source': [0, 1],
|
|
730
|
+
... 'target': [1, 0],
|
|
731
|
+
... }
|
|
732
|
+
... )
|
|
733
|
+
>>> conv = molcraft.layers.GAConv(units=4, heads=2)
|
|
734
|
+
>>> conv(graph)
|
|
735
|
+
GraphTensor(
|
|
736
|
+
context={
|
|
737
|
+
'size': <tf.Tensor: shape=[1], dtype=int32>
|
|
738
|
+
},
|
|
739
|
+
node={
|
|
740
|
+
'feature': <tf.Tensor: shape=[2, 4], dtype=float32>
|
|
741
|
+
},
|
|
742
|
+
edge={
|
|
743
|
+
'source': <tf.Tensor: shape=[2], dtype=int32>,
|
|
744
|
+
'target': <tf.Tensor: shape=[2], dtype=int32>
|
|
745
|
+
}
|
|
746
|
+
)
|
|
747
|
+
"""
|
|
748
|
+
|
|
749
|
+
def __init__(
|
|
750
|
+
self,
|
|
751
|
+
units: int,
|
|
752
|
+
heads: int = 8,
|
|
753
|
+
activation: keras.layers.Activation | str | None = "relu",
|
|
754
|
+
use_bias: bool = True,
|
|
755
|
+
normalize: bool = False,
|
|
756
|
+
skip_connect: bool = True,
|
|
757
|
+
update_edge_feature: bool = True,
|
|
758
|
+
**kwargs,
|
|
759
|
+
) -> None:
|
|
760
|
+
super().__init__(
|
|
761
|
+
units=units,
|
|
762
|
+
activation=activation,
|
|
763
|
+
normalize=normalize,
|
|
764
|
+
use_bias=use_bias,
|
|
765
|
+
skip_connect=skip_connect,
|
|
766
|
+
**kwargs
|
|
767
|
+
)
|
|
768
|
+
self._heads = heads
|
|
769
|
+
if self.units % self.heads != 0:
|
|
770
|
+
raise ValueError(f"units need to be divisible by heads.")
|
|
771
|
+
self._head_units = self.units // self.heads
|
|
772
|
+
self._update_edge_feature = update_edge_feature
|
|
773
|
+
|
|
774
|
+
@property
|
|
775
|
+
def heads(self):
|
|
776
|
+
return self._heads
|
|
777
|
+
|
|
778
|
+
@property
|
|
779
|
+
def head_units(self):
|
|
780
|
+
return self._head_units
|
|
781
|
+
|
|
782
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
783
|
+
self._update_edge_feature = self.has_edge_feature and self._update_edge_feature
|
|
784
|
+
if self._update_edge_feature:
|
|
785
|
+
self._edge_dense = self.get_einsum_dense(
|
|
786
|
+
'ijh,jkh->ikh', (self.head_units, self.heads)
|
|
787
|
+
)
|
|
788
|
+
self._node_dense = self.get_einsum_dense(
|
|
789
|
+
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
790
|
+
)
|
|
791
|
+
self._feature_dense = self.get_einsum_dense(
|
|
792
|
+
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
793
|
+
)
|
|
794
|
+
self._attention_dense = self.get_einsum_dense(
|
|
795
|
+
'ijh,jkh->ikh', (1, self.heads)
|
|
796
|
+
)
|
|
797
|
+
|
|
798
|
+
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
799
|
+
attention_feature = keras.ops.concatenate(
|
|
800
|
+
[
|
|
801
|
+
tensor.gather('feature', 'source'),
|
|
802
|
+
tensor.gather('feature', 'target')
|
|
803
|
+
],
|
|
804
|
+
axis=-1
|
|
805
|
+
)
|
|
806
|
+
if self.has_edge_feature:
|
|
807
|
+
attention_feature = keras.ops.concatenate(
|
|
808
|
+
[
|
|
809
|
+
attention_feature,
|
|
810
|
+
tensor.edge['feature']
|
|
811
|
+
],
|
|
812
|
+
axis=-1
|
|
813
|
+
)
|
|
814
|
+
|
|
815
|
+
attention_feature = self._feature_dense(attention_feature)
|
|
816
|
+
|
|
817
|
+
edge_feature = tensor.edge.get('feature')
|
|
818
|
+
|
|
819
|
+
if self._update_edge_feature:
|
|
820
|
+
edge_feature = self._edge_dense(attention_feature)
|
|
821
|
+
edge_feature = keras.ops.reshape(edge_feature, (-1, self.units))
|
|
822
|
+
|
|
823
|
+
attention_feature = keras.ops.leaky_relu(attention_feature)
|
|
824
|
+
attention_score = self._attention_dense(attention_feature)
|
|
825
|
+
attention_score = ops.edge_softmax(
|
|
826
|
+
score=attention_score, edge_target=tensor.edge['target']
|
|
827
|
+
)
|
|
828
|
+
node_feature = self._node_dense(tensor.node['feature'])
|
|
829
|
+
message = ops.gather(node_feature, tensor.edge['source'])
|
|
830
|
+
message = ops.edge_weight(message, attention_score)
|
|
831
|
+
return tensor.update(
|
|
832
|
+
{
|
|
833
|
+
'edge': {
|
|
834
|
+
'message': message,
|
|
835
|
+
'feature': edge_feature,
|
|
836
|
+
}
|
|
837
|
+
}
|
|
838
|
+
)
|
|
839
|
+
|
|
840
|
+
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
841
|
+
node_feature = tensor.aggregate('message', mode='sum')
|
|
842
|
+
node_feature = keras.ops.reshape(node_feature, (-1, self.units))
|
|
843
|
+
return tensor.update(
|
|
844
|
+
{
|
|
845
|
+
'node': {
|
|
846
|
+
'aggregate': node_feature
|
|
847
|
+
},
|
|
848
|
+
'edge': {
|
|
849
|
+
'message': None,
|
|
850
|
+
}
|
|
851
|
+
}
|
|
852
|
+
)
|
|
853
|
+
|
|
854
|
+
def get_config(self) -> dict:
|
|
855
|
+
config = super().get_config()
|
|
856
|
+
config.update({
|
|
857
|
+
"heads": self._heads,
|
|
858
|
+
'update_edge_feature': self._update_edge_feature,
|
|
859
|
+
})
|
|
860
|
+
return config
|
|
861
|
+
|
|
862
|
+
|
|
863
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
864
|
+
class MPConv(GraphConv):
|
|
865
|
+
|
|
866
|
+
"""Message passing neural network layer.
|
|
867
|
+
|
|
868
|
+
Also supports 3D molecular graphs.
|
|
869
|
+
|
|
870
|
+
>>> graph = molcraft.tensors.GraphTensor(
|
|
871
|
+
... context={
|
|
872
|
+
... 'size': [2]
|
|
873
|
+
... },
|
|
874
|
+
... node={
|
|
875
|
+
... 'feature': [[1.], [2.]]
|
|
876
|
+
... },
|
|
877
|
+
... edge={
|
|
878
|
+
... 'source': [0, 1],
|
|
879
|
+
... 'target': [1, 0],
|
|
880
|
+
... }
|
|
881
|
+
... )
|
|
882
|
+
>>> conv = molcraft.layers.MPConv(units=4)
|
|
883
|
+
>>> conv(graph)
|
|
884
|
+
GraphTensor(
|
|
885
|
+
context={
|
|
886
|
+
'size': <tf.Tensor: shape=[1], dtype=int32>
|
|
887
|
+
},
|
|
888
|
+
node={
|
|
889
|
+
'feature': <tf.Tensor: shape=[2, 4], dtype=float32>
|
|
890
|
+
},
|
|
891
|
+
edge={
|
|
892
|
+
'source': <tf.Tensor: shape=[2], dtype=int32>,
|
|
893
|
+
'target': <tf.Tensor: shape=[2], dtype=int32>
|
|
894
|
+
}
|
|
895
|
+
)
|
|
896
|
+
"""
|
|
897
|
+
|
|
898
|
+
def __init__(
|
|
899
|
+
self,
|
|
900
|
+
units: int = 128,
|
|
901
|
+
activation: keras.layers.Activation | str | None = 'relu',
|
|
902
|
+
use_bias: bool = True,
|
|
903
|
+
normalize: bool = False,
|
|
904
|
+
skip_connect: bool = True,
|
|
905
|
+
**kwargs
|
|
906
|
+
) -> None:
|
|
907
|
+
super().__init__(
|
|
908
|
+
units=units,
|
|
909
|
+
activation=activation,
|
|
910
|
+
use_bias=use_bias,
|
|
911
|
+
normalize=normalize,
|
|
912
|
+
skip_connect=skip_connect,
|
|
913
|
+
**kwargs
|
|
914
|
+
)
|
|
915
|
+
|
|
916
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
917
|
+
node_feature_dim = spec.node['feature'].shape[-1]
|
|
918
|
+
self.update_fn = keras.layers.GRUCell(self.units)
|
|
919
|
+
self._project_previous_node_feature = node_feature_dim != self.units
|
|
920
|
+
if self._project_previous_node_feature:
|
|
921
|
+
warnings.warn(
|
|
922
|
+
'Inputted node feature dim does not match updated node feature dim, '
|
|
923
|
+
'which is required for the GRU update. Applying a projection layer to '
|
|
924
|
+
'the inputted node features prior to the GRU update, to match dim '
|
|
925
|
+
'of the updated node feature dim.'
|
|
926
|
+
)
|
|
927
|
+
self._previous_node_dense = self.get_dense(self.units)
|
|
928
|
+
|
|
929
|
+
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
930
|
+
"""Aggregates messages.
|
|
931
|
+
|
|
932
|
+
This method may be overridden by subclass.
|
|
933
|
+
|
|
934
|
+
Arguments:
|
|
935
|
+
tensor:
|
|
936
|
+
A `GraphTensor` instance containing a message.
|
|
937
|
+
"""
|
|
938
|
+
aggregate = tensor.aggregate('message', mode='mean')
|
|
939
|
+
return tensor.update(
|
|
940
|
+
{
|
|
941
|
+
'node': {
|
|
942
|
+
'aggregate': aggregate,
|
|
943
|
+
},
|
|
944
|
+
'edge': {
|
|
945
|
+
'message': None,
|
|
946
|
+
}
|
|
947
|
+
}
|
|
948
|
+
)
|
|
949
|
+
|
|
950
|
+
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
951
|
+
previous = tensor.node['feature']
|
|
952
|
+
aggregate = tensor.node['aggregate']
|
|
953
|
+
if self._project_previous_node_feature:
|
|
954
|
+
previous = self._previous_node_dense(previous)
|
|
955
|
+
updated_node_feature, _ = self.update_fn(
|
|
956
|
+
inputs=aggregate, states=previous
|
|
957
|
+
)
|
|
958
|
+
return tensor.update(
|
|
959
|
+
{
|
|
960
|
+
'node': {
|
|
961
|
+
'feature': updated_node_feature,
|
|
962
|
+
'aggregate': None,
|
|
963
|
+
}
|
|
964
|
+
}
|
|
965
|
+
)
|
|
966
|
+
|
|
967
|
+
def get_config(self) -> dict:
|
|
968
|
+
config = super().get_config()
|
|
969
|
+
config.update({})
|
|
970
|
+
return config
|
|
971
|
+
|
|
972
|
+
|
|
973
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
974
|
+
class GTConv(GraphConv):
|
|
975
|
+
|
|
976
|
+
"""Graph transformer layer.
|
|
977
|
+
|
|
978
|
+
Also supports 3D molecular graphs.
|
|
979
|
+
|
|
980
|
+
>>> graph = molcraft.tensors.GraphTensor(
|
|
981
|
+
... context={
|
|
982
|
+
... 'size': [2]
|
|
983
|
+
... },
|
|
984
|
+
... node={
|
|
985
|
+
... 'feature': [[1.], [2.]]
|
|
986
|
+
... },
|
|
987
|
+
... edge={
|
|
988
|
+
... 'source': [0, 1],
|
|
989
|
+
... 'target': [1, 0],
|
|
990
|
+
... }
|
|
991
|
+
... )
|
|
992
|
+
>>> conv = molcraft.layers.GTConv(units=4, heads=2)
|
|
993
|
+
>>> conv(graph)
|
|
994
|
+
GraphTensor(
|
|
995
|
+
context={
|
|
996
|
+
'size': <tf.Tensor: shape=[1], dtype=int32>
|
|
997
|
+
},
|
|
998
|
+
node={
|
|
999
|
+
'feature': <tf.Tensor: shape=[2, 4], dtype=float32>
|
|
1000
|
+
},
|
|
1001
|
+
edge={
|
|
1002
|
+
'source': <tf.Tensor: shape=[2], dtype=int32>,
|
|
1003
|
+
'target': <tf.Tensor: shape=[2], dtype=int32>,
|
|
1004
|
+
}
|
|
1005
|
+
)
|
|
1006
|
+
"""
|
|
1007
|
+
|
|
1008
|
+
def __init__(
|
|
1009
|
+
self,
|
|
1010
|
+
units: int,
|
|
1011
|
+
heads: int = 8,
|
|
1012
|
+
activation: keras.layers.Activation | str | None = "relu",
|
|
1013
|
+
use_bias: bool = True,
|
|
1014
|
+
normalize: bool = False,
|
|
1015
|
+
skip_connect: bool = True,
|
|
1016
|
+
attention_dropout: float = 0.0,
|
|
1017
|
+
**kwargs,
|
|
1018
|
+
) -> None:
|
|
1019
|
+
super().__init__(
|
|
1020
|
+
units=units,
|
|
1021
|
+
activation=activation,
|
|
1022
|
+
normalize=normalize,
|
|
1023
|
+
use_bias=use_bias,
|
|
1024
|
+
skip_connect=skip_connect,
|
|
1025
|
+
**kwargs
|
|
1026
|
+
)
|
|
1027
|
+
self._heads = heads
|
|
1028
|
+
if self.units % self.heads != 0:
|
|
1029
|
+
raise ValueError(f"units need to be divisible by heads.")
|
|
1030
|
+
self._head_units = self.units // self.heads
|
|
1031
|
+
self._attention_dropout = attention_dropout
|
|
1032
|
+
|
|
1033
|
+
@property
|
|
1034
|
+
def heads(self):
|
|
1035
|
+
return self._heads
|
|
1036
|
+
|
|
1037
|
+
@property
|
|
1038
|
+
def head_units(self):
|
|
1039
|
+
return self._head_units
|
|
1040
|
+
|
|
1041
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1042
|
+
"""Builds the layer.
|
|
1043
|
+
"""
|
|
1044
|
+
self._query_dense = self.get_einsum_dense(
|
|
1045
|
+
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
1046
|
+
)
|
|
1047
|
+
self._key_dense = self.get_einsum_dense(
|
|
1048
|
+
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
1049
|
+
)
|
|
1050
|
+
self._value_dense = self.get_einsum_dense(
|
|
1051
|
+
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
1052
|
+
)
|
|
1053
|
+
self._output_dense = self.get_dense(self.units)
|
|
1054
|
+
self._softmax_dropout = keras.layers.Dropout(self._attention_dropout)
|
|
1055
|
+
|
|
1056
|
+
if self.has_edge_feature:
|
|
1057
|
+
self._attention_bias_dense_1 = self.get_einsum_dense('ij,jkh->ikh', (1, self.heads))
|
|
1058
|
+
|
|
1059
|
+
if self.has_node_coordinate:
|
|
1060
|
+
node_feature_dim = spec.node['feature'].shape[-1]
|
|
1061
|
+
num_kernels = self.units
|
|
1062
|
+
self._gaussian_loc = self.add_weight(
|
|
1063
|
+
shape=[num_kernels], initializer='zeros', dtype='float32', trainable=True
|
|
1064
|
+
)
|
|
1065
|
+
self._gaussian_scale = self.add_weight(
|
|
1066
|
+
shape=[num_kernels], initializer='ones', dtype='float32', trainable=True
|
|
1067
|
+
)
|
|
1068
|
+
self._centrality_dense = self.get_dense(units=node_feature_dim)
|
|
1069
|
+
self._attention_bias_dense_2 = self.get_einsum_dense('ij,jkh->ikh', (1, self.heads))
|
|
1070
|
+
|
|
1071
|
+
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1072
|
+
node_feature = tensor.node['feature']
|
|
1073
|
+
|
|
1074
|
+
if self.has_node_coordinate:
|
|
1075
|
+
euclidean_distance = ops.euclidean_distance(
|
|
1076
|
+
tensor.gather('coordinate', 'target'),
|
|
1077
|
+
tensor.gather('coordinate', 'source'),
|
|
1078
|
+
axis=-1,
|
|
1079
|
+
keepdims=True
|
|
1080
|
+
)
|
|
1081
|
+
gaussian = ops.gaussian(
|
|
1082
|
+
euclidean_distance, self._gaussian_loc, self._gaussian_scale
|
|
1083
|
+
)
|
|
1084
|
+
centrality = keras.ops.segment_sum(gaussian, tensor.edge['target'], tensor.num_nodes)
|
|
1085
|
+
node_feature += self._centrality_dense(centrality)
|
|
1086
|
+
|
|
1087
|
+
query = self._query_dense(node_feature)
|
|
1088
|
+
key = self._key_dense(node_feature)
|
|
1089
|
+
value = self._value_dense(node_feature)
|
|
1090
|
+
|
|
1091
|
+
query = ops.gather(query, tensor.edge['source'])
|
|
1092
|
+
key = ops.gather(key, tensor.edge['target'])
|
|
1093
|
+
value = ops.gather(value, tensor.edge['source'])
|
|
1094
|
+
|
|
1095
|
+
attention_score = keras.ops.sum(query * key, axis=1, keepdims=True)
|
|
1096
|
+
attention_score /= keras.ops.sqrt(float(self.head_units))
|
|
1097
|
+
|
|
1098
|
+
if self.has_edge_feature:
|
|
1099
|
+
attention_score += self._attention_bias_dense_1(tensor.edge['feature'])
|
|
1100
|
+
|
|
1101
|
+
if self.has_node_coordinate:
|
|
1102
|
+
attention_score += self._attention_bias_dense_2(gaussian)
|
|
1103
|
+
|
|
1104
|
+
attention = ops.edge_softmax(attention_score, tensor.edge['target'])
|
|
1105
|
+
attention = self._softmax_dropout(attention)
|
|
1106
|
+
|
|
1107
|
+
if self.has_node_coordinate:
|
|
1108
|
+
displacement = ops.displacement(
|
|
1109
|
+
tensor.gather('coordinate', 'target'),
|
|
1110
|
+
tensor.gather('coordinate', 'source'),
|
|
1111
|
+
normalize=True,
|
|
1112
|
+
axis=-1,
|
|
1113
|
+
keepdims=True,
|
|
1114
|
+
)
|
|
1115
|
+
attention *= keras.ops.expand_dims(displacement, axis=-1)
|
|
1116
|
+
attention = keras.ops.expand_dims(attention, axis=2)
|
|
1117
|
+
value = keras.ops.expand_dims(value, axis=1)
|
|
1118
|
+
|
|
1119
|
+
message = ops.edge_weight(value, attention)
|
|
1120
|
+
|
|
1121
|
+
return tensor.update(
|
|
1122
|
+
{
|
|
1123
|
+
'edge': {
|
|
1124
|
+
'message': message,
|
|
1125
|
+
},
|
|
1126
|
+
}
|
|
1127
|
+
)
|
|
1128
|
+
|
|
1129
|
+
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1130
|
+
node_feature = tensor.aggregate('message', mode='sum')
|
|
1131
|
+
if self.has_node_coordinate:
|
|
1132
|
+
shape = (tensor.num_nodes, -1, self.units)
|
|
1133
|
+
else:
|
|
1134
|
+
shape = (tensor.num_nodes, self.units)
|
|
1135
|
+
node_feature = keras.ops.reshape(node_feature, shape)
|
|
1136
|
+
node_feature = self._output_dense(node_feature)
|
|
1137
|
+
if self.has_node_coordinate:
|
|
1138
|
+
node_feature = keras.ops.sum(node_feature, axis=1)
|
|
1139
|
+
return tensor.update(
|
|
1140
|
+
{
|
|
1141
|
+
'node': {
|
|
1142
|
+
'aggregate': node_feature,
|
|
1143
|
+
},
|
|
1144
|
+
'edge': {
|
|
1145
|
+
'message': None,
|
|
1146
|
+
}
|
|
1147
|
+
}
|
|
1148
|
+
)
|
|
1149
|
+
|
|
1150
|
+
def get_config(self) -> dict:
|
|
1151
|
+
config = super().get_config()
|
|
1152
|
+
config.update({
|
|
1153
|
+
"heads": self._heads,
|
|
1154
|
+
'attention_dropout': self._attention_dropout,
|
|
1155
|
+
})
|
|
1156
|
+
return config
|
|
1157
|
+
|
|
1158
|
+
|
|
1159
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1160
|
+
class EGConv(GraphConv):
|
|
1161
|
+
|
|
1162
|
+
"""Equivariant graph neural network layer 3D.
|
|
1163
|
+
|
|
1164
|
+
Only supports 3D molecular graphs.
|
|
1165
|
+
|
|
1166
|
+
>>> graph = molcraft.tensors.GraphTensor(
|
|
1167
|
+
... context={
|
|
1168
|
+
... 'size': [2]
|
|
1169
|
+
... },
|
|
1170
|
+
... node={
|
|
1171
|
+
... 'feature': [[1.], [2.]],
|
|
1172
|
+
... 'coordinate': [[0.1, -0.1, 0.5], [1.2, -0.5, 2.1]],
|
|
1173
|
+
... },
|
|
1174
|
+
... edge={
|
|
1175
|
+
... 'source': [0, 1],
|
|
1176
|
+
... 'target': [1, 0],
|
|
1177
|
+
... }
|
|
1178
|
+
... )
|
|
1179
|
+
>>> conv = molcraft.layers.EGConv(units=4)
|
|
1180
|
+
>>> conv(graph)
|
|
1181
|
+
GraphTensor(
|
|
1182
|
+
context={
|
|
1183
|
+
'size': <tf.Tensor: shape=[1], dtype=int32>
|
|
1184
|
+
},
|
|
1185
|
+
node={
|
|
1186
|
+
'feature': <tf.Tensor: shape=[2, 4], dtype=float32>,
|
|
1187
|
+
'coordinate': <tf.Tensor: shape=[2, 3], dtype=float32>
|
|
1188
|
+
},
|
|
1189
|
+
edge={
|
|
1190
|
+
'source': <tf.Tensor: shape=[2], dtype=int32>,
|
|
1191
|
+
'target': <tf.Tensor: shape=[2], dtype=int32>
|
|
1192
|
+
}
|
|
1193
|
+
)
|
|
1194
|
+
"""
|
|
1195
|
+
|
|
1196
|
+
def __init__(
|
|
1197
|
+
self,
|
|
1198
|
+
units: int = 128,
|
|
1199
|
+
activation: keras.layers.Activation | str | None = 'silu',
|
|
1200
|
+
use_bias: bool = True,
|
|
1201
|
+
normalize: bool = False,
|
|
1202
|
+
skip_connect: bool = True,
|
|
1203
|
+
**kwargs
|
|
1204
|
+
) -> None:
|
|
1205
|
+
super().__init__(
|
|
1206
|
+
units=units,
|
|
1207
|
+
activation=activation,
|
|
1208
|
+
use_bias=use_bias,
|
|
1209
|
+
normalize=normalize,
|
|
1210
|
+
skip_connect=skip_connect,
|
|
1211
|
+
**kwargs
|
|
1212
|
+
)
|
|
1213
|
+
|
|
1214
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1215
|
+
if not self.has_node_coordinate:
|
|
1216
|
+
raise ValueError(
|
|
1217
|
+
'Could not find `coordinate`s in node, '
|
|
1218
|
+
'which is required for Conv3D layers.'
|
|
1219
|
+
)
|
|
1220
|
+
self._message_feedforward_intermediate = self.get_dense(
|
|
1221
|
+
self.units, activation=self.activation
|
|
1222
|
+
)
|
|
1223
|
+
self._message_feedforward_final = self.get_dense(
|
|
1224
|
+
self.units, activation=self.activation
|
|
1225
|
+
)
|
|
1226
|
+
|
|
1227
|
+
self._coord_feedforward_intermediate = self.get_dense(
|
|
1228
|
+
self.units, activation=self.activation
|
|
1229
|
+
)
|
|
1230
|
+
self._coord_feedforward_final = self.get_dense(
|
|
1231
|
+
1, use_bias=False, activation='tanh'
|
|
1232
|
+
)
|
|
1233
|
+
|
|
1234
|
+
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1235
|
+
relative_node_coordinate = keras.ops.subtract(
|
|
1236
|
+
tensor.gather('coordinate', 'target'),
|
|
1237
|
+
tensor.gather('coordinate', 'source')
|
|
1238
|
+
)
|
|
1239
|
+
squared_distance = keras.ops.sum(
|
|
1240
|
+
keras.ops.square(relative_node_coordinate),
|
|
1241
|
+
axis=-1,
|
|
1242
|
+
keepdims=True
|
|
1243
|
+
)
|
|
1244
|
+
|
|
1245
|
+
# For numerical stability (i.e., to prevent NaN losses), this implementation of `EGConv3D`
|
|
1246
|
+
# either needs to apply a `tanh` activation to the output of `self._coord_feedforward_final`,
|
|
1247
|
+
# or normalize `relative_node_cordinate` as follows:
|
|
1248
|
+
#
|
|
1249
|
+
# norm = keras.ops.sqrt(squared_distance) + keras.backend.epsilon()
|
|
1250
|
+
# relative_node_coordinate /= norm
|
|
1251
|
+
#
|
|
1252
|
+
# For now, this implementation does the former.
|
|
1253
|
+
|
|
1254
|
+
feature = keras.ops.concatenate(
|
|
1255
|
+
[
|
|
1256
|
+
tensor.gather('feature', 'target'),
|
|
1257
|
+
tensor.gather('feature', 'source'),
|
|
1258
|
+
squared_distance,
|
|
1259
|
+
],
|
|
1260
|
+
axis=-1
|
|
1261
|
+
)
|
|
1262
|
+
if self.has_edge_feature:
|
|
1263
|
+
feature = keras.ops.concatenate(
|
|
1264
|
+
[
|
|
1265
|
+
feature,
|
|
1266
|
+
tensor.edge['feature']
|
|
1267
|
+
],
|
|
1268
|
+
axis=-1
|
|
1269
|
+
)
|
|
1270
|
+
message = self._message_feedforward_final(
|
|
1271
|
+
self._message_feedforward_intermediate(feature)
|
|
1272
|
+
)
|
|
1273
|
+
|
|
1274
|
+
relative_node_coordinate = keras.ops.multiply(
|
|
1275
|
+
relative_node_coordinate,
|
|
1276
|
+
self._coord_feedforward_final(
|
|
1277
|
+
self._coord_feedforward_intermediate(message)
|
|
1278
|
+
)
|
|
1279
|
+
)
|
|
1280
|
+
return tensor.update(
|
|
1281
|
+
{
|
|
1282
|
+
'edge': {
|
|
1283
|
+
'message': message,
|
|
1284
|
+
'relative_node_coordinate': relative_node_coordinate
|
|
1285
|
+
}
|
|
1286
|
+
}
|
|
1287
|
+
)
|
|
1288
|
+
|
|
1289
|
+
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1290
|
+
coordinate = tensor.node['coordinate']
|
|
1291
|
+
coordinate += tensor.aggregate('relative_node_coordinate', mode='mean')
|
|
1292
|
+
|
|
1293
|
+
# Original implementation seems to apply sum aggregation, which does not
|
|
1294
|
+
# seem work well for this implementation of `EGConv3D`, as it causes
|
|
1295
|
+
# large output values and large initial losses. The magnitude of the
|
|
1296
|
+
# aggregated values of a sum aggregation depends on the number of
|
|
1297
|
+
# neighbors, which may be many and may differ from node to node (or
|
|
1298
|
+
# graph to graph). Therefore, a mean mean aggregation is performed
|
|
1299
|
+
# instead:
|
|
1300
|
+
aggregate = tensor.aggregate('message', mode='mean')
|
|
1301
|
+
aggregate = keras.ops.concatenate([aggregate, tensor.node['feature']], axis=-1)
|
|
1302
|
+
# Simply added to silence warning ('no gradients for variables ...')
|
|
1303
|
+
aggregate += (0.0 * keras.ops.sum(coordinate))
|
|
1304
|
+
|
|
1305
|
+
return tensor.update(
|
|
1306
|
+
{
|
|
1307
|
+
'node': {
|
|
1308
|
+
'aggregate': aggregate,
|
|
1309
|
+
'coordinate': coordinate,
|
|
1310
|
+
},
|
|
1311
|
+
'edge': {
|
|
1312
|
+
'message': None,
|
|
1313
|
+
'relative_node_coordinate': None
|
|
1314
|
+
}
|
|
1315
|
+
}
|
|
1316
|
+
)
|
|
1317
|
+
|
|
1318
|
+
def get_config(self) -> dict:
|
|
1319
|
+
config = super().get_config()
|
|
1320
|
+
config.update({})
|
|
1321
|
+
return config
|
|
1322
|
+
|
|
1323
|
+
|
|
1324
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1325
|
+
class Readout(GraphLayer):
|
|
1326
|
+
|
|
1327
|
+
"""Readout layer.
|
|
1328
|
+
"""
|
|
1329
|
+
|
|
1330
|
+
def __init__(self, mode: str | None = None, **kwargs):
|
|
1331
|
+
kwargs['kernel_initializer'] = None
|
|
1332
|
+
kwargs['bias_initializer'] = None
|
|
1333
|
+
super().__init__(**kwargs)
|
|
1334
|
+
self.mode = mode
|
|
1335
|
+
mode = str(self.mode).lower()
|
|
1336
|
+
if mode.startswith('sum'):
|
|
1337
|
+
self._reduce_fn = keras.ops.segment_sum
|
|
1338
|
+
elif mode.startswith('max'):
|
|
1339
|
+
self._reduce_fn = keras.ops.segment_max
|
|
1340
|
+
else:
|
|
1341
|
+
self._reduce_fn = ops.segment_mean
|
|
1342
|
+
|
|
1343
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tf.Tensor:
|
|
1344
|
+
return self._reduce_fn(
|
|
1345
|
+
tensor.node['feature'], tensor.graph_indicator, tensor.num_subgraphs
|
|
1346
|
+
)
|
|
1347
|
+
|
|
1348
|
+
def get_config(self) -> dict:
|
|
1349
|
+
config = super().get_config()
|
|
1350
|
+
config['mode'] = self.mode
|
|
1351
|
+
return config
|
|
1352
|
+
|
|
1353
|
+
|
|
1354
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1355
|
+
class SuperReadout(GraphLayer):
|
|
1356
|
+
|
|
1357
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1358
|
+
if 'super' not in spec.node:
|
|
1359
|
+
raise ValueError(
|
|
1360
|
+
'Could not find `super` field in input.'
|
|
1361
|
+
)
|
|
1362
|
+
|
|
1363
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tf.Tensor:
|
|
1364
|
+
node_feature = tensor.node['feature']
|
|
1365
|
+
node_feature = keras.ops.where(
|
|
1366
|
+
tensor.node['super'][:, None], node_feature, 0.0
|
|
1367
|
+
)
|
|
1368
|
+
return keras.ops.segment_sum(
|
|
1369
|
+
node_feature, tensor.graph_indicator, tensor.num_subgraphs
|
|
1370
|
+
)
|
|
1371
|
+
|
|
1372
|
+
|
|
1373
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1374
|
+
class SubgraphReadout(GraphLayer):
|
|
1375
|
+
|
|
1376
|
+
def __init__(
|
|
1377
|
+
self,
|
|
1378
|
+
pad: bool = True,
|
|
1379
|
+
add_mask: bool | None = None,
|
|
1380
|
+
ignore_super_node: bool = True,
|
|
1381
|
+
**kwargs
|
|
1382
|
+
) -> None:
|
|
1383
|
+
super().__init__(**kwargs)
|
|
1384
|
+
self._pad = pad
|
|
1385
|
+
self._ignore_super_node = ignore_super_node
|
|
1386
|
+
self._add_mask = (
|
|
1387
|
+
add_mask if add_mask is not None else self._pad
|
|
1388
|
+
)
|
|
1389
|
+
if self._add_mask:
|
|
1390
|
+
self._readout_mask = _ReadoutMask()
|
|
1391
|
+
|
|
1392
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1393
|
+
if 'subgraph_indicator' not in spec.node:
|
|
1394
|
+
raise ValueError(
|
|
1395
|
+
'Could not find `subgraph_indicator` field in input.'
|
|
1396
|
+
)
|
|
1397
|
+
self._has_super = 'super' in spec.node
|
|
1398
|
+
|
|
1399
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tf.Tensor:
|
|
1400
|
+
|
|
1401
|
+
size = tensor.context['size']
|
|
1402
|
+
graph_indicator = tensor.graph_indicator
|
|
1403
|
+
subgraph_indicator = tensor.node['subgraph_indicator']
|
|
1404
|
+
node_feature = tensor.node['feature']
|
|
1405
|
+
|
|
1406
|
+
if self._has_super:
|
|
1407
|
+
if not self._ignore_super_node:
|
|
1408
|
+
super_node_feature = tf.boolean_mask(
|
|
1409
|
+
node_feature, tensor.node['super']
|
|
1410
|
+
)
|
|
1411
|
+
keep = (tensor.node['super'] == False)
|
|
1412
|
+
graph_indicator = tf.boolean_mask(graph_indicator, keep)
|
|
1413
|
+
subgraph_indicator = tf.boolean_mask(subgraph_indicator, keep)
|
|
1414
|
+
node_feature = tf.boolean_mask(node_feature, keep)
|
|
1415
|
+
size -= 1
|
|
1416
|
+
|
|
1417
|
+
num_subgraphs = keras.ops.segment_max(
|
|
1418
|
+
subgraph_indicator, graph_indicator
|
|
1419
|
+
)
|
|
1420
|
+
num_subgraphs += 1
|
|
1421
|
+
|
|
1422
|
+
def global_subgraph_indicator():
|
|
1423
|
+
incr = keras.ops.cumsum(num_subgraphs[:-1])
|
|
1424
|
+
incr = keras.ops.pad(incr, [(1, 0)])
|
|
1425
|
+
incr = keras.ops.repeat(incr, size)
|
|
1426
|
+
return subgraph_indicator + incr
|
|
1427
|
+
|
|
1428
|
+
readout = ops.segment_mean(node_feature, global_subgraph_indicator())
|
|
1429
|
+
readout = tf.RaggedTensor.from_row_lengths(readout, num_subgraphs)
|
|
1430
|
+
|
|
1431
|
+
if self._has_super and not self._ignore_super_node:
|
|
1432
|
+
readout += super_node_feature[:, None, :]
|
|
1433
|
+
|
|
1434
|
+
if not self._pad:
|
|
1435
|
+
return readout
|
|
1436
|
+
|
|
1437
|
+
return readout.to_tensor()
|
|
1438
|
+
|
|
1439
|
+
def compute_mask(self, inputs, previous_mask=None):
|
|
1440
|
+
if not self._add_mask:
|
|
1441
|
+
return None
|
|
1442
|
+
return self._readout_mask(inputs)
|
|
1443
|
+
|
|
1444
|
+
def get_config(self) -> dict:
|
|
1445
|
+
config = super().get_config()
|
|
1446
|
+
config['pad'] = self._pad
|
|
1447
|
+
config['add_mask'] = self._add_mask
|
|
1448
|
+
config['ignore_super_node'] = self._ignore_super_node
|
|
1449
|
+
return config
|
|
1450
|
+
|
|
1451
|
+
|
|
1452
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1453
|
+
class NodeEmbedding(GraphLayer):
|
|
1454
|
+
|
|
1455
|
+
"""Node embedding layer.
|
|
1456
|
+
|
|
1457
|
+
Embeds nodes based on its initial features.
|
|
1458
|
+
"""
|
|
1459
|
+
|
|
1460
|
+
def __init__(
|
|
1461
|
+
self,
|
|
1462
|
+
dim: int | None = None,
|
|
1463
|
+
intermediate_dim: int | None = None,
|
|
1464
|
+
intermediate_activation: str | keras.layers.Activation | None = 'relu',
|
|
1465
|
+
normalize: bool = False,
|
|
1466
|
+
embed_context: bool = False,
|
|
1467
|
+
num_wildcards: int | None = None,
|
|
1468
|
+
**kwargs
|
|
1469
|
+
) -> None:
|
|
1470
|
+
super().__init__(**kwargs)
|
|
1471
|
+
self.dim = dim
|
|
1472
|
+
self._intermediate_dim = intermediate_dim
|
|
1473
|
+
self._intermediate_activation = keras.activations.get(
|
|
1474
|
+
intermediate_activation
|
|
1475
|
+
)
|
|
1476
|
+
self._normalize = normalize
|
|
1477
|
+
self._embed_context = embed_context
|
|
1478
|
+
self._num_wildcards = num_wildcards
|
|
1479
|
+
|
|
1480
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1481
|
+
feature_dim = spec.node['feature'].shape[-1]
|
|
1482
|
+
if not self.dim:
|
|
1483
|
+
self.dim = feature_dim
|
|
1484
|
+
if not self._intermediate_dim:
|
|
1485
|
+
self._intermediate_dim = self.dim * 2
|
|
1486
|
+
self._node_dense = self.get_dense(
|
|
1487
|
+
self._intermediate_dim, activation=self._intermediate_activation
|
|
1488
|
+
)
|
|
1489
|
+
self._has_wildcard = 'wildcard' in spec.node
|
|
1490
|
+
self._has_super = 'super' in spec.node
|
|
1491
|
+
has_context_feature = 'feature' in spec.context
|
|
1492
|
+
if not has_context_feature:
|
|
1493
|
+
self._embed_context = False
|
|
1494
|
+
if self._has_super and not self._embed_context:
|
|
1495
|
+
self._super_feature = self.get_weight(
|
|
1496
|
+
shape=[self._intermediate_dim], name='super_node_feature'
|
|
1497
|
+
)
|
|
1498
|
+
if self._embed_context:
|
|
1499
|
+
self._context_dense = self.get_dense(
|
|
1500
|
+
self._intermediate_dim, activation=self._intermediate_activation
|
|
1501
|
+
)
|
|
1502
|
+
if self._has_wildcard:
|
|
1503
|
+
if self._num_wildcards is None:
|
|
1504
|
+
warnings.warn(
|
|
1505
|
+
"Found wildcards in input, but `num_wildcards` is set to `None`. "
|
|
1506
|
+
"Automatically setting `num_wildcards` to 1. Please specify `num_wildcards>1` "
|
|
1507
|
+
"if the layer should distinguish between different types of wildcards."
|
|
1508
|
+
)
|
|
1509
|
+
self._num_wildcards = 1
|
|
1510
|
+
self._wildcard_features = self.get_weight(
|
|
1511
|
+
shape=[self._num_wildcards, self._intermediate_dim],
|
|
1512
|
+
name='wildcard_node_features'
|
|
1513
|
+
)
|
|
1514
|
+
if not self._normalize:
|
|
1515
|
+
self._norm = keras.layers.Identity()
|
|
1516
|
+
elif str(self._normalize).lower().startswith('layer'):
|
|
1517
|
+
self._norm = keras.layers.LayerNormalization()
|
|
1518
|
+
else:
|
|
1519
|
+
self._norm = keras.layers.BatchNormalization()
|
|
1520
|
+
self._dense = self.get_dense(self.dim)
|
|
1521
|
+
|
|
1522
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1523
|
+
feature = self._node_dense(tensor.node['feature'])
|
|
1524
|
+
|
|
1525
|
+
if self._has_super and not self._embed_context:
|
|
1526
|
+
super_mask = keras.ops.expand_dims(tensor.node['super'], 1)
|
|
1527
|
+
super_feature = self._intermediate_activation(self._super_feature)
|
|
1528
|
+
feature = keras.ops.where(super_mask, super_feature, feature)
|
|
1529
|
+
|
|
1530
|
+
if self._embed_context:
|
|
1531
|
+
context_feature = self._context_dense(tensor.context['feature'])
|
|
1532
|
+
feature = ops.scatter_update(feature, tensor.node['super'], context_feature)
|
|
1533
|
+
tensor = tensor.update({'context': {'feature': None}})
|
|
1534
|
+
|
|
1535
|
+
if self._has_wildcard:
|
|
1536
|
+
wildcard = tensor.node['wildcard']
|
|
1537
|
+
wildcard_indices = keras.ops.where(wildcard > 0)[0]
|
|
1538
|
+
wildcard = ops.gather(wildcard, wildcard_indices)
|
|
1539
|
+
wildcard_feature = ops.gather(
|
|
1540
|
+
self._wildcard_features, keras.ops.mod(wildcard-1, self._num_wildcards)
|
|
1541
|
+
)
|
|
1542
|
+
wildcard_feature = self._intermediate_activation(wildcard_feature)
|
|
1543
|
+
feature = ops.scatter_update(feature, wildcard_indices, wildcard_feature)
|
|
1544
|
+
|
|
1545
|
+
feature = self._norm(feature)
|
|
1546
|
+
feature = self._dense(feature)
|
|
1547
|
+
|
|
1548
|
+
return tensor.update({'node': {'feature': feature}})
|
|
1549
|
+
|
|
1550
|
+
def get_config(self) -> dict:
|
|
1551
|
+
config = super().get_config()
|
|
1552
|
+
config.update({
|
|
1553
|
+
'dim': self.dim,
|
|
1554
|
+
'intermediate_dim': self._intermediate_dim,
|
|
1555
|
+
'intermediate_activation': keras.activations.serialize(
|
|
1556
|
+
self._intermediate_activation
|
|
1557
|
+
),
|
|
1558
|
+
'normalize': self._normalize,
|
|
1559
|
+
'embed_context': self._embed_context,
|
|
1560
|
+
'num_wildcards': self._num_wildcards,
|
|
1561
|
+
})
|
|
1562
|
+
return config
|
|
1563
|
+
|
|
1564
|
+
|
|
1565
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1566
|
+
class EdgeEmbedding(GraphLayer):
|
|
1567
|
+
|
|
1568
|
+
"""Edge embedding layer.
|
|
1569
|
+
|
|
1570
|
+
Embeds edges based on its initial features.
|
|
1571
|
+
"""
|
|
1572
|
+
|
|
1573
|
+
def __init__(
|
|
1574
|
+
self,
|
|
1575
|
+
dim: int = None,
|
|
1576
|
+
intermediate_dim: int | None = None,
|
|
1577
|
+
intermediate_activation: str | keras.layers.Activation | None = 'relu',
|
|
1578
|
+
normalize: bool = False,
|
|
1579
|
+
**kwargs
|
|
1580
|
+
) -> None:
|
|
1581
|
+
super().__init__(**kwargs)
|
|
1582
|
+
self.dim = dim
|
|
1583
|
+
self._intermediate_dim = intermediate_dim
|
|
1584
|
+
self._intermediate_activation = keras.activations.get(
|
|
1585
|
+
intermediate_activation
|
|
1586
|
+
)
|
|
1587
|
+
self._normalize = normalize
|
|
1588
|
+
|
|
1589
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1590
|
+
feature_dim = spec.edge['feature'].shape[-1]
|
|
1591
|
+
if not self.dim:
|
|
1592
|
+
self.dim = feature_dim
|
|
1593
|
+
if not self._intermediate_dim:
|
|
1594
|
+
self._intermediate_dim = self.dim * 2
|
|
1595
|
+
self._edge_dense = self.get_dense(
|
|
1596
|
+
self._intermediate_dim, activation=self._intermediate_activation
|
|
1597
|
+
)
|
|
1598
|
+
self._self_loop_feature = self.get_weight(
|
|
1599
|
+
shape=[self._intermediate_dim], name='self_loop_edge_feature'
|
|
1600
|
+
)
|
|
1601
|
+
self._has_super = 'super' in spec.edge
|
|
1602
|
+
if self._has_super:
|
|
1603
|
+
self._super_feature = self.get_weight(
|
|
1604
|
+
shape=[self._intermediate_dim], name='super_edge_feature'
|
|
1605
|
+
)
|
|
1606
|
+
if not self._normalize:
|
|
1607
|
+
self._norm = keras.layers.Identity()
|
|
1608
|
+
elif str(self._normalize).lower().startswith('layer'):
|
|
1609
|
+
self._norm = keras.layers.LayerNormalization()
|
|
1610
|
+
else:
|
|
1611
|
+
self._norm = keras.layers.BatchNormalization()
|
|
1612
|
+
self._dense = self.get_dense(self.dim)
|
|
1613
|
+
|
|
1614
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1615
|
+
feature = self._edge_dense(tensor.edge['feature'])
|
|
1616
|
+
|
|
1617
|
+
if self._has_super:
|
|
1618
|
+
super_mask = keras.ops.expand_dims(tensor.edge['super'], 1)
|
|
1619
|
+
super_feature = self._intermediate_activation(self._super_feature)
|
|
1620
|
+
feature = keras.ops.where(super_mask, super_feature, feature)
|
|
1621
|
+
|
|
1622
|
+
self_loop_mask = keras.ops.expand_dims(tensor.edge['source'] == tensor.edge['target'], 1)
|
|
1623
|
+
self_loop_feature = self._intermediate_activation(self._self_loop_feature)
|
|
1624
|
+
feature = keras.ops.where(self_loop_mask, self_loop_feature, feature)
|
|
1625
|
+
feature = self._norm(feature)
|
|
1626
|
+
feature = self._dense(feature)
|
|
1627
|
+
return tensor.update({'edge': {'feature': feature}})
|
|
1628
|
+
|
|
1629
|
+
def get_config(self) -> dict:
|
|
1630
|
+
config = super().get_config()
|
|
1631
|
+
config.update({
|
|
1632
|
+
'dim': self.dim,
|
|
1633
|
+
'intermediate_dim': self._intermediate_dim,
|
|
1634
|
+
'intermediate_activation': keras.activations.serialize(
|
|
1635
|
+
self._intermediate_activation
|
|
1636
|
+
),
|
|
1637
|
+
'normalize': self._normalize,
|
|
1638
|
+
})
|
|
1639
|
+
return config
|
|
1640
|
+
|
|
1641
|
+
|
|
1642
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1643
|
+
class AddContext(GraphLayer):
|
|
1644
|
+
|
|
1645
|
+
"""Context adding layer.
|
|
1646
|
+
|
|
1647
|
+
Adds context to super nodes or nodes, depending on whether super nodes exist.
|
|
1648
|
+
"""
|
|
1649
|
+
|
|
1650
|
+
def __init__(
|
|
1651
|
+
self,
|
|
1652
|
+
field: str = 'feature',
|
|
1653
|
+
intermediate_dim: int | None = None,
|
|
1654
|
+
intermediate_activation: str | keras.layers.Activation | None = 'relu',
|
|
1655
|
+
drop: bool = False,
|
|
1656
|
+
normalize: bool = False,
|
|
1657
|
+
num_categories: int | None = None,
|
|
1658
|
+
**kwargs
|
|
1659
|
+
) -> None:
|
|
1660
|
+
super().__init__(**kwargs)
|
|
1661
|
+
self._field = field
|
|
1662
|
+
self._drop = drop
|
|
1663
|
+
self._intermediate_dim = intermediate_dim
|
|
1664
|
+
self._intermediate_activation = keras.activations.get(
|
|
1665
|
+
intermediate_activation
|
|
1666
|
+
)
|
|
1667
|
+
self._normalize = normalize
|
|
1668
|
+
self._num_categories = num_categories
|
|
1669
|
+
|
|
1670
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1671
|
+
is_categorical = spec.context[self._field].dtype.is_integer
|
|
1672
|
+
if is_categorical and not self._num_categories:
|
|
1673
|
+
raise ValueError(
|
|
1674
|
+
f'Found context ({self._field}) to be categorical (`int` dtype), but `num_categories`'
|
|
1675
|
+
'to be `None`. Please specify `num_categories` for categorical context.'
|
|
1676
|
+
)
|
|
1677
|
+
elif not is_categorical and self._num_categories:
|
|
1678
|
+
warnings.warn(
|
|
1679
|
+
f'`num_categories` is set to {self._num_categories}, but found context to be '
|
|
1680
|
+
'continuous (`float` dtype). Layer will cast context from `float` to `int` '
|
|
1681
|
+
'before one-hot encoding it.'
|
|
1682
|
+
)
|
|
1683
|
+
feature_dim = spec.node['feature'].shape[-1]
|
|
1684
|
+
self._has_super_node = 'super' in spec.node
|
|
1685
|
+
if self._intermediate_dim is None:
|
|
1686
|
+
self._intermediate_dim = feature_dim * 2
|
|
1687
|
+
self._intermediate_dense = self.get_dense(
|
|
1688
|
+
self._intermediate_dim, activation=self._intermediate_activation
|
|
1689
|
+
)
|
|
1690
|
+
self._final_dense = self.get_dense(feature_dim)
|
|
1691
|
+
if not self._normalize:
|
|
1692
|
+
self._intermediate_norm = keras.layers.Identity()
|
|
1693
|
+
elif str(self._normalize).lower().startswith('layer'):
|
|
1694
|
+
self._intermediate_norm = keras.layers.LayerNormalization()
|
|
1695
|
+
else:
|
|
1696
|
+
self._intermediate_norm = keras.layers.BatchNormalization()
|
|
1697
|
+
|
|
1698
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1699
|
+
context = tensor.context[self._field]
|
|
1700
|
+
if self._num_categories:
|
|
1701
|
+
if context.dtype.is_floating:
|
|
1702
|
+
context = keras.ops.cast(context, dtype=tensor.edge['source'].dtype)
|
|
1703
|
+
context = keras.utils.to_categorical(context, self._num_categories)
|
|
1704
|
+
elif len(keras.ops.shape(context)) == 1:
|
|
1705
|
+
context = keras.ops.expand_dims(context, axis=1)
|
|
1706
|
+
context = self._intermediate_dense(context)
|
|
1707
|
+
context = self._intermediate_norm(context)
|
|
1708
|
+
context = self._final_dense(context)
|
|
1709
|
+
if self._has_super_node:
|
|
1710
|
+
node_feature = ops.scatter_add(
|
|
1711
|
+
tensor.node['feature'], tensor.node['super'], context
|
|
1712
|
+
)
|
|
1713
|
+
else:
|
|
1714
|
+
node_feature = (
|
|
1715
|
+
tensor.node['feature'] + ops.gather(context, tensor.graph_indicator)
|
|
1716
|
+
)
|
|
1717
|
+
data = {'node': {'feature': node_feature}}
|
|
1718
|
+
if self._drop:
|
|
1719
|
+
data['context'] = {self._field: None}
|
|
1720
|
+
return tensor.update(data)
|
|
1721
|
+
|
|
1722
|
+
def get_config(self) -> dict:
|
|
1723
|
+
config = super().get_config()
|
|
1724
|
+
config.update({
|
|
1725
|
+
'field': self._field,
|
|
1726
|
+
'intermediate_dim': self._intermediate_dim,
|
|
1727
|
+
'intermediate_activation': keras.activations.serialize(
|
|
1728
|
+
self._intermediate_activation
|
|
1729
|
+
),
|
|
1730
|
+
'drop': self._drop,
|
|
1731
|
+
'normalize': self._normalize,
|
|
1732
|
+
'num_categories': self._num_categories,
|
|
1733
|
+
})
|
|
1734
|
+
return config
|
|
1735
|
+
|
|
1736
|
+
|
|
1737
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1738
|
+
class GraphNetwork(GraphLayer):
|
|
1739
|
+
|
|
1740
|
+
"""Graph neural network.
|
|
1741
|
+
|
|
1742
|
+
Sequentially calls graph layers (`GraphLayer`) and concatenates its output.
|
|
1743
|
+
|
|
1744
|
+
Arguments:
|
|
1745
|
+
layers (list):
|
|
1746
|
+
A list of graph layers.
|
|
1747
|
+
"""
|
|
1748
|
+
|
|
1749
|
+
def __init__(self, layers: list[GraphLayer], **kwargs) -> None:
|
|
1750
|
+
super().__init__(**kwargs)
|
|
1751
|
+
self.layers = layers
|
|
1752
|
+
self._update_edge_feature = False
|
|
1753
|
+
|
|
1754
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1755
|
+
units = self.layers[0].units
|
|
1756
|
+
node_feature_dim = spec.node['feature'].shape[-1]
|
|
1757
|
+
self._update_node_feature = node_feature_dim != units
|
|
1758
|
+
if self._update_node_feature:
|
|
1759
|
+
warnings.warn(
|
|
1760
|
+
'Node feature dim does not match `units` of the first layer. '
|
|
1761
|
+
'Applying a projection layer to node features to match `units`.',
|
|
1762
|
+
)
|
|
1763
|
+
self._node_dense = self.get_dense(units)
|
|
1764
|
+
self._has_edge_feature = 'feature' in spec.edge
|
|
1765
|
+
if self._has_edge_feature:
|
|
1766
|
+
edge_feature_dim = spec.edge['feature'].shape[-1]
|
|
1767
|
+
self._update_edge_feature = edge_feature_dim != units
|
|
1768
|
+
if self._update_edge_feature:
|
|
1769
|
+
warnings.warn(
|
|
1770
|
+
'Edge feature dim does not match `units` of the first layer. '
|
|
1771
|
+
'Applying projection layer to edge features to match `units`.'
|
|
1772
|
+
)
|
|
1773
|
+
self._edge_dense = self.get_dense(units)
|
|
1774
|
+
|
|
1775
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1776
|
+
x = tensors.to_dict(tensor)
|
|
1777
|
+
if self._update_node_feature:
|
|
1778
|
+
x['node']['feature'] = self._node_dense(tensor.node['feature'])
|
|
1779
|
+
if self._has_edge_feature and self._update_edge_feature:
|
|
1780
|
+
x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
|
|
1781
|
+
outputs = [x['node']['feature']]
|
|
1782
|
+
for layer in self.layers:
|
|
1783
|
+
x = layer(x)
|
|
1784
|
+
outputs.append(x['node']['feature'])
|
|
1785
|
+
return tensor.update(
|
|
1786
|
+
{
|
|
1787
|
+
'node': {
|
|
1788
|
+
'feature': keras.ops.concatenate(outputs, axis=-1)
|
|
1789
|
+
}
|
|
1790
|
+
}
|
|
1791
|
+
)
|
|
1792
|
+
|
|
1793
|
+
def tape_propagate(
|
|
1794
|
+
self,
|
|
1795
|
+
tensor: tensors.GraphTensor,
|
|
1796
|
+
tape: tf.GradientTape,
|
|
1797
|
+
training: bool | None = None,
|
|
1798
|
+
) -> tuple[tensors.GraphTensor, list[tf.Tensor]]:
|
|
1799
|
+
"""Performs the propagation with a `GradientTape`.
|
|
1800
|
+
|
|
1801
|
+
Performs the same forward pass as `propagate` but with a `GradientTape`
|
|
1802
|
+
watching intermediate node features.
|
|
1803
|
+
|
|
1804
|
+
Arguments:
|
|
1805
|
+
tensor (tensors.GraphTensor):
|
|
1806
|
+
The graph input.
|
|
1807
|
+
"""
|
|
1808
|
+
if isinstance(tensor, tensors.GraphTensor):
|
|
1809
|
+
x = tensors.to_dict(tensor)
|
|
1810
|
+
else:
|
|
1811
|
+
x = tensor
|
|
1812
|
+
if self._update_node_feature:
|
|
1813
|
+
x['node']['feature'] = self._node_dense(tensor.node['feature'])
|
|
1814
|
+
if self._update_edge_feature:
|
|
1815
|
+
x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
|
|
1816
|
+
tape.watch(x['node']['feature'])
|
|
1817
|
+
outputs = [x['node']['feature']]
|
|
1818
|
+
for layer in self.layers:
|
|
1819
|
+
x = layer(x, training=training)
|
|
1820
|
+
tape.watch(x['node']['feature'])
|
|
1821
|
+
outputs.append(x['node']['feature'])
|
|
1822
|
+
|
|
1823
|
+
tensor = tensor.update(
|
|
1824
|
+
{
|
|
1825
|
+
'node': {
|
|
1826
|
+
'feature': keras.ops.concatenate(outputs, axis=-1)
|
|
1827
|
+
}
|
|
1828
|
+
}
|
|
1829
|
+
)
|
|
1830
|
+
return tensor, outputs
|
|
1831
|
+
|
|
1832
|
+
def get_config(self) -> dict:
|
|
1833
|
+
config = super().get_config()
|
|
1834
|
+
config.update(
|
|
1835
|
+
{
|
|
1836
|
+
'layers': [
|
|
1837
|
+
keras.layers.serialize(layer) for layer in self.layers
|
|
1838
|
+
]
|
|
1839
|
+
}
|
|
1840
|
+
)
|
|
1841
|
+
return config
|
|
1842
|
+
|
|
1843
|
+
@classmethod
|
|
1844
|
+
def from_config(cls, config: dict) -> 'GraphNetwork':
|
|
1845
|
+
config['layers'] = [
|
|
1846
|
+
keras.layers.deserialize(layer) for layer in config['layers']
|
|
1847
|
+
]
|
|
1848
|
+
return super().from_config(config)
|
|
1849
|
+
|
|
1850
|
+
|
|
1851
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1852
|
+
class GaussianParams(keras.layers.Dense):
|
|
1853
|
+
'''Gaussian parameters.
|
|
1854
|
+
|
|
1855
|
+
Computes loc and scale via a dense layer. Should be implemented
|
|
1856
|
+
as the last layer in a model and paired with `losses.GaussianNLL`.
|
|
1857
|
+
|
|
1858
|
+
The loc and scale parameters (resulting from this layer) are concatenated
|
|
1859
|
+
together along the last axis, resulting in a single output tensor.
|
|
1860
|
+
|
|
1861
|
+
Args:
|
|
1862
|
+
events (int):
|
|
1863
|
+
The number of events. If the model makes a single prediction per example,
|
|
1864
|
+
then the number of events should be 1. If the model makes multiple predictions
|
|
1865
|
+
per example, then the number of events should be greater than 1.
|
|
1866
|
+
Default to 1.
|
|
1867
|
+
kwargs:
|
|
1868
|
+
See `keras.layers.Dense` documentation. `activation` will be applied
|
|
1869
|
+
to `loc` only. `scale` is automatically softplus activated.
|
|
1870
|
+
'''
|
|
1871
|
+
def __init__(self, events: int = 1, **kwargs):
|
|
1872
|
+
units = kwargs.pop('units', None)
|
|
1873
|
+
activation = kwargs.pop('activation', None)
|
|
1874
|
+
if units:
|
|
1875
|
+
if units % 2 != 0:
|
|
1876
|
+
raise ValueError(
|
|
1877
|
+
'`units` needs to be divisble by 2 as `units` = 2 x `events`.'
|
|
1878
|
+
)
|
|
1879
|
+
else:
|
|
1880
|
+
units = int(events * 2)
|
|
1881
|
+
super().__init__(units=units, **kwargs)
|
|
1882
|
+
self.events = events
|
|
1883
|
+
self.loc_activation = keras.activations.get(activation)
|
|
1884
|
+
|
|
1885
|
+
def call(self, inputs, **kwargs):
|
|
1886
|
+
loc_and_scale = super().call(inputs, **kwargs)
|
|
1887
|
+
loc = loc_and_scale[..., :self.events]
|
|
1888
|
+
scale = loc_and_scale[..., self.events:]
|
|
1889
|
+
scale = keras.ops.softplus(scale) + keras.backend.epsilon()
|
|
1890
|
+
loc = self.loc_activation(loc)
|
|
1891
|
+
return keras.ops.concatenate([loc, scale], axis=-1)
|
|
1892
|
+
|
|
1893
|
+
def get_config(self):
|
|
1894
|
+
config = super().get_config()
|
|
1895
|
+
config['events'] = self.events
|
|
1896
|
+
config['units'] = None
|
|
1897
|
+
config['activation'] = keras.activations.serialize(self.loc_activation)
|
|
1898
|
+
return config
|
|
1899
|
+
|
|
1900
|
+
|
|
1901
|
+
def Input(spec: tensors.GraphTensor.Spec) -> dict:
|
|
1902
|
+
"""Used to specify inputs to a functional model.
|
|
1903
|
+
|
|
1904
|
+
Example:
|
|
1905
|
+
|
|
1906
|
+
>>> import molcraft
|
|
1907
|
+
>>> import keras
|
|
1908
|
+
>>>
|
|
1909
|
+
>>> featurizer = molcraft.featurizers.MolGraphFeaturizer()
|
|
1910
|
+
>>> graph = featurizer([('N[C@@H](C)C(=O)O', 1.0), ('N[C@@H](CS)C(=O)O', 2.0)])
|
|
1911
|
+
>>>
|
|
1912
|
+
>>> model = molcraft.models.GraphModel.from_layers(
|
|
1913
|
+
... molcraft.layers.Input(graph.spec),
|
|
1914
|
+
... molcraft.layers.NodeEmbedding(128),
|
|
1915
|
+
... molcraft.layers.EdgeEmbedding(128),
|
|
1916
|
+
... molcraft.layers.GraphTransformer(128),
|
|
1917
|
+
... molcraft.layers.GraphTransformer(128),
|
|
1918
|
+
... molcraft.layers.Readout('mean'),
|
|
1919
|
+
... molcraft.layers.Dense(1)
|
|
1920
|
+
... ])
|
|
1921
|
+
"""
|
|
1922
|
+
|
|
1923
|
+
# Currently, Keras (3.8.0) does not support extension types.
|
|
1924
|
+
# So for now, this function will unpack the `GraphTensor.Spec` and
|
|
1925
|
+
# return a dictionary of nested tensor specs. However, the corresponding
|
|
1926
|
+
# nest of tensors will temporarily be converted to a `GraphTensor` by the
|
|
1927
|
+
# `GraphLayer`, to levarage the utility of a `GraphTensor` object.
|
|
1928
|
+
inputs = {}
|
|
1929
|
+
for outer_field, data in spec.__dict__.items():
|
|
1930
|
+
inputs[outer_field] = {}
|
|
1931
|
+
for inner_field, nested_spec in data.items():
|
|
1932
|
+
if inner_field in ['label', 'sample_weight']:
|
|
1933
|
+
# Remove label and sample_weight from the symbolic input as
|
|
1934
|
+
# a functional model is strict for what input can be passed.
|
|
1935
|
+
continue
|
|
1936
|
+
kwargs = {
|
|
1937
|
+
'shape': nested_spec.shape[1:],
|
|
1938
|
+
'dtype': nested_spec.dtype,
|
|
1939
|
+
'name': f'{outer_field}_{inner_field}'
|
|
1940
|
+
}
|
|
1941
|
+
if isinstance(nested_spec, tf.RaggedTensorSpec):
|
|
1942
|
+
# kwargs['ragged'] = True
|
|
1943
|
+
raise ValueError(
|
|
1944
|
+
'Graph layers only supports graph input with nested `tf.Tensor` values.'
|
|
1945
|
+
)
|
|
1946
|
+
try:
|
|
1947
|
+
inputs[outer_field][inner_field] = keras.Input(**kwargs)
|
|
1948
|
+
except TypeError:
|
|
1949
|
+
raise ValueError(
|
|
1950
|
+
"`keras.Input` does not currently support ragged tensors. For now, "
|
|
1951
|
+
"pass the `Spec` of a 'flat' `GraphTensor` to `GNNInput`."
|
|
1952
|
+
)
|
|
1953
|
+
return inputs
|
|
1954
|
+
|
|
1955
|
+
|
|
1956
|
+
class _ReadoutMask(GraphLayer):
|
|
1957
|
+
|
|
1958
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1959
|
+
if 'subgraph_indicator' not in spec.node:
|
|
1960
|
+
raise ValueError(
|
|
1961
|
+
'Could not find `subgraph_indicator` field in input.'
|
|
1962
|
+
)
|
|
1963
|
+
self._has_super = 'super' in spec.node
|
|
1964
|
+
|
|
1965
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tf.Tensor:
|
|
1966
|
+
|
|
1967
|
+
size = tensor.context['size']
|
|
1968
|
+
graph_indicator = tensor.graph_indicator
|
|
1969
|
+
subgraph_indicator = tensor.node['subgraph_indicator']
|
|
1970
|
+
node_feature = tensor.node['feature']
|
|
1971
|
+
|
|
1972
|
+
if self._has_super:
|
|
1973
|
+
keep = (tensor.node['super'] == False)
|
|
1974
|
+
graph_indicator = tf.boolean_mask(graph_indicator, keep)
|
|
1975
|
+
subgraph_indicator = tf.boolean_mask(subgraph_indicator, keep)
|
|
1976
|
+
node_feature = tf.boolean_mask(node_feature, keep)
|
|
1977
|
+
size -= 1
|
|
1978
|
+
|
|
1979
|
+
num_subgraphs = keras.ops.segment_max(
|
|
1980
|
+
subgraph_indicator, graph_indicator
|
|
1981
|
+
)
|
|
1982
|
+
num_subgraphs += 1
|
|
1983
|
+
max_len = keras.ops.max(num_subgraphs)
|
|
1984
|
+
mask = tf.sequence_mask(num_subgraphs, maxlen=max_len)
|
|
1985
|
+
return mask
|
|
1986
|
+
|
|
1987
|
+
|
|
1988
|
+
def _serialize_spec(spec: tensors.GraphTensor.Spec) -> dict:
|
|
1989
|
+
serialized_spec = {}
|
|
1990
|
+
for outer_field, data in spec.__dict__.items():
|
|
1991
|
+
serialized_spec[outer_field] = {}
|
|
1992
|
+
for inner_field, inner_spec in data.items():
|
|
1993
|
+
if inner_field in ['label', 'sample_weight']:
|
|
1994
|
+
continue
|
|
1995
|
+
serialized_spec[outer_field][inner_field] = {
|
|
1996
|
+
'shape': inner_spec.shape.as_list(),
|
|
1997
|
+
'dtype': inner_spec.dtype.name,
|
|
1998
|
+
'name': inner_spec.name,
|
|
1999
|
+
}
|
|
2000
|
+
return serialized_spec
|
|
2001
|
+
|
|
2002
|
+
def _deserialize_spec(serialized_spec: dict) -> tensors.GraphTensor.Spec:
|
|
2003
|
+
deserialized_spec = {}
|
|
2004
|
+
for outer_field, data in serialized_spec.items():
|
|
2005
|
+
deserialized_spec[outer_field] = {}
|
|
2006
|
+
for inner_field, inner_spec in data.items():
|
|
2007
|
+
deserialized_spec[outer_field][inner_field] = tf.TensorSpec(
|
|
2008
|
+
inner_spec['shape'], inner_spec['dtype'], inner_spec['name']
|
|
2009
|
+
)
|
|
2010
|
+
return tensors.GraphTensor.Spec(**deserialized_spec)
|
|
2011
|
+
|
|
2012
|
+
def _spec_from_inputs(inputs):
|
|
2013
|
+
symbolic_inputs = keras.backend.is_keras_tensor(
|
|
2014
|
+
tf.nest.flatten(inputs)[0]
|
|
2015
|
+
)
|
|
2016
|
+
if not symbolic_inputs:
|
|
2017
|
+
nested_specs = tf.nest.map_structure(
|
|
2018
|
+
tf.type_spec_from_value, inputs
|
|
2019
|
+
)
|
|
2020
|
+
else:
|
|
2021
|
+
nested_specs = tf.nest.map_structure(
|
|
2022
|
+
lambda t: tf.TensorSpec(t.shape, t.dtype), inputs
|
|
2023
|
+
)
|
|
2024
|
+
if isinstance(nested_specs, tensors.GraphTensor.Spec):
|
|
2025
|
+
spec = nested_specs
|
|
2026
|
+
return spec
|
|
2027
|
+
return tensors.GraphTensor.Spec(**nested_specs)
|
|
2028
|
+
|
|
2029
|
+
def _propagate_kwargs(func) -> bool:
|
|
2030
|
+
signature = inspect.signature(func)
|
|
2031
|
+
return any(
|
|
2032
|
+
(param.kind == inspect.Parameter.VAR_KEYWORD) or (param.name == 'training')
|
|
2033
|
+
for param in signature.parameters.values()
|
|
2034
|
+
)
|