unienv 0.0.1b1__py3-none-any.whl → 0.0.1b3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. unienv-0.0.1b3.dist-info/METADATA +74 -0
  2. unienv-0.0.1b3.dist-info/RECORD +92 -0
  3. {unienv-0.0.1b1.dist-info → unienv-0.0.1b3.dist-info}/licenses/LICENSE +1 -3
  4. unienv-0.0.1b3.dist-info/top_level.txt +2 -0
  5. unienv_data/base/__init__.py +0 -1
  6. unienv_data/base/common.py +95 -45
  7. unienv_data/base/storage.py +1 -0
  8. unienv_data/batches/__init__.py +2 -1
  9. unienv_data/batches/backend_compat.py +47 -1
  10. unienv_data/batches/combined_batch.py +2 -4
  11. unienv_data/{base → batches}/transformations.py +3 -2
  12. unienv_data/replay_buffer/replay_buffer.py +4 -0
  13. unienv_data/samplers/__init__.py +0 -1
  14. unienv_data/samplers/multiprocessing_sampler.py +26 -22
  15. unienv_data/samplers/step_sampler.py +9 -18
  16. unienv_data/storages/common.py +5 -0
  17. unienv_data/storages/hdf5.py +291 -20
  18. unienv_data/storages/pytorch.py +1 -0
  19. unienv_data/storages/transformation.py +191 -0
  20. unienv_data/transformations/image_compress.py +213 -0
  21. unienv_interface/backends/jax.py +4 -1
  22. unienv_interface/backends/numpy.py +4 -1
  23. unienv_interface/backends/pytorch.py +4 -1
  24. unienv_interface/env_base/__init__.py +1 -0
  25. unienv_interface/env_base/env.py +5 -0
  26. unienv_interface/env_base/funcenv.py +32 -1
  27. unienv_interface/env_base/funcenv_wrapper.py +2 -2
  28. unienv_interface/env_base/vec_env.py +474 -0
  29. unienv_interface/func_wrapper/__init__.py +2 -1
  30. unienv_interface/func_wrapper/frame_stack.py +150 -0
  31. unienv_interface/space/space_utils/__init__.py +1 -0
  32. unienv_interface/space/space_utils/batch_utils.py +83 -0
  33. unienv_interface/space/space_utils/construct_utils.py +216 -0
  34. unienv_interface/space/space_utils/serialization_utils.py +16 -1
  35. unienv_interface/space/spaces/__init__.py +3 -1
  36. unienv_interface/space/spaces/batched.py +90 -0
  37. unienv_interface/space/spaces/binary.py +0 -1
  38. unienv_interface/space/spaces/box.py +13 -24
  39. unienv_interface/space/spaces/text.py +1 -3
  40. unienv_interface/transformations/dict_transform.py +31 -5
  41. unienv_interface/utils/control_util.py +68 -0
  42. unienv_interface/utils/data_queue.py +184 -0
  43. unienv_interface/utils/stateclass.py +46 -0
  44. unienv_interface/utils/vec_util.py +15 -0
  45. unienv_interface/world/__init__.py +3 -1
  46. unienv_interface/world/combined_funcnode.py +336 -0
  47. unienv_interface/world/combined_node.py +232 -0
  48. unienv_interface/wrapper/backend_compat.py +2 -2
  49. unienv_interface/wrapper/frame_stack.py +19 -114
  50. unienv_interface/wrapper/video_record.py +11 -2
  51. unienv-0.0.1b1.dist-info/METADATA +0 -20
  52. unienv-0.0.1b1.dist-info/RECORD +0 -85
  53. unienv-0.0.1b1.dist-info/top_level.txt +0 -4
  54. unienv_data/samplers/slice_sampler.py +0 -266
  55. unienv_maniskill/__init__.py +0 -1
  56. unienv_maniskill/wrapper/maniskill_compat.py +0 -235
  57. unienv_mjxplayground/__init__.py +0 -1
  58. unienv_mjxplayground/wrapper/playground_compat.py +0 -256
  59. {unienv-0.0.1b1.dist-info → unienv-0.0.1b3.dist-info}/WHEEL +0 -0
