multi-agent-rlenv 3.5.5__py3-none-any.whl → 3.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
marlenv/__init__.py CHANGED
@@ -62,7 +62,13 @@ print(env.extras_shape) # (1, )
62
62
  If you want to create a new environment, you can simply create a class that inherits from `MARLEnv`. If you want to create a wrapper around an existing `MARLEnv`, you probably want to subclass `RLEnvWrapper` which implements a default behaviour for every method.
63
63
  """
64
64
 
65
- __version__ = "3.5.5"
65
+ from importlib.metadata import version, PackageNotFoundError
66
+
67
+ try:
68
+ __version__ = version("overcooked")
69
+ except PackageNotFoundError:
70
+ __version__ = "0.0.0" # fallback pratique en dev/CI
71
+
66
72
 
67
73
  from . import models
68
74
  from .models import (
@@ -82,16 +88,19 @@ from .models import (
82
88
 
83
89
  from . import wrappers
84
90
  from . import adapters
85
- from .env_builder import make, Builder
91
+ from .env_builder import Builder
86
92
  from .wrappers import RLEnvWrapper
87
93
  from .mock_env import DiscreteMockEnv, DiscreteMOMockEnv
94
+ from . import catalog
95
+ from .adapters import make
88
96
 
89
97
  __all__ = [
90
98
  "models",
99
+ "make",
100
+ "catalog",
91
101
  "wrappers",
92
102
  "adapters",
93
103
  "spaces",
94
- "make",
95
104
  "Builder",
96
105
  "MARLEnv",
97
106
  "Step",
@@ -1,42 +1,33 @@
1
1
  from importlib.util import find_spec
2
2
  from .pymarl_adapter import PymarlAdapter
3
+ from marlenv.utils import DummyClass, dummy_function
3
4
 
4
- HAS_GYM = False
5
- if find_spec("gymnasium") is not None:
6
- from .gym_adapter import Gym
5
+ HAS_GYM = find_spec("gymnasium") is not None
6
+ if HAS_GYM:
7
+ from .gym_adapter import Gym, make
8
+ else:
9
+ Gym = DummyClass("gymnasium")
10
+ make = dummy_function("gymnasium")
7
11
 
8
- HAS_GYM = True
9
-
10
- HAS_PETTINGZOO = False
11
- if find_spec("pettingzoo") is not None:
12
+ HAS_PETTINGZOO = find_spec("pettingzoo") is not None
13
+ if HAS_PETTINGZOO:
12
14
  from .pettingzoo_adapter import PettingZoo
15
+ else:
16
+ PettingZoo = DummyClass("pettingzoo")
13
17
 
14
- HAS_PETTINGZOO = True
15
-
16
- HAS_SMAC = False
17
- if find_spec("smac") is not None:
18
+ HAS_SMAC = find_spec("smac") is not None
19
+ if HAS_SMAC:
18
20
  from .smac_adapter import SMAC
19
-
20
- HAS_SMAC = True
21
-
22
- HAS_OVERCOOKED = False
23
- if find_spec("overcooked_ai_py") is not None and find_spec("overcooked_ai_py.mdp") is not None:
24
- import numpy
25
-
26
- # Overcooked assumes a version of numpy <2.0 where np.Inf is available.
27
- setattr(numpy, "Inf", numpy.inf)
28
- from .overcooked_adapter import Overcooked
29
-
30
- HAS_OVERCOOKED = True
21
+ else:
22
+ SMAC = DummyClass("smac", "https://github.com/oxwhirl/smac.git")
31
23
 
32
24
  __all__ = [
33
25
  "PymarlAdapter",
34
26
  "Gym",
27
+ "make",
35
28
  "PettingZoo",
36
29
  "SMAC",
37
- "Overcooked",
38
30
  "HAS_GYM",
39
31
  "HAS_PETTINGZOO",
40
32
  "HAS_SMAC",
41
- "HAS_OVERCOOKED",
42
33
  ]
@@ -78,3 +78,9 @@ class Gym(MARLEnv[Space]):
78
78
 
79
79
  def seed(self, seed_value: int):
80
80
  self._gym_env.reset(seed=seed_value)
81
+
82
+
83
+ def make(env_id: str, **kwargs):
84
+ """Make an RLEnv from str (Gym) or PettingZoo"""
85
+ gym_env = gym.make(env_id, render_mode="rgb_array", **kwargs)
86
+ return Gym(gym_env)
@@ -0,0 +1,26 @@
1
+ from importlib.util import find_spec
2
+ from ..utils.import_placeholders import DummyClass
3
+ from marlenv.adapters import SMAC
4
+ from .deepsea import DeepSea
5
+
6
+
7
+ HAS_LLE = find_spec("lle") is not None
8
+ if HAS_LLE:
9
+ from lle import LLE # pyright: ignore[reportMissingImports]
10
+ else:
11
+ LLE = DummyClass("lle", "laser-learning-environment")
12
+
13
+ HAS_OVERCOOKED = find_spec("overcooked") is not None
14
+ if HAS_OVERCOOKED:
15
+ from overcooked import Overcooked # pyright: ignore[reportMissingImports]
16
+ else:
17
+ Overcooked = DummyClass("overcooked", "overcooked")
18
+
19
+ __all__ = [
20
+ "Overcooked",
21
+ "SMAC",
22
+ "LLE",
23
+ "DeepSea",
24
+ "HAS_LLE",
25
+ "HAS_OVERCOOKED",
26
+ ]
@@ -0,0 +1,73 @@
1
+ from typing import Sequence
2
+ import numpy as np
3
+ from marlenv import MARLEnv, MultiDiscreteSpace, DiscreteSpace, Observation, State, Step
4
+ from dataclasses import dataclass
5
+
6
+
7
+ LEFT = 0
8
+ RIGHT = 1
9
+
10
+
11
+ @dataclass
12
+ class DeepSea(MARLEnv[MultiDiscreteSpace]):
13
+ """
14
+ Deep Sea single-agent environment to test for deep exploration. The probability of reaching the goal state under random exploration is 2^(-max_depth).
15
+
16
+ The agent explores a 2D grid where the bottom-right corner (max_depth, max_depth) is the goal and is the only state to yield a reward.
17
+ The agent starts in the top-left corner (0, 0).
18
+ The agent has two actions: left or right, and taking an action makes the agent dive one row deeper. The agent can not go beyond the grid boundaries.
19
+ Going right gives a penalty of (0.01 / max_depth).
20
+ """
21
+
22
+ max_depth: int
23
+
24
+ def __init__(self, max_depth: int):
25
+ super().__init__(
26
+ n_agents=1,
27
+ action_space=DiscreteSpace(size=2, labels=["left", "right"]).repeat(1),
28
+ observation_shape=(2,),
29
+ state_shape=(2,),
30
+ )
31
+ self.max_depth = max_depth
32
+ self._row = 0
33
+ self._col = 0
34
+ self._step_right_penalty = -0.01 / self.max_depth
35
+
36
+ def get_observation(self) -> Observation:
37
+ return Observation(np.array([self._row, self._col], dtype=np.float32), self.available_actions())
38
+
39
+ def get_state(self) -> State:
40
+ return State(np.array([self._row, self._col], dtype=np.float32))
41
+
42
+ def reset(self):
43
+ self._row = 0
44
+ self._col = 0
45
+ return self.get_observation(), self.get_state()
46
+
47
+ def step(self, action: Sequence[int]):
48
+ self._row += 1
49
+ if action[0] == LEFT:
50
+ self._col -= 1
51
+ else:
52
+ self._col += 1
53
+ self._col = max(0, self._col)
54
+ if action[0] == RIGHT:
55
+ if self._row == self.max_depth:
56
+ reward = 1.0
57
+ else:
58
+ reward = self._step_right_penalty
59
+ else:
60
+ reward = 0.0
61
+ return Step(
62
+ self.get_observation(),
63
+ self.get_state(),
64
+ reward,
65
+ done=self._row == self.max_depth,
66
+ )
67
+
68
+ def set_state(self, state: State):
69
+ self._row, self._col = state.data
70
+
71
+ @property
72
+ def agent_state_size(self):
73
+ return 2
marlenv/env_builder.py CHANGED
@@ -1,73 +1,13 @@
1
1
  from dataclasses import dataclass
2
- from typing import Generic, Literal, Optional, TypeVar, overload
2
+ from typing import Generic, Literal, Optional, TypeVar
3
3
  import numpy as np
4
4
  import numpy.typing as npt
5
5
 
6
6
  from . import wrappers
7
- from marlenv import adapters
8
7
  from .models import Space, MARLEnv
9
8
 
10
9
  AS = TypeVar("AS", bound=Space)
11
10
 
12
- if adapters.HAS_PETTINGZOO:
13
- from .adapters import PettingZoo
14
- from pettingzoo import ParallelEnv
15
-
16
- @overload
17
- def make(env: ParallelEnv) -> PettingZoo: ...
18
-
19
-
20
- if adapters.HAS_GYM:
21
- from .adapters import Gym
22
- from gymnasium import Env
23
- import gymnasium
24
-
25
- @overload
26
- def make(env: Env) -> Gym: ...
27
-
28
- @overload
29
- def make(env: str, **kwargs) -> Gym:
30
- """
31
- Make an RLEnv from the `gymnasium` registry (e.g: "CartPole-v1").
32
- """
33
-
34
-
35
- if adapters.HAS_SMAC:
36
- from .adapters import SMAC
37
- from smac.env import StarCraft2Env
38
-
39
- @overload
40
- def make(env: StarCraft2Env) -> SMAC: ...
41
-
42
-
43
- if adapters.HAS_OVERCOOKED:
44
- from .adapters import Overcooked
45
- from overcooked_ai_py.mdp.overcooked_env import OvercookedEnv
46
-
47
- @overload
48
- def make(env: OvercookedEnv) -> Overcooked: ...
49
-
50
-
51
- def make(env, **kwargs):
52
- """Make an RLEnv from str (Gym) or PettingZoo"""
53
- match env:
54
- case MARLEnv():
55
- return env
56
- case str(env_id):
57
- if adapters.HAS_GYM:
58
- gym_env = gymnasium.make(env_id, render_mode="rgb_array", **kwargs)
59
- return Gym(gym_env)
60
-
61
- if adapters.HAS_PETTINGZOO and isinstance(env, ParallelEnv):
62
- return PettingZoo(env) # type: ignore
63
- if adapters.HAS_SMAC and isinstance(env, StarCraft2Env):
64
- return SMAC(env)
65
- if adapters.HAS_OVERCOOKED and isinstance(env, OvercookedEnv):
66
- return Overcooked(env) # type: ignore
67
- if adapters.HAS_GYM and isinstance(env, Env):
68
- return Gym(env)
69
- raise ValueError(f"Unknown environment type: {type(env)}")
70
-
71
11
 
72
12
  @dataclass
73
13
  class Builder(Generic[AS]):
marlenv/models/spaces.py CHANGED
@@ -105,7 +105,7 @@ class MultiDiscreteSpace(Space):
105
105
  def sample(self, mask: Optional[npt.NDArray[np.bool] | list[npt.NDArray[np.bool]]] = None):
106
106
  if mask is None:
107
107
  return np.array([space.sample() for space in self.spaces], dtype=np.int32)
108
- return np.array([space.sample(mask=mask) for mask, space in zip(mask, self.spaces)], dtype=np.int32)
108
+ return np.array([space.sample(mask=mask) for mask, space in zip(mask, self.spaces)], dtype=np.int32) # type: ignore
109
109
 
110
110
  def __eq__(self, value: object) -> bool:
111
111
  if not isinstance(value, MultiDiscreteSpace):
marlenv/utils/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from .cached_property_collector import CachedPropertyCollector, CachedPropertyInvalidator
2
2
  from .schedule import ExpSchedule, LinearSchedule, MultiSchedule, RoundedSchedule, Schedule
3
+ from .import_placeholders import DummyClass, dummy_function
3
4
 
4
5
  __all__ = [
5
6
  "Schedule",
@@ -9,4 +10,6 @@ __all__ = [
9
10
  "RoundedSchedule",
10
11
  "CachedPropertyCollector",
11
12
  "CachedPropertyInvalidator",
13
+ "DummyClass",
14
+ "dummy_function",
12
15
  ]
@@ -0,0 +1,30 @@
1
+ from typing import Optional, Any
2
+
3
+
4
+ class DummyClass:
5
+ def __init__(self, module_name: str, package_name: Optional[str] = None):
6
+ self.module_name = module_name
7
+ if package_name is None:
8
+ self.package_name = module_name
9
+ else:
10
+ self.package_name = package_name
11
+
12
+ def _raise_error(self):
13
+ raise ImportError(
14
+ f"The optional dependency `{self.module_name}` is not installed.\nInstall the `{self.package_name}` package (e.g. pip install {self.package_name})."
15
+ )
16
+
17
+ def __getattr__(self, _):
18
+ self._raise_error()
19
+
20
+ def __call__(self, *args, **kwargs):
21
+ self._raise_error()
22
+
23
+
24
+ def dummy_function(module_name: str, package_name: Optional[str] = None):
25
+ dummy = DummyClass(module_name, package_name)
26
+
27
+ def fail(*args, **kwargs) -> Any:
28
+ dummy()
29
+
30
+ return fail
marlenv/utils/schedule.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from abc import abstractmethod
2
2
  from dataclasses import dataclass
3
- from typing import Callable, Optional, TypeVar
3
+ from typing import Any, Callable, Optional, TypeVar
4
4
 
5
5
  T = TypeVar("T")
6
6
 
@@ -142,6 +142,21 @@ class Schedule:
142
142
  def __int__(self) -> int:
143
143
  return int(self.value)
144
144
 
145
+ @staticmethod
146
+ def from_json(data: dict[str, Any]):
147
+ """Create a Schedule from a JSON-like dictionary."""
148
+ classname = data.get("name")
149
+ if classname == "LinearSchedule":
150
+ return LinearSchedule(data["start_value"], data["end_value"], data["n_steps"])
151
+ elif classname == "ExpSchedule":
152
+ return ExpSchedule(data["start_value"], data["end_value"], data["n_steps"])
153
+ elif classname == "ConstantSchedule":
154
+ return ConstantSchedule(data["value"])
155
+ elif classname == "ArbitrarySchedule":
156
+ raise NotImplementedError("ArbitrarySchedule cannot be deserialized from JSON")
157
+ else:
158
+ raise ValueError(f"Unknown schedule type: {classname}")
159
+
145
160
 
146
161
  @dataclass(eq=False)
147
162
  class LinearSchedule(Schedule):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: multi-agent-rlenv
3
- Version: 3.5.5
3
+ Version: 3.6.0
4
4
  Summary: A strongly typed Multi-Agent Reinforcement Learning framework
5
5
  Project-URL: repository, https://github.com/yamoling/multi-agent-rlenv
6
6
  Author-email: Yannick Molinghen <yannick.molinghen@ulb.be>
@@ -13,7 +13,8 @@ Requires-Dist: opencv-python>=4.0
13
13
  Requires-Dist: typing-extensions>=4.0
14
14
  Provides-Extra: all
15
15
  Requires-Dist: gymnasium>0.29.1; extra == 'all'
16
- Requires-Dist: overcooked-ai; extra == 'all'
16
+ Requires-Dist: laser-learning-environment>=2.6.1; extra == 'all'
17
+ Requires-Dist: overcooked>=0.1.0; extra == 'all'
17
18
  Requires-Dist: pettingzoo>=1.20; extra == 'all'
18
19
  Requires-Dist: pymunk>=6.0; extra == 'all'
19
20
  Requires-Dist: pysc2; extra == 'all'
@@ -22,9 +23,10 @@ Requires-Dist: smac; extra == 'all'
22
23
  Requires-Dist: torch>=2.0; extra == 'all'
23
24
  Provides-Extra: gym
24
25
  Requires-Dist: gymnasium>=0.29.1; extra == 'gym'
26
+ Provides-Extra: lle
27
+ Requires-Dist: laser-learning-environment>=2.6.1; extra == 'lle'
25
28
  Provides-Extra: overcooked
26
- Requires-Dist: overcooked-ai>=1.1.0; extra == 'overcooked'
27
- Requires-Dist: scipy>=1.10; extra == 'overcooked'
29
+ Requires-Dist: overcooked>=0.1.0; extra == 'overcooked'
28
30
  Provides-Extra: pettingzoo
29
31
  Requires-Dist: pettingzoo>=1.20; extra == 'pettingzoo'
30
32
  Requires-Dist: pymunk>=6.0; extra == 'pettingzoo'
@@ -62,9 +64,24 @@ $ pip install marlenv[smac] # Install SMAC
62
64
  $ pip install marlenv[gym,smac] # Install Gym & smac support
63
65
  ```
64
66
 
67
+ ## Using the `marlenv` environment catalog
68
+ Some environments are registered in the `marlenv` and can be easily instantiated via its catalog.
69
+
70
+ ```python
71
+ from marlenv import catalog
72
+
73
+ env1 = catalog.Overcooked.from_layout("scenario4")
74
+ env2 = catalog.LLE.level(6)
75
+ env3 = catalog.DeepSea(mex_depth=5)
76
+ ```
77
+ Note that using the catalog requires the corresponding environment package to be installed. For instance you need to install the `laser-learning-environment` package to use `catalog.LLE`, which can be done by using the corresponding feature when at installation as shown below.
78
+ ```bash
79
+ pip install multi-agent-rlenv[lle]
80
+ ```
81
+
65
82
 
66
83
  ## Using `marlenv` with existing libraries
67
- `marlenv` unifies multiple popular libraries under a single interface. Namely, `marlenv` supports `smac`, `gymnasium` and `pettingzoo`.
84
+ `marlenv` provides adapters from most popular libraries to unify them under a single interface. Namely, `marlenv` supports `smac`, `gymnasium` and `pettingzoo`.
68
85
 
69
86
  ```python
70
87
  import marlenv
@@ -74,7 +91,7 @@ gym_env = marlenv.make("CartPole-v1", seed=25)
74
91
 
75
92
  # You can seemlessly instanciate a SMAC environment and directly pass your required arguments
76
93
  from marlenv.adapters import SMAC
77
- smac_env = env2 = SMAC("3m", debug=True, difficulty="9")
94
+ smac_env = SMAC("3m", debug=True, difficulty="9")
78
95
 
79
96
  # pettingzoo is also supported
80
97
  from pettingzoo.sisl import pursuit_v4
@@ -1,26 +1,28 @@
1
- marlenv/__init__.py,sha256=bX76JknjwfVJ6IOKql_y4rIqvvx9raepD7u2lB9CgGo,3656
2
- marlenv/env_builder.py,sha256=RJoHJLYAUE1ausAoJiRC3fUxyxpH1WRJf7Sdm2ml-uk,5517
1
+ marlenv/__init__.py,sha256=MJgaW73zWYJKTNMWE8V3hTvrcMk-WEX3RaG-K_oIDD8,3886
2
+ marlenv/env_builder.py,sha256=RUMFvW7dAJtHMLm8-oPVpjBefDtNliZtjlHci97Xj-Q,3874
3
3
  marlenv/env_pool.py,sha256=nCEBkGQU62fcvCAANyAqY8gCFjYlVnSCg-V3Fhx00yc,933
4
4
  marlenv/exceptions.py,sha256=gJUC_2rVAvOfK_ypVFc7Myh-pIfSU3To38VBVS_0rZA,1179
5
5
  marlenv/mock_env.py,sha256=kKvTdZl4_xSTTI9V6otZ1P709sfPYrqZSbbZaTip9iI,4684
6
6
  marlenv/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
- marlenv/adapters/__init__.py,sha256=rWiqQOqTx3kVL5ZkPo3rkczrlQBBhQbU55zGI26SEeY,929
8
- marlenv/adapters/gym_adapter.py,sha256=Rx8ZnHW0XEwZzRT53BdDP1f4CtNp2tATAYZ0BbtBbd0,2863
9
- marlenv/adapters/overcooked_adapter.py,sha256=0-5sfKHGjmM4eYtbydsENMXV2Qx4WVGlvJl7MFZSaj8,9123
7
+ marlenv/adapters/__init__.py,sha256=wJzd94WfNNFX-yojr2M1dExVAGkFqwM2fieQ-v7uH4s,783
8
+ marlenv/adapters/gym_adapter.py,sha256=5HZF3g0QD4n7K4GQoMis4q0zj97uFTLdzdxMYHzM_UE,3041
10
9
  marlenv/adapters/pettingzoo_adapter.py,sha256=w9Ta-X4L_6ZXdDGmREOdcU0vpLR8lGP__s49DyK3dk8,2852
11
10
  marlenv/adapters/pymarl_adapter.py,sha256=2s7EY31s1hrml3q-BBaXo_eDMXTjkebozZPvzsgrb9c,3353
12
11
  marlenv/adapters/smac_adapter.py,sha256=8uWC7YKsaSXeTS8AUhpGOKvrWMbVEQT2-pml5BaFUB0,8343
12
+ marlenv/catalog/__init__.py,sha256=r5iAuIqM5UZkTNNu4cMc0hmjlgZ74OXcxL3aM15dnHw,655
13
+ marlenv/catalog/deepsea.py,sha256=VPRV6tjPIkj-TMqp2L8U9COUK3FBlodoRW6gpsnwZ9Y,2332
13
14
  marlenv/models/__init__.py,sha256=uihmRs71Gg5z7Bvau_xtaQVg7xEtX8sTzi74bIHL5P0,443
14
15
  marlenv/models/env.py,sha256=BG1iVHxGD_p827mF0ewyOBn6wU2gtFsHLW1b4UtW-V0,7841
15
16
  marlenv/models/episode.py,sha256=zsyxsW4LIioPKyY4DZKn64A31e5ZvlwOf3HIGuRUzhs,13531
16
17
  marlenv/models/observation.py,sha256=RhvKvmys4bu3UwwVsvu7fJ7TMKt2QkKnBD1e0hw2r7s,3528
17
- marlenv/models/spaces.py,sha256=v7jnhPfj7vq7DFFJpSbQEYe4NGLLlj_bj2pzvvSBX7Y,7777
18
+ marlenv/models/spaces.py,sha256=1aPmTcoOTU9nlwlcN7aswNrORwghOYAGqCLAMpk39SA,7793
18
19
  marlenv/models/state.py,sha256=LbP--JxBzRwMFpEAaZyxCX13xKQ27xPE2fabohaq9YI,2058
19
20
  marlenv/models/step.py,sha256=00PhD_ccdCIYAY1SVJdJU91weU0Y_tNIJwK16TN_53I,3056
20
21
  marlenv/models/transition.py,sha256=UkJVRNxZoyRkjE7YmKtUf_4xA7cOEh20O60dTldbvys,5070
21
- marlenv/utils/__init__.py,sha256=36pNw0r4V3xsqPZ5ljM29o96dfPAFq8WvMwggyv41fI,362
22
+ marlenv/utils/__init__.py,sha256=udb1AhuX6cdcIEaGpE3E1U4Xlo7DtY6f8KCnFgdgfz0,462
22
23
  marlenv/utils/cached_property_collector.py,sha256=IOjbr61f0DqLhcidXKrl7MhN1BOEGiTzCANIKQCxaF0,600
23
- marlenv/utils/schedule.py,sha256=slhtpQiBHSUNyPmSkKb2yBgiHJqPhoPxa33GxvyV8Jc,8565
24
+ marlenv/utils/import_placeholders.py,sha256=qKp-4YZFGcqoEAdblroify_lfX8iuH_ot4AOJkfpvPg,860
25
+ marlenv/utils/schedule.py,sha256=BdjefYgAtGlh1wWGHENid4WNnPOU78kkNiRvR5A9GEA,9308
24
26
  marlenv/wrappers/__init__.py,sha256=Z4_M-mxRNKQeu52tkmQ4B2m3-zrsmjfXXL5NsWQ4vu4,952
25
27
  marlenv/wrappers/action_randomizer.py,sha256=A1kejqGOTA0sc_RQL0EOd6sMSbcIdiV5zlscjKUlzdY,474
26
28
  marlenv/wrappers/agent_id_wrapper.py,sha256=9qHV3LMQ4AjcDCSuvQhz5h9hUf7Xtrdi2sIxmNZk5NA,1126
@@ -36,7 +38,7 @@ marlenv/wrappers/potential_shaping.py,sha256=T_QvnmWReCgpyoInxRw2UXbmdvcBD5U-vV1
36
38
  marlenv/wrappers/rlenv_wrapper.py,sha256=S6G1VjFklTEzU6bj0AXrTDXnsTQJARq8VB4uUH6AXe4,2993
37
39
  marlenv/wrappers/time_limit.py,sha256=GxbxcbfFyuVg14ylQU2C_cjmV9q4uDAt5wepfgX_PyM,3976
38
40
  marlenv/wrappers/video_recorder.py,sha256=ucBQSNRPqDr-2mYxrTCqlrWcxSWtSJ7XlRC9-LdukBM,2535
39
- multi_agent_rlenv-3.5.5.dist-info/METADATA,sha256=WKf56Bb7PqZFrw2B6Sx8zulM-h7aZRqXMTvYHrSxEtQ,5005
40
- multi_agent_rlenv-3.5.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
41
- multi_agent_rlenv-3.5.5.dist-info/licenses/LICENSE,sha256=_eeiGVoIJ7kYt6l1zbIvSBQppTnw0mjnYk1lQ4FxEjE,1074
42
- multi_agent_rlenv-3.5.5.dist-info/RECORD,,
41
+ multi_agent_rlenv-3.6.0.dist-info/METADATA,sha256=J9gFMccQ88-4Xh-DP80wOGehHc19l-i3BLI2zabFG9A,5751
42
+ multi_agent_rlenv-3.6.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
43
+ multi_agent_rlenv-3.6.0.dist-info/licenses/LICENSE,sha256=_eeiGVoIJ7kYt6l1zbIvSBQppTnw0mjnYk1lQ4FxEjE,1074
44
+ multi_agent_rlenv-3.6.0.dist-info/RECORD,,
@@ -1,241 +0,0 @@
1
- import sys
2
- from dataclasses import dataclass
3
- from typing import Literal, Sequence, Optional
4
- from copy import deepcopy
5
-
6
- import cv2
7
- import numpy as np
8
- import numpy.typing as npt
9
- import pygame
10
- from marlenv.models import ContinuousSpace, DiscreteSpace, MARLEnv, Observation, State, Step, MultiDiscreteSpace
11
- from marlenv.utils import Schedule
12
-
13
- from overcooked_ai_py.mdp.overcooked_env import OvercookedEnv
14
- from overcooked_ai_py.mdp.overcooked_mdp import Action, OvercookedGridworld, OvercookedState
15
- from overcooked_ai_py.visualization.state_visualizer import StateVisualizer
16
-
17
-
18
- @dataclass
19
- class Overcooked(MARLEnv[MultiDiscreteSpace]):
20
- horizon: int
21
- shaping_factor: Schedule
22
-
23
- def __init__(
24
- self,
25
- oenv: OvercookedEnv,
26
- shaping_factor: float | Schedule = 1.0,
27
- name_suffix: Optional[str] = None,
28
- ):
29
- if isinstance(shaping_factor, (int, float)):
30
- shaping_factor = Schedule.constant(shaping_factor)
31
- self.shaping_factor = shaping_factor
32
- self._oenv = oenv
33
- assert isinstance(oenv.mdp, OvercookedGridworld)
34
- self._mdp = oenv.mdp
35
- self._visualizer = StateVisualizer()
36
- width, height, layers = tuple(self._mdp.lossless_state_encoding_shape)
37
- # -1 because we extract the "urgent" layer to the extras
38
- shape = (int(layers - 1), int(width), int(height))
39
- super().__init__(
40
- n_agents=self._mdp.num_players,
41
- action_space=DiscreteSpace(Action.NUM_ACTIONS, labels=[Action.ACTION_TO_CHAR[a] for a in Action.ALL_ACTIONS]).repeat(
42
- self._mdp.num_players
43
- ),
44
- observation_shape=shape,
45
- extras_shape=(2,),
46
- extras_meanings=["timestep", "urgent"],
47
- state_shape=shape,
48
- state_extra_shape=(2,),
49
- reward_space=ContinuousSpace.from_shape(1),
50
- )
51
- self.horizon = int(self._oenv.horizon)
52
- if name_suffix is not None:
53
- self.name = f"{self.name}-{name_suffix}"
54
-
55
- @property
56
- def state(self) -> OvercookedState:
57
- """Current state of the environment"""
58
- return self._oenv.state
59
-
60
- def set_state(self, state: State):
61
- raise NotImplementedError("Not yet implemented")
62
-
63
- @property
64
- def time_step(self):
65
- return self.state.timestep
66
-
67
- def _state_data(self):
68
- players_layers = self._mdp.lossless_state_encoding(self.state)
69
- state = np.array(players_layers, dtype=np.float32)
70
- # Use axes (agents, channels, height, width) instead of (agents, height, width, channels)
71
- state = np.transpose(state, (0, 3, 1, 2))
72
- # The last last layer is for "urgency", put it in the extras
73
- urgency = float(np.all(state[:, -1]))
74
- state = state[:, :-1]
75
- return state, urgency
76
-
77
- def get_state(self):
78
- data, is_urgent = self._state_data()
79
- return State(data[0], np.array([self.time_step / self.horizon, is_urgent], dtype=np.float32))
80
-
81
- def get_observation(self) -> Observation:
82
- data, is_urgent = self._state_data()
83
- return Observation(
84
- data=data,
85
- available_actions=self.available_actions(),
86
- extras=np.array([[self.time_step / self.horizon, is_urgent]] * self.n_agents, dtype=np.float32),
87
- )
88
-
89
- def available_actions(self):
90
- available_actions = np.full((self.n_agents, self.n_actions), False)
91
- actions = self._mdp.get_actions(self._oenv.state)
92
- for agent_num, agent_actions in enumerate(actions):
93
- for action in agent_actions:
94
- available_actions[agent_num, Action.ACTION_TO_INDEX[action]] = True
95
- return np.array(available_actions, dtype=np.bool)
96
-
97
- def step(self, actions: Sequence[int] | np.ndarray) -> Step:
98
- self.shaping_factor.update()
99
- actions = [Action.ALL_ACTIONS[a] for a in actions]
100
- _, reward, done, info = self._oenv.step(actions, display_phi=True)
101
-
102
- reward += sum(info["shaped_r_by_agent"]) * self.shaping_factor
103
- return Step(
104
- obs=self.get_observation(),
105
- state=self.get_state(),
106
- reward=np.array([reward], dtype=np.float32),
107
- done=done,
108
- truncated=False,
109
- info=info,
110
- )
111
-
112
- def reset(self):
113
- self._oenv.reset()
114
- return self.get_observation(), self.get_state()
115
-
116
- def __deepcopy__(self, _):
117
- """
118
- Note: a specific implementation is needed because `pygame.font.Font` objects are not deep-copiable by default.
119
- """
120
- mdp = deepcopy(self._mdp)
121
- copy = Overcooked(OvercookedEnv.from_mdp(mdp, horizon=self.horizon), deepcopy(self.shaping_factor))
122
- copy.name = self.name
123
- return copy
124
-
125
- def __getstate__(self):
126
- return {"horizon": self.horizon, "mdp": self._mdp, "name": self.name, "schedule": self.shaping_factor}
127
-
128
- def __setstate__(self, state: dict):
129
- from overcooked_ai_py.mdp.overcooked_mdp import Recipe
130
-
131
- mdp = state["mdp"]
132
- Recipe.configure(mdp.recipe_config)
133
- self.__init__(OvercookedEnv.from_mdp(state["mdp"], horizon=state["horizon"]), shaping_factor=state["schedule"])
134
- self.name = state["name"]
135
-
136
- def get_image(self):
137
- rewards_dict = {} # dictionary of details you want rendered in the UI
138
- for key, value in self._oenv.game_stats.items():
139
- if key in [
140
- "cumulative_shaped_rewards_by_agent",
141
- "cumulative_sparse_rewards_by_agent",
142
- ]:
143
- rewards_dict[key] = value
144
-
145
- image = self._visualizer.render_state(
146
- state=self._oenv.state,
147
- grid=self._mdp.terrain_mtx,
148
- hud_data=StateVisualizer.default_hud_data(self._oenv.state, **rewards_dict),
149
- )
150
-
151
- image = pygame.surfarray.array3d(image)
152
- image = np.flip(np.rot90(image, 3), 1)
153
- # Depending on the platform, the image may need to be converted to RGB
154
- if sys.platform in ("linux", "linux2"):
155
- image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
156
- return image
157
-
158
- @staticmethod
159
- def from_layout(
160
- layout: Literal[
161
- "asymmetric_advantages",
162
- "asymmetric_advantages_tomato",
163
- "bonus_order_test",
164
- "bottleneck",
165
- "centre_objects",
166
- "centre_pots",
167
- "coordination_ring",
168
- "corridor",
169
- "counter_circuit",
170
- "counter_circuit_o_1order",
171
- "cramped_corridor",
172
- "cramped_room",
173
- "cramped_room_o_3orders",
174
- "cramped_room_single",
175
- "cramped_room_tomato",
176
- "five_by_five",
177
- "forced_coordination",
178
- "forced_coordination_tomato",
179
- "inverse_marshmallow_experiment",
180
- "large_room",
181
- "long_cook_time",
182
- "marshmallow_experiment_coordination",
183
- "marshmallow_experiment",
184
- "mdp_test",
185
- "m_shaped_s",
186
- "multiplayer_schelling",
187
- "pipeline",
188
- "scenario1_s",
189
- "scenario2",
190
- "scenario2_s",
191
- "scenario3",
192
- "scenario4",
193
- "schelling",
194
- "schelling_s",
195
- "simple_o",
196
- "simple_o_t",
197
- "simple_tomato",
198
- "small_corridor",
199
- "soup_coordination",
200
- "tutorial_0",
201
- "tutorial_1",
202
- "tutorial_2",
203
- "tutorial_3",
204
- "unident",
205
- "you_shall_not_pass",
206
- ],
207
- horizon: int = 400,
208
- reward_shaping_factor: float | Schedule = 1.0,
209
- ):
210
- mdp = OvercookedGridworld.from_layout_name(layout)
211
- return Overcooked(OvercookedEnv.from_mdp(mdp, horizon=horizon, info_level=0), reward_shaping_factor, layout)
212
-
213
- @staticmethod
214
- def from_grid(
215
- grid: Sequence[Sequence[Literal["S", "P", "X", "O", "D", "T", "1", "2", " "] | str]],
216
- horizon: int = 400,
217
- shaping_factor: float | Schedule = 1.0,
218
- layout_name: Optional[str] = None,
219
- ):
220
- """
221
- Create an Overcooked environment from a grid layout where
222
- - S is a serving location
223
- - P is a cooking pot
224
- - X is a counter
225
- - O is an onion dispenser
226
- - D is a dish dispenser
227
- - T is a tomato dispenser
228
- - 1 is a player 1 starting location
229
- - 2 is a player 2 starting location
230
- - ' ' is a walkable space
231
-
232
- If provided, `custom_name` is added to the environment name.
233
- """
234
- # It is necessary to add an explicit layout name because Overcooked saves some files under this
235
- # name. By default the name is a concatenation of the grid elements, which may include characters
236
- # such as white spaces, pipes ('|') and square brackets ('[' and ']') that are invalid Windows file paths.
237
- if layout_name is None:
238
- layout_name = "custom-layout"
239
- mdp = OvercookedGridworld.from_grid(grid, base_layout_params={"layout_name": layout_name})
240
-
241
- return Overcooked(OvercookedEnv.from_mdp(mdp, horizon=horizon, info_level=0), shaping_factor, layout_name)