alberta-framework 0.2.1__tar.gz → 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/.github/workflows/publish.yml +1 -0
  2. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/CLAUDE.md +16 -1
  3. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/PKG-INFO +4 -1
  4. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/pyproject.toml +37 -1
  5. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/core/learners.py +69 -60
  6. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/core/normalizers.py +8 -7
  7. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/core/optimizers.py +6 -4
  8. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/core/types.py +65 -52
  9. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/streams/synthetic.py +42 -32
  10. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/tests/conftest.py +2 -1
  11. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/tests/test_learners.py +52 -46
  12. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/tests/test_normalizers.py +12 -10
  13. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/tests/test_optimizers.py +30 -24
  14. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/tests/test_streams.py +70 -61
  15. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/.github/workflows/ci.yml +0 -0
  16. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/.github/workflows/docs.yml +0 -0
  17. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/.gitignore +0 -0
  18. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/ALBERTA_PLAN.md +0 -0
  19. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/CHANGELOG.md +0 -0
  20. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/LICENSE +0 -0
  21. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/README.md +0 -0
  22. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/ROADMAP.md +0 -0
  23. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/docs/contributing.md +0 -0
  24. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/docs/gen_ref_pages.py +0 -0
  25. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/docs/getting-started/installation.md +0 -0
  26. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/docs/getting-started/quickstart.md +0 -0
  27. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/docs/guide/concepts.md +0 -0
  28. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/docs/guide/experiments.md +0 -0
  29. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/docs/guide/gymnasium.md +0 -0
  30. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/docs/guide/optimizers.md +0 -0
  31. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/docs/guide/streams.md +0 -0
  32. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/docs/index.md +0 -0
  33. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/docs/javascripts/mathjax.js +0 -0
  34. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/examples/The Alberta Plan/Step1/README.md +0 -0
  35. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/examples/The Alberta Plan/Step1/autostep_comparison.py +0 -0
  36. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/examples/The Alberta Plan/Step1/external_normalization_study.py +0 -0
  37. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/examples/The Alberta Plan/Step1/idbd_lms_autostep_comparison.py +0 -0
  38. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/examples/The Alberta Plan/Step1/normalization_study.py +0 -0
  39. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/examples/The Alberta Plan/Step1/sutton1992_experiment1.py +0 -0
  40. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/examples/The Alberta Plan/Step1/sutton1992_experiment2.py +0 -0
  41. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/examples/gymnasium_reward_prediction.py +0 -0
  42. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/examples/publication_experiment.py +0 -0
  43. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/examples/td_cartpole_lms.py +0 -0
  44. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/mkdocs.yml +0 -0
  45. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/papers/mahmood-msc-thesis-summary.md +0 -0
  46. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/__init__.py +0 -0
  47. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/core/__init__.py +0 -0
  48. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/py.typed +0 -0
  49. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/streams/__init__.py +0 -0
  50. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/streams/base.py +0 -0
  51. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/streams/gymnasium.py +0 -0
  52. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/utils/__init__.py +0 -0
  53. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/utils/experiments.py +0 -0
  54. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/utils/export.py +0 -0
  55. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/utils/metrics.py +0 -0
  56. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/utils/statistics.py +0 -0
  57. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/utils/timing.py +0 -0
  58. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/src/alberta_framework/utils/visualization.py +0 -0
  59. {alberta_framework-0.2.1 → alberta_framework-0.3.0}/tests/test_gymnasium_streams.py +0 -0
@@ -4,6 +4,7 @@ on:
4
4
  push:
5
5
  tags:
6
6
  - "v*"
7
+ workflow_dispatch:
7
8
 
8
9
  permissions:
9
10
  id-token: write
@@ -70,7 +70,8 @@ mkdocs build # Build static site to site/
70
70
  ## Development Guidelines
71
71
 
72
72
  ### Design Principles
73
- - **Immutable State**: All state uses NamedTuples for JAX compatibility
73
+ - **Immutable State**: All state uses `@chex.dataclass(frozen=True)` for JAX PyTree compatibility
74
+ - **Type Safety**: jaxtyping annotations for shape checking (`Float[Array, " feature_dim"]`)
74
75
  - **Functional Style**: Pure functions enable `jit`, `vmap`, `jax.lax.scan`
75
76
  - **Scan-Based Learning**: Learning loops use `jax.lax.scan` for JIT-compiled training
