multi-agent-rlenv 3.4.0__py3-none-any.whl → 3.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
marlenv/models/spaces.py CHANGED
@@ -1,13 +1,11 @@
1
1
  import math
2
2
  from abc import ABC, abstractmethod
3
3
  from dataclasses import dataclass
4
- from typing import Any, Generic, Optional, TypeVar
4
+ from typing import Optional
5
5
 
6
6
  import numpy as np
7
7
  import numpy.typing as npt
8
8
 
9
- S = TypeVar("S", bound="Space")
10
-
11
9
 
12
10
  @dataclass
13
11
  class Space(ABC):
@@ -23,7 +21,7 @@ class Space(ABC):
23
21
  self.labels = labels
24
22
 
25
23
  @abstractmethod
26
- def sample(self, mask: Optional[npt.NDArray[np.bool_]] = None) -> Any:
24
+ def sample(self, mask: Optional[npt.NDArray[np.bool_]] = None) -> npt.NDArray[np.float32]:
27
25
  """Sample a value from the space."""
28
26
 
29
27
  def __eq__(self, value: object) -> bool:
@@ -34,6 +32,16 @@ class Space(ABC):
34
32
  def __ne__(self, value: object) -> bool:
35
33
  return not self.__eq__(value)
36
34
 
35
+ @property
36
+ @abstractmethod
37
+ def is_discrete(self) -> bool:
38
+ """Whether the space is discrete."""
39
+
40
+ @property
41
+ def is_continuous(self) -> bool:
42
+ """Whether the space is continuous."""
43
+ return not self.is_discrete
44
+
37
45
 
38
46
  @dataclass
39
47
  class DiscreteSpace(Space):
@@ -45,8 +53,8 @@ class DiscreteSpace(Space):
45
53
  self.size = size
46
54
  self.space = np.arange(size)
47
55
 
48
- def sample(self, mask: Optional[npt.NDArray[np.bool_]] = None) -> int:
49
- space = self.space
56
+ def sample(self, mask: Optional[npt.NDArray[np.bool]] = None):
57
+ space = self.space.copy()
50
58
  if mask is not None:
51
59
  space = space[mask]
52
60
  return int(np.random.choice(space))
@@ -58,6 +66,25 @@ class DiscreteSpace(Space):
58
66
  return False
59
67
  return super().__eq__(value)
60
68
 
69
+ @property
70
+ def is_discrete(self) -> bool:
71
+ return True
72
+
73
+ @staticmethod
74
+ def action(size, labels: Optional[list[str]] = None):
75
+ """
76
+ Create a discrete action space where the default labels are set to "Action-n".
77
+ """
78
+ if labels is None:
79
+ labels = [f"Action {i}" for i in range(size)]
80
+ return DiscreteSpace(size, labels)
81
+
82
+ def repeat(self, n: int):
83
+ """
84
+ Repeat the discrete space n times.
85
+ """
86
+ return MultiDiscreteSpace(*([self] * n), labels=self.labels)
87
+
61
88
 
62
89
  @dataclass
63
90
  class MultiDiscreteSpace(Space):
@@ -75,10 +102,10 @@ class MultiDiscreteSpace(Space):
75
102
  def from_sizes(cls, *sizes: int):
76
103
  return cls(*(DiscreteSpace(size) for size in sizes))
77
104
 
78
- def sample(self, masks: Optional[npt.NDArray[np.bool_] | list[npt.NDArray[np.bool_]]] = None):
79
- if masks is None:
105
+ def sample(self, mask: Optional[npt.NDArray[np.bool] | list[npt.NDArray[np.bool]]] = None):
106
+ if mask is None:
80
107
  return np.array([space.sample() for space in self.spaces], dtype=np.int32)
81
- return np.array([space.sample(mask) for mask, space in zip(masks, self.spaces)], dtype=np.int32)
108
+ return np.array([space.sample(mask=mask) for mask, space in zip(mask, self.spaces)], dtype=np.int32)
82
109
 
