multi-agent-rlenv 3.4.0__tar.gz → 3.5.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/PKG-INFO +1 -1
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/__init__.py +11 -13
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/adapters/gym_adapter.py +6 -16
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/adapters/overcooked_adapter.py +6 -7
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/adapters/pettingzoo_adapter.py +5 -5
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/adapters/pymarl_adapter.py +3 -4
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/adapters/smac_adapter.py +6 -6
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/env_builder.py +8 -9
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/env_pool.py +5 -7
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/mock_env.py +7 -7
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/models/__init__.py +2 -4
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/models/env.py +18 -12
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/models/episode.py +15 -18
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/models/spaces.py +90 -83
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/models/step.py +1 -1
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/models/transition.py +6 -10
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/wrappers/__init__.py +2 -0
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/wrappers/agent_id_wrapper.py +4 -5
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/wrappers/available_actions_mask.py +6 -7
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/wrappers/available_actions_wrapper.py +7 -9
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/wrappers/blind_wrapper.py +5 -7
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/wrappers/centralised.py +12 -14
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/wrappers/delayed_rewards.py +13 -11
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/wrappers/last_action_wrapper.py +10 -14
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/wrappers/paddings.py +6 -8
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/wrappers/penalty_wrapper.py +5 -8
- multi_agent_rlenv-3.5.1/src/marlenv/wrappers/potential_shaping.py +49 -0
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/wrappers/rlenv_wrapper.py +12 -10
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/wrappers/time_limit.py +3 -3
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/wrappers/video_recorder.py +4 -6
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/tests/test_adapters.py +7 -7
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/tests/test_models.py +2 -2
- multi_agent_rlenv-3.5.1/tests/test_spaces.py +183 -0
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/tests/test_wrappers.py +35 -3
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/tests/utils.py +2 -2
- multi_agent_rlenv-3.4.0/tests/test_spaces.py +0 -134
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/.github/workflows/ci.yaml +0 -0
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/.github/workflows/docs.yaml +0 -0
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/.gitignore +0 -0
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/LICENSE +0 -0
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/README.md +0 -0
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/pyproject.toml +0 -0
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/adapters/__init__.py +0 -0
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/exceptions.py +0 -0
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/models/observation.py +0 -0
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/models/state.py +0 -0
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/py.typed +0 -0
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/utils/__init__.py +0 -0
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/utils/schedule.py +0 -0
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/tests/__init__.py +0 -0
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/tests/test_episode.py +0 -0
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/tests/test_pool.py +0 -0
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/tests/test_schedules.py +0 -0
- {multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/tests/test_serialization.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: multi-agent-rlenv
|
|
3
|
-
Version: 3.
|
|
3
|
+
Version: 3.5.1
|
|
4
4
|
Summary: A strongly typed Multi-Agent Reinforcement Learning framework
|
|
5
5
|
Project-URL: repository, https://github.com/yamoling/multi-agent-rlenv
|
|
6
6
|
Author-email: Yannick Molinghen <yannick.molinghen@ulb.be>
|
|
@@ -62,16 +62,11 @@ print(env.extras_shape) # (1, )
|
|
|
62
62
|
If you want to create a new environment, you can simply create a class that inherits from `MARLEnv`. If you want to create a wrapper around an existing `MARLEnv`, you probably want to subclass `RLEnvWrapper` which implements a default behaviour for every method.
|
|
63
63
|
"""
|
|
64
64
|
|
|
65
|
-
__version__ = "3.
|
|
65
|
+
__version__ = "3.5.1"
|
|
66
66
|
|
|
67
67
|
from . import models
|
|
68
|
-
from . import wrappers
|
|
69
|
-
from . import adapters
|
|
70
|
-
from .models import spaces
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
from .env_builder import make, Builder
|
|
74
68
|
from .models import (
|
|
69
|
+
spaces,
|
|
75
70
|
MARLEnv,
|
|
76
71
|
State,
|
|
77
72
|
Step,
|
|
@@ -80,10 +75,14 @@ from .models import (
|
|
|
80
75
|
Transition,
|
|
81
76
|
DiscreteSpace,
|
|
82
77
|
ContinuousSpace,
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
ContinuousActionSpace,
|
|
78
|
+
Space,
|
|
79
|
+
MultiDiscreteSpace,
|
|
86
80
|
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
from . import wrappers
|
|
84
|
+
from . import adapters
|
|
85
|
+
from .env_builder import make, Builder
|
|
87
86
|
from .wrappers import RLEnvWrapper
|
|
88
87
|
from .mock_env import DiscreteMockEnv, DiscreteMOMockEnv
|
|
89
88
|
|
|
@@ -100,12 +99,11 @@ __all__ = [
|
|
|
100
99
|
"Observation",
|
|
101
100
|
"Episode",
|
|
102
101
|
"Transition",
|
|
103
|
-
"ActionSpace",
|
|
104
102
|
"DiscreteSpace",
|
|
105
103
|
"ContinuousSpace",
|
|
106
|
-
"DiscreteActionSpace",
|
|
107
|
-
"ContinuousActionSpace",
|
|
108
104
|
"DiscreteMockEnv",
|
|
109
105
|
"DiscreteMOMockEnv",
|
|
110
106
|
"RLEnvWrapper",
|
|
107
|
+
"Space",
|
|
108
|
+
"MultiDiscreteSpace",
|
|
111
109
|
]
|
|
@@ -1,26 +1,16 @@
|
|
|
1
1
|
import sys
|
|
2
|
-
import cv2
|
|
3
2
|
from dataclasses import dataclass
|
|
4
|
-
from typing import Sequence
|
|
5
3
|
|
|
4
|
+
import cv2
|
|
6
5
|
import gymnasium as gym
|
|
7
6
|
import numpy as np
|
|
8
|
-
import numpy.typing as npt
|
|
9
7
|
from gymnasium import Env, spaces
|
|
10
8
|
|
|
11
|
-
from marlenv
|
|
12
|
-
ActionSpace,
|
|
13
|
-
ContinuousActionSpace,
|
|
14
|
-
DiscreteActionSpace,
|
|
15
|
-
MARLEnv,
|
|
16
|
-
Observation,
|
|
17
|
-
State,
|
|
18
|
-
Step,
|
|
19
|
-
)
|
|
9
|
+
from marlenv import ContinuousSpace, DiscreteSpace, MARLEnv, Observation, Space, State, Step
|
|
20
10
|
|
|
21
11
|
|
|
22
12
|
@dataclass
|
|
23
|
-
class Gym(MARLEnv[
|
|
13
|
+
class Gym(MARLEnv[Space]):
|
|
24
14
|
"""Wraps a gym envronment in an RLEnv"""
|
|
25
15
|
|
|
26
16
|
def __init__(self, env: Env | str, **kwargs):
|
|
@@ -30,7 +20,7 @@ class Gym(MARLEnv[Sequence | npt.NDArray, ActionSpace]):
|
|
|
30
20
|
raise NotImplementedError("Observation space must have a shape")
|
|
31
21
|
match env.action_space:
|
|
32
22
|
case spaces.Discrete() as s:
|
|
33
|
-
space =
|
|
23
|
+
space = DiscreteSpace(int(s.n), labels=[f"Action {i}" for i in range(s.n)]).repeat(1)
|
|
34
24
|
case spaces.Box() as s:
|
|
35
25
|
low = s.low.astype(np.float32)
|
|
36
26
|
high = s.high.astype(np.float32)
|
|
@@ -38,10 +28,10 @@ class Gym(MARLEnv[Sequence | npt.NDArray, ActionSpace]):
|
|
|
38
28
|
low = np.full(s.shape, s.low, dtype=np.float32)
|
|
39
29
|
if not isinstance(high, np.ndarray):
|
|
40
30
|
high = np.full(s.shape, s.high, dtype=np.float32)
|
|
41
|
-
space =
|
|
31
|
+
space = ContinuousSpace(low, high, labels=[f"Action {i}" for i in range(s.shape[0])]).repeat(1)
|
|
42
32
|
case other:
|
|
43
33
|
raise NotImplementedError(f"Action space {other} not supported")
|
|
44
|
-
super().__init__(space, env.observation_space.shape, (1,))
|
|
34
|
+
super().__init__(1, space, env.observation_space.shape, (1,))
|
|
45
35
|
self._gym_env = env
|
|
46
36
|
if self._gym_env.unwrapped.spec is not None:
|
|
47
37
|
self.name = self._gym_env.unwrapped.spec.id
|
{multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/adapters/overcooked_adapter.py
RENAMED
|
@@ -7,7 +7,7 @@ import cv2
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
import numpy.typing as npt
|
|
9
9
|
import pygame
|
|
10
|
-
from marlenv.models import ContinuousSpace,
|
|
10
|
+
from marlenv.models import ContinuousSpace, DiscreteSpace, MARLEnv, Observation, State, Step, MultiDiscreteSpace
|
|
11
11
|
from marlenv.utils import Schedule
|
|
12
12
|
|
|
13
13
|
from overcooked_ai_py.mdp.overcooked_env import OvercookedEnv
|
|
@@ -16,7 +16,7 @@ from overcooked_ai_py.visualization.state_visualizer import StateVisualizer
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
@dataclass
|
|
19
|
-
class Overcooked(MARLEnv[
|
|
19
|
+
class Overcooked(MARLEnv[MultiDiscreteSpace]):
|
|
20
20
|
horizon: int
|
|
21
21
|
shaping_factor: Schedule
|
|
22
22
|
|
|
@@ -37,10 +37,9 @@ class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
|
|
|
37
37
|
# -1 because we extract the "urgent" layer to the extras
|
|
38
38
|
shape = (int(layers - 1), int(width), int(height))
|
|
39
39
|
super().__init__(
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
action_names=[Action.ACTION_TO_CHAR[a] for a in Action.ALL_ACTIONS],
|
|
40
|
+
n_agents=self._mdp.num_players,
|
|
41
|
+
action_space=DiscreteSpace(Action.NUM_ACTIONS, labels=[Action.ACTION_TO_CHAR[a] for a in Action.ALL_ACTIONS]).repeat(
|
|
42
|
+
self._mdp.num_players
|
|
44
43
|
),
|
|
45
44
|
observation_shape=shape,
|
|
46
45
|
extras_shape=(2,),
|
|
@@ -95,7 +94,7 @@ class Overcooked(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
|
|
|
95
94
|
available_actions[agent_num, Action.ACTION_TO_INDEX[action]] = True
|
|
96
95
|
return np.array(available_actions, dtype=np.bool)
|
|
97
96
|
|
|
98
|
-
def step(self, actions: Sequence[int] |
|
|
97
|
+
def step(self, actions: Sequence[int] | np.ndarray) -> Step:
|
|
99
98
|
self.shaping_factor.update()
|
|
100
99
|
actions = [Action.ALL_ACTIONS[a] for a in actions]
|
|
101
100
|
_, reward, done, info = self._oenv.step(actions, display_phi=True)
|
{multi_agent_rlenv-3.4.0 → multi_agent_rlenv-3.5.1}/src/marlenv/adapters/pettingzoo_adapter.py
RENAMED
|
@@ -6,17 +6,17 @@ import numpy.typing as npt
|
|
|
6
6
|
from gymnasium import spaces # pettingzoo uses gymnasium spaces
|
|
7
7
|
from pettingzoo import ParallelEnv
|
|
8
8
|
|
|
9
|
-
from marlenv.models import
|
|
9
|
+
from marlenv.models import MARLEnv, Observation, State, Step, DiscreteSpace, ContinuousSpace, Space
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
@dataclass
|
|
13
|
-
class PettingZoo(MARLEnv[
|
|
13
|
+
class PettingZoo(MARLEnv[Space]):
|
|
14
14
|
def __init__(self, env: ParallelEnv):
|
|
15
15
|
aspace = env.action_space(env.possible_agents[0])
|
|
16
16
|
n_agents = len(env.possible_agents)
|
|
17
17
|
match aspace:
|
|
18
18
|
case spaces.Discrete() as s:
|
|
19
|
-
space =
|
|
19
|
+
space = DiscreteSpace.action(int(s.n)).repeat(n_agents)
|
|
20
20
|
|
|
21
21
|
case spaces.Box() as s:
|
|
22
22
|
low = s.low.astype(np.float32)
|
|
@@ -25,7 +25,7 @@ class PettingZoo(MARLEnv[npt.NDArray, ActionSpace]):
|
|
|
25
25
|
low = np.full(s.shape, s.low, dtype=np.float32)
|
|
26
26
|
if not isinstance(high, np.ndarray):
|
|
27
27
|
high = np.full(s.shape, s.high, dtype=np.float32)
|
|
28
|
-
space =
|
|
28
|
+
space = ContinuousSpace(low, high=high).repeat(n_agents)
|
|
29
29
|
case other:
|
|
30
30
|
raise NotImplementedError(f"Action space {other} not supported")
|
|
31
31
|
|
|
@@ -34,7 +34,7 @@ class PettingZoo(MARLEnv[npt.NDArray, ActionSpace]):
|
|
|
34
34
|
raise NotImplementedError("Only discrete observation spaces are supported")
|
|
35
35
|
self._pz_env = env
|
|
36
36
|
env.reset()
|
|
37
|
-
super().__init__(space, obs_space.shape, self.get_state().shape)
|
|
37
|
+
super().__init__(n_agents, space, obs_space.shape, self.get_state().shape)
|
|
38
38
|
self.agents = env.possible_agents
|
|
39
39
|
self.last_observation = None
|
|
40
40
|
|
|
@@ -1,10 +1,9 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
|
-
from typing import Any
|
|
2
|
+
from typing import Any
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
|
-
import numpy.typing as npt
|
|
6
5
|
|
|
7
|
-
from marlenv.models import
|
|
6
|
+
from marlenv.models import MARLEnv, MultiDiscreteSpace
|
|
8
7
|
from marlenv.wrappers import TimeLimit
|
|
9
8
|
|
|
10
9
|
|
|
@@ -15,7 +14,7 @@ class PymarlAdapter:
|
|
|
15
14
|
with the pymarl-qplex code base.
|
|
16
15
|
"""
|
|
17
16
|
|
|
18
|
-
def __init__(self, env: MARLEnv[
|
|
17
|
+
def __init__(self, env: MARLEnv[MultiDiscreteSpace], episode_limit: int):
|
|
19
18
|
assert env.reward_space.size == 1, "Only single objective environments are supported."
|
|
20
19
|
self.env = TimeLimit(env, episode_limit, add_extra=False)
|
|
21
20
|
# Required by PyMarl
|
|
@@ -1,15 +1,15 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import overload
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import numpy.typing as npt
|
|
6
6
|
from smac.env import StarCraft2Env
|
|
7
7
|
|
|
8
|
-
from marlenv.models import
|
|
8
|
+
from marlenv.models import MARLEnv, Observation, State, Step, MultiDiscreteSpace, DiscreteSpace
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
@dataclass
|
|
12
|
-
class SMAC(MARLEnv[
|
|
12
|
+
class SMAC(MARLEnv[MultiDiscreteSpace]):
|
|
13
13
|
"""Wrapper for the SMAC environment to work with this framework"""
|
|
14
14
|
|
|
15
15
|
@overload
|
|
@@ -157,10 +157,10 @@ class SMAC(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
|
|
|
157
157
|
case other:
|
|
158
158
|
raise ValueError(f"Invalid argument type: {type(other)}")
|
|
159
159
|
self._env = StarCraft2Env(map_name=map_name)
|
|
160
|
-
action_space = DiscreteActionSpace(self._env.n_agents, self._env.n_actions)
|
|
161
160
|
self._env_info = self._env.get_env_info()
|
|
162
161
|
super().__init__(
|
|
163
|
-
|
|
162
|
+
self._env.n_agents,
|
|
163
|
+
action_space=DiscreteSpace(self._env.n_actions).repeat(self._env.n_agents),
|
|
164
164
|
observation_shape=(self._env_info["obs_shape"],),
|
|
165
165
|
state_shape=(self._env_info["state_shape"],),
|
|
166
166
|
)
|
|
@@ -195,7 +195,7 @@ class SMAC(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace]):
|
|
|
195
195
|
)
|
|
196
196
|
return step
|
|
197
197
|
|
|
198
|
-
def available_actions(self) -> npt.NDArray[np.
|
|
198
|
+
def available_actions(self) -> npt.NDArray[np.bool]:
|
|
199
199
|
return np.array(self._env.get_avail_actions()) == 1
|
|
200
200
|
|
|
201
201
|
def get_image(self):
|
|
@@ -5,10 +5,9 @@ import numpy.typing as npt
|
|
|
5
5
|
|
|
6
6
|
from . import wrappers
|
|
7
7
|
from marlenv import adapters
|
|
8
|
-
from .models import
|
|
8
|
+
from .models import Space, MARLEnv
|
|
9
9
|
|
|
10
|
-
|
|
11
|
-
AS = TypeVar("AS", bound=ActionSpace)
|
|
10
|
+
AS = TypeVar("AS", bound=Space)
|
|
12
11
|
|
|
13
12
|
if adapters.HAS_PETTINGZOO:
|
|
14
13
|
from .adapters import PettingZoo
|
|
@@ -71,12 +70,12 @@ def make(env, **kwargs):
|
|
|
71
70
|
|
|
72
71
|
|
|
73
72
|
@dataclass
|
|
74
|
-
class Builder(Generic[
|
|
73
|
+
class Builder(Generic[AS]):
|
|
75
74
|
"""Builder for environments"""
|
|
76
75
|
|
|
77
|
-
_env: MARLEnv[
|
|
76
|
+
_env: MARLEnv[AS]
|
|
78
77
|
|
|
79
|
-
def __init__(self, env: MARLEnv[
|
|
78
|
+
def __init__(self, env: MARLEnv[AS]):
|
|
80
79
|
self._env = env
|
|
81
80
|
|
|
82
81
|
def time_limit(self, n_steps: int, add_extra: bool = True, truncation_penalty: Optional[float] = None):
|
|
@@ -124,9 +123,9 @@ class Builder(Generic[A, AS]):
|
|
|
124
123
|
|
|
125
124
|
def centralised(self):
|
|
126
125
|
"""Centralises the observations and actions"""
|
|
127
|
-
from marlenv.models import
|
|
126
|
+
from marlenv.models import MultiDiscreteSpace
|
|
128
127
|
|
|
129
|
-
assert isinstance(self._env.action_space,
|
|
128
|
+
assert isinstance(self._env.action_space, MultiDiscreteSpace)
|
|
130
129
|
self._env = wrappers.Centralized(self._env) # type: ignore
|
|
131
130
|
return self
|
|
132
131
|
|
|
@@ -159,6 +158,6 @@ class Builder(Generic[A, AS]):
|
|
|
159
158
|
self._env = wrappers.TimePenalty(self._env, penalty)
|
|
160
159
|
return self
|
|
161
160
|
|
|
162
|
-
def build(self) -> MARLEnv[
|
|
161
|
+
def build(self) -> MARLEnv[AS]:
|
|
163
162
|
"""Build and return the environment"""
|
|
164
163
|
return self._env
|
|
@@ -1,21 +1,19 @@
|
|
|
1
1
|
from typing import Sequence
|
|
2
2
|
from dataclasses import dataclass
|
|
3
|
-
import numpy.typing as npt
|
|
4
3
|
from typing_extensions import TypeVar
|
|
5
4
|
import random
|
|
6
5
|
|
|
7
6
|
from marlenv import RLEnvWrapper, MARLEnv
|
|
8
|
-
from marlenv.models import
|
|
7
|
+
from marlenv.models import Space
|
|
9
8
|
|
|
10
|
-
|
|
11
|
-
ActionSpaceType = TypeVar("ActionSpaceType", bound=ActionSpace, default=ActionSpace)
|
|
9
|
+
ActionSpaceType = TypeVar("ActionSpaceType", bound=Space, default=Space)
|
|
12
10
|
|
|
13
11
|
|
|
14
12
|
@dataclass
|
|
15
|
-
class EnvPool(RLEnvWrapper[
|
|
16
|
-
envs: Sequence[MARLEnv[
|
|
13
|
+
class EnvPool(RLEnvWrapper[ActionSpaceType]):
|
|
14
|
+
envs: Sequence[MARLEnv[ActionSpaceType]]
|
|
17
15
|
|
|
18
|
-
def __init__(self, envs: Sequence[MARLEnv[
|
|
16
|
+
def __init__(self, envs: Sequence[MARLEnv[ActionSpaceType]]):
|
|
19
17
|
assert len(envs) > 0, "EnvPool must contain at least one environment"
|
|
20
18
|
self.envs = envs
|
|
21
19
|
for env in envs[1:]:
|
|
@@ -1,12 +1,10 @@
|
|
|
1
|
-
from typing import Sequence
|
|
2
1
|
import numpy as np
|
|
3
|
-
import numpy.typing as npt
|
|
4
2
|
from dataclasses import dataclass
|
|
5
|
-
from marlenv import MARLEnv, Observation,
|
|
3
|
+
from marlenv import MARLEnv, Observation, ContinuousSpace, Step, State, DiscreteSpace, MultiDiscreteSpace
|
|
6
4
|
|
|
7
5
|
|
|
8
6
|
@dataclass
|
|
9
|
-
class DiscreteMockEnv(MARLEnv[
|
|
7
|
+
class DiscreteMockEnv(MARLEnv[MultiDiscreteSpace]):
|
|
10
8
|
def __init__(
|
|
11
9
|
self,
|
|
12
10
|
n_agents: int = 4,
|
|
@@ -27,7 +25,8 @@ class DiscreteMockEnv(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace])
|
|
|
27
25
|
case _:
|
|
28
26
|
raise ValueError("reward_step must be an int, float or np.ndarray")
|
|
29
27
|
super().__init__(
|
|
30
|
-
|
|
28
|
+
n_agents,
|
|
29
|
+
DiscreteSpace(n_actions).repeat(n_agents),
|
|
31
30
|
(obs_size,),
|
|
32
31
|
(n_agents * agent_state_size,),
|
|
33
32
|
extras_shape=(extras_size,),
|
|
@@ -85,7 +84,7 @@ class DiscreteMockEnv(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace])
|
|
|
85
84
|
)
|
|
86
85
|
|
|
87
86
|
|
|
88
|
-
class DiscreteMOMockEnv(MARLEnv[
|
|
87
|
+
class DiscreteMOMockEnv(MARLEnv[DiscreteSpace]):
|
|
89
88
|
"""Multi-Objective Mock Environment"""
|
|
90
89
|
|
|
91
90
|
def __init__(
|
|
@@ -100,7 +99,8 @@ class DiscreteMOMockEnv(MARLEnv[Sequence[int] | npt.NDArray, DiscreteActionSpace
|
|
|
100
99
|
extras_size: int = 0,
|
|
101
100
|
) -> None:
|
|
102
101
|
super().__init__(
|
|
103
|
-
|
|
102
|
+
n_agents,
|
|
103
|
+
DiscreteSpace(n_actions),
|
|
104
104
|
(obs_size,),
|
|
105
105
|
(n_agents * agent_state_size,),
|
|
106
106
|
extras_shape=(extras_size,),
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from .spaces import
|
|
1
|
+
from .spaces import DiscreteSpace, ContinuousSpace, MultiDiscreteSpace, Space
|
|
2
2
|
from .observation import Observation
|
|
3
3
|
from .step import Step
|
|
4
4
|
from .state import State
|
|
@@ -8,7 +8,6 @@ from .episode import Episode
|
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
__all__ = [
|
|
11
|
-
"ActionSpace",
|
|
12
11
|
"Step",
|
|
13
12
|
"State",
|
|
14
13
|
"DiscreteSpace",
|
|
@@ -18,6 +17,5 @@ __all__ = [
|
|
|
18
17
|
"Transition",
|
|
19
18
|
"Episode",
|
|
20
19
|
"MultiDiscreteSpace",
|
|
21
|
-
"
|
|
22
|
-
"ContinuousActionSpace",
|
|
20
|
+
"Space",
|
|
23
21
|
]
|
|
@@ -1,24 +1,22 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
from itertools import product
|
|
4
|
-
from typing import
|
|
4
|
+
from typing import Generic, Optional, Sequence, TypeVar
|
|
5
5
|
|
|
6
6
|
import cv2
|
|
7
7
|
import numpy as np
|
|
8
8
|
import numpy.typing as npt
|
|
9
|
-
from typing_extensions import TypeVar
|
|
10
9
|
|
|
11
10
|
from .observation import Observation
|
|
12
|
-
from .spaces import
|
|
11
|
+
from .spaces import ContinuousSpace, Space, DiscreteSpace, MultiDiscreteSpace
|
|
13
12
|
from .state import State
|
|
14
13
|
from .step import Step
|
|
15
14
|
|
|
16
|
-
|
|
17
|
-
ActionSpaceType = TypeVar("ActionSpaceType", bound=ActionSpace, default=Any)
|
|
15
|
+
ActionSpaceType = TypeVar("ActionSpaceType", bound=Space)
|
|
18
16
|
|
|
19
17
|
|
|
20
18
|
@dataclass
|
|
21
|
-
class MARLEnv(ABC, Generic[
|
|
19
|
+
class MARLEnv(ABC, Generic[ActionSpaceType]):
|
|
22
20
|
"""
|
|
23
21
|
Multi-Agent Reinforcement Learning environment.
|
|
24
22
|
|
|
@@ -70,6 +68,7 @@ class MARLEnv(ABC, Generic[ActionType, ActionSpaceType]):
|
|
|
70
68
|
|
|
71
69
|
def __init__(
|
|
72
70
|
self,
|
|
71
|
+
n_agents: int,
|
|
73
72
|
action_space: ActionSpaceType,
|
|
74
73
|
observation_shape: tuple[int, ...],
|
|
75
74
|
state_shape: tuple[int, ...],
|
|
@@ -81,8 +80,8 @@ class MARLEnv(ABC, Generic[ActionType, ActionSpaceType]):
|
|
|
81
80
|
super().__init__()
|
|
82
81
|
self.name = self.__class__.__name__
|
|
83
82
|
self.action_space = action_space
|
|
84
|
-
self.n_actions = action_space.
|
|
85
|
-
self.n_agents =
|
|
83
|
+
self.n_actions = action_space.shape[-1]
|
|
84
|
+
self.n_agents = n_agents
|
|
86
85
|
self.observation_shape = observation_shape
|
|
87
86
|
self.state_shape = state_shape
|
|
88
87
|
self.extras_shape = extras_shape
|
|
@@ -113,9 +112,16 @@ class MARLEnv(ABC, Generic[ActionType, ActionSpaceType]):
|
|
|
113
112
|
"""The number of objectives in the environment."""
|
|
114
113
|
return self.reward_space.size
|
|
115
114
|
|
|
116
|
-
def sample_action(self)
|
|
115
|
+
def sample_action(self):
|
|
117
116
|
"""Sample an available action from the action space."""
|
|
118
|
-
|
|
117
|
+
match self.action_space:
|
|
118
|
+
case MultiDiscreteSpace() as aspace:
|
|
119
|
+
return aspace.sample(mask=self.available_actions())
|
|
120
|
+
case ContinuousSpace() as aspace:
|
|
121
|
+
return aspace.sample()
|
|
122
|
+
case DiscreteSpace() as aspace:
|
|
123
|
+
return np.array([aspace.sample(mask=self.available_actions())])
|
|
124
|
+
raise NotImplementedError("Action space not supported")
|
|
119
125
|
|
|
120
126
|
def available_actions(self) -> npt.NDArray[np.bool]:
|
|
121
127
|
"""
|
|
@@ -147,7 +153,7 @@ class MARLEnv(ABC, Generic[ActionType, ActionSpaceType]):
|
|
|
147
153
|
raise NotImplementedError("Method not implemented")
|
|
148
154
|
|
|
149
155
|
@abstractmethod
|
|
150
|
-
def step(self,
|
|
156
|
+
def step(self, action: Sequence | np.ndarray) -> Step:
|
|
151
157
|
"""Perform a step in the environment.
|
|
152
158
|
|
|
153
159
|
Returns a Step object that can be unpacked as a 6-tuple containing:
|
|
@@ -180,7 +186,7 @@ class MARLEnv(ABC, Generic[ActionType, ActionSpaceType]):
|
|
|
180
186
|
"""Retrieve an image of the environment"""
|
|
181
187
|
raise NotImplementedError("No image available for this environment")
|
|
182
188
|
|
|
183
|
-
def replay(self, actions: Sequence
|
|
189
|
+
def replay(self, actions: Sequence, seed: Optional[int] = None):
|
|
184
190
|
"""Replay a sequence of actions."""
|
|
185
191
|
from .episode import Episode # Avoid circular import
|
|
186
192
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
2
|
from functools import cached_property
|
|
3
|
-
from typing import Any, Callable,
|
|
3
|
+
from typing import Any, Callable, Optional, Sequence, overload
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import numpy.typing as npt
|
|
@@ -14,11 +14,8 @@ from .env import MARLEnv
|
|
|
14
14
|
from marlenv.exceptions import EnvironmentMismatchException, ReplayMismatchException
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
A = TypeVar("A")
|
|
18
|
-
|
|
19
|
-
|
|
20
17
|
@dataclass
|
|
21
|
-
class Episode
|
|
18
|
+
class Episode:
|
|
22
19
|
"""Episode model made of observations, actions, rewards, ..."""
|
|
23
20
|
|
|
24
21
|
all_observations: list[npt.NDArray[np.float32]]
|
|
@@ -55,7 +52,7 @@ class Episode(Generic[A]):
|
|
|
55
52
|
)
|
|
56
53
|
|
|
57
54
|
@staticmethod
|
|
58
|
-
def from_transitions(transitions: Sequence[Transition
|
|
55
|
+
def from_transitions(transitions: Sequence[Transition]) -> "Episode":
|
|
59
56
|
"""Create an episode from a list of transitions"""
|
|
60
57
|
episode = Episode.new(transitions[0].obs, transitions[0].state)
|
|
61
58
|
for transition in transitions:
|
|
@@ -214,11 +211,11 @@ class Episode(Generic[A]):
|
|
|
214
211
|
|
|
215
212
|
def replay(
|
|
216
213
|
self,
|
|
217
|
-
env: MARLEnv
|
|
214
|
+
env: MARLEnv,
|
|
218
215
|
seed: Optional[int] = None,
|
|
219
216
|
*,
|
|
220
|
-
after_reset: Optional[Callable[[Observation, State, MARLEnv
|
|
221
|
-
after_step: Optional[Callable[[int, Step, MARLEnv
|
|
217
|
+
after_reset: Optional[Callable[[Observation, State, MARLEnv], None]] = None,
|
|
218
|
+
after_step: Optional[Callable[[int, Step, MARLEnv], None]] = None,
|
|
222
219
|
):
|
|
223
220
|
"""
|
|
224
221
|
Replay the episode in the environment (i.e. perform the actions) and assert that the outcomes match.
|
|
@@ -243,12 +240,12 @@ class Episode(Generic[A]):
|
|
|
243
240
|
raise ReplayMismatchException("observation", step.obs.data, self.next_obs[i], time_step=i)
|
|
244
241
|
if not np.array_equal(step.state.data, self.next_states[i]):
|
|
245
242
|
raise ReplayMismatchException("state", step.state.data, self.next_states[i], time_step=i)
|
|
246
|
-
if not np.
|
|
243
|
+
if not np.isclose(step.reward, self.rewards[i]):
|
|
247
244
|
raise ReplayMismatchException("reward", step.reward, self.rewards[i], time_step=i)
|
|
248
245
|
if after_step is not None:
|
|
249
246
|
after_step(i, step, env)
|
|
250
247
|
|
|
251
|
-
def get_images(self, env: MARLEnv
|
|
248
|
+
def get_images(self, env: MARLEnv, seed: Optional[int] = None) -> list[np.ndarray]:
|
|
252
249
|
images = []
|
|
253
250
|
|
|
254
251
|
def collect_image(*_, **__):
|
|
@@ -257,7 +254,7 @@ class Episode(Generic[A]):
|
|
|
257
254
|
self.replay(env, seed, after_reset=collect_image, after_step=collect_image)
|
|
258
255
|
return images
|
|
259
256
|
|
|
260
|
-
def render(self, env: MARLEnv
|
|
257
|
+
def render(self, env: MARLEnv, seed: Optional[int] = None, fps: int = 5):
|
|
261
258
|
def render_callback(*_, **__):
|
|
262
259
|
env.render()
|
|
263
260
|
cv2.waitKey(1000 // fps)
|
|
@@ -288,10 +285,10 @@ class Episode(Generic[A]):
|
|
|
288
285
|
return returns
|
|
289
286
|
|
|
290
287
|
@overload
|
|
291
|
-
def add(self, transition: Transition
|
|
288
|
+
def add(self, transition: Transition, /): ...
|
|
292
289
|
|
|
293
290
|
@overload
|
|
294
|
-
def add(self, step: Step, action:
|
|
291
|
+
def add(self, step: Step, action: np.ndarray, /): ...
|
|
295
292
|
|
|
296
293
|
def add(self, *data):
|
|
297
294
|
match data:
|
|
@@ -322,10 +319,10 @@ class Episode(Generic[A]):
|
|
|
322
319
|
|
|
323
320
|
def add_data(
|
|
324
321
|
self,
|
|
325
|
-
next_obs,
|
|
326
|
-
next_state,
|
|
327
|
-
action:
|
|
328
|
-
reward: np.
|
|
322
|
+
next_obs: Observation,
|
|
323
|
+
next_state: State,
|
|
324
|
+
action: np.ndarray,
|
|
325
|
+
reward: npt.NDArray[np.float32],
|
|
329
326
|
others: dict[str, Any],
|
|
330
327
|
done: bool,
|
|
331
328
|
truncated: bool,
|