alberta-framework 0.2.1__tar.gz → 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/.github/workflows/publish.yml +1 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/CLAUDE.md +16 -1
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/PKG-INFO +4 -1
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/pyproject.toml +37 -1
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/core/learners.py +69 -60
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/core/normalizers.py +8 -7
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/core/optimizers.py +6 -4
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/core/types.py +65 -52
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/streams/synthetic.py +42 -32
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/tests/conftest.py +2 -1
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/tests/test_learners.py +52 -46
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/tests/test_normalizers.py +12 -10
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/tests/test_optimizers.py +30 -24
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/tests/test_streams.py +70 -61
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/.github/workflows/ci.yml +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/.github/workflows/docs.yml +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/.gitignore +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/ALBERTA_PLAN.md +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/CHANGELOG.md +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/LICENSE +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/README.md +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/ROADMAP.md +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/docs/contributing.md +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/docs/gen_ref_pages.py +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/docs/getting-started/installation.md +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/docs/getting-started/quickstart.md +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/docs/guide/concepts.md +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/docs/guide/experiments.md +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/docs/guide/gymnasium.md +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/docs/guide/optimizers.md +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/docs/guide/streams.md +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/docs/index.md +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/docs/javascripts/mathjax.js +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/examples/The Alberta Plan/Step1/README.md +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/examples/The Alberta Plan/Step1/autostep_comparison.py +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/examples/The Alberta Plan/Step1/external_normalization_study.py +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/examples/The Alberta Plan/Step1/idbd_lms_autostep_comparison.py +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/examples/The Alberta Plan/Step1/normalization_study.py +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/examples/The Alberta Plan/Step1/sutton1992_experiment1.py +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/examples/The Alberta Plan/Step1/sutton1992_experiment2.py +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/examples/gymnasium_reward_prediction.py +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/examples/publication_experiment.py +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/examples/td_cartpole_lms.py +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/mkdocs.yml +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/papers/mahmood-msc-thesis-summary.md +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/__init__.py +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/core/__init__.py +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/py.typed +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/streams/__init__.py +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/streams/base.py +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/streams/gymnasium.py +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/utils/__init__.py +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/utils/experiments.py +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/utils/export.py +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/utils/metrics.py +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/utils/statistics.py +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/utils/timing.py +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/utils/visualization.py +0 -0
- {alberta_framework-0.2.1 → alberta_framework-0.3.0}/tests/test_gymnasium_streams.py +0 -0
|
@@ -70,7 +70,8 @@ mkdocs build # Build static site to site/
|
|
|
70
70
|
## Development Guidelines
|
|
71
71
|
|
|
72
72
|
### Design Principles
|
|
73
|
-
- **Immutable State**: All state uses
|
|
73
|
+
- **Immutable State**: All state uses `@chex.dataclass(frozen=True)` for JAX PyTree compatibility
|
|
74
|
+
- **Type Safety**: jaxtyping annotations for shape checking (`Float[Array, " feature_dim"]`)
|
|
74
75
|
- **Functional Style**: Pure functions enable `jit`, `vmap`, `jax.lax.scan`
|
|
75
76
|
- **Scan-Based Learning**: Learning loops use `jax.lax.scan` for JIT-compiled training
|
|
76
77
|
- **Composition**: Learners accept optimizers as parameters
|
|
@@ -85,6 +86,7 @@ mkdocs build # Build static site to site/
|
|
|
85
86
|
### Testing
|
|
86
87
|
- Tests are in `tests/` directory
|
|
87
88
|
- Use pytest fixtures from `conftest.py`
|
|
89
|
+
- Use chex assertions: `chex.assert_shape()`, `chex.assert_trees_all_close()`, `chex.assert_tree_all_finite()`
|
|
88
90
|
- All tests should pass before committing
|
|
89
91
|
|
|
90
92
|
## Key Algorithms
|
|
@@ -465,3 +467,16 @@ The publish workflow uses OpenID Connect (no API tokens). Configure on PyPI:
|
|
|
465
467
|
1. PyPI project → Settings → Publishing → Add GitHub publisher
|
|
466
468
|
2. Repository: `j-klawson/alberta-framework`, Workflow: `publish.yml`, Environment: `pypi`
|
|
467
469
|
3. Repeat on TestPyPI with environment: `testpypi`
|
|
470
|
+
|
|
471
|
+
## Changelog
|
|
472
|
+
|
|
473
|
+
### v0.3.0 (2026-02-03)
|
|
474
|
+
- **FEATURE**: Migrated all state types from NamedTuple to `@chex.dataclass(frozen=True)` for DeepMind-style JAX compatibility
|
|
475
|
+
- **FEATURE**: Added jaxtyping shape annotations for compile-time type safety (`Float[Array, " feature_dim"]`, `PRNGKeyArray`, etc.)
|
|
476
|
+
- **FEATURE**: Updated test suite to use chex assertions (`chex.assert_shape`, `chex.assert_tree_all_finite`, `chex.assert_trees_all_close`)
|
|
477
|
+
- **DEPS**: Added `chex>=0.1.86` and `jaxtyping>=0.2.28` as required dependencies
|
|
478
|
+
- **DEPS**: Added `beartype>=0.18.0` as optional dev dependency for runtime type checking
|
|
479
|
+
|
|
480
|
+
### v0.2.2 (2026-02-02)
|
|
481
|
+
- Fixed mypy type errors in `run_learning_loop_batched` and `run_normalized_learning_loop_batched` functions
|
|
482
|
+
- Added `typing.cast` to properly handle conditional return type unpacking in batched learning loops
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: alberta-framework
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0
|
|
4
4
|
Summary: Implementation of the Alberta Plan for AI Research - continual learning with meta-learned step-sizes
|
|
5
5
|
Project-URL: Homepage, https://github.com/j-klawson/alberta-framework
|
|
6
6
|
Project-URL: Repository, https://github.com/j-klawson/alberta-framework
|
|
@@ -17,14 +17,17 @@ Classifier: Programming Language :: Python :: 3
|
|
|
17
17
|
Classifier: Programming Language :: Python :: 3.13
|
|
18
18
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
19
19
|
Requires-Python: >=3.13
|
|
20
|
+
Requires-Dist: chex>=0.1.86
|
|
20
21
|
Requires-Dist: jax>=0.4
|
|
21
22
|
Requires-Dist: jaxlib>=0.4
|
|
23
|
+
Requires-Dist: jaxtyping>=0.2.28
|
|
22
24
|
Requires-Dist: joblib>=1.3
|
|
23
25
|
Requires-Dist: matplotlib>=3.8
|
|
24
26
|
Requires-Dist: numpy>=2.0
|
|
25
27
|
Requires-Dist: scipy>=1.11
|
|
26
28
|
Requires-Dist: tqdm>=4.66
|
|
27
29
|
Provides-Extra: dev
|
|
30
|
+
Requires-Dist: beartype>=0.18.0; extra == 'dev'
|
|
28
31
|
Requires-Dist: mypy; extra == 'dev'
|
|
29
32
|
Requires-Dist: pytest-cov; extra == 'dev'
|
|
30
33
|
Requires-Dist: pytest>=8.0; extra == 'dev'
|
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "alberta-framework"
|
|
7
|
-
version = "0.
|
|
7
|
+
version = "0.3.0"
|
|
8
8
|
description = "Implementation of the Alberta Plan for AI Research - continual learning with meta-learned step-sizes"
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
license = "Apache-2.0"
|
|
@@ -29,6 +29,8 @@ dependencies = [
|
|
|
29
29
|
"scipy>=1.11",
|
|
30
30
|
"joblib>=1.3",
|
|
31
31
|
"tqdm>=4.66",
|
|
32
|
+
"chex>=0.1.86",
|
|
33
|
+
"jaxtyping>=0.2.28",
|
|
32
34
|
]
|
|
33
35
|
|
|
34
36
|
[project.optional-dependencies]
|
|
@@ -37,6 +39,7 @@ dev = [
|
|
|
37
39
|
"pytest-cov",
|
|
38
40
|
"ruff",
|
|
39
41
|
"mypy",
|
|
42
|
+
"beartype>=0.18.0",
|
|
40
43
|
]
|
|
41
44
|
docs = [
|
|
42
45
|
"mkdocs>=1.6",
|
|
@@ -99,6 +102,39 @@ ignore_missing_imports = true
|
|
|
99
102
|
module = "pandas.*"
|
|
100
103
|
ignore_missing_imports = true
|
|
101
104
|
|
|
105
|
+
[[tool.mypy.overrides]]
|
|
106
|
+
module = "chex.*"
|
|
107
|
+
ignore_missing_imports = true
|
|
108
|
+
|
|
109
|
+
[[tool.mypy.overrides]]
|
|
110
|
+
module = "jaxtyping.*"
|
|
111
|
+
ignore_missing_imports = true
|
|
112
|
+
|
|
113
|
+
# chex dataclasses don't have full mypy support - disable strict checks for core modules
|
|
114
|
+
[[tool.mypy.overrides]]
|
|
115
|
+
module = "alberta_framework.core.types"
|
|
116
|
+
disable_error_code = ["call-arg"]
|
|
117
|
+
|
|
118
|
+
[[tool.mypy.overrides]]
|
|
119
|
+
module = "alberta_framework.core.normalizers"
|
|
120
|
+
disable_error_code = ["call-arg"]
|
|
121
|
+
|
|
122
|
+
[[tool.mypy.overrides]]
|
|
123
|
+
module = "alberta_framework.core.optimizers"
|
|
124
|
+
disable_error_code = ["call-arg"]
|
|
125
|
+
|
|
126
|
+
[[tool.mypy.overrides]]
|
|
127
|
+
module = "alberta_framework.core.learners"
|
|
128
|
+
disable_error_code = ["call-arg", "arg-type", "union-attr"]
|
|
129
|
+
|
|
130
|
+
[[tool.mypy.overrides]]
|
|
131
|
+
module = "alberta_framework.streams.synthetic"
|
|
132
|
+
disable_error_code = ["call-arg"]
|
|
133
|
+
|
|
134
|
+
[[tool.mypy.overrides]]
|
|
135
|
+
module = "alberta_framework.streams.gymnasium"
|
|
136
|
+
disable_error_code = ["call-arg"]
|
|
137
|
+
|
|
102
138
|
[tool.pytest.ini_options]
|
|
103
139
|
testpaths = ["tests"]
|
|
104
140
|
addopts = "-v"
|
|
@@ -5,11 +5,13 @@ for temporally-uniform learning. Uses JAX's scan for efficient JIT-compiled
|
|
|
5
5
|
training loops.
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
|
-
from typing import
|
|
8
|
+
from typing import cast
|
|
9
9
|
|
|
10
|
+
import chex
|
|
10
11
|
import jax
|
|
11
12
|
import jax.numpy as jnp
|
|
12
13
|
from jax import Array
|
|
14
|
+
from jaxtyping import Float
|
|
13
15
|
|
|
14
16
|
from alberta_framework.core.normalizers import NormalizerState, OnlineNormalizer
|
|
15
17
|
from alberta_framework.core.optimizers import LMS, Optimizer
|
|
@@ -33,7 +35,9 @@ from alberta_framework.streams.base import ScanStream
|
|
|
33
35
|
# Type alias for any optimizer type
|
|
34
36
|
AnyOptimizer = Optimizer[LMSState] | Optimizer[IDBDState] | Optimizer[AutostepState]
|
|
35
37
|
|
|
36
|
-
|
|
38
|
+
|
|
39
|
+
@chex.dataclass(frozen=True)
|
|
40
|
+
class UpdateResult:
|
|
37
41
|
"""Result of a learner update step.
|
|
38
42
|
|
|
39
43
|
Attributes:
|
|
@@ -45,8 +49,38 @@ class UpdateResult(NamedTuple):
|
|
|
45
49
|
|
|
46
50
|
state: LearnerState
|
|
47
51
|
prediction: Prediction
|
|
48
|
-
error: Array
|
|
49
|
-
metrics: Array
|
|
52
|
+
error: Float[Array, ""]
|
|
53
|
+
metrics: Float[Array, " 3"]
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@chex.dataclass(frozen=True)
|
|
57
|
+
class NormalizedLearnerState:
|
|
58
|
+
"""State for a learner with online feature normalization.
|
|
59
|
+
|
|
60
|
+
Attributes:
|
|
61
|
+
learner_state: Underlying learner state (weights, bias, optimizer)
|
|
62
|
+
normalizer_state: Online normalizer state (mean, var estimates)
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
learner_state: LearnerState
|
|
66
|
+
normalizer_state: NormalizerState
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@chex.dataclass(frozen=True)
|
|
70
|
+
class NormalizedUpdateResult:
|
|
71
|
+
"""Result of a normalized learner update step.
|
|
72
|
+
|
|
73
|
+
Attributes:
|
|
74
|
+
state: Updated normalized learner state
|
|
75
|
+
prediction: Prediction made before update
|
|
76
|
+
error: Prediction error
|
|
77
|
+
metrics: Array of metrics [squared_error, error, mean_step_size, normalizer_mean_var]
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
state: NormalizedLearnerState
|
|
81
|
+
prediction: Prediction
|
|
82
|
+
error: Float[Array, ""]
|
|
83
|
+
metrics: Float[Array, " 4"]
|
|
50
84
|
|
|
51
85
|
|
|
52
86
|
class LinearLearner:
|
|
@@ -133,8 +167,7 @@ class LinearLearner:
|
|
|
133
167
|
# Note: type ignore needed because we can't statically prove optimizer_state
|
|
134
168
|
# matches the optimizer's expected state type (though they will at runtime)
|
|
135
169
|
opt_update = self._optimizer.update(
|
|
136
|
-
state.optimizer_state,
|
|
137
|
-
error,
|
|
170
|
+
state.optimizer_state, error,
|
|
138
171
|
observation,
|
|
139
172
|
)
|
|
140
173
|
|
|
@@ -275,12 +308,12 @@ def run_learning_loop[StreamStateT](
|
|
|
275
308
|
opt_state = result.state.optimizer_state
|
|
276
309
|
if hasattr(opt_state, "log_step_sizes"):
|
|
277
310
|
# IDBD stores log step-sizes
|
|
278
|
-
weight_ss = jnp.exp(opt_state.log_step_sizes)
|
|
279
|
-
bias_ss = opt_state.bias_step_size
|
|
311
|
+
weight_ss = jnp.exp(opt_state.log_step_sizes)
|
|
312
|
+
bias_ss = opt_state.bias_step_size
|
|
280
313
|
elif hasattr(opt_state, "step_sizes"):
|
|
281
314
|
# Autostep stores step-sizes directly
|
|
282
|
-
weight_ss = opt_state.step_sizes
|
|
283
|
-
bias_ss = opt_state.bias_step_size
|
|
315
|
+
weight_ss = opt_state.step_sizes
|
|
316
|
+
bias_ss = opt_state.bias_step_size
|
|
284
317
|
else:
|
|
285
318
|
# LMS has a single fixed step-size
|
|
286
319
|
weight_ss = jnp.full(feature_dim, opt_state.step_size)
|
|
@@ -316,8 +349,7 @@ def run_learning_loop[StreamStateT](
|
|
|
316
349
|
new_norm_history = jax.lax.cond(
|
|
317
350
|
should_record,
|
|
318
351
|
lambda _: norm_history.at[recording_idx].set(
|
|
319
|
-
opt_state.normalizers
|
|
320
|
-
),
|
|
352
|
+
opt_state.normalizers ),
|
|
321
353
|
lambda _: norm_history,
|
|
322
354
|
None,
|
|
323
355
|
)
|
|
@@ -361,34 +393,6 @@ def run_learning_loop[StreamStateT](
|
|
|
361
393
|
return final_learner, metrics, history
|
|
362
394
|
|
|
363
395
|
|
|
364
|
-
class NormalizedLearnerState(NamedTuple):
|
|
365
|
-
"""State for a learner with online feature normalization.
|
|
366
|
-
|
|
367
|
-
Attributes:
|
|
368
|
-
learner_state: Underlying learner state (weights, bias, optimizer)
|
|
369
|
-
normalizer_state: Online normalizer state (mean, var estimates)
|
|
370
|
-
"""
|
|
371
|
-
|
|
372
|
-
learner_state: LearnerState
|
|
373
|
-
normalizer_state: NormalizerState
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
class NormalizedUpdateResult(NamedTuple):
|
|
377
|
-
"""Result of a normalized learner update step.
|
|
378
|
-
|
|
379
|
-
Attributes:
|
|
380
|
-
state: Updated normalized learner state
|
|
381
|
-
prediction: Prediction made before update
|
|
382
|
-
error: Prediction error
|
|
383
|
-
metrics: Array of metrics [squared_error, error, mean_step_size, normalizer_mean_var]
|
|
384
|
-
"""
|
|
385
|
-
|
|
386
|
-
state: NormalizedLearnerState
|
|
387
|
-
prediction: Prediction
|
|
388
|
-
error: Array
|
|
389
|
-
metrics: Array
|
|
390
|
-
|
|
391
|
-
|
|
392
396
|
class NormalizedLinearLearner:
|
|
393
397
|
"""Linear learner with online feature normalization.
|
|
394
398
|
|
|
@@ -702,12 +706,12 @@ def run_normalized_learning_loop[StreamStateT](
|
|
|
702
706
|
opt_state = result.state.learner_state.optimizer_state
|
|
703
707
|
if hasattr(opt_state, "log_step_sizes"):
|
|
704
708
|
# IDBD stores log step-sizes
|
|
705
|
-
weight_ss = jnp.exp(opt_state.log_step_sizes)
|
|
706
|
-
bias_ss = opt_state.bias_step_size
|
|
709
|
+
weight_ss = jnp.exp(opt_state.log_step_sizes)
|
|
710
|
+
bias_ss = opt_state.bias_step_size
|
|
707
711
|
elif hasattr(opt_state, "step_sizes"):
|
|
708
712
|
# Autostep stores step-sizes directly
|
|
709
|
-
weight_ss = opt_state.step_sizes
|
|
710
|
-
bias_ss = opt_state.bias_step_size
|
|
713
|
+
weight_ss = opt_state.step_sizes
|
|
714
|
+
bias_ss = opt_state.bias_step_size
|
|
711
715
|
else:
|
|
712
716
|
# LMS has a single fixed step-size
|
|
713
717
|
weight_ss = jnp.full(feature_dim, opt_state.step_size)
|
|
@@ -741,8 +745,7 @@ def run_normalized_learning_loop[StreamStateT](
|
|
|
741
745
|
new_ss_norm = jax.lax.cond(
|
|
742
746
|
should_record_ss,
|
|
743
747
|
lambda _: ss_norm.at[recording_idx].set(
|
|
744
|
-
opt_state.normalizers
|
|
745
|
-
),
|
|
748
|
+
opt_state.normalizers ),
|
|
746
749
|
lambda _: ss_norm,
|
|
747
750
|
None,
|
|
748
751
|
)
|
|
@@ -825,17 +828,14 @@ def run_normalized_learning_loop[StreamStateT](
|
|
|
825
828
|
ss_history_result = StepSizeHistory(
|
|
826
829
|
step_sizes=final_ss_hist,
|
|
827
830
|
bias_step_sizes=final_ss_bias_hist,
|
|
828
|
-
recording_indices=final_ss_rec,
|
|
829
|
-
normalizers=final_ss_norm,
|
|
831
|
+
recording_indices=final_ss_rec, normalizers=final_ss_norm,
|
|
830
832
|
)
|
|
831
833
|
|
|
832
834
|
norm_history_result = None
|
|
833
835
|
if normalizer_tracking is not None and final_n_means is not None:
|
|
834
836
|
norm_history_result = NormalizerHistory(
|
|
835
837
|
means=final_n_means,
|
|
836
|
-
variances=final_n_vars,
|
|
837
|
-
recording_indices=final_n_rec, # type: ignore[arg-type]
|
|
838
|
-
)
|
|
838
|
+
variances=final_n_vars, recording_indices=final_n_rec, )
|
|
839
839
|
|
|
840
840
|
# Return appropriate tuple based on what was tracked
|
|
841
841
|
if ss_history_result is not None and norm_history_result is not None:
|
|
@@ -900,10 +900,12 @@ def run_learning_loop_batched[StreamStateT](
|
|
|
900
900
|
learner, stream, num_steps, key, learner_state, step_size_tracking
|
|
901
901
|
)
|
|
902
902
|
if step_size_tracking is not None:
|
|
903
|
-
state, metrics, history =
|
|
903
|
+
state, metrics, history = cast(
|
|
904
|
+
tuple[LearnerState, Array, StepSizeHistory], result
|
|
905
|
+
)
|
|
904
906
|
return state, metrics, history
|
|
905
907
|
else:
|
|
906
|
-
state, metrics = result
|
|
908
|
+
state, metrics = cast(tuple[LearnerState, Array], result)
|
|
907
909
|
# Return None for history to maintain consistent output structure
|
|
908
910
|
return state, metrics, None
|
|
909
911
|
|
|
@@ -911,7 +913,7 @@ def run_learning_loop_batched[StreamStateT](
|
|
|
911
913
|
batched_states, batched_metrics, batched_history = jax.vmap(single_seed_run)(keys)
|
|
912
914
|
|
|
913
915
|
# Reconstruct batched history if tracking was enabled
|
|
914
|
-
if step_size_tracking is not None:
|
|
916
|
+
if step_size_tracking is not None and batched_history is not None:
|
|
915
917
|
batched_step_size_history = StepSizeHistory(
|
|
916
918
|
step_sizes=batched_history.step_sizes,
|
|
917
919
|
bias_step_sizes=batched_history.bias_step_sizes,
|
|
@@ -993,16 +995,23 @@ def run_normalized_learning_loop_batched[StreamStateT](
|
|
|
993
995
|
|
|
994
996
|
# Unpack based on what tracking was enabled
|
|
995
997
|
if step_size_tracking is not None and normalizer_tracking is not None:
|
|
996
|
-
state, metrics, ss_history, norm_history =
|
|
998
|
+
state, metrics, ss_history, norm_history = cast(
|
|
999
|
+
tuple[NormalizedLearnerState, Array, StepSizeHistory, NormalizerHistory],
|
|
1000
|
+
result,
|
|
1001
|
+
)
|
|
997
1002
|
return state, metrics, ss_history, norm_history
|
|
998
1003
|
elif step_size_tracking is not None:
|
|
999
|
-
state, metrics, ss_history =
|
|
1004
|
+
state, metrics, ss_history = cast(
|
|
1005
|
+
tuple[NormalizedLearnerState, Array, StepSizeHistory], result
|
|
1006
|
+
)
|
|
1000
1007
|
return state, metrics, ss_history, None
|
|
1001
1008
|
elif normalizer_tracking is not None:
|
|
1002
|
-
state, metrics, norm_history =
|
|
1009
|
+
state, metrics, norm_history = cast(
|
|
1010
|
+
tuple[NormalizedLearnerState, Array, NormalizerHistory], result
|
|
1011
|
+
)
|
|
1003
1012
|
return state, metrics, None, norm_history
|
|
1004
1013
|
else:
|
|
1005
|
-
state, metrics = result
|
|
1014
|
+
state, metrics = cast(tuple[NormalizedLearnerState, Array], result)
|
|
1006
1015
|
return state, metrics, None, None
|
|
1007
1016
|
|
|
1008
1017
|
# vmap over the keys dimension
|
|
@@ -1011,7 +1020,7 @@ def run_normalized_learning_loop_batched[StreamStateT](
|
|
|
1011
1020
|
)
|
|
1012
1021
|
|
|
1013
1022
|
# Reconstruct batched histories if tracking was enabled
|
|
1014
|
-
if step_size_tracking is not None:
|
|
1023
|
+
if step_size_tracking is not None and batched_ss_history is not None:
|
|
1015
1024
|
batched_step_size_history = StepSizeHistory(
|
|
1016
1025
|
step_sizes=batched_ss_history.step_sizes,
|
|
1017
1026
|
bias_step_sizes=batched_ss_history.bias_step_sizes,
|
|
@@ -1021,7 +1030,7 @@ def run_normalized_learning_loop_batched[StreamStateT](
|
|
|
1021
1030
|
else:
|
|
1022
1031
|
batched_step_size_history = None
|
|
1023
1032
|
|
|
1024
|
-
if normalizer_tracking is not None:
|
|
1033
|
+
if normalizer_tracking is not None and batched_norm_history is not None:
|
|
1025
1034
|
batched_normalizer_history = NormalizerHistory(
|
|
1026
1035
|
means=batched_norm_history.means,
|
|
1027
1036
|
variances=batched_norm_history.variances,
|
{alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/core/normalizers.py
RENAMED
|
@@ -6,13 +6,14 @@ and variance at every time step, following the principle of temporal uniformity.
|
|
|
6
6
|
Reference: Welford's online algorithm for numerical stability.
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
|
-
|
|
10
|
-
|
|
9
|
+
import chex
|
|
11
10
|
import jax.numpy as jnp
|
|
12
11
|
from jax import Array
|
|
12
|
+
from jaxtyping import Float
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
|
|
15
|
+
@chex.dataclass(frozen=True)
|
|
16
|
+
class NormalizerState:
|
|
16
17
|
"""State for online feature normalization.
|
|
17
18
|
|
|
18
19
|
Uses Welford's online algorithm for numerically stable estimation
|
|
@@ -25,10 +26,10 @@ class NormalizerState(NamedTuple):
|
|
|
25
26
|
decay: Exponential decay factor for estimates (1.0 = no decay, pure online)
|
|
26
27
|
"""
|
|
27
28
|
|
|
28
|
-
mean: Array
|
|
29
|
-
var: Array
|
|
30
|
-
sample_count: Array
|
|
31
|
-
decay: Array
|
|
29
|
+
mean: Float[Array, " feature_dim"]
|
|
30
|
+
var: Float[Array, " feature_dim"]
|
|
31
|
+
sample_count: Float[Array, ""]
|
|
32
|
+
decay: Float[Array, ""]
|
|
32
33
|
|
|
33
34
|
|
|
34
35
|
class OnlineNormalizer:
|
{alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/core/optimizers.py
RENAMED
|
@@ -10,15 +10,17 @@ References:
|
|
|
10
10
|
"""
|
|
11
11
|
|
|
12
12
|
from abc import ABC, abstractmethod
|
|
13
|
-
from typing import NamedTuple
|
|
14
13
|
|
|
14
|
+
import chex
|
|
15
15
|
import jax.numpy as jnp
|
|
16
16
|
from jax import Array
|
|
17
|
+
from jaxtyping import Float
|
|
17
18
|
|
|
18
19
|
from alberta_framework.core.types import AutostepState, IDBDState, LMSState
|
|
19
20
|
|
|
20
21
|
|
|
21
|
-
|
|
22
|
+
@chex.dataclass(frozen=True)
|
|
23
|
+
class OptimizerUpdate:
|
|
22
24
|
"""Result of an optimizer update step.
|
|
23
25
|
|
|
24
26
|
Attributes:
|
|
@@ -28,8 +30,8 @@ class OptimizerUpdate(NamedTuple):
|
|
|
28
30
|
metrics: Dictionary of metrics for logging (values are JAX arrays for scan compatibility)
|
|
29
31
|
"""
|
|
30
32
|
|
|
31
|
-
weight_delta: Array
|
|
32
|
-
bias_delta: Array
|
|
33
|
+
weight_delta: Float[Array, " feature_dim"]
|
|
34
|
+
bias_delta: Float[Array, ""]
|
|
33
35
|
new_state: LMSState | IDBDState | AutostepState
|
|
34
36
|
metrics: dict[str, Array]
|
|
35
37
|
|
|
@@ -1,13 +1,15 @@
|
|
|
1
1
|
"""Type definitions for the Alberta Framework.
|
|
2
2
|
|
|
3
3
|
This module defines the core data types used throughout the framework,
|
|
4
|
-
|
|
4
|
+
using chex dataclasses for JAX compatibility and jaxtyping for shape annotations.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
from typing import TYPE_CHECKING
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
8
8
|
|
|
9
|
+
import chex
|
|
9
10
|
import jax.numpy as jnp
|
|
10
11
|
from jax import Array
|
|
12
|
+
from jaxtyping import Float, Int, PRNGKeyArray
|
|
11
13
|
|
|
12
14
|
if TYPE_CHECKING:
|
|
13
15
|
from alberta_framework.core.learners import NormalizedLearnerState
|
|
@@ -19,7 +21,8 @@ Prediction = Array # y_t: model output
|
|
|
19
21
|
Reward = float # r_t: scalar reward
|
|
20
22
|
|
|
21
23
|
|
|
22
|
-
|
|
24
|
+
@chex.dataclass(frozen=True)
|
|
25
|
+
class TimeStep:
|
|
23
26
|
"""Single experience from an experience stream.
|
|
24
27
|
|
|
25
28
|
Attributes:
|
|
@@ -27,25 +30,12 @@ class TimeStep(NamedTuple):
|
|
|
27
30
|
target: Desired output y*_t (for supervised learning)
|
|
28
31
|
"""
|
|
29
32
|
|
|
30
|
-
observation:
|
|
31
|
-
target:
|
|
33
|
+
observation: Float[Array, " feature_dim"]
|
|
34
|
+
target: Float[Array, " 1"]
|
|
32
35
|
|
|
33
36
|
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
Attributes:
|
|
38
|
-
weights: Weight vector for linear prediction
|
|
39
|
-
bias: Bias term
|
|
40
|
-
optimizer_state: State maintained by the optimizer
|
|
41
|
-
"""
|
|
42
|
-
|
|
43
|
-
weights: Array
|
|
44
|
-
bias: Array
|
|
45
|
-
optimizer_state: "LMSState | IDBDState | AutostepState"
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
class LMSState(NamedTuple):
|
|
37
|
+
@chex.dataclass(frozen=True)
|
|
38
|
+
class LMSState:
|
|
49
39
|
"""State for the LMS (Least Mean Square) optimizer.
|
|
50
40
|
|
|
51
41
|
LMS uses a fixed step-size, so state only tracks the step-size parameter.
|
|
@@ -54,10 +44,11 @@ class LMSState(NamedTuple):
|
|
|
54
44
|
step_size: Fixed learning rate alpha
|
|
55
45
|
"""
|
|
56
46
|
|
|
57
|
-
step_size: Array
|
|
47
|
+
step_size: Float[Array, ""]
|
|
58
48
|
|
|
59
49
|
|
|
60
|
-
|
|
50
|
+
@chex.dataclass(frozen=True)
|
|
51
|
+
class IDBDState:
|
|
61
52
|
"""State for the IDBD (Incremental Delta-Bar-Delta) optimizer.
|
|
62
53
|
|
|
63
54
|
IDBD maintains per-weight adaptive step-sizes that are meta-learned
|
|
@@ -73,14 +64,15 @@ class IDBDState(NamedTuple):
|
|
|
73
64
|
bias_trace: Trace for the bias term
|
|
74
65
|
"""
|
|
75
66
|
|
|
76
|
-
log_step_sizes: Array # log(alpha_i) for numerical stability
|
|
77
|
-
traces: Array # h_i: trace of weight-feature products
|
|
78
|
-
meta_step_size: Array # beta: step-size for the step-sizes
|
|
79
|
-
bias_step_size: Array # Step-size for bias
|
|
80
|
-
bias_trace: Array # Trace for bias
|
|
67
|
+
log_step_sizes: Float[Array, " feature_dim"] # log(alpha_i) for numerical stability
|
|
68
|
+
traces: Float[Array, " feature_dim"] # h_i: trace of weight-feature products
|
|
69
|
+
meta_step_size: Float[Array, ""] # beta: step-size for the step-sizes
|
|
70
|
+
bias_step_size: Float[Array, ""] # Step-size for bias
|
|
71
|
+
bias_trace: Float[Array, ""] # Trace for bias
|
|
81
72
|
|
|
82
73
|
|
|
83
|
-
|
|
74
|
+
@chex.dataclass(frozen=True)
|
|
75
|
+
class AutostepState:
|
|
84
76
|
"""State for the Autostep optimizer.
|
|
85
77
|
|
|
86
78
|
Autostep is a tuning-free step-size adaptation algorithm that normalizes
|
|
@@ -100,17 +92,33 @@ class AutostepState(NamedTuple):
|
|
|
100
92
|
bias_normalizer: Normalizer for the bias gradient
|
|
101
93
|
"""
|
|
102
94
|
|
|
103
|
-
step_sizes: Array # alpha_i
|
|
104
|
-
traces: Array # h_i
|
|
105
|
-
normalizers: Array # v_i: running max of |gradient|
|
|
106
|
-
meta_step_size: Array # mu
|
|
107
|
-
normalizer_decay: Array # tau
|
|
108
|
-
bias_step_size: Array
|
|
109
|
-
bias_trace: Array
|
|
110
|
-
bias_normalizer: Array
|
|
95
|
+
step_sizes: Float[Array, " feature_dim"] # alpha_i
|
|
96
|
+
traces: Float[Array, " feature_dim"] # h_i
|
|
97
|
+
normalizers: Float[Array, " feature_dim"] # v_i: running max of |gradient|
|
|
98
|
+
meta_step_size: Float[Array, ""] # mu
|
|
99
|
+
normalizer_decay: Float[Array, ""] # tau
|
|
100
|
+
bias_step_size: Float[Array, ""]
|
|
101
|
+
bias_trace: Float[Array, ""]
|
|
102
|
+
bias_normalizer: Float[Array, ""]
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@chex.dataclass(frozen=True)
|
|
106
|
+
class LearnerState:
|
|
107
|
+
"""State for a linear learner.
|
|
108
|
+
|
|
109
|
+
Attributes:
|
|
110
|
+
weights: Weight vector for linear prediction
|
|
111
|
+
bias: Bias term
|
|
112
|
+
optimizer_state: State maintained by the optimizer
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
weights: Float[Array, " feature_dim"]
|
|
116
|
+
bias: Float[Array, ""]
|
|
117
|
+
optimizer_state: LMSState | IDBDState | AutostepState
|
|
111
118
|
|
|
112
119
|
|
|
113
|
-
|
|
120
|
+
@chex.dataclass(frozen=True)
|
|
121
|
+
class StepSizeTrackingConfig:
|
|
114
122
|
"""Configuration for recording per-weight step-sizes during training.
|
|
115
123
|
|
|
116
124
|
Attributes:
|
|
@@ -122,7 +130,8 @@ class StepSizeTrackingConfig(NamedTuple):
|
|
|
122
130
|
include_bias: bool = True
|
|
123
131
|
|
|
124
132
|
|
|
125
|
-
|
|
133
|
+
@chex.dataclass(frozen=True)
|
|
134
|
+
class StepSizeHistory:
|
|
126
135
|
"""History of per-weight step-sizes recorded during training.
|
|
127
136
|
|
|
128
137
|
Attributes:
|
|
@@ -133,13 +142,14 @@ class StepSizeHistory(NamedTuple):
|
|
|
133
142
|
shape (num_recordings, num_weights) or None. Only populated for Autostep optimizer.
|
|
134
143
|
"""
|
|
135
144
|
|
|
136
|
-
step_sizes: Array
|
|
137
|
-
bias_step_sizes: Array
|
|
138
|
-
recording_indices: Array
|
|
139
|
-
normalizers: Array | None = None
|
|
145
|
+
step_sizes: Float[Array, "num_recordings feature_dim"]
|
|
146
|
+
bias_step_sizes: Float[Array, " num_recordings"] | None
|
|
147
|
+
recording_indices: Int[Array, " num_recordings"]
|
|
148
|
+
normalizers: Float[Array, "num_recordings feature_dim"] | None = None
|
|
140
149
|
|
|
141
150
|
|
|
142
|
-
|
|
151
|
+
@chex.dataclass(frozen=True)
|
|
152
|
+
class NormalizerTrackingConfig:
|
|
143
153
|
"""Configuration for recording per-feature normalizer state during training.
|
|
144
154
|
|
|
145
155
|
Attributes:
|
|
@@ -149,7 +159,8 @@ class NormalizerTrackingConfig(NamedTuple):
|
|
|
149
159
|
interval: int
|
|
150
160
|
|
|
151
161
|
|
|
152
|
-
|
|
162
|
+
@chex.dataclass(frozen=True)
|
|
163
|
+
class NormalizerHistory:
|
|
153
164
|
"""History of per-feature normalizer state recorded during training.
|
|
154
165
|
|
|
155
166
|
Used for analyzing how the OnlineNormalizer adapts to distribution shifts
|
|
@@ -162,12 +173,13 @@ class NormalizerHistory(NamedTuple):
|
|
|
162
173
|
recording_indices: Step indices where recordings were made, shape (num_recordings,)
|
|
163
174
|
"""
|
|
164
175
|
|
|
165
|
-
means: Array
|
|
166
|
-
variances: Array
|
|
167
|
-
recording_indices: Array
|
|
176
|
+
means: Float[Array, "num_recordings feature_dim"]
|
|
177
|
+
variances: Float[Array, "num_recordings feature_dim"]
|
|
178
|
+
recording_indices: Int[Array, " num_recordings"]
|
|
168
179
|
|
|
169
180
|
|
|
170
|
-
|
|
181
|
+
@chex.dataclass(frozen=True)
|
|
182
|
+
class BatchedLearningResult:
|
|
171
183
|
"""Result from batched learning loop across multiple seeds.
|
|
172
184
|
|
|
173
185
|
Used with `run_learning_loop_batched` for vmap-based GPU parallelization.
|
|
@@ -180,12 +192,13 @@ class BatchedLearningResult(NamedTuple):
|
|
|
180
192
|
or None if tracking was disabled
|
|
181
193
|
"""
|
|
182
194
|
|
|
183
|
-
states:
|
|
184
|
-
metrics: Array
|
|
195
|
+
states: LearnerState # Batched: each array has shape (num_seeds, ...)
|
|
196
|
+
metrics: Float[Array, "num_seeds num_steps 3"]
|
|
185
197
|
step_size_history: StepSizeHistory | None
|
|
186
198
|
|
|
187
199
|
|
|
188
|
-
|
|
200
|
+
@chex.dataclass(frozen=True)
|
|
201
|
+
class BatchedNormalizedResult:
|
|
189
202
|
"""Result from batched normalized learning loop across multiple seeds.
|
|
190
203
|
|
|
191
204
|
Used with `run_normalized_learning_loop_batched` for vmap-based GPU parallelization.
|
|
@@ -201,7 +214,7 @@ class BatchedNormalizedResult(NamedTuple):
|
|
|
201
214
|
"""
|
|
202
215
|
|
|
203
216
|
states: "NormalizedLearnerState" # Batched: each array has shape (num_seeds, ...)
|
|
204
|
-
metrics: Array
|
|
217
|
+
metrics: Float[Array, "num_seeds num_steps 4"]
|
|
205
218
|
step_size_history: StepSizeHistory | None
|
|
206
219
|
normalizer_history: NormalizerHistory | None
|
|
207
220
|
|