alberta-framework 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alberta_framework/__init__.py +43 -22
- alberta_framework/core/learners.py +357 -18
- alberta_framework/core/normalizers.py +1 -1
- alberta_framework/core/optimizers.py +14 -12
- alberta_framework/core/types.py +31 -0
- alberta_framework/streams/base.py +8 -5
- alberta_framework/streams/synthetic.py +16 -10
- alberta_framework/utils/experiments.py +4 -3
- alberta_framework/utils/timing.py +42 -36
- {alberta_framework-0.1.0.dist-info → alberta_framework-0.1.1.dist-info}/METADATA +10 -2
- alberta_framework-0.1.1.dist-info/RECORD +22 -0
- alberta_framework-0.1.0.dist-info/RECORD +0 -22
- {alberta_framework-0.1.0.dist-info → alberta_framework-0.1.1.dist-info}/WHEEL +0 -0
- {alberta_framework-0.1.0.dist-info → alberta_framework-0.1.1.dist-info}/licenses/LICENSE +0 -0
alberta_framework/__init__.py
CHANGED
|
@@ -1,25 +1,42 @@
|
|
|
1
|
-
"""Alberta Framework:
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
learning
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
1
|
+
"""Alberta Framework: A JAX-based research framework for continual AI.
|
|
2
|
+
|
|
3
|
+
The Alberta Framework provides foundational components for continual reinforcement
|
|
4
|
+
learning research. Built on JAX for hardware acceleration, the framework emphasizes
|
|
5
|
+
temporal uniformity — every component updates at every time step, with no special
|
|
6
|
+
training phases or batch processing.
|
|
7
|
+
|
|
8
|
+
Roadmap
|
|
9
|
+
-------
|
|
10
|
+
| Step | Focus | Status |
|
|
11
|
+
|------|-------|--------|
|
|
12
|
+
| 1 | Meta-learned step-sizes (IDBD, Autostep) | **Complete** |
|
|
13
|
+
| 2 | Feature generation and testing | Planned |
|
|
14
|
+
| 3 | GVF predictions, Horde architecture | Planned |
|
|
15
|
+
| 4 | Actor-critic with eligibility traces | Planned |
|
|
16
|
+
| 5-6 | Off-policy learning, average reward | Planned |
|
|
17
|
+
| 7-12 | Hierarchical, multi-agent, world models | Future |
|
|
18
|
+
|
|
19
|
+
Examples
|
|
20
|
+
--------
|
|
21
|
+
```python
|
|
22
|
+
import jax.random as jr
|
|
23
|
+
from alberta_framework import LinearLearner, IDBD, RandomWalkStream, run_learning_loop
|
|
24
|
+
|
|
25
|
+
# Non-stationary stream where target weights drift over time
|
|
26
|
+
stream = RandomWalkStream(feature_dim=10, drift_rate=0.001)
|
|
27
|
+
|
|
28
|
+
# Learner with IDBD meta-learned step-sizes
|
|
29
|
+
learner = LinearLearner(optimizer=IDBD())
|
|
30
|
+
|
|
31
|
+
# JIT-compiled training via jax.lax.scan
|
|
32
|
+
state, metrics = run_learning_loop(learner, stream, num_steps=10000, key=jr.key(42))
|
|
33
|
+
```
|
|
34
|
+
|
|
35
|
+
References
|
|
36
|
+
----------
|
|
37
|
+
- The Alberta Plan for AI Research (Sutton et al., 2022): https://arxiv.org/abs/2208.11173
|
|
38
|
+
- Adapting Bias by Gradient Descent (Sutton, 1992)
|
|
39
|
+
- Tuning-free Step-size Adaptation (Mahmood et al., 2012)
|
|
23
40
|
"""
|
|
24
41
|
|
|
25
42
|
__version__ = "0.1.0"
|
|
@@ -50,6 +67,8 @@ from alberta_framework.core.types import (
|
|
|
50
67
|
IDBDState,
|
|
51
68
|
LearnerState,
|
|
52
69
|
LMSState,
|
|
70
|
+
NormalizerHistory,
|
|
71
|
+
NormalizerTrackingConfig,
|
|
53
72
|
Observation,
|
|
54
73
|
Prediction,
|
|
55
74
|
StepSizeHistory,
|
|
@@ -122,7 +141,9 @@ __all__ = [
|
|
|
122
141
|
"IDBDState",
|
|
123
142
|
"LMSState",
|
|
124
143
|
"LearnerState",
|
|
144
|
+
"NormalizerHistory",
|
|
125
145
|
"NormalizerState",
|
|
146
|
+
"NormalizerTrackingConfig",
|
|
126
147
|
"Observation",
|
|
127
148
|
"Prediction",
|
|
128
149
|
"StepSizeHistory",
|
|
@@ -18,6 +18,8 @@ from alberta_framework.core.types import (
|
|
|
18
18
|
IDBDState,
|
|
19
19
|
LearnerState,
|
|
20
20
|
LMSState,
|
|
21
|
+
NormalizerHistory,
|
|
22
|
+
NormalizerTrackingConfig,
|
|
21
23
|
Observation,
|
|
22
24
|
Prediction,
|
|
23
25
|
StepSizeHistory,
|
|
@@ -48,7 +50,7 @@ class UpdateResult(NamedTuple):
|
|
|
48
50
|
class LinearLearner:
|
|
49
51
|
"""Linear function approximator with pluggable optimizer.
|
|
50
52
|
|
|
51
|
-
Computes predictions as: y = w @ x + b
|
|
53
|
+
Computes predictions as: `y = w @ x + b`
|
|
52
54
|
|
|
53
55
|
The learner maintains weights and bias, delegating the adaptation
|
|
54
56
|
of learning rates to the optimizer (e.g., LMS or IDBD).
|
|
@@ -93,7 +95,7 @@ class LinearLearner:
|
|
|
93
95
|
observation: Input feature vector
|
|
94
96
|
|
|
95
97
|
Returns:
|
|
96
|
-
Scalar prediction y = w @ x + b
|
|
98
|
+
Scalar prediction `y = w @ x + b`
|
|
97
99
|
"""
|
|
98
100
|
return jnp.atleast_1d(jnp.dot(state.weights, observation) + state.bias)
|
|
99
101
|
|
|
@@ -238,10 +240,25 @@ def run_learning_loop[StreamStateT](
|
|
|
238
240
|
)
|
|
239
241
|
recording_indices = jnp.zeros(num_recordings, dtype=jnp.int32)
|
|
240
242
|
|
|
243
|
+
# Check if we need to track Autostep normalizers
|
|
244
|
+
# We detect this at trace time by checking the initial optimizer state
|
|
245
|
+
track_normalizers = hasattr(learner_state.optimizer_state, "normalizers")
|
|
246
|
+
normalizer_history = (
|
|
247
|
+
jnp.zeros((num_recordings, feature_dim), dtype=jnp.float32)
|
|
248
|
+
if track_normalizers
|
|
249
|
+
else None
|
|
250
|
+
)
|
|
251
|
+
|
|
241
252
|
def step_fn_with_tracking(
|
|
242
|
-
carry: tuple[
|
|
243
|
-
|
|
244
|
-
|
|
253
|
+
carry: tuple[
|
|
254
|
+
LearnerState, StreamStateT, Array, Array | None, Array, Array | None
|
|
255
|
+
],
|
|
256
|
+
idx: Array,
|
|
257
|
+
) -> tuple[
|
|
258
|
+
tuple[LearnerState, StreamStateT, Array, Array | None, Array, Array | None],
|
|
259
|
+
Array,
|
|
260
|
+
]:
|
|
261
|
+
l_state, s_state, ss_history, b_history, rec_indices, norm_history = carry
|
|
245
262
|
|
|
246
263
|
# Perform learning step
|
|
247
264
|
timestep, new_s_state = stream.step(s_state, idx)
|
|
@@ -291,12 +308,25 @@ def run_learning_loop[StreamStateT](
|
|
|
291
308
|
None,
|
|
292
309
|
)
|
|
293
310
|
|
|
311
|
+
# Track Autostep normalizers (v_i) if applicable
|
|
312
|
+
new_norm_history = norm_history
|
|
313
|
+
if norm_history is not None and hasattr(opt_state, "normalizers"):
|
|
314
|
+
new_norm_history = jax.lax.cond(
|
|
315
|
+
should_record,
|
|
316
|
+
lambda _: norm_history.at[recording_idx].set(
|
|
317
|
+
opt_state.normalizers # type: ignore[union-attr]
|
|
318
|
+
),
|
|
319
|
+
lambda _: norm_history,
|
|
320
|
+
None,
|
|
321
|
+
)
|
|
322
|
+
|
|
294
323
|
return (
|
|
295
324
|
result.state,
|
|
296
325
|
new_s_state,
|
|
297
326
|
new_ss_history,
|
|
298
327
|
new_b_history,
|
|
299
328
|
new_rec_indices,
|
|
329
|
+
new_norm_history,
|
|
300
330
|
), result.metrics
|
|
301
331
|
|
|
302
332
|
initial_carry = (
|
|
@@ -305,16 +335,25 @@ def run_learning_loop[StreamStateT](
|
|
|
305
335
|
step_size_history,
|
|
306
336
|
bias_history,
|
|
307
337
|
recording_indices,
|
|
338
|
+
normalizer_history,
|
|
308
339
|
)
|
|
309
340
|
|
|
310
|
-
(
|
|
311
|
-
|
|
341
|
+
(
|
|
342
|
+
final_learner,
|
|
343
|
+
_,
|
|
344
|
+
final_ss_history,
|
|
345
|
+
final_b_history,
|
|
346
|
+
final_rec_indices,
|
|
347
|
+
final_norm_history,
|
|
348
|
+
), metrics = jax.lax.scan(
|
|
349
|
+
step_fn_with_tracking, initial_carry, jnp.arange(num_steps)
|
|
312
350
|
)
|
|
313
351
|
|
|
314
352
|
history = StepSizeHistory(
|
|
315
353
|
step_sizes=final_ss_history,
|
|
316
354
|
bias_step_sizes=final_b_history,
|
|
317
355
|
recording_indices=final_rec_indices,
|
|
356
|
+
normalizers=final_norm_history,
|
|
318
357
|
)
|
|
319
358
|
|
|
320
359
|
return final_learner, metrics, history
|
|
@@ -473,7 +512,14 @@ def run_normalized_learning_loop[StreamStateT](
|
|
|
473
512
|
num_steps: int,
|
|
474
513
|
key: Array,
|
|
475
514
|
learner_state: NormalizedLearnerState | None = None,
|
|
476
|
-
|
|
515
|
+
step_size_tracking: StepSizeTrackingConfig | None = None,
|
|
516
|
+
normalizer_tracking: NormalizerTrackingConfig | None = None,
|
|
517
|
+
) -> (
|
|
518
|
+
tuple[NormalizedLearnerState, Array]
|
|
519
|
+
| tuple[NormalizedLearnerState, Array, StepSizeHistory]
|
|
520
|
+
| tuple[NormalizedLearnerState, Array, NormalizerHistory]
|
|
521
|
+
| tuple[NormalizedLearnerState, Array, StepSizeHistory, NormalizerHistory]
|
|
522
|
+
):
|
|
477
523
|
"""Run the learning loop with normalization using jax.lax.scan.
|
|
478
524
|
|
|
479
525
|
Args:
|
|
@@ -482,29 +528,322 @@ def run_normalized_learning_loop[StreamStateT](
|
|
|
482
528
|
num_steps: Number of learning steps to run
|
|
483
529
|
key: JAX random key for stream initialization
|
|
484
530
|
learner_state: Initial state (if None, will be initialized from stream)
|
|
531
|
+
step_size_tracking: Optional config for recording per-weight step-sizes.
|
|
532
|
+
When provided, returns StepSizeHistory including Autostep normalizers if applicable.
|
|
533
|
+
normalizer_tracking: Optional config for recording per-feature normalizer state.
|
|
534
|
+
When provided, returns NormalizerHistory with means and variances over time.
|
|
485
535
|
|
|
486
536
|
Returns:
|
|
487
|
-
|
|
488
|
-
|
|
537
|
+
If no tracking:
|
|
538
|
+
Tuple of (final_state, metrics_array) where metrics_array has shape
|
|
539
|
+
(num_steps, 4) with columns [squared_error, error, mean_step_size, normalizer_mean_var]
|
|
540
|
+
If step_size_tracking only:
|
|
541
|
+
Tuple of (final_state, metrics_array, step_size_history)
|
|
542
|
+
If normalizer_tracking only:
|
|
543
|
+
Tuple of (final_state, metrics_array, normalizer_history)
|
|
544
|
+
If both:
|
|
545
|
+
Tuple of (final_state, metrics_array, step_size_history, normalizer_history)
|
|
546
|
+
|
|
547
|
+
Raises:
|
|
548
|
+
ValueError: If tracking interval is invalid
|
|
489
549
|
"""
|
|
550
|
+
# Validate tracking configs
|
|
551
|
+
if step_size_tracking is not None:
|
|
552
|
+
if step_size_tracking.interval < 1:
|
|
553
|
+
raise ValueError(
|
|
554
|
+
f"step_size_tracking.interval must be >= 1, got {step_size_tracking.interval}"
|
|
555
|
+
)
|
|
556
|
+
if step_size_tracking.interval > num_steps:
|
|
557
|
+
raise ValueError(
|
|
558
|
+
f"step_size_tracking.interval ({step_size_tracking.interval}) "
|
|
559
|
+
f"must be <= num_steps ({num_steps})"
|
|
560
|
+
)
|
|
561
|
+
|
|
562
|
+
if normalizer_tracking is not None:
|
|
563
|
+
if normalizer_tracking.interval < 1:
|
|
564
|
+
raise ValueError(
|
|
565
|
+
f"normalizer_tracking.interval must be >= 1, got {normalizer_tracking.interval}"
|
|
566
|
+
)
|
|
567
|
+
if normalizer_tracking.interval > num_steps:
|
|
568
|
+
raise ValueError(
|
|
569
|
+
f"normalizer_tracking.interval ({normalizer_tracking.interval}) "
|
|
570
|
+
f"must be <= num_steps ({num_steps})"
|
|
571
|
+
)
|
|
572
|
+
|
|
490
573
|
# Initialize states
|
|
491
574
|
if learner_state is None:
|
|
492
575
|
learner_state = learner.init(stream.feature_dim)
|
|
493
576
|
stream_state = stream.init(key)
|
|
494
577
|
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
578
|
+
feature_dim = stream.feature_dim
|
|
579
|
+
|
|
580
|
+
# No tracking - simple case
|
|
581
|
+
if step_size_tracking is None and normalizer_tracking is None:
|
|
582
|
+
|
|
583
|
+
def step_fn(
|
|
584
|
+
carry: tuple[NormalizedLearnerState, StreamStateT], idx: Array
|
|
585
|
+
) -> tuple[tuple[NormalizedLearnerState, StreamStateT], Array]:
|
|
586
|
+
l_state, s_state = carry
|
|
587
|
+
timestep, new_s_state = stream.step(s_state, idx)
|
|
588
|
+
result = learner.update(l_state, timestep.observation, timestep.target)
|
|
589
|
+
return (result.state, new_s_state), result.metrics
|
|
590
|
+
|
|
591
|
+
(final_learner, _), metrics = jax.lax.scan(
|
|
592
|
+
step_fn, (learner_state, stream_state), jnp.arange(num_steps)
|
|
593
|
+
)
|
|
594
|
+
|
|
595
|
+
return final_learner, metrics
|
|
596
|
+
|
|
597
|
+
# Tracking enabled - need to set up history arrays
|
|
598
|
+
ss_interval = step_size_tracking.interval if step_size_tracking else num_steps + 1
|
|
599
|
+
norm_interval = (
|
|
600
|
+
normalizer_tracking.interval if normalizer_tracking else num_steps + 1
|
|
601
|
+
)
|
|
602
|
+
|
|
603
|
+
ss_num_recordings = num_steps // ss_interval if step_size_tracking else 0
|
|
604
|
+
norm_num_recordings = num_steps // norm_interval if normalizer_tracking else 0
|
|
605
|
+
|
|
606
|
+
# Pre-allocate step-size history arrays
|
|
607
|
+
ss_history = (
|
|
608
|
+
jnp.zeros((ss_num_recordings, feature_dim), dtype=jnp.float32)
|
|
609
|
+
if step_size_tracking
|
|
610
|
+
else None
|
|
611
|
+
)
|
|
612
|
+
ss_bias_history = (
|
|
613
|
+
jnp.zeros(ss_num_recordings, dtype=jnp.float32)
|
|
614
|
+
if step_size_tracking and step_size_tracking.include_bias
|
|
615
|
+
else None
|
|
616
|
+
)
|
|
617
|
+
ss_rec_indices = (
|
|
618
|
+
jnp.zeros(ss_num_recordings, dtype=jnp.int32) if step_size_tracking else None
|
|
619
|
+
)
|
|
620
|
+
|
|
621
|
+
# Check if we need to track Autostep normalizers
|
|
622
|
+
track_autostep_normalizers = hasattr(
|
|
623
|
+
learner_state.learner_state.optimizer_state, "normalizers"
|
|
624
|
+
)
|
|
625
|
+
ss_normalizers = (
|
|
626
|
+
jnp.zeros((ss_num_recordings, feature_dim), dtype=jnp.float32)
|
|
627
|
+
if step_size_tracking and track_autostep_normalizers
|
|
628
|
+
else None
|
|
629
|
+
)
|
|
630
|
+
|
|
631
|
+
# Pre-allocate normalizer state history arrays
|
|
632
|
+
norm_means = (
|
|
633
|
+
jnp.zeros((norm_num_recordings, feature_dim), dtype=jnp.float32)
|
|
634
|
+
if normalizer_tracking
|
|
635
|
+
else None
|
|
636
|
+
)
|
|
637
|
+
norm_vars = (
|
|
638
|
+
jnp.zeros((norm_num_recordings, feature_dim), dtype=jnp.float32)
|
|
639
|
+
if normalizer_tracking
|
|
640
|
+
else None
|
|
641
|
+
)
|
|
642
|
+
norm_rec_indices = (
|
|
643
|
+
jnp.zeros(norm_num_recordings, dtype=jnp.int32) if normalizer_tracking else None
|
|
644
|
+
)
|
|
645
|
+
|
|
646
|
+
def step_fn_with_tracking(
|
|
647
|
+
carry: tuple[
|
|
648
|
+
NormalizedLearnerState,
|
|
649
|
+
StreamStateT,
|
|
650
|
+
Array | None,
|
|
651
|
+
Array | None,
|
|
652
|
+
Array | None,
|
|
653
|
+
Array | None,
|
|
654
|
+
Array | None,
|
|
655
|
+
Array | None,
|
|
656
|
+
Array | None,
|
|
657
|
+
],
|
|
658
|
+
idx: Array,
|
|
659
|
+
) -> tuple[
|
|
660
|
+
tuple[
|
|
661
|
+
NormalizedLearnerState,
|
|
662
|
+
StreamStateT,
|
|
663
|
+
Array | None,
|
|
664
|
+
Array | None,
|
|
665
|
+
Array | None,
|
|
666
|
+
Array | None,
|
|
667
|
+
Array | None,
|
|
668
|
+
Array | None,
|
|
669
|
+
Array | None,
|
|
670
|
+
],
|
|
671
|
+
Array,
|
|
672
|
+
]:
|
|
673
|
+
(
|
|
674
|
+
l_state,
|
|
675
|
+
s_state,
|
|
676
|
+
ss_hist,
|
|
677
|
+
ss_bias_hist,
|
|
678
|
+
ss_rec,
|
|
679
|
+
ss_norm,
|
|
680
|
+
n_means,
|
|
681
|
+
n_vars,
|
|
682
|
+
n_rec,
|
|
683
|
+
) = carry
|
|
684
|
+
|
|
685
|
+
# Perform learning step
|
|
499
686
|
timestep, new_s_state = stream.step(s_state, idx)
|
|
500
687
|
result = learner.update(l_state, timestep.observation, timestep.target)
|
|
501
|
-
return (result.state, new_s_state), result.metrics
|
|
502
688
|
|
|
503
|
-
|
|
504
|
-
|
|
689
|
+
# Step-size tracking
|
|
690
|
+
new_ss_hist = ss_hist
|
|
691
|
+
new_ss_bias_hist = ss_bias_hist
|
|
692
|
+
new_ss_rec = ss_rec
|
|
693
|
+
new_ss_norm = ss_norm
|
|
694
|
+
|
|
695
|
+
if ss_hist is not None:
|
|
696
|
+
should_record_ss = (idx % ss_interval) == 0
|
|
697
|
+
recording_idx = idx // ss_interval
|
|
698
|
+
|
|
699
|
+
# Extract current step-sizes from the inner learner state
|
|
700
|
+
opt_state = result.state.learner_state.optimizer_state
|
|
701
|
+
if hasattr(opt_state, "log_step_sizes"):
|
|
702
|
+
# IDBD stores log step-sizes
|
|
703
|
+
weight_ss = jnp.exp(opt_state.log_step_sizes) # type: ignore[union-attr]
|
|
704
|
+
bias_ss = opt_state.bias_step_size # type: ignore[union-attr]
|
|
705
|
+
elif hasattr(opt_state, "step_sizes"):
|
|
706
|
+
# Autostep stores step-sizes directly
|
|
707
|
+
weight_ss = opt_state.step_sizes # type: ignore[union-attr]
|
|
708
|
+
bias_ss = opt_state.bias_step_size # type: ignore[union-attr]
|
|
709
|
+
else:
|
|
710
|
+
# LMS has a single fixed step-size
|
|
711
|
+
weight_ss = jnp.full(feature_dim, opt_state.step_size)
|
|
712
|
+
bias_ss = opt_state.step_size
|
|
713
|
+
|
|
714
|
+
new_ss_hist = jax.lax.cond(
|
|
715
|
+
should_record_ss,
|
|
716
|
+
lambda _: ss_hist.at[recording_idx].set(weight_ss),
|
|
717
|
+
lambda _: ss_hist,
|
|
718
|
+
None,
|
|
719
|
+
)
|
|
720
|
+
|
|
721
|
+
if ss_bias_hist is not None:
|
|
722
|
+
new_ss_bias_hist = jax.lax.cond(
|
|
723
|
+
should_record_ss,
|
|
724
|
+
lambda _: ss_bias_hist.at[recording_idx].set(bias_ss),
|
|
725
|
+
lambda _: ss_bias_hist,
|
|
726
|
+
None,
|
|
727
|
+
)
|
|
728
|
+
|
|
729
|
+
if ss_rec is not None:
|
|
730
|
+
new_ss_rec = jax.lax.cond(
|
|
731
|
+
should_record_ss,
|
|
732
|
+
lambda _: ss_rec.at[recording_idx].set(idx),
|
|
733
|
+
lambda _: ss_rec,
|
|
734
|
+
None,
|
|
735
|
+
)
|
|
736
|
+
|
|
737
|
+
# Track Autostep normalizers (v_i) if applicable
|
|
738
|
+
if ss_norm is not None and hasattr(opt_state, "normalizers"):
|
|
739
|
+
new_ss_norm = jax.lax.cond(
|
|
740
|
+
should_record_ss,
|
|
741
|
+
lambda _: ss_norm.at[recording_idx].set(
|
|
742
|
+
opt_state.normalizers # type: ignore[union-attr]
|
|
743
|
+
),
|
|
744
|
+
lambda _: ss_norm,
|
|
745
|
+
None,
|
|
746
|
+
)
|
|
747
|
+
|
|
748
|
+
# Normalizer state tracking
|
|
749
|
+
new_n_means = n_means
|
|
750
|
+
new_n_vars = n_vars
|
|
751
|
+
new_n_rec = n_rec
|
|
752
|
+
|
|
753
|
+
if n_means is not None:
|
|
754
|
+
should_record_norm = (idx % norm_interval) == 0
|
|
755
|
+
norm_recording_idx = idx // norm_interval
|
|
756
|
+
|
|
757
|
+
norm_state = result.state.normalizer_state
|
|
758
|
+
|
|
759
|
+
new_n_means = jax.lax.cond(
|
|
760
|
+
should_record_norm,
|
|
761
|
+
lambda _: n_means.at[norm_recording_idx].set(norm_state.mean),
|
|
762
|
+
lambda _: n_means,
|
|
763
|
+
None,
|
|
764
|
+
)
|
|
765
|
+
|
|
766
|
+
if n_vars is not None:
|
|
767
|
+
new_n_vars = jax.lax.cond(
|
|
768
|
+
should_record_norm,
|
|
769
|
+
lambda _: n_vars.at[norm_recording_idx].set(norm_state.var),
|
|
770
|
+
lambda _: n_vars,
|
|
771
|
+
None,
|
|
772
|
+
)
|
|
773
|
+
|
|
774
|
+
if n_rec is not None:
|
|
775
|
+
new_n_rec = jax.lax.cond(
|
|
776
|
+
should_record_norm,
|
|
777
|
+
lambda _: n_rec.at[norm_recording_idx].set(idx),
|
|
778
|
+
lambda _: n_rec,
|
|
779
|
+
None,
|
|
780
|
+
)
|
|
781
|
+
|
|
782
|
+
return (
|
|
783
|
+
result.state,
|
|
784
|
+
new_s_state,
|
|
785
|
+
new_ss_hist,
|
|
786
|
+
new_ss_bias_hist,
|
|
787
|
+
new_ss_rec,
|
|
788
|
+
new_ss_norm,
|
|
789
|
+
new_n_means,
|
|
790
|
+
new_n_vars,
|
|
791
|
+
new_n_rec,
|
|
792
|
+
), result.metrics
|
|
793
|
+
|
|
794
|
+
initial_carry = (
|
|
795
|
+
learner_state,
|
|
796
|
+
stream_state,
|
|
797
|
+
ss_history,
|
|
798
|
+
ss_bias_history,
|
|
799
|
+
ss_rec_indices,
|
|
800
|
+
ss_normalizers,
|
|
801
|
+
norm_means,
|
|
802
|
+
norm_vars,
|
|
803
|
+
norm_rec_indices,
|
|
804
|
+
)
|
|
805
|
+
|
|
806
|
+
(
|
|
807
|
+
final_learner,
|
|
808
|
+
_,
|
|
809
|
+
final_ss_hist,
|
|
810
|
+
final_ss_bias_hist,
|
|
811
|
+
final_ss_rec,
|
|
812
|
+
final_ss_norm,
|
|
813
|
+
final_n_means,
|
|
814
|
+
final_n_vars,
|
|
815
|
+
final_n_rec,
|
|
816
|
+
), metrics = jax.lax.scan(
|
|
817
|
+
step_fn_with_tracking, initial_carry, jnp.arange(num_steps)
|
|
505
818
|
)
|
|
506
819
|
|
|
507
|
-
return
|
|
820
|
+
# Build return values based on what was tracked
|
|
821
|
+
ss_history_result = None
|
|
822
|
+
if step_size_tracking is not None and final_ss_hist is not None:
|
|
823
|
+
ss_history_result = StepSizeHistory(
|
|
824
|
+
step_sizes=final_ss_hist,
|
|
825
|
+
bias_step_sizes=final_ss_bias_hist,
|
|
826
|
+
recording_indices=final_ss_rec, # type: ignore[arg-type]
|
|
827
|
+
normalizers=final_ss_norm,
|
|
828
|
+
)
|
|
829
|
+
|
|
830
|
+
norm_history_result = None
|
|
831
|
+
if normalizer_tracking is not None and final_n_means is not None:
|
|
832
|
+
norm_history_result = NormalizerHistory(
|
|
833
|
+
means=final_n_means,
|
|
834
|
+
variances=final_n_vars, # type: ignore[arg-type]
|
|
835
|
+
recording_indices=final_n_rec, # type: ignore[arg-type]
|
|
836
|
+
)
|
|
837
|
+
|
|
838
|
+
# Return appropriate tuple based on what was tracked
|
|
839
|
+
if ss_history_result is not None and norm_history_result is not None:
|
|
840
|
+
return final_learner, metrics, ss_history_result, norm_history_result
|
|
841
|
+
elif ss_history_result is not None:
|
|
842
|
+
return final_learner, metrics, ss_history_result
|
|
843
|
+
elif norm_history_result is not None:
|
|
844
|
+
return final_learner, metrics, norm_history_result
|
|
845
|
+
else:
|
|
846
|
+
return final_learner, metrics
|
|
508
847
|
|
|
509
848
|
|
|
510
849
|
def metrics_to_dicts(metrics: Array, normalized: bool = False) -> list[dict[str, float]]:
|
|
@@ -35,7 +35,7 @@ class OnlineNormalizer:
|
|
|
35
35
|
"""Online feature normalizer for continual learning.
|
|
36
36
|
|
|
37
37
|
Normalizes features using running estimates of mean and standard deviation:
|
|
38
|
-
|
|
38
|
+
`x_normalized = (x - mean) / (std + epsilon)`
|
|
39
39
|
|
|
40
40
|
The normalizer updates its estimates at every time step, following
|
|
41
41
|
temporal uniformity. Uses exponential moving average for non-stationary
|
|
@@ -72,7 +72,7 @@ class Optimizer[StateT: (LMSState, IDBDState, AutostepState)](ABC):
|
|
|
72
72
|
class LMS(Optimizer[LMSState]):
|
|
73
73
|
"""Least Mean Square optimizer with fixed step-size.
|
|
74
74
|
|
|
75
|
-
The simplest gradient-based optimizer: w_{t+1} = w_t + alpha * delta * x_t
|
|
75
|
+
The simplest gradient-based optimizer: `w_{t+1} = w_t + alpha * delta * x_t`
|
|
76
76
|
|
|
77
77
|
This serves as a baseline. The challenge is that the optimal step-size
|
|
78
78
|
depends on the problem and changes as the task becomes non-stationary.
|
|
@@ -108,7 +108,7 @@ class LMS(Optimizer[LMSState]):
|
|
|
108
108
|
) -> OptimizerUpdate:
|
|
109
109
|
"""Compute LMS weight update.
|
|
110
110
|
|
|
111
|
-
Update rule: delta_w = alpha * error * x
|
|
111
|
+
Update rule: `delta_w = alpha * error * x`
|
|
112
112
|
|
|
113
113
|
Args:
|
|
114
114
|
state: Current LMS state
|
|
@@ -195,10 +195,11 @@ class IDBD(Optimizer[IDBDState]):
|
|
|
195
195
|
"""Compute IDBD weight update with adaptive step-sizes.
|
|
196
196
|
|
|
197
197
|
The IDBD algorithm:
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
198
|
+
|
|
199
|
+
1. Compute step-sizes: `alpha_i = exp(log_alpha_i)`
|
|
200
|
+
2. Update weights: `w_i += alpha_i * error * x_i`
|
|
201
|
+
3. Update log step-sizes: `log_alpha_i += beta * error * x_i * h_i`
|
|
202
|
+
4. Update traces: `h_i = h_i * max(0, 1 - alpha_i * x_i^2) + alpha_i * error * x_i`
|
|
202
203
|
|
|
203
204
|
The trace h_i tracks the correlation between current and past gradients.
|
|
204
205
|
When gradients consistently point the same direction, h_i grows,
|
|
@@ -335,12 +336,13 @@ class Autostep(Optimizer[AutostepState]):
|
|
|
335
336
|
"""Compute Autostep weight update with normalized gradients.
|
|
336
337
|
|
|
337
338
|
The Autostep algorithm:
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
339
|
+
|
|
340
|
+
1. Compute gradient: `g_i = error * x_i`
|
|
341
|
+
2. Normalize gradient: `g_i' = g_i / max(|g_i|, v_i)`
|
|
342
|
+
3. Update weights: `w_i += alpha_i * g_i'`
|
|
343
|
+
4. Update step-sizes: `alpha_i *= exp(mu * g_i' * h_i)`
|
|
344
|
+
5. Update traces: `h_i = h_i * (1 - alpha_i) + alpha_i * g_i'`
|
|
345
|
+
6. Update normalizers: `v_i = max(|g_i|, v_i * tau)`
|
|
344
346
|
|
|
345
347
|
Args:
|
|
346
348
|
state: Current Autostep state
|
alberta_framework/core/types.py
CHANGED
|
@@ -126,11 +126,42 @@ class StepSizeHistory(NamedTuple):
|
|
|
126
126
|
step_sizes: Per-weight step-sizes at each recording, shape (num_recordings, num_weights)
|
|
127
127
|
bias_step_sizes: Bias step-sizes at each recording, shape (num_recordings,) or None
|
|
128
128
|
recording_indices: Step indices where recordings were made, shape (num_recordings,)
|
|
129
|
+
normalizers: Autostep's per-weight normalizers (v_i) at each recording,
|
|
130
|
+
shape (num_recordings, num_weights) or None. Only populated for Autostep optimizer.
|
|
129
131
|
"""
|
|
130
132
|
|
|
131
133
|
step_sizes: Array # (num_recordings, num_weights)
|
|
132
134
|
bias_step_sizes: Array | None # (num_recordings,) or None
|
|
133
135
|
recording_indices: Array # (num_recordings,)
|
|
136
|
+
normalizers: Array | None = None # (num_recordings, num_weights) - Autostep v_i
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class NormalizerTrackingConfig(NamedTuple):
|
|
140
|
+
"""Configuration for recording per-feature normalizer state during training.
|
|
141
|
+
|
|
142
|
+
Attributes:
|
|
143
|
+
interval: Record normalizer state every N steps
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
interval: int
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class NormalizerHistory(NamedTuple):
|
|
150
|
+
"""History of per-feature normalizer state recorded during training.
|
|
151
|
+
|
|
152
|
+
Used for analyzing how the OnlineNormalizer adapts to distribution shifts
|
|
153
|
+
(reactive lag diagnostic).
|
|
154
|
+
|
|
155
|
+
Attributes:
|
|
156
|
+
means: Per-feature mean estimates at each recording, shape (num_recordings, feature_dim)
|
|
157
|
+
variances: Per-feature variance estimates at each recording,
|
|
158
|
+
shape (num_recordings, feature_dim)
|
|
159
|
+
recording_indices: Step indices where recordings were made, shape (num_recordings,)
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
means: Array # (num_recordings, feature_dim)
|
|
163
|
+
variances: Array # (num_recordings, feature_dim)
|
|
164
|
+
recording_indices: Array # (num_recordings,)
|
|
134
165
|
|
|
135
166
|
|
|
136
167
|
def create_lms_state(step_size: float = 0.01) -> LMSState:
|
|
@@ -30,11 +30,14 @@ class ScanStream(Protocol[StateT]):
|
|
|
30
30
|
Type Parameters:
|
|
31
31
|
StateT: The state type maintained by this stream
|
|
32
32
|
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
33
|
+
Examples
|
|
34
|
+
--------
|
|
35
|
+
```python
|
|
36
|
+
stream = RandomWalkStream(feature_dim=10, drift_rate=0.001)
|
|
37
|
+
key = jax.random.key(42)
|
|
38
|
+
state = stream.init(key)
|
|
39
|
+
timestep, new_state = stream.step(state, jnp.array(0))
|
|
40
|
+
```
|
|
38
41
|
"""
|
|
39
42
|
|
|
40
43
|
@property
|
|
@@ -32,7 +32,7 @@ class RandomWalkState(NamedTuple):
|
|
|
32
32
|
class RandomWalkStream:
|
|
33
33
|
"""Non-stationary stream where target weights drift via random walk.
|
|
34
34
|
|
|
35
|
-
The true target function is linear: y* = w_true @ x + noise
|
|
35
|
+
The true target function is linear: `y* = w_true @ x + noise`
|
|
36
36
|
where w_true evolves via random walk at each time step.
|
|
37
37
|
|
|
38
38
|
This tests the learner's ability to continuously track a moving target.
|
|
@@ -590,12 +590,15 @@ class ScaledStreamWrapper:
|
|
|
590
590
|
scale factor. Useful for testing how learners handle features at different
|
|
591
591
|
scales, which is important for understanding normalization benefits.
|
|
592
592
|
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
593
|
+
Examples
|
|
594
|
+
--------
|
|
595
|
+
```python
|
|
596
|
+
stream = ScaledStreamWrapper(
|
|
597
|
+
AbruptChangeStream(feature_dim=10, change_interval=1000),
|
|
598
|
+
feature_scales=jnp.array([0.001, 0.01, 0.1, 1.0, 10.0,
|
|
599
|
+
100.0, 1000.0, 0.001, 0.01, 0.1])
|
|
600
|
+
)
|
|
601
|
+
```
|
|
599
602
|
|
|
600
603
|
Attributes:
|
|
601
604
|
inner_stream: The wrapped stream instance
|
|
@@ -693,9 +696,12 @@ def make_scale_range(
|
|
|
693
696
|
Returns:
|
|
694
697
|
Array of shape (feature_dim,) with scale factors
|
|
695
698
|
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
+
Examples
|
|
700
|
+
--------
|
|
701
|
+
```python
|
|
702
|
+
scales = make_scale_range(10, min_scale=0.01, max_scale=100.0)
|
|
703
|
+
stream = ScaledStreamWrapper(RandomWalkStream(10), scales)
|
|
704
|
+
```
|
|
699
705
|
"""
|
|
700
706
|
if log_spaced:
|
|
701
707
|
return jnp.logspace(
|
|
@@ -110,13 +110,14 @@ def run_single_experiment(
|
|
|
110
110
|
|
|
111
111
|
final_state: LearnerState | NormalizedLearnerState
|
|
112
112
|
if isinstance(learner, NormalizedLinearLearner):
|
|
113
|
-
|
|
113
|
+
norm_result = run_normalized_learning_loop(
|
|
114
114
|
learner, stream, config.num_steps, key
|
|
115
115
|
)
|
|
116
|
+
final_state, metrics = cast(tuple[NormalizedLearnerState, Any], norm_result)
|
|
116
117
|
metrics_history = metrics_to_dicts(metrics, normalized=True)
|
|
117
118
|
else:
|
|
118
|
-
|
|
119
|
-
final_state, metrics = cast(tuple[LearnerState, Any],
|
|
119
|
+
linear_result = run_learning_loop(learner, stream, config.num_steps, key)
|
|
120
|
+
final_state, metrics = cast(tuple[LearnerState, Any], linear_result)
|
|
120
121
|
metrics_history = metrics_to_dicts(metrics)
|
|
121
122
|
|
|
122
123
|
return SingleRunResult(
|
|
@@ -3,19 +3,22 @@
|
|
|
3
3
|
This module provides a simple Timer context manager for measuring execution time
|
|
4
4
|
and formatting durations in a human-readable format.
|
|
5
5
|
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
6
|
+
Examples
|
|
7
|
+
--------
|
|
8
|
+
```python
|
|
9
|
+
from alberta_framework.utils.timing import Timer
|
|
10
|
+
|
|
11
|
+
with Timer("Training"):
|
|
12
|
+
# run training code
|
|
13
|
+
pass
|
|
14
|
+
# Output: Training completed in 1.23s
|
|
15
|
+
|
|
16
|
+
# Or capture the duration:
|
|
17
|
+
with Timer("Experiment") as t:
|
|
18
|
+
# run experiment
|
|
19
|
+
pass
|
|
20
|
+
print(f"Took {t.duration:.2f} seconds")
|
|
21
|
+
```
|
|
19
22
|
"""
|
|
20
23
|
|
|
21
24
|
import time
|
|
@@ -32,13 +35,13 @@ def format_duration(seconds: float) -> str:
|
|
|
32
35
|
Returns:
|
|
33
36
|
Formatted string like "1.23s", "2m 30.5s", or "1h 5m 30s"
|
|
34
37
|
|
|
35
|
-
Examples
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
38
|
+
Examples
|
|
39
|
+
--------
|
|
40
|
+
```python
|
|
41
|
+
format_duration(0.5) # Returns: '0.50s'
|
|
42
|
+
format_duration(90.5) # Returns: '1m 30.50s'
|
|
43
|
+
format_duration(3665) # Returns: '1h 1m 5.00s'
|
|
44
|
+
```
|
|
42
45
|
"""
|
|
43
46
|
if seconds < 60:
|
|
44
47
|
return f"{seconds:.2f}s"
|
|
@@ -66,22 +69,25 @@ class Timer:
|
|
|
66
69
|
start_time: Timestamp when timing started
|
|
67
70
|
end_time: Timestamp when timing ended
|
|
68
71
|
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
72
|
+
Examples
|
|
73
|
+
--------
|
|
74
|
+
```python
|
|
75
|
+
with Timer("Training loop"):
|
|
76
|
+
for i in range(1000):
|
|
77
|
+
pass
|
|
78
|
+
# Output: Training loop completed in 0.01s
|
|
79
|
+
|
|
80
|
+
# Silent timing (no print):
|
|
81
|
+
with Timer("Silent", verbose=False) as t:
|
|
82
|
+
time.sleep(0.1)
|
|
83
|
+
print(f"Elapsed: {t.duration:.2f}s")
|
|
84
|
+
# Output: Elapsed: 0.10s
|
|
85
|
+
|
|
86
|
+
# Custom print function:
|
|
87
|
+
with Timer("Custom", print_fn=lambda msg: print(f">> {msg}")):
|
|
88
|
+
pass
|
|
89
|
+
# Output: >> Custom completed in 0.00s
|
|
90
|
+
```
|
|
85
91
|
"""
|
|
86
92
|
|
|
87
93
|
def __init__(
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: alberta-framework
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.1
|
|
4
4
|
Summary: Implementation of the Alberta Plan for AI Research - continual learning with meta-learned step-sizes
|
|
5
5
|
Project-URL: Homepage, https://github.com/j-klawson/alberta-framework
|
|
6
6
|
Project-URL: Repository, https://github.com/j-klawson/alberta-framework
|
|
@@ -49,7 +49,7 @@ Description-Content-Type: text/markdown
|
|
|
49
49
|
[](https://opensource.org/licenses/Apache-2.0)
|
|
50
50
|
[](https://www.python.org/downloads/)
|
|
51
51
|
|
|
52
|
-
A JAX-based research framework implementing components of [The Alberta Plan](https://arxiv.org/abs/2208.11173) in the pursuit of building the foundations of Continual AI.
|
|
52
|
+
A JAX-based research framework implementing components of [The Alberta Plan for AI Research](https://arxiv.org/abs/2208.11173) in the pursuit of building the foundations of Continual AI.
|
|
53
53
|
|
|
54
54
|
> "The agents are complex only because they interact with a complex world... their initial design is as simple, general, and scalable as possible." — *Sutton et al., 2022*
|
|
55
55
|
|
|
@@ -57,6 +57,14 @@ A JAX-based research framework implementing components of [The Alberta Plan](htt
|
|
|
57
57
|
|
|
58
58
|
The Alberta Framework provides foundational components for continual reinforcement learning research. Built on JAX for hardware acceleration, the framework emphasizes temporal uniformity every component updates at every time step, with no special training phases or batch processing.
|
|
59
59
|
|
|
60
|
+
## Project Context
|
|
61
|
+
|
|
62
|
+
This framework is developed as part of my D.Eng. work focusing on the foundations of Continual AI. For more background and context see:
|
|
63
|
+
|
|
64
|
+
* **Research Blog**: [blog.9600baud.net](https://blog.9600baud.net)
|
|
65
|
+
* **Replicating Sutton '92**: [The Foundation of Step-size Adaptation](https://blog.9600baud.net/sutton92.html)
|
|
66
|
+
* **About the Author**: [Keith Lawson](https://blog.9600baud.net/about.html)
|
|
67
|
+
|
|
60
68
|
### Roadmap
|
|
61
69
|
|
|
62
70
|
Depending on my research trajectory I may or may not implement components required for the plan. The current focus of this framework is the Step 1 Baseline Study, investigating the interaction between adaptive optimizers and online normalization.
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
alberta_framework/__init__.py,sha256=LUrsm6WFh5-Mxg78d1G-Qe015nkGgcCDhSw5lf3UkFo,5460
|
|
2
|
+
alberta_framework/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
+
alberta_framework/core/__init__.py,sha256=PSrC4zSxgm_6YXWEQ80aZaunpbQ58QexxKmDDU-jp6c,522
|
|
4
|
+
alberta_framework/core/learners.py,sha256=dnRQ5B16oGYpamDJIRYzR54ED9bvW0lpa8c_suC6YBA,29879
|
|
5
|
+
alberta_framework/core/normalizers.py,sha256=Z_d3H17qoXh87DE7k41imvWzkVJQ2xQgDUP7GYSNzAY,5903
|
|
6
|
+
alberta_framework/core/optimizers.py,sha256=OefVuDDG1phh1QQIUyVPsQckl41VrpWFG7hY2eqyc64,14585
|
|
7
|
+
alberta_framework/core/types.py,sha256=mtpVEr2qJ0XzZyjOsUdChmS7T7mrXBDMHb-jfkrT9JY,7503
|
|
8
|
+
alberta_framework/streams/__init__.py,sha256=bsDgWjWjotDQHMI2lno3dgk8N14pd-2mYAQpXAtCPx4,2035
|
|
9
|
+
alberta_framework/streams/base.py,sha256=9rJxvUgmzd5u2bRV4vi5PxhUvj39EZTD4bZHo-Ptn-U,2168
|
|
10
|
+
alberta_framework/streams/gymnasium.py,sha256=s733X7aEgy05hcSazjZEhBiJChtEL7uVpxwh0fXBQZA,21980
|
|
11
|
+
alberta_framework/streams/synthetic.py,sha256=4R9GR7Kh0LT7GmGtPhzMJGr8HbhrAMUOjvPwLZk6nDg,32979
|
|
12
|
+
alberta_framework/utils/__init__.py,sha256=zfKfnbikhLp0J6UgVa8HeRo59gZHwqOc8jf03s7AaT4,2845
|
|
13
|
+
alberta_framework/utils/experiments.py,sha256=ekGAzveCRgv9YZ5mfAD5Uf7h_PvQnxsNw2KeZN2eu00,10644
|
|
14
|
+
alberta_framework/utils/export.py,sha256=W9RKfeTiyZcLColOGNjBfZU0N6QMXrfPn4pdYcm-OSk,15832
|
|
15
|
+
alberta_framework/utils/metrics.py,sha256=1cryNJoboO67vvRhausaucbYZFgdL_06vaf08UXbojg,3349
|
|
16
|
+
alberta_framework/utils/statistics.py,sha256=4fbzNlmsdUaM5lLW1BhL5B5MUpnqimQlwJklZ4x0y0U,15416
|
|
17
|
+
alberta_framework/utils/timing.py,sha256=JOLq8CpCAV7LWOWkftxefduSFjaXnVwal1MFBKEMdJI,4049
|
|
18
|
+
alberta_framework/utils/visualization.py,sha256=PmKBD3KGabNhgDizcNiGJEbVCyDL1YMUE5yTwgJHu2o,17924
|
|
19
|
+
alberta_framework-0.1.1.dist-info/METADATA,sha256=Ny-LxHiqZVNXZbu5f8ZyBSLCEZd2KsBhA9iROV7tNiU,7763
|
|
20
|
+
alberta_framework-0.1.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
21
|
+
alberta_framework-0.1.1.dist-info/licenses/LICENSE,sha256=TI1avodt5mvxz7sunyxIa0HlNgLQcmKNLeRjCVcgKmE,10754
|
|
22
|
+
alberta_framework-0.1.1.dist-info/RECORD,,
|
|
@@ -1,22 +0,0 @@
|
|
|
1
|
-
alberta_framework/__init__.py,sha256=gPLBA2EiPcElsYp_U_Rs7C6wlrGHr8w5IL6C0F90zec,4739
|
|
2
|
-
alberta_framework/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
-
alberta_framework/core/__init__.py,sha256=PSrC4zSxgm_6YXWEQ80aZaunpbQ58QexxKmDDU-jp6c,522
|
|
4
|
-
alberta_framework/core/learners.py,sha256=Abq_iOb9CWy9DMWeIX_7PXxo59gXEmlvNnf8zrDKQpo,18157
|
|
5
|
-
alberta_framework/core/normalizers.py,sha256=-OFkdKcfx4VTwm4WLXu1hxrh8DTdwFlIG6CgVYgaZBk,5905
|
|
6
|
-
alberta_framework/core/optimizers.py,sha256=uHEOhE1ThLcXJ18zagW9SQAEiNpsWapNWSMelbhkdNY,14559
|
|
7
|
-
alberta_framework/core/types.py,sha256=Op9EHIIoEZGKbbr3b7xijaOurlQ-mxohBRv7rnVybro,6307
|
|
8
|
-
alberta_framework/streams/__init__.py,sha256=bsDgWjWjotDQHMI2lno3dgk8N14pd-2mYAQpXAtCPx4,2035
|
|
9
|
-
alberta_framework/streams/base.py,sha256=81zqXTF30Orj0N2BXSLYVHF9wUYZSqthqQi1MG5Kzxs,2165
|
|
10
|
-
alberta_framework/streams/gymnasium.py,sha256=s733X7aEgy05hcSazjZEhBiJChtEL7uVpxwh0fXBQZA,21980
|
|
11
|
-
alberta_framework/streams/synthetic.py,sha256=kRQktC4NNlFvoF_FmY_WG9VkiASGOudz8qdI5VjoRq8,32963
|
|
12
|
-
alberta_framework/utils/__init__.py,sha256=zfKfnbikhLp0J6UgVa8HeRo59gZHwqOc8jf03s7AaT4,2845
|
|
13
|
-
alberta_framework/utils/experiments.py,sha256=8N_JrffUa1S_lIZQIqKDuBxyv4UYt9QXzLlo-YnMAEU,10554
|
|
14
|
-
alberta_framework/utils/export.py,sha256=W9RKfeTiyZcLColOGNjBfZU0N6QMXrfPn4pdYcm-OSk,15832
|
|
15
|
-
alberta_framework/utils/metrics.py,sha256=1cryNJoboO67vvRhausaucbYZFgdL_06vaf08UXbojg,3349
|
|
16
|
-
alberta_framework/utils/statistics.py,sha256=4fbzNlmsdUaM5lLW1BhL5B5MUpnqimQlwJklZ4x0y0U,15416
|
|
17
|
-
alberta_framework/utils/timing.py,sha256=05NwXrIc9nS2p2MCHjdOgglPbE1CHZsLLdSB6em7YNY,4110
|
|
18
|
-
alberta_framework/utils/visualization.py,sha256=PmKBD3KGabNhgDizcNiGJEbVCyDL1YMUE5yTwgJHu2o,17924
|
|
19
|
-
alberta_framework-0.1.0.dist-info/METADATA,sha256=G4lIPB7NJlkJGOlmbo7BqeIsHNUl2M-L4f8Gd1T2Ro8,7332
|
|
20
|
-
alberta_framework-0.1.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
21
|
-
alberta_framework-0.1.0.dist-info/licenses/LICENSE,sha256=TI1avodt5mvxz7sunyxIa0HlNgLQcmKNLeRjCVcgKmE,10754
|
|
22
|
-
alberta_framework-0.1.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|