jax-envelope 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- envelope/__init__.py +0 -0
- envelope/compat/__init__.py +97 -0
- envelope/compat/brax_envelope.py +98 -0
- envelope/compat/craftax_envelope.py +86 -0
- envelope/compat/gymnax_envelope.py +91 -0
- envelope/compat/jumanji_envelope.py +127 -0
- envelope/compat/kinetix_envelope.py +194 -0
- envelope/compat/mujoco_playground_envelope.py +101 -0
- envelope/compat/navix_envelope.py +86 -0
- envelope/environment.py +64 -0
- envelope/spaces.py +205 -0
- envelope/struct.py +148 -0
- envelope/typing.py +23 -0
- envelope/wrappers/autoreset_wrapper.py +36 -0
- envelope/wrappers/episode_statistics_wrapper.py +47 -0
- envelope/wrappers/normalization.py +56 -0
- envelope/wrappers/observation_normalization_wrapper.py +114 -0
- envelope/wrappers/state_injection_wrapper.py +91 -0
- envelope/wrappers/timestep_wrapper.py +22 -0
- envelope/wrappers/truncation_wrapper.py +31 -0
- envelope/wrappers/vmap_envs_wrapper.py +77 -0
- envelope/wrappers/vmap_wrapper.py +51 -0
- envelope/wrappers/wrapper.py +57 -0
- jax_envelope-0.1.0.dist-info/METADATA +87 -0
- jax_envelope-0.1.0.dist-info/RECORD +27 -0
- jax_envelope-0.1.0.dist-info/WHEEL +4 -0
- jax_envelope-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
from dataclasses import field
|
|
2
|
+
from envelope.wrappers import Wrapper
|
|
3
|
+
from typing import override
|
|
4
|
+
|
|
5
|
+
from envelope.environment import Environment, Info, State
|
|
6
|
+
from envelope.typing import Key, PyTree, Array
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class EpisodeStatisticsWrapper(Wrapper):
|
|
10
|
+
class StatisticsState(WappedState):
|
|
11
|
+
episode_reward: Array
|
|
12
|
+
episode_length: Array
|
|
13
|
+
_pointer: int = field(default=0)
|
|
14
|
+
|
|
15
|
+
def reset(
|
|
16
|
+
self, key: Key, state: State | None = None, **kwargs
|
|
17
|
+
) -> tuple[State, Info]:
|
|
18
|
+
state, info = self.env.reset(key, state=state, **kwargs)
|
|
19
|
+
info =
|
|
20
|
+
return state, info
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@override
|
|
24
|
+
def reset(
|
|
25
|
+
self, key: Key, state: State | None = None, **kwargs
|
|
26
|
+
) -> tuple[State, Info]:
|
|
27
|
+
state, info = self.env.reset(key, state=state, **kwargs)
|
|
28
|
+
info =
|
|
29
|
+
return state, info
|
|
30
|
+
|
|
31
|
+
@override
|
|
32
|
+
def step(self, state: State, action: PyTree, **kwargs) -> tuple[State, Info]:
|
|
33
|
+
next_state, info = self.env.step(state, action, **kwargs)
|
|
34
|
+
info = self._update_episode_statistics(info)
|
|
35
|
+
return next_state, info
|
|
36
|
+
|
|
37
|
+
def _update_episode_statistics(self, info: Info) -> Info:
|
|
38
|
+
"""Update episode statistics in the info dictionary."""
|
|
39
|
+
if "episode_statistics" not in info:
|
|
40
|
+
info["episode_statistics"] = {
|
|
41
|
+
"reward": 0.0,
|
|
42
|
+
"length": 0,
|
|
43
|
+
}
|
|
44
|
+
info["episode_statistics"]["reward"] += info.get("reward", 0.0)
|
|
45
|
+
info["episode_statistics"]["length"] += 1
|
|
46
|
+
return info
|
|
47
|
+
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from functools import cached_property
|
|
2
|
+
from typing import NamedTuple
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
from jax import numpy as jnp
|
|
6
|
+
|
|
7
|
+
from envelope.struct import FrozenPyTreeNode
|
|
8
|
+
from envelope.typing import Array, PyTree
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class MeanVarPair(NamedTuple):
|
|
12
|
+
mean: Array
|
|
13
|
+
var: Array
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class RunningMeanVar(FrozenPyTreeNode):
|
|
17
|
+
mean: PyTree
|
|
18
|
+
var: PyTree
|
|
19
|
+
count: int
|
|
20
|
+
|
|
21
|
+
@cached_property
|
|
22
|
+
def std(self) -> PyTree:
|
|
23
|
+
return jax.tree.map(jnp.sqrt, self.var)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def update_rmv(rmv_state: RunningMeanVar, x: PyTree) -> RunningMeanVar:
|
|
27
|
+
"""
|
|
28
|
+
Update running mean/variance with a new batch of observations x. We assume x is a
|
|
29
|
+
PyTree of arrays, each with a leading batch dimension (aligned sizes).
|
|
30
|
+
"""
|
|
31
|
+
global_count = rmv_state.count
|
|
32
|
+
batch_count = jax.tree.leaves(x)[0].shape[0]
|
|
33
|
+
tot_count = global_count + batch_count
|
|
34
|
+
|
|
35
|
+
def _update_arr(mean: Array, var: Array, x_arr: Array) -> MeanVarPair:
|
|
36
|
+
batch_mean = x_arr.mean(axis=0)
|
|
37
|
+
batch_var = x_arr.var(axis=0)
|
|
38
|
+
|
|
39
|
+
# Combine variances using parallel algorithm
|
|
40
|
+
m_a = var * global_count
|
|
41
|
+
m_b = batch_var * batch_count
|
|
42
|
+
delta = batch_mean - mean
|
|
43
|
+
m2 = m_a + m_b + (delta**2) * (global_count * batch_count) / tot_count
|
|
44
|
+
|
|
45
|
+
new_mean = mean + delta * (batch_count / tot_count)
|
|
46
|
+
new_var = m2 / tot_count
|
|
47
|
+
return MeanVarPair(mean=new_mean, var=new_var)
|
|
48
|
+
|
|
49
|
+
def is_pair(z):
|
|
50
|
+
return isinstance(z, MeanVarPair)
|
|
51
|
+
|
|
52
|
+
# jax.tree.map returns a PyTree whose leaves are MeanVarPairs
|
|
53
|
+
mean_var_pairs = jax.tree.map(_update_arr, rmv_state.mean, rmv_state.var, x)
|
|
54
|
+
new_mean = jax.tree.map(lambda mv: mv.mean, mean_var_pairs, is_leaf=is_pair)
|
|
55
|
+
new_var = jax.tree.map(lambda mv: mv.var, mean_var_pairs, is_leaf=is_pair)
|
|
56
|
+
return RunningMeanVar(mean=new_mean, var=new_var, count=tot_count)
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from typing import override
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
from jax import numpy as jnp
|
|
5
|
+
|
|
6
|
+
from envelope.environment import Info
|
|
7
|
+
from envelope.spaces import BatchedSpace, PyTreeSpace, Space
|
|
8
|
+
from envelope.struct import field, static_field
|
|
9
|
+
from envelope.typing import Key, PyTree
|
|
10
|
+
from envelope.wrappers.normalization import RunningMeanVar, update_rmv
|
|
11
|
+
from envelope.wrappers.wrapper import WrappedState, Wrapper
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ObservationNormalizationWrapper(Wrapper):
|
|
15
|
+
class ObservationNormalizationState(WrappedState):
|
|
16
|
+
rmv_state: RunningMeanVar = field()
|
|
17
|
+
|
|
18
|
+
stats_spec: PyTree | None = static_field(default=None)
|
|
19
|
+
"""Per-leaf normalization statistics spec as a pytree of jax.ShapeDtypeStruct.
|
|
20
|
+
Shapes must be broadcastable to the observation leaves. If None, inferred from
|
|
21
|
+
the observation_space with BatchedSpace ignored; each leaf must have a floating
|
|
22
|
+
dtype."""
|
|
23
|
+
|
|
24
|
+
def __post_init__(self):
|
|
25
|
+
if self.stats_spec is None:
|
|
26
|
+
stats_spec = _infer_stats_spec(self.env.observation_space)
|
|
27
|
+
object.__setattr__(self, "stats_spec", stats_spec)
|
|
28
|
+
|
|
29
|
+
def _init_rmv_state(self) -> RunningMeanVar:
|
|
30
|
+
def zeros(sd: jax.ShapeDtypeStruct) -> jax.Array:
|
|
31
|
+
return jnp.zeros(sd.shape, dtype=sd.dtype)
|
|
32
|
+
|
|
33
|
+
def ones(sd: jax.ShapeDtypeStruct) -> jax.Array:
|
|
34
|
+
return jnp.ones(sd.shape, dtype=sd.dtype)
|
|
35
|
+
|
|
36
|
+
mean = jax.tree.map(zeros, self.stats_spec)
|
|
37
|
+
var = jax.tree.map(ones, self.stats_spec)
|
|
38
|
+
|
|
39
|
+
return RunningMeanVar(mean=mean, var=var, count=0)
|
|
40
|
+
|
|
41
|
+
def _normalize_obs(self, obs: PyTree, rmv: RunningMeanVar) -> PyTree:
|
|
42
|
+
def norm_leaf(x, mean, std, spec):
|
|
43
|
+
mean = jnp.broadcast_to(mean, x.shape)
|
|
44
|
+
std = jnp.broadcast_to(std, x.shape)
|
|
45
|
+
obs = (x - mean) / (std + 1e-8)
|
|
46
|
+
return obs.astype(spec.dtype)
|
|
47
|
+
|
|
48
|
+
return jax.tree.map(norm_leaf, obs, rmv.mean, rmv.std, self.stats_spec)
|
|
49
|
+
|
|
50
|
+
def _normalize_and_update(
|
|
51
|
+
self, state: WrappedState, info: Info
|
|
52
|
+
) -> tuple[WrappedState, Info]:
|
|
53
|
+
# Ensure each observation leaf is shaped as (-1, *spec.shape)
|
|
54
|
+
reshaped_obs = jax.tree.map(
|
|
55
|
+
lambda x, spec: x.reshape((-1,) + tuple(spec.shape)),
|
|
56
|
+
info.obs,
|
|
57
|
+
self.stats_spec,
|
|
58
|
+
)
|
|
59
|
+
rmv_state = update_rmv(state.rmv_state, reshaped_obs)
|
|
60
|
+
norm_obs = self._normalize_obs(info.obs, rmv_state)
|
|
61
|
+
|
|
62
|
+
state = self.ObservationNormalizationState(
|
|
63
|
+
inner_state=state.inner_state, rmv_state=rmv_state
|
|
64
|
+
)
|
|
65
|
+
info = info.update(obs=norm_obs, unnormalized_obs=info.obs)
|
|
66
|
+
return state, info
|
|
67
|
+
|
|
68
|
+
@override
|
|
69
|
+
def reset(
|
|
70
|
+
self, key: Key, state: PyTree | None = None, **kwargs
|
|
71
|
+
) -> tuple[WrappedState, Info]:
|
|
72
|
+
inner_state = None
|
|
73
|
+
rmv_state = self._init_rmv_state()
|
|
74
|
+
if state:
|
|
75
|
+
inner_state = state.inner_state
|
|
76
|
+
rmv_state = state.rmv_state
|
|
77
|
+
|
|
78
|
+
inner_state, info = self.env.reset(key, inner_state, **kwargs)
|
|
79
|
+
next_state = self.ObservationNormalizationState(
|
|
80
|
+
inner_state=inner_state, rmv_state=rmv_state
|
|
81
|
+
)
|
|
82
|
+
return self._normalize_and_update(next_state, info)
|
|
83
|
+
|
|
84
|
+
@override
|
|
85
|
+
def step(
|
|
86
|
+
self, state: WrappedState, action: PyTree, **kwargs
|
|
87
|
+
) -> tuple[WrappedState, Info]:
|
|
88
|
+
inner_state, info = self.env.step(state.inner_state, action, **kwargs)
|
|
89
|
+
state = state.replace(inner_state=inner_state)
|
|
90
|
+
return self._normalize_and_update(state, info)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _infer_stats_spec(space: Space) -> PyTree:
|
|
94
|
+
"""
|
|
95
|
+
Build a PyTree of jax.ShapeDtypeStruct for stats. Strip BatchedSpace layers,
|
|
96
|
+
and for leaf spaces return (shape=space.shape, dtype=inferred).
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
def descend(sp: Space):
|
|
100
|
+
if isinstance(sp, BatchedSpace):
|
|
101
|
+
return descend(sp.space)
|
|
102
|
+
if isinstance(sp, PyTreeSpace):
|
|
103
|
+
return jax.tree.map(
|
|
104
|
+
lambda s: descend(s),
|
|
105
|
+
sp.tree,
|
|
106
|
+
is_leaf=lambda n: isinstance(n, Space),
|
|
107
|
+
)
|
|
108
|
+
if not jnp.issubdtype(sp.dtype, jnp.floating):
|
|
109
|
+
raise ValueError(
|
|
110
|
+
f"Space {sp} has dtype {sp.dtype} which is not a floating point dtype"
|
|
111
|
+
)
|
|
112
|
+
return jax.ShapeDtypeStruct(tuple(sp.shape), sp.dtype)
|
|
113
|
+
|
|
114
|
+
return descend(space)
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
from envelope.environment import Info, InfoContainer
|
|
2
|
+
from envelope.struct import field
|
|
3
|
+
from envelope.typing import Key, PyTree
|
|
4
|
+
from envelope.wrappers.wrapper import WrappedState, Wrapper
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class StateInjectionWrapper(Wrapper):
|
|
8
|
+
"""Stores a state that all resets return to.
|
|
9
|
+
|
|
10
|
+
For UED: use set_reset_state() to update the injected state, then all resets
|
|
11
|
+
(including auto-reset) return to that state until it's changed again.
|
|
12
|
+
|
|
13
|
+
Usage:
|
|
14
|
+
env = AutoResetWrapper(StateInjectionWrapper(env=base_env))
|
|
15
|
+
state, info = env.reset(key)
|
|
16
|
+
|
|
17
|
+
for outer_iter in range(num_outer_iters):
|
|
18
|
+
# Sample a new task and set it as the reset state
|
|
19
|
+
task_state, task_obs = sample_task(key)
|
|
20
|
+
state = env.set_reset_state(state, task_state, task_obs)
|
|
21
|
+
|
|
22
|
+
# Run episode - auto-resets return to task_state
|
|
23
|
+
for inner_step in range(num_inner_steps):
|
|
24
|
+
state, info = env.step(state, policy(info.obs))
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
class InjectedState(WrappedState):
|
|
28
|
+
reset_state: PyTree | None = field(default=None)
|
|
29
|
+
reset_obs: PyTree | None = field(default=None)
|
|
30
|
+
|
|
31
|
+
def set_reset_state(
|
|
32
|
+
self, state: WrappedState, reset_state: PyTree, reset_obs: PyTree
|
|
33
|
+
) -> WrappedState:
|
|
34
|
+
"""Update the state that resets will return to.
|
|
35
|
+
|
|
36
|
+
This method traverses the wrapped state hierarchy to find and update
|
|
37
|
+
the InjectedState, then reconstructs the full state tree.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
state: Current state (can be from any outer wrapper)
|
|
41
|
+
reset_state: The state to reset to (inner environment state)
|
|
42
|
+
reset_obs: The observation to return on reset
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
New state with updated reset fields at the appropriate level
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def update_injected(s: WrappedState) -> WrappedState:
|
|
49
|
+
# If this is our InjectedState, update it
|
|
50
|
+
if isinstance(s, self.InjectedState):
|
|
51
|
+
return self.InjectedState(
|
|
52
|
+
inner_state=reset_state,
|
|
53
|
+
reset_state=reset_state,
|
|
54
|
+
reset_obs=reset_obs,
|
|
55
|
+
)
|
|
56
|
+
# Otherwise, recurse into inner_state and rebuild
|
|
57
|
+
if hasattr(s, "inner_state"):
|
|
58
|
+
return s.replace(inner_state=update_injected(s.inner_state))
|
|
59
|
+
raise ValueError("Could not find InjectedState in given state")
|
|
60
|
+
|
|
61
|
+
return update_injected(state)
|
|
62
|
+
|
|
63
|
+
def reset(
|
|
64
|
+
self, key: Key, state: PyTree | None = None, **kwargs
|
|
65
|
+
) -> tuple[WrappedState, Info]:
|
|
66
|
+
# Default state has no inner state to reset to
|
|
67
|
+
if state is None:
|
|
68
|
+
state = self.InjectedState(inner_state=None)
|
|
69
|
+
|
|
70
|
+
# If no reset state is set, reset wrapped environment
|
|
71
|
+
if state.reset_state is None and state.reset_obs is None:
|
|
72
|
+
inner_state, info = self.env.reset(key, state=state.inner_state, **kwargs)
|
|
73
|
+
|
|
74
|
+
# If reset state is set, use it
|
|
75
|
+
elif state.reset_state is not None and state.reset_obs is not None:
|
|
76
|
+
inner_state = state.reset_state
|
|
77
|
+
info = InfoContainer(obs=state.reset_obs, reward=0.0, terminated=False)
|
|
78
|
+
|
|
79
|
+
# If only one of reset_state or reset_obs is set, raise error
|
|
80
|
+
else:
|
|
81
|
+
raise ValueError("State must set both reset_state and reset_obs or neither")
|
|
82
|
+
|
|
83
|
+
# Return new state with updated inner state
|
|
84
|
+
state = state.replace(inner_state=inner_state)
|
|
85
|
+
return state, info
|
|
86
|
+
|
|
87
|
+
def step(
|
|
88
|
+
self, state: WrappedState, action: PyTree, **kwargs
|
|
89
|
+
) -> tuple[WrappedState, Info]:
|
|
90
|
+
inner_state, info = self.env.step(state.inner_state, action, **kwargs)
|
|
91
|
+
return state.replace(inner_state=inner_state), info
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from envelope.environment import Info
|
|
2
|
+
from envelope.struct import field
|
|
3
|
+
from envelope.typing import Key, PyTree
|
|
4
|
+
from envelope.wrappers.wrapper import WrappedState, Wrapper
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class TimeStepWrapper(Wrapper):
|
|
8
|
+
class TimeStepState(WrappedState):
|
|
9
|
+
steps: PyTree = field(default=0)
|
|
10
|
+
|
|
11
|
+
def reset(
|
|
12
|
+
self, key: Key, state: PyTree | None = None, **kwargs
|
|
13
|
+
) -> tuple[WrappedState, Info]:
|
|
14
|
+
inner_state, info = self.env.reset(key, state, **kwargs)
|
|
15
|
+
return self.TimeStepState(inner_state=inner_state, steps=0), info
|
|
16
|
+
|
|
17
|
+
def step(
|
|
18
|
+
self, state: WrappedState, action: PyTree, **kwargs
|
|
19
|
+
) -> tuple[WrappedState, Info]:
|
|
20
|
+
next_inner_state, info = self.env.step(state.inner_state, action, **kwargs)
|
|
21
|
+
next_steps = getattr(state, "steps", 0) + 1
|
|
22
|
+
return self.TimeStepState(inner_state=next_inner_state, steps=next_steps), info
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
|
|
3
|
+
from envelope.environment import Info
|
|
4
|
+
from envelope.struct import field
|
|
5
|
+
from envelope.typing import Key, PyTree
|
|
6
|
+
from envelope.wrappers.wrapper import WrappedState, Wrapper
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TruncationWrapper(Wrapper):
|
|
10
|
+
max_steps: int = field(kw_only=True)
|
|
11
|
+
|
|
12
|
+
class TruncationState(WrappedState):
|
|
13
|
+
steps: jnp.ndarray | int = field(default=0)
|
|
14
|
+
|
|
15
|
+
def reset(
|
|
16
|
+
self, key: Key, state: PyTree | None = None, **kwargs
|
|
17
|
+
) -> tuple[WrappedState, Info]:
|
|
18
|
+
inner_state, info = self.env.reset(key)
|
|
19
|
+
state = self.TruncationState(inner_state=inner_state, steps=0)
|
|
20
|
+
return state, info.update(truncated=self.max_steps <= 0)
|
|
21
|
+
|
|
22
|
+
def step(
|
|
23
|
+
self, state: WrappedState, action: PyTree, **kwargs
|
|
24
|
+
) -> tuple[WrappedState, Info]:
|
|
25
|
+
next_inner_state, info = self.env.step(state.inner_state, action, **kwargs)
|
|
26
|
+
next_steps = state.steps + 1
|
|
27
|
+
next_state = self.TruncationState(
|
|
28
|
+
inner_state=next_inner_state, steps=next_steps
|
|
29
|
+
)
|
|
30
|
+
truncated = jnp.asarray(next_steps) >= self.max_steps
|
|
31
|
+
return next_state, info.update(truncated=truncated)
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
from functools import cached_property
|
|
2
|
+
from typing import override
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
|
|
6
|
+
from envelope import spaces
|
|
7
|
+
from envelope.environment import Environment, Info
|
|
8
|
+
from envelope.struct import field
|
|
9
|
+
from envelope.typing import Key, PyTree
|
|
10
|
+
from envelope.wrappers.wrapper import WrappedState, Wrapper
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class VmapEnvsWrapper(Wrapper):
|
|
14
|
+
"""
|
|
15
|
+
Vectorizes over a batched collection of environment instances (vmapping over 'self').
|
|
16
|
+
|
|
17
|
+
Usage:
|
|
18
|
+
envs = jax.vmap(make_env)(params_batch) # env pytree batched on leading axis
|
|
19
|
+
wrapped = VmapEnvsWrapper(env=envs, batch_size=B)
|
|
20
|
+
state, info = wrapped.reset(keys) # keys shape (B, 2) or single key
|
|
21
|
+
next_state, info = wrapped.step(state, action)
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
batch_size: int = field(kw_only=True)
|
|
25
|
+
|
|
26
|
+
@override
|
|
27
|
+
def reset(
|
|
28
|
+
self, key: Key, state: PyTree | None = None, **kwargs
|
|
29
|
+
) -> tuple[WrappedState, Info]:
|
|
30
|
+
if key.shape == (2,):
|
|
31
|
+
keys = jax.random.split(key, self.batch_size)
|
|
32
|
+
else:
|
|
33
|
+
if key.shape[0] != self.batch_size:
|
|
34
|
+
raise ValueError(
|
|
35
|
+
f"reset key's leading dimension ({key.shape[0]}) must match "
|
|
36
|
+
f"batch_size ({self.batch_size})."
|
|
37
|
+
)
|
|
38
|
+
keys = key
|
|
39
|
+
# vmap over env 'self' and keys
|
|
40
|
+
state, info = jax.vmap(lambda e, k: e.reset(k, state, **kwargs))(self.env, keys)
|
|
41
|
+
return state, info
|
|
42
|
+
|
|
43
|
+
@override
|
|
44
|
+
def step(
|
|
45
|
+
self, state: WrappedState, action: PyTree, **kwargs
|
|
46
|
+
) -> tuple[WrappedState, Info]:
|
|
47
|
+
next_state, info = jax.vmap(lambda e, s, a: e.step(s, a, **kwargs))(
|
|
48
|
+
self.env, state, action
|
|
49
|
+
)
|
|
50
|
+
return next_state, info
|
|
51
|
+
|
|
52
|
+
@override
|
|
53
|
+
@property
|
|
54
|
+
def observation_space(self) -> spaces.Space:
|
|
55
|
+
env0 = _index_env(self.env, 0, self.batch_size)
|
|
56
|
+
return spaces.batch_space(env0.observation_space, self.batch_size)
|
|
57
|
+
|
|
58
|
+
@override
|
|
59
|
+
@cached_property
|
|
60
|
+
def action_space(self) -> spaces.Space:
|
|
61
|
+
env0 = _index_env(self.env, 0, self.batch_size)
|
|
62
|
+
return spaces.batch_space(env0.action_space, self.batch_size)
|
|
63
|
+
|
|
64
|
+
@override
|
|
65
|
+
@property
|
|
66
|
+
def unwrapped(self) -> Environment:
|
|
67
|
+
return self.env.unwrapped
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _index_env(env: Environment, idx: int, batch_size: int) -> Environment:
|
|
71
|
+
def idx_or_keep(x):
|
|
72
|
+
if hasattr(x, "shape") and isinstance(getattr(x, "shape"), tuple):
|
|
73
|
+
if len(x.shape) > 0 and x.shape[0] == batch_size:
|
|
74
|
+
return x[idx]
|
|
75
|
+
return x
|
|
76
|
+
|
|
77
|
+
return jax.tree.map(lambda x: idx_or_keep(x), env)
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from functools import cached_property
|
|
2
|
+
from typing import override
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
|
|
6
|
+
from envelope import spaces
|
|
7
|
+
from envelope.environment import Info
|
|
8
|
+
from envelope.struct import field
|
|
9
|
+
from envelope.typing import Key, PyTree
|
|
10
|
+
from envelope.wrappers.wrapper import WrappedState, Wrapper
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class VmapWrapper(Wrapper):
|
|
14
|
+
"""Does not forward kwargs to the underlying env. Does not wrap the state."""
|
|
15
|
+
|
|
16
|
+
batch_size: int = field(kw_only=True)
|
|
17
|
+
|
|
18
|
+
@override
|
|
19
|
+
def reset(
|
|
20
|
+
self, key: Key, state: PyTree | None = None, **kwargs
|
|
21
|
+
) -> tuple[WrappedState, Info]:
|
|
22
|
+
# Accept single key or batched keys
|
|
23
|
+
if key.shape == (2,):
|
|
24
|
+
keys = jax.random.split(key, self.batch_size)
|
|
25
|
+
else:
|
|
26
|
+
if key.shape[0] != self.batch_size:
|
|
27
|
+
raise ValueError(
|
|
28
|
+
f"reset key's leading dimension ({key.shape[0]}) must match "
|
|
29
|
+
f"batch_size ({self.batch_size})."
|
|
30
|
+
)
|
|
31
|
+
keys = key
|
|
32
|
+
|
|
33
|
+
state, info = jax.vmap(self.env.reset)(keys, state)
|
|
34
|
+
return state, info
|
|
35
|
+
|
|
36
|
+
@override
|
|
37
|
+
def step(
|
|
38
|
+
self, state: WrappedState, action: PyTree, **kwargs
|
|
39
|
+
) -> tuple[WrappedState, Info]:
|
|
40
|
+
state, info = jax.vmap(self.env.step)(state, action)
|
|
41
|
+
return state, info
|
|
42
|
+
|
|
43
|
+
@override
|
|
44
|
+
@cached_property
|
|
45
|
+
def observation_space(self) -> spaces.Space:
|
|
46
|
+
return spaces.batch_space(self.env.observation_space, self.batch_size)
|
|
47
|
+
|
|
48
|
+
@override
|
|
49
|
+
@cached_property
|
|
50
|
+
def action_space(self) -> spaces.Space:
|
|
51
|
+
return spaces.batch_space(self.env.action_space, self.batch_size)
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from dataclasses import KW_ONLY
|
|
2
|
+
from functools import cached_property
|
|
3
|
+
from typing import override
|
|
4
|
+
|
|
5
|
+
from envelope import spaces
|
|
6
|
+
from envelope.environment import Environment, Info, State
|
|
7
|
+
from envelope.struct import FrozenPyTreeNode, field
|
|
8
|
+
from envelope.typing import Key, PyTree
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class WrappedState(FrozenPyTreeNode):
|
|
12
|
+
inner_state: State = field()
|
|
13
|
+
_: KW_ONLY
|
|
14
|
+
|
|
15
|
+
@property
|
|
16
|
+
def unwrapped(self) -> State:
|
|
17
|
+
if hasattr(self.inner_state, "unwrapped"):
|
|
18
|
+
return self.inner_state.unwrapped
|
|
19
|
+
return self.inner_state
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class Wrapper(Environment):
|
|
23
|
+
"""Wrapper for environments."""
|
|
24
|
+
|
|
25
|
+
env: Environment = field(kw_only=True)
|
|
26
|
+
|
|
27
|
+
@override
|
|
28
|
+
def reset(
|
|
29
|
+
self, key: Key, state: State | None = None, **kwargs
|
|
30
|
+
) -> tuple[State, Info]:
|
|
31
|
+
return self.env.reset(key, state=state, **kwargs)
|
|
32
|
+
|
|
33
|
+
@override
|
|
34
|
+
def step(
|
|
35
|
+
self, state: WrappedState, action: PyTree, **kwargs
|
|
36
|
+
) -> tuple[WrappedState, Info]:
|
|
37
|
+
return self.env.step(state, action, **kwargs)
|
|
38
|
+
|
|
39
|
+
@override
|
|
40
|
+
@cached_property
|
|
41
|
+
def observation_space(self) -> spaces.Space:
|
|
42
|
+
return self.env.observation_space
|
|
43
|
+
|
|
44
|
+
@override
|
|
45
|
+
@cached_property
|
|
46
|
+
def action_space(self) -> spaces.Space:
|
|
47
|
+
return self.env.action_space
|
|
48
|
+
|
|
49
|
+
@override
|
|
50
|
+
@property
|
|
51
|
+
def unwrapped(self) -> Environment:
|
|
52
|
+
return self.env.unwrapped
|
|
53
|
+
|
|
54
|
+
def __getattr__(self, name):
|
|
55
|
+
if name == "__setstate__":
|
|
56
|
+
raise AttributeError(name)
|
|
57
|
+
return getattr(self.env, name)
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: jax-envelope
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: A JAX-native environment interface with powerful wrappers and adapters for popular RL environment suites
|
|
5
|
+
Project-URL: Homepage, https://github.com/keraJLi/envelope
|
|
6
|
+
Project-URL: Repository, https://github.com/keraJLi/envelope
|
|
7
|
+
Project-URL: Documentation, https://github.com/keraJLi/envelope#readme
|
|
8
|
+
Project-URL: Issues, https://github.com/keraJLi/envelope/issues
|
|
9
|
+
Project-URL: Changelog, https://github.com/keraJLi/envelope/releases
|
|
10
|
+
Author-email: Jarek Liesen <jarek.liesen@reuben.ox.ac.uk>
|
|
11
|
+
License: MIT
|
|
12
|
+
License-File: LICENSE
|
|
13
|
+
Keywords: deep-learning,environments,gymnasium,hardware-acceleration,jax,machine-learning,reinforcement-learning,vectorization
|
|
14
|
+
Classifier: Development Status :: 4 - Beta
|
|
15
|
+
Classifier: Intended Audience :: Developers
|
|
16
|
+
Classifier: Intended Audience :: Science/Research
|
|
17
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
18
|
+
Classifier: Operating System :: OS Independent
|
|
19
|
+
Classifier: Programming Language :: Python :: 3
|
|
20
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
21
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
22
|
+
Classifier: Typing :: Typed
|
|
23
|
+
Requires-Python: >=3.12
|
|
24
|
+
Requires-Dist: jax>=0.5.0
|
|
25
|
+
Description-Content-Type: text/markdown
|
|
26
|
+
|
|
27
|
+
# 💌 Envelope: a JAX-native environment interface
|
|
28
|
+
```python
|
|
29
|
+
# Create environments from JAX-native suites you have installed, ...
|
|
30
|
+
env = envelope.create("gymnax::CartPole-v1")
|
|
31
|
+
|
|
32
|
+
# ... interact with the environments using a simple interface, ...
|
|
33
|
+
state, info = env.reset(key)
|
|
34
|
+
states, infos = jax.lax.scan(env.step, state, actions)
|
|
35
|
+
plt.plot(infos.reward.cumsum())
|
|
36
|
+
|
|
37
|
+
# ... and enjoy a powerful ecosystem of wrappers.
|
|
38
|
+
env = envelope.wrappers.AutoResetWrapper(env)
|
|
39
|
+
env = envelope.wrappers.VmapWrapper(env)
|
|
40
|
+
env = envelope.wrappers.ObservationNormalizationWrapper(env)
|
|
41
|
+
```
|
|
42
|
+
|
|
43
|
+
## 🌍 Simple, expressive interaction!
|
|
44
|
+
* **Environments are pytrees**. Squish them through JAX transformations and trace their parameters.
|
|
45
|
+
* **Idiomatic jax-y interface** of `reset(key: Key) -> State, Info` and `step(state: State, action: PyTree) -> State, Info`. You can directly `jax.scan` over a `step(...)`!
|
|
46
|
+
* **Spaces are super simple**. No `Tuple`, `Dict` nonsense! There are two spaces: `Continuous` and `Discrete`, which you can compose into a `PyTreeSpace`.
|
|
47
|
+
* **Explicit episode truncation** supports correctly handling bootstrapping for value-function targets.
|
|
48
|
+
* **No auto-reset** by default. Resetting every step can be expensive!
|
|
49
|
+
|
|
50
|
+
## 💪 Powerful, composable wrappers!
|
|
51
|
+
* **Carry state across episodes** to track running statistics, for example to normalize observations.
|
|
52
|
+
* **Composable wrappers** can be stacked in any order. For example, `ObservationNormalizationWrapper` before vs. after `VmapWrapper` gives per-env vs. global normalization.
|
|
53
|
+
<!-- TODO: Add auto-reset behavior (including state injection) and optimistic resets once I implement them. -->
|
|
54
|
+
|
|
55
|
+
## 🔌 Adapters for existing suites
|
|
56
|
+
| 📦 | # 🤖 | # 🌍 |
|
|
57
|
+
|------|------|------|
|
|
58
|
+
| [gymnax](https://github.com/RobertTLange/gymnax) | 🕺 | 24 |
|
|
59
|
+
| [brax](https://github.com/google/brax) | 🕺 | 12 |
|
|
60
|
+
| [jumanji](https://github.com/instadeepai/jumanji) | 🕺 / 👯 | 25 / 1 |
|
|
61
|
+
| [kinetix](https://github.com/flairox/kinetix) | 🕺 | 74 |
|
|
62
|
+
| [craftax](https://github.com/MichaelTMatthews/craftax) | 🕺 | 4 |
|
|
63
|
+
| [mujoco_playground](https://github.com/google-deepmind/mujoco_playground) | 🕺 | 54 |
|
|
64
|
+
| | |
|
|
65
|
+
| Total | 🕺 / 👯 | 193 / 1 |
|
|
66
|
+
|
|
67
|
+
```python
|
|
68
|
+
envelope.create("📦::🌍")
|
|
69
|
+
```
|
|
70
|
+
let's you create environments from any of the above!
|
|
71
|
+
|
|
72
|
+
## 📝 Testing
|
|
73
|
+
- **Default (no optional compat deps required)**: `uv run pytest -m "not compat"`
|
|
74
|
+
- **Compat suite (requires full compat dependency group)**:
|
|
75
|
+
- `uv sync --group compat`
|
|
76
|
+
- `uv run pytest -m compat`
|
|
77
|
+
- If any compat dependency is missing/broken, the run will fail fast with an error telling you what to install.
|
|
78
|
+
|
|
79
|
+
## 🏗️ Installation
|
|
80
|
+
```bash
|
|
81
|
+
pip install jax-envelope
|
|
82
|
+
```
|
|
83
|
+
|
|
84
|
+
## 💞 Related projects
|
|
85
|
+
* [stoax](https://github.com/EdanToledo/Stoa) is a very similar project that provides adapters and wrappers for the jumanji-like interface.
|
|
86
|
+
* Check out all the great suites we have adapters for! [gymnax](https://github.com/RobertTLange/gymnax), [brax](https://github.com/google/brax), [jumanji](https://github.com/instadeepai/jumanji), [kinetix](https://github.com/flairox/kinetix), [craftax](https://github.com/MichaelTMatthews/craftax), [mujoco_playground](https://github.com/google-deepmind/mujoco_playground).
|
|
87
|
+
* We will be adding support for [jaxmarl](https://github.com/flairox/jaxmarl) and [pgx](https://github.com/sotetsuk/pgx) in the future, as soon as we figured out the best ever MARL interface for JAX!
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
envelope/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
envelope/environment.py,sha256=4T0IosoRX-6gVvZCKBUEitlM_5cAT43hrKQ0Hr53XGk,1735
|
|
3
|
+
envelope/spaces.py,sha256=fvJ3-ZI3iyFlpfEN1barTd9cbWvQ11nl-7Y45KgFizs,6433
|
|
4
|
+
envelope/struct.py,sha256=Sb8GsxZ7rFF5A5128oZWEQVnFK79xUFe6EaU2bghBOw,5158
|
|
5
|
+
envelope/typing.py,sha256=dcftDRNM0luCqHDEJ_qVmZnjUOEwn-HiFvwuV1z1Bn0,734
|
|
6
|
+
envelope/compat/__init__.py,sha256=6q2B2bfTIu7MlnIXc702ysqEnyURz-MEjTBMBCo8rdQ,3458
|
|
7
|
+
envelope/compat/brax_envelope.py,sha256=T0Xefrpgft16SdRJKMIjzdJq5tbocQ4jeI763PDzZ_4,3445
|
|
8
|
+
envelope/compat/craftax_envelope.py,sha256=t151fDCrgPWfynERXHYXiWBwCGAXbFcreuIyH7GCApg,3162
|
|
9
|
+
envelope/compat/gymnax_envelope.py,sha256=iau4CKFLVIJkRA9WILgOqkp3mXIAzkilKcrKF8vkld4,3790
|
|
10
|
+
envelope/compat/jumanji_envelope.py,sha256=_Tox7DfiyE-Z4drszPUe-arPurz6hMeFtRNfy8o9Ipc,4869
|
|
11
|
+
envelope/compat/kinetix_envelope.py,sha256=kPLWwSoQ49kNXuvfbStOtoGby9EkV7bKAetxUa1_YPA,6817
|
|
12
|
+
envelope/compat/mujoco_playground_envelope.py,sha256=UcYZwOPD27473wmDDb5n2eK46uPwOeh0LZOgyov4rww,3706
|
|
13
|
+
envelope/compat/navix_envelope.py,sha256=K7AGae9_kDETFAUuWH9sCCLDo5fuEFbc0eKTHcvSxek,3027
|
|
14
|
+
envelope/wrappers/autoreset_wrapper.py,sha256=3OUUNb4L6P4Ncn57Uy2ldoLFr1IKnPSWuzyjf20N0Y8,1299
|
|
15
|
+
envelope/wrappers/episode_statistics_wrapper.py,sha256=Mj5Ua7cLtBQYtFn3oQB9eJ7TclKwAFUvsQHIRkcQJvs,1509
|
|
16
|
+
envelope/wrappers/normalization.py,sha256=xHezXsb1J09D3IZACC1xxMy9NuUSBWqCyYH2wFAItPs,1811
|
|
17
|
+
envelope/wrappers/observation_normalization_wrapper.py,sha256=NNHz_THFe2eQUY0Pued9_hJztIGU6SU7kae_AKFbe4k,4248
|
|
18
|
+
envelope/wrappers/state_injection_wrapper.py,sha256=yk7he1zPaEdzjks_NDnYDC91e5bXTlCFPKFMwm-A7jk,3687
|
|
19
|
+
envelope/wrappers/timestep_wrapper.py,sha256=6jS-80AwnIIct3Ool9zq_iGOyfodNbs89NfuyYMfYIE,874
|
|
20
|
+
envelope/wrappers/truncation_wrapper.py,sha256=TzJqxjAjipk7pTk-drXRvQ41uR-jWsHlAgYJpbHI32M,1132
|
|
21
|
+
envelope/wrappers/vmap_envs_wrapper.py,sha256=TezSCeOf3oJKj8w7i_UaAA2LsoWkhhp835W545Q-waI,2561
|
|
22
|
+
envelope/wrappers/vmap_wrapper.py,sha256=zYMYVvuVQfRM7d2ygxxcavRNtkdzkqhArakE507tcU4,1587
|
|
23
|
+
envelope/wrappers/wrapper.py,sha256=ydolcZR9ogW3z4ui07845mbMue2XtEoEJ64tHCzAAtw,1504
|
|
24
|
+
jax_envelope-0.1.0.dist-info/METADATA,sha256=-ASxsWLgF9CwcQ-mp1NJAmdFDS4qzs8bekNJH291R00,4653
|
|
25
|
+
jax_envelope-0.1.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
26
|
+
jax_envelope-0.1.0.dist-info/licenses/LICENSE,sha256=VyF-MK-gY2_fZlhf8uEnE2y8ziIXK-w55GM12eOgXrQ,1069
|
|
27
|
+
jax_envelope-0.1.0.dist-info/RECORD,,
|