alberta-framework 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alberta_framework/__init__.py +52 -23
- alberta_framework/core/learners.py +549 -18
- alberta_framework/core/normalizers.py +1 -1
- alberta_framework/core/optimizers.py +14 -12
- alberta_framework/core/types.py +70 -0
- alberta_framework/streams/base.py +8 -5
- alberta_framework/streams/synthetic.py +16 -10
- alberta_framework/utils/experiments.py +4 -3
- alberta_framework/utils/timing.py +42 -36
- {alberta_framework-0.1.0.dist-info → alberta_framework-0.2.0.dist-info}/METADATA +10 -2
- alberta_framework-0.2.0.dist-info/RECORD +22 -0
- alberta_framework-0.1.0.dist-info/RECORD +0 -22
- {alberta_framework-0.1.0.dist-info → alberta_framework-0.2.0.dist-info}/WHEEL +0 -0
- {alberta_framework-0.1.0.dist-info → alberta_framework-0.2.0.dist-info}/licenses/LICENSE +0 -0
alberta_framework/__init__.py
CHANGED
|
@@ -1,28 +1,45 @@
|
|
|
1
|
-
"""Alberta Framework:
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
learning
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
1
|
+
"""Alberta Framework: A JAX-based research framework for continual AI.
|
|
2
|
+
|
|
3
|
+
The Alberta Framework provides foundational components for continual reinforcement
|
|
4
|
+
learning research. Built on JAX for hardware acceleration, the framework emphasizes
|
|
5
|
+
temporal uniformity — every component updates at every time step, with no special
|
|
6
|
+
training phases or batch processing.
|
|
7
|
+
|
|
8
|
+
Roadmap
|
|
9
|
+
-------
|
|
10
|
+
| Step | Focus | Status |
|
|
11
|
+
|------|-------|--------|
|
|
12
|
+
| 1 | Meta-learned step-sizes (IDBD, Autostep) | **Complete** |
|
|
13
|
+
| 2 | Feature generation and testing | Planned |
|
|
14
|
+
| 3 | GVF predictions, Horde architecture | Planned |
|
|
15
|
+
| 4 | Actor-critic with eligibility traces | Planned |
|
|
16
|
+
| 5-6 | Off-policy learning, average reward | Planned |
|
|
17
|
+
| 7-12 | Hierarchical, multi-agent, world models | Future |
|
|
18
|
+
|
|
19
|
+
Examples
|
|
20
|
+
--------
|
|
21
|
+
```python
|
|
22
|
+
import jax.random as jr
|
|
23
|
+
from alberta_framework import LinearLearner, IDBD, RandomWalkStream, run_learning_loop
|
|
24
|
+
|
|
25
|
+
# Non-stationary stream where target weights drift over time
|
|
26
|
+
stream = RandomWalkStream(feature_dim=10, drift_rate=0.001)
|
|
27
|
+
|
|
28
|
+
# Learner with IDBD meta-learned step-sizes
|
|
29
|
+
learner = LinearLearner(optimizer=IDBD())
|
|
30
|
+
|
|
31
|
+
# JIT-compiled training via jax.lax.scan
|
|
32
|
+
state, metrics = run_learning_loop(learner, stream, num_steps=10000, key=jr.key(42))
|
|
33
|
+
```
|
|
34
|
+
|
|
35
|
+
References
|
|
36
|
+
----------
|
|
37
|
+
- The Alberta Plan for AI Research (Sutton et al., 2022): https://arxiv.org/abs/2208.11173
|
|
38
|
+
- Adapting Bias by Gradient Descent (Sutton, 1992)
|
|
39
|
+
- Tuning-free Step-size Adaptation (Mahmood et al., 2012)
|
|
23
40
|
"""
|
|
24
41
|
|
|
25
|
-
__version__ = "0.
|
|
42
|
+
__version__ = "0.2.0"
|
|
26
43
|
|
|
27
44
|
# Core types
|
|
28
45
|
# Learners
|
|
@@ -33,7 +50,9 @@ from alberta_framework.core.learners import (
|
|
|
33
50
|
UpdateResult,
|
|
34
51
|
metrics_to_dicts,
|
|
35
52
|
run_learning_loop,
|
|
53
|
+
run_learning_loop_batched,
|
|
36
54
|
run_normalized_learning_loop,
|
|
55
|
+
run_normalized_learning_loop_batched,
|
|
37
56
|
)
|
|
38
57
|
|
|
39
58
|
# Normalizers
|
|
@@ -47,9 +66,13 @@ from alberta_framework.core.normalizers import (
|
|
|
47
66
|
from alberta_framework.core.optimizers import IDBD, LMS, Autostep, Optimizer
|
|
48
67
|
from alberta_framework.core.types import (
|
|
49
68
|
AutostepState,
|
|
69
|
+
BatchedLearningResult,
|
|
70
|
+
BatchedNormalizedResult,
|
|
50
71
|
IDBDState,
|
|
51
72
|
LearnerState,
|
|
52
73
|
LMSState,
|
|
74
|
+
NormalizerHistory,
|
|
75
|
+
NormalizerTrackingConfig,
|
|
53
76
|
Observation,
|
|
54
77
|
Prediction,
|
|
55
78
|
StepSizeHistory,
|
|
@@ -119,10 +142,14 @@ __all__ = [
|
|
|
119
142
|
"__version__",
|
|
120
143
|
# Types
|
|
121
144
|
"AutostepState",
|
|
145
|
+
"BatchedLearningResult",
|
|
146
|
+
"BatchedNormalizedResult",
|
|
122
147
|
"IDBDState",
|
|
123
148
|
"LMSState",
|
|
124
149
|
"LearnerState",
|
|
150
|
+
"NormalizerHistory",
|
|
125
151
|
"NormalizerState",
|
|
152
|
+
"NormalizerTrackingConfig",
|
|
126
153
|
"Observation",
|
|
127
154
|
"Prediction",
|
|
128
155
|
"StepSizeHistory",
|
|
@@ -143,7 +170,9 @@ __all__ = [
|
|
|
143
170
|
"NormalizedLearnerState",
|
|
144
171
|
"NormalizedLinearLearner",
|
|
145
172
|
"run_learning_loop",
|
|
173
|
+
"run_learning_loop_batched",
|
|
146
174
|
"run_normalized_learning_loop",
|
|
175
|
+
"run_normalized_learning_loop_batched",
|
|
147
176
|
"metrics_to_dicts",
|
|
148
177
|
# Streams - protocol
|
|
149
178
|
"ScanStream",
|
|
@@ -15,9 +15,13 @@ from alberta_framework.core.normalizers import NormalizerState, OnlineNormalizer
|
|
|
15
15
|
from alberta_framework.core.optimizers import LMS, Optimizer
|
|
16
16
|
from alberta_framework.core.types import (
|
|
17
17
|
AutostepState,
|
|
18
|
+
BatchedLearningResult,
|
|
19
|
+
BatchedNormalizedResult,
|
|
18
20
|
IDBDState,
|
|
19
21
|
LearnerState,
|
|
20
22
|
LMSState,
|
|
23
|
+
NormalizerHistory,
|
|
24
|
+
NormalizerTrackingConfig,
|
|
21
25
|
Observation,
|
|
22
26
|
Prediction,
|
|
23
27
|
StepSizeHistory,
|
|
@@ -48,7 +52,7 @@ class UpdateResult(NamedTuple):
|
|
|
48
52
|
class LinearLearner:
|
|
49
53
|
"""Linear function approximator with pluggable optimizer.
|
|
50
54
|
|
|
51
|
-
Computes predictions as: y = w @ x + b
|
|
55
|
+
Computes predictions as: `y = w @ x + b`
|
|
52
56
|
|
|
53
57
|
The learner maintains weights and bias, delegating the adaptation
|
|
54
58
|
of learning rates to the optimizer (e.g., LMS or IDBD).
|
|
@@ -93,7 +97,7 @@ class LinearLearner:
|
|
|
93
97
|
observation: Input feature vector
|
|
94
98
|
|
|
95
99
|
Returns:
|
|
96
|
-
Scalar prediction y = w @ x + b
|
|
100
|
+
Scalar prediction `y = w @ x + b`
|
|
97
101
|
"""
|
|
98
102
|
return jnp.atleast_1d(jnp.dot(state.weights, observation) + state.bias)
|
|
99
103
|
|
|
@@ -238,10 +242,25 @@ def run_learning_loop[StreamStateT](
|
|
|
238
242
|
)
|
|
239
243
|
recording_indices = jnp.zeros(num_recordings, dtype=jnp.int32)
|
|
240
244
|
|
|
245
|
+
# Check if we need to track Autostep normalizers
|
|
246
|
+
# We detect this at trace time by checking the initial optimizer state
|
|
247
|
+
track_normalizers = hasattr(learner_state.optimizer_state, "normalizers")
|
|
248
|
+
normalizer_history = (
|
|
249
|
+
jnp.zeros((num_recordings, feature_dim), dtype=jnp.float32)
|
|
250
|
+
if track_normalizers
|
|
251
|
+
else None
|
|
252
|
+
)
|
|
253
|
+
|
|
241
254
|
def step_fn_with_tracking(
|
|
242
|
-
carry: tuple[
|
|
243
|
-
|
|
244
|
-
|
|
255
|
+
carry: tuple[
|
|
256
|
+
LearnerState, StreamStateT, Array, Array | None, Array, Array | None
|
|
257
|
+
],
|
|
258
|
+
idx: Array,
|
|
259
|
+
) -> tuple[
|
|
260
|
+
tuple[LearnerState, StreamStateT, Array, Array | None, Array, Array | None],
|
|
261
|
+
Array,
|
|
262
|
+
]:
|
|
263
|
+
l_state, s_state, ss_history, b_history, rec_indices, norm_history = carry
|
|
245
264
|
|
|
246
265
|
# Perform learning step
|
|
247
266
|
timestep, new_s_state = stream.step(s_state, idx)
|
|
@@ -291,12 +310,25 @@ def run_learning_loop[StreamStateT](
|
|
|
291
310
|
None,
|
|
292
311
|
)
|
|
293
312
|
|
|
313
|
+
# Track Autostep normalizers (v_i) if applicable
|
|
314
|
+
new_norm_history = norm_history
|
|
315
|
+
if norm_history is not None and hasattr(opt_state, "normalizers"):
|
|
316
|
+
new_norm_history = jax.lax.cond(
|
|
317
|
+
should_record,
|
|
318
|
+
lambda _: norm_history.at[recording_idx].set(
|
|
319
|
+
opt_state.normalizers # type: ignore[union-attr]
|
|
320
|
+
),
|
|
321
|
+
lambda _: norm_history,
|
|
322
|
+
None,
|
|
323
|
+
)
|
|
324
|
+
|
|
294
325
|
return (
|
|
295
326
|
result.state,
|
|
296
327
|
new_s_state,
|
|
297
328
|
new_ss_history,
|
|
298
329
|
new_b_history,
|
|
299
330
|
new_rec_indices,
|
|
331
|
+
new_norm_history,
|
|
300
332
|
), result.metrics
|
|
301
333
|
|
|
302
334
|
initial_carry = (
|
|
@@ -305,16 +337,25 @@ def run_learning_loop[StreamStateT](
|
|
|
305
337
|
step_size_history,
|
|
306
338
|
bias_history,
|
|
307
339
|
recording_indices,
|
|
340
|
+
normalizer_history,
|
|
308
341
|
)
|
|
309
342
|
|
|
310
|
-
(
|
|
311
|
-
|
|
343
|
+
(
|
|
344
|
+
final_learner,
|
|
345
|
+
_,
|
|
346
|
+
final_ss_history,
|
|
347
|
+
final_b_history,
|
|
348
|
+
final_rec_indices,
|
|
349
|
+
final_norm_history,
|
|
350
|
+
), metrics = jax.lax.scan(
|
|
351
|
+
step_fn_with_tracking, initial_carry, jnp.arange(num_steps)
|
|
312
352
|
)
|
|
313
353
|
|
|
314
354
|
history = StepSizeHistory(
|
|
315
355
|
step_sizes=final_ss_history,
|
|
316
356
|
bias_step_sizes=final_b_history,
|
|
317
357
|
recording_indices=final_rec_indices,
|
|
358
|
+
normalizers=final_norm_history,
|
|
318
359
|
)
|
|
319
360
|
|
|
320
361
|
return final_learner, metrics, history
|
|
@@ -473,7 +514,14 @@ def run_normalized_learning_loop[StreamStateT](
|
|
|
473
514
|
num_steps: int,
|
|
474
515
|
key: Array,
|
|
475
516
|
learner_state: NormalizedLearnerState | None = None,
|
|
476
|
-
|
|
517
|
+
step_size_tracking: StepSizeTrackingConfig | None = None,
|
|
518
|
+
normalizer_tracking: NormalizerTrackingConfig | None = None,
|
|
519
|
+
) -> (
|
|
520
|
+
tuple[NormalizedLearnerState, Array]
|
|
521
|
+
| tuple[NormalizedLearnerState, Array, StepSizeHistory]
|
|
522
|
+
| tuple[NormalizedLearnerState, Array, NormalizerHistory]
|
|
523
|
+
| tuple[NormalizedLearnerState, Array, StepSizeHistory, NormalizerHistory]
|
|
524
|
+
):
|
|
477
525
|
"""Run the learning loop with normalization using jax.lax.scan.
|
|
478
526
|
|
|
479
527
|
Args:
|
|
@@ -482,29 +530,512 @@ def run_normalized_learning_loop[StreamStateT](
|
|
|
482
530
|
num_steps: Number of learning steps to run
|
|
483
531
|
key: JAX random key for stream initialization
|
|
484
532
|
learner_state: Initial state (if None, will be initialized from stream)
|
|
533
|
+
step_size_tracking: Optional config for recording per-weight step-sizes.
|
|
534
|
+
When provided, returns StepSizeHistory including Autostep normalizers if applicable.
|
|
535
|
+
normalizer_tracking: Optional config for recording per-feature normalizer state.
|
|
536
|
+
When provided, returns NormalizerHistory with means and variances over time.
|
|
485
537
|
|
|
486
538
|
Returns:
|
|
487
|
-
|
|
488
|
-
|
|
539
|
+
If no tracking:
|
|
540
|
+
Tuple of (final_state, metrics_array) where metrics_array has shape
|
|
541
|
+
(num_steps, 4) with columns [squared_error, error, mean_step_size, normalizer_mean_var]
|
|
542
|
+
If step_size_tracking only:
|
|
543
|
+
Tuple of (final_state, metrics_array, step_size_history)
|
|
544
|
+
If normalizer_tracking only:
|
|
545
|
+
Tuple of (final_state, metrics_array, normalizer_history)
|
|
546
|
+
If both:
|
|
547
|
+
Tuple of (final_state, metrics_array, step_size_history, normalizer_history)
|
|
548
|
+
|
|
549
|
+
Raises:
|
|
550
|
+
ValueError: If tracking interval is invalid
|
|
489
551
|
"""
|
|
552
|
+
# Validate tracking configs
|
|
553
|
+
if step_size_tracking is not None:
|
|
554
|
+
if step_size_tracking.interval < 1:
|
|
555
|
+
raise ValueError(
|
|
556
|
+
f"step_size_tracking.interval must be >= 1, got {step_size_tracking.interval}"
|
|
557
|
+
)
|
|
558
|
+
if step_size_tracking.interval > num_steps:
|
|
559
|
+
raise ValueError(
|
|
560
|
+
f"step_size_tracking.interval ({step_size_tracking.interval}) "
|
|
561
|
+
f"must be <= num_steps ({num_steps})"
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
if normalizer_tracking is not None:
|
|
565
|
+
if normalizer_tracking.interval < 1:
|
|
566
|
+
raise ValueError(
|
|
567
|
+
f"normalizer_tracking.interval must be >= 1, got {normalizer_tracking.interval}"
|
|
568
|
+
)
|
|
569
|
+
if normalizer_tracking.interval > num_steps:
|
|
570
|
+
raise ValueError(
|
|
571
|
+
f"normalizer_tracking.interval ({normalizer_tracking.interval}) "
|
|
572
|
+
f"must be <= num_steps ({num_steps})"
|
|
573
|
+
)
|
|
574
|
+
|
|
490
575
|
# Initialize states
|
|
491
576
|
if learner_state is None:
|
|
492
577
|
learner_state = learner.init(stream.feature_dim)
|
|
493
578
|
stream_state = stream.init(key)
|
|
494
579
|
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
580
|
+
feature_dim = stream.feature_dim
|
|
581
|
+
|
|
582
|
+
# No tracking - simple case
|
|
583
|
+
if step_size_tracking is None and normalizer_tracking is None:
|
|
584
|
+
|
|
585
|
+
def step_fn(
|
|
586
|
+
carry: tuple[NormalizedLearnerState, StreamStateT], idx: Array
|
|
587
|
+
) -> tuple[tuple[NormalizedLearnerState, StreamStateT], Array]:
|
|
588
|
+
l_state, s_state = carry
|
|
589
|
+
timestep, new_s_state = stream.step(s_state, idx)
|
|
590
|
+
result = learner.update(l_state, timestep.observation, timestep.target)
|
|
591
|
+
return (result.state, new_s_state), result.metrics
|
|
592
|
+
|
|
593
|
+
(final_learner, _), metrics = jax.lax.scan(
|
|
594
|
+
step_fn, (learner_state, stream_state), jnp.arange(num_steps)
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
return final_learner, metrics
|
|
598
|
+
|
|
599
|
+
# Tracking enabled - need to set up history arrays
|
|
600
|
+
ss_interval = step_size_tracking.interval if step_size_tracking else num_steps + 1
|
|
601
|
+
norm_interval = (
|
|
602
|
+
normalizer_tracking.interval if normalizer_tracking else num_steps + 1
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
ss_num_recordings = num_steps // ss_interval if step_size_tracking else 0
|
|
606
|
+
norm_num_recordings = num_steps // norm_interval if normalizer_tracking else 0
|
|
607
|
+
|
|
608
|
+
# Pre-allocate step-size history arrays
|
|
609
|
+
ss_history = (
|
|
610
|
+
jnp.zeros((ss_num_recordings, feature_dim), dtype=jnp.float32)
|
|
611
|
+
if step_size_tracking
|
|
612
|
+
else None
|
|
613
|
+
)
|
|
614
|
+
ss_bias_history = (
|
|
615
|
+
jnp.zeros(ss_num_recordings, dtype=jnp.float32)
|
|
616
|
+
if step_size_tracking and step_size_tracking.include_bias
|
|
617
|
+
else None
|
|
618
|
+
)
|
|
619
|
+
ss_rec_indices = (
|
|
620
|
+
jnp.zeros(ss_num_recordings, dtype=jnp.int32) if step_size_tracking else None
|
|
621
|
+
)
|
|
622
|
+
|
|
623
|
+
# Check if we need to track Autostep normalizers
|
|
624
|
+
track_autostep_normalizers = hasattr(
|
|
625
|
+
learner_state.learner_state.optimizer_state, "normalizers"
|
|
626
|
+
)
|
|
627
|
+
ss_normalizers = (
|
|
628
|
+
jnp.zeros((ss_num_recordings, feature_dim), dtype=jnp.float32)
|
|
629
|
+
if step_size_tracking and track_autostep_normalizers
|
|
630
|
+
else None
|
|
631
|
+
)
|
|
632
|
+
|
|
633
|
+
# Pre-allocate normalizer state history arrays
|
|
634
|
+
norm_means = (
|
|
635
|
+
jnp.zeros((norm_num_recordings, feature_dim), dtype=jnp.float32)
|
|
636
|
+
if normalizer_tracking
|
|
637
|
+
else None
|
|
638
|
+
)
|
|
639
|
+
norm_vars = (
|
|
640
|
+
jnp.zeros((norm_num_recordings, feature_dim), dtype=jnp.float32)
|
|
641
|
+
if normalizer_tracking
|
|
642
|
+
else None
|
|
643
|
+
)
|
|
644
|
+
norm_rec_indices = (
|
|
645
|
+
jnp.zeros(norm_num_recordings, dtype=jnp.int32) if normalizer_tracking else None
|
|
646
|
+
)
|
|
647
|
+
|
|
648
|
+
def step_fn_with_tracking(
|
|
649
|
+
carry: tuple[
|
|
650
|
+
NormalizedLearnerState,
|
|
651
|
+
StreamStateT,
|
|
652
|
+
Array | None,
|
|
653
|
+
Array | None,
|
|
654
|
+
Array | None,
|
|
655
|
+
Array | None,
|
|
656
|
+
Array | None,
|
|
657
|
+
Array | None,
|
|
658
|
+
Array | None,
|
|
659
|
+
],
|
|
660
|
+
idx: Array,
|
|
661
|
+
) -> tuple[
|
|
662
|
+
tuple[
|
|
663
|
+
NormalizedLearnerState,
|
|
664
|
+
StreamStateT,
|
|
665
|
+
Array | None,
|
|
666
|
+
Array | None,
|
|
667
|
+
Array | None,
|
|
668
|
+
Array | None,
|
|
669
|
+
Array | None,
|
|
670
|
+
Array | None,
|
|
671
|
+
Array | None,
|
|
672
|
+
],
|
|
673
|
+
Array,
|
|
674
|
+
]:
|
|
675
|
+
(
|
|
676
|
+
l_state,
|
|
677
|
+
s_state,
|
|
678
|
+
ss_hist,
|
|
679
|
+
ss_bias_hist,
|
|
680
|
+
ss_rec,
|
|
681
|
+
ss_norm,
|
|
682
|
+
n_means,
|
|
683
|
+
n_vars,
|
|
684
|
+
n_rec,
|
|
685
|
+
) = carry
|
|
686
|
+
|
|
687
|
+
# Perform learning step
|
|
499
688
|
timestep, new_s_state = stream.step(s_state, idx)
|
|
500
689
|
result = learner.update(l_state, timestep.observation, timestep.target)
|
|
501
|
-
return (result.state, new_s_state), result.metrics
|
|
502
690
|
|
|
503
|
-
|
|
504
|
-
|
|
691
|
+
# Step-size tracking
|
|
692
|
+
new_ss_hist = ss_hist
|
|
693
|
+
new_ss_bias_hist = ss_bias_hist
|
|
694
|
+
new_ss_rec = ss_rec
|
|
695
|
+
new_ss_norm = ss_norm
|
|
696
|
+
|
|
697
|
+
if ss_hist is not None:
|
|
698
|
+
should_record_ss = (idx % ss_interval) == 0
|
|
699
|
+
recording_idx = idx // ss_interval
|
|
700
|
+
|
|
701
|
+
# Extract current step-sizes from the inner learner state
|
|
702
|
+
opt_state = result.state.learner_state.optimizer_state
|
|
703
|
+
if hasattr(opt_state, "log_step_sizes"):
|
|
704
|
+
# IDBD stores log step-sizes
|
|
705
|
+
weight_ss = jnp.exp(opt_state.log_step_sizes) # type: ignore[union-attr]
|
|
706
|
+
bias_ss = opt_state.bias_step_size # type: ignore[union-attr]
|
|
707
|
+
elif hasattr(opt_state, "step_sizes"):
|
|
708
|
+
# Autostep stores step-sizes directly
|
|
709
|
+
weight_ss = opt_state.step_sizes # type: ignore[union-attr]
|
|
710
|
+
bias_ss = opt_state.bias_step_size # type: ignore[union-attr]
|
|
711
|
+
else:
|
|
712
|
+
# LMS has a single fixed step-size
|
|
713
|
+
weight_ss = jnp.full(feature_dim, opt_state.step_size)
|
|
714
|
+
bias_ss = opt_state.step_size
|
|
715
|
+
|
|
716
|
+
new_ss_hist = jax.lax.cond(
|
|
717
|
+
should_record_ss,
|
|
718
|
+
lambda _: ss_hist.at[recording_idx].set(weight_ss),
|
|
719
|
+
lambda _: ss_hist,
|
|
720
|
+
None,
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
if ss_bias_hist is not None:
|
|
724
|
+
new_ss_bias_hist = jax.lax.cond(
|
|
725
|
+
should_record_ss,
|
|
726
|
+
lambda _: ss_bias_hist.at[recording_idx].set(bias_ss),
|
|
727
|
+
lambda _: ss_bias_hist,
|
|
728
|
+
None,
|
|
729
|
+
)
|
|
730
|
+
|
|
731
|
+
if ss_rec is not None:
|
|
732
|
+
new_ss_rec = jax.lax.cond(
|
|
733
|
+
should_record_ss,
|
|
734
|
+
lambda _: ss_rec.at[recording_idx].set(idx),
|
|
735
|
+
lambda _: ss_rec,
|
|
736
|
+
None,
|
|
737
|
+
)
|
|
738
|
+
|
|
739
|
+
# Track Autostep normalizers (v_i) if applicable
|
|
740
|
+
if ss_norm is not None and hasattr(opt_state, "normalizers"):
|
|
741
|
+
new_ss_norm = jax.lax.cond(
|
|
742
|
+
should_record_ss,
|
|
743
|
+
lambda _: ss_norm.at[recording_idx].set(
|
|
744
|
+
opt_state.normalizers # type: ignore[union-attr]
|
|
745
|
+
),
|
|
746
|
+
lambda _: ss_norm,
|
|
747
|
+
None,
|
|
748
|
+
)
|
|
749
|
+
|
|
750
|
+
# Normalizer state tracking
|
|
751
|
+
new_n_means = n_means
|
|
752
|
+
new_n_vars = n_vars
|
|
753
|
+
new_n_rec = n_rec
|
|
754
|
+
|
|
755
|
+
if n_means is not None:
|
|
756
|
+
should_record_norm = (idx % norm_interval) == 0
|
|
757
|
+
norm_recording_idx = idx // norm_interval
|
|
758
|
+
|
|
759
|
+
norm_state = result.state.normalizer_state
|
|
760
|
+
|
|
761
|
+
new_n_means = jax.lax.cond(
|
|
762
|
+
should_record_norm,
|
|
763
|
+
lambda _: n_means.at[norm_recording_idx].set(norm_state.mean),
|
|
764
|
+
lambda _: n_means,
|
|
765
|
+
None,
|
|
766
|
+
)
|
|
767
|
+
|
|
768
|
+
if n_vars is not None:
|
|
769
|
+
new_n_vars = jax.lax.cond(
|
|
770
|
+
should_record_norm,
|
|
771
|
+
lambda _: n_vars.at[norm_recording_idx].set(norm_state.var),
|
|
772
|
+
lambda _: n_vars,
|
|
773
|
+
None,
|
|
774
|
+
)
|
|
775
|
+
|
|
776
|
+
if n_rec is not None:
|
|
777
|
+
new_n_rec = jax.lax.cond(
|
|
778
|
+
should_record_norm,
|
|
779
|
+
lambda _: n_rec.at[norm_recording_idx].set(idx),
|
|
780
|
+
lambda _: n_rec,
|
|
781
|
+
None,
|
|
782
|
+
)
|
|
783
|
+
|
|
784
|
+
return (
|
|
785
|
+
result.state,
|
|
786
|
+
new_s_state,
|
|
787
|
+
new_ss_hist,
|
|
788
|
+
new_ss_bias_hist,
|
|
789
|
+
new_ss_rec,
|
|
790
|
+
new_ss_norm,
|
|
791
|
+
new_n_means,
|
|
792
|
+
new_n_vars,
|
|
793
|
+
new_n_rec,
|
|
794
|
+
), result.metrics
|
|
795
|
+
|
|
796
|
+
initial_carry = (
|
|
797
|
+
learner_state,
|
|
798
|
+
stream_state,
|
|
799
|
+
ss_history,
|
|
800
|
+
ss_bias_history,
|
|
801
|
+
ss_rec_indices,
|
|
802
|
+
ss_normalizers,
|
|
803
|
+
norm_means,
|
|
804
|
+
norm_vars,
|
|
805
|
+
norm_rec_indices,
|
|
806
|
+
)
|
|
807
|
+
|
|
808
|
+
(
|
|
809
|
+
final_learner,
|
|
810
|
+
_,
|
|
811
|
+
final_ss_hist,
|
|
812
|
+
final_ss_bias_hist,
|
|
813
|
+
final_ss_rec,
|
|
814
|
+
final_ss_norm,
|
|
815
|
+
final_n_means,
|
|
816
|
+
final_n_vars,
|
|
817
|
+
final_n_rec,
|
|
818
|
+
), metrics = jax.lax.scan(
|
|
819
|
+
step_fn_with_tracking, initial_carry, jnp.arange(num_steps)
|
|
820
|
+
)
|
|
821
|
+
|
|
822
|
+
# Build return values based on what was tracked
|
|
823
|
+
ss_history_result = None
|
|
824
|
+
if step_size_tracking is not None and final_ss_hist is not None:
|
|
825
|
+
ss_history_result = StepSizeHistory(
|
|
826
|
+
step_sizes=final_ss_hist,
|
|
827
|
+
bias_step_sizes=final_ss_bias_hist,
|
|
828
|
+
recording_indices=final_ss_rec, # type: ignore[arg-type]
|
|
829
|
+
normalizers=final_ss_norm,
|
|
830
|
+
)
|
|
831
|
+
|
|
832
|
+
norm_history_result = None
|
|
833
|
+
if normalizer_tracking is not None and final_n_means is not None:
|
|
834
|
+
norm_history_result = NormalizerHistory(
|
|
835
|
+
means=final_n_means,
|
|
836
|
+
variances=final_n_vars, # type: ignore[arg-type]
|
|
837
|
+
recording_indices=final_n_rec, # type: ignore[arg-type]
|
|
838
|
+
)
|
|
839
|
+
|
|
840
|
+
# Return appropriate tuple based on what was tracked
|
|
841
|
+
if ss_history_result is not None and norm_history_result is not None:
|
|
842
|
+
return final_learner, metrics, ss_history_result, norm_history_result
|
|
843
|
+
elif ss_history_result is not None:
|
|
844
|
+
return final_learner, metrics, ss_history_result
|
|
845
|
+
elif norm_history_result is not None:
|
|
846
|
+
return final_learner, metrics, norm_history_result
|
|
847
|
+
else:
|
|
848
|
+
return final_learner, metrics
|
|
849
|
+
|
|
850
|
+
|
|
851
|
+
def run_learning_loop_batched[StreamStateT](
|
|
852
|
+
learner: LinearLearner,
|
|
853
|
+
stream: ScanStream[StreamStateT],
|
|
854
|
+
num_steps: int,
|
|
855
|
+
keys: Array,
|
|
856
|
+
learner_state: LearnerState | None = None,
|
|
857
|
+
step_size_tracking: StepSizeTrackingConfig | None = None,
|
|
858
|
+
) -> BatchedLearningResult:
|
|
859
|
+
"""Run learning loop across multiple seeds in parallel using jax.vmap.
|
|
860
|
+
|
|
861
|
+
This function provides GPU parallelization for multi-seed experiments,
|
|
862
|
+
typically achieving 2-5x speedup over sequential execution.
|
|
863
|
+
|
|
864
|
+
Args:
|
|
865
|
+
learner: The learner to train
|
|
866
|
+
stream: Experience stream providing (observation, target) pairs
|
|
867
|
+
num_steps: Number of learning steps to run per seed
|
|
868
|
+
keys: JAX random keys with shape (num_seeds,) or (num_seeds, 2)
|
|
869
|
+
learner_state: Initial state (if None, will be initialized from stream).
|
|
870
|
+
The same initial state is used for all seeds.
|
|
871
|
+
step_size_tracking: Optional config for recording per-weight step-sizes.
|
|
872
|
+
When provided, history arrays have shape (num_seeds, num_recordings, ...)
|
|
873
|
+
|
|
874
|
+
Returns:
|
|
875
|
+
BatchedLearningResult containing:
|
|
876
|
+
- states: Batched final states with shape (num_seeds, ...) for each array
|
|
877
|
+
- metrics: Array of shape (num_seeds, num_steps, 3)
|
|
878
|
+
- step_size_history: Batched history or None if tracking disabled
|
|
879
|
+
|
|
880
|
+
Examples:
|
|
881
|
+
```python
|
|
882
|
+
import jax.random as jr
|
|
883
|
+
from alberta_framework import LinearLearner, IDBD, RandomWalkStream
|
|
884
|
+
from alberta_framework import run_learning_loop_batched
|
|
885
|
+
|
|
886
|
+
stream = RandomWalkStream(feature_dim=10)
|
|
887
|
+
learner = LinearLearner(optimizer=IDBD())
|
|
888
|
+
|
|
889
|
+
# Run 30 seeds in parallel
|
|
890
|
+
keys = jr.split(jr.key(42), 30)
|
|
891
|
+
result = run_learning_loop_batched(learner, stream, num_steps=10000, keys=keys)
|
|
892
|
+
|
|
893
|
+
# result.metrics has shape (30, 10000, 3)
|
|
894
|
+
mean_error = result.metrics[:, :, 0].mean(axis=0) # Average over seeds
|
|
895
|
+
```
|
|
896
|
+
"""
|
|
897
|
+
# Define single-seed function that returns consistent structure
|
|
898
|
+
def single_seed_run(key: Array) -> tuple[LearnerState, Array, StepSizeHistory | None]:
|
|
899
|
+
result = run_learning_loop(
|
|
900
|
+
learner, stream, num_steps, key, learner_state, step_size_tracking
|
|
901
|
+
)
|
|
902
|
+
if step_size_tracking is not None:
|
|
903
|
+
state, metrics, history = result
|
|
904
|
+
return state, metrics, history
|
|
905
|
+
else:
|
|
906
|
+
state, metrics = result
|
|
907
|
+
# Return None for history to maintain consistent output structure
|
|
908
|
+
return state, metrics, None
|
|
909
|
+
|
|
910
|
+
# vmap over the keys dimension
|
|
911
|
+
batched_states, batched_metrics, batched_history = jax.vmap(single_seed_run)(keys)
|
|
912
|
+
|
|
913
|
+
# Reconstruct batched history if tracking was enabled
|
|
914
|
+
if step_size_tracking is not None:
|
|
915
|
+
batched_step_size_history = StepSizeHistory(
|
|
916
|
+
step_sizes=batched_history.step_sizes,
|
|
917
|
+
bias_step_sizes=batched_history.bias_step_sizes,
|
|
918
|
+
recording_indices=batched_history.recording_indices,
|
|
919
|
+
normalizers=batched_history.normalizers,
|
|
920
|
+
)
|
|
921
|
+
else:
|
|
922
|
+
batched_step_size_history = None
|
|
923
|
+
|
|
924
|
+
return BatchedLearningResult(
|
|
925
|
+
states=batched_states,
|
|
926
|
+
metrics=batched_metrics,
|
|
927
|
+
step_size_history=batched_step_size_history,
|
|
505
928
|
)
|
|
506
929
|
|
|
507
|
-
|
|
930
|
+
|
|
931
|
+
def run_normalized_learning_loop_batched[StreamStateT](
|
|
932
|
+
learner: NormalizedLinearLearner,
|
|
933
|
+
stream: ScanStream[StreamStateT],
|
|
934
|
+
num_steps: int,
|
|
935
|
+
keys: Array,
|
|
936
|
+
learner_state: NormalizedLearnerState | None = None,
|
|
937
|
+
step_size_tracking: StepSizeTrackingConfig | None = None,
|
|
938
|
+
normalizer_tracking: NormalizerTrackingConfig | None = None,
|
|
939
|
+
) -> BatchedNormalizedResult:
|
|
940
|
+
"""Run normalized learning loop across multiple seeds in parallel using jax.vmap.
|
|
941
|
+
|
|
942
|
+
This function provides GPU parallelization for multi-seed experiments with
|
|
943
|
+
normalized learners, typically achieving 2-5x speedup over sequential execution.
|
|
944
|
+
|
|
945
|
+
Args:
|
|
946
|
+
learner: The normalized learner to train
|
|
947
|
+
stream: Experience stream providing (observation, target) pairs
|
|
948
|
+
num_steps: Number of learning steps to run per seed
|
|
949
|
+
keys: JAX random keys with shape (num_seeds,) or (num_seeds, 2)
|
|
950
|
+
learner_state: Initial state (if None, will be initialized from stream).
|
|
951
|
+
The same initial state is used for all seeds.
|
|
952
|
+
step_size_tracking: Optional config for recording per-weight step-sizes.
|
|
953
|
+
When provided, history arrays have shape (num_seeds, num_recordings, ...)
|
|
954
|
+
normalizer_tracking: Optional config for recording normalizer state.
|
|
955
|
+
When provided, history arrays have shape (num_seeds, num_recordings, ...)
|
|
956
|
+
|
|
957
|
+
Returns:
|
|
958
|
+
BatchedNormalizedResult containing:
|
|
959
|
+
- states: Batched final states with shape (num_seeds, ...) for each array
|
|
960
|
+
- metrics: Array of shape (num_seeds, num_steps, 4)
|
|
961
|
+
- step_size_history: Batched history or None if tracking disabled
|
|
962
|
+
- normalizer_history: Batched history or None if tracking disabled
|
|
963
|
+
|
|
964
|
+
Examples:
|
|
965
|
+
```python
|
|
966
|
+
import jax.random as jr
|
|
967
|
+
from alberta_framework import NormalizedLinearLearner, IDBD, RandomWalkStream
|
|
968
|
+
from alberta_framework import run_normalized_learning_loop_batched
|
|
969
|
+
|
|
970
|
+
stream = RandomWalkStream(feature_dim=10)
|
|
971
|
+
learner = NormalizedLinearLearner(optimizer=IDBD())
|
|
972
|
+
|
|
973
|
+
# Run 30 seeds in parallel
|
|
974
|
+
keys = jr.split(jr.key(42), 30)
|
|
975
|
+
result = run_normalized_learning_loop_batched(
|
|
976
|
+
learner, stream, num_steps=10000, keys=keys
|
|
977
|
+
)
|
|
978
|
+
|
|
979
|
+
# result.metrics has shape (30, 10000, 4)
|
|
980
|
+
mean_error = result.metrics[:, :, 0].mean(axis=0) # Average over seeds
|
|
981
|
+
```
|
|
982
|
+
"""
|
|
983
|
+
# Define single-seed function that returns consistent structure
|
|
984
|
+
def single_seed_run(
|
|
985
|
+
key: Array,
|
|
986
|
+
) -> tuple[
|
|
987
|
+
NormalizedLearnerState, Array, StepSizeHistory | None, NormalizerHistory | None
|
|
988
|
+
]:
|
|
989
|
+
result = run_normalized_learning_loop(
|
|
990
|
+
learner, stream, num_steps, key, learner_state,
|
|
991
|
+
step_size_tracking, normalizer_tracking
|
|
992
|
+
)
|
|
993
|
+
|
|
994
|
+
# Unpack based on what tracking was enabled
|
|
995
|
+
if step_size_tracking is not None and normalizer_tracking is not None:
|
|
996
|
+
state, metrics, ss_history, norm_history = result
|
|
997
|
+
return state, metrics, ss_history, norm_history
|
|
998
|
+
elif step_size_tracking is not None:
|
|
999
|
+
state, metrics, ss_history = result
|
|
1000
|
+
return state, metrics, ss_history, None
|
|
1001
|
+
elif normalizer_tracking is not None:
|
|
1002
|
+
state, metrics, norm_history = result
|
|
1003
|
+
return state, metrics, None, norm_history
|
|
1004
|
+
else:
|
|
1005
|
+
state, metrics = result
|
|
1006
|
+
return state, metrics, None, None
|
|
1007
|
+
|
|
1008
|
+
# vmap over the keys dimension
|
|
1009
|
+
batched_states, batched_metrics, batched_ss_history, batched_norm_history = (
|
|
1010
|
+
jax.vmap(single_seed_run)(keys)
|
|
1011
|
+
)
|
|
1012
|
+
|
|
1013
|
+
# Reconstruct batched histories if tracking was enabled
|
|
1014
|
+
if step_size_tracking is not None:
|
|
1015
|
+
batched_step_size_history = StepSizeHistory(
|
|
1016
|
+
step_sizes=batched_ss_history.step_sizes,
|
|
1017
|
+
bias_step_sizes=batched_ss_history.bias_step_sizes,
|
|
1018
|
+
recording_indices=batched_ss_history.recording_indices,
|
|
1019
|
+
normalizers=batched_ss_history.normalizers,
|
|
1020
|
+
)
|
|
1021
|
+
else:
|
|
1022
|
+
batched_step_size_history = None
|
|
1023
|
+
|
|
1024
|
+
if normalizer_tracking is not None:
|
|
1025
|
+
batched_normalizer_history = NormalizerHistory(
|
|
1026
|
+
means=batched_norm_history.means,
|
|
1027
|
+
variances=batched_norm_history.variances,
|
|
1028
|
+
recording_indices=batched_norm_history.recording_indices,
|
|
1029
|
+
)
|
|
1030
|
+
else:
|
|
1031
|
+
batched_normalizer_history = None
|
|
1032
|
+
|
|
1033
|
+
return BatchedNormalizedResult(
|
|
1034
|
+
states=batched_states,
|
|
1035
|
+
metrics=batched_metrics,
|
|
1036
|
+
step_size_history=batched_step_size_history,
|
|
1037
|
+
normalizer_history=batched_normalizer_history,
|
|
1038
|
+
)
|
|
508
1039
|
|
|
509
1040
|
|
|
510
1041
|
def metrics_to_dicts(metrics: Array, normalized: bool = False) -> list[dict[str, float]]:
|
|
@@ -35,7 +35,7 @@ class OnlineNormalizer:
|
|
|
35
35
|
"""Online feature normalizer for continual learning.
|
|
36
36
|
|
|
37
37
|
Normalizes features using running estimates of mean and standard deviation:
|
|
38
|
-
|
|
38
|
+
`x_normalized = (x - mean) / (std + epsilon)`
|
|
39
39
|
|
|
40
40
|
The normalizer updates its estimates at every time step, following
|
|
41
41
|
temporal uniformity. Uses exponential moving average for non-stationary
|
|
@@ -72,7 +72,7 @@ class Optimizer[StateT: (LMSState, IDBDState, AutostepState)](ABC):
|
|
|
72
72
|
class LMS(Optimizer[LMSState]):
|
|
73
73
|
"""Least Mean Square optimizer with fixed step-size.
|
|
74
74
|
|
|
75
|
-
The simplest gradient-based optimizer: w_{t+1} = w_t + alpha * delta * x_t
|
|
75
|
+
The simplest gradient-based optimizer: `w_{t+1} = w_t + alpha * delta * x_t`
|
|
76
76
|
|
|
77
77
|
This serves as a baseline. The challenge is that the optimal step-size
|
|
78
78
|
depends on the problem and changes as the task becomes non-stationary.
|
|
@@ -108,7 +108,7 @@ class LMS(Optimizer[LMSState]):
|
|
|
108
108
|
) -> OptimizerUpdate:
|
|
109
109
|
"""Compute LMS weight update.
|
|
110
110
|
|
|
111
|
-
Update rule: delta_w = alpha * error * x
|
|
111
|
+
Update rule: `delta_w = alpha * error * x`
|
|
112
112
|
|
|
113
113
|
Args:
|
|
114
114
|
state: Current LMS state
|
|
@@ -195,10 +195,11 @@ class IDBD(Optimizer[IDBDState]):
|
|
|
195
195
|
"""Compute IDBD weight update with adaptive step-sizes.
|
|
196
196
|
|
|
197
197
|
The IDBD algorithm:
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
198
|
+
|
|
199
|
+
1. Compute step-sizes: `alpha_i = exp(log_alpha_i)`
|
|
200
|
+
2. Update weights: `w_i += alpha_i * error * x_i`
|
|
201
|
+
3. Update log step-sizes: `log_alpha_i += beta * error * x_i * h_i`
|
|
202
|
+
4. Update traces: `h_i = h_i * max(0, 1 - alpha_i * x_i^2) + alpha_i * error * x_i`
|
|
202
203
|
|
|
203
204
|
The trace h_i tracks the correlation between current and past gradients.
|
|
204
205
|
When gradients consistently point the same direction, h_i grows,
|
|
@@ -335,12 +336,13 @@ class Autostep(Optimizer[AutostepState]):
|
|
|
335
336
|
"""Compute Autostep weight update with normalized gradients.
|
|
336
337
|
|
|
337
338
|
The Autostep algorithm:
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
339
|
+
|
|
340
|
+
1. Compute gradient: `g_i = error * x_i`
|
|
341
|
+
2. Normalize gradient: `g_i' = g_i / max(|g_i|, v_i)`
|
|
342
|
+
3. Update weights: `w_i += alpha_i * g_i'`
|
|
343
|
+
4. Update step-sizes: `alpha_i *= exp(mu * g_i' * h_i)`
|
|
344
|
+
5. Update traces: `h_i = h_i * (1 - alpha_i) + alpha_i * g_i'`
|
|
345
|
+
6. Update normalizers: `v_i = max(|g_i|, v_i * tau)`
|
|
344
346
|
|
|
345
347
|
Args:
|
|
346
348
|
state: Current Autostep state
|
alberta_framework/core/types.py
CHANGED
|
@@ -126,11 +126,81 @@ class StepSizeHistory(NamedTuple):
|
|
|
126
126
|
step_sizes: Per-weight step-sizes at each recording, shape (num_recordings, num_weights)
|
|
127
127
|
bias_step_sizes: Bias step-sizes at each recording, shape (num_recordings,) or None
|
|
128
128
|
recording_indices: Step indices where recordings were made, shape (num_recordings,)
|
|
129
|
+
normalizers: Autostep's per-weight normalizers (v_i) at each recording,
|
|
130
|
+
shape (num_recordings, num_weights) or None. Only populated for Autostep optimizer.
|
|
129
131
|
"""
|
|
130
132
|
|
|
131
133
|
step_sizes: Array # (num_recordings, num_weights)
|
|
132
134
|
bias_step_sizes: Array | None # (num_recordings,) or None
|
|
133
135
|
recording_indices: Array # (num_recordings,)
|
|
136
|
+
normalizers: Array | None = None # (num_recordings, num_weights) - Autostep v_i
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class NormalizerTrackingConfig(NamedTuple):
|
|
140
|
+
"""Configuration for recording per-feature normalizer state during training.
|
|
141
|
+
|
|
142
|
+
Attributes:
|
|
143
|
+
interval: Record normalizer state every N steps
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
interval: int
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class NormalizerHistory(NamedTuple):
|
|
150
|
+
"""History of per-feature normalizer state recorded during training.
|
|
151
|
+
|
|
152
|
+
Used for analyzing how the OnlineNormalizer adapts to distribution shifts
|
|
153
|
+
(reactive lag diagnostic).
|
|
154
|
+
|
|
155
|
+
Attributes:
|
|
156
|
+
means: Per-feature mean estimates at each recording, shape (num_recordings, feature_dim)
|
|
157
|
+
variances: Per-feature variance estimates at each recording,
|
|
158
|
+
shape (num_recordings, feature_dim)
|
|
159
|
+
recording_indices: Step indices where recordings were made, shape (num_recordings,)
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
means: Array # (num_recordings, feature_dim)
|
|
163
|
+
variances: Array # (num_recordings, feature_dim)
|
|
164
|
+
recording_indices: Array # (num_recordings,)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class BatchedLearningResult(NamedTuple):
|
|
168
|
+
"""Result from batched learning loop across multiple seeds.
|
|
169
|
+
|
|
170
|
+
Used with `run_learning_loop_batched` for vmap-based GPU parallelization.
|
|
171
|
+
|
|
172
|
+
Attributes:
|
|
173
|
+
states: Batched learner states - each array has shape (num_seeds, ...)
|
|
174
|
+
metrics: Metrics array with shape (num_seeds, num_steps, 3)
|
|
175
|
+
where columns are [squared_error, error, mean_step_size]
|
|
176
|
+
step_size_history: Optional step-size history with batched shapes,
|
|
177
|
+
or None if tracking was disabled
|
|
178
|
+
"""
|
|
179
|
+
|
|
180
|
+
states: "LearnerState" # Batched: each array has shape (num_seeds, ...)
|
|
181
|
+
metrics: Array # Shape: (num_seeds, num_steps, 3)
|
|
182
|
+
step_size_history: StepSizeHistory | None
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
class BatchedNormalizedResult(NamedTuple):
|
|
186
|
+
"""Result from batched normalized learning loop across multiple seeds.
|
|
187
|
+
|
|
188
|
+
Used with `run_normalized_learning_loop_batched` for vmap-based GPU parallelization.
|
|
189
|
+
|
|
190
|
+
Attributes:
|
|
191
|
+
states: Batched normalized learner states - each array has shape (num_seeds, ...)
|
|
192
|
+
metrics: Metrics array with shape (num_seeds, num_steps, 4)
|
|
193
|
+
where columns are [squared_error, error, mean_step_size, normalizer_mean_var]
|
|
194
|
+
step_size_history: Optional step-size history with batched shapes,
|
|
195
|
+
or None if tracking was disabled
|
|
196
|
+
normalizer_history: Optional normalizer history with batched shapes,
|
|
197
|
+
or None if tracking was disabled
|
|
198
|
+
"""
|
|
199
|
+
|
|
200
|
+
states: "NormalizedLearnerState" # Batched: each array has shape (num_seeds, ...)
|
|
201
|
+
metrics: Array # Shape: (num_seeds, num_steps, 4)
|
|
202
|
+
step_size_history: StepSizeHistory | None
|
|
203
|
+
normalizer_history: NormalizerHistory | None
|
|
134
204
|
|
|
135
205
|
|
|
136
206
|
def create_lms_state(step_size: float = 0.01) -> LMSState:
|
|
@@ -30,11 +30,14 @@ class ScanStream(Protocol[StateT]):
|
|
|
30
30
|
Type Parameters:
|
|
31
31
|
StateT: The state type maintained by this stream
|
|
32
32
|
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
33
|
+
Examples
|
|
34
|
+
--------
|
|
35
|
+
```python
|
|
36
|
+
stream = RandomWalkStream(feature_dim=10, drift_rate=0.001)
|
|
37
|
+
key = jax.random.key(42)
|
|
38
|
+
state = stream.init(key)
|
|
39
|
+
timestep, new_state = stream.step(state, jnp.array(0))
|
|
40
|
+
```
|
|
38
41
|
"""
|
|
39
42
|
|
|
40
43
|
@property
|
|
@@ -32,7 +32,7 @@ class RandomWalkState(NamedTuple):
|
|
|
32
32
|
class RandomWalkStream:
|
|
33
33
|
"""Non-stationary stream where target weights drift via random walk.
|
|
34
34
|
|
|
35
|
-
The true target function is linear: y* = w_true @ x + noise
|
|
35
|
+
The true target function is linear: `y* = w_true @ x + noise`
|
|
36
36
|
where w_true evolves via random walk at each time step.
|
|
37
37
|
|
|
38
38
|
This tests the learner's ability to continuously track a moving target.
|
|
@@ -590,12 +590,15 @@ class ScaledStreamWrapper:
|
|
|
590
590
|
scale factor. Useful for testing how learners handle features at different
|
|
591
591
|
scales, which is important for understanding normalization benefits.
|
|
592
592
|
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
593
|
+
Examples
|
|
594
|
+
--------
|
|
595
|
+
```python
|
|
596
|
+
stream = ScaledStreamWrapper(
|
|
597
|
+
AbruptChangeStream(feature_dim=10, change_interval=1000),
|
|
598
|
+
feature_scales=jnp.array([0.001, 0.01, 0.1, 1.0, 10.0,
|
|
599
|
+
100.0, 1000.0, 0.001, 0.01, 0.1])
|
|
600
|
+
)
|
|
601
|
+
```
|
|
599
602
|
|
|
600
603
|
Attributes:
|
|
601
604
|
inner_stream: The wrapped stream instance
|
|
@@ -693,9 +696,12 @@ def make_scale_range(
|
|
|
693
696
|
Returns:
|
|
694
697
|
Array of shape (feature_dim,) with scale factors
|
|
695
698
|
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
+
Examples
|
|
700
|
+
--------
|
|
701
|
+
```python
|
|
702
|
+
scales = make_scale_range(10, min_scale=0.01, max_scale=100.0)
|
|
703
|
+
stream = ScaledStreamWrapper(RandomWalkStream(10), scales)
|
|
704
|
+
```
|
|
699
705
|
"""
|
|
700
706
|
if log_spaced:
|
|
701
707
|
return jnp.logspace(
|
|
@@ -110,13 +110,14 @@ def run_single_experiment(
|
|
|
110
110
|
|
|
111
111
|
final_state: LearnerState | NormalizedLearnerState
|
|
112
112
|
if isinstance(learner, NormalizedLinearLearner):
|
|
113
|
-
|
|
113
|
+
norm_result = run_normalized_learning_loop(
|
|
114
114
|
learner, stream, config.num_steps, key
|
|
115
115
|
)
|
|
116
|
+
final_state, metrics = cast(tuple[NormalizedLearnerState, Any], norm_result)
|
|
116
117
|
metrics_history = metrics_to_dicts(metrics, normalized=True)
|
|
117
118
|
else:
|
|
118
|
-
|
|
119
|
-
final_state, metrics = cast(tuple[LearnerState, Any],
|
|
119
|
+
linear_result = run_learning_loop(learner, stream, config.num_steps, key)
|
|
120
|
+
final_state, metrics = cast(tuple[LearnerState, Any], linear_result)
|
|
120
121
|
metrics_history = metrics_to_dicts(metrics)
|
|
121
122
|
|
|
122
123
|
return SingleRunResult(
|
|
@@ -3,19 +3,22 @@
|
|
|
3
3
|
This module provides a simple Timer context manager for measuring execution time
|
|
4
4
|
and formatting durations in a human-readable format.
|
|
5
5
|
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
6
|
+
Examples
|
|
7
|
+
--------
|
|
8
|
+
```python
|
|
9
|
+
from alberta_framework.utils.timing import Timer
|
|
10
|
+
|
|
11
|
+
with Timer("Training"):
|
|
12
|
+
# run training code
|
|
13
|
+
pass
|
|
14
|
+
# Output: Training completed in 1.23s
|
|
15
|
+
|
|
16
|
+
# Or capture the duration:
|
|
17
|
+
with Timer("Experiment") as t:
|
|
18
|
+
# run experiment
|
|
19
|
+
pass
|
|
20
|
+
print(f"Took {t.duration:.2f} seconds")
|
|
21
|
+
```
|
|
19
22
|
"""
|
|
20
23
|
|
|
21
24
|
import time
|
|
@@ -32,13 +35,13 @@ def format_duration(seconds: float) -> str:
|
|
|
32
35
|
Returns:
|
|
33
36
|
Formatted string like "1.23s", "2m 30.5s", or "1h 5m 30s"
|
|
34
37
|
|
|
35
|
-
Examples
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
38
|
+
Examples
|
|
39
|
+
--------
|
|
40
|
+
```python
|
|
41
|
+
format_duration(0.5) # Returns: '0.50s'
|
|
42
|
+
format_duration(90.5) # Returns: '1m 30.50s'
|
|
43
|
+
format_duration(3665) # Returns: '1h 1m 5.00s'
|
|
44
|
+
```
|
|
42
45
|
"""
|
|
43
46
|
if seconds < 60:
|
|
44
47
|
return f"{seconds:.2f}s"
|
|
@@ -66,22 +69,25 @@ class Timer:
|
|
|
66
69
|
start_time: Timestamp when timing started
|
|
67
70
|
end_time: Timestamp when timing ended
|
|
68
71
|
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
72
|
+
Examples
|
|
73
|
+
--------
|
|
74
|
+
```python
|
|
75
|
+
with Timer("Training loop"):
|
|
76
|
+
for i in range(1000):
|
|
77
|
+
pass
|
|
78
|
+
# Output: Training loop completed in 0.01s
|
|
79
|
+
|
|
80
|
+
# Silent timing (no print):
|
|
81
|
+
with Timer("Silent", verbose=False) as t:
|
|
82
|
+
time.sleep(0.1)
|
|
83
|
+
print(f"Elapsed: {t.duration:.2f}s")
|
|
84
|
+
# Output: Elapsed: 0.10s
|
|
85
|
+
|
|
86
|
+
# Custom print function:
|
|
87
|
+
with Timer("Custom", print_fn=lambda msg: print(f">> {msg}")):
|
|
88
|
+
pass
|
|
89
|
+
# Output: >> Custom completed in 0.00s
|
|
90
|
+
```
|
|
85
91
|
"""
|
|
86
92
|
|
|
87
93
|
def __init__(
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: alberta-framework
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary: Implementation of the Alberta Plan for AI Research - continual learning with meta-learned step-sizes
|
|
5
5
|
Project-URL: Homepage, https://github.com/j-klawson/alberta-framework
|
|
6
6
|
Project-URL: Repository, https://github.com/j-klawson/alberta-framework
|
|
@@ -49,7 +49,7 @@ Description-Content-Type: text/markdown
|
|
|
49
49
|
[](https://opensource.org/licenses/Apache-2.0)
|
|
50
50
|
[](https://www.python.org/downloads/)
|
|
51
51
|
|
|
52
|
-
A JAX-based research framework implementing components of [The Alberta Plan](https://arxiv.org/abs/2208.11173) in the pursuit of building the foundations of Continual AI.
|
|
52
|
+
A JAX-based research framework implementing components of [The Alberta Plan for AI Research](https://arxiv.org/abs/2208.11173) in the pursuit of building the foundations of Continual AI.
|
|
53
53
|
|
|
54
54
|
> "The agents are complex only because they interact with a complex world... their initial design is as simple, general, and scalable as possible." — *Sutton et al., 2022*
|
|
55
55
|
|
|
@@ -57,6 +57,14 @@ A JAX-based research framework implementing components of [The Alberta Plan](htt
|
|
|
57
57
|
|
|
58
58
|
The Alberta Framework provides foundational components for continual reinforcement learning research. Built on JAX for hardware acceleration, the framework emphasizes temporal uniformity every component updates at every time step, with no special training phases or batch processing.
|
|
59
59
|
|
|
60
|
+
## Project Context
|
|
61
|
+
|
|
62
|
+
This framework is developed as part of my D.Eng. work focusing on the foundations of Continual AI. For more background and context see:
|
|
63
|
+
|
|
64
|
+
* **Research Blog**: [blog.9600baud.net](https://blog.9600baud.net)
|
|
65
|
+
* **Replicating Sutton '92**: [The Foundation of Step-size Adaptation](https://blog.9600baud.net/sutton92.html)
|
|
66
|
+
* **About the Author**: [Keith Lawson](https://blog.9600baud.net/about.html)
|
|
67
|
+
|
|
60
68
|
### Roadmap
|
|
61
69
|
|
|
62
70
|
Depending on my research trajectory I may or may not implement components required for the plan. The current focus of this framework is the Step 1 Baseline Study, investigating the interaction between adaptive optimizers and online normalization.
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
alberta_framework/__init__.py,sha256=gAafDDmkivDdfnvDVff9zbVY9ilzqqfJ9KvpbRegKqs,5726
|
|
2
|
+
alberta_framework/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
+
alberta_framework/core/__init__.py,sha256=PSrC4zSxgm_6YXWEQ80aZaunpbQ58QexxKmDDU-jp6c,522
|
|
4
|
+
alberta_framework/core/learners.py,sha256=khZYkae5rlIyV13BW3-hrtPSjGFXPj2IUTM1z74xTTA,37724
|
|
5
|
+
alberta_framework/core/normalizers.py,sha256=Z_d3H17qoXh87DE7k41imvWzkVJQ2xQgDUP7GYSNzAY,5903
|
|
6
|
+
alberta_framework/core/optimizers.py,sha256=OefVuDDG1phh1QQIUyVPsQckl41VrpWFG7hY2eqyc64,14585
|
|
7
|
+
alberta_framework/core/types.py,sha256=_blkTZEm3wNgbweFuqaVL2hxRQU7D6_U66nHp23Pq6Y,9192
|
|
8
|
+
alberta_framework/streams/__init__.py,sha256=bsDgWjWjotDQHMI2lno3dgk8N14pd-2mYAQpXAtCPx4,2035
|
|
9
|
+
alberta_framework/streams/base.py,sha256=9rJxvUgmzd5u2bRV4vi5PxhUvj39EZTD4bZHo-Ptn-U,2168
|
|
10
|
+
alberta_framework/streams/gymnasium.py,sha256=s733X7aEgy05hcSazjZEhBiJChtEL7uVpxwh0fXBQZA,21980
|
|
11
|
+
alberta_framework/streams/synthetic.py,sha256=4R9GR7Kh0LT7GmGtPhzMJGr8HbhrAMUOjvPwLZk6nDg,32979
|
|
12
|
+
alberta_framework/utils/__init__.py,sha256=zfKfnbikhLp0J6UgVa8HeRo59gZHwqOc8jf03s7AaT4,2845
|
|
13
|
+
alberta_framework/utils/experiments.py,sha256=ekGAzveCRgv9YZ5mfAD5Uf7h_PvQnxsNw2KeZN2eu00,10644
|
|
14
|
+
alberta_framework/utils/export.py,sha256=W9RKfeTiyZcLColOGNjBfZU0N6QMXrfPn4pdYcm-OSk,15832
|
|
15
|
+
alberta_framework/utils/metrics.py,sha256=1cryNJoboO67vvRhausaucbYZFgdL_06vaf08UXbojg,3349
|
|
16
|
+
alberta_framework/utils/statistics.py,sha256=4fbzNlmsdUaM5lLW1BhL5B5MUpnqimQlwJklZ4x0y0U,15416
|
|
17
|
+
alberta_framework/utils/timing.py,sha256=JOLq8CpCAV7LWOWkftxefduSFjaXnVwal1MFBKEMdJI,4049
|
|
18
|
+
alberta_framework/utils/visualization.py,sha256=PmKBD3KGabNhgDizcNiGJEbVCyDL1YMUE5yTwgJHu2o,17924
|
|
19
|
+
alberta_framework-0.2.0.dist-info/METADATA,sha256=zcg9hjFzz_oV3Jij8xvDnmww0p8-VPhzSBsN1VD8rvw,7763
|
|
20
|
+
alberta_framework-0.2.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
21
|
+
alberta_framework-0.2.0.dist-info/licenses/LICENSE,sha256=TI1avodt5mvxz7sunyxIa0HlNgLQcmKNLeRjCVcgKmE,10754
|
|
22
|
+
alberta_framework-0.2.0.dist-info/RECORD,,
|
|
@@ -1,22 +0,0 @@
|
|
|
1
|
-
alberta_framework/__init__.py,sha256=gPLBA2EiPcElsYp_U_Rs7C6wlrGHr8w5IL6C0F90zec,4739
|
|
2
|
-
alberta_framework/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
-
alberta_framework/core/__init__.py,sha256=PSrC4zSxgm_6YXWEQ80aZaunpbQ58QexxKmDDU-jp6c,522
|
|
4
|
-
alberta_framework/core/learners.py,sha256=Abq_iOb9CWy9DMWeIX_7PXxo59gXEmlvNnf8zrDKQpo,18157
|
|
5
|
-
alberta_framework/core/normalizers.py,sha256=-OFkdKcfx4VTwm4WLXu1hxrh8DTdwFlIG6CgVYgaZBk,5905
|
|
6
|
-
alberta_framework/core/optimizers.py,sha256=uHEOhE1ThLcXJ18zagW9SQAEiNpsWapNWSMelbhkdNY,14559
|
|
7
|
-
alberta_framework/core/types.py,sha256=Op9EHIIoEZGKbbr3b7xijaOurlQ-mxohBRv7rnVybro,6307
|
|
8
|
-
alberta_framework/streams/__init__.py,sha256=bsDgWjWjotDQHMI2lno3dgk8N14pd-2mYAQpXAtCPx4,2035
|
|
9
|
-
alberta_framework/streams/base.py,sha256=81zqXTF30Orj0N2BXSLYVHF9wUYZSqthqQi1MG5Kzxs,2165
|
|
10
|
-
alberta_framework/streams/gymnasium.py,sha256=s733X7aEgy05hcSazjZEhBiJChtEL7uVpxwh0fXBQZA,21980
|
|
11
|
-
alberta_framework/streams/synthetic.py,sha256=kRQktC4NNlFvoF_FmY_WG9VkiASGOudz8qdI5VjoRq8,32963
|
|
12
|
-
alberta_framework/utils/__init__.py,sha256=zfKfnbikhLp0J6UgVa8HeRo59gZHwqOc8jf03s7AaT4,2845
|
|
13
|
-
alberta_framework/utils/experiments.py,sha256=8N_JrffUa1S_lIZQIqKDuBxyv4UYt9QXzLlo-YnMAEU,10554
|
|
14
|
-
alberta_framework/utils/export.py,sha256=W9RKfeTiyZcLColOGNjBfZU0N6QMXrfPn4pdYcm-OSk,15832
|
|
15
|
-
alberta_framework/utils/metrics.py,sha256=1cryNJoboO67vvRhausaucbYZFgdL_06vaf08UXbojg,3349
|
|
16
|
-
alberta_framework/utils/statistics.py,sha256=4fbzNlmsdUaM5lLW1BhL5B5MUpnqimQlwJklZ4x0y0U,15416
|
|
17
|
-
alberta_framework/utils/timing.py,sha256=05NwXrIc9nS2p2MCHjdOgglPbE1CHZsLLdSB6em7YNY,4110
|
|
18
|
-
alberta_framework/utils/visualization.py,sha256=PmKBD3KGabNhgDizcNiGJEbVCyDL1YMUE5yTwgJHu2o,17924
|
|
19
|
-
alberta_framework-0.1.0.dist-info/METADATA,sha256=G4lIPB7NJlkJGOlmbo7BqeIsHNUl2M-L4f8Gd1T2Ro8,7332
|
|
20
|
-
alberta_framework-0.1.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
21
|
-
alberta_framework-0.1.0.dist-info/licenses/LICENSE,sha256=TI1avodt5mvxz7sunyxIa0HlNgLQcmKNLeRjCVcgKmE,10754
|
|
22
|
-
alberta_framework-0.1.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|