unienv 0.0.1b1__py3-none-any.whl → 0.0.1b3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- unienv-0.0.1b3.dist-info/METADATA +74 -0
- unienv-0.0.1b3.dist-info/RECORD +92 -0
- {unienv-0.0.1b1.dist-info → unienv-0.0.1b3.dist-info}/licenses/LICENSE +1 -3
- unienv-0.0.1b3.dist-info/top_level.txt +2 -0
- unienv_data/base/__init__.py +0 -1
- unienv_data/base/common.py +95 -45
- unienv_data/base/storage.py +1 -0
- unienv_data/batches/__init__.py +2 -1
- unienv_data/batches/backend_compat.py +47 -1
- unienv_data/batches/combined_batch.py +2 -4
- unienv_data/{base → batches}/transformations.py +3 -2
- unienv_data/replay_buffer/replay_buffer.py +4 -0
- unienv_data/samplers/__init__.py +0 -1
- unienv_data/samplers/multiprocessing_sampler.py +26 -22
- unienv_data/samplers/step_sampler.py +9 -18
- unienv_data/storages/common.py +5 -0
- unienv_data/storages/hdf5.py +291 -20
- unienv_data/storages/pytorch.py +1 -0
- unienv_data/storages/transformation.py +191 -0
- unienv_data/transformations/image_compress.py +213 -0
- unienv_interface/backends/jax.py +4 -1
- unienv_interface/backends/numpy.py +4 -1
- unienv_interface/backends/pytorch.py +4 -1
- unienv_interface/env_base/__init__.py +1 -0
- unienv_interface/env_base/env.py +5 -0
- unienv_interface/env_base/funcenv.py +32 -1
- unienv_interface/env_base/funcenv_wrapper.py +2 -2
- unienv_interface/env_base/vec_env.py +474 -0
- unienv_interface/func_wrapper/__init__.py +2 -1
- unienv_interface/func_wrapper/frame_stack.py +150 -0
- unienv_interface/space/space_utils/__init__.py +1 -0
- unienv_interface/space/space_utils/batch_utils.py +83 -0
- unienv_interface/space/space_utils/construct_utils.py +216 -0
- unienv_interface/space/space_utils/serialization_utils.py +16 -1
- unienv_interface/space/spaces/__init__.py +3 -1
- unienv_interface/space/spaces/batched.py +90 -0
- unienv_interface/space/spaces/binary.py +0 -1
- unienv_interface/space/spaces/box.py +13 -24
- unienv_interface/space/spaces/text.py +1 -3
- unienv_interface/transformations/dict_transform.py +31 -5
- unienv_interface/utils/control_util.py +68 -0
- unienv_interface/utils/data_queue.py +184 -0
- unienv_interface/utils/stateclass.py +46 -0
- unienv_interface/utils/vec_util.py +15 -0
- unienv_interface/world/__init__.py +3 -1
- unienv_interface/world/combined_funcnode.py +336 -0
- unienv_interface/world/combined_node.py +232 -0
- unienv_interface/wrapper/backend_compat.py +2 -2
- unienv_interface/wrapper/frame_stack.py +19 -114
- unienv_interface/wrapper/video_record.py +11 -2
- unienv-0.0.1b1.dist-info/METADATA +0 -20
- unienv-0.0.1b1.dist-info/RECORD +0 -85
- unienv-0.0.1b1.dist-info/top_level.txt +0 -4
- unienv_data/samplers/slice_sampler.py +0 -266
- unienv_maniskill/__init__.py +0 -1
- unienv_maniskill/wrapper/maniskill_compat.py +0 -235
- unienv_mjxplayground/__init__.py +0 -1
- unienv_mjxplayground/wrapper/playground_compat.py +0 -256
- {unienv-0.0.1b1.dist-info → unienv-0.0.1b3.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
from typing import Optional, Dict, Mapping, Any, Tuple, Union, Iterable
|
|
2
|
+
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
3
|
+
from unienv_interface.space import Space, DictSpace
|
|
4
|
+
from unienv_interface.utils.control_util import find_best_timestep
|
|
5
|
+
|
|
6
|
+
from .world import World
|
|
7
|
+
from .node import WorldNode, ContextType, ObsType, ActType
|
|
8
|
+
|
|
9
|
+
CombinedDataT = Union[Dict[str, Any], BArrayType]
|
|
10
|
+
|
|
11
|
+
class CombinedWorldNode(WorldNode[
|
|
12
|
+
Optional[CombinedDataT], CombinedDataT, CombinedDataT,
|
|
13
|
+
BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
14
|
+
]):
|
|
15
|
+
"""
|
|
16
|
+
A WorldNode that combines multiple WorldNodes into one node, using a dictionary to store the data from each node.
|
|
17
|
+
The observation, reward, termination, truncation, and info are combined from all child nodes.
|
|
18
|
+
The keys in the dictionary are the names of the child nodes.
|
|
19
|
+
If there is only one child node that supports value and `direct_return` is set to True, the value is returned directly instead of a dictionary.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
name : str,
|
|
25
|
+
nodes : Iterable[WorldNode[WorldNode[Any, Any, Any, BArrayType, BDeviceType, BDtypeType, BRNGType], Any, Any, BArrayType, BDeviceType, BDtypeType, BRNGType]],
|
|
26
|
+
direct_return : bool = True,
|
|
27
|
+
):
|
|
28
|
+
nodes = list(nodes)
|
|
29
|
+
if len(nodes) == 0:
|
|
30
|
+
raise ValueError("At least one node is required to create a CombinedWorldNode.")
|
|
31
|
+
|
|
32
|
+
# Check that all nodes have the same world
|
|
33
|
+
first_node = nodes[0]
|
|
34
|
+
for node in nodes[1:]:
|
|
35
|
+
assert node.world is first_node.world, "All nodes must belong to the same world."
|
|
36
|
+
assert node.control_timestep == first_node.control_timestep, "All nodes must have the same control timestep."
|
|
37
|
+
# Check that all nodes have unique names
|
|
38
|
+
names = [node.name for node in nodes]
|
|
39
|
+
if len(names) != len(set(names)):
|
|
40
|
+
raise ValueError("All nodes must have unique names.")
|
|
41
|
+
self.nodes = nodes
|
|
42
|
+
|
|
43
|
+
# Aggregate Spaces
|
|
44
|
+
_, self.context_space = self.aggregate_spaces(
|
|
45
|
+
{node.name: node.context_space for node in nodes if node.context_space is not None},
|
|
46
|
+
direct_return=direct_return,
|
|
47
|
+
)
|
|
48
|
+
_, self.observation_space = self.aggregate_spaces(
|
|
49
|
+
{node.name: node.observation_space for node in nodes if node.observation_space is not None},
|
|
50
|
+
direct_return=direct_return,
|
|
51
|
+
)
|
|
52
|
+
self._action_node_name_direct, self.action_space = self.aggregate_spaces(
|
|
53
|
+
{node.name: node.action_space for node in nodes if node.action_space is not None},
|
|
54
|
+
direct_return=direct_return,
|
|
55
|
+
)
|
|
56
|
+
self.has_reward = any(node.has_reward for node in nodes)
|
|
57
|
+
self.has_termination_signal = any(node.has_termination_signal for node in nodes)
|
|
58
|
+
self.has_truncation_signal = any(node.has_truncation_signal for node in nodes)
|
|
59
|
+
|
|
60
|
+
# Save attributes
|
|
61
|
+
self.name = name
|
|
62
|
+
self.direct_return = direct_return
|
|
63
|
+
|
|
64
|
+
@staticmethod
|
|
65
|
+
def aggregate_spaces(
|
|
66
|
+
spaces : Dict[str, Optional[Space[Any, BDeviceType, BDtypeType, BRNGType]]],
|
|
67
|
+
direct_return : bool = True,
|
|
68
|
+
) -> Tuple[
|
|
69
|
+
Optional[str],
|
|
70
|
+
Optional[DictSpace[BDeviceType, BDtypeType, BRNGType]]
|
|
71
|
+
]:
|
|
72
|
+
if len(spaces) == 0:
|
|
73
|
+
return None, None
|
|
74
|
+
elif len(spaces) == 1 and direct_return:
|
|
75
|
+
return next(iter(spaces.items()))
|
|
76
|
+
else:
|
|
77
|
+
backend = next(iter(spaces.values())).backend
|
|
78
|
+
return None, DictSpace(
|
|
79
|
+
backend,
|
|
80
|
+
{
|
|
81
|
+
name: space for name, space in spaces.items() if space is not None
|
|
82
|
+
}
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
@staticmethod
|
|
86
|
+
def aggregate_data(
|
|
87
|
+
data : Dict[str, Any],
|
|
88
|
+
direct_return : bool = True,
|
|
89
|
+
) -> Optional[Union[Dict[str, Any], Any]]:
|
|
90
|
+
if len(data) == 0:
|
|
91
|
+
return None
|
|
92
|
+
elif len(data) == 1 and direct_return:
|
|
93
|
+
return next(iter(data.values()))
|
|
94
|
+
else:
|
|
95
|
+
return data
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def world(self) -> World[BArrayType, BDeviceType, BDtypeType, BRNGType]:
|
|
99
|
+
return self.nodes[0].world
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def control_timestep(self) -> Optional[float]:
|
|
103
|
+
return self.nodes[0].control_timestep
|
|
104
|
+
|
|
105
|
+
def pre_environment_step(self, dt):
|
|
106
|
+
for node in self.nodes:
|
|
107
|
+
node.pre_environment_step(dt)
|
|
108
|
+
|
|
109
|
+
def get_observation(self):
|
|
110
|
+
assert self.observation_space is not None, "Observation space is None, cannot get observation."
|
|
111
|
+
return self.aggregate_data(
|
|
112
|
+
{
|
|
113
|
+
node.name: node.get_observation()
|
|
114
|
+
for node in self.nodes
|
|
115
|
+
if node.observation_space is not None
|
|
116
|
+
},
|
|
117
|
+
direct_return=self.direct_return,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
def get_reward(self):
|
|
121
|
+
assert self.has_reward, "This node does not provide a reward."
|
|
122
|
+
if self.world.batch_size is None:
|
|
123
|
+
return sum(
|
|
124
|
+
node.get_reward()
|
|
125
|
+
for node in self.nodes
|
|
126
|
+
if node.has_reward
|
|
127
|
+
)
|
|
128
|
+
else:
|
|
129
|
+
rewards = self.backend.zeros((self.world.batch_size,), dtype=self.backend.default_floating_dtype, device=self.device)
|
|
130
|
+
for node in self.nodes:
|
|
131
|
+
if node.has_reward:
|
|
132
|
+
rewards = rewards + node.get_reward()
|
|
133
|
+
return rewards
|
|
134
|
+
|
|
135
|
+
def get_termination(self):
|
|
136
|
+
assert self.has_termination_signal, "This node does not provide a termination signal."
|
|
137
|
+
if self.world.batch_size is None:
|
|
138
|
+
return any(
|
|
139
|
+
node.get_termination()
|
|
140
|
+
for node in self.nodes
|
|
141
|
+
if node.has_termination_signal
|
|
142
|
+
)
|
|
143
|
+
else:
|
|
144
|
+
terminations = self.backend.zeros((self.world.batch_size,), dtype=self.backend.default_bool_dtype, device=self.device)
|
|
145
|
+
for node in self.nodes:
|
|
146
|
+
if node.has_termination_signal:
|
|
147
|
+
terminations = self.backend.logical_or(terminations, node.get_termination())
|
|
148
|
+
return terminations
|
|
149
|
+
|
|
150
|
+
def get_truncation(self):
|
|
151
|
+
assert self.has_truncation_signal, "This node does not provide a truncation signal."
|
|
152
|
+
if self.world.batch_size is None:
|
|
153
|
+
return any(
|
|
154
|
+
node.get_truncation()
|
|
155
|
+
for node in self.nodes
|
|
156
|
+
if node.has_truncation_signal
|
|
157
|
+
)
|
|
158
|
+
else:
|
|
159
|
+
truncations = self.backend.zeros((self.world.batch_size,), dtype=self.backend.default_bool_dtype, device=self.device)
|
|
160
|
+
for node in self.nodes:
|
|
161
|
+
if node.has_truncation_signal:
|
|
162
|
+
truncations = self.backend.logical_or(truncations, node.get_truncation())
|
|
163
|
+
return truncations
|
|
164
|
+
|
|
165
|
+
def get_info(self) -> Optional[Dict[str, Any]]:
|
|
166
|
+
infos = {}
|
|
167
|
+
for node in self.nodes:
|
|
168
|
+
info = node.get_info()
|
|
169
|
+
if info is not None:
|
|
170
|
+
infos[node.name] = info
|
|
171
|
+
|
|
172
|
+
return self.aggregate_data(
|
|
173
|
+
infos,
|
|
174
|
+
direct_return=False
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
def set_next_action(self, action):
|
|
178
|
+
assert self.action_space is not None, "Action space is None, cannot set action."
|
|
179
|
+
if self._action_node_name_direct is not None:
|
|
180
|
+
for node in self.nodes:
|
|
181
|
+
if node.name == self._action_node_name_direct:
|
|
182
|
+
node.set_next_action(action)
|
|
183
|
+
break
|
|
184
|
+
else:
|
|
185
|
+
assert isinstance(action, Mapping), "Action must be a mapping when there are multiple action spaces."
|
|
186
|
+
for node in self.nodes:
|
|
187
|
+
if node.action_space is not None:
|
|
188
|
+
assert node.name in action, f"Action for node {node.name} is missing."
|
|
189
|
+
node.set_next_action(action[node.name])
|
|
190
|
+
|
|
191
|
+
def post_environment_step(self, dt):
|
|
192
|
+
for node in self.nodes:
|
|
193
|
+
node.post_environment_step(dt)
|
|
194
|
+
|
|
195
|
+
def reset(self, *, seed = None, mask = None, pernode_kwargs : Dict[str, Any] = {}):
|
|
196
|
+
for node in self.nodes:
|
|
197
|
+
node.reset(
|
|
198
|
+
seed=seed,
|
|
199
|
+
mask=mask,
|
|
200
|
+
**pernode_kwargs.get(node.name, {})
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
def after_reset(self, *, mask = None):
|
|
204
|
+
contexts = {}
|
|
205
|
+
observations = {}
|
|
206
|
+
infos = {}
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
for node in self.nodes:
|
|
210
|
+
context, observation, info = node.after_reset(mask=mask)
|
|
211
|
+
if context is not None:
|
|
212
|
+
contexts[node.name] = context
|
|
213
|
+
if observation is not None:
|
|
214
|
+
observations[node.name] = observation
|
|
215
|
+
if info is not None:
|
|
216
|
+
infos[node.name] = info
|
|
217
|
+
|
|
218
|
+
return self.aggregate_data(
|
|
219
|
+
contexts,
|
|
220
|
+
direct_return=self.direct_return,
|
|
221
|
+
), self.aggregate_data(
|
|
222
|
+
observations,
|
|
223
|
+
direct_return=self.direct_return,
|
|
224
|
+
), self.aggregate_data(
|
|
225
|
+
infos,
|
|
226
|
+
direct_return=False
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def close(self):
|
|
231
|
+
for node in self.nodes:
|
|
232
|
+
node.close()
|
|
@@ -60,7 +60,7 @@ class ToBackendOrDeviceWrapper(
|
|
|
60
60
|
|
|
61
61
|
# Set new rng compatible with the new backend and device
|
|
62
62
|
env.rng, seed = seed_util.next_seed_rng(env.rng, env.backend)
|
|
63
|
-
self._rng = backend.random.random_number_generator(
|
|
63
|
+
self._rng = (backend or env.backend).random.random_number_generator(
|
|
64
64
|
seed=seed,
|
|
65
65
|
device=device
|
|
66
66
|
)
|
|
@@ -80,7 +80,7 @@ class ToBackendOrDeviceWrapper(
|
|
|
80
80
|
|
|
81
81
|
@property
|
|
82
82
|
def backend(self) -> ComputeBackend[Any, WrapperBDeviceT, Any, WrapperBRngT]:
|
|
83
|
-
return self._backend
|
|
83
|
+
return self._backend or self.env.backend
|
|
84
84
|
|
|
85
85
|
@property
|
|
86
86
|
def device(self) -> Optional[WrapperBDeviceT]:
|
|
@@ -1,103 +1,17 @@
|
|
|
1
|
-
from typing import Dict
|
|
1
|
+
from typing import Dict, Any, Optional, Tuple, Union, SupportsFloat
|
|
2
2
|
import numpy as np
|
|
3
3
|
import copy
|
|
4
4
|
|
|
5
5
|
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
6
|
-
|
|
7
|
-
from unienv_interface.space.space_utils import batch_utils as sbu, flatten_utils as sfu
|
|
8
|
-
from unienv_interface.utils import seed_util
|
|
6
|
+
from unienv_interface.space.space_utils import batch_utils as sbu
|
|
9
7
|
from unienv_interface.env_base.env import Env, ContextType, ObsType, ActType, RenderFrame, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
10
8
|
from unienv_interface.env_base.wrapper import ContextObservationWrapper, ActionWrapper, WrapperContextT, WrapperObsT, WrapperActT
|
|
11
9
|
from unienv_interface.space import Space, DictSpace
|
|
12
|
-
from
|
|
13
|
-
|
|
14
|
-
DataT = TypeVar('DataT')
|
|
15
|
-
class SpaceDataQueue(
|
|
16
|
-
Generic[DataT, BArrayType, BDeviceType, BDtypeType, BRNGType]
|
|
17
|
-
):
|
|
18
|
-
def __init__(
|
|
19
|
-
self,
|
|
20
|
-
space : Space[DataT, BDeviceType, BDtypeType, BRNGType],
|
|
21
|
-
batch_size : Optional[int],
|
|
22
|
-
maxlen: int,
|
|
23
|
-
) -> None:
|
|
24
|
-
assert maxlen > 0, "Max length must be greater than 0"
|
|
25
|
-
assert batch_size is None or batch_size > 0, "Batch size must be greater than 0 if provided"
|
|
26
|
-
assert batch_size is None or sbu.batch_size(space) == batch_size, "Batch size must match the space's batch size if provided"
|
|
27
|
-
self.space = space
|
|
28
|
-
self.single_space = space
|
|
29
|
-
self.stacked_space = sbu.batch_space(space, maxlen) # (H, ...) or (L, B, ...)
|
|
30
|
-
self.output_space = sbu.swap_batch_dims(
|
|
31
|
-
self.stacked_space, 0, 1
|
|
32
|
-
) if batch_size is not None else self.stacked_space # (B, L, ...) or (H, ...)
|
|
33
|
-
self.data = self.stacked_space.create_empty()
|
|
34
|
-
self._maxlen = maxlen
|
|
35
|
-
self._batch_size = batch_size
|
|
36
|
-
|
|
37
|
-
@property
|
|
38
|
-
def maxlen(self) -> int:
|
|
39
|
-
return self._maxlen
|
|
40
|
-
|
|
41
|
-
@property
|
|
42
|
-
def batch_size(self) -> Optional[int]:
|
|
43
|
-
return self._batch_size
|
|
44
|
-
|
|
45
|
-
@property
|
|
46
|
-
def backend(self) -> ComputeBackend:
|
|
47
|
-
return self.space.backend
|
|
48
|
-
|
|
49
|
-
@property
|
|
50
|
-
def device(self) -> Optional[BDeviceType]:
|
|
51
|
-
return self.space.device
|
|
52
|
-
|
|
53
|
-
def reset(
|
|
54
|
-
self,
|
|
55
|
-
initial_data : DataT,
|
|
56
|
-
mask : Optional[BArrayType] = None,
|
|
57
|
-
) -> None:
|
|
58
|
-
assert self.batch_size is None or mask is None, \
|
|
59
|
-
"Mask should not be provided if batch size is empty"
|
|
60
|
-
index = (
|
|
61
|
-
slice(None), mask
|
|
62
|
-
) if mask is not None else slice(None)
|
|
63
|
-
|
|
64
|
-
expanded_data = sbu.get_at( # Add a singleton horizon dimension to the data
|
|
65
|
-
self.space,
|
|
66
|
-
initial_data,
|
|
67
|
-
None
|
|
68
|
-
)
|
|
69
|
-
self.data = sbu.set_at(
|
|
70
|
-
self.stacked_space,
|
|
71
|
-
self.data,
|
|
72
|
-
index,
|
|
73
|
-
expanded_data
|
|
74
|
-
)
|
|
75
|
-
|
|
76
|
-
def append(self, data : DataT) -> None:
|
|
77
|
-
self.data = self.backend.map_fn_over_arrays(
|
|
78
|
-
self.data,
|
|
79
|
-
lambda x: self.backend.roll(x, shift=-1, axis=0),
|
|
80
|
-
)
|
|
81
|
-
self.data = sbu.set_at(
|
|
82
|
-
self.stacked_space,
|
|
83
|
-
self.data,
|
|
84
|
-
-1,
|
|
85
|
-
data
|
|
86
|
-
)
|
|
87
|
-
|
|
88
|
-
def get_output_data(self) -> DataT:
|
|
89
|
-
if self.batch_size is None:
|
|
90
|
-
return self.data
|
|
91
|
-
else:
|
|
92
|
-
return sbu.swap_batch_dims_in_data(
|
|
93
|
-
self.backend,
|
|
94
|
-
self.data,
|
|
95
|
-
0, 1
|
|
96
|
-
) # (L, B, ...) -> (B, L, ...)
|
|
10
|
+
from unienv_interface.utils.data_queue import SpaceDataQueue
|
|
97
11
|
|
|
98
12
|
class FrameStackWrapper(
|
|
99
13
|
ContextObservationWrapper[
|
|
100
|
-
ContextType, Union[
|
|
14
|
+
ContextType, Union[Dict[str, Any], Any],
|
|
101
15
|
BArrayType, ContextType, ObsType, ActType, RenderFrame, BDeviceType, BDtypeType, BRNGType
|
|
102
16
|
]
|
|
103
17
|
):
|
|
@@ -136,32 +50,24 @@ class FrameStackWrapper(
|
|
|
136
50
|
env.batch_size,
|
|
137
51
|
obs_stack_size + 1
|
|
138
52
|
)
|
|
139
|
-
|
|
140
53
|
if action_stack_size > 0:
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
self.observation_space =
|
|
144
|
-
env.backend,
|
|
145
|
-
new_obs_spaces,
|
|
146
|
-
device=env.observation_space.device
|
|
147
|
-
)
|
|
54
|
+
new_obs_space = copy.copy(self.obs_deque.output_space)
|
|
55
|
+
new_obs_space['past_actions'] = self.action_deque.output_space
|
|
56
|
+
self.observation_space = new_obs_space
|
|
148
57
|
else:
|
|
149
58
|
self.observation_space = self.obs_deque.output_space
|
|
150
59
|
else:
|
|
151
60
|
if action_stack_size > 0:
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
self.observation_space =
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
device=env.observation_space.device
|
|
158
|
-
)
|
|
159
|
-
self.obs_deque = None
|
|
61
|
+
new_obs_space = copy.copy(env.observation_space)
|
|
62
|
+
new_obs_space['past_actions'] = self.action_deque.output_space
|
|
63
|
+
self.observation_space = new_obs_space
|
|
64
|
+
else:
|
|
65
|
+
raise ValueError("At least one of observation stack size or action stack size must be greater than 0")
|
|
160
66
|
|
|
161
67
|
def reverse_map_context(self, context: ContextType) -> ContextType:
|
|
162
68
|
return context
|
|
163
69
|
|
|
164
|
-
def map_observation(self, observation: ObsType) -> Union[
|
|
70
|
+
def map_observation(self, observation: ObsType) -> Union[Dict[str, Any], Any]:
|
|
165
71
|
if self.obs_deque is not None:
|
|
166
72
|
observation = self.obs_deque.get_output_data()
|
|
167
73
|
|
|
@@ -170,7 +76,7 @@ class FrameStackWrapper(
|
|
|
170
76
|
observation['past_actions'] = stacked_action
|
|
171
77
|
return observation
|
|
172
78
|
|
|
173
|
-
def reverse_map_observation(self, observation: Union[
|
|
79
|
+
def reverse_map_observation(self, observation: Union[Dict[str, Any], Any]) -> ObsType:
|
|
174
80
|
if isinstance(observation, dict):
|
|
175
81
|
stacked_obs = observation.copy()
|
|
176
82
|
stacked_obs.pop('past_actions', None)
|
|
@@ -197,7 +103,7 @@ class FrameStackWrapper(
|
|
|
197
103
|
mask: Optional[BArrayType] = None,
|
|
198
104
|
seed: Optional[int] = None,
|
|
199
105
|
**kwargs
|
|
200
|
-
) -> Tuple[ContextType, Union[
|
|
106
|
+
) -> Tuple[ContextType, Union[Dict[str, Any], Any], Dict[str, Any]]:
|
|
201
107
|
# TODO: If a mask is provided, we should only reset the stack for the masked indices
|
|
202
108
|
context, obs, info = self.env.reset(
|
|
203
109
|
*args,
|
|
@@ -227,16 +133,15 @@ class FrameStackWrapper(
|
|
|
227
133
|
self,
|
|
228
134
|
action: ActType
|
|
229
135
|
) -> Tuple[
|
|
230
|
-
Union[
|
|
136
|
+
Union[Dict[str, Any], Any],
|
|
231
137
|
Union[SupportsFloat, BArrayType],
|
|
232
138
|
Union[bool, BArrayType],
|
|
233
139
|
Union[bool, BArrayType],
|
|
234
|
-
|
|
140
|
+
Dict[str, Any]
|
|
235
141
|
]:
|
|
236
142
|
obs, rew, terminated, truncated, info = self.env.step(action)
|
|
237
143
|
if self.action_deque is not None:
|
|
238
|
-
self.action_deque.
|
|
144
|
+
self.action_deque.add(action)
|
|
239
145
|
if self.obs_deque is not None:
|
|
240
|
-
self.obs_deque.
|
|
241
|
-
|
|
146
|
+
self.obs_deque.add(obs)
|
|
242
147
|
return self.map_observation(obs), rew, terminated, truncated, info
|
|
@@ -164,7 +164,11 @@ class EpisodeVideoWrapper(
|
|
|
164
164
|
|
|
165
165
|
frames = []
|
|
166
166
|
for frame in self.episodic_frames:
|
|
167
|
-
|
|
167
|
+
if self.env.backend.is_backendarray(frame):
|
|
168
|
+
frame_np = self.env.backend.to_numpy(frame)
|
|
169
|
+
else:
|
|
170
|
+
assert isinstance(frame, np.ndarray)
|
|
171
|
+
frame_np = frame
|
|
168
172
|
assert frame_np.shape[2] == 3
|
|
169
173
|
frames.append(frame_np)
|
|
170
174
|
clip = ImageSequenceClip(frames, fps=self.env.render_fps or 30)
|
|
@@ -213,7 +217,12 @@ class EpisodeWandbVideoWrapper(
|
|
|
213
217
|
*self.episodic_frames[0].shape
|
|
214
218
|
))
|
|
215
219
|
for i, frame in enumerate(self.episodic_frames):
|
|
216
|
-
|
|
220
|
+
if self.env.backend.is_backendarray(frame):
|
|
221
|
+
frame_np = self.env.backend.to_numpy(frame)
|
|
222
|
+
else:
|
|
223
|
+
assert isinstance(frame, np.ndarray)
|
|
224
|
+
frame_np = frame
|
|
225
|
+
assert frame_np.shape[2] == 3
|
|
217
226
|
frames[i] = frame_np
|
|
218
227
|
clip = self.wandb.Video(
|
|
219
228
|
frames.transpose(0, 3, 1, 2),
|
|
@@ -1,20 +0,0 @@
|
|
|
1
|
-
Metadata-Version: 2.4
|
|
2
|
-
Name: unienv
|
|
3
|
-
Version: 0.0.1b1
|
|
4
|
-
Requires-Python: >=3.10
|
|
5
|
-
License-File: LICENSE
|
|
6
|
-
Requires-Dist: numpy
|
|
7
|
-
Requires-Dist: xbarray>=0.0.1a8
|
|
8
|
-
Requires-Dist: pillow
|
|
9
|
-
Requires-Dist: h5py
|
|
10
|
-
Provides-Extra: dev
|
|
11
|
-
Requires-Dist: pytest; extra == "dev"
|
|
12
|
-
Provides-Extra: gymnasium
|
|
13
|
-
Requires-Dist: gymnasium>=0.29.0; extra == "gymnasium"
|
|
14
|
-
Provides-Extra: video
|
|
15
|
-
Requires-Dist: moviepy>=2.1; extra == "video"
|
|
16
|
-
Provides-Extra: mjx
|
|
17
|
-
Requires-Dist: playground; extra == "mjx"
|
|
18
|
-
Provides-Extra: maniskill
|
|
19
|
-
Requires-Dist: mani_skill>=3.0.0b12; extra == "maniskill"
|
|
20
|
-
Dynamic: license-file
|
unienv-0.0.1b1.dist-info/RECORD
DELETED
|
@@ -1,85 +0,0 @@
|
|
|
1
|
-
unienv-0.0.1b1.dist-info/licenses/LICENSE,sha256=VHeh-ceoc7FrtFmwszuGP2ZgOQGJNXIIWggxpk3B54E,1130
|
|
2
|
-
unienv_data/__init__.py,sha256=zFxbe7aM5JvYXIK0FGnOPwWQJMN-8l_l8prB85CkcA8,95
|
|
3
|
-
unienv_data/base/__init__.py,sha256=4RP84bYe8lpDFF3Ow1Ch4P91HGgE0Y_nmfEj1Ah5nxE,232
|
|
4
|
-
unienv_data/base/common.py,sha256=SBYd2i-ZZq7ws_vphn8yx-q07jAbmMqUDvq4rdVlADE,10179
|
|
5
|
-
unienv_data/base/storage.py,sha256=W_ZaNPnoXoarR83Ap56vV3_7BarZlhXVo2L_RDeARD8,5102
|
|
6
|
-
unienv_data/base/transformations.py,sha256=dlZ7oAABz7Ihm9T6rJnDYw_5JQgBp2UsVRgy5-W-AdA,5596
|
|
7
|
-
unienv_data/batches/__init__.py,sha256=431RMsGjpAIpvwaHBzCFaG5755is0W89KlPTVXESMGw,188
|
|
8
|
-
unienv_data/batches/backend_compat.py,sha256=8L3jeWTWomAI87C7VQDwkbSDSncMS8-UgqGHTe5rmkw,6408
|
|
9
|
-
unienv_data/batches/combined_batch.py,sha256=XyQeaJgKC6uEScjQUD1DrkowTW0R_DTIcH_I6kUnRYA,15395
|
|
10
|
-
unienv_data/batches/framestack_batch.py,sha256=pdURqZeksOlbf21Nhx8kkm0gtFt6rjt2OiNWgZPdFCM,2312
|
|
11
|
-
unienv_data/batches/slicestack_batch.py,sha256=J2EhARcPA-zz6EBnV7OLzm4yyvnZ06vrdUoPD5RkJ-o,16672
|
|
12
|
-
unienv_data/integrations/pytorch.py,sha256=pW5rXBXagfzwJjM_VGgg8CPXEs3e2fKgg4nY7M3dpOc,2350
|
|
13
|
-
unienv_data/replay_buffer/__init__.py,sha256=uVebYruIYlj8OjTYVi8UYI4gWp3S3XIdgFlHbwO260o,100
|
|
14
|
-
unienv_data/replay_buffer/replay_buffer.py,sha256=GVfTyjkFAQ5I6hHtsPk2iftaIOwY9Aj0SN_ir0Kyb4M,10192
|
|
15
|
-
unienv_data/replay_buffer/trajectory_replay_buffer.py,sha256=fxV6FIqAHhN8opYs2WjAJMPqNRWD3iIku-4WlaydyG4,20737
|
|
16
|
-
unienv_data/samplers/__init__.py,sha256=pkoGnWKAR7dqKdDtQn-5y8pY8fqFPt7wFMnGapmFXP0,137
|
|
17
|
-
unienv_data/samplers/multiprocessing_sampler.py,sha256=8GLnGiqEYozTif1Hg_rAjMMseWDkaq5D5IIsTTY67gw,13031
|
|
18
|
-
unienv_data/samplers/slice_sampler.py,sha256=mbVKmlT0uij1p0emaPJI0T1hc5LS5B5D2chEO7yehIM,11521
|
|
19
|
-
unienv_data/samplers/step_sampler.py,sha256=l6gm6F089iaytAd81mXZPgtCOQ1ckUwPuaVOgCIHLxs,3197
|
|
20
|
-
unienv_data/storages/common.py,sha256=-Kb1SY-b4D2GvmLKOwRYjtrXiUROIVHLhs4Xh86S7UQ,6103
|
|
21
|
-
unienv_data/storages/hdf5.py,sha256=3zxt-JWJOBVcC-IGk9LvIQBaJJO9f-2UriVNhMrtvk4,14494
|
|
22
|
-
unienv_data/storages/pytorch.py,sha256=1kAU5ZPMkGLNry4aubnSZQjEP0H8dO8xrH2Yg8shpGw,6213
|
|
23
|
-
unienv_interface/__init__.py,sha256=pAWqfm4l7NAssuyXCugIjekSIh05aBbOjNhwsNXcJbE,100
|
|
24
|
-
unienv_interface/backends/__init__.py,sha256=L7CFwCChHVL-2Dpz34pTGC37WgodfJEeDQwXscyM7FM,198
|
|
25
|
-
unienv_interface/backends/base.py,sha256=1_hji1qwNAhcEtFQdAuzaNey9g5bWYj38t1sQxjnggc,132
|
|
26
|
-
unienv_interface/backends/jax.py,sha256=Rus-aBp1kz3XMOwDzoi5aI1hO0Ha7g2V5U1WcRz2bf0,433
|
|
27
|
-
unienv_interface/backends/numpy.py,sha256=NnNoUGUsvXnTYU86sNmuf44Z2jSxBTk4qKul7eaNM-A,436
|
|
28
|
-
unienv_interface/backends/pytorch.py,sha256=xLobPQFnM4WGppYx9N1Emly9w_UXEx8IHtK1TPd89dw,480
|
|
29
|
-
unienv_interface/backends/serialization.py,sha256=0TZlpfbP1DRB4FkM8ysDVQmn6RlYtIPisyeHjvHr7bE,2289
|
|
30
|
-
unienv_interface/env_base/__init__.py,sha256=UT5pBnf4kaXT450o8bmffH7kkSZXpU6AICrsZLxY2Yg,181
|
|
31
|
-
unienv_interface/env_base/env.py,sha256=Wy2_kYyP-t-2zoJk5SLRhhwV-rMKGdQVWUgXOOCtlHU,4661
|
|
32
|
-
unienv_interface/env_base/funcenv.py,sha256=naOfnff4yw-D3hh6covgVP-ZAszyq-3N96AOslVtASI,9678
|
|
33
|
-
unienv_interface/env_base/funcenv_wrapper.py,sha256=UcTdE7vKiN8AQMwN4tk-AQyHbHuZjtWhEIpiVP1H-xY,7746
|
|
34
|
-
unienv_interface/env_base/wrapper.py,sha256=7hf4Rr2wouS0igPoahhvb2tzYY3bCaWL0NlgwpYZwQs,9734
|
|
35
|
-
unienv_interface/func_wrapper/__init__.py,sha256=X0BwdAFsrhmHy4DkhHgoAyCF5xIHtAwzijzCwegFq10,48
|
|
36
|
-
unienv_interface/func_wrapper/transformation.py,sha256=7mdzcpjLjqtpbtXoqbkGtTMPQxoMmMsqzDWHcZLbrhs,5939
|
|
37
|
-
unienv_interface/space/__init__.py,sha256=6-wLoD9mKDAfz7IuQs_Rn9DMDfDwTZ0tEhQ924libpg,99
|
|
38
|
-
unienv_interface/space/space.py,sha256=mFlCcDvMgEPTXlwo_iwBlm6Eg4Bn2rrecgsfIVstdq0,4067
|
|
39
|
-
unienv_interface/space/space_utils/__init__.py,sha256=uz53bqNEizFxJcyQj-Q3yt65ZovK_ItXyjkdapXqpmg,90
|
|
40
|
-
unienv_interface/space/space_utils/batch_utils.py,sha256=3PXgrMrv9ecw0nwxYSvI91UUqwuxyYhe4INeN1--NIU,33616
|
|
41
|
-
unienv_interface/space/space_utils/flatten_utils.py,sha256=kkHkjrsk43NDbg3Q5VAhVoIXStuRayYFO-7knsDzx4A,12289
|
|
42
|
-
unienv_interface/space/space_utils/gym_utils.py,sha256=nH8EKruOKCXNrIMPUd9F4XGKCfFkhxsTmx4I1BeSgn0,15079
|
|
43
|
-
unienv_interface/space/space_utils/serialization_utils.py,sha256=jfnHowqIAVXgT1WR4OU8VWp4b52mOWDXH25QdclDV18,8863
|
|
44
|
-
unienv_interface/space/spaces/__init__.py,sha256=8EW6XKqTn9NQ9B1jYNSM9ouQhXRFjY-TsH1GAaBRVIU,480
|
|
45
|
-
unienv_interface/space/spaces/binary.py,sha256=zg1fOxsugiDBKja29Q8J3OKZGWnt8jq8AnEEe-hvcKE,3638
|
|
46
|
-
unienv_interface/space/spaces/box.py,sha256=qN-SF80T4UhWZskSd5hNMxSzkT1zEIqOh18LV4IFCeE,13792
|
|
47
|
-
unienv_interface/space/spaces/dict.py,sha256=G5_iYC1Bj5DqeJ7aFlq6eRJbnpATbIRIyRu1jF_UUvk,7022
|
|
48
|
-
unienv_interface/space/spaces/dynamic_box.py,sha256=HvMNgzfYwIVc5VVgCtq-8lQbNI1V1dZMI-w60AwYss4,19591
|
|
49
|
-
unienv_interface/space/spaces/graph.py,sha256=KocRFLtYP5VWYpwbP6HybXH5R4jTIYJdNePKb6vhnYE,15163
|
|
50
|
-
unienv_interface/space/spaces/text.py,sha256=S9u4fd-X6k-xQMxayxYECeasrXosCLprfLz4SgVHVew,4561
|
|
51
|
-
unienv_interface/space/spaces/tuple.py,sha256=rgZQz3EB3CLbIk9UlHBQbM6w9gssbA1izm-Qq-_Chqs,4267
|
|
52
|
-
unienv_interface/space/spaces/union.py,sha256=Qisd-DdmPcGRmdhZFGiQw8_AOjYWqkuQ4Hwd-I8tdSI,4375
|
|
53
|
-
unienv_interface/transformations/__init__.py,sha256=g19uGnDHMywvDAXRaqFgoWAF1vCPrbJENEpaEgtIrOw,353
|
|
54
|
-
unienv_interface/transformations/batch_and_unbatch.py,sha256=ELCnNtwmgA5wpTBJZasfNSHmtf4vzydzLPmO6IGbT9o,1164
|
|
55
|
-
unienv_interface/transformations/chained_transform.py,sha256=TDnUvxUKK6bXGc_sfr6ZCvvVWw7P5KX2sA9i7i2lx14,2075
|
|
56
|
-
unienv_interface/transformations/dict_transform.py,sha256=nv9BFDmoDyv2hTGgHAOcLOwIJD2i5N-L3qqE5AnIR1g,4504
|
|
57
|
-
unienv_interface/transformations/filter_dict.py,sha256=DzR-hgHoHJObTipxwB2UrKVlTxbfIrJohaOgqdAICLY,5871
|
|
58
|
-
unienv_interface/transformations/rescale.py,sha256=fM5ukWUvNvPeDO48_PRU0KyyvGhBIDxaN9XZyQ1VaQQ,4364
|
|
59
|
-
unienv_interface/transformations/transformation.py,sha256=u4_9H1tvophhgG0p0F3xfkMMsRuaKY2TQmVeGoeQsJ0,1652
|
|
60
|
-
unienv_interface/utils/seed_util.py,sha256=Up3nBXj7L8w-S9W5Q1U2d9accMhMf0TmHPaN6JXDVWs,677
|
|
61
|
-
unienv_interface/utils/symbol_util.py,sha256=NAERK-D_2MaTg2eYW-L75tbzPQN5YJIiKtM9zuQ89Sw,383
|
|
62
|
-
unienv_interface/world/__init__.py,sha256=FDhYhxAGnUmR-_3eCjN4LjiaYXXaVaU3QNRPhx-Umbw,132
|
|
63
|
-
unienv_interface/world/funcnode.py,sha256=mUpVQ_j7dVF6V7Dc435RNtMQB2LsuWRRkp-hBLOCwzc,7829
|
|
64
|
-
unienv_interface/world/funcworld.py,sha256=GLp8nS0Ym1gaj7FWvD5FPkQElCgZMbpyuLsIMU0w-sw,2020
|
|
65
|
-
unienv_interface/world/node.py,sha256=Qn8rErvhkRp2U0s_m_0OqDLY723w9E5W8tGdmKcP-mY,5996
|
|
66
|
-
unienv_interface/world/world.py,sha256=Kl7wbNbs2YR3CjFrCLFhDB3DQUAWM6LjBwSADQtBTII,5740
|
|
67
|
-
unienv_interface/wrapper/__init__.py,sha256=ZNqr-WjVRqgvIxkLkeABxpYZ6tRgJNZOzmluDeJ6W_w,614
|
|
68
|
-
unienv_interface/wrapper/action_rescale.py,sha256=rTJlEHvWSuwGVX83cjfLWvszBk7B2iExX_K37vH8Wic,1231
|
|
69
|
-
unienv_interface/wrapper/backend_compat.py,sha256=-nq4XQHBLlwNu3ku65lVPv3-h_A0_sYZciQgGBG6RXs,7044
|
|
70
|
-
unienv_interface/wrapper/batch_and_unbatch.py,sha256=HpmnppgOKmshNlfmJYkGQYtEU7_U7q3mEdY5n4UaqEY,3457
|
|
71
|
-
unienv_interface/wrapper/control_frequency_limit.py,sha256=B0E2aUbaUr2p2yIN6wT3q4rAbPYsVroioqma2qKMoC0,2322
|
|
72
|
-
unienv_interface/wrapper/flatten.py,sha256=NWA5xne5j_L34oq_wT85wGvp6iHwdCSeGsk1DMugvRw,5837
|
|
73
|
-
unienv_interface/wrapper/frame_stack.py,sha256=ss7Z7AEMwuFDkmHbcpr9srZwt4a2i9V9DPykSF9LAQ8,8925
|
|
74
|
-
unienv_interface/wrapper/gym_compat.py,sha256=JhLxDsO1NsJnKzKhO0MqMw9i5_1FLxoxKilWaQQyBkw,9789
|
|
75
|
-
unienv_interface/wrapper/time_limit.py,sha256=VRvB00BK7deI2QtdGatqwDWmPgjgjg1E7MTvEyaW5rg,2904
|
|
76
|
-
unienv_interface/wrapper/transformation.py,sha256=pQ-_YVU8WWDqSk2sONUUgQY1iigOD092KNcp1DYxoxk,10043
|
|
77
|
-
unienv_interface/wrapper/video_record.py,sha256=dAwPQuxhj7Pgn4773I5Wlr1agxcLVVdXnw7dT-d0wtM,8357
|
|
78
|
-
unienv_maniskill/__init__.py,sha256=GJ6Fe5XosHqVxXSQ_cZT6cd9GWmfb2_NIYefCFD7rpU,54
|
|
79
|
-
unienv_maniskill/wrapper/maniskill_compat.py,sha256=yp9kX3Sn5Riszai47VasQcbGD6MFs67ITl0NuFqxp7Y,7435
|
|
80
|
-
unienv_mjxplayground/__init__.py,sha256=EaiB9FV7um1oZhhUNPx-hf4j3ruQ51Q1OSiS7Njls1M,59
|
|
81
|
-
unienv_mjxplayground/wrapper/playground_compat.py,sha256=xbkJ_7kO1VQjYsrtxT9Er6NNls4H5o5RZAe0ttgWX5w,7945
|
|
82
|
-
unienv-0.0.1b1.dist-info/METADATA,sha256=dwzzbRrlByBym1cKaOJPa907gGNqCLyDaiJ--FNCq5c,568
|
|
83
|
-
unienv-0.0.1b1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
84
|
-
unienv-0.0.1b1.dist-info/top_level.txt,sha256=lL1n-OMi2oZ5e4kuFcvbelQ-9DsdEgr7w-dQ-faiCD8,67
|
|
85
|
-
unienv-0.0.1b1.dist-info/RECORD,,
|