jax-envelope 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
envelope/__init__.py ADDED
File without changes
@@ -0,0 +1,97 @@
1
+ """Compatibility wrappers for various RL environment libraries."""
2
+
3
+ from typing import Any, Protocol, Self
4
+
5
+ # Lazy imports to avoid requiring all dependencies at once
6
+ _env_module_map = {
7
+ "gymnax": ("envelope.compat.gymnax_envelope", "GymnaxEnvelope"),
8
+ "brax": ("envelope.compat.brax_envelope", "BraxEnvelope"),
9
+ "navix": ("envelope.compat.navix_envelope", "NavixEnvelope"),
10
+ "jumanji": ("envelope.compat.jumanji_envelope", "JumanjiEnvelope"),
11
+ "kinetix": ("envelope.compat.kinetix_envelope", "KinetixEnvelope"),
12
+ "craftax": ("envelope.compat.craftax_envelope", "CraftaxEnvelope"),
13
+ "mujoco_playground": (
14
+ "envelope.compat.mujoco_playground_envelope",
15
+ "MujocoPlaygroundEnvelope",
16
+ ),
17
+ }
18
+
19
+
20
+ class HasFromNameInit(Protocol):
21
+ @classmethod
22
+ def from_name(
23
+ cls, env_name: str, env_kwargs: dict[str, Any] | None = None, **kwargs
24
+ ) -> Self: ...
25
+
26
+ """Creates an environment from a name and keyword arguments. Unless otherwise noted,
27
+ the created environment will have it's default parameters, with truncation and auto
28
+ reset disabled.
29
+
30
+ Args:
31
+ env_name: Environment name
32
+ env_kwargs: Keyword arguments passed to the environment constructor
33
+ **kwargs: Additional keyword arguments passed to the environment wrapper
34
+ """
35
+
36
+
37
+ def create(env_name: str, env_kwargs: dict[str, Any] | None = None, **kwargs):
38
+ """Create an environment from a prefixed environment ID.
39
+
40
+ Args:
41
+ env_name: Environment ID in the format "suite::env_name" (e.g., "brax::ant")
42
+ env_kwargs: Keyword arguments passed to the suite's environment constructor
43
+ **kwargs: Additional keyword arguments passed to the environment wrapper
44
+
45
+ Returns:
46
+ An instance of the wrapped environment
47
+
48
+ Examples:
49
+ >>> env = create("jumanji::snake")
50
+ >>> env = create("brax::ant", env_kwargs={"backend": "spring"})
51
+ >>> env = create("gymnax::CartPole-v1", env_params=...)
52
+ """
53
+ original_env_id = env_name
54
+ if "::" not in env_name:
55
+ raise ValueError(
56
+ f"Environment ID must be in format 'suite::env_name', got: {original_env_id}"
57
+ )
58
+
59
+ suite, env_name = env_name.split("::", 1)
60
+ if not suite or not env_name:
61
+ raise ValueError(
62
+ f"Environment ID must be in format 'suite::env_name', got: {original_env_id}"
63
+ )
64
+
65
+ if suite not in _env_module_map:
66
+ raise ValueError(
67
+ f"Unknown environment suite: {suite}. "
68
+ f"Available suites: {list(_env_module_map.keys())}"
69
+ )
70
+
71
+ # Lazy import the wrapper class
72
+ module_name, class_name = _env_module_map[suite]
73
+ try:
74
+ import importlib
75
+
76
+ module = importlib.import_module(module_name)
77
+ env_class: HasFromNameInit = getattr(module, class_name)
78
+ except ImportError as e:
79
+ raise ImportError(
80
+ f"Failed to import {suite} wrapper. "
81
+ f"Make sure you have installed the '{suite}' dependencies. "
82
+ f"Original error: {e}"
83
+ ) from e
84
+
85
+ env = env_class.from_name(env_name, env_kwargs=env_kwargs, **kwargs)
86
+
87
+ # Wrap with TruncationWrapper using adapter's default
88
+ default_max_steps = getattr(env, "default_max_steps", None)
89
+ if default_max_steps is not None:
90
+ from envelope.wrappers.truncation_wrapper import TruncationWrapper
91
+
92
+ env = TruncationWrapper(env=env, max_steps=int(default_max_steps))
93
+
94
+ return env
95
+
96
+
97
+ __all__ = ["create"]
@@ -0,0 +1,98 @@
1
+ import dataclasses
2
+ import warnings
3
+ from copy import copy
4
+ from functools import cached_property
5
+ from typing import Any, override
6
+
7
+ from brax.envs import Env as BraxEnv
8
+ from brax.envs import Wrapper as BraxWrapper
9
+ from brax.envs import create as brax_create
10
+ from jax import numpy as jnp
11
+
12
+ from envelope import spaces
13
+ from envelope.environment import Environment, Info, InfoContainer, State
14
+ from envelope.struct import static_field
15
+ from envelope.typing import Key, PyTree
16
+
17
+ # Default episode_length in brax.envs.create()
18
+ _BRAX_DEFAULT_EPISODE_LENGTH = 1000
19
+
20
+
21
+ class BraxEnvelope(Environment):
22
+ """Wrapper to convert a Brax environment to a envelope environment."""
23
+
24
+ brax_env: BraxEnv = static_field()
25
+
26
+ @classmethod
27
+ def from_name(
28
+ cls, env_name: str, env_kwargs: dict[str, Any] | None = None
29
+ ) -> "BraxEnvelope":
30
+ env_kwargs = env_kwargs or {}
31
+ if "episode_length" in env_kwargs:
32
+ raise ValueError(
33
+ "Cannot override 'episode_length' directly. "
34
+ "Use TruncationWrapper for episode length control."
35
+ )
36
+ if "auto_reset" in env_kwargs:
37
+ raise ValueError(
38
+ "Cannot override 'auto_reset' directly. "
39
+ "Use AutoResetWrapper for auto-reset behavior."
40
+ )
41
+
42
+ env_kwargs["episode_length"] = jnp.inf
43
+ env_kwargs["auto_reset"] = False
44
+ env = brax_create(env_name, **env_kwargs)
45
+ return cls(brax_env=env)
46
+
47
+ @property
48
+ def default_max_steps(self) -> int:
49
+ return _BRAX_DEFAULT_EPISODE_LENGTH
50
+
51
+ def __post_init__(self) -> "BraxEnvelope":
52
+ if isinstance(self.brax_env, BraxWrapper):
53
+ warnings.warn(
54
+ "Environment wrapping should be handled by envelope. "
55
+ "Unwrapping brax environment before converting..."
56
+ )
57
+ object.__setattr__(self, "brax_env", self.brax_env.unwrapped)
58
+
59
+ @override
60
+ def reset(self, key: Key) -> tuple[State, Info]:
61
+ brax_state = self.brax_env.reset(key)
62
+ info = InfoContainer(obs=brax_state.obs, reward=0.0, terminated=False)
63
+ info = info.update(**dataclasses.asdict(brax_state))
64
+ return brax_state, info
65
+
66
+ @override
67
+ def step(self, state: State, action: PyTree) -> tuple[State, Info]:
68
+ brax_state = self.brax_env.step(state, action)
69
+ info = InfoContainer(
70
+ obs=brax_state.obs, reward=brax_state.reward, terminated=brax_state.done
71
+ )
72
+ info = info.update(**dataclasses.asdict(brax_state))
73
+ return brax_state, info
74
+
75
+ @override
76
+ @cached_property
77
+ def action_space(self) -> spaces.Space:
78
+ # All brax environments have action limit of -1 to 1
79
+ return spaces.Continuous.from_shape(
80
+ low=-1.0, high=1.0, shape=(self.brax_env.action_size,)
81
+ )
82
+
83
+ @override
84
+ @cached_property
85
+ def observation_space(self) -> spaces.Space:
86
+ # All brax environments have observation limit of -inf to inf
87
+ return spaces.Continuous.from_shape(
88
+ low=-jnp.inf, high=jnp.inf, shape=(self.brax_env.observation_size,)
89
+ )
90
+
91
+ def __deepcopy__(self, memo):
92
+ warnings.warn(
93
+ f"Trying to deepcopy {type(self).__name__}, which contains a brax env. "
94
+ "Brax envs throw an error when deepcopying, so a shallow copy is returned.",
95
+ category=RuntimeWarning,
96
+ stacklevel=2,
97
+ )
98
+ return copy(self)
@@ -0,0 +1,86 @@
1
+ from functools import cached_property
2
+ from typing import Any, override
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from craftax.craftax.craftax_state import EnvParams as CraftaxEnvParams
7
+ from craftax.craftax_classic.envs.craftax_state import (
8
+ EnvParams as CraftaxClassicEnvParams,
9
+ )
10
+ from craftax.craftax_env import make_craftax_env_from_name
11
+
12
+ from envelope import spaces as envelope_spaces
13
+ from envelope.compat.gymnax_envelope import _convert_space as _convert_gymnax_space
14
+ from envelope.environment import Environment, Info, InfoContainer, State
15
+ from envelope.struct import Container, static_field
16
+ from envelope.typing import Key, PyTree, TypeAlias
17
+
18
+ EnvParams: TypeAlias = CraftaxEnvParams | CraftaxClassicEnvParams
19
+
20
+
21
+ class CraftaxEnvelope(Environment):
22
+ """Wrapper to convert a Craftax environment to a envelope environment."""
23
+
24
+ craftax_env: Any = static_field()
25
+ env_params: PyTree
26
+
27
+ @classmethod
28
+ def from_name(
29
+ cls,
30
+ env_name: str,
31
+ env_params: EnvParams | None = None,
32
+ env_kwargs: dict[str, Any] | None = None,
33
+ ) -> "CraftaxEnvelope":
34
+ env_kwargs = env_kwargs or {}
35
+ if "max_timesteps" in env_kwargs:
36
+ raise ValueError(
37
+ "Cannot override 'max_timesteps' directly. "
38
+ "Use TruncationWrapper for episode length control."
39
+ )
40
+ if "auto_reset" in env_kwargs:
41
+ raise ValueError(
42
+ "Cannot override 'auto_reset' directly. "
43
+ "Use AutoResetWrapper for auto-reset behavior."
44
+ )
45
+
46
+ env_kwargs["auto_reset"] = False
47
+ env = make_craftax_env_from_name(env_name, **env_kwargs)
48
+ default_params = env.default_params.replace(max_timesteps=jnp.inf)
49
+
50
+ env_params = env_params or default_params
51
+ return cls(craftax_env=env, env_params=env_params)
52
+
53
+ @property
54
+ def default_max_steps(self) -> int:
55
+ return int(self.craftax_env.default_params.max_timesteps)
56
+
57
+ @override
58
+ def reset(self, key: Key) -> tuple[State, Info]:
59
+ key, subkey = jax.random.split(key)
60
+ obs, env_state = self.craftax_env.reset(subkey, self.env_params)
61
+ state = Container().update(key=key, env_state=env_state)
62
+ info = InfoContainer(obs=obs, reward=0.0, terminated=False)
63
+ return state, info
64
+
65
+ @override
66
+ def step(self, state: State, action: PyTree) -> tuple[State, Info]:
67
+ key, subkey = jax.random.split(state.key)
68
+ obs, env_state, reward, done, env_info = self.craftax_env.step(
69
+ subkey, state.env_state, action, self.env_params
70
+ )
71
+ state = state.update(key=key, env_state=env_state)
72
+ info = InfoContainer(obs=obs, reward=reward, terminated=done)
73
+ info = info.update(info=env_info)
74
+ return state, info
75
+
76
+ @override
77
+ @cached_property
78
+ def action_space(self) -> envelope_spaces.Space:
79
+ return _convert_gymnax_space(self.craftax_env.action_space(self.env_params))
80
+
81
+ @override
82
+ @cached_property
83
+ def observation_space(self) -> envelope_spaces.Space:
84
+ return _convert_gymnax_space(
85
+ self.craftax_env.observation_space(self.env_params)
86
+ )
@@ -0,0 +1,91 @@
1
+ from functools import cached_property
2
+ from typing import Any, override
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from gymnax import make as gymnax_create
7
+ from gymnax.environments import spaces as gymnax_spaces
8
+ from gymnax.environments.environment import Environment as GymnaxEnv
9
+ from gymnax.environments.environment import EnvParams as GymnaxEnvParams
10
+
11
+ from envelope import spaces as envelope_spaces
12
+ from envelope.environment import Environment, Info, InfoContainer, State
13
+ from envelope.struct import Container, static_field
14
+ from envelope.typing import Key, PyTree
15
+
16
+
17
+ class GymnaxEnvelope(Environment):
18
+ """Wrapper to convert a Gymnax environment to a envelope environment."""
19
+
20
+ gymnax_env: GymnaxEnv = static_field()
21
+ env_params: PyTree
22
+
23
+ @classmethod
24
+ def from_name(
25
+ cls,
26
+ env_name: str,
27
+ env_params: GymnaxEnvParams | None = None,
28
+ env_kwargs: dict[str, Any] | None = None,
29
+ ) -> "GymnaxEnvelope":
30
+ env_kwargs = env_kwargs or {}
31
+ if "max_steps_in_episode" in env_kwargs:
32
+ raise ValueError(
33
+ "Cannot override 'max_steps_in_episode' directly. "
34
+ "Use TruncationWrapper for episode length control."
35
+ )
36
+ gymnax_env, default_params = gymnax_create(env_name, **env_kwargs)
37
+ default_params = default_params.replace(max_steps_in_episode=jnp.inf)
38
+
39
+ env_params = env_params or default_params
40
+ return cls(gymnax_env=gymnax_env, env_params=env_params)
41
+
42
+ @property
43
+ def default_max_steps(self) -> int:
44
+ return int(self.gymnax_env.default_params.max_steps_in_episode)
45
+
46
+ @override
47
+ def reset(self, key: Key) -> tuple[State, Info]:
48
+ key, subkey = jax.random.split(key)
49
+ obs, env_state = self.gymnax_env.reset(subkey, self.env_params)
50
+ state = Container().update(key=key, env_state=env_state)
51
+ info = InfoContainer(obs=obs, reward=0.0, terminated=False)
52
+ info = info.update(info=None)
53
+ return state, info
54
+
55
+ @override
56
+ def step(self, state: State, action: PyTree) -> tuple[State, Info]:
57
+ key, subkey = jax.random.split(state.key)
58
+ obs, env_state, reward, done, env_info = self.gymnax_env.step(
59
+ subkey, state.env_state, action, self.env_params
60
+ )
61
+ state = state.update(key=key, env_state=env_state)
62
+ info = InfoContainer(obs=obs, reward=reward, terminated=done)
63
+ info = info.update(info=env_info)
64
+ return state, info
65
+
66
+ @override
67
+ @cached_property
68
+ def action_space(self) -> envelope_spaces.Space:
69
+ return _convert_space(self.gymnax_env.action_space(self.env_params))
70
+
71
+ @override
72
+ @cached_property
73
+ def observation_space(self) -> envelope_spaces.Space:
74
+ return _convert_space(self.gymnax_env.observation_space(self.env_params))
75
+
76
+
77
+ def _convert_space(gmx_space: gymnax_spaces.Space) -> envelope_spaces.Space:
78
+ if isinstance(gmx_space, gymnax_spaces.Box):
79
+ low = jnp.broadcast_to(gmx_space.low, gmx_space.shape).astype(gmx_space.dtype)
80
+ high = jnp.broadcast_to(gmx_space.high, gmx_space.shape).astype(gmx_space.dtype)
81
+ return envelope_spaces.Continuous(low=low, high=high)
82
+ elif isinstance(gmx_space, gymnax_spaces.Discrete):
83
+ n = jnp.broadcast_to(gmx_space.n, gmx_space.shape).astype(gmx_space.dtype)
84
+ return envelope_spaces.Discrete(n=n)
85
+ elif isinstance(gmx_space, gymnax_spaces.Tuple):
86
+ spaces = tuple(_convert_space(space) for space in gmx_space.spaces)
87
+ return envelope_spaces.PyTreeSpace(spaces)
88
+ elif isinstance(gmx_space, gymnax_spaces.Dict):
89
+ spaces = {k: _convert_space(space) for k, space in gmx_space.spaces.items()}
90
+ return envelope_spaces.PyTreeSpace(spaces)
91
+ raise ValueError(f"Unsupported space type: {type(gmx_space)}")
@@ -0,0 +1,127 @@
1
+ import warnings
2
+ from copy import copy
3
+ from functools import cached_property
4
+ from typing import Any, override
5
+
6
+ import jax.numpy as jnp
7
+ import jumanji
8
+ from jumanji.specs import Array, BoundedArray, DiscreteArray, MultiDiscreteArray
9
+ from jumanji.types import TimeStep as JumanjiTimeStep
10
+
11
+ from envelope import spaces as envelope_spaces
12
+ from envelope.environment import Environment, Info, InfoContainer, State
13
+ from envelope.struct import static_field
14
+ from envelope.typing import Key, PyTree
15
+
16
+ _MAX_INT = jnp.iinfo(jnp.int32).max
17
+
18
+
19
+ class JumanjiEnvelope(Environment):
20
+ """Wrapper to convert a Jumanji environment to a envelope environment."""
21
+
22
+ jumanji_env: Any = static_field()
23
+ _default_time_limit: int | None = static_field(default=None)
24
+
25
+ @classmethod
26
+ def from_name(
27
+ cls, env_name: str, env_kwargs: dict[str, Any] | None = None
28
+ ) -> "JumanjiEnvelope":
29
+ env_kwargs = env_kwargs or {}
30
+ if "time_limit" in env_kwargs:
31
+ raise ValueError(
32
+ "Cannot override 'time_limit' directly. "
33
+ "Use TruncationWrapper for episode length control."
34
+ )
35
+
36
+ # Create env first with defaults to capture default time_limit
37
+ temp_env = jumanji.make(env_name, **env_kwargs)
38
+ default_time_limit = getattr(temp_env, "time_limit", None)
39
+
40
+ # Now create env with time_limit=_MAX_INT (if env supports it)
41
+ if default_time_limit is not None:
42
+ env_kwargs["time_limit"] = _MAX_INT
43
+ env = jumanji.make(env_name, **env_kwargs)
44
+ return cls(jumanji_env=env, _default_time_limit=default_time_limit)
45
+
46
+ @property
47
+ def default_max_steps(self) -> int | None:
48
+ return self._default_time_limit
49
+
50
+ @override
51
+ def reset(self, key: Key) -> tuple[State, Info]:
52
+ env_state, timestep = self.jumanji_env.reset(key)
53
+ info = convert_jumanji_to_envelope_info(timestep)
54
+ return env_state, info
55
+
56
+ @override
57
+ def step(self, state: State, action: PyTree) -> tuple[State, Info]:
58
+ env_state, timestep = self.jumanji_env.step(state, action)
59
+ info = convert_jumanji_to_envelope_info(timestep)
60
+ return env_state, info
61
+
62
+ @override
63
+ @cached_property
64
+ def action_space(self) -> envelope_spaces.Space:
65
+ return convert_jumanji_spec_to_envelope_space(self.jumanji_env.action_spec)
66
+
67
+ @override
68
+ @cached_property
69
+ def observation_space(self) -> envelope_spaces.Space:
70
+ return convert_jumanji_spec_to_envelope_space(self.jumanji_env.observation_spec)
71
+
72
+ def __deepcopy__(self, memo):
73
+ warnings.warn(
74
+ f"Trying to deepcopy {type(self).__name__}, which contains a jumanji env. "
75
+ "Jumanji envs may throw an error when deepcopying, so a shallow copy is "
76
+ "returned.",
77
+ category=RuntimeWarning,
78
+ stacklevel=2,
79
+ )
80
+ return copy(self)
81
+
82
+
83
+ def convert_jumanji_to_envelope_info(timestep: JumanjiTimeStep) -> InfoContainer:
84
+ info = InfoContainer(
85
+ obs=timestep.observation, reward=timestep.reward, terminated=timestep.last()
86
+ ).update(**timestep.extras)
87
+ return info
88
+
89
+
90
+ def convert_jumanji_spec_to_envelope_space(spec: Any) -> envelope_spaces.Space:
91
+ """Convert a Jumanji Spec to a envelope Space."""
92
+
93
+ if isinstance(spec, (DiscreteArray, MultiDiscreteArray)):
94
+ n = jnp.asarray(spec.num_values, dtype=spec.dtype)
95
+ if getattr(spec, "shape", ()) not in ((), n.shape):
96
+ n = jnp.broadcast_to(n, spec.shape)
97
+ return envelope_spaces.Discrete(n=n)
98
+
99
+ if isinstance(spec, BoundedArray):
100
+ low = jnp.broadcast_to(jnp.asarray(spec.minimum, dtype=spec.dtype), spec.shape)
101
+ high = jnp.broadcast_to(jnp.asarray(spec.maximum, dtype=spec.dtype), spec.shape)
102
+ return envelope_spaces.Continuous(low=low, high=high)
103
+
104
+ if isinstance(spec, Array):
105
+ dtype = jnp.dtype(spec.dtype)
106
+ if not jnp.issubdtype(dtype, jnp.floating):
107
+ raise NotImplementedError(
108
+ "Unbounded jumanji Array specs are only supported for floating dtypes. "
109
+ f"Got dtype={dtype} for spec={spec!r}."
110
+ )
111
+ low = jnp.full(spec.shape, -jnp.inf, dtype=dtype)
112
+ high = jnp.full(spec.shape, jnp.inf, dtype=dtype)
113
+ return envelope_spaces.Continuous(low=low, high=high)
114
+
115
+ # Structured specs (most Jumanji envs): access private mapping when available.
116
+ subspecs = getattr(spec, "_specs", None)
117
+ if isinstance(subspecs, dict):
118
+ tree = {
119
+ k: convert_jumanji_spec_to_envelope_space(v) for k, v in subspecs.items()
120
+ }
121
+ return envelope_spaces.PyTreeSpace(tree)
122
+
123
+ if isinstance(spec, (tuple, list)):
124
+ tree = tuple(convert_jumanji_spec_to_envelope_space(s) for s in spec)
125
+ return envelope_spaces.PyTreeSpace(tree)
126
+
127
+ raise ValueError(f"Unsupported spec type: {type(spec)}")
@@ -0,0 +1,194 @@
1
+ """Kinetix compatibility wrapper.
2
+
3
+ This module exposes Kinetix environments through the `envelope.environment.Environment`
4
+ API. It mirrors envelope's compat philosophy:
5
+ - prefer *no* environment-side auto-reset (use `AutoResetWrapper` in envelope)
6
+ - prefer *no* fixed episode time-limits (use `TruncationWrapper` in envelope)
7
+
8
+ `from_name` supports premade level ids like `s/h4_thrust_aim` (optionally with
9
+ `.json`). For maximum flexibility, users can bypass level handling entirely by
10
+ passing a custom `reset_fn`.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import warnings
16
+ from functools import cached_property
17
+ from typing import Any, Callable, Literal, override
18
+
19
+ import jax
20
+ import jax.numpy as jnp
21
+ from kinetix.environment import (
22
+ ActionType,
23
+ EnvParams,
24
+ ObservationType,
25
+ StaticEnvParams,
26
+ make_kinetix_env,
27
+ )
28
+ from kinetix.environment.ued.ued import make_reset_fn_sample_kinetix_level
29
+ from kinetix.util.saving import load_from_json_file
30
+
31
+ from envelope import spaces as envelope_spaces
32
+ from envelope.compat.gymnax_envelope import _convert_space as _convert_gymnax_space
33
+ from envelope.environment import Environment, Info, InfoContainer, State
34
+ from envelope.struct import Container, static_field
35
+ from envelope.typing import Key, PyTree
36
+
37
+ LevelResetFn = Callable[[Key], Any]
38
+
39
+
40
+ def _normalize_level_id(level_id: str) -> str:
41
+ """Normalize a path-like level id.
42
+
43
+ Examples:
44
+ - ``"s/h4_thrust_aim"`` -> ``"s/h4_thrust_aim.json"``
45
+ - ``"/s/h4_thrust_aim.json"`` -> ``"s/h4_thrust_aim.json"``
46
+ """
47
+ level_id = level_id.strip().lstrip("/")
48
+ if not level_id:
49
+ raise ValueError("level_id must be a non-empty string")
50
+ if level_id.endswith("/"):
51
+ raise ValueError("level_id must not end with '/'")
52
+ if not level_id.endswith(".json"):
53
+ level_id = f"{level_id}.json"
54
+ return level_id
55
+
56
+
57
+ def _warn_auto_reset(auto_reset: bool) -> None:
58
+ if auto_reset:
59
+ warnings.warn(
60
+ "Creating a KinetixEnvelope with auto_reset=True is not recommended, use "
61
+ "an AutoResetWrapper instead.",
62
+ stacklevel=2,
63
+ )
64
+
65
+
66
+ class KinetixEnvelope(Environment):
67
+ """Wrapper to convert a Kinetix environment to a envelope environment."""
68
+
69
+ kinetix_env: Any = static_field()
70
+ env_params: Any
71
+
72
+ @property
73
+ def default_max_steps(self) -> int:
74
+ return int(EnvParams().max_timesteps)
75
+
76
+ @classmethod
77
+ def from_name(
78
+ cls,
79
+ env_name: str | Literal["random"],
80
+ env_params: EnvParams | None = None,
81
+ env_kwargs: dict[str, Any] | None = None,
82
+ ) -> "KinetixEnvelope":
83
+ env_kwargs = env_kwargs or {}
84
+ if "max_timesteps" in env_kwargs:
85
+ raise ValueError(
86
+ "Cannot override 'max_timesteps' directly. "
87
+ "Use TruncationWrapper for episode length control."
88
+ )
89
+ if "auto_reset" in env_kwargs:
90
+ raise ValueError(
91
+ "Cannot override 'auto_reset' directly. "
92
+ "Use AutoResetWrapper for auto-reset behavior."
93
+ )
94
+
95
+ env_kwargs["auto_reset"] = False
96
+ if env_name == "random":
97
+ return cls.create_random(env_params=env_params, **env_kwargs)
98
+
99
+ if (
100
+ env_params is not None
101
+ or "env_params" in env_kwargs
102
+ or "static_env_params" in env_kwargs
103
+ ):
104
+ raise ValueError(
105
+ "env_params and static_env_params cannot be passed when creating a "
106
+ "KinetixEnvelope from a premade level."
107
+ )
108
+ return cls.create_premade(env_name, **env_kwargs)
109
+
110
+ @classmethod
111
+ def create_premade(
112
+ cls,
113
+ env_name: str,
114
+ action_type: ActionType = ActionType.CONTINUOUS,
115
+ observation_type: ObservationType = ObservationType.SYMBOLIC_FLAT,
116
+ auto_reset: bool = False,
117
+ ) -> "KinetixEnvelope":
118
+ _warn_auto_reset(auto_reset)
119
+
120
+ # Load level.
121
+ level_id_json = _normalize_level_id(env_name)
122
+ level, static_env_params, env_params = load_from_json_file(level_id_json)
123
+ env_params = env_params.replace(max_timesteps=jnp.inf) if env_params else None
124
+
125
+ def reset_fn(_: Key) -> Any:
126
+ return level
127
+
128
+ # Create environment.
129
+ kinetix_env = make_kinetix_env(
130
+ action_type=action_type,
131
+ observation_type=observation_type,
132
+ reset_fn=reset_fn,
133
+ env_params=env_params,
134
+ static_env_params=static_env_params,
135
+ auto_reset=auto_reset,
136
+ )
137
+ return cls(kinetix_env=kinetix_env, env_params=env_params)
138
+
139
+ @classmethod
140
+ def create_random(
141
+ cls,
142
+ action_type: ActionType = ActionType.CONTINUOUS,
143
+ observation_type: ObservationType = ObservationType.SYMBOLIC_FLAT,
144
+ env_params: EnvParams | None = None,
145
+ static_env_params: StaticEnvParams = StaticEnvParams(),
146
+ auto_reset: bool = False,
147
+ ) -> "KinetixEnvelope":
148
+ _warn_auto_reset(auto_reset)
149
+ if env_params is None:
150
+ env_params = EnvParams()
151
+ env_params = env_params.replace(max_timesteps=jnp.inf)
152
+
153
+ reset_fn = make_reset_fn_sample_kinetix_level(env_params, static_env_params)
154
+ kinetix_env = make_kinetix_env(
155
+ action_type=action_type,
156
+ observation_type=observation_type,
157
+ reset_fn=reset_fn,
158
+ env_params=env_params,
159
+ static_env_params=static_env_params,
160
+ auto_reset=auto_reset,
161
+ )
162
+ return cls(kinetix_env=kinetix_env, env_params=env_params)
163
+
164
+ @override
165
+ def reset(self, key: Key) -> tuple[State, Info]:
166
+ key, subkey = jax.random.split(key)
167
+ obs, env_state = self.kinetix_env.reset(subkey, self.env_params)
168
+ state_out = Container().update(key=key, env_state=env_state)
169
+ info = InfoContainer(obs=obs, reward=0.0, terminated=False)
170
+ info = info.update(info=None)
171
+ return state_out, info
172
+
173
+ @override
174
+ def step(self, state: State, action: PyTree) -> tuple[State, Info]:
175
+ key, subkey = jax.random.split(state.key)
176
+ obs, env_state, reward, done, env_info = self.kinetix_env.step(
177
+ subkey, state.env_state, action, self.env_params
178
+ )
179
+ state_out = state.update(key=key, env_state=env_state)
180
+ info = InfoContainer(obs=obs, reward=reward, terminated=done)
181
+ info = info.update(info=env_info)
182
+ return state_out, info
183
+
184
+ @override
185
+ @cached_property
186
+ def action_space(self) -> envelope_spaces.Space:
187
+ return _convert_gymnax_space(self.kinetix_env.action_space(self.env_params))
188
+
189
+ @override
190
+ @cached_property
191
+ def observation_space(self) -> envelope_spaces.Space:
192
+ return _convert_gymnax_space(
193
+ self.kinetix_env.observation_space(self.env_params)
194
+ )