multi-agent-rlenv 3.3.7__py3-none-any.whl → 3.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- marlenv/__init__.py +11 -13
- marlenv/adapters/gym_adapter.py +6 -16
- marlenv/adapters/overcooked_adapter.py +43 -23
- marlenv/adapters/pettingzoo_adapter.py +5 -5
- marlenv/adapters/pymarl_adapter.py +3 -4
- marlenv/adapters/smac_adapter.py +6 -6
- marlenv/env_builder.py +8 -9
- marlenv/env_pool.py +5 -7
- marlenv/mock_env.py +7 -7
- marlenv/models/__init__.py +2 -4
- marlenv/models/env.py +23 -12
- marlenv/models/episode.py +17 -20
- marlenv/models/spaces.py +90 -83
- marlenv/models/transition.py +6 -10
- marlenv/utils/__init__.py +10 -0
- marlenv/utils/schedule.py +281 -0
- marlenv/wrappers/agent_id_wrapper.py +4 -5
- marlenv/wrappers/available_actions_mask.py +6 -7
- marlenv/wrappers/available_actions_wrapper.py +7 -9
- marlenv/wrappers/blind_wrapper.py +5 -7
- marlenv/wrappers/centralised.py +12 -14
- marlenv/wrappers/delayed_rewards.py +13 -11
- marlenv/wrappers/last_action_wrapper.py +10 -14
- marlenv/wrappers/paddings.py +6 -8
- marlenv/wrappers/penalty_wrapper.py +5 -8
- marlenv/wrappers/rlenv_wrapper.py +12 -9
- marlenv/wrappers/time_limit.py +3 -3
- marlenv/wrappers/video_recorder.py +4 -6
- {multi_agent_rlenv-3.3.7.dist-info → multi_agent_rlenv-3.5.0.dist-info}/METADATA +1 -1
- multi_agent_rlenv-3.5.0.dist-info/RECORD +39 -0
- multi_agent_rlenv-3.3.7.dist-info/RECORD +0 -37
- {multi_agent_rlenv-3.3.7.dist-info → multi_agent_rlenv-3.5.0.dist-info}/WHEEL +0 -0
- {multi_agent_rlenv-3.3.7.dist-info → multi_agent_rlenv-3.5.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Callable, Optional, TypeVar
|
|
4
|
+
|
|
5
|
+
T = TypeVar("T")
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class Schedule:
|
|
10
|
+
"""
|
|
11
|
+
Schedules the value of a varaible over time.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
name: str
|
|
15
|
+
start_value: float
|
|
16
|
+
end_value: float
|
|
17
|
+
_t: int
|
|
18
|
+
n_steps: int
|
|
19
|
+
|
|
20
|
+
def __init__(self, start_value: float, end_value: float, n_steps: int):
|
|
21
|
+
self.start_value = start_value
|
|
22
|
+
self.end_value = end_value
|
|
23
|
+
self.n_steps = n_steps
|
|
24
|
+
self.name = self.__class__.__name__
|
|
25
|
+
self._t = 0
|
|
26
|
+
self._current_value = self.start_value
|
|
27
|
+
|
|
28
|
+
def update(self, step: Optional[int] = None):
|
|
29
|
+
"""Update the value of the schedule. Force a step if given."""
|
|
30
|
+
if step is not None:
|
|
31
|
+
self._t = step
|
|
32
|
+
else:
|
|
33
|
+
self._t += 1
|
|
34
|
+
if self._t >= self.n_steps:
|
|
35
|
+
self._current_value = self.end_value
|
|
36
|
+
else:
|
|
37
|
+
self._current_value = self._compute()
|
|
38
|
+
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def _compute(self) -> float:
|
|
41
|
+
"""Compute the value of the schedule"""
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def value(self) -> float:
|
|
45
|
+
"""Returns the current value of the schedule"""
|
|
46
|
+
return self._current_value
|
|
47
|
+
|
|
48
|
+
@staticmethod
|
|
49
|
+
def constant(value: float):
|
|
50
|
+
return ConstantSchedule(value)
|
|
51
|
+
|
|
52
|
+
@staticmethod
|
|
53
|
+
def linear(start_value: float, end_value: float, n_steps: int):
|
|
54
|
+
return LinearSchedule(start_value, end_value, n_steps)
|
|
55
|
+
|
|
56
|
+
@staticmethod
|
|
57
|
+
def exp(start_value: float, end_value: float, n_steps: int):
|
|
58
|
+
return ExpSchedule(start_value, end_value, n_steps)
|
|
59
|
+
|
|
60
|
+
@staticmethod
|
|
61
|
+
def arbitrary(func: Callable[[int], float], n_steps: Optional[int] = None):
|
|
62
|
+
if n_steps is None:
|
|
63
|
+
n_steps = 0
|
|
64
|
+
return ArbitrarySchedule(func, n_steps)
|
|
65
|
+
|
|
66
|
+
def rounded(self, n_digits: int = 0) -> "RoundedSchedule":
|
|
67
|
+
return RoundedSchedule(self, n_digits)
|
|
68
|
+
|
|
69
|
+
# Operator overloading
|
|
70
|
+
def __mul__(self, other: T) -> T:
|
|
71
|
+
return self.value * other # type: ignore
|
|
72
|
+
|
|
73
|
+
def __rmul__(self, other: T) -> T:
|
|
74
|
+
return self.value * other # type: ignore
|
|
75
|
+
|
|
76
|
+
def __pow__(self, exp: float) -> float:
|
|
77
|
+
return self.value**exp
|
|
78
|
+
|
|
79
|
+
def __rpow__(self, other: T) -> T:
|
|
80
|
+
return other**self.value # type: ignore
|
|
81
|
+
|
|
82
|
+
def __add__(self, other: T) -> T:
|
|
83
|
+
return self.value + other # type: ignore
|
|
84
|
+
|
|
85
|
+
def __radd__(self, other: T) -> T:
|
|
86
|
+
return self.value + other # type: ignore
|
|
87
|
+
|
|
88
|
+
def __neg__(self):
|
|
89
|
+
return -self.value
|
|
90
|
+
|
|
91
|
+
def __pos__(self):
|
|
92
|
+
return +self.value
|
|
93
|
+
|
|
94
|
+
def __sub__(self, other: T) -> T:
|
|
95
|
+
return self.value - other # type: ignore
|
|
96
|
+
|
|
97
|
+
def __rsub__(self, other: T) -> T:
|
|
98
|
+
return other - self.value # type: ignore
|
|
99
|
+
|
|
100
|
+
def __div__(self, other: T) -> T:
|
|
101
|
+
return self.value // other # type: ignore
|
|
102
|
+
|
|
103
|
+
def __rdiv__(self, other: T) -> T:
|
|
104
|
+
return other // self.value # type: ignore
|
|
105
|
+
|
|
106
|
+
def __truediv__(self, other: T) -> T:
|
|
107
|
+
return self.value / other # type: ignore
|
|
108
|
+
|
|
109
|
+
def __rtruediv__(self, other: T) -> T:
|
|
110
|
+
return other / self.value # type: ignore
|
|
111
|
+
|
|
112
|
+
def __lt__(self, other) -> bool:
|
|
113
|
+
return self.value < other
|
|
114
|
+
|
|
115
|
+
def __le__(self, other) -> bool:
|
|
116
|
+
return self.value <= other
|
|
117
|
+
|
|
118
|
+
def __gt__(self, other) -> bool:
|
|
119
|
+
return self.value > other
|
|
120
|
+
|
|
121
|
+
def __ge__(self, other) -> bool:
|
|
122
|
+
return self.value >= other
|
|
123
|
+
|
|
124
|
+
def __eq__(self, other) -> bool:
|
|
125
|
+
if isinstance(other, Schedule):
|
|
126
|
+
if self.start_value != other.start_value:
|
|
127
|
+
return False
|
|
128
|
+
if self.end_value != other.end_value:
|
|
129
|
+
return False
|
|
130
|
+
if self.n_steps != other.n_steps:
|
|
131
|
+
return False
|
|
132
|
+
if type(self) is not type(other):
|
|
133
|
+
return False
|
|
134
|
+
return self.value == other
|
|
135
|
+
|
|
136
|
+
def __ne__(self, other) -> bool:
|
|
137
|
+
return not (self.__eq__(other))
|
|
138
|
+
|
|
139
|
+
def __float__(self):
|
|
140
|
+
return self.value
|
|
141
|
+
|
|
142
|
+
def __int__(self) -> int:
|
|
143
|
+
return int(self.value)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
@dataclass(eq=False)
|
|
147
|
+
class LinearSchedule(Schedule):
|
|
148
|
+
a: float
|
|
149
|
+
b: float
|
|
150
|
+
|
|
151
|
+
def __init__(self, start_value: float, end_value: float, n_steps: int):
|
|
152
|
+
super().__init__(start_value, end_value, n_steps)
|
|
153
|
+
self._current_value = self.start_value
|
|
154
|
+
# y = ax + b
|
|
155
|
+
self.a = (self.end_value - self.start_value) / self.n_steps
|
|
156
|
+
self.b = self.start_value
|
|
157
|
+
|
|
158
|
+
def _compute(self):
|
|
159
|
+
return self.a * (self._t) + self.b
|
|
160
|
+
|
|
161
|
+
@property
|
|
162
|
+
def value(self) -> float:
|
|
163
|
+
return self._current_value
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@dataclass(eq=False)
|
|
167
|
+
class ExpSchedule(Schedule):
|
|
168
|
+
"""Exponential schedule. After n_steps, the value will be min_value.
|
|
169
|
+
|
|
170
|
+
Update formula is next_value = start_value * (min_value / start_value) ** (step / (n - 1))
|
|
171
|
+
"""
|
|
172
|
+
|
|
173
|
+
n_steps: int
|
|
174
|
+
base: float
|
|
175
|
+
last_update_step: int
|
|
176
|
+
|
|
177
|
+
def __init__(self, start_value: float, min_value: float, n_steps: int):
|
|
178
|
+
super().__init__(start_value, min_value, n_steps)
|
|
179
|
+
self.base = self.end_value / self.start_value
|
|
180
|
+
self.last_update_step = self.n_steps - 1
|
|
181
|
+
|
|
182
|
+
def _compute(self):
|
|
183
|
+
return self.start_value * (self.base) ** (self._t / (self.n_steps - 1))
|
|
184
|
+
|
|
185
|
+
@property
|
|
186
|
+
def value(self) -> float:
|
|
187
|
+
return self._current_value
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
@dataclass(eq=False)
|
|
191
|
+
class ConstantSchedule(Schedule):
|
|
192
|
+
def __init__(self, value: float):
|
|
193
|
+
super().__init__(value, value, 0)
|
|
194
|
+
self._value = value
|
|
195
|
+
|
|
196
|
+
def update(self, step=None):
|
|
197
|
+
return
|
|
198
|
+
|
|
199
|
+
@property
|
|
200
|
+
def value(self) -> float:
|
|
201
|
+
return self._value
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
@dataclass(eq=False)
|
|
205
|
+
class RoundedSchedule(Schedule):
|
|
206
|
+
def __init__(self, schedule: Schedule, n_digits: int):
|
|
207
|
+
super().__init__(schedule.start_value, schedule.end_value, schedule.n_steps)
|
|
208
|
+
self.schedule = schedule
|
|
209
|
+
self.n_digits = n_digits
|
|
210
|
+
|
|
211
|
+
def update(self, step: int | None = None):
|
|
212
|
+
return self.schedule.update(step)
|
|
213
|
+
|
|
214
|
+
def _compute(self) -> float:
|
|
215
|
+
return self.schedule._compute()
|
|
216
|
+
|
|
217
|
+
@property
|
|
218
|
+
def value(self) -> float:
|
|
219
|
+
return round(self.schedule.value, self.n_digits)
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
@dataclass(eq=False)
|
|
223
|
+
class MultiSchedule(Schedule):
|
|
224
|
+
def __init__(self, schedules: dict[int, Schedule]):
|
|
225
|
+
ordered_schedules, ordered_steps = MultiSchedule._verify(schedules)
|
|
226
|
+
n_steps = ordered_steps[-1] + ordered_schedules[-1].n_steps
|
|
227
|
+
super().__init__(ordered_schedules[0].start_value, ordered_schedules[-1].end_value, n_steps)
|
|
228
|
+
self.schedules = iter(ordered_schedules)
|
|
229
|
+
self.current_schedule = next(self.schedules)
|
|
230
|
+
self.offset = 0
|
|
231
|
+
self.current_end = ordered_steps[1]
|
|
232
|
+
|
|
233
|
+
@staticmethod
|
|
234
|
+
def _verify(schedules: dict[int, Schedule]):
|
|
235
|
+
sorted_steps = sorted(schedules.keys())
|
|
236
|
+
sorted_schedules = [schedules[t] for t in sorted_steps]
|
|
237
|
+
if sorted_steps[0] != 0:
|
|
238
|
+
raise ValueError("First schedule must start at t=0")
|
|
239
|
+
current_step = 0
|
|
240
|
+
for i in range(len(sorted_steps)):
|
|
241
|
+
# Artificially set the end step of ConstantSchedules to the next step
|
|
242
|
+
if isinstance(sorted_schedules[i], ConstantSchedule):
|
|
243
|
+
if i + 1 < len(sorted_steps):
|
|
244
|
+
sorted_schedules[i].n_steps = sorted_steps[i + 1] - sorted_steps[i]
|
|
245
|
+
if sorted_steps[i] != current_step:
|
|
246
|
+
raise ValueError(f"Schedules are not contiguous at t={current_step}")
|
|
247
|
+
current_step += sorted_schedules[i].n_steps
|
|
248
|
+
return sorted_schedules, sorted_steps
|
|
249
|
+
|
|
250
|
+
def update(self, step: int | None = None):
|
|
251
|
+
if step is not None:
|
|
252
|
+
raise NotImplementedError("Cannot update MultiSchedule with a specific step")
|
|
253
|
+
super().update(step)
|
|
254
|
+
# If we reach the end of the current schedule, update to the next one
|
|
255
|
+
# except if we are at the end.
|
|
256
|
+
if self._t == self.current_end and self._t < self.n_steps:
|
|
257
|
+
self.current_schedule = next(self.schedules)
|
|
258
|
+
self.offset = self._t
|
|
259
|
+
self.current_end += self.current_schedule.n_steps
|
|
260
|
+
self.current_schedule.update(self.relative_step)
|
|
261
|
+
|
|
262
|
+
@property
|
|
263
|
+
def relative_step(self):
|
|
264
|
+
return self._t - self.offset
|
|
265
|
+
|
|
266
|
+
def _compute(self) -> float:
|
|
267
|
+
return self.current_schedule._compute()
|
|
268
|
+
|
|
269
|
+
@property
|
|
270
|
+
def value(self):
|
|
271
|
+
return self.current_schedule.value
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
@dataclass(eq=False)
|
|
275
|
+
class ArbitrarySchedule(Schedule):
|
|
276
|
+
def __init__(self, fn: Callable[[int], float], n_steps: int):
|
|
277
|
+
super().__init__(fn(0), fn(n_steps), n_steps)
|
|
278
|
+
self._func = fn
|
|
279
|
+
|
|
280
|
+
def _compute(self) -> float:
|
|
281
|
+
return self._func(self._t)
|
|
@@ -1,19 +1,18 @@
|
|
|
1
1
|
import numpy as np
|
|
2
|
-
from marlenv.models import MARLEnv,
|
|
2
|
+
from marlenv.models import MARLEnv, Space
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
from .rlenv_wrapper import RLEnvWrapper
|
|
5
5
|
|
|
6
6
|
from typing_extensions import TypeVar
|
|
7
7
|
|
|
8
|
-
|
|
9
|
-
AS = TypeVar("AS", bound=ActionSpace, default=ActionSpace)
|
|
8
|
+
AS = TypeVar("AS", bound=Space, default=Space)
|
|
10
9
|
|
|
11
10
|
|
|
12
11
|
@dataclass
|
|
13
|
-
class AgentId(RLEnvWrapper[
|
|
12
|
+
class AgentId(RLEnvWrapper[AS]):
|
|
14
13
|
"""RLEnv wrapper that adds a one-hot encoding of the agent id."""
|
|
15
14
|
|
|
16
|
-
def __init__(self, env: MARLEnv[
|
|
15
|
+
def __init__(self, env: MARLEnv[AS]):
|
|
17
16
|
assert len(env.extras_shape) == 1, "AgentIdWrapper only works with single dimension extras"
|
|
18
17
|
meanings = env.extras_meanings + [f"Agent ID-{i}" for i in range(env.n_agents)]
|
|
19
18
|
super().__init__(env, extra_shape=(env.n_agents + env.extras_shape[0],), extra_meanings=meanings)
|
|
@@ -2,20 +2,19 @@ import numpy as np
|
|
|
2
2
|
import numpy.typing as npt
|
|
3
3
|
from typing_extensions import TypeVar
|
|
4
4
|
from .rlenv_wrapper import MARLEnv, RLEnvWrapper
|
|
5
|
-
from marlenv.models import
|
|
5
|
+
from marlenv.models import Space
|
|
6
6
|
from dataclasses import dataclass
|
|
7
7
|
|
|
8
|
-
|
|
9
|
-
AS = TypeVar("AS", bound=ActionSpace, default=ActionSpace)
|
|
8
|
+
AS = TypeVar("AS", bound=Space, default=Space)
|
|
10
9
|
|
|
11
10
|
|
|
12
11
|
@dataclass
|
|
13
|
-
class AvailableActionsMask(RLEnvWrapper[
|
|
12
|
+
class AvailableActionsMask(RLEnvWrapper[AS]):
|
|
14
13
|
"""Permanently masks a subset of the available actions."""
|
|
15
14
|
|
|
16
15
|
action_mask: npt.NDArray[np.bool_]
|
|
17
16
|
|
|
18
|
-
def __init__(self, env: MARLEnv[
|
|
17
|
+
def __init__(self, env: MARLEnv[AS], action_mask: npt.NDArray[np.bool_]):
|
|
19
18
|
super().__init__(env)
|
|
20
19
|
assert action_mask.shape == (env.n_agents, env.n_actions), "Action mask must have shape (n_agents, n_actions)."
|
|
21
20
|
n_available_action_per_agent = action_mask.sum(axis=-1)
|
|
@@ -27,8 +26,8 @@ class AvailableActionsMask(RLEnvWrapper[A, AS]):
|
|
|
27
26
|
obs.available_actions = self.available_actions()
|
|
28
27
|
return obs, state
|
|
29
28
|
|
|
30
|
-
def step(self,
|
|
31
|
-
step = self.wrapped.step(
|
|
29
|
+
def step(self, action):
|
|
30
|
+
step = self.wrapped.step(action)
|
|
32
31
|
step.obs.available_actions = self.available_actions()
|
|
33
32
|
return step
|
|
34
33
|
|
|
@@ -1,21 +1,19 @@
|
|
|
1
1
|
import numpy as np
|
|
2
|
-
import numpy.typing as npt
|
|
3
2
|
from typing_extensions import TypeVar
|
|
4
|
-
from marlenv.models import
|
|
3
|
+
from marlenv.models import Space, MARLEnv
|
|
5
4
|
from .rlenv_wrapper import RLEnvWrapper
|
|
6
5
|
from dataclasses import dataclass
|
|
7
6
|
|
|
8
7
|
|
|
9
|
-
|
|
10
|
-
AS = TypeVar("AS", bound=ActionSpace, default=ActionSpace)
|
|
8
|
+
AS = TypeVar("AS", bound=Space, default=Space)
|
|
11
9
|
|
|
12
10
|
|
|
13
11
|
@dataclass
|
|
14
|
-
class AvailableActions(RLEnvWrapper[
|
|
12
|
+
class AvailableActions(RLEnvWrapper[AS]):
|
|
15
13
|
"""Adds the available actions (one-hot) as an extra feature to the observation."""
|
|
16
14
|
|
|
17
|
-
def __init__(self, env: MARLEnv[
|
|
18
|
-
meanings = env.extras_meanings + [f"{a} available" for a in env.action_space.
|
|
15
|
+
def __init__(self, env: MARLEnv[AS]):
|
|
16
|
+
meanings = env.extras_meanings + [f"{a} available" for a in env.action_space.labels]
|
|
19
17
|
super().__init__(env, extra_shape=(env.extras_shape[0] + env.n_actions,), extra_meanings=meanings)
|
|
20
18
|
|
|
21
19
|
def reset(self):
|
|
@@ -23,7 +21,7 @@ class AvailableActions(RLEnvWrapper[A, AS]):
|
|
|
23
21
|
obs.add_extra(self.available_actions().astype(np.float32))
|
|
24
22
|
return obs, state
|
|
25
23
|
|
|
26
|
-
def step(self,
|
|
27
|
-
step = self.wrapped.step(
|
|
24
|
+
def step(self, action):
|
|
25
|
+
step = self.wrapped.step(action)
|
|
28
26
|
step.obs.add_extra(self.available_actions().astype(np.float32))
|
|
29
27
|
return step
|
|
@@ -1,26 +1,24 @@
|
|
|
1
1
|
import random
|
|
2
2
|
from typing_extensions import TypeVar
|
|
3
3
|
import numpy as np
|
|
4
|
-
import numpy.typing as npt
|
|
5
4
|
from dataclasses import dataclass
|
|
6
5
|
|
|
7
|
-
from marlenv.models import MARLEnv,
|
|
6
|
+
from marlenv.models import MARLEnv, Space
|
|
8
7
|
from .rlenv_wrapper import RLEnvWrapper
|
|
9
8
|
|
|
10
9
|
|
|
11
|
-
|
|
12
|
-
AS = TypeVar("AS", bound=ActionSpace, default=ActionSpace)
|
|
10
|
+
AS = TypeVar("AS", bound=Space, default=Space)
|
|
13
11
|
|
|
14
12
|
|
|
15
13
|
@dataclass
|
|
16
|
-
class Blind(RLEnvWrapper[
|
|
14
|
+
class Blind(RLEnvWrapper[AS]):
|
|
17
15
|
p: float
|
|
18
16
|
|
|
19
|
-
def __init__(self, env: MARLEnv[
|
|
17
|
+
def __init__(self, env: MARLEnv[AS], p: float | int):
|
|
20
18
|
super().__init__(env)
|
|
21
19
|
self.p = float(p)
|
|
22
20
|
|
|
23
|
-
def step(self, actions
|
|
21
|
+
def step(self, actions):
|
|
24
22
|
step = super().step(actions)
|
|
25
23
|
if random.random() < self.p:
|
|
26
24
|
step.obs.data = np.zeros_like(step.obs.data)
|
marlenv/wrappers/centralised.py
CHANGED
|
@@ -4,28 +4,26 @@ from typing import Sequence
|
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import numpy.typing as npt
|
|
7
|
-
from typing_extensions import TypeVar
|
|
8
7
|
|
|
9
|
-
from marlenv.models import
|
|
8
|
+
from marlenv.models import DiscreteSpace, MARLEnv, MultiDiscreteSpace, Observation
|
|
10
9
|
|
|
11
10
|
from .rlenv_wrapper import RLEnvWrapper
|
|
12
11
|
|
|
13
|
-
A = TypeVar("A", bound=npt.NDArray | Sequence[int] | Sequence[Sequence[float]])
|
|
14
|
-
|
|
15
12
|
|
|
16
13
|
@dataclass
|
|
17
|
-
class Centralized(RLEnvWrapper[
|
|
18
|
-
joint_action_space:
|
|
14
|
+
class Centralized(RLEnvWrapper[MultiDiscreteSpace]):
|
|
15
|
+
joint_action_space: DiscreteSpace
|
|
19
16
|
|
|
20
|
-
def __init__(self, env: MARLEnv[
|
|
21
|
-
if not isinstance(env.action_space
|
|
17
|
+
def __init__(self, env: MARLEnv[MultiDiscreteSpace]):
|
|
18
|
+
if not isinstance(env.action_space, MultiDiscreteSpace):
|
|
22
19
|
raise NotImplementedError(f"Action space {env.action_space} not supported")
|
|
23
20
|
joint_observation_shape = (env.observation_shape[0] * env.n_agents, *env.observation_shape[1:])
|
|
24
21
|
super().__init__(
|
|
25
22
|
env,
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
env.
|
|
23
|
+
n_agents=1,
|
|
24
|
+
observation_shape=joint_observation_shape,
|
|
25
|
+
state_shape=env.state_shape,
|
|
26
|
+
state_extra_shape=env.extras_shape,
|
|
29
27
|
action_space=self._make_joint_action_space(env),
|
|
30
28
|
)
|
|
31
29
|
|
|
@@ -37,12 +35,12 @@ class Centralized(RLEnvWrapper[A, DiscreteActionSpace]):
|
|
|
37
35
|
obs = super().get_observation()
|
|
38
36
|
return self._joint_observation(obs)
|
|
39
37
|
|
|
40
|
-
def _make_joint_action_space(self, env: MARLEnv[
|
|
38
|
+
def _make_joint_action_space(self, env: MARLEnv[MultiDiscreteSpace]):
|
|
41
39
|
agent_actions = list[list[str]]()
|
|
42
40
|
for agent in range(env.n_agents):
|
|
43
|
-
agent_actions.append([f"{agent}-{action}" for action in env.action_space.
|
|
41
|
+
agent_actions.append([f"{agent}-{action}" for action in env.action_space.labels])
|
|
44
42
|
action_names = [str(a) for a in product(*agent_actions)]
|
|
45
|
-
return
|
|
43
|
+
return DiscreteSpace(env.n_actions**env.n_agents, action_names).repeat(1)
|
|
46
44
|
|
|
47
45
|
def step(self, actions: npt.NDArray | Sequence):
|
|
48
46
|
action = actions[0]
|
|
@@ -1,20 +1,22 @@
|
|
|
1
|
-
from .rlenv_wrapper import RLEnvWrapper, MARLEnv
|
|
2
|
-
from marlenv.models import ActionSpace
|
|
3
|
-
from typing_extensions import TypeVar
|
|
4
|
-
import numpy.typing as npt
|
|
5
|
-
import numpy as np
|
|
6
|
-
from dataclasses import dataclass
|
|
7
1
|
from collections import deque
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import numpy.typing as npt
|
|
6
|
+
from typing_extensions import TypeVar
|
|
7
|
+
|
|
8
|
+
from marlenv.models import Space
|
|
9
|
+
|
|
10
|
+
from .rlenv_wrapper import MARLEnv, RLEnvWrapper
|
|
8
11
|
|
|
9
|
-
|
|
10
|
-
AS = TypeVar("AS", bound=ActionSpace, default=ActionSpace)
|
|
12
|
+
AS = TypeVar("AS", bound=Space, default=Space)
|
|
11
13
|
|
|
12
14
|
|
|
13
15
|
@dataclass
|
|
14
|
-
class DelayedReward(RLEnvWrapper[
|
|
16
|
+
class DelayedReward(RLEnvWrapper[AS]):
|
|
15
17
|
delay: int
|
|
16
18
|
|
|
17
|
-
def __init__(self, env: MARLEnv[
|
|
19
|
+
def __init__(self, env: MARLEnv[AS], delay: int):
|
|
18
20
|
super().__init__(env)
|
|
19
21
|
self.delay = delay
|
|
20
22
|
self.reward_queue = deque[npt.NDArray[np.float32]](maxlen=delay + 1)
|
|
@@ -25,7 +27,7 @@ class DelayedReward(RLEnvWrapper[A, AS]):
|
|
|
25
27
|
self.reward_queue.append(np.zeros(self.reward_space.shape, dtype=np.float32))
|
|
26
28
|
return super().reset()
|
|
27
29
|
|
|
28
|
-
def step(self, actions
|
|
30
|
+
def step(self, actions):
|
|
29
31
|
step = super().step(actions)
|
|
30
32
|
self.reward_queue.append(step.reward)
|
|
31
33
|
# If the step is terminal, we sum all the remaining rewards
|
|
@@ -1,25 +1,21 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
|
-
from typing_extensions import TypeVar
|
|
3
|
-
from typing import Sequence
|
|
4
2
|
|
|
5
3
|
import numpy as np
|
|
6
4
|
import numpy.typing as npt
|
|
5
|
+
from typing_extensions import TypeVar
|
|
7
6
|
|
|
8
|
-
from marlenv.models import
|
|
7
|
+
from marlenv.models import ContinuousSpace, DiscreteSpace, MultiDiscreteSpace, Space, State
|
|
9
8
|
|
|
10
9
|
from .rlenv_wrapper import MARLEnv, RLEnvWrapper
|
|
11
10
|
|
|
12
|
-
AS = TypeVar("AS", bound=
|
|
13
|
-
DiscreteActionType = npt.NDArray[np.int64 | np.int32] | Sequence[int]
|
|
14
|
-
ContinuousActionType = npt.NDArray[np.float32] | Sequence[Sequence[float]]
|
|
15
|
-
A = TypeVar("A", bound=DiscreteActionType | ContinuousActionType)
|
|
11
|
+
AS = TypeVar("AS", bound=Space, default=Space)
|
|
16
12
|
|
|
17
13
|
|
|
18
14
|
@dataclass
|
|
19
|
-
class LastAction(RLEnvWrapper[
|
|
15
|
+
class LastAction(RLEnvWrapper[AS]):
|
|
20
16
|
"""Env wrapper that adds the last action taken by the agents to the extra features."""
|
|
21
17
|
|
|
22
|
-
def __init__(self, env: MARLEnv[
|
|
18
|
+
def __init__(self, env: MARLEnv[AS]):
|
|
23
19
|
assert len(env.extras_shape) == 1, "Adding last action is only possible with 1D extras"
|
|
24
20
|
super().__init__(
|
|
25
21
|
env,
|
|
@@ -37,13 +33,13 @@ class LastAction(RLEnvWrapper[A, AS]):
|
|
|
37
33
|
state.add_extra(self.last_one_hot_actions.flatten())
|
|
38
34
|
return obs, state
|
|
39
35
|
|
|
40
|
-
def step(self, actions
|
|
36
|
+
def step(self, actions):
|
|
41
37
|
step = super().step(actions)
|
|
42
38
|
match self.wrapped.action_space:
|
|
43
|
-
case
|
|
39
|
+
case ContinuousSpace():
|
|
44
40
|
self.last_actions = actions
|
|
45
|
-
case
|
|
46
|
-
self.last_one_hot_actions = self.compute_one_hot_actions(actions)
|
|
41
|
+
case DiscreteSpace() | MultiDiscreteSpace():
|
|
42
|
+
self.last_one_hot_actions = self.compute_one_hot_actions(actions)
|
|
47
43
|
case other:
|
|
48
44
|
raise NotImplementedError(f"Action space {other} not supported")
|
|
49
45
|
step.obs.add_extra(self.last_one_hot_actions)
|
|
@@ -60,7 +56,7 @@ class LastAction(RLEnvWrapper[A, AS]):
|
|
|
60
56
|
self.last_one_hot_actions = flattened_one_hots.reshape(self.n_agents, self.n_actions)
|
|
61
57
|
return super().set_state(state)
|
|
62
58
|
|
|
63
|
-
def compute_one_hot_actions(self, actions
|
|
59
|
+
def compute_one_hot_actions(self, actions) -> npt.NDArray:
|
|
64
60
|
one_hot_actions = np.zeros((self.n_agents, self.n_actions), dtype=np.float32)
|
|
65
61
|
index = np.arange(self.n_agents)
|
|
66
62
|
one_hot_actions[index, actions] = 1.0
|
marlenv/wrappers/paddings.py
CHANGED
|
@@ -1,22 +1,20 @@
|
|
|
1
1
|
import numpy as np
|
|
2
|
-
import numpy.typing as npt
|
|
3
2
|
from dataclasses import dataclass
|
|
4
|
-
from marlenv.models import Observation,
|
|
3
|
+
from marlenv.models import Observation, Space
|
|
5
4
|
from .rlenv_wrapper import RLEnvWrapper, MARLEnv
|
|
6
5
|
from typing_extensions import TypeVar
|
|
7
6
|
|
|
8
7
|
|
|
9
|
-
|
|
10
|
-
AS = TypeVar("AS", bound=ActionSpace, default=ActionSpace)
|
|
8
|
+
AS = TypeVar("AS", bound=Space, default=Space)
|
|
11
9
|
|
|
12
10
|
|
|
13
11
|
@dataclass
|
|
14
|
-
class PadExtras(RLEnvWrapper[
|
|
12
|
+
class PadExtras(RLEnvWrapper[AS]):
|
|
15
13
|
"""RLEnv wrapper that adds extra zeros at the end of the observation extras."""
|
|
16
14
|
|
|
17
15
|
n: int
|
|
18
16
|
|
|
19
|
-
def __init__(self, env: MARLEnv[
|
|
17
|
+
def __init__(self, env: MARLEnv[AS], n_added: int):
|
|
20
18
|
assert len(env.extras_shape) == 1, "PadExtras only accepts 1D extras"
|
|
21
19
|
meanings = env.extras_meanings + [f"Padding-{i}" for i in range(n_added)]
|
|
22
20
|
super().__init__(
|
|
@@ -42,10 +40,10 @@ class PadExtras(RLEnvWrapper[A, AS]):
|
|
|
42
40
|
|
|
43
41
|
|
|
44
42
|
@dataclass
|
|
45
|
-
class PadObservations(RLEnvWrapper[
|
|
43
|
+
class PadObservations(RLEnvWrapper[AS]):
|
|
46
44
|
"""RLEnv wrapper that adds extra zeros at the end of the observation data."""
|
|
47
45
|
|
|
48
|
-
def __init__(self, env: MARLEnv[
|
|
46
|
+
def __init__(self, env: MARLEnv[AS], n_added: int) -> None:
|
|
49
47
|
assert len(env.observation_shape) == 1, "PadObservations only accepts 1D observations"
|
|
50
48
|
super().__init__(env, observation_shape=(env.observation_shape[0] + n_added,))
|
|
51
49
|
self.n = n_added
|
|
@@ -1,21 +1,18 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
2
|
import numpy as np
|
|
3
|
-
|
|
4
|
-
from marlenv.models import ActionSpace
|
|
3
|
+
from marlenv.models import Space
|
|
5
4
|
from .rlenv_wrapper import RLEnvWrapper, MARLEnv
|
|
6
|
-
# from ..models.rl_env import MOMARLEnv
|
|
7
5
|
|
|
8
6
|
from typing_extensions import TypeVar
|
|
9
7
|
|
|
10
|
-
|
|
11
|
-
AS = TypeVar("AS", bound=ActionSpace, default=ActionSpace)
|
|
8
|
+
AS = TypeVar("AS", bound=Space, default=Space)
|
|
12
9
|
|
|
13
10
|
|
|
14
11
|
@dataclass
|
|
15
|
-
class TimePenalty(RLEnvWrapper[
|
|
12
|
+
class TimePenalty(RLEnvWrapper[AS]):
|
|
16
13
|
penalty: float | np.ndarray
|
|
17
14
|
|
|
18
|
-
def __init__(self, env: MARLEnv[
|
|
15
|
+
def __init__(self, env: MARLEnv[AS], penalty: float | list[float]):
|
|
19
16
|
super().__init__(env)
|
|
20
17
|
|
|
21
18
|
if env.is_multi_objective:
|
|
@@ -26,7 +23,7 @@ class TimePenalty(RLEnvWrapper[A, AS]):
|
|
|
26
23
|
assert isinstance(penalty, (float, int))
|
|
27
24
|
self.penalty = penalty
|
|
28
25
|
|
|
29
|
-
def step(self, action
|
|
26
|
+
def step(self, action):
|
|
30
27
|
step = self.wrapped.step(action)
|
|
31
28
|
step.reward = step.reward - self.penalty
|
|
32
29
|
return step
|