molcraft 0.1.0rc10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
molcraft/__init__.py ADDED
@@ -0,0 +1,18 @@
1
+ __version__ = '0.1.0rc10'
2
+
3
+ import os
4
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
5
+
6
+ from molcraft import chem
7
+ from molcraft import features
8
+ from molcraft import descriptors
9
+ from molcraft import featurizers
10
+ from molcraft import layers
11
+ from molcraft import models
12
+ from molcraft import ops
13
+ from molcraft import records
14
+ from molcraft import tensors
15
+ from molcraft import callbacks
16
+ from molcraft import datasets
17
+ from molcraft import losses
18
+ from molcraft import trainers
molcraft/callbacks.py ADDED
@@ -0,0 +1,100 @@
1
+ import warnings
2
+ import keras
3
+ import numpy as np
4
+
5
+
6
+ class TensorBoard(keras.callbacks.TensorBoard):
7
+
8
+ def _log_weights(self, epoch):
9
+ with self._train_writer.as_default():
10
+ for layer in self.model.layers:
11
+ for weight in layer.weights:
12
+ # Use weight.path istead of weight.name to distinguish
13
+ # weights of different layers.
14
+ histogram_weight_name = weight.path + "/histogram"
15
+ self.summary.histogram(
16
+ histogram_weight_name, weight, step=epoch
17
+ )
18
+ if self.write_images:
19
+ image_weight_name = weight.path + "/image"
20
+ self._log_weight_as_image(
21
+ weight, image_weight_name, epoch
22
+ )
23
+ self._train_writer.flush()
24
+
25
+
26
+ class LearningRateDecay(keras.callbacks.LearningRateScheduler):
27
+
28
+ def __init__(self, rate: float, delay: int = 0, **kwargs):
29
+
30
+ def lr_schedule(epoch: int, lr: float):
31
+ if epoch < delay:
32
+ return float(lr)
33
+ return float(lr * keras.ops.exp(-rate))
34
+
35
+ super().__init__(schedule=lr_schedule, **kwargs)
36
+
37
+
38
+ class Rollback(keras.callbacks.Callback):
39
+ """Rollback callback.
40
+
41
+ Currently, this callback simply restores the model and (optionally) the optimizer
42
+ variables if current loss deviates too much from the best observed loss.
43
+
44
+ This callback might be useful in situations where the loss tend to spike and put
45
+ the model in an undesired/problematic high-loss parameter space.
46
+
47
+ Args:
48
+ tolerance (float):
49
+ The threshold for when the restoration is triggered. The devaiation is
50
+ calculated as follows: (current_loss - best_loss) / best_loss.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ tolerance: float = 0.5,
56
+ rollback_optimizer: bool = True,
57
+ ):
58
+ super().__init__()
59
+ self.tolerance = tolerance
60
+ self.rollback_optimizer = rollback_optimizer
61
+
62
+ def on_train_begin(self, logs=None):
63
+ self._rollback_weights = self._get_model_vars()
64
+ if self.rollback_optimizer:
65
+ self._rollback_optimizer_vars = self._get_optimizer_vars()
66
+ self._rollback_loss = float('inf')
67
+
68
+ def on_epoch_end(self, epoch: int, logs: dict = None):
69
+ current_loss = logs.get('val_loss', logs.get('loss'))
70
+ deviation = (current_loss - self._rollback_loss) / self._rollback_loss
71
+
72
+ if np.isnan(current_loss) or np.isinf(current_loss):
73
+ self._rollback()
74
+ # Rolling back model because of nan or inf loss
75
+ return
76
+
77
+ if deviation > self.tolerance:
78
+ self._rollback()
79
+ # Rolling back model because of large loss deviation.
80
+ return
81
+
82
+ if current_loss < self._rollback_loss:
83
+ self._save_state(current_loss)
84
+
85
+ def _save_state(self, current_loss: float) -> None:
86
+ self._rollback_loss = current_loss
87
+ self._rollback_weights = self._get_model_vars()
88
+ if self.rollback_optimizer:
89
+ self._rollback_optimizer_vars = self._get_optimizer_vars()
90
+
91
+ def _rollback(self) -> None:
92
+ self.model.set_weights(self._rollback_weights)
93
+ if self.rollback_optimizer:
94
+ self.model.optimizer.set_weights(self._rollback_optimizer_vars)
95
+
96
+ def _get_optimizer_vars(self):
97
+ return [v.numpy() for v in self.model.optimizer.variables]
98
+
99
+ def _get_model_vars(self):
100
+ return self.model.get_weights()