libinephany 0.16.3__tar.gz → 0.16.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- libinephany-0.16.5/CODE_VERSION.cfg +1 -0
- {libinephany-0.16.3/libinephany.egg-info → libinephany-0.16.5}/PKG-INFO +1 -1
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/observations/observation_utils.py +2 -14
- libinephany-0.16.5/libinephany/observations/observers/global_observers/__init__.py +69 -0
- libinephany-0.16.5/libinephany/observations/observers/global_observers/base_classes.py +224 -0
- libinephany-0.16.5/libinephany/observations/observers/global_observers/constants.py +39 -0
- libinephany-0.16.5/libinephany/observations/observers/global_observers/gradient_observers.py +193 -0
- libinephany-0.16.5/libinephany/observations/observers/global_observers/hyperparameter_observers.py +397 -0
- libinephany-0.16.5/libinephany/observations/observers/global_observers/loss_observers.py +464 -0
- libinephany-0.16.5/libinephany/observations/observers/global_observers/model_observers.py +469 -0
- libinephany-0.16.5/libinephany/observations/observers/global_observers/progress_observers.py +142 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/observations/observers/local_observers.py +88 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/observations/observers/observer_containers.py +3 -1
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/observations/statistic_trackers.py +75 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/pydantic_models/states/hyperparameter_states.py +17 -0
- {libinephany-0.16.3 → libinephany-0.16.5/libinephany.egg-info}/PKG-INFO +1 -1
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany.egg-info/SOURCES.txt +8 -1
- libinephany-0.16.3/CODE_VERSION.cfg +0 -1
- libinephany-0.16.3/libinephany/observations/observers/global_observers.py +0 -991
- {libinephany-0.16.3 → libinephany-0.16.5}/LICENSE +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/MANIFEST.in +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/README.md +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/__init__.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/aws/__init__.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/aws/s3_functions.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/observations/__init__.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/observations/observer_pipeline.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/observations/observers/__init__.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/observations/observers/base_observers.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/observations/pipeline_coordinator.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/observations/post_processors/__init__.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/observations/post_processors/postprocessors.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/observations/statistic_manager.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/pydantic_models/__init__.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/pydantic_models/configs/__init__.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/pydantic_models/configs/hyperparameter_configs.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/pydantic_models/configs/observer_config.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/pydantic_models/configs/outer_model_config.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/pydantic_models/schemas/__init__.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/pydantic_models/schemas/agent_info.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/pydantic_models/schemas/inner_task_profile.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/pydantic_models/schemas/observation_models.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/pydantic_models/schemas/request_schemas.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/pydantic_models/schemas/response_schemas.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/pydantic_models/schemas/tensor_statistics.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/pydantic_models/states/__init__.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/__init__.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/agent_utils.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/asyncio_worker.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/backend_statuses.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/constants.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/directory_utils.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/dropout_utils.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/enums.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/error_severities.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/exceptions.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/import_utils.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/optim_utils.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/random_seeds.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/samplers.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/standardizers.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/torch_distributed_utils.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/torch_utils.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/transforms.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/utils/typing.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/web_apps/__init__.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/web_apps/error_logger.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany/web_apps/web_app_utils.py +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany.egg-info/dependency_links.txt +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany.egg-info/requires.txt +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/libinephany.egg-info/top_level.txt +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/pyproject.toml +0 -0
- {libinephany-0.16.3 → libinephany-0.16.5}/setup.cfg +0 -0
@@ -0,0 +1 @@
|
|
1
|
+
0.16.5
|
@@ -64,20 +64,8 @@ def get_exponential_weighted_average(values: list[int | float]) -> float:
|
|
64
64
|
:param values: List of values to average via EWA.
|
65
65
|
:return: EWA of the given values.
|
66
66
|
"""
|
67
|
-
|
68
|
-
# Check for NaN and infinite values in input
|
69
|
-
valid_values = [float(val) for val in values if not math.isnan(float(val)) and not math.isinf(float(val))]
|
70
|
-
|
71
|
-
if not valid_values:
|
72
|
-
raise ValueError("Cannot compute exponential weighted average on empty list")
|
73
|
-
|
74
|
-
if len(valid_values) == 1:
|
75
|
-
return valid_values[0]
|
76
|
-
|
77
|
-
exp_weighted_average = pd.Series(valid_values).ewm(alpha=0.1).mean().iloc[-1]
|
67
|
+
exp_weighted_average = pd.Series(values).ewm(alpha=0.1).mean().iloc[-1]
|
78
68
|
assert isinstance(exp_weighted_average, float)
|
79
|
-
assert not math.isnan(exp_weighted_average)
|
80
|
-
assert not math.isinf(exp_weighted_average)
|
81
69
|
return exp_weighted_average
|
82
70
|
|
83
71
|
|
@@ -395,7 +383,7 @@ def _weighted_interval_expectation(
|
|
395
383
|
interval_gradient = (end_value - start_value) / (end_time_point - start_time_point)
|
396
384
|
start_exp_time = math.exp(log_decay_factor * start_time_point)
|
397
385
|
end_exp_time = math.exp(log_decay_factor * end_time_point)
|
398
|
-
return (1 / log_decay_factor) * (end_value * end_exp_time - start_value * start_exp_time)
|
386
|
+
return (1 / log_decay_factor) * (end_value * end_exp_time - start_value * start_exp_time) - (
|
399
387
|
1 / log_decay_factor**2
|
400
388
|
) * interval_gradient * (end_exp_time - start_exp_time)
|
401
389
|
|
@@ -0,0 +1,69 @@
|
|
1
|
+
# ======================================================================================================================
|
2
|
+
#
|
3
|
+
# GLOBAL OBSERVERS PACKAGE
|
4
|
+
#
|
5
|
+
# This package contains all global observer classes used for collecting observations
|
6
|
+
# across the entire training process, not specific to individual agents.
|
7
|
+
#
|
8
|
+
# ======================================================================================================================
|
9
|
+
|
10
|
+
|
11
|
+
from .gradient_observers import GlobalFirstOrderGradients, GlobalSecondOrderGradients, LHOPTGradientVarianceFraction
|
12
|
+
from .hyperparameter_observers import (
|
13
|
+
InitialHyperparameters,
|
14
|
+
LHOPTHyperparameterRatio,
|
15
|
+
ModelFamilyOneHot,
|
16
|
+
OptimizerTypeOneHot,
|
17
|
+
)
|
18
|
+
from .loss_observers import (
|
19
|
+
LHOPTLossRatio,
|
20
|
+
LHOPTTrainingLoss,
|
21
|
+
LHOPTValidationLoss,
|
22
|
+
LossRatio,
|
23
|
+
PercentileOfLossAtEachCheckpoint,
|
24
|
+
TrainingLoss,
|
25
|
+
TrainingScore,
|
26
|
+
ValidationLoss,
|
27
|
+
ValidationScore,
|
28
|
+
)
|
29
|
+
from .model_observers import (
|
30
|
+
GlobalActivations,
|
31
|
+
GlobalLAMBTrustRatio,
|
32
|
+
GlobalParameters,
|
33
|
+
GlobalParameterUpdates,
|
34
|
+
LogRatioOfPreviousAndCurrentParamNormEnvStepObserver,
|
35
|
+
LogRatioOfUpdateAndPreviousParamNormEnvStepObserver,
|
36
|
+
NumberOfLayers,
|
37
|
+
NumberOfParameters,
|
38
|
+
)
|
39
|
+
from .progress_observers import EpochsCompleted, ProgressAtEachCheckpoint, TrainingProgress
|
40
|
+
|
41
|
+
__all__ = [
|
42
|
+
InitialHyperparameters.__name__,
|
43
|
+
LHOPTHyperparameterRatio.__name__,
|
44
|
+
OptimizerTypeOneHot.__name__,
|
45
|
+
ModelFamilyOneHot.__name__,
|
46
|
+
TrainingLoss.__name__,
|
47
|
+
ValidationLoss.__name__,
|
48
|
+
LossRatio.__name__,
|
49
|
+
TrainingScore.__name__,
|
50
|
+
ValidationScore.__name__,
|
51
|
+
GlobalFirstOrderGradients.__name__,
|
52
|
+
GlobalSecondOrderGradients.__name__,
|
53
|
+
LHOPTGradientVarianceFraction.__name__,
|
54
|
+
GlobalActivations.__name__,
|
55
|
+
GlobalParameterUpdates.__name__,
|
56
|
+
GlobalParameters.__name__,
|
57
|
+
GlobalLAMBTrustRatio.__name__,
|
58
|
+
NumberOfParameters.__name__,
|
59
|
+
NumberOfLayers.__name__,
|
60
|
+
LogRatioOfPreviousAndCurrentParamNormEnvStepObserver.__name__,
|
61
|
+
LogRatioOfUpdateAndPreviousParamNormEnvStepObserver.__name__,
|
62
|
+
TrainingProgress.__name__,
|
63
|
+
EpochsCompleted.__name__,
|
64
|
+
ProgressAtEachCheckpoint.__name__,
|
65
|
+
LHOPTTrainingLoss.__name__,
|
66
|
+
LHOPTValidationLoss.__name__,
|
67
|
+
LHOPTLossRatio.__name__,
|
68
|
+
PercentileOfLossAtEachCheckpoint.__name__,
|
69
|
+
]
|
@@ -0,0 +1,224 @@
|
|
1
|
+
# ======================================================================================================================
|
2
|
+
#
|
3
|
+
# BASE CLASSES
|
4
|
+
#
|
5
|
+
# ======================================================================================================================
|
6
|
+
|
7
|
+
import math
|
8
|
+
from abc import ABC, abstractmethod
|
9
|
+
from typing import Any
|
10
|
+
|
11
|
+
from libinephany.observations.observation_utils import StatisticStorageTypes, compute_cdf_feature
|
12
|
+
from libinephany.observations.observers.base_observers import GlobalObserver
|
13
|
+
from libinephany.observations.observers.global_observers.constants import LHOPT_CONSTANTS
|
14
|
+
from libinephany.pydantic_models.schemas.observation_models import ObservationInputs
|
15
|
+
from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
|
16
|
+
from libinephany.pydantic_models.states.hyperparameter_states import HyperparameterStates
|
17
|
+
|
18
|
+
|
19
|
+
class LHOPTBaseObserver(GlobalObserver, ABC):
|
20
|
+
"""
|
21
|
+
Base class for LHOPT outer step observers to eliminate duplicate code.
|
22
|
+
"""
|
23
|
+
|
24
|
+
def __init__(
|
25
|
+
self,
|
26
|
+
decay_factor: float = LHOPT_CONSTANTS["DEFAULT_DECAY_FACTOR"],
|
27
|
+
time_window: int = LHOPT_CONSTANTS["DEFAULT_TIME_WINDOW"],
|
28
|
+
**kwargs,
|
29
|
+
) -> None:
|
30
|
+
"""
|
31
|
+
:param decay_factor: Decay factor for CDF calculation in [1, 2.5, 5, 10, 20]
|
32
|
+
:param time_window: Number of time steps to consider for CDF calculation
|
33
|
+
:param kwargs: Other observation keyword arguments.
|
34
|
+
"""
|
35
|
+
super().__init__(**kwargs)
|
36
|
+
self.decay_factor = max(0.0, decay_factor)
|
37
|
+
self.time_window = max(1, time_window)
|
38
|
+
|
39
|
+
# Store time series data for CDF calculation
|
40
|
+
self._time_series: list[tuple[float, float]] = [] # (time, value) pairs
|
41
|
+
self._current_time: float = 0.0
|
42
|
+
|
43
|
+
@property
|
44
|
+
def can_standardize(self) -> bool:
|
45
|
+
"""
|
46
|
+
This observer has its own CDF calculation, no need to standardize.
|
47
|
+
:return: Whether the observation can be standardized.
|
48
|
+
"""
|
49
|
+
return False
|
50
|
+
|
51
|
+
def _get_observation_format(self) -> StatisticStorageTypes:
|
52
|
+
"""
|
53
|
+
:return: Format the observation returns data in. Must be one of the StatisticStorageTypes
|
54
|
+
enumeration class.
|
55
|
+
"""
|
56
|
+
return StatisticStorageTypes.VECTOR
|
57
|
+
|
58
|
+
def _compute_cdf_feature(self, value: float) -> float:
|
59
|
+
"""
|
60
|
+
Compute CDF feature for the given value.
|
61
|
+
training loss will be added to the time series after this call.
|
62
|
+
:param value: The value to compute CDF feature for
|
63
|
+
:return: CDF feature value
|
64
|
+
"""
|
65
|
+
return compute_cdf_feature(value, self._time_series, self.decay_factor, self._current_time, self.time_window)
|
66
|
+
|
67
|
+
def _update_time(self) -> None:
|
68
|
+
"""Update the current time counter."""
|
69
|
+
self._current_time += 1.0
|
70
|
+
|
71
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
72
|
+
"""
|
73
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
74
|
+
needed.
|
75
|
+
"""
|
76
|
+
return {}
|
77
|
+
|
78
|
+
def reset(self) -> None:
|
79
|
+
"""Reset the observer by clearing the time series."""
|
80
|
+
self._time_series = []
|
81
|
+
self._current_time = 0.0
|
82
|
+
|
83
|
+
@abstractmethod
|
84
|
+
def _observe(
|
85
|
+
self,
|
86
|
+
observation_inputs: ObservationInputs,
|
87
|
+
hyperparameter_states: HyperparameterStates,
|
88
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
89
|
+
action_taken: float | int | None,
|
90
|
+
) -> float | int | list[int | float] | TensorStatistics:
|
91
|
+
"""
|
92
|
+
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
93
|
+
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
94
|
+
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
95
|
+
:param action_taken: Action taken by the agent this class instance is assigned to.
|
96
|
+
"""
|
97
|
+
raise NotImplementedError
|
98
|
+
|
99
|
+
def _compute_log_ratio(self, numerator: float, denominator: float) -> float:
|
100
|
+
"""
|
101
|
+
Compute the log ratio.
|
102
|
+
|
103
|
+
:param numerator: Numerator value
|
104
|
+
:param denominator: Denominator value
|
105
|
+
:return: Log ratio value
|
106
|
+
"""
|
107
|
+
# Calculate the ratio of numerator to denominator
|
108
|
+
|
109
|
+
if denominator <= LHOPT_CONSTANTS["ZERO_DIVISION_TOLERANCE"]:
|
110
|
+
return 0.0
|
111
|
+
|
112
|
+
ratio = numerator / denominator
|
113
|
+
|
114
|
+
if ratio <= 0:
|
115
|
+
return 0.0
|
116
|
+
|
117
|
+
return math.log(ratio)
|
118
|
+
|
119
|
+
|
120
|
+
class LHOPTCheckpointBaseObserver(GlobalObserver, ABC):
|
121
|
+
"""
|
122
|
+
Base class for checkpoint-based observers to eliminate duplicate code.
|
123
|
+
"""
|
124
|
+
|
125
|
+
def __init__(self, checkpoint_interval: int = LHOPT_CONSTANTS["DEFAULT_CHECKPOINT_INTERVAL"], **kwargs) -> None:
|
126
|
+
"""
|
127
|
+
:param checkpoint_interval: How often to create checkpoints (in outer model steps).
|
128
|
+
:param kwargs: Miscellaneous keyword arguments.
|
129
|
+
"""
|
130
|
+
super().__init__(**kwargs)
|
131
|
+
self.checkpoint_interval = checkpoint_interval
|
132
|
+
self._history: list[float] = []
|
133
|
+
self.last_value: float | None = None
|
134
|
+
|
135
|
+
@property
|
136
|
+
def can_standardize(self) -> bool:
|
137
|
+
"""
|
138
|
+
This observer has its own CDF calculation, no need to standardize.
|
139
|
+
:return: Whether the observation can be standardized.
|
140
|
+
"""
|
141
|
+
return False
|
142
|
+
|
143
|
+
def _get_observation_format(self) -> StatisticStorageTypes:
|
144
|
+
"""
|
145
|
+
:return: Format the observation returns data in.
|
146
|
+
"""
|
147
|
+
return StatisticStorageTypes.FLOAT
|
148
|
+
|
149
|
+
def _update_history(self, value: float) -> None:
|
150
|
+
"""
|
151
|
+
Update the history with a new value and maintain sliding window.
|
152
|
+
|
153
|
+
:param value: The new value to add to history
|
154
|
+
"""
|
155
|
+
self._history.append(value)
|
156
|
+
|
157
|
+
# Keep only the last checkpoint_interval values for sliding window
|
158
|
+
if len(self._history) > self.checkpoint_interval:
|
159
|
+
self._history = self._history[-self.checkpoint_interval :]
|
160
|
+
|
161
|
+
def _should_create_checkpoint(self) -> bool:
|
162
|
+
"""
|
163
|
+
Check if we should create a checkpoint.
|
164
|
+
|
165
|
+
:return: True if checkpoint should be created, False otherwise
|
166
|
+
"""
|
167
|
+
return len(self._history) >= self.checkpoint_interval
|
168
|
+
|
169
|
+
def _cold_start(self, value: float) -> None:
|
170
|
+
"""
|
171
|
+
Handle cold start by setting the last value if not already set.
|
172
|
+
|
173
|
+
:param value: The value to set as last value if cold start
|
174
|
+
"""
|
175
|
+
if self.last_value is None:
|
176
|
+
self.last_value = value
|
177
|
+
|
178
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
179
|
+
"""
|
180
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
181
|
+
needed.
|
182
|
+
"""
|
183
|
+
return {}
|
184
|
+
|
185
|
+
def reset(self) -> None:
|
186
|
+
"""Reset the observer by clearing history."""
|
187
|
+
self._history = []
|
188
|
+
self.last_value = None
|
189
|
+
|
190
|
+
@abstractmethod
|
191
|
+
def _observe(
|
192
|
+
self,
|
193
|
+
observation_inputs: ObservationInputs,
|
194
|
+
hyperparameter_states: HyperparameterStates,
|
195
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
196
|
+
action_taken: float | int | None,
|
197
|
+
) -> float | int | list[int | float] | TensorStatistics:
|
198
|
+
"""
|
199
|
+
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
200
|
+
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
201
|
+
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
202
|
+
:param action_taken: Action taken by the agent this class instance is assigned to.
|
203
|
+
"""
|
204
|
+
raise NotImplementedError
|
205
|
+
|
206
|
+
def _compute_log_ratio(self, numerator: float, denominator: float) -> float:
|
207
|
+
"""
|
208
|
+
Compute the log ratio.
|
209
|
+
|
210
|
+
:param numerator: Numerator value
|
211
|
+
:param denominator: Denominator value
|
212
|
+
:return: Log ratio value
|
213
|
+
"""
|
214
|
+
# Calculate the ratio of numerator to denominator
|
215
|
+
|
216
|
+
if denominator <= LHOPT_CONSTANTS["ZERO_DIVISION_TOLERANCE"]:
|
217
|
+
return 0.0
|
218
|
+
|
219
|
+
ratio = numerator / denominator
|
220
|
+
|
221
|
+
if ratio <= 0:
|
222
|
+
return 0.0
|
223
|
+
|
224
|
+
return math.log(ratio)
|
@@ -0,0 +1,39 @@
|
|
1
|
+
# ======================================================================================================================
|
2
|
+
#
|
3
|
+
# CONSTANTS
|
4
|
+
#
|
5
|
+
# ======================================================================================================================
|
6
|
+
|
7
|
+
from typing import TypedDict
|
8
|
+
|
9
|
+
|
10
|
+
class LHOPTConstants(TypedDict):
|
11
|
+
IS_NAN: float
|
12
|
+
NOT_NAN: float
|
13
|
+
IS_INF: float
|
14
|
+
NOT_INF: float
|
15
|
+
TANH_BOUND: float
|
16
|
+
DEFAULT_DECAY_FACTOR: float
|
17
|
+
DEFAULT_TIME_WINDOW: int
|
18
|
+
DEFAULT_CHECKPOINT_INTERVAL: int
|
19
|
+
DEFAULT_PERCENTILE: float
|
20
|
+
ZERO_DIVISION_TOLERANCE: float
|
21
|
+
DEFAULT_SAMPLE_FREQUENCY: int
|
22
|
+
DEFAULT_VARIANCE_THRESHOLD: float
|
23
|
+
|
24
|
+
|
25
|
+
# Create the constants instance
|
26
|
+
LHOPT_CONSTANTS: LHOPTConstants = LHOPTConstants(
|
27
|
+
IS_NAN=1.0,
|
28
|
+
NOT_NAN=0.0,
|
29
|
+
IS_INF=1.0,
|
30
|
+
NOT_INF=0.0,
|
31
|
+
TANH_BOUND=10.0,
|
32
|
+
DEFAULT_DECAY_FACTOR=1.25,
|
33
|
+
DEFAULT_TIME_WINDOW=32,
|
34
|
+
DEFAULT_CHECKPOINT_INTERVAL=100,
|
35
|
+
DEFAULT_PERCENTILE=0.6,
|
36
|
+
ZERO_DIVISION_TOLERANCE=1e-8,
|
37
|
+
DEFAULT_SAMPLE_FREQUENCY=4,
|
38
|
+
DEFAULT_VARIANCE_THRESHOLD=1e-6,
|
39
|
+
)
|
@@ -0,0 +1,193 @@
|
|
1
|
+
# ======================================================================================================================
|
2
|
+
#
|
3
|
+
# GRADIENT OBSERVERS
|
4
|
+
#
|
5
|
+
# ======================================================================================================================
|
6
|
+
|
7
|
+
from typing import Any
|
8
|
+
|
9
|
+
from libinephany.observations import observation_utils, statistic_trackers
|
10
|
+
from libinephany.observations.observation_utils import StatisticStorageTypes
|
11
|
+
from libinephany.observations.observers.base_observers import GlobalObserver
|
12
|
+
from libinephany.observations.observers.global_observers.base_classes import LHOPTBaseObserver
|
13
|
+
from libinephany.observations.observers.global_observers.constants import LHOPT_CONSTANTS
|
14
|
+
from libinephany.pydantic_models.schemas.observation_models import ObservationInputs
|
15
|
+
from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
|
16
|
+
from libinephany.pydantic_models.states.hyperparameter_states import HyperparameterStates
|
17
|
+
|
18
|
+
|
19
|
+
class GlobalFirstOrderGradients(GlobalObserver):
|
20
|
+
|
21
|
+
def _get_observation_format(self) -> StatisticStorageTypes:
|
22
|
+
"""
|
23
|
+
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
24
|
+
enumeration class.
|
25
|
+
"""
|
26
|
+
|
27
|
+
return StatisticStorageTypes.TENSOR_STATISTICS
|
28
|
+
|
29
|
+
def _observe(
|
30
|
+
self,
|
31
|
+
observation_inputs: ObservationInputs,
|
32
|
+
hyperparameter_states: HyperparameterStates,
|
33
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
34
|
+
action_taken: float | int | None,
|
35
|
+
) -> float | int | list[int | float] | TensorStatistics:
|
36
|
+
"""
|
37
|
+
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
38
|
+
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
39
|
+
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
40
|
+
names to floats or TensorStatistic models.
|
41
|
+
:param action_taken: Action taken by the agent this class instance is assigned to.
|
42
|
+
:return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
|
43
|
+
"""
|
44
|
+
|
45
|
+
statistics = tracked_statistics[statistic_trackers.FirstOrderGradients.__name__]
|
46
|
+
|
47
|
+
return observation_utils.average_tensor_statistics(tensor_statistics=list(statistics.values())) # type: ignore
|
48
|
+
|
49
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
50
|
+
"""
|
51
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
52
|
+
needed.
|
53
|
+
"""
|
54
|
+
|
55
|
+
return {statistic_trackers.FirstOrderGradients.__name__: dict(skip_statistics=self.skip_statistics)}
|
56
|
+
|
57
|
+
|
58
|
+
class GlobalSecondOrderGradients(GlobalObserver):
|
59
|
+
|
60
|
+
def __init__(
|
61
|
+
self,
|
62
|
+
*,
|
63
|
+
compute_hessian_diagonal: bool = False,
|
64
|
+
**kwargs,
|
65
|
+
) -> None:
|
66
|
+
"""
|
67
|
+
:param compute_hessian_diagonal: Whether to compute the Hessian diagonal to determine second order gradients
|
68
|
+
or use the squared first order gradients as approximations in the same way Adam does.
|
69
|
+
:param kwargs: Miscellaneous keyword arguments.
|
70
|
+
"""
|
71
|
+
|
72
|
+
super().__init__(**kwargs)
|
73
|
+
|
74
|
+
self.compute_hessian_diagonal = compute_hessian_diagonal
|
75
|
+
|
76
|
+
def _get_observation_format(self) -> StatisticStorageTypes:
|
77
|
+
"""
|
78
|
+
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
79
|
+
enumeration class.
|
80
|
+
"""
|
81
|
+
|
82
|
+
return StatisticStorageTypes.TENSOR_STATISTICS
|
83
|
+
|
84
|
+
def _observe(
|
85
|
+
self,
|
86
|
+
observation_inputs: ObservationInputs,
|
87
|
+
hyperparameter_states: HyperparameterStates,
|
88
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
89
|
+
action_taken: float | int | None,
|
90
|
+
) -> float | int | list[int | float] | TensorStatistics:
|
91
|
+
"""
|
92
|
+
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
93
|
+
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
94
|
+
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
95
|
+
names to floats or TensorStatistic models.
|
96
|
+
:param action_taken: Action taken by the agent this class instance is assigned to.
|
97
|
+
:return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
|
98
|
+
"""
|
99
|
+
|
100
|
+
statistics = tracked_statistics[statistic_trackers.SecondOrderGradients.__name__]
|
101
|
+
|
102
|
+
return observation_utils.average_tensor_statistics(tensor_statistics=list(statistics.values())) # type: ignore
|
103
|
+
|
104
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
105
|
+
"""
|
106
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
107
|
+
needed.
|
108
|
+
"""
|
109
|
+
|
110
|
+
return {
|
111
|
+
statistic_trackers.SecondOrderGradients.__name__: dict(
|
112
|
+
skip_statistics=self.skip_statistics, compute_hessian_diagonal=self.compute_hessian_diagonal
|
113
|
+
)
|
114
|
+
}
|
115
|
+
|
116
|
+
|
117
|
+
class LHOPTGradientVarianceFraction(LHOPTBaseObserver):
|
118
|
+
"""
|
119
|
+
This is a global observer from the OpenAI paper "Learning to Optimize with Reinforcement Learning"
|
120
|
+
https://arxiv.org/abs/2305.18291.
|
121
|
+
|
122
|
+
It returns two-dimensional observations: [raw_value, cdf_feature] for gradient variance fraction values.
|
123
|
+
"""
|
124
|
+
|
125
|
+
def __init__(
|
126
|
+
self,
|
127
|
+
*,
|
128
|
+
variance_threshold: float = LHOPT_CONSTANTS["DEFAULT_VARIANCE_THRESHOLD"],
|
129
|
+
**kwargs,
|
130
|
+
) -> None:
|
131
|
+
"""
|
132
|
+
:param variance_threshold: Threshold for variance comparison in gradient variance fraction calculation
|
133
|
+
:param kwargs: Other observation keyword arguments.
|
134
|
+
"""
|
135
|
+
super().__init__(**kwargs)
|
136
|
+
self.variance_threshold = variance_threshold
|
137
|
+
|
138
|
+
@property
|
139
|
+
def can_standardize(self) -> bool:
|
140
|
+
"""
|
141
|
+
This observer has its own CDF calculation, no need to standardize.
|
142
|
+
:return: Whether the observation can be standardized.
|
143
|
+
"""
|
144
|
+
return False
|
145
|
+
|
146
|
+
def _get_observation_format(self) -> StatisticStorageTypes:
|
147
|
+
"""
|
148
|
+
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
149
|
+
enumeration class.
|
150
|
+
"""
|
151
|
+
return StatisticStorageTypes.VECTOR
|
152
|
+
|
153
|
+
@property
|
154
|
+
def vector_length(self) -> int:
|
155
|
+
"""
|
156
|
+
:return: Length of the vector returned by this observation if it returns a vector.
|
157
|
+
"""
|
158
|
+
return 2 # [raw_value, cdf_feature]
|
159
|
+
|
160
|
+
def _observe(
|
161
|
+
self,
|
162
|
+
observation_inputs: ObservationInputs,
|
163
|
+
hyperparameter_states: HyperparameterStates,
|
164
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
165
|
+
action_taken: float | int | None,
|
166
|
+
) -> float | int | list[int | float] | TensorStatistics:
|
167
|
+
"""
|
168
|
+
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
169
|
+
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
170
|
+
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
171
|
+
names to floats or TensorStatistic models.
|
172
|
+
:param action_taken: Action taken by the agent this class instance is assigned to.
|
173
|
+
:return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
|
174
|
+
"""
|
175
|
+
if statistic_trackers.GradientVarianceFraction.__name__ not in tracked_statistics:
|
176
|
+
return [0.0, 0.0]
|
177
|
+
|
178
|
+
raw_value = list(tracked_statistics[statistic_trackers.GradientVarianceFraction.__name__].values())[0] # type: ignore[list-item]
|
179
|
+
|
180
|
+
cdf_feature = self._compute_cdf_feature(raw_value) # type: ignore[arg-type]
|
181
|
+
self._update_time()
|
182
|
+
|
183
|
+
return [raw_value, cdf_feature] # type: ignore[list-item]
|
184
|
+
|
185
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
186
|
+
"""
|
187
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
188
|
+
needed.
|
189
|
+
"""
|
190
|
+
|
191
|
+
return {
|
192
|
+
statistic_trackers.GradientVarianceFraction.__name__: dict(variance_threshold=self.variance_threshold),
|
193
|
+
}
|