unienv 0.0.1b4__py3-none-any.whl → 0.0.1b6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {unienv-0.0.1b4.dist-info → unienv-0.0.1b6.dist-info}/METADATA +3 -2
- {unienv-0.0.1b4.dist-info → unienv-0.0.1b6.dist-info}/RECORD +43 -32
- {unienv-0.0.1b4.dist-info → unienv-0.0.1b6.dist-info}/WHEEL +1 -1
- unienv_data/base/common.py +25 -10
- unienv_data/base/storage.py +2 -0
- unienv_data/batches/backend_compat.py +1 -1
- unienv_data/batches/combined_batch.py +1 -1
- unienv_data/batches/slicestack_batch.py +1 -0
- unienv_data/replay_buffer/replay_buffer.py +179 -65
- unienv_data/replay_buffer/trajectory_replay_buffer.py +230 -163
- unienv_data/storages/_episode_storage.py +438 -0
- unienv_data/storages/_list_storage.py +136 -0
- unienv_data/storages/backend_compat.py +268 -0
- unienv_data/storages/dict_storage.py +39 -7
- unienv_data/storages/flattened.py +11 -4
- unienv_data/storages/hdf5.py +11 -0
- unienv_data/storages/image_storage.py +144 -0
- unienv_data/storages/npz_storage.py +135 -0
- unienv_data/storages/pytorch.py +17 -10
- unienv_data/storages/transformation.py +16 -1
- unienv_data/storages/video_storage.py +297 -0
- unienv_data/third_party/tensordict/memmap_tensor.py +1174 -0
- unienv_data/transformations/image_compress.py +97 -21
- unienv_interface/func_wrapper/frame_stack.py +1 -1
- unienv_interface/space/space_utils/batch_utils.py +5 -1
- unienv_interface/space/space_utils/flatten_utils.py +8 -2
- unienv_interface/space/spaces/dict.py +6 -0
- unienv_interface/space/spaces/tuple.py +4 -4
- unienv_interface/transformations/__init__.py +3 -1
- unienv_interface/transformations/batch_and_unbatch.py +42 -4
- unienv_interface/transformations/chained_transform.py +9 -8
- unienv_interface/transformations/crop.py +69 -0
- unienv_interface/transformations/dict_transform.py +8 -2
- unienv_interface/transformations/identity.py +16 -0
- unienv_interface/transformations/image_resize.py +106 -0
- unienv_interface/transformations/iter_transform.py +92 -0
- unienv_interface/transformations/rescale.py +24 -5
- unienv_interface/utils/symbol_util.py +7 -1
- unienv_interface/wrapper/backend_compat.py +1 -1
- unienv_interface/wrapper/frame_stack.py +1 -1
- {unienv-0.0.1b4.dist-info → unienv-0.0.1b6.dist-info}/licenses/LICENSE +0 -0
- {unienv-0.0.1b4.dist-info → unienv-0.0.1b6.dist-info}/top_level.txt +0 -0
- /unienv_interface/utils/{data_queue.py → framestack_queue.py} +0 -0
|
@@ -0,0 +1,438 @@
|
|
|
1
|
+
from importlib import metadata
|
|
2
|
+
from typing import Generic, TypeVar, Generic, Optional, Any, Dict, Tuple, Sequence, Union, List, Iterable, Type
|
|
3
|
+
|
|
4
|
+
from unienv_interface.space import Space
|
|
5
|
+
from unienv_interface.space.space_utils import batch_utils as sbu, flatten_utils as sfu
|
|
6
|
+
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
7
|
+
from unienv_interface.utils.symbol_util import *
|
|
8
|
+
|
|
9
|
+
from unienv_data.base import SpaceStorage, BatchT, IndexableType
|
|
10
|
+
|
|
11
|
+
import os
|
|
12
|
+
import glob
|
|
13
|
+
import shutil
|
|
14
|
+
from abc import abstractmethod
|
|
15
|
+
|
|
16
|
+
class EpisodeStorageBase(SpaceStorage[
|
|
17
|
+
BatchT,
|
|
18
|
+
BArrayType,
|
|
19
|
+
BDeviceType,
|
|
20
|
+
BDtypeType,
|
|
21
|
+
BRNGType,
|
|
22
|
+
]):
|
|
23
|
+
"""
|
|
24
|
+
Base class for episode storage implementations.
|
|
25
|
+
An episode storage stores episodes of data, where each episode can consist of multiple time steps.
|
|
26
|
+
Each episode is stored as a separate file in the specified cache directory, with a specified file extension.
|
|
27
|
+
The file naming convention is "{start_step_index}_{end_step_index}.{file_ext}", where start_index and end_index define the range of time steps in the episode.
|
|
28
|
+
Note that if the storage has a fixed capacity, the {end_step_index} can be smaller than {start_step_index} due to round-robin overwriting.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
# ========== Instance Implementations ==========
|
|
32
|
+
single_file_ext = None
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
single_instance_space: Space[Any, BDeviceType, BDtypeType, BRNGType],
|
|
37
|
+
file_ext : str,
|
|
38
|
+
cache_filename : Union[str, os.PathLike],
|
|
39
|
+
mutable : bool = True,
|
|
40
|
+
capacity : Optional[int] = None,
|
|
41
|
+
length : int = 0,
|
|
42
|
+
):
|
|
43
|
+
assert cache_filename is not None, "EpisodeStorage requires a cache filename"
|
|
44
|
+
super().__init__(single_instance_space)
|
|
45
|
+
self._batched_single_space = sbu.batch_space(self.single_instance_space, 1)
|
|
46
|
+
self.file_ext = file_ext
|
|
47
|
+
self._cache_path = cache_filename
|
|
48
|
+
self.is_mutable = mutable
|
|
49
|
+
self.capacity = capacity
|
|
50
|
+
self.length = length if capacity is None else capacity
|
|
51
|
+
# Cache of file ranges: List[(start_idx, end_idx)]
|
|
52
|
+
self._file_ranges: List[Tuple[int, int]] = []
|
|
53
|
+
self._rebuild_file_range_cache()
|
|
54
|
+
|
|
55
|
+
def _make_filename(self, start_idx: int, end_idx: int) -> str:
|
|
56
|
+
"""Construct a filename from start and end indices."""
|
|
57
|
+
if self.file_ext is not None:
|
|
58
|
+
return os.path.join(self._cache_path, f"{start_idx}_{end_idx}.{self.file_ext}")
|
|
59
|
+
else:
|
|
60
|
+
return os.path.join(self._cache_path, f"{start_idx}_{end_idx}")
|
|
61
|
+
|
|
62
|
+
def get_start_end_filename_iter(self) -> Iterable[Tuple[int, int, str]]:
|
|
63
|
+
"""Iterate over (start_idx, end_idx, filename) tuples, constructing filenames on the fly."""
|
|
64
|
+
for start_idx, end_idx in self._file_ranges:
|
|
65
|
+
yield start_idx, end_idx, self._make_filename(start_idx, end_idx)
|
|
66
|
+
|
|
67
|
+
def _rebuild_file_range_cache(self):
|
|
68
|
+
"""Rebuild the file range cache from disk."""
|
|
69
|
+
all_filenames = glob.glob(os.path.join(self._cache_path, f"*_*.{self.file_ext}" if self.file_ext is not None else "*_*"))
|
|
70
|
+
self._file_ranges = []
|
|
71
|
+
for filename in all_filenames:
|
|
72
|
+
base = os.path.basename(filename)
|
|
73
|
+
name_part = base[:-(len(self.file_ext) + 1)] if self.file_ext is not None else base
|
|
74
|
+
start_str, end_str = name_part.split("_")
|
|
75
|
+
start_idx = int(start_str)
|
|
76
|
+
end_idx = int(end_str)
|
|
77
|
+
self._file_ranges.append((start_idx, end_idx))
|
|
78
|
+
# Sort by start index for consistent iteration
|
|
79
|
+
self._file_ranges.sort(key=lambda x: x[0])
|
|
80
|
+
|
|
81
|
+
def _add_file_range(self, start_idx: int, end_idx: int):
|
|
82
|
+
"""Add a file range to the cache."""
|
|
83
|
+
self._file_ranges.append((start_idx, end_idx))
|
|
84
|
+
self._file_ranges.sort(key=lambda x: x[0])
|
|
85
|
+
|
|
86
|
+
def _remove_file_range(self, start_idx: int, end_idx: int):
|
|
87
|
+
"""Remove a file range from the cache."""
|
|
88
|
+
self._file_ranges = [(s, e) for s, e in self._file_ranges if not (s == start_idx and e == end_idx)]
|
|
89
|
+
|
|
90
|
+
def _get_file_length(self, filename: str) -> int:
|
|
91
|
+
"""Get the length (number of time steps) stored in a given episode file."""
|
|
92
|
+
base = os.path.basename(filename)
|
|
93
|
+
name_part = base[:-(len(self.file_ext) + 1)] if self.file_ext is not None else base
|
|
94
|
+
start_str, end_str = name_part.split("_")
|
|
95
|
+
start_idx = int(start_str)
|
|
96
|
+
end_idx = int(end_str)
|
|
97
|
+
if start_idx <= end_idx:
|
|
98
|
+
return end_idx - start_idx + 1
|
|
99
|
+
else:
|
|
100
|
+
assert self.capacity is not None, "Wrap-around file length calculation requires fixed capacity"
|
|
101
|
+
return (self.capacity - start_idx) + (end_idx + 1)
|
|
102
|
+
|
|
103
|
+
@property
|
|
104
|
+
def cache_filename(self) -> Union[str, os.PathLike]:
|
|
105
|
+
return self._cache_path
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def is_multiprocessing_safe(self) -> bool:
|
|
109
|
+
return not self.is_mutable
|
|
110
|
+
|
|
111
|
+
def convert_read_index_to_filenames_and_offsets(
|
|
112
|
+
self,
|
|
113
|
+
index: Union[IndexableType, BArrayType]
|
|
114
|
+
) -> Tuple[int, List[Tuple[Union[str, os.PathLike], Union[int, BArrayType], Union[int, BArrayType]]]]:
|
|
115
|
+
"""
|
|
116
|
+
Convert an index (which can be an integer, slice, list of integers, or backend array) to a list of filenames and offsets.
|
|
117
|
+
Each filename corresponds to an episode file, and the offset indicates the position within that episode.
|
|
118
|
+
"""
|
|
119
|
+
def generate_episode_ranges(
|
|
120
|
+
filename : Union[str, os.PathLike],
|
|
121
|
+
start_idx : int,
|
|
122
|
+
end_idx : int,
|
|
123
|
+
index : Union[int, BArrayType]
|
|
124
|
+
) -> Optional[Tuple[
|
|
125
|
+
Union[str, os.PathLike], # File Path (Absolute)
|
|
126
|
+
Union[int, BArrayType], # Index in File
|
|
127
|
+
Union[int, BArrayType], # Index into Data (Batch)
|
|
128
|
+
Optional[BArrayType]] # Remaining Indexes
|
|
129
|
+
]:
|
|
130
|
+
if isinstance(index, int):
|
|
131
|
+
if start_idx <= index <= end_idx:
|
|
132
|
+
offset = index - start_idx
|
|
133
|
+
return (filename, slice(offset, offset + 1), slice(0, 1), None)
|
|
134
|
+
elif self.capacity is not None and start_idx > end_idx and (
|
|
135
|
+
index >= start_idx or index <= end_idx
|
|
136
|
+
):
|
|
137
|
+
# Handle wrap-around case for fixed-capacity storage
|
|
138
|
+
if index >= start_idx:
|
|
139
|
+
offset = index - start_idx
|
|
140
|
+
else:
|
|
141
|
+
offset = (self.capacity - start_idx) + index
|
|
142
|
+
return (filename, slice(offset, offset + 1), slice(0, 1), None)
|
|
143
|
+
else:
|
|
144
|
+
return None
|
|
145
|
+
else:
|
|
146
|
+
assert self.backend.is_backendarray(index)
|
|
147
|
+
assert self.backend.dtype_is_real_integer(index.dtype)
|
|
148
|
+
|
|
149
|
+
if start_idx <= end_idx:
|
|
150
|
+
mask = self.backend.logical_and(
|
|
151
|
+
index >= start_idx,
|
|
152
|
+
index <= end_idx
|
|
153
|
+
)
|
|
154
|
+
if self.backend.sum(mask) == 0:
|
|
155
|
+
return None
|
|
156
|
+
in_range_indexes = index[mask]
|
|
157
|
+
index_to_data = self.backend.nonzero(mask)[0]
|
|
158
|
+
offsets = in_range_indexes - start_idx
|
|
159
|
+
remaining_indexes = index[~mask]
|
|
160
|
+
elif self.capacity is not None and start_idx > end_idx:
|
|
161
|
+
mask = self.backend.logical_or(
|
|
162
|
+
index >= start_idx,
|
|
163
|
+
index <= end_idx
|
|
164
|
+
)
|
|
165
|
+
if self.backend.sum(mask) == 0:
|
|
166
|
+
return None
|
|
167
|
+
in_range_indexes = index[mask]
|
|
168
|
+
index_to_data = self.backend.nonzero(mask)[0]
|
|
169
|
+
offsets = self.backend.where(
|
|
170
|
+
in_range_indexes >= start_idx,
|
|
171
|
+
in_range_indexes - start_idx,
|
|
172
|
+
(self.capacity - start_idx) + in_range_indexes
|
|
173
|
+
)
|
|
174
|
+
remaining_indexes = index[~mask]
|
|
175
|
+
else:
|
|
176
|
+
return None
|
|
177
|
+
if remaining_indexes.shape[0] == 0:
|
|
178
|
+
remaining_indexes = None
|
|
179
|
+
return (filename, offsets, index_to_data, remaining_indexes)
|
|
180
|
+
|
|
181
|
+
if isinstance(index, slice):
|
|
182
|
+
if self.capacity is not None:
|
|
183
|
+
index = self.backend.arange(*index.indices(self.capacity), device=self.device)
|
|
184
|
+
else:
|
|
185
|
+
index = self.backend.arange(*index.indices(self.length), device=self.device)
|
|
186
|
+
elif index is Ellipsis:
|
|
187
|
+
if self.capacity is not None:
|
|
188
|
+
index = self.backend.arange(0, self.capacity, device=self.device)
|
|
189
|
+
else:
|
|
190
|
+
index = self.backend.arange(0, self.length, device=self.device)
|
|
191
|
+
elif self.backend.is_backendarray(index) and self.backend.dtype_is_boolean(index.dtype):
|
|
192
|
+
index = self.backend.nonzero(index)[0]
|
|
193
|
+
|
|
194
|
+
batch_size = index.shape[0] if self.backend.is_backendarray(index) else 1
|
|
195
|
+
|
|
196
|
+
remaining_index = index
|
|
197
|
+
all_results = []
|
|
198
|
+
for start_idx, end_idx, filename in self.get_start_end_filename_iter():
|
|
199
|
+
result = generate_episode_ranges(filename, start_idx, end_idx, remaining_index)
|
|
200
|
+
if result is not None:
|
|
201
|
+
filename, offsets, batch_indexes, remaining_index = result
|
|
202
|
+
all_results.append((filename, offsets, batch_indexes))
|
|
203
|
+
if remaining_index is None:
|
|
204
|
+
break
|
|
205
|
+
|
|
206
|
+
assert remaining_index is None, f"Indexes {remaining_index} were not found in any episode files."
|
|
207
|
+
return (batch_size, all_results)
|
|
208
|
+
|
|
209
|
+
def extend_length(self, length):
|
|
210
|
+
assert self.capacity is None, "Cannot extend length of a fixed-capacity storage"
|
|
211
|
+
self.length += length
|
|
212
|
+
|
|
213
|
+
def remove_index_range(self, start_index: int, end_index: int):
|
|
214
|
+
"""
|
|
215
|
+
Remove data in the index range [start_index, end_index] (inclusive).
|
|
216
|
+
This handles both regular files and wrap-around files.
|
|
217
|
+
If start_index > end_index, this is a wrap-around removal covering [start_index, capacity) and [0, end_index].
|
|
218
|
+
"""
|
|
219
|
+
assert self.is_mutable, "Cannot remove index range from a read-only storage"
|
|
220
|
+
if start_index > end_index:
|
|
221
|
+
# Wrap-around removal: remove [start_index, capacity) and [0, end_index]
|
|
222
|
+
assert self.capacity is not None, "Wrap-around removal requires a fixed capacity"
|
|
223
|
+
self.remove_index_range(start_index, self.capacity - 1)
|
|
224
|
+
self.remove_index_range(0, end_index)
|
|
225
|
+
return
|
|
226
|
+
|
|
227
|
+
to_save = []
|
|
228
|
+
files_to_remove = [] # List of (start_idx, end_idx, filename)
|
|
229
|
+
for file_start_idx, file_end_idx, filename in self.get_start_end_filename_iter():
|
|
230
|
+
|
|
231
|
+
if file_start_idx <= file_end_idx:
|
|
232
|
+
# Regular (non-wrapping) file
|
|
233
|
+
file_len = file_end_idx - file_start_idx + 1
|
|
234
|
+
if file_start_idx >= start_index and file_end_idx <= end_index:
|
|
235
|
+
# Entire file is within the range to remove
|
|
236
|
+
files_to_remove.append((file_start_idx, file_end_idx, filename))
|
|
237
|
+
elif file_start_idx <= end_index and file_end_idx >= start_index:
|
|
238
|
+
# Partial overlap with the range to remove
|
|
239
|
+
# Keep data before the removal range
|
|
240
|
+
if file_start_idx < start_index:
|
|
241
|
+
data_before = self.get_from_file(filename, slice(0, start_index - file_start_idx), file_len)
|
|
242
|
+
to_save.append((data_before, file_start_idx, start_index - 1))
|
|
243
|
+
# Keep data after the removal range
|
|
244
|
+
if file_end_idx > end_index:
|
|
245
|
+
data_after = self.get_from_file(filename, slice(end_index - file_start_idx + 1, file_end_idx - file_start_idx + 1), file_len)
|
|
246
|
+
to_save.append((data_after, end_index + 1, file_end_idx))
|
|
247
|
+
files_to_remove.append((file_start_idx, file_end_idx, filename))
|
|
248
|
+
else:
|
|
249
|
+
# Wrap-around file (file_start_idx > file_end_idx)
|
|
250
|
+
# File contains indices [file_start_idx, capacity) and [0, file_end_idx]
|
|
251
|
+
file_len = (self.capacity - file_start_idx) + file_end_idx + 1
|
|
252
|
+
|
|
253
|
+
# Check if removal range overlaps with the high part [file_start_idx, capacity)
|
|
254
|
+
high_overlap_start = max(start_index, file_start_idx)
|
|
255
|
+
high_overlap_end = min(end_index, self.capacity - 1)
|
|
256
|
+
high_overlaps = high_overlap_start <= high_overlap_end
|
|
257
|
+
|
|
258
|
+
# Check if removal range overlaps with the low part [0, file_end_idx]
|
|
259
|
+
low_overlap_start = max(start_index, 0)
|
|
260
|
+
low_overlap_end = min(end_index, file_end_idx)
|
|
261
|
+
low_overlaps = low_overlap_start <= low_overlap_end
|
|
262
|
+
|
|
263
|
+
if not high_overlaps and not low_overlaps:
|
|
264
|
+
# No overlap, keep the file as is
|
|
265
|
+
continue
|
|
266
|
+
|
|
267
|
+
# Calculate what to keep from the high part
|
|
268
|
+
high_part_len = self.capacity - file_start_idx
|
|
269
|
+
if high_overlaps:
|
|
270
|
+
# Keep before the overlap
|
|
271
|
+
if file_start_idx < high_overlap_start:
|
|
272
|
+
keep_len = high_overlap_start - file_start_idx
|
|
273
|
+
data = self.get_from_file(filename, slice(0, keep_len), file_len)
|
|
274
|
+
to_save.append((data, file_start_idx, high_overlap_start - 1))
|
|
275
|
+
# Keep after the overlap (still in high part)
|
|
276
|
+
if high_overlap_end < self.capacity - 1:
|
|
277
|
+
keep_start_offset = high_overlap_end - file_start_idx + 1
|
|
278
|
+
keep_end_offset = high_part_len
|
|
279
|
+
data = self.get_from_file(filename, slice(keep_start_offset, keep_end_offset), file_len)
|
|
280
|
+
to_save.append((data, high_overlap_end + 1, self.capacity - 1))
|
|
281
|
+
else:
|
|
282
|
+
# Keep entire high part
|
|
283
|
+
data = self.get_from_file(filename, slice(0, high_part_len), file_len)
|
|
284
|
+
to_save.append((data, file_start_idx, self.capacity - 1))
|
|
285
|
+
|
|
286
|
+
# Calculate what to keep from the low part
|
|
287
|
+
if low_overlaps:
|
|
288
|
+
# Keep before the overlap
|
|
289
|
+
if 0 < low_overlap_start:
|
|
290
|
+
data = self.get_from_file(filename, slice(high_part_len, high_part_len + low_overlap_start), file_len)
|
|
291
|
+
to_save.append((data, 0, low_overlap_start - 1))
|
|
292
|
+
# Keep after the overlap
|
|
293
|
+
if low_overlap_end < file_end_idx:
|
|
294
|
+
keep_start_offset = high_part_len + low_overlap_end + 1
|
|
295
|
+
keep_end_offset = file_len
|
|
296
|
+
data = self.get_from_file(filename, slice(keep_start_offset, keep_end_offset), file_len)
|
|
297
|
+
to_save.append((data, low_overlap_end + 1, file_end_idx))
|
|
298
|
+
else:
|
|
299
|
+
# Keep entire low part
|
|
300
|
+
data = self.get_from_file(filename, slice(high_part_len, file_len), file_len)
|
|
301
|
+
to_save.append((data, 0, file_end_idx))
|
|
302
|
+
|
|
303
|
+
files_to_remove.append((file_start_idx, file_end_idx, filename))
|
|
304
|
+
|
|
305
|
+
# Remove files after reading all necessary data
|
|
306
|
+
for file_start_idx, file_end_idx, filename in files_to_remove:
|
|
307
|
+
os.remove(filename)
|
|
308
|
+
self._remove_file_range(file_start_idx, file_end_idx)
|
|
309
|
+
|
|
310
|
+
# Save the data that should be kept
|
|
311
|
+
for data, new_start_idx, new_end_idx in to_save:
|
|
312
|
+
new_filename = self._make_filename(new_start_idx, new_end_idx)
|
|
313
|
+
self.set_to_file(new_filename, data)
|
|
314
|
+
self._add_file_range(new_start_idx, new_end_idx)
|
|
315
|
+
|
|
316
|
+
def shrink_length(self, length):
|
|
317
|
+
assert self.is_mutable, "Cannot shrink length of a read-only storage"
|
|
318
|
+
assert self.capacity is None, "Cannot shrink length of a fixed-capacity storage"
|
|
319
|
+
from_len = self.length
|
|
320
|
+
to_len = max(from_len - length, 0)
|
|
321
|
+
self.remove_index_range(to_len, from_len - 1)
|
|
322
|
+
self.length = to_len
|
|
323
|
+
|
|
324
|
+
def __len__(self):
|
|
325
|
+
return self.length if self.capacity is None else self.capacity
|
|
326
|
+
|
|
327
|
+
@abstractmethod
|
|
328
|
+
def get_from_file(self, filename : str, index : Union[IndexableType, BArrayType], total_length : int) -> BatchT:
|
|
329
|
+
raise NotImplementedError
|
|
330
|
+
|
|
331
|
+
@abstractmethod
|
|
332
|
+
def set_to_file(self, filename : str, batched_value : BatchT):
|
|
333
|
+
raise NotImplementedError
|
|
334
|
+
|
|
335
|
+
def get(self, index):
|
|
336
|
+
batch_size, all_filename_offsets = self.convert_read_index_to_filenames_and_offsets(index)
|
|
337
|
+
result_space = sbu.batch_space(self.single_instance_space, batch_size)
|
|
338
|
+
result = result_space.create_empty()
|
|
339
|
+
for filename, file_offset, batch_indexes in all_filename_offsets:
|
|
340
|
+
file_len = self._get_file_length(filename)
|
|
341
|
+
data = self.get_from_file(filename, file_offset, file_len)
|
|
342
|
+
sbu.set_at(
|
|
343
|
+
result_space,
|
|
344
|
+
result,
|
|
345
|
+
batch_indexes,
|
|
346
|
+
data
|
|
347
|
+
)
|
|
348
|
+
if isinstance(index, int):
|
|
349
|
+
result = sbu.get_at(result_space, result, 0)
|
|
350
|
+
return result
|
|
351
|
+
|
|
352
|
+
def set(self, index, value):
|
|
353
|
+
assert self.is_mutable, "Storage is not mutable"
|
|
354
|
+
# Make sure the index is continuous
|
|
355
|
+
if isinstance(index, int):
|
|
356
|
+
self.remove_index_range(index, index)
|
|
357
|
+
filename = self._make_filename(index, index)
|
|
358
|
+
self.set_to_file(filename, sbu.concatenate(self._batched_single_space, [value]))
|
|
359
|
+
self._add_file_range(index, index)
|
|
360
|
+
return
|
|
361
|
+
if isinstance(index, slice):
|
|
362
|
+
if self.capacity is not None:
|
|
363
|
+
index = self.backend.arange(*index.indices(self.capacity), device=self.device)
|
|
364
|
+
else:
|
|
365
|
+
index = self.backend.arange(*index.indices(self.length), device=self.device)
|
|
366
|
+
elif index is Ellipsis:
|
|
367
|
+
if self.capacity is not None:
|
|
368
|
+
index = self.backend.arange(0, self.capacity, device=self.device)
|
|
369
|
+
else:
|
|
370
|
+
index = self.backend.arange(0, self.length, device=self.device)
|
|
371
|
+
elif self.backend.is_backendarray(index) and self.backend.dtype_is_boolean(index.dtype):
|
|
372
|
+
index = self.backend.nonzero(index)[0]
|
|
373
|
+
assert self.backend.is_backendarray(index) and self.backend.dtype_is_real_integer(index.dtype) and len(index.shape) == 1, "Index must be a 1D array of integers"
|
|
374
|
+
sorted_indexes_arg = self.backend.argsort(index)
|
|
375
|
+
sorted_indexes = index[sorted_indexes_arg]
|
|
376
|
+
diff_sorted = sorted_indexes[1:] - sorted_indexes[:-1]
|
|
377
|
+
discontinuities = self.backend.nonzero(diff_sorted > 1)[0]
|
|
378
|
+
if discontinuities.shape[0] == 0:
|
|
379
|
+
# Continuous
|
|
380
|
+
start_index = int(sorted_indexes[0])
|
|
381
|
+
end_index = int(sorted_indexes[-1])
|
|
382
|
+
self.remove_index_range(start_index, end_index)
|
|
383
|
+
filename = self._make_filename(start_index, end_index)
|
|
384
|
+
self.set_to_file(filename, sbu.get_at(
|
|
385
|
+
self._batched_single_space,
|
|
386
|
+
value,
|
|
387
|
+
sorted_indexes_arg
|
|
388
|
+
))
|
|
389
|
+
self._add_file_range(start_index, end_index)
|
|
390
|
+
return
|
|
391
|
+
else:
|
|
392
|
+
assert discontinuities.shape[0] == 1, "Round-robin writes can only handle one discontinuity in the index"
|
|
393
|
+
assert self.capacity is not None, "Round-robin writes require a fixed capacity"
|
|
394
|
+
discontinuity_pos = int(discontinuities[0]) + 1 # Position after the gap in sorted order
|
|
395
|
+
|
|
396
|
+
# First segment (in sorted order): lower indices [0, ...]
|
|
397
|
+
first_segment_indexes = sorted_indexes[:discontinuity_pos]
|
|
398
|
+
first_start_index = int(first_segment_indexes[0])
|
|
399
|
+
first_end_index = int(first_segment_indexes[-1])
|
|
400
|
+
|
|
401
|
+
# Second segment (in sorted order): higher indices [..., capacity-1]
|
|
402
|
+
second_segment_indexes = sorted_indexes[discontinuity_pos:]
|
|
403
|
+
second_start_index = int(second_segment_indexes[0])
|
|
404
|
+
second_end_index = int(second_segment_indexes[-1])
|
|
405
|
+
|
|
406
|
+
assert first_start_index == 0 and second_end_index == self.capacity - 1, "In round-robin writes, the first segment must start at 0 and the second segment must end at capacity - 1"
|
|
407
|
+
|
|
408
|
+
# Remove old data using wrap-around removal (start > end)
|
|
409
|
+
self.remove_index_range(second_start_index, first_end_index)
|
|
410
|
+
|
|
411
|
+
# Reorder the data: high indices first, then low indices (wrap-around order)
|
|
412
|
+
# We want: [high indices data, low indices data]
|
|
413
|
+
wrap_order_arg = self.backend.concat([
|
|
414
|
+
sorted_indexes_arg[discontinuity_pos:], # high indices first
|
|
415
|
+
sorted_indexes_arg[:discontinuity_pos] # then low indices
|
|
416
|
+
], axis=0)
|
|
417
|
+
|
|
418
|
+
# Write a single wrap-around file: start_index > end_index
|
|
419
|
+
# start_index is the first high index, end_index is the last low index
|
|
420
|
+
filename = self._make_filename(second_start_index, first_end_index)
|
|
421
|
+
self.set_to_file(filename, sbu.get_at(
|
|
422
|
+
self._batched_single_space,
|
|
423
|
+
value,
|
|
424
|
+
wrap_order_arg
|
|
425
|
+
))
|
|
426
|
+
self._add_file_range(second_start_index, first_end_index)
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
def clear(self):
|
|
430
|
+
assert self.is_mutable, "Cannot clear a read-only storage"
|
|
431
|
+
if self.capacity is None:
|
|
432
|
+
self.length = 0
|
|
433
|
+
shutil.rmtree(self._cache_path)
|
|
434
|
+
os.makedirs(self._cache_path, exist_ok=True)
|
|
435
|
+
self._file_ranges = []
|
|
436
|
+
|
|
437
|
+
def close(self):
|
|
438
|
+
pass
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
from importlib import metadata
|
|
2
|
+
from typing import Generic, TypeVar, Generic, Optional, Any, Dict, Tuple, Sequence, Union, List, Iterable, Type
|
|
3
|
+
|
|
4
|
+
from unienv_interface.space import Space
|
|
5
|
+
from unienv_interface.space.space_utils import batch_utils as sbu, flatten_utils as sfu
|
|
6
|
+
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
7
|
+
from unienv_interface.utils.symbol_util import *
|
|
8
|
+
|
|
9
|
+
from unienv_data.base import SpaceStorage, BatchT, IndexableType
|
|
10
|
+
|
|
11
|
+
import os
|
|
12
|
+
import shutil
|
|
13
|
+
from abc import abstractmethod
|
|
14
|
+
|
|
15
|
+
def batched_index_to_list(
|
|
16
|
+
backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
17
|
+
batched_index : Union[BArrayType, IndexableType],
|
|
18
|
+
length : int
|
|
19
|
+
) -> List[int]:
|
|
20
|
+
if isinstance(batched_index, slice):
|
|
21
|
+
return list(range(*batched_index.indices(length)))
|
|
22
|
+
elif batched_index is Ellipsis or batched_index is None:
|
|
23
|
+
return list(range(length))
|
|
24
|
+
else: # backend.is_backendarray
|
|
25
|
+
assert backend.is_backendarray(batched_index)
|
|
26
|
+
assert len(batched_index.shape) == 1
|
|
27
|
+
if backend.dtype_is_boolean(batched_index.dtype):
|
|
28
|
+
return [i for i in range(batched_index.shape[0]) if batched_index[i]]
|
|
29
|
+
elif backend.dtype_is_real_integer(batched_index.dtype):
|
|
30
|
+
return [batched_index[i] for i in range(batched_index.shape[0])]
|
|
31
|
+
else:
|
|
32
|
+
raise ValueError(f"Unsupported index type {type(batched_index)}")
|
|
33
|
+
|
|
34
|
+
class ListStorageBase(SpaceStorage[
|
|
35
|
+
BatchT,
|
|
36
|
+
BArrayType,
|
|
37
|
+
BDeviceType,
|
|
38
|
+
BDtypeType,
|
|
39
|
+
BRNGType,
|
|
40
|
+
]):
|
|
41
|
+
# ========== Instance Implementations ==========
|
|
42
|
+
single_file_ext = None
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
single_instance_space: Space[Any, BDeviceType, BDtypeType, BRNGType],
|
|
47
|
+
file_ext : str,
|
|
48
|
+
cache_filename : Union[str, os.PathLike],
|
|
49
|
+
mutable : bool = True,
|
|
50
|
+
capacity : Optional[int] = None,
|
|
51
|
+
length : int = 0,
|
|
52
|
+
):
|
|
53
|
+
assert cache_filename is not None, "ListStorage requires a cache filename"
|
|
54
|
+
super().__init__(single_instance_space)
|
|
55
|
+
self._batched_single_space = sbu.batch_space(self.single_instance_space, 1)
|
|
56
|
+
self.file_ext = file_ext
|
|
57
|
+
self._cache_path = cache_filename
|
|
58
|
+
self.is_mutable = mutable
|
|
59
|
+
self.capacity = capacity
|
|
60
|
+
self.length = length if capacity is None else capacity
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def cache_filename(self) -> Union[str, os.PathLike]:
|
|
64
|
+
return self._cache_path
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def is_multiprocessing_safe(self) -> bool:
|
|
68
|
+
return True
|
|
69
|
+
|
|
70
|
+
def extend_length(self, length):
|
|
71
|
+
assert self.capacity is None, "Cannot extend length of a fixed-capacity storage"
|
|
72
|
+
self.length += length
|
|
73
|
+
|
|
74
|
+
def shrink_length(self, length):
|
|
75
|
+
assert self.is_mutable, "Cannot shrink length of a read-only storage"
|
|
76
|
+
assert self.capacity is None, "Cannot shrink length of a fixed-capacity storage"
|
|
77
|
+
from_len = self.length
|
|
78
|
+
to_len = max(from_len - length, 0)
|
|
79
|
+
all_files = os.listdir(self._cache_path)
|
|
80
|
+
for i in range(to_len, from_len):
|
|
81
|
+
if f"{i}.{self.file_ext}" in all_files:
|
|
82
|
+
os.remove(os.path.join(self._cache_path, f"{i}.{self.file_ext}"))
|
|
83
|
+
self.length = to_len
|
|
84
|
+
|
|
85
|
+
def __len__(self):
|
|
86
|
+
return self.length if self.capacity is None else self.capacity
|
|
87
|
+
|
|
88
|
+
@abstractmethod
|
|
89
|
+
def get_from_file(self, filename : str) -> BatchT:
|
|
90
|
+
raise NotImplementedError
|
|
91
|
+
|
|
92
|
+
@abstractmethod
|
|
93
|
+
def set_to_file(self, filename : str, value : BatchT):
|
|
94
|
+
raise NotImplementedError
|
|
95
|
+
|
|
96
|
+
def get_single(self, index : int) -> BatchT:
|
|
97
|
+
assert 0 <= index < self.length, f"Index {index} out of bounds for storage of length {self.length}"
|
|
98
|
+
filename = os.path.join(self._cache_path, f"{index}.{self.file_ext}")
|
|
99
|
+
return self.get_from_file(filename)
|
|
100
|
+
|
|
101
|
+
def set_single(self, index : int, value : BArrayType):
|
|
102
|
+
assert self.is_mutable, "Storage is not mutable"
|
|
103
|
+
assert 0 <= index < self.length, f"Index {index} out of bounds for storage of length {self.length}"
|
|
104
|
+
filename = os.path.join(self._cache_path, f"{index}.{self.file_ext}")
|
|
105
|
+
self.set_to_file(filename, value)
|
|
106
|
+
|
|
107
|
+
def get(self, index):
|
|
108
|
+
if isinstance(index, int):
|
|
109
|
+
result = self.get_single(index)
|
|
110
|
+
else:
|
|
111
|
+
result = sbu.concatenate(
|
|
112
|
+
self._batched_single_space,
|
|
113
|
+
[
|
|
114
|
+
self.get_single(i) for i in batched_index_to_list(self.backend, index, len(self))
|
|
115
|
+
]
|
|
116
|
+
)
|
|
117
|
+
return result
|
|
118
|
+
|
|
119
|
+
def set(self, index, value):
|
|
120
|
+
assert self.is_mutable, "Storage is not mutable"
|
|
121
|
+
if isinstance(index, int):
|
|
122
|
+
self.set_single(index, value)
|
|
123
|
+
else:
|
|
124
|
+
indexes = batched_index_to_list(self.backend, index, len(self))
|
|
125
|
+
for i, ind in enumerate(indexes):
|
|
126
|
+
self.set_single(ind, sbu.get_at(self._batched_single_space, value, i))
|
|
127
|
+
|
|
128
|
+
def clear(self):
|
|
129
|
+
assert self.is_mutable, "Cannot clear a read-only storage"
|
|
130
|
+
if self.capacity is None:
|
|
131
|
+
self.length = 0
|
|
132
|
+
shutil.rmtree(self._cache_path)
|
|
133
|
+
os.makedirs(self._cache_path, exist_ok=True)
|
|
134
|
+
|
|
135
|
+
def close(self):
|
|
136
|
+
pass
|