alberta-framework 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -16,7 +16,13 @@ import jax.numpy as jnp
16
16
  from jax import Array
17
17
  from jaxtyping import Float
18
18
 
19
- from alberta_framework.core.types import AutostepState, IDBDState, LMSState
19
+ from alberta_framework.core.types import (
20
+ AutostepState,
21
+ AutoTDIDBDState,
22
+ IDBDState,
23
+ LMSState,
24
+ TDIDBDState,
25
+ )
20
26
 
21
27
 
22
28
  @chex.dataclass(frozen=True)
@@ -424,3 +430,494 @@ class Autostep(Optimizer[AutostepState]):
424
430
  "mean_normalizer": jnp.mean(state.normalizers),
425
431
  },
426
432
  )
433
+
434
+
435
+ # =============================================================================
436
+ # TD Optimizers (for Step 3+ of Alberta Plan)
437
+ # =============================================================================
438
+
439
+
440
+ @chex.dataclass(frozen=True)
441
+ class TDOptimizerUpdate:
442
+ """Result of a TD optimizer update step.
443
+
444
+ Attributes:
445
+ weight_delta: Change to apply to weights
446
+ bias_delta: Change to apply to bias
447
+ new_state: Updated optimizer state
448
+ metrics: Dictionary of metrics for logging
449
+ """
450
+
451
+ weight_delta: Float[Array, " feature_dim"]
452
+ bias_delta: Float[Array, ""]
453
+ new_state: TDIDBDState | AutoTDIDBDState
454
+ metrics: dict[str, Array]
455
+
456
+
457
+ class TDOptimizer[StateT: (TDIDBDState, AutoTDIDBDState)](ABC):
458
+ """Base class for TD optimizers.
459
+
460
+ TD optimizers handle temporal-difference learning with eligibility traces.
461
+ They take TD error and both current and next observations as input.
462
+ """
463
+
464
+ @abstractmethod
465
+ def init(self, feature_dim: int) -> StateT:
466
+ """Initialize optimizer state.
467
+
468
+ Args:
469
+ feature_dim: Dimension of weight vector
470
+
471
+ Returns:
472
+ Initial optimizer state
473
+ """
474
+ ...
475
+
476
+ @abstractmethod
477
+ def update(
478
+ self,
479
+ state: StateT,
480
+ td_error: Array,
481
+ observation: Array,
482
+ next_observation: Array,
483
+ gamma: Array,
484
+ ) -> TDOptimizerUpdate:
485
+ """Compute weight updates given TD error.
486
+
487
+ Args:
488
+ state: Current optimizer state
489
+ td_error: TD error δ = R + γV(s') - V(s)
490
+ observation: Current observation φ(s)
491
+ next_observation: Next observation φ(s')
492
+ gamma: Discount factor γ (0 at terminal)
493
+
494
+ Returns:
495
+ TDOptimizerUpdate with deltas and new state
496
+ """
497
+ ...
498
+
499
+
500
+ class TDIDBD(TDOptimizer[TDIDBDState]):
501
+ """TD-IDBD optimizer for temporal-difference learning.
502
+
503
+ Extends IDBD to TD learning with eligibility traces. Maintains per-weight
504
+ adaptive step-sizes that are meta-learned based on gradient correlation
505
+ in the TD setting.
506
+
507
+ Two variants are supported:
508
+ - Semi-gradient (default): Uses only φ(s) in meta-update, more stable
509
+ - Ordinary gradient: Uses both φ(s) and φ(s'), more accurate but sensitive
510
+
511
+ Reference: Kearney et al. 2019, "Learning Feature Relevance Through Step Size
512
+ Adaptation in Temporal-Difference Learning"
513
+
514
+ The semi-gradient TD-IDBD algorithm (Algorithm 3 in paper):
515
+ 1. Compute TD error: `δ = R + γ*w^T*φ(s') - w^T*φ(s)`
516
+ 2. Update meta-weights: `β_i += θ*δ*φ_i(s)*h_i`
517
+ 3. Compute step-sizes: `α_i = exp(β_i)`
518
+ 4. Update eligibility traces: `z_i = γ*λ*z_i + φ_i(s)`
519
+ 5. Update weights: `w_i += α_i*δ*z_i`
520
+ 6. Update h traces: `h_i = h_i*[1 - α_i*φ_i(s)*z_i]^+ + α_i*δ*z_i`
521
+
522
+ Attributes:
523
+ initial_step_size: Initial per-weight step-size
524
+ meta_step_size: Meta learning rate theta
525
+ trace_decay: Eligibility trace decay lambda
526
+ use_semi_gradient: If True, use semi-gradient variant (default)
527
+ """
528
+
529
+ def __init__(
530
+ self,
531
+ initial_step_size: float = 0.01,
532
+ meta_step_size: float = 0.01,
533
+ trace_decay: float = 0.0,
534
+ use_semi_gradient: bool = True,
535
+ ):
536
+ """Initialize TD-IDBD optimizer.
537
+
538
+ Args:
539
+ initial_step_size: Initial value for per-weight step-sizes
540
+ meta_step_size: Meta learning rate theta for adapting step-sizes
541
+ trace_decay: Eligibility trace decay lambda (0 = TD(0))
542
+ use_semi_gradient: If True, use semi-gradient variant (recommended)
543
+ """
544
+ self._initial_step_size = initial_step_size
545
+ self._meta_step_size = meta_step_size
546
+ self._trace_decay = trace_decay
547
+ self._use_semi_gradient = use_semi_gradient
548
+
549
+ def init(self, feature_dim: int) -> TDIDBDState:
550
+ """Initialize TD-IDBD state.
551
+
552
+ Args:
553
+ feature_dim: Dimension of weight vector
554
+
555
+ Returns:
556
+ TD-IDBD state with per-weight step-sizes, traces, and h traces
557
+ """
558
+ return TDIDBDState(
559
+ log_step_sizes=jnp.full(
560
+ feature_dim, jnp.log(self._initial_step_size), dtype=jnp.float32
561
+ ),
562
+ eligibility_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
563
+ h_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
564
+ meta_step_size=jnp.array(self._meta_step_size, dtype=jnp.float32),
565
+ trace_decay=jnp.array(self._trace_decay, dtype=jnp.float32),
566
+ bias_log_step_size=jnp.array(jnp.log(self._initial_step_size), dtype=jnp.float32),
567
+ bias_eligibility_trace=jnp.array(0.0, dtype=jnp.float32),
568
+ bias_h_trace=jnp.array(0.0, dtype=jnp.float32),
569
+ )
570
+
571
+ def update(
572
+ self,
573
+ state: TDIDBDState,
574
+ td_error: Array,
575
+ observation: Array,
576
+ next_observation: Array,
577
+ gamma: Array,
578
+ ) -> TDOptimizerUpdate:
579
+ """Compute TD-IDBD weight update with adaptive step-sizes.
580
+
581
+ Implements Algorithm 3 (semi-gradient) or Algorithm 4 (ordinary gradient)
582
+ from Kearney et al. 2019.
583
+
584
+ Args:
585
+ state: Current TD-IDBD state
586
+ td_error: TD error δ = R + γV(s') - V(s)
587
+ observation: Current observation φ(s)
588
+ next_observation: Next observation φ(s')
589
+ gamma: Discount factor γ (0 at terminal)
590
+
591
+ Returns:
592
+ TDOptimizerUpdate with weight deltas and updated state
593
+ """
594
+ delta = jnp.squeeze(td_error)
595
+ theta = state.meta_step_size
596
+ lam = state.trace_decay
597
+ gamma_scalar = jnp.squeeze(gamma)
598
+
599
+ if self._use_semi_gradient:
600
+ # Semi-gradient TD-IDBD (Algorithm 3)
601
+ # β_i += θ*δ*φ_i(s)*h_i
602
+ gradient_correlation = delta * observation * state.h_traces
603
+ new_log_step_sizes = state.log_step_sizes + theta * gradient_correlation
604
+ else:
605
+ # Ordinary gradient TD-IDBD (Algorithm 4)
606
+ # β_i -= θ*δ*[γ*φ_i(s') - φ_i(s)]*h_i
607
+ # Note: negative sign because gradient direction is reversed
608
+ feature_diff = gamma_scalar * next_observation - observation
609
+ gradient_correlation = delta * feature_diff * state.h_traces
610
+ new_log_step_sizes = state.log_step_sizes - theta * gradient_correlation
611
+
612
+ # Clip log step-sizes to prevent numerical issues
613
+ new_log_step_sizes = jnp.clip(new_log_step_sizes, -10.0, 2.0)
614
+
615
+ # Get updated step-sizes for weight update
616
+ new_alphas = jnp.exp(new_log_step_sizes)
617
+
618
+ # Update eligibility traces: z_i = γ*λ*z_i + φ_i(s)
619
+ new_eligibility_traces = gamma_scalar * lam * state.eligibility_traces + observation
620
+
621
+ # Compute weight delta: α_i*δ*z_i
622
+ weight_delta = new_alphas * delta * new_eligibility_traces
623
+
624
+ if self._use_semi_gradient:
625
+ # Semi-gradient h update (Algorithm 3, line 9)
626
+ # h_i = h_i*[1 - α_i*φ_i(s)*z_i]^+ + α_i*δ*z_i
627
+ h_decay = jnp.maximum(0.0, 1.0 - new_alphas * observation * new_eligibility_traces)
628
+ new_h_traces = state.h_traces * h_decay + new_alphas * delta * new_eligibility_traces
629
+ else:
630
+ # Ordinary gradient h update (Algorithm 4, line 9)
631
+ # h_i = h_i*[1 + α_i*z_i*(γ*φ_i(s') - φ_i(s))]^+ + α_i*δ*z_i
632
+ feature_diff = gamma_scalar * next_observation - observation
633
+ h_decay = jnp.maximum(0.0, 1.0 + new_alphas * new_eligibility_traces * feature_diff)
634
+ new_h_traces = state.h_traces * h_decay + new_alphas * delta * new_eligibility_traces
635
+
636
+ # Bias updates (similar logic but scalar)
637
+ if self._use_semi_gradient:
638
+ # Semi-gradient bias meta-update
639
+ bias_gradient_correlation = delta * state.bias_h_trace
640
+ new_bias_log_step_size = state.bias_log_step_size + theta * bias_gradient_correlation
641
+ else:
642
+ # Ordinary gradient bias meta-update
643
+ # For bias, φ(s) = 1, so feature_diff = γ - 1
644
+ bias_feature_diff = gamma_scalar - 1.0
645
+ bias_gradient_correlation = delta * bias_feature_diff * state.bias_h_trace
646
+ new_bias_log_step_size = state.bias_log_step_size - theta * bias_gradient_correlation
647
+
648
+ new_bias_log_step_size = jnp.clip(new_bias_log_step_size, -10.0, 2.0)
649
+ new_bias_alpha = jnp.exp(new_bias_log_step_size)
650
+
651
+ # Update bias eligibility trace
652
+ new_bias_eligibility_trace = gamma_scalar * lam * state.bias_eligibility_trace + 1.0
653
+
654
+ # Bias weight delta
655
+ bias_delta = new_bias_alpha * delta * new_bias_eligibility_trace
656
+
657
+ if self._use_semi_gradient:
658
+ # Semi-gradient bias h update
659
+ bias_h_decay = jnp.maximum(0.0, 1.0 - new_bias_alpha * new_bias_eligibility_trace)
660
+ new_bias_h_trace = (
661
+ state.bias_h_trace * bias_h_decay
662
+ + new_bias_alpha * delta * new_bias_eligibility_trace
663
+ )
664
+ else:
665
+ # Ordinary gradient bias h update
666
+ bias_feature_diff = gamma_scalar - 1.0
667
+ bias_h_decay = jnp.maximum(
668
+ 0.0, 1.0 + new_bias_alpha * new_bias_eligibility_trace * bias_feature_diff
669
+ )
670
+ new_bias_h_trace = (
671
+ state.bias_h_trace * bias_h_decay
672
+ + new_bias_alpha * delta * new_bias_eligibility_trace
673
+ )
674
+
675
+ new_state = TDIDBDState(
676
+ log_step_sizes=new_log_step_sizes,
677
+ eligibility_traces=new_eligibility_traces,
678
+ h_traces=new_h_traces,
679
+ meta_step_size=theta,
680
+ trace_decay=lam,
681
+ bias_log_step_size=new_bias_log_step_size,
682
+ bias_eligibility_trace=new_bias_eligibility_trace,
683
+ bias_h_trace=new_bias_h_trace,
684
+ )
685
+
686
+ return TDOptimizerUpdate(
687
+ weight_delta=weight_delta,
688
+ bias_delta=bias_delta,
689
+ new_state=new_state,
690
+ metrics={
691
+ "mean_step_size": jnp.mean(new_alphas),
692
+ "min_step_size": jnp.min(new_alphas),
693
+ "max_step_size": jnp.max(new_alphas),
694
+ "mean_eligibility_trace": jnp.mean(jnp.abs(new_eligibility_traces)),
695
+ },
696
+ )
697
+
698
+
699
+ class AutoTDIDBD(TDOptimizer[AutoTDIDBDState]):
700
+ """AutoStep-style normalized TD-IDBD optimizer.
701
+
702
+ Adds AutoStep-style normalization to TDIDBD for improved stability and
703
+ reduced sensitivity to the meta step-size theta. Includes:
704
+ 1. Normalization of the meta-weight update by a running trace of recent updates
705
+ 2. Effective step-size normalization to prevent overshooting
706
+
707
+ Reference: Kearney et al. 2019, Algorithm 6 "AutoStep Style Normalized TIDBD(λ)"
708
+
709
+ The AutoTDIDBD algorithm:
710
+ 1. Compute TD error: `δ = R + γ*w^T*φ(s') - w^T*φ(s)`
711
+ 2. Update normalizers: `η_i = max(|δ*[γφ_i(s')-φ_i(s)]*h_i|,
712
+ η_i - (1/τ)*α_i*[γφ_i(s')-φ_i(s)]*z_i*(|δ*φ_i(s)*h_i| - η_i))`
713
+ 3. Normalized meta-update: `β_i -= θ*(1/η_i)*δ*[γφ_i(s')-φ_i(s)]*h_i`
714
+ 4. Effective step-size normalization: `M = max(-exp(β)*[γφ(s')-φ(s)]^T*z, 1)`
715
+ then `β_i -= log(M)`
716
+ 5. Update weights and traces as in TIDBD
717
+
718
+ Attributes:
719
+ initial_step_size: Initial per-weight step-size
720
+ meta_step_size: Meta learning rate theta
721
+ trace_decay: Eligibility trace decay lambda
722
+ normalizer_decay: Decay parameter tau for normalizers
723
+ """
724
+
725
+ def __init__(
726
+ self,
727
+ initial_step_size: float = 0.01,
728
+ meta_step_size: float = 0.01,
729
+ trace_decay: float = 0.0,
730
+ normalizer_decay: float = 10000.0,
731
+ ):
732
+ """Initialize AutoTDIDBD optimizer.
733
+
734
+ Args:
735
+ initial_step_size: Initial value for per-weight step-sizes
736
+ meta_step_size: Meta learning rate theta for adapting step-sizes
737
+ trace_decay: Eligibility trace decay lambda (0 = TD(0))
738
+ normalizer_decay: Decay parameter tau for normalizers (default: 10000)
739
+ """
740
+ self._initial_step_size = initial_step_size
741
+ self._meta_step_size = meta_step_size
742
+ self._trace_decay = trace_decay
743
+ self._normalizer_decay = normalizer_decay
744
+
745
+ def init(self, feature_dim: int) -> AutoTDIDBDState:
746
+ """Initialize AutoTDIDBD state.
747
+
748
+ Args:
749
+ feature_dim: Dimension of weight vector
750
+
751
+ Returns:
752
+ AutoTDIDBD state with per-weight step-sizes, traces, h traces, and normalizers
753
+ """
754
+ return AutoTDIDBDState(
755
+ log_step_sizes=jnp.full(
756
+ feature_dim, jnp.log(self._initial_step_size), dtype=jnp.float32
757
+ ),
758
+ eligibility_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
759
+ h_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
760
+ normalizers=jnp.ones(feature_dim, dtype=jnp.float32),
761
+ meta_step_size=jnp.array(self._meta_step_size, dtype=jnp.float32),
762
+ trace_decay=jnp.array(self._trace_decay, dtype=jnp.float32),
763
+ normalizer_decay=jnp.array(self._normalizer_decay, dtype=jnp.float32),
764
+ bias_log_step_size=jnp.array(jnp.log(self._initial_step_size), dtype=jnp.float32),
765
+ bias_eligibility_trace=jnp.array(0.0, dtype=jnp.float32),
766
+ bias_h_trace=jnp.array(0.0, dtype=jnp.float32),
767
+ bias_normalizer=jnp.array(1.0, dtype=jnp.float32),
768
+ )
769
+
770
+ def update(
771
+ self,
772
+ state: AutoTDIDBDState,
773
+ td_error: Array,
774
+ observation: Array,
775
+ next_observation: Array,
776
+ gamma: Array,
777
+ ) -> TDOptimizerUpdate:
778
+ """Compute AutoTDIDBD weight update with normalized adaptive step-sizes.
779
+
780
+ Implements Algorithm 6 from Kearney et al. 2019.
781
+
782
+ Args:
783
+ state: Current AutoTDIDBD state
784
+ td_error: TD error δ = R + γV(s') - V(s)
785
+ observation: Current observation φ(s)
786
+ next_observation: Next observation φ(s')
787
+ gamma: Discount factor γ (0 at terminal)
788
+
789
+ Returns:
790
+ TDOptimizerUpdate with weight deltas and updated state
791
+ """
792
+ delta = jnp.squeeze(td_error)
793
+ theta = state.meta_step_size
794
+ lam = state.trace_decay
795
+ tau = state.normalizer_decay
796
+ gamma_scalar = jnp.squeeze(gamma)
797
+
798
+ # Feature difference: γ*φ(s') - φ(s)
799
+ feature_diff = gamma_scalar * next_observation - observation
800
+
801
+ # Current step-sizes
802
+ alphas = jnp.exp(state.log_step_sizes)
803
+
804
+ # Update normalizers (Algorithm 6, lines 5-7)
805
+ # η_i = max(|δ*[γφ_i(s')-φ_i(s)]*h_i|,
806
+ # η_i - (1/τ)*α_i*[γφ_i(s')-φ_i(s)]*z_i*(|δ*φ_i(s)*h_i| - η_i))
807
+ abs_weight_update = jnp.abs(delta * feature_diff * state.h_traces)
808
+ normalizer_decay_term = (
809
+ (1.0 / tau)
810
+ * alphas
811
+ * feature_diff
812
+ * state.eligibility_traces
813
+ * (jnp.abs(delta * observation * state.h_traces) - state.normalizers)
814
+ )
815
+ new_normalizers = jnp.maximum(abs_weight_update, state.normalizers - normalizer_decay_term)
816
+ # Ensure normalizers don't go to zero
817
+ new_normalizers = jnp.maximum(new_normalizers, 1e-8)
818
+
819
+ # Normalized meta-update (Algorithm 6, line 9)
820
+ # β_i -= θ*(1/η_i)*δ*[γφ_i(s')-φ_i(s)]*h_i
821
+ normalized_gradient = delta * feature_diff * state.h_traces / new_normalizers
822
+ new_log_step_sizes = state.log_step_sizes - theta * normalized_gradient
823
+
824
+ # Effective step-size normalization (Algorithm 6, lines 10-11)
825
+ # M = max(-exp(β_i)*[γφ_i(s')-φ_i(s)]^T*z_i, 1)
826
+ # β_i -= log(M)
827
+ effective_step_size = -jnp.sum(
828
+ jnp.exp(new_log_step_sizes) * feature_diff * state.eligibility_traces
829
+ )
830
+ normalization_factor = jnp.maximum(effective_step_size, 1.0)
831
+ new_log_step_sizes = new_log_step_sizes - jnp.log(normalization_factor)
832
+
833
+ # Clip log step-sizes to prevent numerical issues
834
+ new_log_step_sizes = jnp.clip(new_log_step_sizes, -10.0, 2.0)
835
+
836
+ # Get updated step-sizes
837
+ new_alphas = jnp.exp(new_log_step_sizes)
838
+
839
+ # Update eligibility traces: z_i = γ*λ*z_i + φ_i(s)
840
+ new_eligibility_traces = gamma_scalar * lam * state.eligibility_traces + observation
841
+
842
+ # Compute weight delta: α_i*δ*z_i
843
+ weight_delta = new_alphas * delta * new_eligibility_traces
844
+
845
+ # Update h traces (ordinary gradient variant, Algorithm 6 line 15)
846
+ # h_i = h_i*[1 + α_i*[γφ_i(s')-φ_i(s)]*z_i]^+ + α_i*δ*z_i
847
+ h_decay = jnp.maximum(0.0, 1.0 + new_alphas * feature_diff * new_eligibility_traces)
848
+ new_h_traces = state.h_traces * h_decay + new_alphas * delta * new_eligibility_traces
849
+
850
+ # Bias updates
851
+ bias_alpha = jnp.exp(state.bias_log_step_size)
852
+ bias_feature_diff = gamma_scalar - 1.0 # For bias, φ(s) = 1
853
+
854
+ # Bias normalizer update
855
+ abs_bias_weight_update = jnp.abs(delta * bias_feature_diff * state.bias_h_trace)
856
+ bias_normalizer_decay_term = (
857
+ (1.0 / tau)
858
+ * bias_alpha
859
+ * bias_feature_diff
860
+ * state.bias_eligibility_trace
861
+ * (jnp.abs(delta * state.bias_h_trace) - state.bias_normalizer)
862
+ )
863
+ new_bias_normalizer = jnp.maximum(
864
+ abs_bias_weight_update, state.bias_normalizer - bias_normalizer_decay_term
865
+ )
866
+ new_bias_normalizer = jnp.maximum(new_bias_normalizer, 1e-8)
867
+
868
+ # Normalized bias meta-update
869
+ normalized_bias_gradient = (
870
+ delta * bias_feature_diff * state.bias_h_trace / new_bias_normalizer
871
+ )
872
+ new_bias_log_step_size = state.bias_log_step_size - theta * normalized_bias_gradient
873
+
874
+ # Effective step-size normalization for bias
875
+ bias_effective_step_size = (
876
+ -jnp.exp(new_bias_log_step_size) * bias_feature_diff * state.bias_eligibility_trace
877
+ )
878
+ bias_norm_factor = jnp.maximum(bias_effective_step_size, 1.0)
879
+ new_bias_log_step_size = new_bias_log_step_size - jnp.log(bias_norm_factor)
880
+
881
+ new_bias_log_step_size = jnp.clip(new_bias_log_step_size, -10.0, 2.0)
882
+ new_bias_alpha = jnp.exp(new_bias_log_step_size)
883
+
884
+ # Update bias eligibility trace
885
+ new_bias_eligibility_trace = gamma_scalar * lam * state.bias_eligibility_trace + 1.0
886
+
887
+ # Bias weight delta
888
+ bias_delta = new_bias_alpha * delta * new_bias_eligibility_trace
889
+
890
+ # Bias h trace update
891
+ bias_h_decay = jnp.maximum(
892
+ 0.0, 1.0 + new_bias_alpha * bias_feature_diff * new_bias_eligibility_trace
893
+ )
894
+ new_bias_h_trace = (
895
+ state.bias_h_trace * bias_h_decay + new_bias_alpha * delta * new_bias_eligibility_trace
896
+ )
897
+
898
+ new_state = AutoTDIDBDState(
899
+ log_step_sizes=new_log_step_sizes,
900
+ eligibility_traces=new_eligibility_traces,
901
+ h_traces=new_h_traces,
902
+ normalizers=new_normalizers,
903
+ meta_step_size=theta,
904
+ trace_decay=lam,
905
+ normalizer_decay=tau,
906
+ bias_log_step_size=new_bias_log_step_size,
907
+ bias_eligibility_trace=new_bias_eligibility_trace,
908
+ bias_h_trace=new_bias_h_trace,
909
+ bias_normalizer=new_bias_normalizer,
910
+ )
911
+
912
+ return TDOptimizerUpdate(
913
+ weight_delta=weight_delta,
914
+ bias_delta=bias_delta,
915
+ new_state=new_state,
916
+ metrics={
917
+ "mean_step_size": jnp.mean(new_alphas),
918
+ "min_step_size": jnp.min(new_alphas),
919
+ "max_step_size": jnp.max(new_alphas),
920
+ "mean_eligibility_trace": jnp.mean(jnp.abs(new_eligibility_traces)),
921
+ "mean_normalizer": jnp.mean(new_normalizers),
922
+ },
923
+ )
@@ -9,7 +9,7 @@ from typing import TYPE_CHECKING
9
9
  import chex
