multi-agent-rlenv 3.3.7__py3-none-any.whl → 3.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,281 @@
1
+ from abc import abstractmethod
2
+ from dataclasses import dataclass
3
+ from typing import Callable, Optional, TypeVar
4
+
5
+ T = TypeVar("T")
6
+
7
+
8
+ @dataclass
9
+ class Schedule:
10
+ """
11
+ Schedules the value of a varaible over time.
12
+ """
13
+
14
+ name: str
15
+ start_value: float
16
+ end_value: float
17
+ _t: int
18
+ n_steps: int
19
+
20
+ def __init__(self, start_value: float, end_value: float, n_steps: int):
21
+ self.start_value = start_value
22
+ self.end_value = end_value
23
+ self.n_steps = n_steps
24
+ self.name = self.__class__.__name__
25
+ self._t = 0
26
+ self._current_value = self.start_value
27
+
28
+ def update(self, step: Optional[int] = None):
29
+ """Update the value of the schedule. Force a step if given."""
30
+ if step is not None:
31
+ self._t = step
32
+ else:
33
+ self._t += 1
34
+ if self._t >= self.n_steps:
35
+ self._current_value = self.end_value
36
+ else:
37
+ self._current_value = self._compute()
38
+
39
+ @abstractmethod
40
+ def _compute(self) -> float:
41
+ """Compute the value of the schedule"""
42
+
43
+ @property
44
+ def value(self) -> float:
45
+ """Returns the current value of the schedule"""
46
+ return self._current_value
47
+
48
+ @staticmethod
49
+ def constant(value: float):
50
+ return ConstantSchedule(value)
51
+
52
+ @staticmethod
53
+ def linear(start_value: float, end_value: float, n_steps: int):
54
+ return LinearSchedule(start_value, end_value, n_steps)
55
+
56
+ @staticmethod
57
+ def exp(start_value: float, end_value: float, n_steps: int):
58
+ return ExpSchedule(start_value, end_value, n_steps)
59
+
60
+ @staticmethod
61
+ def arbitrary(func: Callable[[int], float], n_steps: Optional[int] = None):
62
+ if n_steps is None:
63
+ n_steps = 0
64
+ return ArbitrarySchedule(func, n_steps)
65
+
66
+ def rounded(self, n_digits: int = 0) -> "RoundedSchedule":
67
+ return RoundedSchedule(self, n_digits)
68
+
69
+ # Operator overloading
70
+ def __mul__(self, other: T) -> T:
71
+ return self.value * other # type: ignore
72
+
73
+ def __rmul__(self, other: T) -> T:
74
+ return self.value * other # type: ignore
75
+
76
+ def __pow__(self, exp: float) -> float:
77
+ return self.value**exp
78
+
79
+ def __rpow__(self, other: T) -> T:
80
+ return other**self.value # type: ignore
81
+
82
+ def __add__(self, other: T) -> T:
83
+ return self.value + other # type: ignore
84
+
85
+ def __radd__(self, other: T) -> T:
86
+ return self.value + other # type: ignore
87
+
88
+ def __neg__(self):
89
+ return -self.value
90
+
91
+ def __pos__(self):
92
+ return +self.value
93
+
94
+ def __sub__(self, other: T) -> T:
95
+ return self.value - other # type: ignore
96
+
97
+ def __rsub__(self, other: T) -> T:
98
+ return other - self.value # type: ignore
99
+
100
+ def __div__(self, other: T) -> T:
101
+ return self.value // other # type: ignore
102
+
103
+ def __rdiv__(self, other: T) -> T:
104
+ return other // self.value # type: ignore
105
+
106
+ def __truediv__(self, other: T) -> T:
107
+ return self.value / other # type: ignore
108
+
109
+ def __rtruediv__(self, other: T) -> T:
110
+ return other / self.value # type: ignore
111
+
112
+ def __lt__(self, other) -> bool:
113
+ return self.value < other
114
+
115
+ def __le__(self, other) -> bool:
116
+ return self.value <= other
117
+
118
+ def __gt__(self, other) -> bool:
119
+ return self.value > other
120
+
121
+ def __ge__(self, other) -> bool:
122
+ return self.value >= other
123
+
124
+ def __eq__(self, other) -> bool:
125
+ if isinstance(other, Schedule):
126
+ if self.start_value != other.start_value:
127
+ return False
128
+ if self.end_value != other.end_value:
129
+ return False
130
+ if self.n_steps != other.n_steps:
131
+ return False
132
+ if type(self) is not type(other):
133
+ return False
134
+ return self.value == other
135
+
136
+ def __ne__(self, other) -> bool:
137
+ return not (self.__eq__(other))
138
+
139
+ def __float__(self):
140
+ return self.value
141
+
142
+ def __int__(self) -> int:
143
+ return int(self.value)
144
+
145
+
146
+ @dataclass(eq=False)
147
+ class LinearSchedule(Schedule):
148
+ a: float
149
+ b: float
150
+
151
+ def __init__(self, start_value: float, end_value: float, n_steps: int):
152
+ super().__init__(start_value, end_value, n_steps)
153
+ self._current_value = self.start_value
154
+ # y = ax + b
155
+ self.a = (self.end_value - self.start_value) / self.n_steps
156
+ self.b = self.start_value
157
+
158
+ def _compute(self):
159
+ return self.a * (self._t) + self.b
160
+
161
+ @property
162
+ def value(self) -> float:
163
+ return self._current_value
164
+
165
+
166
+ @dataclass(eq=False)
167
+ class ExpSchedule(Schedule):
168
+ """Exponential schedule. After n_steps, the value will be min_value.
169
+
170
+ Update formula is next_value = start_value * (min_value / start_value) ** (step / (n - 1))
171
+ """
172
+
173
+ n_steps: int
174
+ base: float
175
+ last_update_step: int
176
+
177
+ def __init__(self, start_value: float, min_value: float, n_steps: int):
178
+ super().__init__(start_value, min_value, n_steps)
179
+ self.base = self.end_value / self.start_value
180
+ self.last_update_step = self.n_steps - 1
181
+
182
+ def _compute(self):
183
+ return self.start_value * (self.base) ** (self._t / (self.n_steps - 1))
184
+
185
+ @property
186
+ def value(self) -> float:
187
+ return self._current_value
188
+
189
+
190
+ @dataclass(eq=False)
191
+ class ConstantSchedule(Schedule):
192
+ def __init__(self, value: float):
193
+ super().__init__(value, value, 0)
194
+ self._value = value
195
+
196
+ def update(self, step=None):
197
+ return
198
+
199
+ @property
200
+ def value(self) -> float:
201
+ return self._value
202
+
203
+
204
+ @dataclass(eq=False)
205
+ class RoundedSchedule(Schedule):
206
+ def __init__(self, schedule: Schedule, n_digits: int):
207
+ super().__init__(schedule.start_value, schedule.end_value, schedule.n_steps)
208
+ self.schedule = schedule
209
+ self.n_digits = n_digits
210
+
211
+ def update(self, step: int | None = None):
212
+ return self.schedule.update(step)
213
+
214
+ def _compute(self) -> float:
215
+ return self.schedule._compute()
216
+
217
+ @property
218
+ def value(self) -> float:
219
+ return round(self.schedule.value, self.n_digits)
220
+
221
+
222
+ @dataclass(eq=False)
223
+ class MultiSchedule(Schedule):
224
+ def __init__(self, schedules: dict[int, Schedule]):
225
+ ordered_schedules, ordered_steps = MultiSchedule._verify(schedules)
226
+ n_steps = ordered_steps[-1] + ordered_schedules[-1].n_steps
227
+ super().__init__(ordered_schedules[0].start_value, ordered_schedules[-1].end_value, n_steps)
228
+ self.schedules = iter(ordered_schedules)
229
+ self.current_schedule = next(self.schedules)
230
+ self.offset = 0
231
+ self.current_end = ordered_steps[1]
232
+
233
+ @staticmethod
234
+ def _verify(schedules: dict[int, Schedule]):
235
+ sorted_steps = sorted(schedules.keys())
236
+ sorted_schedules = [schedules[t] for t in sorted_steps]
237
+ if sorted_steps[0] != 0:
238
+ raise ValueError("First schedule must start at t=0")
239
+ current_step = 0
240
+ for i in range(len(sorted_steps)):
241
+ # Artificially set the end step of ConstantSchedules to the next step
242
+ if isinstance(sorted_schedules[i], ConstantSchedule):
243
+ if i + 1 < len(sorted_steps):
244
+ sorted_schedules[i].n_steps = sorted_steps[i + 1] - sorted_steps[i]
245
+ if sorted_steps[i] != current_step:
246
+ raise ValueError(f"Schedules are not contiguous at t={current_step}")
247
+ current_step += sorted_schedules[i].n_steps
248
+ return sorted_schedules, sorted_steps
249
+
250
+ def update(self, step: int | None = None):
251
+ if step is not None:
252
+ raise NotImplementedError("Cannot update MultiSchedule with a specific step")
253
+ super().update(step)
254
+ # If we reach the end of the current schedule, update to the next one
255
+ # except if we are at the end.
256
+ if self._t == self.current_end and self._t < self.n_steps:
257
+ self.current_schedule = next(self.schedules)
258
+ self.offset = self._t
259
+ self.current_end += self.current_schedule.n_steps
260
+ self.current_schedule.update(self.relative_step)
261
+
262
+ @property
263
+ def relative_step(self):
264
+ return self._t - self.offset
265
+
266
+ def _compute(self) -> float:
267
+ return self.current_schedule._compute()
268
+
269
+ @property
270
+ def value(self):
271
+ return self.current_schedule.value
272
+
273
+
274
+ @dataclass(eq=False)
275
+ class ArbitrarySchedule(Schedule):
276
+ def __init__(self, fn: Callable[[int], float], n_steps: int):
277
+ super().__init__(fn(0), fn(n_steps), n_steps)
278
+ self._func = fn
279
+
280
+ def _compute(self) -> float:
281
+ return self._func(self._t)
@@ -1,19 +1,18 @@
1
1
  import numpy as np
