unienv 0.0.1b5__py3-none-any.whl → 0.0.1b7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {unienv-0.0.1b5.dist-info → unienv-0.0.1b7.dist-info}/METADATA +3 -2
  2. {unienv-0.0.1b5.dist-info → unienv-0.0.1b7.dist-info}/RECORD +30 -21
  3. {unienv-0.0.1b5.dist-info → unienv-0.0.1b7.dist-info}/WHEEL +1 -1
  4. unienv_data/base/common.py +25 -10
  5. unienv_data/batches/backend_compat.py +1 -1
  6. unienv_data/batches/combined_batch.py +1 -1
  7. unienv_data/replay_buffer/replay_buffer.py +51 -8
  8. unienv_data/storages/_episode_storage.py +438 -0
  9. unienv_data/storages/_list_storage.py +136 -0
  10. unienv_data/storages/backend_compat.py +268 -0
  11. unienv_data/storages/flattened.py +3 -3
  12. unienv_data/storages/hdf5.py +7 -2
  13. unienv_data/storages/image_storage.py +144 -0
  14. unienv_data/storages/npz_storage.py +135 -0
  15. unienv_data/storages/pytorch.py +16 -9
  16. unienv_data/storages/video_storage.py +297 -0
  17. unienv_data/third_party/tensordict/memmap_tensor.py +1174 -0
  18. unienv_data/transformations/image_compress.py +81 -18
  19. unienv_interface/space/space_utils/batch_utils.py +5 -1
  20. unienv_interface/space/spaces/dict.py +6 -0
  21. unienv_interface/transformations/__init__.py +3 -1
  22. unienv_interface/transformations/batch_and_unbatch.py +43 -4
  23. unienv_interface/transformations/chained_transform.py +9 -8
  24. unienv_interface/transformations/crop.py +69 -0
  25. unienv_interface/transformations/dict_transform.py +8 -2
  26. unienv_interface/transformations/identity.py +16 -0
  27. unienv_interface/transformations/rescale.py +24 -5
  28. unienv_interface/wrapper/backend_compat.py +1 -1
  29. {unienv-0.0.1b5.dist-info → unienv-0.0.1b7.dist-info}/licenses/LICENSE +0 -0
  30. {unienv-0.0.1b5.dist-info → unienv-0.0.1b7.dist-info}/top_level.txt +0 -0
@@ -6,9 +6,31 @@ from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, B
6
6
  from PIL import Image
7
7
  import numpy as np
8
8
  import io
9
+ import math
10
+
11
+ # https://stackoverflow.com/questions/3471663/jpeg-compression-ratio
12
+ JPEG_QUALITY_COMPRESSION_MAP = {
13
+ "quality": np.array([55, 60, 65, 70, 75, 80, 85, 90, 95, 100], dtype=int),
14
+ "compression_ratio": np.array([43.27, 36.90, 34.24, 31.50, 26.00, 25.06, 19.08, 14.30, 9.88, 5.27], dtype=float),
15
+ "conservative_ratio": 0.6,
16
+ }
17
+ def get_jpeg_compression_ratio(init_quality : int) -> int:
18
+ if init_quality <= JPEG_QUALITY_COMPRESSION_MAP['quality'][0]:
19
+ return math.floor(JPEG_QUALITY_COMPRESSION_MAP['compression_ratio'][0] * JPEG_QUALITY_COMPRESSION_MAP['conservative_ratio'])
20
+ if init_quality >= JPEG_QUALITY_COMPRESSION_MAP['quality'][-1]:
21
+ return math.floor(JPEG_QUALITY_COMPRESSION_MAP['compression_ratio'][-1] * JPEG_QUALITY_COMPRESSION_MAP['conservative_ratio'])
22
+
23
+ for i in range(1, len(JPEG_QUALITY_COMPRESSION_MAP['quality'])):
24
+ if init_quality <= JPEG_QUALITY_COMPRESSION_MAP['quality'][i]:
25
+ q_low = JPEG_QUALITY_COMPRESSION_MAP['quality'][i - 1]
26
+ q_high = JPEG_QUALITY_COMPRESSION_MAP['quality'][i]
27
+ r_low = JPEG_QUALITY_COMPRESSION_MAP['compression_ratio'][i - 1]
28
+ r_high = JPEG_QUALITY_COMPRESSION_MAP['compression_ratio'][i]
29
+ ratio = r_low + (r_high - r_low) * (init_quality - q_low) / (q_high - q_low)
30
+ return math.floor(ratio * JPEG_QUALITY_COMPRESSION_MAP['conservative_ratio'])
9
31
 
