multi-agent-rlenv 3.3.7__py3-none-any.whl → 3.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
marlenv/__init__.py CHANGED
@@ -62,16 +62,11 @@ print(env.extras_shape) # (1, )
62
62
  If you want to create a new environment, you can simply create a class that inherits from `MARLEnv`. If you want to create a wrapper around an existing `MARLEnv`, you probably want to subclass `RLEnvWrapper` which implements a default behaviour for every method.
63
63
  """
64
64
 
65
- __version__ = "3.3.7"
65
+ __version__ = "3.5.0"
66
66
 
67
67
  from . import models
68
- from . import wrappers
69
- from . import adapters
70
- from .models import spaces
71
-
72
-
73
- from .env_builder import make, Builder
74
68
  from .models import (
69
+ spaces,
75
70
  MARLEnv,
76
71
  State,
77
72
  Step,
@@ -80,10 +75,14 @@ from .models import (
80
75
  Transition,
81
76
  DiscreteSpace,
82
77
  ContinuousSpace,
83
- ActionSpace,
84
- DiscreteActionSpace,
85
- ContinuousActionSpace,
78
+ Space,
79
+ MultiDiscreteSpace,
86
80
  )
81
+
82
+
83
+ from . import wrappers
84
+ from . import adapters
85
+ from .env_builder import make, Builder
87
86
  from .wrappers import RLEnvWrapper
88
87
  from .mock_env import DiscreteMockEnv, DiscreteMOMockEnv
89
88
 
@@ -100,12 +99,11 @@ __all__ = [
100
99
  "Observation",
101
100
  "Episode",
102
101
  "Transition",
103
- "ActionSpace",
104
102
  "DiscreteSpace",
105
103
  "ContinuousSpace",
106
- "DiscreteActionSpace",
107
- "ContinuousActionSpace",
108
104
  "DiscreteMockEnv",
109
105
  "DiscreteMOMockEnv",
110
106
  "RLEnvWrapper",
107
+ "Space",
108
+ "MultiDiscreteSpace",
111
109
  ]
@@ -1,26 +1,16 @@
1
1
  import sys
2
- import cv2
3
2
  from dataclasses import dataclass
4
- from typing import Sequence
5
3
 
4
+ import cv2
6
5
  import gymnasium as gym
7
6
  import numpy as np
8
- import numpy.typing as npt
9
7
  from gymnasium import Env, spaces
10
8
 
11
- from marlenv.models import (
12
- ActionSpace,
13
- ContinuousActionSpace,
14
- DiscreteActionSpace,
15
- MARLEnv,
16
- Observation,
17
- State,
18
- Step,
19
- )
9
+ from marlenv import ContinuousSpace, DiscreteSpace, MARLEnv, Observation, Space, State, Step
20
10
 
21
11
 
22
12
  @dataclass
23
- class Gym(MARLEnv[Sequence | npt.NDArray, ActionSpace]):
13
+ class Gym(MARLEnv[Space]):
24
14
  """Wraps a gym envronment in an RLEnv"""
25
15
 
26
16
  def __init__(self, env: Env | str, **kwargs):
@@ -30,7 +20,7 @@ class Gym(MARLEnv[Sequence | npt.NDArray, ActionSpace]):
30
20
  raise NotImplementedError("Observation space must have a shape")
31
21
  match env.action_space:
32
22
  case spaces.Discrete() as s:
33
- space = DiscreteActionSpace(1, int(s.n))
23
+ space = DiscreteSpace(int(s.n), labels=[f"Action {i}" for i in range(s.n)]).repeat(1)
34
24
  case spaces.Box() as s:
35
25
  low = s.low.astype(np.float32)
36
26
  high = s.high.astype(np.float32)
@@ -38,10 +28,10 @@ class Gym(MARLEnv[Sequence | npt.NDArray, ActionSpace]):
38
28
  low = np.full(s.shape, s.low, dtype=np.float32)
39
29
  if not isinstance(high, np.ndarray):
40
30
  high = np.full(s.shape, s.high, dtype=np.float32)
41
- space = ContinuousActionSpace(1, low, high)
31
+ space = ContinuousSpace(low, high, labels=[f"Action {i}" for i in range(s.shape[0])]).repeat(1)
42
32
  case other:
43
33
  raise NotImplementedError(f"Action space {other} not supported")
44
- super().__init__(space, env.observation_space.shape, (1,))
34
+ super().__init__(1, space, env.observation_space.shape, (1,))
45
35
  self._gym_env = env
46
36
  if self._gym_env.unwrapped.spec is not None:
47
37
  self.name = self._gym_env.unwrapped.spec.id
@@ -1,14 +1,14 @@
1
1
  import sys
2
2
  from dataclasses import dataclass
3
- from typing import Literal, Sequence
3
+ from typing import Literal, Sequence, Optional
4
4
  from copy import deepcopy
5
- from time import time
6
5
 
7
6
  import cv2
8
7
  import numpy as np
9
8
  import numpy.typing as npt
10
9
  import pygame
11
- from marlenv.models import ContinuousSpace, DiscreteActionSpace, MARLEnv, Observation, State, Step
10
+ from marlenv.models import ContinuousSpace, DiscreteSpace, MARLEnv, Observation, State, Step, MultiDiscreteSpace
11
+ from marlenv.utils import Schedule
12
12
 
13
13
  from overcooked_ai_py.mdp.overcooked_env import OvercookedEnv
14
14
  from overcooked_ai_py.mdp.overcooked_mdp import Action, OvercookedGridworld, OvercookedState
@@ -16,12 +16,19 @@ from overcooked_ai_py.visualization.state_visualizer import StateVisualizer
16
16
 
17
17
 
18
18
  @dataclass
19
- class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
19
+ class Overcooked(MARLEnv[MultiDiscreteSpace]):
20
20
  horizon: int
21
- reward_shaping: bool
21
+ shaping_factor: Schedule
22
22
 
23
- def __init__(self, oenv: OvercookedEnv, reward_shaping: bool = True):
24
- self.reward_shaping = reward_shaping
23
+ def __init__(
24
+ self,
25
+ oenv: OvercookedEnv,
26
+ shaping_factor: float | Schedule = 1.0,
27
+ name_suffix: Optional[str] = None,
28
+ ):
29
+ if isinstance(shaping_factor, (int, float)):
30
+ shaping_factor = Schedule.constant(shaping_factor)
31
+ self.shaping_factor = shaping_factor
25
32
  self._oenv = oenv
26
33
  assert isinstance(oenv.mdp, OvercookedGridworld)
27
34
  self._mdp = oenv.mdp
@@ -30,10 +37,9 @@ class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
30
37
  # -1 because we extract the "urgent" layer to the extras
31
38
  shape = (int(layers - 1), int(width), int(height))
32
39
  super().__init__(
33
- action_space=DiscreteActionSpace(
34
- n_agents=self._mdp.num_players,
35
- n_actions=Action.NUM_ACTIONS,
36
- action_names=[Action.ACTION_TO_CHAR[a] for a in Action.ALL_ACTIONS],
40
+ n_agents=self._mdp.num_players,
41
+ action_space=DiscreteSpace(Action.NUM_ACTIONS, labels=[Action.ACTION_TO_CHAR[a] for a in Action.ALL_ACTIONS]).repeat(
42
+ self._mdp.num_players
37
43
  ),
38
44
  observation_shape=shape,
39
45
  extras_shape=(2,),
@@ -43,6 +49,8 @@ class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
43
49
  reward_space=ContinuousSpace.from_shape(1),
44
50
  )
45
51
  self.horizon = int(self._oenv.horizon)
52
+ if name_suffix is not None:
53
+ self.name = f"{self.name}-{name_suffix}"
46
54
 
47
55
  @property
48
56
  def state(self) -> OvercookedState:
@@ -86,11 +94,12 @@ class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
86
94
  available_actions[agent_num, Action.ACTION_TO_INDEX[action]] = True
87
95
  return np.array(available_actions, dtype=np.bool)
88
96
 
89
- def step(self, actions: Sequence[int] | npt.NDArray[np.int32 | np.int64]) -> Step:
97
+ def step(self, actions: Sequence[int] | np.ndarray) -> Step:
98
+ self.shaping_factor.update()
90
99
  actions = [Action.ALL_ACTIONS[a] for a in actions]
91
100
  _, reward, done, info = self._oenv.step(actions, display_phi=True)
92
- if self.reward_shaping:
93
- reward += sum(info["shaped_r_by_agent"])
101
+
102
+ reward += sum(info["shaped_r_by_agent"]) * self.shaping_factor
94
103
  return Step(
95
104
  obs=self.get_observation(),
96
105
  state=self.get_state(),
@@ -104,19 +113,25 @@ class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
104
113
  self._oenv.reset()
105
114
  return self.get_observation(), self.get_state()
106
115
 
107
- def __deepcopy__(self, memo: dict):
116
+ def __deepcopy__(self, _):
117
+ """
118
+ Note: a specific implementation is needed because `pygame.font.Font` objects are not deep-copiable by default.
119
+ """
108
120
  mdp = deepcopy(self._mdp)
109
- return Overcooked(OvercookedEnv.from_mdp(mdp, horizon=self.horizon))
121
+ copy = Overcooked(OvercookedEnv.from_mdp(mdp, horizon=self.horizon), deepcopy(self.shaping_factor))
122
+ copy.name = self.name
123
+ return copy
110
124
 
111
125
  def __getstate__(self):
112
- return {"horizon": self.horizon, "mdp": self._mdp}
126
+ return {"horizon": self.horizon, "mdp": self._mdp, "name": self.name, "schedule": self.shaping_factor}
113
127
 
114
128
  def __setstate__(self, state: dict):
115
129
  from overcooked_ai_py.mdp.overcooked_mdp import Recipe
116
130
 
117
131
  mdp = state["mdp"]
118
132
  Recipe.configure(mdp.recipe_config)
119
- self.__init__(OvercookedEnv.from_mdp(state["mdp"], horizon=state["horizon"]))
133
+ self.__init__(OvercookedEnv.from_mdp(state["mdp"], horizon=state["horizon"]), shaping_factor=state["schedule"])
134
+ self.name = state["name"]
120
135
 
121
136
  def get_image(self):
122
137
  rewards_dict = {} # dictionary of details you want rendered in the UI
@@ -190,16 +205,17 @@ class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
190
205
  "you_shall_not_pass",
191
206
  ],
192
207
  horizon: int = 400,
193
- reward_shaping: bool = True,
208
+ reward_shaping_factor: float | Schedule = 1.0,
194
209
  ):
195
210
  mdp = OvercookedGridworld.from_layout_name(layout)
196
- return Overcooked(OvercookedEnv.from_mdp(mdp, horizon=horizon), reward_shaping=reward_shaping)
211
+ return Overcooked(OvercookedEnv.from_mdp(mdp, horizon=horizon, info_level=0), reward_shaping_factor, layout)
197
212
 
198
213
  @staticmethod
199
214
  def from_grid(
200
215
  grid: Sequence[Sequence[Literal["S", "P", "X", "O", "D", "T", "1", "2", " "] | str]],
201
216
  horizon: int = 400,
202
- reward_shaping: bool = True,
217
+ shaping_factor: float | Schedule = 1.0,
218
+ layout_name: Optional[str] = None,
203
219
  ):
204
220
  """