2
- from marlenv.models import MARLEnv, ActionSpace
2
+ from marlenv.models import MARLEnv, Space
3
3
  from dataclasses import dataclass
4
4
  from .rlenv_wrapper import RLEnvWrapper
5
5
 
6
6
  from typing_extensions import TypeVar
7
7
 
8
- A = TypeVar("A", default=np.ndarray)
9
- AS = TypeVar("AS", bound=ActionSpace, default=ActionSpace)
8
+ AS = TypeVar("AS", bound=Space, default=Space)
10
9
 
11
10
 
12
11
  @dataclass
13
- class AgentId(RLEnvWrapper[A, AS]):
12
+ class AgentId(RLEnvWrapper[AS]):
14
13
  """RLEnv wrapper that adds a one-hot encoding of the agent id."""
15
14
 
16
- def __init__(self, env: MARLEnv[A, AS]):
15
+ def __init__(self, env: MARLEnv[AS]):
17
16
  assert len(env.extras_shape) == 1, "AgentIdWrapper only works with single dimension extras"
18
17
  meanings = env.extras_meanings + [f"Agent ID-{i}" for i in range(env.n_agents)]
19
18
  super().__init__(env, extra_shape=(env.n_agents + env.extras_shape[0],), extra_meanings=meanings)
@@ -2,20 +2,19 @@ import numpy as np
2
2
  import numpy.typing as npt
