alberta-framework 0.1.1__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -39,7 +39,7 @@ References
39
39
  - Tuning-free Step-size Adaptation (Mahmood et al., 2012)
40
40
  """
41
41
 
42
- __version__ = "0.1.0"
42
+ __version__ = "0.2.0"
43
43
 
44
44
  # Core types
45
45
  # Learners
@@ -50,7 +50,9 @@ from alberta_framework.core.learners import (
50
50
  UpdateResult,
51
51
  metrics_to_dicts,
52
52
  run_learning_loop,
53
+ run_learning_loop_batched,
53
54
  run_normalized_learning_loop,
55
+ run_normalized_learning_loop_batched,
54
56
  )
55
57
 
56
58
  # Normalizers
@@ -64,6 +66,8 @@ from alberta_framework.core.normalizers import (
64
66
  from alberta_framework.core.optimizers import IDBD, LMS, Autostep, Optimizer
65
67
  from alberta_framework.core.types import (
66
68
  AutostepState,
69
+ BatchedLearningResult,
70
+ BatchedNormalizedResult,
67
71
  IDBDState,
68
72
  LearnerState,
69
73
  LMSState,
@@ -138,6 +142,8 @@ __all__ = [
138
142
  "__version__",
139
143
  # Types
140
144
  "AutostepState",
145
+ "BatchedLearningResult",
146
+ "BatchedNormalizedResult",
141
147
  "IDBDState",
142
148
  "LMSState",
143
149
  "LearnerState",
@@ -164,7 +170,9 @@ __all__ = [
164
170
  "NormalizedLearnerState",
165
171
  "NormalizedLinearLearner",
166
172
  "run_learning_loop",
173
+ "run_learning_loop_batched",
167
174
  "run_normalized_learning_loop",
175
+ "run_normalized_learning_loop_batched",
168
176
  "metrics_to_dicts",
169
177
  # Streams - protocol
170
178
  "ScanStream",
@@ -15,6 +15,8 @@ from alberta_framework.core.normalizers import NormalizerState, OnlineNormalizer
15
15
  from alberta_framework.core.optimizers import LMS, Optimizer
16
16
  from alberta_framework.core.types import (
17
17
  AutostepState,
18
+ BatchedLearningResult,
19
+ BatchedNormalizedResult,
18
20
  IDBDState,
19
21
  LearnerState,
20
22
  LMSState,
@@ -846,6 +848,196 @@ def run_normalized_learning_loop[StreamStateT](
846
848
  return final_learner, metrics
847
849
 
848
850
 
851
+ def run_learning_loop_batched[StreamStateT](
852
+ learner: LinearLearner,
853
+ stream: ScanStream[StreamStateT],
854
+ num_steps: int,
855
+ keys: Array,
856
+ learner_state: LearnerState | None = None,
857
+ step_size_tracking: StepSizeTrackingConfig | None = None,
858
+ ) -> BatchedLearningResult:
859
+ """Run learning loop across multiple seeds in parallel using jax.vmap.
860
+
861
+ This function provides GPU parallelization for multi-seed experiments,
862
+ typically achieving 2-5x speedup over sequential execution.
863
+
864
+ Args:
865
+ learner: The learner to train
866
+ stream: Experience stream providing (observation, target) pairs
867
+ num_steps: Number of learning steps to run per seed
868
+ keys: JAX random keys with shape (num_seeds,) or (num_seeds, 2)
869
+ learner_state: Initial state (if None, will be initialized from stream).
870
+ The same initial state is used for all seeds.
871
+ step_size_tracking: Optional config for recording per-weight step-sizes.
872
+ When provided, history arrays have shape (num_seeds, num_recordings, ...)
873
+
874
+ Returns:
875
+ BatchedLearningResult containing:
876
+ - states: Batched final states with shape (num_seeds, ...) for each array
877
+ - metrics: Array of shape (num_seeds, num_steps, 3)
878
+ - step_size_history: Batched history or None if tracking disabled
879
+
880
+ Examples:
881
+ ```python
882
+ import jax.random as jr
883
+ from alberta_framework import LinearLearner, IDBD, RandomWalkStream
884
+ from alberta_framework import run_learning_loop_batched
885
+
886
+ stream = RandomWalkStream(feature_dim=10)
887
+ learner = LinearLearner(optimizer=IDBD())
888
+
889
+ # Run 30 seeds in parallel
890
+ keys = jr.split(jr.key(42), 30)
891
+ result = run_learning_loop_batched(learner, stream, num_steps=10000, keys=keys)
892
+
893
+ # result.metrics has shape (30, 10000, 3)
894
+ mean_error = result.metrics[:, :, 0].mean(axis=0) # Average over seeds
895
+ ```
896
+ """
897
+ # Define single-seed function that returns consistent structure
898
+ def single_seed_run(key: Array) -> tuple[LearnerState, Array, StepSizeHistory | None]:
899
+ result = run_learning_loop(
900
+ learner, stream, num_steps, key, learner_state, step_size_tracking
901
+ )
902
+ if step_size_tracking is not None:
903
+ state, metrics, history = result
904
+ return state, metrics, history
905
+ else:
906
+ state, metrics = result
907
+ # Return None for history to maintain consistent output structure
908
+ return state, metrics, None
909
+
910
+ # vmap over the keys dimension
911
+ batched_states, batched_metrics, batched_history = jax.vmap(single_seed_run)(keys)
912
+
913
+ # Reconstruct batched history if tracking was enabled
914
+ if step_size_tracking is not None:
915
+ batched_step_size_history = StepSizeHistory(
916
+ step_sizes=batched_history.step_sizes,
917
+ bias_step_sizes=batched_history.bias_step_sizes,
918
+ recording_indices=batched_history.recording_indices,
919
+ normalizers=batched_history.normalizers,
920
+ )
921
+ else:
922
+ batched_step_size_history = None
923
+
924
+ return BatchedLearningResult(
925
+ states=batched_states,
926
+ metrics=batched_metrics,
927
+ step_size_history=batched_step_size_history,
928
+ )
929
+
930
+
931
+ def run_normalized_learning_loop_batched[StreamStateT](
932
+ learner: NormalizedLinearLearner,
933
+ stream: ScanStream[StreamStateT],
934
+ num_steps: int,
935
+ keys: Array,
936
+ learner_state: NormalizedLearnerState | None = None,
937
+ step_size_tracking: StepSizeTrackingConfig | None = None,
938
+ normalizer_tracking: NormalizerTrackingConfig | None = None,
939
+ ) -> BatchedNormalizedResult:
940
+ """Run normalized learning loop across multiple seeds in parallel using jax.vmap.
941
+
942
+ This function provides GPU parallelization for multi-seed experiments with
943
+ normalized learners, typically achieving 2-5x speedup over sequential execution.
944
+
945
+ Args:
946
+ learner: The normalized learner to train
947
+ stream: Experience stream providing (observation, target) pairs
948
+ num_steps: Number of learning steps to run per seed
949
+ keys: JAX random keys with shape (num_seeds,) or (num_seeds, 2)
950
+ learner_state: Initial state (if None, will be initialized from stream).
951
+ The same initial state is used for all seeds.
952
+ step_size_tracking: Optional config for recording per-weight step-sizes.
953
+ When provided, history arrays have shape (num_seeds, num_recordings, ...)
954
+ normalizer_tracking: Optional config for recording normalizer state.
955
+ When provided, history arrays have shape (num_seeds, num_recordings, ...)
956
+
957
+ Returns:
958
+ BatchedNormalizedResult containing:
959
+ - states: Batched final states with shape (num_seeds, ...) for each array
960
+ - metrics: Array of shape (num_seeds, num_steps, 4)
961
+ - step_size_history: Batched history or None if tracking disabled
962
+ - normalizer_history: Batched history or None if tracking disabled
963
+
964
+ Examples:
965
+ ```python
966
+ import jax.random as jr
967
+ from alberta_framework import NormalizedLinearLearner, IDBD, RandomWalkStream
968
+ from alberta_framework import run_normalized_learning_loop_batched
969
+
970
+ stream = RandomWalkStream(feature_dim=10)
971
+ learner = NormalizedLinearLearner(optimizer=IDBD())
972
+
973
+ # Run 30 seeds in parallel
974
+ keys = jr.split(jr.key(42), 30)
975
+ result = run_normalized_learning_loop_batched(
976
+ learner, stream, num_steps=10000, keys=keys
977
+ )
978
+
979
+ # result.metrics has shape (30, 10000, 4)
980
+ mean_error = result.metrics[:, :, 0].mean(axis=0) # Average over seeds
981
+ ```
982
+ """
983
+ # Define single-seed function that returns consistent structure
984
+ def single_seed_run(
985
+ key: Array,
986
+ ) -> tuple[
987
+ NormalizedLearnerState, Array, StepSizeHistory | None, NormalizerHistory | None
988
+ ]:
989
+ result = run_normalized_learning_loop(
990
+ learner, stream, num_steps, key, learner_state,
991
+ step_size_tracking, normalizer_tracking
992
+ )
993
+
994
+ # Unpack based on what tracking was enabled
995
+ if step_size_tracking is not None and normalizer_tracking is not None:
996
+ state, metrics, ss_history, norm_history = result
997
+ return state, metrics, ss_history, norm_history
998
+ elif step_size_tracking is not None:
999
+ state, metrics, ss_history = result
1000
+ return state, metrics, ss_history, None
1001
+ elif normalizer_tracking is not None:
1002
+ state, metrics, norm_history = result
1003
+ return state, metrics, None, norm_history
1004
+ else:
1005
+ state, metrics = result
1006
+ return state, metrics, None, None
1007
+
1008
+ # vmap over the keys dimension
1009
+ batched_states, batched_metrics, batched_ss_history, batched_norm_history = (
1010
+ jax.vmap(single_seed_run)(keys)
1011
+ )
1012
+
1013
+ # Reconstruct batched histories if tracking was enabled
1014
+ if step_size_tracking is not None:
1015
+ batched_step_size_history = StepSizeHistory(
1016
+ step_sizes=batched_ss_history.step_sizes,
1017
+ bias_step_sizes=batched_ss_history.bias_step_sizes,
1018
+ recording_indices=batched_ss_history.recording_indices,
1019
+ normalizers=batched_ss_history.normalizers,
1020
+ )
1021
+ else:
1022
+ batched_step_size_history = None
1023
+
1024
+ if normalizer_tracking is not None:
1025
+ batched_normalizer_history = NormalizerHistory(
1026
+ means=batched_norm_history.means,
1027
+ variances=batched_norm_history.variances,
1028
+ recording_indices=batched_norm_history.recording_indices,
1029
+ )
1030
+ else:
1031
+ batched_normalizer_history = None
1032
+
1033
+ return BatchedNormalizedResult(
1034
+ states=batched_states,
1035
+ metrics=batched_metrics,
1036
+ step_size_history=batched_step_size_history,
1037
+ normalizer_history=batched_normalizer_history,
1038
+ )
1039
+
1040
+
849
1041
  def metrics_to_dicts(metrics: Array, normalized: bool = False) -> list[dict[str, float]]:
850
1042
  """Convert metrics array to list of dicts for backward compatibility.
