unienv 0.0.1b4__py3-none-any.whl → 0.0.1b5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {unienv-0.0.1b4.dist-info → unienv-0.0.1b5.dist-info}/METADATA +1 -1
- {unienv-0.0.1b4.dist-info → unienv-0.0.1b5.dist-info}/RECORD +23 -21
- unienv_data/base/storage.py +2 -0
- unienv_data/batches/slicestack_batch.py +1 -0
- unienv_data/replay_buffer/replay_buffer.py +136 -65
- unienv_data/replay_buffer/trajectory_replay_buffer.py +230 -163
- unienv_data/storages/dict_storage.py +39 -7
- unienv_data/storages/flattened.py +8 -1
- unienv_data/storages/hdf5.py +6 -0
- unienv_data/storages/pytorch.py +1 -1
- unienv_data/storages/transformation.py +16 -1
- unienv_data/transformations/image_compress.py +22 -9
- unienv_interface/func_wrapper/frame_stack.py +1 -1
- unienv_interface/space/space_utils/flatten_utils.py +8 -2
- unienv_interface/space/spaces/tuple.py +4 -4
- unienv_interface/transformations/image_resize.py +106 -0
- unienv_interface/transformations/iter_transform.py +92 -0
- unienv_interface/utils/symbol_util.py +7 -1
- unienv_interface/wrapper/frame_stack.py +1 -1
- {unienv-0.0.1b4.dist-info → unienv-0.0.1b5.dist-info}/WHEEL +0 -0
- {unienv-0.0.1b4.dist-info → unienv-0.0.1b5.dist-info}/licenses/LICENSE +0 -0
- {unienv-0.0.1b4.dist-info → unienv-0.0.1b5.dist-info}/top_level.txt +0 -0
- /unienv_interface/utils/{data_queue.py → framestack_queue.py} +0 -0
|
@@ -12,6 +12,24 @@ import numpy as np
|
|
|
12
12
|
import os
|
|
13
13
|
import json
|
|
14
14
|
|
|
15
|
+
|
|
16
|
+
def _merge_nested_mappings(
|
|
17
|
+
primary: Mapping[str, Any],
|
|
18
|
+
secondary: Mapping[str, Any],
|
|
19
|
+
) -> Mapping[str, Any]:
|
|
20
|
+
"""Merge secondary into primary without clobbering explicitly matched keys."""
|
|
21
|
+
merged: Dict[str, Any] = dict(primary)
|
|
22
|
+
for merge_key, merge_value in secondary.items():
|
|
23
|
+
if (
|
|
24
|
+
merge_key in merged
|
|
25
|
+
and isinstance(merged[merge_key], Mapping)
|
|
26
|
+
and isinstance(merge_value, Mapping)
|
|
27
|
+
):
|
|
28
|
+
merged[merge_key] = _merge_nested_mappings(merged[merge_key], merge_value)
|
|
29
|
+
elif merge_key not in merged:
|
|
30
|
+
merged[merge_key] = merge_value
|
|
31
|
+
return merged
|
|
32
|
+
|
|
15
33
|
def map_transform(
|
|
16
34
|
data : Dict[str, Any],
|
|
17
35
|
value_map : Dict[str, Any],
|
|
@@ -44,7 +62,10 @@ def map_transform(
|
|
|
44
62
|
residual_transformed = fn(prefix + "*", residual_data, value_map[prefix + "*"])
|
|
45
63
|
if isinstance(residual_transformed, Mapping) or isinstance(residual_transformed, DictSpace):
|
|
46
64
|
for key, value in residual_transformed.items():
|
|
47
|
-
transformed_data[key]
|
|
65
|
+
if key in transformed_data and isinstance(transformed_data[key], Mapping) and isinstance(value, Mapping):
|
|
66
|
+
transformed_data[key] = _merge_nested_mappings(transformed_data[key], value)
|
|
67
|
+
elif key not in transformed_data:
|
|
68
|
+
transformed_data[key] = value
|
|
48
69
|
residual_data = {}
|
|
49
70
|
return transformed_data, residual_data
|
|
50
71
|
|
|
@@ -52,7 +73,7 @@ def get_chained_residual_space(
|
|
|
52
73
|
space : DictSpace[BDeviceType, BDtypeType, BRNGType],
|
|
53
74
|
all_keys : List[str],
|
|
54
75
|
prefix : str = "",
|
|
55
|
-
) -> DictSpace[BDeviceType, BDtypeType, BRNGType]:
|
|
76
|
+
) -> Optional[DictSpace[BDeviceType, BDtypeType, BRNGType]]:
|
|
56
77
|
residual_spaces = {}
|
|
57
78
|
|
|
58
79
|
if len(residual_spaces) > 0 and (prefix + "*") in all_keys:
|
|
@@ -72,10 +93,13 @@ def get_chained_residual_space(
|
|
|
72
93
|
all_keys,
|
|
73
94
|
prefix=full_key + "/",
|
|
74
95
|
)
|
|
75
|
-
if len(sub_residual.spaces) > 0:
|
|
96
|
+
if sub_residual is not None and len(sub_residual.spaces) > 0:
|
|
76
97
|
residual_spaces[key] = sub_residual
|
|
77
98
|
else:
|
|
78
99
|
residual_spaces[key] = subspace
|
|
100
|
+
|
|
101
|
+
if len(residual_spaces) == 0:
|
|
102
|
+
return None
|
|
79
103
|
|
|
80
104
|
return DictSpace(
|
|
81
105
|
space.backend,
|
|
@@ -87,7 +111,7 @@ def get_chained_space(
|
|
|
87
111
|
space : DictSpace[BDeviceType, BDtypeType, BRNGType],
|
|
88
112
|
key_chain : str,
|
|
89
113
|
all_keys : List[str],
|
|
90
|
-
) -> Space[Any, BDeviceType, BDtypeType, BRNGType]:
|
|
114
|
+
) -> Optional[Space[Any, BDeviceType, BDtypeType, BRNGType]]:
|
|
91
115
|
if key_chain.endswith("*"):
|
|
92
116
|
prefix = key_chain[:-1]
|
|
93
117
|
subspace = get_chained_residual_space(
|
|
@@ -106,8 +130,8 @@ def get_chained_space(
|
|
|
106
130
|
for key in key_chain:
|
|
107
131
|
if len(key) == 0:
|
|
108
132
|
continue
|
|
109
|
-
|
|
110
|
-
|
|
133
|
+
if not isinstance(current_space, DictSpace) or key not in current_space.spaces:
|
|
134
|
+
return None
|
|
111
135
|
current_space = current_space.spaces[key]
|
|
112
136
|
return current_space
|
|
113
137
|
|
|
@@ -130,6 +154,7 @@ class DictStorage(SpaceStorage[
|
|
|
130
154
|
*args,
|
|
131
155
|
capacity : Optional[int] = None,
|
|
132
156
|
cache_path : Optional[str] = None,
|
|
157
|
+
multiprocessing : bool = False,
|
|
133
158
|
key_kwargs : Dict[str, Any] = {},
|
|
134
159
|
type_kwargs : Dict[Type[SpaceStorage[Any, BArrayType, BDeviceType, BDtypeType, BRNGType]], Dict[str, Any]] = {},
|
|
135
160
|
**kwargs
|
|
@@ -142,6 +167,8 @@ class DictStorage(SpaceStorage[
|
|
|
142
167
|
for key, sub_storage_cls in storage_cls_map.items():
|
|
143
168
|
sub_storage_path = key.replace("/", ".").replace("*", "_default") + (sub_storage_cls.single_file_ext or "")
|
|
144
169
|
subspace = get_chained_space(single_instance_space, key, all_keys)
|
|
170
|
+
if subspace is None:
|
|
171
|
+
continue
|
|
145
172
|
sub_kwargs = kwargs.copy()
|
|
146
173
|
if sub_storage_cls in type_kwargs:
|
|
147
174
|
sub_kwargs.update(type_kwargs[sub_storage_cls])
|
|
@@ -152,6 +179,7 @@ class DictStorage(SpaceStorage[
|
|
|
152
179
|
*args,
|
|
153
180
|
cache_path=None if cache_path is None else os.path.join(cache_path, sub_storage_path),
|
|
154
181
|
capacity=capacity,
|
|
182
|
+
multiprocessing=multiprocessing,
|
|
155
183
|
**sub_kwargs
|
|
156
184
|
)
|
|
157
185
|
|
|
@@ -169,6 +197,7 @@ class DictStorage(SpaceStorage[
|
|
|
169
197
|
*,
|
|
170
198
|
capacity : Optional[int] = None,
|
|
171
199
|
read_only : bool = True,
|
|
200
|
+
multiprocessing : bool = False,
|
|
172
201
|
key_kwargs : Dict[str, Any] = {},
|
|
173
202
|
type_kwargs : Dict[Type[SpaceStorage[Any, BArrayType, BDeviceType, BDtypeType, BRNGType]], Dict[str, Any]] = {},
|
|
174
203
|
**kwargs
|
|
@@ -189,7 +218,9 @@ class DictStorage(SpaceStorage[
|
|
|
189
218
|
storage_path = storage_meta["path"]
|
|
190
219
|
|
|
191
220
|
subspace = get_chained_space(single_instance_space, key, all_keys)
|
|
192
|
-
|
|
221
|
+
if subspace is None:
|
|
222
|
+
continue
|
|
223
|
+
|
|
193
224
|
sub_kwargs = kwargs.copy()
|
|
194
225
|
if storage_cls in type_kwargs:
|
|
195
226
|
sub_kwargs.update(type_kwargs[storage_cls])
|
|
@@ -200,6 +231,7 @@ class DictStorage(SpaceStorage[
|
|
|
200
231
|
subspace,
|
|
201
232
|
capacity=capacity,
|
|
202
233
|
read_only=read_only,
|
|
234
|
+
multiprocessing=multiprocessing,
|
|
203
235
|
**sub_kwargs
|
|
204
236
|
)
|
|
205
237
|
|
|
@@ -28,6 +28,8 @@ class FlattenedStorage(SpaceStorage[
|
|
|
28
28
|
*args,
|
|
29
29
|
capacity : Optional[int] = None,
|
|
30
30
|
cache_path : Optional[str] = None,
|
|
31
|
+
multiprocessing : bool = False,
|
|
32
|
+
inner_storage_kwargs : Dict[str, Any] = {},
|
|
31
33
|
**kwargs
|
|
32
34
|
) -> "FlattenedStorage[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]":
|
|
33
35
|
flattened_space = sfu.flatten_space(single_instance_space)
|
|
@@ -36,12 +38,15 @@ class FlattenedStorage(SpaceStorage[
|
|
|
36
38
|
if cache_path is not None:
|
|
37
39
|
os.makedirs(cache_path, exist_ok=True)
|
|
38
40
|
|
|
41
|
+
_inner_storage_kwargs = kwargs.copy()
|
|
42
|
+
_inner_storage_kwargs.update(inner_storage_kwargs)
|
|
39
43
|
inner_storage = inner_storage_cls.create(
|
|
40
44
|
flattened_space,
|
|
41
45
|
*args,
|
|
42
46
|
cache_path=None if cache_path is None else os.path.join(cache_path, inner_storage_path),
|
|
43
47
|
capacity=capacity,
|
|
44
|
-
|
|
48
|
+
multiprocessing=multiprocessing,
|
|
49
|
+
**_inner_storage_kwargs
|
|
45
50
|
)
|
|
46
51
|
return FlattenedStorage(
|
|
47
52
|
single_instance_space,
|
|
@@ -58,6 +63,7 @@ class FlattenedStorage(SpaceStorage[
|
|
|
58
63
|
*,
|
|
59
64
|
capacity : Optional[int] = None,
|
|
60
65
|
read_only : bool = True,
|
|
66
|
+
multiprocessing : bool = False,
|
|
61
67
|
**kwargs
|
|
62
68
|
) -> "FlattenedStorage[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]":
|
|
63
69
|
metadata_path = os.path.join(path, "flattened_metadata.json")
|
|
@@ -74,6 +80,7 @@ class FlattenedStorage(SpaceStorage[
|
|
|
74
80
|
flattened_space,
|
|
75
81
|
capacity=capacity,
|
|
76
82
|
read_only=read_only,
|
|
83
|
+
multiprocessing=multiprocessing,
|
|
77
84
|
**kwargs
|
|
78
85
|
)
|
|
79
86
|
return FlattenedStorage(
|
unienv_data/storages/hdf5.py
CHANGED
|
@@ -458,6 +458,7 @@ class HDF5Storage(SpaceStorage[
|
|
|
458
458
|
single_instance_space,
|
|
459
459
|
capacity,
|
|
460
460
|
cache_path = None,
|
|
461
|
+
multiprocessing : bool = False,
|
|
461
462
|
initial_capacity : Optional[int] = None,
|
|
462
463
|
compression : Union[
|
|
463
464
|
Dict[str, Any],
|
|
@@ -476,6 +477,8 @@ class HDF5Storage(SpaceStorage[
|
|
|
476
477
|
) -> "HDF5Storage":
|
|
477
478
|
assert cache_path is not None, \
|
|
478
479
|
"cache_path must be provided for HDF5Storage"
|
|
480
|
+
assert not multiprocessing, \
|
|
481
|
+
"HDF5Storage does not support multiprocessing safe creation. Please create the storage in the main process and then load it in child processes."
|
|
479
482
|
root = h5py.File(
|
|
480
483
|
cache_path,
|
|
481
484
|
"w",
|
|
@@ -506,9 +509,12 @@ class HDF5Storage(SpaceStorage[
|
|
|
506
509
|
*,
|
|
507
510
|
capacity = None,
|
|
508
511
|
read_only = True,
|
|
512
|
+
multiprocessing : bool = False,
|
|
509
513
|
reduce_io : bool = True,
|
|
510
514
|
**kwargs
|
|
511
515
|
) -> "HDF5Storage":
|
|
516
|
+
assert not multiprocessing, \
|
|
517
|
+
"HDF5Storage does not support multiprocessing safe loading. Please load the storage in the main process and then share it with child processes."
|
|
512
518
|
assert os.path.exists(path), \
|
|
513
519
|
f"Path {path} does not exist"
|
|
514
520
|
|
unienv_data/storages/pytorch.py
CHANGED
|
@@ -22,8 +22,8 @@ class PytorchTensorStorage(SpaceStorage[
|
|
|
22
22
|
capacity : Optional[int],
|
|
23
23
|
is_memmap : bool = False,
|
|
24
24
|
cache_path : Optional[str] = None,
|
|
25
|
-
memmap_existok : bool = True,
|
|
26
25
|
multiprocessing : bool = False,
|
|
26
|
+
memmap_existok : bool = True,
|
|
27
27
|
) -> "PytorchTensorStorage":
|
|
28
28
|
assert single_instance_space.backend is PyTorchComputeBackend, \
|
|
29
29
|
f"Single instance space must be of type PyTorchComputeBackend, got {single_instance_space.backend}"
|
|
@@ -31,6 +31,8 @@ class TransformedStorage(SpaceStorage[
|
|
|
31
31
|
data_transformation : DataTransformation,
|
|
32
32
|
capacity : Optional[int] = None,
|
|
33
33
|
cache_path : Optional[str] = None,
|
|
34
|
+
multiprocessing : bool = False,
|
|
35
|
+
inner_storage_kwargs : Dict[str, Any] = {},
|
|
34
36
|
**kwargs
|
|
35
37
|
) -> "TransformedStorage[BArrayType, BDeviceType, BDtypeType, BRNGType]":
|
|
36
38
|
assert data_transformation.has_inverse, "To transform storages (potentially to save space), you need to use inversible data transformations"
|
|
@@ -40,12 +42,15 @@ class TransformedStorage(SpaceStorage[
|
|
|
40
42
|
if cache_path is not None:
|
|
41
43
|
os.makedirs(cache_path, exist_ok=True)
|
|
42
44
|
|
|
45
|
+
_inner_storage_kwargs = kwargs.copy()
|
|
46
|
+
_inner_storage_kwargs.update(inner_storage_kwargs)
|
|
43
47
|
inner_storage = inner_storage_cls.create(
|
|
44
48
|
transformed_space,
|
|
45
49
|
*args,
|
|
46
50
|
cache_path=None if cache_path is None else os.path.join(cache_path, inner_storage_path),
|
|
47
51
|
capacity=capacity,
|
|
48
|
-
|
|
52
|
+
multiprocessing=multiprocessing,
|
|
53
|
+
**_inner_storage_kwargs
|
|
49
54
|
)
|
|
50
55
|
return TransformedStorage(
|
|
51
56
|
single_instance_space,
|
|
@@ -62,6 +67,7 @@ class TransformedStorage(SpaceStorage[
|
|
|
62
67
|
*,
|
|
63
68
|
capacity : Optional[int] = None,
|
|
64
69
|
read_only : bool = True,
|
|
70
|
+
multiprocessing : bool = False,
|
|
65
71
|
**kwargs
|
|
66
72
|
) -> "TransformedStorage[BArrayType, BDeviceType, BDtypeType, BRNGType]":
|
|
67
73
|
metadata_path = os.path.join(path, "transformed_metadata.json")
|
|
@@ -85,6 +91,7 @@ class TransformedStorage(SpaceStorage[
|
|
|
85
91
|
transformed_space,
|
|
86
92
|
capacity=capacity,
|
|
87
93
|
read_only=read_only,
|
|
94
|
+
multiprocessing=multiprocessing,
|
|
88
95
|
**kwargs
|
|
89
96
|
)
|
|
90
97
|
return TransformedStorage(
|
|
@@ -140,6 +147,14 @@ class TransformedStorage(SpaceStorage[
|
|
|
140
147
|
def __len__(self):
|
|
141
148
|
return len(self.inner_storage)
|
|
142
149
|
|
|
150
|
+
@property
|
|
151
|
+
def is_mutable(self) -> bool:
|
|
152
|
+
return self.inner_storage.is_mutable
|
|
153
|
+
|
|
154
|
+
@property
|
|
155
|
+
def is_multiprocessing_safe(self) -> bool:
|
|
156
|
+
return self.inner_storage.is_multiprocessing_safe
|
|
157
|
+
|
|
143
158
|
def get_flattened(self, index):
|
|
144
159
|
dat = self.get(index)
|
|
145
160
|
if isinstance(index, int):
|
|
@@ -7,13 +7,17 @@ from PIL import Image
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
import io
|
|
9
9
|
|
|
10
|
+
CONSERVATIVE_COMPRESSION_RATIOS = {
|
|
11
|
+
"JPEG": 10, # https://stackoverflow.com/questions/3471663/jpeg-compression-ratio
|
|
12
|
+
}
|
|
13
|
+
|
|
10
14
|
class ImageCompressTransformation(DataTransformation):
|
|
11
15
|
has_inverse = True
|
|
12
16
|
|
|
13
17
|
def __init__(
|
|
14
18
|
self,
|
|
15
|
-
init_quality : int =
|
|
16
|
-
max_size_bytes : int =
|
|
19
|
+
init_quality : int = 70,
|
|
20
|
+
max_size_bytes : Optional[int] = None,
|
|
17
21
|
mode : Optional[str] = None,
|
|
18
22
|
format : str = "JPEG",
|
|
19
23
|
) -> None:
|
|
@@ -25,9 +29,11 @@ class ImageCompressTransformation(DataTransformation):
|
|
|
25
29
|
mode: Optional mode for PIL Image (e.g., "RGB", "L"). If None, inferred from input.
|
|
26
30
|
format: Image format to use for compression (default "JPEG"). See https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html for options.
|
|
27
31
|
"""
|
|
32
|
+
assert max_size_bytes is not None or format in CONSERVATIVE_COMPRESSION_RATIOS, "Either max_size_bytes must be specified or format must have a conservative compression ratio defined."
|
|
28
33
|
|
|
29
34
|
self.init_quality = init_quality
|
|
30
35
|
self.max_size_bytes = max_size_bytes
|
|
36
|
+
self.compression_ratio = CONSERVATIVE_COMPRESSION_RATIOS.get(format, None) if max_size_bytes is None else None
|
|
31
37
|
self.mode = mode
|
|
32
38
|
self.format = format
|
|
33
39
|
|
|
@@ -45,9 +51,15 @@ class ImageCompressTransformation(DataTransformation):
|
|
|
45
51
|
) -> BDtypeType:
|
|
46
52
|
return backend.__array_namespace_info__().dtypes()['uint8']
|
|
47
53
|
|
|
54
|
+
def _get_max_compressed_size(self, source_space : BoxSpace):
|
|
55
|
+
H, W, C = source_space.shape[-3], source_space.shape[-2], source_space.shape[-1]
|
|
56
|
+
return self.max_size_bytes if self.max_size_bytes is not None else (H * W * C // self.compression_ratio) + 1
|
|
57
|
+
|
|
48
58
|
def get_target_space_from_source(self, source_space):
|
|
49
59
|
self.validate_source_space(source_space)
|
|
50
|
-
|
|
60
|
+
|
|
61
|
+
max_compressed_size = self._get_max_compressed_size(source_space)
|
|
62
|
+
new_shape = source_space.shape[:-3] + (max_compressed_size,)
|
|
51
63
|
|
|
52
64
|
return BoxSpace(
|
|
53
65
|
source_space.backend,
|
|
@@ -78,14 +90,14 @@ class ImageCompressTransformation(DataTransformation):
|
|
|
78
90
|
# Create PIL Image (mode inferred automatically)
|
|
79
91
|
img = Image.fromarray(img_array, mode=mode)
|
|
80
92
|
|
|
81
|
-
quality =
|
|
93
|
+
quality = self.init_quality
|
|
82
94
|
while quality >= min_quality:
|
|
83
95
|
buf = io.BytesIO()
|
|
84
96
|
img.save(buf, format=self.format, quality=quality)
|
|
85
97
|
image_bytes = buf.getvalue()
|
|
86
98
|
if len(image_bytes) <= max_bytes:
|
|
87
99
|
return image_bytes, quality
|
|
88
|
-
quality -=
|
|
100
|
+
quality -= 10
|
|
89
101
|
|
|
90
102
|
img.close()
|
|
91
103
|
# Return lowest quality attempt if still too large
|
|
@@ -93,19 +105,21 @@ class ImageCompressTransformation(DataTransformation):
|
|
|
93
105
|
|
|
94
106
|
def transform(self, source_space, data):
|
|
95
107
|
self.validate_source_space(source_space)
|
|
108
|
+
|
|
109
|
+
max_compressed_size = self._get_max_compressed_size(source_space)
|
|
96
110
|
data_numpy = source_space.backend.to_numpy(data)
|
|
97
111
|
flat_data_numpy = data_numpy.reshape(-1, *data_numpy.shape[-3:])
|
|
98
|
-
flat_compressed_data = np.zeros((flat_data_numpy.shape[0],
|
|
112
|
+
flat_compressed_data = np.zeros((flat_data_numpy.shape[0], max_compressed_size), dtype=np.uint8)
|
|
99
113
|
for i in range(flat_data_numpy.shape[0]):
|
|
100
114
|
img_array = flat_data_numpy[i]
|
|
101
115
|
image_bytes, _ = self.encode_to_size(
|
|
102
116
|
img_array,
|
|
103
|
-
|
|
117
|
+
max_compressed_size,
|
|
104
118
|
mode=self.mode
|
|
105
119
|
)
|
|
106
120
|
byte_array = np.frombuffer(image_bytes, dtype=np.uint8)
|
|
107
121
|
flat_compressed_data[i, :len(byte_array)] = byte_array
|
|
108
|
-
compressed_data = flat_compressed_data.reshape(data_numpy.shape[:-3] + (
|
|
122
|
+
compressed_data = flat_compressed_data.reshape(data_numpy.shape[:-3] + (max_compressed_size, ))
|
|
109
123
|
compressed_data_backend = source_space.backend.from_numpy(compressed_data, dtype=self.get_uint8_dtype(source_space.backend), device=source_space.device)
|
|
110
124
|
return compressed_data_backend
|
|
111
125
|
|
|
@@ -206,7 +220,6 @@ class ImageDecompressTransformation(DataTransformation):
|
|
|
206
220
|
assert source_space is not None, "Source space must be provided to get inverse transformation."
|
|
207
221
|
self.validate_source_space(source_space)
|
|
208
222
|
return ImageCompressTransformation(
|
|
209
|
-
init_quality=75,
|
|
210
223
|
max_size_bytes=source_space.shape[-1],
|
|
211
224
|
mode=self.mode,
|
|
212
225
|
format=self.format if self.format is not None else "JPEG",
|
|
@@ -8,7 +8,7 @@ from unienv_interface.utils import seed_util
|
|
|
8
8
|
from unienv_interface.env_base.funcenv import FuncEnv, ContextType, ObsType, ActType, RenderFrame, StateType, RenderStateType
|
|
9
9
|
from unienv_interface.env_base.funcenv_wrapper import *
|
|
10
10
|
from unienv_interface.space import Space
|
|
11
|
-
from unienv_interface.utils.
|
|
11
|
+
from unienv_interface.utils.framestack_queue import FuncSpaceDataQueue, SpaceDataQueueState
|
|
12
12
|
from unienv_interface.utils.stateclass import StateClass, field
|
|
13
13
|
|
|
14
14
|
class FuncFrameStackWrapperState(
|
|
@@ -192,14 +192,20 @@ def unflatten_data(space : Space, data : BArrayType, start_dim : int = 0) -> Any
|
|
|
192
192
|
@flatten_data.register(BinarySpace)
|
|
193
193
|
def _flatten_data_common(space: typing.Union[BoxSpace, BinarySpace], data: BArrayType, start_dim : int = 0) -> BArrayType:
|
|
194
194
|
assert -len(space.shape) <= start_dim <= len(space.shape)
|
|
195
|
-
|
|
195
|
+
dat = space.backend.reshape(data, data.shape[:start_dim] + (-1,))
|
|
196
|
+
if isinstance(space, BinarySpace):
|
|
197
|
+
dat = space.backend.astype(dat, space.backend.default_integer_dtype)
|
|
198
|
+
return dat
|
|
196
199
|
|
|
197
200
|
@unflatten_data.register(BoxSpace)
|
|
198
201
|
@unflatten_data.register(BinarySpace)
|
|
199
202
|
def _unflatten_data_common(space: typing.Union[BoxSpace, BinarySpace], data: Any, start_dim : int = 0) -> BArrayType:
|
|
200
203
|
assert -len(space.shape) <= start_dim <= len(space.shape)
|
|
201
204
|
unflat_dat = space.backend.reshape(data, data.shape[:start_dim] + space.shape[start_dim:])
|
|
202
|
-
|
|
205
|
+
if isinstance(space, BinarySpace):
|
|
206
|
+
unflat_dat = space.backend.astype(unflat_dat, space.dtype if space.dtype is not None else space.backend.default_boolean_dtype)
|
|
207
|
+
else:
|
|
208
|
+
unflat_dat = space.backend.astype(unflat_dat, space.dtype)
|
|
203
209
|
return unflat_dat
|
|
204
210
|
|
|
205
211
|
@flatten_data.register(DynamicBoxSpace)
|
|
@@ -40,7 +40,7 @@ class TupleSpace(Space[Tuple[Any, ...], BDeviceType, BDtypeType, BRNGType]):
|
|
|
40
40
|
return self
|
|
41
41
|
|
|
42
42
|
new_device = device if backend is not None else (device or self.device)
|
|
43
|
-
return
|
|
43
|
+
return TupleSpace(
|
|
44
44
|
backend=backend or self.backend,
|
|
45
45
|
spaces=[space.to(backend, new_device) for space in self.spaces],
|
|
46
46
|
device=new_device
|
|
@@ -93,11 +93,11 @@ class TupleSpace(Space[Tuple[Any, ...], BDeviceType, BDtypeType, BRNGType]):
|
|
|
93
93
|
|
|
94
94
|
def __eq__(self, other: Any) -> bool:
|
|
95
95
|
"""Check whether ``other`` is equivalent to this instance."""
|
|
96
|
-
return isinstance(other,
|
|
96
|
+
return isinstance(other, TupleSpace) and self.spaces == other.spaces
|
|
97
97
|
|
|
98
|
-
def __copy__(self) -> "
|
|
98
|
+
def __copy__(self) -> "TupleSpace[BDeviceType, BDtypeType, BRNGType]":
|
|
99
99
|
"""Create a shallow copy of the Dict space."""
|
|
100
|
-
return
|
|
100
|
+
return TupleSpace(
|
|
101
101
|
backend=self.backend,
|
|
102
102
|
spaces=copy.copy(self.spaces),
|
|
103
103
|
device=self.device
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
from unienv_interface.space.space_utils import batch_utils as sbu
|
|
2
|
+
from .transformation import DataTransformation, TargetDataT
|
|
3
|
+
from unienv_interface.space import Space, BoxSpace
|
|
4
|
+
from typing import Union, Any, Optional
|
|
5
|
+
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
6
|
+
|
|
7
|
+
class ImageResizeTransformation(DataTransformation):
|
|
8
|
+
has_inverse = True
|
|
9
|
+
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
new_height: int,
|
|
13
|
+
new_width: int
|
|
14
|
+
):
|
|
15
|
+
self.new_height = new_height
|
|
16
|
+
self.new_width = new_width
|
|
17
|
+
|
|
18
|
+
def _validate_source_space(self, source_space : Space[Any, BDeviceType, BDtypeType, BRNGType]) -> BoxSpace[BArrayType, BDeviceType, BDtypeType, BRNGType]:
|
|
19
|
+
assert isinstance(source_space, BoxSpace), \
|
|
20
|
+
f"ImageResizeTransformation only supports BoxSpace, got {type(source_space)}"
|
|
21
|
+
assert len(source_space.shape) >= 3, \
|
|
22
|
+
f"ImageResizeTransformation only supports spaces with at least 3 dimensions (H, W, C), got shape {source_space.shape}"
|
|
23
|
+
assert source_space.shape[-3] > 0 and source_space.shape[-2] > 0, \
|
|
24
|
+
f"ImageResizeTransformation requires positive height and width, got shape {source_space.shape}"
|
|
25
|
+
return source_space
|
|
26
|
+
|
|
27
|
+
def get_target_space_from_source(self, source_space):
|
|
28
|
+
source_space = self._validate_source_space(source_space)
|
|
29
|
+
|
|
30
|
+
backend = source_space.backend
|
|
31
|
+
new_shape = (
|
|
32
|
+
*source_space.shape[:-3],
|
|
33
|
+
self.new_height,
|
|
34
|
+
self.new_width,
|
|
35
|
+
source_space.shape[-1]
|
|
36
|
+
)
|
|
37
|
+
new_low = backend.min(source_space.low, axis=(-3, -2), keepdims=True)
|
|
38
|
+
new_high = backend.max(source_space.high, axis=(-3, -2), keepdims=True)
|
|
39
|
+
|
|
40
|
+
return BoxSpace(
|
|
41
|
+
source_space.backend,
|
|
42
|
+
new_low,
|
|
43
|
+
new_high,
|
|
44
|
+
dtype=source_space.dtype,
|
|
45
|
+
device=source_space.device,
|
|
46
|
+
shape=new_shape
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
def transform(self, source_space, data):
|
|
50
|
+
source_space = self._validate_source_space(source_space)
|
|
51
|
+
backend = source_space.backend
|
|
52
|
+
if backend.simplified_name == "jax":
|
|
53
|
+
target_shape = (
|
|
54
|
+
*data.shape[:-3],
|
|
55
|
+
self.new_height,
|
|
56
|
+
self.new_width,
|
|
57
|
+
source_space.shape[-1]
|
|
58
|
+
)
|
|
59
|
+
import jax.image
|
|
60
|
+
resized_data = jax.image.resize(
|
|
61
|
+
data,
|
|
62
|
+
shape=target_shape,
|
|
63
|
+
method='bilinear',
|
|
64
|
+
antialias=True
|
|
65
|
+
)
|
|
66
|
+
elif backend.simplified_name == "pytorch":
|
|
67
|
+
import torch.nn.functional as F
|
|
68
|
+
# PyTorch expects (B, C, H, W)
|
|
69
|
+
data_permuted = backend.permute_dims(data, (*range(len(data.shape[:-3])), -1, -3, -2))
|
|
70
|
+
resized_data_permuted = F.interpolate(
|
|
71
|
+
data_permuted,
|
|
72
|
+
size=(self.new_height, self.new_width),
|
|
73
|
+
mode='bilinear',
|
|
74
|
+
align_corners=False,
|
|
75
|
+
antialias=True
|
|
76
|
+
)
|
|
77
|
+
# Permute back to original shape
|
|
78
|
+
resized_data = backend.permute_dims(resized_data_permuted, (*range(len(resized_data_permuted.shape[:-3])), -2, -1, -3))
|
|
79
|
+
elif backend.simplified_name == "numpy":
|
|
80
|
+
import cv2
|
|
81
|
+
flat_data = backend.reshape(data, (-1, *source_space.shape[-3:]))
|
|
82
|
+
resized_flat_data = []
|
|
83
|
+
for i in range(flat_data.shape[0]):
|
|
84
|
+
img = flat_data[i]
|
|
85
|
+
resized_img = cv2.resize(
|
|
86
|
+
img,
|
|
87
|
+
(self.new_width, self.new_height),
|
|
88
|
+
interpolation=cv2.INTER_LINEAR
|
|
89
|
+
)
|
|
90
|
+
resized_flat_data.append(resized_img)
|
|
91
|
+
resized_flat_data = backend.stack(resized_flat_data, axis=0)
|
|
92
|
+
resized_data = backend.reshape(
|
|
93
|
+
resized_flat_data,
|
|
94
|
+
(*data.shape[:-3], self.new_height, self.new_width, source_space.shape[-1])
|
|
95
|
+
)
|
|
96
|
+
else:
|
|
97
|
+
raise ValueError(f"Unsupported backend: {backend.simplified_name}")
|
|
98
|
+
return resized_data
|
|
99
|
+
|
|
100
|
+
def direction_inverse(self, source_space = None):
|
|
101
|
+
assert source_space is not None, "Inverse transformation requires source_space"
|
|
102
|
+
source_space = self._validate_source_space(source_space)
|
|
103
|
+
return ImageResizeTransformation(
|
|
104
|
+
new_height=source_space.shape[-3],
|
|
105
|
+
new_width=source_space.shape[-2]
|
|
106
|
+
)
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
from typing import Union, Any, Optional, Mapping, List, Callable, Dict
|
|
2
|
+
|
|
3
|
+
from unienv_interface.space.space_utils import batch_utils as sbu
|
|
4
|
+
from unienv_interface.space import Space, DictSpace, TupleSpace
|
|
5
|
+
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
6
|
+
|
|
7
|
+
import copy
|
|
8
|
+
from .transformation import DataTransformation, TargetDataT
|
|
9
|
+
|
|
10
|
+
def default_is_leaf_fn(space : Space[Any, BDeviceType, BDtypeType, BRNGType]):
|
|
11
|
+
return not isinstance(space, (DictSpace, TupleSpace))
|
|
12
|
+
|
|
13
|
+
class IterativeTransformation(DataTransformation):
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
transformation: DataTransformation,
|
|
17
|
+
is_leaf_node_fn: Callable[[Space[Any, BDeviceType, BDtypeType, BRNGType]], bool] = default_is_leaf_fn,
|
|
18
|
+
inv_is_leaf_node_fn: Callable[[Space[Any, BDeviceType, BDtypeType, BRNGType]], bool] = default_is_leaf_fn
|
|
19
|
+
):
|
|
20
|
+
self.transformation = transformation
|
|
21
|
+
self.is_leaf_node_fn = is_leaf_node_fn
|
|
22
|
+
self.inv_is_leaf_node_fn = inv_is_leaf_node_fn
|
|
23
|
+
self.has_inverse = transformation.has_inverse
|
|
24
|
+
|
|
25
|
+
def get_target_space_from_source(
|
|
26
|
+
self,
|
|
27
|
+
source_space : Space[Any, BDeviceType, BDtypeType, BRNGType]
|
|
28
|
+
):
|
|
29
|
+
if self.is_leaf_node_fn(source_space):
|
|
30
|
+
return self.transformation.get_target_space_from_source(source_space)
|
|
31
|
+
elif isinstance(source_space, DictSpace):
|
|
32
|
+
rsts = {
|
|
33
|
+
key: self.get_target_space_from_source(subspace)
|
|
34
|
+
for key, subspace in source_space.spaces.items()
|
|
35
|
+
}
|
|
36
|
+
backend = source_space.backend if len(rsts) == 0 else next(iter(rsts.values())).backend
|
|
37
|
+
device = source_space.device if len(rsts) == 0 else next(iter(rsts.values())).device
|
|
38
|
+
return DictSpace(
|
|
39
|
+
backend,
|
|
40
|
+
rsts,
|
|
41
|
+
device=device
|
|
42
|
+
)
|
|
43
|
+
elif isinstance(source_space, TupleSpace):
|
|
44
|
+
rsts = tuple(
|
|
45
|
+
self.get_target_space_from_source(subspace)
|
|
46
|
+
for subspace in source_space.spaces
|
|
47
|
+
)
|
|
48
|
+
backend = source_space.backend if len(rsts) == 0 else next(iter(rsts)).backend
|
|
49
|
+
device = source_space.device if len(rsts) == 0 else next(iter(rsts)).device
|
|
50
|
+
return TupleSpace(
|
|
51
|
+
backend,
|
|
52
|
+
rsts,
|
|
53
|
+
device=device
|
|
54
|
+
)
|
|
55
|
+
else:
|
|
56
|
+
raise ValueError(f"Unsupported space type: {type(source_space)}")
|
|
57
|
+
|
|
58
|
+
def transform(
|
|
59
|
+
self,
|
|
60
|
+
source_space: Space,
|
|
61
|
+
data: Union[Mapping[str, Any], BArrayType]
|
|
62
|
+
) -> Union[Mapping[str, Any], BArrayType]:
|
|
63
|
+
if self.is_leaf_node_fn(source_space):
|
|
64
|
+
return self.transformation.transform(source_space, data)
|
|
65
|
+
elif isinstance(source_space, DictSpace):
|
|
66
|
+
return {
|
|
67
|
+
key: self.transform(subspace, data[key])
|
|
68
|
+
for key, subspace in source_space.spaces.items()
|
|
69
|
+
}
|
|
70
|
+
elif isinstance(source_space, TupleSpace):
|
|
71
|
+
return tuple(
|
|
72
|
+
self.transform(subspace, data[i])
|
|
73
|
+
for i, subspace in enumerate(source_space.spaces)
|
|
74
|
+
)
|
|
75
|
+
else:
|
|
76
|
+
raise ValueError(f"Unsupported space type: {type(source_space)}")
|
|
77
|
+
|
|
78
|
+
def direction_inverse(
|
|
79
|
+
self,
|
|
80
|
+
source_space = None,
|
|
81
|
+
) -> Optional["IterativeTransformation"]:
|
|
82
|
+
if not self.has_inverse:
|
|
83
|
+
return None
|
|
84
|
+
|
|
85
|
+
return IterativeTransformation(
|
|
86
|
+
self.transformation.direction_inverse(),
|
|
87
|
+
is_leaf_node_fn=self.inv_is_leaf_node_fn,
|
|
88
|
+
inv_is_leaf_node_fn=self.is_leaf_node_fn
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
def close(self):
|
|
92
|
+
self.transformation.close()
|
|
@@ -5,9 +5,15 @@ __all__ = [
|
|
|
5
5
|
"get_class_from_full_name",
|
|
6
6
|
]
|
|
7
7
|
|
|
8
|
+
REMAP = {
|
|
9
|
+
"unienv_data.storages.common.FlattenedStorage": "unienv_data.storages.flattened.FlattenedStorage",
|
|
10
|
+
}
|
|
11
|
+
|
|
8
12
|
def get_full_class_name(cls : Type) -> str:
|
|
9
13
|
return f"{cls.__module__}.{cls.__qualname__}"
|
|
10
14
|
|
|
11
15
|
def get_class_from_full_name(full_name : str) -> Type:
|
|
16
|
+
if full_name in REMAP:
|
|
17
|
+
full_name = REMAP[full_name]
|
|
12
18
|
module_name, class_name = full_name.rsplit(".", 1)
|
|
13
|
-
return getattr(__import__(module_name, fromlist=[class_name]), class_name)
|
|
19
|
+
return getattr(__import__(module_name, fromlist=[class_name]), class_name)
|
|
@@ -7,7 +7,7 @@ from unienv_interface.space.space_utils import batch_utils as sbu
|
|
|
7
7
|
from unienv_interface.env_base.env import Env, ContextType, ObsType, ActType, RenderFrame, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
8
8
|
from unienv_interface.env_base.wrapper import ContextObservationWrapper, ActionWrapper, WrapperContextT, WrapperObsT, WrapperActT
|
|
9
9
|
from unienv_interface.space import Space, DictSpace
|
|
10
|
-
from unienv_interface.utils.
|
|
10
|
+
from unienv_interface.utils.framestack_queue import SpaceDataQueue
|
|
11
11
|
|
|
12
12
|
class FrameStackWrapper(
|
|
13
13
|
ContextObservationWrapper[
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|