libinephany 0.16.2__py3-none-any.whl → 0.16.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- libinephany/observations/observation_utils.py +167 -13
- libinephany/observations/observers/global_observers/__init__.py +58 -0
- libinephany/observations/observers/global_observers/base_classes.py +183 -0
- libinephany/observations/observers/global_observers/constants.py +33 -0
- libinephany/observations/observers/global_observers/gradient_observers.py +112 -0
- libinephany/observations/observers/global_observers/hyperparameter_observers.py +286 -0
- libinephany/observations/observers/global_observers/loss_observers.py +464 -0
- libinephany/observations/observers/global_observers/model_observers.py +327 -0
- libinephany/observations/observers/global_observers/progress_observers.py +142 -0
- libinephany/observations/observers/observer_containers.py +3 -1
- libinephany/utils/constants.py +1 -0
- libinephany/utils/enums.py +18 -0
- libinephany/utils/samplers.py +36 -1
- {libinephany-0.16.2.dist-info → libinephany-0.16.4.dist-info}/METADATA +2 -1
- {libinephany-0.16.2.dist-info → libinephany-0.16.4.dist-info}/RECORD +18 -11
- libinephany/observations/observers/global_observers.py +0 -991
- {libinephany-0.16.2.dist-info → libinephany-0.16.4.dist-info}/WHEEL +0 -0
- {libinephany-0.16.2.dist-info → libinephany-0.16.4.dist-info}/licenses/LICENSE +0 -0
- {libinephany-0.16.2.dist-info → libinephany-0.16.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,327 @@
|
|
1
|
+
# ======================================================================================================================
|
2
|
+
#
|
3
|
+
# IMPORTS
|
4
|
+
#
|
5
|
+
# ======================================================================================================================
|
6
|
+
|
7
|
+
import math
|
8
|
+
from typing import Any
|
9
|
+
|
10
|
+
from libinephany.observations import observation_utils, statistic_trackers
|
11
|
+
from libinephany.observations.observation_utils import StatisticStorageTypes
|
12
|
+
from libinephany.observations.observers.base_observers import GlobalObserver
|
13
|
+
from libinephany.pydantic_models.schemas.observation_models import ObservationInputs
|
14
|
+
from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
|
15
|
+
from libinephany.pydantic_models.states.hyperparameter_states import HyperparameterStates
|
16
|
+
|
17
|
+
# ======================================================================================================================
|
18
|
+
#
|
19
|
+
# CLASSES
|
20
|
+
#
|
21
|
+
# ======================================================================================================================
|
22
|
+
|
23
|
+
|
24
|
+
class GlobalActivations(GlobalObserver):
|
25
|
+
|
26
|
+
def _get_observation_format(self) -> StatisticStorageTypes:
|
27
|
+
"""
|
28
|
+
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
29
|
+
enumeration class.
|
30
|
+
"""
|
31
|
+
|
32
|
+
return StatisticStorageTypes.TENSOR_STATISTICS
|
33
|
+
|
34
|
+
def _observe(
|
35
|
+
self,
|
36
|
+
observation_inputs: ObservationInputs,
|
37
|
+
hyperparameter_states: HyperparameterStates,
|
38
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
39
|
+
action_taken: float | int | None,
|
40
|
+
) -> float | int | list[int | float] | TensorStatistics:
|
41
|
+
"""
|
42
|
+
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
43
|
+
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
44
|
+
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
45
|
+
names to floats or TensorStatistic models.
|
46
|
+
:param action_taken: Action taken by the agent this class instance is assigned to.
|
47
|
+
:return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
|
48
|
+
"""
|
49
|
+
|
50
|
+
statistics = tracked_statistics[statistic_trackers.ActivationStatistics.__name__]
|
51
|
+
|
52
|
+
return observation_utils.average_tensor_statistics(tensor_statistics=list(statistics.values())) # type: ignore
|
53
|
+
|
54
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
55
|
+
"""
|
56
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
57
|
+
needed.
|
58
|
+
"""
|
59
|
+
|
60
|
+
return {statistic_trackers.ActivationStatistics.__name__: dict(skip_statistics=self.skip_statistics)}
|
61
|
+
|
62
|
+
|
63
|
+
class GlobalParameterUpdates(GlobalObserver):
|
64
|
+
|
65
|
+
def _get_observation_format(self) -> StatisticStorageTypes:
|
66
|
+
"""
|
67
|
+
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
68
|
+
enumeration class.
|
69
|
+
"""
|
70
|
+
|
71
|
+
return StatisticStorageTypes.TENSOR_STATISTICS
|
72
|
+
|
73
|
+
def _observe(
|
74
|
+
self,
|
75
|
+
observation_inputs: ObservationInputs,
|
76
|
+
hyperparameter_states: HyperparameterStates,
|
77
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
78
|
+
action_taken: float | int | None,
|
79
|
+
) -> float | int | list[int | float] | TensorStatistics:
|
80
|
+
"""
|
81
|
+
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
82
|
+
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
83
|
+
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
84
|
+
names to floats or TensorStatistic models.
|
85
|
+
:param action_taken: Action taken by the agent this class instance is assigned to.
|
86
|
+
:return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
|
87
|
+
"""
|
88
|
+
|
89
|
+
statistics = tracked_statistics[statistic_trackers.ParameterUpdateStatistics.__name__]
|
90
|
+
|
91
|
+
return observation_utils.average_tensor_statistics(tensor_statistics=list(statistics.values())) # type: ignore
|
92
|
+
|
93
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
94
|
+
"""
|
95
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
96
|
+
needed.
|
97
|
+
"""
|
98
|
+
|
99
|
+
return {statistic_trackers.ParameterUpdateStatistics.__name__: dict(skip_statistics=self.skip_statistics)}
|
100
|
+
|
101
|
+
|
102
|
+
class GlobalParameters(GlobalObserver):
|
103
|
+
|
104
|
+
def _get_observation_format(self) -> StatisticStorageTypes:
|
105
|
+
"""
|
106
|
+
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
107
|
+
enumeration class.
|
108
|
+
"""
|
109
|
+
|
110
|
+
return StatisticStorageTypes.TENSOR_STATISTICS
|
111
|
+
|
112
|
+
def _observe(
|
113
|
+
self,
|
114
|
+
observation_inputs: ObservationInputs,
|
115
|
+
hyperparameter_states: HyperparameterStates,
|
116
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
117
|
+
action_taken: float | int | None,
|
118
|
+
) -> float | int | list[int | float] | TensorStatistics:
|
119
|
+
"""
|
120
|
+
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
121
|
+
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
122
|
+
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
123
|
+
names to floats or TensorStatistic models.
|
124
|
+
:param action_taken: Action taken by the agent this class instance is assigned to.
|
125
|
+
:return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
|
126
|
+
"""
|
127
|
+
|
128
|
+
statistics = tracked_statistics[statistic_trackers.ParameterStatistics.__name__]
|
129
|
+
|
130
|
+
return observation_utils.average_tensor_statistics(tensor_statistics=list(statistics.values())) # type: ignore
|
131
|
+
|
132
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
133
|
+
"""
|
134
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
135
|
+
needed.
|
136
|
+
"""
|
137
|
+
|
138
|
+
return {statistic_trackers.ParameterStatistics.__name__: dict(skip_statistics=self.skip_statistics)}
|
139
|
+
|
140
|
+
|
141
|
+
class GlobalLAMBTrustRatio(GlobalObserver):
|
142
|
+
|
143
|
+
def __init__(
|
144
|
+
self,
|
145
|
+
*,
|
146
|
+
use_log_transform: bool = False,
|
147
|
+
**kwargs,
|
148
|
+
) -> None:
|
149
|
+
"""
|
150
|
+
:param use_log_transform: Whether to transform the LAMB trust ratio by taking ln(1 + R).
|
151
|
+
:param kwargs: Other observation keyword arguments.
|
152
|
+
"""
|
153
|
+
|
154
|
+
super().__init__(**kwargs)
|
155
|
+
|
156
|
+
self.use_log_transform = use_log_transform
|
157
|
+
|
158
|
+
def _get_observation_format(self) -> StatisticStorageTypes:
|
159
|
+
"""
|
160
|
+
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
161
|
+
enumeration class.
|
162
|
+
"""
|
163
|
+
|
164
|
+
return StatisticStorageTypes.FLOAT
|
165
|
+
|
166
|
+
def _observe(
|
167
|
+
self,
|
168
|
+
observation_inputs: ObservationInputs,
|
169
|
+
hyperparameter_states: HyperparameterStates,
|
170
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
171
|
+
action_taken: float | int | None,
|
172
|
+
) -> float | int | list[int | float] | TensorStatistics:
|
173
|
+
"""
|
174
|
+
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
175
|
+
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
176
|
+
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
177
|
+
names to floats or TensorStatistic models.
|
178
|
+
:param action_taken: Action taken by the agent this class instance is assigned to.
|
179
|
+
:return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
|
180
|
+
"""
|
181
|
+
|
182
|
+
statistics = tracked_statistics[statistic_trackers.LAMBTrustRatioStatistics.__name__]
|
183
|
+
|
184
|
+
return sum(statistics.values()) / len(statistics) # type: ignore
|
185
|
+
|
186
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
187
|
+
"""
|
188
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
189
|
+
needed.
|
190
|
+
"""
|
191
|
+
|
192
|
+
return {statistic_trackers.LAMBTrustRatioStatistics.__name__: dict(use_log_transform=self.use_log_transform)}
|
193
|
+
|
194
|
+
|
195
|
+
class NumberOfParameters(GlobalObserver):
|
196
|
+
|
197
|
+
def __init__(
|
198
|
+
self,
|
199
|
+
*,
|
200
|
+
use_log_transform: bool = True,
|
201
|
+
**kwargs,
|
202
|
+
) -> None:
|
203
|
+
"""
|
204
|
+
:param use_log_transform: Whether to transform the return of the Observer by ln(1 + N).
|
205
|
+
:param kwargs: Miscellaneous keyword arguments.
|
206
|
+
"""
|
207
|
+
|
208
|
+
super().__init__(**kwargs)
|
209
|
+
|
210
|
+
self.use_log_transform = use_log_transform
|
211
|
+
|
212
|
+
@property
|
213
|
+
def can_standardize(self) -> bool:
|
214
|
+
"""
|
215
|
+
:return: Whether the observation can be standardized.
|
216
|
+
"""
|
217
|
+
|
218
|
+
return False
|
219
|
+
|
220
|
+
def _get_observation_format(self) -> StatisticStorageTypes:
|
221
|
+
"""
|
222
|
+
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
223
|
+
enumeration class.
|
224
|
+
"""
|
225
|
+
|
226
|
+
return StatisticStorageTypes.FLOAT
|
227
|
+
|
228
|
+
def _observe(
|
229
|
+
self,
|
230
|
+
observation_inputs: ObservationInputs,
|
231
|
+
hyperparameter_states: HyperparameterStates,
|
232
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
233
|
+
action_taken: float | int | None,
|
234
|
+
) -> float | int | list[int | float] | TensorStatistics:
|
235
|
+
"""
|
236
|
+
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
237
|
+
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
238
|
+
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
239
|
+
names to floats or TensorStatistic models.
|
240
|
+
:param action_taken: Action taken by the agent this class instance is assigned to.
|
241
|
+
:return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
|
242
|
+
"""
|
243
|
+
|
244
|
+
count = list(tracked_statistics[statistic_trackers.NumberOfParameters.__name__].values())[0]
|
245
|
+
|
246
|
+
if self.use_log_transform:
|
247
|
+
return math.log(1 + count) # type: ignore
|
248
|
+
|
249
|
+
else:
|
250
|
+
return count
|
251
|
+
|
252
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
253
|
+
"""
|
254
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
255
|
+
needed.
|
256
|
+
"""
|
257
|
+
|
258
|
+
return {statistic_trackers.NumberOfParameters.__name__: None}
|
259
|
+
|
260
|
+
|
261
|
+
class NumberOfLayers(GlobalObserver):
|
262
|
+
|
263
|
+
def __init__(
|
264
|
+
self,
|
265
|
+
*,
|
266
|
+
use_log_transform: bool = True,
|
267
|
+
trainable_only: bool = False,
|
268
|
+
**kwargs,
|
269
|
+
) -> None:
|
270
|
+
"""
|
271
|
+
:param use_log_transform: Whether to transform the return of the Observer by ln(1 + N).
|
272
|
+
:param trainable_only: Whether to only count trainable layers.
|
273
|
+
:param kwargs: Miscellaneous keyword arguments.
|
274
|
+
"""
|
275
|
+
|
276
|
+
super().__init__(**kwargs)
|
277
|
+
|
278
|
+
self.use_log_transform = use_log_transform
|
279
|
+
self.trainable_only = trainable_only
|
280
|
+
|
281
|
+
@property
|
282
|
+
def can_standardize(self) -> bool:
|
283
|
+
"""
|
284
|
+
:return: Whether the observation can be standardized.
|
285
|
+
"""
|
286
|
+
|
287
|
+
return False
|
288
|
+
|
289
|
+
def _get_observation_format(self) -> StatisticStorageTypes:
|
290
|
+
"""
|
291
|
+
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
292
|
+
enumeration class.
|
293
|
+
"""
|
294
|
+
|
295
|
+
return StatisticStorageTypes.FLOAT
|
296
|
+
|
297
|
+
def _observe(
|
298
|
+
self,
|
299
|
+
observation_inputs: ObservationInputs,
|
300
|
+
hyperparameter_states: HyperparameterStates,
|
301
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
302
|
+
action_taken: float | int | None,
|
303
|
+
) -> float | int | list[int | float] | TensorStatistics:
|
304
|
+
"""
|
305
|
+
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
306
|
+
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
307
|
+
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
308
|
+
names to floats or TensorStatistic models.
|
309
|
+
:param action_taken: Action taken by the agent this class instance is assigned to.
|
310
|
+
:return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
|
311
|
+
"""
|
312
|
+
|
313
|
+
count = list(tracked_statistics[statistic_trackers.NumberOfLayers.__name__].values())[0]
|
314
|
+
|
315
|
+
if self.use_log_transform:
|
316
|
+
return math.log(1 + count) # type: ignore
|
317
|
+
|
318
|
+
else:
|
319
|
+
return count
|
320
|
+
|
321
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
322
|
+
"""
|
323
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
324
|
+
needed.
|
325
|
+
"""
|
326
|
+
|
327
|
+
return {statistic_trackers.NumberOfLayers.__name__: dict(trainable_only=self.trainable_only)}
|
@@ -0,0 +1,142 @@
|
|
1
|
+
# ======================================================================================================================
|
2
|
+
#
|
3
|
+
# imports
|
4
|
+
#
|
5
|
+
# ======================================================================================================================
|
6
|
+
|
7
|
+
from typing import Any
|
8
|
+
|
9
|
+
from libinephany.observations.observation_utils import StatisticStorageTypes
|
10
|
+
from libinephany.observations.observers.base_observers import GlobalObserver
|
11
|
+
from libinephany.observations.observers.global_observers.base_classes import LHOPTCheckpointBaseObserver
|
12
|
+
from libinephany.pydantic_models.schemas.observation_models import ObservationInputs
|
13
|
+
from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
|
14
|
+
from libinephany.pydantic_models.states.hyperparameter_states import HyperparameterStates
|
15
|
+
|
16
|
+
|
17
|
+
class TrainingProgress(GlobalObserver):
|
18
|
+
|
19
|
+
@property
|
20
|
+
def can_standardize(self) -> bool:
|
21
|
+
"""
|
22
|
+
:return: Whether the observation can be standardized.
|
23
|
+
"""
|
24
|
+
|
25
|
+
return False
|
26
|
+
|
27
|
+
def _get_observation_format(self) -> StatisticStorageTypes:
|
28
|
+
"""
|
29
|
+
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
30
|
+
enumeration class.
|
31
|
+
"""
|
32
|
+
|
33
|
+
return StatisticStorageTypes.FLOAT
|
34
|
+
|
35
|
+
def _observe(
|
36
|
+
self,
|
37
|
+
observation_inputs: ObservationInputs,
|
38
|
+
hyperparameter_states: HyperparameterStates,
|
39
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
40
|
+
action_taken: float | int | None,
|
41
|
+
) -> float | int | list[int | float] | TensorStatistics:
|
42
|
+
"""
|
43
|
+
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
44
|
+
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
45
|
+
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
46
|
+
names to floats or TensorStatistic models.
|
47
|
+
:param action_taken: Action taken by the agent this class instance is assigned to.
|
48
|
+
:return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
|
49
|
+
"""
|
50
|
+
|
51
|
+
return observation_inputs.training_progress
|
52
|
+
|
53
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
54
|
+
"""
|
55
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
56
|
+
needed.
|
57
|
+
"""
|
58
|
+
|
59
|
+
return {}
|
60
|
+
|
61
|
+
|
62
|
+
class EpochsCompleted(GlobalObserver):
|
63
|
+
|
64
|
+
@property
|
65
|
+
def can_standardize(self) -> bool:
|
66
|
+
"""
|
67
|
+
:return: Whether the observation can be standardized.
|
68
|
+
"""
|
69
|
+
|
70
|
+
return False
|
71
|
+
|
72
|
+
def _get_observation_format(self) -> StatisticStorageTypes:
|
73
|
+
"""
|
74
|
+
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
75
|
+
enumeration class.
|
76
|
+
"""
|
77
|
+
|
78
|
+
return StatisticStorageTypes.FLOAT
|
79
|
+
|
80
|
+
def _observe(
|
81
|
+
self,
|
82
|
+
observation_inputs: ObservationInputs,
|
83
|
+
hyperparameter_states: HyperparameterStates,
|
84
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
85
|
+
action_taken: float | int | None,
|
86
|
+
) -> float | int | list[int | float] | TensorStatistics:
|
87
|
+
"""
|
88
|
+
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
89
|
+
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
90
|
+
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
91
|
+
names to floats or TensorStatistic models.
|
92
|
+
:param action_taken: Action taken by the agent this class instance is assigned to.
|
93
|
+
:return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
|
94
|
+
"""
|
95
|
+
|
96
|
+
return observation_inputs.epochs_completed
|
97
|
+
|
98
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
99
|
+
"""
|
100
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
101
|
+
needed.
|
102
|
+
"""
|
103
|
+
|
104
|
+
return {}
|
105
|
+
|
106
|
+
|
107
|
+
class ProgressAtEachCheckpoint(LHOPTCheckpointBaseObserver):
|
108
|
+
"""
|
109
|
+
This is a global observer from the paper "Learning to Optimize with Reinforcement Learning"
|
110
|
+
https://arxiv.org/abs/2305.18291.
|
111
|
+
|
112
|
+
It returns a single float value that is the training progress at the current checkpoint.
|
113
|
+
The observation tracks training progress and returns the progress value only when a checkpoint is reached.
|
114
|
+
"""
|
115
|
+
|
116
|
+
def _observe(
|
117
|
+
self,
|
118
|
+
observation_inputs: ObservationInputs,
|
119
|
+
hyperparameter_states: HyperparameterStates,
|
120
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
121
|
+
action_taken: float | int | None,
|
122
|
+
) -> float:
|
123
|
+
"""
|
124
|
+
Returns training progress at each checkpoint interval.
|
125
|
+
|
126
|
+
Before checkpoint: returns progress towards next checkpoint (0 to 1)
|
127
|
+
At checkpoint: returns the actual training progress value
|
128
|
+
"""
|
129
|
+
current_progress = observation_inputs.training_progress
|
130
|
+
|
131
|
+
# Cold start: If the last progress is not set, set it to the first progress record
|
132
|
+
self._cold_start(current_progress)
|
133
|
+
|
134
|
+
self._update_history(current_progress)
|
135
|
+
|
136
|
+
# Check if we should create a checkpoint
|
137
|
+
if self._should_create_checkpoint():
|
138
|
+
# Return the progress at this checkpoint
|
139
|
+
self.last_value = current_progress
|
140
|
+
return current_progress
|
141
|
+
else:
|
142
|
+
return self.last_value
|
@@ -84,7 +84,9 @@ class ObserverContainer(ABC):
|
|
84
84
|
return sum(observer.observation_size for observer in self._observers)
|
85
85
|
|
86
86
|
@final
|
87
|
-
def _build_observers(
|
87
|
+
def _build_observers(
|
88
|
+
self,
|
89
|
+
) -> list[LocalObserver | GlobalObserver]:
|
88
90
|
"""
|
89
91
|
:return: List of instantiated observers.
|
90
92
|
"""
|
libinephany/utils/constants.py
CHANGED
libinephany/utils/enums.py
CHANGED
@@ -78,6 +78,24 @@ class AgentTypes(EnumWithIndices):
|
|
78
78
|
Tokens = TOKENS
|
79
79
|
Samples = SAMPLES
|
80
80
|
|
81
|
+
@classmethod
|
82
|
+
def get_possible_active_agents(cls) -> list["AgentTypes"]:
|
83
|
+
"""
|
84
|
+
:return: List of active agents.
|
85
|
+
"""
|
86
|
+
|
87
|
+
return [
|
88
|
+
cls.LearningRateAgent,
|
89
|
+
cls.WeightDecayAgent,
|
90
|
+
cls.DropoutAgent,
|
91
|
+
cls.GradientClippingAgent,
|
92
|
+
cls.AdamBetaOneAgent,
|
93
|
+
cls.AdamBetaTwoAgent,
|
94
|
+
cls.AdamEpsAgent,
|
95
|
+
cls.SGDMomentumAgent,
|
96
|
+
cls.GradientAccumulationAgent,
|
97
|
+
]
|
98
|
+
|
81
99
|
|
82
100
|
class ModelFamilies(EnumWithIndices):
|
83
101
|
|
libinephany/utils/samplers.py
CHANGED
@@ -68,6 +68,13 @@ class Sampler:
|
|
68
68
|
|
69
69
|
raise NotImplementedError
|
70
70
|
|
71
|
+
@classmethod
|
72
|
+
def get_subclasses(cls):
|
73
|
+
"""Recursively gets subclasses of the Sampler class."""
|
74
|
+
for subclass in cls.__subclasses__():
|
75
|
+
yield from subclass.get_subclasses()
|
76
|
+
yield subclass
|
77
|
+
|
71
78
|
|
72
79
|
class LogUniformSampler(Sampler):
|
73
80
|
|
@@ -228,6 +235,34 @@ class DiscreteValueSampler(Sampler):
|
|
228
235
|
).astype(self.sample_dtype)
|
229
236
|
|
230
237
|
|
238
|
+
class DiscreteValueListSampler(DiscreteValueSampler):
|
239
|
+
|
240
|
+
def __init__(
|
241
|
+
self,
|
242
|
+
length: int,
|
243
|
+
discrete_values: list[float | int | str],
|
244
|
+
sample_dtype: type[np.generic | float | int | str] = np.float64,
|
245
|
+
**kwargs,
|
246
|
+
) -> None:
|
247
|
+
"""
|
248
|
+
:param length: Length of list to sample.
|
249
|
+
:param discrete_values: List of discrete values to sample from.
|
250
|
+
:param kwargs: Miscellaneous keyword arguments.
|
251
|
+
"""
|
252
|
+
|
253
|
+
super().__init__(discrete_values=discrete_values, sample_dtype=sample_dtype)
|
254
|
+
self.list_length = length
|
255
|
+
|
256
|
+
def sample(self, number_of_samples: int = 1, **kwargs) -> list[np.ndarray | list[Any]]:
|
257
|
+
"""
|
258
|
+
:param number_of_samples: Number of samples to make.
|
259
|
+
:param kwargs: Miscellaneous keyword arguments.
|
260
|
+
:return: Array of sampled values.
|
261
|
+
"""
|
262
|
+
|
263
|
+
return [super().sample(number_of_samples=self.list_length) for _ in range(number_of_samples)]
|
264
|
+
|
265
|
+
|
231
266
|
class RoundRobinDiscreteValueSampler(Sampler):
|
232
267
|
|
233
268
|
def __init__(
|
@@ -287,7 +322,7 @@ def build_sampler(sampler_name: str, lower_bound: float | int, upper_bound: floa
|
|
287
322
|
:return: Constructed sampler.
|
288
323
|
"""
|
289
324
|
|
290
|
-
possible_samplers = {sampler_type.__name__: sampler_type for sampler_type in Sampler.
|
325
|
+
possible_samplers = {sampler_type.__name__: sampler_type for sampler_type in Sampler.get_subclasses()}
|
291
326
|
|
292
327
|
try:
|
293
328
|
return possible_samplers[sampler_name](lower_bound=lower_bound, upper_bound=upper_bound, **kwargs) # type: ignore
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: libinephany
|
3
|
-
Version: 0.16.
|
3
|
+
Version: 0.16.4
|
4
4
|
Summary: Inephany library containing code commonly used by multiple subpackages.
|
5
5
|
Author-email: Inephany <info@inephany.com>
|
6
6
|
License: Apache 2.0
|
@@ -18,6 +18,7 @@ Requires-Dist: pydantic<3.0.0,>=2.5.0
|
|
18
18
|
Requires-Dist: loguru<0.8.0,>=0.7.0
|
19
19
|
Requires-Dist: requests<3.0.0,>=2.28.0
|
20
20
|
Requires-Dist: numpy<2.0.0,>=1.24.0
|
21
|
+
Requires-Dist: scipy<2.0.0,>=1.10.0
|
21
22
|
Requires-Dist: slack-sdk<4.0.0,>=3.20.0
|
22
23
|
Requires-Dist: boto3<2.0.0,>=1.26.0
|
23
24
|
Requires-Dist: fastapi<0.116.0,>=0.100.0
|
@@ -2,16 +2,23 @@ libinephany/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
2
|
libinephany/aws/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
libinephany/aws/s3_functions.py,sha256=W8u85A6tDloo4FlJvydJbVHCUq_m9i8KDGdnKzy-Xpg,1745
|
4
4
|
libinephany/observations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
5
|
-
libinephany/observations/observation_utils.py,sha256=
|
5
|
+
libinephany/observations/observation_utils.py,sha256=CXep6CVvzgh8I_SSmZvyupQLiwDzKDEiVLYqidhnL-A,16264
|
6
6
|
libinephany/observations/observer_pipeline.py,sha256=RvMH-TTDTu1Nk4S_KSHDkII1YuIRMSOXkPhn6g4B9ow,12815
|
7
7
|
libinephany/observations/pipeline_coordinator.py,sha256=mw3c5jy_BWvNigUKNjIWMpReOjxFDblzOcWtsIkcls4,7907
|
8
8
|
libinephany/observations/statistic_manager.py,sha256=LLg1zSxnJr2oQQepYla3qoUuRy10rsthr9jta4wEbnc,8956
|
9
9
|
libinephany/observations/statistic_trackers.py,sha256=PUBqGgMRi51SmiNh5HAH5kpYxsaflRepmM-uKyMiQZg,30326
|
10
10
|
libinephany/observations/observers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
11
|
libinephany/observations/observers/base_observers.py,sha256=RkG5SW0b6Ooy0_oscRHxyB_YFNP7k8fxu37jBZElxIM,15418
|
12
|
-
libinephany/observations/observers/global_observers.py,sha256=3TaiV2AxMOXfDq-kXMU3ZSo-rQENNCFhdWCJtpY99ok,38684
|
13
12
|
libinephany/observations/observers/local_observers.py,sha256=EdivrylOcmxRsu4xiMwZqwmPX8Ru9-IRwoPk6En7qvw,37050
|
14
|
-
libinephany/observations/observers/observer_containers.py,sha256=
|
13
|
+
libinephany/observations/observers/observer_containers.py,sha256=Inz4vkqaXsyj09EJYaWWx1eOIuzmOjbiul5wRL5Gm5Y,9274
|
14
|
+
libinephany/observations/observers/global_observers/__init__.py,sha256=hhJIOCGUKR7ai3pkHm9p42ZnyTWGOayg0jyLo7yQsZ0,1897
|
15
|
+
libinephany/observations/observers/global_observers/base_classes.py,sha256=WFaadwQHqN1G5-repyzvG3gq1Yii-NmjXmEN5g1eDK8,7184
|
16
|
+
libinephany/observations/observers/global_observers/constants.py,sha256=3Mc5rXtg9kBLjRn-uOjp-shgUNZAbd_K1S1ut-2DoBc,844
|
17
|
+
libinephany/observations/observers/global_observers/gradient_observers.py,sha256=KzaFR0jCPnJwVr1Gr7vK5tDlou0dU7iZ0qTvY19xzqg,5000
|
18
|
+
libinephany/observations/observers/global_observers/hyperparameter_observers.py,sha256=G3ic1Y1A-UGvXiXPeLrxn2efZrvOe4QY4DhIDbaOV-0,10483
|
19
|
+
libinephany/observations/observers/global_observers/loss_observers.py,sha256=q0fn_WyeqrB28n231dZMqoDp3x6WUPcvvq9gWKJS0wM,18544
|
20
|
+
libinephany/observations/observers/global_observers/model_observers.py,sha256=Fa0jQEBtiKmq0jk5x91ha_k-tURKHJo_wRhOhuEtxVk,13337
|
21
|
+
libinephany/observations/observers/global_observers/progress_observers.py,sha256=m62jUiwPaOUzYG1h7Vg6znj_jK9699lDhg4AhK212s8,5615
|
15
22
|
libinephany/observations/post_processors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
16
23
|
libinephany/observations/post_processors/postprocessors.py,sha256=43_e5UaDPr2KbAvqc_w3wLqnlm7bgRjqgCtyQ95-8cM,5913
|
17
24
|
libinephany/pydantic_models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -32,16 +39,16 @@ libinephany/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU
|
|
32
39
|
libinephany/utils/agent_utils.py,sha256=_2w1AY5Y4mQ5hes_Rq014VhZXOtIOn-W92mZgeixv3g,2658
|
33
40
|
libinephany/utils/asyncio_worker.py,sha256=Ew23zKIbG1zwyCudcyiObMrw4G0f3p2QXzZfM4mePqI,2751
|
34
41
|
libinephany/utils/backend_statuses.py,sha256=ZbpBPbz0qKmeqxyGGN_ePTrQ7Wrxh7KM6W26UDbPXtQ,644
|
35
|
-
libinephany/utils/constants.py,sha256=
|
42
|
+
libinephany/utils/constants.py,sha256=Qh8iz5o1R4UDVVCB69jOQPX2SLWRCncpb_2yTHpFSbY,2259
|
36
43
|
libinephany/utils/directory_utils.py,sha256=408unVeE_5_Hm-ZYZuxc9sdvfuU0CgYELX7EzPlPieo,1217
|
37
44
|
libinephany/utils/dropout_utils.py,sha256=X43yCW7Dh1cC5sNnivgS5j1fn871K_RCvxCBTT0YHKg,3392
|
38
|
-
libinephany/utils/enums.py,sha256=
|
45
|
+
libinephany/utils/enums.py,sha256=6_6k_1I2BwYTIfquUOsoaQT5fkhMXUWtwCxLoTYuFyU,2906
|
39
46
|
libinephany/utils/error_severities.py,sha256=B9oidqOVaYOe0W6P6GwjpmuDsrkyTX30v1xdiUStCFk,1427
|
40
47
|
libinephany/utils/exceptions.py,sha256=kgwLpHOgy3kciUz_I18xnYsWRtzdonfadUtwG2uDYk8,1823
|
41
48
|
libinephany/utils/import_utils.py,sha256=WzC6V6UIa0nCiU2MekROwG82fWBh9RuVzichtby5EvM,1495
|
42
49
|
libinephany/utils/optim_utils.py,sha256=-PLqsyuq4ZH3spBy_olNB3yuLwvhnLrCF0384elCmXc,8777
|
43
50
|
libinephany/utils/random_seeds.py,sha256=eF-ErrMShu8mp9V_gXrB_iUxR-Lb-OtHypEEUQAGn2Y,1565
|
44
|
-
libinephany/utils/samplers.py,sha256=
|
51
|
+
libinephany/utils/samplers.py,sha256=7h_el2dLJi2J97f_zpvc4BrEzoM_EJgZk1-ZjRkOhZ8,13357
|
45
52
|
libinephany/utils/standardizers.py,sha256=pG1K_XL4OR_NjVtT6Hjbln1dk1BtQdDuSK1PQTkA17Y,8014
|
46
53
|
libinephany/utils/torch_distributed_utils.py,sha256=UPMfhdZZwyHX_r3h55AAK4PcB-zFtjK37Z5aawAKNmE,2968
|
47
54
|
libinephany/utils/torch_utils.py,sha256=o5TsqrXe6Id04P6SqB_avGBRZutbu6IBB61llAHQ_PY,2696
|
@@ -50,8 +57,8 @@ libinephany/utils/typing.py,sha256=rGbaPO3MaUndsWiC_wHzReD_TOLYqb43i01pKN-j7Xs,6
|
|
50
57
|
libinephany/web_apps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
51
58
|
libinephany/web_apps/error_logger.py,sha256=gAQIaqerqP4ornXZwFF1cghjnd2mMZEt3aVrTuUCr34,16653
|
52
59
|
libinephany/web_apps/web_app_utils.py,sha256=qiq_lasPipgN1RgRudPJc342kYci8O_4RqppxmIX8NY,4095
|
53
|
-
libinephany-0.16.
|
54
|
-
libinephany-0.16.
|
55
|
-
libinephany-0.16.
|
56
|
-
libinephany-0.16.
|
57
|
-
libinephany-0.16.
|
60
|
+
libinephany-0.16.4.dist-info/licenses/LICENSE,sha256=pogfDoMBP07ehIOvWymuWIar8pg2YLUhqOHsJQU3wdc,9250
|
61
|
+
libinephany-0.16.4.dist-info/METADATA,sha256=HsVgYNptrATUp_vFZziVcK60xJM9whyMbbf0ZZPnj4s,8390
|
62
|
+
libinephany-0.16.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
63
|
+
libinephany-0.16.4.dist-info/top_level.txt,sha256=bYAOXQdJgIoLkO2Ui0kxe7pSYegS_e38u0dMscd7COQ,12
|
64
|
+
libinephany-0.16.4.dist-info/RECORD,,
|