alberta-framework 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alberta_framework/__init__.py +39 -5
- alberta_framework/core/__init__.py +26 -2
- alberta_framework/core/learners.py +277 -59
- alberta_framework/core/normalizers.py +1 -4
- alberta_framework/core/optimizers.py +498 -1
- alberta_framework/core/types.py +176 -1
- alberta_framework/streams/gymnasium.py +3 -10
- alberta_framework/streams/synthetic.py +3 -9
- alberta_framework/utils/experiments.py +1 -3
- alberta_framework/utils/export.py +20 -16
- alberta_framework/utils/statistics.py +17 -9
- alberta_framework/utils/visualization.py +31 -25
- {alberta_framework-0.3.0.dist-info → alberta_framework-0.4.0.dist-info}/METADATA +24 -1
- alberta_framework-0.4.0.dist-info/RECORD +22 -0
- alberta_framework-0.3.0.dist-info/RECORD +0 -22
- {alberta_framework-0.3.0.dist-info → alberta_framework-0.4.0.dist-info}/WHEEL +0 -0
- {alberta_framework-0.3.0.dist-info → alberta_framework-0.4.0.dist-info}/licenses/LICENSE +0 -0
alberta_framework/__init__.py
CHANGED
|
@@ -39,7 +39,7 @@ References
|
|
|
39
39
|
- Tuning-free Step-size Adaptation (Mahmood et al., 2012)
|
|
40
40
|
"""
|
|
41
41
|
|
|
42
|
-
__version__ = "0.
|
|
42
|
+
__version__ = "0.4.0"
|
|
43
43
|
|
|
44
44
|
# Core types
|
|
45
45
|
# Learners
|
|
@@ -47,12 +47,15 @@ from alberta_framework.core.learners import (
|
|
|
47
47
|
LinearLearner,
|
|
48
48
|
NormalizedLearnerState,
|
|
49
49
|
NormalizedLinearLearner,
|
|
50
|
+
TDLinearLearner,
|
|
51
|
+
TDUpdateResult,
|
|
50
52
|
UpdateResult,
|
|
51
53
|
metrics_to_dicts,
|
|
52
54
|
run_learning_loop,
|
|
53
55
|
run_learning_loop_batched,
|
|
54
56
|
run_normalized_learning_loop,
|
|
55
57
|
run_normalized_learning_loop_batched,
|
|
58
|
+
run_td_learning_loop,
|
|
56
59
|
)
|
|
57
60
|
|
|
58
61
|
# Normalizers
|
|
@@ -63,9 +66,19 @@ from alberta_framework.core.normalizers import (
|
|
|
63
66
|
)
|
|
64
67
|
|
|
65
68
|
# Optimizers
|
|
66
|
-
from alberta_framework.core.optimizers import
|
|
69
|
+
from alberta_framework.core.optimizers import (
|
|
70
|
+
IDBD,
|
|
71
|
+
LMS,
|
|
72
|
+
TDIDBD,
|
|
73
|
+
Autostep,
|
|
74
|
+
AutoTDIDBD,
|
|
75
|
+
Optimizer,
|
|
76
|
+
TDOptimizer,
|
|
77
|
+
TDOptimizerUpdate,
|
|
78
|
+
)
|
|
67
79
|
from alberta_framework.core.types import (
|
|
68
80
|
AutostepState,
|
|
81
|
+
AutoTDIDBDState,
|
|
69
82
|
BatchedLearningResult,
|
|
70
83
|
BatchedNormalizedResult,
|
|
71
84
|
IDBDState,
|
|
@@ -78,7 +91,12 @@ from alberta_framework.core.types import (
|
|
|
78
91
|
StepSizeHistory,
|
|
79
92
|
StepSizeTrackingConfig,
|
|
80
93
|
Target,
|
|
94
|
+
TDIDBDState,
|
|
95
|
+
TDLearnerState,
|
|
96
|
+
TDTimeStep,
|
|
81
97
|
TimeStep,
|
|
98
|
+
create_autotdidbd_state,
|
|
99
|
+
create_tdidbd_state,
|
|
82
100
|
)
|
|
83
101
|
|
|
84
102
|
# Streams - base
|
|
@@ -140,7 +158,7 @@ except ImportError:
|
|
|
140
158
|
__all__ = [
|
|
141
159
|
# Version
|
|
142
160
|
"__version__",
|
|
143
|
-
# Types
|
|
161
|
+
# Types - Supervised Learning
|
|
144
162
|
"AutostepState",
|
|
145
163
|
"BatchedLearningResult",
|
|
146
164
|
"BatchedNormalizedResult",
|
|
@@ -157,15 +175,28 @@ __all__ = [
|
|
|
157
175
|
"Target",
|
|
158
176
|
"TimeStep",
|
|
159
177
|
"UpdateResult",
|
|
160
|
-
#
|
|
178
|
+
# Types - TD Learning
|
|
179
|
+
"AutoTDIDBDState",
|
|
180
|
+
"TDIDBDState",
|
|
181
|
+
"TDLearnerState",
|
|
182
|
+
"TDTimeStep",
|
|
183
|
+
"TDUpdateResult",
|
|
184
|
+
"create_tdidbd_state",
|
|
185
|
+
"create_autotdidbd_state",
|
|
186
|
+
# Optimizers - Supervised Learning
|
|
161
187
|
"Autostep",
|
|
162
188
|
"IDBD",
|
|
163
189
|
"LMS",
|
|
164
190
|
"Optimizer",
|
|
191
|
+
# Optimizers - TD Learning
|
|
192
|
+
"AutoTDIDBD",
|
|
193
|
+
"TDIDBD",
|
|
194
|
+
"TDOptimizer",
|
|
195
|
+
"TDOptimizerUpdate",
|
|
165
196
|
# Normalizers
|
|
166
197
|
"OnlineNormalizer",
|
|
167
198
|
"create_normalizer_state",
|
|
168
|
-
# Learners
|
|
199
|
+
# Learners - Supervised Learning
|
|
169
200
|
"LinearLearner",
|
|
170
201
|
"NormalizedLearnerState",
|
|
171
202
|
"NormalizedLinearLearner",
|
|
@@ -174,6 +205,9 @@ __all__ = [
|
|
|
174
205
|
"run_normalized_learning_loop",
|
|
175
206
|
"run_normalized_learning_loop_batched",
|
|
176
207
|
"metrics_to_dicts",
|
|
208
|
+
# Learners - TD Learning
|
|
209
|
+
"TDLinearLearner",
|
|
210
|
+
"run_td_learning_loop",
|
|
177
211
|
# Streams - protocol
|
|
178
212
|
"ScanStream",
|
|
179
213
|
# Streams - synthetic
|
|
@@ -1,18 +1,31 @@
|
|
|
1
1
|
"""Core components for the Alberta Framework."""
|
|
2
2
|
|
|
3
|
-
from alberta_framework.core.learners import LinearLearner
|
|
4
|
-
from alberta_framework.core.optimizers import
|
|
3
|
+
from alberta_framework.core.learners import LinearLearner, TDLinearLearner, TDUpdateResult
|
|
4
|
+
from alberta_framework.core.optimizers import (
|
|
5
|
+
IDBD,
|
|
6
|
+
LMS,
|
|
7
|
+
TDIDBD,
|
|
8
|
+
AutoTDIDBD,
|
|
9
|
+
Optimizer,
|
|
10
|
+
TDOptimizer,
|
|
11
|
+
TDOptimizerUpdate,
|
|
12
|
+
)
|
|
5
13
|
from alberta_framework.core.types import (
|
|
14
|
+
AutoTDIDBDState,
|
|
6
15
|
IDBDState,
|
|
7
16
|
LearnerState,
|
|
8
17
|
LMSState,
|
|
9
18
|
Observation,
|
|
10
19
|
Prediction,
|
|
11
20
|
Target,
|
|
21
|
+
TDIDBDState,
|
|
22
|
+
TDLearnerState,
|
|
23
|
+
TDTimeStep,
|
|
12
24
|
TimeStep,
|
|
13
25
|
)
|
|
14
26
|
|
|
15
27
|
__all__ = [
|
|
28
|
+
# Supervised learning
|
|
16
29
|
"IDBD",
|
|
17
30
|
"IDBDState",
|
|
18
31
|
"LMS",
|
|
@@ -24,4 +37,15 @@ __all__ = [
|
|
|
24
37
|
"Prediction",
|
|
25
38
|
"Target",
|
|
26
39
|
"TimeStep",
|
|
40
|
+
# TD learning
|
|
41
|
+
"AutoTDIDBD",
|
|
42
|
+
"AutoTDIDBDState",
|
|
43
|
+
"TDIDBD",
|
|
44
|
+
"TDIDBDState",
|
|
45
|
+
"TDLearnerState",
|
|
46
|
+
"TDLinearLearner",
|
|
47
|
+
"TDOptimizer",
|
|
48
|
+
"TDOptimizerUpdate",
|
|
49
|
+
"TDTimeStep",
|
|
50
|
+
"TDUpdateResult",
|
|
27
51
|
]
|
|
@@ -5,7 +5,7 @@ for temporally-uniform learning. Uses JAX's scan for efficient JIT-compiled
|
|
|
5
5
|
training loops.
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
|
-
from typing import cast
|
|
8
|
+
from typing import Protocol, TypeVar, cast
|
|
9
9
|
|
|
10
10
|
import chex
|
|
11
11
|
import jax
|
|
@@ -14,9 +14,10 @@ from jax import Array
|
|
|
14
14
|
from jaxtyping import Float
|
|
15
15
|
|
|
16
16
|
from alberta_framework.core.normalizers import NormalizerState, OnlineNormalizer
|
|
17
|
-
from alberta_framework.core.optimizers import LMS, Optimizer
|
|
17
|
+
from alberta_framework.core.optimizers import LMS, TDIDBD, Optimizer, TDOptimizer
|
|
18
18
|
from alberta_framework.core.types import (
|
|
19
19
|
AutostepState,
|
|
20
|
+
AutoTDIDBDState,
|
|
20
21
|
BatchedLearningResult,
|
|
21
22
|
BatchedNormalizedResult,
|
|
22
23
|
IDBDState,
|
|
@@ -29,12 +30,21 @@ from alberta_framework.core.types import (
|
|
|
29
30
|
StepSizeHistory,
|
|
30
31
|
StepSizeTrackingConfig,
|
|
31
32
|
Target,
|
|
33
|
+
TDIDBDState,
|
|
34
|
+
TDLearnerState,
|
|
35
|
+
TDTimeStep,
|
|
32
36
|
)
|
|
33
37
|
from alberta_framework.streams.base import ScanStream
|
|
34
38
|
|
|
39
|
+
# Type variable for TD stream state
|
|
40
|
+
StateT = TypeVar("StateT")
|
|
41
|
+
|
|
35
42
|
# Type alias for any optimizer type
|
|
36
43
|
AnyOptimizer = Optimizer[LMSState] | Optimizer[IDBDState] | Optimizer[AutostepState]
|
|
37
44
|
|
|
45
|
+
# Type alias for any TD optimizer type
|
|
46
|
+
AnyTDOptimizer = TDOptimizer[TDIDBDState] | TDOptimizer[AutoTDIDBDState]
|
|
47
|
+
|
|
38
48
|
|
|
39
49
|
@chex.dataclass(frozen=True)
|
|
40
50
|
class UpdateResult:
|
|
@@ -167,7 +177,8 @@ class LinearLearner:
|
|
|
167
177
|
# Note: type ignore needed because we can't statically prove optimizer_state
|
|
168
178
|
# matches the optimizer's expected state type (though they will at runtime)
|
|
169
179
|
opt_update = self._optimizer.update(
|
|
170
|
-
state.optimizer_state,
|
|
180
|
+
state.optimizer_state,
|
|
181
|
+
error,
|
|
171
182
|
observation,
|
|
172
183
|
)
|
|
173
184
|
|
|
@@ -270,9 +281,7 @@ def run_learning_loop[StreamStateT](
|
|
|
270
281
|
|
|
271
282
|
# Pre-allocate history arrays
|
|
272
283
|
step_size_history = jnp.zeros((num_recordings, feature_dim), dtype=jnp.float32)
|
|
273
|
-
bias_history = (
|
|
274
|
-
jnp.zeros(num_recordings, dtype=jnp.float32) if include_bias else None
|
|
275
|
-
)
|
|
284
|
+
bias_history = jnp.zeros(num_recordings, dtype=jnp.float32) if include_bias else None
|
|
276
285
|
recording_indices = jnp.zeros(num_recordings, dtype=jnp.int32)
|
|
277
286
|
|
|
278
287
|
# Check if we need to track Autostep normalizers
|
|
@@ -285,9 +294,7 @@ def run_learning_loop[StreamStateT](
|
|
|
285
294
|
)
|
|
286
295
|
|
|
287
296
|
def step_fn_with_tracking(
|
|
288
|
-
carry: tuple[
|
|
289
|
-
LearnerState, StreamStateT, Array, Array | None, Array, Array | None
|
|
290
|
-
],
|
|
297
|
+
carry: tuple[LearnerState, StreamStateT, Array, Array | None, Array, Array | None],
|
|
291
298
|
idx: Array,
|
|
292
299
|
) -> tuple[
|
|
293
300
|
tuple[LearnerState, StreamStateT, Array, Array | None, Array, Array | None],
|
|
@@ -348,8 +355,7 @@ def run_learning_loop[StreamStateT](
|
|
|
348
355
|
if norm_history is not None and hasattr(opt_state, "normalizers"):
|
|
349
356
|
new_norm_history = jax.lax.cond(
|
|
350
357
|
should_record,
|
|
351
|
-
lambda _: norm_history.at[recording_idx].set(
|
|
352
|
-
opt_state.normalizers ),
|
|
358
|
+
lambda _: norm_history.at[recording_idx].set(opt_state.normalizers),
|
|
353
359
|
lambda _: norm_history,
|
|
354
360
|
None,
|
|
355
361
|
)
|
|
@@ -373,15 +379,16 @@ def run_learning_loop[StreamStateT](
|
|
|
373
379
|
)
|
|
374
380
|
|
|
375
381
|
(
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
382
|
+
(
|
|
383
|
+
final_learner,
|
|
384
|
+
_,
|
|
385
|
+
final_ss_history,
|
|
386
|
+
final_b_history,
|
|
387
|
+
final_rec_indices,
|
|
388
|
+
final_norm_history,
|
|
389
|
+
),
|
|
390
|
+
metrics,
|
|
391
|
+
) = jax.lax.scan(step_fn_with_tracking, initial_carry, jnp.arange(num_steps))
|
|
385
392
|
|
|
386
393
|
history = StepSizeHistory(
|
|
387
394
|
step_sizes=final_ss_history,
|
|
@@ -454,9 +461,7 @@ class NormalizedLinearLearner:
|
|
|
454
461
|
Returns:
|
|
455
462
|
Scalar prediction y = w @ normalize(x) + b
|
|
456
463
|
"""
|
|
457
|
-
normalized_obs = self._normalizer.normalize_only(
|
|
458
|
-
state.normalizer_state, observation
|
|
459
|
-
)
|
|
464
|
+
normalized_obs = self._normalizer.normalize_only(state.normalizer_state, observation)
|
|
460
465
|
return self._learner.predict(state.learner_state, normalized_obs)
|
|
461
466
|
|
|
462
467
|
def update(
|
|
@@ -602,9 +607,7 @@ def run_normalized_learning_loop[StreamStateT](
|
|
|
602
607
|
|
|
603
608
|
# Tracking enabled - need to set up history arrays
|
|
604
609
|
ss_interval = step_size_tracking.interval if step_size_tracking else num_steps + 1
|
|
605
|
-
norm_interval =
|
|
606
|
-
normalizer_tracking.interval if normalizer_tracking else num_steps + 1
|
|
607
|
-
)
|
|
610
|
+
norm_interval = normalizer_tracking.interval if normalizer_tracking else num_steps + 1
|
|
608
611
|
|
|
609
612
|
ss_num_recordings = num_steps // ss_interval if step_size_tracking else 0
|
|
610
613
|
norm_num_recordings = num_steps // norm_interval if normalizer_tracking else 0
|
|
@@ -620,14 +623,10 @@ def run_normalized_learning_loop[StreamStateT](
|
|
|
620
623
|
if step_size_tracking and step_size_tracking.include_bias
|
|
621
624
|
else None
|
|
622
625
|
)
|
|
623
|
-
ss_rec_indices = (
|
|
624
|
-
jnp.zeros(ss_num_recordings, dtype=jnp.int32) if step_size_tracking else None
|
|
625
|
-
)
|
|
626
|
+
ss_rec_indices = jnp.zeros(ss_num_recordings, dtype=jnp.int32) if step_size_tracking else None
|
|
626
627
|
|
|
627
628
|
# Check if we need to track Autostep normalizers
|
|
628
|
-
track_autostep_normalizers = hasattr(
|
|
629
|
-
learner_state.learner_state.optimizer_state, "normalizers"
|
|
630
|
-
)
|
|
629
|
+
track_autostep_normalizers = hasattr(learner_state.learner_state.optimizer_state, "normalizers")
|
|
631
630
|
ss_normalizers = (
|
|
632
631
|
jnp.zeros((ss_num_recordings, feature_dim), dtype=jnp.float32)
|
|
633
632
|
if step_size_tracking and track_autostep_normalizers
|
|
@@ -744,8 +743,7 @@ def run_normalized_learning_loop[StreamStateT](
|
|
|
744
743
|
if ss_norm is not None and hasattr(opt_state, "normalizers"):
|
|
745
744
|
new_ss_norm = jax.lax.cond(
|
|
746
745
|
should_record_ss,
|
|
747
|
-
lambda _: ss_norm.at[recording_idx].set(
|
|
748
|
-
opt_state.normalizers ),
|
|
746
|
+
lambda _: ss_norm.at[recording_idx].set(opt_state.normalizers),
|
|
749
747
|
lambda _: ss_norm,
|
|
750
748
|
None,
|
|
751
749
|
)
|
|
@@ -809,18 +807,19 @@ def run_normalized_learning_loop[StreamStateT](
|
|
|
809
807
|
)
|
|
810
808
|
|
|
811
809
|
(
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
810
|
+
(
|
|
811
|
+
final_learner,
|
|
812
|
+
_,
|
|
813
|
+
final_ss_hist,
|
|
814
|
+
final_ss_bias_hist,
|
|
815
|
+
final_ss_rec,
|
|
816
|
+
final_ss_norm,
|
|
817
|
+
final_n_means,
|
|
818
|
+
final_n_vars,
|
|
819
|
+
final_n_rec,
|
|
820
|
+
),
|
|
821
|
+
metrics,
|
|
822
|
+
) = jax.lax.scan(step_fn_with_tracking, initial_carry, jnp.arange(num_steps))
|
|
824
823
|
|
|
825
824
|
# Build return values based on what was tracked
|
|
826
825
|
ss_history_result = None
|
|
@@ -828,14 +827,17 @@ def run_normalized_learning_loop[StreamStateT](
|
|
|
828
827
|
ss_history_result = StepSizeHistory(
|
|
829
828
|
step_sizes=final_ss_hist,
|
|
830
829
|
bias_step_sizes=final_ss_bias_hist,
|
|
831
|
-
recording_indices=final_ss_rec,
|
|
830
|
+
recording_indices=final_ss_rec,
|
|
831
|
+
normalizers=final_ss_norm,
|
|
832
832
|
)
|
|
833
833
|
|
|
834
834
|
norm_history_result = None
|
|
835
835
|
if normalizer_tracking is not None and final_n_means is not None:
|
|
836
836
|
norm_history_result = NormalizerHistory(
|
|
837
837
|
means=final_n_means,
|
|
838
|
-
variances=final_n_vars,
|
|
838
|
+
variances=final_n_vars,
|
|
839
|
+
recording_indices=final_n_rec,
|
|
840
|
+
)
|
|
839
841
|
|
|
840
842
|
# Return appropriate tuple based on what was tracked
|
|
841
843
|
if ss_history_result is not None and norm_history_result is not None:
|
|
@@ -894,15 +896,14 @@ def run_learning_loop_batched[StreamStateT](
|
|
|
894
896
|
mean_error = result.metrics[:, :, 0].mean(axis=0) # Average over seeds
|
|
895
897
|
```
|
|
896
898
|
"""
|
|
899
|
+
|
|
897
900
|
# Define single-seed function that returns consistent structure
|
|
898
901
|
def single_seed_run(key: Array) -> tuple[LearnerState, Array, StepSizeHistory | None]:
|
|
899
902
|
result = run_learning_loop(
|
|
900
903
|
learner, stream, num_steps, key, learner_state, step_size_tracking
|
|
901
904
|
)
|
|
902
905
|
if step_size_tracking is not None:
|
|
903
|
-
state, metrics, history = cast(
|
|
904
|
-
tuple[LearnerState, Array, StepSizeHistory], result
|
|
905
|
-
)
|
|
906
|
+
state, metrics, history = cast(tuple[LearnerState, Array, StepSizeHistory], result)
|
|
906
907
|
return state, metrics, history
|
|
907
908
|
else:
|
|
908
909
|
state, metrics = cast(tuple[LearnerState, Array], result)
|
|
@@ -982,15 +983,13 @@ def run_normalized_learning_loop_batched[StreamStateT](
|
|
|
982
983
|
mean_error = result.metrics[:, :, 0].mean(axis=0) # Average over seeds
|
|
983
984
|
```
|
|
984
985
|
"""
|
|
986
|
+
|
|
985
987
|
# Define single-seed function that returns consistent structure
|
|
986
988
|
def single_seed_run(
|
|
987
989
|
key: Array,
|
|
988
|
-
) -> tuple[
|
|
989
|
-
NormalizedLearnerState, Array, StepSizeHistory | None, NormalizerHistory | None
|
|
990
|
-
]:
|
|
990
|
+
) -> tuple[NormalizedLearnerState, Array, StepSizeHistory | None, NormalizerHistory | None]:
|
|
991
991
|
result = run_normalized_learning_loop(
|
|
992
|
-
learner, stream, num_steps, key, learner_state,
|
|
993
|
-
step_size_tracking, normalizer_tracking
|
|
992
|
+
learner, stream, num_steps, key, learner_state, step_size_tracking, normalizer_tracking
|
|
994
993
|
)
|
|
995
994
|
|
|
996
995
|
# Unpack based on what tracking was enabled
|
|
@@ -1015,9 +1014,9 @@ def run_normalized_learning_loop_batched[StreamStateT](
|
|
|
1015
1014
|
return state, metrics, None, None
|
|
1016
1015
|
|
|
1017
1016
|
# vmap over the keys dimension
|
|
1018
|
-
batched_states, batched_metrics, batched_ss_history, batched_norm_history = (
|
|
1019
|
-
|
|
1020
|
-
)
|
|
1017
|
+
batched_states, batched_metrics, batched_ss_history, batched_norm_history = jax.vmap(
|
|
1018
|
+
single_seed_run
|
|
1019
|
+
)(keys)
|
|
1021
1020
|
|
|
1022
1021
|
# Reconstruct batched histories if tracking was enabled
|
|
1023
1022
|
if step_size_tracking is not None and batched_ss_history is not None:
|
|
@@ -1068,3 +1067,222 @@ def metrics_to_dicts(metrics: Array, normalized: bool = False) -> list[dict[str,
|
|
|
1068
1067
|
d["normalizer_mean_var"] = float(row[3])
|
|
1069
1068
|
result.append(d)
|
|
1070
1069
|
return result
|
|
1070
|
+
|
|
1071
|
+
|
|
1072
|
+
# =============================================================================
|
|
1073
|
+
# TD Learning (for Step 3+ of Alberta Plan)
|
|
1074
|
+
# =============================================================================
|
|
1075
|
+
|
|
1076
|
+
|
|
1077
|
+
@chex.dataclass(frozen=True)
|
|
1078
|
+
class TDUpdateResult:
|
|
1079
|
+
"""Result of a TD learner update step.
|
|
1080
|
+
|
|
1081
|
+
Attributes:
|
|
1082
|
+
state: Updated TD learner state
|
|
1083
|
+
prediction: Value prediction V(s) before update
|
|
1084
|
+
td_error: TD error δ = R + γV(s') - V(s)
|
|
1085
|
+
metrics: Array of metrics [squared_td_error, td_error, mean_step_size, ...]
|
|
1086
|
+
"""
|
|
1087
|
+
|
|
1088
|
+
state: TDLearnerState
|
|
1089
|
+
prediction: Prediction
|
|
1090
|
+
td_error: Float[Array, ""]
|
|
1091
|
+
metrics: Float[Array, " 4"]
|
|
1092
|
+
|
|
1093
|
+
|
|
1094
|
+
class TDLinearLearner:
|
|
1095
|
+
"""Linear function approximator for TD learning.
|
|
1096
|
+
|
|
1097
|
+
Computes value predictions as: `V(s) = w @ φ(s) + b`
|
|
1098
|
+
|
|
1099
|
+
The learner maintains weights, bias, and eligibility traces, delegating
|
|
1100
|
+
the adaptation of learning rates to the TD optimizer (e.g., TDIDBD).
|
|
1101
|
+
|
|
1102
|
+
This follows the Alberta Plan philosophy of temporal uniformity:
|
|
1103
|
+
every component updates at every time step.
|
|
1104
|
+
|
|
1105
|
+
Reference: Kearney et al. 2019, "Learning Feature Relevance Through Step Size
|
|
1106
|
+
Adaptation in Temporal-Difference Learning"
|
|
1107
|
+
|
|
1108
|
+
Attributes:
|
|
1109
|
+
optimizer: The TD optimizer to use for weight updates
|
|
1110
|
+
"""
|
|
1111
|
+
|
|
1112
|
+
def __init__(self, optimizer: AnyTDOptimizer | None = None):
|
|
1113
|
+
"""Initialize the TD linear learner.
|
|
1114
|
+
|
|
1115
|
+
Args:
|
|
1116
|
+
optimizer: TD optimizer for weight updates. Defaults to TDIDBD()
|
|
1117
|
+
"""
|
|
1118
|
+
self._optimizer: AnyTDOptimizer = optimizer or TDIDBD()
|
|
1119
|
+
|
|
1120
|
+
def init(self, feature_dim: int) -> TDLearnerState:
|
|
1121
|
+
"""Initialize TD learner state.
|
|
1122
|
+
|
|
1123
|
+
Args:
|
|
1124
|
+
feature_dim: Dimension of the input feature vector
|
|
1125
|
+
|
|
1126
|
+
Returns:
|
|
1127
|
+
Initial TD learner state with zero weights and bias
|
|
1128
|
+
"""
|
|
1129
|
+
optimizer_state = self._optimizer.init(feature_dim)
|
|
1130
|
+
|
|
1131
|
+
return TDLearnerState(
|
|
1132
|
+
weights=jnp.zeros(feature_dim, dtype=jnp.float32),
|
|
1133
|
+
bias=jnp.array(0.0, dtype=jnp.float32),
|
|
1134
|
+
optimizer_state=optimizer_state,
|
|
1135
|
+
)
|
|
1136
|
+
|
|
1137
|
+
def predict(self, state: TDLearnerState, observation: Observation) -> Prediction:
|
|
1138
|
+
"""Compute value prediction for an observation.
|
|
1139
|
+
|
|
1140
|
+
Args:
|
|
1141
|
+
state: Current TD learner state
|
|
1142
|
+
observation: Input feature vector φ(s)
|
|
1143
|
+
|
|
1144
|
+
Returns:
|
|
1145
|
+
Scalar value prediction `V(s) = w @ φ(s) + b`
|
|
1146
|
+
"""
|
|
1147
|
+
return jnp.atleast_1d(jnp.dot(state.weights, observation) + state.bias)
|
|
1148
|
+
|
|
1149
|
+
def update(
|
|
1150
|
+
self,
|
|
1151
|
+
state: TDLearnerState,
|
|
1152
|
+
observation: Observation,
|
|
1153
|
+
reward: Array,
|
|
1154
|
+
next_observation: Observation,
|
|
1155
|
+
gamma: Array,
|
|
1156
|
+
) -> TDUpdateResult:
|
|
1157
|
+
"""Update learner given a TD transition.
|
|
1158
|
+
|
|
1159
|
+
Performs one step of TD learning:
|
|
1160
|
+
1. Compute V(s) and V(s')
|
|
1161
|
+
2. Compute TD error δ = R + γV(s') - V(s)
|
|
1162
|
+
3. Get weight updates from TD optimizer
|
|
1163
|
+
4. Apply updates to weights and bias
|
|
1164
|
+
|
|
1165
|
+
Args:
|
|
1166
|
+
state: Current TD learner state
|
|
1167
|
+
observation: Current observation φ(s)
|
|
1168
|
+
reward: Reward R received
|
|
1169
|
+
next_observation: Next observation φ(s')
|
|
1170
|
+
gamma: Discount factor γ (0 at terminal states)
|
|
1171
|
+
|
|
1172
|
+
Returns:
|
|
1173
|
+
TDUpdateResult with new state, prediction, TD error, and metrics
|
|
1174
|
+
"""
|
|
1175
|
+
# Compute predictions
|
|
1176
|
+
prediction = self.predict(state, observation)
|
|
1177
|
+
next_prediction = self.predict(state, next_observation)
|
|
1178
|
+
|
|
1179
|
+
# Compute TD error: δ = R + γV(s') - V(s)
|
|
1180
|
+
gamma_scalar = jnp.squeeze(gamma)
|
|
1181
|
+
td_error = (
|
|
1182
|
+
jnp.squeeze(reward)
|
|
1183
|
+
+ gamma_scalar * jnp.squeeze(next_prediction)
|
|
1184
|
+
- jnp.squeeze(prediction)
|
|
1185
|
+
)
|
|
1186
|
+
|
|
1187
|
+
# Get update from TD optimizer
|
|
1188
|
+
opt_update = self._optimizer.update(
|
|
1189
|
+
state.optimizer_state,
|
|
1190
|
+
td_error,
|
|
1191
|
+
observation,
|
|
1192
|
+
next_observation,
|
|
1193
|
+
gamma,
|
|
1194
|
+
)
|
|
1195
|
+
|
|
1196
|
+
# Apply updates
|
|
1197
|
+
new_weights = state.weights + opt_update.weight_delta
|
|
1198
|
+
new_bias = state.bias + opt_update.bias_delta
|
|
1199
|
+
|
|
1200
|
+
new_state = TDLearnerState(
|
|
1201
|
+
weights=new_weights,
|
|
1202
|
+
bias=new_bias,
|
|
1203
|
+
optimizer_state=opt_update.new_state,
|
|
1204
|
+
)
|
|
1205
|
+
|
|
1206
|
+
# Pack metrics as array for scan compatibility
|
|
1207
|
+
# Format: [squared_td_error, td_error, mean_step_size, mean_eligibility_trace]
|
|
1208
|
+
squared_td_error = td_error**2
|
|
1209
|
+
mean_step_size = opt_update.metrics.get("mean_step_size", 0.0)
|
|
1210
|
+
mean_elig_trace = opt_update.metrics.get("mean_eligibility_trace", 0.0)
|
|
1211
|
+
metrics = jnp.array(
|
|
1212
|
+
[squared_td_error, td_error, mean_step_size, mean_elig_trace],
|
|
1213
|
+
dtype=jnp.float32,
|
|
1214
|
+
)
|
|
1215
|
+
|
|
1216
|
+
return TDUpdateResult(
|
|
1217
|
+
state=new_state,
|
|
1218
|
+
prediction=prediction,
|
|
1219
|
+
td_error=jnp.atleast_1d(td_error),
|
|
1220
|
+
metrics=metrics,
|
|
1221
|
+
)
|
|
1222
|
+
|
|
1223
|
+
|
|
1224
|
+
class TDStream(Protocol[StateT]):
|
|
1225
|
+
"""Protocol for TD experience streams.
|
|
1226
|
+
|
|
1227
|
+
TD streams produce (s, r, s', γ) tuples for temporal-difference learning.
|
|
1228
|
+
"""
|
|
1229
|
+
|
|
1230
|
+
feature_dim: int
|
|
1231
|
+
|
|
1232
|
+
def init(self, key: Array) -> StateT:
|
|
1233
|
+
"""Initialize stream state."""
|
|
1234
|
+
...
|
|
1235
|
+
|
|
1236
|
+
def step(self, state: StateT, idx: Array) -> tuple[TDTimeStep, StateT]:
|
|
1237
|
+
"""Generate next TD transition."""
|
|
1238
|
+
...
|
|
1239
|
+
|
|
1240
|
+
|
|
1241
|
+
def run_td_learning_loop[StreamStateT](
|
|
1242
|
+
learner: TDLinearLearner,
|
|
1243
|
+
stream: TDStream[StreamStateT],
|
|
1244
|
+
num_steps: int,
|
|
1245
|
+
key: Array,
|
|
1246
|
+
learner_state: TDLearnerState | None = None,
|
|
1247
|
+
) -> tuple[TDLearnerState, Array]:
|
|
1248
|
+
"""Run the TD learning loop using jax.lax.scan.
|
|
1249
|
+
|
|
1250
|
+
This is a JIT-compiled learning loop that uses scan for efficiency.
|
|
1251
|
+
It returns metrics as a fixed-size array rather than a list of dicts.
|
|
1252
|
+
|
|
1253
|
+
Args:
|
|
1254
|
+
learner: The TD learner to train
|
|
1255
|
+
stream: TD experience stream providing (s, r, s', γ) tuples
|
|
1256
|
+
num_steps: Number of learning steps to run
|
|
1257
|
+
key: JAX random key for stream initialization
|
|
1258
|
+
learner_state: Initial state (if None, will be initialized from stream)
|
|
1259
|
+
|
|
1260
|
+
Returns:
|
|
1261
|
+
Tuple of (final_state, metrics_array) where metrics_array has shape
|
|
1262
|
+
(num_steps, 4) with columns [squared_td_error, td_error, mean_step_size,
|
|
1263
|
+
mean_eligibility_trace]
|
|
1264
|
+
"""
|
|
1265
|
+
# Initialize states
|
|
1266
|
+
if learner_state is None:
|
|
1267
|
+
learner_state = learner.init(stream.feature_dim)
|
|
1268
|
+
stream_state = stream.init(key)
|
|
1269
|
+
|
|
1270
|
+
def step_fn(
|
|
1271
|
+
carry: tuple[TDLearnerState, StreamStateT], idx: Array
|
|
1272
|
+
) -> tuple[tuple[TDLearnerState, StreamStateT], Array]:
|
|
1273
|
+
l_state, s_state = carry
|
|
1274
|
+
timestep, new_s_state = stream.step(s_state, idx)
|
|
1275
|
+
result = learner.update(
|
|
1276
|
+
l_state,
|
|
1277
|
+
timestep.observation,
|
|
1278
|
+
timestep.reward,
|
|
1279
|
+
timestep.next_observation,
|
|
1280
|
+
timestep.gamma,
|
|
1281
|
+
)
|
|
1282
|
+
return (result.state, new_s_state), result.metrics
|
|
1283
|
+
|
|
1284
|
+
(final_learner, _), metrics = jax.lax.scan(
|
|
1285
|
+
step_fn, (learner_state, stream_state), jnp.arange(num_steps)
|
|
1286
|
+
)
|
|
1287
|
+
|
|
1288
|
+
return final_learner, metrics
|
|
@@ -101,10 +101,7 @@ class OnlineNormalizer:
|
|
|
101
101
|
|
|
102
102
|
# Compute effective decay (ramp up from 0 to target decay)
|
|
103
103
|
# This prevents instability in early steps
|
|
104
|
-
effective_decay = jnp.minimum(
|
|
105
|
-
state.decay,
|
|
106
|
-
1.0 - 1.0 / (new_count + 1.0)
|
|
107
|
-
)
|
|
104
|
+
effective_decay = jnp.minimum(state.decay, 1.0 - 1.0 / (new_count + 1.0))
|
|
108
105
|
|
|
109
106
|
# Update mean using exponential moving average
|
|
110
107
|
delta = observation - state.mean
|