alberta-framework 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,995 @@
1
+ """Synthetic non-stationary experience streams for testing continual learning.
2
+
3
+ These streams generate non-stationary supervised learning problems where
4
+ the target function changes over time, testing the learner's ability to
5
+ track and adapt.
6
+
7
+ All streams use JAX-compatible pure functions that work with jax.lax.scan.
8
+ """
9
+
10
+ from typing import Any, NamedTuple
11
+
12
+ import jax.numpy as jnp
13
+ import jax.random as jr
14
+ from jax import Array
15
+
16
+ from alberta_framework.core.types import TimeStep
17
+ from alberta_framework.streams.base import ScanStream
18
+
19
+
20
+ class RandomWalkState(NamedTuple):
21
+ """State for RandomWalkStream.
22
+
23
+ Attributes:
24
+ key: JAX random key for generating randomness
25
+ true_weights: Current true target weights
26
+ """
27
+
28
+ key: Array
29
+ true_weights: Array
30
+
31
+
32
+ class RandomWalkStream:
33
+ """Non-stationary stream where target weights drift via random walk.
34
+
35
+ The true target function is linear: y* = w_true @ x + noise
36
+ where w_true evolves via random walk at each time step.
37
+
38
+ This tests the learner's ability to continuously track a moving target.
39
+
40
+ Attributes:
41
+ feature_dim: Dimension of observation vectors
42
+ drift_rate: Standard deviation of weight drift per step
43
+ noise_std: Standard deviation of observation noise
44
+ feature_std: Standard deviation of features
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ feature_dim: int,
50
+ drift_rate: float = 0.001,
51
+ noise_std: float = 0.1,
52
+ feature_std: float = 1.0,
53
+ ):
54
+ """Initialize the random walk target stream.
55
+
56
+ Args:
57
+ feature_dim: Dimension of the feature/observation vectors
58
+ drift_rate: Std dev of weight changes per step (controls non-stationarity)
59
+ noise_std: Std dev of target noise
60
+ feature_std: Std dev of feature values
61
+ """
62
+ self._feature_dim = feature_dim
63
+ self._drift_rate = drift_rate
64
+ self._noise_std = noise_std
65
+ self._feature_std = feature_std
66
+
67
+ @property
68
+ def feature_dim(self) -> int:
69
+ """Return the dimension of observation vectors."""
70
+ return self._feature_dim
71
+
72
+ def init(self, key: Array) -> RandomWalkState:
73
+ """Initialize stream state.
74
+
75
+ Args:
76
+ key: JAX random key
77
+
78
+ Returns:
79
+ Initial stream state with random weights
80
+ """
81
+ key, subkey = jr.split(key)
82
+ weights = jr.normal(subkey, (self._feature_dim,), dtype=jnp.float32)
83
+ return RandomWalkState(key=key, true_weights=weights)
84
+
85
+ def step(self, state: RandomWalkState, idx: Array) -> tuple[TimeStep, RandomWalkState]:
86
+ """Generate one time step.
87
+
88
+ Args:
89
+ state: Current stream state
90
+ idx: Current step index (unused)
91
+
92
+ Returns:
93
+ Tuple of (timestep, new_state)
94
+ """
95
+ del idx # unused
96
+ key, k_drift, k_x, k_noise = jr.split(state.key, 4)
97
+
98
+ # Drift weights
99
+ drift = jr.normal(k_drift, state.true_weights.shape, dtype=jnp.float32)
100
+ new_weights = state.true_weights + self._drift_rate * drift
101
+
102
+ # Generate observation and target
103
+ x = self._feature_std * jr.normal(k_x, (self._feature_dim,), dtype=jnp.float32)
104
+ noise = self._noise_std * jr.normal(k_noise, (), dtype=jnp.float32)
105
+ target = jnp.dot(new_weights, x) + noise
106
+
107
+ timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
108
+ new_state = RandomWalkState(key=key, true_weights=new_weights)
109
+
110
+ return timestep, new_state
111
+
112
+
113
+ class AbruptChangeState(NamedTuple):
114
+ """State for AbruptChangeStream.
115
+
116
+ Attributes:
117
+ key: JAX random key for generating randomness
118
+ true_weights: Current true target weights
119
+ step_count: Number of steps taken
120
+ """
121
+
122
+ key: Array
123
+ true_weights: Array
124
+ step_count: Array
125
+
126
+
127
+ class AbruptChangeStream:
128
+ """Non-stationary stream with sudden target weight changes.
129
+
130
+ Target weights remain constant for a period, then abruptly change
131
+ to new random values. Tests the learner's ability to detect and
132
+ rapidly adapt to distribution shifts.
133
+
134
+ Attributes:
135
+ feature_dim: Dimension of observation vectors
136
+ change_interval: Number of steps between weight changes
137
+ noise_std: Standard deviation of observation noise
138
+ feature_std: Standard deviation of features
139
+ """
140
+
141
+ def __init__(
142
+ self,
143
+ feature_dim: int,
144
+ change_interval: int = 1000,
145
+ noise_std: float = 0.1,
146
+ feature_std: float = 1.0,
147
+ ):
148
+ """Initialize the abrupt change stream.
149
+
150
+ Args:
151
+ feature_dim: Dimension of feature vectors
152
+ change_interval: Steps between abrupt weight changes
153
+ noise_std: Std dev of target noise
154
+ feature_std: Std dev of feature values
155
+ """
156
+ self._feature_dim = feature_dim
157
+ self._change_interval = change_interval
158
+ self._noise_std = noise_std
159
+ self._feature_std = feature_std
160
+
161
+ @property
162
+ def feature_dim(self) -> int:
163
+ """Return the dimension of observation vectors."""
164
+ return self._feature_dim
165
+
166
+ def init(self, key: Array) -> AbruptChangeState:
167
+ """Initialize stream state.
168
+
169
+ Args:
170
+ key: JAX random key
171
+
172
+ Returns:
173
+ Initial stream state
174
+ """
175
+ key, subkey = jr.split(key)
176
+ weights = jr.normal(subkey, (self._feature_dim,), dtype=jnp.float32)
177
+ return AbruptChangeState(
178
+ key=key,
179
+ true_weights=weights,
180
+ step_count=jnp.array(0, dtype=jnp.int32),
181
+ )
182
+
183
+ def step(self, state: AbruptChangeState, idx: Array) -> tuple[TimeStep, AbruptChangeState]:
184
+ """Generate one time step.
185
+
186
+ Args:
187
+ state: Current stream state
188
+ idx: Current step index (unused)
189
+
190
+ Returns:
191
+ Tuple of (timestep, new_state)
192
+ """
193
+ del idx # unused
194
+ key, key_weights, key_x, key_noise = jr.split(state.key, 4)
195
+
196
+ # Determine if we should change weights
197
+ should_change = state.step_count % self._change_interval == 0
198
+
199
+ # Generate new weights (always generated but only used if should_change)
200
+ new_random_weights = jr.normal(key_weights, (self._feature_dim,), dtype=jnp.float32)
201
+
202
+ # Use jnp.where to conditionally update weights (JIT-compatible)
203
+ new_weights = jnp.where(should_change, new_random_weights, state.true_weights)
204
+
205
+ # Generate observation
206
+ x = self._feature_std * jr.normal(key_x, (self._feature_dim,), dtype=jnp.float32)
207
+
208
+ # Compute target
209
+ noise = self._noise_std * jr.normal(key_noise, (), dtype=jnp.float32)
210
+ target = jnp.dot(new_weights, x) + noise
211
+
212
+ timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
213
+ new_state = AbruptChangeState(
214
+ key=key,
215
+ true_weights=new_weights,
216
+ step_count=state.step_count + 1,
217
+ )
218
+
219
+ return timestep, new_state
220
+
221
+
222
+ class SuttonExperiment1State(NamedTuple):
223
+ """State for SuttonExperiment1Stream.
224
+
225
+ Attributes:
226
+ key: JAX random key for generating randomness
227
+ signs: Signs (+1/-1) for the relevant inputs
228
+ step_count: Number of steps taken
229
+ """
230
+
231
+ key: Array
232
+ signs: Array
233
+ step_count: Array
234
+
235
+
236
+ class SuttonExperiment1Stream:
237
+ """Non-stationary stream replicating Experiment 1 from Sutton 1992.
238
+
239
+ This stream implements the exact task from Sutton's IDBD paper:
240
+ - 20 real-valued inputs drawn from N(0, 1)
241
+ - Only first 5 inputs are relevant (weights are ±1)
242
+ - Last 15 inputs are irrelevant (weights are 0)
243
+ - Every change_interval steps, one of the 5 relevant signs is flipped
244
+
245
+ Reference: Sutton, R.S. (1992). "Adapting Bias by Gradient Descent:
246
+ An Incremental Version of Delta-Bar-Delta"
247
+
248
+ Attributes:
249
+ num_relevant: Number of relevant inputs (default 5)
250
+ num_irrelevant: Number of irrelevant inputs (default 15)
251
+ change_interval: Steps between sign changes (default 20)
252
+ """
253
+
254
+ def __init__(
255
+ self,
256
+ num_relevant: int = 5,
257
+ num_irrelevant: int = 15,
258
+ change_interval: int = 20,
259
+ ):
260
+ """Initialize the Sutton Experiment 1 stream.
261
+
262
+ Args:
263
+ num_relevant: Number of relevant inputs with ±1 weights
264
+ num_irrelevant: Number of irrelevant inputs with 0 weights
265
+ change_interval: Number of steps between sign flips
266
+ """
267
+ self._num_relevant = num_relevant
268
+ self._num_irrelevant = num_irrelevant
269
+ self._change_interval = change_interval
270
+
271
+ @property
272
+ def feature_dim(self) -> int:
273
+ """Return the dimension of observation vectors."""
274
+ return self._num_relevant + self._num_irrelevant
275
+
276
+ def init(self, key: Array) -> SuttonExperiment1State:
277
+ """Initialize stream state.
278
+
279
+ Args:
280
+ key: JAX random key
281
+
282
+ Returns:
283
+ Initial stream state with all +1 signs
284
+ """
285
+ signs = jnp.ones(self._num_relevant, dtype=jnp.float32)
286
+ return SuttonExperiment1State(
287
+ key=key,
288
+ signs=signs,
289
+ step_count=jnp.array(0, dtype=jnp.int32),
290
+ )
291
+
292
+ def step(
293
+ self, state: SuttonExperiment1State, idx: Array
294
+ ) -> tuple[TimeStep, SuttonExperiment1State]:
295
+ """Generate one time step.
296
+
297
+ At each step:
298
+ 1. If at a change interval (and not step 0), flip one random sign
299
+ 2. Generate random inputs from N(0, 1)
300
+ 3. Compute target as sum of relevant inputs weighted by signs
301
+
302
+ Args:
303
+ state: Current stream state
304
+ idx: Current step index (unused)
305
+
306
+ Returns:
307
+ Tuple of (timestep, new_state)
308
+ """
309
+ del idx # unused
310
+ key, key_x, key_which = jr.split(state.key, 3)
311
+
312
+ # Determine if we should flip a sign (not at step 0)
313
+ should_flip = (state.step_count > 0) & (state.step_count % self._change_interval == 0)
314
+
315
+ # Select which sign to flip
316
+ idx_to_flip = jr.randint(key_which, (), 0, self._num_relevant)
317
+
318
+ # Create flip mask
319
+ flip_mask = jnp.where(
320
+ jnp.arange(self._num_relevant) == idx_to_flip,
321
+ jnp.array(-1.0, dtype=jnp.float32),
322
+ jnp.array(1.0, dtype=jnp.float32),
323
+ )
324
+
325
+ # Apply flip mask conditionally
326
+ new_signs = jnp.where(should_flip, state.signs * flip_mask, state.signs)
327
+
328
+ # Generate observation from N(0, 1)
329
+ x = jr.normal(key_x, (self.feature_dim,), dtype=jnp.float32)
330
+
331
+ # Compute target: sum of first num_relevant inputs weighted by signs
332
+ target = jnp.dot(new_signs, x[: self._num_relevant])
333
+
334
+ timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
335
+ new_state = SuttonExperiment1State(
336
+ key=key,
337
+ signs=new_signs,
338
+ step_count=state.step_count + 1,
339
+ )
340
+
341
+ return timestep, new_state
342
+
343
+
344
+ class CyclicState(NamedTuple):
345
+ """State for CyclicStream.
346
+
347
+ Attributes:
348
+ key: JAX random key for generating randomness
349
+ configurations: Pre-generated weight configurations
350
+ step_count: Number of steps taken
351
+ """
352
+
353
+ key: Array
354
+ configurations: Array
355
+ step_count: Array
356
+
357
+
358
+ class CyclicStream:
359
+ """Non-stationary stream that cycles between known weight configurations.
360
+
361
+ Weights cycle through a fixed set of configurations. Tests whether
362
+ the learner can re-adapt quickly to previously seen targets.
363
+
364
+ Attributes:
365
+ feature_dim: Dimension of observation vectors
366
+ cycle_length: Number of steps per configuration before switching
367
+ num_configurations: Number of weight configurations to cycle through
368
+ noise_std: Standard deviation of observation noise
369
+ feature_std: Standard deviation of features
370
+ """
371
+
372
+ def __init__(
373
+ self,
374
+ feature_dim: int,
375
+ cycle_length: int = 500,
376
+ num_configurations: int = 4,
377
+ noise_std: float = 0.1,
378
+ feature_std: float = 1.0,
379
+ ):
380
+ """Initialize the cyclic target stream.
381
+
382
+ Args:
383
+ feature_dim: Dimension of feature vectors
384
+ cycle_length: Steps spent in each configuration
385
+ num_configurations: Number of configurations to cycle through
386
+ noise_std: Std dev of target noise
387
+ feature_std: Std dev of feature values
388
+ """
389
+ self._feature_dim = feature_dim
390
+ self._cycle_length = cycle_length
391
+ self._num_configurations = num_configurations
392
+ self._noise_std = noise_std
393
+ self._feature_std = feature_std
394
+
395
+ @property
396
+ def feature_dim(self) -> int:
397
+ """Return the dimension of observation vectors."""
398
+ return self._feature_dim
399
+
400
+ def init(self, key: Array) -> CyclicState:
401
+ """Initialize stream state.
402
+
403
+ Args:
404
+ key: JAX random key
405
+
406
+ Returns:
407
+ Initial stream state with pre-generated configurations
408
+ """
409
+ key, key_configs = jr.split(key)
410
+ configurations = jr.normal(
411
+ key_configs,
412
+ (self._num_configurations, self._feature_dim),
413
+ dtype=jnp.float32,
414
+ )
415
+ return CyclicState(
416
+ key=key,
417
+ configurations=configurations,
418
+ step_count=jnp.array(0, dtype=jnp.int32),
419
+ )
420
+
421
+ def step(self, state: CyclicState, idx: Array) -> tuple[TimeStep, CyclicState]:
422
+ """Generate one time step.
423
+
424
+ Args:
425
+ state: Current stream state
426
+ idx: Current step index (unused)
427
+
428
+ Returns:
429
+ Tuple of (timestep, new_state)
430
+ """
431
+ del idx # unused
432
+ key, key_x, key_noise = jr.split(state.key, 3)
433
+
434
+ # Get current configuration index
435
+ config_idx = (state.step_count // self._cycle_length) % self._num_configurations
436
+ true_weights = state.configurations[config_idx]
437
+
438
+ # Generate observation
439
+ x = self._feature_std * jr.normal(key_x, (self._feature_dim,), dtype=jnp.float32)
440
+
441
+ # Compute target
442
+ noise = self._noise_std * jr.normal(key_noise, (), dtype=jnp.float32)
443
+ target = jnp.dot(true_weights, x) + noise
444
+
445
+ timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
446
+ new_state = CyclicState(
447
+ key=key,
448
+ configurations=state.configurations,
449
+ step_count=state.step_count + 1,
450
+ )
451
+
452
+ return timestep, new_state
453
+
454
+
455
+ class PeriodicChangeState(NamedTuple):
456
+ """State for PeriodicChangeStream.
457
+
458
+ Attributes:
459
+ key: JAX random key for generating randomness
460
+ base_weights: Base target weights (center of oscillation)
461
+ phases: Per-weight phase offsets
462
+ step_count: Number of steps taken
463
+ """
464
+
465
+ key: Array
466
+ base_weights: Array
467
+ phases: Array
468
+ step_count: Array
469
+
470
+
471
+ class PeriodicChangeStream:
472
+ """Non-stationary stream where target weights oscillate sinusoidally.
473
+
474
+ Target weights follow: w(t) = base + amplitude * sin(2π * t / period + phase)
475
+ where each weight has a random phase offset for diversity.
476
+
477
+ This tests the learner's ability to track predictable periodic changes,
478
+ which is qualitatively different from random drift or abrupt changes.
479
+
480
+ Attributes:
481
+ feature_dim: Dimension of observation vectors
482
+ period: Number of steps for one complete oscillation
483
+ amplitude: Magnitude of weight oscillation
484
+ noise_std: Standard deviation of observation noise
485
+ feature_std: Standard deviation of features
486
+ """
487
+
488
+ def __init__(
489
+ self,
490
+ feature_dim: int,
491
+ period: int = 1000,
492
+ amplitude: float = 1.0,
493
+ noise_std: float = 0.1,
494
+ feature_std: float = 1.0,
495
+ ):
496
+ """Initialize the periodic change stream.
497
+
498
+ Args:
499
+ feature_dim: Dimension of feature vectors
500
+ period: Steps for one complete oscillation cycle
501
+ amplitude: Magnitude of weight oscillations around base
502
+ noise_std: Std dev of target noise
503
+ feature_std: Std dev of feature values
504
+ """
505
+ self._feature_dim = feature_dim
506
+ self._period = period
507
+ self._amplitude = amplitude
508
+ self._noise_std = noise_std
509
+ self._feature_std = feature_std
510
+
511
+ @property
512
+ def feature_dim(self) -> int:
513
+ """Return the dimension of observation vectors."""
514
+ return self._feature_dim
515
+
516
+ def init(self, key: Array) -> PeriodicChangeState:
517
+ """Initialize stream state.
518
+
519
+ Args:
520
+ key: JAX random key
521
+
522
+ Returns:
523
+ Initial stream state with random base weights and phases
524
+ """
525
+ key, key_weights, key_phases = jr.split(key, 3)
526
+ base_weights = jr.normal(key_weights, (self._feature_dim,), dtype=jnp.float32)
527
+ # Random phases in [0, 2π) for each weight
528
+ phases = jr.uniform(key_phases, (self._feature_dim,), minval=0.0, maxval=2.0 * jnp.pi)
529
+ return PeriodicChangeState(
530
+ key=key,
531
+ base_weights=base_weights,
532
+ phases=phases,
533
+ step_count=jnp.array(0, dtype=jnp.int32),
534
+ )
535
+
536
+ def step(
537
+ self, state: PeriodicChangeState, idx: Array
538
+ ) -> tuple[TimeStep, PeriodicChangeState]:
539
+ """Generate one time step.
540
+
541
+ Args:
542
+ state: Current stream state
543
+ idx: Current step index (unused)
544
+
545
+ Returns:
546
+ Tuple of (timestep, new_state)
547
+ """
548
+ del idx # unused
549
+ key, key_x, key_noise = jr.split(state.key, 3)
550
+
551
+ # Compute oscillating weights: w(t) = base + amplitude * sin(2π * t / period + phase)
552
+ t = state.step_count.astype(jnp.float32)
553
+ oscillation = self._amplitude * jnp.sin(
554
+ 2.0 * jnp.pi * t / self._period + state.phases
555
+ )
556
+ true_weights = state.base_weights + oscillation
557
+
558
+ # Generate observation
559
+ x = self._feature_std * jr.normal(key_x, (self._feature_dim,), dtype=jnp.float32)
560
+
561
+ # Compute target
562
+ noise = self._noise_std * jr.normal(key_noise, (), dtype=jnp.float32)
563
+ target = jnp.dot(true_weights, x) + noise
564
+
565
+ timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
566
+ new_state = PeriodicChangeState(
567
+ key=key,
568
+ base_weights=state.base_weights,
569
+ phases=state.phases,
570
+ step_count=state.step_count + 1,
571
+ )
572
+
573
+ return timestep, new_state
574
+
575
+
576
+ class ScaledStreamState(NamedTuple):
577
+ """State for ScaledStreamWrapper.
578
+
579
+ Attributes:
580
+ inner_state: State of the wrapped stream
581
+ """
582
+
583
+ inner_state: tuple[Any, ...] # Generic state from wrapped stream
584
+
585
+
586
+ class ScaledStreamWrapper:
587
+ """Wrapper that applies per-feature scaling to any stream's observations.
588
+
589
+ This wrapper multiplies each feature of the observation by a corresponding
590
+ scale factor. Useful for testing how learners handle features at different
591
+ scales, which is important for understanding normalization benefits.
592
+
593
+ Example:
594
+ >>> stream = ScaledStreamWrapper(
595
+ ... AbruptChangeStream(feature_dim=10, change_interval=1000),
596
+ ... feature_scales=jnp.array([0.001, 0.01, 0.1, 1.0, 10.0,
597
+ ... 100.0, 1000.0, 0.001, 0.01, 0.1])
598
+ ... )
599
+
600
+ Attributes:
601
+ inner_stream: The wrapped stream instance
602
+ feature_scales: Per-feature scale factors (must match feature_dim)
603
+ """
604
+
605
+ def __init__(self, inner_stream: ScanStream[Any], feature_scales: Array):
606
+ """Initialize the scaled stream wrapper.
607
+
608
+ Args:
609
+ inner_stream: Stream to wrap (must implement ScanStream protocol)
610
+ feature_scales: Array of scale factors, one per feature. Must have
611
+ shape (feature_dim,) matching the inner stream's feature_dim.
612
+
613
+ Raises:
614
+ ValueError: If feature_scales length doesn't match inner stream's feature_dim
615
+ """
616
+ self._inner_stream: ScanStream[Any] = inner_stream
617
+ self._feature_scales = jnp.asarray(feature_scales, dtype=jnp.float32)
618
+
619
+ if self._feature_scales.shape[0] != inner_stream.feature_dim:
620
+ raise ValueError(
621
+ f"feature_scales length ({self._feature_scales.shape[0]}) "
622
+ f"must match inner stream's feature_dim ({inner_stream.feature_dim})"
623
+ )
624
+
625
+ @property
626
+ def feature_dim(self) -> int:
627
+ """Return the dimension of observation vectors."""
628
+ return int(self._inner_stream.feature_dim)
629
+
630
+ @property
631
+ def inner_stream(self) -> ScanStream[Any]:
632
+ """Return the wrapped stream."""
633
+ return self._inner_stream
634
+
635
+ @property
636
+ def feature_scales(self) -> Array:
637
+ """Return the per-feature scale factors."""
638
+ return self._feature_scales
639
+
640
+ def init(self, key: Array) -> ScaledStreamState:
641
+ """Initialize stream state.
642
+
643
+ Args:
644
+ key: JAX random key
645
+
646
+ Returns:
647
+ Initial stream state wrapping the inner stream's state
648
+ """
649
+ inner_state = self._inner_stream.init(key)
650
+ return ScaledStreamState(inner_state=inner_state)
651
+
652
+ def step(self, state: ScaledStreamState, idx: Array) -> tuple[TimeStep, ScaledStreamState]:
653
+ """Generate one time step with scaled observations.
654
+
655
+ Args:
656
+ state: Current stream state
657
+ idx: Current step index
658
+
659
+ Returns:
660
+ Tuple of (timestep with scaled observation, new_state)
661
+ """
662
+ timestep, new_inner_state = self._inner_stream.step(state.inner_state, idx)
663
+
664
+ # Scale the observation
665
+ scaled_observation = timestep.observation * self._feature_scales
666
+
667
+ scaled_timestep = TimeStep(
668
+ observation=scaled_observation,
669
+ target=timestep.target,
670
+ )
671
+
672
+ new_state = ScaledStreamState(inner_state=new_inner_state)
673
+ return scaled_timestep, new_state
674
+
675
+
676
+ def make_scale_range(
677
+ feature_dim: int,
678
+ min_scale: float = 0.001,
679
+ max_scale: float = 1000.0,
680
+ log_spaced: bool = True,
681
+ ) -> Array:
682
+ """Create a per-feature scale array spanning a range.
683
+
684
+ Utility function to generate scale factors for ScaledStreamWrapper.
685
+
686
+ Args:
687
+ feature_dim: Number of features
688
+ min_scale: Minimum scale factor
689
+ max_scale: Maximum scale factor
690
+ log_spaced: If True, scales are logarithmically spaced (default).
691
+ If False, scales are linearly spaced.
692
+
693
+ Returns:
694
+ Array of shape (feature_dim,) with scale factors
695
+
696
+ Example:
697
+ >>> scales = make_scale_range(10, min_scale=0.01, max_scale=100.0)
698
+ >>> stream = ScaledStreamWrapper(RandomWalkStream(10), scales)
699
+ """
700
+ if log_spaced:
701
+ return jnp.logspace(
702
+ jnp.log10(min_scale),
703
+ jnp.log10(max_scale),
704
+ feature_dim,
705
+ dtype=jnp.float32,
706
+ )
707
+ else:
708
+ return jnp.linspace(min_scale, max_scale, feature_dim, dtype=jnp.float32)
709
+
710
+
711
+ class DynamicScaleShiftState(NamedTuple):
712
+ """State for DynamicScaleShiftStream.
713
+
714
+ Attributes:
715
+ key: JAX random key for generating randomness
716
+ true_weights: Current true target weights
717
+ current_scales: Current per-feature scaling factors
718
+ step_count: Number of steps taken
719
+ """
720
+
721
+ key: Array
722
+ true_weights: Array
723
+ current_scales: Array
724
+ step_count: Array
725
+
726
+
727
+ class DynamicScaleShiftStream:
728
+ """Non-stationary stream with abruptly changing feature scales.
729
+
730
+ Both target weights AND feature scales change at specified intervals.
731
+ This tests whether OnlineNormalizer can track scale shifts faster
732
+ than Autostep's internal v_i adaptation.
733
+
734
+ The target is computed from unscaled features to maintain consistent
735
+ difficulty across scale changes (only the feature representation changes,
736
+ not the underlying prediction task).
737
+
738
+ Attributes:
739
+ feature_dim: Dimension of observation vectors
740
+ scale_change_interval: Steps between scale changes
741
+ weight_change_interval: Steps between weight changes
742
+ min_scale: Minimum scale factor
743
+ max_scale: Maximum scale factor
744
+ noise_std: Standard deviation of observation noise
745
+ """
746
+
747
+ def __init__(
748
+ self,
749
+ feature_dim: int,
750
+ scale_change_interval: int = 2000,
751
+ weight_change_interval: int = 1000,
752
+ min_scale: float = 0.01,
753
+ max_scale: float = 100.0,
754
+ noise_std: float = 0.1,
755
+ ):
756
+ """Initialize the dynamic scale shift stream.
757
+
758
+ Args:
759
+ feature_dim: Dimension of feature vectors
760
+ scale_change_interval: Steps between abrupt scale changes
761
+ weight_change_interval: Steps between abrupt weight changes
762
+ min_scale: Minimum scale factor (log-uniform sampling)
763
+ max_scale: Maximum scale factor (log-uniform sampling)
764
+ noise_std: Std dev of target noise
765
+ """
766
+ self._feature_dim = feature_dim
767
+ self._scale_change_interval = scale_change_interval
768
+ self._weight_change_interval = weight_change_interval
769
+ self._min_scale = min_scale
770
+ self._max_scale = max_scale
771
+ self._noise_std = noise_std
772
+
773
+ @property
774
+ def feature_dim(self) -> int:
775
+ """Return the dimension of observation vectors."""
776
+ return self._feature_dim
777
+
778
+ def init(self, key: Array) -> DynamicScaleShiftState:
779
+ """Initialize stream state.
780
+
781
+ Args:
782
+ key: JAX random key
783
+
784
+ Returns:
785
+ Initial stream state with random weights and scales
786
+ """
787
+ key, k_weights, k_scales = jr.split(key, 3)
788
+ weights = jr.normal(k_weights, (self._feature_dim,), dtype=jnp.float32)
789
+ # Initial scales: log-uniform between min and max
790
+ log_scales = jr.uniform(
791
+ k_scales,
792
+ (self._feature_dim,),
793
+ minval=jnp.log(self._min_scale),
794
+ maxval=jnp.log(self._max_scale),
795
+ )
796
+ scales = jnp.exp(log_scales).astype(jnp.float32)
797
+ return DynamicScaleShiftState(
798
+ key=key,
799
+ true_weights=weights,
800
+ current_scales=scales,
801
+ step_count=jnp.array(0, dtype=jnp.int32),
802
+ )
803
+
804
+ def step(
805
+ self, state: DynamicScaleShiftState, idx: Array
806
+ ) -> tuple[TimeStep, DynamicScaleShiftState]:
807
+ """Generate one time step.
808
+
809
+ Args:
810
+ state: Current stream state
811
+ idx: Current step index (unused)
812
+
813
+ Returns:
814
+ Tuple of (timestep, new_state)
815
+ """
816
+ del idx # unused
817
+ key, k_weights, k_scales, k_x, k_noise = jr.split(state.key, 5)
818
+
819
+ # Check if scales should change
820
+ should_change_scales = state.step_count % self._scale_change_interval == 0
821
+ new_log_scales = jr.uniform(
822
+ k_scales,
823
+ (self._feature_dim,),
824
+ minval=jnp.log(self._min_scale),
825
+ maxval=jnp.log(self._max_scale),
826
+ )
827
+ new_random_scales = jnp.exp(new_log_scales).astype(jnp.float32)
828
+ new_scales = jnp.where(should_change_scales, new_random_scales, state.current_scales)
829
+
830
+ # Check if weights should change
831
+ should_change_weights = state.step_count % self._weight_change_interval == 0
832
+ new_random_weights = jr.normal(k_weights, (self._feature_dim,), dtype=jnp.float32)
833
+ new_weights = jnp.where(should_change_weights, new_random_weights, state.true_weights)
834
+
835
+ # Generate raw features (unscaled)
836
+ raw_x = jr.normal(k_x, (self._feature_dim,), dtype=jnp.float32)
837
+
838
+ # Apply scaling to observation
839
+ x = raw_x * new_scales
840
+
841
+ # Target from true weights using RAW features (for consistent difficulty)
842
+ noise = self._noise_std * jr.normal(k_noise, (), dtype=jnp.float32)
843
+ target = jnp.dot(new_weights, raw_x) + noise
844
+
845
+ timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
846
+ new_state = DynamicScaleShiftState(
847
+ key=key,
848
+ true_weights=new_weights,
849
+ current_scales=new_scales,
850
+ step_count=state.step_count + 1,
851
+ )
852
+ return timestep, new_state
853
+
854
+
855
+ class ScaleDriftState(NamedTuple):
856
+ """State for ScaleDriftStream.
857
+
858
+ Attributes:
859
+ key: JAX random key for generating randomness
860
+ true_weights: Current true target weights
861
+ log_scales: Current log-scale factors (random walk on log-scale)
862
+ step_count: Number of steps taken
863
+ """
864
+
865
+ key: Array
866
+ true_weights: Array
867
+ log_scales: Array
868
+ step_count: Array
869
+
870
+
871
+ class ScaleDriftStream:
872
+ """Non-stationary stream where feature scales drift via random walk.
873
+
874
+ Both target weights and feature scales drift continuously. Weights drift
875
+ in linear space while scales drift in log-space (bounded random walk).
876
+ This tests continuous scale tracking where OnlineNormalizer's EMA
877
+ may adapt differently than Autostep's v_i.
878
+
879
+ The target is computed from unscaled features to maintain consistent
880
+ difficulty across scale changes.
881
+
882
+ Attributes:
883
+ feature_dim: Dimension of observation vectors
884
+ weight_drift_rate: Std dev of weight drift per step
885
+ scale_drift_rate: Std dev of log-scale drift per step
886
+ min_log_scale: Minimum log-scale (clips random walk)
887
+ max_log_scale: Maximum log-scale (clips random walk)
888
+ noise_std: Standard deviation of observation noise
889
+ """
890
+
891
+ def __init__(
892
+ self,
893
+ feature_dim: int,
894
+ weight_drift_rate: float = 0.001,
895
+ scale_drift_rate: float = 0.01,
896
+ min_log_scale: float = -4.0, # exp(-4) ~ 0.018
897
+ max_log_scale: float = 4.0, # exp(4) ~ 54.6
898
+ noise_std: float = 0.1,
899
+ ):
900
+ """Initialize the scale drift stream.
901
+
902
+ Args:
903
+ feature_dim: Dimension of feature vectors
904
+ weight_drift_rate: Std dev of weight drift per step
905
+ scale_drift_rate: Std dev of log-scale drift per step
906
+ min_log_scale: Minimum log-scale (clips drift)
907
+ max_log_scale: Maximum log-scale (clips drift)
908
+ noise_std: Std dev of target noise
909
+ """
910
+ self._feature_dim = feature_dim
911
+ self._weight_drift_rate = weight_drift_rate
912
+ self._scale_drift_rate = scale_drift_rate
913
+ self._min_log_scale = min_log_scale
914
+ self._max_log_scale = max_log_scale
915
+ self._noise_std = noise_std
916
+
917
+ @property
918
+ def feature_dim(self) -> int:
919
+ """Return the dimension of observation vectors."""
920
+ return self._feature_dim
921
+
922
+ def init(self, key: Array) -> ScaleDriftState:
923
+ """Initialize stream state.
924
+
925
+ Args:
926
+ key: JAX random key
927
+
928
+ Returns:
929
+ Initial stream state with random weights and unit scales
930
+ """
931
+ key, k_weights = jr.split(key)
932
+ weights = jr.normal(k_weights, (self._feature_dim,), dtype=jnp.float32)
933
+ # Initial log-scales at 0 (scale = 1)
934
+ log_scales = jnp.zeros(self._feature_dim, dtype=jnp.float32)
935
+ return ScaleDriftState(
936
+ key=key,
937
+ true_weights=weights,
938
+ log_scales=log_scales,
939
+ step_count=jnp.array(0, dtype=jnp.int32),
940
+ )
941
+
942
+ def step(
943
+ self, state: ScaleDriftState, idx: Array
944
+ ) -> tuple[TimeStep, ScaleDriftState]:
945
+ """Generate one time step.
946
+
947
+ Args:
948
+ state: Current stream state
949
+ idx: Current step index (unused)
950
+
951
+ Returns:
952
+ Tuple of (timestep, new_state)
953
+ """
954
+ del idx # unused
955
+ key, k_w_drift, k_s_drift, k_x, k_noise = jr.split(state.key, 5)
956
+
957
+ # Drift target weights
958
+ weight_drift = self._weight_drift_rate * jr.normal(
959
+ k_w_drift, (self._feature_dim,), dtype=jnp.float32
960
+ )
961
+ new_weights = state.true_weights + weight_drift
962
+
963
+ # Drift log-scales (bounded random walk)
964
+ scale_drift = self._scale_drift_rate * jr.normal(
965
+ k_s_drift, (self._feature_dim,), dtype=jnp.float32
966
+ )
967
+ new_log_scales = state.log_scales + scale_drift
968
+ new_log_scales = jnp.clip(new_log_scales, self._min_log_scale, self._max_log_scale)
969
+
970
+ # Generate raw features (unscaled)
971
+ raw_x = jr.normal(k_x, (self._feature_dim,), dtype=jnp.float32)
972
+
973
+ # Apply scaling to observation
974
+ scales = jnp.exp(new_log_scales)
975
+ x = raw_x * scales
976
+
977
+ # Target from true weights using RAW features
978
+ noise = self._noise_std * jr.normal(k_noise, (), dtype=jnp.float32)
979
+ target = jnp.dot(new_weights, raw_x) + noise
980
+
981
+ timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
982
+ new_state = ScaleDriftState(
983
+ key=key,
984
+ true_weights=new_weights,
985
+ log_scales=new_log_scales,
986
+ step_count=state.step_count + 1,
987
+ )
988
+ return timestep, new_state
989
+
990
+
991
+ # Backward-compatible aliases
992
+ RandomWalkTarget = RandomWalkStream
993
+ AbruptChangeTarget = AbruptChangeStream
994
+ CyclicTarget = CyclicStream
995
+ PeriodicChangeTarget = PeriodicChangeStream