jax-envelope 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- envelope/__init__.py +0 -0
- envelope/compat/__init__.py +97 -0
- envelope/compat/brax_envelope.py +98 -0
- envelope/compat/craftax_envelope.py +86 -0
- envelope/compat/gymnax_envelope.py +91 -0
- envelope/compat/jumanji_envelope.py +127 -0
- envelope/compat/kinetix_envelope.py +194 -0
- envelope/compat/mujoco_playground_envelope.py +101 -0
- envelope/compat/navix_envelope.py +86 -0
- envelope/environment.py +64 -0
- envelope/spaces.py +205 -0
- envelope/struct.py +148 -0
- envelope/typing.py +23 -0
- envelope/wrappers/autoreset_wrapper.py +36 -0
- envelope/wrappers/episode_statistics_wrapper.py +47 -0
- envelope/wrappers/normalization.py +56 -0
- envelope/wrappers/observation_normalization_wrapper.py +114 -0
- envelope/wrappers/state_injection_wrapper.py +91 -0
- envelope/wrappers/timestep_wrapper.py +22 -0
- envelope/wrappers/truncation_wrapper.py +31 -0
- envelope/wrappers/vmap_envs_wrapper.py +77 -0
- envelope/wrappers/vmap_wrapper.py +51 -0
- envelope/wrappers/wrapper.py +57 -0
- jax_envelope-0.1.0.dist-info/METADATA +87 -0
- jax_envelope-0.1.0.dist-info/RECORD +27 -0
- jax_envelope-0.1.0.dist-info/WHEEL +4 -0
- jax_envelope-0.1.0.dist-info/licenses/LICENSE +21 -0
envelope/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
"""Compatibility wrappers for various RL environment libraries."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Protocol, Self
|
|
4
|
+
|
|
5
|
+
# Lazy imports to avoid requiring all dependencies at once
|
|
6
|
+
_env_module_map = {
|
|
7
|
+
"gymnax": ("envelope.compat.gymnax_envelope", "GymnaxEnvelope"),
|
|
8
|
+
"brax": ("envelope.compat.brax_envelope", "BraxEnvelope"),
|
|
9
|
+
"navix": ("envelope.compat.navix_envelope", "NavixEnvelope"),
|
|
10
|
+
"jumanji": ("envelope.compat.jumanji_envelope", "JumanjiEnvelope"),
|
|
11
|
+
"kinetix": ("envelope.compat.kinetix_envelope", "KinetixEnvelope"),
|
|
12
|
+
"craftax": ("envelope.compat.craftax_envelope", "CraftaxEnvelope"),
|
|
13
|
+
"mujoco_playground": (
|
|
14
|
+
"envelope.compat.mujoco_playground_envelope",
|
|
15
|
+
"MujocoPlaygroundEnvelope",
|
|
16
|
+
),
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class HasFromNameInit(Protocol):
|
|
21
|
+
@classmethod
|
|
22
|
+
def from_name(
|
|
23
|
+
cls, env_name: str, env_kwargs: dict[str, Any] | None = None, **kwargs
|
|
24
|
+
) -> Self: ...
|
|
25
|
+
|
|
26
|
+
"""Creates an environment from a name and keyword arguments. Unless otherwise noted,
|
|
27
|
+
the created environment will have it's default parameters, with truncation and auto
|
|
28
|
+
reset disabled.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
env_name: Environment name
|
|
32
|
+
env_kwargs: Keyword arguments passed to the environment constructor
|
|
33
|
+
**kwargs: Additional keyword arguments passed to the environment wrapper
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def create(env_name: str, env_kwargs: dict[str, Any] | None = None, **kwargs):
|
|
38
|
+
"""Create an environment from a prefixed environment ID.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
env_name: Environment ID in the format "suite::env_name" (e.g., "brax::ant")
|
|
42
|
+
env_kwargs: Keyword arguments passed to the suite's environment constructor
|
|
43
|
+
**kwargs: Additional keyword arguments passed to the environment wrapper
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
An instance of the wrapped environment
|
|
47
|
+
|
|
48
|
+
Examples:
|
|
49
|
+
>>> env = create("jumanji::snake")
|
|
50
|
+
>>> env = create("brax::ant", env_kwargs={"backend": "spring"})
|
|
51
|
+
>>> env = create("gymnax::CartPole-v1", env_params=...)
|
|
52
|
+
"""
|
|
53
|
+
original_env_id = env_name
|
|
54
|
+
if "::" not in env_name:
|
|
55
|
+
raise ValueError(
|
|
56
|
+
f"Environment ID must be in format 'suite::env_name', got: {original_env_id}"
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
suite, env_name = env_name.split("::", 1)
|
|
60
|
+
if not suite or not env_name:
|
|
61
|
+
raise ValueError(
|
|
62
|
+
f"Environment ID must be in format 'suite::env_name', got: {original_env_id}"
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
if suite not in _env_module_map:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
f"Unknown environment suite: {suite}. "
|
|
68
|
+
f"Available suites: {list(_env_module_map.keys())}"
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# Lazy import the wrapper class
|
|
72
|
+
module_name, class_name = _env_module_map[suite]
|
|
73
|
+
try:
|
|
74
|
+
import importlib
|
|
75
|
+
|
|
76
|
+
module = importlib.import_module(module_name)
|
|
77
|
+
env_class: HasFromNameInit = getattr(module, class_name)
|
|
78
|
+
except ImportError as e:
|
|
79
|
+
raise ImportError(
|
|
80
|
+
f"Failed to import {suite} wrapper. "
|
|
81
|
+
f"Make sure you have installed the '{suite}' dependencies. "
|
|
82
|
+
f"Original error: {e}"
|
|
83
|
+
) from e
|
|
84
|
+
|
|
85
|
+
env = env_class.from_name(env_name, env_kwargs=env_kwargs, **kwargs)
|
|
86
|
+
|
|
87
|
+
# Wrap with TruncationWrapper using adapter's default
|
|
88
|
+
default_max_steps = getattr(env, "default_max_steps", None)
|
|
89
|
+
if default_max_steps is not None:
|
|
90
|
+
from envelope.wrappers.truncation_wrapper import TruncationWrapper
|
|
91
|
+
|
|
92
|
+
env = TruncationWrapper(env=env, max_steps=int(default_max_steps))
|
|
93
|
+
|
|
94
|
+
return env
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
__all__ = ["create"]
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
import warnings
|
|
3
|
+
from copy import copy
|
|
4
|
+
from functools import cached_property
|
|
5
|
+
from typing import Any, override
|
|
6
|
+
|
|
7
|
+
from brax.envs import Env as BraxEnv
|
|
8
|
+
from brax.envs import Wrapper as BraxWrapper
|
|
9
|
+
from brax.envs import create as brax_create
|
|
10
|
+
from jax import numpy as jnp
|
|
11
|
+
|
|
12
|
+
from envelope import spaces
|
|
13
|
+
from envelope.environment import Environment, Info, InfoContainer, State
|
|
14
|
+
from envelope.struct import static_field
|
|
15
|
+
from envelope.typing import Key, PyTree
|
|
16
|
+
|
|
17
|
+
# Default episode_length in brax.envs.create()
|
|
18
|
+
_BRAX_DEFAULT_EPISODE_LENGTH = 1000
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class BraxEnvelope(Environment):
|
|
22
|
+
"""Wrapper to convert a Brax environment to a envelope environment."""
|
|
23
|
+
|
|
24
|
+
brax_env: BraxEnv = static_field()
|
|
25
|
+
|
|
26
|
+
@classmethod
|
|
27
|
+
def from_name(
|
|
28
|
+
cls, env_name: str, env_kwargs: dict[str, Any] | None = None
|
|
29
|
+
) -> "BraxEnvelope":
|
|
30
|
+
env_kwargs = env_kwargs or {}
|
|
31
|
+
if "episode_length" in env_kwargs:
|
|
32
|
+
raise ValueError(
|
|
33
|
+
"Cannot override 'episode_length' directly. "
|
|
34
|
+
"Use TruncationWrapper for episode length control."
|
|
35
|
+
)
|
|
36
|
+
if "auto_reset" in env_kwargs:
|
|
37
|
+
raise ValueError(
|
|
38
|
+
"Cannot override 'auto_reset' directly. "
|
|
39
|
+
"Use AutoResetWrapper for auto-reset behavior."
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
env_kwargs["episode_length"] = jnp.inf
|
|
43
|
+
env_kwargs["auto_reset"] = False
|
|
44
|
+
env = brax_create(env_name, **env_kwargs)
|
|
45
|
+
return cls(brax_env=env)
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def default_max_steps(self) -> int:
|
|
49
|
+
return _BRAX_DEFAULT_EPISODE_LENGTH
|
|
50
|
+
|
|
51
|
+
def __post_init__(self) -> "BraxEnvelope":
|
|
52
|
+
if isinstance(self.brax_env, BraxWrapper):
|
|
53
|
+
warnings.warn(
|
|
54
|
+
"Environment wrapping should be handled by envelope. "
|
|
55
|
+
"Unwrapping brax environment before converting..."
|
|
56
|
+
)
|
|
57
|
+
object.__setattr__(self, "brax_env", self.brax_env.unwrapped)
|
|
58
|
+
|
|
59
|
+
@override
|
|
60
|
+
def reset(self, key: Key) -> tuple[State, Info]:
|
|
61
|
+
brax_state = self.brax_env.reset(key)
|
|
62
|
+
info = InfoContainer(obs=brax_state.obs, reward=0.0, terminated=False)
|
|
63
|
+
info = info.update(**dataclasses.asdict(brax_state))
|
|
64
|
+
return brax_state, info
|
|
65
|
+
|
|
66
|
+
@override
|
|
67
|
+
def step(self, state: State, action: PyTree) -> tuple[State, Info]:
|
|
68
|
+
brax_state = self.brax_env.step(state, action)
|
|
69
|
+
info = InfoContainer(
|
|
70
|
+
obs=brax_state.obs, reward=brax_state.reward, terminated=brax_state.done
|
|
71
|
+
)
|
|
72
|
+
info = info.update(**dataclasses.asdict(brax_state))
|
|
73
|
+
return brax_state, info
|
|
74
|
+
|
|
75
|
+
@override
|
|
76
|
+
@cached_property
|
|
77
|
+
def action_space(self) -> spaces.Space:
|
|
78
|
+
# All brax environments have action limit of -1 to 1
|
|
79
|
+
return spaces.Continuous.from_shape(
|
|
80
|
+
low=-1.0, high=1.0, shape=(self.brax_env.action_size,)
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
@override
|
|
84
|
+
@cached_property
|
|
85
|
+
def observation_space(self) -> spaces.Space:
|
|
86
|
+
# All brax environments have observation limit of -inf to inf
|
|
87
|
+
return spaces.Continuous.from_shape(
|
|
88
|
+
low=-jnp.inf, high=jnp.inf, shape=(self.brax_env.observation_size,)
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
def __deepcopy__(self, memo):
|
|
92
|
+
warnings.warn(
|
|
93
|
+
f"Trying to deepcopy {type(self).__name__}, which contains a brax env. "
|
|
94
|
+
"Brax envs throw an error when deepcopying, so a shallow copy is returned.",
|
|
95
|
+
category=RuntimeWarning,
|
|
96
|
+
stacklevel=2,
|
|
97
|
+
)
|
|
98
|
+
return copy(self)
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
from functools import cached_property
|
|
2
|
+
from typing import Any, override
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
from craftax.craftax.craftax_state import EnvParams as CraftaxEnvParams
|
|
7
|
+
from craftax.craftax_classic.envs.craftax_state import (
|
|
8
|
+
EnvParams as CraftaxClassicEnvParams,
|
|
9
|
+
)
|
|
10
|
+
from craftax.craftax_env import make_craftax_env_from_name
|
|
11
|
+
|
|
12
|
+
from envelope import spaces as envelope_spaces
|
|
13
|
+
from envelope.compat.gymnax_envelope import _convert_space as _convert_gymnax_space
|
|
14
|
+
from envelope.environment import Environment, Info, InfoContainer, State
|
|
15
|
+
from envelope.struct import Container, static_field
|
|
16
|
+
from envelope.typing import Key, PyTree, TypeAlias
|
|
17
|
+
|
|
18
|
+
EnvParams: TypeAlias = CraftaxEnvParams | CraftaxClassicEnvParams
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class CraftaxEnvelope(Environment):
|
|
22
|
+
"""Wrapper to convert a Craftax environment to a envelope environment."""
|
|
23
|
+
|
|
24
|
+
craftax_env: Any = static_field()
|
|
25
|
+
env_params: PyTree
|
|
26
|
+
|
|
27
|
+
@classmethod
|
|
28
|
+
def from_name(
|
|
29
|
+
cls,
|
|
30
|
+
env_name: str,
|
|
31
|
+
env_params: EnvParams | None = None,
|
|
32
|
+
env_kwargs: dict[str, Any] | None = None,
|
|
33
|
+
) -> "CraftaxEnvelope":
|
|
34
|
+
env_kwargs = env_kwargs or {}
|
|
35
|
+
if "max_timesteps" in env_kwargs:
|
|
36
|
+
raise ValueError(
|
|
37
|
+
"Cannot override 'max_timesteps' directly. "
|
|
38
|
+
"Use TruncationWrapper for episode length control."
|
|
39
|
+
)
|
|
40
|
+
if "auto_reset" in env_kwargs:
|
|
41
|
+
raise ValueError(
|
|
42
|
+
"Cannot override 'auto_reset' directly. "
|
|
43
|
+
"Use AutoResetWrapper for auto-reset behavior."
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
env_kwargs["auto_reset"] = False
|
|
47
|
+
env = make_craftax_env_from_name(env_name, **env_kwargs)
|
|
48
|
+
default_params = env.default_params.replace(max_timesteps=jnp.inf)
|
|
49
|
+
|
|
50
|
+
env_params = env_params or default_params
|
|
51
|
+
return cls(craftax_env=env, env_params=env_params)
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def default_max_steps(self) -> int:
|
|
55
|
+
return int(self.craftax_env.default_params.max_timesteps)
|
|
56
|
+
|
|
57
|
+
@override
|
|
58
|
+
def reset(self, key: Key) -> tuple[State, Info]:
|
|
59
|
+
key, subkey = jax.random.split(key)
|
|
60
|
+
obs, env_state = self.craftax_env.reset(subkey, self.env_params)
|
|
61
|
+
state = Container().update(key=key, env_state=env_state)
|
|
62
|
+
info = InfoContainer(obs=obs, reward=0.0, terminated=False)
|
|
63
|
+
return state, info
|
|
64
|
+
|
|
65
|
+
@override
|
|
66
|
+
def step(self, state: State, action: PyTree) -> tuple[State, Info]:
|
|
67
|
+
key, subkey = jax.random.split(state.key)
|
|
68
|
+
obs, env_state, reward, done, env_info = self.craftax_env.step(
|
|
69
|
+
subkey, state.env_state, action, self.env_params
|
|
70
|
+
)
|
|
71
|
+
state = state.update(key=key, env_state=env_state)
|
|
72
|
+
info = InfoContainer(obs=obs, reward=reward, terminated=done)
|
|
73
|
+
info = info.update(info=env_info)
|
|
74
|
+
return state, info
|
|
75
|
+
|
|
76
|
+
@override
|
|
77
|
+
@cached_property
|
|
78
|
+
def action_space(self) -> envelope_spaces.Space:
|
|
79
|
+
return _convert_gymnax_space(self.craftax_env.action_space(self.env_params))
|
|
80
|
+
|
|
81
|
+
@override
|
|
82
|
+
@cached_property
|
|
83
|
+
def observation_space(self) -> envelope_spaces.Space:
|
|
84
|
+
return _convert_gymnax_space(
|
|
85
|
+
self.craftax_env.observation_space(self.env_params)
|
|
86
|
+
)
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
from functools import cached_property
|
|
2
|
+
from typing import Any, override
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
from gymnax import make as gymnax_create
|
|
7
|
+
from gymnax.environments import spaces as gymnax_spaces
|
|
8
|
+
from gymnax.environments.environment import Environment as GymnaxEnv
|
|
9
|
+
from gymnax.environments.environment import EnvParams as GymnaxEnvParams
|
|
10
|
+
|
|
11
|
+
from envelope import spaces as envelope_spaces
|
|
12
|
+
from envelope.environment import Environment, Info, InfoContainer, State
|
|
13
|
+
from envelope.struct import Container, static_field
|
|
14
|
+
from envelope.typing import Key, PyTree
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class GymnaxEnvelope(Environment):
|
|
18
|
+
"""Wrapper to convert a Gymnax environment to a envelope environment."""
|
|
19
|
+
|
|
20
|
+
gymnax_env: GymnaxEnv = static_field()
|
|
21
|
+
env_params: PyTree
|
|
22
|
+
|
|
23
|
+
@classmethod
|
|
24
|
+
def from_name(
|
|
25
|
+
cls,
|
|
26
|
+
env_name: str,
|
|
27
|
+
env_params: GymnaxEnvParams | None = None,
|
|
28
|
+
env_kwargs: dict[str, Any] | None = None,
|
|
29
|
+
) -> "GymnaxEnvelope":
|
|
30
|
+
env_kwargs = env_kwargs or {}
|
|
31
|
+
if "max_steps_in_episode" in env_kwargs:
|
|
32
|
+
raise ValueError(
|
|
33
|
+
"Cannot override 'max_steps_in_episode' directly. "
|
|
34
|
+
"Use TruncationWrapper for episode length control."
|
|
35
|
+
)
|
|
36
|
+
gymnax_env, default_params = gymnax_create(env_name, **env_kwargs)
|
|
37
|
+
default_params = default_params.replace(max_steps_in_episode=jnp.inf)
|
|
38
|
+
|
|
39
|
+
env_params = env_params or default_params
|
|
40
|
+
return cls(gymnax_env=gymnax_env, env_params=env_params)
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def default_max_steps(self) -> int:
|
|
44
|
+
return int(self.gymnax_env.default_params.max_steps_in_episode)
|
|
45
|
+
|
|
46
|
+
@override
|
|
47
|
+
def reset(self, key: Key) -> tuple[State, Info]:
|
|
48
|
+
key, subkey = jax.random.split(key)
|
|
49
|
+
obs, env_state = self.gymnax_env.reset(subkey, self.env_params)
|
|
50
|
+
state = Container().update(key=key, env_state=env_state)
|
|
51
|
+
info = InfoContainer(obs=obs, reward=0.0, terminated=False)
|
|
52
|
+
info = info.update(info=None)
|
|
53
|
+
return state, info
|
|
54
|
+
|
|
55
|
+
@override
|
|
56
|
+
def step(self, state: State, action: PyTree) -> tuple[State, Info]:
|
|
57
|
+
key, subkey = jax.random.split(state.key)
|
|
58
|
+
obs, env_state, reward, done, env_info = self.gymnax_env.step(
|
|
59
|
+
subkey, state.env_state, action, self.env_params
|
|
60
|
+
)
|
|
61
|
+
state = state.update(key=key, env_state=env_state)
|
|
62
|
+
info = InfoContainer(obs=obs, reward=reward, terminated=done)
|
|
63
|
+
info = info.update(info=env_info)
|
|
64
|
+
return state, info
|
|
65
|
+
|
|
66
|
+
@override
|
|
67
|
+
@cached_property
|
|
68
|
+
def action_space(self) -> envelope_spaces.Space:
|
|
69
|
+
return _convert_space(self.gymnax_env.action_space(self.env_params))
|
|
70
|
+
|
|
71
|
+
@override
|
|
72
|
+
@cached_property
|
|
73
|
+
def observation_space(self) -> envelope_spaces.Space:
|
|
74
|
+
return _convert_space(self.gymnax_env.observation_space(self.env_params))
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _convert_space(gmx_space: gymnax_spaces.Space) -> envelope_spaces.Space:
|
|
78
|
+
if isinstance(gmx_space, gymnax_spaces.Box):
|
|
79
|
+
low = jnp.broadcast_to(gmx_space.low, gmx_space.shape).astype(gmx_space.dtype)
|
|
80
|
+
high = jnp.broadcast_to(gmx_space.high, gmx_space.shape).astype(gmx_space.dtype)
|
|
81
|
+
return envelope_spaces.Continuous(low=low, high=high)
|
|
82
|
+
elif isinstance(gmx_space, gymnax_spaces.Discrete):
|
|
83
|
+
n = jnp.broadcast_to(gmx_space.n, gmx_space.shape).astype(gmx_space.dtype)
|
|
84
|
+
return envelope_spaces.Discrete(n=n)
|
|
85
|
+
elif isinstance(gmx_space, gymnax_spaces.Tuple):
|
|
86
|
+
spaces = tuple(_convert_space(space) for space in gmx_space.spaces)
|
|
87
|
+
return envelope_spaces.PyTreeSpace(spaces)
|
|
88
|
+
elif isinstance(gmx_space, gymnax_spaces.Dict):
|
|
89
|
+
spaces = {k: _convert_space(space) for k, space in gmx_space.spaces.items()}
|
|
90
|
+
return envelope_spaces.PyTreeSpace(spaces)
|
|
91
|
+
raise ValueError(f"Unsupported space type: {type(gmx_space)}")
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from copy import copy
|
|
3
|
+
from functools import cached_property
|
|
4
|
+
from typing import Any, override
|
|
5
|
+
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
import jumanji
|
|
8
|
+
from jumanji.specs import Array, BoundedArray, DiscreteArray, MultiDiscreteArray
|
|
9
|
+
from jumanji.types import TimeStep as JumanjiTimeStep
|
|
10
|
+
|
|
11
|
+
from envelope import spaces as envelope_spaces
|
|
12
|
+
from envelope.environment import Environment, Info, InfoContainer, State
|
|
13
|
+
from envelope.struct import static_field
|
|
14
|
+
from envelope.typing import Key, PyTree
|
|
15
|
+
|
|
16
|
+
_MAX_INT = jnp.iinfo(jnp.int32).max
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class JumanjiEnvelope(Environment):
|
|
20
|
+
"""Wrapper to convert a Jumanji environment to a envelope environment."""
|
|
21
|
+
|
|
22
|
+
jumanji_env: Any = static_field()
|
|
23
|
+
_default_time_limit: int | None = static_field(default=None)
|
|
24
|
+
|
|
25
|
+
@classmethod
|
|
26
|
+
def from_name(
|
|
27
|
+
cls, env_name: str, env_kwargs: dict[str, Any] | None = None
|
|
28
|
+
) -> "JumanjiEnvelope":
|
|
29
|
+
env_kwargs = env_kwargs or {}
|
|
30
|
+
if "time_limit" in env_kwargs:
|
|
31
|
+
raise ValueError(
|
|
32
|
+
"Cannot override 'time_limit' directly. "
|
|
33
|
+
"Use TruncationWrapper for episode length control."
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
# Create env first with defaults to capture default time_limit
|
|
37
|
+
temp_env = jumanji.make(env_name, **env_kwargs)
|
|
38
|
+
default_time_limit = getattr(temp_env, "time_limit", None)
|
|
39
|
+
|
|
40
|
+
# Now create env with time_limit=_MAX_INT (if env supports it)
|
|
41
|
+
if default_time_limit is not None:
|
|
42
|
+
env_kwargs["time_limit"] = _MAX_INT
|
|
43
|
+
env = jumanji.make(env_name, **env_kwargs)
|
|
44
|
+
return cls(jumanji_env=env, _default_time_limit=default_time_limit)
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def default_max_steps(self) -> int | None:
|
|
48
|
+
return self._default_time_limit
|
|
49
|
+
|
|
50
|
+
@override
|
|
51
|
+
def reset(self, key: Key) -> tuple[State, Info]:
|
|
52
|
+
env_state, timestep = self.jumanji_env.reset(key)
|
|
53
|
+
info = convert_jumanji_to_envelope_info(timestep)
|
|
54
|
+
return env_state, info
|
|
55
|
+
|
|
56
|
+
@override
|
|
57
|
+
def step(self, state: State, action: PyTree) -> tuple[State, Info]:
|
|
58
|
+
env_state, timestep = self.jumanji_env.step(state, action)
|
|
59
|
+
info = convert_jumanji_to_envelope_info(timestep)
|
|
60
|
+
return env_state, info
|
|
61
|
+
|
|
62
|
+
@override
|
|
63
|
+
@cached_property
|
|
64
|
+
def action_space(self) -> envelope_spaces.Space:
|
|
65
|
+
return convert_jumanji_spec_to_envelope_space(self.jumanji_env.action_spec)
|
|
66
|
+
|
|
67
|
+
@override
|
|
68
|
+
@cached_property
|
|
69
|
+
def observation_space(self) -> envelope_spaces.Space:
|
|
70
|
+
return convert_jumanji_spec_to_envelope_space(self.jumanji_env.observation_spec)
|
|
71
|
+
|
|
72
|
+
def __deepcopy__(self, memo):
|
|
73
|
+
warnings.warn(
|
|
74
|
+
f"Trying to deepcopy {type(self).__name__}, which contains a jumanji env. "
|
|
75
|
+
"Jumanji envs may throw an error when deepcopying, so a shallow copy is "
|
|
76
|
+
"returned.",
|
|
77
|
+
category=RuntimeWarning,
|
|
78
|
+
stacklevel=2,
|
|
79
|
+
)
|
|
80
|
+
return copy(self)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def convert_jumanji_to_envelope_info(timestep: JumanjiTimeStep) -> InfoContainer:
|
|
84
|
+
info = InfoContainer(
|
|
85
|
+
obs=timestep.observation, reward=timestep.reward, terminated=timestep.last()
|
|
86
|
+
).update(**timestep.extras)
|
|
87
|
+
return info
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def convert_jumanji_spec_to_envelope_space(spec: Any) -> envelope_spaces.Space:
|
|
91
|
+
"""Convert a Jumanji Spec to a envelope Space."""
|
|
92
|
+
|
|
93
|
+
if isinstance(spec, (DiscreteArray, MultiDiscreteArray)):
|
|
94
|
+
n = jnp.asarray(spec.num_values, dtype=spec.dtype)
|
|
95
|
+
if getattr(spec, "shape", ()) not in ((), n.shape):
|
|
96
|
+
n = jnp.broadcast_to(n, spec.shape)
|
|
97
|
+
return envelope_spaces.Discrete(n=n)
|
|
98
|
+
|
|
99
|
+
if isinstance(spec, BoundedArray):
|
|
100
|
+
low = jnp.broadcast_to(jnp.asarray(spec.minimum, dtype=spec.dtype), spec.shape)
|
|
101
|
+
high = jnp.broadcast_to(jnp.asarray(spec.maximum, dtype=spec.dtype), spec.shape)
|
|
102
|
+
return envelope_spaces.Continuous(low=low, high=high)
|
|
103
|
+
|
|
104
|
+
if isinstance(spec, Array):
|
|
105
|
+
dtype = jnp.dtype(spec.dtype)
|
|
106
|
+
if not jnp.issubdtype(dtype, jnp.floating):
|
|
107
|
+
raise NotImplementedError(
|
|
108
|
+
"Unbounded jumanji Array specs are only supported for floating dtypes. "
|
|
109
|
+
f"Got dtype={dtype} for spec={spec!r}."
|
|
110
|
+
)
|
|
111
|
+
low = jnp.full(spec.shape, -jnp.inf, dtype=dtype)
|
|
112
|
+
high = jnp.full(spec.shape, jnp.inf, dtype=dtype)
|
|
113
|
+
return envelope_spaces.Continuous(low=low, high=high)
|
|
114
|
+
|
|
115
|
+
# Structured specs (most Jumanji envs): access private mapping when available.
|
|
116
|
+
subspecs = getattr(spec, "_specs", None)
|
|
117
|
+
if isinstance(subspecs, dict):
|
|
118
|
+
tree = {
|
|
119
|
+
k: convert_jumanji_spec_to_envelope_space(v) for k, v in subspecs.items()
|
|
120
|
+
}
|
|
121
|
+
return envelope_spaces.PyTreeSpace(tree)
|
|
122
|
+
|
|
123
|
+
if isinstance(spec, (tuple, list)):
|
|
124
|
+
tree = tuple(convert_jumanji_spec_to_envelope_space(s) for s in spec)
|
|
125
|
+
return envelope_spaces.PyTreeSpace(tree)
|
|
126
|
+
|
|
127
|
+
raise ValueError(f"Unsupported spec type: {type(spec)}")
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
"""Kinetix compatibility wrapper.
|
|
2
|
+
|
|
3
|
+
This module exposes Kinetix environments through the `envelope.environment.Environment`
|
|
4
|
+
API. It mirrors envelope's compat philosophy:
|
|
5
|
+
- prefer *no* environment-side auto-reset (use `AutoResetWrapper` in envelope)
|
|
6
|
+
- prefer *no* fixed episode time-limits (use `TruncationWrapper` in envelope)
|
|
7
|
+
|
|
8
|
+
`from_name` supports premade level ids like `s/h4_thrust_aim` (optionally with
|
|
9
|
+
`.json`). For maximum flexibility, users can bypass level handling entirely by
|
|
10
|
+
passing a custom `reset_fn`.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import warnings
|
|
16
|
+
from functools import cached_property
|
|
17
|
+
from typing import Any, Callable, Literal, override
|
|
18
|
+
|
|
19
|
+
import jax
|
|
20
|
+
import jax.numpy as jnp
|
|
21
|
+
from kinetix.environment import (
|
|
22
|
+
ActionType,
|
|
23
|
+
EnvParams,
|
|
24
|
+
ObservationType,
|
|
25
|
+
StaticEnvParams,
|
|
26
|
+
make_kinetix_env,
|
|
27
|
+
)
|
|
28
|
+
from kinetix.environment.ued.ued import make_reset_fn_sample_kinetix_level
|
|
29
|
+
from kinetix.util.saving import load_from_json_file
|
|
30
|
+
|
|
31
|
+
from envelope import spaces as envelope_spaces
|
|
32
|
+
from envelope.compat.gymnax_envelope import _convert_space as _convert_gymnax_space
|
|
33
|
+
from envelope.environment import Environment, Info, InfoContainer, State
|
|
34
|
+
from envelope.struct import Container, static_field
|
|
35
|
+
from envelope.typing import Key, PyTree
|
|
36
|
+
|
|
37
|
+
LevelResetFn = Callable[[Key], Any]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _normalize_level_id(level_id: str) -> str:
|
|
41
|
+
"""Normalize a path-like level id.
|
|
42
|
+
|
|
43
|
+
Examples:
|
|
44
|
+
- ``"s/h4_thrust_aim"`` -> ``"s/h4_thrust_aim.json"``
|
|
45
|
+
- ``"/s/h4_thrust_aim.json"`` -> ``"s/h4_thrust_aim.json"``
|
|
46
|
+
"""
|
|
47
|
+
level_id = level_id.strip().lstrip("/")
|
|
48
|
+
if not level_id:
|
|
49
|
+
raise ValueError("level_id must be a non-empty string")
|
|
50
|
+
if level_id.endswith("/"):
|
|
51
|
+
raise ValueError("level_id must not end with '/'")
|
|
52
|
+
if not level_id.endswith(".json"):
|
|
53
|
+
level_id = f"{level_id}.json"
|
|
54
|
+
return level_id
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _warn_auto_reset(auto_reset: bool) -> None:
|
|
58
|
+
if auto_reset:
|
|
59
|
+
warnings.warn(
|
|
60
|
+
"Creating a KinetixEnvelope with auto_reset=True is not recommended, use "
|
|
61
|
+
"an AutoResetWrapper instead.",
|
|
62
|
+
stacklevel=2,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class KinetixEnvelope(Environment):
|
|
67
|
+
"""Wrapper to convert a Kinetix environment to a envelope environment."""
|
|
68
|
+
|
|
69
|
+
kinetix_env: Any = static_field()
|
|
70
|
+
env_params: Any
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def default_max_steps(self) -> int:
|
|
74
|
+
return int(EnvParams().max_timesteps)
|
|
75
|
+
|
|
76
|
+
@classmethod
|
|
77
|
+
def from_name(
|
|
78
|
+
cls,
|
|
79
|
+
env_name: str | Literal["random"],
|
|
80
|
+
env_params: EnvParams | None = None,
|
|
81
|
+
env_kwargs: dict[str, Any] | None = None,
|
|
82
|
+
) -> "KinetixEnvelope":
|
|
83
|
+
env_kwargs = env_kwargs or {}
|
|
84
|
+
if "max_timesteps" in env_kwargs:
|
|
85
|
+
raise ValueError(
|
|
86
|
+
"Cannot override 'max_timesteps' directly. "
|
|
87
|
+
"Use TruncationWrapper for episode length control."
|
|
88
|
+
)
|
|
89
|
+
if "auto_reset" in env_kwargs:
|
|
90
|
+
raise ValueError(
|
|
91
|
+
"Cannot override 'auto_reset' directly. "
|
|
92
|
+
"Use AutoResetWrapper for auto-reset behavior."
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
env_kwargs["auto_reset"] = False
|
|
96
|
+
if env_name == "random":
|
|
97
|
+
return cls.create_random(env_params=env_params, **env_kwargs)
|
|
98
|
+
|
|
99
|
+
if (
|
|
100
|
+
env_params is not None
|
|
101
|
+
or "env_params" in env_kwargs
|
|
102
|
+
or "static_env_params" in env_kwargs
|
|
103
|
+
):
|
|
104
|
+
raise ValueError(
|
|
105
|
+
"env_params and static_env_params cannot be passed when creating a "
|
|
106
|
+
"KinetixEnvelope from a premade level."
|
|
107
|
+
)
|
|
108
|
+
return cls.create_premade(env_name, **env_kwargs)
|
|
109
|
+
|
|
110
|
+
@classmethod
|
|
111
|
+
def create_premade(
|
|
112
|
+
cls,
|
|
113
|
+
env_name: str,
|
|
114
|
+
action_type: ActionType = ActionType.CONTINUOUS,
|
|
115
|
+
observation_type: ObservationType = ObservationType.SYMBOLIC_FLAT,
|
|
116
|
+
auto_reset: bool = False,
|
|
117
|
+
) -> "KinetixEnvelope":
|
|
118
|
+
_warn_auto_reset(auto_reset)
|
|
119
|
+
|
|
120
|
+
# Load level.
|
|
121
|
+
level_id_json = _normalize_level_id(env_name)
|
|
122
|
+
level, static_env_params, env_params = load_from_json_file(level_id_json)
|
|
123
|
+
env_params = env_params.replace(max_timesteps=jnp.inf) if env_params else None
|
|
124
|
+
|
|
125
|
+
def reset_fn(_: Key) -> Any:
|
|
126
|
+
return level
|
|
127
|
+
|
|
128
|
+
# Create environment.
|
|
129
|
+
kinetix_env = make_kinetix_env(
|
|
130
|
+
action_type=action_type,
|
|
131
|
+
observation_type=observation_type,
|
|
132
|
+
reset_fn=reset_fn,
|
|
133
|
+
env_params=env_params,
|
|
134
|
+
static_env_params=static_env_params,
|
|
135
|
+
auto_reset=auto_reset,
|
|
136
|
+
)
|
|
137
|
+
return cls(kinetix_env=kinetix_env, env_params=env_params)
|
|
138
|
+
|
|
139
|
+
@classmethod
|
|
140
|
+
def create_random(
|
|
141
|
+
cls,
|
|
142
|
+
action_type: ActionType = ActionType.CONTINUOUS,
|
|
143
|
+
observation_type: ObservationType = ObservationType.SYMBOLIC_FLAT,
|
|
144
|
+
env_params: EnvParams | None = None,
|
|
145
|
+
static_env_params: StaticEnvParams = StaticEnvParams(),
|
|
146
|
+
auto_reset: bool = False,
|
|
147
|
+
) -> "KinetixEnvelope":
|
|
148
|
+
_warn_auto_reset(auto_reset)
|
|
149
|
+
if env_params is None:
|
|
150
|
+
env_params = EnvParams()
|
|
151
|
+
env_params = env_params.replace(max_timesteps=jnp.inf)
|
|
152
|
+
|
|
153
|
+
reset_fn = make_reset_fn_sample_kinetix_level(env_params, static_env_params)
|
|
154
|
+
kinetix_env = make_kinetix_env(
|
|
155
|
+
action_type=action_type,
|
|
156
|
+
observation_type=observation_type,
|
|
157
|
+
reset_fn=reset_fn,
|
|
158
|
+
env_params=env_params,
|
|
159
|
+
static_env_params=static_env_params,
|
|
160
|
+
auto_reset=auto_reset,
|
|
161
|
+
)
|
|
162
|
+
return cls(kinetix_env=kinetix_env, env_params=env_params)
|
|
163
|
+
|
|
164
|
+
@override
|
|
165
|
+
def reset(self, key: Key) -> tuple[State, Info]:
|
|
166
|
+
key, subkey = jax.random.split(key)
|
|
167
|
+
obs, env_state = self.kinetix_env.reset(subkey, self.env_params)
|
|
168
|
+
state_out = Container().update(key=key, env_state=env_state)
|
|
169
|
+
info = InfoContainer(obs=obs, reward=0.0, terminated=False)
|
|
170
|
+
info = info.update(info=None)
|
|
171
|
+
return state_out, info
|
|
172
|
+
|
|
173
|
+
@override
|
|
174
|
+
def step(self, state: State, action: PyTree) -> tuple[State, Info]:
|
|
175
|
+
key, subkey = jax.random.split(state.key)
|
|
176
|
+
obs, env_state, reward, done, env_info = self.kinetix_env.step(
|
|
177
|
+
subkey, state.env_state, action, self.env_params
|
|
178
|
+
)
|
|
179
|
+
state_out = state.update(key=key, env_state=env_state)
|
|
180
|
+
info = InfoContainer(obs=obs, reward=reward, terminated=done)
|
|
181
|
+
info = info.update(info=env_info)
|
|
182
|
+
return state_out, info
|
|
183
|
+
|
|
184
|
+
@override
|
|
185
|
+
@cached_property
|
|
186
|
+
def action_space(self) -> envelope_spaces.Space:
|
|
187
|
+
return _convert_gymnax_space(self.kinetix_env.action_space(self.env_params))
|
|
188
|
+
|
|
189
|
+
@override
|
|
190
|
+
@cached_property
|
|
191
|
+
def observation_space(self) -> envelope_spaces.Space:
|
|
192
|
+
return _convert_gymnax_space(
|
|
193
|
+
self.kinetix_env.observation_space(self.env_params)
|
|
194
|
+
)
|