libinephany 0.18.1__tar.gz → 0.19.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. libinephany-0.19.0/CODE_VERSION.cfg +1 -0
  2. {libinephany-0.18.1/libinephany.egg-info → libinephany-0.19.0}/PKG-INFO +1 -1
  3. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/observations/observation_utils.py +19 -2
  4. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/observations/observers/global_observers/__init__.py +19 -1
  5. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/observations/observers/global_observers/constants.py +2 -0
  6. libinephany-0.19.0/libinephany/observations/observers/global_observers/gradient_observers.py +511 -0
  7. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/observations/observers/global_observers/model_observers.py +219 -3
  8. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/observations/observers/local_observers.py +127 -1
  9. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/observations/statistic_trackers.py +595 -0
  10. {libinephany-0.18.1 → libinephany-0.19.0/libinephany.egg-info}/PKG-INFO +1 -1
  11. libinephany-0.18.1/CODE_VERSION.cfg +0 -1
  12. libinephany-0.18.1/libinephany/observations/observers/global_observers/gradient_observers.py +0 -193
  13. {libinephany-0.18.1 → libinephany-0.19.0}/LICENSE +0 -0
  14. {libinephany-0.18.1 → libinephany-0.19.0}/MANIFEST.in +0 -0
  15. {libinephany-0.18.1 → libinephany-0.19.0}/README.md +0 -0
  16. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/__init__.py +0 -0
  17. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/aws/__init__.py +0 -0
  18. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/aws/s3_functions.py +0 -0
  19. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/observations/__init__.py +0 -0
  20. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/observations/observer_pipeline.py +0 -0
  21. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/observations/observers/__init__.py +0 -0
  22. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/observations/observers/base_observers.py +0 -0
  23. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/observations/observers/global_observers/base_classes.py +0 -0
  24. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/observations/observers/global_observers/hyperparameter_observers.py +0 -0
  25. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/observations/observers/global_observers/loss_observers.py +0 -0
  26. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/observations/observers/global_observers/progress_observers.py +0 -0
  27. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/observations/observers/observer_containers.py +0 -0
  28. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/observations/pipeline_coordinator.py +0 -0
  29. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/observations/post_processors/__init__.py +0 -0
  30. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/observations/post_processors/postprocessors.py +0 -0
  31. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/observations/statistic_manager.py +0 -0
  32. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/pydantic_models/__init__.py +0 -0
  33. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/pydantic_models/configs/__init__.py +0 -0
  34. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/pydantic_models/configs/hyperparameter_configs.py +0 -0
  35. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/pydantic_models/configs/observer_config.py +0 -0
  36. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/pydantic_models/configs/outer_model_config.py +0 -0
  37. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/pydantic_models/schemas/__init__.py +0 -0
  38. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/pydantic_models/schemas/agent_info.py +0 -0
  39. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/pydantic_models/schemas/inner_task_profile.py +0 -0
  40. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/pydantic_models/schemas/observation_models.py +0 -0
  41. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/pydantic_models/schemas/request_schemas.py +0 -0
  42. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/pydantic_models/schemas/response_schemas.py +0 -0
  43. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/pydantic_models/schemas/tensor_statistics.py +0 -0
  44. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/pydantic_models/states/__init__.py +0 -0
  45. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/pydantic_models/states/hyperparameter_states.py +0 -0
  46. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/utils/__init__.py +0 -0
  47. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/utils/agent_utils.py +0 -0
  48. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/utils/asyncio_worker.py +0 -0
  49. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/utils/backend_statuses.py +0 -0
  50. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/utils/constants.py +0 -0
  51. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/utils/directory_utils.py +0 -0
  52. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/utils/dropout_utils.py +0 -0
  53. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/utils/enums.py +0 -0
  54. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/utils/error_severities.py +0 -0
  55. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/utils/exceptions.py +0 -0
  56. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/utils/import_utils.py +0 -0
  57. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/utils/optim_utils.py +0 -0
  58. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/utils/random_seeds.py +0 -0
  59. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/utils/samplers.py +0 -0
  60. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/utils/standardizers.py +0 -0
  61. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/utils/torch_distributed_utils.py +0 -0
  62. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/utils/torch_utils.py +0 -0
  63. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/utils/transforms.py +0 -0
  64. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/utils/typing.py +0 -0
  65. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/web_apps/__init__.py +0 -0
  66. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/web_apps/error_logger.py +0 -0
  67. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany/web_apps/web_app_utils.py +0 -0
  68. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany.egg-info/SOURCES.txt +0 -0
  69. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany.egg-info/dependency_links.txt +0 -0
  70. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany.egg-info/requires.txt +0 -0
  71. {libinephany-0.18.1 → libinephany-0.19.0}/libinephany.egg-info/top_level.txt +0 -0
  72. {libinephany-0.18.1 → libinephany-0.19.0}/pyproject.toml +0 -0
  73. {libinephany-0.18.1 → libinephany-0.19.0}/setup.cfg +0 -0
