alberta-framework 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1070 @@
1
+ """Learning units for continual learning.
2
+
3
+ Implements learners that combine function approximation with optimizers
4
+ for temporally-uniform learning. Uses JAX's scan for efficient JIT-compiled
5
+ training loops.
6
+ """
7
+
8
+ from typing import NamedTuple, cast
9
+
10
+ import jax
11
+ import jax.numpy as jnp
12
+ from jax import Array
13
+
14
+ from alberta_framework.core.normalizers import NormalizerState, OnlineNormalizer
15
+ from alberta_framework.core.optimizers import LMS, Optimizer
16
+ from alberta_framework.core.types import (
17
+ AutostepState,
18
+ BatchedLearningResult,
19
+ BatchedNormalizedResult,
20
+ IDBDState,
21
+ LearnerState,
22
+ LMSState,
23
+ NormalizerHistory,
24
+ NormalizerTrackingConfig,
25
+ Observation,
26
+ Prediction,
27
+ StepSizeHistory,
28
+ StepSizeTrackingConfig,
29
+ Target,
30
+ )
31
+ from alberta_framework.streams.base import ScanStream
32
+
33
+ # Type alias for any optimizer type
34
+ AnyOptimizer = Optimizer[LMSState] | Optimizer[IDBDState] | Optimizer[AutostepState]
35
+
36
+ class UpdateResult(NamedTuple):
37
+ """Result of a learner update step.
38
+
39
+ Attributes:
40
+ state: Updated learner state
41
+ prediction: Prediction made before update
42
+ error: Prediction error
43
+ metrics: Array of metrics [squared_error, error, ...]
44
+ """
45
+
46
+ state: LearnerState
47
+ prediction: Prediction
48
+ error: Array
49
+ metrics: Array
50
+
51
+
52
+ class LinearLearner:
53
+ """Linear function approximator with pluggable optimizer.
54
+
55
+ Computes predictions as: `y = w @ x + b`
56
+
57
+ The learner maintains weights and bias, delegating the adaptation
58
+ of learning rates to the optimizer (e.g., LMS or IDBD).
59
+
60
+ This follows the Alberta Plan philosophy of temporal uniformity:
61
+ every component updates at every time step.
62
+
63
+ Attributes:
64
+ optimizer: The optimizer to use for weight updates
65
+ """
66
+
67
+ def __init__(self, optimizer: AnyOptimizer | None = None):
68
+ """Initialize the linear learner.
69
+
70
+ Args:
71
+ optimizer: Optimizer for weight updates. Defaults to LMS(0.01)
72
+ """
73
+ self._optimizer: AnyOptimizer = optimizer or LMS(step_size=0.01)
74
+
75
+ def init(self, feature_dim: int) -> LearnerState:
76
+ """Initialize learner state.
77
+
78
+ Args:
79
+ feature_dim: Dimension of the input feature vector
80
+
81
+ Returns:
82
+ Initial learner state with zero weights and bias
83
+ """
84
+ optimizer_state = self._optimizer.init(feature_dim)
85
+
86
+ return LearnerState(
87
+ weights=jnp.zeros(feature_dim, dtype=jnp.float32),
88
+ bias=jnp.array(0.0, dtype=jnp.float32),
89
+ optimizer_state=optimizer_state,
90
+ )
91
+
92
+ def predict(self, state: LearnerState, observation: Observation) -> Prediction:
93
+ """Compute prediction for an observation.
94
+
95
+ Args:
96
+ state: Current learner state
97
+ observation: Input feature vector
98
+
99
+ Returns:
100
+ Scalar prediction `y = w @ x + b`
101
+ """
102
+ return jnp.atleast_1d(jnp.dot(state.weights, observation) + state.bias)
103
+
104
+ def update(
105
+ self,
106
+ state: LearnerState,
107
+ observation: Observation,
108
+ target: Target,
109
+ ) -> UpdateResult:
110
+ """Update learner given observation and target.
111
+
112
+ Performs one step of the learning algorithm:
113
+ 1. Compute prediction
114
+ 2. Compute error
115
+ 3. Get weight updates from optimizer
116
+ 4. Apply updates to weights and bias
117
+
118
+ Args:
119
+ state: Current learner state
120
+ observation: Input feature vector
121
+ target: Desired output
122
+
123
+ Returns:
124
+ UpdateResult with new state, prediction, error, and metrics
125
+ """
126
+ # Make prediction
127
+ prediction = self.predict(state, observation)
128
+
129
+ # Compute error (target - prediction)
130
+ error = jnp.squeeze(target) - jnp.squeeze(prediction)
131
+
132
+ # Get update from optimizer
133
+ # Note: type ignore needed because we can't statically prove optimizer_state
134
+ # matches the optimizer's expected state type (though they will at runtime)
135
+ opt_update = self._optimizer.update(
136
+ state.optimizer_state, # type: ignore[arg-type]
137
+ error,
138
+ observation,
139
+ )
140
+
141
+ # Apply updates
142
+ new_weights = state.weights + opt_update.weight_delta
143
+ new_bias = state.bias + opt_update.bias_delta
144
+
145
+ new_state = LearnerState(
146
+ weights=new_weights,
147
+ bias=new_bias,
148
+ optimizer_state=opt_update.new_state,
149
+ )
150
+
151
+ # Pack metrics as array for scan compatibility
152
+ # Format: [squared_error, error, mean_step_size (if adaptive)]
153
+ squared_error = error**2
154
+ mean_step_size = opt_update.metrics.get("mean_step_size", 0.0)
155
+ metrics = jnp.array([squared_error, error, mean_step_size], dtype=jnp.float32)
156
+
157
+ return UpdateResult(
158
+ state=new_state,
159
+ prediction=prediction,
160
+ error=jnp.atleast_1d(error),
161
+ metrics=metrics,
162
+ )
163
+
164
+
165
+ def run_learning_loop[StreamStateT](
166
+ learner: LinearLearner,
167
+ stream: ScanStream[StreamStateT],
168
+ num_steps: int,
169
+ key: Array,
170
+ learner_state: LearnerState | None = None,
171
+ step_size_tracking: StepSizeTrackingConfig | None = None,
172
+ ) -> tuple[LearnerState, Array] | tuple[LearnerState, Array, StepSizeHistory]:
173
+ """Run the learning loop using jax.lax.scan.
174
+
175
+ This is a JIT-compiled learning loop that uses scan for efficiency.
176
+ It returns metrics as a fixed-size array rather than a list of dicts.
177
+
178
+ Args:
179
+ learner: The learner to train
180
+ stream: Experience stream providing (observation, target) pairs
181
+ num_steps: Number of learning steps to run
182
+ key: JAX random key for stream initialization
183
+ learner_state: Initial state (if None, will be initialized from stream)
184
+ step_size_tracking: Optional config for recording per-weight step-sizes.
185
+ When provided, returns a 3-tuple including StepSizeHistory.
186
+
187
+ Returns:
188
+ If step_size_tracking is None:
189
+ Tuple of (final_state, metrics_array) where metrics_array has shape
190
+ (num_steps, 3) with columns [squared_error, error, mean_step_size]
191
+ If step_size_tracking is provided:
192
+ Tuple of (final_state, metrics_array, step_size_history)
193
+
194
+ Raises:
195
+ ValueError: If step_size_tracking.interval is less than 1 or greater than num_steps
196
+ """
197
+ # Validate tracking config
198
+ if step_size_tracking is not None:
199
+ if step_size_tracking.interval < 1:
200
+ raise ValueError(
201
+ f"step_size_tracking.interval must be >= 1, got {step_size_tracking.interval}"
202
+ )
203
+ if step_size_tracking.interval > num_steps:
204
+ raise ValueError(
205
+ f"step_size_tracking.interval ({step_size_tracking.interval}) "
206
+ f"must be <= num_steps ({num_steps})"
207
+ )
208
+
209
+ # Initialize states
210
+ if learner_state is None:
211
+ learner_state = learner.init(stream.feature_dim)
212
+ stream_state = stream.init(key)
213
+
214
+ feature_dim = stream.feature_dim
215
+
216
+ if step_size_tracking is None:
217
+ # Original behavior without tracking
218
+ def step_fn(
219
+ carry: tuple[LearnerState, StreamStateT], idx: Array
220
+ ) -> tuple[tuple[LearnerState, StreamStateT], Array]:
221
+ l_state, s_state = carry
222
+ timestep, new_s_state = stream.step(s_state, idx)
223
+ result = learner.update(l_state, timestep.observation, timestep.target)
224
+ return (result.state, new_s_state), result.metrics
225
+
226
+ (final_learner, _), metrics = jax.lax.scan(
227
+ step_fn, (learner_state, stream_state), jnp.arange(num_steps)
228
+ )
229
+
230
+ return final_learner, metrics
231
+
232
+ else:
233
+ # Step-size tracking enabled
234
+ interval = step_size_tracking.interval
235
+ include_bias = step_size_tracking.include_bias
236
+ num_recordings = num_steps // interval
237
+
238
+ # Pre-allocate history arrays
239
+ step_size_history = jnp.zeros((num_recordings, feature_dim), dtype=jnp.float32)
240
+ bias_history = (
241
+ jnp.zeros(num_recordings, dtype=jnp.float32) if include_bias else None
242
+ )
243
+ recording_indices = jnp.zeros(num_recordings, dtype=jnp.int32)
244
+
245
+ # Check if we need to track Autostep normalizers
246
+ # We detect this at trace time by checking the initial optimizer state
247
+ track_normalizers = hasattr(learner_state.optimizer_state, "normalizers")
248
+ normalizer_history = (
249
+ jnp.zeros((num_recordings, feature_dim), dtype=jnp.float32)
250
+ if track_normalizers
251
+ else None
252
+ )
253
+
254
+ def step_fn_with_tracking(
255
+ carry: tuple[
256
+ LearnerState, StreamStateT, Array, Array | None, Array, Array | None
257
+ ],
258
+ idx: Array,
259
+ ) -> tuple[
260
+ tuple[LearnerState, StreamStateT, Array, Array | None, Array, Array | None],
261
+ Array,
262
+ ]:
263
+ l_state, s_state, ss_history, b_history, rec_indices, norm_history = carry
264
+
265
+ # Perform learning step
266
+ timestep, new_s_state = stream.step(s_state, idx)
267
+ result = learner.update(l_state, timestep.observation, timestep.target)
268
+
269
+ # Check if we should record at this step (idx % interval == 0)
270
+ should_record = (idx % interval) == 0
271
+ recording_idx = idx // interval
272
+
273
+ # Extract current step-sizes
274
+ # Use hasattr checks at trace time (this works because the type is fixed)
275
+ opt_state = result.state.optimizer_state
276
+ if hasattr(opt_state, "log_step_sizes"):
277
+ # IDBD stores log step-sizes
278
+ weight_ss = jnp.exp(opt_state.log_step_sizes) # type: ignore[union-attr]
279
+ bias_ss = opt_state.bias_step_size # type: ignore[union-attr]
280
+ elif hasattr(opt_state, "step_sizes"):
281
+ # Autostep stores step-sizes directly
282
+ weight_ss = opt_state.step_sizes # type: ignore[union-attr]
283
+ bias_ss = opt_state.bias_step_size # type: ignore[union-attr]
284
+ else:
285
+ # LMS has a single fixed step-size
286
+ weight_ss = jnp.full(feature_dim, opt_state.step_size)
287
+ bias_ss = opt_state.step_size
288
+
289
+ # Conditionally update history arrays
290
+ new_ss_history = jax.lax.cond(
291
+ should_record,
292
+ lambda _: ss_history.at[recording_idx].set(weight_ss),
293
+ lambda _: ss_history,
294
+ None,
295
+ )
296
+
297
+ new_b_history = b_history
298
+ if b_history is not None:
299
+ new_b_history = jax.lax.cond(
300
+ should_record,
301
+ lambda _: b_history.at[recording_idx].set(bias_ss),
302
+ lambda _: b_history,
303
+ None,
304
+ )
305
+
306
+ new_rec_indices = jax.lax.cond(
307
+ should_record,
308
+ lambda _: rec_indices.at[recording_idx].set(idx),
309
+ lambda _: rec_indices,
310
+ None,
311
+ )
312
+
313
+ # Track Autostep normalizers (v_i) if applicable
314
+ new_norm_history = norm_history
315
+ if norm_history is not None and hasattr(opt_state, "normalizers"):
316
+ new_norm_history = jax.lax.cond(
317
+ should_record,
318
+ lambda _: norm_history.at[recording_idx].set(
319
+ opt_state.normalizers # type: ignore[union-attr]
320
+ ),
321
+ lambda _: norm_history,
322
+ None,
323
+ )
324
+
325
+ return (
326
+ result.state,
327
+ new_s_state,
328
+ new_ss_history,
329
+ new_b_history,
330
+ new_rec_indices,
331
+ new_norm_history,
332
+ ), result.metrics
333
+
334
+ initial_carry = (
335
+ learner_state,
336
+ stream_state,
337
+ step_size_history,
338
+ bias_history,
339
+ recording_indices,
340
+ normalizer_history,
341
+ )
342
+
343
+ (
344
+ final_learner,
345
+ _,
346
+ final_ss_history,
347
+ final_b_history,
348
+ final_rec_indices,
349
+ final_norm_history,
350
+ ), metrics = jax.lax.scan(
351
+ step_fn_with_tracking, initial_carry, jnp.arange(num_steps)
352
+ )
353
+
354
+ history = StepSizeHistory(
355
+ step_sizes=final_ss_history,
356
+ bias_step_sizes=final_b_history,
357
+ recording_indices=final_rec_indices,
358
+ normalizers=final_norm_history,
359
+ )
360
+
361
+ return final_learner, metrics, history
362
+
363
+
364
+ class NormalizedLearnerState(NamedTuple):
365
+ """State for a learner with online feature normalization.
366
+
367
+ Attributes:
368
+ learner_state: Underlying learner state (weights, bias, optimizer)
369
+ normalizer_state: Online normalizer state (mean, var estimates)
370
+ """
371
+
372
+ learner_state: LearnerState
373
+ normalizer_state: NormalizerState
374
+
375
+
376
+ class NormalizedUpdateResult(NamedTuple):
377
+ """Result of a normalized learner update step.
378
+
379
+ Attributes:
380
+ state: Updated normalized learner state
381
+ prediction: Prediction made before update
382
+ error: Prediction error
383
+ metrics: Array of metrics [squared_error, error, mean_step_size, normalizer_mean_var]
384
+ """
385
+
386
+ state: NormalizedLearnerState
387
+ prediction: Prediction
388
+ error: Array
389
+ metrics: Array
390
+
391
+
392
+ class NormalizedLinearLearner:
393
+ """Linear learner with online feature normalization.
394
+
395
+ Wraps a LinearLearner with online feature normalization, following
396
+ the Alberta Plan's approach to handling varying feature scales.
397
+
398
+ Normalization is applied to features before prediction and learning:
399
+ x_normalized = (x - mean) / (std + epsilon)
400
+
401
+ The normalizer statistics update at every time step, maintaining
402
+ temporal uniformity.
403
+
404
+ Attributes:
405
+ learner: Underlying linear learner
406
+ normalizer: Online feature normalizer
407
+ """
408
+
409
+ def __init__(
410
+ self,
411
+ optimizer: AnyOptimizer | None = None,
412
+ normalizer: OnlineNormalizer | None = None,
413
+ ):
414
+ """Initialize the normalized linear learner.
415
+
416
+ Args:
417
+ optimizer: Optimizer for weight updates. Defaults to LMS(0.01)
418
+ normalizer: Feature normalizer. Defaults to OnlineNormalizer()
419
+ """
420
+ self._learner = LinearLearner(optimizer=optimizer or LMS(step_size=0.01))
421
+ self._normalizer = normalizer or OnlineNormalizer()
422
+
423
+ def init(self, feature_dim: int) -> NormalizedLearnerState:
424
+ """Initialize normalized learner state.
425
+
426
+ Args:
427
+ feature_dim: Dimension of the input feature vector
428
+
429
+ Returns:
430
+ Initial state with zero weights and unit variance estimates
431
+ """
432
+ return NormalizedLearnerState(
433
+ learner_state=self._learner.init(feature_dim),
434
+ normalizer_state=self._normalizer.init(feature_dim),
435
+ )
436
+
437
+ def predict(
438
+ self,
439
+ state: NormalizedLearnerState,
440
+ observation: Observation,
441
+ ) -> Prediction:
442
+ """Compute prediction for an observation.
443
+
444
+ Normalizes the observation using current statistics before prediction.
445
+
446
+ Args:
447
+ state: Current normalized learner state
448
+ observation: Raw (unnormalized) input feature vector
449
+
450
+ Returns:
451
+ Scalar prediction y = w @ normalize(x) + b
452
+ """
453
+ normalized_obs = self._normalizer.normalize_only(
454
+ state.normalizer_state, observation
455
+ )
456
+ return self._learner.predict(state.learner_state, normalized_obs)
457
+
458
+ def update(
459
+ self,
460
+ state: NormalizedLearnerState,
461
+ observation: Observation,
462
+ target: Target,
463
+ ) -> NormalizedUpdateResult:
464
+ """Update learner given observation and target.
465
+
466
+ Performs one step of the learning algorithm:
467
+ 1. Normalize observation (and update normalizer statistics)
468
+ 2. Compute prediction using normalized features
469
+ 3. Compute error
470
+ 4. Get weight updates from optimizer
471
+ 5. Apply updates
472
+
473
+ Args:
474
+ state: Current normalized learner state
475
+ observation: Raw (unnormalized) input feature vector
476
+ target: Desired output
477
+
478
+ Returns:
479
+ NormalizedUpdateResult with new state, prediction, error, and metrics
480
+ """
481
+ # Normalize observation and update normalizer state
482
+ normalized_obs, new_normalizer_state = self._normalizer.normalize(
483
+ state.normalizer_state, observation
484
+ )
485
+
486
+ # Delegate to underlying learner
487
+ result = self._learner.update(
488
+ state.learner_state,
489
+ normalized_obs,
490
+ target,
491
+ )
492
+
493
+ # Build combined state
494
+ new_state = NormalizedLearnerState(
495
+ learner_state=result.state,
496
+ normalizer_state=new_normalizer_state,
497
+ )
498
+
499
+ # Add normalizer metrics to the metrics array
500
+ normalizer_mean_var = jnp.mean(new_normalizer_state.var)
501
+ metrics = jnp.concatenate([result.metrics, jnp.array([normalizer_mean_var])])
502
+
503
+ return NormalizedUpdateResult(
504
+ state=new_state,
505
+ prediction=result.prediction,
506
+ error=result.error,
507
+ metrics=metrics,
508
+ )
509
+
510
+
511
+ def run_normalized_learning_loop[StreamStateT](
512
+ learner: NormalizedLinearLearner,
513
+ stream: ScanStream[StreamStateT],
514
+ num_steps: int,
515
+ key: Array,
516
+ learner_state: NormalizedLearnerState | None = None,
517
+ step_size_tracking: StepSizeTrackingConfig | None = None,
518
+ normalizer_tracking: NormalizerTrackingConfig | None = None,
519
+ ) -> (
520
+ tuple[NormalizedLearnerState, Array]
521
+ | tuple[NormalizedLearnerState, Array, StepSizeHistory]
522
+ | tuple[NormalizedLearnerState, Array, NormalizerHistory]
523
+ | tuple[NormalizedLearnerState, Array, StepSizeHistory, NormalizerHistory]
524
+ ):
525
+ """Run the learning loop with normalization using jax.lax.scan.
526
+
527
+ Args:
528
+ learner: The normalized learner to train
529
+ stream: Experience stream providing (observation, target) pairs
530
+ num_steps: Number of learning steps to run
531
+ key: JAX random key for stream initialization
532
+ learner_state: Initial state (if None, will be initialized from stream)
533
+ step_size_tracking: Optional config for recording per-weight step-sizes.
534
+ When provided, returns StepSizeHistory including Autostep normalizers if applicable.
535
+ normalizer_tracking: Optional config for recording per-feature normalizer state.
536
+ When provided, returns NormalizerHistory with means and variances over time.
537
+
538
+ Returns:
539
+ If no tracking:
540
+ Tuple of (final_state, metrics_array) where metrics_array has shape
541
+ (num_steps, 4) with columns [squared_error, error, mean_step_size, normalizer_mean_var]
542
+ If step_size_tracking only:
543
+ Tuple of (final_state, metrics_array, step_size_history)
544
+ If normalizer_tracking only:
545
+ Tuple of (final_state, metrics_array, normalizer_history)
546
+ If both:
547
+ Tuple of (final_state, metrics_array, step_size_history, normalizer_history)
548
+
549
+ Raises:
550
+ ValueError: If tracking interval is invalid
551
+ """
552
+ # Validate tracking configs
553
+ if step_size_tracking is not None:
554
+ if step_size_tracking.interval < 1:
555
+ raise ValueError(
556
+ f"step_size_tracking.interval must be >= 1, got {step_size_tracking.interval}"
557
+ )
558
+ if step_size_tracking.interval > num_steps:
559
+ raise ValueError(
560
+ f"step_size_tracking.interval ({step_size_tracking.interval}) "
561
+ f"must be <= num_steps ({num_steps})"
562
+ )
563
+
564
+ if normalizer_tracking is not None:
565
+ if normalizer_tracking.interval < 1:
566
+ raise ValueError(
567
+ f"normalizer_tracking.interval must be >= 1, got {normalizer_tracking.interval}"
568
+ )
569
+ if normalizer_tracking.interval > num_steps:
570
+ raise ValueError(
571
+ f"normalizer_tracking.interval ({normalizer_tracking.interval}) "
572
+ f"must be <= num_steps ({num_steps})"
573
+ )
574
+
575
+ # Initialize states
576
+ if learner_state is None:
577
+ learner_state = learner.init(stream.feature_dim)
578
+ stream_state = stream.init(key)
579
+
580
+ feature_dim = stream.feature_dim
581
+
582
+ # No tracking - simple case
583
+ if step_size_tracking is None and normalizer_tracking is None:
584
+
585
+ def step_fn(
586
+ carry: tuple[NormalizedLearnerState, StreamStateT], idx: Array
587
+ ) -> tuple[tuple[NormalizedLearnerState, StreamStateT], Array]:
588
+ l_state, s_state = carry
589
+ timestep, new_s_state = stream.step(s_state, idx)
590
+ result = learner.update(l_state, timestep.observation, timestep.target)
591
+ return (result.state, new_s_state), result.metrics
592
+
593
+ (final_learner, _), metrics = jax.lax.scan(
594
+ step_fn, (learner_state, stream_state), jnp.arange(num_steps)
595
+ )
596
+
597
+ return final_learner, metrics
598
+
599
+ # Tracking enabled - need to set up history arrays
600
+ ss_interval = step_size_tracking.interval if step_size_tracking else num_steps + 1
601
+ norm_interval = (
602
+ normalizer_tracking.interval if normalizer_tracking else num_steps + 1
603
+ )
604
+
605
+ ss_num_recordings = num_steps // ss_interval if step_size_tracking else 0
606
+ norm_num_recordings = num_steps // norm_interval if normalizer_tracking else 0
607
+
608
+ # Pre-allocate step-size history arrays
609
+ ss_history = (
610
+ jnp.zeros((ss_num_recordings, feature_dim), dtype=jnp.float32)
611
+ if step_size_tracking
612
+ else None
613
+ )
614
+ ss_bias_history = (
615
+ jnp.zeros(ss_num_recordings, dtype=jnp.float32)
616
+ if step_size_tracking and step_size_tracking.include_bias
617
+ else None
618
+ )
619
+ ss_rec_indices = (
620
+ jnp.zeros(ss_num_recordings, dtype=jnp.int32) if step_size_tracking else None
621
+ )
622
+
623
+ # Check if we need to track Autostep normalizers
624
+ track_autostep_normalizers = hasattr(
625
+ learner_state.learner_state.optimizer_state, "normalizers"
626
+ )
627
+ ss_normalizers = (
628
+ jnp.zeros((ss_num_recordings, feature_dim), dtype=jnp.float32)
629
+ if step_size_tracking and track_autostep_normalizers
630
+ else None
631
+ )
632
+
633
+ # Pre-allocate normalizer state history arrays
634
+ norm_means = (
635
+ jnp.zeros((norm_num_recordings, feature_dim), dtype=jnp.float32)
636
+ if normalizer_tracking
637
+ else None
638
+ )
639
+ norm_vars = (
640
+ jnp.zeros((norm_num_recordings, feature_dim), dtype=jnp.float32)
641
+ if normalizer_tracking
642
+ else None
643
+ )
644
+ norm_rec_indices = (
645
+ jnp.zeros(norm_num_recordings, dtype=jnp.int32) if normalizer_tracking else None
646
+ )
647
+
648
+ def step_fn_with_tracking(
649
+ carry: tuple[
650
+ NormalizedLearnerState,
651
+ StreamStateT,
652
+ Array | None,
653
+ Array | None,
654
+ Array | None,
655
+ Array | None,
656
+ Array | None,
657
+ Array | None,
658
+ Array | None,
659
+ ],
660
+ idx: Array,
661
+ ) -> tuple[
662
+ tuple[
663
+ NormalizedLearnerState,
664
+ StreamStateT,
665
+ Array | None,
666
+ Array | None,
667
+ Array | None,
668
+ Array | None,
669
+ Array | None,
670
+ Array | None,
671
+ Array | None,
672
+ ],
673
+ Array,
674
+ ]:
675
+ (
676
+ l_state,
677
+ s_state,
678
+ ss_hist,
679
+ ss_bias_hist,
680
+ ss_rec,
681
+ ss_norm,
682
+ n_means,
683
+ n_vars,
684
+ n_rec,
685
+ ) = carry
686
+
687
+ # Perform learning step
688
+ timestep, new_s_state = stream.step(s_state, idx)
689
+ result = learner.update(l_state, timestep.observation, timestep.target)
690
+
691
+ # Step-size tracking
692
+ new_ss_hist = ss_hist
693
+ new_ss_bias_hist = ss_bias_hist
694
+ new_ss_rec = ss_rec
695
+ new_ss_norm = ss_norm
696
+
697
+ if ss_hist is not None:
698
+ should_record_ss = (idx % ss_interval) == 0
699
+ recording_idx = idx // ss_interval
700
+
701
+ # Extract current step-sizes from the inner learner state
702
+ opt_state = result.state.learner_state.optimizer_state
703
+ if hasattr(opt_state, "log_step_sizes"):
704
+ # IDBD stores log step-sizes
705
+ weight_ss = jnp.exp(opt_state.log_step_sizes) # type: ignore[union-attr]
706
+ bias_ss = opt_state.bias_step_size # type: ignore[union-attr]
707
+ elif hasattr(opt_state, "step_sizes"):
708
+ # Autostep stores step-sizes directly
709
+ weight_ss = opt_state.step_sizes # type: ignore[union-attr]
710
+ bias_ss = opt_state.bias_step_size # type: ignore[union-attr]
711
+ else:
712
+ # LMS has a single fixed step-size
713
+ weight_ss = jnp.full(feature_dim, opt_state.step_size)
714
+ bias_ss = opt_state.step_size
715
+
716
+ new_ss_hist = jax.lax.cond(
717
+ should_record_ss,
718
+ lambda _: ss_hist.at[recording_idx].set(weight_ss),
719
+ lambda _: ss_hist,
720
+ None,
721
+ )
722
+
723
+ if ss_bias_hist is not None:
724
+ new_ss_bias_hist = jax.lax.cond(
725
+ should_record_ss,
726
+ lambda _: ss_bias_hist.at[recording_idx].set(bias_ss),
727
+ lambda _: ss_bias_hist,
728
+ None,
729
+ )
730
+
731
+ if ss_rec is not None:
732
+ new_ss_rec = jax.lax.cond(
733
+ should_record_ss,
734
+ lambda _: ss_rec.at[recording_idx].set(idx),
735
+ lambda _: ss_rec,
736
+ None,
737
+ )
738
+
739
+ # Track Autostep normalizers (v_i) if applicable
740
+ if ss_norm is not None and hasattr(opt_state, "normalizers"):
741
+ new_ss_norm = jax.lax.cond(
742
+ should_record_ss,
743
+ lambda _: ss_norm.at[recording_idx].set(
744
+ opt_state.normalizers # type: ignore[union-attr]
745
+ ),
746
+ lambda _: ss_norm,
747
+ None,
748
+ )
749
+
750
+ # Normalizer state tracking
751
+ new_n_means = n_means
752
+ new_n_vars = n_vars
753
+ new_n_rec = n_rec
754
+
755
+ if n_means is not None:
756
+ should_record_norm = (idx % norm_interval) == 0
757
+ norm_recording_idx = idx // norm_interval
758
+
759
+ norm_state = result.state.normalizer_state
760
+
761
+ new_n_means = jax.lax.cond(
762
+ should_record_norm,
763
+ lambda _: n_means.at[norm_recording_idx].set(norm_state.mean),
764
+ lambda _: n_means,
765
+ None,
766
+ )
767
+
768
+ if n_vars is not None:
769
+ new_n_vars = jax.lax.cond(
770
+ should_record_norm,
771
+ lambda _: n_vars.at[norm_recording_idx].set(norm_state.var),
772
+ lambda _: n_vars,
773
+ None,
774
+ )
775
+
776
+ if n_rec is not None:
777
+ new_n_rec = jax.lax.cond(
778
+ should_record_norm,
779
+ lambda _: n_rec.at[norm_recording_idx].set(idx),
780
+ lambda _: n_rec,
781
+ None,
782
+ )
783
+
784
+ return (
785
+ result.state,
786
+ new_s_state,
787
+ new_ss_hist,
788
+ new_ss_bias_hist,
789
+ new_ss_rec,
790
+ new_ss_norm,
791
+ new_n_means,
792
+ new_n_vars,
793
+ new_n_rec,
794
+ ), result.metrics
795
+
796
+ initial_carry = (
797
+ learner_state,
798
+ stream_state,
799
+ ss_history,
800
+ ss_bias_history,
801
+ ss_rec_indices,
802
+ ss_normalizers,
803
+ norm_means,
804
+ norm_vars,
805
+ norm_rec_indices,
806
+ )
807
+
808
+ (
809
+ final_learner,
810
+ _,
811
+ final_ss_hist,
812
+ final_ss_bias_hist,
813
+ final_ss_rec,
814
+ final_ss_norm,
815
+ final_n_means,
816
+ final_n_vars,
817
+ final_n_rec,
818
+ ), metrics = jax.lax.scan(
819
+ step_fn_with_tracking, initial_carry, jnp.arange(num_steps)
820
+ )
821
+
822
+ # Build return values based on what was tracked
823
+ ss_history_result = None
824
+ if step_size_tracking is not None and final_ss_hist is not None:
825
+ ss_history_result = StepSizeHistory(
826
+ step_sizes=final_ss_hist,
827
+ bias_step_sizes=final_ss_bias_hist,
828
+ recording_indices=final_ss_rec, # type: ignore[arg-type]
829
+ normalizers=final_ss_norm,
830
+ )
831
+
832
+ norm_history_result = None
833
+ if normalizer_tracking is not None and final_n_means is not None:
834
+ norm_history_result = NormalizerHistory(
835
+ means=final_n_means,
836
+ variances=final_n_vars, # type: ignore[arg-type]
837
+ recording_indices=final_n_rec, # type: ignore[arg-type]
838
+ )
839
+
840
+ # Return appropriate tuple based on what was tracked
841
+ if ss_history_result is not None and norm_history_result is not None:
842
+ return final_learner, metrics, ss_history_result, norm_history_result
843
+ elif ss_history_result is not None:
844
+ return final_learner, metrics, ss_history_result
845
+ elif norm_history_result is not None:
846
+ return final_learner, metrics, norm_history_result
847
+ else:
848
+ return final_learner, metrics
849
+
850
+
851
+ def run_learning_loop_batched[StreamStateT](
852
+ learner: LinearLearner,
853
+ stream: ScanStream[StreamStateT],
854
+ num_steps: int,
855
+ keys: Array,
856
+ learner_state: LearnerState | None = None,
857
+ step_size_tracking: StepSizeTrackingConfig | None = None,
858
+ ) -> BatchedLearningResult:
859
+ """Run learning loop across multiple seeds in parallel using jax.vmap.
860
+
861
+ This function provides GPU parallelization for multi-seed experiments,
862
+ typically achieving 2-5x speedup over sequential execution.
863
+
864
+ Args:
865
+ learner: The learner to train
866
+ stream: Experience stream providing (observation, target) pairs
867
+ num_steps: Number of learning steps to run per seed
868
+ keys: JAX random keys with shape (num_seeds,) or (num_seeds, 2)
869
+ learner_state: Initial state (if None, will be initialized from stream).
870
+ The same initial state is used for all seeds.
871
+ step_size_tracking: Optional config for recording per-weight step-sizes.
872
+ When provided, history arrays have shape (num_seeds, num_recordings, ...)
873
+
874
+ Returns:
875
+ BatchedLearningResult containing:
876
+ - states: Batched final states with shape (num_seeds, ...) for each array
877
+ - metrics: Array of shape (num_seeds, num_steps, 3)
878
+ - step_size_history: Batched history or None if tracking disabled
879
+
880
+ Examples:
881
+ ```python
882
+ import jax.random as jr
883
+ from alberta_framework import LinearLearner, IDBD, RandomWalkStream
884
+ from alberta_framework import run_learning_loop_batched
885
+
886
+ stream = RandomWalkStream(feature_dim=10)
887
+ learner = LinearLearner(optimizer=IDBD())
888
+
889
+ # Run 30 seeds in parallel
890
+ keys = jr.split(jr.key(42), 30)
891
+ result = run_learning_loop_batched(learner, stream, num_steps=10000, keys=keys)
892
+
893
+ # result.metrics has shape (30, 10000, 3)
894
+ mean_error = result.metrics[:, :, 0].mean(axis=0) # Average over seeds
895
+ ```
896
+ """
897
+ # Define single-seed function that returns consistent structure
898
+ def single_seed_run(key: Array) -> tuple[LearnerState, Array, StepSizeHistory | None]:
899
+ result = run_learning_loop(
900
+ learner, stream, num_steps, key, learner_state, step_size_tracking
901
+ )
902
+ if step_size_tracking is not None:
903
+ state, metrics, history = cast(
904
+ tuple[LearnerState, Array, StepSizeHistory], result
905
+ )
906
+ return state, metrics, history
907
+ else:
908
+ state, metrics = cast(tuple[LearnerState, Array], result)
909
+ # Return None for history to maintain consistent output structure
910
+ return state, metrics, None
911
+
912
+ # vmap over the keys dimension
913
+ batched_states, batched_metrics, batched_history = jax.vmap(single_seed_run)(keys)
914
+
915
+ # Reconstruct batched history if tracking was enabled
916
+ if step_size_tracking is not None and batched_history is not None:
917
+ batched_step_size_history = StepSizeHistory(
918
+ step_sizes=batched_history.step_sizes,
919
+ bias_step_sizes=batched_history.bias_step_sizes,
920
+ recording_indices=batched_history.recording_indices,
921
+ normalizers=batched_history.normalizers,
922
+ )
923
+ else:
924
+ batched_step_size_history = None
925
+
926
+ return BatchedLearningResult(
927
+ states=batched_states,
928
+ metrics=batched_metrics,
929
+ step_size_history=batched_step_size_history,
930
+ )
931
+
932
+
933
+ def run_normalized_learning_loop_batched[StreamStateT](
934
+ learner: NormalizedLinearLearner,
935
+ stream: ScanStream[StreamStateT],
936
+ num_steps: int,
937
+ keys: Array,
938
+ learner_state: NormalizedLearnerState | None = None,
939
+ step_size_tracking: StepSizeTrackingConfig | None = None,
940
+ normalizer_tracking: NormalizerTrackingConfig | None = None,
941
+ ) -> BatchedNormalizedResult:
942
+ """Run normalized learning loop across multiple seeds in parallel using jax.vmap.
943
+
944
+ This function provides GPU parallelization for multi-seed experiments with
945
+ normalized learners, typically achieving 2-5x speedup over sequential execution.
946
+
947
+ Args:
948
+ learner: The normalized learner to train
949
+ stream: Experience stream providing (observation, target) pairs
950
+ num_steps: Number of learning steps to run per seed
951
+ keys: JAX random keys with shape (num_seeds,) or (num_seeds, 2)
952
+ learner_state: Initial state (if None, will be initialized from stream).
953
+ The same initial state is used for all seeds.
954
+ step_size_tracking: Optional config for recording per-weight step-sizes.
955
+ When provided, history arrays have shape (num_seeds, num_recordings, ...)
956
+ normalizer_tracking: Optional config for recording normalizer state.
957
+ When provided, history arrays have shape (num_seeds, num_recordings, ...)
958
+
959
+ Returns:
960
+ BatchedNormalizedResult containing:
961
+ - states: Batched final states with shape (num_seeds, ...) for each array
962
+ - metrics: Array of shape (num_seeds, num_steps, 4)
963
+ - step_size_history: Batched history or None if tracking disabled
964
+ - normalizer_history: Batched history or None if tracking disabled
965
+
966
+ Examples:
967
+ ```python
968
+ import jax.random as jr
969
+ from alberta_framework import NormalizedLinearLearner, IDBD, RandomWalkStream
970
+ from alberta_framework import run_normalized_learning_loop_batched
971
+
972
+ stream = RandomWalkStream(feature_dim=10)
973
+ learner = NormalizedLinearLearner(optimizer=IDBD())
974
+
975
+ # Run 30 seeds in parallel
976
+ keys = jr.split(jr.key(42), 30)
977
+ result = run_normalized_learning_loop_batched(
978
+ learner, stream, num_steps=10000, keys=keys
979
+ )
980
+
981
+ # result.metrics has shape (30, 10000, 4)
982
+ mean_error = result.metrics[:, :, 0].mean(axis=0) # Average over seeds
983
+ ```
984
+ """
985
+ # Define single-seed function that returns consistent structure
986
+ def single_seed_run(
987
+ key: Array,
988
+ ) -> tuple[
989
+ NormalizedLearnerState, Array, StepSizeHistory | None, NormalizerHistory | None
990
+ ]:
991
+ result = run_normalized_learning_loop(
992
+ learner, stream, num_steps, key, learner_state,
993
+ step_size_tracking, normalizer_tracking
994
+ )
995
+
996
+ # Unpack based on what tracking was enabled
997
+ if step_size_tracking is not None and normalizer_tracking is not None:
998
+ state, metrics, ss_history, norm_history = cast(
999
+ tuple[NormalizedLearnerState, Array, StepSizeHistory, NormalizerHistory],
1000
+ result,
1001
+ )
1002
+ return state, metrics, ss_history, norm_history
1003
+ elif step_size_tracking is not None:
1004
+ state, metrics, ss_history = cast(
1005
+ tuple[NormalizedLearnerState, Array, StepSizeHistory], result
1006
+ )
1007
+ return state, metrics, ss_history, None
1008
+ elif normalizer_tracking is not None:
1009
+ state, metrics, norm_history = cast(
1010
+ tuple[NormalizedLearnerState, Array, NormalizerHistory], result
1011
+ )
1012
+ return state, metrics, None, norm_history
1013
+ else:
1014
+ state, metrics = cast(tuple[NormalizedLearnerState, Array], result)
1015
+ return state, metrics, None, None
1016
+
1017
+ # vmap over the keys dimension
1018
+ batched_states, batched_metrics, batched_ss_history, batched_norm_history = (
1019
+ jax.vmap(single_seed_run)(keys)
1020
+ )
1021
+
1022
+ # Reconstruct batched histories if tracking was enabled
1023
+ if step_size_tracking is not None and batched_ss_history is not None:
1024
+ batched_step_size_history = StepSizeHistory(
1025
+ step_sizes=batched_ss_history.step_sizes,
1026
+ bias_step_sizes=batched_ss_history.bias_step_sizes,
1027
+ recording_indices=batched_ss_history.recording_indices,
1028
+ normalizers=batched_ss_history.normalizers,
1029
+ )
1030
+ else:
1031
+ batched_step_size_history = None
1032
+
1033
+ if normalizer_tracking is not None and batched_norm_history is not None:
1034
+ batched_normalizer_history = NormalizerHistory(
1035
+ means=batched_norm_history.means,
1036
+ variances=batched_norm_history.variances,
1037
+ recording_indices=batched_norm_history.recording_indices,
1038
+ )
1039
+ else:
1040
+ batched_normalizer_history = None
1041
+
1042
+ return BatchedNormalizedResult(
1043
+ states=batched_states,
1044
+ metrics=batched_metrics,
1045
+ step_size_history=batched_step_size_history,
1046
+ normalizer_history=batched_normalizer_history,
1047
+ )
1048
+
1049
+
1050
+ def metrics_to_dicts(metrics: Array, normalized: bool = False) -> list[dict[str, float]]:
1051
+ """Convert metrics array to list of dicts for backward compatibility.
1052
+
1053
+ Args:
1054
+ metrics: Array of shape (num_steps, 3) or (num_steps, 4)
1055
+ normalized: If True, expects 4 columns including normalizer_mean_var
1056
+
1057
+ Returns:
1058
+ List of metric dictionaries
1059
+ """
1060
+ result = []
1061
+ for row in metrics:
1062
+ d = {
1063
+ "squared_error": float(row[0]),
1064
+ "error": float(row[1]),
1065
+ "mean_step_size": float(row[2]),
1066
+ }
1067
+ if normalized and len(row) > 3:
1068
+ d["normalizer_mean_var"] = float(row[3])
1069
+ result.append(d)
1070
+ return result