multi-agent-rlenv 3.5.4__py3-none-any.whl → 3.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
marlenv/__init__.py CHANGED
@@ -62,7 +62,13 @@ print(env.extras_shape) # (1, )
62
62
  If you want to create a new environment, you can simply create a class that inherits from `MARLEnv`. If you want to create a wrapper around an existing `MARLEnv`, you probably want to subclass `RLEnvWrapper` which implements a default behaviour for every method.
63
63
  """
64
64
 
65
- __version__ = "3.5.4"
65
+ from importlib.metadata import version, PackageNotFoundError
66
+
67
+ try:
68
+ __version__ = version("overcooked")
69
+ except PackageNotFoundError:
70
+ __version__ = "0.0.0" # fallback pratique en dev/CI
71
+
66
72
 
67
73
  from . import models
68
74
  from .models import (
@@ -82,16 +88,19 @@ from .models import (
82
88
 
83
89
  from . import wrappers
84
90
  from . import adapters
85
- from .env_builder import make, Builder
91
+ from .env_builder import Builder
86
92
  from .wrappers import RLEnvWrapper
87
93
  from .mock_env import DiscreteMockEnv, DiscreteMOMockEnv
94
+ from . import catalog
95
+ from .adapters import make
88
96
 
89
97
  __all__ = [
90
98
  "models",
99
+ "make",
100
+ "catalog",
91
101
  "wrappers",
92
102
  "adapters",
93
103
  "spaces",
94
- "make",
95
104
  "Builder",
96
105
  "MARLEnv",
97
106
  "Step",
@@ -1,42 +1,33 @@
1
1
  from importlib.util import find_spec
2
2
  from .pymarl_adapter import PymarlAdapter
3
+ from marlenv.utils import DummyClass, dummy_function
3
4
 
4
- HAS_GYM = False
5
- if find_spec("gymnasium") is not None:
6
- from .gym_adapter import Gym
5
+ HAS_GYM = find_spec("gymnasium") is not None
6
+ if HAS_GYM:
7
+ from .gym_adapter import Gym, make
8
+ else:
9
+ Gym = DummyClass("gymnasium")
10
+ make = dummy_function("gymnasium")
7
11
 
8
- HAS_GYM = True
9
-
10
- HAS_PETTINGZOO = False
11
- if find_spec("pettingzoo") is not None:
12
+ HAS_PETTINGZOO = find_spec("pettingzoo") is not None
13
+ if HAS_PETTINGZOO:
12
14
  from .pettingzoo_adapter import PettingZoo
15
+ else:
16
+ PettingZoo = DummyClass("pettingzoo")
13
17
 
14
- HAS_PETTINGZOO = True
15
-
16
- HAS_SMAC = False
17
- if find_spec("smac") is not None:
18
+ HAS_SMAC = find_spec("smac") is not None
19
+ if HAS_SMAC:
18
20
  from .smac_adapter import SMAC
19
-
20
- HAS_SMAC = True
21
-
22
- HAS_OVERCOOKED = False
23
- if find_spec("overcooked_ai_py") is not None and find_spec("overcooked_ai_py.mdp") is not None:
24
- import numpy
25
-
26
- # Overcooked assumes a version of numpy <2.0 where np.Inf is available.
27
- setattr(numpy, "Inf", numpy.inf)
28
- from .overcooked_adapter import Overcooked
29
-
30
- HAS_OVERCOOKED = True
21
+ else:
22
+ SMAC = DummyClass("smac", "https://github.com/oxwhirl/smac.git")
31
23
 
32
24
  __all__ = [
33
25
  "PymarlAdapter",
34
26
  "Gym",
27
+ "make",
35
28
  "PettingZoo",
36
29
  "SMAC",
37
- "Overcooked",
38
30
  "HAS_GYM",
39
31
  "HAS_PETTINGZOO",
40
32
  "HAS_SMAC",
41
- "HAS_OVERCOOKED",
42
33
  ]
@@ -78,3 +78,9 @@ class Gym(MARLEnv[Space]):
78
78
 
79
79
  def seed(self, seed_value: int):
80
80
  self._gym_env.reset(seed=seed_value)
81
+
82
+
83
+ def make(env_id: str, **kwargs):
84
+ """Make an RLEnv from str (Gym) or PettingZoo"""
85
+ gym_env = gym.make(env_id, render_mode="rgb_array", **kwargs)
86
+ return Gym(gym_env)
@@ -0,0 +1,26 @@
1
+ from importlib.util import find_spec
2
+ from ..utils.import_placeholders import DummyClass
3
+ from marlenv.adapters import SMAC
4
+ from .deepsea import DeepSea
5
+
6
+
7
+ HAS_LLE = find_spec("lle") is not None
8
+ if HAS_LLE:
9
+ from lle import LLE # pyright: ignore[reportMissingImports]
10
+ else:
11
+ LLE = DummyClass("lle", "laser-learning-environment")
12
+
13
+ HAS_OVERCOOKED = find_spec("overcooked") is not None
14
+ if HAS_OVERCOOKED:
15
+ from overcooked import Overcooked # pyright: ignore[reportMissingImports]
16
+ else:
17
+ Overcooked = DummyClass("overcooked", "overcooked")
18
+
19
+ __all__ = [
20
+ "Overcooked",
21
+ "SMAC",
22
+ "LLE",
23
+ "DeepSea",
24
+ "HAS_LLE",
25
+ "HAS_OVERCOOKED",
26
+ ]
@@ -0,0 +1,73 @@
1
+ from typing import Sequence
2
+ import numpy as np
3
+ from marlenv import MARLEnv, MultiDiscreteSpace, DiscreteSpace, Observation, State, Step
4
+ from dataclasses import dataclass
5
+
6
+
7
+ LEFT = 0
8
+ RIGHT = 1
9
+
10
+
11
+ @dataclass
12
+ class DeepSea(MARLEnv[MultiDiscreteSpace]):
13
+ """
14
+ Deep Sea single-agent environment to test for deep exploration. The probability of reaching the goal state under random exploration is 2^(-max_depth).
15
+
16
+ The agent explores a 2D grid where the bottom-right corner (max_depth, max_depth) is the goal and is the only state to yield a reward.
17
+ The agent starts in the top-left corner (0, 0).
18
+ The agent has two actions: left or right, and taking an action makes the agent dive one row deeper. The agent can not go beyond the grid boundaries.
19
+ Going right gives a penalty of (0.01 / max_depth).
20
+ """
21
+
22
+ max_depth: int
23
+
24
+ def __init__(self, max_depth: int):
25
+ super().__init__(
26
+ n_agents=1,
27
+ action_space=DiscreteSpace(size=2, labels=["left", "right"]).repeat(1),
28
+ observation_shape=(2,),
29
+ state_shape=(2,),
30
+ )
31
+ self.max_depth = max_depth
32
+ self._row = 0
33
+ self._col = 0
34
+ self._step_right_penalty = -0.01 / self.max_depth
35
+
36
+ def get_observation(self) -> Observation:
37
+ return Observation(np.array([self._row, self._col], dtype=np.float32), self.available_actions())
38
+
39
+ def get_state(self) -> State:
40
+ return State(np.array([self._row, self._col], dtype=np.float32))
41
+
42
+ def reset(self):
43
+ self._row = 0
44
+ self._col = 0
45
+ return self.get_observation(), self.get_state()
46
+
47
+ def step(self, action: Sequence[int]):
48
+ self._row += 1
49
+ if action[0] == LEFT:
50
+ self._col -= 1
51
+ else:
52
+ self._col += 1
53
+ self._col = max(0, self._col)
54
+ if action[0] == RIGHT:
55
+ if self._row == self.max_depth:
56
+ reward = 1.0
57
+ else:
58
+ reward = self._step_right_penalty
59
+ else:
60
+ reward = 0.0
61
+ return Step(
62
+ self.get_observation(),
63
+ self.get_state(),
64
+ reward,
65
+ done=self._row == self.max_depth,
66
+ )
67
+
68
+ def set_state(self, state: State):
69
+ self._row, self._col = state.data
70
+
71
+ @property
72
+ def agent_state_size(self):
73
+ return 2
marlenv/env_builder.py CHANGED
@@ -1,73 +1,13 @@
1
1
  from dataclasses import dataclass
2
- from typing import Generic, Literal, Optional, TypeVar, overload
2
+ from typing import Generic, Literal, Optional, TypeVar
3
3
  import numpy as np
4
4
  import numpy.typing as npt
5
5
 
6
6
  from . import wrappers
7
- from marlenv import adapters
8
7
  from .models import Space, MARLEnv
9
8
 
10
9
  AS = TypeVar("AS", bound=Space)
11
10
 
12
- if adapters.HAS_PETTINGZOO:
13
- from .adapters import PettingZoo
14
- from pettingzoo import ParallelEnv
15
-
16
- @overload
17
- def make(env: ParallelEnv) -> PettingZoo: ...
18
-
19
-
20
- if adapters.HAS_GYM:
21
- from .adapters import Gym
22
- from gymnasium import Env
23
- import gymnasium
24
-
25
- @overload
26
- def make(env: Env) -> Gym: ...
27
-
28
- @overload
29
- def make(env: str, **kwargs) -> Gym:
30
- """
31
- Make an RLEnv from the `gymnasium` registry (e.g: "CartPole-v1").
32
- """
33
-
34
-
35
- if adapters.HAS_SMAC:
36
- from .adapters import SMAC
37
- from smac.env import StarCraft2Env
38
-
39
- @overload
40
- def make(env: StarCraft2Env) -> SMAC: ...
41
-
42
-
43
- if adapters.HAS_OVERCOOKED:
44
- from .adapters import Overcooked
45
- from overcooked_ai_py.mdp.overcooked_env import OvercookedEnv
46
-
47
- @overload
48
- def make(env: OvercookedEnv) -> Overcooked: ...
49
-
50
-
51
- def make(env, **kwargs):
52
- """Make an RLEnv from str (Gym) or PettingZoo"""
53
- match env:
54
- case MARLEnv():
55
- return env
56
- case str(env_id):
57
- if adapters.HAS_GYM:
58
- gym_env = gymnasium.make(env_id, render_mode="rgb_array", **kwargs)
59
- return Gym(gym_env)
60
-
61
- if adapters.HAS_PETTINGZOO and isinstance(env, ParallelEnv):
62
- return PettingZoo(env) # type: ignore
63
- if adapters.HAS_SMAC and isinstance(env, StarCraft2Env):
64
- return SMAC(env)
65
- if adapters.HAS_OVERCOOKED and isinstance(env, OvercookedEnv):
66
- return Overcooked(env) # type: ignore
67
- if adapters.HAS_GYM and isinstance(env, Env):
68
- return Gym(env)
69
- raise ValueError(f"Unknown environment type: {type(env)}")
70
-
71
11
 
72
12
  @dataclass
73
13
  class Builder(Generic[AS]):
marlenv/models/episode.py CHANGED
@@ -2,20 +2,22 @@ from dataclasses import dataclass
2
2
  from functools import cached_property
3
3
  from typing import Any, Callable, Optional, Sequence, overload
4
4
 
5
+ import cv2
5
6
  import numpy as np
6
7
  import numpy.typing as npt
7
- import cv2
8
8
 
9
+ from marlenv.exceptions import EnvironmentMismatchException, ReplayMismatchException
10
+ from marlenv.utils import CachedPropertyInvalidator
11
+
12
+ from .env import MARLEnv
9
13
  from .observation import Observation
10
14
  from .state import State
11
15
  from .step import Step
12
16
  from .transition import Transition
13
- from .env import MARLEnv
14
- from marlenv.exceptions import EnvironmentMismatchException, ReplayMismatchException
15
17
 
16
18
 
17
19
  @dataclass
18
- class Episode:
20
+ class Episode(CachedPropertyInvalidator):
19
21
  """Episode model made of observations, actions, rewards, ..."""
20
22
 
21
23
  all_observations: list[npt.NDArray[np.float32]]
@@ -153,12 +155,12 @@ class Episode:
153
155
  """Get the next extra features"""
154
156
  return self.all_extras[1:]
155
157
 
156
- @cached_property
158
+ @property
157
159
  def n_agents(self):
158
160
  """The number of agents in the episode"""
159
161
  return self.all_extras[0].shape[0]
160
162
 
161
- @cached_property
163
+ @property
162
164
  def n_actions(self):
163
165
  """The number of actions"""
164
166
  return len(self.all_available_actions[0][0])
@@ -267,7 +269,7 @@ class Episode:
267
269
  def __len__(self):
268
270
  return self.episode_len
269
271
 
270
- @cached_property
272
+ @property
271
273
  def score(self) -> list[float]:
272
274
  """The episode score (sum of all rewards across all objectives)"""
273
275
  score = []
marlenv/models/spaces.py CHANGED
@@ -105,7 +105,7 @@ class MultiDiscreteSpace(Space):
105
105
  def sample(self, mask: Optional[npt.NDArray[np.bool] | list[npt.NDArray[np.bool]]] = None):
106
106
  if mask is None:
107
107
  return np.array([space.sample() for space in self.spaces], dtype=np.int32)
108
- return np.array([space.sample(mask=mask) for mask, space in zip(mask, self.spaces)], dtype=np.int32)
108
+ return np.array([space.sample(mask=mask) for mask, space in zip(mask, self.spaces)], dtype=np.int32) # type: ignore
109
109
 
110
110
  def __eq__(self, value: object) -> bool:
111
111
  if not isinstance(value, MultiDiscreteSpace):
marlenv/utils/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
- from .schedule import Schedule, MultiSchedule, RoundedSchedule, LinearSchedule, ExpSchedule
2
-
1
+ from .cached_property_collector import CachedPropertyCollector, CachedPropertyInvalidator
2
+ from .schedule import ExpSchedule, LinearSchedule, MultiSchedule, RoundedSchedule, Schedule
3
+ from .import_placeholders import DummyClass, dummy_function
3
4
 
4
5
  __all__ = [
5
6
  "Schedule",
@@ -7,4 +8,8 @@ __all__ = [
7
8
  "ExpSchedule",
8
9
  "MultiSchedule",
9
10
  "RoundedSchedule",
11
+ "CachedPropertyCollector",
12
+ "CachedPropertyInvalidator",
13
+ "DummyClass",
14
+ "dummy_function",
10
15
  ]
@@ -0,0 +1,17 @@
1
+ from functools import cached_property
2
+
3
+
4
+ class CachedPropertyCollector(type):
5
+ def __init__(cls, name: str, bases: tuple, namespace: dict):
6
+ super().__init__(name, bases, namespace)
7
+ cls.CACHED_PROPERTY_NAMES = [key for key, value in namespace.items() if isinstance(value, cached_property)]
8
+
9
+
10
+ class CachedPropertyInvalidator(metaclass=CachedPropertyCollector):
11
+ def __init__(self):
12
+ super().__init__()
13
+
14
+ def invalidate_cached_properties(self):
15
+ for key in self.__class__.CACHED_PROPERTY_NAMES:
16
+ if hasattr(self, key):
17
+ delattr(self, key)
@@ -0,0 +1,30 @@
1
+ from typing import Optional, Any
2
+
3
+
4
+ class DummyClass:
5
+ def __init__(self, module_name: str, package_name: Optional[str] = None):
6
+ self.module_name = module_name
7
+ if package_name is None:
8
+ self.package_name = module_name
9
+ else:
10
+ self.package_name = package_name
11
+
12
+ def _raise_error(self):
13
+ raise ImportError(
14
+ f"The optional dependency `{self.module_name}` is not installed.\nInstall the `{self.package_name}` package (e.g. pip install {self.package_name})."
15
+ )
16
+
17
+ def __getattr__(self, _):
18
+ self._raise_error()
19
+
20
+ def __call__(self, *args, **kwargs):
21
+ self._raise_error()
22
+
23
+
24
+ def dummy_function(module_name: str, package_name: Optional[str] = None):
25
+ dummy = DummyClass(module_name, package_name)
26
+
27
+ def fail(*args, **kwargs) -> Any:
28
+ dummy()
29
+
30
+ return fail
marlenv/utils/schedule.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from abc import abstractmethod
2
2
  from dataclasses import dataclass
3
- from typing import Callable, Optional, TypeVar
3
+ from typing import Any, Callable, Optional, TypeVar
4
4
 
5
5
  T = TypeVar("T")
6
6
 
@@ -142,6 +142,21 @@ class Schedule:
142
142
  def __int__(self) -> int:
143
143
  return int(self.value)
144
144
 
145
+ @staticmethod
146
+ def from_json(data: dict[str, Any]):
147
+ """Create a Schedule from a JSON-like dictionary."""
148
+ classname = data.get("name")
149
+ if classname == "LinearSchedule":
150
+ return LinearSchedule(data["start_value"], data["end_value"], data["n_steps"])
151
+ elif classname == "ExpSchedule":
152
+ return ExpSchedule(data["start_value"], data["end_value"], data["n_steps"])
153
+ elif classname == "ConstantSchedule":
154
+ return ConstantSchedule(data["value"])
155
+ elif classname == "ArbitrarySchedule":
156
+ raise NotImplementedError("ArbitrarySchedule cannot be deserialized from JSON")
157
+ else:
158
+ raise ValueError(f"Unknown schedule type: {classname}")
159
+
145
160
 
146
161
  @dataclass(eq=False)
147
162
  class LinearSchedule(Schedule):
@@ -11,6 +11,7 @@ from .centralised import Centralized
11
11
  from .available_actions_mask import AvailableActionsMask
12
12
  from .delayed_rewards import DelayedReward
13
13
  from .potential_shaping import PotentialShaping
14
+ from .action_randomizer import ActionRandomizer
14
15
 
15
16
  __all__ = [
16
17
  "RLEnvWrapper",
@@ -28,4 +29,5 @@ __all__ = [
28
29
  "Centralized",
29
30
  "DelayedReward",
30
31
  "PotentialShaping",
32
+ "ActionRandomizer",
31
33
  ]
@@ -0,0 +1,17 @@
1
+ from .rlenv_wrapper import RLEnvWrapper, AS, MARLEnv
2
+ import numpy as np
3
+
4
+
5
+ class ActionRandomizer(RLEnvWrapper[AS]):
6
+ def __init__(self, env: MARLEnv[AS], p: float):
7
+ super().__init__(env)
8
+ self.p = p
9
+
10
+ def step(self, action):
11
+ if np.random.rand() < self.p:
12
+ action = self.action_space.sample()
13
+ return super().step(action)
14
+
15
+ def seed(self, seed_value: int):
16
+ np.random.seed(seed_value)
17
+ super().seed(seed_value)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: multi-agent-rlenv
3
- Version: 3.5.4
3
+ Version: 3.6.0
4
4
  Summary: A strongly typed Multi-Agent Reinforcement Learning framework
5
5
  Project-URL: repository, https://github.com/yamoling/multi-agent-rlenv
6
6
  Author-email: Yannick Molinghen <yannick.molinghen@ulb.be>
@@ -13,7 +13,8 @@ Requires-Dist: opencv-python>=4.0
13
13
  Requires-Dist: typing-extensions>=4.0
14
14
  Provides-Extra: all
15
15
  Requires-Dist: gymnasium>0.29.1; extra == 'all'
16
- Requires-Dist: overcooked-ai; extra == 'all'
16
+ Requires-Dist: laser-learning-environment>=2.6.1; extra == 'all'
17
+ Requires-Dist: overcooked>=0.1.0; extra == 'all'
17
18
  Requires-Dist: pettingzoo>=1.20; extra == 'all'
18
19
  Requires-Dist: pymunk>=6.0; extra == 'all'
19
20
  Requires-Dist: pysc2; extra == 'all'
@@ -22,9 +23,10 @@ Requires-Dist: smac; extra == 'all'
22
23
  Requires-Dist: torch>=2.0; extra == 'all'
23
24
  Provides-Extra: gym
24
25
  Requires-Dist: gymnasium>=0.29.1; extra == 'gym'
26
+ Provides-Extra: lle
27
+ Requires-Dist: laser-learning-environment>=2.6.1; extra == 'lle'
25
28
  Provides-Extra: overcooked
26
- Requires-Dist: overcooked-ai>=1.1.0; extra == 'overcooked'
27
- Requires-Dist: scipy>=1.10; extra == 'overcooked'
29
+ Requires-Dist: overcooked>=0.1.0; extra == 'overcooked'
28
30
  Provides-Extra: pettingzoo
29
31
  Requires-Dist: pettingzoo>=1.20; extra == 'pettingzoo'
30
32
  Requires-Dist: pymunk>=6.0; extra == 'pettingzoo'
@@ -62,9 +64,24 @@ $ pip install marlenv[smac] # Install SMAC
62
64
  $ pip install marlenv[gym,smac] # Install Gym & smac support
63
65
  ```
64
66
 
67
+ ## Using the `marlenv` environment catalog
68
+ Some environments are registered in the `marlenv` and can be easily instantiated via its catalog.
69
+
70
+ ```python
71
+ from marlenv import catalog
72
+
73
+ env1 = catalog.Overcooked.from_layout("scenario4")
74
+ env2 = catalog.LLE.level(6)
75
+ env3 = catalog.DeepSea(mex_depth=5)
76
+ ```
77
+ Note that using the catalog requires the corresponding environment package to be installed. For instance you need to install the `laser-learning-environment` package to use `catalog.LLE`, which can be done by using the corresponding feature when at installation as shown below.
78
+ ```bash
79
+ pip install multi-agent-rlenv[lle]
80
+ ```
81
+
65
82
 
66
83
  ## Using `marlenv` with existing libraries
67
- `marlenv` unifies multiple popular libraries under a single interface. Namely, `marlenv` supports `smac`, `gymnasium` and `pettingzoo`.
84
+ `marlenv` provides adapters from most popular libraries to unify them under a single interface. Namely, `marlenv` supports `smac`, `gymnasium` and `pettingzoo`.
68
85
 
69
86
  ```python
70
87
  import marlenv
@@ -74,7 +91,7 @@ gym_env = marlenv.make("CartPole-v1", seed=25)
74
91
 
75
92
  # You can seemlessly instanciate a SMAC environment and directly pass your required arguments
76
93
  from marlenv.adapters import SMAC
77
- smac_env = env2 = SMAC("3m", debug=True, difficulty="9")
94
+ smac_env = SMAC("3m", debug=True, difficulty="9")
78
95
 
79
96
  # pettingzoo is also supported
80
97
  from pettingzoo.sisl import pursuit_v4
@@ -1,26 +1,30 @@
1
- marlenv/__init__.py,sha256=-XMj91_t4PoY6zLhr395u_svfIXYlvk-tMR4YqG-2VQ,3656
2
- marlenv/env_builder.py,sha256=RJoHJLYAUE1ausAoJiRC3fUxyxpH1WRJf7Sdm2ml-uk,5517
1
+ marlenv/__init__.py,sha256=MJgaW73zWYJKTNMWE8V3hTvrcMk-WEX3RaG-K_oIDD8,3886
2
+ marlenv/env_builder.py,sha256=RUMFvW7dAJtHMLm8-oPVpjBefDtNliZtjlHci97Xj-Q,3874
3
3
  marlenv/env_pool.py,sha256=nCEBkGQU62fcvCAANyAqY8gCFjYlVnSCg-V3Fhx00yc,933
4
4
  marlenv/exceptions.py,sha256=gJUC_2rVAvOfK_ypVFc7Myh-pIfSU3To38VBVS_0rZA,1179
5
5
  marlenv/mock_env.py,sha256=kKvTdZl4_xSTTI9V6otZ1P709sfPYrqZSbbZaTip9iI,4684
6
6
  marlenv/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
- marlenv/adapters/__init__.py,sha256=rWiqQOqTx3kVL5ZkPo3rkczrlQBBhQbU55zGI26SEeY,929
8
- marlenv/adapters/gym_adapter.py,sha256=Rx8ZnHW0XEwZzRT53BdDP1f4CtNp2tATAYZ0BbtBbd0,2863
9
- marlenv/adapters/overcooked_adapter.py,sha256=0-5sfKHGjmM4eYtbydsENMXV2Qx4WVGlvJl7MFZSaj8,9123
7
+ marlenv/adapters/__init__.py,sha256=wJzd94WfNNFX-yojr2M1dExVAGkFqwM2fieQ-v7uH4s,783
8
+ marlenv/adapters/gym_adapter.py,sha256=5HZF3g0QD4n7K4GQoMis4q0zj97uFTLdzdxMYHzM_UE,3041
10
9
  marlenv/adapters/pettingzoo_adapter.py,sha256=w9Ta-X4L_6ZXdDGmREOdcU0vpLR8lGP__s49DyK3dk8,2852
11
10
  marlenv/adapters/pymarl_adapter.py,sha256=2s7EY31s1hrml3q-BBaXo_eDMXTjkebozZPvzsgrb9c,3353
12
11
  marlenv/adapters/smac_adapter.py,sha256=8uWC7YKsaSXeTS8AUhpGOKvrWMbVEQT2-pml5BaFUB0,8343
12
+ marlenv/catalog/__init__.py,sha256=r5iAuIqM5UZkTNNu4cMc0hmjlgZ74OXcxL3aM15dnHw,655
13
+ marlenv/catalog/deepsea.py,sha256=VPRV6tjPIkj-TMqp2L8U9COUK3FBlodoRW6gpsnwZ9Y,2332
13
14
  marlenv/models/__init__.py,sha256=uihmRs71Gg5z7Bvau_xtaQVg7xEtX8sTzi74bIHL5P0,443
14
15
  marlenv/models/env.py,sha256=BG1iVHxGD_p827mF0ewyOBn6wU2gtFsHLW1b4UtW-V0,7841
15
- marlenv/models/episode.py,sha256=DOX2FpWK-wm0BX_vC2bVD4BUMiP-JHUfRl80MGXD7kc,13472
16
+ marlenv/models/episode.py,sha256=zsyxsW4LIioPKyY4DZKn64A31e5ZvlwOf3HIGuRUzhs,13531
16
17
  marlenv/models/observation.py,sha256=RhvKvmys4bu3UwwVsvu7fJ7TMKt2QkKnBD1e0hw2r7s,3528
17
- marlenv/models/spaces.py,sha256=v7jnhPfj7vq7DFFJpSbQEYe4NGLLlj_bj2pzvvSBX7Y,7777
18
+ marlenv/models/spaces.py,sha256=1aPmTcoOTU9nlwlcN7aswNrORwghOYAGqCLAMpk39SA,7793
18
19
  marlenv/models/state.py,sha256=LbP--JxBzRwMFpEAaZyxCX13xKQ27xPE2fabohaq9YI,2058
19
20
  marlenv/models/step.py,sha256=00PhD_ccdCIYAY1SVJdJU91weU0Y_tNIJwK16TN_53I,3056
20
21
  marlenv/models/transition.py,sha256=UkJVRNxZoyRkjE7YmKtUf_4xA7cOEh20O60dTldbvys,5070
21
- marlenv/utils/__init__.py,sha256=C3qhvkVwctBP8mG3G5nkAZ5DKfErVRkdbHo7oeWVsM0,209
22
- marlenv/utils/schedule.py,sha256=slhtpQiBHSUNyPmSkKb2yBgiHJqPhoPxa33GxvyV8Jc,8565
23
- marlenv/wrappers/__init__.py,sha256=uV00m0jysZBgOW-TvRekis-gsAKPeR51P3HsuRZKxG8,880
22
+ marlenv/utils/__init__.py,sha256=udb1AhuX6cdcIEaGpE3E1U4Xlo7DtY6f8KCnFgdgfz0,462
23
+ marlenv/utils/cached_property_collector.py,sha256=IOjbr61f0DqLhcidXKrl7MhN1BOEGiTzCANIKQCxaF0,600
24
+ marlenv/utils/import_placeholders.py,sha256=qKp-4YZFGcqoEAdblroify_lfX8iuH_ot4AOJkfpvPg,860
25
+ marlenv/utils/schedule.py,sha256=BdjefYgAtGlh1wWGHENid4WNnPOU78kkNiRvR5A9GEA,9308
26
+ marlenv/wrappers/__init__.py,sha256=Z4_M-mxRNKQeu52tkmQ4B2m3-zrsmjfXXL5NsWQ4vu4,952
27
+ marlenv/wrappers/action_randomizer.py,sha256=A1kejqGOTA0sc_RQL0EOd6sMSbcIdiV5zlscjKUlzdY,474
24
28
  marlenv/wrappers/agent_id_wrapper.py,sha256=9qHV3LMQ4AjcDCSuvQhz5h9hUf7Xtrdi2sIxmNZk5NA,1126
25
29
  marlenv/wrappers/available_actions_mask.py,sha256=OMyt2KntsR8JA2RuRgvwdzqzPe-_H-KKkbUUJfe_mks,1404
26
30
  marlenv/wrappers/available_actions_wrapper.py,sha256=_HRl9zsjJgSrLgVuT-BjpnnfrfM8ic6wBUWlg67uCx4,926
@@ -34,7 +38,7 @@ marlenv/wrappers/potential_shaping.py,sha256=T_QvnmWReCgpyoInxRw2UXbmdvcBD5U-vV1
34
38
  marlenv/wrappers/rlenv_wrapper.py,sha256=S6G1VjFklTEzU6bj0AXrTDXnsTQJARq8VB4uUH6AXe4,2993
35
39
  marlenv/wrappers/time_limit.py,sha256=GxbxcbfFyuVg14ylQU2C_cjmV9q4uDAt5wepfgX_PyM,3976
36
40
  marlenv/wrappers/video_recorder.py,sha256=ucBQSNRPqDr-2mYxrTCqlrWcxSWtSJ7XlRC9-LdukBM,2535
37
- multi_agent_rlenv-3.5.4.dist-info/METADATA,sha256=xgYpo8yrpncxWufHZSqMIeeT3Bb8Ll9DylTRKyesdT4,5005
38
- multi_agent_rlenv-3.5.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
39
- multi_agent_rlenv-3.5.4.dist-info/licenses/LICENSE,sha256=_eeiGVoIJ7kYt6l1zbIvSBQppTnw0mjnYk1lQ4FxEjE,1074
40
- multi_agent_rlenv-3.5.4.dist-info/RECORD,,
41
+ multi_agent_rlenv-3.6.0.dist-info/METADATA,sha256=J9gFMccQ88-4Xh-DP80wOGehHc19l-i3BLI2zabFG9A,5751
42
+ multi_agent_rlenv-3.6.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
43
+ multi_agent_rlenv-3.6.0.dist-info/licenses/LICENSE,sha256=_eeiGVoIJ7kYt6l1zbIvSBQppTnw0mjnYk1lQ4FxEjE,1074
44
+ multi_agent_rlenv-3.6.0.dist-info/RECORD,,
@@ -1,241 +0,0 @@
1
- import sys
2
- from dataclasses import dataclass
3
- from typing import Literal, Sequence, Optional
4
- from copy import deepcopy
5
-
6
- import cv2
7
- import numpy as np
8
- import numpy.typing as npt
9
- import pygame
10
- from marlenv.models import ContinuousSpace, DiscreteSpace, MARLEnv, Observation, State, Step, MultiDiscreteSpace
11
- from marlenv.utils import Schedule
12
-
13
- from overcooked_ai_py.mdp.overcooked_env import OvercookedEnv
14
- from overcooked_ai_py.mdp.overcooked_mdp import Action, OvercookedGridworld, OvercookedState
15
- from overcooked_ai_py.visualization.state_visualizer import StateVisualizer
16
-
17
-
18
- @dataclass
19
- class Overcooked(MARLEnv[MultiDiscreteSpace]):
20
- horizon: int
21
- shaping_factor: Schedule
22
-
23
- def __init__(
24
- self,
25
- oenv: OvercookedEnv,
26
- shaping_factor: float | Schedule = 1.0,
27
- name_suffix: Optional[str] = None,
28
- ):
29
- if isinstance(shaping_factor, (int, float)):
30
- shaping_factor = Schedule.constant(shaping_factor)
31
- self.shaping_factor = shaping_factor
32
- self._oenv = oenv
33
- assert isinstance(oenv.mdp, OvercookedGridworld)
34
- self._mdp = oenv.mdp
35
- self._visualizer = StateVisualizer()
36
- width, height, layers = tuple(self._mdp.lossless_state_encoding_shape)
37
- # -1 because we extract the "urgent" layer to the extras
38
- shape = (int(layers - 1), int(width), int(height))
39
- super().__init__(
40
- n_agents=self._mdp.num_players,
41
- action_space=DiscreteSpace(Action.NUM_ACTIONS, labels=[Action.ACTION_TO_CHAR[a] for a in Action.ALL_ACTIONS]).repeat(
42
- self._mdp.num_players
43
- ),
44
- observation_shape=shape,
45
- extras_shape=(2,),
46
- extras_meanings=["timestep", "urgent"],
47
- state_shape=shape,
48
- state_extra_shape=(2,),
49
- reward_space=ContinuousSpace.from_shape(1),
50
- )
51
- self.horizon = int(self._oenv.horizon)
52
- if name_suffix is not None:
53
- self.name = f"{self.name}-{name_suffix}"
54
-
55
- @property
56
- def state(self) -> OvercookedState:
57
- """Current state of the environment"""
58
- return self._oenv.state
59
-
60
- def set_state(self, state: State):
61
- raise NotImplementedError("Not yet implemented")
62
-
63
- @property
64
- def time_step(self):
65
- return self.state.timestep
66
-
67
- def _state_data(self):
68
- players_layers = self._mdp.lossless_state_encoding(self.state)
69
- state = np.array(players_layers, dtype=np.float32)
70
- # Use axes (agents, channels, height, width) instead of (agents, height, width, channels)
71
- state = np.transpose(state, (0, 3, 1, 2))
72
- # The last last layer is for "urgency", put it in the extras
73
- urgency = float(np.all(state[:, -1]))
74
- state = state[:, :-1]
75
- return state, urgency
76
-
77
- def get_state(self):
78
- data, is_urgent = self._state_data()
79
- return State(data[0], np.array([self.time_step / self.horizon, is_urgent], dtype=np.float32))
80
-
81
- def get_observation(self) -> Observation:
82
- data, is_urgent = self._state_data()
83
- return Observation(
84
- data=data,
85
- available_actions=self.available_actions(),
86
- extras=np.array([[self.time_step / self.horizon, is_urgent]] * self.n_agents, dtype=np.float32),
87
- )
88
-
89
- def available_actions(self):
90
- available_actions = np.full((self.n_agents, self.n_actions), False)
91
- actions = self._mdp.get_actions(self._oenv.state)
92
- for agent_num, agent_actions in enumerate(actions):
93
- for action in agent_actions:
94
- available_actions[agent_num, Action.ACTION_TO_INDEX[action]] = True
95
- return np.array(available_actions, dtype=np.bool)
96
-
97
- def step(self, actions: Sequence[int] | np.ndarray) -> Step:
98
- self.shaping_factor.update()
99
- actions = [Action.ALL_ACTIONS[a] for a in actions]
100
- _, reward, done, info = self._oenv.step(actions, display_phi=True)
101
-
102
- reward += sum(info["shaped_r_by_agent"]) * self.shaping_factor
103
- return Step(
104
- obs=self.get_observation(),
105
- state=self.get_state(),
106
- reward=np.array([reward], dtype=np.float32),
107
- done=done,
108
- truncated=False,
109
- info=info,
110
- )
111
-
112
- def reset(self):
113
- self._oenv.reset()
114
- return self.get_observation(), self.get_state()
115
-
116
- def __deepcopy__(self, _):
117
- """
118
- Note: a specific implementation is needed because `pygame.font.Font` objects are not deep-copiable by default.
119
- """
120
- mdp = deepcopy(self._mdp)
121
- copy = Overcooked(OvercookedEnv.from_mdp(mdp, horizon=self.horizon), deepcopy(self.shaping_factor))
122
- copy.name = self.name
123
- return copy
124
-
125
- def __getstate__(self):
126
- return {"horizon": self.horizon, "mdp": self._mdp, "name": self.name, "schedule": self.shaping_factor}
127
-
128
- def __setstate__(self, state: dict):
129
- from overcooked_ai_py.mdp.overcooked_mdp import Recipe
130
-
131
- mdp = state["mdp"]
132
- Recipe.configure(mdp.recipe_config)
133
- self.__init__(OvercookedEnv.from_mdp(state["mdp"], horizon=state["horizon"]), shaping_factor=state["schedule"])
134
- self.name = state["name"]
135
-
136
- def get_image(self):
137
- rewards_dict = {} # dictionary of details you want rendered in the UI
138
- for key, value in self._oenv.game_stats.items():
139
- if key in [
140
- "cumulative_shaped_rewards_by_agent",
141
- "cumulative_sparse_rewards_by_agent",
142
- ]:
143
- rewards_dict[key] = value
144
-
145
- image = self._visualizer.render_state(
146
- state=self._oenv.state,
147
- grid=self._mdp.terrain_mtx,
148
- hud_data=StateVisualizer.default_hud_data(self._oenv.state, **rewards_dict),
149
- )
150
-
151
- image = pygame.surfarray.array3d(image)
152
- image = np.flip(np.rot90(image, 3), 1)
153
- # Depending on the platform, the image may need to be converted to RGB
154
- if sys.platform in ("linux", "linux2"):
155
- image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
156
- return image
157
-
158
- @staticmethod
159
- def from_layout(
160
- layout: Literal[
161
- "asymmetric_advantages",
162
- "asymmetric_advantages_tomato",
163
- "bonus_order_test",
164
- "bottleneck",
165
- "centre_objects",
166
- "centre_pots",
167
- "coordination_ring",
168
- "corridor",
169
- "counter_circuit",
170
- "counter_circuit_o_1order",
171
- "cramped_corridor",
172
- "cramped_room",
173
- "cramped_room_o_3orders",
174
- "cramped_room_single",
175
- "cramped_room_tomato",
176
- "five_by_five",
177
- "forced_coordination",
178
- "forced_coordination_tomato",
179
- "inverse_marshmallow_experiment",
180
- "large_room",
181
- "long_cook_time",
182
- "marshmallow_experiment_coordination",
183
- "marshmallow_experiment",
184
- "mdp_test",
185
- "m_shaped_s",
186
- "multiplayer_schelling",
187
- "pipeline",
188
- "scenario1_s",
189
- "scenario2",
190
- "scenario2_s",
191
- "scenario3",
192
- "scenario4",
193
- "schelling",
194
- "schelling_s",
195
- "simple_o",
196
- "simple_o_t",
197
- "simple_tomato",
198
- "small_corridor",
199
- "soup_coordination",
200
- "tutorial_0",
201
- "tutorial_1",
202
- "tutorial_2",
203
- "tutorial_3",
204
- "unident",
205
- "you_shall_not_pass",
206
- ],
207
- horizon: int = 400,
208
- reward_shaping_factor: float | Schedule = 1.0,
209
- ):
210
- mdp = OvercookedGridworld.from_layout_name(layout)
211
- return Overcooked(OvercookedEnv.from_mdp(mdp, horizon=horizon, info_level=0), reward_shaping_factor, layout)
212
-
213
- @staticmethod
214
- def from_grid(
215
- grid: Sequence[Sequence[Literal["S", "P", "X", "O", "D", "T", "1", "2", " "] | str]],
216
- horizon: int = 400,
217
- shaping_factor: float | Schedule = 1.0,
218
- layout_name: Optional[str] = None,
219
- ):
220
- """
221
- Create an Overcooked environment from a grid layout where
222
- - S is a serving location
223
- - P is a cooking pot
224
- - X is a counter
225
- - O is an onion dispenser
226
- - D is a dish dispenser
227
- - T is a tomato dispenser
228
- - 1 is a player 1 starting location
229
- - 2 is a player 2 starting location
230
- - ' ' is a walkable space
231
-
232
- If provided, `custom_name` is added to the environment name.
233
- """
234
- # It is necessary to add an explicit layout name because Overcooked saves some files under this
235
- # name. By default the name is a concatenation of the grid elements, which may include characters
236
- # such as white spaces, pipes ('|') and square brackets ('[' and ']') that are invalid Windows file paths.
237
- if layout_name is None:
238
- layout_name = "custom-layout"
239
- mdp = OvercookedGridworld.from_grid(grid, base_layout_params={"layout_name": layout_name})
240
-
241
- return Overcooked(OvercookedEnv.from_mdp(mdp, horizon=horizon, info_level=0), shaping_factor, layout_name)