libinephany 0.18.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -59,7 +59,7 @@ class GlobalActivations(GlobalObserver):
59
59
  needed.
60
60
  """
61
61
 
62
- return {statistic_trackers.ActivationStatistics.__name__: dict(skip_statistics=self.skip_statistics)}
62
+ return {statistic_trackers.ActivationStatistics.__name__: dict(include_statistics=self.include_statistics)}
63
63
 
64
64
 
65
65
  class GlobalParameterUpdates(GlobalObserver):
@@ -98,7 +98,7 @@ class GlobalParameterUpdates(GlobalObserver):
98
98
  needed.
99
99
  """
100
100
 
101
- return {statistic_trackers.ParameterUpdateStatistics.__name__: dict(skip_statistics=self.skip_statistics)}
101
+ return {statistic_trackers.ParameterUpdateStatistics.__name__: dict(include_statistics=self.include_statistics)}
102
102
 
103
103
 
104
104
  class GlobalParameters(GlobalObserver):
@@ -137,7 +137,7 @@ class GlobalParameters(GlobalObserver):
137
137
  needed.
138
138
  """
139
139
 
140
- return {statistic_trackers.ParameterStatistics.__name__: dict(skip_statistics=self.skip_statistics)}
140
+ return {statistic_trackers.ParameterStatistics.__name__: dict(include_statistics=self.include_statistics)}
141
141
 
142
142
 
143
143
  class GlobalLAMBTrustRatio(GlobalObserver):
@@ -385,7 +385,7 @@ class LogRatioOfPreviousAndCurrentParamNormEnvStepObserver(LHOPTBaseObserver):
385
385
  """
386
386
 
387
387
  return {
388
- statistic_trackers.ParameterStatistics.__name__: dict(skip_statistics=self.skip_statistics),
388
+ statistic_trackers.ParameterStatistics.__name__: dict(include_statistics=self.include_statistics),
389
389
  }
390
390
 
391
391
  def reset(self) -> None:
@@ -436,6 +436,146 @@ class LogRatioOfUpdateAndPreviousParamNormEnvStepObserver(LHOPTBaseObserver):
436
436
  tensor_statistics=[stats for stats in param_statistics.values() if isinstance(stats, TensorStatistics)]
437
437
  ).norm_
438
438
 