10
32
  CONSERVATIVE_COMPRESSION_RATIOS = {
11
- "JPEG": 10, # https://stackoverflow.com/questions/3471663/jpeg-compression-ratio
33
+ "JPEG": get_jpeg_compression_ratio,
12
34
  }
13
35
 
14
36
  class ImageCompressTransformation(DataTransformation):
@@ -18,8 +40,10 @@ class ImageCompressTransformation(DataTransformation):
18
40
  self,
19
41
  init_quality : int = 70,
20
42
  max_size_bytes : Optional[int] = None,
43
+ compression_ratio : Optional[float] = None,
21
44
  mode : Optional[str] = None,
22
45
  format : str = "JPEG",
46
+ last_channel : bool = True,
23
47
  ) -> None:
24
48
  """
25
49
  Initialize JPEG compression transformation.
@@ -29,21 +53,25 @@ class ImageCompressTransformation(DataTransformation):
29
53
  mode: Optional mode for PIL Image (e.g., "RGB", "L"). If None, inferred from input.
30
54
  format: Image format to use for compression (default "JPEG"). See https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html for options.
31
55
  """
32
- assert max_size_bytes is not None or format in CONSERVATIVE_COMPRESSION_RATIOS, "Either max_size_bytes must be specified or format must have a conservative compression ratio defined."
56
+ assert not (max_size_bytes is not None and compression_ratio is not None), "Specify either max_size_bytes or compression_ratio, not both."
57
+ assert max_size_bytes is not None or compression_ratio is not None or format in CONSERVATIVE_COMPRESSION_RATIOS, "Either max_size_bytes must be specified or format must have a conservative compression ratio defined."
33
58
 
34
59
  self.init_quality = init_quality
35
60
  self.max_size_bytes = max_size_bytes
36
- self.compression_ratio = CONSERVATIVE_COMPRESSION_RATIOS.get(format, None) if max_size_bytes is None else None
61
+ self.compression_ratio = compression_ratio if compression_ratio is not None else (CONSERVATIVE_COMPRESSION_RATIOS.get(format, None)(init_quality) if max_size_bytes is None else None)
37
62
  self.mode = mode
38
63
  self.format = format
64
+ self.last_channel = last_channel
39
65
 
40
- @staticmethod
41
- def validate_source_space(source_space: Space[Any, BDeviceType, BDtypeType, BRNGType]) -> None:
66
+ def validate_source_space(self, source_space: Space[Any, BDeviceType, BDtypeType, BRNGType]) -> None:
42
67
  assert isinstance(source_space, BoxSpace), "JPEGCompressTransformation only supports BoxSpace source spaces."
43
- assert len(source_space.shape) >= 3 and (
44
- source_space.shape[-1] == 3 or
45
- source_space.shape[-1] == 1
46
- ), "JPEGCompressTransformation only supports BoxSpace source spaces with shape (..., H, W, 1 or 3)."
68
+ if not self.last_channel:
69
+ assert len(source_space.shape) >= 2
70
+ else:
71
+ assert len(source_space.shape) >= 3 and (
72
+ source_space.shape[-1] == 3 or
73
+ source_space.shape[-1] == 1
74
+ ), "JPEGCompressTransformation only supports BoxSpace source spaces with shape (..., H, W, 1 or 3)."
47
75
 
48
76
  @staticmethod
