jax-envelope 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- envelope/__init__.py +54 -0
- envelope/compat/brax_envelope.py +5 -3
- envelope/compat/craftax_envelope.py +17 -2
- envelope/compat/gymnax_envelope.py +34 -7
- envelope/compat/jumanji_envelope.py +3 -2
- envelope/compat/kinetix_envelope.py +3 -2
- envelope/compat/mujoco_playground_envelope.py +1 -1
- envelope/compat/navix_envelope.py +1 -1
- envelope/environment.py +16 -9
- envelope/spaces.py +41 -21
- envelope/struct.py +10 -1
- envelope/typing.py +0 -16
- envelope/wrappers/__init__.py +36 -0
- envelope/wrappers/autoreset_wrapper.py +65 -21
- envelope/wrappers/clip_action_wrapper.py +27 -0
- envelope/wrappers/continuous_observation_wrapper.py +61 -0
- envelope/wrappers/episode_statistics_wrapper.py +29 -36
- envelope/wrappers/flatten_action_wrapper.py +75 -0
- envelope/wrappers/flatten_observation_wrapper.py +81 -0
- envelope/wrappers/normalization.py +1 -1
- envelope/wrappers/observation_normalization_wrapper.py +28 -16
- envelope/wrappers/pooled_init_vmap_wrapper.py +122 -0
- envelope/wrappers/state_injection_wrapper.py +18 -22
- envelope/wrappers/truncation_wrapper.py +18 -14
- envelope/wrappers/vmap_envs_wrapper.py +26 -21
- envelope/wrappers/vmap_wrapper.py +36 -21
- envelope/wrappers/wrapper.py +8 -8
- {jax_envelope-0.1.0.dist-info → jax_envelope-0.2.0.dist-info}/METADATA +3 -3
- jax_envelope-0.2.0.dist-info/RECORD +32 -0
- envelope/wrappers/timestep_wrapper.py +0 -22
- jax_envelope-0.1.0.dist-info/RECORD +0 -27
- {jax_envelope-0.1.0.dist-info → jax_envelope-0.2.0.dist-info}/WHEEL +0 -0
- {jax_envelope-0.1.0.dist-info → jax_envelope-0.2.0.dist-info}/licenses/LICENSE +0 -0
envelope/__init__.py
CHANGED
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
from envelope.compat import create
|
|
2
|
+
from envelope.environment import Environment, Info, InfoContainer
|
|
3
|
+
from envelope.spaces import BatchedSpace, Continuous, Discrete, PyTreeSpace, Space
|
|
4
|
+
from envelope.struct import Container, FrozenPyTreeNode, field, static_field
|
|
5
|
+
from envelope.wrappers import (
|
|
6
|
+
AutoResetWrapper,
|
|
7
|
+
ClipActionWrapper,
|
|
8
|
+
ContinuousObservationWrapper,
|
|
9
|
+
EpisodeStatisticsWrapper,
|
|
10
|
+
FlattenActionWrapper,
|
|
11
|
+
FlattenObservationWrapper,
|
|
12
|
+
ObservationNormalizationWrapper,
|
|
13
|
+
PooledInitVmapWrapper,
|
|
14
|
+
StateInjectionWrapper,
|
|
15
|
+
TruncationWrapper,
|
|
16
|
+
VmapEnvsWrapper,
|
|
17
|
+
VmapWrapper,
|
|
18
|
+
WrappedState,
|
|
19
|
+
Wrapper,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
__all__ = [
|
|
23
|
+
# Basic functionality
|
|
24
|
+
"create",
|
|
25
|
+
"Environment",
|
|
26
|
+
"Info",
|
|
27
|
+
"InfoContainer",
|
|
28
|
+
# Spaces
|
|
29
|
+
"Space",
|
|
30
|
+
"BatchedSpace",
|
|
31
|
+
"Continuous",
|
|
32
|
+
"Discrete",
|
|
33
|
+
"PyTreeSpace",
|
|
34
|
+
# Struct
|
|
35
|
+
"field",
|
|
36
|
+
"static_field",
|
|
37
|
+
"FrozenPyTreeNode",
|
|
38
|
+
"Container",
|
|
39
|
+
# Wrappers
|
|
40
|
+
"Wrapper",
|
|
41
|
+
"WrappedState",
|
|
42
|
+
"AutoResetWrapper",
|
|
43
|
+
"ClipActionWrapper",
|
|
44
|
+
"ContinuousObservationWrapper",
|
|
45
|
+
"EpisodeStatisticsWrapper",
|
|
46
|
+
"FlattenActionWrapper",
|
|
47
|
+
"FlattenObservationWrapper",
|
|
48
|
+
"ObservationNormalizationWrapper",
|
|
49
|
+
"PooledInitVmapWrapper",
|
|
50
|
+
"StateInjectionWrapper",
|
|
51
|
+
"TruncationWrapper",
|
|
52
|
+
"VmapWrapper",
|
|
53
|
+
"VmapEnvsWrapper",
|
|
54
|
+
]
|
envelope/compat/brax_envelope.py
CHANGED
|
@@ -48,7 +48,7 @@ class BraxEnvelope(Environment):
|
|
|
48
48
|
def default_max_steps(self) -> int:
|
|
49
49
|
return _BRAX_DEFAULT_EPISODE_LENGTH
|
|
50
50
|
|
|
51
|
-
def __post_init__(self)
|
|
51
|
+
def __post_init__(self):
|
|
52
52
|
if isinstance(self.brax_env, BraxWrapper):
|
|
53
53
|
warnings.warn(
|
|
54
54
|
"Environment wrapping should be handled by envelope. "
|
|
@@ -57,7 +57,7 @@ class BraxEnvelope(Environment):
|
|
|
57
57
|
object.__setattr__(self, "brax_env", self.brax_env.unwrapped)
|
|
58
58
|
|
|
59
59
|
@override
|
|
60
|
-
def
|
|
60
|
+
def init(self, key: Key) -> tuple[State, Info]:
|
|
61
61
|
brax_state = self.brax_env.reset(key)
|
|
62
62
|
info = InfoContainer(obs=brax_state.obs, reward=0.0, terminated=False)
|
|
63
63
|
info = info.update(**dataclasses.asdict(brax_state))
|
|
@@ -67,7 +67,9 @@ class BraxEnvelope(Environment):
|
|
|
67
67
|
def step(self, state: State, action: PyTree) -> tuple[State, Info]:
|
|
68
68
|
brax_state = self.brax_env.step(state, action)
|
|
69
69
|
info = InfoContainer(
|
|
70
|
-
obs=brax_state.obs,
|
|
70
|
+
obs=brax_state.obs,
|
|
71
|
+
reward=brax_state.reward,
|
|
72
|
+
terminated=jnp.asarry(brax_state.done, dtype=bool).item(),
|
|
71
73
|
)
|
|
72
74
|
info = info.update(**dataclasses.asdict(brax_state))
|
|
73
75
|
return brax_state, info
|
|
@@ -22,7 +22,7 @@ class CraftaxEnvelope(Environment):
|
|
|
22
22
|
"""Wrapper to convert a Craftax environment to a envelope environment."""
|
|
23
23
|
|
|
24
24
|
craftax_env: Any = static_field()
|
|
25
|
-
env_params: PyTree
|
|
25
|
+
env_params: PyTree = static_field() # TODO: remove static marker as soon as craftax merges https://github.com/MichaelTMatthews/Craftax/pull/48
|
|
26
26
|
|
|
27
27
|
@classmethod
|
|
28
28
|
def from_name(
|
|
@@ -54,12 +54,27 @@ class CraftaxEnvelope(Environment):
|
|
|
54
54
|
def default_max_steps(self) -> int:
|
|
55
55
|
return int(self.craftax_env.default_params.max_timesteps)
|
|
56
56
|
|
|
57
|
+
@cached_property
|
|
58
|
+
def _craftax_info_placeholder(self) -> PyTree:
|
|
59
|
+
key = jax.random.PRNGKey(0)
|
|
60
|
+
_, state = self.craftax_env.reset(key, self.env_params)
|
|
61
|
+
_, _, _, _, info = self.craftax_env.step(
|
|
62
|
+
key,
|
|
63
|
+
state,
|
|
64
|
+
self.craftax_env.action_space(self.env_params).sample(key),
|
|
65
|
+
self.env_params,
|
|
66
|
+
)
|
|
67
|
+
return jax.tree.map(lambda x: jnp.full_like(x, jnp.nan), info)
|
|
68
|
+
|
|
57
69
|
@override
|
|
58
|
-
def
|
|
70
|
+
def init(self, key: Key) -> tuple[State, Info]:
|
|
71
|
+
# TODO: this function does not add env_info (or comparable) to the info
|
|
72
|
+
# container. We should add tests for this (and all other envelopes) and fix it.
|
|
59
73
|
key, subkey = jax.random.split(key)
|
|
60
74
|
obs, env_state = self.craftax_env.reset(subkey, self.env_params)
|
|
61
75
|
state = Container().update(key=key, env_state=env_state)
|
|
62
76
|
info = InfoContainer(obs=obs, reward=0.0, terminated=False)
|
|
77
|
+
info = info.update(info=self._craftax_info_placeholder)
|
|
63
78
|
return state, info
|
|
64
79
|
|
|
65
80
|
@override
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from functools import cached_property
|
|
2
|
-
from typing import Any, override
|
|
2
|
+
from typing import Any, Callable, cast, override
|
|
3
3
|
|
|
4
4
|
import jax
|
|
5
5
|
import jax.numpy as jnp
|
|
@@ -10,15 +10,24 @@ from gymnax.environments.environment import EnvParams as GymnaxEnvParams
|
|
|
10
10
|
|
|
11
11
|
from envelope import spaces as envelope_spaces
|
|
12
12
|
from envelope.environment import Environment, Info, InfoContainer, State
|
|
13
|
-
from envelope.struct import Container, static_field
|
|
13
|
+
from envelope.struct import Container, field, static_field
|
|
14
14
|
from envelope.typing import Key, PyTree
|
|
15
15
|
|
|
16
|
+
_GymnaxReset = Callable[
|
|
17
|
+
[Key, GymnaxEnvParams],
|
|
18
|
+
tuple[PyTree, Any],
|
|
19
|
+
]
|
|
20
|
+
_GymnaxStep = Callable[
|
|
21
|
+
[Key, Any, PyTree, GymnaxEnvParams],
|
|
22
|
+
tuple[PyTree, Any, jnp.ndarray, jnp.ndarray, PyTree],
|
|
23
|
+
]
|
|
24
|
+
|
|
16
25
|
|
|
17
26
|
class GymnaxEnvelope(Environment):
|
|
18
27
|
"""Wrapper to convert a Gymnax environment to a envelope environment."""
|
|
19
28
|
|
|
20
29
|
gymnax_env: GymnaxEnv = static_field()
|
|
21
|
-
env_params: PyTree
|
|
30
|
+
env_params: PyTree = field()
|
|
22
31
|
|
|
23
32
|
@classmethod
|
|
24
33
|
def from_name(
|
|
@@ -43,19 +52,37 @@ class GymnaxEnvelope(Environment):
|
|
|
43
52
|
def default_max_steps(self) -> int:
|
|
44
53
|
return int(self.gymnax_env.default_params.max_steps_in_episode)
|
|
45
54
|
|
|
55
|
+
@cached_property
|
|
56
|
+
def _gymnax_info_placeholder(self) -> PyTree:
|
|
57
|
+
reset_fn = cast(_GymnaxReset, self.gymnax_env.reset)
|
|
58
|
+
step_fn = cast(_GymnaxStep, self.gymnax_env.step)
|
|
59
|
+
|
|
60
|
+
key = jax.random.PRNGKey(0)
|
|
61
|
+
_, state = reset_fn(key, self.env_params)
|
|
62
|
+
_, _, _, _, info = step_fn(
|
|
63
|
+
key,
|
|
64
|
+
state,
|
|
65
|
+
self.gymnax_env.action_space(self.env_params).sample(key),
|
|
66
|
+
self.env_params,
|
|
67
|
+
)
|
|
68
|
+
return jax.tree.map(lambda x: jnp.full_like(x, jnp.nan, dtype=float), info)
|
|
69
|
+
|
|
46
70
|
@override
|
|
47
|
-
def
|
|
71
|
+
def init(self, key: Key) -> tuple[State, Info]:
|
|
72
|
+
reset_fn = cast(_GymnaxReset, self.gymnax_env.reset)
|
|
73
|
+
|
|
48
74
|
key, subkey = jax.random.split(key)
|
|
49
|
-
obs, env_state =
|
|
75
|
+
obs, env_state = reset_fn(subkey, self.env_params)
|
|
50
76
|
state = Container().update(key=key, env_state=env_state)
|
|
51
77
|
info = InfoContainer(obs=obs, reward=0.0, terminated=False)
|
|
52
|
-
info = info.update(info=
|
|
78
|
+
info = info.update(info=self._gymnax_info_placeholder)
|
|
53
79
|
return state, info
|
|
54
80
|
|
|
55
81
|
@override
|
|
56
82
|
def step(self, state: State, action: PyTree) -> tuple[State, Info]:
|
|
57
83
|
key, subkey = jax.random.split(state.key)
|
|
58
|
-
|
|
84
|
+
step_fn = cast(_GymnaxStep, self.gymnax_env.step)
|
|
85
|
+
obs, env_state, reward, done, env_info = step_fn(
|
|
59
86
|
subkey, state.env_state, action, self.env_params
|
|
60
87
|
)
|
|
61
88
|
state = state.update(key=key, env_state=env_state)
|
|
@@ -48,7 +48,7 @@ class JumanjiEnvelope(Environment):
|
|
|
48
48
|
return self._default_time_limit
|
|
49
49
|
|
|
50
50
|
@override
|
|
51
|
-
def
|
|
51
|
+
def init(self, key: Key) -> tuple[State, Info]:
|
|
52
52
|
env_state, timestep = self.jumanji_env.reset(key)
|
|
53
53
|
info = convert_jumanji_to_envelope_info(timestep)
|
|
54
54
|
return env_state, info
|
|
@@ -81,8 +81,9 @@ class JumanjiEnvelope(Environment):
|
|
|
81
81
|
|
|
82
82
|
|
|
83
83
|
def convert_jumanji_to_envelope_info(timestep: JumanjiTimeStep) -> InfoContainer:
|
|
84
|
+
term = jnp.asarray(timestep.last(), dtype=bool).item()
|
|
84
85
|
info = InfoContainer(
|
|
85
|
-
obs=timestep.observation, reward=timestep.reward, terminated=
|
|
86
|
+
obs=timestep.observation, reward=timestep.reward, terminated=term
|
|
86
87
|
).update(**timestep.extras)
|
|
87
88
|
return info
|
|
88
89
|
|
|
@@ -28,6 +28,7 @@ from kinetix.environment import (
|
|
|
28
28
|
from kinetix.environment.ued.ued import make_reset_fn_sample_kinetix_level
|
|
29
29
|
from kinetix.util.saving import load_from_json_file
|
|
30
30
|
|
|
31
|
+
from envelope import field
|
|
31
32
|
from envelope import spaces as envelope_spaces
|
|
32
33
|
from envelope.compat.gymnax_envelope import _convert_space as _convert_gymnax_space
|
|
33
34
|
from envelope.environment import Environment, Info, InfoContainer, State
|
|
@@ -67,7 +68,7 @@ class KinetixEnvelope(Environment):
|
|
|
67
68
|
"""Wrapper to convert a Kinetix environment to a envelope environment."""
|
|
68
69
|
|
|
69
70
|
kinetix_env: Any = static_field()
|
|
70
|
-
env_params: Any
|
|
71
|
+
env_params: Any = field()
|
|
71
72
|
|
|
72
73
|
@property
|
|
73
74
|
def default_max_steps(self) -> int:
|
|
@@ -162,7 +163,7 @@ class KinetixEnvelope(Environment):
|
|
|
162
163
|
return cls(kinetix_env=kinetix_env, env_params=env_params)
|
|
163
164
|
|
|
164
165
|
@override
|
|
165
|
-
def
|
|
166
|
+
def init(self, key: Key) -> tuple[State, Info]:
|
|
166
167
|
key, subkey = jax.random.split(key)
|
|
167
168
|
obs, env_state = self.kinetix_env.reset(subkey, self.env_params)
|
|
168
169
|
state_out = Container().update(key=key, env_state=env_state)
|
|
@@ -56,7 +56,7 @@ class MujocoPlaygroundEnvelope(Environment):
|
|
|
56
56
|
return self._default_max_steps
|
|
57
57
|
|
|
58
58
|
@override
|
|
59
|
-
def
|
|
59
|
+
def init(self, key: Key) -> tuple[State, Info]:
|
|
60
60
|
env_state = self.mujoco_playground_env.reset(key)
|
|
61
61
|
info = InfoContainer(obs=env_state.obs, reward=0.0, terminated=False)
|
|
62
62
|
info = info.update(**dataclasses.asdict(env_state))
|
|
@@ -38,7 +38,7 @@ class NavixEnvelope(Environment):
|
|
|
38
38
|
return _NAVIX_DEFAULT_MAX_STEPS
|
|
39
39
|
|
|
40
40
|
@override
|
|
41
|
-
def
|
|
41
|
+
def init(self, key: Key) -> tuple[State, Info]:
|
|
42
42
|
timestep = self.navix_env.reset(key)
|
|
43
43
|
return timestep, convert_navix_to_envelope_info(timestep)
|
|
44
44
|
|
envelope/environment.py
CHANGED
|
@@ -5,7 +5,7 @@ from typing import Protocol, runtime_checkable
|
|
|
5
5
|
|
|
6
6
|
from envelope import spaces
|
|
7
7
|
from envelope.struct import Container, FrozenPyTreeNode
|
|
8
|
-
from envelope.typing import Key, PyTree
|
|
8
|
+
from envelope.typing import Array, Key, PyTree
|
|
9
9
|
|
|
10
10
|
__all__ = ["Environment", "State", "Info", "InfoContainer"]
|
|
11
11
|
|
|
@@ -23,7 +23,7 @@ class Info(Protocol):
|
|
|
23
23
|
|
|
24
24
|
class InfoContainer(Container):
|
|
25
25
|
obs: PyTree
|
|
26
|
-
reward: float
|
|
26
|
+
reward: float | Array
|
|
27
27
|
terminated: bool
|
|
28
28
|
truncated: bool = field(default=False)
|
|
29
29
|
|
|
@@ -38,18 +38,25 @@ class Environment(ABC, FrozenPyTreeNode):
|
|
|
38
38
|
|
|
39
39
|
State is an opaque PyTree owned by each environment; wrappers that stack
|
|
40
40
|
environments should expose their wrapped env state as `inner_state` while
|
|
41
|
-
adding any wrapper-specific fields.
|
|
42
|
-
|
|
43
|
-
|
|
41
|
+
adding any wrapper-specific fields.
|
|
42
|
+
|
|
43
|
+
Two distinct lifecycle methods:
|
|
44
|
+
init(key) - Initialize environment and all state from scratch.
|
|
45
|
+
reset(key, state) - Reset the inner environment while preserving
|
|
46
|
+
episode-persistent state.
|
|
44
47
|
"""
|
|
45
48
|
|
|
46
49
|
@abstractmethod
|
|
47
|
-
def
|
|
48
|
-
|
|
49
|
-
|
|
50
|
+
def init(self, key: Key) -> tuple[State, Info]:
|
|
51
|
+
"""Initialize environment and all state from scratch."""
|
|
52
|
+
...
|
|
53
|
+
|
|
54
|
+
def reset(self, key: Key, state: State) -> tuple[State, Info]:
|
|
55
|
+
"""Reset the inner environment while preserving episode-persistent state."""
|
|
56
|
+
return self.init(key)
|
|
50
57
|
|
|
51
58
|
@abstractmethod
|
|
52
|
-
def step(self, state: State, action: PyTree
|
|
59
|
+
def step(self, state: State, action: PyTree) -> tuple[State, Info]: ...
|
|
53
60
|
|
|
54
61
|
@abstractmethod
|
|
55
62
|
@cached_property
|
envelope/spaces.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
2
|
from functools import cached_property
|
|
3
|
-
from typing import override
|
|
3
|
+
from typing import cast, override
|
|
4
4
|
|
|
5
5
|
import jax
|
|
6
6
|
from jax import numpy as jnp
|
|
@@ -65,7 +65,9 @@ class Continuous(Space):
|
|
|
65
65
|
high: float | jax.Array
|
|
66
66
|
|
|
67
67
|
@classmethod
|
|
68
|
-
def from_shape(
|
|
68
|
+
def from_shape(
|
|
69
|
+
cls, low: float, high: float, shape: tuple[int, ...]
|
|
70
|
+
) -> "Continuous":
|
|
69
71
|
return cls(
|
|
70
72
|
low=jnp.full(shape, low, dtype=jnp.asarray(low).dtype),
|
|
71
73
|
high=jnp.full(shape, high, dtype=jnp.asarray(high).dtype),
|
|
@@ -106,17 +108,25 @@ class PyTreeSpace(Space):
|
|
|
106
108
|
"""A Space defined by a PyTree structure of other Spaces.
|
|
107
109
|
|
|
108
110
|
Args:
|
|
109
|
-
tree: A PyTree with
|
|
111
|
+
tree: A PyTree with Discrete or Continuous leaves.
|
|
110
112
|
|
|
111
113
|
Usage:
|
|
112
114
|
space = PyTreeSpace({
|
|
113
|
-
"action": Discrete(n=4
|
|
114
|
-
"obs": Continuous(low=0.0, high=1.0, shape=(2,)
|
|
115
|
+
"action": Discrete(n=4),
|
|
116
|
+
"obs": Continuous(low=0.0, high=1.0, shape=(2,))
|
|
115
117
|
})
|
|
116
118
|
"""
|
|
117
119
|
|
|
118
120
|
tree: PyTree
|
|
119
121
|
|
|
122
|
+
def __post_init__(self):
|
|
123
|
+
leaves = jax.tree.leaves(self.tree, is_leaf=lambda x: isinstance(x, Space))
|
|
124
|
+
for leaf in leaves:
|
|
125
|
+
if not isinstance(leaf, (Discrete, Continuous)):
|
|
126
|
+
raise TypeError(
|
|
127
|
+
f"PyTreeSpace leaves must be Discrete or Continuous, got {type(leaf).__name__}"
|
|
128
|
+
)
|
|
129
|
+
|
|
120
130
|
@override
|
|
121
131
|
def sample(self, key: Key) -> PyTree:
|
|
122
132
|
leaves, treedef = jax.tree.flatten(
|
|
@@ -149,16 +159,23 @@ class PyTreeSpace(Space):
|
|
|
149
159
|
is_leaf=lambda node: isinstance(node, Space),
|
|
150
160
|
)
|
|
151
161
|
|
|
152
|
-
|
|
153
|
-
def
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
space.tree,
|
|
162
|
+
@property
|
|
163
|
+
def dtype(self) -> PyTree:
|
|
164
|
+
return jax.tree.map(
|
|
165
|
+
lambda space: space.dtype,
|
|
166
|
+
self.tree,
|
|
158
167
|
is_leaf=lambda node: isinstance(node, Space),
|
|
159
168
|
)
|
|
160
|
-
|
|
161
|
-
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def peel_batched(space: Space) -> tuple[tuple[int, ...], Space]:
|
|
172
|
+
"""Collect batch dimensions and return (batch_dims_tuple, base_space)."""
|
|
173
|
+
dims: list[int] = []
|
|
174
|
+
s: Space = space
|
|
175
|
+
while isinstance(s, BatchedSpace):
|
|
176
|
+
dims.append(s.batch_size)
|
|
177
|
+
s = s.space
|
|
178
|
+
return tuple(dims), s
|
|
162
179
|
|
|
163
180
|
|
|
164
181
|
class BatchedSpace(Space):
|
|
@@ -190,16 +207,19 @@ class BatchedSpace(Space):
|
|
|
190
207
|
|
|
191
208
|
@cached_property
|
|
192
209
|
def shape(self) -> PyTree:
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
210
|
+
batch_dims, base = peel_batched(self)
|
|
211
|
+
if isinstance(base, PyTreeSpace):
|
|
212
|
+
return jax.tree.map(
|
|
213
|
+
lambda space: batch_dims + space.shape,
|
|
214
|
+
base.tree,
|
|
215
|
+
is_leaf=lambda node: isinstance(node, Space),
|
|
216
|
+
)
|
|
217
|
+
return batch_dims + base.shape
|
|
199
218
|
|
|
200
219
|
@property
|
|
201
|
-
def dtype(self):
|
|
202
|
-
|
|
220
|
+
def dtype(self) -> PyTree:
|
|
221
|
+
_, base = peel_batched(self)
|
|
222
|
+
return base.dtype
|
|
203
223
|
|
|
204
224
|
def __repr__(self) -> str:
|
|
205
225
|
return f"BatchedSpace(space={self.space!r}, batch_size={self.batch_size})"
|
envelope/struct.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import dataclasses
|
|
2
2
|
from dataclasses import KW_ONLY
|
|
3
|
-
from typing import Any, Iterable, Iterator, Mapping, Self, Tuple
|
|
3
|
+
from typing import Any, Iterable, Iterator, Mapping, Self, Tuple, dataclass_transform
|
|
4
4
|
|
|
5
5
|
import jax
|
|
6
6
|
|
|
@@ -24,6 +24,7 @@ def static_field(**kwargs):
|
|
|
24
24
|
return field(pytree_node=False, **kwargs)
|
|
25
25
|
|
|
26
26
|
|
|
27
|
+
@dataclass_transform()
|
|
27
28
|
class FrozenPyTreeNode:
|
|
28
29
|
"""
|
|
29
30
|
Frozen dataclass base that is a JAX pytree node.
|
|
@@ -64,6 +65,7 @@ class FrozenPyTreeNode:
|
|
|
64
65
|
return dataclasses.replace(self, **changes)
|
|
65
66
|
|
|
66
67
|
|
|
68
|
+
@dataclass_transform()
|
|
67
69
|
@jax.tree_util.register_pytree_node_class
|
|
68
70
|
@dataclasses.dataclass(frozen=True, eq=True, repr=True, slots=False)
|
|
69
71
|
class Container:
|
|
@@ -104,6 +106,13 @@ class Container:
|
|
|
104
106
|
for k, v in self._extras.items():
|
|
105
107
|
yield (k, v)
|
|
106
108
|
|
|
109
|
+
def __str__(self) -> str:
|
|
110
|
+
core_str = super().__str__()
|
|
111
|
+
if not self._extras:
|
|
112
|
+
return core_str
|
|
113
|
+
extras_str = f", {', '.join(f'{k}={v!r}' for k, v in self._extras.items())}"
|
|
114
|
+
return f"{core_str[:-1]}{extras_str})" # remove closing parenthesis from core
|
|
115
|
+
|
|
107
116
|
def update(self, **changes: PyTree) -> Self:
|
|
108
117
|
core_names = {f.name for f in dataclasses.fields(self) if f.name != "_extras"}
|
|
109
118
|
core_updates: dict[str, PyTree] = {}
|
envelope/typing.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
from enum import Enum
|
|
2
1
|
from typing import Any, TypeAlias
|
|
3
2
|
|
|
4
3
|
import jax
|
|
@@ -6,18 +5,3 @@ import jax
|
|
|
6
5
|
PyTree: TypeAlias = Any
|
|
7
6
|
Key: TypeAlias = jax.Array
|
|
8
7
|
Array: TypeAlias = jax.Array
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
class BatchKind(Enum):
|
|
12
|
-
"""
|
|
13
|
-
Batch semantics for environments.
|
|
14
|
-
- VMAP: environment represents instances compatible with `jax.vmap`, for example by
|
|
15
|
-
wrapping it in a `VmapWrapper`.
|
|
16
|
-
- NATIVE_POOL: environment represents a batch of instances via a native pool. This
|
|
17
|
-
is the case when it is wrapping a non-jax-based environment that supports native
|
|
18
|
-
batching, for example those provided by envpool. Environments in this mode cannot
|
|
19
|
-
be vmapped, as it would break the native batching semantics.
|
|
20
|
-
"""
|
|
21
|
-
|
|
22
|
-
VMAP = "vmap"
|
|
23
|
-
NATIVE_POOL = "native_pool"
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from envelope.wrappers.autoreset_wrapper import AutoResetWrapper
|
|
2
|
+
from envelope.wrappers.clip_action_wrapper import ClipActionWrapper
|
|
3
|
+
from envelope.wrappers.continuous_observation_wrapper import (
|
|
4
|
+
ContinuousObservationWrapper,
|
|
5
|
+
)
|
|
6
|
+
from envelope.wrappers.episode_statistics_wrapper import EpisodeStatisticsWrapper
|
|
7
|
+
from envelope.wrappers.flatten_action_wrapper import FlattenActionWrapper
|
|
8
|
+
from envelope.wrappers.flatten_observation_wrapper import FlattenObservationWrapper
|
|
9
|
+
from envelope.wrappers.observation_normalization_wrapper import (
|
|
10
|
+
ObservationNormalizationWrapper,
|
|
11
|
+
)
|
|
12
|
+
from envelope.wrappers.pooled_init_vmap_wrapper import PooledInitVmapWrapper
|
|
13
|
+
from envelope.wrappers.state_injection_wrapper import StateInjectionWrapper
|
|
14
|
+
from envelope.wrappers.truncation_wrapper import TruncationWrapper
|
|
15
|
+
from envelope.wrappers.vmap_envs_wrapper import VmapEnvsWrapper
|
|
16
|
+
from envelope.wrappers.vmap_wrapper import VmapWrapper
|
|
17
|
+
from envelope.wrappers.wrapper import WrappedState, Wrapper
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
# Basic functionality
|
|
21
|
+
"Wrapper",
|
|
22
|
+
"WrappedState",
|
|
23
|
+
# Wrappers
|
|
24
|
+
"AutoResetWrapper",
|
|
25
|
+
"ClipActionWrapper",
|
|
26
|
+
"ContinuousObservationWrapper",
|
|
27
|
+
"EpisodeStatisticsWrapper",
|
|
28
|
+
"FlattenActionWrapper",
|
|
29
|
+
"FlattenObservationWrapper",
|
|
30
|
+
"ObservationNormalizationWrapper",
|
|
31
|
+
"PooledInitVmapWrapper",
|
|
32
|
+
"StateInjectionWrapper",
|
|
33
|
+
"TruncationWrapper",
|
|
34
|
+
"VmapWrapper",
|
|
35
|
+
"VmapEnvsWrapper",
|
|
36
|
+
]
|
|
@@ -1,4 +1,7 @@
|
|
|
1
|
+
from typing import override
|
|
2
|
+
|
|
1
3
|
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
2
5
|
|
|
3
6
|
from envelope.environment import Info
|
|
4
7
|
from envelope.struct import field
|
|
@@ -7,30 +10,71 @@ from envelope.wrappers.wrapper import WrappedState, Wrapper
|
|
|
7
10
|
|
|
8
11
|
|
|
9
12
|
class AutoResetWrapper(Wrapper):
|
|
13
|
+
"""Wrapper that automatically resets the environment when an episode ends.
|
|
14
|
+
|
|
15
|
+
When a step results in termination or truncation, this wrapper immediately
|
|
16
|
+
resets the environment. The returned info preserves critical information
|
|
17
|
+
from the terminal step while providing the new episode's initial observation.
|
|
18
|
+
|
|
19
|
+
Info fields after a terminal step (terminated=True or truncated=True):
|
|
20
|
+
obs: Initial observation from the new episode (after reset).
|
|
21
|
+
final: Full info snapshot from the terminal step (before reset).
|
|
22
|
+
terminated: True if the episode ended due to termination.
|
|
23
|
+
truncated: True if the episode ended due to truncation.
|
|
24
|
+
reward: Reward from the terminal step.
|
|
25
|
+
|
|
26
|
+
Info fields during normal steps (terminated=False and truncated=False):
|
|
27
|
+
obs: Current observation.
|
|
28
|
+
final: Info snapshot from the last completed episode (persisted).
|
|
29
|
+
terminated: False.
|
|
30
|
+
truncated: False.
|
|
31
|
+
reward: Reward from the step.
|
|
32
|
+
|
|
33
|
+
This design enables correct value bootstrapping:
|
|
34
|
+
- Use final.obs for value estimation of the true next state
|
|
35
|
+
- On termination: V(s_final) = 0 (episode truly ended)
|
|
36
|
+
- On truncation: bootstrap from V(final.obs) (episode cut off artificially)
|
|
37
|
+
- final persists until the next episode completes, giving easy access
|
|
38
|
+
to last episode's aggregated stats (e.g., final.episode_return)
|
|
39
|
+
"""
|
|
40
|
+
|
|
10
41
|
class AutoResetState(WrappedState):
|
|
11
42
|
reset_key: jax.Array = field()
|
|
43
|
+
last_final: Info = field()
|
|
12
44
|
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
) -> tuple[WrappedState, Info]:
|
|
45
|
+
@override
|
|
46
|
+
def init(self, key: Key) -> tuple[WrappedState, Info]:
|
|
16
47
|
key, subkey = jax.random.split(key)
|
|
17
|
-
inner_state =
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
48
|
+
inner_state, info = self.env.init(key)
|
|
49
|
+
# Initialize last_final with the reset info (no previous episode yet)
|
|
50
|
+
last_final = jax.tree.map(lambda x: jnp.full_like(x, jnp.nan), info)
|
|
51
|
+
state = self.AutoResetState(
|
|
52
|
+
inner_state=inner_state, reset_key=subkey, last_final=last_final
|
|
53
|
+
)
|
|
54
|
+
return state, info.update(final=state.last_final)
|
|
55
|
+
|
|
56
|
+
@override
|
|
57
|
+
def reset(self, key: Key, state: WrappedState) -> tuple[WrappedState, Info]:
|
|
58
|
+
raise NotImplementedError("Reset is not implemented for AutoResetWrapper")
|
|
59
|
+
|
|
60
|
+
@override
|
|
61
|
+
def step(self, state: WrappedState, action: PyTree) -> tuple[WrappedState, Info]:
|
|
62
|
+
key, key_reset = jax.random.split(state.reset_key)
|
|
63
|
+
state = state.replace(reset_key=key)
|
|
64
|
+
|
|
65
|
+
inner_state, info = self.env.step(state.inner_state, action)
|
|
66
|
+
reset_inner_state, reset_info = self.env.reset(key_reset, inner_state)
|
|
67
|
+
|
|
68
|
+
# Select next state and info based on done
|
|
69
|
+
done = info.terminated | info.truncated
|
|
70
|
+
state = jax.tree.map(
|
|
71
|
+
lambda reset, next: jax.lax.select(done, reset, next),
|
|
72
|
+
state.replace(inner_state=reset_inner_state),
|
|
73
|
+
state.replace(inner_state=inner_state),
|
|
74
|
+
)
|
|
75
|
+
info = jax.tree.map(
|
|
76
|
+
lambda reset, next: jax.lax.select(done, reset, next),
|
|
77
|
+
reset_info.update(final=info),
|
|
78
|
+
info.update(final=state.last_final),
|
|
35
79
|
)
|
|
36
80
|
return state, info
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from typing import override
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
|
|
6
|
+
from envelope.environment import Info, State
|
|
7
|
+
from envelope.spaces import BatchedSpace, Continuous, Discrete, PyTreeSpace, Space
|
|
8
|
+
from envelope.typing import PyTree
|
|
9
|
+
from envelope.wrappers.wrapper import Wrapper
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def clip_action(action: PyTree, space: Space) -> PyTree:
|
|
13
|
+
if isinstance(space, BatchedSpace):
|
|
14
|
+
return jax.vmap(clip_action, in_axes=(0, None))(action, space.space)
|
|
15
|
+
elif isinstance(space, PyTreeSpace):
|
|
16
|
+
return jax.tree.map(clip_action, action, space.tree)
|
|
17
|
+
elif isinstance(space, Continuous):
|
|
18
|
+
return jnp.clip(action, space.low, space.high)
|
|
19
|
+
elif isinstance(space, Discrete):
|
|
20
|
+
return jnp.clip(action, 0, space.n - 1)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ClipActionWrapper(Wrapper):
|
|
24
|
+
@override
|
|
25
|
+
def step(self, state: State, action: PyTree) -> tuple[State, Info]:
|
|
26
|
+
action = clip_action(action, self.action_space)
|
|
27
|
+
return self.env.step(state, action)
|