molcraft 0.1.0a5__py3-none-any.whl → 0.1.0a6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of molcraft might be problematic. Click here for more details.
- molcraft/__init__.py +3 -2
- molcraft/chem.py +70 -4
- molcraft/conformers.py +1 -1
- molcraft/featurizers.py +20 -14
- molcraft/layers.py +258 -185
- molcraft/losses.py +36 -0
- molcraft/models.py +119 -8
- molcraft/ops.py +10 -0
- molcraft/records.py +32 -31
- molcraft/tensors.py +1 -1
- {molcraft-0.1.0a5.dist-info → molcraft-0.1.0a6.dist-info}/METADATA +4 -17
- molcraft-0.1.0a6.dist-info/RECORD +19 -0
- {molcraft-0.1.0a5.dist-info → molcraft-0.1.0a6.dist-info}/WHEEL +1 -1
- molcraft-0.1.0a5.dist-info/RECORD +0 -18
- {molcraft-0.1.0a5.dist-info → molcraft-0.1.0a6.dist-info}/licenses/LICENSE +0 -0
- {molcraft-0.1.0a5.dist-info → molcraft-0.1.0a6.dist-info}/top_level.txt +0 -0
molcraft/layers.py
CHANGED
|
@@ -125,23 +125,45 @@ class GraphLayer(keras.layers.Layer):
|
|
|
125
125
|
return tensors.to_dict(outputs)
|
|
126
126
|
return outputs
|
|
127
127
|
|
|
128
|
-
def __call__(
|
|
128
|
+
def __call__(
|
|
129
|
+
self,
|
|
130
|
+
graph: dict[str, dict[str, tf.Tensor]] | tensors.GraphTensor,
|
|
131
|
+
**kwargs
|
|
132
|
+
) -> tf.Tensor | dict[str, dict[str, tf.Tensor]] | tensors.GraphTensor:
|
|
129
133
|
if not self.built:
|
|
130
|
-
spec = _spec_from_inputs(
|
|
134
|
+
spec = _spec_from_inputs(graph)
|
|
131
135
|
self.build(spec)
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
136
|
+
|
|
137
|
+
is_graph_tensor = isinstance(graph, tensors.GraphTensor)
|
|
138
|
+
if is_graph_tensor:
|
|
139
|
+
graph = tensors.to_dict(graph)
|
|
140
|
+
else:
|
|
141
|
+
graph = {field: dict(data) for (field, data) in graph.items()}
|
|
142
|
+
|
|
135
143
|
if isinstance(self, functional.Functional):
|
|
136
|
-
|
|
137
|
-
|
|
144
|
+
# As a functional model is strict for what input can
|
|
145
|
+
# be passed to it, we need to temporarily pop some of the
|
|
146
|
+
# input and add it afterwards.
|
|
147
|
+
label = graph['context'].pop('label', None)
|
|
148
|
+
weight = graph['context'].pop('weight', None)
|
|
149
|
+
tf.nest.assert_same_structure(self.input, graph)
|
|
150
|
+
|
|
151
|
+
outputs = super().__call__(graph, **kwargs)
|
|
152
|
+
|
|
138
153
|
if not tensors.is_graph(outputs):
|
|
139
154
|
return outputs
|
|
155
|
+
|
|
156
|
+
graph = outputs
|
|
140
157
|
if isinstance(self, functional.Functional):
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
158
|
+
if label is not None:
|
|
159
|
+
graph['context']['label'] = label
|
|
160
|
+
if weight is not None:
|
|
161
|
+
graph['context']['weight'] = weight
|
|
162
|
+
|
|
163
|
+
if is_graph_tensor:
|
|
164
|
+
return tensors.from_dict(graph)
|
|
165
|
+
|
|
166
|
+
return graph
|
|
145
167
|
|
|
146
168
|
def get_build_config(self) -> dict:
|
|
147
169
|
if self._custom_build_config:
|
|
@@ -256,10 +278,10 @@ class GraphConv(GraphLayer):
|
|
|
256
278
|
Default to `None`.
|
|
257
279
|
use_bias (bool):
|
|
258
280
|
Whether bias should be used in dense layers. Default to `True`.
|
|
259
|
-
|
|
281
|
+
normalize (bool, str):
|
|
260
282
|
Whether `LayerNormalization` should be applied to the final node feature output.
|
|
261
283
|
To use `BatchNormalization`, specify `batch_norm`. Default to `False`.
|
|
262
|
-
|
|
284
|
+
skip_connect (bool, str):
|
|
263
285
|
Whether node feature input should be added to the node feature output.
|
|
264
286
|
If node feature input dim is not equal to `units` (node feature output dim),
|
|
265
287
|
a projection layer will automatically project the residual before adding it
|
|
@@ -294,14 +316,14 @@ class GraphConv(GraphLayer):
|
|
|
294
316
|
units: int = None,
|
|
295
317
|
activation: str | keras.layers.Activation | None = None,
|
|
296
318
|
use_bias: bool = True,
|
|
297
|
-
|
|
298
|
-
|
|
319
|
+
normalize: bool | str = False,
|
|
320
|
+
skip_connect: bool | str = True,
|
|
299
321
|
**kwargs
|
|
300
322
|
) -> None:
|
|
301
323
|
super().__init__(use_bias=use_bias, **kwargs)
|
|
302
324
|
self._units = units
|
|
303
|
-
self.
|
|
304
|
-
self.
|
|
325
|
+
self._normalize = normalize
|
|
326
|
+
self._skip_connect = skip_connect
|
|
305
327
|
self._activation = keras.activations.get(activation)
|
|
306
328
|
|
|
307
329
|
def __init_subclass__(cls, **kwargs):
|
|
@@ -319,36 +341,6 @@ class GraphConv(GraphLayer):
|
|
|
319
341
|
def units(self):
|
|
320
342
|
return self._units
|
|
321
343
|
|
|
322
|
-
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
323
|
-
"""Forward pass.
|
|
324
|
-
|
|
325
|
-
Invokes `message(graph)`, `aggregate(graph)` and `update(graph)` in sequence.
|
|
326
|
-
|
|
327
|
-
Arguments:
|
|
328
|
-
tensor:
|
|
329
|
-
A `GraphTensor` instance.
|
|
330
|
-
"""
|
|
331
|
-
if self._skip_connection:
|
|
332
|
-
input_node_feature = tensor.node['feature']
|
|
333
|
-
if self._project_input_node_feature:
|
|
334
|
-
input_node_feature = self._residual_projection(input_node_feature)
|
|
335
|
-
|
|
336
|
-
tensor = self.message(tensor)
|
|
337
|
-
tensor = self.aggregate(tensor)
|
|
338
|
-
tensor = self.update(tensor)
|
|
339
|
-
|
|
340
|
-
updated_node_feature = tensor.node['feature']
|
|
341
|
-
|
|
342
|
-
if self._skip_connection:
|
|
343
|
-
if self._use_weighted_skip_connection:
|
|
344
|
-
input_node_feature *= self._skip_connection_weight
|
|
345
|
-
updated_node_feature += input_node_feature
|
|
346
|
-
|
|
347
|
-
if self._normalization:
|
|
348
|
-
updated_node_feature = self._output_norm(updated_node_feature)
|
|
349
|
-
|
|
350
|
-
return tensor.update({'node': {'feature': updated_node_feature}})
|
|
351
|
-
|
|
352
344
|
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
353
345
|
if not self.units:
|
|
354
346
|
raise ValueError(
|
|
@@ -356,11 +348,11 @@ class GraphConv(GraphLayer):
|
|
|
356
348
|
)
|
|
357
349
|
node_feature_dim = spec.node['feature'].shape[-1]
|
|
358
350
|
self._project_input_node_feature = (
|
|
359
|
-
self.
|
|
351
|
+
self._skip_connect and (node_feature_dim != self.units)
|
|
360
352
|
)
|
|
361
353
|
if self._project_input_node_feature:
|
|
362
354
|
warn(
|
|
363
|
-
'`
|
|
355
|
+
'`skip_connect` is set to `True`, but found incompatible dim '
|
|
364
356
|
'between input (node feature dim) and output (`self.units`). '
|
|
365
357
|
'Automatically applying a projection layer to residual to '
|
|
366
358
|
'match input and output. '
|
|
@@ -369,8 +361,8 @@ class GraphConv(GraphLayer):
|
|
|
369
361
|
self.units, name='residual_projection'
|
|
370
362
|
)
|
|
371
363
|
|
|
372
|
-
|
|
373
|
-
self._use_weighted_skip_connection =
|
|
364
|
+
skip_connect = str(self._skip_connect).lower()
|
|
365
|
+
self._use_weighted_skip_connection = skip_connect.startswith('weight')
|
|
374
366
|
if self._use_weighted_skip_connection:
|
|
375
367
|
self._skip_connection_weight = self.add_weight(
|
|
376
368
|
name='skip_connection_weight',
|
|
@@ -379,8 +371,8 @@ class GraphConv(GraphLayer):
|
|
|
379
371
|
trainable=True,
|
|
380
372
|
)
|
|
381
373
|
|
|
382
|
-
if self.
|
|
383
|
-
if str(self.
|
|
374
|
+
if self._normalize:
|
|
375
|
+
if str(self._normalize).lower().startswith('batch'):
|
|
384
376
|
self._output_norm = keras.layers.BatchNormalization(
|
|
385
377
|
name='output_batch_norm'
|
|
386
378
|
)
|
|
@@ -389,7 +381,7 @@ class GraphConv(GraphLayer):
|
|
|
389
381
|
name='output_layer_norm'
|
|
390
382
|
)
|
|
391
383
|
|
|
392
|
-
self._has_edge_feature = '
|
|
384
|
+
self._has_edge_feature = 'feature' in spec.edge
|
|
393
385
|
|
|
394
386
|
has_overridden_message = self.__class__.message != GraphConv.message
|
|
395
387
|
if not has_overridden_message:
|
|
@@ -400,6 +392,50 @@ class GraphConv(GraphLayer):
|
|
|
400
392
|
self._output_dense = self.get_dense(self.units)
|
|
401
393
|
self._output_activation = self._activation
|
|
402
394
|
|
|
395
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
396
|
+
"""Forward pass.
|
|
397
|
+
|
|
398
|
+
Invokes `message(graph)`, `aggregate(graph)` and `update(graph)` in sequence.
|
|
399
|
+
|
|
400
|
+
Arguments:
|
|
401
|
+
tensor:
|
|
402
|
+
A `GraphTensor` instance.
|
|
403
|
+
"""
|
|
404
|
+
if self._skip_connect:
|
|
405
|
+
input_node_feature = tensor.node['feature']
|
|
406
|
+
if self._project_input_node_feature:
|
|
407
|
+
input_node_feature = self._residual_projection(input_node_feature)
|
|
408
|
+
|
|
409
|
+
message = self.message(tensor)
|
|
410
|
+
if not isinstance(message, tensors.GraphTensor):
|
|
411
|
+
message = tensor.update({'edge': {'message': message}})
|
|
412
|
+
elif not 'message' in message.edge:
|
|
413
|
+
raise ValueError('Could not find `message` in `edge` output.')
|
|
414
|
+
|
|
415
|
+
aggregate = self.aggregate(message)
|
|
416
|
+
if not isinstance(aggregate, tensors.GraphTensor):
|
|
417
|
+
aggregate = tensor.update({'node': {'aggregate': aggregate}})
|
|
418
|
+
elif not 'aggregate' in aggregate.node:
|
|
419
|
+
raise ValueError('Could not find `aggregate` in `node` output.')
|
|
420
|
+
|
|
421
|
+
update = self.update(aggregate)
|
|
422
|
+
if not isinstance(update, tensors.GraphTensor):
|
|
423
|
+
update = tensor.update({'node': {'feature': update}})
|
|
424
|
+
elif not 'feature' in update.node:
|
|
425
|
+
raise ValueError('Could not find `feature` in `node` output.')
|
|
426
|
+
|
|
427
|
+
updated_node_feature = update.node['feature']
|
|
428
|
+
|
|
429
|
+
if self._skip_connect:
|
|
430
|
+
if self._use_weighted_skip_connection:
|
|
431
|
+
input_node_feature *= self._skip_connection_weight
|
|
432
|
+
updated_node_feature += input_node_feature
|
|
433
|
+
|
|
434
|
+
if self._normalize:
|
|
435
|
+
updated_node_feature = self._output_norm(updated_node_feature)
|
|
436
|
+
|
|
437
|
+
return update.update({'node': {'feature': updated_node_feature}})
|
|
438
|
+
|
|
403
439
|
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
404
440
|
"""Compute messages.
|
|
405
441
|
|
|
@@ -441,8 +477,7 @@ class GraphConv(GraphLayer):
|
|
|
441
477
|
return tensor.update(
|
|
442
478
|
{
|
|
443
479
|
'node': {
|
|
444
|
-
'
|
|
445
|
-
'previous_feature': tensor.node['feature']
|
|
480
|
+
'aggregate': aggregate,
|
|
446
481
|
},
|
|
447
482
|
'edge': {
|
|
448
483
|
'message': None
|
|
@@ -460,23 +495,20 @@ class GraphConv(GraphLayer):
|
|
|
460
495
|
A `GraphTensor` instance containing aggregated messages
|
|
461
496
|
(updated node features).
|
|
462
497
|
"""
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
],
|
|
471
|
-
axis=-1
|
|
472
|
-
)
|
|
498
|
+
feature = keras.ops.concatenate(
|
|
499
|
+
[
|
|
500
|
+
tensor.node['aggregate'],
|
|
501
|
+
tensor.node['feature']
|
|
502
|
+
],
|
|
503
|
+
axis=-1
|
|
504
|
+
)
|
|
473
505
|
update = self._output_dense(feature)
|
|
474
506
|
update = self._output_activation(update)
|
|
475
507
|
return tensor.update(
|
|
476
508
|
{
|
|
477
509
|
'node': {
|
|
478
510
|
'feature': update,
|
|
479
|
-
'
|
|
511
|
+
'aggregate': None,
|
|
480
512
|
}
|
|
481
513
|
}
|
|
482
514
|
)
|
|
@@ -486,8 +518,8 @@ class GraphConv(GraphLayer):
|
|
|
486
518
|
config.update({
|
|
487
519
|
'units': self.units,
|
|
488
520
|
'activation': keras.activations.serialize(self._activation),
|
|
489
|
-
'
|
|
490
|
-
'
|
|
521
|
+
'normalize': self._normalize,
|
|
522
|
+
'skip_connect': self._skip_connect,
|
|
491
523
|
})
|
|
492
524
|
return config
|
|
493
525
|
|
|
@@ -530,14 +562,14 @@ class GIConv(GraphConv):
|
|
|
530
562
|
units: int,
|
|
531
563
|
activation: keras.layers.Activation | str | None = 'relu',
|
|
532
564
|
use_bias: bool = True,
|
|
533
|
-
|
|
565
|
+
normalize: bool = False,
|
|
534
566
|
update_edge_feature: bool = True,
|
|
535
567
|
**kwargs,
|
|
536
568
|
):
|
|
537
569
|
super().__init__(
|
|
538
570
|
units=units,
|
|
539
571
|
activation=activation,
|
|
540
|
-
|
|
572
|
+
normalize=normalize,
|
|
541
573
|
use_bias=use_bias,
|
|
542
574
|
**kwargs
|
|
543
575
|
)
|
|
@@ -599,7 +631,7 @@ class GIConv(GraphConv):
|
|
|
599
631
|
return tensor.update(
|
|
600
632
|
{
|
|
601
633
|
'node': {
|
|
602
|
-
'
|
|
634
|
+
'aggregate': node_feature,
|
|
603
635
|
},
|
|
604
636
|
'edge': {
|
|
605
637
|
'message': None,
|
|
@@ -608,7 +640,7 @@ class GIConv(GraphConv):
|
|
|
608
640
|
)
|
|
609
641
|
|
|
610
642
|
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
611
|
-
node_feature = tensor.node['
|
|
643
|
+
node_feature = tensor.node['aggregate']
|
|
612
644
|
node_feature = self._feedforward_intermediate_dense(node_feature)
|
|
613
645
|
node_feature = self._feedforward_activation(node_feature)
|
|
614
646
|
node_feature = self._feedforward_output_dense(node_feature)
|
|
@@ -616,6 +648,7 @@ class GIConv(GraphConv):
|
|
|
616
648
|
{
|
|
617
649
|
'node': {
|
|
618
650
|
'feature': node_feature,
|
|
651
|
+
'aggregate': None,
|
|
619
652
|
}
|
|
620
653
|
}
|
|
621
654
|
)
|
|
@@ -667,17 +700,16 @@ class GAConv(GraphConv):
|
|
|
667
700
|
heads: int = 8,
|
|
668
701
|
activation: keras.layers.Activation | str | None = "relu",
|
|
669
702
|
use_bias: bool = True,
|
|
670
|
-
|
|
703
|
+
normalize: bool = False,
|
|
671
704
|
update_edge_feature: bool = True,
|
|
672
705
|
attention_activation: keras.layers.Activation | str | None = "leaky_relu",
|
|
673
706
|
**kwargs,
|
|
674
707
|
) -> None:
|
|
675
|
-
kwargs['skip_connection'] = False
|
|
676
708
|
super().__init__(
|
|
677
709
|
units=units,
|
|
678
710
|
activation=activation,
|
|
679
711
|
use_bias=use_bias,
|
|
680
|
-
|
|
712
|
+
normalize=normalize,
|
|
681
713
|
**kwargs
|
|
682
714
|
)
|
|
683
715
|
self._heads = heads
|
|
@@ -753,11 +785,11 @@ class GAConv(GraphConv):
|
|
|
753
785
|
)
|
|
754
786
|
node_feature = self._node_dense(tensor.node['feature'])
|
|
755
787
|
message = ops.gather(node_feature, tensor.edge['source'])
|
|
788
|
+
message = ops.edge_weight(message, attention_score)
|
|
756
789
|
return tensor.update(
|
|
757
790
|
{
|
|
758
791
|
'edge': {
|
|
759
792
|
'message': message,
|
|
760
|
-
'weight': attention_score,
|
|
761
793
|
'feature': edge_feature,
|
|
762
794
|
}
|
|
763
795
|
}
|
|
@@ -770,24 +802,24 @@ class GAConv(GraphConv):
|
|
|
770
802
|
return tensor.update(
|
|
771
803
|
{
|
|
772
804
|
'node': {
|
|
773
|
-
'
|
|
805
|
+
'aggregate': node_feature
|
|
774
806
|
},
|
|
775
807
|
'edge': {
|
|
776
808
|
'message': None,
|
|
777
|
-
'weight': None,
|
|
778
809
|
}
|
|
779
810
|
}
|
|
780
811
|
)
|
|
781
812
|
|
|
782
813
|
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
783
|
-
node_feature = tensor.node['
|
|
814
|
+
node_feature = tensor.node['aggregate']
|
|
784
815
|
node_feature = self._feedforward_intermediate_dense(node_feature)
|
|
785
816
|
node_feature = self._feedforward_activation(node_feature)
|
|
786
817
|
node_feature = self._feedforward_output_dense(node_feature)
|
|
787
818
|
return tensor.update(
|
|
788
819
|
{
|
|
789
820
|
'node': {
|
|
790
|
-
'feature': node_feature
|
|
821
|
+
'feature': node_feature,
|
|
822
|
+
'aggregate': None,
|
|
791
823
|
}
|
|
792
824
|
}
|
|
793
825
|
)
|
|
@@ -842,7 +874,7 @@ class GTConv(GraphConv):
|
|
|
842
874
|
heads: int = 8,
|
|
843
875
|
activation: keras.layers.Activation | str | None = "relu",
|
|
844
876
|
use_bias: bool = True,
|
|
845
|
-
|
|
877
|
+
normalize: bool = False,
|
|
846
878
|
attention_dropout: float = 0.0,
|
|
847
879
|
**kwargs,
|
|
848
880
|
) -> None:
|
|
@@ -850,7 +882,7 @@ class GTConv(GraphConv):
|
|
|
850
882
|
units=units,
|
|
851
883
|
activation=activation,
|
|
852
884
|
use_bias=use_bias,
|
|
853
|
-
|
|
885
|
+
normalize=normalize,
|
|
854
886
|
**kwargs
|
|
855
887
|
)
|
|
856
888
|
self._heads = heads
|
|
@@ -901,7 +933,6 @@ class GTConv(GraphConv):
|
|
|
901
933
|
}
|
|
902
934
|
}
|
|
903
935
|
)
|
|
904
|
-
|
|
905
936
|
node_feature = tensor.node['feature']
|
|
906
937
|
|
|
907
938
|
query = self._query_dense(node_feature)
|
|
@@ -918,12 +949,12 @@ class GTConv(GraphConv):
|
|
|
918
949
|
attention_score += keras.ops.expand_dims(tensor.edge['bias'], axis=1)
|
|
919
950
|
attention = ops.edge_softmax(attention_score, tensor.edge['target'])
|
|
920
951
|
attention = self._softmax_dropout(attention)
|
|
952
|
+
message = ops.edge_weight(value, attention)
|
|
921
953
|
|
|
922
954
|
return tensor.update(
|
|
923
955
|
{
|
|
924
956
|
'edge': {
|
|
925
|
-
'message':
|
|
926
|
-
'weight': attention,
|
|
957
|
+
'message': message
|
|
927
958
|
},
|
|
928
959
|
}
|
|
929
960
|
)
|
|
@@ -935,18 +966,16 @@ class GTConv(GraphConv):
|
|
|
935
966
|
return tensor.update(
|
|
936
967
|
{
|
|
937
968
|
'node': {
|
|
938
|
-
'
|
|
939
|
-
'residual': tensor.node['feature']
|
|
969
|
+
'aggregate': node_feature,
|
|
940
970
|
},
|
|
941
971
|
'edge': {
|
|
942
972
|
'message': None,
|
|
943
|
-
'weight': None,
|
|
944
973
|
}
|
|
945
974
|
}
|
|
946
975
|
)
|
|
947
976
|
|
|
948
977
|
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
949
|
-
node_feature = tensor.node['
|
|
978
|
+
node_feature = tensor.node['aggregate']
|
|
950
979
|
node_feature = self._feedforward_intermediate_dense(node_feature)
|
|
951
980
|
node_feature = self._feedforward_activation(node_feature)
|
|
952
981
|
node_feature = self._feedforward_output_dense(node_feature)
|
|
@@ -954,6 +983,7 @@ class GTConv(GraphConv):
|
|
|
954
983
|
{
|
|
955
984
|
'node': {
|
|
956
985
|
'feature': node_feature,
|
|
986
|
+
'aggregate': None,
|
|
957
987
|
},
|
|
958
988
|
}
|
|
959
989
|
)
|
|
@@ -978,14 +1008,14 @@ class MPConv(GraphConv):
|
|
|
978
1008
|
units: int = 128,
|
|
979
1009
|
activation: keras.layers.Activation | str | None = None,
|
|
980
1010
|
use_bias: bool = True,
|
|
981
|
-
|
|
1011
|
+
normalize: bool = False,
|
|
982
1012
|
**kwargs
|
|
983
1013
|
) -> None:
|
|
984
1014
|
super().__init__(
|
|
985
1015
|
units=units,
|
|
986
1016
|
activation=activation,
|
|
987
1017
|
use_bias=use_bias,
|
|
988
|
-
|
|
1018
|
+
normalize=normalize,
|
|
989
1019
|
**kwargs
|
|
990
1020
|
)
|
|
991
1021
|
|
|
@@ -1032,28 +1062,28 @@ class MPConv(GraphConv):
|
|
|
1032
1062
|
|
|
1033
1063
|
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1034
1064
|
aggregate = tensor.aggregate('message', mode='mean')
|
|
1035
|
-
|
|
1065
|
+
feature = tensor.node['feature']
|
|
1036
1066
|
if self.project_input_node_feature:
|
|
1037
|
-
|
|
1067
|
+
feature = self._previous_node_dense(feature)
|
|
1038
1068
|
return tensor.update(
|
|
1039
1069
|
{
|
|
1040
1070
|
'node': {
|
|
1041
|
-
'
|
|
1042
|
-
'
|
|
1071
|
+
'aggregate': aggregate,
|
|
1072
|
+
'feature': feature,
|
|
1043
1073
|
}
|
|
1044
1074
|
}
|
|
1045
1075
|
)
|
|
1046
1076
|
|
|
1047
1077
|
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1048
1078
|
updated_node_feature, _ = self.update_fn(
|
|
1049
|
-
inputs=tensor.node['
|
|
1050
|
-
states=tensor.node['
|
|
1079
|
+
inputs=tensor.node['aggregate'],
|
|
1080
|
+
states=tensor.node['feature']
|
|
1051
1081
|
)
|
|
1052
1082
|
return tensor.update(
|
|
1053
1083
|
{
|
|
1054
1084
|
'node': {
|
|
1055
1085
|
'feature': updated_node_feature,
|
|
1056
|
-
'
|
|
1086
|
+
'aggregate': None,
|
|
1057
1087
|
}
|
|
1058
1088
|
}
|
|
1059
1089
|
)
|
|
@@ -1124,12 +1154,13 @@ class GTConv3D(GTConv):
|
|
|
1124
1154
|
attention *= keras.ops.expand_dims(distance, axis=-1)
|
|
1125
1155
|
attention = keras.ops.expand_dims(attention, axis=2)
|
|
1126
1156
|
value = keras.ops.expand_dims(value, axis=1)
|
|
1157
|
+
|
|
1158
|
+
message = ops.edge_weight(value, attention)
|
|
1127
1159
|
|
|
1128
1160
|
return tensor.update(
|
|
1129
1161
|
{
|
|
1130
1162
|
'edge': {
|
|
1131
|
-
'message':
|
|
1132
|
-
'weight': attention,
|
|
1163
|
+
'message': message,
|
|
1133
1164
|
},
|
|
1134
1165
|
}
|
|
1135
1166
|
)
|
|
@@ -1144,12 +1175,10 @@ class GTConv3D(GTConv):
|
|
|
1144
1175
|
return tensor.update(
|
|
1145
1176
|
{
|
|
1146
1177
|
'node': {
|
|
1147
|
-
'
|
|
1148
|
-
'residual': tensor.node['feature']
|
|
1178
|
+
'aggregate': node_feature,
|
|
1149
1179
|
},
|
|
1150
1180
|
'edge': {
|
|
1151
1181
|
'message': None,
|
|
1152
|
-
'weight': None,
|
|
1153
1182
|
}
|
|
1154
1183
|
}
|
|
1155
1184
|
)
|
|
@@ -1204,14 +1233,14 @@ class EGConv3D(GraphConv):
|
|
|
1204
1233
|
units: int = 128,
|
|
1205
1234
|
activation: keras.layers.Activation | str | None = None,
|
|
1206
1235
|
use_bias: bool = True,
|
|
1207
|
-
|
|
1236
|
+
normalize: bool = False,
|
|
1208
1237
|
**kwargs
|
|
1209
1238
|
) -> None:
|
|
1210
1239
|
super().__init__(
|
|
1211
1240
|
units=units,
|
|
1212
1241
|
activation=activation,
|
|
1213
1242
|
use_bias=use_bias,
|
|
1214
|
-
|
|
1243
|
+
normalize=normalize,
|
|
1215
1244
|
**kwargs
|
|
1216
1245
|
)
|
|
1217
1246
|
|
|
@@ -1223,7 +1252,7 @@ class EGConv3D(GraphConv):
|
|
|
1223
1252
|
)
|
|
1224
1253
|
self._has_edge_feature = 'feature' in spec.edge
|
|
1225
1254
|
self.message_fn = self.get_dense(self.units, activation=self._activation)
|
|
1226
|
-
self.dense_position = self.get_dense(1)
|
|
1255
|
+
self.dense_position = self.get_dense(1, use_bias=False, kernel_initializer='zeros')
|
|
1227
1256
|
|
|
1228
1257
|
has_overridden_update = self.__class__.update != EGConv3D.update
|
|
1229
1258
|
if not has_overridden_update:
|
|
@@ -1273,27 +1302,26 @@ class EGConv3D(GraphConv):
|
|
|
1273
1302
|
)
|
|
1274
1303
|
|
|
1275
1304
|
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1276
|
-
coefficient = keras.ops.bincount(
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
)
|
|
1280
|
-
coefficient = keras.ops.cast(
|
|
1281
|
-
|
|
1282
|
-
)
|
|
1283
|
-
coefficient = keras.ops.expand_dims(
|
|
1284
|
-
|
|
1285
|
-
)
|
|
1286
|
-
|
|
1287
|
-
updated_coordinate = tensor.aggregate('relative_node_coordinate') * coefficient
|
|
1305
|
+
# coefficient = keras.ops.bincount(
|
|
1306
|
+
# tensor.edge['source'],
|
|
1307
|
+
# minlength=tensor.num_nodes
|
|
1308
|
+
# )
|
|
1309
|
+
# coefficient = keras.ops.cast(
|
|
1310
|
+
# coefficient, tensor.node['coordinate'].dtype
|
|
1311
|
+
# )
|
|
1312
|
+
# coefficient = keras.ops.expand_dims(
|
|
1313
|
+
# keras.ops.divide_no_nan(1, coefficient), axis=1
|
|
1314
|
+
# )
|
|
1315
|
+
|
|
1316
|
+
updated_coordinate = tensor.aggregate('relative_node_coordinate', mode='mean')# * coefficient
|
|
1288
1317
|
updated_coordinate += tensor.node['coordinate']
|
|
1289
1318
|
|
|
1290
1319
|
aggregate = tensor.aggregate('message', mode='mean')
|
|
1291
1320
|
return tensor.update(
|
|
1292
1321
|
{
|
|
1293
1322
|
'node': {
|
|
1294
|
-
'
|
|
1323
|
+
'aggregate': aggregate,
|
|
1295
1324
|
'coordinate': updated_coordinate,
|
|
1296
|
-
'previous_feature': tensor.node['feature'],
|
|
1297
1325
|
},
|
|
1298
1326
|
'edge': {
|
|
1299
1327
|
'message': None,
|
|
@@ -1306,8 +1334,8 @@ class EGConv3D(GraphConv):
|
|
|
1306
1334
|
updated_node_feature = self.update_fn(
|
|
1307
1335
|
keras.ops.concatenate(
|
|
1308
1336
|
[
|
|
1309
|
-
tensor.node['
|
|
1310
|
-
tensor.node['
|
|
1337
|
+
tensor.node['aggregate'],
|
|
1338
|
+
tensor.node['feature']
|
|
1311
1339
|
],
|
|
1312
1340
|
axis=-1
|
|
1313
1341
|
)
|
|
@@ -1317,7 +1345,7 @@ class EGConv3D(GraphConv):
|
|
|
1317
1345
|
{
|
|
1318
1346
|
'node': {
|
|
1319
1347
|
'feature': updated_node_feature,
|
|
1320
|
-
'
|
|
1348
|
+
'aggregate': None,
|
|
1321
1349
|
},
|
|
1322
1350
|
}
|
|
1323
1351
|
)
|
|
@@ -1478,6 +1506,32 @@ class GraphNetwork(GraphLayer):
|
|
|
1478
1506
|
return super().from_config(config)
|
|
1479
1507
|
|
|
1480
1508
|
|
|
1509
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1510
|
+
class Extraction(GraphLayer):
|
|
1511
|
+
|
|
1512
|
+
def __init__(
|
|
1513
|
+
self,
|
|
1514
|
+
field: str,
|
|
1515
|
+
inner_field: str | None = None,
|
|
1516
|
+
**kwargs
|
|
1517
|
+
) -> None:
|
|
1518
|
+
super().__init__(**kwargs)
|
|
1519
|
+
self.field = field
|
|
1520
|
+
self.inner_field = inner_field
|
|
1521
|
+
|
|
1522
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1523
|
+
data = dict(getattr(tensor, self.field))
|
|
1524
|
+
if not self.inner_field:
|
|
1525
|
+
return data
|
|
1526
|
+
return data[self.inner_field]
|
|
1527
|
+
|
|
1528
|
+
def get_config(self):
|
|
1529
|
+
config = super().get_config()
|
|
1530
|
+
config['field'] = self.field
|
|
1531
|
+
config['inner_field'] = self.inner_field
|
|
1532
|
+
return config
|
|
1533
|
+
|
|
1534
|
+
|
|
1481
1535
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1482
1536
|
class NodeEmbedding(GraphLayer):
|
|
1483
1537
|
|
|
@@ -1489,15 +1543,15 @@ class NodeEmbedding(GraphLayer):
|
|
|
1489
1543
|
def __init__(
|
|
1490
1544
|
self,
|
|
1491
1545
|
dim: int = None,
|
|
1492
|
-
|
|
1493
|
-
embed_context: bool =
|
|
1546
|
+
normalize: bool = False,
|
|
1547
|
+
embed_context: bool = False,
|
|
1494
1548
|
allow_reconstruction: bool = False,
|
|
1495
|
-
allow_masking: bool =
|
|
1549
|
+
allow_masking: bool = False,
|
|
1496
1550
|
**kwargs
|
|
1497
1551
|
) -> None:
|
|
1498
1552
|
super().__init__(**kwargs)
|
|
1499
1553
|
self.dim = dim
|
|
1500
|
-
self.
|
|
1554
|
+
self._normalize = normalize
|
|
1501
1555
|
self._embed_context = embed_context
|
|
1502
1556
|
self._masking_rate = None
|
|
1503
1557
|
self._allow_masking = allow_masking
|
|
@@ -1517,13 +1571,11 @@ class NodeEmbedding(GraphLayer):
|
|
|
1517
1571
|
self._super_feature = self.get_weight(shape=[self.dim], name='super_node_feature')
|
|
1518
1572
|
if self._allow_masking:
|
|
1519
1573
|
self._mask_feature = self.get_weight(shape=[self.dim], name='mask_node_feature')
|
|
1520
|
-
|
|
1521
1574
|
if self._embed_context:
|
|
1522
|
-
context_feature_dim = spec.context['feature'].shape[-1]
|
|
1523
1575
|
self._context_dense = self.get_dense(self.dim)
|
|
1524
1576
|
|
|
1525
|
-
if self.
|
|
1526
|
-
if str(self.
|
|
1577
|
+
if self._normalize:
|
|
1578
|
+
if str(self._normalize).lower().startswith('batch'):
|
|
1527
1579
|
self._norm = keras.layers.BatchNormalization(
|
|
1528
1580
|
name='output_batch_norm'
|
|
1529
1581
|
)
|
|
@@ -1545,48 +1597,25 @@ class NodeEmbedding(GraphLayer):
|
|
|
1545
1597
|
feature = ops.scatter_update(feature, tensor.node['super'], context_feature)
|
|
1546
1598
|
tensor = tensor.update({'context': {'feature': None}})
|
|
1547
1599
|
|
|
1548
|
-
|
|
1549
|
-
|
|
1550
|
-
|
|
1551
|
-
self._masking_rate > 0
|
|
1552
|
-
):
|
|
1553
|
-
random = keras.random.uniform(shape=[tensor.num_nodes])
|
|
1554
|
-
mask = random <= self._masking_rate
|
|
1555
|
-
if self._has_super:
|
|
1556
|
-
mask = keras.ops.logical_and(
|
|
1557
|
-
mask, keras.ops.logical_not(tensor.node['super'])
|
|
1558
|
-
)
|
|
1559
|
-
mask = keras.ops.expand_dims(mask, -1)
|
|
1600
|
+
apply_mask = (self._allow_masking and 'mask' in tensor.node)
|
|
1601
|
+
if apply_mask:
|
|
1602
|
+
mask = keras.ops.expand_dims(tensor.node['mask'], -1)
|
|
1560
1603
|
feature = keras.ops.where(mask, self._mask_feature, feature)
|
|
1561
1604
|
elif self._allow_masking:
|
|
1562
|
-
# Slience warning of 'no gradients for variables'
|
|
1563
1605
|
feature = feature + (self._mask_feature * 0.0)
|
|
1564
1606
|
|
|
1565
|
-
if self.
|
|
1607
|
+
if self._normalize:
|
|
1566
1608
|
feature = self._norm(feature)
|
|
1567
1609
|
|
|
1568
1610
|
if not self._allow_reconstruction:
|
|
1569
1611
|
return tensor.update({'node': {'feature': feature}})
|
|
1570
1612
|
return tensor.update({'node': {'feature': feature, 'target_feature': feature}})
|
|
1571
|
-
|
|
1572
|
-
@property
|
|
1573
|
-
def masking_rate(self):
|
|
1574
|
-
return self._masking_rate
|
|
1575
|
-
|
|
1576
|
-
@masking_rate.setter
|
|
1577
|
-
def masking_rate(self, rate: float):
|
|
1578
|
-
if not self._allow_masking and rate is not None:
|
|
1579
|
-
raise ValueError(
|
|
1580
|
-
f'Cannot set `masking_rate` for layer {self} '
|
|
1581
|
-
'as `allow_masking` was set to `False`.'
|
|
1582
|
-
)
|
|
1583
|
-
self._masking_rate = float(rate)
|
|
1584
1613
|
|
|
1585
1614
|
def get_config(self) -> dict:
|
|
1586
1615
|
config = super().get_config()
|
|
1587
1616
|
config.update({
|
|
1588
1617
|
'dim': self.dim,
|
|
1589
|
-
'
|
|
1618
|
+
'normalize': self._normalize,
|
|
1590
1619
|
'embed_context': self._embed_context,
|
|
1591
1620
|
'allow_masking': self._allow_masking,
|
|
1592
1621
|
'allow_reconstruction': self._allow_reconstruction,
|
|
@@ -1605,13 +1634,13 @@ class EdgeEmbedding(GraphLayer):
|
|
|
1605
1634
|
def __init__(
|
|
1606
1635
|
self,
|
|
1607
1636
|
dim: int = None,
|
|
1608
|
-
|
|
1637
|
+
normalize: bool = False,
|
|
1609
1638
|
allow_masking: bool = True,
|
|
1610
1639
|
**kwargs
|
|
1611
1640
|
) -> None:
|
|
1612
1641
|
super().__init__(**kwargs)
|
|
1613
1642
|
self.dim = dim
|
|
1614
|
-
self.
|
|
1643
|
+
self._normalize = normalize
|
|
1615
1644
|
self._masking_rate = None
|
|
1616
1645
|
self._allow_masking = allow_masking
|
|
1617
1646
|
|
|
@@ -1622,13 +1651,16 @@ class EdgeEmbedding(GraphLayer):
|
|
|
1622
1651
|
self._edge_dense = self.get_dense(self.dim)
|
|
1623
1652
|
|
|
1624
1653
|
self._has_super = 'super' in spec.edge
|
|
1654
|
+
self._has_self_loop = 'self_loop' in spec.edge
|
|
1625
1655
|
if self._has_super:
|
|
1626
1656
|
self._super_feature = self.get_weight(shape=[self.dim], name='super_edge_feature')
|
|
1657
|
+
if self._has_self_loop:
|
|
1658
|
+
self._self_loop_feature = self.get_weight(shape=[self.dim], name='self_loop_edge_feature')
|
|
1627
1659
|
if self._allow_masking:
|
|
1628
1660
|
self._mask_feature = self.get_weight(shape=[self.dim], name='mask_edge_feature')
|
|
1629
1661
|
|
|
1630
|
-
if self.
|
|
1631
|
-
if str(self.
|
|
1662
|
+
if self._normalize:
|
|
1663
|
+
if str(self._normalize).lower().startswith('batch'):
|
|
1632
1664
|
self._norm = keras.layers.BatchNormalization(
|
|
1633
1665
|
name='output_batch_norm'
|
|
1634
1666
|
)
|
|
@@ -1641,10 +1673,13 @@ class EdgeEmbedding(GraphLayer):
|
|
|
1641
1673
|
feature = self._edge_dense(tensor.edge['feature'])
|
|
1642
1674
|
|
|
1643
1675
|
if self._has_super:
|
|
1644
|
-
super_feature = self._super_feature
|
|
1645
1676
|
super_mask = keras.ops.expand_dims(tensor.edge['super'], 1)
|
|
1646
|
-
feature = keras.ops.where(super_mask,
|
|
1677
|
+
feature = keras.ops.where(super_mask, self._super_feature, feature)
|
|
1647
1678
|
|
|
1679
|
+
if self._has_self_loop:
|
|
1680
|
+
self_loop_mask = keras.ops.expand_dims(tensor.edge['self_loop'], 1)
|
|
1681
|
+
feature = keras.ops.where(self_loop_mask, self._self_loop_feature, feature)
|
|
1682
|
+
|
|
1648
1683
|
if (
|
|
1649
1684
|
self._allow_masking and
|
|
1650
1685
|
self._masking_rate is not None and
|
|
@@ -1662,7 +1697,7 @@ class EdgeEmbedding(GraphLayer):
|
|
|
1662
1697
|
# Slience warning of 'no gradients for variables'
|
|
1663
1698
|
feature = feature + (self._mask_feature * 0.0)
|
|
1664
1699
|
|
|
1665
|
-
if self.
|
|
1700
|
+
if self._normalize:
|
|
1666
1701
|
feature = self._norm(feature)
|
|
1667
1702
|
|
|
1668
1703
|
return tensor.update({'edge': {'feature': feature, 'embedding': feature}})
|
|
@@ -1684,7 +1719,7 @@ class EdgeEmbedding(GraphLayer):
|
|
|
1684
1719
|
config = super().get_config()
|
|
1685
1720
|
config.update({
|
|
1686
1721
|
'dim': self.dim,
|
|
1687
|
-
'
|
|
1722
|
+
'normalize': self._normalize,
|
|
1688
1723
|
'allow_masking': self._allow_masking
|
|
1689
1724
|
})
|
|
1690
1725
|
return config
|
|
@@ -1883,6 +1918,56 @@ class GaussianDistance(GraphLayer):
|
|
|
1883
1918
|
return config
|
|
1884
1919
|
|
|
1885
1920
|
|
|
1921
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1922
|
+
class GaussianParams(keras.layers.Dense):
|
|
1923
|
+
'''Gaussian parameters.
|
|
1924
|
+
|
|
1925
|
+
Computes loc and scale via a dense layer. Should be implemented
|
|
1926
|
+
as the last layer in a model and paired with `losses.GaussianNLL`.
|
|
1927
|
+
|
|
1928
|
+
The loc and scale parameters (resulting from this layer) are concatenated
|
|
1929
|
+
together along the last axis, resulting in a single output tensor.
|
|
1930
|
+
|
|
1931
|
+
Args:
|
|
1932
|
+
events (int):
|
|
1933
|
+
The number of events. If the model makes a single prediction per example,
|
|
1934
|
+
then the number of events should be 1. If the model makes multiple predictions
|
|
1935
|
+
per example, then the number of events should be greater than 1.
|
|
1936
|
+
Default to 1.
|
|
1937
|
+
kwargs:
|
|
1938
|
+
See `keras.layers.Dense` documentation. `activation` will be applied
|
|
1939
|
+
to `loc` only. `scale` is automatically softplus activated.
|
|
1940
|
+
'''
|
|
1941
|
+
def __init__(self, events: int = 1, **kwargs):
|
|
1942
|
+
units = kwargs.pop('units', None)
|
|
1943
|
+
activation = kwargs.pop('activation', None)
|
|
1944
|
+
if units:
|
|
1945
|
+
if units % 2 != 0:
|
|
1946
|
+
raise ValueError(
|
|
1947
|
+
'`units` needs to be divisble by 2 as `units` = 2 x `events`.'
|
|
1948
|
+
)
|
|
1949
|
+
else:
|
|
1950
|
+
units = int(events * 2)
|
|
1951
|
+
super().__init__(units=units, **kwargs)
|
|
1952
|
+
self.events = events
|
|
1953
|
+
self.loc_activation = keras.activations.get(activation)
|
|
1954
|
+
|
|
1955
|
+
def call(self, inputs, **kwargs):
|
|
1956
|
+
loc_and_scale = super().call(inputs, **kwargs)
|
|
1957
|
+
loc = loc_and_scale[..., :self.events]
|
|
1958
|
+
scale = loc_and_scale[..., self.events:]
|
|
1959
|
+
scale = keras.ops.softplus(scale) + keras.backend.epsilon()
|
|
1960
|
+
loc = self.loc_activation(loc)
|
|
1961
|
+
return keras.ops.concatenate([loc, scale], axis=-1)
|
|
1962
|
+
|
|
1963
|
+
def get_config(self):
|
|
1964
|
+
config = super().get_config()
|
|
1965
|
+
config['events'] = self.events
|
|
1966
|
+
config['units'] = None
|
|
1967
|
+
config['activation'] = keras.activations.serialize(self.loc_activation)
|
|
1968
|
+
return config
|
|
1969
|
+
|
|
1970
|
+
|
|
1886
1971
|
def Input(spec: tensors.GraphTensor.Spec) -> dict:
|
|
1887
1972
|
"""Used to specify inputs to model.
|
|
1888
1973
|
|
|
@@ -1915,6 +2000,11 @@ def Input(spec: tensors.GraphTensor.Spec) -> dict:
|
|
|
1915
2000
|
inputs[outer_field] = {}
|
|
1916
2001
|
for inner_field, nested_spec in data.items():
|
|
1917
2002
|
if inner_field in ['label', 'weight']:
|
|
2003
|
+
# Remove context label and weight from the symbolic input
|
|
2004
|
+
# as a functional model is strict for what input can be passed.
|
|
2005
|
+
# We want to be able to pass a graph with or without labels
|
|
2006
|
+
# and sample weights. The __call__ method of the `GraphModel`
|
|
2007
|
+
# temporarily pops label and weight to avoid errors.
|
|
1918
2008
|
if outer_field == 'context':
|
|
1919
2009
|
continue
|
|
1920
2010
|
kwargs = {
|
|
@@ -1941,23 +2031,6 @@ def warn(message: str) -> None:
|
|
|
1941
2031
|
stacklevel=1
|
|
1942
2032
|
)
|
|
1943
2033
|
|
|
1944
|
-
def _match_functional_input(functional_input, inputs):
|
|
1945
|
-
matching_inputs = {}
|
|
1946
|
-
for outer_field, data in functional_input.items():
|
|
1947
|
-
matching_inputs[outer_field] = {}
|
|
1948
|
-
for inner_field, _ in data.items():
|
|
1949
|
-
call_input = inputs[outer_field].pop(inner_field)
|
|
1950
|
-
matching_inputs[outer_field][inner_field] = call_input
|
|
1951
|
-
unmatching_inputs = inputs
|
|
1952
|
-
return matching_inputs, unmatching_inputs
|
|
1953
|
-
|
|
1954
|
-
def _add_left_out_inputs(outputs, inputs):
|
|
1955
|
-
for outer_field, data in inputs.items():
|
|
1956
|
-
for inner_field, value in data.items():
|
|
1957
|
-
if inner_field in ['label', 'weight']:
|
|
1958
|
-
outputs[outer_field][inner_field] = value
|
|
1959
|
-
return outputs
|
|
1960
|
-
|
|
1961
2034
|
def _serialize_spec(spec: tensors.GraphTensor.Spec) -> dict:
|
|
1962
2035
|
serialized_spec = {}
|
|
1963
2036
|
for outer_field, data in spec.__dict__.items():
|