unienv 0.0.1b5__py3-none-any.whl → 0.0.1b7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {unienv-0.0.1b5.dist-info → unienv-0.0.1b7.dist-info}/METADATA +3 -2
- {unienv-0.0.1b5.dist-info → unienv-0.0.1b7.dist-info}/RECORD +30 -21
- {unienv-0.0.1b5.dist-info → unienv-0.0.1b7.dist-info}/WHEEL +1 -1
- unienv_data/base/common.py +25 -10
- unienv_data/batches/backend_compat.py +1 -1
- unienv_data/batches/combined_batch.py +1 -1
- unienv_data/replay_buffer/replay_buffer.py +51 -8
- unienv_data/storages/_episode_storage.py +438 -0
- unienv_data/storages/_list_storage.py +136 -0
- unienv_data/storages/backend_compat.py +268 -0
- unienv_data/storages/flattened.py +3 -3
- unienv_data/storages/hdf5.py +7 -2
- unienv_data/storages/image_storage.py +144 -0
- unienv_data/storages/npz_storage.py +135 -0
- unienv_data/storages/pytorch.py +16 -9
- unienv_data/storages/video_storage.py +297 -0
- unienv_data/third_party/tensordict/memmap_tensor.py +1174 -0
- unienv_data/transformations/image_compress.py +81 -18
- unienv_interface/space/space_utils/batch_utils.py +5 -1
- unienv_interface/space/spaces/dict.py +6 -0
- unienv_interface/transformations/__init__.py +3 -1
- unienv_interface/transformations/batch_and_unbatch.py +43 -4
- unienv_interface/transformations/chained_transform.py +9 -8
- unienv_interface/transformations/crop.py +69 -0
- unienv_interface/transformations/dict_transform.py +8 -2
- unienv_interface/transformations/identity.py +16 -0
- unienv_interface/transformations/rescale.py +24 -5
- unienv_interface/wrapper/backend_compat.py +1 -1
- {unienv-0.0.1b5.dist-info → unienv-0.0.1b7.dist-info}/licenses/LICENSE +0 -0
- {unienv-0.0.1b5.dist-info → unienv-0.0.1b7.dist-info}/top_level.txt +0 -0
|
@@ -6,9 +6,31 @@ from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, B
|
|
|
6
6
|
from PIL import Image
|
|
7
7
|
import numpy as np
|
|
8
8
|
import io
|
|
9
|
+
import math
|
|
10
|
+
|
|
11
|
+
# https://stackoverflow.com/questions/3471663/jpeg-compression-ratio
|
|
12
|
+
JPEG_QUALITY_COMPRESSION_MAP = {
|
|
13
|
+
"quality": np.array([55, 60, 65, 70, 75, 80, 85, 90, 95, 100], dtype=int),
|
|
14
|
+
"compression_ratio": np.array([43.27, 36.90, 34.24, 31.50, 26.00, 25.06, 19.08, 14.30, 9.88, 5.27], dtype=float),
|
|
15
|
+
"conservative_ratio": 0.6,
|
|
16
|
+
}
|
|
17
|
+
def get_jpeg_compression_ratio(init_quality : int) -> int:
|
|
18
|
+
if init_quality <= JPEG_QUALITY_COMPRESSION_MAP['quality'][0]:
|
|
19
|
+
return math.floor(JPEG_QUALITY_COMPRESSION_MAP['compression_ratio'][0] * JPEG_QUALITY_COMPRESSION_MAP['conservative_ratio'])
|
|
20
|
+
if init_quality >= JPEG_QUALITY_COMPRESSION_MAP['quality'][-1]:
|
|
21
|
+
return math.floor(JPEG_QUALITY_COMPRESSION_MAP['compression_ratio'][-1] * JPEG_QUALITY_COMPRESSION_MAP['conservative_ratio'])
|
|
22
|
+
|
|
23
|
+
for i in range(1, len(JPEG_QUALITY_COMPRESSION_MAP['quality'])):
|
|
24
|
+
if init_quality <= JPEG_QUALITY_COMPRESSION_MAP['quality'][i]:
|
|
25
|
+
q_low = JPEG_QUALITY_COMPRESSION_MAP['quality'][i - 1]
|
|
26
|
+
q_high = JPEG_QUALITY_COMPRESSION_MAP['quality'][i]
|
|
27
|
+
r_low = JPEG_QUALITY_COMPRESSION_MAP['compression_ratio'][i - 1]
|
|
28
|
+
r_high = JPEG_QUALITY_COMPRESSION_MAP['compression_ratio'][i]
|
|
29
|
+
ratio = r_low + (r_high - r_low) * (init_quality - q_low) / (q_high - q_low)
|
|
30
|
+
return math.floor(ratio * JPEG_QUALITY_COMPRESSION_MAP['conservative_ratio'])
|
|
9
31
|
|
|
10
32
|
CONSERVATIVE_COMPRESSION_RATIOS = {
|
|
11
|
-
"JPEG":
|
|
33
|
+
"JPEG": get_jpeg_compression_ratio,
|
|
12
34
|
}
|
|
13
35
|
|
|
14
36
|
class ImageCompressTransformation(DataTransformation):
|
|
@@ -18,8 +40,10 @@ class ImageCompressTransformation(DataTransformation):
|
|
|
18
40
|
self,
|
|
19
41
|
init_quality : int = 70,
|
|
20
42
|
max_size_bytes : Optional[int] = None,
|
|
43
|
+
compression_ratio : Optional[float] = None,
|
|
21
44
|
mode : Optional[str] = None,
|
|
22
45
|
format : str = "JPEG",
|
|
46
|
+
last_channel : bool = True,
|
|
23
47
|
) -> None:
|
|
24
48
|
"""
|
|
25
49
|
Initialize JPEG compression transformation.
|
|
@@ -29,21 +53,25 @@ class ImageCompressTransformation(DataTransformation):
|
|
|
29
53
|
mode: Optional mode for PIL Image (e.g., "RGB", "L"). If None, inferred from input.
|
|
30
54
|
format: Image format to use for compression (default "JPEG"). See https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html for options.
|
|
31
55
|
"""
|
|
32
|
-
assert max_size_bytes is not None
|
|
56
|
+
assert not (max_size_bytes is not None and compression_ratio is not None), "Specify either max_size_bytes or compression_ratio, not both."
|
|
57
|
+
assert max_size_bytes is not None or compression_ratio is not None or format in CONSERVATIVE_COMPRESSION_RATIOS, "Either max_size_bytes must be specified or format must have a conservative compression ratio defined."
|
|
33
58
|
|
|
34
59
|
self.init_quality = init_quality
|
|
35
60
|
self.max_size_bytes = max_size_bytes
|
|
36
|
-
self.compression_ratio = CONSERVATIVE_COMPRESSION_RATIOS.get(format, None) if max_size_bytes is None else None
|
|
61
|
+
self.compression_ratio = compression_ratio if compression_ratio is not None else (CONSERVATIVE_COMPRESSION_RATIOS.get(format, None)(init_quality) if max_size_bytes is None else None)
|
|
37
62
|
self.mode = mode
|
|
38
63
|
self.format = format
|
|
64
|
+
self.last_channel = last_channel
|
|
39
65
|
|
|
40
|
-
|
|
41
|
-
def validate_source_space(source_space: Space[Any, BDeviceType, BDtypeType, BRNGType]) -> None:
|
|
66
|
+
def validate_source_space(self, source_space: Space[Any, BDeviceType, BDtypeType, BRNGType]) -> None:
|
|
42
67
|
assert isinstance(source_space, BoxSpace), "JPEGCompressTransformation only supports BoxSpace source spaces."
|
|
43
|
-
|
|
44
|
-
source_space.shape
|
|
45
|
-
|
|
46
|
-
|
|
68
|
+
if not self.last_channel:
|
|
69
|
+
assert len(source_space.shape) >= 2
|
|
70
|
+
else:
|
|
71
|
+
assert len(source_space.shape) >= 3 and (
|
|
72
|
+
source_space.shape[-1] == 3 or
|
|
73
|
+
source_space.shape[-1] == 1
|
|
74
|
+
), "JPEGCompressTransformation only supports BoxSpace source spaces with shape (..., H, W, 1 or 3)."
|
|
47
75
|
|
|
48
76
|
@staticmethod
|
|
49
77
|
def get_uint8_dtype(
|
|
@@ -52,14 +80,22 @@ class ImageCompressTransformation(DataTransformation):
|
|
|
52
80
|
return backend.__array_namespace_info__().dtypes()['uint8']
|
|
53
81
|
|
|
54
82
|
def _get_max_compressed_size(self, source_space : BoxSpace):
|
|
55
|
-
|
|
83
|
+
if self.last_channel:
|
|
84
|
+
H, W, C = source_space.shape[-3], source_space.shape[-2], source_space.shape[-1]
|
|
85
|
+
else:
|
|
86
|
+
H, W = source_space.shape[-2], source_space.shape[-1]
|
|
87
|
+
C = 1
|
|
56
88
|
return self.max_size_bytes if self.max_size_bytes is not None else (H * W * C // self.compression_ratio) + 1
|
|
57
89
|
|
|
58
90
|
def get_target_space_from_source(self, source_space):
|
|
59
91
|
self.validate_source_space(source_space)
|
|
60
92
|
|
|
61
93
|
max_compressed_size = self._get_max_compressed_size(source_space)
|
|
62
|
-
|
|
94
|
+
|
|
95
|
+
if not self.last_channel:
|
|
96
|
+
new_shape = source_space.shape[:-2] + (max_compressed_size,)
|
|
97
|
+
else:
|
|
98
|
+
new_shape = source_space.shape[:-3] + (max_compressed_size,)
|
|
63
99
|
|
|
64
100
|
return BoxSpace(
|
|
65
101
|
source_space.backend,
|
|
@@ -108,7 +144,11 @@ class ImageCompressTransformation(DataTransformation):
|
|
|
108
144
|
|
|
109
145
|
max_compressed_size = self._get_max_compressed_size(source_space)
|
|
110
146
|
data_numpy = source_space.backend.to_numpy(data)
|
|
111
|
-
|
|
147
|
+
if not self.last_channel:
|
|
148
|
+
flat_data_numpy = data_numpy.reshape(-1, *data_numpy.shape[-2:])
|
|
149
|
+
else:
|
|
150
|
+
flat_data_numpy = data_numpy.reshape(-1, *data_numpy.shape[-3:])
|
|
151
|
+
|
|
112
152
|
flat_compressed_data = np.zeros((flat_data_numpy.shape[0], max_compressed_size), dtype=np.uint8)
|
|
113
153
|
for i in range(flat_data_numpy.shape[0]):
|
|
114
154
|
img_array = flat_data_numpy[i]
|
|
@@ -119,16 +159,26 @@ class ImageCompressTransformation(DataTransformation):
|
|
|
119
159
|
)
|
|
120
160
|
byte_array = np.frombuffer(image_bytes, dtype=np.uint8)
|
|
121
161
|
flat_compressed_data[i, :len(byte_array)] = byte_array
|
|
122
|
-
|
|
162
|
+
|
|
163
|
+
if not self.last_channel:
|
|
164
|
+
compressed_data = flat_compressed_data.reshape(data_numpy.shape[:-2] + (max_compressed_size, ))
|
|
165
|
+
else:
|
|
166
|
+
compressed_data = flat_compressed_data.reshape(data_numpy.shape[:-3] + (max_compressed_size, ))
|
|
123
167
|
compressed_data_backend = source_space.backend.from_numpy(compressed_data, dtype=self.get_uint8_dtype(source_space.backend), device=source_space.device)
|
|
124
168
|
return compressed_data_backend
|
|
125
169
|
|
|
126
170
|
def direction_inverse(self, source_space = None):
|
|
127
171
|
assert source_space is not None, "Source space must be provided to get inverse transformation."
|
|
128
172
|
self.validate_source_space(source_space)
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
173
|
+
|
|
174
|
+
if not self.last_channel:
|
|
175
|
+
height = source_space.shape[-2]
|
|
176
|
+
width = source_space.shape[-1]
|
|
177
|
+
channels = None
|
|
178
|
+
else:
|
|
179
|
+
height = source_space.shape[-3]
|
|
180
|
+
width = source_space.shape[-2]
|
|
181
|
+
channels = source_space.shape[-1]
|
|
132
182
|
return ImageDecompressTransformation(
|
|
133
183
|
target_height=height,
|
|
134
184
|
target_width=width,
|
|
@@ -136,6 +186,12 @@ class ImageCompressTransformation(DataTransformation):
|
|
|
136
186
|
mode=self.mode,
|
|
137
187
|
format=self.format,
|
|
138
188
|
)
|
|
189
|
+
|
|
190
|
+
def __setstate__(self, state):
|
|
191
|
+
# for backward compatibility
|
|
192
|
+
self.__dict__.update(state)
|
|
193
|
+
if 'last_channel' not in state:
|
|
194
|
+
self.last_channel = True
|
|
139
195
|
|
|
140
196
|
class ImageDecompressTransformation(DataTransformation):
|
|
141
197
|
has_inverse = True
|
|
@@ -144,7 +200,7 @@ class ImageDecompressTransformation(DataTransformation):
|
|
|
144
200
|
self,
|
|
145
201
|
target_height : int,
|
|
146
202
|
target_width : int,
|
|
147
|
-
target_channels : int = 3,
|
|
203
|
+
target_channels : Optional[int] = 3,
|
|
148
204
|
mode : Optional[str] = None,
|
|
149
205
|
format : Optional[str] = None,
|
|
150
206
|
) -> None:
|
|
@@ -174,6 +230,9 @@ class ImageDecompressTransformation(DataTransformation):
|
|
|
174
230
|
def get_target_space_from_source(self, source_space):
|
|
175
231
|
self.validate_source_space(source_space)
|
|
176
232
|
new_shape = source_space.shape[:-1] + (self.target_height, self.target_width, self.target_channels)
|
|
233
|
+
if self.target_channels is None:
|
|
234
|
+
new_shape = new_shape[:-1]
|
|
235
|
+
|
|
177
236
|
return BoxSpace(
|
|
178
237
|
source_space.backend,
|
|
179
238
|
shape=new_shape,
|
|
@@ -212,7 +271,10 @@ class ImageDecompressTransformation(DataTransformation):
|
|
|
212
271
|
byte_array.tobytes(),
|
|
213
272
|
mode=self.mode
|
|
214
273
|
)
|
|
215
|
-
|
|
274
|
+
if self.target_channels is None:
|
|
275
|
+
decompressed_image = flat_decompressed_image.reshape(data_numpy.shape[:-1] + (self.target_height, self.target_width))
|
|
276
|
+
else:
|
|
277
|
+
decompressed_image = flat_decompressed_image.reshape(data_numpy.shape[:-1] + (self.target_height, self.target_width, self.target_channels))
|
|
216
278
|
decompressed_image_backend = source_space.backend.from_numpy(decompressed_image, dtype=self.get_uint8_dtype(source_space.backend), device=source_space.device)
|
|
217
279
|
return decompressed_image_backend
|
|
218
280
|
|
|
@@ -223,4 +285,5 @@ class ImageDecompressTransformation(DataTransformation):
|
|
|
223
285
|
max_size_bytes=source_space.shape[-1],
|
|
224
286
|
mode=self.mode,
|
|
225
287
|
format=self.format if self.format is not None else "JPEG",
|
|
288
|
+
last_channel=self.target_channels is not None,
|
|
226
289
|
)
|
|
@@ -334,7 +334,7 @@ def swap_batch_dims_in_data(
|
|
|
334
334
|
) -> Any:
|
|
335
335
|
if backend.is_backendarray(data):
|
|
336
336
|
return _tensor_transpose(backend, data, dim1, dim2)
|
|
337
|
-
elif isinstance(data, np.ndarray)
|
|
337
|
+
elif isinstance(data, np.ndarray):
|
|
338
338
|
return _tensor_transpose(NumpyComputeBackend, data, dim1, dim2)
|
|
339
339
|
elif isinstance(data, GraphInstance):
|
|
340
340
|
return GraphInstance(
|
|
@@ -734,6 +734,8 @@ def _get_at_tuple(space: TupleSpace, items: typing.Tuple[Any, ...], index : Arra
|
|
|
734
734
|
|
|
735
735
|
@get_at.register(BatchedSpace)
|
|
736
736
|
def _get_at_batched(space: BatchedSpace, items: np.ndarray, index: ArrayAPIGetIndex) -> typing.Union[np.ndarray, Any]:
|
|
737
|
+
if space.backend.is_backendarray(index):
|
|
738
|
+
index = space.backend.to_numpy(index)
|
|
737
739
|
return items[index]
|
|
738
740
|
|
|
739
741
|
@singledispatch
|
|
@@ -830,6 +832,8 @@ def _set_at_batched(
|
|
|
830
832
|
value: typing.Union[np.ndarray, Any],
|
|
831
833
|
) -> np.ndarray:
|
|
832
834
|
new_data = items.copy()
|
|
835
|
+
if space.backend.is_backendarray(index):
|
|
836
|
+
index = space.backend.to_numpy(index)
|
|
833
837
|
new_data[index] = value
|
|
834
838
|
return new_data
|
|
835
839
|
|
|
@@ -34,6 +34,12 @@ class DictSpace(Space[Dict[str, Any], BDeviceType, BDtypeType, BRNGType]):
|
|
|
34
34
|
new_spaces: Dict[str, Space[Any, BDeviceType, BDtypeType, BRNGType]] = {}
|
|
35
35
|
|
|
36
36
|
for key, space in spaces.items():
|
|
37
|
+
if not isinstance(space, Space) and isinstance(space, Mapping):
|
|
38
|
+
space = DictSpace(
|
|
39
|
+
backend,
|
|
40
|
+
space,
|
|
41
|
+
device
|
|
42
|
+
)
|
|
37
43
|
assert isinstance(
|
|
38
44
|
space, Space
|
|
39
45
|
), f"Dict space element is not an instance of Space: key='{key}', space={space}"
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
from .transformation import DataTransformation
|
|
2
|
+
from .identity import IdentityTransformation
|
|
2
3
|
from .rescale import RescaleTransformation
|
|
3
4
|
from .filter_dict import DictIncludeKeyTransformation, DictExcludeKeyTransformation
|
|
4
5
|
from .batch_and_unbatch import BatchifyTransformation, UnBatchifyTransformation
|
|
5
6
|
from .dict_transform import DictTransformation
|
|
6
|
-
from .chained_transform import ChainedTransformation
|
|
7
|
+
from .chained_transform import ChainedTransformation
|
|
8
|
+
from .crop import CropTransformation
|
|
@@ -6,26 +6,65 @@ from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, B
|
|
|
6
6
|
|
|
7
7
|
class BatchifyTransformation(DataTransformation):
|
|
8
8
|
has_inverse = True
|
|
9
|
+
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
axis : int = 0
|
|
13
|
+
) -> None:
|
|
14
|
+
self.axis = axis
|
|
9
15
|
|
|
10
16
|
def get_target_space_from_source(self, source_space):
|
|
11
|
-
|
|
17
|
+
ret = sbu.batch_space(source_space, 1)
|
|
18
|
+
if self.axis != 0:
|
|
19
|
+
ret = sbu.swap_batch_dims(
|
|
20
|
+
ret,
|
|
21
|
+
0,
|
|
22
|
+
self.axis
|
|
23
|
+
)
|
|
24
|
+
return ret
|
|
12
25
|
|
|
13
26
|
def transform(self, source_space, data):
|
|
14
27
|
return sbu.concatenate(
|
|
15
28
|
source_space,
|
|
16
|
-
[data]
|
|
29
|
+
[data],
|
|
30
|
+
axis=self.axis
|
|
17
31
|
)
|
|
18
32
|
|
|
19
33
|
def direction_inverse(self, source_space = None):
|
|
20
|
-
return UnBatchifyTransformation()
|
|
34
|
+
return UnBatchifyTransformation(axis=self.axis)
|
|
21
35
|
|
|
22
36
|
class UnBatchifyTransformation(DataTransformation):
|
|
23
37
|
has_inverse = True
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
axis : int = 0
|
|
42
|
+
) -> None:
|
|
43
|
+
self.axis = axis
|
|
24
44
|
|
|
25
45
|
def get_target_space_from_source(self, source_space):
|
|
46
|
+
if self.axis != 0:
|
|
47
|
+
source_space = sbu.swap_batch_dims(
|
|
48
|
+
source_space,
|
|
49
|
+
0,
|
|
50
|
+
self.axis
|
|
51
|
+
)
|
|
52
|
+
assert sbu.batch_size(source_space) == 1, "Cannot unbatch space with batch size > 1"
|
|
26
53
|
return next(iter(sbu.unbatch_spaces(source_space)))
|
|
27
54
|
|
|
28
55
|
def transform(self, source_space, data):
|
|
56
|
+
if self.axis != 0:
|
|
57
|
+
source_space = sbu.swap_batch_dims(
|
|
58
|
+
source_space,
|
|
59
|
+
0,
|
|
60
|
+
self.axis
|
|
61
|
+
)
|
|
62
|
+
data = sbu.swap_batch_dims_in_data(
|
|
63
|
+
source_space.backend,
|
|
64
|
+
data,
|
|
65
|
+
0,
|
|
66
|
+
self.axis
|
|
67
|
+
)
|
|
29
68
|
return sbu.get_at(
|
|
30
69
|
source_space,
|
|
31
70
|
data,
|
|
@@ -33,4 +72,4 @@ class UnBatchifyTransformation(DataTransformation):
|
|
|
33
72
|
)
|
|
34
73
|
|
|
35
74
|
def direction_inverse(self, source_space = None):
|
|
36
|
-
return BatchifyTransformation()
|
|
75
|
+
return BatchifyTransformation(axis=self.axis)
|
|
@@ -45,17 +45,18 @@ class ChainedTransformation(DataTransformation):
|
|
|
45
45
|
) -> Optional["ChainedTransformation"]:
|
|
46
46
|
if not self.has_inverse:
|
|
47
47
|
return None
|
|
48
|
+
|
|
49
|
+
source_spaces = [source_space]
|
|
50
|
+
for transformation in self.transformations:
|
|
51
|
+
next_space = transformation.get_target_space_from_source(source_spaces[-1])
|
|
52
|
+
source_spaces.append(next_space)
|
|
48
53
|
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
}
|
|
54
|
+
inverted_transformations = []
|
|
55
|
+
for i in reversed(range(len(self.transformations))):
|
|
56
|
+
inverted_transformations.append(self.transformations[i].direction_inverse(source_spaces[i]))
|
|
53
57
|
|
|
54
58
|
return ChainedTransformation(
|
|
55
|
-
|
|
56
|
-
transformation.direction_inverse(source_space)
|
|
57
|
-
for transformation in reversed(self.transformations)
|
|
58
|
-
]
|
|
59
|
+
inverted_transformations
|
|
59
60
|
)
|
|
60
61
|
|
|
61
62
|
def close(self):
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
from unienv_interface.space.space_utils import batch_utils as sbu
|
|
2
|
+
from unienv_interface.transformations import DataTransformation
|
|
3
|
+
from unienv_interface.space import Space, BoxSpace, TextSpace
|
|
4
|
+
from typing import Union, Any, Optional, Tuple, List
|
|
5
|
+
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
6
|
+
from .identity import IdentityTransformation
|
|
7
|
+
|
|
8
|
+
class CropTransformation(DataTransformation):
|
|
9
|
+
has_inverse = True
|
|
10
|
+
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
crop_low : Union[int, float, BArrayType],
|
|
14
|
+
crop_high : Union[int, float, BArrayType],
|
|
15
|
+
) -> None:
|
|
16
|
+
"""
|
|
17
|
+
Initialize Crop Transformation.
|
|
18
|
+
Args:
|
|
19
|
+
crop_low: Lower bound for cropping the data.
|
|
20
|
+
crop_high: Upper bound for cropping the data.
|
|
21
|
+
"""
|
|
22
|
+
self.crop_low = crop_low
|
|
23
|
+
self.crop_high = crop_high
|
|
24
|
+
|
|
25
|
+
def validate_source_space(self, source_space : Space[Any, BDeviceType, BDtypeType, BRNGType]) -> None:
|
|
26
|
+
assert isinstance(source_space, BoxSpace), "CropTransformation only supports Box spaces"
|
|
27
|
+
|
|
28
|
+
def get_crop_range(self, source_space : BoxSpace[Any, BDeviceType, BDtypeType, BRNGType]) -> Tuple[BArrayType, BArrayType]:
|
|
29
|
+
new_low = self.crop_low
|
|
30
|
+
if source_space.backend.is_backendarray(new_low):
|
|
31
|
+
if len(new_low.shape) < len(source_space.shape):
|
|
32
|
+
new_low = source_space.backend.reshape(
|
|
33
|
+
new_low,
|
|
34
|
+
(-1,)*(len(source_space.shape) - len(new_low.shape)) + new_low.shape
|
|
35
|
+
)
|
|
36
|
+
new_low = source_space.backend.astype(new_low, source_space.dtype)
|
|
37
|
+
if source_space.device is not None:
|
|
38
|
+
new_low = source_space.backend.to_device(new_low, source_space.device)
|
|
39
|
+
new_high = self.crop_high
|
|
40
|
+
if source_space.backend.is_backendarray(new_high):
|
|
41
|
+
if len(new_high.shape) < len(source_space.shape):
|
|
42
|
+
new_high = source_space.backend.reshape(
|
|
43
|
+
new_high,
|
|
44
|
+
(-1,)*(len(source_space.shape) - len(new_high.shape)) + new_high.shape
|
|
45
|
+
)
|
|
46
|
+
new_high = source_space.backend.astype(new_high, source_space.dtype)
|
|
47
|
+
if source_space.device is not None:
|
|
48
|
+
new_high = source_space.backend.to_device(new_high, source_space.device)
|
|
49
|
+
return new_low, new_high
|
|
50
|
+
|
|
51
|
+
def get_target_space_from_source(self, source_space):
|
|
52
|
+
self.validate_source_space(source_space)
|
|
53
|
+
crop_low, crop_high = self.get_crop_range(source_space)
|
|
54
|
+
return BoxSpace(
|
|
55
|
+
backend=source_space.backend,
|
|
56
|
+
low=crop_low,
|
|
57
|
+
high=crop_high,
|
|
58
|
+
shape=source_space.shape,
|
|
59
|
+
dtype=source_space.dtype,
|
|
60
|
+
device=source_space.device,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
def transform(self, source_space, data):
|
|
64
|
+
self.validate_source_space(source_space)
|
|
65
|
+
crop_low, crop_high = self.get_crop_range(source_space)
|
|
66
|
+
return source_space.backend.clip(data, crop_low, crop_high)
|
|
67
|
+
|
|
68
|
+
def direction_inverse(self, source_space = None):
|
|
69
|
+
return IdentityTransformation()
|
|
@@ -135,12 +135,18 @@ class DictTransformation(DataTransformation):
|
|
|
135
135
|
|
|
136
136
|
inverse_mapping = {}
|
|
137
137
|
for key, transformation in self.mapping.items():
|
|
138
|
-
|
|
139
|
-
|
|
138
|
+
if source_space is not None:
|
|
139
|
+
current_source = get_chained_value(
|
|
140
140
|
source_space,
|
|
141
141
|
key.split('/'),
|
|
142
142
|
ignore_missing_keys=self.ignore_missing_keys
|
|
143
143
|
)
|
|
144
|
+
if current_source is None:
|
|
145
|
+
continue
|
|
146
|
+
else:
|
|
147
|
+
current_source = None
|
|
148
|
+
inverse_mapping[key] = transformation.direction_inverse(
|
|
149
|
+
current_source
|
|
144
150
|
)
|
|
145
151
|
|
|
146
152
|
return DictTransformation(
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from .transformation import DataTransformation, TargetDataT
|
|
2
|
+
from unienv_interface.space import BoxSpace
|
|
3
|
+
from typing import Union, Any, Optional
|
|
4
|
+
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
5
|
+
|
|
6
|
+
class IdentityTransformation(DataTransformation):
|
|
7
|
+
has_inverse = True
|
|
8
|
+
|
|
9
|
+
def get_target_space_from_source(self, source_space):
|
|
10
|
+
return source_space
|
|
11
|
+
|
|
12
|
+
def transform(self, source_space, data):
|
|
13
|
+
return data
|
|
14
|
+
|
|
15
|
+
def direction_inverse(self, source_space = None):
|
|
16
|
+
return self
|
|
@@ -11,8 +11,8 @@ def _get_broadcastable_value(
|
|
|
11
11
|
if isinstance(value, (int, float)):
|
|
12
12
|
return value
|
|
13
13
|
else:
|
|
14
|
-
assert target_ndim
|
|
15
|
-
target_shape = tuple([1] * (target_ndim - len(value)) + list(value.shape))
|
|
14
|
+
assert target_ndim >= len(value.shape), "Target space must have at least as many dimensions as the value"
|
|
15
|
+
target_shape = tuple([1] * (target_ndim - len(value.shape)) + list(value.shape))
|
|
16
16
|
return backend.reshape(value, target_shape)
|
|
17
17
|
|
|
18
18
|
class RescaleTransformation(DataTransformation):
|
|
@@ -22,6 +22,7 @@ class RescaleTransformation(DataTransformation):
|
|
|
22
22
|
new_low : Union[BArrayType,float] = -1.0,
|
|
23
23
|
new_high : Union[BArrayType,float] = 1.0,
|
|
24
24
|
new_dtype : Optional[BDtypeType] = None,
|
|
25
|
+
nan_to : Optional[Union[float, int, BArrayType]] = None,
|
|
25
26
|
):
|
|
26
27
|
# assert isinstance(source_space, BoxSpace), "RescaleTransformation only supports Box action spaces"
|
|
27
28
|
# assert source_space.backend.dtype_is_real_floating(source_space.dtype), "RescaleTransformation only supports real-valued floating spaces"
|
|
@@ -31,6 +32,7 @@ class RescaleTransformation(DataTransformation):
|
|
|
31
32
|
self.new_high = new_high
|
|
32
33
|
self._new_span = new_high - new_low
|
|
33
34
|
self.new_dtype = new_dtype
|
|
35
|
+
self.nan_to = nan_to
|
|
34
36
|
|
|
35
37
|
def get_target_space_from_source(self, source_space):
|
|
36
38
|
assert isinstance(source_space, BoxSpace), "RescaleTransformation only supports Box spaces"
|
|
@@ -67,14 +69,26 @@ class RescaleTransformation(DataTransformation):
|
|
|
67
69
|
source_space.backend,
|
|
68
70
|
self.new_low,
|
|
69
71
|
target_ndim
|
|
70
|
-
), source_space.backend.device(data))
|
|
72
|
+
), source_space.backend.device(data)) if source_space.backend.is_backendarray(self.new_low) else self.new_low
|
|
71
73
|
target_high = source_space.backend.to_device(_get_broadcastable_value(
|
|
72
74
|
source_space.backend,
|
|
73
75
|
self.new_high,
|
|
74
76
|
target_ndim
|
|
75
|
-
), source_space.backend.device(data))
|
|
77
|
+
), source_space.backend.device(data)) if source_space.backend.is_backendarray(self.new_high) else self.new_high
|
|
76
78
|
scaling_factor = (target_high - target_low) / (source_space._high - source_space._low)
|
|
77
79
|
target_data = (data - source_space._low) * scaling_factor + target_low
|
|
80
|
+
|
|
81
|
+
if self.nan_to is not None:
|
|
82
|
+
target_data = source_space.backend.where(
|
|
83
|
+
source_space.backend.isnan(target_data),
|
|
84
|
+
self.nan_to,
|
|
85
|
+
target_data
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
if self.new_dtype is not None:
|
|
89
|
+
if source_space.backend.dtype_is_real_integer(self.new_dtype) and source_space.backend.dtype_is_real_floating(target_data.dtype):
|
|
90
|
+
target_data = source_space.backend.round(target_data)
|
|
91
|
+
target_data = source_space.backend.astype(target_data, self.new_dtype)
|
|
78
92
|
return target_data
|
|
79
93
|
|
|
80
94
|
def direction_inverse(self, source_space = None):
|
|
@@ -95,4 +109,9 @@ class RescaleTransformation(DataTransformation):
|
|
|
95
109
|
new_low=new_low,
|
|
96
110
|
new_high=new_high,
|
|
97
111
|
new_dtype=source_space.dtype
|
|
98
|
-
)
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
def __setstate__(self, state):
|
|
115
|
+
self.__dict__.update(state)
|
|
116
|
+
if not hasattr(self, "nan_to"):
|
|
117
|
+
self.nan_to = None
|
|
@@ -31,7 +31,7 @@ def data_to(
|
|
|
31
31
|
key: data_to(value, source_backend, target_backend, target_device)
|
|
32
32
|
for key, value in data.items()
|
|
33
33
|
}
|
|
34
|
-
elif isinstance(data, Sequence):
|
|
34
|
+
elif isinstance(data, Sequence) and not isinstance(data, (str, bytes)):
|
|
35
35
|
data = [
|
|
36
36
|
data_to(value, source_backend, target_backend, target_device)
|
|
37
37
|
for value in data
|
|
File without changes
|
|
File without changes
|