205
221
  Create an Overcooked environment from a grid layout where
@@ -212,10 +228,14 @@ class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
212
228
  - 1 is a player 1 starting location
213
229
  - 2 is a player 2 starting location
214
230
  - ' ' is a walkable space
231
+
232
+ If provided, `custom_name` is added to the environment name.
215
233
  """
216
234
  # It is necessary to add an explicit layout name because Overcooked saves some files under this
217
235
  # name. By default the name is a concatenation of the grid elements, which may include characters
218
236
  # such as white spaces, pipes ('|') and square brackets ('[' and ']') that are invalid Windows file paths.
219
- layout_name = str(time())
237
+ if layout_name is None:
238
+ layout_name = "custom-layout"
220
239
  mdp = OvercookedGridworld.from_grid(grid, base_layout_params={"layout_name": layout_name})
221
- return Overcooked(OvercookedEnv.from_mdp(mdp, horizon=horizon), reward_shaping=reward_shaping)
240
+
241
+ return Overcooked(OvercookedEnv.from_mdp(mdp, horizon=horizon, info_level=0), shaping_factor, layout_name)
@@ -6,17 +6,17 @@ import numpy.typing as npt
6
6
  from gymnasium import spaces # pettingzoo uses gymnasium spaces
7
7
  from pettingzoo import ParallelEnv
8
8
 
9
- from marlenv.models import ActionSpace, ContinuousActionSpace, DiscreteActionSpace, MARLEnv, Observation, State, Step
9
+ from marlenv.models import MARLEnv, Observation, State, Step, DiscreteSpace, ContinuousSpace, Space
10
10
 
11
11
 
12
12
  @dataclass
13
- class PettingZoo(MARLEnv[npt.NDArray, ActionSpace]):
13
+ class PettingZoo(MARLEnv[Space]):
14
14
  def __init__(self, env: ParallelEnv):
15
15
  aspace = env.action_space(env.possible_agents[0])
16
16
  n_agents = len(env.possible_agents)
17
17
  match aspace:
18
18
  case spaces.Discrete() as s:
19
- space = DiscreteActionSpace(n_agents, int(s.n))
19
+ space = DiscreteSpace.action(int(s.n)).repeat(n_agents)
20
20
 
21
21
  case spaces.Box() as s:
22
22
  low = s.low.astype(np.float32)
@@ -25,7 +25,7 @@ class PettingZoo(MARLEnv[npt.NDArray, ActionSpace]):
25
25
  low = np.full(s.shape, s.low, dtype=np.float32)
26
26
  if not isinstance(high, np.ndarray):
27
27
  high = np.full(s.shape, s.high, dtype=np.float32)
28
- space = ContinuousActionSpace(n_agents, low, high=high)
28
+ space = ContinuousSpace(low, high=high).repeat(n_agents)
29
29
  case other:
30
30
  raise NotImplementedError(f"Action space {other} not supported")
31
31
 
@@ -34,7 +34,7 @@ class PettingZoo(MARLEnv[npt.NDArray, ActionSpace]):
34
34
  raise NotImplementedError("Only discrete observation spaces are supported")
35
35
  self._pz_env = env
36
36
  env.reset()
37
- super().__init__(space, obs_space.shape, self.get_state().shape)
37
+ super().__init__(n_agents, space, obs_space.shape, self.get_state().shape)
38
38
  self.agents = env.possible_agents
39
39
  self.last_observation = None
40
40
 
@@ -1,10 +1,9 @@
1
1
  from dataclasses import dataclass
2
- from typing import Any, Sequence
2
+ from typing import Any
3
3
 
4
4
  import numpy as np
5
- import numpy.typing as npt
6
5
 
7
- from marlenv.models import DiscreteActionSpace, MARLEnv
6
+ from marlenv.models import MARLEnv, MultiDiscreteSpace
8
7
  from marlenv.wrappers import TimeLimit
9
8
 
10
9
 
@@ -15,7 +14,7 @@ class PymarlAdapter:
15
14
  with the pymarl-qplex code base.
16
15
  """