3
3
  from typing_extensions import TypeVar
4
4
  from .rlenv_wrapper import MARLEnv, RLEnvWrapper
5
- from marlenv.models import ActionSpace
5
+ from marlenv.models import Space
6
6
  from dataclasses import dataclass
7
7
 
8
- A = TypeVar("A", default=npt.NDArray)
9
- AS = TypeVar("AS", bound=ActionSpace, default=ActionSpace)
8
+ AS = TypeVar("AS", bound=Space, default=Space)
10
9
 
11
10
 
12
11
  @dataclass
13
- class AvailableActionsMask(RLEnvWrapper[A, AS]):
12
+ class AvailableActionsMask(RLEnvWrapper[AS]):
14
13
  """Permanently masks a subset of the available actions."""
15
14
 
16
15
  action_mask: npt.NDArray[np.bool_]
17
16
 
18
- def __init__(self, env: MARLEnv[A, AS], action_mask: npt.NDArray[np.bool_]):
17
+ def __init__(self, env: MARLEnv[AS], action_mask: npt.NDArray[np.bool_]):
19
18
  super().__init__(env)
20
19
  assert action_mask.shape == (env.n_agents, env.n_actions), "Action mask must have shape (n_agents, n_actions)."
21
20
  n_available_action_per_agent = action_mask.sum(axis=-1)
