unienv 0.0.1b1__py3-none-any.whl → 0.0.1b3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. unienv-0.0.1b3.dist-info/METADATA +74 -0
  2. unienv-0.0.1b3.dist-info/RECORD +92 -0
  3. {unienv-0.0.1b1.dist-info → unienv-0.0.1b3.dist-info}/licenses/LICENSE +1 -3
  4. unienv-0.0.1b3.dist-info/top_level.txt +2 -0
  5. unienv_data/base/__init__.py +0 -1
  6. unienv_data/base/common.py +95 -45
  7. unienv_data/base/storage.py +1 -0
  8. unienv_data/batches/__init__.py +2 -1
  9. unienv_data/batches/backend_compat.py +47 -1
  10. unienv_data/batches/combined_batch.py +2 -4
  11. unienv_data/{base → batches}/transformations.py +3 -2
  12. unienv_data/replay_buffer/replay_buffer.py +4 -0
  13. unienv_data/samplers/__init__.py +0 -1
  14. unienv_data/samplers/multiprocessing_sampler.py +26 -22
  15. unienv_data/samplers/step_sampler.py +9 -18
  16. unienv_data/storages/common.py +5 -0
  17. unienv_data/storages/hdf5.py +291 -20
  18. unienv_data/storages/pytorch.py +1 -0
  19. unienv_data/storages/transformation.py +191 -0
  20. unienv_data/transformations/image_compress.py +213 -0
  21. unienv_interface/backends/jax.py +4 -1
  22. unienv_interface/backends/numpy.py +4 -1
  23. unienv_interface/backends/pytorch.py +4 -1
  24. unienv_interface/env_base/__init__.py +1 -0
  25. unienv_interface/env_base/env.py +5 -0
  26. unienv_interface/env_base/funcenv.py +32 -1
  27. unienv_interface/env_base/funcenv_wrapper.py +2 -2
  28. unienv_interface/env_base/vec_env.py +474 -0
  29. unienv_interface/func_wrapper/__init__.py +2 -1
  30. unienv_interface/func_wrapper/frame_stack.py +150 -0
  31. unienv_interface/space/space_utils/__init__.py +1 -0
  32. unienv_interface/space/space_utils/batch_utils.py +83 -0
  33. unienv_interface/space/space_utils/construct_utils.py +216 -0
  34. unienv_interface/space/space_utils/serialization_utils.py +16 -1
  35. unienv_interface/space/spaces/__init__.py +3 -1
  36. unienv_interface/space/spaces/batched.py +90 -0
  37. unienv_interface/space/spaces/binary.py +0 -1
  38. unienv_interface/space/spaces/box.py +13 -24
  39. unienv_interface/space/spaces/text.py +1 -3
  40. unienv_interface/transformations/dict_transform.py +31 -5
  41. unienv_interface/utils/control_util.py +68 -0
  42. unienv_interface/utils/data_queue.py +184 -0
  43. unienv_interface/utils/stateclass.py +46 -0
  44. unienv_interface/utils/vec_util.py +15 -0
  45. unienv_interface/world/__init__.py +3 -1
  46. unienv_interface/world/combined_funcnode.py +336 -0
  47. unienv_interface/world/combined_node.py +232 -0
  48. unienv_interface/wrapper/backend_compat.py +2 -2
  49. unienv_interface/wrapper/frame_stack.py +19 -114
  50. unienv_interface/wrapper/video_record.py +11 -2
  51. unienv-0.0.1b1.dist-info/METADATA +0 -20
  52. unienv-0.0.1b1.dist-info/RECORD +0 -85
  53. unienv-0.0.1b1.dist-info/top_level.txt +0 -4
  54. unienv_data/samplers/slice_sampler.py +0 -266
  55. unienv_maniskill/__init__.py +0 -1
  56. unienv_maniskill/wrapper/maniskill_compat.py +0 -235
  57. unienv_mjxplayground/__init__.py +0 -1
  58. unienv_mjxplayground/wrapper/playground_compat.py +0 -256
  59. {unienv-0.0.1b1.dist-info → unienv-0.0.1b3.dist-info}/WHEEL +0 -0
