jax-envelope 0.1.1__tar.gz → 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/.gitignore +2 -1
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/PKG-INFO +34 -23
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/README.md +33 -22
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/pyproject.toml +13 -6
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/src/envelope/__init__.py +16 -4
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/src/envelope/compat/brax_envelope.py +5 -3
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/src/envelope/compat/craftax_envelope.py +17 -2
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/src/envelope/compat/gymnax_envelope.py +34 -7
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/src/envelope/compat/jumanji_envelope.py +3 -2
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/src/envelope/compat/kinetix_envelope.py +3 -2
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/src/envelope/compat/mujoco_playground_envelope.py +1 -1
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/src/envelope/compat/navix_envelope.py +1 -1
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/src/envelope/environment.py +16 -9
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/src/envelope/spaces.py +40 -20
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/src/envelope/struct.py +10 -1
- jax_envelope-0.3.0/src/envelope/wrappers/__init__.py +36 -0
- jax_envelope-0.3.0/src/envelope/wrappers/autoreset_wrapper.py +80 -0
- jax_envelope-0.3.0/src/envelope/wrappers/clip_action_wrapper.py +27 -0
- jax_envelope-0.3.0/src/envelope/wrappers/continuous_observation_wrapper.py +61 -0
- jax_envelope-0.3.0/src/envelope/wrappers/episode_statistics_wrapper.py +40 -0
- jax_envelope-0.3.0/src/envelope/wrappers/flatten_action_wrapper.py +75 -0
- jax_envelope-0.3.0/src/envelope/wrappers/flatten_observation_wrapper.py +80 -0
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/src/envelope/wrappers/normalization.py +1 -1
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/src/envelope/wrappers/observation_normalization_wrapper.py +28 -16
- jax_envelope-0.3.0/src/envelope/wrappers/pooled_init_vmap_wrapper.py +122 -0
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/src/envelope/wrappers/state_injection_wrapper.py +18 -22
- jax_envelope-0.3.0/src/envelope/wrappers/truncation_wrapper.py +35 -0
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/src/envelope/wrappers/vmap_envs_wrapper.py +26 -21
- jax_envelope-0.3.0/src/envelope/wrappers/vmap_wrapper.py +66 -0
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/src/envelope/wrappers/wrapper.py +8 -8
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/compat/conftest.py +1 -1
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/compat/contract.py +5 -2
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/compat/test_brax_compat.py +6 -6
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/compat/test_craftax_compat.py +4 -2
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/compat/test_create.py +12 -0
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/compat/test_create_integration.py +7 -7
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/compat/test_gymnax_compat.py +3 -3
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/compat/test_jumanji_compat.py +1 -1
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/compat/test_kinetix_compat.py +7 -9
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/compat/test_mujoco_playground_compat.py +6 -6
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/compat/test_navix_compat.py +5 -5
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/spaces/test_batched_space.py +74 -50
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/spaces/test_pytree_space.py +21 -1
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/wrappers/helpers.py +112 -36
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/wrappers/test_autoreset_wrapper.py +164 -67
- jax_envelope-0.3.0/tests/wrappers/test_clip_action_wrapper.py +174 -0
- jax_envelope-0.3.0/tests/wrappers/test_continuous_observation_wrapper.py +153 -0
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/wrappers/test_environment_wrapper.py +12 -12
- jax_envelope-0.3.0/tests/wrappers/test_episode_statistics_wrapper.py +183 -0
- jax_envelope-0.3.0/tests/wrappers/test_flatten_action_wrapper.py +215 -0
- jax_envelope-0.3.0/tests/wrappers/test_flatten_observation_wrapper.py +178 -0
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/wrappers/test_observation_normalization_wrapper.py +9 -9
- jax_envelope-0.3.0/tests/wrappers/test_pooled_init_vmap_wrapper.py +292 -0
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/wrappers/test_state_injection_wrapper.py +19 -19
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/wrappers/test_truncation_wrapper.py +39 -14
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/wrappers/test_vmap_envs_wrapper.py +7 -7
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/wrappers/test_vmap_wrapper.py +18 -18
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/uv.lock +818 -711
- jax_envelope-0.1.1/src/envelope/wrappers/__init__.py +0 -20
- jax_envelope-0.1.1/src/envelope/wrappers/autoreset_wrapper.py +0 -36
- jax_envelope-0.1.1/src/envelope/wrappers/episode_statistics_wrapper.py +0 -47
- jax_envelope-0.1.1/src/envelope/wrappers/truncation_wrapper.py +0 -31
- jax_envelope-0.1.1/src/envelope/wrappers/vmap_wrapper.py +0 -51
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/.github/workflows/publish.yml +0 -0
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/LICENSE +0 -0
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/src/envelope/compat/__init__.py +0 -0
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/src/envelope/typing.py +0 -0
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/__init__.py +0 -0
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/compat/__init__.py +0 -0
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/spaces/__init__.py +0 -0
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/spaces/test_continuous.py +0 -0
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/spaces/test_discrete.py +0 -0
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/spaces/test_serialization.py +0 -0
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/test_container.py +0 -0
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/test_struct.py +0 -0
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/wrappers/__init__.py +0 -0
- {jax_envelope-0.1.1 → jax_envelope-0.3.0}/tests/wrappers/test_normalization.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: jax-envelope
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0
|
|
4
4
|
Summary: A JAX-native environment interface with powerful wrappers and adapters for popular RL environment suites
|
|
5
5
|
Project-URL: Homepage, https://github.com/keraJLi/envelope
|
|
6
6
|
Project-URL: Repository, https://github.com/keraJLi/envelope
|
|
@@ -25,12 +25,13 @@ Requires-Dist: jax>=0.5.0
|
|
|
25
25
|
Description-Content-Type: text/markdown
|
|
26
26
|
|
|
27
27
|
# 💌 Envelope: a JAX-native environment interface
|
|
28
|
+
|
|
28
29
|
```python
|
|
29
30
|
# Create environments from JAX-native suites you have installed, ...
|
|
30
31
|
env = envelope.create("gymnax::CartPole-v1")
|
|
31
32
|
|
|
32
33
|
# ... interact with the environments using a simple interface, ...
|
|
33
|
-
state, info = env.
|
|
34
|
+
state, info = env.init(key)
|
|
34
35
|
states, infos = jax.lax.scan(env.step, state, actions)
|
|
35
36
|
plt.plot(infos.reward.cumsum())
|
|
36
37
|
|
|
@@ -41,35 +42,42 @@ env = envelope.wrappers.ObservationNormalizationWrapper(env)
|
|
|
41
42
|
```
|
|
42
43
|
|
|
43
44
|
## 🌍 Simple, expressive interaction!
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
45
|
+
- **Environments are pytrees**. Squish them through JAX transformations and trace their parameters.
|
|
46
|
+
- **Idiomatic jax-y interface** of `init(key: Key) -> State, Info` and `step(state: State, action: PyTree) -> State, Info`. You can directly `jax.scan` over a `step(...)`!
|
|
47
|
+
- **Spaces are super simple**. No `Tuple`, `Dict` nonsense! There are two spaces: `Continuous` and `Discrete`, which you can compose into a `PyTreeSpace`.
|
|
48
|
+
- **Explicit episode truncation** supports correctly handling bootstrapping for value-function targets.
|
|
49
|
+
- **No auto-reset** by default. Resetting every step can be expensive!
|
|
49
50
|
|
|
50
51
|
## 💪 Powerful, composable wrappers!
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
52
|
+
|
|
53
|
+
- **Carry state across episodes** to track running statistics, for example to normalize observations.
|
|
54
|
+
- **Composable wrappers** can be stacked in any order. For example, `ObservationNormalizationWrapper` before vs. after `VmapWrapper` gives per-env vs. global normalization.
|
|
55
|
+
|
|
56
|
+
|
|
54
57
|
|
|
55
58
|
## 🔌 Adapters for existing suites
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
|
59
|
-
|
|
|
60
|
-
| [
|
|
61
|
-
| [
|
|
62
|
-
| [
|
|
63
|
-
| [
|
|
64
|
-
| | |
|
|
65
|
-
|
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
| 📦 | # 🤖 | # 🌍 |
|
|
62
|
+
| ------------------------------------------------------------------------- | ------- | ------- |
|
|
63
|
+
| [gymnax](https://github.com/RobertTLange/gymnax) | 🕺 | 24 |
|
|
64
|
+
| [brax](https://github.com/google/brax) | 🕺 | 12 |
|
|
65
|
+
| [jumanji](https://github.com/instadeepai/jumanji) | 🕺 / 👯 | 25 / 1 |
|
|
66
|
+
| [kinetix](https://github.com/flairox/kinetix) | 🕺 | 74 |
|
|
67
|
+
| [craftax](https://github.com/MichaelTMatthews/craftax) | 🕺 | 4 |
|
|
68
|
+
| [mujoco_playground](https://github.com/google-deepmind/mujoco_playground) | 🕺 | 54 |
|
|
69
|
+
| | | |
|
|
70
|
+
| Total | 🕺 / 👯 | 193 / 1 |
|
|
71
|
+
|
|
66
72
|
|
|
67
73
|
```python
|
|
68
74
|
envelope.create("📦::🌍")
|
|
69
75
|
```
|
|
76
|
+
|
|
70
77
|
let's you create environments from any of the above!
|
|
71
78
|
|
|
72
79
|
## 📝 Testing
|
|
80
|
+
|
|
73
81
|
- **Default (no optional compat deps required)**: `uv run pytest -m "not compat"`
|
|
74
82
|
- **Compat suite (requires full compat dependency group)**:
|
|
75
83
|
- `uv sync --group compat`
|
|
@@ -77,11 +85,14 @@ let's you create environments from any of the above!
|
|
|
77
85
|
- If any compat dependency is missing/broken, the run will fail fast with an error telling you what to install.
|
|
78
86
|
|
|
79
87
|
## 🏗️ Installation
|
|
88
|
+
|
|
80
89
|
```bash
|
|
81
90
|
pip install jax-envelope
|
|
82
91
|
```
|
|
83
92
|
|
|
84
93
|
## 💞 Related projects
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
94
|
+
|
|
95
|
+
- [stoa](https://github.com/EdanToledo/Stoa) is a very similar project that provides adapters and wrappers for the jumanji-like interface.
|
|
96
|
+
- Check out all the great suites we have adapters for! [gymnax](https://github.com/RobertTLange/gymnax), [brax](https://github.com/google/brax), [jumanji](https://github.com/instadeepai/jumanji), [kinetix](https://github.com/flairox/kinetix), [craftax](https://github.com/MichaelTMatthews/craftax), [mujoco_playground](https://github.com/google-deepmind/mujoco_playground).
|
|
97
|
+
- We will be adding support for [jaxmarl](https://github.com/flairox/jaxmarl) and [pgx](https://github.com/sotetsuk/pgx) in the future, as soon as we figured out the best ever MARL interface for JAX!
|
|
98
|
+
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
# 💌 Envelope: a JAX-native environment interface
|
|
2
|
+
|
|
2
3
|
```python
|
|
3
4
|
# Create environments from JAX-native suites you have installed, ...
|
|
4
5
|
env = envelope.create("gymnax::CartPole-v1")
|
|
5
6
|
|
|
6
7
|
# ... interact with the environments using a simple interface, ...
|
|
7
|
-
state, info = env.
|
|
8
|
+
state, info = env.init(key)
|
|
8
9
|
states, infos = jax.lax.scan(env.step, state, actions)
|
|
9
10
|
plt.plot(infos.reward.cumsum())
|
|
10
11
|
|
|
@@ -15,35 +16,42 @@ env = envelope.wrappers.ObservationNormalizationWrapper(env)
|
|
|
15
16
|
```
|
|
16
17
|
|
|
17
18
|
## 🌍 Simple, expressive interaction!
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
19
|
+
- **Environments are pytrees**. Squish them through JAX transformations and trace their parameters.
|
|
20
|
+
- **Idiomatic jax-y interface** of `init(key: Key) -> State, Info` and `step(state: State, action: PyTree) -> State, Info`. You can directly `jax.scan` over a `step(...)`!
|
|
21
|
+
- **Spaces are super simple**. No `Tuple`, `Dict` nonsense! There are two spaces: `Continuous` and `Discrete`, which you can compose into a `PyTreeSpace`.
|
|
22
|
+
- **Explicit episode truncation** supports correctly handling bootstrapping for value-function targets.
|
|
23
|
+
- **No auto-reset** by default. Resetting every step can be expensive!
|
|
23
24
|
|
|
24
25
|
## 💪 Powerful, composable wrappers!
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
26
|
+
|
|
27
|
+
- **Carry state across episodes** to track running statistics, for example to normalize observations.
|
|
28
|
+
- **Composable wrappers** can be stacked in any order. For example, `ObservationNormalizationWrapper` before vs. after `VmapWrapper` gives per-env vs. global normalization.
|
|
29
|
+
|
|
30
|
+
|
|
28
31
|
|
|
29
32
|
## 🔌 Adapters for existing suites
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
|
33
|
-
|
|
|
34
|
-
| [
|
|
35
|
-
| [
|
|
36
|
-
| [
|
|
37
|
-
| [
|
|
38
|
-
| | |
|
|
39
|
-
|
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
| 📦 | # 🤖 | # 🌍 |
|
|
36
|
+
| ------------------------------------------------------------------------- | ------- | ------- |
|
|
37
|
+
| [gymnax](https://github.com/RobertTLange/gymnax) | 🕺 | 24 |
|
|
38
|
+
| [brax](https://github.com/google/brax) | 🕺 | 12 |
|
|
39
|
+
| [jumanji](https://github.com/instadeepai/jumanji) | 🕺 / 👯 | 25 / 1 |
|
|
40
|
+
| [kinetix](https://github.com/flairox/kinetix) | 🕺 | 74 |
|
|
41
|
+
| [craftax](https://github.com/MichaelTMatthews/craftax) | 🕺 | 4 |
|
|
42
|
+
| [mujoco_playground](https://github.com/google-deepmind/mujoco_playground) | 🕺 | 54 |
|
|
43
|
+
| | | |
|
|
44
|
+
| Total | 🕺 / 👯 | 193 / 1 |
|
|
45
|
+
|
|
40
46
|
|
|
41
47
|
```python
|
|
42
48
|
envelope.create("📦::🌍")
|
|
43
49
|
```
|
|
50
|
+
|
|
44
51
|
let's you create environments from any of the above!
|
|
45
52
|
|
|
46
53
|
## 📝 Testing
|
|
54
|
+
|
|
47
55
|
- **Default (no optional compat deps required)**: `uv run pytest -m "not compat"`
|
|
48
56
|
- **Compat suite (requires full compat dependency group)**:
|
|
49
57
|
- `uv sync --group compat`
|
|
@@ -51,11 +59,14 @@ let's you create environments from any of the above!
|
|
|
51
59
|
- If any compat dependency is missing/broken, the run will fail fast with an error telling you what to install.
|
|
52
60
|
|
|
53
61
|
## 🏗️ Installation
|
|
62
|
+
|
|
54
63
|
```bash
|
|
55
64
|
pip install jax-envelope
|
|
56
65
|
```
|
|
57
66
|
|
|
58
67
|
## 💞 Related projects
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
68
|
+
|
|
69
|
+
- [stoa](https://github.com/EdanToledo/Stoa) is a very similar project that provides adapters and wrappers for the jumanji-like interface.
|
|
70
|
+
- Check out all the great suites we have adapters for! [gymnax](https://github.com/RobertTLange/gymnax), [brax](https://github.com/google/brax), [jumanji](https://github.com/instadeepai/jumanji), [kinetix](https://github.com/flairox/kinetix), [craftax](https://github.com/MichaelTMatthews/craftax), [mujoco_playground](https://github.com/google-deepmind/mujoco_playground).
|
|
71
|
+
- We will be adding support for [jaxmarl](https://github.com/flairox/jaxmarl) and [pgx](https://github.com/sotetsuk/pgx) in the future, as soon as we figured out the best ever MARL interface for JAX!
|
|
72
|
+
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "jax-envelope"
|
|
3
|
-
version = "0.
|
|
3
|
+
version = "0.3.0"
|
|
4
4
|
description = "A JAX-native environment interface with powerful wrappers and adapters for popular RL environment suites"
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
requires-python = ">=3.12"
|
|
@@ -51,18 +51,15 @@ allow-direct-references = true
|
|
|
51
51
|
[tool.hatch.build.targets.wheel]
|
|
52
52
|
packages = ["src/envelope"]
|
|
53
53
|
|
|
54
|
-
[tool.uv.sources]
|
|
55
|
-
gymnax = { git = "https://github.com/RobertTLange/gymnax" }
|
|
56
|
-
|
|
57
54
|
[dependency-groups]
|
|
58
55
|
compat = [
|
|
59
|
-
"gymnax @ git+https://github.com/RobertTLange/gymnax@main",
|
|
60
56
|
"brax>=0.13.0",
|
|
61
57
|
"craftax>=1.4.3",
|
|
62
58
|
"navix>=0.7.0",
|
|
63
59
|
"jumanji>=1.0.1",
|
|
64
|
-
"kinetix-env>=2.0.0",
|
|
65
60
|
"playground>=0.1.0",
|
|
61
|
+
"gymnax",
|
|
62
|
+
"kinetix-env",
|
|
66
63
|
]
|
|
67
64
|
dev = [
|
|
68
65
|
"hypothesis>=6.148.1",
|
|
@@ -71,7 +68,17 @@ dev = [
|
|
|
71
68
|
"ruff>=0.14.2",
|
|
72
69
|
]
|
|
73
70
|
|
|
71
|
+
[tool.uv]
|
|
72
|
+
override-dependencies = [
|
|
73
|
+
"tensorflow-probability>=0.26.0.dev20260116",
|
|
74
|
+
]
|
|
75
|
+
|
|
76
|
+
[tool.uv.sources]
|
|
77
|
+
gymnax = { git = "https://github.com/RobertTLange/gymnax.git" }
|
|
78
|
+
kinetix-env = { git = "https://github.com/FLAIROx/Kinetix.git" }
|
|
79
|
+
|
|
74
80
|
[tool.pytest.ini_options]
|
|
81
|
+
testpaths = ["tests"]
|
|
75
82
|
markers = [
|
|
76
83
|
"compat: tests requiring optional compat dependencies",
|
|
77
84
|
]
|
|
@@ -1,16 +1,22 @@
|
|
|
1
1
|
from envelope.compat import create
|
|
2
2
|
from envelope.environment import Environment, Info, InfoContainer
|
|
3
3
|
from envelope.spaces import BatchedSpace, Continuous, Discrete, PyTreeSpace, Space
|
|
4
|
-
from envelope.struct import
|
|
4
|
+
from envelope.struct import Container, FrozenPyTreeNode, field, static_field
|
|
5
5
|
from envelope.wrappers import (
|
|
6
|
-
Wrapper,
|
|
7
|
-
WrappedState,
|
|
8
6
|
AutoResetWrapper,
|
|
7
|
+
ClipActionWrapper,
|
|
8
|
+
ContinuousObservationWrapper,
|
|
9
|
+
EpisodeStatisticsWrapper,
|
|
10
|
+
FlattenActionWrapper,
|
|
11
|
+
FlattenObservationWrapper,
|
|
9
12
|
ObservationNormalizationWrapper,
|
|
13
|
+
PooledInitVmapWrapper,
|
|
10
14
|
StateInjectionWrapper,
|
|
11
15
|
TruncationWrapper,
|
|
12
|
-
VmapWrapper,
|
|
13
16
|
VmapEnvsWrapper,
|
|
17
|
+
VmapWrapper,
|
|
18
|
+
WrappedState,
|
|
19
|
+
Wrapper,
|
|
14
20
|
)
|
|
15
21
|
|
|
16
22
|
__all__ = [
|
|
@@ -34,7 +40,13 @@ __all__ = [
|
|
|
34
40
|
"Wrapper",
|
|
35
41
|
"WrappedState",
|
|
36
42
|
"AutoResetWrapper",
|
|
43
|
+
"ClipActionWrapper",
|
|
44
|
+
"ContinuousObservationWrapper",
|
|
45
|
+
"EpisodeStatisticsWrapper",
|
|
46
|
+
"FlattenActionWrapper",
|
|
47
|
+
"FlattenObservationWrapper",
|
|
37
48
|
"ObservationNormalizationWrapper",
|
|
49
|
+
"PooledInitVmapWrapper",
|
|
38
50
|
"StateInjectionWrapper",
|
|
39
51
|
"TruncationWrapper",
|
|
40
52
|
"VmapWrapper",
|
|
@@ -48,7 +48,7 @@ class BraxEnvelope(Environment):
|
|
|
48
48
|
def default_max_steps(self) -> int:
|
|
49
49
|
return _BRAX_DEFAULT_EPISODE_LENGTH
|
|
50
50
|
|
|
51
|
-
def __post_init__(self)
|
|
51
|
+
def __post_init__(self):
|
|
52
52
|
if isinstance(self.brax_env, BraxWrapper):
|
|
53
53
|
warnings.warn(
|
|
54
54
|
"Environment wrapping should be handled by envelope. "
|
|
@@ -57,7 +57,7 @@ class BraxEnvelope(Environment):
|
|
|
57
57
|
object.__setattr__(self, "brax_env", self.brax_env.unwrapped)
|
|
58
58
|
|
|
59
59
|
@override
|
|
60
|
-
def
|
|
60
|
+
def init(self, key: Key) -> tuple[State, Info]:
|
|
61
61
|
brax_state = self.brax_env.reset(key)
|
|
62
62
|
info = InfoContainer(obs=brax_state.obs, reward=0.0, terminated=False)
|
|
63
63
|
info = info.update(**dataclasses.asdict(brax_state))
|
|
@@ -67,7 +67,9 @@ class BraxEnvelope(Environment):
|
|
|
67
67
|
def step(self, state: State, action: PyTree) -> tuple[State, Info]:
|
|
68
68
|
brax_state = self.brax_env.step(state, action)
|
|
69
69
|
info = InfoContainer(
|
|
70
|
-
obs=brax_state.obs,
|
|
70
|
+
obs=brax_state.obs,
|
|
71
|
+
reward=brax_state.reward,
|
|
72
|
+
terminated=jnp.asarray(brax_state.done, dtype=bool),
|
|
71
73
|
)
|
|
72
74
|
info = info.update(**dataclasses.asdict(brax_state))
|
|
73
75
|
return brax_state, info
|
|
@@ -22,7 +22,7 @@ class CraftaxEnvelope(Environment):
|
|
|
22
22
|
"""Wrapper to convert a Craftax environment to a envelope environment."""
|
|
23
23
|
|
|
24
24
|
craftax_env: Any = static_field()
|
|
25
|
-
env_params: PyTree
|
|
25
|
+
env_params: PyTree = static_field() # TODO: remove static marker as soon as craftax merges https://github.com/MichaelTMatthews/Craftax/pull/48
|
|
26
26
|
|
|
27
27
|
@classmethod
|
|
28
28
|
def from_name(
|
|
@@ -54,12 +54,27 @@ class CraftaxEnvelope(Environment):
|
|
|
54
54
|
def default_max_steps(self) -> int:
|
|
55
55
|
return int(self.craftax_env.default_params.max_timesteps)
|
|
56
56
|
|
|
57
|
+
@cached_property
|
|
58
|
+
def _craftax_info_placeholder(self) -> PyTree:
|
|
59
|
+
key = jax.random.PRNGKey(0)
|
|
60
|
+
_, state = self.craftax_env.reset(key, self.env_params)
|
|
61
|
+
_, _, _, _, info = self.craftax_env.step(
|
|
62
|
+
key,
|
|
63
|
+
state,
|
|
64
|
+
self.craftax_env.action_space(self.env_params).sample(key),
|
|
65
|
+
self.env_params,
|
|
66
|
+
)
|
|
67
|
+
return jax.tree.map(lambda x: jnp.full_like(x, jnp.nan), info)
|
|
68
|
+
|
|
57
69
|
@override
|
|
58
|
-
def
|
|
70
|
+
def init(self, key: Key) -> tuple[State, Info]:
|
|
71
|
+
# TODO: this function does not add env_info (or comparable) to the info
|
|
72
|
+
# container. We should add tests for this (and all other envelopes) and fix it.
|
|
59
73
|
key, subkey = jax.random.split(key)
|
|
60
74
|
obs, env_state = self.craftax_env.reset(subkey, self.env_params)
|
|
61
75
|
state = Container().update(key=key, env_state=env_state)
|
|
62
76
|
info = InfoContainer(obs=obs, reward=0.0, terminated=False)
|
|
77
|
+
info = info.update(info=self._craftax_info_placeholder)
|
|
63
78
|
return state, info
|
|
64
79
|
|
|
65
80
|
@override
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from functools import cached_property
|
|
2
|
-
from typing import Any, override
|
|
2
|
+
from typing import Any, Callable, cast, override
|
|
3
3
|
|
|
4
4
|
import jax
|
|
5
5
|
import jax.numpy as jnp
|
|
@@ -10,15 +10,24 @@ from gymnax.environments.environment import EnvParams as GymnaxEnvParams
|
|
|
10
10
|
|
|
11
11
|
from envelope import spaces as envelope_spaces
|
|
12
12
|
from envelope.environment import Environment, Info, InfoContainer, State
|
|
13
|
-
from envelope.struct import Container, static_field
|
|
13
|
+
from envelope.struct import Container, field, static_field
|
|
14
14
|
from envelope.typing import Key, PyTree
|
|
15
15
|
|
|
16
|
+
_GymnaxReset = Callable[
|
|
17
|
+
[Key, GymnaxEnvParams],
|
|
18
|
+
tuple[PyTree, Any],
|
|
19
|
+
]
|
|
20
|
+
_GymnaxStep = Callable[
|
|
21
|
+
[Key, Any, PyTree, GymnaxEnvParams],
|
|
22
|
+
tuple[PyTree, Any, jnp.ndarray, jnp.ndarray, PyTree],
|
|
23
|
+
]
|
|
24
|
+
|
|
16
25
|
|
|
17
26
|
class GymnaxEnvelope(Environment):
|
|
18
27
|
"""Wrapper to convert a Gymnax environment to a envelope environment."""
|
|
19
28
|
|
|
20
29
|
gymnax_env: GymnaxEnv = static_field()
|
|
21
|
-
env_params: PyTree
|
|
30
|
+
env_params: PyTree = field()
|
|
22
31
|
|
|
23
32
|
@classmethod
|
|
24
33
|
def from_name(
|
|
@@ -43,19 +52,37 @@ class GymnaxEnvelope(Environment):
|
|
|
43
52
|
def default_max_steps(self) -> int:
|
|
44
53
|
return int(self.gymnax_env.default_params.max_steps_in_episode)
|
|
45
54
|
|
|
55
|
+
@cached_property
|
|
56
|
+
def _gymnax_info_placeholder(self) -> PyTree:
|
|
57
|
+
reset_fn = cast(_GymnaxReset, self.gymnax_env.reset)
|
|
58
|
+
step_fn = cast(_GymnaxStep, self.gymnax_env.step)
|
|
59
|
+
|
|
60
|
+
key = jax.random.PRNGKey(0)
|
|
61
|
+
_, state = reset_fn(key, self.env_params)
|
|
62
|
+
_, _, _, _, info = step_fn(
|
|
63
|
+
key,
|
|
64
|
+
state,
|
|
65
|
+
self.gymnax_env.action_space(self.env_params).sample(key),
|
|
66
|
+
self.env_params,
|
|
67
|
+
)
|
|
68
|
+
return jax.tree.map(lambda x: jnp.full_like(x, jnp.nan, dtype=float), info)
|
|
69
|
+
|
|
46
70
|
@override
|
|
47
|
-
def
|
|
71
|
+
def init(self, key: Key) -> tuple[State, Info]:
|
|
72
|
+
reset_fn = cast(_GymnaxReset, self.gymnax_env.reset)
|
|
73
|
+
|
|
48
74
|
key, subkey = jax.random.split(key)
|
|
49
|
-
obs, env_state =
|
|
75
|
+
obs, env_state = reset_fn(subkey, self.env_params)
|
|
50
76
|
state = Container().update(key=key, env_state=env_state)
|
|
51
77
|
info = InfoContainer(obs=obs, reward=0.0, terminated=False)
|
|
52
|
-
info = info.update(info=
|
|
78
|
+
info = info.update(info=self._gymnax_info_placeholder)
|
|
53
79
|
return state, info
|
|
54
80
|
|
|
55
81
|
@override
|
|
56
82
|
def step(self, state: State, action: PyTree) -> tuple[State, Info]:
|
|
57
83
|
key, subkey = jax.random.split(state.key)
|
|
58
|
-
|
|
84
|
+
step_fn = cast(_GymnaxStep, self.gymnax_env.step)
|
|
85
|
+
obs, env_state, reward, done, env_info = step_fn(
|
|
59
86
|
subkey, state.env_state, action, self.env_params
|
|
60
87
|
)
|
|
61
88
|
state = state.update(key=key, env_state=env_state)
|
|
@@ -48,7 +48,7 @@ class JumanjiEnvelope(Environment):
|
|
|
48
48
|
return self._default_time_limit
|
|
49
49
|
|
|
50
50
|
@override
|
|
51
|
-
def
|
|
51
|
+
def init(self, key: Key) -> tuple[State, Info]:
|
|
52
52
|
env_state, timestep = self.jumanji_env.reset(key)
|
|
53
53
|
info = convert_jumanji_to_envelope_info(timestep)
|
|
54
54
|
return env_state, info
|
|
@@ -81,8 +81,9 @@ class JumanjiEnvelope(Environment):
|
|
|
81
81
|
|
|
82
82
|
|
|
83
83
|
def convert_jumanji_to_envelope_info(timestep: JumanjiTimeStep) -> InfoContainer:
|
|
84
|
+
term = jnp.asarray(timestep.last(), dtype=bool)
|
|
84
85
|
info = InfoContainer(
|
|
85
|
-
obs=timestep.observation, reward=timestep.reward, terminated=
|
|
86
|
+
obs=timestep.observation, reward=timestep.reward, terminated=term
|
|
86
87
|
).update(**timestep.extras)
|
|
87
88
|
return info
|
|
88
89
|
|
|
@@ -28,6 +28,7 @@ from kinetix.environment import (
|
|
|
28
28
|
from kinetix.environment.ued.ued import make_reset_fn_sample_kinetix_level
|
|
29
29
|
from kinetix.util.saving import load_from_json_file
|
|
30
30
|
|
|
31
|
+
from envelope import field
|
|
31
32
|
from envelope import spaces as envelope_spaces
|
|
32
33
|
from envelope.compat.gymnax_envelope import _convert_space as _convert_gymnax_space
|
|
33
34
|
from envelope.environment import Environment, Info, InfoContainer, State
|
|
@@ -67,7 +68,7 @@ class KinetixEnvelope(Environment):
|
|
|
67
68
|
"""Wrapper to convert a Kinetix environment to a envelope environment."""
|
|
68
69
|
|
|
69
70
|
kinetix_env: Any = static_field()
|
|
70
|
-
env_params: Any
|
|
71
|
+
env_params: Any = field()
|
|
71
72
|
|
|
72
73
|
@property
|
|
73
74
|
def default_max_steps(self) -> int:
|
|
@@ -162,7 +163,7 @@ class KinetixEnvelope(Environment):
|
|
|
162
163
|
return cls(kinetix_env=kinetix_env, env_params=env_params)
|
|
163
164
|
|
|
164
165
|
@override
|
|
165
|
-
def
|
|
166
|
+
def init(self, key: Key) -> tuple[State, Info]:
|
|
166
167
|
key, subkey = jax.random.split(key)
|
|
167
168
|
obs, env_state = self.kinetix_env.reset(subkey, self.env_params)
|
|
168
169
|
state_out = Container().update(key=key, env_state=env_state)
|
|
@@ -56,7 +56,7 @@ class MujocoPlaygroundEnvelope(Environment):
|
|
|
56
56
|
return self._default_max_steps
|
|
57
57
|
|
|
58
58
|
@override
|
|
59
|
-
def
|
|
59
|
+
def init(self, key: Key) -> tuple[State, Info]:
|
|
60
60
|
env_state = self.mujoco_playground_env.reset(key)
|
|
61
61
|
info = InfoContainer(obs=env_state.obs, reward=0.0, terminated=False)
|
|
62
62
|
info = info.update(**dataclasses.asdict(env_state))
|
|
@@ -38,7 +38,7 @@ class NavixEnvelope(Environment):
|
|
|
38
38
|
return _NAVIX_DEFAULT_MAX_STEPS
|
|
39
39
|
|
|
40
40
|
@override
|
|
41
|
-
def
|
|
41
|
+
def init(self, key: Key) -> tuple[State, Info]:
|
|
42
42
|
timestep = self.navix_env.reset(key)
|
|
43
43
|
return timestep, convert_navix_to_envelope_info(timestep)
|
|
44
44
|
|
|
@@ -5,7 +5,7 @@ from typing import Protocol, runtime_checkable
|
|
|
5
5
|
|
|
6
6
|
from envelope import spaces
|
|
7
7
|
from envelope.struct import Container, FrozenPyTreeNode
|
|
8
|
-
from envelope.typing import Key, PyTree
|
|
8
|
+
from envelope.typing import Array, Key, PyTree
|
|
9
9
|
|
|
10
10
|
__all__ = ["Environment", "State", "Info", "InfoContainer"]
|
|
11
11
|
|
|
@@ -23,7 +23,7 @@ class Info(Protocol):
|
|
|
23
23
|
|
|
24
24
|
class InfoContainer(Container):
|
|
25
25
|
obs: PyTree
|
|
26
|
-
reward: float
|
|
26
|
+
reward: float | Array
|
|
27
27
|
terminated: bool
|
|
28
28
|
truncated: bool = field(default=False)
|
|
29
29
|
|
|
@@ -38,18 +38,25 @@ class Environment(ABC, FrozenPyTreeNode):
|
|
|
38
38
|
|
|
39
39
|
State is an opaque PyTree owned by each environment; wrappers that stack
|
|
40
40
|
environments should expose their wrapped env state as `inner_state` while
|
|
41
|
-
adding any wrapper-specific fields.
|
|
42
|
-
|
|
43
|
-
|
|
41
|
+
adding any wrapper-specific fields.
|
|
42
|
+
|
|
43
|
+
Two distinct lifecycle methods:
|
|
44
|
+
init(key) - Initialize environment and all state from scratch.
|
|
45
|
+
reset(state, key) - Reset the inner environment while preserving
|
|
46
|
+
episode-persistent state.
|
|
44
47
|
"""
|
|
45
48
|
|
|
46
49
|
@abstractmethod
|
|
47
|
-
def
|
|
48
|
-
|
|
49
|
-
|
|
50
|
+
def init(self, key: Key) -> tuple[State, Info]:
|
|
51
|
+
"""Initialize environment and all state from scratch."""
|
|
52
|
+
...
|
|
53
|
+
|
|
54
|
+
def reset(self, state: State, key: Key) -> tuple[State, Info]:
|
|
55
|
+
"""Reset the inner environment while preserving episode-persistent state."""
|
|
56
|
+
return self.init(key)
|
|
50
57
|
|
|
51
58
|
@abstractmethod
|
|
52
|
-
def step(self, state: State, action: PyTree
|
|
59
|
+
def step(self, state: State, action: PyTree) -> tuple[State, Info]: ...
|
|
53
60
|
|
|
54
61
|
@abstractmethod
|
|
55
62
|
@cached_property
|