439
+ if self._previous_param_norm is None:
440
+ self._previous_param_norm = current_param_norm
441
+ self._compute_cdf_feature(0.0) # default value since we can't compute log ratio yet
442
+ self._update_time()
443
+ return [0.0, 0.0]
444
+
445
+ log_ratio = self._compute_log_ratio(update_norm, self._previous_param_norm)
446
+ tanh_feature = math.tanh(max(-LHOPT_CONSTANTS["TANH_BOUND"], min(LHOPT_CONSTANTS["TANH_BOUND"], log_ratio)))
447
+ cdf_feature = self._compute_cdf_feature(log_ratio)
448
+
449
+ self._update_time()
450
+ self._previous_param_norm = current_param_norm
451
+
452
+ return [tanh_feature, cdf_feature]
453
+
454
+ def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
455
+ """
456
+ :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
457
+ needed.
458
+ """
459
+
460
+ return {
461
+ statistic_trackers.ParameterUpdateStatistics.__name__: dict(include_statistics=self.include_statistics),
462
+ statistic_trackers.ParameterStatistics.__name__: dict(include_statistics=self.include_statistics),
463
+ }
464
+
465
+ def reset(self) -> None:
466
+ """
467
+ Reset the observer by clearing the previous parameter norm and time series.
468
+ """
469
+
470
+ super().reset()
471
+ self._previous_param_norm = None
472
+
473
+
474
+ class LHOPTAverageParameterUpdateMagnitudeObserver(LHOPTBaseObserver):
475
+
476
+ def _get_observation_format(self) -> StatisticStorageTypes:
477
+ """
478
+ :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
479
+ enumeration class.
480
+ """
481
+
482
+ return StatisticStorageTypes.VECTOR
483
+
484
+ @property
485
+ def can_standardize(self) -> bool:
486
+ """
487
+ :return: Whether the observation can be standardized.
488
+ """
489
+
490
+ return False
491
+
492
+ @property
493
+ def vector_length(self) -> int:
494
+ """
495
+ :return: Length of the vector returned by this observation if it returns a vector.
496
+ """
497
+ return 2 # [raw_feature, cdf_feature]
498
+
499
+ def _observe(
500
+ self,
501
+ observation_inputs: ObservationInputs,
502
+ hyperparameter_states: HyperparameterStates,
503
+ tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
504
+ action_taken: float | int | None,
505
+ ) -> float | int | list[int | float] | TensorStatistics:
506
+ """
507
+ :param observation_inputs: Observation input metrics not calculated with statistic trackers.
508
+ :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
509
+ :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
510
+ names to floats or TensorStatistic models.
511
+ :param action_taken: Action taken by the agent this class instance is assigned to.
512
+ :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
513
+ """
514
+
515
+ statistics = tracked_statistics[statistic_trackers.AverageParameterUpdateMagnitudeStatistics.__name__]
516
+
517
+ raw_feature = list(statistics.values())[0] # type: ignore[list-item]
518
+
519
+ cdf_feature = self._compute_cdf_feature(raw_feature) # type: ignore[arg-type]
520
+ self._update_time()
521
+
522
+ return [raw_feature, cdf_feature] # type: ignore[list-item]
523
+
524
+ def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
525
+ """
526
+ :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
527
+ needed.
528
+ """
529
+
530
+ return {
531
+ statistic_trackers.AverageParameterUpdateMagnitudeStatistics.__name__: dict(
532
+ include_statistics=self.include_statistics, sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"]
533
+ )
534
+ }
535
+
536
+
537
+ class LogRatioOfUpdateAndPreviousParamNormInnerStepObserver(LHOPTBaseObserver):
538
+ def __init__(self, **kwargs):
539
+ """
540
+ This observer is used to compute the log ratio of the update and previous parameter norm for the inner step. The sample frequency of the statistics needs to be set to 4 (according to the OpenAI paper).
541
+
542
+ """
543
+ super().__init__(**kwargs)
544
+ self._previous_param_norm = None
545
+
546
+ @property
547
+ def vector_length(self) -> int:
548
+ """
549
+ :return: Length of the vector returned by this observation if it returns a vector.
550
+ """
551
+ return 2 # [tanh_feature, cdf_feature]
552
+
553
+ def _observe(
554
+ self,
555
+ observation_inputs: ObservationInputs,
556
+ hyperparameter_states: HyperparameterStates,
557
+ tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
558
+ action_taken: float | int | None,
559
+ ) -> float | int | list[int | float] | TensorStatistics:
560
+ """
561
+ :param observation_inputs: Observation input metrics not calculated with statistic trackers.
562
+ :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
563
+ :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
564
+ names to floats or TensorStatistics models.
565
+ :param action_taken: Action taken by the agent this class instance is assigned to.
566
+ :return: List containing [raw_log_ratio, cdf_feature].
567
+ """
568
+
569
+ update_statistics = tracked_statistics[statistic_trackers.LHOPTParameterUpdateStatistics.__name__]
570
+ param_statistics = tracked_statistics[statistic_trackers.LHOPTParameterStatistics.__name__]
571
+ update_norm = observation_utils.average_tensor_statistics(
572
+ tensor_statistics=[stats for stats in update_statistics.values() if isinstance(stats, TensorStatistics)]
573
+ ).norm_
574
+
575
+ current_param_norm = observation_utils.average_tensor_statistics(
576
+ tensor_statistics=[stats for stats in param_statistics.values() if isinstance(stats, TensorStatistics)]
577
+ ).norm_
578
+
439
579
  if self._previous_param_norm is None:
