molcraft 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of molcraft might be problematic. Click here for more details.
- molcraft/__init__.py +16 -0
- molcraft/callbacks.py +21 -0
- molcraft/chem.py +600 -0
- molcraft/conformers.py +155 -0
- molcraft/descriptors.py +90 -0
- molcraft/experimental/__init__.py +1 -0
- molcraft/experimental/peptides.py +303 -0
- molcraft/features.py +387 -0
- molcraft/featurizers.py +693 -0
- molcraft/layers.py +1224 -0
- molcraft/models.py +441 -0
- molcraft/ops.py +129 -0
- molcraft/records.py +169 -0
- molcraft/tensors.py +527 -0
- molcraft-0.1.0a1.dist-info/METADATA +58 -0
- molcraft-0.1.0a1.dist-info/RECORD +19 -0
- molcraft-0.1.0a1.dist-info/WHEEL +5 -0
- molcraft-0.1.0a1.dist-info/licenses/LICENSE +21 -0
- molcraft-0.1.0a1.dist-info/top_level.txt +1 -0
molcraft/models.py
ADDED
|
@@ -0,0 +1,441 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
import keras
|
|
3
|
+
import numpy as np
|
|
4
|
+
import tensorflow as tf
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from keras.src.models import functional
|
|
7
|
+
|
|
8
|
+
from molcraft import layers
|
|
9
|
+
from molcraft import tensors
|
|
10
|
+
from molcraft import ops
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@keras.saving.register_keras_serializable(package="molcraft")
|
|
14
|
+
class GraphModel(layers.GraphLayer, keras.models.Model):
|
|
15
|
+
|
|
16
|
+
"""A graph model.
|
|
17
|
+
|
|
18
|
+
Currently, the `GraphModel` only supports `GraphTensor` input.
|
|
19
|
+
|
|
20
|
+
Example (using `from_layers`):
|
|
21
|
+
|
|
22
|
+
>>> import molcraft
|
|
23
|
+
>>> import keras
|
|
24
|
+
>>>
|
|
25
|
+
>>> featurizer = molcraft.featurizers.MolGraphFeaturizer()
|
|
26
|
+
>>> graph = featurizer([('N[C@@H](C)C(=O)O', 1.0), ('N[C@@H](CS)C(=O)O', 2.0)])
|
|
27
|
+
>>>
|
|
28
|
+
>>> model = molcraft.models.GraphModel.from_layers(
|
|
29
|
+
... molcraft.layers.Input(graph.spec),
|
|
30
|
+
... molcraft.layers.NodeEmbedding(128),
|
|
31
|
+
... molcraft.layers.EdgeEmbedding(128),
|
|
32
|
+
... molcraft.layers.GraphTransformer(128),
|
|
33
|
+
... molcraft.layers.GraphTransformer(128),
|
|
34
|
+
... molcraft.layers.Readout('mean'),
|
|
35
|
+
... molcraft.layers.Dense(1)
|
|
36
|
+
... ])
|
|
37
|
+
>>> model.compile(
|
|
38
|
+
... optimizer=keras.optimizers.Adam(1e-3),
|
|
39
|
+
... loss=keras.losses.MeanSquaredError(),
|
|
40
|
+
... metrics=[keras.metrics.MeanAbsolutePercentageError(name='mape')]
|
|
41
|
+
... )
|
|
42
|
+
>>> model.fit(graph, epochs=10)
|
|
43
|
+
>>> mse, mape = model.evaluate(graph)
|
|
44
|
+
>>> preds = model.predict(graph)
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __new__(cls, *args, **kwargs):
|
|
48
|
+
if _functional_init_arguments(args, kwargs) and cls == GraphModel:
|
|
49
|
+
return FunctionalGraphModel(*args, **kwargs)
|
|
50
|
+
return typing.cast(GraphModel, super().__new__(cls))
|
|
51
|
+
|
|
52
|
+
def __init__(self, *args, **kwargs):
|
|
53
|
+
super().__init__(*args, **kwargs)
|
|
54
|
+
self.jit_compile = False
|
|
55
|
+
|
|
56
|
+
@classmethod
|
|
57
|
+
def from_layers(cls, graph_layers: list, **kwargs):
|
|
58
|
+
"""Creates a graph model from a list of graph layers.
|
|
59
|
+
|
|
60
|
+
Currently requires `molcraft.layers.Input(spec)`.
|
|
61
|
+
|
|
62
|
+
If `molcraft.layers.Input(spec)` is supplied, it both
|
|
63
|
+
creates and builds the layer, as a functional model.
|
|
64
|
+
`molcraft.layers.Input` is a function which returns
|
|
65
|
+
a nested structure of graph components based on `spec`.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
graph_layers:
|
|
69
|
+
A list of `GraphLayer` instances, except the initial element
|
|
70
|
+
which is a dictionary of Keras tensors produced by
|
|
71
|
+
`molcraft.layers.Input(spec)`.
|
|
72
|
+
"""
|
|
73
|
+
if not tensors.is_graph(graph_layers[0]):
|
|
74
|
+
# TODO: Allow this. E.g.: return cls(layers=graph_layers)
|
|
75
|
+
raise ValueError(
|
|
76
|
+
'Graph input not found. Make sure to add `Input`.'
|
|
77
|
+
)
|
|
78
|
+
inputs: dict = graph_layers.pop(0)
|
|
79
|
+
x = inputs
|
|
80
|
+
for layer in graph_layers:
|
|
81
|
+
if isinstance(layer, list):
|
|
82
|
+
layer = layers.GraphNetwork(layer)
|
|
83
|
+
x = layer(x)
|
|
84
|
+
outputs = x
|
|
85
|
+
return cls(inputs=inputs, outputs=outputs, **kwargs)
|
|
86
|
+
|
|
87
|
+
def compile(
|
|
88
|
+
self,
|
|
89
|
+
optimizer: keras.optimizers.Optimizer | str | None = 'rmsprop',
|
|
90
|
+
loss: keras.losses.Loss | str | None = None,
|
|
91
|
+
loss_weights: dict[str, float] = None,
|
|
92
|
+
metrics: list[keras.metrics.Metric] = None,
|
|
93
|
+
weighted_metrics: list[keras.metrics.Metric] | None = None,
|
|
94
|
+
run_eagerly: bool = False,
|
|
95
|
+
steps_per_execution: int = 1,
|
|
96
|
+
jit_compile: str | bool = False,
|
|
97
|
+
auto_scale_loss: bool = True,
|
|
98
|
+
**kwargs
|
|
99
|
+
) -> None:
|
|
100
|
+
"""Compiles the model.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
optimizer:
|
|
104
|
+
The optimizer to be used (a `keras.optimizers.Optimizer` subclass).
|
|
105
|
+
loss:
|
|
106
|
+
The loss function to be used (a `keras.losses.Loss` subclass).
|
|
107
|
+
metrics:
|
|
108
|
+
A list of metrics to be used during training (`fit`) and evaluation
|
|
109
|
+
(`evaluate`). Should be `keras.metrics.Metric` subclasses.
|
|
110
|
+
kwargs:
|
|
111
|
+
See `Model.compile` in Keras documentation.
|
|
112
|
+
May or may not apply here.
|
|
113
|
+
"""
|
|
114
|
+
super().compile(
|
|
115
|
+
optimizer=optimizer,
|
|
116
|
+
loss=loss,
|
|
117
|
+
loss_weights=loss_weights,
|
|
118
|
+
metrics=metrics,
|
|
119
|
+
weighted_metrics=weighted_metrics,
|
|
120
|
+
run_eagerly=run_eagerly,
|
|
121
|
+
steps_per_execution=steps_per_execution,
|
|
122
|
+
jit_compile=jit_compile,
|
|
123
|
+
auto_scale_loss=auto_scale_loss,
|
|
124
|
+
**kwargs
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
def fit(self, x: tensors.GraphTensor | tf.data.Dataset, **kwargs):
|
|
128
|
+
"""Fits the model.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
x:
|
|
132
|
+
A `GraphTensor` instance or a `tf.data.Dataset` constructed from
|
|
133
|
+
a `GraphTensor` instance. In comparison to a typical Keras model,
|
|
134
|
+
the label (typically denoted `y`) and the sample_weight (typically
|
|
135
|
+
denoted `sample_weight`) should be encoded in the context of the
|
|
136
|
+
`GraphTensor` instance, as `label` and `weight` respectively.
|
|
137
|
+
validation_data:
|
|
138
|
+
A `GraphTensor` instance or a `tf.data.Dataset` constructed from
|
|
139
|
+
a `GraphTensor` instance. In comparison to a typical Keras model,
|
|
140
|
+
the label (typically denoted `y`) and the sample_weight (typically
|
|
141
|
+
denoted `sample_weight`) should be encoded in the context of the
|
|
142
|
+
`GraphTensor` instance, as `label` and `weight` respectively.
|
|
143
|
+
validaton_split:
|
|
144
|
+
The fraction of training data to be used as validation data.
|
|
145
|
+
Only works if a `GraphTensor` instance is passed as `x`.
|
|
146
|
+
batch_size:
|
|
147
|
+
Number of samples per batch of computation.
|
|
148
|
+
epochs:
|
|
149
|
+
Number of iterations over the entire dataset.
|
|
150
|
+
callbacks:
|
|
151
|
+
A list of callbacks to apply during training.
|
|
152
|
+
kwargs:
|
|
153
|
+
See `Model.fit` in Keras documentation.
|
|
154
|
+
May or may not apply here.
|
|
155
|
+
"""
|
|
156
|
+
batch_size = kwargs.get('batch_size', 32)
|
|
157
|
+
x_val = kwargs.pop('validation_data', None)
|
|
158
|
+
val_split = kwargs.pop('validation_split', None)
|
|
159
|
+
if x_val is not None and isinstance(x_val, tensors.GraphTensor):
|
|
160
|
+
x_val = _make_dataset(x_val, batch_size)
|
|
161
|
+
if isinstance(x, tensors.GraphTensor):
|
|
162
|
+
if val_split:
|
|
163
|
+
val_size = int(val_split * x.num_subgraphs)
|
|
164
|
+
x_val = _make_dataset(x[-val_size:], batch_size)
|
|
165
|
+
x = x[:-val_size]
|
|
166
|
+
x = _make_dataset(x, batch_size)
|
|
167
|
+
return super().fit(x, validation_data=x_val, **kwargs)
|
|
168
|
+
|
|
169
|
+
def evaluate(self, x: tensors.GraphTensor | tf.data.Dataset, **kwargs):
|
|
170
|
+
"""Evaluation of the model.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
x:
|
|
174
|
+
A `GraphTensor` instance or a `tf.data.Dataset` constructed from
|
|
175
|
+
a `GraphTensor` instance. In comparison to a typical Keras model,
|
|
176
|
+
the label (typically denoted `y`) and the sample_weight (typically
|
|
177
|
+
denoted `sample_weight`) should be encoded in the context of the
|
|
178
|
+
`GraphTensor` instance, as `label` and `weight` respectively.
|
|
179
|
+
batch_size:
|
|
180
|
+
Number of samples per batch of computation.
|
|
181
|
+
kwargs:
|
|
182
|
+
See `Model.evaluate` in Keras documentation.
|
|
183
|
+
May or may not apply here.
|
|
184
|
+
"""
|
|
185
|
+
batch_size = kwargs.get('batch_size', 32)
|
|
186
|
+
if isinstance(x, tensors.GraphTensor):
|
|
187
|
+
x = _make_dataset(x, batch_size)
|
|
188
|
+
metric_results = super().evaluate(x, **kwargs)
|
|
189
|
+
return tf.nest.map_structure(lambda value: float(value), metric_results)
|
|
190
|
+
|
|
191
|
+
def predict(self, x: tensors.GraphTensor | tf.data.Dataset, **kwargs):
|
|
192
|
+
"""Makes predictions with the model.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
x:
|
|
196
|
+
A `GraphTensor` instance or a `tf.data.Dataset` constructed from
|
|
197
|
+
a `GraphTensor` instance. Context `label`s and/or `weight`s may
|
|
198
|
+
be encoded and will be ignored.
|
|
199
|
+
batch_size:
|
|
200
|
+
Number of samples per batch of computation.
|
|
201
|
+
kwargs:
|
|
202
|
+
See `Model.predict` in Keras documentation.
|
|
203
|
+
May or may not apply here.
|
|
204
|
+
"""
|
|
205
|
+
batch_size = kwargs.get('batch_size', 32)
|
|
206
|
+
if isinstance(x, tensors.GraphTensor):
|
|
207
|
+
x = _make_dataset(x, batch_size)
|
|
208
|
+
return super().predict(x, **kwargs)
|
|
209
|
+
|
|
210
|
+
def get_compile_config(self) -> dict | None:
|
|
211
|
+
config = super().get_compile_config()
|
|
212
|
+
if config is None:
|
|
213
|
+
return
|
|
214
|
+
return config
|
|
215
|
+
|
|
216
|
+
def compile_from_config(self, config: dict | None) -> None:
|
|
217
|
+
if config is None:
|
|
218
|
+
return
|
|
219
|
+
config = keras.utils.deserialize_keras_object(config)
|
|
220
|
+
self.compile(**config)
|
|
221
|
+
if hasattr(self, 'optimizer') and self.built:
|
|
222
|
+
self.optimizer.build(self.trainable_variables)
|
|
223
|
+
|
|
224
|
+
def save(
|
|
225
|
+
self,
|
|
226
|
+
filepath: str | Path,
|
|
227
|
+
*args,
|
|
228
|
+
**kwargs
|
|
229
|
+
) -> None:
|
|
230
|
+
"""Saves an entire model.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
filepath:
|
|
234
|
+
A string with the path to the model file (requires `.keras` suffix)
|
|
235
|
+
"""
|
|
236
|
+
if not self.built:
|
|
237
|
+
raise ValueError('Cannot save model as it has not been built yet.')
|
|
238
|
+
super().save(filepath, *args, **kwargs)
|
|
239
|
+
|
|
240
|
+
@staticmethod
|
|
241
|
+
def load(
|
|
242
|
+
filepath: str | Path,
|
|
243
|
+
*args,
|
|
244
|
+
**kwargs
|
|
245
|
+
) -> keras.Model:
|
|
246
|
+
"""A `staticmethod` loading an entire model.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
filepath:
|
|
250
|
+
A string with the path to the model file (requires `.keras` suffix)
|
|
251
|
+
"""
|
|
252
|
+
return keras.models.load_model(filepath, *args, **kwargs)
|
|
253
|
+
|
|
254
|
+
def save_weights(self, filepath, *args, **kwargs):
|
|
255
|
+
"""Saves the weights of the model.
|
|
256
|
+
|
|
257
|
+
Args:
|
|
258
|
+
filepath:
|
|
259
|
+
A string with the path to the file (requires `.weights.h5` suffix)
|
|
260
|
+
"""
|
|
261
|
+
path = Path(filepath).parent
|
|
262
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
263
|
+
return super().save_weights(filepath, *args, **kwargs)
|
|
264
|
+
|
|
265
|
+
def load_weights(self, filepath, *args, **kwargs):
|
|
266
|
+
"""Loads the weights from file saved via `save_weights()`.
|
|
267
|
+
|
|
268
|
+
Args:
|
|
269
|
+
filepath:
|
|
270
|
+
A string with the path to the file (requires `.weights.h5` suffix)
|
|
271
|
+
"""
|
|
272
|
+
super().load_weights(filepath, *args, **kwargs)
|
|
273
|
+
|
|
274
|
+
def train_step(self, tensor: tensors.GraphTensor) -> dict[str, float]:
|
|
275
|
+
y = tensor.context.get('label')
|
|
276
|
+
sample_weight = tensor.context.get('weight')
|
|
277
|
+
with tf.GradientTape() as tape:
|
|
278
|
+
y_pred = self(tensor, training=True)
|
|
279
|
+
loss = self.compute_loss(tensor, y, y_pred, sample_weight)
|
|
280
|
+
loss = self.optimizer.scale_loss(loss)
|
|
281
|
+
trainable_weights = self.trainable_weights
|
|
282
|
+
gradients = tape.gradient(loss, trainable_weights)
|
|
283
|
+
self.optimizer.apply_gradients(zip(gradients, trainable_weights))
|
|
284
|
+
return self.compute_metrics(tensor, y, y_pred, sample_weight)
|
|
285
|
+
|
|
286
|
+
def test_step(self, tensor: tensors.GraphTensor) -> dict[str, float]:
|
|
287
|
+
y = tensor.context.get('label')
|
|
288
|
+
sample_weight = tensor.context.get('weight')
|
|
289
|
+
y_pred = self(tensor, training=False)
|
|
290
|
+
return self.compute_metrics(tensor, y, y_pred, sample_weight)
|
|
291
|
+
|
|
292
|
+
def predict_step(self, tensor: tensors.GraphTensor) -> np.ndarray:
|
|
293
|
+
return self(tensor, training=False)
|
|
294
|
+
|
|
295
|
+
def compute_metrics(self, x, y, y_pred, sample_weight=None) -> dict[str, float]:
|
|
296
|
+
loss = self.compute_loss(x, y, y_pred, sample_weight)
|
|
297
|
+
metric_results = {}
|
|
298
|
+
for metric in self.metrics:
|
|
299
|
+
if metric.name == "loss":
|
|
300
|
+
metric.update_state(loss)
|
|
301
|
+
metric_results[metric.name] = metric.result()
|
|
302
|
+
else:
|
|
303
|
+
metric.update_state(y, y_pred)
|
|
304
|
+
metric_results.update(metric.result())
|
|
305
|
+
return metric_results
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
@keras.saving.register_keras_serializable(package="molcraft")
|
|
309
|
+
class FunctionalGraphModel(functional.Functional, GraphModel):
|
|
310
|
+
|
|
311
|
+
@property
|
|
312
|
+
def layers(self):
|
|
313
|
+
return [
|
|
314
|
+
layer for layer in super().layers
|
|
315
|
+
if not isinstance(layer, keras.layers.InputLayer)
|
|
316
|
+
]
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def save_model(model: keras.Model, filepath: str | Path, *args, **kwargs) -> None:
|
|
320
|
+
keras.models.save_model(model, filepath, *args, **kwargs)
|
|
321
|
+
|
|
322
|
+
def load_model(filepath: str | Path, inputs=None, *args, **kwargs) -> None:
|
|
323
|
+
return keras.models.load_model(filepath, *args, **kwargs)
|
|
324
|
+
|
|
325
|
+
def create(
|
|
326
|
+
*layers: list[keras.layers.Layer]
|
|
327
|
+
) -> GraphModel:
|
|
328
|
+
if isinstance(layers[0], list):
|
|
329
|
+
layers = layers[0]
|
|
330
|
+
return GraphModel.from_layers(
|
|
331
|
+
list(layers)
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
def interpret(
|
|
335
|
+
model: GraphModel,
|
|
336
|
+
graph_tensor: tensors.GraphTensor,
|
|
337
|
+
) -> tuple[tf.Tensor | tf.RaggedTensor | np.ndarray, tf.Tensor | np.ndarray]:
|
|
338
|
+
x = graph_tensor
|
|
339
|
+
if tensors.is_ragged(x):
|
|
340
|
+
x = x.flatten()
|
|
341
|
+
graph_indicator = x.graph_indicator
|
|
342
|
+
y_true = x.context.get('label')
|
|
343
|
+
features = []
|
|
344
|
+
with tf.GradientTape(watch_accessed_variables=False) as tape:
|
|
345
|
+
for layer in model.layers:
|
|
346
|
+
if isinstance(layer, layers.GraphNetwork):
|
|
347
|
+
x, taped_features = layer.tape_propagate(x, tape, training=False)
|
|
348
|
+
features.extend(taped_features)
|
|
349
|
+
else:
|
|
350
|
+
if (
|
|
351
|
+
isinstance(layer, layers.GraphConv) and
|
|
352
|
+
isinstance(x, tensors.GraphTensor)
|
|
353
|
+
):
|
|
354
|
+
tape.watch(x.node['feature'])
|
|
355
|
+
features.append(x.node['feature'])
|
|
356
|
+
x = layer(x, training=False)
|
|
357
|
+
y_pred = x
|
|
358
|
+
if y_true is not None and len(y_true.shape) > 1:
|
|
359
|
+
target = tf.gather_nd(y_pred, tf.where(y_true != 0))
|
|
360
|
+
else:
|
|
361
|
+
target = y_pred
|
|
362
|
+
gradients = tape.gradient(target, features)
|
|
363
|
+
features = keras.ops.concatenate(features, axis=-1)
|
|
364
|
+
gradients = keras.ops.concatenate(gradients, axis=-1)
|
|
365
|
+
alpha = ops.segment_mean(gradients, graph_indicator)
|
|
366
|
+
alpha = ops.gather(alpha, graph_indicator)
|
|
367
|
+
maps = keras.ops.where(gradients != 0, alpha * features, gradients)
|
|
368
|
+
maps = keras.ops.sum(maps, axis=-1)
|
|
369
|
+
return graph_tensor.update(
|
|
370
|
+
{
|
|
371
|
+
'node': {
|
|
372
|
+
'saliency': maps
|
|
373
|
+
}
|
|
374
|
+
}
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
def predict(
|
|
378
|
+
model: GraphModel,
|
|
379
|
+
x: tensors.GraphTensor | tf.data.Dataset,
|
|
380
|
+
repeats: int | None = 16,
|
|
381
|
+
batch_size: int = 256,
|
|
382
|
+
verbose: int = 0,
|
|
383
|
+
**kwargs,
|
|
384
|
+
) -> tuple[tf.Tensor | np.ndarray, tf.Tensor | np.ndarray]:
|
|
385
|
+
"""Predict with model.
|
|
386
|
+
|
|
387
|
+
By default performs monte-carlo predictions. Namely, it performs
|
|
388
|
+
`repeats` number of predictions for each example with `training = True`,
|
|
389
|
+
and subsequently computes mean and standard deviations of the predictions.
|
|
390
|
+
|
|
391
|
+
Args:
|
|
392
|
+
x:
|
|
393
|
+
A `GraphTensor` instance.
|
|
394
|
+
repeats:
|
|
395
|
+
Number of predictions per example.
|
|
396
|
+
batch_size:
|
|
397
|
+
Number of samples per batch of computation.
|
|
398
|
+
kwargs:
|
|
399
|
+
See `Model.predict` in Keras documentation.
|
|
400
|
+
May or may not apply here.
|
|
401
|
+
"""
|
|
402
|
+
if not repeats:
|
|
403
|
+
return model.predict(
|
|
404
|
+
x, batch_size=batch_size, verbose=verbose, **kwargs
|
|
405
|
+
)
|
|
406
|
+
if isinstance(x, tensors.GraphTensor):
|
|
407
|
+
ds = tf.data.Dataset.from_tensor_slices(x)
|
|
408
|
+
ds = ds.repeat(repeats)
|
|
409
|
+
ds = ds.batch(batch_size)
|
|
410
|
+
elif isinstance(x, tf.data.Dataset):
|
|
411
|
+
ds = x.repeat(repeats)
|
|
412
|
+
else:
|
|
413
|
+
raise ValueError(
|
|
414
|
+
'Input `x` needs to be a `tensors.GraphTensor` instance '
|
|
415
|
+
'or a `tf.data.Dataset` instance constructed from `tensors.GraphTensor`.'
|
|
416
|
+
)
|
|
417
|
+
ds = ds.prefetch(-1)
|
|
418
|
+
y_pred = keras.ops.concatenate([
|
|
419
|
+
model(x, training=True) for x in ds])
|
|
420
|
+
global_batch_size = len(y_pred) // repeats
|
|
421
|
+
y_pred = np.reshape(y_pred, (repeats, global_batch_size, -1))
|
|
422
|
+
y_pred_loc = keras.ops.mean(y_pred, axis=0)
|
|
423
|
+
y_pred_scale = keras.ops.std(y_pred, axis=0)
|
|
424
|
+
if tf.executing_eagerly():
|
|
425
|
+
y_pred_loc = y_pred_loc.numpy()
|
|
426
|
+
y_pred_scale = y_pred_scale.numpy()
|
|
427
|
+
return (y_pred_loc, y_pred_scale)
|
|
428
|
+
|
|
429
|
+
def _functional_init_arguments(args, kwargs):
|
|
430
|
+
return (
|
|
431
|
+
(len(args) == 2)
|
|
432
|
+
or (len(args) == 1 and "outputs" in kwargs)
|
|
433
|
+
or ("inputs" in kwargs and "outputs" in kwargs)
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
def _make_dataset(x: tensors.GraphTensor, batch_size: int):
|
|
437
|
+
return (
|
|
438
|
+
tf.data.Dataset.from_tensor_slices(x)
|
|
439
|
+
.batch(batch_size)
|
|
440
|
+
.prefetch(-1)
|
|
441
|
+
)
|
molcraft/ops.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
import numpy as np
|
|
3
|
+
import tensorflow as tf
|
|
4
|
+
from keras import backend
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def gather(
|
|
8
|
+
node_feature: tf.Tensor,
|
|
9
|
+
edge: tf.Tensor
|
|
10
|
+
) -> tf.Tensor:
|
|
11
|
+
if backend.backend() == 'tensorflow':
|
|
12
|
+
return tf.gather(node_feature, edge)
|
|
13
|
+
expected_rank = len(keras.ops.shape(node_feature))
|
|
14
|
+
current_rank = len(keras.ops.shape(edge))
|
|
15
|
+
for _ in range(expected_rank - current_rank):
|
|
16
|
+
edge = keras.ops.expand_dims(edge, axis=-1)
|
|
17
|
+
return keras.ops.take_along_axis(node_feature, edge, axis=0)
|
|
18
|
+
|
|
19
|
+
def aggregate(
|
|
20
|
+
node_feature: tf.Tensor,
|
|
21
|
+
edge: tf.Tensor,
|
|
22
|
+
num_nodes: tf.Tensor
|
|
23
|
+
) -> tf.Tensor:
|
|
24
|
+
return keras.ops.segment_sum(node_feature, edge, num_nodes)
|
|
25
|
+
|
|
26
|
+
def propagate(
|
|
27
|
+
node_feature: tf.Tensor,
|
|
28
|
+
edge_source: tf.Tensor,
|
|
29
|
+
edge_target: tf.Tensor,
|
|
30
|
+
edge_feature: tf.Tensor | None = None,
|
|
31
|
+
edge_weight: tf.Tensor | None = None,
|
|
32
|
+
) -> tf.Tensor:
|
|
33
|
+
num_nodes = keras.ops.shape(node_feature)[0]
|
|
34
|
+
|
|
35
|
+
node_feature_source = gather(node_feature, edge_source)
|
|
36
|
+
|
|
37
|
+
if edge_weight is not None:
|
|
38
|
+
node_feature_source *= edge_weight
|
|
39
|
+
|
|
40
|
+
if edge_feature is not None:
|
|
41
|
+
node_feature_source += edge_feature
|
|
42
|
+
|
|
43
|
+
return aggregate(node_feature, edge_target, num_nodes)
|
|
44
|
+
|
|
45
|
+
def scatter_update(
|
|
46
|
+
inputs: tf.Tensor,
|
|
47
|
+
indices: tf.Tensor,
|
|
48
|
+
updates: tf.Tensor,
|
|
49
|
+
) -> tf.Tensor:
|
|
50
|
+
if indices.dtype == tf.bool:
|
|
51
|
+
indices = keras.ops.stack(keras.ops.where(indices), axis=-1)
|
|
52
|
+
expected_rank = len(keras.ops.shape(inputs))
|
|
53
|
+
current_rank = len(keras.ops.shape(indices))
|
|
54
|
+
for _ in range(expected_rank - current_rank):
|
|
55
|
+
indices = keras.ops.expand_dims(indices, axis=-1)
|
|
56
|
+
return keras.ops.scatter_update(inputs, indices, updates)
|
|
57
|
+
|
|
58
|
+
def edge_softmax(
|
|
59
|
+
score: tf.Tensor,
|
|
60
|
+
edge_target: tf.Tensor
|
|
61
|
+
) -> tf.Tensor:
|
|
62
|
+
num_segments = keras.ops.cond(
|
|
63
|
+
keras.ops.shape(edge_target)[0] > 0,
|
|
64
|
+
lambda: keras.ops.maximum(keras.ops.max(edge_target) + 1, 1),
|
|
65
|
+
lambda: 0
|
|
66
|
+
)
|
|
67
|
+
score_max = keras.ops.segment_max(
|
|
68
|
+
score, edge_target, num_segments, sorted=False
|
|
69
|
+
)
|
|
70
|
+
score_max = gather(score_max, edge_target)
|
|
71
|
+
numerator = keras.ops.exp(score - score_max)
|
|
72
|
+
denominator = keras.ops.segment_sum(
|
|
73
|
+
numerator, edge_target, num_segments, sorted=False
|
|
74
|
+
)
|
|
75
|
+
denominator = gather(denominator, edge_target)
|
|
76
|
+
return numerator / denominator
|
|
77
|
+
|
|
78
|
+
def segment_mean(
|
|
79
|
+
data: tf.Tensor,
|
|
80
|
+
segment_ids: tf.Tensor,
|
|
81
|
+
num_segments: int | None = None,
|
|
82
|
+
sorted: bool = False,
|
|
83
|
+
) -> tf.Tensor:
|
|
84
|
+
if num_segments is None:
|
|
85
|
+
num_segments = keras.ops.max(segment_ids) + 1
|
|
86
|
+
if backend.backend() == 'tensorflow':
|
|
87
|
+
return tf.math.unsorted_segment_mean(
|
|
88
|
+
data=data,
|
|
89
|
+
segment_ids=segment_ids,
|
|
90
|
+
num_segments=num_segments
|
|
91
|
+
)
|
|
92
|
+
x = keras.ops.segment_sum(
|
|
93
|
+
data=data,
|
|
94
|
+
segment_ids=segment_ids,
|
|
95
|
+
num_segments=num_segments,
|
|
96
|
+
sorted=sorted
|
|
97
|
+
)
|
|
98
|
+
sizes = keras.ops.cast(
|
|
99
|
+
keras.ops.bincount(segment_ids, minlength=num_segments),
|
|
100
|
+
dtype=x.dtype
|
|
101
|
+
)
|
|
102
|
+
return x / sizes[:, None]
|
|
103
|
+
|
|
104
|
+
def gaussian(
|
|
105
|
+
x: tf.Tensor,
|
|
106
|
+
mean: tf.Tensor,
|
|
107
|
+
std: tf.Tensor
|
|
108
|
+
) -> tf.Tensor:
|
|
109
|
+
expected_rank = len(keras.ops.shape(x))
|
|
110
|
+
current_rank = len(keras.ops.shape(mean))
|
|
111
|
+
for _ in range(expected_rank - current_rank):
|
|
112
|
+
mean = keras.ops.expand_dims(mean, axis=0)
|
|
113
|
+
std = keras.ops.expand_dims(std, axis=0)
|
|
114
|
+
a = (2 * np.pi) ** 0.5
|
|
115
|
+
return keras.ops.exp(-0.5 * (((x - mean) / std) ** 2)) / (a * std)
|
|
116
|
+
|
|
117
|
+
def euclidean_distance(
|
|
118
|
+
x1: tf.Tensor,
|
|
119
|
+
x2: tf.Tensor,
|
|
120
|
+
axis: int = -1
|
|
121
|
+
) -> tf.Tensor:
|
|
122
|
+
relative_distance = keras.ops.subtract(x1, x2)
|
|
123
|
+
return keras.ops.sqrt(
|
|
124
|
+
keras.ops.sum(
|
|
125
|
+
keras.ops.square(relative_distance),
|
|
126
|
+
axis=axis,
|
|
127
|
+
keepdims=True
|
|
128
|
+
)
|
|
129
|
+
)
|