unienv 0.0.1b1__py3-none-any.whl → 0.0.1b3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- unienv-0.0.1b3.dist-info/METADATA +74 -0
- unienv-0.0.1b3.dist-info/RECORD +92 -0
- {unienv-0.0.1b1.dist-info → unienv-0.0.1b3.dist-info}/licenses/LICENSE +1 -3
- unienv-0.0.1b3.dist-info/top_level.txt +2 -0
- unienv_data/base/__init__.py +0 -1
- unienv_data/base/common.py +95 -45
- unienv_data/base/storage.py +1 -0
- unienv_data/batches/__init__.py +2 -1
- unienv_data/batches/backend_compat.py +47 -1
- unienv_data/batches/combined_batch.py +2 -4
- unienv_data/{base → batches}/transformations.py +3 -2
- unienv_data/replay_buffer/replay_buffer.py +4 -0
- unienv_data/samplers/__init__.py +0 -1
- unienv_data/samplers/multiprocessing_sampler.py +26 -22
- unienv_data/samplers/step_sampler.py +9 -18
- unienv_data/storages/common.py +5 -0
- unienv_data/storages/hdf5.py +291 -20
- unienv_data/storages/pytorch.py +1 -0
- unienv_data/storages/transformation.py +191 -0
- unienv_data/transformations/image_compress.py +213 -0
- unienv_interface/backends/jax.py +4 -1
- unienv_interface/backends/numpy.py +4 -1
- unienv_interface/backends/pytorch.py +4 -1
- unienv_interface/env_base/__init__.py +1 -0
- unienv_interface/env_base/env.py +5 -0
- unienv_interface/env_base/funcenv.py +32 -1
- unienv_interface/env_base/funcenv_wrapper.py +2 -2
- unienv_interface/env_base/vec_env.py +474 -0
- unienv_interface/func_wrapper/__init__.py +2 -1
- unienv_interface/func_wrapper/frame_stack.py +150 -0
- unienv_interface/space/space_utils/__init__.py +1 -0
- unienv_interface/space/space_utils/batch_utils.py +83 -0
- unienv_interface/space/space_utils/construct_utils.py +216 -0
- unienv_interface/space/space_utils/serialization_utils.py +16 -1
- unienv_interface/space/spaces/__init__.py +3 -1
- unienv_interface/space/spaces/batched.py +90 -0
- unienv_interface/space/spaces/binary.py +0 -1
- unienv_interface/space/spaces/box.py +13 -24
- unienv_interface/space/spaces/text.py +1 -3
- unienv_interface/transformations/dict_transform.py +31 -5
- unienv_interface/utils/control_util.py +68 -0
- unienv_interface/utils/data_queue.py +184 -0
- unienv_interface/utils/stateclass.py +46 -0
- unienv_interface/utils/vec_util.py +15 -0
- unienv_interface/world/__init__.py +3 -1
- unienv_interface/world/combined_funcnode.py +336 -0
- unienv_interface/world/combined_node.py +232 -0
- unienv_interface/wrapper/backend_compat.py +2 -2
- unienv_interface/wrapper/frame_stack.py +19 -114
- unienv_interface/wrapper/video_record.py +11 -2
- unienv-0.0.1b1.dist-info/METADATA +0 -20
- unienv-0.0.1b1.dist-info/RECORD +0 -85
- unienv-0.0.1b1.dist-info/top_level.txt +0 -4
- unienv_data/samplers/slice_sampler.py +0 -266
- unienv_maniskill/__init__.py +0 -1
- unienv_maniskill/wrapper/maniskill_compat.py +0 -235
- unienv_mjxplayground/__init__.py +0 -1
- unienv_mjxplayground/wrapper/playground_compat.py +0 -256
- {unienv-0.0.1b1.dist-info → unienv-0.0.1b3.dist-info}/WHEEL +0 -0
|
@@ -1,266 +0,0 @@
|
|
|
1
|
-
from typing import Any, Tuple, Union, Optional, List, Dict, Type, TypeVar, Generic, Callable, Iterator
|
|
2
|
-
from unienv_data.base import BatchBase, BatchT, SamplerBatchT, SamplerArrayType, SamplerDeviceType, SamplerDtypeType, SamplerRNGType, BatchSampler
|
|
3
|
-
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
4
|
-
from unienv_interface.space import Space, BoxSpace, BinarySpace, DictSpace
|
|
5
|
-
from unienv_interface.space.space_utils import batch_utils as sbu, flatten_utils as sfu
|
|
6
|
-
|
|
7
|
-
class SliceSampler(
|
|
8
|
-
BatchSampler[
|
|
9
|
-
BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType,
|
|
10
|
-
BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
11
|
-
]
|
|
12
|
-
):
|
|
13
|
-
"""
|
|
14
|
-
It is recommended to use SliceSampler as the final layer sampler
|
|
15
|
-
Because it has to reshape the data, and we add an additional dimension T apart from the Batch dimension
|
|
16
|
-
Which makes a lot of the wrappers incompatible with it
|
|
17
|
-
"""
|
|
18
|
-
def __init__(
|
|
19
|
-
self,
|
|
20
|
-
data : BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
21
|
-
batch_size : int,
|
|
22
|
-
prefetch_horizon : int = 0,
|
|
23
|
-
postfetch_horizon : int = 0,
|
|
24
|
-
get_episode_id_fn: Optional[Callable[[BatchT], BArrayType]] = None,
|
|
25
|
-
seed : Optional[int] = None,
|
|
26
|
-
device : Optional[BDeviceType] = None,
|
|
27
|
-
):
|
|
28
|
-
assert batch_size > 0, "Batch size must be a positive integer"
|
|
29
|
-
assert prefetch_horizon >= 0, "Prefetch horizon must be a non-negative integer"
|
|
30
|
-
assert postfetch_horizon >= 0, "Postfetch horizon must be a non-negative integer"
|
|
31
|
-
assert prefetch_horizon > 0 or postfetch_horizon > 0, "At least one of prefetch_horizon and postfetch_horizon must be greater than 0, otherwise you can use `StepSampler`"
|
|
32
|
-
self.data = data
|
|
33
|
-
self.batch_size = batch_size
|
|
34
|
-
self.prefetch_horizon = prefetch_horizon
|
|
35
|
-
self.postfetch_horizon = postfetch_horizon
|
|
36
|
-
self._device = device
|
|
37
|
-
|
|
38
|
-
self.single_slice_space = sbu.batch_space(
|
|
39
|
-
self.data.single_space,
|
|
40
|
-
self.prefetch_horizon + self.postfetch_horizon + 1
|
|
41
|
-
)
|
|
42
|
-
self.sampled_space = sbu.batch_space(
|
|
43
|
-
self.single_slice_space,
|
|
44
|
-
batch_size
|
|
45
|
-
)
|
|
46
|
-
self.sampled_space_flat = sfu.flatten_space(self.sampled_space, start_dim=2)
|
|
47
|
-
|
|
48
|
-
if self.data.single_metadata_space is not None:
|
|
49
|
-
self.sampled_metadata_space = sbu.batch_space(
|
|
50
|
-
self.data.single_metadata_space,
|
|
51
|
-
self.prefetch_horizon + self.postfetch_horizon + 1
|
|
52
|
-
)
|
|
53
|
-
self.sampled_metadata_space = sbu.batch_space(
|
|
54
|
-
self.sampled_metadata_space,
|
|
55
|
-
batch_size
|
|
56
|
-
)
|
|
57
|
-
else:
|
|
58
|
-
self.sampled_metadata_space = None
|
|
59
|
-
|
|
60
|
-
if get_episode_id_fn is not None:
|
|
61
|
-
if self.sampled_metadata_space is None:
|
|
62
|
-
self.sampled_metadata_space = DictSpace(
|
|
63
|
-
self.backend,
|
|
64
|
-
{},
|
|
65
|
-
device=self.device
|
|
66
|
-
)
|
|
67
|
-
|
|
68
|
-
self.sampled_metadata_space['slice_valid_mask'] = BinarySpace(
|
|
69
|
-
self.backend,
|
|
70
|
-
shape=(self.batch_size, self.prefetch_horizon + self.postfetch_horizon + 1),
|
|
71
|
-
dtype=self.backend.default_boolean_dtype,
|
|
72
|
-
device=self.device
|
|
73
|
-
)
|
|
74
|
-
self.sampled_metadata_space['episode_id'] = BoxSpace(
|
|
75
|
-
self.backend,
|
|
76
|
-
low=-2_147_483_647,
|
|
77
|
-
high=2_147_483_647,
|
|
78
|
-
shape=(self.batch_size, ),
|
|
79
|
-
dtype=self.backend.default_integer_dtype,
|
|
80
|
-
device=self.device
|
|
81
|
-
)
|
|
82
|
-
|
|
83
|
-
if device is not None:
|
|
84
|
-
self.single_slice_space = self.single_slice_space.to(device=device)
|
|
85
|
-
self.sampled_space = self.sampled_space.to(device=device)
|
|
86
|
-
|
|
87
|
-
self.data_rng = self.backend.random.random_number_generator(
|
|
88
|
-
seed,
|
|
89
|
-
device=data.device
|
|
90
|
-
)
|
|
91
|
-
|
|
92
|
-
self.get_episode_id_fn = get_episode_id_fn
|
|
93
|
-
self._build_epid_cache()
|
|
94
|
-
|
|
95
|
-
@property
|
|
96
|
-
def backend(self) -> ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType]:
|
|
97
|
-
return self.data.backend
|
|
98
|
-
|
|
99
|
-
@property
|
|
100
|
-
def device(self) -> Optional[BDeviceType]:
|
|
101
|
-
return self._device or self.data.device
|
|
102
|
-
|
|
103
|
-
def _build_epid_cache(self):
|
|
104
|
-
"""
|
|
105
|
-
Build a cache that helps speed up the filtering process
|
|
106
|
-
"""
|
|
107
|
-
if self.get_episode_id_fn is None:
|
|
108
|
-
self._epid_flatidx = None
|
|
109
|
-
|
|
110
|
-
# First make a fake batch to get the episode ids
|
|
111
|
-
# flat_data = self.backend.zeros(
|
|
112
|
-
# self.sampled_space_flat.shape,
|
|
113
|
-
# dtype=self.sampled_space_flat.dtype,
|
|
114
|
-
# device=self.sampled_space_flat.device
|
|
115
|
-
# )
|
|
116
|
-
# flat_data[:] = self.backend.arange(
|
|
117
|
-
# flat_data.shape[-1], device=self.sampled_space_flat.device
|
|
118
|
-
# )[None, None, :] # (1, 1, D)
|
|
119
|
-
flat_data = self.backend.broadcast_to(
|
|
120
|
-
self.backend.arange(
|
|
121
|
-
self.sampled_space_flat.shape[-1], device=self.sampled_space_flat.device
|
|
122
|
-
)[None, None, :], # (1, 1, D)
|
|
123
|
-
self.sampled_space_flat.shape
|
|
124
|
-
)
|
|
125
|
-
|
|
126
|
-
dat = sfu.unflatten_data(self.sampled_space, flat_data, start_dim=2)
|
|
127
|
-
episode_ids = self.get_episode_id_fn(dat)
|
|
128
|
-
del dat
|
|
129
|
-
|
|
130
|
-
epid_flatidx = int(episode_ids[0, 0])
|
|
131
|
-
if self.backend.all(episode_ids == epid_flatidx):
|
|
132
|
-
self._epid_flatidx = epid_flatidx
|
|
133
|
-
else:
|
|
134
|
-
self._epid_flatidx = None
|
|
135
|
-
|
|
136
|
-
def expand_index(self, index : BArrayType) -> BArrayType:
|
|
137
|
-
"""
|
|
138
|
-
Sample indexes to slice the data, returns a tensor of shape (B, T) that resides on the same device as the data
|
|
139
|
-
"""
|
|
140
|
-
index_shifts = self.backend.arange( # (T, )
|
|
141
|
-
-self.prefetch_horizon, self.postfetch_horizon + 1, dtype=index.dtype, device=self.data.device
|
|
142
|
-
)
|
|
143
|
-
index = index[:, None] + index_shifts[None, :] # (B, T)
|
|
144
|
-
index = self.backend.clip(index, 0, len(self.data) - 1)
|
|
145
|
-
return index
|
|
146
|
-
|
|
147
|
-
def _get_unfiltered_flat_with_metadata(self, idx : BArrayType) -> Tuple[BArrayType, Optional[Dict[str, Any]]]:
|
|
148
|
-
B = idx.shape[0]
|
|
149
|
-
indices = self.expand_index(idx) # (B, T)
|
|
150
|
-
flat_idx = self.backend.reshape(indices, (-1,)) # (B * T, )
|
|
151
|
-
dat_flat, metadata = self.data.get_flattened_at_with_metadata(flat_idx) # (B * T, D)
|
|
152
|
-
metadata_reshaped = self.backend.map_fn_over_arrays(
|
|
153
|
-
metadata,
|
|
154
|
-
lambda x: self.backend.reshape(x, (*indices.shape, *x.shape[1:]))
|
|
155
|
-
) if metadata is not None else None
|
|
156
|
-
assert dat_flat.shape[0] == (self.prefetch_horizon + self.postfetch_horizon + 1) * B
|
|
157
|
-
|
|
158
|
-
dat = self.backend.reshape(dat_flat, (*indices.shape, -1)) # (B, T, D)
|
|
159
|
-
return dat, metadata_reshaped
|
|
160
|
-
|
|
161
|
-
def unfiltered_to_filtered_flat(self, flat_dat: BArrayType) -> Tuple[
|
|
162
|
-
BArrayType, # Data (B, T, D)
|
|
163
|
-
BArrayType, # validity mask (B, T)
|
|
164
|
-
Optional[BArrayType] # episode id (B)
|
|
165
|
-
]:
|
|
166
|
-
B = flat_dat.shape[0]
|
|
167
|
-
device = self._device or self.backend.device(flat_dat)
|
|
168
|
-
if self.get_episode_id_fn is not None:
|
|
169
|
-
# fetch episode ids
|
|
170
|
-
if self._epid_flatidx is None:
|
|
171
|
-
if self._device is not None:
|
|
172
|
-
new_flat_dat = self.backend.to_device(flat_dat, device)
|
|
173
|
-
dat = sfu.unflatten_data(self.sampled_space, flat_dat, start_dim=2) # (B, T, D)
|
|
174
|
-
episode_ids = self.get_episode_id_fn(dat)
|
|
175
|
-
if self._device is not None:
|
|
176
|
-
episode_ids = self.backend.to_device(episode_ids, device)
|
|
177
|
-
flat_dat = new_flat_dat
|
|
178
|
-
del dat
|
|
179
|
-
else:
|
|
180
|
-
episode_ids = flat_dat[:, :, self._epid_flatidx]
|
|
181
|
-
if self._device is not None:
|
|
182
|
-
flat_dat = self.backend.to_device(flat_dat, device)
|
|
183
|
-
episode_ids = self.backend.to_device(episode_ids, device)
|
|
184
|
-
|
|
185
|
-
assert self.backend.is_backendarray(episode_ids)
|
|
186
|
-
assert episode_ids.shape == (B, self.prefetch_horizon + self.postfetch_horizon + 1)
|
|
187
|
-
episode_id_at_step = episode_ids[:, self.prefetch_horizon]
|
|
188
|
-
episode_id_eq = episode_ids == episode_id_at_step[:, None]
|
|
189
|
-
|
|
190
|
-
zero_to_B = self.backend.arange(
|
|
191
|
-
B,
|
|
192
|
-
device=device
|
|
193
|
-
)
|
|
194
|
-
if self.prefetch_horizon > 0:
|
|
195
|
-
num_eq_prefetch = self.backend.sum(episode_id_eq[:, :self.prefetch_horizon], axis=1)
|
|
196
|
-
fill_idx_prefetch = self.prefetch_horizon - num_eq_prefetch
|
|
197
|
-
fill_value_prefetch = flat_dat[
|
|
198
|
-
zero_to_B,
|
|
199
|
-
fill_idx_prefetch
|
|
200
|
-
] # (B, D)
|
|
201
|
-
fill_value_prefetch = fill_value_prefetch[:, None, :] # (B, 1, D)
|
|
202
|
-
flat_dat_prefetch = self.backend.where(
|
|
203
|
-
episode_id_eq[:, :self.prefetch_horizon, None],
|
|
204
|
-
flat_dat[:, :self.prefetch_horizon],
|
|
205
|
-
fill_value_prefetch
|
|
206
|
-
)
|
|
207
|
-
else:
|
|
208
|
-
flat_dat_prefetch = None
|
|
209
|
-
|
|
210
|
-
if self.postfetch_horizon > 0:
|
|
211
|
-
num_eq_postfetch = self.backend.sum(episode_id_eq[:, -self.postfetch_horizon:], axis=1)
|
|
212
|
-
fill_idx_postfetch = self.prefetch_horizon + num_eq_postfetch
|
|
213
|
-
fill_value_postfetch = flat_dat[
|
|
214
|
-
zero_to_B,
|
|
215
|
-
fill_idx_postfetch
|
|
216
|
-
]
|
|
217
|
-
fill_value_postfetch = fill_value_postfetch[:, None, :] # (B, 1, D)
|
|
218
|
-
flat_dat_postfetch = self.backend.where(
|
|
219
|
-
episode_id_eq[:, self.prefetch_horizon:, None],
|
|
220
|
-
flat_dat[:, self.prefetch_horizon:],
|
|
221
|
-
fill_value_postfetch
|
|
222
|
-
)
|
|
223
|
-
else:
|
|
224
|
-
flat_dat_postfetch = flat_dat[:, self.prefetch_horizon:]
|
|
225
|
-
|
|
226
|
-
if flat_dat_prefetch is None:
|
|
227
|
-
flat_dat = flat_dat_postfetch
|
|
228
|
-
else:
|
|
229
|
-
flat_dat = self.backend.concatenate([
|
|
230
|
-
flat_dat_prefetch,
|
|
231
|
-
flat_dat_postfetch
|
|
232
|
-
], axis=1) # (B, T, D)
|
|
233
|
-
else:
|
|
234
|
-
episode_id_eq = None
|
|
235
|
-
episode_id_at_step = None
|
|
236
|
-
return flat_dat, episode_id_eq, episode_id_at_step
|
|
237
|
-
|
|
238
|
-
def get_flat_at(self, idx : BArrayType):
|
|
239
|
-
return self.get_flat_at_with_metadata(idx)[0]
|
|
240
|
-
|
|
241
|
-
def get_flat_at_with_metadata(self, idx : BArrayType) -> Tuple[
|
|
242
|
-
BArrayType,
|
|
243
|
-
Optional[Dict[str, Any]]
|
|
244
|
-
]:
|
|
245
|
-
unfilt_flat_dat, metadata = self._get_unfiltered_flat_with_metadata(idx)
|
|
246
|
-
dat, episode_id_eq, episode_id_at_step = self.unfiltered_to_filtered_flat(unfilt_flat_dat)
|
|
247
|
-
if episode_id_at_step is not None:
|
|
248
|
-
if metadata is None:
|
|
249
|
-
metadata = {}
|
|
250
|
-
metadata.update({
|
|
251
|
-
"slice_valid_mask": episode_id_eq,
|
|
252
|
-
"episode_id": episode_id_at_step
|
|
253
|
-
})
|
|
254
|
-
|
|
255
|
-
return dat, metadata
|
|
256
|
-
|
|
257
|
-
def get_at(self, idx : BArrayType) -> BatchT:
|
|
258
|
-
return self.get_at_with_metadata(idx)[0]
|
|
259
|
-
|
|
260
|
-
def get_at_with_metadata(self, idx : BArrayType) -> Tuple[
|
|
261
|
-
BatchT,
|
|
262
|
-
Optional[Dict[str, Any]]
|
|
263
|
-
]:
|
|
264
|
-
flat_dat, metadata = self.get_flat_at_with_metadata(idx)
|
|
265
|
-
dat = sfu.unflatten_data(self.sampled_space, flat_dat, start_dim=2)
|
|
266
|
-
return dat, metadata
|
unienv_maniskill/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
from .wrapper.maniskill_compat import FromManiSkillEnv
|
|
@@ -1,235 +0,0 @@
|
|
|
1
|
-
from typing import Any, Optional, Tuple, Dict, Union, SupportsFloat, Sequence
|
|
2
|
-
from mani_skill.envs.sapien_env import BaseEnv as ManiSkillBaseEnv
|
|
3
|
-
import gymnasium as gym
|
|
4
|
-
import torch
|
|
5
|
-
import numpy as np
|
|
6
|
-
|
|
7
|
-
from unienv_interface.env_base.env import Env
|
|
8
|
-
from unienv_interface.space import Space
|
|
9
|
-
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
10
|
-
from unienv_interface.backends.numpy import NumpyComputeBackend
|
|
11
|
-
from unienv_interface.backends.pytorch import PyTorchComputeBackend
|
|
12
|
-
from unienv_interface.space.space_utils import batch_utils as space_batch_utils, gym_utils as space_gym_utils
|
|
13
|
-
from unienv_interface.wrapper import backend_compat, gym_compat
|
|
14
|
-
|
|
15
|
-
MANISKILL_ENV_ARRAYTYPE = Union[PyTorchComputeBackend.ARRAY_TYPE, NumpyComputeBackend.ARRAY_TYPE]
|
|
16
|
-
MANISKILL_ENV_DEVICETYPE = Union[PyTorchComputeBackend.DEVICE_TYPE, NumpyComputeBackend.DEVICE_TYPE]
|
|
17
|
-
MANISKILL_ENV_DTYPET = Union[PyTorchComputeBackend.DTYPE_TYPE, NumpyComputeBackend.DTYPE_TYPE]
|
|
18
|
-
MANISKILL_ENV_RNGTYPE = Union[PyTorchComputeBackend.RNG_TYPE, NumpyComputeBackend.RNG_TYPE]
|
|
19
|
-
|
|
20
|
-
def convert_maniskill_dict_to_backend(
|
|
21
|
-
dict: Dict[str, MANISKILL_ENV_ARRAYTYPE],
|
|
22
|
-
backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
23
|
-
device: Optional[BDeviceType] = None,
|
|
24
|
-
):
|
|
25
|
-
return {
|
|
26
|
-
key: convert_maniskill_array_to_backend(value, backend, device)
|
|
27
|
-
for key, value in dict.items()
|
|
28
|
-
}
|
|
29
|
-
|
|
30
|
-
def convert_maniskill_array_to_backend(
|
|
31
|
-
array: MANISKILL_ENV_ARRAYTYPE,
|
|
32
|
-
backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
33
|
-
device: Optional[BDeviceType] = None,
|
|
34
|
-
):
|
|
35
|
-
if isinstance(array, torch.Tensor):
|
|
36
|
-
source_backend = PyTorchComputeBackend
|
|
37
|
-
elif isinstance(array, np.ndarray):
|
|
38
|
-
source_backend = NumpyComputeBackend
|
|
39
|
-
else:
|
|
40
|
-
return array
|
|
41
|
-
|
|
42
|
-
if backend is source_backend:
|
|
43
|
-
if device is not None:
|
|
44
|
-
return backend.to_device(array, device)
|
|
45
|
-
else:
|
|
46
|
-
return array
|
|
47
|
-
else:
|
|
48
|
-
return backend.from_other_backend(
|
|
49
|
-
array,
|
|
50
|
-
source_backend
|
|
51
|
-
)
|
|
52
|
-
|
|
53
|
-
def convert_maniskill_to_backend(
|
|
54
|
-
data: Any,
|
|
55
|
-
backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
56
|
-
device: Optional[BDeviceType] = None,
|
|
57
|
-
):
|
|
58
|
-
if isinstance(data, dict):
|
|
59
|
-
return convert_maniskill_dict_to_backend(data, backend, device)
|
|
60
|
-
else:
|
|
61
|
-
return convert_maniskill_array_to_backend(data, backend, device)
|
|
62
|
-
|
|
63
|
-
class FromManiSkillEnv(
|
|
64
|
-
Env[
|
|
65
|
-
PyTorchComputeBackend.ARRAY_TYPE,
|
|
66
|
-
None,
|
|
67
|
-
Any,
|
|
68
|
-
Any,
|
|
69
|
-
PyTorchComputeBackend.ARRAY_TYPE,
|
|
70
|
-
PyTorchComputeBackend.DEVICE_TYPE,
|
|
71
|
-
PyTorchComputeBackend.DTYPE_TYPE,
|
|
72
|
-
PyTorchComputeBackend.RNG_TYPE
|
|
73
|
-
]
|
|
74
|
-
):
|
|
75
|
-
def __init__(
|
|
76
|
-
self,
|
|
77
|
-
env: ManiSkillBaseEnv,
|
|
78
|
-
) -> None:
|
|
79
|
-
self.env = env
|
|
80
|
-
self.backend = PyTorchComputeBackend
|
|
81
|
-
self.device = env.get_wrapper_attr("device")
|
|
82
|
-
self.batch_size = env.get_wrapper_attr("num_envs")
|
|
83
|
-
|
|
84
|
-
self.action_space = space_gym_utils.from_gym_space(
|
|
85
|
-
env.action_space,
|
|
86
|
-
self.backend,
|
|
87
|
-
device=self.device
|
|
88
|
-
)
|
|
89
|
-
if env.get_wrapper_attr("num_envs") <= 1:
|
|
90
|
-
# Weirdly Maniskill doesn't batch the action space when num_envs is 1 but will batch the observation space
|
|
91
|
-
self.action_space = space_batch_utils.batch_space(
|
|
92
|
-
self.action_space,
|
|
93
|
-
1
|
|
94
|
-
)
|
|
95
|
-
|
|
96
|
-
self.observation_space = space_gym_utils.from_gym_space(
|
|
97
|
-
env.observation_space,
|
|
98
|
-
self.backend,
|
|
99
|
-
device=self.device
|
|
100
|
-
)
|
|
101
|
-
self.context_space = None
|
|
102
|
-
self.rng = torch.Generator(device=self.device)
|
|
103
|
-
|
|
104
|
-
@property
|
|
105
|
-
def metadata(self) -> Dict[str, Any]:
|
|
106
|
-
return self.env.metadata
|
|
107
|
-
|
|
108
|
-
def get_render_camera_params(
|
|
109
|
-
self
|
|
110
|
-
) -> Dict[str, Dict[str, torch.Tensor]]:
|
|
111
|
-
ret = {}
|
|
112
|
-
for name, camera in getattr(self.env, "_human_render_cameras", {}).items():
|
|
113
|
-
ret[name] = camera.get_params()
|
|
114
|
-
return ret
|
|
115
|
-
|
|
116
|
-
@property
|
|
117
|
-
def render_mode(self) -> Optional[str]:
|
|
118
|
-
return self.env.render_mode
|
|
119
|
-
|
|
120
|
-
@property
|
|
121
|
-
def render_fps(self) -> Optional[int]:
|
|
122
|
-
return self.env.get_wrapper_attr("control_freq")
|
|
123
|
-
|
|
124
|
-
def reset(
|
|
125
|
-
self,
|
|
126
|
-
*args,
|
|
127
|
-
mask: Optional[np.ndarray] = None,
|
|
128
|
-
seed: Optional[int] = None,
|
|
129
|
-
**kwargs
|
|
130
|
-
) -> Tuple[
|
|
131
|
-
None,
|
|
132
|
-
Any,
|
|
133
|
-
Dict[str, Any]
|
|
134
|
-
]:
|
|
135
|
-
options = None if mask is None else {
|
|
136
|
-
"env_idx": torch.nonzero(mask).flatten()
|
|
137
|
-
}
|
|
138
|
-
obs, info = self.env.reset(
|
|
139
|
-
*args,
|
|
140
|
-
seed=seed,
|
|
141
|
-
options=options,
|
|
142
|
-
**kwargs
|
|
143
|
-
)
|
|
144
|
-
# We don't convert the obs using from_gym_data here
|
|
145
|
-
# because the array may not be numpy array
|
|
146
|
-
obs = convert_maniskill_to_backend(
|
|
147
|
-
obs,
|
|
148
|
-
self.backend,
|
|
149
|
-
self.device
|
|
150
|
-
)
|
|
151
|
-
if mask is not None:
|
|
152
|
-
obs = space_batch_utils.get_at(
|
|
153
|
-
self.observation_space,
|
|
154
|
-
obs,
|
|
155
|
-
mask
|
|
156
|
-
)
|
|
157
|
-
|
|
158
|
-
return None, obs, info
|
|
159
|
-
|
|
160
|
-
def step(
|
|
161
|
-
self,
|
|
162
|
-
action: Any
|
|
163
|
-
) -> Tuple[
|
|
164
|
-
Any,
|
|
165
|
-
SupportsFloat,
|
|
166
|
-
bool,
|
|
167
|
-
bool,
|
|
168
|
-
Dict[str, Any]
|
|
169
|
-
]:
|
|
170
|
-
obs, rew, terminated, truncated, info = self.env.step(action)
|
|
171
|
-
c_obs = convert_maniskill_to_backend(
|
|
172
|
-
obs,
|
|
173
|
-
self.backend,
|
|
174
|
-
self.device
|
|
175
|
-
)
|
|
176
|
-
c_rew = convert_maniskill_array_to_backend(
|
|
177
|
-
rew,
|
|
178
|
-
self.backend,
|
|
179
|
-
self.device
|
|
180
|
-
)
|
|
181
|
-
c_terminated = convert_maniskill_array_to_backend(
|
|
182
|
-
terminated,
|
|
183
|
-
self.backend,
|
|
184
|
-
self.device
|
|
185
|
-
)
|
|
186
|
-
c_truncated = convert_maniskill_array_to_backend(
|
|
187
|
-
truncated,
|
|
188
|
-
self.backend,
|
|
189
|
-
self.device
|
|
190
|
-
)
|
|
191
|
-
c_info = convert_maniskill_dict_to_backend(
|
|
192
|
-
info,
|
|
193
|
-
self.backend,
|
|
194
|
-
self.device
|
|
195
|
-
)
|
|
196
|
-
return c_obs, c_rew, c_terminated, c_truncated, c_info
|
|
197
|
-
|
|
198
|
-
def render(
|
|
199
|
-
self
|
|
200
|
-
) -> Optional[
|
|
201
|
-
PyTorchComputeBackend.ARRAY_TYPE
|
|
202
|
-
]:
|
|
203
|
-
render_ret = self.env.render()
|
|
204
|
-
if render_ret is None:
|
|
205
|
-
return None
|
|
206
|
-
else:
|
|
207
|
-
dat = convert_maniskill_to_backend(
|
|
208
|
-
render_ret,
|
|
209
|
-
self.backend,
|
|
210
|
-
self.device
|
|
211
|
-
)
|
|
212
|
-
if PyTorchComputeBackend.is_backendarray(dat) and dat.shape[0] == 1:
|
|
213
|
-
return dat.squeeze(0)
|
|
214
|
-
else:
|
|
215
|
-
return None
|
|
216
|
-
|
|
217
|
-
# =========== Wrapper methods ==========
|
|
218
|
-
def has_wrapper_attr(self, name: str) -> bool:
|
|
219
|
-
return hasattr(self, name) or hasattr(self.env, name)
|
|
220
|
-
|
|
221
|
-
def get_wrapper_attr(self, name: str) -> Any:
|
|
222
|
-
if hasattr(self, name):
|
|
223
|
-
return getattr(self, name)
|
|
224
|
-
elif hasattr(self.env, name):
|
|
225
|
-
return getattr(self.env, name)
|
|
226
|
-
else:
|
|
227
|
-
raise AttributeError(f"Attribute {name} not found in the environment.")
|
|
228
|
-
|
|
229
|
-
def set_wrapper_attr(self, name: str, value: Any):
|
|
230
|
-
if hasattr(self, name):
|
|
231
|
-
setattr(self, name, value)
|
|
232
|
-
elif hasattr(self.env, name):
|
|
233
|
-
setattr(self.env, name, value)
|
|
234
|
-
else:
|
|
235
|
-
raise AttributeError(f"Attribute {name} not found in the environment.")
|
unienv_mjxplayground/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
from .wrapper.playground_compat import FromMJXPlaygroundEnv
|