multi-agent-rlenv 3.3.7__py3-none-any.whl → 3.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
marlenv/models/episode.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from dataclasses import dataclass
2
2
  from functools import cached_property
3
- from typing import Any, Callable, Generic, Optional, Sequence, TypeVar, overload
3
+ from typing import Any, Callable, Optional, Sequence, overload
4
4
 
5
5
  import numpy as np
6
6
  import numpy.typing as npt
@@ -14,11 +14,8 @@ from .env import MARLEnv
14
14
  from marlenv.exceptions import EnvironmentMismatchException, ReplayMismatchException
15
15
 
16
16
 
17
- A = TypeVar("A")
18
-
19
-
20
17
  @dataclass
21
- class Episode(Generic[A]):
18
+ class Episode:
22
19
  """Episode model made of observations, actions, rewards, ..."""
23
20
 
24
21
  all_observations: list[npt.NDArray[np.float32]]
@@ -55,7 +52,7 @@ class Episode(Generic[A]):
55
52
  )
56
53
 
57
54
  @staticmethod
58
- def from_transitions(transitions: Sequence[Transition[A]]) -> "Episode":
55
+ def from_transitions(transitions: Sequence[Transition]) -> "Episode":
59
56
  """Create an episode from a list of transitions"""
60
57
  episode = Episode.new(transitions[0].obs, transitions[0].state)
61
58
  for transition in transitions:
@@ -179,9 +176,9 @@ class Episode(Generic[A]):
179
176
  @cached_property
180
177
  def dones(self):
181
178
  """The done flags for each transition"""
182
- dones = np.zeros_like(self.rewards, dtype=np.float32)
179
+ dones = np.zeros_like(self.rewards, dtype=np.bool)
183
180
  if self.is_done:
184
- dones[self.episode_len - 1 :] = 1.0
181
+ dones[self.episode_len - 1 :] = True
185
182
  return dones
186
183
 
187
184
  @property
@@ -214,11 +211,11 @@ class Episode(Generic[A]):
214
211
 
215
212
  def replay(
216
213
  self,
217
- env: MARLEnv[A, Any],
214
+ env: MARLEnv,
218
215
  seed: Optional[int] = None,
219
216
  *,
220
- after_reset: Optional[Callable[[Observation, State, MARLEnv[A]], None]] = None,
221
- after_step: Optional[Callable[[int, Step, MARLEnv[A]], None]] = None,
217
+ after_reset: Optional[Callable[[Observation, State, MARLEnv], None]] = None,
218
+ after_step: Optional[Callable[[int, Step, MARLEnv], None]] = None,
222
219
  ):
