multi-agent-rlenv 3.2.2__py3-none-any.whl → 3.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
marlenv/__init__.py CHANGED
@@ -1,19 +1,68 @@
1
1
  """
2
2
  `marlenv` is a strongly typed library for multi-agent and multi-objective reinforcement learning.
3
3
 
4
- It aims to
5
- - provide a simple and consistent interface for reinforcement learning environments
6
- - provide fundamental models such as `Observation`s, `Episode`s, `Transition`s, ...
7
- - work with gymnasium, pettingzoo and SMAC out of the box
8
- - work with multi-objective environments
9
- - provide helpful wrappers to add intrinsic rewards, agent ids, record videos, ...
4
+ It aims to provide a simple and consistent interface for reinforcement learning environments by providing abstraction models such as `Observation`s or `Episode`s. `marlenv` provides adapters for popular libraries such as `gym` or `pettingzoo` and provides utility wrappers to add functionalities such as video recording or limiting the number of steps.
10
5
 
6
+ Almost every class is a dataclassto enable seemless serialiation with the `orjson` library.
11
7
 
12
- A design choice is taht almost every class is a dataclass. This makes it easy to
13
- serialize and deserialize classes, for instance to json with the `orjson` library.
8
+ # Existing environments
9
+ The `MARLEnv` class represents a multi-agent RL environment and is at the center of this library, and `marlenv` provides an adapted implementation of multiple common MARL environments (gym, pettingzoo, smac and overcooked) in `marlenv.adapters`. Note that these adapters will only work if you have the corresponding library installed.
10
+
11
+ ```python
12
+ from marlenv.adapters import Gym, PettingZoo, SMAC, Overcooked
13
+ import marlenv
14
+
15
+ env1 = Gym("CartPole-v1")
16
+ env2 = marlenv.make("CartPole-v1")
17
+ env3 = PettingZoo("prospector_v4")
18
+ env4 = SMAC("3m")
19
+ env5 = Overcooked.from_layout("cramped_room")
20
+ ```
21
+
22
+ # Wrappers & Builder
23
+ To facilitate the create of an environment with common wrappers, `marlenv` provides a `Builder` class that can be used to chain the creation of multiple wrappers.
24
+
25
+ ```python
26
+ from marlenv import make, Builder
27
+
28
+ env = <your env>
29
+ env = Builder(env).agent_id().time_limit(50).record("videos").build()
30
+ ```
31
+
32
+ # Using the library
33
+ A typical environment loop would look like this:
34
+
35
+ ```python
36
+ from marlenv import DiscreteMockEnv, Builder, Episode
37
+
38
+ env = Builder(DicreteMockEnv()).agent_id().build()
39
+ obs, state = env.reset()
40
+ terminated = False
41
+ episode = Episode.new(obs, state)
42
+ while not episode.is_finished:
43
+ action = env.sample_action() # a valid random action
44
+ step = env.step(action) # Step data `step.obs`, `step.reward`, ...
45
+ episode.add(step, action) # Progressively build the episode
46
+ ```
47
+
48
+ # Extras
49
+ To cope with complex observation spaces, `marlenv` distinguishes the "main" observation data from the "extra" observation data. A typical example would be the observation of a gridworld environment with a time limit. In that case, the main observation has shape (height, width), i.e. the content of the grid, but the current time is an extra observation data of shape (1, ).
50
+
51
+ ```python
52
+ env = GridWorldEnv()
53
+ print(env.observation_shape) # (height, width)
54
+ print(env.extras_shape) # (0, )
55
+
56
+ env = Builder(env).time_limit(25).build()
57
+ print(env.observation_shape) # (height, width)
58
+ print(env.extras_shape) # (1, )
59
+ ```
60
+
61
+ # Creating a new environment
62
+ If you want to create a new environment, you can simply create a class that inherits from `MARLEnv`. If you want to create a wrapper around an existing `MARLEnv`, you probably want to subclass `RLEnvWrapper` which implements a default behaviour for every method.
14
63
  """
15
64
 
16
- __version__ = "3.2.2"
65
+ __version__ = "3.3.1"
17
66
 
18
67
  from . import models
19
68
  from . import wrappers
@@ -1,24 +1,42 @@
1
+ from importlib.util import find_spec
1
2
  from .pymarl_adapter import PymarlAdapter
2
- from typing import Any
3
3
 
4
- __all__ = ["PymarlAdapter"]
5
- try:
4
+ HAS_GYM = False
5
+ if find_spec("gymnasium") is not None:
6
6
  from .gym_adapter import Gym
7
7
 
8
- __all__.append("Gym")
9
- except ImportError:
10
- Gym = Any
8
+ HAS_GYM = True
11
9
 
12
- try:
10
+ HAS_PETTINGZOO = False
11
+ if find_spec("pettingzoo") is not None:
13
12
  from .pettingzoo_adapter import PettingZoo
14
13
 
15
- __all__.append("PettingZoo")
16
- except ImportError:
17
- PettingZoo = Any
14
+ HAS_PETTINGZOO = True
18
15
 
19
- try:
16
+ HAS_SMAC = False
17
+ if find_spec("smac") is not None:
20
18
  from .smac_adapter import SMAC
21
19
 
