jax-envelope 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- envelope/__init__.py +16 -4
- envelope/compat/brax_envelope.py +5 -3
- envelope/compat/craftax_envelope.py +17 -2
- envelope/compat/gymnax_envelope.py +34 -7
- envelope/compat/jumanji_envelope.py +3 -2
- envelope/compat/kinetix_envelope.py +3 -2
- envelope/compat/mujoco_playground_envelope.py +1 -1
- envelope/compat/navix_envelope.py +1 -1
- envelope/environment.py +16 -9
- envelope/spaces.py +41 -21
- envelope/struct.py +10 -1
- envelope/wrappers/__init__.py +18 -2
- envelope/wrappers/autoreset_wrapper.py +65 -21
- envelope/wrappers/clip_action_wrapper.py +27 -0
- envelope/wrappers/continuous_observation_wrapper.py +61 -0
- envelope/wrappers/episode_statistics_wrapper.py +29 -36
- envelope/wrappers/flatten_action_wrapper.py +75 -0
- envelope/wrappers/flatten_observation_wrapper.py +81 -0
- envelope/wrappers/normalization.py +1 -1
- envelope/wrappers/observation_normalization_wrapper.py +28 -16
- envelope/wrappers/pooled_init_vmap_wrapper.py +122 -0
- envelope/wrappers/state_injection_wrapper.py +18 -22
- envelope/wrappers/truncation_wrapper.py +18 -14
- envelope/wrappers/vmap_envs_wrapper.py +26 -21
- envelope/wrappers/vmap_wrapper.py +36 -21
- envelope/wrappers/wrapper.py +8 -8
- {jax_envelope-0.1.1.dist-info → jax_envelope-0.2.0.dist-info}/METADATA +2 -2
- jax_envelope-0.2.0.dist-info/RECORD +32 -0
- jax_envelope-0.1.1.dist-info/RECORD +0 -27
- {jax_envelope-0.1.1.dist-info → jax_envelope-0.2.0.dist-info}/WHEEL +0 -0
- {jax_envelope-0.1.1.dist-info → jax_envelope-0.2.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -17,34 +17,39 @@ class VmapEnvsWrapper(Wrapper):
|
|
|
17
17
|
Usage:
|
|
18
18
|
envs = jax.vmap(make_env)(params_batch) # env pytree batched on leading axis
|
|
19
19
|
wrapped = VmapEnvsWrapper(env=envs, batch_size=B)
|
|
20
|
-
state, info = wrapped.
|
|
20
|
+
state, info = wrapped.init(keys) # keys shape (B, 2) or single key
|
|
21
21
|
next_state, info = wrapped.step(state, action)
|
|
22
22
|
"""
|
|
23
23
|
|
|
24
24
|
batch_size: int = field(kw_only=True)
|
|
25
25
|
|
|
26
|
-
|
|
27
|
-
def reset(
|
|
28
|
-
self, key: Key, state: PyTree | None = None, **kwargs
|
|
29
|
-
) -> tuple[WrappedState, Info]:
|
|
26
|
+
def _split_keys(self, key: Key) -> Key:
|
|
30
27
|
if key.shape == (2,):
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
28
|
+
return jax.random.split(key, self.batch_size)
|
|
29
|
+
if key.shape[0] != self.batch_size:
|
|
30
|
+
raise ValueError(
|
|
31
|
+
f"reset key's leading dimension ({key.shape[0]}) must match "
|
|
32
|
+
f"batch_size ({self.batch_size})."
|
|
33
|
+
)
|
|
34
|
+
return key
|
|
35
|
+
|
|
36
|
+
@override
|
|
37
|
+
def init(self, key: Key) -> tuple[WrappedState, Info]:
|
|
38
|
+
keys = self._split_keys(key)
|
|
39
|
+
state, info = jax.vmap(lambda e, k: e.init(k))(self.env, keys)
|
|
40
|
+
return state, info
|
|
41
|
+
|
|
42
|
+
@override
|
|
43
|
+
def reset(self, key: Key, state: PyTree) -> tuple[WrappedState, Info]:
|
|
44
|
+
keys = self._split_keys(key)
|
|
45
|
+
state, info = jax.vmap(lambda e, k, s: e.reset(k, s))(
|
|
46
|
+
self.env, keys, state
|
|
47
|
+
)
|
|
41
48
|
return state, info
|
|
42
49
|
|
|
43
50
|
@override
|
|
44
|
-
def step(
|
|
45
|
-
|
|
46
|
-
) -> tuple[WrappedState, Info]:
|
|
47
|
-
next_state, info = jax.vmap(lambda e, s, a: e.step(s, a, **kwargs))(
|
|
51
|
+
def step(self, state: WrappedState, action: PyTree) -> tuple[WrappedState, Info]:
|
|
52
|
+
next_state, info = jax.vmap(lambda e, s, a: e.step(s, a))(
|
|
48
53
|
self.env, state, action
|
|
49
54
|
)
|
|
50
55
|
return next_state, info
|
|
@@ -53,13 +58,13 @@ class VmapEnvsWrapper(Wrapper):
|
|
|
53
58
|
@property
|
|
54
59
|
def observation_space(self) -> spaces.Space:
|
|
55
60
|
env0 = _index_env(self.env, 0, self.batch_size)
|
|
56
|
-
return spaces.
|
|
61
|
+
return spaces.BatchedSpace(space=env0.observation_space, batch_size=self.batch_size)
|
|
57
62
|
|
|
58
63
|
@override
|
|
59
64
|
@cached_property
|
|
60
65
|
def action_space(self) -> spaces.Space:
|
|
61
66
|
env0 = _index_env(self.env, 0, self.batch_size)
|
|
62
|
-
return spaces.
|
|
67
|
+
return spaces.BatchedSpace(space=env0.action_space, batch_size=self.batch_size)
|
|
63
68
|
|
|
64
69
|
@override
|
|
65
70
|
@property
|
|
@@ -2,50 +2,65 @@ from functools import cached_property
|
|
|
2
2
|
from typing import override
|
|
3
3
|
|
|
4
4
|
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
5
6
|
|
|
6
7
|
from envelope import spaces
|
|
7
8
|
from envelope.environment import Info
|
|
8
|
-
from envelope.struct import
|
|
9
|
+
from envelope.struct import static_field
|
|
9
10
|
from envelope.typing import Key, PyTree
|
|
10
11
|
from envelope.wrappers.wrapper import WrappedState, Wrapper
|
|
11
12
|
|
|
12
13
|
|
|
14
|
+
def is_single_key(key):
|
|
15
|
+
# New-style typed keys have dtype like key<fry>
|
|
16
|
+
if jnp.issubdtype(key.dtype, jax.dtypes.prng_key):
|
|
17
|
+
return key.ndim == 0
|
|
18
|
+
return key.shape == (2,)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _split_or_keep_key(key: Key, batch_size: int) -> Key:
|
|
22
|
+
if is_single_key(key):
|
|
23
|
+
return jax.random.split(key, batch_size)
|
|
24
|
+
elif key.shape[0] == batch_size:
|
|
25
|
+
return key
|
|
26
|
+
raise ValueError(
|
|
27
|
+
f"reset key's leading dimension ({key.shape[0]}) must match "
|
|
28
|
+
f"batch_size ({batch_size})."
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
13
32
|
class VmapWrapper(Wrapper):
|
|
14
|
-
"""Does not
|
|
33
|
+
"""Does not wrap the state."""
|
|
15
34
|
|
|
16
|
-
batch_size: int =
|
|
35
|
+
batch_size: int = static_field(kw_only=True)
|
|
17
36
|
|
|
18
37
|
@override
|
|
19
|
-
def
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
if key.shape == (2,):
|
|
24
|
-
keys = jax.random.split(key, self.batch_size)
|
|
25
|
-
else:
|
|
26
|
-
if key.shape[0] != self.batch_size:
|
|
27
|
-
raise ValueError(
|
|
28
|
-
f"reset key's leading dimension ({key.shape[0]}) must match "
|
|
29
|
-
f"batch_size ({self.batch_size})."
|
|
30
|
-
)
|
|
31
|
-
keys = key
|
|
38
|
+
def init(self, key: Key) -> tuple[WrappedState, Info]:
|
|
39
|
+
keys = _split_or_keep_key(key, self.batch_size)
|
|
40
|
+
state, info = jax.vmap(self.env.init)(keys)
|
|
41
|
+
return state, info
|
|
32
42
|
|
|
43
|
+
@override
|
|
44
|
+
def reset(self, key: Key, state: PyTree) -> tuple[WrappedState, Info]:
|
|
45
|
+
keys = _split_or_keep_key(key, self.batch_size)
|
|
33
46
|
state, info = jax.vmap(self.env.reset)(keys, state)
|
|
34
47
|
return state, info
|
|
35
48
|
|
|
36
49
|
@override
|
|
37
|
-
def step(
|
|
38
|
-
self, state: WrappedState, action: PyTree, **kwargs
|
|
39
|
-
) -> tuple[WrappedState, Info]:
|
|
50
|
+
def step(self, state: WrappedState, action: PyTree) -> tuple[WrappedState, Info]:
|
|
40
51
|
state, info = jax.vmap(self.env.step)(state, action)
|
|
41
52
|
return state, info
|
|
42
53
|
|
|
43
54
|
@override
|
|
44
55
|
@cached_property
|
|
45
56
|
def observation_space(self) -> spaces.Space:
|
|
46
|
-
return spaces.
|
|
57
|
+
return spaces.BatchedSpace(
|
|
58
|
+
space=self.env.observation_space, batch_size=self.batch_size
|
|
59
|
+
)
|
|
47
60
|
|
|
48
61
|
@override
|
|
49
62
|
@cached_property
|
|
50
63
|
def action_space(self) -> spaces.Space:
|
|
51
|
-
return spaces.
|
|
64
|
+
return spaces.BatchedSpace(
|
|
65
|
+
space=self.env.action_space, batch_size=self.batch_size
|
|
66
|
+
)
|
envelope/wrappers/wrapper.py
CHANGED
|
@@ -25,16 +25,16 @@ class Wrapper(Environment):
|
|
|
25
25
|
env: Environment = field(kw_only=True)
|
|
26
26
|
|
|
27
27
|
@override
|
|
28
|
-
def
|
|
29
|
-
self
|
|
30
|
-
) -> tuple[State, Info]:
|
|
31
|
-
return self.env.reset(key, state=state, **kwargs)
|
|
28
|
+
def init(self, key: Key) -> tuple[State, Info]:
|
|
29
|
+
return self.env.init(key)
|
|
32
30
|
|
|
33
31
|
@override
|
|
34
|
-
def
|
|
35
|
-
self, state
|
|
36
|
-
|
|
37
|
-
|
|
32
|
+
def reset(self, key: Key, state: State) -> tuple[State, Info]:
|
|
33
|
+
return self.env.reset(key, state)
|
|
34
|
+
|
|
35
|
+
@override
|
|
36
|
+
def step(self, state: WrappedState, action: PyTree) -> tuple[WrappedState, Info]:
|
|
37
|
+
return self.env.step(state, action)
|
|
38
38
|
|
|
39
39
|
@override
|
|
40
40
|
@cached_property
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: jax-envelope
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary: A JAX-native environment interface with powerful wrappers and adapters for popular RL environment suites
|
|
5
5
|
Project-URL: Homepage, https://github.com/keraJLi/envelope
|
|
6
6
|
Project-URL: Repository, https://github.com/keraJLi/envelope
|
|
@@ -82,6 +82,6 @@ pip install jax-envelope
|
|
|
82
82
|
```
|
|
83
83
|
|
|
84
84
|
## 💞 Related projects
|
|
85
|
-
* [
|
|
85
|
+
* [stoa](https://github.com/EdanToledo/Stoa) is a very similar project that provides adapters and wrappers for the jumanji-like interface.
|
|
86
86
|
* Check out all the great suites we have adapters for! [gymnax](https://github.com/RobertTLange/gymnax), [brax](https://github.com/google/brax), [jumanji](https://github.com/instadeepai/jumanji), [kinetix](https://github.com/flairox/kinetix), [craftax](https://github.com/MichaelTMatthews/craftax), [mujoco_playground](https://github.com/google-deepmind/mujoco_playground).
|
|
87
87
|
* We will be adding support for [jaxmarl](https://github.com/flairox/jaxmarl) and [pgx](https://github.com/sotetsuk/pgx) in the future, as soon as we figured out the best ever MARL interface for JAX!
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
envelope/__init__.py,sha256=AAk0b9UHGUTC4EzERXYPOrHuECgO0fx_UN4_gpKvgT0,1329
|
|
2
|
+
envelope/environment.py,sha256=OC15ooivLvD8EpmgRQGL-UjRpF0UXO0fOGu61I8HmRs,2015
|
|
3
|
+
envelope/spaces.py,sha256=Nb8T11ilCtqKPSLZ98oqhnuw_8JNLOGCZdsy2X_pSwU,6944
|
|
4
|
+
envelope/struct.py,sha256=tO1LLxk0bYc4kBCXraDv93T5d-xFRzpv4QCjASB0Lv4,5522
|
|
5
|
+
envelope/typing.py,sha256=2rWKiZnKXaUU6JIl51Wcj_txRwe8wQNRSLSiRuT3OP4,127
|
|
6
|
+
envelope/compat/__init__.py,sha256=6q2B2bfTIu7MlnIXc702ysqEnyURz-MEjTBMBCo8rdQ,3458
|
|
7
|
+
envelope/compat/brax_envelope.py,sha256=XV07CYApXmf5g-Bin4pKWqExTZ7gd6AO2BCx3bk1Z1c,3482
|
|
8
|
+
envelope/compat/craftax_envelope.py,sha256=ey_mhJh5KCDRvjpCAJbFexpWiFjzQr7BZuRYLM9g-yA,3959
|
|
9
|
+
envelope/compat/gymnax_envelope.py,sha256=X0pfFKaFKzvBl2Zih4utiXSEI-lc5f9miJGMxeT-Xfw,4690
|
|
10
|
+
envelope/compat/jumanji_envelope.py,sha256=WhU4WJLW2Fn4BG1_69EJO9ZiBDWEaY387dhkRZjqcGc,4916
|
|
11
|
+
envelope/compat/kinetix_envelope.py,sha256=99RNN6nBN9CQEEe2fX2T1n8lL9ki0uKaz5BK052IU1o,6853
|
|
12
|
+
envelope/compat/mujoco_playground_envelope.py,sha256=7iPaRuwXUHbP8hHNgSFipW8rJVruootMQPYBWNpPiHA,3705
|
|
13
|
+
envelope/compat/navix_envelope.py,sha256=2n8vgr6gVeAYo12D1kDPxUrA9Hm-If3E7rV7TF_ksSA,3026
|
|
14
|
+
envelope/wrappers/__init__.py,sha256=GFs54kY1aWbf9xymzm22x6M2AhZCwFLx1f5LF23jX94,1409
|
|
15
|
+
envelope/wrappers/autoreset_wrapper.py,sha256=MJlbOj94_6I1KQ4oBNsKEB416YWp9ZMiqn85b1PdhKQ,3382
|
|
16
|
+
envelope/wrappers/clip_action_wrapper.py,sha256=YKJbKBUbf-LADxIfCE6mMOKo_NEa2xPfDncNKg9iEHM,955
|
|
17
|
+
envelope/wrappers/continuous_observation_wrapper.py,sha256=geDOeQDuASiRDEv5U3fF7GxblfB0Ku99GvTUGZSteJs,2041
|
|
18
|
+
envelope/wrappers/episode_statistics_wrapper.py,sha256=-LzypX4aYS9mCkl33M2mf3aT3OhXWtDc_kVU6VlSJ0o,1433
|
|
19
|
+
envelope/wrappers/flatten_action_wrapper.py,sha256=znR7WSebLy2wZnEWHi9SshxgV9b29khg-6ogQlmXq9U,2609
|
|
20
|
+
envelope/wrappers/flatten_observation_wrapper.py,sha256=xgoToXdJ_6M7cwfL7pu91ZnnFh3LHPR_DSe7XJwkh2Q,2839
|
|
21
|
+
envelope/wrappers/normalization.py,sha256=YkNEZqvFSaPurBA-WfXDynV51552ucNkY0JqNSAKNUE,1819
|
|
22
|
+
envelope/wrappers/observation_normalization_wrapper.py,sha256=pAwyd01lAxvEu1Z4KZfhSRoVNeLpeohtw1t1wsE53w4,4921
|
|
23
|
+
envelope/wrappers/pooled_init_vmap_wrapper.py,sha256=LAgrEjLXlq-gJiCGzI357a41tx1mVgb1ZQwtgN8mTuw,4985
|
|
24
|
+
envelope/wrappers/state_injection_wrapper.py,sha256=L8f9yrlSigk0jNx6D6wSu0cR0bgssq2uJNd9OYPAY0E,3553
|
|
25
|
+
envelope/wrappers/truncation_wrapper.py,sha256=-CSoXs-mDy1AiJVKLUc2zBHQxalhYjjyAAu1NmiIUsE,1345
|
|
26
|
+
envelope/wrappers/vmap_envs_wrapper.py,sha256=gE_TAmfsu2PJlDGWTz5xENJBJ-6nSJpTUlKhOnmZPi0,2757
|
|
27
|
+
envelope/wrappers/vmap_wrapper.py,sha256=B43lFNdAN_d8LgN6Rz6Ty85r9zcXLB3Wrn8R-Z2qHps,1973
|
|
28
|
+
envelope/wrappers/wrapper.py,sha256=tWluyIg4eGl60R2v3qaahU5dlmKKkBsJiNj7GnWR06E,1517
|
|
29
|
+
jax_envelope-0.2.0.dist-info/METADATA,sha256=jCWoAuoFsrLJakIhecRdz1zOIGCdrD4dBS0JSA2LVvw,4651
|
|
30
|
+
jax_envelope-0.2.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
31
|
+
jax_envelope-0.2.0.dist-info/licenses/LICENSE,sha256=VyF-MK-gY2_fZlhf8uEnE2y8ziIXK-w55GM12eOgXrQ,1069
|
|
32
|
+
jax_envelope-0.2.0.dist-info/RECORD,,
|
|
@@ -1,27 +0,0 @@
|
|
|
1
|
-
envelope/__init__.py,sha256=x4DtJ3WWPsPde1wSrXAXCdJVuXV3pioAojrrOvpJNqw,975
|
|
2
|
-
envelope/environment.py,sha256=4T0IosoRX-6gVvZCKBUEitlM_5cAT43hrKQ0Hr53XGk,1735
|
|
3
|
-
envelope/spaces.py,sha256=fvJ3-ZI3iyFlpfEN1barTd9cbWvQ11nl-7Y45KgFizs,6433
|
|
4
|
-
envelope/struct.py,sha256=Sb8GsxZ7rFF5A5128oZWEQVnFK79xUFe6EaU2bghBOw,5158
|
|
5
|
-
envelope/typing.py,sha256=2rWKiZnKXaUU6JIl51Wcj_txRwe8wQNRSLSiRuT3OP4,127
|
|
6
|
-
envelope/compat/__init__.py,sha256=6q2B2bfTIu7MlnIXc702ysqEnyURz-MEjTBMBCo8rdQ,3458
|
|
7
|
-
envelope/compat/brax_envelope.py,sha256=T0Xefrpgft16SdRJKMIjzdJq5tbocQ4jeI763PDzZ_4,3445
|
|
8
|
-
envelope/compat/craftax_envelope.py,sha256=t151fDCrgPWfynERXHYXiWBwCGAXbFcreuIyH7GCApg,3162
|
|
9
|
-
envelope/compat/gymnax_envelope.py,sha256=iau4CKFLVIJkRA9WILgOqkp3mXIAzkilKcrKF8vkld4,3790
|
|
10
|
-
envelope/compat/jumanji_envelope.py,sha256=_Tox7DfiyE-Z4drszPUe-arPurz6hMeFtRNfy8o9Ipc,4869
|
|
11
|
-
envelope/compat/kinetix_envelope.py,sha256=kPLWwSoQ49kNXuvfbStOtoGby9EkV7bKAetxUa1_YPA,6817
|
|
12
|
-
envelope/compat/mujoco_playground_envelope.py,sha256=UcYZwOPD27473wmDDb5n2eK46uPwOeh0LZOgyov4rww,3706
|
|
13
|
-
envelope/compat/navix_envelope.py,sha256=K7AGae9_kDETFAUuWH9sCCLDo5fuEFbc0eKTHcvSxek,3027
|
|
14
|
-
envelope/wrappers/__init__.py,sha256=r0Qz7gNkgJQF848Iv01uMGbeXTKHUdK2tjyom4nPCGw,701
|
|
15
|
-
envelope/wrappers/autoreset_wrapper.py,sha256=3OUUNb4L6P4Ncn57Uy2ldoLFr1IKnPSWuzyjf20N0Y8,1299
|
|
16
|
-
envelope/wrappers/episode_statistics_wrapper.py,sha256=Mj5Ua7cLtBQYtFn3oQB9eJ7TclKwAFUvsQHIRkcQJvs,1509
|
|
17
|
-
envelope/wrappers/normalization.py,sha256=xHezXsb1J09D3IZACC1xxMy9NuUSBWqCyYH2wFAItPs,1811
|
|
18
|
-
envelope/wrappers/observation_normalization_wrapper.py,sha256=NNHz_THFe2eQUY0Pued9_hJztIGU6SU7kae_AKFbe4k,4248
|
|
19
|
-
envelope/wrappers/state_injection_wrapper.py,sha256=yk7he1zPaEdzjks_NDnYDC91e5bXTlCFPKFMwm-A7jk,3687
|
|
20
|
-
envelope/wrappers/truncation_wrapper.py,sha256=TzJqxjAjipk7pTk-drXRvQ41uR-jWsHlAgYJpbHI32M,1132
|
|
21
|
-
envelope/wrappers/vmap_envs_wrapper.py,sha256=TezSCeOf3oJKj8w7i_UaAA2LsoWkhhp835W545Q-waI,2561
|
|
22
|
-
envelope/wrappers/vmap_wrapper.py,sha256=zYMYVvuVQfRM7d2ygxxcavRNtkdzkqhArakE507tcU4,1587
|
|
23
|
-
envelope/wrappers/wrapper.py,sha256=ydolcZR9ogW3z4ui07845mbMue2XtEoEJ64tHCzAAtw,1504
|
|
24
|
-
jax_envelope-0.1.1.dist-info/METADATA,sha256=JdPA7sTzarKPGRTiCckPxwPpy2MU_Wm-5vaKW6dIelg,4652
|
|
25
|
-
jax_envelope-0.1.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
26
|
-
jax_envelope-0.1.1.dist-info/licenses/LICENSE,sha256=VyF-MK-gY2_fZlhf8uEnE2y8ziIXK-w55GM12eOgXrQ,1069
|
|
27
|
-
jax_envelope-0.1.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|