@@ -0,0 +1,191 @@
1
+ from importlib import metadata
2
+ from typing import Generic, TypeVar, Generic, Optional, Any, Dict, Tuple, Sequence, Union, List, Iterable, Type
3
+
4
+ from unienv_interface.space import Space, BoxSpace
5
+ from unienv_interface.space.space_utils import batch_utils as sbu, flatten_utils as sfu
6
+ from unienv_interface.env_base.env import ContextType, ObsType, ActType
7
+ from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
8
+ from unienv_interface.backends.numpy import NumpyComputeBackend
9
+ from unienv_interface.utils.symbol_util import *
10
+ from unienv_interface.transformations import DataTransformation
11
+
12
+ from unienv_data.base import SpaceStorage, BatchT
13
+
14
+ import numpy as np
15
+ import os
16
+ import json
17
+ import pickle
18
+
19
+ class TransformedStorage(SpaceStorage[
20
+ BatchT,
21
+ BArrayType,
22
+ BDeviceType,
23
+ BDtypeType,
24
+ BRNGType,
25
+ ]):
26
+ # ========== Class Attributes ==========
27
+ @classmethod
28
+ def create(
29
+ cls,
30
+ single_instance_space: Space[Any, BDeviceType, BDtypeType, BRNGType],
31
+ inner_storage_cls : Type[SpaceStorage[BArrayType, BArrayType, BDeviceType, BDtypeType, BRNGType]],
32
+ *args,
33
+ data_transformation : DataTransformation,
34
+ capacity : Optional[int] = None,
35
+ cache_path : Optional[str] = None,
36
+ **kwargs
37
+ ) -> "TransformedStorage[BArrayType, BDeviceType, BDtypeType, BRNGType]":
38
+ assert data_transformation.has_inverse, "To transform storages (potentially to save space), you need to use inversible data transformations"
39
+ transformed_space = data_transformation.get_target_space_from_source(single_instance_space)
40
+ inner_storage_path = "transformed_inner_storage" + (inner_storage_cls.single_file_ext or "")
41
+
42
+ if cache_path is not None:
43
+ os.makedirs(cache_path, exist_ok=True)
44
+
45
+ inner_storage = inner_storage_cls.create(
46
+ transformed_space,
47
+ *args,
48
+ cache_path=None if cache_path is None else os.path.join(cache_path, inner_storage_path),
49
+ capacity=capacity,
50
+ **kwargs
51
+ )
52
+ return TransformedStorage(
53
+ single_instance_space,
54
+ data_transformation,
55
+ inner_storage,
56
+ inner_storage_path,
57
+ )
58
+
59
+ @classmethod
60
+ def load_from(
61
+ cls,
62
+ path : Union[str, os.PathLike],
63
+ single_instance_space : Space[Any, BDeviceType, BDtypeType, BRNGType],
64
+ *,
65
+ capacity : Optional[int] = None,
66
+ read_only : bool = True,
67
+ **kwargs
68
+ ) -> "TransformedStorage[BArrayType, BDeviceType, BDtypeType, BRNGType]":
69
+ metadata_path = os.path.join(path, "transformed_metadata.json")
70
+ assert os.path.exists(metadata_path), f"Metadata file {metadata_path} does not exist"
71
+ with open(metadata_path, "r") as f:
72
+ metadata = json.load(f)
73
+ assert metadata["storage_type"] == cls.__name__, \
74
+ f"Expected storage type {cls.__name__}, but found {metadata['storage_type']}"
75
+
76
+ data_transform_path = os.path.join(path, "data_transformation.pkl")
77
+ with open(data_transform_path, "rb") as f:
78
+ data_transform = pickle.load(f)
79
+
80
+ assert isinstance(data_transform, DataTransformation)
81
+ transformed_space = data_transform.get_target_space_from_source(single_instance_space)
82
+
83
+ inner_storage_cls : Type[SpaceStorage] = get_class_from_full_name(metadata["inner_storage_type"])
84
+ inner_storage_path = metadata["inner_storage_path"]
85
+ inner_storage = inner_storage_cls.load_from(
86
+ os.path.join(path, inner_storage_path),
87
+ transformed_space,
88
+ capacity=capacity,
89
+ read_only=read_only,
90
+ **kwargs
91
+ )
92
+ return TransformedStorage(
93
+ single_instance_space,
94
+ data_transform,
95
+ inner_storage,
96
+ inner_storage_path,
97
+ )
98
+
99
+ # ========== Instance Implementations ==========
100
+ single_file_ext = None
101
+
102
+ def __init__(
103
+ self,
104
+ single_instance_space: Space[Any, BDeviceType, BDtypeType, BRNGType],
105
+ data_transformation : DataTransformation,
106
+ inner_storage : SpaceStorage[
107
+ BArrayType,
108
+ BArrayType,
109
+ BDeviceType,
110
+ BDtypeType,
111
+ BRNGType,
112
+ ],
113
+ inner_storage_path : Union[str, os.PathLike],
114
+ ):
115
+ super().__init__(single_instance_space)
116
+ transformed_space = data_transformation.get_target_space_from_source(single_instance_space)
117
+
118
+ assert inner_storage.backend == transformed_space.backend, \
119
+ f"Inner storage backend {inner_storage.backend} does not match single instance space backend {single_instance_space.backend}"
120
+ assert inner_storage.device == transformed_space.device, \
121
+ f"Inner storage device {inner_storage.device} does not match single instance space device {single_instance_space.device}"
122
+ assert inner_storage.single_instance_space == transformed_space
123
+
124
+ self._transformed_space = transformed_space
125
+ self._batched_transformed_space = sbu.batch_space(transformed_space, 1)
126
+ self._batched_instance_space = sbu.batch_space(single_instance_space, 1)
127
+ self.inner_storage = inner_storage
128
+ self.inner_storage_path = inner_storage_path
129
+ self.data_transformation = data_transformation
130
+ self.inv_data_transformation = data_transformation.direction_inverse(single_instance_space)
131
+
132
+ @property
133
+ def capacity(self) -> Optional[int]:
134
+ return self.inner_storage.capacity
135
+
136
+ def extend_length(self, length):
137
+ self.inner_storage.extend_length(length)
138
+
139
+ def shrink_length(self, length):
140
+ self.inner_storage.shrink_length(length)
141
+
142
+ def __len__(self):
143
+ return len(self.inner_storage)
144
+
145
+ def get_flattened(self, index):
146
+ dat = self.get(index)
147
+ if isinstance(index, int):
148
+ return sfu.flatten_data(self.single_instance_space, dat)
149
+ else:
150
+ return sfu.flatten_data(self.single_instance_space, dat, start_dim=1)
151
+
152
+ def get(self, index):
153
+ result = self.inner_storage.get(index)
154
+ result = self.inv_data_transformation.transform(
155
+ self._transformed_space if isinstance(index, int) else self._batched_transformed_space,
156
+ result
157
+ )
158
+ return result
159
+
160
+ def set_flattened(self, index, value):
161
+ if isinstance(index, int):
162
+ set_value = sfu.unflatten_data(self.single_instance_space, value)
163
+ else:
164
+ set_value = sfu.unflatten_data(self._batched_instance_space, value)
165
+ self.set(index, set_value)
166
+
167
+ def set(self, index, value):
168
+ transformed_value = self.data_transformation.transform(
169
+ self.single_instance_space if isinstance(index, int) else self._batched_instance_space,
170
+ value
171
+ )
172
+ self.inner_storage.set(index, transformed_value)
173
+
174
+ def clear(self):
175
+ self.inner_storage.clear()
176
+
177
+ def dumps(self, path):
178
+ metadata = {
179
+ "storage_type": __class__.__name__,
180
+ "inner_storage_type": get_full_class_name(type(self.inner_storage)),
181
+ "inner_storage_path": self.inner_storage_path,
182
+ "transformation": get_full_class_name(type(self.data_transformation))
183
+ }
184
+ self.inner_storage.dumps(os.path.join(path, self.inner_storage_path))
185
+ with open(os.path.join(path, "transformed_metadata.json"), "w") as f:
186
+ json.dump(metadata, f)
187
+ with open(os.path.join(path, "data_transformation.pkl"), "wb") as f:
188
+ pickle.dump(self.data_transformation, f)
189
+
190
+ def close(self):
191
+ self.inner_storage.close()
@@ -0,0 +1,213 @@
1
+ from unienv_interface.space.space_utils import batch_utils as sbu
2
+ from unienv_interface.transformations import DataTransformation
3
+ from unienv_interface.space import Space, BoxSpace, TextSpace
4
+ from typing import Union, Any, Optional
5
+ from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
6
+ from PIL import Image
7
+ import numpy as np
8
+ import io
9
+
10
+ class ImageCompressTransformation(DataTransformation):
11
+ has_inverse = True
12
+
13
+ def __init__(
14
+ self,
15
+ init_quality : int = 75,
16
+ max_size_bytes : int = 65536,
17
+ mode : Optional[str] = None,
18
+ format : str = "JPEG",
19
+ ) -> None:
20
+ """
21
+ Initialize JPEG compression transformation.
22
+ Args:
23
+ init_quality: Initial JPEG quality setting (1-100).
24
+ max_size_bytes: Maximum allowed size of compressed JPEG in bytes.
25
+ mode: Optional mode for PIL Image (e.g., "RGB", "L"). If None, inferred from input.
26
+ format: Image format to use for compression (default "JPEG"). See https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html for options.
27
+ """
28
+
29
+ self.init_quality = init_quality
30
+ self.max_size_bytes = max_size_bytes
31
+ self.mode = mode
32
+ self.format = format
33
+
34
+ @staticmethod
35
+ def validate_source_space(source_space: Space[Any, BDeviceType, BDtypeType, BRNGType]) -> None:
36
+ assert isinstance(source_space, BoxSpace), "JPEGCompressTransformation only supports BoxSpace source spaces."
37
+ assert len(source_space.shape) >= 3 and (
38
+ source_space.shape[-1] == 3 or
39
+ source_space.shape[-1] == 1
40
+ ), "JPEGCompressTransformation only supports BoxSpace source spaces with shape (..., H, W, 1 or 3)."
41
+
42
+ @staticmethod
43
+ def get_uint8_dtype(
44
+ backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
45
+ ) -> BDtypeType:
46
+ return backend.__array_namespace_info__().dtypes()['uint8']
47
+
48
+ def get_target_space_from_source(self, source_space):
49
+ self.validate_source_space(source_space)
50
+ new_shape = source_space.shape[:-3] + (self.max_size_bytes,)
51
+
52
+ return BoxSpace(
53
+ source_space.backend,
54
+ shape=new_shape,
55
+ low=-source_space.backend.inf,
56
+ high=source_space.backend.inf,
57
+ dtype=self.get_uint8_dtype(source_space.backend),
58
+ device=source_space.device,
59
+ )
60
+
61
+ def encode_to_size(self, img_array, max_bytes, min_quality=20, mode=None):
62
+ """
63
+ Encode an image (H, W, 3) or (H, W, 1) as JPEG bytes,
64
+ reducing quality until <= max_bytes.
65
+
66
+ Args:
67
+ img_array: np.ndarray, uint8, shape (H, W, 3) RGB or (H, W, 1) grayscale
68
+ max_bytes: maximum allowed size of JPEG file
69
+ min_quality: minimum JPEG quality before giving up
70
+
71
+ Returns:
72
+ jpeg_bytes (bytes), final_quality (int)
73
+ """
74
+ # Handle grayscale (H, W, 1) → (H, W)
75
+ if img_array.ndim == 3 and img_array.shape[-1] == 1:
76
+ img_array = np.squeeze(img_array, axis=-1)
77
+
78
+ # Create PIL Image (mode inferred automatically)
79
+ img = Image.fromarray(img_array, mode=mode)
80
+
81
+ quality = 95
82
+ while quality >= min_quality:
83
+ buf = io.BytesIO()
84
+ img.save(buf, format=self.format, quality=quality)
85
+ image_bytes = buf.getvalue()
86
+ if len(image_bytes) <= max_bytes:
87
+ return image_bytes, quality
88
+ quality -= 5
89
+
90
+ img.close()
91
+ # Return lowest quality attempt if still too large
92
+ return image_bytes, quality
93
+
94
+ def transform(self, source_space, data):
95
+ self.validate_source_space(source_space)
96
+ data_numpy = source_space.backend.to_numpy(data)
97
+ flat_data_numpy = data_numpy.reshape(-1, *data_numpy.shape[-3:])
98
+ flat_compressed_data = np.zeros((flat_data_numpy.shape[0], self.max_size_bytes), dtype=np.uint8)
99
+ for i in range(flat_data_numpy.shape[0]):
100
+ img_array = flat_data_numpy[i]
101
+ image_bytes, _ = self.encode_to_size(
102
+ img_array,
103
+ self.max_size_bytes,
104
+ mode=self.mode
105
+ )
106
+ byte_array = np.frombuffer(image_bytes, dtype=np.uint8)
107
+ flat_compressed_data[i, :len(byte_array)] = byte_array
108
+ compressed_data = flat_compressed_data.reshape(data_numpy.shape[:-3] + (self.max_size_bytes, ))
109
+ compressed_data_backend = source_space.backend.from_numpy(compressed_data, dtype=self.get_uint8_dtype(source_space.backend), device=source_space.device)
110
+ return compressed_data_backend
111
+
112
+ def direction_inverse(self, source_space = None):
113
+ assert source_space is not None, "Source space must be provided to get inverse transformation."
114
+ self.validate_source_space(source_space)
115
+ height = source_space.shape[-3]
116
+ width = source_space.shape[-2]
117
+ channels = source_space.shape[-1]
118
+ return ImageDecompressTransformation(
119
+ target_height=height,
120
+ target_width=width,
121
+ target_channels=channels,
122
+ mode=self.mode,
123
+ format=self.format,
124
+ )
125
+
126
+ class ImageDecompressTransformation(DataTransformation):
127
+ has_inverse = True
128
+
129
+ def __init__(
130
+ self,
131
+ target_height : int,
132
+ target_width : int,
133
+ target_channels : int = 3,
134
+ mode : Optional[str] = None,
135
+ format : Optional[str] = None,
136
+ ) -> None:
137
+ """
138
+ Initialize JPEG decompression transformation.
139
+ Args:
140
+ target_height: Height of the decompressed image.
141
+ target_width: Width of the decompressed image.
142
+ mode: Optional mode for PIL Image (e.g., "RGB", "L"). If None, inferred from input.
143
+ format: Image format to use for decompression (default None, which will try everything). See https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html for options.
144
+ """
145
+ self.target_height = target_height
146
+ self.target_width = target_width
147
+ self.target_channels = target_channels
148
+ self.mode = mode
149
+ self.format = format
150
+
151
+ @staticmethod
152
+ def validate_source_space(source_space: Space[Any, BDeviceType, BDtypeType, BRNGType]) -> None:
153
+ assert isinstance(source_space, BoxSpace), "JPEGDecompressTransformation only supports BoxSpace source spaces."
154
+ assert len(source_space.shape) >= 1, "JPEGDecompressTransformation requires source space with at least 1 dimension."
155
+
156
+ @staticmethod
157
+ def get_uint8_dtype(backend):
158
+ return ImageCompressTransformation.get_uint8_dtype(backend)
159
+
160
+ def get_target_space_from_source(self, source_space):
161
+ self.validate_source_space(source_space)
162
+ new_shape = source_space.shape[:-1] + (self.target_height, self.target_width, self.target_channels)
163
+ return BoxSpace(
164
+ source_space.backend,
165
+ shape=new_shape,
166
+ low=0,
167
+ high=255,
168
+ dtype=self.get_uint8_dtype(source_space.backend),
169
+ device=source_space.device,
170
+ )
171
+
172
+ def decode_bytes(self, jpeg_bytes : bytes, mode=None):
173
+ """
174
+ Decode JPEG bytes to an image array (H, W, 3).
175
+
176
+ Args:
177
+ jpeg_bytes: bytes of JPEG image
178
+
179
+ Returns:
180
+ img_array: np.ndarray, uint8, shape (H, W, 3)
181
+ """
182
+ buf = io.BytesIO(jpeg_bytes)
183
+ img = Image.open(buf, formats=[self.format] if self.format is not None else None)
184
+ if mode is not None:
185
+ img = img.convert(mode)
186
+ img_array = np.array(img)
187
+ img.close()
188
+ return img_array
189
+
190
+ def transform(self, source_space, data):
191
+ self.validate_source_space(source_space)
192
+ data_numpy = source_space.backend.to_numpy(data)
193
+ flat_data_numpy = data_numpy.reshape(-1, data_numpy.shape[-1])
194
+ flat_decompressed_image = np.zeros((flat_data_numpy.shape[0], self.target_height, self.target_width, self.target_channels), dtype=np.uint8)
195
+ for i in range(flat_data_numpy.shape[0]):
196
+ byte_array : np.ndarray = flat_data_numpy[i]
197
+ flat_decompressed_image[i] = self.decode_bytes(
198
+ byte_array.tobytes(),
199
+ mode=self.mode
200
+ )
201
+ decompressed_image = flat_decompressed_image.reshape(data_numpy.shape[:-1] + (self.target_height, self.target_width, self.target_channels))
202
+ decompressed_image_backend = source_space.backend.from_numpy(decompressed_image, dtype=self.get_uint8_dtype(source_space.backend), device=source_space.device)
203
+ return decompressed_image_backend
204
+
205
+ def direction_inverse(self, source_space = None):
206
+ assert source_space is not None, "Source space must be provided to get inverse transformation."
207
+ self.validate_source_space(source_space)
208
+ return ImageCompressTransformation(
209
+ init_quality=75,
210
+ max_size_bytes=source_space.shape[-1],
211
+ mode=self.mode,
212
+ format=self.format if self.format is not None else "JPEG",
213
+ )
@@ -1,4 +1,7 @@
1
- from xbarray.jax import JaxComputeBackend as XBJaxBackend
1
+ try:
2
+ from xbarray.backends.jax import JaxComputeBackend as XBJaxBackend
3
+ except ImportError:
4
+ from xbarray.jax import JaxComputeBackend as XBJaxBackend
2
5
  from xbarray import ComputeBackend
