unienv 0.0.1b4__py3-none-any.whl → 0.0.1b6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. {unienv-0.0.1b4.dist-info → unienv-0.0.1b6.dist-info}/METADATA +3 -2
  2. {unienv-0.0.1b4.dist-info → unienv-0.0.1b6.dist-info}/RECORD +43 -32
  3. {unienv-0.0.1b4.dist-info → unienv-0.0.1b6.dist-info}/WHEEL +1 -1
  4. unienv_data/base/common.py +25 -10
  5. unienv_data/base/storage.py +2 -0
  6. unienv_data/batches/backend_compat.py +1 -1
  7. unienv_data/batches/combined_batch.py +1 -1
  8. unienv_data/batches/slicestack_batch.py +1 -0
  9. unienv_data/replay_buffer/replay_buffer.py +179 -65
  10. unienv_data/replay_buffer/trajectory_replay_buffer.py +230 -163
  11. unienv_data/storages/_episode_storage.py +438 -0
  12. unienv_data/storages/_list_storage.py +136 -0
  13. unienv_data/storages/backend_compat.py +268 -0
  14. unienv_data/storages/dict_storage.py +39 -7
  15. unienv_data/storages/flattened.py +11 -4
  16. unienv_data/storages/hdf5.py +11 -0
  17. unienv_data/storages/image_storage.py +144 -0
  18. unienv_data/storages/npz_storage.py +135 -0
  19. unienv_data/storages/pytorch.py +17 -10
  20. unienv_data/storages/transformation.py +16 -1
  21. unienv_data/storages/video_storage.py +297 -0
  22. unienv_data/third_party/tensordict/memmap_tensor.py +1174 -0
  23. unienv_data/transformations/image_compress.py +97 -21
  24. unienv_interface/func_wrapper/frame_stack.py +1 -1
  25. unienv_interface/space/space_utils/batch_utils.py +5 -1
  26. unienv_interface/space/space_utils/flatten_utils.py +8 -2
  27. unienv_interface/space/spaces/dict.py +6 -0
  28. unienv_interface/space/spaces/tuple.py +4 -4
  29. unienv_interface/transformations/__init__.py +3 -1
  30. unienv_interface/transformations/batch_and_unbatch.py +42 -4
  31. unienv_interface/transformations/chained_transform.py +9 -8
  32. unienv_interface/transformations/crop.py +69 -0
  33. unienv_interface/transformations/dict_transform.py +8 -2
  34. unienv_interface/transformations/identity.py +16 -0
  35. unienv_interface/transformations/image_resize.py +106 -0
  36. unienv_interface/transformations/iter_transform.py +92 -0
  37. unienv_interface/transformations/rescale.py +24 -5
  38. unienv_interface/utils/symbol_util.py +7 -1
  39. unienv_interface/wrapper/backend_compat.py +1 -1
  40. unienv_interface/wrapper/frame_stack.py +1 -1
  41. {unienv-0.0.1b4.dist-info → unienv-0.0.1b6.dist-info}/licenses/LICENSE +0 -0
  42. {unienv-0.0.1b4.dist-info → unienv-0.0.1b6.dist-info}/top_level.txt +0 -0
  43. /unienv_interface/utils/{data_queue.py → framestack_queue.py} +0 -0
