unienv 0.0.1b5__py3-none-any.whl → 0.0.1b6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {unienv-0.0.1b5.dist-info → unienv-0.0.1b6.dist-info}/METADATA +3 -2
  2. {unienv-0.0.1b5.dist-info → unienv-0.0.1b6.dist-info}/RECORD +30 -21
  3. {unienv-0.0.1b5.dist-info → unienv-0.0.1b6.dist-info}/WHEEL +1 -1
  4. unienv_data/base/common.py +25 -10
  5. unienv_data/batches/backend_compat.py +1 -1
  6. unienv_data/batches/combined_batch.py +1 -1
  7. unienv_data/replay_buffer/replay_buffer.py +51 -8
  8. unienv_data/storages/_episode_storage.py +438 -0
  9. unienv_data/storages/_list_storage.py +136 -0
  10. unienv_data/storages/backend_compat.py +268 -0
  11. unienv_data/storages/flattened.py +3 -3
  12. unienv_data/storages/hdf5.py +7 -2
  13. unienv_data/storages/image_storage.py +144 -0
  14. unienv_data/storages/npz_storage.py +135 -0
  15. unienv_data/storages/pytorch.py +16 -9
  16. unienv_data/storages/video_storage.py +297 -0
  17. unienv_data/third_party/tensordict/memmap_tensor.py +1174 -0
  18. unienv_data/transformations/image_compress.py +81 -18
  19. unienv_interface/space/space_utils/batch_utils.py +5 -1
  20. unienv_interface/space/spaces/dict.py +6 -0
  21. unienv_interface/transformations/__init__.py +3 -1
  22. unienv_interface/transformations/batch_and_unbatch.py +42 -4
  23. unienv_interface/transformations/chained_transform.py +9 -8
  24. unienv_interface/transformations/crop.py +69 -0
  25. unienv_interface/transformations/dict_transform.py +8 -2
  26. unienv_interface/transformations/identity.py +16 -0
  27. unienv_interface/transformations/rescale.py +24 -5
  28. unienv_interface/wrapper/backend_compat.py +1 -1
  29. {unienv-0.0.1b5.dist-info → unienv-0.0.1b6.dist-info}/licenses/LICENSE +0 -0
  30. {unienv-0.0.1b5.dist-info → unienv-0.0.1b6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,268 @@
1
+ from typing import Dict, Any, Optional, Tuple, Union, Generic, SupportsFloat, Type, Sequence, Mapping, TypeVar
2
+ import numpy as np
3
+ import copy
4
+
5
+ from unienv_interface.space import Space, DictSpace
6
+ from unienv_interface.space.space_utils import batch_utils as sbu, flatten_utils as sfu
7
+ from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
8
+ from unienv_interface.backends.serialization import serialize_backend, deserialize_backend
9
+ from unienv_interface.utils.symbol_util import *
10
+
11
+ from unienv_data.base import SpaceStorage, BatchT
12
+
13
+ import os
14
+ import json
15
+
16
+ WrapperBatchT = TypeVar("WrapperBatchT")
17
+ WrapperBArrayT = TypeVar("WrapperBArrayT")
18
+ WrapperBDeviceT = TypeVar("WrapperBDeviceT")
19
+ WrapperBDtypeT = TypeVar("WrapperBDtypeT")
20
+ WrapperBRngT = TypeVar("WrapperBRngT")
21
+
22
+ def data_to(
23
+ data : Any,
24
+ source_backend : Optional[ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType]] = None,
25
+ target_backend : Optional[ComputeBackend[WrapperBArrayT, WrapperBDeviceT, WrapperBDtypeT, WrapperBRngT]] = None,
26
+ target_device : Optional[WrapperBDeviceT] = None,
27
+ ):
28
+ if source_backend.is_backendarray(data):
29
+ if source_backend is not None and target_backend is not None and target_backend != source_backend:
30
+ data = target_backend.from_other_backend(
31
+ source_backend,
32
+ data
33
+ )
34
+ if target_device is not None:
35
+ data = (source_backend or target_backend).to_device(
36
+ data,
37
+ target_device
38
+ )
39
+ elif isinstance(data, Mapping):
40
+ data = {
41
+ key: data_to(value, source_backend, target_backend, target_device)
42
+ for key, value in data.items()
43
+ }
44
+ elif isinstance(data, Sequence) and not isinstance(data, (str, bytes)):
45
+ data = [
46
+ data_to(value, source_backend, target_backend, target_device)
47
+ for value in data
48
+ ]
49
+ try:
50
+ data = type(data)(data) # Preserve the type of the original sequence
51
+ except:
52
+ pass
53
+ return data
54
+
55
+ class ToBackendOrDeviceStorage(
56
+ SpaceStorage[
57
+ WrapperBatchT, WrapperBArrayT, WrapperBDeviceT, WrapperBDtypeT, WrapperBRngT
58
+ ],
59
+ Generic[
60
+ WrapperBatchT, WrapperBArrayT, WrapperBDeviceT, WrapperBDtypeT, WrapperBRngT,
61
+ BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType
62
+ ]
63
+ ):
64
+ # ========== Class Implementations ==========
65
+ @classmethod
66
+ def create(
67
+ cls,
68
+ single_instance_space: Space[WrapperBatchT, WrapperBDeviceT, WrapperBDtypeT, WrapperBRngT],
69
+ inner_storage_cls : Type[SpaceStorage[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]],
70
+ *args,
71
+ capacity : Optional[int] = None,
72
+ cache_path : Optional[str] = None,
73
+ multiprocessing : bool = False,
74
+ backend : Optional[ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType]] = None,
75
+ device : Optional[BDeviceType] = None,
76
+ inner_storage_kwargs : Dict[str, Any] = {},
77
+ **kwargs
78
+ ) -> "ToBackendOrDeviceStorage[WrapperBatchT, WrapperBArrayT, WrapperBDeviceT, WrapperBDtypeT, WrapperBRngT, BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]":
79
+ inner_storage_path = "inner_storage" + (inner_storage_cls.single_file_ext or "")
80
+
81
+ if cache_path is not None:
82
+ os.makedirs(cache_path, exist_ok=True)
83
+
84
+ _inner_storage_kwargs = kwargs.copy()
85
+ _inner_storage_kwargs.update(inner_storage_kwargs)
86
+ inner_storage = inner_storage_cls.create(
87
+ single_instance_space.to(
88
+ backend=backend,
89
+ device=device
90
+ ),
91
+ *args,
92
+ cache_path=None if cache_path is None else os.path.join(cache_path, inner_storage_path),
93
+ capacity=capacity,
94
+ multiprocessing=multiprocessing,
95
+ **_inner_storage_kwargs
96
+ )
97
+ if (backend is None or backend == single_instance_space.backend) and (device is None or device == single_instance_space.device):
98
+ return inner_storage
99
+
100
+ return ToBackendOrDeviceStorage(
101
+ single_instance_space,
102
+ inner_storage,
103
+ inner_storage_path,
104
+ cache_filename=cache_path,
105
+ )
106
+
107
+ @classmethod
108
+ def load_from(
109
+ cls,
110
+ path : Union[str, os.PathLike],
111
+ single_instance_space : Space[Any, BDeviceType, BDtypeType, BRNGType],
112
+ *,
113
+ capacity : Optional[int] = None,
114
+ read_only : bool = True,
115
+ multiprocessing : bool = False,
116
+ **kwargs
117
+ ) -> Union[
118
+ "ToBackendOrDeviceStorage[WrapperBatchT, WrapperBArrayT, WrapperBDeviceT, WrapperBDtypeT, WrapperBRngT, BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]",
119
+ SpaceStorage[WrapperBatchT, WrapperBArrayT, WrapperBDeviceT, WrapperBDtypeT, WrapperBRngT]
120
+ ]:
121
+ metadata_path = os.path.join(path, "backend_metadata.json")
122
+ assert os.path.exists(metadata_path), f"Metadata file {metadata_path} does not exist"
123
+ with open(metadata_path, "r") as f:
124
+ metadata = json.load(f)
125
+ assert metadata["storage_type"] == cls.__name__, \
126
+ f"Expected storage type {cls.__name__}, but found {metadata['storage_type']}"
127
+ inner_storage_cls : Type[SpaceStorage] = get_class_from_full_name(metadata["inner_storage_type"])
128
+ inner_storage_path = metadata["inner_storage_path"]
129
+ inner_backend = deserialize_backend(metadata["inner_backend"])
130
+ inner_device = inner_backend.deserialize_device(metadata["inner_device"])
131
+
132
+ if inner_backend != single_instance_space.backend or inner_device != single_instance_space.device:
133
+ inner_space = single_instance_space.to(
134
+ backend=inner_backend if inner_backend != single_instance_space.backend else None,
135
+ device=inner_device if inner_device != single_instance_space.device else None
136
+ )
137
+ else:
138
+ inner_space = single_instance_space
139
+
140
+ inner_storage = inner_storage_cls.load_from(
141
+ os.path.join(path, inner_storage_path),
142
+ inner_space,
143
+ capacity=capacity,
144
+ read_only=read_only,
145
+ multiprocessing=multiprocessing,
146
+ **kwargs
147
+ )
148
+
149
+ if inner_backend == single_instance_space.backend and inner_device == single_instance_space.device:
150
+ return inner_storage
151
+ else:
152
+ return ToBackendOrDeviceStorage(
153
+ single_instance_space,
154
+ inner_storage,
155
+ inner_storage_path,
156
+ cache_filename=path,
157
+ )
158
+
159
+ # ======== Instance Implementations ==========
160
+ single_file_ext = None
161
+
162
+ def __init__(
163
+ self,
164
+ single_instance_space : Space[WrapperBatchT, WrapperBDeviceT, WrapperBDtypeT, WrapperBRngT],
165
+ inner_storage : SpaceStorage[
166
+ BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType
167
+ ],
168
+ inner_storage_path : Union[str, os.PathLike],
169
+ cache_filename : Optional[Union[str, os.PathLike]] = None,
170
+ ):
171
+ super().__init__(single_instance_space)
172
+ inner_backend = None if inner_storage.backend == single_instance_space.backend else inner_storage.backend
173
+ inner_device = None if inner_storage.device == single_instance_space.device else inner_storage.device
174
+ current_backend = None if inner_storage.backend == single_instance_space.backend else single_instance_space.backend
175
+ current_device = None if inner_storage.device == single_instance_space.device else single_instance_space.device
176
+
177
+ self._batched_instance_space = sbu.batch_space(
178
+ single_instance_space,
179
+ 1
180
+ )
181
+ self._batched_inner_space = self._batched_instance_space.to(
182
+ backend=inner_backend,
183
+ device=inner_device
184
+ )
185
+
186
+ self.inner_storage = inner_storage
187
+ self.inner_storage_path = inner_storage_path
188
+ self._cache_filename = cache_filename
189
+
190
+ self.inner_backend = inner_backend
191
+ self.inner_device = inner_device
192
+ self.current_backend = current_backend
193
+ self.current_device = current_device
194
+
195
+ @property
196
+ def cache_filename(self) -> Optional[Union[str, os.PathLike]]:
197
+ return self._cache_filename
198
+
199
+ @property
200
+ def is_mutable(self) -> bool:
201
+ return self.inner_storage.is_mutable
202
+
203
+ @property
204
+ def is_multiprocessing_safe(self) -> bool:
205
+ return self.inner_storage.is_multiprocessing_safe
206
+
207
+ @property
208
+ def capacity(self) -> Optional[int]:
209
+ return self.inner_storage.capacity
210
+
211
+ def extend_length(self, length):
212
+ self.inner_storage.extend_length(length)
213
+
214
+ def shrink_length(self, length):
215
+ self.inner_storage.shrink_length(length)
216
+
217
+ def __len__(self):
218
+ return len(self.inner_storage)
219
+
220
+ def get(self, index):
221
+ if self.backend.is_backendarray(index):
222
+ index = data_to(
223
+ index,
224
+ source_backend=self.backend,
225
+ target_backend=self.inner_backend,
226
+ target_device=self.inner_device
227
+ )
228
+ target_data = self.inner_storage.get(index)
229
+ return data_to(
230
+ target_data,
231
+ source_backend=self.inner_storage.backend,
232
+ target_backend=self.current_backend,
233
+ target_device=self.current_device
234
+ )
235
+
236
+ def set(self, index, value):
237
+ target_value = data_to(
238
+ value,
239
+ source_backend=self.backend,
240
+ target_backend=self.inner_backend,
241
+ target_device=self.inner_device
242
+ )
243
+ if self.backend.is_backendarray(index):
244
+ index = data_to(
245
+ index,
246
+ source_backend=self.backend,
247
+ target_backend=self.inner_backend,
248
+ target_device=self.inner_device
249
+ )
250
+ self.inner_storage.set(index, target_value)
251
+
252
+ def clear(self):
253
+ self.inner_storage.clear()
254
+
255
+ def dumps(self, path):
256
+ metadata = {
257
+ "storage_type": __class__.__name__,
258
+ "inner_storage_type": get_full_class_name(type(self.inner_storage)),
259
+ "inner_storage_path": self.inner_storage_path,
260
+ "inner_backend": serialize_backend(self.inner_storage.backend),
261
+ "inner_device": self.inner_storage.backend.serialize_device(self.inner_storage.device),
262
+ }
263
+ self.inner_storage.dumps(os.path.join(path, self.inner_storage_path))
264
+ with open(os.path.join(path, "backend_metadata.json"), "w") as f:
265
+ json.dump(metadata, f)
266
+
267
+ def close(self):
268
+ self.inner_storage.close()
@@ -23,7 +23,7 @@ class FlattenedStorage(SpaceStorage[
23
23
  @classmethod
24
24
  def create(
25
25
  cls,
26
- single_instance_space: Space[Any, BDeviceType, BDtypeType, BRNGType],
26
+ single_instance_space: Space[BatchT, BDeviceType, BDtypeType, BRNGType],
27
27
  inner_storage_cls : Type[SpaceStorage[BArrayType, BArrayType, BDeviceType, BDtypeType, BRNGType]],
28
28
  *args,
29
29
  capacity : Optional[int] = None,
@@ -59,7 +59,7 @@ class FlattenedStorage(SpaceStorage[
59
59
  def load_from(
60
60
  cls,
61
61
  path : Union[str, os.PathLike],
62
- single_instance_space : Space[Any, BDeviceType, BDtypeType, BRNGType],
62
+ single_instance_space : Space[BatchT, BDeviceType, BDtypeType, BRNGType],
63
63
  *,
64
64
  capacity : Optional[int] = None,
65
65
  read_only : bool = True,
@@ -95,7 +95,7 @@ class FlattenedStorage(SpaceStorage[
95
95
 
96
96
  def __init__(
97
97
  self,
98
- single_instance_space: Space[Any, BDeviceType, BDtypeType, BRNGType],
98
+ single_instance_space: Space[BatchT, BDeviceType, BDtypeType, BRNGType],
99
99
  inner_storage : SpaceStorage[
100
100
  BArrayType,
101
101
  BArrayType,
@@ -431,6 +431,11 @@ class HDF5Storage(SpaceStorage[
431
431
  # Convert to numpy array if it's a scalar
432
432
  if isinstance(result, (int, float)):
433
433
  result = np.array(result)
434
+ if isinstance(single_instance_space, TextSpace):
435
+ if isinstance(result, bytes):
436
+ result = result.decode('utf-8')
437
+ elif isinstance(result, np.ndarray):
438
+ result = np.array([r.decode('utf-8') if isinstance(r, bytes) else r for r in result], dtype=object)
434
439
  return result
435
440
 
436
441
  @staticmethod
@@ -513,8 +518,8 @@ class HDF5Storage(SpaceStorage[
513
518
  reduce_io : bool = True,
514
519
  **kwargs
515
520
  ) -> "HDF5Storage":
516
- assert not multiprocessing, \
517
- "HDF5Storage does not support multiprocessing safe loading. Please load the storage in the main process and then share it with child processes."
521
+ assert read_only or not multiprocessing, \
522
+ "HDF5Storage does not support multiprocessing safe loading when not `read_only`. Please load the storage in the main process and then share it with child processes."
518
523
  assert os.path.exists(path), \
519
524
  f"Path {path} does not exist"
520
525
 
@@ -0,0 +1,144 @@
1
+ from importlib import metadata
2
+ from typing import Generic, TypeVar, Generic, Optional, Any, Dict, Tuple, Sequence, Union, List, Iterable, Type
3
+
4
+ from unienv_interface.space import Space, BoxSpace
5
+ from unienv_interface.space.space_utils import batch_utils as sbu, flatten_utils as sfu
6
+ from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
7
+ from unienv_interface.utils.symbol_util import *
8
+
9
+ from unienv_data.base import SpaceStorage
10
+ from ._list_storage import ListStorageBase
11
+
12
+ import numpy as np
13
+ import os
14
+ import json
15
+ import shutil
16
+
17
+ from PIL import Image
18
+
19
+ class ImageStorage(ListStorageBase[
20
+ BArrayType,
21
+ BArrayType,
22
+ BDeviceType,
23
+ BDtypeType,
24
+ BRNGType,
25
+ ]):
26
+ # ========== Class Attributes ==========
27
+ @classmethod
28
+ def create(
29
+ cls,
30
+ single_instance_space: BoxSpace[BArrayType, BDeviceType, BDtypeType, BRNGType],
31
+ *args,
32
+ format : str = "JPEG",
33
+ quality : int = 75,
34
+ capacity : Optional[int] = None,
35
+ cache_path : Optional[str] = None,
36
+ multiprocessing : bool = False,
37
+ **kwargs
38
+ ) -> "ImageStorage[BArrayType, BDeviceType, BDtypeType, BRNGType]":
39
+ if cache_path is None:
40
+ raise ValueError("cache_path must be provided for ImageStorage.create")
41
+ assert not os.path.exists(cache_path), f"Cache path {cache_path} already exists"
42
+ os.makedirs(cache_path, exist_ok=True)
43
+ return ImageStorage(
44
+ single_instance_space,
45
+ format=format,
46
+ quality=quality,
47
+ cache_filename=cache_path,
48
+ capacity=capacity,
49
+ )
50
+
51
+ @classmethod
52
+ def load_from(
53
+ cls,
54
+ path : Union[str, os.PathLike],
55
+ single_instance_space : BoxSpace[BArrayType, BDeviceType, BDtypeType, BRNGType],
56
+ *,
57
+ capacity : Optional[int] = None,
58
+ read_only : bool = True,
59
+ multiprocessing : bool = False,
60
+ **kwargs
61
+ ) -> "ImageStorage[BArrayType, BDeviceType, BDtypeType, BRNGType]":
62
+ metadata_path = os.path.join(path, "image_metadata.json")
63
+ assert os.path.exists(metadata_path), f"Metadata file {metadata_path} does not exist"
64
+ with open(metadata_path, "r") as f:
65
+ metadata = json.load(f)
66
+ assert metadata["storage_type"] == cls.__name__, \
67
+ f"Expected storage type {cls.__name__}, but found {metadata['storage_type']}"
68
+
69
+ format = metadata["format"]
70
+ quality = int(metadata["quality"])
71
+ if "capacity" in metadata:
72
+ capacity = None if metadata['capacity'] is None else int(metadata["capacity"])
73
+ length = None if capacity is None else metadata["length"]
74
+
75
+ return ImageStorage(
76
+ single_instance_space,
77
+ format=format,
78
+ quality=quality,
79
+ cache_filename=path,
80
+ mutable=not read_only,
81
+ capacity=capacity,
82
+ length=length,
83
+ )
84
+
85
+ # ========== Instance Implementations ==========
86
+ single_file_ext = None
87
+
88
+ def __init__(
89
+ self,
90
+ single_instance_space: BoxSpace[BArrayType, BDeviceType, BDtypeType, BRNGType],
91
+ format : str,
92
+ quality : int,
93
+ cache_filename : Union[str, os.PathLike],
94
+ mutable : bool = True,
95
+ capacity : Optional[int] = None,
96
+ length : int = 0,
97
+ ):
98
+ super().__init__(
99
+ single_instance_space,
100
+ file_ext=format.lower(),
101
+ cache_filename=cache_filename,
102
+ mutable=mutable,
103
+ capacity=capacity,
104
+ length=length,
105
+ )
106
+ self.format = format
107
+ self.quality = quality
108
+
109
+ def get_from_file(self, filename : str) -> BArrayType:
110
+ if not os.path.exists(filename):
111
+ return self.backend.zeros(
112
+ self.single_instance_space.shape,
113
+ dtype=self.single_instance_space.dtype,
114
+ device=self.single_instance_space.device,
115
+ )
116
+ with Image.open(filename) as img:
117
+ rgb_img = img.convert("RGB")
118
+ # Ensure a compact uint8 array from PIL.
119
+ np_image = np.asarray(rgb_img, dtype=np.uint8)
120
+ return self.backend.from_numpy(np_image, dtype=self.single_instance_space.dtype, device=self.single_instance_space.device)
121
+
122
+ def set_to_file(self, filename : str, value : BArrayType):
123
+ np_value = self.backend.to_numpy(value)
124
+ img = Image.fromarray(np_value, mode="RGB")
125
+ try:
126
+ img.save(filename, format=self.format, quality=self.quality)
127
+ finally:
128
+ img.close()
129
+
130
+ def dumps(self, path):
131
+ assert os.path.samefile(path, self.cache_filename), \
132
+ f"Dump path {path} does not match cache filename {self.cache_filename}"
133
+ metadata = {
134
+ "storage_type": __class__.__name__,
135
+ "format": self.format,
136
+ "quality": self.quality,
137
+ "capacity": self.capacity,
138
+ "length": self.length,
139
+ }
140
+ with open(os.path.join(path, "image_metadata.json"), "w") as f:
141
+ json.dump(metadata, f)
142
+
143
+ def close(self):
144
+ pass
@@ -0,0 +1,135 @@
1
+ from importlib import metadata
2
+ from typing import Generic, TypeVar, Generic, Optional, Any, Dict, Tuple, Sequence, Union, List, Iterable, Type
3
+
4
+ from unienv_interface.space import Space, BoxSpace, BinarySpace
5
+ from unienv_interface.space.space_utils import batch_utils as sbu, flatten_utils as sfu
6
+ from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
7
+ from unienv_interface.utils.symbol_util import *
8
+
9
+ from unienv_data.base import SpaceStorage
10
+ from ._list_storage import ListStorageBase
11
+
12
+ import numpy as np
13
+ import os
14
+ import json
15
+ import shutil
16
+
17
+ from PIL import Image
18
+
19
+ class NPZStorage(ListStorageBase[
20
+ BArrayType,
21
+ BArrayType,
22
+ BDeviceType,
23
+ BDtypeType,
24
+ BRNGType,
25
+ ]):
26
+ # ========== Class Attributes ==========
27
+ @classmethod
28
+ def create(
29
+ cls,
30
+ single_instance_space: BoxSpace[BArrayType, BDeviceType, BDtypeType, BRNGType],
31
+ *args,
32
+ compressed : bool = True,
33
+ capacity : Optional[int] = None,
34
+ cache_path : Optional[str] = None,
35
+ multiprocessing : bool = False,
36
+ **kwargs
37
+ ) -> "NPZStorage[BArrayType, BDeviceType, BDtypeType, BRNGType]":
38
+ assert not os.path.exists(cache_path), f"Cache path {cache_path} already exists"
39
+ os.makedirs(cache_path, exist_ok=True)
40
+ return NPZStorage(
41
+ single_instance_space,
42
+ compressed=compressed,
43
+ cache_filename=cache_path,
44
+ capacity=capacity,
45
+ )
46
+
47
+ @classmethod
48
+ def load_from(
49
+ cls,
50
+ path : Union[str, os.PathLike],
51
+ single_instance_space : BoxSpace[BArrayType, BDeviceType, BDtypeType, BRNGType],
52
+ *,
53
+ capacity : Optional[int] = None,
54
+ read_only : bool = True,
55
+ multiprocessing : bool = False,
56
+ **kwargs
57
+ ) -> "NPZStorage[BArrayType, BDeviceType, BDtypeType, BRNGType]":
58
+ metadata_path = os.path.join(path, "npz_metadata.json")
59
+ assert os.path.exists(metadata_path), f"Metadata file {metadata_path} does not exist"
60
+ with open(metadata_path, "r") as f:
61
+ metadata = json.load(f)
62
+ assert metadata["storage_type"] == cls.__name__, \
63
+ f"Expected storage type {cls.__name__}, but found {metadata['storage_type']}"
64
+
65
+ compressed = metadata.get("compressed", True)
66
+ if "capacity" in metadata:
67
+ capacity = None if metadata['capacity'] is None else int(metadata["capacity"])
68
+ length = None if capacity is None else metadata["length"]
69
+
70
+ return NPZStorage(
71
+ single_instance_space,
72
+ compressed=compressed,
73
+ cache_filename=path,
74
+ mutable=not read_only,
75
+ capacity=capacity,
76
+ length=length,
77
+ )
78
+
79
+ # ========== Instance Implementations ==========
80
+ def __init__(
81
+ self,
82
+ single_instance_space: BoxSpace[BArrayType, BDeviceType, BDtypeType, BRNGType],
83
+ compressed : bool,
84
+ cache_filename : Union[str, os.PathLike],
85
+ mutable : bool = True,
86
+ capacity : Optional[int] = None,
87
+ length : int = 0,
88
+ ):
89
+ assert isinstance(single_instance_space, BoxSpace) or isinstance(single_instance_space, BinarySpace), "single_instance_space must be a BoxSpace or BinarySpace"
90
+ super().__init__(
91
+ single_instance_space,
92
+ file_ext="npz",
93
+ cache_filename=cache_filename,
94
+ mutable=mutable,
95
+ capacity=capacity,
96
+ length=length,
97
+ )
98
+ self.compressed = compressed
99
+
100
+ def get_from_file(self, filename : str) -> BArrayType:
101
+ if not os.path.exists(filename):
102
+ return self.backend.zeros(
103
+ self.single_instance_space.shape,
104
+ dtype=self.single_instance_space.dtype,
105
+ device=self.single_instance_space.device,
106
+ )
107
+
108
+ dat = np.load(filename, allow_pickle=False)
109
+ return self.backend.from_numpy(
110
+ dat['data'],
111
+ dtype=self.single_instance_space.dtype,
112
+ device=self.single_instance_space.device
113
+ )
114
+
115
+ def set_to_file(self, filename : str, value : BArrayType):
116
+ np_value = self.backend.to_numpy(value)
117
+ if self.compressed:
118
+ np.savez_compressed(filename, data=np_value)
119
+ else:
120
+ np.savez(filename, data=np_value)
121
+
122
+ def dumps(self, path):
123
+ assert os.path.samefile(path, self.cache_filename), \
124
+ f"Dump path {path} does not match cache filename {self.cache_filename}"
125
+ metadata = {
126
+ "storage_type": __class__.__name__,
127
+ "compressed": self.compressed,
128
+ "capacity": self.capacity,
129
+ "length": self.length,
130
+ }
131
+ with open(os.path.join(path, "npz_metadata.json"), "w") as f:
132
+ json.dump(metadata, f)
133
+
134
+ def close(self):
135
+ pass
@@ -1,10 +1,10 @@
1
1
  import os
2
2
  import torch
3
- from unienv_interface.space import Space, BoxSpace
3
+ from unienv_interface.space import Space, BoxSpace, BinarySpace
4
4
  from unienv_interface.backends import ComputeBackend
5
5
  from unienv_interface.backends.pytorch import PyTorchComputeBackend, PyTorchArrayType, PyTorchDeviceType, PyTorchDtypeType, PyTorchRNGType
6
6
  from unienv_data.base import SpaceStorage
7
- from tensordict.memmap import MemoryMappedTensor
7
+ from unienv_data.third_party.tensordict.memmap_tensor import MemoryMappedTensor
8
8
  from typing import Generic, TypeVar, Generic, Optional, Any, Dict, Tuple, Sequence, Union, List, Type
9
9
 
10
10
  class PytorchTensorStorage(SpaceStorage[
@@ -20,15 +20,15 @@ class PytorchTensorStorage(SpaceStorage[
20
20
  single_instance_space : BoxSpace[PyTorchArrayType, PyTorchDeviceType, PyTorchDtypeType, PyTorchRNGType],
21
21
  *,
22
22
  capacity : Optional[int],
23
- is_memmap : bool = False,
23
+ is_memmap : bool = True,
24
24
  cache_path : Optional[str] = None,
25
25
  multiprocessing : bool = False,
26
26
  memmap_existok : bool = True,
27
27
  ) -> "PytorchTensorStorage":
28
28
  assert single_instance_space.backend is PyTorchComputeBackend, \
29
29
  f"Single instance space must be of type PyTorchComputeBackend, got {single_instance_space.backend}"
30
- assert isinstance(single_instance_space, BoxSpace), \
31
- f"Single instance space must be a BoxSpace, got {type(single_instance_space)}"
30
+ assert isinstance(single_instance_space, BoxSpace) or isinstance(single_instance_space, BinarySpace), \
31
+ f"Single instance space must be a BoxSpace or BinarySpace, got {type(single_instance_space)}"
32
32
  assert capacity is not None, "Capacity must be specified when creating a new tensor"
33
33
 
34
34
  target_shape = (capacity, *single_instance_space.shape)
@@ -63,9 +63,12 @@ class PytorchTensorStorage(SpaceStorage[
63
63
  def load_from(
64
64
  cls,
65
65
  path: Union[str, os.PathLike],
66
- single_instance_space: BoxSpace[PyTorchArrayType, PyTorchDeviceType, PyTorchDtypeType, PyTorchRNGType],
66
+ single_instance_space: Union[
67
+ BoxSpace[PyTorchArrayType, PyTorchDeviceType, PyTorchDtypeType, PyTorchRNGType],
68
+ BinarySpace[PyTorchArrayType, PyTorchDeviceType, PyTorchDtypeType, PyTorchRNGType],
69
+ ],
67
70
  *,
68
- is_memmap : bool = False,
71
+ is_memmap : bool = True,
69
72
  capacity : Optional[int] = None,
70
73
  read_only : bool = True,
71
74
  multiprocessing : bool = False,
@@ -81,7 +84,8 @@ class PytorchTensorStorage(SpaceStorage[
81
84
  target_data = MemoryMappedTensor.from_filename(
82
85
  path,
83
86
  dtype=single_instance_space.dtype,
84
- shape=target_shape
87
+ shape=target_shape,
88
+ readonly=read_only,
85
89
  )
86
90
 
87
91
  if is_memmap:
@@ -111,7 +115,10 @@ class PytorchTensorStorage(SpaceStorage[
111
115
 
112
116
  def __init__(
113
117
  self,
114
- single_instance_space : BoxSpace[PyTorchArrayType, PyTorchDeviceType, PyTorchDtypeType, PyTorchRNGType],
118
+ single_instance_space : Union[
119
+ BoxSpace[PyTorchArrayType, PyTorchDeviceType, PyTorchDtypeType, PyTorchRNGType],
120
+ BinarySpace[PyTorchArrayType, PyTorchDeviceType, PyTorchDtypeType, PyTorchRNGType],
121
+ ],
115
122
  data : Union[torch.Tensor, MemoryMappedTensor],
116
123
  mutable : bool = True,
117
124
  ):