alberta-framework 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -39,7 +39,7 @@ References
39
39
  - Tuning-free Step-size Adaptation (Mahmood et al., 2012)
40
40
  """
41
41
 
42
- __version__ = "0.2.0"
42
+ __version__ = "0.4.0"
43
43
 
44
44
  # Core types
45
45
  # Learners
@@ -47,12 +47,15 @@ from alberta_framework.core.learners import (
47
47
  LinearLearner,
48
48
  NormalizedLearnerState,
49
49
  NormalizedLinearLearner,
50
+ TDLinearLearner,
51
+ TDUpdateResult,
50
52
  UpdateResult,
51
53
  metrics_to_dicts,
52
54
  run_learning_loop,
53
55
  run_learning_loop_batched,
54
56
  run_normalized_learning_loop,
55
57
  run_normalized_learning_loop_batched,
58
+ run_td_learning_loop,
56
59
  )
57
60
 
58
61
  # Normalizers
@@ -63,9 +66,19 @@ from alberta_framework.core.normalizers import (
63
66
  )
64
67
 
65
68
  # Optimizers
66
- from alberta_framework.core.optimizers import IDBD, LMS, Autostep, Optimizer
69
+ from alberta_framework.core.optimizers import (
70
+ IDBD,
71
+ LMS,
72
+ TDIDBD,
73
+ Autostep,
74
+ AutoTDIDBD,
75
+ Optimizer,
76
+ TDOptimizer,
77
+ TDOptimizerUpdate,
78
+ )
67
79
  from alberta_framework.core.types import (
68
80
  AutostepState,
81
+ AutoTDIDBDState,
69
82
  BatchedLearningResult,
70
83
  BatchedNormalizedResult,
71
84
  IDBDState,
@@ -78,7 +91,12 @@ from alberta_framework.core.types import (
78
91
  StepSizeHistory,
79
92
  StepSizeTrackingConfig,
80
93
  Target,
94
+ TDIDBDState,
95
+ TDLearnerState,
96
+ TDTimeStep,
81
97
  TimeStep,
98
+ create_autotdidbd_state,
99
+ create_tdidbd_state,
82
100
  )
83
101
 
84
102
  # Streams - base
@@ -140,7 +158,7 @@ except ImportError:
140
158
  __all__ = [
141
159
  # Version
142
160
  "__version__",
143
- # Types
161
+ # Types - Supervised Learning
144
162
  "AutostepState",
145
163
  "BatchedLearningResult",
146
164
  "BatchedNormalizedResult",
@@ -157,15 +175,28 @@ __all__ = [
157
175
  "Target",
158
176
  "TimeStep",
159
177
  "UpdateResult",
160
- # Optimizers
178
+ # Types - TD Learning
179
+ "AutoTDIDBDState",
180
+ "TDIDBDState",
181
+ "TDLearnerState",
182
+ "TDTimeStep",
183
+ "TDUpdateResult",
184
+ "create_tdidbd_state",
185
+ "create_autotdidbd_state",
186
+ # Optimizers - Supervised Learning
161
187
  "Autostep",
162
188
  "IDBD",
163
189
  "LMS",
164
190
  "Optimizer",
191
+ # Optimizers - TD Learning
192
+ "AutoTDIDBD",
193
+ "TDIDBD",
194
+ "TDOptimizer",
195
+ "TDOptimizerUpdate",
165
196
  # Normalizers
166
197
  "OnlineNormalizer",
167
198
  "create_normalizer_state",
168
- # Learners
199
+ # Learners - Supervised Learning
169
200
  "LinearLearner",
170
201
  "NormalizedLearnerState",
171
202
  "NormalizedLinearLearner",
@@ -174,6 +205,9 @@ __all__ = [
174
205
  "run_normalized_learning_loop",
175
206
  "run_normalized_learning_loop_batched",
176
207
  "metrics_to_dicts",
208
+ # Learners - TD Learning
209
+ "TDLinearLearner",
210
+ "run_td_learning_loop",
177
211
  # Streams - protocol
178
212
  "ScanStream",
179
213
  # Streams - synthetic
@@ -1,18 +1,31 @@
1
1
  """Core components for the Alberta Framework."""
2
2
 
3
- from alberta_framework.core.learners import LinearLearner
4
- from alberta_framework.core.optimizers import IDBD, LMS, Optimizer
3
+ from alberta_framework.core.learners import LinearLearner, TDLinearLearner, TDUpdateResult
4
+ from alberta_framework.core.optimizers import (
5
+ IDBD,
6
+ LMS,
7
+ TDIDBD,
8
+ AutoTDIDBD,
9
+ Optimizer,
10
+ TDOptimizer,
11
+ TDOptimizerUpdate,
12
+ )
5
13
  from alberta_framework.core.types import (
14
+ AutoTDIDBDState,
6
15
  IDBDState,
7
16
  LearnerState,
8
17
  LMSState,
9
18
  Observation,
10
19
  Prediction,
11
20
  Target,
21
+ TDIDBDState,
22
+ TDLearnerState,
23
+ TDTimeStep,
12
24
  TimeStep,
13
25
  )
14
26
 
15
27
  __all__ = [
28
+ # Supervised learning
16
29
  "IDBD",
17
30
  "IDBDState",
18
31
  "LMS",
@@ -24,4 +37,15 @@ __all__ = [
24
37
  "Prediction",
25
38
  "Target",
26
39
  "TimeStep",
40
+ # TD learning
41
+ "AutoTDIDBD",
42
+ "AutoTDIDBDState",
43
+ "TDIDBD",
44
+ "TDIDBDState",
45
+ "TDLearnerState",
46
+ "TDLinearLearner",
47
+ "TDOptimizer",
48
+ "TDOptimizerUpdate",
49
+ "TDTimeStep",
50
+ "TDUpdateResult",
27
51
  ]
@@ -5,7 +5,7 @@ for temporally-uniform learning. Uses JAX's scan for efficient JIT-compiled
5
5
  training loops.
6
6
  """
