jax-envelope 0.2.0__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/.github/workflows/publish.yml +1 -1
  2. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/PKG-INFO +38 -28
  3. jax_envelope-0.4.0/README.md +71 -0
  4. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/pyproject.toml +3 -3
  5. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/__init__.py +1 -1
  6. {jax_envelope-0.2.0/src/envelope/compat → jax_envelope-0.4.0/src/envelope/adapters}/__init__.py +7 -7
  7. {jax_envelope-0.2.0/src/envelope/compat → jax_envelope-0.4.0/src/envelope/adapters}/brax_envelope.py +1 -1
  8. {jax_envelope-0.2.0/src/envelope/compat → jax_envelope-0.4.0/src/envelope/adapters}/craftax_envelope.py +1 -1
  9. {jax_envelope-0.2.0/src/envelope/compat → jax_envelope-0.4.0/src/envelope/adapters}/jumanji_envelope.py +1 -1
  10. {jax_envelope-0.2.0/src/envelope/compat → jax_envelope-0.4.0/src/envelope/adapters}/kinetix_envelope.py +2 -2
  11. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/environment.py +2 -2
  12. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/spaces.py +1 -1
  13. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/wrappers/autoreset_wrapper.py +2 -2
  14. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/wrappers/continuous_observation_wrapper.py +2 -2
  15. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/wrappers/episode_statistics_wrapper.py +2 -2
  16. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/wrappers/flatten_observation_wrapper.py +3 -4
  17. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/wrappers/observation_normalization_wrapper.py +2 -2
  18. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/wrappers/pooled_init_vmap_wrapper.py +2 -2
  19. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/wrappers/state_injection_wrapper.py +2 -2
  20. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/wrappers/truncation_wrapper.py +2 -2
  21. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/wrappers/vmap_envs_wrapper.py +5 -5
  22. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/wrappers/vmap_wrapper.py +2 -2
  23. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/wrappers/wrapper.py +3 -3
  24. {jax_envelope-0.2.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/conftest.py +1 -1
  25. {jax_envelope-0.2.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/contract.py +2 -2
  26. {jax_envelope-0.2.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/test_brax_compat.py +6 -6
  27. {jax_envelope-0.2.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/test_craftax_compat.py +5 -5
  28. {jax_envelope-0.2.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/test_create.py +7 -7
  29. {jax_envelope-0.2.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/test_create_integration.py +12 -12
  30. {jax_envelope-0.2.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/test_gymnax_compat.py +2 -2
  31. {jax_envelope-0.2.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/test_jumanji_compat.py +5 -5
  32. {jax_envelope-0.2.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/test_kinetix_compat.py +7 -8
  33. {jax_envelope-0.2.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/test_mujoco_playground_compat.py +4 -4
  34. {jax_envelope-0.2.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/test_navix_compat.py +8 -8
  35. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/wrappers/helpers.py +25 -35
  36. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/wrappers/test_autoreset_wrapper.py +2 -2
  37. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/wrappers/test_clip_action_wrapper.py +2 -2
  38. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/wrappers/test_continuous_observation_wrapper.py +1 -1
  39. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/wrappers/test_episode_statistics_wrapper.py +2 -2
  40. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/wrappers/test_flatten_action_wrapper.py +4 -4
  41. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/wrappers/test_flatten_observation_wrapper.py +2 -2
  42. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/wrappers/test_pooled_init_vmap_wrapper.py +1 -1
  43. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/wrappers/test_state_injection_wrapper.py +4 -4
  44. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/wrappers/test_truncation_wrapper.py +2 -2
  45. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/uv.lock +3 -3
  46. jax_envelope-0.2.0/README.md +0 -61
  47. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/.gitignore +0 -0
  48. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/LICENSE +0 -0
  49. {jax_envelope-0.2.0/src/envelope/compat → jax_envelope-0.4.0/src/envelope/adapters}/gymnax_envelope.py +0 -0
  50. {jax_envelope-0.2.0/src/envelope/compat → jax_envelope-0.4.0/src/envelope/adapters}/mujoco_playground_envelope.py +0 -0
  51. {jax_envelope-0.2.0/src/envelope/compat → jax_envelope-0.4.0/src/envelope/adapters}/navix_envelope.py +0 -0
  52. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/struct.py +0 -0
  53. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/typing.py +0 -0
  54. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/wrappers/__init__.py +0 -0
  55. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/wrappers/clip_action_wrapper.py +0 -0
  56. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/wrappers/flatten_action_wrapper.py +0 -0
  57. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/wrappers/normalization.py +0 -0
  58. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/__init__.py +0 -0
  59. {jax_envelope-0.2.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/__init__.py +0 -0
  60. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/spaces/__init__.py +0 -0
  61. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/spaces/test_batched_space.py +0 -0
  62. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/spaces/test_continuous.py +0 -0
  63. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/spaces/test_discrete.py +0 -0
  64. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/spaces/test_pytree_space.py +0 -0
  65. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/spaces/test_serialization.py +0 -0
  66. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/test_container.py +0 -0
  67. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/test_struct.py +0 -0
  68. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/wrappers/__init__.py +0 -0
  69. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/wrappers/test_environment_wrapper.py +0 -0
  70. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/wrappers/test_normalization.py +0 -0
  71. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/wrappers/test_observation_normalization_wrapper.py +0 -0
  72. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/wrappers/test_vmap_envs_wrapper.py +0 -0
  73. {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/wrappers/test_vmap_wrapper.py +0 -0
@@ -28,7 +28,7 @@ jobs:
28
28
  uv sync --group dev --locked
29
29
 
30
30
  - name: Run tests
31
- run: uv run pytest -m "not compat"
31
+ run: uv run pytest -m "not adapters"
32
32
 
33
33
  build:
34
34
  name: Build distribution
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jax-envelope
3
- Version: 0.2.0
3
+ Version: 0.4.0
4
4
  Summary: A JAX-native environment interface with powerful wrappers and adapters for popular RL environment suites
5
5
  Project-URL: Homepage, https://github.com/keraJLi/envelope
6
6
  Project-URL: Repository, https://github.com/keraJLi/envelope
@@ -25,12 +25,13 @@ Requires-Dist: jax>=0.5.0
25
25
  Description-Content-Type: text/markdown
26
26
 
27
27
  # 💌 Envelope: a JAX-native environment interface
28
+
28
29
  ```python
29
30
  # Create environments from JAX-native suites you have installed, ...
30
31
  env = envelope.create("gymnax::CartPole-v1")
31
32
 
32
33
  # ... interact with the environments using a simple interface, ...
33
- state, info = env.reset(key)
34
+ state, info = env.init(key)
34
35
  states, infos = jax.lax.scan(env.step, state, actions)
35
36
  plt.plot(infos.reward.cumsum())
36
37
 
@@ -41,47 +42,56 @@ env = envelope.wrappers.ObservationNormalizationWrapper(env)
41
42
  ```
42
43
 
43
44
  ## 🌍 Simple, expressive interaction!
44
- * **Environments are pytrees**. Squish them through JAX transformations and trace their parameters.
45
- * **Idiomatic jax-y interface** of `reset(key: Key) -> State, Info` and `step(state: State, action: PyTree) -> State, Info`. You can directly `jax.scan` over a `step(...)`!
46
- * **Spaces are super simple**. No `Tuple`, `Dict` nonsense! There are two spaces: `Continuous` and `Discrete`, which you can compose into a `PyTreeSpace`.
47
- * **Explicit episode truncation** supports correctly handling bootstrapping for value-function targets.
48
- * **No auto-reset** by default. Resetting every step can be expensive!
45
+
46
+ - **Environments are pytrees**. Squish them through JAX transformations and trace their parameters.
47
+ - **Idiomatic jax-y interface** of `init(key: Key) -> State, Info` and `step(state: State, action: PyTree) -> State, Info`. You can directly `jax.scan` over a `step(...)`!
48
+ - **Spaces are super simple**. No `Tuple`, `Dict` nonsense! There are two spaces: `Continuous` and `Discrete`, which you can compose into a `PyTreeSpace`.
49
+ - **Explicit episode truncation** supports correctly handling bootstrapping for value-function targets.
50
+ - **No auto-reset** by default. Resetting every step can be expensive!
49
51
 
50
52
  ## 💪 Powerful, composable wrappers!
51
- * **Carry state across episodes** to track running statistics, for example to normalize observations.
52
- * **Composable wrappers** can be stacked in any order. For example, `ObservationNormalizationWrapper` before vs. after `VmapWrapper` gives per-env vs. global normalization.
53
- <!-- TODO: Add auto-reset behavior (including state injection) and optimistic resets once I implement them. -->
53
+
54
+ - **Carry state across episodes** to track running statistics, for example to normalize observations.
55
+ - **Composable wrappers** can be stacked in any order. For example, `ObservationNormalizationWrapper` before vs. after `VmapWrapper` gives per-env vs. global normalization.
54
56
 
55
57
  ## 🔌 Adapters for existing suites
56
- | 📦 | # 🤖 | # 🌍 |
57
- |------|------|------|
58
- | [gymnax](https://github.com/RobertTLange/gymnax) | 🕺 | 24 |
59
- | [brax](https://github.com/google/brax) | 🕺 | 12 |
60
- | [jumanji](https://github.com/instadeepai/jumanji) | 🕺 / 👯 | 25 / 1 |
61
- | [kinetix](https://github.com/flairox/kinetix) | 🕺 | 74 |
62
- | [craftax](https://github.com/MichaelTMatthews/craftax) | 🕺 | 4 |
63
- | [mujoco_playground](https://github.com/google-deepmind/mujoco_playground) | 🕺 | 54 |
64
- | | |
65
- | Total | 🕺 / 👯 | 193 / 1 |
58
+
59
+
60
+ | 📦 | # 🤖 | # 🌍 |
61
+ | ------------------------------------------------------------------------- | ------- | ------- |
62
+ | [gymnax](https://github.com/RobertTLange/gymnax) | 🕺 | 24 |
63
+ | [brax](https://github.com/google/brax) | 🕺 | 12 |
64
+ | [jumanji](https://github.com/instadeepai/jumanji) | 🕺 / 👯 | 25 / 1 |
65
+ | [kinetix](https://github.com/flairox/kinetix) | 🕺 | 74 |
66
+ | [craftax](https://github.com/MichaelTMatthews/craftax) | 🕺 | 4 |
67
+ | [mujoco_playground](https://github.com/google-deepmind/mujoco_playground) | 🕺 | 54 |
68
+ | | | |
69
+ | Total | 🕺 / 👯 | 193 / 1 |
70
+
66
71
 
67
72
  ```python
68
73
  envelope.create("📦::🌍")
69
74
  ```
75
+
70
76
  let's you create environments from any of the above!
71
77
 
72
78
  ## 📝 Testing
73
- - **Default (no optional compat deps required)**: `uv run pytest -m "not compat"`
74
- - **Compat suite (requires full compat dependency group)**:
75
- - `uv sync --group compat`
76
- - `uv run pytest -m compat`
77
- - If any compat dependency is missing/broken, the run will fail fast with an error telling you what to install.
79
+
80
+ - **Default (no optional adapters deps required)**: `uv run pytest -m "not adapters"`
81
+ - **Adapters suite (requires full adapters dependency group)**:
82
+ - `uv sync --group adapters`
83
+ - `uv run pytest -m adapters`
84
+ - If any adapter dependency is missing/broken, the run will fail fast with an error telling you what to install.
78
85
 
79
86
  ## 🏗️ Installation
87
+
80
88
  ```bash
81
89
  pip install jax-envelope
82
90
  ```
83
91
 
84
92
  ## 💞 Related projects
85
- * [stoa](https://github.com/EdanToledo/Stoa) is a very similar project that provides adapters and wrappers for the jumanji-like interface.
86
- * Check out all the great suites we have adapters for! [gymnax](https://github.com/RobertTLange/gymnax), [brax](https://github.com/google/brax), [jumanji](https://github.com/instadeepai/jumanji), [kinetix](https://github.com/flairox/kinetix), [craftax](https://github.com/MichaelTMatthews/craftax), [mujoco_playground](https://github.com/google-deepmind/mujoco_playground).
87
- * We will be adding support for [jaxmarl](https://github.com/flairox/jaxmarl) and [pgx](https://github.com/sotetsuk/pgx) in the future, as soon as we figured out the best ever MARL interface for JAX!
93
+
94
+ - [stoa](https://github.com/EdanToledo/Stoa) is a very similar project that provides adapters and wrappers for the jumanji-like interface.
95
+ - Check out all the great suites we have adapters for! [gymnax](https://github.com/RobertTLange/gymnax), [brax](https://github.com/google/brax), [jumanji](https://github.com/instadeepai/jumanji), [kinetix](https://github.com/flairox/kinetix), [craftax](https://github.com/MichaelTMatthews/craftax), [mujoco_playground](https://github.com/google-deepmind/mujoco_playground).
96
+ - We will be adding support for [jaxmarl](https://github.com/flairox/jaxmarl) and [pgx](https://github.com/sotetsuk/pgx) in the future, as soon as we figured out the best ever MARL interface for JAX!
97
+
@@ -0,0 +1,71 @@
1
+ # 💌 Envelope: a JAX-native environment interface
2
+
3
+ ```python
4
+ # Create environments from JAX-native suites you have installed, ...
5
+ env = envelope.create("gymnax::CartPole-v1")
6
+
7
+ # ... interact with the environments using a simple interface, ...
8
+ state, info = env.init(key)
9
+ states, infos = jax.lax.scan(env.step, state, actions)
10
+ plt.plot(infos.reward.cumsum())
11
+
12
+ # ... and enjoy a powerful ecosystem of wrappers.
13
+ env = envelope.wrappers.AutoResetWrapper(env)
14
+ env = envelope.wrappers.VmapWrapper(env)
15
+ env = envelope.wrappers.ObservationNormalizationWrapper(env)
16
+ ```
17
+
18
+ ## 🌍 Simple, expressive interaction!
19
+
20
+ - **Environments are pytrees**. Squish them through JAX transformations and trace their parameters.
21
+ - **Idiomatic jax-y interface** of `init(key: Key) -> State, Info` and `step(state: State, action: PyTree) -> State, Info`. You can directly `jax.scan` over a `step(...)`!
22
+ - **Spaces are super simple**. No `Tuple`, `Dict` nonsense! There are two spaces: `Continuous` and `Discrete`, which you can compose into a `PyTreeSpace`.
23
+ - **Explicit episode truncation** supports correctly handling bootstrapping for value-function targets.
24
+ - **No auto-reset** by default. Resetting every step can be expensive!
25
+
26
+ ## 💪 Powerful, composable wrappers!
27
+
28
+ - **Carry state across episodes** to track running statistics, for example to normalize observations.
29
+ - **Composable wrappers** can be stacked in any order. For example, `ObservationNormalizationWrapper` before vs. after `VmapWrapper` gives per-env vs. global normalization.
30
+
31
+ ## 🔌 Adapters for existing suites
32
+
33
+
34
+ | 📦 | # 🤖 | # 🌍 |
35
+ | ------------------------------------------------------------------------- | ------- | ------- |
36
+ | [gymnax](https://github.com/RobertTLange/gymnax) | 🕺 | 24 |
37
+ | [brax](https://github.com/google/brax) | 🕺 | 12 |
38
+ | [jumanji](https://github.com/instadeepai/jumanji) | 🕺 / 👯 | 25 / 1 |
39
+ | [kinetix](https://github.com/flairox/kinetix) | 🕺 | 74 |
40
+ | [craftax](https://github.com/MichaelTMatthews/craftax) | 🕺 | 4 |
41
+ | [mujoco_playground](https://github.com/google-deepmind/mujoco_playground) | 🕺 | 54 |
42
+ | | | |
43
+ | Total | 🕺 / 👯 | 193 / 1 |
44
+
45
+
46
+ ```python
47
+ envelope.create("📦::🌍")
48
+ ```
49
+
50
+ let's you create environments from any of the above!
51
+
52
+ ## 📝 Testing
53
+
54
+ - **Default (no optional adapters deps required)**: `uv run pytest -m "not adapters"`
55
+ - **Adapters suite (requires full adapters dependency group)**:
56
+ - `uv sync --group adapters`
57
+ - `uv run pytest -m adapters`
58
+ - If any adapter dependency is missing/broken, the run will fail fast with an error telling you what to install.
59
+
60
+ ## 🏗️ Installation
61
+
62
+ ```bash
63
+ pip install jax-envelope
64
+ ```
65
+
66
+ ## 💞 Related projects
67
+
68
+ - [stoa](https://github.com/EdanToledo/Stoa) is a very similar project that provides adapters and wrappers for the jumanji-like interface.
69
+ - Check out all the great suites we have adapters for! [gymnax](https://github.com/RobertTLange/gymnax), [brax](https://github.com/google/brax), [jumanji](https://github.com/instadeepai/jumanji), [kinetix](https://github.com/flairox/kinetix), [craftax](https://github.com/MichaelTMatthews/craftax), [mujoco_playground](https://github.com/google-deepmind/mujoco_playground).
70
+ - We will be adding support for [jaxmarl](https://github.com/flairox/jaxmarl) and [pgx](https://github.com/sotetsuk/pgx) in the future, as soon as we figured out the best ever MARL interface for JAX!
71
+
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "jax-envelope"
3
- version = "0.2.0"
3
+ version = "0.4.0"
4
4
  description = "A JAX-native environment interface with powerful wrappers and adapters for popular RL environment suites"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.12"
@@ -52,7 +52,7 @@ allow-direct-references = true
52
52
  packages = ["src/envelope"]
53
53
 
54
54
  [dependency-groups]
55
- compat = [
55
+ adapters = [
56
56
  "brax>=0.13.0",
57
57
  "craftax>=1.4.3",
58
58
  "navix>=0.7.0",
@@ -80,5 +80,5 @@ kinetix-env = { git = "https://github.com/FLAIROx/Kinetix.git" }
80
80
  [tool.pytest.ini_options]
81
81
  testpaths = ["tests"]
82
82
  markers = [
83
- "compat: tests requiring optional compat dependencies",
83
+ "adapters: tests requiring optional adapters dependencies",
84
84
  ]
@@ -1,4 +1,4 @@
1
- from envelope.compat import create
1
+ from envelope.adapters import create
2
2
  from envelope.environment import Environment, Info, InfoContainer
3
3
  from envelope.spaces import BatchedSpace, Continuous, Discrete, PyTreeSpace, Space
4
4
  from envelope.struct import Container, FrozenPyTreeNode, field, static_field
@@ -4,14 +4,14 @@ from typing import Any, Protocol, Self
4
4
 
5
5
  # Lazy imports to avoid requiring all dependencies at once
6
6
  _env_module_map = {
7
- "gymnax": ("envelope.compat.gymnax_envelope", "GymnaxEnvelope"),
8
- "brax": ("envelope.compat.brax_envelope", "BraxEnvelope"),
9
- "navix": ("envelope.compat.navix_envelope", "NavixEnvelope"),
10
- "jumanji": ("envelope.compat.jumanji_envelope", "JumanjiEnvelope"),
11
- "kinetix": ("envelope.compat.kinetix_envelope", "KinetixEnvelope"),
12
- "craftax": ("envelope.compat.craftax_envelope", "CraftaxEnvelope"),
7
+ "gymnax": ("envelope.adapters.gymnax_envelope", "GymnaxEnvelope"),
8
+ "brax": ("envelope.adapters.brax_envelope", "BraxEnvelope"),
9
+ "navix": ("envelope.adapters.navix_envelope", "NavixEnvelope"),
10
+ "jumanji": ("envelope.adapters.jumanji_envelope", "JumanjiEnvelope"),
11
+ "kinetix": ("envelope.adapters.kinetix_envelope", "KinetixEnvelope"),
12
+ "craftax": ("envelope.adapters.craftax_envelope", "CraftaxEnvelope"),
13
13
  "mujoco_playground": (
14
- "envelope.compat.mujoco_playground_envelope",
14
+ "envelope.adapters.mujoco_playground_envelope",
15
15
  "MujocoPlaygroundEnvelope",
16
16
  ),
17
17
  }
@@ -69,7 +69,7 @@ class BraxEnvelope(Environment):
69
69
  info = InfoContainer(
70
70
  obs=brax_state.obs,
71
71
  reward=brax_state.reward,
72
- terminated=jnp.asarry(brax_state.done, dtype=bool).item(),
72
+ terminated=jnp.asarray(brax_state.done, dtype=bool),
73
73
  )
74
74
  info = info.update(**dataclasses.asdict(brax_state))
75
75
  return brax_state, info
@@ -10,7 +10,7 @@ from craftax.craftax_classic.envs.craftax_state import (
10
10
  from craftax.craftax_env import make_craftax_env_from_name
11
11
 
12
12
  from envelope import spaces as envelope_spaces
13
- from envelope.compat.gymnax_envelope import _convert_space as _convert_gymnax_space
13
+ from envelope.adapters.gymnax_envelope import _convert_space as _convert_gymnax_space
14
14
  from envelope.environment import Environment, Info, InfoContainer, State
15
15
  from envelope.struct import Container, static_field
16
16
  from envelope.typing import Key, PyTree, TypeAlias
@@ -81,7 +81,7 @@ class JumanjiEnvelope(Environment):
81
81
 
82
82
 
83
83
  def convert_jumanji_to_envelope_info(timestep: JumanjiTimeStep) -> InfoContainer:
84
- term = jnp.asarray(timestep.last(), dtype=bool).item()
84
+ term = jnp.asarray(timestep.last(), dtype=bool)
85
85
  info = InfoContainer(
86
86
  obs=timestep.observation, reward=timestep.reward, terminated=term
87
87
  ).update(**timestep.extras)
@@ -1,7 +1,7 @@
1
1
  """Kinetix compatibility wrapper.
2
2
 
3
3
  This module exposes Kinetix environments through the `envelope.environment.Environment`
4
- API. It mirrors envelope's compat philosophy:
4
+ API. It mirrors envelope's adapters philosophy:
5
5
  - prefer *no* environment-side auto-reset (use `AutoResetWrapper` in envelope)
6
6
  - prefer *no* fixed episode time-limits (use `TruncationWrapper` in envelope)
7
7
 
@@ -30,7 +30,7 @@ from kinetix.util.saving import load_from_json_file
30
30
 
31
31
  from envelope import field
32
32
  from envelope import spaces as envelope_spaces
33
- from envelope.compat.gymnax_envelope import _convert_space as _convert_gymnax_space
33
+ from envelope.adapters.gymnax_envelope import _convert_space as _convert_gymnax_space
34
34
  from envelope.environment import Environment, Info, InfoContainer, State
35
35
  from envelope.struct import Container, static_field
36
36
  from envelope.typing import Key, PyTree
@@ -42,7 +42,7 @@ class Environment(ABC, FrozenPyTreeNode):
42
42
 
43
43
  Two distinct lifecycle methods:
44
44
  init(key) - Initialize environment and all state from scratch.
45
- reset(key, state) - Reset the inner environment while preserving
45
+ reset(state, key) - Reset the inner environment while preserving
46
46
  episode-persistent state.
47
47
  """
48
48
 
@@ -51,7 +51,7 @@ class Environment(ABC, FrozenPyTreeNode):
51
51
  """Initialize environment and all state from scratch."""
52
52
  ...
53
53
 
54
- def reset(self, key: Key, state: State) -> tuple[State, Info]:
54
+ def reset(self, state: State, key: Key) -> tuple[State, Info]:
55
55
  """Reset the inner environment while preserving episode-persistent state."""
56
56
  return self.init(key)
57
57
 
@@ -1,6 +1,6 @@
1
1
  from abc import ABC, abstractmethod
2
2
  from functools import cached_property
3
- from typing import cast, override
3
+ from typing import override
4
4
 
5
5
  import jax
6
6
  from jax import numpy as jnp
@@ -54,7 +54,7 @@ class AutoResetWrapper(Wrapper):
54
54
  return state, info.update(final=state.last_final)
55
55
 
56
56
  @override
57
- def reset(self, key: Key, state: WrappedState) -> tuple[WrappedState, Info]:
57
+ def reset(self, state: WrappedState, key: Key) -> tuple[WrappedState, Info]:
58
58
  raise NotImplementedError("Reset is not implemented for AutoResetWrapper")
59
59
 
60
60
  @override
@@ -63,7 +63,7 @@ class AutoResetWrapper(Wrapper):
63
63
  state = state.replace(reset_key=key)
64
64
 
65
65
  inner_state, info = self.env.step(state.inner_state, action)
66
- reset_inner_state, reset_info = self.env.reset(key_reset, inner_state)
66
+ reset_inner_state, reset_info = self.env.reset(inner_state, key_reset)
67
67
 
68
68
  # Select next state and info based on done
69
69
  done = info.terminated | info.truncated
@@ -35,8 +35,8 @@ class ContinuousObservationWrapper(Wrapper):
35
35
  return state, info
36
36
 
37
37
  @override
38
- def reset(self, key: Key, state: State) -> tuple[State, Info]:
39
- state, info = self.env.reset(key, state)
38
+ def reset(self, state: State, key: Key) -> tuple[State, Info]:
39
+ state, info = self.env.reset(state, key)
40
40
  info = info.update(obs=to_float(info.obs))
41
41
  return state, info
42
42
 
@@ -24,8 +24,8 @@ class EpisodeStatisticsWrapper(Wrapper):
24
24
  return state, info.update(stats=state.stats)
25
25
 
26
26
  @override
27
- def reset(self, key: Key, state: State) -> tuple[State, Info]:
28
- inner_state, info = self.env.reset(key, state.inner_state)
27
+ def reset(self, state: State, key: Key) -> tuple[State, Info]:
28
+ inner_state, info = self.env.reset(state.inner_state, key)
29
29
  state = state.replace(inner_state=inner_state)
30
30
  return state, info.update(stats=state.stats)
31
31
 
@@ -37,10 +37,9 @@ class FlattenObservationWrapper(Wrapper):
37
37
  return state, info
38
38
 
39
39
  @override
40
- def reset(self, key: Key, state: State) -> tuple[State, Info]:
41
- state, info = self.env.reset(key, state)
42
- info = info.update(obs=flatten_x(info.obs))
43
- return state, info
40
+ def reset(self, state: State, key: Key) -> tuple[State, Info]:
41
+ next_state, info = self.env.reset(state, key)
42
+ return next_state, info.update(obs=flatten_x(info.obs))
44
43
 
45
44
  @override
46
45
  def step(self, state: State, action: PyTree) -> tuple[State, Info]:
@@ -76,8 +76,8 @@ class ObservationNormalizationWrapper(Wrapper):
76
76
  return self._normalize_and_update(next_state, info)
77
77
 
78
78
  @override
79
- def reset(self, key: Key, state: WrappedState) -> tuple[WrappedState, Info]:
80
- inner_state, info = self.env.reset(key, state.inner_state)
79
+ def reset(self, state: WrappedState, key: Key) -> tuple[WrappedState, Info]:
80
+ inner_state, info = self.env.reset(state.inner_state, key)
81
81
  # Preserve running statistics across resets
82
82
  next_state = self.ObservationNormalizationState(
83
83
  inner_state=inner_state, rmv_state=state.rmv_state
@@ -36,7 +36,7 @@ class PooledInitVmapWrapper(Wrapper):
36
36
  return state, info.update(final=pholder_info)
37
37
 
38
38
  @override
39
- def reset(self, key: Key, state: WrappedState) -> tuple[WrappedState, Info]:
39
+ def reset(self, state: WrappedState, key: Key) -> tuple[WrappedState, Info]:
40
40
  # It's hard to support reset for this wrapper.
41
41
  # We would have to init the state of a pool of unwrapped environments, and then
42
42
  # somehow inject this into the stack of wrapped states. The current data
@@ -48,7 +48,7 @@ class PooledInitVmapWrapper(Wrapper):
48
48
  # episodes before vmapping, we will implement this later.
49
49
  keys = _split_or_keep_key(key, self.batch_size + 1)
50
50
  key_next, keys_pool = keys[0], keys[1:]
51
- inner_state, info = jax.vmap(self.env.reset)(keys_pool, state.inner_state)
51
+ inner_state, info = jax.vmap(self.env.reset)(state.inner_state, keys_pool)
52
52
  state = state.replace(inner_state=inner_state, init_key=key_next)
53
53
  return state, info.update(final=state.last_final)
54
54
 
@@ -69,13 +69,13 @@ class StateInjectionWrapper(Wrapper):
69
69
  return state, info
70
70
 
71
71
  @override
72
- def reset(self, key: Key, state: WrappedState) -> tuple[WrappedState, Info]:
72
+ def reset(self, state: WrappedState, key: Key) -> tuple[WrappedState, Info]:
73
73
  # If reset state is set, use it instead of resetting inner env
74
74
  if state.reset_state is not None and state.reset_obs is not None:
75
75
  inner_state = state.reset_state
76
76
  info = InfoContainer(obs=state.reset_obs, reward=0.0, terminated=False)
77
77
  elif state.reset_state is None and state.reset_obs is None:
78
- inner_state, info = self.env.reset(key, state.inner_state)
78
+ inner_state, info = self.env.reset(state.inner_state, key)
79
79
  else:
80
80
  raise ValueError("State must set both reset_state and reset_obs or neither")
81
81
 
@@ -21,8 +21,8 @@ class TruncationWrapper(Wrapper):
21
21
  return state, info.update(truncated=self.max_steps <= 0)
22
22
 
23
23
  @override
24
- def reset(self, key: Key, state: WrappedState) -> tuple[WrappedState, Info]:
25
- inner_state, info = self.env.reset(key, state.inner_state)
24
+ def reset(self, state: WrappedState, key: Key) -> tuple[WrappedState, Info]:
25
+ inner_state, info = self.env.reset(state.inner_state, key)
26
26
  state = state.replace(inner_state=inner_state, steps=0)
27
27
  return state, info.update(truncated=self.max_steps <= 0)
28
28
 
@@ -40,11 +40,9 @@ class VmapEnvsWrapper(Wrapper):
40
40
  return state, info
41
41
 
42
42
  @override
43
- def reset(self, key: Key, state: PyTree) -> tuple[WrappedState, Info]:
43
+ def reset(self, state: PyTree, key: Key) -> tuple[WrappedState, Info]:
44
44
  keys = self._split_keys(key)
45
- state, info = jax.vmap(lambda e, k, s: e.reset(k, s))(
46
- self.env, keys, state
47
- )
45
+ state, info = jax.vmap(lambda e, s, k: e.reset(s, k))(self.env, state, keys)
48
46
  return state, info
49
47
 
50
48
  @override
@@ -58,7 +56,9 @@ class VmapEnvsWrapper(Wrapper):
58
56
  @property
59
57
  def observation_space(self) -> spaces.Space:
60
58
  env0 = _index_env(self.env, 0, self.batch_size)
61
- return spaces.BatchedSpace(space=env0.observation_space, batch_size=self.batch_size)
59
+ return spaces.BatchedSpace(
60
+ space=env0.observation_space, batch_size=self.batch_size
61
+ )
62
62
 
63
63
  @override
64
64
  @cached_property
@@ -41,9 +41,9 @@ class VmapWrapper(Wrapper):
41
41
  return state, info
42
42
 
43
43
  @override
44
- def reset(self, key: Key, state: PyTree) -> tuple[WrappedState, Info]:
44
+ def reset(self, state: PyTree, key: Key) -> tuple[WrappedState, Info]:
45
45
  keys = _split_or_keep_key(key, self.batch_size)
46
- state, info = jax.vmap(self.env.reset)(keys, state)
46
+ state, info = jax.vmap(self.env.reset)(state, keys)
47
47
  return state, info
48
48
 
49
49
  @override
@@ -29,11 +29,11 @@ class Wrapper(Environment):
29
29
  return self.env.init(key)
30
30
 
31
31
  @override
32
- def reset(self, key: Key, state: State) -> tuple[State, Info]:
33
- return self.env.reset(key, state)
32
+ def reset(self, state: State, key: Key) -> tuple[State, Info]:
33
+ return self.env.reset(state, key)
34
34
 
35
35
  @override
36
- def step(self, state: WrappedState, action: PyTree) -> tuple[WrappedState, Info]:
36
+ def step(self, state: State, action: PyTree) -> tuple[State, Info]:
37
37
  return self.env.step(state, action)
38
38
 
39
39
  @override
@@ -1,6 +1,6 @@
1
1
  import pytest
2
2
 
3
- pytestmark = pytest.mark.compat
3
+ pytestmark = pytest.mark.adapters
4
4
 
5
5
 
6
6
  @pytest.fixture(scope="module")
@@ -1,6 +1,6 @@
1
- """Shared contract helpers for compat wrappers.
1
+ """Shared contract helpers for adapters.
2
2
 
3
- These functions enforce a consistent baseline across all compat wrappers:
3
+ These functions enforce a consistent baseline across all adapters:
4
4
  - reset/step return (state, info) with Info fields present
5
5
  - reward is scalar-ish and finite
6
6
  - action sampling is valid for action_space
@@ -1,4 +1,4 @@
1
- """Tests for envelope.compat.brax_envelope module."""
1
+ """Tests for envelope.adapters.brax_envelope module."""
2
2
 
3
3
  # ruff: noqa: E402
4
4
 
@@ -7,14 +7,14 @@ from copy import deepcopy
7
7
  import jax
8
8
  import pytest
9
9
 
10
- pytestmark = pytest.mark.compat
10
+ pytestmark = pytest.mark.adapters
11
11
 
12
12
  pytest.importorskip("brax")
13
13
 
14
14
  from brax.envs import Wrapper as BraxWrapper
15
15
 
16
- from envelope.compat.brax_envelope import BraxEnvelope
17
- from tests.compat.contract import (
16
+ from envelope.adapters.brax_envelope import BraxEnvelope
17
+ from tests.adapters.contract import (
18
18
  assert_jitted_rollout_contract,
19
19
  assert_reset_step_contract,
20
20
  )
@@ -93,8 +93,8 @@ def test_wrapper_unwrapping():
93
93
 
94
94
  # Create a simple wrapper
95
95
  class SimpleWrapper(BraxWrapper):
96
- def reset(self, rng):
97
- return self.env.reset(rng)
96
+ def init(self, rng):
97
+ return self.env.init(rng)
98
98
 
99
99
  def step(self, state, action):
100
100
  return self.env.step(state, action)
@@ -1,4 +1,4 @@
1
- """Tests for envelope.compat.craftax_envelope module."""
1
+ """Tests for envelope.adapters.craftax_envelope module."""
2
2
 
3
3
  # ruff: noqa: E402
4
4
 
@@ -8,12 +8,12 @@ import jax
8
8
  import jax.numpy as jnp
9
9
  import pytest
10
10
 
11
- pytestmark = pytest.mark.compat
11
+ pytestmark = pytest.mark.adapters
12
12
 
13
13
  pytest.importorskip("craftax")
14
14
 
15
15
  from envelope.spaces import Continuous, Discrete
16
- from tests.compat.contract import (
16
+ from tests.adapters.contract import (
17
17
  assert_jitted_rollout_contract,
18
18
  assert_reset_step_contract,
19
19
  )
@@ -35,7 +35,7 @@ def craftax_env_id(request: pytest.FixtureRequest) -> str:
35
35
 
36
36
  @pytest.fixture(scope="module")
37
37
  def craftax_env(craftax_env_id: str):
38
- from envelope.compat.craftax_envelope import CraftaxEnvelope
38
+ from envelope.adapters.craftax_envelope import CraftaxEnvelope
39
39
 
40
40
  return CraftaxEnvelope.from_name(craftax_env_id)
41
41
 
@@ -114,7 +114,7 @@ class _DummyEnv:
114
114
 
115
115
 
116
116
  def test_from_name_errors_on_auto_reset():
117
- from envelope.compat.craftax_envelope import CraftaxEnvelope
117
+ from envelope.adapters.craftax_envelope import CraftaxEnvelope
118
118
 
119
119
  with pytest.raises(ValueError, match="Cannot override 'auto_reset' directly"):
120
120
  CraftaxEnvelope.from_name("AnyEnv", env_kwargs={"auto_reset": True})