molcraft 0.1.0a5__py3-none-any.whl → 0.1.0a7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of molcraft might be problematic. Click here for more details.
- molcraft/__init__.py +3 -2
- molcraft/callbacks.py +60 -0
- molcraft/chem.py +103 -21
- molcraft/conformers.py +1 -5
- molcraft/featurizers.py +20 -14
- molcraft/layers.py +307 -211
- molcraft/losses.py +36 -0
- molcraft/models.py +135 -9
- molcraft/ops.py +12 -2
- molcraft/records.py +32 -31
- molcraft/tensors.py +1 -1
- {molcraft-0.1.0a5.dist-info → molcraft-0.1.0a7.dist-info}/METADATA +4 -17
- molcraft-0.1.0a7.dist-info/RECORD +19 -0
- {molcraft-0.1.0a5.dist-info → molcraft-0.1.0a7.dist-info}/WHEEL +1 -1
- molcraft-0.1.0a5.dist-info/RECORD +0 -18
- {molcraft-0.1.0a5.dist-info → molcraft-0.1.0a7.dist-info}/licenses/LICENSE +0 -0
- {molcraft-0.1.0a5.dist-info → molcraft-0.1.0a7.dist-info}/top_level.txt +0 -0
molcraft/layers.py
CHANGED
|
@@ -125,23 +125,45 @@ class GraphLayer(keras.layers.Layer):
|
|
|
125
125
|
return tensors.to_dict(outputs)
|
|
126
126
|
return outputs
|
|
127
127
|
|
|
128
|
-
def __call__(
|
|
128
|
+
def __call__(
|
|
129
|
+
self,
|
|
130
|
+
graph: dict[str, dict[str, tf.Tensor]] | tensors.GraphTensor,
|
|
131
|
+
**kwargs
|
|
132
|
+
) -> tf.Tensor | dict[str, dict[str, tf.Tensor]] | tensors.GraphTensor:
|
|
129
133
|
if not self.built:
|
|
130
|
-
spec = _spec_from_inputs(
|
|
134
|
+
spec = _spec_from_inputs(graph)
|
|
131
135
|
self.build(spec)
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
136
|
+
|
|
137
|
+
is_graph_tensor = isinstance(graph, tensors.GraphTensor)
|
|
138
|
+
if is_graph_tensor:
|
|
139
|
+
graph = tensors.to_dict(graph)
|
|
140
|
+
else:
|
|
141
|
+
graph = {field: dict(data) for (field, data) in graph.items()}
|
|
142
|
+
|
|
135
143
|
if isinstance(self, functional.Functional):
|
|
136
|
-
|
|
137
|
-
|
|
144
|
+
# As a functional model is strict for what input can
|
|
145
|
+
# be passed to it, we need to temporarily pop some of the
|
|
146
|
+
# input and add it afterwards.
|
|
147
|
+
label = graph['context'].pop('label', None)
|
|
148
|
+
weight = graph['context'].pop('weight', None)
|
|
149
|
+
tf.nest.assert_same_structure(self.input, graph)
|
|
150
|
+
|
|
151
|
+
outputs = super().__call__(graph, **kwargs)
|
|
152
|
+
|
|
138
153
|
if not tensors.is_graph(outputs):
|
|
139
154
|
return outputs
|
|
155
|
+
|
|
156
|
+
graph = outputs
|
|
140
157
|
if isinstance(self, functional.Functional):
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
158
|
+
if label is not None:
|
|
159
|
+
graph['context']['label'] = label
|
|
160
|
+
if weight is not None:
|
|
161
|
+
graph['context']['weight'] = weight
|
|
162
|
+
|
|
163
|
+
if is_graph_tensor:
|
|
164
|
+
return tensors.from_dict(graph)
|
|
165
|
+
|
|
166
|
+
return graph
|
|
145
167
|
|
|
146
168
|
def get_build_config(self) -> dict:
|
|
147
169
|
if self._custom_build_config:
|
|
@@ -256,10 +278,10 @@ class GraphConv(GraphLayer):
|
|
|
256
278
|
Default to `None`.
|
|
257
279
|
use_bias (bool):
|
|
258
280
|
Whether bias should be used in dense layers. Default to `True`.
|
|
259
|
-
|
|
281
|
+
normalize (bool, str):
|
|
260
282
|
Whether `LayerNormalization` should be applied to the final node feature output.
|
|
261
283
|
To use `BatchNormalization`, specify `batch_norm`. Default to `False`.
|
|
262
|
-
|
|
284
|
+
skip_connect (bool, str):
|
|
263
285
|
Whether node feature input should be added to the node feature output.
|
|
264
286
|
If node feature input dim is not equal to `units` (node feature output dim),
|
|
265
287
|
a projection layer will automatically project the residual before adding it
|
|
@@ -294,14 +316,14 @@ class GraphConv(GraphLayer):
|
|
|
294
316
|
units: int = None,
|
|
295
317
|
activation: str | keras.layers.Activation | None = None,
|
|
296
318
|
use_bias: bool = True,
|
|
297
|
-
|
|
298
|
-
|
|
319
|
+
normalize: bool | str = False,
|
|
320
|
+
skip_connect: bool | str = True,
|
|
299
321
|
**kwargs
|
|
300
322
|
) -> None:
|
|
301
323
|
super().__init__(use_bias=use_bias, **kwargs)
|
|
302
324
|
self._units = units
|
|
303
|
-
self.
|
|
304
|
-
self.
|
|
325
|
+
self._normalize = normalize
|
|
326
|
+
self._skip_connect = skip_connect
|
|
305
327
|
self._activation = keras.activations.get(activation)
|
|
306
328
|
|
|
307
329
|
def __init_subclass__(cls, **kwargs):
|
|
@@ -319,36 +341,6 @@ class GraphConv(GraphLayer):
|
|
|
319
341
|
def units(self):
|
|
320
342
|
return self._units
|
|
321
343
|
|
|
322
|
-
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
323
|
-
"""Forward pass.
|
|
324
|
-
|
|
325
|
-
Invokes `message(graph)`, `aggregate(graph)` and `update(graph)` in sequence.
|
|
326
|
-
|
|
327
|
-
Arguments:
|
|
328
|
-
tensor:
|
|
329
|
-
A `GraphTensor` instance.
|
|
330
|
-
"""
|
|
331
|
-
if self._skip_connection:
|
|
332
|
-
input_node_feature = tensor.node['feature']
|
|
333
|
-
if self._project_input_node_feature:
|
|
334
|
-
input_node_feature = self._residual_projection(input_node_feature)
|
|
335
|
-
|
|
336
|
-
tensor = self.message(tensor)
|
|
337
|
-
tensor = self.aggregate(tensor)
|
|
338
|
-
tensor = self.update(tensor)
|
|
339
|
-
|
|
340
|
-
updated_node_feature = tensor.node['feature']
|
|
341
|
-
|
|
342
|
-
if self._skip_connection:
|
|
343
|
-
if self._use_weighted_skip_connection:
|
|
344
|
-
input_node_feature *= self._skip_connection_weight
|
|
345
|
-
updated_node_feature += input_node_feature
|
|
346
|
-
|
|
347
|
-
if self._normalization:
|
|
348
|
-
updated_node_feature = self._output_norm(updated_node_feature)
|
|
349
|
-
|
|
350
|
-
return tensor.update({'node': {'feature': updated_node_feature}})
|
|
351
|
-
|
|
352
344
|
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
353
345
|
if not self.units:
|
|
354
346
|
raise ValueError(
|
|
@@ -356,11 +348,11 @@ class GraphConv(GraphLayer):
|
|
|
356
348
|
)
|
|
357
349
|
node_feature_dim = spec.node['feature'].shape[-1]
|
|
358
350
|
self._project_input_node_feature = (
|
|
359
|
-
self.
|
|
351
|
+
self._skip_connect and (node_feature_dim != self.units)
|
|
360
352
|
)
|
|
361
353
|
if self._project_input_node_feature:
|
|
362
354
|
warn(
|
|
363
|
-
'`
|
|
355
|
+
'`skip_connect` is set to `True`, but found incompatible dim '
|
|
364
356
|
'between input (node feature dim) and output (`self.units`). '
|
|
365
357
|
'Automatically applying a projection layer to residual to '
|
|
366
358
|
'match input and output. '
|
|
@@ -369,8 +361,8 @@ class GraphConv(GraphLayer):
|
|
|
369
361
|
self.units, name='residual_projection'
|
|
370
362
|
)
|
|
371
363
|
|
|
372
|
-
|
|
373
|
-
self._use_weighted_skip_connection =
|
|
364
|
+
skip_connect = str(self._skip_connect).lower()
|
|
365
|
+
self._use_weighted_skip_connection = skip_connect.startswith('weight')
|
|
374
366
|
if self._use_weighted_skip_connection:
|
|
375
367
|
self._skip_connection_weight = self.add_weight(
|
|
376
368
|
name='skip_connection_weight',
|
|
@@ -379,8 +371,8 @@ class GraphConv(GraphLayer):
|
|
|
379
371
|
trainable=True,
|
|
380
372
|
)
|
|
381
373
|
|
|
382
|
-
if self.
|
|
383
|
-
if str(self.
|
|
374
|
+
if self._normalize:
|
|
375
|
+
if str(self._normalize).lower().startswith('batch'):
|
|
384
376
|
self._output_norm = keras.layers.BatchNormalization(
|
|
385
377
|
name='output_batch_norm'
|
|
386
378
|
)
|
|
@@ -389,7 +381,7 @@ class GraphConv(GraphLayer):
|
|
|
389
381
|
name='output_layer_norm'
|
|
390
382
|
)
|
|
391
383
|
|
|
392
|
-
self._has_edge_feature = '
|
|
384
|
+
self._has_edge_feature = 'feature' in spec.edge
|
|
393
385
|
|
|
394
386
|
has_overridden_message = self.__class__.message != GraphConv.message
|
|
395
387
|
if not has_overridden_message:
|
|
@@ -400,6 +392,50 @@ class GraphConv(GraphLayer):
|
|
|
400
392
|
self._output_dense = self.get_dense(self.units)
|
|
401
393
|
self._output_activation = self._activation
|
|
402
394
|
|
|
395
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
396
|
+
"""Forward pass.
|
|
397
|
+
|
|
398
|
+
Invokes `message(graph)`, `aggregate(graph)` and `update(graph)` in sequence.
|
|
399
|
+
|
|
400
|
+
Arguments:
|
|
401
|
+
tensor:
|
|
402
|
+
A `GraphTensor` instance.
|
|
403
|
+
"""
|
|
404
|
+
if self._skip_connect:
|
|
405
|
+
input_node_feature = tensor.node['feature']
|
|
406
|
+
if self._project_input_node_feature:
|
|
407
|
+
input_node_feature = self._residual_projection(input_node_feature)
|
|
408
|
+
|
|
409
|
+
message = self.message(tensor)
|
|
410
|
+
if not isinstance(message, tensors.GraphTensor):
|
|
411
|
+
message = tensor.update({'edge': {'message': message}})
|
|
412
|
+
elif not 'message' in message.edge:
|
|
413
|
+
raise ValueError('Could not find `message` in `edge` output.')
|
|
414
|
+
|
|
415
|
+
aggregate = self.aggregate(message)
|
|
416
|
+
if not isinstance(aggregate, tensors.GraphTensor):
|
|
417
|
+
aggregate = tensor.update({'node': {'aggregate': aggregate}})
|
|
418
|
+
elif not 'aggregate' in aggregate.node:
|
|
419
|
+
raise ValueError('Could not find `aggregate` in `node` output.')
|
|
420
|
+
|
|
421
|
+
update = self.update(aggregate)
|
|
422
|
+
if not isinstance(update, tensors.GraphTensor):
|
|
423
|
+
update = tensor.update({'node': {'feature': update}})
|
|
424
|
+
elif not 'feature' in update.node:
|
|
425
|
+
raise ValueError('Could not find `feature` in `node` output.')
|
|
426
|
+
|
|
427
|
+
updated_node_feature = update.node['feature']
|
|
428
|
+
|
|
429
|
+
if self._skip_connect:
|
|
430
|
+
if self._use_weighted_skip_connection:
|
|
431
|
+
input_node_feature *= self._skip_connection_weight
|
|
432
|
+
updated_node_feature += input_node_feature
|
|
433
|
+
|
|
434
|
+
if self._normalize:
|
|
435
|
+
updated_node_feature = self._output_norm(updated_node_feature)
|
|
436
|
+
|
|
437
|
+
return update.update({'node': {'feature': updated_node_feature}})
|
|
438
|
+
|
|
403
439
|
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
404
440
|
"""Compute messages.
|
|
405
441
|
|
|
@@ -441,8 +477,7 @@ class GraphConv(GraphLayer):
|
|
|
441
477
|
return tensor.update(
|
|
442
478
|
{
|
|
443
479
|
'node': {
|
|
444
|
-
'
|
|
445
|
-
'previous_feature': tensor.node['feature']
|
|
480
|
+
'aggregate': aggregate,
|
|
446
481
|
},
|
|
447
482
|
'edge': {
|
|
448
483
|
'message': None
|
|
@@ -460,23 +495,20 @@ class GraphConv(GraphLayer):
|
|
|
460
495
|
A `GraphTensor` instance containing aggregated messages
|
|
461
496
|
(updated node features).
|
|
462
497
|
"""
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
],
|
|
471
|
-
axis=-1
|
|
472
|
-
)
|
|
498
|
+
feature = keras.ops.concatenate(
|
|
499
|
+
[
|
|
500
|
+
tensor.node['aggregate'],
|
|
501
|
+
tensor.node['feature']
|
|
502
|
+
],
|
|
503
|
+
axis=-1
|
|
504
|
+
)
|
|
473
505
|
update = self._output_dense(feature)
|
|
474
506
|
update = self._output_activation(update)
|
|
475
507
|
return tensor.update(
|
|
476
508
|
{
|
|
477
509
|
'node': {
|
|
478
510
|
'feature': update,
|
|
479
|
-
'
|
|
511
|
+
'aggregate': None,
|
|
480
512
|
}
|
|
481
513
|
}
|
|
482
514
|
)
|
|
@@ -486,8 +518,8 @@ class GraphConv(GraphLayer):
|
|
|
486
518
|
config.update({
|
|
487
519
|
'units': self.units,
|
|
488
520
|
'activation': keras.activations.serialize(self._activation),
|
|
489
|
-
'
|
|
490
|
-
'
|
|
521
|
+
'normalize': self._normalize,
|
|
522
|
+
'skip_connect': self._skip_connect,
|
|
491
523
|
})
|
|
492
524
|
return config
|
|
493
525
|
|
|
@@ -530,14 +562,14 @@ class GIConv(GraphConv):
|
|
|
530
562
|
units: int,
|
|
531
563
|
activation: keras.layers.Activation | str | None = 'relu',
|
|
532
564
|
use_bias: bool = True,
|
|
533
|
-
|
|
565
|
+
normalize: bool = False,
|
|
534
566
|
update_edge_feature: bool = True,
|
|
535
567
|
**kwargs,
|
|
536
568
|
):
|
|
537
569
|
super().__init__(
|
|
538
570
|
units=units,
|
|
539
571
|
activation=activation,
|
|
540
|
-
|
|
572
|
+
normalize=normalize,
|
|
541
573
|
use_bias=use_bias,
|
|
542
574
|
**kwargs
|
|
543
575
|
)
|
|
@@ -599,7 +631,7 @@ class GIConv(GraphConv):
|
|
|
599
631
|
return tensor.update(
|
|
600
632
|
{
|
|
601
633
|
'node': {
|
|
602
|
-
'
|
|
634
|
+
'aggregate': node_feature,
|
|
603
635
|
},
|
|
604
636
|
'edge': {
|
|
605
637
|
'message': None,
|
|
@@ -608,7 +640,7 @@ class GIConv(GraphConv):
|
|
|
608
640
|
)
|
|
609
641
|
|
|
610
642
|
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
611
|
-
node_feature = tensor.node['
|
|
643
|
+
node_feature = tensor.node['aggregate']
|
|
612
644
|
node_feature = self._feedforward_intermediate_dense(node_feature)
|
|
613
645
|
node_feature = self._feedforward_activation(node_feature)
|
|
614
646
|
node_feature = self._feedforward_output_dense(node_feature)
|
|
@@ -616,6 +648,7 @@ class GIConv(GraphConv):
|
|
|
616
648
|
{
|
|
617
649
|
'node': {
|
|
618
650
|
'feature': node_feature,
|
|
651
|
+
'aggregate': None,
|
|
619
652
|
}
|
|
620
653
|
}
|
|
621
654
|
)
|
|
@@ -667,17 +700,16 @@ class GAConv(GraphConv):
|
|
|
667
700
|
heads: int = 8,
|
|
668
701
|
activation: keras.layers.Activation | str | None = "relu",
|
|
669
702
|
use_bias: bool = True,
|
|
670
|
-
|
|
703
|
+
normalize: bool = False,
|
|
671
704
|
update_edge_feature: bool = True,
|
|
672
705
|
attention_activation: keras.layers.Activation | str | None = "leaky_relu",
|
|
673
706
|
**kwargs,
|
|
674
707
|
) -> None:
|
|
675
|
-
kwargs['skip_connection'] = False
|
|
676
708
|
super().__init__(
|
|
677
709
|
units=units,
|
|
678
710
|
activation=activation,
|
|
679
711
|
use_bias=use_bias,
|
|
680
|
-
|
|
712
|
+
normalize=normalize,
|
|
681
713
|
**kwargs
|
|
682
714
|
)
|
|
683
715
|
self._heads = heads
|
|
@@ -753,11 +785,11 @@ class GAConv(GraphConv):
|
|
|
753
785
|
)
|
|
754
786
|
node_feature = self._node_dense(tensor.node['feature'])
|
|
755
787
|
message = ops.gather(node_feature, tensor.edge['source'])
|
|
788
|
+
message = ops.edge_weight(message, attention_score)
|
|
756
789
|
return tensor.update(
|
|
757
790
|
{
|
|
758
791
|
'edge': {
|
|
759
792
|
'message': message,
|
|
760
|
-
'weight': attention_score,
|
|
761
793
|
'feature': edge_feature,
|
|
762
794
|
}
|
|
763
795
|
}
|
|
@@ -770,24 +802,24 @@ class GAConv(GraphConv):
|
|
|
770
802
|
return tensor.update(
|
|
771
803
|
{
|
|
772
804
|
'node': {
|
|
773
|
-
'
|
|
805
|
+
'aggregate': node_feature
|
|
774
806
|
},
|
|
775
807
|
'edge': {
|
|
776
808
|
'message': None,
|
|
777
|
-
'weight': None,
|
|
778
809
|
}
|
|
779
810
|
}
|
|
780
811
|
)
|
|
781
812
|
|
|
782
813
|
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
783
|
-
node_feature = tensor.node['
|
|
814
|
+
node_feature = tensor.node['aggregate']
|
|
784
815
|
node_feature = self._feedforward_intermediate_dense(node_feature)
|
|
785
816
|
node_feature = self._feedforward_activation(node_feature)
|
|
786
817
|
node_feature = self._feedforward_output_dense(node_feature)
|
|
787
818
|
return tensor.update(
|
|
788
819
|
{
|
|
789
820
|
'node': {
|
|
790
|
-
'feature': node_feature
|
|
821
|
+
'feature': node_feature,
|
|
822
|
+
'aggregate': None,
|
|
791
823
|
}
|
|
792
824
|
}
|
|
793
825
|
)
|
|
@@ -842,7 +874,7 @@ class GTConv(GraphConv):
|
|
|
842
874
|
heads: int = 8,
|
|
843
875
|
activation: keras.layers.Activation | str | None = "relu",
|
|
844
876
|
use_bias: bool = True,
|
|
845
|
-
|
|
877
|
+
normalize: bool = False,
|
|
846
878
|
attention_dropout: float = 0.0,
|
|
847
879
|
**kwargs,
|
|
848
880
|
) -> None:
|
|
@@ -850,7 +882,7 @@ class GTConv(GraphConv):
|
|
|
850
882
|
units=units,
|
|
851
883
|
activation=activation,
|
|
852
884
|
use_bias=use_bias,
|
|
853
|
-
|
|
885
|
+
normalize=normalize,
|
|
854
886
|
**kwargs
|
|
855
887
|
)
|
|
856
888
|
self._heads = heads
|
|
@@ -901,7 +933,6 @@ class GTConv(GraphConv):
|
|
|
901
933
|
}
|
|
902
934
|
}
|
|
903
935
|
)
|
|
904
|
-
|
|
905
936
|
node_feature = tensor.node['feature']
|
|
906
937
|
|
|
907
938
|
query = self._query_dense(node_feature)
|
|
@@ -918,12 +949,12 @@ class GTConv(GraphConv):
|
|
|
918
949
|
attention_score += keras.ops.expand_dims(tensor.edge['bias'], axis=1)
|
|
919
950
|
attention = ops.edge_softmax(attention_score, tensor.edge['target'])
|
|
920
951
|
attention = self._softmax_dropout(attention)
|
|
952
|
+
message = ops.edge_weight(value, attention)
|
|
921
953
|
|
|
922
954
|
return tensor.update(
|
|
923
955
|
{
|
|
924
956
|
'edge': {
|
|
925
|
-
'message':
|
|
926
|
-
'weight': attention,
|
|
957
|
+
'message': message
|
|
927
958
|
},
|
|
928
959
|
}
|
|
929
960
|
)
|
|
@@ -935,18 +966,16 @@ class GTConv(GraphConv):
|
|
|
935
966
|
return tensor.update(
|
|
936
967
|
{
|
|
937
968
|
'node': {
|
|
938
|
-
'
|
|
939
|
-
'residual': tensor.node['feature']
|
|
969
|
+
'aggregate': node_feature,
|
|
940
970
|
},
|
|
941
971
|
'edge': {
|
|
942
972
|
'message': None,
|
|
943
|
-
'weight': None,
|
|
944
973
|
}
|
|
945
974
|
}
|
|
946
975
|
)
|
|
947
976
|
|
|
948
977
|
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
949
|
-
node_feature = tensor.node['
|
|
978
|
+
node_feature = tensor.node['aggregate']
|
|
950
979
|
node_feature = self._feedforward_intermediate_dense(node_feature)
|
|
951
980
|
node_feature = self._feedforward_activation(node_feature)
|
|
952
981
|
node_feature = self._feedforward_output_dense(node_feature)
|
|
@@ -954,6 +983,7 @@ class GTConv(GraphConv):
|
|
|
954
983
|
{
|
|
955
984
|
'node': {
|
|
956
985
|
'feature': node_feature,
|
|
986
|
+
'aggregate': None,
|
|
957
987
|
},
|
|
958
988
|
}
|
|
959
989
|
)
|
|
@@ -978,14 +1008,14 @@ class MPConv(GraphConv):
|
|
|
978
1008
|
units: int = 128,
|
|
979
1009
|
activation: keras.layers.Activation | str | None = None,
|
|
980
1010
|
use_bias: bool = True,
|
|
981
|
-
|
|
1011
|
+
normalize: bool = False,
|
|
982
1012
|
**kwargs
|
|
983
1013
|
) -> None:
|
|
984
1014
|
super().__init__(
|
|
985
1015
|
units=units,
|
|
986
1016
|
activation=activation,
|
|
987
1017
|
use_bias=use_bias,
|
|
988
|
-
|
|
1018
|
+
normalize=normalize,
|
|
989
1019
|
**kwargs
|
|
990
1020
|
)
|
|
991
1021
|
|
|
@@ -1032,28 +1062,28 @@ class MPConv(GraphConv):
|
|
|
1032
1062
|
|
|
1033
1063
|
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1034
1064
|
aggregate = tensor.aggregate('message', mode='mean')
|
|
1035
|
-
|
|
1065
|
+
feature = tensor.node['feature']
|
|
1036
1066
|
if self.project_input_node_feature:
|
|
1037
|
-
|
|
1067
|
+
feature = self._previous_node_dense(feature)
|
|
1038
1068
|
return tensor.update(
|
|
1039
1069
|
{
|
|
1040
1070
|
'node': {
|
|
1041
|
-
'
|
|
1042
|
-
'
|
|
1071
|
+
'aggregate': aggregate,
|
|
1072
|
+
'feature': feature,
|
|
1043
1073
|
}
|
|
1044
1074
|
}
|
|
1045
1075
|
)
|
|
1046
1076
|
|
|
1047
1077
|
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1048
1078
|
updated_node_feature, _ = self.update_fn(
|
|
1049
|
-
inputs=tensor.node['
|
|
1050
|
-
states=tensor.node['
|
|
1079
|
+
inputs=tensor.node['aggregate'],
|
|
1080
|
+
states=tensor.node['feature']
|
|
1051
1081
|
)
|
|
1052
1082
|
return tensor.update(
|
|
1053
1083
|
{
|
|
1054
1084
|
'node': {
|
|
1055
1085
|
'feature': updated_node_feature,
|
|
1056
|
-
'
|
|
1086
|
+
'aggregate': None,
|
|
1057
1087
|
}
|
|
1058
1088
|
}
|
|
1059
1089
|
)
|
|
@@ -1124,12 +1154,13 @@ class GTConv3D(GTConv):
|
|
|
1124
1154
|
attention *= keras.ops.expand_dims(distance, axis=-1)
|
|
1125
1155
|
attention = keras.ops.expand_dims(attention, axis=2)
|
|
1126
1156
|
value = keras.ops.expand_dims(value, axis=1)
|
|
1157
|
+
|
|
1158
|
+
message = ops.edge_weight(value, attention)
|
|
1127
1159
|
|
|
1128
1160
|
return tensor.update(
|
|
1129
1161
|
{
|
|
1130
1162
|
'edge': {
|
|
1131
|
-
'message':
|
|
1132
|
-
'weight': attention,
|
|
1163
|
+
'message': message,
|
|
1133
1164
|
},
|
|
1134
1165
|
}
|
|
1135
1166
|
)
|
|
@@ -1144,12 +1175,10 @@ class GTConv3D(GTConv):
|
|
|
1144
1175
|
return tensor.update(
|
|
1145
1176
|
{
|
|
1146
1177
|
'node': {
|
|
1147
|
-
'
|
|
1148
|
-
'residual': tensor.node['feature']
|
|
1178
|
+
'aggregate': node_feature,
|
|
1149
1179
|
},
|
|
1150
1180
|
'edge': {
|
|
1151
1181
|
'message': None,
|
|
1152
|
-
'weight': None,
|
|
1153
1182
|
}
|
|
1154
1183
|
}
|
|
1155
1184
|
)
|
|
@@ -1202,16 +1231,16 @@ class EGConv3D(GraphConv):
|
|
|
1202
1231
|
def __init__(
|
|
1203
1232
|
self,
|
|
1204
1233
|
units: int = 128,
|
|
1205
|
-
activation: keras.layers.Activation | str | None =
|
|
1234
|
+
activation: keras.layers.Activation | str | None = 'silu',
|
|
1206
1235
|
use_bias: bool = True,
|
|
1207
|
-
|
|
1236
|
+
normalize: bool = False,
|
|
1208
1237
|
**kwargs
|
|
1209
1238
|
) -> None:
|
|
1210
1239
|
super().__init__(
|
|
1211
1240
|
units=units,
|
|
1212
1241
|
activation=activation,
|
|
1213
1242
|
use_bias=use_bias,
|
|
1214
|
-
|
|
1243
|
+
normalize=normalize,
|
|
1215
1244
|
**kwargs
|
|
1216
1245
|
)
|
|
1217
1246
|
|
|
@@ -1222,31 +1251,52 @@ class EGConv3D(GraphConv):
|
|
|
1222
1251
|
'which is required for Conv3D layers.'
|
|
1223
1252
|
)
|
|
1224
1253
|
self._has_edge_feature = 'feature' in spec.edge
|
|
1225
|
-
self.
|
|
1226
|
-
|
|
1254
|
+
self._message_feedforward_intermediate = self.get_dense(
|
|
1255
|
+
self.units, activation=self._activation
|
|
1256
|
+
)
|
|
1257
|
+
self._message_feedforward_final = self.get_dense(
|
|
1258
|
+
self.units, activation=self._activation
|
|
1259
|
+
)
|
|
1260
|
+
|
|
1261
|
+
self._coord_feedforward_intermediate = self.get_dense(
|
|
1262
|
+
self.units, activation=self._activation
|
|
1263
|
+
)
|
|
1264
|
+
self._coord_feedforward_final = self.get_dense(
|
|
1265
|
+
1, use_bias=False, activation='tanh'
|
|
1266
|
+
)
|
|
1227
1267
|
|
|
1228
1268
|
has_overridden_update = self.__class__.update != EGConv3D.update
|
|
1229
1269
|
if not has_overridden_update:
|
|
1230
|
-
self.
|
|
1231
|
-
|
|
1270
|
+
self._feedforward_intermediate = self.get_dense(
|
|
1271
|
+
self.units, activation=self._activation
|
|
1272
|
+
)
|
|
1273
|
+
self._feedforward_output = self.get_dense(self.units)
|
|
1232
1274
|
|
|
1233
1275
|
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1234
1276
|
relative_node_coordinate = keras.ops.subtract(
|
|
1235
1277
|
tensor.gather('coordinate', 'target'),
|
|
1236
1278
|
tensor.gather('coordinate', 'source')
|
|
1237
1279
|
)
|
|
1238
|
-
|
|
1239
|
-
keras.ops.square(
|
|
1240
|
-
relative_node_coordinate
|
|
1241
|
-
),
|
|
1280
|
+
squared_distance = keras.ops.sum(
|
|
1281
|
+
keras.ops.square(relative_node_coordinate),
|
|
1242
1282
|
axis=-1,
|
|
1243
1283
|
keepdims=True
|
|
1244
1284
|
)
|
|
1285
|
+
|
|
1286
|
+
# For numerical stability (i.e., to prevent NaN losses), this implementation of `EGConv3D`
|
|
1287
|
+
# either needs to apply a `tanh` activation to the output of `self._coord_feedforward_final`,
|
|
1288
|
+
# or normalize `relative_node_cordinate` as follows:
|
|
1289
|
+
#
|
|
1290
|
+
# norm = keras.ops.sqrt(squared_distance) + keras.backend.epsilon()
|
|
1291
|
+
# relative_node_coordinate /= norm
|
|
1292
|
+
#
|
|
1293
|
+
# For now, this implementation does the former.
|
|
1294
|
+
|
|
1245
1295
|
feature = keras.ops.concatenate(
|
|
1246
1296
|
[
|
|
1247
1297
|
tensor.gather('feature', 'target'),
|
|
1248
1298
|
tensor.gather('feature', 'source'),
|
|
1249
|
-
|
|
1299
|
+
squared_distance,
|
|
1250
1300
|
],
|
|
1251
1301
|
axis=-1
|
|
1252
1302
|
)
|
|
@@ -1258,10 +1308,15 @@ class EGConv3D(GraphConv):
|
|
|
1258
1308
|
],
|
|
1259
1309
|
axis=-1
|
|
1260
1310
|
)
|
|
1261
|
-
message = self.
|
|
1311
|
+
message = self._message_feedforward_final(
|
|
1312
|
+
self._message_feedforward_intermediate(feature)
|
|
1313
|
+
)
|
|
1314
|
+
|
|
1262
1315
|
relative_node_coordinate = keras.ops.multiply(
|
|
1263
|
-
relative_node_coordinate,
|
|
1264
|
-
self.
|
|
1316
|
+
relative_node_coordinate,
|
|
1317
|
+
self._coord_feedforward_final(
|
|
1318
|
+
self._coord_feedforward_intermediate(message)
|
|
1319
|
+
)
|
|
1265
1320
|
)
|
|
1266
1321
|
return tensor.update(
|
|
1267
1322
|
{
|
|
@@ -1273,27 +1328,26 @@ class EGConv3D(GraphConv):
|
|
|
1273
1328
|
)
|
|
1274
1329
|
|
|
1275
1330
|
def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
|
|
1285
|
-
|
|
1331
|
+
coordinate = tensor.node['coordinate']
|
|
1332
|
+
coordinate += tensor.aggregate('relative_node_coordinate', mode='mean')
|
|
1333
|
+
|
|
1334
|
+
# Original implementation seems to apply sum aggregation, which does not
|
|
1335
|
+
# seem work well for this implementation of `EGConv3D`, as it causes
|
|
1336
|
+
# large output values and large initial losses. The magnitude of the
|
|
1337
|
+
# aggregated values of a sum aggregation depends on the number of
|
|
1338
|
+
# neighbors, which may be many and may differ from node to node (or
|
|
1339
|
+
# graph to graph). Therefore, a mean mean aggregation is performed
|
|
1340
|
+
# instead:
|
|
1341
|
+
aggregate = tensor.aggregate('message', mode='mean')
|
|
1286
1342
|
|
|
1287
|
-
|
|
1288
|
-
|
|
1343
|
+
# Simply added to silence warning ('no gradients for variables ...')
|
|
1344
|
+
aggregate += (0.0 * keras.ops.sum(coordinate))
|
|
1289
1345
|
|
|
1290
|
-
aggregate = tensor.aggregate('message', mode='mean')
|
|
1291
1346
|
return tensor.update(
|
|
1292
1347
|
{
|
|
1293
1348
|
'node': {
|
|
1294
|
-
'
|
|
1295
|
-
'coordinate':
|
|
1296
|
-
'previous_feature': tensor.node['feature'],
|
|
1349
|
+
'aggregate': aggregate,
|
|
1350
|
+
'coordinate': coordinate,
|
|
1297
1351
|
},
|
|
1298
1352
|
'edge': {
|
|
1299
1353
|
'message': None,
|
|
@@ -1303,21 +1357,21 @@ class EGConv3D(GraphConv):
|
|
|
1303
1357
|
)
|
|
1304
1358
|
|
|
1305
1359
|
def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1306
|
-
|
|
1307
|
-
|
|
1308
|
-
[
|
|
1309
|
-
|
|
1310
|
-
|
|
1311
|
-
|
|
1312
|
-
|
|
1313
|
-
|
|
1360
|
+
feature = keras.ops.concatenate(
|
|
1361
|
+
[
|
|
1362
|
+
tensor.node['aggregate'],
|
|
1363
|
+
tensor.node['feature']
|
|
1364
|
+
],
|
|
1365
|
+
axis=-1
|
|
1366
|
+
)
|
|
1367
|
+
updated_node_feature = self._feedforward_output(
|
|
1368
|
+
self._feedforward_intermediate(feature)
|
|
1314
1369
|
)
|
|
1315
|
-
updated_node_feature = self.output_dense(updated_node_feature)
|
|
1316
1370
|
return tensor.update(
|
|
1317
1371
|
{
|
|
1318
1372
|
'node': {
|
|
1319
1373
|
'feature': updated_node_feature,
|
|
1320
|
-
'
|
|
1374
|
+
'aggregate': None,
|
|
1321
1375
|
},
|
|
1322
1376
|
}
|
|
1323
1377
|
)
|
|
@@ -1478,6 +1532,32 @@ class GraphNetwork(GraphLayer):
|
|
|
1478
1532
|
return super().from_config(config)
|
|
1479
1533
|
|
|
1480
1534
|
|
|
1535
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1536
|
+
class Extraction(GraphLayer):
|
|
1537
|
+
|
|
1538
|
+
def __init__(
|
|
1539
|
+
self,
|
|
1540
|
+
field: str,
|
|
1541
|
+
inner_field: str | None = None,
|
|
1542
|
+
**kwargs
|
|
1543
|
+
) -> None:
|
|
1544
|
+
super().__init__(**kwargs)
|
|
1545
|
+
self.field = field
|
|
1546
|
+
self.inner_field = inner_field
|
|
1547
|
+
|
|
1548
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1549
|
+
data = dict(getattr(tensor, self.field))
|
|
1550
|
+
if not self.inner_field:
|
|
1551
|
+
return data
|
|
1552
|
+
return data[self.inner_field]
|
|
1553
|
+
|
|
1554
|
+
def get_config(self):
|
|
1555
|
+
config = super().get_config()
|
|
1556
|
+
config['field'] = self.field
|
|
1557
|
+
config['inner_field'] = self.inner_field
|
|
1558
|
+
return config
|
|
1559
|
+
|
|
1560
|
+
|
|
1481
1561
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1482
1562
|
class NodeEmbedding(GraphLayer):
|
|
1483
1563
|
|
|
@@ -1489,15 +1569,15 @@ class NodeEmbedding(GraphLayer):
|
|
|
1489
1569
|
def __init__(
|
|
1490
1570
|
self,
|
|
1491
1571
|
dim: int = None,
|
|
1492
|
-
|
|
1493
|
-
embed_context: bool =
|
|
1572
|
+
normalize: bool = False,
|
|
1573
|
+
embed_context: bool = False,
|
|
1494
1574
|
allow_reconstruction: bool = False,
|
|
1495
|
-
allow_masking: bool =
|
|
1575
|
+
allow_masking: bool = False,
|
|
1496
1576
|
**kwargs
|
|
1497
1577
|
) -> None:
|
|
1498
1578
|
super().__init__(**kwargs)
|
|
1499
1579
|
self.dim = dim
|
|
1500
|
-
self.
|
|
1580
|
+
self._normalize = normalize
|
|
1501
1581
|
self._embed_context = embed_context
|
|
1502
1582
|
self._masking_rate = None
|
|
1503
1583
|
self._allow_masking = allow_masking
|
|
@@ -1517,13 +1597,11 @@ class NodeEmbedding(GraphLayer):
|
|
|
1517
1597
|
self._super_feature = self.get_weight(shape=[self.dim], name='super_node_feature')
|
|
1518
1598
|
if self._allow_masking:
|
|
1519
1599
|
self._mask_feature = self.get_weight(shape=[self.dim], name='mask_node_feature')
|
|
1520
|
-
|
|
1521
1600
|
if self._embed_context:
|
|
1522
|
-
context_feature_dim = spec.context['feature'].shape[-1]
|
|
1523
1601
|
self._context_dense = self.get_dense(self.dim)
|
|
1524
1602
|
|
|
1525
|
-
if self.
|
|
1526
|
-
if str(self.
|
|
1603
|
+
if self._normalize:
|
|
1604
|
+
if str(self._normalize).lower().startswith('batch'):
|
|
1527
1605
|
self._norm = keras.layers.BatchNormalization(
|
|
1528
1606
|
name='output_batch_norm'
|
|
1529
1607
|
)
|
|
@@ -1545,48 +1623,25 @@ class NodeEmbedding(GraphLayer):
|
|
|
1545
1623
|
feature = ops.scatter_update(feature, tensor.node['super'], context_feature)
|
|
1546
1624
|
tensor = tensor.update({'context': {'feature': None}})
|
|
1547
1625
|
|
|
1548
|
-
|
|
1549
|
-
|
|
1550
|
-
|
|
1551
|
-
self._masking_rate > 0
|
|
1552
|
-
):
|
|
1553
|
-
random = keras.random.uniform(shape=[tensor.num_nodes])
|
|
1554
|
-
mask = random <= self._masking_rate
|
|
1555
|
-
if self._has_super:
|
|
1556
|
-
mask = keras.ops.logical_and(
|
|
1557
|
-
mask, keras.ops.logical_not(tensor.node['super'])
|
|
1558
|
-
)
|
|
1559
|
-
mask = keras.ops.expand_dims(mask, -1)
|
|
1626
|
+
apply_mask = (self._allow_masking and 'mask' in tensor.node)
|
|
1627
|
+
if apply_mask:
|
|
1628
|
+
mask = keras.ops.expand_dims(tensor.node['mask'], -1)
|
|
1560
1629
|
feature = keras.ops.where(mask, self._mask_feature, feature)
|
|
1561
1630
|
elif self._allow_masking:
|
|
1562
|
-
# Slience warning of 'no gradients for variables'
|
|
1563
1631
|
feature = feature + (self._mask_feature * 0.0)
|
|
1564
1632
|
|
|
1565
|
-
if self.
|
|
1633
|
+
if self._normalize:
|
|
1566
1634
|
feature = self._norm(feature)
|
|
1567
1635
|
|
|
1568
1636
|
if not self._allow_reconstruction:
|
|
1569
1637
|
return tensor.update({'node': {'feature': feature}})
|
|
1570
1638
|
return tensor.update({'node': {'feature': feature, 'target_feature': feature}})
|
|
1571
|
-
|
|
1572
|
-
@property
|
|
1573
|
-
def masking_rate(self):
|
|
1574
|
-
return self._masking_rate
|
|
1575
|
-
|
|
1576
|
-
@masking_rate.setter
|
|
1577
|
-
def masking_rate(self, rate: float):
|
|
1578
|
-
if not self._allow_masking and rate is not None:
|
|
1579
|
-
raise ValueError(
|
|
1580
|
-
f'Cannot set `masking_rate` for layer {self} '
|
|
1581
|
-
'as `allow_masking` was set to `False`.'
|
|
1582
|
-
)
|
|
1583
|
-
self._masking_rate = float(rate)
|
|
1584
1639
|
|
|
1585
1640
|
def get_config(self) -> dict:
|
|
1586
1641
|
config = super().get_config()
|
|
1587
1642
|
config.update({
|
|
1588
1643
|
'dim': self.dim,
|
|
1589
|
-
'
|
|
1644
|
+
'normalize': self._normalize,
|
|
1590
1645
|
'embed_context': self._embed_context,
|
|
1591
1646
|
'allow_masking': self._allow_masking,
|
|
1592
1647
|
'allow_reconstruction': self._allow_reconstruction,
|
|
@@ -1605,13 +1660,13 @@ class EdgeEmbedding(GraphLayer):
|
|
|
1605
1660
|
def __init__(
|
|
1606
1661
|
self,
|
|
1607
1662
|
dim: int = None,
|
|
1608
|
-
|
|
1663
|
+
normalize: bool = False,
|
|
1609
1664
|
allow_masking: bool = True,
|
|
1610
1665
|
**kwargs
|
|
1611
1666
|
) -> None:
|
|
1612
1667
|
super().__init__(**kwargs)
|
|
1613
1668
|
self.dim = dim
|
|
1614
|
-
self.
|
|
1669
|
+
self._normalize = normalize
|
|
1615
1670
|
self._masking_rate = None
|
|
1616
1671
|
self._allow_masking = allow_masking
|
|
1617
1672
|
|
|
@@ -1622,13 +1677,16 @@ class EdgeEmbedding(GraphLayer):
|
|
|
1622
1677
|
self._edge_dense = self.get_dense(self.dim)
|
|
1623
1678
|
|
|
1624
1679
|
self._has_super = 'super' in spec.edge
|
|
1680
|
+
self._has_self_loop = 'self_loop' in spec.edge
|
|
1625
1681
|
if self._has_super:
|
|
1626
1682
|
self._super_feature = self.get_weight(shape=[self.dim], name='super_edge_feature')
|
|
1683
|
+
if self._has_self_loop:
|
|
1684
|
+
self._self_loop_feature = self.get_weight(shape=[self.dim], name='self_loop_edge_feature')
|
|
1627
1685
|
if self._allow_masking:
|
|
1628
1686
|
self._mask_feature = self.get_weight(shape=[self.dim], name='mask_edge_feature')
|
|
1629
1687
|
|
|
1630
|
-
if self.
|
|
1631
|
-
if str(self.
|
|
1688
|
+
if self._normalize:
|
|
1689
|
+
if str(self._normalize).lower().startswith('batch'):
|
|
1632
1690
|
self._norm = keras.layers.BatchNormalization(
|
|
1633
1691
|
name='output_batch_norm'
|
|
1634
1692
|
)
|
|
@@ -1641,10 +1699,13 @@ class EdgeEmbedding(GraphLayer):
|
|
|
1641
1699
|
feature = self._edge_dense(tensor.edge['feature'])
|
|
1642
1700
|
|
|
1643
1701
|
if self._has_super:
|
|
1644
|
-
super_feature = self._super_feature
|
|
1645
1702
|
super_mask = keras.ops.expand_dims(tensor.edge['super'], 1)
|
|
1646
|
-
feature = keras.ops.where(super_mask,
|
|
1703
|
+
feature = keras.ops.where(super_mask, self._super_feature, feature)
|
|
1647
1704
|
|
|
1705
|
+
if self._has_self_loop:
|
|
1706
|
+
self_loop_mask = keras.ops.expand_dims(tensor.edge['self_loop'], 1)
|
|
1707
|
+
feature = keras.ops.where(self_loop_mask, self._self_loop_feature, feature)
|
|
1708
|
+
|
|
1648
1709
|
if (
|
|
1649
1710
|
self._allow_masking and
|
|
1650
1711
|
self._masking_rate is not None and
|
|
@@ -1659,10 +1720,10 @@ class EdgeEmbedding(GraphLayer):
|
|
|
1659
1720
|
mask = keras.ops.expand_dims(mask, -1)
|
|
1660
1721
|
feature = keras.ops.where(mask, self._mask_feature, feature)
|
|
1661
1722
|
elif self._allow_masking:
|
|
1662
|
-
#
|
|
1663
|
-
feature
|
|
1723
|
+
# Simply added to silence warning ('no gradients for variables ...')
|
|
1724
|
+
feature += (0.0 * self._mask_feature)
|
|
1664
1725
|
|
|
1665
|
-
if self.
|
|
1726
|
+
if self._normalize:
|
|
1666
1727
|
feature = self._norm(feature)
|
|
1667
1728
|
|
|
1668
1729
|
return tensor.update({'edge': {'feature': feature, 'embedding': feature}})
|
|
@@ -1684,7 +1745,7 @@ class EdgeEmbedding(GraphLayer):
|
|
|
1684
1745
|
config = super().get_config()
|
|
1685
1746
|
config.update({
|
|
1686
1747
|
'dim': self.dim,
|
|
1687
|
-
'
|
|
1748
|
+
'normalize': self._normalize,
|
|
1688
1749
|
'allow_masking': self._allow_masking
|
|
1689
1750
|
})
|
|
1690
1751
|
return config
|
|
@@ -1883,6 +1944,56 @@ class GaussianDistance(GraphLayer):
|
|
|
1883
1944
|
return config
|
|
1884
1945
|
|
|
1885
1946
|
|
|
1947
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1948
|
+
class GaussianParams(keras.layers.Dense):
|
|
1949
|
+
'''Gaussian parameters.
|
|
1950
|
+
|
|
1951
|
+
Computes loc and scale via a dense layer. Should be implemented
|
|
1952
|
+
as the last layer in a model and paired with `losses.GaussianNLL`.
|
|
1953
|
+
|
|
1954
|
+
The loc and scale parameters (resulting from this layer) are concatenated
|
|
1955
|
+
together along the last axis, resulting in a single output tensor.
|
|
1956
|
+
|
|
1957
|
+
Args:
|
|
1958
|
+
events (int):
|
|
1959
|
+
The number of events. If the model makes a single prediction per example,
|
|
1960
|
+
then the number of events should be 1. If the model makes multiple predictions
|
|
1961
|
+
per example, then the number of events should be greater than 1.
|
|
1962
|
+
Default to 1.
|
|
1963
|
+
kwargs:
|
|
1964
|
+
See `keras.layers.Dense` documentation. `activation` will be applied
|
|
1965
|
+
to `loc` only. `scale` is automatically softplus activated.
|
|
1966
|
+
'''
|
|
1967
|
+
def __init__(self, events: int = 1, **kwargs):
|
|
1968
|
+
units = kwargs.pop('units', None)
|
|
1969
|
+
activation = kwargs.pop('activation', None)
|
|
1970
|
+
if units:
|
|
1971
|
+
if units % 2 != 0:
|
|
1972
|
+
raise ValueError(
|
|
1973
|
+
'`units` needs to be divisble by 2 as `units` = 2 x `events`.'
|
|
1974
|
+
)
|
|
1975
|
+
else:
|
|
1976
|
+
units = int(events * 2)
|
|
1977
|
+
super().__init__(units=units, **kwargs)
|
|
1978
|
+
self.events = events
|
|
1979
|
+
self.loc_activation = keras.activations.get(activation)
|
|
1980
|
+
|
|
1981
|
+
def call(self, inputs, **kwargs):
|
|
1982
|
+
loc_and_scale = super().call(inputs, **kwargs)
|
|
1983
|
+
loc = loc_and_scale[..., :self.events]
|
|
1984
|
+
scale = loc_and_scale[..., self.events:]
|
|
1985
|
+
scale = keras.ops.softplus(scale) + keras.backend.epsilon()
|
|
1986
|
+
loc = self.loc_activation(loc)
|
|
1987
|
+
return keras.ops.concatenate([loc, scale], axis=-1)
|
|
1988
|
+
|
|
1989
|
+
def get_config(self):
|
|
1990
|
+
config = super().get_config()
|
|
1991
|
+
config['events'] = self.events
|
|
1992
|
+
config['units'] = None
|
|
1993
|
+
config['activation'] = keras.activations.serialize(self.loc_activation)
|
|
1994
|
+
return config
|
|
1995
|
+
|
|
1996
|
+
|
|
1886
1997
|
def Input(spec: tensors.GraphTensor.Spec) -> dict:
|
|
1887
1998
|
"""Used to specify inputs to model.
|
|
1888
1999
|
|
|
@@ -1914,9 +2025,11 @@ def Input(spec: tensors.GraphTensor.Spec) -> dict:
|
|
|
1914
2025
|
for outer_field, data in spec.__dict__.items():
|
|
1915
2026
|
inputs[outer_field] = {}
|
|
1916
2027
|
for inner_field, nested_spec in data.items():
|
|
1917
|
-
if inner_field in ['label', 'weight']:
|
|
1918
|
-
|
|
1919
|
-
|
|
2028
|
+
if outer_field == 'context' and inner_field in ['label', 'weight']:
|
|
2029
|
+
# Remove context label and weight from the symbolic input
|
|
2030
|
+
# as a functional model is strict for what input can be passed.
|
|
2031
|
+
# (We want to train and predict with the model.)
|
|
2032
|
+
continue
|
|
1920
2033
|
kwargs = {
|
|
1921
2034
|
'shape': nested_spec.shape[1:],
|
|
1922
2035
|
'dtype': nested_spec.dtype,
|
|
@@ -1941,23 +2054,6 @@ def warn(message: str) -> None:
|
|
|
1941
2054
|
stacklevel=1
|
|
1942
2055
|
)
|
|
1943
2056
|
|
|
1944
|
-
def _match_functional_input(functional_input, inputs):
|
|
1945
|
-
matching_inputs = {}
|
|
1946
|
-
for outer_field, data in functional_input.items():
|
|
1947
|
-
matching_inputs[outer_field] = {}
|
|
1948
|
-
for inner_field, _ in data.items():
|
|
1949
|
-
call_input = inputs[outer_field].pop(inner_field)
|
|
1950
|
-
matching_inputs[outer_field][inner_field] = call_input
|
|
1951
|
-
unmatching_inputs = inputs
|
|
1952
|
-
return matching_inputs, unmatching_inputs
|
|
1953
|
-
|
|
1954
|
-
def _add_left_out_inputs(outputs, inputs):
|
|
1955
|
-
for outer_field, data in inputs.items():
|
|
1956
|
-
for inner_field, value in data.items():
|
|
1957
|
-
if inner_field in ['label', 'weight']:
|
|
1958
|
-
outputs[outer_field][inner_field] = value
|
|
1959
|
-
return outputs
|
|
1960
|
-
|
|
1961
2057
|
def _serialize_spec(spec: tensors.GraphTensor.Spec) -> dict:
|
|
1962
2058
|
serialized_spec = {}
|
|
1963
2059
|
for outer_field, data in spec.__dict__.items():
|