jax-envelope 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. jax_envelope-0.1.0/.github/workflows/publish.yml +69 -0
  2. jax_envelope-0.1.0/.gitignore +172 -0
  3. jax_envelope-0.1.0/LICENSE +21 -0
  4. jax_envelope-0.1.0/PKG-INFO +87 -0
  5. jax_envelope-0.1.0/README.md +61 -0
  6. jax_envelope-0.1.0/pyproject.toml +77 -0
  7. jax_envelope-0.1.0/src/envelope/__init__.py +0 -0
  8. jax_envelope-0.1.0/src/envelope/compat/__init__.py +97 -0
  9. jax_envelope-0.1.0/src/envelope/compat/brax_envelope.py +98 -0
  10. jax_envelope-0.1.0/src/envelope/compat/craftax_envelope.py +86 -0
  11. jax_envelope-0.1.0/src/envelope/compat/gymnax_envelope.py +91 -0
  12. jax_envelope-0.1.0/src/envelope/compat/jumanji_envelope.py +127 -0
  13. jax_envelope-0.1.0/src/envelope/compat/kinetix_envelope.py +194 -0
  14. jax_envelope-0.1.0/src/envelope/compat/mujoco_playground_envelope.py +101 -0
  15. jax_envelope-0.1.0/src/envelope/compat/navix_envelope.py +86 -0
  16. jax_envelope-0.1.0/src/envelope/environment.py +64 -0
  17. jax_envelope-0.1.0/src/envelope/spaces.py +205 -0
  18. jax_envelope-0.1.0/src/envelope/struct.py +148 -0
  19. jax_envelope-0.1.0/src/envelope/typing.py +23 -0
  20. jax_envelope-0.1.0/src/envelope/wrappers/autoreset_wrapper.py +36 -0
  21. jax_envelope-0.1.0/src/envelope/wrappers/episode_statistics_wrapper.py +47 -0
  22. jax_envelope-0.1.0/src/envelope/wrappers/normalization.py +56 -0
  23. jax_envelope-0.1.0/src/envelope/wrappers/observation_normalization_wrapper.py +114 -0
  24. jax_envelope-0.1.0/src/envelope/wrappers/state_injection_wrapper.py +91 -0
  25. jax_envelope-0.1.0/src/envelope/wrappers/timestep_wrapper.py +22 -0
  26. jax_envelope-0.1.0/src/envelope/wrappers/truncation_wrapper.py +31 -0
  27. jax_envelope-0.1.0/src/envelope/wrappers/vmap_envs_wrapper.py +77 -0
  28. jax_envelope-0.1.0/src/envelope/wrappers/vmap_wrapper.py +51 -0
  29. jax_envelope-0.1.0/src/envelope/wrappers/wrapper.py +57 -0
  30. jax_envelope-0.1.0/tests/__init__.py +1 -0
  31. jax_envelope-0.1.0/tests/compat/__init__.py +0 -0
  32. jax_envelope-0.1.0/tests/compat/conftest.py +36 -0
  33. jax_envelope-0.1.0/tests/compat/contract.py +80 -0
  34. jax_envelope-0.1.0/tests/compat/test_brax_compat.py +131 -0
  35. jax_envelope-0.1.0/tests/compat/test_craftax_compat.py +118 -0
  36. jax_envelope-0.1.0/tests/compat/test_create.py +183 -0
  37. jax_envelope-0.1.0/tests/compat/test_create_integration.py +126 -0
  38. jax_envelope-0.1.0/tests/compat/test_gymnax_compat.py +220 -0
  39. jax_envelope-0.1.0/tests/compat/test_jumanji_compat.py +205 -0
  40. jax_envelope-0.1.0/tests/compat/test_kinetix_compat.py +225 -0
  41. jax_envelope-0.1.0/tests/compat/test_mujoco_playground_compat.py +211 -0
  42. jax_envelope-0.1.0/tests/compat/test_navix_compat.py +266 -0
  43. jax_envelope-0.1.0/tests/spaces/__init__.py +1 -0
  44. jax_envelope-0.1.0/tests/spaces/test_batched_space.py +263 -0
  45. jax_envelope-0.1.0/tests/spaces/test_continuous.py +281 -0
  46. jax_envelope-0.1.0/tests/spaces/test_discrete.py +299 -0
  47. jax_envelope-0.1.0/tests/spaces/test_pytree_space.py +477 -0
  48. jax_envelope-0.1.0/tests/spaces/test_serialization.py +249 -0
  49. jax_envelope-0.1.0/tests/test_container.py +236 -0
  50. jax_envelope-0.1.0/tests/test_struct.py +700 -0
  51. jax_envelope-0.1.0/tests/wrappers/__init__.py +0 -0
  52. jax_envelope-0.1.0/tests/wrappers/helpers.py +703 -0
  53. jax_envelope-0.1.0/tests/wrappers/test_autoreset_wrapper.py +549 -0
  54. jax_envelope-0.1.0/tests/wrappers/test_environment_wrapper.py +478 -0
  55. jax_envelope-0.1.0/tests/wrappers/test_normalization.py +128 -0
  56. jax_envelope-0.1.0/tests/wrappers/test_observation_normalization_wrapper.py +234 -0
  57. jax_envelope-0.1.0/tests/wrappers/test_state_injection_wrapper.py +363 -0
  58. jax_envelope-0.1.0/tests/wrappers/test_truncation_wrapper.py +146 -0
  59. jax_envelope-0.1.0/tests/wrappers/test_vmap_envs_wrapper.py +147 -0
  60. jax_envelope-0.1.0/tests/wrappers/test_vmap_wrapper.py +254 -0
  61. jax_envelope-0.1.0/uv.lock +3084 -0
