unienv 0.0.1b7__tar.gz → 0.0.1b8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. {unienv-0.0.1b7/unienv.egg-info → unienv-0.0.1b8}/PKG-INFO +1 -1
  2. {unienv-0.0.1b7 → unienv-0.0.1b8}/pyproject.toml +1 -1
  3. {unienv-0.0.1b7 → unienv-0.0.1b8/unienv.egg-info}/PKG-INFO +1 -1
  4. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv.egg-info/SOURCES.txt +1 -0
  5. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/batches/combined_batch.py +110 -67
  6. unienv-0.0.1b8/unienv_data/integrations/huggingface.py +47 -0
  7. unienv-0.0.1b8/unienv_data/integrations/pytorch.py +122 -0
  8. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/storages/_episode_storage.py +13 -4
  9. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/storages/pytorch.py +1 -0
  10. unienv-0.0.1b8/unienv_data/storages/video_storage.py +535 -0
  11. unienv-0.0.1b7/unienv_data/integrations/pytorch.py +0 -63
  12. unienv-0.0.1b7/unienv_data/storages/video_storage.py +0 -297
  13. {unienv-0.0.1b7 → unienv-0.0.1b8}/LICENSE +0 -0
  14. {unienv-0.0.1b7 → unienv-0.0.1b8}/README.md +0 -0
  15. {unienv-0.0.1b7 → unienv-0.0.1b8}/setup.cfg +0 -0
  16. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv.egg-info/dependency_links.txt +0 -0
  17. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv.egg-info/requires.txt +0 -0
  18. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv.egg-info/top_level.txt +0 -0
  19. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/__init__.py +0 -0
  20. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/base/__init__.py +0 -0
  21. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/base/common.py +0 -0
  22. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/base/storage.py +0 -0
  23. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/batches/__init__.py +0 -0
  24. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/batches/backend_compat.py +0 -0
  25. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/batches/framestack_batch.py +0 -0
  26. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/batches/slicestack_batch.py +0 -0
  27. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/batches/transformations.py +0 -0
  28. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/replay_buffer/__init__.py +0 -0
  29. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/replay_buffer/replay_buffer.py +0 -0
  30. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/replay_buffer/trajectory_replay_buffer.py +0 -0
  31. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/samplers/__init__.py +0 -0
  32. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/samplers/multiprocessing_sampler.py +0 -0
  33. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/samplers/step_sampler.py +0 -0
  34. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/storages/_list_storage.py +0 -0
  35. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/storages/backend_compat.py +0 -0
  36. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/storages/dict_storage.py +0 -0
  37. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/storages/flattened.py +0 -0
  38. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/storages/hdf5.py +0 -0
  39. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/storages/image_storage.py +0 -0
  40. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/storages/npz_storage.py +0 -0
  41. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/storages/transformation.py +0 -0
  42. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/third_party/tensordict/memmap_tensor.py +0 -0
  43. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_data/transformations/image_compress.py +0 -0
  44. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/__init__.py +0 -0
  45. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/backends/__init__.py +0 -0
  46. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/backends/base.py +0 -0
  47. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/backends/jax.py +0 -0
  48. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/backends/numpy.py +0 -0
  49. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/backends/pytorch.py +0 -0
  50. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/backends/serialization.py +0 -0
  51. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/env_base/__init__.py +0 -0
  52. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/env_base/env.py +0 -0
  53. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/env_base/funcenv.py +0 -0
  54. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/env_base/funcenv_wrapper.py +0 -0
  55. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/env_base/vec_env.py +0 -0
  56. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/env_base/wrapper.py +0 -0
  57. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/func_wrapper/__init__.py +0 -0
  58. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/func_wrapper/frame_stack.py +0 -0
  59. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/func_wrapper/transformation.py +0 -0
  60. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/space/__init__.py +0 -0
  61. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/space/space.py +0 -0
  62. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/space/space_utils/__init__.py +0 -0
  63. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/space/space_utils/batch_utils.py +0 -0
  64. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/space/space_utils/construct_utils.py +0 -0
  65. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/space/space_utils/flatten_utils.py +0 -0
  66. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/space/space_utils/gym_utils.py +0 -0
  67. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/space/space_utils/serialization_utils.py +0 -0
  68. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/space/spaces/__init__.py +0 -0
  69. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/space/spaces/batched.py +0 -0
  70. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/space/spaces/binary.py +0 -0
  71. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/space/spaces/box.py +0 -0
  72. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/space/spaces/dict.py +0 -0
  73. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/space/spaces/dynamic_box.py +0 -0
  74. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/space/spaces/graph.py +0 -0
  75. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/space/spaces/text.py +0 -0
  76. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/space/spaces/tuple.py +0 -0
  77. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/space/spaces/union.py +0 -0
  78. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/transformations/__init__.py +0 -0
  79. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/transformations/batch_and_unbatch.py +0 -0
  80. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/transformations/chained_transform.py +0 -0
  81. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/transformations/crop.py +0 -0
  82. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/transformations/dict_transform.py +0 -0
  83. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/transformations/filter_dict.py +0 -0
  84. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/transformations/identity.py +0 -0
  85. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/transformations/image_resize.py +0 -0
  86. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/transformations/iter_transform.py +0 -0
  87. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/transformations/rescale.py +0 -0
  88. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/transformations/transformation.py +0 -0
  89. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/utils/control_util.py +0 -0
  90. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/utils/framestack_queue.py +0 -0
  91. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/utils/seed_util.py +0 -0
  92. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/utils/stateclass.py +0 -0
  93. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/utils/symbol_util.py +0 -0
  94. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/utils/vec_util.py +0 -0
  95. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/world/__init__.py +0 -0
  96. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/world/combined_funcnode.py +0 -0
  97. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/world/combined_node.py +0 -0
  98. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/world/funcnode.py +0 -0
  99. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/world/funcworld.py +0 -0
  100. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/world/node.py +0 -0
  101. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/world/world.py +0 -0
  102. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/wrapper/__init__.py +0 -0
  103. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/wrapper/action_rescale.py +0 -0
  104. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/wrapper/backend_compat.py +0 -0
  105. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/wrapper/batch_and_unbatch.py +0 -0
  106. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/wrapper/control_frequency_limit.py +0 -0
  107. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/wrapper/flatten.py +0 -0
  108. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/wrapper/frame_stack.py +0 -0
  109. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/wrapper/gym_compat.py +0 -0
  110. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/wrapper/time_limit.py +0 -0
  111. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/wrapper/transformation.py +0 -0
  112. {unienv-0.0.1b7 → unienv-0.0.1b8}/unienv_interface/wrapper/video_record.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: unienv
