molcraft 0.1.0a6__tar.gz → 0.1.0a7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

Files changed (31) hide show
  1. {molcraft-0.1.0a6 → molcraft-0.1.0a7}/PKG-INFO +1 -1
  2. {molcraft-0.1.0a6 → molcraft-0.1.0a7}/molcraft/__init__.py +1 -1
  3. molcraft-0.1.0a7/molcraft/callbacks.py +93 -0
  4. {molcraft-0.1.0a6 → molcraft-0.1.0a7}/molcraft/chem.py +33 -17
  5. {molcraft-0.1.0a6 → molcraft-0.1.0a7}/molcraft/conformers.py +0 -4
  6. {molcraft-0.1.0a6 → molcraft-0.1.0a7}/molcraft/layers.py +68 -45
  7. {molcraft-0.1.0a6 → molcraft-0.1.0a7}/molcraft/models.py +16 -1
  8. {molcraft-0.1.0a6 → molcraft-0.1.0a7}/molcraft/ops.py +2 -2
  9. {molcraft-0.1.0a6 → molcraft-0.1.0a7}/molcraft.egg-info/PKG-INFO +1 -1
  10. {molcraft-0.1.0a6 → molcraft-0.1.0a7}/tests/test_models.py +60 -3
  11. molcraft-0.1.0a6/molcraft/callbacks.py +0 -33
  12. {molcraft-0.1.0a6 → molcraft-0.1.0a7}/LICENSE +0 -0
  13. {molcraft-0.1.0a6 → molcraft-0.1.0a7}/README.md +0 -0
  14. {molcraft-0.1.0a6 → molcraft-0.1.0a7}/molcraft/datasets.py +0 -0
  15. {molcraft-0.1.0a6 → molcraft-0.1.0a7}/molcraft/descriptors.py +0 -0
  16. {molcraft-0.1.0a6 → molcraft-0.1.0a7}/molcraft/features.py +0 -0
  17. {molcraft-0.1.0a6 → molcraft-0.1.0a7}/molcraft/featurizers.py +0 -0
  18. {molcraft-0.1.0a6 → molcraft-0.1.0a7}/molcraft/losses.py +0 -0
  19. {molcraft-0.1.0a6 → molcraft-0.1.0a7}/molcraft/records.py +0 -0
  20. {molcraft-0.1.0a6 → molcraft-0.1.0a7}/molcraft/tensors.py +0 -0
  21. {molcraft-0.1.0a6 → molcraft-0.1.0a7}/molcraft.egg-info/SOURCES.txt +0 -0
  22. {molcraft-0.1.0a6 → molcraft-0.1.0a7}/molcraft.egg-info/dependency_links.txt +0 -0
  23. {molcraft-0.1.0a6 → molcraft-0.1.0a7}/molcraft.egg-info/requires.txt +0 -0
  24. {molcraft-0.1.0a6 → molcraft-0.1.0a7}/molcraft.egg-info/top_level.txt +0 -0
  25. {molcraft-0.1.0a6 → molcraft-0.1.0a7}/pyproject.toml +0 -0
  26. {molcraft-0.1.0a6 → molcraft-0.1.0a7}/setup.cfg +0 -0
  27. {molcraft-0.1.0a6 → molcraft-0.1.0a7}/tests/test_chem.py +0 -0
  28. {molcraft-0.1.0a6 → molcraft-0.1.0a7}/tests/test_featurizers.py +0 -0
  29. {molcraft-0.1.0a6 → molcraft-0.1.0a7}/tests/test_layers.py +0 -0
  30. {molcraft-0.1.0a6 → molcraft-0.1.0a7}/tests/test_losses.py +0 -0
  31. {molcraft-0.1.0a6 → molcraft-0.1.0a7}/tests/test_tensors.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: molcraft
3
- Version: 0.1.0a6
3
+ Version: 0.1.0a7
4
4
  Summary: Graph Neural Networks for Molecular Machine Learning
5
5
  Author-email: Alexander Kensert <alexander.kensert@gmail.com>
6
6
  License: MIT License
@@ -1,4 +1,4 @@
1
- __version__ = '0.1.0a6'
1
+ __version__ = '0.1.0a7'
2
2
 
3
3
  import os