@@ -0,0 +1 @@
1
+ 0.19.0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: libinephany
3
- Version: 0.18.1
3
+ Version: 0.19.0
4
4
  Summary: Inephany library containing code commonly used by multiple subpackages.
5
5
  Author-email: Inephany <info@inephany.com>
6
6
  License: Apache 2.0
@@ -25,6 +25,7 @@ from libinephany.utils import optim_utils
25
25
  # ======================================================================================================================
26
26
 
27
27
  EXP_AVERAGE = "exp_avg"
28
+ MOMENTUM_BUFFER = "momentum_buffer"
28
29
  MIN_DECAY_FACTOR = 1e-10
29
30
 
30
31
  MIN_TOTAL_WEIGHT = 1e-15 # Minimum total weight threshold for numerical stability
@@ -64,10 +65,8 @@ def get_exponential_weighted_average(values: list[int | float]) -> float:
64
65
  :param values: List of values to average via EWA.
65
66
  :return: EWA of the given values.
66
67
  """
67
-
68
68
  exp_weighted_average = pd.Series(values).ewm(alpha=0.1).mean().iloc[-1]
69
69
  assert isinstance(exp_weighted_average, float)
70
-
71
70
  return exp_weighted_average
72
71
 
73
72
 
@@ -232,6 +231,24 @@ def form_update_tensor(
232
231
  raise NotImplementedError(f"Optimizer {type(optimizer).__name__} is not supported!")
233
232
 
234
233
 
234
+ def form_momentum_tensor(
235
+ optimizer: optim.Optimizer, parameters: list[torch.Tensor], parameter_group: dict[str, Any]
236
+ ) -> None | torch.Tensor:
237
+ """
238
+ :param optimizer: Optimizer to form the momentum tensor from.
239
+ :param parameters: Parameters to create the momentum tensor from.
240
+ :param parameter_group: Parameter group within the optimizer the given parameters came from.
241
+ """
242
+ if type(optimizer) in optim_utils.ADAM_OPTIMISERS:
243
+ momentum_list = [optimizer.state[p][EXP_AVERAGE].view(-1) for p in parameters if tensor_on_local_rank(p)]
244
+ return torch.cat(momentum_list) if momentum_list else None
245
+ elif type(optimizer) in optim_utils.SGD_OPTIMISERS:
246
+ momentum_list = [optimizer.state[p][MOMENTUM_BUFFER].view(-1) for p in parameters if tensor_on_local_rank(p)]
247
+ return torch.cat(momentum_list) if momentum_list else None
248
+ else:
249
+ raise NotImplementedError(f"Optimizer {type(optimizer).__name__} is not supported!")
250
+
251
+
235
252
  def null_standardizer(value_to_standardize: float, **kwargs) -> float:
236
253
  """
