jax-envelope 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,47 @@
1
+ from dataclasses import field
2
+ from envelope.wrappers import Wrapper
3
+ from typing import override
4
+
5
+ from envelope.environment import Environment, Info, State
6
+ from envelope.typing import Key, PyTree, Array
7
+
8
+
9
+ class EpisodeStatisticsWrapper(Wrapper):
10
+ class StatisticsState(WappedState):
11
+ episode_reward: Array
12
+ episode_length: Array
13
+ _pointer: int = field(default=0)
14
+
15
+ def reset(
16
+ self, key: Key, state: State | None = None, **kwargs
17
+ ) -> tuple[State, Info]:
18
+ state, info = self.env.reset(key, state=state, **kwargs)
19
+ info =
20
+ return state, info
21
+
22
+
23
+ @override
24
+ def reset(
25
+ self, key: Key, state: State | None = None, **kwargs
26
+ ) -> tuple[State, Info]:
27
+ state, info = self.env.reset(key, state=state, **kwargs)
28
+ info =
29
+ return state, info
30
+
31
+ @override
32
+ def step(self, state: State, action: PyTree, **kwargs) -> tuple[State, Info]:
33
+ next_state, info = self.env.step(state, action, **kwargs)
34
+ info = self._update_episode_statistics(info)
35
+ return next_state, info
36
+
37
+ def _update_episode_statistics(self, info: Info) -> Info:
38
+ """Update episode statistics in the info dictionary."""
39
+ if "episode_statistics" not in info:
40
+ info["episode_statistics"] = {
41
+ "reward": 0.0,
42
+ "length": 0,
43
+ }
44
+ info["episode_statistics"]["reward"] += info.get("reward", 0.0)
45
+ info["episode_statistics"]["length"] += 1
46
+ return info
47
+
@@ -0,0 +1,56 @@
1
+ from functools import cached_property
2
+ from typing import NamedTuple
3
+
4
+ import jax
5
+ from jax import numpy as jnp
6
+
7
+ from envelope.struct import FrozenPyTreeNode
8
+ from envelope.typing import Array, PyTree
9
+
10
+
11
+ class MeanVarPair(NamedTuple):
12
+ mean: Array
13
+ var: Array
14
+
15
+
16
+ class RunningMeanVar(FrozenPyTreeNode):
17
+ mean: PyTree
18
+ var: PyTree
19
+ count: int
20
+
21
+ @cached_property
22
+ def std(self) -> PyTree:
23
+ return jax.tree.map(jnp.sqrt, self.var)
24
+
25
+
26
+ def update_rmv(rmv_state: RunningMeanVar, x: PyTree) -> RunningMeanVar:
27
+ """
28
+ Update running mean/variance with a new batch of observations x. We assume x is a
29
+ PyTree of arrays, each with a leading batch dimension (aligned sizes).
30
+ """
31
+ global_count = rmv_state.count
32
+ batch_count = jax.tree.leaves(x)[0].shape[0]
33
+ tot_count = global_count + batch_count
34
+
35
+ def _update_arr(mean: Array, var: Array, x_arr: Array) -> MeanVarPair:
36
+ batch_mean = x_arr.mean(axis=0)
37
+ batch_var = x_arr.var(axis=0)
38
+
39
+ # Combine variances using parallel algorithm
40
+ m_a = var * global_count
41
+ m_b = batch_var * batch_count
42
+ delta = batch_mean - mean
43
+ m2 = m_a + m_b + (delta**2) * (global_count * batch_count) / tot_count
44
+
45
+ new_mean = mean + delta * (batch_count / tot_count)
46
+ new_var = m2 / tot_count
47
+ return MeanVarPair(mean=new_mean, var=new_var)
48
+
49
+ def is_pair(z):
50
+ return isinstance(z, MeanVarPair)
51
+
52
+ # jax.tree.map returns a PyTree whose leaves are MeanVarPairs
53
+ mean_var_pairs = jax.tree.map(_update_arr, rmv_state.mean, rmv_state.var, x)
54
+ new_mean = jax.tree.map(lambda mv: mv.mean, mean_var_pairs, is_leaf=is_pair)
55
+ new_var = jax.tree.map(lambda mv: mv.var, mean_var_pairs, is_leaf=is_pair)
56
+ return RunningMeanVar(mean=new_mean, var=new_var, count=tot_count)
@@ -0,0 +1,114 @@
1
+ from typing import override
2
+
3
+ import jax
4
+ from jax import numpy as jnp
5
+
6
+ from envelope.environment import Info
7
+ from envelope.spaces import BatchedSpace, PyTreeSpace, Space
8
+ from envelope.struct import field, static_field
9
+ from envelope.typing import Key, PyTree
10
+ from envelope.wrappers.normalization import RunningMeanVar, update_rmv
11
+ from envelope.wrappers.wrapper import WrappedState, Wrapper
12
+
13
+
14
+ class ObservationNormalizationWrapper(Wrapper):
15
+ class ObservationNormalizationState(WrappedState):
16
+ rmv_state: RunningMeanVar = field()
17
+
18
+ stats_spec: PyTree | None = static_field(default=None)
19
+ """Per-leaf normalization statistics spec as a pytree of jax.ShapeDtypeStruct.
20
+ Shapes must be broadcastable to the observation leaves. If None, inferred from
21
+ the observation_space with BatchedSpace ignored; each leaf must have a floating
22
+ dtype."""
23
+
24
+ def __post_init__(self):
25
+ if self.stats_spec is None:
26
+ stats_spec = _infer_stats_spec(self.env.observation_space)
27
+ object.__setattr__(self, "stats_spec", stats_spec)
28
+
29
+ def _init_rmv_state(self) -> RunningMeanVar:
30
+ def zeros(sd: jax.ShapeDtypeStruct) -> jax.Array:
31
+ return jnp.zeros(sd.shape, dtype=sd.dtype)
32
+
33
+ def ones(sd: jax.ShapeDtypeStruct) -> jax.Array:
34
+ return jnp.ones(sd.shape, dtype=sd.dtype)
35
+
36
+ mean = jax.tree.map(zeros, self.stats_spec)
37
+ var = jax.tree.map(ones, self.stats_spec)
38
+
39
+ return RunningMeanVar(mean=mean, var=var, count=0)
40
+
41
+ def _normalize_obs(self, obs: PyTree, rmv: RunningMeanVar) -> PyTree:
42
+ def norm_leaf(x, mean, std, spec):
43
+ mean = jnp.broadcast_to(mean, x.shape)
44
+ std = jnp.broadcast_to(std, x.shape)
45
+ obs = (x - mean) / (std + 1e-8)
46
+ return obs.astype(spec.dtype)
47
+
48
+ return jax.tree.map(norm_leaf, obs, rmv.mean, rmv.std, self.stats_spec)
49
+
50
+ def _normalize_and_update(
51
+ self, state: WrappedState, info: Info
52
+ ) -> tuple[WrappedState, Info]:
53
+ # Ensure each observation leaf is shaped as (-1, *spec.shape)
54
+ reshaped_obs = jax.tree.map(
55
+ lambda x, spec: x.reshape((-1,) + tuple(spec.shape)),
56
+ info.obs,
57
+ self.stats_spec,
58
+ )
59
+ rmv_state = update_rmv(state.rmv_state, reshaped_obs)
60
+ norm_obs = self._normalize_obs(info.obs, rmv_state)
61
+
62
+ state = self.ObservationNormalizationState(
63
+ inner_state=state.inner_state, rmv_state=rmv_state
64
+ )
65
+ info = info.update(obs=norm_obs, unnormalized_obs=info.obs)
66
+ return state, info
67
+
68
+ @override
69
+ def reset(
70
+ self, key: Key, state: PyTree | None = None, **kwargs
71
+ ) -> tuple[WrappedState, Info]:
72
+ inner_state = None
73
+ rmv_state = self._init_rmv_state()
74
+ if state:
75
+ inner_state = state.inner_state
76
+ rmv_state = state.rmv_state
77
+
78
+ inner_state, info = self.env.reset(key, inner_state, **kwargs)
79
+ next_state = self.ObservationNormalizationState(
80
+ inner_state=inner_state, rmv_state=rmv_state
81
+ )
82
+ return self._normalize_and_update(next_state, info)
83
+
84
+ @override
85
+ def step(
86
+ self, state: WrappedState, action: PyTree, **kwargs
87
+ ) -> tuple[WrappedState, Info]:
88
+ inner_state, info = self.env.step(state.inner_state, action, **kwargs)
89
+ state = state.replace(inner_state=inner_state)
90
+ return self._normalize_and_update(state, info)
91
+
92
+
93
+ def _infer_stats_spec(space: Space) -> PyTree:
94
+ """
95
+ Build a PyTree of jax.ShapeDtypeStruct for stats. Strip BatchedSpace layers,
96
+ and for leaf spaces return (shape=space.shape, dtype=inferred).
97
+ """
98
+
99
+ def descend(sp: Space):
100
+ if isinstance(sp, BatchedSpace):
101
+ return descend(sp.space)
102
+ if isinstance(sp, PyTreeSpace):
103
+ return jax.tree.map(
104
+ lambda s: descend(s),
105
+ sp.tree,
106
+ is_leaf=lambda n: isinstance(n, Space),
107
+ )
108
+ if not jnp.issubdtype(sp.dtype, jnp.floating):
109
+ raise ValueError(
110
+ f"Space {sp} has dtype {sp.dtype} which is not a floating point dtype"
111
+ )
112
+ return jax.ShapeDtypeStruct(tuple(sp.shape), sp.dtype)
113
+
114
+ return descend(space)
@@ -0,0 +1,91 @@
1
+ from envelope.environment import Info, InfoContainer
2
+ from envelope.struct import field
3
+ from envelope.typing import Key, PyTree
4
+ from envelope.wrappers.wrapper import WrappedState, Wrapper
5
+
6
+
7
+ class StateInjectionWrapper(Wrapper):
8
+ """Stores a state that all resets return to.
9
+
10
+ For UED: use set_reset_state() to update the injected state, then all resets
11
+ (including auto-reset) return to that state until it's changed again.
12
+
13
+ Usage:
14
+ env = AutoResetWrapper(StateInjectionWrapper(env=base_env))
15
+ state, info = env.reset(key)
16
+
17
+ for outer_iter in range(num_outer_iters):
18
+ # Sample a new task and set it as the reset state
19
+ task_state, task_obs = sample_task(key)
20
+ state = env.set_reset_state(state, task_state, task_obs)
21
+
22
+ # Run episode - auto-resets return to task_state
23
+ for inner_step in range(num_inner_steps):
24
+ state, info = env.step(state, policy(info.obs))
25
+ """
26
+
27
+ class InjectedState(WrappedState):
28
+ reset_state: PyTree | None = field(default=None)
29
+ reset_obs: PyTree | None = field(default=None)
30
+
31
+ def set_reset_state(
32
+ self, state: WrappedState, reset_state: PyTree, reset_obs: PyTree
33
+ ) -> WrappedState:
34
+ """Update the state that resets will return to.
35
+
36
+ This method traverses the wrapped state hierarchy to find and update
37
+ the InjectedState, then reconstructs the full state tree.
38
+
39
+ Args:
40
+ state: Current state (can be from any outer wrapper)
41
+ reset_state: The state to reset to (inner environment state)
42
+ reset_obs: The observation to return on reset
43
+
44
+ Returns:
45
+ New state with updated reset fields at the appropriate level
46
+ """
47
+
48
+ def update_injected(s: WrappedState) -> WrappedState:
49
+ # If this is our InjectedState, update it
50
+ if isinstance(s, self.InjectedState):
51
+ return self.InjectedState(
52
+ inner_state=reset_state,
53
+ reset_state=reset_state,
54
+ reset_obs=reset_obs,
55
+ )
56
+ # Otherwise, recurse into inner_state and rebuild
57
+ if hasattr(s, "inner_state"):
58
+ return s.replace(inner_state=update_injected(s.inner_state))
59
+ raise ValueError("Could not find InjectedState in given state")
60
+
61
+ return update_injected(state)
62
+
63
+ def reset(
64
+ self, key: Key, state: PyTree | None = None, **kwargs
65
+ ) -> tuple[WrappedState, Info]:
66
+ # Default state has no inner state to reset to
67
+ if state is None:
68
+ state = self.InjectedState(inner_state=None)
69
+
70
+ # If no reset state is set, reset wrapped environment
71
+ if state.reset_state is None and state.reset_obs is None:
72
+ inner_state, info = self.env.reset(key, state=state.inner_state, **kwargs)
73
+
74
+ # If reset state is set, use it
75
+ elif state.reset_state is not None and state.reset_obs is not None:
76
+ inner_state = state.reset_state
77
+ info = InfoContainer(obs=state.reset_obs, reward=0.0, terminated=False)
78
+
79
+ # If only one of reset_state or reset_obs is set, raise error
80
+ else:
81
+ raise ValueError("State must set both reset_state and reset_obs or neither")
82
+
83
+ # Return new state with updated inner state
84
+ state = state.replace(inner_state=inner_state)
85
+ return state, info
86
+
87
+ def step(
88
+ self, state: WrappedState, action: PyTree, **kwargs
89
+ ) -> tuple[WrappedState, Info]:
90
+ inner_state, info = self.env.step(state.inner_state, action, **kwargs)
91
+ return state.replace(inner_state=inner_state), info
@@ -0,0 +1,22 @@
1
+ from envelope.environment import Info
2
+ from envelope.struct import field
3
+ from envelope.typing import Key, PyTree
4
+ from envelope.wrappers.wrapper import WrappedState, Wrapper
5
+
6
+
7
+ class TimeStepWrapper(Wrapper):
8
+ class TimeStepState(WrappedState):
9
+ steps: PyTree = field(default=0)
10
+
11
+ def reset(
12
+ self, key: Key, state: PyTree | None = None, **kwargs
13
+ ) -> tuple[WrappedState, Info]:
14
+ inner_state, info = self.env.reset(key, state, **kwargs)
15
+ return self.TimeStepState(inner_state=inner_state, steps=0), info
16
+
17
+ def step(
18
+ self, state: WrappedState, action: PyTree, **kwargs
19
+ ) -> tuple[WrappedState, Info]:
20
+ next_inner_state, info = self.env.step(state.inner_state, action, **kwargs)
21
+ next_steps = getattr(state, "steps", 0) + 1
22
+ return self.TimeStepState(inner_state=next_inner_state, steps=next_steps), info
@@ -0,0 +1,31 @@
1
+ import jax.numpy as jnp
2
+
3
+ from envelope.environment import Info
4
+ from envelope.struct import field
5
+ from envelope.typing import Key, PyTree
6
+ from envelope.wrappers.wrapper import WrappedState, Wrapper
7
+
8
+
9
+ class TruncationWrapper(Wrapper):
10
+ max_steps: int = field(kw_only=True)
11
+
12
+ class TruncationState(WrappedState):
13
+ steps: jnp.ndarray | int = field(default=0)
14
+
15
+ def reset(
16
+ self, key: Key, state: PyTree | None = None, **kwargs
17
+ ) -> tuple[WrappedState, Info]:
18
+ inner_state, info = self.env.reset(key)
19
+ state = self.TruncationState(inner_state=inner_state, steps=0)
20
+ return state, info.update(truncated=self.max_steps <= 0)
21
+
22
+ def step(
23
+ self, state: WrappedState, action: PyTree, **kwargs
24
+ ) -> tuple[WrappedState, Info]:
25
+ next_inner_state, info = self.env.step(state.inner_state, action, **kwargs)
26
+ next_steps = state.steps + 1
27
+ next_state = self.TruncationState(
28
+ inner_state=next_inner_state, steps=next_steps
29
+ )
30
+ truncated = jnp.asarray(next_steps) >= self.max_steps
31
+ return next_state, info.update(truncated=truncated)
@@ -0,0 +1,77 @@
1
+ from functools import cached_property
2
+ from typing import override
3
+
4
+ import jax
5
+
6
+ from envelope import spaces
7
+ from envelope.environment import Environment, Info
8
+ from envelope.struct import field
9
+ from envelope.typing import Key, PyTree
10
+ from envelope.wrappers.wrapper import WrappedState, Wrapper
11
+
12
+
13
+ class VmapEnvsWrapper(Wrapper):
14
+ """
15
+ Vectorizes over a batched collection of environment instances (vmapping over 'self').
16
+
17
+ Usage:
18
+ envs = jax.vmap(make_env)(params_batch) # env pytree batched on leading axis
19
+ wrapped = VmapEnvsWrapper(env=envs, batch_size=B)
20
+ state, info = wrapped.reset(keys) # keys shape (B, 2) or single key
21
+ next_state, info = wrapped.step(state, action)
22
+ """
23
+
24
+ batch_size: int = field(kw_only=True)
25
+
26
+ @override
27
+ def reset(
28
+ self, key: Key, state: PyTree | None = None, **kwargs
29
+ ) -> tuple[WrappedState, Info]:
30
+ if key.shape == (2,):
31
+ keys = jax.random.split(key, self.batch_size)
32
+ else:
33
+ if key.shape[0] != self.batch_size:
34
+ raise ValueError(
35
+ f"reset key's leading dimension ({key.shape[0]}) must match "
36
+ f"batch_size ({self.batch_size})."
37
+ )
38
+ keys = key
39
+ # vmap over env 'self' and keys
40
+ state, info = jax.vmap(lambda e, k: e.reset(k, state, **kwargs))(self.env, keys)
41
+ return state, info
42
+
43
+ @override
44
+ def step(
45
+ self, state: WrappedState, action: PyTree, **kwargs
46
+ ) -> tuple[WrappedState, Info]:
47
+ next_state, info = jax.vmap(lambda e, s, a: e.step(s, a, **kwargs))(
48
+ self.env, state, action
49
+ )
50
+ return next_state, info
51
+
52
+ @override
53
+ @property
54
+ def observation_space(self) -> spaces.Space:
55
+ env0 = _index_env(self.env, 0, self.batch_size)
56
+ return spaces.batch_space(env0.observation_space, self.batch_size)
57
+
58
+ @override
59
+ @cached_property
60
+ def action_space(self) -> spaces.Space:
61
+ env0 = _index_env(self.env, 0, self.batch_size)
62
+ return spaces.batch_space(env0.action_space, self.batch_size)
63
+
64
+ @override
65
+ @property
66
+ def unwrapped(self) -> Environment:
67
+ return self.env.unwrapped
68
+
69
+
70
+ def _index_env(env: Environment, idx: int, batch_size: int) -> Environment:
71
+ def idx_or_keep(x):
72
+ if hasattr(x, "shape") and isinstance(getattr(x, "shape"), tuple):
73
+ if len(x.shape) > 0 and x.shape[0] == batch_size:
74
+ return x[idx]
75
+ return x
76
+
77
+ return jax.tree.map(lambda x: idx_or_keep(x), env)
@@ -0,0 +1,51 @@
1
+ from functools import cached_property
2
+ from typing import override
3
+
4
+ import jax
5
+
6
+ from envelope import spaces
7
+ from envelope.environment import Info
8
+ from envelope.struct import field
9
+ from envelope.typing import Key, PyTree
10
+ from envelope.wrappers.wrapper import WrappedState, Wrapper
11
+
12
+
13
+ class VmapWrapper(Wrapper):
14
+ """Does not forward kwargs to the underlying env. Does not wrap the state."""
15
+
16
+ batch_size: int = field(kw_only=True)
17
+
18
+ @override
19
+ def reset(
20
+ self, key: Key, state: PyTree | None = None, **kwargs
21
+ ) -> tuple[WrappedState, Info]:
22
+ # Accept single key or batched keys
23
+ if key.shape == (2,):
24
+ keys = jax.random.split(key, self.batch_size)
25
+ else:
26
+ if key.shape[0] != self.batch_size:
27
+ raise ValueError(
28
+ f"reset key's leading dimension ({key.shape[0]}) must match "
29
+ f"batch_size ({self.batch_size})."
30
+ )
31
+ keys = key
32
+
33
+ state, info = jax.vmap(self.env.reset)(keys, state)
34
+ return state, info
35
+
36
+ @override
37
+ def step(
38
+ self, state: WrappedState, action: PyTree, **kwargs
39
+ ) -> tuple[WrappedState, Info]:
40
+ state, info = jax.vmap(self.env.step)(state, action)
41
+ return state, info
42
+
43
+ @override
44
+ @cached_property
45
+ def observation_space(self) -> spaces.Space:
46
+ return spaces.batch_space(self.env.observation_space, self.batch_size)
47
+
48
+ @override
49
+ @cached_property
50
+ def action_space(self) -> spaces.Space:
51
+ return spaces.batch_space(self.env.action_space, self.batch_size)
@@ -0,0 +1,57 @@
1
+ from dataclasses import KW_ONLY
2
+ from functools import cached_property
3
+ from typing import override
4
+
5
+ from envelope import spaces
6
+ from envelope.environment import Environment, Info, State
7
+ from envelope.struct import FrozenPyTreeNode, field
8
+ from envelope.typing import Key, PyTree
9
+
10
+
11
+ class WrappedState(FrozenPyTreeNode):
12
+ inner_state: State = field()
13
+ _: KW_ONLY
14
+
15
+ @property
16
+ def unwrapped(self) -> State:
17
+ if hasattr(self.inner_state, "unwrapped"):
18
+ return self.inner_state.unwrapped
19
+ return self.inner_state
20
+
21
+
22
+ class Wrapper(Environment):
23
+ """Wrapper for environments."""
24
+
25
+ env: Environment = field(kw_only=True)
26
+
27
+ @override
28
+ def reset(
29
+ self, key: Key, state: State | None = None, **kwargs
30
+ ) -> tuple[State, Info]:
31
+ return self.env.reset(key, state=state, **kwargs)
32
+
33
+ @override
34
+ def step(
35
+ self, state: WrappedState, action: PyTree, **kwargs
36
+ ) -> tuple[WrappedState, Info]:
37
+ return self.env.step(state, action, **kwargs)
38
+
39
+ @override
40
+ @cached_property
41
+ def observation_space(self) -> spaces.Space:
42
+ return self.env.observation_space
43
+
44
+ @override
45
+ @cached_property
46
+ def action_space(self) -> spaces.Space:
47
+ return self.env.action_space
48
+
49
+ @override
50
+ @property
51
+ def unwrapped(self) -> Environment:
52
+ return self.env.unwrapped
53
+
54
+ def __getattr__(self, name):
55
+ if name == "__setstate__":
56
+ raise AttributeError(name)
57
+ return getattr(self.env, name)
@@ -0,0 +1,87 @@
1
+ Metadata-Version: 2.4
2
+ Name: jax-envelope
3
+ Version: 0.1.0
4
+ Summary: A JAX-native environment interface with powerful wrappers and adapters for popular RL environment suites
5
+ Project-URL: Homepage, https://github.com/keraJLi/envelope
6
+ Project-URL: Repository, https://github.com/keraJLi/envelope
7
+ Project-URL: Documentation, https://github.com/keraJLi/envelope#readme
8
+ Project-URL: Issues, https://github.com/keraJLi/envelope/issues
9
+ Project-URL: Changelog, https://github.com/keraJLi/envelope/releases
10
+ Author-email: Jarek Liesen <jarek.liesen@reuben.ox.ac.uk>
11
+ License: MIT
12
+ License-File: LICENSE
13
+ Keywords: deep-learning,environments,gymnasium,hardware-acceleration,jax,machine-learning,reinforcement-learning,vectorization
14
+ Classifier: Development Status :: 4 - Beta
15
+ Classifier: Intended Audience :: Developers
16
+ Classifier: Intended Audience :: Science/Research
17
+ Classifier: License :: OSI Approved :: MIT License
18
+ Classifier: Operating System :: OS Independent
19
+ Classifier: Programming Language :: Python :: 3
20
+ Classifier: Programming Language :: Python :: 3.12
21
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
22
+ Classifier: Typing :: Typed
23
+ Requires-Python: >=3.12
24
+ Requires-Dist: jax>=0.5.0
25
+ Description-Content-Type: text/markdown
26
+
27
+ # 💌 Envelope: a JAX-native environment interface
28
+ ```python
29
+ # Create environments from JAX-native suites you have installed, ...
30
+ env = envelope.create("gymnax::CartPole-v1")
31
+
32
+ # ... interact with the environments using a simple interface, ...
33
+ state, info = env.reset(key)
34
+ states, infos = jax.lax.scan(env.step, state, actions)
35
+ plt.plot(infos.reward.cumsum())
36
+
37
+ # ... and enjoy a powerful ecosystem of wrappers.
38
+ env = envelope.wrappers.AutoResetWrapper(env)
39
+ env = envelope.wrappers.VmapWrapper(env)
40
+ env = envelope.wrappers.ObservationNormalizationWrapper(env)
41
+ ```
42
+
43
+ ## 🌍 Simple, expressive interaction!
44
+ * **Environments are pytrees**. Squish them through JAX transformations and trace their parameters.
45
+ * **Idiomatic jax-y interface** of `reset(key: Key) -> State, Info` and `step(state: State, action: PyTree) -> State, Info`. You can directly `jax.scan` over a `step(...)`!
46
+ * **Spaces are super simple**. No `Tuple`, `Dict` nonsense! There are two spaces: `Continuous` and `Discrete`, which you can compose into a `PyTreeSpace`.
47
+ * **Explicit episode truncation** supports correctly handling bootstrapping for value-function targets.
48
+ * **No auto-reset** by default. Resetting every step can be expensive!
49
+
50
+ ## 💪 Powerful, composable wrappers!
51
+ * **Carry state across episodes** to track running statistics, for example to normalize observations.
52
+ * **Composable wrappers** can be stacked in any order. For example, `ObservationNormalizationWrapper` before vs. after `VmapWrapper` gives per-env vs. global normalization.
53
+ <!-- TODO: Add auto-reset behavior (including state injection) and optimistic resets once I implement them. -->
54
+
55
+ ## 🔌 Adapters for existing suites
56
+ | 📦 | # 🤖 | # 🌍 |
57
+ |------|------|------|
58
+ | [gymnax](https://github.com/RobertTLange/gymnax) | 🕺 | 24 |
59
+ | [brax](https://github.com/google/brax) | 🕺 | 12 |
60
+ | [jumanji](https://github.com/instadeepai/jumanji) | 🕺 / 👯 | 25 / 1 |
61
+ | [kinetix](https://github.com/flairox/kinetix) | 🕺 | 74 |
62
+ | [craftax](https://github.com/MichaelTMatthews/craftax) | 🕺 | 4 |
63
+ | [mujoco_playground](https://github.com/google-deepmind/mujoco_playground) | 🕺 | 54 |
64
+ | | |
65
+ | Total | 🕺 / 👯 | 193 / 1 |
66
+
67
+ ```python
68
+ envelope.create("📦::🌍")
69
+ ```
70
+ let's you create environments from any of the above!
71
+
72
+ ## 📝 Testing
73
+ - **Default (no optional compat deps required)**: `uv run pytest -m "not compat"`
74
+ - **Compat suite (requires full compat dependency group)**:
75
+ - `uv sync --group compat`
76
+ - `uv run pytest -m compat`
77
+ - If any compat dependency is missing/broken, the run will fail fast with an error telling you what to install.
78
+
79
+ ## 🏗️ Installation
80
+ ```bash
81
+ pip install jax-envelope
82
+ ```
83
+
84
+ ## 💞 Related projects
85
+ * [stoax](https://github.com/EdanToledo/Stoa) is a very similar project that provides adapters and wrappers for the jumanji-like interface.
86
+ * Check out all the great suites we have adapters for! [gymnax](https://github.com/RobertTLange/gymnax), [brax](https://github.com/google/brax), [jumanji](https://github.com/instadeepai/jumanji), [kinetix](https://github.com/flairox/kinetix), [craftax](https://github.com/MichaelTMatthews/craftax), [mujoco_playground](https://github.com/google-deepmind/mujoco_playground).
87
+ * We will be adding support for [jaxmarl](https://github.com/flairox/jaxmarl) and [pgx](https://github.com/sotetsuk/pgx) in the future, as soon as we figured out the best ever MARL interface for JAX!
@@ -0,0 +1,27 @@
1
+ envelope/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ envelope/environment.py,sha256=4T0IosoRX-6gVvZCKBUEitlM_5cAT43hrKQ0Hr53XGk,1735
3
+ envelope/spaces.py,sha256=fvJ3-ZI3iyFlpfEN1barTd9cbWvQ11nl-7Y45KgFizs,6433
4
+ envelope/struct.py,sha256=Sb8GsxZ7rFF5A5128oZWEQVnFK79xUFe6EaU2bghBOw,5158
5
+ envelope/typing.py,sha256=dcftDRNM0luCqHDEJ_qVmZnjUOEwn-HiFvwuV1z1Bn0,734
6
+ envelope/compat/__init__.py,sha256=6q2B2bfTIu7MlnIXc702ysqEnyURz-MEjTBMBCo8rdQ,3458
7
+ envelope/compat/brax_envelope.py,sha256=T0Xefrpgft16SdRJKMIjzdJq5tbocQ4jeI763PDzZ_4,3445
8
+ envelope/compat/craftax_envelope.py,sha256=t151fDCrgPWfynERXHYXiWBwCGAXbFcreuIyH7GCApg,3162
9
+ envelope/compat/gymnax_envelope.py,sha256=iau4CKFLVIJkRA9WILgOqkp3mXIAzkilKcrKF8vkld4,3790
10
+ envelope/compat/jumanji_envelope.py,sha256=_Tox7DfiyE-Z4drszPUe-arPurz6hMeFtRNfy8o9Ipc,4869
11
+ envelope/compat/kinetix_envelope.py,sha256=kPLWwSoQ49kNXuvfbStOtoGby9EkV7bKAetxUa1_YPA,6817
12
+ envelope/compat/mujoco_playground_envelope.py,sha256=UcYZwOPD27473wmDDb5n2eK46uPwOeh0LZOgyov4rww,3706
13
+ envelope/compat/navix_envelope.py,sha256=K7AGae9_kDETFAUuWH9sCCLDo5fuEFbc0eKTHcvSxek,3027
14
+ envelope/wrappers/autoreset_wrapper.py,sha256=3OUUNb4L6P4Ncn57Uy2ldoLFr1IKnPSWuzyjf20N0Y8,1299
15
+ envelope/wrappers/episode_statistics_wrapper.py,sha256=Mj5Ua7cLtBQYtFn3oQB9eJ7TclKwAFUvsQHIRkcQJvs,1509
16
+ envelope/wrappers/normalization.py,sha256=xHezXsb1J09D3IZACC1xxMy9NuUSBWqCyYH2wFAItPs,1811
17
+ envelope/wrappers/observation_normalization_wrapper.py,sha256=NNHz_THFe2eQUY0Pued9_hJztIGU6SU7kae_AKFbe4k,4248
18
+ envelope/wrappers/state_injection_wrapper.py,sha256=yk7he1zPaEdzjks_NDnYDC91e5bXTlCFPKFMwm-A7jk,3687
19
+ envelope/wrappers/timestep_wrapper.py,sha256=6jS-80AwnIIct3Ool9zq_iGOyfodNbs89NfuyYMfYIE,874
20
+ envelope/wrappers/truncation_wrapper.py,sha256=TzJqxjAjipk7pTk-drXRvQ41uR-jWsHlAgYJpbHI32M,1132
21
+ envelope/wrappers/vmap_envs_wrapper.py,sha256=TezSCeOf3oJKj8w7i_UaAA2LsoWkhhp835W545Q-waI,2561
22
+ envelope/wrappers/vmap_wrapper.py,sha256=zYMYVvuVQfRM7d2ygxxcavRNtkdzkqhArakE507tcU4,1587
23
+ envelope/wrappers/wrapper.py,sha256=ydolcZR9ogW3z4ui07845mbMue2XtEoEJ64tHCzAAtw,1504
24
+ jax_envelope-0.1.0.dist-info/METADATA,sha256=-ASxsWLgF9CwcQ-mp1NJAmdFDS4qzs8bekNJH291R00,4653
25
+ jax_envelope-0.1.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
26
+ jax_envelope-0.1.0.dist-info/licenses/LICENSE,sha256=VyF-MK-gY2_fZlhf8uEnE2y8ziIXK-w55GM12eOgXrQ,1069
27
+ jax_envelope-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.28.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any