3
- Version: 0.0.1b7
3
+ Version: 0.0.1b8
4
4
  Summary: Unified robot environment framework supporting multiple tensor and simulation backends
5
5
  License-Expression: MIT
6
6
  Project-URL: Homepage, https://github.com/UniEnvOrg/UniEnv
@@ -3,7 +3,7 @@ name = "unienv"
3
3
  description = "Unified robot environment framework supporting multiple tensor and simulation backends"
4
4
  readme = "README.md"
5
5
  license = "MIT"
6
- version = "0.0.1b7"
6
+ version = "0.0.1b8"
7
7
  requires-python = ">= 3.10"
8
8
  dependencies = [
9
9
  "numpy",
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: unienv
3
- Version: 0.0.1b7
3
+ Version: 0.0.1b8
4
4
  Summary: Unified robot environment framework supporting multiple tensor and simulation backends
5
5
  License-Expression: MIT
6
6
  Project-URL: Homepage, https://github.com/UniEnvOrg/UniEnv
@@ -16,6 +16,7 @@ unienv_data/batches/combined_batch.py
16
16
  unienv_data/batches/framestack_batch.py
17
17
  unienv_data/batches/slicestack_batch.py
18
18
  unienv_data/batches/transformations.py
19
+ unienv_data/integrations/huggingface.py
19
20
  unienv_data/integrations/pytorch.py
20
21
  unienv_data/replay_buffer/__init__.py
21
22
  unienv_data/replay_buffer/replay_buffer.py
@@ -12,6 +12,102 @@ from unienv_interface.space.space_utils import batch_utils as sbu, flatten_utils
12
12
 
13
13
  from ..base.common import BatchBase, IndexableType, BatchT
14
14
 
15
+ __all__ = [
16
+ "convert_single_index_to_batch",
17
+ "convert_index_to_batch",
18
+ "CombinedBatch",
19
+ ]
20
+
21
+ def convert_single_index_to_batch(
22
+ backend : ComputeBackend,
23
+ index_starts : BArrayType, # (num_batches, )
24
+ total_length : int, # Total length of all batches
25
+ idx : int,
26
+ device : Optional[BDeviceType] = None,
27
+ ) -> Tuple[int, int]:
28
+ """
29
+ Convert a single index for this batch to a tuple of:
30
+ - The index of the batch
31
+ - The index to index into the batch
32
+ """
33
+ assert -total_length <= idx < total_length, f"Index {idx} out of bounds for batch of size {total_length}"
34
+ if idx < 0:
35
+ idx += total_length
36
+ batch_index = int(backend.sum(
37
+ idx >= index_starts
38
+ ) - 1)
39
+ return batch_index, idx - int(index_starts[batch_index])
40
+
41
+ def convert_index_to_batch(
42
+ backend : ComputeBackend,
43
+ index_starts : BArrayType, # (num_batches, )
44
+ total_length : int, # Total length of all batches
45
+ idx : Union[IndexableType, BArrayType],
46
+ device : Optional[BDeviceType] = None,
47
+ ) -> Tuple[
48
+ int,
49
+ Union[List[
50
+ Tuple[int, BArrayType, BArrayType]
51
+ ], int]
52
+ ]:
53
+ """
54
+ Convert an index for this batch to a tuple of:
55
+ - The length of the resulting array
56
+ - List of tuples, each containing:
57
+ - The index of the batch
58
+ - The index to index into the batch
59
+ - The bool mask to index into the resulting array
60
+ """
61
+ if isinstance(idx, slice):
62
+ idx_array = backend.arange(
63
+ *idx.indices(total_length),
64
+ dtype=backend.default_index_dtype,
65
+ device=device
66
+ )
67
+ elif idx is Ellipsis:
68
+ idx_array = backend.arange(
69
+ total_length,
70
+ dtype=backend.default_index_dtype,
71
+ device=device
72
+ )
73
+ elif backend.is_backendarray(idx):
74
+ assert len(idx.shape) == 1, "Index must be 1D"
75
+ assert backend.dtype_is_real_integer(idx.dtype) or backend.dtype_is_boolean(idx.dtype), \
76
+ f"Index must be of integer or boolean type, got {idx.dtype}"
77
+ if backend.dtype_is_boolean(idx.dtype):
78
+ assert idx.shape[0] == total_length, f"Boolean index must have the same length as the batch, got {idx.shape[0]} vs {total_length}"
79
+ idx_array = backend.nonzero(idx)[0]
80
+ else:
81
+ assert backend.all(backend.logical_and(-total_length <= idx, idx < total_length)), \
82
+ f"Index array contains out of bounds indices for batch of size {total_length}"
83
+ idx_array = (idx + total_length) % total_length
84
+ else:
85
+ raise ValueError(f"Invalid index type: {type(idx)}")
86
+
87
+ assert bool(backend.all(
88
+ backend.logical_and(
89
+ -total_length <= idx_array,
90
+ idx_array < total_length
91
+ )
92
+ )), f"Index {idx} converted to {idx_array} is out of bounds for batch of size {total_length}"
93
+
94
+ # Convert negative indices to positive indices
95
+ idx_array = backend.at(idx_array)[idx_array < 0].add(total_length)
96
+ idx_array_bigger = idx_array[:, None] >= index_starts[None, :] # (idx_array_shape, len(self.batches))
97
+ idx_array_batch_idx = backend.sum(
98
+ idx_array_bigger,
99
+ axis=-1
100
+ ) - 1 # (idx_array_shape, )
101
+
102
+ result_batch_list = []
103
+ batch_indexes = backend.unique_values(idx_array_batch_idx)
104
+ for i in range(batch_indexes.shape[0]):
105
+ batch_index = int(batch_indexes[i])
106
+ result_mask = idx_array_batch_idx == batch_index
107
+ index_into_batch = idx_array[result_mask] - index_starts[batch_index]
108
+ result_batch_list.append((batch_index, index_into_batch, result_mask))
109
+ return idx_array.shape[0], result_batch_list
110
+
15
111
  class CombinedBatch(BatchBase[
16
112
  BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType
17
113
  ]):
@@ -83,18 +179,13 @@ class CombinedBatch(BatchBase[
83
179
  return self.index_caches[-1, 1]
84
180
 
85
181
  def _convert_single_index(self, idx : int) -> Tuple[int, int]:
86
- """
87
- Convert a single index to a tuple containing
88
- - the batch index
89
- - the index within the batch
90
- """
91
- assert -len(self) <= idx < len(self), f"Index {idx} out of bounds for batch of size {len(self)}"
92
- if idx < 0:
93
- idx += len(self)
94
- batch_index = int(self.backend.sum(
95
- idx >= self.index_caches[:, 0]
96
- ) - 1)
97
- return batch_index, idx - int(self.index_caches[batch_index, 0])
182
+ return convert_single_index_to_batch(
183
+ self.backend,
184
+ self.index_caches[:, 0],
185
+ len(self),
186
+ idx,
187
+ device=self.device,
188
+ )
98
189
 
99
190
  def _convert_index(self, idx : Union[IndexableType, BArrayType]) -> Tuple[
100
191
  int,
@@ -102,61 +193,13 @@ class CombinedBatch(BatchBase[
102
193
  Tuple[int, BArrayType, BArrayType]
103
194
  ]
104
195
  ]:
105
- """
106
- Convert an index for this batch to a tuple of:
107
- - The length of the resulting array
108
- - List of tuples, each containing:
109
- - The index of the batch
110
- - The index to index into the batch
111
- - The bool mask to index into the resulting array
112
- """
113
- if isinstance(idx, slice):
114
- idx_array = self.backend.arange(
115
- *idx.indices(len(self)),
116
- dtype=self.backend.default_integer_dtype,
117
- device=self.device
118
- )
119
- elif idx is Ellipsis:
120
- idx_array = self.backend.arange(
121
- len(self),
122
- dtype=self.backend.default_integer_dtype,
123
- device=self.device
124
- )
125
- elif self.backend.is_backendarray(idx):
126
- assert len(idx.shape) == 1, "Index must be 1D"
127
- assert self.backend.dtype_is_real_integer(idx.dtype) or self.backend.dtype_is_boolean(idx.dtype), \
128
- f"Index must be of integer or boolean type, got {idx.dtype}"
129
- if self.backend.dtype_is_boolean(idx.dtype):
130
- assert idx.shape[0] == len(self), f"Boolean index must have the same length as the batch, got {idx.shape[0]} vs {len(self)}"
131
- idx_array = self.backend.nonzero(idx)[0]
132
- else:
133
- idx_array = idx
134
- else:
135
- raise ValueError(f"Invalid index type: {type(idx)}")
136
-
137
- assert bool(self.backend.all(
138
- self.backend.logical_and(
139
- -len(self) <= idx_array,
140
- idx_array < len(self)
141
- )
142
- )), f"Index {idx} converted to {idx_array} is out of bounds for batch of size {len(self)}"
143
-
144
- # Convert negative indices to positive indices
145
- idx_array = self.backend.at(idx_array)[idx_array < 0].add(len(self))
146
- idx_array_bigger = idx_array[:, None] >= self.index_caches[None, :, 0] # (idx_array_shape, len(self.batches))
147
- idx_array_batch_idx = self.backend.sum(
148
- idx_array_bigger,
149
- axis=-1
150
- ) - 1 # (idx_array_shape, )
151
-
152
- result_batch_list = []
153
- batch_indexes = self.backend.unique_values(idx_array_batch_idx)
154
- for i in range(batch_indexes.shape[0]):
155
- batch_index = int(batch_indexes[i])
156
- result_mask = idx_array_batch_idx == batch_index
157
- index_into_batch = idx_array[result_mask] - self.index_caches[batch_index, 0]
158
- result_batch_list.append((batch_index, index_into_batch, result_mask))
159
- return idx_array.shape[0], result_batch_list
196
+ return convert_index_to_batch(
197
+ self.backend,
198
+ self.index_caches[:, 0],
199
+ len(self),
200
+ idx,
201
+ device=self.device,
202
+ )
160
203
 
161
204
  def get_flattened_at(self, idx : Union[IndexableType, BaseException]) -> BArrayType:
162
205
  if isinstance(idx, int):
@@ -0,0 +1,47 @@
1
+ from typing import Optional
2
+ from datasets import Dataset as HFDataset
3
+ from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
4
+ from unienv_interface.space.space_utils import construct_utils as scu, batch_utils as sbu
5
+ from unienv_data.base import BatchBase, BatchT
6
+
7
+ __all__ = [
8
+ 'HFAsUniEnvDataset'
9
+ ]
10
+
11
+ class HFAsUniEnvDataset(BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]):
12
+ BACKEND_TO_FORMAT_MAP = {
13
+ "numpy": "numpy",
14
+ "pytorch": "torch",
15
+ "jax": "jax",
16
+ }
17
+
18
+ is_mutable = False
19
+
20
+ def __init__(
21
+ self,
22
+ hf_dataset: HFDataset,
23
+ backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
24
+ device : Optional[BDeviceType] = None,
25
+ ) -> None:
26
+ kwargs = {}
27
+ if backend.simplified_name != 'numpy':
28
+ kwargs['device'] = device
29
+ self.hf_dataset = hf_dataset.with_format(
30
+ self.BACKEND_TO_FORMAT_MAP[backend.simplified_name],
31
+ **kwargs
32
+ )
33
+ first_data = self.hf_dataset[0]
34
+ super().__init__(
35
+ scu.construct_space_from_data(first_data, backend)
36
+ )
37
+
38
+ def __len__(self) -> int:
39
+ return len(self.hf_dataset)
40
+
41
+ def get_at_with_metadata(self, idx):
42
+ if self.backend.is_backendarray(idx) and self.backend.dtype_is_boolean(idx):
43
+ idx = self.backend.nonzero(idx)[0]
44
+ return self.hf_dataset[idx], {}
45
+
46
+ def get_at(self, idx):
47
+ return self.get_at_with_metadata(idx)[0]
@@ -0,0 +1,122 @@
1
+ from typing import Optional, Union, Tuple, Dict, Any
2
+
3
+ from torch.utils.data import Dataset
4
+
5
+ from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
6
+ from unienv_interface.backends.pytorch import PyTorchComputeBackend, PyTorchArrayType, PyTorchDeviceType, PyTorchDtypeType, PyTorchRNGType
7
+ from unienv_interface.space.space_utils import construct_utils as scu, batch_utils as sbu
8
+ from unienv_data.base import BatchBase, BatchT
9
+
10
+ __all__ = [
11
+ "UniEnvAsPyTorchDataset",
12
+ "PyTorchAsUniEnvDataset",
13
+ ]
14
+
15
+ class UniEnvAsPyTorchDataset(Dataset):
16
+ def __init__(
17
+ self,
18
+ batch : BatchBase[BatchT, PyTorchArrayType, PyTorchDeviceType, PyTorchDtypeType, PyTorchRNGType],
19
+ include_metadata : bool = False,
20
+ ):
21
+ """
22
+ A PyTorch Dataset wrapper for UniEnvPy batches.
23
+ Note that UniEnv's `BatchBase` will automatically collate data when indexed with batches, and therefore in the dataloader you can set `collate_fn=None`.
24
+
25
+ Args:
26
+ batch (BatchBase): The UniEnvPy batch to wrap.
27
+ """
28
+ self.batch = batch
29
+ self.include_metadata = include_metadata
30
+
31
+ def __len__(self) -> int:
32
+ """
33
+ Get the length of the dataset.
34
+
35
+ Returns:
36
+ int: The number of items in the dataset.
37
+ """
38
+ return len(self.batch)
39
+
40
+ def __getitem__(self, index) -> Union[BatchT, Tuple[BatchT, Dict[str, Any]]]:
41
+ """
42
+ Get an item from the dataset.
43
+
44
+ Args:
45
+ index (int): The index of the item to retrieve.
46
+
47
+ Returns:
48
+ Union[BatchT, Tuple[BatchT, Dict[str, Any]]]: The batch data at the specified index,
49
+ optionally including metadata if `include_metadata` is True.
50
+ """
51
+ assert isinstance(index, int), "Index must be an integer."
52
+ if self.include_metadata:
53
+ return self.batch.get_at_with_metadata(index)
54
+ return self.batch.get_at(index)
55
+
56
+ def __getitems__(self, indices: list[int]) -> list[Union[BatchT, Tuple[BatchT, Dict[str, Any]]]]:
57
+ """
58
+ Get multiple items from the dataset.
59
+
60
+ Args:
61
+ indices (list[int]): The indices of the items to retrieve.
62
+
63
+ Returns:
64
+ Union[BatchT, Tuple[BatchT, Dict[str, Any]]]: Batch data at the specified indices,
65
+ optionally including metadata if `include_metadata` is True.
66
+ """
67
+ indices = self.batch.backend.asarray(indices, dtype=self.batch.backend.default_integer_dtype, device=self.batch.device)
68
+ if self.include_metadata:
69
+ return self.batch.get_at_with_metadata(indices)
70
+ return self.batch.get_at(indices)
71
+
72
+ class PyTorchAsUniEnvDataset(BatchBase[BatchT, PyTorchArrayType, PyTorchDeviceType, PyTorchDtypeType, PyTorchRNGType]):
73
+ is_mutable = False
74
+ def __init__(
75
+ self,
76
+ dataset: Dataset,
77
+ ):
78
+ """
79
+ A UniEnvPy BatchBase wrapper for PyTorch Datasets.
80
+
81
+ Args:
82
+ dataset (Dataset): The PyTorch Dataset to wrap.
83
+ """
84
+
85
+ assert len(dataset) >= 0, "The provided PyTorch Dataset must have a defined length."
86
+ self.dataset = dataset
87
+ tmp_data = dataset[0]
88
+ single_space = scu.construct_space_from_data(tmp_data, PyTorchComputeBackend)
89
+ super().__init__(
90
+ single_space,
91
+ None
92
+ )
93
+
94
+ def __len__(self) -> int:
95
+ return len(self.dataset)
96
+
97
+ def get_at_with_metadata(self, idx):
98
+ if isinstance(idx, int):
99
+ data = self.dataset[idx]
100
+ return data, {}
101
+ elif idx is Ellipsis:
102
+ idx = self.backend.arange(0, len(self))
103
+ elif isinstance(idx, slice):
104
+ idx = self.backend.arange(*idx.indices(len(self)))
105
+ elif self.backend.is_backendarray(idx):
106
+ if self.backend.dtype_is_boolean(idx.dtype):
107
+ idx = self.backend.nonzero(idx)[0]
108
+ else:
109
+ assert self.backend.dtype_is_real_integer(idx.dtype), "Index array must be of integer or boolean type."
110
+ idx = (idx + len(self)) % len(self)
111
+
112
+ idx_list = self.backend.to_numpy(idx).tolist()
113
+ if hasattr(self.dataset, "__getitems__"):
114
+ data_list = self.dataset.__getitems__(idx_list)
115
+ else:
116
+ data_list = [self.dataset[i] for i in idx_list]
117
+ aggregated_data = sbu.concatenate(self._batched_space, data_list)
118
+ return aggregated_data, {}
119
+
120
+ def get_at(self, idx):
121
+ data, _ = self.get_at_with_metadata(idx)
122
+ return data
@@ -188,8 +188,12 @@ class EpisodeStorageBase(SpaceStorage[
188
188
  index = self.backend.arange(0, self.capacity, device=self.device)
189
189
  else:
190
190
  index = self.backend.arange(0, self.length, device=self.device)
191
- elif self.backend.is_backendarray(index) and self.backend.dtype_is_boolean(index.dtype):
192
- index = self.backend.nonzero(index)[0]
191
+ elif self.backend.is_backendarray(index):
192
+ if self.backend.dtype_is_boolean(index.dtype):
193
+ index = self.backend.nonzero(index)[0]
194
+ else:
195
+ assert self.backend.dtype_is_real_integer(index.dtype), "Index array must be of integer or boolean type."
196
+ index = (index + (self.capacity if self.capacity is not None else self.length)) % (self.capacity if self.capacity is not None else self.length)
193
197
 
194
198
  batch_size = index.shape[0] if self.backend.is_backendarray(index) else 1
195
199
 
@@ -368,8 +372,13 @@ class EpisodeStorageBase(SpaceStorage[
368
372
  index = self.backend.arange(0, self.capacity, device=self.device)
369
373
  else:
370
374
  index = self.backend.arange(0, self.length, device=self.device)
371
- elif self.backend.is_backendarray(index) and self.backend.dtype_is_boolean(index.dtype):
372
- index = self.backend.nonzero(index)[0]
375
+ elif self.backend.is_backendarray(index):
376
+ if self.backend.dtype_is_boolean(index.dtype):
377
+ index = self.backend.nonzero(index)[0]
378
+ else:
379
+ assert self.backend.dtype_is_real_integer(index.dtype), "Index array must be of integer or boolean type."
380
+ index = (index + (self.capacity if self.capacity is not None else self.length)) % (self.capacity if self.capacity is not None else self.length)
381
+
373
382
  assert self.backend.is_backendarray(index) and self.backend.dtype_is_real_integer(index.dtype) and len(index.shape) == 1, "Index must be a 1D array of integers"
374
383
  sorted_indexes_arg = self.backend.argsort(index)
375
384
  sorted_indexes = index[sorted_indexes_arg]
@@ -72,6 +72,7 @@ class PytorchTensorStorage(SpaceStorage[
72
72
  capacity : Optional[int] = None,
73
73
  read_only : bool = True,
74
74
  multiprocessing : bool = False,
75
+ **kwargs,
75
76
  ) -> "PytorchTensorStorage":
76
77
  assert single_instance_space.backend is PyTorchComputeBackend, "PytorchTensorStorage only supports PyTorch backend"
77
78
  assert capacity is not None, "Capacity must be specified when creating a new tensor"