jax-envelope 0.1.0__tar.gz → 0.1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/PKG-INFO +2 -2
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/README.md +1 -1
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/pyproject.toml +1 -1
- jax_envelope-0.1.1/src/envelope/__init__.py +42 -0
- jax_envelope-0.1.1/src/envelope/typing.py +7 -0
- jax_envelope-0.1.1/src/envelope/wrappers/__init__.py +20 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/uv.lock +1 -1
- jax_envelope-0.1.0/src/envelope/typing.py +0 -23
- jax_envelope-0.1.0/src/envelope/wrappers/timestep_wrapper.py +0 -22
- jax_envelope-0.1.0/tests/wrappers/__init__.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/.github/workflows/publish.yml +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/.gitignore +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/LICENSE +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/src/envelope/compat/__init__.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/src/envelope/compat/brax_envelope.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/src/envelope/compat/craftax_envelope.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/src/envelope/compat/gymnax_envelope.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/src/envelope/compat/jumanji_envelope.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/src/envelope/compat/kinetix_envelope.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/src/envelope/compat/mujoco_playground_envelope.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/src/envelope/compat/navix_envelope.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/src/envelope/environment.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/src/envelope/spaces.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/src/envelope/struct.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/src/envelope/wrappers/autoreset_wrapper.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/src/envelope/wrappers/episode_statistics_wrapper.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/src/envelope/wrappers/normalization.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/src/envelope/wrappers/observation_normalization_wrapper.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/src/envelope/wrappers/state_injection_wrapper.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/src/envelope/wrappers/truncation_wrapper.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/src/envelope/wrappers/vmap_envs_wrapper.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/src/envelope/wrappers/vmap_wrapper.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/src/envelope/wrappers/wrapper.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/tests/__init__.py +0 -0
- {jax_envelope-0.1.0/src/envelope → jax_envelope-0.1.1/tests/compat}/__init__.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/tests/compat/conftest.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/tests/compat/contract.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/tests/compat/test_brax_compat.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/tests/compat/test_craftax_compat.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/tests/compat/test_create.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/tests/compat/test_create_integration.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/tests/compat/test_gymnax_compat.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/tests/compat/test_jumanji_compat.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/tests/compat/test_kinetix_compat.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/tests/compat/test_mujoco_playground_compat.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/tests/compat/test_navix_compat.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/tests/spaces/__init__.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/tests/spaces/test_batched_space.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/tests/spaces/test_continuous.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/tests/spaces/test_discrete.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/tests/spaces/test_pytree_space.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/tests/spaces/test_serialization.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/tests/test_container.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/tests/test_struct.py +0 -0
- {jax_envelope-0.1.0/tests/compat → jax_envelope-0.1.1/tests/wrappers}/__init__.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/tests/wrappers/helpers.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/tests/wrappers/test_autoreset_wrapper.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/tests/wrappers/test_environment_wrapper.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/tests/wrappers/test_normalization.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/tests/wrappers/test_observation_normalization_wrapper.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/tests/wrappers/test_state_injection_wrapper.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/tests/wrappers/test_truncation_wrapper.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/tests/wrappers/test_vmap_envs_wrapper.py +0 -0
- {jax_envelope-0.1.0 → jax_envelope-0.1.1}/tests/wrappers/test_vmap_wrapper.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: jax-envelope
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.1
|
|
4
4
|
Summary: A JAX-native environment interface with powerful wrappers and adapters for popular RL environment suites
|
|
5
5
|
Project-URL: Homepage, https://github.com/keraJLi/envelope
|
|
6
6
|
Project-URL: Repository, https://github.com/keraJLi/envelope
|
|
@@ -81,7 +81,7 @@ let's you create environments from any of the above!
|
|
|
81
81
|
pip install jax-envelope
|
|
82
82
|
```
|
|
83
83
|
|
|
84
|
-
##
|
|
84
|
+
## 💞 Related projects
|
|
85
85
|
* [stoax](https://github.com/EdanToledo/Stoa) is a very similar project that provides adapters and wrappers for the jumanji-like interface.
|
|
86
86
|
* Check out all the great suites we have adapters for! [gymnax](https://github.com/RobertTLange/gymnax), [brax](https://github.com/google/brax), [jumanji](https://github.com/instadeepai/jumanji), [kinetix](https://github.com/flairox/kinetix), [craftax](https://github.com/MichaelTMatthews/craftax), [mujoco_playground](https://github.com/google-deepmind/mujoco_playground).
|
|
87
87
|
* We will be adding support for [jaxmarl](https://github.com/flairox/jaxmarl) and [pgx](https://github.com/sotetsuk/pgx) in the future, as soon as we figured out the best ever MARL interface for JAX!
|
|
@@ -55,7 +55,7 @@ let's you create environments from any of the above!
|
|
|
55
55
|
pip install jax-envelope
|
|
56
56
|
```
|
|
57
57
|
|
|
58
|
-
##
|
|
58
|
+
## 💞 Related projects
|
|
59
59
|
* [stoax](https://github.com/EdanToledo/Stoa) is a very similar project that provides adapters and wrappers for the jumanji-like interface.
|
|
60
60
|
* Check out all the great suites we have adapters for! [gymnax](https://github.com/RobertTLange/gymnax), [brax](https://github.com/google/brax), [jumanji](https://github.com/instadeepai/jumanji), [kinetix](https://github.com/flairox/kinetix), [craftax](https://github.com/MichaelTMatthews/craftax), [mujoco_playground](https://github.com/google-deepmind/mujoco_playground).
|
|
61
61
|
* We will be adding support for [jaxmarl](https://github.com/flairox/jaxmarl) and [pgx](https://github.com/sotetsuk/pgx) in the future, as soon as we figured out the best ever MARL interface for JAX!
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from envelope.compat import create
|
|
2
|
+
from envelope.environment import Environment, Info, InfoContainer
|
|
3
|
+
from envelope.spaces import BatchedSpace, Continuous, Discrete, PyTreeSpace, Space
|
|
4
|
+
from envelope.struct import field, static_field, FrozenPyTreeNode, Container
|
|
5
|
+
from envelope.wrappers import (
|
|
6
|
+
Wrapper,
|
|
7
|
+
WrappedState,
|
|
8
|
+
AutoResetWrapper,
|
|
9
|
+
ObservationNormalizationWrapper,
|
|
10
|
+
StateInjectionWrapper,
|
|
11
|
+
TruncationWrapper,
|
|
12
|
+
VmapWrapper,
|
|
13
|
+
VmapEnvsWrapper,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
# Basic functionality
|
|
18
|
+
"create",
|
|
19
|
+
"Environment",
|
|
20
|
+
"Info",
|
|
21
|
+
"InfoContainer",
|
|
22
|
+
# Spaces
|
|
23
|
+
"Space",
|
|
24
|
+
"BatchedSpace",
|
|
25
|
+
"Continuous",
|
|
26
|
+
"Discrete",
|
|
27
|
+
"PyTreeSpace",
|
|
28
|
+
# Struct
|
|
29
|
+
"field",
|
|
30
|
+
"static_field",
|
|
31
|
+
"FrozenPyTreeNode",
|
|
32
|
+
"Container",
|
|
33
|
+
# Wrappers
|
|
34
|
+
"Wrapper",
|
|
35
|
+
"WrappedState",
|
|
36
|
+
"AutoResetWrapper",
|
|
37
|
+
"ObservationNormalizationWrapper",
|
|
38
|
+
"StateInjectionWrapper",
|
|
39
|
+
"TruncationWrapper",
|
|
40
|
+
"VmapWrapper",
|
|
41
|
+
"VmapEnvsWrapper",
|
|
42
|
+
]
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from envelope.wrappers.autoreset_wrapper import AutoResetWrapper
|
|
2
|
+
from envelope.wrappers.observation_normalization_wrapper import (
|
|
3
|
+
ObservationNormalizationWrapper,
|
|
4
|
+
)
|
|
5
|
+
from envelope.wrappers.state_injection_wrapper import StateInjectionWrapper
|
|
6
|
+
from envelope.wrappers.truncation_wrapper import TruncationWrapper
|
|
7
|
+
from envelope.wrappers.vmap_wrapper import VmapWrapper
|
|
8
|
+
from envelope.wrappers.vmap_envs_wrapper import VmapEnvsWrapper
|
|
9
|
+
from envelope.wrappers.wrapper import Wrapper, WrappedState
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"Wrapper",
|
|
13
|
+
"WrappedState",
|
|
14
|
+
"AutoResetWrapper",
|
|
15
|
+
"ObservationNormalizationWrapper",
|
|
16
|
+
"StateInjectionWrapper",
|
|
17
|
+
"TruncationWrapper",
|
|
18
|
+
"VmapWrapper",
|
|
19
|
+
"VmapEnvsWrapper",
|
|
20
|
+
]
|
|
@@ -1,23 +0,0 @@
|
|
|
1
|
-
from enum import Enum
|
|
2
|
-
from typing import Any, TypeAlias
|
|
3
|
-
|
|
4
|
-
import jax
|
|
5
|
-
|
|
6
|
-
PyTree: TypeAlias = Any
|
|
7
|
-
Key: TypeAlias = jax.Array
|
|
8
|
-
Array: TypeAlias = jax.Array
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
class BatchKind(Enum):
|
|
12
|
-
"""
|
|
13
|
-
Batch semantics for environments.
|
|
14
|
-
- VMAP: environment represents instances compatible with `jax.vmap`, for example by
|
|
15
|
-
wrapping it in a `VmapWrapper`.
|
|
16
|
-
- NATIVE_POOL: environment represents a batch of instances via a native pool. This
|
|
17
|
-
is the case when it is wrapping a non-jax-based environment that supports native
|
|
18
|
-
batching, for example those provided by envpool. Environments in this mode cannot
|
|
19
|
-
be vmapped, as it would break the native batching semantics.
|
|
20
|
-
"""
|
|
21
|
-
|
|
22
|
-
VMAP = "vmap"
|
|
23
|
-
NATIVE_POOL = "native_pool"
|
|
@@ -1,22 +0,0 @@
|
|
|
1
|
-
from envelope.environment import Info
|
|
2
|
-
from envelope.struct import field
|
|
3
|
-
from envelope.typing import Key, PyTree
|
|
4
|
-
from envelope.wrappers.wrapper import WrappedState, Wrapper
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
class TimeStepWrapper(Wrapper):
|
|
8
|
-
class TimeStepState(WrappedState):
|
|
9
|
-
steps: PyTree = field(default=0)
|
|
10
|
-
|
|
11
|
-
def reset(
|
|
12
|
-
self, key: Key, state: PyTree | None = None, **kwargs
|
|
13
|
-
) -> tuple[WrappedState, Info]:
|
|
14
|
-
inner_state, info = self.env.reset(key, state, **kwargs)
|
|
15
|
-
return self.TimeStepState(inner_state=inner_state, steps=0), info
|
|
16
|
-
|
|
17
|
-
def step(
|
|
18
|
-
self, state: WrappedState, action: PyTree, **kwargs
|
|
19
|
-
) -> tuple[WrappedState, Info]:
|
|
20
|
-
next_inner_state, info = self.env.step(state.inner_state, action, **kwargs)
|
|
21
|
-
next_steps = getattr(state, "steps", 0) + 1
|
|
22
|
-
return self.TimeStepState(inner_state=next_inner_state, steps=next_steps), info
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{jax_envelope-0.1.0 → jax_envelope-0.1.1}/src/envelope/wrappers/episode_statistics_wrapper.py
RENAMED
|
File without changes
|
|
File without changes
|
{jax_envelope-0.1.0 → jax_envelope-0.1.1}/src/envelope/wrappers/observation_normalization_wrapper.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{jax_envelope-0.1.0 → jax_envelope-0.1.1}/tests/wrappers/test_observation_normalization_wrapper.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|