jax-envelope 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- envelope/__init__.py +0 -0
- envelope/compat/__init__.py +97 -0
- envelope/compat/brax_envelope.py +98 -0
- envelope/compat/craftax_envelope.py +86 -0
- envelope/compat/gymnax_envelope.py +91 -0
- envelope/compat/jumanji_envelope.py +127 -0
- envelope/compat/kinetix_envelope.py +194 -0
- envelope/compat/mujoco_playground_envelope.py +101 -0
- envelope/compat/navix_envelope.py +86 -0
- envelope/environment.py +64 -0
- envelope/spaces.py +205 -0
- envelope/struct.py +148 -0
- envelope/typing.py +23 -0
- envelope/wrappers/autoreset_wrapper.py +36 -0
- envelope/wrappers/episode_statistics_wrapper.py +47 -0
- envelope/wrappers/normalization.py +56 -0
- envelope/wrappers/observation_normalization_wrapper.py +114 -0
- envelope/wrappers/state_injection_wrapper.py +91 -0
- envelope/wrappers/timestep_wrapper.py +22 -0
- envelope/wrappers/truncation_wrapper.py +31 -0
- envelope/wrappers/vmap_envs_wrapper.py +77 -0
- envelope/wrappers/vmap_wrapper.py +51 -0
- envelope/wrappers/wrapper.py +57 -0
- jax_envelope-0.1.0.dist-info/METADATA +87 -0
- jax_envelope-0.1.0.dist-info/RECORD +27 -0
- jax_envelope-0.1.0.dist-info/WHEEL +4 -0
- jax_envelope-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
from functools import cached_property
|
|
3
|
+
from typing import Any, override
|
|
4
|
+
|
|
5
|
+
from jax import numpy as jnp
|
|
6
|
+
from mujoco_playground import registry
|
|
7
|
+
|
|
8
|
+
from envelope import spaces as envelope_spaces
|
|
9
|
+
from envelope.environment import Environment, Info, InfoContainer, State
|
|
10
|
+
from envelope.struct import static_field
|
|
11
|
+
from envelope.typing import Key, PyTree
|
|
12
|
+
|
|
13
|
+
_MAX_INT = int(jnp.iinfo(jnp.int32).max)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
_MUJOCO_PLAYGROUND_DEFAULT_EPISODE_LENGTH = 1000
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class MujocoPlaygroundEnvelope(Environment):
|
|
20
|
+
"""Wrapper to convert a mujoco_playground environment to a envelope environment."""
|
|
21
|
+
|
|
22
|
+
mujoco_playground_env: Any = static_field()
|
|
23
|
+
_default_max_steps: int = static_field(
|
|
24
|
+
default=_MUJOCO_PLAYGROUND_DEFAULT_EPISODE_LENGTH
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
@classmethod
|
|
28
|
+
def from_name(
|
|
29
|
+
cls, env_name: str, env_kwargs: dict[str, Any] | None = None
|
|
30
|
+
) -> "MujocoPlaygroundEnvelope":
|
|
31
|
+
"""Creates a MujocoPlaygroundEnvelope from a name and keyword arguments.
|
|
32
|
+
env_kwargs are passed to config_overrides of mujoco_playground.registry.load."""
|
|
33
|
+
env_kwargs = env_kwargs or {}
|
|
34
|
+
if "episode_length" in env_kwargs:
|
|
35
|
+
raise ValueError(
|
|
36
|
+
"Cannot override 'episode_length' directly. "
|
|
37
|
+
"Use TruncationWrapper for episode length control."
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
# Get default episode_length from registry config
|
|
41
|
+
default_config = registry.get_default_config(env_name)
|
|
42
|
+
default_max_steps = default_config.episode_length
|
|
43
|
+
|
|
44
|
+
# Set episode_length to a very large value
|
|
45
|
+
# (mujoco_playground uses int for episode_length, so we use max int instead of inf)
|
|
46
|
+
env_kwargs["episode_length"] = _MAX_INT
|
|
47
|
+
|
|
48
|
+
# Pass all env_kwargs as config_overrides
|
|
49
|
+
env = registry.load(
|
|
50
|
+
env_name, config_overrides=env_kwargs if env_kwargs else None
|
|
51
|
+
)
|
|
52
|
+
return cls(mujoco_playground_env=env, _default_max_steps=default_max_steps)
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def default_max_steps(self) -> int:
|
|
56
|
+
return self._default_max_steps
|
|
57
|
+
|
|
58
|
+
@override
|
|
59
|
+
def reset(self, key: Key) -> tuple[State, Info]:
|
|
60
|
+
env_state = self.mujoco_playground_env.reset(key)
|
|
61
|
+
info = InfoContainer(obs=env_state.obs, reward=0.0, terminated=False)
|
|
62
|
+
info = info.update(**dataclasses.asdict(env_state))
|
|
63
|
+
return env_state, info
|
|
64
|
+
|
|
65
|
+
@override
|
|
66
|
+
def step(self, state: State, action: PyTree) -> tuple[State, Info]:
|
|
67
|
+
state = self.mujoco_playground_env.step(state, action)
|
|
68
|
+
info = InfoContainer(obs=state.obs, reward=state.reward, terminated=state.done)
|
|
69
|
+
info = info.update(**dataclasses.asdict(state))
|
|
70
|
+
return state, info
|
|
71
|
+
|
|
72
|
+
@override
|
|
73
|
+
@cached_property
|
|
74
|
+
def action_space(self) -> envelope_spaces.Space:
|
|
75
|
+
# MuJoCo Playground actions are typically bounded [-1, 1]
|
|
76
|
+
return envelope_spaces.Continuous.from_shape(
|
|
77
|
+
low=-1.0, high=1.0, shape=(self.mujoco_playground_env.action_size,)
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
@override
|
|
81
|
+
@cached_property
|
|
82
|
+
def observation_space(self) -> envelope_spaces.Space:
|
|
83
|
+
import jax
|
|
84
|
+
|
|
85
|
+
def to_space(size):
|
|
86
|
+
shape = (size,) if isinstance(size, int) else size
|
|
87
|
+
return envelope_spaces.Continuous.from_shape(
|
|
88
|
+
low=-jnp.inf, high=jnp.inf, shape=shape
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
def is_leaf(x):
|
|
92
|
+
return isinstance(x, int) or (
|
|
93
|
+
isinstance(x, tuple) and all(isinstance(i, int) for i in x)
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
space_tree = jax.tree.map(
|
|
97
|
+
to_space, self.mujoco_playground_env.observation_size, is_leaf=is_leaf
|
|
98
|
+
)
|
|
99
|
+
if isinstance(space_tree, envelope_spaces.Space):
|
|
100
|
+
return space_tree
|
|
101
|
+
return envelope_spaces.PyTreeSpace(space_tree)
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
from functools import cached_property
|
|
3
|
+
from typing import Any, override
|
|
4
|
+
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
import navix
|
|
7
|
+
from navix import spaces as navix_spaces
|
|
8
|
+
from navix.environments.environment import Environment as NavixEnv
|
|
9
|
+
|
|
10
|
+
from envelope import spaces as envelope_spaces
|
|
11
|
+
from envelope.environment import Environment, Info, InfoContainer, State
|
|
12
|
+
from envelope.typing import Key, PyTree
|
|
13
|
+
|
|
14
|
+
_NAVIX_DEFAULT_MAX_STEPS = 100
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class NavixEnvelope(Environment):
|
|
18
|
+
"""Wrapper to convert a Navix environment to a envelope environment."""
|
|
19
|
+
|
|
20
|
+
navix_env: NavixEnv
|
|
21
|
+
|
|
22
|
+
@classmethod
|
|
23
|
+
def from_name(
|
|
24
|
+
cls, env_name: str, env_kwargs: dict[str, Any] | None = None
|
|
25
|
+
) -> "NavixEnvelope":
|
|
26
|
+
env_kwargs = env_kwargs or {}
|
|
27
|
+
if "max_steps" in env_kwargs:
|
|
28
|
+
raise ValueError(
|
|
29
|
+
"Cannot override 'max_steps' directly. "
|
|
30
|
+
"Use TruncationWrapper for episode length control."
|
|
31
|
+
)
|
|
32
|
+
env_kwargs["max_steps"] = jnp.inf
|
|
33
|
+
navix_env = navix.make(env_name, **env_kwargs)
|
|
34
|
+
return cls(navix_env=navix_env)
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def default_max_steps(self) -> int:
|
|
38
|
+
return _NAVIX_DEFAULT_MAX_STEPS
|
|
39
|
+
|
|
40
|
+
@override
|
|
41
|
+
def reset(self, key: Key) -> tuple[State, Info]:
|
|
42
|
+
timestep = self.navix_env.reset(key)
|
|
43
|
+
return timestep, convert_navix_to_envelope_info(timestep)
|
|
44
|
+
|
|
45
|
+
@override
|
|
46
|
+
def step(self, state: State, action: PyTree) -> tuple[State, Info]:
|
|
47
|
+
timestep = self.navix_env.step(state, action)
|
|
48
|
+
return timestep, convert_navix_to_envelope_info(timestep)
|
|
49
|
+
|
|
50
|
+
@override
|
|
51
|
+
@cached_property
|
|
52
|
+
def action_space(self) -> envelope_spaces.Space:
|
|
53
|
+
return convert_navix_to_envelope_space(self.navix_env.action_space)
|
|
54
|
+
|
|
55
|
+
@override
|
|
56
|
+
@cached_property
|
|
57
|
+
def observation_space(self) -> envelope_spaces.Space:
|
|
58
|
+
return convert_navix_to_envelope_space(self.navix_env.observation_space)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def convert_navix_to_envelope_info(nvx_timestep: navix.Timestep) -> InfoContainer:
|
|
62
|
+
timestep_dict = dataclasses.asdict(nvx_timestep)
|
|
63
|
+
step_type = timestep_dict.pop("step_type")
|
|
64
|
+
info = InfoContainer(
|
|
65
|
+
obs=timestep_dict.pop("observation"),
|
|
66
|
+
reward=timestep_dict.pop("reward"),
|
|
67
|
+
terminated=step_type == navix.StepType.TERMINATION,
|
|
68
|
+
truncated=step_type == navix.StepType.TRUNCATION,
|
|
69
|
+
)
|
|
70
|
+
info = info.update(**timestep_dict)
|
|
71
|
+
return info
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def convert_navix_to_envelope_space(
|
|
75
|
+
nvx_space: navix_spaces.Space,
|
|
76
|
+
) -> envelope_spaces.Space:
|
|
77
|
+
if isinstance(nvx_space, navix_spaces.Discrete):
|
|
78
|
+
n = jnp.asarray(nvx_space.n).astype(nvx_space.dtype)
|
|
79
|
+
return envelope_spaces.Discrete.from_shape(n, shape=nvx_space.shape)
|
|
80
|
+
|
|
81
|
+
elif isinstance(nvx_space, navix_spaces.Continuous):
|
|
82
|
+
low = jnp.asarray(nvx_space.minimum).astype(nvx_space.dtype)
|
|
83
|
+
high = jnp.asarray(nvx_space.maximum).astype(nvx_space.dtype)
|
|
84
|
+
return envelope_spaces.Continuous.from_shape(low, high, shape=nvx_space.shape)
|
|
85
|
+
|
|
86
|
+
raise ValueError(f"Unsupported space type: {type(nvx_space)}")
|
envelope/environment.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from dataclasses import field
|
|
3
|
+
from functools import cached_property
|
|
4
|
+
from typing import Protocol, runtime_checkable
|
|
5
|
+
|
|
6
|
+
from envelope import spaces
|
|
7
|
+
from envelope.struct import Container, FrozenPyTreeNode
|
|
8
|
+
from envelope.typing import Key, PyTree
|
|
9
|
+
|
|
10
|
+
__all__ = ["Environment", "State", "Info", "InfoContainer"]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@runtime_checkable
|
|
14
|
+
class Info(Protocol):
|
|
15
|
+
obs: PyTree
|
|
16
|
+
reward: float
|
|
17
|
+
terminated: bool
|
|
18
|
+
truncated: bool
|
|
19
|
+
|
|
20
|
+
def update(self, **changes: PyTree) -> "Info": ...
|
|
21
|
+
def __getattr__(self, name: str) -> PyTree: ...
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class InfoContainer(Container):
|
|
25
|
+
obs: PyTree
|
|
26
|
+
reward: float
|
|
27
|
+
terminated: bool
|
|
28
|
+
truncated: bool = field(default=False)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# State remains a general PyTree alias; environments are not forced to WrappedState
|
|
32
|
+
State = PyTree
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class Environment(ABC, FrozenPyTreeNode):
|
|
36
|
+
"""
|
|
37
|
+
Base class for all environments.
|
|
38
|
+
|
|
39
|
+
State is an opaque PyTree owned by each environment; wrappers that stack
|
|
40
|
+
environments should expose their wrapped env state as `inner_state` while
|
|
41
|
+
adding any wrapper-specific fields. `reset` may optionally receive a prior
|
|
42
|
+
state (for cross-episode persistence) and arbitrary **kwargs that wrappers
|
|
43
|
+
or environments can use.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
@abstractmethod
|
|
47
|
+
def reset(
|
|
48
|
+
self, key: Key, state: State | None = None, **kwargs
|
|
49
|
+
) -> tuple[State, Info]: ...
|
|
50
|
+
|
|
51
|
+
@abstractmethod
|
|
52
|
+
def step(self, state: State, action: PyTree, **kwargs) -> tuple[State, Info]: ...
|
|
53
|
+
|
|
54
|
+
@abstractmethod
|
|
55
|
+
@cached_property
|
|
56
|
+
def observation_space(self) -> spaces.Space: ...
|
|
57
|
+
|
|
58
|
+
@abstractmethod
|
|
59
|
+
@cached_property
|
|
60
|
+
def action_space(self) -> spaces.Space: ...
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def unwrapped(self) -> "Environment":
|
|
64
|
+
return self
|
envelope/spaces.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from functools import cached_property
|
|
3
|
+
from typing import override
|
|
4
|
+
|
|
5
|
+
import jax
|
|
6
|
+
from jax import numpy as jnp
|
|
7
|
+
|
|
8
|
+
from envelope.struct import FrozenPyTreeNode, static_field
|
|
9
|
+
from envelope.typing import Key, PyTree
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Space(ABC, FrozenPyTreeNode):
|
|
13
|
+
@abstractmethod
|
|
14
|
+
def sample(self, key: Key) -> PyTree: ...
|
|
15
|
+
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def contains(self, x: PyTree) -> bool: ...
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class Discrete(Space):
|
|
21
|
+
"""
|
|
22
|
+
A discrete space with a given number of elements. `n` can be a scalar or an array.
|
|
23
|
+
The shape and dtype of the space are inferred from `n`.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
n: The number of elements in the space.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
n: int | jax.Array
|
|
30
|
+
|
|
31
|
+
@classmethod
|
|
32
|
+
def from_shape(cls, n: int, shape: tuple[int]) -> "Discrete":
|
|
33
|
+
return cls(n=jnp.full(shape, n, dtype=jnp.asarray(n).dtype))
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def shape(self) -> tuple[int, ...]:
|
|
37
|
+
return jnp.asarray(self.n).shape
|
|
38
|
+
|
|
39
|
+
@property
|
|
40
|
+
def dtype(self):
|
|
41
|
+
return jnp.asarray(self.n).dtype
|
|
42
|
+
|
|
43
|
+
def sample(self, key: Key) -> jax.Array:
|
|
44
|
+
return jax.random.randint(key, self.shape, 0, self.n, dtype=self.dtype)
|
|
45
|
+
|
|
46
|
+
def contains(self, x: int | jax.Array) -> bool:
|
|
47
|
+
return jnp.all(x >= 0) & jnp.all(x < self.n)
|
|
48
|
+
|
|
49
|
+
def __repr__(self) -> str:
|
|
50
|
+
return f"Discrete(shape={self.shape}, dtype={self.dtype}, n={self.n})"
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class Continuous(Space):
|
|
54
|
+
"""
|
|
55
|
+
A continuous space with a given lower and upper bound. `low` and `high` can be
|
|
56
|
+
scalars or arrays. The shape and dtype of the space are inferred from `low` and
|
|
57
|
+
`high`.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
low: The lower bound of the space.
|
|
61
|
+
high: The upper bound of the space.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
low: float | jax.Array
|
|
65
|
+
high: float | jax.Array
|
|
66
|
+
|
|
67
|
+
@classmethod
|
|
68
|
+
def from_shape(cls, low: float, high: float, shape: tuple[int]) -> "Continuous":
|
|
69
|
+
return cls(
|
|
70
|
+
low=jnp.full(shape, low, dtype=jnp.asarray(low).dtype),
|
|
71
|
+
high=jnp.full(shape, high, dtype=jnp.asarray(high).dtype),
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def dtype(self):
|
|
76
|
+
if jnp.asarray(self.low).dtype != jnp.asarray(self.high).dtype:
|
|
77
|
+
raise ValueError("low and high must have the same dtype")
|
|
78
|
+
|
|
79
|
+
return jnp.asarray(self.low).dtype
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def shape(self) -> tuple[int, ...]:
|
|
83
|
+
if jnp.asarray(self.low).shape != jnp.asarray(self.high).shape:
|
|
84
|
+
raise ValueError("low and high must have the same shape")
|
|
85
|
+
|
|
86
|
+
return jnp.asarray(self.low).shape
|
|
87
|
+
|
|
88
|
+
@override
|
|
89
|
+
def sample(self, key: Key) -> jax.Array:
|
|
90
|
+
uniform_sample = jax.random.uniform(key, self.shape, self.dtype)
|
|
91
|
+
return self.low + uniform_sample * (self.high - self.low)
|
|
92
|
+
|
|
93
|
+
@override
|
|
94
|
+
def contains(self, x: jax.Array) -> bool:
|
|
95
|
+
return jnp.all((x >= jnp.asarray(self.low)) & (x <= jnp.asarray(self.high)))
|
|
96
|
+
|
|
97
|
+
def __repr__(self) -> str:
|
|
98
|
+
dtype_str = getattr(self.dtype, "__name__", str(self.dtype))
|
|
99
|
+
return (
|
|
100
|
+
f"Continuous(shape={self.shape}, dtype={dtype_str}, "
|
|
101
|
+
f"low={self.low}, high={self.high})"
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class PyTreeSpace(Space):
|
|
106
|
+
"""A Space defined by a PyTree structure of other Spaces.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
tree: A PyTree with Space objects leaves.
|
|
110
|
+
|
|
111
|
+
Usage:
|
|
112
|
+
space = PyTreeSpace({
|
|
113
|
+
"action": Discrete(n=4, dtype=jnp.int32),
|
|
114
|
+
"obs": Continuous(low=0.0, high=1.0, shape=(2,), dtype=jnp.float32)
|
|
115
|
+
})
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
tree: PyTree
|
|
119
|
+
|
|
120
|
+
@override
|
|
121
|
+
def sample(self, key: Key) -> PyTree:
|
|
122
|
+
leaves, treedef = jax.tree.flatten(
|
|
123
|
+
self.tree, is_leaf=lambda x: isinstance(x, Space)
|
|
124
|
+
)
|
|
125
|
+
keys = jax.random.split(key, len(leaves))
|
|
126
|
+
samples = [space.sample(key) for key, space in zip(keys, leaves)]
|
|
127
|
+
return jax.tree.unflatten(treedef, samples)
|
|
128
|
+
|
|
129
|
+
@override
|
|
130
|
+
def contains(self, x: PyTree) -> bool:
|
|
131
|
+
# Use tree.map to check containment for each space-value pair
|
|
132
|
+
contains = jax.tree.map(
|
|
133
|
+
lambda space, xi: space.contains(xi),
|
|
134
|
+
self.tree,
|
|
135
|
+
x,
|
|
136
|
+
is_leaf=lambda node: isinstance(node, Space),
|
|
137
|
+
)
|
|
138
|
+
return jnp.all(jnp.array(jax.tree.leaves(contains)))
|
|
139
|
+
|
|
140
|
+
def __repr__(self) -> str:
|
|
141
|
+
"""Return a string representation showing the tree structure."""
|
|
142
|
+
return f"{self.__class__.__name__}({self.tree!r})"
|
|
143
|
+
|
|
144
|
+
@property
|
|
145
|
+
def shape(self) -> PyTree:
|
|
146
|
+
return jax.tree.map(
|
|
147
|
+
lambda space: space.shape,
|
|
148
|
+
self.tree,
|
|
149
|
+
is_leaf=lambda node: isinstance(node, Space),
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def batch_space(space: Space, batch_size: int) -> Space:
|
|
154
|
+
if isinstance(space, PyTreeSpace):
|
|
155
|
+
batched_tree = jax.tree.map(
|
|
156
|
+
lambda sp: batch_space(sp, batch_size),
|
|
157
|
+
space.tree,
|
|
158
|
+
is_leaf=lambda node: isinstance(node, Space),
|
|
159
|
+
)
|
|
160
|
+
return PyTreeSpace(batched_tree)
|
|
161
|
+
return BatchedSpace(space=space, batch_size=batch_size)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class BatchedSpace(Space):
|
|
165
|
+
"""
|
|
166
|
+
A view that adds a leading batch dimension to a base Space without
|
|
167
|
+
materializing or broadcasting its parameters.
|
|
168
|
+
"""
|
|
169
|
+
|
|
170
|
+
space: Space
|
|
171
|
+
batch_size: int = static_field()
|
|
172
|
+
|
|
173
|
+
def sample(self, key: Key) -> PyTree:
|
|
174
|
+
# Accept single PRNGKey or a batch of keys shaped (batch_size, 2)
|
|
175
|
+
if getattr(key, "shape", ()) == (2,):
|
|
176
|
+
keys = jax.random.split(key, self.batch_size)
|
|
177
|
+
else:
|
|
178
|
+
if key.shape[0] != self.batch_size:
|
|
179
|
+
raise ValueError(
|
|
180
|
+
f"sample key's leading dimension ({key.shape[0]}) must match "
|
|
181
|
+
f"batch_size ({self.batch_size})."
|
|
182
|
+
)
|
|
183
|
+
keys = key
|
|
184
|
+
return jax.vmap(self.space.sample)(keys)
|
|
185
|
+
|
|
186
|
+
def contains(self, x: PyTree) -> bool:
|
|
187
|
+
# x is expected to be batched on the leading dimension
|
|
188
|
+
result = jax.vmap(self.space.contains)(x)
|
|
189
|
+
return jnp.all(jnp.asarray(result))
|
|
190
|
+
|
|
191
|
+
@cached_property
|
|
192
|
+
def shape(self) -> PyTree:
|
|
193
|
+
inner_shape = self.space.shape
|
|
194
|
+
# For tuple shapes (leaf spaces), prepend batch dimension.
|
|
195
|
+
# PyTree shapes are handled by wrapping leaves with BatchedSpace via batch_space.
|
|
196
|
+
if isinstance(inner_shape, tuple):
|
|
197
|
+
return (self.batch_size,) + inner_shape
|
|
198
|
+
return inner_shape
|
|
199
|
+
|
|
200
|
+
@property
|
|
201
|
+
def dtype(self):
|
|
202
|
+
return getattr(self.space, "dtype", None)
|
|
203
|
+
|
|
204
|
+
def __repr__(self) -> str:
|
|
205
|
+
return f"BatchedSpace(space={self.space!r}, batch_size={self.batch_size})"
|
envelope/struct.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
from dataclasses import KW_ONLY
|
|
3
|
+
from typing import Any, Iterable, Iterator, Mapping, Self, Tuple
|
|
4
|
+
|
|
5
|
+
import jax
|
|
6
|
+
|
|
7
|
+
from envelope.typing import PyTree
|
|
8
|
+
|
|
9
|
+
__all__ = ["FrozenPyTreeNode", "field", "static_field", "Container"]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def field(*, pytree_node: bool = True, **kwargs):
|
|
13
|
+
"""
|
|
14
|
+
Dataclass field helper.
|
|
15
|
+
Set pytree_node=False for static (non-transformed) fields.
|
|
16
|
+
"""
|
|
17
|
+
meta = dict(kwargs.pop("metadata", {}) or {})
|
|
18
|
+
meta["pytree_node"] = pytree_node
|
|
19
|
+
return dataclasses.field(metadata=meta, **kwargs)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def static_field(**kwargs):
|
|
23
|
+
"""Shorthand for field(pytree_node=False, ...)."""
|
|
24
|
+
return field(pytree_node=False, **kwargs)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class FrozenPyTreeNode:
|
|
28
|
+
"""
|
|
29
|
+
Frozen dataclass base that is a JAX pytree node.
|
|
30
|
+
|
|
31
|
+
Usage:
|
|
32
|
+
class Foo(FrozenPyTreeNode):
|
|
33
|
+
a: Any # pytree leaf
|
|
34
|
+
b: int = static_field() # static, not a leaf
|
|
35
|
+
|
|
36
|
+
x = Foo(a={"w": 1.0}, b=0)
|
|
37
|
+
y = x.replace(b=1)
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
# Turn subclasses into frozen dataclasses and register with JAX.
|
|
41
|
+
def __init_subclass__(cls, *, dataclass_kwargs: dict[str, Any] | None = None, **kw):
|
|
42
|
+
super().__init_subclass__(**kw)
|
|
43
|
+
# Check if this specific class (not parent) has already been processed
|
|
44
|
+
if "__is_envelope_pytreenode__" in cls.__dict__:
|
|
45
|
+
return
|
|
46
|
+
opts = dict(frozen=True, eq=True, repr=True, slots=False)
|
|
47
|
+
if dataclass_kwargs:
|
|
48
|
+
opts.update(dataclass_kwargs)
|
|
49
|
+
dataclasses.dataclass(cls, **opts) # modify in place
|
|
50
|
+
cls.__is_envelope_pytreenode__ = True
|
|
51
|
+
|
|
52
|
+
data = []
|
|
53
|
+
static = []
|
|
54
|
+
for f in dataclasses.fields(cls):
|
|
55
|
+
if f.metadata.get("pytree_node", True):
|
|
56
|
+
data.append(f.name)
|
|
57
|
+
else:
|
|
58
|
+
static.append(f.name)
|
|
59
|
+
|
|
60
|
+
jax.tree_util.register_dataclass(cls, data, static)
|
|
61
|
+
|
|
62
|
+
# convenience
|
|
63
|
+
def replace(self, **changes):
|
|
64
|
+
return dataclasses.replace(self, **changes)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@jax.tree_util.register_pytree_node_class
|
|
68
|
+
@dataclasses.dataclass(frozen=True, eq=True, repr=True, slots=False)
|
|
69
|
+
class Container:
|
|
70
|
+
_: KW_ONLY
|
|
71
|
+
_extras: Mapping[str, PyTree] = dataclasses.field(default_factory=dict, repr=False)
|
|
72
|
+
|
|
73
|
+
def __init_subclass__(cls, *, dataclass_kwargs: dict[str, Any] | None = None, **kw):
|
|
74
|
+
super().__init_subclass__(**kw)
|
|
75
|
+
if "__is_container_dataclass__" in cls.__dict__:
|
|
76
|
+
return
|
|
77
|
+
|
|
78
|
+
opts = dict(frozen=True, eq=True, repr=True, slots=False)
|
|
79
|
+
if dataclass_kwargs:
|
|
80
|
+
opts.update(dataclass_kwargs)
|
|
81
|
+
|
|
82
|
+
dataclasses.dataclass(cls, **opts)
|
|
83
|
+
cls.__is_container_dataclass__ = True
|
|
84
|
+
jax.tree_util.register_pytree_node_class(cls)
|
|
85
|
+
|
|
86
|
+
def __getattr__(self, name: str) -> PyTree:
|
|
87
|
+
# bypass __getattr__ when accessing _extras to avoid recursion
|
|
88
|
+
extras = object.__getattribute__(self, "_extras")
|
|
89
|
+
if name in extras:
|
|
90
|
+
return extras[name]
|
|
91
|
+
self_name = type(self).__name__
|
|
92
|
+
raise AttributeError(f"'{self_name}' object has no attribute '{name}'")
|
|
93
|
+
|
|
94
|
+
def __dir__(self) -> Iterable[str]:
|
|
95
|
+
core_names = {f.name for f in dataclasses.fields(self) if f.name != "_extras"}
|
|
96
|
+
return sorted(set(super().__dir__()) | core_names | set(self._extras.keys()))
|
|
97
|
+
|
|
98
|
+
def __iter__(self) -> Iterator[Tuple[str, PyTree]]:
|
|
99
|
+
for f in dataclasses.fields(self):
|
|
100
|
+
if f.name == "_extras":
|
|
101
|
+
continue
|
|
102
|
+
yield (f.name, getattr(self, f.name))
|
|
103
|
+
# extras
|
|
104
|
+
for k, v in self._extras.items():
|
|
105
|
+
yield (k, v)
|
|
106
|
+
|
|
107
|
+
def update(self, **changes: PyTree) -> Self:
|
|
108
|
+
core_names = {f.name for f in dataclasses.fields(self) if f.name != "_extras"}
|
|
109
|
+
core_updates: dict[str, PyTree] = {}
|
|
110
|
+
extras_updates: dict[str, PyTree] = {}
|
|
111
|
+
|
|
112
|
+
for k, v in changes.items():
|
|
113
|
+
if k in core_names:
|
|
114
|
+
core_updates[k] = v
|
|
115
|
+
else:
|
|
116
|
+
extras_updates[k] = v
|
|
117
|
+
|
|
118
|
+
new = dataclasses.replace(self, **core_updates)
|
|
119
|
+
new_extras = {**self._extras, **extras_updates}
|
|
120
|
+
object.__setattr__(new, "_extras", new_extras)
|
|
121
|
+
return new
|
|
122
|
+
|
|
123
|
+
def tree_flatten(self) -> Tuple[Tuple[PyTree, ...], Tuple[Any, ...]]:
|
|
124
|
+
core_fields = [f for f in dataclasses.fields(self) if f.name != "_extras"]
|
|
125
|
+
core_keys = tuple(f.name for f in core_fields)
|
|
126
|
+
core_vals = tuple(getattr(self, name) for name in core_keys)
|
|
127
|
+
|
|
128
|
+
extras_keys = tuple(self._extras.keys())
|
|
129
|
+
extras_vals = tuple(self._extras[k] for k in extras_keys)
|
|
130
|
+
|
|
131
|
+
children = core_vals + extras_vals
|
|
132
|
+
aux_data = (self.__class__, core_keys, extras_keys)
|
|
133
|
+
return children, aux_data
|
|
134
|
+
|
|
135
|
+
@classmethod
|
|
136
|
+
def tree_unflatten(cls, aux_data, children: Tuple[PyTree, ...]) -> Self:
|
|
137
|
+
actual_cls, core_keys, extras_keys = aux_data
|
|
138
|
+
n_core = len(core_keys)
|
|
139
|
+
|
|
140
|
+
core_vals = children[:n_core]
|
|
141
|
+
extras_vals = children[n_core:]
|
|
142
|
+
|
|
143
|
+
core_kwargs = dict(zip(core_keys, core_vals))
|
|
144
|
+
extras = dict(zip(extras_keys, extras_vals))
|
|
145
|
+
|
|
146
|
+
obj = actual_cls(**core_kwargs)
|
|
147
|
+
object.__setattr__(obj, "_extras", extras)
|
|
148
|
+
return obj
|
envelope/typing.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import Any, TypeAlias
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
|
|
6
|
+
PyTree: TypeAlias = Any
|
|
7
|
+
Key: TypeAlias = jax.Array
|
|
8
|
+
Array: TypeAlias = jax.Array
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BatchKind(Enum):
|
|
12
|
+
"""
|
|
13
|
+
Batch semantics for environments.
|
|
14
|
+
- VMAP: environment represents instances compatible with `jax.vmap`, for example by
|
|
15
|
+
wrapping it in a `VmapWrapper`.
|
|
16
|
+
- NATIVE_POOL: environment represents a batch of instances via a native pool. This
|
|
17
|
+
is the case when it is wrapping a non-jax-based environment that supports native
|
|
18
|
+
batching, for example those provided by envpool. Environments in this mode cannot
|
|
19
|
+
be vmapped, as it would break the native batching semantics.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
VMAP = "vmap"
|
|
23
|
+
NATIVE_POOL = "native_pool"
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
|
|
3
|
+
from envelope.environment import Info
|
|
4
|
+
from envelope.struct import field
|
|
5
|
+
from envelope.typing import Key, PyTree
|
|
6
|
+
from envelope.wrappers.wrapper import WrappedState, Wrapper
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class AutoResetWrapper(Wrapper):
|
|
10
|
+
class AutoResetState(WrappedState):
|
|
11
|
+
reset_key: jax.Array = field()
|
|
12
|
+
|
|
13
|
+
def reset(
|
|
14
|
+
self, key: Key, state: PyTree | None = None, **kwargs
|
|
15
|
+
) -> tuple[WrappedState, Info]:
|
|
16
|
+
key, subkey = jax.random.split(key)
|
|
17
|
+
inner_state = state.inner_state if state else None
|
|
18
|
+
inner_state, info = self.env.reset(key, inner_state, **kwargs)
|
|
19
|
+
state = self.AutoResetState(inner_state=inner_state, reset_key=subkey)
|
|
20
|
+
return state, info.update(next_obs=info.obs)
|
|
21
|
+
|
|
22
|
+
def step(
|
|
23
|
+
self, state: WrappedState, action: PyTree, **kwargs
|
|
24
|
+
) -> tuple[WrappedState, Info]:
|
|
25
|
+
inner_state, info_step = self.env.step(state.inner_state, action, **kwargs)
|
|
26
|
+
done = info_step.terminated | info_step.truncated
|
|
27
|
+
|
|
28
|
+
state = self.AutoResetState(inner_state=inner_state, reset_key=state.reset_key)
|
|
29
|
+
info = info_step.update(next_obs=info_step.obs)
|
|
30
|
+
|
|
31
|
+
state, info = jax.lax.cond(
|
|
32
|
+
done,
|
|
33
|
+
lambda: self.reset(state.reset_key, state),
|
|
34
|
+
lambda: (state, info),
|
|
35
|
+
)
|
|
36
|
+
return state, info
|