49
77
  def get_uint8_dtype(
@@ -52,14 +80,22 @@ class ImageCompressTransformation(DataTransformation):
52
80
  return backend.__array_namespace_info__().dtypes()['uint8']
53
81
 
54
82
  def _get_max_compressed_size(self, source_space : BoxSpace):
55
- H, W, C = source_space.shape[-3], source_space.shape[-2], source_space.shape[-1]
83
+ if self.last_channel:
84
+ H, W, C = source_space.shape[-3], source_space.shape[-2], source_space.shape[-1]
85
+ else:
86
+ H, W = source_space.shape[-2], source_space.shape[-1]
87
+ C = 1
56
88
  return self.max_size_bytes if self.max_size_bytes is not None else (H * W * C // self.compression_ratio) + 1
57
89
 
58
90
  def get_target_space_from_source(self, source_space):
59
91
  self.validate_source_space(source_space)
60
92
 
61
93
  max_compressed_size = self._get_max_compressed_size(source_space)
62
- new_shape = source_space.shape[:-3] + (max_compressed_size,)
94
+
95
+ if not self.last_channel:
96
+ new_shape = source_space.shape[:-2] + (max_compressed_size,)
97
+ else:
98
+ new_shape = source_space.shape[:-3] + (max_compressed_size,)
63
99
 
64
100
  return BoxSpace(
65
101
  source_space.backend,
@@ -108,7 +144,11 @@ class ImageCompressTransformation(DataTransformation):
108
144
 
109
145
  max_compressed_size = self._get_max_compressed_size(source_space)
110
146
  data_numpy = source_space.backend.to_numpy(data)
111
- flat_data_numpy = data_numpy.reshape(-1, *data_numpy.shape[-3:])
147
+ if not self.last_channel:
148
+ flat_data_numpy = data_numpy.reshape(-1, *data_numpy.shape[-2:])
149
+ else:
150
+ flat_data_numpy = data_numpy.reshape(-1, *data_numpy.shape[-3:])
151
+
112
152
  flat_compressed_data = np.zeros((flat_data_numpy.shape[0], max_compressed_size), dtype=np.uint8)
113
153
  for i in range(flat_data_numpy.shape[0]):
114
154
  img_array = flat_data_numpy[i]
@@ -119,16 +159,26 @@ class ImageCompressTransformation(DataTransformation):
119
159
  )
120
160
  byte_array = np.frombuffer(image_bytes, dtype=np.uint8)
121
161
  flat_compressed_data[i, :len(byte_array)] = byte_array
122
- compressed_data = flat_compressed_data.reshape(data_numpy.shape[:-3] + (max_compressed_size, ))
162
+
163
+ if not self.last_channel:
164
+ compressed_data = flat_compressed_data.reshape(data_numpy.shape[:-2] + (max_compressed_size, ))
165
+ else:
166
+ compressed_data = flat_compressed_data.reshape(data_numpy.shape[:-3] + (max_compressed_size, ))
123
167
  compressed_data_backend = source_space.backend.from_numpy(compressed_data, dtype=self.get_uint8_dtype(source_space.backend), device=source_space.device)
124
168
  return compressed_data_backend
125
169
 
126
170
  def direction_inverse(self, source_space = None):
127
171
  assert source_space is not None, "Source space must be provided to get inverse transformation."
128
172
  self.validate_source_space(source_space)
129
- height = source_space.shape[-3]
130
- width = source_space.shape[-2]
131
- channels = source_space.shape[-1]
173
+
174
+ if not self.last_channel:
175
+ height = source_space.shape[-2]
176
+ width = source_space.shape[-1]
177
+ channels = None
178
+ else:
179
+ height = source_space.shape[-3]
180
+ width = source_space.shape[-2]
181
+ channels = source_space.shape[-1]
132
182
  return ImageDecompressTransformation(
133
183
  target_height=height,
134
184
  target_width=width,
@@ -136,6 +186,12 @@ class ImageCompressTransformation(DataTransformation):
136
186
  mode=self.mode,
137
187
  format=self.format,
138
188
  )
189
+
190
+ def __setstate__(self, state):
191
+ # for backward compatibility
192
+ self.__dict__.update(state)
193
+ if 'last_channel' not in state:
194
+ self.last_channel = True
139
195
 
140
196
  class ImageDecompressTransformation(DataTransformation):
141
197
  has_inverse = True
@@ -144,7 +200,7 @@ class ImageDecompressTransformation(DataTransformation):
144
200
  self,
145
201
  target_height : int,
146
202
  target_width : int,
147
- target_channels : int = 3,
203
+ target_channels : Optional[int] = 3,
148
204
  mode : Optional[str] = None,
149
205
  format : Optional[str] = None,
150
206
  ) -> None:
@@ -174,6 +230,9 @@ class ImageDecompressTransformation(DataTransformation):
174
230
  def get_target_space_from_source(self, source_space):
175
231
  self.validate_source_space(source_space)
176
232
  new_shape = source_space.shape[:-1] + (self.target_height, self.target_width, self.target_channels)
233
+ if self.target_channels is None:
234
+ new_shape = new_shape[:-1]
235
+
177
236
  return BoxSpace(
178
237
  source_space.backend,
179
238
  shape=new_shape,
@@ -212,7 +271,10 @@ class ImageDecompressTransformation(DataTransformation):
212
271
  byte_array.tobytes(),
213
272
  mode=self.mode
214
273
  )
215
- decompressed_image = flat_decompressed_image.reshape(data_numpy.shape[:-1] + (self.target_height, self.target_width, self.target_channels))
274
+ if self.target_channels is None:
275
+ decompressed_image = flat_decompressed_image.reshape(data_numpy.shape[:-1] + (self.target_height, self.target_width))
276
+ else:
277
+ decompressed_image = flat_decompressed_image.reshape(data_numpy.shape[:-1] + (self.target_height, self.target_width, self.target_channels))
216
278
  decompressed_image_backend = source_space.backend.from_numpy(decompressed_image, dtype=self.get_uint8_dtype(source_space.backend), device=source_space.device)
217
279
  return decompressed_image_backend
218
280
 
@@ -223,4 +285,5 @@ class ImageDecompressTransformation(DataTransformation):
223
285
  max_size_bytes=source_space.shape[-1],
224
286
  mode=self.mode,
225
287
  format=self.format if self.format is not None else "JPEG",
288
+ last_channel=self.target_channels is not None,
226
289
  )
