unienv 0.0.1b4__py3-none-any.whl → 0.0.1b5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,6 +12,24 @@ import numpy as np
12
12
  import os
13
13
  import json
14
14
 
15
+
16
+ def _merge_nested_mappings(
17
+ primary: Mapping[str, Any],
18
+ secondary: Mapping[str, Any],
19
+ ) -> Mapping[str, Any]:
20
+ """Merge secondary into primary without clobbering explicitly matched keys."""
21
+ merged: Dict[str, Any] = dict(primary)
22
+ for merge_key, merge_value in secondary.items():
23
+ if (
24
+ merge_key in merged
25
+ and isinstance(merged[merge_key], Mapping)
26
+ and isinstance(merge_value, Mapping)
27
+ ):
28
+ merged[merge_key] = _merge_nested_mappings(merged[merge_key], merge_value)
29
+ elif merge_key not in merged:
30
+ merged[merge_key] = merge_value
31
+ return merged
32
+
15
33
  def map_transform(
16
34
  data : Dict[str, Any],
17
35
  value_map : Dict[str, Any],
@@ -44,7 +62,10 @@ def map_transform(
44
62
  residual_transformed = fn(prefix + "*", residual_data, value_map[prefix + "*"])
45
63
  if isinstance(residual_transformed, Mapping) or isinstance(residual_transformed, DictSpace):
46
64
  for key, value in residual_transformed.items():
47
- transformed_data[key] = value
65
+ if key in transformed_data and isinstance(transformed_data[key], Mapping) and isinstance(value, Mapping):
66
+ transformed_data[key] = _merge_nested_mappings(transformed_data[key], value)
67
+ elif key not in transformed_data:
68
+ transformed_data[key] = value
48
69
  residual_data = {}
49
70
  return transformed_data, residual_data
50
71
 
@@ -52,7 +73,7 @@ def get_chained_residual_space(
52
73
  space : DictSpace[BDeviceType, BDtypeType, BRNGType],
53
74
  all_keys : List[str],
54
75
  prefix : str = "",
55
- ) -> DictSpace[BDeviceType, BDtypeType, BRNGType]:
76
+ ) -> Optional[DictSpace[BDeviceType, BDtypeType, BRNGType]]:
56
77
  residual_spaces = {}
57
78
 
58
79
  if len(residual_spaces) > 0 and (prefix + "*") in all_keys:
@@ -72,10 +93,13 @@ def get_chained_residual_space(
72
93
  all_keys,
73
94
  prefix=full_key + "/",
74
95
  )
75
- if len(sub_residual.spaces) > 0:
96
+ if sub_residual is not None and len(sub_residual.spaces) > 0:
76
97
  residual_spaces[key] = sub_residual
77
98
  else:
78
99
  residual_spaces[key] = subspace
100
+
101
+ if len(residual_spaces) == 0:
102
+ return None
79
103
 
80
104
  return DictSpace(
81
105
  space.backend,
@@ -87,7 +111,7 @@ def get_chained_space(
87
111
  space : DictSpace[BDeviceType, BDtypeType, BRNGType],
88
112
  key_chain : str,
89
113
  all_keys : List[str],
90
- ) -> Space[Any, BDeviceType, BDtypeType, BRNGType]:
114
+ ) -> Optional[Space[Any, BDeviceType, BDtypeType, BRNGType]]:
91
115
  if key_chain.endswith("*"):
92
116
  prefix = key_chain[:-1]
93
117
  subspace = get_chained_residual_space(
@@ -106,8 +130,8 @@ def get_chained_space(
106
130
  for key in key_chain:
107
131
  if len(key) == 0:
108
132
  continue
109
- assert isinstance(current_space, DictSpace), \
110
- f"Expected DictSpace while traversing key chain, but got {type(current_space)}"
133
+ if not isinstance(current_space, DictSpace) or key not in current_space.spaces:
134
+ return None
111
135
  current_space = current_space.spaces[key]
112
136
  return current_space
113
137
 
@@ -130,6 +154,7 @@ class DictStorage(SpaceStorage[
130
154
  *args,
131
155
  capacity : Optional[int] = None,
132
156
  cache_path : Optional[str] = None,
157
+ multiprocessing : bool = False,
133
158
  key_kwargs : Dict[str, Any] = {},
134
159
  type_kwargs : Dict[Type[SpaceStorage[Any, BArrayType, BDeviceType, BDtypeType, BRNGType]], Dict[str, Any]] = {},
135
160
  **kwargs
@@ -142,6 +167,8 @@ class DictStorage(SpaceStorage[
142
167
  for key, sub_storage_cls in storage_cls_map.items():
143
168
  sub_storage_path = key.replace("/", ".").replace("*", "_default") + (sub_storage_cls.single_file_ext or "")
144
169
  subspace = get_chained_space(single_instance_space, key, all_keys)
170
+ if subspace is None:
171
+ continue
145
172
  sub_kwargs = kwargs.copy()
146
173
  if sub_storage_cls in type_kwargs:
147
174
  sub_kwargs.update(type_kwargs[sub_storage_cls])
@@ -152,6 +179,7 @@ class DictStorage(SpaceStorage[
152
179
  *args,
153
180
  cache_path=None if cache_path is None else os.path.join(cache_path, sub_storage_path),
154
181
  capacity=capacity,
182
+ multiprocessing=multiprocessing,
155
183
  **sub_kwargs
156
184
  )
157
185
 
@@ -169,6 +197,7 @@ class DictStorage(SpaceStorage[
169
197
  *,
170
198
  capacity : Optional[int] = None,
171
199
  read_only : bool = True,
200
+ multiprocessing : bool = False,
172
201
  key_kwargs : Dict[str, Any] = {},
173
202
  type_kwargs : Dict[Type[SpaceStorage[Any, BArrayType, BDeviceType, BDtypeType, BRNGType]], Dict[str, Any]] = {},
174
203
  **kwargs
@@ -189,7 +218,9 @@ class DictStorage(SpaceStorage[
189
218
  storage_path = storage_meta["path"]
190
219
 
191
220
  subspace = get_chained_space(single_instance_space, key, all_keys)
192
-
221
+ if subspace is None:
222
+ continue
223
+
193
224
  sub_kwargs = kwargs.copy()
194
225
  if storage_cls in type_kwargs:
195
226
  sub_kwargs.update(type_kwargs[storage_cls])
@@ -200,6 +231,7 @@ class DictStorage(SpaceStorage[
200
231
  subspace,
201
232
  capacity=capacity,
202
233
  read_only=read_only,
234
+ multiprocessing=multiprocessing,
203
235
  **sub_kwargs
204
236
  )
205
237
 
@@ -28,6 +28,8 @@ class FlattenedStorage(SpaceStorage[
28
28
  *args,
29
29
  capacity : Optional[int] = None,
30
30
  cache_path : Optional[str] = None,
31
+ multiprocessing : bool = False,
32
+ inner_storage_kwargs : Dict[str, Any] = {},
31
33
  **kwargs
32
34
  ) -> "FlattenedStorage[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]":
33
35
  flattened_space = sfu.flatten_space(single_instance_space)
@@ -36,12 +38,15 @@ class FlattenedStorage(SpaceStorage[
36
38
  if cache_path is not None:
37
39
  os.makedirs(cache_path, exist_ok=True)
38
40
 
41
+ _inner_storage_kwargs = kwargs.copy()
42
+ _inner_storage_kwargs.update(inner_storage_kwargs)
39
43
  inner_storage = inner_storage_cls.create(
40
44
  flattened_space,
41
45
  *args,
42
46
  cache_path=None if cache_path is None else os.path.join(cache_path, inner_storage_path),
43
47
  capacity=capacity,
44
- **kwargs
48
+ multiprocessing=multiprocessing,
49
+ **_inner_storage_kwargs
45
50
  )
46
51
  return FlattenedStorage(
47
52
  single_instance_space,
@@ -58,6 +63,7 @@ class FlattenedStorage(SpaceStorage[
58
63
  *,
59
64
  capacity : Optional[int] = None,
60
65
  read_only : bool = True,
66
+ multiprocessing : bool = False,
61
67
  **kwargs
62
68
  ) -> "FlattenedStorage[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]":
63
69
  metadata_path = os.path.join(path, "flattened_metadata.json")
@@ -74,6 +80,7 @@ class FlattenedStorage(SpaceStorage[
74
80
  flattened_space,
75
81
  capacity=capacity,
76
82
  read_only=read_only,
83
+ multiprocessing=multiprocessing,
77
84
  **kwargs
78
85
  )
79
86
  return FlattenedStorage(
@@ -458,6 +458,7 @@ class HDF5Storage(SpaceStorage[
458
458
  single_instance_space,
459
459
  capacity,
460
460
  cache_path = None,
461
+ multiprocessing : bool = False,
461
462
  initial_capacity : Optional[int] = None,
462
463
  compression : Union[
463
464
  Dict[str, Any],
@@ -476,6 +477,8 @@ class HDF5Storage(SpaceStorage[
476
477
  ) -> "HDF5Storage":
477
478
  assert cache_path is not None, \
478
479
  "cache_path must be provided for HDF5Storage"
480
+ assert not multiprocessing, \
481
+ "HDF5Storage does not support multiprocessing safe creation. Please create the storage in the main process and then load it in child processes."
479
482
  root = h5py.File(
480
483
  cache_path,
481
484
  "w",
@@ -506,9 +509,12 @@ class HDF5Storage(SpaceStorage[
506
509
  *,
507
510
  capacity = None,
508
511
  read_only = True,
512
+ multiprocessing : bool = False,
509
513
  reduce_io : bool = True,
510
514
  **kwargs
511
515
  ) -> "HDF5Storage":
516
+ assert not multiprocessing, \
517
+ "HDF5Storage does not support multiprocessing safe loading. Please load the storage in the main process and then share it with child processes."
512
518
  assert os.path.exists(path), \
513
519
  f"Path {path} does not exist"
514
520
 
@@ -22,8 +22,8 @@ class PytorchTensorStorage(SpaceStorage[
22
22
  capacity : Optional[int],
23
23
  is_memmap : bool = False,
24
24
  cache_path : Optional[str] = None,
25
- memmap_existok : bool = True,
26
25
  multiprocessing : bool = False,
26
+ memmap_existok : bool = True,
27
27
  ) -> "PytorchTensorStorage":
28
28
  assert single_instance_space.backend is PyTorchComputeBackend, \
29
29
  f"Single instance space must be of type PyTorchComputeBackend, got {single_instance_space.backend}"
@@ -31,6 +31,8 @@ class TransformedStorage(SpaceStorage[
31
31
  data_transformation : DataTransformation,
32
32
  capacity : Optional[int] = None,
33
33
  cache_path : Optional[str] = None,
34
+ multiprocessing : bool = False,
35
+ inner_storage_kwargs : Dict[str, Any] = {},
34
36
  **kwargs
35
37
  ) -> "TransformedStorage[BArrayType, BDeviceType, BDtypeType, BRNGType]":
36
38
  assert data_transformation.has_inverse, "To transform storages (potentially to save space), you need to use inversible data transformations"
@@ -40,12 +42,15 @@ class TransformedStorage(SpaceStorage[
40
42
  if cache_path is not None:
41
43
  os.makedirs(cache_path, exist_ok=True)
42
44
 
45
+ _inner_storage_kwargs = kwargs.copy()
46
+ _inner_storage_kwargs.update(inner_storage_kwargs)
43
47
  inner_storage = inner_storage_cls.create(
44
48
  transformed_space,
45
49
  *args,
46
50
  cache_path=None if cache_path is None else os.path.join(cache_path, inner_storage_path),
47
51
  capacity=capacity,
48
- **kwargs
52
+ multiprocessing=multiprocessing,
53
+ **_inner_storage_kwargs
49
54
  )
50
55
  return TransformedStorage(
51
56
  single_instance_space,
@@ -62,6 +67,7 @@ class TransformedStorage(SpaceStorage[
62
67
  *,
63
68
  capacity : Optional[int] = None,
64
69
  read_only : bool = True,
70
+ multiprocessing : bool = False,
65
71
  **kwargs
66
72
  ) -> "TransformedStorage[BArrayType, BDeviceType, BDtypeType, BRNGType]":
67
73
  metadata_path = os.path.join(path, "transformed_metadata.json")
@@ -85,6 +91,7 @@ class TransformedStorage(SpaceStorage[
85
91
  transformed_space,
86
92
  capacity=capacity,
87
93
  read_only=read_only,
94
+ multiprocessing=multiprocessing,
88
95
  **kwargs
89
96
  )
90
97
  return TransformedStorage(
@@ -140,6 +147,14 @@ class TransformedStorage(SpaceStorage[
140
147
  def __len__(self):
141
148
  return len(self.inner_storage)
142
149
 
150
+ @property
151
+ def is_mutable(self) -> bool:
152
+ return self.inner_storage.is_mutable
153
+
154
+ @property
155
+ def is_multiprocessing_safe(self) -> bool:
156
+ return self.inner_storage.is_multiprocessing_safe
157
+
143
158
  def get_flattened(self, index):
144
159
  dat = self.get(index)
145
160
  if isinstance(index, int):
@@ -7,13 +7,17 @@ from PIL import Image
7
7
  import numpy as np
8
8
  import io
9
9
 
10
+ CONSERVATIVE_COMPRESSION_RATIOS = {
11
+ "JPEG": 10, # https://stackoverflow.com/questions/3471663/jpeg-compression-ratio
12
+ }
13
+
10
14
  class ImageCompressTransformation(DataTransformation):
11
15
  has_inverse = True
12
16
 
13
17
  def __init__(
14
18
  self,
15
- init_quality : int = 75,
16
- max_size_bytes : int = 65536,
19
+ init_quality : int = 70,
20
+ max_size_bytes : Optional[int] = None,
17
21
  mode : Optional[str] = None,
18
22
  format : str = "JPEG",
19
23
  ) -> None:
@@ -25,9 +29,11 @@ class ImageCompressTransformation(DataTransformation):
25
29
  mode: Optional mode for PIL Image (e.g., "RGB", "L"). If None, inferred from input.
26
30
  format: Image format to use for compression (default "JPEG"). See https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html for options.
27
31
  """
32
+ assert max_size_bytes is not None or format in CONSERVATIVE_COMPRESSION_RATIOS, "Either max_size_bytes must be specified or format must have a conservative compression ratio defined."
28
33
 
29
34
  self.init_quality = init_quality
30
35
  self.max_size_bytes = max_size_bytes
36
+ self.compression_ratio = CONSERVATIVE_COMPRESSION_RATIOS.get(format, None) if max_size_bytes is None else None
31
37
  self.mode = mode
32
38
  self.format = format
33
39
 
@@ -45,9 +51,15 @@ class ImageCompressTransformation(DataTransformation):
45
51
  ) -> BDtypeType:
46
52
  return backend.__array_namespace_info__().dtypes()['uint8']
47
53
 
54
+ def _get_max_compressed_size(self, source_space : BoxSpace):
55
+ H, W, C = source_space.shape[-3], source_space.shape[-2], source_space.shape[-1]
56
+ return self.max_size_bytes if self.max_size_bytes is not None else (H * W * C // self.compression_ratio) + 1
57
+
48
58
  def get_target_space_from_source(self, source_space):
49
59
  self.validate_source_space(source_space)
50
- new_shape = source_space.shape[:-3] + (self.max_size_bytes,)
60
+
61
+ max_compressed_size = self._get_max_compressed_size(source_space)
62
+ new_shape = source_space.shape[:-3] + (max_compressed_size,)
51
63
 
52
64
  return BoxSpace(
53
65
  source_space.backend,
@@ -78,14 +90,14 @@ class ImageCompressTransformation(DataTransformation):
78
90
  # Create PIL Image (mode inferred automatically)
79
91
  img = Image.fromarray(img_array, mode=mode)
80
92
 
81
- quality = 95
93
+ quality = self.init_quality
82
94
  while quality >= min_quality:
83
95
  buf = io.BytesIO()
84
96
  img.save(buf, format=self.format, quality=quality)
85
97
  image_bytes = buf.getvalue()
86
98
  if len(image_bytes) <= max_bytes:
87
99
  return image_bytes, quality
88
- quality -= 5
100
+ quality -= 10
89
101
 
90
102
  img.close()
91
103
  # Return lowest quality attempt if still too large
@@ -93,19 +105,21 @@ class ImageCompressTransformation(DataTransformation):
93
105
 
94
106
  def transform(self, source_space, data):
95
107
  self.validate_source_space(source_space)
108
+
109
+ max_compressed_size = self._get_max_compressed_size(source_space)
96
110
  data_numpy = source_space.backend.to_numpy(data)
97
111
  flat_data_numpy = data_numpy.reshape(-1, *data_numpy.shape[-3:])
98
- flat_compressed_data = np.zeros((flat_data_numpy.shape[0], self.max_size_bytes), dtype=np.uint8)
112
+ flat_compressed_data = np.zeros((flat_data_numpy.shape[0], max_compressed_size), dtype=np.uint8)
99
113
  for i in range(flat_data_numpy.shape[0]):
100
114
  img_array = flat_data_numpy[i]
101
115
  image_bytes, _ = self.encode_to_size(
102
116
  img_array,
103
- self.max_size_bytes,
117
+ max_compressed_size,
104
118
  mode=self.mode
105
119
  )
106
120
  byte_array = np.frombuffer(image_bytes, dtype=np.uint8)
107
121
  flat_compressed_data[i, :len(byte_array)] = byte_array
108
- compressed_data = flat_compressed_data.reshape(data_numpy.shape[:-3] + (self.max_size_bytes, ))
122
+ compressed_data = flat_compressed_data.reshape(data_numpy.shape[:-3] + (max_compressed_size, ))
109
123
  compressed_data_backend = source_space.backend.from_numpy(compressed_data, dtype=self.get_uint8_dtype(source_space.backend), device=source_space.device)
110
124
  return compressed_data_backend
111
125
 
@@ -206,7 +220,6 @@ class ImageDecompressTransformation(DataTransformation):
206
220
  assert source_space is not None, "Source space must be provided to get inverse transformation."
207
221
  self.validate_source_space(source_space)
208
222
  return ImageCompressTransformation(
209
- init_quality=75,
210
223
  max_size_bytes=source_space.shape[-1],
211
224
  mode=self.mode,
212
225
  format=self.format if self.format is not None else "JPEG",
@@ -8,7 +8,7 @@ from unienv_interface.utils import seed_util
8
8
  from unienv_interface.env_base.funcenv import FuncEnv, ContextType, ObsType, ActType, RenderFrame, StateType, RenderStateType
9
9
  from unienv_interface.env_base.funcenv_wrapper import *
10
10
  from unienv_interface.space import Space
11
- from unienv_interface.utils.data_queue import FuncSpaceDataQueue, SpaceDataQueueState
11
+ from unienv_interface.utils.framestack_queue import FuncSpaceDataQueue, SpaceDataQueueState
12
12
  from unienv_interface.utils.stateclass import StateClass, field
13
13
 
14
14
  class FuncFrameStackWrapperState(
@@ -192,14 +192,20 @@ def unflatten_data(space : Space, data : BArrayType, start_dim : int = 0) -> Any
192
192
  @flatten_data.register(BinarySpace)
193
193
  def _flatten_data_common(space: typing.Union[BoxSpace, BinarySpace], data: BArrayType, start_dim : int = 0) -> BArrayType:
194
194
  assert -len(space.shape) <= start_dim <= len(space.shape)
195
- return space.backend.reshape(data, data.shape[:start_dim] + (-1,))
195
+ dat = space.backend.reshape(data, data.shape[:start_dim] + (-1,))
196
+ if isinstance(space, BinarySpace):
197
+ dat = space.backend.astype(dat, space.backend.default_integer_dtype)
198
+ return dat
196
199
 
197
200
  @unflatten_data.register(BoxSpace)
198
201
  @unflatten_data.register(BinarySpace)
199
202
  def _unflatten_data_common(space: typing.Union[BoxSpace, BinarySpace], data: Any, start_dim : int = 0) -> BArrayType:
200
203
  assert -len(space.shape) <= start_dim <= len(space.shape)
201
204
  unflat_dat = space.backend.reshape(data, data.shape[:start_dim] + space.shape[start_dim:])
202
- unflat_dat = space.backend.astype(unflat_dat, space.dtype)
205
+ if isinstance(space, BinarySpace):
206
+ unflat_dat = space.backend.astype(unflat_dat, space.dtype if space.dtype is not None else space.backend.default_boolean_dtype)
207
+ else:
208
+ unflat_dat = space.backend.astype(unflat_dat, space.dtype)
203
209
  return unflat_dat
204
210
 
205
211
  @flatten_data.register(DynamicBoxSpace)
@@ -40,7 +40,7 @@ class TupleSpace(Space[Tuple[Any, ...], BDeviceType, BDtypeType, BRNGType]):
40
40
  return self
41
41
 
42
42
  new_device = device if backend is not None else (device or self.device)
43
- return Tuple(
43
+ return TupleSpace(
44
44
  backend=backend or self.backend,
45
45
  spaces=[space.to(backend, new_device) for space in self.spaces],
46
46
  device=new_device
@@ -93,11 +93,11 @@ class TupleSpace(Space[Tuple[Any, ...], BDeviceType, BDtypeType, BRNGType]):
93
93
 
94
94
  def __eq__(self, other: Any) -> bool:
95
95
  """Check whether ``other`` is equivalent to this instance."""
96
- return isinstance(other, Tuple) and self.spaces == other.spaces
96
+ return isinstance(other, TupleSpace) and self.spaces == other.spaces
97
97
 
98
- def __copy__(self) -> "Tuple[BDeviceType, BDtypeType, BRNGType]":
98
+ def __copy__(self) -> "TupleSpace[BDeviceType, BDtypeType, BRNGType]":
99
99
  """Create a shallow copy of the Dict space."""
100
- return Tuple(
100
+ return TupleSpace(
101
101
  backend=self.backend,
102
102
  spaces=copy.copy(self.spaces),
103
103
  device=self.device
@@ -0,0 +1,106 @@
1
+ from unienv_interface.space.space_utils import batch_utils as sbu
2
+ from .transformation import DataTransformation, TargetDataT
3
+ from unienv_interface.space import Space, BoxSpace
4
+ from typing import Union, Any, Optional
5
+ from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
6
+
7
+ class ImageResizeTransformation(DataTransformation):
8
+ has_inverse = True
9
+
10
+ def __init__(
11
+ self,
12
+ new_height: int,
13
+ new_width: int
14
+ ):
15
+ self.new_height = new_height
16
+ self.new_width = new_width
17
+
18
+ def _validate_source_space(self, source_space : Space[Any, BDeviceType, BDtypeType, BRNGType]) -> BoxSpace[BArrayType, BDeviceType, BDtypeType, BRNGType]:
19
+ assert isinstance(source_space, BoxSpace), \
20
+ f"ImageResizeTransformation only supports BoxSpace, got {type(source_space)}"
21
+ assert len(source_space.shape) >= 3, \
22
+ f"ImageResizeTransformation only supports spaces with at least 3 dimensions (H, W, C), got shape {source_space.shape}"
23
+ assert source_space.shape[-3] > 0 and source_space.shape[-2] > 0, \
24
+ f"ImageResizeTransformation requires positive height and width, got shape {source_space.shape}"
25
+ return source_space
26
+
27
+ def get_target_space_from_source(self, source_space):
28
+ source_space = self._validate_source_space(source_space)
29
+
30
+ backend = source_space.backend
31
+ new_shape = (
32
+ *source_space.shape[:-3],
33
+ self.new_height,
34
+ self.new_width,
35
+ source_space.shape[-1]
36
+ )
37
+ new_low = backend.min(source_space.low, axis=(-3, -2), keepdims=True)
38
+ new_high = backend.max(source_space.high, axis=(-3, -2), keepdims=True)
39
+
40
+ return BoxSpace(
41
+ source_space.backend,
42
+ new_low,
43
+ new_high,
44
+ dtype=source_space.dtype,
45
+ device=source_space.device,
46
+ shape=new_shape
47
+ )
48
+
49
+ def transform(self, source_space, data):
50
+ source_space = self._validate_source_space(source_space)
51
+ backend = source_space.backend
52
+ if backend.simplified_name == "jax":
53
+ target_shape = (
54
+ *data.shape[:-3],
55
+ self.new_height,
56
+ self.new_width,
57
+ source_space.shape[-1]
58
+ )
59
+ import jax.image
60
+ resized_data = jax.image.resize(
61
+ data,
62
+ shape=target_shape,
63
+ method='bilinear',
64
+ antialias=True
65
+ )
66
+ elif backend.simplified_name == "pytorch":
67
+ import torch.nn.functional as F
68
+ # PyTorch expects (B, C, H, W)
69
+ data_permuted = backend.permute_dims(data, (*range(len(data.shape[:-3])), -1, -3, -2))
70
+ resized_data_permuted = F.interpolate(
71
+ data_permuted,
72
+ size=(self.new_height, self.new_width),
73
+ mode='bilinear',
74
+ align_corners=False,
75
+ antialias=True
76
+ )
77
+ # Permute back to original shape
78
+ resized_data = backend.permute_dims(resized_data_permuted, (*range(len(resized_data_permuted.shape[:-3])), -2, -1, -3))
79
+ elif backend.simplified_name == "numpy":
80
+ import cv2
81
+ flat_data = backend.reshape(data, (-1, *source_space.shape[-3:]))
82
+ resized_flat_data = []
83
+ for i in range(flat_data.shape[0]):
84
+ img = flat_data[i]
85
+ resized_img = cv2.resize(
86
+ img,
87
+ (self.new_width, self.new_height),
88
+ interpolation=cv2.INTER_LINEAR
89
+ )
90
+ resized_flat_data.append(resized_img)
91
+ resized_flat_data = backend.stack(resized_flat_data, axis=0)
92
+ resized_data = backend.reshape(
93
+ resized_flat_data,
94
+ (*data.shape[:-3], self.new_height, self.new_width, source_space.shape[-1])
95
+ )
96
+ else:
97
+ raise ValueError(f"Unsupported backend: {backend.simplified_name}")
98
+ return resized_data
99
+
100
+ def direction_inverse(self, source_space = None):
101
+ assert source_space is not None, "Inverse transformation requires source_space"
102
+ source_space = self._validate_source_space(source_space)
103
+ return ImageResizeTransformation(
104
+ new_height=source_space.shape[-3],
105
+ new_width=source_space.shape[-2]
106
+ )
@@ -0,0 +1,92 @@
1
+ from typing import Union, Any, Optional, Mapping, List, Callable, Dict
2
+
3
+ from unienv_interface.space.space_utils import batch_utils as sbu
4
+ from unienv_interface.space import Space, DictSpace, TupleSpace
5
+ from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
6
+
7
+ import copy
8
+ from .transformation import DataTransformation, TargetDataT
9
+
10
+ def default_is_leaf_fn(space : Space[Any, BDeviceType, BDtypeType, BRNGType]):
11
+ return not isinstance(space, (DictSpace, TupleSpace))
12
+
13
+ class IterativeTransformation(DataTransformation):
14
+ def __init__(
15
+ self,
16
+ transformation: DataTransformation,
17
+ is_leaf_node_fn: Callable[[Space[Any, BDeviceType, BDtypeType, BRNGType]], bool] = default_is_leaf_fn,
18
+ inv_is_leaf_node_fn: Callable[[Space[Any, BDeviceType, BDtypeType, BRNGType]], bool] = default_is_leaf_fn
19
+ ):
20
+ self.transformation = transformation
21
+ self.is_leaf_node_fn = is_leaf_node_fn
22
+ self.inv_is_leaf_node_fn = inv_is_leaf_node_fn
23
+ self.has_inverse = transformation.has_inverse
24
+
25
+ def get_target_space_from_source(
26
+ self,
27
+ source_space : Space[Any, BDeviceType, BDtypeType, BRNGType]
28
+ ):
29
+ if self.is_leaf_node_fn(source_space):
30
+ return self.transformation.get_target_space_from_source(source_space)
31
+ elif isinstance(source_space, DictSpace):
32
+ rsts = {
33
+ key: self.get_target_space_from_source(subspace)
34
+ for key, subspace in source_space.spaces.items()
35
+ }
36
+ backend = source_space.backend if len(rsts) == 0 else next(iter(rsts.values())).backend
37
+ device = source_space.device if len(rsts) == 0 else next(iter(rsts.values())).device
38
+ return DictSpace(
39
+ backend,
40
+ rsts,
41
+ device=device
42
+ )
43
+ elif isinstance(source_space, TupleSpace):
44
+ rsts = tuple(
45
+ self.get_target_space_from_source(subspace)
46
+ for subspace in source_space.spaces
47
+ )
48
+ backend = source_space.backend if len(rsts) == 0 else next(iter(rsts)).backend
49
+ device = source_space.device if len(rsts) == 0 else next(iter(rsts)).device
50
+ return TupleSpace(
51
+ backend,
52
+ rsts,
53
+ device=device
54
+ )
55
+ else:
56
+ raise ValueError(f"Unsupported space type: {type(source_space)}")
57
+
58
+ def transform(
59
+ self,
60
+ source_space: Space,
61
+ data: Union[Mapping[str, Any], BArrayType]
62
+ ) -> Union[Mapping[str, Any], BArrayType]:
63
+ if self.is_leaf_node_fn(source_space):
64
+ return self.transformation.transform(source_space, data)
65
+ elif isinstance(source_space, DictSpace):
66
+ return {
67
+ key: self.transform(subspace, data[key])
68
+ for key, subspace in source_space.spaces.items()
69
+ }
70
+ elif isinstance(source_space, TupleSpace):
71
+ return tuple(
72
+ self.transform(subspace, data[i])
73
+ for i, subspace in enumerate(source_space.spaces)
74
+ )
75
+ else:
76
+ raise ValueError(f"Unsupported space type: {type(source_space)}")
77
+
78
+ def direction_inverse(
79
+ self,
80
+ source_space = None,
81
+ ) -> Optional["IterativeTransformation"]:
82
+ if not self.has_inverse:
83
+ return None
84
+
85
+ return IterativeTransformation(
86
+ self.transformation.direction_inverse(),
87
+ is_leaf_node_fn=self.inv_is_leaf_node_fn,
88
+ inv_is_leaf_node_fn=self.is_leaf_node_fn
89
+ )
90
+
91
+ def close(self):
92
+ self.transformation.close()
@@ -5,9 +5,15 @@ __all__ = [
5
5
  "get_class_from_full_name",
6
6
  ]
7
7
 
8
+ REMAP = {
9
+ "unienv_data.storages.common.FlattenedStorage": "unienv_data.storages.flattened.FlattenedStorage",
10
+ }
11
+
8
12
  def get_full_class_name(cls : Type) -> str:
9
13
  return f"{cls.__module__}.{cls.__qualname__}"
10
14
 
11
15
  def get_class_from_full_name(full_name : str) -> Type:
16
+ if full_name in REMAP:
17
+ full_name = REMAP[full_name]
12
18
  module_name, class_name = full_name.rsplit(".", 1)
13
- return getattr(__import__(module_name, fromlist=[class_name]), class_name)
19
+ return getattr(__import__(module_name, fromlist=[class_name]), class_name)
@@ -7,7 +7,7 @@ from unienv_interface.space.space_utils import batch_utils as sbu
7
7
  from unienv_interface.env_base.env import Env, ContextType, ObsType, ActType, RenderFrame, BArrayType, BDeviceType, BDtypeType, BRNGType
8
8
  from unienv_interface.env_base.wrapper import ContextObservationWrapper, ActionWrapper, WrapperContextT, WrapperObsT, WrapperActT
9
9
  from unienv_interface.space import Space, DictSpace
10
- from unienv_interface.utils.data_queue import SpaceDataQueue
10
+ from unienv_interface.utils.framestack_queue import SpaceDataQueue
11
11
 
12
12
  class FrameStackWrapper(
13
13
  ContextObservationWrapper[