jax-envelope 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
envelope/__init__.py CHANGED
@@ -0,0 +1,42 @@
1
+ from envelope.compat import create
2
+ from envelope.environment import Environment, Info, InfoContainer
3
+ from envelope.spaces import BatchedSpace, Continuous, Discrete, PyTreeSpace, Space
4
+ from envelope.struct import field, static_field, FrozenPyTreeNode, Container
5
+ from envelope.wrappers import (
6
+ Wrapper,
7
+ WrappedState,
8
+ AutoResetWrapper,
9
+ ObservationNormalizationWrapper,
10
+ StateInjectionWrapper,
11
+ TruncationWrapper,
12
+ VmapWrapper,
13
+ VmapEnvsWrapper,
14
+ )
15
+
16
+ __all__ = [
17
+ # Basic functionality
18
+ "create",
19
+ "Environment",
20
+ "Info",
21
+ "InfoContainer",
22
+ # Spaces
23
+ "Space",
24
+ "BatchedSpace",
25
+ "Continuous",
26
+ "Discrete",
27
+ "PyTreeSpace",
28
+ # Struct
29
+ "field",
30
+ "static_field",
31
+ "FrozenPyTreeNode",
32
+ "Container",
33
+ # Wrappers
34
+ "Wrapper",
35
+ "WrappedState",
36
+ "AutoResetWrapper",
37
+ "ObservationNormalizationWrapper",
38
+ "StateInjectionWrapper",
39
+ "TruncationWrapper",
40
+ "VmapWrapper",
41
+ "VmapEnvsWrapper",
42
+ ]
envelope/typing.py CHANGED
@@ -1,4 +1,3 @@
1
- from enum import Enum
2
1
  from typing import Any, TypeAlias
3
2
 
4
3
  import jax
@@ -6,18 +5,3 @@ import jax
6
5
  PyTree: TypeAlias = Any
7
6
  Key: TypeAlias = jax.Array
8
7
  Array: TypeAlias = jax.Array
