alberta-framework 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alberta_framework/__init__.py +196 -0
- alberta_framework/core/__init__.py +27 -0
- alberta_framework/core/learners.py +530 -0
- alberta_framework/core/normalizers.py +192 -0
- alberta_framework/core/optimizers.py +422 -0
- alberta_framework/core/types.py +198 -0
- alberta_framework/py.typed +0 -0
- alberta_framework/streams/__init__.py +83 -0
- alberta_framework/streams/base.py +70 -0
- alberta_framework/streams/gymnasium.py +655 -0
- alberta_framework/streams/synthetic.py +995 -0
- alberta_framework/utils/__init__.py +113 -0
- alberta_framework/utils/experiments.py +334 -0
- alberta_framework/utils/export.py +509 -0
- alberta_framework/utils/metrics.py +112 -0
- alberta_framework/utils/statistics.py +527 -0
- alberta_framework/utils/timing.py +138 -0
- alberta_framework/utils/visualization.py +571 -0
- alberta_framework-0.1.0.dist-info/METADATA +198 -0
- alberta_framework-0.1.0.dist-info/RECORD +22 -0
- alberta_framework-0.1.0.dist-info/WHEEL +4 -0
- alberta_framework-0.1.0.dist-info/licenses/LICENSE +190 -0
|
@@ -0,0 +1,655 @@
|
|
|
1
|
+
"""Gymnasium environment wrappers as experience streams.
|
|
2
|
+
|
|
3
|
+
This module wraps Gymnasium environments to provide temporally-uniform experience
|
|
4
|
+
streams compatible with the Alberta Framework's learners.
|
|
5
|
+
|
|
6
|
+
Gymnasium environments cannot be JIT-compiled, so this module provides:
|
|
7
|
+
1. Trajectory collection: Collect data using Python loop, then learn with scan
|
|
8
|
+
2. Online learning: Python loop for cases requiring real-time env interaction
|
|
9
|
+
|
|
10
|
+
Supports multiple prediction modes:
|
|
11
|
+
- REWARD: Predict immediate reward from (state, action)
|
|
12
|
+
- NEXT_STATE: Predict next state from (state, action)
|
|
13
|
+
- VALUE: Predict cumulative return via TD learning
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
from collections.abc import Callable, Iterator
|
|
19
|
+
from enum import Enum
|
|
20
|
+
from typing import TYPE_CHECKING, Any
|
|
21
|
+
|
|
22
|
+
import jax
|
|
23
|
+
import jax.numpy as jnp
|
|
24
|
+
import jax.random as jr
|
|
25
|
+
from jax import Array
|
|
26
|
+
|
|
27
|
+
from alberta_framework.core.learners import LinearLearner, NormalizedLinearLearner
|
|
28
|
+
from alberta_framework.core.types import LearnerState, TimeStep
|
|
29
|
+
|
|
30
|
+
if TYPE_CHECKING:
|
|
31
|
+
import gymnasium
|
|
32
|
+
|
|
33
|
+
from alberta_framework.core.learners import NormalizedLearnerState
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class PredictionMode(Enum):
|
|
37
|
+
"""Mode for what the stream predicts.
|
|
38
|
+
|
|
39
|
+
REWARD: Predict immediate reward from (state, action)
|
|
40
|
+
NEXT_STATE: Predict next state from (state, action)
|
|
41
|
+
VALUE: Predict cumulative return (TD learning with bootstrap)
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
REWARD = "reward"
|
|
45
|
+
NEXT_STATE = "next_state"
|
|
46
|
+
VALUE = "value"
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _flatten_space(space: gymnasium.spaces.Space[Any]) -> int:
|
|
50
|
+
"""Get the flattened dimension of a Gymnasium space.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
space: A Gymnasium space (Box, Discrete, MultiDiscrete)
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
Integer dimension of the flattened space
|
|
57
|
+
|
|
58
|
+
Raises:
|
|
59
|
+
ValueError: If space type is not supported
|
|
60
|
+
"""
|
|
61
|
+
import gymnasium
|
|
62
|
+
|
|
63
|
+
if isinstance(space, gymnasium.spaces.Box):
|
|
64
|
+
return int(jnp.prod(jnp.array(space.shape)))
|
|
65
|
+
elif isinstance(space, gymnasium.spaces.Discrete):
|
|
66
|
+
return 1
|
|
67
|
+
elif isinstance(space, gymnasium.spaces.MultiDiscrete):
|
|
68
|
+
return len(space.nvec)
|
|
69
|
+
else:
|
|
70
|
+
raise ValueError(
|
|
71
|
+
f"Unsupported space type: {type(space).__name__}. "
|
|
72
|
+
"Supported types: Box, Discrete, MultiDiscrete"
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _flatten_observation(obs: Any, space: gymnasium.spaces.Space[Any]) -> Array:
|
|
77
|
+
"""Flatten an observation to a 1D JAX array.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
obs: Observation from the environment
|
|
81
|
+
space: The observation space
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
Flattened observation as a 1D JAX array
|
|
85
|
+
"""
|
|
86
|
+
import gymnasium
|
|
87
|
+
|
|
88
|
+
if isinstance(space, gymnasium.spaces.Box):
|
|
89
|
+
return jnp.asarray(obs, dtype=jnp.float32).flatten()
|
|
90
|
+
elif isinstance(space, gymnasium.spaces.Discrete):
|
|
91
|
+
return jnp.array([float(obs)], dtype=jnp.float32)
|
|
92
|
+
elif isinstance(space, gymnasium.spaces.MultiDiscrete):
|
|
93
|
+
return jnp.asarray(obs, dtype=jnp.float32)
|
|
94
|
+
else:
|
|
95
|
+
raise ValueError(f"Unsupported space type: {type(space).__name__}")
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _flatten_action(action: Any, space: gymnasium.spaces.Space[Any]) -> Array:
|
|
99
|
+
"""Flatten an action to a 1D JAX array.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
action: Action for the environment
|
|
103
|
+
space: The action space
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
Flattened action as a 1D JAX array
|
|
107
|
+
"""
|
|
108
|
+
import gymnasium
|
|
109
|
+
|
|
110
|
+
if isinstance(space, gymnasium.spaces.Box):
|
|
111
|
+
return jnp.asarray(action, dtype=jnp.float32).flatten()
|
|
112
|
+
elif isinstance(space, gymnasium.spaces.Discrete):
|
|
113
|
+
return jnp.array([float(action)], dtype=jnp.float32)
|
|
114
|
+
elif isinstance(space, gymnasium.spaces.MultiDiscrete):
|
|
115
|
+
return jnp.asarray(action, dtype=jnp.float32)
|
|
116
|
+
else:
|
|
117
|
+
raise ValueError(f"Unsupported space type: {type(space).__name__}")
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def make_random_policy(
|
|
121
|
+
env: gymnasium.Env[Any, Any], seed: int = 0
|
|
122
|
+
) -> Callable[[Array], Any]:
|
|
123
|
+
"""Create a random action policy for an environment.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
env: Gymnasium environment
|
|
127
|
+
seed: Random seed
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
A callable that takes an observation and returns a random action
|
|
131
|
+
"""
|
|
132
|
+
import gymnasium
|
|
133
|
+
|
|
134
|
+
rng = jr.key(seed)
|
|
135
|
+
action_space = env.action_space
|
|
136
|
+
|
|
137
|
+
def policy(_obs: Array) -> Any:
|
|
138
|
+
nonlocal rng
|
|
139
|
+
rng, key = jr.split(rng)
|
|
140
|
+
|
|
141
|
+
if isinstance(action_space, gymnasium.spaces.Discrete):
|
|
142
|
+
return int(jr.randint(key, (), 0, int(action_space.n)))
|
|
143
|
+
elif isinstance(action_space, gymnasium.spaces.Box):
|
|
144
|
+
# Sample uniformly between low and high
|
|
145
|
+
low = jnp.asarray(action_space.low, dtype=jnp.float32)
|
|
146
|
+
high = jnp.asarray(action_space.high, dtype=jnp.float32)
|
|
147
|
+
return jr.uniform(key, action_space.shape, minval=low, maxval=high)
|
|
148
|
+
elif isinstance(action_space, gymnasium.spaces.MultiDiscrete):
|
|
149
|
+
nvec = action_space.nvec
|
|
150
|
+
return [
|
|
151
|
+
int(jr.randint(jr.fold_in(key, i), (), 0, n))
|
|
152
|
+
for i, n in enumerate(nvec)
|
|
153
|
+
]
|
|
154
|
+
else:
|
|
155
|
+
raise ValueError(f"Unsupported action space: {type(action_space).__name__}")
|
|
156
|
+
|
|
157
|
+
return policy
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def make_epsilon_greedy_policy(
|
|
161
|
+
base_policy: Callable[[Array], Any],
|
|
162
|
+
env: gymnasium.Env[Any, Any],
|
|
163
|
+
epsilon: float = 0.1,
|
|
164
|
+
seed: int = 0,
|
|
165
|
+
) -> Callable[[Array], Any]:
|
|
166
|
+
"""Wrap a policy with epsilon-greedy exploration.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
base_policy: The greedy policy to wrap
|
|
170
|
+
env: Gymnasium environment (for random action sampling)
|
|
171
|
+
epsilon: Probability of taking a random action
|
|
172
|
+
seed: Random seed
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
Epsilon-greedy policy
|
|
176
|
+
"""
|
|
177
|
+
random_policy = make_random_policy(env, seed + 1)
|
|
178
|
+
rng = jr.key(seed)
|
|
179
|
+
|
|
180
|
+
def policy(obs: Array) -> Any:
|
|
181
|
+
nonlocal rng
|
|
182
|
+
rng, key = jr.split(rng)
|
|
183
|
+
|
|
184
|
+
if jr.uniform(key) < epsilon:
|
|
185
|
+
return random_policy(obs)
|
|
186
|
+
return base_policy(obs)
|
|
187
|
+
|
|
188
|
+
return policy
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def collect_trajectory(
|
|
192
|
+
env: gymnasium.Env[Any, Any],
|
|
193
|
+
policy: Callable[[Array], Any] | None,
|
|
194
|
+
num_steps: int,
|
|
195
|
+
mode: PredictionMode = PredictionMode.REWARD,
|
|
196
|
+
include_action_in_features: bool = True,
|
|
197
|
+
seed: int = 0,
|
|
198
|
+
) -> tuple[Array, Array]:
|
|
199
|
+
"""Collect a trajectory from a Gymnasium environment.
|
|
200
|
+
|
|
201
|
+
This uses a Python loop to interact with the environment and collects
|
|
202
|
+
observations and targets into JAX arrays that can be used with scan-based
|
|
203
|
+
learning.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
env: Gymnasium environment instance
|
|
207
|
+
policy: Action selection function. If None, uses random policy
|
|
208
|
+
num_steps: Number of steps to collect
|
|
209
|
+
mode: What to predict (REWARD, NEXT_STATE, VALUE)
|
|
210
|
+
include_action_in_features: If True, features = concat(obs, action)
|
|
211
|
+
seed: Random seed for environment resets and random policy
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
Tuple of (observations, targets) as JAX arrays with shape
|
|
215
|
+
(num_steps, feature_dim) and (num_steps, target_dim)
|
|
216
|
+
"""
|
|
217
|
+
if policy is None:
|
|
218
|
+
policy = make_random_policy(env, seed)
|
|
219
|
+
|
|
220
|
+
observations = []
|
|
221
|
+
targets = []
|
|
222
|
+
|
|
223
|
+
reset_count = 0
|
|
224
|
+
raw_obs, _ = env.reset(seed=seed + reset_count)
|
|
225
|
+
reset_count += 1
|
|
226
|
+
current_obs = _flatten_observation(raw_obs, env.observation_space)
|
|
227
|
+
|
|
228
|
+
for _ in range(num_steps):
|
|
229
|
+
action = policy(current_obs)
|
|
230
|
+
flat_action = _flatten_action(action, env.action_space)
|
|
231
|
+
|
|
232
|
+
raw_next_obs, reward, terminated, truncated, _ = env.step(action)
|
|
233
|
+
next_obs = _flatten_observation(raw_next_obs, env.observation_space)
|
|
234
|
+
|
|
235
|
+
# Construct features
|
|
236
|
+
if include_action_in_features:
|
|
237
|
+
features = jnp.concatenate([current_obs, flat_action])
|
|
238
|
+
else:
|
|
239
|
+
features = current_obs
|
|
240
|
+
|
|
241
|
+
# Construct target based on mode
|
|
242
|
+
if mode == PredictionMode.REWARD:
|
|
243
|
+
target = jnp.atleast_1d(jnp.array(reward, dtype=jnp.float32))
|
|
244
|
+
elif mode == PredictionMode.NEXT_STATE:
|
|
245
|
+
target = next_obs
|
|
246
|
+
else: # VALUE mode
|
|
247
|
+
# TD target with 0 bootstrap (simple version)
|
|
248
|
+
target = jnp.atleast_1d(jnp.array(reward, dtype=jnp.float32))
|
|
249
|
+
|
|
250
|
+
observations.append(features)
|
|
251
|
+
targets.append(target)
|
|
252
|
+
|
|
253
|
+
if terminated or truncated:
|
|
254
|
+
raw_obs, _ = env.reset(seed=seed + reset_count)
|
|
255
|
+
reset_count += 1
|
|
256
|
+
current_obs = _flatten_observation(raw_obs, env.observation_space)
|
|
257
|
+
else:
|
|
258
|
+
current_obs = next_obs
|
|
259
|
+
|
|
260
|
+
return jnp.stack(observations), jnp.stack(targets)
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def learn_from_trajectory(
|
|
264
|
+
learner: LinearLearner,
|
|
265
|
+
observations: Array,
|
|
266
|
+
targets: Array,
|
|
267
|
+
learner_state: LearnerState | None = None,
|
|
268
|
+
) -> tuple[LearnerState, Array]:
|
|
269
|
+
"""Learn from a pre-collected trajectory using jax.lax.scan.
|
|
270
|
+
|
|
271
|
+
This is a JIT-compiled learning function that processes a trajectory
|
|
272
|
+
collected from a Gymnasium environment.
|
|
273
|
+
|
|
274
|
+
Args:
|
|
275
|
+
learner: The learner to train
|
|
276
|
+
observations: Array of observations with shape (num_steps, feature_dim)
|
|
277
|
+
targets: Array of targets with shape (num_steps, target_dim)
|
|
278
|
+
learner_state: Initial state (if None, will be initialized)
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
Tuple of (final_state, metrics_array) where metrics_array has shape
|
|
282
|
+
(num_steps, 3) with columns [squared_error, error, mean_step_size]
|
|
283
|
+
"""
|
|
284
|
+
if learner_state is None:
|
|
285
|
+
learner_state = learner.init(observations.shape[1])
|
|
286
|
+
|
|
287
|
+
def step_fn(
|
|
288
|
+
state: LearnerState, inputs: tuple[Array, Array]
|
|
289
|
+
) -> tuple[LearnerState, Array]:
|
|
290
|
+
obs, target = inputs
|
|
291
|
+
result = learner.update(state, obs, target)
|
|
292
|
+
return result.state, result.metrics
|
|
293
|
+
|
|
294
|
+
final_state, metrics = jax.lax.scan(step_fn, learner_state, (observations, targets))
|
|
295
|
+
|
|
296
|
+
return final_state, metrics
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def learn_from_trajectory_normalized(
|
|
300
|
+
learner: NormalizedLinearLearner,
|
|
301
|
+
observations: Array,
|
|
302
|
+
targets: Array,
|
|
303
|
+
learner_state: NormalizedLearnerState | None = None,
|
|
304
|
+
) -> tuple[NormalizedLearnerState, Array]:
|
|
305
|
+
"""Learn from a pre-collected trajectory with normalization using jax.lax.scan.
|
|
306
|
+
|
|
307
|
+
Args:
|
|
308
|
+
learner: The normalized learner to train
|
|
309
|
+
observations: Array of observations with shape (num_steps, feature_dim)
|
|
310
|
+
targets: Array of targets with shape (num_steps, target_dim)
|
|
311
|
+
learner_state: Initial state (if None, will be initialized)
|
|
312
|
+
|
|
313
|
+
Returns:
|
|
314
|
+
Tuple of (final_state, metrics_array) where metrics_array has shape
|
|
315
|
+
(num_steps, 4) with columns [squared_error, error, mean_step_size, normalizer_mean_var]
|
|
316
|
+
"""
|
|
317
|
+
if learner_state is None:
|
|
318
|
+
learner_state = learner.init(observations.shape[1])
|
|
319
|
+
|
|
320
|
+
def step_fn(
|
|
321
|
+
state: NormalizedLearnerState, inputs: tuple[Array, Array]
|
|
322
|
+
) -> tuple[NormalizedLearnerState, Array]:
|
|
323
|
+
obs, target = inputs
|
|
324
|
+
result = learner.update(state, obs, target)
|
|
325
|
+
return result.state, result.metrics
|
|
326
|
+
|
|
327
|
+
final_state, metrics = jax.lax.scan(step_fn, learner_state, (observations, targets))
|
|
328
|
+
|
|
329
|
+
return final_state, metrics
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
class GymnasiumStream:
|
|
333
|
+
"""Experience stream from a Gymnasium environment using Python loop.
|
|
334
|
+
|
|
335
|
+
This class maintains iterator-based access for online learning scenarios
|
|
336
|
+
where you need to interact with the environment in real-time.
|
|
337
|
+
|
|
338
|
+
For batch learning, use collect_trajectory() followed by learn_from_trajectory().
|
|
339
|
+
|
|
340
|
+
Attributes:
|
|
341
|
+
mode: Prediction mode (REWARD, NEXT_STATE, VALUE)
|
|
342
|
+
gamma: Discount factor for VALUE mode
|
|
343
|
+
include_action_in_features: Whether to include action in features
|
|
344
|
+
episode_count: Number of completed episodes
|
|
345
|
+
"""
|
|
346
|
+
|
|
347
|
+
def __init__(
|
|
348
|
+
self,
|
|
349
|
+
env: gymnasium.Env[Any, Any],
|
|
350
|
+
mode: PredictionMode = PredictionMode.REWARD,
|
|
351
|
+
policy: Callable[[Array], Any] | None = None,
|
|
352
|
+
gamma: float = 0.99,
|
|
353
|
+
include_action_in_features: bool = True,
|
|
354
|
+
seed: int = 0,
|
|
355
|
+
):
|
|
356
|
+
"""Initialize the Gymnasium stream.
|
|
357
|
+
|
|
358
|
+
Args:
|
|
359
|
+
env: Gymnasium environment instance
|
|
360
|
+
mode: What to predict (REWARD, NEXT_STATE, VALUE)
|
|
361
|
+
policy: Action selection function. If None, uses random policy
|
|
362
|
+
gamma: Discount factor for VALUE mode
|
|
363
|
+
include_action_in_features: If True, features = concat(obs, action).
|
|
364
|
+
If False, features = obs only
|
|
365
|
+
seed: Random seed for environment resets and random policy
|
|
366
|
+
"""
|
|
367
|
+
self._env = env
|
|
368
|
+
self._mode = mode
|
|
369
|
+
self._gamma = gamma
|
|
370
|
+
self._include_action_in_features = include_action_in_features
|
|
371
|
+
self._seed = seed
|
|
372
|
+
self._reset_count = 0
|
|
373
|
+
|
|
374
|
+
if policy is None:
|
|
375
|
+
self._policy = make_random_policy(env, seed)
|
|
376
|
+
else:
|
|
377
|
+
self._policy = policy
|
|
378
|
+
|
|
379
|
+
self._obs_dim = _flatten_space(env.observation_space)
|
|
380
|
+
self._action_dim = _flatten_space(env.action_space)
|
|
381
|
+
|
|
382
|
+
if include_action_in_features:
|
|
383
|
+
self._feature_dim = self._obs_dim + self._action_dim
|
|
384
|
+
else:
|
|
385
|
+
self._feature_dim = self._obs_dim
|
|
386
|
+
|
|
387
|
+
if mode == PredictionMode.NEXT_STATE:
|
|
388
|
+
self._target_dim = self._obs_dim
|
|
389
|
+
else:
|
|
390
|
+
self._target_dim = 1
|
|
391
|
+
|
|
392
|
+
self._current_obs: Array | None = None
|
|
393
|
+
self._episode_count = 0
|
|
394
|
+
self._step_count = 0
|
|
395
|
+
self._value_estimator: Callable[[Array], float] | None = None
|
|
396
|
+
|
|
397
|
+
@property
|
|
398
|
+
def feature_dim(self) -> int:
|
|
399
|
+
"""Return the dimension of feature vectors."""
|
|
400
|
+
return self._feature_dim
|
|
401
|
+
|
|
402
|
+
@property
|
|
403
|
+
def target_dim(self) -> int:
|
|
404
|
+
"""Return the dimension of target vectors."""
|
|
405
|
+
return self._target_dim
|
|
406
|
+
|
|
407
|
+
@property
|
|
408
|
+
def episode_count(self) -> int:
|
|
409
|
+
"""Return the number of completed episodes."""
|
|
410
|
+
return self._episode_count
|
|
411
|
+
|
|
412
|
+
@property
|
|
413
|
+
def step_count(self) -> int:
|
|
414
|
+
"""Return the total number of steps taken."""
|
|
415
|
+
return self._step_count
|
|
416
|
+
|
|
417
|
+
@property
|
|
418
|
+
def mode(self) -> PredictionMode:
|
|
419
|
+
"""Return the prediction mode."""
|
|
420
|
+
return self._mode
|
|
421
|
+
|
|
422
|
+
def set_value_estimator(self, estimator: Callable[[Array], float]) -> None:
|
|
423
|
+
"""Set the value estimator for proper TD learning in VALUE mode."""
|
|
424
|
+
self._value_estimator = estimator
|
|
425
|
+
|
|
426
|
+
def _get_reset_seed(self) -> int:
|
|
427
|
+
"""Get the seed for the next environment reset."""
|
|
428
|
+
seed = self._seed + self._reset_count
|
|
429
|
+
self._reset_count += 1
|
|
430
|
+
return seed
|
|
431
|
+
|
|
432
|
+
def _construct_features(self, obs: Array, action: Array) -> Array:
|
|
433
|
+
"""Construct feature vector from observation and action."""
|
|
434
|
+
if self._include_action_in_features:
|
|
435
|
+
return jnp.concatenate([obs, action])
|
|
436
|
+
return obs
|
|
437
|
+
|
|
438
|
+
def _construct_target(
|
|
439
|
+
self,
|
|
440
|
+
reward: float,
|
|
441
|
+
next_obs: Array,
|
|
442
|
+
terminated: bool,
|
|
443
|
+
) -> Array:
|
|
444
|
+
"""Construct target based on prediction mode."""
|
|
445
|
+
if self._mode == PredictionMode.REWARD:
|
|
446
|
+
return jnp.atleast_1d(jnp.array(reward, dtype=jnp.float32))
|
|
447
|
+
|
|
448
|
+
elif self._mode == PredictionMode.NEXT_STATE:
|
|
449
|
+
return next_obs
|
|
450
|
+
|
|
451
|
+
elif self._mode == PredictionMode.VALUE:
|
|
452
|
+
if terminated:
|
|
453
|
+
return jnp.atleast_1d(jnp.array(reward, dtype=jnp.float32))
|
|
454
|
+
|
|
455
|
+
if self._value_estimator is not None:
|
|
456
|
+
next_value = self._value_estimator(next_obs)
|
|
457
|
+
else:
|
|
458
|
+
next_value = 0.0
|
|
459
|
+
|
|
460
|
+
target = reward + self._gamma * next_value
|
|
461
|
+
return jnp.atleast_1d(jnp.array(target, dtype=jnp.float32))
|
|
462
|
+
|
|
463
|
+
else:
|
|
464
|
+
raise ValueError(f"Unknown mode: {self._mode}")
|
|
465
|
+
|
|
466
|
+
def __iter__(self) -> Iterator[TimeStep]:
|
|
467
|
+
"""Return self as iterator."""
|
|
468
|
+
return self
|
|
469
|
+
|
|
470
|
+
def __next__(self) -> TimeStep:
|
|
471
|
+
"""Generate the next time step."""
|
|
472
|
+
if self._current_obs is None:
|
|
473
|
+
raw_obs, _ = self._env.reset(seed=self._get_reset_seed())
|
|
474
|
+
self._current_obs = _flatten_observation(raw_obs, self._env.observation_space)
|
|
475
|
+
|
|
476
|
+
action = self._policy(self._current_obs)
|
|
477
|
+
flat_action = _flatten_action(action, self._env.action_space)
|
|
478
|
+
|
|
479
|
+
raw_next_obs, reward, terminated, truncated, _ = self._env.step(action)
|
|
480
|
+
next_obs = _flatten_observation(raw_next_obs, self._env.observation_space)
|
|
481
|
+
|
|
482
|
+
features = self._construct_features(self._current_obs, flat_action)
|
|
483
|
+
target = self._construct_target(float(reward), next_obs, terminated)
|
|
484
|
+
|
|
485
|
+
self._step_count += 1
|
|
486
|
+
|
|
487
|
+
if terminated or truncated:
|
|
488
|
+
self._episode_count += 1
|
|
489
|
+
self._current_obs = None
|
|
490
|
+
else:
|
|
491
|
+
self._current_obs = next_obs
|
|
492
|
+
|
|
493
|
+
return TimeStep(observation=features, target=target)
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
class TDStream:
|
|
497
|
+
"""Experience stream for proper TD learning with value function bootstrap.
|
|
498
|
+
|
|
499
|
+
This stream integrates with a learner to use its predictions for
|
|
500
|
+
bootstrapping in TD targets.
|
|
501
|
+
|
|
502
|
+
Usage:
|
|
503
|
+
stream = TDStream(env)
|
|
504
|
+
learner = LinearLearner(optimizer=IDBD())
|
|
505
|
+
state = learner.init(stream.feature_dim)
|
|
506
|
+
|
|
507
|
+
for step, timestep in enumerate(stream):
|
|
508
|
+
result = learner.update(state, timestep.observation, timestep.target)
|
|
509
|
+
state = result.state
|
|
510
|
+
stream.update_value_function(lambda x: learner.predict(state, x))
|
|
511
|
+
"""
|
|
512
|
+
|
|
513
|
+
def __init__(
|
|
514
|
+
self,
|
|
515
|
+
env: gymnasium.Env[Any, Any],
|
|
516
|
+
policy: Callable[[Array], Any] | None = None,
|
|
517
|
+
gamma: float = 0.99,
|
|
518
|
+
include_action_in_features: bool = False,
|
|
519
|
+
seed: int = 0,
|
|
520
|
+
):
|
|
521
|
+
"""Initialize the TD stream.
|
|
522
|
+
|
|
523
|
+
Args:
|
|
524
|
+
env: Gymnasium environment instance
|
|
525
|
+
policy: Action selection function. If None, uses random policy
|
|
526
|
+
gamma: Discount factor
|
|
527
|
+
include_action_in_features: If True, learn Q(s,a). If False, learn V(s)
|
|
528
|
+
seed: Random seed
|
|
529
|
+
"""
|
|
530
|
+
self._env = env
|
|
531
|
+
self._gamma = gamma
|
|
532
|
+
self._include_action_in_features = include_action_in_features
|
|
533
|
+
self._seed = seed
|
|
534
|
+
self._reset_count = 0
|
|
535
|
+
|
|
536
|
+
if policy is None:
|
|
537
|
+
self._policy = make_random_policy(env, seed)
|
|
538
|
+
else:
|
|
539
|
+
self._policy = policy
|
|
540
|
+
|
|
541
|
+
self._obs_dim = _flatten_space(env.observation_space)
|
|
542
|
+
self._action_dim = _flatten_space(env.action_space)
|
|
543
|
+
|
|
544
|
+
if include_action_in_features:
|
|
545
|
+
self._feature_dim = self._obs_dim + self._action_dim
|
|
546
|
+
else:
|
|
547
|
+
self._feature_dim = self._obs_dim
|
|
548
|
+
|
|
549
|
+
self._current_obs: Array | None = None
|
|
550
|
+
self._episode_count = 0
|
|
551
|
+
self._step_count = 0
|
|
552
|
+
self._value_fn: Callable[[Array], float] = lambda x: 0.0
|
|
553
|
+
|
|
554
|
+
@property
|
|
555
|
+
def feature_dim(self) -> int:
|
|
556
|
+
"""Return the dimension of feature vectors."""
|
|
557
|
+
return self._feature_dim
|
|
558
|
+
|
|
559
|
+
@property
|
|
560
|
+
def episode_count(self) -> int:
|
|
561
|
+
"""Return the number of completed episodes."""
|
|
562
|
+
return self._episode_count
|
|
563
|
+
|
|
564
|
+
@property
|
|
565
|
+
def step_count(self) -> int:
|
|
566
|
+
"""Return the total number of steps taken."""
|
|
567
|
+
return self._step_count
|
|
568
|
+
|
|
569
|
+
def update_value_function(self, value_fn: Callable[[Array], float]) -> None:
|
|
570
|
+
"""Update the value function used for TD bootstrapping."""
|
|
571
|
+
self._value_fn = value_fn
|
|
572
|
+
|
|
573
|
+
def _get_reset_seed(self) -> int:
|
|
574
|
+
"""Get the seed for the next environment reset."""
|
|
575
|
+
seed = self._seed + self._reset_count
|
|
576
|
+
self._reset_count += 1
|
|
577
|
+
return seed
|
|
578
|
+
|
|
579
|
+
def _construct_features(self, obs: Array, action: Array) -> Array:
|
|
580
|
+
"""Construct feature vector from observation and action."""
|
|
581
|
+
if self._include_action_in_features:
|
|
582
|
+
return jnp.concatenate([obs, action])
|
|
583
|
+
return obs
|
|
584
|
+
|
|
585
|
+
def __iter__(self) -> Iterator[TimeStep]:
|
|
586
|
+
"""Return self as iterator."""
|
|
587
|
+
return self
|
|
588
|
+
|
|
589
|
+
def __next__(self) -> TimeStep:
|
|
590
|
+
"""Generate the next time step with TD target."""
|
|
591
|
+
if self._current_obs is None:
|
|
592
|
+
raw_obs, _ = self._env.reset(seed=self._get_reset_seed())
|
|
593
|
+
self._current_obs = _flatten_observation(raw_obs, self._env.observation_space)
|
|
594
|
+
|
|
595
|
+
action = self._policy(self._current_obs)
|
|
596
|
+
flat_action = _flatten_action(action, self._env.action_space)
|
|
597
|
+
|
|
598
|
+
raw_next_obs, reward, terminated, truncated, _ = self._env.step(action)
|
|
599
|
+
next_obs = _flatten_observation(raw_next_obs, self._env.observation_space)
|
|
600
|
+
|
|
601
|
+
features = self._construct_features(self._current_obs, flat_action)
|
|
602
|
+
next_features = self._construct_features(next_obs, flat_action)
|
|
603
|
+
|
|
604
|
+
if terminated:
|
|
605
|
+
target = jnp.atleast_1d(jnp.array(reward, dtype=jnp.float32))
|
|
606
|
+
else:
|
|
607
|
+
bootstrap = self._value_fn(next_features)
|
|
608
|
+
target_val = float(reward) + self._gamma * float(bootstrap)
|
|
609
|
+
target = jnp.atleast_1d(jnp.array(target_val, dtype=jnp.float32))
|
|
610
|
+
|
|
611
|
+
self._step_count += 1
|
|
612
|
+
|
|
613
|
+
if terminated or truncated:
|
|
614
|
+
self._episode_count += 1
|
|
615
|
+
self._current_obs = None
|
|
616
|
+
else:
|
|
617
|
+
self._current_obs = next_obs
|
|
618
|
+
|
|
619
|
+
return TimeStep(observation=features, target=target)
|
|
620
|
+
|
|
621
|
+
|
|
622
|
+
def make_gymnasium_stream(
|
|
623
|
+
env_id: str,
|
|
624
|
+
mode: PredictionMode = PredictionMode.REWARD,
|
|
625
|
+
policy: Callable[[Array], Any] | None = None,
|
|
626
|
+
gamma: float = 0.99,
|
|
627
|
+
include_action_in_features: bool = True,
|
|
628
|
+
seed: int = 0,
|
|
629
|
+
**env_kwargs: Any,
|
|
630
|
+
) -> GymnasiumStream:
|
|
631
|
+
"""Factory function to create a GymnasiumStream from an environment ID.
|
|
632
|
+
|
|
633
|
+
Args:
|
|
634
|
+
env_id: Gymnasium environment ID (e.g., "CartPole-v1")
|
|
635
|
+
mode: What to predict (REWARD, NEXT_STATE, VALUE)
|
|
636
|
+
policy: Action selection function. If None, uses random policy
|
|
637
|
+
gamma: Discount factor for VALUE mode
|
|
638
|
+
include_action_in_features: If True, features = concat(obs, action)
|
|
639
|
+
seed: Random seed
|
|
640
|
+
**env_kwargs: Additional arguments passed to gymnasium.make()
|
|
641
|
+
|
|
642
|
+
Returns:
|
|
643
|
+
GymnasiumStream wrapping the environment
|
|
644
|
+
"""
|
|
645
|
+
import gymnasium
|
|
646
|
+
|
|
647
|
+
env = gymnasium.make(env_id, **env_kwargs)
|
|
648
|
+
return GymnasiumStream(
|
|
649
|
+
env=env,
|
|
650
|
+
mode=mode,
|
|
651
|
+
policy=policy,
|
|
652
|
+
gamma=gamma,
|
|
653
|
+
include_action_in_features=include_action_in_features,
|
|
654
|
+
seed=seed,
|
|
655
|
+
)
|