libinephany 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- libinephany/observations/observers/base_observers.py +26 -0
- libinephany/observations/observers/global_observers/base_classes.py +50 -64
- libinephany/observations/observers/global_observers/constants.py +14 -1
- libinephany/observations/observers/global_observers/gradient_observers.py +55 -94
- libinephany/observations/observers/global_observers/hyperparameter_observers.py +26 -8
- libinephany/observations/observers/global_observers/loss_observers.py +26 -47
- libinephany/observations/observers/global_observers/model_observers.py +65 -25
- libinephany/observations/observers/local_observers.py +90 -52
- libinephany/utils/enums.py +1 -0
- {libinephany-1.0.1.dist-info → libinephany-1.0.3.dist-info}/METADATA +1 -1
- {libinephany-1.0.1.dist-info → libinephany-1.0.3.dist-info}/RECORD +14 -14
- {libinephany-1.0.1.dist-info → libinephany-1.0.3.dist-info}/WHEEL +0 -0
- {libinephany-1.0.1.dist-info → libinephany-1.0.3.dist-info}/licenses/LICENSE +0 -0
- {libinephany-1.0.1.dist-info → libinephany-1.0.3.dist-info}/top_level.txt +0 -0
@@ -44,6 +44,7 @@ class Observer(ABC):
|
|
44
44
|
observer_config: ObserverConfig,
|
45
45
|
should_standardize: bool = True,
|
46
46
|
include_statistics: list[str] | None = None,
|
47
|
+
include_hparams: list[str] | None = None,
|
47
48
|
**kwargs,
|
48
49
|
) -> None:
|
49
50
|
"""
|
@@ -52,6 +53,8 @@ class Observer(ABC):
|
|
52
53
|
:param should_standardize: Whether standardization should be applied to returned values.
|
53
54
|
:param include_statistics: If the observation uses the TensorStatistic model to return observations, names of the
|
54
55
|
fields in the model to include in returned observations.
|
56
|
+
:param include_hparams: If the observation uses the HyperparameterStates model to return observations, names of the
|
57
|
+
hyperparameters to include in returned observations.
|
55
58
|
:param kwargs: Miscellaneous keyword arguments.
|
56
59
|
"""
|
57
60
|
|
@@ -64,10 +67,17 @@ class Observer(ABC):
|
|
64
67
|
self.should_standardize = should_standardize and self.can_standardize
|
65
68
|
|
66
69
|
self.include_statistics: list[str] | None = None
|
70
|
+
self.include_hparams = include_hparams
|
67
71
|
|
68
72
|
if include_statistics is not None:
|
69
73
|
self.include_statistics = TensorStatistics.filter_include_statistics(include_statistics=include_statistics)
|
70
74
|
|
75
|
+
if self.requires_include_statistics and not self.include_statistics:
|
76
|
+
raise ValueError(f"{self.__class__.__name__} must be provided with include_statistics.")
|
77
|
+
|
78
|
+
if self.requires_include_hparams and not self.include_hparams:
|
79
|
+
raise ValueError(f"{self.__class__.__name__} must be provided with include_hparams.")
|
80
|
+
|
71
81
|
@final
|
72
82
|
@property
|
73
83
|
def in_training_mode(self) -> bool:
|
@@ -143,6 +153,22 @@ class Observer(ABC):
|
|
143
153
|
|
144
154
|
return True
|
145
155
|
|
156
|
+
@property
|
157
|
+
def requires_include_statistics(self) -> bool:
|
158
|
+
"""
|
159
|
+
:return: Whether the observation requires include_statistics to be provided.
|
160
|
+
"""
|
161
|
+
|
162
|
+
return False
|
163
|
+
|
164
|
+
@property
|
165
|
+
def requires_include_hparams(self) -> bool:
|
166
|
+
"""
|
167
|
+
:return: Whether the observation requires include_hparams to be provided.
|
168
|
+
"""
|
169
|
+
|
170
|
+
return False
|
171
|
+
|
146
172
|
@property
|
147
173
|
@abstractmethod
|
148
174
|
def standardizer_key_infix(self) -> str:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
# ======================================================================================================================
|
2
2
|
#
|
3
|
-
#
|
3
|
+
# IMPORTS
|
4
4
|
#
|
5
5
|
# ======================================================================================================================
|
6
6
|
|
@@ -15,6 +15,12 @@ from libinephany.pydantic_models.schemas.observation_models import ObservationIn
|
|
15
15
|
from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
|
16
16
|
from libinephany.pydantic_models.states.hyperparameter_states import HyperparameterStates
|
17
17
|
|
18
|
+
# ======================================================================================================================
|
19
|
+
#
|
20
|
+
# CLASSES
|
21
|
+
#
|
22
|
+
# ======================================================================================================================
|
23
|
+
|
18
24
|
|
19
25
|
class LHOPTBaseObserver(GlobalObserver, ABC):
|
20
26
|
"""
|
@@ -33,13 +39,14 @@ class LHOPTBaseObserver(GlobalObserver, ABC):
|
|
33
39
|
:param kwargs: Other observation keyword arguments.
|
34
40
|
"""
|
35
41
|
super().__init__(**kwargs)
|
36
|
-
self.decay_factor = max(0.0, decay_factor)
|
37
|
-
self.time_window = max(1, time_window)
|
38
42
|
|
39
43
|
# Store time series data for CDF calculation
|
40
44
|
self._time_series: list[tuple[float, float]] = [] # (time, value) pairs
|
41
45
|
self._current_time: float = 0.0
|
42
46
|
|
47
|
+
self.decay_factor = max(0.0, decay_factor)
|
48
|
+
self.time_window = max(1, time_window)
|
49
|
+
|
43
50
|
@property
|
44
51
|
def can_standardize(self) -> bool:
|
45
52
|
"""
|
@@ -48,6 +55,28 @@ class LHOPTBaseObserver(GlobalObserver, ABC):
|
|
48
55
|
"""
|
49
56
|
return False
|
50
57
|
|
58
|
+
@staticmethod
|
59
|
+
def _compute_log_ratio(numerator: float, denominator: float) -> float:
|
60
|
+
"""
|
61
|
+
Compute the log ratio.
|
62
|
+
|
63
|
+
:param numerator: Numerator value
|
64
|
+
:param denominator: Denominator value
|
65
|
+
:return: Log ratio value
|
66
|
+
"""
|
67
|
+
# Calculate the ratio of numerator to denominator
|
68
|
+
invalid_denominator = math.isinf(denominator) or math.isnan(denominator)
|
69
|
+
|
70
|
+
if denominator <= LHOPT_CONSTANTS["ZERO_DIVISION_TOLERANCE"] or invalid_denominator:
|
71
|
+
return 0.0
|
72
|
+
|
73
|
+
ratio = numerator / denominator
|
74
|
+
|
75
|
+
if ratio <= 0:
|
76
|
+
return 0.0
|
77
|
+
|
78
|
+
return math.log(ratio)
|
79
|
+
|
51
80
|
def _get_observation_format(self) -> StatisticStorageTypes:
|
52
81
|
"""
|
53
82
|
:return: Format the observation returns data in. Must be one of the StatisticStorageTypes
|
@@ -68,18 +97,6 @@ class LHOPTBaseObserver(GlobalObserver, ABC):
|
|
68
97
|
"""Update the current time counter."""
|
69
98
|
self._current_time += 1.0
|
70
99
|
|
71
|
-
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
72
|
-
"""
|
73
|
-
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
74
|
-
needed.
|
75
|
-
"""
|
76
|
-
return {}
|
77
|
-
|
78
|
-
def reset(self) -> None:
|
79
|
-
"""Reset the observer by clearing the time series."""
|
80
|
-
self._time_series = []
|
81
|
-
self._current_time = 0.0
|
82
|
-
|
83
100
|
@abstractmethod
|
84
101
|
def _observe(
|
85
102
|
self,
|
@@ -94,27 +111,13 @@ class LHOPTBaseObserver(GlobalObserver, ABC):
|
|
94
111
|
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
95
112
|
:param action_taken: Action taken by the agent this class instance is assigned to.
|
96
113
|
"""
|
97
|
-
raise NotImplementedError
|
98
|
-
|
99
|
-
def _compute_log_ratio(self, numerator: float, denominator: float) -> float:
|
100
|
-
"""
|
101
|
-
Compute the log ratio.
|
102
|
-
|
103
|
-
:param numerator: Numerator value
|
104
|
-
:param denominator: Denominator value
|
105
|
-
:return: Log ratio value
|
106
|
-
"""
|
107
|
-
# Calculate the ratio of numerator to denominator
|
108
114
|
|
109
|
-
|
110
|
-
return 0.0
|
115
|
+
...
|
111
116
|
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
return math.log(ratio)
|
117
|
+
def reset(self) -> None:
|
118
|
+
"""Reset the observer by clearing the time series."""
|
119
|
+
self._time_series = []
|
120
|
+
self._current_time = 0.0
|
118
121
|
|
119
122
|
|
120
123
|
class LHOPTCheckpointBaseObserver(GlobalObserver, ABC):
|
@@ -128,8 +131,10 @@ class LHOPTCheckpointBaseObserver(GlobalObserver, ABC):
|
|
128
131
|
:param kwargs: Miscellaneous keyword arguments.
|
129
132
|
"""
|
130
133
|
super().__init__(**kwargs)
|
131
|
-
|
134
|
+
|
132
135
|
self._history: list[float] = []
|
136
|
+
|
137
|
+
self.checkpoint_interval = checkpoint_interval
|
133
138
|
self.last_value: float | None = None
|
134
139
|
|
135
140
|
@property
|
@@ -175,18 +180,6 @@ class LHOPTCheckpointBaseObserver(GlobalObserver, ABC):
|
|
175
180
|
if self.last_value is None:
|
176
181
|
self.last_value = value
|
177
182
|
|
178
|
-
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
179
|
-
"""
|
180
|
-
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
181
|
-
needed.
|
182
|
-
"""
|
183
|
-
return {}
|
184
|
-
|
185
|
-
def reset(self) -> None:
|
186
|
-
"""Reset the observer by clearing history."""
|
187
|
-
self._history = []
|
188
|
-
self.last_value = None
|
189
|
-
|
190
183
|
@abstractmethod
|
191
184
|
def _observe(
|
192
185
|
self,
|
@@ -201,24 +194,17 @@ class LHOPTCheckpointBaseObserver(GlobalObserver, ABC):
|
|
201
194
|
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
202
195
|
:param action_taken: Action taken by the agent this class instance is assigned to.
|
203
196
|
"""
|
204
|
-
raise NotImplementedError
|
205
197
|
|
206
|
-
|
207
|
-
"""
|
208
|
-
Compute the log ratio.
|
198
|
+
...
|
209
199
|
|
210
|
-
|
211
|
-
:param denominator: Denominator value
|
212
|
-
:return: Log ratio value
|
200
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
213
201
|
"""
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
ratio = numerator / denominator
|
220
|
-
|
221
|
-
if ratio <= 0:
|
222
|
-
return 0.0
|
202
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
203
|
+
needed.
|
204
|
+
"""
|
205
|
+
return {}
|
223
206
|
|
224
|
-
|
207
|
+
def reset(self) -> None:
|
208
|
+
"""Reset the observer by clearing history."""
|
209
|
+
self._history = []
|
210
|
+
self.last_value = None
|
@@ -1,11 +1,17 @@
|
|
1
1
|
# ======================================================================================================================
|
2
2
|
#
|
3
|
-
#
|
3
|
+
# IMPORTS
|
4
4
|
#
|
5
5
|
# ======================================================================================================================
|
6
6
|
|
7
7
|
from typing import TypedDict
|
8
8
|
|
9
|
+
# ======================================================================================================================
|
10
|
+
#
|
11
|
+
# CLASSES
|
12
|
+
#
|
13
|
+
# ======================================================================================================================
|
14
|
+
|
9
15
|
|
10
16
|
class LHOPTConstants(TypedDict):
|
11
17
|
IS_NAN: float
|
@@ -23,6 +29,13 @@ class LHOPTConstants(TypedDict):
|
|
23
29
|
DEFAULT_ENV_STEP_SAMPLE_FREQUENCY: int
|
24
30
|
|
25
31
|
|
32
|
+
# ======================================================================================================================
|
33
|
+
#
|
34
|
+
# CONSTANTS
|
35
|
+
#
|
36
|
+
# ======================================================================================================================
|
37
|
+
|
38
|
+
|
26
39
|
# Create the constants instance
|
27
40
|
LHOPT_CONSTANTS: LHOPTConstants = LHOPTConstants(
|
28
41
|
IS_NAN=1.0,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
# ======================================================================================================================
|
2
2
|
#
|
3
|
-
#
|
3
|
+
# IMPORTS
|
4
4
|
#
|
5
5
|
# ======================================================================================================================
|
6
6
|
|
@@ -16,9 +16,23 @@ from libinephany.pydantic_models.schemas.observation_models import ObservationIn
|
|
16
16
|
from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
|
17
17
|
from libinephany.pydantic_models.states.hyperparameter_states import HyperparameterStates
|
18
18
|
|
19
|
+
# ======================================================================================================================
|
20
|
+
#
|
21
|
+
# CLASSES
|
22
|
+
#
|
23
|
+
# ======================================================================================================================
|
24
|
+
|
19
25
|
|
20
26
|
class GlobalFirstOrderGradients(GlobalObserver):
|
21
27
|
|
28
|
+
@property
|
29
|
+
def requires_include_statistics(self) -> bool:
|
30
|
+
"""
|
31
|
+
:return: Whether the observation requires include_statistics to be provided.
|
32
|
+
"""
|
33
|
+
|
34
|
+
return True
|
35
|
+
|
22
36
|
def _get_observation_format(self) -> StatisticStorageTypes:
|
23
37
|
"""
|
24
38
|
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
@@ -74,6 +88,14 @@ class GlobalSecondOrderGradients(GlobalObserver):
|
|
74
88
|
|
75
89
|
self.compute_hessian_diagonal = compute_hessian_diagonal
|
76
90
|
|
91
|
+
@property
|
92
|
+
def requires_include_statistics(self) -> bool:
|
93
|
+
"""
|
94
|
+
:return: Whether the observation requires include_statistics to be provided.
|
95
|
+
"""
|
96
|
+
|
97
|
+
return True
|
98
|
+
|
77
99
|
def _get_observation_format(self) -> StatisticStorageTypes:
|
78
100
|
"""
|
79
101
|
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
@@ -136,21 +158,6 @@ class LHOPTGradientVarianceFraction(LHOPTBaseObserver):
|
|
136
158
|
super().__init__(**kwargs)
|
137
159
|
self.variance_threshold = variance_threshold
|
138
160
|
|
139
|
-
@property
|
140
|
-
def can_standardize(self) -> bool:
|
141
|
-
"""
|
142
|
-
This observer has its own CDF calculation, no need to standardize.
|
143
|
-
:return: Whether the observation can be standardized.
|
144
|
-
"""
|
145
|
-
return False
|
146
|
-
|
147
|
-
def _get_observation_format(self) -> StatisticStorageTypes:
|
148
|
-
"""
|
149
|
-
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
150
|
-
enumeration class.
|
151
|
-
"""
|
152
|
-
return StatisticStorageTypes.VECTOR
|
153
|
-
|
154
161
|
@property
|
155
162
|
def vector_length(self) -> int:
|
156
163
|
"""
|
@@ -173,8 +180,6 @@ class LHOPTGradientVarianceFraction(LHOPTBaseObserver):
|
|
173
180
|
:param action_taken: Action taken by the agent this class instance is assigned to.
|
174
181
|
:return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
|
175
182
|
"""
|
176
|
-
if statistic_trackers.GradientVarianceFraction.__name__ not in tracked_statistics:
|
177
|
-
return [0.0, 0.0]
|
178
183
|
|
179
184
|
raw_value = list(tracked_statistics[statistic_trackers.GradientVarianceFraction.__name__].values())[0] # type: ignore[list-item]
|
180
185
|
|
@@ -204,14 +209,6 @@ class LHOPTMomentumGradientRatio(LHOPTBaseObserver):
|
|
204
209
|
It returns two-dimensional observations: [raw_value, cdf_feature] for momentum gradient ratio values.
|
205
210
|
"""
|
206
211
|
|
207
|
-
def _get_observation_format(self) -> StatisticStorageTypes:
|
208
|
-
"""
|
209
|
-
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
210
|
-
enumeration class.
|
211
|
-
"""
|
212
|
-
|
213
|
-
return StatisticStorageTypes.VECTOR
|
214
|
-
|
215
212
|
@property
|
216
213
|
def vector_length(self) -> int:
|
217
214
|
"""
|
@@ -219,6 +216,14 @@ class LHOPTMomentumGradientRatio(LHOPTBaseObserver):
|
|
219
216
|
"""
|
220
217
|
return 2 # [raw_value, cdf_feature]
|
221
218
|
|
219
|
+
@property
|
220
|
+
def requires_include_statistics(self) -> bool:
|
221
|
+
"""
|
222
|
+
:return: Whether the observation requires include_statistics to be provided.
|
223
|
+
"""
|
224
|
+
|
225
|
+
return True
|
226
|
+
|
222
227
|
def _observe(
|
223
228
|
self,
|
224
229
|
observation_inputs: ObservationInputs,
|
@@ -266,29 +271,6 @@ class CosineSimilarityObserverOfGradientAndMomentum(LHOPTBaseObserver):
|
|
266
271
|
It returns two-dimensional observations: [raw_value, cdf_feature] for cosine similarity of gradient and momentum values.
|
267
272
|
"""
|
268
273
|
|
269
|
-
def __init__(
|
270
|
-
self,
|
271
|
-
*,
|
272
|
-
include_statistics: list[str] | None = None,
|
273
|
-
**kwargs,
|
274
|
-
) -> None:
|
275
|
-
"""
|
276
|
-
:param include_statistics: List of statistics to include.
|
277
|
-
:param kwargs: Miscellaneous keyword arguments.
|
278
|
-
"""
|
279
|
-
|
280
|
-
super().__init__(**kwargs)
|
281
|
-
|
282
|
-
self.include_statistics = include_statistics
|
283
|
-
|
284
|
-
def _get_observation_format(self) -> StatisticStorageTypes:
|
285
|
-
"""
|
286
|
-
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
287
|
-
enumeration class.
|
288
|
-
"""
|
289
|
-
|
290
|
-
return StatisticStorageTypes.VECTOR
|
291
|
-
|
292
274
|
@property
|
293
275
|
def vector_length(self) -> int:
|
294
276
|
"""
|
@@ -296,6 +278,14 @@ class CosineSimilarityObserverOfGradientAndMomentum(LHOPTBaseObserver):
|
|
296
278
|
"""
|
297
279
|
return 3 # [raw_value, cdf_feature, logit_of_cdf_feature]
|
298
280
|
|
281
|
+
@property
|
282
|
+
def requires_include_statistics(self) -> bool:
|
283
|
+
"""
|
284
|
+
:return: Whether the observation requires include_statistics to be provided.
|
285
|
+
"""
|
286
|
+
|
287
|
+
return True
|
288
|
+
|
299
289
|
def _observe(
|
300
290
|
self,
|
301
291
|
observation_inputs: ObservationInputs,
|
@@ -351,29 +341,6 @@ class CosineSimilarityObserverOfGradientAndUpdate(LHOPTBaseObserver):
|
|
351
341
|
It returns two-dimensional observations: [raw_value, cdf_feature] for cosine similarity of gradient and update values.
|
352
342
|
"""
|
353
343
|
|
354
|
-
def __init__(
|
355
|
-
self,
|
356
|
-
*,
|
357
|
-
include_statistics: list[str] | None = None,
|
358
|
-
**kwargs,
|
359
|
-
) -> None:
|
360
|
-
"""
|
361
|
-
:param include_statistics: List of statistics to include.
|
362
|
-
:param kwargs: Miscellaneous keyword arguments.
|
363
|
-
"""
|
364
|
-
|
365
|
-
super().__init__(**kwargs)
|
366
|
-
|
367
|
-
self.include_statistics = include_statistics
|
368
|
-
|
369
|
-
def _get_observation_format(self) -> StatisticStorageTypes:
|
370
|
-
"""
|
371
|
-
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
372
|
-
enumeration class.
|
373
|
-
"""
|
374
|
-
|
375
|
-
return StatisticStorageTypes.VECTOR
|
376
|
-
|
377
344
|
@property
|
378
345
|
def vector_length(self) -> int:
|
379
346
|
"""
|
@@ -381,6 +348,14 @@ class CosineSimilarityObserverOfGradientAndUpdate(LHOPTBaseObserver):
|
|
381
348
|
"""
|
382
349
|
return 3 # [raw_value, cdf_feature, logit_of_cdf_feature]
|
383
350
|
|
351
|
+
@property
|
352
|
+
def requires_include_statistics(self) -> bool:
|
353
|
+
"""
|
354
|
+
:return: Whether the observation requires include_statistics to be provided.
|
355
|
+
"""
|
356
|
+
|
357
|
+
return True
|
358
|
+
|
384
359
|
def _observe(
|
385
360
|
self,
|
386
361
|
observation_inputs: ObservationInputs,
|
@@ -436,28 +411,6 @@ class CosineSimilarityOfGradientAndParameter(LHOPTBaseObserver):
|
|
436
411
|
It returns two-dimensional observations: [raw_value, cdf_feature] for cosine similarity of gradient and parameter values.
|
437
412
|
"""
|
438
413
|
|
439
|
-
def __init__(
|
440
|
-
self,
|
441
|
-
*,
|
442
|
-
include_statistics: list[str] | None = None,
|
443
|
-
**kwargs,
|
444
|
-
) -> None:
|
445
|
-
"""
|
446
|
-
:param include_statistics: List of statistics to include.
|
447
|
-
:param kwargs: Miscellaneous keyword arguments.
|
448
|
-
"""
|
449
|
-
super().__init__(**kwargs)
|
450
|
-
|
451
|
-
self.include_statistics = include_statistics
|
452
|
-
|
453
|
-
def _get_observation_format(self) -> StatisticStorageTypes:
|
454
|
-
"""
|
455
|
-
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
456
|
-
enumeration class.
|
457
|
-
"""
|
458
|
-
|
459
|
-
return StatisticStorageTypes.VECTOR
|
460
|
-
|
461
414
|
@property
|
462
415
|
def vector_length(self) -> int:
|
463
416
|
"""
|
@@ -465,6 +418,14 @@ class CosineSimilarityOfGradientAndParameter(LHOPTBaseObserver):
|
|
465
418
|
"""
|
466
419
|
return 3 # [raw_value, cdf_feature, logit_of_cdf_feature]
|
467
420
|
|
421
|
+
@property
|
422
|
+
def requires_include_statistics(self) -> bool:
|
423
|
+
"""
|
424
|
+
:return: Whether the observation requires include_statistics to be provided.
|
425
|
+
"""
|
426
|
+
|
427
|
+
return True
|
428
|
+
|
468
429
|
def _observe(
|
469
430
|
self,
|
470
431
|
observation_inputs: ObservationInputs,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
# ======================================================================================================================
|
2
2
|
#
|
3
|
-
#
|
3
|
+
# IMPORTS
|
4
4
|
#
|
5
5
|
# ======================================================================================================================
|
6
6
|
|
@@ -12,25 +12,28 @@ from torch.optim import SGD, Adam, AdamW
|
|
12
12
|
from libinephany.observations import observation_utils
|
13
13
|
from libinephany.observations.observation_utils import StatisticStorageTypes
|
14
14
|
from libinephany.observations.observers.base_observers import GlobalObserver
|
15
|
-
from libinephany.observations.observers.global_observers.
|
15
|
+
from libinephany.observations.observers.global_observers.constants import LHOPT_CONSTANTS
|
16
16
|
from libinephany.pydantic_models.schemas.observation_models import ObservationInputs
|
17
17
|
from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
|
18
18
|
from libinephany.pydantic_models.states.hyperparameter_states import HyperparameterStates
|
19
19
|
from libinephany.utils.enums import ModelFamilies
|
20
20
|
|
21
|
+
# ======================================================================================================================
|
22
|
+
#
|
23
|
+
# CLASSES
|
24
|
+
#
|
25
|
+
# ======================================================================================================================
|
26
|
+
|
21
27
|
|
22
28
|
class InitialHyperparameters(GlobalObserver):
|
23
29
|
|
24
|
-
def __init__(self,
|
30
|
+
def __init__(self, pad_with: float = 0.0, **kwargs) -> None:
|
25
31
|
"""
|
26
|
-
:param include_hparams: Names of the hyperparameters to include in the initial values vector returned by
|
27
|
-
this observation.
|
28
32
|
:param kwargs: Miscellaneous keyword arguments.
|
29
33
|
"""
|
30
34
|
|
31
35
|
super().__init__(**kwargs)
|
32
36
|
|
33
|
-
self.include_hparams = include_hparams
|
34
37
|
self.pad_with = pad_with
|
35
38
|
|
36
39
|
@property
|
@@ -62,6 +65,14 @@ class InitialHyperparameters(GlobalObserver):
|
|
62
65
|
|
63
66
|
return False
|
64
67
|
|
68
|
+
@property
|
69
|
+
def requires_include_hparams(self) -> bool:
|
70
|
+
"""
|
71
|
+
:return: Whether the observation requires include_hparams to be provided.
|
72
|
+
"""
|
73
|
+
|
74
|
+
return True
|
75
|
+
|
65
76
|
def _get_observation_format(self) -> StatisticStorageTypes:
|
66
77
|
"""
|
67
78
|
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
@@ -298,7 +309,7 @@ class LHOPTHyperparameterRatio(GlobalObserver):
|
|
298
309
|
providing insights into how much hyperparameters have changed from their starting values.
|
299
310
|
"""
|
300
311
|
|
301
|
-
def __init__(self,
|
312
|
+
def __init__(self, pad_with: float = 0.0, **kwargs) -> None:
|
302
313
|
"""
|
303
314
|
:param include_hparams: Names of the hyperparameters to include in the initial values vector returned by
|
304
315
|
this observation.
|
@@ -307,7 +318,6 @@ class LHOPTHyperparameterRatio(GlobalObserver):
|
|
307
318
|
|
308
319
|
super().__init__(**kwargs)
|
309
320
|
|
310
|
-
self.include_hparams = include_hparams
|
311
321
|
self.pad_with = pad_with
|
312
322
|
|
313
323
|
@property
|
@@ -339,6 +349,14 @@ class LHOPTHyperparameterRatio(GlobalObserver):
|
|
339
349
|
|
340
350
|
return False
|
341
351
|
|
352
|
+
@property
|
353
|
+
def requires_include_hparams(self) -> bool:
|
354
|
+
"""
|
355
|
+
:return: Whether the observation requires include_hparams to be provided.
|
356
|
+
"""
|
357
|
+
|
358
|
+
return True
|
359
|
+
|
342
360
|
def _get_observation_format(self) -> StatisticStorageTypes:
|
343
361
|
"""
|
344
362
|
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
@@ -239,6 +239,13 @@ class LHOPTTrainingLoss(LHOPTBaseObserver):
|
|
239
239
|
This observer use the CDF calculation from the paper and applies CDF transformation using the CDF mean and std.
|
240
240
|
"""
|
241
241
|
|
242
|
+
@property
|
243
|
+
def vector_length(self) -> int:
|
244
|
+
"""
|
245
|
+
:return: Length of the vector returned by this observation if it returns a vector.
|
246
|
+
"""
|
247
|
+
return 3 # [is_nan, is_inf, cdf_feature]
|
248
|
+
|
242
249
|
def _observe(
|
243
250
|
self,
|
244
251
|
observation_inputs: ObservationInputs,
|
@@ -271,18 +278,6 @@ class LHOPTTrainingLoss(LHOPTBaseObserver):
|
|
271
278
|
|
272
279
|
return {}
|
273
280
|
|
274
|
-
def reset(self) -> None:
|
275
|
-
"""Reset the observer by clearing the time series."""
|
276
|
-
self._time_series: list[tuple[float, float]] = []
|
277
|
-
self._current_time: float = 0.0
|
278
|
-
|
279
|
-
@property
|
280
|
-
def vector_length(self) -> int:
|
281
|
-
"""
|
282
|
-
:return: Length of the vector returned by this observation if it returns a vector.
|
283
|
-
"""
|
284
|
-
return 3 # [is_nan, is_inf, cdf_feature]
|
285
|
-
|
286
281
|
|
287
282
|
class LHOPTValidationLoss(LHOPTBaseObserver):
|
288
283
|
"""
|
@@ -294,6 +289,13 @@ class LHOPTValidationLoss(LHOPTBaseObserver):
|
|
294
289
|
This observer use the CDF calculation from the paper and applies CDF transformation using the CDF mean and std.
|
295
290
|
"""
|
296
291
|
|
292
|
+
@property
|
293
|
+
def vector_length(self) -> int:
|
294
|
+
"""
|
295
|
+
:return: Length of the vector returned by this observation if it returns a vector.
|
296
|
+
"""
|
297
|
+
return 3 # [is_nan, is_inf, cdf_feature]
|
298
|
+
|
297
299
|
def _observe(
|
298
300
|
self,
|
299
301
|
observation_inputs: ObservationInputs,
|
@@ -326,18 +328,6 @@ class LHOPTValidationLoss(LHOPTBaseObserver):
|
|
326
328
|
|
327
329
|
return {}
|
328
330
|
|
329
|
-
def reset(self) -> None:
|
330
|
-
"""Reset the observer by clearing the time series."""
|
331
|
-
self._time_series: list[tuple[float, float]] = []
|
332
|
-
self._current_time: float = 0.0
|
333
|
-
|
334
|
-
@property
|
335
|
-
def vector_length(self) -> int:
|
336
|
-
"""
|
337
|
-
:return: Length of the vector returned by this observation if it returns a vector.
|
338
|
-
"""
|
339
|
-
return 3 # [is_nan, is_inf, cdf_feature]
|
340
|
-
|
341
331
|
|
342
332
|
class LHOPTLossRatio(LHOPTBaseObserver):
|
343
333
|
"""
|
@@ -353,6 +343,13 @@ class LHOPTLossRatio(LHOPTBaseObserver):
|
|
353
343
|
3. cdf_feature - CDF transformed feature using CDF mean and std
|
354
344
|
"""
|
355
345
|
|
346
|
+
@property
|
347
|
+
def vector_length(self) -> int:
|
348
|
+
"""
|
349
|
+
:return: Length of the vector returned by this observation if it returns a vector.
|
350
|
+
"""
|
351
|
+
return 3 # [is_nan, tanh, cdf_feature]
|
352
|
+
|
356
353
|
def _observe(
|
357
354
|
self,
|
358
355
|
observation_inputs: ObservationInputs,
|
@@ -370,7 +367,7 @@ class LHOPTLossRatio(LHOPTBaseObserver):
|
|
370
367
|
"""
|
371
368
|
|
372
369
|
log_ratio = self._compute_log_ratio(
|
373
|
-
|
370
|
+
numerator=observation_inputs.training_loss, denominator=observation_inputs.validation_loss
|
374
371
|
)
|
375
372
|
|
376
373
|
tanh_feature = math.tanh(max(-LHOPT_CONSTANTS["TANH_BOUND"], min(LHOPT_CONSTANTS["TANH_BOUND"], log_ratio)))
|
@@ -381,31 +378,13 @@ class LHOPTLossRatio(LHOPTBaseObserver):
|
|
381
378
|
|
382
379
|
return [int(math.isnan(log_ratio)), tanh_feature, cdf_feature]
|
383
380
|
|
384
|
-
def
|
381
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
385
382
|
"""
|
386
|
-
|
387
|
-
|
388
|
-
:param training_score: Training score value
|
389
|
-
:param validation_score: Validation score value
|
390
|
-
:return: Log ratio value
|
383
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
384
|
+
needed.
|
391
385
|
"""
|
392
|
-
if training_score <= 0:
|
393
|
-
return 0.0
|
394
|
-
|
395
|
-
if validation_score <= 0:
|
396
|
-
return 0.0
|
397
386
|
|
398
|
-
|
399
|
-
score_ratio = validation_score / training_score
|
400
|
-
|
401
|
-
return math.log(score_ratio)
|
402
|
-
|
403
|
-
@property
|
404
|
-
def vector_length(self) -> int:
|
405
|
-
"""
|
406
|
-
:return: Length of the vector returned by this observation if it returns a vector.
|
407
|
-
"""
|
408
|
-
return 3 # [is_nan, tanh, cdf_feature]
|
387
|
+
return {}
|
409
388
|
|
410
389
|
|
411
390
|
class PercentileOfLossAtEachCheckpoint(LHOPTCheckpointBaseObserver):
|
@@ -25,6 +25,14 @@ from libinephany.pydantic_models.states.hyperparameter_states import Hyperparame
|
|
25
25
|
|
26
26
|
class GlobalActivations(GlobalObserver):
|
27
27
|
|
28
|
+
@property
|
29
|
+
def requires_include_statistics(self) -> bool:
|
30
|
+
"""
|
31
|
+
:return: Whether the observation requires include_statistics to be provided.
|
32
|
+
"""
|
33
|
+
|
34
|
+
return True
|
35
|
+
|
28
36
|
def _get_observation_format(self) -> StatisticStorageTypes:
|
29
37
|
"""
|
30
38
|
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
@@ -64,6 +72,14 @@ class GlobalActivations(GlobalObserver):
|
|
64
72
|
|
65
73
|
class GlobalParameterUpdates(GlobalObserver):
|
66
74
|
|
75
|
+
@property
|
76
|
+
def requires_include_statistics(self) -> bool:
|
77
|
+
"""
|
78
|
+
:return: Whether the observation requires include_statistics to be provided.
|
79
|
+
"""
|
80
|
+
|
81
|
+
return True
|
82
|
+
|
67
83
|
def _get_observation_format(self) -> StatisticStorageTypes:
|
68
84
|
"""
|
69
85
|
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
@@ -103,6 +119,14 @@ class GlobalParameterUpdates(GlobalObserver):
|
|
103
119
|
|
104
120
|
class GlobalParameters(GlobalObserver):
|
105
121
|
|
122
|
+
@property
|
123
|
+
def requires_include_statistics(self) -> bool:
|
124
|
+
"""
|
125
|
+
:return: Whether the observation requires include_statistics to be provided.
|
126
|
+
"""
|
127
|
+
|
128
|
+
return True
|
129
|
+
|
106
130
|
def _get_observation_format(self) -> StatisticStorageTypes:
|
107
131
|
"""
|
108
132
|
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
@@ -342,6 +366,14 @@ class LogRatioOfPreviousAndCurrentParamNormEnvStepObserver(LHOPTBaseObserver):
|
|
342
366
|
"""
|
343
367
|
return 2 # [tanh_feature, cdf_feature]
|
344
368
|
|
369
|
+
@property
|
370
|
+
def requires_include_statistics(self) -> bool:
|
371
|
+
"""
|
372
|
+
:return: Whether the observation requires include_statistics to be provided.
|
373
|
+
"""
|
374
|
+
|
375
|
+
return True
|
376
|
+
|
345
377
|
def _observe(
|
346
378
|
self,
|
347
379
|
observation_inputs: ObservationInputs,
|
@@ -410,6 +442,14 @@ class LogRatioOfUpdateAndPreviousParamNormEnvStepObserver(LHOPTBaseObserver):
|
|
410
442
|
"""
|
411
443
|
return 2 # [tanh_feature, cdf_feature]
|
412
444
|
|
445
|
+
@property
|
446
|
+
def requires_include_statistics(self) -> bool:
|
447
|
+
"""
|
448
|
+
:return: Whether the observation requires include_statistics to be provided.
|
449
|
+
"""
|
450
|
+
|
451
|
+
return True
|
452
|
+
|
413
453
|
def _observe(
|
414
454
|
self,
|
415
455
|
observation_inputs: ObservationInputs,
|
@@ -473,28 +513,20 @@ class LogRatioOfUpdateAndPreviousParamNormEnvStepObserver(LHOPTBaseObserver):
|
|
473
513
|
|
474
514
|
class LHOPTAverageParameterUpdateMagnitudeObserver(LHOPTBaseObserver):
|
475
515
|
|
476
|
-
def _get_observation_format(self) -> StatisticStorageTypes:
|
477
|
-
"""
|
478
|
-
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
479
|
-
enumeration class.
|
480
|
-
"""
|
481
|
-
|
482
|
-
return StatisticStorageTypes.VECTOR
|
483
|
-
|
484
516
|
@property
|
485
|
-
def
|
517
|
+
def vector_length(self) -> int:
|
486
518
|
"""
|
487
|
-
:return:
|
519
|
+
:return: Length of the vector returned by this observation if it returns a vector.
|
488
520
|
"""
|
489
|
-
|
490
|
-
return False
|
521
|
+
return 2 # [raw_feature, cdf_feature]
|
491
522
|
|
492
523
|
@property
|
493
|
-
def
|
524
|
+
def requires_include_statistics(self) -> bool:
|
494
525
|
"""
|
495
|
-
:return:
|
526
|
+
:return: Whether the observation requires include_statistics to be provided.
|
496
527
|
"""
|
497
|
-
|
528
|
+
|
529
|
+
return True
|
498
530
|
|
499
531
|
def _observe(
|
500
532
|
self,
|
@@ -537,8 +569,8 @@ class LHOPTAverageParameterUpdateMagnitudeObserver(LHOPTBaseObserver):
|
|
537
569
|
class LogRatioOfUpdateAndPreviousParamNormInnerStepObserver(LHOPTBaseObserver):
|
538
570
|
def __init__(self, **kwargs):
|
539
571
|
"""
|
540
|
-
This observer is used to compute the log ratio of the update and previous parameter norm for the inner step.
|
541
|
-
|
572
|
+
This observer is used to compute the log ratio of the update and previous parameter norm for the inner step.
|
573
|
+
The sample frequency of the statistics needs to be set to 4 (according to the OpenAI paper).
|
542
574
|
"""
|
543
575
|
super().__init__(**kwargs)
|
544
576
|
self._previous_param_norm = None
|
@@ -550,6 +582,14 @@ class LogRatioOfUpdateAndPreviousParamNormInnerStepObserver(LHOPTBaseObserver):
|
|
550
582
|
"""
|
551
583
|
return 2 # [tanh_feature, cdf_feature]
|
552
584
|
|
585
|
+
@property
|
586
|
+
def requires_include_statistics(self) -> bool:
|
587
|
+
"""
|
588
|
+
:return: Whether the observation requires include_statistics to be provided.
|
589
|
+
"""
|
590
|
+
|
591
|
+
return True
|
592
|
+
|
553
593
|
def _observe(
|
554
594
|
self,
|
555
595
|
observation_inputs: ObservationInputs,
|
@@ -630,14 +670,6 @@ class LHOPTGlobalLAMBTrustRatio(LHOPTBaseObserver):
|
|
630
670
|
|
631
671
|
self.use_log_transform = use_log_transform
|
632
672
|
|
633
|
-
def _get_observation_format(self) -> StatisticStorageTypes:
|
634
|
-
"""
|
635
|
-
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
636
|
-
enumeration class.
|
637
|
-
"""
|
638
|
-
|
639
|
-
return StatisticStorageTypes.VECTOR
|
640
|
-
|
641
673
|
@property
|
642
674
|
def vector_length(self) -> int:
|
643
675
|
"""
|
@@ -645,6 +677,14 @@ class LHOPTGlobalLAMBTrustRatio(LHOPTBaseObserver):
|
|
645
677
|
"""
|
646
678
|
return 2 # [raw_value, cdf_feature]
|
647
679
|
|
680
|
+
@property
|
681
|
+
def requires_include_statistics(self) -> bool:
|
682
|
+
"""
|
683
|
+
:return: Whether the observation requires include_statistics to be provided.
|
684
|
+
"""
|
685
|
+
|
686
|
+
return True
|
687
|
+
|
648
688
|
def _observe(
|
649
689
|
self,
|
650
690
|
observation_inputs: ObservationInputs,
|
@@ -27,6 +27,14 @@ from libinephany.utils.transforms import HyperparameterTransformType
|
|
27
27
|
|
28
28
|
class FirstOrderGradients(LocalObserver):
|
29
29
|
|
30
|
+
@property
|
31
|
+
def requires_include_statistics(self) -> bool:
|
32
|
+
"""
|
33
|
+
:return: Whether the observation requires include_statistics to be provided.
|
34
|
+
"""
|
35
|
+
|
36
|
+
return True
|
37
|
+
|
30
38
|
def _get_observation_format(self) -> StatisticStorageTypes:
|
31
39
|
"""
|
32
40
|
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
@@ -89,6 +97,14 @@ class SecondOrderGradients(LocalObserver):
|
|
89
97
|
|
90
98
|
self.compute_hessian_diagonal = compute_hessian_diagonal
|
91
99
|
|
100
|
+
@property
|
101
|
+
def requires_include_statistics(self) -> bool:
|
102
|
+
"""
|
103
|
+
:return: Whether the observation requires include_statistics to be provided.
|
104
|
+
"""
|
105
|
+
|
106
|
+
return True
|
107
|
+
|
92
108
|
def _get_observation_format(self) -> StatisticStorageTypes:
|
93
109
|
"""
|
94
110
|
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
@@ -139,6 +155,14 @@ class SecondOrderGradients(LocalObserver):
|
|
139
155
|
|
140
156
|
class Activations(LocalObserver):
|
141
157
|
|
158
|
+
@property
|
159
|
+
def requires_include_statistics(self) -> bool:
|
160
|
+
"""
|
161
|
+
:return: Whether the observation requires include_statistics to be provided.
|
162
|
+
"""
|
163
|
+
|
164
|
+
return True
|
165
|
+
|
142
166
|
def _get_observation_format(self) -> StatisticStorageTypes:
|
143
167
|
"""
|
144
168
|
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
@@ -185,6 +209,14 @@ class Activations(LocalObserver):
|
|
185
209
|
|
186
210
|
class ParameterUpdates(LocalObserver):
|
187
211
|
|
212
|
+
@property
|
213
|
+
def requires_include_statistics(self) -> bool:
|
214
|
+
"""
|
215
|
+
:return: Whether the observation requires include_statistics to be provided.
|
216
|
+
"""
|
217
|
+
|
218
|
+
return True
|
219
|
+
|
188
220
|
def _get_observation_format(self) -> StatisticStorageTypes:
|
189
221
|
"""
|
190
222
|
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
@@ -231,6 +263,14 @@ class ParameterUpdates(LocalObserver):
|
|
231
263
|
|
232
264
|
class Parameters(LocalObserver):
|
233
265
|
|
266
|
+
@property
|
267
|
+
def requires_include_statistics(self) -> bool:
|
268
|
+
"""
|
269
|
+
:return: Whether the observation requires include_statistics to be provided.
|
270
|
+
"""
|
271
|
+
|
272
|
+
return True
|
273
|
+
|
234
274
|
def _get_observation_format(self) -> StatisticStorageTypes:
|
235
275
|
"""
|
236
276
|
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
@@ -624,6 +664,7 @@ class ModuleTypeOneHot(LocalObserver):
|
|
624
664
|
"attention": 1,
|
625
665
|
"linear": 2,
|
626
666
|
"embedding": 3,
|
667
|
+
"lstm": 4,
|
627
668
|
}
|
628
669
|
|
629
670
|
@property
|
@@ -692,17 +733,6 @@ class ModuleTypeOneHot(LocalObserver):
|
|
692
733
|
|
693
734
|
class CurrentHyperparameters(LocalObserver):
|
694
735
|
|
695
|
-
def __init__(self, include_hparams: list[str] | None = None, **kwargs) -> None:
|
696
|
-
"""
|
697
|
-
:param include_hparams: Names of the hyperparameters to include in the initial values vector returned by
|
698
|
-
this observation.
|
699
|
-
:param kwargs: Miscellaneous keyword arguments.
|
700
|
-
"""
|
701
|
-
|
702
|
-
super().__init__(**kwargs)
|
703
|
-
|
704
|
-
self.include_hparams = include_hparams
|
705
|
-
|
706
736
|
@property
|
707
737
|
def can_standardize(self) -> bool:
|
708
738
|
"""
|
@@ -732,6 +762,14 @@ class CurrentHyperparameters(LocalObserver):
|
|
732
762
|
|
733
763
|
return len([hparam for hparam in available_hparams if hparam in self.include_hparams])
|
734
764
|
|
765
|
+
@property
|
766
|
+
def requires_include_hparams(self) -> bool:
|
767
|
+
"""
|
768
|
+
:return: Whether the observation requires include_hparams to be provided.
|
769
|
+
"""
|
770
|
+
|
771
|
+
return True
|
772
|
+
|
735
773
|
def _get_observation_format(self) -> StatisticStorageTypes:
|
736
774
|
"""
|
737
775
|
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
@@ -777,17 +815,6 @@ class CurrentHyperparameters(LocalObserver):
|
|
777
815
|
|
778
816
|
class CurrentHyperparameterDeltas(LocalObserver):
|
779
817
|
|
780
|
-
def __init__(self, include_hparams: list[str] | None = None, **kwargs) -> None:
|
781
|
-
"""
|
782
|
-
:param include_hparams: Names of the hyperparameters to include in the initial deltas vector returned by
|
783
|
-
this observation.
|
784
|
-
:param kwargs: Miscellaneous keyword arguments.
|
785
|
-
"""
|
786
|
-
|
787
|
-
super().__init__(**kwargs)
|
788
|
-
|
789
|
-
self.include_hparams = include_hparams
|
790
|
-
|
791
818
|
@property
|
792
819
|
def can_standardize(self) -> bool:
|
793
820
|
"""
|
@@ -817,6 +844,14 @@ class CurrentHyperparameterDeltas(LocalObserver):
|
|
817
844
|
|
818
845
|
return len([hparam for hparam in available_hparams if hparam in self.include_hparams])
|
819
846
|
|
847
|
+
@property
|
848
|
+
def requires_include_hparams(self) -> bool:
|
849
|
+
"""
|
850
|
+
:return: Whether the observation requires include_hparams to be provided.
|
851
|
+
"""
|
852
|
+
|
853
|
+
return True
|
854
|
+
|
820
855
|
def _get_observation_format(self) -> StatisticStorageTypes:
|
821
856
|
"""
|
822
857
|
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
@@ -865,17 +900,6 @@ class HyperparameterTransformTypes(LocalObserver):
|
|
865
900
|
|
866
901
|
TRANSFORM_TYPE_TO_IDX = dict(((s, i) for i, s in enumerate(HyperparameterTransformType)))
|
867
902
|
|
868
|
-
def __init__(self, include_hparams: list[str] | None = None, **kwargs) -> None:
|
869
|
-
"""
|
870
|
-
:param include_hparams: Names of the hyperparameters to include in the transforms vector returned by
|
871
|
-
this observation.
|
872
|
-
:param kwargs: Miscellaneous keyword arguments.
|
873
|
-
"""
|
874
|
-
|
875
|
-
super().__init__(**kwargs)
|
876
|
-
|
877
|
-
self.include_hparams = include_hparams
|
878
|
-
|
879
903
|
@property
|
880
904
|
def can_standardize(self) -> bool:
|
881
905
|
"""
|
@@ -907,6 +931,14 @@ class HyperparameterTransformTypes(LocalObserver):
|
|
907
931
|
[hparam for hparam in available_hparams if hparam in self.include_hparams]
|
908
932
|
)
|
909
933
|
|
934
|
+
@property
|
935
|
+
def requires_include_hparams(self) -> bool:
|
936
|
+
"""
|
937
|
+
:return: Whether the observation requires include_hparams to be provided.
|
938
|
+
"""
|
939
|
+
|
940
|
+
return True
|
941
|
+
|
910
942
|
def _get_observation_format(self) -> StatisticStorageTypes:
|
911
943
|
"""
|
912
944
|
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
@@ -1087,7 +1119,6 @@ class LogOfNoiseScaleObserver(LocalObserver):
|
|
1087
1119
|
*,
|
1088
1120
|
decay_factor: float = LHOPT_CONSTANTS["DEFAULT_DECAY_FACTOR"],
|
1089
1121
|
time_window: int = LHOPT_CONSTANTS["DEFAULT_TIME_WINDOW"],
|
1090
|
-
include_statistics: list[str] | None = None,
|
1091
1122
|
**kwargs,
|
1092
1123
|
) -> None:
|
1093
1124
|
"""
|
@@ -1100,7 +1131,6 @@ class LogOfNoiseScaleObserver(LocalObserver):
|
|
1100
1131
|
|
1101
1132
|
super().__init__(**kwargs)
|
1102
1133
|
|
1103
|
-
self.include_statistics = include_statistics
|
1104
1134
|
self.decay_factor = max(0.0, decay_factor)
|
1105
1135
|
self.time_window = max(1, time_window)
|
1106
1136
|
|
@@ -1108,14 +1138,6 @@ class LogOfNoiseScaleObserver(LocalObserver):
|
|
1108
1138
|
self._time_series: list[tuple[float, float]] = [] # (time, value) pairs
|
1109
1139
|
self._current_time: float = 0.0
|
1110
1140
|
|
1111
|
-
def _get_observation_format(self) -> StatisticStorageTypes:
|
1112
|
-
"""
|
1113
|
-
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
1114
|
-
enumeration class.
|
1115
|
-
"""
|
1116
|
-
|
1117
|
-
return StatisticStorageTypes.VECTOR
|
1118
|
-
|
1119
1141
|
@property
|
1120
1142
|
def can_standardize(self) -> bool:
|
1121
1143
|
"""
|
@@ -1132,6 +1154,29 @@ class LogOfNoiseScaleObserver(LocalObserver):
|
|
1132
1154
|
|
1133
1155
|
return False
|
1134
1156
|
|
1157
|
+
@property
|
1158
|
+
def vector_length(self) -> int:
|
1159
|
+
"""
|
1160
|
+
:return: Length of the vector returned by this observation if it returns a vector.
|
1161
|
+
"""
|
1162
|
+
return 2 # [log_noise_scale, cdf_feature]
|
1163
|
+
|
1164
|
+
@property
|
1165
|
+
def requires_include_statistics(self) -> bool:
|
1166
|
+
"""
|
1167
|
+
:return: Whether the observation requires include_statistics to be provided.
|
1168
|
+
"""
|
1169
|
+
|
1170
|
+
return True
|
1171
|
+
|
1172
|
+
def _get_observation_format(self) -> StatisticStorageTypes:
|
1173
|
+
"""
|
1174
|
+
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
1175
|
+
enumeration class.
|
1176
|
+
"""
|
1177
|
+
|
1178
|
+
return StatisticStorageTypes.VECTOR
|
1179
|
+
|
1135
1180
|
def _update_time(self) -> None:
|
1136
1181
|
"""Update the current time counter."""
|
1137
1182
|
self._current_time += 1.0
|
@@ -1145,13 +1190,6 @@ class LogOfNoiseScaleObserver(LocalObserver):
|
|
1145
1190
|
"""
|
1146
1191
|
return compute_cdf_feature(value, self._time_series, self.decay_factor, self._current_time, self.time_window)
|
1147
1192
|
|
1148
|
-
@property
|
1149
|
-
def vector_length(self) -> int:
|
1150
|
-
"""
|
1151
|
-
:return: Length of the vector returned by this observation if it returns a vector.
|
1152
|
-
"""
|
1153
|
-
return 2 # [log_noise_scale, cdf_feature]
|
1154
|
-
|
1155
1193
|
def _observe(
|
1156
1194
|
self,
|
1157
1195
|
observation_inputs: ObservationInputs,
|
@@ -1169,16 +1207,16 @@ class LogOfNoiseScaleObserver(LocalObserver):
|
|
1169
1207
|
"""
|
1170
1208
|
|
1171
1209
|
statistics = tracked_statistics[statistic_trackers.LogOfNoiseScaleStatistics.__name__]
|
1172
|
-
|
1173
1210
|
raw_value = list(statistics.values())[0] # type: ignore[list-item]
|
1211
|
+
|
1174
1212
|
assert isinstance(raw_value, float), f"Expected float, got {type(raw_value)}" # to avoid type errors with mypy
|
1213
|
+
|
1175
1214
|
batch_size = hyperparameter_states.global_hparams.batch_size.external_value
|
1176
1215
|
learning_rate = hyperparameter_states.parameter_group_hparams[
|
1177
1216
|
self.parameter_group_name
|
1178
1217
|
].learning_rate.external_value
|
1179
1218
|
|
1180
1219
|
log_b_over_epsilon = math.log(batch_size / learning_rate)
|
1181
|
-
|
1182
1220
|
log_noise_scale = raw_value + log_b_over_epsilon
|
1183
1221
|
|
1184
1222
|
cdf_feature = self._compute_cdf_feature(log_noise_scale) # type: ignore[arg-type]
|
libinephany/utils/enums.py
CHANGED
@@ -8,16 +8,16 @@ libinephany/observations/pipeline_coordinator.py,sha256=mLfaHhkXVhMp9w5jWIAL3jPy
|
|
8
8
|
libinephany/observations/statistic_manager.py,sha256=LLg1zSxnJr2oQQepYla3qoUuRy10rsthr9jta4wEbnc,8956
|
9
9
|
libinephany/observations/statistic_trackers.py,sha256=F98V-H2Ljx0v2YnppYCCJLJojL6pzYBdbBh8Lb4lasA,47666
|
10
10
|
libinephany/observations/observers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
|
-
libinephany/observations/observers/base_observers.py,sha256=
|
12
|
-
libinephany/observations/observers/local_observers.py,sha256=
|
11
|
+
libinephany/observations/observers/base_observers.py,sha256=V8PIysq2wT6K-w_CqeM5benyif-xK1hPT3M6a4ic1So,17535
|
12
|
+
libinephany/observations/observers/local_observers.py,sha256=NVH7fIaV2rIIvXvyntlwOHaKhGyxj_zenqG-4SOAeNM,46156
|
13
13
|
libinephany/observations/observers/observer_containers.py,sha256=VNyqGgxYJ4r49Msp_kk-POgicb-_5w54twuT1qfNdxw,9562
|
14
14
|
libinephany/observations/observers/global_observers/__init__.py,sha256=87WHRPYmL0tVsaTKUd91pwEpCZtHPSKRQoba2VQjswA,3018
|
15
|
-
libinephany/observations/observers/global_observers/base_classes.py,sha256=
|
16
|
-
libinephany/observations/observers/global_observers/constants.py,sha256=
|
17
|
-
libinephany/observations/observers/global_observers/gradient_observers.py,sha256=
|
18
|
-
libinephany/observations/observers/global_observers/hyperparameter_observers.py,sha256=
|
19
|
-
libinephany/observations/observers/global_observers/loss_observers.py,sha256=
|
20
|
-
libinephany/observations/observers/global_observers/model_observers.py,sha256=
|
15
|
+
libinephany/observations/observers/global_observers/base_classes.py,sha256=Q7OblhmKscypTs9JBepSQwo6ljjOdPKTU9kbpuhq_W4,7800
|
16
|
+
libinephany/observations/observers/global_observers/constants.py,sha256=TDQM_sGU8Swze794oB4TBaXFjSddt0OBhYPVhrXQ9Ko,1654
|
17
|
+
libinephany/observations/observers/global_observers/gradient_observers.py,sha256=j9uX7043ic06W6vb6vDt_PH2a5WtLFCdKQZ1JQGT24Q,19531
|
18
|
+
libinephany/observations/observers/global_observers/hyperparameter_observers.py,sha256=5Av8FgwWBJtcn4gpDzPdOTlKOOYy2lVEHS_gt5Sz7xo,15334
|
19
|
+
libinephany/observations/observers/global_observers/loss_observers.py,sha256=Kf943FiuYWuWvjmhgmp3TGyIQoZ27ZKJcxfTBwXS-gA,17761
|
20
|
+
libinephany/observations/observers/global_observers/model_observers.py,sha256=SGWXrmTdgp0kHvEvDSF7d3v1FEcK1sQDMPLQ8Wy3qv4,29306
|
21
21
|
libinephany/observations/observers/global_observers/progress_observers.py,sha256=ypLk1_POAjA8V8rAaQ0B6Qh8m_04s9PAoXsw1KxVrLg,5872
|
22
22
|
libinephany/observations/post_processors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
23
23
|
libinephany/observations/post_processors/postprocessors.py,sha256=43_e5UaDPr2KbAvqc_w3wLqnlm7bgRjqgCtyQ95-8cM,5913
|
@@ -42,7 +42,7 @@ libinephany/utils/backend_statuses.py,sha256=ZbpBPbz0qKmeqxyGGN_ePTrQ7Wrxh7KM6W2
|
|
42
42
|
libinephany/utils/constants.py,sha256=XAOuPowvM4FDSbfvNsubKTAqSB84AANX4CoHb7LwgEI,2330
|
43
43
|
libinephany/utils/directory_utils.py,sha256=408unVeE_5_Hm-ZYZuxc9sdvfuU0CgYELX7EzPlPieo,1217
|
44
44
|
libinephany/utils/dropout_utils.py,sha256=X43yCW7Dh1cC5sNnivgS5j1fn871K_RCvxCBTT0YHKg,3392
|
45
|
-
libinephany/utils/enums.py,sha256=
|
45
|
+
libinephany/utils/enums.py,sha256=XWC98mOny7wRvnUdBPh4gwpfz7BuDi-RHsHlpFyr-3A,2995
|
46
46
|
libinephany/utils/error_severities.py,sha256=B9oidqOVaYOe0W6P6GwjpmuDsrkyTX30v1xdiUStCFk,1427
|
47
47
|
libinephany/utils/exceptions.py,sha256=kgwLpHOgy3kciUz_I18xnYsWRtzdonfadUtwG2uDYk8,1823
|
48
48
|
libinephany/utils/import_utils.py,sha256=WzC6V6UIa0nCiU2MekROwG82fWBh9RuVzichtby5EvM,1495
|
@@ -57,8 +57,8 @@ libinephany/utils/typing.py,sha256=rGbaPO3MaUndsWiC_wHzReD_TOLYqb43i01pKN-j7Xs,6
|
|
57
57
|
libinephany/web_apps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
58
58
|
libinephany/web_apps/error_logger.py,sha256=gAQIaqerqP4ornXZwFF1cghjnd2mMZEt3aVrTuUCr34,16653
|
59
59
|
libinephany/web_apps/web_app_utils.py,sha256=qiq_lasPipgN1RgRudPJc342kYci8O_4RqppxmIX8NY,4095
|
60
|
-
libinephany-1.0.
|
61
|
-
libinephany-1.0.
|
62
|
-
libinephany-1.0.
|
63
|
-
libinephany-1.0.
|
64
|
-
libinephany-1.0.
|
60
|
+
libinephany-1.0.3.dist-info/licenses/LICENSE,sha256=pogfDoMBP07ehIOvWymuWIar8pg2YLUhqOHsJQU3wdc,9250
|
61
|
+
libinephany-1.0.3.dist-info/METADATA,sha256=AJHvThkrO8rYniWmXEWyyH8hnKNMIllOu7VGvCbeTo0,8389
|
62
|
+
libinephany-1.0.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
63
|
+
libinephany-1.0.3.dist-info/top_level.txt,sha256=bYAOXQdJgIoLkO2Ui0kxe7pSYegS_e38u0dMscd7COQ,12
|
64
|
+
libinephany-1.0.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|