@@ -0,0 +1,68 @@
1
+ from fractions import Fraction
2
+ from math import gcd
3
+ from functools import reduce
4
+ from typing import Iterable, Tuple, List, Union
5
+
6
+ Number = Union[int, float]
7
+
8
+ def _lcm(a: int, b: int) -> int:
9
+ return abs(a // gcd(a, b) * b) if a and b else 0
10
+
11
+ def _lcmm(values: Iterable[int]) -> int:
12
+ return reduce(_lcm, values, 1)
13
+
14
+ def _gcdm(values: Iterable[int]) -> int:
15
+ return reduce(gcd, values)
16
+
17
+ def find_best_timestep(
18
+ timesteps: Iterable[Number],
19
+ *,
20
+ max_denominator: int = 10_000,
21
+ return_fraction: bool = False,
22
+ ) -> Tuple[Union[float, Fraction], List[int]]:
23
+ """
24
+ Compute the simulation timestep dt such that every sensor period is an
25
+ integer multiple of dt (i.e., dt is the GCD of the periods). Works with floats.
26
+
27
+ Args:
28
+ timesteps: Iterable of sensor periods (seconds, ms, etc.). Must be > 0.
29
+ max_denominator: Max denominator when rational-approximating floats.
30
+ Increase if your periods are very fine-grained.
31
+ return_fraction: If True, returns dt as a Fraction; otherwise a float.
32
+
33
+ Returns:
34
+ dt: The best simulation timestep (float or Fraction).
35
+ steps_per_sensor: For each input period T_i, the integer k_i = T_i / dt.
36
+
37
+ Raises:
38
+ ValueError: If list is empty or contains non-positive values.
39
+ """
40
+ # Validate and convert to Fractions
41
+ periods = list(timesteps)
42
+ if not periods:
43
+ raise ValueError("timesteps must be a non-empty sequence.")
44
+ if any(p <= 0 for p in periods):
45
+ raise ValueError("All timesteps must be positive.")
46
+
47
+ fracs = [Fraction(p).limit_denominator(max_denominator) for p in periods]
48
+
49
+ # Find common denominator (LCM of all denominators)
50
+ D = _lcmm([f.denominator for f in fracs])
51
+
52
+ # Scale each fraction to that denominator → integers
53
+ ints = [f.numerator * (D // f.denominator) for f in fracs]
54
+
55
+ # GCD of the integerized periods → integer g; convert back via g/D
56
+ g = _gcdm(ints)
57
+ dt_frac = Fraction(g, D) # exact rational dt
58
+
59
+ # Sanity: dt must be > 0
60
+ if dt_frac <= 0:
61
+ raise RuntimeError("Computed non-positive timestep; check inputs.")
62
+
63
+ steps_per_sensor = [int(f // dt_frac) for f in fracs] # exact integer division
64
+
65
+ if return_fraction:
66
+ return dt_frac, steps_per_sensor
67
+ else:
68
+ return float(dt_frac), steps_per_sensor
@@ -0,0 +1,184 @@
1
+ from typing import Dict, Any, Optional, Tuple, Union, Generic, TypeVar
2
+ import numpy as np
3
+ import copy
4
+ import dataclasses
5
+
6
+ from unienv_interface.space import Space
7
+ from unienv_interface.space import batch_utils as sbu, flatten_utils as sfu
8
+ from unienv_interface.backends import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
9
+
10
+ DataT = TypeVar('DataT')
11
+
12
+ @dataclasses.dataclass(frozen=True)
13
+ class SpaceDataQueueState(Generic[DataT]):
14
+ data: DataT # (L, B, ...) or (L, ...)
15
+
16
+ def replace(self, **changes: Any) -> 'SpaceDataQueueState':
17
+ return dataclasses.replace(self, **changes)
18
+
19
+ class FuncSpaceDataQueue(
20
+ Generic[DataT, BArrayType, BDeviceType, BDtypeType, BRNGType]
21
+ ):
22
+ def __init__(
23
+ self,
24
+ space : Space[DataT, BDeviceType, BDtypeType, BRNGType],
25
+ batch_size : Optional[int],
26
+ maxlen: int,
27
+ ) -> None:
28
+ assert maxlen > 0, "Max length must be greater than 0"
29
+ assert batch_size is None or batch_size > 0, "Batch size must be greater than 0 if provided"
30
+ assert batch_size is None or sbu.batch_size(space) == batch_size, "Batch size must match the space's batch size if provided"
31
+ self.single_space = space
32
+ self.stacked_space = sbu.batch_space(space, maxlen) # (L, ...) or (L, B, ...)
33
+ self.output_space = sbu.swap_batch_dims(
34
+ self.stacked_space, 0, 1
35
+ ) if batch_size is not None else self.stacked_space # (B, L, ...) or (L, ...)
36
+ self._maxlen = maxlen
37
+ self._batch_size = batch_size
38
+
39
+ @property
40
+ def maxlen(self) -> int:
41
+ return self._maxlen
42
+
43
+ @property
44
+ def batch_size(self) -> Optional[int]:
45
+ return self._batch_size
46
+
47
+ @property
48
+ def backend(self) -> ComputeBackend:
49
+ return self.single_space.backend
50
+
51
+ @property
52
+ def device(self) -> Optional[BDeviceType]:
53
+ return self.single_space.device
54
+
55
+ def init(
56
+ self,
57
+ initial_data : DataT,
58
+ ) -> SpaceDataQueueState:
59
+ return self.reset(
60
+ SpaceDataQueueState(self.stacked_space.create_empty()),
61
+ initial_data
62
+ )
63
+
64
+ def reset(
65
+ self,
66
+ state : SpaceDataQueueState,
67
+ initial_data : DataT,
68
+ mask : Optional[BArrayType] = None,
69
+ ) -> SpaceDataQueueState:
70
+ assert self.batch_size is None or mask is None, \
71
+ "Mask should not be provided if batch size is empty"
72
+ index = (
73
+ slice(None), mask
74
+ ) if mask is not None else slice(None)
75
+
76
+ expanded_data = sbu.get_at( # Add a singleton horizon dimension to the data
77
+ self.single_space,
78
+ initial_data,
79
+ None
80
+ )
81
+ return state.replace(
82
+ data=sbu.set_at(
83
+ self.stacked_space,
84
+ state.data,
85
+ index,
86
+ expanded_data
87
+ )
88
+ )
89
+
90
+ def add(self, state : SpaceDataQueueState, data : DataT) -> SpaceDataQueueState:
91
+ new_data = self.backend.map_fn_over_arrays(
92
+ state.data,
93
+ lambda x: self.backend.roll(x, shift=-1, axis=0),
94
+ )
95
+ new_data = sbu.set_at(
96
+ self.stacked_space,
97
+ new_data,
98
+ -1,
99
+ data
100
+ )
101
+ return state.replace(data=new_data)
102
+
103
+ def get_output_data(self, state : SpaceDataQueueState) -> DataT:
104
+ if self.batch_size is None:
105
+ return state.data
106
+ else:
107
+ return sbu.swap_batch_dims_in_data(
108
+ self.backend,
109
+ state.data,
110
+ 0, 1
111
+ ) # (L, B, ...) -> (B, L, ...)
112
+
113
+ class SpaceDataQueue(
114
+ Generic[DataT, BArrayType, BDeviceType, BDtypeType, BRNGType]
115
+ ):
116
+ def __init__(
117
+ self,
118
+ space : Space[DataT, BDeviceType, BDtypeType, BRNGType],
119
+ batch_size : Optional[int],
120
+ maxlen: int,
121
+ ) -> None:
122
+ self.func_queue = FuncSpaceDataQueue(
123
+ space,
124
+ batch_size,
125
+ maxlen
126
+ )
127
+ self.state = None
128
+
129
+ @property
130
+ def single_space(self) -> Space[DataT, BDeviceType, BDtypeType, BRNGType]:
131
+ return self.func_queue.single_space
132
+
133
+ @property
134
+ def stacked_space(self) -> Space[DataT, BDeviceType, BDtypeType, BRNGType]:
135
+ return self.func_queue.stacked_space
136
+
137
+ @property
138
+ def output_space(self) -> Space[DataT, BDeviceType, BDtypeType, BRNGType]:
139
+ return self.func_queue.output_space
140
+
141
+ @property
142
+ def maxlen(self) -> int:
143
+ return self.func_queue.maxlen
144
+
145
+ @property
146
+ def batch_size(self) -> Optional[int]:
147
+ return self.func_queue.batch_size
148
+
149
+
150
+ @property
151
+ def backend(self) -> ComputeBackend:
152
+ return self.func_queue.backend
153
+
154
+ @property
155
+ def device(self) -> Optional[BDeviceType]:
156
+ return self.func_queue.device
157
+
158
+ def reset(
159
+ self,
160
+ initial_data : DataT,
161
+ mask : Optional[BArrayType] = None,
162
+ ) -> None:
163
+ if self.state is None:
164
+ assert mask is None, "Mask should not be provided on the first reset"
165
+ self.state = self.func_queue.init(initial_data)
166
+ else:
167
+ self.state = self.func_queue.reset(
168
+ self.state,
169
+ initial_data,
170
+ mask
171
+ )
172
+
173
+ def add(self, data : DataT) -> None:
174
+ assert self.state is not None, "Data queue must be reset before adding data"
175
+ self.state = self.func_queue.add(
176
+ self.state,
177
+ data
178
+ )
179
+
180
+ def get_output_data(self) -> DataT:
181
+ assert self.state is not None, "Data queue must be reset before getting output data"
182
+ return self.func_queue.get_output_data(
183
+ self.state
184
+ )
@@ -0,0 +1,46 @@
1
+ import dataclasses
2
+ from typing import TypeVar
3
+ import functools
4
+
5
+ __all__ = ['stateclass', 'field', 'StateClass']
6
+
7
+ def field(pytree_node=True, *, metadata=None, **kwargs):
8
+ return dataclasses.field(metadata=(metadata or {}) | {'pytree_node': pytree_node},
9
+ **kwargs)
10
+
11
+ def stateclass(
12
+ clz=None, /, **kwargs
13
+ ):
14
+ if clz is None:
15
+ return functools.partial(stateclass, **kwargs) # type: ignore[bad-return-type]
16
+
17
+ # check if already a stateclass
18
+ if '_unienv_stateclass' in clz.__dict__:
19
+ return clz
20
+
21
+ if 'frozen' not in kwargs.keys():
22
+ kwargs['frozen'] = True
23
+ data_clz = dataclasses.dataclass(**kwargs)(clz) # type: ignore
24
+
25
+ def replace(self, **updates):
26
+ """Returns a new object replacing the specified fields with new values."""
27
+ return dataclasses.replace(self, **updates)
28
+
29
+ data_clz.replace = replace
30
+
31
+ # add a _unienv_stateclass flag to distinguish from regular dataclasses
32
+ data_clz._unienv_stateclass = True # type: ignore[attr-defined]
33
+
34
+ return data_clz # type: ignore
35
+
36
+ TNode = TypeVar('TNode', bound='StateClass')
37
+
38
+ class StateClass:
39
+ def __init_subclass__(cls, **kwargs):
40
+ stateclass(cls, **kwargs) # pytype: disable=wrong-arg-types
41
+
42
+ def __init__(self, *args, **kwargs):
43
+ raise NotImplementedError
44
+
45
+ def replace(self: TNode, **overrides) -> TNode:
46
+ raise NotImplementedError
@@ -0,0 +1,15 @@
1
+ import cloudpickle
2
+ from typing import Callable
3
+
4
+ class MultiProcessFn:
5
+ def __init__(self, fn : Callable):
6
+ self.fn = fn
7
+
8
+ def __getstate__(self):
9
+ return cloudpickle.dumps(self.fn)
10
+
11
+ def __setstate__(self, state):
12
+ self.fn = cloudpickle.loads(state)
13
+
14
+ def __call__(self, *args, **kwargs):
15
+ return self.fn(*args, **kwargs)
@@ -1,4 +1,6 @@
1
1
  from .world import World, RealWorld
2
2
  from .node import WorldNode
3
+ from .combined_node import CombinedWorldNode
3
4
  from .funcworld import FuncWorld
4
- from .funcnode import FuncWorldNode
5
+ from .funcnode import FuncWorldNode
6
+ from .combined_funcnode import CombinedFuncWorldNode
@@ -0,0 +1,336 @@
1
+ from typing import Optional, Dict, Any, Tuple, Union, Iterable, Mapping
2
+
3
+ from unienv_interface.backends import BArrayType, BDeviceType, BDtypeType, BRNGType
4
+ from unienv_interface.space import Space, DictSpace
5
+
6
+ from .funcnode import FuncWorldNode
7
+ from .funcworld import WorldStateT
8
+
9
+ CombinedDataT = Union[Dict[str, Any], Any]
10
+ CombinedNodeStateT = Dict[str, Any]
11
+
12
+ class CombinedFuncWorldNode(FuncWorldNode[
13
+ WorldStateT, CombinedNodeStateT,
14
+ Optional[CombinedDataT], # Context type (can be None)
15
+ CombinedDataT, # Observation type
16
+ CombinedDataT, # Action type
17
+ BArrayType, BDeviceType, BDtypeType, BRNGType
18
+ ]):
19
+ """A functional counterpart to `CombinedWorldNode` that composes multiple `FuncWorldNode`s.
20
+
21
+ It aggregates spaces (context, observation, action) and runtime data (context, observation, info, reward, termination, truncation)
22
+ across child nodes. If only one child exposes a given interface and `direct_return=True`, the value is passed through directly.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ name: str,
28
+ nodes: Iterable[FuncWorldNode[WorldStateT, Any, Any, Any, Any, BArrayType, BDeviceType, BDtypeType, BRNGType]],
29
+ direct_return: bool = True,
30
+ ):
31
+ nodes = list(nodes)
32
+ if len(nodes) == 0:
33
+ raise ValueError("At least one node is required to create a CombinedFuncWorldNode.")
34
+
35
+ first_node = nodes[0]
36
+ # Ensure all nodes share the same world & control timestep
37
+ for node in nodes[1:]:
38
+ assert node.world is first_node.world, "All nodes must belong to the same world." \
39
+ f" Mismatch between {first_node.name} and {node.name}."
40
+ assert node.control_timestep == first_node.control_timestep, "All nodes must have the same control timestep." \
41
+ f" Mismatch between {first_node.name} and {node.name}."
42
+
43
+ names = [node.name for node in nodes]
44
+ if len(names) != len(set(names)):
45
+ raise ValueError("All nodes must have unique names.")
46
+
47
+ self.nodes = nodes
48
+
49
+ # Aggregate spaces similar to `CombinedWorldNode`
50
+ _, self.context_space = self.aggregate_spaces(
51
+ {node.name: node.context_space for node in nodes if node.context_space is not None},
52
+ direct_return=direct_return,
53
+ )
54
+ _, self.observation_space = self.aggregate_spaces(
55
+ {node.name: node.observation_space for node in nodes if node.observation_space is not None},
56
+ direct_return=direct_return,
57
+ )
58
+ self._action_node_name_direct, self.action_space = self.aggregate_spaces(
59
+ {node.name: node.action_space for node in nodes if node.action_space is not None},
60
+ direct_return=direct_return,
61
+ )
62
+
63
+ self.has_reward = any(node.has_reward for node in nodes)
64
+ self.has_termination_signal = any(node.has_termination_signal for node in nodes)
65
+ self.has_truncation_signal = any(node.has_truncation_signal for node in nodes)
66
+
67
+ self.name = name
68
+ self.direct_return = direct_return
69
+
70
+ # ========== Helper aggregation methods ==========
71
+ @staticmethod
72
+ def aggregate_spaces(
73
+ spaces: Dict[str, Optional[Space[Any, BDeviceType, BDtypeType, BRNGType]]],
74
+ direct_return: bool = True,
75
+ ) -> Tuple[Optional[str], Optional[DictSpace[BDeviceType, BDtypeType, BRNGType]]]:
76
+ if len(spaces) == 0:
77
+ return None, None
78
+ elif len(spaces) == 1 and direct_return:
79
+ return next(iter(spaces.items()))
80
+ else:
81
+ backend = next(iter(spaces.values())).backend
82
+ return None, DictSpace(
83
+ backend,
84
+ {name: space for name, space in spaces.items() if space is not None},
85
+ )
86
+
87
+ @staticmethod
88
+ def aggregate_data(
89
+ data: Dict[str, Any],
90
+ direct_return: bool = True,
91
+ ) -> Optional[Union[Dict[str, Any], Any]]:
92
+ if len(data) == 0:
93
+ return None
94
+ elif len(data) == 1 and direct_return:
95
+ return next(iter(data.values()))
96
+ else:
97
+ return data
98
+
99
+ # ========== properties ==========
100
+ @property
101
+ def world(self): # type: ignore[override]
102
+ return self.nodes[0].world
103
+
104
+ @property
105
+ def control_timestep(self): # type: ignore[override]
106
+ return self.nodes[0].control_timestep
107
+
108
+ # ========== Lifecycle methods ==========
109
+ def initial(
110
+ self,
111
+ world_state: WorldStateT,
112
+ *,
113
+ seed: Optional[int] = None,
114
+ pernode_kwargs: Dict[str, Dict[str, Any]] = {},
115
+ ) -> Tuple[WorldStateT, CombinedNodeStateT]:
116
+ node_states: CombinedNodeStateT = {}
117
+ for node in self.nodes:
118
+ world_state, node_state = node.initial(world_state, seed=seed, **pernode_kwargs.get(node.name, {}))
119
+ node_states[node.name] = node_state
120
+ return world_state, node_states
121
+
122
+ def reset(
123
+ self,
124
+ world_state: WorldStateT,
125
+ node_state: CombinedNodeStateT,
126
+ *,
127
+ seed: Optional[int] = None,
128
+ mask: Optional[BArrayType] = None,
129
+ pernode_kwargs: Dict[str, Dict[str, Any]] = {},
130
+ **kwargs,
131
+ ) -> Tuple[WorldStateT, CombinedNodeStateT]:
132
+ node_state = node_state.copy()
133
+ for node in self.nodes:
134
+ ns = node_state[node.name]
135
+ world_state, ns = node.reset(
136
+ world_state,
137
+ ns,
138
+ seed=seed,
139
+ mask=mask,
140
+ **pernode_kwargs.get(node.name, {}),
141
+ )
142
+ node_state[node.name] = ns
143
+ return world_state, node_state
144
+
145
+ def after_reset(
146
+ self,
147
+ world_state: WorldStateT,
148
+ node_state: CombinedNodeStateT,
149
+ *,
150
+ mask: Optional[BArrayType] = None,
151
+ ) -> Tuple[
152
+ WorldStateT,
153
+ CombinedNodeStateT,
154
+ Optional[CombinedDataT],
155
+ Optional[CombinedDataT],
156
+ Optional[Dict[str, Any]],
157
+ ]:
158
+ node_state = node_state.copy()
159
+ contexts: Dict[str, Any] = {}
160
+ observations: Dict[str, Any] = {}
161
+ infos: Dict[str, Any] = {}
162
+
163
+ for node in self.nodes:
164
+ ns = node_state[node.name]
165
+ world_state, ns, ctx, obs, info = node.after_reset(world_state, ns, mask=mask)
166
+ node_state[node.name] = ns
167
+ if ctx is not None:
168
+ contexts[node.name] = ctx
169
+ if obs is not None:
170
+ observations[node.name] = obs
171
+ if info is not None:
172
+ infos[node.name] = info
173
+
174
+ return (
175
+ world_state,
176
+ node_state,
177
+ self.aggregate_data(contexts, direct_return=self.direct_return),
178
+ self.aggregate_data(observations, direct_return=self.direct_return),
179
+ self.aggregate_data(infos, direct_return=False),
180
+ )
181
+
182
+ def pre_environment_step(
183
+ self,
184
+ world_state: WorldStateT,
185
+ node_state: CombinedNodeStateT,
186
+ dt: Union[float, BArrayType],
187
+ ) -> Tuple[WorldStateT, CombinedNodeStateT]:
188
+ node_state = node_state.copy()
189
+ for node in self.nodes:
190
+ ns = node_state[node.name]
191
+ world_state, ns = node.pre_environment_step(world_state, ns, dt)
192
+ node_state[node.name] = ns
193
+ return world_state, node_state
194
+
195
+ def set_next_action(
196
+ self,
197
+ world_state: WorldStateT,
198
+ node_state: CombinedNodeStateT,
199
+ action: CombinedDataT,
200
+ ) -> Tuple[WorldStateT, CombinedNodeStateT]:
201
+ assert self.action_space is not None, "Action space is None, cannot set action."
202
+
203
+ node_state = node_state.copy()
204
+ if self._action_node_name_direct is not None:
205
+ # Only one actionable node
206
+ for node in self.nodes:
207
+ if node.name == self._action_node_name_direct:
208
+ ns = node_state[node.name]
209
+ world_state, ns = node.set_next_action(world_state, ns, action) # type: ignore[arg-type]
210
+ node_state[node.name] = ns
211
+ break
212
+ else:
213
+ assert isinstance(action, Mapping), "Action must be a mapping when there are multiple action spaces."
214
+ for node in self.nodes:
215
+ if node.action_space is not None:
216
+ assert node.name in action, f"Action for node {node.name} is missing."
217
+ ns = node_state[node.name]
218
+ world_state, ns = node.set_next_action(world_state, ns, action[node.name])
219
+ node_state[node.name] = ns
220
+ return world_state, node_state
221
+
222
+ def post_environment_step(
223
+ self,
224
+ world_state: WorldStateT,
225
+ node_state: CombinedNodeStateT,
226
+ dt: Union[float, BArrayType],
227
+ ) -> Tuple[WorldStateT, CombinedNodeStateT]:
228
+ node_state = node_state.copy()
229
+ for node in self.nodes:
230
+ ns = node_state[node.name]
231
+ world_state, ns = node.post_environment_step(world_state, ns, dt)
232
+ node_state[node.name] = ns
233
+ return world_state, node_state
234
+
235
+ def close(self, world_state: WorldStateT, node_state: CombinedNodeStateT) -> WorldStateT: # type: ignore[override]
236
+ for node in self.nodes:
237
+ world_state = node.close(world_state, node_state[node.name])
238
+ return world_state
239
+
240
+ # ========== Data accessors ==========
241
+ def get_observation(
242
+ self,
243
+ world_state: WorldStateT,
244
+ node_state: CombinedNodeStateT,
245
+ ) -> CombinedDataT:
246
+ assert self.observation_space is not None, "Observation space is None, cannot get observation."
247
+ return self.aggregate_data(
248
+ {
249
+ node.name: node.get_observation(world_state, node_state[node.name])
250
+ for node in self.nodes
251
+ if node.observation_space is not None
252
+ },
253
+ direct_return=self.direct_return,
254
+ )
255
+
256
+ def get_reward(
257
+ self,
258
+ world_state: WorldStateT,
259
+ node_state: CombinedNodeStateT,
260
+ ) -> Union[float, BArrayType]:
261
+ assert self.has_reward, "This node does not provide a reward."
262
+ if self.world.batch_size is None:
263
+ return sum(
264
+ node.get_reward(world_state, node_state[node.name])
265
+ for node in self.nodes
266
+ if node.has_reward
267
+ )
268
+ rewards = self.backend.zeros(
269
+ (self.world.batch_size,),
270
+ dtype=self.backend.default_floating_dtype,
271
+ device=self.device,
272
+ )
273
+ for node in self.nodes:
274
+ if node.has_reward:
275
+ rewards = rewards + node.get_reward(world_state, node_state[node.name])
276
+ return rewards
277
+
278
+ def get_termination(
279
+ self,
280
+ world_state: WorldStateT,
281
+ node_state: CombinedNodeStateT,
282
+ ) -> Union[bool, BArrayType]:
283
+ assert self.has_termination_signal, "This node does not provide a termination signal."
284
+ if self.world.batch_size is None:
285
+ return any(
286
+ node.get_termination(world_state, node_state[node.name])
287
+ for node in self.nodes
288
+ if node.has_termination_signal
289
+ )
290
+ terminations = self.backend.zeros(
291
+ (self.world.batch_size,),
292
+ dtype=self.backend.default_boolean_dtype,
293
+ device=self.device,
294
+ )
295
+ for node in self.nodes:
296
+ if node.has_termination_signal:
297
+ terminations = self.backend.logical_or(
298
+ terminations, node.get_termination(world_state, node_state[node.name])
299
+ )
300
+ return terminations
301
+
302
+ def get_truncation(
303
+ self,
304
+ world_state: WorldStateT,
305
+ node_state: CombinedNodeStateT,
306
+ ) -> Union[bool, BArrayType]:
307
+ assert self.has_truncation_signal, "This node does not provide a truncation signal."
308
+ if self.world.batch_size is None:
309
+ return any(
310
+ node.get_truncation(world_state, node_state[node.name])
311
+ for node in self.nodes
312
+ if node.has_truncation_signal
313
+ )
314
+ truncations = self.backend.zeros(
315
+ (self.world.batch_size,),
316
+ dtype=self.backend.default_boolean_dtype,
317
+ device=self.device,
318
+ )
319
+ for node in self.nodes:
320
+ if node.has_truncation_signal:
321
+ truncations = self.backend.logical_or(
322
+ truncations, node.get_truncation(world_state, node_state[node.name])
323
+ )
324
+ return truncations
325
+
326
+ def get_info(
327
+ self,
328
+ world_state: WorldStateT,
329
+ node_state: CombinedNodeStateT,
330
+ ) -> Optional[Dict[str, Any]]:
331
+ infos: Dict[str, Any] = {}
332
+ for node in self.nodes:
333
+ info = node.get_info(world_state, node_state[node.name])
334
+ if info is not None:
335
+ infos[node.name] = info
336
+ return self.aggregate_data(infos, direct_return=False) # Always dict if not empty