molcraft 0.1.0a1__py3-none-any.whl → 0.1.0a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of molcraft might be problematic. Click here for more details.
- molcraft/__init__.py +1 -1
- molcraft/datasets.py +123 -0
- molcraft/experimental/peptides.py +28 -67
- molcraft/featurizers.py +66 -26
- molcraft/layers.py +792 -592
- molcraft/models.py +1 -2
- molcraft/tensors.py +33 -12
- {molcraft-0.1.0a1.dist-info → molcraft-0.1.0a2.dist-info}/METADATA +68 -1
- molcraft-0.1.0a2.dist-info/RECORD +20 -0
- molcraft-0.1.0a1.dist-info/RECORD +0 -19
- {molcraft-0.1.0a1.dist-info → molcraft-0.1.0a2.dist-info}/WHEEL +0 -0
- {molcraft-0.1.0a1.dist-info → molcraft-0.1.0a2.dist-info}/licenses/LICENSE +0 -0
- {molcraft-0.1.0a1.dist-info → molcraft-0.1.0a2.dist-info}/top_level.txt +0 -0
molcraft/layers.py
CHANGED
|
@@ -206,12 +206,62 @@ class GraphLayer(keras.layers.Layer):
|
|
|
206
206
|
class GraphConv(GraphLayer):
|
|
207
207
|
|
|
208
208
|
"""Base graph neural network layer.
|
|
209
|
+
|
|
210
|
+
For normalization and skip connection to work, the `GraphConv` subclass
|
|
211
|
+
requires the (node feature) output of `aggregate` and `update` to have a
|
|
212
|
+
dimension of `self.units`, respectively.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
units:
|
|
216
|
+
The number of units.
|
|
217
|
+
normalize:
|
|
218
|
+
Whether `LayerNormalization` should be applied to the (node feature) output
|
|
219
|
+
of the `aggregate` step. While normalization is recommended, it is not used
|
|
220
|
+
by default.
|
|
221
|
+
skip_connection:
|
|
222
|
+
Whether (node feature) input should be added to the (node feature) output.
|
|
223
|
+
If (node feature) input dim is not equal to `units`, a projection layer will
|
|
224
|
+
automatically project the residual before adding it to the output. While skip
|
|
225
|
+
connection is recommended, it is not used by default.
|
|
226
|
+
kwargs:
|
|
227
|
+
See arguments of `GraphLayer`.
|
|
209
228
|
"""
|
|
210
229
|
|
|
211
|
-
def __init__(
|
|
230
|
+
def __init__(
|
|
231
|
+
self,
|
|
232
|
+
units: int,
|
|
233
|
+
normalize: bool = False,
|
|
234
|
+
skip_connection: bool = False,
|
|
235
|
+
**kwargs
|
|
236
|
+
) -> None:
|
|
212
237
|
super().__init__(**kwargs)
|
|
213
238
|
self.units = units
|
|
214
|
-
|
|
239
|
+
self._normalize_aggregate = normalize
|
|
240
|
+
self._skip_connection = skip_connection
|
|
241
|
+
|
|
242
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
243
|
+
node_feature_dim = spec.node['feature'].shape[-1]
|
|
244
|
+
self._project_input_node_feature = (
|
|
245
|
+
self._skip_connection and (node_feature_dim != self.units)
|
|
246
|
+
)
|
|
247
|
+
if self._project_input_node_feature:
|
|
248
|
+
warn(
|
|
249
|
+
'`skip_connection` is set to `True`, but found incompatible dim '
|
|
250
|
+
'between input (node feature dim) and output (`self.units`). '
|
|
251
|
+
'Automatically applying a projection layer to residual to '
|
|
252
|
+
'match input and output. '
|
|
253
|
+
)
|
|
254
|
+
self._residual_projection = self.get_dense(
|
|
255
|
+
self.units, name='residual_projection'
|
|
256
|
+
)
|
|
257
|
+
if self._normalize_aggregate:
|
|
258
|
+
self._aggregation_norm = keras.layers.LayerNormalization(
|
|
259
|
+
name='aggregation_normalizer'
|
|
260
|
+
)
|
|
261
|
+
self._aggregation_norm.build([None, self.units])
|
|
262
|
+
|
|
263
|
+
super().build(spec)
|
|
264
|
+
|
|
215
265
|
@abc.abstractmethod
|
|
216
266
|
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
217
267
|
"""Compute messages.
|
|
@@ -256,479 +306,455 @@ class GraphConv(GraphLayer):
|
|
|
256
306
|
tensor:
|
|
257
307
|
A `GraphTensor` instance.
|
|
258
308
|
"""
|
|
309
|
+
|
|
310
|
+
if self._skip_connection:
|
|
311
|
+
input_node_feature = tensor.node['feature']
|
|
312
|
+
if self._project_input_node_feature:
|
|
313
|
+
input_node_feature = self._residual_projection(input_node_feature)
|
|
314
|
+
|
|
259
315
|
tensor = self.message(tensor)
|
|
260
316
|
tensor = self.aggregate(tensor)
|
|
261
|
-
tensor = self.update(tensor)
|
|
262
|
-
return tensor
|
|
263
317
|
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
'units': self.units
|
|
268
|
-
})
|
|
269
|
-
return config
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
273
|
-
class Projection(GraphLayer):
|
|
274
|
-
"""Base graph projection layer.
|
|
275
|
-
"""
|
|
276
|
-
def __init__(
|
|
277
|
-
self,
|
|
278
|
-
units: int = None,
|
|
279
|
-
activation: str = None,
|
|
280
|
-
field: str = 'node',
|
|
281
|
-
**kwargs
|
|
282
|
-
) -> None:
|
|
283
|
-
super().__init__(**kwargs)
|
|
284
|
-
self.units = units
|
|
285
|
-
self._activation = keras.activations.get(activation)
|
|
286
|
-
self.field = field
|
|
318
|
+
if self._normalize_aggregate:
|
|
319
|
+
normalized_node_feature = self._aggregation_norm(tensor.node['feature'])
|
|
320
|
+
tensor = tensor.update({'node': {'feature': normalized_node_feature}})
|
|
287
321
|
|
|
288
|
-
|
|
289
|
-
"""Builds the layer.
|
|
290
|
-
"""
|
|
291
|
-
data = getattr(spec, self.field, None)
|
|
292
|
-
if data is None:
|
|
293
|
-
raise ValueError('Could not access field {self.field!r}.')
|
|
294
|
-
feature_dim = data['feature'].shape[-1]
|
|
295
|
-
if not self.units:
|
|
296
|
-
self.units = feature_dim
|
|
297
|
-
self._dense = self.get_dense(self.units)
|
|
298
|
-
self._dense.build([None, feature_dim])
|
|
322
|
+
tensor = self.update(tensor)
|
|
299
323
|
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
feature = self._dense(feature)
|
|
305
|
-
feature = self._activation(feature)
|
|
324
|
+
if not self._skip_connection:
|
|
325
|
+
return tensor
|
|
326
|
+
|
|
327
|
+
updated_node_feature = tensor.node['feature']
|
|
306
328
|
return tensor.update(
|
|
307
329
|
{
|
|
308
|
-
|
|
309
|
-
'feature':
|
|
330
|
+
'node': {
|
|
331
|
+
'feature': updated_node_feature + input_node_feature
|
|
310
332
|
}
|
|
311
333
|
}
|
|
312
|
-
)
|
|
334
|
+
)
|
|
313
335
|
|
|
314
336
|
def get_config(self) -> dict:
|
|
315
337
|
config = super().get_config()
|
|
316
338
|
config.update({
|
|
317
339
|
'units': self.units,
|
|
318
|
-
'
|
|
319
|
-
'
|
|
340
|
+
'normalize': self._normalize_aggregate,
|
|
341
|
+
'skip_connection': self._skip_connection,
|
|
320
342
|
})
|
|
321
343
|
return config
|
|
322
344
|
|
|
323
345
|
|
|
324
346
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
325
|
-
class
|
|
326
|
-
|
|
327
|
-
"""Graph neural network.
|
|
328
|
-
|
|
329
|
-
Sequentially calls graph layers (`GraphLayer`) and concatenates its output.
|
|
347
|
+
class GINConv(GraphConv):
|
|
330
348
|
|
|
331
|
-
|
|
332
|
-
layers (list):
|
|
333
|
-
A list of graph layers.
|
|
349
|
+
"""Graph isomorphism network layer.
|
|
334
350
|
"""
|
|
335
351
|
|
|
336
|
-
def __init__(
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
352
|
+
def __init__(
|
|
353
|
+
self,
|
|
354
|
+
units: int,
|
|
355
|
+
activation: keras.layers.Activation | str | None = 'relu',
|
|
356
|
+
use_bias: bool = True,
|
|
357
|
+
normalize: bool = True,
|
|
358
|
+
dropout: float = 0.0,
|
|
359
|
+
update_edge_feature: bool = True,
|
|
360
|
+
**kwargs,
|
|
361
|
+
):
|
|
362
|
+
super().__init__(
|
|
363
|
+
units=units,
|
|
364
|
+
normalize=normalize,
|
|
365
|
+
use_bias=use_bias,
|
|
366
|
+
**kwargs
|
|
367
|
+
)
|
|
368
|
+
self._activation = keras.activations.get(activation)
|
|
369
|
+
self._dropout = dropout
|
|
370
|
+
self._update_edge_feature = update_edge_feature
|
|
340
371
|
|
|
341
372
|
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
342
373
|
"""Builds the layer.
|
|
343
374
|
"""
|
|
344
|
-
units = self.layers[0].units
|
|
345
375
|
node_feature_dim = spec.node['feature'].shape[-1]
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
if
|
|
376
|
+
|
|
377
|
+
self.epsilon = self.add_weight(
|
|
378
|
+
name='epsilon',
|
|
379
|
+
shape=(),
|
|
380
|
+
initializer='zeros',
|
|
381
|
+
trainable=True,
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
if 'feature' in spec.edge:
|
|
355
385
|
edge_feature_dim = spec.edge['feature'].shape[-1]
|
|
356
|
-
if edge_feature_dim != units:
|
|
357
|
-
warn(
|
|
358
|
-
'Edge feature dim does not match `units` of the first layer. '
|
|
359
|
-
'Automatically adding a edge projection layer to match `units`.'
|
|
360
|
-
)
|
|
361
|
-
self._edge_dense = self.get_dense(units)
|
|
362
|
-
self._update_edge_feature = True
|
|
363
386
|
|
|
364
|
-
|
|
365
|
-
|
|
387
|
+
if not self._update_edge_feature:
|
|
388
|
+
if (edge_feature_dim != node_feature_dim):
|
|
389
|
+
warn(
|
|
390
|
+
'Found edge feature dim to be incompatible with node feature dim. '
|
|
391
|
+
'Automatically adding a edge feature projection layer to match '
|
|
392
|
+
'the dim of node features.'
|
|
393
|
+
)
|
|
394
|
+
self._update_edge_feature = True
|
|
395
|
+
|
|
396
|
+
if self._update_edge_feature:
|
|
397
|
+
self._edge_dense = self.get_dense(node_feature_dim)
|
|
398
|
+
self._edge_dense.build([None, edge_feature_dim])
|
|
399
|
+
else:
|
|
400
|
+
self._update_edge_feature = False
|
|
401
|
+
|
|
402
|
+
self._feedforward_intermediate_dense = self.get_dense(self.units)
|
|
403
|
+
self._feedforward_intermediate_dense.build([None, node_feature_dim])
|
|
404
|
+
|
|
405
|
+
has_overridden_update = self.__class__.update != GINConv.update
|
|
406
|
+
if not has_overridden_update:
|
|
407
|
+
# Use default feedforward network
|
|
408
|
+
|
|
409
|
+
self._feedforward_dropout = keras.layers.Dropout(self._dropout)
|
|
410
|
+
self._feedforward_activation = self._activation
|
|
411
|
+
|
|
412
|
+
self._feedforward_output_dense = self.get_dense(self.units)
|
|
413
|
+
self._feedforward_output_dense.build([None, self.units])
|
|
414
|
+
|
|
415
|
+
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
416
|
+
"""Computes messages.
|
|
366
417
|
"""
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
x['node']['feature'] = self._node_dense(tensor.node['feature'])
|
|
418
|
+
message = tensor.gather('feature', 'source')
|
|
419
|
+
edge_feature = tensor.edge.get('feature')
|
|
370
420
|
if self._update_edge_feature:
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
x = layer(x)
|
|
375
|
-
outputs.append(x['node']['feature'])
|
|
421
|
+
edge_feature = self._edge_dense(edge_feature)
|
|
422
|
+
if edge_feature is not None:
|
|
423
|
+
message += edge_feature
|
|
376
424
|
return tensor.update(
|
|
377
425
|
{
|
|
378
|
-
'
|
|
379
|
-
'
|
|
380
|
-
|
|
426
|
+
'edge': {
|
|
427
|
+
'message': message,
|
|
428
|
+
'feature': edge_feature
|
|
429
|
+
}
|
|
381
430
|
}
|
|
382
431
|
)
|
|
383
|
-
|
|
384
|
-
def tape_propagate(
|
|
385
|
-
self,
|
|
386
|
-
tensor: tensors.GraphTensor,
|
|
387
|
-
tape: tf.GradientTape,
|
|
388
|
-
training: bool | None = None,
|
|
389
|
-
) -> tuple[tensors.GraphTensor, list[tf.Tensor]]:
|
|
390
|
-
"""Performs the propagation with a `GradientTape`.
|
|
391
|
-
|
|
392
|
-
Performs the same forward pass as `propagate` but with a `GradientTape`
|
|
393
|
-
watching intermediate node features.
|
|
394
432
|
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
The graph input.
|
|
433
|
+
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
434
|
+
"""Aggregates messages.
|
|
398
435
|
"""
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
x['node']['feature'] = self._node_dense(tensor.node['feature'])
|
|
405
|
-
if self._update_edge_feature:
|
|
406
|
-
x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
|
|
407
|
-
tape.watch(x['node']['feature'])
|
|
408
|
-
outputs = [x['node']['feature']]
|
|
409
|
-
for layer in self.layers:
|
|
410
|
-
x = layer(x, training=training)
|
|
411
|
-
tape.watch(x['node']['feature'])
|
|
412
|
-
outputs.append(x['node']['feature'])
|
|
413
|
-
|
|
414
|
-
tensor = tensor.update(
|
|
436
|
+
node_feature = tensor.aggregate('message')
|
|
437
|
+
node_feature += (1 + self.epsilon) * tensor.node['feature']
|
|
438
|
+
node_feature = self._feedforward_intermediate_dense(node_feature)
|
|
439
|
+
node_feature = self._feedforward_activation(node_feature)
|
|
440
|
+
return tensor.update(
|
|
415
441
|
{
|
|
416
442
|
'node': {
|
|
417
|
-
'feature':
|
|
443
|
+
'feature': node_feature,
|
|
444
|
+
},
|
|
445
|
+
'edge': {
|
|
446
|
+
'message': None,
|
|
418
447
|
}
|
|
419
448
|
}
|
|
420
449
|
)
|
|
421
|
-
return tensor, outputs
|
|
422
450
|
|
|
423
|
-
def
|
|
424
|
-
|
|
425
|
-
|
|
451
|
+
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
452
|
+
"""Updates nodes.
|
|
453
|
+
"""
|
|
454
|
+
node_feature = tensor.node['feature']
|
|
455
|
+
node_feature = self._feedforward_dropout(node_feature)
|
|
456
|
+
node_feature = self._feedforward_output_dense(node_feature)
|
|
457
|
+
return tensor.update(
|
|
426
458
|
{
|
|
427
|
-
'
|
|
428
|
-
|
|
429
|
-
|
|
459
|
+
'node': {
|
|
460
|
+
'feature': node_feature,
|
|
461
|
+
}
|
|
430
462
|
}
|
|
431
463
|
)
|
|
464
|
+
|
|
465
|
+
def get_config(self) -> dict:
|
|
466
|
+
config = super().get_config()
|
|
467
|
+
config.update({
|
|
468
|
+
'activation': keras.activations.serialize(self._activation),
|
|
469
|
+
'dropout': self._dropout,
|
|
470
|
+
'update_edge_feature': self._update_edge_feature
|
|
471
|
+
})
|
|
432
472
|
return config
|
|
433
473
|
|
|
434
|
-
@classmethod
|
|
435
|
-
def from_config(cls, config: dict) -> 'GraphNetwork':
|
|
436
|
-
config['layers'] = [
|
|
437
|
-
keras.layers.deserialize(layer) for layer in config['layers']
|
|
438
|
-
]
|
|
439
|
-
return super().from_config(config)
|
|
440
|
-
|
|
441
474
|
|
|
442
475
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
443
|
-
class
|
|
444
|
-
|
|
445
|
-
"""Node embedding layer.
|
|
476
|
+
class GTConv(GraphConv):
|
|
446
477
|
|
|
447
|
-
|
|
478
|
+
"""Graph transformer layer.
|
|
448
479
|
"""
|
|
449
480
|
|
|
450
481
|
def __init__(
|
|
451
|
-
self,
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
482
|
+
self,
|
|
483
|
+
units: int,
|
|
484
|
+
heads: int = 8,
|
|
485
|
+
activation: keras.layers.Activation | str | None = "relu",
|
|
486
|
+
use_bias: bool = True,
|
|
487
|
+
normalize: bool = True,
|
|
488
|
+
dropout: float = 0.0,
|
|
489
|
+
attention_dropout: float = 0.0,
|
|
490
|
+
**kwargs,
|
|
456
491
|
) -> None:
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
492
|
+
kwargs['skip_connection'] = False
|
|
493
|
+
super().__init__(
|
|
494
|
+
units=units,
|
|
495
|
+
normalize=normalize,
|
|
496
|
+
use_bias=use_bias,
|
|
497
|
+
**kwargs
|
|
498
|
+
)
|
|
499
|
+
self._heads = heads
|
|
500
|
+
if self.units % self.heads != 0:
|
|
501
|
+
raise ValueError(f"units need to be divisible by heads.")
|
|
502
|
+
self._head_units = self.units // self.heads
|
|
503
|
+
self._activation = keras.activations.get(activation)
|
|
504
|
+
self._dropout = dropout
|
|
505
|
+
self._attention_dropout = attention_dropout
|
|
506
|
+
self._normalize = normalize
|
|
462
507
|
|
|
463
|
-
|
|
508
|
+
@property
|
|
509
|
+
def heads(self):
|
|
510
|
+
return self._heads
|
|
511
|
+
|
|
512
|
+
@property
|
|
513
|
+
def head_units(self):
|
|
514
|
+
return self._head_units
|
|
515
|
+
|
|
516
|
+
def build_from_spec(self, spec):
|
|
464
517
|
"""Builds the layer.
|
|
465
518
|
"""
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
519
|
+
node_feature_dim = spec.node['feature'].shape[-1]
|
|
520
|
+
self.project_residual = node_feature_dim != self.units
|
|
521
|
+
if self.project_residual:
|
|
522
|
+
warn(
|
|
523
|
+
'`GTConv` uses residual connections, but found incompatible dim '
|
|
524
|
+
'between input (node feature dim) and output (`self.units`). '
|
|
525
|
+
'Automatically applying a projection layer to residual to '
|
|
526
|
+
'match input and output. '
|
|
527
|
+
)
|
|
528
|
+
self._residual_dense = self.get_dense(self.units)
|
|
529
|
+
self._residual_dense.build([None, node_feature_dim])
|
|
530
|
+
|
|
531
|
+
self._query_dense = self.get_einsum_dense(
|
|
532
|
+
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
533
|
+
)
|
|
534
|
+
self._query_dense.build([None, node_feature_dim])
|
|
471
535
|
|
|
472
|
-
self.
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
if self._has_super and not self._embed_context:
|
|
477
|
-
self._super_feature = self.get_weight(shape=[self.dim], name='super_node_feature')
|
|
478
|
-
if self._allow_masking:
|
|
479
|
-
self._mask_feature = self.get_weight(shape=[self.dim], name='mask_node_feature')
|
|
536
|
+
self._key_dense = self.get_einsum_dense(
|
|
537
|
+
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
538
|
+
)
|
|
539
|
+
self._key_dense.build([None, node_feature_dim])
|
|
480
540
|
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
541
|
+
self._value_dense = self.get_einsum_dense(
|
|
542
|
+
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
543
|
+
)
|
|
544
|
+
self._value_dense.build([None, node_feature_dim])
|
|
485
545
|
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
"""
|
|
489
|
-
feature = self._node_dense(tensor.node['feature'])
|
|
546
|
+
self._output_dense = self.get_dense(self.units)
|
|
547
|
+
self._output_dense.build([None, self.units])
|
|
490
548
|
|
|
491
|
-
|
|
492
|
-
super_feature = (0 if self._embed_context else self._super_feature)
|
|
493
|
-
super_mask = keras.ops.expand_dims(tensor.node['super'], 1)
|
|
494
|
-
feature = keras.ops.where(super_mask, super_feature, feature)
|
|
549
|
+
self._softmax_dropout = keras.layers.Dropout(self._attention_dropout)
|
|
495
550
|
|
|
496
|
-
|
|
497
|
-
context_feature = self._context_dense(tensor.context['feature'])
|
|
498
|
-
feature = ops.scatter_update(feature, tensor.node['super'], context_feature)
|
|
499
|
-
tensor = tensor.update({'context': {'feature': None}})
|
|
551
|
+
self._self_attention_dropout = keras.layers.Dropout(self._dropout)
|
|
500
552
|
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
self.
|
|
504
|
-
self.
|
|
505
|
-
):
|
|
506
|
-
random = keras.random.uniform(shape=[tensor.num_nodes])
|
|
507
|
-
mask = random <= self._masking_rate
|
|
508
|
-
if self._has_super:
|
|
509
|
-
mask = keras.ops.logical_and(
|
|
510
|
-
mask, keras.ops.logical_not(tensor.node['super'])
|
|
511
|
-
)
|
|
512
|
-
mask = keras.ops.expand_dims(mask, -1)
|
|
513
|
-
feature = keras.ops.where(mask, self._mask_feature, feature)
|
|
514
|
-
elif self._allow_masking:
|
|
515
|
-
# Slience warning of 'no gradients for variables'
|
|
516
|
-
feature = feature + (self._mask_feature * 0.0)
|
|
553
|
+
self._add_edge_bias = not 'bias' in spec.edge
|
|
554
|
+
if self._add_edge_bias:
|
|
555
|
+
self._add_edge_bias = AddEdgeBias()
|
|
556
|
+
self._add_edge_bias.build_from_spec(spec)
|
|
517
557
|
|
|
518
|
-
|
|
558
|
+
has_overridden_update = self.__class__.update != GTConv.update
|
|
559
|
+
if not has_overridden_update:
|
|
560
|
+
|
|
561
|
+
if self._normalize:
|
|
562
|
+
self._feedforward_output_norm = keras.layers.LayerNormalization()
|
|
563
|
+
self._feedforward_output_norm.build([None, self.units])
|
|
519
564
|
|
|
520
|
-
|
|
521
|
-
def masking_rate(self):
|
|
522
|
-
return self._masking_rate
|
|
523
|
-
|
|
524
|
-
@masking_rate.setter
|
|
525
|
-
def masking_rate(self, rate: float):
|
|
526
|
-
if not self._allow_masking and rate is not None:
|
|
527
|
-
raise ValueError(
|
|
528
|
-
f'Cannot set `masking_rate` for layer {self} '
|
|
529
|
-
'as `allow_masking` was set to `False`.'
|
|
530
|
-
)
|
|
531
|
-
self._masking_rate = float(rate)
|
|
565
|
+
self._feedforward_dropout = keras.layers.Dropout(self._dropout)
|
|
532
566
|
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
config.update({
|
|
536
|
-
'dim': self.dim,
|
|
537
|
-
'allow_masking': self._allow_masking
|
|
538
|
-
})
|
|
539
|
-
return config
|
|
540
|
-
|
|
567
|
+
self._feedforward_intermediate_dense = self.get_dense(self.units)
|
|
568
|
+
self._feedforward_intermediate_dense.build([None, self.units])
|
|
541
569
|
|
|
542
|
-
|
|
543
|
-
|
|
570
|
+
self._feedforward_output_dense = self.get_dense(self.units)
|
|
571
|
+
self._feedforward_output_dense.build([None, self.units])
|
|
544
572
|
|
|
545
|
-
"""Edge embedding layer.
|
|
546
573
|
|
|
547
|
-
|
|
548
|
-
|
|
574
|
+
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
575
|
+
"""Computes messages.
|
|
576
|
+
"""
|
|
549
577
|
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
) -> None:
|
|
556
|
-
super().__init__(**kwargs)
|
|
557
|
-
self.dim = dim
|
|
558
|
-
self._masking_rate = None
|
|
559
|
-
self._allow_masking = allow_masking
|
|
578
|
+
node_feature = tensor.node['feature']
|
|
579
|
+
|
|
580
|
+
query = self._query_dense(node_feature)
|
|
581
|
+
key = self._key_dense(node_feature)
|
|
582
|
+
value = self._value_dense(node_feature)
|
|
560
583
|
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
feature_dim = spec.edge['feature'].shape[-1]
|
|
565
|
-
if not self.dim:
|
|
566
|
-
self.dim = feature_dim
|
|
567
|
-
self._edge_dense = self.get_dense(self.dim)
|
|
568
|
-
self._edge_dense.build([None, feature_dim])
|
|
584
|
+
query = ops.gather(query, tensor.edge['source'])
|
|
585
|
+
key = ops.gather(key, tensor.edge['target'])
|
|
586
|
+
value = ops.gather(value, tensor.edge['source'])
|
|
569
587
|
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
self._super_feature = self.get_weight(shape=[self.dim], name='super_edge_feature')
|
|
573
|
-
if self._allow_masking:
|
|
574
|
-
self._mask_feature = self.get_weight(shape=[self.dim], name='mask_edge_feature')
|
|
588
|
+
attention_score = keras.ops.sum(query * key, axis=1, keepdims=True)
|
|
589
|
+
attention_score /= keras.ops.sqrt(float(self.head_units))
|
|
575
590
|
|
|
576
|
-
|
|
577
|
-
|
|
591
|
+
if self._add_edge_bias:
|
|
592
|
+
tensor = self._add_edge_bias(tensor)
|
|
593
|
+
|
|
594
|
+
attention_score += keras.ops.expand_dims(tensor.edge['bias'], -1)
|
|
595
|
+
|
|
596
|
+
attention = ops.edge_softmax(attention_score, tensor.edge['target'])
|
|
597
|
+
attention = self._softmax_dropout(attention)
|
|
598
|
+
|
|
599
|
+
return tensor.update(
|
|
600
|
+
{
|
|
601
|
+
'edge': {
|
|
602
|
+
'message': value,
|
|
603
|
+
'weight': attention,
|
|
604
|
+
},
|
|
605
|
+
}
|
|
606
|
+
)
|
|
607
|
+
|
|
608
|
+
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
609
|
+
"""Aggregates messages.
|
|
578
610
|
"""
|
|
579
|
-
|
|
611
|
+
node_feature = tensor.aggregate('message')
|
|
612
|
+
node_feature = keras.ops.reshape(node_feature, (-1, self.units))
|
|
613
|
+
node_feature = self._output_dense(node_feature)
|
|
614
|
+
node_feature = self._self_attention_dropout(node_feature)
|
|
615
|
+
return tensor.update(
|
|
616
|
+
{
|
|
617
|
+
'node': {
|
|
618
|
+
'feature': node_feature,
|
|
619
|
+
'residual': tensor.node['feature']
|
|
620
|
+
},
|
|
621
|
+
'edge': {
|
|
622
|
+
'message': None,
|
|
623
|
+
'weight': None,
|
|
624
|
+
}
|
|
625
|
+
}
|
|
626
|
+
)
|
|
580
627
|
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
628
|
+
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
629
|
+
"""Updates nodes.
|
|
630
|
+
"""
|
|
631
|
+
node_feature = tensor.node['feature']
|
|
585
632
|
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
self.
|
|
589
|
-
self._masking_rate > 0
|
|
590
|
-
):
|
|
591
|
-
random = keras.random.uniform(shape=[tensor.num_edges])
|
|
592
|
-
mask = random <= self._masking_rate
|
|
593
|
-
if self._has_super:
|
|
594
|
-
mask = keras.ops.logical_and(
|
|
595
|
-
mask, keras.ops.logical_not(tensor.edge['super'])
|
|
596
|
-
)
|
|
597
|
-
mask = keras.ops.expand_dims(mask, -1)
|
|
598
|
-
feature = keras.ops.where(mask, self._mask_feature, feature)
|
|
599
|
-
elif self._allow_masking:
|
|
600
|
-
# Slience warning of 'no gradients for variables'
|
|
601
|
-
feature = feature + (self._mask_feature * 0.0)
|
|
633
|
+
residual = tensor.node['residual']
|
|
634
|
+
if self.project_residual:
|
|
635
|
+
residual = self._residual_dense(residual)
|
|
602
636
|
|
|
603
|
-
|
|
637
|
+
node_feature += residual
|
|
638
|
+
residual = node_feature
|
|
604
639
|
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
if not self._allow_masking and rate is not None:
|
|
612
|
-
raise ValueError(
|
|
613
|
-
f'Cannot set `masking_rate` for layer {self} '
|
|
614
|
-
'as `allow_masking` was set to `False`.'
|
|
615
|
-
)
|
|
616
|
-
self._masking_rate = float(rate)
|
|
640
|
+
node_feature = self._feedforward_intermediate_dense(node_feature)
|
|
641
|
+
node_feature = self._activation(node_feature)
|
|
642
|
+
node_feature = self._feedforward_output_dense(node_feature)
|
|
643
|
+
node_feature = self._feedforward_dropout(node_feature)
|
|
644
|
+
if self._normalize:
|
|
645
|
+
node_feature = self._feedforward_output_norm(node_feature)
|
|
617
646
|
|
|
647
|
+
node_feature += residual
|
|
648
|
+
|
|
649
|
+
return tensor.update(
|
|
650
|
+
{
|
|
651
|
+
'node': {
|
|
652
|
+
'feature': node_feature,
|
|
653
|
+
},
|
|
654
|
+
}
|
|
655
|
+
)
|
|
656
|
+
|
|
618
657
|
def get_config(self) -> dict:
|
|
619
658
|
config = super().get_config()
|
|
620
659
|
config.update({
|
|
621
|
-
|
|
622
|
-
'
|
|
660
|
+
"heads": self._heads,
|
|
661
|
+
'activation': keras.activations.serialize(self._activation),
|
|
662
|
+
'dropout': self._dropout,
|
|
663
|
+
'attention_dropout': self._attention_dropout,
|
|
623
664
|
})
|
|
624
665
|
return config
|
|
625
666
|
|
|
626
667
|
|
|
627
668
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
628
|
-
class
|
|
629
|
-
"""Context projection layer.
|
|
630
|
-
"""
|
|
631
|
-
def __init__(self, units: int = None, activation: str = None, **kwargs):
|
|
632
|
-
super().__init__(units=units, activation=activation, field='context', **kwargs)
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
636
|
-
class NodeProjection(Projection):
|
|
637
|
-
"""Node projection layer.
|
|
638
|
-
"""
|
|
639
|
-
def __init__(self, units: int = None, activation: str = None, **kwargs):
|
|
640
|
-
super().__init__(units=units, activation=activation, field='node', **kwargs)
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
644
|
-
class EdgeProjection(Projection):
|
|
645
|
-
"""Edge projection layer.
|
|
646
|
-
"""
|
|
647
|
-
def __init__(self, units: int = None, activation: str = None, **kwargs):
|
|
648
|
-
super().__init__(units=units, activation=activation, field='edge', **kwargs)
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
652
|
-
class GINConv(GraphConv):
|
|
669
|
+
class EGConv3D(GraphConv):
|
|
653
670
|
|
|
654
|
-
"""
|
|
671
|
+
"""Equivariant graph neural network layer.
|
|
655
672
|
"""
|
|
656
673
|
|
|
657
674
|
def __init__(
|
|
658
|
-
self,
|
|
659
|
-
units: int,
|
|
660
|
-
activation: keras.layers.Activation | str | None =
|
|
661
|
-
|
|
675
|
+
self,
|
|
676
|
+
units: int = 128,
|
|
677
|
+
activation: keras.layers.Activation | str | None = None,
|
|
678
|
+
use_bias: bool = True,
|
|
662
679
|
normalize: bool = True,
|
|
663
|
-
|
|
664
|
-
**kwargs
|
|
665
|
-
):
|
|
666
|
-
super().__init__(
|
|
680
|
+
dropout: float = 0.0,
|
|
681
|
+
**kwargs
|
|
682
|
+
) -> None:
|
|
683
|
+
super().__init__(
|
|
684
|
+
units=units,
|
|
685
|
+
normalize=normalize,
|
|
686
|
+
use_bias=use_bias,
|
|
687
|
+
**kwargs
|
|
688
|
+
)
|
|
667
689
|
self._activation = keras.activations.get(activation)
|
|
668
|
-
self.
|
|
669
|
-
self._dropout = dropout
|
|
670
|
-
self._update_edge_feature = update_edge_feature
|
|
690
|
+
self._dropout = dropout or 0.0
|
|
671
691
|
|
|
672
692
|
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
673
|
-
|
|
674
|
-
|
|
693
|
+
if 'coordinate' not in spec.node:
|
|
694
|
+
raise ValueError(
|
|
695
|
+
'Could not find `coordinate`s in node, '
|
|
696
|
+
'which is required for Conv3D layers.'
|
|
697
|
+
)
|
|
675
698
|
node_feature_dim = spec.node['feature'].shape[-1]
|
|
676
|
-
|
|
677
|
-
self.epsilon = self.add_weight(
|
|
678
|
-
name='epsilon',
|
|
679
|
-
shape=(),
|
|
680
|
-
initializer='zeros',
|
|
681
|
-
trainable=True,
|
|
682
|
-
)
|
|
683
|
-
|
|
699
|
+
feature_dim = node_feature_dim + node_feature_dim + 1
|
|
684
700
|
if 'feature' in spec.edge:
|
|
701
|
+
self._has_edge_feature = True
|
|
685
702
|
edge_feature_dim = spec.edge['feature'].shape[-1]
|
|
703
|
+
feature_dim += edge_feature_dim
|
|
704
|
+
else:
|
|
705
|
+
self._has_edge_feature = False
|
|
686
706
|
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
'Automatically adding a edge feature projection layer to match '
|
|
692
|
-
'the dim of node features.'
|
|
693
|
-
)
|
|
694
|
-
self._update_edge_feature = True
|
|
707
|
+
self.message_fn = self.get_dense(self.units, activation=self._activation)
|
|
708
|
+
self.message_fn.build([None, feature_dim])
|
|
709
|
+
self.dense_position = self.get_dense(1)
|
|
710
|
+
self.dense_position.build([None, self.units])
|
|
695
711
|
|
|
696
|
-
|
|
697
|
-
self._edge_dense = self.get_dense(node_feature_dim)
|
|
698
|
-
self._edge_dense.build([None, edge_feature_dim])
|
|
699
|
-
else:
|
|
700
|
-
self._update_edge_feature = False
|
|
701
|
-
|
|
702
|
-
has_overridden_update = self.__class__.update != GINConv.update
|
|
712
|
+
has_overridden_update = self.__class__.update != EGConv3D.update
|
|
703
713
|
if not has_overridden_update:
|
|
704
|
-
|
|
705
|
-
self.
|
|
706
|
-
self.
|
|
714
|
+
self.update_fn = self.get_dense(self.units, activation=self._activation)
|
|
715
|
+
self.update_fn.build([None, node_feature_dim + self.units])
|
|
716
|
+
self._dropout_layer = keras.layers.Dropout(self._dropout)
|
|
707
717
|
|
|
708
|
-
if self._normalize:
|
|
709
|
-
self._feedforward_intermediate_norm = keras.layers.BatchNormalization()
|
|
710
|
-
self._feedforward_intermediate_norm.build([None, self.units])
|
|
711
|
-
|
|
712
|
-
self._feedforward_dropout = keras.layers.Dropout(self._dropout)
|
|
713
|
-
self._feedforward_activation = self._activation
|
|
714
|
-
|
|
715
|
-
self._feedforward_output_dense = self.get_dense(self.units)
|
|
716
|
-
self._feedforward_output_dense.build([None, self.units])
|
|
717
|
-
|
|
718
718
|
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
719
|
-
"""
|
|
719
|
+
"""Computes messages.
|
|
720
720
|
"""
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
721
|
+
relative_node_coordinate = keras.ops.subtract(
|
|
722
|
+
tensor.gather('coordinate', 'target'),
|
|
723
|
+
tensor.gather('coordinate', 'source')
|
|
724
|
+
)
|
|
725
|
+
euclidean_distance = keras.ops.sum(
|
|
726
|
+
keras.ops.square(
|
|
727
|
+
relative_node_coordinate
|
|
728
|
+
),
|
|
729
|
+
axis=-1,
|
|
730
|
+
keepdims=True
|
|
731
|
+
)
|
|
732
|
+
feature = keras.ops.concatenate(
|
|
733
|
+
[
|
|
734
|
+
tensor.gather('feature', 'target'),
|
|
735
|
+
tensor.gather('feature', 'source'),
|
|
736
|
+
euclidean_distance,
|
|
737
|
+
],
|
|
738
|
+
axis=-1
|
|
739
|
+
)
|
|
740
|
+
if self._has_edge_feature:
|
|
741
|
+
feature = keras.ops.concatenate(
|
|
742
|
+
[
|
|
743
|
+
feature,
|
|
744
|
+
tensor.edge['feature']
|
|
745
|
+
],
|
|
746
|
+
axis=-1
|
|
747
|
+
)
|
|
748
|
+
message = self.message_fn(feature)
|
|
749
|
+
relative_node_coordinate = keras.ops.multiply(
|
|
750
|
+
relative_node_coordinate,
|
|
751
|
+
self.dense_position(message)
|
|
752
|
+
)
|
|
727
753
|
return tensor.update(
|
|
728
754
|
{
|
|
729
755
|
'edge': {
|
|
730
756
|
'message': message,
|
|
731
|
-
'
|
|
757
|
+
'relative_node_coordinate': relative_node_coordinate
|
|
732
758
|
}
|
|
733
759
|
}
|
|
734
760
|
)
|
|
@@ -736,34 +762,54 @@ class GINConv(GraphConv):
|
|
|
736
762
|
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
737
763
|
"""Aggregates messages.
|
|
738
764
|
"""
|
|
739
|
-
|
|
740
|
-
|
|
765
|
+
coefficient = keras.ops.bincount(
|
|
766
|
+
tensor.edge['source'],
|
|
767
|
+
minlength=tensor.num_nodes
|
|
768
|
+
)
|
|
769
|
+
coefficient = keras.ops.cast(
|
|
770
|
+
coefficient, tensor.node['coordinate'].dtype
|
|
771
|
+
)
|
|
772
|
+
coefficient = keras.ops.expand_dims(
|
|
773
|
+
keras.ops.divide_no_nan(1, coefficient), axis=1
|
|
774
|
+
)
|
|
775
|
+
|
|
776
|
+
updated_coordinate = tensor.aggregate('relative_node_coordinate') * coefficient
|
|
777
|
+
updated_coordinate += tensor.node['coordinate']
|
|
778
|
+
|
|
779
|
+
aggregate = tensor.aggregate('message')
|
|
741
780
|
return tensor.update(
|
|
742
781
|
{
|
|
743
782
|
'node': {
|
|
744
|
-
'feature':
|
|
783
|
+
'feature': aggregate,
|
|
784
|
+
'coordinate': updated_coordinate,
|
|
785
|
+
'previous_feature': tensor.node['feature'],
|
|
745
786
|
},
|
|
746
787
|
'edge': {
|
|
747
788
|
'message': None,
|
|
789
|
+
'relative_node_coordinate': None
|
|
748
790
|
}
|
|
749
791
|
}
|
|
750
|
-
)
|
|
751
|
-
|
|
792
|
+
)
|
|
793
|
+
|
|
752
794
|
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
753
|
-
"""Updates nodes.
|
|
795
|
+
"""Updates nodes.
|
|
754
796
|
"""
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
797
|
+
updated_node_feature = self.update_fn(
|
|
798
|
+
keras.ops.concatenate(
|
|
799
|
+
[
|
|
800
|
+
tensor.node['feature'],
|
|
801
|
+
tensor.node['previous_feature']
|
|
802
|
+
],
|
|
803
|
+
axis=-1
|
|
804
|
+
)
|
|
805
|
+
)
|
|
806
|
+
updated_node_feature = self._dropout_layer(updated_node_feature)
|
|
762
807
|
return tensor.update(
|
|
763
808
|
{
|
|
764
809
|
'node': {
|
|
765
|
-
'feature':
|
|
766
|
-
|
|
810
|
+
'feature': updated_node_feature,
|
|
811
|
+
'previous_feature': None,
|
|
812
|
+
},
|
|
767
813
|
}
|
|
768
814
|
)
|
|
769
815
|
|
|
@@ -771,267 +817,390 @@ class GINConv(GraphConv):
|
|
|
771
817
|
config = super().get_config()
|
|
772
818
|
config.update({
|
|
773
819
|
'activation': keras.activations.serialize(self._activation),
|
|
774
|
-
'dropout': self._dropout,
|
|
775
|
-
'normalize': self._normalize,
|
|
820
|
+
'dropout': self._dropout,
|
|
776
821
|
})
|
|
777
|
-
return config
|
|
778
|
-
|
|
822
|
+
return config
|
|
823
|
+
|
|
779
824
|
|
|
780
825
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
781
|
-
class
|
|
782
|
-
|
|
783
|
-
"""Graph transformer layer.
|
|
826
|
+
class Projection(GraphLayer):
|
|
827
|
+
"""Base graph projection layer.
|
|
784
828
|
"""
|
|
785
|
-
|
|
786
829
|
def __init__(
|
|
787
|
-
self,
|
|
788
|
-
units: int,
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
attention_dropout: float = 0.0,
|
|
793
|
-
normalize: bool = True,
|
|
794
|
-
normalize_first: bool = True,
|
|
795
|
-
**kwargs,
|
|
830
|
+
self,
|
|
831
|
+
units: int = None,
|
|
832
|
+
activation: str = None,
|
|
833
|
+
field: str = 'node',
|
|
834
|
+
**kwargs
|
|
796
835
|
) -> None:
|
|
797
|
-
super().__init__(
|
|
798
|
-
self.
|
|
799
|
-
if self.units % self.heads != 0:
|
|
800
|
-
raise ValueError(f"units need to be divisible by heads.")
|
|
801
|
-
self._head_units = self.units // self.heads
|
|
836
|
+
super().__init__(**kwargs)
|
|
837
|
+
self.units = units
|
|
802
838
|
self._activation = keras.activations.get(activation)
|
|
803
|
-
self.
|
|
804
|
-
self._attention_dropout = attention_dropout
|
|
805
|
-
self._normalize = normalize
|
|
806
|
-
self._normalize_first = normalize_first
|
|
839
|
+
self.field = field
|
|
807
840
|
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
841
|
+
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
842
|
+
"""Builds the layer.
|
|
843
|
+
"""
|
|
844
|
+
data = getattr(spec, self.field, None)
|
|
845
|
+
if data is None:
|
|
846
|
+
raise ValueError('Could not access field {self.field!r}.')
|
|
847
|
+
feature_dim = data['feature'].shape[-1]
|
|
848
|
+
if not self.units:
|
|
849
|
+
self.units = feature_dim
|
|
850
|
+
self._dense = self.get_dense(self.units)
|
|
851
|
+
self._dense.build([None, feature_dim])
|
|
852
|
+
|
|
853
|
+
def propagate(self, tensor: tensors.GraphTensor):
|
|
854
|
+
"""Calls the layer.
|
|
855
|
+
"""
|
|
856
|
+
feature = getattr(tensor, self.field)['feature']
|
|
857
|
+
feature = self._dense(feature)
|
|
858
|
+
feature = self._activation(feature)
|
|
859
|
+
return tensor.update(
|
|
860
|
+
{
|
|
861
|
+
self.field: {
|
|
862
|
+
'feature': feature
|
|
863
|
+
}
|
|
864
|
+
}
|
|
865
|
+
)
|
|
866
|
+
|
|
867
|
+
def get_config(self) -> dict:
|
|
868
|
+
config = super().get_config()
|
|
869
|
+
config.update({
|
|
870
|
+
'units': self.units,
|
|
871
|
+
'activation': keras.activations.serialize(self._activation),
|
|
872
|
+
'field': self.field,
|
|
873
|
+
})
|
|
874
|
+
return config
|
|
815
875
|
|
|
816
|
-
|
|
876
|
+
|
|
877
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
878
|
+
class GraphNetwork(GraphLayer):
|
|
879
|
+
|
|
880
|
+
"""Graph neural network.
|
|
881
|
+
|
|
882
|
+
Sequentially calls graph layers (`GraphLayer`) and concatenates its output.
|
|
883
|
+
|
|
884
|
+
Args:
|
|
885
|
+
layers (list):
|
|
886
|
+
A list of graph layers.
|
|
887
|
+
"""
|
|
888
|
+
|
|
889
|
+
def __init__(self, layers: list[GraphLayer], **kwargs) -> None:
|
|
890
|
+
super().__init__(**kwargs)
|
|
891
|
+
self.layers = layers
|
|
892
|
+
self._update_edge_feature = False
|
|
893
|
+
|
|
894
|
+
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
817
895
|
"""Builds the layer.
|
|
818
896
|
"""
|
|
897
|
+
units = self.layers[0].units
|
|
819
898
|
node_feature_dim = spec.node['feature'].shape[-1]
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
'`GTConv` uses residual connections, but input node feature dim '
|
|
825
|
-
'is incompatible with intermediate dim (`units`). '
|
|
826
|
-
'Automatically projecting first residual to match its dim with intermediate dim.'
|
|
827
|
-
),
|
|
828
|
-
category=UserWarning,
|
|
829
|
-
stacklevel=1
|
|
899
|
+
if node_feature_dim != units:
|
|
900
|
+
warn(
|
|
901
|
+
'Node feature dim does not match `units` of the first layer. '
|
|
902
|
+
'Automatically adding a node projection layer to match `units`.'
|
|
830
903
|
)
|
|
831
|
-
self.
|
|
832
|
-
self.
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
904
|
+
self._node_dense = self.get_dense(units)
|
|
905
|
+
self._update_node_feature = True
|
|
906
|
+
has_edge_feature = 'feature' in spec.edge
|
|
907
|
+
if has_edge_feature:
|
|
908
|
+
edge_feature_dim = spec.edge['feature'].shape[-1]
|
|
909
|
+
if edge_feature_dim != units:
|
|
910
|
+
warn(
|
|
911
|
+
'Edge feature dim does not match `units` of the first layer. '
|
|
912
|
+
'Automatically adding a edge projection layer to match `units`.'
|
|
913
|
+
)
|
|
914
|
+
self._edge_dense = self.get_dense(units)
|
|
915
|
+
self._update_edge_feature = True
|
|
841
916
|
|
|
842
|
-
|
|
843
|
-
|
|
917
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
918
|
+
"""Calls the layer.
|
|
919
|
+
"""
|
|
920
|
+
x = tensors.to_dict(tensor)
|
|
921
|
+
if self._update_node_feature:
|
|
922
|
+
x['node']['feature'] = self._node_dense(tensor.node['feature'])
|
|
923
|
+
if self._update_edge_feature:
|
|
924
|
+
x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
|
|
925
|
+
outputs = [x['node']['feature']]
|
|
926
|
+
for layer in self.layers:
|
|
927
|
+
x = layer(x)
|
|
928
|
+
outputs.append(x['node']['feature'])
|
|
929
|
+
return tensor.update(
|
|
930
|
+
{
|
|
931
|
+
'node': {
|
|
932
|
+
'feature': keras.ops.concatenate(outputs, axis=-1)
|
|
933
|
+
}
|
|
934
|
+
}
|
|
844
935
|
)
|
|
845
|
-
|
|
936
|
+
|
|
937
|
+
def tape_propagate(
|
|
938
|
+
self,
|
|
939
|
+
tensor: tensors.GraphTensor,
|
|
940
|
+
tape: tf.GradientTape,
|
|
941
|
+
training: bool | None = None,
|
|
942
|
+
) -> tuple[tensors.GraphTensor, list[tf.Tensor]]:
|
|
943
|
+
"""Performs the propagation with a `GradientTape`.
|
|
846
944
|
|
|
847
|
-
|
|
848
|
-
|
|
945
|
+
Performs the same forward pass as `propagate` but with a `GradientTape`
|
|
946
|
+
watching intermediate node features.
|
|
947
|
+
|
|
948
|
+
Args:
|
|
949
|
+
tensor (tensors.GraphTensor):
|
|
950
|
+
The graph input.
|
|
951
|
+
"""
|
|
952
|
+
if isinstance(tensor, tensors.GraphTensor):
|
|
953
|
+
x = tensors.to_dict(tensor)
|
|
954
|
+
else:
|
|
955
|
+
x = tensor
|
|
956
|
+
if self._update_node_feature:
|
|
957
|
+
x['node']['feature'] = self._node_dense(tensor.node['feature'])
|
|
958
|
+
if self._update_edge_feature:
|
|
959
|
+
x['edge']['feature'] = self._edge_dense(tensor.edge['feature'])
|
|
960
|
+
tape.watch(x['node']['feature'])
|
|
961
|
+
outputs = [x['node']['feature']]
|
|
962
|
+
for layer in self.layers:
|
|
963
|
+
x = layer(x, training=training)
|
|
964
|
+
tape.watch(x['node']['feature'])
|
|
965
|
+
outputs.append(x['node']['feature'])
|
|
966
|
+
|
|
967
|
+
tensor = tensor.update(
|
|
968
|
+
{
|
|
969
|
+
'node': {
|
|
970
|
+
'feature': keras.ops.concatenate(outputs, axis=-1)
|
|
971
|
+
}
|
|
972
|
+
}
|
|
849
973
|
)
|
|
850
|
-
|
|
974
|
+
return tensor, outputs
|
|
975
|
+
|
|
976
|
+
def get_config(self) -> dict:
|
|
977
|
+
config = super().get_config()
|
|
978
|
+
config.update(
|
|
979
|
+
{
|
|
980
|
+
'layers': [
|
|
981
|
+
keras.layers.serialize(layer) for layer in self.layers
|
|
982
|
+
]
|
|
983
|
+
}
|
|
984
|
+
)
|
|
985
|
+
return config
|
|
986
|
+
|
|
987
|
+
@classmethod
|
|
988
|
+
def from_config(cls, config: dict) -> 'GraphNetwork':
|
|
989
|
+
config['layers'] = [
|
|
990
|
+
keras.layers.deserialize(layer) for layer in config['layers']
|
|
991
|
+
]
|
|
992
|
+
return super().from_config(config)
|
|
993
|
+
|
|
994
|
+
|
|
995
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
996
|
+
class NodeEmbedding(GraphLayer):
|
|
997
|
+
|
|
998
|
+
"""Node embedding layer.
|
|
999
|
+
|
|
1000
|
+
Embeds nodes based on its initial features.
|
|
1001
|
+
"""
|
|
1002
|
+
|
|
1003
|
+
def __init__(
|
|
1004
|
+
self,
|
|
1005
|
+
dim: int = None,
|
|
1006
|
+
embed_context: bool = True,
|
|
1007
|
+
allow_masking: bool = True,
|
|
1008
|
+
**kwargs
|
|
1009
|
+
) -> None:
|
|
1010
|
+
super().__init__(**kwargs)
|
|
1011
|
+
self.dim = dim
|
|
1012
|
+
self._embed_context = embed_context
|
|
1013
|
+
self._masking_rate = None
|
|
1014
|
+
self._allow_masking = allow_masking
|
|
1015
|
+
|
|
1016
|
+
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1017
|
+
"""Builds the layer.
|
|
1018
|
+
"""
|
|
1019
|
+
feature_dim = spec.node['feature'].shape[-1]
|
|
1020
|
+
if not self.dim:
|
|
1021
|
+
self.dim = feature_dim
|
|
1022
|
+
self._node_dense = self.get_dense(self.dim)
|
|
1023
|
+
self._node_dense.build([None, feature_dim])
|
|
851
1024
|
|
|
852
|
-
self.
|
|
853
|
-
|
|
1025
|
+
self._has_super = 'super' in spec.node
|
|
1026
|
+
has_context_feature = 'feature' in spec.context
|
|
1027
|
+
if not has_context_feature:
|
|
1028
|
+
self._embed_context = False
|
|
1029
|
+
if self._has_super and not self._embed_context:
|
|
1030
|
+
self._super_feature = self.get_weight(shape=[self.dim], name='super_node_feature')
|
|
1031
|
+
if self._allow_masking:
|
|
1032
|
+
self._mask_feature = self.get_weight(shape=[self.dim], name='mask_node_feature')
|
|
854
1033
|
|
|
855
|
-
|
|
1034
|
+
if self._embed_context:
|
|
1035
|
+
context_feature_dim = spec.context['feature'].shape[-1]
|
|
1036
|
+
self._context_dense = self.get_dense(self.dim)
|
|
1037
|
+
self._context_dense.build([None, context_feature_dim])
|
|
856
1038
|
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
self._self_attention_norm.build([None, self.units])
|
|
1039
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1040
|
+
"""Calls the layer.
|
|
1041
|
+
"""
|
|
1042
|
+
feature = self._node_dense(tensor.node['feature'])
|
|
862
1043
|
|
|
863
|
-
|
|
1044
|
+
if self._has_super:
|
|
1045
|
+
super_feature = (0 if self._embed_context else self._super_feature)
|
|
1046
|
+
super_mask = keras.ops.expand_dims(tensor.node['super'], 1)
|
|
1047
|
+
feature = keras.ops.where(super_mask, super_feature, feature)
|
|
864
1048
|
|
|
865
|
-
|
|
866
|
-
self.
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
self._has_edge_length = 'length' in spec.edge
|
|
870
|
-
if self._has_edge_length and 'bias' not in spec.edge:
|
|
871
|
-
edge_length_dim = spec.edge['length'].shape[-1]
|
|
872
|
-
self._spatial_encoding_dense = self.get_einsum_dense(
|
|
873
|
-
'ij,jkh->ikh', (1, self.heads), kernel_initializer='zeros'
|
|
874
|
-
)
|
|
875
|
-
self._spatial_encoding_dense.build([None, edge_length_dim])
|
|
1049
|
+
if self._embed_context:
|
|
1050
|
+
context_feature = self._context_dense(tensor.context['feature'])
|
|
1051
|
+
feature = ops.scatter_update(feature, tensor.node['super'], context_feature)
|
|
1052
|
+
tensor = tensor.update({'context': {'feature': None}})
|
|
876
1053
|
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
1054
|
+
if (
|
|
1055
|
+
self._allow_masking and
|
|
1056
|
+
self._masking_rate is not None and
|
|
1057
|
+
self._masking_rate > 0
|
|
1058
|
+
):
|
|
1059
|
+
random = keras.random.uniform(shape=[tensor.num_nodes])
|
|
1060
|
+
mask = random <= self._masking_rate
|
|
1061
|
+
if self._has_super:
|
|
1062
|
+
mask = keras.ops.logical_and(
|
|
1063
|
+
mask, keras.ops.logical_not(tensor.node['super'])
|
|
882
1064
|
)
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
self._feedforward_norm = keras.layers.LayerNormalization()
|
|
889
|
-
self._feedforward_norm.build([None, self.units])
|
|
890
|
-
|
|
891
|
-
self._feedforward_dropout = keras.layers.Dropout(self._dropout)
|
|
892
|
-
|
|
893
|
-
self._feedforward_intermediate_dense = self.get_dense(self.units)
|
|
894
|
-
self._feedforward_intermediate_dense.build([None, self.units])
|
|
1065
|
+
mask = keras.ops.expand_dims(mask, -1)
|
|
1066
|
+
feature = keras.ops.where(mask, self._mask_feature, feature)
|
|
1067
|
+
elif self._allow_masking:
|
|
1068
|
+
# Slience warning of 'no gradients for variables'
|
|
1069
|
+
feature = feature + (self._mask_feature * 0.0)
|
|
895
1070
|
|
|
896
|
-
|
|
897
|
-
self._feedforward_output_dense.build([None, self.units])
|
|
1071
|
+
return tensor.update({'node': {'feature': feature}})
|
|
898
1072
|
|
|
899
|
-
|
|
900
|
-
|
|
1073
|
+
@property
|
|
1074
|
+
def masking_rate(self):
|
|
1075
|
+
return self._masking_rate
|
|
901
1076
|
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
if self._has_edge_feature and not self._has_edge_length:
|
|
909
|
-
edge_bias = self._edge_feature_dense(tensor.edge['feature'])
|
|
910
|
-
elif not self._has_edge_feature and self._has_edge_length:
|
|
911
|
-
edge_bias = self._spatial_encoding_dense(tensor.edge['length'])
|
|
912
|
-
else:
|
|
913
|
-
edge_bias = (
|
|
914
|
-
self._edge_feature_dense(tensor.edge['feature']) +
|
|
915
|
-
self._spatial_encoding_dense(tensor.edge['length'])
|
|
1077
|
+
@masking_rate.setter
|
|
1078
|
+
def masking_rate(self, rate: float):
|
|
1079
|
+
if not self._allow_masking and rate is not None:
|
|
1080
|
+
raise ValueError(
|
|
1081
|
+
f'Cannot set `masking_rate` for layer {self} '
|
|
1082
|
+
'as `allow_masking` was set to `False`.'
|
|
916
1083
|
)
|
|
917
|
-
|
|
918
|
-
return tensor.update(
|
|
919
|
-
{
|
|
920
|
-
'edge': {
|
|
921
|
-
'bias': edge_bias
|
|
922
|
-
}
|
|
923
|
-
}
|
|
924
|
-
)
|
|
925
|
-
|
|
926
|
-
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
927
|
-
"""Compute messages.
|
|
928
|
-
"""
|
|
929
|
-
tensor = self.add_edge_bias(tensor)
|
|
930
|
-
tensor = self.add_node_bias(tensor)
|
|
931
|
-
|
|
932
|
-
node_feature = tensor.node['feature']
|
|
1084
|
+
self._masking_rate = float(rate)
|
|
933
1085
|
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
value = self._value_dense(node_feature)
|
|
1086
|
+
def get_config(self) -> dict:
|
|
1087
|
+
config = super().get_config()
|
|
1088
|
+
config.update({
|
|
1089
|
+
'dim': self.dim,
|
|
1090
|
+
'allow_masking': self._allow_masking
|
|
1091
|
+
})
|
|
1092
|
+
return config
|
|
1093
|
+
|
|
943
1094
|
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
value = ops.gather(value, tensor.edge['source'])
|
|
1095
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1096
|
+
class EdgeEmbedding(GraphLayer):
|
|
947
1097
|
|
|
948
|
-
|
|
949
|
-
attention_score /= keras.ops.sqrt(float(self.units))
|
|
1098
|
+
"""Edge embedding layer.
|
|
950
1099
|
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
attention = ops.edge_softmax(attention_score, tensor.edge['target'])
|
|
955
|
-
attention = self._softmax_dropout(attention)
|
|
1100
|
+
Embeds edges based on its initial features.
|
|
1101
|
+
"""
|
|
956
1102
|
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
|
|
963
|
-
|
|
964
|
-
|
|
1103
|
+
def __init__(
|
|
1104
|
+
self,
|
|
1105
|
+
dim: int = None,
|
|
1106
|
+
allow_masking: bool = True,
|
|
1107
|
+
**kwargs
|
|
1108
|
+
) -> None:
|
|
1109
|
+
super().__init__(**kwargs)
|
|
1110
|
+
self.dim = dim
|
|
1111
|
+
self._masking_rate = None
|
|
1112
|
+
self._allow_masking = allow_masking
|
|
965
1113
|
|
|
966
|
-
def
|
|
967
|
-
"""
|
|
1114
|
+
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1115
|
+
"""Builds the layer.
|
|
968
1116
|
"""
|
|
969
|
-
|
|
970
|
-
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
|
|
974
|
-
|
|
975
|
-
residual = tensor.node['feature']
|
|
976
|
-
if self._project_residual:
|
|
977
|
-
residual = self._residual_dense(residual)
|
|
978
|
-
node_feature += residual
|
|
979
|
-
|
|
980
|
-
if not self._normalize_first:
|
|
981
|
-
node_feature = self._self_attention_norm(node_feature)
|
|
1117
|
+
feature_dim = spec.edge['feature'].shape[-1]
|
|
1118
|
+
if not self.dim:
|
|
1119
|
+
self.dim = feature_dim
|
|
1120
|
+
self._edge_dense = self.get_dense(self.dim)
|
|
1121
|
+
self._edge_dense.build([None, feature_dim])
|
|
982
1122
|
|
|
983
|
-
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
'edge': {
|
|
989
|
-
'message': None,
|
|
990
|
-
'weight': None,
|
|
991
|
-
}
|
|
992
|
-
}
|
|
993
|
-
)
|
|
994
|
-
|
|
1123
|
+
self._has_super = 'super' in spec.edge
|
|
1124
|
+
if self._has_super:
|
|
1125
|
+
self._super_feature = self.get_weight(shape=[self.dim], name='super_edge_feature')
|
|
1126
|
+
if self._allow_masking:
|
|
1127
|
+
self._mask_feature = self.get_weight(shape=[self.dim], name='mask_edge_feature')
|
|
995
1128
|
|
|
996
|
-
def
|
|
997
|
-
"""
|
|
1129
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1130
|
+
"""Calls the layer.
|
|
998
1131
|
"""
|
|
999
|
-
|
|
1000
|
-
|
|
1001
|
-
if self._normalize_first:
|
|
1002
|
-
node_feature = self._feedforward_norm(node_feature)
|
|
1132
|
+
feature = self._edge_dense(tensor.edge['feature'])
|
|
1003
1133
|
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1134
|
+
if self._has_super:
|
|
1135
|
+
super_feature = self._super_feature
|
|
1136
|
+
super_mask = keras.ops.expand_dims(tensor.edge['super'], 1)
|
|
1137
|
+
feature = keras.ops.where(super_mask, super_feature, feature)
|
|
1007
1138
|
|
|
1008
|
-
|
|
1009
|
-
|
|
1139
|
+
if (
|
|
1140
|
+
self._allow_masking and
|
|
1141
|
+
self._masking_rate is not None and
|
|
1142
|
+
self._masking_rate > 0
|
|
1143
|
+
):
|
|
1144
|
+
random = keras.random.uniform(shape=[tensor.num_edges])
|
|
1145
|
+
mask = random <= self._masking_rate
|
|
1146
|
+
if self._has_super:
|
|
1147
|
+
mask = keras.ops.logical_and(
|
|
1148
|
+
mask, keras.ops.logical_not(tensor.edge['super'])
|
|
1149
|
+
)
|
|
1150
|
+
mask = keras.ops.expand_dims(mask, -1)
|
|
1151
|
+
feature = keras.ops.where(mask, self._mask_feature, feature)
|
|
1152
|
+
elif self._allow_masking:
|
|
1153
|
+
# Slience warning of 'no gradients for variables'
|
|
1154
|
+
feature = feature + (self._mask_feature * 0.0)
|
|
1010
1155
|
|
|
1011
|
-
|
|
1012
|
-
node_feature = self._feedforward_norm(node_feature)
|
|
1156
|
+
return tensor.update({'edge': {'feature': feature}})
|
|
1013
1157
|
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
'feature': node_feature,
|
|
1018
|
-
},
|
|
1019
|
-
}
|
|
1020
|
-
)
|
|
1158
|
+
@property
|
|
1159
|
+
def masking_rate(self):
|
|
1160
|
+
return self._masking_rate
|
|
1021
1161
|
|
|
1162
|
+
@masking_rate.setter
|
|
1163
|
+
def masking_rate(self, rate: float):
|
|
1164
|
+
if not self._allow_masking and rate is not None:
|
|
1165
|
+
raise ValueError(
|
|
1166
|
+
f'Cannot set `masking_rate` for layer {self} '
|
|
1167
|
+
'as `allow_masking` was set to `False`.'
|
|
1168
|
+
)
|
|
1169
|
+
self._masking_rate = float(rate)
|
|
1170
|
+
|
|
1022
1171
|
def get_config(self) -> dict:
|
|
1023
1172
|
config = super().get_config()
|
|
1024
1173
|
config.update({
|
|
1025
|
-
|
|
1026
|
-
'
|
|
1027
|
-
'dropout': self._dropout,
|
|
1028
|
-
'attention_dropout': self._attention_dropout,
|
|
1029
|
-
'normalize': self._normalize,
|
|
1030
|
-
'normalize_first': self._normalize_first,
|
|
1174
|
+
'dim': self.dim,
|
|
1175
|
+
'allow_masking': self._allow_masking
|
|
1031
1176
|
})
|
|
1032
1177
|
return config
|
|
1033
1178
|
|
|
1034
1179
|
|
|
1180
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1181
|
+
class ContextProjection(Projection):
|
|
1182
|
+
"""Context projection layer.
|
|
1183
|
+
"""
|
|
1184
|
+
def __init__(self, units: int = None, activation: str = None, **kwargs):
|
|
1185
|
+
super().__init__(units=units, activation=activation, field='context', **kwargs)
|
|
1186
|
+
|
|
1187
|
+
|
|
1188
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1189
|
+
class NodeProjection(Projection):
|
|
1190
|
+
"""Node projection layer.
|
|
1191
|
+
"""
|
|
1192
|
+
def __init__(self, units: int = None, activation: str = None, **kwargs):
|
|
1193
|
+
super().__init__(units=units, activation=activation, field='node', **kwargs)
|
|
1194
|
+
|
|
1195
|
+
|
|
1196
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1197
|
+
class EdgeProjection(Projection):
|
|
1198
|
+
"""Edge projection layer.
|
|
1199
|
+
"""
|
|
1200
|
+
def __init__(self, units: int = None, activation: str = None, **kwargs):
|
|
1201
|
+
super().__init__(units=units, activation=activation, field='edge', **kwargs)
|
|
1202
|
+
|
|
1203
|
+
|
|
1035
1204
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1036
1205
|
class Readout(keras.layers.Layer):
|
|
1037
1206
|
|
|
@@ -1097,6 +1266,37 @@ class Readout(keras.layers.Layer):
|
|
|
1097
1266
|
return config
|
|
1098
1267
|
|
|
1099
1268
|
|
|
1269
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1270
|
+
class AddEdgeBias(GraphLayer):
|
|
1271
|
+
|
|
1272
|
+
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1273
|
+
self._has_edge_length = 'length' in spec.edge
|
|
1274
|
+
self._has_edge_feature = 'feature' in spec.edge
|
|
1275
|
+
if self._has_edge_feature:
|
|
1276
|
+
self._edge_feature_dense = self.get_dense(units=1)
|
|
1277
|
+
if self._has_edge_length:
|
|
1278
|
+
self._edge_length_dense = self.get_dense(
|
|
1279
|
+
units=1, kernel_initializer='zeros'
|
|
1280
|
+
)
|
|
1281
|
+
|
|
1282
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1283
|
+
bias = keras.ops.zeros(
|
|
1284
|
+
shape=(tensor.num_edges, 1),
|
|
1285
|
+
dtype=tensor.node['feature'].dtype
|
|
1286
|
+
)
|
|
1287
|
+
if self._has_edge_feature:
|
|
1288
|
+
bias += self._edge_feature_dense(tensor.edge['feature'])
|
|
1289
|
+
if self._has_edge_length:
|
|
1290
|
+
bias += self._edge_length_dense(tensor.edge['length'])
|
|
1291
|
+
return tensor.update(
|
|
1292
|
+
{
|
|
1293
|
+
'edge': {
|
|
1294
|
+
'bias': bias
|
|
1295
|
+
}
|
|
1296
|
+
}
|
|
1297
|
+
)
|
|
1298
|
+
|
|
1299
|
+
|
|
1100
1300
|
def Input(spec: tensors.GraphTensor.Spec) -> dict:
|
|
1101
1301
|
"""Used to specify inputs to model.
|
|
1102
1302
|
|