alberta-framework 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alberta_framework/__init__.py +225 -0
- alberta_framework/core/__init__.py +27 -0
- alberta_framework/core/learners.py +1070 -0
- alberta_framework/core/normalizers.py +192 -0
- alberta_framework/core/optimizers.py +424 -0
- alberta_framework/core/types.py +271 -0
- alberta_framework/py.typed +0 -0
- alberta_framework/streams/__init__.py +83 -0
- alberta_framework/streams/base.py +73 -0
- alberta_framework/streams/gymnasium.py +655 -0
- alberta_framework/streams/synthetic.py +1001 -0
- alberta_framework/utils/__init__.py +113 -0
- alberta_framework/utils/experiments.py +335 -0
- alberta_framework/utils/export.py +509 -0
- alberta_framework/utils/metrics.py +112 -0
- alberta_framework/utils/statistics.py +527 -0
- alberta_framework/utils/timing.py +144 -0
- alberta_framework/utils/visualization.py +571 -0
- alberta_framework-0.2.2.dist-info/METADATA +206 -0
- alberta_framework-0.2.2.dist-info/RECORD +22 -0
- alberta_framework-0.2.2.dist-info/WHEEL +4 -0
- alberta_framework-0.2.2.dist-info/licenses/LICENSE +190 -0
|
@@ -0,0 +1,1001 @@
|
|
|
1
|
+
"""Synthetic non-stationary experience streams for testing continual learning.
|
|
2
|
+
|
|
3
|
+
These streams generate non-stationary supervised learning problems where
|
|
4
|
+
the target function changes over time, testing the learner's ability to
|
|
5
|
+
track and adapt.
|
|
6
|
+
|
|
7
|
+
All streams use JAX-compatible pure functions that work with jax.lax.scan.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from typing import Any, NamedTuple
|
|
11
|
+
|
|
12
|
+
import jax.numpy as jnp
|
|
13
|
+
import jax.random as jr
|
|
14
|
+
from jax import Array
|
|
15
|
+
|
|
16
|
+
from alberta_framework.core.types import TimeStep
|
|
17
|
+
from alberta_framework.streams.base import ScanStream
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class RandomWalkState(NamedTuple):
|
|
21
|
+
"""State for RandomWalkStream.
|
|
22
|
+
|
|
23
|
+
Attributes:
|
|
24
|
+
key: JAX random key for generating randomness
|
|
25
|
+
true_weights: Current true target weights
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
key: Array
|
|
29
|
+
true_weights: Array
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class RandomWalkStream:
|
|
33
|
+
"""Non-stationary stream where target weights drift via random walk.
|
|
34
|
+
|
|
35
|
+
The true target function is linear: `y* = w_true @ x + noise`
|
|
36
|
+
where w_true evolves via random walk at each time step.
|
|
37
|
+
|
|
38
|
+
This tests the learner's ability to continuously track a moving target.
|
|
39
|
+
|
|
40
|
+
Attributes:
|
|
41
|
+
feature_dim: Dimension of observation vectors
|
|
42
|
+
drift_rate: Standard deviation of weight drift per step
|
|
43
|
+
noise_std: Standard deviation of observation noise
|
|
44
|
+
feature_std: Standard deviation of features
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
feature_dim: int,
|
|
50
|
+
drift_rate: float = 0.001,
|
|
51
|
+
noise_std: float = 0.1,
|
|
52
|
+
feature_std: float = 1.0,
|
|
53
|
+
):
|
|
54
|
+
"""Initialize the random walk target stream.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
feature_dim: Dimension of the feature/observation vectors
|
|
58
|
+
drift_rate: Std dev of weight changes per step (controls non-stationarity)
|
|
59
|
+
noise_std: Std dev of target noise
|
|
60
|
+
feature_std: Std dev of feature values
|
|
61
|
+
"""
|
|
62
|
+
self._feature_dim = feature_dim
|
|
63
|
+
self._drift_rate = drift_rate
|
|
64
|
+
self._noise_std = noise_std
|
|
65
|
+
self._feature_std = feature_std
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def feature_dim(self) -> int:
|
|
69
|
+
"""Return the dimension of observation vectors."""
|
|
70
|
+
return self._feature_dim
|
|
71
|
+
|
|
72
|
+
def init(self, key: Array) -> RandomWalkState:
|
|
73
|
+
"""Initialize stream state.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
key: JAX random key
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
Initial stream state with random weights
|
|
80
|
+
"""
|
|
81
|
+
key, subkey = jr.split(key)
|
|
82
|
+
weights = jr.normal(subkey, (self._feature_dim,), dtype=jnp.float32)
|
|
83
|
+
return RandomWalkState(key=key, true_weights=weights)
|
|
84
|
+
|
|
85
|
+
def step(self, state: RandomWalkState, idx: Array) -> tuple[TimeStep, RandomWalkState]:
|
|
86
|
+
"""Generate one time step.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
state: Current stream state
|
|
90
|
+
idx: Current step index (unused)
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
Tuple of (timestep, new_state)
|
|
94
|
+
"""
|
|
95
|
+
del idx # unused
|
|
96
|
+
key, k_drift, k_x, k_noise = jr.split(state.key, 4)
|
|
97
|
+
|
|
98
|
+
# Drift weights
|
|
99
|
+
drift = jr.normal(k_drift, state.true_weights.shape, dtype=jnp.float32)
|
|
100
|
+
new_weights = state.true_weights + self._drift_rate * drift
|
|
101
|
+
|
|
102
|
+
# Generate observation and target
|
|
103
|
+
x = self._feature_std * jr.normal(k_x, (self._feature_dim,), dtype=jnp.float32)
|
|
104
|
+
noise = self._noise_std * jr.normal(k_noise, (), dtype=jnp.float32)
|
|
105
|
+
target = jnp.dot(new_weights, x) + noise
|
|
106
|
+
|
|
107
|
+
timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
|
|
108
|
+
new_state = RandomWalkState(key=key, true_weights=new_weights)
|
|
109
|
+
|
|
110
|
+
return timestep, new_state
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class AbruptChangeState(NamedTuple):
|
|
114
|
+
"""State for AbruptChangeStream.
|
|
115
|
+
|
|
116
|
+
Attributes:
|
|
117
|
+
key: JAX random key for generating randomness
|
|
118
|
+
true_weights: Current true target weights
|
|
119
|
+
step_count: Number of steps taken
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
key: Array
|
|
123
|
+
true_weights: Array
|
|
124
|
+
step_count: Array
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class AbruptChangeStream:
|
|
128
|
+
"""Non-stationary stream with sudden target weight changes.
|
|
129
|
+
|
|
130
|
+
Target weights remain constant for a period, then abruptly change
|
|
131
|
+
to new random values. Tests the learner's ability to detect and
|
|
132
|
+
rapidly adapt to distribution shifts.
|
|
133
|
+
|
|
134
|
+
Attributes:
|
|
135
|
+
feature_dim: Dimension of observation vectors
|
|
136
|
+
change_interval: Number of steps between weight changes
|
|
137
|
+
noise_std: Standard deviation of observation noise
|
|
138
|
+
feature_std: Standard deviation of features
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
def __init__(
|
|
142
|
+
self,
|
|
143
|
+
feature_dim: int,
|
|
144
|
+
change_interval: int = 1000,
|
|
145
|
+
noise_std: float = 0.1,
|
|
146
|
+
feature_std: float = 1.0,
|
|
147
|
+
):
|
|
148
|
+
"""Initialize the abrupt change stream.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
feature_dim: Dimension of feature vectors
|
|
152
|
+
change_interval: Steps between abrupt weight changes
|
|
153
|
+
noise_std: Std dev of target noise
|
|
154
|
+
feature_std: Std dev of feature values
|
|
155
|
+
"""
|
|
156
|
+
self._feature_dim = feature_dim
|
|
157
|
+
self._change_interval = change_interval
|
|
158
|
+
self._noise_std = noise_std
|
|
159
|
+
self._feature_std = feature_std
|
|
160
|
+
|
|
161
|
+
@property
|
|
162
|
+
def feature_dim(self) -> int:
|
|
163
|
+
"""Return the dimension of observation vectors."""
|
|
164
|
+
return self._feature_dim
|
|
165
|
+
|
|
166
|
+
def init(self, key: Array) -> AbruptChangeState:
|
|
167
|
+
"""Initialize stream state.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
key: JAX random key
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
Initial stream state
|
|
174
|
+
"""
|
|
175
|
+
key, subkey = jr.split(key)
|
|
176
|
+
weights = jr.normal(subkey, (self._feature_dim,), dtype=jnp.float32)
|
|
177
|
+
return AbruptChangeState(
|
|
178
|
+
key=key,
|
|
179
|
+
true_weights=weights,
|
|
180
|
+
step_count=jnp.array(0, dtype=jnp.int32),
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
def step(self, state: AbruptChangeState, idx: Array) -> tuple[TimeStep, AbruptChangeState]:
|
|
184
|
+
"""Generate one time step.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
state: Current stream state
|
|
188
|
+
idx: Current step index (unused)
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
Tuple of (timestep, new_state)
|
|
192
|
+
"""
|
|
193
|
+
del idx # unused
|
|
194
|
+
key, key_weights, key_x, key_noise = jr.split(state.key, 4)
|
|
195
|
+
|
|
196
|
+
# Determine if we should change weights
|
|
197
|
+
should_change = state.step_count % self._change_interval == 0
|
|
198
|
+
|
|
199
|
+
# Generate new weights (always generated but only used if should_change)
|
|
200
|
+
new_random_weights = jr.normal(key_weights, (self._feature_dim,), dtype=jnp.float32)
|
|
201
|
+
|
|
202
|
+
# Use jnp.where to conditionally update weights (JIT-compatible)
|
|
203
|
+
new_weights = jnp.where(should_change, new_random_weights, state.true_weights)
|
|
204
|
+
|
|
205
|
+
# Generate observation
|
|
206
|
+
x = self._feature_std * jr.normal(key_x, (self._feature_dim,), dtype=jnp.float32)
|
|
207
|
+
|
|
208
|
+
# Compute target
|
|
209
|
+
noise = self._noise_std * jr.normal(key_noise, (), dtype=jnp.float32)
|
|
210
|
+
target = jnp.dot(new_weights, x) + noise
|
|
211
|
+
|
|
212
|
+
timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
|
|
213
|
+
new_state = AbruptChangeState(
|
|
214
|
+
key=key,
|
|
215
|
+
true_weights=new_weights,
|
|
216
|
+
step_count=state.step_count + 1,
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
return timestep, new_state
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
class SuttonExperiment1State(NamedTuple):
|
|
223
|
+
"""State for SuttonExperiment1Stream.
|
|
224
|
+
|
|
225
|
+
Attributes:
|
|
226
|
+
key: JAX random key for generating randomness
|
|
227
|
+
signs: Signs (+1/-1) for the relevant inputs
|
|
228
|
+
step_count: Number of steps taken
|
|
229
|
+
"""
|
|
230
|
+
|
|
231
|
+
key: Array
|
|
232
|
+
signs: Array
|
|
233
|
+
step_count: Array
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
class SuttonExperiment1Stream:
|
|
237
|
+
"""Non-stationary stream replicating Experiment 1 from Sutton 1992.
|
|
238
|
+
|
|
239
|
+
This stream implements the exact task from Sutton's IDBD paper:
|
|
240
|
+
- 20 real-valued inputs drawn from N(0, 1)
|
|
241
|
+
- Only first 5 inputs are relevant (weights are ±1)
|
|
242
|
+
- Last 15 inputs are irrelevant (weights are 0)
|
|
243
|
+
- Every change_interval steps, one of the 5 relevant signs is flipped
|
|
244
|
+
|
|
245
|
+
Reference: Sutton, R.S. (1992). "Adapting Bias by Gradient Descent:
|
|
246
|
+
An Incremental Version of Delta-Bar-Delta"
|
|
247
|
+
|
|
248
|
+
Attributes:
|
|
249
|
+
num_relevant: Number of relevant inputs (default 5)
|
|
250
|
+
num_irrelevant: Number of irrelevant inputs (default 15)
|
|
251
|
+
change_interval: Steps between sign changes (default 20)
|
|
252
|
+
"""
|
|
253
|
+
|
|
254
|
+
def __init__(
|
|
255
|
+
self,
|
|
256
|
+
num_relevant: int = 5,
|
|
257
|
+
num_irrelevant: int = 15,
|
|
258
|
+
change_interval: int = 20,
|
|
259
|
+
):
|
|
260
|
+
"""Initialize the Sutton Experiment 1 stream.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
num_relevant: Number of relevant inputs with ±1 weights
|
|
264
|
+
num_irrelevant: Number of irrelevant inputs with 0 weights
|
|
265
|
+
change_interval: Number of steps between sign flips
|
|
266
|
+
"""
|
|
267
|
+
self._num_relevant = num_relevant
|
|
268
|
+
self._num_irrelevant = num_irrelevant
|
|
269
|
+
self._change_interval = change_interval
|
|
270
|
+
|
|
271
|
+
@property
|
|
272
|
+
def feature_dim(self) -> int:
|
|
273
|
+
"""Return the dimension of observation vectors."""
|
|
274
|
+
return self._num_relevant + self._num_irrelevant
|
|
275
|
+
|
|
276
|
+
def init(self, key: Array) -> SuttonExperiment1State:
|
|
277
|
+
"""Initialize stream state.
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
key: JAX random key
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
Initial stream state with all +1 signs
|
|
284
|
+
"""
|
|
285
|
+
signs = jnp.ones(self._num_relevant, dtype=jnp.float32)
|
|
286
|
+
return SuttonExperiment1State(
|
|
287
|
+
key=key,
|
|
288
|
+
signs=signs,
|
|
289
|
+
step_count=jnp.array(0, dtype=jnp.int32),
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
def step(
|
|
293
|
+
self, state: SuttonExperiment1State, idx: Array
|
|
294
|
+
) -> tuple[TimeStep, SuttonExperiment1State]:
|
|
295
|
+
"""Generate one time step.
|
|
296
|
+
|
|
297
|
+
At each step:
|
|
298
|
+
1. If at a change interval (and not step 0), flip one random sign
|
|
299
|
+
2. Generate random inputs from N(0, 1)
|
|
300
|
+
3. Compute target as sum of relevant inputs weighted by signs
|
|
301
|
+
|
|
302
|
+
Args:
|
|
303
|
+
state: Current stream state
|
|
304
|
+
idx: Current step index (unused)
|
|
305
|
+
|
|
306
|
+
Returns:
|
|
307
|
+
Tuple of (timestep, new_state)
|
|
308
|
+
"""
|
|
309
|
+
del idx # unused
|
|
310
|
+
key, key_x, key_which = jr.split(state.key, 3)
|
|
311
|
+
|
|
312
|
+
# Determine if we should flip a sign (not at step 0)
|
|
313
|
+
should_flip = (state.step_count > 0) & (state.step_count % self._change_interval == 0)
|
|
314
|
+
|
|
315
|
+
# Select which sign to flip
|
|
316
|
+
idx_to_flip = jr.randint(key_which, (), 0, self._num_relevant)
|
|
317
|
+
|
|
318
|
+
# Create flip mask
|
|
319
|
+
flip_mask = jnp.where(
|
|
320
|
+
jnp.arange(self._num_relevant) == idx_to_flip,
|
|
321
|
+
jnp.array(-1.0, dtype=jnp.float32),
|
|
322
|
+
jnp.array(1.0, dtype=jnp.float32),
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
# Apply flip mask conditionally
|
|
326
|
+
new_signs = jnp.where(should_flip, state.signs * flip_mask, state.signs)
|
|
327
|
+
|
|
328
|
+
# Generate observation from N(0, 1)
|
|
329
|
+
x = jr.normal(key_x, (self.feature_dim,), dtype=jnp.float32)
|
|
330
|
+
|
|
331
|
+
# Compute target: sum of first num_relevant inputs weighted by signs
|
|
332
|
+
target = jnp.dot(new_signs, x[: self._num_relevant])
|
|
333
|
+
|
|
334
|
+
timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
|
|
335
|
+
new_state = SuttonExperiment1State(
|
|
336
|
+
key=key,
|
|
337
|
+
signs=new_signs,
|
|
338
|
+
step_count=state.step_count + 1,
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
return timestep, new_state
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
class CyclicState(NamedTuple):
|
|
345
|
+
"""State for CyclicStream.
|
|
346
|
+
|
|
347
|
+
Attributes:
|
|
348
|
+
key: JAX random key for generating randomness
|
|
349
|
+
configurations: Pre-generated weight configurations
|
|
350
|
+
step_count: Number of steps taken
|
|
351
|
+
"""
|
|
352
|
+
|
|
353
|
+
key: Array
|
|
354
|
+
configurations: Array
|
|
355
|
+
step_count: Array
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
class CyclicStream:
|
|
359
|
+
"""Non-stationary stream that cycles between known weight configurations.
|
|
360
|
+
|
|
361
|
+
Weights cycle through a fixed set of configurations. Tests whether
|
|
362
|
+
the learner can re-adapt quickly to previously seen targets.
|
|
363
|
+
|
|
364
|
+
Attributes:
|
|
365
|
+
feature_dim: Dimension of observation vectors
|
|
366
|
+
cycle_length: Number of steps per configuration before switching
|
|
367
|
+
num_configurations: Number of weight configurations to cycle through
|
|
368
|
+
noise_std: Standard deviation of observation noise
|
|
369
|
+
feature_std: Standard deviation of features
|
|
370
|
+
"""
|
|
371
|
+
|
|
372
|
+
def __init__(
|
|
373
|
+
self,
|
|
374
|
+
feature_dim: int,
|
|
375
|
+
cycle_length: int = 500,
|
|
376
|
+
num_configurations: int = 4,
|
|
377
|
+
noise_std: float = 0.1,
|
|
378
|
+
feature_std: float = 1.0,
|
|
379
|
+
):
|
|
380
|
+
"""Initialize the cyclic target stream.
|
|
381
|
+
|
|
382
|
+
Args:
|
|
383
|
+
feature_dim: Dimension of feature vectors
|
|
384
|
+
cycle_length: Steps spent in each configuration
|
|
385
|
+
num_configurations: Number of configurations to cycle through
|
|
386
|
+
noise_std: Std dev of target noise
|
|
387
|
+
feature_std: Std dev of feature values
|
|
388
|
+
"""
|
|
389
|
+
self._feature_dim = feature_dim
|
|
390
|
+
self._cycle_length = cycle_length
|
|
391
|
+
self._num_configurations = num_configurations
|
|
392
|
+
self._noise_std = noise_std
|
|
393
|
+
self._feature_std = feature_std
|
|
394
|
+
|
|
395
|
+
@property
|
|
396
|
+
def feature_dim(self) -> int:
|
|
397
|
+
"""Return the dimension of observation vectors."""
|
|
398
|
+
return self._feature_dim
|
|
399
|
+
|
|
400
|
+
def init(self, key: Array) -> CyclicState:
|
|
401
|
+
"""Initialize stream state.
|
|
402
|
+
|
|
403
|
+
Args:
|
|
404
|
+
key: JAX random key
|
|
405
|
+
|
|
406
|
+
Returns:
|
|
407
|
+
Initial stream state with pre-generated configurations
|
|
408
|
+
"""
|
|
409
|
+
key, key_configs = jr.split(key)
|
|
410
|
+
configurations = jr.normal(
|
|
411
|
+
key_configs,
|
|
412
|
+
(self._num_configurations, self._feature_dim),
|
|
413
|
+
dtype=jnp.float32,
|
|
414
|
+
)
|
|
415
|
+
return CyclicState(
|
|
416
|
+
key=key,
|
|
417
|
+
configurations=configurations,
|
|
418
|
+
step_count=jnp.array(0, dtype=jnp.int32),
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
def step(self, state: CyclicState, idx: Array) -> tuple[TimeStep, CyclicState]:
|
|
422
|
+
"""Generate one time step.
|
|
423
|
+
|
|
424
|
+
Args:
|
|
425
|
+
state: Current stream state
|
|
426
|
+
idx: Current step index (unused)
|
|
427
|
+
|
|
428
|
+
Returns:
|
|
429
|
+
Tuple of (timestep, new_state)
|
|
430
|
+
"""
|
|
431
|
+
del idx # unused
|
|
432
|
+
key, key_x, key_noise = jr.split(state.key, 3)
|
|
433
|
+
|
|
434
|
+
# Get current configuration index
|
|
435
|
+
config_idx = (state.step_count // self._cycle_length) % self._num_configurations
|
|
436
|
+
true_weights = state.configurations[config_idx]
|
|
437
|
+
|
|
438
|
+
# Generate observation
|
|
439
|
+
x = self._feature_std * jr.normal(key_x, (self._feature_dim,), dtype=jnp.float32)
|
|
440
|
+
|
|
441
|
+
# Compute target
|
|
442
|
+
noise = self._noise_std * jr.normal(key_noise, (), dtype=jnp.float32)
|
|
443
|
+
target = jnp.dot(true_weights, x) + noise
|
|
444
|
+
|
|
445
|
+
timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
|
|
446
|
+
new_state = CyclicState(
|
|
447
|
+
key=key,
|
|
448
|
+
configurations=state.configurations,
|
|
449
|
+
step_count=state.step_count + 1,
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
return timestep, new_state
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
class PeriodicChangeState(NamedTuple):
|
|
456
|
+
"""State for PeriodicChangeStream.
|
|
457
|
+
|
|
458
|
+
Attributes:
|
|
459
|
+
key: JAX random key for generating randomness
|
|
460
|
+
base_weights: Base target weights (center of oscillation)
|
|
461
|
+
phases: Per-weight phase offsets
|
|
462
|
+
step_count: Number of steps taken
|
|
463
|
+
"""
|
|
464
|
+
|
|
465
|
+
key: Array
|
|
466
|
+
base_weights: Array
|
|
467
|
+
phases: Array
|
|
468
|
+
step_count: Array
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
class PeriodicChangeStream:
|
|
472
|
+
"""Non-stationary stream where target weights oscillate sinusoidally.
|
|
473
|
+
|
|
474
|
+
Target weights follow: w(t) = base + amplitude * sin(2π * t / period + phase)
|
|
475
|
+
where each weight has a random phase offset for diversity.
|
|
476
|
+
|
|
477
|
+
This tests the learner's ability to track predictable periodic changes,
|
|
478
|
+
which is qualitatively different from random drift or abrupt changes.
|
|
479
|
+
|
|
480
|
+
Attributes:
|
|
481
|
+
feature_dim: Dimension of observation vectors
|
|
482
|
+
period: Number of steps for one complete oscillation
|
|
483
|
+
amplitude: Magnitude of weight oscillation
|
|
484
|
+
noise_std: Standard deviation of observation noise
|
|
485
|
+
feature_std: Standard deviation of features
|
|
486
|
+
"""
|
|
487
|
+
|
|
488
|
+
def __init__(
|
|
489
|
+
self,
|
|
490
|
+
feature_dim: int,
|
|
491
|
+
period: int = 1000,
|
|
492
|
+
amplitude: float = 1.0,
|
|
493
|
+
noise_std: float = 0.1,
|
|
494
|
+
feature_std: float = 1.0,
|
|
495
|
+
):
|
|
496
|
+
"""Initialize the periodic change stream.
|
|
497
|
+
|
|
498
|
+
Args:
|
|
499
|
+
feature_dim: Dimension of feature vectors
|
|
500
|
+
period: Steps for one complete oscillation cycle
|
|
501
|
+
amplitude: Magnitude of weight oscillations around base
|
|
502
|
+
noise_std: Std dev of target noise
|
|
503
|
+
feature_std: Std dev of feature values
|
|
504
|
+
"""
|
|
505
|
+
self._feature_dim = feature_dim
|
|
506
|
+
self._period = period
|
|
507
|
+
self._amplitude = amplitude
|
|
508
|
+
self._noise_std = noise_std
|
|
509
|
+
self._feature_std = feature_std
|
|
510
|
+
|
|
511
|
+
@property
|
|
512
|
+
def feature_dim(self) -> int:
|
|
513
|
+
"""Return the dimension of observation vectors."""
|
|
514
|
+
return self._feature_dim
|
|
515
|
+
|
|
516
|
+
def init(self, key: Array) -> PeriodicChangeState:
|
|
517
|
+
"""Initialize stream state.
|
|
518
|
+
|
|
519
|
+
Args:
|
|
520
|
+
key: JAX random key
|
|
521
|
+
|
|
522
|
+
Returns:
|
|
523
|
+
Initial stream state with random base weights and phases
|
|
524
|
+
"""
|
|
525
|
+
key, key_weights, key_phases = jr.split(key, 3)
|
|
526
|
+
base_weights = jr.normal(key_weights, (self._feature_dim,), dtype=jnp.float32)
|
|
527
|
+
# Random phases in [0, 2π) for each weight
|
|
528
|
+
phases = jr.uniform(key_phases, (self._feature_dim,), minval=0.0, maxval=2.0 * jnp.pi)
|
|
529
|
+
return PeriodicChangeState(
|
|
530
|
+
key=key,
|
|
531
|
+
base_weights=base_weights,
|
|
532
|
+
phases=phases,
|
|
533
|
+
step_count=jnp.array(0, dtype=jnp.int32),
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
def step(
|
|
537
|
+
self, state: PeriodicChangeState, idx: Array
|
|
538
|
+
) -> tuple[TimeStep, PeriodicChangeState]:
|
|
539
|
+
"""Generate one time step.
|
|
540
|
+
|
|
541
|
+
Args:
|
|
542
|
+
state: Current stream state
|
|
543
|
+
idx: Current step index (unused)
|
|
544
|
+
|
|
545
|
+
Returns:
|
|
546
|
+
Tuple of (timestep, new_state)
|
|
547
|
+
"""
|
|
548
|
+
del idx # unused
|
|
549
|
+
key, key_x, key_noise = jr.split(state.key, 3)
|
|
550
|
+
|
|
551
|
+
# Compute oscillating weights: w(t) = base + amplitude * sin(2π * t / period + phase)
|
|
552
|
+
t = state.step_count.astype(jnp.float32)
|
|
553
|
+
oscillation = self._amplitude * jnp.sin(
|
|
554
|
+
2.0 * jnp.pi * t / self._period + state.phases
|
|
555
|
+
)
|
|
556
|
+
true_weights = state.base_weights + oscillation
|
|
557
|
+
|
|
558
|
+
# Generate observation
|
|
559
|
+
x = self._feature_std * jr.normal(key_x, (self._feature_dim,), dtype=jnp.float32)
|
|
560
|
+
|
|
561
|
+
# Compute target
|
|
562
|
+
noise = self._noise_std * jr.normal(key_noise, (), dtype=jnp.float32)
|
|
563
|
+
target = jnp.dot(true_weights, x) + noise
|
|
564
|
+
|
|
565
|
+
timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
|
|
566
|
+
new_state = PeriodicChangeState(
|
|
567
|
+
key=key,
|
|
568
|
+
base_weights=state.base_weights,
|
|
569
|
+
phases=state.phases,
|
|
570
|
+
step_count=state.step_count + 1,
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
return timestep, new_state
|
|
574
|
+
|
|
575
|
+
|
|
576
|
+
class ScaledStreamState(NamedTuple):
|
|
577
|
+
"""State for ScaledStreamWrapper.
|
|
578
|
+
|
|
579
|
+
Attributes:
|
|
580
|
+
inner_state: State of the wrapped stream
|
|
581
|
+
"""
|
|
582
|
+
|
|
583
|
+
inner_state: tuple[Any, ...] # Generic state from wrapped stream
|
|
584
|
+
|
|
585
|
+
|
|
586
|
+
class ScaledStreamWrapper:
|
|
587
|
+
"""Wrapper that applies per-feature scaling to any stream's observations.
|
|
588
|
+
|
|
589
|
+
This wrapper multiplies each feature of the observation by a corresponding
|
|
590
|
+
scale factor. Useful for testing how learners handle features at different
|
|
591
|
+
scales, which is important for understanding normalization benefits.
|
|
592
|
+
|
|
593
|
+
Examples
|
|
594
|
+
--------
|
|
595
|
+
```python
|
|
596
|
+
stream = ScaledStreamWrapper(
|
|
597
|
+
AbruptChangeStream(feature_dim=10, change_interval=1000),
|
|
598
|
+
feature_scales=jnp.array([0.001, 0.01, 0.1, 1.0, 10.0,
|
|
599
|
+
100.0, 1000.0, 0.001, 0.01, 0.1])
|
|
600
|
+
)
|
|
601
|
+
```
|
|
602
|
+
|
|
603
|
+
Attributes:
|
|
604
|
+
inner_stream: The wrapped stream instance
|
|
605
|
+
feature_scales: Per-feature scale factors (must match feature_dim)
|
|
606
|
+
"""
|
|
607
|
+
|
|
608
|
+
def __init__(self, inner_stream: ScanStream[Any], feature_scales: Array):
|
|
609
|
+
"""Initialize the scaled stream wrapper.
|
|
610
|
+
|
|
611
|
+
Args:
|
|
612
|
+
inner_stream: Stream to wrap (must implement ScanStream protocol)
|
|
613
|
+
feature_scales: Array of scale factors, one per feature. Must have
|
|
614
|
+
shape (feature_dim,) matching the inner stream's feature_dim.
|
|
615
|
+
|
|
616
|
+
Raises:
|
|
617
|
+
ValueError: If feature_scales length doesn't match inner stream's feature_dim
|
|
618
|
+
"""
|
|
619
|
+
self._inner_stream: ScanStream[Any] = inner_stream
|
|
620
|
+
self._feature_scales = jnp.asarray(feature_scales, dtype=jnp.float32)
|
|
621
|
+
|
|
622
|
+
if self._feature_scales.shape[0] != inner_stream.feature_dim:
|
|
623
|
+
raise ValueError(
|
|
624
|
+
f"feature_scales length ({self._feature_scales.shape[0]}) "
|
|
625
|
+
f"must match inner stream's feature_dim ({inner_stream.feature_dim})"
|
|
626
|
+
)
|
|
627
|
+
|
|
628
|
+
@property
|
|
629
|
+
def feature_dim(self) -> int:
|
|
630
|
+
"""Return the dimension of observation vectors."""
|
|
631
|
+
return int(self._inner_stream.feature_dim)
|
|
632
|
+
|
|
633
|
+
@property
|
|
634
|
+
def inner_stream(self) -> ScanStream[Any]:
|
|
635
|
+
"""Return the wrapped stream."""
|
|
636
|
+
return self._inner_stream
|
|
637
|
+
|
|
638
|
+
@property
|
|
639
|
+
def feature_scales(self) -> Array:
|
|
640
|
+
"""Return the per-feature scale factors."""
|
|
641
|
+
return self._feature_scales
|
|
642
|
+
|
|
643
|
+
def init(self, key: Array) -> ScaledStreamState:
|
|
644
|
+
"""Initialize stream state.
|
|
645
|
+
|
|
646
|
+
Args:
|
|
647
|
+
key: JAX random key
|
|
648
|
+
|
|
649
|
+
Returns:
|
|
650
|
+
Initial stream state wrapping the inner stream's state
|
|
651
|
+
"""
|
|
652
|
+
inner_state = self._inner_stream.init(key)
|
|
653
|
+
return ScaledStreamState(inner_state=inner_state)
|
|
654
|
+
|
|
655
|
+
def step(self, state: ScaledStreamState, idx: Array) -> tuple[TimeStep, ScaledStreamState]:
|
|
656
|
+
"""Generate one time step with scaled observations.
|
|
657
|
+
|
|
658
|
+
Args:
|
|
659
|
+
state: Current stream state
|
|
660
|
+
idx: Current step index
|
|
661
|
+
|
|
662
|
+
Returns:
|
|
663
|
+
Tuple of (timestep with scaled observation, new_state)
|
|
664
|
+
"""
|
|
665
|
+
timestep, new_inner_state = self._inner_stream.step(state.inner_state, idx)
|
|
666
|
+
|
|
667
|
+
# Scale the observation
|
|
668
|
+
scaled_observation = timestep.observation * self._feature_scales
|
|
669
|
+
|
|
670
|
+
scaled_timestep = TimeStep(
|
|
671
|
+
observation=scaled_observation,
|
|
672
|
+
target=timestep.target,
|
|
673
|
+
)
|
|
674
|
+
|
|
675
|
+
new_state = ScaledStreamState(inner_state=new_inner_state)
|
|
676
|
+
return scaled_timestep, new_state
|
|
677
|
+
|
|
678
|
+
|
|
679
|
+
def make_scale_range(
|
|
680
|
+
feature_dim: int,
|
|
681
|
+
min_scale: float = 0.001,
|
|
682
|
+
max_scale: float = 1000.0,
|
|
683
|
+
log_spaced: bool = True,
|
|
684
|
+
) -> Array:
|
|
685
|
+
"""Create a per-feature scale array spanning a range.
|
|
686
|
+
|
|
687
|
+
Utility function to generate scale factors for ScaledStreamWrapper.
|
|
688
|
+
|
|
689
|
+
Args:
|
|
690
|
+
feature_dim: Number of features
|
|
691
|
+
min_scale: Minimum scale factor
|
|
692
|
+
max_scale: Maximum scale factor
|
|
693
|
+
log_spaced: If True, scales are logarithmically spaced (default).
|
|
694
|
+
If False, scales are linearly spaced.
|
|
695
|
+
|
|
696
|
+
Returns:
|
|
697
|
+
Array of shape (feature_dim,) with scale factors
|
|
698
|
+
|
|
699
|
+
Examples
|
|
700
|
+
--------
|
|
701
|
+
```python
|
|
702
|
+
scales = make_scale_range(10, min_scale=0.01, max_scale=100.0)
|
|
703
|
+
stream = ScaledStreamWrapper(RandomWalkStream(10), scales)
|
|
704
|
+
```
|
|
705
|
+
"""
|
|
706
|
+
if log_spaced:
|
|
707
|
+
return jnp.logspace(
|
|
708
|
+
jnp.log10(min_scale),
|
|
709
|
+
jnp.log10(max_scale),
|
|
710
|
+
feature_dim,
|
|
711
|
+
dtype=jnp.float32,
|
|
712
|
+
)
|
|
713
|
+
else:
|
|
714
|
+
return jnp.linspace(min_scale, max_scale, feature_dim, dtype=jnp.float32)
|
|
715
|
+
|
|
716
|
+
|
|
717
|
+
class DynamicScaleShiftState(NamedTuple):
|
|
718
|
+
"""State for DynamicScaleShiftStream.
|
|
719
|
+
|
|
720
|
+
Attributes:
|
|
721
|
+
key: JAX random key for generating randomness
|
|
722
|
+
true_weights: Current true target weights
|
|
723
|
+
current_scales: Current per-feature scaling factors
|
|
724
|
+
step_count: Number of steps taken
|
|
725
|
+
"""
|
|
726
|
+
|
|
727
|
+
key: Array
|
|
728
|
+
true_weights: Array
|
|
729
|
+
current_scales: Array
|
|
730
|
+
step_count: Array
|
|
731
|
+
|
|
732
|
+
|
|
733
|
+
class DynamicScaleShiftStream:
|
|
734
|
+
"""Non-stationary stream with abruptly changing feature scales.
|
|
735
|
+
|
|
736
|
+
Both target weights AND feature scales change at specified intervals.
|
|
737
|
+
This tests whether OnlineNormalizer can track scale shifts faster
|
|
738
|
+
than Autostep's internal v_i adaptation.
|
|
739
|
+
|
|
740
|
+
The target is computed from unscaled features to maintain consistent
|
|
741
|
+
difficulty across scale changes (only the feature representation changes,
|
|
742
|
+
not the underlying prediction task).
|
|
743
|
+
|
|
744
|
+
Attributes:
|
|
745
|
+
feature_dim: Dimension of observation vectors
|
|
746
|
+
scale_change_interval: Steps between scale changes
|
|
747
|
+
weight_change_interval: Steps between weight changes
|
|
748
|
+
min_scale: Minimum scale factor
|
|
749
|
+
max_scale: Maximum scale factor
|
|
750
|
+
noise_std: Standard deviation of observation noise
|
|
751
|
+
"""
|
|
752
|
+
|
|
753
|
+
def __init__(
|
|
754
|
+
self,
|
|
755
|
+
feature_dim: int,
|
|
756
|
+
scale_change_interval: int = 2000,
|
|
757
|
+
weight_change_interval: int = 1000,
|
|
758
|
+
min_scale: float = 0.01,
|
|
759
|
+
max_scale: float = 100.0,
|
|
760
|
+
noise_std: float = 0.1,
|
|
761
|
+
):
|
|
762
|
+
"""Initialize the dynamic scale shift stream.
|
|
763
|
+
|
|
764
|
+
Args:
|
|
765
|
+
feature_dim: Dimension of feature vectors
|
|
766
|
+
scale_change_interval: Steps between abrupt scale changes
|
|
767
|
+
weight_change_interval: Steps between abrupt weight changes
|
|
768
|
+
min_scale: Minimum scale factor (log-uniform sampling)
|
|
769
|
+
max_scale: Maximum scale factor (log-uniform sampling)
|
|
770
|
+
noise_std: Std dev of target noise
|
|
771
|
+
"""
|
|
772
|
+
self._feature_dim = feature_dim
|
|
773
|
+
self._scale_change_interval = scale_change_interval
|
|
774
|
+
self._weight_change_interval = weight_change_interval
|
|
775
|
+
self._min_scale = min_scale
|
|
776
|
+
self._max_scale = max_scale
|
|
777
|
+
self._noise_std = noise_std
|
|
778
|
+
|
|
779
|
+
@property
|
|
780
|
+
def feature_dim(self) -> int:
|
|
781
|
+
"""Return the dimension of observation vectors."""
|
|
782
|
+
return self._feature_dim
|
|
783
|
+
|
|
784
|
+
def init(self, key: Array) -> DynamicScaleShiftState:
|
|
785
|
+
"""Initialize stream state.
|
|
786
|
+
|
|
787
|
+
Args:
|
|
788
|
+
key: JAX random key
|
|
789
|
+
|
|
790
|
+
Returns:
|
|
791
|
+
Initial stream state with random weights and scales
|
|
792
|
+
"""
|
|
793
|
+
key, k_weights, k_scales = jr.split(key, 3)
|
|
794
|
+
weights = jr.normal(k_weights, (self._feature_dim,), dtype=jnp.float32)
|
|
795
|
+
# Initial scales: log-uniform between min and max
|
|
796
|
+
log_scales = jr.uniform(
|
|
797
|
+
k_scales,
|
|
798
|
+
(self._feature_dim,),
|
|
799
|
+
minval=jnp.log(self._min_scale),
|
|
800
|
+
maxval=jnp.log(self._max_scale),
|
|
801
|
+
)
|
|
802
|
+
scales = jnp.exp(log_scales).astype(jnp.float32)
|
|
803
|
+
return DynamicScaleShiftState(
|
|
804
|
+
key=key,
|
|
805
|
+
true_weights=weights,
|
|
806
|
+
current_scales=scales,
|
|
807
|
+
step_count=jnp.array(0, dtype=jnp.int32),
|
|
808
|
+
)
|
|
809
|
+
|
|
810
|
+
def step(
|
|
811
|
+
self, state: DynamicScaleShiftState, idx: Array
|
|
812
|
+
) -> tuple[TimeStep, DynamicScaleShiftState]:
|
|
813
|
+
"""Generate one time step.
|
|
814
|
+
|
|
815
|
+
Args:
|
|
816
|
+
state: Current stream state
|
|
817
|
+
idx: Current step index (unused)
|
|
818
|
+
|
|
819
|
+
Returns:
|
|
820
|
+
Tuple of (timestep, new_state)
|
|
821
|
+
"""
|
|
822
|
+
del idx # unused
|
|
823
|
+
key, k_weights, k_scales, k_x, k_noise = jr.split(state.key, 5)
|
|
824
|
+
|
|
825
|
+
# Check if scales should change
|
|
826
|
+
should_change_scales = state.step_count % self._scale_change_interval == 0
|
|
827
|
+
new_log_scales = jr.uniform(
|
|
828
|
+
k_scales,
|
|
829
|
+
(self._feature_dim,),
|
|
830
|
+
minval=jnp.log(self._min_scale),
|
|
831
|
+
maxval=jnp.log(self._max_scale),
|
|
832
|
+
)
|
|
833
|
+
new_random_scales = jnp.exp(new_log_scales).astype(jnp.float32)
|
|
834
|
+
new_scales = jnp.where(should_change_scales, new_random_scales, state.current_scales)
|
|
835
|
+
|
|
836
|
+
# Check if weights should change
|
|
837
|
+
should_change_weights = state.step_count % self._weight_change_interval == 0
|
|
838
|
+
new_random_weights = jr.normal(k_weights, (self._feature_dim,), dtype=jnp.float32)
|
|
839
|
+
new_weights = jnp.where(should_change_weights, new_random_weights, state.true_weights)
|
|
840
|
+
|
|
841
|
+
# Generate raw features (unscaled)
|
|
842
|
+
raw_x = jr.normal(k_x, (self._feature_dim,), dtype=jnp.float32)
|
|
843
|
+
|
|
844
|
+
# Apply scaling to observation
|
|
845
|
+
x = raw_x * new_scales
|
|
846
|
+
|
|
847
|
+
# Target from true weights using RAW features (for consistent difficulty)
|
|
848
|
+
noise = self._noise_std * jr.normal(k_noise, (), dtype=jnp.float32)
|
|
849
|
+
target = jnp.dot(new_weights, raw_x) + noise
|
|
850
|
+
|
|
851
|
+
timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
|
|
852
|
+
new_state = DynamicScaleShiftState(
|
|
853
|
+
key=key,
|
|
854
|
+
true_weights=new_weights,
|
|
855
|
+
current_scales=new_scales,
|
|
856
|
+
step_count=state.step_count + 1,
|
|
857
|
+
)
|
|
858
|
+
return timestep, new_state
|
|
859
|
+
|
|
860
|
+
|
|
861
|
+
class ScaleDriftState(NamedTuple):
|
|
862
|
+
"""State for ScaleDriftStream.
|
|
863
|
+
|
|
864
|
+
Attributes:
|
|
865
|
+
key: JAX random key for generating randomness
|
|
866
|
+
true_weights: Current true target weights
|
|
867
|
+
log_scales: Current log-scale factors (random walk on log-scale)
|
|
868
|
+
step_count: Number of steps taken
|
|
869
|
+
"""
|
|
870
|
+
|
|
871
|
+
key: Array
|
|
872
|
+
true_weights: Array
|
|
873
|
+
log_scales: Array
|
|
874
|
+
step_count: Array
|
|
875
|
+
|
|
876
|
+
|
|
877
|
+
class ScaleDriftStream:
|
|
878
|
+
"""Non-stationary stream where feature scales drift via random walk.
|
|
879
|
+
|
|
880
|
+
Both target weights and feature scales drift continuously. Weights drift
|
|
881
|
+
in linear space while scales drift in log-space (bounded random walk).
|
|
882
|
+
This tests continuous scale tracking where OnlineNormalizer's EMA
|
|
883
|
+
may adapt differently than Autostep's v_i.
|
|
884
|
+
|
|
885
|
+
The target is computed from unscaled features to maintain consistent
|
|
886
|
+
difficulty across scale changes.
|
|
887
|
+
|
|
888
|
+
Attributes:
|
|
889
|
+
feature_dim: Dimension of observation vectors
|
|
890
|
+
weight_drift_rate: Std dev of weight drift per step
|
|
891
|
+
scale_drift_rate: Std dev of log-scale drift per step
|
|
892
|
+
min_log_scale: Minimum log-scale (clips random walk)
|
|
893
|
+
max_log_scale: Maximum log-scale (clips random walk)
|
|
894
|
+
noise_std: Standard deviation of observation noise
|
|
895
|
+
"""
|
|
896
|
+
|
|
897
|
+
def __init__(
|
|
898
|
+
self,
|
|
899
|
+
feature_dim: int,
|
|
900
|
+
weight_drift_rate: float = 0.001,
|
|
901
|
+
scale_drift_rate: float = 0.01,
|
|
902
|
+
min_log_scale: float = -4.0, # exp(-4) ~ 0.018
|
|
903
|
+
max_log_scale: float = 4.0, # exp(4) ~ 54.6
|
|
904
|
+
noise_std: float = 0.1,
|
|
905
|
+
):
|
|
906
|
+
"""Initialize the scale drift stream.
|
|
907
|
+
|
|
908
|
+
Args:
|
|
909
|
+
feature_dim: Dimension of feature vectors
|
|
910
|
+
weight_drift_rate: Std dev of weight drift per step
|
|
911
|
+
scale_drift_rate: Std dev of log-scale drift per step
|
|
912
|
+
min_log_scale: Minimum log-scale (clips drift)
|
|
913
|
+
max_log_scale: Maximum log-scale (clips drift)
|
|
914
|
+
noise_std: Std dev of target noise
|
|
915
|
+
"""
|
|
916
|
+
self._feature_dim = feature_dim
|
|
917
|
+
self._weight_drift_rate = weight_drift_rate
|
|
918
|
+
self._scale_drift_rate = scale_drift_rate
|
|
919
|
+
self._min_log_scale = min_log_scale
|
|
920
|
+
self._max_log_scale = max_log_scale
|
|
921
|
+
self._noise_std = noise_std
|
|
922
|
+
|
|
923
|
+
@property
|
|
924
|
+
def feature_dim(self) -> int:
|
|
925
|
+
"""Return the dimension of observation vectors."""
|
|
926
|
+
return self._feature_dim
|
|
927
|
+
|
|
928
|
+
def init(self, key: Array) -> ScaleDriftState:
|
|
929
|
+
"""Initialize stream state.
|
|
930
|
+
|
|
931
|
+
Args:
|
|
932
|
+
key: JAX random key
|
|
933
|
+
|
|
934
|
+
Returns:
|
|
935
|
+
Initial stream state with random weights and unit scales
|
|
936
|
+
"""
|
|
937
|
+
key, k_weights = jr.split(key)
|
|
938
|
+
weights = jr.normal(k_weights, (self._feature_dim,), dtype=jnp.float32)
|
|
939
|
+
# Initial log-scales at 0 (scale = 1)
|
|
940
|
+
log_scales = jnp.zeros(self._feature_dim, dtype=jnp.float32)
|
|
941
|
+
return ScaleDriftState(
|
|
942
|
+
key=key,
|
|
943
|
+
true_weights=weights,
|
|
944
|
+
log_scales=log_scales,
|
|
945
|
+
step_count=jnp.array(0, dtype=jnp.int32),
|
|
946
|
+
)
|
|
947
|
+
|
|
948
|
+
def step(
|
|
949
|
+
self, state: ScaleDriftState, idx: Array
|
|
950
|
+
) -> tuple[TimeStep, ScaleDriftState]:
|
|
951
|
+
"""Generate one time step.
|
|
952
|
+
|
|
953
|
+
Args:
|
|
954
|
+
state: Current stream state
|
|
955
|
+
idx: Current step index (unused)
|
|
956
|
+
|
|
957
|
+
Returns:
|
|
958
|
+
Tuple of (timestep, new_state)
|
|
959
|
+
"""
|
|
960
|
+
del idx # unused
|
|
961
|
+
key, k_w_drift, k_s_drift, k_x, k_noise = jr.split(state.key, 5)
|
|
962
|
+
|
|
963
|
+
# Drift target weights
|
|
964
|
+
weight_drift = self._weight_drift_rate * jr.normal(
|
|
965
|
+
k_w_drift, (self._feature_dim,), dtype=jnp.float32
|
|
966
|
+
)
|
|
967
|
+
new_weights = state.true_weights + weight_drift
|
|
968
|
+
|
|
969
|
+
# Drift log-scales (bounded random walk)
|
|
970
|
+
scale_drift = self._scale_drift_rate * jr.normal(
|
|
971
|
+
k_s_drift, (self._feature_dim,), dtype=jnp.float32
|
|
972
|
+
)
|
|
973
|
+
new_log_scales = state.log_scales + scale_drift
|
|
974
|
+
new_log_scales = jnp.clip(new_log_scales, self._min_log_scale, self._max_log_scale)
|
|
975
|
+
|
|
976
|
+
# Generate raw features (unscaled)
|
|
977
|
+
raw_x = jr.normal(k_x, (self._feature_dim,), dtype=jnp.float32)
|
|
978
|
+
|
|
979
|
+
# Apply scaling to observation
|
|
980
|
+
scales = jnp.exp(new_log_scales)
|
|
981
|
+
x = raw_x * scales
|
|
982
|
+
|
|
983
|
+
# Target from true weights using RAW features
|
|
984
|
+
noise = self._noise_std * jr.normal(k_noise, (), dtype=jnp.float32)
|
|
985
|
+
target = jnp.dot(new_weights, raw_x) + noise
|
|
986
|
+
|
|
987
|
+
timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
|
|
988
|
+
new_state = ScaleDriftState(
|
|
989
|
+
key=key,
|
|
990
|
+
true_weights=new_weights,
|
|
991
|
+
log_scales=new_log_scales,
|
|
992
|
+
step_count=state.step_count + 1,
|
|
993
|
+
)
|
|
994
|
+
return timestep, new_state
|
|
995
|
+
|
|
996
|
+
|
|
997
|
+
# Backward-compatible aliases
|
|
998
|
+
RandomWalkTarget = RandomWalkStream
|
|
999
|
+
AbruptChangeTarget = AbruptChangeStream
|
|
1000
|
+
CyclicTarget = CyclicStream
|
|
1001
|
+
PeriodicChangeTarget = PeriodicChangeStream
|