alberta-framework 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,530 @@
1
+ """Learning units for continual learning.
2
+
3
+ Implements learners that combine function approximation with optimizers
4
+ for temporally-uniform learning. Uses JAX's scan for efficient JIT-compiled
5
+ training loops.
6
+ """
7
+
8
+ from typing import NamedTuple
9
+
10
+ import jax
11
+ import jax.numpy as jnp
12
+ from jax import Array
13
+
14
+ from alberta_framework.core.normalizers import NormalizerState, OnlineNormalizer
15
+ from alberta_framework.core.optimizers import LMS, Optimizer
16
+ from alberta_framework.core.types import (
17
+ AutostepState,
18
+ IDBDState,
19
+ LearnerState,
20
+ LMSState,
21
+ Observation,
22
+ Prediction,
23
+ StepSizeHistory,
24
+ StepSizeTrackingConfig,
25
+ Target,
26
+ )
27
+ from alberta_framework.streams.base import ScanStream
28
+
29
+ # Type alias for any optimizer type
30
+ AnyOptimizer = Optimizer[LMSState] | Optimizer[IDBDState] | Optimizer[AutostepState]
31
+
32
+ class UpdateResult(NamedTuple):
33
+ """Result of a learner update step.
34
+
35
+ Attributes:
36
+ state: Updated learner state
37
+ prediction: Prediction made before update
38
+ error: Prediction error
39
+ metrics: Array of metrics [squared_error, error, ...]
40
+ """
41
+
42
+ state: LearnerState
43
+ prediction: Prediction
44
+ error: Array
45
+ metrics: Array
46
+
47
+
48
+ class LinearLearner:
49
+ """Linear function approximator with pluggable optimizer.
50
+
51
+ Computes predictions as: y = w @ x + b
52
+
53
+ The learner maintains weights and bias, delegating the adaptation
54
+ of learning rates to the optimizer (e.g., LMS or IDBD).
55
+
56
+ This follows the Alberta Plan philosophy of temporal uniformity:
57
+ every component updates at every time step.
58
+
59
+ Attributes:
60
+ optimizer: The optimizer to use for weight updates
61
+ """
62
+
63
+ def __init__(self, optimizer: AnyOptimizer | None = None):
64
+ """Initialize the linear learner.
65
+
66
+ Args:
67
+ optimizer: Optimizer for weight updates. Defaults to LMS(0.01)
68
+ """
69
+ self._optimizer: AnyOptimizer = optimizer or LMS(step_size=0.01)
70
+
71
+ def init(self, feature_dim: int) -> LearnerState:
72
+ """Initialize learner state.
73
+
74
+ Args:
75
+ feature_dim: Dimension of the input feature vector
76
+
77
+ Returns:
78
+ Initial learner state with zero weights and bias
79
+ """
80
+ optimizer_state = self._optimizer.init(feature_dim)
81
+
82
+ return LearnerState(
83
+ weights=jnp.zeros(feature_dim, dtype=jnp.float32),
84
+ bias=jnp.array(0.0, dtype=jnp.float32),
85
+ optimizer_state=optimizer_state,
86
+ )
87
+
88
+ def predict(self, state: LearnerState, observation: Observation) -> Prediction:
89
+ """Compute prediction for an observation.
90
+
91
+ Args:
92
+ state: Current learner state
93
+ observation: Input feature vector
94
+
95
+ Returns:
96
+ Scalar prediction y = w @ x + b
97
+ """
98
+ return jnp.atleast_1d(jnp.dot(state.weights, observation) + state.bias)
99
+
100
+ def update(
101
+ self,
102
+ state: LearnerState,
103
+ observation: Observation,
104
+ target: Target,
105
+ ) -> UpdateResult:
106
+ """Update learner given observation and target.
107
+
108
+ Performs one step of the learning algorithm:
109
+ 1. Compute prediction
110
+ 2. Compute error
111
+ 3. Get weight updates from optimizer
112
+ 4. Apply updates to weights and bias
113
+
114
+ Args:
115
+ state: Current learner state
116
+ observation: Input feature vector
117
+ target: Desired output
118
+
119
+ Returns:
120
+ UpdateResult with new state, prediction, error, and metrics
121
+ """
122
+ # Make prediction
123
+ prediction = self.predict(state, observation)
124
+
125
+ # Compute error (target - prediction)
126
+ error = jnp.squeeze(target) - jnp.squeeze(prediction)
127
+
128
+ # Get update from optimizer
129
+ # Note: type ignore needed because we can't statically prove optimizer_state
130
+ # matches the optimizer's expected state type (though they will at runtime)
131
+ opt_update = self._optimizer.update(
132
+ state.optimizer_state, # type: ignore[arg-type]
133
+ error,
134
+ observation,
135
+ )
136
+
137
+ # Apply updates
138
+ new_weights = state.weights + opt_update.weight_delta
139
+ new_bias = state.bias + opt_update.bias_delta
140
+
141
+ new_state = LearnerState(
142
+ weights=new_weights,
143
+ bias=new_bias,
144
+ optimizer_state=opt_update.new_state,
145
+ )
146
+
147
+ # Pack metrics as array for scan compatibility
148
+ # Format: [squared_error, error, mean_step_size (if adaptive)]
149
+ squared_error = error**2
150
+ mean_step_size = opt_update.metrics.get("mean_step_size", 0.0)
151
+ metrics = jnp.array([squared_error, error, mean_step_size], dtype=jnp.float32)
152
+
153
+ return UpdateResult(
154
+ state=new_state,
155
+ prediction=prediction,
156
+ error=jnp.atleast_1d(error),
157
+ metrics=metrics,
158
+ )
159
+
160
+
161
+ def run_learning_loop[StreamStateT](
162
+ learner: LinearLearner,
163
+ stream: ScanStream[StreamStateT],
164
+ num_steps: int,
165
+ key: Array,
166
+ learner_state: LearnerState | None = None,
167
+ step_size_tracking: StepSizeTrackingConfig | None = None,
168
+ ) -> tuple[LearnerState, Array] | tuple[LearnerState, Array, StepSizeHistory]:
169
+ """Run the learning loop using jax.lax.scan.
170
+
171
+ This is a JIT-compiled learning loop that uses scan for efficiency.
172
+ It returns metrics as a fixed-size array rather than a list of dicts.
173
+
174
+ Args:
175
+ learner: The learner to train
176
+ stream: Experience stream providing (observation, target) pairs
177
+ num_steps: Number of learning steps to run
178
+ key: JAX random key for stream initialization
179
+ learner_state: Initial state (if None, will be initialized from stream)
180
+ step_size_tracking: Optional config for recording per-weight step-sizes.
181
+ When provided, returns a 3-tuple including StepSizeHistory.
182
+
183
+ Returns:
184
+ If step_size_tracking is None:
185
+ Tuple of (final_state, metrics_array) where metrics_array has shape
186
+ (num_steps, 3) with columns [squared_error, error, mean_step_size]
187
+ If step_size_tracking is provided:
188
+ Tuple of (final_state, metrics_array, step_size_history)
189
+
190
+ Raises:
191
+ ValueError: If step_size_tracking.interval is less than 1 or greater than num_steps
192
+ """
193
+ # Validate tracking config
194
+ if step_size_tracking is not None:
195
+ if step_size_tracking.interval < 1:
196
+ raise ValueError(
197
+ f"step_size_tracking.interval must be >= 1, got {step_size_tracking.interval}"
198
+ )
199
+ if step_size_tracking.interval > num_steps:
200
+ raise ValueError(
201
+ f"step_size_tracking.interval ({step_size_tracking.interval}) "
202
+ f"must be <= num_steps ({num_steps})"
203
+ )
204
+
205
+ # Initialize states
206
+ if learner_state is None:
207
+ learner_state = learner.init(stream.feature_dim)
208
+ stream_state = stream.init(key)
209
+
210
+ feature_dim = stream.feature_dim
211
+
212
+ if step_size_tracking is None:
213
+ # Original behavior without tracking
214
+ def step_fn(
215
+ carry: tuple[LearnerState, StreamStateT], idx: Array
216
+ ) -> tuple[tuple[LearnerState, StreamStateT], Array]:
217
+ l_state, s_state = carry
218
+ timestep, new_s_state = stream.step(s_state, idx)
219
+ result = learner.update(l_state, timestep.observation, timestep.target)
220
+ return (result.state, new_s_state), result.metrics
221
+
222
+ (final_learner, _), metrics = jax.lax.scan(
223
+ step_fn, (learner_state, stream_state), jnp.arange(num_steps)
224
+ )
225
+
226
+ return final_learner, metrics
227
+
228
+ else:
229
+ # Step-size tracking enabled
230
+ interval = step_size_tracking.interval
231
+ include_bias = step_size_tracking.include_bias
232
+ num_recordings = num_steps // interval
233
+
234
+ # Pre-allocate history arrays
235
+ step_size_history = jnp.zeros((num_recordings, feature_dim), dtype=jnp.float32)
236
+ bias_history = (
237
+ jnp.zeros(num_recordings, dtype=jnp.float32) if include_bias else None
238
+ )
239
+ recording_indices = jnp.zeros(num_recordings, dtype=jnp.int32)
240
+
241
+ def step_fn_with_tracking(
242
+ carry: tuple[LearnerState, StreamStateT, Array, Array | None, Array], idx: Array
243
+ ) -> tuple[tuple[LearnerState, StreamStateT, Array, Array | None, Array], Array]:
244
+ l_state, s_state, ss_history, b_history, rec_indices = carry
245
+
246
+ # Perform learning step
247
+ timestep, new_s_state = stream.step(s_state, idx)
248
+ result = learner.update(l_state, timestep.observation, timestep.target)
249
+
250
+ # Check if we should record at this step (idx % interval == 0)
251
+ should_record = (idx % interval) == 0
252
+ recording_idx = idx // interval
253
+
254
+ # Extract current step-sizes
255
+ # Use hasattr checks at trace time (this works because the type is fixed)
256
+ opt_state = result.state.optimizer_state
257
+ if hasattr(opt_state, "log_step_sizes"):
258
+ # IDBD stores log step-sizes
259
+ weight_ss = jnp.exp(opt_state.log_step_sizes) # type: ignore[union-attr]
260
+ bias_ss = opt_state.bias_step_size # type: ignore[union-attr]
261
+ elif hasattr(opt_state, "step_sizes"):
262
+ # Autostep stores step-sizes directly
263
+ weight_ss = opt_state.step_sizes # type: ignore[union-attr]
264
+ bias_ss = opt_state.bias_step_size # type: ignore[union-attr]
265
+ else:
266
+ # LMS has a single fixed step-size
267
+ weight_ss = jnp.full(feature_dim, opt_state.step_size)
268
+ bias_ss = opt_state.step_size
269
+
270
+ # Conditionally update history arrays
271
+ new_ss_history = jax.lax.cond(
272
+ should_record,
273
+ lambda _: ss_history.at[recording_idx].set(weight_ss),
274
+ lambda _: ss_history,
275
+ None,
276
+ )
277
+
278
+ new_b_history = b_history
279
+ if b_history is not None:
280
+ new_b_history = jax.lax.cond(
281
+ should_record,
282
+ lambda _: b_history.at[recording_idx].set(bias_ss),
283
+ lambda _: b_history,
284
+ None,
285
+ )
286
+
287
+ new_rec_indices = jax.lax.cond(
288
+ should_record,
289
+ lambda _: rec_indices.at[recording_idx].set(idx),
290
+ lambda _: rec_indices,
291
+ None,
292
+ )
293
+
294
+ return (
295
+ result.state,
296
+ new_s_state,
297
+ new_ss_history,
298
+ new_b_history,
299
+ new_rec_indices,
300
+ ), result.metrics
301
+
302
+ initial_carry = (
303
+ learner_state,
304
+ stream_state,
305
+ step_size_history,
306
+ bias_history,
307
+ recording_indices,
308
+ )
309
+
310
+ (final_learner, _, final_ss_history, final_b_history, final_rec_indices), metrics = (
311
+ jax.lax.scan(step_fn_with_tracking, initial_carry, jnp.arange(num_steps))
312
+ )
313
+
314
+ history = StepSizeHistory(
315
+ step_sizes=final_ss_history,
316
+ bias_step_sizes=final_b_history,
317
+ recording_indices=final_rec_indices,
318
+ )
319
+
320
+ return final_learner, metrics, history
321
+
322
+
323
+ class NormalizedLearnerState(NamedTuple):
324
+ """State for a learner with online feature normalization.
325
+
326
+ Attributes:
327
+ learner_state: Underlying learner state (weights, bias, optimizer)
328
+ normalizer_state: Online normalizer state (mean, var estimates)
329
+ """
330
+
331
+ learner_state: LearnerState
332
+ normalizer_state: NormalizerState
333
+
334
+
335
+ class NormalizedUpdateResult(NamedTuple):
336
+ """Result of a normalized learner update step.
337
+
338
+ Attributes:
339
+ state: Updated normalized learner state
340
+ prediction: Prediction made before update
341
+ error: Prediction error
342
+ metrics: Array of metrics [squared_error, error, mean_step_size, normalizer_mean_var]
343
+ """
344
+
345
+ state: NormalizedLearnerState
346
+ prediction: Prediction
347
+ error: Array
348
+ metrics: Array
349
+
350
+
351
+ class NormalizedLinearLearner:
352
+ """Linear learner with online feature normalization.
353
+
354
+ Wraps a LinearLearner with online feature normalization, following
355
+ the Alberta Plan's approach to handling varying feature scales.
356
+
357
+ Normalization is applied to features before prediction and learning:
358
+ x_normalized = (x - mean) / (std + epsilon)
359
+
360
+ The normalizer statistics update at every time step, maintaining
361
+ temporal uniformity.
362
+
363
+ Attributes:
364
+ learner: Underlying linear learner
365
+ normalizer: Online feature normalizer
366
+ """
367
+
368
+ def __init__(
369
+ self,
370
+ optimizer: AnyOptimizer | None = None,
371
+ normalizer: OnlineNormalizer | None = None,
372
+ ):
373
+ """Initialize the normalized linear learner.
374
+
375
+ Args:
376
+ optimizer: Optimizer for weight updates. Defaults to LMS(0.01)
377
+ normalizer: Feature normalizer. Defaults to OnlineNormalizer()
378
+ """
379
+ self._learner = LinearLearner(optimizer=optimizer or LMS(step_size=0.01))
380
+ self._normalizer = normalizer or OnlineNormalizer()
381
+
382
+ def init(self, feature_dim: int) -> NormalizedLearnerState:
383
+ """Initialize normalized learner state.
384
+
385
+ Args:
386
+ feature_dim: Dimension of the input feature vector
387
+
388
+ Returns:
389
+ Initial state with zero weights and unit variance estimates
390
+ """
391
+ return NormalizedLearnerState(
392
+ learner_state=self._learner.init(feature_dim),
393
+ normalizer_state=self._normalizer.init(feature_dim),
394
+ )
395
+
396
+ def predict(
397
+ self,
398
+ state: NormalizedLearnerState,
399
+ observation: Observation,
400
+ ) -> Prediction:
401
+ """Compute prediction for an observation.
402
+
403
+ Normalizes the observation using current statistics before prediction.
404
+
405
+ Args:
406
+ state: Current normalized learner state
407
+ observation: Raw (unnormalized) input feature vector
408
+
409
+ Returns:
410
+ Scalar prediction y = w @ normalize(x) + b
411
+ """
412
+ normalized_obs = self._normalizer.normalize_only(
413
+ state.normalizer_state, observation
414
+ )
415
+ return self._learner.predict(state.learner_state, normalized_obs)
416
+
417
+ def update(
418
+ self,
419
+ state: NormalizedLearnerState,
420
+ observation: Observation,
421
+ target: Target,
422
+ ) -> NormalizedUpdateResult:
423
+ """Update learner given observation and target.
424
+
425
+ Performs one step of the learning algorithm:
426
+ 1. Normalize observation (and update normalizer statistics)
427
+ 2. Compute prediction using normalized features
428
+ 3. Compute error
429
+ 4. Get weight updates from optimizer
430
+ 5. Apply updates
431
+
432
+ Args:
433
+ state: Current normalized learner state
434
+ observation: Raw (unnormalized) input feature vector
435
+ target: Desired output
436
+
437
+ Returns:
438
+ NormalizedUpdateResult with new state, prediction, error, and metrics
439
+ """
440
+ # Normalize observation and update normalizer state
441
+ normalized_obs, new_normalizer_state = self._normalizer.normalize(
442
+ state.normalizer_state, observation
443
+ )
444
+
445
+ # Delegate to underlying learner
446
+ result = self._learner.update(
447
+ state.learner_state,
448
+ normalized_obs,
449
+ target,
450
+ )
451
+
452
+ # Build combined state
453
+ new_state = NormalizedLearnerState(
454
+ learner_state=result.state,
455
+ normalizer_state=new_normalizer_state,
456
+ )
457
+
458
+ # Add normalizer metrics to the metrics array
459
+ normalizer_mean_var = jnp.mean(new_normalizer_state.var)
460
+ metrics = jnp.concatenate([result.metrics, jnp.array([normalizer_mean_var])])
461
+
462
+ return NormalizedUpdateResult(
463
+ state=new_state,
464
+ prediction=result.prediction,
465
+ error=result.error,
466
+ metrics=metrics,
467
+ )
468
+
469
+
470
+ def run_normalized_learning_loop[StreamStateT](
471
+ learner: NormalizedLinearLearner,
472
+ stream: ScanStream[StreamStateT],
473
+ num_steps: int,
474
+ key: Array,
475
+ learner_state: NormalizedLearnerState | None = None,
476
+ ) -> tuple[NormalizedLearnerState, Array]:
477
+ """Run the learning loop with normalization using jax.lax.scan.
478
+
479
+ Args:
480
+ learner: The normalized learner to train
481
+ stream: Experience stream providing (observation, target) pairs
482
+ num_steps: Number of learning steps to run
483
+ key: JAX random key for stream initialization
484
+ learner_state: Initial state (if None, will be initialized from stream)
485
+
486
+ Returns:
487
+ Tuple of (final_state, metrics_array) where metrics_array has shape
488
+ (num_steps, 4) with columns [squared_error, error, mean_step_size, normalizer_mean_var]
489
+ """
490
+ # Initialize states
491
+ if learner_state is None:
492
+ learner_state = learner.init(stream.feature_dim)
493
+ stream_state = stream.init(key)
494
+
495
+ def step_fn(
496
+ carry: tuple[NormalizedLearnerState, StreamStateT], idx: Array
497
+ ) -> tuple[tuple[NormalizedLearnerState, StreamStateT], Array]:
498
+ l_state, s_state = carry
499
+ timestep, new_s_state = stream.step(s_state, idx)
500
+ result = learner.update(l_state, timestep.observation, timestep.target)
501
+ return (result.state, new_s_state), result.metrics
502
+
503
+ (final_learner, _), metrics = jax.lax.scan(
504
+ step_fn, (learner_state, stream_state), jnp.arange(num_steps)
505
+ )
506
+
507
+ return final_learner, metrics
508
+
509
+
510
+ def metrics_to_dicts(metrics: Array, normalized: bool = False) -> list[dict[str, float]]:
511
+ """Convert metrics array to list of dicts for backward compatibility.
512
+
513
+ Args:
514
+ metrics: Array of shape (num_steps, 3) or (num_steps, 4)
515
+ normalized: If True, expects 4 columns including normalizer_mean_var
516
+
517
+ Returns:
518
+ List of metric dictionaries
519
+ """
520
+ result = []
521
+ for row in metrics:
522
+ d = {
523
+ "squared_error": float(row[0]),
524
+ "error": float(row[1]),
525
+ "mean_step_size": float(row[2]),
526
+ }
527
+ if normalized and len(row) > 3:
528
+ d["normalizer_mean_var"] = float(row[3])
529
+ result.append(d)
530
+ return result