libinephany 0.16.3__tar.gz → 0.16.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. libinephany-0.16.5/CODE_VERSION.cfg +1 -0
  2. {libinephany-0.16.3/libinephany.egg-info → libinephany-0.16.5}/PKG-INFO +1 -1
  3. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/observations/observation_utils.py +2 -14
  4. libinephany-0.16.5/libinephany/observations/observers/global_observers/__init__.py +69 -0
  5. libinephany-0.16.5/libinephany/observations/observers/global_observers/base_classes.py +224 -0
  6. libinephany-0.16.5/libinephany/observations/observers/global_observers/constants.py +39 -0
  7. libinephany-0.16.5/libinephany/observations/observers/global_observers/gradient_observers.py +193 -0
  8. libinephany-0.16.5/libinephany/observations/observers/global_observers/hyperparameter_observers.py +397 -0
  9. libinephany-0.16.5/libinephany/observations/observers/global_observers/loss_observers.py +464 -0
  10. libinephany-0.16.5/libinephany/observations/observers/global_observers/model_observers.py +469 -0
  11. libinephany-0.16.5/libinephany/observations/observers/global_observers/progress_observers.py +142 -0
  12. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/observations/observers/local_observers.py +88 -0
  13. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/observations/observers/observer_containers.py +3 -1
  14. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/observations/statistic_trackers.py +75 -0
  15. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/pydantic_models/states/hyperparameter_states.py +17 -0
  16. {libinephany-0.16.3 → libinephany-0.16.5/libinephany.egg-info}/PKG-INFO +1 -1
  17. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany.egg-info/SOURCES.txt +8 -1
  18. libinephany-0.16.3/CODE_VERSION.cfg +0 -1
  19. libinephany-0.16.3/libinephany/observations/observers/global_observers.py +0 -991
  20. {libinephany-0.16.3 → libinephany-0.16.5}/LICENSE +0 -0
  21. {libinephany-0.16.3 → libinephany-0.16.5}/MANIFEST.in +0 -0
  22. {libinephany-0.16.3 → libinephany-0.16.5}/README.md +0 -0
  23. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/__init__.py +0 -0
  24. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/aws/__init__.py +0 -0
  25. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/aws/s3_functions.py +0 -0
  26. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/observations/__init__.py +0 -0
  27. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/observations/observer_pipeline.py +0 -0
  28. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/observations/observers/__init__.py +0 -0
  29. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/observations/observers/base_observers.py +0 -0
  30. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/observations/pipeline_coordinator.py +0 -0
  31. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/observations/post_processors/__init__.py +0 -0
  32. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/observations/post_processors/postprocessors.py +0 -0
  33. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/observations/statistic_manager.py +0 -0
  34. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/pydantic_models/__init__.py +0 -0
  35. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/pydantic_models/configs/__init__.py +0 -0
  36. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/pydantic_models/configs/hyperparameter_configs.py +0 -0
  37. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/pydantic_models/configs/observer_config.py +0 -0
  38. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/pydantic_models/configs/outer_model_config.py +0 -0
  39. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/pydantic_models/schemas/__init__.py +0 -0
  40. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/pydantic_models/schemas/agent_info.py +0 -0
  41. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/pydantic_models/schemas/inner_task_profile.py +0 -0
  42. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/pydantic_models/schemas/observation_models.py +0 -0
  43. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/pydantic_models/schemas/request_schemas.py +0 -0
  44. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/pydantic_models/schemas/response_schemas.py +0 -0
  45. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/pydantic_models/schemas/tensor_statistics.py +0 -0
  46. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/pydantic_models/states/__init__.py +0 -0
  47. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/__init__.py +0 -0
  48. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/agent_utils.py +0 -0
  49. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/asyncio_worker.py +0 -0
  50. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/backend_statuses.py +0 -0
  51. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/constants.py +0 -0
  52. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/directory_utils.py +0 -0
  53. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/dropout_utils.py +0 -0
  54. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/enums.py +0 -0
  55. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/error_severities.py +0 -0
  56. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/exceptions.py +0 -0
  57. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/import_utils.py +0 -0
  58. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/optim_utils.py +0 -0
  59. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/random_seeds.py +0 -0
  60. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/samplers.py +0 -0
  61. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/standardizers.py +0 -0
  62. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/torch_distributed_utils.py +0 -0
  63. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/torch_utils.py +0 -0
  64. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/transforms.py +0 -0
  65. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/typing.py +0 -0
  66. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/web_apps/__init__.py +0 -0
  67. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/web_apps/error_logger.py +0 -0
  68. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/web_apps/web_app_utils.py +0 -0
  69. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany.egg-info/dependency_links.txt +0 -0
  70. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany.egg-info/requires.txt +0 -0
  71. {libinephany-0.16.3 → libinephany-0.16.5}/libinephany.egg-info/top_level.txt +0 -0
  72. {libinephany-0.16.3 → libinephany-0.16.5}/pyproject.toml +0 -0
  73. {libinephany-0.16.3 → libinephany-0.16.5}/setup.cfg +0 -0