10
10
  import jax.numpy as jnp
11
11
  from jax import Array
12
- from jaxtyping import Float, Int, PRNGKeyArray
12
+ from jaxtyping import Float, Int
13
13
 
14
14
  if TYPE_CHECKING:
15
15
  from alberta_framework.core.learners import NormalizedLearnerState
@@ -255,6 +255,118 @@ def create_idbd_state(
255
255
  )
256
256
 
257
257
 
258
+ # =============================================================================
259
+ # TD Learning Types (for Step 3+ of Alberta Plan)
260
+ # =============================================================================
261
+
262
+
263
+ @chex.dataclass(frozen=True)
264
+ class TDTimeStep:
265
+ """Single experience from a TD stream.
266
+
267
+ Represents a transition (s, r, s', gamma) for temporal-difference learning.
268
+
269
+ Attributes:
270
+ observation: Feature vector φ(s)
271
+ reward: Reward R received
272
+ next_observation: Feature vector φ(s')
273
+ gamma: Discount factor γ_t (0 at terminal states)
274
+ """
275
+
276
+ observation: Float[Array, " feature_dim"]
277
+ reward: Float[Array, ""]
278
+ next_observation: Float[Array, " feature_dim"]
279
+ gamma: Float[Array, ""]
280
+
281
+
282
+ @chex.dataclass(frozen=True)
283
+ class TDIDBDState:
284
+ """State for the TD-IDBD (Temporal-Difference IDBD) optimizer.
285
+
286
+ TD-IDBD extends IDBD to temporal-difference learning with eligibility traces.
287
+ Maintains per-weight adaptive step-sizes that are meta-learned based on
288
+ gradient correlation in the TD setting.
289
+
290
+ Reference: Kearney et al. 2019, "Learning Feature Relevance Through Step Size
291
+ Adaptation in Temporal-Difference Learning"
292
+
293
+ Attributes:
294
+ log_step_sizes: Log of per-weight step-sizes (log alpha_i)
295
+ eligibility_traces: Eligibility traces z_i for temporal credit assignment
296
+ h_traces: Per-weight h traces for gradient correlation
297
+ meta_step_size: Meta learning rate theta for adapting step-sizes
298
+ trace_decay: Eligibility trace decay parameter lambda
299
+ bias_log_step_size: Log step-size for the bias term
300
+ bias_eligibility_trace: Eligibility trace for the bias
301
+ bias_h_trace: h trace for the bias term
302
+ """
303
+
304
+ log_step_sizes: Float[Array, " feature_dim"]
305
+ eligibility_traces: Float[Array, " feature_dim"]
306
+ h_traces: Float[Array, " feature_dim"]
307
+ meta_step_size: Float[Array, ""]
308
+ trace_decay: Float[Array, ""]
309
+ bias_log_step_size: Float[Array, ""]
310
+ bias_eligibility_trace: Float[Array, ""]
311
+ bias_h_trace: Float[Array, ""]
312
+
313
+
314
+ @chex.dataclass(frozen=True)
315
+ class AutoTDIDBDState:
316
+ """State for the AutoTDIDBD optimizer.
317
+
318
+ AutoTDIDBD adds AutoStep-style normalization to TDIDBD for improved stability.
319
+ Includes normalizers for the meta-weight updates and effective step-size
320
+ normalization to prevent overshooting.
321
+
322
+ Reference: Kearney et al. 2019, Algorithm 6
323
+
324
+ Attributes:
325
+ log_step_sizes: Log of per-weight step-sizes (log alpha_i)
326
+ eligibility_traces: Eligibility traces z_i
327
+ h_traces: Per-weight h traces for gradient correlation
328
+ normalizers: Running max of absolute gradient correlations (eta_i)
329
+ meta_step_size: Meta learning rate theta
330
+ trace_decay: Eligibility trace decay parameter lambda
331
+ normalizer_decay: Decay parameter tau for normalizers
332
+ bias_log_step_size: Log step-size for the bias term
333
+ bias_eligibility_trace: Eligibility trace for the bias
334
+ bias_h_trace: h trace for the bias term
335
+ bias_normalizer: Normalizer for the bias gradient correlation
336
+ """
337
+
338
+ log_step_sizes: Float[Array, " feature_dim"]
339
+ eligibility_traces: Float[Array, " feature_dim"]
340
+ h_traces: Float[Array, " feature_dim"]
341
+ normalizers: Float[Array, " feature_dim"]
342
+ meta_step_size: Float[Array, ""]
343
+ trace_decay: Float[Array, ""]
344
+ normalizer_decay: Float[Array, ""]
345
+ bias_log_step_size: Float[Array, ""]
346
+ bias_eligibility_trace: Float[Array, ""]
347
+ bias_h_trace: Float[Array, ""]
348
+ bias_normalizer: Float[Array, ""]
349
+
350
+
351
+ # Union type for TD optimizer states
352
+ TDOptimizerState = TDIDBDState | AutoTDIDBDState
353
+
354
+
355
+ @chex.dataclass(frozen=True)
356
+ class TDLearnerState:
357
+ """State for a TD linear learner.
358
+
359
+ Attributes:
360
+ weights: Weight vector for linear value function approximation
361
+ bias: Bias term
362
+ optimizer_state: State maintained by the TD optimizer
363
+ """
364
+
365
+ weights: Float[Array, " feature_dim"]
366
+ bias: Float[Array, ""]
367
+ optimizer_state: TDOptimizerState
368
+
369
+
258
370
  def create_autostep_state(
259
371
  feature_dim: int,
260
372
  initial_step_size: float = 0.01,
@@ -282,3 +394,66 @@ def create_autostep_state(
282
394
  bias_trace=jnp.array(0.0, dtype=jnp.float32),
283
395
  bias_normalizer=jnp.array(1.0, dtype=jnp.float32),
284
396
  )
397
+
398
+
399
+ def create_tdidbd_state(
400
+ feature_dim: int,
401
+ initial_step_size: float = 0.01,
402
+ meta_step_size: float = 0.01,
403
+ trace_decay: float = 0.0,
404
+ ) -> TDIDBDState:
405
+ """Create initial TD-IDBD optimizer state.
406
+
407
+ Args:
408
+ feature_dim: Dimension of the feature vector
409
+ initial_step_size: Initial per-weight step-size
410
+ meta_step_size: Meta learning rate theta for adapting step-sizes
411
+ trace_decay: Eligibility trace decay parameter lambda (0 = TD(0))
412
+
413
+ Returns:
414
+ Initial TD-IDBD state
415
+ """
416
+ return TDIDBDState(
417
+ log_step_sizes=jnp.full(feature_dim, jnp.log(initial_step_size), dtype=jnp.float32),
418
+ eligibility_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
419
+ h_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
420
+ meta_step_size=jnp.array(meta_step_size, dtype=jnp.float32),
421
+ trace_decay=jnp.array(trace_decay, dtype=jnp.float32),
422
+ bias_log_step_size=jnp.array(jnp.log(initial_step_size), dtype=jnp.float32),
423
+ bias_eligibility_trace=jnp.array(0.0, dtype=jnp.float32),
424
+ bias_h_trace=jnp.array(0.0, dtype=jnp.float32),
425
+ )
426
+
427
+
428
+ def create_autotdidbd_state(
429
+ feature_dim: int,
430
+ initial_step_size: float = 0.01,
431
+ meta_step_size: float = 0.01,
432
+ trace_decay: float = 0.0,
433
+ normalizer_decay: float = 10000.0,
434
+ ) -> AutoTDIDBDState:
435
+ """Create initial AutoTDIDBD optimizer state.
436
+
437
+ Args:
438
+ feature_dim: Dimension of the feature vector
439
+ initial_step_size: Initial per-weight step-size
440
+ meta_step_size: Meta learning rate theta for adapting step-sizes
441
+ trace_decay: Eligibility trace decay parameter lambda (0 = TD(0))
442
+ normalizer_decay: Decay parameter tau for normalizers (default: 10000)
443
+
444
+ Returns:
445
+ Initial AutoTDIDBD state
446
+ """
447
+ return AutoTDIDBDState(
448
+ log_step_sizes=jnp.full(feature_dim, jnp.log(initial_step_size), dtype=jnp.float32),
449
+ eligibility_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
450
+ h_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
451
+ normalizers=jnp.ones(feature_dim, dtype=jnp.float32),
452
+ meta_step_size=jnp.array(meta_step_size, dtype=jnp.float32),
453
+ trace_decay=jnp.array(trace_decay, dtype=jnp.float32),
454
+ normalizer_decay=jnp.array(normalizer_decay, dtype=jnp.float32),
455
+ bias_log_step_size=jnp.array(jnp.log(initial_step_size), dtype=jnp.float32),
456
+ bias_eligibility_trace=jnp.array(0.0, dtype=jnp.float32),
457
+ bias_h_trace=jnp.array(0.0, dtype=jnp.float32),
458
+ bias_normalizer=jnp.array(1.0, dtype=jnp.float32),
459
+ )