alberta-framework 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alberta_framework/__init__.py +225 -0
- alberta_framework/core/__init__.py +27 -0
- alberta_framework/core/learners.py +1070 -0
- alberta_framework/core/normalizers.py +192 -0
- alberta_framework/core/optimizers.py +424 -0
- alberta_framework/core/types.py +271 -0
- alberta_framework/py.typed +0 -0
- alberta_framework/streams/__init__.py +83 -0
- alberta_framework/streams/base.py +73 -0
- alberta_framework/streams/gymnasium.py +655 -0
- alberta_framework/streams/synthetic.py +1001 -0
- alberta_framework/utils/__init__.py +113 -0
- alberta_framework/utils/experiments.py +335 -0
- alberta_framework/utils/export.py +509 -0
- alberta_framework/utils/metrics.py +112 -0
- alberta_framework/utils/statistics.py +527 -0
- alberta_framework/utils/timing.py +144 -0
- alberta_framework/utils/visualization.py +571 -0
- alberta_framework-0.2.2.dist-info/METADATA +206 -0
- alberta_framework-0.2.2.dist-info/RECORD +22 -0
- alberta_framework-0.2.2.dist-info/WHEEL +4 -0
- alberta_framework-0.2.2.dist-info/licenses/LICENSE +190 -0
|
@@ -0,0 +1,271 @@
|
|
|
1
|
+
"""Type definitions for the Alberta Framework.
|
|
2
|
+
|
|
3
|
+
This module defines the core data types used throughout the framework,
|
|
4
|
+
following JAX conventions with immutable NamedTuples for state management.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import TYPE_CHECKING, NamedTuple
|
|
8
|
+
|
|
9
|
+
import jax.numpy as jnp
|
|
10
|
+
from jax import Array
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from alberta_framework.core.learners import NormalizedLearnerState
|
|
14
|
+
|
|
15
|
+
# Type aliases for clarity
|
|
16
|
+
Observation = Array # x_t: feature vector
|
|
17
|
+
Target = Array # y*_t: desired output
|
|
18
|
+
Prediction = Array # y_t: model output
|
|
19
|
+
Reward = float # r_t: scalar reward
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class TimeStep(NamedTuple):
|
|
23
|
+
"""Single experience from an experience stream.
|
|
24
|
+
|
|
25
|
+
Attributes:
|
|
26
|
+
observation: Feature vector x_t
|
|
27
|
+
target: Desired output y*_t (for supervised learning)
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
observation: Observation
|
|
31
|
+
target: Target
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class LearnerState(NamedTuple):
|
|
35
|
+
"""State for a linear learner.
|
|
36
|
+
|
|
37
|
+
Attributes:
|
|
38
|
+
weights: Weight vector for linear prediction
|
|
39
|
+
bias: Bias term
|
|
40
|
+
optimizer_state: State maintained by the optimizer
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
weights: Array
|
|
44
|
+
bias: Array
|
|
45
|
+
optimizer_state: "LMSState | IDBDState | AutostepState"
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class LMSState(NamedTuple):
|
|
49
|
+
"""State for the LMS (Least Mean Square) optimizer.
|
|
50
|
+
|
|
51
|
+
LMS uses a fixed step-size, so state only tracks the step-size parameter.
|
|
52
|
+
|
|
53
|
+
Attributes:
|
|
54
|
+
step_size: Fixed learning rate alpha
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
step_size: Array
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class IDBDState(NamedTuple):
|
|
61
|
+
"""State for the IDBD (Incremental Delta-Bar-Delta) optimizer.
|
|
62
|
+
|
|
63
|
+
IDBD maintains per-weight adaptive step-sizes that are meta-learned
|
|
64
|
+
based on the correlation of successive gradients.
|
|
65
|
+
|
|
66
|
+
Reference: Sutton 1992, "Adapting Bias by Gradient Descent"
|
|
67
|
+
|
|
68
|
+
Attributes:
|
|
69
|
+
log_step_sizes: Log of per-weight step-sizes (log alpha_i)
|
|
70
|
+
traces: Per-weight traces h_i for gradient correlation
|
|
71
|
+
meta_step_size: Meta learning rate beta for adapting step-sizes
|
|
72
|
+
bias_step_size: Step-size for the bias term
|
|
73
|
+
bias_trace: Trace for the bias term
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
log_step_sizes: Array # log(alpha_i) for numerical stability
|
|
77
|
+
traces: Array # h_i: trace of weight-feature products
|
|
78
|
+
meta_step_size: Array # beta: step-size for the step-sizes
|
|
79
|
+
bias_step_size: Array # Step-size for bias
|
|
80
|
+
bias_trace: Array # Trace for bias
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class AutostepState(NamedTuple):
|
|
84
|
+
"""State for the Autostep optimizer.
|
|
85
|
+
|
|
86
|
+
Autostep is a tuning-free step-size adaptation algorithm that normalizes
|
|
87
|
+
gradients to prevent large updates and adapts step-sizes based on
|
|
88
|
+
gradient correlation.
|
|
89
|
+
|
|
90
|
+
Reference: Mahmood et al. 2012, "Tuning-free step-size adaptation"
|
|
91
|
+
|
|
92
|
+
Attributes:
|
|
93
|
+
step_sizes: Per-weight step-sizes (alpha_i)
|
|
94
|
+
traces: Per-weight traces for gradient correlation (h_i)
|
|
95
|
+
normalizers: Running max absolute gradient per weight (v_i)
|
|
96
|
+
meta_step_size: Meta learning rate mu for adapting step-sizes
|
|
97
|
+
normalizer_decay: Decay factor for the normalizer (tau)
|
|
98
|
+
bias_step_size: Step-size for the bias term
|
|
99
|
+
bias_trace: Trace for the bias term
|
|
100
|
+
bias_normalizer: Normalizer for the bias gradient
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
step_sizes: Array # alpha_i
|
|
104
|
+
traces: Array # h_i
|
|
105
|
+
normalizers: Array # v_i: running max of |gradient|
|
|
106
|
+
meta_step_size: Array # mu
|
|
107
|
+
normalizer_decay: Array # tau
|
|
108
|
+
bias_step_size: Array
|
|
109
|
+
bias_trace: Array
|
|
110
|
+
bias_normalizer: Array
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class StepSizeTrackingConfig(NamedTuple):
|
|
114
|
+
"""Configuration for recording per-weight step-sizes during training.
|
|
115
|
+
|
|
116
|
+
Attributes:
|
|
117
|
+
interval: Record step-sizes every N steps
|
|
118
|
+
include_bias: Whether to also record the bias step-size
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
interval: int
|
|
122
|
+
include_bias: bool = True
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class StepSizeHistory(NamedTuple):
|
|
126
|
+
"""History of per-weight step-sizes recorded during training.
|
|
127
|
+
|
|
128
|
+
Attributes:
|
|
129
|
+
step_sizes: Per-weight step-sizes at each recording, shape (num_recordings, num_weights)
|
|
130
|
+
bias_step_sizes: Bias step-sizes at each recording, shape (num_recordings,) or None
|
|
131
|
+
recording_indices: Step indices where recordings were made, shape (num_recordings,)
|
|
132
|
+
normalizers: Autostep's per-weight normalizers (v_i) at each recording,
|
|
133
|
+
shape (num_recordings, num_weights) or None. Only populated for Autostep optimizer.
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
step_sizes: Array # (num_recordings, num_weights)
|
|
137
|
+
bias_step_sizes: Array | None # (num_recordings,) or None
|
|
138
|
+
recording_indices: Array # (num_recordings,)
|
|
139
|
+
normalizers: Array | None = None # (num_recordings, num_weights) - Autostep v_i
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
class NormalizerTrackingConfig(NamedTuple):
|
|
143
|
+
"""Configuration for recording per-feature normalizer state during training.
|
|
144
|
+
|
|
145
|
+
Attributes:
|
|
146
|
+
interval: Record normalizer state every N steps
|
|
147
|
+
"""
|
|
148
|
+
|
|
149
|
+
interval: int
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class NormalizerHistory(NamedTuple):
|
|
153
|
+
"""History of per-feature normalizer state recorded during training.
|
|
154
|
+
|
|
155
|
+
Used for analyzing how the OnlineNormalizer adapts to distribution shifts
|
|
156
|
+
(reactive lag diagnostic).
|
|
157
|
+
|
|
158
|
+
Attributes:
|
|
159
|
+
means: Per-feature mean estimates at each recording, shape (num_recordings, feature_dim)
|
|
160
|
+
variances: Per-feature variance estimates at each recording,
|
|
161
|
+
shape (num_recordings, feature_dim)
|
|
162
|
+
recording_indices: Step indices where recordings were made, shape (num_recordings,)
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
means: Array # (num_recordings, feature_dim)
|
|
166
|
+
variances: Array # (num_recordings, feature_dim)
|
|
167
|
+
recording_indices: Array # (num_recordings,)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class BatchedLearningResult(NamedTuple):
|
|
171
|
+
"""Result from batched learning loop across multiple seeds.
|
|
172
|
+
|
|
173
|
+
Used with `run_learning_loop_batched` for vmap-based GPU parallelization.
|
|
174
|
+
|
|
175
|
+
Attributes:
|
|
176
|
+
states: Batched learner states - each array has shape (num_seeds, ...)
|
|
177
|
+
metrics: Metrics array with shape (num_seeds, num_steps, 3)
|
|
178
|
+
where columns are [squared_error, error, mean_step_size]
|
|
179
|
+
step_size_history: Optional step-size history with batched shapes,
|
|
180
|
+
or None if tracking was disabled
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
states: "LearnerState" # Batched: each array has shape (num_seeds, ...)
|
|
184
|
+
metrics: Array # Shape: (num_seeds, num_steps, 3)
|
|
185
|
+
step_size_history: StepSizeHistory | None
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class BatchedNormalizedResult(NamedTuple):
|
|
189
|
+
"""Result from batched normalized learning loop across multiple seeds.
|
|
190
|
+
|
|
191
|
+
Used with `run_normalized_learning_loop_batched` for vmap-based GPU parallelization.
|
|
192
|
+
|
|
193
|
+
Attributes:
|
|
194
|
+
states: Batched normalized learner states - each array has shape (num_seeds, ...)
|
|
195
|
+
metrics: Metrics array with shape (num_seeds, num_steps, 4)
|
|
196
|
+
where columns are [squared_error, error, mean_step_size, normalizer_mean_var]
|
|
197
|
+
step_size_history: Optional step-size history with batched shapes,
|
|
198
|
+
or None if tracking was disabled
|
|
199
|
+
normalizer_history: Optional normalizer history with batched shapes,
|
|
200
|
+
or None if tracking was disabled
|
|
201
|
+
"""
|
|
202
|
+
|
|
203
|
+
states: "NormalizedLearnerState" # Batched: each array has shape (num_seeds, ...)
|
|
204
|
+
metrics: Array # Shape: (num_seeds, num_steps, 4)
|
|
205
|
+
step_size_history: StepSizeHistory | None
|
|
206
|
+
normalizer_history: NormalizerHistory | None
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def create_lms_state(step_size: float = 0.01) -> LMSState:
|
|
210
|
+
"""Create initial LMS optimizer state.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
step_size: Fixed learning rate
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
Initial LMS state
|
|
217
|
+
"""
|
|
218
|
+
return LMSState(step_size=jnp.array(step_size, dtype=jnp.float32))
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def create_idbd_state(
|
|
222
|
+
feature_dim: int,
|
|
223
|
+
initial_step_size: float = 0.01,
|
|
224
|
+
meta_step_size: float = 0.01,
|
|
225
|
+
) -> IDBDState:
|
|
226
|
+
"""Create initial IDBD optimizer state.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
feature_dim: Dimension of the feature vector
|
|
230
|
+
initial_step_size: Initial per-weight step-size
|
|
231
|
+
meta_step_size: Meta learning rate for adapting step-sizes
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
Initial IDBD state
|
|
235
|
+
"""
|
|
236
|
+
return IDBDState(
|
|
237
|
+
log_step_sizes=jnp.full(feature_dim, jnp.log(initial_step_size), dtype=jnp.float32),
|
|
238
|
+
traces=jnp.zeros(feature_dim, dtype=jnp.float32),
|
|
239
|
+
meta_step_size=jnp.array(meta_step_size, dtype=jnp.float32),
|
|
240
|
+
bias_step_size=jnp.array(initial_step_size, dtype=jnp.float32),
|
|
241
|
+
bias_trace=jnp.array(0.0, dtype=jnp.float32),
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def create_autostep_state(
|
|
246
|
+
feature_dim: int,
|
|
247
|
+
initial_step_size: float = 0.01,
|
|
248
|
+
meta_step_size: float = 0.01,
|
|
249
|
+
normalizer_decay: float = 0.99,
|
|
250
|
+
) -> AutostepState:
|
|
251
|
+
"""Create initial Autostep optimizer state.
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
feature_dim: Dimension of the feature vector
|
|
255
|
+
initial_step_size: Initial per-weight step-size
|
|
256
|
+
meta_step_size: Meta learning rate for adapting step-sizes
|
|
257
|
+
normalizer_decay: Decay factor for gradient normalizers
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
Initial Autostep state
|
|
261
|
+
"""
|
|
262
|
+
return AutostepState(
|
|
263
|
+
step_sizes=jnp.full(feature_dim, initial_step_size, dtype=jnp.float32),
|
|
264
|
+
traces=jnp.zeros(feature_dim, dtype=jnp.float32),
|
|
265
|
+
normalizers=jnp.ones(feature_dim, dtype=jnp.float32),
|
|
266
|
+
meta_step_size=jnp.array(meta_step_size, dtype=jnp.float32),
|
|
267
|
+
normalizer_decay=jnp.array(normalizer_decay, dtype=jnp.float32),
|
|
268
|
+
bias_step_size=jnp.array(initial_step_size, dtype=jnp.float32),
|
|
269
|
+
bias_trace=jnp.array(0.0, dtype=jnp.float32),
|
|
270
|
+
bias_normalizer=jnp.array(1.0, dtype=jnp.float32),
|
|
271
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""Experience streams for continual learning."""
|
|
2
|
+
|
|
3
|
+
from alberta_framework.streams.base import ScanStream
|
|
4
|
+
from alberta_framework.streams.synthetic import (
|
|
5
|
+
AbruptChangeState,
|
|
6
|
+
AbruptChangeStream,
|
|
7
|
+
AbruptChangeTarget,
|
|
8
|
+
CyclicState,
|
|
9
|
+
CyclicStream,
|
|
10
|
+
CyclicTarget,
|
|
11
|
+
DynamicScaleShiftState,
|
|
12
|
+
DynamicScaleShiftStream,
|
|
13
|
+
PeriodicChangeState,
|
|
14
|
+
PeriodicChangeStream,
|
|
15
|
+
PeriodicChangeTarget,
|
|
16
|
+
RandomWalkState,
|
|
17
|
+
RandomWalkStream,
|
|
18
|
+
RandomWalkTarget,
|
|
19
|
+
ScaleDriftState,
|
|
20
|
+
ScaleDriftStream,
|
|
21
|
+
ScaledStreamState,
|
|
22
|
+
ScaledStreamWrapper,
|
|
23
|
+
SuttonExperiment1State,
|
|
24
|
+
SuttonExperiment1Stream,
|
|
25
|
+
make_scale_range,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
__all__ = [
|
|
29
|
+
# Protocol
|
|
30
|
+
"ScanStream",
|
|
31
|
+
# Stream classes
|
|
32
|
+
"AbruptChangeState",
|
|
33
|
+
"AbruptChangeStream",
|
|
34
|
+
"AbruptChangeTarget",
|
|
35
|
+
"CyclicState",
|
|
36
|
+
"CyclicStream",
|
|
37
|
+
"CyclicTarget",
|
|
38
|
+
"DynamicScaleShiftState",
|
|
39
|
+
"DynamicScaleShiftStream",
|
|
40
|
+
"PeriodicChangeState",
|
|
41
|
+
"PeriodicChangeStream",
|
|
42
|
+
"PeriodicChangeTarget",
|
|
43
|
+
"RandomWalkState",
|
|
44
|
+
"RandomWalkStream",
|
|
45
|
+
"RandomWalkTarget",
|
|
46
|
+
"ScaleDriftState",
|
|
47
|
+
"ScaleDriftStream",
|
|
48
|
+
"ScaledStreamState",
|
|
49
|
+
"ScaledStreamWrapper",
|
|
50
|
+
"SuttonExperiment1State",
|
|
51
|
+
"SuttonExperiment1Stream",
|
|
52
|
+
# Utilities
|
|
53
|
+
"make_scale_range",
|
|
54
|
+
]
|
|
55
|
+
|
|
56
|
+
# Gymnasium streams are optional - only export if gymnasium is installed
|
|
57
|
+
try:
|
|
58
|
+
from alberta_framework.streams.gymnasium import (
|
|
59
|
+
GymnasiumStream,
|
|
60
|
+
PredictionMode,
|
|
61
|
+
TDStream,
|
|
62
|
+
collect_trajectory,
|
|
63
|
+
learn_from_trajectory,
|
|
64
|
+
learn_from_trajectory_normalized,
|
|
65
|
+
make_epsilon_greedy_policy,
|
|
66
|
+
make_gymnasium_stream,
|
|
67
|
+
make_random_policy,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
__all__ += [
|
|
71
|
+
"GymnasiumStream",
|
|
72
|
+
"PredictionMode",
|
|
73
|
+
"TDStream",
|
|
74
|
+
"collect_trajectory",
|
|
75
|
+
"learn_from_trajectory",
|
|
76
|
+
"learn_from_trajectory_normalized",
|
|
77
|
+
"make_epsilon_greedy_policy",
|
|
78
|
+
"make_gymnasium_stream",
|
|
79
|
+
"make_random_policy",
|
|
80
|
+
]
|
|
81
|
+
except ImportError:
|
|
82
|
+
# gymnasium not installed
|
|
83
|
+
pass
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
"""Base protocol for experience streams.
|
|
2
|
+
|
|
3
|
+
Experience streams generate temporally-uniform experience for continual learning.
|
|
4
|
+
Every time step produces a new observation-target pair.
|
|
5
|
+
|
|
6
|
+
This module defines the ScanStream protocol for JAX scan-compatible streams.
|
|
7
|
+
All streams implement pure functions that can be JIT-compiled.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from typing import Protocol, TypeVar
|
|
11
|
+
|
|
12
|
+
from jax import Array
|
|
13
|
+
|
|
14
|
+
from alberta_framework.core.types import TimeStep
|
|
15
|
+
|
|
16
|
+
# Type variable for stream state
|
|
17
|
+
StateT = TypeVar("StateT")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ScanStream(Protocol[StateT]):
|
|
21
|
+
"""Protocol for JAX scan-compatible experience streams.
|
|
22
|
+
|
|
23
|
+
Streams generate temporally-uniform experience for continual learning.
|
|
24
|
+
Unlike iterator-based streams, ScanStream uses pure functions that
|
|
25
|
+
can be compiled with JAX's JIT and used with jax.lax.scan.
|
|
26
|
+
|
|
27
|
+
The stream should be non-stationary to test continual learning
|
|
28
|
+
capabilities - the underlying target function changes over time.
|
|
29
|
+
|
|
30
|
+
Type Parameters:
|
|
31
|
+
StateT: The state type maintained by this stream
|
|
32
|
+
|
|
33
|
+
Examples
|
|
34
|
+
--------
|
|
35
|
+
```python
|
|
36
|
+
stream = RandomWalkStream(feature_dim=10, drift_rate=0.001)
|
|
37
|
+
key = jax.random.key(42)
|
|
38
|
+
state = stream.init(key)
|
|
39
|
+
timestep, new_state = stream.step(state, jnp.array(0))
|
|
40
|
+
```
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def feature_dim(self) -> int:
|
|
45
|
+
"""Return the dimension of observation vectors."""
|
|
46
|
+
...
|
|
47
|
+
|
|
48
|
+
def init(self, key: Array) -> StateT:
|
|
49
|
+
"""Initialize stream state.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
key: JAX random key for initialization
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
Initial stream state
|
|
56
|
+
"""
|
|
57
|
+
...
|
|
58
|
+
|
|
59
|
+
def step(self, state: StateT, idx: Array) -> tuple[TimeStep, StateT]:
|
|
60
|
+
"""Generate one time step. Must be JIT-compatible.
|
|
61
|
+
|
|
62
|
+
This is a pure function that takes the current state and step index,
|
|
63
|
+
and returns a TimeStep along with the updated state. The step index
|
|
64
|
+
can be used for time-dependent behavior but is often ignored.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
state: Current stream state
|
|
68
|
+
idx: Current step index (can be ignored for most streams)
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
Tuple of (timestep, new_state)
|
|
72
|
+
"""
|
|
73
|
+
...
|