libinephany 0.18.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- libinephany/observations/observation_utils.py +19 -2
- libinephany/observations/observers/base_observers.py +20 -8
- libinephany/observations/observers/global_observers/__init__.py +19 -1
- libinephany/observations/observers/global_observers/constants.py +2 -0
- libinephany/observations/observers/global_observers/gradient_observers.py +320 -3
- libinephany/observations/observers/global_observers/hyperparameter_observers.py +26 -18
- libinephany/observations/observers/global_observers/model_observers.py +220 -6
- libinephany/observations/observers/global_observers/progress_observers.py +7 -1
- libinephany/observations/observers/local_observers.py +158 -25
- libinephany/observations/statistic_trackers.py +435 -23
- libinephany/pydantic_models/schemas/tensor_statistics.py +33 -32
- libinephany/pydantic_models/states/hyperparameter_states.py +32 -30
- {libinephany-0.18.1.dist-info → libinephany-1.0.0.dist-info}/METADATA +1 -1
- {libinephany-0.18.1.dist-info → libinephany-1.0.0.dist-info}/RECORD +17 -17
- {libinephany-0.18.1.dist-info → libinephany-1.0.0.dist-info}/WHEEL +0 -0
- {libinephany-0.18.1.dist-info → libinephany-1.0.0.dist-info}/licenses/LICENSE +0 -0
- {libinephany-0.18.1.dist-info → libinephany-1.0.0.dist-info}/top_level.txt +0 -0
@@ -59,7 +59,7 @@ class GlobalActivations(GlobalObserver):
|
|
59
59
|
needed.
|
60
60
|
"""
|
61
61
|
|
62
|
-
return {statistic_trackers.ActivationStatistics.__name__: dict(
|
62
|
+
return {statistic_trackers.ActivationStatistics.__name__: dict(include_statistics=self.include_statistics)}
|
63
63
|
|
64
64
|
|
65
65
|
class GlobalParameterUpdates(GlobalObserver):
|
@@ -98,7 +98,7 @@ class GlobalParameterUpdates(GlobalObserver):
|
|
98
98
|
needed.
|
99
99
|
"""
|
100
100
|
|
101
|
-
return {statistic_trackers.ParameterUpdateStatistics.__name__: dict(
|
101
|
+
return {statistic_trackers.ParameterUpdateStatistics.__name__: dict(include_statistics=self.include_statistics)}
|
102
102
|
|
103
103
|
|
104
104
|
class GlobalParameters(GlobalObserver):
|
@@ -137,7 +137,7 @@ class GlobalParameters(GlobalObserver):
|
|
137
137
|
needed.
|
138
138
|
"""
|
139
139
|
|
140
|
-
return {statistic_trackers.ParameterStatistics.__name__: dict(
|
140
|
+
return {statistic_trackers.ParameterStatistics.__name__: dict(include_statistics=self.include_statistics)}
|
141
141
|
|
142
142
|
|
143
143
|
class GlobalLAMBTrustRatio(GlobalObserver):
|
@@ -385,7 +385,7 @@ class LogRatioOfPreviousAndCurrentParamNormEnvStepObserver(LHOPTBaseObserver):
|
|
385
385
|
"""
|
386
386
|
|
387
387
|
return {
|
388
|
-
statistic_trackers.ParameterStatistics.__name__: dict(
|
388
|
+
statistic_trackers.ParameterStatistics.__name__: dict(include_statistics=self.include_statistics),
|
389
389
|
}
|
390
390
|
|
391
391
|
def reset(self) -> None:
|
@@ -436,6 +436,146 @@ class LogRatioOfUpdateAndPreviousParamNormEnvStepObserver(LHOPTBaseObserver):
|
|
436
436
|
tensor_statistics=[stats for stats in param_statistics.values() if isinstance(stats, TensorStatistics)]
|
437
437
|
).norm_
|
438
438
|
|
439
|
+
if self._previous_param_norm is None:
|
440
|
+
self._previous_param_norm = current_param_norm
|
441
|
+
self._compute_cdf_feature(0.0) # default value since we can't compute log ratio yet
|
442
|
+
self._update_time()
|
443
|
+
return [0.0, 0.0]
|
444
|
+
|
445
|
+
log_ratio = self._compute_log_ratio(update_norm, self._previous_param_norm)
|
446
|
+
tanh_feature = math.tanh(max(-LHOPT_CONSTANTS["TANH_BOUND"], min(LHOPT_CONSTANTS["TANH_BOUND"], log_ratio)))
|
447
|
+
cdf_feature = self._compute_cdf_feature(log_ratio)
|
448
|
+
|
449
|
+
self._update_time()
|
450
|
+
self._previous_param_norm = current_param_norm
|
451
|
+
|
452
|
+
return [tanh_feature, cdf_feature]
|
453
|
+
|
454
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
455
|
+
"""
|
456
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
457
|
+
needed.
|
458
|
+
"""
|
459
|
+
|
460
|
+
return {
|
461
|
+
statistic_trackers.ParameterUpdateStatistics.__name__: dict(include_statistics=self.include_statistics),
|
462
|
+
statistic_trackers.ParameterStatistics.__name__: dict(include_statistics=self.include_statistics),
|
463
|
+
}
|
464
|
+
|
465
|
+
def reset(self) -> None:
|
466
|
+
"""
|
467
|
+
Reset the observer by clearing the previous parameter norm and time series.
|
468
|
+
"""
|
469
|
+
|
470
|
+
super().reset()
|
471
|
+
self._previous_param_norm = None
|
472
|
+
|
473
|
+
|
474
|
+
class LHOPTAverageParameterUpdateMagnitudeObserver(LHOPTBaseObserver):
|
475
|
+
|
476
|
+
def _get_observation_format(self) -> StatisticStorageTypes:
|
477
|
+
"""
|
478
|
+
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
479
|
+
enumeration class.
|
480
|
+
"""
|
481
|
+
|
482
|
+
return StatisticStorageTypes.VECTOR
|
483
|
+
|
484
|
+
@property
|
485
|
+
def can_standardize(self) -> bool:
|
486
|
+
"""
|
487
|
+
:return: Whether the observation can be standardized.
|
488
|
+
"""
|
489
|
+
|
490
|
+
return False
|
491
|
+
|
492
|
+
@property
|
493
|
+
def vector_length(self) -> int:
|
494
|
+
"""
|
495
|
+
:return: Length of the vector returned by this observation if it returns a vector.
|
496
|
+
"""
|
497
|
+
return 2 # [raw_feature, cdf_feature]
|
498
|
+
|
499
|
+
def _observe(
|
500
|
+
self,
|
501
|
+
observation_inputs: ObservationInputs,
|
502
|
+
hyperparameter_states: HyperparameterStates,
|
503
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
504
|
+
action_taken: float | int | None,
|
505
|
+
) -> float | int | list[int | float] | TensorStatistics:
|
506
|
+
"""
|
507
|
+
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
508
|
+
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
509
|
+
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
510
|
+
names to floats or TensorStatistic models.
|
511
|
+
:param action_taken: Action taken by the agent this class instance is assigned to.
|
512
|
+
:return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
|
513
|
+
"""
|
514
|
+
|
515
|
+
statistics = tracked_statistics[statistic_trackers.AverageParameterUpdateMagnitudeStatistics.__name__]
|
516
|
+
|
517
|
+
raw_feature = list(statistics.values())[0] # type: ignore[list-item]
|
518
|
+
|
519
|
+
cdf_feature = self._compute_cdf_feature(raw_feature) # type: ignore[arg-type]
|
520
|
+
self._update_time()
|
521
|
+
|
522
|
+
return [raw_feature, cdf_feature] # type: ignore[list-item]
|
523
|
+
|
524
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
525
|
+
"""
|
526
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
527
|
+
needed.
|
528
|
+
"""
|
529
|
+
|
530
|
+
return {
|
531
|
+
statistic_trackers.AverageParameterUpdateMagnitudeStatistics.__name__: dict(
|
532
|
+
include_statistics=self.include_statistics, sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"]
|
533
|
+
)
|
534
|
+
}
|
535
|
+
|
536
|
+
|
537
|
+
class LogRatioOfUpdateAndPreviousParamNormInnerStepObserver(LHOPTBaseObserver):
|
538
|
+
def __init__(self, **kwargs):
|
539
|
+
"""
|
540
|
+
This observer is used to compute the log ratio of the update and previous parameter norm for the inner step. The sample frequency of the statistics needs to be set to 4 (according to the OpenAI paper).
|
541
|
+
|
542
|
+
"""
|
543
|
+
super().__init__(**kwargs)
|
544
|
+
self._previous_param_norm = None
|
545
|
+
|
546
|
+
@property
|
547
|
+
def vector_length(self) -> int:
|
548
|
+
"""
|
549
|
+
:return: Length of the vector returned by this observation if it returns a vector.
|
550
|
+
"""
|
551
|
+
return 2 # [tanh_feature, cdf_feature]
|
552
|
+
|
553
|
+
def _observe(
|
554
|
+
self,
|
555
|
+
observation_inputs: ObservationInputs,
|
556
|
+
hyperparameter_states: HyperparameterStates,
|
557
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
558
|
+
action_taken: float | int | None,
|
559
|
+
) -> float | int | list[int | float] | TensorStatistics:
|
560
|
+
"""
|
561
|
+
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
562
|
+
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
563
|
+
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
564
|
+
names to floats or TensorStatistics models.
|
565
|
+
:param action_taken: Action taken by the agent this class instance is assigned to.
|
566
|
+
:return: List containing [raw_log_ratio, cdf_feature].
|
567
|
+
"""
|
568
|
+
|
569
|
+
update_statistics = tracked_statistics[statistic_trackers.LHOPTParameterUpdateStatistics.__name__]
|
570
|
+
param_statistics = tracked_statistics[statistic_trackers.LHOPTParameterStatistics.__name__]
|
571
|
+
update_norm = observation_utils.average_tensor_statistics(
|
572
|
+
tensor_statistics=[stats for stats in update_statistics.values() if isinstance(stats, TensorStatistics)]
|
573
|
+
).norm_
|
574
|
+
|
575
|
+
current_param_norm = observation_utils.average_tensor_statistics(
|
576
|
+
tensor_statistics=[stats for stats in param_statistics.values() if isinstance(stats, TensorStatistics)]
|
577
|
+
).norm_
|
578
|
+
|
439
579
|
if self._previous_param_norm is None:
|
440
580
|
self._previous_param_norm = current_param_norm
|
441
581
|
self._compute_cdf_feature(0.0) # default value since we can't compute log ratio yet
|
@@ -456,8 +596,12 @@ class LogRatioOfUpdateAndPreviousParamNormEnvStepObserver(LHOPTBaseObserver):
|
|
456
596
|
"""
|
457
597
|
|
458
598
|
return {
|
459
|
-
statistic_trackers.
|
460
|
-
|
599
|
+
statistic_trackers.LHOPTParameterUpdateStatistics.__name__: dict(
|
600
|
+
include_statistics=self.include_statistics, sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"]
|
601
|
+
),
|
602
|
+
statistic_trackers.LHOPTParameterStatistics.__name__: dict(
|
603
|
+
include_statistics=self.include_statistics, sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"]
|
604
|
+
),
|
461
605
|
}
|
462
606
|
|
463
607
|
def reset(self) -> None:
|
@@ -467,3 +611,73 @@ class LogRatioOfUpdateAndPreviousParamNormEnvStepObserver(LHOPTBaseObserver):
|
|
467
611
|
|
468
612
|
super().reset()
|
469
613
|
self._previous_param_norm = None
|
614
|
+
|
615
|
+
|
616
|
+
class LHOPTGlobalLAMBTrustRatio(LHOPTBaseObserver):
|
617
|
+
|
618
|
+
def __init__(
|
619
|
+
self,
|
620
|
+
*,
|
621
|
+
use_log_transform: bool = False,
|
622
|
+
**kwargs,
|
623
|
+
) -> None:
|
624
|
+
"""
|
625
|
+
:param use_log_transform: Whether to transform the LAMB trust ratio by taking ln(1 + R).
|
626
|
+
:param kwargs: Other observation keyword arguments.
|
627
|
+
"""
|
628
|
+
|
629
|
+
super().__init__(**kwargs)
|
630
|
+
|
631
|
+
self.use_log_transform = use_log_transform
|
632
|
+
|
633
|
+
def _get_observation_format(self) -> StatisticStorageTypes:
|
634
|
+
"""
|
635
|
+
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
636
|
+
enumeration class.
|
637
|
+
"""
|
638
|
+
|
639
|
+
return StatisticStorageTypes.VECTOR
|
640
|
+
|
641
|
+
@property
|
642
|
+
def vector_length(self) -> int:
|
643
|
+
"""
|
644
|
+
:return: Length of the vector returned by this observation if it returns a vector.
|
645
|
+
"""
|
646
|
+
return 2 # [raw_value, cdf_feature]
|
647
|
+
|
648
|
+
def _observe(
|
649
|
+
self,
|
650
|
+
observation_inputs: ObservationInputs,
|
651
|
+
hyperparameter_states: HyperparameterStates,
|
652
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
653
|
+
action_taken: float | int | None,
|
654
|
+
) -> float | int | list[int | float] | TensorStatistics:
|
655
|
+
"""
|
656
|
+
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
657
|
+
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
658
|
+
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
659
|
+
names to floats or TensorStatistic models.
|
660
|
+
:param action_taken: Action taken by the agent this class instance is assigned to.
|
661
|
+
:return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
|
662
|
+
"""
|
663
|
+
|
664
|
+
statistics = tracked_statistics[statistic_trackers.LHOPTLAMBTrustRatioStatistics.__name__]
|
665
|
+
|
666
|
+
raw_value = sum(statistics.values()) / len(statistics) # type: ignore[arg-type]
|
667
|
+
cdf_feature = self._compute_cdf_feature(raw_value) # type: ignore[arg-type]
|
668
|
+
self._update_time()
|
669
|
+
return [raw_value, cdf_feature] # type: ignore[list-item]
|
670
|
+
|
671
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
672
|
+
"""
|
673
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
674
|
+
needed.
|
675
|
+
"""
|
676
|
+
|
677
|
+
return {
|
678
|
+
statistic_trackers.LHOPTLAMBTrustRatioStatistics.__name__: dict(
|
679
|
+
include_statistics=self.include_statistics,
|
680
|
+
use_log_transform=self.use_log_transform,
|
681
|
+
sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"],
|
682
|
+
)
|
683
|
+
}
|
@@ -1,6 +1,6 @@
|
|
1
1
|
# ======================================================================================================================
|
2
2
|
#
|
3
|
-
#
|
3
|
+
# IMPORTS
|
4
4
|
#
|
5
5
|
# ======================================================================================================================
|
6
6
|
|
@@ -13,6 +13,12 @@ from libinephany.pydantic_models.schemas.observation_models import ObservationIn
|
|
13
13
|
from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
|
14
14
|
from libinephany.pydantic_models.states.hyperparameter_states import HyperparameterStates
|
15
15
|
|
16
|
+
# ======================================================================================================================
|
17
|
+
#
|
18
|
+
# CLASSES
|
19
|
+
#
|
20
|
+
# ======================================================================================================================
|
21
|
+
|
16
22
|
|
17
23
|
class TrainingProgress(GlobalObserver):
|
18
24
|
|
@@ -4,11 +4,13 @@
|
|
4
4
|
#
|
5
5
|
# ======================================================================================================================
|
6
6
|
|
7
|
+
import math
|
7
8
|
from typing import Any
|
8
9
|
|
9
10
|
from libinephany.observations import observation_utils, statistic_trackers
|
10
|
-
from libinephany.observations.observation_utils import StatisticStorageTypes
|
11
|
+
from libinephany.observations.observation_utils import StatisticStorageTypes, compute_cdf_feature
|
11
12
|
from libinephany.observations.observers.base_observers import LocalObserver
|
13
|
+
from libinephany.observations.observers.global_observers.constants import LHOPT_CONSTANTS
|
12
14
|
from libinephany.pydantic_models.schemas.observation_models import ObservationInputs
|
13
15
|
from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
|
14
16
|
from libinephany.pydantic_models.states.hyperparameter_states import HyperparameterStates
|
@@ -66,7 +68,7 @@ class FirstOrderGradients(LocalObserver):
|
|
66
68
|
needed.
|
67
69
|
"""
|
68
70
|
|
69
|
-
return {statistic_trackers.FirstOrderGradients.__name__: dict(
|
71
|
+
return {statistic_trackers.FirstOrderGradients.__name__: dict(include_statistics=self.include_statistics)}
|
70
72
|
|
71
73
|
|
72
74
|
class SecondOrderGradients(LocalObserver):
|
@@ -130,7 +132,7 @@ class SecondOrderGradients(LocalObserver):
|
|
130
132
|
|
131
133
|
return {
|
132
134
|
statistic_trackers.SecondOrderGradients.__name__: dict(
|
133
|
-
|
135
|
+
include_statistics=self.include_statistics, compute_hessian_diagonal=self.compute_hessian_diagonal
|
134
136
|
)
|
135
137
|
}
|
136
138
|
|
@@ -178,7 +180,7 @@ class Activations(LocalObserver):
|
|
178
180
|
needed.
|
179
181
|
"""
|
180
182
|
|
181
|
-
return {statistic_trackers.ActivationStatistics.__name__: dict(
|
183
|
+
return {statistic_trackers.ActivationStatistics.__name__: dict(include_statistics=self.include_statistics)}
|
182
184
|
|
183
185
|
|
184
186
|
class ParameterUpdates(LocalObserver):
|
@@ -224,7 +226,7 @@ class ParameterUpdates(LocalObserver):
|
|
224
226
|
needed.
|
225
227
|
"""
|
226
228
|
|
227
|
-
return {statistic_trackers.ParameterUpdateStatistics.__name__: dict(
|
229
|
+
return {statistic_trackers.ParameterUpdateStatistics.__name__: dict(include_statistics=self.include_statistics)}
|
228
230
|
|
229
231
|
|
230
232
|
class Parameters(LocalObserver):
|
@@ -270,7 +272,7 @@ class Parameters(LocalObserver):
|
|
270
272
|
needed.
|
271
273
|
"""
|
272
274
|
|
273
|
-
return {statistic_trackers.ParameterStatistics.__name__: dict(
|
275
|
+
return {statistic_trackers.ParameterStatistics.__name__: dict(include_statistics=self.include_statistics)}
|
274
276
|
|
275
277
|
|
276
278
|
class LAMBTrustRatio(LocalObserver):
|
@@ -690,16 +692,16 @@ class ModuleTypeOneHot(LocalObserver):
|
|
690
692
|
|
691
693
|
class CurrentHyperparameters(LocalObserver):
|
692
694
|
|
693
|
-
def __init__(self,
|
695
|
+
def __init__(self, include_hparams: list[str] | None = None, **kwargs) -> None:
|
694
696
|
"""
|
695
|
-
:param
|
697
|
+
:param include_hparams: Names of the hyperparameters to include in the initial values vector returned by
|
696
698
|
this observation.
|
697
699
|
:param kwargs: Miscellaneous keyword arguments.
|
698
700
|
"""
|
699
701
|
|
700
702
|
super().__init__(**kwargs)
|
701
703
|
|
702
|
-
self.
|
704
|
+
self.include_hparams = include_hparams
|
703
705
|
|
704
706
|
@property
|
705
707
|
def can_standardize(self) -> bool:
|
@@ -723,11 +725,12 @@ class CurrentHyperparameters(LocalObserver):
|
|
723
725
|
:return: Length of the vector returned by this observation if it returns a vector.
|
724
726
|
"""
|
725
727
|
|
728
|
+
if self.include_hparams is None:
|
729
|
+
raise ValueError(f"{self.__class__.__name__} must be provided with include_hparams.")
|
730
|
+
|
726
731
|
available_hparams = HyperparameterStates.get_layerwise_hyperparameters()
|
727
732
|
|
728
|
-
return len(
|
729
|
-
[hparam for hparam in available_hparams if not any(skipped in hparam for skipped in self.skip_hparams)]
|
730
|
-
)
|
733
|
+
return len([hparam for hparam in available_hparams if hparam in self.include_hparams])
|
731
734
|
|
732
735
|
def _get_observation_format(self) -> StatisticStorageTypes:
|
733
736
|
"""
|
@@ -756,7 +759,7 @@ class CurrentHyperparameters(LocalObserver):
|
|
756
759
|
assert self.parameter_group_name is not None
|
757
760
|
|
758
761
|
current_internal_values = hyperparameter_states[self.parameter_group_name].get_current_internal_values(
|
759
|
-
|
762
|
+
include_hparams=self.include_hparams
|
760
763
|
)
|
761
764
|
|
762
765
|
self._cached_observation = current_internal_values
|
@@ -774,16 +777,16 @@ class CurrentHyperparameters(LocalObserver):
|
|
774
777
|
|
775
778
|
class CurrentHyperparameterDeltas(LocalObserver):
|
776
779
|
|
777
|
-
def __init__(self,
|
780
|
+
def __init__(self, include_hparams: list[str] | None = None, **kwargs) -> None:
|
778
781
|
"""
|
779
|
-
:param
|
782
|
+
:param include_hparams: Names of the hyperparameters to include in the initial deltas vector returned by
|
780
783
|
this observation.
|
781
784
|
:param kwargs: Miscellaneous keyword arguments.
|
782
785
|
"""
|
783
786
|
|
784
787
|
super().__init__(**kwargs)
|
785
788
|
|
786
|
-
self.
|
789
|
+
self.include_hparams = include_hparams
|
787
790
|
|
788
791
|
@property
|
789
792
|
def can_standardize(self) -> bool:
|
@@ -807,11 +810,12 @@ class CurrentHyperparameterDeltas(LocalObserver):
|
|
807
810
|
:return: Length of the vector returned by this observation if it returns a vector.
|
808
811
|
"""
|
809
812
|
|
813
|
+
if self.include_hparams is None:
|
814
|
+
raise ValueError(f"{self.__class__.__name__} must be provided with include_hparams.")
|
815
|
+
|
810
816
|
available_hparams = HyperparameterStates.get_layerwise_hyperparameters()
|
811
817
|
|
812
|
-
return len(
|
813
|
-
[hparam for hparam in available_hparams if not any(skipped in hparam for skipped in self.skip_hparams)]
|
814
|
-
)
|
818
|
+
return len([hparam for hparam in available_hparams if hparam in self.include_hparams])
|
815
819
|
|
816
820
|
def _get_observation_format(self) -> StatisticStorageTypes:
|
817
821
|
"""
|
@@ -838,9 +842,10 @@ class CurrentHyperparameterDeltas(LocalObserver):
|
|
838
842
|
"""
|
839
843
|
|
840
844
|
assert self.parameter_group_name is not None
|
845
|
+
assert self.include_hparams is not None
|
841
846
|
|
842
847
|
current_deltas = hyperparameter_states[self.parameter_group_name].get_current_deltas(
|
843
|
-
|
848
|
+
include_hparams=self.include_hparams
|
844
849
|
)
|
845
850
|
|
846
851
|
self._cached_observation = current_deltas
|
@@ -860,16 +865,16 @@ class HyperparameterTransformTypes(LocalObserver):
|
|
860
865
|
|
861
866
|
TRANSFORM_TYPE_TO_IDX = dict(((s, i) for i, s in enumerate(HyperparameterTransformType)))
|
862
867
|
|
863
|
-
def __init__(self,
|
868
|
+
def __init__(self, include_hparams: list[str] | None = None, **kwargs) -> None:
|
864
869
|
"""
|
865
|
-
:param
|
870
|
+
:param include_hparams: Names of the hyperparameters to include in the transforms vector returned by
|
866
871
|
this observation.
|
867
872
|
:param kwargs: Miscellaneous keyword arguments.
|
868
873
|
"""
|
869
874
|
|
870
875
|
super().__init__(**kwargs)
|
871
876
|
|
872
|
-
self.
|
877
|
+
self.include_hparams = include_hparams
|
873
878
|
|
874
879
|
@property
|
875
880
|
def can_standardize(self) -> bool:
|
@@ -893,10 +898,13 @@ class HyperparameterTransformTypes(LocalObserver):
|
|
893
898
|
:return: Length of the vector returned by this observation if it returns a vector.
|
894
899
|
"""
|
895
900
|
|
901
|
+
if self.include_hparams is None:
|
902
|
+
raise ValueError(f"{self.__class__.__name__} must be provided with include_hparams.")
|
903
|
+
|
896
904
|
available_hparams = HyperparameterStates.get_layerwise_hyperparameters()
|
897
905
|
|
898
906
|
return len(HyperparameterTransformType) * len(
|
899
|
-
[hparam for hparam in available_hparams if
|
907
|
+
[hparam for hparam in available_hparams if hparam in self.include_hparams]
|
900
908
|
)
|
901
909
|
|
902
910
|
def _get_observation_format(self) -> StatisticStorageTypes:
|
@@ -924,10 +932,11 @@ class HyperparameterTransformTypes(LocalObserver):
|
|
924
932
|
"""
|
925
933
|
|
926
934
|
assert self.parameter_group_name is not None
|
935
|
+
assert self.include_hparams is not None
|
927
936
|
|
928
937
|
parameter_group_hparams = hyperparameter_states[self.parameter_group_name]
|
929
938
|
hyperparameter_transform_types = parameter_group_hparams.get_hyperparameter_transform_types(
|
930
|
-
|
939
|
+
include_hparams=self.include_hparams
|
931
940
|
)
|
932
941
|
hyperparameter_transform_types_onehot_list = [
|
933
942
|
observation_utils.create_one_hot_observation(
|
@@ -1069,3 +1078,127 @@ class PercentageDepth(LocalObserver):
|
|
1069
1078
|
"""
|
1070
1079
|
|
1071
1080
|
return {}
|
1081
|
+
|
1082
|
+
|
1083
|
+
class LogOfNoiseScaleObserver(LocalObserver):
|
1084
|
+
|
1085
|
+
def __init__(
|
1086
|
+
self,
|
1087
|
+
*,
|
1088
|
+
decay_factor: float = LHOPT_CONSTANTS["DEFAULT_DECAY_FACTOR"],
|
1089
|
+
time_window: int = LHOPT_CONSTANTS["DEFAULT_TIME_WINDOW"],
|
1090
|
+
include_statistics: list[str] | None = None,
|
1091
|
+
**kwargs,
|
1092
|
+
) -> None:
|
1093
|
+
"""
|
1094
|
+
:param decay_factor: Decay factor for CDF calculation in [1, 2.5, 5, 10, 20]
|
1095
|
+
:param time_window: Number of time steps to consider for CDF calculation
|
1096
|
+
:param include_statistics: List of statistics to include.
|
1097
|
+
or use the squared first order gradients as approximations in the same way Adam does.
|
1098
|
+
:param kwargs: Miscellaneous keyword arguments.
|
1099
|
+
"""
|
1100
|
+
|
1101
|
+
super().__init__(**kwargs)
|
1102
|
+
|
1103
|
+
self.include_statistics = include_statistics
|
1104
|
+
self.decay_factor = max(0.0, decay_factor)
|
1105
|
+
self.time_window = max(1, time_window)
|
1106
|
+
|
1107
|
+
# Store time series data for CDF calculation
|
1108
|
+
self._time_series: list[tuple[float, float]] = [] # (time, value) pairs
|
1109
|
+
self._current_time: float = 0.0
|
1110
|
+
|
1111
|
+
def _get_observation_format(self) -> StatisticStorageTypes:
|
1112
|
+
"""
|
1113
|
+
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
1114
|
+
enumeration class.
|
1115
|
+
"""
|
1116
|
+
|
1117
|
+
return StatisticStorageTypes.VECTOR
|
1118
|
+
|
1119
|
+
@property
|
1120
|
+
def can_standardize(self) -> bool:
|
1121
|
+
"""
|
1122
|
+
:return: Whether the observation can be standardized.
|
1123
|
+
"""
|
1124
|
+
|
1125
|
+
return False
|
1126
|
+
|
1127
|
+
@property
|
1128
|
+
def can_inform(self) -> bool:
|
1129
|
+
"""
|
1130
|
+
:return: Whether observations from the observer can be used in the agent info dictionary.
|
1131
|
+
"""
|
1132
|
+
|
1133
|
+
return False
|
1134
|
+
|
1135
|
+
def _update_time(self) -> None:
|
1136
|
+
"""Update the current time counter."""
|
1137
|
+
self._current_time += 1.0
|
1138
|
+
|
1139
|
+
def _compute_cdf_feature(self, value: float) -> float:
|
1140
|
+
"""
|
1141
|
+
Compute CDF feature for the given value.
|
1142
|
+
training loss will be added to the time series after this call.
|
1143
|
+
:param value: The value to compute CDF feature for
|
1144
|
+
:return: CDF feature value
|
1145
|
+
"""
|
1146
|
+
return compute_cdf_feature(value, self._time_series, self.decay_factor, self._current_time, self.time_window)
|
1147
|
+
|
1148
|
+
@property
|
1149
|
+
def vector_length(self) -> int:
|
1150
|
+
"""
|
1151
|
+
:return: Length of the vector returned by this observation if it returns a vector.
|
1152
|
+
"""
|
1153
|
+
return 2 # [log_noise_scale, cdf_feature]
|
1154
|
+
|
1155
|
+
def _observe(
|
1156
|
+
self,
|
1157
|
+
observation_inputs: ObservationInputs,
|
1158
|
+
hyperparameter_states: HyperparameterStates,
|
1159
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
1160
|
+
action_taken: float | int | None,
|
1161
|
+
) -> float | int | list[int | float] | TensorStatistics:
|
1162
|
+
"""
|
1163
|
+
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
1164
|
+
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
1165
|
+
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
1166
|
+
names to floats or TensorStatistic models.
|
1167
|
+
:param action_taken: Action taken by the agent this class instance is assigned to.
|
1168
|
+
:return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
|
1169
|
+
"""
|
1170
|
+
|
1171
|
+
statistics = tracked_statistics[statistic_trackers.LogOfNoiseScaleStatistics.__name__]
|
1172
|
+
|
1173
|
+
raw_value = list(statistics.values())[0] # type: ignore[list-item]
|
1174
|
+
assert isinstance(raw_value, float), f"Expected float, got {type(raw_value)}" # to avoid type errors with mypy
|
1175
|
+
batch_size = hyperparameter_states.global_hparams.batch_size.external_value
|
1176
|
+
learning_rate = hyperparameter_states.parameter_group_hparams[
|
1177
|
+
self.parameter_group_name
|
1178
|
+
].learning_rate.external_value
|
1179
|
+
|
1180
|
+
log_b_over_epsilon = math.log(batch_size / learning_rate)
|
1181
|
+
|
1182
|
+
log_noise_scale = raw_value + log_b_over_epsilon
|
1183
|
+
|
1184
|
+
cdf_feature = self._compute_cdf_feature(log_noise_scale) # type: ignore[arg-type]
|
1185
|
+
self._update_time()
|
1186
|
+
|
1187
|
+
return [log_noise_scale, cdf_feature] # type: ignore[list-item]
|
1188
|
+
|
1189
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
1190
|
+
"""
|
1191
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
1192
|
+
needed.
|
1193
|
+
"""
|
1194
|
+
|
1195
|
+
return {
|
1196
|
+
statistic_trackers.LogOfNoiseScaleStatistics.__name__: dict(
|
1197
|
+
include_statistics=self.include_statistics, sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"]
|
1198
|
+
)
|
1199
|
+
}
|
1200
|
+
|
1201
|
+
def reset(self) -> None:
|
1202
|
+
"""Reset the observer by clearing the time series."""
|
1203
|
+
self._time_series = []
|
1204
|
+
self._current_time = 0.0
|