4
4
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
@@ -0,0 +1,93 @@
1
+ import keras
2
+ import warnings
3
+ import numpy as np
4
+
5
+
6
+ class TensorBoard(keras.callbacks.TensorBoard):
7
+
8
+ def _log_weights(self, epoch):
9
+ with self._train_writer.as_default():
10
+ for layer in self.model.layers:
11
+ for weight in layer.weights:
12
+ # Use weight.path istead of weight.name to distinguish
13
+ # weights of different layers.
14
+ histogram_weight_name = weight.path + "/histogram"
15
+ self.summary.histogram(
16
+ histogram_weight_name, weight, step=epoch
17
+ )
18
+ if self.write_images:
19
+ image_weight_name = weight.path + "/image"
20
+ self._log_weight_as_image(
21
+ weight, image_weight_name, epoch
22
+ )
23
+ self._train_writer.flush()
24
+
25
+
26
+ class LearningRateDecay(keras.callbacks.LearningRateScheduler):
27
+
28
+ def __init__(self, rate: float, delay: int = 0, **kwargs):
29
+
30
+ def lr_schedule(epoch: int, lr: float):
31
+ if epoch < delay:
32
+ return float(lr)
33
+ return float(lr * keras.ops.exp(-rate))
34
+
35
+ super().__init__(schedule=lr_schedule, **kwargs)
36
+
37
+
38
+ class Rollback(keras.callbacks.Callback):
39
+
40
+ def __init__(
41
+ self,
42
+ frequency: int = None,
43
+ tolerance: float = 0.5,
44
+ rollback_optimizer: bool = True,
45
+ ):
46
+ super().__init__()
47
+ self.frequency = frequency or 1_000_000_000
48
+ self.tolerance = tolerance
49
+ self.rollback_optimizer = rollback_optimizer
50
+
51
+ def on_train_begin(self, logs=None):
52
+ self.rollback_weights = self._get_model_vars()
53
+ self.rollback_optimizer_vars = self._get_optimizer_vars()
54
+ self.rollback_loss = float('inf')
55
+
56
+ def on_epoch_end(self, epoch: int, logs: dict = None):
57
+ current_loss = logs.get('val_loss', logs.get('loss'))
58
+ deviation = (current_loss - self.rollback_loss) / self.rollback_loss
59
+
60
+ if np.isnan(current_loss) or np.isinf(current_loss):
61
+ self._rollback()
62
+ print("\nRolling back model, found nan or inf loss.\n")
63
+ return
64
+
65
+ if deviation > self.tolerance:
66
+ self._rollback()
67
+ print(f"\nRolling back model, found too large deviation: {deviation:.3f}\n")
68
+
69
+ if epoch and epoch % self.frequency == 0:
70
+ self._rollback()
71
+ print(f"\nRolling back model, {epoch} % {self.frequency} == 0\n")
72
+ return
73
+
74
+ if current_loss < self.rollback_loss:
75
+ self._save_state(current_loss)
76
+
77
+ def _save_state(self, current_loss: float) -> None:
78
+ self.rollback_loss = current_loss
79
+ self.rollback_weights = self._get_model_vars()
80
+ if self.rollback_optimizer:
81
+ self.rollback_optimizer_vars = self._get_optimizer_vars()
82
+
83
+ def _rollback(self) -> None:
84
+ self.model.set_weights(self.rollback_weights)
85
+ if self.rollback_optimizer:
86
+ self.model.optimizer.set_weights(self.rollback_optimizer_vars)
87
+
88
+ def _get_optimizer_vars(self):
89
+ return [v.numpy() for v in self.model.optimizer.variables]
90
+
91
+ def _get_model_vars(self):
92
+ return self.model.get_weights()
93
+
@@ -400,7 +400,6 @@ def embed_conformers(
400
400
  mol: Mol,
401
401
  num_conformers: int,
402
402
  method: str = 'ETKDGv3',
403
- force: bool = True,
404
403
  **kwargs
405
404
  ) -> None:
