jax-envelope 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jax_envelope-0.1.0/.github/workflows/publish.yml +69 -0
- jax_envelope-0.1.0/.gitignore +172 -0
- jax_envelope-0.1.0/LICENSE +21 -0
- jax_envelope-0.1.0/PKG-INFO +87 -0
- jax_envelope-0.1.0/README.md +61 -0
- jax_envelope-0.1.0/pyproject.toml +77 -0
- jax_envelope-0.1.0/src/envelope/__init__.py +0 -0
- jax_envelope-0.1.0/src/envelope/compat/__init__.py +97 -0
- jax_envelope-0.1.0/src/envelope/compat/brax_envelope.py +98 -0
- jax_envelope-0.1.0/src/envelope/compat/craftax_envelope.py +86 -0
- jax_envelope-0.1.0/src/envelope/compat/gymnax_envelope.py +91 -0
- jax_envelope-0.1.0/src/envelope/compat/jumanji_envelope.py +127 -0
- jax_envelope-0.1.0/src/envelope/compat/kinetix_envelope.py +194 -0
- jax_envelope-0.1.0/src/envelope/compat/mujoco_playground_envelope.py +101 -0
- jax_envelope-0.1.0/src/envelope/compat/navix_envelope.py +86 -0
- jax_envelope-0.1.0/src/envelope/environment.py +64 -0
- jax_envelope-0.1.0/src/envelope/spaces.py +205 -0
- jax_envelope-0.1.0/src/envelope/struct.py +148 -0
- jax_envelope-0.1.0/src/envelope/typing.py +23 -0
- jax_envelope-0.1.0/src/envelope/wrappers/autoreset_wrapper.py +36 -0
- jax_envelope-0.1.0/src/envelope/wrappers/episode_statistics_wrapper.py +47 -0
- jax_envelope-0.1.0/src/envelope/wrappers/normalization.py +56 -0
- jax_envelope-0.1.0/src/envelope/wrappers/observation_normalization_wrapper.py +114 -0
- jax_envelope-0.1.0/src/envelope/wrappers/state_injection_wrapper.py +91 -0
- jax_envelope-0.1.0/src/envelope/wrappers/timestep_wrapper.py +22 -0
- jax_envelope-0.1.0/src/envelope/wrappers/truncation_wrapper.py +31 -0
- jax_envelope-0.1.0/src/envelope/wrappers/vmap_envs_wrapper.py +77 -0
- jax_envelope-0.1.0/src/envelope/wrappers/vmap_wrapper.py +51 -0
- jax_envelope-0.1.0/src/envelope/wrappers/wrapper.py +57 -0
- jax_envelope-0.1.0/tests/__init__.py +1 -0
- jax_envelope-0.1.0/tests/compat/__init__.py +0 -0
- jax_envelope-0.1.0/tests/compat/conftest.py +36 -0
- jax_envelope-0.1.0/tests/compat/contract.py +80 -0
- jax_envelope-0.1.0/tests/compat/test_brax_compat.py +131 -0
- jax_envelope-0.1.0/tests/compat/test_craftax_compat.py +118 -0
- jax_envelope-0.1.0/tests/compat/test_create.py +183 -0
- jax_envelope-0.1.0/tests/compat/test_create_integration.py +126 -0
- jax_envelope-0.1.0/tests/compat/test_gymnax_compat.py +220 -0
- jax_envelope-0.1.0/tests/compat/test_jumanji_compat.py +205 -0
- jax_envelope-0.1.0/tests/compat/test_kinetix_compat.py +225 -0
- jax_envelope-0.1.0/tests/compat/test_mujoco_playground_compat.py +211 -0
- jax_envelope-0.1.0/tests/compat/test_navix_compat.py +266 -0
- jax_envelope-0.1.0/tests/spaces/__init__.py +1 -0
- jax_envelope-0.1.0/tests/spaces/test_batched_space.py +263 -0
- jax_envelope-0.1.0/tests/spaces/test_continuous.py +281 -0
- jax_envelope-0.1.0/tests/spaces/test_discrete.py +299 -0
- jax_envelope-0.1.0/tests/spaces/test_pytree_space.py +477 -0
- jax_envelope-0.1.0/tests/spaces/test_serialization.py +249 -0
- jax_envelope-0.1.0/tests/test_container.py +236 -0
- jax_envelope-0.1.0/tests/test_struct.py +700 -0
- jax_envelope-0.1.0/tests/wrappers/__init__.py +0 -0
- jax_envelope-0.1.0/tests/wrappers/helpers.py +703 -0
- jax_envelope-0.1.0/tests/wrappers/test_autoreset_wrapper.py +549 -0
- jax_envelope-0.1.0/tests/wrappers/test_environment_wrapper.py +478 -0
- jax_envelope-0.1.0/tests/wrappers/test_normalization.py +128 -0
- jax_envelope-0.1.0/tests/wrappers/test_observation_normalization_wrapper.py +234 -0
- jax_envelope-0.1.0/tests/wrappers/test_state_injection_wrapper.py +363 -0
- jax_envelope-0.1.0/tests/wrappers/test_truncation_wrapper.py +146 -0
- jax_envelope-0.1.0/tests/wrappers/test_vmap_envs_wrapper.py +147 -0
- jax_envelope-0.1.0/tests/wrappers/test_vmap_wrapper.py +254 -0
- jax_envelope-0.1.0/uv.lock +3084 -0
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
name: Publish to PyPI
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
release:
|
|
5
|
+
types: [published, edited]
|
|
6
|
+
|
|
7
|
+
jobs:
|
|
8
|
+
test:
|
|
9
|
+
name: Test Python ${{ matrix.python-version }}
|
|
10
|
+
runs-on: ubuntu-latest
|
|
11
|
+
strategy:
|
|
12
|
+
matrix:
|
|
13
|
+
python-version: ["3.12", "3.13"]
|
|
14
|
+
steps:
|
|
15
|
+
- uses: actions/checkout@v4
|
|
16
|
+
|
|
17
|
+
- name: Install uv
|
|
18
|
+
uses: astral-sh/setup-uv@v4
|
|
19
|
+
with:
|
|
20
|
+
enable-cache: true
|
|
21
|
+
|
|
22
|
+
- name: Set up Python ${{ matrix.python-version }}
|
|
23
|
+
run: uv python install ${{ matrix.python-version }}
|
|
24
|
+
|
|
25
|
+
- name: Install dependencies
|
|
26
|
+
run: |
|
|
27
|
+
test -f uv.lock
|
|
28
|
+
uv sync --group dev --locked
|
|
29
|
+
|
|
30
|
+
- name: Run tests
|
|
31
|
+
run: uv run pytest -m "not compat"
|
|
32
|
+
|
|
33
|
+
build:
|
|
34
|
+
name: Build distribution
|
|
35
|
+
needs: test
|
|
36
|
+
runs-on: ubuntu-latest
|
|
37
|
+
steps:
|
|
38
|
+
- uses: actions/checkout@v4
|
|
39
|
+
|
|
40
|
+
- name: Install uv
|
|
41
|
+
uses: astral-sh/setup-uv@v4
|
|
42
|
+
|
|
43
|
+
- name: Build package
|
|
44
|
+
run: uv build
|
|
45
|
+
|
|
46
|
+
- name: Store distribution packages
|
|
47
|
+
uses: actions/upload-artifact@v4
|
|
48
|
+
with:
|
|
49
|
+
name: python-package-distributions
|
|
50
|
+
path: dist/
|
|
51
|
+
|
|
52
|
+
publish:
|
|
53
|
+
name: Publish to PyPI
|
|
54
|
+
needs: build
|
|
55
|
+
runs-on: ubuntu-latest
|
|
56
|
+
environment:
|
|
57
|
+
name: pypi
|
|
58
|
+
url: https://pypi.org/p/jax-envelope
|
|
59
|
+
permissions:
|
|
60
|
+
id-token: write
|
|
61
|
+
steps:
|
|
62
|
+
- name: Download distribution packages
|
|
63
|
+
uses: actions/download-artifact@v4
|
|
64
|
+
with:
|
|
65
|
+
name: python-package-distributions
|
|
66
|
+
path: dist/
|
|
67
|
+
|
|
68
|
+
- name: Publish to PyPI
|
|
69
|
+
uses: pypa/gh-action-pypi-publish@release/v1
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
# Byte-compiled / optimized / DLL files
|
|
2
|
+
__pycache__/
|
|
3
|
+
*.py[cod]
|
|
4
|
+
*$py.class
|
|
5
|
+
|
|
6
|
+
# C extensions
|
|
7
|
+
*.so
|
|
8
|
+
|
|
9
|
+
# Distribution / packaging
|
|
10
|
+
.Python
|
|
11
|
+
build/
|
|
12
|
+
develop-eggs/
|
|
13
|
+
dist/
|
|
14
|
+
downloads/
|
|
15
|
+
eggs/
|
|
16
|
+
.eggs/
|
|
17
|
+
lib/
|
|
18
|
+
lib64/
|
|
19
|
+
parts/
|
|
20
|
+
sdist/
|
|
21
|
+
var/
|
|
22
|
+
wheels/
|
|
23
|
+
share/python-wheels/
|
|
24
|
+
*.egg-info/
|
|
25
|
+
.installed.cfg
|
|
26
|
+
*.egg
|
|
27
|
+
MANIFEST
|
|
28
|
+
|
|
29
|
+
# PyInstaller
|
|
30
|
+
# Usually these files are written by a python script from a template
|
|
31
|
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
|
32
|
+
*.manifest
|
|
33
|
+
*.spec
|
|
34
|
+
|
|
35
|
+
# Installer logs
|
|
36
|
+
pip-log.txt
|
|
37
|
+
pip-delete-this-directory.txt
|
|
38
|
+
|
|
39
|
+
# Unit test / coverage reports
|
|
40
|
+
htmlcov/
|
|
41
|
+
.tox/
|
|
42
|
+
.nox/
|
|
43
|
+
.coverage
|
|
44
|
+
.coverage.*
|
|
45
|
+
.cache
|
|
46
|
+
nosetests.xml
|
|
47
|
+
coverage.xml
|
|
48
|
+
*.cover
|
|
49
|
+
*.py,cover
|
|
50
|
+
.hypothesis/
|
|
51
|
+
.pytest_cache/
|
|
52
|
+
cover/
|
|
53
|
+
|
|
54
|
+
# Translations
|
|
55
|
+
*.mo
|
|
56
|
+
*.pot
|
|
57
|
+
|
|
58
|
+
# Django stuff:
|
|
59
|
+
*.log
|
|
60
|
+
local_settings.py
|
|
61
|
+
db.sqlite3
|
|
62
|
+
db.sqlite3-journal
|
|
63
|
+
|
|
64
|
+
# Flask stuff:
|
|
65
|
+
instance/
|
|
66
|
+
.webassets-cache
|
|
67
|
+
|
|
68
|
+
# Scrapy stuff:
|
|
69
|
+
.scrapy
|
|
70
|
+
|
|
71
|
+
# Sphinx documentation
|
|
72
|
+
docs/_build/
|
|
73
|
+
|
|
74
|
+
# PyBuilder
|
|
75
|
+
.pybuilder/
|
|
76
|
+
target/
|
|
77
|
+
|
|
78
|
+
# Jupyter Notebook
|
|
79
|
+
.ipynb_checkpoints
|
|
80
|
+
|
|
81
|
+
# IPython
|
|
82
|
+
profile_default/
|
|
83
|
+
ipython_config.py
|
|
84
|
+
|
|
85
|
+
# pyenv
|
|
86
|
+
# For a library or package, you might want to ignore these files since the code is
|
|
87
|
+
# intended to run in multiple environments; otherwise, check them in:
|
|
88
|
+
# .python-version
|
|
89
|
+
|
|
90
|
+
# pipenv
|
|
91
|
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
|
92
|
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
|
93
|
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
|
94
|
+
# install all needed dependencies.
|
|
95
|
+
#Pipfile.lock
|
|
96
|
+
|
|
97
|
+
# poetry
|
|
98
|
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
|
99
|
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
|
100
|
+
# commonly ignored for libraries.
|
|
101
|
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
|
102
|
+
#poetry.lock
|
|
103
|
+
|
|
104
|
+
# pdm
|
|
105
|
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
|
106
|
+
#pdm.lock
|
|
107
|
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
|
108
|
+
# in version control.
|
|
109
|
+
# https://pdm.fming.dev/#use-with-ide
|
|
110
|
+
.pdm.toml
|
|
111
|
+
.pdm-python
|
|
112
|
+
|
|
113
|
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
|
114
|
+
__pypackages__/
|
|
115
|
+
|
|
116
|
+
# Celery stuff
|
|
117
|
+
celerybeat-schedule
|
|
118
|
+
celerybeat.pid
|
|
119
|
+
|
|
120
|
+
# SageMath parsed files
|
|
121
|
+
*.sage.py
|
|
122
|
+
|
|
123
|
+
# Environments
|
|
124
|
+
.env
|
|
125
|
+
.venv
|
|
126
|
+
env/
|
|
127
|
+
venv/
|
|
128
|
+
ENV/
|
|
129
|
+
env.bak/
|
|
130
|
+
venv.bak/
|
|
131
|
+
|
|
132
|
+
# Spyder project settings
|
|
133
|
+
.spyderproject
|
|
134
|
+
.spyproject
|
|
135
|
+
|
|
136
|
+
# Rope project settings
|
|
137
|
+
.ropeproject
|
|
138
|
+
|
|
139
|
+
# mkdocs documentation
|
|
140
|
+
/site
|
|
141
|
+
|
|
142
|
+
# mypy
|
|
143
|
+
.mypy_cache/
|
|
144
|
+
.dmypy.json
|
|
145
|
+
dmypy.json
|
|
146
|
+
|
|
147
|
+
# Pyre type checker
|
|
148
|
+
.pyre/
|
|
149
|
+
|
|
150
|
+
# pytype static type analyzer
|
|
151
|
+
.pytype/
|
|
152
|
+
|
|
153
|
+
# Cython debug symbols
|
|
154
|
+
cython_debug/
|
|
155
|
+
|
|
156
|
+
# PyCharm
|
|
157
|
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
|
158
|
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
|
159
|
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
|
160
|
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
|
161
|
+
#.idea/
|
|
162
|
+
|
|
163
|
+
# Weights and Biases
|
|
164
|
+
wandb/
|
|
165
|
+
|
|
166
|
+
# ruff
|
|
167
|
+
.ruff_cache/
|
|
168
|
+
|
|
169
|
+
# Cursor
|
|
170
|
+
.cursor
|
|
171
|
+
.cursorignore
|
|
172
|
+
AGENTS.md
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Jarek Liesen
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: jax-envelope
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: A JAX-native environment interface with powerful wrappers and adapters for popular RL environment suites
|
|
5
|
+
Project-URL: Homepage, https://github.com/keraJLi/envelope
|
|
6
|
+
Project-URL: Repository, https://github.com/keraJLi/envelope
|
|
7
|
+
Project-URL: Documentation, https://github.com/keraJLi/envelope#readme
|
|
8
|
+
Project-URL: Issues, https://github.com/keraJLi/envelope/issues
|
|
9
|
+
Project-URL: Changelog, https://github.com/keraJLi/envelope/releases
|
|
10
|
+
Author-email: Jarek Liesen <jarek.liesen@reuben.ox.ac.uk>
|
|
11
|
+
License: MIT
|
|
12
|
+
License-File: LICENSE
|
|
13
|
+
Keywords: deep-learning,environments,gymnasium,hardware-acceleration,jax,machine-learning,reinforcement-learning,vectorization
|
|
14
|
+
Classifier: Development Status :: 4 - Beta
|
|
15
|
+
Classifier: Intended Audience :: Developers
|
|
16
|
+
Classifier: Intended Audience :: Science/Research
|
|
17
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
18
|
+
Classifier: Operating System :: OS Independent
|
|
19
|
+
Classifier: Programming Language :: Python :: 3
|
|
20
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
21
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
22
|
+
Classifier: Typing :: Typed
|
|
23
|
+
Requires-Python: >=3.12
|
|
24
|
+
Requires-Dist: jax>=0.5.0
|
|
25
|
+
Description-Content-Type: text/markdown
|
|
26
|
+
|
|
27
|
+
# 💌 Envelope: a JAX-native environment interface
|
|
28
|
+
```python
|
|
29
|
+
# Create environments from JAX-native suites you have installed, ...
|
|
30
|
+
env = envelope.create("gymnax::CartPole-v1")
|
|
31
|
+
|
|
32
|
+
# ... interact with the environments using a simple interface, ...
|
|
33
|
+
state, info = env.reset(key)
|
|
34
|
+
states, infos = jax.lax.scan(env.step, state, actions)
|
|
35
|
+
plt.plot(infos.reward.cumsum())
|
|
36
|
+
|
|
37
|
+
# ... and enjoy a powerful ecosystem of wrappers.
|
|
38
|
+
env = envelope.wrappers.AutoResetWrapper(env)
|
|
39
|
+
env = envelope.wrappers.VmapWrapper(env)
|
|
40
|
+
env = envelope.wrappers.ObservationNormalizationWrapper(env)
|
|
41
|
+
```
|
|
42
|
+
|
|
43
|
+
## 🌍 Simple, expressive interaction!
|
|
44
|
+
* **Environments are pytrees**. Squish them through JAX transformations and trace their parameters.
|
|
45
|
+
* **Idiomatic jax-y interface** of `reset(key: Key) -> State, Info` and `step(state: State, action: PyTree) -> State, Info`. You can directly `jax.scan` over a `step(...)`!
|
|
46
|
+
* **Spaces are super simple**. No `Tuple`, `Dict` nonsense! There are two spaces: `Continuous` and `Discrete`, which you can compose into a `PyTreeSpace`.
|
|
47
|
+
* **Explicit episode truncation** supports correctly handling bootstrapping for value-function targets.
|
|
48
|
+
* **No auto-reset** by default. Resetting every step can be expensive!
|
|
49
|
+
|
|
50
|
+
## 💪 Powerful, composable wrappers!
|
|
51
|
+
* **Carry state across episodes** to track running statistics, for example to normalize observations.
|
|
52
|
+
* **Composable wrappers** can be stacked in any order. For example, `ObservationNormalizationWrapper` before vs. after `VmapWrapper` gives per-env vs. global normalization.
|
|
53
|
+
<!-- TODO: Add auto-reset behavior (including state injection) and optimistic resets once I implement them. -->
|
|
54
|
+
|
|
55
|
+
## 🔌 Adapters for existing suites
|
|
56
|
+
| 📦 | # 🤖 | # 🌍 |
|
|
57
|
+
|------|------|------|
|
|
58
|
+
| [gymnax](https://github.com/RobertTLange/gymnax) | 🕺 | 24 |
|
|
59
|
+
| [brax](https://github.com/google/brax) | 🕺 | 12 |
|
|
60
|
+
| [jumanji](https://github.com/instadeepai/jumanji) | 🕺 / 👯 | 25 / 1 |
|
|
61
|
+
| [kinetix](https://github.com/flairox/kinetix) | 🕺 | 74 |
|
|
62
|
+
| [craftax](https://github.com/MichaelTMatthews/craftax) | 🕺 | 4 |
|
|
63
|
+
| [mujoco_playground](https://github.com/google-deepmind/mujoco_playground) | 🕺 | 54 |
|
|
64
|
+
| | |
|
|
65
|
+
| Total | 🕺 / 👯 | 193 / 1 |
|
|
66
|
+
|
|
67
|
+
```python
|
|
68
|
+
envelope.create("📦::🌍")
|
|
69
|
+
```
|
|
70
|
+
let's you create environments from any of the above!
|
|
71
|
+
|
|
72
|
+
## 📝 Testing
|
|
73
|
+
- **Default (no optional compat deps required)**: `uv run pytest -m "not compat"`
|
|
74
|
+
- **Compat suite (requires full compat dependency group)**:
|
|
75
|
+
- `uv sync --group compat`
|
|
76
|
+
- `uv run pytest -m compat`
|
|
77
|
+
- If any compat dependency is missing/broken, the run will fail fast with an error telling you what to install.
|
|
78
|
+
|
|
79
|
+
## 🏗️ Installation
|
|
80
|
+
```bash
|
|
81
|
+
pip install jax-envelope
|
|
82
|
+
```
|
|
83
|
+
|
|
84
|
+
## 💞 Related projects
|
|
85
|
+
* [stoax](https://github.com/EdanToledo/Stoa) is a very similar project that provides adapters and wrappers for the jumanji-like interface.
|
|
86
|
+
* Check out all the great suites we have adapters for! [gymnax](https://github.com/RobertTLange/gymnax), [brax](https://github.com/google/brax), [jumanji](https://github.com/instadeepai/jumanji), [kinetix](https://github.com/flairox/kinetix), [craftax](https://github.com/MichaelTMatthews/craftax), [mujoco_playground](https://github.com/google-deepmind/mujoco_playground).
|
|
87
|
+
* We will be adding support for [jaxmarl](https://github.com/flairox/jaxmarl) and [pgx](https://github.com/sotetsuk/pgx) in the future, as soon as we figured out the best ever MARL interface for JAX!
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
# 💌 Envelope: a JAX-native environment interface
|
|
2
|
+
```python
|
|
3
|
+
# Create environments from JAX-native suites you have installed, ...
|
|
4
|
+
env = envelope.create("gymnax::CartPole-v1")
|
|
5
|
+
|
|
6
|
+
# ... interact with the environments using a simple interface, ...
|
|
7
|
+
state, info = env.reset(key)
|
|
8
|
+
states, infos = jax.lax.scan(env.step, state, actions)
|
|
9
|
+
plt.plot(infos.reward.cumsum())
|
|
10
|
+
|
|
11
|
+
# ... and enjoy a powerful ecosystem of wrappers.
|
|
12
|
+
env = envelope.wrappers.AutoResetWrapper(env)
|
|
13
|
+
env = envelope.wrappers.VmapWrapper(env)
|
|
14
|
+
env = envelope.wrappers.ObservationNormalizationWrapper(env)
|
|
15
|
+
```
|
|
16
|
+
|
|
17
|
+
## 🌍 Simple, expressive interaction!
|
|
18
|
+
* **Environments are pytrees**. Squish them through JAX transformations and trace their parameters.
|
|
19
|
+
* **Idiomatic jax-y interface** of `reset(key: Key) -> State, Info` and `step(state: State, action: PyTree) -> State, Info`. You can directly `jax.scan` over a `step(...)`!
|
|
20
|
+
* **Spaces are super simple**. No `Tuple`, `Dict` nonsense! There are two spaces: `Continuous` and `Discrete`, which you can compose into a `PyTreeSpace`.
|
|
21
|
+
* **Explicit episode truncation** supports correctly handling bootstrapping for value-function targets.
|
|
22
|
+
* **No auto-reset** by default. Resetting every step can be expensive!
|
|
23
|
+
|
|
24
|
+
## 💪 Powerful, composable wrappers!
|
|
25
|
+
* **Carry state across episodes** to track running statistics, for example to normalize observations.
|
|
26
|
+
* **Composable wrappers** can be stacked in any order. For example, `ObservationNormalizationWrapper` before vs. after `VmapWrapper` gives per-env vs. global normalization.
|
|
27
|
+
<!-- TODO: Add auto-reset behavior (including state injection) and optimistic resets once I implement them. -->
|
|
28
|
+
|
|
29
|
+
## 🔌 Adapters for existing suites
|
|
30
|
+
| 📦 | # 🤖 | # 🌍 |
|
|
31
|
+
|------|------|------|
|
|
32
|
+
| [gymnax](https://github.com/RobertTLange/gymnax) | 🕺 | 24 |
|
|
33
|
+
| [brax](https://github.com/google/brax) | 🕺 | 12 |
|
|
34
|
+
| [jumanji](https://github.com/instadeepai/jumanji) | 🕺 / 👯 | 25 / 1 |
|
|
35
|
+
| [kinetix](https://github.com/flairox/kinetix) | 🕺 | 74 |
|
|
36
|
+
| [craftax](https://github.com/MichaelTMatthews/craftax) | 🕺 | 4 |
|
|
37
|
+
| [mujoco_playground](https://github.com/google-deepmind/mujoco_playground) | 🕺 | 54 |
|
|
38
|
+
| | |
|
|
39
|
+
| Total | 🕺 / 👯 | 193 / 1 |
|
|
40
|
+
|
|
41
|
+
```python
|
|
42
|
+
envelope.create("📦::🌍")
|
|
43
|
+
```
|
|
44
|
+
let's you create environments from any of the above!
|
|
45
|
+
|
|
46
|
+
## 📝 Testing
|
|
47
|
+
- **Default (no optional compat deps required)**: `uv run pytest -m "not compat"`
|
|
48
|
+
- **Compat suite (requires full compat dependency group)**:
|
|
49
|
+
- `uv sync --group compat`
|
|
50
|
+
- `uv run pytest -m compat`
|
|
51
|
+
- If any compat dependency is missing/broken, the run will fail fast with an error telling you what to install.
|
|
52
|
+
|
|
53
|
+
## 🏗️ Installation
|
|
54
|
+
```bash
|
|
55
|
+
pip install jax-envelope
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
## 💞 Related projects
|
|
59
|
+
* [stoax](https://github.com/EdanToledo/Stoa) is a very similar project that provides adapters and wrappers for the jumanji-like interface.
|
|
60
|
+
* Check out all the great suites we have adapters for! [gymnax](https://github.com/RobertTLange/gymnax), [brax](https://github.com/google/brax), [jumanji](https://github.com/instadeepai/jumanji), [kinetix](https://github.com/flairox/kinetix), [craftax](https://github.com/MichaelTMatthews/craftax), [mujoco_playground](https://github.com/google-deepmind/mujoco_playground).
|
|
61
|
+
* We will be adding support for [jaxmarl](https://github.com/flairox/jaxmarl) and [pgx](https://github.com/sotetsuk/pgx) in the future, as soon as we figured out the best ever MARL interface for JAX!
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "jax-envelope"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "A JAX-native environment interface with powerful wrappers and adapters for popular RL environment suites"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
requires-python = ">=3.12"
|
|
7
|
+
authors = [
|
|
8
|
+
{name = "Jarek Liesen", email = "jarek.liesen@reuben.ox.ac.uk"},
|
|
9
|
+
]
|
|
10
|
+
dependencies = [
|
|
11
|
+
"jax>=0.5.0",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
license = {text = "MIT"}
|
|
15
|
+
keywords = [
|
|
16
|
+
"machine-learning",
|
|
17
|
+
"reinforcement-learning",
|
|
18
|
+
"jax",
|
|
19
|
+
"vectorization",
|
|
20
|
+
"hardware-acceleration",
|
|
21
|
+
"gymnasium",
|
|
22
|
+
"environments",
|
|
23
|
+
"deep-learning",
|
|
24
|
+
]
|
|
25
|
+
classifiers = [
|
|
26
|
+
"Development Status :: 4 - Beta",
|
|
27
|
+
"Intended Audience :: Developers",
|
|
28
|
+
"Intended Audience :: Science/Research",
|
|
29
|
+
"License :: OSI Approved :: MIT License",
|
|
30
|
+
"Operating System :: OS Independent",
|
|
31
|
+
"Programming Language :: Python :: 3",
|
|
32
|
+
"Programming Language :: Python :: 3.12",
|
|
33
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
34
|
+
"Typing :: Typed",
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
[project.urls]
|
|
38
|
+
Homepage = "https://github.com/keraJLi/envelope"
|
|
39
|
+
Repository = "https://github.com/keraJLi/envelope"
|
|
40
|
+
Documentation = "https://github.com/keraJLi/envelope#readme"
|
|
41
|
+
Issues = "https://github.com/keraJLi/envelope/issues"
|
|
42
|
+
Changelog = "https://github.com/keraJLi/envelope/releases"
|
|
43
|
+
|
|
44
|
+
[build-system]
|
|
45
|
+
requires = ["hatchling"]
|
|
46
|
+
build-backend = "hatchling.build"
|
|
47
|
+
|
|
48
|
+
[tool.hatch.metadata]
|
|
49
|
+
allow-direct-references = true
|
|
50
|
+
|
|
51
|
+
[tool.hatch.build.targets.wheel]
|
|
52
|
+
packages = ["src/envelope"]
|
|
53
|
+
|
|
54
|
+
[tool.uv.sources]
|
|
55
|
+
gymnax = { git = "https://github.com/RobertTLange/gymnax" }
|
|
56
|
+
|
|
57
|
+
[dependency-groups]
|
|
58
|
+
compat = [
|
|
59
|
+
"gymnax @ git+https://github.com/RobertTLange/gymnax@main",
|
|
60
|
+
"brax>=0.13.0",
|
|
61
|
+
"craftax>=1.4.3",
|
|
62
|
+
"navix>=0.7.0",
|
|
63
|
+
"jumanji>=1.0.1",
|
|
64
|
+
"kinetix-env>=2.0.0",
|
|
65
|
+
"playground>=0.1.0",
|
|
66
|
+
]
|
|
67
|
+
dev = [
|
|
68
|
+
"hypothesis>=6.148.1",
|
|
69
|
+
"pytest>=9.0.0",
|
|
70
|
+
"pytest-cov>=7.0.0",
|
|
71
|
+
"ruff>=0.14.2",
|
|
72
|
+
]
|
|
73
|
+
|
|
74
|
+
[tool.pytest.ini_options]
|
|
75
|
+
markers = [
|
|
76
|
+
"compat: tests requiring optional compat dependencies",
|
|
77
|
+
]
|
|
File without changes
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
"""Compatibility wrappers for various RL environment libraries."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Protocol, Self
|
|
4
|
+
|
|
5
|
+
# Lazy imports to avoid requiring all dependencies at once
|
|
6
|
+
_env_module_map = {
|
|
7
|
+
"gymnax": ("envelope.compat.gymnax_envelope", "GymnaxEnvelope"),
|
|
8
|
+
"brax": ("envelope.compat.brax_envelope", "BraxEnvelope"),
|
|
9
|
+
"navix": ("envelope.compat.navix_envelope", "NavixEnvelope"),
|
|
10
|
+
"jumanji": ("envelope.compat.jumanji_envelope", "JumanjiEnvelope"),
|
|
11
|
+
"kinetix": ("envelope.compat.kinetix_envelope", "KinetixEnvelope"),
|
|
12
|
+
"craftax": ("envelope.compat.craftax_envelope", "CraftaxEnvelope"),
|
|
13
|
+
"mujoco_playground": (
|
|
14
|
+
"envelope.compat.mujoco_playground_envelope",
|
|
15
|
+
"MujocoPlaygroundEnvelope",
|
|
16
|
+
),
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class HasFromNameInit(Protocol):
|
|
21
|
+
@classmethod
|
|
22
|
+
def from_name(
|
|
23
|
+
cls, env_name: str, env_kwargs: dict[str, Any] | None = None, **kwargs
|
|
24
|
+
) -> Self: ...
|
|
25
|
+
|
|
26
|
+
"""Creates an environment from a name and keyword arguments. Unless otherwise noted,
|
|
27
|
+
the created environment will have it's default parameters, with truncation and auto
|
|
28
|
+
reset disabled.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
env_name: Environment name
|
|
32
|
+
env_kwargs: Keyword arguments passed to the environment constructor
|
|
33
|
+
**kwargs: Additional keyword arguments passed to the environment wrapper
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def create(env_name: str, env_kwargs: dict[str, Any] | None = None, **kwargs):
|
|
38
|
+
"""Create an environment from a prefixed environment ID.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
env_name: Environment ID in the format "suite::env_name" (e.g., "brax::ant")
|
|
42
|
+
env_kwargs: Keyword arguments passed to the suite's environment constructor
|
|
43
|
+
**kwargs: Additional keyword arguments passed to the environment wrapper
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
An instance of the wrapped environment
|
|
47
|
+
|
|
48
|
+
Examples:
|
|
49
|
+
>>> env = create("jumanji::snake")
|
|
50
|
+
>>> env = create("brax::ant", env_kwargs={"backend": "spring"})
|
|
51
|
+
>>> env = create("gymnax::CartPole-v1", env_params=...)
|
|
52
|
+
"""
|
|
53
|
+
original_env_id = env_name
|
|
54
|
+
if "::" not in env_name:
|
|
55
|
+
raise ValueError(
|
|
56
|
+
f"Environment ID must be in format 'suite::env_name', got: {original_env_id}"
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
suite, env_name = env_name.split("::", 1)
|
|
60
|
+
if not suite or not env_name:
|
|
61
|
+
raise ValueError(
|
|
62
|
+
f"Environment ID must be in format 'suite::env_name', got: {original_env_id}"
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
if suite not in _env_module_map:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
f"Unknown environment suite: {suite}. "
|
|
68
|
+
f"Available suites: {list(_env_module_map.keys())}"
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# Lazy import the wrapper class
|
|
72
|
+
module_name, class_name = _env_module_map[suite]
|
|
73
|
+
try:
|
|
74
|
+
import importlib
|
|
75
|
+
|
|
76
|
+
module = importlib.import_module(module_name)
|
|
77
|
+
env_class: HasFromNameInit = getattr(module, class_name)
|
|
78
|
+
except ImportError as e:
|
|
79
|
+
raise ImportError(
|
|
80
|
+
f"Failed to import {suite} wrapper. "
|
|
81
|
+
f"Make sure you have installed the '{suite}' dependencies. "
|
|
82
|
+
f"Original error: {e}"
|
|
83
|
+
) from e
|
|
84
|
+
|
|
85
|
+
env = env_class.from_name(env_name, env_kwargs=env_kwargs, **kwargs)
|
|
86
|
+
|
|
87
|
+
# Wrap with TruncationWrapper using adapter's default
|
|
88
|
+
default_max_steps = getattr(env, "default_max_steps", None)
|
|
89
|
+
if default_max_steps is not None:
|
|
90
|
+
from envelope.wrappers.truncation_wrapper import TruncationWrapper
|
|
91
|
+
|
|
92
|
+
env = TruncationWrapper(env=env, max_steps=int(default_max_steps))
|
|
93
|
+
|
|
94
|
+
return env
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
__all__ = ["create"]
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
import warnings
|
|
3
|
+
from copy import copy
|
|
4
|
+
from functools import cached_property
|
|
5
|
+
from typing import Any, override
|
|
6
|
+
|
|
7
|
+
from brax.envs import Env as BraxEnv
|
|
8
|
+
from brax.envs import Wrapper as BraxWrapper
|
|
9
|
+
from brax.envs import create as brax_create
|
|
10
|
+
from jax import numpy as jnp
|
|
11
|
+
|
|
12
|
+
from envelope import spaces
|
|
13
|
+
from envelope.environment import Environment, Info, InfoContainer, State
|
|
14
|
+
from envelope.struct import static_field
|
|
15
|
+
from envelope.typing import Key, PyTree
|
|
16
|
+
|
|
17
|
+
# Default episode_length in brax.envs.create()
|
|
18
|
+
_BRAX_DEFAULT_EPISODE_LENGTH = 1000
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class BraxEnvelope(Environment):
|
|
22
|
+
"""Wrapper to convert a Brax environment to a envelope environment."""
|
|
23
|
+
|
|
24
|
+
brax_env: BraxEnv = static_field()
|
|
25
|
+
|
|
26
|
+
@classmethod
|
|
27
|
+
def from_name(
|
|
28
|
+
cls, env_name: str, env_kwargs: dict[str, Any] | None = None
|
|
29
|
+
) -> "BraxEnvelope":
|
|
30
|
+
env_kwargs = env_kwargs or {}
|
|
31
|
+
if "episode_length" in env_kwargs:
|
|
32
|
+
raise ValueError(
|
|
33
|
+
"Cannot override 'episode_length' directly. "
|
|
34
|
+
"Use TruncationWrapper for episode length control."
|
|
35
|
+
)
|
|
36
|
+
if "auto_reset" in env_kwargs:
|
|
37
|
+
raise ValueError(
|
|
38
|
+
"Cannot override 'auto_reset' directly. "
|
|
39
|
+
"Use AutoResetWrapper for auto-reset behavior."
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
env_kwargs["episode_length"] = jnp.inf
|
|
43
|
+
env_kwargs["auto_reset"] = False
|
|
44
|
+
env = brax_create(env_name, **env_kwargs)
|
|
45
|
+
return cls(brax_env=env)
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def default_max_steps(self) -> int:
|
|
49
|
+
return _BRAX_DEFAULT_EPISODE_LENGTH
|
|
50
|
+
|
|
51
|
+
def __post_init__(self) -> "BraxEnvelope":
|
|
52
|
+
if isinstance(self.brax_env, BraxWrapper):
|
|
53
|
+
warnings.warn(
|
|
54
|
+
"Environment wrapping should be handled by envelope. "
|
|
55
|
+
"Unwrapping brax environment before converting..."
|
|
56
|
+
)
|
|
57
|
+
object.__setattr__(self, "brax_env", self.brax_env.unwrapped)
|
|
58
|
+
|
|
59
|
+
@override
|
|
60
|
+
def reset(self, key: Key) -> tuple[State, Info]:
|
|
61
|
+
brax_state = self.brax_env.reset(key)
|
|
62
|
+
info = InfoContainer(obs=brax_state.obs, reward=0.0, terminated=False)
|
|
63
|
+
info = info.update(**dataclasses.asdict(brax_state))
|
|
64
|
+
return brax_state, info
|
|
65
|
+
|
|
66
|
+
@override
|
|
67
|
+
def step(self, state: State, action: PyTree) -> tuple[State, Info]:
|
|
68
|
+
brax_state = self.brax_env.step(state, action)
|
|
69
|
+
info = InfoContainer(
|
|
70
|
+
obs=brax_state.obs, reward=brax_state.reward, terminated=brax_state.done
|
|
71
|
+
)
|
|
72
|
+
info = info.update(**dataclasses.asdict(brax_state))
|
|
73
|
+
return brax_state, info
|
|
74
|
+
|
|
75
|
+
@override
|
|
76
|
+
@cached_property
|
|
77
|
+
def action_space(self) -> spaces.Space:
|
|
78
|
+
# All brax environments have action limit of -1 to 1
|
|
79
|
+
return spaces.Continuous.from_shape(
|
|
80
|
+
low=-1.0, high=1.0, shape=(self.brax_env.action_size,)
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
@override
|
|
84
|
+
@cached_property
|
|
85
|
+
def observation_space(self) -> spaces.Space:
|
|
86
|
+
# All brax environments have observation limit of -inf to inf
|
|
87
|
+
return spaces.Continuous.from_shape(
|
|
88
|
+
low=-jnp.inf, high=jnp.inf, shape=(self.brax_env.observation_size,)
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
def __deepcopy__(self, memo):
|
|
92
|
+
warnings.warn(
|
|
93
|
+
f"Trying to deepcopy {type(self).__name__}, which contains a brax env. "
|
|
94
|
+
"Brax envs throw an error when deepcopying, so a shallow copy is returned.",
|
|
95
|
+
category=RuntimeWarning,
|
|
96
|
+
stacklevel=2,
|
|
97
|
+
)
|
|
98
|
+
return copy(self)
|