alberta-framework 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alberta_framework/__init__.py +196 -0
- alberta_framework/core/__init__.py +27 -0
- alberta_framework/core/learners.py +530 -0
- alberta_framework/core/normalizers.py +192 -0
- alberta_framework/core/optimizers.py +422 -0
- alberta_framework/core/types.py +198 -0
- alberta_framework/py.typed +0 -0
- alberta_framework/streams/__init__.py +83 -0
- alberta_framework/streams/base.py +70 -0
- alberta_framework/streams/gymnasium.py +655 -0
- alberta_framework/streams/synthetic.py +995 -0
- alberta_framework/utils/__init__.py +113 -0
- alberta_framework/utils/experiments.py +334 -0
- alberta_framework/utils/export.py +509 -0
- alberta_framework/utils/metrics.py +112 -0
- alberta_framework/utils/statistics.py +527 -0
- alberta_framework/utils/timing.py +138 -0
- alberta_framework/utils/visualization.py +571 -0
- alberta_framework-0.1.0.dist-info/METADATA +198 -0
- alberta_framework-0.1.0.dist-info/RECORD +22 -0
- alberta_framework-0.1.0.dist-info/WHEEL +4 -0
- alberta_framework-0.1.0.dist-info/licenses/LICENSE +190 -0
|
@@ -0,0 +1,530 @@
|
|
|
1
|
+
"""Learning units for continual learning.
|
|
2
|
+
|
|
3
|
+
Implements learners that combine function approximation with optimizers
|
|
4
|
+
for temporally-uniform learning. Uses JAX's scan for efficient JIT-compiled
|
|
5
|
+
training loops.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import NamedTuple
|
|
9
|
+
|
|
10
|
+
import jax
|
|
11
|
+
import jax.numpy as jnp
|
|
12
|
+
from jax import Array
|
|
13
|
+
|
|
14
|
+
from alberta_framework.core.normalizers import NormalizerState, OnlineNormalizer
|
|
15
|
+
from alberta_framework.core.optimizers import LMS, Optimizer
|
|
16
|
+
from alberta_framework.core.types import (
|
|
17
|
+
AutostepState,
|
|
18
|
+
IDBDState,
|
|
19
|
+
LearnerState,
|
|
20
|
+
LMSState,
|
|
21
|
+
Observation,
|
|
22
|
+
Prediction,
|
|
23
|
+
StepSizeHistory,
|
|
24
|
+
StepSizeTrackingConfig,
|
|
25
|
+
Target,
|
|
26
|
+
)
|
|
27
|
+
from alberta_framework.streams.base import ScanStream
|
|
28
|
+
|
|
29
|
+
# Type alias for any optimizer type
|
|
30
|
+
AnyOptimizer = Optimizer[LMSState] | Optimizer[IDBDState] | Optimizer[AutostepState]
|
|
31
|
+
|
|
32
|
+
class UpdateResult(NamedTuple):
|
|
33
|
+
"""Result of a learner update step.
|
|
34
|
+
|
|
35
|
+
Attributes:
|
|
36
|
+
state: Updated learner state
|
|
37
|
+
prediction: Prediction made before update
|
|
38
|
+
error: Prediction error
|
|
39
|
+
metrics: Array of metrics [squared_error, error, ...]
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
state: LearnerState
|
|
43
|
+
prediction: Prediction
|
|
44
|
+
error: Array
|
|
45
|
+
metrics: Array
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class LinearLearner:
|
|
49
|
+
"""Linear function approximator with pluggable optimizer.
|
|
50
|
+
|
|
51
|
+
Computes predictions as: y = w @ x + b
|
|
52
|
+
|
|
53
|
+
The learner maintains weights and bias, delegating the adaptation
|
|
54
|
+
of learning rates to the optimizer (e.g., LMS or IDBD).
|
|
55
|
+
|
|
56
|
+
This follows the Alberta Plan philosophy of temporal uniformity:
|
|
57
|
+
every component updates at every time step.
|
|
58
|
+
|
|
59
|
+
Attributes:
|
|
60
|
+
optimizer: The optimizer to use for weight updates
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(self, optimizer: AnyOptimizer | None = None):
|
|
64
|
+
"""Initialize the linear learner.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
optimizer: Optimizer for weight updates. Defaults to LMS(0.01)
|
|
68
|
+
"""
|
|
69
|
+
self._optimizer: AnyOptimizer = optimizer or LMS(step_size=0.01)
|
|
70
|
+
|
|
71
|
+
def init(self, feature_dim: int) -> LearnerState:
|
|
72
|
+
"""Initialize learner state.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
feature_dim: Dimension of the input feature vector
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Initial learner state with zero weights and bias
|
|
79
|
+
"""
|
|
80
|
+
optimizer_state = self._optimizer.init(feature_dim)
|
|
81
|
+
|
|
82
|
+
return LearnerState(
|
|
83
|
+
weights=jnp.zeros(feature_dim, dtype=jnp.float32),
|
|
84
|
+
bias=jnp.array(0.0, dtype=jnp.float32),
|
|
85
|
+
optimizer_state=optimizer_state,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
def predict(self, state: LearnerState, observation: Observation) -> Prediction:
|
|
89
|
+
"""Compute prediction for an observation.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
state: Current learner state
|
|
93
|
+
observation: Input feature vector
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
Scalar prediction y = w @ x + b
|
|
97
|
+
"""
|
|
98
|
+
return jnp.atleast_1d(jnp.dot(state.weights, observation) + state.bias)
|
|
99
|
+
|
|
100
|
+
def update(
|
|
101
|
+
self,
|
|
102
|
+
state: LearnerState,
|
|
103
|
+
observation: Observation,
|
|
104
|
+
target: Target,
|
|
105
|
+
) -> UpdateResult:
|
|
106
|
+
"""Update learner given observation and target.
|
|
107
|
+
|
|
108
|
+
Performs one step of the learning algorithm:
|
|
109
|
+
1. Compute prediction
|
|
110
|
+
2. Compute error
|
|
111
|
+
3. Get weight updates from optimizer
|
|
112
|
+
4. Apply updates to weights and bias
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
state: Current learner state
|
|
116
|
+
observation: Input feature vector
|
|
117
|
+
target: Desired output
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
UpdateResult with new state, prediction, error, and metrics
|
|
121
|
+
"""
|
|
122
|
+
# Make prediction
|
|
123
|
+
prediction = self.predict(state, observation)
|
|
124
|
+
|
|
125
|
+
# Compute error (target - prediction)
|
|
126
|
+
error = jnp.squeeze(target) - jnp.squeeze(prediction)
|
|
127
|
+
|
|
128
|
+
# Get update from optimizer
|
|
129
|
+
# Note: type ignore needed because we can't statically prove optimizer_state
|
|
130
|
+
# matches the optimizer's expected state type (though they will at runtime)
|
|
131
|
+
opt_update = self._optimizer.update(
|
|
132
|
+
state.optimizer_state, # type: ignore[arg-type]
|
|
133
|
+
error,
|
|
134
|
+
observation,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
# Apply updates
|
|
138
|
+
new_weights = state.weights + opt_update.weight_delta
|
|
139
|
+
new_bias = state.bias + opt_update.bias_delta
|
|
140
|
+
|
|
141
|
+
new_state = LearnerState(
|
|
142
|
+
weights=new_weights,
|
|
143
|
+
bias=new_bias,
|
|
144
|
+
optimizer_state=opt_update.new_state,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
# Pack metrics as array for scan compatibility
|
|
148
|
+
# Format: [squared_error, error, mean_step_size (if adaptive)]
|
|
149
|
+
squared_error = error**2
|
|
150
|
+
mean_step_size = opt_update.metrics.get("mean_step_size", 0.0)
|
|
151
|
+
metrics = jnp.array([squared_error, error, mean_step_size], dtype=jnp.float32)
|
|
152
|
+
|
|
153
|
+
return UpdateResult(
|
|
154
|
+
state=new_state,
|
|
155
|
+
prediction=prediction,
|
|
156
|
+
error=jnp.atleast_1d(error),
|
|
157
|
+
metrics=metrics,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def run_learning_loop[StreamStateT](
|
|
162
|
+
learner: LinearLearner,
|
|
163
|
+
stream: ScanStream[StreamStateT],
|
|
164
|
+
num_steps: int,
|
|
165
|
+
key: Array,
|
|
166
|
+
learner_state: LearnerState | None = None,
|
|
167
|
+
step_size_tracking: StepSizeTrackingConfig | None = None,
|
|
168
|
+
) -> tuple[LearnerState, Array] | tuple[LearnerState, Array, StepSizeHistory]:
|
|
169
|
+
"""Run the learning loop using jax.lax.scan.
|
|
170
|
+
|
|
171
|
+
This is a JIT-compiled learning loop that uses scan for efficiency.
|
|
172
|
+
It returns metrics as a fixed-size array rather than a list of dicts.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
learner: The learner to train
|
|
176
|
+
stream: Experience stream providing (observation, target) pairs
|
|
177
|
+
num_steps: Number of learning steps to run
|
|
178
|
+
key: JAX random key for stream initialization
|
|
179
|
+
learner_state: Initial state (if None, will be initialized from stream)
|
|
180
|
+
step_size_tracking: Optional config for recording per-weight step-sizes.
|
|
181
|
+
When provided, returns a 3-tuple including StepSizeHistory.
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
If step_size_tracking is None:
|
|
185
|
+
Tuple of (final_state, metrics_array) where metrics_array has shape
|
|
186
|
+
(num_steps, 3) with columns [squared_error, error, mean_step_size]
|
|
187
|
+
If step_size_tracking is provided:
|
|
188
|
+
Tuple of (final_state, metrics_array, step_size_history)
|
|
189
|
+
|
|
190
|
+
Raises:
|
|
191
|
+
ValueError: If step_size_tracking.interval is less than 1 or greater than num_steps
|
|
192
|
+
"""
|
|
193
|
+
# Validate tracking config
|
|
194
|
+
if step_size_tracking is not None:
|
|
195
|
+
if step_size_tracking.interval < 1:
|
|
196
|
+
raise ValueError(
|
|
197
|
+
f"step_size_tracking.interval must be >= 1, got {step_size_tracking.interval}"
|
|
198
|
+
)
|
|
199
|
+
if step_size_tracking.interval > num_steps:
|
|
200
|
+
raise ValueError(
|
|
201
|
+
f"step_size_tracking.interval ({step_size_tracking.interval}) "
|
|
202
|
+
f"must be <= num_steps ({num_steps})"
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
# Initialize states
|
|
206
|
+
if learner_state is None:
|
|
207
|
+
learner_state = learner.init(stream.feature_dim)
|
|
208
|
+
stream_state = stream.init(key)
|
|
209
|
+
|
|
210
|
+
feature_dim = stream.feature_dim
|
|
211
|
+
|
|
212
|
+
if step_size_tracking is None:
|
|
213
|
+
# Original behavior without tracking
|
|
214
|
+
def step_fn(
|
|
215
|
+
carry: tuple[LearnerState, StreamStateT], idx: Array
|
|
216
|
+
) -> tuple[tuple[LearnerState, StreamStateT], Array]:
|
|
217
|
+
l_state, s_state = carry
|
|
218
|
+
timestep, new_s_state = stream.step(s_state, idx)
|
|
219
|
+
result = learner.update(l_state, timestep.observation, timestep.target)
|
|
220
|
+
return (result.state, new_s_state), result.metrics
|
|
221
|
+
|
|
222
|
+
(final_learner, _), metrics = jax.lax.scan(
|
|
223
|
+
step_fn, (learner_state, stream_state), jnp.arange(num_steps)
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
return final_learner, metrics
|
|
227
|
+
|
|
228
|
+
else:
|
|
229
|
+
# Step-size tracking enabled
|
|
230
|
+
interval = step_size_tracking.interval
|
|
231
|
+
include_bias = step_size_tracking.include_bias
|
|
232
|
+
num_recordings = num_steps // interval
|
|
233
|
+
|
|
234
|
+
# Pre-allocate history arrays
|
|
235
|
+
step_size_history = jnp.zeros((num_recordings, feature_dim), dtype=jnp.float32)
|
|
236
|
+
bias_history = (
|
|
237
|
+
jnp.zeros(num_recordings, dtype=jnp.float32) if include_bias else None
|
|
238
|
+
)
|
|
239
|
+
recording_indices = jnp.zeros(num_recordings, dtype=jnp.int32)
|
|
240
|
+
|
|
241
|
+
def step_fn_with_tracking(
|
|
242
|
+
carry: tuple[LearnerState, StreamStateT, Array, Array | None, Array], idx: Array
|
|
243
|
+
) -> tuple[tuple[LearnerState, StreamStateT, Array, Array | None, Array], Array]:
|
|
244
|
+
l_state, s_state, ss_history, b_history, rec_indices = carry
|
|
245
|
+
|
|
246
|
+
# Perform learning step
|
|
247
|
+
timestep, new_s_state = stream.step(s_state, idx)
|
|
248
|
+
result = learner.update(l_state, timestep.observation, timestep.target)
|
|
249
|
+
|
|
250
|
+
# Check if we should record at this step (idx % interval == 0)
|
|
251
|
+
should_record = (idx % interval) == 0
|
|
252
|
+
recording_idx = idx // interval
|
|
253
|
+
|
|
254
|
+
# Extract current step-sizes
|
|
255
|
+
# Use hasattr checks at trace time (this works because the type is fixed)
|
|
256
|
+
opt_state = result.state.optimizer_state
|
|
257
|
+
if hasattr(opt_state, "log_step_sizes"):
|
|
258
|
+
# IDBD stores log step-sizes
|
|
259
|
+
weight_ss = jnp.exp(opt_state.log_step_sizes) # type: ignore[union-attr]
|
|
260
|
+
bias_ss = opt_state.bias_step_size # type: ignore[union-attr]
|
|
261
|
+
elif hasattr(opt_state, "step_sizes"):
|
|
262
|
+
# Autostep stores step-sizes directly
|
|
263
|
+
weight_ss = opt_state.step_sizes # type: ignore[union-attr]
|
|
264
|
+
bias_ss = opt_state.bias_step_size # type: ignore[union-attr]
|
|
265
|
+
else:
|
|
266
|
+
# LMS has a single fixed step-size
|
|
267
|
+
weight_ss = jnp.full(feature_dim, opt_state.step_size)
|
|
268
|
+
bias_ss = opt_state.step_size
|
|
269
|
+
|
|
270
|
+
# Conditionally update history arrays
|
|
271
|
+
new_ss_history = jax.lax.cond(
|
|
272
|
+
should_record,
|
|
273
|
+
lambda _: ss_history.at[recording_idx].set(weight_ss),
|
|
274
|
+
lambda _: ss_history,
|
|
275
|
+
None,
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
new_b_history = b_history
|
|
279
|
+
if b_history is not None:
|
|
280
|
+
new_b_history = jax.lax.cond(
|
|
281
|
+
should_record,
|
|
282
|
+
lambda _: b_history.at[recording_idx].set(bias_ss),
|
|
283
|
+
lambda _: b_history,
|
|
284
|
+
None,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
new_rec_indices = jax.lax.cond(
|
|
288
|
+
should_record,
|
|
289
|
+
lambda _: rec_indices.at[recording_idx].set(idx),
|
|
290
|
+
lambda _: rec_indices,
|
|
291
|
+
None,
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
return (
|
|
295
|
+
result.state,
|
|
296
|
+
new_s_state,
|
|
297
|
+
new_ss_history,
|
|
298
|
+
new_b_history,
|
|
299
|
+
new_rec_indices,
|
|
300
|
+
), result.metrics
|
|
301
|
+
|
|
302
|
+
initial_carry = (
|
|
303
|
+
learner_state,
|
|
304
|
+
stream_state,
|
|
305
|
+
step_size_history,
|
|
306
|
+
bias_history,
|
|
307
|
+
recording_indices,
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
(final_learner, _, final_ss_history, final_b_history, final_rec_indices), metrics = (
|
|
311
|
+
jax.lax.scan(step_fn_with_tracking, initial_carry, jnp.arange(num_steps))
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
history = StepSizeHistory(
|
|
315
|
+
step_sizes=final_ss_history,
|
|
316
|
+
bias_step_sizes=final_b_history,
|
|
317
|
+
recording_indices=final_rec_indices,
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
return final_learner, metrics, history
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
class NormalizedLearnerState(NamedTuple):
|
|
324
|
+
"""State for a learner with online feature normalization.
|
|
325
|
+
|
|
326
|
+
Attributes:
|
|
327
|
+
learner_state: Underlying learner state (weights, bias, optimizer)
|
|
328
|
+
normalizer_state: Online normalizer state (mean, var estimates)
|
|
329
|
+
"""
|
|
330
|
+
|
|
331
|
+
learner_state: LearnerState
|
|
332
|
+
normalizer_state: NormalizerState
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
class NormalizedUpdateResult(NamedTuple):
|
|
336
|
+
"""Result of a normalized learner update step.
|
|
337
|
+
|
|
338
|
+
Attributes:
|
|
339
|
+
state: Updated normalized learner state
|
|
340
|
+
prediction: Prediction made before update
|
|
341
|
+
error: Prediction error
|
|
342
|
+
metrics: Array of metrics [squared_error, error, mean_step_size, normalizer_mean_var]
|
|
343
|
+
"""
|
|
344
|
+
|
|
345
|
+
state: NormalizedLearnerState
|
|
346
|
+
prediction: Prediction
|
|
347
|
+
error: Array
|
|
348
|
+
metrics: Array
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
class NormalizedLinearLearner:
|
|
352
|
+
"""Linear learner with online feature normalization.
|
|
353
|
+
|
|
354
|
+
Wraps a LinearLearner with online feature normalization, following
|
|
355
|
+
the Alberta Plan's approach to handling varying feature scales.
|
|
356
|
+
|
|
357
|
+
Normalization is applied to features before prediction and learning:
|
|
358
|
+
x_normalized = (x - mean) / (std + epsilon)
|
|
359
|
+
|
|
360
|
+
The normalizer statistics update at every time step, maintaining
|
|
361
|
+
temporal uniformity.
|
|
362
|
+
|
|
363
|
+
Attributes:
|
|
364
|
+
learner: Underlying linear learner
|
|
365
|
+
normalizer: Online feature normalizer
|
|
366
|
+
"""
|
|
367
|
+
|
|
368
|
+
def __init__(
|
|
369
|
+
self,
|
|
370
|
+
optimizer: AnyOptimizer | None = None,
|
|
371
|
+
normalizer: OnlineNormalizer | None = None,
|
|
372
|
+
):
|
|
373
|
+
"""Initialize the normalized linear learner.
|
|
374
|
+
|
|
375
|
+
Args:
|
|
376
|
+
optimizer: Optimizer for weight updates. Defaults to LMS(0.01)
|
|
377
|
+
normalizer: Feature normalizer. Defaults to OnlineNormalizer()
|
|
378
|
+
"""
|
|
379
|
+
self._learner = LinearLearner(optimizer=optimizer or LMS(step_size=0.01))
|
|
380
|
+
self._normalizer = normalizer or OnlineNormalizer()
|
|
381
|
+
|
|
382
|
+
def init(self, feature_dim: int) -> NormalizedLearnerState:
|
|
383
|
+
"""Initialize normalized learner state.
|
|
384
|
+
|
|
385
|
+
Args:
|
|
386
|
+
feature_dim: Dimension of the input feature vector
|
|
387
|
+
|
|
388
|
+
Returns:
|
|
389
|
+
Initial state with zero weights and unit variance estimates
|
|
390
|
+
"""
|
|
391
|
+
return NormalizedLearnerState(
|
|
392
|
+
learner_state=self._learner.init(feature_dim),
|
|
393
|
+
normalizer_state=self._normalizer.init(feature_dim),
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
def predict(
|
|
397
|
+
self,
|
|
398
|
+
state: NormalizedLearnerState,
|
|
399
|
+
observation: Observation,
|
|
400
|
+
) -> Prediction:
|
|
401
|
+
"""Compute prediction for an observation.
|
|
402
|
+
|
|
403
|
+
Normalizes the observation using current statistics before prediction.
|
|
404
|
+
|
|
405
|
+
Args:
|
|
406
|
+
state: Current normalized learner state
|
|
407
|
+
observation: Raw (unnormalized) input feature vector
|
|
408
|
+
|
|
409
|
+
Returns:
|
|
410
|
+
Scalar prediction y = w @ normalize(x) + b
|
|
411
|
+
"""
|
|
412
|
+
normalized_obs = self._normalizer.normalize_only(
|
|
413
|
+
state.normalizer_state, observation
|
|
414
|
+
)
|
|
415
|
+
return self._learner.predict(state.learner_state, normalized_obs)
|
|
416
|
+
|
|
417
|
+
def update(
|
|
418
|
+
self,
|
|
419
|
+
state: NormalizedLearnerState,
|
|
420
|
+
observation: Observation,
|
|
421
|
+
target: Target,
|
|
422
|
+
) -> NormalizedUpdateResult:
|
|
423
|
+
"""Update learner given observation and target.
|
|
424
|
+
|
|
425
|
+
Performs one step of the learning algorithm:
|
|
426
|
+
1. Normalize observation (and update normalizer statistics)
|
|
427
|
+
2. Compute prediction using normalized features
|
|
428
|
+
3. Compute error
|
|
429
|
+
4. Get weight updates from optimizer
|
|
430
|
+
5. Apply updates
|
|
431
|
+
|
|
432
|
+
Args:
|
|
433
|
+
state: Current normalized learner state
|
|
434
|
+
observation: Raw (unnormalized) input feature vector
|
|
435
|
+
target: Desired output
|
|
436
|
+
|
|
437
|
+
Returns:
|
|
438
|
+
NormalizedUpdateResult with new state, prediction, error, and metrics
|
|
439
|
+
"""
|
|
440
|
+
# Normalize observation and update normalizer state
|
|
441
|
+
normalized_obs, new_normalizer_state = self._normalizer.normalize(
|
|
442
|
+
state.normalizer_state, observation
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
# Delegate to underlying learner
|
|
446
|
+
result = self._learner.update(
|
|
447
|
+
state.learner_state,
|
|
448
|
+
normalized_obs,
|
|
449
|
+
target,
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
# Build combined state
|
|
453
|
+
new_state = NormalizedLearnerState(
|
|
454
|
+
learner_state=result.state,
|
|
455
|
+
normalizer_state=new_normalizer_state,
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
# Add normalizer metrics to the metrics array
|
|
459
|
+
normalizer_mean_var = jnp.mean(new_normalizer_state.var)
|
|
460
|
+
metrics = jnp.concatenate([result.metrics, jnp.array([normalizer_mean_var])])
|
|
461
|
+
|
|
462
|
+
return NormalizedUpdateResult(
|
|
463
|
+
state=new_state,
|
|
464
|
+
prediction=result.prediction,
|
|
465
|
+
error=result.error,
|
|
466
|
+
metrics=metrics,
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
def run_normalized_learning_loop[StreamStateT](
|
|
471
|
+
learner: NormalizedLinearLearner,
|
|
472
|
+
stream: ScanStream[StreamStateT],
|
|
473
|
+
num_steps: int,
|
|
474
|
+
key: Array,
|
|
475
|
+
learner_state: NormalizedLearnerState | None = None,
|
|
476
|
+
) -> tuple[NormalizedLearnerState, Array]:
|
|
477
|
+
"""Run the learning loop with normalization using jax.lax.scan.
|
|
478
|
+
|
|
479
|
+
Args:
|
|
480
|
+
learner: The normalized learner to train
|
|
481
|
+
stream: Experience stream providing (observation, target) pairs
|
|
482
|
+
num_steps: Number of learning steps to run
|
|
483
|
+
key: JAX random key for stream initialization
|
|
484
|
+
learner_state: Initial state (if None, will be initialized from stream)
|
|
485
|
+
|
|
486
|
+
Returns:
|
|
487
|
+
Tuple of (final_state, metrics_array) where metrics_array has shape
|
|
488
|
+
(num_steps, 4) with columns [squared_error, error, mean_step_size, normalizer_mean_var]
|
|
489
|
+
"""
|
|
490
|
+
# Initialize states
|
|
491
|
+
if learner_state is None:
|
|
492
|
+
learner_state = learner.init(stream.feature_dim)
|
|
493
|
+
stream_state = stream.init(key)
|
|
494
|
+
|
|
495
|
+
def step_fn(
|
|
496
|
+
carry: tuple[NormalizedLearnerState, StreamStateT], idx: Array
|
|
497
|
+
) -> tuple[tuple[NormalizedLearnerState, StreamStateT], Array]:
|
|
498
|
+
l_state, s_state = carry
|
|
499
|
+
timestep, new_s_state = stream.step(s_state, idx)
|
|
500
|
+
result = learner.update(l_state, timestep.observation, timestep.target)
|
|
501
|
+
return (result.state, new_s_state), result.metrics
|
|
502
|
+
|
|
503
|
+
(final_learner, _), metrics = jax.lax.scan(
|
|
504
|
+
step_fn, (learner_state, stream_state), jnp.arange(num_steps)
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
return final_learner, metrics
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
def metrics_to_dicts(metrics: Array, normalized: bool = False) -> list[dict[str, float]]:
|
|
511
|
+
"""Convert metrics array to list of dicts for backward compatibility.
|
|
512
|
+
|
|
513
|
+
Args:
|
|
514
|
+
metrics: Array of shape (num_steps, 3) or (num_steps, 4)
|
|
515
|
+
normalized: If True, expects 4 columns including normalizer_mean_var
|
|
516
|
+
|
|
517
|
+
Returns:
|
|
518
|
+
List of metric dictionaries
|
|
519
|
+
"""
|
|
520
|
+
result = []
|
|
521
|
+
for row in metrics:
|
|
522
|
+
d = {
|
|
523
|
+
"squared_error": float(row[0]),
|
|
524
|
+
"error": float(row[1]),
|
|
525
|
+
"mean_step_size": float(row[2]),
|
|
526
|
+
}
|
|
527
|
+
if normalized and len(row) > 3:
|
|
528
|
+
d["normalizer_mean_var"] = float(row[3])
|
|
529
|
+
result.append(d)
|
|
530
|
+
return result
|