libinephany 0.16.2__py3-none-any.whl → 0.16.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,991 +0,0 @@
1
- # ======================================================================================================================
2
- #
3
- # IMPORTS
4
- #
5
- # ======================================================================================================================
6
-
7
- import math
8
- import random
9
- from typing import Any
10
-
11
- from torch.optim import SGD, Adam, AdamW
12
-
13
- from libinephany.observations import observation_utils, statistic_trackers
14
- from libinephany.observations.observation_utils import StatisticStorageTypes
15
- from libinephany.observations.observers.base_observers import GlobalObserver
16
- from libinephany.pydantic_models.schemas.observation_models import ObservationInputs
17
- from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
18
- from libinephany.pydantic_models.states.hyperparameter_states import HyperparameterStates
19
- from libinephany.utils.enums import ModelFamilies
20
-
21
- # ======================================================================================================================
22
- #
23
- # CLASSES
24
- #
25
- # ======================================================================================================================
26
-
27
-
28
- class InitialHyperparameters(GlobalObserver):
29
-
30
- def __init__(self, skip_hparams: list[str] | None = None, pad_with: float = 0.0, **kwargs) -> None:
31
- """
32
- :param skip_hparams: Names of the hyperparameters to not include in the initial values vector returned by
33
- this observation.
34
- :param kwargs: Miscellaneous keyword arguments.
35
- """
36
-
37
- super().__init__(**kwargs)
38
-
39
- force_skip = ["samples", "gradient_accumulation"]
40
- skip_hparams = force_skip if skip_hparams is None else skip_hparams + force_skip
41
- self.skip_hparams = [] if skip_hparams is None else skip_hparams
42
- self.pad_with = pad_with
43
-
44
- @property
45
- def vector_length(self) -> int:
46
- """
47
- :return: Length of the vector returned by this observation if it returns a vector.
48
- """
49
-
50
- available_hparams = HyperparameterStates.get_all_hyperparameters()
51
-
52
- return len(
53
- [hparam for hparam in available_hparams if not any(skipped in hparam for skipped in self.skip_hparams)]
54
- )
55
-
56
- @property
57
- def can_standardize(self) -> bool:
58
- """
59
- :return: Whether the observation can be standardized.
60
- """
61
-
62
- return False
63
-
64
- @property
65
- def can_inform(self) -> bool:
66
- """
67
- :return: Whether observations from the observer can be used in the agent info dictionary.
68
- """
69
-
70
- return False
71
-
72
- def _get_observation_format(self) -> StatisticStorageTypes:
73
- """
74
- :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
75
- enumeration class.
76
- """
77
-
78
- return StatisticStorageTypes.VECTOR
79
-
80
- def _observe(
81
- self,
82
- observation_inputs: ObservationInputs,
83
- hyperparameter_states: HyperparameterStates,
84
- tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
85
- action_taken: float | int | None,
86
- ) -> float | int | list[int | float] | TensorStatistics:
87
- """
88
- :param observation_inputs: Observation input metrics not calculated with statistic trackers.
89
- :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
90
- :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
91
- names to floats or TensorStatistic models.
92
- :param action_taken: Action taken by the agent this class instance is assigned to.
93
- :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
94
- """
95
-
96
- initial_internal_values = hyperparameter_states.get_initial_internal_values(self.skip_hparams)
97
- self._cached_observation = initial_internal_values
98
- initial_internal_values_list = [
99
- self.pad_with if initial_internal_value is None else initial_internal_value
100
- for hparam_name, initial_internal_value in initial_internal_values.items()
101
- if hparam_name not in self.skip_hparams
102
- ]
103
- return initial_internal_values_list
104
-
105
- def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
106
- """
107
- :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
108
- needed.
109
- """
110
-
111
- return {}
112
-
113
-
114
- class TrainingLoss(GlobalObserver):
115
-
116
- @property
117
- def can_standardize(self) -> bool:
118
- """
119
- :return: Whether the observation can be standardized.
120
- """
121
-
122
- return False
123
-
124
- def _get_observation_format(self) -> StatisticStorageTypes:
125
- """
126
- :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
127
- enumeration class.
128
- """
129
-
130
- return StatisticStorageTypes.FLOAT
131
-
132
- def _observe(
133
- self,
134
- observation_inputs: ObservationInputs,
135
- hyperparameter_states: HyperparameterStates,
136
- tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
137
- action_taken: float | int | None,
138
- ) -> float | int | list[int | float] | TensorStatistics:
139
- """
140
- :param observation_inputs: Observation input metrics not calculated with statistic trackers.
141
- :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
142
- :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
143
- names to floats or TensorStatistic models.
144
- :param action_taken: Action taken by the agent this class instance is assigned to.
145
- :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
146
- """
147
-
148
- return observation_inputs.training_loss
149
-
150
- def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
151
- """
152
- :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
153
- needed.
154
- """
155
-
156
- return {}
157
-
158
-
159
- class ValidationLoss(GlobalObserver):
160
-
161
- @property
162
- def can_standardize(self) -> bool:
163
- """
164
- :return: Whether the observation can be standardized.
165
- """
166
-
167
- return False
168
-
169
- def _get_observation_format(self) -> StatisticStorageTypes:
170
- """
171
- :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
172
- enumeration class.
173
- """
174
-
175
- return StatisticStorageTypes.FLOAT
176
-
177
- def _observe(
178
- self,
179
- observation_inputs: ObservationInputs,
180
- hyperparameter_states: HyperparameterStates,
181
- tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
182
- action_taken: float | int | None,
183
- ) -> float | int | list[int | float] | TensorStatistics:
184
- """
185
- :param observation_inputs: Observation input metrics not calculated with statistic trackers.
186
- :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
187
- :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
188
- names to floats or TensorStatistic models.
189
- :param action_taken: Action taken by the agent this class instance is assigned to.
190
- :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
191
- """
192
-
193
- return observation_inputs.validation_loss
194
-
195
- def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
196
- """
197
- :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
198
- needed.
199
- """
200
-
201
- return {}
202
-
203
-
204
- class LossRatio(GlobalObserver):
205
-
206
- def _get_observation_format(self) -> StatisticStorageTypes:
207
- """
208
- :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
209
- enumeration class.
210
- """
211
-
212
- return StatisticStorageTypes.FLOAT
213
-
214
- def _observe(
215
- self,
216
- observation_inputs: ObservationInputs,
217
- hyperparameter_states: HyperparameterStates,
218
- tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
219
- action_taken: float | int | None,
220
- ) -> float | int | list[int | float] | TensorStatistics:
221
- """
222
- :param observation_inputs: Observation input metrics not calculated with statistic trackers.
223
- :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
224
- :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
225
- names to floats or TensorStatistic models.
226
- :param action_taken: Action taken by the agent this class instance is assigned to.
227
- :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
228
- """
229
-
230
- if observation_inputs.validation_loss == 0:
231
- return 0
232
-
233
- return observation_inputs.training_loss / observation_inputs.validation_loss
234
-
235
- def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
236
- """
237
- :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
238
- needed.
239
- """
240
-
241
- return {}
242
-
243
-
244
- class TrainingScore(GlobalObserver):
245
-
246
- def _get_observation_format(self) -> StatisticStorageTypes:
247
- """
248
- :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
249
- enumeration class.
250
- """
251
-
252
- return StatisticStorageTypes.FLOAT
253
-
254
- def _observe(
255
- self,
256
- observation_inputs: ObservationInputs,
257
- hyperparameter_states: HyperparameterStates,
258
- tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
259
- action_taken: float | int | None,
260
- ) -> float | int | list[int | float] | TensorStatistics:
261
- """
262
- :param observation_inputs: Observation input metrics not calculated with statistic trackers.
263
- :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
264
- :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
265
- names to floats or TensorStatistic models.
266
- :param action_taken: Action taken by the agent this class instance is assigned to.
267
- :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
268
- """
269
-
270
- return observation_inputs.training_score
271
-
272
- def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
273
- """
274
- :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
275
- needed.
276
- """
277
-
278
- return {}
279
-
280
-
281
- class ValidationScore(GlobalObserver):
282
-
283
- def _get_observation_format(self) -> StatisticStorageTypes:
284
- """
285
- :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
286
- enumeration class.
287
- """
288
-
289
- return StatisticStorageTypes.FLOAT
290
-
291
- def _observe(
292
- self,
293
- observation_inputs: ObservationInputs,
294
- hyperparameter_states: HyperparameterStates,
295
- tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
296
- action_taken: float | int | None,
297
- ) -> float | int | list[int | float] | TensorStatistics:
298
- """
299
- :param observation_inputs: Observation input metrics not calculated with statistic trackers.
300
- :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
301
- :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
302
- names to floats or TensorStatistic models.
303
- :param action_taken: Action taken by the agent this class instance is assigned to.
304
- :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
305
- """
306
-
307
- return observation_inputs.validation_score
308
-
309
- def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
310
- """
311
- :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
312
- needed.
313
- """
314
-
315
- return {}
316
-
317
-
318
- class TrainingProgress(GlobalObserver):
319
-
320
- @property
321
- def can_standardize(self) -> bool:
322
- """
323
- :return: Whether the observation can be standardized.
324
- """
325
-
326
- return False
327
-
328
- def _get_observation_format(self) -> StatisticStorageTypes:
329
- """
330
- :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
331
- enumeration class.
332
- """
333
-
334
- return StatisticStorageTypes.FLOAT
335
-
336
- def _observe(
337
- self,
338
- observation_inputs: ObservationInputs,
339
- hyperparameter_states: HyperparameterStates,
340
- tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
341
- action_taken: float | int | None,
342
- ) -> float | int | list[int | float] | TensorStatistics:
343
- """
344
- :param observation_inputs: Observation input metrics not calculated with statistic trackers.
345
- :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
346
- :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
347
- names to floats or TensorStatistic models.
348
- :param action_taken: Action taken by the agent this class instance is assigned to.
349
- :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
350
- """
351
-
352
- return observation_inputs.training_progress
353
-
354
- def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
355
- """
356
- :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
357
- needed.
358
- """
359
-
360
- return {}
361
-
362
-
363
- class EpochsCompleted(GlobalObserver):
364
-
365
- @property
366
- def can_standardize(self) -> bool:
367
- """
368
- :return: Whether the observation can be standardized.
369
- """
370
-
371
- return False
372
-
373
- def _get_observation_format(self) -> StatisticStorageTypes:
374
- """
375
- :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
376
- enumeration class.
377
- """
378
-
379
- return StatisticStorageTypes.FLOAT
380
-
381
- def _observe(
382
- self,
383
- observation_inputs: ObservationInputs,
384
- hyperparameter_states: HyperparameterStates,
385
- tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
386
- action_taken: float | int | None,
387
- ) -> float | int | list[int | float] | TensorStatistics:
388
- """
389
- :param observation_inputs: Observation input metrics not calculated with statistic trackers.
390
- :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
391
- :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
392
- names to floats or TensorStatistic models.
393
- :param action_taken: Action taken by the agent this class instance is assigned to.
394
- :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
395
- """
396
-
397
- return observation_inputs.epochs_completed
398
-
399
- def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
400
- """
401
- :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
402
- needed.
403
- """
404
-
405
- return {}
406
-
407
-
408
- class GlobalFirstOrderGradients(GlobalObserver):
409
-
410
- def _get_observation_format(self) -> StatisticStorageTypes:
411
- """
412
- :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
413
- enumeration class.
414
- """
415
-
416
- return StatisticStorageTypes.TENSOR_STATISTICS
417
-
418
- def _observe(
419
- self,
420
- observation_inputs: ObservationInputs,
421
- hyperparameter_states: HyperparameterStates,
422
- tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
423
- action_taken: float | int | None,
424
- ) -> float | int | list[int | float] | TensorStatistics:
425
- """
426
- :param observation_inputs: Observation input metrics not calculated with statistic trackers.
427
- :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
428
- :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
429
- names to floats or TensorStatistic models.
430
- :param action_taken: Action taken by the agent this class instance is assigned to.
431
- :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
432
- """
433
-
434
- statistics = tracked_statistics[statistic_trackers.FirstOrderGradients.__name__]
435
-
436
- return observation_utils.average_tensor_statistics(tensor_statistics=list(statistics.values())) # type: ignore
437
-
438
- def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
439
- """
440
- :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
441
- needed.
442
- """
443
-
444
- return {statistic_trackers.FirstOrderGradients.__name__: dict(skip_statistics=self.skip_statistics)}
445
-
446
-
447
- class GlobalSecondOrderGradients(GlobalObserver):
448
-
449
- def __init__(
450
- self,
451
- *,
452
- compute_hessian_diagonal: bool = False,
453
- **kwargs,
454
- ) -> None:
455
- """
456
- :param compute_hessian_diagonal: Whether to compute the Hessian diagonal to determine second order gradients
457
- or use the squared first order gradients as approximations in the same way Adam does.
458
- :param kwargs: Miscellaneous keyword arguments.
459
- """
460
-
461
- super().__init__(**kwargs)
462
-
463
- self.compute_hessian_diagonal = compute_hessian_diagonal
464
-
465
- def _get_observation_format(self) -> StatisticStorageTypes:
466
- """
467
- :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
468
- enumeration class.
469
- """
470
-
471
- return StatisticStorageTypes.TENSOR_STATISTICS
472
-
473
- def _observe(
474
- self,
475
- observation_inputs: ObservationInputs,
476
- hyperparameter_states: HyperparameterStates,
477
- tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
478
- action_taken: float | int | None,
479
- ) -> float | int | list[int | float] | TensorStatistics:
480
- """
481
- :param observation_inputs: Observation input metrics not calculated with statistic trackers.
482
- :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
483
- :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
484
- names to floats or TensorStatistic models.
485
- :param action_taken: Action taken by the agent this class instance is assigned to.
486
- :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
487
- """
488
-
489
- statistics = tracked_statistics[statistic_trackers.SecondOrderGradients.__name__]
490
-
491
- return observation_utils.average_tensor_statistics(tensor_statistics=list(statistics.values())) # type: ignore
492
-
493
- def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
494
- """
495
- :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
496
- needed.
497
- """
498
-
499
- return {
500
- statistic_trackers.SecondOrderGradients.__name__: dict(
501
- skip_statistics=self.skip_statistics, compute_hessian_diagonal=self.compute_hessian_diagonal
502
- )
503
- }
504
-
505
-
506
- class GlobalActivations(GlobalObserver):
507
-
508
- def _get_observation_format(self) -> StatisticStorageTypes:
509
- """
510
- :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
511
- enumeration class.
512
- """
513
-
514
- return StatisticStorageTypes.TENSOR_STATISTICS
515
-
516
- def _observe(
517
- self,
518
- observation_inputs: ObservationInputs,
519
- hyperparameter_states: HyperparameterStates,
520
- tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
521
- action_taken: float | int | None,
522
- ) -> float | int | list[int | float] | TensorStatistics:
523
- """
524
- :param observation_inputs: Observation input metrics not calculated with statistic trackers.
525
- :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
526
- :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
527
- names to floats or TensorStatistic models.
528
- :param action_taken: Action taken by the agent this class instance is assigned to.
529
- :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
530
- """
531
-
532
- statistics = tracked_statistics[statistic_trackers.ActivationStatistics.__name__]
533
-
534
- return observation_utils.average_tensor_statistics(tensor_statistics=list(statistics.values())) # type: ignore
535
-
536
- def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
537
- """
538
- :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
539
- needed.
540
- """
541
-
542
- return {statistic_trackers.ActivationStatistics.__name__: dict(skip_statistics=self.skip_statistics)}
543
-
544
-
545
- class GlobalParameterUpdates(GlobalObserver):
546
-
547
- def _get_observation_format(self) -> StatisticStorageTypes:
548
- """
549
- :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
550
- enumeration class.
551
- """
552
-
553
- return StatisticStorageTypes.TENSOR_STATISTICS
554
-
555
- def _observe(
556
- self,
557
- observation_inputs: ObservationInputs,
558
- hyperparameter_states: HyperparameterStates,
559
- tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
560
- action_taken: float | int | None,
561
- ) -> float | int | list[int | float] | TensorStatistics:
562
- """
563
- :param observation_inputs: Observation input metrics not calculated with statistic trackers.
564
- :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
565
- :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
566
- names to floats or TensorStatistic models.
567
- :param action_taken: Action taken by the agent this class instance is assigned to.
568
- :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
569
- """
570
-
571
- statistics = tracked_statistics[statistic_trackers.ParameterUpdateStatistics.__name__]
572
-
573
- return observation_utils.average_tensor_statistics(tensor_statistics=list(statistics.values())) # type: ignore
574
-
575
- def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
576
- """
577
- :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
578
- needed.
579
- """
580
-
581
- return {statistic_trackers.ParameterUpdateStatistics.__name__: dict(skip_statistics=self.skip_statistics)}
582
-
583
-
584
- class GlobalParameters(GlobalObserver):
585
-
586
- def _get_observation_format(self) -> StatisticStorageTypes:
587
- """
588
- :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
589
- enumeration class.
590
- """
591
-
592
- return StatisticStorageTypes.TENSOR_STATISTICS
593
-
594
- def _observe(
595
- self,
596
- observation_inputs: ObservationInputs,
597
- hyperparameter_states: HyperparameterStates,
598
- tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
599
- action_taken: float | int | None,
600
- ) -> float | int | list[int | float] | TensorStatistics:
601
- """
602
- :param observation_inputs: Observation input metrics not calculated with statistic trackers.
603
- :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
604
- :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
605
- names to floats or TensorStatistic models.
606
- :param action_taken: Action taken by the agent this class instance is assigned to.
607
- :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
608
- """
609
-
610
- statistics = tracked_statistics[statistic_trackers.ParameterStatistics.__name__]
611
-
612
- return observation_utils.average_tensor_statistics(tensor_statistics=list(statistics.values())) # type: ignore
613
-
614
- def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
615
- """
616
- :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
617
- needed.
618
- """
619
-
620
- return {statistic_trackers.ParameterStatistics.__name__: dict(skip_statistics=self.skip_statistics)}
621
-
622
-
623
- class GlobalLAMBTrustRatio(GlobalObserver):
624
-
625
- def __init__(
626
- self,
627
- *,
628
- use_log_transform: bool = False,
629
- **kwargs,
630
- ) -> None:
631
- """
632
- :param use_log_transform: Whether to transform the LAMB trust ratio by taking ln(1 + R).
633
- :param kwargs: Other observation keyword arguments.
634
- """
635
-
636
- super().__init__(**kwargs)
637
-
638
- self.use_log_transform = use_log_transform
639
-
640
- def _get_observation_format(self) -> StatisticStorageTypes:
641
- """
642
- :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
643
- enumeration class.
644
- """
645
-
646
- return StatisticStorageTypes.FLOAT
647
-
648
- def _observe(
649
- self,
650
- observation_inputs: ObservationInputs,
651
- hyperparameter_states: HyperparameterStates,
652
- tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
653
- action_taken: float | int | None,
654
- ) -> float | int | list[int | float] | TensorStatistics:
655
- """
656
- :param observation_inputs: Observation input metrics not calculated with statistic trackers.
657
- :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
658
- :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
659
- names to floats or TensorStatistic models.
660
- :param action_taken: Action taken by the agent this class instance is assigned to.
661
- :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
662
- """
663
-
664
- statistics = tracked_statistics[statistic_trackers.LAMBTrustRatioStatistics.__name__]
665
-
666
- return sum(statistics.values()) / len(statistics) # type: ignore
667
-
668
- def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
669
- """
670
- :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
671
- needed.
672
- """
673
-
674
- return {statistic_trackers.LAMBTrustRatioStatistics.__name__: dict(use_log_transform=self.use_log_transform)}
675
-
676
-
677
- class NumberOfParameters(GlobalObserver):
678
-
679
- def __init__(
680
- self,
681
- *,
682
- use_log_transform: bool = True,
683
- **kwargs,
684
- ) -> None:
685
- """
686
- :param use_log_transform: Whether to transform the return of the Observer by ln(1 + N).
687
- :param kwargs: Miscellaneous keyword arguments.
688
- """
689
-
690
- super().__init__(**kwargs)
691
-
692
- self.use_log_transform = use_log_transform
693
-
694
- @property
695
- def can_standardize(self) -> bool:
696
- """
697
- :return: Whether the observation can be standardized.
698
- """
699
-
700
- return False
701
-
702
- def _get_observation_format(self) -> StatisticStorageTypes:
703
- """
704
- :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
705
- enumeration class.
706
- """
707
-
708
- return StatisticStorageTypes.FLOAT
709
-
710
- def _observe(
711
- self,
712
- observation_inputs: ObservationInputs,
713
- hyperparameter_states: HyperparameterStates,
714
- tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
715
- action_taken: float | int | None,
716
- ) -> float | int | list[int | float] | TensorStatistics:
717
- """
718
- :param observation_inputs: Observation input metrics not calculated with statistic trackers.
719
- :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
720
- :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
721
- names to floats or TensorStatistic models.
722
- :param action_taken: Action taken by the agent this class instance is assigned to.
723
- :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
724
- """
725
-
726
- count = list(tracked_statistics[statistic_trackers.NumberOfParameters.__name__].values())[0]
727
-
728
- if self.use_log_transform:
729
- return math.log(1 + count) # type: ignore
730
-
731
- else:
732
- return count
733
-
734
- def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
735
- """
736
- :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
737
- needed.
738
- """
739
-
740
- return {statistic_trackers.NumberOfParameters.__name__: None}
741
-
742
-
743
- class NumberOfLayers(GlobalObserver):
744
-
745
- def __init__(
746
- self,
747
- *,
748
- use_log_transform: bool = True,
749
- trainable_only: bool = False,
750
- **kwargs,
751
- ) -> None:
752
- """
753
- :param use_log_transform: Whether to transform the return of the Observer by ln(1 + N).
754
- :param trainable_only: Whether to only count trainable layers.
755
- :param kwargs: Miscellaneous keyword arguments.
756
- """
757
-
758
- super().__init__(**kwargs)
759
-
760
- self.use_log_transform = use_log_transform
761
- self.trainable_only = trainable_only
762
-
763
- @property
764
- def can_standardize(self) -> bool:
765
- """
766
- :return: Whether the observation can be standardized.
767
- """
768
-
769
- return False
770
-
771
- def _get_observation_format(self) -> StatisticStorageTypes:
772
- """
773
- :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
774
- enumeration class.
775
- """
776
-
777
- return StatisticStorageTypes.FLOAT
778
-
779
- def _observe(
780
- self,
781
- observation_inputs: ObservationInputs,
782
- hyperparameter_states: HyperparameterStates,
783
- tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
784
- action_taken: float | int | None,
785
- ) -> float | int | list[int | float] | TensorStatistics:
786
- """
787
- :param observation_inputs: Observation input metrics not calculated with statistic trackers.
788
- :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
789
- :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
790
- names to floats or TensorStatistic models.
791
- :param action_taken: Action taken by the agent this class instance is assigned to.
792
- :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
793
- """
794
-
795
- count = list(tracked_statistics[statistic_trackers.NumberOfLayers.__name__].values())[0]
796
-
797
- if self.use_log_transform:
798
- return math.log(1 + count) # type: ignore
799
-
800
- else:
801
- return count
802
-
803
- def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
804
- """
805
- :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
806
- needed.
807
- """
808
-
809
- return {statistic_trackers.NumberOfLayers.__name__: dict(trainable_only=self.trainable_only)}
810
-
811
-
812
- class OptimizerTypeOneHot(GlobalObserver):
813
-
814
- OPTIMS = [Adam.__name__, AdamW.__name__, SGD.__name__]
815
-
816
- @property
817
- def vector_length(self) -> int:
818
- """
819
- :return: Length of the vector returned by this observation if it returns a vector.
820
- """
821
-
822
- return len(self.OPTIMS)
823
-
824
- @property
825
- def can_inform(self) -> bool:
826
- """
827
- :return: Whether observations from the observer can be used in the agent info dictionary.
828
- """
829
-
830
- return False
831
-
832
- def _get_observation_format(self) -> StatisticStorageTypes:
833
- """
834
- :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
835
- enumeration class.
836
- """
837
-
838
- return StatisticStorageTypes.VECTOR
839
-
840
- def _observe(
841
- self,
842
- observation_inputs: ObservationInputs,
843
- hyperparameter_states: HyperparameterStates,
844
- tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
845
- action_taken: float | int | None,
846
- ) -> float | int | list[int | float] | TensorStatistics:
847
- """
848
- :param observation_inputs: Observation input metrics not calculated with statistic trackers.
849
- :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
850
- :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
851
- names to floats or TensorStatistic models.
852
- :param action_taken: Action taken by the agent this class instance is assigned to.
853
- :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
854
- """
855
-
856
- optimizer_type = self.observer_config.optimizer_name
857
-
858
- if optimizer_type not in self.OPTIMS:
859
- index = None
860
-
861
- else:
862
- index = self.OPTIMS.index(optimizer_type)
863
-
864
- return observation_utils.create_one_hot_observation(vector_length=self.vector_length, one_hot_index=index)
865
-
866
- def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
867
- """
868
- :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
869
- needed.
870
- """
871
-
872
- return {}
873
-
874
-
875
- class ModelFamilyOneHot(GlobalObserver):
876
-
877
- UNIT_EPISODE = "episode"
878
- UNIT_TIMESTEP = "timestep"
879
-
880
- def __init__(
881
- self,
882
- *,
883
- zero_vector_chance: float = 0.2,
884
- zero_vector_frequency_unit: str = "episode",
885
- **kwargs,
886
- ) -> None:
887
- """
888
- :param skip_observations: List of episode boundary observations to ignore.
889
- :param kwargs: Miscellaneous keyword arguments.
890
- """
891
- super().__init__(**kwargs)
892
- self.should_zero = False
893
-
894
- assert 0.0 <= zero_vector_chance < 1.0
895
- self.zero_vector_chance = zero_vector_chance
896
- self._sample_zero_vector()
897
-
898
- if zero_vector_frequency_unit not in [self.UNIT_EPISODE, self.UNIT_TIMESTEP]:
899
- raise ValueError(f"Unknown zero_vector_frequency_unit: {zero_vector_frequency_unit}")
900
-
901
- self.zero_vector_frequency_unit = zero_vector_frequency_unit
902
- self.family_vector = self._create_family_vector()
903
-
904
- @property
905
- def vector_length(self) -> int:
906
- """
907
- :return: Length of the vector returned by this observation if it returns a vector.
908
- """
909
-
910
- return len(ModelFamilies)
911
-
912
- @property
913
- def can_inform(self) -> bool:
914
- """
915
- :return: Whether observations from the observer can be used in the agent info dictionary.
916
- """
917
-
918
- return False
919
-
920
- def _get_observation_format(self) -> StatisticStorageTypes:
921
- """
922
- :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
923
- enumeration class.
924
- """
925
-
926
- return StatisticStorageTypes.VECTOR
927
-
928
- def _create_family_vector(self) -> list[float]:
929
- """
930
- :return: Creates and returns the model family one-hot vector.
931
- """
932
-
933
- family_name = self.observer_config.nn_family_name
934
- known_name = family_name in (family.value for family in ModelFamilies)
935
-
936
- if known_name:
937
- family_idx = ModelFamilies.get_index(family_name)
938
-
939
- else:
940
- family_idx = None
941
-
942
- return observation_utils.create_one_hot_observation(vector_length=self.vector_length, one_hot_index=family_idx)
943
-
944
- def _observe(
945
- self,
946
- observation_inputs: ObservationInputs,
947
- hyperparameter_states: HyperparameterStates,
948
- tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
949
- action_taken: float | int | None,
950
- ) -> float | int | list[int | float] | TensorStatistics:
951
- """
952
- :param observation_inputs: Observation input metrics not calculated with statistic trackers.
953
- :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
954
- :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
955
- names to floats or TensorStatistic models.
956
- :param action_taken: Action taken by the agent this class instance is assigned to.
957
- :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
958
- """
959
-
960
- if not self.in_training_mode:
961
- return self.family_vector
962
-
963
- if self.zero_vector_frequency_unit == self.UNIT_TIMESTEP:
964
- self._sample_zero_vector()
965
-
966
- if self.should_zero:
967
- return [0.0 for _ in range(self.vector_length)]
968
-
969
- else:
970
- return self.family_vector
971
-
972
- def _sample_zero_vector(self) -> None:
973
- """
974
- Determines whether the output vector of this observer should be masked with zeros.
975
- """
976
- self.should_zero = random.choices([True, False], [self.zero_vector_chance, (1 - self.zero_vector_chance)])[0]
977
-
978
- def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
979
- """
980
- :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
981
- needed.
982
- """
983
-
984
- return {}
985
-
986
- def reset(self) -> None:
987
- """
988
- Resets the observer.
989
- """
990
-
991
- self._sample_zero_vector()