3
6
  from typing import Union
4
7
  import jax
@@ -1,4 +1,7 @@
1
- from xbarray.numpy import NumpyComputeBackend as XBNumpyBackend
1
+ try:
2
+ from xbarray.backends.numpy import NumpyComputeBackend as XBNumpyBackend
3
+ except ImportError:
4
+ from xbarray.numpy import NumpyComputeBackend as XBNumpyBackend
2
5
  from xbarray import ComputeBackend
3
6
 
4
7
  import numpy as np
@@ -1,4 +1,7 @@
1
- from xbarray.pytorch import PytorchComputeBackend as XBPytorchBackend
1
+ try:
2
+ from xbarray.backends.pytorch import PytorchComputeBackend as XBPytorchBackend
3
+ except ImportError:
4
+ from xbarray.pytorch import PytorchComputeBackend as XBPytorchBackend
2
5
  from xbarray import ComputeBackend
3
6
 
4
7
  from typing import Union
@@ -1,4 +1,5 @@
1
1
  from .env import Env
2
+ from .vec_env import SyncVecEnv, AsyncVecEnv
2
3
  from .wrapper import Wrapper, ActionWrapper, ContextObservationWrapper
3
4
  from .funcenv import FuncEnv, FuncEnvBasedEnv
4
5
  from .funcenv_wrapper import FuncEnvWrapper
