alberta-framework 0.2.1__tar.gz → 0.2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/CLAUDE.md +6 -0
  2. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/PKG-INFO +1 -1
  3. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/pyproject.toml +1 -1
  4. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/src/alberta_framework/core/learners.py +19 -10
  5. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/.github/workflows/ci.yml +0 -0
  6. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/.github/workflows/docs.yml +0 -0
  7. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/.github/workflows/publish.yml +0 -0
  8. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/.gitignore +0 -0
  9. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/ALBERTA_PLAN.md +0 -0
  10. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/CHANGELOG.md +0 -0
  11. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/LICENSE +0 -0
  12. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/README.md +0 -0
  13. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/ROADMAP.md +0 -0
  14. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/docs/contributing.md +0 -0
  15. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/docs/gen_ref_pages.py +0 -0
  16. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/docs/getting-started/installation.md +0 -0
  17. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/docs/getting-started/quickstart.md +0 -0
  18. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/docs/guide/concepts.md +0 -0
  19. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/docs/guide/experiments.md +0 -0
  20. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/docs/guide/gymnasium.md +0 -0
  21. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/docs/guide/optimizers.md +0 -0
  22. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/docs/guide/streams.md +0 -0
  23. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/docs/index.md +0 -0
  24. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/docs/javascripts/mathjax.js +0 -0
  25. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/examples/The Alberta Plan/Step1/README.md +0 -0
  26. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/examples/The Alberta Plan/Step1/autostep_comparison.py +0 -0
  27. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/examples/The Alberta Plan/Step1/external_normalization_study.py +0 -0
  28. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/examples/The Alberta Plan/Step1/idbd_lms_autostep_comparison.py +0 -0
  29. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/examples/The Alberta Plan/Step1/normalization_study.py +0 -0
  30. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/examples/The Alberta Plan/Step1/sutton1992_experiment1.py +0 -0
  31. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/examples/The Alberta Plan/Step1/sutton1992_experiment2.py +0 -0
  32. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/examples/gymnasium_reward_prediction.py +0 -0
  33. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/examples/publication_experiment.py +0 -0
  34. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/examples/td_cartpole_lms.py +0 -0
  35. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/mkdocs.yml +0 -0
  36. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/papers/mahmood-msc-thesis-summary.md +0 -0
  37. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/src/alberta_framework/__init__.py +0 -0
  38. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/src/alberta_framework/core/__init__.py +0 -0
  39. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/src/alberta_framework/core/normalizers.py +0 -0
  40. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/src/alberta_framework/core/optimizers.py +0 -0
  41. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/src/alberta_framework/core/types.py +0 -0
  42. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/src/alberta_framework/py.typed +0 -0
  43. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/src/alberta_framework/streams/__init__.py +0 -0
  44. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/src/alberta_framework/streams/base.py +0 -0
  45. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/src/alberta_framework/streams/gymnasium.py +0 -0
  46. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/src/alberta_framework/streams/synthetic.py +0 -0
  47. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/src/alberta_framework/utils/__init__.py +0 -0
  48. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/src/alberta_framework/utils/experiments.py +0 -0
  49. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/src/alberta_framework/utils/export.py +0 -0
  50. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/src/alberta_framework/utils/metrics.py +0 -0
  51. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/src/alberta_framework/utils/statistics.py +0 -0
  52. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/src/alberta_framework/utils/timing.py +0 -0
  53. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/src/alberta_framework/utils/visualization.py +0 -0
  54. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/tests/conftest.py +0 -0
  55. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/tests/test_gymnasium_streams.py +0 -0
  56. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/tests/test_learners.py +0 -0
  57. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/tests/test_normalizers.py +0 -0
  58. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/tests/test_optimizers.py +0 -0
  59. {alberta_framework-0.2.1 → alberta_framework-0.2.2}/tests/test_streams.py +0 -0
@@ -465,3 +465,9 @@ The publish workflow uses OpenID Connect (no API tokens). Configure on PyPI:
465
465
  1. PyPI project → Settings → Publishing → Add GitHub publisher
466
466
  2. Repository: `j-klawson/alberta-framework`, Workflow: `publish.yml`, Environment: `pypi`
467
467
  3. Repeat on TestPyPI with environment: `testpypi`
468
+
469
+ ## Changelog
470
+
471
+ ### v0.2.2 (2026-02-02)
472
+ - Fixed mypy type errors in `run_learning_loop_batched` and `run_normalized_learning_loop_batched` functions
473
+ - Added `typing.cast` to properly handle conditional return type unpacking in batched learning loops
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: alberta-framework
3
- Version: 0.2.1
3
+ Version: 0.2.2
4
4
  Summary: Implementation of the Alberta Plan for AI Research - continual learning with meta-learned step-sizes
