unienv 0.0.1b1__py3-none-any.whl → 0.0.1b3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- unienv-0.0.1b3.dist-info/METADATA +74 -0
- unienv-0.0.1b3.dist-info/RECORD +92 -0
- {unienv-0.0.1b1.dist-info → unienv-0.0.1b3.dist-info}/licenses/LICENSE +1 -3
- unienv-0.0.1b3.dist-info/top_level.txt +2 -0
- unienv_data/base/__init__.py +0 -1
- unienv_data/base/common.py +95 -45
- unienv_data/base/storage.py +1 -0
- unienv_data/batches/__init__.py +2 -1
- unienv_data/batches/backend_compat.py +47 -1
- unienv_data/batches/combined_batch.py +2 -4
- unienv_data/{base → batches}/transformations.py +3 -2
- unienv_data/replay_buffer/replay_buffer.py +4 -0
- unienv_data/samplers/__init__.py +0 -1
- unienv_data/samplers/multiprocessing_sampler.py +26 -22
- unienv_data/samplers/step_sampler.py +9 -18
- unienv_data/storages/common.py +5 -0
- unienv_data/storages/hdf5.py +291 -20
- unienv_data/storages/pytorch.py +1 -0
- unienv_data/storages/transformation.py +191 -0
- unienv_data/transformations/image_compress.py +213 -0
- unienv_interface/backends/jax.py +4 -1
- unienv_interface/backends/numpy.py +4 -1
- unienv_interface/backends/pytorch.py +4 -1
- unienv_interface/env_base/__init__.py +1 -0
- unienv_interface/env_base/env.py +5 -0
- unienv_interface/env_base/funcenv.py +32 -1
- unienv_interface/env_base/funcenv_wrapper.py +2 -2
- unienv_interface/env_base/vec_env.py +474 -0
- unienv_interface/func_wrapper/__init__.py +2 -1
- unienv_interface/func_wrapper/frame_stack.py +150 -0
- unienv_interface/space/space_utils/__init__.py +1 -0
- unienv_interface/space/space_utils/batch_utils.py +83 -0
- unienv_interface/space/space_utils/construct_utils.py +216 -0
- unienv_interface/space/space_utils/serialization_utils.py +16 -1
- unienv_interface/space/spaces/__init__.py +3 -1
- unienv_interface/space/spaces/batched.py +90 -0
- unienv_interface/space/spaces/binary.py +0 -1
- unienv_interface/space/spaces/box.py +13 -24
- unienv_interface/space/spaces/text.py +1 -3
- unienv_interface/transformations/dict_transform.py +31 -5
- unienv_interface/utils/control_util.py +68 -0
- unienv_interface/utils/data_queue.py +184 -0
- unienv_interface/utils/stateclass.py +46 -0
- unienv_interface/utils/vec_util.py +15 -0
- unienv_interface/world/__init__.py +3 -1
- unienv_interface/world/combined_funcnode.py +336 -0
- unienv_interface/world/combined_node.py +232 -0
- unienv_interface/wrapper/backend_compat.py +2 -2
- unienv_interface/wrapper/frame_stack.py +19 -114
- unienv_interface/wrapper/video_record.py +11 -2
- unienv-0.0.1b1.dist-info/METADATA +0 -20
- unienv-0.0.1b1.dist-info/RECORD +0 -85
- unienv-0.0.1b1.dist-info/top_level.txt +0 -4
- unienv_data/samplers/slice_sampler.py +0 -266
- unienv_maniskill/__init__.py +0 -1
- unienv_maniskill/wrapper/maniskill_compat.py +0 -235
- unienv_mjxplayground/__init__.py +0 -1
- unienv_mjxplayground/wrapper/playground_compat.py +0 -256
- {unienv-0.0.1b1.dist-info → unienv-0.0.1b3.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
from fractions import Fraction
|
|
2
|
+
from math import gcd
|
|
3
|
+
from functools import reduce
|
|
4
|
+
from typing import Iterable, Tuple, List, Union
|
|
5
|
+
|
|
6
|
+
Number = Union[int, float]
|
|
7
|
+
|
|
8
|
+
def _lcm(a: int, b: int) -> int:
|
|
9
|
+
return abs(a // gcd(a, b) * b) if a and b else 0
|
|
10
|
+
|
|
11
|
+
def _lcmm(values: Iterable[int]) -> int:
|
|
12
|
+
return reduce(_lcm, values, 1)
|
|
13
|
+
|
|
14
|
+
def _gcdm(values: Iterable[int]) -> int:
|
|
15
|
+
return reduce(gcd, values)
|
|
16
|
+
|
|
17
|
+
def find_best_timestep(
|
|
18
|
+
timesteps: Iterable[Number],
|
|
19
|
+
*,
|
|
20
|
+
max_denominator: int = 10_000,
|
|
21
|
+
return_fraction: bool = False,
|
|
22
|
+
) -> Tuple[Union[float, Fraction], List[int]]:
|
|
23
|
+
"""
|
|
24
|
+
Compute the simulation timestep dt such that every sensor period is an
|
|
25
|
+
integer multiple of dt (i.e., dt is the GCD of the periods). Works with floats.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
timesteps: Iterable of sensor periods (seconds, ms, etc.). Must be > 0.
|
|
29
|
+
max_denominator: Max denominator when rational-approximating floats.
|
|
30
|
+
Increase if your periods are very fine-grained.
|
|
31
|
+
return_fraction: If True, returns dt as a Fraction; otherwise a float.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
dt: The best simulation timestep (float or Fraction).
|
|
35
|
+
steps_per_sensor: For each input period T_i, the integer k_i = T_i / dt.
|
|
36
|
+
|
|
37
|
+
Raises:
|
|
38
|
+
ValueError: If list is empty or contains non-positive values.
|
|
39
|
+
"""
|
|
40
|
+
# Validate and convert to Fractions
|
|
41
|
+
periods = list(timesteps)
|
|
42
|
+
if not periods:
|
|
43
|
+
raise ValueError("timesteps must be a non-empty sequence.")
|
|
44
|
+
if any(p <= 0 for p in periods):
|
|
45
|
+
raise ValueError("All timesteps must be positive.")
|
|
46
|
+
|
|
47
|
+
fracs = [Fraction(p).limit_denominator(max_denominator) for p in periods]
|
|
48
|
+
|
|
49
|
+
# Find common denominator (LCM of all denominators)
|
|
50
|
+
D = _lcmm([f.denominator for f in fracs])
|
|
51
|
+
|
|
52
|
+
# Scale each fraction to that denominator → integers
|
|
53
|
+
ints = [f.numerator * (D // f.denominator) for f in fracs]
|
|
54
|
+
|
|
55
|
+
# GCD of the integerized periods → integer g; convert back via g/D
|
|
56
|
+
g = _gcdm(ints)
|
|
57
|
+
dt_frac = Fraction(g, D) # exact rational dt
|
|
58
|
+
|
|
59
|
+
# Sanity: dt must be > 0
|
|
60
|
+
if dt_frac <= 0:
|
|
61
|
+
raise RuntimeError("Computed non-positive timestep; check inputs.")
|
|
62
|
+
|
|
63
|
+
steps_per_sensor = [int(f // dt_frac) for f in fracs] # exact integer division
|
|
64
|
+
|
|
65
|
+
if return_fraction:
|
|
66
|
+
return dt_frac, steps_per_sensor
|
|
67
|
+
else:
|
|
68
|
+
return float(dt_frac), steps_per_sensor
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
from typing import Dict, Any, Optional, Tuple, Union, Generic, TypeVar
|
|
2
|
+
import numpy as np
|
|
3
|
+
import copy
|
|
4
|
+
import dataclasses
|
|
5
|
+
|
|
6
|
+
from unienv_interface.space import Space
|
|
7
|
+
from unienv_interface.space import batch_utils as sbu, flatten_utils as sfu
|
|
8
|
+
from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
9
|
+
|
|
10
|
+
DataT = TypeVar('DataT')
|
|
11
|
+
|
|
12
|
+
@dataclasses.dataclass(frozen=True)
|
|
13
|
+
class SpaceDataQueueState(Generic[DataT]):
|
|
14
|
+
data: DataT # (L, B, ...) or (L, ...)
|
|
15
|
+
|
|
16
|
+
def replace(self, **changes: Any) -> 'SpaceDataQueueState':
|
|
17
|
+
return dataclasses.replace(self, **changes)
|
|
18
|
+
|
|
19
|
+
class FuncSpaceDataQueue(
|
|
20
|
+
Generic[DataT, BArrayType, BDeviceType, BDtypeType, BRNGType]
|
|
21
|
+
):
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
space : Space[DataT, BDeviceType, BDtypeType, BRNGType],
|
|
25
|
+
batch_size : Optional[int],
|
|
26
|
+
maxlen: int,
|
|
27
|
+
) -> None:
|
|
28
|
+
assert maxlen > 0, "Max length must be greater than 0"
|
|
29
|
+
assert batch_size is None or batch_size > 0, "Batch size must be greater than 0 if provided"
|
|
30
|
+
assert batch_size is None or sbu.batch_size(space) == batch_size, "Batch size must match the space's batch size if provided"
|
|
31
|
+
self.single_space = space
|
|
32
|
+
self.stacked_space = sbu.batch_space(space, maxlen) # (L, ...) or (L, B, ...)
|
|
33
|
+
self.output_space = sbu.swap_batch_dims(
|
|
34
|
+
self.stacked_space, 0, 1
|
|
35
|
+
) if batch_size is not None else self.stacked_space # (B, L, ...) or (L, ...)
|
|
36
|
+
self._maxlen = maxlen
|
|
37
|
+
self._batch_size = batch_size
|
|
38
|
+
|
|
39
|
+
@property
|
|
40
|
+
def maxlen(self) -> int:
|
|
41
|
+
return self._maxlen
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def batch_size(self) -> Optional[int]:
|
|
45
|
+
return self._batch_size
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def backend(self) -> ComputeBackend:
|
|
49
|
+
return self.single_space.backend
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def device(self) -> Optional[BDeviceType]:
|
|
53
|
+
return self.single_space.device
|
|
54
|
+
|
|
55
|
+
def init(
|
|
56
|
+
self,
|
|
57
|
+
initial_data : DataT,
|
|
58
|
+
) -> SpaceDataQueueState:
|
|
59
|
+
return self.reset(
|
|
60
|
+
SpaceDataQueueState(self.stacked_space.create_empty()),
|
|
61
|
+
initial_data
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
def reset(
|
|
65
|
+
self,
|
|
66
|
+
state : SpaceDataQueueState,
|
|
67
|
+
initial_data : DataT,
|
|
68
|
+
mask : Optional[BArrayType] = None,
|
|
69
|
+
) -> SpaceDataQueueState:
|
|
70
|
+
assert self.batch_size is None or mask is None, \
|
|
71
|
+
"Mask should not be provided if batch size is empty"
|
|
72
|
+
index = (
|
|
73
|
+
slice(None), mask
|
|
74
|
+
) if mask is not None else slice(None)
|
|
75
|
+
|
|
76
|
+
expanded_data = sbu.get_at( # Add a singleton horizon dimension to the data
|
|
77
|
+
self.single_space,
|
|
78
|
+
initial_data,
|
|
79
|
+
None
|
|
80
|
+
)
|
|
81
|
+
return state.replace(
|
|
82
|
+
data=sbu.set_at(
|
|
83
|
+
self.stacked_space,
|
|
84
|
+
state.data,
|
|
85
|
+
index,
|
|
86
|
+
expanded_data
|
|
87
|
+
)
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
def add(self, state : SpaceDataQueueState, data : DataT) -> SpaceDataQueueState:
|
|
91
|
+
new_data = self.backend.map_fn_over_arrays(
|
|
92
|
+
state.data,
|
|
93
|
+
lambda x: self.backend.roll(x, shift=-1, axis=0),
|
|
94
|
+
)
|
|
95
|
+
new_data = sbu.set_at(
|
|
96
|
+
self.stacked_space,
|
|
97
|
+
new_data,
|
|
98
|
+
-1,
|
|
99
|
+
data
|
|
100
|
+
)
|
|
101
|
+
return state.replace(data=new_data)
|
|
102
|
+
|
|
103
|
+
def get_output_data(self, state : SpaceDataQueueState) -> DataT:
|
|
104
|
+
if self.batch_size is None:
|
|
105
|
+
return state.data
|
|
106
|
+
else:
|
|
107
|
+
return sbu.swap_batch_dims_in_data(
|
|
108
|
+
self.backend,
|
|
109
|
+
state.data,
|
|
110
|
+
0, 1
|
|
111
|
+
) # (L, B, ...) -> (B, L, ...)
|
|
112
|
+
|
|
113
|
+
class SpaceDataQueue(
|
|
114
|
+
Generic[DataT, BArrayType, BDeviceType, BDtypeType, BRNGType]
|
|
115
|
+
):
|
|
116
|
+
def __init__(
|
|
117
|
+
self,
|
|
118
|
+
space : Space[DataT, BDeviceType, BDtypeType, BRNGType],
|
|
119
|
+
batch_size : Optional[int],
|
|
120
|
+
maxlen: int,
|
|
121
|
+
) -> None:
|
|
122
|
+
self.func_queue = FuncSpaceDataQueue(
|
|
123
|
+
space,
|
|
124
|
+
batch_size,
|
|
125
|
+
maxlen
|
|
126
|
+
)
|
|
127
|
+
self.state = None
|
|
128
|
+
|
|
129
|
+
@property
|
|
130
|
+
def single_space(self) -> Space[DataT, BDeviceType, BDtypeType, BRNGType]:
|
|
131
|
+
return self.func_queue.single_space
|
|
132
|
+
|
|
133
|
+
@property
|
|
134
|
+
def stacked_space(self) -> Space[DataT, BDeviceType, BDtypeType, BRNGType]:
|
|
135
|
+
return self.func_queue.stacked_space
|
|
136
|
+
|
|
137
|
+
@property
|
|
138
|
+
def output_space(self) -> Space[DataT, BDeviceType, BDtypeType, BRNGType]:
|
|
139
|
+
return self.func_queue.output_space
|
|
140
|
+
|
|
141
|
+
@property
|
|
142
|
+
def maxlen(self) -> int:
|
|
143
|
+
return self.func_queue.maxlen
|
|
144
|
+
|
|
145
|
+
@property
|
|
146
|
+
def batch_size(self) -> Optional[int]:
|
|
147
|
+
return self.func_queue.batch_size
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
@property
|
|
151
|
+
def backend(self) -> ComputeBackend:
|
|
152
|
+
return self.func_queue.backend
|
|
153
|
+
|
|
154
|
+
@property
|
|
155
|
+
def device(self) -> Optional[BDeviceType]:
|
|
156
|
+
return self.func_queue.device
|
|
157
|
+
|
|
158
|
+
def reset(
|
|
159
|
+
self,
|
|
160
|
+
initial_data : DataT,
|
|
161
|
+
mask : Optional[BArrayType] = None,
|
|
162
|
+
) -> None:
|
|
163
|
+
if self.state is None:
|
|
164
|
+
assert mask is None, "Mask should not be provided on the first reset"
|
|
165
|
+
self.state = self.func_queue.init(initial_data)
|
|
166
|
+
else:
|
|
167
|
+
self.state = self.func_queue.reset(
|
|
168
|
+
self.state,
|
|
169
|
+
initial_data,
|
|
170
|
+
mask
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
def add(self, data : DataT) -> None:
|
|
174
|
+
assert self.state is not None, "Data queue must be reset before adding data"
|
|
175
|
+
self.state = self.func_queue.add(
|
|
176
|
+
self.state,
|
|
177
|
+
data
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
def get_output_data(self) -> DataT:
|
|
181
|
+
assert self.state is not None, "Data queue must be reset before getting output data"
|
|
182
|
+
return self.func_queue.get_output_data(
|
|
183
|
+
self.state
|
|
184
|
+
)
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
from typing import TypeVar
|
|
3
|
+
import functools
|
|
4
|
+
|
|
5
|
+
__all__ = ['stateclass', 'field', 'StateClass']
|
|
6
|
+
|
|
7
|
+
def field(pytree_node=True, *, metadata=None, **kwargs):
|
|
8
|
+
return dataclasses.field(metadata=(metadata or {}) | {'pytree_node': pytree_node},
|
|
9
|
+
**kwargs)
|
|
10
|
+
|
|
11
|
+
def stateclass(
|
|
12
|
+
clz=None, /, **kwargs
|
|
13
|
+
):
|
|
14
|
+
if clz is None:
|
|
15
|
+
return functools.partial(stateclass, **kwargs) # type: ignore[bad-return-type]
|
|
16
|
+
|
|
17
|
+
# check if already a stateclass
|
|
18
|
+
if '_unienv_stateclass' in clz.__dict__:
|
|
19
|
+
return clz
|
|
20
|
+
|
|
21
|
+
if 'frozen' not in kwargs.keys():
|
|
22
|
+
kwargs['frozen'] = True
|
|
23
|
+
data_clz = dataclasses.dataclass(**kwargs)(clz) # type: ignore
|
|
24
|
+
|
|
25
|
+
def replace(self, **updates):
|
|
26
|
+
"""Returns a new object replacing the specified fields with new values."""
|
|
27
|
+
return dataclasses.replace(self, **updates)
|
|
28
|
+
|
|
29
|
+
data_clz.replace = replace
|
|
30
|
+
|
|
31
|
+
# add a _unienv_stateclass flag to distinguish from regular dataclasses
|
|
32
|
+
data_clz._unienv_stateclass = True # type: ignore[attr-defined]
|
|
33
|
+
|
|
34
|
+
return data_clz # type: ignore
|
|
35
|
+
|
|
36
|
+
TNode = TypeVar('TNode', bound='StateClass')
|
|
37
|
+
|
|
38
|
+
class StateClass:
|
|
39
|
+
def __init_subclass__(cls, **kwargs):
|
|
40
|
+
stateclass(cls, **kwargs) # pytype: disable=wrong-arg-types
|
|
41
|
+
|
|
42
|
+
def __init__(self, *args, **kwargs):
|
|
43
|
+
raise NotImplementedError
|
|
44
|
+
|
|
45
|
+
def replace(self: TNode, **overrides) -> TNode:
|
|
46
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import cloudpickle
|
|
2
|
+
from typing import Callable
|
|
3
|
+
|
|
4
|
+
class MultiProcessFn:
|
|
5
|
+
def __init__(self, fn : Callable):
|
|
6
|
+
self.fn = fn
|
|
7
|
+
|
|
8
|
+
def __getstate__(self):
|
|
9
|
+
return cloudpickle.dumps(self.fn)
|
|
10
|
+
|
|
11
|
+
def __setstate__(self, state):
|
|
12
|
+
self.fn = cloudpickle.loads(state)
|
|
13
|
+
|
|
14
|
+
def __call__(self, *args, **kwargs):
|
|
15
|
+
return self.fn(*args, **kwargs)
|
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
from .world import World, RealWorld
|
|
2
2
|
from .node import WorldNode
|
|
3
|
+
from .combined_node import CombinedWorldNode
|
|
3
4
|
from .funcworld import FuncWorld
|
|
4
|
-
from .funcnode import FuncWorldNode
|
|
5
|
+
from .funcnode import FuncWorldNode
|
|
6
|
+
from .combined_funcnode import CombinedFuncWorldNode
|
|
@@ -0,0 +1,336 @@
|
|
|
1
|
+
from typing import Optional, Dict, Any, Tuple, Union, Iterable, Mapping
|
|
2
|
+
|
|
3
|
+
from unienv_interface.backends import BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
4
|
+
from unienv_interface.space import Space, DictSpace
|
|
5
|
+
|
|
6
|
+
from .funcnode import FuncWorldNode
|
|
7
|
+
from .funcworld import WorldStateT
|
|
8
|
+
|
|
9
|
+
CombinedDataT = Union[Dict[str, Any], Any]
|
|
10
|
+
CombinedNodeStateT = Dict[str, Any]
|
|
11
|
+
|
|
12
|
+
class CombinedFuncWorldNode(FuncWorldNode[
|
|
13
|
+
WorldStateT, CombinedNodeStateT,
|
|
14
|
+
Optional[CombinedDataT], # Context type (can be None)
|
|
15
|
+
CombinedDataT, # Observation type
|
|
16
|
+
CombinedDataT, # Action type
|
|
17
|
+
BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
18
|
+
]):
|
|
19
|
+
"""A functional counterpart to `CombinedWorldNode` that composes multiple `FuncWorldNode`s.
|
|
20
|
+
|
|
21
|
+
It aggregates spaces (context, observation, action) and runtime data (context, observation, info, reward, termination, truncation)
|
|
22
|
+
across child nodes. If only one child exposes a given interface and `direct_return=True`, the value is passed through directly.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
name: str,
|
|
28
|
+
nodes: Iterable[FuncWorldNode[WorldStateT, Any, Any, Any, Any, BArrayType, BDeviceType, BDtypeType, BRNGType]],
|
|
29
|
+
direct_return: bool = True,
|
|
30
|
+
):
|
|
31
|
+
nodes = list(nodes)
|
|
32
|
+
if len(nodes) == 0:
|
|
33
|
+
raise ValueError("At least one node is required to create a CombinedFuncWorldNode.")
|
|
34
|
+
|
|
35
|
+
first_node = nodes[0]
|
|
36
|
+
# Ensure all nodes share the same world & control timestep
|
|
37
|
+
for node in nodes[1:]:
|
|
38
|
+
assert node.world is first_node.world, "All nodes must belong to the same world." \
|
|
39
|
+
f" Mismatch between {first_node.name} and {node.name}."
|
|
40
|
+
assert node.control_timestep == first_node.control_timestep, "All nodes must have the same control timestep." \
|
|
41
|
+
f" Mismatch between {first_node.name} and {node.name}."
|
|
42
|
+
|
|
43
|
+
names = [node.name for node in nodes]
|
|
44
|
+
if len(names) != len(set(names)):
|
|
45
|
+
raise ValueError("All nodes must have unique names.")
|
|
46
|
+
|
|
47
|
+
self.nodes = nodes
|
|
48
|
+
|
|
49
|
+
# Aggregate spaces similar to `CombinedWorldNode`
|
|
50
|
+
_, self.context_space = self.aggregate_spaces(
|
|
51
|
+
{node.name: node.context_space for node in nodes if node.context_space is not None},
|
|
52
|
+
direct_return=direct_return,
|
|
53
|
+
)
|
|
54
|
+
_, self.observation_space = self.aggregate_spaces(
|
|
55
|
+
{node.name: node.observation_space for node in nodes if node.observation_space is not None},
|
|
56
|
+
direct_return=direct_return,
|
|
57
|
+
)
|
|
58
|
+
self._action_node_name_direct, self.action_space = self.aggregate_spaces(
|
|
59
|
+
{node.name: node.action_space for node in nodes if node.action_space is not None},
|
|
60
|
+
direct_return=direct_return,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
self.has_reward = any(node.has_reward for node in nodes)
|
|
64
|
+
self.has_termination_signal = any(node.has_termination_signal for node in nodes)
|
|
65
|
+
self.has_truncation_signal = any(node.has_truncation_signal for node in nodes)
|
|
66
|
+
|
|
67
|
+
self.name = name
|
|
68
|
+
self.direct_return = direct_return
|
|
69
|
+
|
|
70
|
+
# ========== Helper aggregation methods ==========
|
|
71
|
+
@staticmethod
|
|
72
|
+
def aggregate_spaces(
|
|
73
|
+
spaces: Dict[str, Optional[Space[Any, BDeviceType, BDtypeType, BRNGType]]],
|
|
74
|
+
direct_return: bool = True,
|
|
75
|
+
) -> Tuple[Optional[str], Optional[DictSpace[BDeviceType, BDtypeType, BRNGType]]]:
|
|
76
|
+
if len(spaces) == 0:
|
|
77
|
+
return None, None
|
|
78
|
+
elif len(spaces) == 1 and direct_return:
|
|
79
|
+
return next(iter(spaces.items()))
|
|
80
|
+
else:
|
|
81
|
+
backend = next(iter(spaces.values())).backend
|
|
82
|
+
return None, DictSpace(
|
|
83
|
+
backend,
|
|
84
|
+
{name: space for name, space in spaces.items() if space is not None},
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
@staticmethod
|
|
88
|
+
def aggregate_data(
|
|
89
|
+
data: Dict[str, Any],
|
|
90
|
+
direct_return: bool = True,
|
|
91
|
+
) -> Optional[Union[Dict[str, Any], Any]]:
|
|
92
|
+
if len(data) == 0:
|
|
93
|
+
return None
|
|
94
|
+
elif len(data) == 1 and direct_return:
|
|
95
|
+
return next(iter(data.values()))
|
|
96
|
+
else:
|
|
97
|
+
return data
|
|
98
|
+
|
|
99
|
+
# ========== properties ==========
|
|
100
|
+
@property
|
|
101
|
+
def world(self): # type: ignore[override]
|
|
102
|
+
return self.nodes[0].world
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def control_timestep(self): # type: ignore[override]
|
|
106
|
+
return self.nodes[0].control_timestep
|
|
107
|
+
|
|
108
|
+
# ========== Lifecycle methods ==========
|
|
109
|
+
def initial(
|
|
110
|
+
self,
|
|
111
|
+
world_state: WorldStateT,
|
|
112
|
+
*,
|
|
113
|
+
seed: Optional[int] = None,
|
|
114
|
+
pernode_kwargs: Dict[str, Dict[str, Any]] = {},
|
|
115
|
+
) -> Tuple[WorldStateT, CombinedNodeStateT]:
|
|
116
|
+
node_states: CombinedNodeStateT = {}
|
|
117
|
+
for node in self.nodes:
|
|
118
|
+
world_state, node_state = node.initial(world_state, seed=seed, **pernode_kwargs.get(node.name, {}))
|
|
119
|
+
node_states[node.name] = node_state
|
|
120
|
+
return world_state, node_states
|
|
121
|
+
|
|
122
|
+
def reset(
|
|
123
|
+
self,
|
|
124
|
+
world_state: WorldStateT,
|
|
125
|
+
node_state: CombinedNodeStateT,
|
|
126
|
+
*,
|
|
127
|
+
seed: Optional[int] = None,
|
|
128
|
+
mask: Optional[BArrayType] = None,
|
|
129
|
+
pernode_kwargs: Dict[str, Dict[str, Any]] = {},
|
|
130
|
+
**kwargs,
|
|
131
|
+
) -> Tuple[WorldStateT, CombinedNodeStateT]:
|
|
132
|
+
node_state = node_state.copy()
|
|
133
|
+
for node in self.nodes:
|
|
134
|
+
ns = node_state[node.name]
|
|
135
|
+
world_state, ns = node.reset(
|
|
136
|
+
world_state,
|
|
137
|
+
ns,
|
|
138
|
+
seed=seed,
|
|
139
|
+
mask=mask,
|
|
140
|
+
**pernode_kwargs.get(node.name, {}),
|
|
141
|
+
)
|
|
142
|
+
node_state[node.name] = ns
|
|
143
|
+
return world_state, node_state
|
|
144
|
+
|
|
145
|
+
def after_reset(
|
|
146
|
+
self,
|
|
147
|
+
world_state: WorldStateT,
|
|
148
|
+
node_state: CombinedNodeStateT,
|
|
149
|
+
*,
|
|
150
|
+
mask: Optional[BArrayType] = None,
|
|
151
|
+
) -> Tuple[
|
|
152
|
+
WorldStateT,
|
|
153
|
+
CombinedNodeStateT,
|
|
154
|
+
Optional[CombinedDataT],
|
|
155
|
+
Optional[CombinedDataT],
|
|
156
|
+
Optional[Dict[str, Any]],
|
|
157
|
+
]:
|
|
158
|
+
node_state = node_state.copy()
|
|
159
|
+
contexts: Dict[str, Any] = {}
|
|
160
|
+
observations: Dict[str, Any] = {}
|
|
161
|
+
infos: Dict[str, Any] = {}
|
|
162
|
+
|
|
163
|
+
for node in self.nodes:
|
|
164
|
+
ns = node_state[node.name]
|
|
165
|
+
world_state, ns, ctx, obs, info = node.after_reset(world_state, ns, mask=mask)
|
|
166
|
+
node_state[node.name] = ns
|
|
167
|
+
if ctx is not None:
|
|
168
|
+
contexts[node.name] = ctx
|
|
169
|
+
if obs is not None:
|
|
170
|
+
observations[node.name] = obs
|
|
171
|
+
if info is not None:
|
|
172
|
+
infos[node.name] = info
|
|
173
|
+
|
|
174
|
+
return (
|
|
175
|
+
world_state,
|
|
176
|
+
node_state,
|
|
177
|
+
self.aggregate_data(contexts, direct_return=self.direct_return),
|
|
178
|
+
self.aggregate_data(observations, direct_return=self.direct_return),
|
|
179
|
+
self.aggregate_data(infos, direct_return=False),
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
def pre_environment_step(
|
|
183
|
+
self,
|
|
184
|
+
world_state: WorldStateT,
|
|
185
|
+
node_state: CombinedNodeStateT,
|
|
186
|
+
dt: Union[float, BArrayType],
|
|
187
|
+
) -> Tuple[WorldStateT, CombinedNodeStateT]:
|
|
188
|
+
node_state = node_state.copy()
|
|
189
|
+
for node in self.nodes:
|
|
190
|
+
ns = node_state[node.name]
|
|
191
|
+
world_state, ns = node.pre_environment_step(world_state, ns, dt)
|
|
192
|
+
node_state[node.name] = ns
|
|
193
|
+
return world_state, node_state
|
|
194
|
+
|
|
195
|
+
def set_next_action(
|
|
196
|
+
self,
|
|
197
|
+
world_state: WorldStateT,
|
|
198
|
+
node_state: CombinedNodeStateT,
|
|
199
|
+
action: CombinedDataT,
|
|
200
|
+
) -> Tuple[WorldStateT, CombinedNodeStateT]:
|
|
201
|
+
assert self.action_space is not None, "Action space is None, cannot set action."
|
|
202
|
+
|
|
203
|
+
node_state = node_state.copy()
|
|
204
|
+
if self._action_node_name_direct is not None:
|
|
205
|
+
# Only one actionable node
|
|
206
|
+
for node in self.nodes:
|
|
207
|
+
if node.name == self._action_node_name_direct:
|
|
208
|
+
ns = node_state[node.name]
|
|
209
|
+
world_state, ns = node.set_next_action(world_state, ns, action) # type: ignore[arg-type]
|
|
210
|
+
node_state[node.name] = ns
|
|
211
|
+
break
|
|
212
|
+
else:
|
|
213
|
+
assert isinstance(action, Mapping), "Action must be a mapping when there are multiple action spaces."
|
|
214
|
+
for node in self.nodes:
|
|
215
|
+
if node.action_space is not None:
|
|
216
|
+
assert node.name in action, f"Action for node {node.name} is missing."
|
|
217
|
+
ns = node_state[node.name]
|
|
218
|
+
world_state, ns = node.set_next_action(world_state, ns, action[node.name])
|
|
219
|
+
node_state[node.name] = ns
|
|
220
|
+
return world_state, node_state
|
|
221
|
+
|
|
222
|
+
def post_environment_step(
|
|
223
|
+
self,
|
|
224
|
+
world_state: WorldStateT,
|
|
225
|
+
node_state: CombinedNodeStateT,
|
|
226
|
+
dt: Union[float, BArrayType],
|
|
227
|
+
) -> Tuple[WorldStateT, CombinedNodeStateT]:
|
|
228
|
+
node_state = node_state.copy()
|
|
229
|
+
for node in self.nodes:
|
|
230
|
+
ns = node_state[node.name]
|
|
231
|
+
world_state, ns = node.post_environment_step(world_state, ns, dt)
|
|
232
|
+
node_state[node.name] = ns
|
|
233
|
+
return world_state, node_state
|
|
234
|
+
|
|
235
|
+
def close(self, world_state: WorldStateT, node_state: CombinedNodeStateT) -> WorldStateT: # type: ignore[override]
|
|
236
|
+
for node in self.nodes:
|
|
237
|
+
world_state = node.close(world_state, node_state[node.name])
|
|
238
|
+
return world_state
|
|
239
|
+
|
|
240
|
+
# ========== Data accessors ==========
|
|
241
|
+
def get_observation(
|
|
242
|
+
self,
|
|
243
|
+
world_state: WorldStateT,
|
|
244
|
+
node_state: CombinedNodeStateT,
|
|
245
|
+
) -> CombinedDataT:
|
|
246
|
+
assert self.observation_space is not None, "Observation space is None, cannot get observation."
|
|
247
|
+
return self.aggregate_data(
|
|
248
|
+
{
|
|
249
|
+
node.name: node.get_observation(world_state, node_state[node.name])
|
|
250
|
+
for node in self.nodes
|
|
251
|
+
if node.observation_space is not None
|
|
252
|
+
},
|
|
253
|
+
direct_return=self.direct_return,
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
def get_reward(
|
|
257
|
+
self,
|
|
258
|
+
world_state: WorldStateT,
|
|
259
|
+
node_state: CombinedNodeStateT,
|
|
260
|
+
) -> Union[float, BArrayType]:
|
|
261
|
+
assert self.has_reward, "This node does not provide a reward."
|
|
262
|
+
if self.world.batch_size is None:
|
|
263
|
+
return sum(
|
|
264
|
+
node.get_reward(world_state, node_state[node.name])
|
|
265
|
+
for node in self.nodes
|
|
266
|
+
if node.has_reward
|
|
267
|
+
)
|
|
268
|
+
rewards = self.backend.zeros(
|
|
269
|
+
(self.world.batch_size,),
|
|
270
|
+
dtype=self.backend.default_floating_dtype,
|
|
271
|
+
device=self.device,
|
|
272
|
+
)
|
|
273
|
+
for node in self.nodes:
|
|
274
|
+
if node.has_reward:
|
|
275
|
+
rewards = rewards + node.get_reward(world_state, node_state[node.name])
|
|
276
|
+
return rewards
|
|
277
|
+
|
|
278
|
+
def get_termination(
|
|
279
|
+
self,
|
|
280
|
+
world_state: WorldStateT,
|
|
281
|
+
node_state: CombinedNodeStateT,
|
|
282
|
+
) -> Union[bool, BArrayType]:
|
|
283
|
+
assert self.has_termination_signal, "This node does not provide a termination signal."
|
|
284
|
+
if self.world.batch_size is None:
|
|
285
|
+
return any(
|
|
286
|
+
node.get_termination(world_state, node_state[node.name])
|
|
287
|
+
for node in self.nodes
|
|
288
|
+
if node.has_termination_signal
|
|
289
|
+
)
|
|
290
|
+
terminations = self.backend.zeros(
|
|
291
|
+
(self.world.batch_size,),
|
|
292
|
+
dtype=self.backend.default_boolean_dtype,
|
|
293
|
+
device=self.device,
|
|
294
|
+
)
|
|
295
|
+
for node in self.nodes:
|
|
296
|
+
if node.has_termination_signal:
|
|
297
|
+
terminations = self.backend.logical_or(
|
|
298
|
+
terminations, node.get_termination(world_state, node_state[node.name])
|
|
299
|
+
)
|
|
300
|
+
return terminations
|
|
301
|
+
|
|
302
|
+
def get_truncation(
|
|
303
|
+
self,
|
|
304
|
+
world_state: WorldStateT,
|
|
305
|
+
node_state: CombinedNodeStateT,
|
|
306
|
+
) -> Union[bool, BArrayType]:
|
|
307
|
+
assert self.has_truncation_signal, "This node does not provide a truncation signal."
|
|
308
|
+
if self.world.batch_size is None:
|
|
309
|
+
return any(
|
|
310
|
+
node.get_truncation(world_state, node_state[node.name])
|
|
311
|
+
for node in self.nodes
|
|
312
|
+
if node.has_truncation_signal
|
|
313
|
+
)
|
|
314
|
+
truncations = self.backend.zeros(
|
|
315
|
+
(self.world.batch_size,),
|
|
316
|
+
dtype=self.backend.default_boolean_dtype,
|
|
317
|
+
device=self.device,
|
|
318
|
+
)
|
|
319
|
+
for node in self.nodes:
|
|
320
|
+
if node.has_truncation_signal:
|
|
321
|
+
truncations = self.backend.logical_or(
|
|
322
|
+
truncations, node.get_truncation(world_state, node_state[node.name])
|
|
323
|
+
)
|
|
324
|
+
return truncations
|
|
325
|
+
|
|
326
|
+
def get_info(
|
|
327
|
+
self,
|
|
328
|
+
world_state: WorldStateT,
|
|
329
|
+
node_state: CombinedNodeStateT,
|
|
330
|
+
) -> Optional[Dict[str, Any]]:
|
|
331
|
+
infos: Dict[str, Any] = {}
|
|
332
|
+
for node in self.nodes:
|
|
333
|
+
info = node.get_info(world_state, node_state[node.name])
|
|
334
|
+
if info is not None:
|
|
335
|
+
infos[node.name] = info
|
|
336
|
+
return self.aggregate_data(infos, direct_return=False) # Always dict if not empty
|