@@ -27,8 +26,8 @@ class AvailableActionsMask(RLEnvWrapper[A, AS]):
27
26
  obs.available_actions = self.available_actions()
28
27
  return obs, state
29
28
 
30
- def step(self, actions):
31
- step = self.wrapped.step(actions)
29
+ def step(self, action):
30
+ step = self.wrapped.step(action)
32
31
  step.obs.available_actions = self.available_actions()
33
32
  return step
34
33
 
@@ -1,21 +1,19 @@
1
1
  import numpy as np
2
- import numpy.typing as npt
3
2
  from typing_extensions import TypeVar
4
- from marlenv.models import ActionSpace, MARLEnv
3
+ from marlenv.models import Space, MARLEnv
5
4
  from .rlenv_wrapper import RLEnvWrapper
6
5
  from dataclasses import dataclass
7
6
 
8
7
 
9
- A = TypeVar("A", default=npt.NDArray)
10
- AS = TypeVar("AS", bound=ActionSpace, default=ActionSpace)
8
+ AS = TypeVar("AS", bound=Space, default=Space)
11
9
 
12
10
 
13
11
  @dataclass
14
- class AvailableActions(RLEnvWrapper[A, AS]):
12
+ class AvailableActions(RLEnvWrapper[AS]):
15
13
  """Adds the available actions (one-hot) as an extra feature to the observation."""
16
14
 
17
- def __init__(self, env: MARLEnv[A, AS]):
18
- meanings = env.extras_meanings + [f"{a} available" for a in env.action_space.action_names]
15
+ def __init__(self, env: MARLEnv[AS]):
16
+ meanings = env.extras_meanings + [f"{a} available" for a in env.action_space.labels]
19
17
  super().__init__(env, extra_shape=(env.extras_shape[0] + env.n_actions,), extra_meanings=meanings)
20
18
 
21
19
  def reset(self):
@@ -23,7 +21,7 @@ class AvailableActions(RLEnvWrapper[A, AS]):
23
21
  obs.add_extra(self.available_actions().astype(np.float32))
24
22
  return obs, state
25
23
 
26
- def step(self, actions: A):
27
- step = self.wrapped.step(actions)
24
+ def step(self, action):
25
+ step = self.wrapped.step(action)
28
26
  step.obs.add_extra(self.available_actions().astype(np.float32))
29
27
  return step
@@ -1,26 +1,24 @@
1
1
  import random
2
2
  from typing_extensions import TypeVar
3
3
  import numpy as np
4
- import numpy.typing as npt
5
4
  from dataclasses import dataclass
6
5
 
7
- from marlenv.models import MARLEnv, ActionSpace
6
+ from marlenv.models import MARLEnv, Space
8
7
  from .rlenv_wrapper import RLEnvWrapper
9
8
 
10
9
 
11
- A = TypeVar("A", default=npt.NDArray)
12
- AS = TypeVar("AS", bound=ActionSpace, default=ActionSpace)
10
+ AS = TypeVar("AS", bound=Space, default=Space)
13
11
 
14
12
 
15
13
  @dataclass
16
- class Blind(RLEnvWrapper[A, AS]):
14
+ class Blind(RLEnvWrapper[AS]):
17
15
  p: float
18
16
 
19
- def __init__(self, env: MARLEnv[A, AS], p: float | int):
17
+ def __init__(self, env: MARLEnv[AS], p: float | int):
20
18
  super().__init__(env)
