unienv 0.0.1b3__py3-none-any.whl → 0.0.1b5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {unienv-0.0.1b3.dist-info → unienv-0.0.1b5.dist-info}/METADATA +1 -1
- {unienv-0.0.1b3.dist-info → unienv-0.0.1b5.dist-info}/RECORD +26 -23
- unienv_data/base/common.py +16 -6
- unienv_data/base/storage.py +13 -3
- unienv_data/batches/slicestack_batch.py +1 -0
- unienv_data/replay_buffer/replay_buffer.py +136 -65
- unienv_data/replay_buffer/trajectory_replay_buffer.py +230 -163
- unienv_data/storages/dict_storage.py +373 -0
- unienv_data/storages/{common.py → flattened.py} +27 -6
- unienv_data/storages/hdf5.py +48 -3
- unienv_data/storages/pytorch.py +26 -5
- unienv_data/storages/transformation.py +16 -3
- unienv_data/transformations/image_compress.py +22 -9
- unienv_interface/func_wrapper/frame_stack.py +1 -1
- unienv_interface/space/space_utils/flatten_utils.py +8 -2
- unienv_interface/space/spaces/tuple.py +4 -4
- unienv_interface/transformations/image_resize.py +106 -0
- unienv_interface/transformations/iter_transform.py +92 -0
- unienv_interface/utils/symbol_util.py +7 -1
- unienv_interface/world/funcnode.py +1 -1
- unienv_interface/world/node.py +2 -2
- unienv_interface/wrapper/frame_stack.py +1 -1
- {unienv-0.0.1b3.dist-info → unienv-0.0.1b5.dist-info}/WHEEL +0 -0
- {unienv-0.0.1b3.dist-info → unienv-0.0.1b5.dist-info}/licenses/LICENSE +0 -0
- {unienv-0.0.1b3.dist-info → unienv-0.0.1b5.dist-info}/top_level.txt +0 -0
- /unienv_interface/utils/{data_queue.py → framestack_queue.py} +0 -0
|
@@ -3,9 +3,7 @@ from typing import Generic, TypeVar, Generic, Optional, Any, Dict, Tuple, Sequen
|
|
|
3
3
|
|
|
4
4
|
from unienv_interface.space import Space, BoxSpace
|
|
5
5
|
from unienv_interface.space.space_utils import batch_utils as sbu, flatten_utils as sfu
|
|
6
|
-
from unienv_interface.env_base.env import ContextType, ObsType, ActType
|
|
7
6
|
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
8
|
-
from unienv_interface.backends.numpy import NumpyComputeBackend
|
|
9
7
|
from unienv_interface.utils.symbol_util import *
|
|
10
8
|
from unienv_interface.transformations import DataTransformation
|
|
11
9
|
|
|
@@ -33,6 +31,8 @@ class TransformedStorage(SpaceStorage[
|
|
|
33
31
|
data_transformation : DataTransformation,
|
|
34
32
|
capacity : Optional[int] = None,
|
|
35
33
|
cache_path : Optional[str] = None,
|
|
34
|
+
multiprocessing : bool = False,
|
|
35
|
+
inner_storage_kwargs : Dict[str, Any] = {},
|
|
36
36
|
**kwargs
|
|
37
37
|
) -> "TransformedStorage[BArrayType, BDeviceType, BDtypeType, BRNGType]":
|
|
38
38
|
assert data_transformation.has_inverse, "To transform storages (potentially to save space), you need to use inversible data transformations"
|
|
@@ -42,12 +42,15 @@ class TransformedStorage(SpaceStorage[
|
|
|
42
42
|
if cache_path is not None:
|
|
43
43
|
os.makedirs(cache_path, exist_ok=True)
|
|
44
44
|
|
|
45
|
+
_inner_storage_kwargs = kwargs.copy()
|
|
46
|
+
_inner_storage_kwargs.update(inner_storage_kwargs)
|
|
45
47
|
inner_storage = inner_storage_cls.create(
|
|
46
48
|
transformed_space,
|
|
47
49
|
*args,
|
|
48
50
|
cache_path=None if cache_path is None else os.path.join(cache_path, inner_storage_path),
|
|
49
51
|
capacity=capacity,
|
|
50
|
-
|
|
52
|
+
multiprocessing=multiprocessing,
|
|
53
|
+
**_inner_storage_kwargs
|
|
51
54
|
)
|
|
52
55
|
return TransformedStorage(
|
|
53
56
|
single_instance_space,
|
|
@@ -64,6 +67,7 @@ class TransformedStorage(SpaceStorage[
|
|
|
64
67
|
*,
|
|
65
68
|
capacity : Optional[int] = None,
|
|
66
69
|
read_only : bool = True,
|
|
70
|
+
multiprocessing : bool = False,
|
|
67
71
|
**kwargs
|
|
68
72
|
) -> "TransformedStorage[BArrayType, BDeviceType, BDtypeType, BRNGType]":
|
|
69
73
|
metadata_path = os.path.join(path, "transformed_metadata.json")
|
|
@@ -87,6 +91,7 @@ class TransformedStorage(SpaceStorage[
|
|
|
87
91
|
transformed_space,
|
|
88
92
|
capacity=capacity,
|
|
89
93
|
read_only=read_only,
|
|
94
|
+
multiprocessing=multiprocessing,
|
|
90
95
|
**kwargs
|
|
91
96
|
)
|
|
92
97
|
return TransformedStorage(
|
|
@@ -142,6 +147,14 @@ class TransformedStorage(SpaceStorage[
|
|
|
142
147
|
def __len__(self):
|
|
143
148
|
return len(self.inner_storage)
|
|
144
149
|
|
|
150
|
+
@property
|
|
151
|
+
def is_mutable(self) -> bool:
|
|
152
|
+
return self.inner_storage.is_mutable
|
|
153
|
+
|
|
154
|
+
@property
|
|
155
|
+
def is_multiprocessing_safe(self) -> bool:
|
|
156
|
+
return self.inner_storage.is_multiprocessing_safe
|
|
157
|
+
|
|
145
158
|
def get_flattened(self, index):
|
|
146
159
|
dat = self.get(index)
|
|
147
160
|
if isinstance(index, int):
|
|
@@ -7,13 +7,17 @@ from PIL import Image
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
import io
|
|
9
9
|
|
|
10
|
+
CONSERVATIVE_COMPRESSION_RATIOS = {
|
|
11
|
+
"JPEG": 10, # https://stackoverflow.com/questions/3471663/jpeg-compression-ratio
|
|
12
|
+
}
|
|
13
|
+
|
|
10
14
|
class ImageCompressTransformation(DataTransformation):
|
|
11
15
|
has_inverse = True
|
|
12
16
|
|
|
13
17
|
def __init__(
|
|
14
18
|
self,
|
|
15
|
-
init_quality : int =
|
|
16
|
-
max_size_bytes : int =
|
|
19
|
+
init_quality : int = 70,
|
|
20
|
+
max_size_bytes : Optional[int] = None,
|
|
17
21
|
mode : Optional[str] = None,
|
|
18
22
|
format : str = "JPEG",
|
|
19
23
|
) -> None:
|
|
@@ -25,9 +29,11 @@ class ImageCompressTransformation(DataTransformation):
|
|
|
25
29
|
mode: Optional mode for PIL Image (e.g., "RGB", "L"). If None, inferred from input.
|
|
26
30
|
format: Image format to use for compression (default "JPEG"). See https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html for options.
|
|
27
31
|
"""
|
|
32
|
+
assert max_size_bytes is not None or format in CONSERVATIVE_COMPRESSION_RATIOS, "Either max_size_bytes must be specified or format must have a conservative compression ratio defined."
|
|
28
33
|
|
|
29
34
|
self.init_quality = init_quality
|
|
30
35
|
self.max_size_bytes = max_size_bytes
|
|
36
|
+
self.compression_ratio = CONSERVATIVE_COMPRESSION_RATIOS.get(format, None) if max_size_bytes is None else None
|
|
31
37
|
self.mode = mode
|
|
32
38
|
self.format = format
|
|
33
39
|
|
|
@@ -45,9 +51,15 @@ class ImageCompressTransformation(DataTransformation):
|
|
|
45
51
|
) -> BDtypeType:
|
|
46
52
|
return backend.__array_namespace_info__().dtypes()['uint8']
|
|
47
53
|
|
|
54
|
+
def _get_max_compressed_size(self, source_space : BoxSpace):
|
|
55
|
+
H, W, C = source_space.shape[-3], source_space.shape[-2], source_space.shape[-1]
|
|
56
|
+
return self.max_size_bytes if self.max_size_bytes is not None else (H * W * C // self.compression_ratio) + 1
|
|
57
|
+
|
|
48
58
|
def get_target_space_from_source(self, source_space):
|
|
49
59
|
self.validate_source_space(source_space)
|
|
50
|
-
|
|
60
|
+
|
|
61
|
+
max_compressed_size = self._get_max_compressed_size(source_space)
|
|
62
|
+
new_shape = source_space.shape[:-3] + (max_compressed_size,)
|
|
51
63
|
|
|
52
64
|
return BoxSpace(
|
|
53
65
|
source_space.backend,
|
|
@@ -78,14 +90,14 @@ class ImageCompressTransformation(DataTransformation):
|
|
|
78
90
|
# Create PIL Image (mode inferred automatically)
|
|
79
91
|
img = Image.fromarray(img_array, mode=mode)
|
|
80
92
|
|
|
81
|
-
quality =
|
|
93
|
+
quality = self.init_quality
|
|
82
94
|
while quality >= min_quality:
|
|
83
95
|
buf = io.BytesIO()
|
|
84
96
|
img.save(buf, format=self.format, quality=quality)
|
|
85
97
|
image_bytes = buf.getvalue()
|
|
86
98
|
if len(image_bytes) <= max_bytes:
|
|
87
99
|
return image_bytes, quality
|
|
88
|
-
quality -=
|
|
100
|
+
quality -= 10
|
|
89
101
|
|
|
90
102
|
img.close()
|
|
91
103
|
# Return lowest quality attempt if still too large
|
|
@@ -93,19 +105,21 @@ class ImageCompressTransformation(DataTransformation):
|
|
|
93
105
|
|
|
94
106
|
def transform(self, source_space, data):
|
|
95
107
|
self.validate_source_space(source_space)
|
|
108
|
+
|
|
109
|
+
max_compressed_size = self._get_max_compressed_size(source_space)
|
|
96
110
|
data_numpy = source_space.backend.to_numpy(data)
|
|
97
111
|
flat_data_numpy = data_numpy.reshape(-1, *data_numpy.shape[-3:])
|
|
98
|
-
flat_compressed_data = np.zeros((flat_data_numpy.shape[0],
|
|
112
|
+
flat_compressed_data = np.zeros((flat_data_numpy.shape[0], max_compressed_size), dtype=np.uint8)
|
|
99
113
|
for i in range(flat_data_numpy.shape[0]):
|
|
100
114
|
img_array = flat_data_numpy[i]
|
|
101
115
|
image_bytes, _ = self.encode_to_size(
|
|
102
116
|
img_array,
|
|
103
|
-
|
|
117
|
+
max_compressed_size,
|
|
104
118
|
mode=self.mode
|
|
105
119
|
)
|
|
106
120
|
byte_array = np.frombuffer(image_bytes, dtype=np.uint8)
|
|
107
121
|
flat_compressed_data[i, :len(byte_array)] = byte_array
|
|
108
|
-
compressed_data = flat_compressed_data.reshape(data_numpy.shape[:-3] + (
|
|
122
|
+
compressed_data = flat_compressed_data.reshape(data_numpy.shape[:-3] + (max_compressed_size, ))
|
|
109
123
|
compressed_data_backend = source_space.backend.from_numpy(compressed_data, dtype=self.get_uint8_dtype(source_space.backend), device=source_space.device)
|
|
110
124
|
return compressed_data_backend
|
|
111
125
|
|
|
@@ -206,7 +220,6 @@ class ImageDecompressTransformation(DataTransformation):
|
|
|
206
220
|
assert source_space is not None, "Source space must be provided to get inverse transformation."
|
|
207
221
|
self.validate_source_space(source_space)
|
|
208
222
|
return ImageCompressTransformation(
|
|
209
|
-
init_quality=75,
|
|
210
223
|
max_size_bytes=source_space.shape[-1],
|
|
211
224
|
mode=self.mode,
|
|
212
225
|
format=self.format if self.format is not None else "JPEG",
|
|
@@ -8,7 +8,7 @@ from unienv_interface.utils import seed_util
|
|
|
8
8
|
from unienv_interface.env_base.funcenv import FuncEnv, ContextType, ObsType, ActType, RenderFrame, StateType, RenderStateType
|
|
9
9
|
from unienv_interface.env_base.funcenv_wrapper import *
|
|
10
10
|
from unienv_interface.space import Space
|
|
11
|
-
from unienv_interface.utils.
|
|
11
|
+
from unienv_interface.utils.framestack_queue import FuncSpaceDataQueue, SpaceDataQueueState
|
|
12
12
|
from unienv_interface.utils.stateclass import StateClass, field
|
|
13
13
|
|
|
14
14
|
class FuncFrameStackWrapperState(
|
|
@@ -192,14 +192,20 @@ def unflatten_data(space : Space, data : BArrayType, start_dim : int = 0) -> Any
|
|
|
192
192
|
@flatten_data.register(BinarySpace)
|
|
193
193
|
def _flatten_data_common(space: typing.Union[BoxSpace, BinarySpace], data: BArrayType, start_dim : int = 0) -> BArrayType:
|
|
194
194
|
assert -len(space.shape) <= start_dim <= len(space.shape)
|
|
195
|
-
|
|
195
|
+
dat = space.backend.reshape(data, data.shape[:start_dim] + (-1,))
|
|
196
|
+
if isinstance(space, BinarySpace):
|
|
197
|
+
dat = space.backend.astype(dat, space.backend.default_integer_dtype)
|
|
198
|
+
return dat
|
|
196
199
|
|
|
197
200
|
@unflatten_data.register(BoxSpace)
|
|
198
201
|
@unflatten_data.register(BinarySpace)
|
|
199
202
|
def _unflatten_data_common(space: typing.Union[BoxSpace, BinarySpace], data: Any, start_dim : int = 0) -> BArrayType:
|
|
200
203
|
assert -len(space.shape) <= start_dim <= len(space.shape)
|
|
201
204
|
unflat_dat = space.backend.reshape(data, data.shape[:start_dim] + space.shape[start_dim:])
|
|
202
|
-
|
|
205
|
+
if isinstance(space, BinarySpace):
|
|
206
|
+
unflat_dat = space.backend.astype(unflat_dat, space.dtype if space.dtype is not None else space.backend.default_boolean_dtype)
|
|
207
|
+
else:
|
|
208
|
+
unflat_dat = space.backend.astype(unflat_dat, space.dtype)
|
|
203
209
|
return unflat_dat
|
|
204
210
|
|
|
205
211
|
@flatten_data.register(DynamicBoxSpace)
|
|
@@ -40,7 +40,7 @@ class TupleSpace(Space[Tuple[Any, ...], BDeviceType, BDtypeType, BRNGType]):
|
|
|
40
40
|
return self
|
|
41
41
|
|
|
42
42
|
new_device = device if backend is not None else (device or self.device)
|
|
43
|
-
return
|
|
43
|
+
return TupleSpace(
|
|
44
44
|
backend=backend or self.backend,
|
|
45
45
|
spaces=[space.to(backend, new_device) for space in self.spaces],
|
|
46
46
|
device=new_device
|
|
@@ -93,11 +93,11 @@ class TupleSpace(Space[Tuple[Any, ...], BDeviceType, BDtypeType, BRNGType]):
|
|
|
93
93
|
|
|
94
94
|
def __eq__(self, other: Any) -> bool:
|
|
95
95
|
"""Check whether ``other`` is equivalent to this instance."""
|
|
96
|
-
return isinstance(other,
|
|
96
|
+
return isinstance(other, TupleSpace) and self.spaces == other.spaces
|
|
97
97
|
|
|
98
|
-
def __copy__(self) -> "
|
|
98
|
+
def __copy__(self) -> "TupleSpace[BDeviceType, BDtypeType, BRNGType]":
|
|
99
99
|
"""Create a shallow copy of the Dict space."""
|
|
100
|
-
return
|
|
100
|
+
return TupleSpace(
|
|
101
101
|
backend=self.backend,
|
|
102
102
|
spaces=copy.copy(self.spaces),
|
|
103
103
|
device=self.device
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
from unienv_interface.space.space_utils import batch_utils as sbu
|
|
2
|
+
from .transformation import DataTransformation, TargetDataT
|
|
3
|
+
from unienv_interface.space import Space, BoxSpace
|
|
4
|
+
from typing import Union, Any, Optional
|
|
5
|
+
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
6
|
+
|
|
7
|
+
class ImageResizeTransformation(DataTransformation):
|
|
8
|
+
has_inverse = True
|
|
9
|
+
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
new_height: int,
|
|
13
|
+
new_width: int
|
|
14
|
+
):
|
|
15
|
+
self.new_height = new_height
|
|
16
|
+
self.new_width = new_width
|
|
17
|
+
|
|
18
|
+
def _validate_source_space(self, source_space : Space[Any, BDeviceType, BDtypeType, BRNGType]) -> BoxSpace[BArrayType, BDeviceType, BDtypeType, BRNGType]:
|
|
19
|
+
assert isinstance(source_space, BoxSpace), \
|
|
20
|
+
f"ImageResizeTransformation only supports BoxSpace, got {type(source_space)}"
|
|
21
|
+
assert len(source_space.shape) >= 3, \
|
|
22
|
+
f"ImageResizeTransformation only supports spaces with at least 3 dimensions (H, W, C), got shape {source_space.shape}"
|
|
23
|
+
assert source_space.shape[-3] > 0 and source_space.shape[-2] > 0, \
|
|
24
|
+
f"ImageResizeTransformation requires positive height and width, got shape {source_space.shape}"
|
|
25
|
+
return source_space
|
|
26
|
+
|
|
27
|
+
def get_target_space_from_source(self, source_space):
|
|
28
|
+
source_space = self._validate_source_space(source_space)
|
|
29
|
+
|
|
30
|
+
backend = source_space.backend
|
|
31
|
+
new_shape = (
|
|
32
|
+
*source_space.shape[:-3],
|
|
33
|
+
self.new_height,
|
|
34
|
+
self.new_width,
|
|
35
|
+
source_space.shape[-1]
|
|
36
|
+
)
|
|
37
|
+
new_low = backend.min(source_space.low, axis=(-3, -2), keepdims=True)
|
|
38
|
+
new_high = backend.max(source_space.high, axis=(-3, -2), keepdims=True)
|
|
39
|
+
|
|
40
|
+
return BoxSpace(
|
|
41
|
+
source_space.backend,
|
|
42
|
+
new_low,
|
|
43
|
+
new_high,
|
|
44
|
+
dtype=source_space.dtype,
|
|
45
|
+
device=source_space.device,
|
|
46
|
+
shape=new_shape
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
def transform(self, source_space, data):
|
|
50
|
+
source_space = self._validate_source_space(source_space)
|
|
51
|
+
backend = source_space.backend
|
|
52
|
+
if backend.simplified_name == "jax":
|
|
53
|
+
target_shape = (
|
|
54
|
+
*data.shape[:-3],
|
|
55
|
+
self.new_height,
|
|
56
|
+
self.new_width,
|
|
57
|
+
source_space.shape[-1]
|
|
58
|
+
)
|
|
59
|
+
import jax.image
|
|
60
|
+
resized_data = jax.image.resize(
|
|
61
|
+
data,
|
|
62
|
+
shape=target_shape,
|
|
63
|
+
method='bilinear',
|
|
64
|
+
antialias=True
|
|
65
|
+
)
|
|
66
|
+
elif backend.simplified_name == "pytorch":
|
|
67
|
+
import torch.nn.functional as F
|
|
68
|
+
# PyTorch expects (B, C, H, W)
|
|
69
|
+
data_permuted = backend.permute_dims(data, (*range(len(data.shape[:-3])), -1, -3, -2))
|
|
70
|
+
resized_data_permuted = F.interpolate(
|
|
71
|
+
data_permuted,
|
|
72
|
+
size=(self.new_height, self.new_width),
|
|
73
|
+
mode='bilinear',
|
|
74
|
+
align_corners=False,
|
|
75
|
+
antialias=True
|
|
76
|
+
)
|
|
77
|
+
# Permute back to original shape
|
|
78
|
+
resized_data = backend.permute_dims(resized_data_permuted, (*range(len(resized_data_permuted.shape[:-3])), -2, -1, -3))
|
|
79
|
+
elif backend.simplified_name == "numpy":
|
|
80
|
+
import cv2
|
|
81
|
+
flat_data = backend.reshape(data, (-1, *source_space.shape[-3:]))
|
|
82
|
+
resized_flat_data = []
|
|
83
|
+
for i in range(flat_data.shape[0]):
|
|
84
|
+
img = flat_data[i]
|
|
85
|
+
resized_img = cv2.resize(
|
|
86
|
+
img,
|
|
87
|
+
(self.new_width, self.new_height),
|
|
88
|
+
interpolation=cv2.INTER_LINEAR
|
|
89
|
+
)
|
|
90
|
+
resized_flat_data.append(resized_img)
|
|
91
|
+
resized_flat_data = backend.stack(resized_flat_data, axis=0)
|
|
92
|
+
resized_data = backend.reshape(
|
|
93
|
+
resized_flat_data,
|
|
94
|
+
(*data.shape[:-3], self.new_height, self.new_width, source_space.shape[-1])
|
|
95
|
+
)
|
|
96
|
+
else:
|
|
97
|
+
raise ValueError(f"Unsupported backend: {backend.simplified_name}")
|
|
98
|
+
return resized_data
|
|
99
|
+
|
|
100
|
+
def direction_inverse(self, source_space = None):
|
|
101
|
+
assert source_space is not None, "Inverse transformation requires source_space"
|
|
102
|
+
source_space = self._validate_source_space(source_space)
|
|
103
|
+
return ImageResizeTransformation(
|
|
104
|
+
new_height=source_space.shape[-3],
|
|
105
|
+
new_width=source_space.shape[-2]
|
|
106
|
+
)
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
from typing import Union, Any, Optional, Mapping, List, Callable, Dict
|
|
2
|
+
|
|
3
|
+
from unienv_interface.space.space_utils import batch_utils as sbu
|
|
4
|
+
from unienv_interface.space import Space, DictSpace, TupleSpace
|
|
5
|
+
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
6
|
+
|
|
7
|
+
import copy
|
|
8
|
+
from .transformation import DataTransformation, TargetDataT
|
|
9
|
+
|
|
10
|
+
def default_is_leaf_fn(space : Space[Any, BDeviceType, BDtypeType, BRNGType]):
|
|
11
|
+
return not isinstance(space, (DictSpace, TupleSpace))
|
|
12
|
+
|
|
13
|
+
class IterativeTransformation(DataTransformation):
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
transformation: DataTransformation,
|
|
17
|
+
is_leaf_node_fn: Callable[[Space[Any, BDeviceType, BDtypeType, BRNGType]], bool] = default_is_leaf_fn,
|
|
18
|
+
inv_is_leaf_node_fn: Callable[[Space[Any, BDeviceType, BDtypeType, BRNGType]], bool] = default_is_leaf_fn
|
|
19
|
+
):
|
|
20
|
+
self.transformation = transformation
|
|
21
|
+
self.is_leaf_node_fn = is_leaf_node_fn
|
|
22
|
+
self.inv_is_leaf_node_fn = inv_is_leaf_node_fn
|
|
23
|
+
self.has_inverse = transformation.has_inverse
|
|
24
|
+
|
|
25
|
+
def get_target_space_from_source(
|
|
26
|
+
self,
|
|
27
|
+
source_space : Space[Any, BDeviceType, BDtypeType, BRNGType]
|
|
28
|
+
):
|
|
29
|
+
if self.is_leaf_node_fn(source_space):
|
|
30
|
+
return self.transformation.get_target_space_from_source(source_space)
|
|
31
|
+
elif isinstance(source_space, DictSpace):
|
|
32
|
+
rsts = {
|
|
33
|
+
key: self.get_target_space_from_source(subspace)
|
|
34
|
+
for key, subspace in source_space.spaces.items()
|
|
35
|
+
}
|
|
36
|
+
backend = source_space.backend if len(rsts) == 0 else next(iter(rsts.values())).backend
|
|
37
|
+
device = source_space.device if len(rsts) == 0 else next(iter(rsts.values())).device
|
|
38
|
+
return DictSpace(
|
|
39
|
+
backend,
|
|
40
|
+
rsts,
|
|
41
|
+
device=device
|
|
42
|
+
)
|
|
43
|
+
elif isinstance(source_space, TupleSpace):
|
|
44
|
+
rsts = tuple(
|
|
45
|
+
self.get_target_space_from_source(subspace)
|
|
46
|
+
for subspace in source_space.spaces
|
|
47
|
+
)
|
|
48
|
+
backend = source_space.backend if len(rsts) == 0 else next(iter(rsts)).backend
|
|
49
|
+
device = source_space.device if len(rsts) == 0 else next(iter(rsts)).device
|
|
50
|
+
return TupleSpace(
|
|
51
|
+
backend,
|
|
52
|
+
rsts,
|
|
53
|
+
device=device
|
|
54
|
+
)
|
|
55
|
+
else:
|
|
56
|
+
raise ValueError(f"Unsupported space type: {type(source_space)}")
|
|
57
|
+
|
|
58
|
+
def transform(
|
|
59
|
+
self,
|
|
60
|
+
source_space: Space,
|
|
61
|
+
data: Union[Mapping[str, Any], BArrayType]
|
|
62
|
+
) -> Union[Mapping[str, Any], BArrayType]:
|
|
63
|
+
if self.is_leaf_node_fn(source_space):
|
|
64
|
+
return self.transformation.transform(source_space, data)
|
|
65
|
+
elif isinstance(source_space, DictSpace):
|
|
66
|
+
return {
|
|
67
|
+
key: self.transform(subspace, data[key])
|
|
68
|
+
for key, subspace in source_space.spaces.items()
|
|
69
|
+
}
|
|
70
|
+
elif isinstance(source_space, TupleSpace):
|
|
71
|
+
return tuple(
|
|
72
|
+
self.transform(subspace, data[i])
|
|
73
|
+
for i, subspace in enumerate(source_space.spaces)
|
|
74
|
+
)
|
|
75
|
+
else:
|
|
76
|
+
raise ValueError(f"Unsupported space type: {type(source_space)}")
|
|
77
|
+
|
|
78
|
+
def direction_inverse(
|
|
79
|
+
self,
|
|
80
|
+
source_space = None,
|
|
81
|
+
) -> Optional["IterativeTransformation"]:
|
|
82
|
+
if not self.has_inverse:
|
|
83
|
+
return None
|
|
84
|
+
|
|
85
|
+
return IterativeTransformation(
|
|
86
|
+
self.transformation.direction_inverse(),
|
|
87
|
+
is_leaf_node_fn=self.inv_is_leaf_node_fn,
|
|
88
|
+
inv_is_leaf_node_fn=self.is_leaf_node_fn
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
def close(self):
|
|
92
|
+
self.transformation.close()
|
|
@@ -5,9 +5,15 @@ __all__ = [
|
|
|
5
5
|
"get_class_from_full_name",
|
|
6
6
|
]
|
|
7
7
|
|
|
8
|
+
REMAP = {
|
|
9
|
+
"unienv_data.storages.common.FlattenedStorage": "unienv_data.storages.flattened.FlattenedStorage",
|
|
10
|
+
}
|
|
11
|
+
|
|
8
12
|
def get_full_class_name(cls : Type) -> str:
|
|
9
13
|
return f"{cls.__module__}.{cls.__qualname__}"
|
|
10
14
|
|
|
11
15
|
def get_class_from_full_name(full_name : str) -> Type:
|
|
16
|
+
if full_name in REMAP:
|
|
17
|
+
full_name = REMAP[full_name]
|
|
12
18
|
module_name, class_name = full_name.rsplit(".", 1)
|
|
13
|
-
return getattr(__import__(module_name, fromlist=[class_name]), class_name)
|
|
19
|
+
return getattr(__import__(module_name, fromlist=[class_name]), class_name)
|
|
@@ -21,7 +21,6 @@ class FuncWorldNode(ABC, Generic[
|
|
|
21
21
|
"""
|
|
22
22
|
|
|
23
23
|
name : str
|
|
24
|
-
world : FuncWorld[WorldStateT, BArrayType, BDeviceType, BDtypeType, BRNGType]
|
|
25
24
|
control_timestep : Optional[float] = None
|
|
26
25
|
context_space : Optional[Space[ContextType, BDeviceType, BDtypeType, BRNGType]] = None
|
|
27
26
|
observation_space : Optional[Space[ObsType, BDeviceType, BDtypeType, BRNGType]] = None
|
|
@@ -29,6 +28,7 @@ class FuncWorldNode(ABC, Generic[
|
|
|
29
28
|
has_reward : bool = False
|
|
30
29
|
has_termination_signal : bool = False
|
|
31
30
|
has_truncation_signal : bool = False
|
|
31
|
+
world : Optional[FuncWorld[WorldStateT, BArrayType, BDeviceType, BDtypeType, BRNGType]] = None
|
|
32
32
|
|
|
33
33
|
@property
|
|
34
34
|
def backend(self) -> ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType]:
|
unienv_interface/world/node.py
CHANGED
|
@@ -8,7 +8,7 @@ from .world import World
|
|
|
8
8
|
|
|
9
9
|
class WorldNode(ABC, Generic[ContextType, ObsType, ActType, BArrayType, BDeviceType, BDtypeType, BRNGType]):
|
|
10
10
|
"""
|
|
11
|
-
Each `WorldNode` in the simulated / real world will manage some aspect of the environment.
|
|
11
|
+
Each `WorldNode` in the simulated / real world will manage some aspect of the environment. This can include sensors, robots, or other entities that interact with the world.
|
|
12
12
|
How the methods in this class will be called once environment resets:
|
|
13
13
|
`World.reset(...)` -> `WorldNode.reset(...)` -> `WorldNode.after_reset(...)` -> `WorldNode.get_observation(...)` -> World can start stepping normally
|
|
14
14
|
How the methods in this class will be called during a environment step:
|
|
@@ -16,7 +16,6 @@ class WorldNode(ABC, Generic[ContextType, ObsType, ActType, BArrayType, BDeviceT
|
|
|
16
16
|
"""
|
|
17
17
|
|
|
18
18
|
name : str
|
|
19
|
-
world : World[BArrayType, BDeviceType, BDtypeType, BRNGType]
|
|
20
19
|
control_timestep : Optional[float] = None
|
|
21
20
|
context_space : Optional[Space[ContextType, BDeviceType, BDtypeType, BRNGType]] = None
|
|
22
21
|
observation_space : Optional[Space[ObsType, BDeviceType, BDtypeType, BRNGType]] = None
|
|
@@ -24,6 +23,7 @@ class WorldNode(ABC, Generic[ContextType, ObsType, ActType, BArrayType, BDeviceT
|
|
|
24
23
|
has_reward : bool = False
|
|
25
24
|
has_termination_signal : bool = False
|
|
26
25
|
has_truncation_signal : bool = False
|
|
26
|
+
world : Optional[World[BArrayType, BDeviceType, BDtypeType, BRNGType]] = None
|
|
27
27
|
|
|
28
28
|
@property
|
|
29
29
|
def backend(self) -> ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType]:
|
|
@@ -7,7 +7,7 @@ from unienv_interface.space.space_utils import batch_utils as sbu
|
|
|
7
7
|
from unienv_interface.env_base.env import Env, ContextType, ObsType, ActType, RenderFrame, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
8
8
|
from unienv_interface.env_base.wrapper import ContextObservationWrapper, ActionWrapper, WrapperContextT, WrapperObsT, WrapperActT
|
|
9
9
|
from unienv_interface.space import Space, DictSpace
|
|
10
|
-
from unienv_interface.utils.
|
|
10
|
+
from unienv_interface.utils.framestack_queue import SpaceDataQueue
|
|
11
11
|
|
|
12
12
|
class FrameStackWrapper(
|
|
13
13
|
ContextObservationWrapper[
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|