jax-envelope 0.1.1__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/.gitignore +2 -1
  2. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/PKG-INFO +2 -2
  3. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/README.md +1 -1
  4. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/pyproject.toml +13 -6
  5. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/src/envelope/__init__.py +16 -4
  6. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/src/envelope/compat/brax_envelope.py +5 -3
  7. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/src/envelope/compat/craftax_envelope.py +17 -2
  8. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/src/envelope/compat/gymnax_envelope.py +34 -7
  9. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/src/envelope/compat/jumanji_envelope.py +3 -2
  10. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/src/envelope/compat/kinetix_envelope.py +3 -2
  11. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/src/envelope/compat/mujoco_playground_envelope.py +1 -1
  12. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/src/envelope/compat/navix_envelope.py +1 -1
  13. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/src/envelope/environment.py +16 -9
  14. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/src/envelope/spaces.py +41 -21
  15. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/src/envelope/struct.py +10 -1
  16. jax_envelope-0.2.0/src/envelope/wrappers/__init__.py +36 -0
  17. jax_envelope-0.2.0/src/envelope/wrappers/autoreset_wrapper.py +80 -0
  18. jax_envelope-0.2.0/src/envelope/wrappers/clip_action_wrapper.py +27 -0
  19. jax_envelope-0.2.0/src/envelope/wrappers/continuous_observation_wrapper.py +61 -0
  20. jax_envelope-0.2.0/src/envelope/wrappers/episode_statistics_wrapper.py +40 -0
  21. jax_envelope-0.2.0/src/envelope/wrappers/flatten_action_wrapper.py +75 -0
  22. jax_envelope-0.2.0/src/envelope/wrappers/flatten_observation_wrapper.py +81 -0
  23. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/src/envelope/wrappers/normalization.py +1 -1
  24. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/src/envelope/wrappers/observation_normalization_wrapper.py +28 -16
  25. jax_envelope-0.2.0/src/envelope/wrappers/pooled_init_vmap_wrapper.py +122 -0
  26. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/src/envelope/wrappers/state_injection_wrapper.py +18 -22
  27. jax_envelope-0.2.0/src/envelope/wrappers/truncation_wrapper.py +35 -0
  28. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/src/envelope/wrappers/vmap_envs_wrapper.py +26 -21
  29. jax_envelope-0.2.0/src/envelope/wrappers/vmap_wrapper.py +66 -0
  30. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/src/envelope/wrappers/wrapper.py +8 -8
  31. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/tests/compat/conftest.py +1 -1
  32. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/tests/compat/contract.py +5 -2
  33. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/tests/compat/test_brax_compat.py +4 -4
  34. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/tests/compat/test_craftax_compat.py +4 -2
  35. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/tests/compat/test_create.py +12 -0
  36. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/tests/compat/test_create_integration.py +7 -7
  37. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/tests/compat/test_gymnax_compat.py +3 -3
  38. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/tests/compat/test_jumanji_compat.py +1 -1
  39. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/tests/compat/test_kinetix_compat.py +7 -9
  40. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/tests/compat/test_mujoco_playground_compat.py +6 -6
  41. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/tests/compat/test_navix_compat.py +5 -5
  42. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/tests/spaces/test_batched_space.py +74 -50
  43. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/tests/spaces/test_pytree_space.py +21 -1
  44. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/tests/wrappers/helpers.py +122 -36
  45. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/tests/wrappers/test_autoreset_wrapper.py +164 -67
  46. jax_envelope-0.2.0/tests/wrappers/test_clip_action_wrapper.py +174 -0
  47. jax_envelope-0.2.0/tests/wrappers/test_continuous_observation_wrapper.py +153 -0
  48. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/tests/wrappers/test_environment_wrapper.py +12 -12
  49. jax_envelope-0.2.0/tests/wrappers/test_episode_statistics_wrapper.py +183 -0
  50. jax_envelope-0.2.0/tests/wrappers/test_flatten_action_wrapper.py +215 -0
  51. jax_envelope-0.2.0/tests/wrappers/test_flatten_observation_wrapper.py +178 -0
  52. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/tests/wrappers/test_observation_normalization_wrapper.py +9 -9
  53. jax_envelope-0.2.0/tests/wrappers/test_pooled_init_vmap_wrapper.py +292 -0
  54. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/tests/wrappers/test_state_injection_wrapper.py +15 -15
  55. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/tests/wrappers/test_truncation_wrapper.py +39 -14
  56. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/tests/wrappers/test_vmap_envs_wrapper.py +7 -7
  57. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/tests/wrappers/test_vmap_wrapper.py +18 -18
  58. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/uv.lock +818 -711
  59. jax_envelope-0.1.1/src/envelope/wrappers/__init__.py +0 -20
  60. jax_envelope-0.1.1/src/envelope/wrappers/autoreset_wrapper.py +0 -36
  61. jax_envelope-0.1.1/src/envelope/wrappers/episode_statistics_wrapper.py +0 -47
  62. jax_envelope-0.1.1/src/envelope/wrappers/truncation_wrapper.py +0 -31
  63. jax_envelope-0.1.1/src/envelope/wrappers/vmap_wrapper.py +0 -51
  64. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/.github/workflows/publish.yml +0 -0
  65. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/LICENSE +0 -0
  66. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/src/envelope/compat/__init__.py +0 -0
  67. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/src/envelope/typing.py +0 -0
  68. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/tests/__init__.py +0 -0
  69. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/tests/compat/__init__.py +0 -0
  70. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/tests/spaces/__init__.py +0 -0
  71. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/tests/spaces/test_continuous.py +0 -0
  72. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/tests/spaces/test_discrete.py +0 -0
  73. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/tests/spaces/test_serialization.py +0 -0
  74. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/tests/test_container.py +0 -0
  75. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/tests/test_struct.py +0 -0
  76. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/tests/wrappers/__init__.py +0 -0
  77. {jax_envelope-0.1.1 → jax_envelope-0.2.0}/tests/wrappers/test_normalization.py +0 -0
