jax-envelope 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. envelope/__init__.py +54 -0
  2. envelope/compat/brax_envelope.py +5 -3
  3. envelope/compat/craftax_envelope.py +17 -2
  4. envelope/compat/gymnax_envelope.py +34 -7
  5. envelope/compat/jumanji_envelope.py +3 -2
  6. envelope/compat/kinetix_envelope.py +3 -2
  7. envelope/compat/mujoco_playground_envelope.py +1 -1
  8. envelope/compat/navix_envelope.py +1 -1
  9. envelope/environment.py +16 -9
  10. envelope/spaces.py +41 -21
  11. envelope/struct.py +10 -1
  12. envelope/typing.py +0 -16
  13. envelope/wrappers/__init__.py +36 -0
  14. envelope/wrappers/autoreset_wrapper.py +65 -21
  15. envelope/wrappers/clip_action_wrapper.py +27 -0
  16. envelope/wrappers/continuous_observation_wrapper.py +61 -0
  17. envelope/wrappers/episode_statistics_wrapper.py +29 -36
  18. envelope/wrappers/flatten_action_wrapper.py +75 -0
  19. envelope/wrappers/flatten_observation_wrapper.py +81 -0
  20. envelope/wrappers/normalization.py +1 -1
  21. envelope/wrappers/observation_normalization_wrapper.py +28 -16
  22. envelope/wrappers/pooled_init_vmap_wrapper.py +122 -0
  23. envelope/wrappers/state_injection_wrapper.py +18 -22
  24. envelope/wrappers/truncation_wrapper.py +18 -14
  25. envelope/wrappers/vmap_envs_wrapper.py +26 -21
  26. envelope/wrappers/vmap_wrapper.py +36 -21
  27. envelope/wrappers/wrapper.py +8 -8
  28. {jax_envelope-0.1.0.dist-info → jax_envelope-0.2.0.dist-info}/METADATA +3 -3
  29. jax_envelope-0.2.0.dist-info/RECORD +32 -0
  30. envelope/wrappers/timestep_wrapper.py +0 -22
  31. jax_envelope-0.1.0.dist-info/RECORD +0 -27
  32. {jax_envelope-0.1.0.dist-info → jax_envelope-0.2.0.dist-info}/WHEEL +0 -0
  33. {jax_envelope-0.1.0.dist-info → jax_envelope-0.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -17,34 +17,39 @@ class VmapEnvsWrapper(Wrapper):
17
17
  Usage:
18
18
  envs = jax.vmap(make_env)(params_batch) # env pytree batched on leading axis
19
19
  wrapped = VmapEnvsWrapper(env=envs, batch_size=B)
20
- state, info = wrapped.reset(keys) # keys shape (B, 2) or single key
20
+ state, info = wrapped.init(keys) # keys shape (B, 2) or single key
21
21
  next_state, info = wrapped.step(state, action)
22
22
  """
23
23
 
24
24
  batch_size: int = field(kw_only=True)
25
25
 
26
- @override
27
- def reset(
28
- self, key: Key, state: PyTree | None = None, **kwargs
29
- ) -> tuple[WrappedState, Info]:
26
+ def _split_keys(self, key: Key) -> Key:
30
27
  if key.shape == (2,):
31
- keys = jax.random.split(key, self.batch_size)
32
- else:
33
- if key.shape[0] != self.batch_size:
34
- raise ValueError(
35
- f"reset key's leading dimension ({key.shape[0]}) must match "
36
- f"batch_size ({self.batch_size})."
37
- )
38
- keys = key
39
- # vmap over env 'self' and keys
40
- state, info = jax.vmap(lambda e, k: e.reset(k, state, **kwargs))(self.env, keys)
28
+ return jax.random.split(key, self.batch_size)
29
+ if key.shape[0] != self.batch_size:
30
+ raise ValueError(
31
+ f"reset key's leading dimension ({key.shape[0]}) must match "
32
+ f"batch_size ({self.batch_size})."
33
+ )
34
+ return key
35
+
36
+ @override
37
+ def init(self, key: Key) -> tuple[WrappedState, Info]:
38
+ keys = self._split_keys(key)
39
+ state, info = jax.vmap(lambda e, k: e.init(k))(self.env, keys)
40
+ return state, info
41
+
42
+ @override
43
+ def reset(self, key: Key, state: PyTree) -> tuple[WrappedState, Info]:
44
+ keys = self._split_keys(key)
45
+ state, info = jax.vmap(lambda e, k, s: e.reset(k, s))(
46
+ self.env, keys, state
47
+ )
41
48
  return state, info
42
49
 
43
50
  @override
44
- def step(
45
- self, state: WrappedState, action: PyTree, **kwargs
46
- ) -> tuple[WrappedState, Info]:
47
- next_state, info = jax.vmap(lambda e, s, a: e.step(s, a, **kwargs))(
51
+ def step(self, state: WrappedState, action: PyTree) -> tuple[WrappedState, Info]:
52
+ next_state, info = jax.vmap(lambda e, s, a: e.step(s, a))(
48
53
  self.env, state, action
49
54
  )
50
55
  return next_state, info
@@ -53,13 +58,13 @@ class VmapEnvsWrapper(Wrapper):
53
58
  @property
54
59
  def observation_space(self) -> spaces.Space:
55
60
  env0 = _index_env(self.env, 0, self.batch_size)
56
- return spaces.batch_space(env0.observation_space, self.batch_size)
61
+ return spaces.BatchedSpace(space=env0.observation_space, batch_size=self.batch_size)
57
62
 
58
63
  @override
59
64
  @cached_property
60
65
  def action_space(self) -> spaces.Space:
61
66
  env0 = _index_env(self.env, 0, self.batch_size)
62
- return spaces.batch_space(env0.action_space, self.batch_size)
67
+ return spaces.BatchedSpace(space=env0.action_space, batch_size=self.batch_size)
63
68
 
64
69
  @override
65
70
  @property
@@ -2,50 +2,65 @@ from functools import cached_property
2
2
  from typing import override
3
3
 
4
4
  import jax
5
+ import jax.numpy as jnp
5
6
 
6
7
  from envelope import spaces
7
8
  from envelope.environment import Info
8
- from envelope.struct import field
9
+ from envelope.struct import static_field
9
10
  from envelope.typing import Key, PyTree
10
11
  from envelope.wrappers.wrapper import WrappedState, Wrapper
11
12
 
12
13
 
14
+ def is_single_key(key):
15
+ # New-style typed keys have dtype like key<fry>
16
+ if jnp.issubdtype(key.dtype, jax.dtypes.prng_key):
17
+ return key.ndim == 0
18
+ return key.shape == (2,)
19
+
20
+
21
+ def _split_or_keep_key(key: Key, batch_size: int) -> Key:
22
+ if is_single_key(key):
23
+ return jax.random.split(key, batch_size)
24
+ elif key.shape[0] == batch_size:
25
+ return key
26
+ raise ValueError(
27
+ f"reset key's leading dimension ({key.shape[0]}) must match "
28
+ f"batch_size ({batch_size})."
29
+ )
30
+
31
+
13
32
  class VmapWrapper(Wrapper):
14
- """Does not forward kwargs to the underlying env. Does not wrap the state."""
33
+ """Does not wrap the state."""
15
34
 
16
- batch_size: int = field(kw_only=True)
35
+ batch_size: int = static_field(kw_only=True)
17
36
 
18
37
  @override
19
- def reset(
20
- self, key: Key, state: PyTree | None = None, **kwargs
21
- ) -> tuple[WrappedState, Info]:
22
- # Accept single key or batched keys
23
- if key.shape == (2,):
24
- keys = jax.random.split(key, self.batch_size)
25
- else:
26
- if key.shape[0] != self.batch_size:
27
- raise ValueError(
28
- f"reset key's leading dimension ({key.shape[0]}) must match "
29
- f"batch_size ({self.batch_size})."
30
- )
31
- keys = key
38
+ def init(self, key: Key) -> tuple[WrappedState, Info]:
39
+ keys = _split_or_keep_key(key, self.batch_size)
40
+ state, info = jax.vmap(self.env.init)(keys)
41
+ return state, info
32
42
 
43
+ @override
44
+ def reset(self, key: Key, state: PyTree) -> tuple[WrappedState, Info]:
45
+ keys = _split_or_keep_key(key, self.batch_size)
33
46
  state, info = jax.vmap(self.env.reset)(keys, state)
34
47
  return state, info
35
48
 
36
49
  @override
37
- def step(
38
- self, state: WrappedState, action: PyTree, **kwargs
39
- ) -> tuple[WrappedState, Info]:
50
+ def step(self, state: WrappedState, action: PyTree) -> tuple[WrappedState, Info]:
40
51
  state, info = jax.vmap(self.env.step)(state, action)
41
52
  return state, info
42
53
 
43
54
  @override
44
55
  @cached_property
45
56
  def observation_space(self) -> spaces.Space:
46
- return spaces.batch_space(self.env.observation_space, self.batch_size)
57
+ return spaces.BatchedSpace(
58
+ space=self.env.observation_space, batch_size=self.batch_size
59
+ )
47
60
 
48
61
  @override
49
62
  @cached_property
50
63
  def action_space(self) -> spaces.Space:
51
- return spaces.batch_space(self.env.action_space, self.batch_size)
64
+ return spaces.BatchedSpace(
65
+ space=self.env.action_space, batch_size=self.batch_size
66
+ )
@@ -25,16 +25,16 @@ class Wrapper(Environment):
25
25
  env: Environment = field(kw_only=True)
26
26
 
27
27
  @override
28
- def reset(
29
- self, key: Key, state: State | None = None, **kwargs
30
- ) -> tuple[State, Info]:
31
- return self.env.reset(key, state=state, **kwargs)
28
+ def init(self, key: Key) -> tuple[State, Info]:
29
+ return self.env.init(key)
32
30
 
33
31
  @override
34
- def step(
35
- self, state: WrappedState, action: PyTree, **kwargs
36
- ) -> tuple[WrappedState, Info]:
37
- return self.env.step(state, action, **kwargs)
32
+ def reset(self, key: Key, state: State) -> tuple[State, Info]:
33
+ return self.env.reset(key, state)
34
+
35
+ @override
36
+ def step(self, state: WrappedState, action: PyTree) -> tuple[WrappedState, Info]:
37
+ return self.env.step(state, action)
38
38
 
39
39
  @override
40
40
  @cached_property
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jax-envelope
3
- Version: 0.1.0
3
+ Version: 0.2.0
4
4
  Summary: A JAX-native environment interface with powerful wrappers and adapters for popular RL environment suites
5
5
  Project-URL: Homepage, https://github.com/keraJLi/envelope
6
6
  Project-URL: Repository, https://github.com/keraJLi/envelope
@@ -81,7 +81,7 @@ let's you create environments from any of the above!
81
81
  pip install jax-envelope
82
82
  ```
83
83
 
84
- ## 💞 Related projects
85
- * [stoax](https://github.com/EdanToledo/Stoa) is a very similar project that provides adapters and wrappers for the jumanji-like interface.
84
+ ## 💞 Related projects
85
+ * [stoa](https://github.com/EdanToledo/Stoa) is a very similar project that provides adapters and wrappers for the jumanji-like interface.
86
86
  * Check out all the great suites we have adapters for! [gymnax](https://github.com/RobertTLange/gymnax), [brax](https://github.com/google/brax), [jumanji](https://github.com/instadeepai/jumanji), [kinetix](https://github.com/flairox/kinetix), [craftax](https://github.com/MichaelTMatthews/craftax), [mujoco_playground](https://github.com/google-deepmind/mujoco_playground).
87
87
  * We will be adding support for [jaxmarl](https://github.com/flairox/jaxmarl) and [pgx](https://github.com/sotetsuk/pgx) in the future, as soon as we figured out the best ever MARL interface for JAX!
@@ -0,0 +1,32 @@
1
+ envelope/__init__.py,sha256=AAk0b9UHGUTC4EzERXYPOrHuECgO0fx_UN4_gpKvgT0,1329
2
+ envelope/environment.py,sha256=OC15ooivLvD8EpmgRQGL-UjRpF0UXO0fOGu61I8HmRs,2015
3
+ envelope/spaces.py,sha256=Nb8T11ilCtqKPSLZ98oqhnuw_8JNLOGCZdsy2X_pSwU,6944
4
+ envelope/struct.py,sha256=tO1LLxk0bYc4kBCXraDv93T5d-xFRzpv4QCjASB0Lv4,5522
5
+ envelope/typing.py,sha256=2rWKiZnKXaUU6JIl51Wcj_txRwe8wQNRSLSiRuT3OP4,127
6
+ envelope/compat/__init__.py,sha256=6q2B2bfTIu7MlnIXc702ysqEnyURz-MEjTBMBCo8rdQ,3458
7
+ envelope/compat/brax_envelope.py,sha256=XV07CYApXmf5g-Bin4pKWqExTZ7gd6AO2BCx3bk1Z1c,3482
8
+ envelope/compat/craftax_envelope.py,sha256=ey_mhJh5KCDRvjpCAJbFexpWiFjzQr7BZuRYLM9g-yA,3959
9
+ envelope/compat/gymnax_envelope.py,sha256=X0pfFKaFKzvBl2Zih4utiXSEI-lc5f9miJGMxeT-Xfw,4690
10
+ envelope/compat/jumanji_envelope.py,sha256=WhU4WJLW2Fn4BG1_69EJO9ZiBDWEaY387dhkRZjqcGc,4916
11
+ envelope/compat/kinetix_envelope.py,sha256=99RNN6nBN9CQEEe2fX2T1n8lL9ki0uKaz5BK052IU1o,6853
12
+ envelope/compat/mujoco_playground_envelope.py,sha256=7iPaRuwXUHbP8hHNgSFipW8rJVruootMQPYBWNpPiHA,3705
13
+ envelope/compat/navix_envelope.py,sha256=2n8vgr6gVeAYo12D1kDPxUrA9Hm-If3E7rV7TF_ksSA,3026
14
+ envelope/wrappers/__init__.py,sha256=GFs54kY1aWbf9xymzm22x6M2AhZCwFLx1f5LF23jX94,1409
15
+ envelope/wrappers/autoreset_wrapper.py,sha256=MJlbOj94_6I1KQ4oBNsKEB416YWp9ZMiqn85b1PdhKQ,3382
16
+ envelope/wrappers/clip_action_wrapper.py,sha256=YKJbKBUbf-LADxIfCE6mMOKo_NEa2xPfDncNKg9iEHM,955
17
+ envelope/wrappers/continuous_observation_wrapper.py,sha256=geDOeQDuASiRDEv5U3fF7GxblfB0Ku99GvTUGZSteJs,2041
18
+ envelope/wrappers/episode_statistics_wrapper.py,sha256=-LzypX4aYS9mCkl33M2mf3aT3OhXWtDc_kVU6VlSJ0o,1433
19
+ envelope/wrappers/flatten_action_wrapper.py,sha256=znR7WSebLy2wZnEWHi9SshxgV9b29khg-6ogQlmXq9U,2609
20
+ envelope/wrappers/flatten_observation_wrapper.py,sha256=xgoToXdJ_6M7cwfL7pu91ZnnFh3LHPR_DSe7XJwkh2Q,2839
21
+ envelope/wrappers/normalization.py,sha256=YkNEZqvFSaPurBA-WfXDynV51552ucNkY0JqNSAKNUE,1819
22
+ envelope/wrappers/observation_normalization_wrapper.py,sha256=pAwyd01lAxvEu1Z4KZfhSRoVNeLpeohtw1t1wsE53w4,4921
23
+ envelope/wrappers/pooled_init_vmap_wrapper.py,sha256=LAgrEjLXlq-gJiCGzI357a41tx1mVgb1ZQwtgN8mTuw,4985
24
+ envelope/wrappers/state_injection_wrapper.py,sha256=L8f9yrlSigk0jNx6D6wSu0cR0bgssq2uJNd9OYPAY0E,3553
25
+ envelope/wrappers/truncation_wrapper.py,sha256=-CSoXs-mDy1AiJVKLUc2zBHQxalhYjjyAAu1NmiIUsE,1345
26
+ envelope/wrappers/vmap_envs_wrapper.py,sha256=gE_TAmfsu2PJlDGWTz5xENJBJ-6nSJpTUlKhOnmZPi0,2757
27
+ envelope/wrappers/vmap_wrapper.py,sha256=B43lFNdAN_d8LgN6Rz6Ty85r9zcXLB3Wrn8R-Z2qHps,1973
28
+ envelope/wrappers/wrapper.py,sha256=tWluyIg4eGl60R2v3qaahU5dlmKKkBsJiNj7GnWR06E,1517
29
+ jax_envelope-0.2.0.dist-info/METADATA,sha256=jCWoAuoFsrLJakIhecRdz1zOIGCdrD4dBS0JSA2LVvw,4651
30
+ jax_envelope-0.2.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
31
+ jax_envelope-0.2.0.dist-info/licenses/LICENSE,sha256=VyF-MK-gY2_fZlhf8uEnE2y8ziIXK-w55GM12eOgXrQ,1069
32
+ jax_envelope-0.2.0.dist-info/RECORD,,
@@ -1,22 +0,0 @@
1
- from envelope.environment import Info
2
- from envelope.struct import field
3
- from envelope.typing import Key, PyTree
4
- from envelope.wrappers.wrapper import WrappedState, Wrapper
5
-
6
-
7
- class TimeStepWrapper(Wrapper):
8
- class TimeStepState(WrappedState):
9
- steps: PyTree = field(default=0)
10
-
11
- def reset(
12
- self, key: Key, state: PyTree | None = None, **kwargs
13
- ) -> tuple[WrappedState, Info]:
14
- inner_state, info = self.env.reset(key, state, **kwargs)
15
- return self.TimeStepState(inner_state=inner_state, steps=0), info
16
-
17
- def step(
18
- self, state: WrappedState, action: PyTree, **kwargs
19
- ) -> tuple[WrappedState, Info]:
20
- next_inner_state, info = self.env.step(state.inner_state, action, **kwargs)
21
- next_steps = getattr(state, "steps", 0) + 1
22
- return self.TimeStepState(inner_state=next_inner_state, steps=next_steps), info
@@ -1,27 +0,0 @@
1
- envelope/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- envelope/environment.py,sha256=4T0IosoRX-6gVvZCKBUEitlM_5cAT43hrKQ0Hr53XGk,1735
3
- envelope/spaces.py,sha256=fvJ3-ZI3iyFlpfEN1barTd9cbWvQ11nl-7Y45KgFizs,6433
4
- envelope/struct.py,sha256=Sb8GsxZ7rFF5A5128oZWEQVnFK79xUFe6EaU2bghBOw,5158
5
- envelope/typing.py,sha256=dcftDRNM0luCqHDEJ_qVmZnjUOEwn-HiFvwuV1z1Bn0,734
6
- envelope/compat/__init__.py,sha256=6q2B2bfTIu7MlnIXc702ysqEnyURz-MEjTBMBCo8rdQ,3458
7
- envelope/compat/brax_envelope.py,sha256=T0Xefrpgft16SdRJKMIjzdJq5tbocQ4jeI763PDzZ_4,3445
8
- envelope/compat/craftax_envelope.py,sha256=t151fDCrgPWfynERXHYXiWBwCGAXbFcreuIyH7GCApg,3162
9
- envelope/compat/gymnax_envelope.py,sha256=iau4CKFLVIJkRA9WILgOqkp3mXIAzkilKcrKF8vkld4,3790
10
- envelope/compat/jumanji_envelope.py,sha256=_Tox7DfiyE-Z4drszPUe-arPurz6hMeFtRNfy8o9Ipc,4869
11
- envelope/compat/kinetix_envelope.py,sha256=kPLWwSoQ49kNXuvfbStOtoGby9EkV7bKAetxUa1_YPA,6817
12
- envelope/compat/mujoco_playground_envelope.py,sha256=UcYZwOPD27473wmDDb5n2eK46uPwOeh0LZOgyov4rww,3706
13
- envelope/compat/navix_envelope.py,sha256=K7AGae9_kDETFAUuWH9sCCLDo5fuEFbc0eKTHcvSxek,3027
14
- envelope/wrappers/autoreset_wrapper.py,sha256=3OUUNb4L6P4Ncn57Uy2ldoLFr1IKnPSWuzyjf20N0Y8,1299
15
- envelope/wrappers/episode_statistics_wrapper.py,sha256=Mj5Ua7cLtBQYtFn3oQB9eJ7TclKwAFUvsQHIRkcQJvs,1509
16
- envelope/wrappers/normalization.py,sha256=xHezXsb1J09D3IZACC1xxMy9NuUSBWqCyYH2wFAItPs,1811
17
- envelope/wrappers/observation_normalization_wrapper.py,sha256=NNHz_THFe2eQUY0Pued9_hJztIGU6SU7kae_AKFbe4k,4248
18
- envelope/wrappers/state_injection_wrapper.py,sha256=yk7he1zPaEdzjks_NDnYDC91e5bXTlCFPKFMwm-A7jk,3687
19
- envelope/wrappers/timestep_wrapper.py,sha256=6jS-80AwnIIct3Ool9zq_iGOyfodNbs89NfuyYMfYIE,874
20
- envelope/wrappers/truncation_wrapper.py,sha256=TzJqxjAjipk7pTk-drXRvQ41uR-jWsHlAgYJpbHI32M,1132
21
- envelope/wrappers/vmap_envs_wrapper.py,sha256=TezSCeOf3oJKj8w7i_UaAA2LsoWkhhp835W545Q-waI,2561
22
- envelope/wrappers/vmap_wrapper.py,sha256=zYMYVvuVQfRM7d2ygxxcavRNtkdzkqhArakE507tcU4,1587
23
- envelope/wrappers/wrapper.py,sha256=ydolcZR9ogW3z4ui07845mbMue2XtEoEJ64tHCzAAtw,1504
24
- jax_envelope-0.1.0.dist-info/METADATA,sha256=-ASxsWLgF9CwcQ-mp1NJAmdFDS4qzs8bekNJH291R00,4653
25
- jax_envelope-0.1.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
26
- jax_envelope-0.1.0.dist-info/licenses/LICENSE,sha256=VyF-MK-gY2_fZlhf8uEnE2y8ziIXK-w55GM12eOgXrQ,1069
27
- jax_envelope-0.1.0.dist-info/RECORD,,