unienv 0.0.1b5__py3-none-any.whl → 0.0.1b7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {unienv-0.0.1b5.dist-info → unienv-0.0.1b7.dist-info}/METADATA +3 -2
- {unienv-0.0.1b5.dist-info → unienv-0.0.1b7.dist-info}/RECORD +30 -21
- {unienv-0.0.1b5.dist-info → unienv-0.0.1b7.dist-info}/WHEEL +1 -1
- unienv_data/base/common.py +25 -10
- unienv_data/batches/backend_compat.py +1 -1
- unienv_data/batches/combined_batch.py +1 -1
- unienv_data/replay_buffer/replay_buffer.py +51 -8
- unienv_data/storages/_episode_storage.py +438 -0
- unienv_data/storages/_list_storage.py +136 -0
- unienv_data/storages/backend_compat.py +268 -0
- unienv_data/storages/flattened.py +3 -3
- unienv_data/storages/hdf5.py +7 -2
- unienv_data/storages/image_storage.py +144 -0
- unienv_data/storages/npz_storage.py +135 -0
- unienv_data/storages/pytorch.py +16 -9
- unienv_data/storages/video_storage.py +297 -0
- unienv_data/third_party/tensordict/memmap_tensor.py +1174 -0
- unienv_data/transformations/image_compress.py +81 -18
- unienv_interface/space/space_utils/batch_utils.py +5 -1
- unienv_interface/space/spaces/dict.py +6 -0
- unienv_interface/transformations/__init__.py +3 -1
- unienv_interface/transformations/batch_and_unbatch.py +43 -4
- unienv_interface/transformations/chained_transform.py +9 -8
- unienv_interface/transformations/crop.py +69 -0
- unienv_interface/transformations/dict_transform.py +8 -2
- unienv_interface/transformations/identity.py +16 -0
- unienv_interface/transformations/rescale.py +24 -5
- unienv_interface/wrapper/backend_compat.py +1 -1
- {unienv-0.0.1b5.dist-info → unienv-0.0.1b7.dist-info}/licenses/LICENSE +0 -0
- {unienv-0.0.1b5.dist-info → unienv-0.0.1b7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
from typing import Dict, Any, Optional, Tuple, Union, Generic, SupportsFloat, Type, Sequence, Mapping, TypeVar
|
|
2
|
+
import numpy as np
|
|
3
|
+
import copy
|
|
4
|
+
|
|
5
|
+
from unienv_interface.space import Space, DictSpace
|
|
6
|
+
from unienv_interface.space.space_utils import batch_utils as sbu, flatten_utils as sfu
|
|
7
|
+
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
8
|
+
from unienv_interface.backends.serialization import serialize_backend, deserialize_backend
|
|
9
|
+
from unienv_interface.utils.symbol_util import *
|
|
10
|
+
|
|
11
|
+
from unienv_data.base import SpaceStorage, BatchT
|
|
12
|
+
|
|
13
|
+
import os
|
|
14
|
+
import json
|
|
15
|
+
|
|
16
|
+
WrapperBatchT = TypeVar("WrapperBatchT")
|
|
17
|
+
WrapperBArrayT = TypeVar("WrapperBArrayT")
|
|
18
|
+
WrapperBDeviceT = TypeVar("WrapperBDeviceT")
|
|
19
|
+
WrapperBDtypeT = TypeVar("WrapperBDtypeT")
|
|
20
|
+
WrapperBRngT = TypeVar("WrapperBRngT")
|
|
21
|
+
|
|
22
|
+
def data_to(
|
|
23
|
+
data : Any,
|
|
24
|
+
source_backend : Optional[ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType]] = None,
|
|
25
|
+
target_backend : Optional[ComputeBackend[WrapperBArrayT, WrapperBDeviceT, WrapperBDtypeT, WrapperBRngT]] = None,
|
|
26
|
+
target_device : Optional[WrapperBDeviceT] = None,
|
|
27
|
+
):
|
|
28
|
+
if source_backend.is_backendarray(data):
|
|
29
|
+
if source_backend is not None and target_backend is not None and target_backend != source_backend:
|
|
30
|
+
data = target_backend.from_other_backend(
|
|
31
|
+
source_backend,
|
|
32
|
+
data
|
|
33
|
+
)
|
|
34
|
+
if target_device is not None:
|
|
35
|
+
data = (source_backend or target_backend).to_device(
|
|
36
|
+
data,
|
|
37
|
+
target_device
|
|
38
|
+
)
|
|
39
|
+
elif isinstance(data, Mapping):
|
|
40
|
+
data = {
|
|
41
|
+
key: data_to(value, source_backend, target_backend, target_device)
|
|
42
|
+
for key, value in data.items()
|
|
43
|
+
}
|
|
44
|
+
elif isinstance(data, Sequence) and not isinstance(data, (str, bytes)):
|
|
45
|
+
data = [
|
|
46
|
+
data_to(value, source_backend, target_backend, target_device)
|
|
47
|
+
for value in data
|
|
48
|
+
]
|
|
49
|
+
try:
|
|
50
|
+
data = type(data)(data) # Preserve the type of the original sequence
|
|
51
|
+
except:
|
|
52
|
+
pass
|
|
53
|
+
return data
|
|
54
|
+
|
|
55
|
+
class ToBackendOrDeviceStorage(
|
|
56
|
+
SpaceStorage[
|
|
57
|
+
WrapperBatchT, WrapperBArrayT, WrapperBDeviceT, WrapperBDtypeT, WrapperBRngT
|
|
58
|
+
],
|
|
59
|
+
Generic[
|
|
60
|
+
WrapperBatchT, WrapperBArrayT, WrapperBDeviceT, WrapperBDtypeT, WrapperBRngT,
|
|
61
|
+
BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
62
|
+
]
|
|
63
|
+
):
|
|
64
|
+
# ========== Class Implementations ==========
|
|
65
|
+
@classmethod
|
|
66
|
+
def create(
|
|
67
|
+
cls,
|
|
68
|
+
single_instance_space: Space[WrapperBatchT, WrapperBDeviceT, WrapperBDtypeT, WrapperBRngT],
|
|
69
|
+
inner_storage_cls : Type[SpaceStorage[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]],
|
|
70
|
+
*args,
|
|
71
|
+
capacity : Optional[int] = None,
|
|
72
|
+
cache_path : Optional[str] = None,
|
|
73
|
+
multiprocessing : bool = False,
|
|
74
|
+
backend : Optional[ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType]] = None,
|
|
75
|
+
device : Optional[BDeviceType] = None,
|
|
76
|
+
inner_storage_kwargs : Dict[str, Any] = {},
|
|
77
|
+
**kwargs
|
|
78
|
+
) -> "ToBackendOrDeviceStorage[WrapperBatchT, WrapperBArrayT, WrapperBDeviceT, WrapperBDtypeT, WrapperBRngT, BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]":
|
|
79
|
+
inner_storage_path = "inner_storage" + (inner_storage_cls.single_file_ext or "")
|
|
80
|
+
|
|
81
|
+
if cache_path is not None:
|
|
82
|
+
os.makedirs(cache_path, exist_ok=True)
|
|
83
|
+
|
|
84
|
+
_inner_storage_kwargs = kwargs.copy()
|
|
85
|
+
_inner_storage_kwargs.update(inner_storage_kwargs)
|
|
86
|
+
inner_storage = inner_storage_cls.create(
|
|
87
|
+
single_instance_space.to(
|
|
88
|
+
backend=backend,
|
|
89
|
+
device=device
|
|
90
|
+
),
|
|
91
|
+
*args,
|
|
92
|
+
cache_path=None if cache_path is None else os.path.join(cache_path, inner_storage_path),
|
|
93
|
+
capacity=capacity,
|
|
94
|
+
multiprocessing=multiprocessing,
|
|
95
|
+
**_inner_storage_kwargs
|
|
96
|
+
)
|
|
97
|
+
if (backend is None or backend == single_instance_space.backend) and (device is None or device == single_instance_space.device):
|
|
98
|
+
return inner_storage
|
|
99
|
+
|
|
100
|
+
return ToBackendOrDeviceStorage(
|
|
101
|
+
single_instance_space,
|
|
102
|
+
inner_storage,
|
|
103
|
+
inner_storage_path,
|
|
104
|
+
cache_filename=cache_path,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
@classmethod
|
|
108
|
+
def load_from(
|
|
109
|
+
cls,
|
|
110
|
+
path : Union[str, os.PathLike],
|
|
111
|
+
single_instance_space : Space[Any, BDeviceType, BDtypeType, BRNGType],
|
|
112
|
+
*,
|
|
113
|
+
capacity : Optional[int] = None,
|
|
114
|
+
read_only : bool = True,
|
|
115
|
+
multiprocessing : bool = False,
|
|
116
|
+
**kwargs
|
|
117
|
+
) -> Union[
|
|
118
|
+
"ToBackendOrDeviceStorage[WrapperBatchT, WrapperBArrayT, WrapperBDeviceT, WrapperBDtypeT, WrapperBRngT, BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]",
|
|
119
|
+
SpaceStorage[WrapperBatchT, WrapperBArrayT, WrapperBDeviceT, WrapperBDtypeT, WrapperBRngT]
|
|
120
|
+
]:
|
|
121
|
+
metadata_path = os.path.join(path, "backend_metadata.json")
|
|
122
|
+
assert os.path.exists(metadata_path), f"Metadata file {metadata_path} does not exist"
|
|
123
|
+
with open(metadata_path, "r") as f:
|
|
124
|
+
metadata = json.load(f)
|
|
125
|
+
assert metadata["storage_type"] == cls.__name__, \
|
|
126
|
+
f"Expected storage type {cls.__name__}, but found {metadata['storage_type']}"
|
|
127
|
+
inner_storage_cls : Type[SpaceStorage] = get_class_from_full_name(metadata["inner_storage_type"])
|
|
128
|
+
inner_storage_path = metadata["inner_storage_path"]
|
|
129
|
+
inner_backend = deserialize_backend(metadata["inner_backend"])
|
|
130
|
+
inner_device = inner_backend.deserialize_device(metadata["inner_device"])
|
|
131
|
+
|
|
132
|
+
if inner_backend != single_instance_space.backend or inner_device != single_instance_space.device:
|
|
133
|
+
inner_space = single_instance_space.to(
|
|
134
|
+
backend=inner_backend if inner_backend != single_instance_space.backend else None,
|
|
135
|
+
device=inner_device if inner_device != single_instance_space.device else None
|
|
136
|
+
)
|
|
137
|
+
else:
|
|
138
|
+
inner_space = single_instance_space
|
|
139
|
+
|
|
140
|
+
inner_storage = inner_storage_cls.load_from(
|
|
141
|
+
os.path.join(path, inner_storage_path),
|
|
142
|
+
inner_space,
|
|
143
|
+
capacity=capacity,
|
|
144
|
+
read_only=read_only,
|
|
145
|
+
multiprocessing=multiprocessing,
|
|
146
|
+
**kwargs
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
if inner_backend == single_instance_space.backend and inner_device == single_instance_space.device:
|
|
150
|
+
return inner_storage
|
|
151
|
+
else:
|
|
152
|
+
return ToBackendOrDeviceStorage(
|
|
153
|
+
single_instance_space,
|
|
154
|
+
inner_storage,
|
|
155
|
+
inner_storage_path,
|
|
156
|
+
cache_filename=path,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
# ======== Instance Implementations ==========
|
|
160
|
+
single_file_ext = None
|
|
161
|
+
|
|
162
|
+
def __init__(
|
|
163
|
+
self,
|
|
164
|
+
single_instance_space : Space[WrapperBatchT, WrapperBDeviceT, WrapperBDtypeT, WrapperBRngT],
|
|
165
|
+
inner_storage : SpaceStorage[
|
|
166
|
+
BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
167
|
+
],
|
|
168
|
+
inner_storage_path : Union[str, os.PathLike],
|
|
169
|
+
cache_filename : Optional[Union[str, os.PathLike]] = None,
|
|
170
|
+
):
|
|
171
|
+
super().__init__(single_instance_space)
|
|
172
|
+
inner_backend = None if inner_storage.backend == single_instance_space.backend else inner_storage.backend
|
|
173
|
+
inner_device = None if inner_storage.device == single_instance_space.device else inner_storage.device
|
|
174
|
+
current_backend = None if inner_storage.backend == single_instance_space.backend else single_instance_space.backend
|
|
175
|
+
current_device = None if inner_storage.device == single_instance_space.device else single_instance_space.device
|
|
176
|
+
|
|
177
|
+
self._batched_instance_space = sbu.batch_space(
|
|
178
|
+
single_instance_space,
|
|
179
|
+
1
|
|
180
|
+
)
|
|
181
|
+
self._batched_inner_space = self._batched_instance_space.to(
|
|
182
|
+
backend=inner_backend,
|
|
183
|
+
device=inner_device
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
self.inner_storage = inner_storage
|
|
187
|
+
self.inner_storage_path = inner_storage_path
|
|
188
|
+
self._cache_filename = cache_filename
|
|
189
|
+
|
|
190
|
+
self.inner_backend = inner_backend
|
|
191
|
+
self.inner_device = inner_device
|
|
192
|
+
self.current_backend = current_backend
|
|
193
|
+
self.current_device = current_device
|
|
194
|
+
|
|
195
|
+
@property
|
|
196
|
+
def cache_filename(self) -> Optional[Union[str, os.PathLike]]:
|
|
197
|
+
return self._cache_filename
|
|
198
|
+
|
|
199
|
+
@property
|
|
200
|
+
def is_mutable(self) -> bool:
|
|
201
|
+
return self.inner_storage.is_mutable
|
|
202
|
+
|
|
203
|
+
@property
|
|
204
|
+
def is_multiprocessing_safe(self) -> bool:
|
|
205
|
+
return self.inner_storage.is_multiprocessing_safe
|
|
206
|
+
|
|
207
|
+
@property
|
|
208
|
+
def capacity(self) -> Optional[int]:
|
|
209
|
+
return self.inner_storage.capacity
|
|
210
|
+
|
|
211
|
+
def extend_length(self, length):
|
|
212
|
+
self.inner_storage.extend_length(length)
|
|
213
|
+
|
|
214
|
+
def shrink_length(self, length):
|
|
215
|
+
self.inner_storage.shrink_length(length)
|
|
216
|
+
|
|
217
|
+
def __len__(self):
|
|
218
|
+
return len(self.inner_storage)
|
|
219
|
+
|
|
220
|
+
def get(self, index):
|
|
221
|
+
if self.backend.is_backendarray(index):
|
|
222
|
+
index = data_to(
|
|
223
|
+
index,
|
|
224
|
+
source_backend=self.backend,
|
|
225
|
+
target_backend=self.inner_backend,
|
|
226
|
+
target_device=self.inner_device
|
|
227
|
+
)
|
|
228
|
+
target_data = self.inner_storage.get(index)
|
|
229
|
+
return data_to(
|
|
230
|
+
target_data,
|
|
231
|
+
source_backend=self.inner_storage.backend,
|
|
232
|
+
target_backend=self.current_backend,
|
|
233
|
+
target_device=self.current_device
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
def set(self, index, value):
|
|
237
|
+
target_value = data_to(
|
|
238
|
+
value,
|
|
239
|
+
source_backend=self.backend,
|
|
240
|
+
target_backend=self.inner_backend,
|
|
241
|
+
target_device=self.inner_device
|
|
242
|
+
)
|
|
243
|
+
if self.backend.is_backendarray(index):
|
|
244
|
+
index = data_to(
|
|
245
|
+
index,
|
|
246
|
+
source_backend=self.backend,
|
|
247
|
+
target_backend=self.inner_backend,
|
|
248
|
+
target_device=self.inner_device
|
|
249
|
+
)
|
|
250
|
+
self.inner_storage.set(index, target_value)
|
|
251
|
+
|
|
252
|
+
def clear(self):
|
|
253
|
+
self.inner_storage.clear()
|
|
254
|
+
|
|
255
|
+
def dumps(self, path):
|
|
256
|
+
metadata = {
|
|
257
|
+
"storage_type": __class__.__name__,
|
|
258
|
+
"inner_storage_type": get_full_class_name(type(self.inner_storage)),
|
|
259
|
+
"inner_storage_path": self.inner_storage_path,
|
|
260
|
+
"inner_backend": serialize_backend(self.inner_storage.backend),
|
|
261
|
+
"inner_device": self.inner_storage.backend.serialize_device(self.inner_storage.device),
|
|
262
|
+
}
|
|
263
|
+
self.inner_storage.dumps(os.path.join(path, self.inner_storage_path))
|
|
264
|
+
with open(os.path.join(path, "backend_metadata.json"), "w") as f:
|
|
265
|
+
json.dump(metadata, f)
|
|
266
|
+
|
|
267
|
+
def close(self):
|
|
268
|
+
self.inner_storage.close()
|
|
@@ -23,7 +23,7 @@ class FlattenedStorage(SpaceStorage[
|
|
|
23
23
|
@classmethod
|
|
24
24
|
def create(
|
|
25
25
|
cls,
|
|
26
|
-
single_instance_space: Space[
|
|
26
|
+
single_instance_space: Space[BatchT, BDeviceType, BDtypeType, BRNGType],
|
|
27
27
|
inner_storage_cls : Type[SpaceStorage[BArrayType, BArrayType, BDeviceType, BDtypeType, BRNGType]],
|
|
28
28
|
*args,
|
|
29
29
|
capacity : Optional[int] = None,
|
|
@@ -59,7 +59,7 @@ class FlattenedStorage(SpaceStorage[
|
|
|
59
59
|
def load_from(
|
|
60
60
|
cls,
|
|
61
61
|
path : Union[str, os.PathLike],
|
|
62
|
-
single_instance_space : Space[
|
|
62
|
+
single_instance_space : Space[BatchT, BDeviceType, BDtypeType, BRNGType],
|
|
63
63
|
*,
|
|
64
64
|
capacity : Optional[int] = None,
|
|
65
65
|
read_only : bool = True,
|
|
@@ -95,7 +95,7 @@ class FlattenedStorage(SpaceStorage[
|
|
|
95
95
|
|
|
96
96
|
def __init__(
|
|
97
97
|
self,
|
|
98
|
-
single_instance_space: Space[
|
|
98
|
+
single_instance_space: Space[BatchT, BDeviceType, BDtypeType, BRNGType],
|
|
99
99
|
inner_storage : SpaceStorage[
|
|
100
100
|
BArrayType,
|
|
101
101
|
BArrayType,
|
unienv_data/storages/hdf5.py
CHANGED
|
@@ -431,6 +431,11 @@ class HDF5Storage(SpaceStorage[
|
|
|
431
431
|
# Convert to numpy array if it's a scalar
|
|
432
432
|
if isinstance(result, (int, float)):
|
|
433
433
|
result = np.array(result)
|
|
434
|
+
if isinstance(single_instance_space, TextSpace):
|
|
435
|
+
if isinstance(result, bytes):
|
|
436
|
+
result = result.decode('utf-8')
|
|
437
|
+
elif isinstance(result, np.ndarray):
|
|
438
|
+
result = np.array([r.decode('utf-8') if isinstance(r, bytes) else r for r in result], dtype=object)
|
|
434
439
|
return result
|
|
435
440
|
|
|
436
441
|
@staticmethod
|
|
@@ -513,8 +518,8 @@ class HDF5Storage(SpaceStorage[
|
|
|
513
518
|
reduce_io : bool = True,
|
|
514
519
|
**kwargs
|
|
515
520
|
) -> "HDF5Storage":
|
|
516
|
-
assert not multiprocessing, \
|
|
517
|
-
"HDF5Storage does not support multiprocessing safe loading
|
|
521
|
+
assert read_only or not multiprocessing, \
|
|
522
|
+
"HDF5Storage does not support multiprocessing safe loading when not `read_only`. Please load the storage in the main process and then share it with child processes."
|
|
518
523
|
assert os.path.exists(path), \
|
|
519
524
|
f"Path {path} does not exist"
|
|
520
525
|
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
from importlib import metadata
|
|
2
|
+
from typing import Generic, TypeVar, Generic, Optional, Any, Dict, Tuple, Sequence, Union, List, Iterable, Type
|
|
3
|
+
|
|
4
|
+
from unienv_interface.space import Space, BoxSpace
|
|
5
|
+
from unienv_interface.space.space_utils import batch_utils as sbu, flatten_utils as sfu
|
|
6
|
+
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
7
|
+
from unienv_interface.utils.symbol_util import *
|
|
8
|
+
|
|
9
|
+
from unienv_data.base import SpaceStorage
|
|
10
|
+
from ._list_storage import ListStorageBase
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import os
|
|
14
|
+
import json
|
|
15
|
+
import shutil
|
|
16
|
+
|
|
17
|
+
from PIL import Image
|
|
18
|
+
|
|
19
|
+
class ImageStorage(ListStorageBase[
|
|
20
|
+
BArrayType,
|
|
21
|
+
BArrayType,
|
|
22
|
+
BDeviceType,
|
|
23
|
+
BDtypeType,
|
|
24
|
+
BRNGType,
|
|
25
|
+
]):
|
|
26
|
+
# ========== Class Attributes ==========
|
|
27
|
+
@classmethod
|
|
28
|
+
def create(
|
|
29
|
+
cls,
|
|
30
|
+
single_instance_space: BoxSpace[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
31
|
+
*args,
|
|
32
|
+
format : str = "JPEG",
|
|
33
|
+
quality : int = 75,
|
|
34
|
+
capacity : Optional[int] = None,
|
|
35
|
+
cache_path : Optional[str] = None,
|
|
36
|
+
multiprocessing : bool = False,
|
|
37
|
+
**kwargs
|
|
38
|
+
) -> "ImageStorage[BArrayType, BDeviceType, BDtypeType, BRNGType]":
|
|
39
|
+
if cache_path is None:
|
|
40
|
+
raise ValueError("cache_path must be provided for ImageStorage.create")
|
|
41
|
+
assert not os.path.exists(cache_path), f"Cache path {cache_path} already exists"
|
|
42
|
+
os.makedirs(cache_path, exist_ok=True)
|
|
43
|
+
return ImageStorage(
|
|
44
|
+
single_instance_space,
|
|
45
|
+
format=format,
|
|
46
|
+
quality=quality,
|
|
47
|
+
cache_filename=cache_path,
|
|
48
|
+
capacity=capacity,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
@classmethod
|
|
52
|
+
def load_from(
|
|
53
|
+
cls,
|
|
54
|
+
path : Union[str, os.PathLike],
|
|
55
|
+
single_instance_space : BoxSpace[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
56
|
+
*,
|
|
57
|
+
capacity : Optional[int] = None,
|
|
58
|
+
read_only : bool = True,
|
|
59
|
+
multiprocessing : bool = False,
|
|
60
|
+
**kwargs
|
|
61
|
+
) -> "ImageStorage[BArrayType, BDeviceType, BDtypeType, BRNGType]":
|
|
62
|
+
metadata_path = os.path.join(path, "image_metadata.json")
|
|
63
|
+
assert os.path.exists(metadata_path), f"Metadata file {metadata_path} does not exist"
|
|
64
|
+
with open(metadata_path, "r") as f:
|
|
65
|
+
metadata = json.load(f)
|
|
66
|
+
assert metadata["storage_type"] == cls.__name__, \
|
|
67
|
+
f"Expected storage type {cls.__name__}, but found {metadata['storage_type']}"
|
|
68
|
+
|
|
69
|
+
format = metadata["format"]
|
|
70
|
+
quality = int(metadata["quality"])
|
|
71
|
+
if "capacity" in metadata:
|
|
72
|
+
capacity = None if metadata['capacity'] is None else int(metadata["capacity"])
|
|
73
|
+
length = None if capacity is None else metadata["length"]
|
|
74
|
+
|
|
75
|
+
return ImageStorage(
|
|
76
|
+
single_instance_space,
|
|
77
|
+
format=format,
|
|
78
|
+
quality=quality,
|
|
79
|
+
cache_filename=path,
|
|
80
|
+
mutable=not read_only,
|
|
81
|
+
capacity=capacity,
|
|
82
|
+
length=length,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# ========== Instance Implementations ==========
|
|
86
|
+
single_file_ext = None
|
|
87
|
+
|
|
88
|
+
def __init__(
|
|
89
|
+
self,
|
|
90
|
+
single_instance_space: BoxSpace[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
91
|
+
format : str,
|
|
92
|
+
quality : int,
|
|
93
|
+
cache_filename : Union[str, os.PathLike],
|
|
94
|
+
mutable : bool = True,
|
|
95
|
+
capacity : Optional[int] = None,
|
|
96
|
+
length : int = 0,
|
|
97
|
+
):
|
|
98
|
+
super().__init__(
|
|
99
|
+
single_instance_space,
|
|
100
|
+
file_ext=format.lower(),
|
|
101
|
+
cache_filename=cache_filename,
|
|
102
|
+
mutable=mutable,
|
|
103
|
+
capacity=capacity,
|
|
104
|
+
length=length,
|
|
105
|
+
)
|
|
106
|
+
self.format = format
|
|
107
|
+
self.quality = quality
|
|
108
|
+
|
|
109
|
+
def get_from_file(self, filename : str) -> BArrayType:
|
|
110
|
+
if not os.path.exists(filename):
|
|
111
|
+
return self.backend.zeros(
|
|
112
|
+
self.single_instance_space.shape,
|
|
113
|
+
dtype=self.single_instance_space.dtype,
|
|
114
|
+
device=self.single_instance_space.device,
|
|
115
|
+
)
|
|
116
|
+
with Image.open(filename) as img:
|
|
117
|
+
rgb_img = img.convert("RGB")
|
|
118
|
+
# Ensure a compact uint8 array from PIL.
|
|
119
|
+
np_image = np.asarray(rgb_img, dtype=np.uint8)
|
|
120
|
+
return self.backend.from_numpy(np_image, dtype=self.single_instance_space.dtype, device=self.single_instance_space.device)
|
|
121
|
+
|
|
122
|
+
def set_to_file(self, filename : str, value : BArrayType):
|
|
123
|
+
np_value = self.backend.to_numpy(value)
|
|
124
|
+
img = Image.fromarray(np_value, mode="RGB")
|
|
125
|
+
try:
|
|
126
|
+
img.save(filename, format=self.format, quality=self.quality)
|
|
127
|
+
finally:
|
|
128
|
+
img.close()
|
|
129
|
+
|
|
130
|
+
def dumps(self, path):
|
|
131
|
+
assert os.path.samefile(path, self.cache_filename), \
|
|
132
|
+
f"Dump path {path} does not match cache filename {self.cache_filename}"
|
|
133
|
+
metadata = {
|
|
134
|
+
"storage_type": __class__.__name__,
|
|
135
|
+
"format": self.format,
|
|
136
|
+
"quality": self.quality,
|
|
137
|
+
"capacity": self.capacity,
|
|
138
|
+
"length": self.length,
|
|
139
|
+
}
|
|
140
|
+
with open(os.path.join(path, "image_metadata.json"), "w") as f:
|
|
141
|
+
json.dump(metadata, f)
|
|
142
|
+
|
|
143
|
+
def close(self):
|
|
144
|
+
pass
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
from importlib import metadata
|
|
2
|
+
from typing import Generic, TypeVar, Generic, Optional, Any, Dict, Tuple, Sequence, Union, List, Iterable, Type
|
|
3
|
+
|
|
4
|
+
from unienv_interface.space import Space, BoxSpace, BinarySpace
|
|
5
|
+
from unienv_interface.space.space_utils import batch_utils as sbu, flatten_utils as sfu
|
|
6
|
+
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
7
|
+
from unienv_interface.utils.symbol_util import *
|
|
8
|
+
|
|
9
|
+
from unienv_data.base import SpaceStorage
|
|
10
|
+
from ._list_storage import ListStorageBase
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import os
|
|
14
|
+
import json
|
|
15
|
+
import shutil
|
|
16
|
+
|
|
17
|
+
from PIL import Image
|
|
18
|
+
|
|
19
|
+
class NPZStorage(ListStorageBase[
|
|
20
|
+
BArrayType,
|
|
21
|
+
BArrayType,
|
|
22
|
+
BDeviceType,
|
|
23
|
+
BDtypeType,
|
|
24
|
+
BRNGType,
|
|
25
|
+
]):
|
|
26
|
+
# ========== Class Attributes ==========
|
|
27
|
+
@classmethod
|
|
28
|
+
def create(
|
|
29
|
+
cls,
|
|
30
|
+
single_instance_space: BoxSpace[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
31
|
+
*args,
|
|
32
|
+
compressed : bool = True,
|
|
33
|
+
capacity : Optional[int] = None,
|
|
34
|
+
cache_path : Optional[str] = None,
|
|
35
|
+
multiprocessing : bool = False,
|
|
36
|
+
**kwargs
|
|
37
|
+
) -> "NPZStorage[BArrayType, BDeviceType, BDtypeType, BRNGType]":
|
|
38
|
+
assert not os.path.exists(cache_path), f"Cache path {cache_path} already exists"
|
|
39
|
+
os.makedirs(cache_path, exist_ok=True)
|
|
40
|
+
return NPZStorage(
|
|
41
|
+
single_instance_space,
|
|
42
|
+
compressed=compressed,
|
|
43
|
+
cache_filename=cache_path,
|
|
44
|
+
capacity=capacity,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
@classmethod
|
|
48
|
+
def load_from(
|
|
49
|
+
cls,
|
|
50
|
+
path : Union[str, os.PathLike],
|
|
51
|
+
single_instance_space : BoxSpace[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
52
|
+
*,
|
|
53
|
+
capacity : Optional[int] = None,
|
|
54
|
+
read_only : bool = True,
|
|
55
|
+
multiprocessing : bool = False,
|
|
56
|
+
**kwargs
|
|
57
|
+
) -> "NPZStorage[BArrayType, BDeviceType, BDtypeType, BRNGType]":
|
|
58
|
+
metadata_path = os.path.join(path, "npz_metadata.json")
|
|
59
|
+
assert os.path.exists(metadata_path), f"Metadata file {metadata_path} does not exist"
|
|
60
|
+
with open(metadata_path, "r") as f:
|
|
61
|
+
metadata = json.load(f)
|
|
62
|
+
assert metadata["storage_type"] == cls.__name__, \
|
|
63
|
+
f"Expected storage type {cls.__name__}, but found {metadata['storage_type']}"
|
|
64
|
+
|
|
65
|
+
compressed = metadata.get("compressed", True)
|
|
66
|
+
if "capacity" in metadata:
|
|
67
|
+
capacity = None if metadata['capacity'] is None else int(metadata["capacity"])
|
|
68
|
+
length = None if capacity is None else metadata["length"]
|
|
69
|
+
|
|
70
|
+
return NPZStorage(
|
|
71
|
+
single_instance_space,
|
|
72
|
+
compressed=compressed,
|
|
73
|
+
cache_filename=path,
|
|
74
|
+
mutable=not read_only,
|
|
75
|
+
capacity=capacity,
|
|
76
|
+
length=length,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
# ========== Instance Implementations ==========
|
|
80
|
+
def __init__(
|
|
81
|
+
self,
|
|
82
|
+
single_instance_space: BoxSpace[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
83
|
+
compressed : bool,
|
|
84
|
+
cache_filename : Union[str, os.PathLike],
|
|
85
|
+
mutable : bool = True,
|
|
86
|
+
capacity : Optional[int] = None,
|
|
87
|
+
length : int = 0,
|
|
88
|
+
):
|
|
89
|
+
assert isinstance(single_instance_space, BoxSpace) or isinstance(single_instance_space, BinarySpace), "single_instance_space must be a BoxSpace or BinarySpace"
|
|
90
|
+
super().__init__(
|
|
91
|
+
single_instance_space,
|
|
92
|
+
file_ext="npz",
|
|
93
|
+
cache_filename=cache_filename,
|
|
94
|
+
mutable=mutable,
|
|
95
|
+
capacity=capacity,
|
|
96
|
+
length=length,
|
|
97
|
+
)
|
|
98
|
+
self.compressed = compressed
|
|
99
|
+
|
|
100
|
+
def get_from_file(self, filename : str) -> BArrayType:
|
|
101
|
+
if not os.path.exists(filename):
|
|
102
|
+
return self.backend.zeros(
|
|
103
|
+
self.single_instance_space.shape,
|
|
104
|
+
dtype=self.single_instance_space.dtype,
|
|
105
|
+
device=self.single_instance_space.device,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
dat = np.load(filename, allow_pickle=False)
|
|
109
|
+
return self.backend.from_numpy(
|
|
110
|
+
dat['data'],
|
|
111
|
+
dtype=self.single_instance_space.dtype,
|
|
112
|
+
device=self.single_instance_space.device
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
def set_to_file(self, filename : str, value : BArrayType):
|
|
116
|
+
np_value = self.backend.to_numpy(value)
|
|
117
|
+
if self.compressed:
|
|
118
|
+
np.savez_compressed(filename, data=np_value)
|
|
119
|
+
else:
|
|
120
|
+
np.savez(filename, data=np_value)
|
|
121
|
+
|
|
122
|
+
def dumps(self, path):
|
|
123
|
+
assert os.path.samefile(path, self.cache_filename), \
|
|
124
|
+
f"Dump path {path} does not match cache filename {self.cache_filename}"
|
|
125
|
+
metadata = {
|
|
126
|
+
"storage_type": __class__.__name__,
|
|
127
|
+
"compressed": self.compressed,
|
|
128
|
+
"capacity": self.capacity,
|
|
129
|
+
"length": self.length,
|
|
130
|
+
}
|
|
131
|
+
with open(os.path.join(path, "npz_metadata.json"), "w") as f:
|
|
132
|
+
json.dump(metadata, f)
|
|
133
|
+
|
|
134
|
+
def close(self):
|
|
135
|
+
pass
|
unienv_data/storages/pytorch.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import torch
|
|
3
|
-
from unienv_interface.space import Space, BoxSpace
|
|
3
|
+
from unienv_interface.space import Space, BoxSpace, BinarySpace
|
|
4
4
|
from unienv_interface.backends import ComputeBackend
|
|
5
5
|
from unienv_interface.backends.pytorch import PyTorchComputeBackend, PyTorchArrayType, PyTorchDeviceType, PyTorchDtypeType, PyTorchRNGType
|
|
6
6
|
from unienv_data.base import SpaceStorage
|
|
7
|
-
from tensordict.
|
|
7
|
+
from unienv_data.third_party.tensordict.memmap_tensor import MemoryMappedTensor
|
|
8
8
|
from typing import Generic, TypeVar, Generic, Optional, Any, Dict, Tuple, Sequence, Union, List, Type
|
|
9
9
|
|
|
10
10
|
class PytorchTensorStorage(SpaceStorage[
|
|
@@ -20,15 +20,15 @@ class PytorchTensorStorage(SpaceStorage[
|
|
|
20
20
|
single_instance_space : BoxSpace[PyTorchArrayType, PyTorchDeviceType, PyTorchDtypeType, PyTorchRNGType],
|
|
21
21
|
*,
|
|
22
22
|
capacity : Optional[int],
|
|
23
|
-
is_memmap : bool =
|
|
23
|
+
is_memmap : bool = True,
|
|
24
24
|
cache_path : Optional[str] = None,
|
|
25
25
|
multiprocessing : bool = False,
|
|
26
26
|
memmap_existok : bool = True,
|
|
27
27
|
) -> "PytorchTensorStorage":
|
|
28
28
|
assert single_instance_space.backend is PyTorchComputeBackend, \
|
|
29
29
|
f"Single instance space must be of type PyTorchComputeBackend, got {single_instance_space.backend}"
|
|
30
|
-
assert isinstance(single_instance_space, BoxSpace), \
|
|
31
|
-
f"Single instance space must be a BoxSpace, got {type(single_instance_space)}"
|
|
30
|
+
assert isinstance(single_instance_space, BoxSpace) or isinstance(single_instance_space, BinarySpace), \
|
|
31
|
+
f"Single instance space must be a BoxSpace or BinarySpace, got {type(single_instance_space)}"
|
|
32
32
|
assert capacity is not None, "Capacity must be specified when creating a new tensor"
|
|
33
33
|
|
|
34
34
|
target_shape = (capacity, *single_instance_space.shape)
|
|
@@ -63,9 +63,12 @@ class PytorchTensorStorage(SpaceStorage[
|
|
|
63
63
|
def load_from(
|
|
64
64
|
cls,
|
|
65
65
|
path: Union[str, os.PathLike],
|
|
66
|
-
single_instance_space:
|
|
66
|
+
single_instance_space: Union[
|
|
67
|
+
BoxSpace[PyTorchArrayType, PyTorchDeviceType, PyTorchDtypeType, PyTorchRNGType],
|
|
68
|
+
BinarySpace[PyTorchArrayType, PyTorchDeviceType, PyTorchDtypeType, PyTorchRNGType],
|
|
69
|
+
],
|
|
67
70
|
*,
|
|
68
|
-
is_memmap : bool =
|
|
71
|
+
is_memmap : bool = True,
|
|
69
72
|
capacity : Optional[int] = None,
|
|
70
73
|
read_only : bool = True,
|
|
71
74
|
multiprocessing : bool = False,
|
|
@@ -81,7 +84,8 @@ class PytorchTensorStorage(SpaceStorage[
|
|
|
81
84
|
target_data = MemoryMappedTensor.from_filename(
|
|
82
85
|
path,
|
|
83
86
|
dtype=single_instance_space.dtype,
|
|
84
|
-
shape=target_shape
|
|
87
|
+
shape=target_shape,
|
|
88
|
+
readonly=read_only,
|
|
85
89
|
)
|
|
86
90
|
|
|
87
91
|
if is_memmap:
|
|
@@ -111,7 +115,10 @@ class PytorchTensorStorage(SpaceStorage[
|
|
|
111
115
|
|
|
112
116
|
def __init__(
|
|
113
117
|
self,
|
|
114
|
-
single_instance_space :
|
|
118
|
+
single_instance_space : Union[
|
|
119
|
+
BoxSpace[PyTorchArrayType, PyTorchDeviceType, PyTorchDtypeType, PyTorchRNGType],
|
|
120
|
+
BinarySpace[PyTorchArrayType, PyTorchDeviceType, PyTorchDtypeType, PyTorchRNGType],
|
|
121
|
+
],
|
|
115
122
|
data : Union[torch.Tensor, MemoryMappedTensor],
|
|
116
123
|
mutable : bool = True,
|
|
117
124
|
):
|