jax-envelope 0.2.0__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/.github/workflows/publish.yml +1 -1
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/PKG-INFO +38 -28
- jax_envelope-0.4.0/README.md +71 -0
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/pyproject.toml +3 -3
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/__init__.py +1 -1
- {jax_envelope-0.2.0/src/envelope/compat → jax_envelope-0.4.0/src/envelope/adapters}/__init__.py +7 -7
- {jax_envelope-0.2.0/src/envelope/compat → jax_envelope-0.4.0/src/envelope/adapters}/brax_envelope.py +1 -1
- {jax_envelope-0.2.0/src/envelope/compat → jax_envelope-0.4.0/src/envelope/adapters}/craftax_envelope.py +1 -1
- {jax_envelope-0.2.0/src/envelope/compat → jax_envelope-0.4.0/src/envelope/adapters}/jumanji_envelope.py +1 -1
- {jax_envelope-0.2.0/src/envelope/compat → jax_envelope-0.4.0/src/envelope/adapters}/kinetix_envelope.py +2 -2
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/environment.py +2 -2
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/spaces.py +1 -1
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/wrappers/autoreset_wrapper.py +2 -2
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/wrappers/continuous_observation_wrapper.py +2 -2
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/wrappers/episode_statistics_wrapper.py +2 -2
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/wrappers/flatten_observation_wrapper.py +3 -4
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/wrappers/observation_normalization_wrapper.py +2 -2
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/wrappers/pooled_init_vmap_wrapper.py +2 -2
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/wrappers/state_injection_wrapper.py +2 -2
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/wrappers/truncation_wrapper.py +2 -2
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/wrappers/vmap_envs_wrapper.py +5 -5
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/wrappers/vmap_wrapper.py +2 -2
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/wrappers/wrapper.py +3 -3
- {jax_envelope-0.2.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/conftest.py +1 -1
- {jax_envelope-0.2.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/contract.py +2 -2
- {jax_envelope-0.2.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/test_brax_compat.py +6 -6
- {jax_envelope-0.2.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/test_craftax_compat.py +5 -5
- {jax_envelope-0.2.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/test_create.py +7 -7
- {jax_envelope-0.2.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/test_create_integration.py +12 -12
- {jax_envelope-0.2.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/test_gymnax_compat.py +2 -2
- {jax_envelope-0.2.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/test_jumanji_compat.py +5 -5
- {jax_envelope-0.2.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/test_kinetix_compat.py +7 -8
- {jax_envelope-0.2.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/test_mujoco_playground_compat.py +4 -4
- {jax_envelope-0.2.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/test_navix_compat.py +8 -8
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/wrappers/helpers.py +25 -35
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/wrappers/test_autoreset_wrapper.py +2 -2
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/wrappers/test_clip_action_wrapper.py +2 -2
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/wrappers/test_continuous_observation_wrapper.py +1 -1
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/wrappers/test_episode_statistics_wrapper.py +2 -2
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/wrappers/test_flatten_action_wrapper.py +4 -4
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/wrappers/test_flatten_observation_wrapper.py +2 -2
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/wrappers/test_pooled_init_vmap_wrapper.py +1 -1
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/wrappers/test_state_injection_wrapper.py +4 -4
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/wrappers/test_truncation_wrapper.py +2 -2
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/uv.lock +3 -3
- jax_envelope-0.2.0/README.md +0 -61
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/.gitignore +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/LICENSE +0 -0
- {jax_envelope-0.2.0/src/envelope/compat → jax_envelope-0.4.0/src/envelope/adapters}/gymnax_envelope.py +0 -0
- {jax_envelope-0.2.0/src/envelope/compat → jax_envelope-0.4.0/src/envelope/adapters}/mujoco_playground_envelope.py +0 -0
- {jax_envelope-0.2.0/src/envelope/compat → jax_envelope-0.4.0/src/envelope/adapters}/navix_envelope.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/struct.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/typing.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/wrappers/__init__.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/wrappers/clip_action_wrapper.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/wrappers/flatten_action_wrapper.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/wrappers/normalization.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/__init__.py +0 -0
- {jax_envelope-0.2.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/__init__.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/spaces/__init__.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/spaces/test_batched_space.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/spaces/test_continuous.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/spaces/test_discrete.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/spaces/test_pytree_space.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/spaces/test_serialization.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/test_container.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/test_struct.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/wrappers/__init__.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/wrappers/test_environment_wrapper.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/wrappers/test_normalization.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/wrappers/test_observation_normalization_wrapper.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/wrappers/test_vmap_envs_wrapper.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.4.0}/tests/wrappers/test_vmap_wrapper.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: jax-envelope
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.0
|
|
4
4
|
Summary: A JAX-native environment interface with powerful wrappers and adapters for popular RL environment suites
|
|
5
5
|
Project-URL: Homepage, https://github.com/keraJLi/envelope
|
|
6
6
|
Project-URL: Repository, https://github.com/keraJLi/envelope
|
|
@@ -25,12 +25,13 @@ Requires-Dist: jax>=0.5.0
|
|
|
25
25
|
Description-Content-Type: text/markdown
|
|
26
26
|
|
|
27
27
|
# 💌 Envelope: a JAX-native environment interface
|
|
28
|
+
|
|
28
29
|
```python
|
|
29
30
|
# Create environments from JAX-native suites you have installed, ...
|
|
30
31
|
env = envelope.create("gymnax::CartPole-v1")
|
|
31
32
|
|
|
32
33
|
# ... interact with the environments using a simple interface, ...
|
|
33
|
-
state, info = env.
|
|
34
|
+
state, info = env.init(key)
|
|
34
35
|
states, infos = jax.lax.scan(env.step, state, actions)
|
|
35
36
|
plt.plot(infos.reward.cumsum())
|
|
36
37
|
|
|
@@ -41,47 +42,56 @@ env = envelope.wrappers.ObservationNormalizationWrapper(env)
|
|
|
41
42
|
```
|
|
42
43
|
|
|
43
44
|
## 🌍 Simple, expressive interaction!
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
45
|
+
|
|
46
|
+
- **Environments are pytrees**. Squish them through JAX transformations and trace their parameters.
|
|
47
|
+
- **Idiomatic jax-y interface** of `init(key: Key) -> State, Info` and `step(state: State, action: PyTree) -> State, Info`. You can directly `jax.scan` over a `step(...)`!
|
|
48
|
+
- **Spaces are super simple**. No `Tuple`, `Dict` nonsense! There are two spaces: `Continuous` and `Discrete`, which you can compose into a `PyTreeSpace`.
|
|
49
|
+
- **Explicit episode truncation** supports correctly handling bootstrapping for value-function targets.
|
|
50
|
+
- **No auto-reset** by default. Resetting every step can be expensive!
|
|
49
51
|
|
|
50
52
|
## 💪 Powerful, composable wrappers!
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
53
|
+
|
|
54
|
+
- **Carry state across episodes** to track running statistics, for example to normalize observations.
|
|
55
|
+
- **Composable wrappers** can be stacked in any order. For example, `ObservationNormalizationWrapper` before vs. after `VmapWrapper` gives per-env vs. global normalization.
|
|
54
56
|
|
|
55
57
|
## 🔌 Adapters for existing suites
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
|
59
|
-
|
|
|
60
|
-
| [
|
|
61
|
-
| [
|
|
62
|
-
| [
|
|
63
|
-
| [
|
|
64
|
-
| | |
|
|
65
|
-
|
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
| 📦 | # 🤖 | # 🌍 |
|
|
61
|
+
| ------------------------------------------------------------------------- | ------- | ------- |
|
|
62
|
+
| [gymnax](https://github.com/RobertTLange/gymnax) | 🕺 | 24 |
|
|
63
|
+
| [brax](https://github.com/google/brax) | 🕺 | 12 |
|
|
64
|
+
| [jumanji](https://github.com/instadeepai/jumanji) | 🕺 / 👯 | 25 / 1 |
|
|
65
|
+
| [kinetix](https://github.com/flairox/kinetix) | 🕺 | 74 |
|
|
66
|
+
| [craftax](https://github.com/MichaelTMatthews/craftax) | 🕺 | 4 |
|
|
67
|
+
| [mujoco_playground](https://github.com/google-deepmind/mujoco_playground) | 🕺 | 54 |
|
|
68
|
+
| | | |
|
|
69
|
+
| Total | 🕺 / 👯 | 193 / 1 |
|
|
70
|
+
|
|
66
71
|
|
|
67
72
|
```python
|
|
68
73
|
envelope.create("📦::🌍")
|
|
69
74
|
```
|
|
75
|
+
|
|
70
76
|
let's you create environments from any of the above!
|
|
71
77
|
|
|
72
78
|
## 📝 Testing
|
|
73
|
-
|
|
74
|
-
- **
|
|
75
|
-
|
|
76
|
-
- `uv
|
|
77
|
-
-
|
|
79
|
+
|
|
80
|
+
- **Default (no optional adapters deps required)**: `uv run pytest -m "not adapters"`
|
|
81
|
+
- **Adapters suite (requires full adapters dependency group)**:
|
|
82
|
+
- `uv sync --group adapters`
|
|
83
|
+
- `uv run pytest -m adapters`
|
|
84
|
+
- If any adapter dependency is missing/broken, the run will fail fast with an error telling you what to install.
|
|
78
85
|
|
|
79
86
|
## 🏗️ Installation
|
|
87
|
+
|
|
80
88
|
```bash
|
|
81
89
|
pip install jax-envelope
|
|
82
90
|
```
|
|
83
91
|
|
|
84
92
|
## 💞 Related projects
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
93
|
+
|
|
94
|
+
- [stoa](https://github.com/EdanToledo/Stoa) is a very similar project that provides adapters and wrappers for the jumanji-like interface.
|
|
95
|
+
- Check out all the great suites we have adapters for! [gymnax](https://github.com/RobertTLange/gymnax), [brax](https://github.com/google/brax), [jumanji](https://github.com/instadeepai/jumanji), [kinetix](https://github.com/flairox/kinetix), [craftax](https://github.com/MichaelTMatthews/craftax), [mujoco_playground](https://github.com/google-deepmind/mujoco_playground).
|
|
96
|
+
- We will be adding support for [jaxmarl](https://github.com/flairox/jaxmarl) and [pgx](https://github.com/sotetsuk/pgx) in the future, as soon as we figured out the best ever MARL interface for JAX!
|
|
97
|
+
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
# 💌 Envelope: a JAX-native environment interface
|
|
2
|
+
|
|
3
|
+
```python
|
|
4
|
+
# Create environments from JAX-native suites you have installed, ...
|
|
5
|
+
env = envelope.create("gymnax::CartPole-v1")
|
|
6
|
+
|
|
7
|
+
# ... interact with the environments using a simple interface, ...
|
|
8
|
+
state, info = env.init(key)
|
|
9
|
+
states, infos = jax.lax.scan(env.step, state, actions)
|
|
10
|
+
plt.plot(infos.reward.cumsum())
|
|
11
|
+
|
|
12
|
+
# ... and enjoy a powerful ecosystem of wrappers.
|
|
13
|
+
env = envelope.wrappers.AutoResetWrapper(env)
|
|
14
|
+
env = envelope.wrappers.VmapWrapper(env)
|
|
15
|
+
env = envelope.wrappers.ObservationNormalizationWrapper(env)
|
|
16
|
+
```
|
|
17
|
+
|
|
18
|
+
## 🌍 Simple, expressive interaction!
|
|
19
|
+
|
|
20
|
+
- **Environments are pytrees**. Squish them through JAX transformations and trace their parameters.
|
|
21
|
+
- **Idiomatic jax-y interface** of `init(key: Key) -> State, Info` and `step(state: State, action: PyTree) -> State, Info`. You can directly `jax.scan` over a `step(...)`!
|
|
22
|
+
- **Spaces are super simple**. No `Tuple`, `Dict` nonsense! There are two spaces: `Continuous` and `Discrete`, which you can compose into a `PyTreeSpace`.
|
|
23
|
+
- **Explicit episode truncation** supports correctly handling bootstrapping for value-function targets.
|
|
24
|
+
- **No auto-reset** by default. Resetting every step can be expensive!
|
|
25
|
+
|
|
26
|
+
## 💪 Powerful, composable wrappers!
|
|
27
|
+
|
|
28
|
+
- **Carry state across episodes** to track running statistics, for example to normalize observations.
|
|
29
|
+
- **Composable wrappers** can be stacked in any order. For example, `ObservationNormalizationWrapper` before vs. after `VmapWrapper` gives per-env vs. global normalization.
|
|
30
|
+
|
|
31
|
+
## 🔌 Adapters for existing suites
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
| 📦 | # 🤖 | # 🌍 |
|
|
35
|
+
| ------------------------------------------------------------------------- | ------- | ------- |
|
|
36
|
+
| [gymnax](https://github.com/RobertTLange/gymnax) | 🕺 | 24 |
|
|
37
|
+
| [brax](https://github.com/google/brax) | 🕺 | 12 |
|
|
38
|
+
| [jumanji](https://github.com/instadeepai/jumanji) | 🕺 / 👯 | 25 / 1 |
|
|
39
|
+
| [kinetix](https://github.com/flairox/kinetix) | 🕺 | 74 |
|
|
40
|
+
| [craftax](https://github.com/MichaelTMatthews/craftax) | 🕺 | 4 |
|
|
41
|
+
| [mujoco_playground](https://github.com/google-deepmind/mujoco_playground) | 🕺 | 54 |
|
|
42
|
+
| | | |
|
|
43
|
+
| Total | 🕺 / 👯 | 193 / 1 |
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
```python
|
|
47
|
+
envelope.create("📦::🌍")
|
|
48
|
+
```
|
|
49
|
+
|
|
50
|
+
let's you create environments from any of the above!
|
|
51
|
+
|
|
52
|
+
## 📝 Testing
|
|
53
|
+
|
|
54
|
+
- **Default (no optional adapters deps required)**: `uv run pytest -m "not adapters"`
|
|
55
|
+
- **Adapters suite (requires full adapters dependency group)**:
|
|
56
|
+
- `uv sync --group adapters`
|
|
57
|
+
- `uv run pytest -m adapters`
|
|
58
|
+
- If any adapter dependency is missing/broken, the run will fail fast with an error telling you what to install.
|
|
59
|
+
|
|
60
|
+
## 🏗️ Installation
|
|
61
|
+
|
|
62
|
+
```bash
|
|
63
|
+
pip install jax-envelope
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
## 💞 Related projects
|
|
67
|
+
|
|
68
|
+
- [stoa](https://github.com/EdanToledo/Stoa) is a very similar project that provides adapters and wrappers for the jumanji-like interface.
|
|
69
|
+
- Check out all the great suites we have adapters for! [gymnax](https://github.com/RobertTLange/gymnax), [brax](https://github.com/google/brax), [jumanji](https://github.com/instadeepai/jumanji), [kinetix](https://github.com/flairox/kinetix), [craftax](https://github.com/MichaelTMatthews/craftax), [mujoco_playground](https://github.com/google-deepmind/mujoco_playground).
|
|
70
|
+
- We will be adding support for [jaxmarl](https://github.com/flairox/jaxmarl) and [pgx](https://github.com/sotetsuk/pgx) in the future, as soon as we figured out the best ever MARL interface for JAX!
|
|
71
|
+
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "jax-envelope"
|
|
3
|
-
version = "0.
|
|
3
|
+
version = "0.4.0"
|
|
4
4
|
description = "A JAX-native environment interface with powerful wrappers and adapters for popular RL environment suites"
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
requires-python = ">=3.12"
|
|
@@ -52,7 +52,7 @@ allow-direct-references = true
|
|
|
52
52
|
packages = ["src/envelope"]
|
|
53
53
|
|
|
54
54
|
[dependency-groups]
|
|
55
|
-
|
|
55
|
+
adapters = [
|
|
56
56
|
"brax>=0.13.0",
|
|
57
57
|
"craftax>=1.4.3",
|
|
58
58
|
"navix>=0.7.0",
|
|
@@ -80,5 +80,5 @@ kinetix-env = { git = "https://github.com/FLAIROx/Kinetix.git" }
|
|
|
80
80
|
[tool.pytest.ini_options]
|
|
81
81
|
testpaths = ["tests"]
|
|
82
82
|
markers = [
|
|
83
|
-
"
|
|
83
|
+
"adapters: tests requiring optional adapters dependencies",
|
|
84
84
|
]
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from envelope.
|
|
1
|
+
from envelope.adapters import create
|
|
2
2
|
from envelope.environment import Environment, Info, InfoContainer
|
|
3
3
|
from envelope.spaces import BatchedSpace, Continuous, Discrete, PyTreeSpace, Space
|
|
4
4
|
from envelope.struct import Container, FrozenPyTreeNode, field, static_field
|
{jax_envelope-0.2.0/src/envelope/compat → jax_envelope-0.4.0/src/envelope/adapters}/__init__.py
RENAMED
|
@@ -4,14 +4,14 @@ from typing import Any, Protocol, Self
|
|
|
4
4
|
|
|
5
5
|
# Lazy imports to avoid requiring all dependencies at once
|
|
6
6
|
_env_module_map = {
|
|
7
|
-
"gymnax": ("envelope.
|
|
8
|
-
"brax": ("envelope.
|
|
9
|
-
"navix": ("envelope.
|
|
10
|
-
"jumanji": ("envelope.
|
|
11
|
-
"kinetix": ("envelope.
|
|
12
|
-
"craftax": ("envelope.
|
|
7
|
+
"gymnax": ("envelope.adapters.gymnax_envelope", "GymnaxEnvelope"),
|
|
8
|
+
"brax": ("envelope.adapters.brax_envelope", "BraxEnvelope"),
|
|
9
|
+
"navix": ("envelope.adapters.navix_envelope", "NavixEnvelope"),
|
|
10
|
+
"jumanji": ("envelope.adapters.jumanji_envelope", "JumanjiEnvelope"),
|
|
11
|
+
"kinetix": ("envelope.adapters.kinetix_envelope", "KinetixEnvelope"),
|
|
12
|
+
"craftax": ("envelope.adapters.craftax_envelope", "CraftaxEnvelope"),
|
|
13
13
|
"mujoco_playground": (
|
|
14
|
-
"envelope.
|
|
14
|
+
"envelope.adapters.mujoco_playground_envelope",
|
|
15
15
|
"MujocoPlaygroundEnvelope",
|
|
16
16
|
),
|
|
17
17
|
}
|
{jax_envelope-0.2.0/src/envelope/compat → jax_envelope-0.4.0/src/envelope/adapters}/brax_envelope.py
RENAMED
|
@@ -69,7 +69,7 @@ class BraxEnvelope(Environment):
|
|
|
69
69
|
info = InfoContainer(
|
|
70
70
|
obs=brax_state.obs,
|
|
71
71
|
reward=brax_state.reward,
|
|
72
|
-
terminated=jnp.
|
|
72
|
+
terminated=jnp.asarray(brax_state.done, dtype=bool),
|
|
73
73
|
)
|
|
74
74
|
info = info.update(**dataclasses.asdict(brax_state))
|
|
75
75
|
return brax_state, info
|
|
@@ -10,7 +10,7 @@ from craftax.craftax_classic.envs.craftax_state import (
|
|
|
10
10
|
from craftax.craftax_env import make_craftax_env_from_name
|
|
11
11
|
|
|
12
12
|
from envelope import spaces as envelope_spaces
|
|
13
|
-
from envelope.
|
|
13
|
+
from envelope.adapters.gymnax_envelope import _convert_space as _convert_gymnax_space
|
|
14
14
|
from envelope.environment import Environment, Info, InfoContainer, State
|
|
15
15
|
from envelope.struct import Container, static_field
|
|
16
16
|
from envelope.typing import Key, PyTree, TypeAlias
|
|
@@ -81,7 +81,7 @@ class JumanjiEnvelope(Environment):
|
|
|
81
81
|
|
|
82
82
|
|
|
83
83
|
def convert_jumanji_to_envelope_info(timestep: JumanjiTimeStep) -> InfoContainer:
|
|
84
|
-
term = jnp.asarray(timestep.last(), dtype=bool)
|
|
84
|
+
term = jnp.asarray(timestep.last(), dtype=bool)
|
|
85
85
|
info = InfoContainer(
|
|
86
86
|
obs=timestep.observation, reward=timestep.reward, terminated=term
|
|
87
87
|
).update(**timestep.extras)
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Kinetix compatibility wrapper.
|
|
2
2
|
|
|
3
3
|
This module exposes Kinetix environments through the `envelope.environment.Environment`
|
|
4
|
-
API. It mirrors envelope's
|
|
4
|
+
API. It mirrors envelope's adapters philosophy:
|
|
5
5
|
- prefer *no* environment-side auto-reset (use `AutoResetWrapper` in envelope)
|
|
6
6
|
- prefer *no* fixed episode time-limits (use `TruncationWrapper` in envelope)
|
|
7
7
|
|
|
@@ -30,7 +30,7 @@ from kinetix.util.saving import load_from_json_file
|
|
|
30
30
|
|
|
31
31
|
from envelope import field
|
|
32
32
|
from envelope import spaces as envelope_spaces
|
|
33
|
-
from envelope.
|
|
33
|
+
from envelope.adapters.gymnax_envelope import _convert_space as _convert_gymnax_space
|
|
34
34
|
from envelope.environment import Environment, Info, InfoContainer, State
|
|
35
35
|
from envelope.struct import Container, static_field
|
|
36
36
|
from envelope.typing import Key, PyTree
|
|
@@ -42,7 +42,7 @@ class Environment(ABC, FrozenPyTreeNode):
|
|
|
42
42
|
|
|
43
43
|
Two distinct lifecycle methods:
|
|
44
44
|
init(key) - Initialize environment and all state from scratch.
|
|
45
|
-
reset(
|
|
45
|
+
reset(state, key) - Reset the inner environment while preserving
|
|
46
46
|
episode-persistent state.
|
|
47
47
|
"""
|
|
48
48
|
|
|
@@ -51,7 +51,7 @@ class Environment(ABC, FrozenPyTreeNode):
|
|
|
51
51
|
"""Initialize environment and all state from scratch."""
|
|
52
52
|
...
|
|
53
53
|
|
|
54
|
-
def reset(self,
|
|
54
|
+
def reset(self, state: State, key: Key) -> tuple[State, Info]:
|
|
55
55
|
"""Reset the inner environment while preserving episode-persistent state."""
|
|
56
56
|
return self.init(key)
|
|
57
57
|
|
|
@@ -54,7 +54,7 @@ class AutoResetWrapper(Wrapper):
|
|
|
54
54
|
return state, info.update(final=state.last_final)
|
|
55
55
|
|
|
56
56
|
@override
|
|
57
|
-
def reset(self,
|
|
57
|
+
def reset(self, state: WrappedState, key: Key) -> tuple[WrappedState, Info]:
|
|
58
58
|
raise NotImplementedError("Reset is not implemented for AutoResetWrapper")
|
|
59
59
|
|
|
60
60
|
@override
|
|
@@ -63,7 +63,7 @@ class AutoResetWrapper(Wrapper):
|
|
|
63
63
|
state = state.replace(reset_key=key)
|
|
64
64
|
|
|
65
65
|
inner_state, info = self.env.step(state.inner_state, action)
|
|
66
|
-
reset_inner_state, reset_info = self.env.reset(
|
|
66
|
+
reset_inner_state, reset_info = self.env.reset(inner_state, key_reset)
|
|
67
67
|
|
|
68
68
|
# Select next state and info based on done
|
|
69
69
|
done = info.terminated | info.truncated
|
{jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/wrappers/continuous_observation_wrapper.py
RENAMED
|
@@ -35,8 +35,8 @@ class ContinuousObservationWrapper(Wrapper):
|
|
|
35
35
|
return state, info
|
|
36
36
|
|
|
37
37
|
@override
|
|
38
|
-
def reset(self,
|
|
39
|
-
state, info = self.env.reset(
|
|
38
|
+
def reset(self, state: State, key: Key) -> tuple[State, Info]:
|
|
39
|
+
state, info = self.env.reset(state, key)
|
|
40
40
|
info = info.update(obs=to_float(info.obs))
|
|
41
41
|
return state, info
|
|
42
42
|
|
{jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/wrappers/episode_statistics_wrapper.py
RENAMED
|
@@ -24,8 +24,8 @@ class EpisodeStatisticsWrapper(Wrapper):
|
|
|
24
24
|
return state, info.update(stats=state.stats)
|
|
25
25
|
|
|
26
26
|
@override
|
|
27
|
-
def reset(self,
|
|
28
|
-
inner_state, info = self.env.reset(
|
|
27
|
+
def reset(self, state: State, key: Key) -> tuple[State, Info]:
|
|
28
|
+
inner_state, info = self.env.reset(state.inner_state, key)
|
|
29
29
|
state = state.replace(inner_state=inner_state)
|
|
30
30
|
return state, info.update(stats=state.stats)
|
|
31
31
|
|
{jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/wrappers/flatten_observation_wrapper.py
RENAMED
|
@@ -37,10 +37,9 @@ class FlattenObservationWrapper(Wrapper):
|
|
|
37
37
|
return state, info
|
|
38
38
|
|
|
39
39
|
@override
|
|
40
|
-
def reset(self,
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
return state, info
|
|
40
|
+
def reset(self, state: State, key: Key) -> tuple[State, Info]:
|
|
41
|
+
next_state, info = self.env.reset(state, key)
|
|
42
|
+
return next_state, info.update(obs=flatten_x(info.obs))
|
|
44
43
|
|
|
45
44
|
@override
|
|
46
45
|
def step(self, state: State, action: PyTree) -> tuple[State, Info]:
|
{jax_envelope-0.2.0 → jax_envelope-0.4.0}/src/envelope/wrappers/observation_normalization_wrapper.py
RENAMED
|
@@ -76,8 +76,8 @@ class ObservationNormalizationWrapper(Wrapper):
|
|
|
76
76
|
return self._normalize_and_update(next_state, info)
|
|
77
77
|
|
|
78
78
|
@override
|
|
79
|
-
def reset(self,
|
|
80
|
-
inner_state, info = self.env.reset(
|
|
79
|
+
def reset(self, state: WrappedState, key: Key) -> tuple[WrappedState, Info]:
|
|
80
|
+
inner_state, info = self.env.reset(state.inner_state, key)
|
|
81
81
|
# Preserve running statistics across resets
|
|
82
82
|
next_state = self.ObservationNormalizationState(
|
|
83
83
|
inner_state=inner_state, rmv_state=state.rmv_state
|
|
@@ -36,7 +36,7 @@ class PooledInitVmapWrapper(Wrapper):
|
|
|
36
36
|
return state, info.update(final=pholder_info)
|
|
37
37
|
|
|
38
38
|
@override
|
|
39
|
-
def reset(self,
|
|
39
|
+
def reset(self, state: WrappedState, key: Key) -> tuple[WrappedState, Info]:
|
|
40
40
|
# It's hard to support reset for this wrapper.
|
|
41
41
|
# We would have to init the state of a pool of unwrapped environments, and then
|
|
42
42
|
# somehow inject this into the stack of wrapped states. The current data
|
|
@@ -48,7 +48,7 @@ class PooledInitVmapWrapper(Wrapper):
|
|
|
48
48
|
# episodes before vmapping, we will implement this later.
|
|
49
49
|
keys = _split_or_keep_key(key, self.batch_size + 1)
|
|
50
50
|
key_next, keys_pool = keys[0], keys[1:]
|
|
51
|
-
inner_state, info = jax.vmap(self.env.reset)(
|
|
51
|
+
inner_state, info = jax.vmap(self.env.reset)(state.inner_state, keys_pool)
|
|
52
52
|
state = state.replace(inner_state=inner_state, init_key=key_next)
|
|
53
53
|
return state, info.update(final=state.last_final)
|
|
54
54
|
|
|
@@ -69,13 +69,13 @@ class StateInjectionWrapper(Wrapper):
|
|
|
69
69
|
return state, info
|
|
70
70
|
|
|
71
71
|
@override
|
|
72
|
-
def reset(self,
|
|
72
|
+
def reset(self, state: WrappedState, key: Key) -> tuple[WrappedState, Info]:
|
|
73
73
|
# If reset state is set, use it instead of resetting inner env
|
|
74
74
|
if state.reset_state is not None and state.reset_obs is not None:
|
|
75
75
|
inner_state = state.reset_state
|
|
76
76
|
info = InfoContainer(obs=state.reset_obs, reward=0.0, terminated=False)
|
|
77
77
|
elif state.reset_state is None and state.reset_obs is None:
|
|
78
|
-
inner_state, info = self.env.reset(
|
|
78
|
+
inner_state, info = self.env.reset(state.inner_state, key)
|
|
79
79
|
else:
|
|
80
80
|
raise ValueError("State must set both reset_state and reset_obs or neither")
|
|
81
81
|
|
|
@@ -21,8 +21,8 @@ class TruncationWrapper(Wrapper):
|
|
|
21
21
|
return state, info.update(truncated=self.max_steps <= 0)
|
|
22
22
|
|
|
23
23
|
@override
|
|
24
|
-
def reset(self,
|
|
25
|
-
inner_state, info = self.env.reset(
|
|
24
|
+
def reset(self, state: WrappedState, key: Key) -> tuple[WrappedState, Info]:
|
|
25
|
+
inner_state, info = self.env.reset(state.inner_state, key)
|
|
26
26
|
state = state.replace(inner_state=inner_state, steps=0)
|
|
27
27
|
return state, info.update(truncated=self.max_steps <= 0)
|
|
28
28
|
|
|
@@ -40,11 +40,9 @@ class VmapEnvsWrapper(Wrapper):
|
|
|
40
40
|
return state, info
|
|
41
41
|
|
|
42
42
|
@override
|
|
43
|
-
def reset(self,
|
|
43
|
+
def reset(self, state: PyTree, key: Key) -> tuple[WrappedState, Info]:
|
|
44
44
|
keys = self._split_keys(key)
|
|
45
|
-
state, info = jax.vmap(lambda e,
|
|
46
|
-
self.env, keys, state
|
|
47
|
-
)
|
|
45
|
+
state, info = jax.vmap(lambda e, s, k: e.reset(s, k))(self.env, state, keys)
|
|
48
46
|
return state, info
|
|
49
47
|
|
|
50
48
|
@override
|
|
@@ -58,7 +56,9 @@ class VmapEnvsWrapper(Wrapper):
|
|
|
58
56
|
@property
|
|
59
57
|
def observation_space(self) -> spaces.Space:
|
|
60
58
|
env0 = _index_env(self.env, 0, self.batch_size)
|
|
61
|
-
return spaces.BatchedSpace(
|
|
59
|
+
return spaces.BatchedSpace(
|
|
60
|
+
space=env0.observation_space, batch_size=self.batch_size
|
|
61
|
+
)
|
|
62
62
|
|
|
63
63
|
@override
|
|
64
64
|
@cached_property
|
|
@@ -41,9 +41,9 @@ class VmapWrapper(Wrapper):
|
|
|
41
41
|
return state, info
|
|
42
42
|
|
|
43
43
|
@override
|
|
44
|
-
def reset(self,
|
|
44
|
+
def reset(self, state: PyTree, key: Key) -> tuple[WrappedState, Info]:
|
|
45
45
|
keys = _split_or_keep_key(key, self.batch_size)
|
|
46
|
-
state, info = jax.vmap(self.env.reset)(
|
|
46
|
+
state, info = jax.vmap(self.env.reset)(state, keys)
|
|
47
47
|
return state, info
|
|
48
48
|
|
|
49
49
|
@override
|
|
@@ -29,11 +29,11 @@ class Wrapper(Environment):
|
|
|
29
29
|
return self.env.init(key)
|
|
30
30
|
|
|
31
31
|
@override
|
|
32
|
-
def reset(self,
|
|
33
|
-
return self.env.reset(
|
|
32
|
+
def reset(self, state: State, key: Key) -> tuple[State, Info]:
|
|
33
|
+
return self.env.reset(state, key)
|
|
34
34
|
|
|
35
35
|
@override
|
|
36
|
-
def step(self, state:
|
|
36
|
+
def step(self, state: State, action: PyTree) -> tuple[State, Info]:
|
|
37
37
|
return self.env.step(state, action)
|
|
38
38
|
|
|
39
39
|
@override
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
"""Shared contract helpers for
|
|
1
|
+
"""Shared contract helpers for adapters.
|
|
2
2
|
|
|
3
|
-
These functions enforce a consistent baseline across all
|
|
3
|
+
These functions enforce a consistent baseline across all adapters:
|
|
4
4
|
- reset/step return (state, info) with Info fields present
|
|
5
5
|
- reward is scalar-ish and finite
|
|
6
6
|
- action sampling is valid for action_space
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""Tests for envelope.
|
|
1
|
+
"""Tests for envelope.adapters.brax_envelope module."""
|
|
2
2
|
|
|
3
3
|
# ruff: noqa: E402
|
|
4
4
|
|
|
@@ -7,14 +7,14 @@ from copy import deepcopy
|
|
|
7
7
|
import jax
|
|
8
8
|
import pytest
|
|
9
9
|
|
|
10
|
-
pytestmark = pytest.mark.
|
|
10
|
+
pytestmark = pytest.mark.adapters
|
|
11
11
|
|
|
12
12
|
pytest.importorskip("brax")
|
|
13
13
|
|
|
14
14
|
from brax.envs import Wrapper as BraxWrapper
|
|
15
15
|
|
|
16
|
-
from envelope.
|
|
17
|
-
from tests.
|
|
16
|
+
from envelope.adapters.brax_envelope import BraxEnvelope
|
|
17
|
+
from tests.adapters.contract import (
|
|
18
18
|
assert_jitted_rollout_contract,
|
|
19
19
|
assert_reset_step_contract,
|
|
20
20
|
)
|
|
@@ -93,8 +93,8 @@ def test_wrapper_unwrapping():
|
|
|
93
93
|
|
|
94
94
|
# Create a simple wrapper
|
|
95
95
|
class SimpleWrapper(BraxWrapper):
|
|
96
|
-
def
|
|
97
|
-
return self.env.
|
|
96
|
+
def init(self, rng):
|
|
97
|
+
return self.env.init(rng)
|
|
98
98
|
|
|
99
99
|
def step(self, state, action):
|
|
100
100
|
return self.env.step(state, action)
|
{jax_envelope-0.2.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/test_craftax_compat.py
RENAMED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""Tests for envelope.
|
|
1
|
+
"""Tests for envelope.adapters.craftax_envelope module."""
|
|
2
2
|
|
|
3
3
|
# ruff: noqa: E402
|
|
4
4
|
|
|
@@ -8,12 +8,12 @@ import jax
|
|
|
8
8
|
import jax.numpy as jnp
|
|
9
9
|
import pytest
|
|
10
10
|
|
|
11
|
-
pytestmark = pytest.mark.
|
|
11
|
+
pytestmark = pytest.mark.adapters
|
|
12
12
|
|
|
13
13
|
pytest.importorskip("craftax")
|
|
14
14
|
|
|
15
15
|
from envelope.spaces import Continuous, Discrete
|
|
16
|
-
from tests.
|
|
16
|
+
from tests.adapters.contract import (
|
|
17
17
|
assert_jitted_rollout_contract,
|
|
18
18
|
assert_reset_step_contract,
|
|
19
19
|
)
|
|
@@ -35,7 +35,7 @@ def craftax_env_id(request: pytest.FixtureRequest) -> str:
|
|
|
35
35
|
|
|
36
36
|
@pytest.fixture(scope="module")
|
|
37
37
|
def craftax_env(craftax_env_id: str):
|
|
38
|
-
from envelope.
|
|
38
|
+
from envelope.adapters.craftax_envelope import CraftaxEnvelope
|
|
39
39
|
|
|
40
40
|
return CraftaxEnvelope.from_name(craftax_env_id)
|
|
41
41
|
|
|
@@ -114,7 +114,7 @@ class _DummyEnv:
|
|
|
114
114
|
|
|
115
115
|
|
|
116
116
|
def test_from_name_errors_on_auto_reset():
|
|
117
|
-
from envelope.
|
|
117
|
+
from envelope.adapters.craftax_envelope import CraftaxEnvelope
|
|
118
118
|
|
|
119
119
|
with pytest.raises(ValueError, match="Cannot override 'auto_reset' directly"):
|
|
120
120
|
CraftaxEnvelope.from_name("AnyEnv", env_kwargs={"auto_reset": True})
|