440
580
  self._previous_param_norm = current_param_norm
441
581
  self._compute_cdf_feature(0.0) # default value since we can't compute log ratio yet
@@ -456,8 +596,12 @@ class LogRatioOfUpdateAndPreviousParamNormEnvStepObserver(LHOPTBaseObserver):
456
596
  """
457
597
 
458
598
  return {
459
- statistic_trackers.ParameterUpdateStatistics.__name__: dict(skip_statistics=self.skip_statistics),
460
- statistic_trackers.ParameterStatistics.__name__: dict(skip_statistics=self.skip_statistics),
599
+ statistic_trackers.LHOPTParameterUpdateStatistics.__name__: dict(
600
+ include_statistics=self.include_statistics, sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"]
601
+ ),
602
+ statistic_trackers.LHOPTParameterStatistics.__name__: dict(
603
+ include_statistics=self.include_statistics, sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"]
604
+ ),
461
605
  }
462
606
 
463
607
  def reset(self) -> None:
@@ -467,3 +611,73 @@ class LogRatioOfUpdateAndPreviousParamNormEnvStepObserver(LHOPTBaseObserver):
467
611
 
468
612
  super().reset()
469
613
  self._previous_param_norm = None
614
+
615
+
616
+ class LHOPTGlobalLAMBTrustRatio(LHOPTBaseObserver):
617
+
618
+ def __init__(
619
+ self,
620
+ *,
621
+ use_log_transform: bool = False,
622
+ **kwargs,
623
+ ) -> None:
624
+ """
625
+ :param use_log_transform: Whether to transform the LAMB trust ratio by taking ln(1 + R).
626
+ :param kwargs: Other observation keyword arguments.
627
+ """
628
+
629
+ super().__init__(**kwargs)
630
+
631
+ self.use_log_transform = use_log_transform
632
+
633
+ def _get_observation_format(self) -> StatisticStorageTypes:
634
+ """
635
+ :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
636
+ enumeration class.
637
+ """
638
+
639
+ return StatisticStorageTypes.VECTOR
640
+
641
+ @property
642
+ def vector_length(self) -> int:
643
+ """
644
+ :return: Length of the vector returned by this observation if it returns a vector.
645
+ """
646
+ return 2 # [raw_value, cdf_feature]
647
+
648
+ def _observe(
649
+ self,
650
+ observation_inputs: ObservationInputs,
651
+ hyperparameter_states: HyperparameterStates,
652
+ tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
653
+ action_taken: float | int | None,
654
+ ) -> float | int | list[int | float] | TensorStatistics:
655
+ """
656
+ :param observation_inputs: Observation input metrics not calculated with statistic trackers.
657
+ :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
658
+ :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
659
+ names to floats or TensorStatistic models.
660
+ :param action_taken: Action taken by the agent this class instance is assigned to.
661
+ :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
662
+ """
663
+
664
+ statistics = tracked_statistics[statistic_trackers.LHOPTLAMBTrustRatioStatistics.__name__]
665
+
666
+ raw_value = sum(statistics.values()) / len(statistics) # type: ignore[arg-type]
667
+ cdf_feature = self._compute_cdf_feature(raw_value) # type: ignore[arg-type]
668
+ self._update_time()
669
+ return [raw_value, cdf_feature] # type: ignore[list-item]
670
+
671
+ def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
672
+ """
673
+ :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
674
+ needed.
675
+ """
676
+
677
+ return {
678
+ statistic_trackers.LHOPTLAMBTrustRatioStatistics.__name__: dict(
679
+ include_statistics=self.include_statistics,
680
+ use_log_transform=self.use_log_transform,
681
+ sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"],
682
+ )
683
+ }
@@ -1,6 +1,6 @@
1
1
  # ======================================================================================================================
2
2
  #
3
- # imports
3
+ # IMPORTS
4
4
  #
5
5
  # ======================================================================================================================
6
6
 
@@ -13,6 +13,12 @@ from libinephany.pydantic_models.schemas.observation_models import ObservationIn
13
13
  from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
14
14
  from libinephany.pydantic_models.states.hyperparameter_states import HyperparameterStates
15
15
 
16
+ # ======================================================================================================================
17
+ #
18
+ # CLASSES
19
+ #
20
+ # ======================================================================================================================
21
+
16
22
 
17
23
  class TrainingProgress(GlobalObserver):
18
24
 
@@ -4,11 +4,13 @@
4
4
  #
5
5
  # ======================================================================================================================
6
6
 
7
+ import math
7
8
  from typing import Any
8
9
 
9
10
  from libinephany.observations import observation_utils, statistic_trackers
10
- from libinephany.observations.observation_utils import StatisticStorageTypes
11
+ from libinephany.observations.observation_utils import StatisticStorageTypes, compute_cdf_feature
11
12
  from libinephany.observations.observers.base_observers import LocalObserver
13
+ from libinephany.observations.observers.global_observers.constants import LHOPT_CONSTANTS
12
14
  from libinephany.pydantic_models.schemas.observation_models import ObservationInputs
13
15
  from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
14
16
  from libinephany.pydantic_models.states.hyperparameter_states import HyperparameterStates
@@ -66,7 +68,7 @@ class FirstOrderGradients(LocalObserver):
66
68
  needed.
67
69
  """
