alberta-framework 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,225 @@
1
+ """Alberta Framework: A JAX-based research framework for continual AI.
2
+
3
+ The Alberta Framework provides foundational components for continual reinforcement
4
+ learning research. Built on JAX for hardware acceleration, the framework emphasizes
5
+ temporal uniformity — every component updates at every time step, with no special
6
+ training phases or batch processing.
7
+
8
+ Roadmap
9
+ -------
10
+ | Step | Focus | Status |
11
+ |------|-------|--------|
12
+ | 1 | Meta-learned step-sizes (IDBD, Autostep) | **Complete** |
13
+ | 2 | Feature generation and testing | Planned |
14
+ | 3 | GVF predictions, Horde architecture | Planned |
15
+ | 4 | Actor-critic with eligibility traces | Planned |
16
+ | 5-6 | Off-policy learning, average reward | Planned |
17
+ | 7-12 | Hierarchical, multi-agent, world models | Future |
18
+
19
+ Examples
20
+ --------
21
+ ```python
22
+ import jax.random as jr
23
+ from alberta_framework import LinearLearner, IDBD, RandomWalkStream, run_learning_loop
24
+
25
+ # Non-stationary stream where target weights drift over time
26
+ stream = RandomWalkStream(feature_dim=10, drift_rate=0.001)
27
+
28
+ # Learner with IDBD meta-learned step-sizes
29
+ learner = LinearLearner(optimizer=IDBD())
30
+
31
+ # JIT-compiled training via jax.lax.scan
32
+ state, metrics = run_learning_loop(learner, stream, num_steps=10000, key=jr.key(42))
33
+ ```
34
+
35
+ References
36
+ ----------
37
+ - The Alberta Plan for AI Research (Sutton et al., 2022): https://arxiv.org/abs/2208.11173
38
+ - Adapting Bias by Gradient Descent (Sutton, 1992)
39
+ - Tuning-free Step-size Adaptation (Mahmood et al., 2012)
40
+ """
41
+
42
+ __version__ = "0.2.0"
43
+
44
+ # Core types
45
+ # Learners
46
+ from alberta_framework.core.learners import (
47
+ LinearLearner,
48
+ NormalizedLearnerState,
49
+ NormalizedLinearLearner,
50
+ UpdateResult,
51
+ metrics_to_dicts,
52
+ run_learning_loop,
53
+ run_learning_loop_batched,
54
+ run_normalized_learning_loop,
55
+ run_normalized_learning_loop_batched,
56
+ )
57
+
58
+ # Normalizers
59
+ from alberta_framework.core.normalizers import (
60
+ NormalizerState,
61
+ OnlineNormalizer,
62
+ create_normalizer_state,
63
+ )
64
+
65
+ # Optimizers
66
+ from alberta_framework.core.optimizers import IDBD, LMS, Autostep, Optimizer
67
+ from alberta_framework.core.types import (
68
+ AutostepState,
69
+ BatchedLearningResult,
70
+ BatchedNormalizedResult,
71
+ IDBDState,
72
+ LearnerState,
73
+ LMSState,
74
+ NormalizerHistory,
75
+ NormalizerTrackingConfig,
76
+ Observation,
77
+ Prediction,
78
+ StepSizeHistory,
79
+ StepSizeTrackingConfig,
80
+ Target,
81
+ TimeStep,
82
+ )
83
+
84
+ # Streams - base
85
+ from alberta_framework.streams.base import ScanStream
86
+
87
+ # Streams - synthetic
88
+ from alberta_framework.streams.synthetic import (
89
+ AbruptChangeState,
90
+ AbruptChangeStream,
91
+ AbruptChangeTarget,
92
+ CyclicState,
93
+ CyclicStream,
94
+ CyclicTarget,
95
+ DynamicScaleShiftState,
96
+ DynamicScaleShiftStream,
97
+ PeriodicChangeState,
98
+ PeriodicChangeStream,
99
+ PeriodicChangeTarget,
100
+ RandomWalkState,
101
+ RandomWalkStream,
102
+ RandomWalkTarget,
103
+ ScaleDriftState,
104
+ ScaleDriftStream,
105
+ ScaledStreamState,
106
+ ScaledStreamWrapper,
107
+ SuttonExperiment1State,
108
+ SuttonExperiment1Stream,
109
+ make_scale_range,
110
+ )
111
+
112
+ # Utilities
113
+ from alberta_framework.utils.metrics import (
114
+ compare_learners,
115
+ compute_cumulative_error,
116
+ compute_running_mean,
117
+ compute_tracking_error,
118
+ extract_metric,
119
+ )
120
+ from alberta_framework.utils.timing import Timer, format_duration
121
+
122
+ # Gymnasium streams (optional)
123
+ try:
124
+ from alberta_framework.streams.gymnasium import (
125
+ GymnasiumStream,
126
+ PredictionMode,
127
+ TDStream,
128
+ collect_trajectory,
129
+ learn_from_trajectory,
130
+ learn_from_trajectory_normalized,
131
+ make_epsilon_greedy_policy,
132
+ make_gymnasium_stream,
133
+ make_random_policy,
134
+ )
135
+
136
+ _gymnasium_available = True
137
+ except ImportError:
138
+ _gymnasium_available = False
139
+
140
+ __all__ = [
141
+ # Version
142
+ "__version__",
143
+ # Types
144
+ "AutostepState",
145
+ "BatchedLearningResult",
146
+ "BatchedNormalizedResult",
147
+ "IDBDState",
148
+ "LMSState",
149
+ "LearnerState",
150
+ "NormalizerHistory",
151
+ "NormalizerState",
152
+ "NormalizerTrackingConfig",
153
+ "Observation",
154
+ "Prediction",
155
+ "StepSizeHistory",
156
+ "StepSizeTrackingConfig",
157
+ "Target",
158
+ "TimeStep",
159
+ "UpdateResult",
160
+ # Optimizers
161
+ "Autostep",
162
+ "IDBD",
163
+ "LMS",
164
+ "Optimizer",
165
+ # Normalizers
166
+ "OnlineNormalizer",
167
+ "create_normalizer_state",
168
+ # Learners
169
+ "LinearLearner",
170
+ "NormalizedLearnerState",
171
+ "NormalizedLinearLearner",
172
+ "run_learning_loop",
173
+ "run_learning_loop_batched",
174
+ "run_normalized_learning_loop",
175
+ "run_normalized_learning_loop_batched",
176
+ "metrics_to_dicts",
177
+ # Streams - protocol
178
+ "ScanStream",
179
+ # Streams - synthetic
180
+ "AbruptChangeState",
181
+ "AbruptChangeStream",
182
+ "AbruptChangeTarget",
183
+ "CyclicState",
184
+ "CyclicStream",
185
+ "CyclicTarget",
186
+ "DynamicScaleShiftState",
187
+ "DynamicScaleShiftStream",
188
+ "PeriodicChangeState",
189
+ "PeriodicChangeStream",
190
+ "PeriodicChangeTarget",
191
+ "RandomWalkState",
192
+ "RandomWalkStream",
193
+ "RandomWalkTarget",
194
+ "ScaleDriftState",
195
+ "ScaleDriftStream",
196
+ "ScaledStreamState",
197
+ "ScaledStreamWrapper",
198
+ "SuttonExperiment1State",
199
+ "SuttonExperiment1Stream",
200
+ # Stream utilities
201
+ "make_scale_range",
202
+ # Utilities
203
+ "compare_learners",
204
+ "compute_cumulative_error",
205
+ "compute_running_mean",
206
+ "compute_tracking_error",
207
+ "extract_metric",
208
+ # Timing
209
+ "Timer",
210
+ "format_duration",
211
+ ]
212
+
213
+ # Add Gymnasium exports if available
214
+ if _gymnasium_available:
215
+ __all__ += [
216
+ "GymnasiumStream",
217
+ "PredictionMode",
218
+ "TDStream",
219
+ "collect_trajectory",
220
+ "learn_from_trajectory",
221
+ "learn_from_trajectory_normalized",
222
+ "make_epsilon_greedy_policy",
223
+ "make_gymnasium_stream",
224
+ "make_random_policy",
225
+ ]
@@ -0,0 +1,27 @@
1
+ """Core components for the Alberta Framework."""
2
+
3
+ from alberta_framework.core.learners import LinearLearner
4
+ from alberta_framework.core.optimizers import IDBD, LMS, Optimizer
5
+ from alberta_framework.core.types import (
6
+ IDBDState,
7
+ LearnerState,
8
+ LMSState,
9
+ Observation,
10
+ Prediction,
11
+ Target,
12
+ TimeStep,
13
+ )
14
+
15
+ __all__ = [
16
+ "IDBD",
17
+ "IDBDState",
18
+ "LMS",
19
+ "LMSState",
20
+ "LearnerState",
21
+ "LinearLearner",
22
+ "Observation",
23
+ "Optimizer",
24
+ "Prediction",
25
+ "Target",
26
+ "TimeStep",
27
+ ]