17
16
 
18
- def __init__(self, env: MARLEnv[Sequence | npt.NDArray, DiscreteActionSpace], episode_limit: int):
17
+ def __init__(self, env: MARLEnv[MultiDiscreteSpace], episode_limit: int):
19
18
  assert env.reward_space.size == 1, "Only single objective environments are supported."
20
19
  self.env = TimeLimit(env, episode_limit, add_extra=False)
21
20
  # Required by PyMarl
@@ -1,15 +1,15 @@
1
1
  from dataclasses import dataclass
2
- from typing import Sequence, overload
2
+ from typing import overload
3
3
 
4
4
  import numpy as np
5
5
  import numpy.typing as npt
6
6
  from smac.env import StarCraft2Env
7
7
 
8
- from marlenv.models import DiscreteActionSpace, MARLEnv, Observation, State, Step
8
+ from marlenv.models import MARLEnv, Observation, State, Step, MultiDiscreteSpace, DiscreteSpace
9
9
 
10
10
 
11
11
  @dataclass
12
- class SMAC(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
12
+ class SMAC(MARLEnv[MultiDiscreteSpace]):
13
13
  """Wrapper for the SMAC environment to work with this framework"""
14
14
 
15
15
  @overload
@@ -157,10 +157,10 @@ class SMAC(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
157
157
  case other:
158
158
  raise ValueError(f"Invalid argument type: {type(other)}")
159
159
  self._env = StarCraft2Env(map_name=map_name)
160
- action_space = DiscreteActionSpace(self._env.n_agents, self._env.n_actions)
161
160
  self._env_info = self._env.get_env_info()
162
161
  super().__init__(
163
- action_space=action_space,
162
+ self._env.n_agents,
163
+ action_space=DiscreteSpace(self._env.n_actions).repeat(self._env.n_agents),
164
164
  observation_shape=(self._env_info["obs_shape"],),
165
165
  state_shape=(self._env_info["state_shape"],),
166
166
  )
@@ -195,7 +195,7 @@ class SMAC(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
195
195
  )
196
196
  return step
197
197
 
198
- def available_actions(self) -> npt.NDArray[np.bool_]:
198
+ def available_actions(self) -> npt.NDArray[np.bool]:
199
199
  return np.array(self._env.get_avail_actions()) == 1
200
200
 
201
201
  def get_image(self):
marlenv/env_builder.py CHANGED
@@ -5,10 +5,9 @@ import numpy.typing as npt
5
5
 
6
6
  from . import wrappers
7
7
  from marlenv import adapters
8
- from .models import ActionSpace, MARLEnv
8
+ from .models import Space, MARLEnv
9
9
 
10
- A = TypeVar("A")
11
- AS = TypeVar("AS", bound=ActionSpace)
10
+ AS = TypeVar("AS", bound=Space)
12
11
 
13
12
  if adapters.HAS_PETTINGZOO:
14
13
  from .adapters import PettingZoo
@@ -71,12 +70,12 @@ def make(env, **kwargs):
71
70
 
72
71
 
73
72
  @dataclass
74
- class Builder(Generic[A, AS]):
73
+ class Builder(Generic[AS]):
75
74
  """Builder for environments"""
76
75
 
77
- _env: MARLEnv[A, AS]
76
+ _env: MARLEnv[AS]
78
77
 
79
- def __init__(self, env: MARLEnv[A, AS]):
78
+ def __init__(self, env: MARLEnv[AS]):
80
79
  self._env = env
81
80
 
82
81
  def time_limit(self, n_steps: int, add_extra: bool = True, truncation_penalty: Optional[float] = None):
@@ -124,9 +123,9 @@ class Builder(Generic[A, AS]):
124
123
 
125
124
  def centralised(self):
126
125
  """Centralises the observations and actions"""
127
- from marlenv.models import DiscreteActionSpace
126
+ from marlenv.models import MultiDiscreteSpace
128
127
 
129
- assert isinstance(self._env.action_space, DiscreteActionSpace)
128
+ assert isinstance(self._env.action_space, MultiDiscreteSpace)
130
129
  self._env = wrappers.Centralized(self._env) # type: ignore
131
130
  return self
132
131
 
@@ -159,6 +158,6 @@ class Builder(Generic[A, AS]):
159
158
  self._env = wrappers.TimePenalty(self._env, penalty)
160
159
  return self
161
160
 
162
- def build(self) -> MARLEnv[A, AS]:
161
+ def build(self) -> MARLEnv[AS]:
163
162
  """Build and return the environment"""
164
163
  return self._env
marlenv/env_pool.py CHANGED
@@ -1,21 +1,19 @@
1
1
  from typing import Sequence
2
2
  from dataclasses import dataclass
3
- import numpy.typing as npt
4
3
  from typing_extensions import TypeVar
5
4
  import random
6
5
 
7
6
  from marlenv import RLEnvWrapper, MARLEnv
8
- from marlenv.models import ActionSpace
7
+ from marlenv.models import Space
9
8
 
10
- ActionType = TypeVar("ActionType", default=npt.NDArray)
11
- ActionSpaceType = TypeVar("ActionSpaceType", bound=ActionSpace, default=ActionSpace)
9
+ ActionSpaceType = TypeVar("ActionSpaceType", bound=Space, default=Space)
12
10
 
13
11
 
14
12
  @dataclass
15
- class EnvPool(RLEnvWrapper[ActionType, ActionSpaceType]):
16
- envs: Sequence[MARLEnv[ActionType, ActionSpaceType]]
13
+ class EnvPool(RLEnvWrapper[ActionSpaceType]):
14
+ envs: Sequence[MARLEnv[ActionSpaceType]]
17
15
 
18
- def __init__(self, envs: Sequence[MARLEnv[ActionType, ActionSpaceType]]):
16
+ def __init__(self, envs: Sequence[MARLEnv[ActionSpaceType]]):
19
17
  assert len(envs) > 0, "EnvPool must contain at least one environment"
20
18
  self.envs = envs
21
19
  for env in envs[1:]:
marlenv/mock_env.py CHANGED
@@ -1,12 +1,10 @@
1
- from typing import Sequence
2
1
  import numpy as np
3
- import numpy.typing as npt
4
2
  from dataclasses import dataclass
5
- from marlenv import MARLEnv, Observation, DiscreteActionSpace, ContinuousSpace, Step, State
3
+ from marlenv import MARLEnv, Observation, ContinuousSpace, Step, State, DiscreteSpace, MultiDiscreteSpace
6
4
 
7
5
 
8
6
  @dataclass
9
- class DiscreteMockEnv(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
7
+ class DiscreteMockEnv(MARLEnv[MultiDiscreteSpace]):
10
8
  def __init__(
11
9
  self,
12
10
  n_agents: int = 4,
@@ -27,7 +25,8 @@ class DiscreteMockEnv(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace])
27
25
  case _:
28
26
  raise ValueError("reward_step must be an int, float or np.ndarray")
29
27
  super().__init__(
30
- DiscreteActionSpace(n_agents, n_actions),
28
+ n_agents,
29
+ DiscreteSpace(n_actions).repeat(n_agents),
31
30
  (obs_size,),
32
31
  (n_agents * agent_state_size,),
33
32
  extras_shape=(extras_size,),
@@ -85,7 +84,7 @@ class DiscreteMockEnv(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace])
85
84
  )
86
85
 
87
86
 
88
- class DiscreteMOMockEnv(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
87
+ class DiscreteMOMockEnv(MARLEnv[DiscreteSpace]):
89
88
  """Multi-Objective Mock Environment"""
90
89
 
91
90
  def __init__(
@@ -100,7 +99,8 @@ class DiscreteMOMockEnv(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace
100
99
  extras_size: int = 0,
101
100
  ) -> None:
102
101
  super().__init__(
103
- DiscreteActionSpace(n_agents, n_actions),
102
+ n_agents,
103
+ DiscreteSpace(n_actions),
104
104
  (obs_size,),
105
105
  (n_agents * agent_state_size,),
106
106
  extras_shape=(extras_size,),
@@ -1,4 +1,4 @@
1
- from .spaces import ActionSpace, DiscreteSpace, ContinuousSpace, MultiDiscreteSpace, DiscreteActionSpace, ContinuousActionSpace
1
+ from .spaces import DiscreteSpace, ContinuousSpace, MultiDiscreteSpace, Space
2
2
  from .observation import Observation
3
3
  from .step import Step
4
4
  from .state import State
@@ -8,7 +8,6 @@ from .episode import Episode
8
8
 
9
9
 
10
10
  __all__ = [
11
- "ActionSpace",
12
11
  "Step",
13
12
  "State",
14
13
  "DiscreteSpace",
@@ -18,6 +17,5 @@ __all__ = [
18
17
  "Transition",
19
18
  "Episode",
20
19
  "MultiDiscreteSpace",
21
- "DiscreteActionSpace",
22
- "ContinuousActionSpace",
20
+ "Space",
23
21
  ]
marlenv/models/env.py CHANGED
@@ -1,24 +1,22 @@
1
1
  from abc import ABC, abstractmethod
2
2
  from dataclasses import dataclass
3
3
  from itertools import product
4
- from typing import Generic, Optional, Sequence
4
+ from typing import Generic, Optional, Sequence, TypeVar
5
5
 
6
6
  import cv2
7
7
  import numpy as np
8
8
  import numpy.typing as npt
9
- from typing_extensions import TypeVar
10
9
 
11
10
  from .observation import Observation
12
- from .spaces import ActionSpace, ContinuousSpace, Space
11
+ from .spaces import ContinuousSpace, Space, DiscreteSpace, MultiDiscreteSpace
13
12
  from .state import State
14
13
  from .step import Step
15
14
 
16
- ActionType = TypeVar("ActionType", default=npt.NDArray)
17
- ActionSpaceType = TypeVar("ActionSpaceType", bound=ActionSpace, default=ActionSpace)
15
+ ActionSpaceType = TypeVar("ActionSpaceType", bound=Space)
18
16
 
19
17
 
20
18
  @dataclass
21
- class MARLEnv(ABC, Generic[ActionType, ActionSpaceType]):
19
+ class MARLEnv(ABC, Generic[ActionSpaceType]):
22
20
  """
23
21
  Multi-Agent Reinforcement Learning environment.
24
22
 
@@ -70,6 +68,7 @@ class MARLEnv(ABC, Generic[ActionType, ActionSpaceType]):
70
68
 
71
69
  def __init__(
72
70
  self,
71
+ n_agents: int,
73
72
  action_space: ActionSpaceType,
74
73
  observation_shape: tuple[int, ...],
75
74
  state_shape: tuple[int, ...],
@@ -81,8 +80,8 @@ class MARLEnv(ABC, Generic[ActionType, ActionSpaceType]):
81
80
  super().__init__()
82
81
  self.name = self.__class__.__name__
83
82
  self.action_space = action_space
84
- self.n_actions = action_space.n_actions
85
- self.n_agents = action_space.n_agents
83
+ self.n_actions = action_space.shape[-1]
84
+ self.n_agents = n_agents
86
85
  self.observation_shape = observation_shape
87
86
  self.state_shape = state_shape
88
87
  self.extras_shape = extras_shape
@@ -108,9 +107,21 @@ class MARLEnv(ABC, Generic[ActionType, ActionSpaceType]):
108
107
  """Whether the environment is multi-objective."""
109
108
  return self.reward_space.size > 1
110
109
 
111
- def sample_action(self) -> ActionType:
110
+ @property
111
+ def n_objectives(self) -> int:
112
+ """The number of objectives in the environment."""
113
+ return self.reward_space.size
114
+
115
+ def sample_action(self):
112
116
  """Sample an available action from the action space."""
113
- return self.action_space.sample(self.available_actions()) # type: ignore
117
+ match self.action_space:
118
+ case MultiDiscreteSpace() as aspace:
119
+ return aspace.sample(mask=self.available_actions())
120
+ case ContinuousSpace() as aspace:
121
+ return aspace.sample()
122
+ case DiscreteSpace() as aspace:
123
+ return np.array([aspace.sample(mask=self.available_actions())])
124
+ raise NotImplementedError("Action space not supported")
114
125
 
115
126
  def available_actions(self) -> npt.NDArray[np.bool]:
116
127
  """
@@ -142,7 +153,7 @@ class MARLEnv(ABC, Generic[ActionType, ActionSpaceType]):
142
153
  raise NotImplementedError("Method not implemented")
143
154
 
144
155
  @abstractmethod
145
- def step(self, actions: ActionType) -> Step:
156
+ def step(self, action: Sequence | np.ndarray) -> Step:
146
157
  """Perform a step in the environment.
147
158
 
148
159
  Returns a Step object that can be unpacked as a 6-tuple containing:
@@ -175,7 +186,7 @@ class MARLEnv(ABC, Generic[ActionType, ActionSpaceType]):
175
186
  """Retrieve an image of the environment"""
176
187
  raise NotImplementedError("No image available for this environment")
177
188
 
178
- def replay(self, actions: Sequence[ActionType], seed: Optional[int] = None):
189
+ def replay(self, actions: Sequence, seed: Optional[int] = None):
179
190
  """Replay a sequence of actions."""
180
191
  from .episode import Episode # Avoid circular import
181
192