jax-envelope 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. envelope/__init__.py +16 -4
  2. envelope/compat/brax_envelope.py +5 -3
  3. envelope/compat/craftax_envelope.py +17 -2
  4. envelope/compat/gymnax_envelope.py +34 -7
  5. envelope/compat/jumanji_envelope.py +3 -2
  6. envelope/compat/kinetix_envelope.py +3 -2
  7. envelope/compat/mujoco_playground_envelope.py +1 -1
  8. envelope/compat/navix_envelope.py +1 -1
  9. envelope/environment.py +16 -9
  10. envelope/spaces.py +41 -21
  11. envelope/struct.py +10 -1
  12. envelope/wrappers/__init__.py +18 -2
  13. envelope/wrappers/autoreset_wrapper.py +65 -21
  14. envelope/wrappers/clip_action_wrapper.py +27 -0
  15. envelope/wrappers/continuous_observation_wrapper.py +61 -0
  16. envelope/wrappers/episode_statistics_wrapper.py +29 -36
  17. envelope/wrappers/flatten_action_wrapper.py +75 -0
  18. envelope/wrappers/flatten_observation_wrapper.py +81 -0
  19. envelope/wrappers/normalization.py +1 -1
  20. envelope/wrappers/observation_normalization_wrapper.py +28 -16
  21. envelope/wrappers/pooled_init_vmap_wrapper.py +122 -0
  22. envelope/wrappers/state_injection_wrapper.py +18 -22
  23. envelope/wrappers/truncation_wrapper.py +18 -14
  24. envelope/wrappers/vmap_envs_wrapper.py +26 -21
  25. envelope/wrappers/vmap_wrapper.py +36 -21
  26. envelope/wrappers/wrapper.py +8 -8
  27. {jax_envelope-0.1.1.dist-info → jax_envelope-0.2.0.dist-info}/METADATA +2 -2
  28. jax_envelope-0.2.0.dist-info/RECORD +32 -0
  29. jax_envelope-0.1.1.dist-info/RECORD +0 -27
  30. {jax_envelope-0.1.1.dist-info → jax_envelope-0.2.0.dist-info}/WHEEL +0 -0
  31. {jax_envelope-0.1.1.dist-info → jax_envelope-0.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,61 @@
1
+ from functools import cached_property
2
+ from typing import override
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+
7
+ from envelope.environment import Info, State
8
+ from envelope.spaces import BatchedSpace, Continuous, Discrete, Space, peel_batched
9
+ from envelope.typing import Key, PyTree
10
+ from envelope.wrappers.wrapper import Wrapper
11
+
12
+
13
+ def to_float(obs: PyTree) -> PyTree:
14
+ return jax.tree.map(lambda x: x.astype(jnp.float32), obs)
15
+
16
+
17
+ def to_continuous(space: Discrete | Continuous) -> Continuous:
18
+ if isinstance(space, Continuous):
19
+ low = jnp.asarray(space.low, dtype=jnp.float32)
20
+ high = jnp.asarray(space.high, dtype=jnp.float32)
21
+ return Continuous(low=low, high=high)
22
+ elif isinstance(space, Discrete):
23
+ n = jnp.asarray(space.n)
24
+ low = jnp.zeros_like(n, dtype=jnp.float32)
25
+ high = jnp.asarray(n - 1, dtype=jnp.float32)
26
+ return Continuous(low=low, high=high)
27
+ raise TypeError(f"Expected Discrete or Continuous, got {type(space)}")
28
+
29
+
30
+ class ContinuousObservationWrapper(Wrapper):
31
+ @override
32
+ def init(self, key: Key) -> tuple[State, Info]:
33
+ state, info = self.env.init(key)
34
+ info = info.update(obs=to_float(info.obs))
35
+ return state, info
36
+
37
+ @override
38
+ def reset(self, key: Key, state: State) -> tuple[State, Info]:
39
+ state, info = self.env.reset(key, state)
40
+ info = info.update(obs=to_float(info.obs))
41
+ return state, info
42
+
43
+ @override
44
+ def step(self, state: State, action: PyTree) -> tuple[State, Info]:
45
+ state, info = self.env.step(state, action)
46
+ info = info.update(obs=to_float(info.obs))
47
+ return state, info
48
+
49
+ @override
50
+ @cached_property
51
+ def observation_space(self) -> Space:
52
+ batch_dims, base = peel_batched(self.env.observation_space)
53
+
54
+ def is_leaf(x):
55
+ return isinstance(x, (Discrete, Continuous))
56
+
57
+ space = jax.tree.map(to_continuous, base, is_leaf=is_leaf)
58
+
59
+ for batch_dim in batch_dims:
60
+ space = BatchedSpace(space, batch_dim)
61
+ return space
@@ -1,47 +1,40 @@
1
- from dataclasses import field
2
- from envelope.wrappers import Wrapper
3
1
  from typing import override
4
2
 
5
- from envelope.environment import Environment, Info, State
6
- from envelope.typing import Key, PyTree, Array
3
+ import jax
7
4
 
5
+ from envelope.environment import Info, State
6
+ from envelope.struct import FrozenPyTreeNode, field
7
+ from envelope.typing import Key, PyTree
8
+ from envelope.wrappers.wrapper import WrappedState, Wrapper
8
9
 
9
- class EpisodeStatisticsWrapper(Wrapper):
10
- class StatisticsState(WappedState):
11
- episode_reward: Array
12
- episode_length: Array
13
- _pointer: int = field(default=0)
14
10
 
15
- def reset(
16
- self, key: Key, state: State | None = None, **kwargs
17
- ) -> tuple[State, Info]:
18
- state, info = self.env.reset(key, state=state, **kwargs)
19
- info =
20
- return state, info
11
+ class EpisodeStatistics(FrozenPyTreeNode):
12
+ reward: jax.Array = field(default=0)
13
+ length: jax.Array = field(default=0)
14
+
21
15
 
16
+ class EpisodeStatisticsWrapper(Wrapper):
17
+ class EpisodeStatisticsState(WrappedState):
18
+ stats: EpisodeStatistics = field(default=EpisodeStatistics())
22
19
 
23
20
  @override
24
- def reset(
25
- self, key: Key, state: State | None = None, **kwargs
26
- ) -> tuple[State, Info]:
27
- state, info = self.env.reset(key, state=state, **kwargs)
28
- info =
29
- return state, info
21
+ def init(self, key: Key) -> tuple[State, Info]:
22
+ inner_state, info = self.env.init(key)
23
+ state = self.EpisodeStatisticsState(inner_state=inner_state)
24
+ return state, info.update(stats=state.stats)
30
25
 
31
26
  @override
32
- def step(self, state: State, action: PyTree, **kwargs) -> tuple[State, Info]:
33
- next_state, info = self.env.step(state, action, **kwargs)
34
- info = self._update_episode_statistics(info)
35
- return next_state, info
36
-
37
- def _update_episode_statistics(self, info: Info) -> Info:
38
- """Update episode statistics in the info dictionary."""
39
- if "episode_statistics" not in info:
40
- info["episode_statistics"] = {
41
- "reward": 0.0,
42
- "length": 0,
43
- }
44
- info["episode_statistics"]["reward"] += info.get("reward", 0.0)
45
- info["episode_statistics"]["length"] += 1
46
- return info
27
+ def reset(self, key: Key, state: State) -> tuple[State, Info]:
28
+ inner_state, info = self.env.reset(key, state.inner_state)
29
+ state = state.replace(inner_state=inner_state)
30
+ return state, info.update(stats=state.stats)
47
31
 
32
+ @override
33
+ def step(self, state: State, action: PyTree) -> tuple[State, Info]:
34
+ inner_state, info = self.env.step(state.inner_state, action)
35
+ stats = state.stats.replace(
36
+ reward=state.stats.reward + info.reward,
37
+ length=state.stats.length + 1,
38
+ )
39
+ state = state.replace(inner_state=inner_state, stats=stats)
40
+ return state, info.update(stats=stats)
@@ -0,0 +1,75 @@
1
+ from functools import cached_property
2
+ from typing import override
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+
7
+ from envelope.environment import Info, State
8
+ from envelope.spaces import (
9
+ BatchedSpace,
10
+ Continuous,
11
+ Discrete,
12
+ PyTreeSpace,
13
+ Space,
14
+ peel_batched,
15
+ )
16
+ from envelope.typing import PyTree
17
+ from envelope.wrappers.wrapper import Wrapper
18
+
19
+
20
+ def flatten_space(space: PyTreeSpace | Continuous | Discrete):
21
+ def is_leaf(x):
22
+ # Tuples containing only integers are shape tuples (leaves)
23
+ # PyTreeSpace can only have tuples that contain at least a Space, so
24
+ # tuples with only integers must be shape tuples from leaf spaces
25
+ return isinstance(x, tuple) and all(isinstance(i, int) for i in x)
26
+
27
+ shapes, treedef = jax.tree.flatten(space.shape, is_leaf=is_leaf)
28
+ dims = [jnp.prod(jnp.asarray(shape)) for shape in shapes]
29
+ return treedef, shapes, dims
30
+
31
+
32
+ def unflatten_x(x: jax.Array, treedef, shapes, dims):
33
+ indices = jnp.cumsum(jnp.array(dims))[:-1] # last split is the remainder
34
+ xs = jnp.split(x, indices)
35
+ xs = jax.tree.map(lambda x, shape: x.reshape(shape), xs, shapes)
36
+ return jax.tree.unflatten(treedef, xs)
37
+
38
+
39
+ class FlattenActionWrapper(Wrapper):
40
+ @override
41
+ def step(self, state: State, action: PyTree) -> tuple[State, Info]:
42
+ treedef, shapes, dims = flatten_space(self.env.action_space)
43
+ action = unflatten_x(action, treedef, shapes, dims)
44
+ return self.env.step(state, action)
45
+
46
+ @override
47
+ @cached_property
48
+ def action_space(self) -> Space:
49
+ batch_dims, base = peel_batched(self.env.action_space)
50
+
51
+ def is_leaf(x):
52
+ return isinstance(x, (Continuous, Discrete))
53
+
54
+ spaces = jax.tree.leaves(base, is_leaf=is_leaf)
55
+ act_cls = type(spaces[0])
56
+
57
+ if not all(isinstance(space, act_cls) for space in spaces):
58
+ raise ValueError("All spaces must be of the same type")
59
+
60
+ if act_cls == Continuous:
61
+ lows = [jnp.asarray(s.low).reshape(-1) for s in spaces]
62
+ highs = [jnp.asarray(s.high).reshape(-1) for s in spaces]
63
+ low = jnp.concatenate(lows, axis=0)
64
+ high = jnp.concatenate(highs, axis=0)
65
+ space = Continuous(low=low, high=high)
66
+ elif act_cls == Discrete:
67
+ ns = [jnp.asarray(s.n).reshape(-1) for s in spaces]
68
+ n = jnp.concatenate(ns, axis=0)
69
+ space = Discrete(n=n)
70
+ else:
71
+ raise ValueError(f"Unsupported space type: {act_cls}")
72
+
73
+ for batch_dim in batch_dims:
74
+ space = BatchedSpace(space, batch_dim)
75
+ return space
@@ -0,0 +1,81 @@
1
+ from functools import cached_property
2
+ from typing import override
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+
7
+ from envelope.environment import Info, State
8
+ from envelope.spaces import BatchedSpace, Continuous, Discrete, Space, peel_batched
9
+ from envelope.typing import Key, PyTree
10
+ from envelope.wrappers.wrapper import Wrapper
11
+
12
+
13
+ def flatten_space(space: Space):
14
+ def is_leaf(x):
15
+ # Tuples containing only integers are shape tuples (leaves)
16
+ # PyTreeSpace can only have tuples that contain at least a Space, so
17
+ # tuples with only integers must be shape tuples from leaf spaces
18
+ return isinstance(x, tuple) and all(isinstance(i, int) for i in x)
19
+
20
+ shapes, treedef = jax.tree.flatten(space.shape, is_leaf=is_leaf)
21
+ dims = [jnp.prod(jnp.asarray(shape)) for shape in shapes]
22
+ return treedef, shapes, dims
23
+
24
+
25
+ def flatten_x(x: PyTree):
26
+ leaves = jax.tree.leaves(x)
27
+ xs = jax.tree.map(lambda x: jnp.asarray(x).reshape(-1), leaves)
28
+ x = jnp.concatenate(xs, axis=0)
29
+ return x
30
+
31
+
32
+ class FlattenObservationWrapper(Wrapper):
33
+ @override
34
+ def init(self, key: Key) -> tuple[State, Info]:
35
+ state, info = self.env.init(key)
36
+ info = info.update(obs=flatten_x(info.obs))
37
+ return state, info
38
+
39
+ @override
40
+ def reset(self, key: Key, state: State) -> tuple[State, Info]:
41
+ state, info = self.env.reset(key, state)
42
+ info = info.update(obs=flatten_x(info.obs))
43
+ return state, info
44
+
45
+ @override
46
+ def step(self, state: State, action: PyTree) -> tuple[State, Info]:
47
+ state, info = self.env.step(state, action)
48
+ info = info.update(obs=flatten_x(info.obs))
49
+ return state, info
50
+
51
+ @override
52
+ @cached_property
53
+ def observation_space(self) -> Space:
54
+ batch_dims, base = peel_batched(self.env.observation_space)
55
+
56
+ def is_leaf(x):
57
+ spaces = (Continuous, Discrete)
58
+ return isinstance(x, spaces)
59
+
60
+ spaces = jax.tree.leaves(base, is_leaf=is_leaf)
61
+ obs_cls = type(spaces[0])
62
+
63
+ if not all(isinstance(space, obs_cls) for space in spaces):
64
+ raise ValueError("All spaces must be of the same type")
65
+
66
+ if obs_cls == Continuous:
67
+ lows = [jnp.asarray(s.low).reshape(-1) for s in spaces]
68
+ highs = [jnp.asarray(s.high).reshape(-1) for s in spaces]
69
+ low = jnp.concatenate(lows, axis=0)
70
+ high = jnp.concatenate(highs, axis=0)
71
+ space = Continuous(low=low, high=high)
72
+ elif obs_cls == Discrete:
73
+ ns = [jnp.asarray(s.n).reshape(-1) for s in spaces]
74
+ n = jnp.concatenate(ns, axis=0)
75
+ space = Discrete(n=n)
76
+ else:
77
+ raise ValueError(f"Unsupported space type: {obs_cls}")
78
+
79
+ for batch_dim in batch_dims:
80
+ space = BatchedSpace(space, batch_dim)
81
+ return space
@@ -16,7 +16,7 @@ class MeanVarPair(NamedTuple):
16
16
  class RunningMeanVar(FrozenPyTreeNode):
17
17
  mean: PyTree
18
18
  var: PyTree
19
- count: int
19
+ count: int | Array
20
20
 
21
21
  @cached_property
22
22
  def std(self) -> PyTree:
@@ -1,10 +1,11 @@
1
- from typing import override
1
+ from functools import cached_property
2
+ from typing import cast, override
2
3
 
3
4
  import jax
4
5
  from jax import numpy as jnp
5
6
 
6
7
  from envelope.environment import Info
7
- from envelope.spaces import BatchedSpace, PyTreeSpace, Space
8
+ from envelope.spaces import BatchedSpace, Continuous, Discrete, PyTreeSpace, Space
8
9
  from envelope.struct import field, static_field
9
10
  from envelope.typing import Key, PyTree
10
11
  from envelope.wrappers.normalization import RunningMeanVar, update_rmv
@@ -36,7 +37,7 @@ class ObservationNormalizationWrapper(Wrapper):
36
37
  mean = jax.tree.map(zeros, self.stats_spec)
37
38
  var = jax.tree.map(ones, self.stats_spec)
38
39
 
39
- return RunningMeanVar(mean=mean, var=var, count=0)
40
+ return RunningMeanVar(mean=mean, var=var, count=jnp.asarray(0))
40
41
 
41
42
  def _normalize_obs(self, obs: PyTree, rmv: RunningMeanVar) -> PyTree:
42
43
  def norm_leaf(x, mean, std, spec):
@@ -66,29 +67,40 @@ class ObservationNormalizationWrapper(Wrapper):
66
67
  return state, info
67
68
 
68
69
  @override
69
- def reset(
70
- self, key: Key, state: PyTree | None = None, **kwargs
71
- ) -> tuple[WrappedState, Info]:
72
- inner_state = None
70
+ def init(self, key: Key) -> tuple[WrappedState, Info]:
71
+ inner_state, info = self.env.init(key)
73
72
  rmv_state = self._init_rmv_state()
74
- if state:
75
- inner_state = state.inner_state
76
- rmv_state = state.rmv_state
77
-
78
- inner_state, info = self.env.reset(key, inner_state, **kwargs)
79
73
  next_state = self.ObservationNormalizationState(
80
74
  inner_state=inner_state, rmv_state=rmv_state
81
75
  )
82
76
  return self._normalize_and_update(next_state, info)
83
77
 
84
78
  @override
85
- def step(
86
- self, state: WrappedState, action: PyTree, **kwargs
87
- ) -> tuple[WrappedState, Info]:
88
- inner_state, info = self.env.step(state.inner_state, action, **kwargs)
79
+ def reset(self, key: Key, state: WrappedState) -> tuple[WrappedState, Info]:
80
+ inner_state, info = self.env.reset(key, state.inner_state)
81
+ # Preserve running statistics across resets
82
+ next_state = self.ObservationNormalizationState(
83
+ inner_state=inner_state, rmv_state=state.rmv_state
84
+ )
85
+ return self._normalize_and_update(next_state, info)
86
+
87
+ @override
88
+ def step(self, state: WrappedState, action: PyTree) -> tuple[WrappedState, Info]:
89
+ inner_state, info = self.env.step(state.inner_state, action)
89
90
  state = state.replace(inner_state=inner_state)
90
91
  return self._normalize_and_update(state, info)
91
92
 
93
+ @override
94
+ @cached_property
95
+ def observation_space(self) -> Space:
96
+ def to_continuous(space: Continuous | Discrete) -> Continuous:
97
+ return Continuous.from_shape(low=-jnp.inf, high=jnp.inf, shape=space.shape)
98
+
99
+ def is_leaf(space: Space) -> bool:
100
+ return isinstance(space, (Discrete, Continuous))
101
+
102
+ return jax.tree.map(to_continuous, self.env.observation_space, is_leaf=is_leaf)
103
+
92
104
 
93
105
  def _infer_stats_spec(space: Space) -> PyTree:
94
106
  """
@@ -0,0 +1,122 @@
1
+ from functools import cached_property
2
+ from typing import override
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+
7
+ from envelope import spaces
8
+ from envelope.environment import Info
9
+ from envelope.struct import field
10
+ from envelope.typing import Key, PyTree
11
+ from envelope.wrappers.vmap_wrapper import _split_or_keep_key
12
+ from envelope.wrappers.wrapper import WrappedState, Wrapper
13
+
14
+
15
+ class PooledInitVmapWrapper(Wrapper):
16
+ batch_size: int = field(kw_only=True)
17
+ pool_size: int = field(kw_only=True)
18
+
19
+ class PooledInitVmapState(WrappedState):
20
+ init_key: Key = field()
21
+ last_final: Info = field()
22
+
23
+ @override
24
+ def init(self, key: Key) -> tuple[WrappedState, Info]:
25
+ keys = _split_or_keep_key(key, self.batch_size + 1)
26
+ key_next, keys_pool = keys[0], keys[1:]
27
+ inner_state, info = jax.vmap(self.env.init)(keys_pool)
28
+ pholder_info = jax.tree.map(
29
+ lambda x: jnp.full_like(x, jnp.nan, dtype=jnp.float32), info
30
+ )
31
+ state = self.PooledInitVmapState(
32
+ inner_state=inner_state,
33
+ init_key=key_next,
34
+ last_final=pholder_info,
35
+ )
36
+ return state, info.update(final=pholder_info)
37
+
38
+ @override
39
+ def reset(self, key: Key, state: WrappedState) -> tuple[WrappedState, Info]:
40
+ # It's hard to support reset for this wrapper.
41
+ # We would have to init the state of a pool of unwrapped environments, and then
42
+ # somehow inject this into the stack of wrapped states. The current data
43
+ # structure for wrapped states does not make this possible without being super
44
+ # hacky, and violating the assumption that wrapped states are opaque (we would
45
+ # likely have to recursively descend by checking if
46
+ # hasattr(state, "inner_state")).
47
+ # Since there is currently no use case in which we need to carry state across
48
+ # episodes before vmapping, we will implement this later.
49
+ keys = _split_or_keep_key(key, self.batch_size + 1)
50
+ key_next, keys_pool = keys[0], keys[1:]
51
+ inner_state, info = jax.vmap(self.env.reset)(keys_pool, state.inner_state)
52
+ state = state.replace(inner_state=inner_state, init_key=key_next)
53
+ return state, info.update(final=state.last_final)
54
+
55
+ @override
56
+ def step(self, state: WrappedState, action: PyTree) -> tuple[WrappedState, Info]:
57
+ inner_state, info = jax.vmap(self.env.step)(state.inner_state, action)
58
+ done = info.terminated | info.truncated
59
+
60
+ # Compute pool_size fresh init states
61
+ key_pool = jax.random.fold_in(state.init_key, 0)
62
+ next_init_key = jax.random.fold_in(state.init_key, 1)
63
+ keys_pool = jax.random.split(key_pool, self.pool_size)
64
+ inner_states_pool, infos_pool = jax.vmap(self.env.init)(keys_pool)
65
+
66
+ # Randomly assign each env a init state from the pool
67
+ key_idxs = jax.random.fold_in(state.init_key, 2)
68
+ pool_idxs = jax.random.randint(key_idxs, (self.batch_size,), 0, self.pool_size)
69
+
70
+ # Expand pool states to batch_size via indexing
71
+ mapped_init_state = jax.tree.map(lambda x: x[pool_idxs], inner_states_pool)
72
+ mapped_init_info = jax.tree.map(lambda x: x[pool_idxs], infos_pool)
73
+
74
+ # Select inner_state: init for done envs, continue for others
75
+ final_inner_state = jax.tree.map(
76
+ lambda init, curr: jax.vmap(jnp.where)(done, init, curr),
77
+ mapped_init_state,
78
+ inner_state,
79
+ )
80
+
81
+ # Select last_final: on done, store terminal info; on continue, keep previous
82
+ final_last_final = jax.tree.map(
83
+ lambda curr, prev: jax.vmap(jnp.where)(done, curr, prev),
84
+ info,
85
+ state.last_final,
86
+ )
87
+
88
+ # Build final_info with final field
89
+ # For done envs: obs is new initial obs, final is terminal info
90
+ # For continue envs: obs is current obs, final is previous last_final
91
+ final_obs = jax.tree.map(
92
+ lambda init, curr: jax.vmap(jnp.where)(done, init, curr),
93
+ mapped_init_info.obs,
94
+ info.obs,
95
+ )
96
+ final_final = jax.tree.map(
97
+ lambda curr, prev: jax.vmap(jnp.where)(done, curr, prev),
98
+ info, # Terminal info snapshot for done envs
99
+ state.last_final, # Previous episode's final for continue envs
100
+ )
101
+ final_info = info.update(obs=final_obs, final=final_final)
102
+
103
+ state = state.replace(
104
+ inner_state=final_inner_state,
105
+ init_key=next_init_key,
106
+ last_final=final_last_final,
107
+ )
108
+ return state, final_info
109
+
110
+ @override
111
+ @cached_property
112
+ def observation_space(self) -> spaces.Space:
113
+ return spaces.BatchedSpace(
114
+ space=self.env.observation_space, batch_size=self.batch_size
115
+ )
116
+
117
+ @override
118
+ @cached_property
119
+ def action_space(self) -> spaces.Space:
120
+ return spaces.BatchedSpace(
121
+ space=self.env.action_space, batch_size=self.batch_size
122
+ )
@@ -1,3 +1,5 @@
1
+ from typing import override
2
+
1
3
  from envelope.environment import Info, InfoContainer
2
4
  from envelope.struct import field
3
5
  from envelope.typing import Key, PyTree
@@ -12,7 +14,7 @@ class StateInjectionWrapper(Wrapper):
12
14
 
13
15
  Usage:
14
16
  env = AutoResetWrapper(StateInjectionWrapper(env=base_env))
15
- state, info = env.reset(key)
17
+ state, info = env.init(key)
16
18
 
17
19
  for outer_iter in range(num_outer_iters):
18
20
  # Sample a new task and set it as the reset state
@@ -60,32 +62,26 @@ class StateInjectionWrapper(Wrapper):
60
62
 
61
63
  return update_injected(state)
62
64
 
63
- def reset(
64
- self, key: Key, state: PyTree | None = None, **kwargs
65
- ) -> tuple[WrappedState, Info]:
66
- # Default state has no inner state to reset to
67
- if state is None:
68
- state = self.InjectedState(inner_state=None)
69
-
70
- # If no reset state is set, reset wrapped environment
71
- if state.reset_state is None and state.reset_obs is None:
72
- inner_state, info = self.env.reset(key, state=state.inner_state, **kwargs)
65
+ @override
66
+ def init(self, key: Key) -> tuple[WrappedState, Info]:
67
+ inner_state, info = self.env.init(key)
68
+ state = self.InjectedState(inner_state=inner_state)
69
+ return state, info
73
70
 
74
- # If reset state is set, use it
75
- elif state.reset_state is not None and state.reset_obs is not None:
71
+ @override
72
+ def reset(self, key: Key, state: WrappedState) -> tuple[WrappedState, Info]:
73
+ # If reset state is set, use it instead of resetting inner env
74
+ if state.reset_state is not None and state.reset_obs is not None:
76
75
  inner_state = state.reset_state
77
76
  info = InfoContainer(obs=state.reset_obs, reward=0.0, terminated=False)
78
-
79
- # If only one of reset_state or reset_obs is set, raise error
77
+ elif state.reset_state is None and state.reset_obs is None:
78
+ inner_state, info = self.env.reset(key, state.inner_state)
80
79
  else:
81
80
  raise ValueError("State must set both reset_state and reset_obs or neither")
82
81
 
83
- # Return new state with updated inner state
84
- state = state.replace(inner_state=inner_state)
85
- return state, info
82
+ return state.replace(inner_state=inner_state), info
86
83
 
87
- def step(
88
- self, state: WrappedState, action: PyTree, **kwargs
89
- ) -> tuple[WrappedState, Info]:
90
- inner_state, info = self.env.step(state.inner_state, action, **kwargs)
84
+ @override
85
+ def step(self, state: WrappedState, action: PyTree) -> tuple[WrappedState, Info]:
86
+ inner_state, info = self.env.step(state.inner_state, action)
91
87
  return state.replace(inner_state=inner_state), info
@@ -1,3 +1,5 @@
1
+ from typing import override
2
+
1
3
  import jax.numpy as jnp
2
4
 
3
5
  from envelope.environment import Info
@@ -12,20 +14,22 @@ class TruncationWrapper(Wrapper):
12
14
  class TruncationState(WrappedState):
13
15
  steps: jnp.ndarray | int = field(default=0)
14
16
 
15
- def reset(
16
- self, key: Key, state: PyTree | None = None, **kwargs
17
- ) -> tuple[WrappedState, Info]:
18
- inner_state, info = self.env.reset(key)
17
+ @override
18
+ def init(self, key: Key) -> tuple[WrappedState, Info]:
19
+ inner_state, info = self.env.init(key)
19
20
  state = self.TruncationState(inner_state=inner_state, steps=0)
20
21
  return state, info.update(truncated=self.max_steps <= 0)
21
22
 
22
- def step(
23
- self, state: WrappedState, action: PyTree, **kwargs
24
- ) -> tuple[WrappedState, Info]:
25
- next_inner_state, info = self.env.step(state.inner_state, action, **kwargs)
26
- next_steps = state.steps + 1
27
- next_state = self.TruncationState(
28
- inner_state=next_inner_state, steps=next_steps
29
- )
30
- truncated = jnp.asarray(next_steps) >= self.max_steps
31
- return next_state, info.update(truncated=truncated)
23
+ @override
24
+ def reset(self, key: Key, state: WrappedState) -> tuple[WrappedState, Info]:
25
+ inner_state, info = self.env.reset(key, state.inner_state)
26
+ state = state.replace(inner_state=inner_state, steps=0)
27
+ return state, info.update(truncated=self.max_steps <= 0)
28
+
29
+ @override
30
+ def step(self, state: WrappedState, action: PyTree) -> tuple[WrappedState, Info]:
31
+ next_inner_state, info = self.env.step(state.inner_state, action)
32
+ steps = state.steps + 1
33
+ state = self.TruncationState(inner_state=next_inner_state, steps=steps)
34
+ truncated = jnp.asarray(steps) >= self.max_steps
35
+ return state, info.update(truncated=truncated)