9
-
10
-
11
- class BatchKind(Enum):
12
- """
13
- Batch semantics for environments.
14
- - VMAP: environment represents instances compatible with `jax.vmap`, for example by
15
- wrapping it in a `VmapWrapper`.
16
- - NATIVE_POOL: environment represents a batch of instances via a native pool. This
17
- is the case when it is wrapping a non-jax-based environment that supports native
18
- batching, for example those provided by envpool. Environments in this mode cannot
19
- be vmapped, as it would break the native batching semantics.
20
- """
21
-
22
- VMAP = "vmap"
23
- NATIVE_POOL = "native_pool"
@@ -0,0 +1,20 @@
1
+ from envelope.wrappers.autoreset_wrapper import AutoResetWrapper
2
+ from envelope.wrappers.observation_normalization_wrapper import (
3
+ ObservationNormalizationWrapper,
4
+ )
5
+ from envelope.wrappers.state_injection_wrapper import StateInjectionWrapper
6
+ from envelope.wrappers.truncation_wrapper import TruncationWrapper
7
+ from envelope.wrappers.vmap_wrapper import VmapWrapper
8
+ from envelope.wrappers.vmap_envs_wrapper import VmapEnvsWrapper
9
+ from envelope.wrappers.wrapper import Wrapper, WrappedState
10
+
11
+ __all__ = [
12
+ "Wrapper",
13
+ "WrappedState",
14
+ "AutoResetWrapper",
15
+ "ObservationNormalizationWrapper",
16
+ "StateInjectionWrapper",
17
+ "TruncationWrapper",
18
+ "VmapWrapper",
19
+ "VmapEnvsWrapper",
20
+ ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jax-envelope
3
- Version: 0.1.0
3
+ Version: 0.1.1
4
4
  Summary: A JAX-native environment interface with powerful wrappers and adapters for popular RL environment suites
5
5
  Project-URL: Homepage, https://github.com/keraJLi/envelope
6
6
  Project-URL: Repository, https://github.com/keraJLi/envelope
@@ -81,7 +81,7 @@ let's you create environments from any of the above!
81
81
  pip install jax-envelope
82
82
  ```
83
83
 
84
- ## 💞 Related projects
84
+ ## 💞 Related projects
85
85
  * [stoax](https://github.com/EdanToledo/Stoa) is a very similar project that provides adapters and wrappers for the jumanji-like interface.
86
86
  * Check out all the great suites we have adapters for! [gymnax](https://github.com/RobertTLange/gymnax), [brax](https://github.com/google/brax), [jumanji](https://github.com/instadeepai/jumanji), [kinetix](https://github.com/flairox/kinetix), [craftax](https://github.com/MichaelTMatthews/craftax), [mujoco_playground](https://github.com/google-deepmind/mujoco_playground).
87
87
  * We will be adding support for [jaxmarl](https://github.com/flairox/jaxmarl) and [pgx](https://github.com/sotetsuk/pgx) in the future, as soon as we figured out the best ever MARL interface for JAX!
@@ -1,8 +1,8 @@
1
- envelope/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
1
+ envelope/__init__.py,sha256=x4DtJ3WWPsPde1wSrXAXCdJVuXV3pioAojrrOvpJNqw,975
2
2
  envelope/environment.py,sha256=4T0IosoRX-6gVvZCKBUEitlM_5cAT43hrKQ0Hr53XGk,1735
3
3
  envelope/spaces.py,sha256=fvJ3-ZI3iyFlpfEN1barTd9cbWvQ11nl-7Y45KgFizs,6433
4
4
  envelope/struct.py,sha256=Sb8GsxZ7rFF5A5128oZWEQVnFK79xUFe6EaU2bghBOw,5158
5
- envelope/typing.py,sha256=dcftDRNM0luCqHDEJ_qVmZnjUOEwn-HiFvwuV1z1Bn0,734
5
+ envelope/typing.py,sha256=2rWKiZnKXaUU6JIl51Wcj_txRwe8wQNRSLSiRuT3OP4,127
6
6
  envelope/compat/__init__.py,sha256=6q2B2bfTIu7MlnIXc702ysqEnyURz-MEjTBMBCo8rdQ,3458
7
7
  envelope/compat/brax_envelope.py,sha256=T0Xefrpgft16SdRJKMIjzdJq5tbocQ4jeI763PDzZ_4,3445
8
8
  envelope/compat/craftax_envelope.py,sha256=t151fDCrgPWfynERXHYXiWBwCGAXbFcreuIyH7GCApg,3162
@@ -11,17 +11,17 @@ envelope/compat/jumanji_envelope.py,sha256=_Tox7DfiyE-Z4drszPUe-arPurz6hMeFtRNfy
11
11
  envelope/compat/kinetix_envelope.py,sha256=kPLWwSoQ49kNXuvfbStOtoGby9EkV7bKAetxUa1_YPA,6817
12
12
  envelope/compat/mujoco_playground_envelope.py,sha256=UcYZwOPD27473wmDDb5n2eK46uPwOeh0LZOgyov4rww,3706
13
13
  envelope/compat/navix_envelope.py,sha256=K7AGae9_kDETFAUuWH9sCCLDo5fuEFbc0eKTHcvSxek,3027
14
+ envelope/wrappers/__init__.py,sha256=r0Qz7gNkgJQF848Iv01uMGbeXTKHUdK2tjyom4nPCGw,701
14
15
  envelope/wrappers/autoreset_wrapper.py,sha256=3OUUNb4L6P4Ncn57Uy2ldoLFr1IKnPSWuzyjf20N0Y8,1299
15
16
  envelope/wrappers/episode_statistics_wrapper.py,sha256=Mj5Ua7cLtBQYtFn3oQB9eJ7TclKwAFUvsQHIRkcQJvs,1509
16
17
  envelope/wrappers/normalization.py,sha256=xHezXsb1J09D3IZACC1xxMy9NuUSBWqCyYH2wFAItPs,1811
17
18
  envelope/wrappers/observation_normalization_wrapper.py,sha256=NNHz_THFe2eQUY0Pued9_hJztIGU6SU7kae_AKFbe4k,4248
18
19
  envelope/wrappers/state_injection_wrapper.py,sha256=yk7he1zPaEdzjks_NDnYDC91e5bXTlCFPKFMwm-A7jk,3687
19
- envelope/wrappers/timestep_wrapper.py,sha256=6jS-80AwnIIct3Ool9zq_iGOyfodNbs89NfuyYMfYIE,874
20
20
  envelope/wrappers/truncation_wrapper.py,sha256=TzJqxjAjipk7pTk-drXRvQ41uR-jWsHlAgYJpbHI32M,1132
21
21
  envelope/wrappers/vmap_envs_wrapper.py,sha256=TezSCeOf3oJKj8w7i_UaAA2LsoWkhhp835W545Q-waI,2561
22
22
  envelope/wrappers/vmap_wrapper.py,sha256=zYMYVvuVQfRM7d2ygxxcavRNtkdzkqhArakE507tcU4,1587
23
23
  envelope/wrappers/wrapper.py,sha256=ydolcZR9ogW3z4ui07845mbMue2XtEoEJ64tHCzAAtw,1504
24
- jax_envelope-0.1.0.dist-info/METADATA,sha256=-ASxsWLgF9CwcQ-mp1NJAmdFDS4qzs8bekNJH291R00,4653
25
- jax_envelope-0.1.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
26
- jax_envelope-0.1.0.dist-info/licenses/LICENSE,sha256=VyF-MK-gY2_fZlhf8uEnE2y8ziIXK-w55GM12eOgXrQ,1069
27
- jax_envelope-0.1.0.dist-info/RECORD,,
24
+ jax_envelope-0.1.1.dist-info/METADATA,sha256=JdPA7sTzarKPGRTiCckPxwPpy2MU_Wm-5vaKW6dIelg,4652
25
+ jax_envelope-0.1.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
26
+ jax_envelope-0.1.1.dist-info/licenses/LICENSE,sha256=VyF-MK-gY2_fZlhf8uEnE2y8ziIXK-w55GM12eOgXrQ,1069
27
+ jax_envelope-0.1.1.dist-info/RECORD,,
@@ -1,22 +0,0 @@
1
- from envelope.environment import Info
2
- from envelope.struct import field
3
- from envelope.typing import Key, PyTree
4
- from envelope.wrappers.wrapper import WrappedState, Wrapper
5
-
6
-
7
- class TimeStepWrapper(Wrapper):
8
- class TimeStepState(WrappedState):
9
- steps: PyTree = field(default=0)
10
-
11
- def reset(
12
- self, key: Key, state: PyTree | None = None, **kwargs
13
- ) -> tuple[WrappedState, Info]:
14
- inner_state, info = self.env.reset(key, state, **kwargs)
15
- return self.TimeStepState(inner_state=inner_state, steps=0), info
16
-
17
- def step(
18
- self, state: WrappedState, action: PyTree, **kwargs
19
- ) -> tuple[WrappedState, Info]:
20
- next_inner_state, info = self.env.step(state.inner_state, action, **kwargs)
21
- next_steps = getattr(state, "steps", 0) + 1
22
- return self.TimeStepState(inner_state=next_inner_state, steps=next_steps), info