alberta-framework 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alberta_framework/__init__.py +196 -0
- alberta_framework/core/__init__.py +27 -0
- alberta_framework/core/learners.py +530 -0
- alberta_framework/core/normalizers.py +192 -0
- alberta_framework/core/optimizers.py +422 -0
- alberta_framework/core/types.py +198 -0
- alberta_framework/py.typed +0 -0
- alberta_framework/streams/__init__.py +83 -0
- alberta_framework/streams/base.py +70 -0
- alberta_framework/streams/gymnasium.py +655 -0
- alberta_framework/streams/synthetic.py +995 -0
- alberta_framework/utils/__init__.py +113 -0
- alberta_framework/utils/experiments.py +334 -0
- alberta_framework/utils/export.py +509 -0
- alberta_framework/utils/metrics.py +112 -0
- alberta_framework/utils/statistics.py +527 -0
- alberta_framework/utils/timing.py +138 -0
- alberta_framework/utils/visualization.py +571 -0
- alberta_framework-0.1.0.dist-info/METADATA +198 -0
- alberta_framework-0.1.0.dist-info/RECORD +22 -0
- alberta_framework-0.1.0.dist-info/WHEEL +4 -0
- alberta_framework-0.1.0.dist-info/licenses/LICENSE +190 -0
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
"""Online feature normalization for continual learning.
|
|
2
|
+
|
|
3
|
+
Implements online (streaming) normalization that updates estimates of mean
|
|
4
|
+
and variance at every time step, following the principle of temporal uniformity.
|
|
5
|
+
|
|
6
|
+
Reference: Welford's online algorithm for numerical stability.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import NamedTuple
|
|
10
|
+
|
|
11
|
+
import jax.numpy as jnp
|
|
12
|
+
from jax import Array
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class NormalizerState(NamedTuple):
|
|
16
|
+
"""State for online feature normalization.
|
|
17
|
+
|
|
18
|
+
Uses Welford's online algorithm for numerically stable estimation
|
|
19
|
+
of running mean and variance.
|
|
20
|
+
|
|
21
|
+
Attributes:
|
|
22
|
+
mean: Running mean estimate per feature
|
|
23
|
+
var: Running variance estimate per feature
|
|
24
|
+
sample_count: Number of samples seen
|
|
25
|
+
decay: Exponential decay factor for estimates (1.0 = no decay, pure online)
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
mean: Array # Shape: (feature_dim,)
|
|
29
|
+
var: Array # Shape: (feature_dim,)
|
|
30
|
+
sample_count: Array # Scalar
|
|
31
|
+
decay: Array # Scalar
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class OnlineNormalizer:
|
|
35
|
+
"""Online feature normalizer for continual learning.
|
|
36
|
+
|
|
37
|
+
Normalizes features using running estimates of mean and standard deviation:
|
|
38
|
+
x_normalized = (x - mean) / (std + epsilon)
|
|
39
|
+
|
|
40
|
+
The normalizer updates its estimates at every time step, following
|
|
41
|
+
temporal uniformity. Uses exponential moving average for non-stationary
|
|
42
|
+
environments.
|
|
43
|
+
|
|
44
|
+
Attributes:
|
|
45
|
+
epsilon: Small constant for numerical stability
|
|
46
|
+
decay: Exponential decay for running estimates (0.99 = slower adaptation)
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
epsilon: float = 1e-8,
|
|
52
|
+
decay: float = 0.99,
|
|
53
|
+
):
|
|
54
|
+
"""Initialize the online normalizer.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
epsilon: Small constant added to std for numerical stability
|
|
58
|
+
decay: Exponential decay factor for running estimates.
|
|
59
|
+
Lower values adapt faster to changes.
|
|
60
|
+
1.0 means pure online average (no decay).
|
|
61
|
+
"""
|
|
62
|
+
self._epsilon = epsilon
|
|
63
|
+
self._decay = decay
|
|
64
|
+
|
|
65
|
+
def init(self, feature_dim: int) -> NormalizerState:
|
|
66
|
+
"""Initialize normalizer state.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
feature_dim: Dimension of feature vectors
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
Initial normalizer state with zero mean and unit variance
|
|
73
|
+
"""
|
|
74
|
+
return NormalizerState(
|
|
75
|
+
mean=jnp.zeros(feature_dim, dtype=jnp.float32),
|
|
76
|
+
var=jnp.ones(feature_dim, dtype=jnp.float32),
|
|
77
|
+
sample_count=jnp.array(0.0, dtype=jnp.float32),
|
|
78
|
+
decay=jnp.array(self._decay, dtype=jnp.float32),
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
def normalize(
|
|
82
|
+
self,
|
|
83
|
+
state: NormalizerState,
|
|
84
|
+
observation: Array,
|
|
85
|
+
) -> tuple[Array, NormalizerState]:
|
|
86
|
+
"""Normalize observation and update running statistics.
|
|
87
|
+
|
|
88
|
+
This method both normalizes the current observation AND updates
|
|
89
|
+
the running statistics, maintaining temporal uniformity.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
state: Current normalizer state
|
|
93
|
+
observation: Raw feature vector
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
Tuple of (normalized_observation, new_state)
|
|
97
|
+
"""
|
|
98
|
+
# Update count
|
|
99
|
+
new_count = state.sample_count + 1.0
|
|
100
|
+
|
|
101
|
+
# Compute effective decay (ramp up from 0 to target decay)
|
|
102
|
+
# This prevents instability in early steps
|
|
103
|
+
effective_decay = jnp.minimum(
|
|
104
|
+
state.decay,
|
|
105
|
+
1.0 - 1.0 / (new_count + 1.0)
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# Update mean using exponential moving average
|
|
109
|
+
delta = observation - state.mean
|
|
110
|
+
new_mean = state.mean + (1.0 - effective_decay) * delta
|
|
111
|
+
|
|
112
|
+
# Update variance using exponential moving average of squared deviations
|
|
113
|
+
# This is a simplified Welford's algorithm adapted for EMA
|
|
114
|
+
delta2 = observation - new_mean
|
|
115
|
+
new_var = effective_decay * state.var + (1.0 - effective_decay) * delta * delta2
|
|
116
|
+
|
|
117
|
+
# Ensure variance is positive
|
|
118
|
+
new_var = jnp.maximum(new_var, self._epsilon)
|
|
119
|
+
|
|
120
|
+
# Normalize using updated statistics
|
|
121
|
+
std = jnp.sqrt(new_var)
|
|
122
|
+
normalized = (observation - new_mean) / (std + self._epsilon)
|
|
123
|
+
|
|
124
|
+
new_state = NormalizerState(
|
|
125
|
+
mean=new_mean,
|
|
126
|
+
var=new_var,
|
|
127
|
+
sample_count=new_count,
|
|
128
|
+
decay=state.decay,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
return normalized, new_state
|
|
132
|
+
|
|
133
|
+
def normalize_only(
|
|
134
|
+
self,
|
|
135
|
+
state: NormalizerState,
|
|
136
|
+
observation: Array,
|
|
137
|
+
) -> Array:
|
|
138
|
+
"""Normalize observation without updating statistics.
|
|
139
|
+
|
|
140
|
+
Useful for inference or when you want to normalize multiple
|
|
141
|
+
observations with the same statistics.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
state: Current normalizer state
|
|
145
|
+
observation: Raw feature vector
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
Normalized observation
|
|
149
|
+
"""
|
|
150
|
+
std = jnp.sqrt(state.var)
|
|
151
|
+
return (observation - state.mean) / (std + self._epsilon)
|
|
152
|
+
|
|
153
|
+
def update_only(
|
|
154
|
+
self,
|
|
155
|
+
state: NormalizerState,
|
|
156
|
+
observation: Array,
|
|
157
|
+
) -> NormalizerState:
|
|
158
|
+
"""Update statistics without returning normalized observation.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
state: Current normalizer state
|
|
162
|
+
observation: Raw feature vector
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
Updated normalizer state
|
|
166
|
+
"""
|
|
167
|
+
_, new_state = self.normalize(state, observation)
|
|
168
|
+
return new_state
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def create_normalizer_state(
|
|
172
|
+
feature_dim: int,
|
|
173
|
+
decay: float = 0.99,
|
|
174
|
+
) -> NormalizerState:
|
|
175
|
+
"""Create initial normalizer state.
|
|
176
|
+
|
|
177
|
+
Convenience function for creating normalizer state without
|
|
178
|
+
instantiating the OnlineNormalizer class.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
feature_dim: Dimension of feature vectors
|
|
182
|
+
decay: Exponential decay factor
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
Initial normalizer state
|
|
186
|
+
"""
|
|
187
|
+
return NormalizerState(
|
|
188
|
+
mean=jnp.zeros(feature_dim, dtype=jnp.float32),
|
|
189
|
+
var=jnp.ones(feature_dim, dtype=jnp.float32),
|
|
190
|
+
sample_count=jnp.array(0.0, dtype=jnp.float32),
|
|
191
|
+
decay=jnp.array(decay, dtype=jnp.float32),
|
|
192
|
+
)
|
|
@@ -0,0 +1,422 @@
|
|
|
1
|
+
"""Optimizers for continual learning.
|
|
2
|
+
|
|
3
|
+
Implements LMS (fixed step-size baseline), IDBD (meta-learned step-sizes),
|
|
4
|
+
and Autostep (tuning-free step-size adaptation) for Step 1 of the Alberta Plan.
|
|
5
|
+
|
|
6
|
+
References:
|
|
7
|
+
- Sutton 1992, "Adapting Bias by Gradient Descent: An Incremental
|
|
8
|
+
Version of Delta-Bar-Delta"
|
|
9
|
+
- Mahmood et al. 2012, "Tuning-free step-size adaptation"
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from abc import ABC, abstractmethod
|
|
13
|
+
from typing import NamedTuple
|
|
14
|
+
|
|
15
|
+
import jax.numpy as jnp
|
|
16
|
+
from jax import Array
|
|
17
|
+
|
|
18
|
+
from alberta_framework.core.types import AutostepState, IDBDState, LMSState
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class OptimizerUpdate(NamedTuple):
|
|
22
|
+
"""Result of an optimizer update step.
|
|
23
|
+
|
|
24
|
+
Attributes:
|
|
25
|
+
weight_delta: Change to apply to weights
|
|
26
|
+
bias_delta: Change to apply to bias
|
|
27
|
+
new_state: Updated optimizer state
|
|
28
|
+
metrics: Dictionary of metrics for logging (values are JAX arrays for scan compatibility)
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
weight_delta: Array
|
|
32
|
+
bias_delta: Array
|
|
33
|
+
new_state: LMSState | IDBDState | AutostepState
|
|
34
|
+
metrics: dict[str, Array]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Optimizer[StateT: (LMSState, IDBDState, AutostepState)](ABC):
|
|
38
|
+
"""Base class for optimizers."""
|
|
39
|
+
|
|
40
|
+
@abstractmethod
|
|
41
|
+
def init(self, feature_dim: int) -> StateT:
|
|
42
|
+
"""Initialize optimizer state.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
feature_dim: Dimension of weight vector
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
Initial optimizer state
|
|
49
|
+
"""
|
|
50
|
+
...
|
|
51
|
+
|
|
52
|
+
@abstractmethod
|
|
53
|
+
def update(
|
|
54
|
+
self,
|
|
55
|
+
state: StateT,
|
|
56
|
+
error: Array,
|
|
57
|
+
observation: Array,
|
|
58
|
+
) -> OptimizerUpdate:
|
|
59
|
+
"""Compute weight updates given prediction error.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
state: Current optimizer state
|
|
63
|
+
error: Prediction error (target - prediction)
|
|
64
|
+
observation: Current observation/feature vector
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
OptimizerUpdate with deltas and new state
|
|
68
|
+
"""
|
|
69
|
+
...
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class LMS(Optimizer[LMSState]):
|
|
73
|
+
"""Least Mean Square optimizer with fixed step-size.
|
|
74
|
+
|
|
75
|
+
The simplest gradient-based optimizer: w_{t+1} = w_t + alpha * delta * x_t
|
|
76
|
+
|
|
77
|
+
This serves as a baseline. The challenge is that the optimal step-size
|
|
78
|
+
depends on the problem and changes as the task becomes non-stationary.
|
|
79
|
+
|
|
80
|
+
Attributes:
|
|
81
|
+
step_size: Fixed learning rate alpha
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
def __init__(self, step_size: float = 0.01):
|
|
85
|
+
"""Initialize LMS optimizer.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
step_size: Fixed learning rate
|
|
89
|
+
"""
|
|
90
|
+
self._step_size = step_size
|
|
91
|
+
|
|
92
|
+
def init(self, feature_dim: int) -> LMSState:
|
|
93
|
+
"""Initialize LMS state.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
feature_dim: Dimension of weight vector (unused for LMS)
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
LMS state containing the step-size
|
|
100
|
+
"""
|
|
101
|
+
return LMSState(step_size=jnp.array(self._step_size, dtype=jnp.float32))
|
|
102
|
+
|
|
103
|
+
def update(
|
|
104
|
+
self,
|
|
105
|
+
state: LMSState,
|
|
106
|
+
error: Array,
|
|
107
|
+
observation: Array,
|
|
108
|
+
) -> OptimizerUpdate:
|
|
109
|
+
"""Compute LMS weight update.
|
|
110
|
+
|
|
111
|
+
Update rule: delta_w = alpha * error * x
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
state: Current LMS state
|
|
115
|
+
error: Prediction error (scalar)
|
|
116
|
+
observation: Feature vector
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
OptimizerUpdate with weight and bias deltas
|
|
120
|
+
"""
|
|
121
|
+
alpha = state.step_size
|
|
122
|
+
error_scalar = jnp.squeeze(error)
|
|
123
|
+
|
|
124
|
+
# Weight update: alpha * error * x
|
|
125
|
+
weight_delta = alpha * error_scalar * observation
|
|
126
|
+
|
|
127
|
+
# Bias update: alpha * error
|
|
128
|
+
bias_delta = alpha * error_scalar
|
|
129
|
+
|
|
130
|
+
return OptimizerUpdate(
|
|
131
|
+
weight_delta=weight_delta,
|
|
132
|
+
bias_delta=bias_delta,
|
|
133
|
+
new_state=state, # LMS state doesn't change
|
|
134
|
+
metrics={"step_size": alpha},
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class IDBD(Optimizer[IDBDState]):
|
|
139
|
+
"""Incremental Delta-Bar-Delta optimizer.
|
|
140
|
+
|
|
141
|
+
IDBD maintains per-weight adaptive step-sizes that are meta-learned
|
|
142
|
+
based on gradient correlation. When successive gradients agree in sign,
|
|
143
|
+
the step-size for that weight increases. When they disagree, it decreases.
|
|
144
|
+
|
|
145
|
+
This implements Sutton's 1992 algorithm for adapting step-sizes online
|
|
146
|
+
without requiring manual tuning.
|
|
147
|
+
|
|
148
|
+
Reference: Sutton, R.S. (1992). "Adapting Bias by Gradient Descent:
|
|
149
|
+
An Incremental Version of Delta-Bar-Delta"
|
|
150
|
+
|
|
151
|
+
Attributes:
|
|
152
|
+
initial_step_size: Initial per-weight step-size
|
|
153
|
+
meta_step_size: Meta learning rate beta for adapting step-sizes
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
def __init__(
|
|
157
|
+
self,
|
|
158
|
+
initial_step_size: float = 0.01,
|
|
159
|
+
meta_step_size: float = 0.01,
|
|
160
|
+
):
|
|
161
|
+
"""Initialize IDBD optimizer.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
initial_step_size: Initial value for per-weight step-sizes
|
|
165
|
+
meta_step_size: Meta learning rate beta for adapting step-sizes
|
|
166
|
+
"""
|
|
167
|
+
self._initial_step_size = initial_step_size
|
|
168
|
+
self._meta_step_size = meta_step_size
|
|
169
|
+
|
|
170
|
+
def init(self, feature_dim: int) -> IDBDState:
|
|
171
|
+
"""Initialize IDBD state.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
feature_dim: Dimension of weight vector
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
IDBD state with per-weight step-sizes and traces
|
|
178
|
+
"""
|
|
179
|
+
return IDBDState(
|
|
180
|
+
log_step_sizes=jnp.full(
|
|
181
|
+
feature_dim, jnp.log(self._initial_step_size), dtype=jnp.float32
|
|
182
|
+
),
|
|
183
|
+
traces=jnp.zeros(feature_dim, dtype=jnp.float32),
|
|
184
|
+
meta_step_size=jnp.array(self._meta_step_size, dtype=jnp.float32),
|
|
185
|
+
bias_step_size=jnp.array(self._initial_step_size, dtype=jnp.float32),
|
|
186
|
+
bias_trace=jnp.array(0.0, dtype=jnp.float32),
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
def update(
|
|
190
|
+
self,
|
|
191
|
+
state: IDBDState,
|
|
192
|
+
error: Array,
|
|
193
|
+
observation: Array,
|
|
194
|
+
) -> OptimizerUpdate:
|
|
195
|
+
"""Compute IDBD weight update with adaptive step-sizes.
|
|
196
|
+
|
|
197
|
+
The IDBD algorithm:
|
|
198
|
+
1. Compute step-sizes: alpha_i = exp(log_alpha_i)
|
|
199
|
+
2. Update weights: w_i += alpha_i * error * x_i
|
|
200
|
+
3. Update log step-sizes: log_alpha_i += beta * error * x_i * h_i
|
|
201
|
+
4. Update traces: h_i = h_i * max(0, 1 - alpha_i * x_i^2) + alpha_i * error * x_i
|
|
202
|
+
|
|
203
|
+
The trace h_i tracks the correlation between current and past gradients.
|
|
204
|
+
When gradients consistently point the same direction, h_i grows,
|
|
205
|
+
leading to larger step-sizes.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
state: Current IDBD state
|
|
209
|
+
error: Prediction error (scalar)
|
|
210
|
+
observation: Feature vector
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
OptimizerUpdate with weight deltas and updated state
|
|
214
|
+
"""
|
|
215
|
+
error_scalar = jnp.squeeze(error)
|
|
216
|
+
beta = state.meta_step_size
|
|
217
|
+
|
|
218
|
+
# Current step-sizes (exponentiate log values)
|
|
219
|
+
alphas = jnp.exp(state.log_step_sizes)
|
|
220
|
+
|
|
221
|
+
# Weight updates: alpha_i * error * x_i
|
|
222
|
+
weight_delta = alphas * error_scalar * observation
|
|
223
|
+
|
|
224
|
+
# Meta-update: adapt step-sizes based on gradient correlation
|
|
225
|
+
# log_alpha_i += beta * error * x_i * h_i
|
|
226
|
+
gradient_correlation = error_scalar * observation * state.traces
|
|
227
|
+
new_log_step_sizes = state.log_step_sizes + beta * gradient_correlation
|
|
228
|
+
|
|
229
|
+
# Clip log step-sizes to prevent numerical issues
|
|
230
|
+
new_log_step_sizes = jnp.clip(new_log_step_sizes, -10.0, 2.0)
|
|
231
|
+
|
|
232
|
+
# Update traces: h_i = h_i * decay + alpha_i * error * x_i
|
|
233
|
+
# decay = max(0, 1 - alpha_i * x_i^2)
|
|
234
|
+
decay = jnp.maximum(0.0, 1.0 - alphas * observation**2)
|
|
235
|
+
new_traces = state.traces * decay + alphas * error_scalar * observation
|
|
236
|
+
|
|
237
|
+
# Bias updates (similar logic but scalar)
|
|
238
|
+
bias_alpha = state.bias_step_size
|
|
239
|
+
bias_delta = bias_alpha * error_scalar
|
|
240
|
+
|
|
241
|
+
# Update bias step-size
|
|
242
|
+
bias_gradient_correlation = error_scalar * state.bias_trace
|
|
243
|
+
new_bias_step_size = bias_alpha * jnp.exp(beta * bias_gradient_correlation)
|
|
244
|
+
new_bias_step_size = jnp.clip(new_bias_step_size, 1e-6, 1.0)
|
|
245
|
+
|
|
246
|
+
# Update bias trace
|
|
247
|
+
bias_decay = jnp.maximum(0.0, 1.0 - bias_alpha)
|
|
248
|
+
new_bias_trace = state.bias_trace * bias_decay + bias_alpha * error_scalar
|
|
249
|
+
|
|
250
|
+
new_state = IDBDState(
|
|
251
|
+
log_step_sizes=new_log_step_sizes,
|
|
252
|
+
traces=new_traces,
|
|
253
|
+
meta_step_size=beta,
|
|
254
|
+
bias_step_size=new_bias_step_size,
|
|
255
|
+
bias_trace=new_bias_trace,
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
return OptimizerUpdate(
|
|
259
|
+
weight_delta=weight_delta,
|
|
260
|
+
bias_delta=bias_delta,
|
|
261
|
+
new_state=new_state,
|
|
262
|
+
metrics={
|
|
263
|
+
"mean_step_size": jnp.mean(alphas),
|
|
264
|
+
"min_step_size": jnp.min(alphas),
|
|
265
|
+
"max_step_size": jnp.max(alphas),
|
|
266
|
+
},
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
class Autostep(Optimizer[AutostepState]):
|
|
271
|
+
"""Autostep optimizer with tuning-free step-size adaptation.
|
|
272
|
+
|
|
273
|
+
Autostep normalizes gradients to prevent large updates and adapts
|
|
274
|
+
per-weight step-sizes based on gradient correlation. The key innovation
|
|
275
|
+
is automatic normalization that makes the algorithm robust to different
|
|
276
|
+
feature scales.
|
|
277
|
+
|
|
278
|
+
The algorithm maintains:
|
|
279
|
+
- Per-weight step-sizes that adapt based on gradient correlation
|
|
280
|
+
- Running max of absolute gradients for normalization
|
|
281
|
+
- Traces for detecting consistent gradient directions
|
|
282
|
+
|
|
283
|
+
Reference: Mahmood, A.R., Sutton, R.S., Degris, T., & Pilarski, P.M. (2012).
|
|
284
|
+
"Tuning-free step-size adaptation"
|
|
285
|
+
|
|
286
|
+
Attributes:
|
|
287
|
+
initial_step_size: Initial per-weight step-size
|
|
288
|
+
meta_step_size: Meta learning rate mu for adapting step-sizes
|
|
289
|
+
normalizer_decay: Decay factor tau for gradient normalizers
|
|
290
|
+
"""
|
|
291
|
+
|
|
292
|
+
def __init__(
|
|
293
|
+
self,
|
|
294
|
+
initial_step_size: float = 0.01,
|
|
295
|
+
meta_step_size: float = 0.01,
|
|
296
|
+
normalizer_decay: float = 0.99,
|
|
297
|
+
):
|
|
298
|
+
"""Initialize Autostep optimizer.
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
initial_step_size: Initial value for per-weight step-sizes
|
|
302
|
+
meta_step_size: Meta learning rate for adapting step-sizes
|
|
303
|
+
normalizer_decay: Decay factor for gradient normalizers (higher = slower decay)
|
|
304
|
+
"""
|
|
305
|
+
self._initial_step_size = initial_step_size
|
|
306
|
+
self._meta_step_size = meta_step_size
|
|
307
|
+
self._normalizer_decay = normalizer_decay
|
|
308
|
+
|
|
309
|
+
def init(self, feature_dim: int) -> AutostepState:
|
|
310
|
+
"""Initialize Autostep state.
|
|
311
|
+
|
|
312
|
+
Args:
|
|
313
|
+
feature_dim: Dimension of weight vector
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
Autostep state with per-weight step-sizes, traces, and normalizers
|
|
317
|
+
"""
|
|
318
|
+
return AutostepState(
|
|
319
|
+
step_sizes=jnp.full(feature_dim, self._initial_step_size, dtype=jnp.float32),
|
|
320
|
+
traces=jnp.zeros(feature_dim, dtype=jnp.float32),
|
|
321
|
+
normalizers=jnp.ones(feature_dim, dtype=jnp.float32),
|
|
322
|
+
meta_step_size=jnp.array(self._meta_step_size, dtype=jnp.float32),
|
|
323
|
+
normalizer_decay=jnp.array(self._normalizer_decay, dtype=jnp.float32),
|
|
324
|
+
bias_step_size=jnp.array(self._initial_step_size, dtype=jnp.float32),
|
|
325
|
+
bias_trace=jnp.array(0.0, dtype=jnp.float32),
|
|
326
|
+
bias_normalizer=jnp.array(1.0, dtype=jnp.float32),
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
def update(
|
|
330
|
+
self,
|
|
331
|
+
state: AutostepState,
|
|
332
|
+
error: Array,
|
|
333
|
+
observation: Array,
|
|
334
|
+
) -> OptimizerUpdate:
|
|
335
|
+
"""Compute Autostep weight update with normalized gradients.
|
|
336
|
+
|
|
337
|
+
The Autostep algorithm:
|
|
338
|
+
1. Compute gradient: g_i = error * x_i
|
|
339
|
+
2. Normalize gradient: g_i' = g_i / max(|g_i|, v_i)
|
|
340
|
+
3. Update weights: w_i += alpha_i * g_i'
|
|
341
|
+
4. Update step-sizes: alpha_i *= exp(mu * g_i' * h_i)
|
|
342
|
+
5. Update traces: h_i = h_i * (1 - alpha_i) + alpha_i * g_i'
|
|
343
|
+
6. Update normalizers: v_i = max(|g_i|, v_i * tau)
|
|
344
|
+
|
|
345
|
+
Args:
|
|
346
|
+
state: Current Autostep state
|
|
347
|
+
error: Prediction error (scalar)
|
|
348
|
+
observation: Feature vector
|
|
349
|
+
|
|
350
|
+
Returns:
|
|
351
|
+
OptimizerUpdate with weight deltas and updated state
|
|
352
|
+
"""
|
|
353
|
+
error_scalar = jnp.squeeze(error)
|
|
354
|
+
mu = state.meta_step_size
|
|
355
|
+
tau = state.normalizer_decay
|
|
356
|
+
|
|
357
|
+
# Compute raw gradient
|
|
358
|
+
gradient = error_scalar * observation
|
|
359
|
+
|
|
360
|
+
# Normalize gradient using running max
|
|
361
|
+
abs_gradient = jnp.abs(gradient)
|
|
362
|
+
normalizer = jnp.maximum(abs_gradient, state.normalizers)
|
|
363
|
+
normalized_gradient = gradient / (normalizer + 1e-8)
|
|
364
|
+
|
|
365
|
+
# Compute weight delta using normalized gradient
|
|
366
|
+
weight_delta = state.step_sizes * normalized_gradient
|
|
367
|
+
|
|
368
|
+
# Update step-sizes based on gradient correlation
|
|
369
|
+
gradient_correlation = normalized_gradient * state.traces
|
|
370
|
+
new_step_sizes = state.step_sizes * jnp.exp(mu * gradient_correlation)
|
|
371
|
+
|
|
372
|
+
# Clip step-sizes to prevent instability
|
|
373
|
+
new_step_sizes = jnp.clip(new_step_sizes, 1e-8, 1.0)
|
|
374
|
+
|
|
375
|
+
# Update traces with decay based on step-size
|
|
376
|
+
trace_decay = 1.0 - state.step_sizes
|
|
377
|
+
new_traces = state.traces * trace_decay + state.step_sizes * normalized_gradient
|
|
378
|
+
|
|
379
|
+
# Update normalizers with decay
|
|
380
|
+
new_normalizers = jnp.maximum(abs_gradient, state.normalizers * tau)
|
|
381
|
+
|
|
382
|
+
# Bias updates (similar logic)
|
|
383
|
+
bias_gradient = error_scalar
|
|
384
|
+
abs_bias_gradient = jnp.abs(bias_gradient)
|
|
385
|
+
bias_normalizer = jnp.maximum(abs_bias_gradient, state.bias_normalizer)
|
|
386
|
+
normalized_bias_gradient = bias_gradient / (bias_normalizer + 1e-8)
|
|
387
|
+
|
|
388
|
+
bias_delta = state.bias_step_size * normalized_bias_gradient
|
|
389
|
+
|
|
390
|
+
bias_correlation = normalized_bias_gradient * state.bias_trace
|
|
391
|
+
new_bias_step_size = state.bias_step_size * jnp.exp(mu * bias_correlation)
|
|
392
|
+
new_bias_step_size = jnp.clip(new_bias_step_size, 1e-8, 1.0)
|
|
393
|
+
|
|
394
|
+
bias_trace_decay = 1.0 - state.bias_step_size
|
|
395
|
+
new_bias_trace = (
|
|
396
|
+
state.bias_trace * bias_trace_decay + state.bias_step_size * normalized_bias_gradient
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
new_bias_normalizer = jnp.maximum(abs_bias_gradient, state.bias_normalizer * tau)
|
|
400
|
+
|
|
401
|
+
new_state = AutostepState(
|
|
402
|
+
step_sizes=new_step_sizes,
|
|
403
|
+
traces=new_traces,
|
|
404
|
+
normalizers=new_normalizers,
|
|
405
|
+
meta_step_size=mu,
|
|
406
|
+
normalizer_decay=tau,
|
|
407
|
+
bias_step_size=new_bias_step_size,
|
|
408
|
+
bias_trace=new_bias_trace,
|
|
409
|
+
bias_normalizer=new_bias_normalizer,
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
return OptimizerUpdate(
|
|
413
|
+
weight_delta=weight_delta,
|
|
414
|
+
bias_delta=bias_delta,
|
|
415
|
+
new_state=new_state,
|
|
416
|
+
metrics={
|
|
417
|
+
"mean_step_size": jnp.mean(state.step_sizes),
|
|
418
|
+
"min_step_size": jnp.min(state.step_sizes),
|
|
419
|
+
"max_step_size": jnp.max(state.step_sizes),
|
|
420
|
+
"mean_normalizer": jnp.mean(state.normalizers),
|
|
421
|
+
},
|
|
422
|
+
)
|