alberta-framework 0.1.0__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/CLAUDE.md +98 -5
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/PKG-INFO +10 -2
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/README.md +9 -1
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/pyproject.toml +1 -1
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/src/alberta_framework/__init__.py +52 -23
- alberta_framework-0.2.0/src/alberta_framework/core/learners.py +1061 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/src/alberta_framework/core/normalizers.py +1 -1
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/src/alberta_framework/core/optimizers.py +14 -12
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/src/alberta_framework/core/types.py +70 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/src/alberta_framework/streams/base.py +8 -5
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/src/alberta_framework/streams/synthetic.py +16 -10
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/src/alberta_framework/utils/experiments.py +4 -3
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/src/alberta_framework/utils/timing.py +42 -36
- alberta_framework-0.2.0/tests/test_learners.py +777 -0
- alberta_framework-0.1.0/src/alberta_framework/core/learners.py +0 -530
- alberta_framework-0.1.0/tests/test_learners.py +0 -339
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/.github/workflows/ci.yml +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/.github/workflows/docs.yml +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/.github/workflows/publish.yml +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/.gitignore +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/ALBERTA_PLAN.md +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/CHANGELOG.md +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/LICENSE +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/ROADMAP.md +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/docs/contributing.md +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/docs/gen_ref_pages.py +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/docs/getting-started/installation.md +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/docs/getting-started/quickstart.md +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/docs/guide/concepts.md +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/docs/guide/experiments.md +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/docs/guide/gymnasium.md +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/docs/guide/optimizers.md +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/docs/guide/streams.md +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/docs/index.md +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/docs/javascripts/mathjax.js +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/examples/The Alberta Plan/Step1/README.md +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/examples/The Alberta Plan/Step1/autostep_comparison.py +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/examples/The Alberta Plan/Step1/external_normalization_study.py +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/examples/The Alberta Plan/Step1/idbd_lms_autostep_comparison.py +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/examples/The Alberta Plan/Step1/normalization_study.py +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/examples/The Alberta Plan/Step1/sutton1992_experiment1.py +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/examples/The Alberta Plan/Step1/sutton1992_experiment2.py +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/examples/gymnasium_reward_prediction.py +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/examples/publication_experiment.py +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/examples/td_cartpole_lms.py +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/mkdocs.yml +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/papers/mahmood-msc-thesis-summary.md +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/src/alberta_framework/core/__init__.py +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/src/alberta_framework/py.typed +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/src/alberta_framework/streams/__init__.py +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/src/alberta_framework/streams/gymnasium.py +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/src/alberta_framework/utils/__init__.py +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/src/alberta_framework/utils/export.py +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/src/alberta_framework/utils/metrics.py +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/src/alberta_framework/utils/statistics.py +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/src/alberta_framework/utils/visualization.py +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/tests/conftest.py +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/tests/test_gymnasium_streams.py +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/tests/test_normalizers.py +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/tests/test_optimizers.py +0 -0
- {alberta_framework-0.1.0 → alberta_framework-0.2.0}/tests/test_streams.py +0 -0
|
@@ -14,10 +14,10 @@ This framework implements Step 1 of the Alberta Plan: demonstrating that IDBD (I
|
|
|
14
14
|
```
|
|
15
15
|
src/alberta_framework/
|
|
16
16
|
├── core/
|
|
17
|
-
│ ├── types.py # TimeStep, LearnerState, LMSState, IDBDState, AutostepState, StepSizeTrackingConfig, StepSizeHistory
|
|
17
|
+
│ ├── types.py # TimeStep, LearnerState, LMSState, IDBDState, AutostepState, StepSizeTrackingConfig, StepSizeHistory, NormalizerTrackingConfig, NormalizerHistory, BatchedLearningResult, BatchedNormalizedResult
|
|
18
18
|
│ ├── optimizers.py # LMS, IDBD, Autostep optimizers
|
|
19
19
|
│ ├── normalizers.py # OnlineNormalizer, NormalizerState
|
|
20
|
-
│ └── learners.py # LinearLearner, NormalizedLinearLearner, run_learning_loop, metrics_to_dicts
|
|
20
|
+
│ └── learners.py # LinearLearner, NormalizedLinearLearner, run_learning_loop, run_learning_loop_batched, run_normalized_learning_loop, run_normalized_learning_loop_batched, metrics_to_dicts
|
|
21
21
|
├── streams/
|
|
22
22
|
│ ├── base.py # ScanStream protocol (pure function interface for jax.lax.scan)
|
|
23
23
|
│ ├── synthetic.py # RandomWalkStream, AbruptChangeStream, CyclicStream, PeriodicChangeStream, ScaledStreamWrapper, DynamicScaleShiftStream, ScaleDriftStream
|
|
@@ -125,15 +125,15 @@ IDBD/Autostep should beat LMS when starting from the same step-size (demonstrate
|
|
|
125
125
|
With optimal parameters, adaptive methods should match best grid-searched LMS.
|
|
126
126
|
|
|
127
127
|
### Step-Size Tracking for Meta-Adaptation Analysis
|
|
128
|
-
The `run_learning_loop`
|
|
128
|
+
The `run_learning_loop` and `run_normalized_learning_loop` functions support optional per-weight step-size tracking for analyzing how adaptive optimizers evolve their step-sizes during training:
|
|
129
129
|
|
|
130
130
|
```python
|
|
131
|
-
from alberta_framework import LinearLearner, IDBD, StepSizeTrackingConfig, run_learning_loop
|
|
131
|
+
from alberta_framework import LinearLearner, IDBD, Autostep, StepSizeTrackingConfig, run_learning_loop
|
|
132
132
|
from alberta_framework.streams import RandomWalkStream
|
|
133
133
|
import jax.random as jr
|
|
134
134
|
|
|
135
135
|
stream = RandomWalkStream(feature_dim=10)
|
|
136
|
-
learner = LinearLearner(optimizer=
|
|
136
|
+
learner = LinearLearner(optimizer=Autostep())
|
|
137
137
|
config = StepSizeTrackingConfig(interval=100) # Record every 100 steps
|
|
138
138
|
|
|
139
139
|
state, metrics, history = run_learning_loop(
|
|
@@ -143,6 +143,7 @@ state, metrics, history = run_learning_loop(
|
|
|
143
143
|
# history.step_sizes: shape (100, 10) - per-weight step-sizes at each recording
|
|
144
144
|
# history.bias_step_sizes: shape (100,) - bias step-size at each recording
|
|
145
145
|
# history.recording_indices: shape (100,) - step indices where recordings were made
|
|
146
|
+
# history.normalizers: shape (100, 10) - Autostep's v_i normalizers (None for IDBD/LMS)
|
|
146
147
|
```
|
|
147
148
|
|
|
148
149
|
Key features:
|
|
@@ -150,6 +151,94 @@ Key features:
|
|
|
150
151
|
- Configurable interval to control memory usage
|
|
151
152
|
- Optional `include_bias=False` to skip bias tracking
|
|
152
153
|
- Works with LMS (constant), IDBD, and Autostep optimizers
|
|
154
|
+
- **Autostep's normalizers (v_i)** are tracked automatically when using Autostep
|
|
155
|
+
|
|
156
|
+
### Normalizer State Tracking for Reactive Lag Analysis
|
|
157
|
+
The `run_normalized_learning_loop` function supports tracking the OnlineNormalizer's per-feature mean and variance estimates over time. This is essential for analyzing reactive lag — how quickly the normalizer adapts to distribution shifts:
|
|
158
|
+
|
|
159
|
+
```python
|
|
160
|
+
from alberta_framework import (
|
|
161
|
+
NormalizedLinearLearner, IDBD,
|
|
162
|
+
StepSizeTrackingConfig, NormalizerTrackingConfig,
|
|
163
|
+
run_normalized_learning_loop
|
|
164
|
+
)
|
|
165
|
+
from alberta_framework.streams import RandomWalkStream
|
|
166
|
+
import jax.random as jr
|
|
167
|
+
|
|
168
|
+
stream = RandomWalkStream(feature_dim=10)
|
|
169
|
+
learner = NormalizedLinearLearner(optimizer=IDBD())
|
|
170
|
+
ss_config = StepSizeTrackingConfig(interval=100)
|
|
171
|
+
norm_config = NormalizerTrackingConfig(interval=100)
|
|
172
|
+
|
|
173
|
+
# Track both step-sizes and normalizer state
|
|
174
|
+
state, metrics, ss_history, norm_history = run_normalized_learning_loop(
|
|
175
|
+
learner, stream, num_steps=10000, key=jr.key(42),
|
|
176
|
+
step_size_tracking=ss_config, normalizer_tracking=norm_config
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# norm_history.means: shape (100, 10) - per-feature mean estimates at each recording
|
|
180
|
+
# norm_history.variances: shape (100, 10) - per-feature variance estimates at each recording
|
|
181
|
+
# norm_history.recording_indices: shape (100,) - step indices where recordings were made
|
|
182
|
+
```
|
|
183
|
+
|
|
184
|
+
Return value depends on tracking options:
|
|
185
|
+
- No tracking: `(state, metrics)` — 2-tuple
|
|
186
|
+
- step_size_tracking only: `(state, metrics, ss_history)` — 3-tuple
|
|
187
|
+
- normalizer_tracking only: `(state, metrics, norm_history)` — 3-tuple
|
|
188
|
+
- Both: `(state, metrics, ss_history, norm_history)` — 4-tuple
|
|
189
|
+
|
|
190
|
+
### Batched Learning Loops (vmap-based GPU Parallelization)
|
|
191
|
+
The `run_learning_loop_batched` and `run_normalized_learning_loop_batched` functions use `jax.vmap` to run multiple seeds in parallel, typically achieving 2-5x speedup over sequential execution:
|
|
192
|
+
|
|
193
|
+
```python
|
|
194
|
+
import jax.random as jr
|
|
195
|
+
from alberta_framework import (
|
|
196
|
+
LinearLearner, IDBD, RandomWalkStream,
|
|
197
|
+
run_learning_loop_batched, StepSizeTrackingConfig
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
stream = RandomWalkStream(feature_dim=10)
|
|
201
|
+
learner = LinearLearner(optimizer=IDBD())
|
|
202
|
+
|
|
203
|
+
# Run 30 seeds in parallel
|
|
204
|
+
keys = jr.split(jr.key(42), 30)
|
|
205
|
+
result = run_learning_loop_batched(learner, stream, num_steps=10000, keys=keys)
|
|
206
|
+
|
|
207
|
+
# result.metrics has shape (30, 10000, 3)
|
|
208
|
+
# result.states.weights has shape (30, 10)
|
|
209
|
+
mean_error = result.metrics[:, :, 0].mean(axis=0) # Average squared error over seeds
|
|
210
|
+
|
|
211
|
+
# With step-size tracking
|
|
212
|
+
config = StepSizeTrackingConfig(interval=100)
|
|
213
|
+
result = run_learning_loop_batched(
|
|
214
|
+
learner, stream, num_steps=10000, keys=keys, step_size_tracking=config
|
|
215
|
+
)
|
|
216
|
+
# result.step_size_history.step_sizes has shape (30, 100, 10)
|
|
217
|
+
```
|
|
218
|
+
|
|
219
|
+
Key features:
|
|
220
|
+
- `jax.vmap` parallelizes over seeds, not steps — memory scales with num_seeds
|
|
221
|
+
- `jax.lax.scan` processes steps sequentially within each seed
|
|
222
|
+
- Returns `BatchedLearningResult` or `BatchedNormalizedResult` NamedTuples
|
|
223
|
+
- Tracking histories get batched shapes: `(num_seeds, num_recordings, ...)`
|
|
224
|
+
- Same initial state used for all seeds (controlled variation via different keys)
|
|
225
|
+
|
|
226
|
+
For normalized learners:
|
|
227
|
+
```python
|
|
228
|
+
from alberta_framework import (
|
|
229
|
+
NormalizedLinearLearner, run_normalized_learning_loop_batched,
|
|
230
|
+
NormalizerTrackingConfig
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
learner = NormalizedLinearLearner(optimizer=IDBD())
|
|
234
|
+
result = run_normalized_learning_loop_batched(
|
|
235
|
+
learner, stream, num_steps=10000, keys=keys,
|
|
236
|
+
step_size_tracking=StepSizeTrackingConfig(interval=100),
|
|
237
|
+
normalizer_tracking=NormalizerTrackingConfig(interval=100)
|
|
238
|
+
)
|
|
239
|
+
# result.metrics has shape (30, 10000, 4)
|
|
240
|
+
# result.step_size_history and result.normalizer_history both batched
|
|
241
|
+
```
|
|
153
242
|
|
|
154
243
|
## Gymnasium Integration
|
|
155
244
|
|
|
@@ -267,6 +356,10 @@ The API Reference section is auto-generated from docstrings in the source code.
|
|
|
267
356
|
### Docstring Style
|
|
268
357
|
Use NumPy-style docstrings for all public functions and classes. See `core/optimizers.py` for examples.
|
|
269
358
|
|
|
359
|
+
**Code examples**: Use fenced markdown code blocks (triple backticks with `python`) inside an `Examples` section, not doctest `>>>` format. This ensures proper syntax highlighting in mkdocstrings. See `streams/base.py` or `utils/timing.py` for examples.
|
|
360
|
+
|
|
361
|
+
**Math formulas**: Wrap inline math expressions in backticks for monospace rendering, e.g., `` `y = w @ x + b` `` or `` `alpha_i = exp(log_alpha_i)` ``. See `core/optimizers.py` for examples.
|
|
362
|
+
|
|
270
363
|
## Streams for Factorial Studies
|
|
271
364
|
|
|
272
365
|
The framework supports factorial experiment designs with multiple non-stationarity types and scale ranges:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: alberta-framework
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary: Implementation of the Alberta Plan for AI Research - continual learning with meta-learned step-sizes
|
|
5
5
|
Project-URL: Homepage, https://github.com/j-klawson/alberta-framework
|
|
6
6
|
Project-URL: Repository, https://github.com/j-klawson/alberta-framework
|
|
@@ -49,7 +49,7 @@ Description-Content-Type: text/markdown
|
|
|
49
49
|
[](https://opensource.org/licenses/Apache-2.0)
|
|
50
50
|
[](https://www.python.org/downloads/)
|
|
51
51
|
|
|
52
|
-
A JAX-based research framework implementing components of [The Alberta Plan](https://arxiv.org/abs/2208.11173) in the pursuit of building the foundations of Continual AI.
|
|
52
|
+
A JAX-based research framework implementing components of [The Alberta Plan for AI Research](https://arxiv.org/abs/2208.11173) in the pursuit of building the foundations of Continual AI.
|
|
53
53
|
|
|
54
54
|
> "The agents are complex only because they interact with a complex world... their initial design is as simple, general, and scalable as possible." — *Sutton et al., 2022*
|
|
55
55
|
|
|
@@ -57,6 +57,14 @@ A JAX-based research framework implementing components of [The Alberta Plan](htt
|
|
|
57
57
|
|
|
58
58
|
The Alberta Framework provides foundational components for continual reinforcement learning research. Built on JAX for hardware acceleration, the framework emphasizes temporal uniformity every component updates at every time step, with no special training phases or batch processing.
|
|
59
59
|
|
|
60
|
+
## Project Context
|
|
61
|
+
|
|
62
|
+
This framework is developed as part of my D.Eng. work focusing on the foundations of Continual AI. For more background and context see:
|
|
63
|
+
|
|
64
|
+
* **Research Blog**: [blog.9600baud.net](https://blog.9600baud.net)
|
|
65
|
+
* **Replicating Sutton '92**: [The Foundation of Step-size Adaptation](https://blog.9600baud.net/sutton92.html)
|
|
66
|
+
* **About the Author**: [Keith Lawson](https://blog.9600baud.net/about.html)
|
|
67
|
+
|
|
60
68
|
### Roadmap
|
|
61
69
|
|
|
62
70
|
Depending on my research trajectory I may or may not implement components required for the plan. The current focus of this framework is the Step 1 Baseline Study, investigating the interaction between adaptive optimizers and online normalization.
|
|
@@ -5,7 +5,7 @@
|
|
|
5
5
|
[](https://opensource.org/licenses/Apache-2.0)
|
|
6
6
|
[](https://www.python.org/downloads/)
|
|
7
7
|
|
|
8
|
-
A JAX-based research framework implementing components of [The Alberta Plan](https://arxiv.org/abs/2208.11173) in the pursuit of building the foundations of Continual AI.
|
|
8
|
+
A JAX-based research framework implementing components of [The Alberta Plan for AI Research](https://arxiv.org/abs/2208.11173) in the pursuit of building the foundations of Continual AI.
|
|
9
9
|
|
|
10
10
|
> "The agents are complex only because they interact with a complex world... their initial design is as simple, general, and scalable as possible." — *Sutton et al., 2022*
|
|
11
11
|
|
|
@@ -13,6 +13,14 @@ A JAX-based research framework implementing components of [The Alberta Plan](htt
|
|
|
13
13
|
|
|
14
14
|
The Alberta Framework provides foundational components for continual reinforcement learning research. Built on JAX for hardware acceleration, the framework emphasizes temporal uniformity every component updates at every time step, with no special training phases or batch processing.
|
|
15
15
|
|
|
16
|
+
## Project Context
|
|
17
|
+
|
|
18
|
+
This framework is developed as part of my D.Eng. work focusing on the foundations of Continual AI. For more background and context see:
|
|
19
|
+
|
|
20
|
+
* **Research Blog**: [blog.9600baud.net](https://blog.9600baud.net)
|
|
21
|
+
* **Replicating Sutton '92**: [The Foundation of Step-size Adaptation](https://blog.9600baud.net/sutton92.html)
|
|
22
|
+
* **About the Author**: [Keith Lawson](https://blog.9600baud.net/about.html)
|
|
23
|
+
|
|
16
24
|
### Roadmap
|
|
17
25
|
|
|
18
26
|
Depending on my research trajectory I may or may not implement components required for the plan. The current focus of this framework is the Step 1 Baseline Study, investigating the interaction between adaptive optimizers and online normalization.
|
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "alberta-framework"
|
|
7
|
-
version = "0.
|
|
7
|
+
version = "0.2.0"
|
|
8
8
|
description = "Implementation of the Alberta Plan for AI Research - continual learning with meta-learned step-sizes"
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
license = "Apache-2.0"
|
|
@@ -1,28 +1,45 @@
|
|
|
1
|
-
"""Alberta Framework:
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
learning
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
1
|
+
"""Alberta Framework: A JAX-based research framework for continual AI.
|
|
2
|
+
|
|
3
|
+
The Alberta Framework provides foundational components for continual reinforcement
|
|
4
|
+
learning research. Built on JAX for hardware acceleration, the framework emphasizes
|
|
5
|
+
temporal uniformity — every component updates at every time step, with no special
|
|
6
|
+
training phases or batch processing.
|
|
7
|
+
|
|
8
|
+
Roadmap
|
|
9
|
+
-------
|
|
10
|
+
| Step | Focus | Status |
|
|
11
|
+
|------|-------|--------|
|
|
12
|
+
| 1 | Meta-learned step-sizes (IDBD, Autostep) | **Complete** |
|
|
13
|
+
| 2 | Feature generation and testing | Planned |
|
|
14
|
+
| 3 | GVF predictions, Horde architecture | Planned |
|
|
15
|
+
| 4 | Actor-critic with eligibility traces | Planned |
|
|
16
|
+
| 5-6 | Off-policy learning, average reward | Planned |
|
|
17
|
+
| 7-12 | Hierarchical, multi-agent, world models | Future |
|
|
18
|
+
|
|
19
|
+
Examples
|
|
20
|
+
--------
|
|
21
|
+
```python
|
|
22
|
+
import jax.random as jr
|
|
23
|
+
from alberta_framework import LinearLearner, IDBD, RandomWalkStream, run_learning_loop
|
|
24
|
+
|
|
25
|
+
# Non-stationary stream where target weights drift over time
|
|
26
|
+
stream = RandomWalkStream(feature_dim=10, drift_rate=0.001)
|
|
27
|
+
|
|
28
|
+
# Learner with IDBD meta-learned step-sizes
|
|
29
|
+
learner = LinearLearner(optimizer=IDBD())
|
|
30
|
+
|
|
31
|
+
# JIT-compiled training via jax.lax.scan
|
|
32
|
+
state, metrics = run_learning_loop(learner, stream, num_steps=10000, key=jr.key(42))
|
|
33
|
+
```
|
|
34
|
+
|
|
35
|
+
References
|
|
36
|
+
----------
|
|
37
|
+
- The Alberta Plan for AI Research (Sutton et al., 2022): https://arxiv.org/abs/2208.11173
|
|
38
|
+
- Adapting Bias by Gradient Descent (Sutton, 1992)
|
|
39
|
+
- Tuning-free Step-size Adaptation (Mahmood et al., 2012)
|
|
23
40
|
"""
|
|
24
41
|
|
|
25
|
-
__version__ = "0.
|
|
42
|
+
__version__ = "0.2.0"
|
|
26
43
|
|
|
27
44
|
# Core types
|
|
28
45
|
# Learners
|
|
@@ -33,7 +50,9 @@ from alberta_framework.core.learners import (
|
|
|
33
50
|
UpdateResult,
|
|
34
51
|
metrics_to_dicts,
|
|
35
52
|
run_learning_loop,
|
|
53
|
+
run_learning_loop_batched,
|
|
36
54
|
run_normalized_learning_loop,
|
|
55
|
+
run_normalized_learning_loop_batched,
|
|
37
56
|
)
|
|
38
57
|
|
|
39
58
|
# Normalizers
|
|
@@ -47,9 +66,13 @@ from alberta_framework.core.normalizers import (
|
|
|
47
66
|
from alberta_framework.core.optimizers import IDBD, LMS, Autostep, Optimizer
|
|
48
67
|
from alberta_framework.core.types import (
|
|
49
68
|
AutostepState,
|
|
69
|
+
BatchedLearningResult,
|
|
70
|
+
BatchedNormalizedResult,
|
|
50
71
|
IDBDState,
|
|
51
72
|
LearnerState,
|
|
52
73
|
LMSState,
|
|
74
|
+
NormalizerHistory,
|
|
75
|
+
NormalizerTrackingConfig,
|
|
53
76
|
Observation,
|
|
54
77
|
Prediction,
|
|
55
78
|
StepSizeHistory,
|
|
@@ -119,10 +142,14 @@ __all__ = [
|
|
|
119
142
|
"__version__",
|
|
120
143
|
# Types
|
|
121
144
|
"AutostepState",
|
|
145
|
+
"BatchedLearningResult",
|
|
146
|
+
"BatchedNormalizedResult",
|
|
122
147
|
"IDBDState",
|
|
123
148
|
"LMSState",
|
|
124
149
|
"LearnerState",
|
|
150
|
+
"NormalizerHistory",
|
|
125
151
|
"NormalizerState",
|
|
152
|
+
"NormalizerTrackingConfig",
|
|
126
153
|
"Observation",
|
|
127
154
|
"Prediction",
|
|
128
155
|
"StepSizeHistory",
|
|
@@ -143,7 +170,9 @@ __all__ = [
|
|
|
143
170
|
"NormalizedLearnerState",
|
|
144
171
|
"NormalizedLinearLearner",
|
|
145
172
|
"run_learning_loop",
|
|
173
|
+
"run_learning_loop_batched",
|
|
146
174
|
"run_normalized_learning_loop",
|
|
175
|
+
"run_normalized_learning_loop_batched",
|
|
147
176
|
"metrics_to_dicts",
|
|
148
177
|
# Streams - protocol
|
|
149
178
|
"ScanStream",
|