alberta-framework 0.1.1__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/CLAUDE.md +55 -2
  2. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/PKG-INFO +1 -1
  3. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/pyproject.toml +1 -1
  4. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/src/alberta_framework/__init__.py +9 -1
  5. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/src/alberta_framework/core/learners.py +192 -0
  6. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/src/alberta_framework/core/types.py +39 -0
  7. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/tests/test_learners.py +249 -0
  8. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/.github/workflows/ci.yml +0 -0
  9. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/.github/workflows/docs.yml +0 -0
  10. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/.github/workflows/publish.yml +0 -0
  11. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/.gitignore +0 -0
  12. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/ALBERTA_PLAN.md +0 -0
  13. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/CHANGELOG.md +0 -0
  14. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/LICENSE +0 -0
  15. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/README.md +0 -0
  16. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/ROADMAP.md +0 -0
  17. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/docs/contributing.md +0 -0
  18. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/docs/gen_ref_pages.py +0 -0
  19. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/docs/getting-started/installation.md +0 -0
  20. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/docs/getting-started/quickstart.md +0 -0
  21. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/docs/guide/concepts.md +0 -0
  22. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/docs/guide/experiments.md +0 -0
  23. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/docs/guide/gymnasium.md +0 -0
  24. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/docs/guide/optimizers.md +0 -0
  25. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/docs/guide/streams.md +0 -0
  26. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/docs/index.md +0 -0
  27. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/docs/javascripts/mathjax.js +0 -0
  28. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/examples/The Alberta Plan/Step1/README.md +0 -0
  29. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/examples/The Alberta Plan/Step1/autostep_comparison.py +0 -0
  30. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/examples/The Alberta Plan/Step1/external_normalization_study.py +0 -0
  31. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/examples/The Alberta Plan/Step1/idbd_lms_autostep_comparison.py +0 -0
  32. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/examples/The Alberta Plan/Step1/normalization_study.py +0 -0
  33. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/examples/The Alberta Plan/Step1/sutton1992_experiment1.py +0 -0
  34. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/examples/The Alberta Plan/Step1/sutton1992_experiment2.py +0 -0
  35. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/examples/gymnasium_reward_prediction.py +0 -0
  36. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/examples/publication_experiment.py +0 -0
  37. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/examples/td_cartpole_lms.py +0 -0
  38. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/mkdocs.yml +0 -0
  39. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/papers/mahmood-msc-thesis-summary.md +0 -0
  40. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/src/alberta_framework/core/__init__.py +0 -0
  41. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/src/alberta_framework/core/normalizers.py +0 -0
  42. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/src/alberta_framework/core/optimizers.py +0 -0
  43. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/src/alberta_framework/py.typed +0 -0
  44. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/src/alberta_framework/streams/__init__.py +0 -0
  45. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/src/alberta_framework/streams/base.py +0 -0
  46. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/src/alberta_framework/streams/gymnasium.py +0 -0
  47. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/src/alberta_framework/streams/synthetic.py +0 -0
  48. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/src/alberta_framework/utils/__init__.py +0 -0
  49. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/src/alberta_framework/utils/experiments.py +0 -0
  50. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/src/alberta_framework/utils/export.py +0 -0
  51. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/src/alberta_framework/utils/metrics.py +0 -0
  52. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/src/alberta_framework/utils/statistics.py +0 -0
  53. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/src/alberta_framework/utils/timing.py +0 -0
  54. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/src/alberta_framework/utils/visualization.py +0 -0
  55. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/tests/conftest.py +0 -0
  56. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/tests/test_gymnasium_streams.py +0 -0
  57. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/tests/test_normalizers.py +0 -0
  58. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/tests/test_optimizers.py +0 -0
  59. {alberta_framework-0.1.1 → alberta_framework-0.2.0}/tests/test_streams.py +0 -0
