unienv 0.0.1b1__py3-none-any.whl → 0.0.1b3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. unienv-0.0.1b3.dist-info/METADATA +74 -0
  2. unienv-0.0.1b3.dist-info/RECORD +92 -0
  3. {unienv-0.0.1b1.dist-info → unienv-0.0.1b3.dist-info}/licenses/LICENSE +1 -3
  4. unienv-0.0.1b3.dist-info/top_level.txt +2 -0
  5. unienv_data/base/__init__.py +0 -1
  6. unienv_data/base/common.py +95 -45
  7. unienv_data/base/storage.py +1 -0
  8. unienv_data/batches/__init__.py +2 -1
  9. unienv_data/batches/backend_compat.py +47 -1
  10. unienv_data/batches/combined_batch.py +2 -4
  11. unienv_data/{base → batches}/transformations.py +3 -2
  12. unienv_data/replay_buffer/replay_buffer.py +4 -0
  13. unienv_data/samplers/__init__.py +0 -1
  14. unienv_data/samplers/multiprocessing_sampler.py +26 -22
  15. unienv_data/samplers/step_sampler.py +9 -18
  16. unienv_data/storages/common.py +5 -0
  17. unienv_data/storages/hdf5.py +291 -20
  18. unienv_data/storages/pytorch.py +1 -0
  19. unienv_data/storages/transformation.py +191 -0
  20. unienv_data/transformations/image_compress.py +213 -0
  21. unienv_interface/backends/jax.py +4 -1
  22. unienv_interface/backends/numpy.py +4 -1
  23. unienv_interface/backends/pytorch.py +4 -1
  24. unienv_interface/env_base/__init__.py +1 -0
  25. unienv_interface/env_base/env.py +5 -0
  26. unienv_interface/env_base/funcenv.py +32 -1
  27. unienv_interface/env_base/funcenv_wrapper.py +2 -2
  28. unienv_interface/env_base/vec_env.py +474 -0
  29. unienv_interface/func_wrapper/__init__.py +2 -1
  30. unienv_interface/func_wrapper/frame_stack.py +150 -0
  31. unienv_interface/space/space_utils/__init__.py +1 -0
  32. unienv_interface/space/space_utils/batch_utils.py +83 -0
  33. unienv_interface/space/space_utils/construct_utils.py +216 -0
  34. unienv_interface/space/space_utils/serialization_utils.py +16 -1
  35. unienv_interface/space/spaces/__init__.py +3 -1
  36. unienv_interface/space/spaces/batched.py +90 -0
  37. unienv_interface/space/spaces/binary.py +0 -1
  38. unienv_interface/space/spaces/box.py +13 -24
  39. unienv_interface/space/spaces/text.py +1 -3
  40. unienv_interface/transformations/dict_transform.py +31 -5
  41. unienv_interface/utils/control_util.py +68 -0
  42. unienv_interface/utils/data_queue.py +184 -0
  43. unienv_interface/utils/stateclass.py +46 -0
  44. unienv_interface/utils/vec_util.py +15 -0
  45. unienv_interface/world/__init__.py +3 -1
  46. unienv_interface/world/combined_funcnode.py +336 -0
  47. unienv_interface/world/combined_node.py +232 -0
  48. unienv_interface/wrapper/backend_compat.py +2 -2
  49. unienv_interface/wrapper/frame_stack.py +19 -114
  50. unienv_interface/wrapper/video_record.py +11 -2
  51. unienv-0.0.1b1.dist-info/METADATA +0 -20
  52. unienv-0.0.1b1.dist-info/RECORD +0 -85
  53. unienv-0.0.1b1.dist-info/top_level.txt +0 -4
  54. unienv_data/samplers/slice_sampler.py +0 -266
  55. unienv_maniskill/__init__.py +0 -1
  56. unienv_maniskill/wrapper/maniskill_compat.py +0 -235
  57. unienv_mjxplayground/__init__.py +0 -1
  58. unienv_mjxplayground/wrapper/playground_compat.py +0 -256
  59. {unienv-0.0.1b1.dist-info → unienv-0.0.1b3.dist-info}/WHEEL +0 -0
