alberta-framework 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1001 @@
1
+ """Synthetic non-stationary experience streams for testing continual learning.
2
+
3
+ These streams generate non-stationary supervised learning problems where
4
+ the target function changes over time, testing the learner's ability to
5
+ track and adapt.
6
+
7
+ All streams use JAX-compatible pure functions that work with jax.lax.scan.
8
+ """
9
+
10
+ from typing import Any, NamedTuple
11
+
12
+ import jax.numpy as jnp
13
+ import jax.random as jr
14
+ from jax import Array
15
+
16
+ from alberta_framework.core.types import TimeStep
17
+ from alberta_framework.streams.base import ScanStream
18
+
19
+
20
+ class RandomWalkState(NamedTuple):
21
+ """State for RandomWalkStream.
22
+
23
+ Attributes:
24
+ key: JAX random key for generating randomness
25
+ true_weights: Current true target weights
26
+ """
27
+
28
+ key: Array
29
+ true_weights: Array
30
+
31
+
32
+ class RandomWalkStream:
33
+ """Non-stationary stream where target weights drift via random walk.
34
+
35
+ The true target function is linear: `y* = w_true @ x + noise`
36
+ where w_true evolves via random walk at each time step.
37
+
38
+ This tests the learner's ability to continuously track a moving target.
39
+
40
+ Attributes:
41
+ feature_dim: Dimension of observation vectors
42
+ drift_rate: Standard deviation of weight drift per step
43
+ noise_std: Standard deviation of observation noise
44
+ feature_std: Standard deviation of features
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ feature_dim: int,
50
+ drift_rate: float = 0.001,
51
+ noise_std: float = 0.1,
52
+ feature_std: float = 1.0,
53
+ ):
54
+ """Initialize the random walk target stream.
55
+
56
+ Args:
57
+ feature_dim: Dimension of the feature/observation vectors
58
+ drift_rate: Std dev of weight changes per step (controls non-stationarity)
59
+ noise_std: Std dev of target noise
60
+ feature_std: Std dev of feature values
61
+ """
62
+ self._feature_dim = feature_dim
63
+ self._drift_rate = drift_rate
64
+ self._noise_std = noise_std
65
+ self._feature_std = feature_std
66
+
67
+ @property
68
+ def feature_dim(self) -> int:
69
+ """Return the dimension of observation vectors."""
70
+ return self._feature_dim
71
+
72
+ def init(self, key: Array) -> RandomWalkState:
73
+ """Initialize stream state.
74
+
75
+ Args:
76
+ key: JAX random key
77
+
78
+ Returns:
79
+ Initial stream state with random weights
80
+ """
81
+ key, subkey = jr.split(key)
82
+ weights = jr.normal(subkey, (self._feature_dim,), dtype=jnp.float32)
83
+ return RandomWalkState(key=key, true_weights=weights)
84
+
85
+ def step(self, state: RandomWalkState, idx: Array) -> tuple[TimeStep, RandomWalkState]:
86
+ """Generate one time step.
87
+
88
+ Args:
89
+ state: Current stream state
90
+ idx: Current step index (unused)
91
+
92
+ Returns:
93
+ Tuple of (timestep, new_state)
94
+ """
95
+ del idx # unused
96
+ key, k_drift, k_x, k_noise = jr.split(state.key, 4)
97
+
98
+ # Drift weights
99
+ drift = jr.normal(k_drift, state.true_weights.shape, dtype=jnp.float32)
100
+ new_weights = state.true_weights + self._drift_rate * drift
101
+
102
+ # Generate observation and target
103
+ x = self._feature_std * jr.normal(k_x, (self._feature_dim,), dtype=jnp.float32)
104
+ noise = self._noise_std * jr.normal(k_noise, (), dtype=jnp.float32)
105
+ target = jnp.dot(new_weights, x) + noise
106
+
107
+ timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
108
+ new_state = RandomWalkState(key=key, true_weights=new_weights)
109
+
110
+ return timestep, new_state
111
+
112
+
113
+ class AbruptChangeState(NamedTuple):
114
+ """State for AbruptChangeStream.
115
+
116
+ Attributes:
117
+ key: JAX random key for generating randomness
118
+ true_weights: Current true target weights
119
+ step_count: Number of steps taken
120
+ """
121
+
122
+ key: Array
123
+ true_weights: Array
124
+ step_count: Array
125
+
126
+
127
+ class AbruptChangeStream:
128
+ """Non-stationary stream with sudden target weight changes.
129
+
130
+ Target weights remain constant for a period, then abruptly change
131
+ to new random values. Tests the learner's ability to detect and
132
+ rapidly adapt to distribution shifts.
133
+
134
+ Attributes:
135
+ feature_dim: Dimension of observation vectors
136
+ change_interval: Number of steps between weight changes
137
+ noise_std: Standard deviation of observation noise
138
+ feature_std: Standard deviation of features
139
+ """
140
+
141
+ def __init__(
142
+ self,
143
+ feature_dim: int,
144
+ change_interval: int = 1000,
145
+ noise_std: float = 0.1,
146
+ feature_std: float = 1.0,
147
+ ):
148
+ """Initialize the abrupt change stream.
149
+
150
+ Args:
151
+ feature_dim: Dimension of feature vectors
152
+ change_interval: Steps between abrupt weight changes
153
+ noise_std: Std dev of target noise
154
+ feature_std: Std dev of feature values
155
+ """
156
+ self._feature_dim = feature_dim
157
+ self._change_interval = change_interval
158
+ self._noise_std = noise_std
159
+ self._feature_std = feature_std
160
+
161
+ @property
162
+ def feature_dim(self) -> int:
163
+ """Return the dimension of observation vectors."""
164
+ return self._feature_dim
165
+
166
+ def init(self, key: Array) -> AbruptChangeState:
167
+ """Initialize stream state.
168
+
169
+ Args:
170
+ key: JAX random key
171
+
172
+ Returns:
173
+ Initial stream state
174
+ """
175
+ key, subkey = jr.split(key)
176
+ weights = jr.normal(subkey, (self._feature_dim,), dtype=jnp.float32)
177
+ return AbruptChangeState(
178
+ key=key,
179
+ true_weights=weights,
180
+ step_count=jnp.array(0, dtype=jnp.int32),
181
+ )
182
+
183
+ def step(self, state: AbruptChangeState, idx: Array) -> tuple[TimeStep, AbruptChangeState]:
184
+ """Generate one time step.
185
+
186
+ Args:
187
+ state: Current stream state
188
+ idx: Current step index (unused)
189
+
190
+ Returns:
191
+ Tuple of (timestep, new_state)
192
+ """
193
+ del idx # unused
194
+ key, key_weights, key_x, key_noise = jr.split(state.key, 4)
195
+
196
+ # Determine if we should change weights
197
+ should_change = state.step_count % self._change_interval == 0
198
+
199
+ # Generate new weights (always generated but only used if should_change)
200
+ new_random_weights = jr.normal(key_weights, (self._feature_dim,), dtype=jnp.float32)
201
+
202
+ # Use jnp.where to conditionally update weights (JIT-compatible)
203
+ new_weights = jnp.where(should_change, new_random_weights, state.true_weights)
204
+
205
+ # Generate observation
206
+ x = self._feature_std * jr.normal(key_x, (self._feature_dim,), dtype=jnp.float32)
207
+
208
+ # Compute target
209
+ noise = self._noise_std * jr.normal(key_noise, (), dtype=jnp.float32)
210
+ target = jnp.dot(new_weights, x) + noise
211
+
212
+ timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
213
+ new_state = AbruptChangeState(
214
+ key=key,
215
+ true_weights=new_weights,
216
+ step_count=state.step_count + 1,
217
+ )
218
+
219
+ return timestep, new_state
220
+
221
+
222
+ class SuttonExperiment1State(NamedTuple):
223
+ """State for SuttonExperiment1Stream.
224
+
225
+ Attributes:
226
+ key: JAX random key for generating randomness
227
+ signs: Signs (+1/-1) for the relevant inputs
228
+ step_count: Number of steps taken
229
+ """
230
+
231
+ key: Array
232
+ signs: Array
233
+ step_count: Array
234
+
235
+
236
+ class SuttonExperiment1Stream:
237
+ """Non-stationary stream replicating Experiment 1 from Sutton 1992.
238
+
239
+ This stream implements the exact task from Sutton's IDBD paper:
240
+ - 20 real-valued inputs drawn from N(0, 1)
241
+ - Only first 5 inputs are relevant (weights are ±1)
242
+ - Last 15 inputs are irrelevant (weights are 0)
243
+ - Every change_interval steps, one of the 5 relevant signs is flipped
244
+
245
+ Reference: Sutton, R.S. (1992). "Adapting Bias by Gradient Descent:
246
+ An Incremental Version of Delta-Bar-Delta"
247
+
248
+ Attributes:
249
+ num_relevant: Number of relevant inputs (default 5)
250
+ num_irrelevant: Number of irrelevant inputs (default 15)
251
+ change_interval: Steps between sign changes (default 20)
252
+ """
253
+
254
+ def __init__(
255
+ self,
256
+ num_relevant: int = 5,
257
+ num_irrelevant: int = 15,
258
+ change_interval: int = 20,
259
+ ):
260
+ """Initialize the Sutton Experiment 1 stream.
261
+
262
+ Args:
263
+ num_relevant: Number of relevant inputs with ±1 weights
264
+ num_irrelevant: Number of irrelevant inputs with 0 weights
265
+ change_interval: Number of steps between sign flips
266
+ """
267
+ self._num_relevant = num_relevant
268
+ self._num_irrelevant = num_irrelevant
269
+ self._change_interval = change_interval
270
+
271
+ @property
272
+ def feature_dim(self) -> int:
273
+ """Return the dimension of observation vectors."""
274
+ return self._num_relevant + self._num_irrelevant
275
+
276
+ def init(self, key: Array) -> SuttonExperiment1State:
277
+ """Initialize stream state.
278
+
279
+ Args:
280
+ key: JAX random key
281
+
282
+ Returns:
283
+ Initial stream state with all +1 signs
284
+ """
285
+ signs = jnp.ones(self._num_relevant, dtype=jnp.float32)
286
+ return SuttonExperiment1State(
287
+ key=key,
288
+ signs=signs,
289
+ step_count=jnp.array(0, dtype=jnp.int32),
290
+ )
291
+
292
+ def step(
293
+ self, state: SuttonExperiment1State, idx: Array
294
+ ) -> tuple[TimeStep, SuttonExperiment1State]:
295
+ """Generate one time step.
296
+
297
+ At each step:
298
+ 1. If at a change interval (and not step 0), flip one random sign
299
+ 2. Generate random inputs from N(0, 1)
300
+ 3. Compute target as sum of relevant inputs weighted by signs
301
+
302
+ Args:
303
+ state: Current stream state
304
+ idx: Current step index (unused)
305
+
306
+ Returns:
307
+ Tuple of (timestep, new_state)
308
+ """
309
+ del idx # unused
310
+ key, key_x, key_which = jr.split(state.key, 3)
311
+
312
+ # Determine if we should flip a sign (not at step 0)
313
+ should_flip = (state.step_count > 0) & (state.step_count % self._change_interval == 0)
314
+
315
+ # Select which sign to flip
316
+ idx_to_flip = jr.randint(key_which, (), 0, self._num_relevant)
317
+
318
+ # Create flip mask
319
+ flip_mask = jnp.where(
320
+ jnp.arange(self._num_relevant) == idx_to_flip,
321
+ jnp.array(-1.0, dtype=jnp.float32),
322
+ jnp.array(1.0, dtype=jnp.float32),
323
+ )
324
+
325
+ # Apply flip mask conditionally
326
+ new_signs = jnp.where(should_flip, state.signs * flip_mask, state.signs)
327
+
328
+ # Generate observation from N(0, 1)
329
+ x = jr.normal(key_x, (self.feature_dim,), dtype=jnp.float32)
330
+
331
+ # Compute target: sum of first num_relevant inputs weighted by signs
332
+ target = jnp.dot(new_signs, x[: self._num_relevant])
333
+
334
+ timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
335
+ new_state = SuttonExperiment1State(
336
+ key=key,
337
+ signs=new_signs,
338
+ step_count=state.step_count + 1,
339
+ )
340
+
341
+ return timestep, new_state
342
+
343
+
344
+ class CyclicState(NamedTuple):
345
+ """State for CyclicStream.
346
+
347
+ Attributes:
348
+ key: JAX random key for generating randomness
349
+ configurations: Pre-generated weight configurations
350
+ step_count: Number of steps taken
351
+ """
352
+
353
+ key: Array
354
+ configurations: Array
355
+ step_count: Array
356
+
357
+
358
+ class CyclicStream:
359
+ """Non-stationary stream that cycles between known weight configurations.
360
+
361
+ Weights cycle through a fixed set of configurations. Tests whether
362
+ the learner can re-adapt quickly to previously seen targets.
363
+
364
+ Attributes:
365
+ feature_dim: Dimension of observation vectors
366
+ cycle_length: Number of steps per configuration before switching
367
+ num_configurations: Number of weight configurations to cycle through
368
+ noise_std: Standard deviation of observation noise
369
+ feature_std: Standard deviation of features
370
+ """
371
+
372
+ def __init__(
373
+ self,
374
+ feature_dim: int,
375
+ cycle_length: int = 500,
376
+ num_configurations: int = 4,
377
+ noise_std: float = 0.1,
378
+ feature_std: float = 1.0,
379
+ ):
380
+ """Initialize the cyclic target stream.
381
+
382
+ Args:
383
+ feature_dim: Dimension of feature vectors
384
+ cycle_length: Steps spent in each configuration
385
+ num_configurations: Number of configurations to cycle through
386
+ noise_std: Std dev of target noise
387
+ feature_std: Std dev of feature values
388
+ """
389
+ self._feature_dim = feature_dim
390
+ self._cycle_length = cycle_length
391
+ self._num_configurations = num_configurations
392
+ self._noise_std = noise_std
393
+ self._feature_std = feature_std
394
+
395
+ @property
396
+ def feature_dim(self) -> int:
397
+ """Return the dimension of observation vectors."""
398
+ return self._feature_dim
399
+
400
+ def init(self, key: Array) -> CyclicState:
401
+ """Initialize stream state.
402
+
403
+ Args:
404
+ key: JAX random key
405
+
406
+ Returns:
407
+ Initial stream state with pre-generated configurations
408
+ """
409
+ key, key_configs = jr.split(key)
410
+ configurations = jr.normal(
411
+ key_configs,
412
+ (self._num_configurations, self._feature_dim),
413
+ dtype=jnp.float32,
414
+ )
415
+ return CyclicState(
416
+ key=key,
417
+ configurations=configurations,
418
+ step_count=jnp.array(0, dtype=jnp.int32),
419
+ )
420
+
421
+ def step(self, state: CyclicState, idx: Array) -> tuple[TimeStep, CyclicState]:
422
+ """Generate one time step.
423
+
424
+ Args:
425
+ state: Current stream state
426
+ idx: Current step index (unused)
427
+
428
+ Returns:
429
+ Tuple of (timestep, new_state)
430
+ """
431
+ del idx # unused
432
+ key, key_x, key_noise = jr.split(state.key, 3)
433
+
434
+ # Get current configuration index
435
+ config_idx = (state.step_count // self._cycle_length) % self._num_configurations
436
+ true_weights = state.configurations[config_idx]
437
+
438
+ # Generate observation
439
+ x = self._feature_std * jr.normal(key_x, (self._feature_dim,), dtype=jnp.float32)
440
+
441
+ # Compute target
442
+ noise = self._noise_std * jr.normal(key_noise, (), dtype=jnp.float32)
443
+ target = jnp.dot(true_weights, x) + noise
444
+
445
+ timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
446
+ new_state = CyclicState(
447
+ key=key,
448
+ configurations=state.configurations,
449
+ step_count=state.step_count + 1,
450
+ )
451
+
452
+ return timestep, new_state
453
+
454
+
455
+ class PeriodicChangeState(NamedTuple):
456
+ """State for PeriodicChangeStream.
457
+
458
+ Attributes:
459
+ key: JAX random key for generating randomness
460
+ base_weights: Base target weights (center of oscillation)
461
+ phases: Per-weight phase offsets
462
+ step_count: Number of steps taken
463
+ """
464
+
465
+ key: Array
466
+ base_weights: Array
467
+ phases: Array
468
+ step_count: Array
469
+
470
+
471
+ class PeriodicChangeStream:
472
+ """Non-stationary stream where target weights oscillate sinusoidally.
473
+
474
+ Target weights follow: w(t) = base + amplitude * sin(2π * t / period + phase)
475
+ where each weight has a random phase offset for diversity.
476
+
477
+ This tests the learner's ability to track predictable periodic changes,
478
+ which is qualitatively different from random drift or abrupt changes.
479
+
480
+ Attributes:
481
+ feature_dim: Dimension of observation vectors
482
+ period: Number of steps for one complete oscillation
483
+ amplitude: Magnitude of weight oscillation
484
+ noise_std: Standard deviation of observation noise
485
+ feature_std: Standard deviation of features
486
+ """
487
+
488
+ def __init__(
489
+ self,
490
+ feature_dim: int,
491
+ period: int = 1000,
492
+ amplitude: float = 1.0,
493
+ noise_std: float = 0.1,
494
+ feature_std: float = 1.0,
495
+ ):
496
+ """Initialize the periodic change stream.
497
+
498
+ Args:
499
+ feature_dim: Dimension of feature vectors
500
+ period: Steps for one complete oscillation cycle
501
+ amplitude: Magnitude of weight oscillations around base
502
+ noise_std: Std dev of target noise
503
+ feature_std: Std dev of feature values
504
+ """
505
+ self._feature_dim = feature_dim
506
+ self._period = period
507
+ self._amplitude = amplitude
508
+ self._noise_std = noise_std
509
+ self._feature_std = feature_std
510
+
511
+ @property
512
+ def feature_dim(self) -> int:
513
+ """Return the dimension of observation vectors."""
514
+ return self._feature_dim
515
+
516
+ def init(self, key: Array) -> PeriodicChangeState:
517
+ """Initialize stream state.
518
+
519
+ Args:
520
+ key: JAX random key
521
+
522
+ Returns:
523
+ Initial stream state with random base weights and phases
524
+ """
525
+ key, key_weights, key_phases = jr.split(key, 3)
526
+ base_weights = jr.normal(key_weights, (self._feature_dim,), dtype=jnp.float32)
527
+ # Random phases in [0, 2π) for each weight
528
+ phases = jr.uniform(key_phases, (self._feature_dim,), minval=0.0, maxval=2.0 * jnp.pi)
529
+ return PeriodicChangeState(
530
+ key=key,
531
+ base_weights=base_weights,
532
+ phases=phases,
533
+ step_count=jnp.array(0, dtype=jnp.int32),
534
+ )
535
+
536
+ def step(
537
+ self, state: PeriodicChangeState, idx: Array
538
+ ) -> tuple[TimeStep, PeriodicChangeState]:
539
+ """Generate one time step.
540
+
541
+ Args:
542
+ state: Current stream state
543
+ idx: Current step index (unused)
544
+
545
+ Returns:
546
+ Tuple of (timestep, new_state)
547
+ """
548
+ del idx # unused
549
+ key, key_x, key_noise = jr.split(state.key, 3)
550
+
551
+ # Compute oscillating weights: w(t) = base + amplitude * sin(2π * t / period + phase)
552
+ t = state.step_count.astype(jnp.float32)
553
+ oscillation = self._amplitude * jnp.sin(
554
+ 2.0 * jnp.pi * t / self._period + state.phases
555
+ )
556
+ true_weights = state.base_weights + oscillation
557
+
558
+ # Generate observation
559
+ x = self._feature_std * jr.normal(key_x, (self._feature_dim,), dtype=jnp.float32)
560
+
561
+ # Compute target
562
+ noise = self._noise_std * jr.normal(key_noise, (), dtype=jnp.float32)
563
+ target = jnp.dot(true_weights, x) + noise
564
+
565
+ timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
566
+ new_state = PeriodicChangeState(
567
+ key=key,
568
+ base_weights=state.base_weights,
569
+ phases=state.phases,
570
+ step_count=state.step_count + 1,
571
+ )
572
+
573
+ return timestep, new_state
574
+
575
+
576
+ class ScaledStreamState(NamedTuple):
577
+ """State for ScaledStreamWrapper.
578
+
579
+ Attributes:
580
+ inner_state: State of the wrapped stream
581
+ """
582
+
583
+ inner_state: tuple[Any, ...] # Generic state from wrapped stream
584
+
585
+
586
+ class ScaledStreamWrapper:
587
+ """Wrapper that applies per-feature scaling to any stream's observations.
588
+
589
+ This wrapper multiplies each feature of the observation by a corresponding
590
+ scale factor. Useful for testing how learners handle features at different
591
+ scales, which is important for understanding normalization benefits.
592
+
593
+ Examples
594
+ --------
595
+ ```python
596
+ stream = ScaledStreamWrapper(
597
+ AbruptChangeStream(feature_dim=10, change_interval=1000),
598
+ feature_scales=jnp.array([0.001, 0.01, 0.1, 1.0, 10.0,
599
+ 100.0, 1000.0, 0.001, 0.01, 0.1])
600
+ )
601
+ ```
602
+
603
+ Attributes:
604
+ inner_stream: The wrapped stream instance
605
+ feature_scales: Per-feature scale factors (must match feature_dim)
606
+ """
607
+
608
+ def __init__(self, inner_stream: ScanStream[Any], feature_scales: Array):
609
+ """Initialize the scaled stream wrapper.
610
+
611
+ Args:
612
+ inner_stream: Stream to wrap (must implement ScanStream protocol)
613
+ feature_scales: Array of scale factors, one per feature. Must have
614
+ shape (feature_dim,) matching the inner stream's feature_dim.
615
+
616
+ Raises:
617
+ ValueError: If feature_scales length doesn't match inner stream's feature_dim
618
+ """
619
+ self._inner_stream: ScanStream[Any] = inner_stream
620
+ self._feature_scales = jnp.asarray(feature_scales, dtype=jnp.float32)
621
+
622
+ if self._feature_scales.shape[0] != inner_stream.feature_dim:
623
+ raise ValueError(
624
+ f"feature_scales length ({self._feature_scales.shape[0]}) "
625
+ f"must match inner stream's feature_dim ({inner_stream.feature_dim})"
626
+ )
627
+
628
+ @property
629
+ def feature_dim(self) -> int:
630
+ """Return the dimension of observation vectors."""
631
+ return int(self._inner_stream.feature_dim)
632
+
633
+ @property
634
+ def inner_stream(self) -> ScanStream[Any]:
635
+ """Return the wrapped stream."""
636
+ return self._inner_stream
637
+
638
+ @property
639
+ def feature_scales(self) -> Array:
640
+ """Return the per-feature scale factors."""
641
+ return self._feature_scales
642
+
643
+ def init(self, key: Array) -> ScaledStreamState:
644
+ """Initialize stream state.
645
+
646
+ Args:
647
+ key: JAX random key
648
+
649
+ Returns:
650
+ Initial stream state wrapping the inner stream's state
651
+ """
652
+ inner_state = self._inner_stream.init(key)
653
+ return ScaledStreamState(inner_state=inner_state)
654
+
655
+ def step(self, state: ScaledStreamState, idx: Array) -> tuple[TimeStep, ScaledStreamState]:
656
+ """Generate one time step with scaled observations.
657
+
658
+ Args:
659
+ state: Current stream state
660
+ idx: Current step index
661
+
662
+ Returns:
663
+ Tuple of (timestep with scaled observation, new_state)
664
+ """
665
+ timestep, new_inner_state = self._inner_stream.step(state.inner_state, idx)
666
+
667
+ # Scale the observation
668
+ scaled_observation = timestep.observation * self._feature_scales
669
+
670
+ scaled_timestep = TimeStep(
671
+ observation=scaled_observation,
672
+ target=timestep.target,
673
+ )
674
+
675
+ new_state = ScaledStreamState(inner_state=new_inner_state)
676
+ return scaled_timestep, new_state
677
+
678
+
679
+ def make_scale_range(
680
+ feature_dim: int,
681
+ min_scale: float = 0.001,
682
+ max_scale: float = 1000.0,
683
+ log_spaced: bool = True,
684
+ ) -> Array:
685
+ """Create a per-feature scale array spanning a range.
686
+
687
+ Utility function to generate scale factors for ScaledStreamWrapper.
688
+
689
+ Args:
690
+ feature_dim: Number of features
691
+ min_scale: Minimum scale factor
692
+ max_scale: Maximum scale factor
693
+ log_spaced: If True, scales are logarithmically spaced (default).
694
+ If False, scales are linearly spaced.
695
+
696
+ Returns:
697
+ Array of shape (feature_dim,) with scale factors
698
+
699
+ Examples
700
+ --------
701
+ ```python
702
+ scales = make_scale_range(10, min_scale=0.01, max_scale=100.0)
703
+ stream = ScaledStreamWrapper(RandomWalkStream(10), scales)
704
+ ```
705
+ """
706
+ if log_spaced:
707
+ return jnp.logspace(
708
+ jnp.log10(min_scale),
709
+ jnp.log10(max_scale),
710
+ feature_dim,
711
+ dtype=jnp.float32,
712
+ )
713
+ else:
714
+ return jnp.linspace(min_scale, max_scale, feature_dim, dtype=jnp.float32)
715
+
716
+
717
+ class DynamicScaleShiftState(NamedTuple):
718
+ """State for DynamicScaleShiftStream.
719
+
720
+ Attributes:
721
+ key: JAX random key for generating randomness
722
+ true_weights: Current true target weights
723
+ current_scales: Current per-feature scaling factors
724
+ step_count: Number of steps taken
725
+ """
726
+
727
+ key: Array
728
+ true_weights: Array
729
+ current_scales: Array
730
+ step_count: Array
731
+
732
+
733
+ class DynamicScaleShiftStream:
734
+ """Non-stationary stream with abruptly changing feature scales.
735
+
736
+ Both target weights AND feature scales change at specified intervals.
737
+ This tests whether OnlineNormalizer can track scale shifts faster
738
+ than Autostep's internal v_i adaptation.
739
+
740
+ The target is computed from unscaled features to maintain consistent
741
+ difficulty across scale changes (only the feature representation changes,
742
+ not the underlying prediction task).
743
+
744
+ Attributes:
745
+ feature_dim: Dimension of observation vectors
746
+ scale_change_interval: Steps between scale changes
747
+ weight_change_interval: Steps between weight changes
748
+ min_scale: Minimum scale factor
749
+ max_scale: Maximum scale factor
750
+ noise_std: Standard deviation of observation noise
751
+ """
752
+
753
+ def __init__(
754
+ self,
755
+ feature_dim: int,
756
+ scale_change_interval: int = 2000,
757
+ weight_change_interval: int = 1000,
758
+ min_scale: float = 0.01,
759
+ max_scale: float = 100.0,
760
+ noise_std: float = 0.1,
761
+ ):
762
+ """Initialize the dynamic scale shift stream.
763
+
764
+ Args:
765
+ feature_dim: Dimension of feature vectors
766
+ scale_change_interval: Steps between abrupt scale changes
767
+ weight_change_interval: Steps between abrupt weight changes
768
+ min_scale: Minimum scale factor (log-uniform sampling)
769
+ max_scale: Maximum scale factor (log-uniform sampling)
770
+ noise_std: Std dev of target noise
771
+ """
772
+ self._feature_dim = feature_dim
773
+ self._scale_change_interval = scale_change_interval
774
+ self._weight_change_interval = weight_change_interval
775
+ self._min_scale = min_scale
776
+ self._max_scale = max_scale
777
+ self._noise_std = noise_std
778
+
779
+ @property
780
+ def feature_dim(self) -> int:
781
+ """Return the dimension of observation vectors."""
782
+ return self._feature_dim
783
+
784
+ def init(self, key: Array) -> DynamicScaleShiftState:
785
+ """Initialize stream state.
786
+
787
+ Args:
788
+ key: JAX random key
789
+
790
+ Returns:
791
+ Initial stream state with random weights and scales
792
+ """
793
+ key, k_weights, k_scales = jr.split(key, 3)
794
+ weights = jr.normal(k_weights, (self._feature_dim,), dtype=jnp.float32)
795
+ # Initial scales: log-uniform between min and max
796
+ log_scales = jr.uniform(
797
+ k_scales,
798
+ (self._feature_dim,),
799
+ minval=jnp.log(self._min_scale),
800
+ maxval=jnp.log(self._max_scale),
801
+ )
802
+ scales = jnp.exp(log_scales).astype(jnp.float32)
803
+ return DynamicScaleShiftState(
804
+ key=key,
805
+ true_weights=weights,
806
+ current_scales=scales,
807
+ step_count=jnp.array(0, dtype=jnp.int32),
808
+ )
809
+
810
+ def step(
811
+ self, state: DynamicScaleShiftState, idx: Array
812
+ ) -> tuple[TimeStep, DynamicScaleShiftState]:
813
+ """Generate one time step.
814
+
815
+ Args:
816
+ state: Current stream state
817
+ idx: Current step index (unused)
818
+
819
+ Returns:
820
+ Tuple of (timestep, new_state)
821
+ """
822
+ del idx # unused
823
+ key, k_weights, k_scales, k_x, k_noise = jr.split(state.key, 5)
824
+
825
+ # Check if scales should change
826
+ should_change_scales = state.step_count % self._scale_change_interval == 0
827
+ new_log_scales = jr.uniform(
828
+ k_scales,
829
+ (self._feature_dim,),
830
+ minval=jnp.log(self._min_scale),
831
+ maxval=jnp.log(self._max_scale),
832
+ )
833
+ new_random_scales = jnp.exp(new_log_scales).astype(jnp.float32)
834
+ new_scales = jnp.where(should_change_scales, new_random_scales, state.current_scales)
835
+
836
+ # Check if weights should change
837
+ should_change_weights = state.step_count % self._weight_change_interval == 0
838
+ new_random_weights = jr.normal(k_weights, (self._feature_dim,), dtype=jnp.float32)
839
+ new_weights = jnp.where(should_change_weights, new_random_weights, state.true_weights)
840
+
841
+ # Generate raw features (unscaled)
842
+ raw_x = jr.normal(k_x, (self._feature_dim,), dtype=jnp.float32)
843
+
844
+ # Apply scaling to observation
845
+ x = raw_x * new_scales
846
+
847
+ # Target from true weights using RAW features (for consistent difficulty)
848
+ noise = self._noise_std * jr.normal(k_noise, (), dtype=jnp.float32)
849
+ target = jnp.dot(new_weights, raw_x) + noise
850
+
851
+ timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
852
+ new_state = DynamicScaleShiftState(
853
+ key=key,
854
+ true_weights=new_weights,
855
+ current_scales=new_scales,
856
+ step_count=state.step_count + 1,
857
+ )
858
+ return timestep, new_state
859
+
860
+
861
+ class ScaleDriftState(NamedTuple):
862
+ """State for ScaleDriftStream.
863
+
864
+ Attributes:
865
+ key: JAX random key for generating randomness
866
+ true_weights: Current true target weights
867
+ log_scales: Current log-scale factors (random walk on log-scale)
868
+ step_count: Number of steps taken
869
+ """
870
+
871
+ key: Array
872
+ true_weights: Array
873
+ log_scales: Array
874
+ step_count: Array
875
+
876
+
877
+ class ScaleDriftStream:
878
+ """Non-stationary stream where feature scales drift via random walk.
879
+
880
+ Both target weights and feature scales drift continuously. Weights drift
881
+ in linear space while scales drift in log-space (bounded random walk).
882
+ This tests continuous scale tracking where OnlineNormalizer's EMA
883
+ may adapt differently than Autostep's v_i.
884
+
885
+ The target is computed from unscaled features to maintain consistent
886
+ difficulty across scale changes.
887
+
888
+ Attributes:
889
+ feature_dim: Dimension of observation vectors
890
+ weight_drift_rate: Std dev of weight drift per step
891
+ scale_drift_rate: Std dev of log-scale drift per step
892
+ min_log_scale: Minimum log-scale (clips random walk)
893
+ max_log_scale: Maximum log-scale (clips random walk)
894
+ noise_std: Standard deviation of observation noise
895
+ """
896
+
897
+ def __init__(
898
+ self,
899
+ feature_dim: int,
900
+ weight_drift_rate: float = 0.001,
901
+ scale_drift_rate: float = 0.01,
902
+ min_log_scale: float = -4.0, # exp(-4) ~ 0.018
903
+ max_log_scale: float = 4.0, # exp(4) ~ 54.6
904
+ noise_std: float = 0.1,
905
+ ):
906
+ """Initialize the scale drift stream.
907
+
908
+ Args:
909
+ feature_dim: Dimension of feature vectors
910
+ weight_drift_rate: Std dev of weight drift per step
911
+ scale_drift_rate: Std dev of log-scale drift per step
912
+ min_log_scale: Minimum log-scale (clips drift)
913
+ max_log_scale: Maximum log-scale (clips drift)
914
+ noise_std: Std dev of target noise
915
+ """
916
+ self._feature_dim = feature_dim
917
+ self._weight_drift_rate = weight_drift_rate
918
+ self._scale_drift_rate = scale_drift_rate
919
+ self._min_log_scale = min_log_scale
920
+ self._max_log_scale = max_log_scale
921
+ self._noise_std = noise_std
922
+
923
+ @property
924
+ def feature_dim(self) -> int:
925
+ """Return the dimension of observation vectors."""
926
+ return self._feature_dim
927
+
928
+ def init(self, key: Array) -> ScaleDriftState:
929
+ """Initialize stream state.
930
+
931
+ Args:
932
+ key: JAX random key
933
+
934
+ Returns:
935
+ Initial stream state with random weights and unit scales
936
+ """
937
+ key, k_weights = jr.split(key)
938
+ weights = jr.normal(k_weights, (self._feature_dim,), dtype=jnp.float32)
939
+ # Initial log-scales at 0 (scale = 1)
940
+ log_scales = jnp.zeros(self._feature_dim, dtype=jnp.float32)
941
+ return ScaleDriftState(
942
+ key=key,
943
+ true_weights=weights,
944
+ log_scales=log_scales,
945
+ step_count=jnp.array(0, dtype=jnp.int32),
946
+ )
947
+
948
+ def step(
949
+ self, state: ScaleDriftState, idx: Array
950
+ ) -> tuple[TimeStep, ScaleDriftState]:
951
+ """Generate one time step.
952
+
953
+ Args:
954
+ state: Current stream state
955
+ idx: Current step index (unused)
956
+
957
+ Returns:
958
+ Tuple of (timestep, new_state)
959
+ """
960
+ del idx # unused
961
+ key, k_w_drift, k_s_drift, k_x, k_noise = jr.split(state.key, 5)
962
+
963
+ # Drift target weights
964
+ weight_drift = self._weight_drift_rate * jr.normal(
965
+ k_w_drift, (self._feature_dim,), dtype=jnp.float32
966
+ )
967
+ new_weights = state.true_weights + weight_drift
968
+
969
+ # Drift log-scales (bounded random walk)
970
+ scale_drift = self._scale_drift_rate * jr.normal(
971
+ k_s_drift, (self._feature_dim,), dtype=jnp.float32
972
+ )
973
+ new_log_scales = state.log_scales + scale_drift
974
+ new_log_scales = jnp.clip(new_log_scales, self._min_log_scale, self._max_log_scale)
975
+
976
+ # Generate raw features (unscaled)
977
+ raw_x = jr.normal(k_x, (self._feature_dim,), dtype=jnp.float32)
978
+
979
+ # Apply scaling to observation
980
+ scales = jnp.exp(new_log_scales)
981
+ x = raw_x * scales
982
+
983
+ # Target from true weights using RAW features
984
+ noise = self._noise_std * jr.normal(k_noise, (), dtype=jnp.float32)
985
+ target = jnp.dot(new_weights, raw_x) + noise
986
+
987
+ timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
988
+ new_state = ScaleDriftState(
989
+ key=key,
990
+ true_weights=new_weights,
991
+ log_scales=new_log_scales,
992
+ step_count=state.step_count + 1,
993
+ )
994
+ return timestep, new_state
995
+
996
+
997
+ # Backward-compatible aliases
998
+ RandomWalkTarget = RandomWalkStream
999
+ AbruptChangeTarget = AbruptChangeStream
1000
+ CyclicTarget = CyclicStream
1001
+ PeriodicChangeTarget = PeriodicChangeStream