7
7
 
8
- from typing import cast
8
+ from typing import Protocol, TypeVar, cast
9
9
 
10
10
  import chex
11
11
  import jax
@@ -14,9 +14,10 @@ from jax import Array
14
14
  from jaxtyping import Float
15
15
 
16
16
  from alberta_framework.core.normalizers import NormalizerState, OnlineNormalizer
17
- from alberta_framework.core.optimizers import LMS, Optimizer
17
+ from alberta_framework.core.optimizers import LMS, TDIDBD, Optimizer, TDOptimizer
18
18
  from alberta_framework.core.types import (
19
19
  AutostepState,
20
+ AutoTDIDBDState,
20
21
  BatchedLearningResult,
21
22
  BatchedNormalizedResult,
22
23
  IDBDState,
@@ -29,12 +30,21 @@ from alberta_framework.core.types import (
29
30
  StepSizeHistory,
30
31
  StepSizeTrackingConfig,
31
32
  Target,
33
+ TDIDBDState,
34
+ TDLearnerState,
35
+ TDTimeStep,
32
36
  )
33
37
  from alberta_framework.streams.base import ScanStream
34
38
 
39
+ # Type variable for TD stream state
40
+ StateT = TypeVar("StateT")
41
+
35
42
  # Type alias for any optimizer type
36
43
  AnyOptimizer = Optimizer[LMSState] | Optimizer[IDBDState] | Optimizer[AutostepState]
37
44
 
45
+ # Type alias for any TD optimizer type
46
+ AnyTDOptimizer = TDOptimizer[TDIDBDState] | TDOptimizer[AutoTDIDBDState]
47
+
38
48
 
39
49
  @chex.dataclass(frozen=True)
40
50
  class UpdateResult:
@@ -167,7 +177,8 @@ class LinearLearner:
167
177
  # Note: type ignore needed because we can't statically prove optimizer_state
168
178
  # matches the optimizer's expected state type (though they will at runtime)
169
179
  opt_update = self._optimizer.update(
170
- state.optimizer_state, error,
180
+ state.optimizer_state,
181
+ error,
171
182
  observation,
172
183
  )
173
184
 
@@ -270,9 +281,7 @@ def run_learning_loop[StreamStateT](
270
281
 
271
282
  # Pre-allocate history arrays
272
283
  step_size_history = jnp.zeros((num_recordings, feature_dim), dtype=jnp.float32)
273
- bias_history = (
274
- jnp.zeros(num_recordings, dtype=jnp.float32) if include_bias else None
275
- )
284
+ bias_history = jnp.zeros(num_recordings, dtype=jnp.float32) if include_bias else None
276
285
  recording_indices = jnp.zeros(num_recordings, dtype=jnp.int32)
277
286
 
278
287
  # Check if we need to track Autostep normalizers
@@ -285,9 +294,7 @@ def run_learning_loop[StreamStateT](
285
294
  )
286
295
 
287
296
  def step_fn_with_tracking(
288
- carry: tuple[
289
- LearnerState, StreamStateT, Array, Array | None, Array, Array | None
290
- ],
297
+ carry: tuple[LearnerState, StreamStateT, Array, Array | None, Array, Array | None],
291
298
  idx: Array,
292
299
  ) -> tuple[
293
300
  tuple[LearnerState, StreamStateT, Array, Array | None, Array, Array | None],
@@ -348,8 +355,7 @@ def run_learning_loop[StreamStateT](
348
355
  if norm_history is not None and hasattr(opt_state, "normalizers"):
349
356
  new_norm_history = jax.lax.cond(
350
357
  should_record,
351
- lambda _: norm_history.at[recording_idx].set(
352
- opt_state.normalizers ),
358
+ lambda _: norm_history.at[recording_idx].set(opt_state.normalizers),
353
359
  lambda _: norm_history,
354
360
  None,
355
361
  )
@@ -373,15 +379,16 @@ def run_learning_loop[StreamStateT](
373
379
  )
374
380
 
375
381
  (
376
- final_learner,
377
- _,
378
- final_ss_history,
379
- final_b_history,
380
- final_rec_indices,
381
- final_norm_history,
382
- ), metrics = jax.lax.scan(
383
- step_fn_with_tracking, initial_carry, jnp.arange(num_steps)
384
- )
382
+ (
383
+ final_learner,
384
+ _,
385
+ final_ss_history,
386
+ final_b_history,
387
+ final_rec_indices,
388
+ final_norm_history,
389
+ ),
390
+ metrics,
391
+ ) = jax.lax.scan(step_fn_with_tracking, initial_carry, jnp.arange(num_steps))
385
392
 
386
393
  history = StepSizeHistory(
387
394
  step_sizes=final_ss_history,
@@ -454,9 +461,7 @@ class NormalizedLinearLearner:
454
461
  Returns:
455
462
  Scalar prediction y = w @ normalize(x) + b
456
463
  """
457
- normalized_obs = self._normalizer.normalize_only(
458
- state.normalizer_state, observation
459
- )
464
+ normalized_obs = self._normalizer.normalize_only(state.normalizer_state, observation)
460
465
  return self._learner.predict(state.learner_state, normalized_obs)
461
466
 
462
467
  def update(
@@ -602,9 +607,7 @@ def run_normalized_learning_loop[StreamStateT](
602
607
 
603
608
  # Tracking enabled - need to set up history arrays
604
609
  ss_interval = step_size_tracking.interval if step_size_tracking else num_steps + 1
605
- norm_interval = (
606
- normalizer_tracking.interval if normalizer_tracking else num_steps + 1
607
- )
610
+ norm_interval = normalizer_tracking.interval if normalizer_tracking else num_steps + 1
608
611
 
609
612
  ss_num_recordings = num_steps // ss_interval if step_size_tracking else 0
610
613
  norm_num_recordings = num_steps // norm_interval if normalizer_tracking else 0
@@ -620,14 +623,10 @@ def run_normalized_learning_loop[StreamStateT](
620
623
  if step_size_tracking and step_size_tracking.include_bias
621
624
  else None
622
625
  )
623
- ss_rec_indices = (
624
- jnp.zeros(ss_num_recordings, dtype=jnp.int32) if step_size_tracking else None
625
- )
626
+ ss_rec_indices = jnp.zeros(ss_num_recordings, dtype=jnp.int32) if step_size_tracking else None
626
627
 
627
628
  # Check if we need to track Autostep normalizers
628
- track_autostep_normalizers = hasattr(
629
- learner_state.learner_state.optimizer_state, "normalizers"
630
- )
629
+ track_autostep_normalizers = hasattr(learner_state.learner_state.optimizer_state, "normalizers")
631
630
  ss_normalizers = (
632
631
  jnp.zeros((ss_num_recordings, feature_dim), dtype=jnp.float32)
633
632
  if step_size_tracking and track_autostep_normalizers
@@ -744,8 +743,7 @@ def run_normalized_learning_loop[StreamStateT](
744
743
  if ss_norm is not None and hasattr(opt_state, "normalizers"):
745
744
  new_ss_norm = jax.lax.cond(
746
745
  should_record_ss,
747
- lambda _: ss_norm.at[recording_idx].set(
748
- opt_state.normalizers ),
746
+ lambda _: ss_norm.at[recording_idx].set(opt_state.normalizers),
749
747
  lambda _: ss_norm,
750
748
  None,
751
749
  )
@@ -809,18 +807,19 @@ def run_normalized_learning_loop[StreamStateT](
809
807
  )
810
808
 
811
809
  (
812
- final_learner,
813
- _,
814
- final_ss_hist,
815
- final_ss_bias_hist,
816
- final_ss_rec,
817
- final_ss_norm,
818
- final_n_means,
819
- final_n_vars,
820
- final_n_rec,
821
- ), metrics = jax.lax.scan(
822
- step_fn_with_tracking, initial_carry, jnp.arange(num_steps)
823
- )
810
+ (
811
+ final_learner,
812
+ _,
813
+ final_ss_hist,
814
+ final_ss_bias_hist,
815
+ final_ss_rec,
816
+ final_ss_norm,
817
+ final_n_means,
818
+ final_n_vars,
819
+ final_n_rec,
820
+ ),
821
+ metrics,
822
+ ) = jax.lax.scan(step_fn_with_tracking, initial_carry, jnp.arange(num_steps))
824
823
 
825
824
  # Build return values based on what was tracked
826
825
  ss_history_result = None
@@ -828,14 +827,17 @@ def run_normalized_learning_loop[StreamStateT](
828
827
  ss_history_result = StepSizeHistory(
829
828
  step_sizes=final_ss_hist,
830
829
  bias_step_sizes=final_ss_bias_hist,
831
- recording_indices=final_ss_rec, normalizers=final_ss_norm,
830
+ recording_indices=final_ss_rec,
831
+ normalizers=final_ss_norm,
832
832
  )
833
833
 
834
834
  norm_history_result = None
835
835
  if normalizer_tracking is not None and final_n_means is not None:
836
836
  norm_history_result = NormalizerHistory(
837
837
  means=final_n_means,
838
- variances=final_n_vars, recording_indices=final_n_rec, )
838
+ variances=final_n_vars,
839
+ recording_indices=final_n_rec,
840
+ )
839
841
 
840
842
  # Return appropriate tuple based on what was tracked
841
843
  if ss_history_result is not None and norm_history_result is not None:
@@ -894,15 +896,14 @@ def run_learning_loop_batched[StreamStateT](
894
896
  mean_error = result.metrics[:, :, 0].mean(axis=0) # Average over seeds
895
897
  ```
896
898
  """
899
+
897
900
  # Define single-seed function that returns consistent structure
898
901
  def single_seed_run(key: Array) -> tuple[LearnerState, Array, StepSizeHistory | None]:
899
902
  result = run_learning_loop(
900
903
  learner, stream, num_steps, key, learner_state, step_size_tracking
901
904
  )
902
905
  if step_size_tracking is not None:
903
- state, metrics, history = cast(
904
- tuple[LearnerState, Array, StepSizeHistory], result
905
- )
906
+ state, metrics, history = cast(tuple[LearnerState, Array, StepSizeHistory], result)
906
907
  return state, metrics, history
907
908
  else:
908
909
  state, metrics = cast(tuple[LearnerState, Array], result)
@@ -982,15 +983,13 @@ def run_normalized_learning_loop_batched[StreamStateT](
982
983
  mean_error = result.metrics[:, :, 0].mean(axis=0) # Average over seeds
983
984
  ```
984
985
  """
986
+
985
987
  # Define single-seed function that returns consistent structure
986
988
  def single_seed_run(
987
989
  key: Array,
988
- ) -> tuple[
989
- NormalizedLearnerState, Array, StepSizeHistory | None, NormalizerHistory | None
990
- ]:
990
+ ) -> tuple[NormalizedLearnerState, Array, StepSizeHistory | None, NormalizerHistory | None]:
991
991
  result = run_normalized_learning_loop(
992
- learner, stream, num_steps, key, learner_state,
993
- step_size_tracking, normalizer_tracking
992
+ learner, stream, num_steps, key, learner_state, step_size_tracking, normalizer_tracking
994
993
  )
995
994
 
996
995
  # Unpack based on what tracking was enabled
@@ -1015,9 +1014,9 @@ def run_normalized_learning_loop_batched[StreamStateT](
1015
1014
  return state, metrics, None, None
1016
1015
 
1017
1016
  # vmap over the keys dimension
1018
- batched_states, batched_metrics, batched_ss_history, batched_norm_history = (
1019
- jax.vmap(single_seed_run)(keys)
1020
- )
1017
+ batched_states, batched_metrics, batched_ss_history, batched_norm_history = jax.vmap(
1018
+ single_seed_run
1019
+ )(keys)
1021
1020
 
1022
1021
  # Reconstruct batched histories if tracking was enabled
1023
1022
  if step_size_tracking is not None and batched_ss_history is not None:
@@ -1068,3 +1067,222 @@ def metrics_to_dicts(metrics: Array, normalized: bool = False) -> list[dict[str,
1068
1067
  d["normalizer_mean_var"] = float(row[3])
1069
1068
  result.append(d)
1070
1069
  return result
1070
+
1071
+
1072
+ # =============================================================================
1073
+ # TD Learning (for Step 3+ of Alberta Plan)
1074
+ # =============================================================================
1075
+
1076
+
1077
+ @chex.dataclass(frozen=True)
1078
+ class TDUpdateResult:
1079
+ """Result of a TD learner update step.
1080
+
1081
+ Attributes:
1082
+ state: Updated TD learner state
1083
+ prediction: Value prediction V(s) before update
1084
+ td_error: TD error δ = R + γV(s') - V(s)
1085
+ metrics: Array of metrics [squared_td_error, td_error, mean_step_size, ...]
1086
+ """
1087
+
1088
+ state: TDLearnerState
1089
+ prediction: Prediction
1090
+ td_error: Float[Array, ""]
1091
+ metrics: Float[Array, " 4"]
1092
+
1093
+
1094
+ class TDLinearLearner:
1095
+ """Linear function approximator for TD learning.
1096
+
1097
+ Computes value predictions as: `V(s) = w @ φ(s) + b`
1098
+
1099
+ The learner maintains weights, bias, and eligibility traces, delegating
1100
+ the adaptation of learning rates to the TD optimizer (e.g., TDIDBD).
1101
+
1102
+ This follows the Alberta Plan philosophy of temporal uniformity:
1103
+ every component updates at every time step.
1104
+
1105
+ Reference: Kearney et al. 2019, "Learning Feature Relevance Through Step Size
1106
+ Adaptation in Temporal-Difference Learning"
1107
+
1108
+ Attributes:
1109
+ optimizer: The TD optimizer to use for weight updates
1110
+ """
1111
+
1112
+ def __init__(self, optimizer: AnyTDOptimizer | None = None):
1113
+ """Initialize the TD linear learner.
1114
+
1115
+ Args:
1116
+ optimizer: TD optimizer for weight updates. Defaults to TDIDBD()
1117
+ """
1118
+ self._optimizer: AnyTDOptimizer = optimizer or TDIDBD()
1119
+
1120
+ def init(self, feature_dim: int) -> TDLearnerState:
1121
+ """Initialize TD learner state.
1122
+
1123
+ Args:
1124
+ feature_dim: Dimension of the input feature vector
1125
+
1126
+ Returns:
1127
+ Initial TD learner state with zero weights and bias
1128
+ """
1129
+ optimizer_state = self._optimizer.init(feature_dim)
1130
+
1131
+ return TDLearnerState(
1132
+ weights=jnp.zeros(feature_dim, dtype=jnp.float32),
1133
+ bias=jnp.array(0.0, dtype=jnp.float32),
1134
+ optimizer_state=optimizer_state,
1135
+ )
1136
+
1137
+ def predict(self, state: TDLearnerState, observation: Observation) -> Prediction:
1138
+ """Compute value prediction for an observation.
1139
+
1140
+ Args:
1141
+ state: Current TD learner state
1142
+ observation: Input feature vector φ(s)
1143
+
1144
+ Returns:
1145
+ Scalar value prediction `V(s) = w @ φ(s) + b`
1146
+ """
1147
+ return jnp.atleast_1d(jnp.dot(state.weights, observation) + state.bias)
1148
+
1149
+ def update(
1150
+ self,
1151
+ state: TDLearnerState,
1152
+ observation: Observation,
1153
+ reward: Array,
1154
+ next_observation: Observation,
1155
+ gamma: Array,
1156
+ ) -> TDUpdateResult:
1157
+ """Update learner given a TD transition.
1158
+
1159
+ Performs one step of TD learning:
1160
+ 1. Compute V(s) and V(s')
1161
+ 2. Compute TD error δ = R + γV(s') - V(s)
1162
+ 3. Get weight updates from TD optimizer
1163
+ 4. Apply updates to weights and bias
1164
+
1165
+ Args:
1166
+ state: Current TD learner state
1167
+ observation: Current observation φ(s)
1168
+ reward: Reward R received
1169
+ next_observation: Next observation φ(s')
1170
+ gamma: Discount factor γ (0 at terminal states)
1171
+
1172
+ Returns:
1173
+ TDUpdateResult with new state, prediction, TD error, and metrics
1174
+ """
1175
+ # Compute predictions
1176
+ prediction = self.predict(state, observation)
1177
+ next_prediction = self.predict(state, next_observation)
1178
+
1179
+ # Compute TD error: δ = R + γV(s') - V(s)
1180
+ gamma_scalar = jnp.squeeze(gamma)
1181
+ td_error = (
1182
+ jnp.squeeze(reward)
1183
+ + gamma_scalar * jnp.squeeze(next_prediction)
1184
+ - jnp.squeeze(prediction)
1185
+ )
1186
+
1187
+ # Get update from TD optimizer
1188
+ opt_update = self._optimizer.update(
1189
+ state.optimizer_state,
1190
+ td_error,
1191
+ observation,
1192
+ next_observation,
1193
+ gamma,
1194
+ )
1195
+
1196
+ # Apply updates
1197
+ new_weights = state.weights + opt_update.weight_delta
1198
+ new_bias = state.bias + opt_update.bias_delta
1199
+
1200
+ new_state = TDLearnerState(
1201
+ weights=new_weights,
1202
+ bias=new_bias,
1203
+ optimizer_state=opt_update.new_state,
1204
+ )
1205
+
1206
+ # Pack metrics as array for scan compatibility
1207
+ # Format: [squared_td_error, td_error, mean_step_size, mean_eligibility_trace]
1208
+ squared_td_error = td_error**2
1209
+ mean_step_size = opt_update.metrics.get("mean_step_size", 0.0)
1210
+ mean_elig_trace = opt_update.metrics.get("mean_eligibility_trace", 0.0)
1211
+ metrics = jnp.array(
1212
+ [squared_td_error, td_error, mean_step_size, mean_elig_trace],
1213
+ dtype=jnp.float32,
1214
+ )
1215
+
1216
+ return TDUpdateResult(
1217
+ state=new_state,
1218
+ prediction=prediction,
1219
+ td_error=jnp.atleast_1d(td_error),
1220
+ metrics=metrics,
1221
+ )
1222
+
1223
+
1224
+ class TDStream(Protocol[StateT]):
1225
+ """Protocol for TD experience streams.
1226
+
1227
+ TD streams produce (s, r, s', γ) tuples for temporal-difference learning.
1228
+ """
1229
+
1230
+ feature_dim: int
1231
+
1232
+ def init(self, key: Array) -> StateT:
1233
+ """Initialize stream state."""
1234
+ ...
1235
+
1236
+ def step(self, state: StateT, idx: Array) -> tuple[TDTimeStep, StateT]:
1237
+ """Generate next TD transition."""
1238
+ ...
1239
+
1240
+
1241
+ def run_td_learning_loop[StreamStateT](
1242
+ learner: TDLinearLearner,
1243
+ stream: TDStream[StreamStateT],
1244
+ num_steps: int,
1245
+ key: Array,
1246
+ learner_state: TDLearnerState | None = None,
1247
+ ) -> tuple[TDLearnerState, Array]:
1248
+ """Run the TD learning loop using jax.lax.scan.
1249
+
1250
+ This is a JIT-compiled learning loop that uses scan for efficiency.
1251
+ It returns metrics as a fixed-size array rather than a list of dicts.
1252
+
1253
+ Args:
1254
+ learner: The TD learner to train
1255
+ stream: TD experience stream providing (s, r, s', γ) tuples
1256
+ num_steps: Number of learning steps to run
1257
+ key: JAX random key for stream initialization
1258
+ learner_state: Initial state (if None, will be initialized from stream)
1259
+
1260
+ Returns:
1261
+ Tuple of (final_state, metrics_array) where metrics_array has shape
1262
+ (num_steps, 4) with columns [squared_td_error, td_error, mean_step_size,
1263
+ mean_eligibility_trace]
1264
+ """
1265
+ # Initialize states
1266
+ if learner_state is None:
1267
+ learner_state = learner.init(stream.feature_dim)
1268
+ stream_state = stream.init(key)
1269
+
1270
+ def step_fn(
1271
+ carry: tuple[TDLearnerState, StreamStateT], idx: Array
1272
+ ) -> tuple[tuple[TDLearnerState, StreamStateT], Array]:
1273
+ l_state, s_state = carry
1274
+ timestep, new_s_state = stream.step(s_state, idx)
1275
+ result = learner.update(
1276
+ l_state,
1277
+ timestep.observation,
1278
+ timestep.reward,
1279
+ timestep.next_observation,
1280
+ timestep.gamma,
1281
+ )
1282
+ return (result.state, new_s_state), result.metrics
1283
+
1284
+ (final_learner, _), metrics = jax.lax.scan(
1285
+ step_fn, (learner_state, stream_state), jnp.arange(num_steps)
1286
+ )
1287
+
1288
+ return final_learner, metrics
@@ -101,10 +101,7 @@ class OnlineNormalizer:
101
101
 
102
102
  # Compute effective decay (ramp up from 0 to target decay)
103
103
  # This prevents instability in early steps
104
- effective_decay = jnp.minimum(
105
- state.decay,
106
- 1.0 - 1.0 / (new_count + 1.0)
107
- )
104
+ effective_decay = jnp.minimum(state.decay, 1.0 - 1.0 / (new_count + 1.0))
108
105
 
109
106
  # Update mean using exponential moving average
110
107
  delta = observation - state.mean