851
1043
 
@@ -4,11 +4,14 @@ This module defines the core data types used throughout the framework,
4
4
  following JAX conventions with immutable NamedTuples for state management.
5
5
  """
6
6
 
7
- from typing import NamedTuple
7
+ from typing import TYPE_CHECKING, NamedTuple
8
8
 
9
9
  import jax.numpy as jnp
10
10
  from jax import Array
11
11
 
12
+ if TYPE_CHECKING:
13
+ from alberta_framework.core.learners import NormalizedLearnerState
14
+
12
15
  # Type aliases for clarity
13
16
  Observation = Array # x_t: feature vector
14
17
  Target = Array # y*_t: desired output
@@ -164,6 +167,45 @@ class NormalizerHistory(NamedTuple):
164
167
  recording_indices: Array # (num_recordings,)
165
168
 
166
169
 
170
+ class BatchedLearningResult(NamedTuple):
171
+ """Result from batched learning loop across multiple seeds.
172
+
173
+ Used with `run_learning_loop_batched` for vmap-based GPU parallelization.
174
+
175
+ Attributes:
176
+ states: Batched learner states - each array has shape (num_seeds, ...)
177
+ metrics: Metrics array with shape (num_seeds, num_steps, 3)
178
+ where columns are [squared_error, error, mean_step_size]
179
+ step_size_history: Optional step-size history with batched shapes,
180
+ or None if tracking was disabled
181
+ """
182
+
183
+ states: "LearnerState" # Batched: each array has shape (num_seeds, ...)
184
+ metrics: Array # Shape: (num_seeds, num_steps, 3)
185
+ step_size_history: StepSizeHistory | None
186
+
187
+
188
+ class BatchedNormalizedResult(NamedTuple):
189
+ """Result from batched normalized learning loop across multiple seeds.
190
+
191
+ Used with `run_normalized_learning_loop_batched` for vmap-based GPU parallelization.
192
+
193
+ Attributes:
194
+ states: Batched normalized learner states - each array has shape (num_seeds, ...)
195
+ metrics: Metrics array with shape (num_seeds, num_steps, 4)
196
+ where columns are [squared_error, error, mean_step_size, normalizer_mean_var]
197
+ step_size_history: Optional step-size history with batched shapes,
198
+ or None if tracking was disabled
199
+ normalizer_history: Optional normalizer history with batched shapes,
200
+ or None if tracking was disabled
201
+ """
202
+
203
+ states: "NormalizedLearnerState" # Batched: each array has shape (num_seeds, ...)
204
+ metrics: Array # Shape: (num_seeds, num_steps, 4)
205
+ step_size_history: StepSizeHistory | None
206
+ normalizer_history: NormalizerHistory | None
207
+
208
+
167
209
  def create_lms_state(step_size: float = 0.01) -> LMSState:
168
210
  """Create initial LMS optimizer state.