22
- __all__.append("SMAC")
23
- except ImportError:
24
- SMAC = Any
20
+ HAS_SMAC = True
21
+
22
+ HAS_OVERCOOKED = False
23
+ if find_spec("overcooked_ai_py.mdp") is not None:
24
+ import numpy
25
+
26
+ # Overcooked assumes a version of numpy <2.0 where np.Inf is available.
27
+ setattr(numpy, "Inf", numpy.inf)
28
+ from .overcooked_adapter import Overcooked
29
+
30
+ HAS_OVERCOOKED = True
31
+
32
+ __all__ = [
33
+ "PymarlAdapter",
34
+ "Gym",
35
+ "PettingZoo",
36
+ "SMAC",
37
+ "Overcooked",
38
+ "HAS_GYM",
39
+ "HAS_PETTINGZOO",
40
+ "HAS_SMAC",
41
+ "HAS_OVERCOOKED",
42
+ ]
@@ -1,3 +1,5 @@
1
+ import sys
2
+ import cv2
1
3
  from dataclasses import dataclass
2
4
  from typing import Sequence
3
5
 
@@ -79,7 +81,10 @@ class Gym(MARLEnv[Sequence | npt.NDArray, ActionSpace]):
79
81
  return self.last_obs, self.get_state()
80
82
 
81
83
  def get_image(self):
82
- return self.env.render()
84
+ image = np.array(self.env.render())
85
+ if sys.platform in ("linux", "linux2"):
86
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
87
+ return image
83
88
 
84
89
  def seed(self, seed_value: int):
85
90
  self.env.reset(seed=seed_value)
@@ -0,0 +1,164 @@
1
+ import sys
2
+ from dataclasses import dataclass
3
+ from typing import Literal, Sequence
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import numpy.typing as npt
8
+ import pygame
9
+ from marlenv.models import ContinuousSpace, DiscreteActionSpace, MARLEnv, Observation, State, Step
10
+
11
+ from overcooked_ai_py.mdp.overcooked_env import OvercookedEnv
12
+ from overcooked_ai_py.mdp.overcooked_mdp import Action, OvercookedGridworld, OvercookedState
13
+ from overcooked_ai_py.visualization.state_visualizer import StateVisualizer
14
+
15
+
16
+ @dataclass
17
+ class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
18
+ horizon: int
19
+
20
+ def __init__(self, oenv: OvercookedEnv):
21
+ self._oenv = oenv
22
+ assert isinstance(oenv.mdp, OvercookedGridworld)
23
+ self._mdp = oenv.mdp
24
+ self.visualizer = StateVisualizer()
25
+ shape = tuple(int(s) for s in self._mdp.get_lossless_state_encoding_shape())
26
+ shape = (shape[2], shape[0], shape[1])
27
+ super().__init__(
28
+ action_space=DiscreteActionSpace(
29
+ n_agents=self._mdp.num_players,
30
+ n_actions=Action.NUM_ACTIONS,
31
+ action_names=[Action.ACTION_TO_CHAR[a] for a in Action.ALL_ACTIONS],
32
+ ),
33
+ observation_shape=shape,
34
+ extras_shape=(1,),
35
+ extras_meanings=["timestep"],
36
+ state_shape=shape,
37
+ state_extra_shape=(1,),
38
+ reward_space=ContinuousSpace.from_shape(1),
39
+ )
40
+ self.horizon = int(self._oenv.horizon)
41
+
42
+ @property
43
+ def state(self) -> OvercookedState:
44
+ """Current state of the environment"""
45
+ return self._oenv.state
46
+
47
+ def set_state(self, state: State):
48
+ raise NotImplementedError("Not yet implemented")
49
+
50
+ @property
51
+ def time_step(self):
52
+ return self.state.timestep
53
+
54
+ def _state_data(self):
55
+ state = np.array(self._mdp.lossless_state_encoding(self.state))
56
+ # Use axes (agents, channels, height, width) instead of (agents, height, width, channels)
57
+ state = np.transpose(state, (0, 3, 1, 2))
58
+ return state
59
+
60
+ def get_state(self):
61
+ return State(self._state_data()[0], np.array([self.time_step / self.horizon]))
62
+
63
+ def get_observation(self) -> Observation:
64
+ return Observation(
65
+ data=self._state_data(),
66
+ available_actions=self.available_actions(),
67
+ extras=np.array([[self.time_step / self.horizon]] * self.n_agents),
68
+ )
69
+
70
+ def available_actions(self):
71
+ available_actions = np.full((self.n_agents, self.n_actions), False)
72
+ actions = self._mdp.get_actions(self._oenv.state)
73
+ for agent_num, agent_actions in enumerate(actions):
74
+ for action in agent_actions:
75
+ available_actions[agent_num, Action.ACTION_TO_INDEX[action]] = True
76
+ return np.array(available_actions)
77
+
78
+ def step(self, actions: Sequence[int] | npt.NDArray[np.int32 | np.int64]) -> Step:
79
+ actions = [Action.ALL_ACTIONS[a] for a in actions]
80
+ _, reward, done, info = self._oenv.step(actions, display_phi=True)
81
+ return Step(
82
+ obs=self.get_observation(),
83
+ state=self.get_state(),
84
+ reward=np.array([reward]),
85
+ done=done,
86
+ truncated=False,
87
+ info=info,
88
+ )
89
+
90
+ def get_image(self):
91
+ rewards_dict = {} # dictionary of details you want rendered in the UI
92
+ for key, value in self._oenv.game_stats.items():
93
+ if key in [
94
+ "cumulative_shaped_rewards_by_agent",
95
+ "cumulative_sparse_rewards_by_agent",
96
+ ]:
97
+ rewards_dict[key] = value
98
+
99
+ image = self.visualizer.render_state(
100
+ state=self._oenv.state,
101
+ grid=self._mdp.terrain_mtx,
102
+ hud_data=StateVisualizer.default_hud_data(self._oenv.state, **rewards_dict),
103
+ )
104
+
105
+ image = pygame.surfarray.array3d(image)
106
+ image = np.flip(np.rot90(image, 3), 1)
107
+ # Depending on the platform, the image may need to be converted to RGB
108
+ if sys.platform in ("linux", "linux2"):
109
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
110
+ return image
111
+
112
+ @staticmethod
113
+ def from_layout(
114
+ layout: Literal[
115
+ "asymmetric_advantages",
116
+ "asymmetric_advantages_tomato",
117
+ "bonus_order_test",
118
+ "bottleneck",
119
+ "centre_objects",
120
+ "centre_pots",
121
+ "coordination_ring",
122
+ "corridor",
123
+ "counter_circuit",
124
+ "counter_circuit_o_1order",
125
+ "cramped_corridor",
126
+ "cramped_room",
127
+ "cramped_room_o_3orders",
128
+ "cramped_room_single",
129
+ "cramped_room_tomato",
130
+ "five_by_five",
131
+ "forced_coordination",
132
+ "forced_coordination_tomato",
133
+ "inverse_marshmallow_experiment",
134
+ "large_room",
135
+ "long_cook_time",
136
+ "marshmallow_experiment_coordination",
137
+ "marshmallow_experiment",
138
+ "mdp_test",
139
+ "m_shaped_s",
140
+ "multiplayer_schelling",
141
+ "pipeline",
142
+ "scenario1_s",
143
+ "scenario2",
144
+ "scenario2_s",
145
+ "scenario3",
146
+ "scenario4",
147
+ "schelling",
148
+ "schelling_s",
149
+ "simple_o",
150
+ "simple_o_t",
151
+ "simple_tomato",
152
+ "small_corridor",
153
+ "soup_coordination",
154
+ "tutorial_0",
155
+ "tutorial_1",
156
+ "tutorial_2",
157
+ "tutorial_3",
158
+ "unident",
159
+ "you_shall_not_pass",
160
+ ],
161
+ horizon: int = 400,
162
+ ):
163
+ mdp = OvercookedGridworld.from_layout_name(layout)
164
+ return Overcooked(OvercookedEnv.from_mdp(mdp, horizon=horizon))
marlenv/env_builder.py CHANGED
@@ -1,32 +1,27 @@
1
1
  from dataclasses import dataclass
