libinephany 1.1.4__py3-none-any.whl → 1.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- libinephany/observations/observers/base_observers.py +2 -1
- libinephany/observations/observers/global_observers/__init__.py +2 -1
- libinephany/observations/observers/global_observers/progress_observers.py +74 -1
- {libinephany-1.1.4.dist-info → libinephany-1.1.6.dist-info}/METADATA +1 -1
- {libinephany-1.1.4.dist-info → libinephany-1.1.6.dist-info}/RECORD +8 -8
- {libinephany-1.1.4.dist-info → libinephany-1.1.6.dist-info}/WHEEL +0 -0
- {libinephany-1.1.4.dist-info → libinephany-1.1.6.dist-info}/licenses/LICENSE +0 -0
- {libinephany-1.1.4.dist-info → libinephany-1.1.6.dist-info}/top_level.txt +0 -0
@@ -261,7 +261,8 @@ class Observer(ABC):
|
|
261
261
|
f"{self.invalid_observation_threshold})! Ratio: {inf_ratio:.2%} ({inf_count}/{len(observations)})"
|
262
262
|
)
|
263
263
|
validate = False
|
264
|
-
|
264
|
+
|
265
|
+
logger.trace(
|
265
266
|
f'Observer: {self.__class__.__name__} observation validation {"passed" if validate else "failed"}!'
|
266
267
|
)
|
267
268
|
|
@@ -47,7 +47,7 @@ from .model_observers import (
|
|
47
47
|
NumberOfLayers,
|
48
48
|
NumberOfParameters,
|
49
49
|
)
|
50
|
-
from .progress_observers import EpochsCompleted, ProgressAtEachCheckpoint, TrainingProgress
|
50
|
+
from .progress_observers import EpochsCompleted, ProgressAtEachCheckpoint, StagnationObserver, TrainingProgress
|
51
51
|
|
52
52
|
__all__ = [
|
53
53
|
InitialHyperparameters.__name__,
|
@@ -84,4 +84,5 @@ __all__ = [
|
|
84
84
|
CosineSimilarityObserverOfGradientAndMomentum.__name__,
|
85
85
|
CosineSimilarityObserverOfGradientAndUpdate.__name__,
|
86
86
|
CosineSimilarityOfGradientAndParameter.__name__,
|
87
|
+
StagnationObserver.__name__,
|
87
88
|
]
|
@@ -3,7 +3,7 @@
|
|
3
3
|
# IMPORTS
|
4
4
|
#
|
5
5
|
# ======================================================================================================================
|
6
|
-
|
6
|
+
import math
|
7
7
|
from typing import Any
|
8
8
|
|
9
9
|
from libinephany.observations.observation_utils import StatisticStorageTypes
|
@@ -146,3 +146,76 @@ class ProgressAtEachCheckpoint(LHOPTCheckpointBaseObserver):
|
|
146
146
|
return current_progress
|
147
147
|
else:
|
148
148
|
return self.last_value
|
149
|
+
|
150
|
+
|
151
|
+
class StagnationObserver(GlobalObserver):
|
152
|
+
|
153
|
+
def __init__(self, stagnation_threshold: float = 0.0, log_transform: bool = True, **kwargs):
|
154
|
+
"""
|
155
|
+
:param stagnation_threshold: The loss improvement threshold. Default is 0.0. This is the amount the validation
|
156
|
+
loss must improve by over the best validation loss to reset the stagnation counter.
|
157
|
+
We assume that improvement means loss decreases regardless of the sign of the loss, which means the best loss is
|
158
|
+
negative infinity.
|
159
|
+
:param log_transform: Whether to log transform the stagnation counter. Default is True.
|
160
|
+
:param kwargs: Miscellaneous keyword arguments.
|
161
|
+
"""
|
162
|
+
super().__init__(**kwargs)
|
163
|
+
self.best_validation_loss: float | None = None
|
164
|
+
self.stagnation_counter: int = 0
|
165
|
+
self.stagnation_threshold = stagnation_threshold
|
166
|
+
self.log_transform = log_transform
|
167
|
+
|
168
|
+
@property
|
169
|
+
def can_standardize(self) -> bool:
|
170
|
+
"""
|
171
|
+
:return: Whether the observation can be standardized.
|
172
|
+
"""
|
173
|
+
|
174
|
+
return False
|
175
|
+
|
176
|
+
def _get_observation_format(self) -> StatisticStorageTypes:
|
177
|
+
"""
|
178
|
+
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
179
|
+
enumeration class.
|
180
|
+
"""
|
181
|
+
|
182
|
+
return StatisticStorageTypes.FLOAT
|
183
|
+
|
184
|
+
def _observe(
|
185
|
+
self,
|
186
|
+
observation_inputs: ObservationInputs,
|
187
|
+
hyperparameter_states: HyperparameterStates,
|
188
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
189
|
+
action_taken: float | int | None,
|
190
|
+
) -> float | int | list[int | float] | TensorStatistics:
|
191
|
+
"""
|
192
|
+
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
193
|
+
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
194
|
+
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
195
|
+
names to floats or TensorStatistic models.
|
196
|
+
:param action_taken: Action taken by the agent this class instance is assigned to.
|
197
|
+
:return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
|
198
|
+
"""
|
199
|
+
# Validate input values, if loss is invalid, increment stagnation counter
|
200
|
+
if math.isnan(observation_inputs.validation_loss) or math.isinf(observation_inputs.validation_loss):
|
201
|
+
self.stagnation_counter += 1
|
202
|
+
else:
|
203
|
+
if self.best_validation_loss is None:
|
204
|
+
self.best_validation_loss = observation_inputs.validation_loss
|
205
|
+
self.stagnation_counter = 0
|
206
|
+
else:
|
207
|
+
improvement = self.best_validation_loss - observation_inputs.validation_loss
|
208
|
+
if improvement > self.stagnation_threshold:
|
209
|
+
self.best_validation_loss = observation_inputs.validation_loss
|
210
|
+
self.stagnation_counter = 0
|
211
|
+
else:
|
212
|
+
self.stagnation_counter += 1
|
213
|
+
return self.stagnation_counter if not self.log_transform else math.log(self.stagnation_counter + 1)
|
214
|
+
|
215
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
216
|
+
"""
|
217
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
218
|
+
needed.
|
219
|
+
"""
|
220
|
+
|
221
|
+
return {}
|
@@ -8,17 +8,17 @@ libinephany/observations/pipeline_coordinator.py,sha256=mLfaHhkXVhMp9w5jWIAL3jPy
|
|
8
8
|
libinephany/observations/statistic_manager.py,sha256=LLg1zSxnJr2oQQepYla3qoUuRy10rsthr9jta4wEbnc,8956
|
9
9
|
libinephany/observations/statistic_trackers.py,sha256=nBxOxB9NHHr6OZylmgPpY3AXOr7kAJ0MF899YbAP_oc,48201
|
10
10
|
libinephany/observations/observers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
|
-
libinephany/observations/observers/base_observers.py,sha256=
|
11
|
+
libinephany/observations/observers/base_observers.py,sha256=THfOE4ozP_LG1UfsiuJkH6dhqp4VOpXY4NmpVlAa-kY,19886
|
12
12
|
libinephany/observations/observers/local_observers.py,sha256=HvOQhVEANQJjAl6Ln_YuP0-fNt_Xkh8ixlkb-A5JqDM,46318
|
13
13
|
libinephany/observations/observers/observer_containers.py,sha256=VNyqGgxYJ4r49Msp_kk-POgicb-_5w54twuT1qfNdxw,9562
|
14
|
-
libinephany/observations/observers/global_observers/__init__.py,sha256=
|
14
|
+
libinephany/observations/observers/global_observers/__init__.py,sha256=H8WHtWyBmjBQkaoy6fxkgYCkr-FEKJXgIRpVnw_ThZU,3071
|
15
15
|
libinephany/observations/observers/global_observers/base_classes.py,sha256=Q7OblhmKscypTs9JBepSQwo6ljjOdPKTU9kbpuhq_W4,7800
|
16
16
|
libinephany/observations/observers/global_observers/constants.py,sha256=TDQM_sGU8Swze794oB4TBaXFjSddt0OBhYPVhrXQ9Ko,1654
|
17
17
|
libinephany/observations/observers/global_observers/gradient_observers.py,sha256=j9uX7043ic06W6vb6vDt_PH2a5WtLFCdKQZ1JQGT24Q,19531
|
18
18
|
libinephany/observations/observers/global_observers/hyperparameter_observers.py,sha256=5Av8FgwWBJtcn4gpDzPdOTlKOOYy2lVEHS_gt5Sz7xo,15334
|
19
19
|
libinephany/observations/observers/global_observers/loss_observers.py,sha256=Kf943FiuYWuWvjmhgmp3TGyIQoZ27ZKJcxfTBwXS-gA,17761
|
20
20
|
libinephany/observations/observers/global_observers/model_observers.py,sha256=SGWXrmTdgp0kHvEvDSF7d3v1FEcK1sQDMPLQ8Wy3qv4,29306
|
21
|
-
libinephany/observations/observers/global_observers/progress_observers.py,sha256=
|
21
|
+
libinephany/observations/observers/global_observers/progress_observers.py,sha256=vJwue9tY6XjQKU_79ZAZMnC4HZjR8Y-AGyZpgEKiDbU,9284
|
22
22
|
libinephany/observations/post_processors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
23
23
|
libinephany/observations/post_processors/postprocessors.py,sha256=43_e5UaDPr2KbAvqc_w3wLqnlm7bgRjqgCtyQ95-8cM,5913
|
24
24
|
libinephany/pydantic_models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -57,8 +57,8 @@ libinephany/utils/typing.py,sha256=rGbaPO3MaUndsWiC_wHzReD_TOLYqb43i01pKN-j7Xs,6
|
|
57
57
|
libinephany/web_apps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
58
58
|
libinephany/web_apps/error_logger.py,sha256=QpspO726Uoyyr6lBEEb3Q9XqhVOXUM4AaYE7vbnk31c,18153
|
59
59
|
libinephany/web_apps/web_app_utils.py,sha256=qiq_lasPipgN1RgRudPJc342kYci8O_4RqppxmIX8NY,4095
|
60
|
-
libinephany-1.1.
|
61
|
-
libinephany-1.1.
|
62
|
-
libinephany-1.1.
|
63
|
-
libinephany-1.1.
|
64
|
-
libinephany-1.1.
|
60
|
+
libinephany-1.1.6.dist-info/licenses/LICENSE,sha256=pogfDoMBP07ehIOvWymuWIar8pg2YLUhqOHsJQU3wdc,9250
|
61
|
+
libinephany-1.1.6.dist-info/METADATA,sha256=js8dpqAYAwo1UTFR7sOReZMQ25IWrX4EucCdDfEQFgc,8389
|
62
|
+
libinephany-1.1.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
63
|
+
libinephany-1.1.6.dist-info/top_level.txt,sha256=bYAOXQdJgIoLkO2Ui0kxe7pSYegS_e38u0dMscd7COQ,12
|
64
|
+
libinephany-1.1.6.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|