223
220
  """
224
221
  Replay the episode in the environment (i.e. perform the actions) and assert that the outcomes match.
@@ -243,12 +240,12 @@ class Episode(Generic[A]):
243
240
  raise ReplayMismatchException("observation", step.obs.data, self.next_obs[i], time_step=i)
244
241
  if not np.array_equal(step.state.data, self.next_states[i]):
245
242
  raise ReplayMismatchException("state", step.state.data, self.next_states[i], time_step=i)
246
- if not np.array_equal(step.reward, self.rewards[i]):
243
+ if not np.isclose(step.reward, self.rewards[i]):
247
244
  raise ReplayMismatchException("reward", step.reward, self.rewards[i], time_step=i)
248
245
  if after_step is not None:
249
246
  after_step(i, step, env)
250
247
 
251
- def get_images(self, env: MARLEnv[A, Any], seed: Optional[int] = None) -> list[np.ndarray]:
248
+ def get_images(self, env: MARLEnv, seed: Optional[int] = None) -> list[np.ndarray]:
252
249
  images = []
253
250
 
254
251
  def collect_image(*_, **__):
@@ -257,7 +254,7 @@ class Episode(Generic[A]):
257
254
  self.replay(env, seed, after_reset=collect_image, after_step=collect_image)
258
255
  return images
259
256
 
260
- def render(self, env: MARLEnv[A, Any], seed: Optional[int] = None, fps: int = 5):
257
+ def render(self, env: MARLEnv, seed: Optional[int] = None, fps: int = 5):
261
258
  def render_callback(*_, **__):
262
259
  env.render()
263
260
  cv2.waitKey(1000 // fps)
@@ -288,10 +285,10 @@ class Episode(Generic[A]):
288
285
  return returns
289
286
 
290
287
  @overload
291
- def add(self, transition: Transition[A], /): ...
288
+ def add(self, transition: Transition, /): ...
292
289
 
293
290
  @overload
294
- def add(self, step: Step, action: A, /): ...
291
+ def add(self, step: Step, action: np.ndarray, /): ...
295
292
 
296
293
  def add(self, *data):
297
294
  match data:
@@ -322,10 +319,10 @@ class Episode(Generic[A]):
322
319
 
323
320
  def add_data(
324
321
  self,
325
- next_obs,
326
- next_state,
327
- action: A,
328
- reward: np.ndarray,
322
+ next_obs: Observation,
323
+ next_state: State,
324
+ action: np.ndarray,
325
+ reward: npt.NDArray[np.float32],
329
326
  others: dict[str, Any],
330
327
  done: bool,
331
328
  truncated: bool,
marlenv/models/spaces.py CHANGED
@@ -1,13 +1,11 @@
1
1
  import math
2
2
  from abc import ABC, abstractmethod
3
3
  from dataclasses import dataclass
4
- from typing import Any, Generic, Optional, TypeVar
4
+ from typing import Optional
5
5
 
6
6
  import numpy as np
7
7
  import numpy.typing as npt
8
8
 
9
- S = TypeVar("S", bound="Space")
10
-
11
9
 
12
10
  @dataclass
13
11
  class Space(ABC):
@@ -23,7 +21,7 @@ class Space(ABC):
23
21
  self.labels = labels
24
22
 
25
23
  @abstractmethod
26
- def sample(self, mask: Optional[npt.NDArray[np.bool_]] = None) -> Any:
24
+ def sample(self, mask: Optional[npt.NDArray[np.bool_]] = None) -> npt.NDArray[np.float32]:
27
25
  """Sample a value from the space."""
28
26
 
29
27
  def __eq__(self, value: object) -> bool:
@@ -34,6 +32,16 @@ class Space(ABC):
34
32
  def __ne__(self, value: object) -> bool:
35
33
  return not self.__eq__(value)
36
34
 
35
+ @property
36
+ @abstractmethod
37
+ def is_discrete(self) -> bool:
38
+ """Whether the space is discrete."""
39
+
40
+ @property
41
+ def is_continuous(self) -> bool:
42
+ """Whether the space is continuous."""
43
+ return not self.is_discrete
44
+
37
45
 
38
46
  @dataclass
39
47
  class DiscreteSpace(Space):
@@ -45,8 +53,8 @@ class DiscreteSpace(Space):
45
53
  self.size = size
46
54
  self.space = np.arange(size)
47
55
 
48
- def sample(self, mask: Optional[npt.NDArray[np.bool_]] = None) -> int:
49
- space = self.space
56
+ def sample(self, mask: Optional[npt.NDArray[np.bool]] = None):
57
+ space = self.space.copy()
50
58
  if mask is not None:
51
59
  space = space[mask]
52
60
  return int(np.random.choice(space))
@@ -58,6 +66,25 @@ class DiscreteSpace(Space):
58
66
  return False
59
67
  return super().__eq__(value)
60
68
 
69
+ @property
70
+ def is_discrete(self) -> bool:
71
+ return True
72
+
73
+ @staticmethod
74
+ def action(size, labels: Optional[list[str]] = None):
75
+ """
76
+ Create a discrete action space where the default labels are set to "Action-n".
77
+ """
78
+ if labels is None:
79
+ labels = [f"Action {i}" for i in range(size)]
80
+ return DiscreteSpace(size, labels)
81
+
82
+ def repeat(self, n: int):
83
+ """
84
+ Repeat the discrete space n times.
85
+ """
86
+ return MultiDiscreteSpace(*([self] * n), labels=self.labels)
87
+
61
88
 
62
89
  @dataclass
63
90
  class MultiDiscreteSpace(Space):
@@ -75,10 +102,10 @@ class MultiDiscreteSpace(Space):
75
102
  def from_sizes(cls, *sizes: int):
76
103
  return cls(*(DiscreteSpace(size) for size in sizes))
77
104
 
78
- def sample(self, masks: Optional[npt.NDArray[np.bool_] | list[npt.NDArray[np.bool_]]] = None):
79
- if masks is None:
105
+ def sample(self, mask: Optional[npt.NDArray[np.bool] | list[npt.NDArray[np.bool]]] = None):
106
+ if mask is None:
80
107
  return np.array([space.sample() for space in self.spaces], dtype=np.int32)
81
- return np.array([space.sample(mask) for mask, space in zip(masks, self.spaces)], dtype=np.int32)
108
+ return np.array([space.sample(mask=mask) for mask, space in zip(mask, self.spaces)], dtype=np.int32)
82
109
 
83
110
  def __eq__(self, value: object) -> bool:
84
111
  if not isinstance(value, MultiDiscreteSpace):
@@ -90,6 +117,10 @@ class MultiDiscreteSpace(Space):
90
117
  return False
91
118
  return super().__eq__(value)
92
119
 
120
+ @property
121
+ def is_discrete(self) -> bool:
122
+ return True
123
+
93
124
 
94
125
  @dataclass
95
126
  class ContinuousSpace(Space):
@@ -100,23 +131,35 @@ class ContinuousSpace(Space):
100
131
  high: npt.NDArray[np.float32]
101
132
  """Upper bound of the space for each dimension."""
102
133
 
103
- @staticmethod
104
- def from_bounds(
105
- low: int | float | list | npt.NDArray[np.float32],
106
- high: int | float | list | npt.NDArray[np.float32],
134
+ def __init__(
135
+ self,
136
+ low: int | float | list | npt.NDArray[np.float32] | None,
137
+ high: int | float | list | npt.NDArray[np.float32] | None,
107
138
  labels: Optional[list[str]] = None,
108
139
  ):
109
140
  match low:
141
+ case None:
142
+ assert high is not None, "If low is None, high must be set to infer the shape."
143
+ shape = ContinuousSpace.get_shape(high)
144
+ low = np.full(shape, -np.inf, dtype=np.float32)
110
145
  case list():
111
146
  low = np.array(low, dtype=np.float32)
112
147
  case float() | int():
113
148
  low = np.array([low], dtype=np.float32)
114
149
  match high:
150
+ case None:
151
+ assert low is not None, "If high is None, low must be set to infer the shape."
152
+ shape = ContinuousSpace.get_shape(low)
153
+ high = np.full(shape, np.inf, dtype=np.float32)
115
154
  case list():
116
155
  high = np.array(high, dtype=np.float32)
117
156
  case float() | int():
118
157
  high = np.array([high], dtype=np.float32)
119
- return ContinuousSpace(low, high, labels)
158
+ assert low.shape == high.shape, f"Low and high must have the same shape. Low shape: {low.shape}, high shape: {high.shape}"
159
+ assert np.all(low <= high), "All elements in low must be less than the corresponding elements in high."
160
+ Space.__init__(self, low.shape, labels)
161
+ self.low = low
162
+ self.high = high
120
163
 
121
164
  @staticmethod
122
165
  def from_shape(
@@ -143,20 +186,24 @@ class ContinuousSpace(Space):
143
186
  high = np.array(high, dtype=np.float32)
144
187
  return ContinuousSpace(low, high, labels)
145
188
 
146
- def __init__(
147
- self,
148
- low: npt.NDArray[np.float32],
149
- high: npt.NDArray[np.float32],
150
- labels: Optional[list[str]] = None,
151
- ):
152
- assert low.shape == high.shape, "Low and high must have the same shape."
153
- assert np.all(low <= high), "All elements in low must be less than the corresponding elements in high."
154
- Space.__init__(self, low.shape, labels)
155
- self.low = low
156
- self.high = high
189
+ def clamp(self, action: np.ndarray | list):
190
+ """Clamp the action to the bounds of the space."""
191
+ if isinstance(action, list):
192
+ action = np.array(action)
193
+ return np.clip(action, self.low, self.high)
194
+
195
+ def sample(self) -> npt.NDArray[np.float32]:
196
+ r = np.random.random(self.shape) * (self.high - self.low) + self.low
197
+ return r.astype(np.float32)
157
198
 
158
- def sample(self, *_):
159
- return np.random.random(self.shape) * (self.high - self.low) + self.low
199
+ @staticmethod
200
+ def get_shape(item: float | int | list | npt.NDArray[np.float32]) -> tuple[int, ...]:
201
+ """Get the shape of the item."""
202
+ if isinstance(item, list):
203
+ item = np.array(item)
204
+ if isinstance(item, np.ndarray):
205
+ return item.shape
206
+ return (1,)
160
207
 
161
208
  def __eq__(self, value: object) -> bool:
162
209
  if not isinstance(value, ContinuousSpace):
@@ -167,59 +214,19 @@ class ContinuousSpace(Space):
167
214
  return False
168
215
  return super().__eq__(value)
169
216
 
170
-
171
- @dataclass
172
- class ActionSpace(Space, Generic[S]):
173
- n_agents: int
174
- """Number of agents."""
175
- action_names: list[str]
176
- """The meaning of each action."""
177
- n_actions: int
178
- individual_action_space: S
179
-
180
- def __init__(self, n_agents: int, individual_action_space: S, action_names: Optional[list] = None):
181
- Space.__init__(self, (n_agents, *individual_action_space.shape), action_names)
182
- self.n_agents = n_agents
183
- self.individual_action_space = individual_action_space
184
- self.n_actions = math.prod(individual_action_space.shape)
185
- self.action_names = action_names or [f"Action {i}" for i in range(self.n_actions)]
186
-
187
- def sample(self, mask: np.ndarray | None = None):
188
- res = []
189
- for i in range(self.n_agents):
190
- if mask is not None:
191
- m = mask[i]
192
- else:
193
- m = None
194
- res.append(self.individual_action_space.sample(m))
195
- return np.array(res)
196
-
197
- def __eq__(self, value: object) -> bool:
198
- if not isinstance(value, ActionSpace):
199
- return False
200
- if self.n_agents != value.n_agents:
201
- return False
202
- if self.n_actions != value.n_actions:
203
- return False
204
- if self.individual_action_space != value.individual_action_space:
205
- return False
206
- return super().__eq__(value)
207
-
208
-
209
- @dataclass
210
- class DiscreteActionSpace(ActionSpace[DiscreteSpace]):
211
- def __init__(self, n_agents: int, n_actions: int, action_names: Optional[list[str]] = None):
212
- individual_action_space = DiscreteSpace(n_actions, action_names)
213
- super().__init__(n_agents, individual_action_space, action_names)
214
-
215
-
216
- @dataclass
217
- class MultiDiscreteActionSpace(ActionSpace[MultiDiscreteSpace]):
218
- pass
219
-
220
-
221
- @dataclass
222
- class ContinuousActionSpace(ActionSpace[ContinuousSpace]):
223
- def __init__(self, n_agents: int, low: np.ndarray | list, high: np.ndarray | list, action_names: list | None = None):
224
- space = ContinuousSpace.from_bounds(low, high, action_names)
225
- super().__init__(n_agents, space, action_names)
217
+ def repeat(self, n: int):
218
+ """
219
+ Repeat the continuous space n times to become of shape (n, *shape).
220
+ """
221
+ low = np.tile(self.low, (n, 1))
222
+ high = np.tile(self.high, (n, 1))
223
+ return ContinuousSpace.from_shape(
224
+ (n, *self.shape),
225
+ low=low,
226
+ high=high,
227
+ labels=self.labels,
228
+ )
229
+
230
+ @property
231
+ def is_discrete(self) -> bool:
232
+ return False
@@ -1,6 +1,5 @@
1
1
  from dataclasses import dataclass
2
- from typing import Any, Generic, Sequence
3
- from typing_extensions import TypeVar
2
+ from typing import Any, Sequence
4
3
 
5
4
  import numpy as np
6
5
  import numpy.typing as npt
@@ -10,16 +9,13 @@ from .state import State
10
9
  from .step import Step
11
10
 
12
11
 
13
- A = TypeVar("A", default=np.ndarray)
14
-
15
-
16
12
  @dataclass
17
- class Transition(Generic[A]):
13
+ class Transition:
18
14
  """Transition model"""
19
15
 
20
16
  obs: Observation
21
17
  state: State
22
- action: A
18
+ action: np.ndarray
23
19
  reward: npt.NDArray[np.float32]
24
20
  done: bool
25
21
  info: dict[str, Any]
@@ -32,7 +28,7 @@ class Transition(Generic[A]):
32
28
  self,
33
29
  obs: Observation,
34
30
  state: State,
35
- action: A,
31
+ action: np.ndarray | Sequence[float],
36
32
  reward: npt.NDArray[np.float32] | float | Sequence[float],
37
33
  done: bool,
38
34
  info: dict[str, Any],
@@ -65,14 +61,14 @@ class Transition(Generic[A]):
65
61
  def from_step(
66
62
  prev_obs: Observation,
67
63
  prev_state: State,
68
- actions: A,
64
+ action: np.ndarray | Sequence[float],
69
65
  step: Step,
70
66
  **kwargs,
71
67
  ):
72
68
  return Transition(
73
69
  obs=prev_obs,
74
70
  state=prev_state,
75
- action=actions,
71
+ action=action,
76
72
  reward=step.reward,
77
73
  done=step.done,
78
74
  info=step.info,
@@ -0,0 +1,10 @@
1
+ from .schedule import Schedule, MultiSchedule, RoundedSchedule, LinearSchedule, ExpSchedule
2
+
3
+
4
+ __all__ = [
5
+ "Schedule",
6
+ "LinearSchedule",
7
+ "ExpSchedule",
8
+ "MultiSchedule",
9
+ "RoundedSchedule",
10
+ ]