alberta-framework 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,655 @@
1
+ """Gymnasium environment wrappers as experience streams.
2
+
3
+ This module wraps Gymnasium environments to provide temporally-uniform experience
4
+ streams compatible with the Alberta Framework's learners.
5
+
6
+ Gymnasium environments cannot be JIT-compiled, so this module provides:
7
+ 1. Trajectory collection: Collect data using Python loop, then learn with scan
8
+ 2. Online learning: Python loop for cases requiring real-time env interaction
9
+
10
+ Supports multiple prediction modes:
11
+ - REWARD: Predict immediate reward from (state, action)
12
+ - NEXT_STATE: Predict next state from (state, action)
13
+ - VALUE: Predict cumulative return via TD learning
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ from collections.abc import Callable, Iterator
19
+ from enum import Enum
20
+ from typing import TYPE_CHECKING, Any
21
+
22
+ import jax
23
+ import jax.numpy as jnp
24
+ import jax.random as jr
25
+ from jax import Array
26
+
27
+ from alberta_framework.core.learners import LinearLearner, NormalizedLinearLearner
28
+ from alberta_framework.core.types import LearnerState, TimeStep
29
+
30
+ if TYPE_CHECKING:
31
+ import gymnasium
32
+
33
+ from alberta_framework.core.learners import NormalizedLearnerState
34
+
35
+
36
+ class PredictionMode(Enum):
37
+ """Mode for what the stream predicts.
38
+
39
+ REWARD: Predict immediate reward from (state, action)
40
+ NEXT_STATE: Predict next state from (state, action)
41
+ VALUE: Predict cumulative return (TD learning with bootstrap)
42
+ """
43
+
44
+ REWARD = "reward"
45
+ NEXT_STATE = "next_state"
46
+ VALUE = "value"
47
+
48
+
49
+ def _flatten_space(space: gymnasium.spaces.Space[Any]) -> int:
50
+ """Get the flattened dimension of a Gymnasium space.
51
+
52
+ Args:
53
+ space: A Gymnasium space (Box, Discrete, MultiDiscrete)
54
+
55
+ Returns:
56
+ Integer dimension of the flattened space
57
+
58
+ Raises:
59
+ ValueError: If space type is not supported
60
+ """
61
+ import gymnasium
62
+
63
+ if isinstance(space, gymnasium.spaces.Box):
64
+ return int(jnp.prod(jnp.array(space.shape)))
65
+ elif isinstance(space, gymnasium.spaces.Discrete):
66
+ return 1
67
+ elif isinstance(space, gymnasium.spaces.MultiDiscrete):
68
+ return len(space.nvec)
69
+ else:
70
+ raise ValueError(
71
+ f"Unsupported space type: {type(space).__name__}. "
72
+ "Supported types: Box, Discrete, MultiDiscrete"
73
+ )
74
+
75
+
76
+ def _flatten_observation(obs: Any, space: gymnasium.spaces.Space[Any]) -> Array:
77
+ """Flatten an observation to a 1D JAX array.
78
+
79
+ Args:
80
+ obs: Observation from the environment
81
+ space: The observation space
82
+
83
+ Returns:
84
+ Flattened observation as a 1D JAX array
85
+ """
86
+ import gymnasium
87
+
88
+ if isinstance(space, gymnasium.spaces.Box):
89
+ return jnp.asarray(obs, dtype=jnp.float32).flatten()
90
+ elif isinstance(space, gymnasium.spaces.Discrete):
91
+ return jnp.array([float(obs)], dtype=jnp.float32)
92
+ elif isinstance(space, gymnasium.spaces.MultiDiscrete):
93
+ return jnp.asarray(obs, dtype=jnp.float32)
94
+ else:
95
+ raise ValueError(f"Unsupported space type: {type(space).__name__}")
96
+
97
+
98
+ def _flatten_action(action: Any, space: gymnasium.spaces.Space[Any]) -> Array:
99
+ """Flatten an action to a 1D JAX array.
100
+
101
+ Args:
102
+ action: Action for the environment
103
+ space: The action space
104
+
105
+ Returns:
106
+ Flattened action as a 1D JAX array
107
+ """
108
+ import gymnasium
109
+
110
+ if isinstance(space, gymnasium.spaces.Box):
111
+ return jnp.asarray(action, dtype=jnp.float32).flatten()
112
+ elif isinstance(space, gymnasium.spaces.Discrete):
113
+ return jnp.array([float(action)], dtype=jnp.float32)
114
+ elif isinstance(space, gymnasium.spaces.MultiDiscrete):
115
+ return jnp.asarray(action, dtype=jnp.float32)
116
+ else:
117
+ raise ValueError(f"Unsupported space type: {type(space).__name__}")
118
+
119
+
120
+ def make_random_policy(
121
+ env: gymnasium.Env[Any, Any], seed: int = 0
122
+ ) -> Callable[[Array], Any]:
123
+ """Create a random action policy for an environment.
124
+
125
+ Args:
126
+ env: Gymnasium environment
127
+ seed: Random seed
128
+
129
+ Returns:
130
+ A callable that takes an observation and returns a random action
131
+ """
132
+ import gymnasium
133
+
134
+ rng = jr.key(seed)
135
+ action_space = env.action_space
136
+
137
+ def policy(_obs: Array) -> Any:
138
+ nonlocal rng
139
+ rng, key = jr.split(rng)
140
+
141
+ if isinstance(action_space, gymnasium.spaces.Discrete):
142
+ return int(jr.randint(key, (), 0, int(action_space.n)))
143
+ elif isinstance(action_space, gymnasium.spaces.Box):
144
+ # Sample uniformly between low and high
145
+ low = jnp.asarray(action_space.low, dtype=jnp.float32)
146
+ high = jnp.asarray(action_space.high, dtype=jnp.float32)
147
+ return jr.uniform(key, action_space.shape, minval=low, maxval=high)
148
+ elif isinstance(action_space, gymnasium.spaces.MultiDiscrete):
149
+ nvec = action_space.nvec
150
+ return [
151
+ int(jr.randint(jr.fold_in(key, i), (), 0, n))
152
+ for i, n in enumerate(nvec)
153
+ ]
154
+ else:
155
+ raise ValueError(f"Unsupported action space: {type(action_space).__name__}")
156
+
157
+ return policy
158
+
159
+
160
+ def make_epsilon_greedy_policy(
161
+ base_policy: Callable[[Array], Any],
162
+ env: gymnasium.Env[Any, Any],
163
+ epsilon: float = 0.1,
164
+ seed: int = 0,
165
+ ) -> Callable[[Array], Any]:
166
+ """Wrap a policy with epsilon-greedy exploration.
167
+
168
+ Args:
169
+ base_policy: The greedy policy to wrap
170
+ env: Gymnasium environment (for random action sampling)
171
+ epsilon: Probability of taking a random action
172
+ seed: Random seed
173
+
174
+ Returns:
175
+ Epsilon-greedy policy
176
+ """
177
+ random_policy = make_random_policy(env, seed + 1)
178
+ rng = jr.key(seed)
179
+
180
+ def policy(obs: Array) -> Any:
181
+ nonlocal rng
182
+ rng, key = jr.split(rng)
183
+
184
+ if jr.uniform(key) < epsilon:
185
+ return random_policy(obs)
186
+ return base_policy(obs)
187
+
188
+ return policy
189
+
190
+
191
+ def collect_trajectory(
192
+ env: gymnasium.Env[Any, Any],
193
+ policy: Callable[[Array], Any] | None,
194
+ num_steps: int,
195
+ mode: PredictionMode = PredictionMode.REWARD,
196
+ include_action_in_features: bool = True,
197
+ seed: int = 0,
198
+ ) -> tuple[Array, Array]:
199
+ """Collect a trajectory from a Gymnasium environment.
200
+
201
+ This uses a Python loop to interact with the environment and collects
202
+ observations and targets into JAX arrays that can be used with scan-based
203
+ learning.
204
+
205
+ Args:
206
+ env: Gymnasium environment instance
207
+ policy: Action selection function. If None, uses random policy
208
+ num_steps: Number of steps to collect
209
+ mode: What to predict (REWARD, NEXT_STATE, VALUE)
210
+ include_action_in_features: If True, features = concat(obs, action)
211
+ seed: Random seed for environment resets and random policy
212
+
213
+ Returns:
214
+ Tuple of (observations, targets) as JAX arrays with shape
215
+ (num_steps, feature_dim) and (num_steps, target_dim)
216
+ """
217
+ if policy is None:
218
+ policy = make_random_policy(env, seed)
219
+
220
+ observations = []
221
+ targets = []
222
+
223
+ reset_count = 0
224
+ raw_obs, _ = env.reset(seed=seed + reset_count)
225
+ reset_count += 1
226
+ current_obs = _flatten_observation(raw_obs, env.observation_space)
227
+
228
+ for _ in range(num_steps):
229
+ action = policy(current_obs)
230
+ flat_action = _flatten_action(action, env.action_space)
231
+
232
+ raw_next_obs, reward, terminated, truncated, _ = env.step(action)
233
+ next_obs = _flatten_observation(raw_next_obs, env.observation_space)
234
+
235
+ # Construct features
236
+ if include_action_in_features:
237
+ features = jnp.concatenate([current_obs, flat_action])
238
+ else:
239
+ features = current_obs
240
+
241
+ # Construct target based on mode
242
+ if mode == PredictionMode.REWARD:
243
+ target = jnp.atleast_1d(jnp.array(reward, dtype=jnp.float32))
244
+ elif mode == PredictionMode.NEXT_STATE:
245
+ target = next_obs
246
+ else: # VALUE mode
247
+ # TD target with 0 bootstrap (simple version)
248
+ target = jnp.atleast_1d(jnp.array(reward, dtype=jnp.float32))
249
+
250
+ observations.append(features)
251
+ targets.append(target)
252
+
253
+ if terminated or truncated:
254
+ raw_obs, _ = env.reset(seed=seed + reset_count)
255
+ reset_count += 1
256
+ current_obs = _flatten_observation(raw_obs, env.observation_space)
257
+ else:
258
+ current_obs = next_obs
259
+
260
+ return jnp.stack(observations), jnp.stack(targets)
261
+
262
+
263
+ def learn_from_trajectory(
264
+ learner: LinearLearner,
265
+ observations: Array,
266
+ targets: Array,
267
+ learner_state: LearnerState | None = None,
268
+ ) -> tuple[LearnerState, Array]:
269
+ """Learn from a pre-collected trajectory using jax.lax.scan.
270
+
271
+ This is a JIT-compiled learning function that processes a trajectory
272
+ collected from a Gymnasium environment.
273
+
274
+ Args:
275
+ learner: The learner to train
276
+ observations: Array of observations with shape (num_steps, feature_dim)
277
+ targets: Array of targets with shape (num_steps, target_dim)
278
+ learner_state: Initial state (if None, will be initialized)
279
+
280
+ Returns:
281
+ Tuple of (final_state, metrics_array) where metrics_array has shape
282
+ (num_steps, 3) with columns [squared_error, error, mean_step_size]
283
+ """
284
+ if learner_state is None:
285
+ learner_state = learner.init(observations.shape[1])
286
+
287
+ def step_fn(
288
+ state: LearnerState, inputs: tuple[Array, Array]
289
+ ) -> tuple[LearnerState, Array]:
290
+ obs, target = inputs
291
+ result = learner.update(state, obs, target)
292
+ return result.state, result.metrics
293
+
294
+ final_state, metrics = jax.lax.scan(step_fn, learner_state, (observations, targets))
295
+
296
+ return final_state, metrics
297
+
298
+
299
+ def learn_from_trajectory_normalized(
300
+ learner: NormalizedLinearLearner,
301
+ observations: Array,
302
+ targets: Array,
303
+ learner_state: NormalizedLearnerState | None = None,
304
+ ) -> tuple[NormalizedLearnerState, Array]:
305
+ """Learn from a pre-collected trajectory with normalization using jax.lax.scan.
306
+
307
+ Args:
308
+ learner: The normalized learner to train
309
+ observations: Array of observations with shape (num_steps, feature_dim)
310
+ targets: Array of targets with shape (num_steps, target_dim)
311
+ learner_state: Initial state (if None, will be initialized)
312
+
313
+ Returns:
314
+ Tuple of (final_state, metrics_array) where metrics_array has shape
315
+ (num_steps, 4) with columns [squared_error, error, mean_step_size, normalizer_mean_var]
316
+ """
317
+ if learner_state is None:
318
+ learner_state = learner.init(observations.shape[1])
319
+
320
+ def step_fn(
321
+ state: NormalizedLearnerState, inputs: tuple[Array, Array]
322
+ ) -> tuple[NormalizedLearnerState, Array]:
323
+ obs, target = inputs
324
+ result = learner.update(state, obs, target)
325
+ return result.state, result.metrics
326
+
327
+ final_state, metrics = jax.lax.scan(step_fn, learner_state, (observations, targets))
328
+
329
+ return final_state, metrics
330
+
331
+
332
+ class GymnasiumStream:
333
+ """Experience stream from a Gymnasium environment using Python loop.
334
+
335
+ This class maintains iterator-based access for online learning scenarios
336
+ where you need to interact with the environment in real-time.
337
+
338
+ For batch learning, use collect_trajectory() followed by learn_from_trajectory().
339
+
340
+ Attributes:
341
+ mode: Prediction mode (REWARD, NEXT_STATE, VALUE)
342
+ gamma: Discount factor for VALUE mode
343
+ include_action_in_features: Whether to include action in features
344
+ episode_count: Number of completed episodes
345
+ """
346
+
347
+ def __init__(
348
+ self,
349
+ env: gymnasium.Env[Any, Any],
350
+ mode: PredictionMode = PredictionMode.REWARD,
351
+ policy: Callable[[Array], Any] | None = None,
352
+ gamma: float = 0.99,
353
+ include_action_in_features: bool = True,
354
+ seed: int = 0,
355
+ ):
356
+ """Initialize the Gymnasium stream.
357
+
358
+ Args:
359
+ env: Gymnasium environment instance
360
+ mode: What to predict (REWARD, NEXT_STATE, VALUE)
361
+ policy: Action selection function. If None, uses random policy
362
+ gamma: Discount factor for VALUE mode
363
+ include_action_in_features: If True, features = concat(obs, action).
364
+ If False, features = obs only
365
+ seed: Random seed for environment resets and random policy
366
+ """
367
+ self._env = env
368
+ self._mode = mode
369
+ self._gamma = gamma
370
+ self._include_action_in_features = include_action_in_features
371
+ self._seed = seed
372
+ self._reset_count = 0
373
+
374
+ if policy is None:
375
+ self._policy = make_random_policy(env, seed)
376
+ else:
377
+ self._policy = policy
378
+
379
+ self._obs_dim = _flatten_space(env.observation_space)
380
+ self._action_dim = _flatten_space(env.action_space)
381
+
382
+ if include_action_in_features:
383
+ self._feature_dim = self._obs_dim + self._action_dim
384
+ else:
385
+ self._feature_dim = self._obs_dim
386
+
387
+ if mode == PredictionMode.NEXT_STATE:
388
+ self._target_dim = self._obs_dim
389
+ else:
390
+ self._target_dim = 1
391
+
392
+ self._current_obs: Array | None = None
393
+ self._episode_count = 0
394
+ self._step_count = 0
395
+ self._value_estimator: Callable[[Array], float] | None = None
396
+
397
+ @property
398
+ def feature_dim(self) -> int:
399
+ """Return the dimension of feature vectors."""
400
+ return self._feature_dim
401
+
402
+ @property
403
+ def target_dim(self) -> int:
404
+ """Return the dimension of target vectors."""
405
+ return self._target_dim
406
+
407
+ @property
408
+ def episode_count(self) -> int:
409
+ """Return the number of completed episodes."""
410
+ return self._episode_count
411
+
412
+ @property
413
+ def step_count(self) -> int:
414
+ """Return the total number of steps taken."""
415
+ return self._step_count
416
+
417
+ @property
418
+ def mode(self) -> PredictionMode:
419
+ """Return the prediction mode."""
420
+ return self._mode
421
+
422
+ def set_value_estimator(self, estimator: Callable[[Array], float]) -> None:
423
+ """Set the value estimator for proper TD learning in VALUE mode."""
424
+ self._value_estimator = estimator
425
+
426
+ def _get_reset_seed(self) -> int:
427
+ """Get the seed for the next environment reset."""
428
+ seed = self._seed + self._reset_count
429
+ self._reset_count += 1
430
+ return seed
431
+
432
+ def _construct_features(self, obs: Array, action: Array) -> Array:
433
+ """Construct feature vector from observation and action."""
434
+ if self._include_action_in_features:
435
+ return jnp.concatenate([obs, action])
436
+ return obs
437
+
438
+ def _construct_target(
439
+ self,
440
+ reward: float,
441
+ next_obs: Array,
442
+ terminated: bool,
443
+ ) -> Array:
444
+ """Construct target based on prediction mode."""
445
+ if self._mode == PredictionMode.REWARD:
446
+ return jnp.atleast_1d(jnp.array(reward, dtype=jnp.float32))
447
+
448
+ elif self._mode == PredictionMode.NEXT_STATE:
449
+ return next_obs
450
+
451
+ elif self._mode == PredictionMode.VALUE:
452
+ if terminated:
453
+ return jnp.atleast_1d(jnp.array(reward, dtype=jnp.float32))
454
+
455
+ if self._value_estimator is not None:
456
+ next_value = self._value_estimator(next_obs)
457
+ else:
458
+ next_value = 0.0
459
+
460
+ target = reward + self._gamma * next_value
461
+ return jnp.atleast_1d(jnp.array(target, dtype=jnp.float32))
462
+
463
+ else:
464
+ raise ValueError(f"Unknown mode: {self._mode}")
465
+
466
+ def __iter__(self) -> Iterator[TimeStep]:
467
+ """Return self as iterator."""
468
+ return self
469
+
470
+ def __next__(self) -> TimeStep:
471
+ """Generate the next time step."""
472
+ if self._current_obs is None:
473
+ raw_obs, _ = self._env.reset(seed=self._get_reset_seed())
474
+ self._current_obs = _flatten_observation(raw_obs, self._env.observation_space)
475
+
476
+ action = self._policy(self._current_obs)
477
+ flat_action = _flatten_action(action, self._env.action_space)
478
+
479
+ raw_next_obs, reward, terminated, truncated, _ = self._env.step(action)
480
+ next_obs = _flatten_observation(raw_next_obs, self._env.observation_space)
481
+
482
+ features = self._construct_features(self._current_obs, flat_action)
483
+ target = self._construct_target(float(reward), next_obs, terminated)
484
+
485
+ self._step_count += 1
486
+
487
+ if terminated or truncated:
488
+ self._episode_count += 1
489
+ self._current_obs = None
490
+ else:
491
+ self._current_obs = next_obs
492
+
493
+ return TimeStep(observation=features, target=target)
494
+
495
+
496
+ class TDStream:
497
+ """Experience stream for proper TD learning with value function bootstrap.
498
+
499
+ This stream integrates with a learner to use its predictions for
500
+ bootstrapping in TD targets.
501
+
502
+ Usage:
503
+ stream = TDStream(env)
504
+ learner = LinearLearner(optimizer=IDBD())
505
+ state = learner.init(stream.feature_dim)
506
+
507
+ for step, timestep in enumerate(stream):
508
+ result = learner.update(state, timestep.observation, timestep.target)
509
+ state = result.state
510
+ stream.update_value_function(lambda x: learner.predict(state, x))
511
+ """
512
+
513
+ def __init__(
514
+ self,
515
+ env: gymnasium.Env[Any, Any],
516
+ policy: Callable[[Array], Any] | None = None,
517
+ gamma: float = 0.99,
518
+ include_action_in_features: bool = False,
519
+ seed: int = 0,
520
+ ):
521
+ """Initialize the TD stream.
522
+
523
+ Args:
524
+ env: Gymnasium environment instance
525
+ policy: Action selection function. If None, uses random policy
526
+ gamma: Discount factor
527
+ include_action_in_features: If True, learn Q(s,a). If False, learn V(s)
528
+ seed: Random seed
529
+ """
530
+ self._env = env
531
+ self._gamma = gamma
532
+ self._include_action_in_features = include_action_in_features
533
+ self._seed = seed
534
+ self._reset_count = 0
535
+
536
+ if policy is None:
537
+ self._policy = make_random_policy(env, seed)
538
+ else:
539
+ self._policy = policy
540
+
541
+ self._obs_dim = _flatten_space(env.observation_space)
542
+ self._action_dim = _flatten_space(env.action_space)
543
+
544
+ if include_action_in_features:
545
+ self._feature_dim = self._obs_dim + self._action_dim
546
+ else:
547
+ self._feature_dim = self._obs_dim
548
+
549
+ self._current_obs: Array | None = None
550
+ self._episode_count = 0
551
+ self._step_count = 0
552
+ self._value_fn: Callable[[Array], float] = lambda x: 0.0
553
+
554
+ @property
555
+ def feature_dim(self) -> int:
556
+ """Return the dimension of feature vectors."""
557
+ return self._feature_dim
558
+
559
+ @property
560
+ def episode_count(self) -> int:
561
+ """Return the number of completed episodes."""
562
+ return self._episode_count
563
+
564
+ @property
565
+ def step_count(self) -> int:
566
+ """Return the total number of steps taken."""
567
+ return self._step_count
568
+
569
+ def update_value_function(self, value_fn: Callable[[Array], float]) -> None:
570
+ """Update the value function used for TD bootstrapping."""
571
+ self._value_fn = value_fn
572
+
573
+ def _get_reset_seed(self) -> int:
574
+ """Get the seed for the next environment reset."""
575
+ seed = self._seed + self._reset_count
576
+ self._reset_count += 1
577
+ return seed
578
+
579
+ def _construct_features(self, obs: Array, action: Array) -> Array:
580
+ """Construct feature vector from observation and action."""
581
+ if self._include_action_in_features:
582
+ return jnp.concatenate([obs, action])
583
+ return obs
584
+
585
+ def __iter__(self) -> Iterator[TimeStep]:
586
+ """Return self as iterator."""
587
+ return self
588
+
589
+ def __next__(self) -> TimeStep:
590
+ """Generate the next time step with TD target."""
591
+ if self._current_obs is None:
592
+ raw_obs, _ = self._env.reset(seed=self._get_reset_seed())
593
+ self._current_obs = _flatten_observation(raw_obs, self._env.observation_space)
594
+
595
+ action = self._policy(self._current_obs)
596
+ flat_action = _flatten_action(action, self._env.action_space)
597
+
598
+ raw_next_obs, reward, terminated, truncated, _ = self._env.step(action)
599
+ next_obs = _flatten_observation(raw_next_obs, self._env.observation_space)
600
+
601
+ features = self._construct_features(self._current_obs, flat_action)
602
+ next_features = self._construct_features(next_obs, flat_action)
603
+
604
+ if terminated:
605
+ target = jnp.atleast_1d(jnp.array(reward, dtype=jnp.float32))
606
+ else:
607
+ bootstrap = self._value_fn(next_features)
608
+ target_val = float(reward) + self._gamma * float(bootstrap)
609
+ target = jnp.atleast_1d(jnp.array(target_val, dtype=jnp.float32))
610
+
611
+ self._step_count += 1
612
+
613
+ if terminated or truncated:
614
+ self._episode_count += 1
615
+ self._current_obs = None
616
+ else:
617
+ self._current_obs = next_obs
618
+
619
+ return TimeStep(observation=features, target=target)
620
+
621
+
622
+ def make_gymnasium_stream(
623
+ env_id: str,
624
+ mode: PredictionMode = PredictionMode.REWARD,
625
+ policy: Callable[[Array], Any] | None = None,
626
+ gamma: float = 0.99,
627
+ include_action_in_features: bool = True,
628
+ seed: int = 0,
629
+ **env_kwargs: Any,
630
+ ) -> GymnasiumStream:
631
+ """Factory function to create a GymnasiumStream from an environment ID.
632
+
633
+ Args:
634
+ env_id: Gymnasium environment ID (e.g., "CartPole-v1")
635
+ mode: What to predict (REWARD, NEXT_STATE, VALUE)
636
+ policy: Action selection function. If None, uses random policy
637
+ gamma: Discount factor for VALUE mode
638
+ include_action_in_features: If True, features = concat(obs, action)
639
+ seed: Random seed
640
+ **env_kwargs: Additional arguments passed to gymnasium.make()
641
+
642
+ Returns:
643
+ GymnasiumStream wrapping the environment
644
+ """
645
+ import gymnasium
646
+
647
+ env = gymnasium.make(env_id, **env_kwargs)
648
+ return GymnasiumStream(
649
+ env=env,
650
+ mode=mode,
651
+ policy=policy,
652
+ gamma=gamma,
653
+ include_action_in_features=include_action_in_features,
654
+ seed=seed,
655
+ )