unienv 0.0.1b1__py3-none-any.whl → 0.0.1b3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- unienv-0.0.1b3.dist-info/METADATA +74 -0
- unienv-0.0.1b3.dist-info/RECORD +92 -0
- {unienv-0.0.1b1.dist-info → unienv-0.0.1b3.dist-info}/licenses/LICENSE +1 -3
- unienv-0.0.1b3.dist-info/top_level.txt +2 -0
- unienv_data/base/__init__.py +0 -1
- unienv_data/base/common.py +95 -45
- unienv_data/base/storage.py +1 -0
- unienv_data/batches/__init__.py +2 -1
- unienv_data/batches/backend_compat.py +47 -1
- unienv_data/batches/combined_batch.py +2 -4
- unienv_data/{base → batches}/transformations.py +3 -2
- unienv_data/replay_buffer/replay_buffer.py +4 -0
- unienv_data/samplers/__init__.py +0 -1
- unienv_data/samplers/multiprocessing_sampler.py +26 -22
- unienv_data/samplers/step_sampler.py +9 -18
- unienv_data/storages/common.py +5 -0
- unienv_data/storages/hdf5.py +291 -20
- unienv_data/storages/pytorch.py +1 -0
- unienv_data/storages/transformation.py +191 -0
- unienv_data/transformations/image_compress.py +213 -0
- unienv_interface/backends/jax.py +4 -1
- unienv_interface/backends/numpy.py +4 -1
- unienv_interface/backends/pytorch.py +4 -1
- unienv_interface/env_base/__init__.py +1 -0
- unienv_interface/env_base/env.py +5 -0
- unienv_interface/env_base/funcenv.py +32 -1
- unienv_interface/env_base/funcenv_wrapper.py +2 -2
- unienv_interface/env_base/vec_env.py +474 -0
- unienv_interface/func_wrapper/__init__.py +2 -1
- unienv_interface/func_wrapper/frame_stack.py +150 -0
- unienv_interface/space/space_utils/__init__.py +1 -0
- unienv_interface/space/space_utils/batch_utils.py +83 -0
- unienv_interface/space/space_utils/construct_utils.py +216 -0
- unienv_interface/space/space_utils/serialization_utils.py +16 -1
- unienv_interface/space/spaces/__init__.py +3 -1
- unienv_interface/space/spaces/batched.py +90 -0
- unienv_interface/space/spaces/binary.py +0 -1
- unienv_interface/space/spaces/box.py +13 -24
- unienv_interface/space/spaces/text.py +1 -3
- unienv_interface/transformations/dict_transform.py +31 -5
- unienv_interface/utils/control_util.py +68 -0
- unienv_interface/utils/data_queue.py +184 -0
- unienv_interface/utils/stateclass.py +46 -0
- unienv_interface/utils/vec_util.py +15 -0
- unienv_interface/world/__init__.py +3 -1
- unienv_interface/world/combined_funcnode.py +336 -0
- unienv_interface/world/combined_node.py +232 -0
- unienv_interface/wrapper/backend_compat.py +2 -2
- unienv_interface/wrapper/frame_stack.py +19 -114
- unienv_interface/wrapper/video_record.py +11 -2
- unienv-0.0.1b1.dist-info/METADATA +0 -20
- unienv-0.0.1b1.dist-info/RECORD +0 -85
- unienv-0.0.1b1.dist-info/top_level.txt +0 -4
- unienv_data/samplers/slice_sampler.py +0 -266
- unienv_maniskill/__init__.py +0 -1
- unienv_maniskill/wrapper/maniskill_compat.py +0 -235
- unienv_mjxplayground/__init__.py +0 -1
- unienv_mjxplayground/wrapper/playground_compat.py +0 -256
- {unienv-0.0.1b1.dist-info → unienv-0.0.1b3.dist-info}/WHEEL +0 -0
|
@@ -5,6 +5,7 @@ import copy
|
|
|
5
5
|
from functools import singledispatch
|
|
6
6
|
from typing import Optional, Any, Iterable, Iterator
|
|
7
7
|
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType, ArrayAPIGetIndex, ArrayAPISetIndex
|
|
8
|
+
from unienv_interface.backends.numpy import NumpyComputeBackend
|
|
8
9
|
import numpy as np
|
|
9
10
|
|
|
10
11
|
from ..spaces import *
|
|
@@ -178,6 +179,19 @@ def _reshape_batch_size_tuple(
|
|
|
178
179
|
device=space.device,
|
|
179
180
|
)
|
|
180
181
|
|
|
182
|
+
@reshape_batch_size.register(BatchedSpace)
|
|
183
|
+
def _reshape_batch_size_batched(
|
|
184
|
+
space: BatchedSpace,
|
|
185
|
+
old_batch_shape: typing.Tuple[int],
|
|
186
|
+
new_batch_shape: typing.Tuple[int]
|
|
187
|
+
) -> BatchedSpace:
|
|
188
|
+
assert len(old_batch_shape) <= len(space.batch_shape) and old_batch_shape == space.batch_shape[:len(old_batch_shape)], \
|
|
189
|
+
f"Expected the old batch shape to be a prefix of the current batch shape, but got old {old_batch_shape} != current {space.batch_shape}"
|
|
190
|
+
return BatchedSpace(
|
|
191
|
+
single_space=space.single_space,
|
|
192
|
+
batch_shape=new_batch_shape + space.batch_shape[len(old_batch_shape):]
|
|
193
|
+
)
|
|
194
|
+
|
|
181
195
|
def reshape_batch_size_in_data(
|
|
182
196
|
backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
183
197
|
data : Any,
|
|
@@ -194,6 +208,10 @@ def reshape_batch_size_in_data(
|
|
|
194
208
|
assert data.shape[:len(old_batch_shape)] == old_batch_shape, \
|
|
195
209
|
f"Expected the beginning of the shape to match the old batch shape, but got {data.shape[:len(old_batch_shape)]} != {old_batch_shape}"
|
|
196
210
|
data = data.reshape(new_batch_shape + data.shape[len(old_batch_shape):])
|
|
211
|
+
elif isinstance(data, np.ndarray) and data.dtype == object:
|
|
212
|
+
assert data.shape[:len(old_batch_shape)] == old_batch_shape, \
|
|
213
|
+
f"Expected the beginning of the shape to match the old batch shape, but got {data.shape[:len(old_batch_shape)]} != {old_batch_shape}"
|
|
214
|
+
data = data.reshape(new_batch_shape + data.shape[len(old_batch_shape):])
|
|
197
215
|
elif isinstance(data, GraphInstance):
|
|
198
216
|
assert data.n_nodes.shape[:len(old_batch_shape)] == old_batch_shape, \
|
|
199
217
|
f"Expected the beginning of the n_nodes shape to match the old batch shape, but got {data.n_nodes.shape[:len(old_batch_shape)]} != {old_batch_shape}"
|
|
@@ -301,6 +319,12 @@ def _swap_batch_dims_tuple(space: TupleSpace, dim1: int, dim2: int):
|
|
|
301
319
|
device=space.device,
|
|
302
320
|
)
|
|
303
321
|
|
|
322
|
+
@swap_batch_dims.register(BatchedSpace)
|
|
323
|
+
def _swap_batch_dims_batched(space: BatchedSpace, dim1: int, dim2: int):
|
|
324
|
+
return BatchedSpace(
|
|
325
|
+
single_space=space.single_space,
|
|
326
|
+
batch_shape=_shape_transpose(space.batch_shape, dim1, dim2)
|
|
327
|
+
)
|
|
304
328
|
|
|
305
329
|
def swap_batch_dims_in_data(
|
|
306
330
|
backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
@@ -310,6 +334,8 @@ def swap_batch_dims_in_data(
|
|
|
310
334
|
) -> Any:
|
|
311
335
|
if backend.is_backendarray(data):
|
|
312
336
|
return _tensor_transpose(backend, data, dim1, dim2)
|
|
337
|
+
elif isinstance(data, np.ndarray) and data.dtype != object:
|
|
338
|
+
return _tensor_transpose(NumpyComputeBackend, data, dim1, dim2)
|
|
313
339
|
elif isinstance(data, GraphInstance):
|
|
314
340
|
return GraphInstance(
|
|
315
341
|
n_nodes=_tensor_transpose(backend, data.n_nodes, dim1, dim2),
|
|
@@ -367,6 +393,10 @@ def _batch_size_tuple(space: TupleSpace):
|
|
|
367
393
|
|
|
368
394
|
return len(space.spaces)
|
|
369
395
|
|
|
396
|
+
@batch_size.register(BatchedSpace)
|
|
397
|
+
def _batch_size_batched(space: BatchedSpace):
|
|
398
|
+
return space.batch_shape[0] if len(space.batch_shape) > 0 else None
|
|
399
|
+
|
|
370
400
|
def batch_size_data(data: Any) -> Optional[int]:
|
|
371
401
|
if hasattr(data, "shape"):
|
|
372
402
|
return data.shape[0] if len(data.shape) > 0 else None
|
|
@@ -451,6 +481,21 @@ def _batch_space_tuple(space: TupleSpace, n: int = 1):
|
|
|
451
481
|
device=space.device,
|
|
452
482
|
)
|
|
453
483
|
|
|
484
|
+
@batch_space.register(BatchedSpace)
|
|
485
|
+
def _batch_space_batched(space: BatchedSpace, n: int = 1):
|
|
486
|
+
return BatchedSpace(
|
|
487
|
+
single_space=space.single_space,
|
|
488
|
+
batch_shape=(n,) + space.batch_shape,
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
@batch_space.register(TextSpace)
|
|
492
|
+
@batch_space.register(UnionSpace)
|
|
493
|
+
def _batch_space_text(space: typing.Union[TextSpace, UnionSpace], n: int = 1):
|
|
494
|
+
return BatchedSpace(
|
|
495
|
+
space,
|
|
496
|
+
batch_shape=(n,),
|
|
497
|
+
)
|
|
498
|
+
|
|
454
499
|
@singledispatch
|
|
455
500
|
def batch_differing_spaces(spaces: typing.Sequence[Space], device : Optional[Any] = None) -> Space:
|
|
456
501
|
assert len(spaces) > 0, "Expects a non-empty list of spaces"
|
|
@@ -642,6 +687,16 @@ def _unbatch_spaces_tuple(space: TupleSpace):
|
|
|
642
687
|
device=space.device,
|
|
643
688
|
)
|
|
644
689
|
|
|
690
|
+
@unbatch_spaces.register(BatchedSpace)
|
|
691
|
+
def _unbatch_spaces_batched(space: BatchedSpace):
|
|
692
|
+
assert len(space.batch_shape) > 0, "Expected BatchedSpace to be batched, but it is not."
|
|
693
|
+
unbatched_space = space.single_space if len(space.batch_shape) == 1 else BatchedSpace(
|
|
694
|
+
single_space=space.single_space,
|
|
695
|
+
batch_shape=space.batch_shape[1:],
|
|
696
|
+
)
|
|
697
|
+
for i in range(space.batch_shape[0]):
|
|
698
|
+
yield unbatched_space
|
|
699
|
+
|
|
645
700
|
def iterate(space: Space, items: Any) -> Iterator:
|
|
646
701
|
for i in range(batch_size_data(items)):
|
|
647
702
|
yield get_at(space, items, i)
|
|
@@ -677,6 +732,10 @@ def _get_at_dict(space: DictSpace, items: typing.Mapping[str, Any], index : Arra
|
|
|
677
732
|
def _get_at_tuple(space: TupleSpace, items: typing.Tuple[Any, ...], index : ArrayAPIGetIndex):
|
|
678
733
|
return tuple(get_at(subspace, item, index) for (subspace, item) in zip(space.spaces, items))
|
|
679
734
|
|
|
735
|
+
@get_at.register(BatchedSpace)
|
|
736
|
+
def _get_at_batched(space: BatchedSpace, items: np.ndarray, index: ArrayAPIGetIndex) -> typing.Union[np.ndarray, Any]:
|
|
737
|
+
return items[index]
|
|
738
|
+
|
|
680
739
|
@singledispatch
|
|
681
740
|
def set_at(
|
|
682
741
|
space: Space, items: Any, index: ArrayAPISetIndex, value: Any
|
|
@@ -763,6 +822,17 @@ def _set_at_tuple(
|
|
|
763
822
|
for i, subspace in enumerate(space.spaces)
|
|
764
823
|
)
|
|
765
824
|
|
|
825
|
+
@set_at.register(BatchedSpace)
|
|
826
|
+
def _set_at_batched(
|
|
827
|
+
space: BatchedSpace,
|
|
828
|
+
items: np.ndarray,
|
|
829
|
+
index: ArrayAPISetIndex,
|
|
830
|
+
value: typing.Union[np.ndarray, Any],
|
|
831
|
+
) -> np.ndarray:
|
|
832
|
+
new_data = items.copy()
|
|
833
|
+
new_data[index] = value
|
|
834
|
+
return new_data
|
|
835
|
+
|
|
766
836
|
@singledispatch
|
|
767
837
|
def concatenate(
|
|
768
838
|
space: Space, items: Iterable[Any], axis : int = 0,
|
|
@@ -865,3 +935,16 @@ def _concatenate_tuple(
|
|
|
865
935
|
)
|
|
866
936
|
|
|
867
937
|
return tuple(items)
|
|
938
|
+
|
|
939
|
+
@concatenate.register(BatchedSpace)
|
|
940
|
+
def _concatenate_batched(
|
|
941
|
+
space: BatchedSpace, items: Iterable, axis: int = 0
|
|
942
|
+
) -> Any:
|
|
943
|
+
items = list(items)
|
|
944
|
+
if len(items) == 0:
|
|
945
|
+
return np.array([], dtype=object)
|
|
946
|
+
if isinstance(items[0], np.ndarray) and items[0].dtype == object:
|
|
947
|
+
return np.concatenate(items, axis=axis)
|
|
948
|
+
else:
|
|
949
|
+
assert axis == 0, "Expected axis to be 0 when concatenating non-numpy arrays"
|
|
950
|
+
return np.asarray(items, dtype=object)
|
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
from copy import deepcopy
|
|
3
|
+
from functools import singledispatch
|
|
4
|
+
from typing import Optional, Any, Iterable, Iterator, Sequence, Tuple, Literal, Mapping
|
|
5
|
+
import numpy as np
|
|
6
|
+
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
7
|
+
from unienv_interface.space.spaces import *
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
'construct_space_from_data_stream',
|
|
11
|
+
'construct_space_from_data',
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
def construct_space_from_data_stream(
|
|
15
|
+
data : Iterable[Any],
|
|
16
|
+
backend : ComputeBackend,
|
|
17
|
+
add_bounds : bool = True
|
|
18
|
+
) -> Space:
|
|
19
|
+
"""Construct a space from a stream of data.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
data (Iterable[Any]): An iterable stream of data samples.
|
|
23
|
+
backend (ComputeBackend): The compute backend to use for array operations.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
Space: The constructed space.
|
|
27
|
+
"""
|
|
28
|
+
candidate = None
|
|
29
|
+
for d in data:
|
|
30
|
+
candidate = construct_space_from_data(d, backend, candidate, add_bounds=add_bounds)
|
|
31
|
+
return candidate
|
|
32
|
+
|
|
33
|
+
def construct_space_from_data(
|
|
34
|
+
data : Any,
|
|
35
|
+
backend : ComputeBackend,
|
|
36
|
+
candidate_space : Optional[Space] = None,
|
|
37
|
+
add_bounds : bool = False
|
|
38
|
+
) -> Space:
|
|
39
|
+
if backend.is_backendarray(data) and backend.dtype_is_boolean(data.dtype):
|
|
40
|
+
assert candidate_space is None or (
|
|
41
|
+
isinstance(candidate_space, BinarySpace) and
|
|
42
|
+
candidate_space.shape == data.shape and
|
|
43
|
+
(candidate_space.dtype is None or candidate_space.dtype == data.dtype) and
|
|
44
|
+
(candidate_space.device is None or candidate_space.device == backend.device(data))
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
return BinarySpace(
|
|
48
|
+
backend,
|
|
49
|
+
data.shape,
|
|
50
|
+
dtype=data.dtype,
|
|
51
|
+
device=backend.device(data)
|
|
52
|
+
)
|
|
53
|
+
elif backend.is_backendarray(data):
|
|
54
|
+
assert candidate_space is None or (
|
|
55
|
+
isinstance(candidate_space, (BoxSpace, DynamicBoxSpace)) and
|
|
56
|
+
candidate_space.dtype == data.dtype and
|
|
57
|
+
(candidate_space.device is None or candidate_space.device == backend.device(data))
|
|
58
|
+
)
|
|
59
|
+
if candidate_space is None:
|
|
60
|
+
return BoxSpace(
|
|
61
|
+
backend,
|
|
62
|
+
low=data if add_bounds else -backend.inf,
|
|
63
|
+
high=data if add_bounds else backend.inf,
|
|
64
|
+
shape=data.shape,
|
|
65
|
+
dtype=data.dtype,
|
|
66
|
+
device=backend.device(data)
|
|
67
|
+
)
|
|
68
|
+
elif isinstance(candidate_space, BoxSpace):
|
|
69
|
+
if candidate_space.shape != data.shape:
|
|
70
|
+
assert len(candidate_space.shape) == len(data.shape)
|
|
71
|
+
new_low_shape, new_high_shape = _dynamic_box_find_shape(
|
|
72
|
+
candidate_space.shape,
|
|
73
|
+
candidate_space.shape,
|
|
74
|
+
data.shape
|
|
75
|
+
)
|
|
76
|
+
broadcast_shape = _dynamic_box_get_broadcast_shape(new_low_shape, new_high_shape)
|
|
77
|
+
|
|
78
|
+
return DynamicBoxSpace(
|
|
79
|
+
backend,
|
|
80
|
+
*(_dynamic_box_update_bounds(
|
|
81
|
+
backend,
|
|
82
|
+
broadcast_shape,
|
|
83
|
+
candidate_space._low,
|
|
84
|
+
candidate_space._high,
|
|
85
|
+
data
|
|
86
|
+
) if add_bounds else (candidate_space._low, candidate_space._high)),
|
|
87
|
+
shape_low=new_low_shape,
|
|
88
|
+
shape_high=new_high_shape,
|
|
89
|
+
dtype=data.dtype,
|
|
90
|
+
device=backend.device(data)
|
|
91
|
+
)
|
|
92
|
+
else: # elif isinstance(candidate_space, DynamicBoxSpace):
|
|
93
|
+
new_low_shape, new_high_shape = _dynamic_box_find_shape(
|
|
94
|
+
candidate_space.shape_low,
|
|
95
|
+
candidate_space.shape_high,
|
|
96
|
+
data.shape
|
|
97
|
+
)
|
|
98
|
+
broadcast_shape = _dynamic_box_get_broadcast_shape(new_low_shape, new_high_shape)
|
|
99
|
+
|
|
100
|
+
return DynamicBoxSpace(
|
|
101
|
+
backend,
|
|
102
|
+
*(_dynamic_box_update_bounds(
|
|
103
|
+
backend,
|
|
104
|
+
broadcast_shape,
|
|
105
|
+
candidate_space._low,
|
|
106
|
+
candidate_space._high,
|
|
107
|
+
data
|
|
108
|
+
) if add_bounds else (candidate_space._low, candidate_space._high)),
|
|
109
|
+
shape_low=new_low_shape,
|
|
110
|
+
shape_high=new_high_shape,
|
|
111
|
+
dtype=data.dtype,
|
|
112
|
+
device=backend.device(data)
|
|
113
|
+
)
|
|
114
|
+
elif isinstance(data, Mapping):
|
|
115
|
+
assert candidate_space is None or (
|
|
116
|
+
isinstance(candidate_space, DictSpace)
|
|
117
|
+
and set(candidate_space.spaces.keys()) == set(data.keys())
|
|
118
|
+
)
|
|
119
|
+
if candidate_space is None:
|
|
120
|
+
spaces = {k: construct_space_from_data(v, backend) for k, v in data.items()}
|
|
121
|
+
else:
|
|
122
|
+
spaces = {
|
|
123
|
+
k: construct_space_from_data(v, backend, candidate_space.spaces[k])
|
|
124
|
+
for k, v in data.items()
|
|
125
|
+
}
|
|
126
|
+
return DictSpace(backend, spaces)
|
|
127
|
+
elif isinstance(data, str):
|
|
128
|
+
assert candidate_space is None or isinstance(candidate_space, TextSpace)
|
|
129
|
+
max_length = len(data)
|
|
130
|
+
if candidate_space is not None:
|
|
131
|
+
max_length = max(max_length, candidate_space.max_length)
|
|
132
|
+
if not add_bounds:
|
|
133
|
+
max_length = max(max_length, 4096) # Arbitrary large length if not adding bounds
|
|
134
|
+
return TextSpace(
|
|
135
|
+
backend,
|
|
136
|
+
max_length=max_length
|
|
137
|
+
)
|
|
138
|
+
elif isinstance(data, Sequence):
|
|
139
|
+
assert candidate_space is None or (
|
|
140
|
+
isinstance(candidate_space, TupleSpace)
|
|
141
|
+
and len(candidate_space.spaces) == len(data)
|
|
142
|
+
)
|
|
143
|
+
if candidate_space is None:
|
|
144
|
+
spaces = [construct_space_from_data(d, backend) for d in data]
|
|
145
|
+
else:
|
|
146
|
+
spaces = [
|
|
147
|
+
construct_space_from_data(d, backend, candidate_space.spaces[i])
|
|
148
|
+
for i, d in enumerate(data)
|
|
149
|
+
]
|
|
150
|
+
return TupleSpace(backend, spaces)
|
|
151
|
+
else:
|
|
152
|
+
raise ValueError(f"Unsupported data type for space construction: {type(data)}")
|
|
153
|
+
|
|
154
|
+
def _dynamic_box_find_shape(
|
|
155
|
+
shape_low : Sequence[int],
|
|
156
|
+
shape_high : Sequence[int],
|
|
157
|
+
data_shape : Sequence[int]
|
|
158
|
+
) -> Tuple[Sequence[int], Sequence[int]]:
|
|
159
|
+
assert len(shape_low) == len(shape_high) == len(data_shape)
|
|
160
|
+
new_shape_low = list(shape_low)
|
|
161
|
+
new_shape_high = list(shape_high)
|
|
162
|
+
for i in range(len(data_shape)):
|
|
163
|
+
new_shape_low[i] = min(shape_low[i], data_shape[i])
|
|
164
|
+
new_shape_high[i] = max(shape_high[i], data_shape[i])
|
|
165
|
+
return new_shape_low, new_shape_high
|
|
166
|
+
|
|
167
|
+
def _dynamic_box_get_broadcast_shape(
|
|
168
|
+
shape_low : Sequence[int],
|
|
169
|
+
shape_high : Sequence[int],
|
|
170
|
+
) -> Sequence[int]:
|
|
171
|
+
assert len(shape_low) == len(shape_high)
|
|
172
|
+
broadcast_shape = []
|
|
173
|
+
for low, high in zip(shape_low, shape_high):
|
|
174
|
+
if low == high:
|
|
175
|
+
broadcast_shape.append(low)
|
|
176
|
+
else:
|
|
177
|
+
broadcast_shape.append(1) # Use 1 to indicate dynamic dimension
|
|
178
|
+
return tuple(broadcast_shape)
|
|
179
|
+
|
|
180
|
+
def reshape_to_broadcastable(
|
|
181
|
+
backend : ComputeBackend,
|
|
182
|
+
target_shape : Sequence[int],
|
|
183
|
+
array : BArrayType,
|
|
184
|
+
method : Literal['min', 'max'] = 'min'
|
|
185
|
+
) -> BArrayType:
|
|
186
|
+
if array.shape == target_shape:
|
|
187
|
+
return array
|
|
188
|
+
else:
|
|
189
|
+
assert len(array.shape) == len(target_shape)
|
|
190
|
+
for i, (t_dim, a_dim) in enumerate(zip(target_shape, array.shape)):
|
|
191
|
+
assert t_dim == a_dim or t_dim == 1 or a_dim == 1
|
|
192
|
+
if t_dim == a_dim or a_dim == 1:
|
|
193
|
+
continue
|
|
194
|
+
else:
|
|
195
|
+
if method == 'min':
|
|
196
|
+
array = backend.min(array, axis=i, keepdims=True)
|
|
197
|
+
else:
|
|
198
|
+
array = backend.max(array, axis=i, keepdims=True)
|
|
199
|
+
return array
|
|
200
|
+
|
|
201
|
+
def _dynamic_box_update_bounds(
|
|
202
|
+
backend : ComputeBackend,
|
|
203
|
+
target_shape : Sequence[int],
|
|
204
|
+
current_low : BArrayType,
|
|
205
|
+
current_high : BArrayType,
|
|
206
|
+
new_data : BArrayType
|
|
207
|
+
) -> Tuple[BArrayType, BArrayType]:
|
|
208
|
+
new_low = backend.minimum(
|
|
209
|
+
reshape_to_broadcastable(backend, target_shape, current_low, method='min'),
|
|
210
|
+
reshape_to_broadcastable(backend, target_shape, new_data, method='min')
|
|
211
|
+
)
|
|
212
|
+
new_high = backend.maximum(
|
|
213
|
+
reshape_to_broadcastable(backend, target_shape, current_high, method='max'),
|
|
214
|
+
reshape_to_broadcastable(backend, target_shape, new_data, method='max')
|
|
215
|
+
)
|
|
216
|
+
return new_low, new_high
|
|
@@ -160,7 +160,7 @@ def _text_space_to_json(space: TextSpace) -> typing.Dict[str, Any]:
|
|
|
160
160
|
"type": "TextSpace",
|
|
161
161
|
"min_length": space.min_length,
|
|
162
162
|
"max_length": space.max_length,
|
|
163
|
-
"charset": "".join(space.charset)
|
|
163
|
+
"charset": "".join(space.charset) if space.charset is not None else None,
|
|
164
164
|
}
|
|
165
165
|
|
|
166
166
|
@json_to_space.register(TextSpace)
|
|
@@ -219,3 +219,18 @@ def _json_to_union_space(json_data: typing.Dict[str, Any], map_backend: ComputeB
|
|
|
219
219
|
spaces = [json_to_space(s, map_backend, map_device) for s in json_data["spaces"]]
|
|
220
220
|
return UnionSpace(map_backend, spaces, device=map_device)
|
|
221
221
|
|
|
222
|
+
@space_to_json.register(BatchedSpace)
|
|
223
|
+
def _batched_space_to_json(space: BatchedSpace) -> typing.Dict[str, Any]:
|
|
224
|
+
return {
|
|
225
|
+
"type": "BatchedSpace",
|
|
226
|
+
"single_space": space_to_json(space.single_space),
|
|
227
|
+
"batch_shape": space.batch_shape,
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
@json_to_space.register(BatchedSpace)
|
|
231
|
+
def _json_to_batched_space(json_data: typing.Dict[str, Any], map_backend: ComputeBackend, map_device: Optional[BDeviceType]) -> BatchedSpace:
|
|
232
|
+
single_space = json_to_space(json_data["single_space"], map_backend, map_device)
|
|
233
|
+
return BatchedSpace(
|
|
234
|
+
single_space,
|
|
235
|
+
batch_shape=json_data["batch_shape"]
|
|
236
|
+
)
|
|
@@ -7,6 +7,7 @@ from .graph import GraphSpace, GraphInstance
|
|
|
7
7
|
from .text import TextSpace
|
|
8
8
|
from .tuple import TupleSpace
|
|
9
9
|
from .union import UnionSpace
|
|
10
|
+
from .batched import BatchedSpace
|
|
10
11
|
|
|
11
12
|
__all__ = [
|
|
12
13
|
"Space",
|
|
@@ -18,5 +19,6 @@ __all__ = [
|
|
|
18
19
|
"GraphInstance",
|
|
19
20
|
"TextSpace",
|
|
20
21
|
"TupleSpace",
|
|
21
|
-
"UnionSpace"
|
|
22
|
+
"UnionSpace",
|
|
23
|
+
"BatchedSpace",
|
|
22
24
|
]
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""Implementation of a space consisting of finitely many elements."""
|
|
2
|
+
from typing import Any, Generic, Iterable, SupportsFloat, Mapping, Sequence, TypeVar, Optional, Tuple, Type, Literal, Union, Callable
|
|
3
|
+
import numpy as np
|
|
4
|
+
from ..space import Space, SpaceDataT
|
|
5
|
+
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
6
|
+
|
|
7
|
+
class BatchedSpace(Space[np.ndarray, BDeviceType, BDtypeType, BRNGType]):
|
|
8
|
+
"""
|
|
9
|
+
This space represents a batch of
|
|
10
|
+
"""
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
single_space : Space[SpaceDataT, BDeviceType, BDtypeType, BRNGType],
|
|
14
|
+
batch_shape: Sequence[int],
|
|
15
|
+
):
|
|
16
|
+
assert len(batch_shape) > 0, "Batch shape must be a non-empty sequence"
|
|
17
|
+
batch_shape = tuple(int(dim) for dim in batch_shape) # This changes any np types to int
|
|
18
|
+
super().__init__(
|
|
19
|
+
backend=single_space.backend,
|
|
20
|
+
shape=batch_shape + ((),) if single_space.shape is None else batch_shape + single_space.shape,
|
|
21
|
+
device=single_space.device,
|
|
22
|
+
dtype=single_space.dtype,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
self.batch_shape = batch_shape
|
|
26
|
+
self.single_space = single_space
|
|
27
|
+
|
|
28
|
+
def to(self, backend = None, device = None):
|
|
29
|
+
return BatchedSpace(
|
|
30
|
+
self.single_space.to(backend=backend, device=device),
|
|
31
|
+
self.batch_shape
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
def sample(self, rng : BRNGType) -> Tuple[
|
|
35
|
+
BRNGType, BArrayType
|
|
36
|
+
]:
|
|
37
|
+
flat_shape = np.prod(self.batch_shape)
|
|
38
|
+
samples = []
|
|
39
|
+
for i in range(flat_shape):
|
|
40
|
+
rng, single_sample = self.single_space.sample(rng)
|
|
41
|
+
samples.append(single_sample)
|
|
42
|
+
return rng, np.asarray(samples, dtype=object).reshape(self.batch_shape)
|
|
43
|
+
|
|
44
|
+
def create_empty(self):
|
|
45
|
+
flat_shape = np.prod(self.batch_shape)
|
|
46
|
+
empties = [self.single_space.create_empty() for _ in range(flat_shape)]
|
|
47
|
+
return np.asarray(empties, dtype=object).reshape(self.batch_shape)
|
|
48
|
+
|
|
49
|
+
def is_bounded(self, manner = "both"):
|
|
50
|
+
return self.single_space.is_bounded(manner=manner)
|
|
51
|
+
|
|
52
|
+
def contains(self, x: BArrayType) -> bool:
|
|
53
|
+
def is_contained_func(x):
|
|
54
|
+
return self.single_space.contains(x)
|
|
55
|
+
for _dim in reversed(self.batch_shape):
|
|
56
|
+
def new_is_contained_func(x):
|
|
57
|
+
if (not isinstance(x, np.ndarray)) or (not x.dtype == object):
|
|
58
|
+
return False
|
|
59
|
+
if len(x) != _dim:
|
|
60
|
+
return False
|
|
61
|
+
return all(is_contained_func(xi) for xi in x)
|
|
62
|
+
is_contained_func = new_is_contained_func
|
|
63
|
+
return is_contained_func(x)
|
|
64
|
+
|
|
65
|
+
def get_repr(
|
|
66
|
+
self,
|
|
67
|
+
abbreviate = False,
|
|
68
|
+
include_backend = True,
|
|
69
|
+
include_device = True,
|
|
70
|
+
include_dtype = True
|
|
71
|
+
):
|
|
72
|
+
ret = f"BatchedSpace({self.single_space}, batch_shape={self.batch_shape}"
|
|
73
|
+
if include_backend:
|
|
74
|
+
ret += f", {self.backend}"
|
|
75
|
+
if include_device:
|
|
76
|
+
ret += f", {self.device}"
|
|
77
|
+
if include_dtype:
|
|
78
|
+
ret += f", {self.dtype}"
|
|
79
|
+
ret += ")"
|
|
80
|
+
return ret
|
|
81
|
+
|
|
82
|
+
def __eq__(self, other: Any) -> bool:
|
|
83
|
+
"""Check whether `other` is equivalent to this instance."""
|
|
84
|
+
return isinstance(other, BatchedSpace) and self.backend == other.backend and self.batch_shape == other.batch_shape and self.single_space == other.single_space
|
|
85
|
+
|
|
86
|
+
def data_to(self, data, backend = None, device = None):
|
|
87
|
+
if isinstance(data, np.ndarray) and data.dtype == object:
|
|
88
|
+
return tuple(self.data_to(d, backend=backend, device=device) for d in data)
|
|
89
|
+
else:
|
|
90
|
+
return self.single_space.data_to(data, backend=backend, device=device)
|
|
@@ -13,7 +13,6 @@ class BinarySpace(Space[BArrayType, BDeviceType, BDtypeType, BRNGType]):
|
|
|
13
13
|
device : Optional[BDeviceType] = None,
|
|
14
14
|
):
|
|
15
15
|
assert dtype is None or backend.dtype_is_boolean(dtype), f"Invalid dtype {dtype}"
|
|
16
|
-
assert len(shape) > 0, "Shape must be a non-empty sequence"
|
|
17
16
|
|
|
18
17
|
assert all(
|
|
19
18
|
np.issubdtype(type(dim), np.integer) for dim in shape
|
|
@@ -53,31 +53,20 @@ class BoxSpace(Space[BArrayType, BDeviceType, BDtypeType, BRNGType]):
|
|
|
53
53
|
dttype_iinfo = backend.iinfo(dtype)
|
|
54
54
|
dtype_min = dttype_iinfo.min
|
|
55
55
|
dtype_max = dttype_iinfo.max
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
_high = high
|
|
69
|
-
_high = backend.full([1] * len(shape), _high, dtype=dtype, device=device)
|
|
70
|
-
else:
|
|
71
|
-
_high = backend.astype(high, dtype)
|
|
56
|
+
|
|
57
|
+
if isinstance(low, (int, float)) and (low == backend.inf or low == -backend.inf):
|
|
58
|
+
low = dtype_min if low == -backend.inf else dtype_max
|
|
59
|
+
if isinstance(high, (int, float)) and (high == backend.inf or high == -backend.inf):
|
|
60
|
+
high = dtype_max if high == backend.inf else dtype_min
|
|
61
|
+
|
|
62
|
+
if isinstance(low, (int, float)):
|
|
63
|
+
_low = backend.full([1] * len(shape), low, dtype=dtype, device=device)
|
|
64
|
+
else:
|
|
65
|
+
_low = backend.astype(low, dtype)
|
|
66
|
+
if isinstance(high, (int, float)):
|
|
67
|
+
_high = backend.full([1] * len(shape), high, dtype=dtype, device=device)
|
|
72
68
|
else:
|
|
73
|
-
|
|
74
|
-
_low = backend.full([1] * len(shape), low, dtype=dtype, device=device)
|
|
75
|
-
else:
|
|
76
|
-
_low = backend.astype(low, dtype)
|
|
77
|
-
if isinstance(high, (int, float)):
|
|
78
|
-
_high = backend.full([1] * len(shape), high, dtype=dtype, device=device)
|
|
79
|
-
else:
|
|
80
|
-
_high = backend.astype(high, dtype)
|
|
69
|
+
_high = backend.astype(high, dtype)
|
|
81
70
|
|
|
82
71
|
_low = backend.abbreviate_array(
|
|
83
72
|
_low,
|
|
@@ -88,9 +88,7 @@ class TextSpace(Space[str, BDeviceType, BDtypeType, BRNGType]):
|
|
|
88
88
|
return rng, sample
|
|
89
89
|
|
|
90
90
|
def create_empty(self) -> str:
|
|
91
|
-
|
|
92
|
-
"TextSpace does not support create_empty method. Use an empty string instead."
|
|
93
|
-
)
|
|
91
|
+
return ""
|
|
94
92
|
|
|
95
93
|
def is_bounded(self, manner = "both"):
|
|
96
94
|
return manner == "below" or (
|
|
@@ -7,6 +7,27 @@ from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, B
|
|
|
7
7
|
import copy
|
|
8
8
|
from .transformation import DataTransformation, TargetDataT
|
|
9
9
|
|
|
10
|
+
def get_chained_value(
|
|
11
|
+
data : Mapping[str, Any],
|
|
12
|
+
chained_key : List[str],
|
|
13
|
+
ignore_missing_keys : bool = False,
|
|
14
|
+
) -> Any:
|
|
15
|
+
assert len(chained_key) >= 1, "Chained key must have at least one key"
|
|
16
|
+
if chained_key[0] not in data.keys():
|
|
17
|
+
if ignore_missing_keys:
|
|
18
|
+
return None
|
|
19
|
+
else:
|
|
20
|
+
raise KeyError(f"Key '{chained_key[0]}' not found in data")
|
|
21
|
+
|
|
22
|
+
if len(chained_key) == 1:
|
|
23
|
+
return data[chained_key[0]]
|
|
24
|
+
else:
|
|
25
|
+
return get_chained_value(
|
|
26
|
+
data[chained_key[0]],
|
|
27
|
+
chained_key[1:],
|
|
28
|
+
ignore_missing_keys=ignore_missing_keys
|
|
29
|
+
)
|
|
30
|
+
|
|
10
31
|
def call_function_on_chained_dict(
|
|
11
32
|
data : Mapping[str, Any],
|
|
12
33
|
chained_key : List[str],
|
|
@@ -111,11 +132,16 @@ class DictTransformation(DataTransformation):
|
|
|
111
132
|
) -> Optional["DictTransformation"]:
|
|
112
133
|
if not self.has_inverse:
|
|
113
134
|
return None
|
|
114
|
-
|
|
115
|
-
inverse_mapping = {
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
135
|
+
|
|
136
|
+
inverse_mapping = {}
|
|
137
|
+
for key, transformation in self.mapping.items():
|
|
138
|
+
inverse_mapping[key] = transformation.direction_inverse(
|
|
139
|
+
None if source_space is None else get_chained_value(
|
|
140
|
+
source_space,
|
|
141
|
+
key.split('/'),
|
|
142
|
+
ignore_missing_keys=self.ignore_missing_keys
|
|
143
|
+
)
|
|
144
|
+
)
|
|
119
145
|
|
|
120
146
|
return DictTransformation(
|
|
121
147
|
mapping=inverse_mapping,
|