@@ -0,0 +1,69 @@
1
+ name: Publish to PyPI
2
+
3
+ on:
4
+ release:
5
+ types: [published, edited]
6
+
7
+ jobs:
8
+ test:
9
+ name: Test Python ${{ matrix.python-version }}
10
+ runs-on: ubuntu-latest
11
+ strategy:
12
+ matrix:
13
+ python-version: ["3.12", "3.13"]
14
+ steps:
15
+ - uses: actions/checkout@v4
16
+
17
+ - name: Install uv
18
+ uses: astral-sh/setup-uv@v4
19
+ with:
20
+ enable-cache: true
21
+
22
+ - name: Set up Python ${{ matrix.python-version }}
23
+ run: uv python install ${{ matrix.python-version }}
24
+
25
+ - name: Install dependencies
26
+ run: |
27
+ test -f uv.lock
28
+ uv sync --group dev --locked
29
+
30
+ - name: Run tests
31
+ run: uv run pytest -m "not compat"
32
+
33
+ build:
34
+ name: Build distribution
35
+ needs: test
36
+ runs-on: ubuntu-latest
37
+ steps:
38
+ - uses: actions/checkout@v4
39
+
40
+ - name: Install uv
41
+ uses: astral-sh/setup-uv@v4
42
+
43
+ - name: Build package
44
+ run: uv build
45
+
46
+ - name: Store distribution packages
47
+ uses: actions/upload-artifact@v4
48
+ with:
49
+ name: python-package-distributions
50
+ path: dist/
51
+
52
+ publish:
53
+ name: Publish to PyPI
54
+ needs: build
55
+ runs-on: ubuntu-latest
56
+ environment:
57
+ name: pypi
58
+ url: https://pypi.org/p/jax-envelope
59
+ permissions:
60
+ id-token: write
61
+ steps:
62
+ - name: Download distribution packages
63
+ uses: actions/download-artifact@v4
64
+ with:
65
+ name: python-package-distributions
66
+ path: dist/
67
+
68
+ - name: Publish to PyPI
69
+ uses: pypa/gh-action-pypi-publish@release/v1
@@ -0,0 +1,172 @@
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+ .pdm-python
112
+
113
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
114
+ __pypackages__/
115
+
116
+ # Celery stuff
117
+ celerybeat-schedule
118
+ celerybeat.pid
119
+
120
+ # SageMath parsed files
121
+ *.sage.py
122
+
123
+ # Environments
124
+ .env
125
+ .venv
126
+ env/
127
+ venv/
128
+ ENV/
129
+ env.bak/
130
+ venv.bak/
131
+
132
+ # Spyder project settings
133
+ .spyderproject
134
+ .spyproject
135
+
136
+ # Rope project settings
137
+ .ropeproject
138
+
139
+ # mkdocs documentation
140
+ /site
141
+
142
+ # mypy
143
+ .mypy_cache/
144
+ .dmypy.json
145
+ dmypy.json
146
+
147
+ # Pyre type checker
148
+ .pyre/
149
+
150
+ # pytype static type analyzer
151
+ .pytype/
152
+
153
+ # Cython debug symbols
154
+ cython_debug/
155
+
156
+ # PyCharm
157
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
158
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
159
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
160
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
161
+ #.idea/
162
+
163
+ # Weights and Biases
164
+ wandb/
165
+
166
+ # ruff
167
+ .ruff_cache/
168
+
169
+ # Cursor
170
+ .cursor
171
+ .cursorignore
172
+ AGENTS.md
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Jarek Liesen
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,87 @@
1
+ Metadata-Version: 2.4
2
+ Name: jax-envelope
3
+ Version: 0.1.0
4
+ Summary: A JAX-native environment interface with powerful wrappers and adapters for popular RL environment suites
5
+ Project-URL: Homepage, https://github.com/keraJLi/envelope
6
+ Project-URL: Repository, https://github.com/keraJLi/envelope
7
+ Project-URL: Documentation, https://github.com/keraJLi/envelope#readme
8
+ Project-URL: Issues, https://github.com/keraJLi/envelope/issues
9
+ Project-URL: Changelog, https://github.com/keraJLi/envelope/releases
10
+ Author-email: Jarek Liesen <jarek.liesen@reuben.ox.ac.uk>
11
+ License: MIT
12
+ License-File: LICENSE
13
+ Keywords: deep-learning,environments,gymnasium,hardware-acceleration,jax,machine-learning,reinforcement-learning,vectorization
14
+ Classifier: Development Status :: 4 - Beta
15
+ Classifier: Intended Audience :: Developers
16
+ Classifier: Intended Audience :: Science/Research
17
+ Classifier: License :: OSI Approved :: MIT License
18
+ Classifier: Operating System :: OS Independent
19
+ Classifier: Programming Language :: Python :: 3
20
+ Classifier: Programming Language :: Python :: 3.12
21
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
22
+ Classifier: Typing :: Typed
23
+ Requires-Python: >=3.12
24
+ Requires-Dist: jax>=0.5.0
25
+ Description-Content-Type: text/markdown
26
+
27
+ # 💌 Envelope: a JAX-native environment interface
28
+ ```python
29
+ # Create environments from JAX-native suites you have installed, ...
30
+ env = envelope.create("gymnax::CartPole-v1")
31
+
32
+ # ... interact with the environments using a simple interface, ...
33
+ state, info = env.reset(key)
34
+ states, infos = jax.lax.scan(env.step, state, actions)
35
+ plt.plot(infos.reward.cumsum())
36
+
37
+ # ... and enjoy a powerful ecosystem of wrappers.
38
+ env = envelope.wrappers.AutoResetWrapper(env)
39
+ env = envelope.wrappers.VmapWrapper(env)
40
+ env = envelope.wrappers.ObservationNormalizationWrapper(env)
41
+ ```
42
+
43
+ ## 🌍 Simple, expressive interaction!
44
+ * **Environments are pytrees**. Squish them through JAX transformations and trace their parameters.
45
+ * **Idiomatic jax-y interface** of `reset(key: Key) -> State, Info` and `step(state: State, action: PyTree) -> State, Info`. You can directly `jax.scan` over a `step(...)`!
46
+ * **Spaces are super simple**. No `Tuple`, `Dict` nonsense! There are two spaces: `Continuous` and `Discrete`, which you can compose into a `PyTreeSpace`.
47
+ * **Explicit episode truncation** supports correctly handling bootstrapping for value-function targets.
48
+ * **No auto-reset** by default. Resetting every step can be expensive!
49
+
50
+ ## 💪 Powerful, composable wrappers!
51
+ * **Carry state across episodes** to track running statistics, for example to normalize observations.
52
+ * **Composable wrappers** can be stacked in any order. For example, `ObservationNormalizationWrapper` before vs. after `VmapWrapper` gives per-env vs. global normalization.
53
+ <!-- TODO: Add auto-reset behavior (including state injection) and optimistic resets once I implement them. -->
54
+
55
+ ## 🔌 Adapters for existing suites
56
+ | 📦 | # 🤖 | # 🌍 |
57
+ |------|------|------|
58
+ | [gymnax](https://github.com/RobertTLange/gymnax) | 🕺 | 24 |
59
+ | [brax](https://github.com/google/brax) | 🕺 | 12 |
60
+ | [jumanji](https://github.com/instadeepai/jumanji) | 🕺 / 👯 | 25 / 1 |
61
+ | [kinetix](https://github.com/flairox/kinetix) | 🕺 | 74 |
62
+ | [craftax](https://github.com/MichaelTMatthews/craftax) | 🕺 | 4 |
63
+ | [mujoco_playground](https://github.com/google-deepmind/mujoco_playground) | 🕺 | 54 |
64
+ | | |
65
+ | Total | 🕺 / 👯 | 193 / 1 |
66
+
67
+ ```python
68
+ envelope.create("📦::🌍")
69
+ ```
70
+ let's you create environments from any of the above!
71
+
72
+ ## 📝 Testing
73
+ - **Default (no optional compat deps required)**: `uv run pytest -m "not compat"`
74
+ - **Compat suite (requires full compat dependency group)**:
75
+ - `uv sync --group compat`
76
+ - `uv run pytest -m compat`
77
+ - If any compat dependency is missing/broken, the run will fail fast with an error telling you what to install.
78
+
79
+ ## 🏗️ Installation
80
+ ```bash
81
+ pip install jax-envelope
82
+ ```
83
+
84
+ ## 💞 Related projects
85
+ * [stoax](https://github.com/EdanToledo/Stoa) is a very similar project that provides adapters and wrappers for the jumanji-like interface.
86
+ * Check out all the great suites we have adapters for! [gymnax](https://github.com/RobertTLange/gymnax), [brax](https://github.com/google/brax), [jumanji](https://github.com/instadeepai/jumanji), [kinetix](https://github.com/flairox/kinetix), [craftax](https://github.com/MichaelTMatthews/craftax), [mujoco_playground](https://github.com/google-deepmind/mujoco_playground).
87
+ * We will be adding support for [jaxmarl](https://github.com/flairox/jaxmarl) and [pgx](https://github.com/sotetsuk/pgx) in the future, as soon as we figured out the best ever MARL interface for JAX!
@@ -0,0 +1,61 @@
1
+ # 💌 Envelope: a JAX-native environment interface
2
+ ```python
3
+ # Create environments from JAX-native suites you have installed, ...
4
+ env = envelope.create("gymnax::CartPole-v1")
5
+
6
+ # ... interact with the environments using a simple interface, ...
7
+ state, info = env.reset(key)
8
+ states, infos = jax.lax.scan(env.step, state, actions)
9
+ plt.plot(infos.reward.cumsum())
10
+
11
+ # ... and enjoy a powerful ecosystem of wrappers.
12
+ env = envelope.wrappers.AutoResetWrapper(env)
13
+ env = envelope.wrappers.VmapWrapper(env)
14
+ env = envelope.wrappers.ObservationNormalizationWrapper(env)
15
+ ```
16
+
17
+ ## 🌍 Simple, expressive interaction!
18
+ * **Environments are pytrees**. Squish them through JAX transformations and trace their parameters.
19
+ * **Idiomatic jax-y interface** of `reset(key: Key) -> State, Info` and `step(state: State, action: PyTree) -> State, Info`. You can directly `jax.scan` over a `step(...)`!
20
+ * **Spaces are super simple**. No `Tuple`, `Dict` nonsense! There are two spaces: `Continuous` and `Discrete`, which you can compose into a `PyTreeSpace`.
21
+ * **Explicit episode truncation** supports correctly handling bootstrapping for value-function targets.
22
+ * **No auto-reset** by default. Resetting every step can be expensive!
23
+
24
+ ## 💪 Powerful, composable wrappers!
25
+ * **Carry state across episodes** to track running statistics, for example to normalize observations.
26
+ * **Composable wrappers** can be stacked in any order. For example, `ObservationNormalizationWrapper` before vs. after `VmapWrapper` gives per-env vs. global normalization.
27
+ <!-- TODO: Add auto-reset behavior (including state injection) and optimistic resets once I implement them. -->
28
+
29
+ ## 🔌 Adapters for existing suites
30
+ | 📦 | # 🤖 | # 🌍 |
31
+ |------|------|------|
32
+ | [gymnax](https://github.com/RobertTLange/gymnax) | 🕺 | 24 |
33
+ | [brax](https://github.com/google/brax) | 🕺 | 12 |
34
+ | [jumanji](https://github.com/instadeepai/jumanji) | 🕺 / 👯 | 25 / 1 |
35
+ | [kinetix](https://github.com/flairox/kinetix) | 🕺 | 74 |
36
+ | [craftax](https://github.com/MichaelTMatthews/craftax) | 🕺 | 4 |
37
+ | [mujoco_playground](https://github.com/google-deepmind/mujoco_playground) | 🕺 | 54 |
38
+ | | |
39
+ | Total | 🕺 / 👯 | 193 / 1 |
40
+
41
+ ```python
42
+ envelope.create("📦::🌍")
43
+ ```
44
+ let's you create environments from any of the above!
45
+
46
+ ## 📝 Testing
47
+ - **Default (no optional compat deps required)**: `uv run pytest -m "not compat"`
48
+ - **Compat suite (requires full compat dependency group)**:
49
+ - `uv sync --group compat`
50
+ - `uv run pytest -m compat`
51
+ - If any compat dependency is missing/broken, the run will fail fast with an error telling you what to install.
52
+
53
+ ## 🏗️ Installation
54
+ ```bash
55
+ pip install jax-envelope
56
+ ```
57
+
58
+ ## 💞 Related projects
59
+ * [stoax](https://github.com/EdanToledo/Stoa) is a very similar project that provides adapters and wrappers for the jumanji-like interface.
60
+ * Check out all the great suites we have adapters for! [gymnax](https://github.com/RobertTLange/gymnax), [brax](https://github.com/google/brax), [jumanji](https://github.com/instadeepai/jumanji), [kinetix](https://github.com/flairox/kinetix), [craftax](https://github.com/MichaelTMatthews/craftax), [mujoco_playground](https://github.com/google-deepmind/mujoco_playground).
61
+ * We will be adding support for [jaxmarl](https://github.com/flairox/jaxmarl) and [pgx](https://github.com/sotetsuk/pgx) in the future, as soon as we figured out the best ever MARL interface for JAX!
@@ -0,0 +1,77 @@
1
+ [project]
2
+ name = "jax-envelope"
3
+ version = "0.1.0"
4
+ description = "A JAX-native environment interface with powerful wrappers and adapters for popular RL environment suites"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ authors = [
8
+ {name = "Jarek Liesen", email = "jarek.liesen@reuben.ox.ac.uk"},
9
+ ]
10
+ dependencies = [
11
+ "jax>=0.5.0",
12
+ ]
13
+
14
+ license = {text = "MIT"}
15
+ keywords = [
16
+ "machine-learning",
17
+ "reinforcement-learning",
18
+ "jax",
19
+ "vectorization",
20
+ "hardware-acceleration",
21
+ "gymnasium",
22
+ "environments",
23
+ "deep-learning",
24
+ ]
25
+ classifiers = [
26
+ "Development Status :: 4 - Beta",
27
+ "Intended Audience :: Developers",
28
+ "Intended Audience :: Science/Research",
29
+ "License :: OSI Approved :: MIT License",
30
+ "Operating System :: OS Independent",
31
+ "Programming Language :: Python :: 3",
32
+ "Programming Language :: Python :: 3.12",
33
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
34
+ "Typing :: Typed",
35
+ ]
36
+
37
+ [project.urls]
38
+ Homepage = "https://github.com/keraJLi/envelope"
39
+ Repository = "https://github.com/keraJLi/envelope"
40
+ Documentation = "https://github.com/keraJLi/envelope#readme"
41
+ Issues = "https://github.com/keraJLi/envelope/issues"
42
+ Changelog = "https://github.com/keraJLi/envelope/releases"
43
+
44
+ [build-system]
45
+ requires = ["hatchling"]
46
+ build-backend = "hatchling.build"
47
+
48
+ [tool.hatch.metadata]
49
+ allow-direct-references = true
50
+
51
+ [tool.hatch.build.targets.wheel]
52
+ packages = ["src/envelope"]
53
+
54
+ [tool.uv.sources]
55
+ gymnax = { git = "https://github.com/RobertTLange/gymnax" }
56
+
57
+ [dependency-groups]
58
+ compat = [
59
+ "gymnax @ git+https://github.com/RobertTLange/gymnax@main",
60
+ "brax>=0.13.0",
61
+ "craftax>=1.4.3",
62
+ "navix>=0.7.0",
63
+ "jumanji>=1.0.1",
64
+ "kinetix-env>=2.0.0",
65
+ "playground>=0.1.0",
66
+ ]
67
+ dev = [
68
+ "hypothesis>=6.148.1",
69
+ "pytest>=9.0.0",
70
+ "pytest-cov>=7.0.0",
71
+ "ruff>=0.14.2",
72
+ ]
73
+
74
+ [tool.pytest.ini_options]
75
+ markers = [
76
+ "compat: tests requiring optional compat dependencies",
77
+ ]
File without changes
@@ -0,0 +1,97 @@
1
+ """Compatibility wrappers for various RL environment libraries."""
2
+
3
+ from typing import Any, Protocol, Self
4
+
5
+ # Lazy imports to avoid requiring all dependencies at once
6
+ _env_module_map = {
7
+ "gymnax": ("envelope.compat.gymnax_envelope", "GymnaxEnvelope"),
8
+ "brax": ("envelope.compat.brax_envelope", "BraxEnvelope"),
9
+ "navix": ("envelope.compat.navix_envelope", "NavixEnvelope"),
10
+ "jumanji": ("envelope.compat.jumanji_envelope", "JumanjiEnvelope"),
11
+ "kinetix": ("envelope.compat.kinetix_envelope", "KinetixEnvelope"),
12
+ "craftax": ("envelope.compat.craftax_envelope", "CraftaxEnvelope"),
13
+ "mujoco_playground": (
14
+ "envelope.compat.mujoco_playground_envelope",
15
+ "MujocoPlaygroundEnvelope",
16
+ ),
17
+ }
18
+
19
+
20
+ class HasFromNameInit(Protocol):
21
+ @classmethod
22
+ def from_name(
23
+ cls, env_name: str, env_kwargs: dict[str, Any] | None = None, **kwargs
24
+ ) -> Self: ...
25
+
26
+ """Creates an environment from a name and keyword arguments. Unless otherwise noted,
27
+ the created environment will have it's default parameters, with truncation and auto
28
+ reset disabled.
29
+
30
+ Args:
31
+ env_name: Environment name
32
+ env_kwargs: Keyword arguments passed to the environment constructor
33
+ **kwargs: Additional keyword arguments passed to the environment wrapper
34
+ """
35
+
36
+
37
+ def create(env_name: str, env_kwargs: dict[str, Any] | None = None, **kwargs):
38
+ """Create an environment from a prefixed environment ID.
39
+
40
+ Args:
41
+ env_name: Environment ID in the format "suite::env_name" (e.g., "brax::ant")
42
+ env_kwargs: Keyword arguments passed to the suite's environment constructor
43
+ **kwargs: Additional keyword arguments passed to the environment wrapper
44
+
45
+ Returns:
46
+ An instance of the wrapped environment
47
+
48
+ Examples:
49
+ >>> env = create("jumanji::snake")
50
+ >>> env = create("brax::ant", env_kwargs={"backend": "spring"})
51
+ >>> env = create("gymnax::CartPole-v1", env_params=...)
52
+ """
53
+ original_env_id = env_name
54
+ if "::" not in env_name:
55
+ raise ValueError(
56
+ f"Environment ID must be in format 'suite::env_name', got: {original_env_id}"
57
+ )
58
+
59
+ suite, env_name = env_name.split("::", 1)
60
+ if not suite or not env_name:
61
+ raise ValueError(
62
+ f"Environment ID must be in format 'suite::env_name', got: {original_env_id}"
63
+ )
64
+
65
+ if suite not in _env_module_map:
66
+ raise ValueError(
67
+ f"Unknown environment suite: {suite}. "
68
+ f"Available suites: {list(_env_module_map.keys())}"
69
+ )
70
+
71
+ # Lazy import the wrapper class
72
+ module_name, class_name = _env_module_map[suite]
73
+ try:
74
+ import importlib
75
+
76
+ module = importlib.import_module(module_name)
77
+ env_class: HasFromNameInit = getattr(module, class_name)
78
+ except ImportError as e:
79
+ raise ImportError(
80
+ f"Failed to import {suite} wrapper. "
81
+ f"Make sure you have installed the '{suite}' dependencies. "
82
+ f"Original error: {e}"
83
+ ) from e
84
+
85
+ env = env_class.from_name(env_name, env_kwargs=env_kwargs, **kwargs)
86
+
87
+ # Wrap with TruncationWrapper using adapter's default
88
+ default_max_steps = getattr(env, "default_max_steps", None)
89
+ if default_max_steps is not None:
90
+ from envelope.wrappers.truncation_wrapper import TruncationWrapper
91
+
92
+ env = TruncationWrapper(env=env, max_steps=int(default_max_steps))
93
+
94
+ return env
95
+
96
+
97
+ __all__ = ["create"]
@@ -0,0 +1,98 @@
1
+ import dataclasses
2
+ import warnings
3
+ from copy import copy
4
+ from functools import cached_property
5
+ from typing import Any, override
6
+
7
+ from brax.envs import Env as BraxEnv
8
+ from brax.envs import Wrapper as BraxWrapper
9
+ from brax.envs import create as brax_create
10
+ from jax import numpy as jnp
11
+
12
+ from envelope import spaces
13
+ from envelope.environment import Environment, Info, InfoContainer, State
14
+ from envelope.struct import static_field
15
+ from envelope.typing import Key, PyTree
16
+
17
+ # Default episode_length in brax.envs.create()
18
+ _BRAX_DEFAULT_EPISODE_LENGTH = 1000
19
+
20
+
21
+ class BraxEnvelope(Environment):
22
+ """Wrapper to convert a Brax environment to a envelope environment."""
23
+
24
+ brax_env: BraxEnv = static_field()
25
+
26
+ @classmethod
27
+ def from_name(
28
+ cls, env_name: str, env_kwargs: dict[str, Any] | None = None
29
+ ) -> "BraxEnvelope":
30
+ env_kwargs = env_kwargs or {}
31
+ if "episode_length" in env_kwargs:
32
+ raise ValueError(
33
+ "Cannot override 'episode_length' directly. "
34
+ "Use TruncationWrapper for episode length control."
35
+ )
36
+ if "auto_reset" in env_kwargs:
37
+ raise ValueError(
38
+ "Cannot override 'auto_reset' directly. "
39
+ "Use AutoResetWrapper for auto-reset behavior."
40
+ )
41
+
42
+ env_kwargs["episode_length"] = jnp.inf
43
+ env_kwargs["auto_reset"] = False
44
+ env = brax_create(env_name, **env_kwargs)
45
+ return cls(brax_env=env)
46
+
47
+ @property
48
+ def default_max_steps(self) -> int:
49
+ return _BRAX_DEFAULT_EPISODE_LENGTH
50
+
51
+ def __post_init__(self) -> "BraxEnvelope":
52
+ if isinstance(self.brax_env, BraxWrapper):
53
+ warnings.warn(
54
+ "Environment wrapping should be handled by envelope. "
55
+ "Unwrapping brax environment before converting..."
56
+ )
57
+ object.__setattr__(self, "brax_env", self.brax_env.unwrapped)
58
+
59
+ @override
60
+ def reset(self, key: Key) -> tuple[State, Info]:
61
+ brax_state = self.brax_env.reset(key)
62
+ info = InfoContainer(obs=brax_state.obs, reward=0.0, terminated=False)
63
+ info = info.update(**dataclasses.asdict(brax_state))
64
+ return brax_state, info
65
+
66
+ @override
67
+ def step(self, state: State, action: PyTree) -> tuple[State, Info]:
68
+ brax_state = self.brax_env.step(state, action)
69
+ info = InfoContainer(
70
+ obs=brax_state.obs, reward=brax_state.reward, terminated=brax_state.done
71
+ )
72
+ info = info.update(**dataclasses.asdict(brax_state))
73
+ return brax_state, info
74
+
75
+ @override
76
+ @cached_property
77
+ def action_space(self) -> spaces.Space:
78
+ # All brax environments have action limit of -1 to 1
79
+ return spaces.Continuous.from_shape(
80
+ low=-1.0, high=1.0, shape=(self.brax_env.action_size,)
81
+ )
82
+
83
+ @override
84
+ @cached_property
85
+ def observation_space(self) -> spaces.Space:
86
+ # All brax environments have observation limit of -inf to inf
87
+ return spaces.Continuous.from_shape(
88
+ low=-jnp.inf, high=jnp.inf, shape=(self.brax_env.observation_size,)
89
+ )
90
+
91
+ def __deepcopy__(self, memo):
92
+ warnings.warn(
93
+ f"Trying to deepcopy {type(self).__name__}, which contains a brax env. "
94
+ "Brax envs throw an error when deepcopying, so a shallow copy is returned.",
95
+ category=RuntimeWarning,
96
+ stacklevel=2,
97
+ )
98
+ return copy(self)