unienv 0.0.1b1__py3-none-any.whl → 0.0.1b3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. unienv-0.0.1b3.dist-info/METADATA +74 -0
  2. unienv-0.0.1b3.dist-info/RECORD +92 -0
  3. {unienv-0.0.1b1.dist-info → unienv-0.0.1b3.dist-info}/licenses/LICENSE +1 -3
  4. unienv-0.0.1b3.dist-info/top_level.txt +2 -0
  5. unienv_data/base/__init__.py +0 -1
  6. unienv_data/base/common.py +95 -45
  7. unienv_data/base/storage.py +1 -0
  8. unienv_data/batches/__init__.py +2 -1
  9. unienv_data/batches/backend_compat.py +47 -1
  10. unienv_data/batches/combined_batch.py +2 -4
  11. unienv_data/{base → batches}/transformations.py +3 -2
  12. unienv_data/replay_buffer/replay_buffer.py +4 -0
  13. unienv_data/samplers/__init__.py +0 -1
  14. unienv_data/samplers/multiprocessing_sampler.py +26 -22
  15. unienv_data/samplers/step_sampler.py +9 -18
  16. unienv_data/storages/common.py +5 -0
  17. unienv_data/storages/hdf5.py +291 -20
  18. unienv_data/storages/pytorch.py +1 -0
  19. unienv_data/storages/transformation.py +191 -0
  20. unienv_data/transformations/image_compress.py +213 -0
  21. unienv_interface/backends/jax.py +4 -1
  22. unienv_interface/backends/numpy.py +4 -1
  23. unienv_interface/backends/pytorch.py +4 -1
  24. unienv_interface/env_base/__init__.py +1 -0
  25. unienv_interface/env_base/env.py +5 -0
  26. unienv_interface/env_base/funcenv.py +32 -1
  27. unienv_interface/env_base/funcenv_wrapper.py +2 -2
  28. unienv_interface/env_base/vec_env.py +474 -0
  29. unienv_interface/func_wrapper/__init__.py +2 -1
  30. unienv_interface/func_wrapper/frame_stack.py +150 -0
  31. unienv_interface/space/space_utils/__init__.py +1 -0
  32. unienv_interface/space/space_utils/batch_utils.py +83 -0
  33. unienv_interface/space/space_utils/construct_utils.py +216 -0
  34. unienv_interface/space/space_utils/serialization_utils.py +16 -1
  35. unienv_interface/space/spaces/__init__.py +3 -1
  36. unienv_interface/space/spaces/batched.py +90 -0
  37. unienv_interface/space/spaces/binary.py +0 -1
  38. unienv_interface/space/spaces/box.py +13 -24
  39. unienv_interface/space/spaces/text.py +1 -3
  40. unienv_interface/transformations/dict_transform.py +31 -5
  41. unienv_interface/utils/control_util.py +68 -0
  42. unienv_interface/utils/data_queue.py +184 -0
  43. unienv_interface/utils/stateclass.py +46 -0
  44. unienv_interface/utils/vec_util.py +15 -0
  45. unienv_interface/world/__init__.py +3 -1
  46. unienv_interface/world/combined_funcnode.py +336 -0
  47. unienv_interface/world/combined_node.py +232 -0
  48. unienv_interface/wrapper/backend_compat.py +2 -2
  49. unienv_interface/wrapper/frame_stack.py +19 -114
  50. unienv_interface/wrapper/video_record.py +11 -2
  51. unienv-0.0.1b1.dist-info/METADATA +0 -20
  52. unienv-0.0.1b1.dist-info/RECORD +0 -85
  53. unienv-0.0.1b1.dist-info/top_level.txt +0 -4
  54. unienv_data/samplers/slice_sampler.py +0 -266
  55. unienv_maniskill/__init__.py +0 -1
  56. unienv_maniskill/wrapper/maniskill_compat.py +0 -235
  57. unienv_mjxplayground/__init__.py +0 -1
  58. unienv_mjxplayground/wrapper/playground_compat.py +0 -256
  59. {unienv-0.0.1b1.dist-info → unienv-0.0.1b3.dist-info}/WHEEL +0 -0
