unienv 0.0.1b1__py3-none-any.whl → 0.0.1b3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. unienv-0.0.1b3.dist-info/METADATA +74 -0
  2. unienv-0.0.1b3.dist-info/RECORD +92 -0
  3. {unienv-0.0.1b1.dist-info → unienv-0.0.1b3.dist-info}/licenses/LICENSE +1 -3
  4. unienv-0.0.1b3.dist-info/top_level.txt +2 -0
  5. unienv_data/base/__init__.py +0 -1
  6. unienv_data/base/common.py +95 -45
  7. unienv_data/base/storage.py +1 -0
  8. unienv_data/batches/__init__.py +2 -1
  9. unienv_data/batches/backend_compat.py +47 -1
  10. unienv_data/batches/combined_batch.py +2 -4
  11. unienv_data/{base → batches}/transformations.py +3 -2
  12. unienv_data/replay_buffer/replay_buffer.py +4 -0
  13. unienv_data/samplers/__init__.py +0 -1
  14. unienv_data/samplers/multiprocessing_sampler.py +26 -22
  15. unienv_data/samplers/step_sampler.py +9 -18
  16. unienv_data/storages/common.py +5 -0
  17. unienv_data/storages/hdf5.py +291 -20
  18. unienv_data/storages/pytorch.py +1 -0
  19. unienv_data/storages/transformation.py +191 -0
  20. unienv_data/transformations/image_compress.py +213 -0
  21. unienv_interface/backends/jax.py +4 -1
  22. unienv_interface/backends/numpy.py +4 -1
  23. unienv_interface/backends/pytorch.py +4 -1
  24. unienv_interface/env_base/__init__.py +1 -0
  25. unienv_interface/env_base/env.py +5 -0
  26. unienv_interface/env_base/funcenv.py +32 -1
  27. unienv_interface/env_base/funcenv_wrapper.py +2 -2
  28. unienv_interface/env_base/vec_env.py +474 -0
  29. unienv_interface/func_wrapper/__init__.py +2 -1
  30. unienv_interface/func_wrapper/frame_stack.py +150 -0
  31. unienv_interface/space/space_utils/__init__.py +1 -0
  32. unienv_interface/space/space_utils/batch_utils.py +83 -0
  33. unienv_interface/space/space_utils/construct_utils.py +216 -0
  34. unienv_interface/space/space_utils/serialization_utils.py +16 -1
  35. unienv_interface/space/spaces/__init__.py +3 -1
  36. unienv_interface/space/spaces/batched.py +90 -0
  37. unienv_interface/space/spaces/binary.py +0 -1
  38. unienv_interface/space/spaces/box.py +13 -24
  39. unienv_interface/space/spaces/text.py +1 -3
  40. unienv_interface/transformations/dict_transform.py +31 -5
  41. unienv_interface/utils/control_util.py +68 -0
  42. unienv_interface/utils/data_queue.py +184 -0
  43. unienv_interface/utils/stateclass.py +46 -0
  44. unienv_interface/utils/vec_util.py +15 -0
  45. unienv_interface/world/__init__.py +3 -1
  46. unienv_interface/world/combined_funcnode.py +336 -0
  47. unienv_interface/world/combined_node.py +232 -0
  48. unienv_interface/wrapper/backend_compat.py +2 -2
  49. unienv_interface/wrapper/frame_stack.py +19 -114
  50. unienv_interface/wrapper/video_record.py +11 -2
  51. unienv-0.0.1b1.dist-info/METADATA +0 -20
  52. unienv-0.0.1b1.dist-info/RECORD +0 -85
  53. unienv-0.0.1b1.dist-info/top_level.txt +0 -4
  54. unienv_data/samplers/slice_sampler.py +0 -266
  55. unienv_maniskill/__init__.py +0 -1
  56. unienv_maniskill/wrapper/maniskill_compat.py +0 -235
  57. unienv_mjxplayground/__init__.py +0 -1
  58. unienv_mjxplayground/wrapper/playground_compat.py +0 -256
  59. {unienv-0.0.1b1.dist-info → unienv-0.0.1b3.dist-info}/WHEEL +0 -0
