libinephany 0.16.3__py3-none-any.whl → 0.16.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- libinephany/observations/observation_utils.py +2 -14
- libinephany/observations/observers/global_observers/__init__.py +58 -0
- libinephany/observations/observers/global_observers/base_classes.py +183 -0
- libinephany/observations/observers/global_observers/constants.py +33 -0
- libinephany/observations/observers/global_observers/gradient_observers.py +112 -0
- libinephany/observations/observers/global_observers/hyperparameter_observers.py +286 -0
- libinephany/observations/observers/global_observers/loss_observers.py +464 -0
- libinephany/observations/observers/global_observers/model_observers.py +327 -0
- libinephany/observations/observers/global_observers/progress_observers.py +142 -0
- libinephany/observations/observers/observer_containers.py +3 -1
- {libinephany-0.16.3.dist-info → libinephany-0.16.4.dist-info}/METADATA +1 -1
- {libinephany-0.16.3.dist-info → libinephany-0.16.4.dist-info}/RECORD +15 -8
- libinephany/observations/observers/global_observers.py +0 -991
- {libinephany-0.16.3.dist-info → libinephany-0.16.4.dist-info}/WHEEL +0 -0
- {libinephany-0.16.3.dist-info → libinephany-0.16.4.dist-info}/licenses/LICENSE +0 -0
- {libinephany-0.16.3.dist-info → libinephany-0.16.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,286 @@
|
|
1
|
+
# ======================================================================================================================
|
2
|
+
#
|
3
|
+
# HYPERPARAMETER OBSERVERS
|
4
|
+
#
|
5
|
+
# ======================================================================================================================
|
6
|
+
|
7
|
+
import random
|
8
|
+
from typing import Any
|
9
|
+
|
10
|
+
from torch.optim import SGD, Adam, AdamW
|
11
|
+
|
12
|
+
from libinephany.observations import observation_utils
|
13
|
+
from libinephany.observations.observation_utils import StatisticStorageTypes
|
14
|
+
from libinephany.observations.observers.base_observers import GlobalObserver
|
15
|
+
from libinephany.pydantic_models.schemas.observation_models import ObservationInputs
|
16
|
+
from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
|
17
|
+
from libinephany.pydantic_models.states.hyperparameter_states import HyperparameterStates
|
18
|
+
from libinephany.utils.enums import ModelFamilies
|
19
|
+
|
20
|
+
|
21
|
+
class InitialHyperparameters(GlobalObserver):
|
22
|
+
|
23
|
+
def __init__(self, skip_hparams: list[str] | None = None, pad_with: float = 0.0, **kwargs) -> None:
|
24
|
+
"""
|
25
|
+
:param skip_hparams: Names of the hyperparameters to not include in the initial values vector returned by
|
26
|
+
this observation.
|
27
|
+
:param kwargs: Miscellaneous keyword arguments.
|
28
|
+
"""
|
29
|
+
|
30
|
+
super().__init__(**kwargs)
|
31
|
+
|
32
|
+
force_skip = ["samples", "gradient_accumulation"]
|
33
|
+
skip_hparams = force_skip if skip_hparams is None else skip_hparams + force_skip
|
34
|
+
self.skip_hparams = [] if skip_hparams is None else skip_hparams
|
35
|
+
self.pad_with = pad_with
|
36
|
+
|
37
|
+
@property
|
38
|
+
def vector_length(self) -> int:
|
39
|
+
"""
|
40
|
+
:return: Length of the vector returned by this observation if it returns a vector.
|
41
|
+
"""
|
42
|
+
|
43
|
+
available_hparams = HyperparameterStates.get_all_hyperparameters()
|
44
|
+
|
45
|
+
return len(
|
46
|
+
[hparam for hparam in available_hparams if not any(skipped in hparam for skipped in self.skip_hparams)]
|
47
|
+
)
|
48
|
+
|
49
|
+
@property
|
50
|
+
def can_standardize(self) -> bool:
|
51
|
+
"""
|
52
|
+
:return: Whether the observation can be standardized.
|
53
|
+
"""
|
54
|
+
|
55
|
+
return False
|
56
|
+
|
57
|
+
@property
|
58
|
+
def can_inform(self) -> bool:
|
59
|
+
"""
|
60
|
+
:return: Whether observations from the observer can be used in the agent info dictionary.
|
61
|
+
"""
|
62
|
+
|
63
|
+
return False
|
64
|
+
|
65
|
+
def _get_observation_format(self) -> StatisticStorageTypes:
|
66
|
+
"""
|
67
|
+
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
68
|
+
enumeration class.
|
69
|
+
"""
|
70
|
+
|
71
|
+
return StatisticStorageTypes.VECTOR
|
72
|
+
|
73
|
+
def _observe(
|
74
|
+
self,
|
75
|
+
observation_inputs: ObservationInputs,
|
76
|
+
hyperparameter_states: HyperparameterStates,
|
77
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
78
|
+
action_taken: float | int | None,
|
79
|
+
) -> float | int | list[int | float] | TensorStatistics:
|
80
|
+
"""
|
81
|
+
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
82
|
+
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
83
|
+
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
84
|
+
names to floats or TensorStatistic models.
|
85
|
+
:param action_taken: Action taken by the agent this class instance is assigned to.
|
86
|
+
:return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
|
87
|
+
"""
|
88
|
+
|
89
|
+
initial_internal_values = hyperparameter_states.get_initial_internal_values(self.skip_hparams)
|
90
|
+
self._cached_observation = initial_internal_values
|
91
|
+
initial_internal_values_list = [
|
92
|
+
self.pad_with if initial_internal_value is None else initial_internal_value
|
93
|
+
for hparam_name, initial_internal_value in initial_internal_values.items()
|
94
|
+
if hparam_name not in self.skip_hparams
|
95
|
+
]
|
96
|
+
return initial_internal_values_list
|
97
|
+
|
98
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
99
|
+
"""
|
100
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
101
|
+
needed.
|
102
|
+
"""
|
103
|
+
|
104
|
+
return {}
|
105
|
+
|
106
|
+
|
107
|
+
class OptimizerTypeOneHot(GlobalObserver):
|
108
|
+
|
109
|
+
OPTIMS = [Adam.__name__, AdamW.__name__, SGD.__name__]
|
110
|
+
|
111
|
+
@property
|
112
|
+
def vector_length(self) -> int:
|
113
|
+
"""
|
114
|
+
:return: Length of the vector returned by this observation if it returns a vector.
|
115
|
+
"""
|
116
|
+
|
117
|
+
return len(self.OPTIMS)
|
118
|
+
|
119
|
+
@property
|
120
|
+
def can_inform(self) -> bool:
|
121
|
+
"""
|
122
|
+
:return: Whether observations from the observer can be used in the agent info dictionary.
|
123
|
+
"""
|
124
|
+
|
125
|
+
return False
|
126
|
+
|
127
|
+
def _get_observation_format(self) -> StatisticStorageTypes:
|
128
|
+
"""
|
129
|
+
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
130
|
+
enumeration class.
|
131
|
+
"""
|
132
|
+
|
133
|
+
return StatisticStorageTypes.VECTOR
|
134
|
+
|
135
|
+
def _observe(
|
136
|
+
self,
|
137
|
+
observation_inputs: ObservationInputs,
|
138
|
+
hyperparameter_states: HyperparameterStates,
|
139
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
140
|
+
action_taken: float | int | None,
|
141
|
+
) -> float | int | list[int | float] | TensorStatistics:
|
142
|
+
"""
|
143
|
+
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
144
|
+
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
145
|
+
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
146
|
+
names to floats or TensorStatistic models.
|
147
|
+
:param action_taken: Action taken by the agent this class instance is assigned to.
|
148
|
+
:return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
|
149
|
+
"""
|
150
|
+
|
151
|
+
optimizer_type = self.observer_config.optimizer_name
|
152
|
+
|
153
|
+
if optimizer_type not in self.OPTIMS:
|
154
|
+
index = None
|
155
|
+
|
156
|
+
else:
|
157
|
+
index = self.OPTIMS.index(optimizer_type)
|
158
|
+
|
159
|
+
return observation_utils.create_one_hot_observation(vector_length=self.vector_length, one_hot_index=index)
|
160
|
+
|
161
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
162
|
+
"""
|
163
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
164
|
+
needed.
|
165
|
+
"""
|
166
|
+
|
167
|
+
return {}
|
168
|
+
|
169
|
+
|
170
|
+
class ModelFamilyOneHot(GlobalObserver):
|
171
|
+
|
172
|
+
UNIT_EPISODE = "episode"
|
173
|
+
UNIT_TIMESTEP = "timestep"
|
174
|
+
|
175
|
+
def __init__(
|
176
|
+
self,
|
177
|
+
*,
|
178
|
+
zero_vector_chance: float = 0.2,
|
179
|
+
zero_vector_frequency_unit: str = "episode",
|
180
|
+
**kwargs,
|
181
|
+
) -> None:
|
182
|
+
"""
|
183
|
+
:param skip_observations: List of episode boundary observations to ignore.
|
184
|
+
:param kwargs: Miscellaneous keyword arguments.
|
185
|
+
"""
|
186
|
+
super().__init__(**kwargs)
|
187
|
+
self.should_zero = False
|
188
|
+
|
189
|
+
assert 0.0 <= zero_vector_chance < 1.0
|
190
|
+
self.zero_vector_chance = zero_vector_chance
|
191
|
+
self._sample_zero_vector()
|
192
|
+
|
193
|
+
if zero_vector_frequency_unit not in [self.UNIT_EPISODE, self.UNIT_TIMESTEP]:
|
194
|
+
raise ValueError(f"Unknown zero_vector_frequency_unit: {zero_vector_frequency_unit}")
|
195
|
+
|
196
|
+
self.zero_vector_frequency_unit = zero_vector_frequency_unit
|
197
|
+
self.family_vector = self._create_family_vector()
|
198
|
+
|
199
|
+
@property
|
200
|
+
def vector_length(self) -> int:
|
201
|
+
"""
|
202
|
+
:return: Length of the vector returned by this observation if it returns a vector.
|
203
|
+
"""
|
204
|
+
|
205
|
+
return len(ModelFamilies)
|
206
|
+
|
207
|
+
@property
|
208
|
+
def can_inform(self) -> bool:
|
209
|
+
"""
|
210
|
+
:return: Whether observations from the observer can be used in the agent info dictionary.
|
211
|
+
"""
|
212
|
+
|
213
|
+
return False
|
214
|
+
|
215
|
+
def _get_observation_format(self) -> StatisticStorageTypes:
|
216
|
+
"""
|
217
|
+
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
218
|
+
enumeration class.
|
219
|
+
"""
|
220
|
+
|
221
|
+
return StatisticStorageTypes.VECTOR
|
222
|
+
|
223
|
+
def _create_family_vector(self) -> list[float]:
|
224
|
+
"""
|
225
|
+
:return: Creates and returns the model family one-hot vector.
|
226
|
+
"""
|
227
|
+
|
228
|
+
family_name = self.observer_config.nn_family_name
|
229
|
+
known_name = family_name in (family.value for family in ModelFamilies)
|
230
|
+
|
231
|
+
if known_name:
|
232
|
+
family_idx = ModelFamilies.get_index(family_name)
|
233
|
+
|
234
|
+
else:
|
235
|
+
family_idx = None
|
236
|
+
|
237
|
+
return observation_utils.create_one_hot_observation(vector_length=self.vector_length, one_hot_index=family_idx)
|
238
|
+
|
239
|
+
def _observe(
|
240
|
+
self,
|
241
|
+
observation_inputs: ObservationInputs,
|
242
|
+
hyperparameter_states: HyperparameterStates,
|
243
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
244
|
+
action_taken: float | int | None,
|
245
|
+
) -> float | int | list[int | float] | TensorStatistics:
|
246
|
+
"""
|
247
|
+
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
248
|
+
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
249
|
+
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
250
|
+
names to floats or TensorStatistic models.
|
251
|
+
:param action_taken: Action taken by the agent this class instance is assigned to.
|
252
|
+
:return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
|
253
|
+
"""
|
254
|
+
|
255
|
+
if not self.in_training_mode:
|
256
|
+
return self.family_vector
|
257
|
+
|
258
|
+
if self.zero_vector_frequency_unit == self.UNIT_TIMESTEP:
|
259
|
+
self._sample_zero_vector()
|
260
|
+
|
261
|
+
if self.should_zero:
|
262
|
+
return [0.0 for _ in range(self.vector_length)]
|
263
|
+
|
264
|
+
else:
|
265
|
+
return self.family_vector
|
266
|
+
|
267
|
+
def _sample_zero_vector(self) -> None:
|
268
|
+
"""
|
269
|
+
Determines whether the output vector of this observer should be masked with zeros.
|
270
|
+
"""
|
271
|
+
self.should_zero = random.choices([True, False], [self.zero_vector_chance, (1 - self.zero_vector_chance)])[0]
|
272
|
+
|
273
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
274
|
+
"""
|
275
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
276
|
+
needed.
|
277
|
+
"""
|
278
|
+
|
279
|
+
return {}
|
280
|
+
|
281
|
+
def reset(self) -> None:
|
282
|
+
"""
|
283
|
+
Resets the observer.
|
284
|
+
"""
|
285
|
+
|
286
|
+
self._sample_zero_vector()
|