multi-agent-rlenv 3.6.3__tar.gz → 3.7.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/.github/workflows/ci.yaml +3 -5
  2. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/PKG-INFO +2 -2
  3. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/pyproject.toml +2 -2
  4. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/__init__.py +2 -2
  5. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/adapters/gym_adapter.py +3 -3
  6. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/adapters/pettingzoo_adapter.py +14 -14
  7. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/adapters/smac_adapter.py +10 -7
  8. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/catalog/__init__.py +9 -6
  9. multi_agent_rlenv-3.7.1/src/marlenv/catalog/connectn/__init__.py +11 -0
  10. multi_agent_rlenv-3.7.1/src/marlenv/catalog/connectn/board.py +186 -0
  11. multi_agent_rlenv-3.7.1/src/marlenv/catalog/connectn/env.py +51 -0
  12. multi_agent_rlenv-3.7.1/src/marlenv/catalog/coordinated_grid.py +139 -0
  13. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/catalog/deepsea.py +1 -1
  14. multi_agent_rlenv-3.7.1/src/marlenv/catalog/matrix_game.py +52 -0
  15. multi_agent_rlenv-3.7.1/src/marlenv/catalog/two_steps.py +93 -0
  16. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/env_pool.py +3 -3
  17. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/mock_env.py +2 -2
  18. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/models/spaces.py +7 -7
  19. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/utils/schedule.py +8 -10
  20. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/wrappers/agent_id_wrapper.py +2 -2
  21. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/wrappers/blind_wrapper.py +2 -2
  22. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/wrappers/centralised.py +3 -3
  23. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/wrappers/delayed_rewards.py +2 -2
  24. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/wrappers/last_action_wrapper.py +4 -4
  25. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/wrappers/paddings.py +4 -4
  26. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/wrappers/potential_shaping.py +2 -2
  27. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/wrappers/rlenv_wrapper.py +2 -2
  28. multi_agent_rlenv-3.7.1/src/marlenv/wrappers/state_counter.py +35 -0
  29. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/wrappers/time_limit.py +2 -2
  30. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/wrappers/video_recorder.py +2 -2
  31. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/tests/test_adapters.py +2 -3
  32. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/tests/test_models.py +7 -7
  33. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/tests/test_serialization.py +1 -1
  34. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/tests/test_wrappers.py +18 -18
  35. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/.github/workflows/docs.yaml +0 -0
  36. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/.gitignore +0 -0
  37. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/LICENSE +0 -0
  38. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/README.md +0 -0
  39. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/adapters/__init__.py +0 -0
  40. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/adapters/pymarl_adapter.py +0 -0
  41. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/env_builder.py +0 -0
  42. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/exceptions.py +0 -0
  43. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/models/__init__.py +0 -0
  44. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/models/env.py +0 -0
  45. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/models/episode.py +0 -0
  46. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/models/observation.py +0 -0
  47. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/models/state.py +0 -0
  48. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/models/step.py +0 -0
  49. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/models/transition.py +0 -0
  50. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/py.typed +0 -0
  51. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/utils/__init__.py +0 -0
  52. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/utils/cached_property_collector.py +0 -0
  53. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/utils/import_placeholders.py +0 -0
  54. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/wrappers/__init__.py +0 -0
  55. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/wrappers/action_randomizer.py +0 -0
  56. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/wrappers/available_actions_mask.py +0 -0
  57. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/wrappers/available_actions_wrapper.py +0 -0
  58. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/wrappers/penalty_wrapper.py +0 -0
  59. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/tests/__init__.py +0 -0
  60. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/tests/test_catalog.py +0 -0
  61. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/tests/test_deepsea.py +0 -0
  62. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/tests/test_episode.py +0 -0
  63. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/tests/test_others.py +0 -0
  64. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/tests/test_pool.py +0 -0
  65. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/tests/test_schedules.py +0 -0
  66. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/tests/test_spaces.py +0 -0
  67. {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/tests/utils.py +0 -0
@@ -27,8 +27,6 @@ jobs:
27
27
  - x86_64
28
28
  - aarch64
29
29
  python-version:
30
- - '3.10'
31
- - '3.11'
32
30
  - '3.12'
33
31
  - '3.13'
34
32
  runs-on: ${{ matrix.os }}
@@ -43,7 +41,7 @@ jobs:
43
41
  - name: Install uv
44
42
  uses: yezz123/setup-uv@v4
45
43
  with:
46
- uv-version: 0.6.4
44
+ uv-version: 0.9.24
47
45
  - name: Install dependencies and run pytest
48
46
  run: |
49
47
  uv sync --extra overcooked --extra gym --extra pettingzoo --extra torch
@@ -59,11 +57,11 @@ jobs:
59
57
  - name: Set up Python
60
58
  uses: actions/setup-python@v5
61
59
  with:
62
- python-version: 3.12
60
+ python-version: 3.13
63
61
  - name: Install UV
64
62
  uses: yezz123/setup-uv@v4
65
63
  with:
66
- uv-version: 0.6.4
64
+ uv-version: 0.9.24
67
65
  - name: Build wheels
68
66
  run: |
69
67
  uv venv
@@ -1,13 +1,13 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: multi-agent-rlenv
3
- Version: 3.6.3
3
+ Version: 3.7.1
4
4
  Summary: A strongly typed Multi-Agent Reinforcement Learning framework
5
5
  Project-URL: repository, https://github.com/yamoling/multi-agent-rlenv
6
6
  Author-email: Yannick Molinghen <yannick.molinghen@ulb.be>
7
7
  License-File: LICENSE
8
8
  Classifier: Operating System :: OS Independent
9
9
  Classifier: Programming Language :: Python :: 3
10
- Requires-Python: <4,>=3.10
10
+ Requires-Python: <4,>=3.12
11
11
  Requires-Dist: numpy>=2.0.0
12
12
  Requires-Dist: opencv-python>=4.0
13
13
  Requires-Dist: typing-extensions>=4.0
@@ -1,12 +1,12 @@
1
1
  [project]
2
2
  name = "multi-agent-rlenv"
3
- version = "3.6.3"
3
+ version = "3.7.1"
4
4
  description = "A strongly typed Multi-Agent Reinforcement Learning framework"
5
5
  authors = [
6
6
  { "name" = "Yannick Molinghen", "email" = "yannick.molinghen@ulb.be" },
7
7
  ]
8
8
  readme = "README.md"
9
- requires-python = ">=3.10, <4"
9
+ requires-python = ">=3.12, <4"
10
10
  urls = { "repository" = "https://github.com/yamoling/multi-agent-rlenv" }
11
11
  classifiers = [
12
12
  "Programming Language :: Python :: 3",
@@ -65,9 +65,9 @@ If you want to create a new environment, you can simply create a class that inhe
65
65
  from importlib.metadata import version, PackageNotFoundError
66
66
 
67
67
  try:
68
- __version__ = version("overcooked")
68
+ __version__ = version("multi-agent-rlenv")
69
69
  except PackageNotFoundError:
70
- __version__ = "0.0.0" # fallback pratique en dev/CI
70
+ __version__ = "0.0.0" # fallback for CI
71
71
 
72
72
 
73
73
  from . import models
@@ -44,8 +44,8 @@ class Gym(MARLEnv[Space]):
44
44
  raise ValueError("No observation available. Call reset() first.")
45
45
  return self._last_obs
46
46
 
47
- def step(self, actions):
48
- obs, reward, done, truncated, info = self._gym_env.step(list(actions)[0])
47
+ def step(self, action):
48
+ obs, reward, done, truncated, info = self._gym_env.step(list(action)[0])
49
49
  self._last_obs = Observation(
50
50
  np.array([obs], dtype=np.float32),
51
51
  self.available_actions(),
@@ -74,7 +74,7 @@ class Gym(MARLEnv[Space]):
74
74
  image = np.array(self._gym_env.render())
75
75
  if sys.platform in ("linux", "linux2"):
76
76
  image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
77
- return image
77
+ return np.array(image, dtype=np.uint8)
78
78
 
79
79
  def seed(self, seed_value: int):
80
80
  self._gym_env.reset(seed=seed_value)
@@ -33,39 +33,39 @@ class PettingZoo(MARLEnv[Space]):
33
33
  if obs_space.shape is None:
34
34
  raise NotImplementedError("Only discrete observation spaces are supported")
35
35
  self._pz_env = env
36
- env.reset()
37
- super().__init__(n_agents, space, obs_space.shape, self.get_state().shape)
36
+ self.n_agents = n_agents
37
+ self.n_actions = space.shape[-1]
38
+ self.last_observation, state = self.reset()
39
+ super().__init__(n_agents, space, obs_space.shape, state.shape)
38
40
  self.agents = env.possible_agents
39
- self.last_observation = None
40
41
 
41
42
  def get_state(self):
42
43
  try:
43
- return self._pz_env.state()
44
+ return State(self._pz_env.state())
44
45
  except NotImplementedError:
45
- return np.array([0])
46
+ assert self.last_observation is not None, "Cannot get the state unless there is a previous observation"
47
+ return State(self.last_observation.data)
46
48
 
47
- def step(self, actions: npt.NDArray | Sequence):
48
- action_dict = dict(zip(self.agents, actions))
49
+ def step(self, action: npt.NDArray | Sequence):
50
+ action_dict = dict(zip(self.agents, action))
49
51
  obs, reward, term, trunc, info = self._pz_env.step(action_dict)
50
52
  obs_data = np.array([v for v in obs.values()])
51
53
  reward = np.sum([r for r in reward.values()], keepdims=True)
52
54
  self.last_observation = Observation(obs_data, self.available_actions())
53
- state = State(self.get_state())
55
+ state = self.get_state()
54
56
  return Step(self.last_observation, state, reward, any(term.values()), any(trunc.values()), info)
55
57
 
56
58
  def reset(self):
57
59
  obs = self._pz_env.reset()[0]
58
60
  obs_data = np.array([v for v in obs.values()])
59
- self.last_observation = Observation(obs_data, self.available_actions(), self.get_state())
60
- return self.last_observation
61
+ self.last_observation = Observation(obs_data, self.available_actions())
62
+ return self.last_observation, self.get_state()
61
63
 
62
64
  def get_observation(self):
63
- if self.last_observation is None:
64
- raise ValueError("No observation available. Call reset() first.")
65
65
  return self.last_observation
66
66
 
67
67
  def seed(self, seed_value: int):
68
68
  self._pz_env.reset(seed=seed_value)
69
69
 
70
- def render(self, *_):
71
- return self._pz_env.render()
70
+ def render(self):
71
+ self._pz_env.render()
@@ -3,7 +3,7 @@ from typing import overload
3
3
 
4
4
  import numpy as np
5
5
  import numpy.typing as npt
6
- from smac.env import StarCraft2Env
6
+ from smac.env import StarCraft2Env # pyright: ignore[reportMissingImports]
7
7
 
8
8
  from marlenv.models import MARLEnv, Observation, State, Step, MultiDiscreteSpace, DiscreteSpace
9
9
 
@@ -169,17 +169,18 @@ class SMAC(MARLEnv[MultiDiscreteSpace]):
169
169
 
170
170
  def reset(self):
171
171
  obs, state = self._env.reset()
172
- obs = Observation(np.array(obs), self.available_actions(), state)
173
- return obs
172
+ obs = Observation(np.array(obs), self.available_actions())
173
+ state = State(state)
174
+ return obs, state
174
175
 
175
176
  def get_observation(self):
176
- return self._env.get_obs()
177
+ return Observation(np.array(self._env.get_obs()), self.available_actions())
177
178
 
178
179
  def get_state(self):
179
180
  return State(self._env.get_state())
180
181
 
181
- def step(self, actions):
182
- reward, done, info = self._env.step(actions)
182
+ def step(self, action):
183
+ reward, done, info = self._env.step(action)
183
184
  obs = Observation(
184
185
  self._env.get_obs(), # type: ignore
185
186
  self.available_actions(),
@@ -199,7 +200,9 @@ class SMAC(MARLEnv[MultiDiscreteSpace]):
199
200
  return np.array(self._env.get_avail_actions()) == 1
200
201
 
201
202
  def get_image(self):
202
- return self._env.render(mode="rgb_array")
203
+ img = self._env.render(mode="rgb_array")
204
+ assert img is not None
205
+ return img
203
206
 
204
207
  def seed(self, seed_value: int):
205
208
  self._env = StarCraft2Env(map_name=self._env.map_name, seed=seed_value)
@@ -1,13 +1,10 @@
1
1
  from marlenv.adapters import SMAC
2
2
  from .deepsea import DeepSea
3
+ from .matrix_game import MatrixGame
4
+ from .coordinated_grid import CoordinatedGrid
3
5
 
4
6
 
5
- __all__ = [
6
- "SMAC",
7
- "DeepSea",
8
- "lle",
9
- "overcooked",
10
- ]
7
+ __all__ = ["SMAC", "DeepSea", "lle", "overcooked", "MatrixGame", "connect_n", "CoordinatedGrid"]
11
8
 
12
9
 
13
10
  def lle():
@@ -20,3 +17,9 @@ def overcooked():
20
17
  from overcooked import Overcooked # pyright: ignore[reportMissingImports]
21
18
 
22
19
  return Overcooked
20
+
21
+
22
+ def connect_n():
23
+ from .connectn import ConnectN
24
+
25
+ return ConnectN
@@ -0,0 +1,11 @@
1
+ """
2
+ Connect-N game environment.
3
+
4
+ Inspiration from: https://github.com/Gualor/connect4-montecarlo
5
+ """
6
+
7
+ from .board import GameBoard
8
+ from .env import ConnectN
9
+
10
+
11
+ __all__ = ["ConnectN", "GameBoard"]
@@ -0,0 +1,186 @@
1
+ from enum import IntEnum
2
+
3
+ import numpy as np
4
+
5
+
6
+ class StepResult(IntEnum):
7
+ NOTHING = 0
8
+ TIE = 1
9
+ WIN = 2
10
+
11
+
12
+ class GameBoard:
13
+ """Connect4 game board class."""
14
+
15
+ def __init__(self, width: int, height: int, n: int):
16
+ assert width >= n or height >= height, "Impossible to win with this combination of width, height and n"
17
+ self.turn = 1
18
+ self.board = np.zeros(shape=(height, width), dtype=np.float32)
19
+ self.width = width
20
+ self.height = height
21
+ self.n_to_align = n
22
+ self.n_items_in_column = np.zeros(width, dtype=np.int32)
23
+
24
+ self.str_row = "+" + "-" * (self.width * 4 - 1) + "+"
25
+ self.numbers = "|" + " ".join([f" {i} " for i in range(self.width)]) + "|"
26
+
27
+ def valid_moves(self):
28
+ """Get list of valid moves (i.e. not full columns)."""
29
+ return self.n_items_in_column < self.height
30
+
31
+ def clear(self):
32
+ self.board = np.zeros(shape=(self.height, self.width), dtype=np.float32)
33
+ self.n_items_in_column = np.zeros(self.width, dtype=np.int32)
34
+ self.turn = 0
35
+
36
+ def show(self):
37
+ """Print out game board on console."""
38
+ print(self.str_row)
39
+ for j in range(self.height - 1, -1, -1):
40
+ for i in range(self.width):
41
+ match self.board[j, i]:
42
+ case 1:
43
+ print("| X", end=" ")
44
+ case -1:
45
+ print("| O", end=" ")
46
+ case _:
47
+ print("| ", end=" ")
48
+ print("|")
49
+ print(self.str_row)
50
+ print(self.numbers)
51
+ print(self.str_row)
52
+
53
+ def check_win(self, move_played: tuple[int, int]) -> bool:
54
+ if self.check_rows(move_played):
55
+ return True
56
+ if self.check_cols(move_played):
57
+ return True
58
+ if self.check_diags(move_played):
59
+ return True
60
+ return False
61
+
62
+ def check_tie(self) -> bool:
63
+ """
64
+ Check whether the game is a tie (i.e. the board is full).
65
+
66
+ Note that it does not check for a win, so it should be called after check_win.
67
+ """
68
+ # If the last row is full, the game is a tie
69
+ return bool(np.all(self.board[-1] != 0))
70
+
71
+ def check_rows(self, move_played: tuple[int, int]) -> bool:
72
+ row, col = move_played
73
+ start_index = max(0, col - self.n_to_align + 1)
74
+ end_index = min(self.width - self.n_to_align, col) + 1
75
+ for start in range(start_index, end_index):
76
+ slice = self.board[row, start : start + self.n_to_align]
77
+ if np.all(slice == self.turn):
78
+ return True
79
+ return False
80
+
81
+ def check_cols(self, move_played: tuple[int, int]) -> bool:
82
+ row, col = move_played
83
+ start_index = max(0, row - self.n_to_align + 1)
84
+ end_index = min(self.height - self.n_to_align, row) + 1
85
+ for start in range(start_index, end_index):
86
+ slice = self.board[start : start + self.n_to_align, col]
87
+ if np.all(slice == self.turn):
88
+ return True
89
+ return False
90
+
91
+ def check_diags(self, move_played: tuple[int, int]) -> bool:
92
+ row, col = move_played
93
+ # count the adjacent items in the / diagonal
94
+ n_adjacent = 0
95
+ # Top right
96
+ row_i, col_i = row + 1, col + 1
97
+ while row_i < self.height and col_i < self.width and self.board[row_i, col_i] == self.turn:
98
+ n_adjacent += 1
99
+ row_i += 1
100
+ col_i += 1
101
+ # Bottom left
102
+ row_i, col_i = row - 1, col - 1
103
+ while row_i >= 0 and col_i >= 0 and self.board[row_i, col_i] == self.turn:
104
+ n_adjacent += 1
105
+ row_i -= 1
106
+ col_i -= 1
107
+ if n_adjacent >= self.n_to_align - 1:
108
+ return True
109
+
110
+ # Count adjacent items in the \ diagonal
111
+ n_adjacent = 0
112
+ # Top left
113
+ row_i, col_i = row + 1, col - 1
114
+ while row_i < self.height and col_i >= 0 and self.board[row_i, col_i] == self.turn:
115
+ n_adjacent += 1
116
+ row_i += 1
117
+ col_i -= 1
118
+ # Bottom right
119
+ row_i, col_i = row - 1, col + 1
120
+ while row_i >= 0 and col_i < self.width and self.board[row_i, col_i] == self.turn:
121
+ n_adjacent += 1
122
+ row_i -= 1
123
+ col_i += 1
124
+
125
+ return n_adjacent >= self.n_to_align - 1
126
+
127
+ def play(self, column: int) -> StepResult:
128
+ """Apply move to board.
129
+
130
+ Args:
131
+ column (int): Selected column index (between 0 and the number of cols - 1).
132
+
133
+ Returns:
134
+ bool: whether the player has won.
135
+ """
136
+ row_index = self.n_items_in_column[column]
137
+ if row_index >= self.height:
138
+ raise ValueError(f"Column {column} is full, use `valid_moves` to check valid moves.")
139
+ self.n_items_in_column[column] += 1
140
+ self.board[row_index, column] = self.turn
141
+ if self.check_win((row_index, column)):
142
+ result = StepResult.WIN
143
+ elif self.check_tie():
144
+ result = StepResult.TIE
145
+ else:
146
+ result = StepResult.NOTHING
147
+ self.switch_turn()
148
+ return result
149
+
150
+ def switch_turn(self) -> None:
151
+ """Switch turn between players."""
152
+ self.turn = -self.turn
153
+
154
+
155
+ def test_win():
156
+ board = GameBoard(4, 1, 2)
157
+ assert board.play(0) == StepResult.NOTHING
158
+ assert board.play(2) == StepResult.NOTHING
159
+ assert board.play(1) == StepResult.WIN
160
+
161
+
162
+ def test_tie():
163
+ board = GameBoard(4, 1, 2)
164
+ assert board.play(0) == StepResult.NOTHING
165
+ assert board.play(1) == StepResult.NOTHING
166
+ assert board.play(2) == StepResult.NOTHING
167
+ assert board.play(3) == StepResult.TIE
168
+
169
+
170
+ def test_win_diag():
171
+ board = GameBoard(2, 2, 2)
172
+ assert board.play(0) == StepResult.NOTHING
173
+ assert board.play(1) == StepResult.NOTHING
174
+ assert board.play(1) == StepResult.WIN
175
+
176
+ board.clear()
177
+ assert board.play(1) == StepResult.NOTHING
178
+ assert board.play(1) == StepResult.NOTHING
179
+ assert board.play(0) == StepResult.WIN
180
+
181
+
182
+ if __name__ == "__main__":
183
+ test_win()
184
+ test_tie()
185
+ test_win_diag()
186
+ print("All tests passed!")
@@ -0,0 +1,51 @@
1
+ from typing import Sequence
2
+ import numpy as np
3
+ import numpy.typing as npt
4
+ from marlenv import MARLEnv, MultiDiscreteSpace, Step, State, Observation, DiscreteSpace
5
+
6
+ from .board import GameBoard, StepResult
7
+
8
+
9
+ class ConnectN(MARLEnv[MultiDiscreteSpace]):
10
+ def __init__(self, width: int = 7, height: int = 6, n: int = 4):
11
+ self.board = GameBoard(width, height, n)
12
+ action_space = DiscreteSpace(self.board.width).repeat(1)
13
+ observation_shape = (self.board.height, self.board.width)
14
+ state_shape = observation_shape
15
+ super().__init__(1, action_space, observation_shape, state_shape)
16
+
17
+ def reset(self):
18
+ self.board.clear()
19
+ return self.get_observation(), self.get_state()
20
+
21
+ def step(self, action: Sequence[int] | npt.NDArray[np.uint32]):
22
+ match self.board.play(action[0]):
23
+ case StepResult.NOTHING:
24
+ done = False
25
+ reward = 0
26
+ case StepResult.WIN:
27
+ done = True
28
+ reward = 1
29
+ case StepResult.TIE:
30
+ done = True
31
+ reward = 0
32
+ return Step(self.get_observation(), self.get_state(), reward, done, False)
33
+
34
+ def available_actions(self):
35
+ """Full columns are not available."""
36
+ return np.expand_dims(self.board.valid_moves(), axis=0)
37
+
38
+ def get_observation(self):
39
+ return Observation(self.board.board.copy(), self.available_actions())
40
+
41
+ def get_state(self):
42
+ return State(self.board.board.copy(), np.array([self.board.turn]))
43
+
44
+ def set_state(self, state: State):
45
+ self.board.board = state.data.copy() # type: ignore Currently a type error because of the unchecked shape
46
+ self.board.turn = int(state.extras[0])
47
+ n_completed = np.count_nonzero(self.board.board, axis=0)
48
+ self.board.n_items_in_column = n_completed
49
+
50
+ def render(self):
51
+ self.board.show()
@@ -0,0 +1,139 @@
1
+ import numpy as np
2
+ import itertools
3
+ from marlenv import MARLEnv, DiscreteSpace, Observation, State, Step
4
+
5
+
6
+ N_ROWS = 11
7
+ N_COLS = 12
8
+
9
+
10
+ class CoordinatedGrid(MARLEnv):
11
+ """
12
+ Coordinated grid world environment used in the EMC paper to test the effectiveness of the proposed method.
13
+ https://proceedings.neurips.cc/paper_files/paper/2021/file/1e8ca836c962598551882e689265c1c5-Paper.pdf
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ episode_limit=30,
19
+ time_penalty=2,
20
+ ):
21
+ super().__init__(
22
+ n_agents=2,
23
+ action_space=DiscreteSpace(5, ["SOUTH", "NORTH", "WEST", "EAST", "STAY"]).repeat(2),
24
+ observation_shape=(N_ROWS + N_COLS,),
25
+ state_shape=(N_ROWS + N_COLS,) * 2,
26
+ )
27
+ self._episode_steps = 0
28
+ self.episode_limit = episode_limit
29
+ self.center = N_COLS // 2
30
+ ###larger gridworld
31
+ visible_row = [i for i in range(N_ROWS // 2 - 2, N_ROWS // 2 + 3)]
32
+ visible_col = [i for i in range(N_COLS // 2 - 3, N_COLS // 2 + 3)]
33
+ self.vision_index = [[i, j] for i, j in list(itertools.product(visible_row, visible_col))]
34
+ self.agents_location = [[0, 0], [N_ROWS - 1, N_COLS - 1]]
35
+ self.time_penalty = time_penalty
36
+
37
+ def reset(self):
38
+ self.agents_location = [[0, 0], [N_ROWS - 1, N_COLS - 1]]
39
+ self._episode_steps = 0
40
+ return self.get_observation(), self.get_state()
41
+
42
+ def get_observation(self):
43
+ obs_1 = [[0 for _ in range(N_ROWS)], [0 for _ in range(N_COLS)]]
44
+ # obs_2 = obs_1.copy()
45
+ import copy
46
+
47
+ obs_2 = copy.deepcopy(obs_1)
48
+
49
+ obs_1[0][self.agents_location[0][0]] = 1
50
+ obs_1[1][self.agents_location[0][1]] = 1
51
+ obs_1 = obs_1[0] + obs_1[1]
52
+
53
+ obs_2[0][self.agents_location[1][0]] = 1
54
+ obs_2[1][self.agents_location[1][1]] = 1
55
+ obs_2 = obs_2[0] + obs_2[1]
56
+
57
+ if self.agents_location[0] in self.vision_index and self.agents_location[1] in self.vision_index:
58
+ temp = obs_1.copy()
59
+ obs_1 += obs_2.copy()
60
+ obs_2 += temp.copy()
61
+ elif self.agents_location[0] in self.vision_index:
62
+ obs_2 += obs_1.copy()
63
+ obs_1 += [0 for _ in range(N_ROWS + N_COLS)]
64
+ elif self.agents_location[1] in self.vision_index:
65
+ obs_1 += obs_2.copy()
66
+ obs_2 += [0 for _ in range(N_ROWS + N_COLS)]
67
+ else:
68
+ obs_2 += [0 for _ in range(N_ROWS + N_COLS)]
69
+ obs_1 += [0 for _ in range(N_ROWS + N_COLS)]
70
+
71
+ obs_data = np.array([obs_1, obs_2])
72
+ return Observation(obs_data, self.available_actions())
73
+
74
+ def get_state(self):
75
+ obs = self.get_observation()
76
+ state_data = obs.data.reshape(-1)
77
+ return State(state_data)
78
+
79
+ def available_actions(self):
80
+ avail_actions = np.full((self.n_agents, self.n_actions), True)
81
+ for agent_num, (y, x) in enumerate(self.agents_location):
82
+ if x == 0:
83
+ avail_actions[agent_num, 0] = 0
84
+ elif x == N_ROWS - 1:
85
+ avail_actions[agent_num, 1] = 0
86
+ if y == 0:
87
+ avail_actions[agent_num, 2] = 0
88
+ # Check for center line (depends on the agent number)
89
+ elif y == self.center + agent_num - 1:
90
+ avail_actions[agent_num, 3] = 0
91
+ return avail_actions
92
+
93
+ def step(self, action):
94
+ for idx, action in enumerate(action):
95
+ match action:
96
+ case 0:
97
+ self.agents_location[idx][0] -= 1
98
+ case 1:
99
+ self.agents_location[idx][0] += 1
100
+ case 2:
101
+ self.agents_location[idx][1] -= 1
102
+ case 3:
103
+ self.agents_location[idx][1] += 1
104
+ case 4:
105
+ pass
106
+ case _:
107
+ raise ValueError(f"Invalid action {action} for agent {idx}!")
108
+
109
+ self._episode_steps += 1
110
+ terminated = self._episode_steps >= self.episode_limit
111
+ env_info = {"battle_won": False}
112
+ n_arrived = self.n_agents_arrived()
113
+ if n_arrived == 1:
114
+ reward = -self.time_penalty
115
+ elif n_arrived == 2:
116
+ reward = 10
117
+ terminated = True
118
+ env_info = {"battle_won": True}
119
+ else:
120
+ reward = 0
121
+ return Step(self.get_observation(), self.get_state(), reward, terminated, terminated, env_info)
122
+
123
+ def n_agents_arrived(self):
124
+ n = 0
125
+ if self.agents_location[0] == [N_ROWS // 2, self.center - 1]:
126
+ n += 1
127
+ if self.agents_location[1] == [N_ROWS // 2, self.center]:
128
+ n += 1
129
+ return n
130
+
131
+ def render(self):
132
+ print("Agents location: ", self.agents_location)
133
+ for row in range(N_ROWS):
134
+ for col in range(N_COLS):
135
+ if [row, col] in self.agents_location:
136
+ print("X", end=" ")
137
+ else:
138
+ print(".", end=" ")
139
+ print()
@@ -45,7 +45,7 @@ class DeepSea(MARLEnv[MultiDiscreteSpace]):
45
45
  self._col = 0
46
46
  return self.get_observation(), self.get_state()
47
47
 
48
- def step(self, action: Sequence[int]):
48
+ def step(self, action: Sequence[int] | np.ndarray):
49
49
  self._row += 1
50
50
  if action[0] == LEFT:
51
51
  self._col -= 1
@@ -0,0 +1,52 @@
1
+ import numpy as np
2
+ from marlenv import MARLEnv, Observation, DiscreteSpace, State, Step
3
+
4
+
5
+ class MatrixGame(MARLEnv):
6
+ """Single step matrix game used in QTRAN, Qatten and QPLEX papers."""
7
+
8
+ N_AGENTS = 2
9
+ UNIT_DIM = 1
10
+ OBS_SHAPE = (1,)
11
+ STATE_SIZE = UNIT_DIM * N_AGENTS
12
+
13
+ QPLEX_PAYOFF_MATRIX = [
14
+ [8.0, -12.0, -12.0],
15
+ [-12.0, 0.0, 0.0],
16
+ [-12.0, 0.0, 0.0],
17
+ ]
18
+
19
+ def __init__(self, payoff_matrix: list[list[float]]):
20
+ action_names = [chr(ord("A") + i) for i in range(len(payoff_matrix[0]))]
21
+ super().__init__(
22
+ 2,
23
+ action_space=DiscreteSpace(len(payoff_matrix[0]), action_names).repeat(2),
24
+ observation_shape=MatrixGame.OBS_SHAPE,
25
+ state_shape=(MatrixGame.STATE_SIZE,),
26
+ )
27
+ self.current_step = 0
28
+ self.payoffs = payoff_matrix
29
+
30
+ def reset(self):
31
+ self.current_step = 0
32
+ return self.get_observation(), self.get_state()
33
+
34
+ def get_observation(self):
35
+ return Observation(
36
+ np.array([[self.current_step]] * MatrixGame.N_AGENTS, np.float32),
37
+ self.available_actions(),
38
+ )
39
+
40
+ def step(self, action):
41
+ action = list(action)
42
+ self.current_step += 1
43
+ return Step(self.get_observation(), self.get_state(), self.payoffs[action[0]][action[1]], True)
44
+
45
+ def render(self):
46
+ return
47
+
48
+ def get_state(self):
49
+ return State(np.zeros((MatrixGame.STATE_SIZE,), np.float32))
50
+
51
+ def seed(self, seed_value):
52
+ return