76
77
  - **Composition**: Learners accept optimizers as parameters
@@ -85,6 +86,7 @@ mkdocs build # Build static site to site/
85
86
  ### Testing
86
87
  - Tests are in `tests/` directory
87
88
  - Use pytest fixtures from `conftest.py`
89
+ - Use chex assertions: `chex.assert_shape()`, `chex.assert_trees_all_close()`, `chex.assert_tree_all_finite()`
88
90
  - All tests should pass before committing
89
91
 
90
92
  ## Key Algorithms
@@ -465,3 +467,16 @@ The publish workflow uses OpenID Connect (no API tokens). Configure on PyPI:
465
467
  1. PyPI project → Settings → Publishing → Add GitHub publisher
466
468
  2. Repository: `j-klawson/alberta-framework`, Workflow: `publish.yml`, Environment: `pypi`
467
469
  3. Repeat on TestPyPI with environment: `testpypi`
470
+
471
+ ## Changelog
472
+
473
+ ### v0.3.0 (2026-02-03)
474
+ - **FEATURE**: Migrated all state types from NamedTuple to `@chex.dataclass(frozen=True)` for DeepMind-style JAX compatibility
475
+ - **FEATURE**: Added jaxtyping shape annotations for compile-time type safety (`Float[Array, " feature_dim"]`, `PRNGKeyArray`, etc.)
476
+ - **FEATURE**: Updated test suite to use chex assertions (`chex.assert_shape`, `chex.assert_tree_all_finite`, `chex.assert_trees_all_close`)
477
+ - **DEPS**: Added `chex>=0.1.86` and `jaxtyping>=0.2.28` as required dependencies
478
+ - **DEPS**: Added `beartype>=0.18.0` as optional dev dependency for runtime type checking
479
+
480
+ ### v0.2.2 (2026-02-02)
481
+ - Fixed mypy type errors in `run_learning_loop_batched` and `run_normalized_learning_loop_batched` functions
482
+ - Added `typing.cast` to properly handle conditional return type unpacking in batched learning loops
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: alberta-framework
3
- Version: 0.2.1
3
+ Version: 0.3.0
4
4
  Summary: Implementation of the Alberta Plan for AI Research - continual learning with meta-learned step-sizes
5
5
  Project-URL: Homepage, https://github.com/j-klawson/alberta-framework
6
6
  Project-URL: Repository, https://github.com/j-klawson/alberta-framework
@@ -17,14 +17,17 @@ Classifier: Programming Language :: Python :: 3
17
17
  Classifier: Programming Language :: Python :: 3.13
18
18
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
19
19
  Requires-Python: >=3.13
20
+ Requires-Dist: chex>=0.1.86
20
21
  Requires-Dist: jax>=0.4
21
22
  Requires-Dist: jaxlib>=0.4
23
+ Requires-Dist: jaxtyping>=0.2.28
22
24
  Requires-Dist: joblib>=1.3
23
25
  Requires-Dist: matplotlib>=3.8
24
26
  Requires-Dist: numpy>=2.0
25
27
  Requires-Dist: scipy>=1.11
26
28
  Requires-Dist: tqdm>=4.66
27
29
  Provides-Extra: dev
30
+ Requires-Dist: beartype>=0.18.0; extra == 'dev'
28
31
  Requires-Dist: mypy; extra == 'dev'
29
32
  Requires-Dist: pytest-cov; extra == 'dev'
30
33
  Requires-Dist: pytest>=8.0; extra == 'dev'
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "alberta-framework"
7
- version = "0.2.1"
7
+ version = "0.3.0"
8
8
  description = "Implementation of the Alberta Plan for AI Research - continual learning with meta-learned step-sizes"
9
9
  readme = "README.md"
10
10
  license = "Apache-2.0"
@@ -29,6 +29,8 @@ dependencies = [
29
29
  "scipy>=1.11",
30
30
  "joblib>=1.3",
31
31
  "tqdm>=4.66",
32
+ "chex>=0.1.86",
33
+ "jaxtyping>=0.2.28",
32
34
  ]
33
35
 
34
36
  [project.optional-dependencies]
@@ -37,6 +39,7 @@ dev = [
37
39
  "pytest-cov",
38
40
  "ruff",
39
41
  "mypy",
42
+ "beartype>=0.18.0",
40
43
  ]