21
19
  self.p = float(p)
22
20
 
23
- def step(self, actions: A):
21
+ def step(self, actions):
24
22
  step = super().step(actions)
25
23
  if random.random() < self.p:
26
24
  step.obs.data = np.zeros_like(step.obs.data)
@@ -4,28 +4,26 @@ from typing import Sequence
4
4
 
5
5
  import numpy as np
6
6
  import numpy.typing as npt
7
- from typing_extensions import TypeVar
8
7
 
9
- from marlenv.models import ActionSpace, DiscreteActionSpace, DiscreteSpace, MARLEnv, Observation
8
+ from marlenv.models import DiscreteSpace, MARLEnv, MultiDiscreteSpace, Observation
10
9
 
11
10
  from .rlenv_wrapper import RLEnvWrapper
12
11
 
13
- A = TypeVar("A", bound=npt.NDArray | Sequence[int] | Sequence[Sequence[float]])
14
-
15
12
 
16
13
  @dataclass
17
- class Centralized(RLEnvWrapper[A, DiscreteActionSpace]):
18
- joint_action_space: ActionSpace
14
+ class Centralized(RLEnvWrapper[MultiDiscreteSpace]):
15
+ joint_action_space: DiscreteSpace
19
16
 
20
- def __init__(self, env: MARLEnv[A, DiscreteActionSpace]):
21
- if not isinstance(env.action_space.individual_action_space, DiscreteSpace):
17
+ def __init__(self, env: MARLEnv[MultiDiscreteSpace]):
18
+ if not isinstance(env.action_space, MultiDiscreteSpace):
22
19
  raise NotImplementedError(f"Action space {env.action_space} not supported")
23
20
  joint_observation_shape = (env.observation_shape[0] * env.n_agents, *env.observation_shape[1:])
24
21
  super().__init__(
25
22
  env,
26
- joint_observation_shape,
27
- env.state_shape,
28
- env.extras_shape,
23
+ n_agents=1,
24
+ observation_shape=joint_observation_shape,
25
+ state_shape=env.state_shape,
26
+ state_extra_shape=env.extras_shape,
29
27
  action_space=self._make_joint_action_space(env),
30
28
  )
31
29
 
@@ -37,12 +35,12 @@ class Centralized(RLEnvWrapper[A, DiscreteActionSpace]):
37
35
  obs = super().get_observation()
38
36
  return self._joint_observation(obs)
39
37
 
40
- def _make_joint_action_space(self, env: MARLEnv[A, DiscreteActionSpace]):
38
+ def _make_joint_action_space(self, env: MARLEnv[MultiDiscreteSpace]):
41
39
  agent_actions = list[list[str]]()
42
40
  for agent in range(env.n_agents):
43
- agent_actions.append([f"{agent}-{action}" for action in env.action_space.action_names])
41
+ agent_actions.append([f"{agent}-{action}" for action in env.action_space.labels])
44
42
  action_names = [str(a) for a in product(*agent_actions)]
45
- return DiscreteActionSpace(1, env.n_actions**env.n_agents, action_names)
43
+ return DiscreteSpace(env.n_actions**env.n_agents, action_names).repeat(1)
46
44
 
47
45
  def step(self, actions: npt.NDArray | Sequence):
48
46
  action = actions[0]
@@ -1,20 +1,22 @@
1
- from .rlenv_wrapper import RLEnvWrapper, MARLEnv
2
- from marlenv.models import ActionSpace
3
- from typing_extensions import TypeVar
4
- import numpy.typing as npt
5
- import numpy as np
6
- from dataclasses import dataclass
7
1
  from collections import deque
2
+ from dataclasses import dataclass
3
+
4
+ import numpy as np
5
+ import numpy.typing as npt
6
+ from typing_extensions import TypeVar
7
+
8
+ from marlenv.models import Space
9
+
10
+ from .rlenv_wrapper import MARLEnv, RLEnvWrapper
8
11
 