2
2
  from typing import Generic, Literal, Optional, TypeVar, overload
3
-
4
3
  import numpy as np
5
4
  import numpy.typing as npt
6
5
 
7
6
  from . import wrappers
7
+ from marlenv import adapters
8
8
  from .models import ActionSpace, MARLEnv
9
- from .adapters import PettingZoo
10
9
 
11
10
  A = TypeVar("A")
12
11
  AS = TypeVar("AS", bound=ActionSpace)
13
12
 
14
- try:
13
+ if adapters.HAS_PETTINGZOO:
14
+ from .adapters import PettingZoo
15
15
  from pettingzoo import ParallelEnv
16
16
 
17
17
  @overload
18
- def make(
19
- env: ParallelEnv,
20
- ) -> PettingZoo: ...
21
-
22
- HAS_PETTINGZOO = True
23
- except ImportError:
24
- HAS_PETTINGZOO = False
18
+ def make(env: ParallelEnv) -> PettingZoo: ...
25
19
 
26
20
 
27
- try:
28
- from gymnasium import Env
21
+ if adapters.HAS_GYM:
29
22
  from .adapters import Gym
23
+ from gymnasium import Env
24
+ import gymnasium
30
25
 
31
26
  @overload
32
27
  def make(env: Env) -> Gym: ...
@@ -37,25 +32,21 @@ try:
37
32
  Make an RLEnv from the `gymnasium` registry (e.g: "CartPole-v1").