@@ -0,0 +1,438 @@
1
+ from importlib import metadata
2
+ from typing import Generic, TypeVar, Generic, Optional, Any, Dict, Tuple, Sequence, Union, List, Iterable, Type
3
+
4
+ from unienv_interface.space import Space
5
+ from unienv_interface.space.space_utils import batch_utils as sbu, flatten_utils as sfu
6
+ from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
7
+ from unienv_interface.utils.symbol_util import *
8
+
9
+ from unienv_data.base import SpaceStorage, BatchT, IndexableType
10
+
11
+ import os
12
+ import glob
13
+ import shutil
14
+ from abc import abstractmethod
15
+
16
+ class EpisodeStorageBase(SpaceStorage[
17
+ BatchT,
18
+ BArrayType,
19
+ BDeviceType,
20
+ BDtypeType,
21
+ BRNGType,
22
+ ]):
23
+ """
24
+ Base class for episode storage implementations.
25
+ An episode storage stores episodes of data, where each episode can consist of multiple time steps.
26
+ Each episode is stored as a separate file in the specified cache directory, with a specified file extension.
27
+ The file naming convention is "{start_step_index}_{end_step_index}.{file_ext}", where start_index and end_index define the range of time steps in the episode.
28
+ Note that if the storage has a fixed capacity, the {end_step_index} can be smaller than {start_step_index} due to round-robin overwriting.
29
+ """
30
+
31
+ # ========== Instance Implementations ==========
32
+ single_file_ext = None
33
+
34
+ def __init__(
35
+ self,
36
+ single_instance_space: Space[Any, BDeviceType, BDtypeType, BRNGType],
37
+ file_ext : str,
38
+ cache_filename : Union[str, os.PathLike],
39
+ mutable : bool = True,
40
+ capacity : Optional[int] = None,
41
+ length : int = 0,
42
+ ):
43
+ assert cache_filename is not None, "EpisodeStorage requires a cache filename"
44
+ super().__init__(single_instance_space)
45
+ self._batched_single_space = sbu.batch_space(self.single_instance_space, 1)
46
+ self.file_ext = file_ext
47
+ self._cache_path = cache_filename
48
+ self.is_mutable = mutable
49
+ self.capacity = capacity
50
+ self.length = length if capacity is None else capacity
51
+ # Cache of file ranges: List[(start_idx, end_idx)]
52
+ self._file_ranges: List[Tuple[int, int]] = []
53
+ self._rebuild_file_range_cache()
54
+
55
+ def _make_filename(self, start_idx: int, end_idx: int) -> str:
56
+ """Construct a filename from start and end indices."""
57
+ if self.file_ext is not None:
58
+ return os.path.join(self._cache_path, f"{start_idx}_{end_idx}.{self.file_ext}")
59
+ else:
60
+ return os.path.join(self._cache_path, f"{start_idx}_{end_idx}")
61
+
62
+ def get_start_end_filename_iter(self) -> Iterable[Tuple[int, int, str]]:
63
+ """Iterate over (start_idx, end_idx, filename) tuples, constructing filenames on the fly."""
64
+ for start_idx, end_idx in self._file_ranges:
65
+ yield start_idx, end_idx, self._make_filename(start_idx, end_idx)
66
+
67
+ def _rebuild_file_range_cache(self):
68
+ """Rebuild the file range cache from disk."""
69
+ all_filenames = glob.glob(os.path.join(self._cache_path, f"*_*.{self.file_ext}" if self.file_ext is not None else "*_*"))
70
+ self._file_ranges = []
71
+ for filename in all_filenames:
72
+ base = os.path.basename(filename)
73
+ name_part = base[:-(len(self.file_ext) + 1)] if self.file_ext is not None else base
74
+ start_str, end_str = name_part.split("_")
75
+ start_idx = int(start_str)
76
+ end_idx = int(end_str)
77
+ self._file_ranges.append((start_idx, end_idx))
78
+ # Sort by start index for consistent iteration
79
+ self._file_ranges.sort(key=lambda x: x[0])
80
+
81
+ def _add_file_range(self, start_idx: int, end_idx: int):
82
+ """Add a file range to the cache."""
83
+ self._file_ranges.append((start_idx, end_idx))
84
+ self._file_ranges.sort(key=lambda x: x[0])
85
+
86
+ def _remove_file_range(self, start_idx: int, end_idx: int):
87
+ """Remove a file range from the cache."""
88
+ self._file_ranges = [(s, e) for s, e in self._file_ranges if not (s == start_idx and e == end_idx)]
89
+
90
+ def _get_file_length(self, filename: str) -> int:
91
+ """Get the length (number of time steps) stored in a given episode file."""
92
+ base = os.path.basename(filename)
93
+ name_part = base[:-(len(self.file_ext) + 1)] if self.file_ext is not None else base
94
+ start_str, end_str = name_part.split("_")
95
+ start_idx = int(start_str)
96
+ end_idx = int(end_str)
97
+ if start_idx <= end_idx:
98
+ return end_idx - start_idx + 1
99
+ else:
100
+ assert self.capacity is not None, "Wrap-around file length calculation requires fixed capacity"
101
+ return (self.capacity - start_idx) + (end_idx + 1)
102
+
103
+ @property
104
+ def cache_filename(self) -> Union[str, os.PathLike]:
105
+ return self._cache_path
106
+
107
+ @property
108
+ def is_multiprocessing_safe(self) -> bool:
109
+ return not self.is_mutable
110
+
111
+ def convert_read_index_to_filenames_and_offsets(
112
+ self,
113
+ index: Union[IndexableType, BArrayType]
114
+ ) -> Tuple[int, List[Tuple[Union[str, os.PathLike], Union[int, BArrayType], Union[int, BArrayType]]]]:
115
+ """
116
+ Convert an index (which can be an integer, slice, list of integers, or backend array) to a list of filenames and offsets.
117
+ Each filename corresponds to an episode file, and the offset indicates the position within that episode.
118
+ """
119
+ def generate_episode_ranges(
120
+ filename : Union[str, os.PathLike],
121
+ start_idx : int,
122
+ end_idx : int,
123
+ index : Union[int, BArrayType]
124
+ ) -> Optional[Tuple[
125
+ Union[str, os.PathLike], # File Path (Absolute)
126
+ Union[int, BArrayType], # Index in File
127
+ Union[int, BArrayType], # Index into Data (Batch)
128
+ Optional[BArrayType]] # Remaining Indexes
129
+ ]:
130
+ if isinstance(index, int):
131
+ if start_idx <= index <= end_idx:
132
+ offset = index - start_idx
133
+ return (filename, slice(offset, offset + 1), slice(0, 1), None)
134
+ elif self.capacity is not None and start_idx > end_idx and (
135
+ index >= start_idx or index <= end_idx
136
+ ):
137
+ # Handle wrap-around case for fixed-capacity storage
138
+ if index >= start_idx:
139
+ offset = index - start_idx
140
+ else:
141
+ offset = (self.capacity - start_idx) + index
142
+ return (filename, slice(offset, offset + 1), slice(0, 1), None)
143
+ else:
144
+ return None
145
+ else:
146
+ assert self.backend.is_backendarray(index)
147
+ assert self.backend.dtype_is_real_integer(index.dtype)
148
+
149
+ if start_idx <= end_idx:
150
+ mask = self.backend.logical_and(
151
+ index >= start_idx,
152
+ index <= end_idx
153
+ )
154
+ if self.backend.sum(mask) == 0:
155
+ return None
156
+ in_range_indexes = index[mask]
157
+ index_to_data = self.backend.nonzero(mask)[0]
158
+ offsets = in_range_indexes - start_idx
159
+ remaining_indexes = index[~mask]
160
+ elif self.capacity is not None and start_idx > end_idx:
161
+ mask = self.backend.logical_or(
162
+ index >= start_idx,
163
+ index <= end_idx
164
+ )
165
+ if self.backend.sum(mask) == 0:
166
+ return None
167
+ in_range_indexes = index[mask]
168
+ index_to_data = self.backend.nonzero(mask)[0]
169
+ offsets = self.backend.where(
170
+ in_range_indexes >= start_idx,
171
+ in_range_indexes - start_idx,
172
+ (self.capacity - start_idx) + in_range_indexes
173
+ )
174
+ remaining_indexes = index[~mask]
175
+ else:
176
+ return None
177
+ if remaining_indexes.shape[0] == 0:
178
+ remaining_indexes = None
179
+ return (filename, offsets, index_to_data, remaining_indexes)
180
+
181
+ if isinstance(index, slice):
182
+ if self.capacity is not None:
183
+ index = self.backend.arange(*index.indices(self.capacity), device=self.device)
184
+ else:
185
+ index = self.backend.arange(*index.indices(self.length), device=self.device)
186
+ elif index is Ellipsis:
187
+ if self.capacity is not None:
188
+ index = self.backend.arange(0, self.capacity, device=self.device)
189
+ else:
190
+ index = self.backend.arange(0, self.length, device=self.device)
191
+ elif self.backend.is_backendarray(index) and self.backend.dtype_is_boolean(index.dtype):
192
+ index = self.backend.nonzero(index)[0]
193
+
194
+ batch_size = index.shape[0] if self.backend.is_backendarray(index) else 1
195
+
196
+ remaining_index = index
197
+ all_results = []
198
+ for start_idx, end_idx, filename in self.get_start_end_filename_iter():
199
+ result = generate_episode_ranges(filename, start_idx, end_idx, remaining_index)
200
+ if result is not None:
201
+ filename, offsets, batch_indexes, remaining_index = result
202
+ all_results.append((filename, offsets, batch_indexes))
203
+ if remaining_index is None:
204
+ break
205
+
206
+ assert remaining_index is None, f"Indexes {remaining_index} were not found in any episode files."
207
+ return (batch_size, all_results)
208
+
209
+ def extend_length(self, length):
210
+ assert self.capacity is None, "Cannot extend length of a fixed-capacity storage"
211
+ self.length += length
212
+
213
+ def remove_index_range(self, start_index: int, end_index: int):
214
+ """
215
+ Remove data in the index range [start_index, end_index] (inclusive).
216
+ This handles both regular files and wrap-around files.
217
+ If start_index > end_index, this is a wrap-around removal covering [start_index, capacity) and [0, end_index].
218
+ """
219
+ assert self.is_mutable, "Cannot remove index range from a read-only storage"
220
+ if start_index > end_index:
221
+ # Wrap-around removal: remove [start_index, capacity) and [0, end_index]
222
+ assert self.capacity is not None, "Wrap-around removal requires a fixed capacity"
223
+ self.remove_index_range(start_index, self.capacity - 1)
224
+ self.remove_index_range(0, end_index)
225
+ return
226
+
227
+ to_save = []
228
+ files_to_remove = [] # List of (start_idx, end_idx, filename)
229
+ for file_start_idx, file_end_idx, filename in self.get_start_end_filename_iter():
230
+
231
+ if file_start_idx <= file_end_idx:
232
+ # Regular (non-wrapping) file
233
+ file_len = file_end_idx - file_start_idx + 1
234
+ if file_start_idx >= start_index and file_end_idx <= end_index:
235
+ # Entire file is within the range to remove
236
+ files_to_remove.append((file_start_idx, file_end_idx, filename))
237
+ elif file_start_idx <= end_index and file_end_idx >= start_index:
238
+ # Partial overlap with the range to remove
239
+ # Keep data before the removal range
240
+ if file_start_idx < start_index:
241
+ data_before = self.get_from_file(filename, slice(0, start_index - file_start_idx), file_len)
242
+ to_save.append((data_before, file_start_idx, start_index - 1))
243
+ # Keep data after the removal range
244
+ if file_end_idx > end_index:
245
+ data_after = self.get_from_file(filename, slice(end_index - file_start_idx + 1, file_end_idx - file_start_idx + 1), file_len)
246
+ to_save.append((data_after, end_index + 1, file_end_idx))
247
+ files_to_remove.append((file_start_idx, file_end_idx, filename))
248
+ else:
249
+ # Wrap-around file (file_start_idx > file_end_idx)
250
+ # File contains indices [file_start_idx, capacity) and [0, file_end_idx]
251
+ file_len = (self.capacity - file_start_idx) + file_end_idx + 1
252
+
253
+ # Check if removal range overlaps with the high part [file_start_idx, capacity)
254
+ high_overlap_start = max(start_index, file_start_idx)
255
+ high_overlap_end = min(end_index, self.capacity - 1)
256
+ high_overlaps = high_overlap_start <= high_overlap_end
257
+
258
+ # Check if removal range overlaps with the low part [0, file_end_idx]
259
+ low_overlap_start = max(start_index, 0)
260
+ low_overlap_end = min(end_index, file_end_idx)
261
+ low_overlaps = low_overlap_start <= low_overlap_end
262
+
263
+ if not high_overlaps and not low_overlaps:
264
+ # No overlap, keep the file as is
265
+ continue
266
+
267
+ # Calculate what to keep from the high part
268
+ high_part_len = self.capacity - file_start_idx
269
+ if high_overlaps:
270
+ # Keep before the overlap
271
+ if file_start_idx < high_overlap_start:
272
+ keep_len = high_overlap_start - file_start_idx
273
+ data = self.get_from_file(filename, slice(0, keep_len), file_len)
274
+ to_save.append((data, file_start_idx, high_overlap_start - 1))
275
+ # Keep after the overlap (still in high part)
276
+ if high_overlap_end < self.capacity - 1:
277
+ keep_start_offset = high_overlap_end - file_start_idx + 1
278
+ keep_end_offset = high_part_len
279
+ data = self.get_from_file(filename, slice(keep_start_offset, keep_end_offset), file_len)
280
+ to_save.append((data, high_overlap_end + 1, self.capacity - 1))
281
+ else:
282
+ # Keep entire high part
283
+ data = self.get_from_file(filename, slice(0, high_part_len), file_len)
284
+ to_save.append((data, file_start_idx, self.capacity - 1))
285
+
286
+ # Calculate what to keep from the low part
287
+ if low_overlaps:
288
+ # Keep before the overlap
289
+ if 0 < low_overlap_start:
290
+ data = self.get_from_file(filename, slice(high_part_len, high_part_len + low_overlap_start), file_len)
291
+ to_save.append((data, 0, low_overlap_start - 1))
292
+ # Keep after the overlap
293
+ if low_overlap_end < file_end_idx:
294
+ keep_start_offset = high_part_len + low_overlap_end + 1
295
+ keep_end_offset = file_len
296
+ data = self.get_from_file(filename, slice(keep_start_offset, keep_end_offset), file_len)
297
+ to_save.append((data, low_overlap_end + 1, file_end_idx))
298
+ else:
299
+ # Keep entire low part
300
+ data = self.get_from_file(filename, slice(high_part_len, file_len), file_len)
301
+ to_save.append((data, 0, file_end_idx))
302
+
303
+ files_to_remove.append((file_start_idx, file_end_idx, filename))
304
+
305
+ # Remove files after reading all necessary data
306
+ for file_start_idx, file_end_idx, filename in files_to_remove:
307
+ os.remove(filename)
308
+ self._remove_file_range(file_start_idx, file_end_idx)
309
+
310
+ # Save the data that should be kept
311
+ for data, new_start_idx, new_end_idx in to_save:
312
+ new_filename = self._make_filename(new_start_idx, new_end_idx)
313
+ self.set_to_file(new_filename, data)
314
+ self._add_file_range(new_start_idx, new_end_idx)
315
+
316
+ def shrink_length(self, length):
317
+ assert self.is_mutable, "Cannot shrink length of a read-only storage"
318
+ assert self.capacity is None, "Cannot shrink length of a fixed-capacity storage"
319
+ from_len = self.length
320
+ to_len = max(from_len - length, 0)
321
+ self.remove_index_range(to_len, from_len - 1)
322
+ self.length = to_len
323
+
324
+ def __len__(self):
325
+ return self.length if self.capacity is None else self.capacity
326
+
327
+ @abstractmethod
328
+ def get_from_file(self, filename : str, index : Union[IndexableType, BArrayType], total_length : int) -> BatchT:
329
+ raise NotImplementedError
330
+
331
+ @abstractmethod
332
+ def set_to_file(self, filename : str, batched_value : BatchT):
333
+ raise NotImplementedError
334
+
335
+ def get(self, index):
336
+ batch_size, all_filename_offsets = self.convert_read_index_to_filenames_and_offsets(index)
337
+ result_space = sbu.batch_space(self.single_instance_space, batch_size)
338
+ result = result_space.create_empty()
339
+ for filename, file_offset, batch_indexes in all_filename_offsets:
340
+ file_len = self._get_file_length(filename)
341
+ data = self.get_from_file(filename, file_offset, file_len)
342
+ sbu.set_at(
343
+ result_space,
344
+ result,
345
+ batch_indexes,
346
+ data
347
+ )
348
+ if isinstance(index, int):
349
+ result = sbu.get_at(result_space, result, 0)
350
+ return result
351
+
352
+ def set(self, index, value):
353
+ assert self.is_mutable, "Storage is not mutable"
354
+ # Make sure the index is continuous
355
+ if isinstance(index, int):
356
+ self.remove_index_range(index, index)
357
+ filename = self._make_filename(index, index)
358
+ self.set_to_file(filename, sbu.concatenate(self._batched_single_space, [value]))
359
+ self._add_file_range(index, index)
360
+ return
361
+ if isinstance(index, slice):
362
+ if self.capacity is not None:
363
+ index = self.backend.arange(*index.indices(self.capacity), device=self.device)
364
+ else:
365
+ index = self.backend.arange(*index.indices(self.length), device=self.device)
366
+ elif index is Ellipsis:
367
+ if self.capacity is not None:
368
+ index = self.backend.arange(0, self.capacity, device=self.device)
369
+ else:
370
+ index = self.backend.arange(0, self.length, device=self.device)
371
+ elif self.backend.is_backendarray(index) and self.backend.dtype_is_boolean(index.dtype):
372
+ index = self.backend.nonzero(index)[0]
373
+ assert self.backend.is_backendarray(index) and self.backend.dtype_is_real_integer(index.dtype) and len(index.shape) == 1, "Index must be a 1D array of integers"
374
+ sorted_indexes_arg = self.backend.argsort(index)
375
+ sorted_indexes = index[sorted_indexes_arg]
376
+ diff_sorted = sorted_indexes[1:] - sorted_indexes[:-1]
377
+ discontinuities = self.backend.nonzero(diff_sorted > 1)[0]
378
+ if discontinuities.shape[0] == 0:
379
+ # Continuous
380
+ start_index = int(sorted_indexes[0])
381
+ end_index = int(sorted_indexes[-1])
382
+ self.remove_index_range(start_index, end_index)
383
+ filename = self._make_filename(start_index, end_index)
384
+ self.set_to_file(filename, sbu.get_at(
385
+ self._batched_single_space,
386
+ value,
387
+ sorted_indexes_arg
388
+ ))
389
+ self._add_file_range(start_index, end_index)
390
+ return
391
+ else:
392
+ assert discontinuities.shape[0] == 1, "Round-robin writes can only handle one discontinuity in the index"
393
+ assert self.capacity is not None, "Round-robin writes require a fixed capacity"
394
+ discontinuity_pos = int(discontinuities[0]) + 1 # Position after the gap in sorted order
395
+
396
+ # First segment (in sorted order): lower indices [0, ...]
397
+ first_segment_indexes = sorted_indexes[:discontinuity_pos]
398
+ first_start_index = int(first_segment_indexes[0])
399
+ first_end_index = int(first_segment_indexes[-1])
400
+
401
+ # Second segment (in sorted order): higher indices [..., capacity-1]
402
+ second_segment_indexes = sorted_indexes[discontinuity_pos:]
403
+ second_start_index = int(second_segment_indexes[0])
404
+ second_end_index = int(second_segment_indexes[-1])
405
+
406
+ assert first_start_index == 0 and second_end_index == self.capacity - 1, "In round-robin writes, the first segment must start at 0 and the second segment must end at capacity - 1"
407
+
408
+ # Remove old data using wrap-around removal (start > end)
409
+ self.remove_index_range(second_start_index, first_end_index)
410
+
411
+ # Reorder the data: high indices first, then low indices (wrap-around order)
412
+ # We want: [high indices data, low indices data]
413
+ wrap_order_arg = self.backend.concat([
414
+ sorted_indexes_arg[discontinuity_pos:], # high indices first
415
+ sorted_indexes_arg[:discontinuity_pos] # then low indices
416
+ ], axis=0)
417
+
418
+ # Write a single wrap-around file: start_index > end_index
419
+ # start_index is the first high index, end_index is the last low index
420
+ filename = self._make_filename(second_start_index, first_end_index)
421
+ self.set_to_file(filename, sbu.get_at(
422
+ self._batched_single_space,
423
+ value,
424
+ wrap_order_arg
425
+ ))
426
+ self._add_file_range(second_start_index, first_end_index)
427
+
428
+
429
+ def clear(self):
430
+ assert self.is_mutable, "Cannot clear a read-only storage"
431
+ if self.capacity is None:
432
+ self.length = 0
433
+ shutil.rmtree(self._cache_path)
434
+ os.makedirs(self._cache_path, exist_ok=True)
435
+ self._file_ranges = []
436
+
437
+ def close(self):
438
+ pass
@@ -0,0 +1,136 @@
1
+ from importlib import metadata
2
+ from typing import Generic, TypeVar, Generic, Optional, Any, Dict, Tuple, Sequence, Union, List, Iterable, Type
3
+
4
+ from unienv_interface.space import Space
5
+ from unienv_interface.space.space_utils import batch_utils as sbu, flatten_utils as sfu
6
+ from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
7
+ from unienv_interface.utils.symbol_util import *
8
+
9
+ from unienv_data.base import SpaceStorage, BatchT, IndexableType
10
+
11
+ import os
12
+ import shutil
13
+ from abc import abstractmethod
14
+
15
+ def batched_index_to_list(
16
+ backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
17
+ batched_index : Union[BArrayType, IndexableType],
18
+ length : int
19
+ ) -> List[int]:
20
+ if isinstance(batched_index, slice):
21
+ return list(range(*batched_index.indices(length)))
22
+ elif batched_index is Ellipsis or batched_index is None:
23
+ return list(range(length))
24
+ else: # backend.is_backendarray
25
+ assert backend.is_backendarray(batched_index)
26
+ assert len(batched_index.shape) == 1
27
+ if backend.dtype_is_boolean(batched_index.dtype):
28
+ return [i for i in range(batched_index.shape[0]) if batched_index[i]]
29
+ elif backend.dtype_is_real_integer(batched_index.dtype):
30
+ return [batched_index[i] for i in range(batched_index.shape[0])]
31
+ else:
32
+ raise ValueError(f"Unsupported index type {type(batched_index)}")
33
+
34
+ class ListStorageBase(SpaceStorage[
35
+ BatchT,
36
+ BArrayType,
37
+ BDeviceType,
38
+ BDtypeType,
39
+ BRNGType,
40
+ ]):
41
+ # ========== Instance Implementations ==========
42
+ single_file_ext = None
43
+
44
+ def __init__(
45
+ self,
46
+ single_instance_space: Space[Any, BDeviceType, BDtypeType, BRNGType],
47
+ file_ext : str,
48
+ cache_filename : Union[str, os.PathLike],
49
+ mutable : bool = True,
50
+ capacity : Optional[int] = None,
51
+ length : int = 0,
52
+ ):
53
+ assert cache_filename is not None, "ListStorage requires a cache filename"
54
+ super().__init__(single_instance_space)
55
+ self._batched_single_space = sbu.batch_space(self.single_instance_space, 1)
56
+ self.file_ext = file_ext
57
+ self._cache_path = cache_filename
58
+ self.is_mutable = mutable
59
+ self.capacity = capacity
60
+ self.length = length if capacity is None else capacity
61
+
62
+ @property
63
+ def cache_filename(self) -> Union[str, os.PathLike]:
64
+ return self._cache_path
65
+
66
+ @property
67
+ def is_multiprocessing_safe(self) -> bool:
68
+ return True
69
+
70
+ def extend_length(self, length):
71
+ assert self.capacity is None, "Cannot extend length of a fixed-capacity storage"
72
+ self.length += length
73
+
74
+ def shrink_length(self, length):
75
+ assert self.is_mutable, "Cannot shrink length of a read-only storage"
76
+ assert self.capacity is None, "Cannot shrink length of a fixed-capacity storage"
77
+ from_len = self.length
78
+ to_len = max(from_len - length, 0)
79
+ all_files = os.listdir(self._cache_path)
80
+ for i in range(to_len, from_len):
81
+ if f"{i}.{self.file_ext}" in all_files:
82
+ os.remove(os.path.join(self._cache_path, f"{i}.{self.file_ext}"))
83
+ self.length = to_len
84
+
85
+ def __len__(self):
86
+ return self.length if self.capacity is None else self.capacity
87
+
88
+ @abstractmethod
89
+ def get_from_file(self, filename : str) -> BatchT:
90
+ raise NotImplementedError
91
+
92
+ @abstractmethod
93
+ def set_to_file(self, filename : str, value : BatchT):
94
+ raise NotImplementedError
95
+
96
+ def get_single(self, index : int) -> BatchT:
97
+ assert 0 <= index < self.length, f"Index {index} out of bounds for storage of length {self.length}"
98
+ filename = os.path.join(self._cache_path, f"{index}.{self.file_ext}")
99
+ return self.get_from_file(filename)
100
+
101
+ def set_single(self, index : int, value : BArrayType):
102
+ assert self.is_mutable, "Storage is not mutable"
103
+ assert 0 <= index < self.length, f"Index {index} out of bounds for storage of length {self.length}"
104
+ filename = os.path.join(self._cache_path, f"{index}.{self.file_ext}")
105
+ self.set_to_file(filename, value)
106
+
107
+ def get(self, index):
108
+ if isinstance(index, int):
109
+ result = self.get_single(index)
110
+ else:
111
+ result = sbu.concatenate(
112
+ self._batched_single_space,
113
+ [
114
+ self.get_single(i) for i in batched_index_to_list(self.backend, index, len(self))
115
+ ]
116
+ )
117
+ return result
118
+
119
+ def set(self, index, value):
120
+ assert self.is_mutable, "Storage is not mutable"
121
+ if isinstance(index, int):
122
+ self.set_single(index, value)
123
+ else:
124
+ indexes = batched_index_to_list(self.backend, index, len(self))
125
+ for i, ind in enumerate(indexes):
126
+ self.set_single(ind, sbu.get_at(self._batched_single_space, value, i))
127
+
128
+ def clear(self):
129
+ assert self.is_mutable, "Cannot clear a read-only storage"
130
+ if self.capacity is None:
131
+ self.length = 0
132
+ shutil.rmtree(self._cache_path)
133
+ os.makedirs(self._cache_path, exist_ok=True)
134
+
135
+ def close(self):
136
+ pass