68
70
 
69
- return {statistic_trackers.FirstOrderGradients.__name__: dict(skip_statistics=self.skip_statistics)}
71
+ return {statistic_trackers.FirstOrderGradients.__name__: dict(include_statistics=self.include_statistics)}
70
72
 
71
73
 
72
74
  class SecondOrderGradients(LocalObserver):
@@ -130,7 +132,7 @@ class SecondOrderGradients(LocalObserver):
130
132
 
131
133
  return {
132
134
  statistic_trackers.SecondOrderGradients.__name__: dict(
133
- skip_statistics=self.skip_statistics, compute_hessian_diagonal=self.compute_hessian_diagonal
135
+ include_statistics=self.include_statistics, compute_hessian_diagonal=self.compute_hessian_diagonal
134
136
  )
135
137
  }
136
138
 
@@ -178,7 +180,7 @@ class Activations(LocalObserver):
178
180
  needed.
179
181
  """
180
182
 
181
- return {statistic_trackers.ActivationStatistics.__name__: dict(skip_statistics=self.skip_statistics)}
183
+ return {statistic_trackers.ActivationStatistics.__name__: dict(include_statistics=self.include_statistics)}
182
184
 
183
185
 
184
186
  class ParameterUpdates(LocalObserver):
@@ -224,7 +226,7 @@ class ParameterUpdates(LocalObserver):
224
226
  needed.
225
227
  """
226
228
 
227
- return {statistic_trackers.ParameterUpdateStatistics.__name__: dict(skip_statistics=self.skip_statistics)}
229
+ return {statistic_trackers.ParameterUpdateStatistics.__name__: dict(include_statistics=self.include_statistics)}
228
230
 
229
231
 
230
232
  class Parameters(LocalObserver):
@@ -270,7 +272,7 @@ class Parameters(LocalObserver):
270
272
  needed.
271
273
  """
272
274
 
273
- return {statistic_trackers.ParameterStatistics.__name__: dict(skip_statistics=self.skip_statistics)}
275
+ return {statistic_trackers.ParameterStatistics.__name__: dict(include_statistics=self.include_statistics)}
274
276
 
275
277
 
276
278
  class LAMBTrustRatio(LocalObserver):
@@ -690,16 +692,16 @@ class ModuleTypeOneHot(LocalObserver):
690
692
 
691
693
  class CurrentHyperparameters(LocalObserver):
692
694
 
693
- def __init__(self, skip_hparams: list[str] | None = None, **kwargs) -> None:
695
+ def __init__(self, include_hparams: list[str] | None = None, **kwargs) -> None:
694
696
  """
