molcraft 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

molcraft/models.py ADDED
@@ -0,0 +1,441 @@
1
+ import typing
2
+ import keras
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ from pathlib import Path
6
+ from keras.src.models import functional
7
+
8
+ from molcraft import layers
9
+ from molcraft import tensors
10
+ from molcraft import ops
11
+
12
+
13
+ @keras.saving.register_keras_serializable(package="molcraft")
14
+ class GraphModel(layers.GraphLayer, keras.models.Model):
15
+
16
+ """A graph model.
17
+
18
+ Currently, the `GraphModel` only supports `GraphTensor` input.
19
+
20
+ Example (using `from_layers`):
21
+
22
+ >>> import molcraft
23
+ >>> import keras
24
+ >>>
25
+ >>> featurizer = molcraft.featurizers.MolGraphFeaturizer()
26
+ >>> graph = featurizer([('N[C@@H](C)C(=O)O', 1.0), ('N[C@@H](CS)C(=O)O', 2.0)])
27
+ >>>
28
+ >>> model = molcraft.models.GraphModel.from_layers(
29
+ ... molcraft.layers.Input(graph.spec),
30
+ ... molcraft.layers.NodeEmbedding(128),
31
+ ... molcraft.layers.EdgeEmbedding(128),
32
+ ... molcraft.layers.GraphTransformer(128),
33
+ ... molcraft.layers.GraphTransformer(128),
34
+ ... molcraft.layers.Readout('mean'),
35
+ ... molcraft.layers.Dense(1)
36
+ ... ])
37
+ >>> model.compile(
38
+ ... optimizer=keras.optimizers.Adam(1e-3),
39
+ ... loss=keras.losses.MeanSquaredError(),
40
+ ... metrics=[keras.metrics.MeanAbsolutePercentageError(name='mape')]
41
+ ... )
42
+ >>> model.fit(graph, epochs=10)
43
+ >>> mse, mape = model.evaluate(graph)
44
+ >>> preds = model.predict(graph)
45
+ """
46
+
47
+ def __new__(cls, *args, **kwargs):
48
+ if _functional_init_arguments(args, kwargs) and cls == GraphModel:
49
+ return FunctionalGraphModel(*args, **kwargs)
50
+ return typing.cast(GraphModel, super().__new__(cls))
51
+
52
+ def __init__(self, *args, **kwargs):
53
+ super().__init__(*args, **kwargs)
54
+ self.jit_compile = False
55
+
56
+ @classmethod
57
+ def from_layers(cls, graph_layers: list, **kwargs):
58
+ """Creates a graph model from a list of graph layers.
59
+
60
+ Currently requires `molcraft.layers.Input(spec)`.
61
+
62
+ If `molcraft.layers.Input(spec)` is supplied, it both
63
+ creates and builds the layer, as a functional model.
64
+ `molcraft.layers.Input` is a function which returns
65
+ a nested structure of graph components based on `spec`.
66
+
67
+ Args:
68
+ graph_layers:
69
+ A list of `GraphLayer` instances, except the initial element
70
+ which is a dictionary of Keras tensors produced by
71
+ `molcraft.layers.Input(spec)`.
72
+ """
73
+ if not tensors.is_graph(graph_layers[0]):
74
+ # TODO: Allow this. E.g.: return cls(layers=graph_layers)
75
+ raise ValueError(
76
+ 'Graph input not found. Make sure to add `Input`.'
77
+ )
78
+ inputs: dict = graph_layers.pop(0)
79
+ x = inputs
80
+ for layer in graph_layers:
81
+ if isinstance(layer, list):
82
+ layer = layers.GraphNetwork(layer)
83
+ x = layer(x)
84
+ outputs = x
85
+ return cls(inputs=inputs, outputs=outputs, **kwargs)
86
+
87
+ def compile(
88
+ self,
89
+ optimizer: keras.optimizers.Optimizer | str | None = 'rmsprop',
90
+ loss: keras.losses.Loss | str | None = None,
91
+ loss_weights: dict[str, float] = None,
92
+ metrics: list[keras.metrics.Metric] = None,
93
+ weighted_metrics: list[keras.metrics.Metric] | None = None,
94
+ run_eagerly: bool = False,
95
+ steps_per_execution: int = 1,
96
+ jit_compile: str | bool = False,
97
+ auto_scale_loss: bool = True,
98
+ **kwargs
99
+ ) -> None:
100
+ """Compiles the model.
101
+
102
+ Args:
103
+ optimizer:
104
+ The optimizer to be used (a `keras.optimizers.Optimizer` subclass).
105
+ loss:
106
+ The loss function to be used (a `keras.losses.Loss` subclass).
107
+ metrics:
108
+ A list of metrics to be used during training (`fit`) and evaluation
109
+ (`evaluate`). Should be `keras.metrics.Metric` subclasses.
110
+ kwargs:
111
+ See `Model.compile` in Keras documentation.
112
+ May or may not apply here.
113
+ """
114
+ super().compile(
115
+ optimizer=optimizer,
116
+ loss=loss,
117
+ loss_weights=loss_weights,
118
+ metrics=metrics,
119
+ weighted_metrics=weighted_metrics,
120
+ run_eagerly=run_eagerly,
121
+ steps_per_execution=steps_per_execution,
122
+ jit_compile=jit_compile,
123
+ auto_scale_loss=auto_scale_loss,
124
+ **kwargs
125
+ )
126
+
127
+ def fit(self, x: tensors.GraphTensor | tf.data.Dataset, **kwargs):
128
+ """Fits the model.
129
+
130
+ Args:
131
+ x:
132
+ A `GraphTensor` instance or a `tf.data.Dataset` constructed from
133
+ a `GraphTensor` instance. In comparison to a typical Keras model,
134
+ the label (typically denoted `y`) and the sample_weight (typically
135
+ denoted `sample_weight`) should be encoded in the context of the
136
+ `GraphTensor` instance, as `label` and `weight` respectively.
137
+ validation_data:
138
+ A `GraphTensor` instance or a `tf.data.Dataset` constructed from
139
+ a `GraphTensor` instance. In comparison to a typical Keras model,
140
+ the label (typically denoted `y`) and the sample_weight (typically
141
+ denoted `sample_weight`) should be encoded in the context of the
142
+ `GraphTensor` instance, as `label` and `weight` respectively.
143
+ validaton_split:
144
+ The fraction of training data to be used as validation data.
145
+ Only works if a `GraphTensor` instance is passed as `x`.
146
+ batch_size:
147
+ Number of samples per batch of computation.
148
+ epochs:
149
+ Number of iterations over the entire dataset.
150
+ callbacks:
151
+ A list of callbacks to apply during training.
152
+ kwargs:
153
+ See `Model.fit` in Keras documentation.
154
+ May or may not apply here.
155
+ """
156
+ batch_size = kwargs.get('batch_size', 32)
157
+ x_val = kwargs.pop('validation_data', None)
158
+ val_split = kwargs.pop('validation_split', None)
159
+ if x_val is not None and isinstance(x_val, tensors.GraphTensor):
160
+ x_val = _make_dataset(x_val, batch_size)
161
+ if isinstance(x, tensors.GraphTensor):
162
+ if val_split:
163
+ val_size = int(val_split * x.num_subgraphs)
164
+ x_val = _make_dataset(x[-val_size:], batch_size)
165
+ x = x[:-val_size]
166
+ x = _make_dataset(x, batch_size)
167
+ return super().fit(x, validation_data=x_val, **kwargs)
168
+
169
+ def evaluate(self, x: tensors.GraphTensor | tf.data.Dataset, **kwargs):
170
+ """Evaluation of the model.
171
+
172
+ Args:
173
+ x:
174
+ A `GraphTensor` instance or a `tf.data.Dataset` constructed from
175
+ a `GraphTensor` instance. In comparison to a typical Keras model,
176
+ the label (typically denoted `y`) and the sample_weight (typically
177
+ denoted `sample_weight`) should be encoded in the context of the
178
+ `GraphTensor` instance, as `label` and `weight` respectively.
179
+ batch_size:
180
+ Number of samples per batch of computation.
181
+ kwargs:
182
+ See `Model.evaluate` in Keras documentation.
183
+ May or may not apply here.
184
+ """
185
+ batch_size = kwargs.get('batch_size', 32)
186
+ if isinstance(x, tensors.GraphTensor):
187
+ x = _make_dataset(x, batch_size)
188
+ metric_results = super().evaluate(x, **kwargs)
189
+ return tf.nest.map_structure(lambda value: float(value), metric_results)
190
+
191
+ def predict(self, x: tensors.GraphTensor | tf.data.Dataset, **kwargs):
192
+ """Makes predictions with the model.
193
+
194
+ Args:
195
+ x:
196
+ A `GraphTensor` instance or a `tf.data.Dataset` constructed from
197
+ a `GraphTensor` instance. Context `label`s and/or `weight`s may
198
+ be encoded and will be ignored.
199
+ batch_size:
200
+ Number of samples per batch of computation.
201
+ kwargs:
202
+ See `Model.predict` in Keras documentation.
203
+ May or may not apply here.
204
+ """
205
+ batch_size = kwargs.get('batch_size', 32)
206
+ if isinstance(x, tensors.GraphTensor):
207
+ x = _make_dataset(x, batch_size)
208
+ return super().predict(x, **kwargs)
209
+
210
+ def get_compile_config(self) -> dict | None:
211
+ config = super().get_compile_config()
212
+ if config is None:
213
+ return
214
+ return config
215
+
216
+ def compile_from_config(self, config: dict | None) -> None:
217
+ if config is None:
218
+ return
219
+ config = keras.utils.deserialize_keras_object(config)
220
+ self.compile(**config)
221
+ if hasattr(self, 'optimizer') and self.built:
222
+ self.optimizer.build(self.trainable_variables)
223
+
224
+ def save(
225
+ self,
226
+ filepath: str | Path,
227
+ *args,
228
+ **kwargs
229
+ ) -> None:
230
+ """Saves an entire model.
231
+
232
+ Args:
233
+ filepath:
234
+ A string with the path to the model file (requires `.keras` suffix)
235
+ """
236
+ if not self.built:
237
+ raise ValueError('Cannot save model as it has not been built yet.')
238
+ super().save(filepath, *args, **kwargs)
239
+
240
+ @staticmethod
241
+ def load(
242
+ filepath: str | Path,
243
+ *args,
244
+ **kwargs
245
+ ) -> keras.Model:
246
+ """A `staticmethod` loading an entire model.
247
+
248
+ Args:
249
+ filepath:
250
+ A string with the path to the model file (requires `.keras` suffix)
251
+ """
252
+ return keras.models.load_model(filepath, *args, **kwargs)
253
+
254
+ def save_weights(self, filepath, *args, **kwargs):
255
+ """Saves the weights of the model.
256
+
257
+ Args:
258
+ filepath:
259
+ A string with the path to the file (requires `.weights.h5` suffix)
260
+ """
261
+ path = Path(filepath).parent
262
+ path.mkdir(parents=True, exist_ok=True)
263
+ return super().save_weights(filepath, *args, **kwargs)
264
+
265
+ def load_weights(self, filepath, *args, **kwargs):
266
+ """Loads the weights from file saved via `save_weights()`.
267
+
268
+ Args:
269
+ filepath:
270
+ A string with the path to the file (requires `.weights.h5` suffix)
271
+ """
272
+ super().load_weights(filepath, *args, **kwargs)
273
+
274
+ def train_step(self, tensor: tensors.GraphTensor) -> dict[str, float]:
275
+ y = tensor.context.get('label')
276
+ sample_weight = tensor.context.get('weight')
277
+ with tf.GradientTape() as tape:
278
+ y_pred = self(tensor, training=True)
279
+ loss = self.compute_loss(tensor, y, y_pred, sample_weight)
280
+ loss = self.optimizer.scale_loss(loss)
281
+ trainable_weights = self.trainable_weights
282
+ gradients = tape.gradient(loss, trainable_weights)
283
+ self.optimizer.apply_gradients(zip(gradients, trainable_weights))
284
+ return self.compute_metrics(tensor, y, y_pred, sample_weight)
285
+
286
+ def test_step(self, tensor: tensors.GraphTensor) -> dict[str, float]:
287
+ y = tensor.context.get('label')
288
+ sample_weight = tensor.context.get('weight')
289
+ y_pred = self(tensor, training=False)
290
+ return self.compute_metrics(tensor, y, y_pred, sample_weight)
291
+
292
+ def predict_step(self, tensor: tensors.GraphTensor) -> np.ndarray:
293
+ return self(tensor, training=False)
294
+
295
+ def compute_metrics(self, x, y, y_pred, sample_weight=None) -> dict[str, float]:
296
+ loss = self.compute_loss(x, y, y_pred, sample_weight)
297
+ metric_results = {}
298
+ for metric in self.metrics:
299
+ if metric.name == "loss":
300
+ metric.update_state(loss)
301
+ metric_results[metric.name] = metric.result()
302
+ else:
303
+ metric.update_state(y, y_pred)
304
+ metric_results.update(metric.result())
305
+ return metric_results
306
+
307
+
308
+ @keras.saving.register_keras_serializable(package="molcraft")
309
+ class FunctionalGraphModel(functional.Functional, GraphModel):
310
+
311
+ @property
312
+ def layers(self):
313
+ return [
314
+ layer for layer in super().layers
315
+ if not isinstance(layer, keras.layers.InputLayer)
316
+ ]
317
+
318
+
319
+ def save_model(model: keras.Model, filepath: str | Path, *args, **kwargs) -> None:
320
+ keras.models.save_model(model, filepath, *args, **kwargs)
321
+
322
+ def load_model(filepath: str | Path, inputs=None, *args, **kwargs) -> None:
323
+ return keras.models.load_model(filepath, *args, **kwargs)
324
+
325
+ def create(
326
+ *layers: list[keras.layers.Layer]
327
+ ) -> GraphModel:
328
+ if isinstance(layers[0], list):
329
+ layers = layers[0]
330
+ return GraphModel.from_layers(
331
+ list(layers)
332
+ )
333
+
334
+ def interpret(
335
+ model: GraphModel,
336
+ graph_tensor: tensors.GraphTensor,
337
+ ) -> tuple[tf.Tensor | tf.RaggedTensor | np.ndarray, tf.Tensor | np.ndarray]:
338
+ x = graph_tensor
339
+ if tensors.is_ragged(x):
340
+ x = x.flatten()
341
+ graph_indicator = x.graph_indicator
342
+ y_true = x.context.get('label')
343
+ features = []
344
+ with tf.GradientTape(watch_accessed_variables=False) as tape:
345
+ for layer in model.layers:
346
+ if isinstance(layer, layers.GraphNetwork):
347
+ x, taped_features = layer.tape_propagate(x, tape, training=False)
348
+ features.extend(taped_features)
349
+ else:
350
+ if (
351
+ isinstance(layer, layers.GraphConv) and
352
+ isinstance(x, tensors.GraphTensor)
353
+ ):
354
+ tape.watch(x.node['feature'])
355
+ features.append(x.node['feature'])
356
+ x = layer(x, training=False)
357
+ y_pred = x
358
+ if y_true is not None and len(y_true.shape) > 1:
359
+ target = tf.gather_nd(y_pred, tf.where(y_true != 0))
360
+ else:
361
+ target = y_pred
362
+ gradients = tape.gradient(target, features)
363
+ features = keras.ops.concatenate(features, axis=-1)
364
+ gradients = keras.ops.concatenate(gradients, axis=-1)
365
+ alpha = ops.segment_mean(gradients, graph_indicator)
366
+ alpha = ops.gather(alpha, graph_indicator)
367
+ maps = keras.ops.where(gradients != 0, alpha * features, gradients)
368
+ maps = keras.ops.sum(maps, axis=-1)
369
+ return graph_tensor.update(
370
+ {
371
+ 'node': {
372
+ 'saliency': maps
373
+ }
374
+ }
375
+ )
376
+
377
+ def predict(
378
+ model: GraphModel,
379
+ x: tensors.GraphTensor | tf.data.Dataset,
380
+ repeats: int | None = 16,
381
+ batch_size: int = 256,
382
+ verbose: int = 0,
383
+ **kwargs,
384
+ ) -> tuple[tf.Tensor | np.ndarray, tf.Tensor | np.ndarray]:
385
+ """Predict with model.
386
+
387
+ By default performs monte-carlo predictions. Namely, it performs
388
+ `repeats` number of predictions for each example with `training = True`,
389
+ and subsequently computes mean and standard deviations of the predictions.
390
+
391
+ Args:
392
+ x:
393
+ A `GraphTensor` instance.
394
+ repeats:
395
+ Number of predictions per example.
396
+ batch_size:
397
+ Number of samples per batch of computation.
398
+ kwargs:
399
+ See `Model.predict` in Keras documentation.
400
+ May or may not apply here.
401
+ """
402
+ if not repeats:
403
+ return model.predict(
404
+ x, batch_size=batch_size, verbose=verbose, **kwargs
405
+ )
406
+ if isinstance(x, tensors.GraphTensor):
407
+ ds = tf.data.Dataset.from_tensor_slices(x)
408
+ ds = ds.repeat(repeats)
409
+ ds = ds.batch(batch_size)
410
+ elif isinstance(x, tf.data.Dataset):
411
+ ds = x.repeat(repeats)
412
+ else:
413
+ raise ValueError(
414
+ 'Input `x` needs to be a `tensors.GraphTensor` instance '
415
+ 'or a `tf.data.Dataset` instance constructed from `tensors.GraphTensor`.'
416
+ )
417
+ ds = ds.prefetch(-1)
418
+ y_pred = keras.ops.concatenate([
419
+ model(x, training=True) for x in ds])
420
+ global_batch_size = len(y_pred) // repeats
421
+ y_pred = np.reshape(y_pred, (repeats, global_batch_size, -1))
422
+ y_pred_loc = keras.ops.mean(y_pred, axis=0)
423
+ y_pred_scale = keras.ops.std(y_pred, axis=0)
424
+ if tf.executing_eagerly():
425
+ y_pred_loc = y_pred_loc.numpy()
426
+ y_pred_scale = y_pred_scale.numpy()
427
+ return (y_pred_loc, y_pred_scale)
428
+
429
+ def _functional_init_arguments(args, kwargs):
430
+ return (
431
+ (len(args) == 2)
432
+ or (len(args) == 1 and "outputs" in kwargs)
433
+ or ("inputs" in kwargs and "outputs" in kwargs)
434
+ )
435
+
436
+ def _make_dataset(x: tensors.GraphTensor, batch_size: int):
437
+ return (
438
+ tf.data.Dataset.from_tensor_slices(x)
439
+ .batch(batch_size)
440
+ .prefetch(-1)
441
+ )
molcraft/ops.py ADDED
@@ -0,0 +1,129 @@
1
+ import keras
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from keras import backend
5
+
6
+
7
+ def gather(
8
+ node_feature: tf.Tensor,
9
+ edge: tf.Tensor
10
+ ) -> tf.Tensor:
11
+ if backend.backend() == 'tensorflow':
12
+ return tf.gather(node_feature, edge)
13
+ expected_rank = len(keras.ops.shape(node_feature))
14
+ current_rank = len(keras.ops.shape(edge))
15
+ for _ in range(expected_rank - current_rank):
16
+ edge = keras.ops.expand_dims(edge, axis=-1)
17
+ return keras.ops.take_along_axis(node_feature, edge, axis=0)
18
+
19
+ def aggregate(
20
+ node_feature: tf.Tensor,
21
+ edge: tf.Tensor,
22
+ num_nodes: tf.Tensor
23
+ ) -> tf.Tensor:
24
+ return keras.ops.segment_sum(node_feature, edge, num_nodes)
25
+
26
+ def propagate(
27
+ node_feature: tf.Tensor,
28
+ edge_source: tf.Tensor,
29
+ edge_target: tf.Tensor,
30
+ edge_feature: tf.Tensor | None = None,
31
+ edge_weight: tf.Tensor | None = None,
32
+ ) -> tf.Tensor:
33
+ num_nodes = keras.ops.shape(node_feature)[0]
34
+
35
+ node_feature_source = gather(node_feature, edge_source)
36
+
37
+ if edge_weight is not None:
38
+ node_feature_source *= edge_weight
39
+
40
+ if edge_feature is not None:
41
+ node_feature_source += edge_feature
42
+
43
+ return aggregate(node_feature, edge_target, num_nodes)
44
+
45
+ def scatter_update(
46
+ inputs: tf.Tensor,
47
+ indices: tf.Tensor,
48
+ updates: tf.Tensor,
49
+ ) -> tf.Tensor:
50
+ if indices.dtype == tf.bool:
51
+ indices = keras.ops.stack(keras.ops.where(indices), axis=-1)
52
+ expected_rank = len(keras.ops.shape(inputs))
53
+ current_rank = len(keras.ops.shape(indices))
54
+ for _ in range(expected_rank - current_rank):
55
+ indices = keras.ops.expand_dims(indices, axis=-1)
56
+ return keras.ops.scatter_update(inputs, indices, updates)
57
+
58
+ def edge_softmax(
59
+ score: tf.Tensor,
60
+ edge_target: tf.Tensor
61
+ ) -> tf.Tensor:
62
+ num_segments = keras.ops.cond(
63
+ keras.ops.shape(edge_target)[0] > 0,
64
+ lambda: keras.ops.maximum(keras.ops.max(edge_target) + 1, 1),
65
+ lambda: 0
66
+ )
67
+ score_max = keras.ops.segment_max(
68
+ score, edge_target, num_segments, sorted=False
69
+ )
70
+ score_max = gather(score_max, edge_target)
71
+ numerator = keras.ops.exp(score - score_max)
72
+ denominator = keras.ops.segment_sum(
73
+ numerator, edge_target, num_segments, sorted=False
74
+ )
75
+ denominator = gather(denominator, edge_target)
76
+ return numerator / denominator
77
+
78
+ def segment_mean(
79
+ data: tf.Tensor,
80
+ segment_ids: tf.Tensor,
81
+ num_segments: int | None = None,
82
+ sorted: bool = False,
83
+ ) -> tf.Tensor:
84
+ if num_segments is None:
85
+ num_segments = keras.ops.max(segment_ids) + 1
86
+ if backend.backend() == 'tensorflow':
87
+ return tf.math.unsorted_segment_mean(
88
+ data=data,
89
+ segment_ids=segment_ids,
90
+ num_segments=num_segments
91
+ )
92
+ x = keras.ops.segment_sum(
93
+ data=data,
94
+ segment_ids=segment_ids,
95
+ num_segments=num_segments,
96
+ sorted=sorted
97
+ )
98
+ sizes = keras.ops.cast(
99
+ keras.ops.bincount(segment_ids, minlength=num_segments),
100
+ dtype=x.dtype
101
+ )
102
+ return x / sizes[:, None]
103
+
104
+ def gaussian(
105
+ x: tf.Tensor,
106
+ mean: tf.Tensor,
107
+ std: tf.Tensor
108
+ ) -> tf.Tensor:
109
+ expected_rank = len(keras.ops.shape(x))
110
+ current_rank = len(keras.ops.shape(mean))
111
+ for _ in range(expected_rank - current_rank):
112
+ mean = keras.ops.expand_dims(mean, axis=0)
113
+ std = keras.ops.expand_dims(std, axis=0)
114
+ a = (2 * np.pi) ** 0.5
115
+ return keras.ops.exp(-0.5 * (((x - mean) / std) ** 2)) / (a * std)
116
+
117
+ def euclidean_distance(
118
+ x1: tf.Tensor,
119
+ x2: tf.Tensor,
120
+ axis: int = -1
121
+ ) -> tf.Tensor:
122
+ relative_distance = keras.ops.subtract(x1, x2)
123
+ return keras.ops.sqrt(
124
+ keras.ops.sum(
125
+ keras.ops.square(relative_distance),
126
+ axis=axis,
127
+ keepdims=True
128
+ )
129
+ )