@@ -334,7 +334,7 @@ def swap_batch_dims_in_data(
334
334
  ) -> Any:
335
335
  if backend.is_backendarray(data):
336
336
  return _tensor_transpose(backend, data, dim1, dim2)
337
- elif isinstance(data, np.ndarray) and data.dtype != object:
337
+ elif isinstance(data, np.ndarray):
338
338
  return _tensor_transpose(NumpyComputeBackend, data, dim1, dim2)
339
339
  elif isinstance(data, GraphInstance):
340
340
  return GraphInstance(
@@ -734,6 +734,8 @@ def _get_at_tuple(space: TupleSpace, items: typing.Tuple[Any, ...], index : Arra
734
734
 
735
735
  @get_at.register(BatchedSpace)
736
736
  def _get_at_batched(space: BatchedSpace, items: np.ndarray, index: ArrayAPIGetIndex) -> typing.Union[np.ndarray, Any]:
737
+ if space.backend.is_backendarray(index):
738
+ index = space.backend.to_numpy(index)
737
739
  return items[index]
738
740
 
739
741
  @singledispatch
@@ -830,6 +832,8 @@ def _set_at_batched(
830
832
  value: typing.Union[np.ndarray, Any],
831
833
  ) -> np.ndarray:
832
834
  new_data = items.copy()
835
+ if space.backend.is_backendarray(index):
836
+ index = space.backend.to_numpy(index)
833
837
  new_data[index] = value
834
838
  return new_data
835
839
 
@@ -34,6 +34,12 @@ class DictSpace(Space[Dict[str, Any], BDeviceType, BDtypeType, BRNGType]):
34
34
  new_spaces: Dict[str, Space[Any, BDeviceType, BDtypeType, BRNGType]] = {}
35
35
 
36
36
  for key, space in spaces.items():
37
+ if not isinstance(space, Space) and isinstance(space, Mapping):
38
+ space = DictSpace(
39
+ backend,
40
+ space,
41
+ device
42
+ )
37
43
  assert isinstance(
38
44
  space, Space
39
45
  ), f"Dict space element is not an instance of Space: key='{key}', space={space}"
@@ -1,6 +1,8 @@
1
1
  from .transformation import DataTransformation
2
+ from .identity import IdentityTransformation
2
3
  from .rescale import RescaleTransformation
3
4
  from .filter_dict import DictIncludeKeyTransformation, DictExcludeKeyTransformation
4
5
  from .batch_and_unbatch import BatchifyTransformation, UnBatchifyTransformation
5
6
  from .dict_transform import DictTransformation
6
- from .chained_transform import ChainedTransformation
7
+ from .chained_transform import ChainedTransformation
8
+ from .crop import CropTransformation
@@ -6,26 +6,65 @@ from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, B
6
6
 
7
7
  class BatchifyTransformation(DataTransformation):
8
8
  has_inverse = True
9
+
10
+ def __init__(
11
+ self,
12
+ axis : int = 0
13
+ ) -> None:
14
+ self.axis = axis
9
15
 
10
16
  def get_target_space_from_source(self, source_space):
11
- return sbu.batch_space(source_space, 1)
17
+ ret = sbu.batch_space(source_space, 1)
18
+ if self.axis != 0:
19
+ ret = sbu.swap_batch_dims(
20
+ ret,
21
+ 0,
22
+ self.axis
23
+ )
24
+ return ret
12
25
 
13
26
  def transform(self, source_space, data):
14
27
  return sbu.concatenate(
15
28
  source_space,
16
- [data]
29
+ [data],
30
+ axis=self.axis
17
31
  )
18
32
 
19
33
  def direction_inverse(self, source_space = None):
20
- return UnBatchifyTransformation()
34
+ return UnBatchifyTransformation(axis=self.axis)
21
35
 
22
36
  class UnBatchifyTransformation(DataTransformation):
23
37
  has_inverse = True
38
+
39
+ def __init__(
40
+ self,
41
+ axis : int = 0
42
+ ) -> None:
43
+ self.axis = axis
24
44
 
25
45
  def get_target_space_from_source(self, source_space):
46
+ if self.axis != 0:
47
+ source_space = sbu.swap_batch_dims(
48
+ source_space,
49
+ 0,
50
+ self.axis
51
+ )
52
+ assert sbu.batch_size(source_space) == 1, "Cannot unbatch space with batch size > 1"
26
53
  return next(iter(sbu.unbatch_spaces(source_space)))
27
54
 
28
55
  def transform(self, source_space, data):
56
+ if self.axis != 0:
57
+ source_space = sbu.swap_batch_dims(
58
+ source_space,
59
+ 0,
60
+ self.axis
61
+ )
62
+ data = sbu.swap_batch_dims_in_data(
63
+ source_space.backend,
64
+ data,
65
+ 0,
66
+ self.axis
67
+ )
29
68
  return sbu.get_at(
30
69
  source_space,
31
70
  data,
@@ -33,4 +72,4 @@ class UnBatchifyTransformation(DataTransformation):
33
72
  )
34
73
 
35
74
  def direction_inverse(self, source_space = None):
36
- return BatchifyTransformation()
75
+ return BatchifyTransformation(axis=self.axis)
@@ -45,17 +45,18 @@ class ChainedTransformation(DataTransformation):
45
45
  ) -> Optional["ChainedTransformation"]:
46
46
  if not self.has_inverse:
47
47
  return None
48
+
49
+ source_spaces = [source_space]
50
+ for transformation in self.transformations:
51
+ next_space = transformation.get_target_space_from_source(source_spaces[-1])
52
+ source_spaces.append(next_space)
48
53
 
49
- inverse_mapping = {
50
- key: transformation.direction_inverse(source_space)
51
- for key, transformation in self.mapping.items()
52
- }
54
+ inverted_transformations = []
55
+ for i in reversed(range(len(self.transformations))):
56
+ inverted_transformations.append(self.transformations[i].direction_inverse(source_spaces[i]))
53
57
 
54
58
  return ChainedTransformation(
55
- transformations=[
56
- transformation.direction_inverse(source_space)
57
- for transformation in reversed(self.transformations)
58
- ]
59
+ inverted_transformations
59
60
  )
60
61
 
61
62
  def close(self):
@@ -0,0 +1,69 @@
1
+ from unienv_interface.space.space_utils import batch_utils as sbu
2
+ from unienv_interface.transformations import DataTransformation
3
+ from unienv_interface.space import Space, BoxSpace, TextSpace
4
+ from typing import Union, Any, Optional, Tuple, List
5
+ from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
6
+ from .identity import IdentityTransformation
7
+
8
+ class CropTransformation(DataTransformation):
9
+ has_inverse = True
10
+
11
+ def __init__(
12
+ self,
13
+ crop_low : Union[int, float, BArrayType],
14
+ crop_high : Union[int, float, BArrayType],
15
+ ) -> None:
16
+ """
17
+ Initialize Crop Transformation.
18
+ Args:
19
+ crop_low: Lower bound for cropping the data.
20
+ crop_high: Upper bound for cropping the data.
21
+ """
22
+ self.crop_low = crop_low
23
+ self.crop_high = crop_high
24
+
25
+ def validate_source_space(self, source_space : Space[Any, BDeviceType, BDtypeType, BRNGType]) -> None:
26
+ assert isinstance(source_space, BoxSpace), "CropTransformation only supports Box spaces"
27
+
28
+ def get_crop_range(self, source_space : BoxSpace[Any, BDeviceType, BDtypeType, BRNGType]) -> Tuple[BArrayType, BArrayType]:
29
+ new_low = self.crop_low
30
+ if source_space.backend.is_backendarray(new_low):
31
+ if len(new_low.shape) < len(source_space.shape):
32
+ new_low = source_space.backend.reshape(
33
+ new_low,
34
+ (-1,)*(len(source_space.shape) - len(new_low.shape)) + new_low.shape
35
+ )
36
+ new_low = source_space.backend.astype(new_low, source_space.dtype)
37
+ if source_space.device is not None:
38
+ new_low = source_space.backend.to_device(new_low, source_space.device)
39
+ new_high = self.crop_high
40
+ if source_space.backend.is_backendarray(new_high):
41
+ if len(new_high.shape) < len(source_space.shape):
42
+ new_high = source_space.backend.reshape(
43
+ new_high,
44
+ (-1,)*(len(source_space.shape) - len(new_high.shape)) + new_high.shape
45
+ )
46
+ new_high = source_space.backend.astype(new_high, source_space.dtype)
47
+ if source_space.device is not None:
48
+ new_high = source_space.backend.to_device(new_high, source_space.device)
49
+ return new_low, new_high
50
+
51
+ def get_target_space_from_source(self, source_space):
52
+ self.validate_source_space(source_space)
53
+ crop_low, crop_high = self.get_crop_range(source_space)
54
+ return BoxSpace(
55
+ backend=source_space.backend,
56
+ low=crop_low,
57
+ high=crop_high,
58
+ shape=source_space.shape,
59
+ dtype=source_space.dtype,
60
+ device=source_space.device,
61
+ )
62
+
63
+ def transform(self, source_space, data):
64
+ self.validate_source_space(source_space)
65
+ crop_low, crop_high = self.get_crop_range(source_space)
66
+ return source_space.backend.clip(data, crop_low, crop_high)
67
+
68
+ def direction_inverse(self, source_space = None):
69
+ return IdentityTransformation()
@@ -135,12 +135,18 @@ class DictTransformation(DataTransformation):
135
135
 
136
136
  inverse_mapping = {}
137
137
  for key, transformation in self.mapping.items():
138
- inverse_mapping[key] = transformation.direction_inverse(
139
- None if source_space is None else get_chained_value(
138
+ if source_space is not None:
139
+ current_source = get_chained_value(
140
140
  source_space,
141
141
  key.split('/'),
142
142
  ignore_missing_keys=self.ignore_missing_keys
143
143
  )
144
+ if current_source is None:
145
+ continue
146
+ else:
147
+ current_source = None
148
+ inverse_mapping[key] = transformation.direction_inverse(
149
+ current_source
144
150
  )
145
151
 
146
152
  return DictTransformation(
@@ -0,0 +1,16 @@
1
+ from .transformation import DataTransformation, TargetDataT
2
+ from unienv_interface.space import BoxSpace
3
+ from typing import Union, Any, Optional
4
+ from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
5
+
6
+ class IdentityTransformation(DataTransformation):
7
+ has_inverse = True
8
+
9
+ def get_target_space_from_source(self, source_space):
10
+ return source_space
11
+
12
+ def transform(self, source_space, data):
13
+ return data
14
+
15
+ def direction_inverse(self, source_space = None):
16
+ return self
@@ -11,8 +11,8 @@ def _get_broadcastable_value(
11
11
  if isinstance(value, (int, float)):
12
12
  return value
13
13
  else:
14
- assert target_ndim <= len(value), "Value must have at least as many dimensions as target space"
15
- target_shape = tuple([1] * (target_ndim - len(value)) + list(value.shape))
14
+ assert target_ndim >= len(value.shape), "Target space must have at least as many dimensions as the value"
15
+ target_shape = tuple([1] * (target_ndim - len(value.shape)) + list(value.shape))
16
16
  return backend.reshape(value, target_shape)
17
17
 
18
18
  class RescaleTransformation(DataTransformation):
@@ -22,6 +22,7 @@ class RescaleTransformation(DataTransformation):
22
22
  new_low : Union[BArrayType,float] = -1.0,
23
23
  new_high : Union[BArrayType,float] = 1.0,
24
24
  new_dtype : Optional[BDtypeType] = None,
25
+ nan_to : Optional[Union[float, int, BArrayType]] = None,
25
26
  ):
26
27
  # assert isinstance(source_space, BoxSpace), "RescaleTransformation only supports Box action spaces"
27
28
  # assert source_space.backend.dtype_is_real_floating(source_space.dtype), "RescaleTransformation only supports real-valued floating spaces"
@@ -31,6 +32,7 @@ class RescaleTransformation(DataTransformation):
31
32
  self.new_high = new_high
32
33
  self._new_span = new_high - new_low
33
34
  self.new_dtype = new_dtype
35
+ self.nan_to = nan_to
34
36
 
35
37
  def get_target_space_from_source(self, source_space):
36
38
  assert isinstance(source_space, BoxSpace), "RescaleTransformation only supports Box spaces"
@@ -67,14 +69,26 @@ class RescaleTransformation(DataTransformation):
67
69
  source_space.backend,
68
70
  self.new_low,
69
71
  target_ndim
70
- ), source_space.backend.device(data))
72
+ ), source_space.backend.device(data)) if source_space.backend.is_backendarray(self.new_low) else self.new_low
71
73
  target_high = source_space.backend.to_device(_get_broadcastable_value(
72
74
  source_space.backend,
73
75
  self.new_high,
74
76
  target_ndim
75
- ), source_space.backend.device(data))
77
+ ), source_space.backend.device(data)) if source_space.backend.is_backendarray(self.new_high) else self.new_high
76
78
  scaling_factor = (target_high - target_low) / (source_space._high - source_space._low)
