libinephany 1.1.4__py3-none-any.whl → 1.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -261,7 +261,8 @@ class Observer(ABC):
261
261
  f"{self.invalid_observation_threshold})! Ratio: {inf_ratio:.2%} ({inf_count}/{len(observations)})"
262
262
  )
263
263
  validate = False
264
- logger.debug(
264
+
265
+ logger.trace(
265
266
  f'Observer: {self.__class__.__name__} observation validation {"passed" if validate else "failed"}!'
266
267
  )
267
268
 
@@ -47,7 +47,7 @@ from .model_observers import (
47
47
  NumberOfLayers,
48
48
  NumberOfParameters,
49
49
  )
50
- from .progress_observers import EpochsCompleted, ProgressAtEachCheckpoint, TrainingProgress
50
+ from .progress_observers import EpochsCompleted, ProgressAtEachCheckpoint, StagnationObserver, TrainingProgress
51
51
 
52
52
  __all__ = [
53
53
  InitialHyperparameters.__name__,
@@ -84,4 +84,5 @@ __all__ = [
84
84
  CosineSimilarityObserverOfGradientAndMomentum.__name__,
85
85
  CosineSimilarityObserverOfGradientAndUpdate.__name__,
86
86
  CosineSimilarityOfGradientAndParameter.__name__,
87
+ StagnationObserver.__name__,
87
88
  ]
@@ -3,7 +3,7 @@
3
3
  # IMPORTS
4
4
  #
5
5
  # ======================================================================================================================
6
-
6
+ import math
7
7
  from typing import Any
8
8
 
9
9
  from libinephany.observations.observation_utils import StatisticStorageTypes
@@ -146,3 +146,76 @@ class ProgressAtEachCheckpoint(LHOPTCheckpointBaseObserver):
146
146
  return current_progress
147
147
  else:
148
148
  return self.last_value
149
+
150
+
151
+ class StagnationObserver(GlobalObserver):
152
+
153
+ def __init__(self, stagnation_threshold: float = 0.0, log_transform: bool = True, **kwargs):
154
+ """
155
+ :param stagnation_threshold: The loss improvement threshold. Default is 0.0. This is the amount the validation
156
+ loss must improve by over the best validation loss to reset the stagnation counter.
157
+ We assume that improvement means loss decreases regardless of the sign of the loss, which means the best loss is
158
+ negative infinity.
159
+ :param log_transform: Whether to log transform the stagnation counter. Default is True.
160
+ :param kwargs: Miscellaneous keyword arguments.
161
+ """
162
+ super().__init__(**kwargs)
163
+ self.best_validation_loss: float | None = None
164
+ self.stagnation_counter: int = 0
165
+ self.stagnation_threshold = stagnation_threshold
166
+ self.log_transform = log_transform
167
+
168
+ @property
169
+ def can_standardize(self) -> bool:
170
+ """
171
+ :return: Whether the observation can be standardized.
172
+ """
173
+
174
+ return False
175
+
176
+ def _get_observation_format(self) -> StatisticStorageTypes:
177
+ """
178
+ :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
179
+ enumeration class.
180
+ """
181
+
182
+ return StatisticStorageTypes.FLOAT
183
+
184
+ def _observe(
185
+ self,
186
+ observation_inputs: ObservationInputs,
187
+ hyperparameter_states: HyperparameterStates,
188
+ tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
189
+ action_taken: float | int | None,
190
+ ) -> float | int | list[int | float] | TensorStatistics:
191
+ """
192
+ :param observation_inputs: Observation input metrics not calculated with statistic trackers.
193
+ :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
194
+ :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
195
+ names to floats or TensorStatistic models.
196
+ :param action_taken: Action taken by the agent this class instance is assigned to.
197
+ :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
198
+ """
199
+ # Validate input values, if loss is invalid, increment stagnation counter
200
+ if math.isnan(observation_inputs.validation_loss) or math.isinf(observation_inputs.validation_loss):
201
+ self.stagnation_counter += 1
202
+ else:
203
+ if self.best_validation_loss is None:
204
+ self.best_validation_loss = observation_inputs.validation_loss
205
+ self.stagnation_counter = 0
206
+ else:
207
+ improvement = self.best_validation_loss - observation_inputs.validation_loss
208
+ if improvement > self.stagnation_threshold:
209
+ self.best_validation_loss = observation_inputs.validation_loss
210
+ self.stagnation_counter = 0
211
+ else:
212
+ self.stagnation_counter += 1
213
+ return self.stagnation_counter if not self.log_transform else math.log(self.stagnation_counter + 1)
214
+
215
+ def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
216
+ """
217
+ :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
218
+ needed.
219
+ """
220
+
221
+ return {}
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: libinephany
3
- Version: 1.1.4
3
+ Version: 1.1.6
4
4
  Summary: Inephany library containing code commonly used by multiple subpackages.
5
5
  Author-email: Inephany <info@inephany.com>
6
6
  License: Apache 2.0
@@ -8,17 +8,17 @@ libinephany/observations/pipeline_coordinator.py,sha256=mLfaHhkXVhMp9w5jWIAL3jPy
8
8
  libinephany/observations/statistic_manager.py,sha256=LLg1zSxnJr2oQQepYla3qoUuRy10rsthr9jta4wEbnc,8956
9
9
  libinephany/observations/statistic_trackers.py,sha256=nBxOxB9NHHr6OZylmgPpY3AXOr7kAJ0MF899YbAP_oc,48201
10
10
  libinephany/observations/observers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- libinephany/observations/observers/base_observers.py,sha256=nUgpcKGwKagH9zqdztkjcycxmLToVrq8ju0WB8zJtdk,19885
11
+ libinephany/observations/observers/base_observers.py,sha256=THfOE4ozP_LG1UfsiuJkH6dhqp4VOpXY4NmpVlAa-kY,19886
12
12
  libinephany/observations/observers/local_observers.py,sha256=HvOQhVEANQJjAl6Ln_YuP0-fNt_Xkh8ixlkb-A5JqDM,46318
13
13
  libinephany/observations/observers/observer_containers.py,sha256=VNyqGgxYJ4r49Msp_kk-POgicb-_5w54twuT1qfNdxw,9562
14
- libinephany/observations/observers/global_observers/__init__.py,sha256=87WHRPYmL0tVsaTKUd91pwEpCZtHPSKRQoba2VQjswA,3018
14
+ libinephany/observations/observers/global_observers/__init__.py,sha256=H8WHtWyBmjBQkaoy6fxkgYCkr-FEKJXgIRpVnw_ThZU,3071
15
15
  libinephany/observations/observers/global_observers/base_classes.py,sha256=Q7OblhmKscypTs9JBepSQwo6ljjOdPKTU9kbpuhq_W4,7800
16
16
  libinephany/observations/observers/global_observers/constants.py,sha256=TDQM_sGU8Swze794oB4TBaXFjSddt0OBhYPVhrXQ9Ko,1654
17
17
  libinephany/observations/observers/global_observers/gradient_observers.py,sha256=j9uX7043ic06W6vb6vDt_PH2a5WtLFCdKQZ1JQGT24Q,19531
18
18
  libinephany/observations/observers/global_observers/hyperparameter_observers.py,sha256=5Av8FgwWBJtcn4gpDzPdOTlKOOYy2lVEHS_gt5Sz7xo,15334
19
19
  libinephany/observations/observers/global_observers/loss_observers.py,sha256=Kf943FiuYWuWvjmhgmp3TGyIQoZ27ZKJcxfTBwXS-gA,17761
20
20
  libinephany/observations/observers/global_observers/model_observers.py,sha256=SGWXrmTdgp0kHvEvDSF7d3v1FEcK1sQDMPLQ8Wy3qv4,29306
21
- libinephany/observations/observers/global_observers/progress_observers.py,sha256=ypLk1_POAjA8V8rAaQ0B6Qh8m_04s9PAoXsw1KxVrLg,5872
21
+ libinephany/observations/observers/global_observers/progress_observers.py,sha256=vJwue9tY6XjQKU_79ZAZMnC4HZjR8Y-AGyZpgEKiDbU,9284
22
22
  libinephany/observations/post_processors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
23
  libinephany/observations/post_processors/postprocessors.py,sha256=43_e5UaDPr2KbAvqc_w3wLqnlm7bgRjqgCtyQ95-8cM,5913
24
24
  libinephany/pydantic_models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -57,8 +57,8 @@ libinephany/utils/typing.py,sha256=rGbaPO3MaUndsWiC_wHzReD_TOLYqb43i01pKN-j7Xs,6
57
57
  libinephany/web_apps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
58
58
  libinephany/web_apps/error_logger.py,sha256=QpspO726Uoyyr6lBEEb3Q9XqhVOXUM4AaYE7vbnk31c,18153
59
59
  libinephany/web_apps/web_app_utils.py,sha256=qiq_lasPipgN1RgRudPJc342kYci8O_4RqppxmIX8NY,4095
60
- libinephany-1.1.4.dist-info/licenses/LICENSE,sha256=pogfDoMBP07ehIOvWymuWIar8pg2YLUhqOHsJQU3wdc,9250
61
- libinephany-1.1.4.dist-info/METADATA,sha256=jRCMmu7yAURPjD2ZRJnKxCMpZD9son6Xx4mQ1HBbWzc,8389
62
- libinephany-1.1.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
63
- libinephany-1.1.4.dist-info/top_level.txt,sha256=bYAOXQdJgIoLkO2Ui0kxe7pSYegS_e38u0dMscd7COQ,12
64
- libinephany-1.1.4.dist-info/RECORD,,
60
+ libinephany-1.1.6.dist-info/licenses/LICENSE,sha256=pogfDoMBP07ehIOvWymuWIar8pg2YLUhqOHsJQU3wdc,9250
61
+ libinephany-1.1.6.dist-info/METADATA,sha256=js8dpqAYAwo1UTFR7sOReZMQ25IWrX4EucCdDfEQFgc,8389
62
+ libinephany-1.1.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
63
+ libinephany-1.1.6.dist-info/top_level.txt,sha256=bYAOXQdJgIoLkO2Ui0kxe7pSYegS_e38u0dMscd7COQ,12
64
+ libinephany-1.1.6.dist-info/RECORD,,