libinephany 1.0.0__py3-none-any.whl → 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- libinephany/observations/observers/base_observers.py +26 -0
- libinephany/observations/observers/global_observers/base_classes.py +50 -64
- libinephany/observations/observers/global_observers/constants.py +14 -1
- libinephany/observations/observers/global_observers/gradient_observers.py +55 -94
- libinephany/observations/observers/global_observers/hyperparameter_observers.py +26 -8
- libinephany/observations/observers/global_observers/loss_observers.py +26 -47
- libinephany/observations/observers/global_observers/model_observers.py +65 -25
- libinephany/observations/observers/local_observers.py +89 -52
- libinephany/observations/statistic_trackers.py +8 -10
- libinephany/utils/enums.py +6 -0
- {libinephany-1.0.0.dist-info → libinephany-1.0.2.dist-info}/METADATA +1 -1
- {libinephany-1.0.0.dist-info → libinephany-1.0.2.dist-info}/RECORD +15 -15
- {libinephany-1.0.0.dist-info → libinephany-1.0.2.dist-info}/WHEEL +0 -0
- {libinephany-1.0.0.dist-info → libinephany-1.0.2.dist-info}/licenses/LICENSE +0 -0
- {libinephany-1.0.0.dist-info → libinephany-1.0.2.dist-info}/top_level.txt +0 -0
@@ -44,6 +44,7 @@ class Observer(ABC):
|
|
44
44
|
observer_config: ObserverConfig,
|
45
45
|
should_standardize: bool = True,
|
46
46
|
include_statistics: list[str] | None = None,
|
47
|
+
include_hparams: list[str] | None = None,
|
47
48
|
**kwargs,
|
48
49
|
) -> None:
|
49
50
|
"""
|
@@ -52,6 +53,8 @@ class Observer(ABC):
|
|
52
53
|
:param should_standardize: Whether standardization should be applied to returned values.
|
53
54
|
:param include_statistics: If the observation uses the TensorStatistic model to return observations, names of the
|
54
55
|
fields in the model to include in returned observations.
|
56
|
+
:param include_hparams: If the observation uses the HyperparameterStates model to return observations, names of the
|
57
|
+
hyperparameters to include in returned observations.
|
55
58
|
:param kwargs: Miscellaneous keyword arguments.
|
56
59
|
"""
|
57
60
|
|
@@ -64,10 +67,17 @@ class Observer(ABC):
|
|
64
67
|
self.should_standardize = should_standardize and self.can_standardize
|
65
68
|
|
66
69
|
self.include_statistics: list[str] | None = None
|
70
|
+
self.include_hparams = include_hparams
|
67
71
|
|
68
72
|
if include_statistics is not None:
|
69
73
|
self.include_statistics = TensorStatistics.filter_include_statistics(include_statistics=include_statistics)
|
70
74
|
|
75
|
+
if self.requires_include_statistics and not self.include_statistics:
|
76
|
+
raise ValueError(f"{self.__class__.__name__} must be provided with include_statistics.")
|
77
|
+
|
78
|
+
if self.requires_include_hparams and not self.include_hparams:
|
79
|
+
raise ValueError(f"{self.__class__.__name__} must be provided with include_hparams.")
|
80
|
+
|
71
81
|
@final
|
72
82
|
@property
|
73
83
|
def in_training_mode(self) -> bool:
|
@@ -143,6 +153,22 @@ class Observer(ABC):
|
|
143
153
|
|
144
154
|
return True
|
145
155
|
|
156
|
+
@property
|
157
|
+
def requires_include_statistics(self) -> bool:
|
158
|
+
"""
|
159
|
+
:return: Whether the observation requires include_statistics to be provided.
|
160
|
+
"""
|
161
|
+
|
162
|
+
return False
|
163
|
+
|
164
|
+
@property
|
165
|
+
def requires_include_hparams(self) -> bool:
|
166
|
+
"""
|
167
|
+
:return: Whether the observation requires include_hparams to be provided.
|
168
|
+
"""
|
169
|
+
|
170
|
+
return False
|
171
|
+
|
146
172
|
@property
|
147
173
|
@abstractmethod
|
148
174
|
def standardizer_key_infix(self) -> str:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
# ======================================================================================================================
|
2
2
|
#
|
3
|
-
#
|
3
|
+
# IMPORTS
|
4
4
|
#
|
5
5
|
# ======================================================================================================================
|
6
6
|
|
@@ -15,6 +15,12 @@ from libinephany.pydantic_models.schemas.observation_models import ObservationIn
|
|
15
15
|
from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
|
16
16
|
from libinephany.pydantic_models.states.hyperparameter_states import HyperparameterStates
|
17
17
|
|
18
|
+
# ======================================================================================================================
|
19
|
+
#
|
20
|
+
# CLASSES
|
21
|
+
#
|
22
|
+
# ======================================================================================================================
|
23
|
+
|
18
24
|
|
19
25
|
class LHOPTBaseObserver(GlobalObserver, ABC):
|
20
26
|
"""
|
@@ -33,13 +39,14 @@ class LHOPTBaseObserver(GlobalObserver, ABC):
|
|
33
39
|
:param kwargs: Other observation keyword arguments.
|
34
40
|
"""
|
35
41
|
super().__init__(**kwargs)
|
36
|
-
self.decay_factor = max(0.0, decay_factor)
|
37
|
-
self.time_window = max(1, time_window)
|
38
42
|
|
39
43
|
# Store time series data for CDF calculation
|
40
44
|
self._time_series: list[tuple[float, float]] = [] # (time, value) pairs
|
41
45
|
self._current_time: float = 0.0
|
42
46
|
|
47
|
+
self.decay_factor = max(0.0, decay_factor)
|
48
|
+
self.time_window = max(1, time_window)
|
49
|
+
|
43
50
|
@property
|
44
51
|
def can_standardize(self) -> bool:
|
45
52
|
"""
|
@@ -48,6 +55,28 @@ class LHOPTBaseObserver(GlobalObserver, ABC):
|
|
48
55
|
"""
|
49
56
|
return False
|
50
57
|
|
58
|
+
@staticmethod
|
59
|
+
def _compute_log_ratio(numerator: float, denominator: float) -> float:
|
60
|
+
"""
|
61
|
+
Compute the log ratio.
|
62
|
+
|
63
|
+
:param numerator: Numerator value
|
64
|
+
:param denominator: Denominator value
|
65
|
+
:return: Log ratio value
|
66
|
+
"""
|
67
|
+
# Calculate the ratio of numerator to denominator
|
68
|
+
invalid_denominator = math.isinf(denominator) or math.isnan(denominator)
|
69
|
+
|
70
|
+
if denominator <= LHOPT_CONSTANTS["ZERO_DIVISION_TOLERANCE"] or invalid_denominator:
|
71
|
+
return 0.0
|
72
|
+
|
73
|
+
ratio = numerator / denominator
|
74
|
+
|
75
|
+
if ratio <= 0:
|
76
|
+
return 0.0
|
77
|
+
|
78
|
+
return math.log(ratio)
|
79
|
+
|
51
80
|
def _get_observation_format(self) -> StatisticStorageTypes:
|
52
81
|
"""
|
53
82
|
:return: Format the observation returns data in. Must be one of the StatisticStorageTypes
|
@@ -68,18 +97,6 @@ class LHOPTBaseObserver(GlobalObserver, ABC):
|
|
68
97
|
"""Update the current time counter."""
|
69
98
|
self._current_time += 1.0
|
70
99
|
|
71
|
-
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
72
|
-
"""
|
73
|
-
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
74
|
-
needed.
|
75
|
-
"""
|
76
|
-
return {}
|
77
|
-
|
78
|
-
def reset(self) -> None:
|
79
|
-
"""Reset the observer by clearing the time series."""
|
80
|
-
self._time_series = []
|
81
|
-
self._current_time = 0.0
|
82
|
-
|
83
100
|
@abstractmethod
|
84
101
|
def _observe(
|
85
102
|
self,
|
@@ -94,27 +111,13 @@ class LHOPTBaseObserver(GlobalObserver, ABC):
|
|
94
111
|
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
95
112
|
:param action_taken: Action taken by the agent this class instance is assigned to.
|
96
113
|
"""
|
97
|
-
raise NotImplementedError
|
98
|
-
|
99
|
-
def _compute_log_ratio(self, numerator: float, denominator: float) -> float:
|
100
|
-
"""
|
101
|
-
Compute the log ratio.
|
102
|
-
|
103
|
-
:param numerator: Numerator value
|
104
|
-
:param denominator: Denominator value
|
105
|
-
:return: Log ratio value
|
106
|
-
"""
|
107
|
-
# Calculate the ratio of numerator to denominator
|
108
114
|
|
109
|
-
|
110
|
-
return 0.0
|
115
|
+
...
|
111
116
|
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
return math.log(ratio)
|
117
|
+
def reset(self) -> None:
|
118
|
+
"""Reset the observer by clearing the time series."""
|
119
|
+
self._time_series = []
|
120
|
+
self._current_time = 0.0
|
118
121
|
|
119
122
|
|
120
123
|
class LHOPTCheckpointBaseObserver(GlobalObserver, ABC):
|
@@ -128,8 +131,10 @@ class LHOPTCheckpointBaseObserver(GlobalObserver, ABC):
|
|
128
131
|
:param kwargs: Miscellaneous keyword arguments.
|
129
132
|
"""
|
130
133
|
super().__init__(**kwargs)
|
131
|
-
|
134
|
+
|
132
135
|
self._history: list[float] = []
|
136
|
+
|
137
|
+
self.checkpoint_interval = checkpoint_interval
|
133
138
|
self.last_value: float | None = None
|
134
139
|
|
135
140
|
@property
|
@@ -175,18 +180,6 @@ class LHOPTCheckpointBaseObserver(GlobalObserver, ABC):
|
|
175
180
|
if self.last_value is None:
|
176
181
|
self.last_value = value
|
177
182
|
|
178
|
-
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
179
|
-
"""
|
180
|
-
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
181
|
-
needed.
|
182
|
-
"""
|
183
|
-
return {}
|
184
|
-
|
185
|
-
def reset(self) -> None:
|
186
|
-
"""Reset the observer by clearing history."""
|
187
|
-
self._history = []
|
188
|
-
self.last_value = None
|
189
|
-
|
190
183
|
@abstractmethod
|
191
184
|
def _observe(
|
192
185
|
self,
|
@@ -201,24 +194,17 @@ class LHOPTCheckpointBaseObserver(GlobalObserver, ABC):
|
|
201
194
|
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
202
195
|
:param action_taken: Action taken by the agent this class instance is assigned to.
|
203
196
|
"""
|
204
|
-
raise NotImplementedError
|
205
197
|
|
206
|
-
|
207
|
-
"""
|
208
|
-
Compute the log ratio.
|
198
|
+
...
|
209
199
|
|
210
|
-
|
211
|
-
:param denominator: Denominator value
|
212
|
-
:return: Log ratio value
|
200
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
213
201
|
"""
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
ratio = numerator / denominator
|
220
|
-
|
221
|
-
if ratio <= 0:
|
222
|
-
return 0.0
|
202
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
203
|
+
needed.
|
204
|
+
"""
|
205
|
+
return {}
|
223
206
|
|
224
|
-
|
207
|
+
def reset(self) -> None:
|
208
|
+
"""Reset the observer by clearing history."""
|
209
|
+
self._history = []
|
210
|
+
self.last_value = None
|
@@ -1,11 +1,17 @@
|
|
1
1
|
# ======================================================================================================================
|
2
2
|
#
|
3
|
-
#
|
3
|
+
# IMPORTS
|
4
4
|
#
|
5
5
|
# ======================================================================================================================
|
6
6
|
|
7
7
|
from typing import TypedDict
|
8
8
|
|
9
|
+
# ======================================================================================================================
|
10
|
+
#
|
11
|
+
# CLASSES
|
12
|
+
#
|
13
|
+
# ======================================================================================================================
|
14
|
+
|
9
15
|
|
10
16
|
class LHOPTConstants(TypedDict):
|
11
17
|
IS_NAN: float
|
@@ -23,6 +29,13 @@ class LHOPTConstants(TypedDict):
|
|
23
29
|
DEFAULT_ENV_STEP_SAMPLE_FREQUENCY: int
|
24
30
|
|
25
31
|
|
32
|
+
# ======================================================================================================================
|
33
|
+
#
|
34
|
+
# CONSTANTS
|
35
|
+
#
|
36
|
+
# ======================================================================================================================
|
37
|
+
|
38
|
+
|
26
39
|
# Create the constants instance
|
27
40
|
LHOPT_CONSTANTS: LHOPTConstants = LHOPTConstants(
|
28
41
|
IS_NAN=1.0,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
# ======================================================================================================================
|
2
2
|
#
|
3
|
-
#
|
3
|
+
# IMPORTS
|
4
4
|
#
|
5
5
|
# ======================================================================================================================
|
6
6
|
|
@@ -16,9 +16,23 @@ from libinephany.pydantic_models.schemas.observation_models import ObservationIn
|
|
16
16
|
from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
|
17
17
|
from libinephany.pydantic_models.states.hyperparameter_states import HyperparameterStates
|
18
18
|
|
19
|
+
# ======================================================================================================================
|
20
|
+
#
|
21
|
+
# CLASSES
|
22
|
+
#
|
23
|
+
# ======================================================================================================================
|
24
|
+
|
19
25
|
|
20
26
|
class GlobalFirstOrderGradients(GlobalObserver):
|
21
27
|
|
28
|
+
@property
|
29
|
+
def requires_include_statistics(self) -> bool:
|
30
|
+
"""
|
31
|
+
:return: Whether the observation requires include_statistics to be provided.
|
32
|
+
"""
|
33
|
+
|
34
|
+
return True
|
35
|
+
|
22
36
|
def _get_observation_format(self) -> StatisticStorageTypes:
|
23
37
|
"""
|
24
38
|
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
@@ -74,6 +88,14 @@ class GlobalSecondOrderGradients(GlobalObserver):
|
|
74
88
|
|
75
89
|
self.compute_hessian_diagonal = compute_hessian_diagonal
|
76
90
|
|
91
|
+
@property
|
92
|
+
def requires_include_statistics(self) -> bool:
|
93
|
+
"""
|
94
|
+
:return: Whether the observation requires include_statistics to be provided.
|
95
|
+
"""
|
96
|
+
|
97
|
+
return True
|
98
|
+
|
77
99
|
def _get_observation_format(self) -> StatisticStorageTypes:
|
78
100
|
"""
|
79
101
|
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
@@ -136,21 +158,6 @@ class LHOPTGradientVarianceFraction(LHOPTBaseObserver):
|
|
136
158
|
super().__init__(**kwargs)
|
137
159
|
self.variance_threshold = variance_threshold
|
138
160
|
|
139
|
-
@property
|
140
|
-
def can_standardize(self) -> bool:
|
141
|
-
"""
|
142
|
-
This observer has its own CDF calculation, no need to standardize.
|
143
|
-
:return: Whether the observation can be standardized.
|
144
|
-
"""
|
145
|
-
return False
|
146
|
-
|
147
|
-
def _get_observation_format(self) -> StatisticStorageTypes:
|
148
|
-
"""
|
149
|
-
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
150
|
-
enumeration class.
|
151
|
-
"""
|
152
|
-
return StatisticStorageTypes.VECTOR
|
153
|
-
|
154
161
|
@property
|
155
162
|
def vector_length(self) -> int:
|
156
163
|
"""
|
@@ -173,8 +180,6 @@ class LHOPTGradientVarianceFraction(LHOPTBaseObserver):
|
|
173
180
|
:param action_taken: Action taken by the agent this class instance is assigned to.
|
174
181
|
:return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
|
175
182
|
"""
|
176
|
-
if statistic_trackers.GradientVarianceFraction.__name__ not in tracked_statistics:
|
177
|
-
return [0.0, 0.0]
|
178
183
|
|
179
184
|
raw_value = list(tracked_statistics[statistic_trackers.GradientVarianceFraction.__name__].values())[0] # type: ignore[list-item]
|
180
185
|
|
@@ -204,14 +209,6 @@ class LHOPTMomentumGradientRatio(LHOPTBaseObserver):
|
|
204
209
|
It returns two-dimensional observations: [raw_value, cdf_feature] for momentum gradient ratio values.
|
205
210
|
"""
|
206
211
|
|
207
|
-
def _get_observation_format(self) -> StatisticStorageTypes:
|
208
|
-
"""
|
209
|
-
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
210
|
-
enumeration class.
|
211
|
-
"""
|
212
|
-
|
213
|
-
return StatisticStorageTypes.VECTOR
|
214
|
-
|
215
212
|
@property
|
216
213
|
def vector_length(self) -> int:
|
217
214
|
"""
|
@@ -219,6 +216,14 @@ class LHOPTMomentumGradientRatio(LHOPTBaseObserver):
|
|
219
216
|
"""
|
220
217
|
return 2 # [raw_value, cdf_feature]
|
221
218
|
|
219
|
+
@property
|
220
|
+
def requires_include_statistics(self) -> bool:
|
221
|
+
"""
|
222
|
+
:return: Whether the observation requires include_statistics to be provided.
|
223
|
+
"""
|
224
|
+
|
225
|
+
return True
|
226
|
+
|
222
227
|
def _observe(
|
223
228
|
self,
|
224
229
|
observation_inputs: ObservationInputs,
|
@@ -266,29 +271,6 @@ class CosineSimilarityObserverOfGradientAndMomentum(LHOPTBaseObserver):
|
|
266
271
|
It returns two-dimensional observations: [raw_value, cdf_feature] for cosine similarity of gradient and momentum values.
|
267
272
|
"""
|
268
273
|
|
269
|
-
def __init__(
|
270
|
-
self,
|
271
|
-
*,
|
272
|
-
include_statistics: list[str] | None = None,
|
273
|
-
**kwargs,
|
274
|
-
) -> None:
|
275
|
-
"""
|
276
|
-
:param include_statistics: List of statistics to include.
|
277
|
-
:param kwargs: Miscellaneous keyword arguments.
|
278
|
-
"""
|
279
|
-
|
280
|
-
super().__init__(**kwargs)
|
281
|
-
|
282
|
-
self.include_statistics = include_statistics
|
283
|
-
|
284
|
-
def _get_observation_format(self) -> StatisticStorageTypes:
|
285
|
-
"""
|
286
|
-
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
287
|
-
enumeration class.
|
288
|
-
"""
|
289
|
-
|
290
|
-
return StatisticStorageTypes.VECTOR
|
291
|
-
|
292
274
|
@property
|
293
275
|
def vector_length(self) -> int:
|
294
276
|
"""
|
@@ -296,6 +278,14 @@ class CosineSimilarityObserverOfGradientAndMomentum(LHOPTBaseObserver):
|
|
296
278
|
"""
|
297
279
|
return 3 # [raw_value, cdf_feature, logit_of_cdf_feature]
|
298
280
|
|
281
|
+
@property
|
282
|
+
def requires_include_statistics(self) -> bool:
|
283
|
+
"""
|
284
|
+
:return: Whether the observation requires include_statistics to be provided.
|
285
|
+
"""
|
286
|
+
|
287
|
+
return True
|
288
|
+
|
299
289
|
def _observe(
|
300
290
|
self,
|
301
291
|
observation_inputs: ObservationInputs,
|
@@ -351,29 +341,6 @@ class CosineSimilarityObserverOfGradientAndUpdate(LHOPTBaseObserver):
|
|
351
341
|
It returns two-dimensional observations: [raw_value, cdf_feature] for cosine similarity of gradient and update values.
|
352
342
|
"""
|
353
343
|
|
354
|
-
def __init__(
|
355
|
-
self,
|
356
|
-
*,
|
357
|
-
include_statistics: list[str] | None = None,
|
358
|
-
**kwargs,
|
359
|
-
) -> None:
|
360
|
-
"""
|
361
|
-
:param include_statistics: List of statistics to include.
|
362
|
-
:param kwargs: Miscellaneous keyword arguments.
|
363
|
-
"""
|
364
|
-
|
365
|
-
super().__init__(**kwargs)
|
366
|
-
|
367
|
-
self.include_statistics = include_statistics
|
368
|
-
|
369
|
-
def _get_observation_format(self) -> StatisticStorageTypes:
|
370
|
-
"""
|
371
|
-
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
372
|
-
enumeration class.
|
373
|
-
"""
|
374
|
-
|
375
|
-
return StatisticStorageTypes.VECTOR
|
376
|
-
|
377
344
|
@property
|
378
345
|
def vector_length(self) -> int:
|
379
346
|
"""
|
@@ -381,6 +348,14 @@ class CosineSimilarityObserverOfGradientAndUpdate(LHOPTBaseObserver):
|
|
381
348
|
"""
|
382
349
|
return 3 # [raw_value, cdf_feature, logit_of_cdf_feature]
|
383
350
|
|
351
|
+
@property
|
352
|
+
def requires_include_statistics(self) -> bool:
|
353
|
+
"""
|
354
|
+
:return: Whether the observation requires include_statistics to be provided.
|
355
|
+
"""
|
356
|
+
|
357
|
+
return True
|
358
|
+
|
384
359
|
def _observe(
|
385
360
|
self,
|
386
361
|
observation_inputs: ObservationInputs,
|
@@ -436,28 +411,6 @@ class CosineSimilarityOfGradientAndParameter(LHOPTBaseObserver):
|
|
436
411
|
It returns two-dimensional observations: [raw_value, cdf_feature] for cosine similarity of gradient and parameter values.
|
437
412
|
"""
|
438
413
|
|
439
|
-
def __init__(
|
440
|
-
self,
|
441
|
-
*,
|
442
|
-
include_statistics: list[str] | None = None,
|
443
|
-
**kwargs,
|
444
|
-
) -> None:
|
445
|
-
"""
|
446
|
-
:param include_statistics: List of statistics to include.
|
447
|
-
:param kwargs: Miscellaneous keyword arguments.
|
448
|
-
"""
|
449
|
-
super().__init__(**kwargs)
|
450
|
-
|
451
|
-
self.include_statistics = include_statistics
|
452
|
-
|
453
|
-
def _get_observation_format(self) -> StatisticStorageTypes:
|
454
|
-
"""
|
455
|
-
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
456
|
-
enumeration class.
|
457
|
-
"""
|
458
|
-
|
459
|
-
return StatisticStorageTypes.VECTOR
|
460
|
-
|
461
414
|
@property
|
462
415
|
def vector_length(self) -> int:
|
463
416
|
"""
|
@@ -465,6 +418,14 @@ class CosineSimilarityOfGradientAndParameter(LHOPTBaseObserver):
|
|
465
418
|
"""
|
466
419
|
return 3 # [raw_value, cdf_feature, logit_of_cdf_feature]
|
467
420
|
|
421
|
+
@property
|
422
|
+
def requires_include_statistics(self) -> bool:
|
423
|
+
"""
|
424
|
+
:return: Whether the observation requires include_statistics to be provided.
|
425
|
+
"""
|
426
|
+
|
427
|
+
return True
|
428
|
+
|
468
429
|
def _observe(
|
469
430
|
self,
|
470
431
|
observation_inputs: ObservationInputs,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
# ======================================================================================================================
|
2
2
|
#
|
3
|
-
#
|
3
|
+
# IMPORTS
|
4
4
|
#
|
5
5
|
# ======================================================================================================================
|
6
6
|
|
@@ -12,25 +12,28 @@ from torch.optim import SGD, Adam, AdamW
|
|
12
12
|
from libinephany.observations import observation_utils
|
13
13
|
from libinephany.observations.observation_utils import StatisticStorageTypes
|
14
14
|
from libinephany.observations.observers.base_observers import GlobalObserver
|
15
|
-
from libinephany.observations.observers.global_observers.
|
15
|
+
from libinephany.observations.observers.global_observers.constants import LHOPT_CONSTANTS
|
16
16
|
from libinephany.pydantic_models.schemas.observation_models import ObservationInputs
|
17
17
|
from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
|
18
18
|
from libinephany.pydantic_models.states.hyperparameter_states import HyperparameterStates
|
19
19
|
from libinephany.utils.enums import ModelFamilies
|
20
20
|
|
21
|
+
# ======================================================================================================================
|
22
|
+
#
|
23
|
+
# CLASSES
|
24
|
+
#
|
25
|
+
# ======================================================================================================================
|
26
|
+
|
21
27
|
|
22
28
|
class InitialHyperparameters(GlobalObserver):
|
23
29
|
|
24
|
-
def __init__(self,
|
30
|
+
def __init__(self, pad_with: float = 0.0, **kwargs) -> None:
|
25
31
|
"""
|
26
|
-
:param include_hparams: Names of the hyperparameters to include in the initial values vector returned by
|
27
|
-
this observation.
|
28
32
|
:param kwargs: Miscellaneous keyword arguments.
|
29
33
|
"""
|
30
34
|
|
31
35
|
super().__init__(**kwargs)
|
32
36
|
|
33
|
-
self.include_hparams = include_hparams
|
34
37
|
self.pad_with = pad_with
|
35
38
|
|
36
39
|
@property
|
@@ -62,6 +65,14 @@ class InitialHyperparameters(GlobalObserver):
|
|
62
65
|
|
63
66
|
return False
|
64
67
|
|
68
|
+
@property
|
69
|
+
def requires_include_hparams(self) -> bool:
|
70
|
+
"""
|
71
|
+
:return: Whether the observation requires include_hparams to be provided.
|
72
|
+
"""
|
73
|
+
|
74
|
+
return True
|
75
|
+
|
65
76
|
def _get_observation_format(self) -> StatisticStorageTypes:
|
66
77
|
"""
|
67
78
|
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
@@ -298,7 +309,7 @@ class LHOPTHyperparameterRatio(GlobalObserver):
|
|
298
309
|
providing insights into how much hyperparameters have changed from their starting values.
|
299
310
|
"""
|
300
311
|
|
301
|
-
def __init__(self,
|
312
|
+
def __init__(self, pad_with: float = 0.0, **kwargs) -> None:
|
302
313
|
"""
|
303
314
|
:param include_hparams: Names of the hyperparameters to include in the initial values vector returned by
|
304
315
|
this observation.
|
@@ -307,7 +318,6 @@ class LHOPTHyperparameterRatio(GlobalObserver):
|
|
307
318
|
|
308
319
|
super().__init__(**kwargs)
|
309
320
|
|
310
|
-
self.include_hparams = include_hparams
|
311
321
|
self.pad_with = pad_with
|
312
322
|
|
313
323
|
@property
|
@@ -339,6 +349,14 @@ class LHOPTHyperparameterRatio(GlobalObserver):
|
|
339
349
|
|
340
350
|
return False
|
341
351
|
|
352
|
+
@property
|
353
|
+
def requires_include_hparams(self) -> bool:
|
354
|
+
"""
|
355
|
+
:return: Whether the observation requires include_hparams to be provided.
|
356
|
+
"""
|
357
|
+
|
358
|
+
return True
|
359
|
+
|
342
360
|
def _get_observation_format(self) -> StatisticStorageTypes:
|
343
361
|
"""
|
344
362
|
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
@@ -239,6 +239,13 @@ class LHOPTTrainingLoss(LHOPTBaseObserver):
|
|
239
239
|
This observer use the CDF calculation from the paper and applies CDF transformation using the CDF mean and std.
|
240
240
|
"""
|
241
241
|
|
242
|
+
@property
|
243
|
+
def vector_length(self) -> int:
|
244
|
+
"""
|
245
|
+
:return: Length of the vector returned by this observation if it returns a vector.
|
246
|
+
"""
|
247
|
+
return 3 # [is_nan, is_inf, cdf_feature]
|
248
|
+
|
242
249
|
def _observe(
|
243
250
|
self,
|
244
251
|
observation_inputs: ObservationInputs,
|
@@ -271,18 +278,6 @@ class LHOPTTrainingLoss(LHOPTBaseObserver):
|
|
271
278
|
|
272
279
|
return {}
|
273
280
|
|
274
|
-
def reset(self) -> None:
|
275
|
-
"""Reset the observer by clearing the time series."""
|
276
|
-
self._time_series: list[tuple[float, float]] = []
|
277
|
-
self._current_time: float = 0.0
|
278
|
-
|
279
|
-
@property
|
280
|
-
def vector_length(self) -> int:
|
281
|
-
"""
|
282
|
-
:return: Length of the vector returned by this observation if it returns a vector.
|
283
|
-
"""
|
284
|
-
return 3 # [is_nan, is_inf, cdf_feature]
|
285
|
-
|
286
281
|
|
287
282
|
class LHOPTValidationLoss(LHOPTBaseObserver):
|
288
283
|
"""
|
@@ -294,6 +289,13 @@ class LHOPTValidationLoss(LHOPTBaseObserver):
|
|
294
289
|
This observer use the CDF calculation from the paper and applies CDF transformation using the CDF mean and std.
|
295
290
|
"""
|
296
291
|
|
292
|
+
@property
|
293
|
+
def vector_length(self) -> int:
|
294
|
+
"""
|
295
|
+
:return: Length of the vector returned by this observation if it returns a vector.
|
296
|
+
"""
|
297
|
+
return 3 # [is_nan, is_inf, cdf_feature]
|
298
|
+
|
297
299
|
def _observe(
|
298
300
|
self,
|
299
301
|
observation_inputs: ObservationInputs,
|
@@ -326,18 +328,6 @@ class LHOPTValidationLoss(LHOPTBaseObserver):
|
|
326
328
|
|
327
329
|
return {}
|
328
330
|
|
329
|
-
def reset(self) -> None:
|
330
|
-
"""Reset the observer by clearing the time series."""
|
331
|
-
self._time_series: list[tuple[float, float]] = []
|
332
|
-
self._current_time: float = 0.0
|
333
|
-
|
334
|
-
@property
|
335
|
-
def vector_length(self) -> int:
|
336
|
-
"""
|
337
|
-
:return: Length of the vector returned by this observation if it returns a vector.
|
338
|
-
"""
|
339
|
-
return 3 # [is_nan, is_inf, cdf_feature]
|
340
|
-
|
341
331
|
|
342
332
|
class LHOPTLossRatio(LHOPTBaseObserver):
|
343
333
|
"""
|
@@ -353,6 +343,13 @@ class LHOPTLossRatio(LHOPTBaseObserver):
|
|
353
343
|
3. cdf_feature - CDF transformed feature using CDF mean and std
|
354
344
|
"""
|
355
345
|
|
346
|
+
@property
|
347
|
+
def vector_length(self) -> int:
|
348
|
+
"""
|
349
|
+
:return: Length of the vector returned by this observation if it returns a vector.
|
350
|
+
"""
|
351
|
+
return 3 # [is_nan, tanh, cdf_feature]
|
352
|
+
|
356
353
|
def _observe(
|
357
354
|
self,
|
358
355
|
observation_inputs: ObservationInputs,
|
@@ -370,7 +367,7 @@ class LHOPTLossRatio(LHOPTBaseObserver):
|
|
370
367
|
"""
|
371
368
|
|
372
369
|
log_ratio = self._compute_log_ratio(
|
373
|
-
|
370
|
+
numerator=observation_inputs.training_loss, denominator=observation_inputs.validation_loss
|
374
371
|
)
|
375
372
|
|
376
373
|
tanh_feature = math.tanh(max(-LHOPT_CONSTANTS["TANH_BOUND"], min(LHOPT_CONSTANTS["TANH_BOUND"], log_ratio)))
|
@@ -381,31 +378,13 @@ class LHOPTLossRatio(LHOPTBaseObserver):
|
|
381
378
|
|
382
379
|
return [int(math.isnan(log_ratio)), tanh_feature, cdf_feature]
|
383
380
|
|
384
|
-
def
|
381
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
385
382
|
"""
|
386
|
-
|
387
|
-
|
388
|
-
:param training_score: Training score value
|
389
|
-
:param validation_score: Validation score value
|
390
|
-
:return: Log ratio value
|
383
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
384
|
+
needed.
|
391
385
|
"""
|
392
|
-
if training_score <= 0:
|
393
|
-
return 0.0
|
394
|
-
|
395
|
-
if validation_score <= 0:
|
396
|
-
return 0.0
|
397
386
|
|
398
|
-
|
399
|
-
score_ratio = validation_score / training_score
|
400
|
-
|
401
|
-
return math.log(score_ratio)
|
402
|
-
|
403
|
-
@property
|
404
|
-
def vector_length(self) -> int:
|
405
|
-
"""
|
406
|
-
:return: Length of the vector returned by this observation if it returns a vector.
|
407
|
-
"""
|
408
|
-
return 3 # [is_nan, tanh, cdf_feature]
|
387
|
+
return {}
|
409
388
|
|
410
389
|
|
411
390
|
class PercentileOfLossAtEachCheckpoint(LHOPTCheckpointBaseObserver):
|
@@ -25,6 +25,14 @@ from libinephany.pydantic_models.states.hyperparameter_states import Hyperparame
|
|
25
25
|
|
26
26
|
class GlobalActivations(GlobalObserver):
|
27
27
|
|
28
|
+
@property
|
29
|
+
def requires_include_statistics(self) -> bool:
|
30
|
+
"""
|
31
|
+
:return: Whether the observation requires include_statistics to be provided.
|
32
|
+
"""
|
33
|
+
|
34
|
+
return True
|
35
|
+
|
28
36
|
def _get_observation_format(self) -> StatisticStorageTypes:
|
29
37
|
"""
|
30
38
|
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
@@ -64,6 +72,14 @@ class GlobalActivations(GlobalObserver):
|
|
64
72
|
|
65
73
|
class GlobalParameterUpdates(GlobalObserver):
|
66
74
|
|
75
|
+
@property
|
76
|
+
def requires_include_statistics(self) -> bool:
|
77
|
+
"""
|
78
|
+
:return: Whether the observation requires include_statistics to be provided.
|
79
|
+
"""
|
80
|
+
|
81
|
+
return True
|
82
|
+
|
67
83
|
def _get_observation_format(self) -> StatisticStorageTypes:
|
68
84
|
"""
|
69
85
|
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
@@ -103,6 +119,14 @@ class GlobalParameterUpdates(GlobalObserver):
|
|
103
119
|
|
104
120
|
class GlobalParameters(GlobalObserver):
|
105
121
|
|
122
|
+
@property
|
123
|
+
def requires_include_statistics(self) -> bool:
|
124
|
+
"""
|
125
|
+
:return: Whether the observation requires include_statistics to be provided.
|
126
|
+
"""
|
127
|
+
|
128
|
+
return True
|
129
|
+
|
106
130
|
def _get_observation_format(self) -> StatisticStorageTypes:
|
107
131
|
"""
|
108
132
|
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
@@ -342,6 +366,14 @@ class LogRatioOfPreviousAndCurrentParamNormEnvStepObserver(LHOPTBaseObserver):
|
|
342
366
|
"""
|
343
367
|
return 2 # [tanh_feature, cdf_feature]
|
344
368
|
|
369
|
+
@property
|
370
|
+
def requires_include_statistics(self) -> bool:
|
371
|
+
"""
|
372
|
+
:return: Whether the observation requires include_statistics to be provided.
|
373
|
+
"""
|
374
|
+
|
375
|
+
return True
|
376
|
+
|
345
377
|
def _observe(
|
346
378
|
self,
|
347
379
|
observation_inputs: ObservationInputs,
|
@@ -410,6 +442,14 @@ class LogRatioOfUpdateAndPreviousParamNormEnvStepObserver(LHOPTBaseObserver):
|
|
410
442
|
"""
|
411
443
|
return 2 # [tanh_feature, cdf_feature]
|
412
444
|
|
445
|
+
@property
|
446
|
+
def requires_include_statistics(self) -> bool:
|
447
|
+
"""
|
448
|
+
:return: Whether the observation requires include_statistics to be provided.
|
449
|
+
"""
|
450
|
+
|
451
|
+
return True
|
452
|
+
|
413
453
|
def _observe(
|
414
454
|
self,
|
415
455
|
observation_inputs: ObservationInputs,
|
@@ -473,28 +513,20 @@ class LogRatioOfUpdateAndPreviousParamNormEnvStepObserver(LHOPTBaseObserver):
|
|
473
513
|
|
474
514
|
class LHOPTAverageParameterUpdateMagnitudeObserver(LHOPTBaseObserver):
|
475
515
|
|
476
|
-
def _get_observation_format(self) -> StatisticStorageTypes:
|
477
|
-
"""
|
478
|
-
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
479
|
-
enumeration class.
|
480
|
-
"""
|
481
|
-
|
482
|
-
return StatisticStorageTypes.VECTOR
|
483
|
-
|
484
516
|
@property
|
485
|
-
def
|
517
|
+
def vector_length(self) -> int:
|
486
518
|
"""
|
487
|
-
:return:
|
519
|
+
:return: Length of the vector returned by this observation if it returns a vector.
|
488
520
|
"""
|
489
|
-
|
490
|
-
return False
|
521
|
+
return 2 # [raw_feature, cdf_feature]
|
491
522
|
|
492
523
|
@property
|
493
|
-
def
|
524
|
+
def requires_include_statistics(self) -> bool:
|
494
525
|
"""
|
495
|
-
:return:
|
526
|
+
:return: Whether the observation requires include_statistics to be provided.
|
496
527
|
"""
|
497
|
-
|
528
|
+
|
529
|
+
return True
|
498
530
|
|
499
531
|
def _observe(
|
500
532
|
self,
|
@@ -537,8 +569,8 @@ class LHOPTAverageParameterUpdateMagnitudeObserver(LHOPTBaseObserver):
|
|
537
569
|
class LogRatioOfUpdateAndPreviousParamNormInnerStepObserver(LHOPTBaseObserver):
|
538
570
|
def __init__(self, **kwargs):
|
539
571
|
"""
|
540
|
-
This observer is used to compute the log ratio of the update and previous parameter norm for the inner step.
|
541
|
-
|
572
|
+
This observer is used to compute the log ratio of the update and previous parameter norm for the inner step.
|
573
|
+
The sample frequency of the statistics needs to be set to 4 (according to the OpenAI paper).
|
542
574
|
"""
|
543
575
|
super().__init__(**kwargs)
|
544
576
|
self._previous_param_norm = None
|
@@ -550,6 +582,14 @@ class LogRatioOfUpdateAndPreviousParamNormInnerStepObserver(LHOPTBaseObserver):
|
|
550
582
|
"""
|
551
583
|
return 2 # [tanh_feature, cdf_feature]
|
552
584
|
|
585
|
+
@property
|
586
|
+
def requires_include_statistics(self) -> bool:
|
587
|
+
"""
|
588
|
+
:return: Whether the observation requires include_statistics to be provided.
|
589
|
+
"""
|
590
|
+
|
591
|
+
return True
|
592
|
+
|
553
593
|
def _observe(
|
554
594
|
self,
|
555
595
|
observation_inputs: ObservationInputs,
|
@@ -630,14 +670,6 @@ class LHOPTGlobalLAMBTrustRatio(LHOPTBaseObserver):
|
|
630
670
|
|
631
671
|
self.use_log_transform = use_log_transform
|
632
672
|
|
633
|
-
def _get_observation_format(self) -> StatisticStorageTypes:
|
634
|
-
"""
|
635
|
-
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
636
|
-
enumeration class.
|
637
|
-
"""
|
638
|
-
|
639
|
-
return StatisticStorageTypes.VECTOR
|
640
|
-
|
641
673
|
@property
|
642
674
|
def vector_length(self) -> int:
|
643
675
|
"""
|
@@ -645,6 +677,14 @@ class LHOPTGlobalLAMBTrustRatio(LHOPTBaseObserver):
|
|
645
677
|
"""
|
646
678
|
return 2 # [raw_value, cdf_feature]
|
647
679
|
|
680
|
+
@property
|
681
|
+
def requires_include_statistics(self) -> bool:
|
682
|
+
"""
|
683
|
+
:return: Whether the observation requires include_statistics to be provided.
|
684
|
+
"""
|
685
|
+
|
686
|
+
return True
|
687
|
+
|
648
688
|
def _observe(
|
649
689
|
self,
|
650
690
|
observation_inputs: ObservationInputs,
|
@@ -27,6 +27,14 @@ from libinephany.utils.transforms import HyperparameterTransformType
|
|
27
27
|
|
28
28
|
class FirstOrderGradients(LocalObserver):
|
29
29
|
|
30
|
+
@property
|
31
|
+
def requires_include_statistics(self) -> bool:
|
32
|
+
"""
|
33
|
+
:return: Whether the observation requires include_statistics to be provided.
|
34
|
+
"""
|
35
|
+
|
36
|
+
return True
|
37
|
+
|
30
38
|
def _get_observation_format(self) -> StatisticStorageTypes:
|
31
39
|
"""
|
32
40
|
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
@@ -89,6 +97,14 @@ class SecondOrderGradients(LocalObserver):
|
|
89
97
|
|
90
98
|
self.compute_hessian_diagonal = compute_hessian_diagonal
|
91
99
|
|
100
|
+
@property
|
101
|
+
def requires_include_statistics(self) -> bool:
|
102
|
+
"""
|
103
|
+
:return: Whether the observation requires include_statistics to be provided.
|
104
|
+
"""
|
105
|
+
|
106
|
+
return True
|
107
|
+
|
92
108
|
def _get_observation_format(self) -> StatisticStorageTypes:
|
93
109
|
"""
|
94
110
|
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
@@ -139,6 +155,14 @@ class SecondOrderGradients(LocalObserver):
|
|
139
155
|
|
140
156
|
class Activations(LocalObserver):
|
141
157
|
|
158
|
+
@property
|
159
|
+
def requires_include_statistics(self) -> bool:
|
160
|
+
"""
|
161
|
+
:return: Whether the observation requires include_statistics to be provided.
|
162
|
+
"""
|
163
|
+
|
164
|
+
return True
|
165
|
+
|
142
166
|
def _get_observation_format(self) -> StatisticStorageTypes:
|
143
167
|
"""
|
144
168
|
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
@@ -185,6 +209,14 @@ class Activations(LocalObserver):
|
|
185
209
|
|
186
210
|
class ParameterUpdates(LocalObserver):
|
187
211
|
|
212
|
+
@property
|
213
|
+
def requires_include_statistics(self) -> bool:
|
214
|
+
"""
|
215
|
+
:return: Whether the observation requires include_statistics to be provided.
|
216
|
+
"""
|
217
|
+
|
218
|
+
return True
|
219
|
+
|
188
220
|
def _get_observation_format(self) -> StatisticStorageTypes:
|
189
221
|
"""
|
190
222
|
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
@@ -231,6 +263,14 @@ class ParameterUpdates(LocalObserver):
|
|
231
263
|
|
232
264
|
class Parameters(LocalObserver):
|
233
265
|
|
266
|
+
@property
|
267
|
+
def requires_include_statistics(self) -> bool:
|
268
|
+
"""
|
269
|
+
:return: Whether the observation requires include_statistics to be provided.
|
270
|
+
"""
|
271
|
+
|
272
|
+
return True
|
273
|
+
|
234
274
|
def _get_observation_format(self) -> StatisticStorageTypes:
|
235
275
|
"""
|
236
276
|
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
@@ -692,17 +732,6 @@ class ModuleTypeOneHot(LocalObserver):
|
|
692
732
|
|
693
733
|
class CurrentHyperparameters(LocalObserver):
|
694
734
|
|
695
|
-
def __init__(self, include_hparams: list[str] | None = None, **kwargs) -> None:
|
696
|
-
"""
|
697
|
-
:param include_hparams: Names of the hyperparameters to include in the initial values vector returned by
|
698
|
-
this observation.
|
699
|
-
:param kwargs: Miscellaneous keyword arguments.
|
700
|
-
"""
|
701
|
-
|
702
|
-
super().__init__(**kwargs)
|
703
|
-
|
704
|
-
self.include_hparams = include_hparams
|
705
|
-
|
706
735
|
@property
|
707
736
|
def can_standardize(self) -> bool:
|
708
737
|
"""
|
@@ -732,6 +761,14 @@ class CurrentHyperparameters(LocalObserver):
|
|
732
761
|
|
733
762
|
return len([hparam for hparam in available_hparams if hparam in self.include_hparams])
|
734
763
|
|
764
|
+
@property
|
765
|
+
def requires_include_hparams(self) -> bool:
|
766
|
+
"""
|
767
|
+
:return: Whether the observation requires include_hparams to be provided.
|
768
|
+
"""
|
769
|
+
|
770
|
+
return True
|
771
|
+
|
735
772
|
def _get_observation_format(self) -> StatisticStorageTypes:
|
736
773
|
"""
|
737
774
|
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
@@ -777,17 +814,6 @@ class CurrentHyperparameters(LocalObserver):
|
|
777
814
|
|
778
815
|
class CurrentHyperparameterDeltas(LocalObserver):
|
779
816
|
|
780
|
-
def __init__(self, include_hparams: list[str] | None = None, **kwargs) -> None:
|
781
|
-
"""
|
782
|
-
:param include_hparams: Names of the hyperparameters to include in the initial deltas vector returned by
|
783
|
-
this observation.
|
784
|
-
:param kwargs: Miscellaneous keyword arguments.
|
785
|
-
"""
|
786
|
-
|
787
|
-
super().__init__(**kwargs)
|
788
|
-
|
789
|
-
self.include_hparams = include_hparams
|
790
|
-
|
791
817
|
@property
|
792
818
|
def can_standardize(self) -> bool:
|
793
819
|
"""
|
@@ -817,6 +843,14 @@ class CurrentHyperparameterDeltas(LocalObserver):
|
|
817
843
|
|
818
844
|
return len([hparam for hparam in available_hparams if hparam in self.include_hparams])
|
819
845
|
|
846
|
+
@property
|
847
|
+
def requires_include_hparams(self) -> bool:
|
848
|
+
"""
|
849
|
+
:return: Whether the observation requires include_hparams to be provided.
|
850
|
+
"""
|
851
|
+
|
852
|
+
return True
|
853
|
+
|
820
854
|
def _get_observation_format(self) -> StatisticStorageTypes:
|
821
855
|
"""
|
822
856
|
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
@@ -865,17 +899,6 @@ class HyperparameterTransformTypes(LocalObserver):
|
|
865
899
|
|
866
900
|
TRANSFORM_TYPE_TO_IDX = dict(((s, i) for i, s in enumerate(HyperparameterTransformType)))
|
867
901
|
|
868
|
-
def __init__(self, include_hparams: list[str] | None = None, **kwargs) -> None:
|
869
|
-
"""
|
870
|
-
:param include_hparams: Names of the hyperparameters to include in the transforms vector returned by
|
871
|
-
this observation.
|
872
|
-
:param kwargs: Miscellaneous keyword arguments.
|
873
|
-
"""
|
874
|
-
|
875
|
-
super().__init__(**kwargs)
|
876
|
-
|
877
|
-
self.include_hparams = include_hparams
|
878
|
-
|
879
902
|
@property
|
880
903
|
def can_standardize(self) -> bool:
|
881
904
|
"""
|
@@ -907,6 +930,14 @@ class HyperparameterTransformTypes(LocalObserver):
|
|
907
930
|
[hparam for hparam in available_hparams if hparam in self.include_hparams]
|
908
931
|
)
|
909
932
|
|
933
|
+
@property
|
934
|
+
def requires_include_hparams(self) -> bool:
|
935
|
+
"""
|
936
|
+
:return: Whether the observation requires include_hparams to be provided.
|
937
|
+
"""
|
938
|
+
|
939
|
+
return True
|
940
|
+
|
910
941
|
def _get_observation_format(self) -> StatisticStorageTypes:
|
911
942
|
"""
|
912
943
|
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
@@ -1087,7 +1118,6 @@ class LogOfNoiseScaleObserver(LocalObserver):
|
|
1087
1118
|
*,
|
1088
1119
|
decay_factor: float = LHOPT_CONSTANTS["DEFAULT_DECAY_FACTOR"],
|
1089
1120
|
time_window: int = LHOPT_CONSTANTS["DEFAULT_TIME_WINDOW"],
|
1090
|
-
include_statistics: list[str] | None = None,
|
1091
1121
|
**kwargs,
|
1092
1122
|
) -> None:
|
1093
1123
|
"""
|
@@ -1100,7 +1130,6 @@ class LogOfNoiseScaleObserver(LocalObserver):
|
|
1100
1130
|
|
1101
1131
|
super().__init__(**kwargs)
|
1102
1132
|
|
1103
|
-
self.include_statistics = include_statistics
|
1104
1133
|
self.decay_factor = max(0.0, decay_factor)
|
1105
1134
|
self.time_window = max(1, time_window)
|
1106
1135
|
|
@@ -1108,14 +1137,6 @@ class LogOfNoiseScaleObserver(LocalObserver):
|
|
1108
1137
|
self._time_series: list[tuple[float, float]] = [] # (time, value) pairs
|
1109
1138
|
self._current_time: float = 0.0
|
1110
1139
|
|
1111
|
-
def _get_observation_format(self) -> StatisticStorageTypes:
|
1112
|
-
"""
|
1113
|
-
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
1114
|
-
enumeration class.
|
1115
|
-
"""
|
1116
|
-
|
1117
|
-
return StatisticStorageTypes.VECTOR
|
1118
|
-
|
1119
1140
|
@property
|
1120
1141
|
def can_standardize(self) -> bool:
|
1121
1142
|
"""
|
@@ -1132,6 +1153,29 @@ class LogOfNoiseScaleObserver(LocalObserver):
|
|
1132
1153
|
|
1133
1154
|
return False
|
1134
1155
|
|
1156
|
+
@property
|
1157
|
+
def vector_length(self) -> int:
|
1158
|
+
"""
|
1159
|
+
:return: Length of the vector returned by this observation if it returns a vector.
|
1160
|
+
"""
|
1161
|
+
return 2 # [log_noise_scale, cdf_feature]
|
1162
|
+
|
1163
|
+
@property
|
1164
|
+
def requires_include_statistics(self) -> bool:
|
1165
|
+
"""
|
1166
|
+
:return: Whether the observation requires include_statistics to be provided.
|
1167
|
+
"""
|
1168
|
+
|
1169
|
+
return True
|
1170
|
+
|
1171
|
+
def _get_observation_format(self) -> StatisticStorageTypes:
|
1172
|
+
"""
|
1173
|
+
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
1174
|
+
enumeration class.
|
1175
|
+
"""
|
1176
|
+
|
1177
|
+
return StatisticStorageTypes.VECTOR
|
1178
|
+
|
1135
1179
|
def _update_time(self) -> None:
|
1136
1180
|
"""Update the current time counter."""
|
1137
1181
|
self._current_time += 1.0
|
@@ -1145,13 +1189,6 @@ class LogOfNoiseScaleObserver(LocalObserver):
|
|
1145
1189
|
"""
|
1146
1190
|
return compute_cdf_feature(value, self._time_series, self.decay_factor, self._current_time, self.time_window)
|
1147
1191
|
|
1148
|
-
@property
|
1149
|
-
def vector_length(self) -> int:
|
1150
|
-
"""
|
1151
|
-
:return: Length of the vector returned by this observation if it returns a vector.
|
1152
|
-
"""
|
1153
|
-
return 2 # [log_noise_scale, cdf_feature]
|
1154
|
-
|
1155
1192
|
def _observe(
|
1156
1193
|
self,
|
1157
1194
|
observation_inputs: ObservationInputs,
|
@@ -1169,16 +1206,16 @@ class LogOfNoiseScaleObserver(LocalObserver):
|
|
1169
1206
|
"""
|
1170
1207
|
|
1171
1208
|
statistics = tracked_statistics[statistic_trackers.LogOfNoiseScaleStatistics.__name__]
|
1172
|
-
|
1173
1209
|
raw_value = list(statistics.values())[0] # type: ignore[list-item]
|
1210
|
+
|
1174
1211
|
assert isinstance(raw_value, float), f"Expected float, got {type(raw_value)}" # to avoid type errors with mypy
|
1212
|
+
|
1175
1213
|
batch_size = hyperparameter_states.global_hparams.batch_size.external_value
|
1176
1214
|
learning_rate = hyperparameter_states.parameter_group_hparams[
|
1177
1215
|
self.parameter_group_name
|
1178
1216
|
].learning_rate.external_value
|
1179
1217
|
|
1180
1218
|
log_b_over_epsilon = math.log(batch_size / learning_rate)
|
1181
|
-
|
1182
1219
|
log_noise_scale = raw_value + log_b_over_epsilon
|
1183
1220
|
|
1184
1221
|
cdf_feature = self._compute_cdf_feature(log_noise_scale) # type: ignore[arg-type]
|
@@ -1187,19 +1187,17 @@ class LogOfNoiseScaleStatistics(Statistic):
|
|
1187
1187
|
# This is a common assumption when the exact noise structure is unknown
|
1188
1188
|
noise_covariance = torch.ones_like(hessian_diagonals)
|
1189
1189
|
|
1190
|
-
# Compute tr(HΣ)
|
1191
|
-
trace_hessian_noise_covariance =
|
1192
|
-
|
1193
|
-
|
1194
|
-
if trace_hessian_noise_covariance <= 0:
|
1195
|
-
return None
|
1190
|
+
# Compute tr(HΣ), add zero division tolerance to avoid log of zero when gradient is too small
|
1191
|
+
trace_hessian_noise_covariance = (
|
1192
|
+
torch.sum(hessian_diagonals * noise_covariance) + LHOPT_CONSTANTS["ZERO_DIVISION_TOLERANCE"]
|
1193
|
+
)
|
1196
1194
|
|
1197
1195
|
log_trace_hessian_noise_covariance = torch.log(trace_hessian_noise_covariance).item()
|
1198
1196
|
|
1199
|
-
# Compute tr(H^3 Σ)
|
1200
|
-
trace_hessian_cubed_noise_covariance =
|
1201
|
-
|
1202
|
-
|
1197
|
+
# Compute tr(H^3 Σ), add zero division tolerance to avoid log of zero when gradient is too small
|
1198
|
+
trace_hessian_cubed_noise_covariance = (
|
1199
|
+
torch.sum(hessian_diagonals**3 * noise_covariance) + LHOPT_CONSTANTS["ZERO_DIVISION_TOLERANCE"]
|
1200
|
+
)
|
1203
1201
|
|
1204
1202
|
log_trace_hessian_cubed_noise_covariance = torch.log(trace_hessian_cubed_noise_covariance).item()
|
1205
1203
|
|
libinephany/utils/enums.py
CHANGED
@@ -6,18 +6,18 @@ libinephany/observations/observation_utils.py,sha256=JSNJYEi2d-VQ0ZovfHrn28RDv41
|
|
6
6
|
libinephany/observations/observer_pipeline.py,sha256=_xA4vrijhG8-9MCtGXnKAEmpd6q0nKVpJgY_qSbypIA,12979
|
7
7
|
libinephany/observations/pipeline_coordinator.py,sha256=mLfaHhkXVhMp9w5jWIAL3jPyauCM-795qOzyqwGOSdw,7932
|
8
8
|
libinephany/observations/statistic_manager.py,sha256=LLg1zSxnJr2oQQepYla3qoUuRy10rsthr9jta4wEbnc,8956
|
9
|
-
libinephany/observations/statistic_trackers.py,sha256=
|
9
|
+
libinephany/observations/statistic_trackers.py,sha256=F98V-H2Ljx0v2YnppYCCJLJojL6pzYBdbBh8Lb4lasA,47666
|
10
10
|
libinephany/observations/observers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
|
-
libinephany/observations/observers/base_observers.py,sha256=
|
12
|
-
libinephany/observations/observers/local_observers.py,sha256=
|
11
|
+
libinephany/observations/observers/base_observers.py,sha256=V8PIysq2wT6K-w_CqeM5benyif-xK1hPT3M6a4ic1So,17535
|
12
|
+
libinephany/observations/observers/local_observers.py,sha256=PJDsZ-DO2rptW87wrZsclrx3GKuSMdFfxT_FF7ov0Is,46137
|
13
13
|
libinephany/observations/observers/observer_containers.py,sha256=VNyqGgxYJ4r49Msp_kk-POgicb-_5w54twuT1qfNdxw,9562
|
14
14
|
libinephany/observations/observers/global_observers/__init__.py,sha256=87WHRPYmL0tVsaTKUd91pwEpCZtHPSKRQoba2VQjswA,3018
|
15
|
-
libinephany/observations/observers/global_observers/base_classes.py,sha256=
|
16
|
-
libinephany/observations/observers/global_observers/constants.py,sha256=
|
17
|
-
libinephany/observations/observers/global_observers/gradient_observers.py,sha256=
|
18
|
-
libinephany/observations/observers/global_observers/hyperparameter_observers.py,sha256=
|
19
|
-
libinephany/observations/observers/global_observers/loss_observers.py,sha256=
|
20
|
-
libinephany/observations/observers/global_observers/model_observers.py,sha256=
|
15
|
+
libinephany/observations/observers/global_observers/base_classes.py,sha256=Q7OblhmKscypTs9JBepSQwo6ljjOdPKTU9kbpuhq_W4,7800
|
16
|
+
libinephany/observations/observers/global_observers/constants.py,sha256=TDQM_sGU8Swze794oB4TBaXFjSddt0OBhYPVhrXQ9Ko,1654
|
17
|
+
libinephany/observations/observers/global_observers/gradient_observers.py,sha256=j9uX7043ic06W6vb6vDt_PH2a5WtLFCdKQZ1JQGT24Q,19531
|
18
|
+
libinephany/observations/observers/global_observers/hyperparameter_observers.py,sha256=5Av8FgwWBJtcn4gpDzPdOTlKOOYy2lVEHS_gt5Sz7xo,15334
|
19
|
+
libinephany/observations/observers/global_observers/loss_observers.py,sha256=Kf943FiuYWuWvjmhgmp3TGyIQoZ27ZKJcxfTBwXS-gA,17761
|
20
|
+
libinephany/observations/observers/global_observers/model_observers.py,sha256=SGWXrmTdgp0kHvEvDSF7d3v1FEcK1sQDMPLQ8Wy3qv4,29306
|
21
21
|
libinephany/observations/observers/global_observers/progress_observers.py,sha256=ypLk1_POAjA8V8rAaQ0B6Qh8m_04s9PAoXsw1KxVrLg,5872
|
22
22
|
libinephany/observations/post_processors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
23
23
|
libinephany/observations/post_processors/postprocessors.py,sha256=43_e5UaDPr2KbAvqc_w3wLqnlm7bgRjqgCtyQ95-8cM,5913
|
@@ -42,7 +42,7 @@ libinephany/utils/backend_statuses.py,sha256=ZbpBPbz0qKmeqxyGGN_ePTrQ7Wrxh7KM6W2
|
|
42
42
|
libinephany/utils/constants.py,sha256=XAOuPowvM4FDSbfvNsubKTAqSB84AANX4CoHb7LwgEI,2330
|
43
43
|
libinephany/utils/directory_utils.py,sha256=408unVeE_5_Hm-ZYZuxc9sdvfuU0CgYELX7EzPlPieo,1217
|
44
44
|
libinephany/utils/dropout_utils.py,sha256=X43yCW7Dh1cC5sNnivgS5j1fn871K_RCvxCBTT0YHKg,3392
|
45
|
-
libinephany/utils/enums.py,sha256=
|
45
|
+
libinephany/utils/enums.py,sha256=6fTgUd4EiFh4TzNXjvWX-zx1UKb90emgDaGB5gyAbdo,2977
|
46
46
|
libinephany/utils/error_severities.py,sha256=B9oidqOVaYOe0W6P6GwjpmuDsrkyTX30v1xdiUStCFk,1427
|
47
47
|
libinephany/utils/exceptions.py,sha256=kgwLpHOgy3kciUz_I18xnYsWRtzdonfadUtwG2uDYk8,1823
|
48
48
|
libinephany/utils/import_utils.py,sha256=WzC6V6UIa0nCiU2MekROwG82fWBh9RuVzichtby5EvM,1495
|
@@ -57,8 +57,8 @@ libinephany/utils/typing.py,sha256=rGbaPO3MaUndsWiC_wHzReD_TOLYqb43i01pKN-j7Xs,6
|
|
57
57
|
libinephany/web_apps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
58
58
|
libinephany/web_apps/error_logger.py,sha256=gAQIaqerqP4ornXZwFF1cghjnd2mMZEt3aVrTuUCr34,16653
|
59
59
|
libinephany/web_apps/web_app_utils.py,sha256=qiq_lasPipgN1RgRudPJc342kYci8O_4RqppxmIX8NY,4095
|
60
|
-
libinephany-1.0.
|
61
|
-
libinephany-1.0.
|
62
|
-
libinephany-1.0.
|
63
|
-
libinephany-1.0.
|
64
|
-
libinephany-1.0.
|
60
|
+
libinephany-1.0.2.dist-info/licenses/LICENSE,sha256=pogfDoMBP07ehIOvWymuWIar8pg2YLUhqOHsJQU3wdc,9250
|
61
|
+
libinephany-1.0.2.dist-info/METADATA,sha256=pATrPbN9k--PfUM5JDsn6jPj-77eUnHD94rJrVUs8JI,8389
|
62
|
+
libinephany-1.0.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
63
|
+
libinephany-1.0.2.dist-info/top_level.txt,sha256=bYAOXQdJgIoLkO2Ui0kxe7pSYegS_e38u0dMscd7COQ,12
|
64
|
+
libinephany-1.0.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|