alberta-framework 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -117,9 +117,7 @@ def _flatten_action(action: Any, space: gymnasium.spaces.Space[Any]) -> Array:
117
117
  raise ValueError(f"Unsupported space type: {type(space).__name__}")
118
118
 
119
119
 
120
- def make_random_policy(
121
- env: gymnasium.Env[Any, Any], seed: int = 0
122
- ) -> Callable[[Array], Any]:
120
+ def make_random_policy(env: gymnasium.Env[Any, Any], seed: int = 0) -> Callable[[Array], Any]:
123
121
  """Create a random action policy for an environment.
124
122
 
125
123
  Args:
@@ -147,10 +145,7 @@ def make_random_policy(
147
145
  return jr.uniform(key, action_space.shape, minval=low, maxval=high)
148
146
  elif isinstance(action_space, gymnasium.spaces.MultiDiscrete):
149
147
  nvec = action_space.nvec
150
- return [
151
- int(jr.randint(jr.fold_in(key, i), (), 0, n))
152
- for i, n in enumerate(nvec)
153
- ]
148
+ return [int(jr.randint(jr.fold_in(key, i), (), 0, n)) for i, n in enumerate(nvec)]
154
149
  else:
155
150
  raise ValueError(f"Unsupported action space: {type(action_space).__name__}")
156
151
 
@@ -284,9 +279,7 @@ def learn_from_trajectory(
284
279
  if learner_state is None:
285
280
  learner_state = learner.init(observations.shape[1])
286
281
 
287
- def step_fn(
288
- state: LearnerState, inputs: tuple[Array, Array]
289
- ) -> tuple[LearnerState, Array]:
282
+ def step_fn(state: LearnerState, inputs: tuple[Array, Array]) -> tuple[LearnerState, Array]:
290
283
  obs, target = inputs
291
284
  result = learner.update(state, obs, target)
292
285
  return result.state, result.metrics
@@ -540,9 +540,7 @@ class PeriodicChangeStream:
540
540
  step_count=jnp.array(0, dtype=jnp.int32),
541
541
  )
542
542
 
543
- def step(
544
- self, state: PeriodicChangeState, idx: Array
545
- ) -> tuple[TimeStep, PeriodicChangeState]:
543
+ def step(self, state: PeriodicChangeState, idx: Array) -> tuple[TimeStep, PeriodicChangeState]:
546
544
  """Generate one time step.
547
545
 
548
546
  Args:
@@ -557,9 +555,7 @@ class PeriodicChangeStream:
557
555
 
558
556
  # Compute oscillating weights: w(t) = base + amplitude * sin(2π * t / period + phase)
559
557
  t = state.step_count.astype(jnp.float32)
560
- oscillation = self._amplitude * jnp.sin(
561
- 2.0 * jnp.pi * t / self._period + state.phases
562
- )
558
+ oscillation = self._amplitude * jnp.sin(2.0 * jnp.pi * t / self._period + state.phases)
563
559
  true_weights = state.base_weights + oscillation
564
560
 
565
561
  # Generate observation
@@ -955,9 +951,7 @@ class ScaleDriftStream:
955
951
  step_count=jnp.array(0, dtype=jnp.int32),
956
952
  )
957
953
 
