molcraft 0.1.0a2__py3-none-any.whl → 0.1.0a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of molcraft might be problematic. Click here for more details.
- molcraft/__init__.py +2 -1
- molcraft/features.py +5 -3
- molcraft/featurizers.py +2 -1
- molcraft/layers.py +560 -108
- molcraft/models.py +34 -3
- {molcraft-0.1.0a2.dist-info → molcraft-0.1.0a3.dist-info}/METADATA +1 -1
- {molcraft-0.1.0a2.dist-info → molcraft-0.1.0a3.dist-info}/RECORD +10 -10
- {molcraft-0.1.0a2.dist-info → molcraft-0.1.0a3.dist-info}/WHEEL +0 -0
- {molcraft-0.1.0a2.dist-info → molcraft-0.1.0a3.dist-info}/licenses/LICENSE +0 -0
- {molcraft-0.1.0a2.dist-info → molcraft-0.1.0a3.dist-info}/top_level.txt +0 -0
molcraft/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
__version__ = '0.1.
|
|
1
|
+
__version__ = '0.1.0a3'
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
@@ -14,3 +14,4 @@ from molcraft import ops
|
|
|
14
14
|
from molcraft import records
|
|
15
15
|
from molcraft import tensors
|
|
16
16
|
from molcraft import callbacks
|
|
17
|
+
from molcraft import datasets
|
molcraft/features.py
CHANGED
|
@@ -155,9 +155,11 @@ class Distance(EdgeFeature):
|
|
|
155
155
|
encode_oov: bool = True,
|
|
156
156
|
**kwargs,
|
|
157
157
|
) -> None:
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
158
|
+
vocab = kwargs.pop('vocab', None)
|
|
159
|
+
if not vocab:
|
|
160
|
+
if max_distance is None:
|
|
161
|
+
max_distance = 20
|
|
162
|
+
vocab = list(range(max_distance + 1))
|
|
161
163
|
super().__init__(
|
|
162
164
|
vocab=vocab,
|
|
163
165
|
allow_oov=allow_oov,
|
molcraft/featurizers.py
CHANGED
|
@@ -402,6 +402,7 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
402
402
|
return cls(**config)
|
|
403
403
|
|
|
404
404
|
|
|
405
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
405
406
|
class MolGraphFeaturizer3D(MolGraphFeaturizer):
|
|
406
407
|
|
|
407
408
|
"""Molecular 3d-graph featurizer.
|
|
@@ -627,7 +628,7 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
|
|
|
627
628
|
config['conformer_generator'] = keras.saving.deserialize_keras_object(
|
|
628
629
|
config['conformer_generator']
|
|
629
630
|
)
|
|
630
|
-
return super().from_config(
|
|
631
|
+
return super().from_config(config)
|
|
631
632
|
|
|
632
633
|
|
|
633
634
|
def save_featurizer(
|
molcraft/layers.py
CHANGED
|
@@ -60,25 +60,20 @@ class GraphLayer(keras.layers.Layer):
|
|
|
60
60
|
May use built-in methods such as `get_weight`, `get_dense` and `get_einsum_dense`.
|
|
61
61
|
|
|
62
62
|
Optionally implemented by subclass. If implemented, it is recommended to
|
|
63
|
-
|
|
64
|
-
|
|
63
|
+
If sub-layers are built (via `build` or `build_from_spec`), set `built`
|
|
64
|
+
to True. If not, symbolic input will be passed through the layer to build them.
|
|
65
65
|
|
|
66
66
|
Args:
|
|
67
67
|
spec:
|
|
68
|
-
A `GraphTensor.Spec` instance, corresponding to the
|
|
69
|
-
|
|
68
|
+
A `GraphTensor.Spec` instance, corresponding to the `GraphTensor`
|
|
69
|
+
passed to `propagate`.
|
|
70
70
|
"""
|
|
71
|
-
|
|
71
|
+
|
|
72
72
|
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
73
73
|
|
|
74
74
|
self._custom_build_config = {'spec': _serialize_spec(spec)}
|
|
75
75
|
|
|
76
|
-
|
|
77
|
-
GraphLayer.build_from_spec != self.__class__.build_from_spec
|
|
78
|
-
)
|
|
79
|
-
if invoke_build_from_spec:
|
|
80
|
-
self.build_from_spec(spec)
|
|
81
|
-
self.built = True
|
|
76
|
+
self.build_from_spec(spec)
|
|
82
77
|
|
|
83
78
|
if not self.built:
|
|
84
79
|
# Automatically build layer or model by calling it on symbolic inputs
|
|
@@ -229,7 +224,7 @@ class GraphConv(GraphLayer):
|
|
|
229
224
|
|
|
230
225
|
def __init__(
|
|
231
226
|
self,
|
|
232
|
-
units: int,
|
|
227
|
+
units: int = None,
|
|
233
228
|
normalize: bool = False,
|
|
234
229
|
skip_connection: bool = False,
|
|
235
230
|
**kwargs
|
|
@@ -240,6 +235,10 @@ class GraphConv(GraphLayer):
|
|
|
240
235
|
self._skip_connection = skip_connection
|
|
241
236
|
|
|
242
237
|
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
238
|
+
if not self.units:
|
|
239
|
+
raise ValueError(
|
|
240
|
+
f'`self.units` needs to be a positive integer. ound: {self.units}.'
|
|
241
|
+
)
|
|
243
242
|
node_feature_dim = spec.node['feature'].shape[-1]
|
|
244
243
|
self._project_input_node_feature = (
|
|
245
244
|
self._skip_connection and (node_feature_dim != self.units)
|
|
@@ -256,7 +255,7 @@ class GraphConv(GraphLayer):
|
|
|
256
255
|
)
|
|
257
256
|
if self._normalize_aggregate:
|
|
258
257
|
self._aggregation_norm = keras.layers.LayerNormalization(
|
|
259
|
-
name='
|
|
258
|
+
name='aggregation_normalization'
|
|
260
259
|
)
|
|
261
260
|
self._aggregation_norm.build([None, self.units])
|
|
262
261
|
|
|
@@ -344,7 +343,7 @@ class GraphConv(GraphLayer):
|
|
|
344
343
|
|
|
345
344
|
|
|
346
345
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
347
|
-
class
|
|
346
|
+
class GIConv(GraphConv):
|
|
348
347
|
|
|
349
348
|
"""Graph isomorphism network layer.
|
|
350
349
|
"""
|
|
@@ -381,7 +380,8 @@ class GINConv(GraphConv):
|
|
|
381
380
|
trainable=True,
|
|
382
381
|
)
|
|
383
382
|
|
|
384
|
-
|
|
383
|
+
self._has_edge_feature = 'feature' in spec.edge
|
|
384
|
+
if self._has_edge_feature:
|
|
385
385
|
edge_feature_dim = spec.edge['feature'].shape[-1]
|
|
386
386
|
|
|
387
387
|
if not self._update_edge_feature:
|
|
@@ -402,16 +402,15 @@ class GINConv(GraphConv):
|
|
|
402
402
|
self._feedforward_intermediate_dense = self.get_dense(self.units)
|
|
403
403
|
self._feedforward_intermediate_dense.build([None, node_feature_dim])
|
|
404
404
|
|
|
405
|
-
has_overridden_update = self.__class__.update !=
|
|
405
|
+
has_overridden_update = self.__class__.update != GIConv.update
|
|
406
406
|
if not has_overridden_update:
|
|
407
|
-
# Use default feedforward network
|
|
408
|
-
|
|
409
|
-
self._feedforward_dropout = keras.layers.Dropout(self._dropout)
|
|
410
407
|
self._feedforward_activation = self._activation
|
|
411
|
-
|
|
408
|
+
self._feedforward_dropout = keras.layers.Dropout(self._dropout)
|
|
412
409
|
self._feedforward_output_dense = self.get_dense(self.units)
|
|
413
410
|
self._feedforward_output_dense.build([None, self.units])
|
|
414
|
-
|
|
411
|
+
|
|
412
|
+
self.built = True
|
|
413
|
+
|
|
415
414
|
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
416
415
|
"""Computes messages.
|
|
417
416
|
"""
|
|
@@ -419,7 +418,7 @@ class GINConv(GraphConv):
|
|
|
419
418
|
edge_feature = tensor.edge.get('feature')
|
|
420
419
|
if self._update_edge_feature:
|
|
421
420
|
edge_feature = self._edge_dense(edge_feature)
|
|
422
|
-
if
|
|
421
|
+
if self._has_edge_feature:
|
|
423
422
|
message += edge_feature
|
|
424
423
|
return tensor.update(
|
|
425
424
|
{
|
|
@@ -436,7 +435,6 @@ class GINConv(GraphConv):
|
|
|
436
435
|
node_feature = tensor.aggregate('message')
|
|
437
436
|
node_feature += (1 + self.epsilon) * tensor.node['feature']
|
|
438
437
|
node_feature = self._feedforward_intermediate_dense(node_feature)
|
|
439
|
-
node_feature = self._feedforward_activation(node_feature)
|
|
440
438
|
return tensor.update(
|
|
441
439
|
{
|
|
442
440
|
'node': {
|
|
@@ -452,6 +450,7 @@ class GINConv(GraphConv):
|
|
|
452
450
|
"""Updates nodes.
|
|
453
451
|
"""
|
|
454
452
|
node_feature = tensor.node['feature']
|
|
453
|
+
node_feature = self._feedforward_activation(node_feature)
|
|
455
454
|
node_feature = self._feedforward_dropout(node_feature)
|
|
456
455
|
node_feature = self._feedforward_output_dense(node_feature)
|
|
457
456
|
return tensor.update(
|
|
@@ -472,6 +471,171 @@ class GINConv(GraphConv):
|
|
|
472
471
|
return config
|
|
473
472
|
|
|
474
473
|
|
|
474
|
+
@keras.saving.register_keras_serializable(package='molgraphx')
|
|
475
|
+
class GAConv(GraphConv):
|
|
476
|
+
|
|
477
|
+
"""Graph attention network layer.
|
|
478
|
+
"""
|
|
479
|
+
|
|
480
|
+
def __init__(
|
|
481
|
+
self,
|
|
482
|
+
units: int,
|
|
483
|
+
heads: int = 8,
|
|
484
|
+
activation: keras.layers.Activation | str | None = "relu",
|
|
485
|
+
use_bias: bool = True,
|
|
486
|
+
normalize: bool = True,
|
|
487
|
+
dropout: float = 0.0,
|
|
488
|
+
update_edge_feature: bool = True,
|
|
489
|
+
attention_activation: keras.layers.Activation | str | None = "leaky_relu",
|
|
490
|
+
**kwargs,
|
|
491
|
+
) -> None:
|
|
492
|
+
kwargs['skip_connection'] = False
|
|
493
|
+
super().__init__(
|
|
494
|
+
units=units,
|
|
495
|
+
normalize=normalize,
|
|
496
|
+
use_bias=use_bias,
|
|
497
|
+
**kwargs
|
|
498
|
+
)
|
|
499
|
+
self._heads = heads
|
|
500
|
+
if self.units % self.heads != 0:
|
|
501
|
+
raise ValueError(f"units need to be divisible by heads.")
|
|
502
|
+
self._head_units = self.units // self.heads
|
|
503
|
+
self._activation = keras.activations.get(activation)
|
|
504
|
+
self._dropout = dropout
|
|
505
|
+
self._normalize = normalize
|
|
506
|
+
self._update_edge_feature = update_edge_feature
|
|
507
|
+
self._attention_activation = keras.activations.get(attention_activation)
|
|
508
|
+
|
|
509
|
+
@property
|
|
510
|
+
def heads(self):
|
|
511
|
+
return self._heads
|
|
512
|
+
|
|
513
|
+
@property
|
|
514
|
+
def head_units(self):
|
|
515
|
+
return self._head_units
|
|
516
|
+
|
|
517
|
+
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
518
|
+
|
|
519
|
+
node_feature_dim = spec.node['feature'].shape[-1]
|
|
520
|
+
attn_feature_dim = node_feature_dim + node_feature_dim
|
|
521
|
+
|
|
522
|
+
self._has_edge_feature = 'feature' in spec.edge
|
|
523
|
+
if self._has_edge_feature:
|
|
524
|
+
edge_feature_dim = spec.edge['feature'].shape[-1]
|
|
525
|
+
attn_feature_dim += edge_feature_dim
|
|
526
|
+
if self._update_edge_feature:
|
|
527
|
+
self._edge_dense = self.get_einsum_dense(
|
|
528
|
+
'ijh,jkh->ikh', (self.head_units, self.heads)
|
|
529
|
+
)
|
|
530
|
+
self._edge_dense.build([None, self.head_units, self.heads])
|
|
531
|
+
else:
|
|
532
|
+
self._update_edge_feature = False
|
|
533
|
+
|
|
534
|
+
self._node_dense = self.get_einsum_dense(
|
|
535
|
+
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
536
|
+
)
|
|
537
|
+
self._node_dense.build([None, node_feature_dim])
|
|
538
|
+
|
|
539
|
+
self._feature_dense = self.get_einsum_dense(
|
|
540
|
+
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
541
|
+
)
|
|
542
|
+
self._feature_dense.build([None, attn_feature_dim])
|
|
543
|
+
|
|
544
|
+
self._attention_dense = self.get_einsum_dense(
|
|
545
|
+
'ijh,jkh->ikh', (1, self.heads)
|
|
546
|
+
)
|
|
547
|
+
self._attention_dense.build([None, self.head_units, self.heads])
|
|
548
|
+
|
|
549
|
+
self._node_self_dense = self.get_einsum_dense(
|
|
550
|
+
'ij,jkh->ikh', (self.head_units, self.heads)
|
|
551
|
+
)
|
|
552
|
+
self._node_self_dense.build([None, node_feature_dim])
|
|
553
|
+
self._dropout_layer = keras.layers.Dropout(self._dropout)
|
|
554
|
+
|
|
555
|
+
self.built = True
|
|
556
|
+
|
|
557
|
+
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
558
|
+
|
|
559
|
+
attention_feature = keras.ops.concatenate(
|
|
560
|
+
[
|
|
561
|
+
tensor.gather('feature', 'source'),
|
|
562
|
+
tensor.gather('feature', 'target')
|
|
563
|
+
],
|
|
564
|
+
axis=-1
|
|
565
|
+
)
|
|
566
|
+
if self._has_edge_feature:
|
|
567
|
+
attention_feature = keras.ops.concatenate(
|
|
568
|
+
[
|
|
569
|
+
attention_feature,
|
|
570
|
+
tensor.edge['feature']
|
|
571
|
+
],
|
|
572
|
+
axis=-1
|
|
573
|
+
)
|
|
574
|
+
|
|
575
|
+
attention_feature = self._feature_dense(attention_feature)
|
|
576
|
+
|
|
577
|
+
edge_feature = tensor.edge.get('feature')
|
|
578
|
+
|
|
579
|
+
if self._update_edge_feature:
|
|
580
|
+
edge_feature = self._edge_dense(attention_feature)
|
|
581
|
+
edge_feature = keras.ops.reshape(edge_feature, (-1, self.units))
|
|
582
|
+
|
|
583
|
+
attention_feature = self._attention_activation(attention_feature)
|
|
584
|
+
attention_score = self._attention_dense(attention_feature)
|
|
585
|
+
attention_score = ops.edge_softmax(
|
|
586
|
+
score=attention_score, edge_target=tensor.edge['target']
|
|
587
|
+
)
|
|
588
|
+
node_feature = self._node_dense(tensor.node['feature'])
|
|
589
|
+
message = ops.gather(node_feature, tensor.edge['source'])
|
|
590
|
+
return tensor.update(
|
|
591
|
+
{
|
|
592
|
+
'edge': {
|
|
593
|
+
'message': message,
|
|
594
|
+
'weight': attention_score,
|
|
595
|
+
'feature': edge_feature,
|
|
596
|
+
}
|
|
597
|
+
}
|
|
598
|
+
)
|
|
599
|
+
|
|
600
|
+
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
601
|
+
node_feature = tensor.aggregate('message')
|
|
602
|
+
node_feature += self._node_self_dense(tensor.node['feature'])
|
|
603
|
+
node_feature = self._dropout_layer(node_feature)
|
|
604
|
+
node_feature = keras.ops.reshape(node_feature, (-1, self.units))
|
|
605
|
+
return tensor.update(
|
|
606
|
+
{
|
|
607
|
+
'node': {
|
|
608
|
+
'feature': node_feature
|
|
609
|
+
},
|
|
610
|
+
'edge': {
|
|
611
|
+
'message': None,
|
|
612
|
+
'weight': None,
|
|
613
|
+
}
|
|
614
|
+
}
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
618
|
+
node_feature = self._activation(tensor.node['feature'])
|
|
619
|
+
return tensor.update(
|
|
620
|
+
{
|
|
621
|
+
'node': {
|
|
622
|
+
'feature': node_feature
|
|
623
|
+
}
|
|
624
|
+
}
|
|
625
|
+
)
|
|
626
|
+
|
|
627
|
+
def get_config(self) -> dict:
|
|
628
|
+
config = super().get_config()
|
|
629
|
+
config.update({
|
|
630
|
+
"heads": self._heads,
|
|
631
|
+
'activation': keras.activations.serialize(self._activation),
|
|
632
|
+
'dropout': self._dropout,
|
|
633
|
+
'update_edge_feature': self._update_edge_feature,
|
|
634
|
+
'attention_activation': keras.activations.serialize(self._attention_activation),
|
|
635
|
+
})
|
|
636
|
+
return config
|
|
637
|
+
|
|
638
|
+
|
|
475
639
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
476
640
|
class GTConv(GraphConv):
|
|
477
641
|
|
|
@@ -550,10 +714,11 @@ class GTConv(GraphConv):
|
|
|
550
714
|
|
|
551
715
|
self._self_attention_dropout = keras.layers.Dropout(self._dropout)
|
|
552
716
|
|
|
553
|
-
self.
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
self.
|
|
717
|
+
self._add_bias = not 'bias' in spec.edge
|
|
718
|
+
|
|
719
|
+
if self._add_bias:
|
|
720
|
+
self._edge_bias = EdgeBias(biases=self.heads)
|
|
721
|
+
self._edge_bias.build_from_spec(spec)
|
|
557
722
|
|
|
558
723
|
has_overridden_update = self.__class__.update != GTConv.update
|
|
559
724
|
if not has_overridden_update:
|
|
@@ -570,11 +735,21 @@ class GTConv(GraphConv):
|
|
|
570
735
|
self._feedforward_output_dense = self.get_dense(self.units)
|
|
571
736
|
self._feedforward_output_dense.build([None, self.units])
|
|
572
737
|
|
|
738
|
+
self.built = True
|
|
573
739
|
|
|
574
740
|
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
575
741
|
"""Computes messages.
|
|
576
742
|
"""
|
|
577
|
-
|
|
743
|
+
if self._add_bias:
|
|
744
|
+
edge_bias = self._edge_bias(tensor)
|
|
745
|
+
tensor = tensor.update(
|
|
746
|
+
{
|
|
747
|
+
'edge': {
|
|
748
|
+
'bias': edge_bias
|
|
749
|
+
}
|
|
750
|
+
}
|
|
751
|
+
)
|
|
752
|
+
|
|
578
753
|
node_feature = tensor.node['feature']
|
|
579
754
|
|
|
580
755
|
query = self._query_dense(node_feature)
|
|
@@ -587,11 +762,8 @@ class GTConv(GraphConv):
|
|
|
587
762
|
|
|
588
763
|
attention_score = keras.ops.sum(query * key, axis=1, keepdims=True)
|
|
589
764
|
attention_score /= keras.ops.sqrt(float(self.head_units))
|
|
590
|
-
|
|
591
|
-
if self._add_edge_bias:
|
|
592
|
-
tensor = self._add_edge_bias(tensor)
|
|
593
765
|
|
|
594
|
-
attention_score += keras.ops.expand_dims(tensor.edge['bias'],
|
|
766
|
+
attention_score += keras.ops.expand_dims(tensor.edge['bias'], axis=1)
|
|
595
767
|
|
|
596
768
|
attention = ops.edge_softmax(attention_score, tensor.edge['target'])
|
|
597
769
|
attention = self._softmax_dropout(attention)
|
|
@@ -665,6 +837,242 @@ class GTConv(GraphConv):
|
|
|
665
837
|
return config
|
|
666
838
|
|
|
667
839
|
|
|
840
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
841
|
+
class GTConv3D(GTConv):
|
|
842
|
+
|
|
843
|
+
"""Graph transformer 3D layer.
|
|
844
|
+
"""
|
|
845
|
+
|
|
846
|
+
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
847
|
+
super().build_from_spec(spec)
|
|
848
|
+
if self._add_bias:
|
|
849
|
+
node_feature_dim = spec.node['feature'].shape[-1]
|
|
850
|
+
kernels = self.units
|
|
851
|
+
self._gaussian_basis = GaussianDistance(kernels)
|
|
852
|
+
self._gaussian_basis.build_from_spec(spec)
|
|
853
|
+
self._centrality_dense = self.get_dense(units=node_feature_dim)
|
|
854
|
+
self._centrality_dense.build([None, kernels])
|
|
855
|
+
self._gaussian_edge_bias = self.get_dense(self.heads)
|
|
856
|
+
self._gaussian_edge_bias.build([None, kernels])
|
|
857
|
+
|
|
858
|
+
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
859
|
+
"""Computes messages.
|
|
860
|
+
"""
|
|
861
|
+
node_feature = tensor.node['feature']
|
|
862
|
+
|
|
863
|
+
if self._add_bias:
|
|
864
|
+
gaussian = self._gaussian_basis(tensor)
|
|
865
|
+
centrality = keras.ops.segment_sum(
|
|
866
|
+
gaussian, tensor.edge['target'], tensor.num_nodes
|
|
867
|
+
)
|
|
868
|
+
node_feature += self._centrality_dense(centrality)
|
|
869
|
+
|
|
870
|
+
edge_bias = self._edge_bias(tensor) + self._gaussian_edge_bias(gaussian)
|
|
871
|
+
tensor = tensor.update({'edge': {'bias': edge_bias}})
|
|
872
|
+
|
|
873
|
+
query = self._query_dense(node_feature)
|
|
874
|
+
key = self._key_dense(node_feature)
|
|
875
|
+
value = self._value_dense(node_feature)
|
|
876
|
+
|
|
877
|
+
query = ops.gather(query, tensor.edge['source'])
|
|
878
|
+
key = ops.gather(key, tensor.edge['target'])
|
|
879
|
+
value = ops.gather(value, tensor.edge['source'])
|
|
880
|
+
|
|
881
|
+
attention_score = keras.ops.sum(query * key, axis=1, keepdims=True)
|
|
882
|
+
attention_score /= keras.ops.sqrt(float(self.head_units))
|
|
883
|
+
|
|
884
|
+
attention_score += keras.ops.expand_dims(tensor.edge['bias'], axis=1)
|
|
885
|
+
|
|
886
|
+
attention = ops.edge_softmax(attention_score, tensor.edge['target'])
|
|
887
|
+
attention = self._softmax_dropout(attention)
|
|
888
|
+
|
|
889
|
+
distance = keras.ops.subtract(
|
|
890
|
+
tensor.gather('coordinate', 'source'),
|
|
891
|
+
tensor.gather('coordinate', 'target')
|
|
892
|
+
)
|
|
893
|
+
euclidean_distance = ops.euclidean_distance(
|
|
894
|
+
tensor.gather('coordinate', 'source'),
|
|
895
|
+
tensor.gather('coordinate', 'target'),
|
|
896
|
+
axis=-1
|
|
897
|
+
)
|
|
898
|
+
distance /= euclidean_distance
|
|
899
|
+
|
|
900
|
+
attention *= keras.ops.expand_dims(distance, axis=-1)
|
|
901
|
+
attention = keras.ops.expand_dims(attention, axis=2)
|
|
902
|
+
value = keras.ops.expand_dims(value, axis=1)
|
|
903
|
+
|
|
904
|
+
return tensor.update(
|
|
905
|
+
{
|
|
906
|
+
'edge': {
|
|
907
|
+
'message': value,
|
|
908
|
+
'weight': attention,
|
|
909
|
+
},
|
|
910
|
+
}
|
|
911
|
+
)
|
|
912
|
+
|
|
913
|
+
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
914
|
+
"""Aggregates messages.
|
|
915
|
+
"""
|
|
916
|
+
node_feature = tensor.aggregate('message')
|
|
917
|
+
node_feature = keras.ops.reshape(
|
|
918
|
+
node_feature, (tensor.num_nodes, -1, self.units)
|
|
919
|
+
)
|
|
920
|
+
node_feature = self._output_dense(node_feature)
|
|
921
|
+
node_feature = keras.ops.sum(node_feature, axis=1)
|
|
922
|
+
node_feature = self._self_attention_dropout(node_feature)
|
|
923
|
+
return tensor.update(
|
|
924
|
+
{
|
|
925
|
+
'node': {
|
|
926
|
+
'feature': node_feature,
|
|
927
|
+
'residual': tensor.node['feature']
|
|
928
|
+
},
|
|
929
|
+
'edge': {
|
|
930
|
+
'message': None,
|
|
931
|
+
'weight': None,
|
|
932
|
+
}
|
|
933
|
+
}
|
|
934
|
+
)
|
|
935
|
+
|
|
936
|
+
|
|
937
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
938
|
+
class MPConv(GraphConv):
|
|
939
|
+
|
|
940
|
+
"""Message passing neural network layer.
|
|
941
|
+
"""
|
|
942
|
+
|
|
943
|
+
def __init__(
|
|
944
|
+
self,
|
|
945
|
+
units: int = 128,
|
|
946
|
+
activation: keras.layers.Activation | str | None = None,
|
|
947
|
+
use_bias: bool = True,
|
|
948
|
+
normalize: bool = True,
|
|
949
|
+
dropout: float = 0.0,
|
|
950
|
+
**kwargs
|
|
951
|
+
) -> None:
|
|
952
|
+
super().__init__(
|
|
953
|
+
units=units,
|
|
954
|
+
normalize=normalize,
|
|
955
|
+
use_bias=use_bias,
|
|
956
|
+
**kwargs
|
|
957
|
+
)
|
|
958
|
+
self._activation = keras.activations.get(activation)
|
|
959
|
+
self._dropout = dropout or 0.0
|
|
960
|
+
|
|
961
|
+
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
962
|
+
node_feature_dim = spec.node['feature'].shape[-1]
|
|
963
|
+
self.message_fn = self.get_dense(self.units, activation=self._activation)
|
|
964
|
+
self.update_fn = keras.layers.GRUCell(self.units)
|
|
965
|
+
self._has_edge_feature = 'feature' in spec.edge
|
|
966
|
+
self.project_input_node_feature = node_feature_dim != self.units
|
|
967
|
+
if self.project_input_node_feature:
|
|
968
|
+
warn(
|
|
969
|
+
'Input node feature dim does not match updated node feature dim. '
|
|
970
|
+
'To make sure input node feature can be passed as `states` to the '
|
|
971
|
+
'GRU cell, it will automatically be projected prior to it.'
|
|
972
|
+
)
|
|
973
|
+
self._previous_node_dense = self.get_dense(self.units, activation=self._activation)
|
|
974
|
+
self.built = True
|
|
975
|
+
|
|
976
|
+
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
977
|
+
feature = keras.ops.concatenate(
|
|
978
|
+
[
|
|
979
|
+
tensor.gather('feature', 'source'),
|
|
980
|
+
tensor.gather('feature', 'target'),
|
|
981
|
+
],
|
|
982
|
+
axis=-1
|
|
983
|
+
)
|
|
984
|
+
if self._has_edge_feature:
|
|
985
|
+
feature = keras.ops.concatenate(
|
|
986
|
+
[
|
|
987
|
+
feature,
|
|
988
|
+
tensor.edge['feature']
|
|
989
|
+
],
|
|
990
|
+
axis=-1
|
|
991
|
+
)
|
|
992
|
+
message = self.message_fn(feature)
|
|
993
|
+
return tensor.update(
|
|
994
|
+
{
|
|
995
|
+
'edge': {
|
|
996
|
+
'message': message,
|
|
997
|
+
}
|
|
998
|
+
}
|
|
999
|
+
)
|
|
1000
|
+
|
|
1001
|
+
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1002
|
+
aggregate = tensor.aggregate('message')
|
|
1003
|
+
previous = tensor.node['feature']
|
|
1004
|
+
if self.project_input_node_feature:
|
|
1005
|
+
previous = self._previous_node_dense(previous)
|
|
1006
|
+
return tensor.update(
|
|
1007
|
+
{
|
|
1008
|
+
'node': {
|
|
1009
|
+
'feature': aggregate,
|
|
1010
|
+
'previous_feature': previous,
|
|
1011
|
+
}
|
|
1012
|
+
}
|
|
1013
|
+
)
|
|
1014
|
+
|
|
1015
|
+
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1016
|
+
updated_node_feature, _ = self.update_fn(
|
|
1017
|
+
inputs=tensor.node['feature'],
|
|
1018
|
+
states=tensor.node['previous_feature']
|
|
1019
|
+
)
|
|
1020
|
+
return tensor.update(
|
|
1021
|
+
{
|
|
1022
|
+
'node': {
|
|
1023
|
+
'feature': updated_node_feature,
|
|
1024
|
+
'previous_feature': None,
|
|
1025
|
+
}
|
|
1026
|
+
}
|
|
1027
|
+
)
|
|
1028
|
+
|
|
1029
|
+
def get_config(self) -> dict:
|
|
1030
|
+
config = super().get_config()
|
|
1031
|
+
config.update({
|
|
1032
|
+
'activation': keras.activations.serialize(self._activation),
|
|
1033
|
+
'dropout': self._dropout,
|
|
1034
|
+
})
|
|
1035
|
+
return config
|
|
1036
|
+
|
|
1037
|
+
|
|
1038
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1039
|
+
class MPConv3D(MPConv):
|
|
1040
|
+
|
|
1041
|
+
"""3D Message passing neural network layer.
|
|
1042
|
+
"""
|
|
1043
|
+
|
|
1044
|
+
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1045
|
+
euclidean_distance = ops.euclidean_distance(
|
|
1046
|
+
tensor.gather('coordinate', 'target'),
|
|
1047
|
+
tensor.gather('coordinate', 'source'),
|
|
1048
|
+
axis=-1
|
|
1049
|
+
)
|
|
1050
|
+
feature = keras.ops.concatenate(
|
|
1051
|
+
[
|
|
1052
|
+
tensor.gather('feature', 'source'),
|
|
1053
|
+
tensor.gather('feature', 'target'),
|
|
1054
|
+
euclidean_distance,
|
|
1055
|
+
],
|
|
1056
|
+
axis=-1
|
|
1057
|
+
)
|
|
1058
|
+
if self._has_edge_feature:
|
|
1059
|
+
feature = keras.ops.concatenate(
|
|
1060
|
+
[
|
|
1061
|
+
feature,
|
|
1062
|
+
tensor.edge['feature']
|
|
1063
|
+
],
|
|
1064
|
+
axis=-1
|
|
1065
|
+
)
|
|
1066
|
+
message = self.message_fn(feature)
|
|
1067
|
+
return tensor.update(
|
|
1068
|
+
{
|
|
1069
|
+
'edge': {
|
|
1070
|
+
'message': message,
|
|
1071
|
+
}
|
|
1072
|
+
}
|
|
1073
|
+
)
|
|
1074
|
+
|
|
1075
|
+
|
|
668
1076
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
669
1077
|
class EGConv3D(GraphConv):
|
|
670
1078
|
|
|
@@ -714,6 +1122,7 @@ class EGConv3D(GraphConv):
|
|
|
714
1122
|
self.update_fn = self.get_dense(self.units, activation=self._activation)
|
|
715
1123
|
self.update_fn.build([None, node_feature_dim + self.units])
|
|
716
1124
|
self._dropout_layer = keras.layers.Dropout(self._dropout)
|
|
1125
|
+
self.built = True
|
|
717
1126
|
|
|
718
1127
|
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
719
1128
|
"""Computes messages.
|
|
@@ -849,6 +1258,7 @@ class Projection(GraphLayer):
|
|
|
849
1258
|
self.units = feature_dim
|
|
850
1259
|
self._dense = self.get_dense(self.units)
|
|
851
1260
|
self._dense.build([None, feature_dim])
|
|
1261
|
+
self.built = True
|
|
852
1262
|
|
|
853
1263
|
def propagate(self, tensor: tensors.GraphTensor):
|
|
854
1264
|
"""Calls the layer.
|
|
@@ -913,6 +1323,7 @@ class GraphNetwork(GraphLayer):
|
|
|
913
1323
|
)
|
|
914
1324
|
self._edge_dense = self.get_dense(units)
|
|
915
1325
|
self._update_edge_feature = True
|
|
1326
|
+
self.built = True
|
|
916
1327
|
|
|
917
1328
|
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
918
1329
|
"""Calls the layer.
|
|
@@ -1003,12 +1414,14 @@ class NodeEmbedding(GraphLayer):
|
|
|
1003
1414
|
def __init__(
|
|
1004
1415
|
self,
|
|
1005
1416
|
dim: int = None,
|
|
1417
|
+
normalize: bool = True,
|
|
1006
1418
|
embed_context: bool = True,
|
|
1007
1419
|
allow_masking: bool = True,
|
|
1008
1420
|
**kwargs
|
|
1009
1421
|
) -> None:
|
|
1010
1422
|
super().__init__(**kwargs)
|
|
1011
1423
|
self.dim = dim
|
|
1424
|
+
self._normalize = normalize
|
|
1012
1425
|
self._embed_context = embed_context
|
|
1013
1426
|
self._masking_rate = None
|
|
1014
1427
|
self._allow_masking = allow_masking
|
|
@@ -1035,6 +1448,12 @@ class NodeEmbedding(GraphLayer):
|
|
|
1035
1448
|
context_feature_dim = spec.context['feature'].shape[-1]
|
|
1036
1449
|
self._context_dense = self.get_dense(self.dim)
|
|
1037
1450
|
self._context_dense.build([None, context_feature_dim])
|
|
1451
|
+
|
|
1452
|
+
if self._normalize:
|
|
1453
|
+
self._norm = keras.layers.LayerNormalization()
|
|
1454
|
+
self._norm.build([None, self.dim])
|
|
1455
|
+
|
|
1456
|
+
self.built = True
|
|
1038
1457
|
|
|
1039
1458
|
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1040
1459
|
"""Calls the layer.
|
|
@@ -1068,6 +1487,9 @@ class NodeEmbedding(GraphLayer):
|
|
|
1068
1487
|
# Slience warning of 'no gradients for variables'
|
|
1069
1488
|
feature = feature + (self._mask_feature * 0.0)
|
|
1070
1489
|
|
|
1490
|
+
if self._normalize:
|
|
1491
|
+
feature = self._norm(feature)
|
|
1492
|
+
|
|
1071
1493
|
return tensor.update({'node': {'feature': feature}})
|
|
1072
1494
|
|
|
1073
1495
|
@property
|
|
@@ -1087,6 +1509,8 @@ class NodeEmbedding(GraphLayer):
|
|
|
1087
1509
|
config = super().get_config()
|
|
1088
1510
|
config.update({
|
|
1089
1511
|
'dim': self.dim,
|
|
1512
|
+
'normalize': self._normalize,
|
|
1513
|
+
'embed_context': self._embed_context,
|
|
1090
1514
|
'allow_masking': self._allow_masking
|
|
1091
1515
|
})
|
|
1092
1516
|
return config
|
|
@@ -1103,11 +1527,13 @@ class EdgeEmbedding(GraphLayer):
|
|
|
1103
1527
|
def __init__(
|
|
1104
1528
|
self,
|
|
1105
1529
|
dim: int = None,
|
|
1530
|
+
normalize: bool = True,
|
|
1106
1531
|
allow_masking: bool = True,
|
|
1107
1532
|
**kwargs
|
|
1108
1533
|
) -> None:
|
|
1109
1534
|
super().__init__(**kwargs)
|
|
1110
1535
|
self.dim = dim
|
|
1536
|
+
self._normalize = normalize
|
|
1111
1537
|
self._masking_rate = None
|
|
1112
1538
|
self._allow_masking = allow_masking
|
|
1113
1539
|
|
|
@@ -1125,6 +1551,11 @@ class EdgeEmbedding(GraphLayer):
|
|
|
1125
1551
|
self._super_feature = self.get_weight(shape=[self.dim], name='super_edge_feature')
|
|
1126
1552
|
if self._allow_masking:
|
|
1127
1553
|
self._mask_feature = self.get_weight(shape=[self.dim], name='mask_edge_feature')
|
|
1554
|
+
if self._normalize:
|
|
1555
|
+
self._norm = keras.layers.LayerNormalization()
|
|
1556
|
+
self._norm.build([None, self.dim])
|
|
1557
|
+
|
|
1558
|
+
self.built = True
|
|
1128
1559
|
|
|
1129
1560
|
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1130
1561
|
"""Calls the layer.
|
|
@@ -1153,6 +1584,9 @@ class EdgeEmbedding(GraphLayer):
|
|
|
1153
1584
|
# Slience warning of 'no gradients for variables'
|
|
1154
1585
|
feature = feature + (self._mask_feature * 0.0)
|
|
1155
1586
|
|
|
1587
|
+
if self._normalize:
|
|
1588
|
+
feature = self._norm(feature)
|
|
1589
|
+
|
|
1156
1590
|
return tensor.update({'edge': {'feature': feature}})
|
|
1157
1591
|
|
|
1158
1592
|
@property
|
|
@@ -1172,6 +1606,7 @@ class EdgeEmbedding(GraphLayer):
|
|
|
1172
1606
|
config = super().get_config()
|
|
1173
1607
|
config.update({
|
|
1174
1608
|
'dim': self.dim,
|
|
1609
|
+
'normalize': self._normalize,
|
|
1175
1610
|
'allow_masking': self._allow_masking
|
|
1176
1611
|
})
|
|
1177
1612
|
return config
|
|
@@ -1199,17 +1634,97 @@ class EdgeProjection(Projection):
|
|
|
1199
1634
|
"""
|
|
1200
1635
|
def __init__(self, units: int = None, activation: str = None, **kwargs):
|
|
1201
1636
|
super().__init__(units=units, activation=activation, field='edge', **kwargs)
|
|
1637
|
+
|
|
1638
|
+
|
|
1639
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1640
|
+
class EdgeBias(GraphLayer):
|
|
1641
|
+
|
|
1642
|
+
def __init__(self, biases: int, **kwargs):
|
|
1643
|
+
super().__init__(**kwargs)
|
|
1644
|
+
self.biases = biases
|
|
1645
|
+
|
|
1646
|
+
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1647
|
+
self._has_edge_length = 'length' in spec.edge
|
|
1648
|
+
self._has_edge_feature = 'feature' in spec.edge
|
|
1649
|
+
if self._has_edge_feature:
|
|
1650
|
+
self._edge_feature_dense = self.get_dense(self.biases)
|
|
1651
|
+
self._edge_feature_dense.build([None, spec.edge['feature'].shape[-1]])
|
|
1652
|
+
if self._has_edge_length:
|
|
1653
|
+
self._edge_length_dense = self.get_dense(
|
|
1654
|
+
self.biases, kernel_initializer='zeros'
|
|
1655
|
+
)
|
|
1656
|
+
self._edge_length_dense.build([None, spec.edge['length'].shape[-1]])
|
|
1657
|
+
self.built = True
|
|
1658
|
+
|
|
1659
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1660
|
+
bias = keras.ops.zeros(
|
|
1661
|
+
shape=(tensor.num_edges, self.biases),
|
|
1662
|
+
dtype=tensor.node['feature'].dtype
|
|
1663
|
+
)
|
|
1664
|
+
if self._has_edge_feature:
|
|
1665
|
+
bias += self._edge_feature_dense(tensor.edge['feature'])
|
|
1666
|
+
if self._has_edge_length:
|
|
1667
|
+
bias += self._edge_length_dense(tensor.edge['length'])
|
|
1668
|
+
return bias
|
|
1669
|
+
|
|
1670
|
+
def get_config(self) -> dict:
|
|
1671
|
+
config = super().get_config()
|
|
1672
|
+
config.update({'biases': self.biases})
|
|
1673
|
+
return config
|
|
1202
1674
|
|
|
1203
1675
|
|
|
1204
1676
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1205
|
-
class
|
|
1677
|
+
class GaussianDistance(GraphLayer):
|
|
1678
|
+
|
|
1679
|
+
def __init__(self, kernels: int, **kwargs):
|
|
1680
|
+
super().__init__(**kwargs)
|
|
1681
|
+
self.kernels = kernels
|
|
1682
|
+
|
|
1683
|
+
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1684
|
+
self._loc = self.add_weight(
|
|
1685
|
+
shape=[self.kernels],
|
|
1686
|
+
initializer='zeros',
|
|
1687
|
+
dtype='float32',
|
|
1688
|
+
trainable=True
|
|
1689
|
+
)
|
|
1690
|
+
self._scale = self.add_weight(
|
|
1691
|
+
shape=[self.kernels],
|
|
1692
|
+
initializer='ones',
|
|
1693
|
+
dtype='float32',
|
|
1694
|
+
trainable=True
|
|
1695
|
+
)
|
|
1696
|
+
self.built = True
|
|
1697
|
+
|
|
1698
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1699
|
+
euclidean_distance = ops.euclidean_distance(
|
|
1700
|
+
tensor.gather('coordinate', 'source'),
|
|
1701
|
+
tensor.gather('coordinate', 'target'),
|
|
1702
|
+
axis=-1
|
|
1703
|
+
)
|
|
1704
|
+
return ops.gaussian(
|
|
1705
|
+
euclidean_distance, self._loc, self._scale
|
|
1706
|
+
)
|
|
1707
|
+
|
|
1708
|
+
def get_config(self) -> dict:
|
|
1709
|
+
config = super().get_config()
|
|
1710
|
+
config.update({
|
|
1711
|
+
'kernels': self.kernels,
|
|
1712
|
+
})
|
|
1713
|
+
return config
|
|
1714
|
+
|
|
1715
|
+
|
|
1716
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1717
|
+
class Readout(GraphLayer):
|
|
1718
|
+
|
|
1719
|
+
"""Readout layer.
|
|
1720
|
+
"""
|
|
1206
1721
|
|
|
1207
1722
|
def __init__(self, mode: str | None = None, **kwargs):
|
|
1723
|
+
kwargs['kernel_initializer'] = None
|
|
1724
|
+
kwargs['bias_initializer'] = None
|
|
1208
1725
|
super().__init__(**kwargs)
|
|
1209
1726
|
self.mode = mode
|
|
1210
|
-
if
|
|
1211
|
-
self._reduce_fn = None
|
|
1212
|
-
elif str(self.mode).lower().startswith('sum'):
|
|
1727
|
+
if str(self.mode).lower().startswith('sum'):
|
|
1213
1728
|
self._reduce_fn = keras.ops.segment_sum
|
|
1214
1729
|
elif str(self.mode).lower().startswith('max'):
|
|
1215
1730
|
self._reduce_fn = keras.ops.segment_max
|
|
@@ -1221,80 +1736,24 @@ class Readout(keras.layers.Layer):
|
|
|
1221
1736
|
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1222
1737
|
"""Builds the layer.
|
|
1223
1738
|
"""
|
|
1224
|
-
|
|
1739
|
+
self.built = True
|
|
1225
1740
|
|
|
1226
|
-
def
|
|
1227
|
-
|
|
1228
|
-
|
|
1741
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tf.Tensor:
|
|
1742
|
+
"""Calls the layer.
|
|
1743
|
+
"""
|
|
1744
|
+
node_feature = tensor.node['feature']
|
|
1229
1745
|
if str(self.mode).lower().startswith('super'):
|
|
1230
1746
|
node_feature = keras.ops.where(
|
|
1231
|
-
tensor.node['super'][:, None],
|
|
1232
|
-
)
|
|
1233
|
-
return self._reduce_fn(
|
|
1234
|
-
node_feature, tensor.graph_indicator, tensor.num_subgraphs
|
|
1747
|
+
tensor.node['super'][:, None], node_feature, 0.0
|
|
1235
1748
|
)
|
|
1236
1749
|
return self._reduce_fn(
|
|
1237
|
-
|
|
1750
|
+
node_feature, tensor.graph_indicator, tensor.num_subgraphs
|
|
1238
1751
|
)
|
|
1239
1752
|
|
|
1240
|
-
def build(self, input_shapes) -> None:
|
|
1241
|
-
spec = tensors.GraphTensor.Spec.from_input_shape_dict(input_shapes)
|
|
1242
|
-
self.build_from_spec(spec)
|
|
1243
|
-
self.built = True
|
|
1244
|
-
|
|
1245
|
-
def call(self, graph) -> tf.Tensor:
|
|
1246
|
-
graph_tensor = tensors.from_dict(graph)
|
|
1247
|
-
if tensors.is_ragged(graph_tensor):
|
|
1248
|
-
graph_tensor = graph_tensor.flatten()
|
|
1249
|
-
return self.reduce(graph_tensor)
|
|
1250
|
-
|
|
1251
|
-
def __call__(
|
|
1252
|
-
self,
|
|
1253
|
-
graph: tensors.GraphTensor,
|
|
1254
|
-
*args,
|
|
1255
|
-
**kwargs
|
|
1256
|
-
) -> tensors.GraphTensor:
|
|
1257
|
-
is_tensor = isinstance(graph, tensors.GraphTensor)
|
|
1258
|
-
if is_tensor:
|
|
1259
|
-
graph = tensors.to_dict(graph)
|
|
1260
|
-
tensor = super().__call__(graph, *args, **kwargs)
|
|
1261
|
-
return tensor
|
|
1262
|
-
|
|
1263
1753
|
def get_config(self) -> dict:
|
|
1264
1754
|
config = super().get_config()
|
|
1265
1755
|
config['mode'] = self.mode
|
|
1266
1756
|
return config
|
|
1267
|
-
|
|
1268
|
-
|
|
1269
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1270
|
-
class AddEdgeBias(GraphLayer):
|
|
1271
|
-
|
|
1272
|
-
def build_from_spec(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1273
|
-
self._has_edge_length = 'length' in spec.edge
|
|
1274
|
-
self._has_edge_feature = 'feature' in spec.edge
|
|
1275
|
-
if self._has_edge_feature:
|
|
1276
|
-
self._edge_feature_dense = self.get_dense(units=1)
|
|
1277
|
-
if self._has_edge_length:
|
|
1278
|
-
self._edge_length_dense = self.get_dense(
|
|
1279
|
-
units=1, kernel_initializer='zeros'
|
|
1280
|
-
)
|
|
1281
|
-
|
|
1282
|
-
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1283
|
-
bias = keras.ops.zeros(
|
|
1284
|
-
shape=(tensor.num_edges, 1),
|
|
1285
|
-
dtype=tensor.node['feature'].dtype
|
|
1286
|
-
)
|
|
1287
|
-
if self._has_edge_feature:
|
|
1288
|
-
bias += self._edge_feature_dense(tensor.edge['feature'])
|
|
1289
|
-
if self._has_edge_length:
|
|
1290
|
-
bias += self._edge_length_dense(tensor.edge['length'])
|
|
1291
|
-
return tensor.update(
|
|
1292
|
-
{
|
|
1293
|
-
'edge': {
|
|
1294
|
-
'bias': bias
|
|
1295
|
-
}
|
|
1296
|
-
}
|
|
1297
|
-
)
|
|
1298
1757
|
|
|
1299
1758
|
|
|
1300
1759
|
def Input(spec: tensors.GraphTensor.Spec) -> dict:
|
|
@@ -1412,13 +1871,6 @@ def _spec_from_inputs(inputs):
|
|
|
1412
1871
|
return tensors.GraphTensor.Spec(**nested_specs)
|
|
1413
1872
|
|
|
1414
1873
|
|
|
1415
|
-
GraphTransformer =
|
|
1416
|
-
|
|
1417
|
-
|
|
1418
|
-
EdgeEmbed = EdgeEmbedding
|
|
1419
|
-
NodeEmbed = NodeEmbedding
|
|
1420
|
-
|
|
1421
|
-
ContextDense = ContextProjection
|
|
1422
|
-
EdgeDense = EdgeProjection
|
|
1423
|
-
NodeDense = NodeProjection
|
|
1874
|
+
GraphTransformer = GTConv
|
|
1875
|
+
GraphTransformer3D = GTConv3D
|
|
1424
1876
|
|
molcraft/models.py
CHANGED
|
@@ -315,10 +315,16 @@ class FunctionalGraphModel(functional.Functional, GraphModel):
|
|
|
315
315
|
]
|
|
316
316
|
|
|
317
317
|
|
|
318
|
-
def save_model(model:
|
|
318
|
+
def save_model(model: GraphModel, filepath: str | Path, *args, **kwargs) -> None:
|
|
319
|
+
if not model.built:
|
|
320
|
+
raise ValueError(
|
|
321
|
+
'Model and its layers have not yet been (fully) built. '
|
|
322
|
+
'Build the model before saving it: `model.build(graph_spec)` '
|
|
323
|
+
'or `model(graph)`.'
|
|
324
|
+
)
|
|
319
325
|
keras.models.save_model(model, filepath, *args, **kwargs)
|
|
320
326
|
|
|
321
|
-
def load_model(filepath: str | Path, inputs=None, *args, **kwargs) ->
|
|
327
|
+
def load_model(filepath: str | Path, inputs=None, *args, **kwargs) -> GraphModel:
|
|
322
328
|
return keras.models.load_model(filepath, *args, **kwargs)
|
|
323
329
|
|
|
324
330
|
def create(
|
|
@@ -333,7 +339,7 @@ def create(
|
|
|
333
339
|
def interpret(
|
|
334
340
|
model: GraphModel,
|
|
335
341
|
graph_tensor: tensors.GraphTensor,
|
|
336
|
-
) ->
|
|
342
|
+
) -> tensors.GraphTensor:
|
|
337
343
|
x = graph_tensor
|
|
338
344
|
if tensors.is_ragged(x):
|
|
339
345
|
x = x.flatten()
|
|
@@ -373,6 +379,31 @@ def interpret(
|
|
|
373
379
|
}
|
|
374
380
|
)
|
|
375
381
|
|
|
382
|
+
def saliency(
|
|
383
|
+
model: GraphModel,
|
|
384
|
+
graph_tensor: tensors.GraphTensor
|
|
385
|
+
) -> tensors.GraphTensor:
|
|
386
|
+
x = graph_tensor
|
|
387
|
+
if tensors.is_ragged(x):
|
|
388
|
+
x = x.flatten()
|
|
389
|
+
y_true = x.context.get('label')
|
|
390
|
+
with tf.GradientTape(watch_accessed_variables=False) as tape:
|
|
391
|
+
tape.watch(x.node['feature'])
|
|
392
|
+
y_pred = model(x, training=False)
|
|
393
|
+
if y_true is not None and len(y_true.shape) > 1:
|
|
394
|
+
target = tf.gather_nd(y_pred, tf.where(y_true != 0))
|
|
395
|
+
else:
|
|
396
|
+
target = y_pred
|
|
397
|
+
gradients = tape.gradient(target, x.node['feature'])
|
|
398
|
+
gradients = keras.ops.absolute(gradients)
|
|
399
|
+
return graph_tensor.update(
|
|
400
|
+
{
|
|
401
|
+
'node': {
|
|
402
|
+
'feature_saliency': gradients
|
|
403
|
+
}
|
|
404
|
+
}
|
|
405
|
+
)
|
|
406
|
+
|
|
376
407
|
def predict(
|
|
377
408
|
model: GraphModel,
|
|
378
409
|
x: tensors.GraphTensor | tf.data.Dataset,
|
|
@@ -1,20 +1,20 @@
|
|
|
1
|
-
molcraft/__init__.py,sha256=
|
|
1
|
+
molcraft/__init__.py,sha256=2ZNfWBjGl8DscOwjdDiRkgIsuPnKit29Q3MhZyP336Q,435
|
|
2
2
|
molcraft/callbacks.py,sha256=6gwCwdsHGb-fVB4m1QGmtBwQwZ9mFq9QUkmPKSMn05U,849
|
|
3
3
|
molcraft/chem.py,sha256=_UO5O-I7KUtGf3vRrFEYoAUGlW5xi2x8ylu5f-Ybumo,18696
|
|
4
4
|
molcraft/conformers.py,sha256=p09gOQOdxLSj3yohZOMkxxLriHsZ1ZqOoiWLi73OpIg,4325
|
|
5
5
|
molcraft/datasets.py,sha256=rFgXTC1ZheLhfgQgcCspP_wEE54a33PIneH7OplbS-8,4047
|
|
6
6
|
molcraft/descriptors.py,sha256=x6RfZ-gK7D_WSvmK6sh6yHyEjQqovPnRU0xwC3dAKfg,2880
|
|
7
|
-
molcraft/features.py,sha256=
|
|
8
|
-
molcraft/featurizers.py,sha256=
|
|
9
|
-
molcraft/layers.py,sha256=
|
|
10
|
-
molcraft/models.py,sha256=
|
|
7
|
+
molcraft/features.py,sha256=69oV_GHNdBKPA4sp6Tpo6brvNmaauk_IVIzNjX7VDmg,13648
|
|
8
|
+
molcraft/featurizers.py,sha256=Yu8I6I_zkzB__WYSiqz-FDGjvKFOmyWFxojRBr39Aw8,26236
|
|
9
|
+
molcraft/layers.py,sha256=HjnAtqhuP0uZ5yP4L33k3xT4IUdLavWBrjd3wO9_Rmw,64915
|
|
10
|
+
molcraft/models.py,sha256=DXqWR_XnMVXQseVR91XnDLXvmHa1hv-6_Y_wvpQZBFI,17476
|
|
11
11
|
molcraft/ops.py,sha256=iiE6zgA2P7cmjKO1RHmL9GE_Tv7Tyuo_xDoxB_ELZQM,3824
|
|
12
12
|
molcraft/records.py,sha256=w4-bcWZEC0oVInrE1e0kQBroIaSCA0PN1JBPOtO6VUY,5251
|
|
13
13
|
molcraft/tensors.py,sha256=b7PO-YOvV72s9g057ILJACKS2n2fn10VkO35gHXpssI,22312
|
|
14
14
|
molcraft/experimental/__init__.py,sha256=x5h6LOO8bo3NPjkKKM9M1H-Kz6R3yxYhRSePoxHCdRE,42
|
|
15
15
|
molcraft/experimental/peptides.py,sha256=RCuOTSwoYHGSdeYi6TWHdPIv2WC3avCZjKLdhEZQeXw,8997
|
|
16
|
-
molcraft-0.1.
|
|
17
|
-
molcraft-0.1.
|
|
18
|
-
molcraft-0.1.
|
|
19
|
-
molcraft-0.1.
|
|
20
|
-
molcraft-0.1.
|
|
16
|
+
molcraft-0.1.0a3.dist-info/licenses/LICENSE,sha256=sbVeqlrtZ0V63uYhZGL5dCxUm8rBAOqe2avyA1zIQNk,1074
|
|
17
|
+
molcraft-0.1.0a3.dist-info/METADATA,sha256=f_5sBinpFcGSqKLaSqGkZJ83gGQtZw1Pb3fkgq9aCBM,4088
|
|
18
|
+
molcraft-0.1.0a3.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
|
19
|
+
molcraft-0.1.0a3.dist-info/top_level.txt,sha256=dENV6MfOceshM6MQCgJlcN1ojZkiCL9B4F7XyUge3QM,9
|
|
20
|
+
molcraft-0.1.0a3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|