multi-agent-rlenv 3.3.7__py3-none-any.whl → 3.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- marlenv/__init__.py +11 -13
- marlenv/adapters/gym_adapter.py +6 -16
- marlenv/adapters/overcooked_adapter.py +43 -23
- marlenv/adapters/pettingzoo_adapter.py +5 -5
- marlenv/adapters/pymarl_adapter.py +3 -4
- marlenv/adapters/smac_adapter.py +6 -6
- marlenv/env_builder.py +8 -9
- marlenv/env_pool.py +5 -7
- marlenv/mock_env.py +7 -7
- marlenv/models/__init__.py +2 -4
- marlenv/models/env.py +23 -12
- marlenv/models/episode.py +17 -20
- marlenv/models/spaces.py +90 -83
- marlenv/models/transition.py +6 -10
- marlenv/utils/__init__.py +10 -0
- marlenv/utils/schedule.py +281 -0
- marlenv/wrappers/agent_id_wrapper.py +4 -5
- marlenv/wrappers/available_actions_mask.py +6 -7
- marlenv/wrappers/available_actions_wrapper.py +7 -9
- marlenv/wrappers/blind_wrapper.py +5 -7
- marlenv/wrappers/centralised.py +12 -14
- marlenv/wrappers/delayed_rewards.py +13 -11
- marlenv/wrappers/last_action_wrapper.py +10 -14
- marlenv/wrappers/paddings.py +6 -8
- marlenv/wrappers/penalty_wrapper.py +5 -8
- marlenv/wrappers/rlenv_wrapper.py +12 -9
- marlenv/wrappers/time_limit.py +3 -3
- marlenv/wrappers/video_recorder.py +4 -6
- {multi_agent_rlenv-3.3.7.dist-info → multi_agent_rlenv-3.5.0.dist-info}/METADATA +1 -1
- multi_agent_rlenv-3.5.0.dist-info/RECORD +39 -0
- multi_agent_rlenv-3.3.7.dist-info/RECORD +0 -37
- {multi_agent_rlenv-3.3.7.dist-info → multi_agent_rlenv-3.5.0.dist-info}/WHEEL +0 -0
- {multi_agent_rlenv-3.3.7.dist-info → multi_agent_rlenv-3.5.0.dist-info}/licenses/LICENSE +0 -0
marlenv/models/episode.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
2
|
from functools import cached_property
|
|
3
|
-
from typing import Any, Callable,
|
|
3
|
+
from typing import Any, Callable, Optional, Sequence, overload
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import numpy.typing as npt
|
|
@@ -14,11 +14,8 @@ from .env import MARLEnv
|
|
|
14
14
|
from marlenv.exceptions import EnvironmentMismatchException, ReplayMismatchException
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
A = TypeVar("A")
|
|
18
|
-
|
|
19
|
-
|
|
20
17
|
@dataclass
|
|
21
|
-
class Episode
|
|
18
|
+
class Episode:
|
|
22
19
|
"""Episode model made of observations, actions, rewards, ..."""
|
|
23
20
|
|
|
24
21
|
all_observations: list[npt.NDArray[np.float32]]
|
|
@@ -55,7 +52,7 @@ class Episode(Generic[A]):
|
|
|
55
52
|
)
|
|
56
53
|
|
|
57
54
|
@staticmethod
|
|
58
|
-
def from_transitions(transitions: Sequence[Transition
|
|
55
|
+
def from_transitions(transitions: Sequence[Transition]) -> "Episode":
|
|
59
56
|
"""Create an episode from a list of transitions"""
|
|
60
57
|
episode = Episode.new(transitions[0].obs, transitions[0].state)
|
|
61
58
|
for transition in transitions:
|
|
@@ -179,9 +176,9 @@ class Episode(Generic[A]):
|
|
|
179
176
|
@cached_property
|
|
180
177
|
def dones(self):
|
|
181
178
|
"""The done flags for each transition"""
|
|
182
|
-
dones = np.zeros_like(self.rewards, dtype=np.
|
|
179
|
+
dones = np.zeros_like(self.rewards, dtype=np.bool)
|
|
183
180
|
if self.is_done:
|
|
184
|
-
dones[self.episode_len - 1 :] =
|
|
181
|
+
dones[self.episode_len - 1 :] = True
|
|
185
182
|
return dones
|
|
186
183
|
|
|
187
184
|
@property
|
|
@@ -214,11 +211,11 @@ class Episode(Generic[A]):
|
|
|
214
211
|
|
|
215
212
|
def replay(
|
|
216
213
|
self,
|
|
217
|
-
env: MARLEnv
|
|
214
|
+
env: MARLEnv,
|
|
218
215
|
seed: Optional[int] = None,
|
|
219
216
|
*,
|
|
220
|
-
after_reset: Optional[Callable[[Observation, State, MARLEnv
|
|
221
|
-
after_step: Optional[Callable[[int, Step, MARLEnv
|
|
217
|
+
after_reset: Optional[Callable[[Observation, State, MARLEnv], None]] = None,
|
|
218
|
+
after_step: Optional[Callable[[int, Step, MARLEnv], None]] = None,
|
|
222
219
|
):
|
|
223
220
|
"""
|
|
224
221
|
Replay the episode in the environment (i.e. perform the actions) and assert that the outcomes match.
|
|
@@ -243,12 +240,12 @@ class Episode(Generic[A]):
|
|
|
243
240
|
raise ReplayMismatchException("observation", step.obs.data, self.next_obs[i], time_step=i)
|
|
244
241
|
if not np.array_equal(step.state.data, self.next_states[i]):
|
|
245
242
|
raise ReplayMismatchException("state", step.state.data, self.next_states[i], time_step=i)
|
|
246
|
-
if not np.
|
|
243
|
+
if not np.isclose(step.reward, self.rewards[i]):
|
|
247
244
|
raise ReplayMismatchException("reward", step.reward, self.rewards[i], time_step=i)
|
|
248
245
|
if after_step is not None:
|
|
249
246
|
after_step(i, step, env)
|
|
250
247
|
|
|
251
|
-
def get_images(self, env: MARLEnv
|
|
248
|
+
def get_images(self, env: MARLEnv, seed: Optional[int] = None) -> list[np.ndarray]:
|
|
252
249
|
images = []
|
|
253
250
|
|
|
254
251
|
def collect_image(*_, **__):
|
|
@@ -257,7 +254,7 @@ class Episode(Generic[A]):
|
|
|
257
254
|
self.replay(env, seed, after_reset=collect_image, after_step=collect_image)
|
|
258
255
|
return images
|
|
259
256
|
|
|
260
|
-
def render(self, env: MARLEnv
|
|
257
|
+
def render(self, env: MARLEnv, seed: Optional[int] = None, fps: int = 5):
|
|
261
258
|
def render_callback(*_, **__):
|
|
262
259
|
env.render()
|
|
263
260
|
cv2.waitKey(1000 // fps)
|
|
@@ -288,10 +285,10 @@ class Episode(Generic[A]):
|
|
|
288
285
|
return returns
|
|
289
286
|
|
|
290
287
|
@overload
|
|
291
|
-
def add(self, transition: Transition
|
|
288
|
+
def add(self, transition: Transition, /): ...
|
|
292
289
|
|
|
293
290
|
@overload
|
|
294
|
-
def add(self, step: Step, action:
|
|
291
|
+
def add(self, step: Step, action: np.ndarray, /): ...
|
|
295
292
|
|
|
296
293
|
def add(self, *data):
|
|
297
294
|
match data:
|
|
@@ -322,10 +319,10 @@ class Episode(Generic[A]):
|
|
|
322
319
|
|
|
323
320
|
def add_data(
|
|
324
321
|
self,
|
|
325
|
-
next_obs,
|
|
326
|
-
next_state,
|
|
327
|
-
action:
|
|
328
|
-
reward: np.
|
|
322
|
+
next_obs: Observation,
|
|
323
|
+
next_state: State,
|
|
324
|
+
action: np.ndarray,
|
|
325
|
+
reward: npt.NDArray[np.float32],
|
|
329
326
|
others: dict[str, Any],
|
|
330
327
|
done: bool,
|
|
331
328
|
truncated: bool,
|
marlenv/models/spaces.py
CHANGED
|
@@ -1,13 +1,11 @@
|
|
|
1
1
|
import math
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
-
from typing import
|
|
4
|
+
from typing import Optional
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
import numpy.typing as npt
|
|
8
8
|
|
|
9
|
-
S = TypeVar("S", bound="Space")
|
|
10
|
-
|
|
11
9
|
|
|
12
10
|
@dataclass
|
|
13
11
|
class Space(ABC):
|
|
@@ -23,7 +21,7 @@ class Space(ABC):
|
|
|
23
21
|
self.labels = labels
|
|
24
22
|
|
|
25
23
|
@abstractmethod
|
|
26
|
-
def sample(self, mask: Optional[npt.NDArray[np.bool_]] = None) ->
|
|
24
|
+
def sample(self, mask: Optional[npt.NDArray[np.bool_]] = None) -> npt.NDArray[np.float32]:
|
|
27
25
|
"""Sample a value from the space."""
|
|
28
26
|
|
|
29
27
|
def __eq__(self, value: object) -> bool:
|
|
@@ -34,6 +32,16 @@ class Space(ABC):
|
|
|
34
32
|
def __ne__(self, value: object) -> bool:
|
|
35
33
|
return not self.__eq__(value)
|
|
36
34
|
|
|
35
|
+
@property
|
|
36
|
+
@abstractmethod
|
|
37
|
+
def is_discrete(self) -> bool:
|
|
38
|
+
"""Whether the space is discrete."""
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def is_continuous(self) -> bool:
|
|
42
|
+
"""Whether the space is continuous."""
|
|
43
|
+
return not self.is_discrete
|
|
44
|
+
|
|
37
45
|
|
|
38
46
|
@dataclass
|
|
39
47
|
class DiscreteSpace(Space):
|
|
@@ -45,8 +53,8 @@ class DiscreteSpace(Space):
|
|
|
45
53
|
self.size = size
|
|
46
54
|
self.space = np.arange(size)
|
|
47
55
|
|
|
48
|
-
def sample(self, mask: Optional[npt.NDArray[np.
|
|
49
|
-
space = self.space
|
|
56
|
+
def sample(self, mask: Optional[npt.NDArray[np.bool]] = None):
|
|
57
|
+
space = self.space.copy()
|
|
50
58
|
if mask is not None:
|
|
51
59
|
space = space[mask]
|
|
52
60
|
return int(np.random.choice(space))
|
|
@@ -58,6 +66,25 @@ class DiscreteSpace(Space):
|
|
|
58
66
|
return False
|
|
59
67
|
return super().__eq__(value)
|
|
60
68
|
|
|
69
|
+
@property
|
|
70
|
+
def is_discrete(self) -> bool:
|
|
71
|
+
return True
|
|
72
|
+
|
|
73
|
+
@staticmethod
|
|
74
|
+
def action(size, labels: Optional[list[str]] = None):
|
|
75
|
+
"""
|
|
76
|
+
Create a discrete action space where the default labels are set to "Action-n".
|
|
77
|
+
"""
|
|
78
|
+
if labels is None:
|
|
79
|
+
labels = [f"Action {i}" for i in range(size)]
|
|
80
|
+
return DiscreteSpace(size, labels)
|
|
81
|
+
|
|
82
|
+
def repeat(self, n: int):
|
|
83
|
+
"""
|
|
84
|
+
Repeat the discrete space n times.
|
|
85
|
+
"""
|
|
86
|
+
return MultiDiscreteSpace(*([self] * n), labels=self.labels)
|
|
87
|
+
|
|
61
88
|
|
|
62
89
|
@dataclass
|
|
63
90
|
class MultiDiscreteSpace(Space):
|
|
@@ -75,10 +102,10 @@ class MultiDiscreteSpace(Space):
|
|
|
75
102
|
def from_sizes(cls, *sizes: int):
|
|
76
103
|
return cls(*(DiscreteSpace(size) for size in sizes))
|
|
77
104
|
|
|
78
|
-
def sample(self,
|
|
79
|
-
if
|
|
105
|
+
def sample(self, mask: Optional[npt.NDArray[np.bool] | list[npt.NDArray[np.bool]]] = None):
|
|
106
|
+
if mask is None:
|
|
80
107
|
return np.array([space.sample() for space in self.spaces], dtype=np.int32)
|
|
81
|
-
return np.array([space.sample(mask) for mask, space in zip(
|
|
108
|
+
return np.array([space.sample(mask=mask) for mask, space in zip(mask, self.spaces)], dtype=np.int32)
|
|
82
109
|
|
|
83
110
|
def __eq__(self, value: object) -> bool:
|
|
84
111
|
if not isinstance(value, MultiDiscreteSpace):
|
|
@@ -90,6 +117,10 @@ class MultiDiscreteSpace(Space):
|
|
|
90
117
|
return False
|
|
91
118
|
return super().__eq__(value)
|
|
92
119
|
|
|
120
|
+
@property
|
|
121
|
+
def is_discrete(self) -> bool:
|
|
122
|
+
return True
|
|
123
|
+
|
|
93
124
|
|
|
94
125
|
@dataclass
|
|
95
126
|
class ContinuousSpace(Space):
|
|
@@ -100,23 +131,35 @@ class ContinuousSpace(Space):
|
|
|
100
131
|
high: npt.NDArray[np.float32]
|
|
101
132
|
"""Upper bound of the space for each dimension."""
|
|
102
133
|
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
low: int | float | list | npt.NDArray[np.float32],
|
|
106
|
-
high: int | float | list | npt.NDArray[np.float32],
|
|
134
|
+
def __init__(
|
|
135
|
+
self,
|
|
136
|
+
low: int | float | list | npt.NDArray[np.float32] | None,
|
|
137
|
+
high: int | float | list | npt.NDArray[np.float32] | None,
|
|
107
138
|
labels: Optional[list[str]] = None,
|
|
108
139
|
):
|
|
109
140
|
match low:
|
|
141
|
+
case None:
|
|
142
|
+
assert high is not None, "If low is None, high must be set to infer the shape."
|
|
143
|
+
shape = ContinuousSpace.get_shape(high)
|
|
144
|
+
low = np.full(shape, -np.inf, dtype=np.float32)
|
|
110
145
|
case list():
|
|
111
146
|
low = np.array(low, dtype=np.float32)
|
|
112
147
|
case float() | int():
|
|
113
148
|
low = np.array([low], dtype=np.float32)
|
|
114
149
|
match high:
|
|
150
|
+
case None:
|
|
151
|
+
assert low is not None, "If high is None, low must be set to infer the shape."
|
|
152
|
+
shape = ContinuousSpace.get_shape(low)
|
|
153
|
+
high = np.full(shape, np.inf, dtype=np.float32)
|
|
115
154
|
case list():
|
|
116
155
|
high = np.array(high, dtype=np.float32)
|
|
117
156
|
case float() | int():
|
|
118
157
|
high = np.array([high], dtype=np.float32)
|
|
119
|
-
|
|
158
|
+
assert low.shape == high.shape, f"Low and high must have the same shape. Low shape: {low.shape}, high shape: {high.shape}"
|
|
159
|
+
assert np.all(low <= high), "All elements in low must be less than the corresponding elements in high."
|
|
160
|
+
Space.__init__(self, low.shape, labels)
|
|
161
|
+
self.low = low
|
|
162
|
+
self.high = high
|
|
120
163
|
|
|
121
164
|
@staticmethod
|
|
122
165
|
def from_shape(
|
|
@@ -143,20 +186,24 @@ class ContinuousSpace(Space):
|
|
|
143
186
|
high = np.array(high, dtype=np.float32)
|
|
144
187
|
return ContinuousSpace(low, high, labels)
|
|
145
188
|
|
|
146
|
-
def
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
self.low = low
|
|
156
|
-
self.high = high
|
|
189
|
+
def clamp(self, action: np.ndarray | list):
|
|
190
|
+
"""Clamp the action to the bounds of the space."""
|
|
191
|
+
if isinstance(action, list):
|
|
192
|
+
action = np.array(action)
|
|
193
|
+
return np.clip(action, self.low, self.high)
|
|
194
|
+
|
|
195
|
+
def sample(self) -> npt.NDArray[np.float32]:
|
|
196
|
+
r = np.random.random(self.shape) * (self.high - self.low) + self.low
|
|
197
|
+
return r.astype(np.float32)
|
|
157
198
|
|
|
158
|
-
|
|
159
|
-
|
|
199
|
+
@staticmethod
|
|
200
|
+
def get_shape(item: float | int | list | npt.NDArray[np.float32]) -> tuple[int, ...]:
|
|
201
|
+
"""Get the shape of the item."""
|
|
202
|
+
if isinstance(item, list):
|
|
203
|
+
item = np.array(item)
|
|
204
|
+
if isinstance(item, np.ndarray):
|
|
205
|
+
return item.shape
|
|
206
|
+
return (1,)
|
|
160
207
|
|
|
161
208
|
def __eq__(self, value: object) -> bool:
|
|
162
209
|
if not isinstance(value, ContinuousSpace):
|
|
@@ -167,59 +214,19 @@ class ContinuousSpace(Space):
|
|
|
167
214
|
return False
|
|
168
215
|
return super().__eq__(value)
|
|
169
216
|
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
def sample(self, mask: np.ndarray | None = None):
|
|
188
|
-
res = []
|
|
189
|
-
for i in range(self.n_agents):
|
|
190
|
-
if mask is not None:
|
|
191
|
-
m = mask[i]
|
|
192
|
-
else:
|
|
193
|
-
m = None
|
|
194
|
-
res.append(self.individual_action_space.sample(m))
|
|
195
|
-
return np.array(res)
|
|
196
|
-
|
|
197
|
-
def __eq__(self, value: object) -> bool:
|
|
198
|
-
if not isinstance(value, ActionSpace):
|
|
199
|
-
return False
|
|
200
|
-
if self.n_agents != value.n_agents:
|
|
201
|
-
return False
|
|
202
|
-
if self.n_actions != value.n_actions:
|
|
203
|
-
return False
|
|
204
|
-
if self.individual_action_space != value.individual_action_space:
|
|
205
|
-
return False
|
|
206
|
-
return super().__eq__(value)
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
@dataclass
|
|
210
|
-
class DiscreteActionSpace(ActionSpace[DiscreteSpace]):
|
|
211
|
-
def __init__(self, n_agents: int, n_actions: int, action_names: Optional[list[str]] = None):
|
|
212
|
-
individual_action_space = DiscreteSpace(n_actions, action_names)
|
|
213
|
-
super().__init__(n_agents, individual_action_space, action_names)
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
@dataclass
|
|
217
|
-
class MultiDiscreteActionSpace(ActionSpace[MultiDiscreteSpace]):
|
|
218
|
-
pass
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
@dataclass
|
|
222
|
-
class ContinuousActionSpace(ActionSpace[ContinuousSpace]):
|
|
223
|
-
def __init__(self, n_agents: int, low: np.ndarray | list, high: np.ndarray | list, action_names: list | None = None):
|
|
224
|
-
space = ContinuousSpace.from_bounds(low, high, action_names)
|
|
225
|
-
super().__init__(n_agents, space, action_names)
|
|
217
|
+
def repeat(self, n: int):
|
|
218
|
+
"""
|
|
219
|
+
Repeat the continuous space n times to become of shape (n, *shape).
|
|
220
|
+
"""
|
|
221
|
+
low = np.tile(self.low, (n, 1))
|
|
222
|
+
high = np.tile(self.high, (n, 1))
|
|
223
|
+
return ContinuousSpace.from_shape(
|
|
224
|
+
(n, *self.shape),
|
|
225
|
+
low=low,
|
|
226
|
+
high=high,
|
|
227
|
+
labels=self.labels,
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
@property
|
|
231
|
+
def is_discrete(self) -> bool:
|
|
232
|
+
return False
|
marlenv/models/transition.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
|
-
from typing import Any,
|
|
3
|
-
from typing_extensions import TypeVar
|
|
2
|
+
from typing import Any, Sequence
|
|
4
3
|
|
|
5
4
|
import numpy as np
|
|
6
5
|
import numpy.typing as npt
|
|
@@ -10,16 +9,13 @@ from .state import State
|
|
|
10
9
|
from .step import Step
|
|
11
10
|
|
|
12
11
|
|
|
13
|
-
A = TypeVar("A", default=np.ndarray)
|
|
14
|
-
|
|
15
|
-
|
|
16
12
|
@dataclass
|
|
17
|
-
class Transition
|
|
13
|
+
class Transition:
|
|
18
14
|
"""Transition model"""
|
|
19
15
|
|
|
20
16
|
obs: Observation
|
|
21
17
|
state: State
|
|
22
|
-
action:
|
|
18
|
+
action: np.ndarray
|
|
23
19
|
reward: npt.NDArray[np.float32]
|
|
24
20
|
done: bool
|
|
25
21
|
info: dict[str, Any]
|
|
@@ -32,7 +28,7 @@ class Transition(Generic[A]):
|
|
|
32
28
|
self,
|
|
33
29
|
obs: Observation,
|
|
34
30
|
state: State,
|
|
35
|
-
action:
|
|
31
|
+
action: np.ndarray | Sequence[float],
|
|
36
32
|
reward: npt.NDArray[np.float32] | float | Sequence[float],
|
|
37
33
|
done: bool,
|
|
38
34
|
info: dict[str, Any],
|
|
@@ -65,14 +61,14 @@ class Transition(Generic[A]):
|
|
|
65
61
|
def from_step(
|
|
66
62
|
prev_obs: Observation,
|
|
67
63
|
prev_state: State,
|
|
68
|
-
|
|
64
|
+
action: np.ndarray | Sequence[float],
|
|
69
65
|
step: Step,
|
|
70
66
|
**kwargs,
|
|
71
67
|
):
|
|
72
68
|
return Transition(
|
|
73
69
|
obs=prev_obs,
|
|
74
70
|
state=prev_state,
|
|
75
|
-
action=
|
|
71
|
+
action=action,
|
|
76
72
|
reward=step.reward,
|
|
77
73
|
done=step.done,
|
|
78
74
|
info=step.info,
|