38
33
  """
39
34
 
40
- HAS_GYM = True
41
- except ImportError:
42
- HAS_GYM = False
43
35
 
44
- try:
45
- from smac.env import StarCraft2Env
36
+ if adapters.HAS_SMAC:
46
37
  from .adapters import SMAC
38
+ from smac.env import StarCraft2Env
47
39
 
48
40
  @overload
49
41
  def make(env: StarCraft2Env) -> SMAC: ...
50
42
 
51
- HAS_SMAC = True
52
- except ImportError:
53
- HAS_SMAC = False
54
43
 
44
+ if adapters.HAS_OVERCOOKED:
45
+ from .adapters import Overcooked
46
+ from overcooked_ai_py.mdp.overcooked_env import OvercookedEnv
55
47
 
56
- @overload
57
- def make(env: MARLEnv[A, AS]) -> MARLEnv[A, AS]:
58
- """Why would you do this ?"""
48
+ @overload
49
+ def make(env: OvercookedEnv) -> Overcooked: ...
59
50
 
60
51
 
61
52
  def make(env, **kwargs):
@@ -64,32 +55,18 @@ def make(env, **kwargs):
64
55
  case MARLEnv():
65
56
  return env
66
57
  case str(env_id):
67
- try:
68
- import gymnasium
69
- except ImportError:
70
- raise ImportError("Gymnasium is not installed !")
71
- from marlenv.adapters import Gym
72
-
73
- gym_env = gymnasium.make(env_id, render_mode="rgb_array", **kwargs)
74
- return Gym(gym_env)
75
-
76
- try:
77
- from marlenv.adapters import PettingZoo
78
-
79
- if isinstance(env, ParallelEnv):
80
- return PettingZoo(env)
81
- except ImportError:
82
- pass
83
- try:
84
- from smac.env import StarCraft2Env
85
-
86
- from marlenv.adapters import SMAC
87
-
88
- if isinstance(env, StarCraft2Env):
89
- return SMAC(env)
90
- except ImportError:
91
- pass
92
-
58
+ if adapters.HAS_GYM:
59
+ gym_env = gymnasium.make(env_id, render_mode="rgb_array", **kwargs)
60
+ return Gym(gym_env)
61
+
62
+ if adapters.HAS_PETTINGZOO and isinstance(env, ParallelEnv):
63
+ return PettingZoo(env) # type: ignore
64
+ if adapters.HAS_SMAC and isinstance(env, StarCraft2Env):
65
+ return SMAC(env)
66
+ if adapters.HAS_OVERCOOKED and isinstance(env, OvercookedEnv):
67
+ return Overcooked(env) # type: ignore
68
+ if adapters.HAS_GYM and isinstance(env, Env):
69
+ return Gym(env)
93
70
  raise ValueError(f"Unknown environment type: {type(env)}")
94
71
 
95
72
 
@@ -116,6 +93,11 @@ class Builder(Generic[A, AS]):
116
93
  self._env = wrappers.TimeLimit(self._env, n_steps, add_extra, truncation_penalty)
117
94
  return self
118
95
 
96
+ def delay_rewards(self, delay: int):
97
+ """Delays the rewards by `delay` steps"""
98
+ self._env = wrappers.DelayedReward(self._env, delay)
99
+ return self
100
+
119
101
  def pad(self, to_pad: Literal["obs", "extra"], n: int):
120
102
  match to_pad:
121
103
  case "obs":
marlenv/env_pool.py CHANGED
@@ -1,6 +1,5 @@
1
1
  from typing import Sequence
2
2
  from dataclasses import dataclass
3
- import numpy as np
4
3
  import numpy.typing as npt
5
4
  from typing_extensions import TypeVar
6
5
  import random
marlenv/mock_env.py CHANGED
@@ -2,7 +2,7 @@ from typing import Sequence
2
2
  import numpy as np
3
3
  import numpy.typing as npt
4
4
  from dataclasses import dataclass
5
- from marlenv import MARLEnv, Observation, DiscreteActionSpace, DiscreteSpace, Step, State
5
+ from marlenv import MARLEnv, Observation, DiscreteActionSpace, ContinuousSpace, Step, State
6
6
 
7
7
 
8
8
  @dataclass
@@ -13,21 +13,31 @@ class DiscreteMockEnv(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace])
13
13
  obs_size: int = 42,
14
14
  n_actions: int = 5,
15
15
  end_game: int = 30,
16
- reward_step: int = 1,
16
+ reward_step: int | float | np.ndarray | list = 1,
17
17
  agent_state_size: int = 1,
18
18
  extras_size: int = 0,
19
19
  ) -> None:
20
+ match reward_step:
21
+ case int() | float():
22
+ reward_step = np.array([reward_step])
23
+ case list():
24
+ reward_step = np.array(reward_step)
25
+ case np.ndarray():
26
+ reward_step = reward_step
27
+ case _:
28
+ raise ValueError("reward_step must be an int, float or np.ndarray")
20
29
  super().__init__(
21
30
  DiscreteActionSpace(n_agents, n_actions),
22
31
  (obs_size,),
23
32
  (n_agents * agent_state_size,),
24
33
  extras_shape=(extras_size,),
34
+ reward_space=ContinuousSpace.from_shape(reward_step.shape),
25
35
  )
36
+ self.reward_step = reward_step
26
37
  self.obs_size = obs_size
27
38
  self.extra_size = extras_size
28
39
  self._agent_state_size = agent_state_size
29
40
  self.end_game = end_game
30
- self.reward_step = reward_step
31
41
  self.t = 0
32
42
  self.actions_history = []
33
43
  self._seed = -1
@@ -70,7 +80,7 @@ class DiscreteMockEnv(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace])
70
80
  return Step(
71
81
  self.get_observation(),
72
82
  self.get_state(),
73
- np.array([self.reward_step]),
83
+ self.reward_step,
74
84
  self.t >= self.end_game,
75
85
  )
76
86
 
@@ -94,7 +104,7 @@ class DiscreteMOMockEnv(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace
94
104
  (obs_size,),
95
105
  (n_agents * agent_state_size,),
96
106
  extras_shape=(extras_size,),
97
- reward_space=DiscreteSpace(n_objectives),
107
+ reward_space=ContinuousSpace.from_shape(n_objectives),
98
108
  )
99
109
  self.obs_size = obs_size
100
110
  self.extra_size = extras_size
marlenv/models/env.py CHANGED
@@ -1,17 +1,17 @@
1
1
  from abc import ABC, abstractmethod
2
+ from dataclasses import dataclass
3
+ from itertools import product
2
4
  from typing import Generic, Optional, Sequence
3
- from typing_extensions import TypeVar
5
+
4
6
  import cv2
5
7
  import numpy as np
6
8
  import numpy.typing as npt
7
- from dataclasses import dataclass
8
- from itertools import product
9
-
9
+ from typing_extensions import TypeVar
10
10
 
11
- from .step import Step
12
- from .state import State
13
- from .spaces import ActionSpace, DiscreteSpace
14
11
  from .observation import Observation
12
+ from .spaces import ActionSpace, ContinuousSpace, Space
13
+ from .state import State
14
+ from .step import Step
15
15
 
16
16
  ActionType = TypeVar("ActionType", default=npt.NDArray)
17
17
  ActionSpaceType = TypeVar("ActionSpaceType", bound=ActionSpace, default=ActionSpace)
@@ -25,6 +25,34 @@ class MARLEnv(ABC, Generic[ActionType, ActionSpaceType]):
25
25
  This type is generic on
26
26
  - the action type
27
27
  - the action space
28
+
29
+ You can inherit from this class to create your own environemnt:
30
+ ```
31
+ import numpy as np
32
+ from marlenv import MARLEnv, DiscreteActionSpace, Observation
33
+
34
+ N_AGENTS = 3
35
+ N_ACTIONS = 5
36
+
37
+ class CustomEnv(MARLEnv[DiscreteActionSpace]):
38
+ def __init__(self, width: int, height: int):
39
+ super().__init__(
40
+ action_space=DiscreteActionSpace(N_AGENTS, N_ACTIONS),
41
+ observation_shape=(height, width),
42
+ state_shape=(1,),
43
+ )
44
+ self.time = 0
45
+
46
+ def reset(self) -> Observation:
47
+ self.time = 0
48
+ ...
49
+ return obs
50
+
51
+ def get_state(self):
52
+ return np.array([self.time])
53
+
54
+ ...
55
+ ```
28
56
  """
