jax-envelope 0.2.0__tar.gz → 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/PKG-INFO +34 -23
  2. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/README.md +33 -22
  3. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/pyproject.toml +1 -1
  4. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/compat/brax_envelope.py +1 -1
  5. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/compat/jumanji_envelope.py +1 -1
  6. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/environment.py +2 -2
  7. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/spaces.py +1 -1
  8. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/wrappers/autoreset_wrapper.py +2 -2
  9. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/wrappers/continuous_observation_wrapper.py +2 -2
  10. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/wrappers/episode_statistics_wrapper.py +2 -2
  11. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/wrappers/flatten_observation_wrapper.py +3 -4
  12. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/wrappers/observation_normalization_wrapper.py +2 -2
  13. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/wrappers/pooled_init_vmap_wrapper.py +2 -2
  14. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/wrappers/state_injection_wrapper.py +2 -2
  15. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/wrappers/truncation_wrapper.py +2 -2
  16. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/wrappers/vmap_envs_wrapper.py +5 -5
  17. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/wrappers/vmap_wrapper.py +2 -2
  18. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/wrappers/wrapper.py +3 -3
  19. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/compat/test_brax_compat.py +2 -2
  20. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/wrappers/helpers.py +25 -35
  21. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/wrappers/test_autoreset_wrapper.py +2 -2
  22. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/wrappers/test_clip_action_wrapper.py +2 -2
  23. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/wrappers/test_continuous_observation_wrapper.py +1 -1
  24. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/wrappers/test_episode_statistics_wrapper.py +2 -2
  25. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/wrappers/test_flatten_action_wrapper.py +4 -4
  26. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/wrappers/test_flatten_observation_wrapper.py +2 -2
  27. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/wrappers/test_pooled_init_vmap_wrapper.py +1 -1
  28. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/wrappers/test_state_injection_wrapper.py +4 -4
  29. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/wrappers/test_truncation_wrapper.py +2 -2
  30. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/uv.lock +1 -1
  31. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/.github/workflows/publish.yml +0 -0
  32. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/.gitignore +0 -0
  33. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/LICENSE +0 -0
  34. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/__init__.py +0 -0
  35. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/compat/__init__.py +0 -0
  36. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/compat/craftax_envelope.py +0 -0
  37. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/compat/gymnax_envelope.py +0 -0
  38. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/compat/kinetix_envelope.py +0 -0
  39. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/compat/mujoco_playground_envelope.py +0 -0
  40. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/compat/navix_envelope.py +0 -0
  41. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/struct.py +0 -0
  42. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/typing.py +0 -0
  43. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/wrappers/__init__.py +0 -0
  44. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/wrappers/clip_action_wrapper.py +0 -0
  45. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/wrappers/flatten_action_wrapper.py +0 -0
  46. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/wrappers/normalization.py +0 -0
  47. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/__init__.py +0 -0
  48. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/compat/__init__.py +0 -0
  49. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/compat/conftest.py +0 -0
  50. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/compat/contract.py +0 -0
  51. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/compat/test_craftax_compat.py +0 -0
  52. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/compat/test_create.py +0 -0
  53. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/compat/test_create_integration.py +0 -0
  54. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/compat/test_gymnax_compat.py +0 -0
  55. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/compat/test_jumanji_compat.py +0 -0
  56. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/compat/test_kinetix_compat.py +0 -0
  57. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/compat/test_mujoco_playground_compat.py +0 -0
  58. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/compat/test_navix_compat.py +0 -0
  59. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/spaces/__init__.py +0 -0
  60. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/spaces/test_batched_space.py +0 -0
  61. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/spaces/test_continuous.py +0 -0
  62. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/spaces/test_discrete.py +0 -0
  63. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/spaces/test_pytree_space.py +0 -0
  64. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/spaces/test_serialization.py +0 -0
  65. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/test_container.py +0 -0
  66. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/test_struct.py +0 -0
  67. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/wrappers/__init__.py +0 -0
  68. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/wrappers/test_environment_wrapper.py +0 -0
  69. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/wrappers/test_normalization.py +0 -0
  70. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/wrappers/test_observation_normalization_wrapper.py +0 -0
  71. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/wrappers/test_vmap_envs_wrapper.py +0 -0
  72. {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/wrappers/test_vmap_wrapper.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jax-envelope
3
- Version: 0.2.0
3
+ Version: 0.3.0
4
4
  Summary: A JAX-native environment interface with powerful wrappers and adapters for popular RL environment suites
5
5
  Project-URL: Homepage, https://github.com/keraJLi/envelope
6
6
  Project-URL: Repository, https://github.com/keraJLi/envelope
@@ -25,12 +25,13 @@ Requires-Dist: jax>=0.5.0
25
25
  Description-Content-Type: text/markdown
26
26
 
27
27
  # 💌 Envelope: a JAX-native environment interface
28
+
28
29
  ```python
29
30
  # Create environments from JAX-native suites you have installed, ...
30
31
  env = envelope.create("gymnax::CartPole-v1")
31
32
 
32
33
  # ... interact with the environments using a simple interface, ...
33
- state, info = env.reset(key)
34
+ state, info = env.init(key)
34
35
  states, infos = jax.lax.scan(env.step, state, actions)
35
36
  plt.plot(infos.reward.cumsum())
36
37
 
@@ -41,35 +42,42 @@ env = envelope.wrappers.ObservationNormalizationWrapper(env)
41
42
  ```
42
43
 
43
44
  ## 🌍 Simple, expressive interaction!
44
- * **Environments are pytrees**. Squish them through JAX transformations and trace their parameters.
45
- * **Idiomatic jax-y interface** of `reset(key: Key) -> State, Info` and `step(state: State, action: PyTree) -> State, Info`. You can directly `jax.scan` over a `step(...)`!
46
- * **Spaces are super simple**. No `Tuple`, `Dict` nonsense! There are two spaces: `Continuous` and `Discrete`, which you can compose into a `PyTreeSpace`.
47
- * **Explicit episode truncation** supports correctly handling bootstrapping for value-function targets.
48
- * **No auto-reset** by default. Resetting every step can be expensive!
45
+ - **Environments are pytrees**. Squish them through JAX transformations and trace their parameters.
46
+ - **Idiomatic jax-y interface** of `init(key: Key) -> State, Info` and `step(state: State, action: PyTree) -> State, Info`. You can directly `jax.scan` over a `step(...)`!
47
+ - **Spaces are super simple**. No `Tuple`, `Dict` nonsense! There are two spaces: `Continuous` and `Discrete`, which you can compose into a `PyTreeSpace`.
48
+ - **Explicit episode truncation** supports correctly handling bootstrapping for value-function targets.
49
+ - **No auto-reset** by default. Resetting every step can be expensive!
49
50
 
50
51
  ## 💪 Powerful, composable wrappers!
51
- * **Carry state across episodes** to track running statistics, for example to normalize observations.
52
- * **Composable wrappers** can be stacked in any order. For example, `ObservationNormalizationWrapper` before vs. after `VmapWrapper` gives per-env vs. global normalization.
53
- <!-- TODO: Add auto-reset behavior (including state injection) and optimistic resets once I implement them. -->
52
+
53
+ - **Carry state across episodes** to track running statistics, for example to normalize observations.
54
+ - **Composable wrappers** can be stacked in any order. For example, `ObservationNormalizationWrapper` before vs. after `VmapWrapper` gives per-env vs. global normalization.
55
+
56
+
54
57
 
55
58
  ## 🔌 Adapters for existing suites
56
- | 📦 | # 🤖 | # 🌍 |
57
- |------|------|------|
58
- | [gymnax](https://github.com/RobertTLange/gymnax) | 🕺 | 24 |
59
- | [brax](https://github.com/google/brax) | 🕺 | 12 |
60
- | [jumanji](https://github.com/instadeepai/jumanji) | 🕺 / 👯 | 25 / 1 |
61
- | [kinetix](https://github.com/flairox/kinetix) | 🕺 | 74 |
62
- | [craftax](https://github.com/MichaelTMatthews/craftax) | 🕺 | 4 |
63
- | [mujoco_playground](https://github.com/google-deepmind/mujoco_playground) | 🕺 | 54 |
64
- | | |
65
- | Total | 🕺 / 👯 | 193 / 1 |
59
+
60
+
61
+ | 📦 | # 🤖 | # 🌍 |
62
+ | ------------------------------------------------------------------------- | ------- | ------- |
63
+ | [gymnax](https://github.com/RobertTLange/gymnax) | 🕺 | 24 |
64
+ | [brax](https://github.com/google/brax) | 🕺 | 12 |
65
+ | [jumanji](https://github.com/instadeepai/jumanji) | 🕺 / 👯 | 25 / 1 |
66
+ | [kinetix](https://github.com/flairox/kinetix) | 🕺 | 74 |
67
+ | [craftax](https://github.com/MichaelTMatthews/craftax) | 🕺 | 4 |
68
+ | [mujoco_playground](https://github.com/google-deepmind/mujoco_playground) | 🕺 | 54 |
69
+ | | | |
70
+ | Total | 🕺 / 👯 | 193 / 1 |
71
+
66
72
 
67
73
  ```python
68
74
  envelope.create("📦::🌍")
69
75
  ```
76
+
70
77
  let's you create environments from any of the above!
71
78
 
72
79
  ## 📝 Testing
80
+
73
81
  - **Default (no optional compat deps required)**: `uv run pytest -m "not compat"`
74
82
  - **Compat suite (requires full compat dependency group)**:
75
83
  - `uv sync --group compat`
@@ -77,11 +85,14 @@ let's you create environments from any of the above!
77
85
  - If any compat dependency is missing/broken, the run will fail fast with an error telling you what to install.
78
86
 
79
87
  ## 🏗️ Installation
88
+
80
89
  ```bash
81
90
  pip install jax-envelope
82
91
  ```
83
92
 
84
93
  ## 💞 Related projects
85
- * [stoa](https://github.com/EdanToledo/Stoa) is a very similar project that provides adapters and wrappers for the jumanji-like interface.
86
- * Check out all the great suites we have adapters for! [gymnax](https://github.com/RobertTLange/gymnax), [brax](https://github.com/google/brax), [jumanji](https://github.com/instadeepai/jumanji), [kinetix](https://github.com/flairox/kinetix), [craftax](https://github.com/MichaelTMatthews/craftax), [mujoco_playground](https://github.com/google-deepmind/mujoco_playground).
87
- * We will be adding support for [jaxmarl](https://github.com/flairox/jaxmarl) and [pgx](https://github.com/sotetsuk/pgx) in the future, as soon as we figured out the best ever MARL interface for JAX!
94
+
95
+ - [stoa](https://github.com/EdanToledo/Stoa) is a very similar project that provides adapters and wrappers for the jumanji-like interface.
96
+ - Check out all the great suites we have adapters for! [gymnax](https://github.com/RobertTLange/gymnax), [brax](https://github.com/google/brax), [jumanji](https://github.com/instadeepai/jumanji), [kinetix](https://github.com/flairox/kinetix), [craftax](https://github.com/MichaelTMatthews/craftax), [mujoco_playground](https://github.com/google-deepmind/mujoco_playground).
97
+ - We will be adding support for [jaxmarl](https://github.com/flairox/jaxmarl) and [pgx](https://github.com/sotetsuk/pgx) in the future, as soon as we figured out the best ever MARL interface for JAX!
98
+
@@ -1,10 +1,11 @@
1
1
  # 💌 Envelope: a JAX-native environment interface
2
+
2
3
  ```python
3
4
  # Create environments from JAX-native suites you have installed, ...
4
5
  env = envelope.create("gymnax::CartPole-v1")
5
6
 
6
7
  # ... interact with the environments using a simple interface, ...
7
- state, info = env.reset(key)
8
+ state, info = env.init(key)
8
9
  states, infos = jax.lax.scan(env.step, state, actions)
9
10
  plt.plot(infos.reward.cumsum())
10
11
 
@@ -15,35 +16,42 @@ env = envelope.wrappers.ObservationNormalizationWrapper(env)
15
16
  ```
16
17
 
17
18
  ## 🌍 Simple, expressive interaction!
18
- * **Environments are pytrees**. Squish them through JAX transformations and trace their parameters.
19
- * **Idiomatic jax-y interface** of `reset(key: Key) -> State, Info` and `step(state: State, action: PyTree) -> State, Info`. You can directly `jax.scan` over a `step(...)`!
20
- * **Spaces are super simple**. No `Tuple`, `Dict` nonsense! There are two spaces: `Continuous` and `Discrete`, which you can compose into a `PyTreeSpace`.
21
- * **Explicit episode truncation** supports correctly handling bootstrapping for value-function targets.
22
- * **No auto-reset** by default. Resetting every step can be expensive!
19
+ - **Environments are pytrees**. Squish them through JAX transformations and trace their parameters.
20
+ - **Idiomatic jax-y interface** of `init(key: Key) -> State, Info` and `step(state: State, action: PyTree) -> State, Info`. You can directly `jax.scan` over a `step(...)`!
21
+ - **Spaces are super simple**. No `Tuple`, `Dict` nonsense! There are two spaces: `Continuous` and `Discrete`, which you can compose into a `PyTreeSpace`.
22
+ - **Explicit episode truncation** supports correctly handling bootstrapping for value-function targets.
23
+ - **No auto-reset** by default. Resetting every step can be expensive!
23
24
 
24
25
  ## 💪 Powerful, composable wrappers!
25
- * **Carry state across episodes** to track running statistics, for example to normalize observations.
26
- * **Composable wrappers** can be stacked in any order. For example, `ObservationNormalizationWrapper` before vs. after `VmapWrapper` gives per-env vs. global normalization.
27
- <!-- TODO: Add auto-reset behavior (including state injection) and optimistic resets once I implement them. -->
26
+
27
+ - **Carry state across episodes** to track running statistics, for example to normalize observations.
28
+ - **Composable wrappers** can be stacked in any order. For example, `ObservationNormalizationWrapper` before vs. after `VmapWrapper` gives per-env vs. global normalization.
29
+
30
+
28
31
 
29
32
  ## 🔌 Adapters for existing suites
30
- | 📦 | # 🤖 | # 🌍 |
31
- |------|------|------|
32
- | [gymnax](https://github.com/RobertTLange/gymnax) | 🕺 | 24 |
33
- | [brax](https://github.com/google/brax) | 🕺 | 12 |
34
- | [jumanji](https://github.com/instadeepai/jumanji) | 🕺 / 👯 | 25 / 1 |
35
- | [kinetix](https://github.com/flairox/kinetix) | 🕺 | 74 |
36
- | [craftax](https://github.com/MichaelTMatthews/craftax) | 🕺 | 4 |
37
- | [mujoco_playground](https://github.com/google-deepmind/mujoco_playground) | 🕺 | 54 |
38
- | | |
39
- | Total | 🕺 / 👯 | 193 / 1 |
33
+
34
+
35
+ | 📦 | # 🤖 | # 🌍 |
36
+ | ------------------------------------------------------------------------- | ------- | ------- |
37
+ | [gymnax](https://github.com/RobertTLange/gymnax) | 🕺 | 24 |
38
+ | [brax](https://github.com/google/brax) | 🕺 | 12 |
39
+ | [jumanji](https://github.com/instadeepai/jumanji) | 🕺 / 👯 | 25 / 1 |
40
+ | [kinetix](https://github.com/flairox/kinetix) | 🕺 | 74 |
41
+ | [craftax](https://github.com/MichaelTMatthews/craftax) | 🕺 | 4 |
42
+ | [mujoco_playground](https://github.com/google-deepmind/mujoco_playground) | 🕺 | 54 |
43
+ | | | |
44
+ | Total | 🕺 / 👯 | 193 / 1 |
45
+
40
46
 
41
47
  ```python
42
48
  envelope.create("📦::🌍")
43
49
  ```
50
+
44
51
  let's you create environments from any of the above!
45
52
 
46
53
  ## 📝 Testing
54
+
47
55
  - **Default (no optional compat deps required)**: `uv run pytest -m "not compat"`
48
56
  - **Compat suite (requires full compat dependency group)**:
49
57
  - `uv sync --group compat`
@@ -51,11 +59,14 @@ let's you create environments from any of the above!
51
59
  - If any compat dependency is missing/broken, the run will fail fast with an error telling you what to install.
52
60
 
53
61
  ## 🏗️ Installation
62
+
54
63
  ```bash
55
64
  pip install jax-envelope
56
65
  ```
57
66
 
58
67
  ## 💞 Related projects
59
- * [stoa](https://github.com/EdanToledo/Stoa) is a very similar project that provides adapters and wrappers for the jumanji-like interface.
60
- * Check out all the great suites we have adapters for! [gymnax](https://github.com/RobertTLange/gymnax), [brax](https://github.com/google/brax), [jumanji](https://github.com/instadeepai/jumanji), [kinetix](https://github.com/flairox/kinetix), [craftax](https://github.com/MichaelTMatthews/craftax), [mujoco_playground](https://github.com/google-deepmind/mujoco_playground).
61
- * We will be adding support for [jaxmarl](https://github.com/flairox/jaxmarl) and [pgx](https://github.com/sotetsuk/pgx) in the future, as soon as we figured out the best ever MARL interface for JAX!
68
+
69
+ - [stoa](https://github.com/EdanToledo/Stoa) is a very similar project that provides adapters and wrappers for the jumanji-like interface.
70
+ - Check out all the great suites we have adapters for! [gymnax](https://github.com/RobertTLange/gymnax), [brax](https://github.com/google/brax), [jumanji](https://github.com/instadeepai/jumanji), [kinetix](https://github.com/flairox/kinetix), [craftax](https://github.com/MichaelTMatthews/craftax), [mujoco_playground](https://github.com/google-deepmind/mujoco_playground).
71
+ - We will be adding support for [jaxmarl](https://github.com/flairox/jaxmarl) and [pgx](https://github.com/sotetsuk/pgx) in the future, as soon as we figured out the best ever MARL interface for JAX!
72
+
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "jax-envelope"
3
- version = "0.2.0"
3
+ version = "0.3.0"
4
4
  description = "A JAX-native environment interface with powerful wrappers and adapters for popular RL environment suites"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.12"
@@ -69,7 +69,7 @@ class BraxEnvelope(Environment):
69
69
  info = InfoContainer(
70
70
  obs=brax_state.obs,
71
71
  reward=brax_state.reward,
72
- terminated=jnp.asarry(brax_state.done, dtype=bool).item(),
72
+ terminated=jnp.asarray(brax_state.done, dtype=bool),
73
73
  )
74
74
  info = info.update(**dataclasses.asdict(brax_state))
75
75
  return brax_state, info
@@ -81,7 +81,7 @@ class JumanjiEnvelope(Environment):
81
81
 
82
82
 
83
83
  def convert_jumanji_to_envelope_info(timestep: JumanjiTimeStep) -> InfoContainer:
84
- term = jnp.asarray(timestep.last(), dtype=bool).item()
84
+ term = jnp.asarray(timestep.last(), dtype=bool)
85
85
  info = InfoContainer(
86
86
  obs=timestep.observation, reward=timestep.reward, terminated=term
87
87
  ).update(**timestep.extras)
@@ -42,7 +42,7 @@ class Environment(ABC, FrozenPyTreeNode):
42
42
 
43
43
  Two distinct lifecycle methods:
44
44
  init(key) - Initialize environment and all state from scratch.
45
- reset(key, state) - Reset the inner environment while preserving
45
+ reset(state, key) - Reset the inner environment while preserving
46
46
  episode-persistent state.
47
47
  """
48
48
 
@@ -51,7 +51,7 @@ class Environment(ABC, FrozenPyTreeNode):
51
51
  """Initialize environment and all state from scratch."""
52
52
  ...
53
53
 
54
- def reset(self, key: Key, state: State) -> tuple[State, Info]:
54
+ def reset(self, state: State, key: Key) -> tuple[State, Info]:
55
55
  """Reset the inner environment while preserving episode-persistent state."""
56
56
  return self.init(key)
57
57
 
@@ -1,6 +1,6 @@
1
1
  from abc import ABC, abstractmethod
2
2
  from functools import cached_property
3
- from typing import cast, override
3
+ from typing import override
4
4
 
5
5
  import jax
6
6
  from jax import numpy as jnp
@@ -54,7 +54,7 @@ class AutoResetWrapper(Wrapper):
54
54
  return state, info.update(final=state.last_final)
55
55
 
56
56
  @override
57
- def reset(self, key: Key, state: WrappedState) -> tuple[WrappedState, Info]:
57
+ def reset(self, state: WrappedState, key: Key) -> tuple[WrappedState, Info]:
58
58
  raise NotImplementedError("Reset is not implemented for AutoResetWrapper")
59
59
 
60
60
  @override
@@ -63,7 +63,7 @@ class AutoResetWrapper(Wrapper):
63
63
  state = state.replace(reset_key=key)
64
64
 
65
65
  inner_state, info = self.env.step(state.inner_state, action)
66
- reset_inner_state, reset_info = self.env.reset(key_reset, inner_state)
66
+ reset_inner_state, reset_info = self.env.reset(inner_state, key_reset)
67
67
 
68
68
  # Select next state and info based on done
69
69
  done = info.terminated | info.truncated
@@ -35,8 +35,8 @@ class ContinuousObservationWrapper(Wrapper):
35
35
  return state, info
36
36
 
37
37
  @override
38
- def reset(self, key: Key, state: State) -> tuple[State, Info]:
39
- state, info = self.env.reset(key, state)
38
+ def reset(self, state: State, key: Key) -> tuple[State, Info]:
39
+ state, info = self.env.reset(state, key)
40
40
  info = info.update(obs=to_float(info.obs))
41
41
  return state, info
42
42
 
@@ -24,8 +24,8 @@ class EpisodeStatisticsWrapper(Wrapper):
24
24
  return state, info.update(stats=state.stats)
25
25
 
26
26
  @override
27
- def reset(self, key: Key, state: State) -> tuple[State, Info]:
28
- inner_state, info = self.env.reset(key, state.inner_state)
27
+ def reset(self, state: State, key: Key) -> tuple[State, Info]:
28
+ inner_state, info = self.env.reset(state.inner_state, key)
29
29
  state = state.replace(inner_state=inner_state)
30
30
  return state, info.update(stats=state.stats)
31
31
 
@@ -37,10 +37,9 @@ class FlattenObservationWrapper(Wrapper):
37
37
  return state, info
38
38
 
39
39
  @override
40
- def reset(self, key: Key, state: State) -> tuple[State, Info]:
41
- state, info = self.env.reset(key, state)
42
- info = info.update(obs=flatten_x(info.obs))
43
- return state, info
40
+ def reset(self, state: State, key: Key) -> tuple[State, Info]:
41
+ next_state, info = self.env.reset(state, key)
42
+ return next_state, info.update(obs=flatten_x(info.obs))
44
43
 
45
44
  @override
46
45
  def step(self, state: State, action: PyTree) -> tuple[State, Info]:
@@ -76,8 +76,8 @@ class ObservationNormalizationWrapper(Wrapper):
76
76
  return self._normalize_and_update(next_state, info)
77
77
 
78
78
  @override
79
- def reset(self, key: Key, state: WrappedState) -> tuple[WrappedState, Info]:
80
- inner_state, info = self.env.reset(key, state.inner_state)
79
+ def reset(self, state: WrappedState, key: Key) -> tuple[WrappedState, Info]:
80
+ inner_state, info = self.env.reset(state.inner_state, key)
81
81
  # Preserve running statistics across resets
82
82
  next_state = self.ObservationNormalizationState(
83
83
  inner_state=inner_state, rmv_state=state.rmv_state
@@ -36,7 +36,7 @@ class PooledInitVmapWrapper(Wrapper):
36
36
  return state, info.update(final=pholder_info)
37
37
 
38
38
  @override
39
- def reset(self, key: Key, state: WrappedState) -> tuple[WrappedState, Info]:
39
+ def reset(self, state: WrappedState, key: Key) -> tuple[WrappedState, Info]:
40
40
  # It's hard to support reset for this wrapper.
41
41
  # We would have to init the state of a pool of unwrapped environments, and then
42
42
  # somehow inject this into the stack of wrapped states. The current data
@@ -48,7 +48,7 @@ class PooledInitVmapWrapper(Wrapper):
48
48
  # episodes before vmapping, we will implement this later.
49
49
  keys = _split_or_keep_key(key, self.batch_size + 1)
50
50
  key_next, keys_pool = keys[0], keys[1:]
51
- inner_state, info = jax.vmap(self.env.reset)(keys_pool, state.inner_state)
51
+ inner_state, info = jax.vmap(self.env.reset)(state.inner_state, keys_pool)
52
52
  state = state.replace(inner_state=inner_state, init_key=key_next)
53
53
  return state, info.update(final=state.last_final)
54
54
 
@@ -69,13 +69,13 @@ class StateInjectionWrapper(Wrapper):
69
69
  return state, info
70
70
 
71
71
  @override
72
- def reset(self, key: Key, state: WrappedState) -> tuple[WrappedState, Info]:
72
+ def reset(self, state: WrappedState, key: Key) -> tuple[WrappedState, Info]:
73
73
  # If reset state is set, use it instead of resetting inner env
74
74
  if state.reset_state is not None and state.reset_obs is not None:
75
75
  inner_state = state.reset_state
76
76
  info = InfoContainer(obs=state.reset_obs, reward=0.0, terminated=False)
77
77
  elif state.reset_state is None and state.reset_obs is None:
78
- inner_state, info = self.env.reset(key, state.inner_state)
78
+ inner_state, info = self.env.reset(state.inner_state, key)
79
79
  else:
80
80
  raise ValueError("State must set both reset_state and reset_obs or neither")
81
81
 
@@ -21,8 +21,8 @@ class TruncationWrapper(Wrapper):
21
21
  return state, info.update(truncated=self.max_steps <= 0)
22
22
 
23
23
  @override
24
- def reset(self, key: Key, state: WrappedState) -> tuple[WrappedState, Info]:
25
- inner_state, info = self.env.reset(key, state.inner_state)
24
+ def reset(self, state: WrappedState, key: Key) -> tuple[WrappedState, Info]:
25
+ inner_state, info = self.env.reset(state.inner_state, key)
26
26
  state = state.replace(inner_state=inner_state, steps=0)
27
27
  return state, info.update(truncated=self.max_steps <= 0)
28
28
 
@@ -40,11 +40,9 @@ class VmapEnvsWrapper(Wrapper):
40
40
  return state, info
41
41
 
42
42
  @override
43
- def reset(self, key: Key, state: PyTree) -> tuple[WrappedState, Info]:
43
+ def reset(self, state: PyTree, key: Key) -> tuple[WrappedState, Info]:
44
44
  keys = self._split_keys(key)
45
- state, info = jax.vmap(lambda e, k, s: e.reset(k, s))(
46
- self.env, keys, state
47
- )
45
+ state, info = jax.vmap(lambda e, s, k: e.reset(s, k))(self.env, state, keys)
48
46
  return state, info
49
47
 
50
48
  @override
@@ -58,7 +56,9 @@ class VmapEnvsWrapper(Wrapper):
58
56
  @property
59
57
  def observation_space(self) -> spaces.Space:
60
58
  env0 = _index_env(self.env, 0, self.batch_size)
61
- return spaces.BatchedSpace(space=env0.observation_space, batch_size=self.batch_size)
59
+ return spaces.BatchedSpace(
60
+ space=env0.observation_space, batch_size=self.batch_size
61
+ )
62
62
 
63
63
  @override
64
64
  @cached_property
@@ -41,9 +41,9 @@ class VmapWrapper(Wrapper):
41
41
  return state, info
42
42
 
43
43
  @override
44
- def reset(self, key: Key, state: PyTree) -> tuple[WrappedState, Info]:
44
+ def reset(self, state: PyTree, key: Key) -> tuple[WrappedState, Info]:
45
45
  keys = _split_or_keep_key(key, self.batch_size)
46
- state, info = jax.vmap(self.env.reset)(keys, state)
46
+ state, info = jax.vmap(self.env.reset)(state, keys)
47
47
  return state, info
48
48
 
49
49
  @override
@@ -29,11 +29,11 @@ class Wrapper(Environment):
29
29
  return self.env.init(key)
30
30
 
31
31
  @override
32
- def reset(self, key: Key, state: State) -> tuple[State, Info]:
33
- return self.env.reset(key, state)
32
+ def reset(self, state: State, key: Key) -> tuple[State, Info]:
33
+ return self.env.reset(state, key)
34
34
 
35
35
  @override
36
- def step(self, state: WrappedState, action: PyTree) -> tuple[WrappedState, Info]:
36
+ def step(self, state: State, action: PyTree) -> tuple[State, Info]:
37
37
  return self.env.step(state, action)
38
38
 
39
39
  @override
@@ -93,8 +93,8 @@ def test_wrapper_unwrapping():
93
93
 
94
94
  # Create a simple wrapper
95
95
  class SimpleWrapper(BraxWrapper):
96
- def reset(self, rng):
97
- return self.env.reset(rng)
96
+ def init(self, rng):
97
+ return self.env.init(rng)
98
98
 
99
99
  def step(self, state, action):
100
100
  return self.env.step(state, action)
@@ -98,7 +98,7 @@ class StepCounterEnv(Environment):
98
98
  truncated=truncated,
99
99
  )
100
100
 
101
- def reset(self, key: Key, state: State) -> tuple[StepState, InfoContainer]:
101
+ def reset(self, state: State, key: Key) -> tuple[StepState, InfoContainer]:
102
102
  return self.init(key)
103
103
 
104
104
  def step(
@@ -198,9 +198,7 @@ class NoStepsEnv(Environment):
198
198
  obs=s.env_state, reward=0.0, terminated=False, truncated=False
199
199
  )
200
200
 
201
- def reset(
202
- self, key: Key, state: State
203
- ) -> tuple[NoStepsState, InfoContainer]:
201
+ def reset(self, state: State, key: Key) -> tuple[NoStepsState, InfoContainer]:
204
202
  return self.init(key)
205
203
 
206
204
  def step(
@@ -230,9 +228,7 @@ class AlternatingTerminationEnv(Environment):
230
228
  obs=s.env_state, reward=0.0, terminated=False, truncated=False
231
229
  )
232
230
 
233
- def reset(
234
- self, key: Key, state: State
235
- ) -> tuple[StepState, InfoContainer]:
231
+ def reset(self, state: State, key: Key) -> tuple[StepState, InfoContainer]:
236
232
  return self.init(key)
237
233
 
238
234
  def step(
@@ -266,7 +262,7 @@ class ScalarToyEnv(Environment):
266
262
  s = jnp.asarray(0.0, dtype=jnp.float32)
267
263
  return s, InfoContainer(obs=s, reward=0.0, terminated=False, truncated=False)
268
264
 
269
- def reset(self, key: Key, state: State) -> tuple[State, Info]:
265
+ def reset(self, state: State, key: Key) -> tuple[State, Info]:
270
266
  return self.init(key)
271
267
 
272
268
  def step(self, state: State, action: jax.Array) -> tuple[State, Info]:
@@ -300,7 +296,7 @@ class VectorToyEnv(Environment):
300
296
  s = jnp.zeros((self.dim,), dtype=jnp.float32)
301
297
  return s, InfoContainer(obs=s, reward=0.0, terminated=False, truncated=False)
302
298
 
303
- def reset(self, key: Key, state: State) -> tuple[State, Info]:
299
+ def reset(self, state: State, key: Key) -> tuple[State, Info]:
304
300
  return self.init(key)
305
301
 
306
302
  def step(self, state: State, action: jax.Array) -> tuple[State, Info]:
@@ -330,7 +326,7 @@ class FlagDoneEnv(Environment):
330
326
  z = jnp.array(0.0)
331
327
  return z, InfoContainer(obs=z, reward=0.0, terminated=False, truncated=False)
332
328
 
333
- def reset(self, key: Key, state: State):
329
+ def reset(self, state: State, key: Key):
334
330
  return self.init(key)
335
331
 
336
332
  def step(self, state: State, action: jax.Array):
@@ -366,7 +362,7 @@ class ParamEnv(Environment):
366
362
  s = jnp.asarray([self.offset, -self.offset], dtype=jnp.float32)
367
363
  return s, InfoContainer(obs=s, reward=0.0, terminated=False, truncated=False)
368
364
 
369
- def reset(self, key: Key, state: State) -> tuple[State, Info]:
365
+ def reset(self, state: State, key: Key) -> tuple[State, Info]:
370
366
  return self.init(key)
371
367
 
372
368
  def step(self, state: State, action: jax.Array) -> tuple[State, Info]:
@@ -401,7 +397,7 @@ class VectorObsEnv(Environment):
401
397
  s = jnp.linspace(0.0, 1.0, self.dim, dtype=jnp.float32)
402
398
  return s, InfoContainer(obs=s, reward=0.0, terminated=False, truncated=False)
403
399
 
404
- def reset(self, key: Key, state: State):
400
+ def reset(self, state: State, key: Key):
405
401
  return self.init(key)
406
402
 
407
403
  def step(self, state: State, action: jax.Array):
@@ -444,7 +440,7 @@ class PyTreeObsEnv(Environment):
444
440
  s = obs
445
441
  return s, InfoContainer(obs=obs, reward=0.0, terminated=False, truncated=False)
446
442
 
447
- def reset(self, key: Key, state: State):
443
+ def reset(self, state: State, key: Key):
448
444
  return self.init(key)
449
445
 
450
446
  def step(self, state: State, action: jax.Array):
@@ -476,7 +472,7 @@ class ConstantObsEnv(Environment):
476
472
  obs = jnp.asarray(self.value, self.dtype) * jnp.ones(self.shape, self.dtype)
477
473
  return 0, InfoContainer(obs=obs, reward=0.0, terminated=False, truncated=False)
478
474
 
479
- def reset(self, key: Key, state: State):
475
+ def reset(self, state: State, key: Key):
480
476
  return self.init(key)
481
477
 
482
478
  def step(self, state: State, action: jax.Array):
@@ -500,10 +496,12 @@ class PyTreeActionEnv(Environment):
500
496
 
501
497
  @cached_property
502
498
  def action_space(self) -> PyTreeSpace:
503
- return PyTreeSpace({
504
- "a": Continuous.from_shape(low=-1.0, high=1.0, shape=(2,)),
505
- "b": Continuous.from_shape(low=-1.0, high=1.0, shape=(3,)),
506
- })
499
+ return PyTreeSpace(
500
+ {
501
+ "a": Continuous.from_shape(low=-1.0, high=1.0, shape=(2,)),
502
+ "b": Continuous.from_shape(low=-1.0, high=1.0, shape=(3,)),
503
+ }
504
+ )
507
505
 
508
506
  def _action_to_vec(self, action: PyTree) -> jax.Array:
509
507
  leaves = jax.tree.leaves(action)
@@ -513,12 +511,10 @@ class PyTreeActionEnv(Environment):
513
511
  s = jnp.zeros(5, dtype=jnp.float32)
514
512
  return s, InfoContainer(obs=s, reward=0.0, terminated=False, truncated=False)
515
513
 
516
- def reset(self, key: Key, state: State) -> tuple[jax.Array, InfoContainer]:
514
+ def reset(self, state: State, key: Key) -> tuple[jax.Array, InfoContainer]:
517
515
  return self.init(key)
518
516
 
519
- def step(
520
- self, state: jax.Array, action: PyTree
521
- ) -> tuple[jax.Array, InfoContainer]:
517
+ def step(self, state: jax.Array, action: PyTree) -> tuple[jax.Array, InfoContainer]:
522
518
  vec = self._action_to_vec(action)
523
519
  ns = state + jnp.asarray(vec, dtype=jnp.float32)
524
520
  reward = jnp.sum(vec)
@@ -545,7 +541,7 @@ class IntObsEnv(Environment):
545
541
  s = jnp.array(0, dtype=jnp.int32)
546
542
  return s, InfoContainer(obs=s, reward=0.0, terminated=False, truncated=False)
547
543
 
548
- def reset(self, key: Key, state: State):
544
+ def reset(self, state: State, key: Key):
549
545
  return self.init(key)
550
546
 
551
547
  def step(self, state: State, action: jax.Array):
@@ -581,7 +577,7 @@ class RandomImageEnv(Environment):
581
577
  obs=obs.astype(self.dtype), reward=0.0, terminated=False, truncated=False
582
578
  )
583
579
 
584
- def reset(self, key: Key, state: State):
580
+ def reset(self, state: State, key: Key):
585
581
  return self.init(key)
586
582
 
587
583
  def step(self, state: State, action: jax.Array):
@@ -618,9 +614,7 @@ class WrapperSimpleEnv(Environment):
618
614
  info = TestInfo(obs=state, reward=0.0, terminated=False, truncated=False)
619
615
  return state, info
620
616
 
621
- def reset(
622
- self, key: Key, state: State
623
- ) -> tuple[jax.Array, TestInfo]:
617
+ def reset(self, state: State, key: Key) -> tuple[jax.Array, TestInfo]:
624
618
  return self.init(key)
625
619
 
626
620
  def step(self, state: jax.Array, action: jax.Array) -> tuple[jax.Array, TestInfo]:
@@ -650,9 +644,7 @@ class WrapperEnvWithFields(Environment):
650
644
  info = TestInfo(obs=state, reward=0.0, terminated=False, truncated=False)
651
645
  return state, info
652
646
 
653
- def reset(
654
- self, key: Key, state: State
655
- ) -> tuple[jax.Array, TestInfo]:
647
+ def reset(self, state: State, key: Key) -> tuple[jax.Array, TestInfo]:
656
648
  return self.init(key)
657
649
 
658
650
  def step(self, state: jax.Array, action: jax.Array) -> tuple[jax.Array, TestInfo]:
@@ -679,9 +671,7 @@ class WrapperEnvWithMethods(Environment):
679
671
  info = TestInfo(obs=state, reward=0.0, terminated=False, truncated=False)
680
672
  return state, info
681
673
 
682
- def reset(
683
- self, key: Key, state: State
684
- ) -> tuple[jax.Array, TestInfo]:
674
+ def reset(self, state: State, key: Key) -> tuple[jax.Array, TestInfo]:
685
675
  return self.init(key)
686
676
 
687
677
  def step(self, state: jax.Array, action: jax.Array) -> tuple[jax.Array, TestInfo]:
@@ -727,7 +717,7 @@ def make_wrapper_discrete_env() -> Environment:
727
717
  info = TestInfo(obs=state, reward=0.0, terminated=False, truncated=False)
728
718
  return state, info
729
719
 
730
- def reset(self, key: Key, state: State):
720
+ def reset(self, state: State, key: Key):
731
721
  return self.init(key)
732
722
 
733
723
  def step(self, state: jax.Array, action: jax.Array):
@@ -762,7 +752,7 @@ def make_wrapper_complex_state_env() -> Environment:
762
752
  )
763
753
  return st, info
764
754
 
765
- def reset(self, key: Key, state: State):
755
+ def reset(self, state: State, key: Key):
766
756
  return self.init(key)
767
757
 
768
758
  def step(self, state: dict, action: jax.Array):
@@ -544,8 +544,8 @@ def test_auto_reset_passes_state_to_inner_wrapper():
544
544
  received_state_on_reset=False,
545
545
  ), info
546
546
 
547
- def reset(self, key, state):
548
- inner_state, info = self.env.reset(key, state.inner_state)
547
+ def reset(self, state, key):
548
+ inner_state, info = self.env.reset(state.inner_state, key)
549
549
  return self.TrackingState(
550
550
  inner_state=inner_state,
551
551
  received_state_on_reset=True,
@@ -21,8 +21,8 @@ def test_init_reset_delegate_unchanged():
21
21
  assert jnp.allclose(state_w, state_e)
22
22
  assert jnp.allclose(info_w.obs, info_e.obs)
23
23
 
24
- state_w, info_w = w.reset(key, state_w)
25
- state_e, info_e = env.reset(key, state_e)
24
+ state_w, info_w = w.reset(state_w, key)
25
+ state_e, info_e = env.reset(state_e, key)
26
26
  assert jnp.allclose(state_w, state_e)
27
27
  assert jnp.allclose(info_w.obs, info_e.obs)
28
28
  assert w.observation_space.contains(info_w.obs)
@@ -17,7 +17,7 @@ def test_init_reset_step_cast_discrete_obs_to_float32():
17
17
  key = jax.random.PRNGKey(0)
18
18
  state, info = w.init(key)
19
19
  assert info.obs.dtype == jnp.float32
20
- state, info = w.reset(key, state)
20
+ state, info = w.reset(state, key)
21
21
  assert info.obs.dtype == jnp.float32
22
22
  assert w.observation_space.contains(info.obs)
23
23
  state, info = w.step(state, jnp.array(0, dtype=jnp.int32))
@@ -66,7 +66,7 @@ def test_reset_preserves_stats():
66
66
  state, _ = w.step(state, jnp.asarray(0.2))
67
67
  reward_before = state.stats.reward
68
68
  length_before = state.stats.length
69
- state, info = w.reset(key, state)
69
+ state, info = w.reset(state, key)
70
70
  assert jnp.allclose(state.stats.reward, reward_before)
71
71
  assert jnp.allclose(state.stats.length, length_before)
72
72
  assert jnp.allclose(info.stats.reward, reward_before)
@@ -81,7 +81,7 @@ def test_stats_persist_and_continue_after_reset():
81
81
  state, _ = w.init(key)
82
82
  for _ in range(3):
83
83
  state, _ = w.step(state, jnp.asarray(0.1))
84
- state, _ = w.reset(key, state)
84
+ state, _ = w.reset(state, key)
85
85
  for _ in range(2):
86
86
  state, _ = w.step(state, jnp.asarray(0.1))
87
87
  # Total length = 3 + 2 = 5, reward = 0.1*5 = 0.5
@@ -60,8 +60,8 @@ def test_init_reset_delegate_unchanged():
60
60
  state_e, info_e = env.init(key)
61
61
  assert jnp.allclose(state_w, state_e)
62
62
  assert jnp.allclose(info_w.obs, info_e.obs)
63
- state_w, info_w = w.reset(key, state_w)
64
- state_e, info_e = env.reset(key, state_e)
63
+ state_w, info_w = w.reset(state_w, key)
64
+ state_e, info_e = env.reset(state_e, key)
65
65
  assert jnp.allclose(state_w, state_e)
66
66
  assert jnp.allclose(info_w.obs, info_e.obs)
67
67
 
@@ -101,7 +101,7 @@ def test_action_space_flattened_discrete():
101
101
  obs=s, reward=0.0, terminated=False, truncated=False
102
102
  )
103
103
 
104
- def reset(self, key, state):
104
+ def reset(self, state, key):
105
105
  return self.init(key)
106
106
 
107
107
  def step(self, state, action):
@@ -161,7 +161,7 @@ def test_mixed_space_types_raises_value_error():
161
161
  obs=s, reward=0.0, terminated=False, truncated=False
162
162
  )
163
163
 
164
- def reset(self, key, state):
164
+ def reset(self, state, key):
165
165
  return self.init(key)
166
166
 
167
167
  def step(self, state, action):
@@ -29,7 +29,7 @@ def test_reset_step_flatten_pytree_obs():
29
29
  key = jax.random.PRNGKey(0)
30
30
  state, info = w.init(key)
31
31
  assert info.obs.shape == (5,)
32
- state, info = w.reset(key, state)
32
+ state, info = w.reset(state, key)
33
33
  assert info.obs.shape == (5,)
34
34
  assert w.observation_space.contains(info.obs)
35
35
  state, info = w.step(state, jnp.array(0.0))
@@ -126,7 +126,7 @@ def test_mixed_space_types_raises_value_error():
126
126
  obs=obs, reward=0.0, terminated=False, truncated=False
127
127
  )
128
128
 
129
- def reset(self, key, state):
129
+ def reset(self, state, key):
130
130
  return self.init(key)
131
131
 
132
132
  def step(self, state, action):
@@ -108,7 +108,7 @@ def test_reset_vmaps_inner_reset():
108
108
  w = PooledInitVmapWrapper(env=env, batch_size=batch_size, pool_size=3)
109
109
  key = jax.random.PRNGKey(0)
110
110
  state, info = w.init(key)
111
- state, info = w.reset(key, state)
111
+ state, info = w.reset(state, key)
112
112
  assert info.obs.shape == (batch_size,)
113
113
  assert w.observation_space.contains(info.obs)
114
114
 
@@ -66,7 +66,7 @@ class TestStateInjectionCoreFunctionality:
66
66
 
67
67
  # Reset again, passing the current state (simulates auto-reset)
68
68
  key2 = jax.random.PRNGKey(1)
69
- state2, info2 = w.reset(key2, state)
69
+ state2, info2 = w.reset(state, key2)
70
70
 
71
71
  # Should preserve the injected state
72
72
  assert jnp.allclose(state2.reset_state.env_state, jnp.array(42.0))
@@ -132,7 +132,7 @@ class TestStateInjectionCoreFunctionality:
132
132
 
133
133
  # Reset with this state (no reset_state set) - should delegate to inner env
134
134
  key2 = jax.random.PRNGKey(1)
135
- state2, info2 = w.reset(key2, state)
135
+ state2, info2 = w.reset(state, key2)
136
136
 
137
137
  # Should have done a normal reset - inner_state is fresh from env
138
138
  assert jnp.allclose(state2.inner_state.env_state, jnp.array(0.0))
@@ -166,7 +166,7 @@ class TestStateInjectionCoreFunctionality:
166
166
  )
167
167
 
168
168
  with pytest.raises(ValueError, match="must set both"):
169
- w.reset(key, state_with_only_reset_state)
169
+ w.reset(state_with_only_reset_state, key)
170
170
 
171
171
  # Create state with only reset_obs set (not reset_state)
172
172
  state_with_only_reset_obs = w.InjectedState(
@@ -176,7 +176,7 @@ class TestStateInjectionCoreFunctionality:
176
176
  )
177
177
 
178
178
  with pytest.raises(ValueError, match="must set both"):
179
- w.reset(key, state_with_only_reset_obs)
179
+ w.reset(state_with_only_reset_obs, key)
180
180
 
181
181
 
182
182
  # ============================================================================
@@ -116,7 +116,7 @@ def test_steps_as_jax_scalar_array_behaves_correctly():
116
116
 
117
117
 
118
118
  def test_reset_with_state_passes_inner_state_down():
119
- """reset(key, state) should pass state.inner_state to the inner env's reset."""
119
+ """reset(state, key) should pass state.inner_state to the inner env's reset."""
120
120
  env = StepCounterEnv()
121
121
  w = TruncationWrapper(env=env, max_steps=10)
122
122
  key = jax.random.PRNGKey(0)
@@ -126,7 +126,7 @@ def test_reset_with_state_passes_inner_state_down():
126
126
  state, _ = w.step(state, jnp.asarray(0.1))
127
127
  assert state.steps == 5
128
128
 
129
- new_state, _ = w.reset(jax.random.PRNGKey(1), state)
129
+ new_state, _ = w.reset(state, jax.random.PRNGKey(1))
130
130
 
131
131
  # Inner env should be reset
132
132
  assert jnp.allclose(new_state.inner_state.env_state, 0.0)
@@ -1056,7 +1056,7 @@ wheels = [
1056
1056
 
1057
1057
  [[package]]
1058
1058
  name = "jax-envelope"
1059
- version = "0.2.0"
1059
+ version = "0.3.0"
1060
1060
  source = { editable = "." }
1061
1061
  dependencies = [
1062
1062
  { name = "jax" },
File without changes
File without changes