@@ -166,7 +166,8 @@ wandb/
166
166
  # ruff
167
167
  .ruff_cache/
168
168
 
169
- # Cursor
169
+ # Vibecoding
170
170
  .cursor
171
171
  .cursorignore
172
172
  AGENTS.md
173
+ CLAUDE.md
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jax-envelope
3
- Version: 0.1.1
3
+ Version: 0.2.0
4
4
  Summary: A JAX-native environment interface with powerful wrappers and adapters for popular RL environment suites
5
5
  Project-URL: Homepage, https://github.com/keraJLi/envelope
6
6
  Project-URL: Repository, https://github.com/keraJLi/envelope
@@ -82,6 +82,6 @@ pip install jax-envelope
82
82
  ```
83
83
 
84
84
  ## 💞 Related projects
85
- * [stoax](https://github.com/EdanToledo/Stoa) is a very similar project that provides adapters and wrappers for the jumanji-like interface.
85
+ * [stoa](https://github.com/EdanToledo/Stoa) is a very similar project that provides adapters and wrappers for the jumanji-like interface.
86
86
  * Check out all the great suites we have adapters for! [gymnax](https://github.com/RobertTLange/gymnax), [brax](https://github.com/google/brax), [jumanji](https://github.com/instadeepai/jumanji), [kinetix](https://github.com/flairox/kinetix), [craftax](https://github.com/MichaelTMatthews/craftax), [mujoco_playground](https://github.com/google-deepmind/mujoco_playground).
87
87
  * We will be adding support for [jaxmarl](https://github.com/flairox/jaxmarl) and [pgx](https://github.com/sotetsuk/pgx) in the future, as soon as we figured out the best ever MARL interface for JAX!
@@ -56,6 +56,6 @@ pip install jax-envelope
56
56
  ```
57
57
 
58
58
  ## 💞 Related projects