9
- A = TypeVar("A", default=npt.NDArray)
10
- AS = TypeVar("AS", bound=ActionSpace, default=ActionSpace)
12
+ AS = TypeVar("AS", bound=Space, default=Space)
11
13
 
12
14
 
13
15
  @dataclass
14
- class DelayedReward(RLEnvWrapper[A, AS]):
16
+ class DelayedReward(RLEnvWrapper[AS]):
15
17
  delay: int
16
18
 
17
- def __init__(self, env: MARLEnv[A, AS], delay: int):
19
+ def __init__(self, env: MARLEnv[AS], delay: int):
18
20
  super().__init__(env)
19
21
  self.delay = delay
20
22
  self.reward_queue = deque[npt.NDArray[np.float32]](maxlen=delay + 1)
@@ -25,7 +27,7 @@ class DelayedReward(RLEnvWrapper[A, AS]):
25
27
  self.reward_queue.append(np.zeros(self.reward_space.shape, dtype=np.float32))
26
28
  return super().reset()
27
29
 
28
- def step(self, actions: A):
30
+ def step(self, actions):
29
31
  step = super().step(actions)
30
32
  self.reward_queue.append(step.reward)
31
33
  # If the step is terminal, we sum all the remaining rewards
@@ -1,25 +1,21 @@
1
1
  from dataclasses import dataclass
2
- from typing_extensions import TypeVar
3
- from typing import Sequence
4
2
 
5
3
  import numpy as np
6
4
  import numpy.typing as npt
5
+ from typing_extensions import TypeVar
7
6
 
8
- from marlenv.models import State, ActionSpace, ContinuousActionSpace, DiscreteActionSpace
7
+ from marlenv.models import ContinuousSpace, DiscreteSpace, MultiDiscreteSpace, Space, State
9
8
 
10
9
  from .rlenv_wrapper import MARLEnv, RLEnvWrapper
11
10
 
12
- AS = TypeVar("AS", bound=ActionSpace, default=ActionSpace)
13
- DiscreteActionType = npt.NDArray[np.int64 | np.int32] | Sequence[int]
14
- ContinuousActionType = npt.NDArray[np.float32] | Sequence[Sequence[float]]
15
- A = TypeVar("A", bound=DiscreteActionType | ContinuousActionType)
11
+ AS = TypeVar("AS", bound=Space, default=Space)
16
12
 
17
13
 
18
14
  @dataclass
19
- class LastAction(RLEnvWrapper[A, AS]):
15
+ class LastAction(RLEnvWrapper[AS]):
20
16
  """Env wrapper that adds the last action taken by the agents to the extra features."""
21
17
 
22
- def __init__(self, env: MARLEnv[A, AS]):
18
+ def __init__(self, env: MARLEnv[AS]):
23
19
  assert len(env.extras_shape) == 1, "Adding last action is only possible with 1D extras"
