alberta-framework 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alberta_framework/__init__.py +225 -0
- alberta_framework/core/__init__.py +27 -0
- alberta_framework/core/learners.py +1070 -0
- alberta_framework/core/normalizers.py +192 -0
- alberta_framework/core/optimizers.py +424 -0
- alberta_framework/core/types.py +271 -0
- alberta_framework/py.typed +0 -0
- alberta_framework/streams/__init__.py +83 -0
- alberta_framework/streams/base.py +73 -0
- alberta_framework/streams/gymnasium.py +655 -0
- alberta_framework/streams/synthetic.py +1001 -0
- alberta_framework/utils/__init__.py +113 -0
- alberta_framework/utils/experiments.py +335 -0
- alberta_framework/utils/export.py +509 -0
- alberta_framework/utils/metrics.py +112 -0
- alberta_framework/utils/statistics.py +527 -0
- alberta_framework/utils/timing.py +144 -0
- alberta_framework/utils/visualization.py +571 -0
- alberta_framework-0.2.2.dist-info/METADATA +206 -0
- alberta_framework-0.2.2.dist-info/RECORD +22 -0
- alberta_framework-0.2.2.dist-info/WHEEL +4 -0
- alberta_framework-0.2.2.dist-info/licenses/LICENSE +190 -0
|
@@ -0,0 +1,1070 @@
|
|
|
1
|
+
"""Learning units for continual learning.
|
|
2
|
+
|
|
3
|
+
Implements learners that combine function approximation with optimizers
|
|
4
|
+
for temporally-uniform learning. Uses JAX's scan for efficient JIT-compiled
|
|
5
|
+
training loops.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import NamedTuple, cast
|
|
9
|
+
|
|
10
|
+
import jax
|
|
11
|
+
import jax.numpy as jnp
|
|
12
|
+
from jax import Array
|
|
13
|
+
|
|
14
|
+
from alberta_framework.core.normalizers import NormalizerState, OnlineNormalizer
|
|
15
|
+
from alberta_framework.core.optimizers import LMS, Optimizer
|
|
16
|
+
from alberta_framework.core.types import (
|
|
17
|
+
AutostepState,
|
|
18
|
+
BatchedLearningResult,
|
|
19
|
+
BatchedNormalizedResult,
|
|
20
|
+
IDBDState,
|
|
21
|
+
LearnerState,
|
|
22
|
+
LMSState,
|
|
23
|
+
NormalizerHistory,
|
|
24
|
+
NormalizerTrackingConfig,
|
|
25
|
+
Observation,
|
|
26
|
+
Prediction,
|
|
27
|
+
StepSizeHistory,
|
|
28
|
+
StepSizeTrackingConfig,
|
|
29
|
+
Target,
|
|
30
|
+
)
|
|
31
|
+
from alberta_framework.streams.base import ScanStream
|
|
32
|
+
|
|
33
|
+
# Type alias for any optimizer type
|
|
34
|
+
AnyOptimizer = Optimizer[LMSState] | Optimizer[IDBDState] | Optimizer[AutostepState]
|
|
35
|
+
|
|
36
|
+
class UpdateResult(NamedTuple):
|
|
37
|
+
"""Result of a learner update step.
|
|
38
|
+
|
|
39
|
+
Attributes:
|
|
40
|
+
state: Updated learner state
|
|
41
|
+
prediction: Prediction made before update
|
|
42
|
+
error: Prediction error
|
|
43
|
+
metrics: Array of metrics [squared_error, error, ...]
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
state: LearnerState
|
|
47
|
+
prediction: Prediction
|
|
48
|
+
error: Array
|
|
49
|
+
metrics: Array
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class LinearLearner:
|
|
53
|
+
"""Linear function approximator with pluggable optimizer.
|
|
54
|
+
|
|
55
|
+
Computes predictions as: `y = w @ x + b`
|
|
56
|
+
|
|
57
|
+
The learner maintains weights and bias, delegating the adaptation
|
|
58
|
+
of learning rates to the optimizer (e.g., LMS or IDBD).
|
|
59
|
+
|
|
60
|
+
This follows the Alberta Plan philosophy of temporal uniformity:
|
|
61
|
+
every component updates at every time step.
|
|
62
|
+
|
|
63
|
+
Attributes:
|
|
64
|
+
optimizer: The optimizer to use for weight updates
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
def __init__(self, optimizer: AnyOptimizer | None = None):
|
|
68
|
+
"""Initialize the linear learner.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
optimizer: Optimizer for weight updates. Defaults to LMS(0.01)
|
|
72
|
+
"""
|
|
73
|
+
self._optimizer: AnyOptimizer = optimizer or LMS(step_size=0.01)
|
|
74
|
+
|
|
75
|
+
def init(self, feature_dim: int) -> LearnerState:
|
|
76
|
+
"""Initialize learner state.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
feature_dim: Dimension of the input feature vector
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
Initial learner state with zero weights and bias
|
|
83
|
+
"""
|
|
84
|
+
optimizer_state = self._optimizer.init(feature_dim)
|
|
85
|
+
|
|
86
|
+
return LearnerState(
|
|
87
|
+
weights=jnp.zeros(feature_dim, dtype=jnp.float32),
|
|
88
|
+
bias=jnp.array(0.0, dtype=jnp.float32),
|
|
89
|
+
optimizer_state=optimizer_state,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
def predict(self, state: LearnerState, observation: Observation) -> Prediction:
|
|
93
|
+
"""Compute prediction for an observation.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
state: Current learner state
|
|
97
|
+
observation: Input feature vector
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
Scalar prediction `y = w @ x + b`
|
|
101
|
+
"""
|
|
102
|
+
return jnp.atleast_1d(jnp.dot(state.weights, observation) + state.bias)
|
|
103
|
+
|
|
104
|
+
def update(
|
|
105
|
+
self,
|
|
106
|
+
state: LearnerState,
|
|
107
|
+
observation: Observation,
|
|
108
|
+
target: Target,
|
|
109
|
+
) -> UpdateResult:
|
|
110
|
+
"""Update learner given observation and target.
|
|
111
|
+
|
|
112
|
+
Performs one step of the learning algorithm:
|
|
113
|
+
1. Compute prediction
|
|
114
|
+
2. Compute error
|
|
115
|
+
3. Get weight updates from optimizer
|
|
116
|
+
4. Apply updates to weights and bias
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
state: Current learner state
|
|
120
|
+
observation: Input feature vector
|
|
121
|
+
target: Desired output
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
UpdateResult with new state, prediction, error, and metrics
|
|
125
|
+
"""
|
|
126
|
+
# Make prediction
|
|
127
|
+
prediction = self.predict(state, observation)
|
|
128
|
+
|
|
129
|
+
# Compute error (target - prediction)
|
|
130
|
+
error = jnp.squeeze(target) - jnp.squeeze(prediction)
|
|
131
|
+
|
|
132
|
+
# Get update from optimizer
|
|
133
|
+
# Note: type ignore needed because we can't statically prove optimizer_state
|
|
134
|
+
# matches the optimizer's expected state type (though they will at runtime)
|
|
135
|
+
opt_update = self._optimizer.update(
|
|
136
|
+
state.optimizer_state, # type: ignore[arg-type]
|
|
137
|
+
error,
|
|
138
|
+
observation,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# Apply updates
|
|
142
|
+
new_weights = state.weights + opt_update.weight_delta
|
|
143
|
+
new_bias = state.bias + opt_update.bias_delta
|
|
144
|
+
|
|
145
|
+
new_state = LearnerState(
|
|
146
|
+
weights=new_weights,
|
|
147
|
+
bias=new_bias,
|
|
148
|
+
optimizer_state=opt_update.new_state,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# Pack metrics as array for scan compatibility
|
|
152
|
+
# Format: [squared_error, error, mean_step_size (if adaptive)]
|
|
153
|
+
squared_error = error**2
|
|
154
|
+
mean_step_size = opt_update.metrics.get("mean_step_size", 0.0)
|
|
155
|
+
metrics = jnp.array([squared_error, error, mean_step_size], dtype=jnp.float32)
|
|
156
|
+
|
|
157
|
+
return UpdateResult(
|
|
158
|
+
state=new_state,
|
|
159
|
+
prediction=prediction,
|
|
160
|
+
error=jnp.atleast_1d(error),
|
|
161
|
+
metrics=metrics,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def run_learning_loop[StreamStateT](
|
|
166
|
+
learner: LinearLearner,
|
|
167
|
+
stream: ScanStream[StreamStateT],
|
|
168
|
+
num_steps: int,
|
|
169
|
+
key: Array,
|
|
170
|
+
learner_state: LearnerState | None = None,
|
|
171
|
+
step_size_tracking: StepSizeTrackingConfig | None = None,
|
|
172
|
+
) -> tuple[LearnerState, Array] | tuple[LearnerState, Array, StepSizeHistory]:
|
|
173
|
+
"""Run the learning loop using jax.lax.scan.
|
|
174
|
+
|
|
175
|
+
This is a JIT-compiled learning loop that uses scan for efficiency.
|
|
176
|
+
It returns metrics as a fixed-size array rather than a list of dicts.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
learner: The learner to train
|
|
180
|
+
stream: Experience stream providing (observation, target) pairs
|
|
181
|
+
num_steps: Number of learning steps to run
|
|
182
|
+
key: JAX random key for stream initialization
|
|
183
|
+
learner_state: Initial state (if None, will be initialized from stream)
|
|
184
|
+
step_size_tracking: Optional config for recording per-weight step-sizes.
|
|
185
|
+
When provided, returns a 3-tuple including StepSizeHistory.
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
If step_size_tracking is None:
|
|
189
|
+
Tuple of (final_state, metrics_array) where metrics_array has shape
|
|
190
|
+
(num_steps, 3) with columns [squared_error, error, mean_step_size]
|
|
191
|
+
If step_size_tracking is provided:
|
|
192
|
+
Tuple of (final_state, metrics_array, step_size_history)
|
|
193
|
+
|
|
194
|
+
Raises:
|
|
195
|
+
ValueError: If step_size_tracking.interval is less than 1 or greater than num_steps
|
|
196
|
+
"""
|
|
197
|
+
# Validate tracking config
|
|
198
|
+
if step_size_tracking is not None:
|
|
199
|
+
if step_size_tracking.interval < 1:
|
|
200
|
+
raise ValueError(
|
|
201
|
+
f"step_size_tracking.interval must be >= 1, got {step_size_tracking.interval}"
|
|
202
|
+
)
|
|
203
|
+
if step_size_tracking.interval > num_steps:
|
|
204
|
+
raise ValueError(
|
|
205
|
+
f"step_size_tracking.interval ({step_size_tracking.interval}) "
|
|
206
|
+
f"must be <= num_steps ({num_steps})"
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
# Initialize states
|
|
210
|
+
if learner_state is None:
|
|
211
|
+
learner_state = learner.init(stream.feature_dim)
|
|
212
|
+
stream_state = stream.init(key)
|
|
213
|
+
|
|
214
|
+
feature_dim = stream.feature_dim
|
|
215
|
+
|
|
216
|
+
if step_size_tracking is None:
|
|
217
|
+
# Original behavior without tracking
|
|
218
|
+
def step_fn(
|
|
219
|
+
carry: tuple[LearnerState, StreamStateT], idx: Array
|
|
220
|
+
) -> tuple[tuple[LearnerState, StreamStateT], Array]:
|
|
221
|
+
l_state, s_state = carry
|
|
222
|
+
timestep, new_s_state = stream.step(s_state, idx)
|
|
223
|
+
result = learner.update(l_state, timestep.observation, timestep.target)
|
|
224
|
+
return (result.state, new_s_state), result.metrics
|
|
225
|
+
|
|
226
|
+
(final_learner, _), metrics = jax.lax.scan(
|
|
227
|
+
step_fn, (learner_state, stream_state), jnp.arange(num_steps)
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
return final_learner, metrics
|
|
231
|
+
|
|
232
|
+
else:
|
|
233
|
+
# Step-size tracking enabled
|
|
234
|
+
interval = step_size_tracking.interval
|
|
235
|
+
include_bias = step_size_tracking.include_bias
|
|
236
|
+
num_recordings = num_steps // interval
|
|
237
|
+
|
|
238
|
+
# Pre-allocate history arrays
|
|
239
|
+
step_size_history = jnp.zeros((num_recordings, feature_dim), dtype=jnp.float32)
|
|
240
|
+
bias_history = (
|
|
241
|
+
jnp.zeros(num_recordings, dtype=jnp.float32) if include_bias else None
|
|
242
|
+
)
|
|
243
|
+
recording_indices = jnp.zeros(num_recordings, dtype=jnp.int32)
|
|
244
|
+
|
|
245
|
+
# Check if we need to track Autostep normalizers
|
|
246
|
+
# We detect this at trace time by checking the initial optimizer state
|
|
247
|
+
track_normalizers = hasattr(learner_state.optimizer_state, "normalizers")
|
|
248
|
+
normalizer_history = (
|
|
249
|
+
jnp.zeros((num_recordings, feature_dim), dtype=jnp.float32)
|
|
250
|
+
if track_normalizers
|
|
251
|
+
else None
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
def step_fn_with_tracking(
|
|
255
|
+
carry: tuple[
|
|
256
|
+
LearnerState, StreamStateT, Array, Array | None, Array, Array | None
|
|
257
|
+
],
|
|
258
|
+
idx: Array,
|
|
259
|
+
) -> tuple[
|
|
260
|
+
tuple[LearnerState, StreamStateT, Array, Array | None, Array, Array | None],
|
|
261
|
+
Array,
|
|
262
|
+
]:
|
|
263
|
+
l_state, s_state, ss_history, b_history, rec_indices, norm_history = carry
|
|
264
|
+
|
|
265
|
+
# Perform learning step
|
|
266
|
+
timestep, new_s_state = stream.step(s_state, idx)
|
|
267
|
+
result = learner.update(l_state, timestep.observation, timestep.target)
|
|
268
|
+
|
|
269
|
+
# Check if we should record at this step (idx % interval == 0)
|
|
270
|
+
should_record = (idx % interval) == 0
|
|
271
|
+
recording_idx = idx // interval
|
|
272
|
+
|
|
273
|
+
# Extract current step-sizes
|
|
274
|
+
# Use hasattr checks at trace time (this works because the type is fixed)
|
|
275
|
+
opt_state = result.state.optimizer_state
|
|
276
|
+
if hasattr(opt_state, "log_step_sizes"):
|
|
277
|
+
# IDBD stores log step-sizes
|
|
278
|
+
weight_ss = jnp.exp(opt_state.log_step_sizes) # type: ignore[union-attr]
|
|
279
|
+
bias_ss = opt_state.bias_step_size # type: ignore[union-attr]
|
|
280
|
+
elif hasattr(opt_state, "step_sizes"):
|
|
281
|
+
# Autostep stores step-sizes directly
|
|
282
|
+
weight_ss = opt_state.step_sizes # type: ignore[union-attr]
|
|
283
|
+
bias_ss = opt_state.bias_step_size # type: ignore[union-attr]
|
|
284
|
+
else:
|
|
285
|
+
# LMS has a single fixed step-size
|
|
286
|
+
weight_ss = jnp.full(feature_dim, opt_state.step_size)
|
|
287
|
+
bias_ss = opt_state.step_size
|
|
288
|
+
|
|
289
|
+
# Conditionally update history arrays
|
|
290
|
+
new_ss_history = jax.lax.cond(
|
|
291
|
+
should_record,
|
|
292
|
+
lambda _: ss_history.at[recording_idx].set(weight_ss),
|
|
293
|
+
lambda _: ss_history,
|
|
294
|
+
None,
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
new_b_history = b_history
|
|
298
|
+
if b_history is not None:
|
|
299
|
+
new_b_history = jax.lax.cond(
|
|
300
|
+
should_record,
|
|
301
|
+
lambda _: b_history.at[recording_idx].set(bias_ss),
|
|
302
|
+
lambda _: b_history,
|
|
303
|
+
None,
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
new_rec_indices = jax.lax.cond(
|
|
307
|
+
should_record,
|
|
308
|
+
lambda _: rec_indices.at[recording_idx].set(idx),
|
|
309
|
+
lambda _: rec_indices,
|
|
310
|
+
None,
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
# Track Autostep normalizers (v_i) if applicable
|
|
314
|
+
new_norm_history = norm_history
|
|
315
|
+
if norm_history is not None and hasattr(opt_state, "normalizers"):
|
|
316
|
+
new_norm_history = jax.lax.cond(
|
|
317
|
+
should_record,
|
|
318
|
+
lambda _: norm_history.at[recording_idx].set(
|
|
319
|
+
opt_state.normalizers # type: ignore[union-attr]
|
|
320
|
+
),
|
|
321
|
+
lambda _: norm_history,
|
|
322
|
+
None,
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
return (
|
|
326
|
+
result.state,
|
|
327
|
+
new_s_state,
|
|
328
|
+
new_ss_history,
|
|
329
|
+
new_b_history,
|
|
330
|
+
new_rec_indices,
|
|
331
|
+
new_norm_history,
|
|
332
|
+
), result.metrics
|
|
333
|
+
|
|
334
|
+
initial_carry = (
|
|
335
|
+
learner_state,
|
|
336
|
+
stream_state,
|
|
337
|
+
step_size_history,
|
|
338
|
+
bias_history,
|
|
339
|
+
recording_indices,
|
|
340
|
+
normalizer_history,
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
(
|
|
344
|
+
final_learner,
|
|
345
|
+
_,
|
|
346
|
+
final_ss_history,
|
|
347
|
+
final_b_history,
|
|
348
|
+
final_rec_indices,
|
|
349
|
+
final_norm_history,
|
|
350
|
+
), metrics = jax.lax.scan(
|
|
351
|
+
step_fn_with_tracking, initial_carry, jnp.arange(num_steps)
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
history = StepSizeHistory(
|
|
355
|
+
step_sizes=final_ss_history,
|
|
356
|
+
bias_step_sizes=final_b_history,
|
|
357
|
+
recording_indices=final_rec_indices,
|
|
358
|
+
normalizers=final_norm_history,
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
return final_learner, metrics, history
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
class NormalizedLearnerState(NamedTuple):
|
|
365
|
+
"""State for a learner with online feature normalization.
|
|
366
|
+
|
|
367
|
+
Attributes:
|
|
368
|
+
learner_state: Underlying learner state (weights, bias, optimizer)
|
|
369
|
+
normalizer_state: Online normalizer state (mean, var estimates)
|
|
370
|
+
"""
|
|
371
|
+
|
|
372
|
+
learner_state: LearnerState
|
|
373
|
+
normalizer_state: NormalizerState
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
class NormalizedUpdateResult(NamedTuple):
|
|
377
|
+
"""Result of a normalized learner update step.
|
|
378
|
+
|
|
379
|
+
Attributes:
|
|
380
|
+
state: Updated normalized learner state
|
|
381
|
+
prediction: Prediction made before update
|
|
382
|
+
error: Prediction error
|
|
383
|
+
metrics: Array of metrics [squared_error, error, mean_step_size, normalizer_mean_var]
|
|
384
|
+
"""
|
|
385
|
+
|
|
386
|
+
state: NormalizedLearnerState
|
|
387
|
+
prediction: Prediction
|
|
388
|
+
error: Array
|
|
389
|
+
metrics: Array
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
class NormalizedLinearLearner:
|
|
393
|
+
"""Linear learner with online feature normalization.
|
|
394
|
+
|
|
395
|
+
Wraps a LinearLearner with online feature normalization, following
|
|
396
|
+
the Alberta Plan's approach to handling varying feature scales.
|
|
397
|
+
|
|
398
|
+
Normalization is applied to features before prediction and learning:
|
|
399
|
+
x_normalized = (x - mean) / (std + epsilon)
|
|
400
|
+
|
|
401
|
+
The normalizer statistics update at every time step, maintaining
|
|
402
|
+
temporal uniformity.
|
|
403
|
+
|
|
404
|
+
Attributes:
|
|
405
|
+
learner: Underlying linear learner
|
|
406
|
+
normalizer: Online feature normalizer
|
|
407
|
+
"""
|
|
408
|
+
|
|
409
|
+
def __init__(
|
|
410
|
+
self,
|
|
411
|
+
optimizer: AnyOptimizer | None = None,
|
|
412
|
+
normalizer: OnlineNormalizer | None = None,
|
|
413
|
+
):
|
|
414
|
+
"""Initialize the normalized linear learner.
|
|
415
|
+
|
|
416
|
+
Args:
|
|
417
|
+
optimizer: Optimizer for weight updates. Defaults to LMS(0.01)
|
|
418
|
+
normalizer: Feature normalizer. Defaults to OnlineNormalizer()
|
|
419
|
+
"""
|
|
420
|
+
self._learner = LinearLearner(optimizer=optimizer or LMS(step_size=0.01))
|
|
421
|
+
self._normalizer = normalizer or OnlineNormalizer()
|
|
422
|
+
|
|
423
|
+
def init(self, feature_dim: int) -> NormalizedLearnerState:
|
|
424
|
+
"""Initialize normalized learner state.
|
|
425
|
+
|
|
426
|
+
Args:
|
|
427
|
+
feature_dim: Dimension of the input feature vector
|
|
428
|
+
|
|
429
|
+
Returns:
|
|
430
|
+
Initial state with zero weights and unit variance estimates
|
|
431
|
+
"""
|
|
432
|
+
return NormalizedLearnerState(
|
|
433
|
+
learner_state=self._learner.init(feature_dim),
|
|
434
|
+
normalizer_state=self._normalizer.init(feature_dim),
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
def predict(
|
|
438
|
+
self,
|
|
439
|
+
state: NormalizedLearnerState,
|
|
440
|
+
observation: Observation,
|
|
441
|
+
) -> Prediction:
|
|
442
|
+
"""Compute prediction for an observation.
|
|
443
|
+
|
|
444
|
+
Normalizes the observation using current statistics before prediction.
|
|
445
|
+
|
|
446
|
+
Args:
|
|
447
|
+
state: Current normalized learner state
|
|
448
|
+
observation: Raw (unnormalized) input feature vector
|
|
449
|
+
|
|
450
|
+
Returns:
|
|
451
|
+
Scalar prediction y = w @ normalize(x) + b
|
|
452
|
+
"""
|
|
453
|
+
normalized_obs = self._normalizer.normalize_only(
|
|
454
|
+
state.normalizer_state, observation
|
|
455
|
+
)
|
|
456
|
+
return self._learner.predict(state.learner_state, normalized_obs)
|
|
457
|
+
|
|
458
|
+
def update(
|
|
459
|
+
self,
|
|
460
|
+
state: NormalizedLearnerState,
|
|
461
|
+
observation: Observation,
|
|
462
|
+
target: Target,
|
|
463
|
+
) -> NormalizedUpdateResult:
|
|
464
|
+
"""Update learner given observation and target.
|
|
465
|
+
|
|
466
|
+
Performs one step of the learning algorithm:
|
|
467
|
+
1. Normalize observation (and update normalizer statistics)
|
|
468
|
+
2. Compute prediction using normalized features
|
|
469
|
+
3. Compute error
|
|
470
|
+
4. Get weight updates from optimizer
|
|
471
|
+
5. Apply updates
|
|
472
|
+
|
|
473
|
+
Args:
|
|
474
|
+
state: Current normalized learner state
|
|
475
|
+
observation: Raw (unnormalized) input feature vector
|
|
476
|
+
target: Desired output
|
|
477
|
+
|
|
478
|
+
Returns:
|
|
479
|
+
NormalizedUpdateResult with new state, prediction, error, and metrics
|
|
480
|
+
"""
|
|
481
|
+
# Normalize observation and update normalizer state
|
|
482
|
+
normalized_obs, new_normalizer_state = self._normalizer.normalize(
|
|
483
|
+
state.normalizer_state, observation
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
# Delegate to underlying learner
|
|
487
|
+
result = self._learner.update(
|
|
488
|
+
state.learner_state,
|
|
489
|
+
normalized_obs,
|
|
490
|
+
target,
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
# Build combined state
|
|
494
|
+
new_state = NormalizedLearnerState(
|
|
495
|
+
learner_state=result.state,
|
|
496
|
+
normalizer_state=new_normalizer_state,
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
# Add normalizer metrics to the metrics array
|
|
500
|
+
normalizer_mean_var = jnp.mean(new_normalizer_state.var)
|
|
501
|
+
metrics = jnp.concatenate([result.metrics, jnp.array([normalizer_mean_var])])
|
|
502
|
+
|
|
503
|
+
return NormalizedUpdateResult(
|
|
504
|
+
state=new_state,
|
|
505
|
+
prediction=result.prediction,
|
|
506
|
+
error=result.error,
|
|
507
|
+
metrics=metrics,
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
def run_normalized_learning_loop[StreamStateT](
|
|
512
|
+
learner: NormalizedLinearLearner,
|
|
513
|
+
stream: ScanStream[StreamStateT],
|
|
514
|
+
num_steps: int,
|
|
515
|
+
key: Array,
|
|
516
|
+
learner_state: NormalizedLearnerState | None = None,
|
|
517
|
+
step_size_tracking: StepSizeTrackingConfig | None = None,
|
|
518
|
+
normalizer_tracking: NormalizerTrackingConfig | None = None,
|
|
519
|
+
) -> (
|
|
520
|
+
tuple[NormalizedLearnerState, Array]
|
|
521
|
+
| tuple[NormalizedLearnerState, Array, StepSizeHistory]
|
|
522
|
+
| tuple[NormalizedLearnerState, Array, NormalizerHistory]
|
|
523
|
+
| tuple[NormalizedLearnerState, Array, StepSizeHistory, NormalizerHistory]
|
|
524
|
+
):
|
|
525
|
+
"""Run the learning loop with normalization using jax.lax.scan.
|
|
526
|
+
|
|
527
|
+
Args:
|
|
528
|
+
learner: The normalized learner to train
|
|
529
|
+
stream: Experience stream providing (observation, target) pairs
|
|
530
|
+
num_steps: Number of learning steps to run
|
|
531
|
+
key: JAX random key for stream initialization
|
|
532
|
+
learner_state: Initial state (if None, will be initialized from stream)
|
|
533
|
+
step_size_tracking: Optional config for recording per-weight step-sizes.
|
|
534
|
+
When provided, returns StepSizeHistory including Autostep normalizers if applicable.
|
|
535
|
+
normalizer_tracking: Optional config for recording per-feature normalizer state.
|
|
536
|
+
When provided, returns NormalizerHistory with means and variances over time.
|
|
537
|
+
|
|
538
|
+
Returns:
|
|
539
|
+
If no tracking:
|
|
540
|
+
Tuple of (final_state, metrics_array) where metrics_array has shape
|
|
541
|
+
(num_steps, 4) with columns [squared_error, error, mean_step_size, normalizer_mean_var]
|
|
542
|
+
If step_size_tracking only:
|
|
543
|
+
Tuple of (final_state, metrics_array, step_size_history)
|
|
544
|
+
If normalizer_tracking only:
|
|
545
|
+
Tuple of (final_state, metrics_array, normalizer_history)
|
|
546
|
+
If both:
|
|
547
|
+
Tuple of (final_state, metrics_array, step_size_history, normalizer_history)
|
|
548
|
+
|
|
549
|
+
Raises:
|
|
550
|
+
ValueError: If tracking interval is invalid
|
|
551
|
+
"""
|
|
552
|
+
# Validate tracking configs
|
|
553
|
+
if step_size_tracking is not None:
|
|
554
|
+
if step_size_tracking.interval < 1:
|
|
555
|
+
raise ValueError(
|
|
556
|
+
f"step_size_tracking.interval must be >= 1, got {step_size_tracking.interval}"
|
|
557
|
+
)
|
|
558
|
+
if step_size_tracking.interval > num_steps:
|
|
559
|
+
raise ValueError(
|
|
560
|
+
f"step_size_tracking.interval ({step_size_tracking.interval}) "
|
|
561
|
+
f"must be <= num_steps ({num_steps})"
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
if normalizer_tracking is not None:
|
|
565
|
+
if normalizer_tracking.interval < 1:
|
|
566
|
+
raise ValueError(
|
|
567
|
+
f"normalizer_tracking.interval must be >= 1, got {normalizer_tracking.interval}"
|
|
568
|
+
)
|
|
569
|
+
if normalizer_tracking.interval > num_steps:
|
|
570
|
+
raise ValueError(
|
|
571
|
+
f"normalizer_tracking.interval ({normalizer_tracking.interval}) "
|
|
572
|
+
f"must be <= num_steps ({num_steps})"
|
|
573
|
+
)
|
|
574
|
+
|
|
575
|
+
# Initialize states
|
|
576
|
+
if learner_state is None:
|
|
577
|
+
learner_state = learner.init(stream.feature_dim)
|
|
578
|
+
stream_state = stream.init(key)
|
|
579
|
+
|
|
580
|
+
feature_dim = stream.feature_dim
|
|
581
|
+
|
|
582
|
+
# No tracking - simple case
|
|
583
|
+
if step_size_tracking is None and normalizer_tracking is None:
|
|
584
|
+
|
|
585
|
+
def step_fn(
|
|
586
|
+
carry: tuple[NormalizedLearnerState, StreamStateT], idx: Array
|
|
587
|
+
) -> tuple[tuple[NormalizedLearnerState, StreamStateT], Array]:
|
|
588
|
+
l_state, s_state = carry
|
|
589
|
+
timestep, new_s_state = stream.step(s_state, idx)
|
|
590
|
+
result = learner.update(l_state, timestep.observation, timestep.target)
|
|
591
|
+
return (result.state, new_s_state), result.metrics
|
|
592
|
+
|
|
593
|
+
(final_learner, _), metrics = jax.lax.scan(
|
|
594
|
+
step_fn, (learner_state, stream_state), jnp.arange(num_steps)
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
return final_learner, metrics
|
|
598
|
+
|
|
599
|
+
# Tracking enabled - need to set up history arrays
|
|
600
|
+
ss_interval = step_size_tracking.interval if step_size_tracking else num_steps + 1
|
|
601
|
+
norm_interval = (
|
|
602
|
+
normalizer_tracking.interval if normalizer_tracking else num_steps + 1
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
ss_num_recordings = num_steps // ss_interval if step_size_tracking else 0
|
|
606
|
+
norm_num_recordings = num_steps // norm_interval if normalizer_tracking else 0
|
|
607
|
+
|
|
608
|
+
# Pre-allocate step-size history arrays
|
|
609
|
+
ss_history = (
|
|
610
|
+
jnp.zeros((ss_num_recordings, feature_dim), dtype=jnp.float32)
|
|
611
|
+
if step_size_tracking
|
|
612
|
+
else None
|
|
613
|
+
)
|
|
614
|
+
ss_bias_history = (
|
|
615
|
+
jnp.zeros(ss_num_recordings, dtype=jnp.float32)
|
|
616
|
+
if step_size_tracking and step_size_tracking.include_bias
|
|
617
|
+
else None
|
|
618
|
+
)
|
|
619
|
+
ss_rec_indices = (
|
|
620
|
+
jnp.zeros(ss_num_recordings, dtype=jnp.int32) if step_size_tracking else None
|
|
621
|
+
)
|
|
622
|
+
|
|
623
|
+
# Check if we need to track Autostep normalizers
|
|
624
|
+
track_autostep_normalizers = hasattr(
|
|
625
|
+
learner_state.learner_state.optimizer_state, "normalizers"
|
|
626
|
+
)
|
|
627
|
+
ss_normalizers = (
|
|
628
|
+
jnp.zeros((ss_num_recordings, feature_dim), dtype=jnp.float32)
|
|
629
|
+
if step_size_tracking and track_autostep_normalizers
|
|
630
|
+
else None
|
|
631
|
+
)
|
|
632
|
+
|
|
633
|
+
# Pre-allocate normalizer state history arrays
|
|
634
|
+
norm_means = (
|
|
635
|
+
jnp.zeros((norm_num_recordings, feature_dim), dtype=jnp.float32)
|
|
636
|
+
if normalizer_tracking
|
|
637
|
+
else None
|
|
638
|
+
)
|
|
639
|
+
norm_vars = (
|
|
640
|
+
jnp.zeros((norm_num_recordings, feature_dim), dtype=jnp.float32)
|
|
641
|
+
if normalizer_tracking
|
|
642
|
+
else None
|
|
643
|
+
)
|
|
644
|
+
norm_rec_indices = (
|
|
645
|
+
jnp.zeros(norm_num_recordings, dtype=jnp.int32) if normalizer_tracking else None
|
|
646
|
+
)
|
|
647
|
+
|
|
648
|
+
def step_fn_with_tracking(
|
|
649
|
+
carry: tuple[
|
|
650
|
+
NormalizedLearnerState,
|
|
651
|
+
StreamStateT,
|
|
652
|
+
Array | None,
|
|
653
|
+
Array | None,
|
|
654
|
+
Array | None,
|
|
655
|
+
Array | None,
|
|
656
|
+
Array | None,
|
|
657
|
+
Array | None,
|
|
658
|
+
Array | None,
|
|
659
|
+
],
|
|
660
|
+
idx: Array,
|
|
661
|
+
) -> tuple[
|
|
662
|
+
tuple[
|
|
663
|
+
NormalizedLearnerState,
|
|
664
|
+
StreamStateT,
|
|
665
|
+
Array | None,
|
|
666
|
+
Array | None,
|
|
667
|
+
Array | None,
|
|
668
|
+
Array | None,
|
|
669
|
+
Array | None,
|
|
670
|
+
Array | None,
|
|
671
|
+
Array | None,
|
|
672
|
+
],
|
|
673
|
+
Array,
|
|
674
|
+
]:
|
|
675
|
+
(
|
|
676
|
+
l_state,
|
|
677
|
+
s_state,
|
|
678
|
+
ss_hist,
|
|
679
|
+
ss_bias_hist,
|
|
680
|
+
ss_rec,
|
|
681
|
+
ss_norm,
|
|
682
|
+
n_means,
|
|
683
|
+
n_vars,
|
|
684
|
+
n_rec,
|
|
685
|
+
) = carry
|
|
686
|
+
|
|
687
|
+
# Perform learning step
|
|
688
|
+
timestep, new_s_state = stream.step(s_state, idx)
|
|
689
|
+
result = learner.update(l_state, timestep.observation, timestep.target)
|
|
690
|
+
|
|
691
|
+
# Step-size tracking
|
|
692
|
+
new_ss_hist = ss_hist
|
|
693
|
+
new_ss_bias_hist = ss_bias_hist
|
|
694
|
+
new_ss_rec = ss_rec
|
|
695
|
+
new_ss_norm = ss_norm
|
|
696
|
+
|
|
697
|
+
if ss_hist is not None:
|
|
698
|
+
should_record_ss = (idx % ss_interval) == 0
|
|
699
|
+
recording_idx = idx // ss_interval
|
|
700
|
+
|
|
701
|
+
# Extract current step-sizes from the inner learner state
|
|
702
|
+
opt_state = result.state.learner_state.optimizer_state
|
|
703
|
+
if hasattr(opt_state, "log_step_sizes"):
|
|
704
|
+
# IDBD stores log step-sizes
|
|
705
|
+
weight_ss = jnp.exp(opt_state.log_step_sizes) # type: ignore[union-attr]
|
|
706
|
+
bias_ss = opt_state.bias_step_size # type: ignore[union-attr]
|
|
707
|
+
elif hasattr(opt_state, "step_sizes"):
|
|
708
|
+
# Autostep stores step-sizes directly
|
|
709
|
+
weight_ss = opt_state.step_sizes # type: ignore[union-attr]
|
|
710
|
+
bias_ss = opt_state.bias_step_size # type: ignore[union-attr]
|
|
711
|
+
else:
|
|
712
|
+
# LMS has a single fixed step-size
|
|
713
|
+
weight_ss = jnp.full(feature_dim, opt_state.step_size)
|
|
714
|
+
bias_ss = opt_state.step_size
|
|
715
|
+
|
|
716
|
+
new_ss_hist = jax.lax.cond(
|
|
717
|
+
should_record_ss,
|
|
718
|
+
lambda _: ss_hist.at[recording_idx].set(weight_ss),
|
|
719
|
+
lambda _: ss_hist,
|
|
720
|
+
None,
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
if ss_bias_hist is not None:
|
|
724
|
+
new_ss_bias_hist = jax.lax.cond(
|
|
725
|
+
should_record_ss,
|
|
726
|
+
lambda _: ss_bias_hist.at[recording_idx].set(bias_ss),
|
|
727
|
+
lambda _: ss_bias_hist,
|
|
728
|
+
None,
|
|
729
|
+
)
|
|
730
|
+
|
|
731
|
+
if ss_rec is not None:
|
|
732
|
+
new_ss_rec = jax.lax.cond(
|
|
733
|
+
should_record_ss,
|
|
734
|
+
lambda _: ss_rec.at[recording_idx].set(idx),
|
|
735
|
+
lambda _: ss_rec,
|
|
736
|
+
None,
|
|
737
|
+
)
|
|
738
|
+
|
|
739
|
+
# Track Autostep normalizers (v_i) if applicable
|
|
740
|
+
if ss_norm is not None and hasattr(opt_state, "normalizers"):
|
|
741
|
+
new_ss_norm = jax.lax.cond(
|
|
742
|
+
should_record_ss,
|
|
743
|
+
lambda _: ss_norm.at[recording_idx].set(
|
|
744
|
+
opt_state.normalizers # type: ignore[union-attr]
|
|
745
|
+
),
|
|
746
|
+
lambda _: ss_norm,
|
|
747
|
+
None,
|
|
748
|
+
)
|
|
749
|
+
|
|
750
|
+
# Normalizer state tracking
|
|
751
|
+
new_n_means = n_means
|
|
752
|
+
new_n_vars = n_vars
|
|
753
|
+
new_n_rec = n_rec
|
|
754
|
+
|
|
755
|
+
if n_means is not None:
|
|
756
|
+
should_record_norm = (idx % norm_interval) == 0
|
|
757
|
+
norm_recording_idx = idx // norm_interval
|
|
758
|
+
|
|
759
|
+
norm_state = result.state.normalizer_state
|
|
760
|
+
|
|
761
|
+
new_n_means = jax.lax.cond(
|
|
762
|
+
should_record_norm,
|
|
763
|
+
lambda _: n_means.at[norm_recording_idx].set(norm_state.mean),
|
|
764
|
+
lambda _: n_means,
|
|
765
|
+
None,
|
|
766
|
+
)
|
|
767
|
+
|
|
768
|
+
if n_vars is not None:
|
|
769
|
+
new_n_vars = jax.lax.cond(
|
|
770
|
+
should_record_norm,
|
|
771
|
+
lambda _: n_vars.at[norm_recording_idx].set(norm_state.var),
|
|
772
|
+
lambda _: n_vars,
|
|
773
|
+
None,
|
|
774
|
+
)
|
|
775
|
+
|
|
776
|
+
if n_rec is not None:
|
|
777
|
+
new_n_rec = jax.lax.cond(
|
|
778
|
+
should_record_norm,
|
|
779
|
+
lambda _: n_rec.at[norm_recording_idx].set(idx),
|
|
780
|
+
lambda _: n_rec,
|
|
781
|
+
None,
|
|
782
|
+
)
|
|
783
|
+
|
|
784
|
+
return (
|
|
785
|
+
result.state,
|
|
786
|
+
new_s_state,
|
|
787
|
+
new_ss_hist,
|
|
788
|
+
new_ss_bias_hist,
|
|
789
|
+
new_ss_rec,
|
|
790
|
+
new_ss_norm,
|
|
791
|
+
new_n_means,
|
|
792
|
+
new_n_vars,
|
|
793
|
+
new_n_rec,
|
|
794
|
+
), result.metrics
|
|
795
|
+
|
|
796
|
+
initial_carry = (
|
|
797
|
+
learner_state,
|
|
798
|
+
stream_state,
|
|
799
|
+
ss_history,
|
|
800
|
+
ss_bias_history,
|
|
801
|
+
ss_rec_indices,
|
|
802
|
+
ss_normalizers,
|
|
803
|
+
norm_means,
|
|
804
|
+
norm_vars,
|
|
805
|
+
norm_rec_indices,
|
|
806
|
+
)
|
|
807
|
+
|
|
808
|
+
(
|
|
809
|
+
final_learner,
|
|
810
|
+
_,
|
|
811
|
+
final_ss_hist,
|
|
812
|
+
final_ss_bias_hist,
|
|
813
|
+
final_ss_rec,
|
|
814
|
+
final_ss_norm,
|
|
815
|
+
final_n_means,
|
|
816
|
+
final_n_vars,
|
|
817
|
+
final_n_rec,
|
|
818
|
+
), metrics = jax.lax.scan(
|
|
819
|
+
step_fn_with_tracking, initial_carry, jnp.arange(num_steps)
|
|
820
|
+
)
|
|
821
|
+
|
|
822
|
+
# Build return values based on what was tracked
|
|
823
|
+
ss_history_result = None
|
|
824
|
+
if step_size_tracking is not None and final_ss_hist is not None:
|
|
825
|
+
ss_history_result = StepSizeHistory(
|
|
826
|
+
step_sizes=final_ss_hist,
|
|
827
|
+
bias_step_sizes=final_ss_bias_hist,
|
|
828
|
+
recording_indices=final_ss_rec, # type: ignore[arg-type]
|
|
829
|
+
normalizers=final_ss_norm,
|
|
830
|
+
)
|
|
831
|
+
|
|
832
|
+
norm_history_result = None
|
|
833
|
+
if normalizer_tracking is not None and final_n_means is not None:
|
|
834
|
+
norm_history_result = NormalizerHistory(
|
|
835
|
+
means=final_n_means,
|
|
836
|
+
variances=final_n_vars, # type: ignore[arg-type]
|
|
837
|
+
recording_indices=final_n_rec, # type: ignore[arg-type]
|
|
838
|
+
)
|
|
839
|
+
|
|
840
|
+
# Return appropriate tuple based on what was tracked
|
|
841
|
+
if ss_history_result is not None and norm_history_result is not None:
|
|
842
|
+
return final_learner, metrics, ss_history_result, norm_history_result
|
|
843
|
+
elif ss_history_result is not None:
|
|
844
|
+
return final_learner, metrics, ss_history_result
|
|
845
|
+
elif norm_history_result is not None:
|
|
846
|
+
return final_learner, metrics, norm_history_result
|
|
847
|
+
else:
|
|
848
|
+
return final_learner, metrics
|
|
849
|
+
|
|
850
|
+
|
|
851
|
+
def run_learning_loop_batched[StreamStateT](
|
|
852
|
+
learner: LinearLearner,
|
|
853
|
+
stream: ScanStream[StreamStateT],
|
|
854
|
+
num_steps: int,
|
|
855
|
+
keys: Array,
|
|
856
|
+
learner_state: LearnerState | None = None,
|
|
857
|
+
step_size_tracking: StepSizeTrackingConfig | None = None,
|
|
858
|
+
) -> BatchedLearningResult:
|
|
859
|
+
"""Run learning loop across multiple seeds in parallel using jax.vmap.
|
|
860
|
+
|
|
861
|
+
This function provides GPU parallelization for multi-seed experiments,
|
|
862
|
+
typically achieving 2-5x speedup over sequential execution.
|
|
863
|
+
|
|
864
|
+
Args:
|
|
865
|
+
learner: The learner to train
|
|
866
|
+
stream: Experience stream providing (observation, target) pairs
|
|
867
|
+
num_steps: Number of learning steps to run per seed
|
|
868
|
+
keys: JAX random keys with shape (num_seeds,) or (num_seeds, 2)
|
|
869
|
+
learner_state: Initial state (if None, will be initialized from stream).
|
|
870
|
+
The same initial state is used for all seeds.
|
|
871
|
+
step_size_tracking: Optional config for recording per-weight step-sizes.
|
|
872
|
+
When provided, history arrays have shape (num_seeds, num_recordings, ...)
|
|
873
|
+
|
|
874
|
+
Returns:
|
|
875
|
+
BatchedLearningResult containing:
|
|
876
|
+
- states: Batched final states with shape (num_seeds, ...) for each array
|
|
877
|
+
- metrics: Array of shape (num_seeds, num_steps, 3)
|
|
878
|
+
- step_size_history: Batched history or None if tracking disabled
|
|
879
|
+
|
|
880
|
+
Examples:
|
|
881
|
+
```python
|
|
882
|
+
import jax.random as jr
|
|
883
|
+
from alberta_framework import LinearLearner, IDBD, RandomWalkStream
|
|
884
|
+
from alberta_framework import run_learning_loop_batched
|
|
885
|
+
|
|
886
|
+
stream = RandomWalkStream(feature_dim=10)
|
|
887
|
+
learner = LinearLearner(optimizer=IDBD())
|
|
888
|
+
|
|
889
|
+
# Run 30 seeds in parallel
|
|
890
|
+
keys = jr.split(jr.key(42), 30)
|
|
891
|
+
result = run_learning_loop_batched(learner, stream, num_steps=10000, keys=keys)
|
|
892
|
+
|
|
893
|
+
# result.metrics has shape (30, 10000, 3)
|
|
894
|
+
mean_error = result.metrics[:, :, 0].mean(axis=0) # Average over seeds
|
|
895
|
+
```
|
|
896
|
+
"""
|
|
897
|
+
# Define single-seed function that returns consistent structure
|
|
898
|
+
def single_seed_run(key: Array) -> tuple[LearnerState, Array, StepSizeHistory | None]:
|
|
899
|
+
result = run_learning_loop(
|
|
900
|
+
learner, stream, num_steps, key, learner_state, step_size_tracking
|
|
901
|
+
)
|
|
902
|
+
if step_size_tracking is not None:
|
|
903
|
+
state, metrics, history = cast(
|
|
904
|
+
tuple[LearnerState, Array, StepSizeHistory], result
|
|
905
|
+
)
|
|
906
|
+
return state, metrics, history
|
|
907
|
+
else:
|
|
908
|
+
state, metrics = cast(tuple[LearnerState, Array], result)
|
|
909
|
+
# Return None for history to maintain consistent output structure
|
|
910
|
+
return state, metrics, None
|
|
911
|
+
|
|
912
|
+
# vmap over the keys dimension
|
|
913
|
+
batched_states, batched_metrics, batched_history = jax.vmap(single_seed_run)(keys)
|
|
914
|
+
|
|
915
|
+
# Reconstruct batched history if tracking was enabled
|
|
916
|
+
if step_size_tracking is not None and batched_history is not None:
|
|
917
|
+
batched_step_size_history = StepSizeHistory(
|
|
918
|
+
step_sizes=batched_history.step_sizes,
|
|
919
|
+
bias_step_sizes=batched_history.bias_step_sizes,
|
|
920
|
+
recording_indices=batched_history.recording_indices,
|
|
921
|
+
normalizers=batched_history.normalizers,
|
|
922
|
+
)
|
|
923
|
+
else:
|
|
924
|
+
batched_step_size_history = None
|
|
925
|
+
|
|
926
|
+
return BatchedLearningResult(
|
|
927
|
+
states=batched_states,
|
|
928
|
+
metrics=batched_metrics,
|
|
929
|
+
step_size_history=batched_step_size_history,
|
|
930
|
+
)
|
|
931
|
+
|
|
932
|
+
|
|
933
|
+
def run_normalized_learning_loop_batched[StreamStateT](
|
|
934
|
+
learner: NormalizedLinearLearner,
|
|
935
|
+
stream: ScanStream[StreamStateT],
|
|
936
|
+
num_steps: int,
|
|
937
|
+
keys: Array,
|
|
938
|
+
learner_state: NormalizedLearnerState | None = None,
|
|
939
|
+
step_size_tracking: StepSizeTrackingConfig | None = None,
|
|
940
|
+
normalizer_tracking: NormalizerTrackingConfig | None = None,
|
|
941
|
+
) -> BatchedNormalizedResult:
|
|
942
|
+
"""Run normalized learning loop across multiple seeds in parallel using jax.vmap.
|
|
943
|
+
|
|
944
|
+
This function provides GPU parallelization for multi-seed experiments with
|
|
945
|
+
normalized learners, typically achieving 2-5x speedup over sequential execution.
|
|
946
|
+
|
|
947
|
+
Args:
|
|
948
|
+
learner: The normalized learner to train
|
|
949
|
+
stream: Experience stream providing (observation, target) pairs
|
|
950
|
+
num_steps: Number of learning steps to run per seed
|
|
951
|
+
keys: JAX random keys with shape (num_seeds,) or (num_seeds, 2)
|
|
952
|
+
learner_state: Initial state (if None, will be initialized from stream).
|
|
953
|
+
The same initial state is used for all seeds.
|
|
954
|
+
step_size_tracking: Optional config for recording per-weight step-sizes.
|
|
955
|
+
When provided, history arrays have shape (num_seeds, num_recordings, ...)
|
|
956
|
+
normalizer_tracking: Optional config for recording normalizer state.
|
|
957
|
+
When provided, history arrays have shape (num_seeds, num_recordings, ...)
|
|
958
|
+
|
|
959
|
+
Returns:
|
|
960
|
+
BatchedNormalizedResult containing:
|
|
961
|
+
- states: Batched final states with shape (num_seeds, ...) for each array
|
|
962
|
+
- metrics: Array of shape (num_seeds, num_steps, 4)
|
|
963
|
+
- step_size_history: Batched history or None if tracking disabled
|
|
964
|
+
- normalizer_history: Batched history or None if tracking disabled
|
|
965
|
+
|
|
966
|
+
Examples:
|
|
967
|
+
```python
|
|
968
|
+
import jax.random as jr
|
|
969
|
+
from alberta_framework import NormalizedLinearLearner, IDBD, RandomWalkStream
|
|
970
|
+
from alberta_framework import run_normalized_learning_loop_batched
|
|
971
|
+
|
|
972
|
+
stream = RandomWalkStream(feature_dim=10)
|
|
973
|
+
learner = NormalizedLinearLearner(optimizer=IDBD())
|
|
974
|
+
|
|
975
|
+
# Run 30 seeds in parallel
|
|
976
|
+
keys = jr.split(jr.key(42), 30)
|
|
977
|
+
result = run_normalized_learning_loop_batched(
|
|
978
|
+
learner, stream, num_steps=10000, keys=keys
|
|
979
|
+
)
|
|
980
|
+
|
|
981
|
+
# result.metrics has shape (30, 10000, 4)
|
|
982
|
+
mean_error = result.metrics[:, :, 0].mean(axis=0) # Average over seeds
|
|
983
|
+
```
|
|
984
|
+
"""
|
|
985
|
+
# Define single-seed function that returns consistent structure
|
|
986
|
+
def single_seed_run(
|
|
987
|
+
key: Array,
|
|
988
|
+
) -> tuple[
|
|
989
|
+
NormalizedLearnerState, Array, StepSizeHistory | None, NormalizerHistory | None
|
|
990
|
+
]:
|
|
991
|
+
result = run_normalized_learning_loop(
|
|
992
|
+
learner, stream, num_steps, key, learner_state,
|
|
993
|
+
step_size_tracking, normalizer_tracking
|
|
994
|
+
)
|
|
995
|
+
|
|
996
|
+
# Unpack based on what tracking was enabled
|
|
997
|
+
if step_size_tracking is not None and normalizer_tracking is not None:
|
|
998
|
+
state, metrics, ss_history, norm_history = cast(
|
|
999
|
+
tuple[NormalizedLearnerState, Array, StepSizeHistory, NormalizerHistory],
|
|
1000
|
+
result,
|
|
1001
|
+
)
|
|
1002
|
+
return state, metrics, ss_history, norm_history
|
|
1003
|
+
elif step_size_tracking is not None:
|
|
1004
|
+
state, metrics, ss_history = cast(
|
|
1005
|
+
tuple[NormalizedLearnerState, Array, StepSizeHistory], result
|
|
1006
|
+
)
|
|
1007
|
+
return state, metrics, ss_history, None
|
|
1008
|
+
elif normalizer_tracking is not None:
|
|
1009
|
+
state, metrics, norm_history = cast(
|
|
1010
|
+
tuple[NormalizedLearnerState, Array, NormalizerHistory], result
|
|
1011
|
+
)
|
|
1012
|
+
return state, metrics, None, norm_history
|
|
1013
|
+
else:
|
|
1014
|
+
state, metrics = cast(tuple[NormalizedLearnerState, Array], result)
|
|
1015
|
+
return state, metrics, None, None
|
|
1016
|
+
|
|
1017
|
+
# vmap over the keys dimension
|
|
1018
|
+
batched_states, batched_metrics, batched_ss_history, batched_norm_history = (
|
|
1019
|
+
jax.vmap(single_seed_run)(keys)
|
|
1020
|
+
)
|
|
1021
|
+
|
|
1022
|
+
# Reconstruct batched histories if tracking was enabled
|
|
1023
|
+
if step_size_tracking is not None and batched_ss_history is not None:
|
|
1024
|
+
batched_step_size_history = StepSizeHistory(
|
|
1025
|
+
step_sizes=batched_ss_history.step_sizes,
|
|
1026
|
+
bias_step_sizes=batched_ss_history.bias_step_sizes,
|
|
1027
|
+
recording_indices=batched_ss_history.recording_indices,
|
|
1028
|
+
normalizers=batched_ss_history.normalizers,
|
|
1029
|
+
)
|
|
1030
|
+
else:
|
|
1031
|
+
batched_step_size_history = None
|
|
1032
|
+
|
|
1033
|
+
if normalizer_tracking is not None and batched_norm_history is not None:
|
|
1034
|
+
batched_normalizer_history = NormalizerHistory(
|
|
1035
|
+
means=batched_norm_history.means,
|
|
1036
|
+
variances=batched_norm_history.variances,
|
|
1037
|
+
recording_indices=batched_norm_history.recording_indices,
|
|
1038
|
+
)
|
|
1039
|
+
else:
|
|
1040
|
+
batched_normalizer_history = None
|
|
1041
|
+
|
|
1042
|
+
return BatchedNormalizedResult(
|
|
1043
|
+
states=batched_states,
|
|
1044
|
+
metrics=batched_metrics,
|
|
1045
|
+
step_size_history=batched_step_size_history,
|
|
1046
|
+
normalizer_history=batched_normalizer_history,
|
|
1047
|
+
)
|
|
1048
|
+
|
|
1049
|
+
|
|
1050
|
+
def metrics_to_dicts(metrics: Array, normalized: bool = False) -> list[dict[str, float]]:
|
|
1051
|
+
"""Convert metrics array to list of dicts for backward compatibility.
|
|
1052
|
+
|
|
1053
|
+
Args:
|
|
1054
|
+
metrics: Array of shape (num_steps, 3) or (num_steps, 4)
|
|
1055
|
+
normalized: If True, expects 4 columns including normalizer_mean_var
|
|
1056
|
+
|
|
1057
|
+
Returns:
|
|
1058
|
+
List of metric dictionaries
|
|
1059
|
+
"""
|
|
1060
|
+
result = []
|
|
1061
|
+
for row in metrics:
|
|
1062
|
+
d = {
|
|
1063
|
+
"squared_error": float(row[0]),
|
|
1064
|
+
"error": float(row[1]),
|
|
1065
|
+
"mean_step_size": float(row[2]),
|
|
1066
|
+
}
|
|
1067
|
+
if normalized and len(row) > 3:
|
|
1068
|
+
d["normalizer_mean_var"] = float(row[3])
|
|
1069
|
+
result.append(d)
|
|
1070
|
+
return result
|