5
5
  Project-URL: Homepage, https://github.com/j-klawson/alberta-framework
6
6
  Project-URL: Repository, https://github.com/j-klawson/alberta-framework
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "alberta-framework"
7
- version = "0.2.1"
7
+ version = "0.2.2"
8
8
  description = "Implementation of the Alberta Plan for AI Research - continual learning with meta-learned step-sizes"
9
9
  readme = "README.md"
10
10
  license = "Apache-2.0"
@@ -5,7 +5,7 @@ for temporally-uniform learning. Uses JAX's scan for efficient JIT-compiled
5
5
  training loops.
6
6
  """
7
7
 
8
- from typing import NamedTuple
8
+ from typing import NamedTuple, cast
9
9
 
10
10
  import jax
11
11
  import jax.numpy as jnp
@@ -900,10 +900,12 @@ def run_learning_loop_batched[StreamStateT](
900
900
  learner, stream, num_steps, key, learner_state, step_size_tracking
901
901
  )
902
902
  if step_size_tracking is not None:
903
- state, metrics, history = result
903
+ state, metrics, history = cast(
904
+ tuple[LearnerState, Array, StepSizeHistory], result
905
+ )
904
906
  return state, metrics, history
905
907
  else:
906
- state, metrics = result
908
+ state, metrics = cast(tuple[LearnerState, Array], result)
907
909
  # Return None for history to maintain consistent output structure
908
910
  return state, metrics, None
909
911
 
@@ -911,7 +913,7 @@ def run_learning_loop_batched[StreamStateT](
911
913
  batched_states, batched_metrics, batched_history = jax.vmap(single_seed_run)(keys)
912
914
 
913
915
  # Reconstruct batched history if tracking was enabled
914
- if step_size_tracking is not None:
916
+ if step_size_tracking is not None and batched_history is not None:
915
917
  batched_step_size_history = StepSizeHistory(
916
918
  step_sizes=batched_history.step_sizes,
917
919
  bias_step_sizes=batched_history.bias_step_sizes,
@@ -993,16 +995,23 @@ def run_normalized_learning_loop_batched[StreamStateT](
993
995
 
994
996
  # Unpack based on what tracking was enabled
995
997
  if step_size_tracking is not None and normalizer_tracking is not None:
996
- state, metrics, ss_history, norm_history = result
998
+ state, metrics, ss_history, norm_history = cast(
999
+ tuple[NormalizedLearnerState, Array, StepSizeHistory, NormalizerHistory],
1000
+ result,
1001
+ )
997
1002
  return state, metrics, ss_history, norm_history
998
1003
  elif step_size_tracking is not None:
999
- state, metrics, ss_history = result
1004
+ state, metrics, ss_history = cast(
1005
+ tuple[NormalizedLearnerState, Array, StepSizeHistory], result
1006
+ )
1000
1007
  return state, metrics, ss_history, None
1001
1008
  elif normalizer_tracking is not None:
1002
- state, metrics, norm_history = result
1009
+ state, metrics, norm_history = cast(
1010
+ tuple[NormalizedLearnerState, Array, NormalizerHistory], result
1011
+ )
1003
1012
  return state, metrics, None, norm_history
1004
1013
  else:
1005
- state, metrics = result
1014
+ state, metrics = cast(tuple[NormalizedLearnerState, Array], result)
1006
1015
  return state, metrics, None, None
1007
1016
 
1008
1017
  # vmap over the keys dimension
@@ -1011,7 +1020,7 @@ def run_normalized_learning_loop_batched[StreamStateT](
1011
1020
  )
1012
1021
 
1013
1022
  # Reconstruct batched histories if tracking was enabled
1014
- if step_size_tracking is not None:
1023
+ if step_size_tracking is not None and batched_ss_history is not None:
1015
1024
  batched_step_size_history = StepSizeHistory(
1016
1025
  step_sizes=batched_ss_history.step_sizes,
1017
1026
  bias_step_sizes=batched_ss_history.bias_step_sizes,
@@ -1021,7 +1030,7 @@ def run_normalized_learning_loop_batched[StreamStateT](
1021
1030
  else:
1022
1031
  batched_step_size_history = None
1023
1032
 
1024
- if normalizer_tracking is not None:
1033
+ if normalizer_tracking is not None and batched_norm_history is not None:
1025
1034
  batched_normalizer_history = NormalizerHistory(
1026
1035
  means=batched_norm_history.means,
1027
1036
  variances=batched_norm_history.variances,