jax-envelope 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,101 @@
1
+ import dataclasses
2
+ from functools import cached_property
3
+ from typing import Any, override
4
+
5
+ from jax import numpy as jnp
6
+ from mujoco_playground import registry
7
+
8
+ from envelope import spaces as envelope_spaces
9
+ from envelope.environment import Environment, Info, InfoContainer, State
10
+ from envelope.struct import static_field
11
+ from envelope.typing import Key, PyTree
12
+
13
+ _MAX_INT = int(jnp.iinfo(jnp.int32).max)
14
+
15
+
16
+ _MUJOCO_PLAYGROUND_DEFAULT_EPISODE_LENGTH = 1000
17
+
18
+
19
+ class MujocoPlaygroundEnvelope(Environment):
20
+ """Wrapper to convert a mujoco_playground environment to a envelope environment."""
21
+
22
+ mujoco_playground_env: Any = static_field()
23
+ _default_max_steps: int = static_field(
24
+ default=_MUJOCO_PLAYGROUND_DEFAULT_EPISODE_LENGTH
25
+ )
26
+
27
+ @classmethod
28
+ def from_name(
29
+ cls, env_name: str, env_kwargs: dict[str, Any] | None = None
30
+ ) -> "MujocoPlaygroundEnvelope":
31
+ """Creates a MujocoPlaygroundEnvelope from a name and keyword arguments.
32
+ env_kwargs are passed to config_overrides of mujoco_playground.registry.load."""
33
+ env_kwargs = env_kwargs or {}
34
+ if "episode_length" in env_kwargs:
35
+ raise ValueError(
36
+ "Cannot override 'episode_length' directly. "
37
+ "Use TruncationWrapper for episode length control."
38
+ )
39
+
40
+ # Get default episode_length from registry config
41
+ default_config = registry.get_default_config(env_name)
42
+ default_max_steps = default_config.episode_length
43
+
44
+ # Set episode_length to a very large value
45
+ # (mujoco_playground uses int for episode_length, so we use max int instead of inf)
46
+ env_kwargs["episode_length"] = _MAX_INT
47
+
48
+ # Pass all env_kwargs as config_overrides
49
+ env = registry.load(
50
+ env_name, config_overrides=env_kwargs if env_kwargs else None
51
+ )
52
+ return cls(mujoco_playground_env=env, _default_max_steps=default_max_steps)
53
+
54
+ @property
55
+ def default_max_steps(self) -> int:
56
+ return self._default_max_steps
57
+
58
+ @override
59
+ def reset(self, key: Key) -> tuple[State, Info]:
60
+ env_state = self.mujoco_playground_env.reset(key)
61
+ info = InfoContainer(obs=env_state.obs, reward=0.0, terminated=False)
62
+ info = info.update(**dataclasses.asdict(env_state))
63
+ return env_state, info
64
+
65
+ @override
66
+ def step(self, state: State, action: PyTree) -> tuple[State, Info]:
67
+ state = self.mujoco_playground_env.step(state, action)
68
+ info = InfoContainer(obs=state.obs, reward=state.reward, terminated=state.done)
69
+ info = info.update(**dataclasses.asdict(state))
70
+ return state, info
71
+
72
+ @override
73
+ @cached_property
74
+ def action_space(self) -> envelope_spaces.Space:
75
+ # MuJoCo Playground actions are typically bounded [-1, 1]
76
+ return envelope_spaces.Continuous.from_shape(
77
+ low=-1.0, high=1.0, shape=(self.mujoco_playground_env.action_size,)
78
+ )
79
+
80
+ @override
81
+ @cached_property
82
+ def observation_space(self) -> envelope_spaces.Space:
83
+ import jax
84
+
85
+ def to_space(size):
86
+ shape = (size,) if isinstance(size, int) else size
87
+ return envelope_spaces.Continuous.from_shape(
88
+ low=-jnp.inf, high=jnp.inf, shape=shape
89
+ )
90
+
91
+ def is_leaf(x):
92
+ return isinstance(x, int) or (
93
+ isinstance(x, tuple) and all(isinstance(i, int) for i in x)
94
+ )
95
+
96
+ space_tree = jax.tree.map(
97
+ to_space, self.mujoco_playground_env.observation_size, is_leaf=is_leaf
98
+ )
99
+ if isinstance(space_tree, envelope_spaces.Space):
100
+ return space_tree
101
+ return envelope_spaces.PyTreeSpace(space_tree)
@@ -0,0 +1,86 @@
1
+ import dataclasses
2
+ from functools import cached_property
3
+ from typing import Any, override
4
+
5
+ import jax.numpy as jnp
6
+ import navix
7
+ from navix import spaces as navix_spaces
8
+ from navix.environments.environment import Environment as NavixEnv
9
+
10
+ from envelope import spaces as envelope_spaces
11
+ from envelope.environment import Environment, Info, InfoContainer, State
12
+ from envelope.typing import Key, PyTree
13
+
14
+ _NAVIX_DEFAULT_MAX_STEPS = 100
15
+
16
+
17
+ class NavixEnvelope(Environment):
18
+ """Wrapper to convert a Navix environment to a envelope environment."""
19
+
20
+ navix_env: NavixEnv
21
+
22
+ @classmethod
23
+ def from_name(
24
+ cls, env_name: str, env_kwargs: dict[str, Any] | None = None
25
+ ) -> "NavixEnvelope":
26
+ env_kwargs = env_kwargs or {}
27
+ if "max_steps" in env_kwargs:
28
+ raise ValueError(
29
+ "Cannot override 'max_steps' directly. "
30
+ "Use TruncationWrapper for episode length control."
31
+ )
32
+ env_kwargs["max_steps"] = jnp.inf
33
+ navix_env = navix.make(env_name, **env_kwargs)
34
+ return cls(navix_env=navix_env)
35
+
36
+ @property
37
+ def default_max_steps(self) -> int:
38
+ return _NAVIX_DEFAULT_MAX_STEPS
39
+
40
+ @override
41
+ def reset(self, key: Key) -> tuple[State, Info]:
42
+ timestep = self.navix_env.reset(key)
43
+ return timestep, convert_navix_to_envelope_info(timestep)
44
+
45
+ @override
46
+ def step(self, state: State, action: PyTree) -> tuple[State, Info]:
47
+ timestep = self.navix_env.step(state, action)
48
+ return timestep, convert_navix_to_envelope_info(timestep)
49
+
50
+ @override
51
+ @cached_property
52
+ def action_space(self) -> envelope_spaces.Space:
53
+ return convert_navix_to_envelope_space(self.navix_env.action_space)
54
+
55
+ @override
56
+ @cached_property
57
+ def observation_space(self) -> envelope_spaces.Space:
58
+ return convert_navix_to_envelope_space(self.navix_env.observation_space)
59
+
60
+
61
+ def convert_navix_to_envelope_info(nvx_timestep: navix.Timestep) -> InfoContainer:
62
+ timestep_dict = dataclasses.asdict(nvx_timestep)
63
+ step_type = timestep_dict.pop("step_type")
64
+ info = InfoContainer(
65
+ obs=timestep_dict.pop("observation"),
66
+ reward=timestep_dict.pop("reward"),
67
+ terminated=step_type == navix.StepType.TERMINATION,
68
+ truncated=step_type == navix.StepType.TRUNCATION,
69
+ )
70
+ info = info.update(**timestep_dict)
71
+ return info
72
+
73
+
74
+ def convert_navix_to_envelope_space(
75
+ nvx_space: navix_spaces.Space,
76
+ ) -> envelope_spaces.Space:
77
+ if isinstance(nvx_space, navix_spaces.Discrete):
78
+ n = jnp.asarray(nvx_space.n).astype(nvx_space.dtype)
79
+ return envelope_spaces.Discrete.from_shape(n, shape=nvx_space.shape)
80
+
81
+ elif isinstance(nvx_space, navix_spaces.Continuous):
82
+ low = jnp.asarray(nvx_space.minimum).astype(nvx_space.dtype)
83
+ high = jnp.asarray(nvx_space.maximum).astype(nvx_space.dtype)
84
+ return envelope_spaces.Continuous.from_shape(low, high, shape=nvx_space.shape)
85
+
86
+ raise ValueError(f"Unsupported space type: {type(nvx_space)}")
@@ -0,0 +1,64 @@
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import field
3
+ from functools import cached_property
4
+ from typing import Protocol, runtime_checkable
5
+
6
+ from envelope import spaces
7
+ from envelope.struct import Container, FrozenPyTreeNode
8
+ from envelope.typing import Key, PyTree
9
+
10
+ __all__ = ["Environment", "State", "Info", "InfoContainer"]
11
+
12
+
13
+ @runtime_checkable
14
+ class Info(Protocol):
15
+ obs: PyTree
16
+ reward: float
17
+ terminated: bool
18
+ truncated: bool
19
+
20
+ def update(self, **changes: PyTree) -> "Info": ...
21
+ def __getattr__(self, name: str) -> PyTree: ...
22
+
23
+
24
+ class InfoContainer(Container):
25
+ obs: PyTree
26
+ reward: float
27
+ terminated: bool
28
+ truncated: bool = field(default=False)
29
+
30
+
31
+ # State remains a general PyTree alias; environments are not forced to WrappedState
32
+ State = PyTree
33
+
34
+
35
+ class Environment(ABC, FrozenPyTreeNode):
36
+ """
37
+ Base class for all environments.
38
+
39
+ State is an opaque PyTree owned by each environment; wrappers that stack
40
+ environments should expose their wrapped env state as `inner_state` while
41
+ adding any wrapper-specific fields. `reset` may optionally receive a prior
42
+ state (for cross-episode persistence) and arbitrary **kwargs that wrappers
43
+ or environments can use.
44
+ """
45
+
46
+ @abstractmethod
47
+ def reset(
48
+ self, key: Key, state: State | None = None, **kwargs
49
+ ) -> tuple[State, Info]: ...
50
+
51
+ @abstractmethod
52
+ def step(self, state: State, action: PyTree, **kwargs) -> tuple[State, Info]: ...
53
+
54
+ @abstractmethod
55
+ @cached_property
56
+ def observation_space(self) -> spaces.Space: ...
57
+
58
+ @abstractmethod
59
+ @cached_property
60
+ def action_space(self) -> spaces.Space: ...
61
+
62
+ @property
63
+ def unwrapped(self) -> "Environment":
64
+ return self
envelope/spaces.py ADDED
@@ -0,0 +1,205 @@
1
+ from abc import ABC, abstractmethod
2
+ from functools import cached_property
3
+ from typing import override
4
+
5
+ import jax
6
+ from jax import numpy as jnp
7
+
8
+ from envelope.struct import FrozenPyTreeNode, static_field
9
+ from envelope.typing import Key, PyTree
10
+
11
+
12
+ class Space(ABC, FrozenPyTreeNode):
13
+ @abstractmethod
14
+ def sample(self, key: Key) -> PyTree: ...
15
+
16
+ @abstractmethod
17
+ def contains(self, x: PyTree) -> bool: ...
18
+
19
+
20
+ class Discrete(Space):
21
+ """
22
+ A discrete space with a given number of elements. `n` can be a scalar or an array.
23
+ The shape and dtype of the space are inferred from `n`.
24
+
25
+ Args:
26
+ n: The number of elements in the space.
27
+ """
28
+
29
+ n: int | jax.Array
30
+
31
+ @classmethod
32
+ def from_shape(cls, n: int, shape: tuple[int]) -> "Discrete":
33
+ return cls(n=jnp.full(shape, n, dtype=jnp.asarray(n).dtype))
34
+
35
+ @property
36
+ def shape(self) -> tuple[int, ...]:
37
+ return jnp.asarray(self.n).shape
38
+
39
+ @property
40
+ def dtype(self):
41
+ return jnp.asarray(self.n).dtype
42
+
43
+ def sample(self, key: Key) -> jax.Array:
44
+ return jax.random.randint(key, self.shape, 0, self.n, dtype=self.dtype)
45
+
46
+ def contains(self, x: int | jax.Array) -> bool:
47
+ return jnp.all(x >= 0) & jnp.all(x < self.n)
48
+
49
+ def __repr__(self) -> str:
50
+ return f"Discrete(shape={self.shape}, dtype={self.dtype}, n={self.n})"
51
+
52
+
53
+ class Continuous(Space):
54
+ """
55
+ A continuous space with a given lower and upper bound. `low` and `high` can be
56
+ scalars or arrays. The shape and dtype of the space are inferred from `low` and
57
+ `high`.
58
+
59
+ Args:
60
+ low: The lower bound of the space.
61
+ high: The upper bound of the space.
62
+ """
63
+
64
+ low: float | jax.Array
65
+ high: float | jax.Array
66
+
67
+ @classmethod
68
+ def from_shape(cls, low: float, high: float, shape: tuple[int]) -> "Continuous":
69
+ return cls(
70
+ low=jnp.full(shape, low, dtype=jnp.asarray(low).dtype),
71
+ high=jnp.full(shape, high, dtype=jnp.asarray(high).dtype),
72
+ )
73
+
74
+ @property
75
+ def dtype(self):
76
+ if jnp.asarray(self.low).dtype != jnp.asarray(self.high).dtype:
77
+ raise ValueError("low and high must have the same dtype")
78
+
79
+ return jnp.asarray(self.low).dtype
80
+
81
+ @property
82
+ def shape(self) -> tuple[int, ...]:
83
+ if jnp.asarray(self.low).shape != jnp.asarray(self.high).shape:
84
+ raise ValueError("low and high must have the same shape")
85
+
86
+ return jnp.asarray(self.low).shape
87
+
88
+ @override
89
+ def sample(self, key: Key) -> jax.Array:
90
+ uniform_sample = jax.random.uniform(key, self.shape, self.dtype)
91
+ return self.low + uniform_sample * (self.high - self.low)
92
+
93
+ @override
94
+ def contains(self, x: jax.Array) -> bool:
95
+ return jnp.all((x >= jnp.asarray(self.low)) & (x <= jnp.asarray(self.high)))
96
+
97
+ def __repr__(self) -> str:
98
+ dtype_str = getattr(self.dtype, "__name__", str(self.dtype))
99
+ return (
100
+ f"Continuous(shape={self.shape}, dtype={dtype_str}, "
101
+ f"low={self.low}, high={self.high})"
102
+ )
103
+
104
+
105
+ class PyTreeSpace(Space):
106
+ """A Space defined by a PyTree structure of other Spaces.
107
+
108
+ Args:
109
+ tree: A PyTree with Space objects leaves.
110
+
111
+ Usage:
112
+ space = PyTreeSpace({
113
+ "action": Discrete(n=4, dtype=jnp.int32),
114
+ "obs": Continuous(low=0.0, high=1.0, shape=(2,), dtype=jnp.float32)
115
+ })
116
+ """
117
+
118
+ tree: PyTree
119
+
120
+ @override
121
+ def sample(self, key: Key) -> PyTree:
122
+ leaves, treedef = jax.tree.flatten(
123
+ self.tree, is_leaf=lambda x: isinstance(x, Space)
124
+ )
125
+ keys = jax.random.split(key, len(leaves))
126
+ samples = [space.sample(key) for key, space in zip(keys, leaves)]
127
+ return jax.tree.unflatten(treedef, samples)
128
+
129
+ @override
130
+ def contains(self, x: PyTree) -> bool:
131
+ # Use tree.map to check containment for each space-value pair
132
+ contains = jax.tree.map(
133
+ lambda space, xi: space.contains(xi),
134
+ self.tree,
135
+ x,
136
+ is_leaf=lambda node: isinstance(node, Space),
137
+ )
138
+ return jnp.all(jnp.array(jax.tree.leaves(contains)))
139
+
140
+ def __repr__(self) -> str:
141
+ """Return a string representation showing the tree structure."""
142
+ return f"{self.__class__.__name__}({self.tree!r})"
143
+
144
+ @property
145
+ def shape(self) -> PyTree:
146
+ return jax.tree.map(
147
+ lambda space: space.shape,
148
+ self.tree,
149
+ is_leaf=lambda node: isinstance(node, Space),
150
+ )
151
+
152
+
153
+ def batch_space(space: Space, batch_size: int) -> Space:
154
+ if isinstance(space, PyTreeSpace):
155
+ batched_tree = jax.tree.map(
156
+ lambda sp: batch_space(sp, batch_size),
157
+ space.tree,
158
+ is_leaf=lambda node: isinstance(node, Space),
159
+ )
160
+ return PyTreeSpace(batched_tree)
161
+ return BatchedSpace(space=space, batch_size=batch_size)
162
+
163
+
164
+ class BatchedSpace(Space):
165
+ """
166
+ A view that adds a leading batch dimension to a base Space without
167
+ materializing or broadcasting its parameters.
168
+ """
169
+
170
+ space: Space
171
+ batch_size: int = static_field()
172
+
173
+ def sample(self, key: Key) -> PyTree:
174
+ # Accept single PRNGKey or a batch of keys shaped (batch_size, 2)
175
+ if getattr(key, "shape", ()) == (2,):
176
+ keys = jax.random.split(key, self.batch_size)
177
+ else:
178
+ if key.shape[0] != self.batch_size:
179
+ raise ValueError(
180
+ f"sample key's leading dimension ({key.shape[0]}) must match "
181
+ f"batch_size ({self.batch_size})."
182
+ )
183
+ keys = key
184
+ return jax.vmap(self.space.sample)(keys)
185
+
186
+ def contains(self, x: PyTree) -> bool:
187
+ # x is expected to be batched on the leading dimension
188
+ result = jax.vmap(self.space.contains)(x)
189
+ return jnp.all(jnp.asarray(result))
190
+
191
+ @cached_property
192
+ def shape(self) -> PyTree:
193
+ inner_shape = self.space.shape
194
+ # For tuple shapes (leaf spaces), prepend batch dimension.
195
+ # PyTree shapes are handled by wrapping leaves with BatchedSpace via batch_space.
196
+ if isinstance(inner_shape, tuple):
197
+ return (self.batch_size,) + inner_shape
198
+ return inner_shape
199
+
200
+ @property
201
+ def dtype(self):
202
+ return getattr(self.space, "dtype", None)
203
+
204
+ def __repr__(self) -> str:
205
+ return f"BatchedSpace(space={self.space!r}, batch_size={self.batch_size})"
envelope/struct.py ADDED
@@ -0,0 +1,148 @@
1
+ import dataclasses
2
+ from dataclasses import KW_ONLY
3
+ from typing import Any, Iterable, Iterator, Mapping, Self, Tuple
4
+
5
+ import jax
6
+
7
+ from envelope.typing import PyTree
8
+
9
+ __all__ = ["FrozenPyTreeNode", "field", "static_field", "Container"]
10
+
11
+
12
+ def field(*, pytree_node: bool = True, **kwargs):
13
+ """
14
+ Dataclass field helper.
15
+ Set pytree_node=False for static (non-transformed) fields.
16
+ """
17
+ meta = dict(kwargs.pop("metadata", {}) or {})
18
+ meta["pytree_node"] = pytree_node
19
+ return dataclasses.field(metadata=meta, **kwargs)
20
+
21
+
22
+ def static_field(**kwargs):
23
+ """Shorthand for field(pytree_node=False, ...)."""
24
+ return field(pytree_node=False, **kwargs)
25
+
26
+
27
+ class FrozenPyTreeNode:
28
+ """
29
+ Frozen dataclass base that is a JAX pytree node.
30
+
31
+ Usage:
32
+ class Foo(FrozenPyTreeNode):
33
+ a: Any # pytree leaf
34
+ b: int = static_field() # static, not a leaf
35
+
36
+ x = Foo(a={"w": 1.0}, b=0)
37
+ y = x.replace(b=1)
38
+ """
39
+
40
+ # Turn subclasses into frozen dataclasses and register with JAX.
41
+ def __init_subclass__(cls, *, dataclass_kwargs: dict[str, Any] | None = None, **kw):
42
+ super().__init_subclass__(**kw)
43
+ # Check if this specific class (not parent) has already been processed
44
+ if "__is_envelope_pytreenode__" in cls.__dict__:
45
+ return
46
+ opts = dict(frozen=True, eq=True, repr=True, slots=False)
47
+ if dataclass_kwargs:
48
+ opts.update(dataclass_kwargs)
49
+ dataclasses.dataclass(cls, **opts) # modify in place
50
+ cls.__is_envelope_pytreenode__ = True
51
+
52
+ data = []
53
+ static = []
54
+ for f in dataclasses.fields(cls):
55
+ if f.metadata.get("pytree_node", True):
56
+ data.append(f.name)
57
+ else:
58
+ static.append(f.name)
59
+
60
+ jax.tree_util.register_dataclass(cls, data, static)
61
+
62
+ # convenience
63
+ def replace(self, **changes):
64
+ return dataclasses.replace(self, **changes)
65
+
66
+
67
+ @jax.tree_util.register_pytree_node_class
68
+ @dataclasses.dataclass(frozen=True, eq=True, repr=True, slots=False)
69
+ class Container:
70
+ _: KW_ONLY
71
+ _extras: Mapping[str, PyTree] = dataclasses.field(default_factory=dict, repr=False)
72
+
73
+ def __init_subclass__(cls, *, dataclass_kwargs: dict[str, Any] | None = None, **kw):
74
+ super().__init_subclass__(**kw)
75
+ if "__is_container_dataclass__" in cls.__dict__:
76
+ return
77
+
78
+ opts = dict(frozen=True, eq=True, repr=True, slots=False)
79
+ if dataclass_kwargs:
80
+ opts.update(dataclass_kwargs)
81
+
82
+ dataclasses.dataclass(cls, **opts)
83
+ cls.__is_container_dataclass__ = True
84
+ jax.tree_util.register_pytree_node_class(cls)
85
+
86
+ def __getattr__(self, name: str) -> PyTree:
87
+ # bypass __getattr__ when accessing _extras to avoid recursion
88
+ extras = object.__getattribute__(self, "_extras")
89
+ if name in extras:
90
+ return extras[name]
91
+ self_name = type(self).__name__
92
+ raise AttributeError(f"'{self_name}' object has no attribute '{name}'")
93
+
94
+ def __dir__(self) -> Iterable[str]:
95
+ core_names = {f.name for f in dataclasses.fields(self) if f.name != "_extras"}
96
+ return sorted(set(super().__dir__()) | core_names | set(self._extras.keys()))
97
+
98
+ def __iter__(self) -> Iterator[Tuple[str, PyTree]]:
99
+ for f in dataclasses.fields(self):
100
+ if f.name == "_extras":
101
+ continue
102
+ yield (f.name, getattr(self, f.name))
103
+ # extras
104
+ for k, v in self._extras.items():
105
+ yield (k, v)
106
+
107
+ def update(self, **changes: PyTree) -> Self:
108
+ core_names = {f.name for f in dataclasses.fields(self) if f.name != "_extras"}
109
+ core_updates: dict[str, PyTree] = {}
110
+ extras_updates: dict[str, PyTree] = {}
111
+
112
+ for k, v in changes.items():
113
+ if k in core_names:
114
+ core_updates[k] = v
115
+ else:
116
+ extras_updates[k] = v
117
+
118
+ new = dataclasses.replace(self, **core_updates)
119
+ new_extras = {**self._extras, **extras_updates}
120
+ object.__setattr__(new, "_extras", new_extras)
121
+ return new
122
+
123
+ def tree_flatten(self) -> Tuple[Tuple[PyTree, ...], Tuple[Any, ...]]:
124
+ core_fields = [f for f in dataclasses.fields(self) if f.name != "_extras"]
125
+ core_keys = tuple(f.name for f in core_fields)
126
+ core_vals = tuple(getattr(self, name) for name in core_keys)
127
+
128
+ extras_keys = tuple(self._extras.keys())
129
+ extras_vals = tuple(self._extras[k] for k in extras_keys)
130
+
131
+ children = core_vals + extras_vals
132
+ aux_data = (self.__class__, core_keys, extras_keys)
133
+ return children, aux_data
134
+
135
+ @classmethod
136
+ def tree_unflatten(cls, aux_data, children: Tuple[PyTree, ...]) -> Self:
137
+ actual_cls, core_keys, extras_keys = aux_data
138
+ n_core = len(core_keys)
139
+
140
+ core_vals = children[:n_core]
141
+ extras_vals = children[n_core:]
142
+
143
+ core_kwargs = dict(zip(core_keys, core_vals))
144
+ extras = dict(zip(extras_keys, extras_vals))
145
+
146
+ obj = actual_cls(**core_kwargs)
147
+ object.__setattr__(obj, "_extras", extras)
148
+ return obj
envelope/typing.py ADDED
@@ -0,0 +1,23 @@
1
+ from enum import Enum
2
+ from typing import Any, TypeAlias
3
+
4
+ import jax
5
+
6
+ PyTree: TypeAlias = Any
7
+ Key: TypeAlias = jax.Array
8
+ Array: TypeAlias = jax.Array
9
+
10
+
11
+ class BatchKind(Enum):
12
+ """
13
+ Batch semantics for environments.
14
+ - VMAP: environment represents instances compatible with `jax.vmap`, for example by
15
+ wrapping it in a `VmapWrapper`.
16
+ - NATIVE_POOL: environment represents a batch of instances via a native pool. This
17
+ is the case when it is wrapping a non-jax-based environment that supports native
18
+ batching, for example those provided by envpool. Environments in this mode cannot
19
+ be vmapped, as it would break the native batching semantics.
20
+ """
21
+
22
+ VMAP = "vmap"
23
+ NATIVE_POOL = "native_pool"
@@ -0,0 +1,36 @@
1
+ import jax
2
+
3
+ from envelope.environment import Info
4
+ from envelope.struct import field
5
+ from envelope.typing import Key, PyTree
6
+ from envelope.wrappers.wrapper import WrappedState, Wrapper
7
+
8
+
9
+ class AutoResetWrapper(Wrapper):
10
+ class AutoResetState(WrappedState):
11
+ reset_key: jax.Array = field()
12
+
13
+ def reset(
14
+ self, key: Key, state: PyTree | None = None, **kwargs
15
+ ) -> tuple[WrappedState, Info]:
16
+ key, subkey = jax.random.split(key)
17
+ inner_state = state.inner_state if state else None
18
+ inner_state, info = self.env.reset(key, inner_state, **kwargs)
19
+ state = self.AutoResetState(inner_state=inner_state, reset_key=subkey)
20
+ return state, info.update(next_obs=info.obs)
21
+
22
+ def step(
23
+ self, state: WrappedState, action: PyTree, **kwargs
24
+ ) -> tuple[WrappedState, Info]:
25
+ inner_state, info_step = self.env.step(state.inner_state, action, **kwargs)
26
+ done = info_step.terminated | info_step.truncated
27
+
28
+ state = self.AutoResetState(inner_state=inner_state, reset_key=state.reset_key)
29
+ info = info_step.update(next_obs=info_step.obs)
30
+
31
+ state, info = jax.lax.cond(
32
+ done,
33
+ lambda: self.reset(state.reset_key, state),
34
+ lambda: (state, info),
35
+ )
36
+ return state, info