@@ -90,6 +90,11 @@ class Env(abc.ABC, Generic[BArrayType, ContextType, ObsType, ActType, RenderFram
90
90
  def sample_observation(self) -> ObsType:
91
91
  return self.sample_space(self.observation_space)
92
92
 
93
+ def sample_context(self) -> Optional[ContextType]:
94
+ if self.context_space is None:
95
+ return None
96
+ return self.sample_space(self.context_space)
97
+
93
98
  def update_observation_post_reset(
94
99
  self,
95
100
  old_obs: ObsType,
@@ -2,7 +2,7 @@ from typing import Any, Callable, Generic, TypeVar, Tuple, Dict, Optional, Suppo
2
2
  import abc
3
3
  import numpy as np
4
4
  from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
5
- from unienv_interface.space import Space
5
+ from unienv_interface.space import Space, batch_utils as sbu
6
6
  from dataclasses import dataclass, replace as dataclass_replace
7
7
  from .env import Env, ContextType, ObsType, ActType, RenderFrame
8
8
 
@@ -121,6 +121,37 @@ class FuncEnv(
121
121
  """Close the render state."""
122
122
  raise NotImplementedError
123
123
 
124
+ # ========== Convenience methods ==========
125
+ def update_observation_post_reset(
126
+ self,
127
+ old_obs: ObsType,
128
+ newobs_masked: ObsType,
129
+ mask: BArrayType
130
+ ) -> ObsType:
131
+ assert self.batch_size is not None, "This method is used by batched environment after reset"
132
+ return sbu.set_at(
133
+ self.observation_space,
134
+ old_obs,
135
+ mask,
136
+ newobs_masked
137
+ )
138
+
139
+ def update_context_post_reset(
140
+ self,
141
+ old_context: ContextType,
142
+ new_context: ContextType,
143
+ mask: BArrayType
144
+ ) -> ContextType:
145
+ assert self.batch_size is not None, "This method is used by batched environment after reset"
146
+ if self.context_space is None:
147
+ return None
148
+ return sbu.set_at(
149
+ self.context_space,
150
+ old_context,
151
+ mask,
152
+ new_context
153
+ )
154
+
124
155
  # ========== Wrapper methods ==========
125
156
  @property
126
157
  def unwrapped(self) -> "FuncEnv":
@@ -110,7 +110,7 @@ class FuncEnvWrapper(
110
110
  def reset(
111
111
  self,
112
112
  state : WrapperStateT,
113
- *,
113
+ *args,
114
114
  seed : Optional[int] = None,
115
115
  mask : Optional[WrapperBArrayT] = None,
116
116
  **kwargs
@@ -120,7 +120,7 @@ class FuncEnvWrapper(
120
120
  WrapperObsT,
121
121
  Dict[str, Any]
122
122
  ]:
123
- return self.func_env.reset(state, seed=seed, mask=mask, **kwargs)
123
+ return self.func_env.reset(state, *args, seed=seed, mask=mask, **kwargs)
124
124
 
125
125
  def step(
126
126
  self,