alberta-framework 0.2.1__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alberta_framework/core/learners.py +69 -60
- alberta_framework/core/normalizers.py +8 -7
- alberta_framework/core/optimizers.py +6 -4
- alberta_framework/core/types.py +65 -52
- alberta_framework/streams/synthetic.py +42 -32
- {alberta_framework-0.2.1.dist-info → alberta_framework-0.3.0.dist-info}/METADATA +4 -1
- {alberta_framework-0.2.1.dist-info → alberta_framework-0.3.0.dist-info}/RECORD +9 -9
- {alberta_framework-0.2.1.dist-info → alberta_framework-0.3.0.dist-info}/WHEEL +0 -0
- {alberta_framework-0.2.1.dist-info → alberta_framework-0.3.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -5,11 +5,13 @@ for temporally-uniform learning. Uses JAX's scan for efficient JIT-compiled
|
|
|
5
5
|
training loops.
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
|
-
from typing import
|
|
8
|
+
from typing import cast
|
|
9
9
|
|
|
10
|
+
import chex
|
|
10
11
|
import jax
|
|
11
12
|
import jax.numpy as jnp
|
|
12
13
|
from jax import Array
|
|
14
|
+
from jaxtyping import Float
|
|
13
15
|
|
|
14
16
|
from alberta_framework.core.normalizers import NormalizerState, OnlineNormalizer
|
|
15
17
|
from alberta_framework.core.optimizers import LMS, Optimizer
|
|
@@ -33,7 +35,9 @@ from alberta_framework.streams.base import ScanStream
|
|
|
33
35
|
# Type alias for any optimizer type
|
|
34
36
|
AnyOptimizer = Optimizer[LMSState] | Optimizer[IDBDState] | Optimizer[AutostepState]
|
|
35
37
|
|
|
36
|
-
|
|
38
|
+
|
|
39
|
+
@chex.dataclass(frozen=True)
|
|
40
|
+
class UpdateResult:
|
|
37
41
|
"""Result of a learner update step.
|
|
38
42
|
|
|
39
43
|
Attributes:
|
|
@@ -45,8 +49,38 @@ class UpdateResult(NamedTuple):
|
|
|
45
49
|
|
|
46
50
|
state: LearnerState
|
|
47
51
|
prediction: Prediction
|
|
48
|
-
error: Array
|
|
49
|
-
metrics: Array
|
|
52
|
+
error: Float[Array, ""]
|
|
53
|
+
metrics: Float[Array, " 3"]
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@chex.dataclass(frozen=True)
|
|
57
|
+
class NormalizedLearnerState:
|
|
58
|
+
"""State for a learner with online feature normalization.
|
|
59
|
+
|
|
60
|
+
Attributes:
|
|
61
|
+
learner_state: Underlying learner state (weights, bias, optimizer)
|
|
62
|
+
normalizer_state: Online normalizer state (mean, var estimates)
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
learner_state: LearnerState
|
|
66
|
+
normalizer_state: NormalizerState
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@chex.dataclass(frozen=True)
|
|
70
|
+
class NormalizedUpdateResult:
|
|
71
|
+
"""Result of a normalized learner update step.
|
|
72
|
+
|
|
73
|
+
Attributes:
|
|
74
|
+
state: Updated normalized learner state
|
|
75
|
+
prediction: Prediction made before update
|
|
76
|
+
error: Prediction error
|
|
77
|
+
metrics: Array of metrics [squared_error, error, mean_step_size, normalizer_mean_var]
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
state: NormalizedLearnerState
|
|
81
|
+
prediction: Prediction
|
|
82
|
+
error: Float[Array, ""]
|
|
83
|
+
metrics: Float[Array, " 4"]
|
|
50
84
|
|
|
51
85
|
|
|
52
86
|
class LinearLearner:
|
|
@@ -133,8 +167,7 @@ class LinearLearner:
|
|
|
133
167
|
# Note: type ignore needed because we can't statically prove optimizer_state
|
|
134
168
|
# matches the optimizer's expected state type (though they will at runtime)
|
|
135
169
|
opt_update = self._optimizer.update(
|
|
136
|
-
state.optimizer_state,
|
|
137
|
-
error,
|
|
170
|
+
state.optimizer_state, error,
|
|
138
171
|
observation,
|
|
139
172
|
)
|
|
140
173
|
|
|
@@ -275,12 +308,12 @@ def run_learning_loop[StreamStateT](
|
|
|
275
308
|
opt_state = result.state.optimizer_state
|
|
276
309
|
if hasattr(opt_state, "log_step_sizes"):
|
|
277
310
|
# IDBD stores log step-sizes
|
|
278
|
-
weight_ss = jnp.exp(opt_state.log_step_sizes)
|
|
279
|
-
bias_ss = opt_state.bias_step_size
|
|
311
|
+
weight_ss = jnp.exp(opt_state.log_step_sizes)
|
|
312
|
+
bias_ss = opt_state.bias_step_size
|
|
280
313
|
elif hasattr(opt_state, "step_sizes"):
|
|
281
314
|
# Autostep stores step-sizes directly
|
|
282
|
-
weight_ss = opt_state.step_sizes
|
|
283
|
-
bias_ss = opt_state.bias_step_size
|
|
315
|
+
weight_ss = opt_state.step_sizes
|
|
316
|
+
bias_ss = opt_state.bias_step_size
|
|
284
317
|
else:
|
|
285
318
|
# LMS has a single fixed step-size
|
|
286
319
|
weight_ss = jnp.full(feature_dim, opt_state.step_size)
|
|
@@ -316,8 +349,7 @@ def run_learning_loop[StreamStateT](
|
|
|
316
349
|
new_norm_history = jax.lax.cond(
|
|
317
350
|
should_record,
|
|
318
351
|
lambda _: norm_history.at[recording_idx].set(
|
|
319
|
-
opt_state.normalizers
|
|
320
|
-
),
|
|
352
|
+
opt_state.normalizers ),
|
|
321
353
|
lambda _: norm_history,
|
|
322
354
|
None,
|
|
323
355
|
)
|
|
@@ -361,34 +393,6 @@ def run_learning_loop[StreamStateT](
|
|
|
361
393
|
return final_learner, metrics, history
|
|
362
394
|
|
|
363
395
|
|
|
364
|
-
class NormalizedLearnerState(NamedTuple):
|
|
365
|
-
"""State for a learner with online feature normalization.
|
|
366
|
-
|
|
367
|
-
Attributes:
|
|
368
|
-
learner_state: Underlying learner state (weights, bias, optimizer)
|
|
369
|
-
normalizer_state: Online normalizer state (mean, var estimates)
|
|
370
|
-
"""
|
|
371
|
-
|
|
372
|
-
learner_state: LearnerState
|
|
373
|
-
normalizer_state: NormalizerState
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
class NormalizedUpdateResult(NamedTuple):
|
|
377
|
-
"""Result of a normalized learner update step.
|
|
378
|
-
|
|
379
|
-
Attributes:
|
|
380
|
-
state: Updated normalized learner state
|
|
381
|
-
prediction: Prediction made before update
|
|
382
|
-
error: Prediction error
|
|
383
|
-
metrics: Array of metrics [squared_error, error, mean_step_size, normalizer_mean_var]
|
|
384
|
-
"""
|
|
385
|
-
|
|
386
|
-
state: NormalizedLearnerState
|
|
387
|
-
prediction: Prediction
|
|
388
|
-
error: Array
|
|
389
|
-
metrics: Array
|
|
390
|
-
|
|
391
|
-
|
|
392
396
|
class NormalizedLinearLearner:
|
|
393
397
|
"""Linear learner with online feature normalization.
|
|
394
398
|
|
|
@@ -702,12 +706,12 @@ def run_normalized_learning_loop[StreamStateT](
|
|
|
702
706
|
opt_state = result.state.learner_state.optimizer_state
|
|
703
707
|
if hasattr(opt_state, "log_step_sizes"):
|
|
704
708
|
# IDBD stores log step-sizes
|
|
705
|
-
weight_ss = jnp.exp(opt_state.log_step_sizes)
|
|
706
|
-
bias_ss = opt_state.bias_step_size
|
|
709
|
+
weight_ss = jnp.exp(opt_state.log_step_sizes)
|
|
710
|
+
bias_ss = opt_state.bias_step_size
|
|
707
711
|
elif hasattr(opt_state, "step_sizes"):
|
|
708
712
|
# Autostep stores step-sizes directly
|
|
709
|
-
weight_ss = opt_state.step_sizes
|
|
710
|
-
bias_ss = opt_state.bias_step_size
|
|
713
|
+
weight_ss = opt_state.step_sizes
|
|
714
|
+
bias_ss = opt_state.bias_step_size
|
|
711
715
|
else:
|
|
712
716
|
# LMS has a single fixed step-size
|
|
713
717
|
weight_ss = jnp.full(feature_dim, opt_state.step_size)
|
|
@@ -741,8 +745,7 @@ def run_normalized_learning_loop[StreamStateT](
|
|
|
741
745
|
new_ss_norm = jax.lax.cond(
|
|
742
746
|
should_record_ss,
|
|
743
747
|
lambda _: ss_norm.at[recording_idx].set(
|
|
744
|
-
opt_state.normalizers
|
|
745
|
-
),
|
|
748
|
+
opt_state.normalizers ),
|
|
746
749
|
lambda _: ss_norm,
|
|
747
750
|
None,
|
|
748
751
|
)
|
|
@@ -825,17 +828,14 @@ def run_normalized_learning_loop[StreamStateT](
|
|
|
825
828
|
ss_history_result = StepSizeHistory(
|
|
826
829
|
step_sizes=final_ss_hist,
|
|
827
830
|
bias_step_sizes=final_ss_bias_hist,
|
|
828
|
-
recording_indices=final_ss_rec,
|
|
829
|
-
normalizers=final_ss_norm,
|
|
831
|
+
recording_indices=final_ss_rec, normalizers=final_ss_norm,
|
|
830
832
|
)
|
|
831
833
|
|
|
832
834
|
norm_history_result = None
|
|
833
835
|
if normalizer_tracking is not None and final_n_means is not None:
|
|
834
836
|
norm_history_result = NormalizerHistory(
|
|
835
837
|
means=final_n_means,
|
|
836
|
-
variances=final_n_vars,
|
|
837
|
-
recording_indices=final_n_rec, # type: ignore[arg-type]
|
|
838
|
-
)
|
|
838
|
+
variances=final_n_vars, recording_indices=final_n_rec, )
|
|
839
839
|
|
|
840
840
|
# Return appropriate tuple based on what was tracked
|
|
841
841
|
if ss_history_result is not None and norm_history_result is not None:
|
|
@@ -900,10 +900,12 @@ def run_learning_loop_batched[StreamStateT](
|
|
|
900
900
|
learner, stream, num_steps, key, learner_state, step_size_tracking
|
|
901
901
|
)
|
|
902
902
|
if step_size_tracking is not None:
|
|
903
|
-
state, metrics, history =
|
|
903
|
+
state, metrics, history = cast(
|
|
904
|
+
tuple[LearnerState, Array, StepSizeHistory], result
|
|
905
|
+
)
|
|
904
906
|
return state, metrics, history
|
|
905
907
|
else:
|
|
906
|
-
state, metrics = result
|
|
908
|
+
state, metrics = cast(tuple[LearnerState, Array], result)
|
|
907
909
|
# Return None for history to maintain consistent output structure
|
|
908
910
|
return state, metrics, None
|
|
909
911
|
|
|
@@ -911,7 +913,7 @@ def run_learning_loop_batched[StreamStateT](
|
|
|
911
913
|
batched_states, batched_metrics, batched_history = jax.vmap(single_seed_run)(keys)
|
|
912
914
|
|
|
913
915
|
# Reconstruct batched history if tracking was enabled
|
|
914
|
-
if step_size_tracking is not None:
|
|
916
|
+
if step_size_tracking is not None and batched_history is not None:
|
|
915
917
|
batched_step_size_history = StepSizeHistory(
|
|
916
918
|
step_sizes=batched_history.step_sizes,
|
|
917
919
|
bias_step_sizes=batched_history.bias_step_sizes,
|
|
@@ -993,16 +995,23 @@ def run_normalized_learning_loop_batched[StreamStateT](
|
|
|
993
995
|
|
|
994
996
|
# Unpack based on what tracking was enabled
|
|
995
997
|
if step_size_tracking is not None and normalizer_tracking is not None:
|
|
996
|
-
state, metrics, ss_history, norm_history =
|
|
998
|
+
state, metrics, ss_history, norm_history = cast(
|
|
999
|
+
tuple[NormalizedLearnerState, Array, StepSizeHistory, NormalizerHistory],
|
|
1000
|
+
result,
|
|
1001
|
+
)
|
|
997
1002
|
return state, metrics, ss_history, norm_history
|
|
998
1003
|
elif step_size_tracking is not None:
|
|
999
|
-
state, metrics, ss_history =
|
|
1004
|
+
state, metrics, ss_history = cast(
|
|
1005
|
+
tuple[NormalizedLearnerState, Array, StepSizeHistory], result
|
|
1006
|
+
)
|
|
1000
1007
|
return state, metrics, ss_history, None
|
|
1001
1008
|
elif normalizer_tracking is not None:
|
|
1002
|
-
state, metrics, norm_history =
|
|
1009
|
+
state, metrics, norm_history = cast(
|
|
1010
|
+
tuple[NormalizedLearnerState, Array, NormalizerHistory], result
|
|
1011
|
+
)
|
|
1003
1012
|
return state, metrics, None, norm_history
|
|
1004
1013
|
else:
|
|
1005
|
-
state, metrics = result
|
|
1014
|
+
state, metrics = cast(tuple[NormalizedLearnerState, Array], result)
|
|
1006
1015
|
return state, metrics, None, None
|
|
1007
1016
|
|
|
1008
1017
|
# vmap over the keys dimension
|
|
@@ -1011,7 +1020,7 @@ def run_normalized_learning_loop_batched[StreamStateT](
|
|
|
1011
1020
|
)
|
|
1012
1021
|
|
|
1013
1022
|
# Reconstruct batched histories if tracking was enabled
|
|
1014
|
-
if step_size_tracking is not None:
|
|
1023
|
+
if step_size_tracking is not None and batched_ss_history is not None:
|
|
1015
1024
|
batched_step_size_history = StepSizeHistory(
|
|
1016
1025
|
step_sizes=batched_ss_history.step_sizes,
|
|
1017
1026
|
bias_step_sizes=batched_ss_history.bias_step_sizes,
|
|
@@ -1021,7 +1030,7 @@ def run_normalized_learning_loop_batched[StreamStateT](
|
|
|
1021
1030
|
else:
|
|
1022
1031
|
batched_step_size_history = None
|
|
1023
1032
|
|
|
1024
|
-
if normalizer_tracking is not None:
|
|
1033
|
+
if normalizer_tracking is not None and batched_norm_history is not None:
|
|
1025
1034
|
batched_normalizer_history = NormalizerHistory(
|
|
1026
1035
|
means=batched_norm_history.means,
|
|
1027
1036
|
variances=batched_norm_history.variances,
|
|
@@ -6,13 +6,14 @@ and variance at every time step, following the principle of temporal uniformity.
|
|
|
6
6
|
Reference: Welford's online algorithm for numerical stability.
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
|
-
|
|
10
|
-
|
|
9
|
+
import chex
|
|
11
10
|
import jax.numpy as jnp
|
|
12
11
|
from jax import Array
|
|
12
|
+
from jaxtyping import Float
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
|
|
15
|
+
@chex.dataclass(frozen=True)
|
|
16
|
+
class NormalizerState:
|
|
16
17
|
"""State for online feature normalization.
|
|
17
18
|
|
|
18
19
|
Uses Welford's online algorithm for numerically stable estimation
|
|
@@ -25,10 +26,10 @@ class NormalizerState(NamedTuple):
|
|
|
25
26
|
decay: Exponential decay factor for estimates (1.0 = no decay, pure online)
|
|
26
27
|
"""
|
|
27
28
|
|
|
28
|
-
mean: Array
|
|
29
|
-
var: Array
|
|
30
|
-
sample_count: Array
|
|
31
|
-
decay: Array
|
|
29
|
+
mean: Float[Array, " feature_dim"]
|
|
30
|
+
var: Float[Array, " feature_dim"]
|
|
31
|
+
sample_count: Float[Array, ""]
|
|
32
|
+
decay: Float[Array, ""]
|
|
32
33
|
|
|
33
34
|
|
|
34
35
|
class OnlineNormalizer:
|
|
@@ -10,15 +10,17 @@ References:
|
|
|
10
10
|
"""
|
|
11
11
|
|
|
12
12
|
from abc import ABC, abstractmethod
|
|
13
|
-
from typing import NamedTuple
|
|
14
13
|
|
|
14
|
+
import chex
|
|
15
15
|
import jax.numpy as jnp
|
|
16
16
|
from jax import Array
|
|
17
|
+
from jaxtyping import Float
|
|
17
18
|
|
|
18
19
|
from alberta_framework.core.types import AutostepState, IDBDState, LMSState
|
|
19
20
|
|
|
20
21
|
|
|
21
|
-
|
|
22
|
+
@chex.dataclass(frozen=True)
|
|
23
|
+
class OptimizerUpdate:
|
|
22
24
|
"""Result of an optimizer update step.
|
|
23
25
|
|
|
24
26
|
Attributes:
|
|
@@ -28,8 +30,8 @@ class OptimizerUpdate(NamedTuple):
|
|
|
28
30
|
metrics: Dictionary of metrics for logging (values are JAX arrays for scan compatibility)
|
|
29
31
|
"""
|
|
30
32
|
|
|
31
|
-
weight_delta: Array
|
|
32
|
-
bias_delta: Array
|
|
33
|
+
weight_delta: Float[Array, " feature_dim"]
|
|
34
|
+
bias_delta: Float[Array, ""]
|
|
33
35
|
new_state: LMSState | IDBDState | AutostepState
|
|
34
36
|
metrics: dict[str, Array]
|
|
35
37
|
|
alberta_framework/core/types.py
CHANGED
|
@@ -1,13 +1,15 @@
|
|
|
1
1
|
"""Type definitions for the Alberta Framework.
|
|
2
2
|
|
|
3
3
|
This module defines the core data types used throughout the framework,
|
|
4
|
-
|
|
4
|
+
using chex dataclasses for JAX compatibility and jaxtyping for shape annotations.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
from typing import TYPE_CHECKING
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
8
8
|
|
|
9
|
+
import chex
|
|
9
10
|
import jax.numpy as jnp
|
|
10
11
|
from jax import Array
|
|
12
|
+
from jaxtyping import Float, Int, PRNGKeyArray
|
|
11
13
|
|
|
12
14
|
if TYPE_CHECKING:
|
|
13
15
|
from alberta_framework.core.learners import NormalizedLearnerState
|
|
@@ -19,7 +21,8 @@ Prediction = Array # y_t: model output
|
|
|
19
21
|
Reward = float # r_t: scalar reward
|
|
20
22
|
|
|
21
23
|
|
|
22
|
-
|
|
24
|
+
@chex.dataclass(frozen=True)
|
|
25
|
+
class TimeStep:
|
|
23
26
|
"""Single experience from an experience stream.
|
|
24
27
|
|
|
25
28
|
Attributes:
|
|
@@ -27,25 +30,12 @@ class TimeStep(NamedTuple):
|
|
|
27
30
|
target: Desired output y*_t (for supervised learning)
|
|
28
31
|
"""
|
|
29
32
|
|
|
30
|
-
observation:
|
|
31
|
-
target:
|
|
33
|
+
observation: Float[Array, " feature_dim"]
|
|
34
|
+
target: Float[Array, " 1"]
|
|
32
35
|
|
|
33
36
|
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
Attributes:
|
|
38
|
-
weights: Weight vector for linear prediction
|
|
39
|
-
bias: Bias term
|
|
40
|
-
optimizer_state: State maintained by the optimizer
|
|
41
|
-
"""
|
|
42
|
-
|
|
43
|
-
weights: Array
|
|
44
|
-
bias: Array
|
|
45
|
-
optimizer_state: "LMSState | IDBDState | AutostepState"
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
class LMSState(NamedTuple):
|
|
37
|
+
@chex.dataclass(frozen=True)
|
|
38
|
+
class LMSState:
|
|
49
39
|
"""State for the LMS (Least Mean Square) optimizer.
|
|
50
40
|
|
|
51
41
|
LMS uses a fixed step-size, so state only tracks the step-size parameter.
|
|
@@ -54,10 +44,11 @@ class LMSState(NamedTuple):
|
|
|
54
44
|
step_size: Fixed learning rate alpha
|
|
55
45
|
"""
|
|
56
46
|
|
|
57
|
-
step_size: Array
|
|
47
|
+
step_size: Float[Array, ""]
|
|
58
48
|
|
|
59
49
|
|
|
60
|
-
|
|
50
|
+
@chex.dataclass(frozen=True)
|
|
51
|
+
class IDBDState:
|
|
61
52
|
"""State for the IDBD (Incremental Delta-Bar-Delta) optimizer.
|
|
62
53
|
|
|
63
54
|
IDBD maintains per-weight adaptive step-sizes that are meta-learned
|
|
@@ -73,14 +64,15 @@ class IDBDState(NamedTuple):
|
|
|
73
64
|
bias_trace: Trace for the bias term
|
|
74
65
|
"""
|
|
75
66
|
|
|
76
|
-
log_step_sizes: Array # log(alpha_i) for numerical stability
|
|
77
|
-
traces: Array # h_i: trace of weight-feature products
|
|
78
|
-
meta_step_size: Array # beta: step-size for the step-sizes
|
|
79
|
-
bias_step_size: Array # Step-size for bias
|
|
80
|
-
bias_trace: Array # Trace for bias
|
|
67
|
+
log_step_sizes: Float[Array, " feature_dim"] # log(alpha_i) for numerical stability
|
|
68
|
+
traces: Float[Array, " feature_dim"] # h_i: trace of weight-feature products
|
|
69
|
+
meta_step_size: Float[Array, ""] # beta: step-size for the step-sizes
|
|
70
|
+
bias_step_size: Float[Array, ""] # Step-size for bias
|
|
71
|
+
bias_trace: Float[Array, ""] # Trace for bias
|
|
81
72
|
|
|
82
73
|
|
|
83
|
-
|
|
74
|
+
@chex.dataclass(frozen=True)
|
|
75
|
+
class AutostepState:
|
|
84
76
|
"""State for the Autostep optimizer.
|
|
85
77
|
|
|
86
78
|
Autostep is a tuning-free step-size adaptation algorithm that normalizes
|
|
@@ -100,17 +92,33 @@ class AutostepState(NamedTuple):
|
|
|
100
92
|
bias_normalizer: Normalizer for the bias gradient
|
|
101
93
|
"""
|
|
102
94
|
|
|
103
|
-
step_sizes: Array # alpha_i
|
|
104
|
-
traces: Array # h_i
|
|
105
|
-
normalizers: Array # v_i: running max of |gradient|
|
|
106
|
-
meta_step_size: Array # mu
|
|
107
|
-
normalizer_decay: Array # tau
|
|
108
|
-
bias_step_size: Array
|
|
109
|
-
bias_trace: Array
|
|
110
|
-
bias_normalizer: Array
|
|
95
|
+
step_sizes: Float[Array, " feature_dim"] # alpha_i
|
|
96
|
+
traces: Float[Array, " feature_dim"] # h_i
|
|
97
|
+
normalizers: Float[Array, " feature_dim"] # v_i: running max of |gradient|
|
|
98
|
+
meta_step_size: Float[Array, ""] # mu
|
|
99
|
+
normalizer_decay: Float[Array, ""] # tau
|
|
100
|
+
bias_step_size: Float[Array, ""]
|
|
101
|
+
bias_trace: Float[Array, ""]
|
|
102
|
+
bias_normalizer: Float[Array, ""]
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@chex.dataclass(frozen=True)
|
|
106
|
+
class LearnerState:
|
|
107
|
+
"""State for a linear learner.
|
|
108
|
+
|
|
109
|
+
Attributes:
|
|
110
|
+
weights: Weight vector for linear prediction
|
|
111
|
+
bias: Bias term
|
|
112
|
+
optimizer_state: State maintained by the optimizer
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
weights: Float[Array, " feature_dim"]
|
|
116
|
+
bias: Float[Array, ""]
|
|
117
|
+
optimizer_state: LMSState | IDBDState | AutostepState
|
|
111
118
|
|
|
112
119
|
|
|
113
|
-
|
|
120
|
+
@chex.dataclass(frozen=True)
|
|
121
|
+
class StepSizeTrackingConfig:
|
|
114
122
|
"""Configuration for recording per-weight step-sizes during training.
|
|
115
123
|
|
|
116
124
|
Attributes:
|
|
@@ -122,7 +130,8 @@ class StepSizeTrackingConfig(NamedTuple):
|
|
|
122
130
|
include_bias: bool = True
|
|
123
131
|
|
|
124
132
|
|
|
125
|
-
|
|
133
|
+
@chex.dataclass(frozen=True)
|
|
134
|
+
class StepSizeHistory:
|
|
126
135
|
"""History of per-weight step-sizes recorded during training.
|
|
127
136
|
|
|
128
137
|
Attributes:
|
|
@@ -133,13 +142,14 @@ class StepSizeHistory(NamedTuple):
|
|
|
133
142
|
shape (num_recordings, num_weights) or None. Only populated for Autostep optimizer.
|
|
134
143
|
"""
|
|
135
144
|
|
|
136
|
-
step_sizes: Array
|
|
137
|
-
bias_step_sizes: Array
|
|
138
|
-
recording_indices: Array
|
|
139
|
-
normalizers: Array | None = None
|
|
145
|
+
step_sizes: Float[Array, "num_recordings feature_dim"]
|
|
146
|
+
bias_step_sizes: Float[Array, " num_recordings"] | None
|
|
147
|
+
recording_indices: Int[Array, " num_recordings"]
|
|
148
|
+
normalizers: Float[Array, "num_recordings feature_dim"] | None = None
|
|
140
149
|
|
|
141
150
|
|
|
142
|
-
|
|
151
|
+
@chex.dataclass(frozen=True)
|
|
152
|
+
class NormalizerTrackingConfig:
|
|
143
153
|
"""Configuration for recording per-feature normalizer state during training.
|
|
144
154
|
|
|
145
155
|
Attributes:
|
|
@@ -149,7 +159,8 @@ class NormalizerTrackingConfig(NamedTuple):
|
|
|
149
159
|
interval: int
|
|
150
160
|
|
|
151
161
|
|
|
152
|
-
|
|
162
|
+
@chex.dataclass(frozen=True)
|
|
163
|
+
class NormalizerHistory:
|
|
153
164
|
"""History of per-feature normalizer state recorded during training.
|
|
154
165
|
|
|
155
166
|
Used for analyzing how the OnlineNormalizer adapts to distribution shifts
|
|
@@ -162,12 +173,13 @@ class NormalizerHistory(NamedTuple):
|
|
|
162
173
|
recording_indices: Step indices where recordings were made, shape (num_recordings,)
|
|
163
174
|
"""
|
|
164
175
|
|
|
165
|
-
means: Array
|
|
166
|
-
variances: Array
|
|
167
|
-
recording_indices: Array
|
|
176
|
+
means: Float[Array, "num_recordings feature_dim"]
|
|
177
|
+
variances: Float[Array, "num_recordings feature_dim"]
|
|
178
|
+
recording_indices: Int[Array, " num_recordings"]
|
|
168
179
|
|
|
169
180
|
|
|
170
|
-
|
|
181
|
+
@chex.dataclass(frozen=True)
|
|
182
|
+
class BatchedLearningResult:
|
|
171
183
|
"""Result from batched learning loop across multiple seeds.
|
|
172
184
|
|
|
173
185
|
Used with `run_learning_loop_batched` for vmap-based GPU parallelization.
|
|
@@ -180,12 +192,13 @@ class BatchedLearningResult(NamedTuple):
|
|
|
180
192
|
or None if tracking was disabled
|
|
181
193
|
"""
|
|
182
194
|
|
|
183
|
-
states:
|
|
184
|
-
metrics: Array
|
|
195
|
+
states: LearnerState # Batched: each array has shape (num_seeds, ...)
|
|
196
|
+
metrics: Float[Array, "num_seeds num_steps 3"]
|
|
185
197
|
step_size_history: StepSizeHistory | None
|
|
186
198
|
|
|
187
199
|
|
|
188
|
-
|
|
200
|
+
@chex.dataclass(frozen=True)
|
|
201
|
+
class BatchedNormalizedResult:
|
|
189
202
|
"""Result from batched normalized learning loop across multiple seeds.
|
|
190
203
|
|
|
191
204
|
Used with `run_normalized_learning_loop_batched` for vmap-based GPU parallelization.
|
|
@@ -201,7 +214,7 @@ class BatchedNormalizedResult(NamedTuple):
|
|
|
201
214
|
"""
|
|
202
215
|
|
|
203
216
|
states: "NormalizedLearnerState" # Batched: each array has shape (num_seeds, ...)
|
|
204
|
-
metrics: Array
|
|
217
|
+
metrics: Float[Array, "num_seeds num_steps 4"]
|
|
205
218
|
step_size_history: StepSizeHistory | None
|
|
206
219
|
normalizer_history: NormalizerHistory | None
|
|
207
220
|
|
|
@@ -7,17 +7,20 @@ track and adapt.
|
|
|
7
7
|
All streams use JAX-compatible pure functions that work with jax.lax.scan.
|
|
8
8
|
"""
|
|
9
9
|
|
|
10
|
-
from typing import Any
|
|
10
|
+
from typing import Any
|
|
11
11
|
|
|
12
|
+
import chex
|
|
12
13
|
import jax.numpy as jnp
|
|
13
14
|
import jax.random as jr
|
|
14
15
|
from jax import Array
|
|
16
|
+
from jaxtyping import Float, Int, PRNGKeyArray
|
|
15
17
|
|
|
16
18
|
from alberta_framework.core.types import TimeStep
|
|
17
19
|
from alberta_framework.streams.base import ScanStream
|
|
18
20
|
|
|
19
21
|
|
|
20
|
-
|
|
22
|
+
@chex.dataclass(frozen=True)
|
|
23
|
+
class RandomWalkState:
|
|
21
24
|
"""State for RandomWalkStream.
|
|
22
25
|
|
|
23
26
|
Attributes:
|
|
@@ -25,8 +28,8 @@ class RandomWalkState(NamedTuple):
|
|
|
25
28
|
true_weights: Current true target weights
|
|
26
29
|
"""
|
|
27
30
|
|
|
28
|
-
key:
|
|
29
|
-
true_weights: Array
|
|
31
|
+
key: PRNGKeyArray
|
|
32
|
+
true_weights: Float[Array, " feature_dim"]
|
|
30
33
|
|
|
31
34
|
|
|
32
35
|
class RandomWalkStream:
|
|
@@ -110,7 +113,8 @@ class RandomWalkStream:
|
|
|
110
113
|
return timestep, new_state
|
|
111
114
|
|
|
112
115
|
|
|
113
|
-
|
|
116
|
+
@chex.dataclass(frozen=True)
|
|
117
|
+
class AbruptChangeState:
|
|
114
118
|
"""State for AbruptChangeStream.
|
|
115
119
|
|
|
116
120
|
Attributes:
|
|
@@ -119,9 +123,9 @@ class AbruptChangeState(NamedTuple):
|
|
|
119
123
|
step_count: Number of steps taken
|
|
120
124
|
"""
|
|
121
125
|
|
|
122
|
-
key:
|
|
123
|
-
true_weights: Array
|
|
124
|
-
step_count: Array
|
|
126
|
+
key: PRNGKeyArray
|
|
127
|
+
true_weights: Float[Array, " feature_dim"]
|
|
128
|
+
step_count: Int[Array, ""]
|
|
125
129
|
|
|
126
130
|
|
|
127
131
|
class AbruptChangeStream:
|
|
@@ -219,7 +223,8 @@ class AbruptChangeStream:
|
|
|
219
223
|
return timestep, new_state
|
|
220
224
|
|
|
221
225
|
|
|
222
|
-
|
|
226
|
+
@chex.dataclass(frozen=True)
|
|
227
|
+
class SuttonExperiment1State:
|
|
223
228
|
"""State for SuttonExperiment1Stream.
|
|
224
229
|
|
|
225
230
|
Attributes:
|
|
@@ -228,9 +233,9 @@ class SuttonExperiment1State(NamedTuple):
|
|
|
228
233
|
step_count: Number of steps taken
|
|
229
234
|
"""
|
|
230
235
|
|
|
231
|
-
key:
|
|
232
|
-
signs: Array
|
|
233
|
-
step_count: Array
|
|
236
|
+
key: PRNGKeyArray
|
|
237
|
+
signs: Float[Array, " num_relevant"]
|
|
238
|
+
step_count: Int[Array, ""]
|
|
234
239
|
|
|
235
240
|
|
|
236
241
|
class SuttonExperiment1Stream:
|
|
@@ -341,7 +346,8 @@ class SuttonExperiment1Stream:
|
|
|
341
346
|
return timestep, new_state
|
|
342
347
|
|
|
343
348
|
|
|
344
|
-
|
|
349
|
+
@chex.dataclass(frozen=True)
|
|
350
|
+
class CyclicState:
|
|
345
351
|
"""State for CyclicStream.
|
|
346
352
|
|
|
347
353
|
Attributes:
|
|
@@ -350,9 +356,9 @@ class CyclicState(NamedTuple):
|
|
|
350
356
|
step_count: Number of steps taken
|
|
351
357
|
"""
|
|
352
358
|
|
|
353
|
-
key:
|
|
354
|
-
configurations: Array
|
|
355
|
-
step_count: Array
|
|
359
|
+
key: PRNGKeyArray
|
|
360
|
+
configurations: Float[Array, "num_configs feature_dim"]
|
|
361
|
+
step_count: Int[Array, ""]
|
|
356
362
|
|
|
357
363
|
|
|
358
364
|
class CyclicStream:
|
|
@@ -452,7 +458,8 @@ class CyclicStream:
|
|
|
452
458
|
return timestep, new_state
|
|
453
459
|
|
|
454
460
|
|
|
455
|
-
|
|
461
|
+
@chex.dataclass(frozen=True)
|
|
462
|
+
class PeriodicChangeState:
|
|
456
463
|
"""State for PeriodicChangeStream.
|
|
457
464
|
|
|
458
465
|
Attributes:
|
|
@@ -462,10 +469,10 @@ class PeriodicChangeState(NamedTuple):
|
|
|
462
469
|
step_count: Number of steps taken
|
|
463
470
|
"""
|
|
464
471
|
|
|
465
|
-
key:
|
|
466
|
-
base_weights: Array
|
|
467
|
-
phases: Array
|
|
468
|
-
step_count: Array
|
|
472
|
+
key: PRNGKeyArray
|
|
473
|
+
base_weights: Float[Array, " feature_dim"]
|
|
474
|
+
phases: Float[Array, " feature_dim"]
|
|
475
|
+
step_count: Int[Array, ""]
|
|
469
476
|
|
|
470
477
|
|
|
471
478
|
class PeriodicChangeStream:
|
|
@@ -573,7 +580,8 @@ class PeriodicChangeStream:
|
|
|
573
580
|
return timestep, new_state
|
|
574
581
|
|
|
575
582
|
|
|
576
|
-
|
|
583
|
+
@chex.dataclass(frozen=True)
|
|
584
|
+
class ScaledStreamState:
|
|
577
585
|
"""State for ScaledStreamWrapper.
|
|
578
586
|
|
|
579
587
|
Attributes:
|
|
@@ -714,7 +722,8 @@ def make_scale_range(
|
|
|
714
722
|
return jnp.linspace(min_scale, max_scale, feature_dim, dtype=jnp.float32)
|
|
715
723
|
|
|
716
724
|
|
|
717
|
-
|
|
725
|
+
@chex.dataclass(frozen=True)
|
|
726
|
+
class DynamicScaleShiftState:
|
|
718
727
|
"""State for DynamicScaleShiftStream.
|
|
719
728
|
|
|
720
729
|
Attributes:
|
|
@@ -724,10 +733,10 @@ class DynamicScaleShiftState(NamedTuple):
|
|
|
724
733
|
step_count: Number of steps taken
|
|
725
734
|
"""
|
|
726
735
|
|
|
727
|
-
key:
|
|
728
|
-
true_weights: Array
|
|
729
|
-
current_scales: Array
|
|
730
|
-
step_count: Array
|
|
736
|
+
key: PRNGKeyArray
|
|
737
|
+
true_weights: Float[Array, " feature_dim"]
|
|
738
|
+
current_scales: Float[Array, " feature_dim"]
|
|
739
|
+
step_count: Int[Array, ""]
|
|
731
740
|
|
|
732
741
|
|
|
733
742
|
class DynamicScaleShiftStream:
|
|
@@ -858,7 +867,8 @@ class DynamicScaleShiftStream:
|
|
|
858
867
|
return timestep, new_state
|
|
859
868
|
|
|
860
869
|
|
|
861
|
-
|
|
870
|
+
@chex.dataclass(frozen=True)
|
|
871
|
+
class ScaleDriftState:
|
|
862
872
|
"""State for ScaleDriftStream.
|
|
863
873
|
|
|
864
874
|
Attributes:
|
|
@@ -868,10 +878,10 @@ class ScaleDriftState(NamedTuple):
|
|
|
868
878
|
step_count: Number of steps taken
|
|
869
879
|
"""
|
|
870
880
|
|
|
871
|
-
key:
|
|
872
|
-
true_weights: Array
|
|
873
|
-
log_scales: Array
|
|
874
|
-
step_count: Array
|
|
881
|
+
key: PRNGKeyArray
|
|
882
|
+
true_weights: Float[Array, " feature_dim"]
|
|
883
|
+
log_scales: Float[Array, " feature_dim"]
|
|
884
|
+
step_count: Int[Array, ""]
|
|
875
885
|
|
|
876
886
|
|
|
877
887
|
class ScaleDriftStream:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: alberta-framework
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0
|
|
4
4
|
Summary: Implementation of the Alberta Plan for AI Research - continual learning with meta-learned step-sizes
|
|
5
5
|
Project-URL: Homepage, https://github.com/j-klawson/alberta-framework
|
|
6
6
|
Project-URL: Repository, https://github.com/j-klawson/alberta-framework
|
|
@@ -17,14 +17,17 @@ Classifier: Programming Language :: Python :: 3
|
|
|
17
17
|
Classifier: Programming Language :: Python :: 3.13
|
|
18
18
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
19
19
|
Requires-Python: >=3.13
|
|
20
|
+
Requires-Dist: chex>=0.1.86
|
|
20
21
|
Requires-Dist: jax>=0.4
|
|
21
22
|
Requires-Dist: jaxlib>=0.4
|
|
23
|
+
Requires-Dist: jaxtyping>=0.2.28
|
|
22
24
|
Requires-Dist: joblib>=1.3
|
|
23
25
|
Requires-Dist: matplotlib>=3.8
|
|
24
26
|
Requires-Dist: numpy>=2.0
|
|
25
27
|
Requires-Dist: scipy>=1.11
|
|
26
28
|
Requires-Dist: tqdm>=4.66
|
|
27
29
|
Provides-Extra: dev
|
|
30
|
+
Requires-Dist: beartype>=0.18.0; extra == 'dev'
|
|
28
31
|
Requires-Dist: mypy; extra == 'dev'
|
|
29
32
|
Requires-Dist: pytest-cov; extra == 'dev'
|
|
30
33
|
Requires-Dist: pytest>=8.0; extra == 'dev'
|
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
alberta_framework/__init__.py,sha256=gAafDDmkivDdfnvDVff9zbVY9ilzqqfJ9KvpbRegKqs,5726
|
|
2
2
|
alberta_framework/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
3
|
alberta_framework/core/__init__.py,sha256=PSrC4zSxgm_6YXWEQ80aZaunpbQ58QexxKmDDU-jp6c,522
|
|
4
|
-
alberta_framework/core/learners.py,sha256=
|
|
5
|
-
alberta_framework/core/normalizers.py,sha256=
|
|
6
|
-
alberta_framework/core/optimizers.py,sha256=
|
|
7
|
-
alberta_framework/core/types.py,sha256=
|
|
4
|
+
alberta_framework/core/learners.py,sha256=gUhX7caXBfpWYgnvYTp5YKXfP6wbzB2T2gkSMMtrHDQ,38042
|
|
5
|
+
alberta_framework/core/normalizers.py,sha256=QmKmha-mFgKi1KD-f8xuB2U175yQL6Ll0D4c8OONIl0,5927
|
|
6
|
+
alberta_framework/core/optimizers.py,sha256=a4gYac5DyXReir9ycudRg8uQ9b53uLWTIldZ1A3Ae5c,14646
|
|
7
|
+
alberta_framework/core/types.py,sha256=eqiPMVD8_QYJNg83rnE3XO9Z5BPczsg_LKh5dhmIgt4,9807
|
|
8
8
|
alberta_framework/streams/__init__.py,sha256=bsDgWjWjotDQHMI2lno3dgk8N14pd-2mYAQpXAtCPx4,2035
|
|
9
9
|
alberta_framework/streams/base.py,sha256=9rJxvUgmzd5u2bRV4vi5PxhUvj39EZTD4bZHo-Ptn-U,2168
|
|
10
10
|
alberta_framework/streams/gymnasium.py,sha256=s733X7aEgy05hcSazjZEhBiJChtEL7uVpxwh0fXBQZA,21980
|
|
11
|
-
alberta_framework/streams/synthetic.py,sha256=
|
|
11
|
+
alberta_framework/streams/synthetic.py,sha256=8njzQCFRi_iVgdPA3slyn46vFIHHkIwaZsABZyPwqnU,33507
|
|
12
12
|
alberta_framework/utils/__init__.py,sha256=zfKfnbikhLp0J6UgVa8HeRo59gZHwqOc8jf03s7AaT4,2845
|
|
13
13
|
alberta_framework/utils/experiments.py,sha256=ekGAzveCRgv9YZ5mfAD5Uf7h_PvQnxsNw2KeZN2eu00,10644
|
|
14
14
|
alberta_framework/utils/export.py,sha256=W9RKfeTiyZcLColOGNjBfZU0N6QMXrfPn4pdYcm-OSk,15832
|
|
@@ -16,7 +16,7 @@ alberta_framework/utils/metrics.py,sha256=1cryNJoboO67vvRhausaucbYZFgdL_06vaf08U
|
|
|
16
16
|
alberta_framework/utils/statistics.py,sha256=4fbzNlmsdUaM5lLW1BhL5B5MUpnqimQlwJklZ4x0y0U,15416
|
|
17
17
|
alberta_framework/utils/timing.py,sha256=JOLq8CpCAV7LWOWkftxefduSFjaXnVwal1MFBKEMdJI,4049
|
|
18
18
|
alberta_framework/utils/visualization.py,sha256=PmKBD3KGabNhgDizcNiGJEbVCyDL1YMUE5yTwgJHu2o,17924
|
|
19
|
-
alberta_framework-0.
|
|
20
|
-
alberta_framework-0.
|
|
21
|
-
alberta_framework-0.
|
|
22
|
-
alberta_framework-0.
|
|
19
|
+
alberta_framework-0.3.0.dist-info/METADATA,sha256=OMn-jZdVSJ0vdLTLnpxzeTWx0UXlPFbGngA38wr_xDI,7872
|
|
20
|
+
alberta_framework-0.3.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
21
|
+
alberta_framework-0.3.0.dist-info/licenses/LICENSE,sha256=TI1avodt5mvxz7sunyxIa0HlNgLQcmKNLeRjCVcgKmE,10754
|
|
22
|
+
alberta_framework-0.3.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|