alberta-framework 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,198 @@
1
+ """Type definitions for the Alberta Framework.
2
+
3
+ This module defines the core data types used throughout the framework,
4
+ following JAX conventions with immutable NamedTuples for state management.
5
+ """
6
+
7
+ from typing import NamedTuple
8
+
9
+ import jax.numpy as jnp
10
+ from jax import Array
11
+
12
+ # Type aliases for clarity
13
+ Observation = Array # x_t: feature vector
14
+ Target = Array # y*_t: desired output
15
+ Prediction = Array # y_t: model output
16
+ Reward = float # r_t: scalar reward
17
+
18
+
19
+ class TimeStep(NamedTuple):
20
+ """Single experience from an experience stream.
21
+
22
+ Attributes:
23
+ observation: Feature vector x_t
24
+ target: Desired output y*_t (for supervised learning)
25
+ """
26
+
27
+ observation: Observation
28
+ target: Target
29
+
30
+
31
+ class LearnerState(NamedTuple):
32
+ """State for a linear learner.
33
+
34
+ Attributes:
35
+ weights: Weight vector for linear prediction
36
+ bias: Bias term
37
+ optimizer_state: State maintained by the optimizer
38
+ """
39
+
40
+ weights: Array
41
+ bias: Array
42
+ optimizer_state: "LMSState | IDBDState | AutostepState"
43
+
44
+
45
+ class LMSState(NamedTuple):
46
+ """State for the LMS (Least Mean Square) optimizer.
47
+
48
+ LMS uses a fixed step-size, so state only tracks the step-size parameter.
49
+
50
+ Attributes:
51
+ step_size: Fixed learning rate alpha
52
+ """
53
+
54
+ step_size: Array
55
+
56
+
57
+ class IDBDState(NamedTuple):
58
+ """State for the IDBD (Incremental Delta-Bar-Delta) optimizer.
59
+
60
+ IDBD maintains per-weight adaptive step-sizes that are meta-learned
61
+ based on the correlation of successive gradients.
62
+
63
+ Reference: Sutton 1992, "Adapting Bias by Gradient Descent"
64
+
65
+ Attributes:
66
+ log_step_sizes: Log of per-weight step-sizes (log alpha_i)
67
+ traces: Per-weight traces h_i for gradient correlation
68
+ meta_step_size: Meta learning rate beta for adapting step-sizes
69
+ bias_step_size: Step-size for the bias term
70
+ bias_trace: Trace for the bias term
71
+ """
72
+
73
+ log_step_sizes: Array # log(alpha_i) for numerical stability
74
+ traces: Array # h_i: trace of weight-feature products
75
+ meta_step_size: Array # beta: step-size for the step-sizes
76
+ bias_step_size: Array # Step-size for bias
77
+ bias_trace: Array # Trace for bias
78
+
79
+
80
+ class AutostepState(NamedTuple):
81
+ """State for the Autostep optimizer.
82
+
83
+ Autostep is a tuning-free step-size adaptation algorithm that normalizes
84
+ gradients to prevent large updates and adapts step-sizes based on
85
+ gradient correlation.
86
+
87
+ Reference: Mahmood et al. 2012, "Tuning-free step-size adaptation"
88
+
89
+ Attributes:
90
+ step_sizes: Per-weight step-sizes (alpha_i)
91
+ traces: Per-weight traces for gradient correlation (h_i)
92
+ normalizers: Running max absolute gradient per weight (v_i)
93
+ meta_step_size: Meta learning rate mu for adapting step-sizes
94
+ normalizer_decay: Decay factor for the normalizer (tau)
95
+ bias_step_size: Step-size for the bias term
96
+ bias_trace: Trace for the bias term
97
+ bias_normalizer: Normalizer for the bias gradient
98
+ """
99
+
100
+ step_sizes: Array # alpha_i
101
+ traces: Array # h_i
102
+ normalizers: Array # v_i: running max of |gradient|
103
+ meta_step_size: Array # mu
104
+ normalizer_decay: Array # tau
105
+ bias_step_size: Array
106
+ bias_trace: Array
107
+ bias_normalizer: Array
108
+
109
+
110
+ class StepSizeTrackingConfig(NamedTuple):
111
+ """Configuration for recording per-weight step-sizes during training.
112
+
113
+ Attributes:
114
+ interval: Record step-sizes every N steps
115
+ include_bias: Whether to also record the bias step-size
116
+ """
117
+
118
+ interval: int
119
+ include_bias: bool = True
120
+
121
+
122
+ class StepSizeHistory(NamedTuple):
123
+ """History of per-weight step-sizes recorded during training.
124
+
125
+ Attributes:
126
+ step_sizes: Per-weight step-sizes at each recording, shape (num_recordings, num_weights)
127
+ bias_step_sizes: Bias step-sizes at each recording, shape (num_recordings,) or None
128
+ recording_indices: Step indices where recordings were made, shape (num_recordings,)
129
+ """
130
+
131
+ step_sizes: Array # (num_recordings, num_weights)
132
+ bias_step_sizes: Array | None # (num_recordings,) or None
133
+ recording_indices: Array # (num_recordings,)
134
+
135
+
136
+ def create_lms_state(step_size: float = 0.01) -> LMSState:
137
+ """Create initial LMS optimizer state.
138
+
139
+ Args:
140
+ step_size: Fixed learning rate
141
+
142
+ Returns:
143
+ Initial LMS state
144
+ """
145
+ return LMSState(step_size=jnp.array(step_size, dtype=jnp.float32))
146
+
147
+
148
+ def create_idbd_state(
149
+ feature_dim: int,
150
+ initial_step_size: float = 0.01,
151
+ meta_step_size: float = 0.01,
152
+ ) -> IDBDState:
153
+ """Create initial IDBD optimizer state.
154
+
155
+ Args:
156
+ feature_dim: Dimension of the feature vector
157
+ initial_step_size: Initial per-weight step-size
158
+ meta_step_size: Meta learning rate for adapting step-sizes
159
+
160
+ Returns:
161
+ Initial IDBD state
162
+ """
163
+ return IDBDState(
164
+ log_step_sizes=jnp.full(feature_dim, jnp.log(initial_step_size), dtype=jnp.float32),
165
+ traces=jnp.zeros(feature_dim, dtype=jnp.float32),
166
+ meta_step_size=jnp.array(meta_step_size, dtype=jnp.float32),
167
+ bias_step_size=jnp.array(initial_step_size, dtype=jnp.float32),
168
+ bias_trace=jnp.array(0.0, dtype=jnp.float32),
169
+ )
170
+
171
+
172
+ def create_autostep_state(
173
+ feature_dim: int,
174
+ initial_step_size: float = 0.01,
175
+ meta_step_size: float = 0.01,
176
+ normalizer_decay: float = 0.99,
177
+ ) -> AutostepState:
178
+ """Create initial Autostep optimizer state.
179
+
180
+ Args:
181
+ feature_dim: Dimension of the feature vector
182
+ initial_step_size: Initial per-weight step-size
183
+ meta_step_size: Meta learning rate for adapting step-sizes
184
+ normalizer_decay: Decay factor for gradient normalizers
185
+
186
+ Returns:
187
+ Initial Autostep state
188
+ """
189
+ return AutostepState(
190
+ step_sizes=jnp.full(feature_dim, initial_step_size, dtype=jnp.float32),
191
+ traces=jnp.zeros(feature_dim, dtype=jnp.float32),
192
+ normalizers=jnp.ones(feature_dim, dtype=jnp.float32),
193
+ meta_step_size=jnp.array(meta_step_size, dtype=jnp.float32),
194
+ normalizer_decay=jnp.array(normalizer_decay, dtype=jnp.float32),
195
+ bias_step_size=jnp.array(initial_step_size, dtype=jnp.float32),
196
+ bias_trace=jnp.array(0.0, dtype=jnp.float32),
197
+ bias_normalizer=jnp.array(1.0, dtype=jnp.float32),
198
+ )
File without changes
@@ -0,0 +1,83 @@
1
+ """Experience streams for continual learning."""
2
+
3
+ from alberta_framework.streams.base import ScanStream
4
+ from alberta_framework.streams.synthetic import (
5
+ AbruptChangeState,
6
+ AbruptChangeStream,
7
+ AbruptChangeTarget,
8
+ CyclicState,
9
+ CyclicStream,
10
+ CyclicTarget,
11
+ DynamicScaleShiftState,
12
+ DynamicScaleShiftStream,
13
+ PeriodicChangeState,
14
+ PeriodicChangeStream,
15
+ PeriodicChangeTarget,
16
+ RandomWalkState,
17
+ RandomWalkStream,
18
+ RandomWalkTarget,
19
+ ScaleDriftState,
20
+ ScaleDriftStream,
21
+ ScaledStreamState,
22
+ ScaledStreamWrapper,
23
+ SuttonExperiment1State,
24
+ SuttonExperiment1Stream,
25
+ make_scale_range,
26
+ )
27
+
28
+ __all__ = [
29
+ # Protocol
30
+ "ScanStream",
31
+ # Stream classes
32
+ "AbruptChangeState",
33
+ "AbruptChangeStream",
34
+ "AbruptChangeTarget",
35
+ "CyclicState",
36
+ "CyclicStream",
37
+ "CyclicTarget",
38
+ "DynamicScaleShiftState",
39
+ "DynamicScaleShiftStream",
40
+ "PeriodicChangeState",
41
+ "PeriodicChangeStream",
42
+ "PeriodicChangeTarget",
43
+ "RandomWalkState",
44
+ "RandomWalkStream",
45
+ "RandomWalkTarget",
46
+ "ScaleDriftState",
47
+ "ScaleDriftStream",
48
+ "ScaledStreamState",
49
+ "ScaledStreamWrapper",
50
+ "SuttonExperiment1State",
51
+ "SuttonExperiment1Stream",
52
+ # Utilities
53
+ "make_scale_range",
54
+ ]
55
+
56
+ # Gymnasium streams are optional - only export if gymnasium is installed
57
+ try:
58
+ from alberta_framework.streams.gymnasium import (
59
+ GymnasiumStream,
60
+ PredictionMode,
61
+ TDStream,
62
+ collect_trajectory,
63
+ learn_from_trajectory,
64
+ learn_from_trajectory_normalized,
65
+ make_epsilon_greedy_policy,
66
+ make_gymnasium_stream,
67
+ make_random_policy,
68
+ )
69
+
70
+ __all__ += [
71
+ "GymnasiumStream",
72
+ "PredictionMode",
73
+ "TDStream",
74
+ "collect_trajectory",
75
+ "learn_from_trajectory",
76
+ "learn_from_trajectory_normalized",
77
+ "make_epsilon_greedy_policy",
78
+ "make_gymnasium_stream",
79
+ "make_random_policy",
80
+ ]
81
+ except ImportError:
82
+ # gymnasium not installed
83
+ pass
@@ -0,0 +1,70 @@
1
+ """Base protocol for experience streams.
2
+
3
+ Experience streams generate temporally-uniform experience for continual learning.
4
+ Every time step produces a new observation-target pair.
5
+
6
+ This module defines the ScanStream protocol for JAX scan-compatible streams.
7
+ All streams implement pure functions that can be JIT-compiled.
8
+ """
9
+
10
+ from typing import Protocol, TypeVar
11
+
12
+ from jax import Array
13
+
14
+ from alberta_framework.core.types import TimeStep
15
+
16
+ # Type variable for stream state
17
+ StateT = TypeVar("StateT")
18
+
19
+
20
+ class ScanStream(Protocol[StateT]):
21
+ """Protocol for JAX scan-compatible experience streams.
22
+
23
+ Streams generate temporally-uniform experience for continual learning.
24
+ Unlike iterator-based streams, ScanStream uses pure functions that
25
+ can be compiled with JAX's JIT and used with jax.lax.scan.
26
+
27
+ The stream should be non-stationary to test continual learning
28
+ capabilities - the underlying target function changes over time.
29
+
30
+ Type Parameters:
31
+ StateT: The state type maintained by this stream
32
+
33
+ Example:
34
+ >>> stream = RandomWalkStream(feature_dim=10, drift_rate=0.001)
35
+ >>> key = jax.random.key(42)
36
+ >>> state = stream.init(key)
37
+ >>> timestep, new_state = stream.step(state, jnp.array(0))
38
+ """
39
+
40
+ @property
41
+ def feature_dim(self) -> int:
42
+ """Return the dimension of observation vectors."""
43
+ ...
44
+
45
+ def init(self, key: Array) -> StateT:
46
+ """Initialize stream state.
47
+
48
+ Args:
49
+ key: JAX random key for initialization
50
+
51
+ Returns:
52
+ Initial stream state
53
+ """
54
+ ...
55
+
56
+ def step(self, state: StateT, idx: Array) -> tuple[TimeStep, StateT]:
57
+ """Generate one time step. Must be JIT-compatible.
58
+
59
+ This is a pure function that takes the current state and step index,
60
+ and returns a TimeStep along with the updated state. The step index
61
+ can be used for time-dependent behavior but is often ignored.
62
+
63
+ Args:
64
+ state: Current stream state
65
+ idx: Current step index (can be ignored for most streams)
66
+
67
+ Returns:
68
+ Tuple of (timestep, new_state)
69
+ """
70
+ ...