169
211
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: alberta-framework
3
- Version: 0.1.1
3
+ Version: 0.2.1
4
4
  Summary: Implementation of the Alberta Plan for AI Research - continual learning with meta-learned step-sizes
5
5
  Project-URL: Homepage, https://github.com/j-klawson/alberta-framework
6
6
  Project-URL: Repository, https://github.com/j-klawson/alberta-framework
@@ -1,10 +1,10 @@
1
- alberta_framework/__init__.py,sha256=LUrsm6WFh5-Mxg78d1G-Qe015nkGgcCDhSw5lf3UkFo,5460
1
+ alberta_framework/__init__.py,sha256=gAafDDmkivDdfnvDVff9zbVY9ilzqqfJ9KvpbRegKqs,5726
2
2
  alberta_framework/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  alberta_framework/core/__init__.py,sha256=PSrC4zSxgm_6YXWEQ80aZaunpbQ58QexxKmDDU-jp6c,522
4
- alberta_framework/core/learners.py,sha256=dnRQ5B16oGYpamDJIRYzR54ED9bvW0lpa8c_suC6YBA,29879
4
+ alberta_framework/core/learners.py,sha256=khZYkae5rlIyV13BW3-hrtPSjGFXPj2IUTM1z74xTTA,37724
5
5
  alberta_framework/core/normalizers.py,sha256=Z_d3H17qoXh87DE7k41imvWzkVJQ2xQgDUP7GYSNzAY,5903