29
57
 
30
58
  action_space: ActionSpaceType
@@ -38,6 +66,7 @@ class MARLEnv(ABC, Generic[ActionType, ActionSpaceType]):
38
66
  n_agents: int
39
67
  n_actions: int
40
68
  name: str
69
+ reward_space: Space
41
70
 
42
71
  def __init__(
43
72
  self,
@@ -46,7 +75,7 @@ class MARLEnv(ABC, Generic[ActionType, ActionSpaceType]):
46
75
  state_shape: tuple[int, ...],
47
76
  extras_shape: tuple[int, ...] = (0,),
48
77
  state_extra_shape: tuple[int, ...] = (0,),
49
- reward_space: Optional[DiscreteSpace] = None,
78
+ reward_space: Optional[Space] = None,
50
79
  extras_meanings: Optional[list[str]] = None,
51
80
  ):
52
81
  super().__init__()
@@ -58,7 +87,9 @@ class MARLEnv(ABC, Generic[ActionType, ActionSpaceType]):
58
87
  self.state_shape = state_shape
59
88
  self.extras_shape = extras_shape
60
89
  self.state_extra_shape = state_extra_shape
61
- self.reward_space = reward_space or DiscreteSpace(1, labels=["Reward"])
90
+ if reward_space is None:
91
+ reward_space = ContinuousSpace.from_shape(1, labels=["Reward"])
92
+ self.reward_space = reward_space
62
93
  if extras_meanings is None:
63
94
  extras_meanings = [f"{self.name}-extra-{i}" for i in range(extras_shape[0])]
64
95
  elif len(extras_meanings) != extras_shape[0]:
@@ -77,9 +108,9 @@ class MARLEnv(ABC, Generic[ActionType, ActionSpaceType]):
77
108
  """Whether the environment is multi-objective."""
78
109
  return self.reward_space.size > 1
79
110
 
80
- def sample_action(self):
111
+ def sample_action(self) -> ActionType:
81
112
  """Sample an available action from the action space."""
82
- return self.action_space.sample(self.available_actions())
113
+ return self.action_space.sample(self.available_actions()) # type: ignore
83
114
 
84
115
  def available_actions(self) -> npt.NDArray[np.bool]:
85
116
  """
@@ -123,6 +154,10 @@ class MARLEnv(ABC, Generic[ActionType, ActionSpaceType]):
123
154
  - info: Extra information
124
155
  """
125
156
 
157
+ def random_step(self) -> Step:
158
+ """Perform a random step in the environment."""
159
+ return self.step(self.sample_action())
160
+
126
161
  def reset(self) -> tuple[Observation, State]:
127
162
  """Reset the environment and return the initial observation and state."""
128
163
  return self.get_observation(), self.get_state()
@@ -58,9 +58,14 @@ class Observation:
58
58
  available_actions=self.available_actions[agent_id],
59
59
  )
60
60
 
61
+ @property
62
+ def shape(self) -> tuple[int, ...]:
63
+ """The individual shape of the observation data"""
64
+ return self.data[0].shape
65
+
61
66
  @property
62
67
  def extras_shape(self) -> tuple[int, ...]:
63
- """The shape of the observation extras"""
68
+ """The individual shape of the observation extras"""
64
69
  return self.extras[0].shape
65
70
 
66
71
  def __hash__(self):
marlenv/models/spaces.py CHANGED
@@ -12,14 +12,14 @@ S = TypeVar("S", bound="Space")
12
12
  @dataclass
13
13
  class Space(ABC):
14
14
  shape: tuple[int, ...]
15
- n_dims: int
15
+ size: int
16
16
  labels: list[str]
17
17
 
18
18
  def __init__(self, shape: tuple[int, ...], labels: Optional[list[str]] = None):
19
19
  self.shape = shape
20
- self.n_dims = len(shape)
20
+ self.size = math.prod(shape)
21
21
  if labels is None:
22
- labels = [f"Dim {i}" for i in range(self.n_dims)]
22
+ labels = [f"Dim {i}" for i in range(self.size)]
23
23
  self.labels = labels
24
24
 
25
25
  @abstractmethod
@@ -100,16 +100,55 @@ class ContinuousSpace(Space):
100
100
  high: npt.NDArray[np.float32]
101
101
  """Upper bound of the space for each dimension."""
102
102
 