83
110
  def __eq__(self, value: object) -> bool:
84
111
  if not isinstance(value, MultiDiscreteSpace):
@@ -90,6 +117,10 @@ class MultiDiscreteSpace(Space):
90
117
  return False
91
118
  return super().__eq__(value)
92
119
 
120
+ @property
121
+ def is_discrete(self) -> bool:
122
+ return True
123
+
93
124
 
94
125
  @dataclass
95
126
  class ContinuousSpace(Space):
@@ -100,23 +131,35 @@ class ContinuousSpace(Space):
100
131
  high: npt.NDArray[np.float32]
101
132
  """Upper bound of the space for each dimension."""
102
133
 
103
- @staticmethod
104
- def from_bounds(
105
- low: int | float | list | npt.NDArray[np.float32],
106
- high: int | float | list | npt.NDArray[np.float32],
134
+ def __init__(
135
+ self,
136
+ low: int | float | list | npt.NDArray[np.float32] | None,
137
+ high: int | float | list | npt.NDArray[np.float32] | None,
107
138
  labels: Optional[list[str]] = None,
108
139
  ):
109
140
  match low:
141
+ case None:
142
+ assert high is not None, "If low is None, high must be set to infer the shape."
143
+ shape = ContinuousSpace.get_shape(high)
144
+ low = np.full(shape, -np.inf, dtype=np.float32)
110
145
  case list():
111
146
  low = np.array(low, dtype=np.float32)
112
147
  case float() | int():
113
148
  low = np.array([low], dtype=np.float32)
114
149
  match high:
150
+ case None:
151
+ assert low is not None, "If high is None, low must be set to infer the shape."
152
+ shape = ContinuousSpace.get_shape(low)
153
+ high = np.full(shape, np.inf, dtype=np.float32)
115
154
  case list():
116
155
  high = np.array(high, dtype=np.float32)
117
156
  case float() | int():
118
157
  high = np.array([high], dtype=np.float32)
119
- return ContinuousSpace(low, high, labels)
158
+ assert low.shape == high.shape, f"Low and high must have the same shape. Low shape: {low.shape}, high shape: {high.shape}"
159
+ assert np.all(low <= high), "All elements in low must be less than the corresponding elements in high."
160
+ Space.__init__(self, low.shape, labels)
161
+ self.low = low
162
+ self.high = high
120
163
 
121
164
  @staticmethod
122
165
  def from_shape(
@@ -143,20 +186,24 @@ class ContinuousSpace(Space):
143
186
  high = np.array(high, dtype=np.float32)
144
187
  return ContinuousSpace(low, high, labels)
145
188
 
146
- def __init__(
147
- self,
148
- low: npt.NDArray[np.float32],
149
- high: npt.NDArray[np.float32],
150
- labels: Optional[list[str]] = None,
151
- ):
152
- assert low.shape == high.shape, "Low and high must have the same shape."
153
- assert np.all(low <= high), "All elements in low must be less than the corresponding elements in high."
154
- Space.__init__(self, low.shape, labels)
155
- self.low = low
156
- self.high = high
189
+ def clamp(self, action: np.ndarray | list):
190
+ """Clamp the action to the bounds of the space."""
191
+ if isinstance(action, list):
192
+ action = np.array(action)
193
+ return np.clip(action, self.low, self.high)
194
+
195
+ def sample(self) -> npt.NDArray[np.float32]:
196
+ r = np.random.random(self.shape) * (self.high - self.low) + self.low
197
+ return r.astype(np.float32)
157
198
 
158
- def sample(self, *_):
159
- return np.random.random(self.shape) * (self.high - self.low) + self.low
199
+ @staticmethod
200
+ def get_shape(item: float | int | list | npt.NDArray[np.float32]) -> tuple[int, ...]:
201
+ """Get the shape of the item."""
202
+ if isinstance(item, list):
203
+ item = np.array(item)
204
+ if isinstance(item, np.ndarray):
205
+ return item.shape
206
+ return (1,)
160
207
 
161
208
  def __eq__(self, value: object) -> bool:
162
209
  if not isinstance(value, ContinuousSpace):
@@ -167,59 +214,19 @@ class ContinuousSpace(Space):
167
214
  return False
168
215
  return super().__eq__(value)
169
216
 
170
-
171
- @dataclass
172
- class ActionSpace(Space, Generic[S]):
173
- n_agents: int
174
- """Number of agents."""
175
- action_names: list[str]
176
- """The meaning of each action."""
177
- n_actions: int
178
- individual_action_space: S
179
-
180
- def __init__(self, n_agents: int, individual_action_space: S, action_names: Optional[list] = None):
181
- Space.__init__(self, (n_agents, *individual_action_space.shape), action_names)
182
- self.n_agents = n_agents
183
- self.individual_action_space = individual_action_space
184
- self.n_actions = math.prod(individual_action_space.shape)
185
- self.action_names = action_names or [f"Action {i}" for i in range(self.n_actions)]
186
-
187
- def sample(self, mask: np.ndarray | None = None):
188
- res = []
189
- for i in range(self.n_agents):
190
- if mask is not None:
191
- m = mask[i]
192
- else:
193
- m = None
194
- res.append(self.individual_action_space.sample(m))
195
- return np.array(res)
196
-
197
- def __eq__(self, value: object) -> bool:
198
- if not isinstance(value, ActionSpace):
199
- return False
200
- if self.n_agents != value.n_agents:
201
- return False
202
- if self.n_actions != value.n_actions:
203
- return False
204
- if self.individual_action_space != value.individual_action_space:
205
- return False
206
- return super().__eq__(value)
207
-
208
-
209
- @dataclass
210
- class DiscreteActionSpace(ActionSpace[DiscreteSpace]):
211
- def __init__(self, n_agents: int, n_actions: int, action_names: Optional[list[str]] = None):
212
- individual_action_space = DiscreteSpace(n_actions, action_names)
213
- super().__init__(n_agents, individual_action_space, action_names)
214
-
215
-
216
- @dataclass
217
- class MultiDiscreteActionSpace(ActionSpace[MultiDiscreteSpace]):
218
- pass
219
-
220
-
221
- @dataclass
222
- class ContinuousActionSpace(ActionSpace[ContinuousSpace]):
223
- def __init__(self, n_agents: int, low: np.ndarray | list, high: np.ndarray | list, action_names: list | None = None):
224
- space = ContinuousSpace.from_bounds(low, high, action_names)
225
- super().__init__(n_agents, space, action_names)
217
+ def repeat(self, n: int):
218
+ """
219
+ Repeat the continuous space n times to become of shape (n, *shape).
220
+ """
221
+ low = np.tile(self.low, (n, 1))
222
+ high = np.tile(self.high, (n, 1))
223
+ return ContinuousSpace.from_shape(
224
+ (n, *self.shape),
225
+ low=low,
226
+ high=high,
227
+ labels=self.labels,
228
+ )
229
+
230
+ @property
231
+ def is_discrete(self) -> bool:
232
+ return False
@@ -1,6 +1,5 @@
1
1
  from dataclasses import dataclass
2
- from typing import Any, Generic, Sequence
3
- from typing_extensions import TypeVar
2
+ from typing import Any, Sequence
4
3
 
5
4
  import numpy as np
6
5
  import numpy.typing as npt
@@ -10,16 +9,13 @@ from .state import State
10
9
  from .step import Step
11
10
 
12
11
 
13
- A = TypeVar("A", default=np.ndarray)
14
-
15
-
16
12
  @dataclass
17
- class Transition(Generic[A]):
13
+ class Transition:
18
14
  """Transition model"""
19
15
 
20
16
  obs: Observation
21
17
  state: State
22
- action: A
18
+ action: np.ndarray
23
19
  reward: npt.NDArray[np.float32]
24
20
  done: bool
25
21
  info: dict[str, Any]
@@ -32,7 +28,7 @@ class Transition(Generic[A]):
32
28
  self,
33
29
  obs: Observation,
34
30
  state: State,
35
- action: A,
31
+ action: np.ndarray | Sequence[float],
36
32
  reward: npt.NDArray[np.float32] | float | Sequence[float],
37
33
  done: bool,
38
34
  info: dict[str, Any],
@@ -65,14 +61,14 @@ class Transition(Generic[A]):
65
61
  def from_step(
66
62
  prev_obs: Observation,
67
63
  prev_state: State,
68
- actions: A,
64
+ action: np.ndarray | Sequence[float],
69
65
  step: Step,
70
66
  **kwargs,
71
67
  ):
72
68
  return Transition(
73
69
  obs=prev_obs,
74
70
  state=prev_state,
75
- action=actions,
71
+ action=action,
76
72
  reward=step.reward,
77
73
  done=step.done,
78
74
  info=step.info,
@@ -1,19 +1,18 @@
1
1
  import numpy as np
2
- from marlenv.models import MARLEnv, ActionSpace
2
+ from marlenv.models import MARLEnv, Space
3
3
  from dataclasses import dataclass
4
4
  from .rlenv_wrapper import RLEnvWrapper
5
5
 
6
6
  from typing_extensions import TypeVar
7
7
 
8
- A = TypeVar("A", default=np.ndarray)
9
- AS = TypeVar("AS", bound=ActionSpace, default=ActionSpace)
8
+ AS = TypeVar("AS", bound=Space, default=Space)
10
9
 
11
10
 
12
11
  @dataclass
13
- class AgentId(RLEnvWrapper[A, AS]):
12
+ class AgentId(RLEnvWrapper[AS]):
14
13
  """RLEnv wrapper that adds a one-hot encoding of the agent id."""
15
14
 
16
- def __init__(self, env: MARLEnv[A, AS]):
15
+ def __init__(self, env: MARLEnv[AS]):
17
16
  assert len(env.extras_shape) == 1, "AgentIdWrapper only works with single dimension extras"
18
17
  meanings = env.extras_meanings + [f"Agent ID-{i}" for i in range(env.n_agents)]
19
18
  super().__init__(env, extra_shape=(env.n_agents + env.extras_shape[0],), extra_meanings=meanings)
@@ -2,20 +2,19 @@ import numpy as np
2
2
  import numpy.typing as npt
3
3
  from typing_extensions import TypeVar
4
4
  from .rlenv_wrapper import MARLEnv, RLEnvWrapper
5
- from marlenv.models import ActionSpace
5
+ from marlenv.models import Space
6
6
  from dataclasses import dataclass
7
7
 
8
- A = TypeVar("A", default=npt.NDArray)
9
- AS = TypeVar("AS", bound=ActionSpace, default=ActionSpace)
8
+ AS = TypeVar("AS", bound=Space, default=Space)
10
9
 
11
10
 
12
11
  @dataclass
13
- class AvailableActionsMask(RLEnvWrapper[A, AS]):
12
+ class AvailableActionsMask(RLEnvWrapper[AS]):
14
13
  """Permanently masks a subset of the available actions."""
15
14
 
16
15
  action_mask: npt.NDArray[np.bool_]
17
16
 
18
- def __init__(self, env: MARLEnv[A, AS], action_mask: npt.NDArray[np.bool_]):
17
+ def __init__(self, env: MARLEnv[AS], action_mask: npt.NDArray[np.bool_]):
19
18
  super().__init__(env)
20
19
  assert action_mask.shape == (env.n_agents, env.n_actions), "Action mask must have shape (n_agents, n_actions)."
21
20
  n_available_action_per_agent = action_mask.sum(axis=-1)
@@ -27,8 +26,8 @@ class AvailableActionsMask(RLEnvWrapper[A, AS]):
27
26
  obs.available_actions = self.available_actions()
28
27
  return obs, state
29
28
 
30
- def step(self, actions):
31
- step = self.wrapped.step(actions)
29
+ def step(self, action):
30
+ step = self.wrapped.step(action)
32
31
  step.obs.available_actions = self.available_actions()
33
32
  return step
34
33
 
@@ -1,21 +1,19 @@
1
1
  import numpy as np
2
- import numpy.typing as npt
3
2
  from typing_extensions import TypeVar
4
- from marlenv.models import ActionSpace, MARLEnv
3
+ from marlenv.models import Space, MARLEnv
5
4
  from .rlenv_wrapper import RLEnvWrapper
6
5
  from dataclasses import dataclass
7
6
 
8
7
 
9
- A = TypeVar("A", default=npt.NDArray)
10
- AS = TypeVar("AS", bound=ActionSpace, default=ActionSpace)
8
+ AS = TypeVar("AS", bound=Space, default=Space)
11
9
 
12
10
 
13
11
  @dataclass
14
- class AvailableActions(RLEnvWrapper[A, AS]):
12
+ class AvailableActions(RLEnvWrapper[AS]):
15
13
  """Adds the available actions (one-hot) as an extra feature to the observation."""
16
14
 
17
- def __init__(self, env: MARLEnv[A, AS]):
18
- meanings = env.extras_meanings + [f"{a} available" for a in env.action_space.action_names]
15
+ def __init__(self, env: MARLEnv[AS]):
16
+ meanings = env.extras_meanings + [f"{a} available" for a in env.action_space.labels]
19
17
  super().__init__(env, extra_shape=(env.extras_shape[0] + env.n_actions,), extra_meanings=meanings)
20
18
 
21
19
  def reset(self):
@@ -23,7 +21,7 @@ class AvailableActions(RLEnvWrapper[A, AS]):
23
21
  obs.add_extra(self.available_actions().astype(np.float32))
24
22
  return obs, state
25
23
 
26
- def step(self, actions: A):
27
- step = self.wrapped.step(actions)
24
+ def step(self, action):
25
+ step = self.wrapped.step(action)
28
26
  step.obs.add_extra(self.available_actions().astype(np.float32))
29
27
  return step
@@ -1,26 +1,24 @@
1
1
  import random
2
2
  from typing_extensions import TypeVar
3
3
  import numpy as np
4
- import numpy.typing as npt
5
4
  from dataclasses import dataclass
6
5
 
7
- from marlenv.models import MARLEnv, ActionSpace
6
+ from marlenv.models import MARLEnv, Space
8
7
  from .rlenv_wrapper import RLEnvWrapper
9
8
 
10
9
 
11
- A = TypeVar("A", default=npt.NDArray)
12
- AS = TypeVar("AS", bound=ActionSpace, default=ActionSpace)
10
+ AS = TypeVar("AS", bound=Space, default=Space)
13
11
 
14
12
 
15
13
  @dataclass
16
- class Blind(RLEnvWrapper[A, AS]):
14
+ class Blind(RLEnvWrapper[AS]):
17
15
  p: float
18
16
 
19
- def __init__(self, env: MARLEnv[A, AS], p: float | int):
17
+ def __init__(self, env: MARLEnv[AS], p: float | int):
20
18
  super().__init__(env)
21
19
  self.p = float(p)
22
20
 
23
- def step(self, actions: A):
21
+ def step(self, actions):
24
22
  step = super().step(actions)
25
23
  if random.random() < self.p:
26
24
  step.obs.data = np.zeros_like(step.obs.data)
@@ -4,28 +4,26 @@ from typing import Sequence
4
4
 
5
5
  import numpy as np
6
6
  import numpy.typing as npt
7
- from typing_extensions import TypeVar
8
7
 
9
- from marlenv.models import ActionSpace, DiscreteActionSpace, DiscreteSpace, MARLEnv, Observation
8
+ from marlenv.models import DiscreteSpace, MARLEnv, MultiDiscreteSpace, Observation
10
9
 
11
10
  from .rlenv_wrapper import RLEnvWrapper
12
11
 
13
- A = TypeVar("A", bound=npt.NDArray | Sequence[int] | Sequence[Sequence[float]])
14
-
15
12
 
16
13
  @dataclass
17
- class Centralized(RLEnvWrapper[A, DiscreteActionSpace]):
18
- joint_action_space: ActionSpace
14
+ class Centralized(RLEnvWrapper[MultiDiscreteSpace]):
15
+ joint_action_space: DiscreteSpace
19
16
 
20
- def __init__(self, env: MARLEnv[A, DiscreteActionSpace]):
21
- if not isinstance(env.action_space.individual_action_space, DiscreteSpace):
17
+ def __init__(self, env: MARLEnv[MultiDiscreteSpace]):
18
+ if not isinstance(env.action_space, MultiDiscreteSpace):
22
19
  raise NotImplementedError(f"Action space {env.action_space} not supported")
23
20
  joint_observation_shape = (env.observation_shape[0] * env.n_agents, *env.observation_shape[1:])
24
21
  super().__init__(
25
22
  env,
26
- joint_observation_shape,
27
- env.state_shape,
28
- env.extras_shape,
23
+ n_agents=1,
24
+ observation_shape=joint_observation_shape,
25
+ state_shape=env.state_shape,
26
+ state_extra_shape=env.extras_shape,
29
27
  action_space=self._make_joint_action_space(env),
30
28
  )
31
29
 
@@ -37,12 +35,12 @@ class Centralized(RLEnvWrapper[A, DiscreteActionSpace]):
37
35
  obs = super().get_observation()
38
36
  return self._joint_observation(obs)
39
37
 
40
- def _make_joint_action_space(self, env: MARLEnv[A, DiscreteActionSpace]):
38
+ def _make_joint_action_space(self, env: MARLEnv[MultiDiscreteSpace]):
41
39
  agent_actions = list[list[str]]()
42
40
  for agent in range(env.n_agents):
43
- agent_actions.append([f"{agent}-{action}" for action in env.action_space.action_names])
41
+ agent_actions.append([f"{agent}-{action}" for action in env.action_space.labels])
44
42
  action_names = [str(a) for a in product(*agent_actions)]
45
- return DiscreteActionSpace(1, env.n_actions**env.n_agents, action_names)
43
+ return DiscreteSpace(env.n_actions**env.n_agents, action_names).repeat(1)
46
44
 
47
45
  def step(self, actions: npt.NDArray | Sequence):
48
46
  action = actions[0]
@@ -1,20 +1,22 @@
1
- from .rlenv_wrapper import RLEnvWrapper, MARLEnv
2
- from marlenv.models import ActionSpace
3
- from typing_extensions import TypeVar
4
- import numpy.typing as npt
5
- import numpy as np
6
- from dataclasses import dataclass
7
1
  from collections import deque
2
+ from dataclasses import dataclass
3
+
4
+ import numpy as np
5
+ import numpy.typing as npt
6
+ from typing_extensions import TypeVar
7
+
8
+ from marlenv.models import Space
9
+
10
+ from .rlenv_wrapper import MARLEnv, RLEnvWrapper
8
11
 
9
- A = TypeVar("A", default=npt.NDArray)
10
- AS = TypeVar("AS", bound=ActionSpace, default=ActionSpace)
12
+ AS = TypeVar("AS", bound=Space, default=Space)
11
13
 
12
14
 
13
15
  @dataclass
14
- class DelayedReward(RLEnvWrapper[A, AS]):
16
+ class DelayedReward(RLEnvWrapper[AS]):
15
17
  delay: int
16
18
 
17
- def __init__(self, env: MARLEnv[A, AS], delay: int):
19
+ def __init__(self, env: MARLEnv[AS], delay: int):
18
20
  super().__init__(env)
19
21
  self.delay = delay
20
22
  self.reward_queue = deque[npt.NDArray[np.float32]](maxlen=delay + 1)
@@ -25,7 +27,7 @@ class DelayedReward(RLEnvWrapper[A, AS]):
25
27
  self.reward_queue.append(np.zeros(self.reward_space.shape, dtype=np.float32))
26
28
  return super().reset()
27
29
 
28
- def step(self, actions: A):
30
+ def step(self, actions):
29
31
  step = super().step(actions)
30
32
  self.reward_queue.append(step.reward)
31
33
  # If the step is terminal, we sum all the remaining rewards
@@ -1,25 +1,21 @@
1
1
  from dataclasses import dataclass
2
- from typing_extensions import TypeVar
3
- from typing import Sequence
4
2
 
5
3
  import numpy as np
6
4
  import numpy.typing as npt
5
+ from typing_extensions import TypeVar
7
6
 
8
- from marlenv.models import State, ActionSpace, ContinuousActionSpace, DiscreteActionSpace
7
+ from marlenv.models import ContinuousSpace, DiscreteSpace, MultiDiscreteSpace, Space, State
9
8
 
10
9
  from .rlenv_wrapper import MARLEnv, RLEnvWrapper
11
10
 
12
- AS = TypeVar("AS", bound=ActionSpace, default=ActionSpace)
13
- DiscreteActionType = npt.NDArray[np.int64 | np.int32] | Sequence[int]
14
- ContinuousActionType = npt.NDArray[np.float32] | Sequence[Sequence[float]]
15
- A = TypeVar("A", bound=DiscreteActionType | ContinuousActionType)
11
+ AS = TypeVar("AS", bound=Space, default=Space)
16
12
 
17
13
 
18
14
  @dataclass
19
- class LastAction(RLEnvWrapper[A, AS]):
15
+ class LastAction(RLEnvWrapper[AS]):
20
16
  """Env wrapper that adds the last action taken by the agents to the extra features."""
21
17
 
22
- def __init__(self, env: MARLEnv[A, AS]):
18
+ def __init__(self, env: MARLEnv[AS]):
23
19
  assert len(env.extras_shape) == 1, "Adding last action is only possible with 1D extras"
24
20
  super().__init__(
25
21
  env,
@@ -37,13 +33,13 @@ class LastAction(RLEnvWrapper[A, AS]):
37
33
  state.add_extra(self.last_one_hot_actions.flatten())
38
34
  return obs, state
39
35
 
40
- def step(self, actions: A):
36
+ def step(self, actions):
41
37
  step = super().step(actions)
42
38
  match self.wrapped.action_space:
43
- case ContinuousActionSpace():
39
+ case ContinuousSpace():
44
40
  self.last_actions = actions
45
- case DiscreteActionSpace():
46
- self.last_one_hot_actions = self.compute_one_hot_actions(actions) # type: ignore
41
+ case DiscreteSpace() | MultiDiscreteSpace():
42
+ self.last_one_hot_actions = self.compute_one_hot_actions(actions)
47
43
  case other:
48
44
  raise NotImplementedError(f"Action space {other} not supported")
49
45
  step.obs.add_extra(self.last_one_hot_actions)
@@ -60,7 +56,7 @@ class LastAction(RLEnvWrapper[A, AS]):
60
56
  self.last_one_hot_actions = flattened_one_hots.reshape(self.n_agents, self.n_actions)
61
57
  return super().set_state(state)
62
58
 
63
- def compute_one_hot_actions(self, actions: DiscreteActionType) -> npt.NDArray:
59
+ def compute_one_hot_actions(self, actions) -> npt.NDArray:
64
60
  one_hot_actions = np.zeros((self.n_agents, self.n_actions), dtype=np.float32)
65
61
  index = np.arange(self.n_agents)
66
62
  one_hot_actions[index, actions] = 1.0
@@ -1,22 +1,20 @@
1
1
  import numpy as np
2
- import numpy.typing as npt
3
2
  from dataclasses import dataclass
4
- from marlenv.models import Observation, ActionSpace
3
+ from marlenv.models import Observation, Space
5
4
  from .rlenv_wrapper import RLEnvWrapper, MARLEnv
6
5
  from typing_extensions import TypeVar
7
6
 
8
7
 
9
- A = TypeVar("A", default=npt.NDArray)
10
- AS = TypeVar("AS", bound=ActionSpace, default=ActionSpace)
8
+ AS = TypeVar("AS", bound=Space, default=Space)
11
9
 
12
10
 
13
11
  @dataclass
14
- class PadExtras(RLEnvWrapper[A, AS]):
12
+ class PadExtras(RLEnvWrapper[AS]):
15
13
  """RLEnv wrapper that adds extra zeros at the end of the observation extras."""
16
14
 
17
15
  n: int
18
16
 
19
- def __init__(self, env: MARLEnv[A, AS], n_added: int):
17
+ def __init__(self, env: MARLEnv[AS], n_added: int):
20
18
  assert len(env.extras_shape) == 1, "PadExtras only accepts 1D extras"
21
19
  meanings = env.extras_meanings + [f"Padding-{i}" for i in range(n_added)]
22
20
  super().__init__(
@@ -42,10 +40,10 @@ class PadExtras(RLEnvWrapper[A, AS]):
42
40
 
43
41
 
44
42
  @dataclass
45
- class PadObservations(RLEnvWrapper[A, AS]):
43
+ class PadObservations(RLEnvWrapper[AS]):
46
44
  """RLEnv wrapper that adds extra zeros at the end of the observation data."""
47
45
 
48
- def __init__(self, env: MARLEnv[A, AS], n_added: int) -> None:
46
+ def __init__(self, env: MARLEnv[AS], n_added: int) -> None:
49
47
  assert len(env.observation_shape) == 1, "PadObservations only accepts 1D observations"
50
48
  super().__init__(env, observation_shape=(env.observation_shape[0] + n_added,))
51
49
  self.n = n_added
@@ -1,21 +1,18 @@
1
1
  from dataclasses import dataclass
2
2
  import numpy as np
3
- import numpy.typing as npt
4
- from marlenv.models import ActionSpace
3
+ from marlenv.models import Space
5
4
  from .rlenv_wrapper import RLEnvWrapper, MARLEnv
6
- # from ..models.rl_env import MOMARLEnv
7
5
 
8
6
  from typing_extensions import TypeVar
9
7
 
10
- A = TypeVar("A", default=npt.NDArray)
11
- AS = TypeVar("AS", bound=ActionSpace, default=ActionSpace)
8
+ AS = TypeVar("AS", bound=Space, default=Space)
12
9
 
13
10
 
14
11
  @dataclass
15
- class TimePenalty(RLEnvWrapper[A, AS]):
12
+ class TimePenalty(RLEnvWrapper[AS]):
16
13
  penalty: float | np.ndarray
17
14
 
18
- def __init__(self, env: MARLEnv[A, AS], penalty: float | list[float]):
15
+ def __init__(self, env: MARLEnv[AS], penalty: float | list[float]):
19
16
  super().__init__(env)
20
17
 
21
18
  if env.is_multi_objective:
@@ -26,7 +23,7 @@ class TimePenalty(RLEnvWrapper[A, AS]):
26
23
  assert isinstance(penalty, (float, int))
27
24
  self.penalty = penalty
28
25
 
29
- def step(self, action: A):
26
+ def step(self, action):
30
27
  step = self.wrapped.step(action)
31
28
  step.reward = step.reward - self.penalty
32
29
  return step