24
20
  super().__init__(
25
21
  env,
@@ -37,13 +33,13 @@ class LastAction(RLEnvWrapper[A, AS]):
37
33
  state.add_extra(self.last_one_hot_actions.flatten())
38
34
  return obs, state
39
35
 
40
- def step(self, actions: A):
36
+ def step(self, actions):
41
37
  step = super().step(actions)
42
38
  match self.wrapped.action_space:
43
- case ContinuousActionSpace():
39
+ case ContinuousSpace():
44
40
  self.last_actions = actions
45
- case DiscreteActionSpace():
46
- self.last_one_hot_actions = self.compute_one_hot_actions(actions) # type: ignore
41
+ case DiscreteSpace() | MultiDiscreteSpace():
42
+ self.last_one_hot_actions = self.compute_one_hot_actions(actions)
47
43
  case other:
48
44
  raise NotImplementedError(f"Action space {other} not supported")
49
45
  step.obs.add_extra(self.last_one_hot_actions)
@@ -60,7 +56,7 @@ class LastAction(RLEnvWrapper[A, AS]):
60
56
  self.last_one_hot_actions = flattened_one_hots.reshape(self.n_agents, self.n_actions)
61
57
  return super().set_state(state)
62
58
 
63
- def compute_one_hot_actions(self, actions: DiscreteActionType) -> npt.NDArray:
59
+ def compute_one_hot_actions(self, actions) -> npt.NDArray:
64
60
  one_hot_actions = np.zeros((self.n_agents, self.n_actions), dtype=np.float32)
65
61
  index = np.arange(self.n_agents)
66
62
  one_hot_actions[index, actions] = 1.0
@@ -1,22 +1,20 @@
1
1
  import numpy as np
2
- import numpy.typing as npt
3
2
  from dataclasses import dataclass
4
- from marlenv.models import Observation, ActionSpace
3
+ from marlenv.models import Observation, Space
5
4
  from .rlenv_wrapper import RLEnvWrapper, MARLEnv
6
5
  from typing_extensions import TypeVar
7
6
 
8
7
 
9
- A = TypeVar("A", default=npt.NDArray)
10
- AS = TypeVar("AS", bound=ActionSpace, default=ActionSpace)
8
+ AS = TypeVar("AS", bound=Space, default=Space)
11
9
 
12
10
 
13
11
  @dataclass
14
- class PadExtras(RLEnvWrapper[A, AS]):
12
+ class PadExtras(RLEnvWrapper[AS]):
15
13
  """RLEnv wrapper that adds extra zeros at the end of the observation extras."""
16
14
 
17
15
  n: int
18
16
 
19
- def __init__(self, env: MARLEnv[A, AS], n_added: int):
17
+ def __init__(self, env: MARLEnv[AS], n_added: int):
20
18
  assert len(env.extras_shape) == 1, "PadExtras only accepts 1D extras"
21
19
  meanings = env.extras_meanings + [f"Padding-{i}" for i in range(n_added)]
22
20
  super().__init__(
@@ -42,10 +40,10 @@ class PadExtras(RLEnvWrapper[A, AS]):
42
40
 
43
41
 
44
42
  @dataclass
45
- class PadObservations(RLEnvWrapper[A, AS]):
43
+ class PadObservations(RLEnvWrapper[AS]):
46
44
  """RLEnv wrapper that adds extra zeros at the end of the observation data."""
47
45
 
48
- def __init__(self, env: MARLEnv[A, AS], n_added: int) -> None:
46
+ def __init__(self, env: MARLEnv[AS], n_added: int) -> None:
49
47
  assert len(env.observation_shape) == 1, "PadObservations only accepts 1D observations"
50
48
  super().__init__(env, observation_shape=(env.observation_shape[0] + n_added,))
51
49
  self.n = n_added
@@ -1,21 +1,18 @@
1
1
  from dataclasses import dataclass
2
2
  import numpy as np
3
- import numpy.typing as npt
4
- from marlenv.models import ActionSpace
3
+ from marlenv.models import Space
5
4
  from .rlenv_wrapper import RLEnvWrapper, MARLEnv
6
- # from ..models.rl_env import MOMARLEnv
7
5
 
8
6
  from typing_extensions import TypeVar
9
7
 
10
- A = TypeVar("A", default=npt.NDArray)
11
- AS = TypeVar("AS", bound=ActionSpace, default=ActionSpace)
8
+ AS = TypeVar("AS", bound=Space, default=Space)
12
9
 
13
10
 
14
11
  @dataclass
15
- class TimePenalty(RLEnvWrapper[A, AS]):
12
+ class TimePenalty(RLEnvWrapper[AS]):
16
13
  penalty: float | np.ndarray
17
14
 
18
- def __init__(self, env: MARLEnv[A, AS], penalty: float | list[float]):
15
+ def __init__(self, env: MARLEnv[AS], penalty: float | list[float]):
19
16
  super().__init__(env)
20
17
 
21
18
  if env.is_multi_objective:
@@ -26,7 +23,7 @@ class TimePenalty(RLEnvWrapper[A, AS]):
26
23
  assert isinstance(penalty, (float, int))
27
24
  self.penalty = penalty
28
25
 
29
- def step(self, action: A):
26
+ def step(self, action):
30
27
  step = self.wrapped.step(action)
31
28
  step.reward = step.reward - self.penalty
32
29
  return step