unienv 0.0.1b1__py3-none-any.whl → 0.0.1b3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- unienv-0.0.1b3.dist-info/METADATA +74 -0
- unienv-0.0.1b3.dist-info/RECORD +92 -0
- {unienv-0.0.1b1.dist-info → unienv-0.0.1b3.dist-info}/licenses/LICENSE +1 -3
- unienv-0.0.1b3.dist-info/top_level.txt +2 -0
- unienv_data/base/__init__.py +0 -1
- unienv_data/base/common.py +95 -45
- unienv_data/base/storage.py +1 -0
- unienv_data/batches/__init__.py +2 -1
- unienv_data/batches/backend_compat.py +47 -1
- unienv_data/batches/combined_batch.py +2 -4
- unienv_data/{base → batches}/transformations.py +3 -2
- unienv_data/replay_buffer/replay_buffer.py +4 -0
- unienv_data/samplers/__init__.py +0 -1
- unienv_data/samplers/multiprocessing_sampler.py +26 -22
- unienv_data/samplers/step_sampler.py +9 -18
- unienv_data/storages/common.py +5 -0
- unienv_data/storages/hdf5.py +291 -20
- unienv_data/storages/pytorch.py +1 -0
- unienv_data/storages/transformation.py +191 -0
- unienv_data/transformations/image_compress.py +213 -0
- unienv_interface/backends/jax.py +4 -1
- unienv_interface/backends/numpy.py +4 -1
- unienv_interface/backends/pytorch.py +4 -1
- unienv_interface/env_base/__init__.py +1 -0
- unienv_interface/env_base/env.py +5 -0
- unienv_interface/env_base/funcenv.py +32 -1
- unienv_interface/env_base/funcenv_wrapper.py +2 -2
- unienv_interface/env_base/vec_env.py +474 -0
- unienv_interface/func_wrapper/__init__.py +2 -1
- unienv_interface/func_wrapper/frame_stack.py +150 -0
- unienv_interface/space/space_utils/__init__.py +1 -0
- unienv_interface/space/space_utils/batch_utils.py +83 -0
- unienv_interface/space/space_utils/construct_utils.py +216 -0
- unienv_interface/space/space_utils/serialization_utils.py +16 -1
- unienv_interface/space/spaces/__init__.py +3 -1
- unienv_interface/space/spaces/batched.py +90 -0
- unienv_interface/space/spaces/binary.py +0 -1
- unienv_interface/space/spaces/box.py +13 -24
- unienv_interface/space/spaces/text.py +1 -3
- unienv_interface/transformations/dict_transform.py +31 -5
- unienv_interface/utils/control_util.py +68 -0
- unienv_interface/utils/data_queue.py +184 -0
- unienv_interface/utils/stateclass.py +46 -0
- unienv_interface/utils/vec_util.py +15 -0
- unienv_interface/world/__init__.py +3 -1
- unienv_interface/world/combined_funcnode.py +336 -0
- unienv_interface/world/combined_node.py +232 -0
- unienv_interface/wrapper/backend_compat.py +2 -2
- unienv_interface/wrapper/frame_stack.py +19 -114
- unienv_interface/wrapper/video_record.py +11 -2
- unienv-0.0.1b1.dist-info/METADATA +0 -20
- unienv-0.0.1b1.dist-info/RECORD +0 -85
- unienv-0.0.1b1.dist-info/top_level.txt +0 -4
- unienv_data/samplers/slice_sampler.py +0 -266
- unienv_maniskill/__init__.py +0 -1
- unienv_maniskill/wrapper/maniskill_compat.py +0 -235
- unienv_mjxplayground/__init__.py +0 -1
- unienv_mjxplayground/wrapper/playground_compat.py +0 -256
- {unienv-0.0.1b1.dist-info → unienv-0.0.1b3.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
from importlib import metadata
|
|
2
|
+
from typing import Generic, TypeVar, Generic, Optional, Any, Dict, Tuple, Sequence, Union, List, Iterable, Type
|
|
3
|
+
|
|
4
|
+
from unienv_interface.space import Space, BoxSpace
|
|
5
|
+
from unienv_interface.space.space_utils import batch_utils as sbu, flatten_utils as sfu
|
|
6
|
+
from unienv_interface.env_base.env import ContextType, ObsType, ActType
|
|
7
|
+
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
8
|
+
from unienv_interface.backends.numpy import NumpyComputeBackend
|
|
9
|
+
from unienv_interface.utils.symbol_util import *
|
|
10
|
+
from unienv_interface.transformations import DataTransformation
|
|
11
|
+
|
|
12
|
+
from unienv_data.base import SpaceStorage, BatchT
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import os
|
|
16
|
+
import json
|
|
17
|
+
import pickle
|
|
18
|
+
|
|
19
|
+
class TransformedStorage(SpaceStorage[
|
|
20
|
+
BatchT,
|
|
21
|
+
BArrayType,
|
|
22
|
+
BDeviceType,
|
|
23
|
+
BDtypeType,
|
|
24
|
+
BRNGType,
|
|
25
|
+
]):
|
|
26
|
+
# ========== Class Attributes ==========
|
|
27
|
+
@classmethod
|
|
28
|
+
def create(
|
|
29
|
+
cls,
|
|
30
|
+
single_instance_space: Space[Any, BDeviceType, BDtypeType, BRNGType],
|
|
31
|
+
inner_storage_cls : Type[SpaceStorage[BArrayType, BArrayType, BDeviceType, BDtypeType, BRNGType]],
|
|
32
|
+
*args,
|
|
33
|
+
data_transformation : DataTransformation,
|
|
34
|
+
capacity : Optional[int] = None,
|
|
35
|
+
cache_path : Optional[str] = None,
|
|
36
|
+
**kwargs
|
|
37
|
+
) -> "TransformedStorage[BArrayType, BDeviceType, BDtypeType, BRNGType]":
|
|
38
|
+
assert data_transformation.has_inverse, "To transform storages (potentially to save space), you need to use inversible data transformations"
|
|
39
|
+
transformed_space = data_transformation.get_target_space_from_source(single_instance_space)
|
|
40
|
+
inner_storage_path = "transformed_inner_storage" + (inner_storage_cls.single_file_ext or "")
|
|
41
|
+
|
|
42
|
+
if cache_path is not None:
|
|
43
|
+
os.makedirs(cache_path, exist_ok=True)
|
|
44
|
+
|
|
45
|
+
inner_storage = inner_storage_cls.create(
|
|
46
|
+
transformed_space,
|
|
47
|
+
*args,
|
|
48
|
+
cache_path=None if cache_path is None else os.path.join(cache_path, inner_storage_path),
|
|
49
|
+
capacity=capacity,
|
|
50
|
+
**kwargs
|
|
51
|
+
)
|
|
52
|
+
return TransformedStorage(
|
|
53
|
+
single_instance_space,
|
|
54
|
+
data_transformation,
|
|
55
|
+
inner_storage,
|
|
56
|
+
inner_storage_path,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
@classmethod
|
|
60
|
+
def load_from(
|
|
61
|
+
cls,
|
|
62
|
+
path : Union[str, os.PathLike],
|
|
63
|
+
single_instance_space : Space[Any, BDeviceType, BDtypeType, BRNGType],
|
|
64
|
+
*,
|
|
65
|
+
capacity : Optional[int] = None,
|
|
66
|
+
read_only : bool = True,
|
|
67
|
+
**kwargs
|
|
68
|
+
) -> "TransformedStorage[BArrayType, BDeviceType, BDtypeType, BRNGType]":
|
|
69
|
+
metadata_path = os.path.join(path, "transformed_metadata.json")
|
|
70
|
+
assert os.path.exists(metadata_path), f"Metadata file {metadata_path} does not exist"
|
|
71
|
+
with open(metadata_path, "r") as f:
|
|
72
|
+
metadata = json.load(f)
|
|
73
|
+
assert metadata["storage_type"] == cls.__name__, \
|
|
74
|
+
f"Expected storage type {cls.__name__}, but found {metadata['storage_type']}"
|
|
75
|
+
|
|
76
|
+
data_transform_path = os.path.join(path, "data_transformation.pkl")
|
|
77
|
+
with open(data_transform_path, "rb") as f:
|
|
78
|
+
data_transform = pickle.load(f)
|
|
79
|
+
|
|
80
|
+
assert isinstance(data_transform, DataTransformation)
|
|
81
|
+
transformed_space = data_transform.get_target_space_from_source(single_instance_space)
|
|
82
|
+
|
|
83
|
+
inner_storage_cls : Type[SpaceStorage] = get_class_from_full_name(metadata["inner_storage_type"])
|
|
84
|
+
inner_storage_path = metadata["inner_storage_path"]
|
|
85
|
+
inner_storage = inner_storage_cls.load_from(
|
|
86
|
+
os.path.join(path, inner_storage_path),
|
|
87
|
+
transformed_space,
|
|
88
|
+
capacity=capacity,
|
|
89
|
+
read_only=read_only,
|
|
90
|
+
**kwargs
|
|
91
|
+
)
|
|
92
|
+
return TransformedStorage(
|
|
93
|
+
single_instance_space,
|
|
94
|
+
data_transform,
|
|
95
|
+
inner_storage,
|
|
96
|
+
inner_storage_path,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# ========== Instance Implementations ==========
|
|
100
|
+
single_file_ext = None
|
|
101
|
+
|
|
102
|
+
def __init__(
|
|
103
|
+
self,
|
|
104
|
+
single_instance_space: Space[Any, BDeviceType, BDtypeType, BRNGType],
|
|
105
|
+
data_transformation : DataTransformation,
|
|
106
|
+
inner_storage : SpaceStorage[
|
|
107
|
+
BArrayType,
|
|
108
|
+
BArrayType,
|
|
109
|
+
BDeviceType,
|
|
110
|
+
BDtypeType,
|
|
111
|
+
BRNGType,
|
|
112
|
+
],
|
|
113
|
+
inner_storage_path : Union[str, os.PathLike],
|
|
114
|
+
):
|
|
115
|
+
super().__init__(single_instance_space)
|
|
116
|
+
transformed_space = data_transformation.get_target_space_from_source(single_instance_space)
|
|
117
|
+
|
|
118
|
+
assert inner_storage.backend == transformed_space.backend, \
|
|
119
|
+
f"Inner storage backend {inner_storage.backend} does not match single instance space backend {single_instance_space.backend}"
|
|
120
|
+
assert inner_storage.device == transformed_space.device, \
|
|
121
|
+
f"Inner storage device {inner_storage.device} does not match single instance space device {single_instance_space.device}"
|
|
122
|
+
assert inner_storage.single_instance_space == transformed_space
|
|
123
|
+
|
|
124
|
+
self._transformed_space = transformed_space
|
|
125
|
+
self._batched_transformed_space = sbu.batch_space(transformed_space, 1)
|
|
126
|
+
self._batched_instance_space = sbu.batch_space(single_instance_space, 1)
|
|
127
|
+
self.inner_storage = inner_storage
|
|
128
|
+
self.inner_storage_path = inner_storage_path
|
|
129
|
+
self.data_transformation = data_transformation
|
|
130
|
+
self.inv_data_transformation = data_transformation.direction_inverse(single_instance_space)
|
|
131
|
+
|
|
132
|
+
@property
|
|
133
|
+
def capacity(self) -> Optional[int]:
|
|
134
|
+
return self.inner_storage.capacity
|
|
135
|
+
|
|
136
|
+
def extend_length(self, length):
|
|
137
|
+
self.inner_storage.extend_length(length)
|
|
138
|
+
|
|
139
|
+
def shrink_length(self, length):
|
|
140
|
+
self.inner_storage.shrink_length(length)
|
|
141
|
+
|
|
142
|
+
def __len__(self):
|
|
143
|
+
return len(self.inner_storage)
|
|
144
|
+
|
|
145
|
+
def get_flattened(self, index):
|
|
146
|
+
dat = self.get(index)
|
|
147
|
+
if isinstance(index, int):
|
|
148
|
+
return sfu.flatten_data(self.single_instance_space, dat)
|
|
149
|
+
else:
|
|
150
|
+
return sfu.flatten_data(self.single_instance_space, dat, start_dim=1)
|
|
151
|
+
|
|
152
|
+
def get(self, index):
|
|
153
|
+
result = self.inner_storage.get(index)
|
|
154
|
+
result = self.inv_data_transformation.transform(
|
|
155
|
+
self._transformed_space if isinstance(index, int) else self._batched_transformed_space,
|
|
156
|
+
result
|
|
157
|
+
)
|
|
158
|
+
return result
|
|
159
|
+
|
|
160
|
+
def set_flattened(self, index, value):
|
|
161
|
+
if isinstance(index, int):
|
|
162
|
+
set_value = sfu.unflatten_data(self.single_instance_space, value)
|
|
163
|
+
else:
|
|
164
|
+
set_value = sfu.unflatten_data(self._batched_instance_space, value)
|
|
165
|
+
self.set(index, set_value)
|
|
166
|
+
|
|
167
|
+
def set(self, index, value):
|
|
168
|
+
transformed_value = self.data_transformation.transform(
|
|
169
|
+
self.single_instance_space if isinstance(index, int) else self._batched_instance_space,
|
|
170
|
+
value
|
|
171
|
+
)
|
|
172
|
+
self.inner_storage.set(index, transformed_value)
|
|
173
|
+
|
|
174
|
+
def clear(self):
|
|
175
|
+
self.inner_storage.clear()
|
|
176
|
+
|
|
177
|
+
def dumps(self, path):
|
|
178
|
+
metadata = {
|
|
179
|
+
"storage_type": __class__.__name__,
|
|
180
|
+
"inner_storage_type": get_full_class_name(type(self.inner_storage)),
|
|
181
|
+
"inner_storage_path": self.inner_storage_path,
|
|
182
|
+
"transformation": get_full_class_name(type(self.data_transformation))
|
|
183
|
+
}
|
|
184
|
+
self.inner_storage.dumps(os.path.join(path, self.inner_storage_path))
|
|
185
|
+
with open(os.path.join(path, "transformed_metadata.json"), "w") as f:
|
|
186
|
+
json.dump(metadata, f)
|
|
187
|
+
with open(os.path.join(path, "data_transformation.pkl"), "wb") as f:
|
|
188
|
+
pickle.dump(self.data_transformation, f)
|
|
189
|
+
|
|
190
|
+
def close(self):
|
|
191
|
+
self.inner_storage.close()
|
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
from unienv_interface.space.space_utils import batch_utils as sbu
|
|
2
|
+
from unienv_interface.transformations import DataTransformation
|
|
3
|
+
from unienv_interface.space import Space, BoxSpace, TextSpace
|
|
4
|
+
from typing import Union, Any, Optional
|
|
5
|
+
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
6
|
+
from PIL import Image
|
|
7
|
+
import numpy as np
|
|
8
|
+
import io
|
|
9
|
+
|
|
10
|
+
class ImageCompressTransformation(DataTransformation):
|
|
11
|
+
has_inverse = True
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
init_quality : int = 75,
|
|
16
|
+
max_size_bytes : int = 65536,
|
|
17
|
+
mode : Optional[str] = None,
|
|
18
|
+
format : str = "JPEG",
|
|
19
|
+
) -> None:
|
|
20
|
+
"""
|
|
21
|
+
Initialize JPEG compression transformation.
|
|
22
|
+
Args:
|
|
23
|
+
init_quality: Initial JPEG quality setting (1-100).
|
|
24
|
+
max_size_bytes: Maximum allowed size of compressed JPEG in bytes.
|
|
25
|
+
mode: Optional mode for PIL Image (e.g., "RGB", "L"). If None, inferred from input.
|
|
26
|
+
format: Image format to use for compression (default "JPEG"). See https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html for options.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
self.init_quality = init_quality
|
|
30
|
+
self.max_size_bytes = max_size_bytes
|
|
31
|
+
self.mode = mode
|
|
32
|
+
self.format = format
|
|
33
|
+
|
|
34
|
+
@staticmethod
|
|
35
|
+
def validate_source_space(source_space: Space[Any, BDeviceType, BDtypeType, BRNGType]) -> None:
|
|
36
|
+
assert isinstance(source_space, BoxSpace), "JPEGCompressTransformation only supports BoxSpace source spaces."
|
|
37
|
+
assert len(source_space.shape) >= 3 and (
|
|
38
|
+
source_space.shape[-1] == 3 or
|
|
39
|
+
source_space.shape[-1] == 1
|
|
40
|
+
), "JPEGCompressTransformation only supports BoxSpace source spaces with shape (..., H, W, 1 or 3)."
|
|
41
|
+
|
|
42
|
+
@staticmethod
|
|
43
|
+
def get_uint8_dtype(
|
|
44
|
+
backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
45
|
+
) -> BDtypeType:
|
|
46
|
+
return backend.__array_namespace_info__().dtypes()['uint8']
|
|
47
|
+
|
|
48
|
+
def get_target_space_from_source(self, source_space):
|
|
49
|
+
self.validate_source_space(source_space)
|
|
50
|
+
new_shape = source_space.shape[:-3] + (self.max_size_bytes,)
|
|
51
|
+
|
|
52
|
+
return BoxSpace(
|
|
53
|
+
source_space.backend,
|
|
54
|
+
shape=new_shape,
|
|
55
|
+
low=-source_space.backend.inf,
|
|
56
|
+
high=source_space.backend.inf,
|
|
57
|
+
dtype=self.get_uint8_dtype(source_space.backend),
|
|
58
|
+
device=source_space.device,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
def encode_to_size(self, img_array, max_bytes, min_quality=20, mode=None):
|
|
62
|
+
"""
|
|
63
|
+
Encode an image (H, W, 3) or (H, W, 1) as JPEG bytes,
|
|
64
|
+
reducing quality until <= max_bytes.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
img_array: np.ndarray, uint8, shape (H, W, 3) RGB or (H, W, 1) grayscale
|
|
68
|
+
max_bytes: maximum allowed size of JPEG file
|
|
69
|
+
min_quality: minimum JPEG quality before giving up
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
jpeg_bytes (bytes), final_quality (int)
|
|
73
|
+
"""
|
|
74
|
+
# Handle grayscale (H, W, 1) → (H, W)
|
|
75
|
+
if img_array.ndim == 3 and img_array.shape[-1] == 1:
|
|
76
|
+
img_array = np.squeeze(img_array, axis=-1)
|
|
77
|
+
|
|
78
|
+
# Create PIL Image (mode inferred automatically)
|
|
79
|
+
img = Image.fromarray(img_array, mode=mode)
|
|
80
|
+
|
|
81
|
+
quality = 95
|
|
82
|
+
while quality >= min_quality:
|
|
83
|
+
buf = io.BytesIO()
|
|
84
|
+
img.save(buf, format=self.format, quality=quality)
|
|
85
|
+
image_bytes = buf.getvalue()
|
|
86
|
+
if len(image_bytes) <= max_bytes:
|
|
87
|
+
return image_bytes, quality
|
|
88
|
+
quality -= 5
|
|
89
|
+
|
|
90
|
+
img.close()
|
|
91
|
+
# Return lowest quality attempt if still too large
|
|
92
|
+
return image_bytes, quality
|
|
93
|
+
|
|
94
|
+
def transform(self, source_space, data):
|
|
95
|
+
self.validate_source_space(source_space)
|
|
96
|
+
data_numpy = source_space.backend.to_numpy(data)
|
|
97
|
+
flat_data_numpy = data_numpy.reshape(-1, *data_numpy.shape[-3:])
|
|
98
|
+
flat_compressed_data = np.zeros((flat_data_numpy.shape[0], self.max_size_bytes), dtype=np.uint8)
|
|
99
|
+
for i in range(flat_data_numpy.shape[0]):
|
|
100
|
+
img_array = flat_data_numpy[i]
|
|
101
|
+
image_bytes, _ = self.encode_to_size(
|
|
102
|
+
img_array,
|
|
103
|
+
self.max_size_bytes,
|
|
104
|
+
mode=self.mode
|
|
105
|
+
)
|
|
106
|
+
byte_array = np.frombuffer(image_bytes, dtype=np.uint8)
|
|
107
|
+
flat_compressed_data[i, :len(byte_array)] = byte_array
|
|
108
|
+
compressed_data = flat_compressed_data.reshape(data_numpy.shape[:-3] + (self.max_size_bytes, ))
|
|
109
|
+
compressed_data_backend = source_space.backend.from_numpy(compressed_data, dtype=self.get_uint8_dtype(source_space.backend), device=source_space.device)
|
|
110
|
+
return compressed_data_backend
|
|
111
|
+
|
|
112
|
+
def direction_inverse(self, source_space = None):
|
|
113
|
+
assert source_space is not None, "Source space must be provided to get inverse transformation."
|
|
114
|
+
self.validate_source_space(source_space)
|
|
115
|
+
height = source_space.shape[-3]
|
|
116
|
+
width = source_space.shape[-2]
|
|
117
|
+
channels = source_space.shape[-1]
|
|
118
|
+
return ImageDecompressTransformation(
|
|
119
|
+
target_height=height,
|
|
120
|
+
target_width=width,
|
|
121
|
+
target_channels=channels,
|
|
122
|
+
mode=self.mode,
|
|
123
|
+
format=self.format,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
class ImageDecompressTransformation(DataTransformation):
|
|
127
|
+
has_inverse = True
|
|
128
|
+
|
|
129
|
+
def __init__(
|
|
130
|
+
self,
|
|
131
|
+
target_height : int,
|
|
132
|
+
target_width : int,
|
|
133
|
+
target_channels : int = 3,
|
|
134
|
+
mode : Optional[str] = None,
|
|
135
|
+
format : Optional[str] = None,
|
|
136
|
+
) -> None:
|
|
137
|
+
"""
|
|
138
|
+
Initialize JPEG decompression transformation.
|
|
139
|
+
Args:
|
|
140
|
+
target_height: Height of the decompressed image.
|
|
141
|
+
target_width: Width of the decompressed image.
|
|
142
|
+
mode: Optional mode for PIL Image (e.g., "RGB", "L"). If None, inferred from input.
|
|
143
|
+
format: Image format to use for decompression (default None, which will try everything). See https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html for options.
|
|
144
|
+
"""
|
|
145
|
+
self.target_height = target_height
|
|
146
|
+
self.target_width = target_width
|
|
147
|
+
self.target_channels = target_channels
|
|
148
|
+
self.mode = mode
|
|
149
|
+
self.format = format
|
|
150
|
+
|
|
151
|
+
@staticmethod
|
|
152
|
+
def validate_source_space(source_space: Space[Any, BDeviceType, BDtypeType, BRNGType]) -> None:
|
|
153
|
+
assert isinstance(source_space, BoxSpace), "JPEGDecompressTransformation only supports BoxSpace source spaces."
|
|
154
|
+
assert len(source_space.shape) >= 1, "JPEGDecompressTransformation requires source space with at least 1 dimension."
|
|
155
|
+
|
|
156
|
+
@staticmethod
|
|
157
|
+
def get_uint8_dtype(backend):
|
|
158
|
+
return ImageCompressTransformation.get_uint8_dtype(backend)
|
|
159
|
+
|
|
160
|
+
def get_target_space_from_source(self, source_space):
|
|
161
|
+
self.validate_source_space(source_space)
|
|
162
|
+
new_shape = source_space.shape[:-1] + (self.target_height, self.target_width, self.target_channels)
|
|
163
|
+
return BoxSpace(
|
|
164
|
+
source_space.backend,
|
|
165
|
+
shape=new_shape,
|
|
166
|
+
low=0,
|
|
167
|
+
high=255,
|
|
168
|
+
dtype=self.get_uint8_dtype(source_space.backend),
|
|
169
|
+
device=source_space.device,
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
def decode_bytes(self, jpeg_bytes : bytes, mode=None):
|
|
173
|
+
"""
|
|
174
|
+
Decode JPEG bytes to an image array (H, W, 3).
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
jpeg_bytes: bytes of JPEG image
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
img_array: np.ndarray, uint8, shape (H, W, 3)
|
|
181
|
+
"""
|
|
182
|
+
buf = io.BytesIO(jpeg_bytes)
|
|
183
|
+
img = Image.open(buf, formats=[self.format] if self.format is not None else None)
|
|
184
|
+
if mode is not None:
|
|
185
|
+
img = img.convert(mode)
|
|
186
|
+
img_array = np.array(img)
|
|
187
|
+
img.close()
|
|
188
|
+
return img_array
|
|
189
|
+
|
|
190
|
+
def transform(self, source_space, data):
|
|
191
|
+
self.validate_source_space(source_space)
|
|
192
|
+
data_numpy = source_space.backend.to_numpy(data)
|
|
193
|
+
flat_data_numpy = data_numpy.reshape(-1, data_numpy.shape[-1])
|
|
194
|
+
flat_decompressed_image = np.zeros((flat_data_numpy.shape[0], self.target_height, self.target_width, self.target_channels), dtype=np.uint8)
|
|
195
|
+
for i in range(flat_data_numpy.shape[0]):
|
|
196
|
+
byte_array : np.ndarray = flat_data_numpy[i]
|
|
197
|
+
flat_decompressed_image[i] = self.decode_bytes(
|
|
198
|
+
byte_array.tobytes(),
|
|
199
|
+
mode=self.mode
|
|
200
|
+
)
|
|
201
|
+
decompressed_image = flat_decompressed_image.reshape(data_numpy.shape[:-1] + (self.target_height, self.target_width, self.target_channels))
|
|
202
|
+
decompressed_image_backend = source_space.backend.from_numpy(decompressed_image, dtype=self.get_uint8_dtype(source_space.backend), device=source_space.device)
|
|
203
|
+
return decompressed_image_backend
|
|
204
|
+
|
|
205
|
+
def direction_inverse(self, source_space = None):
|
|
206
|
+
assert source_space is not None, "Source space must be provided to get inverse transformation."
|
|
207
|
+
self.validate_source_space(source_space)
|
|
208
|
+
return ImageCompressTransformation(
|
|
209
|
+
init_quality=75,
|
|
210
|
+
max_size_bytes=source_space.shape[-1],
|
|
211
|
+
mode=self.mode,
|
|
212
|
+
format=self.format if self.format is not None else "JPEG",
|
|
213
|
+
)
|
unienv_interface/backends/jax.py
CHANGED
|
@@ -1,4 +1,7 @@
|
|
|
1
|
-
|
|
1
|
+
try:
|
|
2
|
+
from xbarray.backends.jax import JaxComputeBackend as XBJaxBackend
|
|
3
|
+
except ImportError:
|
|
4
|
+
from xbarray.jax import JaxComputeBackend as XBJaxBackend
|
|
2
5
|
from xbarray import ComputeBackend
|
|
3
6
|
from typing import Union
|
|
4
7
|
import jax
|
|
@@ -1,4 +1,7 @@
|
|
|
1
|
-
|
|
1
|
+
try:
|
|
2
|
+
from xbarray.backends.numpy import NumpyComputeBackend as XBNumpyBackend
|
|
3
|
+
except ImportError:
|
|
4
|
+
from xbarray.numpy import NumpyComputeBackend as XBNumpyBackend
|
|
2
5
|
from xbarray import ComputeBackend
|
|
3
6
|
|
|
4
7
|
import numpy as np
|
|
@@ -1,4 +1,7 @@
|
|
|
1
|
-
|
|
1
|
+
try:
|
|
2
|
+
from xbarray.backends.pytorch import PytorchComputeBackend as XBPytorchBackend
|
|
3
|
+
except ImportError:
|
|
4
|
+
from xbarray.pytorch import PytorchComputeBackend as XBPytorchBackend
|
|
2
5
|
from xbarray import ComputeBackend
|
|
3
6
|
|
|
4
7
|
from typing import Union
|
unienv_interface/env_base/env.py
CHANGED
|
@@ -90,6 +90,11 @@ class Env(abc.ABC, Generic[BArrayType, ContextType, ObsType, ActType, RenderFram
|
|
|
90
90
|
def sample_observation(self) -> ObsType:
|
|
91
91
|
return self.sample_space(self.observation_space)
|
|
92
92
|
|
|
93
|
+
def sample_context(self) -> Optional[ContextType]:
|
|
94
|
+
if self.context_space is None:
|
|
95
|
+
return None
|
|
96
|
+
return self.sample_space(self.context_space)
|
|
97
|
+
|
|
93
98
|
def update_observation_post_reset(
|
|
94
99
|
self,
|
|
95
100
|
old_obs: ObsType,
|
|
@@ -2,7 +2,7 @@ from typing import Any, Callable, Generic, TypeVar, Tuple, Dict, Optional, Suppo
|
|
|
2
2
|
import abc
|
|
3
3
|
import numpy as np
|
|
4
4
|
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
5
|
-
from unienv_interface.space import Space
|
|
5
|
+
from unienv_interface.space import Space, batch_utils as sbu
|
|
6
6
|
from dataclasses import dataclass, replace as dataclass_replace
|
|
7
7
|
from .env import Env, ContextType, ObsType, ActType, RenderFrame
|
|
8
8
|
|
|
@@ -121,6 +121,37 @@ class FuncEnv(
|
|
|
121
121
|
"""Close the render state."""
|
|
122
122
|
raise NotImplementedError
|
|
123
123
|
|
|
124
|
+
# ========== Convenience methods ==========
|
|
125
|
+
def update_observation_post_reset(
|
|
126
|
+
self,
|
|
127
|
+
old_obs: ObsType,
|
|
128
|
+
newobs_masked: ObsType,
|
|
129
|
+
mask: BArrayType
|
|
130
|
+
) -> ObsType:
|
|
131
|
+
assert self.batch_size is not None, "This method is used by batched environment after reset"
|
|
132
|
+
return sbu.set_at(
|
|
133
|
+
self.observation_space,
|
|
134
|
+
old_obs,
|
|
135
|
+
mask,
|
|
136
|
+
newobs_masked
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
def update_context_post_reset(
|
|
140
|
+
self,
|
|
141
|
+
old_context: ContextType,
|
|
142
|
+
new_context: ContextType,
|
|
143
|
+
mask: BArrayType
|
|
144
|
+
) -> ContextType:
|
|
145
|
+
assert self.batch_size is not None, "This method is used by batched environment after reset"
|
|
146
|
+
if self.context_space is None:
|
|
147
|
+
return None
|
|
148
|
+
return sbu.set_at(
|
|
149
|
+
self.context_space,
|
|
150
|
+
old_context,
|
|
151
|
+
mask,
|
|
152
|
+
new_context
|
|
153
|
+
)
|
|
154
|
+
|
|
124
155
|
# ========== Wrapper methods ==========
|
|
125
156
|
@property
|
|
126
157
|
def unwrapped(self) -> "FuncEnv":
|
|
@@ -110,7 +110,7 @@ class FuncEnvWrapper(
|
|
|
110
110
|
def reset(
|
|
111
111
|
self,
|
|
112
112
|
state : WrapperStateT,
|
|
113
|
-
|
|
113
|
+
*args,
|
|
114
114
|
seed : Optional[int] = None,
|
|
115
115
|
mask : Optional[WrapperBArrayT] = None,
|
|
116
116
|
**kwargs
|
|
@@ -120,7 +120,7 @@ class FuncEnvWrapper(
|
|
|
120
120
|
WrapperObsT,
|
|
121
121
|
Dict[str, Any]
|
|
122
122
|
]:
|
|
123
|
-
return self.func_env.reset(state, seed=seed, mask=mask, **kwargs)
|
|
123
|
+
return self.func_env.reset(state, *args, seed=seed, mask=mask, **kwargs)
|
|
124
124
|
|
|
125
125
|
def step(
|
|
126
126
|
self,
|