unienv 0.0.1b7__tar.gz → 0.0.1b8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {unienv-0.0.1b7/unienv.egg-info → unienv-0.0.1b8}/PKG-INFO +1 -1
- {unienv-0.0.1b7 → unienv-0.0.1b8}/pyproject.toml +1 -1
- {unienv-0.0.1b7 → unienv-0.0.1b8/unienv.egg-info}/PKG-INFO +1 -1
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv.egg-info/SOURCES.txt +1 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/batches/combined_batch.py +110 -67
- unienv-0.0.1b8/unienv_data/integrations/huggingface.py +47 -0
- unienv-0.0.1b8/unienv_data/integrations/pytorch.py +122 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/storages/_episode_storage.py +13 -4
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/storages/pytorch.py +1 -0
- unienv-0.0.1b8/unienv_data/storages/video_storage.py +535 -0
- unienv-0.0.1b7/unienv_data/integrations/pytorch.py +0 -63
- unienv-0.0.1b7/unienv_data/storages/video_storage.py +0 -297
- {unienv-0.0.1b7 → unienv-0.0.1b8}/LICENSE +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/README.md +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/setup.cfg +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv.egg-info/dependency_links.txt +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv.egg-info/requires.txt +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv.egg-info/top_level.txt +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/__init__.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/base/__init__.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/base/common.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/base/storage.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/batches/__init__.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/batches/backend_compat.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/batches/framestack_batch.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/batches/slicestack_batch.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/batches/transformations.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/replay_buffer/__init__.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/replay_buffer/replay_buffer.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/replay_buffer/trajectory_replay_buffer.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/samplers/__init__.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/samplers/multiprocessing_sampler.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/samplers/step_sampler.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/storages/_list_storage.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/storages/backend_compat.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/storages/dict_storage.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/storages/flattened.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/storages/hdf5.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/storages/image_storage.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/storages/npz_storage.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/storages/transformation.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/third_party/tensordict/memmap_tensor.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/transformations/image_compress.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/__init__.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/backends/__init__.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/backends/base.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/backends/jax.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/backends/numpy.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/backends/pytorch.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/backends/serialization.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/env_base/__init__.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/env_base/env.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/env_base/funcenv.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/env_base/funcenv_wrapper.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/env_base/vec_env.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/env_base/wrapper.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/func_wrapper/__init__.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/func_wrapper/frame_stack.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/func_wrapper/transformation.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/space/__init__.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/space/space.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/space/space_utils/__init__.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/space/space_utils/batch_utils.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/space/space_utils/construct_utils.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/space/space_utils/flatten_utils.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/space/space_utils/gym_utils.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/space/space_utils/serialization_utils.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/space/spaces/__init__.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/space/spaces/batched.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/space/spaces/binary.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/space/spaces/box.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/space/spaces/dict.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/space/spaces/dynamic_box.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/space/spaces/graph.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/space/spaces/text.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/space/spaces/tuple.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/space/spaces/union.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/transformations/__init__.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/transformations/batch_and_unbatch.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/transformations/chained_transform.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/transformations/crop.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/transformations/dict_transform.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/transformations/filter_dict.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/transformations/identity.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/transformations/image_resize.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/transformations/iter_transform.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/transformations/rescale.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/transformations/transformation.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/utils/control_util.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/utils/framestack_queue.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/utils/seed_util.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/utils/stateclass.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/utils/symbol_util.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/utils/vec_util.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/world/__init__.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/world/combined_funcnode.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/world/combined_node.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/world/funcnode.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/world/funcworld.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/world/node.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/world/world.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/wrapper/__init__.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/wrapper/action_rescale.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/wrapper/backend_compat.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/wrapper/batch_and_unbatch.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/wrapper/control_frequency_limit.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/wrapper/flatten.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/wrapper/frame_stack.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/wrapper/gym_compat.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/wrapper/time_limit.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/wrapper/transformation.py +0 -0
- {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/wrapper/video_record.py +0 -0
|
@@ -16,6 +16,7 @@ unienv_data/batches/combined_batch.py
|
|
|
16
16
|
unienv_data/batches/framestack_batch.py
|
|
17
17
|
unienv_data/batches/slicestack_batch.py
|
|
18
18
|
unienv_data/batches/transformations.py
|
|
19
|
+
unienv_data/integrations/huggingface.py
|
|
19
20
|
unienv_data/integrations/pytorch.py
|
|
20
21
|
unienv_data/replay_buffer/__init__.py
|
|
21
22
|
unienv_data/replay_buffer/replay_buffer.py
|
|
@@ -12,6 +12,102 @@ from unienv_interface.space.space_utils import batch_utils as sbu, flatten_utils
|
|
|
12
12
|
|
|
13
13
|
from ..base.common import BatchBase, IndexableType, BatchT
|
|
14
14
|
|
|
15
|
+
__all__ = [
|
|
16
|
+
"convert_single_index_to_batch",
|
|
17
|
+
"convert_index_to_batch",
|
|
18
|
+
"CombinedBatch",
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
def convert_single_index_to_batch(
|
|
22
|
+
backend : ComputeBackend,
|
|
23
|
+
index_starts : BArrayType, # (num_batches, )
|
|
24
|
+
total_length : int, # Total length of all batches
|
|
25
|
+
idx : int,
|
|
26
|
+
device : Optional[BDeviceType] = None,
|
|
27
|
+
) -> Tuple[int, int]:
|
|
28
|
+
"""
|
|
29
|
+
Convert a single index for this batch to a tuple of:
|
|
30
|
+
- The index of the batch
|
|
31
|
+
- The index to index into the batch
|
|
32
|
+
"""
|
|
33
|
+
assert -total_length <= idx < total_length, f"Index {idx} out of bounds for batch of size {total_length}"
|
|
34
|
+
if idx < 0:
|
|
35
|
+
idx += total_length
|
|
36
|
+
batch_index = int(backend.sum(
|
|
37
|
+
idx >= index_starts
|
|
38
|
+
) - 1)
|
|
39
|
+
return batch_index, idx - int(index_starts[batch_index])
|
|
40
|
+
|
|
41
|
+
def convert_index_to_batch(
|
|
42
|
+
backend : ComputeBackend,
|
|
43
|
+
index_starts : BArrayType, # (num_batches, )
|
|
44
|
+
total_length : int, # Total length of all batches
|
|
45
|
+
idx : Union[IndexableType, BArrayType],
|
|
46
|
+
device : Optional[BDeviceType] = None,
|
|
47
|
+
) -> Tuple[
|
|
48
|
+
int,
|
|
49
|
+
Union[List[
|
|
50
|
+
Tuple[int, BArrayType, BArrayType]
|
|
51
|
+
], int]
|
|
52
|
+
]:
|
|
53
|
+
"""
|
|
54
|
+
Convert an index for this batch to a tuple of:
|
|
55
|
+
- The length of the resulting array
|
|
56
|
+
- List of tuples, each containing:
|
|
57
|
+
- The index of the batch
|
|
58
|
+
- The index to index into the batch
|
|
59
|
+
- The bool mask to index into the resulting array
|
|
60
|
+
"""
|
|
61
|
+
if isinstance(idx, slice):
|
|
62
|
+
idx_array = backend.arange(
|
|
63
|
+
*idx.indices(total_length),
|
|
64
|
+
dtype=backend.default_index_dtype,
|
|
65
|
+
device=device
|
|
66
|
+
)
|
|
67
|
+
elif idx is Ellipsis:
|
|
68
|
+
idx_array = backend.arange(
|
|
69
|
+
total_length,
|
|
70
|
+
dtype=backend.default_index_dtype,
|
|
71
|
+
device=device
|
|
72
|
+
)
|
|
73
|
+
elif backend.is_backendarray(idx):
|
|
74
|
+
assert len(idx.shape) == 1, "Index must be 1D"
|
|
75
|
+
assert backend.dtype_is_real_integer(idx.dtype) or backend.dtype_is_boolean(idx.dtype), \
|
|
76
|
+
f"Index must be of integer or boolean type, got {idx.dtype}"
|
|
77
|
+
if backend.dtype_is_boolean(idx.dtype):
|
|
78
|
+
assert idx.shape[0] == total_length, f"Boolean index must have the same length as the batch, got {idx.shape[0]} vs {total_length}"
|
|
79
|
+
idx_array = backend.nonzero(idx)[0]
|
|
80
|
+
else:
|
|
81
|
+
assert backend.all(backend.logical_and(-total_length <= idx, idx < total_length)), \
|
|
82
|
+
f"Index array contains out of bounds indices for batch of size {total_length}"
|
|
83
|
+
idx_array = (idx + total_length) % total_length
|
|
84
|
+
else:
|
|
85
|
+
raise ValueError(f"Invalid index type: {type(idx)}")
|
|
86
|
+
|
|
87
|
+
assert bool(backend.all(
|
|
88
|
+
backend.logical_and(
|
|
89
|
+
-total_length <= idx_array,
|
|
90
|
+
idx_array < total_length
|
|
91
|
+
)
|
|
92
|
+
)), f"Index {idx} converted to {idx_array} is out of bounds for batch of size {total_length}"
|
|
93
|
+
|
|
94
|
+
# Convert negative indices to positive indices
|
|
95
|
+
idx_array = backend.at(idx_array)[idx_array < 0].add(total_length)
|
|
96
|
+
idx_array_bigger = idx_array[:, None] >= index_starts[None, :] # (idx_array_shape, len(self.batches))
|
|
97
|
+
idx_array_batch_idx = backend.sum(
|
|
98
|
+
idx_array_bigger,
|
|
99
|
+
axis=-1
|
|
100
|
+
) - 1 # (idx_array_shape, )
|
|
101
|
+
|
|
102
|
+
result_batch_list = []
|
|
103
|
+
batch_indexes = backend.unique_values(idx_array_batch_idx)
|
|
104
|
+
for i in range(batch_indexes.shape[0]):
|
|
105
|
+
batch_index = int(batch_indexes[i])
|
|
106
|
+
result_mask = idx_array_batch_idx == batch_index
|
|
107
|
+
index_into_batch = idx_array[result_mask] - index_starts[batch_index]
|
|
108
|
+
result_batch_list.append((batch_index, index_into_batch, result_mask))
|
|
109
|
+
return idx_array.shape[0], result_batch_list
|
|
110
|
+
|
|
15
111
|
class CombinedBatch(BatchBase[
|
|
16
112
|
BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
17
113
|
]):
|
|
@@ -83,18 +179,13 @@ class CombinedBatch(BatchBase[
|
|
|
83
179
|
return self.index_caches[-1, 1]
|
|
84
180
|
|
|
85
181
|
def _convert_single_index(self, idx : int) -> Tuple[int, int]:
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
idx += len(self)
|
|
94
|
-
batch_index = int(self.backend.sum(
|
|
95
|
-
idx >= self.index_caches[:, 0]
|
|
96
|
-
) - 1)
|
|
97
|
-
return batch_index, idx - int(self.index_caches[batch_index, 0])
|
|
182
|
+
return convert_single_index_to_batch(
|
|
183
|
+
self.backend,
|
|
184
|
+
self.index_caches[:, 0],
|
|
185
|
+
len(self),
|
|
186
|
+
idx,
|
|
187
|
+
device=self.device,
|
|
188
|
+
)
|
|
98
189
|
|
|
99
190
|
def _convert_index(self, idx : Union[IndexableType, BArrayType]) -> Tuple[
|
|
100
191
|
int,
|
|
@@ -102,61 +193,13 @@ class CombinedBatch(BatchBase[
|
|
|
102
193
|
Tuple[int, BArrayType, BArrayType]
|
|
103
194
|
]
|
|
104
195
|
]:
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
"""
|
|
113
|
-
if isinstance(idx, slice):
|
|
114
|
-
idx_array = self.backend.arange(
|
|
115
|
-
*idx.indices(len(self)),
|
|
116
|
-
dtype=self.backend.default_integer_dtype,
|
|
117
|
-
device=self.device
|
|
118
|
-
)
|
|
119
|
-
elif idx is Ellipsis:
|
|
120
|
-
idx_array = self.backend.arange(
|
|
121
|
-
len(self),
|
|
122
|
-
dtype=self.backend.default_integer_dtype,
|
|
123
|
-
device=self.device
|
|
124
|
-
)
|
|
125
|
-
elif self.backend.is_backendarray(idx):
|
|
126
|
-
assert len(idx.shape) == 1, "Index must be 1D"
|
|
127
|
-
assert self.backend.dtype_is_real_integer(idx.dtype) or self.backend.dtype_is_boolean(idx.dtype), \
|
|
128
|
-
f"Index must be of integer or boolean type, got {idx.dtype}"
|
|
129
|
-
if self.backend.dtype_is_boolean(idx.dtype):
|
|
130
|
-
assert idx.shape[0] == len(self), f"Boolean index must have the same length as the batch, got {idx.shape[0]} vs {len(self)}"
|
|
131
|
-
idx_array = self.backend.nonzero(idx)[0]
|
|
132
|
-
else:
|
|
133
|
-
idx_array = idx
|
|
134
|
-
else:
|
|
135
|
-
raise ValueError(f"Invalid index type: {type(idx)}")
|
|
136
|
-
|
|
137
|
-
assert bool(self.backend.all(
|
|
138
|
-
self.backend.logical_and(
|
|
139
|
-
-len(self) <= idx_array,
|
|
140
|
-
idx_array < len(self)
|
|
141
|
-
)
|
|
142
|
-
)), f"Index {idx} converted to {idx_array} is out of bounds for batch of size {len(self)}"
|
|
143
|
-
|
|
144
|
-
# Convert negative indices to positive indices
|
|
145
|
-
idx_array = self.backend.at(idx_array)[idx_array < 0].add(len(self))
|
|
146
|
-
idx_array_bigger = idx_array[:, None] >= self.index_caches[None, :, 0] # (idx_array_shape, len(self.batches))
|
|
147
|
-
idx_array_batch_idx = self.backend.sum(
|
|
148
|
-
idx_array_bigger,
|
|
149
|
-
axis=-1
|
|
150
|
-
) - 1 # (idx_array_shape, )
|
|
151
|
-
|
|
152
|
-
result_batch_list = []
|
|
153
|
-
batch_indexes = self.backend.unique_values(idx_array_batch_idx)
|
|
154
|
-
for i in range(batch_indexes.shape[0]):
|
|
155
|
-
batch_index = int(batch_indexes[i])
|
|
156
|
-
result_mask = idx_array_batch_idx == batch_index
|
|
157
|
-
index_into_batch = idx_array[result_mask] - self.index_caches[batch_index, 0]
|
|
158
|
-
result_batch_list.append((batch_index, index_into_batch, result_mask))
|
|
159
|
-
return idx_array.shape[0], result_batch_list
|
|
196
|
+
return convert_index_to_batch(
|
|
197
|
+
self.backend,
|
|
198
|
+
self.index_caches[:, 0],
|
|
199
|
+
len(self),
|
|
200
|
+
idx,
|
|
201
|
+
device=self.device,
|
|
202
|
+
)
|
|
160
203
|
|
|
161
204
|
def get_flattened_at(self, idx : Union[IndexableType, BaseException]) -> BArrayType:
|
|
162
205
|
if isinstance(idx, int):
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
from datasets import Dataset as HFDataset
|
|
3
|
+
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
4
|
+
from unienv_interface.space.space_utils import construct_utils as scu, batch_utils as sbu
|
|
5
|
+
from unienv_data.base import BatchBase, BatchT
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
'HFAsUniEnvDataset'
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
class HFAsUniEnvDataset(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]):
|
|
12
|
+
BACKEND_TO_FORMAT_MAP = {
|
|
13
|
+
"numpy": "numpy",
|
|
14
|
+
"pytorch": "torch",
|
|
15
|
+
"jax": "jax",
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
is_mutable = False
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
hf_dataset: HFDataset,
|
|
23
|
+
backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
24
|
+
device : Optional[BDeviceType] = None,
|
|
25
|
+
) -> None:
|
|
26
|
+
kwargs = {}
|
|
27
|
+
if backend.simplified_name != 'numpy':
|
|
28
|
+
kwargs['device'] = device
|
|
29
|
+
self.hf_dataset = hf_dataset.with_format(
|
|
30
|
+
self.BACKEND_TO_FORMAT_MAP[backend.simplified_name],
|
|
31
|
+
**kwargs
|
|
32
|
+
)
|
|
33
|
+
first_data = self.hf_dataset[0]
|
|
34
|
+
super().__init__(
|
|
35
|
+
scu.construct_space_from_data(first_data, backend)
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
def __len__(self) -> int:
|
|
39
|
+
return len(self.hf_dataset)
|
|
40
|
+
|
|
41
|
+
def get_at_with_metadata(self, idx):
|
|
42
|
+
if self.backend.is_backendarray(idx) and self.backend.dtype_is_boolean(idx):
|
|
43
|
+
idx = self.backend.nonzero(idx)[0]
|
|
44
|
+
return self.hf_dataset[idx], {}
|
|
45
|
+
|
|
46
|
+
def get_at(self, idx):
|
|
47
|
+
return self.get_at_with_metadata(idx)[0]
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
from typing import Optional, Union, Tuple, Dict, Any
|
|
2
|
+
|
|
3
|
+
from torch.utils.data import Dataset
|
|
4
|
+
|
|
5
|
+
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
6
|
+
from unienv_interface.backends.pytorch import PyTorchComputeBackend, PyTorchArrayType, PyTorchDeviceType, PyTorchDtypeType, PyTorchRNGType
|
|
7
|
+
from unienv_interface.space.space_utils import construct_utils as scu, batch_utils as sbu
|
|
8
|
+
from unienv_data.base import BatchBase, BatchT
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"UniEnvAsPyTorchDataset",
|
|
12
|
+
"PyTorchAsUniEnvDataset",
|
|
13
|
+
]
|
|
14
|
+
|
|
15
|
+
class UniEnvAsPyTorchDataset(Dataset):
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
batch : BatchBase[BatchT, PyTorchArrayType, PyTorchDeviceType, PyTorchDtypeType, PyTorchRNGType],
|
|
19
|
+
include_metadata : bool = False,
|
|
20
|
+
):
|
|
21
|
+
"""
|
|
22
|
+
A PyTorch Dataset wrapper for UniEnvPy batches.
|
|
23
|
+
Note that UniEnv's `BatchBase` will automatically collate data when indexed with batches, and therefore in the dataloader you can set `collate_fn=None`.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
batch (BatchBase): The UniEnvPy batch to wrap.
|
|
27
|
+
"""
|
|
28
|
+
self.batch = batch
|
|
29
|
+
self.include_metadata = include_metadata
|
|
30
|
+
|
|
31
|
+
def __len__(self) -> int:
|
|
32
|
+
"""
|
|
33
|
+
Get the length of the dataset.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
int: The number of items in the dataset.
|
|
37
|
+
"""
|
|
38
|
+
return len(self.batch)
|
|
39
|
+
|
|
40
|
+
def __getitem__(self, index) -> Union[BatchT, Tuple[BatchT, Dict[str, Any]]]:
|
|
41
|
+
"""
|
|
42
|
+
Get an item from the dataset.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
index (int): The index of the item to retrieve.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
Union[BatchT, Tuple[BatchT, Dict[str, Any]]]: The batch data at the specified index,
|
|
49
|
+
optionally including metadata if `include_metadata` is True.
|
|
50
|
+
"""
|
|
51
|
+
assert isinstance(index, int), "Index must be an integer."
|
|
52
|
+
if self.include_metadata:
|
|
53
|
+
return self.batch.get_at_with_metadata(index)
|
|
54
|
+
return self.batch.get_at(index)
|
|
55
|
+
|
|
56
|
+
def __getitems__(self, indices: list[int]) -> list[Union[BatchT, Tuple[BatchT, Dict[str, Any]]]]:
|
|
57
|
+
"""
|
|
58
|
+
Get multiple items from the dataset.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
indices (list[int]): The indices of the items to retrieve.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
Union[BatchT, Tuple[BatchT, Dict[str, Any]]]: Batch data at the specified indices,
|
|
65
|
+
optionally including metadata if `include_metadata` is True.
|
|
66
|
+
"""
|
|
67
|
+
indices = self.batch.backend.asarray(indices, dtype=self.batch.backend.default_integer_dtype, device=self.batch.device)
|
|
68
|
+
if self.include_metadata:
|
|
69
|
+
return self.batch.get_at_with_metadata(indices)
|
|
70
|
+
return self.batch.get_at(indices)
|
|
71
|
+
|
|
72
|
+
class PyTorchAsUniEnvDataset(BatchBase[BatchT, PyTorchArrayType, PyTorchDeviceType, PyTorchDtypeType, PyTorchRNGType]):
|
|
73
|
+
is_mutable = False
|
|
74
|
+
def __init__(
|
|
75
|
+
self,
|
|
76
|
+
dataset: Dataset,
|
|
77
|
+
):
|
|
78
|
+
"""
|
|
79
|
+
A UniEnvPy BatchBase wrapper for PyTorch Datasets.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
dataset (Dataset): The PyTorch Dataset to wrap.
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
assert len(dataset) >= 0, "The provided PyTorch Dataset must have a defined length."
|
|
86
|
+
self.dataset = dataset
|
|
87
|
+
tmp_data = dataset[0]
|
|
88
|
+
single_space = scu.construct_space_from_data(tmp_data, PyTorchComputeBackend)
|
|
89
|
+
super().__init__(
|
|
90
|
+
single_space,
|
|
91
|
+
None
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
def __len__(self) -> int:
|
|
95
|
+
return len(self.dataset)
|
|
96
|
+
|
|
97
|
+
def get_at_with_metadata(self, idx):
|
|
98
|
+
if isinstance(idx, int):
|
|
99
|
+
data = self.dataset[idx]
|
|
100
|
+
return data, {}
|
|
101
|
+
elif idx is Ellipsis:
|
|
102
|
+
idx = self.backend.arange(0, len(self))
|
|
103
|
+
elif isinstance(idx, slice):
|
|
104
|
+
idx = self.backend.arange(*idx.indices(len(self)))
|
|
105
|
+
elif self.backend.is_backendarray(idx):
|
|
106
|
+
if self.backend.dtype_is_boolean(idx.dtype):
|
|
107
|
+
idx = self.backend.nonzero(idx)[0]
|
|
108
|
+
else:
|
|
109
|
+
assert self.backend.dtype_is_real_integer(idx.dtype), "Index array must be of integer or boolean type."
|
|
110
|
+
idx = (idx + len(self)) % len(self)
|
|
111
|
+
|
|
112
|
+
idx_list = self.backend.to_numpy(idx).tolist()
|
|
113
|
+
if hasattr(self.dataset, "__getitems__"):
|
|
114
|
+
data_list = self.dataset.__getitems__(idx_list)
|
|
115
|
+
else:
|
|
116
|
+
data_list = [self.dataset[i] for i in idx_list]
|
|
117
|
+
aggregated_data = sbu.concatenate(self._batched_space, data_list)
|
|
118
|
+
return aggregated_data, {}
|
|
119
|
+
|
|
120
|
+
def get_at(self, idx):
|
|
121
|
+
data, _ = self.get_at_with_metadata(idx)
|
|
122
|
+
return data
|
|
@@ -188,8 +188,12 @@ class EpisodeStorageBase(SpaceStorage[
|
|
|
188
188
|
index = self.backend.arange(0, self.capacity, device=self.device)
|
|
189
189
|
else:
|
|
190
190
|
index = self.backend.arange(0, self.length, device=self.device)
|
|
191
|
-
elif self.backend.is_backendarray(index)
|
|
192
|
-
|
|
191
|
+
elif self.backend.is_backendarray(index):
|
|
192
|
+
if self.backend.dtype_is_boolean(index.dtype):
|
|
193
|
+
index = self.backend.nonzero(index)[0]
|
|
194
|
+
else:
|
|
195
|
+
assert self.backend.dtype_is_real_integer(index.dtype), "Index array must be of integer or boolean type."
|
|
196
|
+
index = (index + (self.capacity if self.capacity is not None else self.length)) % (self.capacity if self.capacity is not None else self.length)
|
|
193
197
|
|
|
194
198
|
batch_size = index.shape[0] if self.backend.is_backendarray(index) else 1
|
|
195
199
|
|
|
@@ -368,8 +372,13 @@ class EpisodeStorageBase(SpaceStorage[
|
|
|
368
372
|
index = self.backend.arange(0, self.capacity, device=self.device)
|
|
369
373
|
else:
|
|
370
374
|
index = self.backend.arange(0, self.length, device=self.device)
|
|
371
|
-
elif self.backend.is_backendarray(index)
|
|
372
|
-
|
|
375
|
+
elif self.backend.is_backendarray(index):
|
|
376
|
+
if self.backend.dtype_is_boolean(index.dtype):
|
|
377
|
+
index = self.backend.nonzero(index)[0]
|
|
378
|
+
else:
|
|
379
|
+
assert self.backend.dtype_is_real_integer(index.dtype), "Index array must be of integer or boolean type."
|
|
380
|
+
index = (index + (self.capacity if self.capacity is not None else self.length)) % (self.capacity if self.capacity is not None else self.length)
|
|
381
|
+
|
|
373
382
|
assert self.backend.is_backendarray(index) and self.backend.dtype_is_real_integer(index.dtype) and len(index.shape) == 1, "Index must be a 1D array of integers"
|
|
374
383
|
sorted_indexes_arg = self.backend.argsort(index)
|
|
375
384
|
sorted_indexes = index[sorted_indexes_arg]
|
|
@@ -72,6 +72,7 @@ class PytorchTensorStorage(SpaceStorage[
|
|
|
72
72
|
capacity : Optional[int] = None,
|
|
73
73
|
read_only : bool = True,
|
|
74
74
|
multiprocessing : bool = False,
|
|
75
|
+
**kwargs,
|
|
75
76
|
) -> "PytorchTensorStorage":
|
|
76
77
|
assert single_instance_space.backend is PyTorchComputeBackend, "PytorchTensorStorage only supports PyTorch backend"
|
|
77
78
|
assert capacity is not None, "Capacity must be specified when creating a new tensor"
|