103
+ @staticmethod
104
+ def from_bounds(
105
+ low: int | float | list | npt.NDArray[np.float32],
106
+ high: int | float | list | npt.NDArray[np.float32],
107
+ labels: Optional[list[str]] = None,
108
+ ):
109
+ match low:
110
+ case list():
111
+ low = np.array(low, dtype=np.float32)
112
+ case float() | int():
113
+ low = np.array([low], dtype=np.float32)
114
+ match high:
115
+ case list():
116
+ high = np.array(high, dtype=np.float32)
117
+ case float() | int():
118
+ high = np.array([high], dtype=np.float32)
119
+ return ContinuousSpace(low, high, labels)
120
+
121
+ @staticmethod
122
+ def from_shape(
123
+ shape: int | tuple[int, ...],
124
+ low: Optional[int | float | list | npt.NDArray[np.float32]] = None,
125
+ high: Optional[int | float | list | npt.NDArray[np.float32]] = None,
126
+ labels: Optional[list[str]] = None,
127
+ ):
128
+ if isinstance(shape, int):
129
+ shape = (shape,)
130
+ match low:
131
+ case None:
132
+ low = np.full(shape, -np.inf, dtype=np.float32)
133
+ case float() | int():
134
+ low = np.full(shape, low, dtype=np.float32)
135
+ case list():
136
+ low = np.array(low, dtype=np.float32)
137
+ match high:
138
+ case None:
139
+ high = np.full(shape, np.inf, dtype=np.float32)
140
+ case float() | int():
141
+ high = np.full(shape, high, dtype=np.float32)
142
+ case list():
143
+ high = np.array(high, dtype=np.float32)
144
+ return ContinuousSpace(low, high, labels)
145
+
103
146
  def __init__(
104
147
  self,
105
- low: list | npt.NDArray[np.float32],
106
- high: list | npt.NDArray[np.float32],
148
+ low: npt.NDArray[np.float32],
149
+ high: npt.NDArray[np.float32],
107
150
  labels: Optional[list[str]] = None,
108
151
  ):
109
- if isinstance(low, list):
110
- low = np.array(low, dtype=np.float32)
111
- if isinstance(high, list):
112
- high = np.array(high, dtype=np.float32)
113
152
  assert low.shape == high.shape, "Low and high must have the same shape."
114
153
  assert np.all(low <= high), "All elements in low must be less than the corresponding elements in high."
115
154
  Space.__init__(self, low.shape, labels)
@@ -182,5 +221,5 @@ class MultiDiscreteActionSpace(ActionSpace[MultiDiscreteSpace]):
182
221
  @dataclass
183
222
  class ContinuousActionSpace(ActionSpace[ContinuousSpace]):
184
223
  def __init__(self, n_agents: int, low: np.ndarray | list, high: np.ndarray | list, action_names: list | None = None):
185
- space = ContinuousSpace(low, high, action_names)
224
+ space = ContinuousSpace.from_bounds(low, high, action_names)
186
225
  super().__init__(n_agents, space, action_names)
@@ -9,6 +9,7 @@ from .available_actions_wrapper import AvailableActions
9
9
  from .blind_wrapper import Blind
10
10
  from .centralised import Centralised
11
11
  from .available_actions_mask import AvailableActionsMask
12
+ from .delayed_rewards import DelayedReward
12
13
 
