multi-agent-rlenv 3.4.0__tar.gz → 3.5.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/PKG-INFO +1 -1
  2. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/__init__.py +11 -13
  3. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/adapters/gym_adapter.py +6 -16
  4. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/adapters/overcooked_adapter.py +6 -7
  5. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/adapters/pettingzoo_adapter.py +5 -5
  6. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/adapters/pymarl_adapter.py +3 -4
  7. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/adapters/smac_adapter.py +6 -6
  8. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/env_builder.py +8 -9
  9. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/env_pool.py +5 -7
  10. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/mock_env.py +7 -7
  11. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/models/__init__.py +2 -4
  12. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/models/env.py +18 -12
  13. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/models/episode.py +15 -18
  14. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/models/spaces.py +90 -83
  15. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/models/step.py +1 -1
  16. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/models/transition.py +6 -10
  17. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/wrappers/__init__.py +2 -0
  18. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/wrappers/agent_id_wrapper.py +4 -5
  19. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/wrappers/available_actions_mask.py +6 -7
  20. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/wrappers/available_actions_wrapper.py +7 -9
  21. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/wrappers/blind_wrapper.py +5 -7
  22. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/wrappers/centralised.py +12 -14
  23. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/wrappers/delayed_rewards.py +13 -11
  24. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/wrappers/last_action_wrapper.py +10 -14
  25. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/wrappers/paddings.py +6 -8
  26. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/wrappers/penalty_wrapper.py +5 -8
  27. multi_agent_rlenv-3.5.1/src/marlenv/wrappers/potential_shaping.py +49 -0
  28. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/wrappers/rlenv_wrapper.py +12 -10
  29. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/wrappers/time_limit.py +3 -3
  30. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/wrappers/video_recorder.py +4 -6
  31. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/tests/test_adapters.py +7 -7
  32. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/tests/test_models.py +2 -2
  33. multi_agent_rlenv-3.5.1/tests/test_spaces.py +183 -0
  34. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/tests/test_wrappers.py +35 -3
  35. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/tests/utils.py +2 -2
  36. multi_agent_rlenv-3.4.0/tests/test_spaces.py +0 -134
  37. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/.github/workflows/ci.yaml +0 -0
  38. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/.github/workflows/docs.yaml +0 -0
  39. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/.gitignore +0 -0
  40. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/LICENSE +0 -0
  41. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/README.md +0 -0
  42. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/pyproject.toml +0 -0
  43. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/adapters/__init__.py +0 -0
  44. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/exceptions.py +0 -0
  45. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/models/observation.py +0 -0
  46. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/models/state.py +0 -0
  47. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/py.typed +0 -0
  48. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/utils/__init__.py +0 -0
  49. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/utils/schedule.py +0 -0
  50. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/tests/__init__.py +0 -0
  51. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/tests/test_episode.py +0 -0
  52. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/tests/test_pool.py +0 -0
  53. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/tests/test_schedules.py +0 -0
  54. {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/tests/test_serialization.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: multi-agent-rlenv
3
- Version: 3.4.0
3
+ Version: 3.5.1
4
4
  Summary: A strongly typed Multi-Agent Reinforcement Learning framework
5
5
  Project-URL: repository, https://github.com/yamoling/multi-agent-rlenv
6
6
  Author-email: Yannick Molinghen <yannick.molinghen@ulb.be>
@@ -62,16 +62,11 @@ print(env.extras_shape) # (1, )
62
62
  If you want to create a new environment, you can simply create a class that inherits from `MARLEnv`. If you want to create a wrapper around an existing `MARLEnv`, you probably want to subclass `RLEnvWrapper` which implements a default behaviour for every method.
63
63
  """
64
64
 
65
- __version__ = "3.4.0"
65
+ __version__ = "3.5.1"
66
66
 
67
67
  from . import models
68
- from . import wrappers
69
- from . import adapters
70
- from .models import spaces
71
-
72
-
73
- from .env_builder import make, Builder
74
68
  from .models import (
69
+ spaces,
75
70
  MARLEnv,
76
71
  State,
77
72
  Step,
@@ -80,10 +75,14 @@ from .models import (
80
75
  Transition,
81
76
  DiscreteSpace,
82
77
  ContinuousSpace,
83
- ActionSpace,
84
- DiscreteActionSpace,
85
- ContinuousActionSpace,
78
+ Space,
79
+ MultiDiscreteSpace,
86
80
  )
81
+
82
+
83
+ from . import wrappers
84
+ from . import adapters
85
+ from .env_builder import make, Builder
87
86
  from .wrappers import RLEnvWrapper
88
87
  from .mock_env import DiscreteMockEnv, DiscreteMOMockEnv
89
88
 
@@ -100,12 +99,11 @@ __all__ = [
100
99
  "Observation",
101
100
  "Episode",
102
101
  "Transition",
103
- "ActionSpace",
104
102
  "DiscreteSpace",
105
103
  "ContinuousSpace",
106
- "DiscreteActionSpace",
107
- "ContinuousActionSpace",
108
104
  "DiscreteMockEnv",
109
105
  "DiscreteMOMockEnv",
110
106
  "RLEnvWrapper",
107
+ "Space",
108
+ "MultiDiscreteSpace",
111
109
  ]
@@ -1,26 +1,16 @@
1
1
  import sys
2
- import cv2
3
2
  from dataclasses import dataclass
4
- from typing import Sequence
5
3
 
4
+ import cv2
6
5
  import gymnasium as gym
7
6
  import numpy as np
8
- import numpy.typing as npt
9
7
  from gymnasium import Env, spaces
10
8
 
11
- from marlenv.models import (
12
- ActionSpace,
13
- ContinuousActionSpace,
14
- DiscreteActionSpace,
15
- MARLEnv,
16
- Observation,
17
- State,
18
- Step,
19
- )
9
+ from marlenv import ContinuousSpace, DiscreteSpace, MARLEnv, Observation, Space, State, Step
20
10
 
21
11
 
22
12
  @dataclass
23
- class Gym(MARLEnv[Sequence | npt.NDArray, ActionSpace]):
13
+ class Gym(MARLEnv[Space]):
24
14
  """Wraps a gym envronment in an RLEnv"""
25
15
 
26
16
  def __init__(self, env: Env | str, **kwargs):
@@ -30,7 +20,7 @@ class Gym(MARLEnv[Sequence | npt.NDArray, ActionSpace]):
30
20
  raise NotImplementedError("Observation space must have a shape")
31
21
  match env.action_space:
32
22
  case spaces.Discrete() as s:
33
- space = DiscreteActionSpace(1, int(s.n))
23
+ space = DiscreteSpace(int(s.n), labels=[f"Action {i}" for i in range(s.n)]).repeat(1)
34
24
  case spaces.Box() as s:
35
25
  low = s.low.astype(np.float32)
36
26
  high = s.high.astype(np.float32)
@@ -38,10 +28,10 @@ class Gym(MARLEnv[Sequence | npt.NDArray, ActionSpace]):
38
28
  low = np.full(s.shape, s.low, dtype=np.float32)
39
29
  if not isinstance(high, np.ndarray):
40
30
  high = np.full(s.shape, s.high, dtype=np.float32)
41
- space = ContinuousActionSpace(1, low, high)
31
+ space = ContinuousSpace(low, high, labels=[f"Action {i}" for i in range(s.shape[0])]).repeat(1)
42
32
  case other:
43
33
  raise NotImplementedError(f"Action space {other} not supported")
44
- super().__init__(space, env.observation_space.shape, (1,))
34
+ super().__init__(1, space, env.observation_space.shape, (1,))
45
35
  self._gym_env = env
46
36
  if self._gym_env.unwrapped.spec is not None:
47
37
  self.name = self._gym_env.unwrapped.spec.id
@@ -7,7 +7,7 @@ import cv2
7
7
  import numpy as np
8
8
  import numpy.typing as npt
9
9
  import pygame
10
- from marlenv.models import ContinuousSpace, DiscreteActionSpace, MARLEnv, Observation, State, Step
10
+ from marlenv.models import ContinuousSpace, DiscreteSpace, MARLEnv, Observation, State, Step, MultiDiscreteSpace
11
11
  from marlenv.utils import Schedule
12
12
 
13
13
  from overcooked_ai_py.mdp.overcooked_env import OvercookedEnv
@@ -16,7 +16,7 @@ from overcooked_ai_py.visualization.state_visualizer import StateVisualizer
16
16
 
17
17
 
18
18
  @dataclass
19
- class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
19
+ class Overcooked(MARLEnv[MultiDiscreteSpace]):
20
20
  horizon: int
21
21
  shaping_factor: Schedule
22
22
 
@@ -37,10 +37,9 @@ class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
37
37
  # -1 because we extract the "urgent" layer to the extras
38
38
  shape = (int(layers - 1), int(width), int(height))
39
39
  super().__init__(
40
- action_space=DiscreteActionSpace(
41
- n_agents=self._mdp.num_players,
42
- n_actions=Action.NUM_ACTIONS,
43
- action_names=[Action.ACTION_TO_CHAR[a] for a in Action.ALL_ACTIONS],
40
+ n_agents=self._mdp.num_players,
41
+ action_space=DiscreteSpace(Action.NUM_ACTIONS, labels=[Action.ACTION_TO_CHAR[a] for a in Action.ALL_ACTIONS]).repeat(
42
+ self._mdp.num_players
44
43
  ),
45
44
  observation_shape=shape,
46
45
  extras_shape=(2,),
@@ -95,7 +94,7 @@ class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
95
94
  available_actions[agent_num, Action.ACTION_TO_INDEX[action]] = True
96
95
  return np.array(available_actions, dtype=np.bool)
97
96
 
98
- def step(self, actions: Sequence[int] | npt.NDArray[np.int32 | np.int64]) -> Step:
97
+ def step(self, actions: Sequence[int] | np.ndarray) -> Step:
99
98
  self.shaping_factor.update()
100
99
  actions = [Action.ALL_ACTIONS[a] for a in actions]
101
100
  _, reward, done, info = self._oenv.step(actions, display_phi=True)
@@ -6,17 +6,17 @@ import numpy.typing as npt
6
6
  from gymnasium import spaces # pettingzoo uses gymnasium spaces
7
7
  from pettingzoo import ParallelEnv
8
8
 
9
- from marlenv.models import ActionSpace, ContinuousActionSpace, DiscreteActionSpace, MARLEnv, Observation, State, Step
9
+ from marlenv.models import MARLEnv, Observation, State, Step, DiscreteSpace, ContinuousSpace, Space
10
10
 
11
11
 
12
12
  @dataclass
13
- class PettingZoo(MARLEnv[npt.NDArray, ActionSpace]):
13
+ class PettingZoo(MARLEnv[Space]):
14
14
  def __init__(self, env: ParallelEnv):
15
15
  aspace = env.action_space(env.possible_agents[0])
16
16
  n_agents = len(env.possible_agents)
17
17
  match aspace:
18
18
  case spaces.Discrete() as s:
19
- space = DiscreteActionSpace(n_agents, int(s.n))
19
+ space = DiscreteSpace.action(int(s.n)).repeat(n_agents)
20
20
 
21
21
  case spaces.Box() as s:
22
22
  low = s.low.astype(np.float32)
@@ -25,7 +25,7 @@ class PettingZoo(MARLEnv[npt.NDArray, ActionSpace]):
25
25
  low = np.full(s.shape, s.low, dtype=np.float32)
26
26
  if not isinstance(high, np.ndarray):
27
27
  high = np.full(s.shape, s.high, dtype=np.float32)
28
- space = ContinuousActionSpace(n_agents, low, high=high)
28
+ space = ContinuousSpace(low, high=high).repeat(n_agents)
29
29
  case other:
30
30
  raise NotImplementedError(f"Action space {other} not supported")
31
31
 
@@ -34,7 +34,7 @@ class PettingZoo(MARLEnv[npt.NDArray, ActionSpace]):
34
34
  raise NotImplementedError("Only discrete observation spaces are supported")
35
35
  self._pz_env = env
36
36
  env.reset()
37
- super().__init__(space, obs_space.shape, self.get_state().shape)
37
+ super().__init__(n_agents, space, obs_space.shape, self.get_state().shape)
38
38
  self.agents = env.possible_agents
39
39
  self.last_observation = None
40
40
 
@@ -1,10 +1,9 @@
1
1
  from dataclasses import dataclass
2
- from typing import Any, Sequence
2
+ from typing import Any
3
3
 
4
4
  import numpy as np
5
- import numpy.typing as npt
6
5
 
7
- from marlenv.models import DiscreteActionSpace, MARLEnv
6
+ from marlenv.models import MARLEnv, MultiDiscreteSpace
8
7
  from marlenv.wrappers import TimeLimit
9
8
 
10
9
 
@@ -15,7 +14,7 @@ class PymarlAdapter:
15
14
  with the pymarl-qplex code base.
16
15
  """
17
16
 
18
- def __init__(self, env: MARLEnv[Sequence | npt.NDArray, DiscreteActionSpace], episode_limit: int):
17
+ def __init__(self, env: MARLEnv[MultiDiscreteSpace], episode_limit: int):
19
18
  assert env.reward_space.size == 1, "Only single objective environments are supported."
20
19
  self.env = TimeLimit(env, episode_limit, add_extra=False)
21
20
  # Required by PyMarl
@@ -1,15 +1,15 @@
1
1
  from dataclasses import dataclass
2
- from typing import Sequence, overload
2
+ from typing import overload
3
3
 
4
4
  import numpy as np
5
5
  import numpy.typing as npt
6
6
  from smac.env import StarCraft2Env
7
7
 
8
- from marlenv.models import DiscreteActionSpace, MARLEnv, Observation, State, Step
8
+ from marlenv.models import MARLEnv, Observation, State, Step, MultiDiscreteSpace, DiscreteSpace
9
9
 
10
10
 
11
11
  @dataclass
12
- class SMAC(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
12
+ class SMAC(MARLEnv[MultiDiscreteSpace]):
13
13
  """Wrapper for the SMAC environment to work with this framework"""
14
14
 
15
15
  @overload
@@ -157,10 +157,10 @@ class SMAC(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
157
157
  case other:
158
158
  raise ValueError(f"Invalid argument type: {type(other)}")
159
159
  self._env = StarCraft2Env(map_name=map_name)
160
- action_space = DiscreteActionSpace(self._env.n_agents, self._env.n_actions)
161
160
  self._env_info = self._env.get_env_info()
162
161
  super().__init__(
163
- action_space=action_space,
162
+ self._env.n_agents,
163
+ action_space=DiscreteSpace(self._env.n_actions).repeat(self._env.n_agents),
164
164
  observation_shape=(self._env_info["obs_shape"],),
165
165
  state_shape=(self._env_info["state_shape"],),
166
166
  )
@@ -195,7 +195,7 @@ class SMAC(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
195
195
  )
196
196
  return step
197
197
 
198
- def available_actions(self) -> npt.NDArray[np.bool_]:
198
+ def available_actions(self) -> npt.NDArray[np.bool]:
199
199
  return np.array(self._env.get_avail_actions()) == 1
200
200
 
201
201
  def get_image(self):
@@ -5,10 +5,9 @@ import numpy.typing as npt
5
5
 
6
6
  from . import wrappers
7
7
  from marlenv import adapters
8
- from .models import ActionSpace, MARLEnv
8
+ from .models import Space, MARLEnv
9
9
 
10
- A = TypeVar("A")
11
- AS = TypeVar("AS", bound=ActionSpace)
10
+ AS = TypeVar("AS", bound=Space)
12
11
 
13
12
  if adapters.HAS_PETTINGZOO:
14
13
  from .adapters import PettingZoo
@@ -71,12 +70,12 @@ def make(env, **kwargs):
71
70
 
72
71
 
73
72
  @dataclass
74
- class Builder(Generic[A, AS]):
73
+ class Builder(Generic[AS]):
75
74
  """Builder for environments"""
76
75
 
77
- _env: MARLEnv[A, AS]
76
+ _env: MARLEnv[AS]
78
77
 
79
- def __init__(self, env: MARLEnv[A, AS]):
78
+ def __init__(self, env: MARLEnv[AS]):
80
79
  self._env = env
81
80
 
82
81
  def time_limit(self, n_steps: int, add_extra: bool = True, truncation_penalty: Optional[float] = None):
@@ -124,9 +123,9 @@ class Builder(Generic[A, AS]):
124
123
 
125
124
  def centralised(self):
126
125
  """Centralises the observations and actions"""
127
- from marlenv.models import DiscreteActionSpace
126
+ from marlenv.models import MultiDiscreteSpace
128
127
 
129
- assert isinstance(self._env.action_space, DiscreteActionSpace)
128
+ assert isinstance(self._env.action_space, MultiDiscreteSpace)
130
129
  self._env = wrappers.Centralized(self._env) # type: ignore
131
130
  return self
132
131
 
@@ -159,6 +158,6 @@ class Builder(Generic[A, AS]):
159
158
  self._env = wrappers.TimePenalty(self._env, penalty)
160
159
  return self
161
160
 
162
- def build(self) -> MARLEnv[A, AS]:
161
+ def build(self) -> MARLEnv[AS]:
163
162
  """Build and return the environment"""
164
163
  return self._env
@@ -1,21 +1,19 @@
1
1
  from typing import Sequence
2
2
  from dataclasses import dataclass
3
- import numpy.typing as npt
4
3
  from typing_extensions import TypeVar
5
4
  import random
6
5
 
7
6
  from marlenv import RLEnvWrapper, MARLEnv
8
- from marlenv.models import ActionSpace
7
+ from marlenv.models import Space
9
8
 
10
- ActionType = TypeVar("ActionType", default=npt.NDArray)
11
- ActionSpaceType = TypeVar("ActionSpaceType", bound=ActionSpace, default=ActionSpace)
9
+ ActionSpaceType = TypeVar("ActionSpaceType", bound=Space, default=Space)
12
10
 
13
11
 
14
12
  @dataclass
15
- class EnvPool(RLEnvWrapper[ActionType, ActionSpaceType]):
16
- envs: Sequence[MARLEnv[ActionType, ActionSpaceType]]
13
+ class EnvPool(RLEnvWrapper[ActionSpaceType]):
14
+ envs: Sequence[MARLEnv[ActionSpaceType]]
17
15
 
18
- def __init__(self, envs: Sequence[MARLEnv[ActionType, ActionSpaceType]]):
16
+ def __init__(self, envs: Sequence[MARLEnv[ActionSpaceType]]):
19
17
  assert len(envs) > 0, "EnvPool must contain at least one environment"
20
18
  self.envs = envs
21
19
  for env in envs[1:]:
@@ -1,12 +1,10 @@
1
- from typing import Sequence
2
1
  import numpy as np
3
- import numpy.typing as npt
4
2
  from dataclasses import dataclass
5
- from marlenv import MARLEnv, Observation, DiscreteActionSpace, ContinuousSpace, Step, State
3
+ from marlenv import MARLEnv, Observation, ContinuousSpace, Step, State, DiscreteSpace, MultiDiscreteSpace
6
4
 
7
5
 
8
6
  @dataclass
9
- class DiscreteMockEnv(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
7
+ class DiscreteMockEnv(MARLEnv[MultiDiscreteSpace]):
10
8
  def __init__(
11
9
  self,
12
10
  n_agents: int = 4,
@@ -27,7 +25,8 @@ class DiscreteMockEnv(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace])
27
25
  case _:
28
26
  raise ValueError("reward_step must be an int, float or np.ndarray")
29
27
  super().__init__(
30
- DiscreteActionSpace(n_agents, n_actions),
28
+ n_agents,
29
+ DiscreteSpace(n_actions).repeat(n_agents),
31
30
  (obs_size,),
32
31
  (n_agents * agent_state_size,),
33
32
  extras_shape=(extras_size,),
@@ -85,7 +84,7 @@ class DiscreteMockEnv(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace])
85
84
  )
86
85
 
87
86
 
88
- class DiscreteMOMockEnv(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
87
+ class DiscreteMOMockEnv(MARLEnv[DiscreteSpace]):
89
88
  """Multi-Objective Mock Environment"""
90
89
 
91
90
  def __init__(
@@ -100,7 +99,8 @@ class DiscreteMOMockEnv(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace
100
99
  extras_size: int = 0,
101
100
  ) -> None:
102
101
  super().__init__(
103
- DiscreteActionSpace(n_agents, n_actions),
102
+ n_agents,
103
+ DiscreteSpace(n_actions),
104
104
  (obs_size,),
105
105
  (n_agents * agent_state_size,),
106
106
  extras_shape=(extras_size,),
@@ -1,4 +1,4 @@
1
- from .spaces import ActionSpace, DiscreteSpace, ContinuousSpace, MultiDiscreteSpace, DiscreteActionSpace, ContinuousActionSpace
1
+ from .spaces import DiscreteSpace, ContinuousSpace, MultiDiscreteSpace, Space
2
2
  from .observation import Observation
3
3
  from .step import Step
4
4
  from .state import State
@@ -8,7 +8,6 @@ from .episode import Episode
8
8
 
9
9
 
10
10
  __all__ = [
11
- "ActionSpace",
12
11
  "Step",
13
12
  "State",
14
13
  "DiscreteSpace",
@@ -18,6 +17,5 @@ __all__ = [
18
17
  "Transition",
19
18
  "Episode",
20
19
  "MultiDiscreteSpace",
21
- "DiscreteActionSpace",
22
- "ContinuousActionSpace",
20
+ "Space",
23
21
  ]
@@ -1,24 +1,22 @@
1
1
  from abc import ABC, abstractmethod
2
2
  from dataclasses import dataclass
3
3
  from itertools import product
4
- from typing import Any, Generic, Optional, Sequence
4
+ from typing import Generic, Optional, Sequence, TypeVar
5
5
 
6
6
  import cv2
7
7
  import numpy as np
8
8
  import numpy.typing as npt
9
- from typing_extensions import TypeVar
10
9
 
11
10
  from .observation import Observation
12
- from .spaces import ActionSpace, ContinuousSpace, Space
11
+ from .spaces import ContinuousSpace, Space, DiscreteSpace, MultiDiscreteSpace
13
12
  from .state import State
14
13
  from .step import Step
15
14
 
16
- ActionType = TypeVar("ActionType", default=Any)
17
- ActionSpaceType = TypeVar("ActionSpaceType", bound=ActionSpace, default=Any)
15
+ ActionSpaceType = TypeVar("ActionSpaceType", bound=Space)
18
16
 
19
17
 
20
18
  @dataclass
21
- class MARLEnv(ABC, Generic[ActionType, ActionSpaceType]):
19
+ class MARLEnv(ABC, Generic[ActionSpaceType]):
22
20
  """
23
21
  Multi-Agent Reinforcement Learning environment.
24
22
 
@@ -70,6 +68,7 @@ class MARLEnv(ABC, Generic[ActionType, ActionSpaceType]):
70
68
 
71
69
  def __init__(
72
70
  self,
71
+ n_agents: int,
73
72
  action_space: ActionSpaceType,
74
73
  observation_shape: tuple[int, ...],
75
74
  state_shape: tuple[int, ...],
@@ -81,8 +80,8 @@ class MARLEnv(ABC, Generic[ActionType, ActionSpaceType]):
81
80
  super().__init__()
82
81
  self.name = self.__class__.__name__
83
82
  self.action_space = action_space
84
- self.n_actions = action_space.n_actions
85
- self.n_agents = action_space.n_agents
83
+ self.n_actions = action_space.shape[-1]
84
+ self.n_agents = n_agents
86
85
  self.observation_shape = observation_shape
87
86
  self.state_shape = state_shape
88
87
  self.extras_shape = extras_shape
@@ -113,9 +112,16 @@ class MARLEnv(ABC, Generic[ActionType, ActionSpaceType]):
113
112
  """The number of objectives in the environment."""
114
113
  return self.reward_space.size
115
114
 
116
- def sample_action(self) -> ActionType:
115
+ def sample_action(self):
117
116
  """Sample an available action from the action space."""
118
- return self.action_space.sample(self.available_actions()) # type: ignore
117
+ match self.action_space:
118
+ case MultiDiscreteSpace() as aspace:
119
+ return aspace.sample(mask=self.available_actions())
120
+ case ContinuousSpace() as aspace:
121
+ return aspace.sample()
122
+ case DiscreteSpace() as aspace:
123
+ return np.array([aspace.sample(mask=self.available_actions())])
124
+ raise NotImplementedError("Action space not supported")
119
125
 
120
126
  def available_actions(self) -> npt.NDArray[np.bool]:
121
127
  """
@@ -147,7 +153,7 @@ class MARLEnv(ABC, Generic[ActionType, ActionSpaceType]):
147
153
  raise NotImplementedError("Method not implemented")
148
154
 
149
155
  @abstractmethod
150
- def step(self, actions: ActionType) -> Step:
156
+ def step(self, action: Sequence | np.ndarray) -> Step:
151
157
  """Perform a step in the environment.
152
158
 
153
159
  Returns a Step object that can be unpacked as a 6-tuple containing:
@@ -180,7 +186,7 @@ class MARLEnv(ABC, Generic[ActionType, ActionSpaceType]):
180
186
  """Retrieve an image of the environment"""
181
187
  raise NotImplementedError("No image available for this environment")
182
188
 
183
- def replay(self, actions: Sequence[ActionType], seed: Optional[int] = None):
189
+ def replay(self, actions: Sequence, seed: Optional[int] = None):
184
190
  """Replay a sequence of actions."""
185
191
  from .episode import Episode # Avoid circular import
186
192
 
@@ -1,6 +1,6 @@
1
1
  from dataclasses import dataclass
2
2
  from functools import cached_property
3
- from typing import Any, Callable, Generic, Optional, Sequence, TypeVar, overload
3
+ from typing import Any, Callable, Optional, Sequence, overload
4
4
 
5
5
  import numpy as np
6
6
  import numpy.typing as npt
@@ -14,11 +14,8 @@ from .env import MARLEnv
14
14
  from marlenv.exceptions import EnvironmentMismatchException, ReplayMismatchException
15
15
 
16
16
 
17
- A = TypeVar("A")
18
-
19
-
20
17
  @dataclass
21
- class Episode(Generic[A]):
18
+ class Episode:
22
19
  """Episode model made of observations, actions, rewards, ..."""
23
20
 
24
21
  all_observations: list[npt.NDArray[np.float32]]
@@ -55,7 +52,7 @@ class Episode(Generic[A]):
55
52
  )
56
53
 
57
54
  @staticmethod
58
- def from_transitions(transitions: Sequence[Transition[A]]) -> "Episode":
55
+ def from_transitions(transitions: Sequence[Transition]) -> "Episode":
59
56
  """Create an episode from a list of transitions"""
60
57
  episode = Episode.new(transitions[0].obs, transitions[0].state)
61
58
  for transition in transitions:
@@ -214,11 +211,11 @@ class Episode(Generic[A]):
214
211
 
215
212
  def replay(
216
213
  self,
217
- env: MARLEnv[A, Any],
214
+ env: MARLEnv,
218
215
  seed: Optional[int] = None,
219
216
  *,
220
- after_reset: Optional[Callable[[Observation, State, MARLEnv[A]], None]] = None,
221
- after_step: Optional[Callable[[int, Step, MARLEnv[A]], None]] = None,
217
+ after_reset: Optional[Callable[[Observation, State, MARLEnv], None]] = None,
218
+ after_step: Optional[Callable[[int, Step, MARLEnv], None]] = None,
222
219
  ):
223
220
  """
224
221
  Replay the episode in the environment (i.e. perform the actions) and assert that the outcomes match.
@@ -243,12 +240,12 @@ class Episode(Generic[A]):
243
240
  raise ReplayMismatchException("observation", step.obs.data, self.next_obs[i], time_step=i)
244
241
  if not np.array_equal(step.state.data, self.next_states[i]):
245
242
  raise ReplayMismatchException("state", step.state.data, self.next_states[i], time_step=i)
246
- if not np.array_equal(step.reward, self.rewards[i]):
243
+ if not np.isclose(step.reward, self.rewards[i]):
247
244
  raise ReplayMismatchException("reward", step.reward, self.rewards[i], time_step=i)
248
245
  if after_step is not None:
249
246
  after_step(i, step, env)
250
247
 
251
- def get_images(self, env: MARLEnv[A, Any], seed: Optional[int] = None) -> list[np.ndarray]:
248
+ def get_images(self, env: MARLEnv, seed: Optional[int] = None) -> list[np.ndarray]:
252
249
  images = []
253
250
 
254
251
  def collect_image(*_, **__):
@@ -257,7 +254,7 @@ class Episode(Generic[A]):
257
254
  self.replay(env, seed, after_reset=collect_image, after_step=collect_image)
258
255
  return images
259
256
 
260
- def render(self, env: MARLEnv[A, Any], seed: Optional[int] = None, fps: int = 5):
257
+ def render(self, env: MARLEnv, seed: Optional[int] = None, fps: int = 5):
261
258
  def render_callback(*_, **__):
262
259
  env.render()
263
260
  cv2.waitKey(1000 // fps)
@@ -288,10 +285,10 @@ class Episode(Generic[A]):
288
285
  return returns
289
286
 
290
287
  @overload
291
- def add(self, transition: Transition[A], /): ...
288
+ def add(self, transition: Transition, /): ...
292
289
 
293
290
  @overload
294
- def add(self, step: Step, action: A, /): ...
291
+ def add(self, step: Step, action: np.ndarray, /): ...
295
292
 
296
293
  def add(self, *data):
297
294
  match data:
@@ -322,10 +319,10 @@ class Episode(Generic[A]):
322
319
 
323
320
  def add_data(
324
321
  self,
325
- next_obs,
326
- next_state,
327
- action: A,
328
- reward: np.ndarray,
322
+ next_obs: Observation,
323
+ next_state: State,
324
+ action: np.ndarray,
325
+ reward: npt.NDArray[np.float32],
329
326
  others: dict[str, Any],
330
327
  done: bool,
331
328
  truncated: bool,