jax-envelope 0.2.0__tar.gz → 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/PKG-INFO +34 -23
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/README.md +33 -22
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/pyproject.toml +1 -1
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/compat/brax_envelope.py +1 -1
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/compat/jumanji_envelope.py +1 -1
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/environment.py +2 -2
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/spaces.py +1 -1
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/wrappers/autoreset_wrapper.py +2 -2
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/wrappers/continuous_observation_wrapper.py +2 -2
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/wrappers/episode_statistics_wrapper.py +2 -2
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/wrappers/flatten_observation_wrapper.py +3 -4
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/wrappers/observation_normalization_wrapper.py +2 -2
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/wrappers/pooled_init_vmap_wrapper.py +2 -2
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/wrappers/state_injection_wrapper.py +2 -2
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/wrappers/truncation_wrapper.py +2 -2
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/wrappers/vmap_envs_wrapper.py +5 -5
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/wrappers/vmap_wrapper.py +2 -2
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/wrappers/wrapper.py +3 -3
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/compat/test_brax_compat.py +2 -2
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/wrappers/helpers.py +25 -35
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/wrappers/test_autoreset_wrapper.py +2 -2
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/wrappers/test_clip_action_wrapper.py +2 -2
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/wrappers/test_continuous_observation_wrapper.py +1 -1
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/wrappers/test_episode_statistics_wrapper.py +2 -2
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/wrappers/test_flatten_action_wrapper.py +4 -4
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/wrappers/test_flatten_observation_wrapper.py +2 -2
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/wrappers/test_pooled_init_vmap_wrapper.py +1 -1
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/wrappers/test_state_injection_wrapper.py +4 -4
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/wrappers/test_truncation_wrapper.py +2 -2
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/uv.lock +1 -1
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/.github/workflows/publish.yml +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/.gitignore +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/LICENSE +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/__init__.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/compat/__init__.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/compat/craftax_envelope.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/compat/gymnax_envelope.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/compat/kinetix_envelope.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/compat/mujoco_playground_envelope.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/compat/navix_envelope.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/struct.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/typing.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/wrappers/__init__.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/wrappers/clip_action_wrapper.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/wrappers/flatten_action_wrapper.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/wrappers/normalization.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/__init__.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/compat/__init__.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/compat/conftest.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/compat/contract.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/compat/test_craftax_compat.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/compat/test_create.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/compat/test_create_integration.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/compat/test_gymnax_compat.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/compat/test_jumanji_compat.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/compat/test_kinetix_compat.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/compat/test_mujoco_playground_compat.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/compat/test_navix_compat.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/spaces/__init__.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/spaces/test_batched_space.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/spaces/test_continuous.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/spaces/test_discrete.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/spaces/test_pytree_space.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/spaces/test_serialization.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/test_container.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/test_struct.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/wrappers/__init__.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/wrappers/test_environment_wrapper.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/wrappers/test_normalization.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/wrappers/test_observation_normalization_wrapper.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/wrappers/test_vmap_envs_wrapper.py +0 -0
- {jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/wrappers/test_vmap_wrapper.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: jax-envelope
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0
|
|
4
4
|
Summary: A JAX-native environment interface with powerful wrappers and adapters for popular RL environment suites
|
|
5
5
|
Project-URL: Homepage, https://github.com/keraJLi/envelope
|
|
6
6
|
Project-URL: Repository, https://github.com/keraJLi/envelope
|
|
@@ -25,12 +25,13 @@ Requires-Dist: jax>=0.5.0
|
|
|
25
25
|
Description-Content-Type: text/markdown
|
|
26
26
|
|
|
27
27
|
# 💌 Envelope: a JAX-native environment interface
|
|
28
|
+
|
|
28
29
|
```python
|
|
29
30
|
# Create environments from JAX-native suites you have installed, ...
|
|
30
31
|
env = envelope.create("gymnax::CartPole-v1")
|
|
31
32
|
|
|
32
33
|
# ... interact with the environments using a simple interface, ...
|
|
33
|
-
state, info = env.
|
|
34
|
+
state, info = env.init(key)
|
|
34
35
|
states, infos = jax.lax.scan(env.step, state, actions)
|
|
35
36
|
plt.plot(infos.reward.cumsum())
|
|
36
37
|
|
|
@@ -41,35 +42,42 @@ env = envelope.wrappers.ObservationNormalizationWrapper(env)
|
|
|
41
42
|
```
|
|
42
43
|
|
|
43
44
|
## 🌍 Simple, expressive interaction!
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
45
|
+
- **Environments are pytrees**. Squish them through JAX transformations and trace their parameters.
|
|
46
|
+
- **Idiomatic jax-y interface** of `init(key: Key) -> State, Info` and `step(state: State, action: PyTree) -> State, Info`. You can directly `jax.scan` over a `step(...)`!
|
|
47
|
+
- **Spaces are super simple**. No `Tuple`, `Dict` nonsense! There are two spaces: `Continuous` and `Discrete`, which you can compose into a `PyTreeSpace`.
|
|
48
|
+
- **Explicit episode truncation** supports correctly handling bootstrapping for value-function targets.
|
|
49
|
+
- **No auto-reset** by default. Resetting every step can be expensive!
|
|
49
50
|
|
|
50
51
|
## 💪 Powerful, composable wrappers!
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
52
|
+
|
|
53
|
+
- **Carry state across episodes** to track running statistics, for example to normalize observations.
|
|
54
|
+
- **Composable wrappers** can be stacked in any order. For example, `ObservationNormalizationWrapper` before vs. after `VmapWrapper` gives per-env vs. global normalization.
|
|
55
|
+
|
|
56
|
+
|
|
54
57
|
|
|
55
58
|
## 🔌 Adapters for existing suites
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
|
59
|
-
|
|
|
60
|
-
| [
|
|
61
|
-
| [
|
|
62
|
-
| [
|
|
63
|
-
| [
|
|
64
|
-
| | |
|
|
65
|
-
|
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
| 📦 | # 🤖 | # 🌍 |
|
|
62
|
+
| ------------------------------------------------------------------------- | ------- | ------- |
|
|
63
|
+
| [gymnax](https://github.com/RobertTLange/gymnax) | 🕺 | 24 |
|
|
64
|
+
| [brax](https://github.com/google/brax) | 🕺 | 12 |
|
|
65
|
+
| [jumanji](https://github.com/instadeepai/jumanji) | 🕺 / 👯 | 25 / 1 |
|
|
66
|
+
| [kinetix](https://github.com/flairox/kinetix) | 🕺 | 74 |
|
|
67
|
+
| [craftax](https://github.com/MichaelTMatthews/craftax) | 🕺 | 4 |
|
|
68
|
+
| [mujoco_playground](https://github.com/google-deepmind/mujoco_playground) | 🕺 | 54 |
|
|
69
|
+
| | | |
|
|
70
|
+
| Total | 🕺 / 👯 | 193 / 1 |
|
|
71
|
+
|
|
66
72
|
|
|
67
73
|
```python
|
|
68
74
|
envelope.create("📦::🌍")
|
|
69
75
|
```
|
|
76
|
+
|
|
70
77
|
let's you create environments from any of the above!
|
|
71
78
|
|
|
72
79
|
## 📝 Testing
|
|
80
|
+
|
|
73
81
|
- **Default (no optional compat deps required)**: `uv run pytest -m "not compat"`
|
|
74
82
|
- **Compat suite (requires full compat dependency group)**:
|
|
75
83
|
- `uv sync --group compat`
|
|
@@ -77,11 +85,14 @@ let's you create environments from any of the above!
|
|
|
77
85
|
- If any compat dependency is missing/broken, the run will fail fast with an error telling you what to install.
|
|
78
86
|
|
|
79
87
|
## 🏗️ Installation
|
|
88
|
+
|
|
80
89
|
```bash
|
|
81
90
|
pip install jax-envelope
|
|
82
91
|
```
|
|
83
92
|
|
|
84
93
|
## 💞 Related projects
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
94
|
+
|
|
95
|
+
- [stoa](https://github.com/EdanToledo/Stoa) is a very similar project that provides adapters and wrappers for the jumanji-like interface.
|
|
96
|
+
- Check out all the great suites we have adapters for! [gymnax](https://github.com/RobertTLange/gymnax), [brax](https://github.com/google/brax), [jumanji](https://github.com/instadeepai/jumanji), [kinetix](https://github.com/flairox/kinetix), [craftax](https://github.com/MichaelTMatthews/craftax), [mujoco_playground](https://github.com/google-deepmind/mujoco_playground).
|
|
97
|
+
- We will be adding support for [jaxmarl](https://github.com/flairox/jaxmarl) and [pgx](https://github.com/sotetsuk/pgx) in the future, as soon as we figured out the best ever MARL interface for JAX!
|
|
98
|
+
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
# 💌 Envelope: a JAX-native environment interface
|
|
2
|
+
|
|
2
3
|
```python
|
|
3
4
|
# Create environments from JAX-native suites you have installed, ...
|
|
4
5
|
env = envelope.create("gymnax::CartPole-v1")
|
|
5
6
|
|
|
6
7
|
# ... interact with the environments using a simple interface, ...
|
|
7
|
-
state, info = env.
|
|
8
|
+
state, info = env.init(key)
|
|
8
9
|
states, infos = jax.lax.scan(env.step, state, actions)
|
|
9
10
|
plt.plot(infos.reward.cumsum())
|
|
10
11
|
|
|
@@ -15,35 +16,42 @@ env = envelope.wrappers.ObservationNormalizationWrapper(env)
|
|
|
15
16
|
```
|
|
16
17
|
|
|
17
18
|
## 🌍 Simple, expressive interaction!
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
19
|
+
- **Environments are pytrees**. Squish them through JAX transformations and trace their parameters.
|
|
20
|
+
- **Idiomatic jax-y interface** of `init(key: Key) -> State, Info` and `step(state: State, action: PyTree) -> State, Info`. You can directly `jax.scan` over a `step(...)`!
|
|
21
|
+
- **Spaces are super simple**. No `Tuple`, `Dict` nonsense! There are two spaces: `Continuous` and `Discrete`, which you can compose into a `PyTreeSpace`.
|
|
22
|
+
- **Explicit episode truncation** supports correctly handling bootstrapping for value-function targets.
|
|
23
|
+
- **No auto-reset** by default. Resetting every step can be expensive!
|
|
23
24
|
|
|
24
25
|
## 💪 Powerful, composable wrappers!
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
26
|
+
|
|
27
|
+
- **Carry state across episodes** to track running statistics, for example to normalize observations.
|
|
28
|
+
- **Composable wrappers** can be stacked in any order. For example, `ObservationNormalizationWrapper` before vs. after `VmapWrapper` gives per-env vs. global normalization.
|
|
29
|
+
|
|
30
|
+
|
|
28
31
|
|
|
29
32
|
## 🔌 Adapters for existing suites
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
|
33
|
-
|
|
|
34
|
-
| [
|
|
35
|
-
| [
|
|
36
|
-
| [
|
|
37
|
-
| [
|
|
38
|
-
| | |
|
|
39
|
-
|
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
| 📦 | # 🤖 | # 🌍 |
|
|
36
|
+
| ------------------------------------------------------------------------- | ------- | ------- |
|
|
37
|
+
| [gymnax](https://github.com/RobertTLange/gymnax) | 🕺 | 24 |
|
|
38
|
+
| [brax](https://github.com/google/brax) | 🕺 | 12 |
|
|
39
|
+
| [jumanji](https://github.com/instadeepai/jumanji) | 🕺 / 👯 | 25 / 1 |
|
|
40
|
+
| [kinetix](https://github.com/flairox/kinetix) | 🕺 | 74 |
|
|
41
|
+
| [craftax](https://github.com/MichaelTMatthews/craftax) | 🕺 | 4 |
|
|
42
|
+
| [mujoco_playground](https://github.com/google-deepmind/mujoco_playground) | 🕺 | 54 |
|
|
43
|
+
| | | |
|
|
44
|
+
| Total | 🕺 / 👯 | 193 / 1 |
|
|
45
|
+
|
|
40
46
|
|
|
41
47
|
```python
|
|
42
48
|
envelope.create("📦::🌍")
|
|
43
49
|
```
|
|
50
|
+
|
|
44
51
|
let's you create environments from any of the above!
|
|
45
52
|
|
|
46
53
|
## 📝 Testing
|
|
54
|
+
|
|
47
55
|
- **Default (no optional compat deps required)**: `uv run pytest -m "not compat"`
|
|
48
56
|
- **Compat suite (requires full compat dependency group)**:
|
|
49
57
|
- `uv sync --group compat`
|
|
@@ -51,11 +59,14 @@ let's you create environments from any of the above!
|
|
|
51
59
|
- If any compat dependency is missing/broken, the run will fail fast with an error telling you what to install.
|
|
52
60
|
|
|
53
61
|
## 🏗️ Installation
|
|
62
|
+
|
|
54
63
|
```bash
|
|
55
64
|
pip install jax-envelope
|
|
56
65
|
```
|
|
57
66
|
|
|
58
67
|
## 💞 Related projects
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
68
|
+
|
|
69
|
+
- [stoa](https://github.com/EdanToledo/Stoa) is a very similar project that provides adapters and wrappers for the jumanji-like interface.
|
|
70
|
+
- Check out all the great suites we have adapters for! [gymnax](https://github.com/RobertTLange/gymnax), [brax](https://github.com/google/brax), [jumanji](https://github.com/instadeepai/jumanji), [kinetix](https://github.com/flairox/kinetix), [craftax](https://github.com/MichaelTMatthews/craftax), [mujoco_playground](https://github.com/google-deepmind/mujoco_playground).
|
|
71
|
+
- We will be adding support for [jaxmarl](https://github.com/flairox/jaxmarl) and [pgx](https://github.com/sotetsuk/pgx) in the future, as soon as we figured out the best ever MARL interface for JAX!
|
|
72
|
+
|
|
@@ -69,7 +69,7 @@ class BraxEnvelope(Environment):
|
|
|
69
69
|
info = InfoContainer(
|
|
70
70
|
obs=brax_state.obs,
|
|
71
71
|
reward=brax_state.reward,
|
|
72
|
-
terminated=jnp.
|
|
72
|
+
terminated=jnp.asarray(brax_state.done, dtype=bool),
|
|
73
73
|
)
|
|
74
74
|
info = info.update(**dataclasses.asdict(brax_state))
|
|
75
75
|
return brax_state, info
|
|
@@ -81,7 +81,7 @@ class JumanjiEnvelope(Environment):
|
|
|
81
81
|
|
|
82
82
|
|
|
83
83
|
def convert_jumanji_to_envelope_info(timestep: JumanjiTimeStep) -> InfoContainer:
|
|
84
|
-
term = jnp.asarray(timestep.last(), dtype=bool)
|
|
84
|
+
term = jnp.asarray(timestep.last(), dtype=bool)
|
|
85
85
|
info = InfoContainer(
|
|
86
86
|
obs=timestep.observation, reward=timestep.reward, terminated=term
|
|
87
87
|
).update(**timestep.extras)
|
|
@@ -42,7 +42,7 @@ class Environment(ABC, FrozenPyTreeNode):
|
|
|
42
42
|
|
|
43
43
|
Two distinct lifecycle methods:
|
|
44
44
|
init(key) - Initialize environment and all state from scratch.
|
|
45
|
-
reset(
|
|
45
|
+
reset(state, key) - Reset the inner environment while preserving
|
|
46
46
|
episode-persistent state.
|
|
47
47
|
"""
|
|
48
48
|
|
|
@@ -51,7 +51,7 @@ class Environment(ABC, FrozenPyTreeNode):
|
|
|
51
51
|
"""Initialize environment and all state from scratch."""
|
|
52
52
|
...
|
|
53
53
|
|
|
54
|
-
def reset(self,
|
|
54
|
+
def reset(self, state: State, key: Key) -> tuple[State, Info]:
|
|
55
55
|
"""Reset the inner environment while preserving episode-persistent state."""
|
|
56
56
|
return self.init(key)
|
|
57
57
|
|
|
@@ -54,7 +54,7 @@ class AutoResetWrapper(Wrapper):
|
|
|
54
54
|
return state, info.update(final=state.last_final)
|
|
55
55
|
|
|
56
56
|
@override
|
|
57
|
-
def reset(self,
|
|
57
|
+
def reset(self, state: WrappedState, key: Key) -> tuple[WrappedState, Info]:
|
|
58
58
|
raise NotImplementedError("Reset is not implemented for AutoResetWrapper")
|
|
59
59
|
|
|
60
60
|
@override
|
|
@@ -63,7 +63,7 @@ class AutoResetWrapper(Wrapper):
|
|
|
63
63
|
state = state.replace(reset_key=key)
|
|
64
64
|
|
|
65
65
|
inner_state, info = self.env.step(state.inner_state, action)
|
|
66
|
-
reset_inner_state, reset_info = self.env.reset(
|
|
66
|
+
reset_inner_state, reset_info = self.env.reset(inner_state, key_reset)
|
|
67
67
|
|
|
68
68
|
# Select next state and info based on done
|
|
69
69
|
done = info.terminated | info.truncated
|
{jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/wrappers/continuous_observation_wrapper.py
RENAMED
|
@@ -35,8 +35,8 @@ class ContinuousObservationWrapper(Wrapper):
|
|
|
35
35
|
return state, info
|
|
36
36
|
|
|
37
37
|
@override
|
|
38
|
-
def reset(self,
|
|
39
|
-
state, info = self.env.reset(
|
|
38
|
+
def reset(self, state: State, key: Key) -> tuple[State, Info]:
|
|
39
|
+
state, info = self.env.reset(state, key)
|
|
40
40
|
info = info.update(obs=to_float(info.obs))
|
|
41
41
|
return state, info
|
|
42
42
|
|
{jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/wrappers/episode_statistics_wrapper.py
RENAMED
|
@@ -24,8 +24,8 @@ class EpisodeStatisticsWrapper(Wrapper):
|
|
|
24
24
|
return state, info.update(stats=state.stats)
|
|
25
25
|
|
|
26
26
|
@override
|
|
27
|
-
def reset(self,
|
|
28
|
-
inner_state, info = self.env.reset(
|
|
27
|
+
def reset(self, state: State, key: Key) -> tuple[State, Info]:
|
|
28
|
+
inner_state, info = self.env.reset(state.inner_state, key)
|
|
29
29
|
state = state.replace(inner_state=inner_state)
|
|
30
30
|
return state, info.update(stats=state.stats)
|
|
31
31
|
|
{jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/wrappers/flatten_observation_wrapper.py
RENAMED
|
@@ -37,10 +37,9 @@ class FlattenObservationWrapper(Wrapper):
|
|
|
37
37
|
return state, info
|
|
38
38
|
|
|
39
39
|
@override
|
|
40
|
-
def reset(self,
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
return state, info
|
|
40
|
+
def reset(self, state: State, key: Key) -> tuple[State, Info]:
|
|
41
|
+
next_state, info = self.env.reset(state, key)
|
|
42
|
+
return next_state, info.update(obs=flatten_x(info.obs))
|
|
44
43
|
|
|
45
44
|
@override
|
|
46
45
|
def step(self, state: State, action: PyTree) -> tuple[State, Info]:
|
{jax_envelope-0.2.0 → jax_envelope-0.3.0}/src/envelope/wrappers/observation_normalization_wrapper.py
RENAMED
|
@@ -76,8 +76,8 @@ class ObservationNormalizationWrapper(Wrapper):
|
|
|
76
76
|
return self._normalize_and_update(next_state, info)
|
|
77
77
|
|
|
78
78
|
@override
|
|
79
|
-
def reset(self,
|
|
80
|
-
inner_state, info = self.env.reset(
|
|
79
|
+
def reset(self, state: WrappedState, key: Key) -> tuple[WrappedState, Info]:
|
|
80
|
+
inner_state, info = self.env.reset(state.inner_state, key)
|
|
81
81
|
# Preserve running statistics across resets
|
|
82
82
|
next_state = self.ObservationNormalizationState(
|
|
83
83
|
inner_state=inner_state, rmv_state=state.rmv_state
|
|
@@ -36,7 +36,7 @@ class PooledInitVmapWrapper(Wrapper):
|
|
|
36
36
|
return state, info.update(final=pholder_info)
|
|
37
37
|
|
|
38
38
|
@override
|
|
39
|
-
def reset(self,
|
|
39
|
+
def reset(self, state: WrappedState, key: Key) -> tuple[WrappedState, Info]:
|
|
40
40
|
# It's hard to support reset for this wrapper.
|
|
41
41
|
# We would have to init the state of a pool of unwrapped environments, and then
|
|
42
42
|
# somehow inject this into the stack of wrapped states. The current data
|
|
@@ -48,7 +48,7 @@ class PooledInitVmapWrapper(Wrapper):
|
|
|
48
48
|
# episodes before vmapping, we will implement this later.
|
|
49
49
|
keys = _split_or_keep_key(key, self.batch_size + 1)
|
|
50
50
|
key_next, keys_pool = keys[0], keys[1:]
|
|
51
|
-
inner_state, info = jax.vmap(self.env.reset)(
|
|
51
|
+
inner_state, info = jax.vmap(self.env.reset)(state.inner_state, keys_pool)
|
|
52
52
|
state = state.replace(inner_state=inner_state, init_key=key_next)
|
|
53
53
|
return state, info.update(final=state.last_final)
|
|
54
54
|
|
|
@@ -69,13 +69,13 @@ class StateInjectionWrapper(Wrapper):
|
|
|
69
69
|
return state, info
|
|
70
70
|
|
|
71
71
|
@override
|
|
72
|
-
def reset(self,
|
|
72
|
+
def reset(self, state: WrappedState, key: Key) -> tuple[WrappedState, Info]:
|
|
73
73
|
# If reset state is set, use it instead of resetting inner env
|
|
74
74
|
if state.reset_state is not None and state.reset_obs is not None:
|
|
75
75
|
inner_state = state.reset_state
|
|
76
76
|
info = InfoContainer(obs=state.reset_obs, reward=0.0, terminated=False)
|
|
77
77
|
elif state.reset_state is None and state.reset_obs is None:
|
|
78
|
-
inner_state, info = self.env.reset(
|
|
78
|
+
inner_state, info = self.env.reset(state.inner_state, key)
|
|
79
79
|
else:
|
|
80
80
|
raise ValueError("State must set both reset_state and reset_obs or neither")
|
|
81
81
|
|
|
@@ -21,8 +21,8 @@ class TruncationWrapper(Wrapper):
|
|
|
21
21
|
return state, info.update(truncated=self.max_steps <= 0)
|
|
22
22
|
|
|
23
23
|
@override
|
|
24
|
-
def reset(self,
|
|
25
|
-
inner_state, info = self.env.reset(
|
|
24
|
+
def reset(self, state: WrappedState, key: Key) -> tuple[WrappedState, Info]:
|
|
25
|
+
inner_state, info = self.env.reset(state.inner_state, key)
|
|
26
26
|
state = state.replace(inner_state=inner_state, steps=0)
|
|
27
27
|
return state, info.update(truncated=self.max_steps <= 0)
|
|
28
28
|
|
|
@@ -40,11 +40,9 @@ class VmapEnvsWrapper(Wrapper):
|
|
|
40
40
|
return state, info
|
|
41
41
|
|
|
42
42
|
@override
|
|
43
|
-
def reset(self,
|
|
43
|
+
def reset(self, state: PyTree, key: Key) -> tuple[WrappedState, Info]:
|
|
44
44
|
keys = self._split_keys(key)
|
|
45
|
-
state, info = jax.vmap(lambda e,
|
|
46
|
-
self.env, keys, state
|
|
47
|
-
)
|
|
45
|
+
state, info = jax.vmap(lambda e, s, k: e.reset(s, k))(self.env, state, keys)
|
|
48
46
|
return state, info
|
|
49
47
|
|
|
50
48
|
@override
|
|
@@ -58,7 +56,9 @@ class VmapEnvsWrapper(Wrapper):
|
|
|
58
56
|
@property
|
|
59
57
|
def observation_space(self) -> spaces.Space:
|
|
60
58
|
env0 = _index_env(self.env, 0, self.batch_size)
|
|
61
|
-
return spaces.BatchedSpace(
|
|
59
|
+
return spaces.BatchedSpace(
|
|
60
|
+
space=env0.observation_space, batch_size=self.batch_size
|
|
61
|
+
)
|
|
62
62
|
|
|
63
63
|
@override
|
|
64
64
|
@cached_property
|
|
@@ -41,9 +41,9 @@ class VmapWrapper(Wrapper):
|
|
|
41
41
|
return state, info
|
|
42
42
|
|
|
43
43
|
@override
|
|
44
|
-
def reset(self,
|
|
44
|
+
def reset(self, state: PyTree, key: Key) -> tuple[WrappedState, Info]:
|
|
45
45
|
keys = _split_or_keep_key(key, self.batch_size)
|
|
46
|
-
state, info = jax.vmap(self.env.reset)(
|
|
46
|
+
state, info = jax.vmap(self.env.reset)(state, keys)
|
|
47
47
|
return state, info
|
|
48
48
|
|
|
49
49
|
@override
|
|
@@ -29,11 +29,11 @@ class Wrapper(Environment):
|
|
|
29
29
|
return self.env.init(key)
|
|
30
30
|
|
|
31
31
|
@override
|
|
32
|
-
def reset(self,
|
|
33
|
-
return self.env.reset(
|
|
32
|
+
def reset(self, state: State, key: Key) -> tuple[State, Info]:
|
|
33
|
+
return self.env.reset(state, key)
|
|
34
34
|
|
|
35
35
|
@override
|
|
36
|
-
def step(self, state:
|
|
36
|
+
def step(self, state: State, action: PyTree) -> tuple[State, Info]:
|
|
37
37
|
return self.env.step(state, action)
|
|
38
38
|
|
|
39
39
|
@override
|
|
@@ -93,8 +93,8 @@ def test_wrapper_unwrapping():
|
|
|
93
93
|
|
|
94
94
|
# Create a simple wrapper
|
|
95
95
|
class SimpleWrapper(BraxWrapper):
|
|
96
|
-
def
|
|
97
|
-
return self.env.
|
|
96
|
+
def init(self, rng):
|
|
97
|
+
return self.env.init(rng)
|
|
98
98
|
|
|
99
99
|
def step(self, state, action):
|
|
100
100
|
return self.env.step(state, action)
|
|
@@ -98,7 +98,7 @@ class StepCounterEnv(Environment):
|
|
|
98
98
|
truncated=truncated,
|
|
99
99
|
)
|
|
100
100
|
|
|
101
|
-
def reset(self,
|
|
101
|
+
def reset(self, state: State, key: Key) -> tuple[StepState, InfoContainer]:
|
|
102
102
|
return self.init(key)
|
|
103
103
|
|
|
104
104
|
def step(
|
|
@@ -198,9 +198,7 @@ class NoStepsEnv(Environment):
|
|
|
198
198
|
obs=s.env_state, reward=0.0, terminated=False, truncated=False
|
|
199
199
|
)
|
|
200
200
|
|
|
201
|
-
def reset(
|
|
202
|
-
self, key: Key, state: State
|
|
203
|
-
) -> tuple[NoStepsState, InfoContainer]:
|
|
201
|
+
def reset(self, state: State, key: Key) -> tuple[NoStepsState, InfoContainer]:
|
|
204
202
|
return self.init(key)
|
|
205
203
|
|
|
206
204
|
def step(
|
|
@@ -230,9 +228,7 @@ class AlternatingTerminationEnv(Environment):
|
|
|
230
228
|
obs=s.env_state, reward=0.0, terminated=False, truncated=False
|
|
231
229
|
)
|
|
232
230
|
|
|
233
|
-
def reset(
|
|
234
|
-
self, key: Key, state: State
|
|
235
|
-
) -> tuple[StepState, InfoContainer]:
|
|
231
|
+
def reset(self, state: State, key: Key) -> tuple[StepState, InfoContainer]:
|
|
236
232
|
return self.init(key)
|
|
237
233
|
|
|
238
234
|
def step(
|
|
@@ -266,7 +262,7 @@ class ScalarToyEnv(Environment):
|
|
|
266
262
|
s = jnp.asarray(0.0, dtype=jnp.float32)
|
|
267
263
|
return s, InfoContainer(obs=s, reward=0.0, terminated=False, truncated=False)
|
|
268
264
|
|
|
269
|
-
def reset(self,
|
|
265
|
+
def reset(self, state: State, key: Key) -> tuple[State, Info]:
|
|
270
266
|
return self.init(key)
|
|
271
267
|
|
|
272
268
|
def step(self, state: State, action: jax.Array) -> tuple[State, Info]:
|
|
@@ -300,7 +296,7 @@ class VectorToyEnv(Environment):
|
|
|
300
296
|
s = jnp.zeros((self.dim,), dtype=jnp.float32)
|
|
301
297
|
return s, InfoContainer(obs=s, reward=0.0, terminated=False, truncated=False)
|
|
302
298
|
|
|
303
|
-
def reset(self,
|
|
299
|
+
def reset(self, state: State, key: Key) -> tuple[State, Info]:
|
|
304
300
|
return self.init(key)
|
|
305
301
|
|
|
306
302
|
def step(self, state: State, action: jax.Array) -> tuple[State, Info]:
|
|
@@ -330,7 +326,7 @@ class FlagDoneEnv(Environment):
|
|
|
330
326
|
z = jnp.array(0.0)
|
|
331
327
|
return z, InfoContainer(obs=z, reward=0.0, terminated=False, truncated=False)
|
|
332
328
|
|
|
333
|
-
def reset(self,
|
|
329
|
+
def reset(self, state: State, key: Key):
|
|
334
330
|
return self.init(key)
|
|
335
331
|
|
|
336
332
|
def step(self, state: State, action: jax.Array):
|
|
@@ -366,7 +362,7 @@ class ParamEnv(Environment):
|
|
|
366
362
|
s = jnp.asarray([self.offset, -self.offset], dtype=jnp.float32)
|
|
367
363
|
return s, InfoContainer(obs=s, reward=0.0, terminated=False, truncated=False)
|
|
368
364
|
|
|
369
|
-
def reset(self,
|
|
365
|
+
def reset(self, state: State, key: Key) -> tuple[State, Info]:
|
|
370
366
|
return self.init(key)
|
|
371
367
|
|
|
372
368
|
def step(self, state: State, action: jax.Array) -> tuple[State, Info]:
|
|
@@ -401,7 +397,7 @@ class VectorObsEnv(Environment):
|
|
|
401
397
|
s = jnp.linspace(0.0, 1.0, self.dim, dtype=jnp.float32)
|
|
402
398
|
return s, InfoContainer(obs=s, reward=0.0, terminated=False, truncated=False)
|
|
403
399
|
|
|
404
|
-
def reset(self,
|
|
400
|
+
def reset(self, state: State, key: Key):
|
|
405
401
|
return self.init(key)
|
|
406
402
|
|
|
407
403
|
def step(self, state: State, action: jax.Array):
|
|
@@ -444,7 +440,7 @@ class PyTreeObsEnv(Environment):
|
|
|
444
440
|
s = obs
|
|
445
441
|
return s, InfoContainer(obs=obs, reward=0.0, terminated=False, truncated=False)
|
|
446
442
|
|
|
447
|
-
def reset(self,
|
|
443
|
+
def reset(self, state: State, key: Key):
|
|
448
444
|
return self.init(key)
|
|
449
445
|
|
|
450
446
|
def step(self, state: State, action: jax.Array):
|
|
@@ -476,7 +472,7 @@ class ConstantObsEnv(Environment):
|
|
|
476
472
|
obs = jnp.asarray(self.value, self.dtype) * jnp.ones(self.shape, self.dtype)
|
|
477
473
|
return 0, InfoContainer(obs=obs, reward=0.0, terminated=False, truncated=False)
|
|
478
474
|
|
|
479
|
-
def reset(self,
|
|
475
|
+
def reset(self, state: State, key: Key):
|
|
480
476
|
return self.init(key)
|
|
481
477
|
|
|
482
478
|
def step(self, state: State, action: jax.Array):
|
|
@@ -500,10 +496,12 @@ class PyTreeActionEnv(Environment):
|
|
|
500
496
|
|
|
501
497
|
@cached_property
|
|
502
498
|
def action_space(self) -> PyTreeSpace:
|
|
503
|
-
return PyTreeSpace(
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
499
|
+
return PyTreeSpace(
|
|
500
|
+
{
|
|
501
|
+
"a": Continuous.from_shape(low=-1.0, high=1.0, shape=(2,)),
|
|
502
|
+
"b": Continuous.from_shape(low=-1.0, high=1.0, shape=(3,)),
|
|
503
|
+
}
|
|
504
|
+
)
|
|
507
505
|
|
|
508
506
|
def _action_to_vec(self, action: PyTree) -> jax.Array:
|
|
509
507
|
leaves = jax.tree.leaves(action)
|
|
@@ -513,12 +511,10 @@ class PyTreeActionEnv(Environment):
|
|
|
513
511
|
s = jnp.zeros(5, dtype=jnp.float32)
|
|
514
512
|
return s, InfoContainer(obs=s, reward=0.0, terminated=False, truncated=False)
|
|
515
513
|
|
|
516
|
-
def reset(self,
|
|
514
|
+
def reset(self, state: State, key: Key) -> tuple[jax.Array, InfoContainer]:
|
|
517
515
|
return self.init(key)
|
|
518
516
|
|
|
519
|
-
def step(
|
|
520
|
-
self, state: jax.Array, action: PyTree
|
|
521
|
-
) -> tuple[jax.Array, InfoContainer]:
|
|
517
|
+
def step(self, state: jax.Array, action: PyTree) -> tuple[jax.Array, InfoContainer]:
|
|
522
518
|
vec = self._action_to_vec(action)
|
|
523
519
|
ns = state + jnp.asarray(vec, dtype=jnp.float32)
|
|
524
520
|
reward = jnp.sum(vec)
|
|
@@ -545,7 +541,7 @@ class IntObsEnv(Environment):
|
|
|
545
541
|
s = jnp.array(0, dtype=jnp.int32)
|
|
546
542
|
return s, InfoContainer(obs=s, reward=0.0, terminated=False, truncated=False)
|
|
547
543
|
|
|
548
|
-
def reset(self,
|
|
544
|
+
def reset(self, state: State, key: Key):
|
|
549
545
|
return self.init(key)
|
|
550
546
|
|
|
551
547
|
def step(self, state: State, action: jax.Array):
|
|
@@ -581,7 +577,7 @@ class RandomImageEnv(Environment):
|
|
|
581
577
|
obs=obs.astype(self.dtype), reward=0.0, terminated=False, truncated=False
|
|
582
578
|
)
|
|
583
579
|
|
|
584
|
-
def reset(self,
|
|
580
|
+
def reset(self, state: State, key: Key):
|
|
585
581
|
return self.init(key)
|
|
586
582
|
|
|
587
583
|
def step(self, state: State, action: jax.Array):
|
|
@@ -618,9 +614,7 @@ class WrapperSimpleEnv(Environment):
|
|
|
618
614
|
info = TestInfo(obs=state, reward=0.0, terminated=False, truncated=False)
|
|
619
615
|
return state, info
|
|
620
616
|
|
|
621
|
-
def reset(
|
|
622
|
-
self, key: Key, state: State
|
|
623
|
-
) -> tuple[jax.Array, TestInfo]:
|
|
617
|
+
def reset(self, state: State, key: Key) -> tuple[jax.Array, TestInfo]:
|
|
624
618
|
return self.init(key)
|
|
625
619
|
|
|
626
620
|
def step(self, state: jax.Array, action: jax.Array) -> tuple[jax.Array, TestInfo]:
|
|
@@ -650,9 +644,7 @@ class WrapperEnvWithFields(Environment):
|
|
|
650
644
|
info = TestInfo(obs=state, reward=0.0, terminated=False, truncated=False)
|
|
651
645
|
return state, info
|
|
652
646
|
|
|
653
|
-
def reset(
|
|
654
|
-
self, key: Key, state: State
|
|
655
|
-
) -> tuple[jax.Array, TestInfo]:
|
|
647
|
+
def reset(self, state: State, key: Key) -> tuple[jax.Array, TestInfo]:
|
|
656
648
|
return self.init(key)
|
|
657
649
|
|
|
658
650
|
def step(self, state: jax.Array, action: jax.Array) -> tuple[jax.Array, TestInfo]:
|
|
@@ -679,9 +671,7 @@ class WrapperEnvWithMethods(Environment):
|
|
|
679
671
|
info = TestInfo(obs=state, reward=0.0, terminated=False, truncated=False)
|
|
680
672
|
return state, info
|
|
681
673
|
|
|
682
|
-
def reset(
|
|
683
|
-
self, key: Key, state: State
|
|
684
|
-
) -> tuple[jax.Array, TestInfo]:
|
|
674
|
+
def reset(self, state: State, key: Key) -> tuple[jax.Array, TestInfo]:
|
|
685
675
|
return self.init(key)
|
|
686
676
|
|
|
687
677
|
def step(self, state: jax.Array, action: jax.Array) -> tuple[jax.Array, TestInfo]:
|
|
@@ -727,7 +717,7 @@ def make_wrapper_discrete_env() -> Environment:
|
|
|
727
717
|
info = TestInfo(obs=state, reward=0.0, terminated=False, truncated=False)
|
|
728
718
|
return state, info
|
|
729
719
|
|
|
730
|
-
def reset(self,
|
|
720
|
+
def reset(self, state: State, key: Key):
|
|
731
721
|
return self.init(key)
|
|
732
722
|
|
|
733
723
|
def step(self, state: jax.Array, action: jax.Array):
|
|
@@ -762,7 +752,7 @@ def make_wrapper_complex_state_env() -> Environment:
|
|
|
762
752
|
)
|
|
763
753
|
return st, info
|
|
764
754
|
|
|
765
|
-
def reset(self,
|
|
755
|
+
def reset(self, state: State, key: Key):
|
|
766
756
|
return self.init(key)
|
|
767
757
|
|
|
768
758
|
def step(self, state: dict, action: jax.Array):
|
|
@@ -544,8 +544,8 @@ def test_auto_reset_passes_state_to_inner_wrapper():
|
|
|
544
544
|
received_state_on_reset=False,
|
|
545
545
|
), info
|
|
546
546
|
|
|
547
|
-
def reset(self,
|
|
548
|
-
inner_state, info = self.env.reset(
|
|
547
|
+
def reset(self, state, key):
|
|
548
|
+
inner_state, info = self.env.reset(state.inner_state, key)
|
|
549
549
|
return self.TrackingState(
|
|
550
550
|
inner_state=inner_state,
|
|
551
551
|
received_state_on_reset=True,
|
|
@@ -21,8 +21,8 @@ def test_init_reset_delegate_unchanged():
|
|
|
21
21
|
assert jnp.allclose(state_w, state_e)
|
|
22
22
|
assert jnp.allclose(info_w.obs, info_e.obs)
|
|
23
23
|
|
|
24
|
-
state_w, info_w = w.reset(
|
|
25
|
-
state_e, info_e = env.reset(
|
|
24
|
+
state_w, info_w = w.reset(state_w, key)
|
|
25
|
+
state_e, info_e = env.reset(state_e, key)
|
|
26
26
|
assert jnp.allclose(state_w, state_e)
|
|
27
27
|
assert jnp.allclose(info_w.obs, info_e.obs)
|
|
28
28
|
assert w.observation_space.contains(info_w.obs)
|
{jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/wrappers/test_continuous_observation_wrapper.py
RENAMED
|
@@ -17,7 +17,7 @@ def test_init_reset_step_cast_discrete_obs_to_float32():
|
|
|
17
17
|
key = jax.random.PRNGKey(0)
|
|
18
18
|
state, info = w.init(key)
|
|
19
19
|
assert info.obs.dtype == jnp.float32
|
|
20
|
-
state, info = w.reset(
|
|
20
|
+
state, info = w.reset(state, key)
|
|
21
21
|
assert info.obs.dtype == jnp.float32
|
|
22
22
|
assert w.observation_space.contains(info.obs)
|
|
23
23
|
state, info = w.step(state, jnp.array(0, dtype=jnp.int32))
|
|
@@ -66,7 +66,7 @@ def test_reset_preserves_stats():
|
|
|
66
66
|
state, _ = w.step(state, jnp.asarray(0.2))
|
|
67
67
|
reward_before = state.stats.reward
|
|
68
68
|
length_before = state.stats.length
|
|
69
|
-
state, info = w.reset(
|
|
69
|
+
state, info = w.reset(state, key)
|
|
70
70
|
assert jnp.allclose(state.stats.reward, reward_before)
|
|
71
71
|
assert jnp.allclose(state.stats.length, length_before)
|
|
72
72
|
assert jnp.allclose(info.stats.reward, reward_before)
|
|
@@ -81,7 +81,7 @@ def test_stats_persist_and_continue_after_reset():
|
|
|
81
81
|
state, _ = w.init(key)
|
|
82
82
|
for _ in range(3):
|
|
83
83
|
state, _ = w.step(state, jnp.asarray(0.1))
|
|
84
|
-
state, _ = w.reset(
|
|
84
|
+
state, _ = w.reset(state, key)
|
|
85
85
|
for _ in range(2):
|
|
86
86
|
state, _ = w.step(state, jnp.asarray(0.1))
|
|
87
87
|
# Total length = 3 + 2 = 5, reward = 0.1*5 = 0.5
|
|
@@ -60,8 +60,8 @@ def test_init_reset_delegate_unchanged():
|
|
|
60
60
|
state_e, info_e = env.init(key)
|
|
61
61
|
assert jnp.allclose(state_w, state_e)
|
|
62
62
|
assert jnp.allclose(info_w.obs, info_e.obs)
|
|
63
|
-
state_w, info_w = w.reset(
|
|
64
|
-
state_e, info_e = env.reset(
|
|
63
|
+
state_w, info_w = w.reset(state_w, key)
|
|
64
|
+
state_e, info_e = env.reset(state_e, key)
|
|
65
65
|
assert jnp.allclose(state_w, state_e)
|
|
66
66
|
assert jnp.allclose(info_w.obs, info_e.obs)
|
|
67
67
|
|
|
@@ -101,7 +101,7 @@ def test_action_space_flattened_discrete():
|
|
|
101
101
|
obs=s, reward=0.0, terminated=False, truncated=False
|
|
102
102
|
)
|
|
103
103
|
|
|
104
|
-
def reset(self,
|
|
104
|
+
def reset(self, state, key):
|
|
105
105
|
return self.init(key)
|
|
106
106
|
|
|
107
107
|
def step(self, state, action):
|
|
@@ -161,7 +161,7 @@ def test_mixed_space_types_raises_value_error():
|
|
|
161
161
|
obs=s, reward=0.0, terminated=False, truncated=False
|
|
162
162
|
)
|
|
163
163
|
|
|
164
|
-
def reset(self,
|
|
164
|
+
def reset(self, state, key):
|
|
165
165
|
return self.init(key)
|
|
166
166
|
|
|
167
167
|
def step(self, state, action):
|
{jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/wrappers/test_flatten_observation_wrapper.py
RENAMED
|
@@ -29,7 +29,7 @@ def test_reset_step_flatten_pytree_obs():
|
|
|
29
29
|
key = jax.random.PRNGKey(0)
|
|
30
30
|
state, info = w.init(key)
|
|
31
31
|
assert info.obs.shape == (5,)
|
|
32
|
-
state, info = w.reset(
|
|
32
|
+
state, info = w.reset(state, key)
|
|
33
33
|
assert info.obs.shape == (5,)
|
|
34
34
|
assert w.observation_space.contains(info.obs)
|
|
35
35
|
state, info = w.step(state, jnp.array(0.0))
|
|
@@ -126,7 +126,7 @@ def test_mixed_space_types_raises_value_error():
|
|
|
126
126
|
obs=obs, reward=0.0, terminated=False, truncated=False
|
|
127
127
|
)
|
|
128
128
|
|
|
129
|
-
def reset(self,
|
|
129
|
+
def reset(self, state, key):
|
|
130
130
|
return self.init(key)
|
|
131
131
|
|
|
132
132
|
def step(self, state, action):
|
|
@@ -108,7 +108,7 @@ def test_reset_vmaps_inner_reset():
|
|
|
108
108
|
w = PooledInitVmapWrapper(env=env, batch_size=batch_size, pool_size=3)
|
|
109
109
|
key = jax.random.PRNGKey(0)
|
|
110
110
|
state, info = w.init(key)
|
|
111
|
-
state, info = w.reset(
|
|
111
|
+
state, info = w.reset(state, key)
|
|
112
112
|
assert info.obs.shape == (batch_size,)
|
|
113
113
|
assert w.observation_space.contains(info.obs)
|
|
114
114
|
|
|
@@ -66,7 +66,7 @@ class TestStateInjectionCoreFunctionality:
|
|
|
66
66
|
|
|
67
67
|
# Reset again, passing the current state (simulates auto-reset)
|
|
68
68
|
key2 = jax.random.PRNGKey(1)
|
|
69
|
-
state2, info2 = w.reset(
|
|
69
|
+
state2, info2 = w.reset(state, key2)
|
|
70
70
|
|
|
71
71
|
# Should preserve the injected state
|
|
72
72
|
assert jnp.allclose(state2.reset_state.env_state, jnp.array(42.0))
|
|
@@ -132,7 +132,7 @@ class TestStateInjectionCoreFunctionality:
|
|
|
132
132
|
|
|
133
133
|
# Reset with this state (no reset_state set) - should delegate to inner env
|
|
134
134
|
key2 = jax.random.PRNGKey(1)
|
|
135
|
-
state2, info2 = w.reset(
|
|
135
|
+
state2, info2 = w.reset(state, key2)
|
|
136
136
|
|
|
137
137
|
# Should have done a normal reset - inner_state is fresh from env
|
|
138
138
|
assert jnp.allclose(state2.inner_state.env_state, jnp.array(0.0))
|
|
@@ -166,7 +166,7 @@ class TestStateInjectionCoreFunctionality:
|
|
|
166
166
|
)
|
|
167
167
|
|
|
168
168
|
with pytest.raises(ValueError, match="must set both"):
|
|
169
|
-
w.reset(
|
|
169
|
+
w.reset(state_with_only_reset_state, key)
|
|
170
170
|
|
|
171
171
|
# Create state with only reset_obs set (not reset_state)
|
|
172
172
|
state_with_only_reset_obs = w.InjectedState(
|
|
@@ -176,7 +176,7 @@ class TestStateInjectionCoreFunctionality:
|
|
|
176
176
|
)
|
|
177
177
|
|
|
178
178
|
with pytest.raises(ValueError, match="must set both"):
|
|
179
|
-
w.reset(
|
|
179
|
+
w.reset(state_with_only_reset_obs, key)
|
|
180
180
|
|
|
181
181
|
|
|
182
182
|
# ============================================================================
|
|
@@ -116,7 +116,7 @@ def test_steps_as_jax_scalar_array_behaves_correctly():
|
|
|
116
116
|
|
|
117
117
|
|
|
118
118
|
def test_reset_with_state_passes_inner_state_down():
|
|
119
|
-
"""reset(
|
|
119
|
+
"""reset(state, key) should pass state.inner_state to the inner env's reset."""
|
|
120
120
|
env = StepCounterEnv()
|
|
121
121
|
w = TruncationWrapper(env=env, max_steps=10)
|
|
122
122
|
key = jax.random.PRNGKey(0)
|
|
@@ -126,7 +126,7 @@ def test_reset_with_state_passes_inner_state_down():
|
|
|
126
126
|
state, _ = w.step(state, jnp.asarray(0.1))
|
|
127
127
|
assert state.steps == 5
|
|
128
128
|
|
|
129
|
-
new_state, _ = w.reset(jax.random.PRNGKey(1)
|
|
129
|
+
new_state, _ = w.reset(state, jax.random.PRNGKey(1))
|
|
130
130
|
|
|
131
131
|
# Inner env should be reset
|
|
132
132
|
assert jnp.allclose(new_state.inner_state.env_state, 0.0)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{jax_envelope-0.2.0 → jax_envelope-0.3.0}/tests/wrappers/test_observation_normalization_wrapper.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|