jax-envelope 0.3.0__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/.github/workflows/publish.yml +1 -1
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/PKG-INFO +7 -8
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/README.md +6 -7
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/pyproject.toml +3 -3
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/src/envelope/__init__.py +1 -1
- {jax_envelope-0.3.0/src/envelope/compat → jax_envelope-0.4.0/src/envelope/adapters}/__init__.py +7 -7
- {jax_envelope-0.3.0/src/envelope/compat → jax_envelope-0.4.0/src/envelope/adapters}/craftax_envelope.py +1 -1
- {jax_envelope-0.3.0/src/envelope/compat → jax_envelope-0.4.0/src/envelope/adapters}/kinetix_envelope.py +2 -2
- {jax_envelope-0.3.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/conftest.py +1 -1
- {jax_envelope-0.3.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/contract.py +2 -2
- {jax_envelope-0.3.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/test_brax_compat.py +4 -4
- {jax_envelope-0.3.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/test_craftax_compat.py +5 -5
- {jax_envelope-0.3.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/test_create.py +7 -7
- {jax_envelope-0.3.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/test_create_integration.py +12 -12
- {jax_envelope-0.3.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/test_gymnax_compat.py +2 -2
- {jax_envelope-0.3.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/test_jumanji_compat.py +5 -5
- {jax_envelope-0.3.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/test_kinetix_compat.py +7 -8
- {jax_envelope-0.3.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/test_mujoco_playground_compat.py +4 -4
- {jax_envelope-0.3.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/test_navix_compat.py +8 -8
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/uv.lock +3 -3
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/.gitignore +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/LICENSE +0 -0
- {jax_envelope-0.3.0/src/envelope/compat → jax_envelope-0.4.0/src/envelope/adapters}/brax_envelope.py +0 -0
- {jax_envelope-0.3.0/src/envelope/compat → jax_envelope-0.4.0/src/envelope/adapters}/gymnax_envelope.py +0 -0
- {jax_envelope-0.3.0/src/envelope/compat → jax_envelope-0.4.0/src/envelope/adapters}/jumanji_envelope.py +0 -0
- {jax_envelope-0.3.0/src/envelope/compat → jax_envelope-0.4.0/src/envelope/adapters}/mujoco_playground_envelope.py +0 -0
- {jax_envelope-0.3.0/src/envelope/compat → jax_envelope-0.4.0/src/envelope/adapters}/navix_envelope.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/src/envelope/environment.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/src/envelope/spaces.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/src/envelope/struct.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/src/envelope/typing.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/src/envelope/wrappers/__init__.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/src/envelope/wrappers/autoreset_wrapper.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/src/envelope/wrappers/clip_action_wrapper.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/src/envelope/wrappers/continuous_observation_wrapper.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/src/envelope/wrappers/episode_statistics_wrapper.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/src/envelope/wrappers/flatten_action_wrapper.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/src/envelope/wrappers/flatten_observation_wrapper.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/src/envelope/wrappers/normalization.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/src/envelope/wrappers/observation_normalization_wrapper.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/src/envelope/wrappers/pooled_init_vmap_wrapper.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/src/envelope/wrappers/state_injection_wrapper.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/src/envelope/wrappers/truncation_wrapper.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/src/envelope/wrappers/vmap_envs_wrapper.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/src/envelope/wrappers/vmap_wrapper.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/src/envelope/wrappers/wrapper.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/tests/__init__.py +0 -0
- {jax_envelope-0.3.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/__init__.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/tests/spaces/__init__.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/tests/spaces/test_batched_space.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/tests/spaces/test_continuous.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/tests/spaces/test_discrete.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/tests/spaces/test_pytree_space.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/tests/spaces/test_serialization.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/tests/test_container.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/tests/test_struct.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/tests/wrappers/__init__.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/tests/wrappers/helpers.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/tests/wrappers/test_autoreset_wrapper.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/tests/wrappers/test_clip_action_wrapper.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/tests/wrappers/test_continuous_observation_wrapper.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/tests/wrappers/test_environment_wrapper.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/tests/wrappers/test_episode_statistics_wrapper.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/tests/wrappers/test_flatten_action_wrapper.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/tests/wrappers/test_flatten_observation_wrapper.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/tests/wrappers/test_normalization.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/tests/wrappers/test_observation_normalization_wrapper.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/tests/wrappers/test_pooled_init_vmap_wrapper.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/tests/wrappers/test_state_injection_wrapper.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/tests/wrappers/test_truncation_wrapper.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/tests/wrappers/test_vmap_envs_wrapper.py +0 -0
- {jax_envelope-0.3.0 → jax_envelope-0.4.0}/tests/wrappers/test_vmap_wrapper.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: jax-envelope
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.0
|
|
4
4
|
Summary: A JAX-native environment interface with powerful wrappers and adapters for popular RL environment suites
|
|
5
5
|
Project-URL: Homepage, https://github.com/keraJLi/envelope
|
|
6
6
|
Project-URL: Repository, https://github.com/keraJLi/envelope
|
|
@@ -42,6 +42,7 @@ env = envelope.wrappers.ObservationNormalizationWrapper(env)
|
|
|
42
42
|
```
|
|
43
43
|
|
|
44
44
|
## 🌍 Simple, expressive interaction!
|
|
45
|
+
|
|
45
46
|
- **Environments are pytrees**. Squish them through JAX transformations and trace their parameters.
|
|
46
47
|
- **Idiomatic jax-y interface** of `init(key: Key) -> State, Info` and `step(state: State, action: PyTree) -> State, Info`. You can directly `jax.scan` over a `step(...)`!
|
|
47
48
|
- **Spaces are super simple**. No `Tuple`, `Dict` nonsense! There are two spaces: `Continuous` and `Discrete`, which you can compose into a `PyTreeSpace`.
|
|
@@ -53,8 +54,6 @@ env = envelope.wrappers.ObservationNormalizationWrapper(env)
|
|
|
53
54
|
- **Carry state across episodes** to track running statistics, for example to normalize observations.
|
|
54
55
|
- **Composable wrappers** can be stacked in any order. For example, `ObservationNormalizationWrapper` before vs. after `VmapWrapper` gives per-env vs. global normalization.
|
|
55
56
|
|
|
56
|
-
|
|
57
|
-
|
|
58
57
|
## 🔌 Adapters for existing suites
|
|
59
58
|
|
|
60
59
|
|
|
@@ -78,11 +77,11 @@ let's you create environments from any of the above!
|
|
|
78
77
|
|
|
79
78
|
## 📝 Testing
|
|
80
79
|
|
|
81
|
-
- **Default (no optional
|
|
82
|
-
- **
|
|
83
|
-
- `uv sync --group
|
|
84
|
-
- `uv run pytest -m
|
|
85
|
-
- If any
|
|
80
|
+
- **Default (no optional adapters deps required)**: `uv run pytest -m "not adapters"`
|
|
81
|
+
- **Adapters suite (requires full adapters dependency group)**:
|
|
82
|
+
- `uv sync --group adapters`
|
|
83
|
+
- `uv run pytest -m adapters`
|
|
84
|
+
- If any adapter dependency is missing/broken, the run will fail fast with an error telling you what to install.
|
|
86
85
|
|
|
87
86
|
## 🏗️ Installation
|
|
88
87
|
|
|
@@ -16,6 +16,7 @@ env = envelope.wrappers.ObservationNormalizationWrapper(env)
|
|
|
16
16
|
```
|
|
17
17
|
|
|
18
18
|
## 🌍 Simple, expressive interaction!
|
|
19
|
+
|
|
19
20
|
- **Environments are pytrees**. Squish them through JAX transformations and trace their parameters.
|
|
20
21
|
- **Idiomatic jax-y interface** of `init(key: Key) -> State, Info` and `step(state: State, action: PyTree) -> State, Info`. You can directly `jax.scan` over a `step(...)`!
|
|
21
22
|
- **Spaces are super simple**. No `Tuple`, `Dict` nonsense! There are two spaces: `Continuous` and `Discrete`, which you can compose into a `PyTreeSpace`.
|
|
@@ -27,8 +28,6 @@ env = envelope.wrappers.ObservationNormalizationWrapper(env)
|
|
|
27
28
|
- **Carry state across episodes** to track running statistics, for example to normalize observations.
|
|
28
29
|
- **Composable wrappers** can be stacked in any order. For example, `ObservationNormalizationWrapper` before vs. after `VmapWrapper` gives per-env vs. global normalization.
|
|
29
30
|
|
|
30
|
-
|
|
31
|
-
|
|
32
31
|
## 🔌 Adapters for existing suites
|
|
33
32
|
|
|
34
33
|
|
|
@@ -52,11 +51,11 @@ let's you create environments from any of the above!
|
|
|
52
51
|
|
|
53
52
|
## 📝 Testing
|
|
54
53
|
|
|
55
|
-
- **Default (no optional
|
|
56
|
-
- **
|
|
57
|
-
- `uv sync --group
|
|
58
|
-
- `uv run pytest -m
|
|
59
|
-
- If any
|
|
54
|
+
- **Default (no optional adapters deps required)**: `uv run pytest -m "not adapters"`
|
|
55
|
+
- **Adapters suite (requires full adapters dependency group)**:
|
|
56
|
+
- `uv sync --group adapters`
|
|
57
|
+
- `uv run pytest -m adapters`
|
|
58
|
+
- If any adapter dependency is missing/broken, the run will fail fast with an error telling you what to install.
|
|
60
59
|
|
|
61
60
|
## 🏗️ Installation
|
|
62
61
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "jax-envelope"
|
|
3
|
-
version = "0.
|
|
3
|
+
version = "0.4.0"
|
|
4
4
|
description = "A JAX-native environment interface with powerful wrappers and adapters for popular RL environment suites"
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
requires-python = ">=3.12"
|
|
@@ -52,7 +52,7 @@ allow-direct-references = true
|
|
|
52
52
|
packages = ["src/envelope"]
|
|
53
53
|
|
|
54
54
|
[dependency-groups]
|
|
55
|
-
|
|
55
|
+
adapters = [
|
|
56
56
|
"brax>=0.13.0",
|
|
57
57
|
"craftax>=1.4.3",
|
|
58
58
|
"navix>=0.7.0",
|
|
@@ -80,5 +80,5 @@ kinetix-env = { git = "https://github.com/FLAIROx/Kinetix.git" }
|
|
|
80
80
|
[tool.pytest.ini_options]
|
|
81
81
|
testpaths = ["tests"]
|
|
82
82
|
markers = [
|
|
83
|
-
"
|
|
83
|
+
"adapters: tests requiring optional adapters dependencies",
|
|
84
84
|
]
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from envelope.
|
|
1
|
+
from envelope.adapters import create
|
|
2
2
|
from envelope.environment import Environment, Info, InfoContainer
|
|
3
3
|
from envelope.spaces import BatchedSpace, Continuous, Discrete, PyTreeSpace, Space
|
|
4
4
|
from envelope.struct import Container, FrozenPyTreeNode, field, static_field
|
{jax_envelope-0.3.0/src/envelope/compat → jax_envelope-0.4.0/src/envelope/adapters}/__init__.py
RENAMED
|
@@ -4,14 +4,14 @@ from typing import Any, Protocol, Self
|
|
|
4
4
|
|
|
5
5
|
# Lazy imports to avoid requiring all dependencies at once
|
|
6
6
|
_env_module_map = {
|
|
7
|
-
"gymnax": ("envelope.
|
|
8
|
-
"brax": ("envelope.
|
|
9
|
-
"navix": ("envelope.
|
|
10
|
-
"jumanji": ("envelope.
|
|
11
|
-
"kinetix": ("envelope.
|
|
12
|
-
"craftax": ("envelope.
|
|
7
|
+
"gymnax": ("envelope.adapters.gymnax_envelope", "GymnaxEnvelope"),
|
|
8
|
+
"brax": ("envelope.adapters.brax_envelope", "BraxEnvelope"),
|
|
9
|
+
"navix": ("envelope.adapters.navix_envelope", "NavixEnvelope"),
|
|
10
|
+
"jumanji": ("envelope.adapters.jumanji_envelope", "JumanjiEnvelope"),
|
|
11
|
+
"kinetix": ("envelope.adapters.kinetix_envelope", "KinetixEnvelope"),
|
|
12
|
+
"craftax": ("envelope.adapters.craftax_envelope", "CraftaxEnvelope"),
|
|
13
13
|
"mujoco_playground": (
|
|
14
|
-
"envelope.
|
|
14
|
+
"envelope.adapters.mujoco_playground_envelope",
|
|
15
15
|
"MujocoPlaygroundEnvelope",
|
|
16
16
|
),
|
|
17
17
|
}
|
|
@@ -10,7 +10,7 @@ from craftax.craftax_classic.envs.craftax_state import (
|
|
|
10
10
|
from craftax.craftax_env import make_craftax_env_from_name
|
|
11
11
|
|
|
12
12
|
from envelope import spaces as envelope_spaces
|
|
13
|
-
from envelope.
|
|
13
|
+
from envelope.adapters.gymnax_envelope import _convert_space as _convert_gymnax_space
|
|
14
14
|
from envelope.environment import Environment, Info, InfoContainer, State
|
|
15
15
|
from envelope.struct import Container, static_field
|
|
16
16
|
from envelope.typing import Key, PyTree, TypeAlias
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Kinetix compatibility wrapper.
|
|
2
2
|
|
|
3
3
|
This module exposes Kinetix environments through the `envelope.environment.Environment`
|
|
4
|
-
API. It mirrors envelope's
|
|
4
|
+
API. It mirrors envelope's adapters philosophy:
|
|
5
5
|
- prefer *no* environment-side auto-reset (use `AutoResetWrapper` in envelope)
|
|
6
6
|
- prefer *no* fixed episode time-limits (use `TruncationWrapper` in envelope)
|
|
7
7
|
|
|
@@ -30,7 +30,7 @@ from kinetix.util.saving import load_from_json_file
|
|
|
30
30
|
|
|
31
31
|
from envelope import field
|
|
32
32
|
from envelope import spaces as envelope_spaces
|
|
33
|
-
from envelope.
|
|
33
|
+
from envelope.adapters.gymnax_envelope import _convert_space as _convert_gymnax_space
|
|
34
34
|
from envelope.environment import Environment, Info, InfoContainer, State
|
|
35
35
|
from envelope.struct import Container, static_field
|
|
36
36
|
from envelope.typing import Key, PyTree
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
"""Shared contract helpers for
|
|
1
|
+
"""Shared contract helpers for adapters.
|
|
2
2
|
|
|
3
|
-
These functions enforce a consistent baseline across all
|
|
3
|
+
These functions enforce a consistent baseline across all adapters:
|
|
4
4
|
- reset/step return (state, info) with Info fields present
|
|
5
5
|
- reward is scalar-ish and finite
|
|
6
6
|
- action sampling is valid for action_space
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""Tests for envelope.
|
|
1
|
+
"""Tests for envelope.adapters.brax_envelope module."""
|
|
2
2
|
|
|
3
3
|
# ruff: noqa: E402
|
|
4
4
|
|
|
@@ -7,14 +7,14 @@ from copy import deepcopy
|
|
|
7
7
|
import jax
|
|
8
8
|
import pytest
|
|
9
9
|
|
|
10
|
-
pytestmark = pytest.mark.
|
|
10
|
+
pytestmark = pytest.mark.adapters
|
|
11
11
|
|
|
12
12
|
pytest.importorskip("brax")
|
|
13
13
|
|
|
14
14
|
from brax.envs import Wrapper as BraxWrapper
|
|
15
15
|
|
|
16
|
-
from envelope.
|
|
17
|
-
from tests.
|
|
16
|
+
from envelope.adapters.brax_envelope import BraxEnvelope
|
|
17
|
+
from tests.adapters.contract import (
|
|
18
18
|
assert_jitted_rollout_contract,
|
|
19
19
|
assert_reset_step_contract,
|
|
20
20
|
)
|
{jax_envelope-0.3.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/test_craftax_compat.py
RENAMED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""Tests for envelope.
|
|
1
|
+
"""Tests for envelope.adapters.craftax_envelope module."""
|
|
2
2
|
|
|
3
3
|
# ruff: noqa: E402
|
|
4
4
|
|
|
@@ -8,12 +8,12 @@ import jax
|
|
|
8
8
|
import jax.numpy as jnp
|
|
9
9
|
import pytest
|
|
10
10
|
|
|
11
|
-
pytestmark = pytest.mark.
|
|
11
|
+
pytestmark = pytest.mark.adapters
|
|
12
12
|
|
|
13
13
|
pytest.importorskip("craftax")
|
|
14
14
|
|
|
15
15
|
from envelope.spaces import Continuous, Discrete
|
|
16
|
-
from tests.
|
|
16
|
+
from tests.adapters.contract import (
|
|
17
17
|
assert_jitted_rollout_contract,
|
|
18
18
|
assert_reset_step_contract,
|
|
19
19
|
)
|
|
@@ -35,7 +35,7 @@ def craftax_env_id(request: pytest.FixtureRequest) -> str:
|
|
|
35
35
|
|
|
36
36
|
@pytest.fixture(scope="module")
|
|
37
37
|
def craftax_env(craftax_env_id: str):
|
|
38
|
-
from envelope.
|
|
38
|
+
from envelope.adapters.craftax_envelope import CraftaxEnvelope
|
|
39
39
|
|
|
40
40
|
return CraftaxEnvelope.from_name(craftax_env_id)
|
|
41
41
|
|
|
@@ -114,7 +114,7 @@ class _DummyEnv:
|
|
|
114
114
|
|
|
115
115
|
|
|
116
116
|
def test_from_name_errors_on_auto_reset():
|
|
117
|
-
from envelope.
|
|
117
|
+
from envelope.adapters.craftax_envelope import CraftaxEnvelope
|
|
118
118
|
|
|
119
119
|
with pytest.raises(ValueError, match="Cannot override 'auto_reset' directly"):
|
|
120
120
|
CraftaxEnvelope.from_name("AnyEnv", env_kwargs={"auto_reset": True})
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""Unit tests for envelope.
|
|
1
|
+
"""Unit tests for envelope.adapters.create() factory function.
|
|
2
2
|
|
|
3
3
|
These tests are dependency-free (no brax/gymnax/navix imports) and focus on:
|
|
4
4
|
- parsing/validation of the env id
|
|
@@ -12,8 +12,8 @@ import types
|
|
|
12
12
|
|
|
13
13
|
import pytest
|
|
14
14
|
|
|
15
|
-
import envelope.
|
|
16
|
-
from envelope.
|
|
15
|
+
import envelope.adapters as adapters
|
|
16
|
+
from envelope.adapters import create
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
def _install_dummy_suite(
|
|
@@ -44,7 +44,7 @@ def _install_dummy_suite(
|
|
|
44
44
|
raise AssertionError(f"Unexpected import: {name}")
|
|
45
45
|
return dummy_module
|
|
46
46
|
|
|
47
|
-
monkeypatch.setattr(
|
|
47
|
+
monkeypatch.setattr(adapters, "_env_module_map", {suite: (module_name, class_name)})
|
|
48
48
|
monkeypatch.setattr(importlib, "import_module", fake_import_module)
|
|
49
49
|
|
|
50
50
|
return import_calls, from_name_calls
|
|
@@ -77,7 +77,7 @@ def test_create_unknown_suite_mentions_available_suites(
|
|
|
77
77
|
):
|
|
78
78
|
# Keep the map deterministic so we can assert it appears in the message.
|
|
79
79
|
monkeypatch.setattr(
|
|
80
|
-
|
|
80
|
+
adapters, "_env_module_map", {"dummy": ("dummy_mod", "DummyWrapper")}
|
|
81
81
|
)
|
|
82
82
|
|
|
83
83
|
with pytest.raises(ValueError) as excinfo:
|
|
@@ -91,7 +91,7 @@ def test_create_unknown_suite_mentions_available_suites(
|
|
|
91
91
|
|
|
92
92
|
def test_create_wraps_import_error_and_chains_cause(monkeypatch):
|
|
93
93
|
monkeypatch.setattr(
|
|
94
|
-
|
|
94
|
+
adapters, "_env_module_map", {"dummy": ("dummy_mod", "DummyWrapper")}
|
|
95
95
|
)
|
|
96
96
|
|
|
97
97
|
def fake_import_module(name: str):
|
|
@@ -164,7 +164,7 @@ def test_create_imports_only_the_requested_suite(monkeypatch):
|
|
|
164
164
|
module_b = types.SimpleNamespace(WrapperB=WrapperB)
|
|
165
165
|
|
|
166
166
|
monkeypatch.setattr(
|
|
167
|
-
|
|
167
|
+
adapters,
|
|
168
168
|
"_env_module_map",
|
|
169
169
|
{"a": ("a_mod", "WrapperA"), "b": ("b_mod", "WrapperB")},
|
|
170
170
|
)
|
{jax_envelope-0.3.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/test_create_integration.py
RENAMED
|
@@ -1,22 +1,22 @@
|
|
|
1
|
-
"""Integration tests for envelope.
|
|
1
|
+
"""Integration tests for envelope.adapters.create().
|
|
2
2
|
|
|
3
|
-
These require optional
|
|
4
|
-
|
|
3
|
+
These require optional adapter dependencies (brax/gymnax/navix). They are kept separate
|
|
4
|
+
from the unit tests so a minimal install can still run the suite.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
7
|
import pytest
|
|
8
8
|
|
|
9
|
-
from envelope.
|
|
9
|
+
from envelope.adapters import create
|
|
10
10
|
from envelope.environment import Environment
|
|
11
11
|
from envelope.wrappers.truncation_wrapper import TruncationWrapper
|
|
12
12
|
|
|
13
|
-
pytestmark = pytest.mark.
|
|
13
|
+
pytestmark = pytest.mark.adapters
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
def test_create_brax_smoke(prng_key):
|
|
17
17
|
pytest.importorskip("brax")
|
|
18
18
|
|
|
19
|
-
from envelope.
|
|
19
|
+
from envelope.adapters.brax_envelope import BraxEnvelope
|
|
20
20
|
|
|
21
21
|
env = create("brax::fast")
|
|
22
22
|
assert isinstance(env, TruncationWrapper)
|
|
@@ -31,7 +31,7 @@ def test_create_brax_smoke(prng_key):
|
|
|
31
31
|
def test_create_gymnax_smoke(prng_key):
|
|
32
32
|
pytest.importorskip("gymnax")
|
|
33
33
|
|
|
34
|
-
from envelope.
|
|
34
|
+
from envelope.adapters.gymnax_envelope import GymnaxEnvelope
|
|
35
35
|
|
|
36
36
|
env = create("gymnax::CartPole-v1")
|
|
37
37
|
assert isinstance(env, TruncationWrapper)
|
|
@@ -46,7 +46,7 @@ def test_create_gymnax_smoke(prng_key):
|
|
|
46
46
|
def test_create_navix_smoke(prng_key):
|
|
47
47
|
pytest.importorskip("navix")
|
|
48
48
|
|
|
49
|
-
from envelope.
|
|
49
|
+
from envelope.adapters.navix_envelope import NavixEnvelope
|
|
50
50
|
|
|
51
51
|
env = create("navix::Navix-Empty-5x5-v0")
|
|
52
52
|
assert isinstance(env, TruncationWrapper)
|
|
@@ -61,7 +61,7 @@ def test_create_navix_smoke(prng_key):
|
|
|
61
61
|
def test_create_jumanji_smoke(prng_key):
|
|
62
62
|
pytest.importorskip("jumanji")
|
|
63
63
|
|
|
64
|
-
from envelope.
|
|
64
|
+
from envelope.adapters.jumanji_envelope import JumanjiEnvelope
|
|
65
65
|
|
|
66
66
|
env = create("jumanji::Snake-v1")
|
|
67
67
|
assert isinstance(env, TruncationWrapper)
|
|
@@ -76,7 +76,7 @@ def test_create_jumanji_smoke(prng_key):
|
|
|
76
76
|
def test_create_craftax_smoke(prng_key):
|
|
77
77
|
pytest.importorskip("craftax")
|
|
78
78
|
|
|
79
|
-
from envelope.
|
|
79
|
+
from envelope.adapters.craftax_envelope import CraftaxEnvelope
|
|
80
80
|
|
|
81
81
|
env = create("craftax::Craftax-Symbolic-v1")
|
|
82
82
|
assert isinstance(env, TruncationWrapper)
|
|
@@ -91,7 +91,7 @@ def test_create_craftax_smoke(prng_key):
|
|
|
91
91
|
def test_create_mujoco_playground_smoke(prng_key):
|
|
92
92
|
pytest.importorskip("mujoco_playground")
|
|
93
93
|
|
|
94
|
-
from envelope.
|
|
94
|
+
from envelope.adapters.mujoco_playground_envelope import MujocoPlaygroundEnvelope
|
|
95
95
|
|
|
96
96
|
env = create("mujoco_playground::CartpoleBalance")
|
|
97
97
|
assert isinstance(env, TruncationWrapper)
|
|
@@ -106,7 +106,7 @@ def test_create_mujoco_playground_smoke(prng_key):
|
|
|
106
106
|
def test_create_kinetix_smoke(prng_key):
|
|
107
107
|
pytest.importorskip("kinetix")
|
|
108
108
|
|
|
109
|
-
from envelope.
|
|
109
|
+
from envelope.adapters.kinetix_envelope import KinetixEnvelope
|
|
110
110
|
|
|
111
111
|
env = create("kinetix::random")
|
|
112
112
|
assert isinstance(env, TruncationWrapper)
|
|
@@ -6,13 +6,13 @@ import jax
|
|
|
6
6
|
import jax.numpy as jnp
|
|
7
7
|
import pytest
|
|
8
8
|
|
|
9
|
-
pytestmark = pytest.mark.
|
|
9
|
+
pytestmark = pytest.mark.adapters
|
|
10
10
|
|
|
11
11
|
pytest.importorskip("gymnax")
|
|
12
12
|
|
|
13
|
+
from envelope.compat.gymnax_envelope import GymnaxEnvelope, _convert_space
|
|
13
14
|
from gymnax.environments import spaces as gymnax_spaces
|
|
14
15
|
|
|
15
|
-
from envelope.compat.gymnax_envelope import GymnaxEnvelope, _convert_space
|
|
16
16
|
from envelope.spaces import Continuous, Discrete, PyTreeSpace
|
|
17
17
|
from tests.compat.contract import (
|
|
18
18
|
assert_jitted_rollout_contract,
|
{jax_envelope-0.3.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/test_jumanji_compat.py
RENAMED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""Tests for envelope.
|
|
1
|
+
"""Tests for envelope.adapters.jumanji_envelope module."""
|
|
2
2
|
|
|
3
3
|
# ruff: noqa: E402
|
|
4
4
|
|
|
@@ -11,19 +11,19 @@ import jax.numpy as jnp
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
import pytest
|
|
13
13
|
|
|
14
|
-
pytestmark = pytest.mark.
|
|
14
|
+
pytestmark = pytest.mark.adapters
|
|
15
15
|
|
|
16
16
|
pytest.importorskip("jumanji")
|
|
17
17
|
|
|
18
18
|
from jumanji import specs
|
|
19
19
|
|
|
20
|
-
import envelope.
|
|
21
|
-
from envelope.
|
|
20
|
+
import envelope.adapters.jumanji_envelope as jumanji_envelope
|
|
21
|
+
from envelope.adapters.jumanji_envelope import (
|
|
22
22
|
JumanjiEnvelope,
|
|
23
23
|
convert_jumanji_spec_to_envelope_space,
|
|
24
24
|
)
|
|
25
25
|
from envelope.spaces import Continuous, Discrete, PyTreeSpace
|
|
26
|
-
from tests.
|
|
26
|
+
from tests.adapters.contract import (
|
|
27
27
|
assert_jitted_rollout_contract,
|
|
28
28
|
assert_reset_step_contract,
|
|
29
29
|
)
|
{jax_envelope-0.3.0/tests/compat → jax_envelope-0.4.0/tests/adapters}/test_kinetix_compat.py
RENAMED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""Tests for envelope.
|
|
1
|
+
"""Tests for envelope.adapters.kinetix_envelope module."""
|
|
2
2
|
|
|
3
3
|
# ruff: noqa: E402
|
|
4
4
|
|
|
@@ -10,17 +10,17 @@ import jax
|
|
|
10
10
|
import jax.numpy as jnp
|
|
11
11
|
import pytest
|
|
12
12
|
|
|
13
|
-
pytestmark = pytest.mark.
|
|
13
|
+
pytestmark = pytest.mark.adapters
|
|
14
14
|
|
|
15
15
|
pytest.importorskip("kinetix")
|
|
16
16
|
|
|
17
17
|
import kinetix
|
|
18
18
|
from kinetix.environment import EnvParams, StaticEnvParams
|
|
19
19
|
|
|
20
|
-
from envelope.
|
|
20
|
+
from envelope.adapters.kinetix_envelope import KinetixEnvelope, _normalize_level_id
|
|
21
21
|
from envelope.environment import Info
|
|
22
22
|
from envelope.spaces import Continuous
|
|
23
|
-
from tests.
|
|
23
|
+
from tests.adapters.contract import (
|
|
24
24
|
assert_jitted_rollout_contract,
|
|
25
25
|
assert_reset_step_contract,
|
|
26
26
|
)
|
|
@@ -189,7 +189,7 @@ def test_from_name_rejects_unknown_env_kwargs():
|
|
|
189
189
|
|
|
190
190
|
def test_from_name_allows_premade_state_none(monkeypatch: pytest.MonkeyPatch):
|
|
191
191
|
"""Current implementation does not guard against missing premade state."""
|
|
192
|
-
from envelope.
|
|
192
|
+
from envelope.adapters import kinetix_envelope
|
|
193
193
|
|
|
194
194
|
def mock_load(_level_id: str):
|
|
195
195
|
return None, StaticEnvParams(), EnvParams()
|
|
@@ -210,7 +210,7 @@ def test_create_premade_replace_failure_raises(monkeypatch: pytest.MonkeyPatch):
|
|
|
210
210
|
|
|
211
211
|
monkeypatch.setattr(EnvParams, "replace", failing_replace)
|
|
212
212
|
monkeypatch.setattr(
|
|
213
|
-
"envelope.
|
|
213
|
+
"envelope.adapters.kinetix_envelope.load_from_json_file",
|
|
214
214
|
lambda _level_id: (object(), StaticEnvParams(), ep),
|
|
215
215
|
)
|
|
216
216
|
|
|
@@ -218,6 +218,5 @@ def test_create_premade_replace_failure_raises(monkeypatch: pytest.MonkeyPatch):
|
|
|
218
218
|
KinetixEnvelope.from_name("s/h4_thrust_aim")
|
|
219
219
|
|
|
220
220
|
|
|
221
|
-
#
|
|
222
|
-
# NOTE: Level-id normalization tests live in tests/compat/test_kinetix_level_id.py
|
|
221
|
+
# NOTE: Level-id normalization tests live in tests/adapters/test_kinetix_level_id.py
|
|
223
222
|
# to keep this module focused on runtime wrapper behavior.
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""Tests for envelope.
|
|
1
|
+
"""Tests for envelope.adapters.mujoco_playground_envelope module."""
|
|
2
2
|
|
|
3
3
|
# ruff: noqa: E402
|
|
4
4
|
|
|
@@ -6,14 +6,14 @@ import jax
|
|
|
6
6
|
import jax.numpy as jnp
|
|
7
7
|
import pytest
|
|
8
8
|
|
|
9
|
-
pytestmark = pytest.mark.
|
|
9
|
+
pytestmark = pytest.mark.adapters
|
|
10
10
|
|
|
11
11
|
pytest.importorskip("mujoco_playground")
|
|
12
12
|
|
|
13
|
-
from envelope.
|
|
13
|
+
from envelope.adapters.mujoco_playground_envelope import MujocoPlaygroundEnvelope
|
|
14
14
|
from envelope.environment import Info
|
|
15
15
|
from envelope.spaces import Continuous, PyTreeSpace
|
|
16
|
-
from tests.
|
|
16
|
+
from tests.adapters.contract import (
|
|
17
17
|
assert_jitted_rollout_contract,
|
|
18
18
|
assert_reset_step_contract,
|
|
19
19
|
)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""Tests for envelope.
|
|
1
|
+
"""Tests for envelope.adapters.navix_envelope module."""
|
|
2
2
|
|
|
3
3
|
# ruff: noqa: E402
|
|
4
4
|
|
|
@@ -6,15 +6,15 @@ import jax
|
|
|
6
6
|
import jax.numpy as jnp
|
|
7
7
|
import pytest
|
|
8
8
|
|
|
9
|
-
pytestmark = pytest.mark.
|
|
9
|
+
pytestmark = pytest.mark.adapters
|
|
10
10
|
|
|
11
11
|
pytest.importorskip("navix")
|
|
12
12
|
|
|
13
13
|
import navix
|
|
14
14
|
|
|
15
|
-
from envelope.
|
|
15
|
+
from envelope.adapters.navix_envelope import NavixEnvelope
|
|
16
16
|
from envelope.spaces import Continuous, Discrete
|
|
17
|
-
from tests.
|
|
17
|
+
from tests.adapters.contract import (
|
|
18
18
|
assert_jitted_rollout_contract,
|
|
19
19
|
assert_reset_step_contract,
|
|
20
20
|
)
|
|
@@ -160,7 +160,7 @@ def test_unsupported_space_type():
|
|
|
160
160
|
import jax.numpy as jnp
|
|
161
161
|
from navix import spaces as navix_spaces
|
|
162
162
|
|
|
163
|
-
from envelope.
|
|
163
|
+
from envelope.adapters.navix_envelope import convert_navix_to_envelope_space
|
|
164
164
|
|
|
165
165
|
# Create a mock space that's neither Discrete nor Continuous
|
|
166
166
|
class MockSpace(navix_spaces.Space):
|
|
@@ -181,7 +181,7 @@ def test_step_type_conversion(navix_env, prng_key):
|
|
|
181
181
|
"""Test all navix StepType values are correctly converted."""
|
|
182
182
|
import navix
|
|
183
183
|
|
|
184
|
-
from envelope.
|
|
184
|
+
from envelope.adapters.navix_envelope import convert_navix_to_envelope_info
|
|
185
185
|
|
|
186
186
|
env = navix_env
|
|
187
187
|
key = prng_key
|
|
@@ -224,7 +224,7 @@ def test_discrete_space_conversion():
|
|
|
224
224
|
"""Test conversion of discrete spaces from navix to envelope."""
|
|
225
225
|
from navix import spaces as navix_spaces
|
|
226
226
|
|
|
227
|
-
from envelope.
|
|
227
|
+
from envelope.adapters.navix_envelope import convert_navix_to_envelope_space
|
|
228
228
|
|
|
229
229
|
# Create a navix Discrete space
|
|
230
230
|
navix_discrete = navix_spaces.Discrete.create(10, shape=(3,), dtype=jnp.int32)
|
|
@@ -245,7 +245,7 @@ def test_continuous_space_conversion():
|
|
|
245
245
|
"""Test conversion of continuous spaces from navix to envelope."""
|
|
246
246
|
from navix import spaces as navix_spaces
|
|
247
247
|
|
|
248
|
-
from envelope.
|
|
248
|
+
from envelope.adapters.navix_envelope import convert_navix_to_envelope_space
|
|
249
249
|
|
|
250
250
|
# Create a navix Continuous space
|
|
251
251
|
navix_continuous = navix_spaces.Continuous.create(
|
|
@@ -1056,14 +1056,14 @@ wheels = [
|
|
|
1056
1056
|
|
|
1057
1057
|
[[package]]
|
|
1058
1058
|
name = "jax-envelope"
|
|
1059
|
-
version = "0.
|
|
1059
|
+
version = "0.4.0"
|
|
1060
1060
|
source = { editable = "." }
|
|
1061
1061
|
dependencies = [
|
|
1062
1062
|
{ name = "jax" },
|
|
1063
1063
|
]
|
|
1064
1064
|
|
|
1065
1065
|
[package.dev-dependencies]
|
|
1066
|
-
|
|
1066
|
+
adapters = [
|
|
1067
1067
|
{ name = "brax" },
|
|
1068
1068
|
{ name = "craftax" },
|
|
1069
1069
|
{ name = "gymnax" },
|
|
@@ -1083,7 +1083,7 @@ dev = [
|
|
|
1083
1083
|
requires-dist = [{ name = "jax", specifier = ">=0.5.0" }]
|
|
1084
1084
|
|
|
1085
1085
|
[package.metadata.requires-dev]
|
|
1086
|
-
|
|
1086
|
+
adapters = [
|
|
1087
1087
|
{ name = "brax", specifier = ">=0.13.0" },
|
|
1088
1088
|
{ name = "craftax", specifier = ">=1.4.3" },
|
|
1089
1089
|
{ name = "gymnax", git = "https://github.com/RobertTLange/gymnax.git" },
|
|
File without changes
|
|
File without changes
|
{jax_envelope-0.3.0/src/envelope/compat → jax_envelope-0.4.0/src/envelope/adapters}/brax_envelope.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{jax_envelope-0.3.0 → jax_envelope-0.4.0}/src/envelope/wrappers/continuous_observation_wrapper.py
RENAMED
|
File without changes
|
{jax_envelope-0.3.0 → jax_envelope-0.4.0}/src/envelope/wrappers/episode_statistics_wrapper.py
RENAMED
|
File without changes
|
|
File without changes
|
{jax_envelope-0.3.0 → jax_envelope-0.4.0}/src/envelope/wrappers/flatten_observation_wrapper.py
RENAMED
|
File without changes
|
|
File without changes
|
{jax_envelope-0.3.0 → jax_envelope-0.4.0}/src/envelope/wrappers/observation_normalization_wrapper.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{jax_envelope-0.3.0 → jax_envelope-0.4.0}/tests/wrappers/test_continuous_observation_wrapper.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{jax_envelope-0.3.0 → jax_envelope-0.4.0}/tests/wrappers/test_flatten_observation_wrapper.py
RENAMED
|
File without changes
|
|
File without changes
|
{jax_envelope-0.3.0 → jax_envelope-0.4.0}/tests/wrappers/test_observation_normalization_wrapper.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|