libinephany 0.16.2__py3-none-any.whl → 0.16.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,286 @@
1
+ # ======================================================================================================================
2
+ #
3
+ # HYPERPARAMETER OBSERVERS
4
+ #
5
+ # ======================================================================================================================
6
+
7
+ import random
8
+ from typing import Any
9
+
10
+ from torch.optim import SGD, Adam, AdamW
11
+
12
+ from libinephany.observations import observation_utils
13
+ from libinephany.observations.observation_utils import StatisticStorageTypes
14
+ from libinephany.observations.observers.base_observers import GlobalObserver
15
+ from libinephany.pydantic_models.schemas.observation_models import ObservationInputs
16
+ from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
17
+ from libinephany.pydantic_models.states.hyperparameter_states import HyperparameterStates
18
+ from libinephany.utils.enums import ModelFamilies
19
+
20
+
21
+ class InitialHyperparameters(GlobalObserver):
22
+
23
+ def __init__(self, skip_hparams: list[str] | None = None, pad_with: float = 0.0, **kwargs) -> None:
24
+ """
25
+ :param skip_hparams: Names of the hyperparameters to not include in the initial values vector returned by
26
+ this observation.
27
+ :param kwargs: Miscellaneous keyword arguments.
28
+ """
29
+
30
+ super().__init__(**kwargs)
31
+
32
+ force_skip = ["samples", "gradient_accumulation"]
33
+ skip_hparams = force_skip if skip_hparams is None else skip_hparams + force_skip
34
+ self.skip_hparams = [] if skip_hparams is None else skip_hparams
35
+ self.pad_with = pad_with
36
+
37
+ @property
38
+ def vector_length(self) -> int:
39
+ """
40
+ :return: Length of the vector returned by this observation if it returns a vector.
41
+ """
42
+
43
+ available_hparams = HyperparameterStates.get_all_hyperparameters()
44
+
45
+ return len(
46
+ [hparam for hparam in available_hparams if not any(skipped in hparam for skipped in self.skip_hparams)]
47
+ )
48
+
49
+ @property
50
+ def can_standardize(self) -> bool:
51
+ """
52
+ :return: Whether the observation can be standardized.
53
+ """
54
+
55
+ return False
56
+
57
+ @property
58
+ def can_inform(self) -> bool:
59
+ """
60
+ :return: Whether observations from the observer can be used in the agent info dictionary.
61
+ """
62
+
63
+ return False
64
+
65
+ def _get_observation_format(self) -> StatisticStorageTypes:
66
+ """
67
+ :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
68
+ enumeration class.
69
+ """
70
+
71
+ return StatisticStorageTypes.VECTOR
72
+
73
+ def _observe(
74
+ self,
75
+ observation_inputs: ObservationInputs,
76
+ hyperparameter_states: HyperparameterStates,
77
+ tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
78
+ action_taken: float | int | None,
79
+ ) -> float | int | list[int | float] | TensorStatistics:
80
+ """
81
+ :param observation_inputs: Observation input metrics not calculated with statistic trackers.
82
+ :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
83
+ :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
84
+ names to floats or TensorStatistic models.
85
+ :param action_taken: Action taken by the agent this class instance is assigned to.
86
+ :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
87
+ """
88
+
89
+ initial_internal_values = hyperparameter_states.get_initial_internal_values(self.skip_hparams)
90
+ self._cached_observation = initial_internal_values
91
+ initial_internal_values_list = [
92
+ self.pad_with if initial_internal_value is None else initial_internal_value
93
+ for hparam_name, initial_internal_value in initial_internal_values.items()
94
+ if hparam_name not in self.skip_hparams
95
+ ]
96
+ return initial_internal_values_list
97
+
98
+ def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
99
+ """
100
+ :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
101
+ needed.
102
+ """
103
+
104
+ return {}
105
+
106
+
107
+ class OptimizerTypeOneHot(GlobalObserver):
108
+
109
+ OPTIMS = [Adam.__name__, AdamW.__name__, SGD.__name__]
110
+
111
+ @property
112
+ def vector_length(self) -> int:
113
+ """
114
+ :return: Length of the vector returned by this observation if it returns a vector.
115
+ """
116
+
117
+ return len(self.OPTIMS)
118
+
119
+ @property
120
+ def can_inform(self) -> bool:
121
+ """
122
+ :return: Whether observations from the observer can be used in the agent info dictionary.
123
+ """
124
+
125
+ return False
126
+
127
+ def _get_observation_format(self) -> StatisticStorageTypes:
128
+ """
129
+ :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
130
+ enumeration class.
131
+ """
132
+
133
+ return StatisticStorageTypes.VECTOR
134
+
135
+ def _observe(
136
+ self,
137
+ observation_inputs: ObservationInputs,
138
+ hyperparameter_states: HyperparameterStates,
139
+ tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
140
+ action_taken: float | int | None,
141
+ ) -> float | int | list[int | float] | TensorStatistics:
142
+ """
143
+ :param observation_inputs: Observation input metrics not calculated with statistic trackers.
144
+ :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
145
+ :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
146
+ names to floats or TensorStatistic models.
147
+ :param action_taken: Action taken by the agent this class instance is assigned to.
148
+ :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
149
+ """
150
+
151
+ optimizer_type = self.observer_config.optimizer_name
152
+
153
+ if optimizer_type not in self.OPTIMS:
154
+ index = None
155
+
156
+ else:
157
+ index = self.OPTIMS.index(optimizer_type)
158
+
159
+ return observation_utils.create_one_hot_observation(vector_length=self.vector_length, one_hot_index=index)
160
+
161
+ def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
162
+ """
163
+ :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
164
+ needed.
165
+ """
166
+
167
+ return {}
168
+
169
+
170
+ class ModelFamilyOneHot(GlobalObserver):
171
+
172
+ UNIT_EPISODE = "episode"
173
+ UNIT_TIMESTEP = "timestep"
174
+
175
+ def __init__(
176
+ self,
177
+ *,
178
+ zero_vector_chance: float = 0.2,
179
+ zero_vector_frequency_unit: str = "episode",
180
+ **kwargs,
181
+ ) -> None:
182
+ """
183
+ :param skip_observations: List of episode boundary observations to ignore.
184
+ :param kwargs: Miscellaneous keyword arguments.
185
+ """
186
+ super().__init__(**kwargs)
187
+ self.should_zero = False
188
+
189
+ assert 0.0 <= zero_vector_chance < 1.0
190
+ self.zero_vector_chance = zero_vector_chance
191
+ self._sample_zero_vector()
192
+
193
+ if zero_vector_frequency_unit not in [self.UNIT_EPISODE, self.UNIT_TIMESTEP]:
194
+ raise ValueError(f"Unknown zero_vector_frequency_unit: {zero_vector_frequency_unit}")
195
+
196
+ self.zero_vector_frequency_unit = zero_vector_frequency_unit
197
+ self.family_vector = self._create_family_vector()
198
+
199
+ @property
200
+ def vector_length(self) -> int:
201
+ """
202
+ :return: Length of the vector returned by this observation if it returns a vector.
203
+ """
204
+
205
+ return len(ModelFamilies)
206
+
207
+ @property
208
+ def can_inform(self) -> bool:
209
+ """
210
+ :return: Whether observations from the observer can be used in the agent info dictionary.
211
+ """
212
+
213
+ return False
214
+
215
+ def _get_observation_format(self) -> StatisticStorageTypes:
216
+ """
217
+ :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
218
+ enumeration class.
219
+ """
220
+
221
+ return StatisticStorageTypes.VECTOR
222
+
223
+ def _create_family_vector(self) -> list[float]:
224
+ """
225
+ :return: Creates and returns the model family one-hot vector.
226
+ """
227
+
228
+ family_name = self.observer_config.nn_family_name
229
+ known_name = family_name in (family.value for family in ModelFamilies)
230
+
231
+ if known_name:
232
+ family_idx = ModelFamilies.get_index(family_name)
233
+
234
+ else:
235
+ family_idx = None
236
+
237
+ return observation_utils.create_one_hot_observation(vector_length=self.vector_length, one_hot_index=family_idx)
238
+
239
+ def _observe(
240
+ self,
241
+ observation_inputs: ObservationInputs,
242
+ hyperparameter_states: HyperparameterStates,
243
+ tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
244
+ action_taken: float | int | None,
245
+ ) -> float | int | list[int | float] | TensorStatistics:
246
+ """
247
+ :param observation_inputs: Observation input metrics not calculated with statistic trackers.
248
+ :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
249
+ :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
250
+ names to floats or TensorStatistic models.
251
+ :param action_taken: Action taken by the agent this class instance is assigned to.
252
+ :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
253
+ """
254
+
255
+ if not self.in_training_mode:
256
+ return self.family_vector
257
+
258
+ if self.zero_vector_frequency_unit == self.UNIT_TIMESTEP:
259
+ self._sample_zero_vector()
260
+
261
+ if self.should_zero:
262
+ return [0.0 for _ in range(self.vector_length)]
263
+
264
+ else:
265
+ return self.family_vector
266
+
267
+ def _sample_zero_vector(self) -> None:
268
+ """
269
+ Determines whether the output vector of this observer should be masked with zeros.
270
+ """
271
+ self.should_zero = random.choices([True, False], [self.zero_vector_chance, (1 - self.zero_vector_chance)])[0]
272
+
273
+ def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
274
+ """
275
+ :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
276
+ needed.
277
+ """
278
+
279
+ return {}
280
+
281
+ def reset(self) -> None:
282
+ """
283
+ Resets the observer.
284
+ """
285
+
286
+ self._sample_zero_vector()