alberta-framework 0.1.1__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alberta_framework/__init__.py +9 -1
- alberta_framework/core/learners.py +192 -0
- alberta_framework/core/types.py +43 -1
- {alberta_framework-0.1.1.dist-info → alberta_framework-0.2.1.dist-info}/METADATA +1 -1
- {alberta_framework-0.1.1.dist-info → alberta_framework-0.2.1.dist-info}/RECORD +7 -7
- {alberta_framework-0.1.1.dist-info → alberta_framework-0.2.1.dist-info}/WHEEL +0 -0
- {alberta_framework-0.1.1.dist-info → alberta_framework-0.2.1.dist-info}/licenses/LICENSE +0 -0
alberta_framework/__init__.py
CHANGED
|
@@ -39,7 +39,7 @@ References
|
|
|
39
39
|
- Tuning-free Step-size Adaptation (Mahmood et al., 2012)
|
|
40
40
|
"""
|
|
41
41
|
|
|
42
|
-
__version__ = "0.
|
|
42
|
+
__version__ = "0.2.0"
|
|
43
43
|
|
|
44
44
|
# Core types
|
|
45
45
|
# Learners
|
|
@@ -50,7 +50,9 @@ from alberta_framework.core.learners import (
|
|
|
50
50
|
UpdateResult,
|
|
51
51
|
metrics_to_dicts,
|
|
52
52
|
run_learning_loop,
|
|
53
|
+
run_learning_loop_batched,
|
|
53
54
|
run_normalized_learning_loop,
|
|
55
|
+
run_normalized_learning_loop_batched,
|
|
54
56
|
)
|
|
55
57
|
|
|
56
58
|
# Normalizers
|
|
@@ -64,6 +66,8 @@ from alberta_framework.core.normalizers import (
|
|
|
64
66
|
from alberta_framework.core.optimizers import IDBD, LMS, Autostep, Optimizer
|
|
65
67
|
from alberta_framework.core.types import (
|
|
66
68
|
AutostepState,
|
|
69
|
+
BatchedLearningResult,
|
|
70
|
+
BatchedNormalizedResult,
|
|
67
71
|
IDBDState,
|
|
68
72
|
LearnerState,
|
|
69
73
|
LMSState,
|
|
@@ -138,6 +142,8 @@ __all__ = [
|
|
|
138
142
|
"__version__",
|
|
139
143
|
# Types
|
|
140
144
|
"AutostepState",
|
|
145
|
+
"BatchedLearningResult",
|
|
146
|
+
"BatchedNormalizedResult",
|
|
141
147
|
"IDBDState",
|
|
142
148
|
"LMSState",
|
|
143
149
|
"LearnerState",
|
|
@@ -164,7 +170,9 @@ __all__ = [
|
|
|
164
170
|
"NormalizedLearnerState",
|
|
165
171
|
"NormalizedLinearLearner",
|
|
166
172
|
"run_learning_loop",
|
|
173
|
+
"run_learning_loop_batched",
|
|
167
174
|
"run_normalized_learning_loop",
|
|
175
|
+
"run_normalized_learning_loop_batched",
|
|
168
176
|
"metrics_to_dicts",
|
|
169
177
|
# Streams - protocol
|
|
170
178
|
"ScanStream",
|
|
@@ -15,6 +15,8 @@ from alberta_framework.core.normalizers import NormalizerState, OnlineNormalizer
|
|
|
15
15
|
from alberta_framework.core.optimizers import LMS, Optimizer
|
|
16
16
|
from alberta_framework.core.types import (
|
|
17
17
|
AutostepState,
|
|
18
|
+
BatchedLearningResult,
|
|
19
|
+
BatchedNormalizedResult,
|
|
18
20
|
IDBDState,
|
|
19
21
|
LearnerState,
|
|
20
22
|
LMSState,
|
|
@@ -846,6 +848,196 @@ def run_normalized_learning_loop[StreamStateT](
|
|
|
846
848
|
return final_learner, metrics
|
|
847
849
|
|
|
848
850
|
|
|
851
|
+
def run_learning_loop_batched[StreamStateT](
|
|
852
|
+
learner: LinearLearner,
|
|
853
|
+
stream: ScanStream[StreamStateT],
|
|
854
|
+
num_steps: int,
|
|
855
|
+
keys: Array,
|
|
856
|
+
learner_state: LearnerState | None = None,
|
|
857
|
+
step_size_tracking: StepSizeTrackingConfig | None = None,
|
|
858
|
+
) -> BatchedLearningResult:
|
|
859
|
+
"""Run learning loop across multiple seeds in parallel using jax.vmap.
|
|
860
|
+
|
|
861
|
+
This function provides GPU parallelization for multi-seed experiments,
|
|
862
|
+
typically achieving 2-5x speedup over sequential execution.
|
|
863
|
+
|
|
864
|
+
Args:
|
|
865
|
+
learner: The learner to train
|
|
866
|
+
stream: Experience stream providing (observation, target) pairs
|
|
867
|
+
num_steps: Number of learning steps to run per seed
|
|
868
|
+
keys: JAX random keys with shape (num_seeds,) or (num_seeds, 2)
|
|
869
|
+
learner_state: Initial state (if None, will be initialized from stream).
|
|
870
|
+
The same initial state is used for all seeds.
|
|
871
|
+
step_size_tracking: Optional config for recording per-weight step-sizes.
|
|
872
|
+
When provided, history arrays have shape (num_seeds, num_recordings, ...)
|
|
873
|
+
|
|
874
|
+
Returns:
|
|
875
|
+
BatchedLearningResult containing:
|
|
876
|
+
- states: Batched final states with shape (num_seeds, ...) for each array
|
|
877
|
+
- metrics: Array of shape (num_seeds, num_steps, 3)
|
|
878
|
+
- step_size_history: Batched history or None if tracking disabled
|
|
879
|
+
|
|
880
|
+
Examples:
|
|
881
|
+
```python
|
|
882
|
+
import jax.random as jr
|
|
883
|
+
from alberta_framework import LinearLearner, IDBD, RandomWalkStream
|
|
884
|
+
from alberta_framework import run_learning_loop_batched
|
|
885
|
+
|
|
886
|
+
stream = RandomWalkStream(feature_dim=10)
|
|
887
|
+
learner = LinearLearner(optimizer=IDBD())
|
|
888
|
+
|
|
889
|
+
# Run 30 seeds in parallel
|
|
890
|
+
keys = jr.split(jr.key(42), 30)
|
|
891
|
+
result = run_learning_loop_batched(learner, stream, num_steps=10000, keys=keys)
|
|
892
|
+
|
|
893
|
+
# result.metrics has shape (30, 10000, 3)
|
|
894
|
+
mean_error = result.metrics[:, :, 0].mean(axis=0) # Average over seeds
|
|
895
|
+
```
|
|
896
|
+
"""
|
|
897
|
+
# Define single-seed function that returns consistent structure
|
|
898
|
+
def single_seed_run(key: Array) -> tuple[LearnerState, Array, StepSizeHistory | None]:
|
|
899
|
+
result = run_learning_loop(
|
|
900
|
+
learner, stream, num_steps, key, learner_state, step_size_tracking
|
|
901
|
+
)
|
|
902
|
+
if step_size_tracking is not None:
|
|
903
|
+
state, metrics, history = result
|
|
904
|
+
return state, metrics, history
|
|
905
|
+
else:
|
|
906
|
+
state, metrics = result
|
|
907
|
+
# Return None for history to maintain consistent output structure
|
|
908
|
+
return state, metrics, None
|
|
909
|
+
|
|
910
|
+
# vmap over the keys dimension
|
|
911
|
+
batched_states, batched_metrics, batched_history = jax.vmap(single_seed_run)(keys)
|
|
912
|
+
|
|
913
|
+
# Reconstruct batched history if tracking was enabled
|
|
914
|
+
if step_size_tracking is not None:
|
|
915
|
+
batched_step_size_history = StepSizeHistory(
|
|
916
|
+
step_sizes=batched_history.step_sizes,
|
|
917
|
+
bias_step_sizes=batched_history.bias_step_sizes,
|
|
918
|
+
recording_indices=batched_history.recording_indices,
|
|
919
|
+
normalizers=batched_history.normalizers,
|
|
920
|
+
)
|
|
921
|
+
else:
|
|
922
|
+
batched_step_size_history = None
|
|
923
|
+
|
|
924
|
+
return BatchedLearningResult(
|
|
925
|
+
states=batched_states,
|
|
926
|
+
metrics=batched_metrics,
|
|
927
|
+
step_size_history=batched_step_size_history,
|
|
928
|
+
)
|
|
929
|
+
|
|
930
|
+
|
|
931
|
+
def run_normalized_learning_loop_batched[StreamStateT](
|
|
932
|
+
learner: NormalizedLinearLearner,
|
|
933
|
+
stream: ScanStream[StreamStateT],
|
|
934
|
+
num_steps: int,
|
|
935
|
+
keys: Array,
|
|
936
|
+
learner_state: NormalizedLearnerState | None = None,
|
|
937
|
+
step_size_tracking: StepSizeTrackingConfig | None = None,
|
|
938
|
+
normalizer_tracking: NormalizerTrackingConfig | None = None,
|
|
939
|
+
) -> BatchedNormalizedResult:
|
|
940
|
+
"""Run normalized learning loop across multiple seeds in parallel using jax.vmap.
|
|
941
|
+
|
|
942
|
+
This function provides GPU parallelization for multi-seed experiments with
|
|
943
|
+
normalized learners, typically achieving 2-5x speedup over sequential execution.
|
|
944
|
+
|
|
945
|
+
Args:
|
|
946
|
+
learner: The normalized learner to train
|
|
947
|
+
stream: Experience stream providing (observation, target) pairs
|
|
948
|
+
num_steps: Number of learning steps to run per seed
|
|
949
|
+
keys: JAX random keys with shape (num_seeds,) or (num_seeds, 2)
|
|
950
|
+
learner_state: Initial state (if None, will be initialized from stream).
|
|
951
|
+
The same initial state is used for all seeds.
|
|
952
|
+
step_size_tracking: Optional config for recording per-weight step-sizes.
|
|
953
|
+
When provided, history arrays have shape (num_seeds, num_recordings, ...)
|
|
954
|
+
normalizer_tracking: Optional config for recording normalizer state.
|
|
955
|
+
When provided, history arrays have shape (num_seeds, num_recordings, ...)
|
|
956
|
+
|
|
957
|
+
Returns:
|
|
958
|
+
BatchedNormalizedResult containing:
|
|
959
|
+
- states: Batched final states with shape (num_seeds, ...) for each array
|
|
960
|
+
- metrics: Array of shape (num_seeds, num_steps, 4)
|
|
961
|
+
- step_size_history: Batched history or None if tracking disabled
|
|
962
|
+
- normalizer_history: Batched history or None if tracking disabled
|
|
963
|
+
|
|
964
|
+
Examples:
|
|
965
|
+
```python
|
|
966
|
+
import jax.random as jr
|
|
967
|
+
from alberta_framework import NormalizedLinearLearner, IDBD, RandomWalkStream
|
|
968
|
+
from alberta_framework import run_normalized_learning_loop_batched
|
|
969
|
+
|
|
970
|
+
stream = RandomWalkStream(feature_dim=10)
|
|
971
|
+
learner = NormalizedLinearLearner(optimizer=IDBD())
|
|
972
|
+
|
|
973
|
+
# Run 30 seeds in parallel
|
|
974
|
+
keys = jr.split(jr.key(42), 30)
|
|
975
|
+
result = run_normalized_learning_loop_batched(
|
|
976
|
+
learner, stream, num_steps=10000, keys=keys
|
|
977
|
+
)
|
|
978
|
+
|
|
979
|
+
# result.metrics has shape (30, 10000, 4)
|
|
980
|
+
mean_error = result.metrics[:, :, 0].mean(axis=0) # Average over seeds
|
|
981
|
+
```
|
|
982
|
+
"""
|
|
983
|
+
# Define single-seed function that returns consistent structure
|
|
984
|
+
def single_seed_run(
|
|
985
|
+
key: Array,
|
|
986
|
+
) -> tuple[
|
|
987
|
+
NormalizedLearnerState, Array, StepSizeHistory | None, NormalizerHistory | None
|
|
988
|
+
]:
|
|
989
|
+
result = run_normalized_learning_loop(
|
|
990
|
+
learner, stream, num_steps, key, learner_state,
|
|
991
|
+
step_size_tracking, normalizer_tracking
|
|
992
|
+
)
|
|
993
|
+
|
|
994
|
+
# Unpack based on what tracking was enabled
|
|
995
|
+
if step_size_tracking is not None and normalizer_tracking is not None:
|
|
996
|
+
state, metrics, ss_history, norm_history = result
|
|
997
|
+
return state, metrics, ss_history, norm_history
|
|
998
|
+
elif step_size_tracking is not None:
|
|
999
|
+
state, metrics, ss_history = result
|
|
1000
|
+
return state, metrics, ss_history, None
|
|
1001
|
+
elif normalizer_tracking is not None:
|
|
1002
|
+
state, metrics, norm_history = result
|
|
1003
|
+
return state, metrics, None, norm_history
|
|
1004
|
+
else:
|
|
1005
|
+
state, metrics = result
|
|
1006
|
+
return state, metrics, None, None
|
|
1007
|
+
|
|
1008
|
+
# vmap over the keys dimension
|
|
1009
|
+
batched_states, batched_metrics, batched_ss_history, batched_norm_history = (
|
|
1010
|
+
jax.vmap(single_seed_run)(keys)
|
|
1011
|
+
)
|
|
1012
|
+
|
|
1013
|
+
# Reconstruct batched histories if tracking was enabled
|
|
1014
|
+
if step_size_tracking is not None:
|
|
1015
|
+
batched_step_size_history = StepSizeHistory(
|
|
1016
|
+
step_sizes=batched_ss_history.step_sizes,
|
|
1017
|
+
bias_step_sizes=batched_ss_history.bias_step_sizes,
|
|
1018
|
+
recording_indices=batched_ss_history.recording_indices,
|
|
1019
|
+
normalizers=batched_ss_history.normalizers,
|
|
1020
|
+
)
|
|
1021
|
+
else:
|
|
1022
|
+
batched_step_size_history = None
|
|
1023
|
+
|
|
1024
|
+
if normalizer_tracking is not None:
|
|
1025
|
+
batched_normalizer_history = NormalizerHistory(
|
|
1026
|
+
means=batched_norm_history.means,
|
|
1027
|
+
variances=batched_norm_history.variances,
|
|
1028
|
+
recording_indices=batched_norm_history.recording_indices,
|
|
1029
|
+
)
|
|
1030
|
+
else:
|
|
1031
|
+
batched_normalizer_history = None
|
|
1032
|
+
|
|
1033
|
+
return BatchedNormalizedResult(
|
|
1034
|
+
states=batched_states,
|
|
1035
|
+
metrics=batched_metrics,
|
|
1036
|
+
step_size_history=batched_step_size_history,
|
|
1037
|
+
normalizer_history=batched_normalizer_history,
|
|
1038
|
+
)
|
|
1039
|
+
|
|
1040
|
+
|
|
849
1041
|
def metrics_to_dicts(metrics: Array, normalized: bool = False) -> list[dict[str, float]]:
|
|
850
1042
|
"""Convert metrics array to list of dicts for backward compatibility.
|
|
851
1043
|
|
alberta_framework/core/types.py
CHANGED
|
@@ -4,11 +4,14 @@ This module defines the core data types used throughout the framework,
|
|
|
4
4
|
following JAX conventions with immutable NamedTuples for state management.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
from typing import NamedTuple
|
|
7
|
+
from typing import TYPE_CHECKING, NamedTuple
|
|
8
8
|
|
|
9
9
|
import jax.numpy as jnp
|
|
10
10
|
from jax import Array
|
|
11
11
|
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from alberta_framework.core.learners import NormalizedLearnerState
|
|
14
|
+
|
|
12
15
|
# Type aliases for clarity
|
|
13
16
|
Observation = Array # x_t: feature vector
|
|
14
17
|
Target = Array # y*_t: desired output
|
|
@@ -164,6 +167,45 @@ class NormalizerHistory(NamedTuple):
|
|
|
164
167
|
recording_indices: Array # (num_recordings,)
|
|
165
168
|
|
|
166
169
|
|
|
170
|
+
class BatchedLearningResult(NamedTuple):
|
|
171
|
+
"""Result from batched learning loop across multiple seeds.
|
|
172
|
+
|
|
173
|
+
Used with `run_learning_loop_batched` for vmap-based GPU parallelization.
|
|
174
|
+
|
|
175
|
+
Attributes:
|
|
176
|
+
states: Batched learner states - each array has shape (num_seeds, ...)
|
|
177
|
+
metrics: Metrics array with shape (num_seeds, num_steps, 3)
|
|
178
|
+
where columns are [squared_error, error, mean_step_size]
|
|
179
|
+
step_size_history: Optional step-size history with batched shapes,
|
|
180
|
+
or None if tracking was disabled
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
states: "LearnerState" # Batched: each array has shape (num_seeds, ...)
|
|
184
|
+
metrics: Array # Shape: (num_seeds, num_steps, 3)
|
|
185
|
+
step_size_history: StepSizeHistory | None
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class BatchedNormalizedResult(NamedTuple):
|
|
189
|
+
"""Result from batched normalized learning loop across multiple seeds.
|
|
190
|
+
|
|
191
|
+
Used with `run_normalized_learning_loop_batched` for vmap-based GPU parallelization.
|
|
192
|
+
|
|
193
|
+
Attributes:
|
|
194
|
+
states: Batched normalized learner states - each array has shape (num_seeds, ...)
|
|
195
|
+
metrics: Metrics array with shape (num_seeds, num_steps, 4)
|
|
196
|
+
where columns are [squared_error, error, mean_step_size, normalizer_mean_var]
|
|
197
|
+
step_size_history: Optional step-size history with batched shapes,
|
|
198
|
+
or None if tracking was disabled
|
|
199
|
+
normalizer_history: Optional normalizer history with batched shapes,
|
|
200
|
+
or None if tracking was disabled
|
|
201
|
+
"""
|
|
202
|
+
|
|
203
|
+
states: "NormalizedLearnerState" # Batched: each array has shape (num_seeds, ...)
|
|
204
|
+
metrics: Array # Shape: (num_seeds, num_steps, 4)
|
|
205
|
+
step_size_history: StepSizeHistory | None
|
|
206
|
+
normalizer_history: NormalizerHistory | None
|
|
207
|
+
|
|
208
|
+
|
|
167
209
|
def create_lms_state(step_size: float = 0.01) -> LMSState:
|
|
168
210
|
"""Create initial LMS optimizer state.
|
|
169
211
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: alberta-framework
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.1
|
|
4
4
|
Summary: Implementation of the Alberta Plan for AI Research - continual learning with meta-learned step-sizes
|
|
5
5
|
Project-URL: Homepage, https://github.com/j-klawson/alberta-framework
|
|
6
6
|
Project-URL: Repository, https://github.com/j-klawson/alberta-framework
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
alberta_framework/__init__.py,sha256=
|
|
1
|
+
alberta_framework/__init__.py,sha256=gAafDDmkivDdfnvDVff9zbVY9ilzqqfJ9KvpbRegKqs,5726
|
|
2
2
|
alberta_framework/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
3
|
alberta_framework/core/__init__.py,sha256=PSrC4zSxgm_6YXWEQ80aZaunpbQ58QexxKmDDU-jp6c,522
|
|
4
|
-
alberta_framework/core/learners.py,sha256=
|
|
4
|
+
alberta_framework/core/learners.py,sha256=khZYkae5rlIyV13BW3-hrtPSjGFXPj2IUTM1z74xTTA,37724
|
|
5
5
|
alberta_framework/core/normalizers.py,sha256=Z_d3H17qoXh87DE7k41imvWzkVJQ2xQgDUP7GYSNzAY,5903
|
|
6
6
|
alberta_framework/core/optimizers.py,sha256=OefVuDDG1phh1QQIUyVPsQckl41VrpWFG7hY2eqyc64,14585
|
|
7
|
-
alberta_framework/core/types.py,sha256=
|
|
7
|
+
alberta_framework/core/types.py,sha256=svV2Q5-0bj7reQ_hh-pRGp2wYfde5VWgTiRm4hUCDKI,9297
|
|
8
8
|
alberta_framework/streams/__init__.py,sha256=bsDgWjWjotDQHMI2lno3dgk8N14pd-2mYAQpXAtCPx4,2035
|
|
9
9
|
alberta_framework/streams/base.py,sha256=9rJxvUgmzd5u2bRV4vi5PxhUvj39EZTD4bZHo-Ptn-U,2168
|
|
10
10
|
alberta_framework/streams/gymnasium.py,sha256=s733X7aEgy05hcSazjZEhBiJChtEL7uVpxwh0fXBQZA,21980
|
|
@@ -16,7 +16,7 @@ alberta_framework/utils/metrics.py,sha256=1cryNJoboO67vvRhausaucbYZFgdL_06vaf08U
|
|
|
16
16
|
alberta_framework/utils/statistics.py,sha256=4fbzNlmsdUaM5lLW1BhL5B5MUpnqimQlwJklZ4x0y0U,15416
|
|
17
17
|
alberta_framework/utils/timing.py,sha256=JOLq8CpCAV7LWOWkftxefduSFjaXnVwal1MFBKEMdJI,4049
|
|
18
18
|
alberta_framework/utils/visualization.py,sha256=PmKBD3KGabNhgDizcNiGJEbVCyDL1YMUE5yTwgJHu2o,17924
|
|
19
|
-
alberta_framework-0.
|
|
20
|
-
alberta_framework-0.
|
|
21
|
-
alberta_framework-0.
|
|
22
|
-
alberta_framework-0.
|
|
19
|
+
alberta_framework-0.2.1.dist-info/METADATA,sha256=pJ7SujFDZXrWxqjpSK8NFSBoYJpN5VfLneFnAmYG3hw,7763
|
|
20
|
+
alberta_framework-0.2.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
21
|
+
alberta_framework-0.2.1.dist-info/licenses/LICENSE,sha256=TI1avodt5mvxz7sunyxIa0HlNgLQcmKNLeRjCVcgKmE,10754
|
|
22
|
+
alberta_framework-0.2.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|