@@ -0,0 +1 @@
1
+ 0.16.5
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: libinephany
3
- Version: 0.16.3
3
+ Version: 0.16.5
4
4
  Summary: Inephany library containing code commonly used by multiple subpackages.
5
5
  Author-email: Inephany <info@inephany.com>
6
6
  License: Apache 2.0
@@ -64,20 +64,8 @@ def get_exponential_weighted_average(values: list[int | float]) -> float:
64
64
  :param values: List of values to average via EWA.
65
65
  :return: EWA of the given values.
66
66
  """
67
-
68
- # Check for NaN and infinite values in input
69
- valid_values = [float(val) for val in values if not math.isnan(float(val)) and not math.isinf(float(val))]
70
-
71
- if not valid_values:
72
- raise ValueError("Cannot compute exponential weighted average on empty list")
73
-
74
- if len(valid_values) == 1:
75
- return valid_values[0]
76
-
77
- exp_weighted_average = pd.Series(valid_values).ewm(alpha=0.1).mean().iloc[-1]
67
+ exp_weighted_average = pd.Series(values).ewm(alpha=0.1).mean().iloc[-1]
78
68
  assert isinstance(exp_weighted_average, float)
79
- assert not math.isnan(exp_weighted_average)
80
- assert not math.isinf(exp_weighted_average)
81
69
  return exp_weighted_average
82
70
 
83
71
 
@@ -395,7 +383,7 @@ def _weighted_interval_expectation(
395
383
  interval_gradient = (end_value - start_value) / (end_time_point - start_time_point)
396
384
  start_exp_time = math.exp(log_decay_factor * start_time_point)
397
385
  end_exp_time = math.exp(log_decay_factor * end_time_point)
398
- return (1 / log_decay_factor) * (end_value * end_exp_time - start_value * start_exp_time) + (
386
+ return (1 / log_decay_factor) * (end_value * end_exp_time - start_value * start_exp_time) - (
399
387
  1 / log_decay_factor**2
400
388
  ) * interval_gradient * (end_exp_time - start_exp_time)
401
389
 
@@ -0,0 +1,69 @@
1
+ # ======================================================================================================================
2
+ #
3
+ # GLOBAL OBSERVERS PACKAGE
4
+ #
5
+ # This package contains all global observer classes used for collecting observations
6
+ # across the entire training process, not specific to individual agents.
7
+ #
8
+ # ======================================================================================================================
9
+
10
+
11
+ from .gradient_observers import GlobalFirstOrderGradients, GlobalSecondOrderGradients, LHOPTGradientVarianceFraction
12
+ from .hyperparameter_observers import (
13
+ InitialHyperparameters,
14
+ LHOPTHyperparameterRatio,
15
+ ModelFamilyOneHot,
16
+ OptimizerTypeOneHot,
17
+ )
18
+ from .loss_observers import (
19
+ LHOPTLossRatio,
20
+ LHOPTTrainingLoss,
21
+ LHOPTValidationLoss,
22
+ LossRatio,
23
+ PercentileOfLossAtEachCheckpoint,
24
+ TrainingLoss,
25
+ TrainingScore,
26
+ ValidationLoss,
27
+ ValidationScore,
28
+ )
29
+ from .model_observers import (
30
+ GlobalActivations,
31
+ GlobalLAMBTrustRatio,
32
+ GlobalParameters,
33
+ GlobalParameterUpdates,
34
+ LogRatioOfPreviousAndCurrentParamNormEnvStepObserver,
35
+ LogRatioOfUpdateAndPreviousParamNormEnvStepObserver,
36
+ NumberOfLayers,
37
+ NumberOfParameters,
38
+ )
39
+ from .progress_observers import EpochsCompleted, ProgressAtEachCheckpoint, TrainingProgress
40
+
41
+ __all__ = [
42
+ InitialHyperparameters.__name__,
43
+ LHOPTHyperparameterRatio.__name__,
44
+ OptimizerTypeOneHot.__name__,
45
+ ModelFamilyOneHot.__name__,
46
+ TrainingLoss.__name__,
47
+ ValidationLoss.__name__,
48
+ LossRatio.__name__,
49
+ TrainingScore.__name__,
50
+ ValidationScore.__name__,
51
+ GlobalFirstOrderGradients.__name__,
52
+ GlobalSecondOrderGradients.__name__,
53
+ LHOPTGradientVarianceFraction.__name__,
54
+ GlobalActivations.__name__,
55
+ GlobalParameterUpdates.__name__,
56
+ GlobalParameters.__name__,
57
+ GlobalLAMBTrustRatio.__name__,
58
+ NumberOfParameters.__name__,
59
+ NumberOfLayers.__name__,
60
+ LogRatioOfPreviousAndCurrentParamNormEnvStepObserver.__name__,
61
+ LogRatioOfUpdateAndPreviousParamNormEnvStepObserver.__name__,
62
+ TrainingProgress.__name__,
63
+ EpochsCompleted.__name__,
64
+ ProgressAtEachCheckpoint.__name__,
65
+ LHOPTTrainingLoss.__name__,
66
+ LHOPTValidationLoss.__name__,
67
+ LHOPTLossRatio.__name__,
68
+ PercentileOfLossAtEachCheckpoint.__name__,
69
+ ]
@@ -0,0 +1,224 @@
1
+ # ======================================================================================================================
2
+ #
3
+ # BASE CLASSES
4
+ #
5
+ # ======================================================================================================================
6
+
7
+ import math
8
+ from abc import ABC, abstractmethod
9
+ from typing import Any
10
+
11
+ from libinephany.observations.observation_utils import StatisticStorageTypes, compute_cdf_feature
12
+ from libinephany.observations.observers.base_observers import GlobalObserver
13
+ from libinephany.observations.observers.global_observers.constants import LHOPT_CONSTANTS
14
+ from libinephany.pydantic_models.schemas.observation_models import ObservationInputs
15
+ from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
16
+ from libinephany.pydantic_models.states.hyperparameter_states import HyperparameterStates
17
+
18
+
19
+ class LHOPTBaseObserver(GlobalObserver, ABC):
20
+ """
21
+ Base class for LHOPT outer step observers to eliminate duplicate code.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ decay_factor: float = LHOPT_CONSTANTS["DEFAULT_DECAY_FACTOR"],
27
+ time_window: int = LHOPT_CONSTANTS["DEFAULT_TIME_WINDOW"],
28
+ **kwargs,
29
+ ) -> None:
30
+ """
31
+ :param decay_factor: Decay factor for CDF calculation in [1, 2.5, 5, 10, 20]
32
+ :param time_window: Number of time steps to consider for CDF calculation
33
+ :param kwargs: Other observation keyword arguments.
34
+ """
35
+ super().__init__(**kwargs)
36
+ self.decay_factor = max(0.0, decay_factor)
37
+ self.time_window = max(1, time_window)
38
+
39
+ # Store time series data for CDF calculation
40
+ self._time_series: list[tuple[float, float]] = [] # (time, value) pairs
41
+ self._current_time: float = 0.0
42
+
43
+ @property
44
+ def can_standardize(self) -> bool:
45
+ """
46
+ This observer has its own CDF calculation, no need to standardize.
47
+ :return: Whether the observation can be standardized.
48
+ """
49
+ return False
50
+
51
+ def _get_observation_format(self) -> StatisticStorageTypes:
52
+ """
53
+ :return: Format the observation returns data in. Must be one of the StatisticStorageTypes
54
+ enumeration class.
55
+ """
56
+ return StatisticStorageTypes.VECTOR
57
+
58
+ def _compute_cdf_feature(self, value: float) -> float:
59
+ """
60
+ Compute CDF feature for the given value.
61
+ training loss will be added to the time series after this call.
62
+ :param value: The value to compute CDF feature for
63
+ :return: CDF feature value
64
+ """
65
+ return compute_cdf_feature(value, self._time_series, self.decay_factor, self._current_time, self.time_window)
66
+
67
+ def _update_time(self) -> None:
68
+ """Update the current time counter."""
69
+ self._current_time += 1.0
70
+
71
+ def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
72
+ """
73
+ :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
74
+ needed.
75
+ """
76
+ return {}
77
+
78
+ def reset(self) -> None:
79
+ """Reset the observer by clearing the time series."""
80
+ self._time_series = []
81
+ self._current_time = 0.0
82
+
83
+ @abstractmethod
84
+ def _observe(
85
+ self,
86
+ observation_inputs: ObservationInputs,
87
+ hyperparameter_states: HyperparameterStates,
88
+ tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
89
+ action_taken: float | int | None,
90
+ ) -> float | int | list[int | float] | TensorStatistics:
91
+ """
92
+ :param observation_inputs: Observation input metrics not calculated with statistic trackers.
93
+ :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
94
+ :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
95
+ :param action_taken: Action taken by the agent this class instance is assigned to.
96
+ """
97
+ raise NotImplementedError
98
+
99
+ def _compute_log_ratio(self, numerator: float, denominator: float) -> float:
100
+ """
101
+ Compute the log ratio.
102
+
103
+ :param numerator: Numerator value
104
+ :param denominator: Denominator value
105
+ :return: Log ratio value
106
+ """
107
+ # Calculate the ratio of numerator to denominator
108
+
109
+ if denominator <= LHOPT_CONSTANTS["ZERO_DIVISION_TOLERANCE"]:
110
+ return 0.0
111
+
112
+ ratio = numerator / denominator
113
+
114
+ if ratio <= 0:
115
+ return 0.0
116
+
117
+ return math.log(ratio)
118
+
119
+
120
+ class LHOPTCheckpointBaseObserver(GlobalObserver, ABC):
121
+ """
122
+ Base class for checkpoint-based observers to eliminate duplicate code.
123
+ """
124
+
125
+ def __init__(self, checkpoint_interval: int = LHOPT_CONSTANTS["DEFAULT_CHECKPOINT_INTERVAL"], **kwargs) -> None:
126
+ """
127
+ :param checkpoint_interval: How often to create checkpoints (in outer model steps).
128
+ :param kwargs: Miscellaneous keyword arguments.
129
+ """
130
+ super().__init__(**kwargs)
131
+ self.checkpoint_interval = checkpoint_interval
132
+ self._history: list[float] = []
133
+ self.last_value: float | None = None
134
+
135
+ @property
136
+ def can_standardize(self) -> bool:
137
+ """
138
+ This observer has its own CDF calculation, no need to standardize.
139
+ :return: Whether the observation can be standardized.
140
+ """
141
+ return False
142
+
143
+ def _get_observation_format(self) -> StatisticStorageTypes:
144
+ """
145
+ :return: Format the observation returns data in.
146
+ """
147
+ return StatisticStorageTypes.FLOAT
148
+
149
+ def _update_history(self, value: float) -> None:
150
+ """
151
+ Update the history with a new value and maintain sliding window.
152
+
153
+ :param value: The new value to add to history
154
+ """
155
+ self._history.append(value)
156
+
157
+ # Keep only the last checkpoint_interval values for sliding window
158
+ if len(self._history) > self.checkpoint_interval:
159
+ self._history = self._history[-self.checkpoint_interval :]
160
+
161
+ def _should_create_checkpoint(self) -> bool:
162
+ """
163
+ Check if we should create a checkpoint.
164
+
165
+ :return: True if checkpoint should be created, False otherwise
166
+ """
167
+ return len(self._history) >= self.checkpoint_interval
168
+
169
+ def _cold_start(self, value: float) -> None:
170
+ """
171
+ Handle cold start by setting the last value if not already set.
172
+
173
+ :param value: The value to set as last value if cold start
174
+ """
175
+ if self.last_value is None:
176
+ self.last_value = value
177
+
178
+ def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
179
+ """
180
+ :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
181
+ needed.
182
+ """
183
+ return {}
184
+
185
+ def reset(self) -> None:
186
+ """Reset the observer by clearing history."""
187
+ self._history = []
188
+ self.last_value = None
189
+
190
+ @abstractmethod
191
+ def _observe(
192
+ self,
193
+ observation_inputs: ObservationInputs,
194
+ hyperparameter_states: HyperparameterStates,
195
+ tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
196
+ action_taken: float | int | None,
197
+ ) -> float | int | list[int | float] | TensorStatistics:
198
+ """
199
+ :param observation_inputs: Observation input metrics not calculated with statistic trackers.
200
+ :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
201
+ :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
202
+ :param action_taken: Action taken by the agent this class instance is assigned to.
203
+ """
204
+ raise NotImplementedError
205
+
206
+ def _compute_log_ratio(self, numerator: float, denominator: float) -> float:
207
+ """
208
+ Compute the log ratio.
209
+
210
+ :param numerator: Numerator value
211
+ :param denominator: Denominator value
212
+ :return: Log ratio value
213
+ """
214
+ # Calculate the ratio of numerator to denominator
215
+
216
+ if denominator <= LHOPT_CONSTANTS["ZERO_DIVISION_TOLERANCE"]:
217
+ return 0.0
218
+
219
+ ratio = numerator / denominator
220
+
221
+ if ratio <= 0:
222
+ return 0.0
223
+
224
+ return math.log(ratio)
@@ -0,0 +1,39 @@
1
+ # ======================================================================================================================
2
+ #
3
+ # CONSTANTS
4
+ #
5
+ # ======================================================================================================================
6
+
7
+ from typing import TypedDict
8
+
9
+
10
+ class LHOPTConstants(TypedDict):
11
+ IS_NAN: float
12
+ NOT_NAN: float
13
+ IS_INF: float
14
+ NOT_INF: float
15
+ TANH_BOUND: float
16
+ DEFAULT_DECAY_FACTOR: float
17
+ DEFAULT_TIME_WINDOW: int
18
+ DEFAULT_CHECKPOINT_INTERVAL: int
19
+ DEFAULT_PERCENTILE: float
20
+ ZERO_DIVISION_TOLERANCE: float
21
+ DEFAULT_SAMPLE_FREQUENCY: int
22
+ DEFAULT_VARIANCE_THRESHOLD: float
23
+
24
+
25
+ # Create the constants instance
26
+ LHOPT_CONSTANTS: LHOPTConstants = LHOPTConstants(
27
+ IS_NAN=1.0,
28
+ NOT_NAN=0.0,
29
+ IS_INF=1.0,
30
+ NOT_INF=0.0,
31
+ TANH_BOUND=10.0,
32
+ DEFAULT_DECAY_FACTOR=1.25,
33
+ DEFAULT_TIME_WINDOW=32,
34
+ DEFAULT_CHECKPOINT_INTERVAL=100,
35
+ DEFAULT_PERCENTILE=0.6,
36
+ ZERO_DIVISION_TOLERANCE=1e-8,
37
+ DEFAULT_SAMPLE_FREQUENCY=4,
38
+ DEFAULT_VARIANCE_THRESHOLD=1e-6,
39
+ )
@@ -0,0 +1,193 @@
1
+ # ======================================================================================================================
2
+ #
3
+ # GRADIENT OBSERVERS
4
+ #
5
+ # ======================================================================================================================
6
+
7
+ from typing import Any
8
+
9
+ from libinephany.observations import observation_utils, statistic_trackers
10
+ from libinephany.observations.observation_utils import StatisticStorageTypes
11
+ from libinephany.observations.observers.base_observers import GlobalObserver
12
+ from libinephany.observations.observers.global_observers.base_classes import LHOPTBaseObserver
13
+ from libinephany.observations.observers.global_observers.constants import LHOPT_CONSTANTS
14
+ from libinephany.pydantic_models.schemas.observation_models import ObservationInputs
15
+ from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
16
+ from libinephany.pydantic_models.states.hyperparameter_states import HyperparameterStates
17
+
18
+
19
+ class GlobalFirstOrderGradients(GlobalObserver):
20
+
21
+ def _get_observation_format(self) -> StatisticStorageTypes:
22
+ """
23
+ :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
24
+ enumeration class.
25
+ """
26
+
27
+ return StatisticStorageTypes.TENSOR_STATISTICS
28
+
29
+ def _observe(
30
+ self,
31
+ observation_inputs: ObservationInputs,
32
+ hyperparameter_states: HyperparameterStates,
33
+ tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
34
+ action_taken: float | int | None,
35
+ ) -> float | int | list[int | float] | TensorStatistics:
36
+ """
37
+ :param observation_inputs: Observation input metrics not calculated with statistic trackers.
38
+ :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
39
+ :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
40
+ names to floats or TensorStatistic models.
41
+ :param action_taken: Action taken by the agent this class instance is assigned to.
42
+ :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
43
+ """
44
+
45
+ statistics = tracked_statistics[statistic_trackers.FirstOrderGradients.__name__]
46
+
47
+ return observation_utils.average_tensor_statistics(tensor_statistics=list(statistics.values())) # type: ignore
48
+
49
+ def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
50
+ """
51
+ :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
52
+ needed.
53
+ """
54
+
55
+ return {statistic_trackers.FirstOrderGradients.__name__: dict(skip_statistics=self.skip_statistics)}
56
+
57
+
58
+ class GlobalSecondOrderGradients(GlobalObserver):
59
+
60
+ def __init__(
61
+ self,
62
+ *,
63
+ compute_hessian_diagonal: bool = False,
64
+ **kwargs,
65
+ ) -> None:
66
+ """
67
+ :param compute_hessian_diagonal: Whether to compute the Hessian diagonal to determine second order gradients
68
+ or use the squared first order gradients as approximations in the same way Adam does.
69
+ :param kwargs: Miscellaneous keyword arguments.
70
+ """
71
+
72
+ super().__init__(**kwargs)
73
+
74
+ self.compute_hessian_diagonal = compute_hessian_diagonal
75
+
76
+ def _get_observation_format(self) -> StatisticStorageTypes:
77
+ """
78
+ :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
79
+ enumeration class.
80
+ """
81
+
82
+ return StatisticStorageTypes.TENSOR_STATISTICS
83
+
84
+ def _observe(
85
+ self,
86
+ observation_inputs: ObservationInputs,
87
+ hyperparameter_states: HyperparameterStates,
88
+ tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
89
+ action_taken: float | int | None,
90
+ ) -> float | int | list[int | float] | TensorStatistics:
91
+ """
92
+ :param observation_inputs: Observation input metrics not calculated with statistic trackers.
93
+ :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
94
+ :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
95
+ names to floats or TensorStatistic models.
96
+ :param action_taken: Action taken by the agent this class instance is assigned to.
97
+ :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
98
+ """
99
+
100
+ statistics = tracked_statistics[statistic_trackers.SecondOrderGradients.__name__]
101
+
102
+ return observation_utils.average_tensor_statistics(tensor_statistics=list(statistics.values())) # type: ignore
103
+
104
+ def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
105
+ """
106
+ :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
107
+ needed.
108
+ """
109
+
110
+ return {
111
+ statistic_trackers.SecondOrderGradients.__name__: dict(
112
+ skip_statistics=self.skip_statistics, compute_hessian_diagonal=self.compute_hessian_diagonal
113
+ )
114
+ }
115
+
116
+
117
+ class LHOPTGradientVarianceFraction(LHOPTBaseObserver):
118
+ """
119
+ This is a global observer from the OpenAI paper "Learning to Optimize with Reinforcement Learning"
120
+ https://arxiv.org/abs/2305.18291.
121
+
122
+ It returns two-dimensional observations: [raw_value, cdf_feature] for gradient variance fraction values.
123
+ """
124
+
125
+ def __init__(
126
+ self,
127
+ *,
128
+ variance_threshold: float = LHOPT_CONSTANTS["DEFAULT_VARIANCE_THRESHOLD"],
129
+ **kwargs,
130
+ ) -> None:
131
+ """
132
+ :param variance_threshold: Threshold for variance comparison in gradient variance fraction calculation
133
+ :param kwargs: Other observation keyword arguments.
134
+ """
135
+ super().__init__(**kwargs)
136
+ self.variance_threshold = variance_threshold
137
+
138
+ @property
139
+ def can_standardize(self) -> bool:
140
+ """
141
+ This observer has its own CDF calculation, no need to standardize.
142
+ :return: Whether the observation can be standardized.
143
+ """
144
+ return False
145
+
146
+ def _get_observation_format(self) -> StatisticStorageTypes:
147
+ """
148
+ :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
149
+ enumeration class.
150
+ """
151
+ return StatisticStorageTypes.VECTOR
152
+
153
+ @property
154
+ def vector_length(self) -> int:
155
+ """
156
+ :return: Length of the vector returned by this observation if it returns a vector.
157
+ """
158
+ return 2 # [raw_value, cdf_feature]
159
+
160
+ def _observe(
161
+ self,
162
+ observation_inputs: ObservationInputs,
163
+ hyperparameter_states: HyperparameterStates,
164
+ tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
165
+ action_taken: float | int | None,
166
+ ) -> float | int | list[int | float] | TensorStatistics:
167
+ """
168
+ :param observation_inputs: Observation input metrics not calculated with statistic trackers.
169
+ :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
170
+ :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
171
+ names to floats or TensorStatistic models.
172
+ :param action_taken: Action taken by the agent this class instance is assigned to.
173
+ :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
174
+ """
175
+ if statistic_trackers.GradientVarianceFraction.__name__ not in tracked_statistics:
176
+ return [0.0, 0.0]
177
+
178
+ raw_value = list(tracked_statistics[statistic_trackers.GradientVarianceFraction.__name__].values())[0] # type: ignore[list-item]
179
+
180
+ cdf_feature = self._compute_cdf_feature(raw_value) # type: ignore[arg-type]
181
+ self._update_time()
182
+
183
+ return [raw_value, cdf_feature] # type: ignore[list-item]
184
+
185
+ def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
186
+ """
187
+ :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
188
+ needed.
189
+ """
190
+
191
+ return {
192
+ statistic_trackers.GradientVarianceFraction.__name__: dict(variance_threshold=self.variance_threshold),
193
+ }