multi-agent-rlenv 3.6.3__tar.gz → 3.7.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/.github/workflows/ci.yaml +3 -5
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/PKG-INFO +2 -2
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/pyproject.toml +2 -2
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/__init__.py +2 -2
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/adapters/gym_adapter.py +3 -3
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/adapters/pettingzoo_adapter.py +14 -14
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/adapters/smac_adapter.py +10 -7
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/catalog/deepsea.py +1 -1
- multi_agent_rlenv-3.7.0/src/marlenv/catalog/two_steps.py +93 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/env_pool.py +3 -3
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/mock_env.py +2 -2
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/models/spaces.py +7 -7
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/utils/schedule.py +8 -10
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/wrappers/agent_id_wrapper.py +2 -2
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/wrappers/blind_wrapper.py +2 -2
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/wrappers/centralised.py +3 -3
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/wrappers/delayed_rewards.py +2 -2
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/wrappers/last_action_wrapper.py +4 -4
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/wrappers/paddings.py +4 -4
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/wrappers/potential_shaping.py +2 -2
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/wrappers/rlenv_wrapper.py +2 -2
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/wrappers/time_limit.py +2 -2
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/wrappers/video_recorder.py +2 -2
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/tests/test_adapters.py +2 -3
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/tests/test_models.py +7 -7
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/tests/test_serialization.py +1 -1
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/tests/test_wrappers.py +18 -18
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/.github/workflows/docs.yaml +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/.gitignore +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/LICENSE +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/README.md +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/adapters/__init__.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/adapters/pymarl_adapter.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/catalog/__init__.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/env_builder.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/exceptions.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/models/__init__.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/models/env.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/models/episode.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/models/observation.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/models/state.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/models/step.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/models/transition.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/py.typed +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/utils/__init__.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/utils/cached_property_collector.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/utils/import_placeholders.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/wrappers/__init__.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/wrappers/action_randomizer.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/wrappers/available_actions_mask.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/wrappers/available_actions_wrapper.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/wrappers/penalty_wrapper.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/tests/__init__.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/tests/test_catalog.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/tests/test_deepsea.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/tests/test_episode.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/tests/test_others.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/tests/test_pool.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/tests/test_schedules.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/tests/test_spaces.py +0 -0
- {multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/tests/utils.py +0 -0
|
@@ -27,8 +27,6 @@ jobs:
|
|
|
27
27
|
- x86_64
|
|
28
28
|
- aarch64
|
|
29
29
|
python-version:
|
|
30
|
-
- '3.10'
|
|
31
|
-
- '3.11'
|
|
32
30
|
- '3.12'
|
|
33
31
|
- '3.13'
|
|
34
32
|
runs-on: ${{ matrix.os }}
|
|
@@ -43,7 +41,7 @@ jobs:
|
|
|
43
41
|
- name: Install uv
|
|
44
42
|
uses: yezz123/setup-uv@v4
|
|
45
43
|
with:
|
|
46
|
-
uv-version: 0.
|
|
44
|
+
uv-version: 0.9.24
|
|
47
45
|
- name: Install dependencies and run pytest
|
|
48
46
|
run: |
|
|
49
47
|
uv sync --extra overcooked --extra gym --extra pettingzoo --extra torch
|
|
@@ -59,11 +57,11 @@ jobs:
|
|
|
59
57
|
- name: Set up Python
|
|
60
58
|
uses: actions/setup-python@v5
|
|
61
59
|
with:
|
|
62
|
-
python-version: 3.
|
|
60
|
+
python-version: 3.13
|
|
63
61
|
- name: Install UV
|
|
64
62
|
uses: yezz123/setup-uv@v4
|
|
65
63
|
with:
|
|
66
|
-
uv-version: 0.
|
|
64
|
+
uv-version: 0.9.24
|
|
67
65
|
- name: Build wheels
|
|
68
66
|
run: |
|
|
69
67
|
uv venv
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: multi-agent-rlenv
|
|
3
|
-
Version: 3.
|
|
3
|
+
Version: 3.7.0
|
|
4
4
|
Summary: A strongly typed Multi-Agent Reinforcement Learning framework
|
|
5
5
|
Project-URL: repository, https://github.com/yamoling/multi-agent-rlenv
|
|
6
6
|
Author-email: Yannick Molinghen <yannick.molinghen@ulb.be>
|
|
7
7
|
License-File: LICENSE
|
|
8
8
|
Classifier: Operating System :: OS Independent
|
|
9
9
|
Classifier: Programming Language :: Python :: 3
|
|
10
|
-
Requires-Python: <4,>=3.
|
|
10
|
+
Requires-Python: <4,>=3.12
|
|
11
11
|
Requires-Dist: numpy>=2.0.0
|
|
12
12
|
Requires-Dist: opencv-python>=4.0
|
|
13
13
|
Requires-Dist: typing-extensions>=4.0
|
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "multi-agent-rlenv"
|
|
3
|
-
version = "3.
|
|
3
|
+
version = "3.7.0"
|
|
4
4
|
description = "A strongly typed Multi-Agent Reinforcement Learning framework"
|
|
5
5
|
authors = [
|
|
6
6
|
{ "name" = "Yannick Molinghen", "email" = "yannick.molinghen@ulb.be" },
|
|
7
7
|
]
|
|
8
8
|
readme = "README.md"
|
|
9
|
-
requires-python = ">=3.
|
|
9
|
+
requires-python = ">=3.12, <4"
|
|
10
10
|
urls = { "repository" = "https://github.com/yamoling/multi-agent-rlenv" }
|
|
11
11
|
classifiers = [
|
|
12
12
|
"Programming Language :: Python :: 3",
|
|
@@ -65,9 +65,9 @@ If you want to create a new environment, you can simply create a class that inhe
|
|
|
65
65
|
from importlib.metadata import version, PackageNotFoundError
|
|
66
66
|
|
|
67
67
|
try:
|
|
68
|
-
__version__ = version("
|
|
68
|
+
__version__ = version("multi-agent-rlenv")
|
|
69
69
|
except PackageNotFoundError:
|
|
70
|
-
__version__ = "0.0.0" # fallback
|
|
70
|
+
__version__ = "0.0.0" # fallback for CI
|
|
71
71
|
|
|
72
72
|
|
|
73
73
|
from . import models
|
|
@@ -44,8 +44,8 @@ class Gym(MARLEnv[Space]):
|
|
|
44
44
|
raise ValueError("No observation available. Call reset() first.")
|
|
45
45
|
return self._last_obs
|
|
46
46
|
|
|
47
|
-
def step(self,
|
|
48
|
-
obs, reward, done, truncated, info = self._gym_env.step(list(
|
|
47
|
+
def step(self, action):
|
|
48
|
+
obs, reward, done, truncated, info = self._gym_env.step(list(action)[0])
|
|
49
49
|
self._last_obs = Observation(
|
|
50
50
|
np.array([obs], dtype=np.float32),
|
|
51
51
|
self.available_actions(),
|
|
@@ -74,7 +74,7 @@ class Gym(MARLEnv[Space]):
|
|
|
74
74
|
image = np.array(self._gym_env.render())
|
|
75
75
|
if sys.platform in ("linux", "linux2"):
|
|
76
76
|
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
|
77
|
-
return image
|
|
77
|
+
return np.array(image, dtype=np.uint8)
|
|
78
78
|
|
|
79
79
|
def seed(self, seed_value: int):
|
|
80
80
|
self._gym_env.reset(seed=seed_value)
|
{multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/adapters/pettingzoo_adapter.py
RENAMED
|
@@ -33,39 +33,39 @@ class PettingZoo(MARLEnv[Space]):
|
|
|
33
33
|
if obs_space.shape is None:
|
|
34
34
|
raise NotImplementedError("Only discrete observation spaces are supported")
|
|
35
35
|
self._pz_env = env
|
|
36
|
-
|
|
37
|
-
|
|
36
|
+
self.n_agents = n_agents
|
|
37
|
+
self.n_actions = space.shape[-1]
|
|
38
|
+
self.last_observation, state = self.reset()
|
|
39
|
+
super().__init__(n_agents, space, obs_space.shape, state.shape)
|
|
38
40
|
self.agents = env.possible_agents
|
|
39
|
-
self.last_observation = None
|
|
40
41
|
|
|
41
42
|
def get_state(self):
|
|
42
43
|
try:
|
|
43
|
-
return self._pz_env.state()
|
|
44
|
+
return State(self._pz_env.state())
|
|
44
45
|
except NotImplementedError:
|
|
45
|
-
|
|
46
|
+
assert self.last_observation is not None, "Cannot get the state unless there is a previous observation"
|
|
47
|
+
return State(self.last_observation.data)
|
|
46
48
|
|
|
47
|
-
def step(self,
|
|
48
|
-
action_dict = dict(zip(self.agents,
|
|
49
|
+
def step(self, action: npt.NDArray | Sequence):
|
|
50
|
+
action_dict = dict(zip(self.agents, action))
|
|
49
51
|
obs, reward, term, trunc, info = self._pz_env.step(action_dict)
|
|
50
52
|
obs_data = np.array([v for v in obs.values()])
|
|
51
53
|
reward = np.sum([r for r in reward.values()], keepdims=True)
|
|
52
54
|
self.last_observation = Observation(obs_data, self.available_actions())
|
|
53
|
-
state =
|
|
55
|
+
state = self.get_state()
|
|
54
56
|
return Step(self.last_observation, state, reward, any(term.values()), any(trunc.values()), info)
|
|
55
57
|
|
|
56
58
|
def reset(self):
|
|
57
59
|
obs = self._pz_env.reset()[0]
|
|
58
60
|
obs_data = np.array([v for v in obs.values()])
|
|
59
|
-
self.last_observation = Observation(obs_data, self.available_actions()
|
|
60
|
-
return self.last_observation
|
|
61
|
+
self.last_observation = Observation(obs_data, self.available_actions())
|
|
62
|
+
return self.last_observation, self.get_state()
|
|
61
63
|
|
|
62
64
|
def get_observation(self):
|
|
63
|
-
if self.last_observation is None:
|
|
64
|
-
raise ValueError("No observation available. Call reset() first.")
|
|
65
65
|
return self.last_observation
|
|
66
66
|
|
|
67
67
|
def seed(self, seed_value: int):
|
|
68
68
|
self._pz_env.reset(seed=seed_value)
|
|
69
69
|
|
|
70
|
-
def render(self
|
|
71
|
-
|
|
70
|
+
def render(self):
|
|
71
|
+
self._pz_env.render()
|
|
@@ -3,7 +3,7 @@ from typing import overload
|
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import numpy.typing as npt
|
|
6
|
-
from smac.env import StarCraft2Env
|
|
6
|
+
from smac.env import StarCraft2Env # type:ignore[import]
|
|
7
7
|
|
|
8
8
|
from marlenv.models import MARLEnv, Observation, State, Step, MultiDiscreteSpace, DiscreteSpace
|
|
9
9
|
|
|
@@ -169,17 +169,18 @@ class SMAC(MARLEnv[MultiDiscreteSpace]):
|
|
|
169
169
|
|
|
170
170
|
def reset(self):
|
|
171
171
|
obs, state = self._env.reset()
|
|
172
|
-
obs = Observation(np.array(obs), self.available_actions()
|
|
173
|
-
|
|
172
|
+
obs = Observation(np.array(obs), self.available_actions())
|
|
173
|
+
state = State(state)
|
|
174
|
+
return obs, state
|
|
174
175
|
|
|
175
176
|
def get_observation(self):
|
|
176
|
-
return self._env.get_obs()
|
|
177
|
+
return Observation(np.array(self._env.get_obs()), self.available_actions())
|
|
177
178
|
|
|
178
179
|
def get_state(self):
|
|
179
180
|
return State(self._env.get_state())
|
|
180
181
|
|
|
181
|
-
def step(self,
|
|
182
|
-
reward, done, info = self._env.step(
|
|
182
|
+
def step(self, action):
|
|
183
|
+
reward, done, info = self._env.step(action)
|
|
183
184
|
obs = Observation(
|
|
184
185
|
self._env.get_obs(), # type: ignore
|
|
185
186
|
self.available_actions(),
|
|
@@ -199,7 +200,9 @@ class SMAC(MARLEnv[MultiDiscreteSpace]):
|
|
|
199
200
|
return np.array(self._env.get_avail_actions()) == 1
|
|
200
201
|
|
|
201
202
|
def get_image(self):
|
|
202
|
-
|
|
203
|
+
img = self._env.render(mode="rgb_array")
|
|
204
|
+
assert img is not None
|
|
205
|
+
return img
|
|
203
206
|
|
|
204
207
|
def seed(self, seed_value: int):
|
|
205
208
|
self._env = StarCraft2Env(map_name=self._env.map_name, seed=seed_value)
|
|
@@ -45,7 +45,7 @@ class DeepSea(MARLEnv[MultiDiscreteSpace]):
|
|
|
45
45
|
self._col = 0
|
|
46
46
|
return self.get_observation(), self.get_state()
|
|
47
47
|
|
|
48
|
-
def step(self, action: Sequence[int]):
|
|
48
|
+
def step(self, action: Sequence[int] | np.ndarray):
|
|
49
49
|
self._row += 1
|
|
50
50
|
if action[0] == LEFT:
|
|
51
51
|
self._col -= 1
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
from enum import IntEnum
|
|
2
|
+
import cv2
|
|
3
|
+
import marlenv
|
|
4
|
+
import numpy as np
|
|
5
|
+
import numpy.typing as npt
|
|
6
|
+
from typing import Sequence
|
|
7
|
+
from marlenv import Observation, State, DiscreteSpace, Step
|
|
8
|
+
|
|
9
|
+
PAYOFF_INITIAL = [[0, 0], [0, 0]]
|
|
10
|
+
PAYOFF_2A = [[7, 7], [7, 7]]
|
|
11
|
+
PAYOFF_2B = [[0, 1], [1, 8]]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TwoStepsState(IntEnum):
|
|
15
|
+
INITIAL = 0
|
|
16
|
+
STATE_2A = 1
|
|
17
|
+
STATE_2B = 2
|
|
18
|
+
END = 3
|
|
19
|
+
|
|
20
|
+
def one_hot(self):
|
|
21
|
+
res = np.zeros((4,), dtype=np.float32)
|
|
22
|
+
res[self.value] = 1
|
|
23
|
+
return res
|
|
24
|
+
|
|
25
|
+
@staticmethod
|
|
26
|
+
def from_one_hot(x: np.ndarray):
|
|
27
|
+
for s in TwoStepsState:
|
|
28
|
+
if x[s.value] == 1:
|
|
29
|
+
return s
|
|
30
|
+
raise ValueError()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class TwoStepsGame(marlenv.MARLEnv):
|
|
34
|
+
"""
|
|
35
|
+
Two-steps game used in QMix paper (https://arxiv.org/pdf/1803.11485.pdf, section 5)
|
|
36
|
+
to demonstrate its superior representationability compared to VDN.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(self):
|
|
40
|
+
self.state = TwoStepsState.INITIAL
|
|
41
|
+
self._identity = np.identity(2, dtype=np.float32)
|
|
42
|
+
super().__init__(
|
|
43
|
+
2,
|
|
44
|
+
DiscreteSpace(2).repeat(2),
|
|
45
|
+
observation_shape=(self.state.one_hot().shape[0] + 2,),
|
|
46
|
+
state_shape=self.state.one_hot().shape,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
def reset(self):
|
|
50
|
+
self.state = TwoStepsState.INITIAL
|
|
51
|
+
return self.observation(), self.get_state()
|
|
52
|
+
|
|
53
|
+
def step(self, action: npt.NDArray[np.int32] | Sequence):
|
|
54
|
+
match self.state:
|
|
55
|
+
case TwoStepsState.INITIAL:
|
|
56
|
+
# In the initial step, only agent 0's actions have an influence on the state
|
|
57
|
+
payoffs = PAYOFF_INITIAL
|
|
58
|
+
if action[0] == 0:
|
|
59
|
+
self.state = TwoStepsState.STATE_2A
|
|
60
|
+
elif action[0] == 1:
|
|
61
|
+
self.state = TwoStepsState.STATE_2B
|
|
62
|
+
else:
|
|
63
|
+
raise ValueError(f"Invalid action: {action[0]}")
|
|
64
|
+
case TwoStepsState.STATE_2A:
|
|
65
|
+
payoffs = PAYOFF_2A
|
|
66
|
+
self.state = TwoStepsState.END
|
|
67
|
+
case TwoStepsState.STATE_2B:
|
|
68
|
+
payoffs = PAYOFF_2B
|
|
69
|
+
self.state = TwoStepsState.END
|
|
70
|
+
case TwoStepsState.END:
|
|
71
|
+
raise ValueError("Episode is already over")
|
|
72
|
+
reward = payoffs[action[0]][action[1]]
|
|
73
|
+
done = self.state == TwoStepsState.END
|
|
74
|
+
return Step(self.observation(), self.get_state(), reward, done, False)
|
|
75
|
+
|
|
76
|
+
def get_state(self):
|
|
77
|
+
return State(self.state.one_hot())
|
|
78
|
+
|
|
79
|
+
def observation(self):
|
|
80
|
+
obs_data = np.array([self.state.one_hot(), self.state.one_hot()])
|
|
81
|
+
extras = self._identity
|
|
82
|
+
return Observation(obs_data, self.available_actions(), extras)
|
|
83
|
+
|
|
84
|
+
def render(self):
|
|
85
|
+
print(self.state)
|
|
86
|
+
|
|
87
|
+
def get_image(self):
|
|
88
|
+
state = self.state.one_hot()
|
|
89
|
+
img = cv2.cvtColor(state, cv2.COLOR_GRAY2BGR)
|
|
90
|
+
return np.array(img, dtype=np.uint8)
|
|
91
|
+
|
|
92
|
+
def set_state(self, state: State):
|
|
93
|
+
self.state = TwoStepsState.from_one_hot(state.data)
|
|
@@ -20,10 +20,10 @@ class EnvPool(RLEnvWrapper[ActionSpaceType]):
|
|
|
20
20
|
assert env.has_same_inouts(self.envs[0]), "All environments must have the same inputs and outputs"
|
|
21
21
|
super().__init__(self.envs[0])
|
|
22
22
|
|
|
23
|
-
def seed(self,
|
|
24
|
-
random.seed(
|
|
23
|
+
def seed(self, seed_value: int):
|
|
24
|
+
random.seed(seed_value)
|
|
25
25
|
for env in self.envs:
|
|
26
|
-
env.seed(
|
|
26
|
+
env.seed(seed_value)
|
|
27
27
|
|
|
28
28
|
def reset(self):
|
|
29
29
|
self.wrapped = random.choice(self.envs)
|
|
@@ -73,9 +73,9 @@ class DiscreteMockEnv(MARLEnv[MultiDiscreteSpace]):
|
|
|
73
73
|
def render(self, mode: str = "human"):
|
|
74
74
|
return
|
|
75
75
|
|
|
76
|
-
def step(self,
|
|
76
|
+
def step(self, action):
|
|
77
77
|
self.t += 1
|
|
78
|
-
self.actions_history.append(
|
|
78
|
+
self.actions_history.append(action)
|
|
79
79
|
return Step(
|
|
80
80
|
self.get_observation(),
|
|
81
81
|
self.get_state(),
|
|
@@ -8,7 +8,7 @@ import numpy.typing as npt
|
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
@dataclass
|
|
11
|
-
class Space(ABC):
|
|
11
|
+
class Space[T](ABC):
|
|
12
12
|
shape: tuple[int, ...]
|
|
13
13
|
size: int
|
|
14
14
|
labels: list[str]
|
|
@@ -21,7 +21,7 @@ class Space(ABC):
|
|
|
21
21
|
self.labels = labels
|
|
22
22
|
|
|
23
23
|
@abstractmethod
|
|
24
|
-
def sample(self, mask:
|
|
24
|
+
def sample(self, mask: npt.NDArray[np.bool] | None = None) -> T:
|
|
25
25
|
"""Sample a value from the space."""
|
|
26
26
|
|
|
27
27
|
def __eq__(self, value: object) -> bool:
|
|
@@ -44,7 +44,7 @@ class Space(ABC):
|
|
|
44
44
|
|
|
45
45
|
|
|
46
46
|
@dataclass
|
|
47
|
-
class DiscreteSpace(Space):
|
|
47
|
+
class DiscreteSpace(Space[int]):
|
|
48
48
|
size: int
|
|
49
49
|
"""Number of categories"""
|
|
50
50
|
|
|
@@ -53,7 +53,7 @@ class DiscreteSpace(Space):
|
|
|
53
53
|
self.size = size
|
|
54
54
|
self.space = np.arange(size)
|
|
55
55
|
|
|
56
|
-
def sample(self, mask:
|
|
56
|
+
def sample(self, mask: npt.NDArray[np.bool] | None = None):
|
|
57
57
|
space = self.space.copy()
|
|
58
58
|
if mask is not None:
|
|
59
59
|
space = space[mask]
|
|
@@ -87,7 +87,7 @@ class DiscreteSpace(Space):
|
|
|
87
87
|
|
|
88
88
|
|
|
89
89
|
@dataclass
|
|
90
|
-
class MultiDiscreteSpace(Space):
|
|
90
|
+
class MultiDiscreteSpace(Space[npt.NDArray[np.int32]]):
|
|
91
91
|
n_dims: int
|
|
92
92
|
spaces: tuple[DiscreteSpace, ...]
|
|
93
93
|
|
|
@@ -123,7 +123,7 @@ class MultiDiscreteSpace(Space):
|
|
|
123
123
|
|
|
124
124
|
|
|
125
125
|
@dataclass
|
|
126
|
-
class ContinuousSpace(Space):
|
|
126
|
+
class ContinuousSpace(Space[npt.NDArray[np.float32]]):
|
|
127
127
|
"""A continuous space (box) in R^n."""
|
|
128
128
|
|
|
129
129
|
low: npt.NDArray[np.float32]
|
|
@@ -192,7 +192,7 @@ class ContinuousSpace(Space):
|
|
|
192
192
|
action = np.array(action)
|
|
193
193
|
return np.clip(action, self.low, self.high)
|
|
194
194
|
|
|
195
|
-
def sample(self
|
|
195
|
+
def sample(self, *args, **kwargs):
|
|
196
196
|
r = np.random.random(self.shape) * (self.high - self.low) + self.low
|
|
197
197
|
return r.astype(np.float32)
|
|
198
198
|
|
|
@@ -145,17 +145,15 @@ class Schedule:
|
|
|
145
145
|
@staticmethod
|
|
146
146
|
def from_json(data: dict[str, Any]):
|
|
147
147
|
"""Create a Schedule from a JSON-like dictionary."""
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
elif classname == "ArbitrarySchedule":
|
|
148
|
+
candidates = [LinearSchedule, ExpSchedule, ConstantSchedule]
|
|
149
|
+
data = data.copy()
|
|
150
|
+
classname = data.pop("name")
|
|
151
|
+
for cls in candidates:
|
|
152
|
+
if cls.__name__ == classname:
|
|
153
|
+
return cls(**data)
|
|
154
|
+
if classname == "ArbitrarySchedule":
|
|
156
155
|
raise NotImplementedError("ArbitrarySchedule cannot be deserialized from JSON")
|
|
157
|
-
|
|
158
|
-
raise ValueError(f"Unknown schedule type: {classname}")
|
|
156
|
+
raise ValueError(f"Unknown schedule type: {classname}")
|
|
159
157
|
|
|
160
158
|
|
|
161
159
|
@dataclass(eq=False)
|
{multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/wrappers/agent_id_wrapper.py
RENAMED
|
@@ -18,8 +18,8 @@ class AgentId(RLEnvWrapper[AS]):
|
|
|
18
18
|
super().__init__(env, extra_shape=(env.n_agents + env.extras_shape[0],), extra_meanings=meanings)
|
|
19
19
|
self._identity = np.identity(env.n_agents, dtype=np.float32)
|
|
20
20
|
|
|
21
|
-
def step(self,
|
|
22
|
-
step = super().step(
|
|
21
|
+
def step(self, action):
|
|
22
|
+
step = super().step(action)
|
|
23
23
|
step.obs.add_extra(self._identity)
|
|
24
24
|
return step
|
|
25
25
|
|
|
@@ -18,8 +18,8 @@ class Blind(RLEnvWrapper[AS]):
|
|
|
18
18
|
super().__init__(env)
|
|
19
19
|
self.p = float(p)
|
|
20
20
|
|
|
21
|
-
def step(self,
|
|
22
|
-
step = super().step(
|
|
21
|
+
def step(self, action):
|
|
22
|
+
step = super().step(action)
|
|
23
23
|
if random.random() < self.p:
|
|
24
24
|
step.obs.data = np.zeros_like(step.obs.data)
|
|
25
25
|
return step
|
|
@@ -42,9 +42,9 @@ class Centralized(RLEnvWrapper[MultiDiscreteSpace]):
|
|
|
42
42
|
action_names = [str(a) for a in product(*agent_actions)]
|
|
43
43
|
return DiscreteSpace(env.n_actions**env.n_agents, action_names).repeat(1)
|
|
44
44
|
|
|
45
|
-
def step(self,
|
|
46
|
-
|
|
47
|
-
individual_actions = self._individual_actions(
|
|
45
|
+
def step(self, action: npt.NDArray | Sequence):
|
|
46
|
+
action1 = action[0]
|
|
47
|
+
individual_actions = self._individual_actions(action1)
|
|
48
48
|
individual_actions = np.array(individual_actions)
|
|
49
49
|
step = self.wrapped.step(individual_actions) # type: ignore
|
|
50
50
|
step.obs = self._joint_observation(step.obs)
|
|
@@ -27,8 +27,8 @@ class DelayedReward(RLEnvWrapper[AS]):
|
|
|
27
27
|
self.reward_queue.append(np.zeros(self.reward_space.shape, dtype=np.float32))
|
|
28
28
|
return super().reset()
|
|
29
29
|
|
|
30
|
-
def step(self,
|
|
31
|
-
step = super().step(
|
|
30
|
+
def step(self, action):
|
|
31
|
+
step = super().step(action)
|
|
32
32
|
self.reward_queue.append(step.reward)
|
|
33
33
|
# If the step is terminal, we sum all the remaining rewards
|
|
34
34
|
if step.is_terminal:
|
{multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/wrappers/last_action_wrapper.py
RENAMED
|
@@ -33,13 +33,13 @@ class LastAction(RLEnvWrapper[AS]):
|
|
|
33
33
|
state.add_extra(self.last_one_hot_actions.flatten())
|
|
34
34
|
return obs, state
|
|
35
35
|
|
|
36
|
-
def step(self,
|
|
37
|
-
step = super().step(
|
|
36
|
+
def step(self, action):
|
|
37
|
+
step = super().step(action)
|
|
38
38
|
match self.wrapped.action_space:
|
|
39
39
|
case ContinuousSpace():
|
|
40
|
-
self.last_actions =
|
|
40
|
+
self.last_actions = action
|
|
41
41
|
case DiscreteSpace() | MultiDiscreteSpace():
|
|
42
|
-
self.last_one_hot_actions = self.compute_one_hot_actions(
|
|
42
|
+
self.last_one_hot_actions = self.compute_one_hot_actions(action)
|
|
43
43
|
case other:
|
|
44
44
|
raise NotImplementedError(f"Action space {other} not supported")
|
|
45
45
|
step.obs.add_extra(self.last_one_hot_actions)
|
|
@@ -24,8 +24,8 @@ class PadExtras(RLEnvWrapper[AS]):
|
|
|
24
24
|
)
|
|
25
25
|
self.n = n_added
|
|
26
26
|
|
|
27
|
-
def step(self,
|
|
28
|
-
step = super().step(
|
|
27
|
+
def step(self, action):
|
|
28
|
+
step = super().step(action)
|
|
29
29
|
step.obs = self._add_extras(step.obs)
|
|
30
30
|
return step
|
|
31
31
|
|
|
@@ -48,8 +48,8 @@ class PadObservations(RLEnvWrapper[AS]):
|
|
|
48
48
|
super().__init__(env, observation_shape=(env.observation_shape[0] + n_added,))
|
|
49
49
|
self.n = n_added
|
|
50
50
|
|
|
51
|
-
def step(self,
|
|
52
|
-
step = super().step(
|
|
51
|
+
def step(self, action):
|
|
52
|
+
step = super().step(action)
|
|
53
53
|
step.obs = self._add_obs(step.obs)
|
|
54
54
|
return step
|
|
55
55
|
|
{multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/wrappers/potential_shaping.py
RENAMED
|
@@ -39,9 +39,9 @@ class PotentialShaping(RLEnvWrapper[A], ABC):
|
|
|
39
39
|
self._current_potential = self.compute_potential()
|
|
40
40
|
return self.add_extras(obs), state
|
|
41
41
|
|
|
42
|
-
def step(self,
|
|
42
|
+
def step(self, action):
|
|
43
43
|
prev_potential = self._current_potential
|
|
44
|
-
step = super().step(
|
|
44
|
+
step = super().step(action)
|
|
45
45
|
|
|
46
46
|
self._current_potential = self.compute_potential()
|
|
47
47
|
shaped_reward = self.gamma * self._current_potential - prev_potential
|
|
@@ -62,8 +62,8 @@ class RLEnvWrapper(MARLEnv[AS]):
|
|
|
62
62
|
def agent_state_size(self):
|
|
63
63
|
return self.wrapped.agent_state_size
|
|
64
64
|
|
|
65
|
-
def step(self,
|
|
66
|
-
return self.wrapped.step(
|
|
65
|
+
def step(self, action: np.ndarray | Sequence):
|
|
66
|
+
return self.wrapped.step(action)
|
|
67
67
|
|
|
68
68
|
def reset(self):
|
|
69
69
|
return self.wrapped.reset()
|
|
@@ -64,9 +64,9 @@ class TimeLimit(RLEnvWrapper[AS]):
|
|
|
64
64
|
self.add_time_extra(obs, state)
|
|
65
65
|
return obs, state
|
|
66
66
|
|
|
67
|
-
def step(self,
|
|
67
|
+
def step(self, action):
|
|
68
68
|
self._current_step += 1
|
|
69
|
-
step = super().step(
|
|
69
|
+
step = super().step(action)
|
|
70
70
|
if self.add_extra:
|
|
71
71
|
self.add_time_extra(step.obs, step.state)
|
|
72
72
|
# If we reach the time limit
|
|
@@ -44,10 +44,10 @@ class VideoRecorder(RLEnvWrapper[AS]):
|
|
|
44
44
|
case other:
|
|
45
45
|
raise ValueError(f"Unsupported file video encoding: {other}")
|
|
46
46
|
|
|
47
|
-
def step(self,
|
|
47
|
+
def step(self, action):
|
|
48
48
|
if self._recorder is None:
|
|
49
49
|
raise RuntimeError("VideoRecorder not initialized")
|
|
50
|
-
step = super().step(
|
|
50
|
+
step = super().step(action)
|
|
51
51
|
img = self.get_image()
|
|
52
52
|
self._recorder.write(img)
|
|
53
53
|
if step.is_terminal:
|
|
@@ -98,7 +98,7 @@ def _check_env_3m(env):
|
|
|
98
98
|
from marlenv.adapters import SMAC
|
|
99
99
|
|
|
100
100
|
assert isinstance(env, SMAC)
|
|
101
|
-
obs = env.reset()
|
|
101
|
+
obs, state = env.reset()
|
|
102
102
|
assert isinstance(obs, Observation)
|
|
103
103
|
assert env.n_agents == 3
|
|
104
104
|
assert isinstance(env.action_space, MultiDiscreteSpace)
|
|
@@ -114,8 +114,7 @@ def _check_env_3m(env):
|
|
|
114
114
|
|
|
115
115
|
@pytest.mark.skipif(skip_smac, reason="SMAC is not installed")
|
|
116
116
|
def test_smac_from_class():
|
|
117
|
-
from smac.env import StarCraft2Env
|
|
118
|
-
|
|
117
|
+
from smac.env import StarCraft2Env # type: ignore[import]
|
|
119
118
|
from marlenv.adapters import SMAC
|
|
120
119
|
|
|
121
120
|
env = SMAC(StarCraft2Env("3m"))
|
|
@@ -380,8 +380,8 @@ def test_env_replay():
|
|
|
380
380
|
available[(agent + self._seed) % self.n_actions] = True
|
|
381
381
|
return availables
|
|
382
382
|
|
|
383
|
-
def step(self,
|
|
384
|
-
return super().step(
|
|
383
|
+
def step(self, action):
|
|
384
|
+
return super().step(action)
|
|
385
385
|
|
|
386
386
|
def seed(self, seed_value: int):
|
|
387
387
|
np.random.seed(seed_value)
|
|
@@ -409,16 +409,16 @@ def test_wrong_extras_meanings_length():
|
|
|
409
409
|
super().__init__(4, DiscreteSpace(5), (10,), (10,), extras_shape=(5,), extras_meanings=["a", "b", "c"])
|
|
410
410
|
|
|
411
411
|
def get_observation(self):
|
|
412
|
-
|
|
412
|
+
raise NotImplementedError()
|
|
413
413
|
|
|
414
414
|
def get_state(self):
|
|
415
|
-
|
|
415
|
+
raise NotImplementedError()
|
|
416
416
|
|
|
417
|
-
def step(self,
|
|
418
|
-
|
|
417
|
+
def step(self, action):
|
|
418
|
+
raise NotImplementedError()
|
|
419
419
|
|
|
420
420
|
def reset(self):
|
|
421
|
-
|
|
421
|
+
raise NotImplementedError()
|
|
422
422
|
|
|
423
423
|
try:
|
|
424
424
|
TestClass()
|
|
@@ -241,7 +241,7 @@ def test_serialize_schedule():
|
|
|
241
241
|
try:
|
|
242
242
|
pickle.dumps(s)
|
|
243
243
|
assert False, "Should not be able to pickle arbitrary schedules because of the callable lambda"
|
|
244
|
-
except AttributeError:
|
|
244
|
+
except (pickle.PicklingError, AttributeError):
|
|
245
245
|
pass
|
|
246
246
|
|
|
247
247
|
s = Schedule.arbitrary(C())
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import numpy as np
|
|
2
2
|
from marlenv import Builder, DiscreteMOMockEnv, DiscreteMockEnv, MARLEnv
|
|
3
|
-
from marlenv.wrappers import Centralized, AvailableActionsMask, TimeLimit, LastAction, DelayedReward
|
|
3
|
+
from marlenv.wrappers import Centralized, AvailableActionsMask, TimeLimit, LastAction, DelayedReward
|
|
4
4
|
import marlenv
|
|
5
5
|
|
|
6
6
|
|
|
@@ -55,13 +55,12 @@ def test_time_limit_wrapper():
|
|
|
55
55
|
env = Builder(DiscreteMockEnv(1)).time_limit(MAX_T).build()
|
|
56
56
|
assert env.extras_shape == (1,)
|
|
57
57
|
assert env.state_extra_shape == (1,)
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
while not done:
|
|
61
|
-
step = env.step(np.array([0]))
|
|
58
|
+
t = 1
|
|
59
|
+
step = env.step(np.array([0]))
|
|
60
|
+
while not step.done:
|
|
62
61
|
assert step.obs.extras.shape == (env.n_agents, 1)
|
|
63
62
|
assert step.state.extras_shape == (1,)
|
|
64
|
-
|
|
63
|
+
step = env.step(np.array([0]))
|
|
65
64
|
t += 1
|
|
66
65
|
assert t == MAX_T
|
|
67
66
|
assert step.truncated
|
|
@@ -73,12 +72,15 @@ def test_truncated_and_done():
|
|
|
73
72
|
env = marlenv.wrappers.TimeLimit(DiscreteMockEnv(2, end_game=END_GAME), END_GAME)
|
|
74
73
|
obs, state = env.reset()
|
|
75
74
|
episode = marlenv.Episode.new(obs, state)
|
|
75
|
+
action = env.action_space.sample()
|
|
76
|
+
step = env.step(action)
|
|
76
77
|
while not episode.is_finished:
|
|
77
|
-
action = env.action_space.sample()
|
|
78
|
-
step = env.step(action)
|
|
79
78
|
episode.add(marlenv.Transition.from_step(obs, state, action, step))
|
|
80
79
|
obs = step.obs
|
|
81
80
|
state = step.state
|
|
81
|
+
action = env.action_space.sample()
|
|
82
|
+
step = env.step(action)
|
|
83
|
+
|
|
82
84
|
assert step.done
|
|
83
85
|
assert not step.truncated, (
|
|
84
86
|
"The episode is done, so it does not have to be truncated even though the time limit is reached at the same time."
|
|
@@ -97,11 +99,10 @@ def test_time_limit_wrapper_with_extra():
|
|
|
97
99
|
assert env.extras_shape == (1,)
|
|
98
100
|
obs, _ = env.reset()
|
|
99
101
|
assert obs.extras.shape == (5, 1)
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
while not
|
|
102
|
+
t = 1
|
|
103
|
+
step = env.step(np.array([0]))
|
|
104
|
+
while not step.is_terminal:
|
|
103
105
|
step = env.step(np.array([0]))
|
|
104
|
-
stop = step.done or step.truncated
|
|
105
106
|
t += 1
|
|
106
107
|
assert t == MAX_T
|
|
107
108
|
assert np.all(step.obs.extras == 1.0)
|
|
@@ -129,11 +130,10 @@ def test_time_limit_wrapper_with_truncation_penalty():
|
|
|
129
130
|
assert env.extras_shape == (1,)
|
|
130
131
|
obs, _ = env.reset()
|
|
131
132
|
assert obs.extras.shape == (5, 1)
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
while not
|
|
133
|
+
t = 1
|
|
134
|
+
step = env.step(np.array([0]))
|
|
135
|
+
while not step.is_terminal:
|
|
135
136
|
step = env.step(np.array([0]))
|
|
136
|
-
stop = step.done or step.truncated
|
|
137
137
|
t += 1
|
|
138
138
|
assert t == MAX_T
|
|
139
139
|
assert np.all(step.obs.extras[:] == 1)
|
|
@@ -374,9 +374,9 @@ def test_potential_shaping():
|
|
|
374
374
|
def compute_potential(self) -> float:
|
|
375
375
|
return self.phi
|
|
376
376
|
|
|
377
|
-
def step(self,
|
|
377
|
+
def step(self, action):
|
|
378
378
|
self.phi = max(0, self.phi - 1)
|
|
379
|
-
return super().step(
|
|
379
|
+
return super().step(action)
|
|
380
380
|
|
|
381
381
|
EP_LENGTH = 20
|
|
382
382
|
env = PS(DiscreteMockEnv(reward_step=0, end_game=EP_LENGTH))
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/utils/cached_property_collector.py
RENAMED
|
File without changes
|
{multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/utils/import_placeholders.py
RENAMED
|
File without changes
|
|
File without changes
|
{multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/wrappers/action_randomizer.py
RENAMED
|
File without changes
|
{multi_agent_rlenv-3.6.3 → multi_agent_rlenv-3.7.0}/src/marlenv/wrappers/available_actions_mask.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|