@@ -1,256 +0,0 @@
1
- from typing import Any, Optional, Tuple, Dict, Union, SupportsFloat, Sequence, Callable, Mapping
2
- import jax
3
- import jax.numpy as jnp
4
- import numpy as np
5
-
6
- from unienv_interface.env_base import FuncEnv
7
- from unienv_interface.space import Space, BoxSpace, DictSpace
8
- from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
9
- from unienv_interface.backends.numpy import NumpyComputeBackend
10
- from unienv_interface.backends.jax import JaxComputeBackend, JaxArrayType, JaxDeviceType, JaxDtypeType, JaxRNGType
11
- from unienv_interface.space.space_utils import batch_utils as sbu
12
- from unienv_interface.wrapper import backend_compat
13
-
14
- import mujoco.mjx as mjx
15
- from mujoco_playground import MjxEnv, State as MjxState
16
- from mujoco_playground._src.wrapper import Wrapper, MadronaWrapper, BraxDomainRandomizationVmapWrapper
17
- from brax.envs.wrappers import training as brax_training
18
- from flax import struct
19
-
20
- AxisMapSingleT = Union[Mapping[str, "AxisMapSingleT"], int, None]
21
- AxisMapT = Union[AxisMapSingleT, Tuple[AxisMapSingleT, ...]]
22
- JaxTreeOrArrayT = Union[JaxArrayType, Dict[str, Any]]
23
- RandomizationFnT = Callable[[mjx.Model], Tuple[mjx.Model, AxisMapSingleT]]
24
-
25
- @struct.dataclass
26
- class MJXPlaygroundState:
27
- state : MjxState
28
- rng : JaxRNGType
29
-
30
- def is_mjx_env_vision(
31
- env: MjxEnv
32
- ) -> bool:
33
- try:
34
- from madrona_mjx.renderer import BatchRenderer # pytype: disable=import-error
35
- except ImportError:
36
- return False
37
- if hasattr(env, "renderer") and isinstance(env.renderer, BatchRenderer):
38
- return True
39
- return False
40
-
41
- def wrap_mjx_env(
42
- env: MjxEnv,
43
- batch_size: int,
44
- is_vision : bool = False,
45
- randomization_fn : Optional[
46
- RandomizationFnT
47
- ] = None,
48
- ) -> Wrapper:
49
- if is_vision:
50
- env = MadronaWrapper(env, batch_size, randomization_fn)
51
- elif randomization_fn is not None:
52
- env = BraxDomainRandomizationVmapWrapper(env, randomization_fn)
53
- else:
54
- env = brax_training.VmapWrapper(env)
55
- return env
56
-
57
- def space_from_size(
58
- size : AxisMapSingleT,
59
- device : Optional[JaxDeviceType] = None,
60
- ) -> Union[
61
- BoxSpace[JaxArrayType, JaxDeviceType, JaxDtypeType, JaxRNGType],
62
- DictSpace[JaxDeviceType, JaxDtypeType, JaxRNGType]
63
- ]:
64
- if isinstance(size, Mapping):
65
- return DictSpace(
66
- JaxComputeBackend,
67
- {
68
- key: space_from_size(size_inner, device=device)
69
- for key, size_inner in size.items()
70
- },
71
- device=device
72
- )
73
- else:
74
- return BoxSpace(
75
- JaxComputeBackend,
76
- -jnp.inf,
77
- jnp.inf,
78
- shape=size if not isinstance(size, int) else (size,),
79
- dtype=jnp.float32,
80
- device=device
81
- )
82
-
83
- class FromMJXPlaygroundEnv(
84
- FuncEnv[
85
- MjxState,
86
- None,
87
- JaxArrayType,
88
- None,
89
- JaxTreeOrArrayT,
90
- JaxTreeOrArrayT,
91
- None,
92
- JaxDeviceType,
93
- JaxDtypeType,
94
- JaxRNGType
95
- ]
96
- ):
97
- metadata = {
98
- "render_modes": ['rgb_array', 'human']
99
- }
100
- backend = JaxComputeBackend
101
- def __init__(
102
- self,
103
- single_env: MjxEnv,
104
- batch_size: int,
105
- randomization_fn : Optional[
106
- RandomizationFnT
107
- ] = None,
108
- device : Optional[JaxDeviceType] = None,
109
- jit : bool = True
110
- ) -> None:
111
- self.single_env = single_env
112
- self.env = wrap_mjx_env(
113
- single_env,
114
- batch_size,
115
- is_vision=is_mjx_env_vision(single_env),
116
- randomization_fn=randomization_fn
117
- )
118
- if jit:
119
- self.vanilla_reset_fn = jax.jit(self.env.reset, device=self.device)
120
- self.vanilla_step_fn = jax.jit(self.env.step, device=self.device)
121
- else:
122
- self.vanilla_reset_fn = self.env.reset
123
- self.vanilla_step_fn = self.env.step
124
-
125
- self.batch_size = batch_size
126
- self.device = device
127
-
128
- self.action_space = sbu.batch_space(
129
- space_from_size(
130
- single_env.action_size,
131
- device=self.device
132
- ),
133
- batch_size
134
- )
135
- self.observation_space = sbu.batch_space(
136
- space_from_size(
137
- single_env.observation_size,
138
- device=self.device
139
- ),
140
- batch_size
141
- )
142
-
143
- self.context_space = None
144
-
145
- def initial(self, *, seed : Optional[int]) -> Tuple[
146
- MJXPlaygroundState,
147
- None,
148
- JaxTreeOrArrayT,
149
- Dict[str, Any]
150
- ]:
151
- rng = jax.random.PRNGKey(seed)
152
- rng, reset_rng = jax.random.split(rng)
153
- reset_rng = jax.random.split(reset_rng, self.batch_size)
154
- raw_state = self.vanilla_reset_fn(rng=reset_rng)
155
- state = MJXPlaygroundState(
156
- state=raw_state,
157
- rng=rng
158
- )
159
- return state, None, raw_state.obs, raw_state.info
160
-
161
- def reset(
162
- self,
163
- state : MJXPlaygroundState,
164
- *,
165
- seed : Optional[int] = None,
166
- mask : Optional[JaxArrayType] = None
167
- ) -> Tuple[
168
- MJXPlaygroundState,
169
- None,
170
- JaxTreeOrArrayT,
171
- Dict[str, Any]
172
- ]:
173
- if seed is None:
174
- rng = state.rng
175
- else:
176
- rng = jax.random.PRNGKey(seed)
177
- rng, reset_rng = jax.random.split(rng)
178
- reset_rng = jax.random.split(reset_rng, self.batch_size)
179
- reset_state = self.env.reset(
180
- rng=reset_rng
181
- )
182
- if mask is None:
183
- return MJXPlaygroundState(
184
- state=reset_state,
185
- rng=rng
186
- ), None, reset_state.obs, reset_state.info
187
- else:
188
- def where_reset(
189
- x,y
190
- ) -> bool:
191
- mask_casted = jnp.reshape(mask, [mask.shape[0]] + [1]*(len(x.shape)-1))
192
- return jnp.where(mask_casted, x, y)
193
- def pick_reset(x):
194
- return x[mask]
195
-
196
- new_state = jax.tree.map(
197
- where_reset, reset_state, state.state
198
- )
199
- reset_obs = jax.tree.map(
200
- pick_reset, reset_state.obs
201
- ) if isinstance(reset_state.obs, Mapping) else reset_state.obs[mask]
202
- reset_info = jax.tree.map(
203
- pick_reset, reset_state.info
204
- )
205
- return MJXPlaygroundState(
206
- state=new_state,
207
- rng=rng
208
- ), None, reset_obs, reset_info
209
-
210
- def step(
211
- self,
212
- state : MJXPlaygroundState,
213
- action : JaxTreeOrArrayT
214
- ) -> Tuple[
215
- MJXPlaygroundState,
216
- JaxTreeOrArrayT,
217
- JaxArrayType,
218
- JaxArrayType,
219
- JaxArrayType,
220
- Dict[str, Any]
221
- ]:
222
- step_state = self.vanilla_step_fn(
223
- state.state,
224
- action
225
- )
226
- return (
227
- MJXPlaygroundState(
228
- state=step_state,
229
- rng=state.rng
230
- ),
231
- step_state.obs,
232
- step_state.reward,
233
- step_state.done,
234
- jnp.zeros_like(step_state.done, dtype=jnp.bool, device=self.device),
235
- step_state.info
236
- )
237
-
238
- # =========== Wrapper methods ==========
239
- def has_wrapper_attr(self, name: str) -> bool:
240
- return hasattr(self, name) or hasattr(self.env, name)
241
-
242
- def get_wrapper_attr(self, name: str) -> Any:
243
- if hasattr(self, name):
244
- return getattr(self, name)
245
- elif hasattr(self.env, name):
246
- return getattr(self.env, name)
247
- else:
248
- raise AttributeError(f"Attribute {name} not found in the environment.")
249
-
250
- def set_wrapper_attr(self, name: str, value: Any):
251
- if hasattr(self, name):
252
- setattr(self, name, value)
253
- elif hasattr(self.env, name):
254
- setattr(self.env, name, value)
255
- else:
256
- raise AttributeError(f"Attribute {name} not found in the environment.")