multi-agent-rlenv 3.6.3__tar.gz → 3.7.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/.github/workflows/ci.yaml +3 -5
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/PKG-INFO +2 -2
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/pyproject.toml +2 -2
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/__init__.py +2 -2
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/adapters/gym_adapter.py +3 -3
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/adapters/pettingzoo_adapter.py +14 -14
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/adapters/smac_adapter.py +10 -7
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/catalog/__init__.py +9 -6
- multi_agent_rlenv-3.7.1/src/marlenv/catalog/connectn/__init__.py +11 -0
- multi_agent_rlenv-3.7.1/src/marlenv/catalog/connectn/board.py +186 -0
- multi_agent_rlenv-3.7.1/src/marlenv/catalog/connectn/env.py +51 -0
- multi_agent_rlenv-3.7.1/src/marlenv/catalog/coordinated_grid.py +139 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/catalog/deepsea.py +1 -1
- multi_agent_rlenv-3.7.1/src/marlenv/catalog/matrix_game.py +52 -0
- multi_agent_rlenv-3.7.1/src/marlenv/catalog/two_steps.py +93 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/env_pool.py +3 -3
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/mock_env.py +2 -2
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/models/spaces.py +7 -7
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/utils/schedule.py +8 -10
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/wrappers/agent_id_wrapper.py +2 -2
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/wrappers/blind_wrapper.py +2 -2
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/wrappers/centralised.py +3 -3
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/wrappers/delayed_rewards.py +2 -2
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/wrappers/last_action_wrapper.py +4 -4
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/wrappers/paddings.py +4 -4
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/wrappers/potential_shaping.py +2 -2
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/wrappers/rlenv_wrapper.py +2 -2
- multi_agent_rlenv-3.7.1/src/marlenv/wrappers/state_counter.py +35 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/wrappers/time_limit.py +2 -2
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/wrappers/video_recorder.py +2 -2
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/tests/test_adapters.py +2 -3
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/tests/test_models.py +7 -7
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/tests/test_serialization.py +1 -1
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/tests/test_wrappers.py +18 -18
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/.github/workflows/docs.yaml +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/.gitignore +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/LICENSE +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/README.md +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/adapters/__init__.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/adapters/pymarl_adapter.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/env_builder.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/exceptions.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/models/__init__.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/models/env.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/models/episode.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/models/observation.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/models/state.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/models/step.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/models/transition.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/py.typed +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/utils/__init__.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/utils/cached_property_collector.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/utils/import_placeholders.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/wrappers/__init__.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/wrappers/action_randomizer.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/wrappers/available_actions_mask.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/wrappers/available_actions_wrapper.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/wrappers/penalty_wrapper.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/tests/__init__.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/tests/test_catalog.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/tests/test_deepsea.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/tests/test_episode.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/tests/test_others.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/tests/test_pool.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/tests/test_schedules.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/tests/test_spaces.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/tests/utils.py +0 -0
|
@@ -27,8 +27,6 @@ jobs:
|
|
|
27
27
|
- x86_64
|
|
28
28
|
- aarch64
|
|
29
29
|
python-version:
|
|
30
|
-
- '3.10'
|
|
31
|
-
- '3.11'
|
|
32
30
|
- '3.12'
|
|
33
31
|
- '3.13'
|
|
34
32
|
runs-on: ${{ matrix.os }}
|
|
@@ -43,7 +41,7 @@ jobs:
|
|
|
43
41
|
- name: Install uv
|
|
44
42
|
uses: yezz123/setup-uv@v4
|
|
45
43
|
with:
|
|
46
|
-
uv-version: 0.
|
|
44
|
+
uv-version: 0.9.24
|
|
47
45
|
- name: Install dependencies and run pytest
|
|
48
46
|
run: |
|
|
49
47
|
uv sync --extra overcooked --extra gym --extra pettingzoo --extra torch
|
|
@@ -59,11 +57,11 @@ jobs:
|
|
|
59
57
|
- name: Set up Python
|
|
60
58
|
uses: actions/setup-python@v5
|
|
61
59
|
with:
|
|
62
|
-
python-version: 3.
|
|
60
|
+
python-version: 3.13
|
|
63
61
|
- name: Install UV
|
|
64
62
|
uses: yezz123/setup-uv@v4
|
|
65
63
|
with:
|
|
66
|
-
uv-version: 0.
|
|
64
|
+
uv-version: 0.9.24
|
|
67
65
|
- name: Build wheels
|
|
68
66
|
run: |
|
|
69
67
|
uv venv
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: multi-agent-rlenv
|
|
3
|
-
Version: 3.
|
|
3
|
+
Version: 3.7.1
|
|
4
4
|
Summary: A strongly typed Multi-Agent Reinforcement Learning framework
|
|
5
5
|
Project-URL: repository, https://github.com/yamoling/multi-agent-rlenv
|
|
6
6
|
Author-email: Yannick Molinghen <yannick.molinghen@ulb.be>
|
|
7
7
|
License-File: LICENSE
|
|
8
8
|
Classifier: Operating System :: OS Independent
|
|
9
9
|
Classifier: Programming Language :: Python :: 3
|
|
10
|
-
Requires-Python: <4,>=3.
|
|
10
|
+
Requires-Python: <4,>=3.12
|
|
11
11
|
Requires-Dist: numpy>=2.0.0
|
|
12
12
|
Requires-Dist: opencv-python>=4.0
|
|
13
13
|
Requires-Dist: typing-extensions>=4.0
|
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "multi-agent-rlenv"
|
|
3
|
-
version = "3.
|
|
3
|
+
version = "3.7.1"
|
|
4
4
|
description = "A strongly typed Multi-Agent Reinforcement Learning framework"
|
|
5
5
|
authors = [
|
|
6
6
|
{ "name" = "Yannick Molinghen", "email" = "yannick.molinghen@ulb.be" },
|
|
7
7
|
]
|
|
8
8
|
readme = "README.md"
|
|
9
|
-
requires-python = ">=3.
|
|
9
|
+
requires-python = ">=3.12, <4"
|
|
10
10
|
urls = { "repository" = "https://github.com/yamoling/multi-agent-rlenv" }
|
|
11
11
|
classifiers = [
|
|
12
12
|
"Programming Language :: Python :: 3",
|
|
@@ -65,9 +65,9 @@ If you want to create a new environment, you can simply create a class that inhe
|
|
|
65
65
|
from importlib.metadata import version, PackageNotFoundError
|
|
66
66
|
|
|
67
67
|
try:
|
|
68
|
-
__version__ = version("
|
|
68
|
+
__version__ = version("multi-agent-rlenv")
|
|
69
69
|
except PackageNotFoundError:
|
|
70
|
-
__version__ = "0.0.0" # fallback
|
|
70
|
+
__version__ = "0.0.0" # fallback for CI
|
|
71
71
|
|
|
72
72
|
|
|
73
73
|
from . import models
|
|
@@ -44,8 +44,8 @@ class Gym(MARLEnv[Space]):
|
|
|
44
44
|
raise ValueError("No observation available. Call reset() first.")
|
|
45
45
|
return self._last_obs
|
|
46
46
|
|
|
47
|
-
def step(self,
|
|
48
|
-
obs, reward, done, truncated, info = self._gym_env.step(list(
|
|
47
|
+
def step(self, action):
|
|
48
|
+
obs, reward, done, truncated, info = self._gym_env.step(list(action)[0])
|
|
49
49
|
self._last_obs = Observation(
|
|
50
50
|
np.array([obs], dtype=np.float32),
|
|
51
51
|
self.available_actions(),
|
|
@@ -74,7 +74,7 @@ class Gym(MARLEnv[Space]):
|
|
|
74
74
|
image = np.array(self._gym_env.render())
|
|
75
75
|
if sys.platform in ("linux", "linux2"):
|
|
76
76
|
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
|
77
|
-
return image
|
|
77
|
+
return np.array(image, dtype=np.uint8)
|
|
78
78
|
|
|
79
79
|
def seed(self, seed_value: int):
|
|
80
80
|
self._gym_env.reset(seed=seed_value)
|
{multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.1}/src/marlenv/adapters/pettingzoo_adapter.py
RENAMED
|
@@ -33,39 +33,39 @@ class PettingZoo(MARLEnv[Space]):
|
|
|
33
33
|
if obs_space.shape is None:
|
|
34
34
|
raise NotImplementedError("Only discrete observation spaces are supported")
|
|
35
35
|
self._pz_env = env
|
|
36
|
-
|
|
37
|
-
|
|
36
|
+
self.n_agents = n_agents
|
|
37
|
+
self.n_actions = space.shape[-1]
|
|
38
|
+
self.last_observation, state = self.reset()
|
|
39
|
+
super().__init__(n_agents, space, obs_space.shape, state.shape)
|
|
38
40
|
self.agents = env.possible_agents
|
|
39
|
-
self.last_observation = None
|
|
40
41
|
|
|
41
42
|
def get_state(self):
|
|
42
43
|
try:
|
|
43
|
-
return self._pz_env.state()
|
|
44
|
+
return State(self._pz_env.state())
|
|
44
45
|
except NotImplementedError:
|
|
45
|
-
|
|
46
|
+
assert self.last_observation is not None, "Cannot get the state unless there is a previous observation"
|
|
47
|
+
return State(self.last_observation.data)
|
|
46
48
|
|
|
47
|
-
def step(self,
|
|
48
|
-
action_dict = dict(zip(self.agents,
|
|
49
|
+
def step(self, action: npt.NDArray | Sequence):
|
|
50
|
+
action_dict = dict(zip(self.agents, action))
|
|
49
51
|
obs, reward, term, trunc, info = self._pz_env.step(action_dict)
|
|
50
52
|
obs_data = np.array([v for v in obs.values()])
|
|
51
53
|
reward = np.sum([r for r in reward.values()], keepdims=True)
|
|
52
54
|
self.last_observation = Observation(obs_data, self.available_actions())
|
|
53
|
-
state =
|
|
55
|
+
state = self.get_state()
|
|
54
56
|
return Step(self.last_observation, state, reward, any(term.values()), any(trunc.values()), info)
|
|
55
57
|
|
|
56
58
|
def reset(self):
|
|
57
59
|
obs = self._pz_env.reset()[0]
|
|
58
60
|
obs_data = np.array([v for v in obs.values()])
|
|
59
|
-
self.last_observation = Observation(obs_data, self.available_actions()
|
|
60
|
-
return self.last_observation
|
|
61
|
+
self.last_observation = Observation(obs_data, self.available_actions())
|
|
62
|
+
return self.last_observation, self.get_state()
|
|
61
63
|
|
|
62
64
|
def get_observation(self):
|
|
63
|
-
if self.last_observation is None:
|
|
64
|
-
raise ValueError("No observation available. Call reset() first.")
|
|
65
65
|
return self.last_observation
|
|
66
66
|
|
|
67
67
|
def seed(self, seed_value: int):
|
|
68
68
|
self._pz_env.reset(seed=seed_value)
|
|
69
69
|
|
|
70
|
-
def render(self
|
|
71
|
-
|
|
70
|
+
def render(self):
|
|
71
|
+
self._pz_env.render()
|
|
@@ -3,7 +3,7 @@ from typing import overload
|
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import numpy.typing as npt
|
|
6
|
-
from smac.env import StarCraft2Env
|
|
6
|
+
from smac.env import StarCraft2Env # pyright: ignore[reportMissingImports]
|
|
7
7
|
|
|
8
8
|
from marlenv.models import MARLEnv, Observation, State, Step, MultiDiscreteSpace, DiscreteSpace
|
|
9
9
|
|
|
@@ -169,17 +169,18 @@ class SMAC(MARLEnv[MultiDiscreteSpace]):
|
|
|
169
169
|
|
|
170
170
|
def reset(self):
|
|
171
171
|
obs, state = self._env.reset()
|
|
172
|
-
obs = Observation(np.array(obs), self.available_actions()
|
|
173
|
-
|
|
172
|
+
obs = Observation(np.array(obs), self.available_actions())
|
|
173
|
+
state = State(state)
|
|
174
|
+
return obs, state
|
|
174
175
|
|
|
175
176
|
def get_observation(self):
|
|
176
|
-
return self._env.get_obs()
|
|
177
|
+
return Observation(np.array(self._env.get_obs()), self.available_actions())
|
|
177
178
|
|
|
178
179
|
def get_state(self):
|
|
179
180
|
return State(self._env.get_state())
|
|
180
181
|
|
|
181
|
-
def step(self,
|
|
182
|
-
reward, done, info = self._env.step(
|
|
182
|
+
def step(self, action):
|
|
183
|
+
reward, done, info = self._env.step(action)
|
|
183
184
|
obs = Observation(
|
|
184
185
|
self._env.get_obs(), # type: ignore
|
|
185
186
|
self.available_actions(),
|
|
@@ -199,7 +200,9 @@ class SMAC(MARLEnv[MultiDiscreteSpace]):
|
|
|
199
200
|
return np.array(self._env.get_avail_actions()) == 1
|
|
200
201
|
|
|
201
202
|
def get_image(self):
|
|
202
|
-
|
|
203
|
+
img = self._env.render(mode="rgb_array")
|
|
204
|
+
assert img is not None
|
|
205
|
+
return img
|
|
203
206
|
|
|
204
207
|
def seed(self, seed_value: int):
|
|
205
208
|
self._env = StarCraft2Env(map_name=self._env.map_name, seed=seed_value)
|
|
@@ -1,13 +1,10 @@
|
|
|
1
1
|
from marlenv.adapters import SMAC
|
|
2
2
|
from .deepsea import DeepSea
|
|
3
|
+
from .matrix_game import MatrixGame
|
|
4
|
+
from .coordinated_grid import CoordinatedGrid
|
|
3
5
|
|
|
4
6
|
|
|
5
|
-
__all__ = [
|
|
6
|
-
"SMAC",
|
|
7
|
-
"DeepSea",
|
|
8
|
-
"lle",
|
|
9
|
-
"overcooked",
|
|
10
|
-
]
|
|
7
|
+
__all__ = ["SMAC", "DeepSea", "lle", "overcooked", "MatrixGame", "connect_n", "CoordinatedGrid"]
|
|
11
8
|
|
|
12
9
|
|
|
13
10
|
def lle():
|
|
@@ -20,3 +17,9 @@ def overcooked():
|
|
|
20
17
|
from overcooked import Overcooked # pyright: ignore[reportMissingImports]
|
|
21
18
|
|
|
22
19
|
return Overcooked
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def connect_n():
|
|
23
|
+
from .connectn import ConnectN
|
|
24
|
+
|
|
25
|
+
return ConnectN
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
from enum import IntEnum
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class StepResult(IntEnum):
|
|
7
|
+
NOTHING = 0
|
|
8
|
+
TIE = 1
|
|
9
|
+
WIN = 2
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class GameBoard:
|
|
13
|
+
"""Connect4 game board class."""
|
|
14
|
+
|
|
15
|
+
def __init__(self, width: int, height: int, n: int):
|
|
16
|
+
assert width >= n or height >= height, "Impossible to win with this combination of width, height and n"
|
|
17
|
+
self.turn = 1
|
|
18
|
+
self.board = np.zeros(shape=(height, width), dtype=np.float32)
|
|
19
|
+
self.width = width
|
|
20
|
+
self.height = height
|
|
21
|
+
self.n_to_align = n
|
|
22
|
+
self.n_items_in_column = np.zeros(width, dtype=np.int32)
|
|
23
|
+
|
|
24
|
+
self.str_row = "+" + "-" * (self.width * 4 - 1) + "+"
|
|
25
|
+
self.numbers = "|" + " ".join([f" {i} " for i in range(self.width)]) + "|"
|
|
26
|
+
|
|
27
|
+
def valid_moves(self):
|
|
28
|
+
"""Get list of valid moves (i.e. not full columns)."""
|
|
29
|
+
return self.n_items_in_column < self.height
|
|
30
|
+
|
|
31
|
+
def clear(self):
|
|
32
|
+
self.board = np.zeros(shape=(self.height, self.width), dtype=np.float32)
|
|
33
|
+
self.n_items_in_column = np.zeros(self.width, dtype=np.int32)
|
|
34
|
+
self.turn = 0
|
|
35
|
+
|
|
36
|
+
def show(self):
|
|
37
|
+
"""Print out game board on console."""
|
|
38
|
+
print(self.str_row)
|
|
39
|
+
for j in range(self.height - 1, -1, -1):
|
|
40
|
+
for i in range(self.width):
|
|
41
|
+
match self.board[j, i]:
|
|
42
|
+
case 1:
|
|
43
|
+
print("| X", end=" ")
|
|
44
|
+
case -1:
|
|
45
|
+
print("| O", end=" ")
|
|
46
|
+
case _:
|
|
47
|
+
print("| ", end=" ")
|
|
48
|
+
print("|")
|
|
49
|
+
print(self.str_row)
|
|
50
|
+
print(self.numbers)
|
|
51
|
+
print(self.str_row)
|
|
52
|
+
|
|
53
|
+
def check_win(self, move_played: tuple[int, int]) -> bool:
|
|
54
|
+
if self.check_rows(move_played):
|
|
55
|
+
return True
|
|
56
|
+
if self.check_cols(move_played):
|
|
57
|
+
return True
|
|
58
|
+
if self.check_diags(move_played):
|
|
59
|
+
return True
|
|
60
|
+
return False
|
|
61
|
+
|
|
62
|
+
def check_tie(self) -> bool:
|
|
63
|
+
"""
|
|
64
|
+
Check whether the game is a tie (i.e. the board is full).
|
|
65
|
+
|
|
66
|
+
Note that it does not check for a win, so it should be called after check_win.
|
|
67
|
+
"""
|
|
68
|
+
# If the last row is full, the game is a tie
|
|
69
|
+
return bool(np.all(self.board[-1] != 0))
|
|
70
|
+
|
|
71
|
+
def check_rows(self, move_played: tuple[int, int]) -> bool:
|
|
72
|
+
row, col = move_played
|
|
73
|
+
start_index = max(0, col - self.n_to_align + 1)
|
|
74
|
+
end_index = min(self.width - self.n_to_align, col) + 1
|
|
75
|
+
for start in range(start_index, end_index):
|
|
76
|
+
slice = self.board[row, start : start + self.n_to_align]
|
|
77
|
+
if np.all(slice == self.turn):
|
|
78
|
+
return True
|
|
79
|
+
return False
|
|
80
|
+
|
|
81
|
+
def check_cols(self, move_played: tuple[int, int]) -> bool:
|
|
82
|
+
row, col = move_played
|
|
83
|
+
start_index = max(0, row - self.n_to_align + 1)
|
|
84
|
+
end_index = min(self.height - self.n_to_align, row) + 1
|
|
85
|
+
for start in range(start_index, end_index):
|
|
86
|
+
slice = self.board[start : start + self.n_to_align, col]
|
|
87
|
+
if np.all(slice == self.turn):
|
|
88
|
+
return True
|
|
89
|
+
return False
|
|
90
|
+
|
|
91
|
+
def check_diags(self, move_played: tuple[int, int]) -> bool:
|
|
92
|
+
row, col = move_played
|
|
93
|
+
# count the adjacent items in the / diagonal
|
|
94
|
+
n_adjacent = 0
|
|
95
|
+
# Top right
|
|
96
|
+
row_i, col_i = row + 1, col + 1
|
|
97
|
+
while row_i < self.height and col_i < self.width and self.board[row_i, col_i] == self.turn:
|
|
98
|
+
n_adjacent += 1
|
|
99
|
+
row_i += 1
|
|
100
|
+
col_i += 1
|
|
101
|
+
# Bottom left
|
|
102
|
+
row_i, col_i = row - 1, col - 1
|
|
103
|
+
while row_i >= 0 and col_i >= 0 and self.board[row_i, col_i] == self.turn:
|
|
104
|
+
n_adjacent += 1
|
|
105
|
+
row_i -= 1
|
|
106
|
+
col_i -= 1
|
|
107
|
+
if n_adjacent >= self.n_to_align - 1:
|
|
108
|
+
return True
|
|
109
|
+
|
|
110
|
+
# Count adjacent items in the \ diagonal
|
|
111
|
+
n_adjacent = 0
|
|
112
|
+
# Top left
|
|
113
|
+
row_i, col_i = row + 1, col - 1
|
|
114
|
+
while row_i < self.height and col_i >= 0 and self.board[row_i, col_i] == self.turn:
|
|
115
|
+
n_adjacent += 1
|
|
116
|
+
row_i += 1
|
|
117
|
+
col_i -= 1
|
|
118
|
+
# Bottom right
|
|
119
|
+
row_i, col_i = row - 1, col + 1
|
|
120
|
+
while row_i >= 0 and col_i < self.width and self.board[row_i, col_i] == self.turn:
|
|
121
|
+
n_adjacent += 1
|
|
122
|
+
row_i -= 1
|
|
123
|
+
col_i += 1
|
|
124
|
+
|
|
125
|
+
return n_adjacent >= self.n_to_align - 1
|
|
126
|
+
|
|
127
|
+
def play(self, column: int) -> StepResult:
|
|
128
|
+
"""Apply move to board.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
column (int): Selected column index (between 0 and the number of cols - 1).
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
bool: whether the player has won.
|
|
135
|
+
"""
|
|
136
|
+
row_index = self.n_items_in_column[column]
|
|
137
|
+
if row_index >= self.height:
|
|
138
|
+
raise ValueError(f"Column {column} is full, use `valid_moves` to check valid moves.")
|
|
139
|
+
self.n_items_in_column[column] += 1
|
|
140
|
+
self.board[row_index, column] = self.turn
|
|
141
|
+
if self.check_win((row_index, column)):
|
|
142
|
+
result = StepResult.WIN
|
|
143
|
+
elif self.check_tie():
|
|
144
|
+
result = StepResult.TIE
|
|
145
|
+
else:
|
|
146
|
+
result = StepResult.NOTHING
|
|
147
|
+
self.switch_turn()
|
|
148
|
+
return result
|
|
149
|
+
|
|
150
|
+
def switch_turn(self) -> None:
|
|
151
|
+
"""Switch turn between players."""
|
|
152
|
+
self.turn = -self.turn
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def test_win():
|
|
156
|
+
board = GameBoard(4, 1, 2)
|
|
157
|
+
assert board.play(0) == StepResult.NOTHING
|
|
158
|
+
assert board.play(2) == StepResult.NOTHING
|
|
159
|
+
assert board.play(1) == StepResult.WIN
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def test_tie():
|
|
163
|
+
board = GameBoard(4, 1, 2)
|
|
164
|
+
assert board.play(0) == StepResult.NOTHING
|
|
165
|
+
assert board.play(1) == StepResult.NOTHING
|
|
166
|
+
assert board.play(2) == StepResult.NOTHING
|
|
167
|
+
assert board.play(3) == StepResult.TIE
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def test_win_diag():
|
|
171
|
+
board = GameBoard(2, 2, 2)
|
|
172
|
+
assert board.play(0) == StepResult.NOTHING
|
|
173
|
+
assert board.play(1) == StepResult.NOTHING
|
|
174
|
+
assert board.play(1) == StepResult.WIN
|
|
175
|
+
|
|
176
|
+
board.clear()
|
|
177
|
+
assert board.play(1) == StepResult.NOTHING
|
|
178
|
+
assert board.play(1) == StepResult.NOTHING
|
|
179
|
+
assert board.play(0) == StepResult.WIN
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
if __name__ == "__main__":
|
|
183
|
+
test_win()
|
|
184
|
+
test_tie()
|
|
185
|
+
test_win_diag()
|
|
186
|
+
print("All tests passed!")
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from typing import Sequence
|
|
2
|
+
import numpy as np
|
|
3
|
+
import numpy.typing as npt
|
|
4
|
+
from marlenv import MARLEnv, MultiDiscreteSpace, Step, State, Observation, DiscreteSpace
|
|
5
|
+
|
|
6
|
+
from .board import GameBoard, StepResult
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ConnectN(MARLEnv[MultiDiscreteSpace]):
|
|
10
|
+
def __init__(self, width: int = 7, height: int = 6, n: int = 4):
|
|
11
|
+
self.board = GameBoard(width, height, n)
|
|
12
|
+
action_space = DiscreteSpace(self.board.width).repeat(1)
|
|
13
|
+
observation_shape = (self.board.height, self.board.width)
|
|
14
|
+
state_shape = observation_shape
|
|
15
|
+
super().__init__(1, action_space, observation_shape, state_shape)
|
|
16
|
+
|
|
17
|
+
def reset(self):
|
|
18
|
+
self.board.clear()
|
|
19
|
+
return self.get_observation(), self.get_state()
|
|
20
|
+
|
|
21
|
+
def step(self, action: Sequence[int] | npt.NDArray[np.uint32]):
|
|
22
|
+
match self.board.play(action[0]):
|
|
23
|
+
case StepResult.NOTHING:
|
|
24
|
+
done = False
|
|
25
|
+
reward = 0
|
|
26
|
+
case StepResult.WIN:
|
|
27
|
+
done = True
|
|
28
|
+
reward = 1
|
|
29
|
+
case StepResult.TIE:
|
|
30
|
+
done = True
|
|
31
|
+
reward = 0
|
|
32
|
+
return Step(self.get_observation(), self.get_state(), reward, done, False)
|
|
33
|
+
|
|
34
|
+
def available_actions(self):
|
|
35
|
+
"""Full columns are not available."""
|
|
36
|
+
return np.expand_dims(self.board.valid_moves(), axis=0)
|
|
37
|
+
|
|
38
|
+
def get_observation(self):
|
|
39
|
+
return Observation(self.board.board.copy(), self.available_actions())
|
|
40
|
+
|
|
41
|
+
def get_state(self):
|
|
42
|
+
return State(self.board.board.copy(), np.array([self.board.turn]))
|
|
43
|
+
|
|
44
|
+
def set_state(self, state: State):
|
|
45
|
+
self.board.board = state.data.copy() # type: ignore Currently a type error because of the unchecked shape
|
|
46
|
+
self.board.turn = int(state.extras[0])
|
|
47
|
+
n_completed = np.count_nonzero(self.board.board, axis=0)
|
|
48
|
+
self.board.n_items_in_column = n_completed
|
|
49
|
+
|
|
50
|
+
def render(self):
|
|
51
|
+
self.board.show()
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import itertools
|
|
3
|
+
from marlenv import MARLEnv, DiscreteSpace, Observation, State, Step
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
N_ROWS = 11
|
|
7
|
+
N_COLS = 12
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class CoordinatedGrid(MARLEnv):
|
|
11
|
+
"""
|
|
12
|
+
Coordinated grid world environment used in the EMC paper to test the effectiveness of the proposed method.
|
|
13
|
+
https://proceedings.neurips.cc/paper_files/paper/2021/file/1e8ca836c962598551882e689265c1c5-Paper.pdf
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
episode_limit=30,
|
|
19
|
+
time_penalty=2,
|
|
20
|
+
):
|
|
21
|
+
super().__init__(
|
|
22
|
+
n_agents=2,
|
|
23
|
+
action_space=DiscreteSpace(5, ["SOUTH", "NORTH", "WEST", "EAST", "STAY"]).repeat(2),
|
|
24
|
+
observation_shape=(N_ROWS + N_COLS,),
|
|
25
|
+
state_shape=(N_ROWS + N_COLS,) * 2,
|
|
26
|
+
)
|
|
27
|
+
self._episode_steps = 0
|
|
28
|
+
self.episode_limit = episode_limit
|
|
29
|
+
self.center = N_COLS // 2
|
|
30
|
+
###larger gridworld
|
|
31
|
+
visible_row = [i for i in range(N_ROWS // 2 - 2, N_ROWS // 2 + 3)]
|
|
32
|
+
visible_col = [i for i in range(N_COLS // 2 - 3, N_COLS // 2 + 3)]
|
|
33
|
+
self.vision_index = [[i, j] for i, j in list(itertools.product(visible_row, visible_col))]
|
|
34
|
+
self.agents_location = [[0, 0], [N_ROWS - 1, N_COLS - 1]]
|
|
35
|
+
self.time_penalty = time_penalty
|
|
36
|
+
|
|
37
|
+
def reset(self):
|
|
38
|
+
self.agents_location = [[0, 0], [N_ROWS - 1, N_COLS - 1]]
|
|
39
|
+
self._episode_steps = 0
|
|
40
|
+
return self.get_observation(), self.get_state()
|
|
41
|
+
|
|
42
|
+
def get_observation(self):
|
|
43
|
+
obs_1 = [[0 for _ in range(N_ROWS)], [0 for _ in range(N_COLS)]]
|
|
44
|
+
# obs_2 = obs_1.copy()
|
|
45
|
+
import copy
|
|
46
|
+
|
|
47
|
+
obs_2 = copy.deepcopy(obs_1)
|
|
48
|
+
|
|
49
|
+
obs_1[0][self.agents_location[0][0]] = 1
|
|
50
|
+
obs_1[1][self.agents_location[0][1]] = 1
|
|
51
|
+
obs_1 = obs_1[0] + obs_1[1]
|
|
52
|
+
|
|
53
|
+
obs_2[0][self.agents_location[1][0]] = 1
|
|
54
|
+
obs_2[1][self.agents_location[1][1]] = 1
|
|
55
|
+
obs_2 = obs_2[0] + obs_2[1]
|
|
56
|
+
|
|
57
|
+
if self.agents_location[0] in self.vision_index and self.agents_location[1] in self.vision_index:
|
|
58
|
+
temp = obs_1.copy()
|
|
59
|
+
obs_1 += obs_2.copy()
|
|
60
|
+
obs_2 += temp.copy()
|
|
61
|
+
elif self.agents_location[0] in self.vision_index:
|
|
62
|
+
obs_2 += obs_1.copy()
|
|
63
|
+
obs_1 += [0 for _ in range(N_ROWS + N_COLS)]
|
|
64
|
+
elif self.agents_location[1] in self.vision_index:
|
|
65
|
+
obs_1 += obs_2.copy()
|
|
66
|
+
obs_2 += [0 for _ in range(N_ROWS + N_COLS)]
|
|
67
|
+
else:
|
|
68
|
+
obs_2 += [0 for _ in range(N_ROWS + N_COLS)]
|
|
69
|
+
obs_1 += [0 for _ in range(N_ROWS + N_COLS)]
|
|
70
|
+
|
|
71
|
+
obs_data = np.array([obs_1, obs_2])
|
|
72
|
+
return Observation(obs_data, self.available_actions())
|
|
73
|
+
|
|
74
|
+
def get_state(self):
|
|
75
|
+
obs = self.get_observation()
|
|
76
|
+
state_data = obs.data.reshape(-1)
|
|
77
|
+
return State(state_data)
|
|
78
|
+
|
|
79
|
+
def available_actions(self):
|
|
80
|
+
avail_actions = np.full((self.n_agents, self.n_actions), True)
|
|
81
|
+
for agent_num, (y, x) in enumerate(self.agents_location):
|
|
82
|
+
if x == 0:
|
|
83
|
+
avail_actions[agent_num, 0] = 0
|
|
84
|
+
elif x == N_ROWS - 1:
|
|
85
|
+
avail_actions[agent_num, 1] = 0
|
|
86
|
+
if y == 0:
|
|
87
|
+
avail_actions[agent_num, 2] = 0
|
|
88
|
+
# Check for center line (depends on the agent number)
|
|
89
|
+
elif y == self.center + agent_num - 1:
|
|
90
|
+
avail_actions[agent_num, 3] = 0
|
|
91
|
+
return avail_actions
|
|
92
|
+
|
|
93
|
+
def step(self, action):
|
|
94
|
+
for idx, action in enumerate(action):
|
|
95
|
+
match action:
|
|
96
|
+
case 0:
|
|
97
|
+
self.agents_location[idx][0] -= 1
|
|
98
|
+
case 1:
|
|
99
|
+
self.agents_location[idx][0] += 1
|
|
100
|
+
case 2:
|
|
101
|
+
self.agents_location[idx][1] -= 1
|
|
102
|
+
case 3:
|
|
103
|
+
self.agents_location[idx][1] += 1
|
|
104
|
+
case 4:
|
|
105
|
+
pass
|
|
106
|
+
case _:
|
|
107
|
+
raise ValueError(f"Invalid action {action} for agent {idx}!")
|
|
108
|
+
|
|
109
|
+
self._episode_steps += 1
|
|
110
|
+
terminated = self._episode_steps >= self.episode_limit
|
|
111
|
+
env_info = {"battle_won": False}
|
|
112
|
+
n_arrived = self.n_agents_arrived()
|
|
113
|
+
if n_arrived == 1:
|
|
114
|
+
reward = -self.time_penalty
|
|
115
|
+
elif n_arrived == 2:
|
|
116
|
+
reward = 10
|
|
117
|
+
terminated = True
|
|
118
|
+
env_info = {"battle_won": True}
|
|
119
|
+
else:
|
|
120
|
+
reward = 0
|
|
121
|
+
return Step(self.get_observation(), self.get_state(), reward, terminated, terminated, env_info)
|
|
122
|
+
|
|
123
|
+
def n_agents_arrived(self):
|
|
124
|
+
n = 0
|
|
125
|
+
if self.agents_location[0] == [N_ROWS // 2, self.center - 1]:
|
|
126
|
+
n += 1
|
|
127
|
+
if self.agents_location[1] == [N_ROWS // 2, self.center]:
|
|
128
|
+
n += 1
|
|
129
|
+
return n
|
|
130
|
+
|
|
131
|
+
def render(self):
|
|
132
|
+
print("Agents location: ", self.agents_location)
|
|
133
|
+
for row in range(N_ROWS):
|
|
134
|
+
for col in range(N_COLS):
|
|
135
|
+
if [row, col] in self.agents_location:
|
|
136
|
+
print("X", end=" ")
|
|
137
|
+
else:
|
|
138
|
+
print(".", end=" ")
|
|
139
|
+
print()
|
|
@@ -45,7 +45,7 @@ class DeepSea(MARLEnv[MultiDiscreteSpace]):
|
|
|
45
45
|
self._col = 0
|
|
46
46
|
return self.get_observation(), self.get_state()
|
|
47
47
|
|
|
48
|
-
def step(self, action: Sequence[int]):
|
|
48
|
+
def step(self, action: Sequence[int] | np.ndarray):
|
|
49
49
|
self._row += 1
|
|
50
50
|
if action[0] == LEFT:
|
|
51
51
|
self._col -= 1
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from marlenv import MARLEnv, Observation, DiscreteSpace, State, Step
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class MatrixGame(MARLEnv):
|
|
6
|
+
"""Single step matrix game used in QTRAN, Qatten and QPLEX papers."""
|
|
7
|
+
|
|
8
|
+
N_AGENTS = 2
|
|
9
|
+
UNIT_DIM = 1
|
|
10
|
+
OBS_SHAPE = (1,)
|
|
11
|
+
STATE_SIZE = UNIT_DIM * N_AGENTS
|
|
12
|
+
|
|
13
|
+
QPLEX_PAYOFF_MATRIX = [
|
|
14
|
+
[8.0, -12.0, -12.0],
|
|
15
|
+
[-12.0, 0.0, 0.0],
|
|
16
|
+
[-12.0, 0.0, 0.0],
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
def __init__(self, payoff_matrix: list[list[float]]):
|
|
20
|
+
action_names = [chr(ord("A") + i) for i in range(len(payoff_matrix[0]))]
|
|
21
|
+
super().__init__(
|
|
22
|
+
2,
|
|
23
|
+
action_space=DiscreteSpace(len(payoff_matrix[0]), action_names).repeat(2),
|
|
24
|
+
observation_shape=MatrixGame.OBS_SHAPE,
|
|
25
|
+
state_shape=(MatrixGame.STATE_SIZE,),
|
|
26
|
+
)
|
|
27
|
+
self.current_step = 0
|
|
28
|
+
self.payoffs = payoff_matrix
|
|
29
|
+
|
|
30
|
+
def reset(self):
|
|
31
|
+
self.current_step = 0
|
|
32
|
+
return self.get_observation(), self.get_state()
|
|
33
|
+
|
|
34
|
+
def get_observation(self):
|
|
35
|
+
return Observation(
|
|
36
|
+
np.array([[self.current_step]] * MatrixGame.N_AGENTS, np.float32),
|
|
37
|
+
self.available_actions(),
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
def step(self, action):
|
|
41
|
+
action = list(action)
|
|
42
|
+
self.current_step += 1
|
|
43
|
+
return Step(self.get_observation(), self.get_state(), self.payoffs[action[0]][action[1]], True)
|
|
44
|
+
|
|
45
|
+
def render(self):
|
|
46
|
+
return
|
|
47
|
+
|
|
48
|
+
def get_state(self):
|
|
49
|
+
return State(np.zeros((MatrixGame.STATE_SIZE,), np.float32))
|
|
50
|
+
|
|
51
|
+
def seed(self, seed_value):
|
|
52
|
+
return
|