695
- :param skip_hparams: Names of the hyperparameters to not include in the initial values vector returned by
697
+ :param include_hparams: Names of the hyperparameters to include in the initial values vector returned by
696
698
  this observation.
697
699
  :param kwargs: Miscellaneous keyword arguments.
698
700
  """
699
701
 
700
702
  super().__init__(**kwargs)
701
703
 
702
- self.skip_hparams = skip_hparams if skip_hparams is not None else []
704
+ self.include_hparams = include_hparams
703
705
 
704
706
  @property
705
707
  def can_standardize(self) -> bool:
@@ -723,11 +725,12 @@ class CurrentHyperparameters(LocalObserver):
723
725
  :return: Length of the vector returned by this observation if it returns a vector.
724
726
  """
725
727
 
728
+ if self.include_hparams is None:
729
+ raise ValueError(f"{self.__class__.__name__} must be provided with include_hparams.")
730
+
726
731
  available_hparams = HyperparameterStates.get_layerwise_hyperparameters()
727
732
 
728
- return len(
729
- [hparam for hparam in available_hparams if not any(skipped in hparam for skipped in self.skip_hparams)]
730
- )
733
+ return len([hparam for hparam in available_hparams if hparam in self.include_hparams])
731
734
 
732
735
  def _get_observation_format(self) -> StatisticStorageTypes:
733
736
  """
@@ -756,7 +759,7 @@ class CurrentHyperparameters(LocalObserver):
756
759
  assert self.parameter_group_name is not None
757
760
 
758
761
  current_internal_values = hyperparameter_states[self.parameter_group_name].get_current_internal_values(
759
- skip_hparams=self.skip_hparams
762
+ include_hparams=self.include_hparams
760
763
  )
761
764
 
762
765
  self._cached_observation = current_internal_values
@@ -774,16 +777,16 @@ class CurrentHyperparameters(LocalObserver):
774
777
 
775
778
  class CurrentHyperparameterDeltas(LocalObserver):
776
779
 
777
- def __init__(self, skip_hparams: list[str] | None = None, **kwargs) -> None:
780
+ def __init__(self, include_hparams: list[str] | None = None, **kwargs) -> None:
778
781
  """
779
- :param skip_hparams: Names of the hyperparameters to not include in the initial deltas vector returned by
782
+ :param include_hparams: Names of the hyperparameters to include in the initial deltas vector returned by
780
783
  this observation.
781
784
  :param kwargs: Miscellaneous keyword arguments.
782
785
  """
783
786
 
784
787
  super().__init__(**kwargs)
785
788
 
786
- self.skip_hparams = skip_hparams if skip_hparams is not None else []
789
+ self.include_hparams = include_hparams
787
790
 
788
791
  @property
789
792
  def can_standardize(self) -> bool:
@@ -807,11 +810,12 @@ class CurrentHyperparameterDeltas(LocalObserver):
807
810
  :return: Length of the vector returned by this observation if it returns a vector.
808
811
  """
809
812
 
813
+ if self.include_hparams is None:
814
+ raise ValueError(f"{self.__class__.__name__} must be provided with include_hparams.")
815
+
810
816
  available_hparams = HyperparameterStates.get_layerwise_hyperparameters()
811
817
 
812
- return len(
813
- [hparam for hparam in available_hparams if not any(skipped in hparam for skipped in self.skip_hparams)]
814
- )
818
+ return len([hparam for hparam in available_hparams if hparam in self.include_hparams])
815
819
 
816
820
  def _get_observation_format(self) -> StatisticStorageTypes:
817
821
  """
@@ -838,9 +842,10 @@ class CurrentHyperparameterDeltas(LocalObserver):
838
842
  """
839
843
 
840
844
  assert self.parameter_group_name is not None
845
+ assert self.include_hparams is not None
841
846
 