77
79
  target_data = (data - source_space._low) * scaling_factor + target_low
80
+
81
+ if self.nan_to is not None:
82
+ target_data = source_space.backend.where(
83
+ source_space.backend.isnan(target_data),
84
+ self.nan_to,
85
+ target_data
86
+ )
87
+
88
+ if self.new_dtype is not None:
89
+ if source_space.backend.dtype_is_real_integer(self.new_dtype) and source_space.backend.dtype_is_real_floating(target_data.dtype):
90
+ target_data = source_space.backend.round(target_data)
91
+ target_data = source_space.backend.astype(target_data, self.new_dtype)
78
92
  return target_data
79
93
 
80
94
  def direction_inverse(self, source_space = None):
@@ -95,4 +109,9 @@ class RescaleTransformation(DataTransformation):
95
109
  new_low=new_low,
96
110
  new_high=new_high,
97
111
  new_dtype=source_space.dtype
98
- )
112
+ )
113
+
114
+ def __setstate__(self, state):
115
+ self.__dict__.update(state)
116
+ if not hasattr(self, "nan_to"):
117
+ self.nan_to = None
@@ -31,7 +31,7 @@ def data_to(
31
31
  key: data_to(value, source_backend, target_backend, target_device)
32
32
  for key, value in data.items()
33
33
  }
34
- elif isinstance(data, Sequence):
34
+ elif isinstance(data, Sequence) and not isinstance(data, (str, bytes)):
35
35
  data = [
36
36
  data_to(value, source_backend, target_backend, target_device)
37
37
  for value in data