@@ -0,0 +1,474 @@
1
+ from .env import Env, ContextType, ObsType, ActType, RenderFrame
2
+ from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
3
+ from unienv_interface.space import Space, batch_utils as sbu
4
+ from unienv_interface.utils.vec_util import MultiProcessFn
5
+ from typing import Any, Dict, Generic, Literal, Optional, SupportsFloat, Tuple, TypeVar, Callable, Iterable, Mapping, Sequence, List
6
+ import numpy as np
7
+ import multiprocessing as mp
8
+ from multiprocessing.connection import Connection as MPConnection
9
+ from multiprocessing.context import BaseContext as MPContext
10
+ from queue import Empty as QueueEmpty
11
+
12
+ def data_stack(
13
+ data : Any,
14
+ backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
15
+ device : Optional[BDeviceType] = None,
16
+ ):
17
+ if isinstance(data, Mapping):
18
+ data = {
19
+ key: data_stack(value, backend)
20
+ for key, value in data.items()
21
+ }
22
+ if isinstance(data, Sequence):
23
+ if len(data) == 0:
24
+ return data
25
+ if backend.is_backendarray(data[0]):
26
+ return backend.stack(data, axis=0)
27
+ elif isinstance(data[0], np.ndarray):
28
+ return np.stack(data, axis=0)
29
+ elif isinstance(data[0], (int, float, bool)):
30
+ dtype = (
31
+ backend.default_boolean_dtype if isinstance(data[0], bool) else
32
+ backend.default_floating_dtype if isinstance(data[0], float) else
33
+ backend.default_integer_dtype
34
+ )
35
+ return backend.asarray(data, dtype=dtype, device=device)
36
+ elif isinstance(data[0], SupportsFloat):
37
+ return backend.asarray(
38
+ [float(d) for d in data],
39
+ dtype=backend.default_floating_dtype,
40
+ device=device
41
+ )
42
+ else:
43
+ data = [
44
+ data_stack(value, backend)
45
+ for value in data
46
+ ]
47
+ return data
48
+
49
+ class SyncVecEnv(Env[
50
+ BArrayType, ContextType, ObsType, ActType, RenderFrame, BDeviceType, BDtypeType, BRNGType
51
+ ]):
52
+ def __init__(
53
+ self,
54
+ env_fn : Iterable[Callable[[], Env[
55
+ BArrayType, ContextType, ObsType, ActType, RenderFrame, BDeviceType, BDtypeType, BRNGType
56
+ ]]],
57
+ seed : Optional[int] = None,
58
+ ):
59
+ self.envs = [fn() for fn in env_fn]
60
+ assert len(self.envs) > 1, "env_fns should have more than 1 env function"
61
+ assert all(env.batch_size is None for env in self.envs), "All envs must be non-batched envs"
62
+
63
+ # check all envs have the same backend, device, action_space, observation_space, context_space
64
+ first_env = self.envs[0]
65
+ for env in self.envs[1:]:
66
+ assert env.backend == first_env.backend, "All envs must have the same backend"
67
+ assert env.device == first_env.device, "All envs must have the same device"
68
+ # assert env.action_space == first_env.action_space, "All envs must have the same action_space"
69
+ # assert env.observation_space == first_env.observation_space, "All envs must have the same observation_space"
70
+ # assert env.context_space == first_env.context_space, "All envs must have the same context_space"
71
+
72
+ self.action_space = sbu.batch_differing_spaces(
73
+ [env.action_space for env in self.envs],
74
+ device=first_env.device,
75
+ )
76
+ self.observation_space = sbu.batch_differing_spaces(
77
+ [env.observation_space for env in self.envs],
78
+ device=first_env.device,
79
+ )
80
+ self.context_space = None if first_env.context_space is None else sbu.batch_differing_spaces(
81
+ [env.context_space for env in self.envs],
82
+ device=first_env.device,
83
+ )
84
+ self.rng = self.backend.random.random_number_generator(
85
+ seed,
86
+ device=first_env.device,
87
+ )
88
+
89
+ @property
90
+ def metadata(self) -> Dict[str, Any]:
91
+ return self.envs[0].metadata
92
+
93
+ @property
94
+ def render_mode(self) -> Optional[str]:
95
+ return self.envs[0].render_mode
96
+
97
+ @property
98
+ def render_fps(self) -> Optional[int]:
99
+ return self.envs[0].render_fps
100
+
101
+ @property
102
+ def backend(self) -> ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType]:
103
+ return self.envs[0].backend
104
+
105
+ @property
106
+ def device(self) -> Optional[BDeviceType]:
107
+ return self.envs[0].device
108
+
109
+ @property
110
+ def batch_size(self) -> Optional[int]:
111
+ return len(self.envs)
112
+
113
+ def reset(
114
+ self,
115
+ *args,
116
+ mask : Optional[BArrayType] = None,
117
+ seed : Optional[int] = None,
118
+ **kwargs
119
+ ) -> Tuple[ContextType, ObsType, Dict[str, Any]]:
120
+ if seed is not None:
121
+ self.rng = self.backend.random.random_number_generator(
122
+ seed,
123
+ device=self.device,
124
+ )
125
+
126
+ all_contexts = []
127
+ all_obs = []
128
+ all_infos = []
129
+ for i, env in enumerate(self.envs):
130
+ env_reset = True if mask is None else bool(mask[i])
131
+ if env_reset:
132
+ context, obs, info = env.reset(*args, **kwargs)
133
+ all_contexts.append(context)
134
+ all_obs.append(obs)
135
+ all_infos.append(info)
136
+
137
+ if self.context_space is not None:
138
+ all_contexts = sbu.concatenate(
139
+ self.context_space,
140
+ all_contexts,
141
+ axis=0,
142
+ )
143
+ else:
144
+ all_contexts = None
145
+
146
+ all_obs = sbu.concatenate(
147
+ self.observation_space,
148
+ all_obs,
149
+ axis=0,
150
+ )
151
+ all_infos = data_stack(
152
+ all_infos,
153
+ self.backend,
154
+ self.device
155
+ )
156
+ return all_contexts, all_obs, all_infos
157
+
158
+ def step(
159
+ self,
160
+ action : ActType
161
+ ) -> Tuple[ObsType, BArrayType, BArrayType, BArrayType, Dict[str, Any]]:
162
+ actions = sbu.iterate(
163
+ self.action_space,
164
+ action
165
+ )
166
+ all_obs = []
167
+ all_rewards = []
168
+ all_terminated = []
169
+ all_truncated = []
170
+ all_infos = []
171
+ for i, env in enumerate(self.envs):
172
+ obs, reward, terminated, truncated, info = env.step(next(actions))
173
+ all_obs.append(obs)
174
+ all_rewards.append(reward)
175
+ all_terminated.append(terminated)
176
+ all_truncated.append(truncated)
177
+ all_infos.append(info)
178
+ all_obs = sbu.concatenate(
179
+ self.observation_space,
180
+ all_obs,
181
+ axis=0,
182
+ )
183
+ all_rewards = self.backend.asarray(
184
+ [float(r) for r in all_rewards],
185
+ dtype=self.backend.default_floating_dtype,
186
+ device=self.device,
187
+ )
188
+ all_terminated = self.backend.asarray(
189
+ [bool(t) for t in all_terminated],
190
+ dtype=self.backend.default_boolean_dtype,
191
+ device=self.device,
192
+ )
193
+ all_truncated = self.backend.asarray(
194
+ [bool(t) for t in all_truncated],
195
+ dtype=self.backend.default_boolean_dtype,
196
+ device=self.device,
197
+ )
198
+ all_infos = data_stack(
199
+ all_infos,
200
+ self.backend,
201
+ self.device
202
+ )
203
+ return all_obs, all_rewards, all_terminated, all_truncated, all_infos
204
+
205
+ def render(self) -> Sequence[RenderFrame] | None:
206
+ frames = []
207
+ for env in self.envs:
208
+ frame = env.render()
209
+ if frame is not None:
210
+ frames.append(frame)
211
+ return frames if len(frames) > 0 else None
212
+
213
+ def close(self):
214
+ for env in self.envs:
215
+ env.close()
216
+ self.envs = []
217
+
218
+ def _async_worker_fn(
219
+ index : int,
220
+ env_fn : Callable[[], Env[
221
+ BArrayType, ContextType, ObsType, ActType, RenderFrame, BDeviceType, BDtypeType, BRNGType
222
+ ]],
223
+ pipe : MPConnection,
224
+ parent_pipe : MPConnection,
225
+ error_queue : mp.Queue,
226
+ ) -> None:
227
+ parent_pipe.close()
228
+ del parent_pipe
229
+ env = env_fn()
230
+ try:
231
+ while True:
232
+ cmd, args, kwargs = pipe.recv()
233
+ if cmd == "reset":
234
+ context, observation, info = env.reset(*args, **kwargs)
235
+ pipe.send(((context, observation, info), True))
236
+ elif cmd == "step":
237
+ observation, reward, terminated, truncated, info = env.step(*args, **kwargs)
238
+ pipe.send(((observation, reward, terminated, truncated, info), True))
239
+ elif cmd == "render":
240
+ frame = env.render(*args, **kwargs)
241
+ pipe.send((frame, True))
242
+ elif cmd == "close":
243
+ break
244
+ except (KeyboardInterrupt, Exception) as e:
245
+ pipe.send(None, False)
246
+ error_queue.put((index, e))
247
+ finally:
248
+ env.close()
249
+ pipe.close()
250
+
251
+ class AsyncVecEnv(Env[
252
+ BArrayType, ContextType, ObsType, ActType, RenderFrame, BDeviceType, BDtypeType, BRNGType
253
+ ]):
254
+ def __init__(
255
+ self,
256
+ env_fn : Iterable[Callable[[], Env[
257
+ BArrayType, ContextType, ObsType, ActType, RenderFrame, BDeviceType, BDtypeType, BRNGType
258
+ ]]],
259
+ seed : Optional[int] = None,
260
+ ctx : Optional[MPContext] = None,
261
+ daemon : bool = True,
262
+ ):
263
+ ctx = ctx or mp.get_context()
264
+ self.command_pipes : List[MPConnection] = []
265
+ self.processes : List[mp.Process] = []
266
+ self.error_queue : mp.Queue = ctx.Queue()
267
+
268
+ dummy_fn = None
269
+ for i, fn in enumerate(env_fn):
270
+ if i == 0:
271
+ dummy_fn = fn
272
+ parent_pipe, child_pipe = ctx.Pipe()
273
+ process = ctx.Process(
274
+ target=_async_worker_fn,
275
+ name="AsyncWorker-{}".format(i),
276
+ args=(i, MultiProcessFn(fn), child_pipe, parent_pipe, self.error_queue),
277
+ daemon=daemon,
278
+ )
279
+ process.start()
280
+ child_pipe.close()
281
+
282
+ self.command_pipes.append(parent_pipe)
283
+ self.processes.append(process)
284
+
285
+ assert len(self.processes) > 1, "env_fns should have more than 1 env function"
286
+
287
+ # Use Dummy Environment to get spaces and metadata
288
+ dummy_env = dummy_fn()
289
+ self.backend = dummy_env.backend
290
+ self.device = dummy_env.device
291
+ self.metadata = dummy_env.metadata
292
+ self.render_mode = dummy_env.render_mode
293
+ self.render_fps = dummy_env.render_fps
294
+ self.action_space = sbu.batch_space(
295
+ dummy_env.action_space,
296
+ len(self.processes),
297
+ )
298
+ self.observation_space = sbu.batch_space(
299
+ dummy_env.observation_space,
300
+ len(self.processes),
301
+ )
302
+ self.context_space = None if dummy_env.context_space is None else sbu.batch_space(
303
+ dummy_env.context_space,
304
+ len(self.processes),
305
+ )
306
+ self.rng = dummy_env.backend.random.random_number_generator(
307
+ seed,
308
+ device=dummy_env.device,
309
+ )
310
+ dummy_env.close()
311
+ del dummy_env
312
+
313
+ # Temporal mask storage
314
+ self._reset_mask : Optional[BArrayType] = None
315
+
316
+ @property
317
+ def batch_size(self) -> Optional[int]:
318
+ return len(self.processes)
319
+
320
+ def send_command(self, index: int, cmd: Literal["reset", "step", "render", "close"], *args, **kwargs):
321
+ self.command_pipes[index].send((cmd, args, kwargs))
322
+
323
+ def get_command_result(self, index: int):
324
+ data, success = self.command_pipes[index].recv()
325
+ if not success:
326
+ self._raise_if_error()
327
+ return data
328
+
329
+ def reset(
330
+ self,
331
+ *args,
332
+ mask : Optional[BArrayType] = None,
333
+ seed : Optional[int] = None,
334
+ **kwargs
335
+ ) -> Tuple[ContextType, ObsType, Dict[str, Any]]:
336
+ self.reset_async(*args, mask=mask, seed=seed, **kwargs)
337
+ return self.reset_wait()
338
+
339
+ def reset_async(
340
+ self,
341
+ *args,
342
+ mask : Optional[BArrayType] = None,
343
+ seed : Optional[int] = None,
344
+ **kwargs
345
+ ) -> None:
346
+ if seed is not None:
347
+ self.rng = self.backend.random.random_number_generator(
348
+ seed,
349
+ device=self.device,
350
+ )
351
+
352
+ self._reset_mask = mask
353
+ for i in range(len(self.processes)):
354
+ env_reset = True if mask is None else bool(mask[i])
355
+ if env_reset:
356
+ self.send_command(i, "reset", *args, **kwargs)
357
+
358
+ def reset_wait(
359
+ self,
360
+ ) -> Tuple[ContextType, ObsType, Dict[str, Any]]:
361
+ all_contexts = []
362
+ all_obs = []
363
+ all_infos = []
364
+ for i in range(len(self.processes)):
365
+ env_reset = True if self._reset_mask is None else bool(self._reset_mask[i])
366
+ if env_reset:
367
+ context, obs, info = self.get_command_result(i)
368
+ all_contexts.append(context)
369
+ all_obs.append(obs)
370
+ all_infos.append(info)
371
+
372
+ if self.context_space is not None:
373
+ all_contexts = sbu.concatenate(
374
+ self.context_space,
375
+ all_contexts,
376
+ axis=0,
377
+ )
378
+ else:
379
+ all_contexts = None
380
+
381
+ all_obs = sbu.concatenate(
382
+ self.observation_space,
383
+ all_obs,
384
+ axis=0,
385
+ )
386
+ all_infos = data_stack(
387
+ all_infos,
388
+ self.backend,
389
+ self.device
390
+ )
391
+
392
+ self._reset_mask = None
393
+ return all_contexts, all_obs, all_infos
394
+
395
+ def step(
396
+ self,
397
+ action : ActType
398
+ ) -> Tuple[ObsType, BArrayType, BArrayType, BArrayType, Dict[str, Any]]:
399
+ self.step_async(action)
400
+ return self.step_wait()
401
+
402
+ def step_async(
403
+ self,
404
+ action : ActType
405
+ ) -> None:
406
+ actions = sbu.iterate(
407
+ self.action_space,
408
+ action
409
+ )
410
+ for i in range(len(self.processes)):
411
+ self.send_command(i, "step", next(actions))
412
+
413
+ def step_wait(
414
+ self,
415
+ ) -> Tuple[ObsType, BArrayType, BArrayType, BArrayType, Dict[str, Any]]:
416
+ all_obs = []
417
+ all_rewards = []
418
+ all_terminated = []
419
+ all_truncated = []
420
+ all_infos = []
421
+ for i in range(len(self.processes)):
422
+ obs, reward, terminated, truncated, info = self.get_command_result(i)
423
+ all_obs.append(obs)
424
+ all_rewards.append(reward)
425
+ all_terminated.append(terminated)
426
+ all_truncated.append(truncated)
427
+ all_infos.append(info)
428
+ all_obs = sbu.concatenate(
429
+ self.observation_space,
430
+ all_obs,
431
+ axis=0,
432
+ )
433
+ all_rewards = self.backend.asarray(
434
+ [float(r) for r in all_rewards],
435
+ dtype=self.backend.default_floating_dtype,
436
+ device=self.device,
437
+ )
438
+ all_terminated = self.backend.asarray(
439
+ [bool(t) for t in all_terminated],
440
+ dtype=self.backend.default_boolean_dtype,
441
+ device=self.device,
442
+ )
443
+ all_truncated = self.backend.asarray(
444
+ [bool(t) for t in all_truncated],
445
+ dtype=self.backend.default_boolean_dtype,
446
+ device=self.device,
447
+ )
448
+ all_infos = data_stack(
449
+ all_infos,
450
+ self.backend,
451
+ self.device
452
+ )
453
+ return all_obs, all_rewards, all_terminated, all_truncated, all_infos
454
+
455
+ def close(self):
456
+ for i in range(len(self.processes)):
457
+ try:
458
+ self.send_command(i, "close")
459
+ except Exception:
460
+ pass
461
+ for process in self.processes:
462
+ process.join()
463
+ process.close()
464
+ for pipe in self.command_pipes:
465
+ pipe.close()
466
+ self.command_pipes = []
467
+ self.processes = []
468
+
469
+ def _raise_if_error(self):
470
+ try:
471
+ index, e = self.error_queue.get(block=False)
472
+ except QueueEmpty:
473
+ raise RuntimeError("Unknown error in AsyncVecEnv worker without any error message.")
474
+ raise RuntimeError("Error in AsyncVecEnv worker {}".format(index)) from e
@@ -1 +1,2 @@
1
- from .transformation import FuncTransformWrapper
1
+ from .transformation import FuncTransformWrapper
2
+ from .frame_stack import FuncFrameStackWrapper, FuncFrameStackWrapperState
@@ -0,0 +1,150 @@
1
+ from typing import Dict as DictT, Any, Optional, Tuple, Union, Generic, SupportsFloat, Type, Sequence, TypeVar
2
+ import numpy as np
3
+ import copy
4
+ from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
5
+
6
+ from unienv_interface.space.space_utils import batch_utils as sbu, flatten_utils as sfu
7
+ from unienv_interface.utils import seed_util
8
+ from unienv_interface.env_base.funcenv import FuncEnv, ContextType, ObsType, ActType, RenderFrame, StateType, RenderStateType
9
+ from unienv_interface.env_base.funcenv_wrapper import *
10
+ from unienv_interface.space import Space
11
+ from unienv_interface.utils.data_queue import FuncSpaceDataQueue, SpaceDataQueueState
12
+ from unienv_interface.utils.stateclass import StateClass, field
13
+
14
+ class FuncFrameStackWrapperState(
15
+ Generic[StateType], StateClass
16
+ ):
17
+ env_state : StateType
18
+ obs_queue_state : Optional[SpaceDataQueueState]
19
+ action_queue_state : Optional[SpaceDataQueueState]
20
+
21
+ class FuncFrameStackWrapper(
22
+ FuncEnvWrapper[
23
+ FuncFrameStackWrapperState[StateType], RenderStateType,
24
+ BArrayType, ContextType, ObsType, ActType, RenderFrame, BDeviceType, BDtypeType, BRNGType,
25
+ StateType, RenderStateType,
26
+ BArrayType, ContextType, ObsType, ActType, RenderFrame, BDeviceType, BDtypeType, BRNGType
27
+ ]
28
+ ):
29
+ def __init__(
30
+ self,
31
+ func_env: FuncEnv[
32
+ StateType, RenderStateType,
33
+ BArrayType, ContextType, ObsType, ActType, RenderFrame, BDeviceType, BDtypeType, BRNGType
34
+ ],
35
+ obs_stack_size: int = 0,
36
+ action_stack_size: int = 0,
37
+ action_default_value: Optional[ActType] = None,
38
+ ):
39
+ assert obs_stack_size >= 0, "Observation stack size must be greater than 0"
40
+ assert action_stack_size >= 0, "Action stack size must be greater than 0"
41
+ assert action_stack_size == 0 or action_default_value is not None, "Action default value must be provided if action stack size is greater than 0"
42
+ assert obs_stack_size > 0 or action_stack_size > 0, "At least one of observation stack size or action stack size must be greater than 0"
43
+ super().__init__(func_env)
44
+ obs_is_dict = isinstance(func_env.observation_space, DictT)
45
+ assert obs_is_dict or action_stack_size == 0, "Action stack size must be 0 if observation space is not a DictSpace"
46
+
47
+ self.action_stack_size = action_stack_size
48
+ self.obs_stack_size = obs_stack_size
49
+
50
+ if action_stack_size > 0:
51
+ self.action_deque = FuncSpaceDataQueue(
52
+ func_env.action_space,
53
+ func_env.batch_size,
54
+ action_stack_size,
55
+ )
56
+ self.action_default_value = action_default_value
57
+ else:
58
+ self.action_deque = None
59
+
60
+ self.obs_deque = None
61
+ if obs_stack_size > 0:
62
+ self.obs_deque = FuncSpaceDataQueue(
63
+ func_env.observation_space,
64
+ func_env.batch_size,
65
+ obs_stack_size + 1,
66
+ )
67
+ if action_stack_size > 0:
68
+ new_obs_space = copy.copy(self.obs_deque.output_space)
69
+ new_obs_space['past_actions'] = self.action_deque.output_space
70
+ self.observation_space = new_obs_space
71
+ else:
72
+ self.observation_space = self.obs_deque.output_space
73
+ else:
74
+ if action_stack_size > 0:
75
+ self.observation_space = copy.copy(func_env.observation_space)
76
+ self.observation_space['past_actions'] = self.action_deque.output_space
77
+ else:
78
+ raise ValueError("At least one of observation stack size or action stack size must be greater than 0")
79
+
80
+ def map_observation(
81
+ self,
82
+ state : FuncFrameStackWrapperState[StateType],
83
+ observation : ObsType
84
+ ) -> ObsType:
85
+ if self.obs_deque is not None:
86
+ observation = self.obs_deque.get_output_data(state.obs_queue_state)
87
+ if self.action_deque is not None:
88
+ stacked_action = self.action_deque.get_output_data(state.action_queue_state)
89
+ observation = copy.copy(observation)
90
+ observation['past_actions'] = stacked_action
91
+ return observation
92
+
93
+ def initial(self, *, seed = None, **kwargs):
94
+ init_state, init_context, init_obs, init_info = self.func_env.initial(seed=seed, **kwargs)
95
+ obs_queue_state = None
96
+ action_queue_state = None
97
+ if self.obs_deque is not None:
98
+ obs_queue_state = self.obs_deque.init(init_obs)
99
+ if self.action_deque is not None:
100
+ action_queue_state = self.action_deque.init(self.action_default_value)
101
+ state = FuncFrameStackWrapperState(
102
+ env_state=init_state,
103
+ obs_queue_state=obs_queue_state,
104
+ action_queue_state=action_queue_state,
105
+ )
106
+ return state, init_context, self.map_observation(state, init_obs), init_info
107
+
108
+ def reset(self, state, *args, seed = None, mask = None, **kwargs):
109
+ env_state, context, observation, info = self.func_env.reset(
110
+ state.env_state,
111
+ *args,
112
+ seed=seed,
113
+ mask=mask,
114
+ **kwargs
115
+ )
116
+ obs_queue_state = state.obs_queue_state
117
+ action_queue_state = state.action_queue_state
118
+ if self.obs_deque is not None:
119
+ obs_queue_state = self.obs_deque.reset(
120
+ obs_queue_state,
121
+ initial_data=observation,
122
+ mask=mask
123
+ )
124
+ if self.action_deque is not None:
125
+ action_queue_state = self.action_deque.reset(
126
+ action_queue_state,
127
+ initial_data=self.action_default_value,
128
+ mask=mask
129
+ )
130
+ new_state = FuncFrameStackWrapperState(
131
+ env_state=env_state,
132
+ obs_queue_state=obs_queue_state,
133
+ action_queue_state=action_queue_state,
134
+ )
135
+ return new_state, context, self.map_observation(new_state, observation), info
136
+
137
+ def step(self, state, action):
138
+ env_state, observation, reward, terminated, truncated, info = self.func_env.step(state.env_state, action)
139
+ obs_queue_state = state.obs_queue_state
140
+ action_queue_state = state.action_queue_state
141
+ if self.obs_deque is not None:
142
+ obs_queue_state = self.obs_deque.add(obs_queue_state, observation)
143
+ if self.action_deque is not None:
144
+ action_queue_state = self.action_deque.add(action_queue_state, action)
145
+ new_state = FuncFrameStackWrapperState(
146
+ env_state=env_state,
147
+ obs_queue_state=obs_queue_state,
148
+ action_queue_state=action_queue_state,
149
+ )
150
+ return new_state, self.map_observation(new_state, observation), reward, terminated, truncated, info
@@ -1,3 +1,4 @@
1
1
  from .batch_utils import *
2
+ from .construct_utils import *
2
3
  from .flatten_utils import *
3
4
  from .serialization_utils import *