molcraft 0.1.0rc9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

molcraft/losses.py ADDED
@@ -0,0 +1,37 @@
1
+ import warnings
2
+ import keras
3
+ import numpy as np
4
+
5
+
6
+ @keras.saving.register_keras_serializable(package='molcraft')
7
+ class GaussianNegativeLogLikelihood(keras.losses.Loss):
8
+
9
+ def __init__(
10
+ self,
11
+ events: int = 1,
12
+ name="gaussian_nll",
13
+ **kwargs
14
+ ):
15
+ super().__init__(name=name, **kwargs)
16
+ self.events = events
17
+
18
+ def call(self, y_true, y_pred):
19
+ mean = y_pred[..., :self.events]
20
+ scale = y_pred[..., self.events:]
21
+ variance = keras.ops.square(scale)
22
+ expected_rank = len(keras.ops.shape(mean))
23
+ current_rank = len(keras.ops.shape(y_true))
24
+ for _ in range(expected_rank - current_rank):
25
+ y_true = keras.ops.expand_dims(y_true, axis=-1)
26
+ return keras.ops.mean(
27
+ 0.5 * keras.ops.log(2.0 * np.pi * variance) +
28
+ 0.5 * keras.ops.square(y_true - mean) / variance
29
+ )
30
+
31
+ def get_config(self):
32
+ config = super().get_config()
33
+ config['events'] = self.events
34
+ return config
35
+
36
+
37
+ GaussianNLL = GaussianNegativeLogLikelihood
molcraft/models.py ADDED
@@ -0,0 +1,623 @@
1
+ import warnings
2
+ import typing
3
+ import keras
4
+ import numpy as np
5
+ import tensorflow as tf
6
+ from pathlib import Path
7
+ from keras.src.models import functional
8
+
9
+ from molcraft import layers
10
+ from molcraft import tensors
11
+ from molcraft import ops
12
+
13
+
14
+ @keras.saving.register_keras_serializable(package="molcraft")
15
+ class GraphModel(layers.GraphLayer, keras.models.Model):
16
+
17
+ """A graph model.
18
+
19
+ Currently, the `GraphModel` only supports `GraphTensor` input.
20
+
21
+ Build a subclassed GraphModel:
22
+
23
+ >>> import molcraft
24
+ >>> import keras
25
+ >>>
26
+ >>> featurizer = molcraft.featurizers.MolGraphFeaturizer()
27
+ >>> graph = featurizer([('N[C@@H](C)C(=O)O', 1.0), ('N[C@@H](CS)C(=O)O', 2.0)])
28
+ >>>
29
+ >>> @keras.saving.register_keras_serializable()
30
+ >>> class GraphNeuralNetwork(molcraft.models.GraphModel):
31
+ ... def __init__(self, units, **kwargs):
32
+ ... super().__init__(**kwargs)
33
+ ... self.units = units
34
+ ... self.node_embedding = molcraft.layers.NodeEmbedding(self.units)
35
+ ... self.edge_embedding = molcraft.layers.EdgeEmbedding(self.units)
36
+ ... self.conv_1 = molcraft.layers.GraphConv(self.units)
37
+ ... self.conv_2 = molcraft.layers.GraphConv(self.units)
38
+ ... self.readout = molcraft.layers.Readout('mean')
39
+ ... self.dense = keras.layers.Dense(1)
40
+ ... def propagate(self, graph):
41
+ ... x = self.edge_embedding(self.node_embedding(graph))
42
+ ... x = self.conv_2(self.conv_1(x))
43
+ ... return self.dense(self.readout(x))
44
+ ... def get_config(self):
45
+ ... config = super().get_config()
46
+ ... config['units'] = self.units
47
+ ... return config
48
+ >>>
49
+ >>> model = GraphNeuralNetwork(128)
50
+ >>> model.compile(
51
+ ... optimizer=keras.optimizers.Adam(1e-3),
52
+ ... loss=keras.losses.MeanSquaredError(),
53
+ ... metrics=[keras.metrics.MeanAbsolutePercentageError(name='mape')]
54
+ ... )
55
+ >>> model.fit(graph, epochs=10)
56
+ >>> mse, mape = model.evaluate(graph)
57
+ >>> preds = model.predict(graph)
58
+
59
+ Build a functional GraphModel:
60
+
61
+ >>> import molcraft
62
+ >>> import keras
63
+ >>>
64
+ >>> featurizer = molcraft.featurizers.MolGraphFeaturizer()
65
+ >>> graph = featurizer([('N[C@@H](C)C(=O)O', 1.0), ('N[C@@H](CS)C(=O)O', 2.0)])
66
+ >>>
67
+ >>> inputs = molcraft.layers.Input(graph.spec)
68
+ >>> x = molcraft.layers.NodeEmbedding(128)(inputs)
69
+ >>> x = molcraft.layers.EdgeEmbedding(128)(x)
70
+ >>> x = molcraft.layers.GraphConv(128)(x)
71
+ >>> x = molcraft.layers.GraphConv(128)(x)
72
+ >>> x = molcraft.layers.Readout('mean')(x)
73
+ >>> outputs = keras.layers.Dense(1)(x)
74
+ >>> model = molcraft.models.GraphModel(inputs, outputs)
75
+ >>> model.compile(
76
+ ... optimizer=keras.optimizers.Adam(1e-3),
77
+ ... loss=keras.losses.MeanSquaredError(),
78
+ ... metrics=[keras.metrics.MeanAbsolutePercentageError(name='mape')]
79
+ ... )
80
+ >>> model.fit(graph, epochs=10)
81
+ >>> mse, mape = model.evaluate(graph)
82
+ >>> preds = model.predict(graph)
83
+
84
+ Build a GraphModel using `from_layers`:
85
+
86
+ >>> import molcraft
87
+ >>> import keras
88
+ >>>
89
+ >>> featurizer = molcraft.featurizers.MolGraphFeaturizer()
90
+ >>> graph = featurizer([('N[C@@H](C)C(=O)O', 1.0), ('N[C@@H](CS)C(=O)O', 2.0)])
91
+ >>>
92
+ >>> model = molcraft.models.GraphModel.from_layers([
93
+ ... molcraft.layers.Input(graph.spec),
94
+ ... molcraft.layers.NodeEmbedding(128),
95
+ ... molcraft.layers.EdgeEmbedding(128),
96
+ ... molcraft.layers.GraphConv(128),
97
+ ... molcraft.layers.GraphConv(128),
98
+ ... molcraft.layers.Readout('mean'),
99
+ ... keras.layers.Dense(1)
100
+ ... ])
101
+ >>> model.compile(
102
+ ... optimizer=keras.optimizers.Adam(1e-3),
103
+ ... loss=keras.losses.MeanSquaredError(),
104
+ ... metrics=[keras.metrics.MeanAbsolutePercentageError(name='mape')]
105
+ ... )
106
+ >>> model.fit(graph, epochs=10)
107
+ >>> mse, mape = model.evaluate(graph)
108
+ >>> preds = model.predict(graph)
109
+
110
+ """
111
+
112
+ def __new__(cls, *args, **kwargs):
113
+ if _functional_init_arguments(args, kwargs) and cls == GraphModel:
114
+ return FunctionalGraphModel(*args, **kwargs)
115
+ return super().__new__(cls)
116
+
117
+ def __init__(self, *args, **kwargs):
118
+ self._model_layers = kwargs.pop('model_layers', None)
119
+ super().__init__(*args, **kwargs)
120
+ self.jit_compile = False
121
+
122
+ @classmethod
123
+ def from_layers(cls, graph_layers: list, **kwargs):
124
+ """Creates a graph model from a list of graph layers.
125
+
126
+ Currently requires `molcraft.layers.Input(spec)`.
127
+
128
+ If `molcraft.layers.Input(spec)` is supplied, it both
129
+ creates and builds the layer, as a functional model.
130
+ `molcraft.layers.Input` is a function which returns
131
+ a nested structure of graph components based on `spec`.
132
+
133
+ Args:
134
+ graph_layers:
135
+ A list of `GraphLayer` instances, except the initial element
136
+ which is a dictionary of Keras tensors produced by
137
+ `molcraft.layers.Input(spec)`.
138
+ """
139
+ if not tensors.is_graph(graph_layers[0]):
140
+ return cls(model_layers=graph_layers, **kwargs)
141
+ elif cls != GraphModel:
142
+ return cls(model_layers=graph_layers[1:], **kwargs)
143
+ inputs: dict = graph_layers.pop(0)
144
+ x = inputs
145
+ for layer in graph_layers:
146
+ if isinstance(layer, list):
147
+ layer = layers.GraphNetwork(layer)
148
+ x = layer(x)
149
+ outputs = x
150
+ return cls(inputs=inputs, outputs=outputs, **kwargs)
151
+
152
+ def propagate(self, graph: tensors.GraphTensor) -> tensors.GraphTensor:
153
+ if self._model_layers is None:
154
+ return super().propagate(graph)
155
+ for layer in self._model_layers:
156
+ graph = layer(graph)
157
+ return graph
158
+
159
+ def get_config(self):
160
+ """Obtain model config."""
161
+ config = super().get_config()
162
+ if hasattr(self, '_model_layers') and self._model_layers is not None:
163
+ config['model_layers'] = [
164
+ keras.saving.serialize_keras_object(l)
165
+ for l in self._model_layers
166
+ ]
167
+ return config
168
+
169
+ @classmethod
170
+ def from_config(cls, config: dict):
171
+ """Obtain model from model config."""
172
+ if 'model_layers' in config:
173
+ config['model_layers'] = [
174
+ keras.saving.deserialize_keras_object(l)
175
+ for l in config['model_layers']
176
+ ]
177
+ return super().from_config(config)
178
+
179
+ def compile(
180
+ self,
181
+ optimizer: keras.optimizers.Optimizer | str | None = 'rmsprop',
182
+ loss: keras.losses.Loss | str | None = None,
183
+ loss_weights: dict[str, float] = None,
184
+ metrics: list[keras.metrics.Metric] = None,
185
+ weighted_metrics: list[keras.metrics.Metric] | None = None,
186
+ run_eagerly: bool = False,
187
+ steps_per_execution: int = 1,
188
+ jit_compile: str | bool = False,
189
+ auto_scale_loss: bool = True,
190
+ **kwargs
191
+ ) -> None:
192
+ """Compiles the model.
193
+
194
+ Args:
195
+ optimizer:
196
+ The optimizer to be used (a `keras.optimizers.Optimizer` subclass).
197
+ loss:
198
+ The loss function to be used (a `keras.losses.Loss` subclass).
199
+ metrics:
200
+ A list of metrics to be used during training (`fit`) and evaluation
201
+ (`evaluate`). Should be `keras.metrics.Metric` subclasses.
202
+ kwargs:
203
+ See `Model.compile` in Keras documentation.
204
+ May or may not apply here.
205
+ """
206
+ super().compile(
207
+ optimizer=optimizer,
208
+ loss=loss,
209
+ loss_weights=loss_weights,
210
+ metrics=metrics,
211
+ weighted_metrics=weighted_metrics,
212
+ run_eagerly=run_eagerly,
213
+ steps_per_execution=steps_per_execution,
214
+ jit_compile=jit_compile,
215
+ auto_scale_loss=auto_scale_loss,
216
+ **kwargs
217
+ )
218
+
219
+ def fit(self, x: tensors.GraphTensor | tf.data.Dataset, **kwargs):
220
+ """Fits the model.
221
+
222
+ Args:
223
+ x:
224
+ A `GraphTensor` instance or a `tf.data.Dataset` constructed from
225
+ a `GraphTensor` instance. In comparison to a typical Keras model,
226
+ the label (typically denoted `y`) and the sample_weight (typically
227
+ denoted `sample_weight`) should be encoded in the context of the
228
+ `GraphTensor` instance, as `label` and `weight` respectively.
229
+ validation_data:
230
+ A `GraphTensor` instance or a `tf.data.Dataset` constructed from
231
+ a `GraphTensor` instance. In comparison to a typical Keras model,
232
+ the label (typically denoted `y`) and the sample_weight (typically
233
+ denoted `sample_weight`) should be encoded in the context of the
234
+ `GraphTensor` instance, as `label` and `weight` respectively.
235
+ validaton_split:
236
+ The fraction of training data to be used as validation data.
237
+ Only works if a `GraphTensor` instance is passed as `x`.
238
+ batch_size:
239
+ Number of samples per batch of computation.
240
+ epochs:
241
+ Number of iterations over the entire dataset.
242
+ callbacks:
243
+ A list of callbacks to apply during training.
244
+ kwargs:
245
+ See `Model.fit` in Keras documentation.
246
+ May or may not apply here.
247
+ """
248
+ batch_size = kwargs.get('batch_size', 32)
249
+ x_val = kwargs.pop('validation_data', None)
250
+ val_split = kwargs.pop('validation_split', None)
251
+ if x_val is not None and isinstance(x_val, tensors.GraphTensor):
252
+ x_val = _make_dataset(x_val, batch_size)
253
+ if isinstance(x, tensors.GraphTensor):
254
+ if val_split:
255
+ val_size = int(val_split * x.num_subgraphs)
256
+ x_val = _make_dataset(x[-val_size:], batch_size)
257
+ x = x[:-val_size]
258
+ x = _make_dataset(x, batch_size, shuffle=kwargs.get('shuffle', True))
259
+ return super().fit(x, validation_data=x_val, **kwargs)
260
+
261
+ def evaluate(self, x: tensors.GraphTensor | tf.data.Dataset, **kwargs):
262
+ """Evaluation of the model.
263
+
264
+ Args:
265
+ x:
266
+ A `GraphTensor` instance or a `tf.data.Dataset` constructed from
267
+ a `GraphTensor` instance. In comparison to a typical Keras model,
268
+ the label (typically denoted `y`) and the sample_weight (typically
269
+ denoted `sample_weight`) should be encoded in the context of the
270
+ `GraphTensor` instance, as `label` and `weight` respectively.
271
+ batch_size:
272
+ Number of samples per batch of computation.
273
+ kwargs:
274
+ See `Model.evaluate` in Keras documentation.
275
+ May or may not apply here.
276
+ """
277
+ batch_size = kwargs.get('batch_size', 32)
278
+ if isinstance(x, tensors.GraphTensor):
279
+ x = _make_dataset(x, batch_size)
280
+ metric_results = super().evaluate(x, **kwargs)
281
+ return tf.nest.map_structure(lambda value: float(value), metric_results)
282
+
283
+ def predict(self, x: tensors.GraphTensor | tf.data.Dataset, **kwargs):
284
+ """Makes predictions with the model.
285
+
286
+ Args:
287
+ x:
288
+ A `GraphTensor` instance or a `tf.data.Dataset` constructed from
289
+ a `GraphTensor` instance.
290
+ batch_size:
291
+ Number of samples per batch of computation.
292
+ kwargs:
293
+ See `Model.predict` in Keras documentation.
294
+ May or may not apply here.
295
+ """
296
+ batch_size = kwargs.get('batch_size', 32)
297
+ if isinstance(x, tensors.GraphTensor):
298
+ x = _make_dataset(x, batch_size)
299
+ output = super().predict(x, **kwargs)
300
+ if tensors.is_graph(output):
301
+ return tensors.from_dict(output).flatten()
302
+ return output
303
+
304
+ def get_compile_config(self) -> dict | None:
305
+ config = super().get_compile_config()
306
+ if config is None:
307
+ return
308
+ return config
309
+
310
+ def compile_from_config(self, config: dict | None) -> None:
311
+ if config is None:
312
+ return
313
+ config = keras.utils.deserialize_keras_object(config)
314
+ self.compile(**config)
315
+ if hasattr(self, 'optimizer') and self.built:
316
+ self.optimizer.build(self.trainable_variables)
317
+
318
+ def save(
319
+ self,
320
+ filepath: str | Path,
321
+ *args,
322
+ **kwargs
323
+ ) -> None:
324
+ """Saves an entire model.
325
+
326
+ Args:
327
+ filepath:
328
+ A string with the path to the model file (requires `.keras` suffix)
329
+ """
330
+ if not self.built:
331
+ raise ValueError('Cannot save model as it has not been built yet.')
332
+ super().save(filepath, *args, **kwargs)
333
+
334
+ @staticmethod
335
+ def load(
336
+ filepath: str | Path,
337
+ *args,
338
+ **kwargs
339
+ ) -> keras.Model:
340
+ """A `staticmethod` loading an entire model.
341
+
342
+ Args:
343
+ filepath:
344
+ A string with the path to the model file (requires `.keras` suffix)
345
+ """
346
+ return keras.models.load_model(filepath, *args, **kwargs)
347
+
348
+ def save_weights(self, filepath, *args, **kwargs):
349
+ """Saves the weights of the model.
350
+
351
+ Args:
352
+ filepath:
353
+ A string with the path to the file (requires `.weights.h5` suffix)
354
+ """
355
+ path = Path(filepath).parent
356
+ path.mkdir(parents=True, exist_ok=True)
357
+ return super().save_weights(filepath, *args, **kwargs)
358
+
359
+ def load_weights(self, filepath, *args, **kwargs):
360
+ """Loads the weights from file saved via `save_weights()`.
361
+
362
+ Args:
363
+ filepath:
364
+ A string with the path to the file (requires `.weights.h5` suffix)
365
+ """
366
+ super().load_weights(filepath, *args, **kwargs)
367
+
368
+ def embedding(self, layer_name: str = None) -> 'FunctionalGraphModel':
369
+ model = self
370
+ if not isinstance(model, FunctionalGraphModel):
371
+ raise ValueError(
372
+ 'Currently, to extract the embedding part of the model, '
373
+ 'it needs to be a `FunctionalGraphModel`. '
374
+ )
375
+ inputs = model.input
376
+ if not layer_name:
377
+ for layer in model.layers:
378
+ if isinstance(layer, layers.Readout):
379
+ outputs = layer.output
380
+ else:
381
+ layer = model.get_layer(layer_name)
382
+ outputs = (
383
+ layer.output if isinstance(layer, keras.layers.Layer) else None
384
+ )
385
+ if outputs is None:
386
+ raise ValueError(
387
+ f'Could not find `{layer_name}` or '
388
+ f'`{layer_name} is not a `keras.layers.Layer`.'
389
+ )
390
+ return self.__class__(inputs, outputs, name=f'{self.name}_embedding')
391
+
392
+ def backbone(self) -> 'FunctionalGraphModel':
393
+ if not isinstance(self, FunctionalGraphModel):
394
+ raise ValueError(
395
+ 'Currently, to extract the backbone part of the model, '
396
+ 'it needs to be a `FunctionalGraphModel`, with a `Readout` '
397
+ 'layer dividing the backbone and the head part of the model.'
398
+ )
399
+ inputs = self.input
400
+ outputs = None
401
+ for layer in self.layers:
402
+ if isinstance(layer, layers.Readout):
403
+ outputs = layer.output
404
+ if outputs is None:
405
+ raise ValueError(
406
+ 'Could not extract output. `Readout` layer not found.'
407
+ )
408
+ return self.__class__(inputs, outputs, name=f'{self.name}_backbone')
409
+
410
+ def head(self) -> functional.Functional:
411
+ if not isinstance(self, FunctionalGraphModel):
412
+ raise ValueError(
413
+ 'Currently, to extract the head part of the model, '
414
+ 'it needs to be a `FunctionalGraphModel`, with a `Readout` '
415
+ 'layer dividing the backbone and the head part of the model.'
416
+ )
417
+ inputs = None
418
+ for layer in self.layers:
419
+ if isinstance(layer, layers.Readout):
420
+ inputs = layer.output
421
+ if inputs is None:
422
+ raise ValueError(
423
+ 'Could not extract input. `Readout` layer not found.'
424
+ )
425
+ outputs = layer.output
426
+ return keras.models.Model(inputs, outputs, name=f'{self.name}_head')
427
+
428
+ def train_step(self, tensor: tensors.GraphTensor) -> dict[str, float]:
429
+ with tf.GradientTape() as tape:
430
+ output = self(tensor, training=True)
431
+ y, y_pred, sample_weight = _get_loss_args(tensor, output)
432
+ loss = self.compute_loss(tensor, y, y_pred, sample_weight)
433
+ loss = self.optimizer.scale_loss(loss)
434
+ trainable_weights = self.trainable_weights
435
+ gradients = tape.gradient(loss, trainable_weights)
436
+ self.optimizer.apply_gradients(zip(gradients, trainable_weights))
437
+ return self.compute_metrics(tensor, y, y_pred, sample_weight)
438
+
439
+ def test_step(self, tensor: tensors.GraphTensor) -> dict[str, float]:
440
+ output = self(tensor, training=False)
441
+ y, y_pred, sample_weight = _get_loss_args(tensor, output)
442
+ return self.compute_metrics(tensor, y, y_pred, sample_weight)
443
+
444
+ def predict_step(self, tensor: tensors.GraphTensor) -> np.ndarray:
445
+ output = self(tensor, training=False)
446
+ if tensors.is_graph(output):
447
+ if not isinstance(output, tensors.GraphTensor):
448
+ output = tensors.from_dict(output)
449
+ output = tensors.to_dict(output.unflatten())
450
+ return output
451
+
452
+ def compute_loss(self, x, y, y_pred, sample_weight=None):
453
+ return super().compute_loss(x, y, y_pred, sample_weight)
454
+
455
+ def compute_metrics(self, x, y, y_pred, sample_weight=None) -> dict[str, float]:
456
+ loss = self.compute_loss(x, y, y_pred, sample_weight)
457
+ metric_results = {}
458
+ for metric in self.metrics:
459
+ if metric.name == "loss":
460
+ metric.update_state(loss)
461
+ metric_results[metric.name] = metric.result()
462
+ else:
463
+ metric.update_state(y, y_pred, sample_weight=sample_weight)
464
+ metric_results.update(metric.result())
465
+ return metric_results
466
+
467
+
468
+ @keras.saving.register_keras_serializable(package="molcraft")
469
+ class FunctionalGraphModel(functional.Functional, GraphModel):
470
+
471
+ @property
472
+ def layers(self):
473
+ return [
474
+ l for l in super().layers if not isinstance(l, keras.layers.InputLayer)
475
+ ]
476
+
477
+
478
+ def save_model(model: GraphModel, filepath: str | Path, *args, **kwargs) -> None:
479
+ if not model.built:
480
+ raise ValueError(
481
+ 'Model and its layers have not yet been (fully) built. '
482
+ 'Build the model before saving it: `model.build(graph_spec)` '
483
+ 'or `model(graph)`.'
484
+ )
485
+ keras.models.save_model(model, filepath, *args, **kwargs)
486
+
487
+ def load_model(filepath: str | Path, inputs=None, *args, **kwargs) -> GraphModel:
488
+ return keras.models.load_model(filepath, *args, **kwargs)
489
+
490
+ def create(
491
+ *layers: list[keras.layers.Layer],
492
+ **kwargs
493
+ ) -> GraphModel:
494
+ if isinstance(layers[0], list):
495
+ layers = layers[0]
496
+ return GraphModel.from_layers(
497
+ list(layers), **kwargs
498
+ )
499
+
500
+ def interpret(
501
+ model: GraphModel,
502
+ graph_tensor: tensors.GraphTensor,
503
+ ) -> tensors.GraphTensor:
504
+ x = graph_tensor
505
+ if tensors.is_ragged(x):
506
+ x = x.flatten()
507
+ graph_indicator = x.graph_indicator
508
+ y_true = x.context.get('label')
509
+ features = []
510
+ with tf.GradientTape(watch_accessed_variables=False) as tape:
511
+ for layer in model.layers:
512
+ if isinstance(layer, keras.layers.InputLayer):
513
+ continue
514
+ if isinstance(layer, layers.GraphNetwork):
515
+ x, taped_features = layer.tape_propagate(x, tape, training=False)
516
+ features.extend(taped_features)
517
+ else:
518
+ if (
519
+ isinstance(layer, layers.GraphConv) and
520
+ isinstance(x, tensors.GraphTensor)
521
+ ):
522
+ tape.watch(x.node['feature'])
523
+ features.append(x.node['feature'])
524
+ x = layer(x, training=False)
525
+ y_pred = x
526
+ if y_true is not None and len(y_true.shape) > 1:
527
+ target = tf.gather_nd(y_pred, tf.where(y_true != 0))
528
+ else:
529
+ target = y_pred
530
+ gradients = tape.gradient(target, features)
531
+ features = keras.ops.concatenate(features, axis=-1)
532
+ gradients = keras.ops.concatenate(gradients, axis=-1)
533
+ alpha = ops.segment_mean(gradients, graph_indicator)
534
+ alpha = ops.gather(alpha, graph_indicator)
535
+ maps = keras.ops.where(gradients != 0, alpha * features, gradients)
536
+ maps = keras.ops.sum(maps, axis=-1)
537
+ return graph_tensor.update(
538
+ {
539
+ 'node': {
540
+ 'saliency': maps
541
+ }
542
+ }
543
+ )
544
+
545
+ def saliency(
546
+ model: GraphModel,
547
+ graph_tensor: tensors.GraphTensor,
548
+ ) -> tensors.GraphTensor:
549
+ x = graph_tensor
550
+ if tensors.is_ragged(x):
551
+ x = x.flatten()
552
+ y_true = x.context.get('label')
553
+ with tf.GradientTape(watch_accessed_variables=False) as tape:
554
+ tape.watch(x.node['feature'])
555
+ y_pred = model(x, training=False)
556
+ if y_true is not None and len(y_true.shape) > 1:
557
+ target = tf.gather_nd(y_pred, tf.where(y_true != 0))
558
+ else:
559
+ target = y_pred
560
+ gradients = tape.gradient(target, x.node['feature'])
561
+ gradients = keras.ops.absolute(gradients)
562
+ return graph_tensor.update(
563
+ {
564
+ 'node': {
565
+ 'feature_saliency': gradients
566
+ }
567
+ }
568
+ )
569
+
570
+ def _functional_init_arguments(args, kwargs):
571
+ return (
572
+ (len(args) == 2)
573
+ or (len(args) == 1 and "outputs" in kwargs)
574
+ or ("inputs" in kwargs and "outputs" in kwargs)
575
+ )
576
+
577
+ def _make_dataset(x: tensors.GraphTensor, batch_size: int, shuffle: bool = False):
578
+ ds = tf.data.Dataset.from_tensor_slices(x)
579
+ if shuffle:
580
+ ds = ds.shuffle(buffer_size=ds.cardinality())
581
+ return ds.batch(batch_size).prefetch(-1)
582
+
583
+ def _get_loss_args(
584
+ inputs: tensors.GraphTensor,
585
+ outputs: tensors.GraphTensor | tf.Tensor,
586
+ ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor | None]:
587
+ if (
588
+ not isinstance(inputs, tensors.GraphTensor) and
589
+ tensors.is_graph(inputs)
590
+ ):
591
+ inputs = tensors.from_dict(inputs)
592
+ if (
593
+ not isinstance(outputs, tensors.GraphTensor) and
594
+ tensors.is_graph(outputs)
595
+ ):
596
+ outputs = tensors.from_dict(outputs)
597
+
598
+ if not isinstance(outputs, tensors.GraphTensor):
599
+ tensor, prediction = inputs, outputs
600
+ else:
601
+ tensor, prediction = outputs, None
602
+
603
+ if 'label' in tensor.context:
604
+ data = tensor.context
605
+ elif 'label' in tensor.node:
606
+ data = tensor.node
607
+ elif 'label' in tensor.edge:
608
+ data = tensor.edge
609
+ else:
610
+ raise ValueError(
611
+ 'Could not find a `label` in the `GraphTensor`. Make sure a '
612
+ '`label` exists in either the `context`, `node` or `edge`.'
613
+ )
614
+
615
+ prediction = (
616
+ prediction if prediction is not None else data.get('prediction')
617
+ )
618
+ if prediction is None:
619
+ raise ValueError(
620
+ 'Could not find a `prediction` in the `GraphTensor`. Make sure a '
621
+ '`prediction` exists in either the `context`, `node` or `edge`.'
622
+ )
623
+ return data['label'], prediction, data.get('sample_weight')