alberta-framework 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,28 +1,45 @@
1
- """Alberta Framework: Implementation of the Alberta Plan for AI Research.
2
-
3
- This framework implements Step 1 of the Alberta Plan: continual supervised
4
- learning with meta-learned step-sizes.
5
-
6
- Core Philosophy: Temporal uniformity - every component updates at every time step.
7
-
8
- Quick Start:
9
- >>> import jax.random as jr
10
- >>> from alberta_framework import LinearLearner, IDBD, RandomWalkStream, run_learning_loop
11
- >>>
12
- >>> # Create a non-stationary stream
13
- >>> stream = RandomWalkStream(feature_dim=10, drift_rate=0.001)
14
- >>>
15
- >>> # Create a learner with adaptive step-sizes
16
- >>> learner = LinearLearner(optimizer=IDBD())
17
- >>>
18
- >>> # Run learning loop with scan
19
- >>> key = jr.key(42)
20
- >>> state, metrics = run_learning_loop(learner, stream, num_steps=10000, key=key)
21
-
22
- Reference: The Alberta Plan for AI Research (Sutton et al.)
1
+ """Alberta Framework: A JAX-based research framework for continual AI.
2
+
3
+ The Alberta Framework provides foundational components for continual reinforcement
4
+ learning research. Built on JAX for hardware acceleration, the framework emphasizes
5
+ temporal uniformity — every component updates at every time step, with no special
6
+ training phases or batch processing.
7
+
8
+ Roadmap
9
+ -------
10
+ | Step | Focus | Status |
11
+ |------|-------|--------|
12
+ | 1 | Meta-learned step-sizes (IDBD, Autostep) | **Complete** |
13
+ | 2 | Feature generation and testing | Planned |
14
+ | 3 | GVF predictions, Horde architecture | Planned |
15
+ | 4 | Actor-critic with eligibility traces | Planned |
16
+ | 5-6 | Off-policy learning, average reward | Planned |
17
+ | 7-12 | Hierarchical, multi-agent, world models | Future |
18
+
19
+ Examples
20
+ --------
21
+ ```python
22
+ import jax.random as jr
23
+ from alberta_framework import LinearLearner, IDBD, RandomWalkStream, run_learning_loop
24
+
25
+ # Non-stationary stream where target weights drift over time
26
+ stream = RandomWalkStream(feature_dim=10, drift_rate=0.001)
27
+
28
+ # Learner with IDBD meta-learned step-sizes
29
+ learner = LinearLearner(optimizer=IDBD())
30
+
31
+ # JIT-compiled training via jax.lax.scan
32
+ state, metrics = run_learning_loop(learner, stream, num_steps=10000, key=jr.key(42))
33
+ ```
34
+
35
+ References
36
+ ----------
37
+ - The Alberta Plan for AI Research (Sutton et al., 2022): https://arxiv.org/abs/2208.11173
38
+ - Adapting Bias by Gradient Descent (Sutton, 1992)
39
+ - Tuning-free Step-size Adaptation (Mahmood et al., 2012)
23
40
  """
24
41
 
25
- __version__ = "0.1.0"
42
+ __version__ = "0.2.0"
26
43
 
27
44
  # Core types
28
45
  # Learners
@@ -33,7 +50,9 @@ from alberta_framework.core.learners import (
33
50
  UpdateResult,
34
51
  metrics_to_dicts,
35
52
  run_learning_loop,
53
+ run_learning_loop_batched,
36
54
  run_normalized_learning_loop,
55
+ run_normalized_learning_loop_batched,
37
56
  )
38
57
 
39
58
  # Normalizers
@@ -47,9 +66,13 @@ from alberta_framework.core.normalizers import (
47
66
  from alberta_framework.core.optimizers import IDBD, LMS, Autostep, Optimizer
48
67
  from alberta_framework.core.types import (
49
68
  AutostepState,
69
+ BatchedLearningResult,
70
+ BatchedNormalizedResult,
50
71
  IDBDState,
51
72
  LearnerState,
52
73
  LMSState,
74
+ NormalizerHistory,
75
+ NormalizerTrackingConfig,
53
76
  Observation,
54
77
  Prediction,
55
78
  StepSizeHistory,
@@ -119,10 +142,14 @@ __all__ = [
119
142
  "__version__",
120
143
  # Types
121
144
  "AutostepState",
145
+ "BatchedLearningResult",
146
+ "BatchedNormalizedResult",
122
147
  "IDBDState",
123
148
  "LMSState",
124
149
  "LearnerState",
150
+ "NormalizerHistory",
125
151
  "NormalizerState",
152
+ "NormalizerTrackingConfig",
126
153
  "Observation",
127
154
  "Prediction",
128
155
  "StepSizeHistory",
@@ -143,7 +170,9 @@ __all__ = [
143
170
  "NormalizedLearnerState",
144
171
  "NormalizedLinearLearner",
145
172
  "run_learning_loop",
173
+ "run_learning_loop_batched",
146
174
  "run_normalized_learning_loop",
175
+ "run_normalized_learning_loop_batched",
147
176
  "metrics_to_dicts",
148
177
  # Streams - protocol
149
178
  "ScanStream",
@@ -15,9 +15,13 @@ from alberta_framework.core.normalizers import NormalizerState, OnlineNormalizer
15
15
  from alberta_framework.core.optimizers import LMS, Optimizer
16
16
  from alberta_framework.core.types import (
17
17
  AutostepState,
18
+ BatchedLearningResult,
19
+ BatchedNormalizedResult,
18
20
  IDBDState,
19
21
  LearnerState,
20
22
  LMSState,
23
+ NormalizerHistory,
24
+ NormalizerTrackingConfig,
21
25
  Observation,
22
26
  Prediction,
23
27
  StepSizeHistory,
@@ -48,7 +52,7 @@ class UpdateResult(NamedTuple):
48
52
  class LinearLearner:
49
53
  """Linear function approximator with pluggable optimizer.
50
54
 
51
- Computes predictions as: y = w @ x + b
55
+ Computes predictions as: `y = w @ x + b`
52
56
 
53
57
  The learner maintains weights and bias, delegating the adaptation
54
58
  of learning rates to the optimizer (e.g., LMS or IDBD).
@@ -93,7 +97,7 @@ class LinearLearner:
93
97
  observation: Input feature vector
94
98
 
95
99
  Returns:
96
- Scalar prediction y = w @ x + b
100
+ Scalar prediction `y = w @ x + b`
97
101
  """
98
102
  return jnp.atleast_1d(jnp.dot(state.weights, observation) + state.bias)
99
103
 
@@ -238,10 +242,25 @@ def run_learning_loop[StreamStateT](
238
242
  )
239
243
  recording_indices = jnp.zeros(num_recordings, dtype=jnp.int32)
240
244
 
245
+ # Check if we need to track Autostep normalizers
246
+ # We detect this at trace time by checking the initial optimizer state
247
+ track_normalizers = hasattr(learner_state.optimizer_state, "normalizers")
248
+ normalizer_history = (
249
+ jnp.zeros((num_recordings, feature_dim), dtype=jnp.float32)
250
+ if track_normalizers
251
+ else None
252
+ )
253
+
241
254
  def step_fn_with_tracking(
242
- carry: tuple[LearnerState, StreamStateT, Array, Array | None, Array], idx: Array
243
- ) -> tuple[tuple[LearnerState, StreamStateT, Array, Array | None, Array], Array]:
244
- l_state, s_state, ss_history, b_history, rec_indices = carry
255
+ carry: tuple[
256
+ LearnerState, StreamStateT, Array, Array | None, Array, Array | None
257
+ ],
258
+ idx: Array,
259
+ ) -> tuple[
260
+ tuple[LearnerState, StreamStateT, Array, Array | None, Array, Array | None],
261
+ Array,
262
+ ]:
263
+ l_state, s_state, ss_history, b_history, rec_indices, norm_history = carry
245
264
 
246
265
  # Perform learning step
247
266
  timestep, new_s_state = stream.step(s_state, idx)
@@ -291,12 +310,25 @@ def run_learning_loop[StreamStateT](
291
310
  None,
292
311
  )
293
312
 
313
+ # Track Autostep normalizers (v_i) if applicable
314
+ new_norm_history = norm_history
315
+ if norm_history is not None and hasattr(opt_state, "normalizers"):
316
+ new_norm_history = jax.lax.cond(
317
+ should_record,
318
+ lambda _: norm_history.at[recording_idx].set(
319
+ opt_state.normalizers # type: ignore[union-attr]
320
+ ),
321
+ lambda _: norm_history,
322
+ None,
323
+ )
324
+
294
325
  return (
295
326
  result.state,
296
327
  new_s_state,
297
328
  new_ss_history,
298
329
  new_b_history,
299
330
  new_rec_indices,
331
+ new_norm_history,
300
332
  ), result.metrics
301
333
 
302
334
  initial_carry = (
@@ -305,16 +337,25 @@ def run_learning_loop[StreamStateT](
305
337
  step_size_history,
306
338
  bias_history,
307
339
  recording_indices,
340
+ normalizer_history,
308
341
  )
309
342
 
310
- (final_learner, _, final_ss_history, final_b_history, final_rec_indices), metrics = (
311
- jax.lax.scan(step_fn_with_tracking, initial_carry, jnp.arange(num_steps))
343
+ (
344
+ final_learner,
345
+ _,
346
+ final_ss_history,
347
+ final_b_history,
348
+ final_rec_indices,
349
+ final_norm_history,
350
+ ), metrics = jax.lax.scan(
351
+ step_fn_with_tracking, initial_carry, jnp.arange(num_steps)
312
352
  )
313
353
 
314
354
  history = StepSizeHistory(
315
355
  step_sizes=final_ss_history,
316
356
  bias_step_sizes=final_b_history,
317
357
  recording_indices=final_rec_indices,
358
+ normalizers=final_norm_history,
318
359
  )
319
360
 
320
361
  return final_learner, metrics, history
@@ -473,7 +514,14 @@ def run_normalized_learning_loop[StreamStateT](
473
514
  num_steps: int,
474
515
  key: Array,
475
516
  learner_state: NormalizedLearnerState | None = None,
476
- ) -> tuple[NormalizedLearnerState, Array]:
517
+ step_size_tracking: StepSizeTrackingConfig | None = None,
518
+ normalizer_tracking: NormalizerTrackingConfig | None = None,
519
+ ) -> (
520
+ tuple[NormalizedLearnerState, Array]
521
+ | tuple[NormalizedLearnerState, Array, StepSizeHistory]
522
+ | tuple[NormalizedLearnerState, Array, NormalizerHistory]
523
+ | tuple[NormalizedLearnerState, Array, StepSizeHistory, NormalizerHistory]
524
+ ):
477
525
  """Run the learning loop with normalization using jax.lax.scan.
478
526
 
479
527
  Args:
@@ -482,29 +530,512 @@ def run_normalized_learning_loop[StreamStateT](
482
530
  num_steps: Number of learning steps to run
483
531
  key: JAX random key for stream initialization
484
532
  learner_state: Initial state (if None, will be initialized from stream)
533
+ step_size_tracking: Optional config for recording per-weight step-sizes.
534
+ When provided, returns StepSizeHistory including Autostep normalizers if applicable.
535
+ normalizer_tracking: Optional config for recording per-feature normalizer state.
536
+ When provided, returns NormalizerHistory with means and variances over time.
485
537
 
486
538
  Returns:
487
- Tuple of (final_state, metrics_array) where metrics_array has shape
488
- (num_steps, 4) with columns [squared_error, error, mean_step_size, normalizer_mean_var]
539
+ If no tracking:
540
+ Tuple of (final_state, metrics_array) where metrics_array has shape
541
+ (num_steps, 4) with columns [squared_error, error, mean_step_size, normalizer_mean_var]
542
+ If step_size_tracking only:
543
+ Tuple of (final_state, metrics_array, step_size_history)
544
+ If normalizer_tracking only:
545
+ Tuple of (final_state, metrics_array, normalizer_history)
546
+ If both:
547
+ Tuple of (final_state, metrics_array, step_size_history, normalizer_history)
548
+
549
+ Raises:
550
+ ValueError: If tracking interval is invalid
489
551
  """
552
+ # Validate tracking configs
553
+ if step_size_tracking is not None:
554
+ if step_size_tracking.interval < 1:
555
+ raise ValueError(
556
+ f"step_size_tracking.interval must be >= 1, got {step_size_tracking.interval}"
557
+ )
558
+ if step_size_tracking.interval > num_steps:
559
+ raise ValueError(
560
+ f"step_size_tracking.interval ({step_size_tracking.interval}) "
561
+ f"must be <= num_steps ({num_steps})"
562
+ )
563
+
564
+ if normalizer_tracking is not None:
565
+ if normalizer_tracking.interval < 1:
566
+ raise ValueError(
567
+ f"normalizer_tracking.interval must be >= 1, got {normalizer_tracking.interval}"
568
+ )
569
+ if normalizer_tracking.interval > num_steps:
570
+ raise ValueError(
571
+ f"normalizer_tracking.interval ({normalizer_tracking.interval}) "
572
+ f"must be <= num_steps ({num_steps})"
573
+ )
574
+
490
575
  # Initialize states
491
576
  if learner_state is None:
492
577
  learner_state = learner.init(stream.feature_dim)
493
578
  stream_state = stream.init(key)
494
579
 
495
- def step_fn(
496
- carry: tuple[NormalizedLearnerState, StreamStateT], idx: Array
497
- ) -> tuple[tuple[NormalizedLearnerState, StreamStateT], Array]:
498
- l_state, s_state = carry
580
+ feature_dim = stream.feature_dim
581
+
582
+ # No tracking - simple case
583
+ if step_size_tracking is None and normalizer_tracking is None:
584
+
585
+ def step_fn(
586
+ carry: tuple[NormalizedLearnerState, StreamStateT], idx: Array
587
+ ) -> tuple[tuple[NormalizedLearnerState, StreamStateT], Array]:
588
+ l_state, s_state = carry
589
+ timestep, new_s_state = stream.step(s_state, idx)
590
+ result = learner.update(l_state, timestep.observation, timestep.target)
591
+ return (result.state, new_s_state), result.metrics
592
+
593
+ (final_learner, _), metrics = jax.lax.scan(
594
+ step_fn, (learner_state, stream_state), jnp.arange(num_steps)
595
+ )
596
+
597
+ return final_learner, metrics
598
+
599
+ # Tracking enabled - need to set up history arrays
600
+ ss_interval = step_size_tracking.interval if step_size_tracking else num_steps + 1
601
+ norm_interval = (
602
+ normalizer_tracking.interval if normalizer_tracking else num_steps + 1
603
+ )
604
+
605
+ ss_num_recordings = num_steps // ss_interval if step_size_tracking else 0
606
+ norm_num_recordings = num_steps // norm_interval if normalizer_tracking else 0
607
+
608
+ # Pre-allocate step-size history arrays
609
+ ss_history = (
610
+ jnp.zeros((ss_num_recordings, feature_dim), dtype=jnp.float32)
611
+ if step_size_tracking
612
+ else None
613
+ )
614
+ ss_bias_history = (
615
+ jnp.zeros(ss_num_recordings, dtype=jnp.float32)
616
+ if step_size_tracking and step_size_tracking.include_bias
617
+ else None
618
+ )
619
+ ss_rec_indices = (
620
+ jnp.zeros(ss_num_recordings, dtype=jnp.int32) if step_size_tracking else None
621
+ )
622
+
623
+ # Check if we need to track Autostep normalizers
624
+ track_autostep_normalizers = hasattr(
625
+ learner_state.learner_state.optimizer_state, "normalizers"
626
+ )
627
+ ss_normalizers = (
628
+ jnp.zeros((ss_num_recordings, feature_dim), dtype=jnp.float32)
629
+ if step_size_tracking and track_autostep_normalizers
630
+ else None
631
+ )
632
+
633
+ # Pre-allocate normalizer state history arrays
634
+ norm_means = (
635
+ jnp.zeros((norm_num_recordings, feature_dim), dtype=jnp.float32)
636
+ if normalizer_tracking
637
+ else None
638
+ )
639
+ norm_vars = (
640
+ jnp.zeros((norm_num_recordings, feature_dim), dtype=jnp.float32)
641
+ if normalizer_tracking
642
+ else None
643
+ )
644
+ norm_rec_indices = (
645
+ jnp.zeros(norm_num_recordings, dtype=jnp.int32) if normalizer_tracking else None
646
+ )
647
+
648
+ def step_fn_with_tracking(
649
+ carry: tuple[
650
+ NormalizedLearnerState,
651
+ StreamStateT,
652
+ Array | None,
653
+ Array | None,
654
+ Array | None,
655
+ Array | None,
656
+ Array | None,
657
+ Array | None,
658
+ Array | None,
659
+ ],
660
+ idx: Array,
661
+ ) -> tuple[
662
+ tuple[
663
+ NormalizedLearnerState,
664
+ StreamStateT,
665
+ Array | None,
666
+ Array | None,
667
+ Array | None,
668
+ Array | None,
669
+ Array | None,
670
+ Array | None,
671
+ Array | None,
672
+ ],
673
+ Array,
674
+ ]:
675
+ (
676
+ l_state,
677
+ s_state,
678
+ ss_hist,
679
+ ss_bias_hist,
680
+ ss_rec,
681
+ ss_norm,
682
+ n_means,
683
+ n_vars,
684
+ n_rec,
685
+ ) = carry
686
+
687
+ # Perform learning step
499
688
  timestep, new_s_state = stream.step(s_state, idx)
500
689
  result = learner.update(l_state, timestep.observation, timestep.target)
501
- return (result.state, new_s_state), result.metrics
502
690
 
503
- (final_learner, _), metrics = jax.lax.scan(
504
- step_fn, (learner_state, stream_state), jnp.arange(num_steps)
691
+ # Step-size tracking
692
+ new_ss_hist = ss_hist
693
+ new_ss_bias_hist = ss_bias_hist
694
+ new_ss_rec = ss_rec
695
+ new_ss_norm = ss_norm
696
+
697
+ if ss_hist is not None:
698
+ should_record_ss = (idx % ss_interval) == 0
699
+ recording_idx = idx // ss_interval
700
+
701
+ # Extract current step-sizes from the inner learner state
702
+ opt_state = result.state.learner_state.optimizer_state
703
+ if hasattr(opt_state, "log_step_sizes"):
704
+ # IDBD stores log step-sizes
705
+ weight_ss = jnp.exp(opt_state.log_step_sizes) # type: ignore[union-attr]
706
+ bias_ss = opt_state.bias_step_size # type: ignore[union-attr]
707
+ elif hasattr(opt_state, "step_sizes"):
708
+ # Autostep stores step-sizes directly
709
+ weight_ss = opt_state.step_sizes # type: ignore[union-attr]
710
+ bias_ss = opt_state.bias_step_size # type: ignore[union-attr]
711
+ else:
712
+ # LMS has a single fixed step-size
713
+ weight_ss = jnp.full(feature_dim, opt_state.step_size)
714
+ bias_ss = opt_state.step_size
715
+
716
+ new_ss_hist = jax.lax.cond(
717
+ should_record_ss,
718
+ lambda _: ss_hist.at[recording_idx].set(weight_ss),
719
+ lambda _: ss_hist,
720
+ None,
721
+ )
722
+
723
+ if ss_bias_hist is not None:
724
+ new_ss_bias_hist = jax.lax.cond(
725
+ should_record_ss,
726
+ lambda _: ss_bias_hist.at[recording_idx].set(bias_ss),
727
+ lambda _: ss_bias_hist,
728
+ None,
729
+ )
730
+
731
+ if ss_rec is not None:
732
+ new_ss_rec = jax.lax.cond(
733
+ should_record_ss,
734
+ lambda _: ss_rec.at[recording_idx].set(idx),
735
+ lambda _: ss_rec,
736
+ None,
737
+ )
738
+
739
+ # Track Autostep normalizers (v_i) if applicable
740
+ if ss_norm is not None and hasattr(opt_state, "normalizers"):
741
+ new_ss_norm = jax.lax.cond(
742
+ should_record_ss,
743
+ lambda _: ss_norm.at[recording_idx].set(
744
+ opt_state.normalizers # type: ignore[union-attr]
745
+ ),
746
+ lambda _: ss_norm,
747
+ None,
748
+ )
749
+
750
+ # Normalizer state tracking
751
+ new_n_means = n_means
752
+ new_n_vars = n_vars
753
+ new_n_rec = n_rec
754
+
755
+ if n_means is not None:
756
+ should_record_norm = (idx % norm_interval) == 0
757
+ norm_recording_idx = idx // norm_interval
758
+
759
+ norm_state = result.state.normalizer_state
760
+
761
+ new_n_means = jax.lax.cond(
762
+ should_record_norm,
763
+ lambda _: n_means.at[norm_recording_idx].set(norm_state.mean),
764
+ lambda _: n_means,
765
+ None,
766
+ )
767
+
768
+ if n_vars is not None:
769
+ new_n_vars = jax.lax.cond(
770
+ should_record_norm,
771
+ lambda _: n_vars.at[norm_recording_idx].set(norm_state.var),
772
+ lambda _: n_vars,
773
+ None,
774
+ )
775
+
776
+ if n_rec is not None:
777
+ new_n_rec = jax.lax.cond(
778
+ should_record_norm,
779
+ lambda _: n_rec.at[norm_recording_idx].set(idx),
780
+ lambda _: n_rec,
781
+ None,
782
+ )
783
+
784
+ return (
785
+ result.state,
786
+ new_s_state,
787
+ new_ss_hist,
788
+ new_ss_bias_hist,
789
+ new_ss_rec,
790
+ new_ss_norm,
791
+ new_n_means,
792
+ new_n_vars,
793
+ new_n_rec,
794
+ ), result.metrics
795
+
796
+ initial_carry = (
797
+ learner_state,
798
+ stream_state,
799
+ ss_history,
800
+ ss_bias_history,
801
+ ss_rec_indices,
802
+ ss_normalizers,
803
+ norm_means,
804
+ norm_vars,
805
+ norm_rec_indices,
806
+ )
807
+
808
+ (
809
+ final_learner,
810
+ _,
811
+ final_ss_hist,
812
+ final_ss_bias_hist,
813
+ final_ss_rec,
814
+ final_ss_norm,
815
+ final_n_means,
816
+ final_n_vars,
817
+ final_n_rec,
818
+ ), metrics = jax.lax.scan(
819
+ step_fn_with_tracking, initial_carry, jnp.arange(num_steps)
820
+ )
821
+
822
+ # Build return values based on what was tracked
823
+ ss_history_result = None
824
+ if step_size_tracking is not None and final_ss_hist is not None:
825
+ ss_history_result = StepSizeHistory(
826
+ step_sizes=final_ss_hist,
827
+ bias_step_sizes=final_ss_bias_hist,
828
+ recording_indices=final_ss_rec, # type: ignore[arg-type]
829
+ normalizers=final_ss_norm,
830
+ )
831
+
832
+ norm_history_result = None
833
+ if normalizer_tracking is not None and final_n_means is not None:
834
+ norm_history_result = NormalizerHistory(
835
+ means=final_n_means,
836
+ variances=final_n_vars, # type: ignore[arg-type]
837
+ recording_indices=final_n_rec, # type: ignore[arg-type]
838
+ )
839
+
840
+ # Return appropriate tuple based on what was tracked
841
+ if ss_history_result is not None and norm_history_result is not None:
842
+ return final_learner, metrics, ss_history_result, norm_history_result
843
+ elif ss_history_result is not None:
844
+ return final_learner, metrics, ss_history_result
845
+ elif norm_history_result is not None:
846
+ return final_learner, metrics, norm_history_result
847
+ else:
848
+ return final_learner, metrics
849
+
850
+
851
+ def run_learning_loop_batched[StreamStateT](
852
+ learner: LinearLearner,
853
+ stream: ScanStream[StreamStateT],
854
+ num_steps: int,
855
+ keys: Array,
856
+ learner_state: LearnerState | None = None,
857
+ step_size_tracking: StepSizeTrackingConfig | None = None,
858
+ ) -> BatchedLearningResult:
859
+ """Run learning loop across multiple seeds in parallel using jax.vmap.
860
+
861
+ This function provides GPU parallelization for multi-seed experiments,
862
+ typically achieving 2-5x speedup over sequential execution.
863
+
864
+ Args:
865
+ learner: The learner to train
866
+ stream: Experience stream providing (observation, target) pairs
867
+ num_steps: Number of learning steps to run per seed
868
+ keys: JAX random keys with shape (num_seeds,) or (num_seeds, 2)
869
+ learner_state: Initial state (if None, will be initialized from stream).
870
+ The same initial state is used for all seeds.
871
+ step_size_tracking: Optional config for recording per-weight step-sizes.
872
+ When provided, history arrays have shape (num_seeds, num_recordings, ...)
873
+
874
+ Returns:
875
+ BatchedLearningResult containing:
876
+ - states: Batched final states with shape (num_seeds, ...) for each array
877
+ - metrics: Array of shape (num_seeds, num_steps, 3)
878
+ - step_size_history: Batched history or None if tracking disabled
879
+
880
+ Examples:
881
+ ```python
882
+ import jax.random as jr
883
+ from alberta_framework import LinearLearner, IDBD, RandomWalkStream
884
+ from alberta_framework import run_learning_loop_batched
885
+
886
+ stream = RandomWalkStream(feature_dim=10)
887
+ learner = LinearLearner(optimizer=IDBD())
888
+
889
+ # Run 30 seeds in parallel
890
+ keys = jr.split(jr.key(42), 30)
891
+ result = run_learning_loop_batched(learner, stream, num_steps=10000, keys=keys)
892
+
893
+ # result.metrics has shape (30, 10000, 3)
894
+ mean_error = result.metrics[:, :, 0].mean(axis=0) # Average over seeds
895
+ ```
896
+ """
897
+ # Define single-seed function that returns consistent structure
898
+ def single_seed_run(key: Array) -> tuple[LearnerState, Array, StepSizeHistory | None]:
899
+ result = run_learning_loop(
900
+ learner, stream, num_steps, key, learner_state, step_size_tracking
901
+ )
902
+ if step_size_tracking is not None:
903
+ state, metrics, history = result
904
+ return state, metrics, history
905
+ else:
906
+ state, metrics = result
907
+ # Return None for history to maintain consistent output structure
908
+ return state, metrics, None
909
+
910
+ # vmap over the keys dimension
911
+ batched_states, batched_metrics, batched_history = jax.vmap(single_seed_run)(keys)
912
+
913
+ # Reconstruct batched history if tracking was enabled
914
+ if step_size_tracking is not None:
915
+ batched_step_size_history = StepSizeHistory(
916
+ step_sizes=batched_history.step_sizes,
917
+ bias_step_sizes=batched_history.bias_step_sizes,
918
+ recording_indices=batched_history.recording_indices,
919
+ normalizers=batched_history.normalizers,
920
+ )
921
+ else:
922
+ batched_step_size_history = None
923
+
924
+ return BatchedLearningResult(
925
+ states=batched_states,
926
+ metrics=batched_metrics,
927
+ step_size_history=batched_step_size_history,
505
928
  )
506
929
 
507
- return final_learner, metrics
930
+
931
+ def run_normalized_learning_loop_batched[StreamStateT](
932
+ learner: NormalizedLinearLearner,
933
+ stream: ScanStream[StreamStateT],
934
+ num_steps: int,
935
+ keys: Array,
936
+ learner_state: NormalizedLearnerState | None = None,
937
+ step_size_tracking: StepSizeTrackingConfig | None = None,
938
+ normalizer_tracking: NormalizerTrackingConfig | None = None,
939
+ ) -> BatchedNormalizedResult:
940
+ """Run normalized learning loop across multiple seeds in parallel using jax.vmap.
941
+
942
+ This function provides GPU parallelization for multi-seed experiments with
943
+ normalized learners, typically achieving 2-5x speedup over sequential execution.
944
+
945
+ Args:
946
+ learner: The normalized learner to train
947
+ stream: Experience stream providing (observation, target) pairs
948
+ num_steps: Number of learning steps to run per seed
949
+ keys: JAX random keys with shape (num_seeds,) or (num_seeds, 2)
950
+ learner_state: Initial state (if None, will be initialized from stream).
951
+ The same initial state is used for all seeds.
952
+ step_size_tracking: Optional config for recording per-weight step-sizes.
953
+ When provided, history arrays have shape (num_seeds, num_recordings, ...)
954
+ normalizer_tracking: Optional config for recording normalizer state.
955
+ When provided, history arrays have shape (num_seeds, num_recordings, ...)
956
+
957
+ Returns:
958
+ BatchedNormalizedResult containing:
959
+ - states: Batched final states with shape (num_seeds, ...) for each array
960
+ - metrics: Array of shape (num_seeds, num_steps, 4)
961
+ - step_size_history: Batched history or None if tracking disabled
962
+ - normalizer_history: Batched history or None if tracking disabled
963
+
964
+ Examples:
965
+ ```python
966
+ import jax.random as jr
967
+ from alberta_framework import NormalizedLinearLearner, IDBD, RandomWalkStream
968
+ from alberta_framework import run_normalized_learning_loop_batched
969
+
970
+ stream = RandomWalkStream(feature_dim=10)
971
+ learner = NormalizedLinearLearner(optimizer=IDBD())
972
+
973
+ # Run 30 seeds in parallel
974
+ keys = jr.split(jr.key(42), 30)
975
+ result = run_normalized_learning_loop_batched(
976
+ learner, stream, num_steps=10000, keys=keys
977
+ )
978
+
979
+ # result.metrics has shape (30, 10000, 4)
980
+ mean_error = result.metrics[:, :, 0].mean(axis=0) # Average over seeds
981
+ ```
982
+ """
983
+ # Define single-seed function that returns consistent structure
984
+ def single_seed_run(
985
+ key: Array,
986
+ ) -> tuple[
987
+ NormalizedLearnerState, Array, StepSizeHistory | None, NormalizerHistory | None
988
+ ]:
989
+ result = run_normalized_learning_loop(
990
+ learner, stream, num_steps, key, learner_state,
991
+ step_size_tracking, normalizer_tracking
992
+ )
993
+
994
+ # Unpack based on what tracking was enabled
995
+ if step_size_tracking is not None and normalizer_tracking is not None:
996
+ state, metrics, ss_history, norm_history = result
997
+ return state, metrics, ss_history, norm_history
998
+ elif step_size_tracking is not None:
999
+ state, metrics, ss_history = result
1000
+ return state, metrics, ss_history, None
1001
+ elif normalizer_tracking is not None:
1002
+ state, metrics, norm_history = result
1003
+ return state, metrics, None, norm_history
1004
+ else:
1005
+ state, metrics = result
1006
+ return state, metrics, None, None
1007
+
1008
+ # vmap over the keys dimension
1009
+ batched_states, batched_metrics, batched_ss_history, batched_norm_history = (
1010
+ jax.vmap(single_seed_run)(keys)
1011
+ )
1012
+
1013
+ # Reconstruct batched histories if tracking was enabled
1014
+ if step_size_tracking is not None:
1015
+ batched_step_size_history = StepSizeHistory(
1016
+ step_sizes=batched_ss_history.step_sizes,
1017
+ bias_step_sizes=batched_ss_history.bias_step_sizes,
1018
+ recording_indices=batched_ss_history.recording_indices,
1019
+ normalizers=batched_ss_history.normalizers,
1020
+ )
1021
+ else:
1022
+ batched_step_size_history = None
1023
+
1024
+ if normalizer_tracking is not None:
1025
+ batched_normalizer_history = NormalizerHistory(
1026
+ means=batched_norm_history.means,
1027
+ variances=batched_norm_history.variances,
1028
+ recording_indices=batched_norm_history.recording_indices,
1029
+ )
1030
+ else:
1031
+ batched_normalizer_history = None
1032
+
1033
+ return BatchedNormalizedResult(
1034
+ states=batched_states,
1035
+ metrics=batched_metrics,
1036
+ step_size_history=batched_step_size_history,
1037
+ normalizer_history=batched_normalizer_history,
1038
+ )
508
1039
 
509
1040
 
510
1041
  def metrics_to_dicts(metrics: Array, normalized: bool = False) -> list[dict[str, float]]:
@@ -35,7 +35,7 @@ class OnlineNormalizer:
35
35
  """Online feature normalizer for continual learning.
36
36
 
37
37
  Normalizes features using running estimates of mean and standard deviation:
38
- x_normalized = (x - mean) / (std + epsilon)
38
+ `x_normalized = (x - mean) / (std + epsilon)`
39
39
 
40
40
  The normalizer updates its estimates at every time step, following
41
41
  temporal uniformity. Uses exponential moving average for non-stationary
@@ -72,7 +72,7 @@ class Optimizer[StateT: (LMSState, IDBDState, AutostepState)](ABC):
72
72
  class LMS(Optimizer[LMSState]):
73
73
  """Least Mean Square optimizer with fixed step-size.
74
74
 
75
- The simplest gradient-based optimizer: w_{t+1} = w_t + alpha * delta * x_t
75
+ The simplest gradient-based optimizer: `w_{t+1} = w_t + alpha * delta * x_t`
76
76
 
77
77
  This serves as a baseline. The challenge is that the optimal step-size
78
78
  depends on the problem and changes as the task becomes non-stationary.
@@ -108,7 +108,7 @@ class LMS(Optimizer[LMSState]):
108
108
  ) -> OptimizerUpdate:
109
109
  """Compute LMS weight update.
110
110
 
111
- Update rule: delta_w = alpha * error * x
111
+ Update rule: `delta_w = alpha * error * x`
112
112
 
113
113
  Args:
114
114
  state: Current LMS state
@@ -195,10 +195,11 @@ class IDBD(Optimizer[IDBDState]):
195
195
  """Compute IDBD weight update with adaptive step-sizes.
196
196
 
197
197
  The IDBD algorithm:
198
- 1. Compute step-sizes: alpha_i = exp(log_alpha_i)
199
- 2. Update weights: w_i += alpha_i * error * x_i
200
- 3. Update log step-sizes: log_alpha_i += beta * error * x_i * h_i
201
- 4. Update traces: h_i = h_i * max(0, 1 - alpha_i * x_i^2) + alpha_i * error * x_i
198
+
199
+ 1. Compute step-sizes: `alpha_i = exp(log_alpha_i)`
200
+ 2. Update weights: `w_i += alpha_i * error * x_i`
201
+ 3. Update log step-sizes: `log_alpha_i += beta * error * x_i * h_i`
202
+ 4. Update traces: `h_i = h_i * max(0, 1 - alpha_i * x_i^2) + alpha_i * error * x_i`
202
203
 
203
204
  The trace h_i tracks the correlation between current and past gradients.
204
205
  When gradients consistently point the same direction, h_i grows,
@@ -335,12 +336,13 @@ class Autostep(Optimizer[AutostepState]):
335
336
  """Compute Autostep weight update with normalized gradients.
336
337
 
337
338
  The Autostep algorithm:
338
- 1. Compute gradient: g_i = error * x_i
339
- 2. Normalize gradient: g_i' = g_i / max(|g_i|, v_i)
340
- 3. Update weights: w_i += alpha_i * g_i'
341
- 4. Update step-sizes: alpha_i *= exp(mu * g_i' * h_i)
342
- 5. Update traces: h_i = h_i * (1 - alpha_i) + alpha_i * g_i'
343
- 6. Update normalizers: v_i = max(|g_i|, v_i * tau)
339
+
340
+ 1. Compute gradient: `g_i = error * x_i`
341
+ 2. Normalize gradient: `g_i' = g_i / max(|g_i|, v_i)`
342
+ 3. Update weights: `w_i += alpha_i * g_i'`
343
+ 4. Update step-sizes: `alpha_i *= exp(mu * g_i' * h_i)`
344
+ 5. Update traces: `h_i = h_i * (1 - alpha_i) + alpha_i * g_i'`
345
+ 6. Update normalizers: `v_i = max(|g_i|, v_i * tau)`
344
346
 
345
347
  Args:
346
348
  state: Current Autostep state
@@ -126,11 +126,81 @@ class StepSizeHistory(NamedTuple):
126
126
  step_sizes: Per-weight step-sizes at each recording, shape (num_recordings, num_weights)
127
127
  bias_step_sizes: Bias step-sizes at each recording, shape (num_recordings,) or None
128
128
  recording_indices: Step indices where recordings were made, shape (num_recordings,)
129
+ normalizers: Autostep's per-weight normalizers (v_i) at each recording,
130
+ shape (num_recordings, num_weights) or None. Only populated for Autostep optimizer.
129
131
  """
130
132
 
131
133
  step_sizes: Array # (num_recordings, num_weights)
132
134
  bias_step_sizes: Array | None # (num_recordings,) or None
133
135
  recording_indices: Array # (num_recordings,)
136
+ normalizers: Array | None = None # (num_recordings, num_weights) - Autostep v_i
137
+
138
+
139
+ class NormalizerTrackingConfig(NamedTuple):
140
+ """Configuration for recording per-feature normalizer state during training.
141
+
142
+ Attributes:
143
+ interval: Record normalizer state every N steps
144
+ """
145
+
146
+ interval: int
147
+
148
+
149
+ class NormalizerHistory(NamedTuple):
150
+ """History of per-feature normalizer state recorded during training.
151
+
152
+ Used for analyzing how the OnlineNormalizer adapts to distribution shifts
153
+ (reactive lag diagnostic).
154
+
155
+ Attributes:
156
+ means: Per-feature mean estimates at each recording, shape (num_recordings, feature_dim)
157
+ variances: Per-feature variance estimates at each recording,
158
+ shape (num_recordings, feature_dim)
159
+ recording_indices: Step indices where recordings were made, shape (num_recordings,)
160
+ """
161
+
162
+ means: Array # (num_recordings, feature_dim)
163
+ variances: Array # (num_recordings, feature_dim)
164
+ recording_indices: Array # (num_recordings,)
165
+
166
+
167
+ class BatchedLearningResult(NamedTuple):
168
+ """Result from batched learning loop across multiple seeds.
169
+
170
+ Used with `run_learning_loop_batched` for vmap-based GPU parallelization.
171
+
172
+ Attributes:
173
+ states: Batched learner states - each array has shape (num_seeds, ...)
174
+ metrics: Metrics array with shape (num_seeds, num_steps, 3)
175
+ where columns are [squared_error, error, mean_step_size]
176
+ step_size_history: Optional step-size history with batched shapes,
177
+ or None if tracking was disabled
178
+ """
179
+
180
+ states: "LearnerState" # Batched: each array has shape (num_seeds, ...)
181
+ metrics: Array # Shape: (num_seeds, num_steps, 3)
182
+ step_size_history: StepSizeHistory | None
183
+
184
+
185
+ class BatchedNormalizedResult(NamedTuple):
186
+ """Result from batched normalized learning loop across multiple seeds.
187
+
188
+ Used with `run_normalized_learning_loop_batched` for vmap-based GPU parallelization.
189
+
190
+ Attributes:
191
+ states: Batched normalized learner states - each array has shape (num_seeds, ...)
192
+ metrics: Metrics array with shape (num_seeds, num_steps, 4)
193
+ where columns are [squared_error, error, mean_step_size, normalizer_mean_var]
194
+ step_size_history: Optional step-size history with batched shapes,
195
+ or None if tracking was disabled
196
+ normalizer_history: Optional normalizer history with batched shapes,
197
+ or None if tracking was disabled
198
+ """
199
+
200
+ states: "NormalizedLearnerState" # Batched: each array has shape (num_seeds, ...)
201
+ metrics: Array # Shape: (num_seeds, num_steps, 4)
202
+ step_size_history: StepSizeHistory | None
203
+ normalizer_history: NormalizerHistory | None
134
204
 
135
205
 
136
206
  def create_lms_state(step_size: float = 0.01) -> LMSState:
@@ -30,11 +30,14 @@ class ScanStream(Protocol[StateT]):
30
30
  Type Parameters:
31
31
  StateT: The state type maintained by this stream
32
32
 
33
- Example:
34
- >>> stream = RandomWalkStream(feature_dim=10, drift_rate=0.001)
35
- >>> key = jax.random.key(42)
36
- >>> state = stream.init(key)
37
- >>> timestep, new_state = stream.step(state, jnp.array(0))
33
+ Examples
34
+ --------
35
+ ```python
36
+ stream = RandomWalkStream(feature_dim=10, drift_rate=0.001)
37
+ key = jax.random.key(42)
38
+ state = stream.init(key)
39
+ timestep, new_state = stream.step(state, jnp.array(0))
40
+ ```
38
41
  """
39
42
 
40
43
  @property
@@ -32,7 +32,7 @@ class RandomWalkState(NamedTuple):
32
32
  class RandomWalkStream:
33
33
  """Non-stationary stream where target weights drift via random walk.
34
34
 
35
- The true target function is linear: y* = w_true @ x + noise
35
+ The true target function is linear: `y* = w_true @ x + noise`
36
36
  where w_true evolves via random walk at each time step.
37
37
 
38
38
  This tests the learner's ability to continuously track a moving target.
@@ -590,12 +590,15 @@ class ScaledStreamWrapper:
590
590
  scale factor. Useful for testing how learners handle features at different
591
591
  scales, which is important for understanding normalization benefits.
592
592
 
593
- Example:
594
- >>> stream = ScaledStreamWrapper(
595
- ... AbruptChangeStream(feature_dim=10, change_interval=1000),
596
- ... feature_scales=jnp.array([0.001, 0.01, 0.1, 1.0, 10.0,
597
- ... 100.0, 1000.0, 0.001, 0.01, 0.1])
598
- ... )
593
+ Examples
594
+ --------
595
+ ```python
596
+ stream = ScaledStreamWrapper(
597
+ AbruptChangeStream(feature_dim=10, change_interval=1000),
598
+ feature_scales=jnp.array([0.001, 0.01, 0.1, 1.0, 10.0,
599
+ 100.0, 1000.0, 0.001, 0.01, 0.1])
600
+ )
601
+ ```
599
602
 
600
603
  Attributes:
601
604
  inner_stream: The wrapped stream instance
@@ -693,9 +696,12 @@ def make_scale_range(
693
696
  Returns:
694
697
  Array of shape (feature_dim,) with scale factors
695
698
 
696
- Example:
697
- >>> scales = make_scale_range(10, min_scale=0.01, max_scale=100.0)
698
- >>> stream = ScaledStreamWrapper(RandomWalkStream(10), scales)
699
+ Examples
700
+ --------
701
+ ```python
702
+ scales = make_scale_range(10, min_scale=0.01, max_scale=100.0)
703
+ stream = ScaledStreamWrapper(RandomWalkStream(10), scales)
704
+ ```
699
705
  """
700
706
  if log_spaced:
701
707
  return jnp.logspace(
@@ -110,13 +110,14 @@ def run_single_experiment(
110
110
 
111
111
  final_state: LearnerState | NormalizedLearnerState
112
112
  if isinstance(learner, NormalizedLinearLearner):
113
- final_state, metrics = run_normalized_learning_loop(
113
+ norm_result = run_normalized_learning_loop(
114
114
  learner, stream, config.num_steps, key
115
115
  )
116
+ final_state, metrics = cast(tuple[NormalizedLearnerState, Any], norm_result)
116
117
  metrics_history = metrics_to_dicts(metrics, normalized=True)
117
118
  else:
118
- result = run_learning_loop(learner, stream, config.num_steps, key)
119
- final_state, metrics = cast(tuple[LearnerState, Any], result)
119
+ linear_result = run_learning_loop(learner, stream, config.num_steps, key)
120
+ final_state, metrics = cast(tuple[LearnerState, Any], linear_result)
120
121
  metrics_history = metrics_to_dicts(metrics)
121
122
 
122
123
  return SingleRunResult(
@@ -3,19 +3,22 @@
3
3
  This module provides a simple Timer context manager for measuring execution time
4
4
  and formatting durations in a human-readable format.
5
5
 
6
- Example:
7
- >>> from alberta_framework.utils.timing import Timer
8
- >>>
9
- >>> with Timer("Training"):
10
- ... # run training code
11
- ... pass
12
- Training completed in 1.23s
13
- >>>
14
- >>> # Or capture the duration:
15
- >>> with Timer("Experiment") as t:
16
- ... # run experiment
17
- ... pass
18
- >>> print(f"Took {t.duration:.2f} seconds")
6
+ Examples
7
+ --------
8
+ ```python
9
+ from alberta_framework.utils.timing import Timer
10
+
11
+ with Timer("Training"):
12
+ # run training code
13
+ pass
14
+ # Output: Training completed in 1.23s
15
+
16
+ # Or capture the duration:
17
+ with Timer("Experiment") as t:
18
+ # run experiment
19
+ pass
20
+ print(f"Took {t.duration:.2f} seconds")
21
+ ```
19
22
  """
20
23
 
21
24
  import time
@@ -32,13 +35,13 @@ def format_duration(seconds: float) -> str:
32
35
  Returns:
33
36
  Formatted string like "1.23s", "2m 30.5s", or "1h 5m 30s"
34
37
 
35
- Examples:
36
- >>> format_duration(0.5)
37
- '0.50s'
38
- >>> format_duration(90.5)
39
- '1m 30.50s'
40
- >>> format_duration(3665)
41
- '1h 1m 5.00s'
38
+ Examples
39
+ --------
40
+ ```python
41
+ format_duration(0.5) # Returns: '0.50s'
42
+ format_duration(90.5) # Returns: '1m 30.50s'
43
+ format_duration(3665) # Returns: '1h 1m 5.00s'
44
+ ```
42
45
  """
43
46
  if seconds < 60:
44
47
  return f"{seconds:.2f}s"
@@ -66,22 +69,25 @@ class Timer:
66
69
  start_time: Timestamp when timing started
67
70
  end_time: Timestamp when timing ended
68
71
 
69
- Example:
70
- >>> with Timer("Training loop"):
71
- ... for i in range(1000):
72
- ... pass
73
- Training loop completed in 0.01s
74
-
75
- >>> # Silent timing (no print):
76
- >>> with Timer("Silent", verbose=False) as t:
77
- ... time.sleep(0.1)
78
- >>> print(f"Elapsed: {t.duration:.2f}s")
79
- Elapsed: 0.10s
80
-
81
- >>> # Custom print function:
82
- >>> with Timer("Custom", print_fn=lambda msg: print(f">> {msg}")):
83
- ... pass
84
- >> Custom completed in 0.00s
72
+ Examples
73
+ --------
74
+ ```python
75
+ with Timer("Training loop"):
76
+ for i in range(1000):
77
+ pass
78
+ # Output: Training loop completed in 0.01s
79
+
80
+ # Silent timing (no print):
81
+ with Timer("Silent", verbose=False) as t:
82
+ time.sleep(0.1)
83
+ print(f"Elapsed: {t.duration:.2f}s")
84
+ # Output: Elapsed: 0.10s
85
+
86
+ # Custom print function:
87
+ with Timer("Custom", print_fn=lambda msg: print(f">> {msg}")):
88
+ pass
89
+ # Output: >> Custom completed in 0.00s
90
+ ```
85
91
  """
86
92
 
87
93
  def __init__(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: alberta-framework
3
- Version: 0.1.0
3
+ Version: 0.2.0
4
4
  Summary: Implementation of the Alberta Plan for AI Research - continual learning with meta-learned step-sizes
5
5
  Project-URL: Homepage, https://github.com/j-klawson/alberta-framework
6
6
  Project-URL: Repository, https://github.com/j-klawson/alberta-framework
@@ -49,7 +49,7 @@ Description-Content-Type: text/markdown
49
49
  [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
50
50
  [![Python 3.13+](https://img.shields.io/badge/python-3.13+-blue.svg)](https://www.python.org/downloads/)
51
51
 
52
- A JAX-based research framework implementing components of [The Alberta Plan](https://arxiv.org/abs/2208.11173) in the pursuit of building the foundations of Continual AI.
52
+ A JAX-based research framework implementing components of [The Alberta Plan for AI Research](https://arxiv.org/abs/2208.11173) in the pursuit of building the foundations of Continual AI.
53
53
 
54
54
  > "The agents are complex only because they interact with a complex world... their initial design is as simple, general, and scalable as possible." — *Sutton et al., 2022*
55
55
 
@@ -57,6 +57,14 @@ A JAX-based research framework implementing components of [The Alberta Plan](htt
57
57
 
58
58
  The Alberta Framework provides foundational components for continual reinforcement learning research. Built on JAX for hardware acceleration, the framework emphasizes temporal uniformity every component updates at every time step, with no special training phases or batch processing.
59
59
 
60
+ ## Project Context
61
+
62
+ This framework is developed as part of my D.Eng. work focusing on the foundations of Continual AI. For more background and context see:
63
+
64
+ * **Research Blog**: [blog.9600baud.net](https://blog.9600baud.net)
65
+ * **Replicating Sutton '92**: [The Foundation of Step-size Adaptation](https://blog.9600baud.net/sutton92.html)
66
+ * **About the Author**: [Keith Lawson](https://blog.9600baud.net/about.html)
67
+
60
68
  ### Roadmap
61
69
 
62
70
  Depending on my research trajectory I may or may not implement components required for the plan. The current focus of this framework is the Step 1 Baseline Study, investigating the interaction between adaptive optimizers and online normalization.
@@ -0,0 +1,22 @@
1
+ alberta_framework/__init__.py,sha256=gAafDDmkivDdfnvDVff9zbVY9ilzqqfJ9KvpbRegKqs,5726
2
+ alberta_framework/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ alberta_framework/core/__init__.py,sha256=PSrC4zSxgm_6YXWEQ80aZaunpbQ58QexxKmDDU-jp6c,522
4
+ alberta_framework/core/learners.py,sha256=khZYkae5rlIyV13BW3-hrtPSjGFXPj2IUTM1z74xTTA,37724
5
+ alberta_framework/core/normalizers.py,sha256=Z_d3H17qoXh87DE7k41imvWzkVJQ2xQgDUP7GYSNzAY,5903
6
+ alberta_framework/core/optimizers.py,sha256=OefVuDDG1phh1QQIUyVPsQckl41VrpWFG7hY2eqyc64,14585
7
+ alberta_framework/core/types.py,sha256=_blkTZEm3wNgbweFuqaVL2hxRQU7D6_U66nHp23Pq6Y,9192
8
+ alberta_framework/streams/__init__.py,sha256=bsDgWjWjotDQHMI2lno3dgk8N14pd-2mYAQpXAtCPx4,2035
9
+ alberta_framework/streams/base.py,sha256=9rJxvUgmzd5u2bRV4vi5PxhUvj39EZTD4bZHo-Ptn-U,2168
10
+ alberta_framework/streams/gymnasium.py,sha256=s733X7aEgy05hcSazjZEhBiJChtEL7uVpxwh0fXBQZA,21980
11
+ alberta_framework/streams/synthetic.py,sha256=4R9GR7Kh0LT7GmGtPhzMJGr8HbhrAMUOjvPwLZk6nDg,32979
12
+ alberta_framework/utils/__init__.py,sha256=zfKfnbikhLp0J6UgVa8HeRo59gZHwqOc8jf03s7AaT4,2845
13
+ alberta_framework/utils/experiments.py,sha256=ekGAzveCRgv9YZ5mfAD5Uf7h_PvQnxsNw2KeZN2eu00,10644
14
+ alberta_framework/utils/export.py,sha256=W9RKfeTiyZcLColOGNjBfZU0N6QMXrfPn4pdYcm-OSk,15832
15
+ alberta_framework/utils/metrics.py,sha256=1cryNJoboO67vvRhausaucbYZFgdL_06vaf08UXbojg,3349
16
+ alberta_framework/utils/statistics.py,sha256=4fbzNlmsdUaM5lLW1BhL5B5MUpnqimQlwJklZ4x0y0U,15416
17
+ alberta_framework/utils/timing.py,sha256=JOLq8CpCAV7LWOWkftxefduSFjaXnVwal1MFBKEMdJI,4049
18
+ alberta_framework/utils/visualization.py,sha256=PmKBD3KGabNhgDizcNiGJEbVCyDL1YMUE5yTwgJHu2o,17924
19
+ alberta_framework-0.2.0.dist-info/METADATA,sha256=zcg9hjFzz_oV3Jij8xvDnmww0p8-VPhzSBsN1VD8rvw,7763
20
+ alberta_framework-0.2.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
21
+ alberta_framework-0.2.0.dist-info/licenses/LICENSE,sha256=TI1avodt5mvxz7sunyxIa0HlNgLQcmKNLeRjCVcgKmE,10754
22
+ alberta_framework-0.2.0.dist-info/RECORD,,
@@ -1,22 +0,0 @@
1
- alberta_framework/__init__.py,sha256=gPLBA2EiPcElsYp_U_Rs7C6wlrGHr8w5IL6C0F90zec,4739
2
- alberta_framework/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- alberta_framework/core/__init__.py,sha256=PSrC4zSxgm_6YXWEQ80aZaunpbQ58QexxKmDDU-jp6c,522
4
- alberta_framework/core/learners.py,sha256=Abq_iOb9CWy9DMWeIX_7PXxo59gXEmlvNnf8zrDKQpo,18157
5
- alberta_framework/core/normalizers.py,sha256=-OFkdKcfx4VTwm4WLXu1hxrh8DTdwFlIG6CgVYgaZBk,5905
6
- alberta_framework/core/optimizers.py,sha256=uHEOhE1ThLcXJ18zagW9SQAEiNpsWapNWSMelbhkdNY,14559
7
- alberta_framework/core/types.py,sha256=Op9EHIIoEZGKbbr3b7xijaOurlQ-mxohBRv7rnVybro,6307
8
- alberta_framework/streams/__init__.py,sha256=bsDgWjWjotDQHMI2lno3dgk8N14pd-2mYAQpXAtCPx4,2035
9
- alberta_framework/streams/base.py,sha256=81zqXTF30Orj0N2BXSLYVHF9wUYZSqthqQi1MG5Kzxs,2165
10
- alberta_framework/streams/gymnasium.py,sha256=s733X7aEgy05hcSazjZEhBiJChtEL7uVpxwh0fXBQZA,21980
11
- alberta_framework/streams/synthetic.py,sha256=kRQktC4NNlFvoF_FmY_WG9VkiASGOudz8qdI5VjoRq8,32963
12
- alberta_framework/utils/__init__.py,sha256=zfKfnbikhLp0J6UgVa8HeRo59gZHwqOc8jf03s7AaT4,2845
13
- alberta_framework/utils/experiments.py,sha256=8N_JrffUa1S_lIZQIqKDuBxyv4UYt9QXzLlo-YnMAEU,10554
14
- alberta_framework/utils/export.py,sha256=W9RKfeTiyZcLColOGNjBfZU0N6QMXrfPn4pdYcm-OSk,15832
15
- alberta_framework/utils/metrics.py,sha256=1cryNJoboO67vvRhausaucbYZFgdL_06vaf08UXbojg,3349
16
- alberta_framework/utils/statistics.py,sha256=4fbzNlmsdUaM5lLW1BhL5B5MUpnqimQlwJklZ4x0y0U,15416
17
- alberta_framework/utils/timing.py,sha256=05NwXrIc9nS2p2MCHjdOgglPbE1CHZsLLdSB6em7YNY,4110
18
- alberta_framework/utils/visualization.py,sha256=PmKBD3KGabNhgDizcNiGJEbVCyDL1YMUE5yTwgJHu2o,17924
19
- alberta_framework-0.1.0.dist-info/METADATA,sha256=G4lIPB7NJlkJGOlmbo7BqeIsHNUl2M-L4f8Gd1T2Ro8,7332
20
- alberta_framework-0.1.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
21
- alberta_framework-0.1.0.dist-info/licenses/LICENSE,sha256=TI1avodt5mvxz7sunyxIa0HlNgLQcmKNLeRjCVcgKmE,10754
22
- alberta_framework-0.1.0.dist-info/RECORD,,