237
254
  :param value_to_standardize: Value to mock the standardization of.
@@ -8,7 +8,15 @@
8
8
  # ======================================================================================================================
9
9
 
10
10
 
11
- from .gradient_observers import GlobalFirstOrderGradients, GlobalSecondOrderGradients, LHOPTGradientVarianceFraction
11
+ from .gradient_observers import (
12
+ CosineSimilarityObserverOfGradientAndMomentum,
13
+ CosineSimilarityObserverOfGradientAndUpdate,
14
+ CosineSimilarityOfGradientAndParameter,
15
+ GlobalFirstOrderGradients,
16
+ GlobalSecondOrderGradients,
17
+ LHOPTGradientVarianceFraction,
18
+ LHOPTMomentumGradientRatio,
19
+ )
12
20
  from .hyperparameter_observers import (
13
21
  InitialHyperparameters,
14
22
  LHOPTHyperparameterRatio,
@@ -31,8 +39,11 @@ from .model_observers import (
31
39
  GlobalLAMBTrustRatio,
32
40
  GlobalParameters,
33
41
  GlobalParameterUpdates,
42
+ LHOPTAverageParameterUpdateMagnitudeObserver,
43
+ LHOPTGlobalLAMBTrustRatio,
34
44
  LogRatioOfPreviousAndCurrentParamNormEnvStepObserver,
35
45
  LogRatioOfUpdateAndPreviousParamNormEnvStepObserver,
46
+ LogRatioOfUpdateAndPreviousParamNormInnerStepObserver,
36
47
  NumberOfLayers,
37
48
  NumberOfParameters,
38
49
  )
@@ -51,14 +62,17 @@ __all__ = [
51
62
  GlobalFirstOrderGradients.__name__,
52
63
  GlobalSecondOrderGradients.__name__,
53
64
  LHOPTGradientVarianceFraction.__name__,
65
+ LHOPTMomentumGradientRatio.__name__,
54
66
  GlobalActivations.__name__,
55
67
  GlobalParameterUpdates.__name__,
56
68
  GlobalParameters.__name__,
57
69
  GlobalLAMBTrustRatio.__name__,
58
70
  NumberOfParameters.__name__,
59
71
  NumberOfLayers.__name__,
72
+ LHOPTAverageParameterUpdateMagnitudeObserver.__name__,
60
73
  LogRatioOfPreviousAndCurrentParamNormEnvStepObserver.__name__,
61
74
  LogRatioOfUpdateAndPreviousParamNormEnvStepObserver.__name__,
75
+ LogRatioOfUpdateAndPreviousParamNormInnerStepObserver.__name__,
62
76
  TrainingProgress.__name__,
63
77
  EpochsCompleted.__name__,
64
78
  ProgressAtEachCheckpoint.__name__,
@@ -66,4 +80,8 @@ __all__ = [
66
80
  LHOPTValidationLoss.__name__,
67
81
  LHOPTLossRatio.__name__,
68
82
  PercentileOfLossAtEachCheckpoint.__name__,
83
+ LHOPTGlobalLAMBTrustRatio.__name__,
84
+ CosineSimilarityObserverOfGradientAndMomentum.__name__,
85
+ CosineSimilarityObserverOfGradientAndUpdate.__name__,
86
+ CosineSimilarityOfGradientAndParameter.__name__,
69
87
  ]
@@ -20,6 +20,7 @@ class LHOPTConstants(TypedDict):
20
20
  ZERO_DIVISION_TOLERANCE: float
21
21
  DEFAULT_SAMPLE_FREQUENCY: int
22
22
  DEFAULT_VARIANCE_THRESHOLD: float
23
+ DEFAULT_ENV_STEP_SAMPLE_FREQUENCY: int
23
24
 
24
25
 
25
26
  # Create the constants instance
@@ -36,4 +37,5 @@ LHOPT_CONSTANTS: LHOPTConstants = LHOPTConstants(
36
37
  ZERO_DIVISION_TOLERANCE=1e-8,
37
38
  DEFAULT_SAMPLE_FREQUENCY=4,
38
39
  DEFAULT_VARIANCE_THRESHOLD=1e-6,
40
+ DEFAULT_ENV_STEP_SAMPLE_FREQUENCY=10,
39
41
  )
@@ -0,0 +1,511 @@
1
+ # ======================================================================================================================
2
+ #
3
+ # GRADIENT OBSERVERS
4
+ #
5
+ # ======================================================================================================================
6
+
7
+ import math
8
+ from typing import Any
9
+
10
+ from libinephany.observations import observation_utils, statistic_trackers
11
+ from libinephany.observations.observation_utils import StatisticStorageTypes
12
+ from libinephany.observations.observers.base_observers import GlobalObserver
13
+ from libinephany.observations.observers.global_observers.base_classes import LHOPTBaseObserver
14
+ from libinephany.observations.observers.global_observers.constants import LHOPT_CONSTANTS
15
+ from libinephany.pydantic_models.schemas.observation_models import ObservationInputs
16
+ from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
17
+ from libinephany.pydantic_models.states.hyperparameter_states import HyperparameterStates
18
+
19
+
20
+ class GlobalFirstOrderGradients(GlobalObserver):
21
+
22
+ def _get_observation_format(self) -> StatisticStorageTypes:
23
+ """
24
+ :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
25
+ enumeration class.
26
+ """
27
+
28
+ return StatisticStorageTypes.TENSOR_STATISTICS
29
+
30
+ def _observe(
31
+ self,
32
+ observation_inputs: ObservationInputs,
33
+ hyperparameter_states: HyperparameterStates,
34
+ tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
35
+ action_taken: float | int | None,
36
+ ) -> float | int | list[int | float] | TensorStatistics:
37
+ """
38
+ :param observation_inputs: Observation input metrics not calculated with statistic trackers.
39
+ :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
40
+ :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
41
+ names to floats or TensorStatistic models.
42
+ :param action_taken: Action taken by the agent this class instance is assigned to.
43
+ :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
44
+ """
45
+
46
+ statistics = tracked_statistics[statistic_trackers.FirstOrderGradients.__name__]
47
+
48
+ return observation_utils.average_tensor_statistics(tensor_statistics=list(statistics.values())) # type: ignore
49
+
50
+ def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
51
+ """
52
+ :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
53
+ needed.
54
+ """
55
+
56
+ return {statistic_trackers.FirstOrderGradients.__name__: dict(skip_statistics=self.skip_statistics)}
57
+
58
+
59
+ class GlobalSecondOrderGradients(GlobalObserver):
60
+
61
+ def __init__(
62
+ self,
63
+ *,
64
+ compute_hessian_diagonal: bool = False,
65
+ **kwargs,
66
+ ) -> None:
67
+ """
68
+ :param compute_hessian_diagonal: Whether to compute the Hessian diagonal to determine second order gradients
69
+ or use the squared first order gradients as approximations in the same way Adam does.
70
+ :param kwargs: Miscellaneous keyword arguments.
71
+ """
72
+
73
+ super().__init__(**kwargs)
74
+
75
+ self.compute_hessian_diagonal = compute_hessian_diagonal
76
+
77
+ def _get_observation_format(self) -> StatisticStorageTypes:
78
+ """
79
+ :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
80
+ enumeration class.
81
+ """
82
+
83
+ return StatisticStorageTypes.TENSOR_STATISTICS
84
+
85
+ def _observe(
86
+ self,
87
+ observation_inputs: ObservationInputs,
88
+ hyperparameter_states: HyperparameterStates,
89
+ tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
90
+ action_taken: float | int | None,
91
+ ) -> float | int | list[int | float] | TensorStatistics:
92
+ """
93
+ :param observation_inputs: Observation input metrics not calculated with statistic trackers.
94
+ :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
95
+ :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
96
+ names to floats or TensorStatistic models.
97
+ :param action_taken: Action taken by the agent this class instance is assigned to.
98
+ :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
99
+ """
100
+
101
+ statistics = tracked_statistics[statistic_trackers.SecondOrderGradients.__name__]
102
+
103
+ return observation_utils.average_tensor_statistics(tensor_statistics=list(statistics.values())) # type: ignore
104
+
105
+ def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
106
+ """
107
+ :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
108
+ needed.
109
+ """
110
+
111
+ return {
112
+ statistic_trackers.SecondOrderGradients.__name__: dict(
113
+ skip_statistics=self.skip_statistics, compute_hessian_diagonal=self.compute_hessian_diagonal
114
+ )
115
+ }
116
+
117
+
118
+ class LHOPTGradientVarianceFraction(LHOPTBaseObserver):
119
+ """
120
+ This is a global observer from the OpenAI paper "Learning to Optimize with Reinforcement Learning"
121
+ https://arxiv.org/abs/2305.18291.
122
+
123
+ It returns two-dimensional observations: [raw_value, cdf_feature] for gradient variance fraction values.
124
+ """
125
+
126
+ def __init__(
127
+ self,
128
+ *,
129
+ variance_threshold: float = LHOPT_CONSTANTS["DEFAULT_VARIANCE_THRESHOLD"],
130
+ **kwargs,
131
+ ) -> None:
132
+ """
133
+ :param variance_threshold: Threshold for variance comparison in gradient variance fraction calculation
134
+ :param kwargs: Other observation keyword arguments.
135
+ """
136
+ super().__init__(**kwargs)
137
+ self.variance_threshold = variance_threshold
138
+
139
+ @property
140
+ def can_standardize(self) -> bool:
141
+ """
142
+ This observer has its own CDF calculation, no need to standardize.
143
+ :return: Whether the observation can be standardized.
144
+ """
145
+ return False
146
+
147
+ def _get_observation_format(self) -> StatisticStorageTypes:
148
+ """
149
+ :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
150
+ enumeration class.
151
+ """
152
+ return StatisticStorageTypes.VECTOR
153
+
154
+ @property
155
+ def vector_length(self) -> int:
156
+ """
157
+ :return: Length of the vector returned by this observation if it returns a vector.
158
+ """
159
+ return 2 # [raw_value, cdf_feature]
160
+
161
+ def _observe(
162
+ self,
163
+ observation_inputs: ObservationInputs,
164
+ hyperparameter_states: HyperparameterStates,
165
+ tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
166
+ action_taken: float | int | None,
167
+ ) -> float | int | list[int | float] | TensorStatistics:
168
+ """
169
+ :param observation_inputs: Observation input metrics not calculated with statistic trackers.
170
+ :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
171
+ :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
172
+ names to floats or TensorStatistic models.
173
+ :param action_taken: Action taken by the agent this class instance is assigned to.
174
+ :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
175
+ """
176
+ if statistic_trackers.GradientVarianceFraction.__name__ not in tracked_statistics:
177
+ return [0.0, 0.0]
178
+
179
+ raw_value = list(tracked_statistics[statistic_trackers.GradientVarianceFraction.__name__].values())[0] # type: ignore[list-item]
180
+
181
+ cdf_feature = self._compute_cdf_feature(raw_value) # type: ignore[arg-type]
182
+ self._update_time()
183
+
184
+ return [raw_value, cdf_feature] # type: ignore[list-item]
185
+
186
+ def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
187
+ """
188
+ :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
189
+ needed.
190
+ """
191
+
192
+ return {
193
+ statistic_trackers.GradientVarianceFraction.__name__: dict(
194
+ variance_threshold=self.variance_threshold, sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"]
195
+ ),
196
+ }
197
+
198
+
199
+ class LHOPTMomentumGradientRatio(LHOPTBaseObserver):
200
+ """
201
+ This is a global observer from the OpenAI paper "Learning to Optimize with Reinforcement Learning"
202
+ https://arxiv.org/abs/2305.18291.
203
+
204
+ It returns two-dimensional observations: [raw_value, cdf_feature] for momentum gradient ratio values.
205
+ """
206
+
207
+ def _get_observation_format(self) -> StatisticStorageTypes:
208
+ """
209
+ :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
210
+ enumeration class.
211
+ """
212
+
213
+ return StatisticStorageTypes.VECTOR
214
+
215
+ @property
216
+ def vector_length(self) -> int:
217
+ """
218
+ :return: Length of the vector returned by this observation if it returns a vector.
219
+ """
220
+ return 2 # [raw_value, cdf_feature]
221
+
222
+ def _observe(
223
+ self,
224
+ observation_inputs: ObservationInputs,
225
+ hyperparameter_states: HyperparameterStates,
226
+ tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
227
+ action_taken: float | int | None,
228
+ ) -> float | int | list[int | float] | TensorStatistics:
229
+ """
230
+ :param observation_inputs: Observation input metrics not calculated with statistic trackers.
231
+ :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
232
+ :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
233
+ names to floats or TensorStatistic models.
234
+ :param action_taken: Action taken by the agent this class instance is assigned to.
235
+ :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
236
+ """
237
+
238
+ statistics = tracked_statistics[statistic_trackers.MomentumGradientRatioStatistics.__name__]
239
+
240
+ raw_value = list(statistics.values())[0] # type: ignore[list-item]
241
+
242
+ cdf_feature = self._compute_cdf_feature(raw_value) # type: ignore[arg-type]
243
+ self._update_time()
244
+
245
+ return [raw_value, cdf_feature] # type: ignore[list-item]
246
+
247
+ def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
248
+ """
249
+ :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
250
+ needed.
251
+ """
252
+
253
+ return {
254
+ statistic_trackers.MomentumGradientRatioStatistics.__name__: dict(
255
+ skip_statistics=self.skip_statistics,
256
+ sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"],
257
+ ),
258
+ }
259
+
260
+
261
+ class CosineSimilarityObserverOfGradientAndMomentum(LHOPTBaseObserver):
262
+ """
263
+ This is a global observer from the OpenAI paper "Learning to Optimize with Reinforcement Learning"
264
+ https://arxiv.org/abs/2305.18291.
265
+
266
+ It returns two-dimensional observations: [raw_value, cdf_feature] for cosine similarity of gradient and momentum values.
267
+ """
268
+
269
+ def __init__(
270
+ self,
271
+ *,
272
+ skip_statistics: list[str] | None = None,
273
+ **kwargs,
274
+ ) -> None:
275
+ """
276
+ :param compute_hessian_diagonal: Whether to compute the Hessian diagonal to determine second order gradients
277
+ or use the squared first order gradients as approximations in the same way Adam does.
278
+ :param kwargs: Miscellaneous keyword arguments.
279
+ """
280
+
281
+ super().__init__(**kwargs)
282
+
283
+ self.skip_statistics = skip_statistics
284
+
285
+ def _get_observation_format(self) -> StatisticStorageTypes:
286
+ """
287
+ :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
288
+ enumeration class.
289
+ """
290
+
291
+ return StatisticStorageTypes.VECTOR
292
+
293
+ @property
294
+ def vector_length(self) -> int:
295
+ """
296
+ :return: Length of the vector returned by this observation if it returns a vector.
297
+ """
298
+ return 3 # [raw_value, cdf_feature, logit_of_cdf_feature]
299
+
300
+ def _observe(
301
+ self,
302
+ observation_inputs: ObservationInputs,
303
+ hyperparameter_states: HyperparameterStates,
304
+ tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
305
+ action_taken: float | int | None,
306
+ ) -> float | int | list[int | float] | TensorStatistics:
307
+ """
308
+ :param observation_inputs: Observation input metrics not calculated with statistic trackers.
309
+ :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
310
+ :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
311
+ names to floats or TensorStatistic models.
312
+ :param action_taken: Action taken by the agent this class instance is assigned to.
313
+ :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
314
+ """
315
+
316
+ statistics = tracked_statistics[
317
+ statistic_trackers.CosineSimilarityObserverOfGradientAndMomentumStatistics.__name__
318
+ ]
319
+
320
+ raw_value = list(statistics.values())[0] # type: ignore[list-item]
321
+
322
+ cdf_feature = self._compute_cdf_feature(raw_value) # type: ignore[arg-type]
323
+ self._update_time()
324
+
325
+ # Handle edge cases for logit calculation
326
+ if cdf_feature <= 0.0 or cdf_feature >= 1.0:
327
+ logit_of_cdf_feature = 0.0
328
+ else:
329
+ logit_of_cdf_feature = math.log(cdf_feature / (1 - cdf_feature))
330
+
331
+ return [raw_value, cdf_feature, logit_of_cdf_feature] # type: ignore[list-item]
332
+
333
+ def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
334
+ """
335
+ :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
336
+ needed.
337
+ """
338
+
339
+ return {
340
+ statistic_trackers.CosineSimilarityObserverOfGradientAndMomentumStatistics.__name__: dict(
341
+ skip_statistics=self.skip_statistics,
342
+ sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"],
343
+ )
344
+ }
345
+
346
+
347
+ class CosineSimilarityObserverOfGradientAndUpdate(LHOPTBaseObserver):
348
+ """
349
+ This is a global observer from the OpenAI paper "Learning to Optimize with Reinforcement Learning"
350
+ https://arxiv.org/abs/2305.18291.
351
+
352
+ It returns two-dimensional observations: [raw_value, cdf_feature] for cosine similarity of gradient and update values.
353
+ """
354
+
355
+ def __init__(
356
+ self,
357
+ *,
358
+ skip_statistics: list[str] | None = None,
359
+ **kwargs,
360
+ ) -> None:
361
+ """
362
+ :param skip_statistics: List of statistics to skip.
363
+ :param kwargs: Miscellaneous keyword arguments.
364
+ """
365
+
366
+ super().__init__(**kwargs)
367
+
368
+ self.skip_statistics = skip_statistics
369
+
370
+ def _get_observation_format(self) -> StatisticStorageTypes:
371
+ """
372
+ :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
373
+ enumeration class.
374
+ """
375
+
376
+ return StatisticStorageTypes.VECTOR
377
+
378
+ @property
379
+ def vector_length(self) -> int:
380
+ """
381
+ :return: Length of the vector returned by this observation if it returns a vector.
382
+ """
383
+ return 3 # [raw_value, cdf_feature, logit_of_cdf_feature]
384
+
385
+ def _observe(
386
+ self,
387
+ observation_inputs: ObservationInputs,
388
+ hyperparameter_states: HyperparameterStates,
389
+ tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
390
+ action_taken: float | int | None,
391
+ ) -> float | int | list[int | float] | TensorStatistics:
392
+ """
393
+ :param observation_inputs: Observation input metrics not calculated with statistic trackers.
394
+ :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
395
+ :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
396
+ names to floats or TensorStatistic models.
397
+ :param action_taken: Action taken by the agent this class instance is assigned to.
398
+ :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
399
+ """
400
+
401
+ statistics = tracked_statistics[
402
+ statistic_trackers.CosineSimilarityObserverOfGradientAndUpdateStatistics.__name__
403
+ ]
404
+
405
+ raw_value = list(statistics.values())[0] # type: ignore[list-item]
406
+
407
+ cdf_feature = self._compute_cdf_feature(raw_value) # type: ignore[arg-type]
408
+ self._update_time()
409
+
410
+ # Handle edge cases for logit calculation
411
+ if cdf_feature <= 0.0 or cdf_feature >= 1.0:
412
+ logit_of_cdf_feature = 0.0
413
+ else:
414
+ logit_of_cdf_feature = math.log(cdf_feature / (1 - cdf_feature))
415
+
416
+ return [raw_value, cdf_feature, logit_of_cdf_feature] # type: ignore[list-item]
417
+
418
+ def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
419
+ """
420
+ :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
421
+ needed.
422
+ """
423
+
424
+ return {
425
+ statistic_trackers.CosineSimilarityObserverOfGradientAndUpdateStatistics.__name__: dict(
426
+ skip_statistics=self.skip_statistics,
427
+ sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"],
428
+ )
429
+ }
430
+
431
+
432
+ class CosineSimilarityOfGradientAndParameter(LHOPTBaseObserver):
433
+ """
434
+ This is a global observer from the OpenAI paper "Learning to Optimize with Reinforcement Learning"
435
+ https://arxiv.org/abs/2305.18291.
436
+
437
+ It returns two-dimensional observations: [raw_value, cdf_feature] for cosine similarity of gradient and parameter values.
438
+ """
439
+
440
+ def __init__(
441
+ self,
442
+ *,
443
+ skip_statistics: list[str] | None = None,
444
+ **kwargs,
445
+ ) -> None:
446
+ """
447
+ :param skip_statistics: List of statistics to skip.
448
+ :param kwargs: Miscellaneous keyword arguments.
449
+ """
450
+ super().__init__(**kwargs)
451
+
452
+ self.skip_statistics = skip_statistics
453
+
454
+ def _get_observation_format(self) -> StatisticStorageTypes:
455
+ """
456
+ :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
457
+ enumeration class.
458
+ """
459
+
460
+ return StatisticStorageTypes.VECTOR
461
+
462
+ @property
463
+ def vector_length(self) -> int:
464
+ """
465
+ :return: Length of the vector returned by this observation if it returns a vector.
466
+ """
467
+ return 3 # [raw_value, cdf_feature, logit_of_cdf_feature]
468
+
469
+ def _observe(
470
+ self,
471
+ observation_inputs: ObservationInputs,
472
+ hyperparameter_states: HyperparameterStates,
473
+ tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
474
+ action_taken: float | int | None,
475
+ ) -> float | int | list[int | float] | TensorStatistics:
476
+ """
477
+ :param observation_inputs: Observation input metrics not calculated with statistic trackers.
478
+ :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
479
+ :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
480
+ names to floats or TensorStatistic models.
481
+ :param action_taken: Action taken by the agent this class instance is assigned to.
482
+ :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
483
+ """
484
+
485
+ statistics = tracked_statistics[statistic_trackers.CosineSimilarityOfGradientAndParameterStatistics.__name__]
486
+
487
+ raw_value = list(statistics.values())[0] # type: ignore[list-item]
488
+
489
+ cdf_feature = self._compute_cdf_feature(raw_value) # type: ignore[arg-type]
490
+ self._update_time()
491
+
492
+ # Handle edge cases for logit calculation
493
+ if cdf_feature <= 0.0 or cdf_feature >= 1.0:
494
+ logit_of_cdf_feature = 0.0
495
+ else:
496
+ logit_of_cdf_feature = math.log(cdf_feature / (1 - cdf_feature))
497
+
498
+ return [raw_value, cdf_feature, logit_of_cdf_feature] # type: ignore[list-item]
499
+
500
+ def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
501
+ """
502
+ :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
503
+ needed.
504
+ """
505
+
506
+ return {
507
+ statistic_trackers.CosineSimilarityOfGradientAndParameterStatistics.__name__: dict(
508
+ skip_statistics=self.skip_statistics,
509
+ sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"],
510
+ )
511
+ }