alberta-framework 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alberta_framework/__init__.py +39 -5
- alberta_framework/core/__init__.py +26 -2
- alberta_framework/core/learners.py +277 -59
- alberta_framework/core/normalizers.py +1 -4
- alberta_framework/core/optimizers.py +498 -1
- alberta_framework/core/types.py +175 -0
- alberta_framework/streams/gymnasium.py +3 -10
- alberta_framework/streams/synthetic.py +3 -9
- alberta_framework/utils/experiments.py +1 -3
- alberta_framework/utils/export.py +20 -16
- alberta_framework/utils/statistics.py +17 -9
- alberta_framework/utils/visualization.py +31 -25
- {alberta_framework-0.3.2.dist-info → alberta_framework-0.4.0.dist-info}/METADATA +24 -1
- alberta_framework-0.4.0.dist-info/RECORD +22 -0
- alberta_framework-0.3.2.dist-info/RECORD +0 -22
- {alberta_framework-0.3.2.dist-info → alberta_framework-0.4.0.dist-info}/WHEEL +0 -0
- {alberta_framework-0.3.2.dist-info → alberta_framework-0.4.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -16,7 +16,13 @@ import jax.numpy as jnp
|
|
|
16
16
|
from jax import Array
|
|
17
17
|
from jaxtyping import Float
|
|
18
18
|
|
|
19
|
-
from alberta_framework.core.types import
|
|
19
|
+
from alberta_framework.core.types import (
|
|
20
|
+
AutostepState,
|
|
21
|
+
AutoTDIDBDState,
|
|
22
|
+
IDBDState,
|
|
23
|
+
LMSState,
|
|
24
|
+
TDIDBDState,
|
|
25
|
+
)
|
|
20
26
|
|
|
21
27
|
|
|
22
28
|
@chex.dataclass(frozen=True)
|
|
@@ -424,3 +430,494 @@ class Autostep(Optimizer[AutostepState]):
|
|
|
424
430
|
"mean_normalizer": jnp.mean(state.normalizers),
|
|
425
431
|
},
|
|
426
432
|
)
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
# =============================================================================
|
|
436
|
+
# TD Optimizers (for Step 3+ of Alberta Plan)
|
|
437
|
+
# =============================================================================
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
@chex.dataclass(frozen=True)
|
|
441
|
+
class TDOptimizerUpdate:
|
|
442
|
+
"""Result of a TD optimizer update step.
|
|
443
|
+
|
|
444
|
+
Attributes:
|
|
445
|
+
weight_delta: Change to apply to weights
|
|
446
|
+
bias_delta: Change to apply to bias
|
|
447
|
+
new_state: Updated optimizer state
|
|
448
|
+
metrics: Dictionary of metrics for logging
|
|
449
|
+
"""
|
|
450
|
+
|
|
451
|
+
weight_delta: Float[Array, " feature_dim"]
|
|
452
|
+
bias_delta: Float[Array, ""]
|
|
453
|
+
new_state: TDIDBDState | AutoTDIDBDState
|
|
454
|
+
metrics: dict[str, Array]
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
class TDOptimizer[StateT: (TDIDBDState, AutoTDIDBDState)](ABC):
|
|
458
|
+
"""Base class for TD optimizers.
|
|
459
|
+
|
|
460
|
+
TD optimizers handle temporal-difference learning with eligibility traces.
|
|
461
|
+
They take TD error and both current and next observations as input.
|
|
462
|
+
"""
|
|
463
|
+
|
|
464
|
+
@abstractmethod
|
|
465
|
+
def init(self, feature_dim: int) -> StateT:
|
|
466
|
+
"""Initialize optimizer state.
|
|
467
|
+
|
|
468
|
+
Args:
|
|
469
|
+
feature_dim: Dimension of weight vector
|
|
470
|
+
|
|
471
|
+
Returns:
|
|
472
|
+
Initial optimizer state
|
|
473
|
+
"""
|
|
474
|
+
...
|
|
475
|
+
|
|
476
|
+
@abstractmethod
|
|
477
|
+
def update(
|
|
478
|
+
self,
|
|
479
|
+
state: StateT,
|
|
480
|
+
td_error: Array,
|
|
481
|
+
observation: Array,
|
|
482
|
+
next_observation: Array,
|
|
483
|
+
gamma: Array,
|
|
484
|
+
) -> TDOptimizerUpdate:
|
|
485
|
+
"""Compute weight updates given TD error.
|
|
486
|
+
|
|
487
|
+
Args:
|
|
488
|
+
state: Current optimizer state
|
|
489
|
+
td_error: TD error δ = R + γV(s') - V(s)
|
|
490
|
+
observation: Current observation φ(s)
|
|
491
|
+
next_observation: Next observation φ(s')
|
|
492
|
+
gamma: Discount factor γ (0 at terminal)
|
|
493
|
+
|
|
494
|
+
Returns:
|
|
495
|
+
TDOptimizerUpdate with deltas and new state
|
|
496
|
+
"""
|
|
497
|
+
...
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
class TDIDBD(TDOptimizer[TDIDBDState]):
|
|
501
|
+
"""TD-IDBD optimizer for temporal-difference learning.
|
|
502
|
+
|
|
503
|
+
Extends IDBD to TD learning with eligibility traces. Maintains per-weight
|
|
504
|
+
adaptive step-sizes that are meta-learned based on gradient correlation
|
|
505
|
+
in the TD setting.
|
|
506
|
+
|
|
507
|
+
Two variants are supported:
|
|
508
|
+
- Semi-gradient (default): Uses only φ(s) in meta-update, more stable
|
|
509
|
+
- Ordinary gradient: Uses both φ(s) and φ(s'), more accurate but sensitive
|
|
510
|
+
|
|
511
|
+
Reference: Kearney et al. 2019, "Learning Feature Relevance Through Step Size
|
|
512
|
+
Adaptation in Temporal-Difference Learning"
|
|
513
|
+
|
|
514
|
+
The semi-gradient TD-IDBD algorithm (Algorithm 3 in paper):
|
|
515
|
+
1. Compute TD error: `δ = R + γ*w^T*φ(s') - w^T*φ(s)`
|
|
516
|
+
2. Update meta-weights: `β_i += θ*δ*φ_i(s)*h_i`
|
|
517
|
+
3. Compute step-sizes: `α_i = exp(β_i)`
|
|
518
|
+
4. Update eligibility traces: `z_i = γ*λ*z_i + φ_i(s)`
|
|
519
|
+
5. Update weights: `w_i += α_i*δ*z_i`
|
|
520
|
+
6. Update h traces: `h_i = h_i*[1 - α_i*φ_i(s)*z_i]^+ + α_i*δ*z_i`
|
|
521
|
+
|
|
522
|
+
Attributes:
|
|
523
|
+
initial_step_size: Initial per-weight step-size
|
|
524
|
+
meta_step_size: Meta learning rate theta
|
|
525
|
+
trace_decay: Eligibility trace decay lambda
|
|
526
|
+
use_semi_gradient: If True, use semi-gradient variant (default)
|
|
527
|
+
"""
|
|
528
|
+
|
|
529
|
+
def __init__(
|
|
530
|
+
self,
|
|
531
|
+
initial_step_size: float = 0.01,
|
|
532
|
+
meta_step_size: float = 0.01,
|
|
533
|
+
trace_decay: float = 0.0,
|
|
534
|
+
use_semi_gradient: bool = True,
|
|
535
|
+
):
|
|
536
|
+
"""Initialize TD-IDBD optimizer.
|
|
537
|
+
|
|
538
|
+
Args:
|
|
539
|
+
initial_step_size: Initial value for per-weight step-sizes
|
|
540
|
+
meta_step_size: Meta learning rate theta for adapting step-sizes
|
|
541
|
+
trace_decay: Eligibility trace decay lambda (0 = TD(0))
|
|
542
|
+
use_semi_gradient: If True, use semi-gradient variant (recommended)
|
|
543
|
+
"""
|
|
544
|
+
self._initial_step_size = initial_step_size
|
|
545
|
+
self._meta_step_size = meta_step_size
|
|
546
|
+
self._trace_decay = trace_decay
|
|
547
|
+
self._use_semi_gradient = use_semi_gradient
|
|
548
|
+
|
|
549
|
+
def init(self, feature_dim: int) -> TDIDBDState:
|
|
550
|
+
"""Initialize TD-IDBD state.
|
|
551
|
+
|
|
552
|
+
Args:
|
|
553
|
+
feature_dim: Dimension of weight vector
|
|
554
|
+
|
|
555
|
+
Returns:
|
|
556
|
+
TD-IDBD state with per-weight step-sizes, traces, and h traces
|
|
557
|
+
"""
|
|
558
|
+
return TDIDBDState(
|
|
559
|
+
log_step_sizes=jnp.full(
|
|
560
|
+
feature_dim, jnp.log(self._initial_step_size), dtype=jnp.float32
|
|
561
|
+
),
|
|
562
|
+
eligibility_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
|
|
563
|
+
h_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
|
|
564
|
+
meta_step_size=jnp.array(self._meta_step_size, dtype=jnp.float32),
|
|
565
|
+
trace_decay=jnp.array(self._trace_decay, dtype=jnp.float32),
|
|
566
|
+
bias_log_step_size=jnp.array(jnp.log(self._initial_step_size), dtype=jnp.float32),
|
|
567
|
+
bias_eligibility_trace=jnp.array(0.0, dtype=jnp.float32),
|
|
568
|
+
bias_h_trace=jnp.array(0.0, dtype=jnp.float32),
|
|
569
|
+
)
|
|
570
|
+
|
|
571
|
+
def update(
|
|
572
|
+
self,
|
|
573
|
+
state: TDIDBDState,
|
|
574
|
+
td_error: Array,
|
|
575
|
+
observation: Array,
|
|
576
|
+
next_observation: Array,
|
|
577
|
+
gamma: Array,
|
|
578
|
+
) -> TDOptimizerUpdate:
|
|
579
|
+
"""Compute TD-IDBD weight update with adaptive step-sizes.
|
|
580
|
+
|
|
581
|
+
Implements Algorithm 3 (semi-gradient) or Algorithm 4 (ordinary gradient)
|
|
582
|
+
from Kearney et al. 2019.
|
|
583
|
+
|
|
584
|
+
Args:
|
|
585
|
+
state: Current TD-IDBD state
|
|
586
|
+
td_error: TD error δ = R + γV(s') - V(s)
|
|
587
|
+
observation: Current observation φ(s)
|
|
588
|
+
next_observation: Next observation φ(s')
|
|
589
|
+
gamma: Discount factor γ (0 at terminal)
|
|
590
|
+
|
|
591
|
+
Returns:
|
|
592
|
+
TDOptimizerUpdate with weight deltas and updated state
|
|
593
|
+
"""
|
|
594
|
+
delta = jnp.squeeze(td_error)
|
|
595
|
+
theta = state.meta_step_size
|
|
596
|
+
lam = state.trace_decay
|
|
597
|
+
gamma_scalar = jnp.squeeze(gamma)
|
|
598
|
+
|
|
599
|
+
if self._use_semi_gradient:
|
|
600
|
+
# Semi-gradient TD-IDBD (Algorithm 3)
|
|
601
|
+
# β_i += θ*δ*φ_i(s)*h_i
|
|
602
|
+
gradient_correlation = delta * observation * state.h_traces
|
|
603
|
+
new_log_step_sizes = state.log_step_sizes + theta * gradient_correlation
|
|
604
|
+
else:
|
|
605
|
+
# Ordinary gradient TD-IDBD (Algorithm 4)
|
|
606
|
+
# β_i -= θ*δ*[γ*φ_i(s') - φ_i(s)]*h_i
|
|
607
|
+
# Note: negative sign because gradient direction is reversed
|
|
608
|
+
feature_diff = gamma_scalar * next_observation - observation
|
|
609
|
+
gradient_correlation = delta * feature_diff * state.h_traces
|
|
610
|
+
new_log_step_sizes = state.log_step_sizes - theta * gradient_correlation
|
|
611
|
+
|
|
612
|
+
# Clip log step-sizes to prevent numerical issues
|
|
613
|
+
new_log_step_sizes = jnp.clip(new_log_step_sizes, -10.0, 2.0)
|
|
614
|
+
|
|
615
|
+
# Get updated step-sizes for weight update
|
|
616
|
+
new_alphas = jnp.exp(new_log_step_sizes)
|
|
617
|
+
|
|
618
|
+
# Update eligibility traces: z_i = γ*λ*z_i + φ_i(s)
|
|
619
|
+
new_eligibility_traces = gamma_scalar * lam * state.eligibility_traces + observation
|
|
620
|
+
|
|
621
|
+
# Compute weight delta: α_i*δ*z_i
|
|
622
|
+
weight_delta = new_alphas * delta * new_eligibility_traces
|
|
623
|
+
|
|
624
|
+
if self._use_semi_gradient:
|
|
625
|
+
# Semi-gradient h update (Algorithm 3, line 9)
|
|
626
|
+
# h_i = h_i*[1 - α_i*φ_i(s)*z_i]^+ + α_i*δ*z_i
|
|
627
|
+
h_decay = jnp.maximum(0.0, 1.0 - new_alphas * observation * new_eligibility_traces)
|
|
628
|
+
new_h_traces = state.h_traces * h_decay + new_alphas * delta * new_eligibility_traces
|
|
629
|
+
else:
|
|
630
|
+
# Ordinary gradient h update (Algorithm 4, line 9)
|
|
631
|
+
# h_i = h_i*[1 + α_i*z_i*(γ*φ_i(s') - φ_i(s))]^+ + α_i*δ*z_i
|
|
632
|
+
feature_diff = gamma_scalar * next_observation - observation
|
|
633
|
+
h_decay = jnp.maximum(0.0, 1.0 + new_alphas * new_eligibility_traces * feature_diff)
|
|
634
|
+
new_h_traces = state.h_traces * h_decay + new_alphas * delta * new_eligibility_traces
|
|
635
|
+
|
|
636
|
+
# Bias updates (similar logic but scalar)
|
|
637
|
+
if self._use_semi_gradient:
|
|
638
|
+
# Semi-gradient bias meta-update
|
|
639
|
+
bias_gradient_correlation = delta * state.bias_h_trace
|
|
640
|
+
new_bias_log_step_size = state.bias_log_step_size + theta * bias_gradient_correlation
|
|
641
|
+
else:
|
|
642
|
+
# Ordinary gradient bias meta-update
|
|
643
|
+
# For bias, φ(s) = 1, so feature_diff = γ - 1
|
|
644
|
+
bias_feature_diff = gamma_scalar - 1.0
|
|
645
|
+
bias_gradient_correlation = delta * bias_feature_diff * state.bias_h_trace
|
|
646
|
+
new_bias_log_step_size = state.bias_log_step_size - theta * bias_gradient_correlation
|
|
647
|
+
|
|
648
|
+
new_bias_log_step_size = jnp.clip(new_bias_log_step_size, -10.0, 2.0)
|
|
649
|
+
new_bias_alpha = jnp.exp(new_bias_log_step_size)
|
|
650
|
+
|
|
651
|
+
# Update bias eligibility trace
|
|
652
|
+
new_bias_eligibility_trace = gamma_scalar * lam * state.bias_eligibility_trace + 1.0
|
|
653
|
+
|
|
654
|
+
# Bias weight delta
|
|
655
|
+
bias_delta = new_bias_alpha * delta * new_bias_eligibility_trace
|
|
656
|
+
|
|
657
|
+
if self._use_semi_gradient:
|
|
658
|
+
# Semi-gradient bias h update
|
|
659
|
+
bias_h_decay = jnp.maximum(0.0, 1.0 - new_bias_alpha * new_bias_eligibility_trace)
|
|
660
|
+
new_bias_h_trace = (
|
|
661
|
+
state.bias_h_trace * bias_h_decay
|
|
662
|
+
+ new_bias_alpha * delta * new_bias_eligibility_trace
|
|
663
|
+
)
|
|
664
|
+
else:
|
|
665
|
+
# Ordinary gradient bias h update
|
|
666
|
+
bias_feature_diff = gamma_scalar - 1.0
|
|
667
|
+
bias_h_decay = jnp.maximum(
|
|
668
|
+
0.0, 1.0 + new_bias_alpha * new_bias_eligibility_trace * bias_feature_diff
|
|
669
|
+
)
|
|
670
|
+
new_bias_h_trace = (
|
|
671
|
+
state.bias_h_trace * bias_h_decay
|
|
672
|
+
+ new_bias_alpha * delta * new_bias_eligibility_trace
|
|
673
|
+
)
|
|
674
|
+
|
|
675
|
+
new_state = TDIDBDState(
|
|
676
|
+
log_step_sizes=new_log_step_sizes,
|
|
677
|
+
eligibility_traces=new_eligibility_traces,
|
|
678
|
+
h_traces=new_h_traces,
|
|
679
|
+
meta_step_size=theta,
|
|
680
|
+
trace_decay=lam,
|
|
681
|
+
bias_log_step_size=new_bias_log_step_size,
|
|
682
|
+
bias_eligibility_trace=new_bias_eligibility_trace,
|
|
683
|
+
bias_h_trace=new_bias_h_trace,
|
|
684
|
+
)
|
|
685
|
+
|
|
686
|
+
return TDOptimizerUpdate(
|
|
687
|
+
weight_delta=weight_delta,
|
|
688
|
+
bias_delta=bias_delta,
|
|
689
|
+
new_state=new_state,
|
|
690
|
+
metrics={
|
|
691
|
+
"mean_step_size": jnp.mean(new_alphas),
|
|
692
|
+
"min_step_size": jnp.min(new_alphas),
|
|
693
|
+
"max_step_size": jnp.max(new_alphas),
|
|
694
|
+
"mean_eligibility_trace": jnp.mean(jnp.abs(new_eligibility_traces)),
|
|
695
|
+
},
|
|
696
|
+
)
|
|
697
|
+
|
|
698
|
+
|
|
699
|
+
class AutoTDIDBD(TDOptimizer[AutoTDIDBDState]):
|
|
700
|
+
"""AutoStep-style normalized TD-IDBD optimizer.
|
|
701
|
+
|
|
702
|
+
Adds AutoStep-style normalization to TDIDBD for improved stability and
|
|
703
|
+
reduced sensitivity to the meta step-size theta. Includes:
|
|
704
|
+
1. Normalization of the meta-weight update by a running trace of recent updates
|
|
705
|
+
2. Effective step-size normalization to prevent overshooting
|
|
706
|
+
|
|
707
|
+
Reference: Kearney et al. 2019, Algorithm 6 "AutoStep Style Normalized TIDBD(λ)"
|
|
708
|
+
|
|
709
|
+
The AutoTDIDBD algorithm:
|
|
710
|
+
1. Compute TD error: `δ = R + γ*w^T*φ(s') - w^T*φ(s)`
|
|
711
|
+
2. Update normalizers: `η_i = max(|δ*[γφ_i(s')-φ_i(s)]*h_i|,
|
|
712
|
+
η_i - (1/τ)*α_i*[γφ_i(s')-φ_i(s)]*z_i*(|δ*φ_i(s)*h_i| - η_i))`
|
|
713
|
+
3. Normalized meta-update: `β_i -= θ*(1/η_i)*δ*[γφ_i(s')-φ_i(s)]*h_i`
|
|
714
|
+
4. Effective step-size normalization: `M = max(-exp(β)*[γφ(s')-φ(s)]^T*z, 1)`
|
|
715
|
+
then `β_i -= log(M)`
|
|
716
|
+
5. Update weights and traces as in TIDBD
|
|
717
|
+
|
|
718
|
+
Attributes:
|
|
719
|
+
initial_step_size: Initial per-weight step-size
|
|
720
|
+
meta_step_size: Meta learning rate theta
|
|
721
|
+
trace_decay: Eligibility trace decay lambda
|
|
722
|
+
normalizer_decay: Decay parameter tau for normalizers
|
|
723
|
+
"""
|
|
724
|
+
|
|
725
|
+
def __init__(
|
|
726
|
+
self,
|
|
727
|
+
initial_step_size: float = 0.01,
|
|
728
|
+
meta_step_size: float = 0.01,
|
|
729
|
+
trace_decay: float = 0.0,
|
|
730
|
+
normalizer_decay: float = 10000.0,
|
|
731
|
+
):
|
|
732
|
+
"""Initialize AutoTDIDBD optimizer.
|
|
733
|
+
|
|
734
|
+
Args:
|
|
735
|
+
initial_step_size: Initial value for per-weight step-sizes
|
|
736
|
+
meta_step_size: Meta learning rate theta for adapting step-sizes
|
|
737
|
+
trace_decay: Eligibility trace decay lambda (0 = TD(0))
|
|
738
|
+
normalizer_decay: Decay parameter tau for normalizers (default: 10000)
|
|
739
|
+
"""
|
|
740
|
+
self._initial_step_size = initial_step_size
|
|
741
|
+
self._meta_step_size = meta_step_size
|
|
742
|
+
self._trace_decay = trace_decay
|
|
743
|
+
self._normalizer_decay = normalizer_decay
|
|
744
|
+
|
|
745
|
+
def init(self, feature_dim: int) -> AutoTDIDBDState:
|
|
746
|
+
"""Initialize AutoTDIDBD state.
|
|
747
|
+
|
|
748
|
+
Args:
|
|
749
|
+
feature_dim: Dimension of weight vector
|
|
750
|
+
|
|
751
|
+
Returns:
|
|
752
|
+
AutoTDIDBD state with per-weight step-sizes, traces, h traces, and normalizers
|
|
753
|
+
"""
|
|
754
|
+
return AutoTDIDBDState(
|
|
755
|
+
log_step_sizes=jnp.full(
|
|
756
|
+
feature_dim, jnp.log(self._initial_step_size), dtype=jnp.float32
|
|
757
|
+
),
|
|
758
|
+
eligibility_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
|
|
759
|
+
h_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
|
|
760
|
+
normalizers=jnp.ones(feature_dim, dtype=jnp.float32),
|
|
761
|
+
meta_step_size=jnp.array(self._meta_step_size, dtype=jnp.float32),
|
|
762
|
+
trace_decay=jnp.array(self._trace_decay, dtype=jnp.float32),
|
|
763
|
+
normalizer_decay=jnp.array(self._normalizer_decay, dtype=jnp.float32),
|
|
764
|
+
bias_log_step_size=jnp.array(jnp.log(self._initial_step_size), dtype=jnp.float32),
|
|
765
|
+
bias_eligibility_trace=jnp.array(0.0, dtype=jnp.float32),
|
|
766
|
+
bias_h_trace=jnp.array(0.0, dtype=jnp.float32),
|
|
767
|
+
bias_normalizer=jnp.array(1.0, dtype=jnp.float32),
|
|
768
|
+
)
|
|
769
|
+
|
|
770
|
+
def update(
|
|
771
|
+
self,
|
|
772
|
+
state: AutoTDIDBDState,
|
|
773
|
+
td_error: Array,
|
|
774
|
+
observation: Array,
|
|
775
|
+
next_observation: Array,
|
|
776
|
+
gamma: Array,
|
|
777
|
+
) -> TDOptimizerUpdate:
|
|
778
|
+
"""Compute AutoTDIDBD weight update with normalized adaptive step-sizes.
|
|
779
|
+
|
|
780
|
+
Implements Algorithm 6 from Kearney et al. 2019.
|
|
781
|
+
|
|
782
|
+
Args:
|
|
783
|
+
state: Current AutoTDIDBD state
|
|
784
|
+
td_error: TD error δ = R + γV(s') - V(s)
|
|
785
|
+
observation: Current observation φ(s)
|
|
786
|
+
next_observation: Next observation φ(s')
|
|
787
|
+
gamma: Discount factor γ (0 at terminal)
|
|
788
|
+
|
|
789
|
+
Returns:
|
|
790
|
+
TDOptimizerUpdate with weight deltas and updated state
|
|
791
|
+
"""
|
|
792
|
+
delta = jnp.squeeze(td_error)
|
|
793
|
+
theta = state.meta_step_size
|
|
794
|
+
lam = state.trace_decay
|
|
795
|
+
tau = state.normalizer_decay
|
|
796
|
+
gamma_scalar = jnp.squeeze(gamma)
|
|
797
|
+
|
|
798
|
+
# Feature difference: γ*φ(s') - φ(s)
|
|
799
|
+
feature_diff = gamma_scalar * next_observation - observation
|
|
800
|
+
|
|
801
|
+
# Current step-sizes
|
|
802
|
+
alphas = jnp.exp(state.log_step_sizes)
|
|
803
|
+
|
|
804
|
+
# Update normalizers (Algorithm 6, lines 5-7)
|
|
805
|
+
# η_i = max(|δ*[γφ_i(s')-φ_i(s)]*h_i|,
|
|
806
|
+
# η_i - (1/τ)*α_i*[γφ_i(s')-φ_i(s)]*z_i*(|δ*φ_i(s)*h_i| - η_i))
|
|
807
|
+
abs_weight_update = jnp.abs(delta * feature_diff * state.h_traces)
|
|
808
|
+
normalizer_decay_term = (
|
|
809
|
+
(1.0 / tau)
|
|
810
|
+
* alphas
|
|
811
|
+
* feature_diff
|
|
812
|
+
* state.eligibility_traces
|
|
813
|
+
* (jnp.abs(delta * observation * state.h_traces) - state.normalizers)
|
|
814
|
+
)
|
|
815
|
+
new_normalizers = jnp.maximum(abs_weight_update, state.normalizers - normalizer_decay_term)
|
|
816
|
+
# Ensure normalizers don't go to zero
|
|
817
|
+
new_normalizers = jnp.maximum(new_normalizers, 1e-8)
|
|
818
|
+
|
|
819
|
+
# Normalized meta-update (Algorithm 6, line 9)
|
|
820
|
+
# β_i -= θ*(1/η_i)*δ*[γφ_i(s')-φ_i(s)]*h_i
|
|
821
|
+
normalized_gradient = delta * feature_diff * state.h_traces / new_normalizers
|
|
822
|
+
new_log_step_sizes = state.log_step_sizes - theta * normalized_gradient
|
|
823
|
+
|
|
824
|
+
# Effective step-size normalization (Algorithm 6, lines 10-11)
|
|
825
|
+
# M = max(-exp(β_i)*[γφ_i(s')-φ_i(s)]^T*z_i, 1)
|
|
826
|
+
# β_i -= log(M)
|
|
827
|
+
effective_step_size = -jnp.sum(
|
|
828
|
+
jnp.exp(new_log_step_sizes) * feature_diff * state.eligibility_traces
|
|
829
|
+
)
|
|
830
|
+
normalization_factor = jnp.maximum(effective_step_size, 1.0)
|
|
831
|
+
new_log_step_sizes = new_log_step_sizes - jnp.log(normalization_factor)
|
|
832
|
+
|
|
833
|
+
# Clip log step-sizes to prevent numerical issues
|
|
834
|
+
new_log_step_sizes = jnp.clip(new_log_step_sizes, -10.0, 2.0)
|
|
835
|
+
|
|
836
|
+
# Get updated step-sizes
|
|
837
|
+
new_alphas = jnp.exp(new_log_step_sizes)
|
|
838
|
+
|
|
839
|
+
# Update eligibility traces: z_i = γ*λ*z_i + φ_i(s)
|
|
840
|
+
new_eligibility_traces = gamma_scalar * lam * state.eligibility_traces + observation
|
|
841
|
+
|
|
842
|
+
# Compute weight delta: α_i*δ*z_i
|
|
843
|
+
weight_delta = new_alphas * delta * new_eligibility_traces
|
|
844
|
+
|
|
845
|
+
# Update h traces (ordinary gradient variant, Algorithm 6 line 15)
|
|
846
|
+
# h_i = h_i*[1 + α_i*[γφ_i(s')-φ_i(s)]*z_i]^+ + α_i*δ*z_i
|
|
847
|
+
h_decay = jnp.maximum(0.0, 1.0 + new_alphas * feature_diff * new_eligibility_traces)
|
|
848
|
+
new_h_traces = state.h_traces * h_decay + new_alphas * delta * new_eligibility_traces
|
|
849
|
+
|
|
850
|
+
# Bias updates
|
|
851
|
+
bias_alpha = jnp.exp(state.bias_log_step_size)
|
|
852
|
+
bias_feature_diff = gamma_scalar - 1.0 # For bias, φ(s) = 1
|
|
853
|
+
|
|
854
|
+
# Bias normalizer update
|
|
855
|
+
abs_bias_weight_update = jnp.abs(delta * bias_feature_diff * state.bias_h_trace)
|
|
856
|
+
bias_normalizer_decay_term = (
|
|
857
|
+
(1.0 / tau)
|
|
858
|
+
* bias_alpha
|
|
859
|
+
* bias_feature_diff
|
|
860
|
+
* state.bias_eligibility_trace
|
|
861
|
+
* (jnp.abs(delta * state.bias_h_trace) - state.bias_normalizer)
|
|
862
|
+
)
|
|
863
|
+
new_bias_normalizer = jnp.maximum(
|
|
864
|
+
abs_bias_weight_update, state.bias_normalizer - bias_normalizer_decay_term
|
|
865
|
+
)
|
|
866
|
+
new_bias_normalizer = jnp.maximum(new_bias_normalizer, 1e-8)
|
|
867
|
+
|
|
868
|
+
# Normalized bias meta-update
|
|
869
|
+
normalized_bias_gradient = (
|
|
870
|
+
delta * bias_feature_diff * state.bias_h_trace / new_bias_normalizer
|
|
871
|
+
)
|
|
872
|
+
new_bias_log_step_size = state.bias_log_step_size - theta * normalized_bias_gradient
|
|
873
|
+
|
|
874
|
+
# Effective step-size normalization for bias
|
|
875
|
+
bias_effective_step_size = (
|
|
876
|
+
-jnp.exp(new_bias_log_step_size) * bias_feature_diff * state.bias_eligibility_trace
|
|
877
|
+
)
|
|
878
|
+
bias_norm_factor = jnp.maximum(bias_effective_step_size, 1.0)
|
|
879
|
+
new_bias_log_step_size = new_bias_log_step_size - jnp.log(bias_norm_factor)
|
|
880
|
+
|
|
881
|
+
new_bias_log_step_size = jnp.clip(new_bias_log_step_size, -10.0, 2.0)
|
|
882
|
+
new_bias_alpha = jnp.exp(new_bias_log_step_size)
|
|
883
|
+
|
|
884
|
+
# Update bias eligibility trace
|
|
885
|
+
new_bias_eligibility_trace = gamma_scalar * lam * state.bias_eligibility_trace + 1.0
|
|
886
|
+
|
|
887
|
+
# Bias weight delta
|
|
888
|
+
bias_delta = new_bias_alpha * delta * new_bias_eligibility_trace
|
|
889
|
+
|
|
890
|
+
# Bias h trace update
|
|
891
|
+
bias_h_decay = jnp.maximum(
|
|
892
|
+
0.0, 1.0 + new_bias_alpha * bias_feature_diff * new_bias_eligibility_trace
|
|
893
|
+
)
|
|
894
|
+
new_bias_h_trace = (
|
|
895
|
+
state.bias_h_trace * bias_h_decay + new_bias_alpha * delta * new_bias_eligibility_trace
|
|
896
|
+
)
|
|
897
|
+
|
|
898
|
+
new_state = AutoTDIDBDState(
|
|
899
|
+
log_step_sizes=new_log_step_sizes,
|
|
900
|
+
eligibility_traces=new_eligibility_traces,
|
|
901
|
+
h_traces=new_h_traces,
|
|
902
|
+
normalizers=new_normalizers,
|
|
903
|
+
meta_step_size=theta,
|
|
904
|
+
trace_decay=lam,
|
|
905
|
+
normalizer_decay=tau,
|
|
906
|
+
bias_log_step_size=new_bias_log_step_size,
|
|
907
|
+
bias_eligibility_trace=new_bias_eligibility_trace,
|
|
908
|
+
bias_h_trace=new_bias_h_trace,
|
|
909
|
+
bias_normalizer=new_bias_normalizer,
|
|
910
|
+
)
|
|
911
|
+
|
|
912
|
+
return TDOptimizerUpdate(
|
|
913
|
+
weight_delta=weight_delta,
|
|
914
|
+
bias_delta=bias_delta,
|
|
915
|
+
new_state=new_state,
|
|
916
|
+
metrics={
|
|
917
|
+
"mean_step_size": jnp.mean(new_alphas),
|
|
918
|
+
"min_step_size": jnp.min(new_alphas),
|
|
919
|
+
"max_step_size": jnp.max(new_alphas),
|
|
920
|
+
"mean_eligibility_trace": jnp.mean(jnp.abs(new_eligibility_traces)),
|
|
921
|
+
"mean_normalizer": jnp.mean(new_normalizers),
|
|
922
|
+
},
|
|
923
|
+
)
|
alberta_framework/core/types.py
CHANGED
|
@@ -255,6 +255,118 @@ def create_idbd_state(
|
|
|
255
255
|
)
|
|
256
256
|
|
|
257
257
|
|
|
258
|
+
# =============================================================================
|
|
259
|
+
# TD Learning Types (for Step 3+ of Alberta Plan)
|
|
260
|
+
# =============================================================================
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
@chex.dataclass(frozen=True)
|
|
264
|
+
class TDTimeStep:
|
|
265
|
+
"""Single experience from a TD stream.
|
|
266
|
+
|
|
267
|
+
Represents a transition (s, r, s', gamma) for temporal-difference learning.
|
|
268
|
+
|
|
269
|
+
Attributes:
|
|
270
|
+
observation: Feature vector φ(s)
|
|
271
|
+
reward: Reward R received
|
|
272
|
+
next_observation: Feature vector φ(s')
|
|
273
|
+
gamma: Discount factor γ_t (0 at terminal states)
|
|
274
|
+
"""
|
|
275
|
+
|
|
276
|
+
observation: Float[Array, " feature_dim"]
|
|
277
|
+
reward: Float[Array, ""]
|
|
278
|
+
next_observation: Float[Array, " feature_dim"]
|
|
279
|
+
gamma: Float[Array, ""]
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
@chex.dataclass(frozen=True)
|
|
283
|
+
class TDIDBDState:
|
|
284
|
+
"""State for the TD-IDBD (Temporal-Difference IDBD) optimizer.
|
|
285
|
+
|
|
286
|
+
TD-IDBD extends IDBD to temporal-difference learning with eligibility traces.
|
|
287
|
+
Maintains per-weight adaptive step-sizes that are meta-learned based on
|
|
288
|
+
gradient correlation in the TD setting.
|
|
289
|
+
|
|
290
|
+
Reference: Kearney et al. 2019, "Learning Feature Relevance Through Step Size
|
|
291
|
+
Adaptation in Temporal-Difference Learning"
|
|
292
|
+
|
|
293
|
+
Attributes:
|
|
294
|
+
log_step_sizes: Log of per-weight step-sizes (log alpha_i)
|
|
295
|
+
eligibility_traces: Eligibility traces z_i for temporal credit assignment
|
|
296
|
+
h_traces: Per-weight h traces for gradient correlation
|
|
297
|
+
meta_step_size: Meta learning rate theta for adapting step-sizes
|
|
298
|
+
trace_decay: Eligibility trace decay parameter lambda
|
|
299
|
+
bias_log_step_size: Log step-size for the bias term
|
|
300
|
+
bias_eligibility_trace: Eligibility trace for the bias
|
|
301
|
+
bias_h_trace: h trace for the bias term
|
|
302
|
+
"""
|
|
303
|
+
|
|
304
|
+
log_step_sizes: Float[Array, " feature_dim"]
|
|
305
|
+
eligibility_traces: Float[Array, " feature_dim"]
|
|
306
|
+
h_traces: Float[Array, " feature_dim"]
|
|
307
|
+
meta_step_size: Float[Array, ""]
|
|
308
|
+
trace_decay: Float[Array, ""]
|
|
309
|
+
bias_log_step_size: Float[Array, ""]
|
|
310
|
+
bias_eligibility_trace: Float[Array, ""]
|
|
311
|
+
bias_h_trace: Float[Array, ""]
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
@chex.dataclass(frozen=True)
|
|
315
|
+
class AutoTDIDBDState:
|
|
316
|
+
"""State for the AutoTDIDBD optimizer.
|
|
317
|
+
|
|
318
|
+
AutoTDIDBD adds AutoStep-style normalization to TDIDBD for improved stability.
|
|
319
|
+
Includes normalizers for the meta-weight updates and effective step-size
|
|
320
|
+
normalization to prevent overshooting.
|
|
321
|
+
|
|
322
|
+
Reference: Kearney et al. 2019, Algorithm 6
|
|
323
|
+
|
|
324
|
+
Attributes:
|
|
325
|
+
log_step_sizes: Log of per-weight step-sizes (log alpha_i)
|
|
326
|
+
eligibility_traces: Eligibility traces z_i
|
|
327
|
+
h_traces: Per-weight h traces for gradient correlation
|
|
328
|
+
normalizers: Running max of absolute gradient correlations (eta_i)
|
|
329
|
+
meta_step_size: Meta learning rate theta
|
|
330
|
+
trace_decay: Eligibility trace decay parameter lambda
|
|
331
|
+
normalizer_decay: Decay parameter tau for normalizers
|
|
332
|
+
bias_log_step_size: Log step-size for the bias term
|
|
333
|
+
bias_eligibility_trace: Eligibility trace for the bias
|
|
334
|
+
bias_h_trace: h trace for the bias term
|
|
335
|
+
bias_normalizer: Normalizer for the bias gradient correlation
|
|
336
|
+
"""
|
|
337
|
+
|
|
338
|
+
log_step_sizes: Float[Array, " feature_dim"]
|
|
339
|
+
eligibility_traces: Float[Array, " feature_dim"]
|
|
340
|
+
h_traces: Float[Array, " feature_dim"]
|
|
341
|
+
normalizers: Float[Array, " feature_dim"]
|
|
342
|
+
meta_step_size: Float[Array, ""]
|
|
343
|
+
trace_decay: Float[Array, ""]
|
|
344
|
+
normalizer_decay: Float[Array, ""]
|
|
345
|
+
bias_log_step_size: Float[Array, ""]
|
|
346
|
+
bias_eligibility_trace: Float[Array, ""]
|
|
347
|
+
bias_h_trace: Float[Array, ""]
|
|
348
|
+
bias_normalizer: Float[Array, ""]
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
# Union type for TD optimizer states
|
|
352
|
+
TDOptimizerState = TDIDBDState | AutoTDIDBDState
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
@chex.dataclass(frozen=True)
|
|
356
|
+
class TDLearnerState:
|
|
357
|
+
"""State for a TD linear learner.
|
|
358
|
+
|
|
359
|
+
Attributes:
|
|
360
|
+
weights: Weight vector for linear value function approximation
|
|
361
|
+
bias: Bias term
|
|
362
|
+
optimizer_state: State maintained by the TD optimizer
|
|
363
|
+
"""
|
|
364
|
+
|
|
365
|
+
weights: Float[Array, " feature_dim"]
|
|
366
|
+
bias: Float[Array, ""]
|
|
367
|
+
optimizer_state: TDOptimizerState
|
|
368
|
+
|
|
369
|
+
|
|
258
370
|
def create_autostep_state(
|
|
259
371
|
feature_dim: int,
|
|
260
372
|
initial_step_size: float = 0.01,
|
|
@@ -282,3 +394,66 @@ def create_autostep_state(
|
|
|
282
394
|
bias_trace=jnp.array(0.0, dtype=jnp.float32),
|
|
283
395
|
bias_normalizer=jnp.array(1.0, dtype=jnp.float32),
|
|
284
396
|
)
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
def create_tdidbd_state(
|
|
400
|
+
feature_dim: int,
|
|
401
|
+
initial_step_size: float = 0.01,
|
|
402
|
+
meta_step_size: float = 0.01,
|
|
403
|
+
trace_decay: float = 0.0,
|
|
404
|
+
) -> TDIDBDState:
|
|
405
|
+
"""Create initial TD-IDBD optimizer state.
|
|
406
|
+
|
|
407
|
+
Args:
|
|
408
|
+
feature_dim: Dimension of the feature vector
|
|
409
|
+
initial_step_size: Initial per-weight step-size
|
|
410
|
+
meta_step_size: Meta learning rate theta for adapting step-sizes
|
|
411
|
+
trace_decay: Eligibility trace decay parameter lambda (0 = TD(0))
|
|
412
|
+
|
|
413
|
+
Returns:
|
|
414
|
+
Initial TD-IDBD state
|
|
415
|
+
"""
|
|
416
|
+
return TDIDBDState(
|
|
417
|
+
log_step_sizes=jnp.full(feature_dim, jnp.log(initial_step_size), dtype=jnp.float32),
|
|
418
|
+
eligibility_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
|
|
419
|
+
h_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
|
|
420
|
+
meta_step_size=jnp.array(meta_step_size, dtype=jnp.float32),
|
|
421
|
+
trace_decay=jnp.array(trace_decay, dtype=jnp.float32),
|
|
422
|
+
bias_log_step_size=jnp.array(jnp.log(initial_step_size), dtype=jnp.float32),
|
|
423
|
+
bias_eligibility_trace=jnp.array(0.0, dtype=jnp.float32),
|
|
424
|
+
bias_h_trace=jnp.array(0.0, dtype=jnp.float32),
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
def create_autotdidbd_state(
|
|
429
|
+
feature_dim: int,
|
|
430
|
+
initial_step_size: float = 0.01,
|
|
431
|
+
meta_step_size: float = 0.01,
|
|
432
|
+
trace_decay: float = 0.0,
|
|
433
|
+
normalizer_decay: float = 10000.0,
|
|
434
|
+
) -> AutoTDIDBDState:
|
|
435
|
+
"""Create initial AutoTDIDBD optimizer state.
|
|
436
|
+
|
|
437
|
+
Args:
|
|
438
|
+
feature_dim: Dimension of the feature vector
|
|
439
|
+
initial_step_size: Initial per-weight step-size
|
|
440
|
+
meta_step_size: Meta learning rate theta for adapting step-sizes
|
|
441
|
+
trace_decay: Eligibility trace decay parameter lambda (0 = TD(0))
|
|
442
|
+
normalizer_decay: Decay parameter tau for normalizers (default: 10000)
|
|
443
|
+
|
|
444
|
+
Returns:
|
|
445
|
+
Initial AutoTDIDBD state
|
|
446
|
+
"""
|
|
447
|
+
return AutoTDIDBDState(
|
|
448
|
+
log_step_sizes=jnp.full(feature_dim, jnp.log(initial_step_size), dtype=jnp.float32),
|
|
449
|
+
eligibility_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
|
|
450
|
+
h_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
|
|
451
|
+
normalizers=jnp.ones(feature_dim, dtype=jnp.float32),
|
|
452
|
+
meta_step_size=jnp.array(meta_step_size, dtype=jnp.float32),
|
|
453
|
+
trace_decay=jnp.array(trace_decay, dtype=jnp.float32),
|
|
454
|
+
normalizer_decay=jnp.array(normalizer_decay, dtype=jnp.float32),
|
|
455
|
+
bias_log_step_size=jnp.array(jnp.log(initial_step_size), dtype=jnp.float32),
|
|
456
|
+
bias_eligibility_trace=jnp.array(0.0, dtype=jnp.float32),
|
|
457
|
+
bias_h_trace=jnp.array(0.0, dtype=jnp.float32),
|
|
458
|
+
bias_normalizer=jnp.array(1.0, dtype=jnp.float32),
|
|
459
|
+
)
|