alberta-framework 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,192 @@
1
+ """Online feature normalization for continual learning.
2
+
3
+ Implements online (streaming) normalization that updates estimates of mean
4
+ and variance at every time step, following the principle of temporal uniformity.
5
+
6
+ Reference: Welford's online algorithm for numerical stability.
7
+ """
8
+
9
+ from typing import NamedTuple
10
+
11
+ import jax.numpy as jnp
12
+ from jax import Array
13
+
14
+
15
+ class NormalizerState(NamedTuple):
16
+ """State for online feature normalization.
17
+
18
+ Uses Welford's online algorithm for numerically stable estimation
19
+ of running mean and variance.
20
+
21
+ Attributes:
22
+ mean: Running mean estimate per feature
23
+ var: Running variance estimate per feature
24
+ sample_count: Number of samples seen
25
+ decay: Exponential decay factor for estimates (1.0 = no decay, pure online)
26
+ """
27
+
28
+ mean: Array # Shape: (feature_dim,)
29
+ var: Array # Shape: (feature_dim,)
30
+ sample_count: Array # Scalar
31
+ decay: Array # Scalar
32
+
33
+
34
+ class OnlineNormalizer:
35
+ """Online feature normalizer for continual learning.
36
+
37
+ Normalizes features using running estimates of mean and standard deviation:
38
+ x_normalized = (x - mean) / (std + epsilon)
39
+
40
+ The normalizer updates its estimates at every time step, following
41
+ temporal uniformity. Uses exponential moving average for non-stationary
42
+ environments.
43
+
44
+ Attributes:
45
+ epsilon: Small constant for numerical stability
46
+ decay: Exponential decay for running estimates (0.99 = slower adaptation)
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ epsilon: float = 1e-8,
52
+ decay: float = 0.99,
53
+ ):
54
+ """Initialize the online normalizer.
55
+
56
+ Args:
57
+ epsilon: Small constant added to std for numerical stability
58
+ decay: Exponential decay factor for running estimates.
59
+ Lower values adapt faster to changes.
60
+ 1.0 means pure online average (no decay).
61
+ """
62
+ self._epsilon = epsilon
63
+ self._decay = decay
64
+
65
+ def init(self, feature_dim: int) -> NormalizerState:
66
+ """Initialize normalizer state.
67
+
68
+ Args:
69
+ feature_dim: Dimension of feature vectors
70
+
71
+ Returns:
72
+ Initial normalizer state with zero mean and unit variance
73
+ """
74
+ return NormalizerState(
75
+ mean=jnp.zeros(feature_dim, dtype=jnp.float32),
76
+ var=jnp.ones(feature_dim, dtype=jnp.float32),
77
+ sample_count=jnp.array(0.0, dtype=jnp.float32),
78
+ decay=jnp.array(self._decay, dtype=jnp.float32),
79
+ )
80
+
81
+ def normalize(
82
+ self,
83
+ state: NormalizerState,
84
+ observation: Array,
85
+ ) -> tuple[Array, NormalizerState]:
86
+ """Normalize observation and update running statistics.
87
+
88
+ This method both normalizes the current observation AND updates
89
+ the running statistics, maintaining temporal uniformity.
90
+
91
+ Args:
92
+ state: Current normalizer state
93
+ observation: Raw feature vector
94
+
95
+ Returns:
96
+ Tuple of (normalized_observation, new_state)
97
+ """
98
+ # Update count
99
+ new_count = state.sample_count + 1.0
100
+
101
+ # Compute effective decay (ramp up from 0 to target decay)
102
+ # This prevents instability in early steps
103
+ effective_decay = jnp.minimum(
104
+ state.decay,
105
+ 1.0 - 1.0 / (new_count + 1.0)
106
+ )
107
+
108
+ # Update mean using exponential moving average
109
+ delta = observation - state.mean
110
+ new_mean = state.mean + (1.0 - effective_decay) * delta
111
+
112
+ # Update variance using exponential moving average of squared deviations
113
+ # This is a simplified Welford's algorithm adapted for EMA
114
+ delta2 = observation - new_mean
115
+ new_var = effective_decay * state.var + (1.0 - effective_decay) * delta * delta2
116
+
117
+ # Ensure variance is positive
118
+ new_var = jnp.maximum(new_var, self._epsilon)
119
+
120
+ # Normalize using updated statistics
121
+ std = jnp.sqrt(new_var)
122
+ normalized = (observation - new_mean) / (std + self._epsilon)
123
+
124
+ new_state = NormalizerState(
125
+ mean=new_mean,
126
+ var=new_var,
127
+ sample_count=new_count,
128
+ decay=state.decay,
129
+ )
130
+
131
+ return normalized, new_state
132
+
133
+ def normalize_only(
134
+ self,
135
+ state: NormalizerState,
136
+ observation: Array,
137
+ ) -> Array:
138
+ """Normalize observation without updating statistics.
139
+
140
+ Useful for inference or when you want to normalize multiple
141
+ observations with the same statistics.
142
+
143
+ Args:
144
+ state: Current normalizer state
145
+ observation: Raw feature vector
146
+
147
+ Returns:
148
+ Normalized observation
149
+ """
150
+ std = jnp.sqrt(state.var)
151
+ return (observation - state.mean) / (std + self._epsilon)
152
+
153
+ def update_only(
154
+ self,
155
+ state: NormalizerState,
156
+ observation: Array,
157
+ ) -> NormalizerState:
158
+ """Update statistics without returning normalized observation.
159
+
160
+ Args:
161
+ state: Current normalizer state
162
+ observation: Raw feature vector
163
+
164
+ Returns:
165
+ Updated normalizer state
166
+ """
167
+ _, new_state = self.normalize(state, observation)
168
+ return new_state
169
+
170
+
171
+ def create_normalizer_state(
172
+ feature_dim: int,
173
+ decay: float = 0.99,
174
+ ) -> NormalizerState:
175
+ """Create initial normalizer state.
176
+
177
+ Convenience function for creating normalizer state without
178
+ instantiating the OnlineNormalizer class.
179
+
180
+ Args:
181
+ feature_dim: Dimension of feature vectors
182
+ decay: Exponential decay factor
183
+
184
+ Returns:
185
+ Initial normalizer state
186
+ """
187
+ return NormalizerState(
188
+ mean=jnp.zeros(feature_dim, dtype=jnp.float32),
189
+ var=jnp.ones(feature_dim, dtype=jnp.float32),
190
+ sample_count=jnp.array(0.0, dtype=jnp.float32),
191
+ decay=jnp.array(decay, dtype=jnp.float32),
192
+ )
@@ -0,0 +1,422 @@
1
+ """Optimizers for continual learning.
2
+
3
+ Implements LMS (fixed step-size baseline), IDBD (meta-learned step-sizes),
4
+ and Autostep (tuning-free step-size adaptation) for Step 1 of the Alberta Plan.
5
+
6
+ References:
7
+ - Sutton 1992, "Adapting Bias by Gradient Descent: An Incremental
8
+ Version of Delta-Bar-Delta"
9
+ - Mahmood et al. 2012, "Tuning-free step-size adaptation"
10
+ """
11
+
12
+ from abc import ABC, abstractmethod
13
+ from typing import NamedTuple
14
+
15
+ import jax.numpy as jnp
16
+ from jax import Array
17
+
18
+ from alberta_framework.core.types import AutostepState, IDBDState, LMSState
19
+
20
+
21
+ class OptimizerUpdate(NamedTuple):
22
+ """Result of an optimizer update step.
23
+
24
+ Attributes:
25
+ weight_delta: Change to apply to weights
26
+ bias_delta: Change to apply to bias
27
+ new_state: Updated optimizer state
28
+ metrics: Dictionary of metrics for logging (values are JAX arrays for scan compatibility)
29
+ """
30
+
31
+ weight_delta: Array
32
+ bias_delta: Array
33
+ new_state: LMSState | IDBDState | AutostepState
34
+ metrics: dict[str, Array]
35
+
36
+
37
+ class Optimizer[StateT: (LMSState, IDBDState, AutostepState)](ABC):
38
+ """Base class for optimizers."""
39
+
40
+ @abstractmethod
41
+ def init(self, feature_dim: int) -> StateT:
42
+ """Initialize optimizer state.
43
+
44
+ Args:
45
+ feature_dim: Dimension of weight vector
46
+
47
+ Returns:
48
+ Initial optimizer state
49
+ """
50
+ ...
51
+
52
+ @abstractmethod
53
+ def update(
54
+ self,
55
+ state: StateT,
56
+ error: Array,
57
+ observation: Array,
58
+ ) -> OptimizerUpdate:
59
+ """Compute weight updates given prediction error.
60
+
61
+ Args:
62
+ state: Current optimizer state
63
+ error: Prediction error (target - prediction)
64
+ observation: Current observation/feature vector
65
+
66
+ Returns:
67
+ OptimizerUpdate with deltas and new state
68
+ """
69
+ ...
70
+
71
+
72
+ class LMS(Optimizer[LMSState]):
73
+ """Least Mean Square optimizer with fixed step-size.
74
+
75
+ The simplest gradient-based optimizer: w_{t+1} = w_t + alpha * delta * x_t
76
+
77
+ This serves as a baseline. The challenge is that the optimal step-size
78
+ depends on the problem and changes as the task becomes non-stationary.
79
+
80
+ Attributes:
81
+ step_size: Fixed learning rate alpha
82
+ """
83
+
84
+ def __init__(self, step_size: float = 0.01):
85
+ """Initialize LMS optimizer.
86
+
87
+ Args:
88
+ step_size: Fixed learning rate
89
+ """
90
+ self._step_size = step_size
91
+
92
+ def init(self, feature_dim: int) -> LMSState:
93
+ """Initialize LMS state.
94
+
95
+ Args:
96
+ feature_dim: Dimension of weight vector (unused for LMS)
97
+
98
+ Returns:
99
+ LMS state containing the step-size
100
+ """
101
+ return LMSState(step_size=jnp.array(self._step_size, dtype=jnp.float32))
102
+
103
+ def update(
104
+ self,
105
+ state: LMSState,
106
+ error: Array,
107
+ observation: Array,
108
+ ) -> OptimizerUpdate:
109
+ """Compute LMS weight update.
110
+
111
+ Update rule: delta_w = alpha * error * x
112
+
113
+ Args:
114
+ state: Current LMS state
115
+ error: Prediction error (scalar)
116
+ observation: Feature vector
117
+
118
+ Returns:
119
+ OptimizerUpdate with weight and bias deltas
120
+ """
121
+ alpha = state.step_size
122
+ error_scalar = jnp.squeeze(error)
123
+
124
+ # Weight update: alpha * error * x
125
+ weight_delta = alpha * error_scalar * observation
126
+
127
+ # Bias update: alpha * error
128
+ bias_delta = alpha * error_scalar
129
+
130
+ return OptimizerUpdate(
131
+ weight_delta=weight_delta,
132
+ bias_delta=bias_delta,
133
+ new_state=state, # LMS state doesn't change
134
+ metrics={"step_size": alpha},
135
+ )
136
+
137
+
138
+ class IDBD(Optimizer[IDBDState]):
139
+ """Incremental Delta-Bar-Delta optimizer.
140
+
141
+ IDBD maintains per-weight adaptive step-sizes that are meta-learned
142
+ based on gradient correlation. When successive gradients agree in sign,
143
+ the step-size for that weight increases. When they disagree, it decreases.
144
+
145
+ This implements Sutton's 1992 algorithm for adapting step-sizes online
146
+ without requiring manual tuning.
147
+
148
+ Reference: Sutton, R.S. (1992). "Adapting Bias by Gradient Descent:
149
+ An Incremental Version of Delta-Bar-Delta"
150
+
151
+ Attributes:
152
+ initial_step_size: Initial per-weight step-size
153
+ meta_step_size: Meta learning rate beta for adapting step-sizes
154
+ """
155
+
156
+ def __init__(
157
+ self,
158
+ initial_step_size: float = 0.01,
159
+ meta_step_size: float = 0.01,
160
+ ):
161
+ """Initialize IDBD optimizer.
162
+
163
+ Args:
164
+ initial_step_size: Initial value for per-weight step-sizes
165
+ meta_step_size: Meta learning rate beta for adapting step-sizes
166
+ """
167
+ self._initial_step_size = initial_step_size
168
+ self._meta_step_size = meta_step_size
169
+
170
+ def init(self, feature_dim: int) -> IDBDState:
171
+ """Initialize IDBD state.
172
+
173
+ Args:
174
+ feature_dim: Dimension of weight vector
175
+
176
+ Returns:
177
+ IDBD state with per-weight step-sizes and traces
178
+ """
179
+ return IDBDState(
180
+ log_step_sizes=jnp.full(
181
+ feature_dim, jnp.log(self._initial_step_size), dtype=jnp.float32
182
+ ),
183
+ traces=jnp.zeros(feature_dim, dtype=jnp.float32),
184
+ meta_step_size=jnp.array(self._meta_step_size, dtype=jnp.float32),
185
+ bias_step_size=jnp.array(self._initial_step_size, dtype=jnp.float32),
186
+ bias_trace=jnp.array(0.0, dtype=jnp.float32),
187
+ )
188
+
189
+ def update(
190
+ self,
191
+ state: IDBDState,
192
+ error: Array,
193
+ observation: Array,
194
+ ) -> OptimizerUpdate:
195
+ """Compute IDBD weight update with adaptive step-sizes.
196
+
197
+ The IDBD algorithm:
198
+ 1. Compute step-sizes: alpha_i = exp(log_alpha_i)
199
+ 2. Update weights: w_i += alpha_i * error * x_i
200
+ 3. Update log step-sizes: log_alpha_i += beta * error * x_i * h_i
201
+ 4. Update traces: h_i = h_i * max(0, 1 - alpha_i * x_i^2) + alpha_i * error * x_i
202
+
203
+ The trace h_i tracks the correlation between current and past gradients.
204
+ When gradients consistently point the same direction, h_i grows,
205
+ leading to larger step-sizes.
206
+
207
+ Args:
208
+ state: Current IDBD state
209
+ error: Prediction error (scalar)
210
+ observation: Feature vector
211
+
212
+ Returns:
213
+ OptimizerUpdate with weight deltas and updated state
214
+ """
215
+ error_scalar = jnp.squeeze(error)
216
+ beta = state.meta_step_size
217
+
218
+ # Current step-sizes (exponentiate log values)
219
+ alphas = jnp.exp(state.log_step_sizes)
220
+
221
+ # Weight updates: alpha_i * error * x_i
222
+ weight_delta = alphas * error_scalar * observation
223
+
224
+ # Meta-update: adapt step-sizes based on gradient correlation
225
+ # log_alpha_i += beta * error * x_i * h_i
226
+ gradient_correlation = error_scalar * observation * state.traces
227
+ new_log_step_sizes = state.log_step_sizes + beta * gradient_correlation
228
+
229
+ # Clip log step-sizes to prevent numerical issues
230
+ new_log_step_sizes = jnp.clip(new_log_step_sizes, -10.0, 2.0)
231
+
232
+ # Update traces: h_i = h_i * decay + alpha_i * error * x_i
233
+ # decay = max(0, 1 - alpha_i * x_i^2)
234
+ decay = jnp.maximum(0.0, 1.0 - alphas * observation**2)
235
+ new_traces = state.traces * decay + alphas * error_scalar * observation
236
+
237
+ # Bias updates (similar logic but scalar)
238
+ bias_alpha = state.bias_step_size
239
+ bias_delta = bias_alpha * error_scalar
240
+
241
+ # Update bias step-size
242
+ bias_gradient_correlation = error_scalar * state.bias_trace
243
+ new_bias_step_size = bias_alpha * jnp.exp(beta * bias_gradient_correlation)
244
+ new_bias_step_size = jnp.clip(new_bias_step_size, 1e-6, 1.0)
245
+
246
+ # Update bias trace
247
+ bias_decay = jnp.maximum(0.0, 1.0 - bias_alpha)
248
+ new_bias_trace = state.bias_trace * bias_decay + bias_alpha * error_scalar
249
+
250
+ new_state = IDBDState(
251
+ log_step_sizes=new_log_step_sizes,
252
+ traces=new_traces,
253
+ meta_step_size=beta,
254
+ bias_step_size=new_bias_step_size,
255
+ bias_trace=new_bias_trace,
256
+ )
257
+
258
+ return OptimizerUpdate(
259
+ weight_delta=weight_delta,
260
+ bias_delta=bias_delta,
261
+ new_state=new_state,
262
+ metrics={
263
+ "mean_step_size": jnp.mean(alphas),
264
+ "min_step_size": jnp.min(alphas),
265
+ "max_step_size": jnp.max(alphas),
266
+ },
267
+ )
268
+
269
+
270
+ class Autostep(Optimizer[AutostepState]):
271
+ """Autostep optimizer with tuning-free step-size adaptation.
272
+
273
+ Autostep normalizes gradients to prevent large updates and adapts
274
+ per-weight step-sizes based on gradient correlation. The key innovation
275
+ is automatic normalization that makes the algorithm robust to different
276
+ feature scales.
277
+
278
+ The algorithm maintains:
279
+ - Per-weight step-sizes that adapt based on gradient correlation
280
+ - Running max of absolute gradients for normalization
281
+ - Traces for detecting consistent gradient directions
282
+
283
+ Reference: Mahmood, A.R., Sutton, R.S., Degris, T., & Pilarski, P.M. (2012).
284
+ "Tuning-free step-size adaptation"
285
+
286
+ Attributes:
287
+ initial_step_size: Initial per-weight step-size
288
+ meta_step_size: Meta learning rate mu for adapting step-sizes
289
+ normalizer_decay: Decay factor tau for gradient normalizers
290
+ """
291
+
292
+ def __init__(
293
+ self,
294
+ initial_step_size: float = 0.01,
295
+ meta_step_size: float = 0.01,
296
+ normalizer_decay: float = 0.99,
297
+ ):
298
+ """Initialize Autostep optimizer.
299
+
300
+ Args:
301
+ initial_step_size: Initial value for per-weight step-sizes
302
+ meta_step_size: Meta learning rate for adapting step-sizes
303
+ normalizer_decay: Decay factor for gradient normalizers (higher = slower decay)
304
+ """
305
+ self._initial_step_size = initial_step_size
306
+ self._meta_step_size = meta_step_size
307
+ self._normalizer_decay = normalizer_decay
308
+
309
+ def init(self, feature_dim: int) -> AutostepState:
310
+ """Initialize Autostep state.
311
+
312
+ Args:
313
+ feature_dim: Dimension of weight vector
314
+
315
+ Returns:
316
+ Autostep state with per-weight step-sizes, traces, and normalizers
317
+ """
318
+ return AutostepState(
319
+ step_sizes=jnp.full(feature_dim, self._initial_step_size, dtype=jnp.float32),
320
+ traces=jnp.zeros(feature_dim, dtype=jnp.float32),
321
+ normalizers=jnp.ones(feature_dim, dtype=jnp.float32),
322
+ meta_step_size=jnp.array(self._meta_step_size, dtype=jnp.float32),
323
+ normalizer_decay=jnp.array(self._normalizer_decay, dtype=jnp.float32),
324
+ bias_step_size=jnp.array(self._initial_step_size, dtype=jnp.float32),
325
+ bias_trace=jnp.array(0.0, dtype=jnp.float32),
326
+ bias_normalizer=jnp.array(1.0, dtype=jnp.float32),
327
+ )
328
+
329
+ def update(
330
+ self,
331
+ state: AutostepState,
332
+ error: Array,
333
+ observation: Array,
334
+ ) -> OptimizerUpdate:
335
+ """Compute Autostep weight update with normalized gradients.
336
+
337
+ The Autostep algorithm:
338
+ 1. Compute gradient: g_i = error * x_i
339
+ 2. Normalize gradient: g_i' = g_i / max(|g_i|, v_i)
340
+ 3. Update weights: w_i += alpha_i * g_i'
341
+ 4. Update step-sizes: alpha_i *= exp(mu * g_i' * h_i)
342
+ 5. Update traces: h_i = h_i * (1 - alpha_i) + alpha_i * g_i'
343
+ 6. Update normalizers: v_i = max(|g_i|, v_i * tau)
344
+
345
+ Args:
346
+ state: Current Autostep state
347
+ error: Prediction error (scalar)
348
+ observation: Feature vector
349
+
350
+ Returns:
351
+ OptimizerUpdate with weight deltas and updated state
352
+ """
353
+ error_scalar = jnp.squeeze(error)
354
+ mu = state.meta_step_size
355
+ tau = state.normalizer_decay
356
+
357
+ # Compute raw gradient
358
+ gradient = error_scalar * observation
359
+
360
+ # Normalize gradient using running max
361
+ abs_gradient = jnp.abs(gradient)
362
+ normalizer = jnp.maximum(abs_gradient, state.normalizers)
363
+ normalized_gradient = gradient / (normalizer + 1e-8)
364
+
365
+ # Compute weight delta using normalized gradient
366
+ weight_delta = state.step_sizes * normalized_gradient
367
+
368
+ # Update step-sizes based on gradient correlation
369
+ gradient_correlation = normalized_gradient * state.traces
370
+ new_step_sizes = state.step_sizes * jnp.exp(mu * gradient_correlation)
371
+
372
+ # Clip step-sizes to prevent instability
373
+ new_step_sizes = jnp.clip(new_step_sizes, 1e-8, 1.0)
374
+
375
+ # Update traces with decay based on step-size
376
+ trace_decay = 1.0 - state.step_sizes
377
+ new_traces = state.traces * trace_decay + state.step_sizes * normalized_gradient
378
+
379
+ # Update normalizers with decay
380
+ new_normalizers = jnp.maximum(abs_gradient, state.normalizers * tau)
381
+
382
+ # Bias updates (similar logic)
383
+ bias_gradient = error_scalar
384
+ abs_bias_gradient = jnp.abs(bias_gradient)
385
+ bias_normalizer = jnp.maximum(abs_bias_gradient, state.bias_normalizer)
386
+ normalized_bias_gradient = bias_gradient / (bias_normalizer + 1e-8)
387
+
388
+ bias_delta = state.bias_step_size * normalized_bias_gradient
389
+
390
+ bias_correlation = normalized_bias_gradient * state.bias_trace
391
+ new_bias_step_size = state.bias_step_size * jnp.exp(mu * bias_correlation)
392
+ new_bias_step_size = jnp.clip(new_bias_step_size, 1e-8, 1.0)
393
+
394
+ bias_trace_decay = 1.0 - state.bias_step_size
395
+ new_bias_trace = (
396
+ state.bias_trace * bias_trace_decay + state.bias_step_size * normalized_bias_gradient
397
+ )
398
+
399
+ new_bias_normalizer = jnp.maximum(abs_bias_gradient, state.bias_normalizer * tau)
400
+
401
+ new_state = AutostepState(
402
+ step_sizes=new_step_sizes,
403
+ traces=new_traces,
404
+ normalizers=new_normalizers,
405
+ meta_step_size=mu,
406
+ normalizer_decay=tau,
407
+ bias_step_size=new_bias_step_size,
408
+ bias_trace=new_bias_trace,
409
+ bias_normalizer=new_bias_normalizer,
410
+ )
411
+
412
+ return OptimizerUpdate(
413
+ weight_delta=weight_delta,
414
+ bias_delta=bias_delta,
415
+ new_state=new_state,
416
+ metrics={
417
+ "mean_step_size": jnp.mean(state.step_sizes),
418
+ "min_step_size": jnp.min(state.step_sizes),
419
+ "max_step_size": jnp.max(state.step_sizes),
420
+ "mean_normalizer": jnp.mean(state.normalizers),
421
+ },
422
+ )