unienv 0.0.1b1__py3-none-any.whl → 0.0.1b3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- unienv-0.0.1b3.dist-info/METADATA +74 -0
- unienv-0.0.1b3.dist-info/RECORD +92 -0
- {unienv-0.0.1b1.dist-info → unienv-0.0.1b3.dist-info}/licenses/LICENSE +1 -3
- unienv-0.0.1b3.dist-info/top_level.txt +2 -0
- unienv_data/base/__init__.py +0 -1
- unienv_data/base/common.py +95 -45
- unienv_data/base/storage.py +1 -0
- unienv_data/batches/__init__.py +2 -1
- unienv_data/batches/backend_compat.py +47 -1
- unienv_data/batches/combined_batch.py +2 -4
- unienv_data/{base → batches}/transformations.py +3 -2
- unienv_data/replay_buffer/replay_buffer.py +4 -0
- unienv_data/samplers/__init__.py +0 -1
- unienv_data/samplers/multiprocessing_sampler.py +26 -22
- unienv_data/samplers/step_sampler.py +9 -18
- unienv_data/storages/common.py +5 -0
- unienv_data/storages/hdf5.py +291 -20
- unienv_data/storages/pytorch.py +1 -0
- unienv_data/storages/transformation.py +191 -0
- unienv_data/transformations/image_compress.py +213 -0
- unienv_interface/backends/jax.py +4 -1
- unienv_interface/backends/numpy.py +4 -1
- unienv_interface/backends/pytorch.py +4 -1
- unienv_interface/env_base/__init__.py +1 -0
- unienv_interface/env_base/env.py +5 -0
- unienv_interface/env_base/funcenv.py +32 -1
- unienv_interface/env_base/funcenv_wrapper.py +2 -2
- unienv_interface/env_base/vec_env.py +474 -0
- unienv_interface/func_wrapper/__init__.py +2 -1
- unienv_interface/func_wrapper/frame_stack.py +150 -0
- unienv_interface/space/space_utils/__init__.py +1 -0
- unienv_interface/space/space_utils/batch_utils.py +83 -0
- unienv_interface/space/space_utils/construct_utils.py +216 -0
- unienv_interface/space/space_utils/serialization_utils.py +16 -1
- unienv_interface/space/spaces/__init__.py +3 -1
- unienv_interface/space/spaces/batched.py +90 -0
- unienv_interface/space/spaces/binary.py +0 -1
- unienv_interface/space/spaces/box.py +13 -24
- unienv_interface/space/spaces/text.py +1 -3
- unienv_interface/transformations/dict_transform.py +31 -5
- unienv_interface/utils/control_util.py +68 -0
- unienv_interface/utils/data_queue.py +184 -0
- unienv_interface/utils/stateclass.py +46 -0
- unienv_interface/utils/vec_util.py +15 -0
- unienv_interface/world/__init__.py +3 -1
- unienv_interface/world/combined_funcnode.py +336 -0
- unienv_interface/world/combined_node.py +232 -0
- unienv_interface/wrapper/backend_compat.py +2 -2
- unienv_interface/wrapper/frame_stack.py +19 -114
- unienv_interface/wrapper/video_record.py +11 -2
- unienv-0.0.1b1.dist-info/METADATA +0 -20
- unienv-0.0.1b1.dist-info/RECORD +0 -85
- unienv-0.0.1b1.dist-info/top_level.txt +0 -4
- unienv_data/samplers/slice_sampler.py +0 -266
- unienv_maniskill/__init__.py +0 -1
- unienv_maniskill/wrapper/maniskill_compat.py +0 -235
- unienv_mjxplayground/__init__.py +0 -1
- unienv_mjxplayground/wrapper/playground_compat.py +0 -256
- {unienv-0.0.1b1.dist-info → unienv-0.0.1b3.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,474 @@
|
|
|
1
|
+
from .env import Env, ContextType, ObsType, ActType, RenderFrame
|
|
2
|
+
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
3
|
+
from unienv_interface.space import Space, batch_utils as sbu
|
|
4
|
+
from unienv_interface.utils.vec_util import MultiProcessFn
|
|
5
|
+
from typing import Any, Dict, Generic, Literal, Optional, SupportsFloat, Tuple, TypeVar, Callable, Iterable, Mapping, Sequence, List
|
|
6
|
+
import numpy as np
|
|
7
|
+
import multiprocessing as mp
|
|
8
|
+
from multiprocessing.connection import Connection as MPConnection
|
|
9
|
+
from multiprocessing.context import BaseContext as MPContext
|
|
10
|
+
from queue import Empty as QueueEmpty
|
|
11
|
+
|
|
12
|
+
def data_stack(
|
|
13
|
+
data : Any,
|
|
14
|
+
backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
15
|
+
device : Optional[BDeviceType] = None,
|
|
16
|
+
):
|
|
17
|
+
if isinstance(data, Mapping):
|
|
18
|
+
data = {
|
|
19
|
+
key: data_stack(value, backend)
|
|
20
|
+
for key, value in data.items()
|
|
21
|
+
}
|
|
22
|
+
if isinstance(data, Sequence):
|
|
23
|
+
if len(data) == 0:
|
|
24
|
+
return data
|
|
25
|
+
if backend.is_backendarray(data[0]):
|
|
26
|
+
return backend.stack(data, axis=0)
|
|
27
|
+
elif isinstance(data[0], np.ndarray):
|
|
28
|
+
return np.stack(data, axis=0)
|
|
29
|
+
elif isinstance(data[0], (int, float, bool)):
|
|
30
|
+
dtype = (
|
|
31
|
+
backend.default_boolean_dtype if isinstance(data[0], bool) else
|
|
32
|
+
backend.default_floating_dtype if isinstance(data[0], float) else
|
|
33
|
+
backend.default_integer_dtype
|
|
34
|
+
)
|
|
35
|
+
return backend.asarray(data, dtype=dtype, device=device)
|
|
36
|
+
elif isinstance(data[0], SupportsFloat):
|
|
37
|
+
return backend.asarray(
|
|
38
|
+
[float(d) for d in data],
|
|
39
|
+
dtype=backend.default_floating_dtype,
|
|
40
|
+
device=device
|
|
41
|
+
)
|
|
42
|
+
else:
|
|
43
|
+
data = [
|
|
44
|
+
data_stack(value, backend)
|
|
45
|
+
for value in data
|
|
46
|
+
]
|
|
47
|
+
return data
|
|
48
|
+
|
|
49
|
+
class SyncVecEnv(Env[
|
|
50
|
+
BArrayType, ContextType, ObsType, ActType, RenderFrame, BDeviceType, BDtypeType, BRNGType
|
|
51
|
+
]):
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
env_fn : Iterable[Callable[[], Env[
|
|
55
|
+
BArrayType, ContextType, ObsType, ActType, RenderFrame, BDeviceType, BDtypeType, BRNGType
|
|
56
|
+
]]],
|
|
57
|
+
seed : Optional[int] = None,
|
|
58
|
+
):
|
|
59
|
+
self.envs = [fn() for fn in env_fn]
|
|
60
|
+
assert len(self.envs) > 1, "env_fns should have more than 1 env function"
|
|
61
|
+
assert all(env.batch_size is None for env in self.envs), "All envs must be non-batched envs"
|
|
62
|
+
|
|
63
|
+
# check all envs have the same backend, device, action_space, observation_space, context_space
|
|
64
|
+
first_env = self.envs[0]
|
|
65
|
+
for env in self.envs[1:]:
|
|
66
|
+
assert env.backend == first_env.backend, "All envs must have the same backend"
|
|
67
|
+
assert env.device == first_env.device, "All envs must have the same device"
|
|
68
|
+
# assert env.action_space == first_env.action_space, "All envs must have the same action_space"
|
|
69
|
+
# assert env.observation_space == first_env.observation_space, "All envs must have the same observation_space"
|
|
70
|
+
# assert env.context_space == first_env.context_space, "All envs must have the same context_space"
|
|
71
|
+
|
|
72
|
+
self.action_space = sbu.batch_differing_spaces(
|
|
73
|
+
[env.action_space for env in self.envs],
|
|
74
|
+
device=first_env.device,
|
|
75
|
+
)
|
|
76
|
+
self.observation_space = sbu.batch_differing_spaces(
|
|
77
|
+
[env.observation_space for env in self.envs],
|
|
78
|
+
device=first_env.device,
|
|
79
|
+
)
|
|
80
|
+
self.context_space = None if first_env.context_space is None else sbu.batch_differing_spaces(
|
|
81
|
+
[env.context_space for env in self.envs],
|
|
82
|
+
device=first_env.device,
|
|
83
|
+
)
|
|
84
|
+
self.rng = self.backend.random.random_number_generator(
|
|
85
|
+
seed,
|
|
86
|
+
device=first_env.device,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
@property
|
|
90
|
+
def metadata(self) -> Dict[str, Any]:
|
|
91
|
+
return self.envs[0].metadata
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def render_mode(self) -> Optional[str]:
|
|
95
|
+
return self.envs[0].render_mode
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def render_fps(self) -> Optional[int]:
|
|
99
|
+
return self.envs[0].render_fps
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def backend(self) -> ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType]:
|
|
103
|
+
return self.envs[0].backend
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def device(self) -> Optional[BDeviceType]:
|
|
107
|
+
return self.envs[0].device
|
|
108
|
+
|
|
109
|
+
@property
|
|
110
|
+
def batch_size(self) -> Optional[int]:
|
|
111
|
+
return len(self.envs)
|
|
112
|
+
|
|
113
|
+
def reset(
|
|
114
|
+
self,
|
|
115
|
+
*args,
|
|
116
|
+
mask : Optional[BArrayType] = None,
|
|
117
|
+
seed : Optional[int] = None,
|
|
118
|
+
**kwargs
|
|
119
|
+
) -> Tuple[ContextType, ObsType, Dict[str, Any]]:
|
|
120
|
+
if seed is not None:
|
|
121
|
+
self.rng = self.backend.random.random_number_generator(
|
|
122
|
+
seed,
|
|
123
|
+
device=self.device,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
all_contexts = []
|
|
127
|
+
all_obs = []
|
|
128
|
+
all_infos = []
|
|
129
|
+
for i, env in enumerate(self.envs):
|
|
130
|
+
env_reset = True if mask is None else bool(mask[i])
|
|
131
|
+
if env_reset:
|
|
132
|
+
context, obs, info = env.reset(*args, **kwargs)
|
|
133
|
+
all_contexts.append(context)
|
|
134
|
+
all_obs.append(obs)
|
|
135
|
+
all_infos.append(info)
|
|
136
|
+
|
|
137
|
+
if self.context_space is not None:
|
|
138
|
+
all_contexts = sbu.concatenate(
|
|
139
|
+
self.context_space,
|
|
140
|
+
all_contexts,
|
|
141
|
+
axis=0,
|
|
142
|
+
)
|
|
143
|
+
else:
|
|
144
|
+
all_contexts = None
|
|
145
|
+
|
|
146
|
+
all_obs = sbu.concatenate(
|
|
147
|
+
self.observation_space,
|
|
148
|
+
all_obs,
|
|
149
|
+
axis=0,
|
|
150
|
+
)
|
|
151
|
+
all_infos = data_stack(
|
|
152
|
+
all_infos,
|
|
153
|
+
self.backend,
|
|
154
|
+
self.device
|
|
155
|
+
)
|
|
156
|
+
return all_contexts, all_obs, all_infos
|
|
157
|
+
|
|
158
|
+
def step(
|
|
159
|
+
self,
|
|
160
|
+
action : ActType
|
|
161
|
+
) -> Tuple[ObsType, BArrayType, BArrayType, BArrayType, Dict[str, Any]]:
|
|
162
|
+
actions = sbu.iterate(
|
|
163
|
+
self.action_space,
|
|
164
|
+
action
|
|
165
|
+
)
|
|
166
|
+
all_obs = []
|
|
167
|
+
all_rewards = []
|
|
168
|
+
all_terminated = []
|
|
169
|
+
all_truncated = []
|
|
170
|
+
all_infos = []
|
|
171
|
+
for i, env in enumerate(self.envs):
|
|
172
|
+
obs, reward, terminated, truncated, info = env.step(next(actions))
|
|
173
|
+
all_obs.append(obs)
|
|
174
|
+
all_rewards.append(reward)
|
|
175
|
+
all_terminated.append(terminated)
|
|
176
|
+
all_truncated.append(truncated)
|
|
177
|
+
all_infos.append(info)
|
|
178
|
+
all_obs = sbu.concatenate(
|
|
179
|
+
self.observation_space,
|
|
180
|
+
all_obs,
|
|
181
|
+
axis=0,
|
|
182
|
+
)
|
|
183
|
+
all_rewards = self.backend.asarray(
|
|
184
|
+
[float(r) for r in all_rewards],
|
|
185
|
+
dtype=self.backend.default_floating_dtype,
|
|
186
|
+
device=self.device,
|
|
187
|
+
)
|
|
188
|
+
all_terminated = self.backend.asarray(
|
|
189
|
+
[bool(t) for t in all_terminated],
|
|
190
|
+
dtype=self.backend.default_boolean_dtype,
|
|
191
|
+
device=self.device,
|
|
192
|
+
)
|
|
193
|
+
all_truncated = self.backend.asarray(
|
|
194
|
+
[bool(t) for t in all_truncated],
|
|
195
|
+
dtype=self.backend.default_boolean_dtype,
|
|
196
|
+
device=self.device,
|
|
197
|
+
)
|
|
198
|
+
all_infos = data_stack(
|
|
199
|
+
all_infos,
|
|
200
|
+
self.backend,
|
|
201
|
+
self.device
|
|
202
|
+
)
|
|
203
|
+
return all_obs, all_rewards, all_terminated, all_truncated, all_infos
|
|
204
|
+
|
|
205
|
+
def render(self) -> Sequence[RenderFrame] | None:
|
|
206
|
+
frames = []
|
|
207
|
+
for env in self.envs:
|
|
208
|
+
frame = env.render()
|
|
209
|
+
if frame is not None:
|
|
210
|
+
frames.append(frame)
|
|
211
|
+
return frames if len(frames) > 0 else None
|
|
212
|
+
|
|
213
|
+
def close(self):
|
|
214
|
+
for env in self.envs:
|
|
215
|
+
env.close()
|
|
216
|
+
self.envs = []
|
|
217
|
+
|
|
218
|
+
def _async_worker_fn(
|
|
219
|
+
index : int,
|
|
220
|
+
env_fn : Callable[[], Env[
|
|
221
|
+
BArrayType, ContextType, ObsType, ActType, RenderFrame, BDeviceType, BDtypeType, BRNGType
|
|
222
|
+
]],
|
|
223
|
+
pipe : MPConnection,
|
|
224
|
+
parent_pipe : MPConnection,
|
|
225
|
+
error_queue : mp.Queue,
|
|
226
|
+
) -> None:
|
|
227
|
+
parent_pipe.close()
|
|
228
|
+
del parent_pipe
|
|
229
|
+
env = env_fn()
|
|
230
|
+
try:
|
|
231
|
+
while True:
|
|
232
|
+
cmd, args, kwargs = pipe.recv()
|
|
233
|
+
if cmd == "reset":
|
|
234
|
+
context, observation, info = env.reset(*args, **kwargs)
|
|
235
|
+
pipe.send(((context, observation, info), True))
|
|
236
|
+
elif cmd == "step":
|
|
237
|
+
observation, reward, terminated, truncated, info = env.step(*args, **kwargs)
|
|
238
|
+
pipe.send(((observation, reward, terminated, truncated, info), True))
|
|
239
|
+
elif cmd == "render":
|
|
240
|
+
frame = env.render(*args, **kwargs)
|
|
241
|
+
pipe.send((frame, True))
|
|
242
|
+
elif cmd == "close":
|
|
243
|
+
break
|
|
244
|
+
except (KeyboardInterrupt, Exception) as e:
|
|
245
|
+
pipe.send(None, False)
|
|
246
|
+
error_queue.put((index, e))
|
|
247
|
+
finally:
|
|
248
|
+
env.close()
|
|
249
|
+
pipe.close()
|
|
250
|
+
|
|
251
|
+
class AsyncVecEnv(Env[
|
|
252
|
+
BArrayType, ContextType, ObsType, ActType, RenderFrame, BDeviceType, BDtypeType, BRNGType
|
|
253
|
+
]):
|
|
254
|
+
def __init__(
|
|
255
|
+
self,
|
|
256
|
+
env_fn : Iterable[Callable[[], Env[
|
|
257
|
+
BArrayType, ContextType, ObsType, ActType, RenderFrame, BDeviceType, BDtypeType, BRNGType
|
|
258
|
+
]]],
|
|
259
|
+
seed : Optional[int] = None,
|
|
260
|
+
ctx : Optional[MPContext] = None,
|
|
261
|
+
daemon : bool = True,
|
|
262
|
+
):
|
|
263
|
+
ctx = ctx or mp.get_context()
|
|
264
|
+
self.command_pipes : List[MPConnection] = []
|
|
265
|
+
self.processes : List[mp.Process] = []
|
|
266
|
+
self.error_queue : mp.Queue = ctx.Queue()
|
|
267
|
+
|
|
268
|
+
dummy_fn = None
|
|
269
|
+
for i, fn in enumerate(env_fn):
|
|
270
|
+
if i == 0:
|
|
271
|
+
dummy_fn = fn
|
|
272
|
+
parent_pipe, child_pipe = ctx.Pipe()
|
|
273
|
+
process = ctx.Process(
|
|
274
|
+
target=_async_worker_fn,
|
|
275
|
+
name="AsyncWorker-{}".format(i),
|
|
276
|
+
args=(i, MultiProcessFn(fn), child_pipe, parent_pipe, self.error_queue),
|
|
277
|
+
daemon=daemon,
|
|
278
|
+
)
|
|
279
|
+
process.start()
|
|
280
|
+
child_pipe.close()
|
|
281
|
+
|
|
282
|
+
self.command_pipes.append(parent_pipe)
|
|
283
|
+
self.processes.append(process)
|
|
284
|
+
|
|
285
|
+
assert len(self.processes) > 1, "env_fns should have more than 1 env function"
|
|
286
|
+
|
|
287
|
+
# Use Dummy Environment to get spaces and metadata
|
|
288
|
+
dummy_env = dummy_fn()
|
|
289
|
+
self.backend = dummy_env.backend
|
|
290
|
+
self.device = dummy_env.device
|
|
291
|
+
self.metadata = dummy_env.metadata
|
|
292
|
+
self.render_mode = dummy_env.render_mode
|
|
293
|
+
self.render_fps = dummy_env.render_fps
|
|
294
|
+
self.action_space = sbu.batch_space(
|
|
295
|
+
dummy_env.action_space,
|
|
296
|
+
len(self.processes),
|
|
297
|
+
)
|
|
298
|
+
self.observation_space = sbu.batch_space(
|
|
299
|
+
dummy_env.observation_space,
|
|
300
|
+
len(self.processes),
|
|
301
|
+
)
|
|
302
|
+
self.context_space = None if dummy_env.context_space is None else sbu.batch_space(
|
|
303
|
+
dummy_env.context_space,
|
|
304
|
+
len(self.processes),
|
|
305
|
+
)
|
|
306
|
+
self.rng = dummy_env.backend.random.random_number_generator(
|
|
307
|
+
seed,
|
|
308
|
+
device=dummy_env.device,
|
|
309
|
+
)
|
|
310
|
+
dummy_env.close()
|
|
311
|
+
del dummy_env
|
|
312
|
+
|
|
313
|
+
# Temporal mask storage
|
|
314
|
+
self._reset_mask : Optional[BArrayType] = None
|
|
315
|
+
|
|
316
|
+
@property
|
|
317
|
+
def batch_size(self) -> Optional[int]:
|
|
318
|
+
return len(self.processes)
|
|
319
|
+
|
|
320
|
+
def send_command(self, index: int, cmd: Literal["reset", "step", "render", "close"], *args, **kwargs):
|
|
321
|
+
self.command_pipes[index].send((cmd, args, kwargs))
|
|
322
|
+
|
|
323
|
+
def get_command_result(self, index: int):
|
|
324
|
+
data, success = self.command_pipes[index].recv()
|
|
325
|
+
if not success:
|
|
326
|
+
self._raise_if_error()
|
|
327
|
+
return data
|
|
328
|
+
|
|
329
|
+
def reset(
|
|
330
|
+
self,
|
|
331
|
+
*args,
|
|
332
|
+
mask : Optional[BArrayType] = None,
|
|
333
|
+
seed : Optional[int] = None,
|
|
334
|
+
**kwargs
|
|
335
|
+
) -> Tuple[ContextType, ObsType, Dict[str, Any]]:
|
|
336
|
+
self.reset_async(*args, mask=mask, seed=seed, **kwargs)
|
|
337
|
+
return self.reset_wait()
|
|
338
|
+
|
|
339
|
+
def reset_async(
|
|
340
|
+
self,
|
|
341
|
+
*args,
|
|
342
|
+
mask : Optional[BArrayType] = None,
|
|
343
|
+
seed : Optional[int] = None,
|
|
344
|
+
**kwargs
|
|
345
|
+
) -> None:
|
|
346
|
+
if seed is not None:
|
|
347
|
+
self.rng = self.backend.random.random_number_generator(
|
|
348
|
+
seed,
|
|
349
|
+
device=self.device,
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
self._reset_mask = mask
|
|
353
|
+
for i in range(len(self.processes)):
|
|
354
|
+
env_reset = True if mask is None else bool(mask[i])
|
|
355
|
+
if env_reset:
|
|
356
|
+
self.send_command(i, "reset", *args, **kwargs)
|
|
357
|
+
|
|
358
|
+
def reset_wait(
|
|
359
|
+
self,
|
|
360
|
+
) -> Tuple[ContextType, ObsType, Dict[str, Any]]:
|
|
361
|
+
all_contexts = []
|
|
362
|
+
all_obs = []
|
|
363
|
+
all_infos = []
|
|
364
|
+
for i in range(len(self.processes)):
|
|
365
|
+
env_reset = True if self._reset_mask is None else bool(self._reset_mask[i])
|
|
366
|
+
if env_reset:
|
|
367
|
+
context, obs, info = self.get_command_result(i)
|
|
368
|
+
all_contexts.append(context)
|
|
369
|
+
all_obs.append(obs)
|
|
370
|
+
all_infos.append(info)
|
|
371
|
+
|
|
372
|
+
if self.context_space is not None:
|
|
373
|
+
all_contexts = sbu.concatenate(
|
|
374
|
+
self.context_space,
|
|
375
|
+
all_contexts,
|
|
376
|
+
axis=0,
|
|
377
|
+
)
|
|
378
|
+
else:
|
|
379
|
+
all_contexts = None
|
|
380
|
+
|
|
381
|
+
all_obs = sbu.concatenate(
|
|
382
|
+
self.observation_space,
|
|
383
|
+
all_obs,
|
|
384
|
+
axis=0,
|
|
385
|
+
)
|
|
386
|
+
all_infos = data_stack(
|
|
387
|
+
all_infos,
|
|
388
|
+
self.backend,
|
|
389
|
+
self.device
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
self._reset_mask = None
|
|
393
|
+
return all_contexts, all_obs, all_infos
|
|
394
|
+
|
|
395
|
+
def step(
|
|
396
|
+
self,
|
|
397
|
+
action : ActType
|
|
398
|
+
) -> Tuple[ObsType, BArrayType, BArrayType, BArrayType, Dict[str, Any]]:
|
|
399
|
+
self.step_async(action)
|
|
400
|
+
return self.step_wait()
|
|
401
|
+
|
|
402
|
+
def step_async(
|
|
403
|
+
self,
|
|
404
|
+
action : ActType
|
|
405
|
+
) -> None:
|
|
406
|
+
actions = sbu.iterate(
|
|
407
|
+
self.action_space,
|
|
408
|
+
action
|
|
409
|
+
)
|
|
410
|
+
for i in range(len(self.processes)):
|
|
411
|
+
self.send_command(i, "step", next(actions))
|
|
412
|
+
|
|
413
|
+
def step_wait(
|
|
414
|
+
self,
|
|
415
|
+
) -> Tuple[ObsType, BArrayType, BArrayType, BArrayType, Dict[str, Any]]:
|
|
416
|
+
all_obs = []
|
|
417
|
+
all_rewards = []
|
|
418
|
+
all_terminated = []
|
|
419
|
+
all_truncated = []
|
|
420
|
+
all_infos = []
|
|
421
|
+
for i in range(len(self.processes)):
|
|
422
|
+
obs, reward, terminated, truncated, info = self.get_command_result(i)
|
|
423
|
+
all_obs.append(obs)
|
|
424
|
+
all_rewards.append(reward)
|
|
425
|
+
all_terminated.append(terminated)
|
|
426
|
+
all_truncated.append(truncated)
|
|
427
|
+
all_infos.append(info)
|
|
428
|
+
all_obs = sbu.concatenate(
|
|
429
|
+
self.observation_space,
|
|
430
|
+
all_obs,
|
|
431
|
+
axis=0,
|
|
432
|
+
)
|
|
433
|
+
all_rewards = self.backend.asarray(
|
|
434
|
+
[float(r) for r in all_rewards],
|
|
435
|
+
dtype=self.backend.default_floating_dtype,
|
|
436
|
+
device=self.device,
|
|
437
|
+
)
|
|
438
|
+
all_terminated = self.backend.asarray(
|
|
439
|
+
[bool(t) for t in all_terminated],
|
|
440
|
+
dtype=self.backend.default_boolean_dtype,
|
|
441
|
+
device=self.device,
|
|
442
|
+
)
|
|
443
|
+
all_truncated = self.backend.asarray(
|
|
444
|
+
[bool(t) for t in all_truncated],
|
|
445
|
+
dtype=self.backend.default_boolean_dtype,
|
|
446
|
+
device=self.device,
|
|
447
|
+
)
|
|
448
|
+
all_infos = data_stack(
|
|
449
|
+
all_infos,
|
|
450
|
+
self.backend,
|
|
451
|
+
self.device
|
|
452
|
+
)
|
|
453
|
+
return all_obs, all_rewards, all_terminated, all_truncated, all_infos
|
|
454
|
+
|
|
455
|
+
def close(self):
|
|
456
|
+
for i in range(len(self.processes)):
|
|
457
|
+
try:
|
|
458
|
+
self.send_command(i, "close")
|
|
459
|
+
except Exception:
|
|
460
|
+
pass
|
|
461
|
+
for process in self.processes:
|
|
462
|
+
process.join()
|
|
463
|
+
process.close()
|
|
464
|
+
for pipe in self.command_pipes:
|
|
465
|
+
pipe.close()
|
|
466
|
+
self.command_pipes = []
|
|
467
|
+
self.processes = []
|
|
468
|
+
|
|
469
|
+
def _raise_if_error(self):
|
|
470
|
+
try:
|
|
471
|
+
index, e = self.error_queue.get(block=False)
|
|
472
|
+
except QueueEmpty:
|
|
473
|
+
raise RuntimeError("Unknown error in AsyncVecEnv worker without any error message.")
|
|
474
|
+
raise RuntimeError("Error in AsyncVecEnv worker {}".format(index)) from e
|
|
@@ -1 +1,2 @@
|
|
|
1
|
-
from .transformation import FuncTransformWrapper
|
|
1
|
+
from .transformation import FuncTransformWrapper
|
|
2
|
+
from .frame_stack import FuncFrameStackWrapper, FuncFrameStackWrapperState
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
from typing import Dict as DictT, Any, Optional, Tuple, Union, Generic, SupportsFloat, Type, Sequence, TypeVar
|
|
2
|
+
import numpy as np
|
|
3
|
+
import copy
|
|
4
|
+
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
5
|
+
|
|
6
|
+
from unienv_interface.space.space_utils import batch_utils as sbu, flatten_utils as sfu
|
|
7
|
+
from unienv_interface.utils import seed_util
|
|
8
|
+
from unienv_interface.env_base.funcenv import FuncEnv, ContextType, ObsType, ActType, RenderFrame, StateType, RenderStateType
|
|
9
|
+
from unienv_interface.env_base.funcenv_wrapper import *
|
|
10
|
+
from unienv_interface.space import Space
|
|
11
|
+
from unienv_interface.utils.data_queue import FuncSpaceDataQueue, SpaceDataQueueState
|
|
12
|
+
from unienv_interface.utils.stateclass import StateClass, field
|
|
13
|
+
|
|
14
|
+
class FuncFrameStackWrapperState(
|
|
15
|
+
Generic[StateType], StateClass
|
|
16
|
+
):
|
|
17
|
+
env_state : StateType
|
|
18
|
+
obs_queue_state : Optional[SpaceDataQueueState]
|
|
19
|
+
action_queue_state : Optional[SpaceDataQueueState]
|
|
20
|
+
|
|
21
|
+
class FuncFrameStackWrapper(
|
|
22
|
+
FuncEnvWrapper[
|
|
23
|
+
FuncFrameStackWrapperState[StateType], RenderStateType,
|
|
24
|
+
BArrayType, ContextType, ObsType, ActType, RenderFrame, BDeviceType, BDtypeType, BRNGType,
|
|
25
|
+
StateType, RenderStateType,
|
|
26
|
+
BArrayType, ContextType, ObsType, ActType, RenderFrame, BDeviceType, BDtypeType, BRNGType
|
|
27
|
+
]
|
|
28
|
+
):
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
func_env: FuncEnv[
|
|
32
|
+
StateType, RenderStateType,
|
|
33
|
+
BArrayType, ContextType, ObsType, ActType, RenderFrame, BDeviceType, BDtypeType, BRNGType
|
|
34
|
+
],
|
|
35
|
+
obs_stack_size: int = 0,
|
|
36
|
+
action_stack_size: int = 0,
|
|
37
|
+
action_default_value: Optional[ActType] = None,
|
|
38
|
+
):
|
|
39
|
+
assert obs_stack_size >= 0, "Observation stack size must be greater than 0"
|
|
40
|
+
assert action_stack_size >= 0, "Action stack size must be greater than 0"
|
|
41
|
+
assert action_stack_size == 0 or action_default_value is not None, "Action default value must be provided if action stack size is greater than 0"
|
|
42
|
+
assert obs_stack_size > 0 or action_stack_size > 0, "At least one of observation stack size or action stack size must be greater than 0"
|
|
43
|
+
super().__init__(func_env)
|
|
44
|
+
obs_is_dict = isinstance(func_env.observation_space, DictT)
|
|
45
|
+
assert obs_is_dict or action_stack_size == 0, "Action stack size must be 0 if observation space is not a DictSpace"
|
|
46
|
+
|
|
47
|
+
self.action_stack_size = action_stack_size
|
|
48
|
+
self.obs_stack_size = obs_stack_size
|
|
49
|
+
|
|
50
|
+
if action_stack_size > 0:
|
|
51
|
+
self.action_deque = FuncSpaceDataQueue(
|
|
52
|
+
func_env.action_space,
|
|
53
|
+
func_env.batch_size,
|
|
54
|
+
action_stack_size,
|
|
55
|
+
)
|
|
56
|
+
self.action_default_value = action_default_value
|
|
57
|
+
else:
|
|
58
|
+
self.action_deque = None
|
|
59
|
+
|
|
60
|
+
self.obs_deque = None
|
|
61
|
+
if obs_stack_size > 0:
|
|
62
|
+
self.obs_deque = FuncSpaceDataQueue(
|
|
63
|
+
func_env.observation_space,
|
|
64
|
+
func_env.batch_size,
|
|
65
|
+
obs_stack_size + 1,
|
|
66
|
+
)
|
|
67
|
+
if action_stack_size > 0:
|
|
68
|
+
new_obs_space = copy.copy(self.obs_deque.output_space)
|
|
69
|
+
new_obs_space['past_actions'] = self.action_deque.output_space
|
|
70
|
+
self.observation_space = new_obs_space
|
|
71
|
+
else:
|
|
72
|
+
self.observation_space = self.obs_deque.output_space
|
|
73
|
+
else:
|
|
74
|
+
if action_stack_size > 0:
|
|
75
|
+
self.observation_space = copy.copy(func_env.observation_space)
|
|
76
|
+
self.observation_space['past_actions'] = self.action_deque.output_space
|
|
77
|
+
else:
|
|
78
|
+
raise ValueError("At least one of observation stack size or action stack size must be greater than 0")
|
|
79
|
+
|
|
80
|
+
def map_observation(
|
|
81
|
+
self,
|
|
82
|
+
state : FuncFrameStackWrapperState[StateType],
|
|
83
|
+
observation : ObsType
|
|
84
|
+
) -> ObsType:
|
|
85
|
+
if self.obs_deque is not None:
|
|
86
|
+
observation = self.obs_deque.get_output_data(state.obs_queue_state)
|
|
87
|
+
if self.action_deque is not None:
|
|
88
|
+
stacked_action = self.action_deque.get_output_data(state.action_queue_state)
|
|
89
|
+
observation = copy.copy(observation)
|
|
90
|
+
observation['past_actions'] = stacked_action
|
|
91
|
+
return observation
|
|
92
|
+
|
|
93
|
+
def initial(self, *, seed = None, **kwargs):
|
|
94
|
+
init_state, init_context, init_obs, init_info = self.func_env.initial(seed=seed, **kwargs)
|
|
95
|
+
obs_queue_state = None
|
|
96
|
+
action_queue_state = None
|
|
97
|
+
if self.obs_deque is not None:
|
|
98
|
+
obs_queue_state = self.obs_deque.init(init_obs)
|
|
99
|
+
if self.action_deque is not None:
|
|
100
|
+
action_queue_state = self.action_deque.init(self.action_default_value)
|
|
101
|
+
state = FuncFrameStackWrapperState(
|
|
102
|
+
env_state=init_state,
|
|
103
|
+
obs_queue_state=obs_queue_state,
|
|
104
|
+
action_queue_state=action_queue_state,
|
|
105
|
+
)
|
|
106
|
+
return state, init_context, self.map_observation(state, init_obs), init_info
|
|
107
|
+
|
|
108
|
+
def reset(self, state, *args, seed = None, mask = None, **kwargs):
|
|
109
|
+
env_state, context, observation, info = self.func_env.reset(
|
|
110
|
+
state.env_state,
|
|
111
|
+
*args,
|
|
112
|
+
seed=seed,
|
|
113
|
+
mask=mask,
|
|
114
|
+
**kwargs
|
|
115
|
+
)
|
|
116
|
+
obs_queue_state = state.obs_queue_state
|
|
117
|
+
action_queue_state = state.action_queue_state
|
|
118
|
+
if self.obs_deque is not None:
|
|
119
|
+
obs_queue_state = self.obs_deque.reset(
|
|
120
|
+
obs_queue_state,
|
|
121
|
+
initial_data=observation,
|
|
122
|
+
mask=mask
|
|
123
|
+
)
|
|
124
|
+
if self.action_deque is not None:
|
|
125
|
+
action_queue_state = self.action_deque.reset(
|
|
126
|
+
action_queue_state,
|
|
127
|
+
initial_data=self.action_default_value,
|
|
128
|
+
mask=mask
|
|
129
|
+
)
|
|
130
|
+
new_state = FuncFrameStackWrapperState(
|
|
131
|
+
env_state=env_state,
|
|
132
|
+
obs_queue_state=obs_queue_state,
|
|
133
|
+
action_queue_state=action_queue_state,
|
|
134
|
+
)
|
|
135
|
+
return new_state, context, self.map_observation(new_state, observation), info
|
|
136
|
+
|
|
137
|
+
def step(self, state, action):
|
|
138
|
+
env_state, observation, reward, terminated, truncated, info = self.func_env.step(state.env_state, action)
|
|
139
|
+
obs_queue_state = state.obs_queue_state
|
|
140
|
+
action_queue_state = state.action_queue_state
|
|
141
|
+
if self.obs_deque is not None:
|
|
142
|
+
obs_queue_state = self.obs_deque.add(obs_queue_state, observation)
|
|
143
|
+
if self.action_deque is not None:
|
|
144
|
+
action_queue_state = self.action_deque.add(action_queue_state, action)
|
|
145
|
+
new_state = FuncFrameStackWrapperState(
|
|
146
|
+
env_state=env_state,
|
|
147
|
+
obs_queue_state=obs_queue_state,
|
|
148
|
+
action_queue_state=action_queue_state,
|
|
149
|
+
)
|
|
150
|
+
return new_state, self.map_observation(new_state, observation), reward, terminated, truncated, info
|