958
- def step(
959
- self, state: ScaleDriftState, idx: Array
960
- ) -> tuple[TimeStep, ScaleDriftState]:
954
+ def step(self, state: ScaleDriftState, idx: Array) -> tuple[TimeStep, ScaleDriftState]:
961
955
  """Generate one time step.
962
956
 
963
957
  Args:
@@ -110,9 +110,7 @@ def run_single_experiment(
110
110
 
111
111
  final_state: LearnerState | NormalizedLearnerState
112
112
  if isinstance(learner, NormalizedLinearLearner):
113
- norm_result = run_normalized_learning_loop(
114
- learner, stream, config.num_steps, key
115
- )
113
+ norm_result = run_normalized_learning_loop(learner, stream, config.num_steps, key)
116
114
  final_state, metrics = cast(tuple[NormalizedLearnerState, Any], norm_result)
117
115
  metrics_history = metrics_to_dicts(metrics, normalized=True)
118
116
  else:
@@ -51,14 +51,16 @@ def _export_summary_csv(
51
51
 
52
52
  for name, agg in results.items():
53
53
  summary = agg.summary[metric]
54
- writer.writerow([
55
- name,
56
- f"{summary.mean:.6f}",
57
- f"{summary.std:.6f}",
58
- f"{summary.min:.6f}",
59
- f"{summary.max:.6f}",
60
- summary.n_seeds,
61
- ])
54
+ writer.writerow(
55
+ [
56
+ name,
57
+ f"{summary.mean:.6f}",
58
+ f"{summary.std:.6f}",
59
+ f"{summary.min:.6f}",
60
+ f"{summary.max:.6f}",
61
+ summary.n_seeds,
62
+ ]
63
+ )
62
64
 
63
65
 
64
66
  def _export_timeseries_csv(
@@ -497,13 +499,15 @@ def results_to_dataframe(
497
499
  rows = []
498
500
  for name, agg in results.items():
499
501
  summary = agg.summary[metric]
500
- rows.append({
501
- "method": name,
502
- "mean": summary.mean,
503
- "std": summary.std,
504
- "min": summary.min,
505
- "max": summary.max,
506
- "n_seeds": summary.n_seeds,
507
- })
502
+ rows.append(
503
+ {
504
+ "method": name,
505
+ "mean": summary.mean,
506
+ "std": summary.std,
507
+ "min": summary.min,
508
+ "max": summary.max,
509
+ "n_seeds": summary.n_seeds,
510
+ }
511
+ )
508
512
 
509
513
  return pd.DataFrame(rows)
@@ -313,9 +313,7 @@ def wilcoxon_comparison(
313
313
  stat_val = float(result[0])
314
314
  p_val = float(result[1])
315
315
  except ImportError:
316
- raise ImportError(
317
- "scipy is required for Wilcoxon test. Install with: pip install scipy"
318
- )
316
+ raise ImportError("scipy is required for Wilcoxon test. Install with: pip install scipy")
319
317
 
320
318
  effect = cohens_d(a, b)
321
319
 
@@ -443,18 +441,28 @@ def pairwise_comparisons(
443
441
 
444
442
  if test == "ttest":
445
443
  result = ttest_comparison(
446
- values_a, values_b, paired=True, alpha=alpha,
447
- method_a=name_a, method_b=name_b,
444
+ values_a,
445
+ values_b,
446
+ paired=True,
447
+ alpha=alpha,
448
+ method_a=name_a,
449
+ method_b=name_b,
448
450
  )
449
451
  elif test == "mann_whitney":
450
452
  result = mann_whitney_comparison(
451
- values_a, values_b, alpha=alpha,
452
- method_a=name_a, method_b=name_b,
453
+ values_a,
454
+ values_b,
455
+ alpha=alpha,
456
+ method_a=name_a,
457
+ method_b=name_b,
453
458
  )
454
459
  else: # wilcoxon
455
460
  result = wilcoxon_comparison(
456
- values_a, values_b, alpha=alpha,
457
- method_a=name_a, method_b=name_b,
461
+ values_a,
462
+ values_b,
463
+ alpha=alpha,
464
+ method_a=name_a,
465
+ method_b=name_b,
458
466
  )
459
467
 
460
468
  comparisons[(name_a, name_b)] = result
@@ -69,29 +69,33 @@ def set_publication_style(
69
69
  pass
70
70
 
71
71
  # Configure matplotlib
72
- plt.rcParams.update({
73
- "font.size": font_size,
74
- "axes.labelsize": font_size,
75
- "axes.titlesize": font_size + 1,
76
- "xtick.labelsize": font_size - 1,
77
- "ytick.labelsize": font_size - 1,
78
- "legend.fontsize": font_size - 1,
79
- "figure.figsize": (_current_style["figure_width"], _current_style["figure_height"]),
80
- "figure.dpi": _current_style["dpi"],
81
- "savefig.dpi": _current_style["dpi"],
82
- "lines.linewidth": _current_style["line_width"],
83
- "lines.markersize": _current_style["marker_size"],
84
- "axes.linewidth": 0.8,
85
- "grid.linewidth": 0.5,
86
- "grid.alpha": 0.3,
87
- })
72
+ plt.rcParams.update(
73
+ {
74
+ "font.size": font_size,
75
+ "axes.labelsize": font_size,
76
+ "axes.titlesize": font_size + 1,
77
+ "xtick.labelsize": font_size - 1,
78
+ "ytick.labelsize": font_size - 1,
79
+ "legend.fontsize": font_size - 1,
80
+ "figure.figsize": (_current_style["figure_width"], _current_style["figure_height"]),
81
+ "figure.dpi": _current_style["dpi"],
82
+ "savefig.dpi": _current_style["dpi"],
83
+ "lines.linewidth": _current_style["line_width"],
84
+ "lines.markersize": _current_style["marker_size"],
85
+ "axes.linewidth": 0.8,
86
+ "grid.linewidth": 0.5,
87
+ "grid.alpha": 0.3,
88
+ }
89
+ )
88
90
 
89
91
  if use_latex:
90
- plt.rcParams.update({
91
- "text.usetex": True,
92
- "font.family": "serif",
93
- "font.serif": ["Computer Modern Roman"],
94
- })
92
+ plt.rcParams.update(
93
+ {
94
+ "text.usetex": True,
95
+ "font.family": "serif",
96
+ "font.serif": ["Computer Modern Roman"],
97
+ }
98
+ )
95
99
 
96
100
 
97
101
  def plot_learning_curves(
@@ -142,10 +146,12 @@ def plot_learning_curves(
142
146
  metric_array = agg.metric_arrays[metric]
143
147
 
144
148
  # Smooth each seed individually, then compute statistics
145
- smoothed = np.array([
146
- compute_running_mean(metric_array[seed_idx], window_size)
147
- for seed_idx in range(metric_array.shape[0])
148
- ])
149
+ smoothed = np.array(
150
+ [
151
+ compute_running_mean(metric_array[seed_idx], window_size)
152
+ for seed_idx in range(metric_array.shape[0])
153
+ ]
154
+ )
149
155
 
150
156
  mean, ci_lower, ci_upper = compute_timeseries_statistics(smoothed)
151
157
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: alberta-framework
3
- Version: 0.3.2
3
+ Version: 0.4.0
4
4
  Summary: Implementation of the Alberta Plan for AI Research - continual learning with meta-learned step-sizes
5
5
  Project-URL: Homepage, https://github.com/j-klawson/alberta-framework
6
6
  Project-URL: Repository, https://github.com/j-klawson/alberta-framework
@@ -113,10 +113,15 @@ state, metrics = run_learning_loop(learner, stream, num_steps=10000, key=jr.key(
113
113
 
114
114
  ### Optimizers
115
115
 
116
+ **Supervised Learning:**
116
117
  - **LMS**: Fixed step-size baseline
117
118
  - **IDBD**: Per-weight adaptive step-sizes via gradient correlation (Sutton, 1992)
118
119
  - **Autostep**: Tuning-free adaptation with gradient normalization (Mahmood et al., 2012)
119
120
 
121
+ **TD Learning:**
122
+ - **TDIDBD**: TD learning with per-weight adaptive step-sizes and eligibility traces (Kearney et al., 2019)
123
+ - **AutoTDIDBD**: TD learning with AutoStep-style normalization for improved stability
124
+
120
125
  ### Streams
121
126
 
122
127
  Non-stationary experience generators implementing the `ScanStream` protocol:
@@ -126,6 +131,17 @@ Non-stationary experience generators implementing the `ScanStream` protocol:
126
131
  - `PeriodicChangeStream`: Sinusoidal oscillation
127
132
  - `DynamicScaleShiftStream`: Time-varying feature scales
128
133
 
134
+ ### TD Learning
135
+
136
+ For temporal-difference learning with value function approximation:
137
+
138
+ ```python
139
+ from alberta_framework import TDLinearLearner, TDIDBD, run_td_learning_loop
140
+
141
+ learner = TDLinearLearner(optimizer=TDIDBD(trace_decay=0.9))
142
+ state, metrics = run_td_learning_loop(learner, td_stream, num_steps=10000, key=jr.key(42))
143
+ ```
144
+
129
145
  ### Gymnasium Integration
130
146
 
131
147
  ```python
@@ -202,6 +218,13 @@ If you use this framework in your research, please cite:
202
218
  booktitle = {IEEE International Conference on Acoustics, Speech and Signal Processing},
203
219
  year = {2012}
204
220
  }
221
+
222
+ @inproceedings{kearney2019tidbd,
223
+ title = {Learning Feature Relevance Through Step Size Adaptation in Temporal-Difference Learning},
224
+ author = {Kearney, Alex and Veeriah, Vivek and Travnik, Jaden and Sutton, Richard S. and Pilarski, Patrick M.},
225
+ booktitle = {International Conference on Machine Learning},
226
+ year = {2019}
227
+ }
205
228
  ```
206
229
 
207
230
  ## License
@@ -0,0 +1,22 @@
1
+ alberta_framework/__init__.py,sha256=RB8-ly8UK6IGnDX8Qw3mW_uSJc8iEJT57CCXg6cxj4c,6451
2
+ alberta_framework/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ alberta_framework/core/__init__.py,sha256=wr7VZrbXP0SWw-_MEFn0rfn1wQrUaDkIm1tXFYhjz1k,961
4
+ alberta_framework/core/learners.py,sha256=lePBbeDReGAg73TvpLPMQNI2k5H6OQfECnvi2qTYT-I,45277
5
+ alberta_framework/core/normalizers.py,sha256=GZkmFbRI3lk7HvNMSX9ByOvXCj_3QX_5h_k6-Y35IqY,5893
6
+ alberta_framework/core/optimizers.py,sha256=9fdic6h-vxBm7BXox8QSLhtzEEkVc7VSHn3sufuIMZY,34588
7
+ alberta_framework/core/types.py,sha256=SZr16VQAemJCXIlBoOlp91rqDeWM5rd831Vya7G2ths,16430
8
+ alberta_framework/streams/__init__.py,sha256=bsDgWjWjotDQHMI2lno3dgk8N14pd-2mYAQpXAtCPx4,2035
9
+ alberta_framework/streams/base.py,sha256=9rJxvUgmzd5u2bRV4vi5PxhUvj39EZTD4bZHo-Ptn-U,2168
10
+ alberta_framework/streams/gymnasium.py,sha256=3Kg8qORzvNqRkTceQ7THfm3kp3_Skbva1XbtCDBTsT4,21914
11
+ alberta_framework/streams/synthetic.py,sha256=8e5EY3rtiJhdQbLlWyalNE3nRHhn_5T2Z_aHRS4BpG4,33457
12
+ alberta_framework/utils/__init__.py,sha256=zfKfnbikhLp0J6UgVa8HeRo59gZHwqOc8jf03s7AaT4,2845
13
+ alberta_framework/utils/experiments.py,sha256=vxcbCxUloWu2J2mKHjdkM6cLeY9EYIq1JvpR-gyfPwQ,10622
14
+ alberta_framework/utils/export.py,sha256=vGsBTFcr84Ga8Ka0IZFMVqhMUNOCQstVcfyU468V3Cs,15940
15
+ alberta_framework/utils/metrics.py,sha256=1cryNJoboO67vvRhausaucbYZFgdL_06vaf08UXbojg,3349
16
+ alberta_framework/utils/statistics.py,sha256=QZsDVVNqhiY6chGFLzYmtiUVBIHeBfr_LDTuSyBQROY,15594
17
+ alberta_framework/utils/timing.py,sha256=JOLq8CpCAV7LWOWkftxefduSFjaXnVwal1MFBKEMdJI,4049
18
+ alberta_framework/utils/visualization.py,sha256=aQc4PsWGFCycm0uPvaChFjaoWgBOsD7UHOydWF0WKFo,18070
19
+ alberta_framework-0.4.0.dist-info/METADATA,sha256=7jnUWrT9qFNC6UFZ6r9gxpXUsIO1MyeSsJQFNdBgc1g,8769
20
+ alberta_framework-0.4.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
21
+ alberta_framework-0.4.0.dist-info/licenses/LICENSE,sha256=TI1avodt5mvxz7sunyxIa0HlNgLQcmKNLeRjCVcgKmE,10754
22
+ alberta_framework-0.4.0.dist-info/RECORD,,
@@ -1,22 +0,0 @@
1
- alberta_framework/__init__.py,sha256=gAafDDmkivDdfnvDVff9zbVY9ilzqqfJ9KvpbRegKqs,5726
2
- alberta_framework/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- alberta_framework/core/__init__.py,sha256=PSrC4zSxgm_6YXWEQ80aZaunpbQ58QexxKmDDU-jp6c,522
4
- alberta_framework/core/learners.py,sha256=gUhX7caXBfpWYgnvYTp5YKXfP6wbzB2T2gkSMMtrHDQ,38042
5
- alberta_framework/core/normalizers.py,sha256=QmKmha-mFgKi1KD-f8xuB2U175yQL6Ll0D4c8OONIl0,5927
6
- alberta_framework/core/optimizers.py,sha256=a4gYac5DyXReir9ycudRg8uQ9b53uLWTIldZ1A3Ae5c,14646
7
- alberta_framework/core/types.py,sha256=XBmT689nRKEBwwfUbpohi4IfT-d-eJRIFH_L2swYW2E,9793
8
- alberta_framework/streams/__init__.py,sha256=bsDgWjWjotDQHMI2lno3dgk8N14pd-2mYAQpXAtCPx4,2035
9
- alberta_framework/streams/base.py,sha256=9rJxvUgmzd5u2bRV4vi5PxhUvj39EZTD4bZHo-Ptn-U,2168
10
- alberta_framework/streams/gymnasium.py,sha256=s733X7aEgy05hcSazjZEhBiJChtEL7uVpxwh0fXBQZA,21980
11
- alberta_framework/streams/synthetic.py,sha256=8njzQCFRi_iVgdPA3slyn46vFIHHkIwaZsABZyPwqnU,33507
12
- alberta_framework/utils/__init__.py,sha256=zfKfnbikhLp0J6UgVa8HeRo59gZHwqOc8jf03s7AaT4,2845
13
- alberta_framework/utils/experiments.py,sha256=ekGAzveCRgv9YZ5mfAD5Uf7h_PvQnxsNw2KeZN2eu00,10644
14
- alberta_framework/utils/export.py,sha256=W9RKfeTiyZcLColOGNjBfZU0N6QMXrfPn4pdYcm-OSk,15832
15
- alberta_framework/utils/metrics.py,sha256=1cryNJoboO67vvRhausaucbYZFgdL_06vaf08UXbojg,3349
16
- alberta_framework/utils/statistics.py,sha256=4fbzNlmsdUaM5lLW1BhL5B5MUpnqimQlwJklZ4x0y0U,15416
17
- alberta_framework/utils/timing.py,sha256=JOLq8CpCAV7LWOWkftxefduSFjaXnVwal1MFBKEMdJI,4049
18
- alberta_framework/utils/visualization.py,sha256=PmKBD3KGabNhgDizcNiGJEbVCyDL1YMUE5yTwgJHu2o,17924
19
- alberta_framework-0.3.2.dist-info/METADATA,sha256=aD7q4wh1xm0pQiARtRnUrgLU83JQ8JBidzK-bXmn5_s,7872
20
- alberta_framework-0.3.2.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
21
- alberta_framework-0.3.2.dist-info/licenses/LICENSE,sha256=TI1avodt5mvxz7sunyxIa0HlNgLQcmKNLeRjCVcgKmE,10754
22
- alberta_framework-0.3.2.dist-info/RECORD,,