842
847
  current_deltas = hyperparameter_states[self.parameter_group_name].get_current_deltas(
843
- skip_hparams=self.skip_hparams
848
+ include_hparams=self.include_hparams
844
849
  )
845
850
 
846
851
  self._cached_observation = current_deltas
@@ -860,16 +865,16 @@ class HyperparameterTransformTypes(LocalObserver):
860
865
 
861
866
  TRANSFORM_TYPE_TO_IDX = dict(((s, i) for i, s in enumerate(HyperparameterTransformType)))
862
867
 
863
- def __init__(self, skip_hparams: list[str] | None = None, **kwargs) -> None:
868
+ def __init__(self, include_hparams: list[str] | None = None, **kwargs) -> None:
864
869
  """
865
- :param skip_hparams: Names of the hyperparameters to not include in the transforms vector returned by
870
+ :param include_hparams: Names of the hyperparameters to include in the transforms vector returned by
866
871
  this observation.
867
872
  :param kwargs: Miscellaneous keyword arguments.
868
873
  """
869
874
 
870
875
  super().__init__(**kwargs)
871
876
 
872
- self.skip_hparams = skip_hparams if skip_hparams is not None else []
877
+ self.include_hparams = include_hparams
873
878
 
874
879
  @property
875
880
  def can_standardize(self) -> bool:
@@ -893,10 +898,13 @@ class HyperparameterTransformTypes(LocalObserver):
893
898
  :return: Length of the vector returned by this observation if it returns a vector.
894
899
  """
895
900
 
901
+ if self.include_hparams is None:
902
+ raise ValueError(f"{self.__class__.__name__} must be provided with include_hparams.")
903
+
896
904
  available_hparams = HyperparameterStates.get_layerwise_hyperparameters()
897
905
 
898
906
  return len(HyperparameterTransformType) * len(
899
- [hparam for hparam in available_hparams if not any(skipped in hparam for skipped in self.skip_hparams)]
907
+ [hparam for hparam in available_hparams if hparam in self.include_hparams]
900
908
  )
901
909
 
902
910
  def _get_observation_format(self) -> StatisticStorageTypes:
@@ -924,10 +932,11 @@ class HyperparameterTransformTypes(LocalObserver):
924
932
  """
925
933
 
926
934
  assert self.parameter_group_name is not None
935
+ assert self.include_hparams is not None
927
936
 
928
937
  parameter_group_hparams = hyperparameter_states[self.parameter_group_name]
929
938
  hyperparameter_transform_types = parameter_group_hparams.get_hyperparameter_transform_types(
930
- skip_hparams=self.skip_hparams
939
+ include_hparams=self.include_hparams
931
940
  )