41
44
  docs = [
42
45
  "mkdocs>=1.6",
@@ -99,6 +102,39 @@ ignore_missing_imports = true
99
102
  module = "pandas.*"
100
103
  ignore_missing_imports = true
101
104
 
105
+ [[tool.mypy.overrides]]
106
+ module = "chex.*"
107
+ ignore_missing_imports = true
108
+
109
+ [[tool.mypy.overrides]]
110
+ module = "jaxtyping.*"
111
+ ignore_missing_imports = true
112
+
113
+ # chex dataclasses don't have full mypy support - disable strict checks for core modules
114
+ [[tool.mypy.overrides]]
115
+ module = "alberta_framework.core.types"
116
+ disable_error_code = ["call-arg"]
117
+
118
+ [[tool.mypy.overrides]]
119
+ module = "alberta_framework.core.normalizers"
120
+ disable_error_code = ["call-arg"]
121
+
122
+ [[tool.mypy.overrides]]
123
+ module = "alberta_framework.core.optimizers"
124
+ disable_error_code = ["call-arg"]
125
+
126
+ [[tool.mypy.overrides]]
127
+ module = "alberta_framework.core.learners"
128
+ disable_error_code = ["call-arg", "arg-type", "union-attr"]
129
+
130
+ [[tool.mypy.overrides]]
131
+ module = "alberta_framework.streams.synthetic"
132
+ disable_error_code = ["call-arg"]
133
+
134
+ [[tool.mypy.overrides]]
135
+ module = "alberta_framework.streams.gymnasium"
136
+ disable_error_code = ["call-arg"]
137
+
102
138
  [tool.pytest.ini_options]
103
139
  testpaths = ["tests"]
104
140
  addopts = "-v"
@@ -5,11 +5,13 @@ for temporally-uniform learning. Uses JAX's scan for efficient JIT-compiled
5
5
  training loops.
6
6
  """
7
7
 
8
- from typing import NamedTuple
8
+ from typing import cast
9
9
 
10
+ import chex
10
11
  import jax
11
12
  import jax.numpy as jnp
12
13
  from jax import Array
14
+ from jaxtyping import Float
13
15
 
14
16
  from alberta_framework.core.normalizers import NormalizerState, OnlineNormalizer
15
17
  from alberta_framework.core.optimizers import LMS, Optimizer
@@ -33,7 +35,9 @@ from alberta_framework.streams.base import ScanStream
33
35
  # Type alias for any optimizer type
34
36
  AnyOptimizer = Optimizer[LMSState] | Optimizer[IDBDState] | Optimizer[AutostepState]
35
37
 
36
- class UpdateResult(NamedTuple):
38
+
39
+ @chex.dataclass(frozen=True)
40
+ class UpdateResult:
37
41
  """Result of a learner update step.
38
42
 
39
43
  Attributes:
@@ -45,8 +49,38 @@ class UpdateResult(NamedTuple):
45
49
 
46
50
  state: LearnerState
47
51
  prediction: Prediction
48
- error: Array
49
- metrics: Array
52
+ error: Float[Array, ""]
53
+ metrics: Float[Array, " 3"]
54
+
55
+
56
+ @chex.dataclass(frozen=True)
57
+ class NormalizedLearnerState:
58
+ """State for a learner with online feature normalization.
59
+
60
+ Attributes:
61
+ learner_state: Underlying learner state (weights, bias, optimizer)
62
+ normalizer_state: Online normalizer state (mean, var estimates)
63
+ """
64
+
65
+ learner_state: LearnerState
66
+ normalizer_state: NormalizerState
67
+
68
+
69
+ @chex.dataclass(frozen=True)
70
+ class NormalizedUpdateResult:
71
+ """Result of a normalized learner update step.
72
+
73
+ Attributes:
74
+ state: Updated normalized learner state
75
+ prediction: Prediction made before update
76
+ error: Prediction error
77
+ metrics: Array of metrics [squared_error, error, mean_step_size, normalizer_mean_var]
78
+ """
79
+
80
+ state: NormalizedLearnerState
81
+ prediction: Prediction
82
+ error: Float[Array, ""]
83
+ metrics: Float[Array, " 4"]
50
84
 
51
85
 
52
86
  class LinearLearner:
@@ -133,8 +167,7 @@ class LinearLearner:
133
167
  # Note: type ignore needed because we can't statically prove optimizer_state
134
168
  # matches the optimizer's expected state type (though they will at runtime)
135
169
  opt_update = self._optimizer.update(
136
- state.optimizer_state, # type: ignore[arg-type]
137
- error,
170
+ state.optimizer_state, error,
138
171
  observation,
139
172
  )
140
173
 
@@ -275,12 +308,12 @@ def run_learning_loop[StreamStateT](
275
308
  opt_state = result.state.optimizer_state
276
309
  if hasattr(opt_state, "log_step_sizes"):
277
310
  # IDBD stores log step-sizes
278
- weight_ss = jnp.exp(opt_state.log_step_sizes) # type: ignore[union-attr]
279
- bias_ss = opt_state.bias_step_size # type: ignore[union-attr]
311
+ weight_ss = jnp.exp(opt_state.log_step_sizes)
312
+ bias_ss = opt_state.bias_step_size
280
313
  elif hasattr(opt_state, "step_sizes"):
281
314
  # Autostep stores step-sizes directly
282
- weight_ss = opt_state.step_sizes # type: ignore[union-attr]
283
- bias_ss = opt_state.bias_step_size # type: ignore[union-attr]
315
+ weight_ss = opt_state.step_sizes
316
+ bias_ss = opt_state.bias_step_size
284
317
  else:
285
318
  # LMS has a single fixed step-size
286
319
  weight_ss = jnp.full(feature_dim, opt_state.step_size)
@@ -316,8 +349,7 @@ def run_learning_loop[StreamStateT](
316
349
  new_norm_history = jax.lax.cond(
317
350
  should_record,
318
351
  lambda _: norm_history.at[recording_idx].set(
319
- opt_state.normalizers # type: ignore[union-attr]
320
- ),
352
+ opt_state.normalizers ),
321
353
  lambda _: norm_history,
322
354
  None,
323
355
  )
@@ -361,34 +393,6 @@ def run_learning_loop[StreamStateT](
361
393
  return final_learner, metrics, history
362
394
 
363
395
 
364
- class NormalizedLearnerState(NamedTuple):
365
- """State for a learner with online feature normalization.
366
-
367
- Attributes:
368
- learner_state: Underlying learner state (weights, bias, optimizer)
369
- normalizer_state: Online normalizer state (mean, var estimates)
370
- """
371
-
372
- learner_state: LearnerState
373
- normalizer_state: NormalizerState
374
-
375
-
376
- class NormalizedUpdateResult(NamedTuple):
377
- """Result of a normalized learner update step.
378
-
379
- Attributes:
380
- state: Updated normalized learner state
381
- prediction: Prediction made before update
382
- error: Prediction error
383
- metrics: Array of metrics [squared_error, error, mean_step_size, normalizer_mean_var]
384
- """
385
-
386
- state: NormalizedLearnerState
387
- prediction: Prediction
388
- error: Array
389
- metrics: Array
390
-
391
-
392
396
  class NormalizedLinearLearner:
393
397
  """Linear learner with online feature normalization.
394
398
 
@@ -702,12 +706,12 @@ def run_normalized_learning_loop[StreamStateT](
702
706
  opt_state = result.state.learner_state.optimizer_state
703
707
  if hasattr(opt_state, "log_step_sizes"):
704
708
  # IDBD stores log step-sizes
705
- weight_ss = jnp.exp(opt_state.log_step_sizes) # type: ignore[union-attr]
706
- bias_ss = opt_state.bias_step_size # type: ignore[union-attr]
709
+ weight_ss = jnp.exp(opt_state.log_step_sizes)
710
+ bias_ss = opt_state.bias_step_size
707
711
  elif hasattr(opt_state, "step_sizes"):
708
712
  # Autostep stores step-sizes directly
709
- weight_ss = opt_state.step_sizes # type: ignore[union-attr]
710
- bias_ss = opt_state.bias_step_size # type: ignore[union-attr]
713
+ weight_ss = opt_state.step_sizes
714
+ bias_ss = opt_state.bias_step_size
711
715
  else:
712
716
  # LMS has a single fixed step-size
713
717
  weight_ss = jnp.full(feature_dim, opt_state.step_size)
@@ -741,8 +745,7 @@ def run_normalized_learning_loop[StreamStateT](
741
745
  new_ss_norm = jax.lax.cond(
742
746
  should_record_ss,
743
747
  lambda _: ss_norm.at[recording_idx].set(
744
- opt_state.normalizers # type: ignore[union-attr]
745
- ),
748
+ opt_state.normalizers ),
746
749
  lambda _: ss_norm,
747
750
  None,
748
751
  )
@@ -825,17 +828,14 @@ def run_normalized_learning_loop[StreamStateT](
825
828
  ss_history_result = StepSizeHistory(
826
829
  step_sizes=final_ss_hist,
827
830
  bias_step_sizes=final_ss_bias_hist,
828
- recording_indices=final_ss_rec, # type: ignore[arg-type]
829
- normalizers=final_ss_norm,
831
+ recording_indices=final_ss_rec, normalizers=final_ss_norm,
830
832
  )
831
833
 
832
834
  norm_history_result = None
833
835
  if normalizer_tracking is not None and final_n_means is not None:
834
836
  norm_history_result = NormalizerHistory(
835
837
  means=final_n_means,
836
- variances=final_n_vars, # type: ignore[arg-type]
837
- recording_indices=final_n_rec, # type: ignore[arg-type]
838
- )
838
+ variances=final_n_vars, recording_indices=final_n_rec, )
839
839
 
840
840
  # Return appropriate tuple based on what was tracked
841
841
  if ss_history_result is not None and norm_history_result is not None:
@@ -900,10 +900,12 @@ def run_learning_loop_batched[StreamStateT](
900
900
  learner, stream, num_steps, key, learner_state, step_size_tracking
901
901
  )
902
902
  if step_size_tracking is not None:
903
- state, metrics, history = result
903
+ state, metrics, history = cast(
904
+ tuple[LearnerState, Array, StepSizeHistory], result
905
+ )
904
906
  return state, metrics, history
905
907
  else:
906
- state, metrics = result
908
+ state, metrics = cast(tuple[LearnerState, Array], result)
907
909
  # Return None for history to maintain consistent output structure
908
910
  return state, metrics, None
909
911
 
@@ -911,7 +913,7 @@ def run_learning_loop_batched[StreamStateT](
911
913
  batched_states, batched_metrics, batched_history = jax.vmap(single_seed_run)(keys)
912
914
 
913
915
  # Reconstruct batched history if tracking was enabled
914
- if step_size_tracking is not None:
916
+ if step_size_tracking is not None and batched_history is not None:
915
917
  batched_step_size_history = StepSizeHistory(
916
918
  step_sizes=batched_history.step_sizes,
917
919
  bias_step_sizes=batched_history.bias_step_sizes,
@@ -993,16 +995,23 @@ def run_normalized_learning_loop_batched[StreamStateT](
993
995
 
994
996
  # Unpack based on what tracking was enabled
995
997
  if step_size_tracking is not None and normalizer_tracking is not None:
996
- state, metrics, ss_history, norm_history = result
998
+ state, metrics, ss_history, norm_history = cast(
999
+ tuple[NormalizedLearnerState, Array, StepSizeHistory, NormalizerHistory],
1000
+ result,
1001
+ )
997
1002
  return state, metrics, ss_history, norm_history
998
1003
  elif step_size_tracking is not None:
999
- state, metrics, ss_history = result
1004
+ state, metrics, ss_history = cast(
1005
+ tuple[NormalizedLearnerState, Array, StepSizeHistory], result
1006
+ )
1000
1007
  return state, metrics, ss_history, None
1001
1008
  elif normalizer_tracking is not None:
1002
- state, metrics, norm_history = result
1009
+ state, metrics, norm_history = cast(
1010
+ tuple[NormalizedLearnerState, Array, NormalizerHistory], result
1011
+ )
1003
1012
  return state, metrics, None, norm_history
1004
1013
  else:
1005
- state, metrics = result
1014
+ state, metrics = cast(tuple[NormalizedLearnerState, Array], result)
1006
1015
  return state, metrics, None, None
1007
1016
 
1008
1017
  # vmap over the keys dimension
@@ -1011,7 +1020,7 @@ def run_normalized_learning_loop_batched[StreamStateT](
1011
1020
  )
1012
1021
 
1013
1022
  # Reconstruct batched histories if tracking was enabled
1014
- if step_size_tracking is not None:
1023
+ if step_size_tracking is not None and batched_ss_history is not None:
1015
1024
  batched_step_size_history = StepSizeHistory(
1016
1025
  step_sizes=batched_ss_history.step_sizes,
1017
1026
  bias_step_sizes=batched_ss_history.bias_step_sizes,
@@ -1021,7 +1030,7 @@ def run_normalized_learning_loop_batched[StreamStateT](
1021
1030
  else:
1022
1031
  batched_step_size_history = None
1023
1032
 
1024
- if normalizer_tracking is not None:
1033
+ if normalizer_tracking is not None and batched_norm_history is not None:
1025
1034
  batched_normalizer_history = NormalizerHistory(
1026
1035
  means=batched_norm_history.means,
1027
1036
  variances=batched_norm_history.variances,
@@ -6,13 +6,14 @@ and variance at every time step, following the principle of temporal uniformity.
6
6
  Reference: Welford's online algorithm for numerical stability.
7
7
  """
8
8
 
9
- from typing import NamedTuple
10
-
9
+ import chex
11
10
  import jax.numpy as jnp
12
11
  from jax import Array
12
+ from jaxtyping import Float
13
13
 
14
14
 
15
- class NormalizerState(NamedTuple):
15
+ @chex.dataclass(frozen=True)
16
+ class NormalizerState:
16
17
  """State for online feature normalization.
17
18
 
18
19
  Uses Welford's online algorithm for numerically stable estimation
@@ -25,10 +26,10 @@ class NormalizerState(NamedTuple):
25
26
  decay: Exponential decay factor for estimates (1.0 = no decay, pure online)
26
27
  """
27
28
 
28
- mean: Array # Shape: (feature_dim,)
29
- var: Array # Shape: (feature_dim,)
30
- sample_count: Array # Scalar
31
- decay: Array # Scalar
29
+ mean: Float[Array, " feature_dim"]
30
+ var: Float[Array, " feature_dim"]
31
+ sample_count: Float[Array, ""]
32
+ decay: Float[Array, ""]
32
33
 
33
34
 
34
35
  class OnlineNormalizer:
@@ -10,15 +10,17 @@ References:
10
10
  """
11
11
 
12
12
  from abc import ABC, abstractmethod
13
- from typing import NamedTuple
14
13
 
14
+ import chex
15
15
  import jax.numpy as jnp
16
16
  from jax import Array
17
+ from jaxtyping import Float
17
18
 
18
19
  from alberta_framework.core.types import AutostepState, IDBDState, LMSState
19
20
 
20
21
 
21
- class OptimizerUpdate(NamedTuple):
22
+ @chex.dataclass(frozen=True)
23
+ class OptimizerUpdate:
22
24
  """Result of an optimizer update step.
23
25
 
24
26
  Attributes:
@@ -28,8 +30,8 @@ class OptimizerUpdate(NamedTuple):
28
30
  metrics: Dictionary of metrics for logging (values are JAX arrays for scan compatibility)
29
31
  """
30
32
 
31
- weight_delta: Array
32
- bias_delta: Array
33
+ weight_delta: Float[Array, " feature_dim"]
34
+ bias_delta: Float[Array, ""]
33
35
  new_state: LMSState | IDBDState | AutostepState
34
36
  metrics: dict[str, Array]
35
37
 
@@ -1,13 +1,15 @@
1
1
  """Type definitions for the Alberta Framework.
2
2
 
3
3
  This module defines the core data types used throughout the framework,
4
- following JAX conventions with immutable NamedTuples for state management.
4
+ using chex dataclasses for JAX compatibility and jaxtyping for shape annotations.
5
5
  """
6
6
 
7
- from typing import TYPE_CHECKING, NamedTuple
7
+ from typing import TYPE_CHECKING
8
8
 
9
+ import chex
9
10
  import jax.numpy as jnp
10
11
  from jax import Array
12
+ from jaxtyping import Float, Int, PRNGKeyArray
11
13
 
12
14
  if TYPE_CHECKING:
13
15
  from alberta_framework.core.learners import NormalizedLearnerState
@@ -19,7 +21,8 @@ Prediction = Array # y_t: model output
19
21
  Reward = float # r_t: scalar reward
20
22
 
21
23
 
22
- class TimeStep(NamedTuple):
24
+ @chex.dataclass(frozen=True)
25
+ class TimeStep:
23
26
  """Single experience from an experience stream.
24
27
 
25
28
  Attributes:
@@ -27,25 +30,12 @@ class TimeStep(NamedTuple):
27
30
  target: Desired output y*_t (for supervised learning)
28
31
  """
29
32
 
30
- observation: Observation
31
- target: Target
33
+ observation: Float[Array, " feature_dim"]
34
+ target: Float[Array, " 1"]
32
35
 
33
36
 
34
- class LearnerState(NamedTuple):
35
- """State for a linear learner.
36
-
37
- Attributes:
38
- weights: Weight vector for linear prediction
39
- bias: Bias term
40
- optimizer_state: State maintained by the optimizer
41
- """
42
-
43
- weights: Array
44
- bias: Array
45
- optimizer_state: "LMSState | IDBDState | AutostepState"
46
-
47
-
48
- class LMSState(NamedTuple):
37
+ @chex.dataclass(frozen=True)
38
+ class LMSState:
49
39
  """State for the LMS (Least Mean Square) optimizer.
50
40
 
51
41
  LMS uses a fixed step-size, so state only tracks the step-size parameter.
@@ -54,10 +44,11 @@ class LMSState(NamedTuple):
54
44
  step_size: Fixed learning rate alpha
55
45
  """
56
46
 
57
- step_size: Array
47
+ step_size: Float[Array, ""]
58
48
 
59
49
 
60
- class IDBDState(NamedTuple):
50
+ @chex.dataclass(frozen=True)
51
+ class IDBDState:
61
52
  """State for the IDBD (Incremental Delta-Bar-Delta) optimizer.
62
53
 
63
54
  IDBD maintains per-weight adaptive step-sizes that are meta-learned
@@ -73,14 +64,15 @@ class IDBDState(NamedTuple):
73
64
  bias_trace: Trace for the bias term
74
65
  """
75
66
 
76
- log_step_sizes: Array # log(alpha_i) for numerical stability
77
- traces: Array # h_i: trace of weight-feature products
78
- meta_step_size: Array # beta: step-size for the step-sizes
79
- bias_step_size: Array # Step-size for bias
80
- bias_trace: Array # Trace for bias
67
+ log_step_sizes: Float[Array, " feature_dim"] # log(alpha_i) for numerical stability
68
+ traces: Float[Array, " feature_dim"] # h_i: trace of weight-feature products
69
+ meta_step_size: Float[Array, ""] # beta: step-size for the step-sizes
70
+ bias_step_size: Float[Array, ""] # Step-size for bias
71
+ bias_trace: Float[Array, ""] # Trace for bias
81
72
 
82
73
 
83
- class AutostepState(NamedTuple):
74
+ @chex.dataclass(frozen=True)
75
+ class AutostepState:
84
76
  """State for the Autostep optimizer.
85
77
 
86
78
  Autostep is a tuning-free step-size adaptation algorithm that normalizes
@@ -100,17 +92,33 @@ class AutostepState(NamedTuple):
100
92
  bias_normalizer: Normalizer for the bias gradient
101
93
  """
102
94
 
103
- step_sizes: Array # alpha_i
104
- traces: Array # h_i
105
- normalizers: Array # v_i: running max of |gradient|
106
- meta_step_size: Array # mu
107
- normalizer_decay: Array # tau
108
- bias_step_size: Array
109
- bias_trace: Array
110
- bias_normalizer: Array
95
+ step_sizes: Float[Array, " feature_dim"] # alpha_i
96
+ traces: Float[Array, " feature_dim"] # h_i
97
+ normalizers: Float[Array, " feature_dim"] # v_i: running max of |gradient|
98
+ meta_step_size: Float[Array, ""] # mu
99
+ normalizer_decay: Float[Array, ""] # tau
100
+ bias_step_size: Float[Array, ""]
101
+ bias_trace: Float[Array, ""]
102
+ bias_normalizer: Float[Array, ""]
103
+
104
+
105
+ @chex.dataclass(frozen=True)
106
+ class LearnerState:
107
+ """State for a linear learner.
108
+
109
+ Attributes:
110
+ weights: Weight vector for linear prediction
111
+ bias: Bias term
112
+ optimizer_state: State maintained by the optimizer
113
+ """
114
+
115
+ weights: Float[Array, " feature_dim"]
116
+ bias: Float[Array, ""]
117
+ optimizer_state: LMSState | IDBDState | AutostepState
111
118
 
112
119
 
113
- class StepSizeTrackingConfig(NamedTuple):
120
+ @chex.dataclass(frozen=True)
121
+ class StepSizeTrackingConfig:
114
122
  """Configuration for recording per-weight step-sizes during training.
115
123
 
116
124
  Attributes:
@@ -122,7 +130,8 @@ class StepSizeTrackingConfig(NamedTuple):
122
130
  include_bias: bool = True
123
131
 
124
132
 
125
- class StepSizeHistory(NamedTuple):
133
+ @chex.dataclass(frozen=True)
134
+ class StepSizeHistory:
126
135
  """History of per-weight step-sizes recorded during training.
127
136
 
128
137
  Attributes:
@@ -133,13 +142,14 @@ class StepSizeHistory(NamedTuple):
133
142
  shape (num_recordings, num_weights) or None. Only populated for Autostep optimizer.
134
143
  """
135
144
 
136
- step_sizes: Array # (num_recordings, num_weights)
137
- bias_step_sizes: Array | None # (num_recordings,) or None
138
- recording_indices: Array # (num_recordings,)
139
- normalizers: Array | None = None # (num_recordings, num_weights) - Autostep v_i
145
+ step_sizes: Float[Array, "num_recordings feature_dim"]
146
+ bias_step_sizes: Float[Array, " num_recordings"] | None
147
+ recording_indices: Int[Array, " num_recordings"]
148
+ normalizers: Float[Array, "num_recordings feature_dim"] | None = None
140
149
 
141
150
 
142
- class NormalizerTrackingConfig(NamedTuple):
151
+ @chex.dataclass(frozen=True)
152
+ class NormalizerTrackingConfig:
143
153
  """Configuration for recording per-feature normalizer state during training.
144
154
 
145
155
  Attributes:
@@ -149,7 +159,8 @@ class NormalizerTrackingConfig(NamedTuple):
149
159
  interval: int
150
160
 
151
161
 
152
- class NormalizerHistory(NamedTuple):
162
+ @chex.dataclass(frozen=True)
163
+ class NormalizerHistory:
153
164
  """History of per-feature normalizer state recorded during training.
154
165
 
155
166
  Used for analyzing how the OnlineNormalizer adapts to distribution shifts
@@ -162,12 +173,13 @@ class NormalizerHistory(NamedTuple):
162
173
  recording_indices: Step indices where recordings were made, shape (num_recordings,)
163
174
  """
164
175
 
165
- means: Array # (num_recordings, feature_dim)
166
- variances: Array # (num_recordings, feature_dim)
167
- recording_indices: Array # (num_recordings,)
176
+ means: Float[Array, "num_recordings feature_dim"]
177
+ variances: Float[Array, "num_recordings feature_dim"]
178
+ recording_indices: Int[Array, " num_recordings"]
168
179
 
169
180
 
170
- class BatchedLearningResult(NamedTuple):
181
+ @chex.dataclass(frozen=True)
182
+ class BatchedLearningResult:
171
183
  """Result from batched learning loop across multiple seeds.
172
184
 
173
185
  Used with `run_learning_loop_batched` for vmap-based GPU parallelization.
@@ -180,12 +192,13 @@ class BatchedLearningResult(NamedTuple):
180
192
  or None if tracking was disabled
181
193
  """
182
194
 
183
- states: "LearnerState" # Batched: each array has shape (num_seeds, ...)
184
- metrics: Array # Shape: (num_seeds, num_steps, 3)
195
+ states: LearnerState # Batched: each array has shape (num_seeds, ...)
196
+ metrics: Float[Array, "num_seeds num_steps 3"]
185
197
  step_size_history: StepSizeHistory | None
186
198
 
187
199
 
188
- class BatchedNormalizedResult(NamedTuple):
200
+ @chex.dataclass(frozen=True)
201
+ class BatchedNormalizedResult:
189
202
  """Result from batched normalized learning loop across multiple seeds.
190
203
 
191
204
  Used with `run_normalized_learning_loop_batched` for vmap-based GPU parallelization.
@@ -201,7 +214,7 @@ class BatchedNormalizedResult(NamedTuple):
201
214
  """
202
215
 
203
216
  states: "NormalizedLearnerState" # Batched: each array has shape (num_seeds, ...)
204
- metrics: Array # Shape: (num_seeds, num_steps, 4)
217
+ metrics: Float[Array, "num_seeds num_steps 4"]
205
218
  step_size_history: StepSizeHistory | None
206
219
  normalizer_history: NormalizerHistory | None
207
220