@@ -5,6 +5,7 @@ import copy
5
5
  from functools import singledispatch
6
6
  from typing import Optional, Any, Iterable, Iterator
7
7
  from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType, ArrayAPIGetIndex, ArrayAPISetIndex
8
+ from unienv_interface.backends.numpy import NumpyComputeBackend
8
9
  import numpy as np
9
10
 
10
11
  from ..spaces import *
@@ -178,6 +179,19 @@ def _reshape_batch_size_tuple(
178
179
  device=space.device,
179
180
  )
180
181
 
182
+ @reshape_batch_size.register(BatchedSpace)
183
+ def _reshape_batch_size_batched(
184
+ space: BatchedSpace,
185
+ old_batch_shape: typing.Tuple[int],
186
+ new_batch_shape: typing.Tuple[int]
187
+ ) -> BatchedSpace:
188
+ assert len(old_batch_shape) <= len(space.batch_shape) and old_batch_shape == space.batch_shape[:len(old_batch_shape)], \
189
+ f"Expected the old batch shape to be a prefix of the current batch shape, but got old {old_batch_shape} != current {space.batch_shape}"
190
+ return BatchedSpace(
191
+ single_space=space.single_space,
192
+ batch_shape=new_batch_shape + space.batch_shape[len(old_batch_shape):]
193
+ )
194
+
181
195
  def reshape_batch_size_in_data(
182
196
  backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
183
197
  data : Any,
@@ -194,6 +208,10 @@ def reshape_batch_size_in_data(
194
208
  assert data.shape[:len(old_batch_shape)] == old_batch_shape, \
195
209
  f"Expected the beginning of the shape to match the old batch shape, but got {data.shape[:len(old_batch_shape)]} != {old_batch_shape}"
196
210
  data = data.reshape(new_batch_shape + data.shape[len(old_batch_shape):])
211
+ elif isinstance(data, np.ndarray) and data.dtype == object:
212
+ assert data.shape[:len(old_batch_shape)] == old_batch_shape, \
213
+ f"Expected the beginning of the shape to match the old batch shape, but got {data.shape[:len(old_batch_shape)]} != {old_batch_shape}"
214
+ data = data.reshape(new_batch_shape + data.shape[len(old_batch_shape):])
197
215
  elif isinstance(data, GraphInstance):
198
216
  assert data.n_nodes.shape[:len(old_batch_shape)] == old_batch_shape, \
199
217
  f"Expected the beginning of the n_nodes shape to match the old batch shape, but got {data.n_nodes.shape[:len(old_batch_shape)]} != {old_batch_shape}"
@@ -301,6 +319,12 @@ def _swap_batch_dims_tuple(space: TupleSpace, dim1: int, dim2: int):
301
319
  device=space.device,
302
320
  )
303
321
 
322
+ @swap_batch_dims.register(BatchedSpace)
323
+ def _swap_batch_dims_batched(space: BatchedSpace, dim1: int, dim2: int):
324
+ return BatchedSpace(
325
+ single_space=space.single_space,
326
+ batch_shape=_shape_transpose(space.batch_shape, dim1, dim2)
327
+ )
304
328
 
305
329
  def swap_batch_dims_in_data(
306
330
  backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
@@ -310,6 +334,8 @@ def swap_batch_dims_in_data(
310
334
  ) -> Any:
311
335
  if backend.is_backendarray(data):
312
336
  return _tensor_transpose(backend, data, dim1, dim2)
337
+ elif isinstance(data, np.ndarray) and data.dtype != object:
338
+ return _tensor_transpose(NumpyComputeBackend, data, dim1, dim2)
313
339
  elif isinstance(data, GraphInstance):
314
340
  return GraphInstance(
315
341
  n_nodes=_tensor_transpose(backend, data.n_nodes, dim1, dim2),
@@ -367,6 +393,10 @@ def _batch_size_tuple(space: TupleSpace):
367
393
 
368
394
  return len(space.spaces)
369
395
 
396
+ @batch_size.register(BatchedSpace)
397
+ def _batch_size_batched(space: BatchedSpace):
398
+ return space.batch_shape[0] if len(space.batch_shape) > 0 else None
399
+
370
400
  def batch_size_data(data: Any) -> Optional[int]:
371
401
  if hasattr(data, "shape"):
372
402
  return data.shape[0] if len(data.shape) > 0 else None
@@ -451,6 +481,21 @@ def _batch_space_tuple(space: TupleSpace, n: int = 1):
451
481
  device=space.device,
452
482
  )
453
483
 
484
+ @batch_space.register(BatchedSpace)
485
+ def _batch_space_batched(space: BatchedSpace, n: int = 1):
486
+ return BatchedSpace(
487
+ single_space=space.single_space,
488
+ batch_shape=(n,) + space.batch_shape,
489
+ )
490
+
491
+ @batch_space.register(TextSpace)
492
+ @batch_space.register(UnionSpace)
493
+ def _batch_space_text(space: typing.Union[TextSpace, UnionSpace], n: int = 1):
494
+ return BatchedSpace(
495
+ space,
496
+ batch_shape=(n,),
497
+ )
498
+
454
499
  @singledispatch
455
500
  def batch_differing_spaces(spaces: typing.Sequence[Space], device : Optional[Any] = None) -> Space:
456
501
  assert len(spaces) > 0, "Expects a non-empty list of spaces"
@@ -642,6 +687,16 @@ def _unbatch_spaces_tuple(space: TupleSpace):
642
687
  device=space.device,
643
688
  )
644
689
 
690
+ @unbatch_spaces.register(BatchedSpace)
691
+ def _unbatch_spaces_batched(space: BatchedSpace):
692
+ assert len(space.batch_shape) > 0, "Expected BatchedSpace to be batched, but it is not."
693
+ unbatched_space = space.single_space if len(space.batch_shape) == 1 else BatchedSpace(
694
+ single_space=space.single_space,
695
+ batch_shape=space.batch_shape[1:],
696
+ )
697
+ for i in range(space.batch_shape[0]):
698
+ yield unbatched_space
699
+
645
700
  def iterate(space: Space, items: Any) -> Iterator:
646
701
  for i in range(batch_size_data(items)):
647
702
  yield get_at(space, items, i)
@@ -677,6 +732,10 @@ def _get_at_dict(space: DictSpace, items: typing.Mapping[str, Any], index : Arra
677
732
  def _get_at_tuple(space: TupleSpace, items: typing.Tuple[Any, ...], index : ArrayAPIGetIndex):
678
733
  return tuple(get_at(subspace, item, index) for (subspace, item) in zip(space.spaces, items))
679
734
 
735
+ @get_at.register(BatchedSpace)
736
+ def _get_at_batched(space: BatchedSpace, items: np.ndarray, index: ArrayAPIGetIndex) -> typing.Union[np.ndarray, Any]:
737
+ return items[index]
738
+
680
739
  @singledispatch
681
740
  def set_at(
682
741
  space: Space, items: Any, index: ArrayAPISetIndex, value: Any
@@ -763,6 +822,17 @@ def _set_at_tuple(
763
822
  for i, subspace in enumerate(space.spaces)
764
823
  )
765
824
 
825
+ @set_at.register(BatchedSpace)
826
+ def _set_at_batched(
827
+ space: BatchedSpace,
828
+ items: np.ndarray,
829
+ index: ArrayAPISetIndex,
830
+ value: typing.Union[np.ndarray, Any],
831
+ ) -> np.ndarray:
832
+ new_data = items.copy()
833
+ new_data[index] = value
834
+ return new_data
835
+
766
836
  @singledispatch
767
837
  def concatenate(
768
838
  space: Space, items: Iterable[Any], axis : int = 0,
@@ -865,3 +935,16 @@ def _concatenate_tuple(
865
935
  )
866
936
 
867
937
  return tuple(items)
938
+
939
+ @concatenate.register(BatchedSpace)
940
+ def _concatenate_batched(
941
+ space: BatchedSpace, items: Iterable, axis: int = 0
942
+ ) -> Any:
943
+ items = list(items)
944
+ if len(items) == 0:
945
+ return np.array([], dtype=object)
946
+ if isinstance(items[0], np.ndarray) and items[0].dtype == object:
947
+ return np.concatenate(items, axis=axis)
948
+ else:
949
+ assert axis == 0, "Expected axis to be 0 when concatenating non-numpy arrays"
950
+ return np.asarray(items, dtype=object)
@@ -0,0 +1,216 @@
1
+ import typing
2
+ from copy import deepcopy
3
+ from functools import singledispatch
4
+ from typing import Optional, Any, Iterable, Iterator, Sequence, Tuple, Literal, Mapping
5
+ import numpy as np
6
+ from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
7
+ from unienv_interface.space.spaces import *
8
+
9
+ __all__ = [
10
+ 'construct_space_from_data_stream',
11
+ 'construct_space_from_data',
12
+ ]
13
+
14
+ def construct_space_from_data_stream(
15
+ data : Iterable[Any],
16
+ backend : ComputeBackend,
17
+ add_bounds : bool = True
18
+ ) -> Space:
19
+ """Construct a space from a stream of data.
20
+
21
+ Args:
22
+ data (Iterable[Any]): An iterable stream of data samples.
23
+ backend (ComputeBackend): The compute backend to use for array operations.
24
+
25
+ Returns:
26
+ Space: The constructed space.
27
+ """
28
+ candidate = None
29
+ for d in data:
30
+ candidate = construct_space_from_data(d, backend, candidate, add_bounds=add_bounds)
31
+ return candidate
32
+
33
+ def construct_space_from_data(
34
+ data : Any,
35
+ backend : ComputeBackend,
36
+ candidate_space : Optional[Space] = None,
37
+ add_bounds : bool = False
38
+ ) -> Space:
39
+ if backend.is_backendarray(data) and backend.dtype_is_boolean(data.dtype):
40
+ assert candidate_space is None or (
41
+ isinstance(candidate_space, BinarySpace) and
42
+ candidate_space.shape == data.shape and
43
+ (candidate_space.dtype is None or candidate_space.dtype == data.dtype) and
44
+ (candidate_space.device is None or candidate_space.device == backend.device(data))
45
+ )
46
+
47
+ return BinarySpace(
48
+ backend,
49
+ data.shape,
50
+ dtype=data.dtype,
51
+ device=backend.device(data)
52
+ )
53
+ elif backend.is_backendarray(data):
54
+ assert candidate_space is None or (
55
+ isinstance(candidate_space, (BoxSpace, DynamicBoxSpace)) and
56
+ candidate_space.dtype == data.dtype and
57
+ (candidate_space.device is None or candidate_space.device == backend.device(data))
58
+ )
59
+ if candidate_space is None:
60
+ return BoxSpace(
61
+ backend,
62
+ low=data if add_bounds else -backend.inf,
63
+ high=data if add_bounds else backend.inf,
64
+ shape=data.shape,
65
+ dtype=data.dtype,
66
+ device=backend.device(data)
67
+ )
68
+ elif isinstance(candidate_space, BoxSpace):
69
+ if candidate_space.shape != data.shape:
70
+ assert len(candidate_space.shape) == len(data.shape)
71
+ new_low_shape, new_high_shape = _dynamic_box_find_shape(
72
+ candidate_space.shape,
73
+ candidate_space.shape,
74
+ data.shape
75
+ )
76
+ broadcast_shape = _dynamic_box_get_broadcast_shape(new_low_shape, new_high_shape)
77
+
78
+ return DynamicBoxSpace(
79
+ backend,
80
+ *(_dynamic_box_update_bounds(
81
+ backend,
82
+ broadcast_shape,
83
+ candidate_space._low,
84
+ candidate_space._high,
85
+ data
86
+ ) if add_bounds else (candidate_space._low, candidate_space._high)),
87
+ shape_low=new_low_shape,
88
+ shape_high=new_high_shape,
89
+ dtype=data.dtype,
90
+ device=backend.device(data)
91
+ )
92
+ else: # elif isinstance(candidate_space, DynamicBoxSpace):
93
+ new_low_shape, new_high_shape = _dynamic_box_find_shape(
94
+ candidate_space.shape_low,
95
+ candidate_space.shape_high,
96
+ data.shape
97
+ )
98
+ broadcast_shape = _dynamic_box_get_broadcast_shape(new_low_shape, new_high_shape)
99
+
100
+ return DynamicBoxSpace(
101
+ backend,
102
+ *(_dynamic_box_update_bounds(
103
+ backend,
104
+ broadcast_shape,
105
+ candidate_space._low,
106
+ candidate_space._high,
107
+ data
108
+ ) if add_bounds else (candidate_space._low, candidate_space._high)),
109
+ shape_low=new_low_shape,
110
+ shape_high=new_high_shape,
111
+ dtype=data.dtype,
112
+ device=backend.device(data)
113
+ )
114
+ elif isinstance(data, Mapping):
115
+ assert candidate_space is None or (
116
+ isinstance(candidate_space, DictSpace)
117
+ and set(candidate_space.spaces.keys()) == set(data.keys())
118
+ )
119
+ if candidate_space is None:
120
+ spaces = {k: construct_space_from_data(v, backend) for k, v in data.items()}
121
+ else:
122
+ spaces = {
123
+ k: construct_space_from_data(v, backend, candidate_space.spaces[k])
124
+ for k, v in data.items()
125
+ }
126
+ return DictSpace(backend, spaces)
127
+ elif isinstance(data, str):
128
+ assert candidate_space is None or isinstance(candidate_space, TextSpace)
129
+ max_length = len(data)
130
+ if candidate_space is not None:
131
+ max_length = max(max_length, candidate_space.max_length)
132
+ if not add_bounds:
133
+ max_length = max(max_length, 4096) # Arbitrary large length if not adding bounds
134
+ return TextSpace(
135
+ backend,
136
+ max_length=max_length
137
+ )
138
+ elif isinstance(data, Sequence):
139
+ assert candidate_space is None or (
140
+ isinstance(candidate_space, TupleSpace)
141
+ and len(candidate_space.spaces) == len(data)
142
+ )
143
+ if candidate_space is None:
144
+ spaces = [construct_space_from_data(d, backend) for d in data]
145
+ else:
146
+ spaces = [
147
+ construct_space_from_data(d, backend, candidate_space.spaces[i])
148
+ for i, d in enumerate(data)
149
+ ]
150
+ return TupleSpace(backend, spaces)
151
+ else:
152
+ raise ValueError(f"Unsupported data type for space construction: {type(data)}")
153
+
154
+ def _dynamic_box_find_shape(
155
+ shape_low : Sequence[int],
156
+ shape_high : Sequence[int],
157
+ data_shape : Sequence[int]
158
+ ) -> Tuple[Sequence[int], Sequence[int]]:
159
+ assert len(shape_low) == len(shape_high) == len(data_shape)
160
+ new_shape_low = list(shape_low)
161
+ new_shape_high = list(shape_high)
162
+ for i in range(len(data_shape)):
163
+ new_shape_low[i] = min(shape_low[i], data_shape[i])
164
+ new_shape_high[i] = max(shape_high[i], data_shape[i])
165
+ return new_shape_low, new_shape_high
166
+
167
+ def _dynamic_box_get_broadcast_shape(
168
+ shape_low : Sequence[int],
169
+ shape_high : Sequence[int],
170
+ ) -> Sequence[int]:
171
+ assert len(shape_low) == len(shape_high)
172
+ broadcast_shape = []
173
+ for low, high in zip(shape_low, shape_high):
174
+ if low == high:
175
+ broadcast_shape.append(low)
176
+ else:
177
+ broadcast_shape.append(1) # Use 1 to indicate dynamic dimension
178
+ return tuple(broadcast_shape)
179
+
180
+ def reshape_to_broadcastable(
181
+ backend : ComputeBackend,
182
+ target_shape : Sequence[int],
183
+ array : BArrayType,
184
+ method : Literal['min', 'max'] = 'min'
185
+ ) -> BArrayType:
186
+ if array.shape == target_shape:
187
+ return array
188
+ else:
189
+ assert len(array.shape) == len(target_shape)
190
+ for i, (t_dim, a_dim) in enumerate(zip(target_shape, array.shape)):
191
+ assert t_dim == a_dim or t_dim == 1 or a_dim == 1
192
+ if t_dim == a_dim or a_dim == 1:
193
+ continue
194
+ else:
195
+ if method == 'min':
196
+ array = backend.min(array, axis=i, keepdims=True)
197
+ else:
198
+ array = backend.max(array, axis=i, keepdims=True)
199
+ return array
200
+
201
+ def _dynamic_box_update_bounds(
202
+ backend : ComputeBackend,
203
+ target_shape : Sequence[int],
204
+ current_low : BArrayType,
205
+ current_high : BArrayType,
206
+ new_data : BArrayType
207
+ ) -> Tuple[BArrayType, BArrayType]:
208
+ new_low = backend.minimum(
209
+ reshape_to_broadcastable(backend, target_shape, current_low, method='min'),
210
+ reshape_to_broadcastable(backend, target_shape, new_data, method='min')
211
+ )
212
+ new_high = backend.maximum(
213
+ reshape_to_broadcastable(backend, target_shape, current_high, method='max'),
214
+ reshape_to_broadcastable(backend, target_shape, new_data, method='max')
215
+ )
216
+ return new_low, new_high
@@ -160,7 +160,7 @@ def _text_space_to_json(space: TextSpace) -> typing.Dict[str, Any]:
160
160
  "type": "TextSpace",
161
161
  "min_length": space.min_length,
162
162
  "max_length": space.max_length,
163
- "charset": "".join(space.charset)
163
+ "charset": "".join(space.charset) if space.charset is not None else None,
164
164
  }
165
165
 
166
166
  @json_to_space.register(TextSpace)
@@ -219,3 +219,18 @@ def _json_to_union_space(json_data: typing.Dict[str, Any], map_backend: ComputeB
219
219
  spaces = [json_to_space(s, map_backend, map_device) for s in json_data["spaces"]]
220
220
  return UnionSpace(map_backend, spaces, device=map_device)
221
221
 
222
+ @space_to_json.register(BatchedSpace)
223
+ def _batched_space_to_json(space: BatchedSpace) -> typing.Dict[str, Any]:
224
+ return {
225
+ "type": "BatchedSpace",
226
+ "single_space": space_to_json(space.single_space),
227
+ "batch_shape": space.batch_shape,
228
+ }
229
+
230
+ @json_to_space.register(BatchedSpace)
231
+ def _json_to_batched_space(json_data: typing.Dict[str, Any], map_backend: ComputeBackend, map_device: Optional[BDeviceType]) -> BatchedSpace:
232
+ single_space = json_to_space(json_data["single_space"], map_backend, map_device)
233
+ return BatchedSpace(
234
+ single_space,
235
+ batch_shape=json_data["batch_shape"]
236
+ )
@@ -7,6 +7,7 @@ from .graph import GraphSpace, GraphInstance
7
7
  from .text import TextSpace
8
8
  from .tuple import TupleSpace
9
9
  from .union import UnionSpace
10
+ from .batched import BatchedSpace
10
11
 
11
12
  __all__ = [
12
13
  "Space",
@@ -18,5 +19,6 @@ __all__ = [
18
19
  "GraphInstance",
19
20
  "TextSpace",
20
21
  "TupleSpace",
21
- "UnionSpace"
22
+ "UnionSpace",
23
+ "BatchedSpace",
22
24
  ]
@@ -0,0 +1,90 @@
1
+ """Implementation of a space consisting of finitely many elements."""
2
+ from typing import Any, Generic, Iterable, SupportsFloat, Mapping, Sequence, TypeVar, Optional, Tuple, Type, Literal, Union, Callable
3
+ import numpy as np
4
+ from ..space import Space, SpaceDataT
5
+ from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
6
+
7
+ class BatchedSpace(Space[np.ndarray, BDeviceType, BDtypeType, BRNGType]):
8
+ """
9
+ This space represents a batch of
10
+ """
11
+ def __init__(
12
+ self,
13
+ single_space : Space[SpaceDataT, BDeviceType, BDtypeType, BRNGType],
14
+ batch_shape: Sequence[int],
15
+ ):
16
+ assert len(batch_shape) > 0, "Batch shape must be a non-empty sequence"
17
+ batch_shape = tuple(int(dim) for dim in batch_shape) # This changes any np types to int
18
+ super().__init__(
19
+ backend=single_space.backend,
20
+ shape=batch_shape + ((),) if single_space.shape is None else batch_shape + single_space.shape,
21
+ device=single_space.device,
22
+ dtype=single_space.dtype,
23
+ )
24
+
25
+ self.batch_shape = batch_shape
26
+ self.single_space = single_space
27
+
28
+ def to(self, backend = None, device = None):
29
+ return BatchedSpace(
30
+ self.single_space.to(backend=backend, device=device),
31
+ self.batch_shape
32
+ )
33
+
34
+ def sample(self, rng : BRNGType) -> Tuple[
35
+ BRNGType, BArrayType
36
+ ]:
37
+ flat_shape = np.prod(self.batch_shape)
38
+ samples = []
39
+ for i in range(flat_shape):
40
+ rng, single_sample = self.single_space.sample(rng)
41
+ samples.append(single_sample)
42
+ return rng, np.asarray(samples, dtype=object).reshape(self.batch_shape)
43
+
44
+ def create_empty(self):
45
+ flat_shape = np.prod(self.batch_shape)
46
+ empties = [self.single_space.create_empty() for _ in range(flat_shape)]
47
+ return np.asarray(empties, dtype=object).reshape(self.batch_shape)
48
+
49
+ def is_bounded(self, manner = "both"):
50
+ return self.single_space.is_bounded(manner=manner)
51
+
52
+ def contains(self, x: BArrayType) -> bool:
53
+ def is_contained_func(x):
54
+ return self.single_space.contains(x)
55
+ for _dim in reversed(self.batch_shape):
56
+ def new_is_contained_func(x):
57
+ if (not isinstance(x, np.ndarray)) or (not x.dtype == object):
58
+ return False
59
+ if len(x) != _dim:
60
+ return False
61
+ return all(is_contained_func(xi) for xi in x)
62
+ is_contained_func = new_is_contained_func
63
+ return is_contained_func(x)
64
+
65
+ def get_repr(
66
+ self,
67
+ abbreviate = False,
68
+ include_backend = True,
69
+ include_device = True,
70
+ include_dtype = True
71
+ ):
72
+ ret = f"BatchedSpace({self.single_space}, batch_shape={self.batch_shape}"
73
+ if include_backend:
74
+ ret += f", {self.backend}"
75
+ if include_device:
76
+ ret += f", {self.device}"
77
+ if include_dtype:
78
+ ret += f", {self.dtype}"
79
+ ret += ")"
80
+ return ret
81
+
82
+ def __eq__(self, other: Any) -> bool:
83
+ """Check whether `other` is equivalent to this instance."""
84
+ return isinstance(other, BatchedSpace) and self.backend == other.backend and self.batch_shape == other.batch_shape and self.single_space == other.single_space
85
+
86
+ def data_to(self, data, backend = None, device = None):
87
+ if isinstance(data, np.ndarray) and data.dtype == object:
88
+ return tuple(self.data_to(d, backend=backend, device=device) for d in data)
89
+ else:
90
+ return self.single_space.data_to(data, backend=backend, device=device)
@@ -13,7 +13,6 @@ class BinarySpace(Space[BArrayType, BDeviceType, BDtypeType, BRNGType]):
13
13
  device : Optional[BDeviceType] = None,
14
14
  ):
15
15
  assert dtype is None or backend.dtype_is_boolean(dtype), f"Invalid dtype {dtype}"
16
- assert len(shape) > 0, "Shape must be a non-empty sequence"
17
16
 
18
17
  assert all(
19
18
  np.issubdtype(type(dim), np.integer) for dim in shape
@@ -53,31 +53,20 @@ class BoxSpace(Space[BArrayType, BDeviceType, BDtypeType, BRNGType]):
53
53
  dttype_iinfo = backend.iinfo(dtype)
54
54
  dtype_min = dttype_iinfo.min
55
55
  dtype_max = dttype_iinfo.max
56
- if isinstance(low, int):
57
- if low == backend.inf or low == -backend.inf:
58
- _low = dtype_min if low == -backend.inf else dtype_max
59
- else:
60
- _low = low
61
- _low = backend.full([1] * len(shape), _low, dtype=dtype, device=device)
62
- else:
63
- _low = backend.astype(low, dtype)
64
- if isinstance(high, int):
65
- if high == backend.inf or high == -backend.inf:
66
- _high = dtype_max if high == backend.inf else dtype_min
67
- else:
68
- _high = high
69
- _high = backend.full([1] * len(shape), _high, dtype=dtype, device=device)
70
- else:
71
- _high = backend.astype(high, dtype)
56
+
57
+ if isinstance(low, (int, float)) and (low == backend.inf or low == -backend.inf):
58
+ low = dtype_min if low == -backend.inf else dtype_max
59
+ if isinstance(high, (int, float)) and (high == backend.inf or high == -backend.inf):
60
+ high = dtype_max if high == backend.inf else dtype_min
61
+
62
+ if isinstance(low, (int, float)):
63
+ _low = backend.full([1] * len(shape), low, dtype=dtype, device=device)
64
+ else:
65
+ _low = backend.astype(low, dtype)
66
+ if isinstance(high, (int, float)):
67
+ _high = backend.full([1] * len(shape), high, dtype=dtype, device=device)
72
68
  else:
73
- if isinstance(low, (int, float)):
74
- _low = backend.full([1] * len(shape), low, dtype=dtype, device=device)
75
- else:
76
- _low = backend.astype(low, dtype)
77
- if isinstance(high, (int, float)):
78
- _high = backend.full([1] * len(shape), high, dtype=dtype, device=device)
79
- else:
80
- _high = backend.astype(high, dtype)
69
+ _high = backend.astype(high, dtype)
81
70
 
82
71
  _low = backend.abbreviate_array(
83
72
  _low,
@@ -88,9 +88,7 @@ class TextSpace(Space[str, BDeviceType, BDtypeType, BRNGType]):
88
88
  return rng, sample
89
89
 
90
90
  def create_empty(self) -> str:
91
- raise NotImplementedError(
92
- "TextSpace does not support create_empty method. Use an empty string instead."
93
- )
91
+ return ""
94
92
 
95
93
  def is_bounded(self, manner = "both"):
96
94
  return manner == "below" or (
@@ -7,6 +7,27 @@ from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, B
7
7
  import copy
8
8
  from .transformation import DataTransformation, TargetDataT
9
9
 
10
+ def get_chained_value(
11
+ data : Mapping[str, Any],
12
+ chained_key : List[str],
13
+ ignore_missing_keys : bool = False,
14
+ ) -> Any:
15
+ assert len(chained_key) >= 1, "Chained key must have at least one key"
16
+ if chained_key[0] not in data.keys():
17
+ if ignore_missing_keys:
18
+ return None
19
+ else:
20
+ raise KeyError(f"Key '{chained_key[0]}' not found in data")
21
+
22
+ if len(chained_key) == 1:
23
+ return data[chained_key[0]]
24
+ else:
25
+ return get_chained_value(
26
+ data[chained_key[0]],
27
+ chained_key[1:],
28
+ ignore_missing_keys=ignore_missing_keys
29
+ )
30
+
10
31
  def call_function_on_chained_dict(
11
32
  data : Mapping[str, Any],
12
33
  chained_key : List[str],
@@ -111,11 +132,16 @@ class DictTransformation(DataTransformation):
111
132
  ) -> Optional["DictTransformation"]:
112
133
  if not self.has_inverse:
113
134
  return None
114
-
115
- inverse_mapping = {
116
- key: transformation.direction_inverse(source_space)
117
- for key, transformation in self.mapping.items()
118
- }
135
+
136
+ inverse_mapping = {}
137
+ for key, transformation in self.mapping.items():
138
+ inverse_mapping[key] = transformation.direction_inverse(
139
+ None if source_space is None else get_chained_value(
140
+ source_space,
141
+ key.split('/'),
142
+ ignore_missing_keys=self.ignore_missing_keys
143
+ )
144
+ )
119
145
 
120
146
  return DictTransformation(
121
147
  mapping=inverse_mapping,