jax-envelope 0.1.1__tar.gz → 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/.gitignore +2 -1
  2. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/PKG-INFO +34 -23
  3. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/README.md +33 -22
  4. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/pyproject.toml +13 -6
  5. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/src/envelope/__init__.py +16 -4
  6. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/src/envelope/compat/brax_envelope.py +5 -3
  7. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/src/envelope/compat/craftax_envelope.py +17 -2
  8. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/src/envelope/compat/gymnax_envelope.py +34 -7
  9. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/src/envelope/compat/jumanji_envelope.py +3 -2
  10. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/src/envelope/compat/kinetix_envelope.py +3 -2
  11. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/src/envelope/compat/mujoco_playground_envelope.py +1 -1
  12. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/src/envelope/compat/navix_envelope.py +1 -1
  13. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/src/envelope/environment.py +16 -9
  14. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/src/envelope/spaces.py +40 -20
  15. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/src/envelope/struct.py +10 -1
  16. jax_envelope-0.3.0/src/envelope/wrappers/__init__.py +36 -0
  17. jax_envelope-0.3.0/src/envelope/wrappers/autoreset_wrapper.py +80 -0
  18. jax_envelope-0.3.0/src/envelope/wrappers/clip_action_wrapper.py +27 -0
  19. jax_envelope-0.3.0/src/envelope/wrappers/continuous_observation_wrapper.py +61 -0
  20. jax_envelope-0.3.0/src/envelope/wrappers/episode_statistics_wrapper.py +40 -0
  21. jax_envelope-0.3.0/src/envelope/wrappers/flatten_action_wrapper.py +75 -0
  22. jax_envelope-0.3.0/src/envelope/wrappers/flatten_observation_wrapper.py +80 -0
  23. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/src/envelope/wrappers/normalization.py +1 -1
  24. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/src/envelope/wrappers/observation_normalization_wrapper.py +28 -16
  25. jax_envelope-0.3.0/src/envelope/wrappers/pooled_init_vmap_wrapper.py +122 -0
  26. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/src/envelope/wrappers/state_injection_wrapper.py +18 -22
  27. jax_envelope-0.3.0/src/envelope/wrappers/truncation_wrapper.py +35 -0
  28. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/src/envelope/wrappers/vmap_envs_wrapper.py +26 -21
  29. jax_envelope-0.3.0/src/envelope/wrappers/vmap_wrapper.py +66 -0
  30. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/src/envelope/wrappers/wrapper.py +8 -8
  31. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/compat/conftest.py +1 -1
  32. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/compat/contract.py +5 -2
  33. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/compat/test_brax_compat.py +6 -6
  34. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/compat/test_craftax_compat.py +4 -2
  35. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/compat/test_create.py +12 -0
  36. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/compat/test_create_integration.py +7 -7
  37. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/compat/test_gymnax_compat.py +3 -3
  38. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/compat/test_jumanji_compat.py +1 -1
  39. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/compat/test_kinetix_compat.py +7 -9
  40. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/compat/test_mujoco_playground_compat.py +6 -6
  41. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/compat/test_navix_compat.py +5 -5
  42. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/spaces/test_batched_space.py +74 -50
  43. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/spaces/test_pytree_space.py +21 -1
  44. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/wrappers/helpers.py +112 -36
  45. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/wrappers/test_autoreset_wrapper.py +164 -67
  46. jax_envelope-0.3.0/tests/wrappers/test_clip_action_wrapper.py +174 -0
  47. jax_envelope-0.3.0/tests/wrappers/test_continuous_observation_wrapper.py +153 -0
  48. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/wrappers/test_environment_wrapper.py +12 -12
  49. jax_envelope-0.3.0/tests/wrappers/test_episode_statistics_wrapper.py +183 -0
  50. jax_envelope-0.3.0/tests/wrappers/test_flatten_action_wrapper.py +215 -0
  51. jax_envelope-0.3.0/tests/wrappers/test_flatten_observation_wrapper.py +178 -0
  52. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/wrappers/test_observation_normalization_wrapper.py +9 -9
  53. jax_envelope-0.3.0/tests/wrappers/test_pooled_init_vmap_wrapper.py +292 -0
  54. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/wrappers/test_state_injection_wrapper.py +19 -19
  55. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/wrappers/test_truncation_wrapper.py +39 -14
  56. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/wrappers/test_vmap_envs_wrapper.py +7 -7
  57. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/wrappers/test_vmap_wrapper.py +18 -18
  58. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/uv.lock +818 -711
  59. jax_envelope-0.1.1/src/envelope/wrappers/__init__.py +0 -20
  60. jax_envelope-0.1.1/src/envelope/wrappers/autoreset_wrapper.py +0 -36
  61. jax_envelope-0.1.1/src/envelope/wrappers/episode_statistics_wrapper.py +0 -47
  62. jax_envelope-0.1.1/src/envelope/wrappers/truncation_wrapper.py +0 -31
  63. jax_envelope-0.1.1/src/envelope/wrappers/vmap_wrapper.py +0 -51
  64. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/.github/workflows/publish.yml +0 -0
  65. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/LICENSE +0 -0
  66. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/src/envelope/compat/__init__.py +0 -0
  67. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/src/envelope/typing.py +0 -0
  68. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/__init__.py +0 -0
  69. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/compat/__init__.py +0 -0
  70. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/spaces/__init__.py +0 -0
  71. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/spaces/test_continuous.py +0 -0
  72. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/spaces/test_discrete.py +0 -0
  73. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/spaces/test_serialization.py +0 -0
  74. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/test_container.py +0 -0
  75. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/test_struct.py +0 -0
  76. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/wrappers/__init__.py +0 -0
  77. {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/wrappers/test_normalization.py +0 -0
@@ -166,7 +166,8 @@ wandb/
166
166
  # ruff
167
167
  .ruff_cache/
168
168
 
169
- # Cursor
169
+ # Vibecoding
170
170
  .cursor
171
171
  .cursorignore
172
172
  AGENTS.md
173
+ CLAUDE.md
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jax-envelope
3
- Version: 0.1.1
3
+ Version: 0.3.0
4
4
  Summary: A JAX-native environment interface with powerful wrappers and adapters for popular RL environment suites
5
5
  Project-URL: Homepage, https://github.com/keraJLi/envelope
6
6
  Project-URL: Repository, https://github.com/keraJLi/envelope
@@ -25,12 +25,13 @@ Requires-Dist: jax>=0.5.0
25
25
  Description-Content-Type: text/markdown
26
26
 
27
27
  # 💌 Envelope: a JAX-native environment interface
28
+
28
29
  ```python
29
30
  # Create environments from JAX-native suites you have installed, ...
30
31
  env = envelope.create("gymnax::CartPole-v1")
31
32
 
32
33
  # ... interact with the environments using a simple interface, ...
33
- state, info = env.reset(key)
34
+ state, info = env.init(key)
34
35
  states, infos = jax.lax.scan(env.step, state, actions)
35
36
  plt.plot(infos.reward.cumsum())
36
37
 
@@ -41,35 +42,42 @@ env = envelope.wrappers.ObservationNormalizationWrapper(env)
41
42
  ```
42
43
 
43
44
  ## 🌍 Simple, expressive interaction!
44
- * **Environments are pytrees**. Squish them through JAX transformations and trace their parameters.
45
- * **Idiomatic jax-y interface** of `reset(key: Key) -> State, Info` and `step(state: State, action: PyTree) -> State, Info`. You can directly `jax.scan` over a `step(...)`!
46
- * **Spaces are super simple**. No `Tuple`, `Dict` nonsense! There are two spaces: `Continuous` and `Discrete`, which you can compose into a `PyTreeSpace`.
47
- * **Explicit episode truncation** supports correctly handling bootstrapping for value-function targets.
48
- * **No auto-reset** by default. Resetting every step can be expensive!
45
+ - **Environments are pytrees**. Squish them through JAX transformations and trace their parameters.
46
+ - **Idiomatic jax-y interface** of `init(key: Key) -> State, Info` and `step(state: State, action: PyTree) -> State, Info`. You can directly `jax.scan` over a `step(...)`!
47
+ - **Spaces are super simple**. No `Tuple`, `Dict` nonsense! There are two spaces: `Continuous` and `Discrete`, which you can compose into a `PyTreeSpace`.
48
+ - **Explicit episode truncation** supports correctly handling bootstrapping for value-function targets.
49
+ - **No auto-reset** by default. Resetting every step can be expensive!
49
50
 
50
51
  ## 💪 Powerful, composable wrappers!
51
- * **Carry state across episodes** to track running statistics, for example to normalize observations.
52
- * **Composable wrappers** can be stacked in any order. For example, `ObservationNormalizationWrapper` before vs. after `VmapWrapper` gives per-env vs. global normalization.
53
- <!-- TODO: Add auto-reset behavior (including state injection) and optimistic resets once I implement them. -->
52
+
53
+ - **Carry state across episodes** to track running statistics, for example to normalize observations.
54
+ - **Composable wrappers** can be stacked in any order. For example, `ObservationNormalizationWrapper` before vs. after `VmapWrapper` gives per-env vs. global normalization.
55
+
56
+
54
57
 
55
58
  ## 🔌 Adapters for existing suites
56
- | 📦 | # 🤖 | # 🌍 |
57
- |------|------|------|
58
- | [gymnax](https://github.com/RobertTLange/gymnax) | 🕺 | 24 |
59
- | [brax](https://github.com/google/brax) | 🕺 | 12 |
60
- | [jumanji](https://github.com/instadeepai/jumanji) | 🕺 / 👯 | 25 / 1 |
61
- | [kinetix](https://github.com/flairox/kinetix) | 🕺 | 74 |
62
- | [craftax](https://github.com/MichaelTMatthews/craftax) | 🕺 | 4 |
63
- | [mujoco_playground](https://github.com/google-deepmind/mujoco_playground) | 🕺 | 54 |
64
- | | |
65
- | Total | 🕺 / 👯 | 193 / 1 |
59
+
60
+
61
+ | 📦 | # 🤖 | # 🌍 |
62
+ | ------------------------------------------------------------------------- | ------- | ------- |
63
+ | [gymnax](https://github.com/RobertTLange/gymnax) | 🕺 | 24 |
64
+ | [brax](https://github.com/google/brax) | 🕺 | 12 |
65
+ | [jumanji](https://github.com/instadeepai/jumanji) | 🕺 / 👯 | 25 / 1 |
66
+ | [kinetix](https://github.com/flairox/kinetix) | 🕺 | 74 |
67
+ | [craftax](https://github.com/MichaelTMatthews/craftax) | 🕺 | 4 |
68
+ | [mujoco_playground](https://github.com/google-deepmind/mujoco_playground) | 🕺 | 54 |
69
+ | | | |
70
+ | Total | 🕺 / 👯 | 193 / 1 |
71
+
66
72
 
67
73
  ```python
68
74
  envelope.create("📦::🌍")
69
75
  ```
76
+
70
77
  let's you create environments from any of the above!
71
78
 
72
79
  ## 📝 Testing
80
+
73
81
  - **Default (no optional compat deps required)**: `uv run pytest -m "not compat"`
74
82
  - **Compat suite (requires full compat dependency group)**:
75
83
  - `uv sync --group compat`
@@ -77,11 +85,14 @@ let's you create environments from any of the above!
77
85
  - If any compat dependency is missing/broken, the run will fail fast with an error telling you what to install.
78
86
 
79
87
  ## 🏗️ Installation
88
+
80
89
  ```bash
81
90
  pip install jax-envelope
82
91
  ```
83
92
 
84
93
  ## 💞 Related projects
85
- * [stoax](https://github.com/EdanToledo/Stoa) is a very similar project that provides adapters and wrappers for the jumanji-like interface.
86
- * Check out all the great suites we have adapters for! [gymnax](https://github.com/RobertTLange/gymnax), [brax](https://github.com/google/brax), [jumanji](https://github.com/instadeepai/jumanji), [kinetix](https://github.com/flairox/kinetix), [craftax](https://github.com/MichaelTMatthews/craftax), [mujoco_playground](https://github.com/google-deepmind/mujoco_playground).
87
- * We will be adding support for [jaxmarl](https://github.com/flairox/jaxmarl) and [pgx](https://github.com/sotetsuk/pgx) in the future, as soon as we figured out the best ever MARL interface for JAX!
94
+
95
+ - [stoa](https://github.com/EdanToledo/Stoa) is a very similar project that provides adapters and wrappers for the jumanji-like interface.
96
+ - Check out all the great suites we have adapters for! [gymnax](https://github.com/RobertTLange/gymnax), [brax](https://github.com/google/brax), [jumanji](https://github.com/instadeepai/jumanji), [kinetix](https://github.com/flairox/kinetix), [craftax](https://github.com/MichaelTMatthews/craftax), [mujoco_playground](https://github.com/google-deepmind/mujoco_playground).
97
+ - We will be adding support for [jaxmarl](https://github.com/flairox/jaxmarl) and [pgx](https://github.com/sotetsuk/pgx) in the future, as soon as we figured out the best ever MARL interface for JAX!
98
+
@@ -1,10 +1,11 @@
1
1
  # 💌 Envelope: a JAX-native environment interface
2
+
2
3
  ```python
3
4
  # Create environments from JAX-native suites you have installed, ...
4
5
  env = envelope.create("gymnax::CartPole-v1")
5
6
 
6
7
  # ... interact with the environments using a simple interface, ...
7
- state, info = env.reset(key)
8
+ state, info = env.init(key)
8
9
  states, infos = jax.lax.scan(env.step, state, actions)
9
10
  plt.plot(infos.reward.cumsum())
10
11
 
@@ -15,35 +16,42 @@ env = envelope.wrappers.ObservationNormalizationWrapper(env)
15
16
  ```
16
17
 
17
18
  ## 🌍 Simple, expressive interaction!
18
- * **Environments are pytrees**. Squish them through JAX transformations and trace their parameters.
19
- * **Idiomatic jax-y interface** of `reset(key: Key) -> State, Info` and `step(state: State, action: PyTree) -> State, Info`. You can directly `jax.scan` over a `step(...)`!
20
- * **Spaces are super simple**. No `Tuple`, `Dict` nonsense! There are two spaces: `Continuous` and `Discrete`, which you can compose into a `PyTreeSpace`.
21
- * **Explicit episode truncation** supports correctly handling bootstrapping for value-function targets.
22
- * **No auto-reset** by default. Resetting every step can be expensive!
19
+ - **Environments are pytrees**. Squish them through JAX transformations and trace their parameters.
20
+ - **Idiomatic jax-y interface** of `init(key: Key) -> State, Info` and `step(state: State, action: PyTree) -> State, Info`. You can directly `jax.scan` over a `step(...)`!
21
+ - **Spaces are super simple**. No `Tuple`, `Dict` nonsense! There are two spaces: `Continuous` and `Discrete`, which you can compose into a `PyTreeSpace`.
22
+ - **Explicit episode truncation** supports correctly handling bootstrapping for value-function targets.
23
+ - **No auto-reset** by default. Resetting every step can be expensive!
23
24
 
24
25
  ## 💪 Powerful, composable wrappers!
25
- * **Carry state across episodes** to track running statistics, for example to normalize observations.
26
- * **Composable wrappers** can be stacked in any order. For example, `ObservationNormalizationWrapper` before vs. after `VmapWrapper` gives per-env vs. global normalization.
27
- <!-- TODO: Add auto-reset behavior (including state injection) and optimistic resets once I implement them. -->
26
+
27
+ - **Carry state across episodes** to track running statistics, for example to normalize observations.
28
+ - **Composable wrappers** can be stacked in any order. For example, `ObservationNormalizationWrapper` before vs. after `VmapWrapper` gives per-env vs. global normalization.
29
+
30
+
28
31
 
29
32
  ## 🔌 Adapters for existing suites
30
- | 📦 | # 🤖 | # 🌍 |
31
- |------|------|------|
32
- | [gymnax](https://github.com/RobertTLange/gymnax) | 🕺 | 24 |
33
- | [brax](https://github.com/google/brax) | 🕺 | 12 |
34
- | [jumanji](https://github.com/instadeepai/jumanji) | 🕺 / 👯 | 25 / 1 |
35
- | [kinetix](https://github.com/flairox/kinetix) | 🕺 | 74 |
36
- | [craftax](https://github.com/MichaelTMatthews/craftax) | 🕺 | 4 |
37
- | [mujoco_playground](https://github.com/google-deepmind/mujoco_playground) | 🕺 | 54 |
38
- | | |
39
- | Total | 🕺 / 👯 | 193 / 1 |
33
+
34
+
35
+ | 📦 | # 🤖 | # 🌍 |
36
+ | ------------------------------------------------------------------------- | ------- | ------- |
37
+ | [gymnax](https://github.com/RobertTLange/gymnax) | 🕺 | 24 |
38
+ | [brax](https://github.com/google/brax) | 🕺 | 12 |
39
+ | [jumanji](https://github.com/instadeepai/jumanji) | 🕺 / 👯 | 25 / 1 |
40
+ | [kinetix](https://github.com/flairox/kinetix) | 🕺 | 74 |
41
+ | [craftax](https://github.com/MichaelTMatthews/craftax) | 🕺 | 4 |
42
+ | [mujoco_playground](https://github.com/google-deepmind/mujoco_playground) | 🕺 | 54 |
43
+ | | | |
44
+ | Total | 🕺 / 👯 | 193 / 1 |
45
+
40
46
 
41
47
  ```python
42
48
  envelope.create("📦::🌍")
43
49
  ```
50
+
44
51
  let's you create environments from any of the above!
45
52
 
46
53
  ## 📝 Testing
54
+
47
55
  - **Default (no optional compat deps required)**: `uv run pytest -m "not compat"`
48
56
  - **Compat suite (requires full compat dependency group)**:
49
57
  - `uv sync --group compat`
@@ -51,11 +59,14 @@ let's you create environments from any of the above!
51
59
  - If any compat dependency is missing/broken, the run will fail fast with an error telling you what to install.
52
60
 
53
61
  ## 🏗️ Installation
62
+
54
63
  ```bash
55
64
  pip install jax-envelope
56
65
  ```
57
66
 
58
67
  ## 💞 Related projects
59
- * [stoax](https://github.com/EdanToledo/Stoa) is a very similar project that provides adapters and wrappers for the jumanji-like interface.
60
- * Check out all the great suites we have adapters for! [gymnax](https://github.com/RobertTLange/gymnax), [brax](https://github.com/google/brax), [jumanji](https://github.com/instadeepai/jumanji), [kinetix](https://github.com/flairox/kinetix), [craftax](https://github.com/MichaelTMatthews/craftax), [mujoco_playground](https://github.com/google-deepmind/mujoco_playground).
61
- * We will be adding support for [jaxmarl](https://github.com/flairox/jaxmarl) and [pgx](https://github.com/sotetsuk/pgx) in the future, as soon as we figured out the best ever MARL interface for JAX!
68
+
69
+ - [stoa](https://github.com/EdanToledo/Stoa) is a very similar project that provides adapters and wrappers for the jumanji-like interface.
70
+ - Check out all the great suites we have adapters for! [gymnax](https://github.com/RobertTLange/gymnax), [brax](https://github.com/google/brax), [jumanji](https://github.com/instadeepai/jumanji), [kinetix](https://github.com/flairox/kinetix), [craftax](https://github.com/MichaelTMatthews/craftax), [mujoco_playground](https://github.com/google-deepmind/mujoco_playground).
71
+ - We will be adding support for [jaxmarl](https://github.com/flairox/jaxmarl) and [pgx](https://github.com/sotetsuk/pgx) in the future, as soon as we figured out the best ever MARL interface for JAX!
72
+
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "jax-envelope"
3
- version = "0.1.1"
3
+ version = "0.3.0"
4
4
  description = "A JAX-native environment interface with powerful wrappers and adapters for popular RL environment suites"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.12"
@@ -51,18 +51,15 @@ allow-direct-references = true
51
51
  [tool.hatch.build.targets.wheel]
52
52
  packages = ["src/envelope"]
53
53
 
54
- [tool.uv.sources]
55
- gymnax = { git = "https://github.com/RobertTLange/gymnax" }
56
-
57
54
  [dependency-groups]
58
55
  compat = [
59
- "gymnax @ git+https://github.com/RobertTLange/gymnax@main",
60
56
  "brax>=0.13.0",
61
57
  "craftax>=1.4.3",
62
58
  "navix>=0.7.0",
63
59
  "jumanji>=1.0.1",
64
- "kinetix-env>=2.0.0",
65
60
  "playground>=0.1.0",
61
+ "gymnax",
62
+ "kinetix-env",
66
63
  ]
67
64
  dev = [
68
65
  "hypothesis>=6.148.1",
@@ -71,7 +68,17 @@ dev = [
71
68
  "ruff>=0.14.2",
72
69
  ]
73
70
 
71
+ [tool.uv]
72
+ override-dependencies = [
73
+ "tensorflow-probability>=0.26.0.dev20260116",
74
+ ]
75
+
76
+ [tool.uv.sources]
77
+ gymnax = { git = "https://github.com/RobertTLange/gymnax.git" }
78
+ kinetix-env = { git = "https://github.com/FLAIROx/Kinetix.git" }
79
+
74
80
  [tool.pytest.ini_options]
81
+ testpaths = ["tests"]
75
82
  markers = [
76
83
  "compat: tests requiring optional compat dependencies",
77
84
  ]
@@ -1,16 +1,22 @@
1
1
  from envelope.compat import create
2
2
  from envelope.environment import Environment, Info, InfoContainer
3
3
  from envelope.spaces import BatchedSpace, Continuous, Discrete, PyTreeSpace, Space
4
- from envelope.struct import field, static_field, FrozenPyTreeNode, Container
4
+ from envelope.struct import Container, FrozenPyTreeNode, field, static_field
5
5
  from envelope.wrappers import (
6
- Wrapper,
7
- WrappedState,
8
6
  AutoResetWrapper,
7
+ ClipActionWrapper,
8
+ ContinuousObservationWrapper,
9
+ EpisodeStatisticsWrapper,
10
+ FlattenActionWrapper,
11
+ FlattenObservationWrapper,
9
12
  ObservationNormalizationWrapper,
13
+ PooledInitVmapWrapper,
10
14
  StateInjectionWrapper,
11
15
  TruncationWrapper,
12
- VmapWrapper,
13
16
  VmapEnvsWrapper,
17
+ VmapWrapper,
18
+ WrappedState,
19
+ Wrapper,
14
20
  )
15
21
 
16
22
  __all__ = [
@@ -34,7 +40,13 @@ __all__ = [
34
40
  "Wrapper",
35
41
  "WrappedState",
36
42
  "AutoResetWrapper",
43
+ "ClipActionWrapper",
44
+ "ContinuousObservationWrapper",
45
+ "EpisodeStatisticsWrapper",
46
+ "FlattenActionWrapper",
47
+ "FlattenObservationWrapper",
37
48
  "ObservationNormalizationWrapper",
49
+ "PooledInitVmapWrapper",
38
50
  "StateInjectionWrapper",
39
51
  "TruncationWrapper",
40
52
  "VmapWrapper",
@@ -48,7 +48,7 @@ class BraxEnvelope(Environment):
48
48
  def default_max_steps(self) -> int:
49
49
  return _BRAX_DEFAULT_EPISODE_LENGTH
50
50
 
51
- def __post_init__(self) -> "BraxEnvelope":
51
+ def __post_init__(self):
52
52
  if isinstance(self.brax_env, BraxWrapper):
53
53
  warnings.warn(
54
54
  "Environment wrapping should be handled by envelope. "
@@ -57,7 +57,7 @@ class BraxEnvelope(Environment):
57
57
  object.__setattr__(self, "brax_env", self.brax_env.unwrapped)
58
58
 
59
59
  @override
60
- def reset(self, key: Key) -> tuple[State, Info]:
60
+ def init(self, key: Key) -> tuple[State, Info]:
61
61
  brax_state = self.brax_env.reset(key)
62
62
  info = InfoContainer(obs=brax_state.obs, reward=0.0, terminated=False)
63
63
  info = info.update(**dataclasses.asdict(brax_state))
@@ -67,7 +67,9 @@ class BraxEnvelope(Environment):
67
67
  def step(self, state: State, action: PyTree) -> tuple[State, Info]:
68
68
  brax_state = self.brax_env.step(state, action)
69
69
  info = InfoContainer(
70
- obs=brax_state.obs, reward=brax_state.reward, terminated=brax_state.done
70
+ obs=brax_state.obs,
71
+ reward=brax_state.reward,
72
+ terminated=jnp.asarray(brax_state.done, dtype=bool),
71
73
  )
72
74
  info = info.update(**dataclasses.asdict(brax_state))
73
75
  return brax_state, info
@@ -22,7 +22,7 @@ class CraftaxEnvelope(Environment):
22
22
  """Wrapper to convert a Craftax environment to a envelope environment."""
23
23
 
24
24
  craftax_env: Any = static_field()
25
- env_params: PyTree
25
+ env_params: PyTree = static_field() # TODO: remove static marker as soon as craftax merges https://github.com/MichaelTMatthews/Craftax/pull/48
26
26
 
27
27
  @classmethod
28
28
  def from_name(
@@ -54,12 +54,27 @@ class CraftaxEnvelope(Environment):
54
54
  def default_max_steps(self) -> int:
55
55
  return int(self.craftax_env.default_params.max_timesteps)
56
56
 
57
+ @cached_property
58
+ def _craftax_info_placeholder(self) -> PyTree:
59
+ key = jax.random.PRNGKey(0)
60
+ _, state = self.craftax_env.reset(key, self.env_params)
61
+ _, _, _, _, info = self.craftax_env.step(
62
+ key,
63
+ state,
64
+ self.craftax_env.action_space(self.env_params).sample(key),
65
+ self.env_params,
66
+ )
67
+ return jax.tree.map(lambda x: jnp.full_like(x, jnp.nan), info)
68
+
57
69
  @override
58
- def reset(self, key: Key) -> tuple[State, Info]:
70
+ def init(self, key: Key) -> tuple[State, Info]:
71
+ # TODO: this function does not add env_info (or comparable) to the info
72
+ # container. We should add tests for this (and all other envelopes) and fix it.
59
73
  key, subkey = jax.random.split(key)
60
74
  obs, env_state = self.craftax_env.reset(subkey, self.env_params)
61
75
  state = Container().update(key=key, env_state=env_state)
62
76
  info = InfoContainer(obs=obs, reward=0.0, terminated=False)
77
+ info = info.update(info=self._craftax_info_placeholder)
63
78
  return state, info
64
79
 
65
80
  @override
@@ -1,5 +1,5 @@
1
1
  from functools import cached_property
2
- from typing import Any, override
2
+ from typing import Any, Callable, cast, override
3
3
 
4
4
  import jax
5
5
  import jax.numpy as jnp
@@ -10,15 +10,24 @@ from gymnax.environments.environment import EnvParams as GymnaxEnvParams
10
10
 
11
11
  from envelope import spaces as envelope_spaces
12
12
  from envelope.environment import Environment, Info, InfoContainer, State
13
- from envelope.struct import Container, static_field
13
+ from envelope.struct import Container, field, static_field
14
14
  from envelope.typing import Key, PyTree
15
15
 
16
+ _GymnaxReset = Callable[
17
+ [Key, GymnaxEnvParams],
18
+ tuple[PyTree, Any],
19
+ ]
20
+ _GymnaxStep = Callable[
21
+ [Key, Any, PyTree, GymnaxEnvParams],
22
+ tuple[PyTree, Any, jnp.ndarray, jnp.ndarray, PyTree],
23
+ ]
24
+
16
25
 
17
26
  class GymnaxEnvelope(Environment):
18
27
  """Wrapper to convert a Gymnax environment to a envelope environment."""
19
28
 
20
29
  gymnax_env: GymnaxEnv = static_field()
21
- env_params: PyTree
30
+ env_params: PyTree = field()
22
31
 
23
32
  @classmethod
24
33
  def from_name(
@@ -43,19 +52,37 @@ class GymnaxEnvelope(Environment):
43
52
  def default_max_steps(self) -> int:
44
53
  return int(self.gymnax_env.default_params.max_steps_in_episode)
45
54
 
55
+ @cached_property
56
+ def _gymnax_info_placeholder(self) -> PyTree:
57
+ reset_fn = cast(_GymnaxReset, self.gymnax_env.reset)
58
+ step_fn = cast(_GymnaxStep, self.gymnax_env.step)
59
+
60
+ key = jax.random.PRNGKey(0)
61
+ _, state = reset_fn(key, self.env_params)
62
+ _, _, _, _, info = step_fn(
63
+ key,
64
+ state,
65
+ self.gymnax_env.action_space(self.env_params).sample(key),
66
+ self.env_params,
67
+ )
68
+ return jax.tree.map(lambda x: jnp.full_like(x, jnp.nan, dtype=float), info)
69
+
46
70
  @override
47
- def reset(self, key: Key) -> tuple[State, Info]:
71
+ def init(self, key: Key) -> tuple[State, Info]:
72
+ reset_fn = cast(_GymnaxReset, self.gymnax_env.reset)
73
+
48
74
  key, subkey = jax.random.split(key)
49
- obs, env_state = self.gymnax_env.reset(subkey, self.env_params)
75
+ obs, env_state = reset_fn(subkey, self.env_params)
50
76
  state = Container().update(key=key, env_state=env_state)
51
77
  info = InfoContainer(obs=obs, reward=0.0, terminated=False)
52
- info = info.update(info=None)
78
+ info = info.update(info=self._gymnax_info_placeholder)
53
79
  return state, info
54
80
 
55
81
  @override
56
82
  def step(self, state: State, action: PyTree) -> tuple[State, Info]:
57
83
  key, subkey = jax.random.split(state.key)
58
- obs, env_state, reward, done, env_info = self.gymnax_env.step(
84
+ step_fn = cast(_GymnaxStep, self.gymnax_env.step)
85
+ obs, env_state, reward, done, env_info = step_fn(
59
86
  subkey, state.env_state, action, self.env_params
60
87
  )
61
88
  state = state.update(key=key, env_state=env_state)
@@ -48,7 +48,7 @@ class JumanjiEnvelope(Environment):
48
48
  return self._default_time_limit
49
49
 
50
50
  @override
51
- def reset(self, key: Key) -> tuple[State, Info]:
51
+ def init(self, key: Key) -> tuple[State, Info]:
52
52
  env_state, timestep = self.jumanji_env.reset(key)
53
53
  info = convert_jumanji_to_envelope_info(timestep)
54
54
  return env_state, info
@@ -81,8 +81,9 @@ class JumanjiEnvelope(Environment):
81
81
 
82
82
 
83
83
  def convert_jumanji_to_envelope_info(timestep: JumanjiTimeStep) -> InfoContainer:
84
+ term = jnp.asarray(timestep.last(), dtype=bool)
84
85
  info = InfoContainer(
85
- obs=timestep.observation, reward=timestep.reward, terminated=timestep.last()
86
+ obs=timestep.observation, reward=timestep.reward, terminated=term
86
87
  ).update(**timestep.extras)
87
88
  return info
88
89
 
@@ -28,6 +28,7 @@ from kinetix.environment import (
28
28
  from kinetix.environment.ued.ued import make_reset_fn_sample_kinetix_level
29
29
  from kinetix.util.saving import load_from_json_file
30
30
 
31
+ from envelope import field
31
32
  from envelope import spaces as envelope_spaces
32
33
  from envelope.compat.gymnax_envelope import _convert_space as _convert_gymnax_space
33
34
  from envelope.environment import Environment, Info, InfoContainer, State
@@ -67,7 +68,7 @@ class KinetixEnvelope(Environment):
67
68
  """Wrapper to convert a Kinetix environment to a envelope environment."""
68
69
 
69
70
  kinetix_env: Any = static_field()
70
- env_params: Any
71
+ env_params: Any = field()
71
72
 
72
73
  @property
73
74
  def default_max_steps(self) -> int:
@@ -162,7 +163,7 @@ class KinetixEnvelope(Environment):
162
163
  return cls(kinetix_env=kinetix_env, env_params=env_params)
163
164
 
164
165
  @override
165
- def reset(self, key: Key) -> tuple[State, Info]:
166
+ def init(self, key: Key) -> tuple[State, Info]:
166
167
  key, subkey = jax.random.split(key)
167
168
  obs, env_state = self.kinetix_env.reset(subkey, self.env_params)
168
169
  state_out = Container().update(key=key, env_state=env_state)
@@ -56,7 +56,7 @@ class MujocoPlaygroundEnvelope(Environment):
56
56
  return self._default_max_steps
57
57
 
58
58
  @override
59
- def reset(self, key: Key) -> tuple[State, Info]:
59
+ def init(self, key: Key) -> tuple[State, Info]:
60
60
  env_state = self.mujoco_playground_env.reset(key)
61
61
  info = InfoContainer(obs=env_state.obs, reward=0.0, terminated=False)
62
62
  info = info.update(**dataclasses.asdict(env_state))
@@ -38,7 +38,7 @@ class NavixEnvelope(Environment):
38
38
  return _NAVIX_DEFAULT_MAX_STEPS
39
39
 
40
40
  @override
41
- def reset(self, key: Key) -> tuple[State, Info]:
41
+ def init(self, key: Key) -> tuple[State, Info]:
42
42
  timestep = self.navix_env.reset(key)
43
43
  return timestep, convert_navix_to_envelope_info(timestep)
44
44
 
@@ -5,7 +5,7 @@ from typing import Protocol, runtime_checkable
5
5
 
6
6
  from envelope import spaces
7
7
  from envelope.struct import Container, FrozenPyTreeNode
8
- from envelope.typing import Key, PyTree
8
+ from envelope.typing import Array, Key, PyTree
9
9
 
10
10
  __all__ = ["Environment", "State", "Info", "InfoContainer"]
11
11
 
@@ -23,7 +23,7 @@ class Info(Protocol):
23
23
 
24
24
  class InfoContainer(Container):
25
25
  obs: PyTree
26
- reward: float
26
+ reward: float | Array
27
27
  terminated: bool
28
28
  truncated: bool = field(default=False)
29
29
 
@@ -38,18 +38,25 @@ class Environment(ABC, FrozenPyTreeNode):
38
38
 
39
39
  State is an opaque PyTree owned by each environment; wrappers that stack
40
40
  environments should expose their wrapped env state as `inner_state` while
41
- adding any wrapper-specific fields. `reset` may optionally receive a prior
42
- state (for cross-episode persistence) and arbitrary **kwargs that wrappers
43
- or environments can use.
41
+ adding any wrapper-specific fields.
42
+
43
+ Two distinct lifecycle methods:
44
+ init(key) - Initialize environment and all state from scratch.
45
+ reset(state, key) - Reset the inner environment while preserving
46
+ episode-persistent state.
44
47
  """
45
48
 
46
49
  @abstractmethod
47
- def reset(
48
- self, key: Key, state: State | None = None, **kwargs
49
- ) -> tuple[State, Info]: ...
50
+ def init(self, key: Key) -> tuple[State, Info]:
51
+ """Initialize environment and all state from scratch."""
52
+ ...
53
+
54
+ def reset(self, state: State, key: Key) -> tuple[State, Info]:
55
+ """Reset the inner environment while preserving episode-persistent state."""
56
+ return self.init(key)
50
57
 
51
58
  @abstractmethod
52
- def step(self, state: State, action: PyTree, **kwargs) -> tuple[State, Info]: ...
59
+ def step(self, state: State, action: PyTree) -> tuple[State, Info]: ...
53
60
 
54
61
  @abstractmethod
55
62
  @cached_property