406
405
  available_embedding_methods = {
@@ -411,27 +410,39 @@ def embed_conformers(
411
410
  'srETKDGv3': rdDistGeom.srETKDGv3(),
412
411
  'KDG': rdDistGeom.KDG()
413
412
  }
414
- default_embedding_method = 'ETKDGv3'
415
413
  mol = Mol(mol)
416
- params = available_embedding_methods.get(method)
417
- if params is None:
418
- warn(
419
- f"Could not find `method` {method}. "
420
- f"Automatically setting method to {default_embedding_method}."
414
+ embedding_method = available_embedding_methods.get(method)
415
+ if embedding_method is None:
416
+ raise ValueError(
417
+ f'Could not find `method` {method!r}. Specify either of: '
418
+ '`ETDG`, `ETKDG`, `ETKDGv2`, `ETKDGv3`, `srETKDGv3` or `KDG`.'
421
419
  )
422
- params = available_embedding_methods[default_embedding_method]
420
+
423
421
  for key, value in kwargs.items():
424
- setattr(params, key, value)
422
+ setattr(embedding_method, key, value)
425
423
 
426
- success = rdDistGeom.EmbedMultipleConfs(mol, numConfs=num_conformers, params=params)
424
+ success = rdDistGeom.EmbedMultipleConfs(
425
+ mol, numConfs=num_conformers, params=embedding_method
426
+ )
427
427
  if not len(success):
428
- warning = 'Could not embed conformer(s).'
429
- if not force:
430
- warn(warning)
428
+ warn(
429
+ f'Could not embed conformer(s) for {mol.canonical_smiles!r} using the '
430
+ 'speified method. Giving it another try with more permissive methods.'
431
+ )
432
+ max_attempts = (20 * mol.num_atoms) # increasing it from 10xN to 20xN
433
+ for fallback_method in [method, 'ETDG', 'KDG']:
434
+ fallback_embedding_method = available_embedding_methods[fallback_method]
435
+ fallback_embedding_method.useRandomCoords = True
436
+ fallback_embedding_method.maxAttempts = max_attempts
437
+ success = rdDistGeom.EmbedMultipleConfs(
438
+ mol, numConfs=num_conformers, params=fallback_embedding_method
439
+ )
440
+ if len(success):
441
+ break
431
442
  else:
432
- solution = ' Embedding a conformer (in 3D space) using (x, y) coordinates.'
433
- warn(warning + solution)
434
- rdDepictor.Compute2DCoords(mol)
443
+ raise RuntimeError(
444
+ f'Could not embed conformer(s) for {mol.canonical_smiles!r}. '
445
+ )
435
446
  return mol
436
447
 
437
448
  def optimize_conformers(
@@ -445,6 +456,11 @@ def optimize_conformers(
445
456
  available_force_field_methods = [
446
457
  'MMFF', 'MMFF94', 'MMFF94s', 'UFF'
447
458
  ]
459
+ if method not in available_force_field_methods:
460
+ raise ValueError(
461
+ f'Could not find `method` {method!r}. Specify either of: '
462
+ '`UFF`, `MMFF`, `MMFF94` or `MMFF94s`.'
463
+ )
448
464
  mol = Mol(mol)
449
465
  try:
450
466
  if method.startswith('MMFF'):
@@ -469,7 +485,7 @@ def optimize_conformers(
469
485
  except RuntimeError as e:
470
486
  warn(
471
487
  f'{method} force field minimization raised {e}. '
472
- '\nProceeding without force field minimization...'
488
+ '\nProceeding without force field minimization.'
473
489
  )
474
490
  return mol
475
491
 
@@ -24,19 +24,16 @@ class ConformerEmbedder(ConformerProcessor):
24
24
  self,
25
25
  method: str = 'ETKDGv3',
26
26
  num_conformers: int = 5,
27
- force: bool = True,
28
27
  **kwargs,
29
28
  ) -> None:
30
29
  self.method = method
31
30
  self.num_conformers = num_conformers
32
- self.force = force
33
31
  self.kwargs = kwargs
34
32
 
35
33
  def get_config(self) -> dict:
36
34
  config = {
37
35
  'method': self.method,
38
36
  'num_conformers': self.num_conformers,
39
- 'force': self.force,
40
37
  }
41
38
  config.update({
42
39
  k: v for (k, v) in self.kwargs.items()
@@ -48,7 +45,6 @@ class ConformerEmbedder(ConformerProcessor):
48
45
  mol,
49
46
  method=self.method,
50
47
  num_conformers=self.num_conformers,
51
- force=self.force,
52
48
  **self.kwargs,
53
49
  )
54
50
 
@@ -1231,7 +1231,7 @@ class EGConv3D(GraphConv):
1231
1231
  def __init__(
1232
1232
  self,
1233
1233
  units: int = 128,
1234
- activation: keras.layers.Activation | str | None = None,
1234
+ activation: keras.layers.Activation | str | None = 'silu',
1235
1235
  use_bias: bool = True,
1236
1236
  normalize: bool = False,
1237
1237
  **kwargs
@@ -1251,31 +1251,52 @@ class EGConv3D(GraphConv):
1251
1251
  'which is required for Conv3D layers.'
1252
1252
  )
1253
1253
  self._has_edge_feature = 'feature' in spec.edge
1254
- self.message_fn = self.get_dense(self.units, activation=self._activation)
1255
- self.dense_position = self.get_dense(1, use_bias=False, kernel_initializer='zeros')
1254
+ self._message_feedforward_intermediate = self.get_dense(
1255
+ self.units, activation=self._activation
1256
+ )
1257
+ self._message_feedforward_final = self.get_dense(
1258
+ self.units, activation=self._activation
1259
+ )
1260
+
1261
+ self._coord_feedforward_intermediate = self.get_dense(
1262
+ self.units, activation=self._activation
1263
+ )
1264
+ self._coord_feedforward_final = self.get_dense(
1265
+ 1, use_bias=False, activation='tanh'
1266
+ )
1256
1267
 
1257
1268
  has_overridden_update = self.__class__.update != EGConv3D.update
1258
1269
  if not has_overridden_update:
1259
- self.update_fn = self.get_dense(self.units, activation=self._activation)
1260
- self.output_dense = self.get_dense(self.units)
1270
+ self._feedforward_intermediate = self.get_dense(
1271
+ self.units, activation=self._activation
1272
+ )
1273
+ self._feedforward_output = self.get_dense(self.units)
1261
1274
 
1262
1275
  def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1263
1276
  relative_node_coordinate = keras.ops.subtract(
1264
1277
  tensor.gather('coordinate', 'target'),
1265
1278
  tensor.gather('coordinate', 'source')
1266
1279
  )
1267
- euclidean_distance = keras.ops.sum(
1268
- keras.ops.square(
1269
- relative_node_coordinate
1270
- ),
1280
+ squared_distance = keras.ops.sum(
1281
+ keras.ops.square(relative_node_coordinate),
1271
1282
  axis=-1,
1272
1283
  keepdims=True
1273
1284
  )
1285
+
1286
+ # For numerical stability (i.e., to prevent NaN losses), this implementation of `EGConv3D`
1287
+ # either needs to apply a `tanh` activation to the output of `self._coord_feedforward_final`,
1288
+ # or normalize `relative_node_cordinate` as follows:
1289
+ #
1290
+ # norm = keras.ops.sqrt(squared_distance) + keras.backend.epsilon()
1291
+ # relative_node_coordinate /= norm
1292
+ #
1293
+ # For now, this implementation does the former.
1294
+
1274
1295
  feature = keras.ops.concatenate(
1275
1296
  [
1276
1297
  tensor.gather('feature', 'target'),
1277
1298
  tensor.gather('feature', 'source'),
1278
- euclidean_distance,
1299
+ squared_distance,
1279
1300
  ],
1280
1301
  axis=-1
1281
1302
  )
@@ -1287,10 +1308,15 @@ class EGConv3D(GraphConv):
1287
1308
  ],
1288
1309
  axis=-1
1289
1310
  )
1290
- message = self.message_fn(feature)
1311
+ message = self._message_feedforward_final(
1312
+ self._message_feedforward_intermediate(feature)
1313
+ )
1314
+
1291
1315
  relative_node_coordinate = keras.ops.multiply(
1292
- relative_node_coordinate,
1293
- self.dense_position(message)
1316
+ relative_node_coordinate,
1317
+ self._coord_feedforward_final(
1318
+ self._coord_feedforward_intermediate(message)
1319
+ )
1294
1320
  )
1295
1321
  return tensor.update(
1296
1322
  {
@@ -1302,26 +1328,26 @@ class EGConv3D(GraphConv):
1302
1328
  )
1303
1329
 
1304
1330
  def aggregate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1305
- # coefficient = keras.ops.bincount(
1306
- # tensor.edge['source'],
1307
- # minlength=tensor.num_nodes
1308
- # )
1309
- # coefficient = keras.ops.cast(
1310
- # coefficient, tensor.node['coordinate'].dtype
1311
- # )
1312
- # coefficient = keras.ops.expand_dims(
1313
- # keras.ops.divide_no_nan(1, coefficient), axis=1
1314
- # )
1315
-
1316
- updated_coordinate = tensor.aggregate('relative_node_coordinate', mode='mean')# * coefficient
1317
- updated_coordinate += tensor.node['coordinate']
1318
-
1331
+ coordinate = tensor.node['coordinate']
1332
+ coordinate += tensor.aggregate('relative_node_coordinate', mode='mean')
1333
+
1334
+ # Original implementation seems to apply sum aggregation, which does not
1335
+ # seem work well for this implementation of `EGConv3D`, as it causes
1336
+ # large output values and large initial losses. The magnitude of the
1337
+ # aggregated values of a sum aggregation depends on the number of
1338
+ # neighbors, which may be many and may differ from node to node (or
1339
+ # graph to graph). Therefore, a mean mean aggregation is performed
1340
+ # instead:
1319
1341
  aggregate = tensor.aggregate('message', mode='mean')
1342
+
1343
+ # Simply added to silence warning ('no gradients for variables ...')
1344
+ aggregate += (0.0 * keras.ops.sum(coordinate))
1345
+
1320
1346
  return tensor.update(
1321
1347
  {
1322
1348
  'node': {
1323
1349
  'aggregate': aggregate,
1324
- 'coordinate': updated_coordinate,
1350
+ 'coordinate': coordinate,
1325
1351
  },
1326
1352
  'edge': {
1327
1353
  'message': None,
@@ -1331,16 +1357,16 @@ class EGConv3D(GraphConv):
1331
1357
  )
1332
1358
 
1333
1359
  def update(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1334
- updated_node_feature = self.update_fn(
1335
- keras.ops.concatenate(
1336
- [
1337
- tensor.node['aggregate'],
1338
- tensor.node['feature']
1339
- ],
1340
- axis=-1
1341
- )
1360
+ feature = keras.ops.concatenate(
1361
+ [
1362
+ tensor.node['aggregate'],
1363
+ tensor.node['feature']
1364
+ ],
1365
+ axis=-1
1366
+ )
1367
+ updated_node_feature = self._feedforward_output(
1368
+ self._feedforward_intermediate(feature)
1342
1369
  )
1343
- updated_node_feature = self.output_dense(updated_node_feature)
1344
1370
  return tensor.update(
1345
1371
  {
1346
1372
  'node': {
@@ -1694,8 +1720,8 @@ class EdgeEmbedding(GraphLayer):
1694
1720
  mask = keras.ops.expand_dims(mask, -1)
1695
1721
  feature = keras.ops.where(mask, self._mask_feature, feature)
1696
1722
  elif self._allow_masking:
1697
- # Slience warning of 'no gradients for variables'
1698
- feature = feature + (self._mask_feature * 0.0)
1723
+ # Simply added to silence warning ('no gradients for variables ...')
1724
+ feature += (0.0 * self._mask_feature)
1699
1725
 
1700
1726
  if self._normalize:
1701
1727
  feature = self._norm(feature)
@@ -1999,14 +2025,11 @@ def Input(spec: tensors.GraphTensor.Spec) -> dict:
1999
2025
  for outer_field, data in spec.__dict__.items():
2000
2026
  inputs[outer_field] = {}
2001
2027
  for inner_field, nested_spec in data.items():
2002
- if inner_field in ['label', 'weight']:
2028
+ if outer_field == 'context' and inner_field in ['label', 'weight']:
2003
2029
  # Remove context label and weight from the symbolic input
2004
2030
  # as a functional model is strict for what input can be passed.
2005
- # We want to be able to pass a graph with or without labels
2006
- # and sample weights. The __call__ method of the `GraphModel`
2007
- # temporarily pops label and weight to avoid errors.
2008
- if outer_field == 'context':
2009
- continue
2031
+ # (We want to train and predict with the model.)
2032
+ continue
2010
2033
  kwargs = {
2011
2034
  'shape': nested_spec.shape[1:],
2012
2035
  'dtype': nested_spec.dtype,
@@ -415,6 +415,10 @@ class GraphModel(layers.GraphLayer, keras.models.Model):
415
415
  def predict_step(self, tensor: tensors.GraphTensor) -> np.ndarray:
416
416
  return self(tensor, training=False)
417
417
 
418
+ def compute_loss(self, x, y, y_pred, sample_weight=None):
419
+ y, y_pred, sample_weight = _maybe_reshape(y, y_pred, sample_weight)
420
+ return super().compute_loss(x, y, y_pred, sample_weight)
421
+
418
422
  def compute_metrics(self, x, y, y_pred, sample_weight=None) -> dict[str, float]:
419
423
  loss = self.compute_loss(x, y, y_pred, sample_weight)
420
424
  metric_results = {}
@@ -423,7 +427,7 @@ class GraphModel(layers.GraphLayer, keras.models.Model):
423
427
  metric.update_state(loss)
424
428
  metric_results[metric.name] = metric.result()
425
429
  else:
426
- metric.update_state(y, y_pred)
430
+ metric.update_state(y, y_pred, sample_weight=sample_weight)
427
431
  metric_results.update(metric.result())
428
432
  return metric_results
429
433
 
@@ -593,3 +597,14 @@ def _make_dataset(x: tensors.GraphTensor, batch_size: int):
593
597
  .batch(batch_size)
594
598
  .prefetch(-1)
595
599
  )
600
+
601
+ def _maybe_reshape(y, y_pred, sample_weight):
602
+ if (
603
+ sample_weight is not None and
604
+ len(keras.ops.shape(sample_weight)) == 2 and
605
+ sample_weight.shape == y_pred.shape
606
+ ):
607
+ y = keras.ops.reshape(y, [-1])
608
+ y_pred = keras.ops.reshape(y_pred, [-1])
609
+ sample_weight = keras.ops.reshape(sample_weight, [-1])
610
+ return y, y_pred, sample_weight
@@ -67,7 +67,7 @@ def edge_softmax(
67
67
  edge_target: tf.Tensor
68
68
  ) -> tf.Tensor:
69
69
  num_segments = keras.ops.cond(
70
- keras.ops.shape(edge_target)[0] > 0,
70
+ keras.ops.greater(keras.ops.shape(edge_target)[0], 0),
71
71
  lambda: keras.ops.maximum(keras.ops.max(edge_target) + 1, 1),
72
72
  lambda: 0
73
73
  )
@@ -100,7 +100,7 @@ def segment_mean(
100
100
  ) -> tf.Tensor:
101
101
  if num_segments is None:
102
102
  num_segments = keras.ops.cond(
103
- keras.ops.shape(segment_ids)[0] > 0,
103
+ keras.ops.greater(keras.ops.shape(segment_ids)[0], 0),
104
104
  lambda: keras.ops.max(segment_ids) + 1,
105
105
  lambda: 0
106
106
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: molcraft
3
- Version: 0.1.0a6
3
+ Version: 0.1.0a7
4
4
  Summary: Graph Neural Networks for Molecular Machine Learning
5
5
  Author-email: Alexander Kensert <alexander.kensert@gmail.com>
6
6
  License: MIT License
@@ -1,6 +1,8 @@
1
1
  import unittest
2
-
2
+ import tempfile
3
+ import shutil
3
4
  import keras
5
+ import numpy as np
4
6
 
5
7
  from molcraft import tensors
6
8
  from molcraft import layers
@@ -138,9 +140,64 @@ class TestModel(unittest.TestCase):
138
140
  self.assertTrue(isinstance(metrics, list))
139
141
  del model
140
142
 
141
- # TODO: Write test for saving and loading model: model(tensor) == loaded_model(tensor)
142
143
  def test_saved_model(self):
143
- pass
144
+
145
+ def get_model(tensor):
146
+ inputs = layers.Input(tensor.spec)
147
+ x = layers.NodeEmbedding(32)(inputs)
148
+ x = layers.EdgeEmbedding(32)(x)
149
+ x = layers.GTConv(32)(x)
150
+ x = layers.GTConv(32)(x)
151
+ x = layers.Readout('sum')(x)
152
+ outputs = keras.layers.Dense(1)(x)
153
+ return models.GraphModel(inputs, outputs)
154
+
155
+ @keras.saving.register_keras_serializable()
156
+ class Model(models.GraphModel):
157
+ def __init__(self, **kwargs):
158
+ super().__init__(**kwargs)
159
+ self.e1 = layers.NodeEmbedding(32)
160
+ self.e2 = layers.EdgeEmbedding(32)
161
+ self.c1 = layers.GTConv(32)
162
+ self.c2 = layers.GTConv(32)
163
+ self.r = layers.Readout('sum')
164
+ self.d = keras.layers.Dense(1)
165
+ def propagate(self, tensor):
166
+ return self.d(self.r(self.c2(self.c1(self.e2(self.e1(tensor))))))
167
+
168
+ example = self.tensors[-1]
169
+
170
+ tmp_dir = tempfile.mkdtemp()
171
+ tmp_file = tmp_dir + '/model.keras'
172
+ with self.subTest(functional_model=True):
173
+ model = get_model(example)
174
+ model.compile('adam', 'mse')
175
+ model.fit(example, verbose=0)
176
+ model.save(tmp_file)
177
+ loaded_model = models.load_model(tmp_file)
178
+ pred_1 = model.predict(example, verbose=0)
179
+ pred_2 = loaded_model.predict(example, verbose=0)
180
+ test_preds = np.all(pred_1.round(4) == pred_2.round(4))
181
+ self.assertTrue(test_preds)
182
+ test_vars = np.all([np.all((v1 == v2).numpy()) for (v1, v2) in zip(model.variables, loaded_model.variables)])
183
+ self.assertTrue(test_vars)
184
+ shutil.rmtree(tmp_dir)
185
+
186
+ tmp_dir = tempfile.mkdtemp()
187
+ tmp_file = tmp_dir + '/model.keras'
188
+ with self.subTest(functional_model=False):
189
+ model = Model()
190
+ model.compile('adam', 'mse')
191
+ model.fit(example, verbose=0)
192
+ model.save(tmp_file)
193
+ loaded_model = models.load_model(tmp_file)
194
+ pred_1 = model.predict(example, verbose=0)
195
+ pred_2 = loaded_model.predict(example, verbose=0)
196
+ test_preds = np.all(pred_1.round(4) == pred_2.round(4))
197
+ self.assertTrue(test_preds)
198
+ test_vars = np.all([np.all((v1 == v2).numpy()) for (v1, v2) in zip(model.variables, loaded_model.variables)])
199
+ self.assertTrue(test_vars)
200
+ shutil.rmtree(tmp_dir)
144
201
 
145
202
  def test_subclassed_model(self):
146
203
 
@@ -1,33 +0,0 @@
1
- import keras
2
-
3
-
4
- class TensorBoard(keras.callbacks.TensorBoard):
5
-
6
- def _log_weights(self, epoch):
7
- with self._train_writer.as_default():
8
- for layer in self.model.layers:
9
- for weight in layer.weights:
10
- # Use weight.path istead of weight.name to distinguish
11
- # weights of different layers.
12
- histogram_weight_name = weight.path + "/histogram"
13
- self.summary.histogram(
14
- histogram_weight_name, weight, step=epoch
15
- )
16
- if self.write_images:
17
- image_weight_name = weight.path + "/image"
18
- self._log_weight_as_image(
19
- weight, image_weight_name, epoch
20
- )
21
- self._train_writer.flush()
22
-
23
-
24
- class LearningRateDecay(keras.callbacks.LearningRateScheduler):
25
-
26
- def __init__(self, rate: float, delay: int = 0, **kwargs):
27
-
28
- def lr_schedule(epoch: int, lr: float):
29
- if epoch < delay:
30
- return float(lr)
31
- return float(lr * keras.ops.exp(-rate))
32
-
33
- super().__init__(schedule=lr_schedule, **kwargs)
File without changes
File without changes
File without changes
File without changes