alberta-framework 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alberta_framework/__init__.py +225 -0
- alberta_framework/core/__init__.py +27 -0
- alberta_framework/core/learners.py +1070 -0
- alberta_framework/core/normalizers.py +192 -0
- alberta_framework/core/optimizers.py +424 -0
- alberta_framework/core/types.py +271 -0
- alberta_framework/py.typed +0 -0
- alberta_framework/streams/__init__.py +83 -0
- alberta_framework/streams/base.py +73 -0
- alberta_framework/streams/gymnasium.py +655 -0
- alberta_framework/streams/synthetic.py +1001 -0
- alberta_framework/utils/__init__.py +113 -0
- alberta_framework/utils/experiments.py +335 -0
- alberta_framework/utils/export.py +509 -0
- alberta_framework/utils/metrics.py +112 -0
- alberta_framework/utils/statistics.py +527 -0
- alberta_framework/utils/timing.py +144 -0
- alberta_framework/utils/visualization.py +571 -0
- alberta_framework-0.2.2.dist-info/METADATA +206 -0
- alberta_framework-0.2.2.dist-info/RECORD +22 -0
- alberta_framework-0.2.2.dist-info/WHEEL +4 -0
- alberta_framework-0.2.2.dist-info/licenses/LICENSE +190 -0
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
"""Alberta Framework: A JAX-based research framework for continual AI.
|
|
2
|
+
|
|
3
|
+
The Alberta Framework provides foundational components for continual reinforcement
|
|
4
|
+
learning research. Built on JAX for hardware acceleration, the framework emphasizes
|
|
5
|
+
temporal uniformity — every component updates at every time step, with no special
|
|
6
|
+
training phases or batch processing.
|
|
7
|
+
|
|
8
|
+
Roadmap
|
|
9
|
+
-------
|
|
10
|
+
| Step | Focus | Status |
|
|
11
|
+
|------|-------|--------|
|
|
12
|
+
| 1 | Meta-learned step-sizes (IDBD, Autostep) | **Complete** |
|
|
13
|
+
| 2 | Feature generation and testing | Planned |
|
|
14
|
+
| 3 | GVF predictions, Horde architecture | Planned |
|
|
15
|
+
| 4 | Actor-critic with eligibility traces | Planned |
|
|
16
|
+
| 5-6 | Off-policy learning, average reward | Planned |
|
|
17
|
+
| 7-12 | Hierarchical, multi-agent, world models | Future |
|
|
18
|
+
|
|
19
|
+
Examples
|
|
20
|
+
--------
|
|
21
|
+
```python
|
|
22
|
+
import jax.random as jr
|
|
23
|
+
from alberta_framework import LinearLearner, IDBD, RandomWalkStream, run_learning_loop
|
|
24
|
+
|
|
25
|
+
# Non-stationary stream where target weights drift over time
|
|
26
|
+
stream = RandomWalkStream(feature_dim=10, drift_rate=0.001)
|
|
27
|
+
|
|
28
|
+
# Learner with IDBD meta-learned step-sizes
|
|
29
|
+
learner = LinearLearner(optimizer=IDBD())
|
|
30
|
+
|
|
31
|
+
# JIT-compiled training via jax.lax.scan
|
|
32
|
+
state, metrics = run_learning_loop(learner, stream, num_steps=10000, key=jr.key(42))
|
|
33
|
+
```
|
|
34
|
+
|
|
35
|
+
References
|
|
36
|
+
----------
|
|
37
|
+
- The Alberta Plan for AI Research (Sutton et al., 2022): https://arxiv.org/abs/2208.11173
|
|
38
|
+
- Adapting Bias by Gradient Descent (Sutton, 1992)
|
|
39
|
+
- Tuning-free Step-size Adaptation (Mahmood et al., 2012)
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
__version__ = "0.2.0"
|
|
43
|
+
|
|
44
|
+
# Core types
|
|
45
|
+
# Learners
|
|
46
|
+
from alberta_framework.core.learners import (
|
|
47
|
+
LinearLearner,
|
|
48
|
+
NormalizedLearnerState,
|
|
49
|
+
NormalizedLinearLearner,
|
|
50
|
+
UpdateResult,
|
|
51
|
+
metrics_to_dicts,
|
|
52
|
+
run_learning_loop,
|
|
53
|
+
run_learning_loop_batched,
|
|
54
|
+
run_normalized_learning_loop,
|
|
55
|
+
run_normalized_learning_loop_batched,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# Normalizers
|
|
59
|
+
from alberta_framework.core.normalizers import (
|
|
60
|
+
NormalizerState,
|
|
61
|
+
OnlineNormalizer,
|
|
62
|
+
create_normalizer_state,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# Optimizers
|
|
66
|
+
from alberta_framework.core.optimizers import IDBD, LMS, Autostep, Optimizer
|
|
67
|
+
from alberta_framework.core.types import (
|
|
68
|
+
AutostepState,
|
|
69
|
+
BatchedLearningResult,
|
|
70
|
+
BatchedNormalizedResult,
|
|
71
|
+
IDBDState,
|
|
72
|
+
LearnerState,
|
|
73
|
+
LMSState,
|
|
74
|
+
NormalizerHistory,
|
|
75
|
+
NormalizerTrackingConfig,
|
|
76
|
+
Observation,
|
|
77
|
+
Prediction,
|
|
78
|
+
StepSizeHistory,
|
|
79
|
+
StepSizeTrackingConfig,
|
|
80
|
+
Target,
|
|
81
|
+
TimeStep,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
# Streams - base
|
|
85
|
+
from alberta_framework.streams.base import ScanStream
|
|
86
|
+
|
|
87
|
+
# Streams - synthetic
|
|
88
|
+
from alberta_framework.streams.synthetic import (
|
|
89
|
+
AbruptChangeState,
|
|
90
|
+
AbruptChangeStream,
|
|
91
|
+
AbruptChangeTarget,
|
|
92
|
+
CyclicState,
|
|
93
|
+
CyclicStream,
|
|
94
|
+
CyclicTarget,
|
|
95
|
+
DynamicScaleShiftState,
|
|
96
|
+
DynamicScaleShiftStream,
|
|
97
|
+
PeriodicChangeState,
|
|
98
|
+
PeriodicChangeStream,
|
|
99
|
+
PeriodicChangeTarget,
|
|
100
|
+
RandomWalkState,
|
|
101
|
+
RandomWalkStream,
|
|
102
|
+
RandomWalkTarget,
|
|
103
|
+
ScaleDriftState,
|
|
104
|
+
ScaleDriftStream,
|
|
105
|
+
ScaledStreamState,
|
|
106
|
+
ScaledStreamWrapper,
|
|
107
|
+
SuttonExperiment1State,
|
|
108
|
+
SuttonExperiment1Stream,
|
|
109
|
+
make_scale_range,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Utilities
|
|
113
|
+
from alberta_framework.utils.metrics import (
|
|
114
|
+
compare_learners,
|
|
115
|
+
compute_cumulative_error,
|
|
116
|
+
compute_running_mean,
|
|
117
|
+
compute_tracking_error,
|
|
118
|
+
extract_metric,
|
|
119
|
+
)
|
|
120
|
+
from alberta_framework.utils.timing import Timer, format_duration
|
|
121
|
+
|
|
122
|
+
# Gymnasium streams (optional)
|
|
123
|
+
try:
|
|
124
|
+
from alberta_framework.streams.gymnasium import (
|
|
125
|
+
GymnasiumStream,
|
|
126
|
+
PredictionMode,
|
|
127
|
+
TDStream,
|
|
128
|
+
collect_trajectory,
|
|
129
|
+
learn_from_trajectory,
|
|
130
|
+
learn_from_trajectory_normalized,
|
|
131
|
+
make_epsilon_greedy_policy,
|
|
132
|
+
make_gymnasium_stream,
|
|
133
|
+
make_random_policy,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
_gymnasium_available = True
|
|
137
|
+
except ImportError:
|
|
138
|
+
_gymnasium_available = False
|
|
139
|
+
|
|
140
|
+
__all__ = [
|
|
141
|
+
# Version
|
|
142
|
+
"__version__",
|
|
143
|
+
# Types
|
|
144
|
+
"AutostepState",
|
|
145
|
+
"BatchedLearningResult",
|
|
146
|
+
"BatchedNormalizedResult",
|
|
147
|
+
"IDBDState",
|
|
148
|
+
"LMSState",
|
|
149
|
+
"LearnerState",
|
|
150
|
+
"NormalizerHistory",
|
|
151
|
+
"NormalizerState",
|
|
152
|
+
"NormalizerTrackingConfig",
|
|
153
|
+
"Observation",
|
|
154
|
+
"Prediction",
|
|
155
|
+
"StepSizeHistory",
|
|
156
|
+
"StepSizeTrackingConfig",
|
|
157
|
+
"Target",
|
|
158
|
+
"TimeStep",
|
|
159
|
+
"UpdateResult",
|
|
160
|
+
# Optimizers
|
|
161
|
+
"Autostep",
|
|
162
|
+
"IDBD",
|
|
163
|
+
"LMS",
|
|
164
|
+
"Optimizer",
|
|
165
|
+
# Normalizers
|
|
166
|
+
"OnlineNormalizer",
|
|
167
|
+
"create_normalizer_state",
|
|
168
|
+
# Learners
|
|
169
|
+
"LinearLearner",
|
|
170
|
+
"NormalizedLearnerState",
|
|
171
|
+
"NormalizedLinearLearner",
|
|
172
|
+
"run_learning_loop",
|
|
173
|
+
"run_learning_loop_batched",
|
|
174
|
+
"run_normalized_learning_loop",
|
|
175
|
+
"run_normalized_learning_loop_batched",
|
|
176
|
+
"metrics_to_dicts",
|
|
177
|
+
# Streams - protocol
|
|
178
|
+
"ScanStream",
|
|
179
|
+
# Streams - synthetic
|
|
180
|
+
"AbruptChangeState",
|
|
181
|
+
"AbruptChangeStream",
|
|
182
|
+
"AbruptChangeTarget",
|
|
183
|
+
"CyclicState",
|
|
184
|
+
"CyclicStream",
|
|
185
|
+
"CyclicTarget",
|
|
186
|
+
"DynamicScaleShiftState",
|
|
187
|
+
"DynamicScaleShiftStream",
|
|
188
|
+
"PeriodicChangeState",
|
|
189
|
+
"PeriodicChangeStream",
|
|
190
|
+
"PeriodicChangeTarget",
|
|
191
|
+
"RandomWalkState",
|
|
192
|
+
"RandomWalkStream",
|
|
193
|
+
"RandomWalkTarget",
|
|
194
|
+
"ScaleDriftState",
|
|
195
|
+
"ScaleDriftStream",
|
|
196
|
+
"ScaledStreamState",
|
|
197
|
+
"ScaledStreamWrapper",
|
|
198
|
+
"SuttonExperiment1State",
|
|
199
|
+
"SuttonExperiment1Stream",
|
|
200
|
+
# Stream utilities
|
|
201
|
+
"make_scale_range",
|
|
202
|
+
# Utilities
|
|
203
|
+
"compare_learners",
|
|
204
|
+
"compute_cumulative_error",
|
|
205
|
+
"compute_running_mean",
|
|
206
|
+
"compute_tracking_error",
|
|
207
|
+
"extract_metric",
|
|
208
|
+
# Timing
|
|
209
|
+
"Timer",
|
|
210
|
+
"format_duration",
|
|
211
|
+
]
|
|
212
|
+
|
|
213
|
+
# Add Gymnasium exports if available
|
|
214
|
+
if _gymnasium_available:
|
|
215
|
+
__all__ += [
|
|
216
|
+
"GymnasiumStream",
|
|
217
|
+
"PredictionMode",
|
|
218
|
+
"TDStream",
|
|
219
|
+
"collect_trajectory",
|
|
220
|
+
"learn_from_trajectory",
|
|
221
|
+
"learn_from_trajectory_normalized",
|
|
222
|
+
"make_epsilon_greedy_policy",
|
|
223
|
+
"make_gymnasium_stream",
|
|
224
|
+
"make_random_policy",
|
|
225
|
+
]
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""Core components for the Alberta Framework."""
|
|
2
|
+
|
|
3
|
+
from alberta_framework.core.learners import LinearLearner
|
|
4
|
+
from alberta_framework.core.optimizers import IDBD, LMS, Optimizer
|
|
5
|
+
from alberta_framework.core.types import (
|
|
6
|
+
IDBDState,
|
|
7
|
+
LearnerState,
|
|
8
|
+
LMSState,
|
|
9
|
+
Observation,
|
|
10
|
+
Prediction,
|
|
11
|
+
Target,
|
|
12
|
+
TimeStep,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"IDBD",
|
|
17
|
+
"IDBDState",
|
|
18
|
+
"LMS",
|
|
19
|
+
"LMSState",
|
|
20
|
+
"LearnerState",
|
|
21
|
+
"LinearLearner",
|
|
22
|
+
"Observation",
|
|
23
|
+
"Optimizer",
|
|
24
|
+
"Prediction",
|
|
25
|
+
"Target",
|
|
26
|
+
"TimeStep",
|
|
27
|
+
]
|