molcraft 0.1.0rc9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of molcraft might be problematic. Click here for more details.
- molcraft/__init__.py +18 -0
- molcraft/callbacks.py +100 -0
- molcraft/chem.py +714 -0
- molcraft/datasets.py +132 -0
- molcraft/descriptors.py +149 -0
- molcraft/features.py +379 -0
- molcraft/featurizers.py +624 -0
- molcraft/layers.py +1910 -0
- molcraft/losses.py +37 -0
- molcraft/models.py +623 -0
- molcraft/ops.py +195 -0
- molcraft/records.py +187 -0
- molcraft/tensors.py +561 -0
- molcraft/trainers.py +212 -0
- molcraft-0.1.0rc9.dist-info/METADATA +118 -0
- molcraft-0.1.0rc9.dist-info/RECORD +19 -0
- molcraft-0.1.0rc9.dist-info/WHEEL +5 -0
- molcraft-0.1.0rc9.dist-info/licenses/LICENSE +21 -0
- molcraft-0.1.0rc9.dist-info/top_level.txt +1 -0
molcraft/losses.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
import keras
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
7
|
+
class GaussianNegativeLogLikelihood(keras.losses.Loss):
|
|
8
|
+
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
events: int = 1,
|
|
12
|
+
name="gaussian_nll",
|
|
13
|
+
**kwargs
|
|
14
|
+
):
|
|
15
|
+
super().__init__(name=name, **kwargs)
|
|
16
|
+
self.events = events
|
|
17
|
+
|
|
18
|
+
def call(self, y_true, y_pred):
|
|
19
|
+
mean = y_pred[..., :self.events]
|
|
20
|
+
scale = y_pred[..., self.events:]
|
|
21
|
+
variance = keras.ops.square(scale)
|
|
22
|
+
expected_rank = len(keras.ops.shape(mean))
|
|
23
|
+
current_rank = len(keras.ops.shape(y_true))
|
|
24
|
+
for _ in range(expected_rank - current_rank):
|
|
25
|
+
y_true = keras.ops.expand_dims(y_true, axis=-1)
|
|
26
|
+
return keras.ops.mean(
|
|
27
|
+
0.5 * keras.ops.log(2.0 * np.pi * variance) +
|
|
28
|
+
0.5 * keras.ops.square(y_true - mean) / variance
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
def get_config(self):
|
|
32
|
+
config = super().get_config()
|
|
33
|
+
config['events'] = self.events
|
|
34
|
+
return config
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
GaussianNLL = GaussianNegativeLogLikelihood
|
molcraft/models.py
ADDED
|
@@ -0,0 +1,623 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
import typing
|
|
3
|
+
import keras
|
|
4
|
+
import numpy as np
|
|
5
|
+
import tensorflow as tf
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from keras.src.models import functional
|
|
8
|
+
|
|
9
|
+
from molcraft import layers
|
|
10
|
+
from molcraft import tensors
|
|
11
|
+
from molcraft import ops
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@keras.saving.register_keras_serializable(package="molcraft")
|
|
15
|
+
class GraphModel(layers.GraphLayer, keras.models.Model):
|
|
16
|
+
|
|
17
|
+
"""A graph model.
|
|
18
|
+
|
|
19
|
+
Currently, the `GraphModel` only supports `GraphTensor` input.
|
|
20
|
+
|
|
21
|
+
Build a subclassed GraphModel:
|
|
22
|
+
|
|
23
|
+
>>> import molcraft
|
|
24
|
+
>>> import keras
|
|
25
|
+
>>>
|
|
26
|
+
>>> featurizer = molcraft.featurizers.MolGraphFeaturizer()
|
|
27
|
+
>>> graph = featurizer([('N[C@@H](C)C(=O)O', 1.0), ('N[C@@H](CS)C(=O)O', 2.0)])
|
|
28
|
+
>>>
|
|
29
|
+
>>> @keras.saving.register_keras_serializable()
|
|
30
|
+
>>> class GraphNeuralNetwork(molcraft.models.GraphModel):
|
|
31
|
+
... def __init__(self, units, **kwargs):
|
|
32
|
+
... super().__init__(**kwargs)
|
|
33
|
+
... self.units = units
|
|
34
|
+
... self.node_embedding = molcraft.layers.NodeEmbedding(self.units)
|
|
35
|
+
... self.edge_embedding = molcraft.layers.EdgeEmbedding(self.units)
|
|
36
|
+
... self.conv_1 = molcraft.layers.GraphConv(self.units)
|
|
37
|
+
... self.conv_2 = molcraft.layers.GraphConv(self.units)
|
|
38
|
+
... self.readout = molcraft.layers.Readout('mean')
|
|
39
|
+
... self.dense = keras.layers.Dense(1)
|
|
40
|
+
... def propagate(self, graph):
|
|
41
|
+
... x = self.edge_embedding(self.node_embedding(graph))
|
|
42
|
+
... x = self.conv_2(self.conv_1(x))
|
|
43
|
+
... return self.dense(self.readout(x))
|
|
44
|
+
... def get_config(self):
|
|
45
|
+
... config = super().get_config()
|
|
46
|
+
... config['units'] = self.units
|
|
47
|
+
... return config
|
|
48
|
+
>>>
|
|
49
|
+
>>> model = GraphNeuralNetwork(128)
|
|
50
|
+
>>> model.compile(
|
|
51
|
+
... optimizer=keras.optimizers.Adam(1e-3),
|
|
52
|
+
... loss=keras.losses.MeanSquaredError(),
|
|
53
|
+
... metrics=[keras.metrics.MeanAbsolutePercentageError(name='mape')]
|
|
54
|
+
... )
|
|
55
|
+
>>> model.fit(graph, epochs=10)
|
|
56
|
+
>>> mse, mape = model.evaluate(graph)
|
|
57
|
+
>>> preds = model.predict(graph)
|
|
58
|
+
|
|
59
|
+
Build a functional GraphModel:
|
|
60
|
+
|
|
61
|
+
>>> import molcraft
|
|
62
|
+
>>> import keras
|
|
63
|
+
>>>
|
|
64
|
+
>>> featurizer = molcraft.featurizers.MolGraphFeaturizer()
|
|
65
|
+
>>> graph = featurizer([('N[C@@H](C)C(=O)O', 1.0), ('N[C@@H](CS)C(=O)O', 2.0)])
|
|
66
|
+
>>>
|
|
67
|
+
>>> inputs = molcraft.layers.Input(graph.spec)
|
|
68
|
+
>>> x = molcraft.layers.NodeEmbedding(128)(inputs)
|
|
69
|
+
>>> x = molcraft.layers.EdgeEmbedding(128)(x)
|
|
70
|
+
>>> x = molcraft.layers.GraphConv(128)(x)
|
|
71
|
+
>>> x = molcraft.layers.GraphConv(128)(x)
|
|
72
|
+
>>> x = molcraft.layers.Readout('mean')(x)
|
|
73
|
+
>>> outputs = keras.layers.Dense(1)(x)
|
|
74
|
+
>>> model = molcraft.models.GraphModel(inputs, outputs)
|
|
75
|
+
>>> model.compile(
|
|
76
|
+
... optimizer=keras.optimizers.Adam(1e-3),
|
|
77
|
+
... loss=keras.losses.MeanSquaredError(),
|
|
78
|
+
... metrics=[keras.metrics.MeanAbsolutePercentageError(name='mape')]
|
|
79
|
+
... )
|
|
80
|
+
>>> model.fit(graph, epochs=10)
|
|
81
|
+
>>> mse, mape = model.evaluate(graph)
|
|
82
|
+
>>> preds = model.predict(graph)
|
|
83
|
+
|
|
84
|
+
Build a GraphModel using `from_layers`:
|
|
85
|
+
|
|
86
|
+
>>> import molcraft
|
|
87
|
+
>>> import keras
|
|
88
|
+
>>>
|
|
89
|
+
>>> featurizer = molcraft.featurizers.MolGraphFeaturizer()
|
|
90
|
+
>>> graph = featurizer([('N[C@@H](C)C(=O)O', 1.0), ('N[C@@H](CS)C(=O)O', 2.0)])
|
|
91
|
+
>>>
|
|
92
|
+
>>> model = molcraft.models.GraphModel.from_layers([
|
|
93
|
+
... molcraft.layers.Input(graph.spec),
|
|
94
|
+
... molcraft.layers.NodeEmbedding(128),
|
|
95
|
+
... molcraft.layers.EdgeEmbedding(128),
|
|
96
|
+
... molcraft.layers.GraphConv(128),
|
|
97
|
+
... molcraft.layers.GraphConv(128),
|
|
98
|
+
... molcraft.layers.Readout('mean'),
|
|
99
|
+
... keras.layers.Dense(1)
|
|
100
|
+
... ])
|
|
101
|
+
>>> model.compile(
|
|
102
|
+
... optimizer=keras.optimizers.Adam(1e-3),
|
|
103
|
+
... loss=keras.losses.MeanSquaredError(),
|
|
104
|
+
... metrics=[keras.metrics.MeanAbsolutePercentageError(name='mape')]
|
|
105
|
+
... )
|
|
106
|
+
>>> model.fit(graph, epochs=10)
|
|
107
|
+
>>> mse, mape = model.evaluate(graph)
|
|
108
|
+
>>> preds = model.predict(graph)
|
|
109
|
+
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
def __new__(cls, *args, **kwargs):
|
|
113
|
+
if _functional_init_arguments(args, kwargs) and cls == GraphModel:
|
|
114
|
+
return FunctionalGraphModel(*args, **kwargs)
|
|
115
|
+
return super().__new__(cls)
|
|
116
|
+
|
|
117
|
+
def __init__(self, *args, **kwargs):
|
|
118
|
+
self._model_layers = kwargs.pop('model_layers', None)
|
|
119
|
+
super().__init__(*args, **kwargs)
|
|
120
|
+
self.jit_compile = False
|
|
121
|
+
|
|
122
|
+
@classmethod
|
|
123
|
+
def from_layers(cls, graph_layers: list, **kwargs):
|
|
124
|
+
"""Creates a graph model from a list of graph layers.
|
|
125
|
+
|
|
126
|
+
Currently requires `molcraft.layers.Input(spec)`.
|
|
127
|
+
|
|
128
|
+
If `molcraft.layers.Input(spec)` is supplied, it both
|
|
129
|
+
creates and builds the layer, as a functional model.
|
|
130
|
+
`molcraft.layers.Input` is a function which returns
|
|
131
|
+
a nested structure of graph components based on `spec`.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
graph_layers:
|
|
135
|
+
A list of `GraphLayer` instances, except the initial element
|
|
136
|
+
which is a dictionary of Keras tensors produced by
|
|
137
|
+
`molcraft.layers.Input(spec)`.
|
|
138
|
+
"""
|
|
139
|
+
if not tensors.is_graph(graph_layers[0]):
|
|
140
|
+
return cls(model_layers=graph_layers, **kwargs)
|
|
141
|
+
elif cls != GraphModel:
|
|
142
|
+
return cls(model_layers=graph_layers[1:], **kwargs)
|
|
143
|
+
inputs: dict = graph_layers.pop(0)
|
|
144
|
+
x = inputs
|
|
145
|
+
for layer in graph_layers:
|
|
146
|
+
if isinstance(layer, list):
|
|
147
|
+
layer = layers.GraphNetwork(layer)
|
|
148
|
+
x = layer(x)
|
|
149
|
+
outputs = x
|
|
150
|
+
return cls(inputs=inputs, outputs=outputs, **kwargs)
|
|
151
|
+
|
|
152
|
+
def propagate(self, graph: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
153
|
+
if self._model_layers is None:
|
|
154
|
+
return super().propagate(graph)
|
|
155
|
+
for layer in self._model_layers:
|
|
156
|
+
graph = layer(graph)
|
|
157
|
+
return graph
|
|
158
|
+
|
|
159
|
+
def get_config(self):
|
|
160
|
+
"""Obtain model config."""
|
|
161
|
+
config = super().get_config()
|
|
162
|
+
if hasattr(self, '_model_layers') and self._model_layers is not None:
|
|
163
|
+
config['model_layers'] = [
|
|
164
|
+
keras.saving.serialize_keras_object(l)
|
|
165
|
+
for l in self._model_layers
|
|
166
|
+
]
|
|
167
|
+
return config
|
|
168
|
+
|
|
169
|
+
@classmethod
|
|
170
|
+
def from_config(cls, config: dict):
|
|
171
|
+
"""Obtain model from model config."""
|
|
172
|
+
if 'model_layers' in config:
|
|
173
|
+
config['model_layers'] = [
|
|
174
|
+
keras.saving.deserialize_keras_object(l)
|
|
175
|
+
for l in config['model_layers']
|
|
176
|
+
]
|
|
177
|
+
return super().from_config(config)
|
|
178
|
+
|
|
179
|
+
def compile(
|
|
180
|
+
self,
|
|
181
|
+
optimizer: keras.optimizers.Optimizer | str | None = 'rmsprop',
|
|
182
|
+
loss: keras.losses.Loss | str | None = None,
|
|
183
|
+
loss_weights: dict[str, float] = None,
|
|
184
|
+
metrics: list[keras.metrics.Metric] = None,
|
|
185
|
+
weighted_metrics: list[keras.metrics.Metric] | None = None,
|
|
186
|
+
run_eagerly: bool = False,
|
|
187
|
+
steps_per_execution: int = 1,
|
|
188
|
+
jit_compile: str | bool = False,
|
|
189
|
+
auto_scale_loss: bool = True,
|
|
190
|
+
**kwargs
|
|
191
|
+
) -> None:
|
|
192
|
+
"""Compiles the model.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
optimizer:
|
|
196
|
+
The optimizer to be used (a `keras.optimizers.Optimizer` subclass).
|
|
197
|
+
loss:
|
|
198
|
+
The loss function to be used (a `keras.losses.Loss` subclass).
|
|
199
|
+
metrics:
|
|
200
|
+
A list of metrics to be used during training (`fit`) and evaluation
|
|
201
|
+
(`evaluate`). Should be `keras.metrics.Metric` subclasses.
|
|
202
|
+
kwargs:
|
|
203
|
+
See `Model.compile` in Keras documentation.
|
|
204
|
+
May or may not apply here.
|
|
205
|
+
"""
|
|
206
|
+
super().compile(
|
|
207
|
+
optimizer=optimizer,
|
|
208
|
+
loss=loss,
|
|
209
|
+
loss_weights=loss_weights,
|
|
210
|
+
metrics=metrics,
|
|
211
|
+
weighted_metrics=weighted_metrics,
|
|
212
|
+
run_eagerly=run_eagerly,
|
|
213
|
+
steps_per_execution=steps_per_execution,
|
|
214
|
+
jit_compile=jit_compile,
|
|
215
|
+
auto_scale_loss=auto_scale_loss,
|
|
216
|
+
**kwargs
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
def fit(self, x: tensors.GraphTensor | tf.data.Dataset, **kwargs):
|
|
220
|
+
"""Fits the model.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
x:
|
|
224
|
+
A `GraphTensor` instance or a `tf.data.Dataset` constructed from
|
|
225
|
+
a `GraphTensor` instance. In comparison to a typical Keras model,
|
|
226
|
+
the label (typically denoted `y`) and the sample_weight (typically
|
|
227
|
+
denoted `sample_weight`) should be encoded in the context of the
|
|
228
|
+
`GraphTensor` instance, as `label` and `weight` respectively.
|
|
229
|
+
validation_data:
|
|
230
|
+
A `GraphTensor` instance or a `tf.data.Dataset` constructed from
|
|
231
|
+
a `GraphTensor` instance. In comparison to a typical Keras model,
|
|
232
|
+
the label (typically denoted `y`) and the sample_weight (typically
|
|
233
|
+
denoted `sample_weight`) should be encoded in the context of the
|
|
234
|
+
`GraphTensor` instance, as `label` and `weight` respectively.
|
|
235
|
+
validaton_split:
|
|
236
|
+
The fraction of training data to be used as validation data.
|
|
237
|
+
Only works if a `GraphTensor` instance is passed as `x`.
|
|
238
|
+
batch_size:
|
|
239
|
+
Number of samples per batch of computation.
|
|
240
|
+
epochs:
|
|
241
|
+
Number of iterations over the entire dataset.
|
|
242
|
+
callbacks:
|
|
243
|
+
A list of callbacks to apply during training.
|
|
244
|
+
kwargs:
|
|
245
|
+
See `Model.fit` in Keras documentation.
|
|
246
|
+
May or may not apply here.
|
|
247
|
+
"""
|
|
248
|
+
batch_size = kwargs.get('batch_size', 32)
|
|
249
|
+
x_val = kwargs.pop('validation_data', None)
|
|
250
|
+
val_split = kwargs.pop('validation_split', None)
|
|
251
|
+
if x_val is not None and isinstance(x_val, tensors.GraphTensor):
|
|
252
|
+
x_val = _make_dataset(x_val, batch_size)
|
|
253
|
+
if isinstance(x, tensors.GraphTensor):
|
|
254
|
+
if val_split:
|
|
255
|
+
val_size = int(val_split * x.num_subgraphs)
|
|
256
|
+
x_val = _make_dataset(x[-val_size:], batch_size)
|
|
257
|
+
x = x[:-val_size]
|
|
258
|
+
x = _make_dataset(x, batch_size, shuffle=kwargs.get('shuffle', True))
|
|
259
|
+
return super().fit(x, validation_data=x_val, **kwargs)
|
|
260
|
+
|
|
261
|
+
def evaluate(self, x: tensors.GraphTensor | tf.data.Dataset, **kwargs):
|
|
262
|
+
"""Evaluation of the model.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
x:
|
|
266
|
+
A `GraphTensor` instance or a `tf.data.Dataset` constructed from
|
|
267
|
+
a `GraphTensor` instance. In comparison to a typical Keras model,
|
|
268
|
+
the label (typically denoted `y`) and the sample_weight (typically
|
|
269
|
+
denoted `sample_weight`) should be encoded in the context of the
|
|
270
|
+
`GraphTensor` instance, as `label` and `weight` respectively.
|
|
271
|
+
batch_size:
|
|
272
|
+
Number of samples per batch of computation.
|
|
273
|
+
kwargs:
|
|
274
|
+
See `Model.evaluate` in Keras documentation.
|
|
275
|
+
May or may not apply here.
|
|
276
|
+
"""
|
|
277
|
+
batch_size = kwargs.get('batch_size', 32)
|
|
278
|
+
if isinstance(x, tensors.GraphTensor):
|
|
279
|
+
x = _make_dataset(x, batch_size)
|
|
280
|
+
metric_results = super().evaluate(x, **kwargs)
|
|
281
|
+
return tf.nest.map_structure(lambda value: float(value), metric_results)
|
|
282
|
+
|
|
283
|
+
def predict(self, x: tensors.GraphTensor | tf.data.Dataset, **kwargs):
|
|
284
|
+
"""Makes predictions with the model.
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
x:
|
|
288
|
+
A `GraphTensor` instance or a `tf.data.Dataset` constructed from
|
|
289
|
+
a `GraphTensor` instance.
|
|
290
|
+
batch_size:
|
|
291
|
+
Number of samples per batch of computation.
|
|
292
|
+
kwargs:
|
|
293
|
+
See `Model.predict` in Keras documentation.
|
|
294
|
+
May or may not apply here.
|
|
295
|
+
"""
|
|
296
|
+
batch_size = kwargs.get('batch_size', 32)
|
|
297
|
+
if isinstance(x, tensors.GraphTensor):
|
|
298
|
+
x = _make_dataset(x, batch_size)
|
|
299
|
+
output = super().predict(x, **kwargs)
|
|
300
|
+
if tensors.is_graph(output):
|
|
301
|
+
return tensors.from_dict(output).flatten()
|
|
302
|
+
return output
|
|
303
|
+
|
|
304
|
+
def get_compile_config(self) -> dict | None:
|
|
305
|
+
config = super().get_compile_config()
|
|
306
|
+
if config is None:
|
|
307
|
+
return
|
|
308
|
+
return config
|
|
309
|
+
|
|
310
|
+
def compile_from_config(self, config: dict | None) -> None:
|
|
311
|
+
if config is None:
|
|
312
|
+
return
|
|
313
|
+
config = keras.utils.deserialize_keras_object(config)
|
|
314
|
+
self.compile(**config)
|
|
315
|
+
if hasattr(self, 'optimizer') and self.built:
|
|
316
|
+
self.optimizer.build(self.trainable_variables)
|
|
317
|
+
|
|
318
|
+
def save(
|
|
319
|
+
self,
|
|
320
|
+
filepath: str | Path,
|
|
321
|
+
*args,
|
|
322
|
+
**kwargs
|
|
323
|
+
) -> None:
|
|
324
|
+
"""Saves an entire model.
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
filepath:
|
|
328
|
+
A string with the path to the model file (requires `.keras` suffix)
|
|
329
|
+
"""
|
|
330
|
+
if not self.built:
|
|
331
|
+
raise ValueError('Cannot save model as it has not been built yet.')
|
|
332
|
+
super().save(filepath, *args, **kwargs)
|
|
333
|
+
|
|
334
|
+
@staticmethod
|
|
335
|
+
def load(
|
|
336
|
+
filepath: str | Path,
|
|
337
|
+
*args,
|
|
338
|
+
**kwargs
|
|
339
|
+
) -> keras.Model:
|
|
340
|
+
"""A `staticmethod` loading an entire model.
|
|
341
|
+
|
|
342
|
+
Args:
|
|
343
|
+
filepath:
|
|
344
|
+
A string with the path to the model file (requires `.keras` suffix)
|
|
345
|
+
"""
|
|
346
|
+
return keras.models.load_model(filepath, *args, **kwargs)
|
|
347
|
+
|
|
348
|
+
def save_weights(self, filepath, *args, **kwargs):
|
|
349
|
+
"""Saves the weights of the model.
|
|
350
|
+
|
|
351
|
+
Args:
|
|
352
|
+
filepath:
|
|
353
|
+
A string with the path to the file (requires `.weights.h5` suffix)
|
|
354
|
+
"""
|
|
355
|
+
path = Path(filepath).parent
|
|
356
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
357
|
+
return super().save_weights(filepath, *args, **kwargs)
|
|
358
|
+
|
|
359
|
+
def load_weights(self, filepath, *args, **kwargs):
|
|
360
|
+
"""Loads the weights from file saved via `save_weights()`.
|
|
361
|
+
|
|
362
|
+
Args:
|
|
363
|
+
filepath:
|
|
364
|
+
A string with the path to the file (requires `.weights.h5` suffix)
|
|
365
|
+
"""
|
|
366
|
+
super().load_weights(filepath, *args, **kwargs)
|
|
367
|
+
|
|
368
|
+
def embedding(self, layer_name: str = None) -> 'FunctionalGraphModel':
|
|
369
|
+
model = self
|
|
370
|
+
if not isinstance(model, FunctionalGraphModel):
|
|
371
|
+
raise ValueError(
|
|
372
|
+
'Currently, to extract the embedding part of the model, '
|
|
373
|
+
'it needs to be a `FunctionalGraphModel`. '
|
|
374
|
+
)
|
|
375
|
+
inputs = model.input
|
|
376
|
+
if not layer_name:
|
|
377
|
+
for layer in model.layers:
|
|
378
|
+
if isinstance(layer, layers.Readout):
|
|
379
|
+
outputs = layer.output
|
|
380
|
+
else:
|
|
381
|
+
layer = model.get_layer(layer_name)
|
|
382
|
+
outputs = (
|
|
383
|
+
layer.output if isinstance(layer, keras.layers.Layer) else None
|
|
384
|
+
)
|
|
385
|
+
if outputs is None:
|
|
386
|
+
raise ValueError(
|
|
387
|
+
f'Could not find `{layer_name}` or '
|
|
388
|
+
f'`{layer_name} is not a `keras.layers.Layer`.'
|
|
389
|
+
)
|
|
390
|
+
return self.__class__(inputs, outputs, name=f'{self.name}_embedding')
|
|
391
|
+
|
|
392
|
+
def backbone(self) -> 'FunctionalGraphModel':
|
|
393
|
+
if not isinstance(self, FunctionalGraphModel):
|
|
394
|
+
raise ValueError(
|
|
395
|
+
'Currently, to extract the backbone part of the model, '
|
|
396
|
+
'it needs to be a `FunctionalGraphModel`, with a `Readout` '
|
|
397
|
+
'layer dividing the backbone and the head part of the model.'
|
|
398
|
+
)
|
|
399
|
+
inputs = self.input
|
|
400
|
+
outputs = None
|
|
401
|
+
for layer in self.layers:
|
|
402
|
+
if isinstance(layer, layers.Readout):
|
|
403
|
+
outputs = layer.output
|
|
404
|
+
if outputs is None:
|
|
405
|
+
raise ValueError(
|
|
406
|
+
'Could not extract output. `Readout` layer not found.'
|
|
407
|
+
)
|
|
408
|
+
return self.__class__(inputs, outputs, name=f'{self.name}_backbone')
|
|
409
|
+
|
|
410
|
+
def head(self) -> functional.Functional:
|
|
411
|
+
if not isinstance(self, FunctionalGraphModel):
|
|
412
|
+
raise ValueError(
|
|
413
|
+
'Currently, to extract the head part of the model, '
|
|
414
|
+
'it needs to be a `FunctionalGraphModel`, with a `Readout` '
|
|
415
|
+
'layer dividing the backbone and the head part of the model.'
|
|
416
|
+
)
|
|
417
|
+
inputs = None
|
|
418
|
+
for layer in self.layers:
|
|
419
|
+
if isinstance(layer, layers.Readout):
|
|
420
|
+
inputs = layer.output
|
|
421
|
+
if inputs is None:
|
|
422
|
+
raise ValueError(
|
|
423
|
+
'Could not extract input. `Readout` layer not found.'
|
|
424
|
+
)
|
|
425
|
+
outputs = layer.output
|
|
426
|
+
return keras.models.Model(inputs, outputs, name=f'{self.name}_head')
|
|
427
|
+
|
|
428
|
+
def train_step(self, tensor: tensors.GraphTensor) -> dict[str, float]:
|
|
429
|
+
with tf.GradientTape() as tape:
|
|
430
|
+
output = self(tensor, training=True)
|
|
431
|
+
y, y_pred, sample_weight = _get_loss_args(tensor, output)
|
|
432
|
+
loss = self.compute_loss(tensor, y, y_pred, sample_weight)
|
|
433
|
+
loss = self.optimizer.scale_loss(loss)
|
|
434
|
+
trainable_weights = self.trainable_weights
|
|
435
|
+
gradients = tape.gradient(loss, trainable_weights)
|
|
436
|
+
self.optimizer.apply_gradients(zip(gradients, trainable_weights))
|
|
437
|
+
return self.compute_metrics(tensor, y, y_pred, sample_weight)
|
|
438
|
+
|
|
439
|
+
def test_step(self, tensor: tensors.GraphTensor) -> dict[str, float]:
|
|
440
|
+
output = self(tensor, training=False)
|
|
441
|
+
y, y_pred, sample_weight = _get_loss_args(tensor, output)
|
|
442
|
+
return self.compute_metrics(tensor, y, y_pred, sample_weight)
|
|
443
|
+
|
|
444
|
+
def predict_step(self, tensor: tensors.GraphTensor) -> np.ndarray:
|
|
445
|
+
output = self(tensor, training=False)
|
|
446
|
+
if tensors.is_graph(output):
|
|
447
|
+
if not isinstance(output, tensors.GraphTensor):
|
|
448
|
+
output = tensors.from_dict(output)
|
|
449
|
+
output = tensors.to_dict(output.unflatten())
|
|
450
|
+
return output
|
|
451
|
+
|
|
452
|
+
def compute_loss(self, x, y, y_pred, sample_weight=None):
|
|
453
|
+
return super().compute_loss(x, y, y_pred, sample_weight)
|
|
454
|
+
|
|
455
|
+
def compute_metrics(self, x, y, y_pred, sample_weight=None) -> dict[str, float]:
|
|
456
|
+
loss = self.compute_loss(x, y, y_pred, sample_weight)
|
|
457
|
+
metric_results = {}
|
|
458
|
+
for metric in self.metrics:
|
|
459
|
+
if metric.name == "loss":
|
|
460
|
+
metric.update_state(loss)
|
|
461
|
+
metric_results[metric.name] = metric.result()
|
|
462
|
+
else:
|
|
463
|
+
metric.update_state(y, y_pred, sample_weight=sample_weight)
|
|
464
|
+
metric_results.update(metric.result())
|
|
465
|
+
return metric_results
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
@keras.saving.register_keras_serializable(package="molcraft")
|
|
469
|
+
class FunctionalGraphModel(functional.Functional, GraphModel):
|
|
470
|
+
|
|
471
|
+
@property
|
|
472
|
+
def layers(self):
|
|
473
|
+
return [
|
|
474
|
+
l for l in super().layers if not isinstance(l, keras.layers.InputLayer)
|
|
475
|
+
]
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
def save_model(model: GraphModel, filepath: str | Path, *args, **kwargs) -> None:
|
|
479
|
+
if not model.built:
|
|
480
|
+
raise ValueError(
|
|
481
|
+
'Model and its layers have not yet been (fully) built. '
|
|
482
|
+
'Build the model before saving it: `model.build(graph_spec)` '
|
|
483
|
+
'or `model(graph)`.'
|
|
484
|
+
)
|
|
485
|
+
keras.models.save_model(model, filepath, *args, **kwargs)
|
|
486
|
+
|
|
487
|
+
def load_model(filepath: str | Path, inputs=None, *args, **kwargs) -> GraphModel:
|
|
488
|
+
return keras.models.load_model(filepath, *args, **kwargs)
|
|
489
|
+
|
|
490
|
+
def create(
|
|
491
|
+
*layers: list[keras.layers.Layer],
|
|
492
|
+
**kwargs
|
|
493
|
+
) -> GraphModel:
|
|
494
|
+
if isinstance(layers[0], list):
|
|
495
|
+
layers = layers[0]
|
|
496
|
+
return GraphModel.from_layers(
|
|
497
|
+
list(layers), **kwargs
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
def interpret(
|
|
501
|
+
model: GraphModel,
|
|
502
|
+
graph_tensor: tensors.GraphTensor,
|
|
503
|
+
) -> tensors.GraphTensor:
|
|
504
|
+
x = graph_tensor
|
|
505
|
+
if tensors.is_ragged(x):
|
|
506
|
+
x = x.flatten()
|
|
507
|
+
graph_indicator = x.graph_indicator
|
|
508
|
+
y_true = x.context.get('label')
|
|
509
|
+
features = []
|
|
510
|
+
with tf.GradientTape(watch_accessed_variables=False) as tape:
|
|
511
|
+
for layer in model.layers:
|
|
512
|
+
if isinstance(layer, keras.layers.InputLayer):
|
|
513
|
+
continue
|
|
514
|
+
if isinstance(layer, layers.GraphNetwork):
|
|
515
|
+
x, taped_features = layer.tape_propagate(x, tape, training=False)
|
|
516
|
+
features.extend(taped_features)
|
|
517
|
+
else:
|
|
518
|
+
if (
|
|
519
|
+
isinstance(layer, layers.GraphConv) and
|
|
520
|
+
isinstance(x, tensors.GraphTensor)
|
|
521
|
+
):
|
|
522
|
+
tape.watch(x.node['feature'])
|
|
523
|
+
features.append(x.node['feature'])
|
|
524
|
+
x = layer(x, training=False)
|
|
525
|
+
y_pred = x
|
|
526
|
+
if y_true is not None and len(y_true.shape) > 1:
|
|
527
|
+
target = tf.gather_nd(y_pred, tf.where(y_true != 0))
|
|
528
|
+
else:
|
|
529
|
+
target = y_pred
|
|
530
|
+
gradients = tape.gradient(target, features)
|
|
531
|
+
features = keras.ops.concatenate(features, axis=-1)
|
|
532
|
+
gradients = keras.ops.concatenate(gradients, axis=-1)
|
|
533
|
+
alpha = ops.segment_mean(gradients, graph_indicator)
|
|
534
|
+
alpha = ops.gather(alpha, graph_indicator)
|
|
535
|
+
maps = keras.ops.where(gradients != 0, alpha * features, gradients)
|
|
536
|
+
maps = keras.ops.sum(maps, axis=-1)
|
|
537
|
+
return graph_tensor.update(
|
|
538
|
+
{
|
|
539
|
+
'node': {
|
|
540
|
+
'saliency': maps
|
|
541
|
+
}
|
|
542
|
+
}
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
def saliency(
|
|
546
|
+
model: GraphModel,
|
|
547
|
+
graph_tensor: tensors.GraphTensor,
|
|
548
|
+
) -> tensors.GraphTensor:
|
|
549
|
+
x = graph_tensor
|
|
550
|
+
if tensors.is_ragged(x):
|
|
551
|
+
x = x.flatten()
|
|
552
|
+
y_true = x.context.get('label')
|
|
553
|
+
with tf.GradientTape(watch_accessed_variables=False) as tape:
|
|
554
|
+
tape.watch(x.node['feature'])
|
|
555
|
+
y_pred = model(x, training=False)
|
|
556
|
+
if y_true is not None and len(y_true.shape) > 1:
|
|
557
|
+
target = tf.gather_nd(y_pred, tf.where(y_true != 0))
|
|
558
|
+
else:
|
|
559
|
+
target = y_pred
|
|
560
|
+
gradients = tape.gradient(target, x.node['feature'])
|
|
561
|
+
gradients = keras.ops.absolute(gradients)
|
|
562
|
+
return graph_tensor.update(
|
|
563
|
+
{
|
|
564
|
+
'node': {
|
|
565
|
+
'feature_saliency': gradients
|
|
566
|
+
}
|
|
567
|
+
}
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
def _functional_init_arguments(args, kwargs):
|
|
571
|
+
return (
|
|
572
|
+
(len(args) == 2)
|
|
573
|
+
or (len(args) == 1 and "outputs" in kwargs)
|
|
574
|
+
or ("inputs" in kwargs and "outputs" in kwargs)
|
|
575
|
+
)
|
|
576
|
+
|
|
577
|
+
def _make_dataset(x: tensors.GraphTensor, batch_size: int, shuffle: bool = False):
|
|
578
|
+
ds = tf.data.Dataset.from_tensor_slices(x)
|
|
579
|
+
if shuffle:
|
|
580
|
+
ds = ds.shuffle(buffer_size=ds.cardinality())
|
|
581
|
+
return ds.batch(batch_size).prefetch(-1)
|
|
582
|
+
|
|
583
|
+
def _get_loss_args(
|
|
584
|
+
inputs: tensors.GraphTensor,
|
|
585
|
+
outputs: tensors.GraphTensor | tf.Tensor,
|
|
586
|
+
) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor | None]:
|
|
587
|
+
if (
|
|
588
|
+
not isinstance(inputs, tensors.GraphTensor) and
|
|
589
|
+
tensors.is_graph(inputs)
|
|
590
|
+
):
|
|
591
|
+
inputs = tensors.from_dict(inputs)
|
|
592
|
+
if (
|
|
593
|
+
not isinstance(outputs, tensors.GraphTensor) and
|
|
594
|
+
tensors.is_graph(outputs)
|
|
595
|
+
):
|
|
596
|
+
outputs = tensors.from_dict(outputs)
|
|
597
|
+
|
|
598
|
+
if not isinstance(outputs, tensors.GraphTensor):
|
|
599
|
+
tensor, prediction = inputs, outputs
|
|
600
|
+
else:
|
|
601
|
+
tensor, prediction = outputs, None
|
|
602
|
+
|
|
603
|
+
if 'label' in tensor.context:
|
|
604
|
+
data = tensor.context
|
|
605
|
+
elif 'label' in tensor.node:
|
|
606
|
+
data = tensor.node
|
|
607
|
+
elif 'label' in tensor.edge:
|
|
608
|
+
data = tensor.edge
|
|
609
|
+
else:
|
|
610
|
+
raise ValueError(
|
|
611
|
+
'Could not find a `label` in the `GraphTensor`. Make sure a '
|
|
612
|
+
'`label` exists in either the `context`, `node` or `edge`.'
|
|
613
|
+
)
|
|
614
|
+
|
|
615
|
+
prediction = (
|
|
616
|
+
prediction if prediction is not None else data.get('prediction')
|
|
617
|
+
)
|
|
618
|
+
if prediction is None:
|
|
619
|
+
raise ValueError(
|
|
620
|
+
'Could not find a `prediction` in the `GraphTensor`. Make sure a '
|
|
621
|
+
'`prediction` exists in either the `context`, `node` or `edge`.'
|
|
622
|
+
)
|
|
623
|
+
return data['label'], prediction, data.get('sample_weight')
|