alberta-framework 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,271 @@
1
+ """Type definitions for the Alberta Framework.
2
+
3
+ This module defines the core data types used throughout the framework,
4
+ following JAX conventions with immutable NamedTuples for state management.
5
+ """
6
+
7
+ from typing import TYPE_CHECKING, NamedTuple
8
+
9
+ import jax.numpy as jnp
10
+ from jax import Array
11
+
12
+ if TYPE_CHECKING:
13
+ from alberta_framework.core.learners import NormalizedLearnerState
14
+
15
+ # Type aliases for clarity
16
+ Observation = Array # x_t: feature vector
17
+ Target = Array # y*_t: desired output
18
+ Prediction = Array # y_t: model output
19
+ Reward = float # r_t: scalar reward
20
+
21
+
22
+ class TimeStep(NamedTuple):
23
+ """Single experience from an experience stream.
24
+
25
+ Attributes:
26
+ observation: Feature vector x_t
27
+ target: Desired output y*_t (for supervised learning)
28
+ """
29
+
30
+ observation: Observation
31
+ target: Target
32
+
33
+
34
+ class LearnerState(NamedTuple):
35
+ """State for a linear learner.
36
+
37
+ Attributes:
38
+ weights: Weight vector for linear prediction
39
+ bias: Bias term
40
+ optimizer_state: State maintained by the optimizer
41
+ """
42
+
43
+ weights: Array
44
+ bias: Array
45
+ optimizer_state: "LMSState | IDBDState | AutostepState"
46
+
47
+
48
+ class LMSState(NamedTuple):
49
+ """State for the LMS (Least Mean Square) optimizer.
50
+
51
+ LMS uses a fixed step-size, so state only tracks the step-size parameter.
52
+
53
+ Attributes:
54
+ step_size: Fixed learning rate alpha
55
+ """
56
+
57
+ step_size: Array
58
+
59
+
60
+ class IDBDState(NamedTuple):
61
+ """State for the IDBD (Incremental Delta-Bar-Delta) optimizer.
62
+
63
+ IDBD maintains per-weight adaptive step-sizes that are meta-learned
64
+ based on the correlation of successive gradients.
65
+
66
+ Reference: Sutton 1992, "Adapting Bias by Gradient Descent"
67
+
68
+ Attributes:
69
+ log_step_sizes: Log of per-weight step-sizes (log alpha_i)
70
+ traces: Per-weight traces h_i for gradient correlation
71
+ meta_step_size: Meta learning rate beta for adapting step-sizes
72
+ bias_step_size: Step-size for the bias term
73
+ bias_trace: Trace for the bias term
74
+ """
75
+
76
+ log_step_sizes: Array # log(alpha_i) for numerical stability
77
+ traces: Array # h_i: trace of weight-feature products
78
+ meta_step_size: Array # beta: step-size for the step-sizes
79
+ bias_step_size: Array # Step-size for bias
80
+ bias_trace: Array # Trace for bias
81
+
82
+
83
+ class AutostepState(NamedTuple):
84
+ """State for the Autostep optimizer.
85
+
86
+ Autostep is a tuning-free step-size adaptation algorithm that normalizes
87
+ gradients to prevent large updates and adapts step-sizes based on
88
+ gradient correlation.
89
+
90
+ Reference: Mahmood et al. 2012, "Tuning-free step-size adaptation"
91
+
92
+ Attributes:
93
+ step_sizes: Per-weight step-sizes (alpha_i)
94
+ traces: Per-weight traces for gradient correlation (h_i)
95
+ normalizers: Running max absolute gradient per weight (v_i)
96
+ meta_step_size: Meta learning rate mu for adapting step-sizes
97
+ normalizer_decay: Decay factor for the normalizer (tau)
98
+ bias_step_size: Step-size for the bias term
99
+ bias_trace: Trace for the bias term
100
+ bias_normalizer: Normalizer for the bias gradient
101
+ """
102
+
103
+ step_sizes: Array # alpha_i
104
+ traces: Array # h_i
105
+ normalizers: Array # v_i: running max of |gradient|
106
+ meta_step_size: Array # mu
107
+ normalizer_decay: Array # tau
108
+ bias_step_size: Array
109
+ bias_trace: Array
110
+ bias_normalizer: Array
111
+
112
+
113
+ class StepSizeTrackingConfig(NamedTuple):
114
+ """Configuration for recording per-weight step-sizes during training.
115
+
116
+ Attributes:
117
+ interval: Record step-sizes every N steps
118
+ include_bias: Whether to also record the bias step-size
119
+ """
120
+
121
+ interval: int
122
+ include_bias: bool = True
123
+
124
+
125
+ class StepSizeHistory(NamedTuple):
126
+ """History of per-weight step-sizes recorded during training.
127
+
128
+ Attributes:
129
+ step_sizes: Per-weight step-sizes at each recording, shape (num_recordings, num_weights)
130
+ bias_step_sizes: Bias step-sizes at each recording, shape (num_recordings,) or None
131
+ recording_indices: Step indices where recordings were made, shape (num_recordings,)
132
+ normalizers: Autostep's per-weight normalizers (v_i) at each recording,
133
+ shape (num_recordings, num_weights) or None. Only populated for Autostep optimizer.
134
+ """
135
+
136
+ step_sizes: Array # (num_recordings, num_weights)
137
+ bias_step_sizes: Array | None # (num_recordings,) or None
138
+ recording_indices: Array # (num_recordings,)
139
+ normalizers: Array | None = None # (num_recordings, num_weights) - Autostep v_i
140
+
141
+
142
+ class NormalizerTrackingConfig(NamedTuple):
143
+ """Configuration for recording per-feature normalizer state during training.
144
+
145
+ Attributes:
146
+ interval: Record normalizer state every N steps
147
+ """
148
+
149
+ interval: int
150
+
151
+
152
+ class NormalizerHistory(NamedTuple):
153
+ """History of per-feature normalizer state recorded during training.
154
+
155
+ Used for analyzing how the OnlineNormalizer adapts to distribution shifts
156
+ (reactive lag diagnostic).
157
+
158
+ Attributes:
159
+ means: Per-feature mean estimates at each recording, shape (num_recordings, feature_dim)
160
+ variances: Per-feature variance estimates at each recording,
161
+ shape (num_recordings, feature_dim)
162
+ recording_indices: Step indices where recordings were made, shape (num_recordings,)
163
+ """
164
+
165
+ means: Array # (num_recordings, feature_dim)
166
+ variances: Array # (num_recordings, feature_dim)
167
+ recording_indices: Array # (num_recordings,)
168
+
169
+
170
+ class BatchedLearningResult(NamedTuple):
171
+ """Result from batched learning loop across multiple seeds.
172
+
173
+ Used with `run_learning_loop_batched` for vmap-based GPU parallelization.
174
+
175
+ Attributes:
176
+ states: Batched learner states - each array has shape (num_seeds, ...)
177
+ metrics: Metrics array with shape (num_seeds, num_steps, 3)
178
+ where columns are [squared_error, error, mean_step_size]
179
+ step_size_history: Optional step-size history with batched shapes,
180
+ or None if tracking was disabled
181
+ """
182
+
183
+ states: "LearnerState" # Batched: each array has shape (num_seeds, ...)
184
+ metrics: Array # Shape: (num_seeds, num_steps, 3)
185
+ step_size_history: StepSizeHistory | None
186
+
187
+
188
+ class BatchedNormalizedResult(NamedTuple):
189
+ """Result from batched normalized learning loop across multiple seeds.
190
+
191
+ Used with `run_normalized_learning_loop_batched` for vmap-based GPU parallelization.
192
+
193
+ Attributes:
194
+ states: Batched normalized learner states - each array has shape (num_seeds, ...)
195
+ metrics: Metrics array with shape (num_seeds, num_steps, 4)
196
+ where columns are [squared_error, error, mean_step_size, normalizer_mean_var]
197
+ step_size_history: Optional step-size history with batched shapes,
198
+ or None if tracking was disabled
199
+ normalizer_history: Optional normalizer history with batched shapes,
200
+ or None if tracking was disabled
201
+ """
202
+
203
+ states: "NormalizedLearnerState" # Batched: each array has shape (num_seeds, ...)
204
+ metrics: Array # Shape: (num_seeds, num_steps, 4)
205
+ step_size_history: StepSizeHistory | None
206
+ normalizer_history: NormalizerHistory | None
207
+
208
+
209
+ def create_lms_state(step_size: float = 0.01) -> LMSState:
210
+ """Create initial LMS optimizer state.
211
+
212
+ Args:
213
+ step_size: Fixed learning rate
214
+
215
+ Returns:
216
+ Initial LMS state
217
+ """
218
+ return LMSState(step_size=jnp.array(step_size, dtype=jnp.float32))
219
+
220
+
221
+ def create_idbd_state(
222
+ feature_dim: int,
223
+ initial_step_size: float = 0.01,
224
+ meta_step_size: float = 0.01,
225
+ ) -> IDBDState:
226
+ """Create initial IDBD optimizer state.
227
+
228
+ Args:
229
+ feature_dim: Dimension of the feature vector
230
+ initial_step_size: Initial per-weight step-size
231
+ meta_step_size: Meta learning rate for adapting step-sizes
232
+
233
+ Returns:
234
+ Initial IDBD state
235
+ """
236
+ return IDBDState(
237
+ log_step_sizes=jnp.full(feature_dim, jnp.log(initial_step_size), dtype=jnp.float32),
238
+ traces=jnp.zeros(feature_dim, dtype=jnp.float32),
239
+ meta_step_size=jnp.array(meta_step_size, dtype=jnp.float32),
240
+ bias_step_size=jnp.array(initial_step_size, dtype=jnp.float32),
241
+ bias_trace=jnp.array(0.0, dtype=jnp.float32),
242
+ )
243
+
244
+
245
+ def create_autostep_state(
246
+ feature_dim: int,
247
+ initial_step_size: float = 0.01,
248
+ meta_step_size: float = 0.01,
249
+ normalizer_decay: float = 0.99,
250
+ ) -> AutostepState:
251
+ """Create initial Autostep optimizer state.
252
+
253
+ Args:
254
+ feature_dim: Dimension of the feature vector
255
+ initial_step_size: Initial per-weight step-size
256
+ meta_step_size: Meta learning rate for adapting step-sizes
257
+ normalizer_decay: Decay factor for gradient normalizers
258
+
259
+ Returns:
260
+ Initial Autostep state
261
+ """
262
+ return AutostepState(
263
+ step_sizes=jnp.full(feature_dim, initial_step_size, dtype=jnp.float32),
264
+ traces=jnp.zeros(feature_dim, dtype=jnp.float32),
265
+ normalizers=jnp.ones(feature_dim, dtype=jnp.float32),
266
+ meta_step_size=jnp.array(meta_step_size, dtype=jnp.float32),
267
+ normalizer_decay=jnp.array(normalizer_decay, dtype=jnp.float32),
268
+ bias_step_size=jnp.array(initial_step_size, dtype=jnp.float32),
269
+ bias_trace=jnp.array(0.0, dtype=jnp.float32),
270
+ bias_normalizer=jnp.array(1.0, dtype=jnp.float32),
271
+ )
File without changes
@@ -0,0 +1,83 @@
1
+ """Experience streams for continual learning."""
2
+
3
+ from alberta_framework.streams.base import ScanStream
4
+ from alberta_framework.streams.synthetic import (
5
+ AbruptChangeState,
6
+ AbruptChangeStream,
7
+ AbruptChangeTarget,
8
+ CyclicState,
9
+ CyclicStream,
10
+ CyclicTarget,
11
+ DynamicScaleShiftState,
12
+ DynamicScaleShiftStream,
13
+ PeriodicChangeState,
14
+ PeriodicChangeStream,
15
+ PeriodicChangeTarget,
16
+ RandomWalkState,
17
+ RandomWalkStream,
18
+ RandomWalkTarget,
19
+ ScaleDriftState,
20
+ ScaleDriftStream,
21
+ ScaledStreamState,
22
+ ScaledStreamWrapper,
23
+ SuttonExperiment1State,
24
+ SuttonExperiment1Stream,
25
+ make_scale_range,
26
+ )
27
+
28
+ __all__ = [
29
+ # Protocol
30
+ "ScanStream",
31
+ # Stream classes
32
+ "AbruptChangeState",
33
+ "AbruptChangeStream",
34
+ "AbruptChangeTarget",
35
+ "CyclicState",
36
+ "CyclicStream",
37
+ "CyclicTarget",
38
+ "DynamicScaleShiftState",
39
+ "DynamicScaleShiftStream",
40
+ "PeriodicChangeState",
41
+ "PeriodicChangeStream",
42
+ "PeriodicChangeTarget",
43
+ "RandomWalkState",
44
+ "RandomWalkStream",
45
+ "RandomWalkTarget",
46
+ "ScaleDriftState",
47
+ "ScaleDriftStream",
48
+ "ScaledStreamState",
49
+ "ScaledStreamWrapper",
50
+ "SuttonExperiment1State",
51
+ "SuttonExperiment1Stream",
52
+ # Utilities
53
+ "make_scale_range",
54
+ ]
55
+
56
+ # Gymnasium streams are optional - only export if gymnasium is installed
57
+ try:
58
+ from alberta_framework.streams.gymnasium import (
59
+ GymnasiumStream,
60
+ PredictionMode,
61
+ TDStream,
62
+ collect_trajectory,
63
+ learn_from_trajectory,
64
+ learn_from_trajectory_normalized,
65
+ make_epsilon_greedy_policy,
66
+ make_gymnasium_stream,
67
+ make_random_policy,
68
+ )
69
+
70
+ __all__ += [
71
+ "GymnasiumStream",
72
+ "PredictionMode",
73
+ "TDStream",
74
+ "collect_trajectory",
75
+ "learn_from_trajectory",
76
+ "learn_from_trajectory_normalized",
77
+ "make_epsilon_greedy_policy",
78
+ "make_gymnasium_stream",
79
+ "make_random_policy",
80
+ ]
81
+ except ImportError:
82
+ # gymnasium not installed
83
+ pass
@@ -0,0 +1,73 @@
1
+ """Base protocol for experience streams.
2
+
3
+ Experience streams generate temporally-uniform experience for continual learning.
4
+ Every time step produces a new observation-target pair.
5
+
6
+ This module defines the ScanStream protocol for JAX scan-compatible streams.
7
+ All streams implement pure functions that can be JIT-compiled.
8
+ """
9
+
10
+ from typing import Protocol, TypeVar
11
+
12
+ from jax import Array
13
+
14
+ from alberta_framework.core.types import TimeStep
15
+
16
+ # Type variable for stream state
17
+ StateT = TypeVar("StateT")
18
+
19
+
20
+ class ScanStream(Protocol[StateT]):
21
+ """Protocol for JAX scan-compatible experience streams.
22
+
23
+ Streams generate temporally-uniform experience for continual learning.
24
+ Unlike iterator-based streams, ScanStream uses pure functions that
25
+ can be compiled with JAX's JIT and used with jax.lax.scan.
26
+
27
+ The stream should be non-stationary to test continual learning
28
+ capabilities - the underlying target function changes over time.
29
+
30
+ Type Parameters:
31
+ StateT: The state type maintained by this stream
32
+
33
+ Examples
34
+ --------
35
+ ```python
36
+ stream = RandomWalkStream(feature_dim=10, drift_rate=0.001)
37
+ key = jax.random.key(42)
38
+ state = stream.init(key)
39
+ timestep, new_state = stream.step(state, jnp.array(0))
40
+ ```
41
+ """
42
+
43
+ @property
44
+ def feature_dim(self) -> int:
45
+ """Return the dimension of observation vectors."""
46
+ ...
47
+
48
+ def init(self, key: Array) -> StateT:
49
+ """Initialize stream state.
50
+
51
+ Args:
52
+ key: JAX random key for initialization
53
+
54
+ Returns:
55
+ Initial stream state
56
+ """
57
+ ...
58
+
59
+ def step(self, state: StateT, idx: Array) -> tuple[TimeStep, StateT]:
60
+ """Generate one time step. Must be JIT-compatible.
61
+
62
+ This is a pure function that takes the current state and step index,
63
+ and returns a TimeStep along with the updated state. The step index
64
+ can be used for time-dependent behavior but is often ignored.
65
+
66
+ Args:
67
+ state: Current stream state
68
+ idx: Current step index (can be ignored for most streams)
69
+
70
+ Returns:
71
+ Tuple of (timestep, new_state)
72
+ """
73
+ ...