13
14
  __all__ = [
14
15
  "RLEnvWrapper",
@@ -24,4 +25,5 @@ __all__ = [
24
25
  "AvailableActions",
25
26
  "Blind",
26
27
  "Centralised",
28
+ "DelayedReward",
27
29
  ]
@@ -0,0 +1,36 @@
1
+ from .rlenv_wrapper import RLEnvWrapper, MARLEnv
2
+ from marlenv.models import ActionSpace
3
+ from typing_extensions import TypeVar
4
+ import numpy.typing as npt
5
+ import numpy as np
6
+ from dataclasses import dataclass
7
+ from collections import deque
8
+
9
+ A = TypeVar("A", default=npt.NDArray)
10
+ AS = TypeVar("AS", bound=ActionSpace, default=ActionSpace)
11
+
12
+
13
+ @dataclass
14
+ class DelayedReward(RLEnvWrapper[A, AS]):
15
+ delay: int
16
+
17
+ def __init__(self, env: MARLEnv[A, AS], delay: int):
18
+ super().__init__(env)
19
+ self.delay = delay
20
+ self.reward_queue = deque[npt.NDArray[np.float32]](maxlen=delay + 1)
21
+
22
+ def reset(self):
23
+ self.reward_queue.clear()
24
+ for _ in range(self.delay):
25
+ self.reward_queue.append(np.zeros(self.reward_space.shape, dtype=np.float32))
26
+ return super().reset()
27
+
28
+ def step(self, actions: A):
29
+ step = super().step(actions)
30
+ self.reward_queue.append(step.reward)
31
+ # If the step is terminal, we sum all the remaining rewards
32
+ if step.is_terminal:
33
+ step.reward = np.sum(self.reward_queue, axis=0)
34
+ else:
35
+ step.reward = self.reward_queue.popleft()
36
+ return step
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: multi-agent-rlenv
3
- Version: 3.2.2
3
+ Version: 3.3.1
4
4
  Summary: A strongly typed Multi-Agent Reinforcement Learning framework
5
5
  Project-URL: repository, https://github.com/yamoling/multi-agent-rlenv
6
6
  Author-email: Yannick Molinghen <yannick.molinghen@ulb.be>
@@ -8,16 +8,58 @@ License-File: LICENSE
8
8
  Classifier: Operating System :: OS Independent
9
9
  Classifier: Programming Language :: Python :: 3
10
10
  Requires-Python: <4,>=3.10
11
- Requires-Dist: gymnasium>=0.29.1
12
11
  Requires-Dist: numpy>=2.0.0
13
- Requires-Dist: opencv-python>=4.10.0.84
12
+ Requires-Dist: opencv-python>=4.0
13
+ Requires-Dist: typing-extensions>=4.0
14
+ Provides-Extra: all
15
+ Requires-Dist: gymnasium>0.29.1; extra == 'all'
16
+ Requires-Dist: overcooked-ai; extra == 'all'
17
+ Requires-Dist: pettingzoo>=1.20; extra == 'all'
18
+ Requires-Dist: pymunk>=6.0; extra == 'all'
19
+ Requires-Dist: pysc2; extra == 'all'
20
+ Requires-Dist: scipy>=1.10; extra == 'all'
21
+ Requires-Dist: smac; extra == 'all'
22
+ Provides-Extra: gym
23
+ Requires-Dist: gymnasium>=0.29.1; extra == 'gym'
24
+ Provides-Extra: overcooked
25
+ Requires-Dist: overcooked-ai>=1.1.0; extra == 'overcooked'
26
+ Requires-Dist: scipy>=1.10; extra == 'overcooked'
27
+ Provides-Extra: pettingzoo
28
+ Requires-Dist: pettingzoo>=1.20; extra == 'pettingzoo'
29
+ Requires-Dist: pymunk>=6.0; extra == 'pettingzoo'
30
+ Requires-Dist: scipy>=1.10; extra == 'pettingzoo'
31
+ Provides-Extra: smac
32
+ Requires-Dist: pysc2; extra == 'smac'
33
+ Requires-Dist: smac; extra == 'smac'
14
34
  Description-Content-Type: text/markdown
15
35
 
16
- # `marlenv` - A unified interface for muti-agent reinforcement learning
36
+ # `marlenv` - A unified framework for muti-agent reinforcement learning
37
+ **Documentation: [https://yamoling.github.io/multi-agent-rlenv](https://yamoling.github.io/multi-agent-rlenv)**
38
+
17
39
  The objective of `marlenv` is to provide a common (typed) interface for many different reinforcement learning environments.
18
40
 
19
41
  As such, `marlenv` provides high level abstractions of RL concepts such as `Observation`s or `Transition`s that are commonly represented as mere (confusing) lists or tuples.
20
42
 
43
+ ## Installation
44
+ Install with you preferred package manager (`uv`, `pip`, `poetry`, ...):
45
+ ```bash
46
+ $ pip install marlenv[all] # Enable all features
47
+ $ pip install marlenv # Basic installation
48
+ ```
49
+
50
+ There are multiple optional dependencies if you want to support specific libraries and environments. Available options are:
51
+ - `smac` for StarCraft II environments
52
+ - `gym` for OpenAI Gym environments
53
+ - `pettingzoo` for PettingZoo environments
54
+ - `overcooked` for Overcooked environments
55
+
56
+ Install them with:
57
+ ```bash
58
+ $ pip install marlenv[smac] # Install SMAC
59
+ $ pip install marlenv[gym,smac] # Install Gym & smac support
60
+ ```
61
+
62
+
21
63
  ## Using `marlenv` with existing libraries
22
64
  `marlenv` unifies multiple popular libraries under a single interface. Namely, `marlenv` supports `smac`, `gymnasium` and `pettingzoo`.
23
65
 
@@ -47,7 +89,7 @@ from marlenv import RLEnv, DiscreteActionSpace, Observation
47
89
  N_AGENTS = 3
48
90
  N_ACTIONS = 5
49
91
 
50
- class CustomEnv(RLEnv[DiscreteActionSpace]):
92
+ class CustomEnv(MARLEnv[DiscreteActionSpace]):
51
93
  def __init__(self, width: int, height: int):
52
94
  super().__init__(
53
95
  action_space=DiscreteActionSpace(N_AGENTS, N_ACTIONS),
@@ -1,35 +1,37 @@
1
- marlenv/__init__.py,sha256=mA5XCI_yboGYESxCtbpp5oKHx1tplA-TwnyIjazx3nE,1499
2
- marlenv/env_builder.py,sha256=G2ZkTradBmbfN8U84ux_Bflv8S6pr2zJiQVrWaP6ygc,5527
3
- marlenv/env_pool.py,sha256=TSSYwD5-g4G473Ea097wFVbp3tyQrawywLIAFFEJCJY,1089
1
+ marlenv/__init__.py,sha256=7HYiLVpZ4PQQWYjvc_78eXrnidv-FfUiY1pUkbxpm5U,3741
2
+ marlenv/env_builder.py,sha256=_rdwcWRqnHP7i4M4Oje1Y2nrEBKH9EzTpqOuw_PNUyw,5560
3
+ marlenv/env_pool.py,sha256=R3WIrnQ5Zvff4HR1ecfkDmuO2zl7v1ywQ0K2_nvWFzs,1070
4
4
  marlenv/exceptions.py,sha256=gJUC_2rVAvOfK_ypVFc7Myh-pIfSU3To38VBVS_0rZA,1179
5
- marlenv/mock_env.py,sha256=fGPTJ0GEYvHNraI9OJh3PN1My27equFZjf3o8iZnHJs,4289
5
+ marlenv/mock_env.py,sha256=qB0fYFIfbopJf7Va8kCeVI5vsOy1-2JdEYe9gdV1Ruw,4761
6
6
  marlenv/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
- marlenv/adapters/__init__.py,sha256=eJLY15xLDik9nO51Qa4Hp4Ak1_tLxm5C7kzFySNmUjs,425
8
- marlenv/adapters/gym_adapter.py,sha256=Nwp3gtuZQM1j9ify3h5QaA4enAJ1e9cTT_nPi7qNubE,2703
7
+ marlenv/adapters/__init__.py,sha256=NEmuHPWz4SGQcgF7QuIeA0QaXK141JoYco-7mqj9Ghk,883
8
+ marlenv/adapters/gym_adapter.py,sha256=Vx6ZrYI7kiNlJODmqyjXu9WCdbCr6trcMNot0pvYD74,2864
9
+ marlenv/adapters/overcooked_adapter.py,sha256=exNYFQhtnrvXQtG7M7FuQUkcRJK60a8vu2uwhRLc0K4,5768
9
10
  marlenv/adapters/pettingzoo_adapter.py,sha256=9rwSc_b7qV3ChtEIevOkJvtIp7WoY3CVnu6L9DxlMB4,2852
10
11
  marlenv/adapters/pymarl_adapter.py,sha256=x__E90XpFbfSWhnBHtkcD6WYkmKki1LByNbUFoDBUcg,3416
11
12
  marlenv/adapters/smac_adapter.py,sha256=fOfKo1hL4ioKtM5qQGcwtfdkdwUEACjAZqaGmkoQUcU,8373
12
13
  marlenv/models/__init__.py,sha256=9M-rnj94nsdyO4zm_VEtyYBmde3iD2_eIY4bMB-IqCo,555
13
- marlenv/models/env.py,sha256=JOT00dFXRuRMQYrv6ZYb3siwNh9eMXKnGRw_dM_mgVU,6519
14
+ marlenv/models/env.py,sha256=faezAKOIccBauOFeo9wu5sX32pFmP3AMmGyJzaTRJcM,7514
14
15
  marlenv/models/episode.py,sha256=ZGBx6lb2snrUhDgFEwHPV1dp-XvMA7k4quQVUNQxsP0,15140
15
- marlenv/models/observation.py,sha256=rTAesS_jaIyRlH4wjo2izEpWS0Hn5_UKjhbvdp0H4tA,2994
16
- marlenv/models/spaces.py,sha256=12PIoSWqgXVSfB5poqRTV4CdS-WNCrKiPmfALIGj7Mk,6226
16
+ marlenv/models/observation.py,sha256=kAmh1hIoC2TGrZlGVzV0y4TXXCSrI7gcmG0raeoncYk,3153
17
+ marlenv/models/spaces.py,sha256=pw8Sum_fHBkR-lyfTqUij4azMCNm8oBZrYZe4WVR7rA,7652
17
18
  marlenv/models/state.py,sha256=958PXTHadi3gtRnhGgcGtqBnF44R11kdcx62NN2gwxA,1717
18
19
  marlenv/models/step.py,sha256=LKGAV2Cu-k9Gz1hwrfvGx51l8axtQRqDE9WVL5r2A1Q,3037
19
20
  marlenv/models/transition.py,sha256=2vvuhSSq911weCXio9nuyfsLVh_7ORSU_znOqpLLdLg,5107
20
- marlenv/wrappers/__init__.py,sha256=E1IwrJjXGB6ZgQPiv-Huw7TGmnCSutFSkApHEi8dSQA,744
21
+ marlenv/wrappers/__init__.py,sha256=P7YCK1KYJvE6BAlH--nOW9PSlrohhuw-1wlfgCTOl9U,808
21
22
  marlenv/wrappers/agent_id_wrapper.py,sha256=oTIAYxKD1JtHfrZN43mf-3e8pxjd0nxm07vxs3BfrGY,1187
22
23
  marlenv/wrappers/available_actions_mask.py,sha256=JoCJ9eqHlkY8wfY-oaceEi8yp1Efs1iK6IO2Ibf9oZA,1468
23
24
  marlenv/wrappers/available_actions_wrapper.py,sha256=9UTwP3LXvncBITJeQnEqwiP_lj-ipULACkGs-2QbMrI,1026
24
25
  marlenv/wrappers/blind_wrapper.py,sha256=YEayRf_dclhzx6LXsasZ-IM7C71kyPb1gV0pHYYfjig,857
25
26
  marlenv/wrappers/centralised.py,sha256=J4hOMRT2fit936LifANNJtP7UbBEb_xIyF4VL9-fZGw,3226
27
+ marlenv/wrappers/delayed_rewards.py,sha256=6oGJe-L_gGI-pQMResbkjsMDvXpni2SQvnTQ6wsZqGo,1170
26
28
  marlenv/wrappers/last_action_wrapper.py,sha256=u7a3Da5sg_gMrwZ3SE7PAwt2m9xSYYDKjngQyOmcJ74,2886
27
29
  marlenv/wrappers/paddings.py,sha256=VQOF4zaP61R74tQ4XTTT-FkK2QSy31AukICnqCy6zB0,2119
28
30
  marlenv/wrappers/penalty_wrapper.py,sha256=v4_H8OEN2-yujLzRb6P7W7KwmXHtjAFsxcdp3SbnKpo,996
29
31
  marlenv/wrappers/rlenv_wrapper.py,sha256=C2XekgBIM4x3Wa2Mtsn7rihRD4ymC2hORI473Af0sfw,2962
30
32
  marlenv/wrappers/time_limit.py,sha256=CDIMMJPMyIDHSFxUJaC7nb7Kd86-07NgZeFhrpZm82o,3985
31
33
  marlenv/wrappers/video_recorder.py,sha256=d5AFu6qHqby9mOcBsYWYPxAPiK1vtnfMYdZ81AnCekI,2624
32
- multi_agent_rlenv-3.2.2.dist-info/METADATA,sha256=EX_3681ZauwzYUllRCn7QaX9ejiXihzO7ib7TZJBEiY,3357
33
- multi_agent_rlenv-3.2.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
34
- multi_agent_rlenv-3.2.2.dist-info/licenses/LICENSE,sha256=_eeiGVoIJ7kYt6l1zbIvSBQppTnw0mjnYk1lQ4FxEjE,1074
35
- multi_agent_rlenv-3.2.2.dist-info/RECORD,,
34
+ multi_agent_rlenv-3.3.1.dist-info/METADATA,sha256=lrE00NibOxZAoMD_DacNbfIbpEHqX7zw5wXc0xj8iiY,4897
35
+ multi_agent_rlenv-3.3.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
36
+ multi_agent_rlenv-3.3.1.dist-info/licenses/LICENSE,sha256=_eeiGVoIJ7kYt6l1zbIvSBQppTnw0mjnYk1lQ4FxEjE,1074
37
+ multi_agent_rlenv-3.3.1.dist-info/RECORD,,