libinephany 0.18.0__py3-none-any.whl → 0.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- libinephany/observations/observation_utils.py +19 -2
- libinephany/observations/observers/global_observers/__init__.py +19 -1
- libinephany/observations/observers/global_observers/constants.py +2 -0
- libinephany/observations/observers/global_observers/gradient_observers.py +319 -1
- libinephany/observations/observers/global_observers/model_observers.py +219 -3
- libinephany/observations/observers/local_observers.py +127 -1
- libinephany/observations/statistic_trackers.py +595 -0
- libinephany/utils/constants.py +3 -3
- {libinephany-0.18.0.dist-info → libinephany-0.19.0.dist-info}/METADATA +1 -1
- {libinephany-0.18.0.dist-info → libinephany-0.19.0.dist-info}/RECORD +13 -13
- {libinephany-0.18.0.dist-info → libinephany-0.19.0.dist-info}/WHEEL +0 -0
- {libinephany-0.18.0.dist-info → libinephany-0.19.0.dist-info}/licenses/LICENSE +0 -0
- {libinephany-0.18.0.dist-info → libinephany-0.19.0.dist-info}/top_level.txt +0 -0
@@ -4,11 +4,13 @@
|
|
4
4
|
#
|
5
5
|
# ======================================================================================================================
|
6
6
|
|
7
|
+
import math
|
7
8
|
from typing import Any
|
8
9
|
|
9
10
|
from libinephany.observations import observation_utils, statistic_trackers
|
10
|
-
from libinephany.observations.observation_utils import StatisticStorageTypes
|
11
|
+
from libinephany.observations.observation_utils import StatisticStorageTypes, compute_cdf_feature
|
11
12
|
from libinephany.observations.observers.base_observers import LocalObserver
|
13
|
+
from libinephany.observations.observers.global_observers.constants import LHOPT_CONSTANTS
|
12
14
|
from libinephany.pydantic_models.schemas.observation_models import ObservationInputs
|
13
15
|
from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
|
14
16
|
from libinephany.pydantic_models.states.hyperparameter_states import HyperparameterStates
|
@@ -1069,3 +1071,127 @@ class PercentageDepth(LocalObserver):
|
|
1069
1071
|
"""
|
1070
1072
|
|
1071
1073
|
return {}
|
1074
|
+
|
1075
|
+
|
1076
|
+
class LogOfNoiseScaleObserver(LocalObserver):
|
1077
|
+
|
1078
|
+
def __init__(
|
1079
|
+
self,
|
1080
|
+
*,
|
1081
|
+
decay_factor: float = LHOPT_CONSTANTS["DEFAULT_DECAY_FACTOR"],
|
1082
|
+
time_window: int = LHOPT_CONSTANTS["DEFAULT_TIME_WINDOW"],
|
1083
|
+
skip_statistics: list[str] | None = None,
|
1084
|
+
**kwargs,
|
1085
|
+
) -> None:
|
1086
|
+
"""
|
1087
|
+
:param decay_factor: Decay factor for CDF calculation in [1, 2.5, 5, 10, 20]
|
1088
|
+
:param time_window: Number of time steps to consider for CDF calculation
|
1089
|
+
:param skip_statistics: Whether to skip the statistics
|
1090
|
+
or use the squared first order gradients as approximations in the same way Adam does.
|
1091
|
+
:param kwargs: Miscellaneous keyword arguments.
|
1092
|
+
"""
|
1093
|
+
|
1094
|
+
super().__init__(**kwargs)
|
1095
|
+
|
1096
|
+
self.skip_statistics = skip_statistics
|
1097
|
+
self.decay_factor = max(0.0, decay_factor)
|
1098
|
+
self.time_window = max(1, time_window)
|
1099
|
+
|
1100
|
+
# Store time series data for CDF calculation
|
1101
|
+
self._time_series: list[tuple[float, float]] = [] # (time, value) pairs
|
1102
|
+
self._current_time: float = 0.0
|
1103
|
+
|
1104
|
+
def _get_observation_format(self) -> StatisticStorageTypes:
|
1105
|
+
"""
|
1106
|
+
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
1107
|
+
enumeration class.
|
1108
|
+
"""
|
1109
|
+
|
1110
|
+
return StatisticStorageTypes.VECTOR
|
1111
|
+
|
1112
|
+
@property
|
1113
|
+
def can_standardize(self) -> bool:
|
1114
|
+
"""
|
1115
|
+
:return: Whether the observation can be standardized.
|
1116
|
+
"""
|
1117
|
+
|
1118
|
+
return False
|
1119
|
+
|
1120
|
+
@property
|
1121
|
+
def can_inform(self) -> bool:
|
1122
|
+
"""
|
1123
|
+
:return: Whether observations from the observer can be used in the agent info dictionary.
|
1124
|
+
"""
|
1125
|
+
|
1126
|
+
return False
|
1127
|
+
|
1128
|
+
def _update_time(self) -> None:
|
1129
|
+
"""Update the current time counter."""
|
1130
|
+
self._current_time += 1.0
|
1131
|
+
|
1132
|
+
def _compute_cdf_feature(self, value: float) -> float:
|
1133
|
+
"""
|
1134
|
+
Compute CDF feature for the given value.
|
1135
|
+
training loss will be added to the time series after this call.
|
1136
|
+
:param value: The value to compute CDF feature for
|
1137
|
+
:return: CDF feature value
|
1138
|
+
"""
|
1139
|
+
return compute_cdf_feature(value, self._time_series, self.decay_factor, self._current_time, self.time_window)
|
1140
|
+
|
1141
|
+
@property
|
1142
|
+
def vector_length(self) -> int:
|
1143
|
+
"""
|
1144
|
+
:return: Length of the vector returned by this observation if it returns a vector.
|
1145
|
+
"""
|
1146
|
+
return 2 # [log_noise_scale, cdf_feature]
|
1147
|
+
|
1148
|
+
def _observe(
|
1149
|
+
self,
|
1150
|
+
observation_inputs: ObservationInputs,
|
1151
|
+
hyperparameter_states: HyperparameterStates,
|
1152
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
1153
|
+
action_taken: float | int | None,
|
1154
|
+
) -> float | int | list[int | float] | TensorStatistics:
|
1155
|
+
"""
|
1156
|
+
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
1157
|
+
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
1158
|
+
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
1159
|
+
names to floats or TensorStatistic models.
|
1160
|
+
:param action_taken: Action taken by the agent this class instance is assigned to.
|
1161
|
+
:return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
|
1162
|
+
"""
|
1163
|
+
|
1164
|
+
statistics = tracked_statistics[statistic_trackers.LogOfNoiseScaleStatistics.__name__]
|
1165
|
+
|
1166
|
+
raw_value = list(statistics.values())[0] # type: ignore[list-item]
|
1167
|
+
assert isinstance(raw_value, float), f"Expected float, got {type(raw_value)}" # to avoid type errors with mypy
|
1168
|
+
batch_size = hyperparameter_states.global_hparams.batch_size.external_value
|
1169
|
+
learning_rate = hyperparameter_states.parameter_group_hparams[
|
1170
|
+
self.parameter_group_name
|
1171
|
+
].learning_rate.external_value
|
1172
|
+
|
1173
|
+
log_b_over_epsilon = math.log(batch_size / learning_rate)
|
1174
|
+
|
1175
|
+
log_noise_scale = raw_value + log_b_over_epsilon
|
1176
|
+
|
1177
|
+
cdf_feature = self._compute_cdf_feature(log_noise_scale) # type: ignore[arg-type]
|
1178
|
+
self._update_time()
|
1179
|
+
|
1180
|
+
return [log_noise_scale, cdf_feature] # type: ignore[list-item]
|
1181
|
+
|
1182
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
1183
|
+
"""
|
1184
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
1185
|
+
needed.
|
1186
|
+
"""
|
1187
|
+
|
1188
|
+
return {
|
1189
|
+
statistic_trackers.LogOfNoiseScaleStatistics.__name__: dict(
|
1190
|
+
skip_statistics=self.skip_statistics, sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"]
|
1191
|
+
)
|
1192
|
+
}
|
1193
|
+
|
1194
|
+
def reset(self) -> None:
|
1195
|
+
"""Reset the observer by clearing the time series."""
|
1196
|
+
self._time_series = []
|
1197
|
+
self._current_time = 0.0
|