alberta-framework 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,192 @@
1
+ """Online feature normalization for continual learning.
2
+
3
+ Implements online (streaming) normalization that updates estimates of mean
4
+ and variance at every time step, following the principle of temporal uniformity.
5
+
6
+ Reference: Welford's online algorithm for numerical stability.
7
+ """
8
+
9
+ from typing import NamedTuple
10
+
11
+ import jax.numpy as jnp
12
+ from jax import Array
13
+
14
+
15
+ class NormalizerState(NamedTuple):
16
+ """State for online feature normalization.
17
+
18
+ Uses Welford's online algorithm for numerically stable estimation
19
+ of running mean and variance.
20
+
21
+ Attributes:
22
+ mean: Running mean estimate per feature
23
+ var: Running variance estimate per feature
24
+ sample_count: Number of samples seen
25
+ decay: Exponential decay factor for estimates (1.0 = no decay, pure online)
26
+ """
27
+
28
+ mean: Array # Shape: (feature_dim,)
29
+ var: Array # Shape: (feature_dim,)
30
+ sample_count: Array # Scalar
31
+ decay: Array # Scalar
32
+
33
+
34
+ class OnlineNormalizer:
35
+ """Online feature normalizer for continual learning.
36
+
37
+ Normalizes features using running estimates of mean and standard deviation:
38
+ `x_normalized = (x - mean) / (std + epsilon)`
39
+
40
+ The normalizer updates its estimates at every time step, following
41
+ temporal uniformity. Uses exponential moving average for non-stationary
42
+ environments.
43
+
44
+ Attributes:
45
+ epsilon: Small constant for numerical stability
46
+ decay: Exponential decay for running estimates (0.99 = slower adaptation)
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ epsilon: float = 1e-8,
52
+ decay: float = 0.99,
53
+ ):
54
+ """Initialize the online normalizer.
55
+
56
+ Args:
57
+ epsilon: Small constant added to std for numerical stability
58
+ decay: Exponential decay factor for running estimates.
59
+ Lower values adapt faster to changes.
60
+ 1.0 means pure online average (no decay).
61
+ """
62
+ self._epsilon = epsilon
63
+ self._decay = decay
64
+
65
+ def init(self, feature_dim: int) -> NormalizerState:
66
+ """Initialize normalizer state.
67
+
68
+ Args:
69
+ feature_dim: Dimension of feature vectors
70
+
71
+ Returns:
72
+ Initial normalizer state with zero mean and unit variance
73
+ """
74
+ return NormalizerState(
75
+ mean=jnp.zeros(feature_dim, dtype=jnp.float32),
76
+ var=jnp.ones(feature_dim, dtype=jnp.float32),
77
+ sample_count=jnp.array(0.0, dtype=jnp.float32),
78
+ decay=jnp.array(self._decay, dtype=jnp.float32),
79
+ )
80
+
81
+ def normalize(
82
+ self,
83
+ state: NormalizerState,
84
+ observation: Array,
85
+ ) -> tuple[Array, NormalizerState]:
86
+ """Normalize observation and update running statistics.
87
+
88
+ This method both normalizes the current observation AND updates
89
+ the running statistics, maintaining temporal uniformity.
90
+
91
+ Args:
92
+ state: Current normalizer state
93
+ observation: Raw feature vector
94
+
95
+ Returns:
96
+ Tuple of (normalized_observation, new_state)
97
+ """
98
+ # Update count
99
+ new_count = state.sample_count + 1.0
100
+
101
+ # Compute effective decay (ramp up from 0 to target decay)
102
+ # This prevents instability in early steps
103
+ effective_decay = jnp.minimum(
104
+ state.decay,
105
+ 1.0 - 1.0 / (new_count + 1.0)
106
+ )
107
+
108
+ # Update mean using exponential moving average
109
+ delta = observation - state.mean
110
+ new_mean = state.mean + (1.0 - effective_decay) * delta
111
+
112
+ # Update variance using exponential moving average of squared deviations
113
+ # This is a simplified Welford's algorithm adapted for EMA
114
+ delta2 = observation - new_mean
115
+ new_var = effective_decay * state.var + (1.0 - effective_decay) * delta * delta2
116
+
117
+ # Ensure variance is positive
118
+ new_var = jnp.maximum(new_var, self._epsilon)
119
+
120
+ # Normalize using updated statistics
121
+ std = jnp.sqrt(new_var)
122
+ normalized = (observation - new_mean) / (std + self._epsilon)
123
+
124
+ new_state = NormalizerState(
125
+ mean=new_mean,
126
+ var=new_var,
127
+ sample_count=new_count,
128
+ decay=state.decay,
129
+ )
130
+
131
+ return normalized, new_state
132
+
133
+ def normalize_only(
134
+ self,
135
+ state: NormalizerState,
136
+ observation: Array,
137
+ ) -> Array:
138
+ """Normalize observation without updating statistics.
139
+
140
+ Useful for inference or when you want to normalize multiple
141
+ observations with the same statistics.
142
+
143
+ Args:
144
+ state: Current normalizer state
145
+ observation: Raw feature vector
146
+
147
+ Returns:
148
+ Normalized observation
149
+ """
150
+ std = jnp.sqrt(state.var)
151
+ return (observation - state.mean) / (std + self._epsilon)
152
+
153
+ def update_only(
154
+ self,
155
+ state: NormalizerState,
156
+ observation: Array,
157
+ ) -> NormalizerState:
158
+ """Update statistics without returning normalized observation.
159
+
160
+ Args:
161
+ state: Current normalizer state
162
+ observation: Raw feature vector
163
+
164
+ Returns:
165
+ Updated normalizer state
166
+ """
167
+ _, new_state = self.normalize(state, observation)
168
+ return new_state
169
+
170
+
171
+ def create_normalizer_state(
172
+ feature_dim: int,
173
+ decay: float = 0.99,
174
+ ) -> NormalizerState:
175
+ """Create initial normalizer state.
176
+
177
+ Convenience function for creating normalizer state without
178
+ instantiating the OnlineNormalizer class.
179
+
180
+ Args:
181
+ feature_dim: Dimension of feature vectors
182
+ decay: Exponential decay factor
183
+
184
+ Returns:
185
+ Initial normalizer state
186
+ """
187
+ return NormalizerState(
188
+ mean=jnp.zeros(feature_dim, dtype=jnp.float32),
189
+ var=jnp.ones(feature_dim, dtype=jnp.float32),
190
+ sample_count=jnp.array(0.0, dtype=jnp.float32),
191
+ decay=jnp.array(decay, dtype=jnp.float32),
192
+ )
@@ -0,0 +1,424 @@
1
+ """Optimizers for continual learning.
2
+
3
+ Implements LMS (fixed step-size baseline), IDBD (meta-learned step-sizes),
4
+ and Autostep (tuning-free step-size adaptation) for Step 1 of the Alberta Plan.
5
+
6
+ References:
7
+ - Sutton 1992, "Adapting Bias by Gradient Descent: An Incremental
8
+ Version of Delta-Bar-Delta"
9
+ - Mahmood et al. 2012, "Tuning-free step-size adaptation"
10
+ """
11
+
12
+ from abc import ABC, abstractmethod
13
+ from typing import NamedTuple
14
+
15
+ import jax.numpy as jnp
16
+ from jax import Array
17
+
18
+ from alberta_framework.core.types import AutostepState, IDBDState, LMSState
19
+
20
+
21
+ class OptimizerUpdate(NamedTuple):
22
+ """Result of an optimizer update step.
23
+
24
+ Attributes:
25
+ weight_delta: Change to apply to weights
26
+ bias_delta: Change to apply to bias
27
+ new_state: Updated optimizer state
28
+ metrics: Dictionary of metrics for logging (values are JAX arrays for scan compatibility)
29
+ """
30
+
31
+ weight_delta: Array
32
+ bias_delta: Array
33
+ new_state: LMSState | IDBDState | AutostepState
34
+ metrics: dict[str, Array]
35
+
36
+
37
+ class Optimizer[StateT: (LMSState, IDBDState, AutostepState)](ABC):
38
+ """Base class for optimizers."""
39
+
40
+ @abstractmethod
41
+ def init(self, feature_dim: int) -> StateT:
42
+ """Initialize optimizer state.
43
+
44
+ Args:
45
+ feature_dim: Dimension of weight vector
46
+
47
+ Returns:
48
+ Initial optimizer state
49
+ """
50
+ ...
51
+
52
+ @abstractmethod
53
+ def update(
54
+ self,
55
+ state: StateT,
56
+ error: Array,
57
+ observation: Array,
58
+ ) -> OptimizerUpdate:
59
+ """Compute weight updates given prediction error.
60
+
61
+ Args:
62
+ state: Current optimizer state
63
+ error: Prediction error (target - prediction)
64
+ observation: Current observation/feature vector
65
+
66
+ Returns:
67
+ OptimizerUpdate with deltas and new state
68
+ """
69
+ ...
70
+
71
+
72
+ class LMS(Optimizer[LMSState]):
73
+ """Least Mean Square optimizer with fixed step-size.
74
+
75
+ The simplest gradient-based optimizer: `w_{t+1} = w_t + alpha * delta * x_t`
76
+
77
+ This serves as a baseline. The challenge is that the optimal step-size
78
+ depends on the problem and changes as the task becomes non-stationary.
79
+
80
+ Attributes:
81
+ step_size: Fixed learning rate alpha
82
+ """
83
+
84
+ def __init__(self, step_size: float = 0.01):
85
+ """Initialize LMS optimizer.
86
+
87
+ Args:
88
+ step_size: Fixed learning rate
89
+ """
90
+ self._step_size = step_size
91
+
92
+ def init(self, feature_dim: int) -> LMSState:
93
+ """Initialize LMS state.
94
+
95
+ Args:
96
+ feature_dim: Dimension of weight vector (unused for LMS)
97
+
98
+ Returns:
99
+ LMS state containing the step-size
100
+ """
101
+ return LMSState(step_size=jnp.array(self._step_size, dtype=jnp.float32))
102
+
103
+ def update(
104
+ self,
105
+ state: LMSState,
106
+ error: Array,
107
+ observation: Array,
108
+ ) -> OptimizerUpdate:
109
+ """Compute LMS weight update.
110
+
111
+ Update rule: `delta_w = alpha * error * x`
112
+
113
+ Args:
114
+ state: Current LMS state
115
+ error: Prediction error (scalar)
116
+ observation: Feature vector
117
+
118
+ Returns:
119
+ OptimizerUpdate with weight and bias deltas
120
+ """
121
+ alpha = state.step_size
122
+ error_scalar = jnp.squeeze(error)
123
+
124
+ # Weight update: alpha * error * x
125
+ weight_delta = alpha * error_scalar * observation
126
+
127
+ # Bias update: alpha * error
128
+ bias_delta = alpha * error_scalar
129
+
130
+ return OptimizerUpdate(
131
+ weight_delta=weight_delta,
132
+ bias_delta=bias_delta,
133
+ new_state=state, # LMS state doesn't change
134
+ metrics={"step_size": alpha},
135
+ )
136
+
137
+
138
+ class IDBD(Optimizer[IDBDState]):
139
+ """Incremental Delta-Bar-Delta optimizer.
140
+
141
+ IDBD maintains per-weight adaptive step-sizes that are meta-learned
142
+ based on gradient correlation. When successive gradients agree in sign,
143
+ the step-size for that weight increases. When they disagree, it decreases.
144
+
145
+ This implements Sutton's 1992 algorithm for adapting step-sizes online
146
+ without requiring manual tuning.
147
+
148
+ Reference: Sutton, R.S. (1992). "Adapting Bias by Gradient Descent:
149
+ An Incremental Version of Delta-Bar-Delta"
150
+
151
+ Attributes:
152
+ initial_step_size: Initial per-weight step-size
153
+ meta_step_size: Meta learning rate beta for adapting step-sizes
154
+ """
155
+
156
+ def __init__(
157
+ self,
158
+ initial_step_size: float = 0.01,
159
+ meta_step_size: float = 0.01,
160
+ ):
161
+ """Initialize IDBD optimizer.
162
+
163
+ Args:
164
+ initial_step_size: Initial value for per-weight step-sizes
165
+ meta_step_size: Meta learning rate beta for adapting step-sizes
166
+ """
167
+ self._initial_step_size = initial_step_size
168
+ self._meta_step_size = meta_step_size
169
+
170
+ def init(self, feature_dim: int) -> IDBDState:
171
+ """Initialize IDBD state.
172
+
173
+ Args:
174
+ feature_dim: Dimension of weight vector
175
+
176
+ Returns:
177
+ IDBD state with per-weight step-sizes and traces
178
+ """
179
+ return IDBDState(
180
+ log_step_sizes=jnp.full(
181
+ feature_dim, jnp.log(self._initial_step_size), dtype=jnp.float32
182
+ ),
183
+ traces=jnp.zeros(feature_dim, dtype=jnp.float32),
184
+ meta_step_size=jnp.array(self._meta_step_size, dtype=jnp.float32),
185
+ bias_step_size=jnp.array(self._initial_step_size, dtype=jnp.float32),
186
+ bias_trace=jnp.array(0.0, dtype=jnp.float32),
187
+ )
188
+
189
+ def update(
190
+ self,
191
+ state: IDBDState,
192
+ error: Array,
193
+ observation: Array,
194
+ ) -> OptimizerUpdate:
195
+ """Compute IDBD weight update with adaptive step-sizes.
196
+
197
+ The IDBD algorithm:
198
+
199
+ 1. Compute step-sizes: `alpha_i = exp(log_alpha_i)`
200
+ 2. Update weights: `w_i += alpha_i * error * x_i`
201
+ 3. Update log step-sizes: `log_alpha_i += beta * error * x_i * h_i`
202
+ 4. Update traces: `h_i = h_i * max(0, 1 - alpha_i * x_i^2) + alpha_i * error * x_i`
203
+
204
+ The trace h_i tracks the correlation between current and past gradients.
205
+ When gradients consistently point the same direction, h_i grows,
206
+ leading to larger step-sizes.
207
+
208
+ Args:
209
+ state: Current IDBD state
210
+ error: Prediction error (scalar)
211
+ observation: Feature vector
212
+
213
+ Returns:
214
+ OptimizerUpdate with weight deltas and updated state
215
+ """
216
+ error_scalar = jnp.squeeze(error)
217
+ beta = state.meta_step_size
218
+
219
+ # Current step-sizes (exponentiate log values)
220
+ alphas = jnp.exp(state.log_step_sizes)
221
+
222
+ # Weight updates: alpha_i * error * x_i
223
+ weight_delta = alphas * error_scalar * observation
224
+
225
+ # Meta-update: adapt step-sizes based on gradient correlation
226
+ # log_alpha_i += beta * error * x_i * h_i
227
+ gradient_correlation = error_scalar * observation * state.traces
228
+ new_log_step_sizes = state.log_step_sizes + beta * gradient_correlation
229
+
230
+ # Clip log step-sizes to prevent numerical issues
231
+ new_log_step_sizes = jnp.clip(new_log_step_sizes, -10.0, 2.0)
232
+
233
+ # Update traces: h_i = h_i * decay + alpha_i * error * x_i
234
+ # decay = max(0, 1 - alpha_i * x_i^2)
235
+ decay = jnp.maximum(0.0, 1.0 - alphas * observation**2)
236
+ new_traces = state.traces * decay + alphas * error_scalar * observation
237
+
238
+ # Bias updates (similar logic but scalar)
239
+ bias_alpha = state.bias_step_size
240
+ bias_delta = bias_alpha * error_scalar
241
+
242
+ # Update bias step-size
243
+ bias_gradient_correlation = error_scalar * state.bias_trace
244
+ new_bias_step_size = bias_alpha * jnp.exp(beta * bias_gradient_correlation)
245
+ new_bias_step_size = jnp.clip(new_bias_step_size, 1e-6, 1.0)
246
+
247
+ # Update bias trace
248
+ bias_decay = jnp.maximum(0.0, 1.0 - bias_alpha)
249
+ new_bias_trace = state.bias_trace * bias_decay + bias_alpha * error_scalar
250
+
251
+ new_state = IDBDState(
252
+ log_step_sizes=new_log_step_sizes,
253
+ traces=new_traces,
254
+ meta_step_size=beta,
255
+ bias_step_size=new_bias_step_size,
256
+ bias_trace=new_bias_trace,
257
+ )
258
+
259
+ return OptimizerUpdate(
260
+ weight_delta=weight_delta,
261
+ bias_delta=bias_delta,
262
+ new_state=new_state,
263
+ metrics={
264
+ "mean_step_size": jnp.mean(alphas),
265
+ "min_step_size": jnp.min(alphas),
266
+ "max_step_size": jnp.max(alphas),
267
+ },
268
+ )
269
+
270
+
271
+ class Autostep(Optimizer[AutostepState]):
272
+ """Autostep optimizer with tuning-free step-size adaptation.
273
+
274
+ Autostep normalizes gradients to prevent large updates and adapts
275
+ per-weight step-sizes based on gradient correlation. The key innovation
276
+ is automatic normalization that makes the algorithm robust to different
277
+ feature scales.
278
+
279
+ The algorithm maintains:
280
+ - Per-weight step-sizes that adapt based on gradient correlation
281
+ - Running max of absolute gradients for normalization
282
+ - Traces for detecting consistent gradient directions
283
+
284
+ Reference: Mahmood, A.R., Sutton, R.S., Degris, T., & Pilarski, P.M. (2012).
285
+ "Tuning-free step-size adaptation"
286
+
287
+ Attributes:
288
+ initial_step_size: Initial per-weight step-size
289
+ meta_step_size: Meta learning rate mu for adapting step-sizes
290
+ normalizer_decay: Decay factor tau for gradient normalizers
291
+ """
292
+
293
+ def __init__(
294
+ self,
295
+ initial_step_size: float = 0.01,
296
+ meta_step_size: float = 0.01,
297
+ normalizer_decay: float = 0.99,
298
+ ):
299
+ """Initialize Autostep optimizer.
300
+
301
+ Args:
302
+ initial_step_size: Initial value for per-weight step-sizes
303
+ meta_step_size: Meta learning rate for adapting step-sizes
304
+ normalizer_decay: Decay factor for gradient normalizers (higher = slower decay)
305
+ """
306
+ self._initial_step_size = initial_step_size
307
+ self._meta_step_size = meta_step_size
308
+ self._normalizer_decay = normalizer_decay
309
+
310
+ def init(self, feature_dim: int) -> AutostepState:
311
+ """Initialize Autostep state.
312
+
313
+ Args:
314
+ feature_dim: Dimension of weight vector
315
+
316
+ Returns:
317
+ Autostep state with per-weight step-sizes, traces, and normalizers
318
+ """
319
+ return AutostepState(
320
+ step_sizes=jnp.full(feature_dim, self._initial_step_size, dtype=jnp.float32),
321
+ traces=jnp.zeros(feature_dim, dtype=jnp.float32),
322
+ normalizers=jnp.ones(feature_dim, dtype=jnp.float32),
323
+ meta_step_size=jnp.array(self._meta_step_size, dtype=jnp.float32),
324
+ normalizer_decay=jnp.array(self._normalizer_decay, dtype=jnp.float32),
325
+ bias_step_size=jnp.array(self._initial_step_size, dtype=jnp.float32),
326
+ bias_trace=jnp.array(0.0, dtype=jnp.float32),
327
+ bias_normalizer=jnp.array(1.0, dtype=jnp.float32),
328
+ )
329
+
330
+ def update(
331
+ self,
332
+ state: AutostepState,
333
+ error: Array,
334
+ observation: Array,
335
+ ) -> OptimizerUpdate:
336
+ """Compute Autostep weight update with normalized gradients.
337
+
338
+ The Autostep algorithm:
339
+
340
+ 1. Compute gradient: `g_i = error * x_i`
341
+ 2. Normalize gradient: `g_i' = g_i / max(|g_i|, v_i)`
342
+ 3. Update weights: `w_i += alpha_i * g_i'`
343
+ 4. Update step-sizes: `alpha_i *= exp(mu * g_i' * h_i)`
344
+ 5. Update traces: `h_i = h_i * (1 - alpha_i) + alpha_i * g_i'`
345
+ 6. Update normalizers: `v_i = max(|g_i|, v_i * tau)`
346
+
347
+ Args:
348
+ state: Current Autostep state
349
+ error: Prediction error (scalar)
350
+ observation: Feature vector
351
+
352
+ Returns:
353
+ OptimizerUpdate with weight deltas and updated state
354
+ """
355
+ error_scalar = jnp.squeeze(error)
356
+ mu = state.meta_step_size
357
+ tau = state.normalizer_decay
358
+
359
+ # Compute raw gradient
360
+ gradient = error_scalar * observation
361
+
362
+ # Normalize gradient using running max
363
+ abs_gradient = jnp.abs(gradient)
364
+ normalizer = jnp.maximum(abs_gradient, state.normalizers)
365
+ normalized_gradient = gradient / (normalizer + 1e-8)
366
+
367
+ # Compute weight delta using normalized gradient
368
+ weight_delta = state.step_sizes * normalized_gradient
369
+
370
+ # Update step-sizes based on gradient correlation
371
+ gradient_correlation = normalized_gradient * state.traces
372
+ new_step_sizes = state.step_sizes * jnp.exp(mu * gradient_correlation)
373
+
374
+ # Clip step-sizes to prevent instability
375
+ new_step_sizes = jnp.clip(new_step_sizes, 1e-8, 1.0)
376
+
377
+ # Update traces with decay based on step-size
378
+ trace_decay = 1.0 - state.step_sizes
379
+ new_traces = state.traces * trace_decay + state.step_sizes * normalized_gradient
380
+
381
+ # Update normalizers with decay
382
+ new_normalizers = jnp.maximum(abs_gradient, state.normalizers * tau)
383
+
384
+ # Bias updates (similar logic)
385
+ bias_gradient = error_scalar
386
+ abs_bias_gradient = jnp.abs(bias_gradient)
387
+ bias_normalizer = jnp.maximum(abs_bias_gradient, state.bias_normalizer)
388
+ normalized_bias_gradient = bias_gradient / (bias_normalizer + 1e-8)
389
+
390
+ bias_delta = state.bias_step_size * normalized_bias_gradient
391
+
392
+ bias_correlation = normalized_bias_gradient * state.bias_trace
393
+ new_bias_step_size = state.bias_step_size * jnp.exp(mu * bias_correlation)
394
+ new_bias_step_size = jnp.clip(new_bias_step_size, 1e-8, 1.0)
395
+
396
+ bias_trace_decay = 1.0 - state.bias_step_size
397
+ new_bias_trace = (
398
+ state.bias_trace * bias_trace_decay + state.bias_step_size * normalized_bias_gradient
399
+ )
400
+
401
+ new_bias_normalizer = jnp.maximum(abs_bias_gradient, state.bias_normalizer * tau)
402
+
403
+ new_state = AutostepState(
404
+ step_sizes=new_step_sizes,
405
+ traces=new_traces,
406
+ normalizers=new_normalizers,
407
+ meta_step_size=mu,
408
+ normalizer_decay=tau,
409
+ bias_step_size=new_bias_step_size,
410
+ bias_trace=new_bias_trace,
411
+ bias_normalizer=new_bias_normalizer,
412
+ )
413
+
414
+ return OptimizerUpdate(
415
+ weight_delta=weight_delta,
416
+ bias_delta=bias_delta,
417
+ new_state=new_state,
418
+ metrics={
419
+ "mean_step_size": jnp.mean(state.step_sizes),
420
+ "min_step_size": jnp.min(state.step_sizes),
421
+ "max_step_size": jnp.max(state.step_sizes),
422
+ "mean_normalizer": jnp.mean(state.normalizers),
423
+ },
424
+ )