59
- * [stoax](https://github.com/EdanToledo/Stoa) is a very similar project that provides adapters and wrappers for the jumanji-like interface.
59
+ * [stoa](https://github.com/EdanToledo/Stoa) is a very similar project that provides adapters and wrappers for the jumanji-like interface.
60
60
  * Check out all the great suites we have adapters for! [gymnax](https://github.com/RobertTLange/gymnax), [brax](https://github.com/google/brax), [jumanji](https://github.com/instadeepai/jumanji), [kinetix](https://github.com/flairox/kinetix), [craftax](https://github.com/MichaelTMatthews/craftax), [mujoco_playground](https://github.com/google-deepmind/mujoco_playground).
61
61
  * We will be adding support for [jaxmarl](https://github.com/flairox/jaxmarl) and [pgx](https://github.com/sotetsuk/pgx) in the future, as soon as we figured out the best ever MARL interface for JAX!
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "jax-envelope"
3
- version = "0.1.1"
3
+ version = "0.2.0"
4
4
  description = "A JAX-native environment interface with powerful wrappers and adapters for popular RL environment suites"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.12"
@@ -51,18 +51,15 @@ allow-direct-references = true
51
51
  [tool.hatch.build.targets.wheel]
52
52
  packages = ["src/envelope"]
53
53
 
54
- [tool.uv.sources]
55
- gymnax = { git = "https://github.com/RobertTLange/gymnax" }
56
-
57
54
  [dependency-groups]
58
55
  compat = [
59
- "gymnax @ git+https://github.com/RobertTLange/gymnax@main",
60
56
  "brax>=0.13.0",
61
57
  "craftax>=1.4.3",
62
58
  "navix>=0.7.0",
63
59
  "jumanji>=1.0.1",
64
- "kinetix-env>=2.0.0",
65
60
  "playground>=0.1.0",
61
+ "gymnax",
62
+ "kinetix-env",
66
63
  ]
67
64
  dev = [
68
65
  "hypothesis>=6.148.1",
@@ -71,7 +68,17 @@ dev = [
71
68
  "ruff>=0.14.2",
72
69
  ]
73
70
 
71
+ [tool.uv]
72
+ override-dependencies = [
73
+ "tensorflow-probability>=0.26.0.dev20260116",
74
+ ]
75
+
76
+ [tool.uv.sources]
77
+ gymnax = { git = "https://github.com/RobertTLange/gymnax.git" }
78
+ kinetix-env = { git = "https://github.com/FLAIROx/Kinetix.git" }
79
+
74
80
  [tool.pytest.ini_options]
81
+ testpaths = ["tests"]
75
82
  markers = [
76
83
  "compat: tests requiring optional compat dependencies",
77
84
  ]
@@ -1,16 +1,22 @@
1
1
  from envelope.compat import create
2
2
  from envelope.environment import Environment, Info, InfoContainer
3
3
  from envelope.spaces import BatchedSpace, Continuous, Discrete, PyTreeSpace, Space
4
- from envelope.struct import field, static_field, FrozenPyTreeNode, Container
4
+ from envelope.struct import Container, FrozenPyTreeNode, field, static_field
5
5
  from envelope.wrappers import (
6
- Wrapper,
7
- WrappedState,
8
6
  AutoResetWrapper,
7
+ ClipActionWrapper,
8
+ ContinuousObservationWrapper,
9
+ EpisodeStatisticsWrapper,
10
+ FlattenActionWrapper,
11
+ FlattenObservationWrapper,
9
12
  ObservationNormalizationWrapper,
13
+ PooledInitVmapWrapper,
10
14
  StateInjectionWrapper,
11
15
  TruncationWrapper,
12
- VmapWrapper,
13
16
  VmapEnvsWrapper,
17
+ VmapWrapper,
18
+ WrappedState,
19
+ Wrapper,
14
20
  )
15
21
 
16
22
  __all__ = [
@@ -34,7 +40,13 @@ __all__ = [
34
40
  "Wrapper",
35
41
  "WrappedState",
36
42
  "AutoResetWrapper",
43
+ "ClipActionWrapper",
44
+ "ContinuousObservationWrapper",
45
+ "EpisodeStatisticsWrapper",
46
+ "FlattenActionWrapper",
47
+ "FlattenObservationWrapper",
37
48
  "ObservationNormalizationWrapper",
49
+ "PooledInitVmapWrapper",
38
50
  "StateInjectionWrapper",
39
51
  "TruncationWrapper",
40
52
  "VmapWrapper",
@@ -48,7 +48,7 @@ class BraxEnvelope(Environment):
48
48
  def default_max_steps(self) -> int:
49
49
  return _BRAX_DEFAULT_EPISODE_LENGTH
50
50
 
51
- def __post_init__(self) -> "BraxEnvelope":
51
+ def __post_init__(self):
52
52
  if isinstance(self.brax_env, BraxWrapper):
53
53
  warnings.warn(
54
54
  "Environment wrapping should be handled by envelope. "
@@ -57,7 +57,7 @@ class BraxEnvelope(Environment):
57
57
  object.__setattr__(self, "brax_env", self.brax_env.unwrapped)
58
58
 
59
59
  @override
60
- def reset(self, key: Key) -> tuple[State, Info]:
60
+ def init(self, key: Key) -> tuple[State, Info]:
61
61
  brax_state = self.brax_env.reset(key)
62
62
  info = InfoContainer(obs=brax_state.obs, reward=0.0, terminated=False)
63
63
  info = info.update(**dataclasses.asdict(brax_state))
@@ -67,7 +67,9 @@ class BraxEnvelope(Environment):
67
67
  def step(self, state: State, action: PyTree) -> tuple[State, Info]:
68
68
  brax_state = self.brax_env.step(state, action)
69
69
  info = InfoContainer(
70
- obs=brax_state.obs, reward=brax_state.reward, terminated=brax_state.done
70
+ obs=brax_state.obs,
71
+ reward=brax_state.reward,
72
+ terminated=jnp.asarry(brax_state.done, dtype=bool).item(),
71
73
  )
72
74
  info = info.update(**dataclasses.asdict(brax_state))
73
75
  return brax_state, info
@@ -22,7 +22,7 @@ class CraftaxEnvelope(Environment):
22
22
  """Wrapper to convert a Craftax environment to a envelope environment."""
23
23
 
24
24
  craftax_env: Any = static_field()
25
- env_params: PyTree
25
+ env_params: PyTree = static_field() # TODO: remove static marker as soon as craftax merges https://github.com/MichaelTMatthews/Craftax/pull/48
26
26
 
27
27
  @classmethod
28
28
  def from_name(
@@ -54,12 +54,27 @@ class CraftaxEnvelope(Environment):
54
54
  def default_max_steps(self) -> int:
55
55
  return int(self.craftax_env.default_params.max_timesteps)
56
56
 
57
+ @cached_property
58
+ def _craftax_info_placeholder(self) -> PyTree:
59
+ key = jax.random.PRNGKey(0)
60
+ _, state = self.craftax_env.reset(key, self.env_params)
61
+ _, _, _, _, info = self.craftax_env.step(
62
+ key,
63
+ state,
64
+ self.craftax_env.action_space(self.env_params).sample(key),
65
+ self.env_params,
66
+ )
67
+ return jax.tree.map(lambda x: jnp.full_like(x, jnp.nan), info)
68
+
57
69
  @override
58
- def reset(self, key: Key) -> tuple[State, Info]:
70
+ def init(self, key: Key) -> tuple[State, Info]:
71
+ # TODO: this function does not add env_info (or comparable) to the info
72
+ # container. We should add tests for this (and all other envelopes) and fix it.
59
73
  key, subkey = jax.random.split(key)
60
74
  obs, env_state = self.craftax_env.reset(subkey, self.env_params)
61
75
  state = Container().update(key=key, env_state=env_state)
62
76
  info = InfoContainer(obs=obs, reward=0.0, terminated=False)
77
+ info = info.update(info=self._craftax_info_placeholder)
63
78
  return state, info
64
79
 
65
80
  @override
@@ -1,5 +1,5 @@
1
1
  from functools import cached_property
2
- from typing import Any, override
2
+ from typing import Any, Callable, cast, override
3
3
 
4
4
  import jax
5
5
  import jax.numpy as jnp
@@ -10,15 +10,24 @@ from gymnax.environments.environment import EnvParams as GymnaxEnvParams
10
10
 
11
11
  from envelope import spaces as envelope_spaces
12
12
  from envelope.environment import Environment, Info, InfoContainer, State
13
- from envelope.struct import Container, static_field
13
+ from envelope.struct import Container, field, static_field
14
14
  from envelope.typing import Key, PyTree
15
15
 
16
+ _GymnaxReset = Callable[
17
+ [Key, GymnaxEnvParams],
18
+ tuple[PyTree, Any],
19
+ ]
20
+ _GymnaxStep = Callable[
21
+ [Key, Any, PyTree, GymnaxEnvParams],
22
+ tuple[PyTree, Any, jnp.ndarray, jnp.ndarray, PyTree],
23
+ ]
24
+
16
25
 
17
26
  class GymnaxEnvelope(Environment):
18
27
  """Wrapper to convert a Gymnax environment to a envelope environment."""
19
28
 
20
29
  gymnax_env: GymnaxEnv = static_field()
21
- env_params: PyTree
30
+ env_params: PyTree = field()
22
31
 
23
32
  @classmethod
24
33
  def from_name(
@@ -43,19 +52,37 @@ class GymnaxEnvelope(Environment):
43
52
  def default_max_steps(self) -> int:
44
53
  return int(self.gymnax_env.default_params.max_steps_in_episode)
45
54
 
55
+ @cached_property
56
+ def _gymnax_info_placeholder(self) -> PyTree:
57
+ reset_fn = cast(_GymnaxReset, self.gymnax_env.reset)
58
+ step_fn = cast(_GymnaxStep, self.gymnax_env.step)
59
+
60
+ key = jax.random.PRNGKey(0)
61
+ _, state = reset_fn(key, self.env_params)
62
+ _, _, _, _, info = step_fn(
63
+ key,
64
+ state,
65
+ self.gymnax_env.action_space(self.env_params).sample(key),
66
+ self.env_params,
67
+ )
68
+ return jax.tree.map(lambda x: jnp.full_like(x, jnp.nan, dtype=float), info)
69
+
46
70
  @override
47
- def reset(self, key: Key) -> tuple[State, Info]:
71
+ def init(self, key: Key) -> tuple[State, Info]:
72
+ reset_fn = cast(_GymnaxReset, self.gymnax_env.reset)
73
+
48
74
  key, subkey = jax.random.split(key)
49
- obs, env_state = self.gymnax_env.reset(subkey, self.env_params)
75
+ obs, env_state = reset_fn(subkey, self.env_params)
50
76
  state = Container().update(key=key, env_state=env_state)
51
77
  info = InfoContainer(obs=obs, reward=0.0, terminated=False)
52
- info = info.update(info=None)
78
+ info = info.update(info=self._gymnax_info_placeholder)
53
79
  return state, info
54
80
 
55
81
  @override
56
82
  def step(self, state: State, action: PyTree) -> tuple[State, Info]:
57
83
  key, subkey = jax.random.split(state.key)
58
- obs, env_state, reward, done, env_info = self.gymnax_env.step(
84
+ step_fn = cast(_GymnaxStep, self.gymnax_env.step)
85
+ obs, env_state, reward, done, env_info = step_fn(
59
86
  subkey, state.env_state, action, self.env_params
60
87
  )
61
88
  state = state.update(key=key, env_state=env_state)
@@ -48,7 +48,7 @@ class JumanjiEnvelope(Environment):
48
48
  return self._default_time_limit
49
49
 
50
50
  @override
51
- def reset(self, key: Key) -> tuple[State, Info]:
51
+ def init(self, key: Key) -> tuple[State, Info]:
52
52
  env_state, timestep = self.jumanji_env.reset(key)
53
53
  info = convert_jumanji_to_envelope_info(timestep)
54
54
  return env_state, info
@@ -81,8 +81,9 @@ class JumanjiEnvelope(Environment):
81
81
 
82
82
 
83
83
  def convert_jumanji_to_envelope_info(timestep: JumanjiTimeStep) -> InfoContainer:
84
+ term = jnp.asarray(timestep.last(), dtype=bool).item()
84
85
  info = InfoContainer(
85
- obs=timestep.observation, reward=timestep.reward, terminated=timestep.last()
86
+ obs=timestep.observation, reward=timestep.reward, terminated=term
86
87
  ).update(**timestep.extras)
87
88
  return info
88
89
 
@@ -28,6 +28,7 @@ from kinetix.environment import (
28
28
  from kinetix.environment.ued.ued import make_reset_fn_sample_kinetix_level
29
29
  from kinetix.util.saving import load_from_json_file
30
30
 
31
+ from envelope import field
31
32
  from envelope import spaces as envelope_spaces
32
33
  from envelope.compat.gymnax_envelope import _convert_space as _convert_gymnax_space
33
34
  from envelope.environment import Environment, Info, InfoContainer, State
@@ -67,7 +68,7 @@ class KinetixEnvelope(Environment):
67
68
  """Wrapper to convert a Kinetix environment to a envelope environment."""
68
69
 
69
70
  kinetix_env: Any = static_field()
70
- env_params: Any
71
+ env_params: Any = field()
71
72
 
72
73
  @property
73
74
  def default_max_steps(self) -> int:
@@ -162,7 +163,7 @@ class KinetixEnvelope(Environment):
162
163
  return cls(kinetix_env=kinetix_env, env_params=env_params)
163
164
 
164
165
  @override
165
- def reset(self, key: Key) -> tuple[State, Info]:
166
+ def init(self, key: Key) -> tuple[State, Info]:
166
167
  key, subkey = jax.random.split(key)
167
168
  obs, env_state = self.kinetix_env.reset(subkey, self.env_params)
168
169
  state_out = Container().update(key=key, env_state=env_state)
@@ -56,7 +56,7 @@ class MujocoPlaygroundEnvelope(Environment):
56
56
  return self._default_max_steps
57
57
 
58
58
  @override
59
- def reset(self, key: Key) -> tuple[State, Info]:
59
+ def init(self, key: Key) -> tuple[State, Info]:
60
60
  env_state = self.mujoco_playground_env.reset(key)
61
61
  info = InfoContainer(obs=env_state.obs, reward=0.0, terminated=False)
62
62
  info = info.update(**dataclasses.asdict(env_state))
@@ -38,7 +38,7 @@ class NavixEnvelope(Environment):
38
38
  return _NAVIX_DEFAULT_MAX_STEPS
39
39
 
40
40
  @override
41
- def reset(self, key: Key) -> tuple[State, Info]:
41
+ def init(self, key: Key) -> tuple[State, Info]:
42
42
  timestep = self.navix_env.reset(key)
43
43
  return timestep, convert_navix_to_envelope_info(timestep)
44
44
 
@@ -5,7 +5,7 @@ from typing import Protocol, runtime_checkable
5
5
 
6
6
  from envelope import spaces
7
7
  from envelope.struct import Container, FrozenPyTreeNode
8
- from envelope.typing import Key, PyTree
8
+ from envelope.typing import Array, Key, PyTree
9
9
 
10
10
  __all__ = ["Environment", "State", "Info", "InfoContainer"]
11
11
 
@@ -23,7 +23,7 @@ class Info(Protocol):
23
23
 
24
24
  class InfoContainer(Container):
25
25
  obs: PyTree
26
- reward: float
26
+ reward: float | Array
27
27
  terminated: bool
28
28
  truncated: bool = field(default=False)
29
29
 
@@ -38,18 +38,25 @@ class Environment(ABC, FrozenPyTreeNode):
38
38
 
39
39
  State is an opaque PyTree owned by each environment; wrappers that stack
40
40
  environments should expose their wrapped env state as `inner_state` while
41
- adding any wrapper-specific fields. `reset` may optionally receive a prior
42
- state (for cross-episode persistence) and arbitrary **kwargs that wrappers
43
- or environments can use.
41
+ adding any wrapper-specific fields.
42
+
43
+ Two distinct lifecycle methods:
44
+ init(key) - Initialize environment and all state from scratch.
45
+ reset(key, state) - Reset the inner environment while preserving
46
+ episode-persistent state.
44
47
  """
45
48
 
46
49
  @abstractmethod
47
- def reset(
48
- self, key: Key, state: State | None = None, **kwargs
49
- ) -> tuple[State, Info]: ...
50
+ def init(self, key: Key) -> tuple[State, Info]:
51
+ """Initialize environment and all state from scratch."""
52
+ ...
53
+
54
+ def reset(self, key: Key, state: State) -> tuple[State, Info]:
55
+ """Reset the inner environment while preserving episode-persistent state."""
56
+ return self.init(key)
50
57
 
51
58
  @abstractmethod
52
- def step(self, state: State, action: PyTree, **kwargs) -> tuple[State, Info]: ...
59
+ def step(self, state: State, action: PyTree) -> tuple[State, Info]: ...
53
60
 
54
61
  @abstractmethod
55
62
  @cached_property
@@ -1,6 +1,6 @@
1
1
  from abc import ABC, abstractmethod
2
2
  from functools import cached_property
3
- from typing import override
3
+ from typing import cast, override
4
4
 
5
5
  import jax
6
6
  from jax import numpy as jnp
@@ -65,7 +65,9 @@ class Continuous(Space):
65
65
  high: float | jax.Array
66
66
 
67
67
  @classmethod
68
- def from_shape(cls, low: float, high: float, shape: tuple[int]) -> "Continuous":
68
+ def from_shape(
69
+ cls, low: float, high: float, shape: tuple[int, ...]
70
+ ) -> "Continuous":
69
71
  return cls(
70
72
  low=jnp.full(shape, low, dtype=jnp.asarray(low).dtype),
71
73
  high=jnp.full(shape, high, dtype=jnp.asarray(high).dtype),
@@ -106,17 +108,25 @@ class PyTreeSpace(Space):
106
108
  """A Space defined by a PyTree structure of other Spaces.
107
109
 
108
110
  Args:
109
- tree: A PyTree with Space objects leaves.
111
+ tree: A PyTree with Discrete or Continuous leaves.
110
112
 
111
113
  Usage:
112
114
  space = PyTreeSpace({
113
- "action": Discrete(n=4, dtype=jnp.int32),
114
- "obs": Continuous(low=0.0, high=1.0, shape=(2,), dtype=jnp.float32)
115
+ "action": Discrete(n=4),
116
+ "obs": Continuous(low=0.0, high=1.0, shape=(2,))
115
117
  })
116
118
  """
117
119
 
118
120
  tree: PyTree
119
121
 
122
+ def __post_init__(self):
123
+ leaves = jax.tree.leaves(self.tree, is_leaf=lambda x: isinstance(x, Space))
124
+ for leaf in leaves:
125
+ if not isinstance(leaf, (Discrete, Continuous)):
126
+ raise TypeError(
127
+ f"PyTreeSpace leaves must be Discrete or Continuous, got {type(leaf).__name__}"
128
+ )
129
+
120
130
  @override
121
131
  def sample(self, key: Key) -> PyTree:
122
132
  leaves, treedef = jax.tree.flatten(
@@ -149,16 +159,23 @@ class PyTreeSpace(Space):
149
159
  is_leaf=lambda node: isinstance(node, Space),
150
160
  )
151
161
 
152
-
153
- def batch_space(space: Space, batch_size: int) -> Space:
154
- if isinstance(space, PyTreeSpace):
155
- batched_tree = jax.tree.map(
156
- lambda sp: batch_space(sp, batch_size),
157
- space.tree,
162
+ @property
163
+ def dtype(self) -> PyTree:
164
+ return jax.tree.map(
165
+ lambda space: space.dtype,
166
+ self.tree,
158
167
  is_leaf=lambda node: isinstance(node, Space),
159
168
  )
160
- return PyTreeSpace(batched_tree)
161
- return BatchedSpace(space=space, batch_size=batch_size)
169
+
170
+
171
+ def peel_batched(space: Space) -> tuple[tuple[int, ...], Space]:
172
+ """Collect batch dimensions and return (batch_dims_tuple, base_space)."""
173
+ dims: list[int] = []
174
+ s: Space = space
175
+ while isinstance(s, BatchedSpace):
176
+ dims.append(s.batch_size)
177
+ s = s.space
178
+ return tuple(dims), s
162
179
 
163
180
 
164
181
  class BatchedSpace(Space):
@@ -190,16 +207,19 @@ class BatchedSpace(Space):
190
207
 
191
208
  @cached_property
192
209
  def shape(self) -> PyTree:
193
- inner_shape = self.space.shape
194
- # For tuple shapes (leaf spaces), prepend batch dimension.
195
- # PyTree shapes are handled by wrapping leaves with BatchedSpace via batch_space.
196
- if isinstance(inner_shape, tuple):
197
- return (self.batch_size,) + inner_shape
198
- return inner_shape
210
+ batch_dims, base = peel_batched(self)
211
+ if isinstance(base, PyTreeSpace):
212
+ return jax.tree.map(
213
+ lambda space: batch_dims + space.shape,
214
+ base.tree,
215
+ is_leaf=lambda node: isinstance(node, Space),
216
+ )
217
+ return batch_dims + base.shape
199
218
 
200
219
  @property
201
- def dtype(self):
202
- return getattr(self.space, "dtype", None)
220
+ def dtype(self) -> PyTree:
221
+ _, base = peel_batched(self)
222
+ return base.dtype
203
223
 
204
224
  def __repr__(self) -> str:
205
225
  return f"BatchedSpace(space={self.space!r}, batch_size={self.batch_size})"
@@ -1,6 +1,6 @@
1
1
  import dataclasses
2
2
  from dataclasses import KW_ONLY
3
- from typing import Any, Iterable, Iterator, Mapping, Self, Tuple
3
+ from typing import Any, Iterable, Iterator, Mapping, Self, Tuple, dataclass_transform
4
4
 
5
5
  import jax
6
6
 
@@ -24,6 +24,7 @@ def static_field(**kwargs):
24
24
  return field(pytree_node=False, **kwargs)
25
25
 
26
26
 
27
+ @dataclass_transform()
27
28
  class FrozenPyTreeNode:
28
29
  """
29
30
  Frozen dataclass base that is a JAX pytree node.
@@ -64,6 +65,7 @@ class FrozenPyTreeNode:
64
65
  return dataclasses.replace(self, **changes)
65
66
 
66
67
 
68
+ @dataclass_transform()
67
69
  @jax.tree_util.register_pytree_node_class
68
70
  @dataclasses.dataclass(frozen=True, eq=True, repr=True, slots=False)
69
71
  class Container:
@@ -104,6 +106,13 @@ class Container:
104
106
  for k, v in self._extras.items():
105
107
  yield (k, v)
106
108
 
109
+ def __str__(self) -> str:
110
+ core_str = super().__str__()
111
+ if not self._extras:
112
+ return core_str
113
+ extras_str = f", {', '.join(f'{k}={v!r}' for k, v in self._extras.items())}"
114
+ return f"{core_str[:-1]}{extras_str})" # remove closing parenthesis from core
115
+
107
116
  def update(self, **changes: PyTree) -> Self:
108
117
  core_names = {f.name for f in dataclasses.fields(self) if f.name != "_extras"}
109
118
  core_updates: dict[str, PyTree] = {}
@@ -0,0 +1,36 @@
1
+ from envelope.wrappers.autoreset_wrapper import AutoResetWrapper
2
+ from envelope.wrappers.clip_action_wrapper import ClipActionWrapper
3
+ from envelope.wrappers.continuous_observation_wrapper import (
4
+ ContinuousObservationWrapper,
5
+ )
6
+ from envelope.wrappers.episode_statistics_wrapper import EpisodeStatisticsWrapper
7
+ from envelope.wrappers.flatten_action_wrapper import FlattenActionWrapper
8
+ from envelope.wrappers.flatten_observation_wrapper import FlattenObservationWrapper
9
+ from envelope.wrappers.observation_normalization_wrapper import (
10
+ ObservationNormalizationWrapper,
11
+ )
12
+ from envelope.wrappers.pooled_init_vmap_wrapper import PooledInitVmapWrapper
13
+ from envelope.wrappers.state_injection_wrapper import StateInjectionWrapper
14
+ from envelope.wrappers.truncation_wrapper import TruncationWrapper
15
+ from envelope.wrappers.vmap_envs_wrapper import VmapEnvsWrapper
16
+ from envelope.wrappers.vmap_wrapper import VmapWrapper
17
+ from envelope.wrappers.wrapper import WrappedState, Wrapper
18
+
19
+ __all__ = [
20
+ # Basic functionality
21
+ "Wrapper",
22
+ "WrappedState",
23
+ # Wrappers
24
+ "AutoResetWrapper",
25
+ "ClipActionWrapper",
26
+ "ContinuousObservationWrapper",
27
+ "EpisodeStatisticsWrapper",
28
+ "FlattenActionWrapper",
29
+ "FlattenObservationWrapper",
30
+ "ObservationNormalizationWrapper",
31
+ "PooledInitVmapWrapper",
32
+ "StateInjectionWrapper",
33
+ "TruncationWrapper",
34
+ "VmapWrapper",
35
+ "VmapEnvsWrapper",
36
+ ]
@@ -0,0 +1,80 @@
1
+ from typing import override
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+
6
+ from envelope.environment import Info
7
+ from envelope.struct import field
8
+ from envelope.typing import Key, PyTree
9
+ from envelope.wrappers.wrapper import WrappedState, Wrapper
10
+
11
+
12
+ class AutoResetWrapper(Wrapper):
13
+ """Wrapper that automatically resets the environment when an episode ends.
14
+
15
+ When a step results in termination or truncation, this wrapper immediately
16
+ resets the environment. The returned info preserves critical information
17
+ from the terminal step while providing the new episode's initial observation.
18
+
19
+ Info fields after a terminal step (terminated=True or truncated=True):
20
+ obs: Initial observation from the new episode (after reset).
21
+ final: Full info snapshot from the terminal step (before reset).
22
+ terminated: True if the episode ended due to termination.
23
+ truncated: True if the episode ended due to truncation.
24
+ reward: Reward from the terminal step.
25
+
26
+ Info fields during normal steps (terminated=False and truncated=False):
27
+ obs: Current observation.
28
+ final: Info snapshot from the last completed episode (persisted).
29
+ terminated: False.
30
+ truncated: False.
31
+ reward: Reward from the step.
32
+
33
+ This design enables correct value bootstrapping:
34
+ - Use final.obs for value estimation of the true next state
35
+ - On termination: V(s_final) = 0 (episode truly ended)
36
+ - On truncation: bootstrap from V(final.obs) (episode cut off artificially)
37
+ - final persists until the next episode completes, giving easy access
38
+ to last episode's aggregated stats (e.g., final.episode_return)
39
+ """
40
+
41
+ class AutoResetState(WrappedState):
42
+ reset_key: jax.Array = field()
43
+ last_final: Info = field()
44
+
45
+ @override
46
+ def init(self, key: Key) -> tuple[WrappedState, Info]:
47
+ key, subkey = jax.random.split(key)
48
+ inner_state, info = self.env.init(key)
49
+ # Initialize last_final with the reset info (no previous episode yet)
50
+ last_final = jax.tree.map(lambda x: jnp.full_like(x, jnp.nan), info)
51
+ state = self.AutoResetState(
52
+ inner_state=inner_state, reset_key=subkey, last_final=last_final
53
+ )
54
+ return state, info.update(final=state.last_final)
55
+
56
+ @override
57
+ def reset(self, key: Key, state: WrappedState) -> tuple[WrappedState, Info]:
58
+ raise NotImplementedError("Reset is not implemented for AutoResetWrapper")
59
+
60
+ @override
61
+ def step(self, state: WrappedState, action: PyTree) -> tuple[WrappedState, Info]:
62
+ key, key_reset = jax.random.split(state.reset_key)
63
+ state = state.replace(reset_key=key)
64
+
65
+ inner_state, info = self.env.step(state.inner_state, action)
66
+ reset_inner_state, reset_info = self.env.reset(key_reset, inner_state)
67
+
68
+ # Select next state and info based on done
69
+ done = info.terminated | info.truncated
70
+ state = jax.tree.map(
71
+ lambda reset, next: jax.lax.select(done, reset, next),
72
+ state.replace(inner_state=reset_inner_state),
73
+ state.replace(inner_state=inner_state),
74
+ )
75
+ info = jax.tree.map(
76
+ lambda reset, next: jax.lax.select(done, reset, next),
77
+ reset_info.update(final=info),
78
+ info.update(final=state.last_final),
79
+ )
80
+ return state, info