6
6
  alberta_framework/core/optimizers.py,sha256=OefVuDDG1phh1QQIUyVPsQckl41VrpWFG7hY2eqyc64,14585
7
- alberta_framework/core/types.py,sha256=mtpVEr2qJ0XzZyjOsUdChmS7T7mrXBDMHb-jfkrT9JY,7503
7
+ alberta_framework/core/types.py,sha256=svV2Q5-0bj7reQ_hh-pRGp2wYfde5VWgTiRm4hUCDKI,9297
8
8
  alberta_framework/streams/__init__.py,sha256=bsDgWjWjotDQHMI2lno3dgk8N14pd-2mYAQpXAtCPx4,2035
9
9
  alberta_framework/streams/base.py,sha256=9rJxvUgmzd5u2bRV4vi5PxhUvj39EZTD4bZHo-Ptn-U,2168
10
10
  alberta_framework/streams/gymnasium.py,sha256=s733X7aEgy05hcSazjZEhBiJChtEL7uVpxwh0fXBQZA,21980
@@ -16,7 +16,7 @@ alberta_framework/utils/metrics.py,sha256=1cryNJoboO67vvRhausaucbYZFgdL_06vaf08U
16
16
  alberta_framework/utils/statistics.py,sha256=4fbzNlmsdUaM5lLW1BhL5B5MUpnqimQlwJklZ4x0y0U,15416
17
17
  alberta_framework/utils/timing.py,sha256=JOLq8CpCAV7LWOWkftxefduSFjaXnVwal1MFBKEMdJI,4049
18
18
  alberta_framework/utils/visualization.py,sha256=PmKBD3KGabNhgDizcNiGJEbVCyDL1YMUE5yTwgJHu2o,17924
19
- alberta_framework-0.1.1.dist-info/METADATA,sha256=Ny-LxHiqZVNXZbu5f8ZyBSLCEZd2KsBhA9iROV7tNiU,7763
20
- alberta_framework-0.1.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
21
- alberta_framework-0.1.1.dist-info/licenses/LICENSE,sha256=TI1avodt5mvxz7sunyxIa0HlNgLQcmKNLeRjCVcgKmE,10754
22
- alberta_framework-0.1.1.dist-info/RECORD,,
19
+ alberta_framework-0.2.1.dist-info/METADATA,sha256=pJ7SujFDZXrWxqjpSK8NFSBoYJpN5VfLneFnAmYG3hw,7763
20
+ alberta_framework-0.2.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
21
+ alberta_framework-0.2.1.dist-info/licenses/LICENSE,sha256=TI1avodt5mvxz7sunyxIa0HlNgLQcmKNLeRjCVcgKmE,10754
22
+ alberta_framework-0.2.1.dist-info/RECORD,,