unienv 0.0.1b1__py3-none-any.whl → 0.0.1b3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- unienv-0.0.1b3.dist-info/METADATA +74 -0
- unienv-0.0.1b3.dist-info/RECORD +92 -0
- {unienv-0.0.1b1.dist-info → unienv-0.0.1b3.dist-info}/licenses/LICENSE +1 -3
- unienv-0.0.1b3.dist-info/top_level.txt +2 -0
- unienv_data/base/__init__.py +0 -1
- unienv_data/base/common.py +95 -45
- unienv_data/base/storage.py +1 -0
- unienv_data/batches/__init__.py +2 -1
- unienv_data/batches/backend_compat.py +47 -1
- unienv_data/batches/combined_batch.py +2 -4
- unienv_data/{base → batches}/transformations.py +3 -2
- unienv_data/replay_buffer/replay_buffer.py +4 -0
- unienv_data/samplers/__init__.py +0 -1
- unienv_data/samplers/multiprocessing_sampler.py +26 -22
- unienv_data/samplers/step_sampler.py +9 -18
- unienv_data/storages/common.py +5 -0
- unienv_data/storages/hdf5.py +291 -20
- unienv_data/storages/pytorch.py +1 -0
- unienv_data/storages/transformation.py +191 -0
- unienv_data/transformations/image_compress.py +213 -0
- unienv_interface/backends/jax.py +4 -1
- unienv_interface/backends/numpy.py +4 -1
- unienv_interface/backends/pytorch.py +4 -1
- unienv_interface/env_base/__init__.py +1 -0
- unienv_interface/env_base/env.py +5 -0
- unienv_interface/env_base/funcenv.py +32 -1
- unienv_interface/env_base/funcenv_wrapper.py +2 -2
- unienv_interface/env_base/vec_env.py +474 -0
- unienv_interface/func_wrapper/__init__.py +2 -1
- unienv_interface/func_wrapper/frame_stack.py +150 -0
- unienv_interface/space/space_utils/__init__.py +1 -0
- unienv_interface/space/space_utils/batch_utils.py +83 -0
- unienv_interface/space/space_utils/construct_utils.py +216 -0
- unienv_interface/space/space_utils/serialization_utils.py +16 -1
- unienv_interface/space/spaces/__init__.py +3 -1
- unienv_interface/space/spaces/batched.py +90 -0
- unienv_interface/space/spaces/binary.py +0 -1
- unienv_interface/space/spaces/box.py +13 -24
- unienv_interface/space/spaces/text.py +1 -3
- unienv_interface/transformations/dict_transform.py +31 -5
- unienv_interface/utils/control_util.py +68 -0
- unienv_interface/utils/data_queue.py +184 -0
- unienv_interface/utils/stateclass.py +46 -0
- unienv_interface/utils/vec_util.py +15 -0
- unienv_interface/world/__init__.py +3 -1
- unienv_interface/world/combined_funcnode.py +336 -0
- unienv_interface/world/combined_node.py +232 -0
- unienv_interface/wrapper/backend_compat.py +2 -2
- unienv_interface/wrapper/frame_stack.py +19 -114
- unienv_interface/wrapper/video_record.py +11 -2
- unienv-0.0.1b1.dist-info/METADATA +0 -20
- unienv-0.0.1b1.dist-info/RECORD +0 -85
- unienv-0.0.1b1.dist-info/top_level.txt +0 -4
- unienv_data/samplers/slice_sampler.py +0 -266
- unienv_maniskill/__init__.py +0 -1
- unienv_maniskill/wrapper/maniskill_compat.py +0 -235
- unienv_mjxplayground/__init__.py +0 -1
- unienv_mjxplayground/wrapper/playground_compat.py +0 -256
- {unienv-0.0.1b1.dist-info → unienv-0.0.1b3.dist-info}/WHEEL +0 -0
|
@@ -1,256 +0,0 @@
|
|
|
1
|
-
from typing import Any, Optional, Tuple, Dict, Union, SupportsFloat, Sequence, Callable, Mapping
|
|
2
|
-
import jax
|
|
3
|
-
import jax.numpy as jnp
|
|
4
|
-
import numpy as np
|
|
5
|
-
|
|
6
|
-
from unienv_interface.env_base import FuncEnv
|
|
7
|
-
from unienv_interface.space import Space, BoxSpace, DictSpace
|
|
8
|
-
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
9
|
-
from unienv_interface.backends.numpy import NumpyComputeBackend
|
|
10
|
-
from unienv_interface.backends.jax import JaxComputeBackend, JaxArrayType, JaxDeviceType, JaxDtypeType, JaxRNGType
|
|
11
|
-
from unienv_interface.space.space_utils import batch_utils as sbu
|
|
12
|
-
from unienv_interface.wrapper import backend_compat
|
|
13
|
-
|
|
14
|
-
import mujoco.mjx as mjx
|
|
15
|
-
from mujoco_playground import MjxEnv, State as MjxState
|
|
16
|
-
from mujoco_playground._src.wrapper import Wrapper, MadronaWrapper, BraxDomainRandomizationVmapWrapper
|
|
17
|
-
from brax.envs.wrappers import training as brax_training
|
|
18
|
-
from flax import struct
|
|
19
|
-
|
|
20
|
-
AxisMapSingleT = Union[Mapping[str, "AxisMapSingleT"], int, None]
|
|
21
|
-
AxisMapT = Union[AxisMapSingleT, Tuple[AxisMapSingleT, ...]]
|
|
22
|
-
JaxTreeOrArrayT = Union[JaxArrayType, Dict[str, Any]]
|
|
23
|
-
RandomizationFnT = Callable[[mjx.Model], Tuple[mjx.Model, AxisMapSingleT]]
|
|
24
|
-
|
|
25
|
-
@struct.dataclass
|
|
26
|
-
class MJXPlaygroundState:
|
|
27
|
-
state : MjxState
|
|
28
|
-
rng : JaxRNGType
|
|
29
|
-
|
|
30
|
-
def is_mjx_env_vision(
|
|
31
|
-
env: MjxEnv
|
|
32
|
-
) -> bool:
|
|
33
|
-
try:
|
|
34
|
-
from madrona_mjx.renderer import BatchRenderer # pytype: disable=import-error
|
|
35
|
-
except ImportError:
|
|
36
|
-
return False
|
|
37
|
-
if hasattr(env, "renderer") and isinstance(env.renderer, BatchRenderer):
|
|
38
|
-
return True
|
|
39
|
-
return False
|
|
40
|
-
|
|
41
|
-
def wrap_mjx_env(
|
|
42
|
-
env: MjxEnv,
|
|
43
|
-
batch_size: int,
|
|
44
|
-
is_vision : bool = False,
|
|
45
|
-
randomization_fn : Optional[
|
|
46
|
-
RandomizationFnT
|
|
47
|
-
] = None,
|
|
48
|
-
) -> Wrapper:
|
|
49
|
-
if is_vision:
|
|
50
|
-
env = MadronaWrapper(env, batch_size, randomization_fn)
|
|
51
|
-
elif randomization_fn is not None:
|
|
52
|
-
env = BraxDomainRandomizationVmapWrapper(env, randomization_fn)
|
|
53
|
-
else:
|
|
54
|
-
env = brax_training.VmapWrapper(env)
|
|
55
|
-
return env
|
|
56
|
-
|
|
57
|
-
def space_from_size(
|
|
58
|
-
size : AxisMapSingleT,
|
|
59
|
-
device : Optional[JaxDeviceType] = None,
|
|
60
|
-
) -> Union[
|
|
61
|
-
BoxSpace[JaxArrayType, JaxDeviceType, JaxDtypeType, JaxRNGType],
|
|
62
|
-
DictSpace[JaxDeviceType, JaxDtypeType, JaxRNGType]
|
|
63
|
-
]:
|
|
64
|
-
if isinstance(size, Mapping):
|
|
65
|
-
return DictSpace(
|
|
66
|
-
JaxComputeBackend,
|
|
67
|
-
{
|
|
68
|
-
key: space_from_size(size_inner, device=device)
|
|
69
|
-
for key, size_inner in size.items()
|
|
70
|
-
},
|
|
71
|
-
device=device
|
|
72
|
-
)
|
|
73
|
-
else:
|
|
74
|
-
return BoxSpace(
|
|
75
|
-
JaxComputeBackend,
|
|
76
|
-
-jnp.inf,
|
|
77
|
-
jnp.inf,
|
|
78
|
-
shape=size if not isinstance(size, int) else (size,),
|
|
79
|
-
dtype=jnp.float32,
|
|
80
|
-
device=device
|
|
81
|
-
)
|
|
82
|
-
|
|
83
|
-
class FromMJXPlaygroundEnv(
|
|
84
|
-
FuncEnv[
|
|
85
|
-
MjxState,
|
|
86
|
-
None,
|
|
87
|
-
JaxArrayType,
|
|
88
|
-
None,
|
|
89
|
-
JaxTreeOrArrayT,
|
|
90
|
-
JaxTreeOrArrayT,
|
|
91
|
-
None,
|
|
92
|
-
JaxDeviceType,
|
|
93
|
-
JaxDtypeType,
|
|
94
|
-
JaxRNGType
|
|
95
|
-
]
|
|
96
|
-
):
|
|
97
|
-
metadata = {
|
|
98
|
-
"render_modes": ['rgb_array', 'human']
|
|
99
|
-
}
|
|
100
|
-
backend = JaxComputeBackend
|
|
101
|
-
def __init__(
|
|
102
|
-
self,
|
|
103
|
-
single_env: MjxEnv,
|
|
104
|
-
batch_size: int,
|
|
105
|
-
randomization_fn : Optional[
|
|
106
|
-
RandomizationFnT
|
|
107
|
-
] = None,
|
|
108
|
-
device : Optional[JaxDeviceType] = None,
|
|
109
|
-
jit : bool = True
|
|
110
|
-
) -> None:
|
|
111
|
-
self.single_env = single_env
|
|
112
|
-
self.env = wrap_mjx_env(
|
|
113
|
-
single_env,
|
|
114
|
-
batch_size,
|
|
115
|
-
is_vision=is_mjx_env_vision(single_env),
|
|
116
|
-
randomization_fn=randomization_fn
|
|
117
|
-
)
|
|
118
|
-
if jit:
|
|
119
|
-
self.vanilla_reset_fn = jax.jit(self.env.reset, device=self.device)
|
|
120
|
-
self.vanilla_step_fn = jax.jit(self.env.step, device=self.device)
|
|
121
|
-
else:
|
|
122
|
-
self.vanilla_reset_fn = self.env.reset
|
|
123
|
-
self.vanilla_step_fn = self.env.step
|
|
124
|
-
|
|
125
|
-
self.batch_size = batch_size
|
|
126
|
-
self.device = device
|
|
127
|
-
|
|
128
|
-
self.action_space = sbu.batch_space(
|
|
129
|
-
space_from_size(
|
|
130
|
-
single_env.action_size,
|
|
131
|
-
device=self.device
|
|
132
|
-
),
|
|
133
|
-
batch_size
|
|
134
|
-
)
|
|
135
|
-
self.observation_space = sbu.batch_space(
|
|
136
|
-
space_from_size(
|
|
137
|
-
single_env.observation_size,
|
|
138
|
-
device=self.device
|
|
139
|
-
),
|
|
140
|
-
batch_size
|
|
141
|
-
)
|
|
142
|
-
|
|
143
|
-
self.context_space = None
|
|
144
|
-
|
|
145
|
-
def initial(self, *, seed : Optional[int]) -> Tuple[
|
|
146
|
-
MJXPlaygroundState,
|
|
147
|
-
None,
|
|
148
|
-
JaxTreeOrArrayT,
|
|
149
|
-
Dict[str, Any]
|
|
150
|
-
]:
|
|
151
|
-
rng = jax.random.PRNGKey(seed)
|
|
152
|
-
rng, reset_rng = jax.random.split(rng)
|
|
153
|
-
reset_rng = jax.random.split(reset_rng, self.batch_size)
|
|
154
|
-
raw_state = self.vanilla_reset_fn(rng=reset_rng)
|
|
155
|
-
state = MJXPlaygroundState(
|
|
156
|
-
state=raw_state,
|
|
157
|
-
rng=rng
|
|
158
|
-
)
|
|
159
|
-
return state, None, raw_state.obs, raw_state.info
|
|
160
|
-
|
|
161
|
-
def reset(
|
|
162
|
-
self,
|
|
163
|
-
state : MJXPlaygroundState,
|
|
164
|
-
*,
|
|
165
|
-
seed : Optional[int] = None,
|
|
166
|
-
mask : Optional[JaxArrayType] = None
|
|
167
|
-
) -> Tuple[
|
|
168
|
-
MJXPlaygroundState,
|
|
169
|
-
None,
|
|
170
|
-
JaxTreeOrArrayT,
|
|
171
|
-
Dict[str, Any]
|
|
172
|
-
]:
|
|
173
|
-
if seed is None:
|
|
174
|
-
rng = state.rng
|
|
175
|
-
else:
|
|
176
|
-
rng = jax.random.PRNGKey(seed)
|
|
177
|
-
rng, reset_rng = jax.random.split(rng)
|
|
178
|
-
reset_rng = jax.random.split(reset_rng, self.batch_size)
|
|
179
|
-
reset_state = self.env.reset(
|
|
180
|
-
rng=reset_rng
|
|
181
|
-
)
|
|
182
|
-
if mask is None:
|
|
183
|
-
return MJXPlaygroundState(
|
|
184
|
-
state=reset_state,
|
|
185
|
-
rng=rng
|
|
186
|
-
), None, reset_state.obs, reset_state.info
|
|
187
|
-
else:
|
|
188
|
-
def where_reset(
|
|
189
|
-
x,y
|
|
190
|
-
) -> bool:
|
|
191
|
-
mask_casted = jnp.reshape(mask, [mask.shape[0]] + [1]*(len(x.shape)-1))
|
|
192
|
-
return jnp.where(mask_casted, x, y)
|
|
193
|
-
def pick_reset(x):
|
|
194
|
-
return x[mask]
|
|
195
|
-
|
|
196
|
-
new_state = jax.tree.map(
|
|
197
|
-
where_reset, reset_state, state.state
|
|
198
|
-
)
|
|
199
|
-
reset_obs = jax.tree.map(
|
|
200
|
-
pick_reset, reset_state.obs
|
|
201
|
-
) if isinstance(reset_state.obs, Mapping) else reset_state.obs[mask]
|
|
202
|
-
reset_info = jax.tree.map(
|
|
203
|
-
pick_reset, reset_state.info
|
|
204
|
-
)
|
|
205
|
-
return MJXPlaygroundState(
|
|
206
|
-
state=new_state,
|
|
207
|
-
rng=rng
|
|
208
|
-
), None, reset_obs, reset_info
|
|
209
|
-
|
|
210
|
-
def step(
|
|
211
|
-
self,
|
|
212
|
-
state : MJXPlaygroundState,
|
|
213
|
-
action : JaxTreeOrArrayT
|
|
214
|
-
) -> Tuple[
|
|
215
|
-
MJXPlaygroundState,
|
|
216
|
-
JaxTreeOrArrayT,
|
|
217
|
-
JaxArrayType,
|
|
218
|
-
JaxArrayType,
|
|
219
|
-
JaxArrayType,
|
|
220
|
-
Dict[str, Any]
|
|
221
|
-
]:
|
|
222
|
-
step_state = self.vanilla_step_fn(
|
|
223
|
-
state.state,
|
|
224
|
-
action
|
|
225
|
-
)
|
|
226
|
-
return (
|
|
227
|
-
MJXPlaygroundState(
|
|
228
|
-
state=step_state,
|
|
229
|
-
rng=state.rng
|
|
230
|
-
),
|
|
231
|
-
step_state.obs,
|
|
232
|
-
step_state.reward,
|
|
233
|
-
step_state.done,
|
|
234
|
-
jnp.zeros_like(step_state.done, dtype=jnp.bool, device=self.device),
|
|
235
|
-
step_state.info
|
|
236
|
-
)
|
|
237
|
-
|
|
238
|
-
# =========== Wrapper methods ==========
|
|
239
|
-
def has_wrapper_attr(self, name: str) -> bool:
|
|
240
|
-
return hasattr(self, name) or hasattr(self.env, name)
|
|
241
|
-
|
|
242
|
-
def get_wrapper_attr(self, name: str) -> Any:
|
|
243
|
-
if hasattr(self, name):
|
|
244
|
-
return getattr(self, name)
|
|
245
|
-
elif hasattr(self.env, name):
|
|
246
|
-
return getattr(self.env, name)
|
|
247
|
-
else:
|
|
248
|
-
raise AttributeError(f"Attribute {name} not found in the environment.")
|
|
249
|
-
|
|
250
|
-
def set_wrapper_attr(self, name: str, value: Any):
|
|
251
|
-
if hasattr(self, name):
|
|
252
|
-
setattr(self, name, value)
|
|
253
|
-
elif hasattr(self.env, name):
|
|
254
|
-
setattr(self.env, name, value)
|
|
255
|
-
else:
|
|
256
|
-
raise AttributeError(f"Attribute {name} not found in the environment.")
|
|
File without changes
|