932
941
  hyperparameter_transform_types_onehot_list = [
933
942
  observation_utils.create_one_hot_observation(
@@ -1069,3 +1078,127 @@ class PercentageDepth(LocalObserver):
1069
1078
  """
1070
1079
 
1071
1080
  return {}
1081
+
1082
+
1083
+ class LogOfNoiseScaleObserver(LocalObserver):
1084
+
1085
+ def __init__(
1086
+ self,
1087
+ *,
1088
+ decay_factor: float = LHOPT_CONSTANTS["DEFAULT_DECAY_FACTOR"],
1089
+ time_window: int = LHOPT_CONSTANTS["DEFAULT_TIME_WINDOW"],
1090
+ include_statistics: list[str] | None = None,
1091
+ **kwargs,
1092
+ ) -> None:
1093
+ """
1094
+ :param decay_factor: Decay factor for CDF calculation in [1, 2.5, 5, 10, 20]
1095
+ :param time_window: Number of time steps to consider for CDF calculation
1096
+ :param include_statistics: List of statistics to include.
1097
+ or use the squared first order gradients as approximations in the same way Adam does.
1098
+ :param kwargs: Miscellaneous keyword arguments.
1099
+ """
1100
+
1101
+ super().__init__(**kwargs)
1102
+
1103
+ self.include_statistics = include_statistics
1104
+ self.decay_factor = max(0.0, decay_factor)
1105
+ self.time_window = max(1, time_window)
1106
+
1107
+ # Store time series data for CDF calculation
1108
+ self._time_series: list[tuple[float, float]] = [] # (time, value) pairs
1109
+ self._current_time: float = 0.0
1110
+
1111
+ def _get_observation_format(self) -> StatisticStorageTypes:
1112
+ """
1113
+ :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
1114
+ enumeration class.
1115
+ """
1116
+
1117
+ return StatisticStorageTypes.VECTOR
1118
+
1119
+ @property
1120
+ def can_standardize(self) -> bool:
1121
+ """
1122
+ :return: Whether the observation can be standardized.
1123
+ """
1124
+
1125
+ return False
1126
+
1127
+ @property
1128
+ def can_inform(self) -> bool:
1129
+ """
1130
+ :return: Whether observations from the observer can be used in the agent info dictionary.
1131
+ """
1132
+
1133
+ return False
1134
+
1135
+ def _update_time(self) -> None:
1136
+ """Update the current time counter."""
1137
+ self._current_time += 1.0
1138
+
1139
+ def _compute_cdf_feature(self, value: float) -> float:
1140
+ """
1141
+ Compute CDF feature for the given value.
1142
+ training loss will be added to the time series after this call.
1143
+ :param value: The value to compute CDF feature for
1144
+ :return: CDF feature value
1145
+ """
1146
+ return compute_cdf_feature(value, self._time_series, self.decay_factor, self._current_time, self.time_window)
1147
+
1148
+ @property
1149
+ def vector_length(self) -> int:
1150
+ """
1151
+ :return: Length of the vector returned by this observation if it returns a vector.
1152
+ """
1153
+ return 2 # [log_noise_scale, cdf_feature]
1154
+
1155
+ def _observe(
1156
+ self,
1157
+ observation_inputs: ObservationInputs,
1158
+ hyperparameter_states: HyperparameterStates,
1159
+ tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
1160
+ action_taken: float | int | None,
1161
+ ) -> float | int | list[int | float] | TensorStatistics:
1162
+ """
1163
+ :param observation_inputs: Observation input metrics not calculated with statistic trackers.
1164
+ :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
1165
+ :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
1166
+ names to floats or TensorStatistic models.
1167
+ :param action_taken: Action taken by the agent this class instance is assigned to.
1168
+ :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
1169
+ """
1170
+
1171
+ statistics = tracked_statistics[statistic_trackers.LogOfNoiseScaleStatistics.__name__]
1172
+
1173
+ raw_value = list(statistics.values())[0] # type: ignore[list-item]
1174
+ assert isinstance(raw_value, float), f"Expected float, got {type(raw_value)}" # to avoid type errors with mypy
1175
+ batch_size = hyperparameter_states.global_hparams.batch_size.external_value
1176
+ learning_rate = hyperparameter_states.parameter_group_hparams[
1177
+ self.parameter_group_name
1178
+ ].learning_rate.external_value
1179
+
1180
+ log_b_over_epsilon = math.log(batch_size / learning_rate)
1181
+
1182
+ log_noise_scale = raw_value + log_b_over_epsilon
1183
+
1184
+ cdf_feature = self._compute_cdf_feature(log_noise_scale) # type: ignore[arg-type]
1185
+ self._update_time()
1186
+
1187
+ return [log_noise_scale, cdf_feature] # type: ignore[list-item]
1188
+
1189
+ def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
1190
+ """
1191
+ :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
1192
+ needed.
1193
+ """
1194
+
1195
+ return {
1196
+ statistic_trackers.LogOfNoiseScaleStatistics.__name__: dict(
1197
+ include_statistics=self.include_statistics, sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"]
1198
+ )
1199
+ }
1200
+
1201
+ def reset(self) -> None:
1202
+ """Reset the observer by clearing the time series."""
1203
+ self._time_series = []
1204
+ self._current_time = 0.0