@@ -14,10 +14,10 @@ This framework implements Step 1 of the Alberta Plan: demonstrating that IDBD (I
14
14
  ```
15
15
  src/alberta_framework/
16
16
  ├── core/
17
- │ ├── types.py # TimeStep, LearnerState, LMSState, IDBDState, AutostepState, StepSizeTrackingConfig, StepSizeHistory, NormalizerTrackingConfig, NormalizerHistory
17
+ │ ├── types.py # TimeStep, LearnerState, LMSState, IDBDState, AutostepState, StepSizeTrackingConfig, StepSizeHistory, NormalizerTrackingConfig, NormalizerHistory, BatchedLearningResult, BatchedNormalizedResult
18
18
  │ ├── optimizers.py # LMS, IDBD, Autostep optimizers
19
19
  │ ├── normalizers.py # OnlineNormalizer, NormalizerState
20
- │ └── learners.py # LinearLearner, NormalizedLinearLearner, run_learning_loop, run_normalized_learning_loop, metrics_to_dicts
20
+ │ └── learners.py # LinearLearner, NormalizedLinearLearner, run_learning_loop, run_learning_loop_batched, run_normalized_learning_loop, run_normalized_learning_loop_batched, metrics_to_dicts
21
21
  ├── streams/
22
22
  │ ├── base.py # ScanStream protocol (pure function interface for jax.lax.scan)
23
23
  │ ├── synthetic.py # RandomWalkStream, AbruptChangeStream, CyclicStream, PeriodicChangeStream, ScaledStreamWrapper, DynamicScaleShiftStream, ScaleDriftStream
@@ -187,6 +187,59 @@ Return value depends on tracking options:
187
187
  - normalizer_tracking only: `(state, metrics, norm_history)` — 3-tuple
188
188
  - Both: `(state, metrics, ss_history, norm_history)` — 4-tuple
189
189
 
190
+ ### Batched Learning Loops (vmap-based GPU Parallelization)
191
+ The `run_learning_loop_batched` and `run_normalized_learning_loop_batched` functions use `jax.vmap` to run multiple seeds in parallel, typically achieving 2-5x speedup over sequential execution:
192
+
193
+ ```python
194
+ import jax.random as jr
195
+ from alberta_framework import (
196
+ LinearLearner, IDBD, RandomWalkStream,
197
+ run_learning_loop_batched, StepSizeTrackingConfig
198
+ )
199
+
200
+ stream = RandomWalkStream(feature_dim=10)
201
+ learner = LinearLearner(optimizer=IDBD())
202
+
203
+ # Run 30 seeds in parallel
204
+ keys = jr.split(jr.key(42), 30)
205
+ result = run_learning_loop_batched(learner, stream, num_steps=10000, keys=keys)
206
+
207
+ # result.metrics has shape (30, 10000, 3)
208
+ # result.states.weights has shape (30, 10)
209
+ mean_error = result.metrics[:, :, 0].mean(axis=0) # Average squared error over seeds
210
+
211
+ # With step-size tracking
212
+ config = StepSizeTrackingConfig(interval=100)
213
+ result = run_learning_loop_batched(
214
+ learner, stream, num_steps=10000, keys=keys, step_size_tracking=config
215
+ )
216
+ # result.step_size_history.step_sizes has shape (30, 100, 10)
217
+ ```
218
+
219
+ Key features:
220
+ - `jax.vmap` parallelizes over seeds, not steps — memory scales with num_seeds
221
+ - `jax.lax.scan` processes steps sequentially within each seed
222
+ - Returns `BatchedLearningResult` or `BatchedNormalizedResult` NamedTuples
223
+ - Tracking histories get batched shapes: `(num_seeds, num_recordings, ...)`
224
+ - Same initial state used for all seeds (controlled variation via different keys)
225
+
226
+ For normalized learners:
227
+ ```python
228
+ from alberta_framework import (
229
+ NormalizedLinearLearner, run_normalized_learning_loop_batched,
230
+ NormalizerTrackingConfig
231
+ )
232
+
233
+ learner = NormalizedLinearLearner(optimizer=IDBD())
234
+ result = run_normalized_learning_loop_batched(
235
+ learner, stream, num_steps=10000, keys=keys,
236
+ step_size_tracking=StepSizeTrackingConfig(interval=100),
237
+ normalizer_tracking=NormalizerTrackingConfig(interval=100)
238
+ )
239
+ # result.metrics has shape (30, 10000, 4)
240
+ # result.step_size_history and result.normalizer_history both batched
241
+ ```
242
+
190
243
  ## Gymnasium Integration
191
244
 
192
245
  Wrap Gymnasium RL environments as experience streams for the framework.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: alberta-framework
3
- Version: 0.1.1
3
+ Version: 0.2.0
4
4
  Summary: Implementation of the Alberta Plan for AI Research - continual learning with meta-learned step-sizes
5
5
  Project-URL: Homepage, https://github.com/j-klawson/alberta-framework
6
6
  Project-URL: Repository, https://github.com/j-klawson/alberta-framework
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "alberta-framework"
7
- version = "0.1.1"
7
+ version = "0.2.0"
8
8
  description = "Implementation of the Alberta Plan for AI Research - continual learning with meta-learned step-sizes"
9
9
  readme = "README.md"
10
10
  license = "Apache-2.0"
@@ -39,7 +39,7 @@ References
39
39
  - Tuning-free Step-size Adaptation (Mahmood et al., 2012)
40
40
  """
41
41
 
42
- __version__ = "0.1.0"
42
+ __version__ = "0.2.0"
43
43
 
44
44
  # Core types
45
45
  # Learners
@@ -50,7 +50,9 @@ from alberta_framework.core.learners import (
50
50
  UpdateResult,
51
51
  metrics_to_dicts,
52
52
  run_learning_loop,
53
+ run_learning_loop_batched,
53
54
  run_normalized_learning_loop,
55
+ run_normalized_learning_loop_batched,
54
56
  )
55
57
 
56
58
  # Normalizers
@@ -64,6 +66,8 @@ from alberta_framework.core.normalizers import (
64
66
  from alberta_framework.core.optimizers import IDBD, LMS, Autostep, Optimizer
65
67
  from alberta_framework.core.types import (
66
68
  AutostepState,
69
+ BatchedLearningResult,
70
+ BatchedNormalizedResult,
67
71
  IDBDState,
68
72
  LearnerState,
69
73
  LMSState,
@@ -138,6 +142,8 @@ __all__ = [
138
142
  "__version__",
139
143
  # Types
140
144
  "AutostepState",
145
+ "BatchedLearningResult",
146
+ "BatchedNormalizedResult",
141
147
  "IDBDState",
142
148
  "LMSState",
143
149
  "LearnerState",
@@ -164,7 +170,9 @@ __all__ = [
164
170
  "NormalizedLearnerState",
165
171
  "NormalizedLinearLearner",
166
172
  "run_learning_loop",
173
+ "run_learning_loop_batched",
167
174
  "run_normalized_learning_loop",
175
+ "run_normalized_learning_loop_batched",
168
176
  "metrics_to_dicts",
169
177
  # Streams - protocol
170
178
  "ScanStream",
@@ -15,6 +15,8 @@ from alberta_framework.core.normalizers import NormalizerState, OnlineNormalizer
15
15
  from alberta_framework.core.optimizers import LMS, Optimizer
16
16
  from alberta_framework.core.types import (
17
17
  AutostepState,
18
+ BatchedLearningResult,
19
+ BatchedNormalizedResult,
18
20
  IDBDState,
19
21
  LearnerState,
20
22
  LMSState,
@@ -846,6 +848,196 @@ def run_normalized_learning_loop[StreamStateT](
846
848
  return final_learner, metrics
847
849
 
848
850
 
851
+ def run_learning_loop_batched[StreamStateT](
852
+ learner: LinearLearner,
853
+ stream: ScanStream[StreamStateT],
854
+ num_steps: int,
855
+ keys: Array,
856
+ learner_state: LearnerState | None = None,
857
+ step_size_tracking: StepSizeTrackingConfig | None = None,
858
+ ) -> BatchedLearningResult:
859
+ """Run learning loop across multiple seeds in parallel using jax.vmap.
860
+
861
+ This function provides GPU parallelization for multi-seed experiments,
862
+ typically achieving 2-5x speedup over sequential execution.
863
+
864
+ Args:
865
+ learner: The learner to train
866
+ stream: Experience stream providing (observation, target) pairs
867
+ num_steps: Number of learning steps to run per seed
868
+ keys: JAX random keys with shape (num_seeds,) or (num_seeds, 2)
869
+ learner_state: Initial state (if None, will be initialized from stream).
870
+ The same initial state is used for all seeds.
871
+ step_size_tracking: Optional config for recording per-weight step-sizes.
872
+ When provided, history arrays have shape (num_seeds, num_recordings, ...)
873
+
874
+ Returns:
875
+ BatchedLearningResult containing:
876
+ - states: Batched final states with shape (num_seeds, ...) for each array
877
+ - metrics: Array of shape (num_seeds, num_steps, 3)
878
+ - step_size_history: Batched history or None if tracking disabled
879
+
880
+ Examples:
881
+ ```python
882
+ import jax.random as jr
883
+ from alberta_framework import LinearLearner, IDBD, RandomWalkStream
884
+ from alberta_framework import run_learning_loop_batched
885
+
886
+ stream = RandomWalkStream(feature_dim=10)
887
+ learner = LinearLearner(optimizer=IDBD())
888
+
889
+ # Run 30 seeds in parallel
890
+ keys = jr.split(jr.key(42), 30)
891
+ result = run_learning_loop_batched(learner, stream, num_steps=10000, keys=keys)
892
+
893
+ # result.metrics has shape (30, 10000, 3)
894
+ mean_error = result.metrics[:, :, 0].mean(axis=0) # Average over seeds
895
+ ```
896
+ """
897
+ # Define single-seed function that returns consistent structure
898
+ def single_seed_run(key: Array) -> tuple[LearnerState, Array, StepSizeHistory | None]:
899
+ result = run_learning_loop(
900
+ learner, stream, num_steps, key, learner_state, step_size_tracking
901
+ )
902
+ if step_size_tracking is not None:
903
+ state, metrics, history = result
904
+ return state, metrics, history
905
+ else:
906
+ state, metrics = result
907
+ # Return None for history to maintain consistent output structure
908
+ return state, metrics, None
909
+
910
+ # vmap over the keys dimension
911
+ batched_states, batched_metrics, batched_history = jax.vmap(single_seed_run)(keys)
912
+
913
+ # Reconstruct batched history if tracking was enabled
914
+ if step_size_tracking is not None:
915
+ batched_step_size_history = StepSizeHistory(
916
+ step_sizes=batched_history.step_sizes,
917
+ bias_step_sizes=batched_history.bias_step_sizes,
918
+ recording_indices=batched_history.recording_indices,
919
+ normalizers=batched_history.normalizers,
920
+ )
921
+ else:
922
+ batched_step_size_history = None
923
+
924
+ return BatchedLearningResult(
925
+ states=batched_states,
926
+ metrics=batched_metrics,
927
+ step_size_history=batched_step_size_history,
928
+ )
929
+
930
+
931
+ def run_normalized_learning_loop_batched[StreamStateT](
932
+ learner: NormalizedLinearLearner,
933
+ stream: ScanStream[StreamStateT],
934
+ num_steps: int,
935
+ keys: Array,
936
+ learner_state: NormalizedLearnerState | None = None,
937
+ step_size_tracking: StepSizeTrackingConfig | None = None,
938
+ normalizer_tracking: NormalizerTrackingConfig | None = None,
939
+ ) -> BatchedNormalizedResult:
940
+ """Run normalized learning loop across multiple seeds in parallel using jax.vmap.
941
+
942
+ This function provides GPU parallelization for multi-seed experiments with
943
+ normalized learners, typically achieving 2-5x speedup over sequential execution.
944
+
945
+ Args:
946
+ learner: The normalized learner to train
947
+ stream: Experience stream providing (observation, target) pairs
948
+ num_steps: Number of learning steps to run per seed
949
+ keys: JAX random keys with shape (num_seeds,) or (num_seeds, 2)
950
+ learner_state: Initial state (if None, will be initialized from stream).
951
+ The same initial state is used for all seeds.
952
+ step_size_tracking: Optional config for recording per-weight step-sizes.
953
+ When provided, history arrays have shape (num_seeds, num_recordings, ...)
954
+ normalizer_tracking: Optional config for recording normalizer state.
955
+ When provided, history arrays have shape (num_seeds, num_recordings, ...)
956
+
957
+ Returns:
958
+ BatchedNormalizedResult containing:
959
+ - states: Batched final states with shape (num_seeds, ...) for each array
960
+ - metrics: Array of shape (num_seeds, num_steps, 4)
961
+ - step_size_history: Batched history or None if tracking disabled
962
+ - normalizer_history: Batched history or None if tracking disabled
963
+
964
+ Examples:
965
+ ```python
966
+ import jax.random as jr
967
+ from alberta_framework import NormalizedLinearLearner, IDBD, RandomWalkStream
968
+ from alberta_framework import run_normalized_learning_loop_batched
969
+
970
+ stream = RandomWalkStream(feature_dim=10)
971
+ learner = NormalizedLinearLearner(optimizer=IDBD())
972
+
973
+ # Run 30 seeds in parallel
974
+ keys = jr.split(jr.key(42), 30)
975
+ result = run_normalized_learning_loop_batched(
976
+ learner, stream, num_steps=10000, keys=keys
977
+ )
978
+
979
+ # result.metrics has shape (30, 10000, 4)
980
+ mean_error = result.metrics[:, :, 0].mean(axis=0) # Average over seeds
981
+ ```
982
+ """
983
+ # Define single-seed function that returns consistent structure
984
+ def single_seed_run(
985
+ key: Array,
986
+ ) -> tuple[
987
+ NormalizedLearnerState, Array, StepSizeHistory | None, NormalizerHistory | None
988
+ ]:
989
+ result = run_normalized_learning_loop(
990
+ learner, stream, num_steps, key, learner_state,
991
+ step_size_tracking, normalizer_tracking
992
+ )
993
+
994
+ # Unpack based on what tracking was enabled
995
+ if step_size_tracking is not None and normalizer_tracking is not None:
996
+ state, metrics, ss_history, norm_history = result
997
+ return state, metrics, ss_history, norm_history
998
+ elif step_size_tracking is not None:
999
+ state, metrics, ss_history = result
1000
+ return state, metrics, ss_history, None
1001
+ elif normalizer_tracking is not None:
1002
+ state, metrics, norm_history = result
1003
+ return state, metrics, None, norm_history
1004
+ else:
1005
+ state, metrics = result
1006
+ return state, metrics, None, None
1007
+
1008
+ # vmap over the keys dimension
1009
+ batched_states, batched_metrics, batched_ss_history, batched_norm_history = (
1010
+ jax.vmap(single_seed_run)(keys)
1011
+ )
1012
+
1013
+ # Reconstruct batched histories if tracking was enabled
1014
+ if step_size_tracking is not None:
1015
+ batched_step_size_history = StepSizeHistory(
1016
+ step_sizes=batched_ss_history.step_sizes,
1017
+ bias_step_sizes=batched_ss_history.bias_step_sizes,
1018
+ recording_indices=batched_ss_history.recording_indices,
1019
+ normalizers=batched_ss_history.normalizers,
1020
+ )
1021
+ else:
1022
+ batched_step_size_history = None
1023
+
1024
+ if normalizer_tracking is not None:
1025
+ batched_normalizer_history = NormalizerHistory(
1026
+ means=batched_norm_history.means,
1027
+ variances=batched_norm_history.variances,
1028
+ recording_indices=batched_norm_history.recording_indices,
1029
+ )
1030
+ else:
1031
+ batched_normalizer_history = None
1032
+
1033
+ return BatchedNormalizedResult(
1034
+ states=batched_states,
1035
+ metrics=batched_metrics,
1036
+ step_size_history=batched_step_size_history,
1037
+ normalizer_history=batched_normalizer_history,
1038
+ )
1039
+
1040
+
849
1041
  def metrics_to_dicts(metrics: Array, normalized: bool = False) -> list[dict[str, float]]:
850
1042
  """Convert metrics array to list of dicts for backward compatibility.
851
1043
 
@@ -164,6 +164,45 @@ class NormalizerHistory(NamedTuple):
164
164
  recording_indices: Array # (num_recordings,)
165
165
 
166
166
 
167
+ class BatchedLearningResult(NamedTuple):
168
+ """Result from batched learning loop across multiple seeds.
169
+
170
+ Used with `run_learning_loop_batched` for vmap-based GPU parallelization.
171
+
172
+ Attributes:
173
+ states: Batched learner states - each array has shape (num_seeds, ...)
174
+ metrics: Metrics array with shape (num_seeds, num_steps, 3)
175
+ where columns are [squared_error, error, mean_step_size]
176
+ step_size_history: Optional step-size history with batched shapes,
177
+ or None if tracking was disabled
178
+ """
179
+
180
+ states: "LearnerState" # Batched: each array has shape (num_seeds, ...)
181
+ metrics: Array # Shape: (num_seeds, num_steps, 3)
182
+ step_size_history: StepSizeHistory | None
183
+
184
+
185
+ class BatchedNormalizedResult(NamedTuple):
186
+ """Result from batched normalized learning loop across multiple seeds.
187
+
188
+ Used with `run_normalized_learning_loop_batched` for vmap-based GPU parallelization.
189
+
190
+ Attributes:
191
+ states: Batched normalized learner states - each array has shape (num_seeds, ...)
192
+ metrics: Metrics array with shape (num_seeds, num_steps, 4)
193
+ where columns are [squared_error, error, mean_step_size, normalizer_mean_var]
194
+ step_size_history: Optional step-size history with batched shapes,
195
+ or None if tracking was disabled
196
+ normalizer_history: Optional normalizer history with batched shapes,
197
+ or None if tracking was disabled
198
+ """
199
+
200
+ states: "NormalizedLearnerState" # Batched: each array has shape (num_seeds, ...)
201
+ metrics: Array # Shape: (num_seeds, num_steps, 4)
202
+ step_size_history: StepSizeHistory | None
203
+ normalizer_history: NormalizerHistory | None
204
+
205
+
167
206
  def create_lms_state(step_size: float = 0.01) -> LMSState:
168
207
  """Create initial LMS optimizer state.
169
208
 
@@ -6,6 +6,8 @@ import pytest
6
6
 
7
7
  from alberta_framework import (
8
8
  Autostep,
9
+ BatchedLearningResult,
10
+ BatchedNormalizedResult,
9
11
  IDBD,
10
12
  LMS,
11
13
  LinearLearner,
@@ -17,7 +19,9 @@ from alberta_framework import (
17
19
  StepSizeTrackingConfig,
18
20
  metrics_to_dicts,
19
21
  run_learning_loop,
22
+ run_learning_loop_batched,
20
23
  run_normalized_learning_loop,
24
+ run_normalized_learning_loop_batched,
21
25
  )
22
26
 
23
27
 
@@ -526,3 +530,248 @@ class TestNormalizedLearningLoopTracking:
526
530
  # Should record at steps 0, 25, 50, 75
527
531
  expected_indices = jnp.array([0, 25, 50, 75])
528
532
  assert jnp.allclose(norm_history.recording_indices, expected_indices)
533
+
534
+
535
+ class TestBatchedLearningLoop:
536
+ """Tests for run_learning_loop_batched."""
537
+
538
+ def test_batched_returns_correct_shapes(self, rng_key):
539
+ """Batched loop should return metrics with shape (num_seeds, num_steps, 3)."""
540
+ num_seeds = 5
541
+ num_steps = 100
542
+ feature_dim = 10
543
+
544
+ stream = RandomWalkStream(feature_dim=feature_dim)
545
+ learner = LinearLearner(optimizer=IDBD())
546
+ keys = jr.split(rng_key, num_seeds)
547
+
548
+ result = run_learning_loop_batched(learner, stream, num_steps, keys)
549
+
550
+ assert isinstance(result, BatchedLearningResult)
551
+ assert result.metrics.shape == (num_seeds, num_steps, 3)
552
+ assert result.states.weights.shape == (num_seeds, feature_dim)
553
+ assert result.states.bias.shape == (num_seeds,)
554
+ assert result.step_size_history is None
555
+
556
+ def test_batched_matches_sequential(self, rng_key):
557
+ """Batched results should match sequential execution."""
558
+ num_seeds = 3
559
+ num_steps = 50
560
+ feature_dim = 5
561
+
562
+ stream = RandomWalkStream(feature_dim=feature_dim)
563
+ learner = LinearLearner(optimizer=IDBD())
564
+ keys = jr.split(rng_key, num_seeds)
565
+
566
+ # Run batched
567
+ batched_result = run_learning_loop_batched(learner, stream, num_steps, keys)
568
+
569
+ # Run sequential
570
+ sequential_metrics = []
571
+ for i in range(num_seeds):
572
+ _, metrics = run_learning_loop(learner, stream, num_steps, keys[i])
573
+ sequential_metrics.append(metrics)
574
+ sequential_metrics = jnp.stack(sequential_metrics)
575
+
576
+ # Should match
577
+ assert jnp.allclose(batched_result.metrics, sequential_metrics)
578
+
579
+ def test_batched_with_step_size_tracking(self, rng_key):
580
+ """Batched loop should support step-size tracking."""
581
+ num_seeds = 4
582
+ num_steps = 100
583
+ feature_dim = 8
584
+ interval = 10
585
+ expected_recordings = num_steps // interval
586
+
587
+ stream = RandomWalkStream(feature_dim=feature_dim)
588
+ learner = LinearLearner(optimizer=Autostep())
589
+ keys = jr.split(rng_key, num_seeds)
590
+ config = StepSizeTrackingConfig(interval=interval)
591
+
592
+ result = run_learning_loop_batched(
593
+ learner, stream, num_steps, keys, step_size_tracking=config
594
+ )
595
+
596
+ assert result.step_size_history is not None
597
+ assert result.step_size_history.step_sizes.shape == (
598
+ num_seeds, expected_recordings, feature_dim
599
+ )
600
+ assert result.step_size_history.bias_step_sizes.shape == (
601
+ num_seeds, expected_recordings
602
+ )
603
+ assert result.step_size_history.recording_indices.shape == (
604
+ num_seeds, expected_recordings
605
+ )
606
+ # Autostep should have normalizers tracked
607
+ assert result.step_size_history.normalizers is not None
608
+ assert result.step_size_history.normalizers.shape == (
609
+ num_seeds, expected_recordings, feature_dim
610
+ )
611
+
612
+ def test_batched_without_tracking_has_none_history(self, rng_key):
613
+ """When tracking disabled, step_size_history should be None."""
614
+ stream = RandomWalkStream(feature_dim=5)
615
+ learner = LinearLearner(optimizer=IDBD())
616
+ keys = jr.split(rng_key, 3)
617
+
618
+ result = run_learning_loop_batched(learner, stream, num_steps=50, keys=keys)
619
+
620
+ assert result.step_size_history is None
621
+
622
+ def test_batched_deterministic_with_same_keys(self, rng_key):
623
+ """Same keys should produce same results."""
624
+ stream = RandomWalkStream(feature_dim=5)
625
+ learner = LinearLearner(optimizer=IDBD())
626
+ keys = jr.split(rng_key, 4)
627
+
628
+ result1 = run_learning_loop_batched(learner, stream, num_steps=50, keys=keys)
629
+ result2 = run_learning_loop_batched(learner, stream, num_steps=50, keys=keys)
630
+
631
+ assert jnp.allclose(result1.metrics, result2.metrics)
632
+ assert jnp.allclose(result1.states.weights, result2.states.weights)
633
+
634
+ def test_batched_different_keys_different_results(self, rng_key):
635
+ """Different keys should produce different results."""
636
+ stream = RandomWalkStream(feature_dim=5)
637
+ learner = LinearLearner(optimizer=IDBD())
638
+
639
+ keys1 = jr.split(jr.key(42), 3)
640
+ keys2 = jr.split(jr.key(123), 3)
641
+
642
+ result1 = run_learning_loop_batched(learner, stream, num_steps=50, keys=keys1)
643
+ result2 = run_learning_loop_batched(learner, stream, num_steps=50, keys=keys2)
644
+
645
+ assert not jnp.allclose(result1.metrics, result2.metrics)
646
+
647
+ def test_batched_with_lms_optimizer(self, rng_key):
648
+ """Batched loop should work with LMS optimizer."""
649
+ stream = RandomWalkStream(feature_dim=5)
650
+ learner = LinearLearner(optimizer=LMS(step_size=0.01))
651
+ keys = jr.split(rng_key, 3)
652
+
653
+ result = run_learning_loop_batched(learner, stream, num_steps=50, keys=keys)
654
+
655
+ assert result.metrics.shape == (3, 50, 3)
656
+ # LMS doesn't report mean_step_size in metrics (defaults to 0.0)
657
+ assert jnp.allclose(result.metrics[:, :, 2], 0.0)
658
+
659
+
660
+ class TestBatchedNormalizedLearningLoop:
661
+ """Tests for run_normalized_learning_loop_batched."""
662
+
663
+ def test_normalized_batched_returns_correct_shapes(self, rng_key):
664
+ """Batched normalized loop should return metrics with shape (num_seeds, num_steps, 4)."""
665
+ num_seeds = 5
666
+ num_steps = 100
667
+ feature_dim = 10
668
+
669
+ stream = RandomWalkStream(feature_dim=feature_dim)
670
+ learner = NormalizedLinearLearner(optimizer=IDBD())
671
+ keys = jr.split(rng_key, num_seeds)
672
+
673
+ result = run_normalized_learning_loop_batched(learner, stream, num_steps, keys)
674
+
675
+ assert isinstance(result, BatchedNormalizedResult)
676
+ assert result.metrics.shape == (num_seeds, num_steps, 4)
677
+ assert result.states.learner_state.weights.shape == (num_seeds, feature_dim)
678
+ assert result.states.normalizer_state.mean.shape == (num_seeds, feature_dim)
679
+ assert result.step_size_history is None
680
+ assert result.normalizer_history is None
681
+
682
+ def test_normalized_batched_matches_sequential(self, rng_key):
683
+ """Batched normalized results should match sequential execution."""
684
+ num_seeds = 3
685
+ num_steps = 50
686
+ feature_dim = 5
687
+
688
+ stream = RandomWalkStream(feature_dim=feature_dim)
689
+ learner = NormalizedLinearLearner(optimizer=IDBD())
690
+ keys = jr.split(rng_key, num_seeds)
691
+
692
+ # Run batched
693
+ batched_result = run_normalized_learning_loop_batched(
694
+ learner, stream, num_steps, keys
695
+ )
696
+
697
+ # Run sequential
698
+ sequential_metrics = []
699
+ for i in range(num_seeds):
700
+ _, metrics = run_normalized_learning_loop(learner, stream, num_steps, keys[i])
701
+ sequential_metrics.append(metrics)
702
+ sequential_metrics = jnp.stack(sequential_metrics)
703
+
704
+ # Should match
705
+ assert jnp.allclose(batched_result.metrics, sequential_metrics)
706
+
707
+ def test_normalized_batched_with_both_tracking(self, rng_key):
708
+ """Batched normalized loop should support both tracking options."""
709
+ num_seeds = 4
710
+ num_steps = 100
711
+ feature_dim = 8
712
+ ss_interval = 10
713
+ norm_interval = 20
714
+ ss_recordings = num_steps // ss_interval
715
+ norm_recordings = num_steps // norm_interval
716
+
717
+ stream = RandomWalkStream(feature_dim=feature_dim)
718
+ learner = NormalizedLinearLearner(optimizer=Autostep())
719
+ keys = jr.split(rng_key, num_seeds)
720
+ ss_config = StepSizeTrackingConfig(interval=ss_interval)
721
+ norm_config = NormalizerTrackingConfig(interval=norm_interval)
722
+
723
+ result = run_normalized_learning_loop_batched(
724
+ learner, stream, num_steps, keys,
725
+ step_size_tracking=ss_config, normalizer_tracking=norm_config
726
+ )
727
+
728
+ # Step-size history
729
+ assert result.step_size_history is not None
730
+ assert result.step_size_history.step_sizes.shape == (
731
+ num_seeds, ss_recordings, feature_dim
732
+ )
733
+ # Autostep normalizers
734
+ assert result.step_size_history.normalizers is not None
735
+
736
+ # Normalizer history
737
+ assert result.normalizer_history is not None
738
+ assert result.normalizer_history.means.shape == (
739
+ num_seeds, norm_recordings, feature_dim
740
+ )
741
+ assert result.normalizer_history.variances.shape == (
742
+ num_seeds, norm_recordings, feature_dim
743
+ )
744
+
745
+ def test_normalized_batched_step_size_only(self, rng_key):
746
+ """Batched normalized loop with only step-size tracking."""
747
+ num_seeds = 3
748
+ num_steps = 50
749
+
750
+ stream = RandomWalkStream(feature_dim=5)
751
+ learner = NormalizedLinearLearner(optimizer=IDBD())
752
+ keys = jr.split(rng_key, num_seeds)
753
+ ss_config = StepSizeTrackingConfig(interval=10)
754
+
755
+ result = run_normalized_learning_loop_batched(
756
+ learner, stream, num_steps, keys, step_size_tracking=ss_config
757
+ )
758
+
759
+ assert result.step_size_history is not None
760
+ assert result.normalizer_history is None
761
+
762
+ def test_normalized_batched_normalizer_only(self, rng_key):
763
+ """Batched normalized loop with only normalizer tracking."""
764
+ num_seeds = 3
765
+ num_steps = 50
766
+
767
+ stream = RandomWalkStream(feature_dim=5)
768
+ learner = NormalizedLinearLearner(optimizer=IDBD())
769
+ keys = jr.split(rng_key, num_seeds)
770
+ norm_config = NormalizerTrackingConfig(interval=10)
771
+
772
+ result = run_normalized_learning_loop_batched(
773
+ learner, stream, num_steps, keys, normalizer_tracking=norm_config
774
+ )
775
+
776
+ assert result.step_size_history is None
777
+ assert result.normalizer_history is not None